diff --git a/src/main/wacc/backend/asmGenerator.scala b/src/main/wacc/backend/asmGenerator.scala index 8d5e746..2a4a2ad 100644 --- a/src/main/wacc/backend/asmGenerator.scala +++ b/src/main/wacc/backend/asmGenerator.scala @@ -2,7 +2,6 @@ package wacc import scala.collection.mutable.LinkedHashMap import scala.collection.mutable.ListBuffer -import cats.syntax.align object asmGenerator { import microWacc._ @@ -18,6 +17,10 @@ object asmGenerator { val RIP = Register(RegSize.R64, RegName.IP) val RBP = Register(RegSize.R64, RegName.BP) val RSI = Register(RegSize.R64, RegName.SI) + val RDX = Register(RegSize.R64, RegName.DX) + val RCX = Register(RegSize.R64, RegName.CX) + val R8 = Register(RegSize.R64, RegName.Reg8) + val R9 = Register(RegSize.R64, RegName.Reg9) val _8_BIT_MASK = 0xff @@ -44,10 +47,11 @@ object asmGenerator { alignStack() ++ main.flatMap(generateStmt) ++ List(Move(RAX, ImmediateVal(0))) ++ - funcEpilogue() + funcEpilogue() ++ + generateBuiltInFuncs() val strDirs = strings.toList.zipWithIndex.flatMap { case (str, i) => - List(Directive.Int(str.size), LabelDef(s".L.str$i"), Directive.Asciz(str)) + List(Directive.Int(str.size), LabelDef(s".L.str$i"), Directive.Asciz(str.replace("\"", "\\\""))) } List(Directive.IntelSyntax, Directive.Global("main"), Directive.RoData) ++ @@ -73,7 +77,7 @@ object asmGenerator { wrapFunc( labelGenerator.getLabel(Builtin.Exit), alignStack() ++ - List(Pop(RDI), assemblyIR.Call(CLibFunc.Exit)) + List(assemblyIR.Call(CLibFunc.Exit)) ) ++ wrapFunc( labelGenerator.getLabel(Builtin.Printf), @@ -84,16 +88,28 @@ object asmGenerator { assemblyIR.Call(CLibFunc.Fflush) ) ) ++ - wrapFunc(labelGenerator.getLabel(Builtin.Malloc), List()) ++ - wrapFunc(labelGenerator.getLabel(Builtin.Free), List()) + wrapFunc( + labelGenerator.getLabel(Builtin.Malloc), + alignStack() ++ + List() + ) ++ + wrapFunc(labelGenerator.getLabel(Builtin.Free), List()) ++ + wrapFunc( + labelGenerator.getLabel(Builtin.Read), + alignStack() ++ + List( + Push(RSI), + Load(RSI, MemLocation(RSP)), + assemblyIR.Call(CLibFunc.Scanf), + Pop(RAX) + ) + ) } def generateStmt( stmt: Stmt )(using stack: LinkedHashMap[Ident, Int], strings: ListBuffer[String]): List[AsmLine] = stmt match { - case microWacc.Call(Builtin.Exit, code :: _) => - List() case Assign(lhs, rhs) => var dest: IndexAddress = IndexAddress(RSP, 0) // gets overrwitten @@ -112,15 +128,8 @@ object asmGenerator { // dest = ??? List() }) ++ - (rhs match { - case microWacc.Call(Builtin.ReadInt, _) => - readIntoVar(dest, Builtin.ReadInt) - case microWacc.Call(Builtin.ReadChar, _) => - readIntoVar(dest, Builtin.ReadChar) - case _ => - evalExprOntoStack(rhs) ++ - List(Pop(dest)) - }) + evalExprOntoStack(rhs) ++ + List(Pop(dest)) case If(cond, thenBranch, elseBranch) => { val elseLabel = labelGenerator.getLabel() val endLabel = labelGenerator.getLabel() @@ -151,8 +160,7 @@ object asmGenerator { case microWacc.Return(expr) => evalExprOntoStack(expr) ++ List(Pop(RAX), assemblyIR.Return()) - - case _ => List() + case call: microWacc.Call => generateCall(call) } def evalExprOntoStack(expr: Expr)(using @@ -213,7 +221,7 @@ object asmGenerator { evalExprOntoStack(x) ++ evalExprOntoStack(y) ++ List( - Pop(EAX), + Pop(RAX), Add(MemLocation(RSP, SizeDir.Word), EAX) // TODO OVERFLOWING ) @@ -221,7 +229,7 @@ object asmGenerator { evalExprOntoStack(x) ++ evalExprOntoStack(y) ++ List( - Pop(EAX), + Pop(RAX), Subtract(MemLocation(RSP, SizeDir.Word), EAX) // TODO OVERFLOWING ) @@ -229,28 +237,30 @@ object asmGenerator { evalExprOntoStack(x) ++ evalExprOntoStack(y) ++ List( - Pop(EAX), - Multiply(MemLocation(RSP, SizeDir.Word), EAX) + Pop(RAX), + Multiply(EAX, MemLocation(RSP, SizeDir.Word)), + Add(RSP, ImmediateVal(8)), + Push(RAX) // TODO OVERFLOWING ) case BinaryOperator.Div => evalExprOntoStack(y) ++ evalExprOntoStack(x) ++ List( - Pop(EAX), + Pop(RAX), Divide(MemLocation(RSP, SizeDir.Word)), Add(RSP, ImmediateVal(8)), - Push(EAX) + Push(RAX) // TODO CHECK DIVISOR IS NOT 0 ) case BinaryOperator.Mod => evalExprOntoStack(y) ++ evalExprOntoStack(x) ++ List( - Pop(EAX), + Pop(RAX), Divide(MemLocation(RSP, SizeDir.Word)), Add(RSP, ImmediateVal(8)), - Push(EDX) + Push(RDX) // TODO CHECK DIVISOR IS NOT 0 ) case BinaryOperator.Eq => @@ -269,21 +279,38 @@ object asmGenerator { evalExprOntoStack(x) ++ evalExprOntoStack(y) ++ List( - Pop(EAX), + Pop(RAX), And(MemLocation(RSP, SizeDir.Word), EAX) ) case BinaryOperator.Or => evalExprOntoStack(x) ++ evalExprOntoStack(y) ++ List( - Pop(EAX), + Pop(RAX), Or(MemLocation(RSP, SizeDir.Word), EAX) ) } - case microWacc.Call(target, args) => List() + case call: microWacc.Call => generateCall(call) } } + def generateCall(call: microWacc.Call)(using + stack: LinkedHashMap[Ident, Int], + strings: ListBuffer[String] + ): List[AsmLine] = { + val argRegs = List(RDI, RSI, RDX, RCX, R8, R9) + val microWacc.Call(target, args) = call + argRegs.zip(args).flatMap { (reg, expr) => + evalExprOntoStack(expr) ++ + List(Pop(reg)) + } ++ + args.drop(argRegs.size).flatMap(evalExprOntoStack) ++ + List(assemblyIR.Call(LabelArg(labelGenerator.getLabel(target)))) ++ + (if (args.size > argRegs.size) { + List(Load(RSP, IndexAddress(RSP, (args.size - argRegs.size) * 8))) + } else Nil) + } + // def readIntoVar(dest: IndexAddress, readType: Builtin.ReadInt.type | Builtin.ReadChar.type)(using // stack: LinkedHashMap[Ident, Int], // strings: ListBuffer[String] @@ -316,11 +343,11 @@ object asmGenerator { evalExprOntoStack(x) ++ evalExprOntoStack(y) ++ List( - Pop(EAX), + Pop(RAX), Compare(MemLocation(RSP, SizeDir.Word), EAX), Set(Register(RegSize.Byte, RegName.AL), cond), And(EAX, ImmediateVal(_8_BIT_MASK)), - Push(EAX) + Push(RAX) ) } def accessVar(ident: Ident)(using stack: LinkedHashMap[Ident, Int]): IndexAddress = @@ -352,83 +379,83 @@ object asmGenerator { // def restoreRegs(regList: List[Register]): List[AsmLine] = regList.reverse.map(Pop(_)) // TODO: refactor, really ugly function - def printF(expr: Expr)(using - stack: LinkedHashMap[Ident, Int], - strings: ListBuffer[String] - ): List[AsmLine] = { -// determine the format string - expr.ty match { - case KnownType.String => - strings += PrintFormat.String.toString - case KnownType.Char => - strings += PrintFormat.Char.toString - case KnownType.Int => - strings += PrintFormat.Int.toString - case _ => - strings += PrintFormat.String.toString - } - List( - Load( - RDI, - IndexAddress( - RIP, - LabelArg(s".L.str${strings.size - 1}") - ) - ) - ) - ++ - // determine the actual value to print - (if (expr.ty == KnownType.Bool) { - expr match { - case BoolLiter(true) => { - strings += "true" - } - case _ => { - strings += "false" - } - } - List( - Load( - RDI, - IndexAddress( - RIP, - LabelArg(s".L.str${strings.size - 1}") - ) - ) - ) +// def printF(expr: Expr)(using +// stack: LinkedHashMap[Ident, Int], +// strings: ListBuffer[String] +// ): List[AsmLine] = { +// // determine the format string +// expr.ty match { +// case KnownType.String => +// strings += PrintFormat.String.toString +// case KnownType.Char => +// strings += PrintFormat.Char.toString +// case KnownType.Int => +// strings += PrintFormat.Int.toString +// case _ => +// strings += PrintFormat.String.toString +// } +// List( +// Load( +// RDI, +// IndexAddress( +// RIP, +// LabelArg(s".L.str${strings.size - 1}") +// ) +// ) +// ) +// ++ +// // determine the actual value to print +// (if (expr.ty == KnownType.Bool) { +// expr match { +// case BoolLiter(true) => { +// strings += "true" +// } +// case _ => { +// strings += "false" +// } +// } +// List( +// Load( +// RDI, +// IndexAddress( +// RIP, +// LabelArg(s".L.str${strings.size - 1}") +// ) +// ) +// ) - } else { - evalExprOntoStack(expr) ++ - List(Pop(RSI)) - }) - // print the value - ++ - List( - assemblyIR.Call(CLibFunc.PrintF), - Move(RDI, ImmediateVal(0)), - assemblyIR.Call(CLibFunc.Fflush) - ) - } +// } else { +// evalExprOntoStack(expr) ++ +// List(Pop(RSI)) +// }) +// // print the value +// ++ +// List( +// assemblyIR.Call(CLibFunc.PrintF), +// Move(RDI, ImmediateVal(0)), +// assemblyIR.Call(CLibFunc.Fflush) +// ) +// } // prints a new line - def printLn()(using - stack: LinkedHashMap[Ident, Int], - strings: ListBuffer[String] - ): List[AsmLine] = { - strings += "" - Load( - RDI, - IndexAddress( - RIP, - LabelArg(s".L.str${strings.size - 1}") - ) - ) - :: - List( - assemblyIR.Call(CLibFunc.Puts), - Move(RDI, ImmediateVal(0)), - assemblyIR.Call(CLibFunc.Fflush) - ) + // def printLn()(using + // stack: LinkedHashMap[Ident, Int], + // strings: ListBuffer[String] + // ): List[AsmLine] = { + // strings += "" + // Load( + // RDI, + // IndexAddress( + // RIP, + // LabelArg(s".L.str${strings.size - 1}") + // ) + // ) + // :: + // List( + // assemblyIR.Call(CLibFunc.Puts), + // Move(RDI, ImmediateVal(0)), + // assemblyIR.Call(CLibFunc.Fflush) + // ) - } + // } } diff --git a/src/main/wacc/backend/assemblyIR.scala b/src/main/wacc/backend/assemblyIR.scala index 14e4f4f..0921fd8 100644 --- a/src/main/wacc/backend/assemblyIR.scala +++ b/src/main/wacc/backend/assemblyIR.scala @@ -179,7 +179,7 @@ object assemblyIR { override def toString(): String = this match { case Byte => "byte " + ptr - case Word => "word " + ptr + case Word => "dword " + ptr // TODO check word/doubleword/quadword case Unspecified => "" } } diff --git a/src/main/wacc/frontend/lexer.scala b/src/main/wacc/frontend/lexer.scala index 2efe517..e0b0a44 100644 --- a/src/main/wacc/frontend/lexer.scala +++ b/src/main/wacc/frontend/lexer.scala @@ -39,7 +39,7 @@ val errConfig = new ErrorConfig { ) } object lexer { - + /** Language description for the WACC lexer */ private val desc = LexicalDesc.plain.copy(