diff --git a/src/main/wacc/backend/asmGenerator.scala b/src/main/wacc/backend/asmGenerator.scala index 0e8eb55..deeaf8e 100644 --- a/src/main/wacc/backend/asmGenerator.scala +++ b/src/main/wacc/backend/asmGenerator.scala @@ -10,7 +10,6 @@ object asmGenerator { val RAX = Register(RegSize.R64, RegName.AX) val EAX = Register(RegSize.E32, RegName.AX) - val RSP = Register(RegSize.R64, RegName.SP) val ESP = Register(RegSize.E32, RegName.SP) val EDX = Register(RegSize.E32, RegName.DX) val RDI = Register(RegSize.R64, RegName.DI) @@ -37,14 +36,14 @@ object asmGenerator { } def generateAsm(microProg: Program): List[AsmLine] = { - given stack: LinkedHashMap[Ident, Int] = LinkedHashMap[Ident, Int]() + given stack: Stack = Stack() given strings: ListBuffer[String] = ListBuffer[String]() val Program(funcs, main) = microProg val progAsm = LabelDef("main") :: funcPrologue() ++ - alignStack() ++ + List(stack.align()) ++ main.flatMap(generateStmt) ++ List(Move(RAX, ImmediateVal(0))) ++ funcEpilogue() ++ @@ -61,7 +60,7 @@ object asmGenerator { } def wrapFunc(labelName: String, funcBody: List[AsmLine])(using - stack: LinkedHashMap[Ident, Int], + stack: Stack, strings: ListBuffer[String] ): List[AsmLine] = { LabelDef(labelName) :: @@ -71,74 +70,71 @@ object asmGenerator { } def generateBuiltInFuncs()(using - stack: LinkedHashMap[Ident, Int], + stack: Stack, strings: ListBuffer[String] ): List[AsmLine] = { wrapFunc( labelGenerator.getLabel(Builtin.Exit), - alignStack() ++ - List(assemblyIR.Call(CLibFunc.Exit)) + List(stack.align(), assemblyIR.Call(CLibFunc.Exit)) ) ++ wrapFunc( labelGenerator.getLabel(Builtin.Printf), - alignStack() ++ - List( - assemblyIR.Call(CLibFunc.PrintF), - Move(RDI, ImmediateVal(0)), - assemblyIR.Call(CLibFunc.Fflush) - ) + List( + stack.align(), + assemblyIR.Call(CLibFunc.PrintF), + Move(RDI, ImmediateVal(0)), + assemblyIR.Call(CLibFunc.Fflush) + ) ) ++ wrapFunc( labelGenerator.getLabel(Builtin.Malloc), - alignStack() ++ - List() + List( + stack.align(), + ) ) ++ wrapFunc(labelGenerator.getLabel(Builtin.Free), List()) ++ wrapFunc( labelGenerator.getLabel(Builtin.Read), - alignStack() ++ - List( - Push(RSI), - Load(RSI, MemLocation(RSP)), - assemblyIR.Call(CLibFunc.Scanf), - Pop(RAX) - ) + List( + stack.align(), + stack.push(RSI), + Load(RSI, stack.head), + assemblyIR.Call(CLibFunc.Scanf), + stack.pop(RAX) + ) ) } def generateStmt( stmt: Stmt - )(using stack: LinkedHashMap[Ident, Int], strings: ListBuffer[String]): List[AsmLine] = + )(using stack: Stack, strings: ListBuffer[String]): List[AsmLine] = stmt match { case Assign(lhs, rhs) => var dest: () => IndexAddress = - () => IndexAddress(RSP, 0) // gets overrwitten + () => IndexAddress(RAX, 0) // gets overrwitten (lhs match { case ident: Ident => + dest = stack.accessVar(ident) if (!stack.contains(ident)) { - stack += (ident -> (stack.size + 1)) - dest = accessVar(ident) - List(Subtract(RSP, ImmediateVal(8))) - } else { - dest = accessVar(ident) - List() - } + List(stack.reserve(ident)) + } else Nil // TODO lhs = arrayElem case _ => // dest = ??? List() }) ++ evalExprOntoStack(rhs) ++ - List(Pop(RAX), - Move(dest(), RAX), + List( + stack.pop(RAX), + Move(dest(), RAX), ) case If(cond, thenBranch, elseBranch) => { val elseLabel = labelGenerator.getLabel() val endLabel = labelGenerator.getLabel() evalExprOntoStack(cond) ++ List( - Compare(MemLocation(RSP, SizeDir.Word), ImmediateVal(0)), - Add(RSP, ImmediateVal(8)), + Compare(stack.head(SizeDir.Word), ImmediateVal(0)), + stack.drop(), Jump(LabelArg(elseLabel), Cond.Equal) ) ++ thenBranch.flatMap(generateStmt) ++ @@ -152,8 +148,8 @@ object asmGenerator { List(LabelDef(startLabel)) ++ evalExprOntoStack(cond) ++ List( - Compare(MemLocation(RSP, SizeDir.Word), ImmediateVal(0)), - Add(RSP, ImmediateVal(8)), + Compare(stack.head(SizeDir.Word), ImmediateVal(0)), + stack.drop(), Jump(LabelArg(endLabel), Cond.Equal) ) ++ body.flatMap(generateStmt) ++ @@ -161,21 +157,21 @@ object asmGenerator { } case microWacc.Return(expr) => evalExprOntoStack(expr) ++ - List(Pop(RAX), assemblyIR.Return()) + List(stack.pop(RAX), assemblyIR.Return()) case call: microWacc.Call => generateCall(call) } def evalExprOntoStack(expr: Expr)(using - stack: LinkedHashMap[Ident, Int], + stack: Stack, strings: ListBuffer[String] ): List[AsmLine] = { expr match { case IntLiter(v) => - List(Push(ImmediateVal(v))) + List(stack.push(ImmediateVal(v))) case CharLiter(v) => - List(Push(ImmediateVal(v.toInt))) + List(stack.push(ImmediateVal(v.toInt))) case ident: Ident => - List(Push(accessVar(ident)())) + List(stack.push(stack.accessVar(ident)())) case ArrayLiter(elems) => expr.ty match { case KnownType.String => @@ -191,13 +187,13 @@ object asmGenerator { LabelArg(s".L.str${strings.size - 1}") ) ), - Push(RAX) + stack.push(RAX) ) // TODO other array types case _ => List() } - case BoolLiter(v) => List(Push(ImmediateVal(if (v) 1 else 0))) - case NullLiter() => List(Push(ImmediateVal(0))) + case BoolLiter(v) => List(stack.push(ImmediateVal(if (v) 1 else 0))) + case NullLiter() => List(stack.push(ImmediateVal(0))) case ArrayElem(value, indices) => List() case UnaryOp(x, op) => op match { @@ -208,12 +204,12 @@ object asmGenerator { case UnaryOperator.Len => List() case UnaryOperator.Negate => List( - Negate(MemLocation(RSP, SizeDir.Word)) + Negate(stack.head(SizeDir.Word)) ) case UnaryOperator.Not => evalExprOntoStack(x) ++ List( - Xor(MemLocation(RSP, SizeDir.Word), ImmediateVal(1)) + Xor(stack.head(SizeDir.Word), ImmediateVal(1)) ) } @@ -223,46 +219,46 @@ object asmGenerator { evalExprOntoStack(x) ++ evalExprOntoStack(y) ++ List( - Pop(RAX), - Add(MemLocation(RSP, SizeDir.Word), EAX) + stack.pop(RAX), + Add(stack.head(SizeDir.Word), EAX) // TODO OVERFLOWING ) case BinaryOperator.Sub => evalExprOntoStack(x) ++ evalExprOntoStack(y) ++ List( - Pop(RAX), - Subtract(MemLocation(RSP, SizeDir.Word), EAX) + stack.pop(RAX), + Subtract(stack.head(SizeDir.Word), EAX) // TODO OVERFLOWING ) case BinaryOperator.Mul => evalExprOntoStack(x) ++ evalExprOntoStack(y) ++ List( - Pop(RAX), - Multiply(EAX, MemLocation(RSP, SizeDir.Word)), - Add(RSP, ImmediateVal(8)), - Push(RAX) + stack.pop(RAX), + Multiply(EAX, stack.head(SizeDir.Word)), + stack.drop(), + stack.push(RAX) // TODO OVERFLOWING ) case BinaryOperator.Div => evalExprOntoStack(y) ++ evalExprOntoStack(x) ++ List( - Pop(RAX), - Divide(MemLocation(RSP, SizeDir.Word)), - Add(RSP, ImmediateVal(8)), - Push(RAX) + stack.pop(RAX), + Divide(stack.head(SizeDir.Word)), + stack.drop(), + stack.push(RAX) // TODO CHECK DIVISOR IS NOT 0 ) case BinaryOperator.Mod => evalExprOntoStack(y) ++ evalExprOntoStack(x) ++ List( - Pop(RAX), - Divide(MemLocation(RSP, SizeDir.Word)), - Add(RSP, ImmediateVal(8)), - Push(RDX) + stack.pop(RAX), + Divide(stack.head(SizeDir.Word)), + stack.drop(), + stack.push(RDX) // TODO CHECK DIVISOR IS NOT 0 ) case BinaryOperator.Eq => @@ -281,15 +277,15 @@ object asmGenerator { evalExprOntoStack(x) ++ evalExprOntoStack(y) ++ List( - Pop(RAX), - And(MemLocation(RSP, SizeDir.Word), EAX) + stack.pop(RAX), + And(stack.head(SizeDir.Word), EAX) ) case BinaryOperator.Or => evalExprOntoStack(x) ++ evalExprOntoStack(y) ++ List( - Pop(RAX), - Or(MemLocation(RSP, SizeDir.Word), EAX) + stack.pop(RAX), + Or(stack.head(SizeDir.Word), EAX) ) } case call: microWacc.Call => generateCall(call) @@ -297,24 +293,24 @@ object asmGenerator { } def generateCall(call: microWacc.Call)(using - stack: LinkedHashMap[Ident, Int], + stack: Stack, strings: ListBuffer[String] ): List[AsmLine] = { val argRegs = List(RDI, RSI, RDX, RCX, R8, R9) val microWacc.Call(target, args) = call argRegs.zip(args).flatMap { (reg, expr) => evalExprOntoStack(expr) ++ - List(Pop(reg)) + List(stack.pop(reg)) } ++ args.drop(argRegs.size).flatMap(evalExprOntoStack) ++ List(assemblyIR.Call(LabelArg(labelGenerator.getLabel(target)))) ++ (if (args.size > argRegs.size) { - List(Load(RSP, IndexAddress(RSP, (args.size - argRegs.size) * 8))) + List(stack.reserve(args.size - argRegs.size)) } else Nil) } // def readIntoVar(dest: IndexAddress, readType: Builtin.ReadInt.type | Builtin.ReadChar.type)(using - // stack: LinkedHashMap[Ident, Int], + // stack: Stack, // strings: ListBuffer[String] // ): List[AsmLine] = { // readType match { @@ -339,41 +335,33 @@ object asmGenerator { // } def generateComparison(x: Expr, y: Expr, cond: Cond)(using - stack: LinkedHashMap[Ident, Int], + stack: Stack, strings: ListBuffer[String] ): List[AsmLine] = { evalExprOntoStack(x) ++ evalExprOntoStack(y) ++ List( - Pop(RAX), - Compare(MemLocation(RSP, SizeDir.Word), EAX), + stack.pop(RAX), + Compare(stack.head(SizeDir.Word), EAX), Set(Register(RegSize.Byte, RegName.AL), cond), - And(EAX, ImmediateVal(_8_BIT_MASK)), - Load(RSP, IndexAddress(RSP, 8)), - Push(RAX) + And(RAX, ImmediateVal(_8_BIT_MASK)), + stack.drop(), + stack.push(RAX) ) } - def accessVar(ident: Ident)(using stack: LinkedHashMap[Ident, Int]): () => IndexAddress = - () => IndexAddress(RSP, (stack.size - stack(ident)) * 8) - - def alignStack()(using stack: LinkedHashMap[Ident, Int]): List[AsmLine] = { - List( - And(RSP, ImmediateVal(-16)) - ) - } // Missing a sub instruction but dont think we need it - def funcPrologue(): List[AsmLine] = { + def funcPrologue()(using stack: Stack): List[AsmLine] = { List( - Push(RBP), - Move(RBP, RSP) + stack.push(RBP), + Move(RBP, Register(RegSize.R64, RegName.SP)) ) } - def funcEpilogue(): List[AsmLine] = { + def funcEpilogue()(using stack: Stack): List[AsmLine] = { List( - Move(RSP, RBP), - Pop(RBP), + Move(Register(RegSize.R64, RegName.SP), RBP), + stack.pop(RBP), assemblyIR.Return() ) } @@ -383,7 +371,7 @@ object asmGenerator { // TODO: refactor, really ugly function // def printF(expr: Expr)(using -// stack: LinkedHashMap[Ident, Int], +// stack: Stack, // strings: ListBuffer[String] // ): List[AsmLine] = { // // determine the format string @@ -442,7 +430,7 @@ object asmGenerator { // prints a new line // def printLn()(using - // stack: LinkedHashMap[Ident, Int], + // stack: Stack, // strings: ListBuffer[String] // ): List[AsmLine] = { // strings += "" @@ -461,4 +449,44 @@ object asmGenerator { // ) // } + + + class Stack { + private val stack = LinkedHashMap[Expr | Int, Int]() + private val RSP = Register(RegSize.R64, RegName.SP) + + def next: Int = stack.size + 1 + def push(expr: Expr, src: Src): AsmLine = { + stack += expr -> next + Push(src) + } + def push(src: Src): AsmLine = { + stack += stack.size -> next + Push(src) + } + def pop(dest: Src): AsmLine = { + stack.remove(stack.last._1) + Pop(dest) + } + def reserve(ident: Ident): AsmLine = { + stack += ident -> next + Subtract(RSP, ImmediateVal(8)) + } + def reserve(n: Int = 1): AsmLine = { + (1 to n).foreach(_ => stack += stack.size -> next) + Subtract(RSP, ImmediateVal(n*8)) + } + def drop(n : Int = 1): AsmLine = { + (1 to n).foreach(_ => stack.remove(stack.last._1)) + Add(RSP, ImmediateVal(n*8)) + } + def accessVar(ident: Ident): () => IndexAddress = () => { + IndexAddress(RSP, (stack.size - stack(ident)) * 8) + } + def head: MemLocation = MemLocation(RSP) + def head(size: SizeDir): MemLocation = MemLocation(RSP, size) + def contains(ident: Ident): Boolean = stack.contains(ident) + // TODO: Might want to actually properly handle this with the LinkedHashMap too + def align(): AsmLine = And(RSP, ImmediateVal(-16)) + } }