From 02e741c52e5d71654d8c4486c3710f002a2c55c4 Mon Sep 17 00:00:00 2001 From: Barf-Vader <47476490+Barf-Vader@users.noreply.github.com> Date: Fri, 21 Feb 2025 22:53:20 +0000 Subject: [PATCH 01/54] feat: implemented println and exit --- src/main/wacc/backend/asmGenerator.scala | 111 +++++++++++++++++++++++ src/main/wacc/backend/assemblyIR.scala | 56 ++++++++++-- 2 files changed, 158 insertions(+), 9 deletions(-) create mode 100644 src/main/wacc/backend/asmGenerator.scala diff --git a/src/main/wacc/backend/asmGenerator.scala b/src/main/wacc/backend/asmGenerator.scala new file mode 100644 index 0000000..7440791 --- /dev/null +++ b/src/main/wacc/backend/asmGenerator.scala @@ -0,0 +1,111 @@ +package wacc + +import scala.collection.mutable.LinkedHashMap +import scala.collection.mutable.ListBuffer + +object asmGenerator { + import microWacc._ + import assemblyIR._ + import wacc.types._ + + def generateAsm(microProg: Program): List[AsmLine] = { + given stack: LinkedHashMap[Ident, Int] = LinkedHashMap[Ident, Int]() + given strings: ListBuffer[String] = ListBuffer[String]() + val Program(funcs, main) = microProg + + val progAsm = + LabelDef("main") :: + main.flatMap(generateStmt) ++ + List(assemblyIR.Return()) ++ + generateFuncs() + + val strDirs = strings.toList.zipWithIndex.flatMap { case (str, i) => + List(Directive.Int(str.size), LabelDef(s".L.str$i:"), Directive.Asciz(str)) + } + + List(Directive.IntelSyntax, Directive.Global("main"), Directive.RoData) ++ + strDirs ++ + List(Directive.Text) ++ + progAsm + } + +//TODO + def generateFuncs()(using stack: LinkedHashMap[Ident, Int], strings: ListBuffer[String]): List[AsmLine] = { + List() + } + + def generateStmt(stmt: Stmt)(using stack: LinkedHashMap[Ident, Int], strings: ListBuffer[String]): List[AsmLine] = + stmt match { + case microWacc.Call(Builtin.Exit, code :: _) => + alignStack() ++ + evalExprIntoReg(code, Register(RegSize.R64, RegName.DI)) ++ + List(assemblyIR.Call(CLibFunc.Exit)) + + case microWacc.Call(Builtin.Println, expr :: _) => + alignStack() ++ + evalExprIntoReg(expr, Register(RegSize.R64, RegName.DI)) ++ + List( + assemblyIR.Call(CLibFunc.Puts), + Move(Register(RegSize.R64, RegName.DI), ImmediateVal(0)), + assemblyIR.Call(CLibFunc.Fflush)) ++ + restoreStack() + + case microWacc.Call(Builtin.ReadInt, expr :: _) => + List() + + case Assign(lhs, rhs) => + lhs match { + case ident: Ident => + stack += (ident -> stack.size) + evalExprIntoReg(rhs, Register(RegSize.R64, RegName.AX)) ++ + List(Push(Register(RegSize.R64, RegName.AX))) + case _ => List() + } + case _ => List() + } + + def evalExprIntoReg(expr: Expr, dest: Register) + (using stack: LinkedHashMap[Ident, Int], strings: ListBuffer[String]): List[AsmLine] = { + var src: Src = ImmediateVal(0) // Placeholder + (expr match { + case IntLiter(v) => + src = ImmediateVal(v) + List() + case ident: Ident => + List( + Move( + dest, + IndexAddress(Register(RegSize.R64, RegName.SP), (stack.size - stack(ident)) * 4) + ) + ) + case ArrayLiter(elems) => expr.ty match { + case KnownType.Char => + strings += elems.mkString + List( + Load(dest, IndexAddress(Register(RegSize.R64, RegName.IP),LabelArg(s".L.str${strings.size - 1}"))) + ) + case _ => List() + } + case _ => List() + }) ++ List(Move(dest, src)) + } + + def alignStack()(using stack: LinkedHashMap[Ident, Int]): List[AsmLine] = { + List( + And(Register(RegSize.R64, RegName.SP), ImmediateVal(-16)), + // Store stack pointer in rbp as it is callee saved + Push(Register(RegSize.R64, RegName.BP)), + Move(Register(RegSize.R64, RegName.BP), Register(RegSize.R64, RegName.SP)) + ) + } + + def restoreStack()(using stack: LinkedHashMap[Ident, Int]): List[AsmLine] = { + List( + Move(Register(RegSize.R64, RegName.SP), Register(RegSize.R64, RegName.BP)), + Pop(Register(RegSize.R64, RegName.BP)) + ) + } + + // def saveRegs(regList: List[Register]): List[AsmLine] = regList.map(Push(_)) + // def restoreRegs(regList: List[Register]): List[AsmLine] = regList.reverse.map(Pop(_)) +} diff --git a/src/main/wacc/backend/assemblyIR.scala b/src/main/wacc/backend/assemblyIR.scala index 323742b..e1b16ee 100644 --- a/src/main/wacc/backend/assemblyIR.scala +++ b/src/main/wacc/backend/assemblyIR.scala @@ -16,6 +16,29 @@ object assemblyIR { } } + enum RegName { + case AX, BX, CX, DX, SI, DI, SP, BP, IP, Reg8, Reg9, Reg10, Reg11, Reg12, Reg13, Reg14, Reg15 + override def toString = this match { + case AX => "ax" + case BX => "bx" + case CX => "cx" + case DX => "dx" + case SI => "si" + case DI => "di" + case SP => "sp" + case BP => "bp" + case IP => "ip" + case Reg8 => "8" + case Reg9 => "9" + case Reg10 => "10" + case Reg11 => "11" + case Reg12 => "12" + case Reg13 => "13" + case Reg14 => "14" + case Reg15 => "15" + } + } + // arguments enum CLibFunc extends Operand { case Scanf, @@ -35,13 +58,8 @@ object assemblyIR { } } - enum Register extends Dest with Src { - case Named(name: String, size: RegSize) - case Scratch(num: Int, size: RegSize) - override def toString = this match { - case Named(name, size) => s"${size}${name.toLowerCase()}" - case Scratch(num, size) => s"r${num}${if (size == RegSize.E32) "d" else ""}" - } + case class Register(size: RegSize, name: RegName) extends Dest with Src { + override def toString = s"${size}${name}" } case class MemLocation(pointer: Long | Register) extends Dest with Src { override def toString = pointer match { @@ -49,6 +67,9 @@ object assemblyIR { case reg: Register => s"[$reg]" } } + case class IndexAddress(base: Register, offset: Int | LabelArg) extends Dest with Src { + override def toString = s"[$base + $offset]" + } case class ImmediateVal(value: Int) extends Src { override def toString = value.toString @@ -74,10 +95,10 @@ object assemblyIR { // stack operations case class Push(op1: Src) extends Operation("push", op1) case class Pop(op1: Src) extends Operation("pop", op1) - case class Call(op1: CLibFunc) extends Operation("call", op1) + case class Call(op1: CLibFunc | LabelArg) extends Operation("call", op1) case class Move(op1: Dest, op2: Src) extends Operation("mov", op1, op2) - case class Load(op1: Register, op2: MemLocation) extends Operation("lea ", op1, op2) + case class Load(op1: Register, op2: MemLocation | IndexAddress) extends Operation("lea ", op1, op2) case class Return() extends Operation("ret") @@ -108,4 +129,21 @@ object assemblyIR { case Always => "mp" } } + + enum Directive extends AsmLine { + case IntelSyntax, RoData, Text + case Global(name: String) + case Int(value: scala.Int) + case Asciz(string: String) + + override def toString(): String = this match { + case IntelSyntax => ".intel_syntax noprefix" + case Global(name) => s".globl $name" + case Text => ".text" + case RoData => ".section .rodata" + case Int(value) => s".int $value" + case Asciz(string) => s".asciz $string" + + } + } } From ee4109e9cd94ffe7d17aacbb69abc74b993470b1 Mon Sep 17 00:00:00 2001 From: Barf-Vader <47476490+Barf-Vader@users.noreply.github.com> Date: Fri, 21 Feb 2025 23:00:59 +0000 Subject: [PATCH 02/54] style: fix style --- src/main/wacc/backend/asmGenerator.scala | 45 ++++++++++++++++-------- src/main/wacc/backend/assemblyIR.scala | 3 +- 2 files changed, 32 insertions(+), 16 deletions(-) diff --git a/src/main/wacc/backend/asmGenerator.scala b/src/main/wacc/backend/asmGenerator.scala index 7440791..ab350f4 100644 --- a/src/main/wacc/backend/asmGenerator.scala +++ b/src/main/wacc/backend/asmGenerator.scala @@ -30,11 +30,16 @@ object asmGenerator { } //TODO - def generateFuncs()(using stack: LinkedHashMap[Ident, Int], strings: ListBuffer[String]): List[AsmLine] = { + def generateFuncs()(using + stack: LinkedHashMap[Ident, Int], + strings: ListBuffer[String] + ): List[AsmLine] = { List() } - def generateStmt(stmt: Stmt)(using stack: LinkedHashMap[Ident, Int], strings: ListBuffer[String]): List[AsmLine] = + def generateStmt( + stmt: Stmt + )(using stack: LinkedHashMap[Ident, Int], strings: ListBuffer[String]): List[AsmLine] = stmt match { case microWacc.Call(Builtin.Exit, code :: _) => alignStack() ++ @@ -47,9 +52,10 @@ object asmGenerator { List( assemblyIR.Call(CLibFunc.Puts), Move(Register(RegSize.R64, RegName.DI), ImmediateVal(0)), - assemblyIR.Call(CLibFunc.Fflush)) ++ - restoreStack() - + assemblyIR.Call(CLibFunc.Fflush) + ) ++ + restoreStack() + case microWacc.Call(Builtin.ReadInt, expr :: _) => List() @@ -64,8 +70,10 @@ object asmGenerator { case _ => List() } - def evalExprIntoReg(expr: Expr, dest: Register) - (using stack: LinkedHashMap[Ident, Int], strings: ListBuffer[String]): List[AsmLine] = { + def evalExprIntoReg(expr: Expr, dest: Register)(using + stack: LinkedHashMap[Ident, Int], + strings: ListBuffer[String] + ): List[AsmLine] = { var src: Src = ImmediateVal(0) // Placeholder (expr match { case IntLiter(v) => @@ -78,14 +86,21 @@ object asmGenerator { IndexAddress(Register(RegSize.R64, RegName.SP), (stack.size - stack(ident)) * 4) ) ) - case ArrayLiter(elems) => expr.ty match { - case KnownType.Char => - strings += elems.mkString - List( - Load(dest, IndexAddress(Register(RegSize.R64, RegName.IP),LabelArg(s".L.str${strings.size - 1}"))) - ) - case _ => List() - } + case ArrayLiter(elems) => + expr.ty match { + case KnownType.Char => + strings += elems.mkString + List( + Load( + dest, + IndexAddress( + Register(RegSize.R64, RegName.IP), + LabelArg(s".L.str${strings.size - 1}") + ) + ) + ) + case _ => List() + } case _ => List() }) ++ List(Move(dest, src)) } diff --git a/src/main/wacc/backend/assemblyIR.scala b/src/main/wacc/backend/assemblyIR.scala index e1b16ee..73cdeaf 100644 --- a/src/main/wacc/backend/assemblyIR.scala +++ b/src/main/wacc/backend/assemblyIR.scala @@ -98,7 +98,8 @@ object assemblyIR { case class Call(op1: CLibFunc | LabelArg) extends Operation("call", op1) case class Move(op1: Dest, op2: Src) extends Operation("mov", op1, op2) - case class Load(op1: Register, op2: MemLocation | IndexAddress) extends Operation("lea ", op1, op2) + case class Load(op1: Register, op2: MemLocation | IndexAddress) + extends Operation("lea ", op1, op2) case class Return() extends Operation("ret") From 1ce36dd8da8e4b217dfce542f63968353a642a25 Mon Sep 17 00:00:00 2001 From: Barf-Vader <47476490+Barf-Vader@users.noreply.github.com> Date: Fri, 21 Feb 2025 23:34:37 +0000 Subject: [PATCH 03/54] refactor: unit tests now work with asm ir refactor --- src/test/wacc/instructionSpec.scala | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/src/test/wacc/instructionSpec.scala b/src/test/wacc/instructionSpec.scala index 6d427a5..b7452a0 100644 --- a/src/test/wacc/instructionSpec.scala +++ b/src/test/wacc/instructionSpec.scala @@ -3,28 +3,28 @@ import wacc.assemblyIR._ class instructionSpec extends AnyFunSuite { - val named64BitRegister = Register.Named("ax", RegSize.R64) + val named64BitRegister = Register(RegSize.R64, RegName.AX) test("named 64-bit register toString") { assert(named64BitRegister.toString == "rax") } - val named32BitRegister = Register.Named("ax", RegSize.E32) + val named32BitRegister = Register(RegSize.E32, RegName.AX) test("named 32-bit register toString") { assert(named32BitRegister.toString == "eax") } - val scratch64BitRegister = Register.Scratch(1, RegSize.R64) + val scratch64BitRegister = Register(RegSize.R64, RegName.Reg8) test("scratch 64-bit register toString") { - assert(scratch64BitRegister.toString == "r1") + assert(scratch64BitRegister.toString == "r8") } - val scratch32BitRegister = Register.Scratch(1, RegSize.E32) + val scratch32BitRegister = Register(RegSize.E32, RegName.Reg8) test("scratch 32-bit register toString") { - assert(scratch32BitRegister.toString == "r1d") + assert(scratch32BitRegister.toString == "e8") } val memLocationWithHex = MemLocation(0x12345678) @@ -54,7 +54,7 @@ class instructionSpec extends AnyFunSuite { val subInstruction = Subtract(scratch64BitRegister, named64BitRegister) test("x86: sub instruction toString") { - assert(subInstruction.toString == "\tsub r1, rax") + assert(subInstruction.toString == "\tsub r8, rax") } val callInstruction = Call(CLibFunc.Scanf) From 7f2870e340c633a36f965309b102c437041686c4 Mon Sep 17 00:00:00 2001 From: Gleb Koval Date: Fri, 21 Feb 2025 23:30:17 +0000 Subject: [PATCH 04/54] feat: generate assembly from main --- src/main/wacc/Main.scala | 162 +---------------------------------- src/test/wacc/examples.scala | 3 +- 2 files changed, 6 insertions(+), 159 deletions(-) diff --git a/src/main/wacc/Main.scala b/src/main/wacc/Main.scala index e8e7b7b..89cfd98 100644 --- a/src/main/wacc/Main.scala +++ b/src/main/wacc/Main.scala @@ -64,157 +64,8 @@ def frontend( } val s = "enter an integer to echo" -def backend(typedProg: microWacc.Program): List[asm.AsmLine] | String = - typedProg match { - case microWacc.Program( - Nil, - microWacc.Call(microWacc.Builtin.Exit, microWacc.IntLiter(v) :: Nil) :: Nil - ) => - s""".intel_syntax noprefix -.globl main -main: - mov edi, ${v} - call exit@plt -""" - case microWacc.Program( - Nil, - microWacc.Assign(microWacc.Ident("x", _), microWacc.IntLiter(1)) :: - microWacc.Call(microWacc.Builtin.Println, _) :: - microWacc.Assign( - microWacc.Ident("x", _), - microWacc.Call(microWacc.Builtin.ReadInt, Nil) - ) :: - microWacc.Call(microWacc.Builtin.Println, microWacc.Ident("x", _) :: Nil) :: Nil - ) => - """.intel_syntax noprefix -.globl main -.section .rodata -# length of .L.str0 - .int 24 -.L.str0: - .asciz "enter an integer to echo" -.text -main: - push rbp - # push {rbx, r12} - sub rsp, 16 - mov qword ptr [rsp], rbx - mov qword ptr [rsp + 8], r12 - mov rbp, rsp - mov r12d, 1 - lea rdi, [rip + .L.str0] - # statement primitives do not return results (but will clobber r0/rax) - call _prints - call _println - # load the current value in the destination of the read so it supports defaults - mov edi, r12d - call _readi - mov r12d, eax - mov edi, eax - # statement primitives do not return results (but will clobber r0/rax) - call _printi - call _println - mov rax, 0 - # pop/peek {rbx, r12} - mov rbx, qword ptr [rsp] - mov r12, qword ptr [rsp + 8] - add rsp, 16 - pop rbp - ret - -.section .rodata -# length of .L._printi_str0 - .int 2 -.L._printi_str0: - .asciz "%d" -.text -_printi: - push rbp - mov rbp, rsp - # external calls must be stack-aligned to 16 bytes, accomplished by masking with fffffffffffffff0 - and rsp, -16 - mov esi, edi - lea rdi, [rip + .L._printi_str0] - # on x86, al represents the number of SIMD registers used as variadic arguments - mov al, 0 - call printf@plt - mov rdi, 0 - call fflush@plt - mov rsp, rbp - pop rbp - ret - -.section .rodata -# length of .L._prints_str0 - .int 4 -.L._prints_str0: - .asciz "%.*s" -.text -_prints: - push rbp - mov rbp, rsp - # external calls must be stack-aligned to 16 bytes, accomplished by masking with fffffffffffffff0 - and rsp, -16 - mov rdx, rdi - mov esi, dword ptr [rdi - 4] - lea rdi, [rip + .L._prints_str0] - # on x86, al represents the number of SIMD registers used as variadic arguments - mov al, 0 - call printf@plt - mov rdi, 0 - call fflush@plt - mov rsp, rbp - pop rbp - ret - -.section .rodata -# length of .L._println_str0 - .int 0 -.L._println_str0: - .asciz "" -.text -_println: - push rbp - mov rbp, rsp - # external calls must be stack-aligned to 16 bytes, accomplished by masking with fffffffffffffff0 - and rsp, -16 - lea rdi, [rip + .L._println_str0] - call puts@plt - mov rdi, 0 - call fflush@plt - mov rsp, rbp - pop rbp - ret - -.section .rodata -# length of .L._readi_str0 - .int 2 -.L._readi_str0: - .asciz "%d" -.text -_readi: - push rbp - mov rbp, rsp - # external calls must be stack-aligned to 16 bytes, accomplished by masking with fffffffffffffff0 - and rsp, -16 - # RDI contains the "original" value of the destination of the read - # allocate space on the stack to store the read: preserve alignment! - # the passed default argument should be stored in case of EOF - sub rsp, 16 - mov dword ptr [rsp], edi - lea rsi, qword ptr [rsp] - lea rdi, [rip + .L._readi_str0] - # on x86, al represents the number of SIMD registers used as variadic arguments - mov al, 0 - call scanf@plt - mov eax, dword ptr [rsp] - add rsp, 16 - mov rsp, rbp - pop rbp - ret - """ - case _ => List() - } +def backend(typedProg: microWacc.Program): List[asm.AsmLine] = + asmGenerator.generateAsm(typedProg) def compile(filename: String, outFile: Option[File] = None)(using stdout: PrintStream = Console.out @@ -222,13 +73,8 @@ def compile(filename: String, outFile: Option[File] = None)(using frontend(os.read(os.Path(filename))) match { case Left(typedProg) => val asmFile = outFile.getOrElse(File(filename.stripSuffix(".wacc") + ".s")) - backend(typedProg) match { - case s: String => - os.write.over(os.Path(asmFile.getAbsolutePath), s) - case ops: List[asm.AsmLine] => { - writer.writeTo(ops, PrintStream(asmFile)) - } - } + val asm = backend(typedProg) + writer.writeTo(asm, PrintStream(asmFile)) 0 case Right(exitCode) => exitCode } diff --git a/src/test/wacc/examples.scala b/src/test/wacc/examples.scala index abff693..970bcb6 100644 --- a/src/test/wacc/examples.scala +++ b/src/test/wacc/examples.scala @@ -81,13 +81,14 @@ class ParallelExamplesSpec extends AnyFlatSpec with BeforeAndAfterAll with Paral // disable formatting to avoid binPack "^.*wacc-examples/valid/advanced.*$", "^.*wacc-examples/valid/array.*$", + "^.*wacc-examples/valid/basic/exit.*$", "^.*wacc-examples/valid/basic/skip.*$", "^.*wacc-examples/valid/expressions.*$", "^.*wacc-examples/valid/function/nested_functions.*$", "^.*wacc-examples/valid/function/simple_functions.*$", "^.*wacc-examples/valid/if.*$", "^.*wacc-examples/valid/IO/print.*$", - "^.*wacc-examples/valid/IO/read(?!echoInt\\.wacc).*$", + "^.*wacc-examples/valid/IO/read.*$", "^.*wacc-examples/valid/IO/IOLoop.wacc.*$", "^.*wacc-examples/valid/IO/IOSequence.wacc.*$", "^.*wacc-examples/valid/pairs.*$", From 24dddcadabe6764dcedd1e0e501cf145b48ad8e3 Mon Sep 17 00:00:00 2001 From: Alex Ling Date: Sat, 22 Feb 2025 21:38:12 +0000 Subject: [PATCH 05/54] feat: almost complete clib calls --- src/main/wacc/backend/asmGenerator.scala | 204 +++++++++++++++++++---- src/main/wacc/backend/assemblyIR.scala | 11 +- src/test/wacc/examples.scala | 38 ++--- 3 files changed, 198 insertions(+), 55 deletions(-) diff --git a/src/main/wacc/backend/asmGenerator.scala b/src/main/wacc/backend/asmGenerator.scala index ab350f4..7964cf4 100644 --- a/src/main/wacc/backend/asmGenerator.scala +++ b/src/main/wacc/backend/asmGenerator.scala @@ -15,12 +15,15 @@ object asmGenerator { val progAsm = LabelDef("main") :: + funcPrologue() ++ + alignStack() ++ main.flatMap(generateStmt) ++ + funcEpilogue() ++ List(assemblyIR.Return()) ++ generateFuncs() val strDirs = strings.toList.zipWithIndex.flatMap { case (str, i) => - List(Directive.Int(str.size), LabelDef(s".L.str$i:"), Directive.Asciz(str)) + List(Directive.Int(str.size), LabelDef(s".L.str$i"), Directive.Asciz(str)) } List(Directive.IntelSyntax, Directive.Global("main"), Directive.RoData) ++ @@ -42,31 +45,47 @@ object asmGenerator { )(using stack: LinkedHashMap[Ident, Int], strings: ListBuffer[String]): List[AsmLine] = stmt match { case microWacc.Call(Builtin.Exit, code :: _) => - alignStack() ++ - evalExprIntoReg(code, Register(RegSize.R64, RegName.DI)) ++ + // alignStack() ++ + evalExprIntoReg(code, Register(RegSize.R64, RegName.DI)) ++ List(assemblyIR.Call(CLibFunc.Exit)) case microWacc.Call(Builtin.Println, expr :: _) => - alignStack() ++ - evalExprIntoReg(expr, Register(RegSize.R64, RegName.DI)) ++ - List( - assemblyIR.Call(CLibFunc.Puts), - Move(Register(RegSize.R64, RegName.DI), ImmediateVal(0)), - assemblyIR.Call(CLibFunc.Fflush) - ) ++ - restoreStack() + // alignStack() ++ + printF(expr) ++ + printLn() - case microWacc.Call(Builtin.ReadInt, expr :: _) => - List() + case microWacc.Call(Builtin.Print, expr :: _) => + // alignStack() ++ + printF(expr) case Assign(lhs, rhs) => - lhs match { + var dest: IndexAddress = + IndexAddress(Register(RegSize.R64, RegName.SP), 0) // gets overrwitten + (lhs match { case ident: Ident => - stack += (ident -> stack.size) - evalExprIntoReg(rhs, Register(RegSize.R64, RegName.AX)) ++ - List(Push(Register(RegSize.R64, RegName.AX))) - case _ => List() - } + if (!stack.contains(ident)) { + stack += (ident -> (stack.size + 1)) + dest = accessVar(ident) + List(Subtract(Register(RegSize.R64, RegName.SP), ImmediateVal(16))) + } else { + dest = accessVar(ident) + List() + } + // TODO lhs = arrayElem + case _ => + // dest = ??? + List() + }) ++ + (rhs match { + case microWacc.Call(Builtin.ReadInt, _) => + readIntoVar(dest, Builtin.ReadInt) + case microWacc.Call(Builtin.ReadChar, _) => + readIntoVar(dest, Builtin.ReadChar) + case _ => + evalExprIntoReg(rhs, Register(RegSize.R64, RegName.AX)) ++ + List(Move(dest, Register(RegSize.R64, RegName.AX))) + }) + // TODO other statements case _ => List() } @@ -74,22 +93,20 @@ object asmGenerator { stack: LinkedHashMap[Ident, Int], strings: ListBuffer[String] ): List[AsmLine] = { - var src: Src = ImmediateVal(0) // Placeholder - (expr match { + expr match { case IntLiter(v) => - src = ImmediateVal(v) - List() + List(Move(dest, ImmediateVal(v))) + case CharLiter(v) => + List(Move(dest, ImmediateVal(v.toInt))) case ident: Ident => - List( - Move( - dest, - IndexAddress(Register(RegSize.R64, RegName.SP), (stack.size - stack(ident)) * 4) - ) - ) + List(Move(dest, accessVar(ident))) case ArrayLiter(elems) => expr.ty match { - case KnownType.Char => - strings += elems.mkString + case KnownType.String => + strings += elems.map { + case CharLiter(v) => v + case _ => "" + }.mkString List( Load( dest, @@ -99,22 +116,59 @@ object asmGenerator { ) ) ) + // TODO other array types case _ => List() } + // TODO other expr types case _ => List() - }) ++ List(Move(dest, src)) + } } + // TODO make sure EOF doenst override the value in the stack + // probably need labels implemented for conditional jumps + def readIntoVar(dest: IndexAddress, readType: Builtin.ReadInt.type | Builtin.ReadChar.type)(using + stack: LinkedHashMap[Ident, Int], + strings: ListBuffer[String] + ): List[AsmLine] = { + readType match { + case Builtin.ReadInt => + strings += PrintFormat.Int.toString + case Builtin.ReadChar => + strings += PrintFormat.Char.toString + } + List( + Load( + Register(RegSize.R64, RegName.DI), + IndexAddress( + Register(RegSize.R64, RegName.IP), + LabelArg(s".L.str${strings.size - 1}") + ) + ), + Load(Register(RegSize.R64, RegName.SI), dest) + ) ++ + // alignStack() ++ + List(assemblyIR.Call(CLibFunc.Scanf)) + + } + + def accessVar(ident: Ident)(using stack: LinkedHashMap[Ident, Int]): IndexAddress = + IndexAddress(Register(RegSize.R64, RegName.SP), (stack.size - stack(ident)) * 16) + def alignStack()(using stack: LinkedHashMap[Ident, Int]): List[AsmLine] = { List( - And(Register(RegSize.R64, RegName.SP), ImmediateVal(-16)), - // Store stack pointer in rbp as it is callee saved + And(Register(RegSize.R64, RegName.SP), ImmediateVal(-16)) + ) + } + + // Missing a sub instruction but dont think we need it + def funcPrologue(): List[AsmLine] = { + List( Push(Register(RegSize.R64, RegName.BP)), Move(Register(RegSize.R64, RegName.BP), Register(RegSize.R64, RegName.SP)) ) } - def restoreStack()(using stack: LinkedHashMap[Ident, Int]): List[AsmLine] = { + def funcEpilogue(): List[AsmLine] = { List( Move(Register(RegSize.R64, RegName.SP), Register(RegSize.R64, RegName.BP)), Pop(Register(RegSize.R64, RegName.BP)) @@ -123,4 +177,84 @@ object asmGenerator { // def saveRegs(regList: List[Register]): List[AsmLine] = regList.map(Push(_)) // def restoreRegs(regList: List[Register]): List[AsmLine] = regList.reverse.map(Pop(_)) + +// TODO: refactor, really ugly function + def printF(expr: Expr)(using + stack: LinkedHashMap[Ident, Int], + strings: ListBuffer[String] + ): List[AsmLine] = { +// determine the format string + expr.ty match { + case KnownType.String => + strings += PrintFormat.String.toString + case KnownType.Char => + strings += PrintFormat.Char.toString + case KnownType.Int => + strings += PrintFormat.Int.toString + case _ => + strings += PrintFormat.String.toString + } + List( + Load( + Register(RegSize.R64, RegName.DI), + IndexAddress( + Register(RegSize.R64, RegName.IP), + LabelArg(s".L.str${strings.size - 1}") + ) + ) + ) + ++ + // determine the actual value to print + (if (expr.ty == KnownType.Bool) { + expr match { + case BoolLiter(true) => { + strings += "true" + } + case _ => { + strings += "false" + } + } + List( + Load( + Register(RegSize.R64, RegName.DI), + IndexAddress( + Register(RegSize.R64, RegName.IP), + LabelArg(s".L.str${strings.size - 1}") + ) + ) + ) + + } else { + evalExprIntoReg(expr, Register(RegSize.R64, RegName.SI)) + }) + // print the value + ++ + List( + assemblyIR.Call(CLibFunc.PrintF), + Move(Register(RegSize.R64, RegName.DI), ImmediateVal(0)), + assemblyIR.Call(CLibFunc.Fflush) + ) + } + +// prints a new line + def printLn()(using + stack: LinkedHashMap[Ident, Int], + strings: ListBuffer[String] + ): List[AsmLine] = { + strings += "" + Load( + Register(RegSize.R64, RegName.DI), + IndexAddress( + Register(RegSize.R64, RegName.IP), + LabelArg(s".L.str${strings.size - 1}") + ) + ) + :: + List( + assemblyIR.Call(CLibFunc.Puts), + Move(Register(RegSize.R64, RegName.DI), ImmediateVal(0)), + assemblyIR.Call(CLibFunc.Fflush) + ) + + } } diff --git a/src/main/wacc/backend/assemblyIR.scala b/src/main/wacc/backend/assemblyIR.scala index 73cdeaf..c48daac 100644 --- a/src/main/wacc/backend/assemblyIR.scala +++ b/src/main/wacc/backend/assemblyIR.scala @@ -143,8 +143,17 @@ object assemblyIR { case Text => ".text" case RoData => ".section .rodata" case Int(value) => s".int $value" - case Asciz(string) => s".asciz $string" + case Asciz(string) => s".asciz \"$string\"" + } + } + enum PrintFormat { + case Int, Char, String + + override def toString(): String = this match { + case Int => "%d" + case Char => "%c" + case String => "%s" } } } diff --git a/src/test/wacc/examples.scala b/src/test/wacc/examples.scala index 970bcb6..abf6769 100644 --- a/src/test/wacc/examples.scala +++ b/src/test/wacc/examples.scala @@ -47,7 +47,7 @@ class ParallelExamplesSpec extends AnyFlatSpec with BeforeAndAfterAll with Paral .drop(outputLineIdx + 1) .takeWhile(_.startsWith("#")) .map(_.stripPrefix("#").stripLeading) - .mkString("\n") + .mkString("") val exitLineIdx = contents.indexWhere(_.matches("^# ?[Ee]xit:.*$")) val expectedExit = @@ -79,24 +79,24 @@ class ParallelExamplesSpec extends AnyFlatSpec with BeforeAndAfterAll with Paral Seq( // format: off // disable formatting to avoid binPack - "^.*wacc-examples/valid/advanced.*$", - "^.*wacc-examples/valid/array.*$", - "^.*wacc-examples/valid/basic/exit.*$", - "^.*wacc-examples/valid/basic/skip.*$", - "^.*wacc-examples/valid/expressions.*$", - "^.*wacc-examples/valid/function/nested_functions.*$", - "^.*wacc-examples/valid/function/simple_functions.*$", - "^.*wacc-examples/valid/if.*$", - "^.*wacc-examples/valid/IO/print.*$", - "^.*wacc-examples/valid/IO/read.*$", - "^.*wacc-examples/valid/IO/IOLoop.wacc.*$", - "^.*wacc-examples/valid/IO/IOSequence.wacc.*$", - "^.*wacc-examples/valid/pairs.*$", - "^.*wacc-examples/valid/runtimeErr.*$", - "^.*wacc-examples/valid/scope.*$", - "^.*wacc-examples/valid/sequence.*$", - "^.*wacc-examples/valid/variables.*$", - "^.*wacc-examples/valid/while.*$", + // "^.*wacc-examples/valid/advanced.*$", + // "^.*wacc-examples/valid/array.*$", + // "^.*wacc-examples/valid/basic/exit.*$", + // "^.*wacc-examples/valid/basic/skip.*$", + // "^.*wacc-examples/valid/expressions.*$", + // "^.*wacc-examples/valid/function/nested_functions.*$", + // "^.*wacc-examples/valid/function/simple_functions.*$", + // "^.*wacc-examples/valid/if.*$", + // "^.*wacc-examples/valid/IO/print.*$", + // "^.*wacc-examples/valid/IO/read.*$", + // "^.*wacc-examples/valid/IO/IOLoop.wacc.*$", + // "^.*wacc-examples/valid/IO/IOSequence.wacc.*$", + // "^.*wacc-examples/valid/pairs.*$", + // "^.*wacc-examples/valid/runtimeErr.*$", + // "^.*wacc-examples/valid/scope.*$", + // "^.*wacc-examples/valid/sequence.*$", + // "^.*wacc-examples/valid/variables.*$", + // "^.*wacc-examples/valid/while.*$", // format: on ).find(filename.matches).isDefined } From 1255a2e74c9a24ac045093dbec89e64a803130fa Mon Sep 17 00:00:00 2001 From: Guy C Date: Sat, 22 Feb 2025 22:53:17 +0000 Subject: [PATCH 06/54] feat: add initialization of AX register in function prologue --- src/main/wacc/backend/asmGenerator.scala | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/main/wacc/backend/asmGenerator.scala b/src/main/wacc/backend/asmGenerator.scala index 7964cf4..df863a7 100644 --- a/src/main/wacc/backend/asmGenerator.scala +++ b/src/main/wacc/backend/asmGenerator.scala @@ -164,7 +164,8 @@ object asmGenerator { def funcPrologue(): List[AsmLine] = { List( Push(Register(RegSize.R64, RegName.BP)), - Move(Register(RegSize.R64, RegName.BP), Register(RegSize.R64, RegName.SP)) + Move(Register(RegSize.R64, RegName.BP), Register(RegSize.R64, RegName.SP)), + Move(Register(RegSize.R64, RegName.AX), ImmediateVal(0)) ) } From 82230a5f66146543b1df7a0bcba1f28eb51210cb Mon Sep 17 00:00:00 2001 From: Guy C Date: Sat, 22 Feb 2025 22:53:42 +0000 Subject: [PATCH 07/54] refactor: remove unused import of IntLiter in Main.scala --- src/main/wacc/Main.scala | 1 - 1 file changed, 1 deletion(-) diff --git a/src/main/wacc/Main.scala b/src/main/wacc/Main.scala index 89cfd98..52c40aa 100644 --- a/src/main/wacc/Main.scala +++ b/src/main/wacc/Main.scala @@ -7,7 +7,6 @@ import java.io.File import java.io.PrintStream import assemblyIR as asm -import wacc.microWacc.IntLiter case class CliConfig( file: File = new File(".") From c59c28ecbdbb5d7437e6c6fc0090c210863ba8e6 Mon Sep 17 00:00:00 2001 From: Gleb Koval Date: Sat, 22 Feb 2025 23:38:19 +0000 Subject: [PATCH 08/54] test: retrieve raw stdout in example tests, rather than lines --- src/test/wacc/examples.scala | 16 +++++++++++----- 1 file changed, 11 insertions(+), 5 deletions(-) diff --git a/src/test/wacc/examples.scala b/src/test/wacc/examples.scala index abf6769..e7397b6 100644 --- a/src/test/wacc/examples.scala +++ b/src/test/wacc/examples.scala @@ -6,6 +6,7 @@ import org.scalatest.Inspectors.forEvery import java.io.File import sys.process._ import java.io.PrintStream +import scala.io.Source class ParallelExamplesSpec extends AnyFlatSpec with BeforeAndAfterAll with ParallelTestExecution { val files = @@ -47,7 +48,7 @@ class ParallelExamplesSpec extends AnyFlatSpec with BeforeAndAfterAll with Paral .drop(outputLineIdx + 1) .takeWhile(_.startsWith("#")) .map(_.stripPrefix("#").stripLeading) - .mkString("") + .mkString("\n") val exitLineIdx = contents.indexWhere(_.matches("^# ?[Ee]xit:.*$")) val expectedExit = @@ -62,11 +63,16 @@ class ParallelExamplesSpec extends AnyFlatSpec with BeforeAndAfterAll with Paral // Run the executable with the provided input val stdout = new StringBuilder - // val execResult = s"$execFilename".!(ProcessLogger(stdout.append(_))) - val execResult = - s"echo $inputLine" #| s"timeout 5s $execFilename" ! ProcessLogger(stdout.append(_)) + val process = s"timeout 5s $execFilename" run ProcessIO( + in = w => { + w.write(inputLine.getBytes) + w.close() + }, + out = Source.fromInputStream(_).addString(stdout), + err = _ => () + ) - assert(execResult == expectedExit) + assert(process.exitValue == expectedExit) assert(stdout.toString == expectedOutput) } } From dc61b1e390ac4a33e096e96122cf7cf8b1603186 Mon Sep 17 00:00:00 2001 From: Guy C Date: Mon, 24 Feb 2025 02:00:35 +0000 Subject: [PATCH 09/54] feat: implement label generation and basic conditional branching in asmGenerator --- src/main/wacc/backend/asmGenerator.scala | 23 +++++++++++++++++++++-- 1 file changed, 21 insertions(+), 2 deletions(-) diff --git a/src/main/wacc/backend/asmGenerator.scala b/src/main/wacc/backend/asmGenerator.scala index df863a7..55da169 100644 --- a/src/main/wacc/backend/asmGenerator.scala +++ b/src/main/wacc/backend/asmGenerator.scala @@ -3,6 +3,13 @@ package wacc import scala.collection.mutable.LinkedHashMap import scala.collection.mutable.ListBuffer +object labelGenerator { + var labelVal = -1 + def getLabel(): String = { + labelVal += 1 + s".L$labelVal" + } +} object asmGenerator { import microWacc._ import assemblyIR._ @@ -85,7 +92,18 @@ object asmGenerator { evalExprIntoReg(rhs, Register(RegSize.R64, RegName.AX)) ++ List(Move(dest, Register(RegSize.R64, RegName.AX))) }) - // TODO other statements + case If(cond, thenBranch, elseBranch) => { + val elseLabel = labelGenerator.getLabel() + val endLabel = labelGenerator.getLabel() + evalExprIntoReg(cond, Register(RegSize.R64, RegName.AX)) ++ + List(Compare(Register(RegSize.R64, RegName.AX), ImmediateVal(0)), + Jump(LabelArg(elseLabel), Cond.Equal)) ++ + thenBranch.flatMap(generateStmt) ++ + List(Jump(LabelArg(endLabel)), + LabelDef(elseLabel)) ++ + elseBranch.flatMap(generateStmt) ++ + List(LabelDef(endLabel)) + } case _ => List() } @@ -120,6 +138,7 @@ object asmGenerator { case _ => List() } // TODO other expr types + case BoolLiter(v) => List(Move(dest, ImmediateVal(if (v) 1 else 0))) case _ => List() } } @@ -165,12 +184,12 @@ object asmGenerator { List( Push(Register(RegSize.R64, RegName.BP)), Move(Register(RegSize.R64, RegName.BP), Register(RegSize.R64, RegName.SP)), - Move(Register(RegSize.R64, RegName.AX), ImmediateVal(0)) ) } def funcEpilogue(): List[AsmLine] = { List( + Move(Register(RegSize.R64, RegName.AX), ImmediateVal(0)), Move(Register(RegSize.R64, RegName.SP), Register(RegSize.R64, RegName.BP)), Pop(Register(RegSize.R64, RegName.BP)) ) From 2bed722a4fe0d5219f56fa6c293d4ab107917dd4 Mon Sep 17 00:00:00 2001 From: Guy C Date: Mon, 24 Feb 2025 02:04:21 +0000 Subject: [PATCH 10/54] style: improve code formatting and readability in asmGenerator --- src/main/wacc/backend/asmGenerator.scala | 15 ++++++++------- 1 file changed, 8 insertions(+), 7 deletions(-) diff --git a/src/main/wacc/backend/asmGenerator.scala b/src/main/wacc/backend/asmGenerator.scala index 55da169..ec16435 100644 --- a/src/main/wacc/backend/asmGenerator.scala +++ b/src/main/wacc/backend/asmGenerator.scala @@ -96,11 +96,12 @@ object asmGenerator { val elseLabel = labelGenerator.getLabel() val endLabel = labelGenerator.getLabel() evalExprIntoReg(cond, Register(RegSize.R64, RegName.AX)) ++ - List(Compare(Register(RegSize.R64, RegName.AX), ImmediateVal(0)), - Jump(LabelArg(elseLabel), Cond.Equal)) ++ - thenBranch.flatMap(generateStmt) ++ - List(Jump(LabelArg(endLabel)), - LabelDef(elseLabel)) ++ + List( + Compare(Register(RegSize.R64, RegName.AX), ImmediateVal(0)), + Jump(LabelArg(elseLabel), Cond.Equal) + ) ++ + thenBranch.flatMap(generateStmt) ++ + List(Jump(LabelArg(endLabel)), LabelDef(elseLabel)) ++ elseBranch.flatMap(generateStmt) ++ List(LabelDef(endLabel)) } @@ -139,7 +140,7 @@ object asmGenerator { } // TODO other expr types case BoolLiter(v) => List(Move(dest, ImmediateVal(if (v) 1 else 0))) - case _ => List() + case _ => List() } } @@ -183,7 +184,7 @@ object asmGenerator { def funcPrologue(): List[AsmLine] = { List( Push(Register(RegSize.R64, RegName.BP)), - Move(Register(RegSize.R64, RegName.BP), Register(RegSize.R64, RegName.SP)), + Move(Register(RegSize.R64, RegName.BP), Register(RegSize.R64, RegName.SP)) ) } From 909114bdce3930b3f2433903cb5690a584750684 Mon Sep 17 00:00:00 2001 From: Guy C Date: Mon, 24 Feb 2025 04:47:21 +0000 Subject: [PATCH 11/54] feat: implement basic while loop generation in asmGenerator --- src/main/wacc/backend/asmGenerator.scala | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/src/main/wacc/backend/asmGenerator.scala b/src/main/wacc/backend/asmGenerator.scala index ec16435..c8981af 100644 --- a/src/main/wacc/backend/asmGenerator.scala +++ b/src/main/wacc/backend/asmGenerator.scala @@ -105,6 +105,18 @@ object asmGenerator { elseBranch.flatMap(generateStmt) ++ List(LabelDef(endLabel)) } + case While(cond, body) => { + val startLabel = labelGenerator.getLabel() + val endLabel = labelGenerator.getLabel() + List(LabelDef(startLabel)) ++ + evalExprIntoReg(cond, Register(RegSize.R64, RegName.AX)) ++ + List( + Compare(Register(RegSize.R64, RegName.AX), ImmediateVal(0)), + Jump(LabelArg(endLabel), Cond.Equal) + ) ++ + body.flatMap(generateStmt) ++ + List(Jump(LabelArg(startLabel)), LabelDef(endLabel)) + } case _ => List() } From 9d78caf6d918489b04bb9add2a14254a876e2d51 Mon Sep 17 00:00:00 2001 From: Guy C Date: Mon, 24 Feb 2025 18:57:13 +0000 Subject: [PATCH 12/54] feat: add support for return statements in asmGenerator --- src/main/wacc/backend/asmGenerator.scala | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/main/wacc/backend/asmGenerator.scala b/src/main/wacc/backend/asmGenerator.scala index c8981af..16b9a41 100644 --- a/src/main/wacc/backend/asmGenerator.scala +++ b/src/main/wacc/backend/asmGenerator.scala @@ -117,6 +117,8 @@ object asmGenerator { body.flatMap(generateStmt) ++ List(Jump(LabelArg(startLabel)), LabelDef(endLabel)) } + case microWacc.Return(expr) => + evalExprIntoReg(expr, Register(RegSize.R64, RegName.AX)) case _ => List() } From 668d7338aec029bf789ce8c0e299943b8299aa5c Mon Sep 17 00:00:00 2001 From: Guy C Date: Mon, 24 Feb 2025 19:47:06 +0000 Subject: [PATCH 13/54] feat: move default return out of functionEpilogue into main def --- src/main/wacc/backend/asmGenerator.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/main/wacc/backend/asmGenerator.scala b/src/main/wacc/backend/asmGenerator.scala index 16b9a41..4c1a924 100644 --- a/src/main/wacc/backend/asmGenerator.scala +++ b/src/main/wacc/backend/asmGenerator.scala @@ -25,6 +25,7 @@ object asmGenerator { funcPrologue() ++ alignStack() ++ main.flatMap(generateStmt) ++ + List(Move(Register(RegSize.R64, RegName.AX), ImmediateVal(0))) ++ funcEpilogue() ++ List(assemblyIR.Return()) ++ generateFuncs() @@ -204,7 +205,6 @@ object asmGenerator { def funcEpilogue(): List[AsmLine] = { List( - Move(Register(RegSize.R64, RegName.AX), ImmediateVal(0)), Move(Register(RegSize.R64, RegName.SP), Register(RegSize.R64, RegName.BP)), Pop(Register(RegSize.R64, RegName.BP)) ) From 148828122363f92a36e6a4110e907e77ccb64a46 Mon Sep 17 00:00:00 2001 From: Guy C Date: Tue, 25 Feb 2025 00:00:12 +0000 Subject: [PATCH 14/54] feat: implements binary operators in asmGenerator Co-authored-by: Gleb Koval Co-authored-by: Barf-Vader --- src/main/wacc/backend/asmGenerator.scala | 304 ++++++++++++++++------- src/main/wacc/backend/assemblyIR.scala | 41 ++- src/main/wacc/frontend/microWacc.scala | 14 +- src/main/wacc/frontend/typeChecker.scala | 35 ++- 4 files changed, 285 insertions(+), 109 deletions(-) diff --git a/src/main/wacc/backend/asmGenerator.scala b/src/main/wacc/backend/asmGenerator.scala index 4c1a924..1c3a2a9 100644 --- a/src/main/wacc/backend/asmGenerator.scala +++ b/src/main/wacc/backend/asmGenerator.scala @@ -2,19 +2,37 @@ package wacc import scala.collection.mutable.LinkedHashMap import scala.collection.mutable.ListBuffer +import cats.syntax.align -object labelGenerator { - var labelVal = -1 - def getLabel(): String = { - labelVal += 1 - s".L$labelVal" - } -} object asmGenerator { import microWacc._ import assemblyIR._ import wacc.types._ + val RAX = Register(RegSize.R64, RegName.AX) + val EAX = Register(RegSize.E32, RegName.AX) + val RSP = Register(RegSize.R64, RegName.SP) + val ESP = Register(RegSize.E32, RegName.SP) + val EDX = Register(RegSize.E32, RegName.DX) + val RDI = Register(RegSize.R64, RegName.DI) + val RIP = Register(RegSize.R64, RegName.IP) + val RBP = Register(RegSize.R64, RegName.BP) + val RSI = Register(RegSize.R64, RegName.SI) + + val _8_BIT_MASK = 0xFF + + object labelGenerator { + var labelVal = -1 + def getLabel(): String = { + labelVal += 1 + s".L$labelVal" + } + def getLabel(target: CallTarget): String = target match{ + case Ident(v,_) => s"wacc_$v" + case Builtin(name) => s"_$name" + } + } + def generateAsm(microProg: Program): List[AsmLine] = { given stack: LinkedHashMap[Ident, Int] = LinkedHashMap[Ident, Int]() given strings: ListBuffer[String] = ListBuffer[String]() @@ -25,10 +43,8 @@ object asmGenerator { funcPrologue() ++ alignStack() ++ main.flatMap(generateStmt) ++ - List(Move(Register(RegSize.R64, RegName.AX), ImmediateVal(0))) ++ - funcEpilogue() ++ - List(assemblyIR.Return()) ++ - generateFuncs() + List(Move(RAX, ImmediateVal(0))) ++ + funcEpilogue() val strDirs = strings.toList.zipWithIndex.flatMap { case (str, i) => List(Directive.Int(str.size), LabelDef(s".L.str$i"), Directive.Asciz(str)) @@ -40,12 +56,38 @@ object asmGenerator { progAsm } -//TODO - def generateFuncs()(using + def wrapFunc(labelName: String, funcBody: List[AsmLine])(using stack: LinkedHashMap[Ident, Int], strings: ListBuffer[String] ): List[AsmLine] = { + LabelDef(labelName) :: + funcPrologue() ++ + funcBody ++ + funcEpilogue() + } + + def generateBuiltInFuncs()(using + stack: LinkedHashMap[Ident, Int], + strings: ListBuffer[String] + ): List[AsmLine] = { + wrapFunc(labelGenerator.getLabel(Builtin.Exit), + alignStack() ++ + List(Pop(RDI), + assemblyIR.Call(CLibFunc.Exit)) + ) ++ + wrapFunc(labelGenerator.getLabel(Builtin.Printf), + alignStack() ++ + List(assemblyIR.Call(CLibFunc.PrintF), + Move(RDI, ImmediateVal(0)), + assemblyIR.Call(CLibFunc.Fflush) + ) + )++ + wrapFunc(labelGenerator.getLabel(Builtin.Malloc), List() + )++ + wrapFunc(labelGenerator.getLabel(Builtin.Free), + List() + ) } def generateStmt( @@ -53,28 +95,16 @@ object asmGenerator { )(using stack: LinkedHashMap[Ident, Int], strings: ListBuffer[String]): List[AsmLine] = stmt match { case microWacc.Call(Builtin.Exit, code :: _) => - // alignStack() ++ - evalExprIntoReg(code, Register(RegSize.R64, RegName.DI)) ++ - List(assemblyIR.Call(CLibFunc.Exit)) - - case microWacc.Call(Builtin.Println, expr :: _) => - // alignStack() ++ - printF(expr) ++ - printLn() - - case microWacc.Call(Builtin.Print, expr :: _) => - // alignStack() ++ - printF(expr) - + List() case Assign(lhs, rhs) => var dest: IndexAddress = - IndexAddress(Register(RegSize.R64, RegName.SP), 0) // gets overrwitten + IndexAddress(RSP, 0) // gets overrwitten (lhs match { case ident: Ident => if (!stack.contains(ident)) { stack += (ident -> (stack.size + 1)) dest = accessVar(ident) - List(Subtract(Register(RegSize.R64, RegName.SP), ImmediateVal(16))) + List(Subtract(RSP, ImmediateVal(8))) } else { dest = accessVar(ident) List() @@ -90,15 +120,16 @@ object asmGenerator { case microWacc.Call(Builtin.ReadChar, _) => readIntoVar(dest, Builtin.ReadChar) case _ => - evalExprIntoReg(rhs, Register(RegSize.R64, RegName.AX)) ++ - List(Move(dest, Register(RegSize.R64, RegName.AX))) + evalExprOntoStack(rhs) ++ + List(Pop(dest)) }) case If(cond, thenBranch, elseBranch) => { val elseLabel = labelGenerator.getLabel() val endLabel = labelGenerator.getLabel() - evalExprIntoReg(cond, Register(RegSize.R64, RegName.AX)) ++ + evalExprOntoStack(cond) ++ List( - Compare(Register(RegSize.R64, RegName.AX), ImmediateVal(0)), + Compare(MemLocation(RSP, SizeDir.Word), ImmediateVal(0)), + Add(RSP, ImmediateVal(8)), Jump(LabelArg(elseLabel), Cond.Equal) ) ++ thenBranch.flatMap(generateStmt) ++ @@ -110,30 +141,33 @@ object asmGenerator { val startLabel = labelGenerator.getLabel() val endLabel = labelGenerator.getLabel() List(LabelDef(startLabel)) ++ - evalExprIntoReg(cond, Register(RegSize.R64, RegName.AX)) ++ + evalExprOntoStack(cond) ++ List( - Compare(Register(RegSize.R64, RegName.AX), ImmediateVal(0)), + Compare(MemLocation(RSP, SizeDir.Word), ImmediateVal(0)), + Add(RSP, ImmediateVal(8)), Jump(LabelArg(endLabel), Cond.Equal) ) ++ body.flatMap(generateStmt) ++ List(Jump(LabelArg(startLabel)), LabelDef(endLabel)) } case microWacc.Return(expr) => - evalExprIntoReg(expr, Register(RegSize.R64, RegName.AX)) + evalExprOntoStack(expr) ++ + List(Pop(RAX), assemblyIR.Return()) + case _ => List() } - def evalExprIntoReg(expr: Expr, dest: Register)(using + def evalExprOntoStack(expr: Expr)(using stack: LinkedHashMap[Ident, Int], strings: ListBuffer[String] ): List[AsmLine] = { expr match { case IntLiter(v) => - List(Move(dest, ImmediateVal(v))) + List(Push(ImmediateVal(v))) case CharLiter(v) => - List(Move(dest, ImmediateVal(v.toInt))) + List(Push(ImmediateVal(v.toInt))) case ident: Ident => - List(Move(dest, accessVar(ident))) + List(Push(accessVar(ident))) case ArrayLiter(elems) => expr.ty match { case KnownType.String => @@ -143,70 +177,173 @@ object asmGenerator { }.mkString List( Load( - dest, + RAX, IndexAddress( - Register(RegSize.R64, RegName.IP), + RIP, LabelArg(s".L.str${strings.size - 1}") ) - ) + ), + Push(RAX) ) // TODO other array types case _ => List() } - // TODO other expr types - case BoolLiter(v) => List(Move(dest, ImmediateVal(if (v) 1 else 0))) - case _ => List() + case BoolLiter(v) => List(Push(ImmediateVal(if (v) 1 else 0))) + case NullLiter() => List(Push(ImmediateVal(0))) + case ArrayElem(value, indices) => List() + case UnaryOp(x, op) => op match { + // TODO: chr and ord are TYPE CASTS. They do not change the internal value, + // but will need bound checking e.t.c. + case UnaryOperator.Chr => List() + case UnaryOperator.Ord => List() + case UnaryOperator.Len => List() + case UnaryOperator.Negate => List( + Negate(MemLocation(RSP,SizeDir.Word)) + ) + case UnaryOperator.Not => + evalExprOntoStack(x) ++ + List( + Xor(MemLocation(RSP, SizeDir.Word), ImmediateVal(1)) + ) + + } + case BinaryOp(x, y, op) => op match { + case BinaryOperator.Add => + evalExprOntoStack(x) ++ + evalExprOntoStack(y) ++ + List( + Pop(EAX), + Add(MemLocation(RSP, SizeDir.Word), EAX) + // TODO OVERFLOWING + ) + case BinaryOperator.Sub => + evalExprOntoStack(x) ++ + evalExprOntoStack(y) ++ + List( + Pop(EAX), + Subtract(MemLocation(RSP, SizeDir.Word), EAX) + // TODO OVERFLOWING + ) + case BinaryOperator.Mul => + evalExprOntoStack(x) ++ + evalExprOntoStack(y) ++ + List( + Pop(EAX), + Multiply(MemLocation(RSP, SizeDir.Word), EAX) + // TODO OVERFLOWING + ) + case BinaryOperator.Div => + evalExprOntoStack(y) ++ + evalExprOntoStack(x) ++ + List( + Pop(EAX), + Divide(MemLocation(RSP, SizeDir.Word)), + Add(RSP, ImmediateVal(8)), + Push(EAX) + // TODO CHECK DIVISOR IS NOT 0 + ) + case BinaryOperator.Mod => + evalExprOntoStack(y) ++ + evalExprOntoStack(x) ++ + List( + Pop(EAX), + Divide(MemLocation(RSP, SizeDir.Word)), + Add(RSP, ImmediateVal(8)), + Push(EDX) + // TODO CHECK DIVISOR IS NOT 0 + ) + case BinaryOperator.Eq => + generateComparison(x, y, Cond.Equal) + case BinaryOperator.Neq => + generateComparison(x, y, Cond.NotEqual) + case BinaryOperator.Greater => + generateComparison(x, y, Cond.Greater) + case BinaryOperator.GreaterEq => + generateComparison(x, y, Cond.GreaterEqual) + case BinaryOperator.Less => + generateComparison(x, y, Cond.Less) + case BinaryOperator.LessEq => + generateComparison(x, y, Cond.LessEqual) + case BinaryOperator.And => + evalExprOntoStack(x) ++ + evalExprOntoStack(y) ++ + List( + Pop(EAX), + And(MemLocation(RSP, SizeDir.Word), EAX), + ) + case BinaryOperator.Or => + evalExprOntoStack(x) ++ + evalExprOntoStack(y) ++ + List( + Pop(EAX), + Or(MemLocation(RSP, SizeDir.Word), EAX), + ) + } + case microWacc.Call(target, args) => List() } } - // TODO make sure EOF doenst override the value in the stack - // probably need labels implemented for conditional jumps - def readIntoVar(dest: IndexAddress, readType: Builtin.ReadInt.type | Builtin.ReadChar.type)(using + // def readIntoVar(dest: IndexAddress, readType: Builtin.ReadInt.type | Builtin.ReadChar.type)(using + // stack: LinkedHashMap[Ident, Int], + // strings: ListBuffer[String] + // ): List[AsmLine] = { + // readType match { + // case Builtin.ReadInt => + // strings += PrintFormat.Int.toString + // case Builtin.ReadChar => + // strings += PrintFormat.Char.toString + // } + // List( + // Load( + // RDI, + // IndexAddress( + // RIP, + // LabelArg(s".L.str${strings.size - 1}") + // ) + // ), + // Load(RSI, dest) + // ) ++ + // // alignStack() ++ + // List(assemblyIR.Call(CLibFunc.Scanf)) + + // } + + def generateComparison(x : Expr, y: Expr, cond: Cond)(using stack: LinkedHashMap[Ident, Int], strings: ListBuffer[String] ): List[AsmLine] = { - readType match { - case Builtin.ReadInt => - strings += PrintFormat.Int.toString - case Builtin.ReadChar => - strings += PrintFormat.Char.toString - } - List( - Load( - Register(RegSize.R64, RegName.DI), - IndexAddress( - Register(RegSize.R64, RegName.IP), - LabelArg(s".L.str${strings.size - 1}") - ) - ), - Load(Register(RegSize.R64, RegName.SI), dest) - ) ++ - // alignStack() ++ - List(assemblyIR.Call(CLibFunc.Scanf)) - + evalExprOntoStack(x) ++ + evalExprOntoStack(y) ++ + List( + Pop(EAX), + Compare(MemLocation(RSP, SizeDir.Word), EAX), + Set(Register(RegSize.Byte, RegName.AL), cond), + And(EAX, ImmediateVal(_8_BIT_MASK)), + Push(EAX) + ) } - def accessVar(ident: Ident)(using stack: LinkedHashMap[Ident, Int]): IndexAddress = - IndexAddress(Register(RegSize.R64, RegName.SP), (stack.size - stack(ident)) * 16) + IndexAddress(RSP, (stack.size - stack(ident)) * 8) def alignStack()(using stack: LinkedHashMap[Ident, Int]): List[AsmLine] = { List( - And(Register(RegSize.R64, RegName.SP), ImmediateVal(-16)) + And(RSP, ImmediateVal(-16)) ) } // Missing a sub instruction but dont think we need it def funcPrologue(): List[AsmLine] = { List( - Push(Register(RegSize.R64, RegName.BP)), - Move(Register(RegSize.R64, RegName.BP), Register(RegSize.R64, RegName.SP)) + Push(RBP), + Move(RBP, RSP) ) } def funcEpilogue(): List[AsmLine] = { List( - Move(Register(RegSize.R64, RegName.SP), Register(RegSize.R64, RegName.BP)), - Pop(Register(RegSize.R64, RegName.BP)) + Move(RSP, RBP), + Pop(RBP), + assemblyIR.Return() ) } @@ -231,9 +368,9 @@ object asmGenerator { } List( Load( - Register(RegSize.R64, RegName.DI), + RDI, IndexAddress( - Register(RegSize.R64, RegName.IP), + RIP, LabelArg(s".L.str${strings.size - 1}") ) ) @@ -251,22 +388,23 @@ object asmGenerator { } List( Load( - Register(RegSize.R64, RegName.DI), + RDI, IndexAddress( - Register(RegSize.R64, RegName.IP), + RIP, LabelArg(s".L.str${strings.size - 1}") ) ) ) } else { - evalExprIntoReg(expr, Register(RegSize.R64, RegName.SI)) + evalExprOntoStack(expr) ++ + List(Pop(RSI)) }) // print the value ++ List( assemblyIR.Call(CLibFunc.PrintF), - Move(Register(RegSize.R64, RegName.DI), ImmediateVal(0)), + Move(RDI, ImmediateVal(0)), assemblyIR.Call(CLibFunc.Fflush) ) } @@ -278,16 +416,16 @@ object asmGenerator { ): List[AsmLine] = { strings += "" Load( - Register(RegSize.R64, RegName.DI), + RDI, IndexAddress( - Register(RegSize.R64, RegName.IP), + RIP, LabelArg(s".L.str${strings.size - 1}") ) ) :: List( assemblyIR.Call(CLibFunc.Puts), - Move(Register(RegSize.R64, RegName.DI), ImmediateVal(0)), + Move(RDI, ImmediateVal(0)), assemblyIR.Call(CLibFunc.Fflush) ) diff --git a/src/main/wacc/backend/assemblyIR.scala b/src/main/wacc/backend/assemblyIR.scala index c48daac..22ca36b 100644 --- a/src/main/wacc/backend/assemblyIR.scala +++ b/src/main/wacc/backend/assemblyIR.scala @@ -9,17 +9,20 @@ object assemblyIR { enum RegSize { case R64 case E32 + case Byte override def toString = this match { case R64 => "r" case E32 => "e" + case Byte => "" } } enum RegName { - case AX, BX, CX, DX, SI, DI, SP, BP, IP, Reg8, Reg9, Reg10, Reg11, Reg12, Reg13, Reg14, Reg15 + case AX, AL, BX, CX, DX, SI, DI, SP, BP, IP, Reg8, Reg9, Reg10, Reg11, Reg12, Reg13, Reg14, Reg15 override def toString = this match { case AX => "ax" + case AL => "al" case BX => "bx" case CX => "cx" case DX => "dx" @@ -42,7 +45,6 @@ object assemblyIR { // arguments enum CLibFunc extends Operand { case Scanf, - Puts, Fflush, Exit, PrintF @@ -51,24 +53,24 @@ object assemblyIR { override def toString = this match { case Scanf => "scanf" + plt - case Puts => "puts" + plt case Fflush => "fflush" + plt case Exit => "exit" + plt case PrintF => "printf" + plt } } + //TODO register naming conventions are wrong case class Register(size: RegSize, name: RegName) extends Dest with Src { override def toString = s"${size}${name}" } - case class MemLocation(pointer: Long | Register) extends Dest with Src { + case class MemLocation(pointer: Long | Register, opSize: SizeDir = SizeDir.Unspecified) extends Dest with Src { override def toString = pointer match { - case hex: Long => f"[0x$hex%X]" - case reg: Register => s"[$reg]" + case hex: Long => opSize.toString + f"[0x$hex%X]" + case reg: Register => opSize.toString + s"[$reg]" } } - case class IndexAddress(base: Register, offset: Int | LabelArg) extends Dest with Src { - override def toString = s"[$base + $offset]" + case class IndexAddress(base: Register, offset: Int | LabelArg, opSize: SizeDir = SizeDir.Unspecified) extends Dest with Src { + override def toString = s"$opSize[$base + $offset]" } case class ImmediateVal(value: Int) extends Src { @@ -85,11 +87,13 @@ object assemblyIR { } case class Add(op1: Dest, op2: Src) extends Operation("add", op1, op2) case class Subtract(op1: Dest, op2: Src) extends Operation("sub", op1, op2) - case class Multiply(ops: Operand*) extends Operation("mul", ops*) - case class Divide(op1: Src) extends Operation("div", op1) + case class Multiply(ops: Operand*) extends Operation("imul", ops*) + case class Divide(op1: Src) extends Operation("idiv", op1) + case class Negate(op: Dest) extends Operation("neg", op) case class And(op1: Dest, op2: Src) extends Operation("and", op1, op2) case class Or(op1: Dest, op2: Src) extends Operation("or", op1, op2) + case class Xor(op1: Dest, op2: Src) extends Operation("xor", op1, op2) case class Compare(op1: Dest, op2: Src) extends Operation("cmp", op1, op2) // stack operations @@ -106,6 +110,9 @@ object assemblyIR { case class Jump(op1: LabelArg, condition: Cond = Cond.Always) extends Operation(s"j${condition.toString}", op1) + case class Set(op1: Dest, condition: Cond = Cond.Always) + extends Operation(s"set${condition.toString}", op1) + case class LabelDef(name: String) extends AsmLine { override def toString = s"$name:" } @@ -156,4 +163,16 @@ object assemblyIR { case String => "%s" } } -} + + enum SizeDir { + case Byte, Word, Unspecified + + private val ptr = "ptr " + + override def toString(): String = this match { + case Byte => "byte " + ptr + case Word => "word " + ptr + case Unspecified => "" + } + } +} \ No newline at end of file diff --git a/src/main/wacc/frontend/microWacc.scala b/src/main/wacc/frontend/microWacc.scala index b9f6635..c558b6d 100644 --- a/src/main/wacc/frontend/microWacc.scala +++ b/src/main/wacc/frontend/microWacc.scala @@ -69,13 +69,15 @@ object microWacc { // Statements sealed trait Stmt + case class Builtin(val name: String)(retTy: SemType) extends CallTarget(retTy) { + override def toString(): String = name + } object Builtin { - case object ReadInt extends CallTarget(KnownType.Int) - case object ReadChar extends CallTarget(KnownType.Char) - case object Print extends CallTarget(?) - case object Println extends CallTarget(?) - case object Exit extends CallTarget(?) - case object Free extends CallTarget(?) + object Read extends Builtin("read")(?) + object Printf extends Builtin("printf")(?) + object Exit extends Builtin("exit")(?) + object Free extends Builtin("free")(?) + object Malloc extends Builtin("malloc")(?) } case class Assign(lhs: LValue, rhs: Expr) extends Stmt diff --git a/src/main/wacc/frontend/typeChecker.scala b/src/main/wacc/frontend/typeChecker.scala index e3960cb..b854272 100644 --- a/src/main/wacc/frontend/typeChecker.scala +++ b/src/main/wacc/frontend/typeChecker.scala @@ -177,12 +177,14 @@ object typeChecker { microWacc.Assign( destTyped, microWacc.Call( - destTy match { - case KnownType.Int => microWacc.Builtin.ReadInt - case KnownType.Char => microWacc.Builtin.ReadChar - case _ => microWacc.Builtin.ReadInt // we'll stop due to error anyway - }, - Nil + microWacc.Builtin.Read, + List( + destTy match { + case KnownType.Int => "%d".toMicroWaccCharArray + case KnownType.Char | _ => "%c".toMicroWaccCharArray + }, + destTyped + ) ) ) ) @@ -213,10 +215,20 @@ object typeChecker { ) case ast.Print(expr, newline) => // This constraint should never fail, the scope-checker should have caught it already + val exprTyped = checkValue(expr, Constraint.Unconstrained) + val format = exprTyped.ty match { + case KnownType.Bool | KnownType.String => "%s" + case KnownType.Char => "%c" + case KnownType.Int => "%d" + case _ => "%p" + } List( microWacc.Call( - if newline then microWacc.Builtin.Println else microWacc.Builtin.Print, - List(checkValue(expr, Constraint.Unconstrained)) + microWacc.Builtin.Printf, + List( + s"$format${if newline then "\n" else ""}".toMicroWaccCharArray, + exprTyped + ) ) ) case ast.If(cond, thenStmt, elseStmt) => @@ -262,7 +274,7 @@ object typeChecker { microWacc.CharLiter(v) case l @ ast.StrLiter(v) => KnownType.String.satisfies(constraint, l.pos) - microWacc.ArrayLiter(v.map(microWacc.CharLiter(_)).toList)(KnownType.String) + v.toMicroWaccCharArray case l @ ast.PairLiter() => microWacc.NullLiter()(KnownType.Pair(?, ?).satisfies(constraint, l.pos)) case ast.Parens(expr) => checkValue(expr, constraint) @@ -441,4 +453,9 @@ object typeChecker { case _ => ctx.error(Error.InternalError(elem.pos, "snd must be applied to a pair")) }) } + + extension (s: String) { + def toMicroWaccCharArray: microWacc.ArrayLiter = + microWacc.ArrayLiter(s.map(microWacc.CharLiter(_)).toList)(KnownType.String) + } } From f30cf42c4b73d50b79c7fe1e3829313839c7ecc8 Mon Sep 17 00:00:00 2001 From: Guy C Date: Tue, 25 Feb 2025 00:05:10 +0000 Subject: [PATCH 15/54] style: improve code formatting and consistency in typeChecker and assemblyIR --- src/main/wacc/backend/asmGenerator.scala | 247 ++++++++++++----------- src/main/wacc/backend/assemblyIR.scala | 30 ++- src/main/wacc/frontend/typeChecker.scala | 8 +- 3 files changed, 147 insertions(+), 138 deletions(-) diff --git a/src/main/wacc/backend/asmGenerator.scala b/src/main/wacc/backend/asmGenerator.scala index 1c3a2a9..8d5e746 100644 --- a/src/main/wacc/backend/asmGenerator.scala +++ b/src/main/wacc/backend/asmGenerator.scala @@ -19,7 +19,7 @@ object asmGenerator { val RBP = Register(RegSize.R64, RegName.BP) val RSI = Register(RegSize.R64, RegName.SI) - val _8_BIT_MASK = 0xFF + val _8_BIT_MASK = 0xff object labelGenerator { var labelVal = -1 @@ -27,8 +27,8 @@ object asmGenerator { labelVal += 1 s".L$labelVal" } - def getLabel(target: CallTarget): String = target match{ - case Ident(v,_) => s"wacc_$v" + def getLabel(target: CallTarget): String = target match { + case Ident(v, _) => s"wacc_$v" case Builtin(name) => s"_$name" } } @@ -61,33 +61,31 @@ object asmGenerator { strings: ListBuffer[String] ): List[AsmLine] = { LabelDef(labelName) :: - funcPrologue() ++ + funcPrologue() ++ funcBody ++ - funcEpilogue() + funcEpilogue() } def generateBuiltInFuncs()(using stack: LinkedHashMap[Ident, Int], strings: ListBuffer[String] ): List[AsmLine] = { - wrapFunc(labelGenerator.getLabel(Builtin.Exit), + wrapFunc( + labelGenerator.getLabel(Builtin.Exit), alignStack() ++ - List(Pop(RDI), - assemblyIR.Call(CLibFunc.Exit)) + List(Pop(RDI), assemblyIR.Call(CLibFunc.Exit)) ) ++ - wrapFunc(labelGenerator.getLabel(Builtin.Printf), - alignStack() ++ - List(assemblyIR.Call(CLibFunc.PrintF), - Move(RDI, ImmediateVal(0)), - assemblyIR.Call(CLibFunc.Fflush) - ) - )++ - wrapFunc(labelGenerator.getLabel(Builtin.Malloc), - List() - )++ - wrapFunc(labelGenerator.getLabel(Builtin.Free), - List() - ) + wrapFunc( + labelGenerator.getLabel(Builtin.Printf), + alignStack() ++ + List( + assemblyIR.Call(CLibFunc.PrintF), + Move(RDI, ImmediateVal(0)), + assemblyIR.Call(CLibFunc.Fflush) + ) + ) ++ + wrapFunc(labelGenerator.getLabel(Builtin.Malloc), List()) ++ + wrapFunc(labelGenerator.getLabel(Builtin.Free), List()) } def generateStmt( @@ -95,7 +93,7 @@ object asmGenerator { )(using stack: LinkedHashMap[Ident, Int], strings: ListBuffer[String]): List[AsmLine] = stmt match { case microWacc.Call(Builtin.Exit, code :: _) => - List() + List() case Assign(lhs, rhs) => var dest: IndexAddress = IndexAddress(RSP, 0) // gets overrwitten @@ -188,97 +186,100 @@ object asmGenerator { // TODO other array types case _ => List() } - case BoolLiter(v) => List(Push(ImmediateVal(if (v) 1 else 0))) - case NullLiter() => List(Push(ImmediateVal(0))) + case BoolLiter(v) => List(Push(ImmediateVal(if (v) 1 else 0))) + case NullLiter() => List(Push(ImmediateVal(0))) case ArrayElem(value, indices) => List() - case UnaryOp(x, op) => op match { - // TODO: chr and ord are TYPE CASTS. They do not change the internal value, - // but will need bound checking e.t.c. - case UnaryOperator.Chr => List() - case UnaryOperator.Ord => List() - case UnaryOperator.Len => List() - case UnaryOperator.Negate => List( - Negate(MemLocation(RSP,SizeDir.Word)) - ) - case UnaryOperator.Not => - evalExprOntoStack(x) ++ - List( - Xor(MemLocation(RSP, SizeDir.Word), ImmediateVal(1)) - ) - - } - case BinaryOp(x, y, op) => op match { - case BinaryOperator.Add => - evalExprOntoStack(x) ++ - evalExprOntoStack(y) ++ - List( - Pop(EAX), - Add(MemLocation(RSP, SizeDir.Word), EAX) - // TODO OVERFLOWING - ) - case BinaryOperator.Sub => - evalExprOntoStack(x) ++ - evalExprOntoStack(y) ++ - List( - Pop(EAX), - Subtract(MemLocation(RSP, SizeDir.Word), EAX) - // TODO OVERFLOWING - ) - case BinaryOperator.Mul => - evalExprOntoStack(x) ++ - evalExprOntoStack(y) ++ - List( - Pop(EAX), - Multiply(MemLocation(RSP, SizeDir.Word), EAX) - // TODO OVERFLOWING - ) - case BinaryOperator.Div => - evalExprOntoStack(y) ++ - evalExprOntoStack(x) ++ - List( - Pop(EAX), - Divide(MemLocation(RSP, SizeDir.Word)), - Add(RSP, ImmediateVal(8)), - Push(EAX) - // TODO CHECK DIVISOR IS NOT 0 - ) - case BinaryOperator.Mod => - evalExprOntoStack(y) ++ - evalExprOntoStack(x) ++ - List( - Pop(EAX), - Divide(MemLocation(RSP, SizeDir.Word)), - Add(RSP, ImmediateVal(8)), - Push(EDX) - // TODO CHECK DIVISOR IS NOT 0 + case UnaryOp(x, op) => + op match { + // TODO: chr and ord are TYPE CASTS. They do not change the internal value, + // but will need bound checking e.t.c. + case UnaryOperator.Chr => List() + case UnaryOperator.Ord => List() + case UnaryOperator.Len => List() + case UnaryOperator.Negate => + List( + Negate(MemLocation(RSP, SizeDir.Word)) ) - case BinaryOperator.Eq => - generateComparison(x, y, Cond.Equal) - case BinaryOperator.Neq => - generateComparison(x, y, Cond.NotEqual) - case BinaryOperator.Greater => - generateComparison(x, y, Cond.Greater) - case BinaryOperator.GreaterEq => - generateComparison(x, y, Cond.GreaterEqual) - case BinaryOperator.Less => - generateComparison(x, y, Cond.Less) - case BinaryOperator.LessEq => - generateComparison(x, y, Cond.LessEqual) - case BinaryOperator.And => - evalExprOntoStack(x) ++ - evalExprOntoStack(y) ++ - List( - Pop(EAX), - And(MemLocation(RSP, SizeDir.Word), EAX), - ) - case BinaryOperator.Or => - evalExprOntoStack(x) ++ - evalExprOntoStack(y) ++ - List( - Pop(EAX), - Or(MemLocation(RSP, SizeDir.Word), EAX), - ) - } + case UnaryOperator.Not => + evalExprOntoStack(x) ++ + List( + Xor(MemLocation(RSP, SizeDir.Word), ImmediateVal(1)) + ) + + } + case BinaryOp(x, y, op) => + op match { + case BinaryOperator.Add => + evalExprOntoStack(x) ++ + evalExprOntoStack(y) ++ + List( + Pop(EAX), + Add(MemLocation(RSP, SizeDir.Word), EAX) + // TODO OVERFLOWING + ) + case BinaryOperator.Sub => + evalExprOntoStack(x) ++ + evalExprOntoStack(y) ++ + List( + Pop(EAX), + Subtract(MemLocation(RSP, SizeDir.Word), EAX) + // TODO OVERFLOWING + ) + case BinaryOperator.Mul => + evalExprOntoStack(x) ++ + evalExprOntoStack(y) ++ + List( + Pop(EAX), + Multiply(MemLocation(RSP, SizeDir.Word), EAX) + // TODO OVERFLOWING + ) + case BinaryOperator.Div => + evalExprOntoStack(y) ++ + evalExprOntoStack(x) ++ + List( + Pop(EAX), + Divide(MemLocation(RSP, SizeDir.Word)), + Add(RSP, ImmediateVal(8)), + Push(EAX) + // TODO CHECK DIVISOR IS NOT 0 + ) + case BinaryOperator.Mod => + evalExprOntoStack(y) ++ + evalExprOntoStack(x) ++ + List( + Pop(EAX), + Divide(MemLocation(RSP, SizeDir.Word)), + Add(RSP, ImmediateVal(8)), + Push(EDX) + // TODO CHECK DIVISOR IS NOT 0 + ) + case BinaryOperator.Eq => + generateComparison(x, y, Cond.Equal) + case BinaryOperator.Neq => + generateComparison(x, y, Cond.NotEqual) + case BinaryOperator.Greater => + generateComparison(x, y, Cond.Greater) + case BinaryOperator.GreaterEq => + generateComparison(x, y, Cond.GreaterEqual) + case BinaryOperator.Less => + generateComparison(x, y, Cond.Less) + case BinaryOperator.LessEq => + generateComparison(x, y, Cond.LessEqual) + case BinaryOperator.And => + evalExprOntoStack(x) ++ + evalExprOntoStack(y) ++ + List( + Pop(EAX), + And(MemLocation(RSP, SizeDir.Word), EAX) + ) + case BinaryOperator.Or => + evalExprOntoStack(x) ++ + evalExprOntoStack(y) ++ + List( + Pop(EAX), + Or(MemLocation(RSP, SizeDir.Word), EAX) + ) + } case microWacc.Call(target, args) => List() } } @@ -308,19 +309,19 @@ object asmGenerator { // } - def generateComparison(x : Expr, y: Expr, cond: Cond)(using + def generateComparison(x: Expr, y: Expr, cond: Cond)(using stack: LinkedHashMap[Ident, Int], strings: ListBuffer[String] ): List[AsmLine] = { - evalExprOntoStack(x) ++ - evalExprOntoStack(y) ++ - List( - Pop(EAX), - Compare(MemLocation(RSP, SizeDir.Word), EAX), - Set(Register(RegSize.Byte, RegName.AL), cond), - And(EAX, ImmediateVal(_8_BIT_MASK)), - Push(EAX) - ) + evalExprOntoStack(x) ++ + evalExprOntoStack(y) ++ + List( + Pop(EAX), + Compare(MemLocation(RSP, SizeDir.Word), EAX), + Set(Register(RegSize.Byte, RegName.AL), cond), + And(EAX, ImmediateVal(_8_BIT_MASK)), + Push(EAX) + ) } def accessVar(ident: Ident)(using stack: LinkedHashMap[Ident, Int]): IndexAddress = IndexAddress(RSP, (stack.size - stack(ident)) * 8) @@ -397,8 +398,8 @@ object asmGenerator { ) } else { - evalExprOntoStack(expr) ++ - List(Pop(RSI)) + evalExprOntoStack(expr) ++ + List(Pop(RSI)) }) // print the value ++ diff --git a/src/main/wacc/backend/assemblyIR.scala b/src/main/wacc/backend/assemblyIR.scala index 22ca36b..14e4f4f 100644 --- a/src/main/wacc/backend/assemblyIR.scala +++ b/src/main/wacc/backend/assemblyIR.scala @@ -12,14 +12,15 @@ object assemblyIR { case Byte override def toString = this match { - case R64 => "r" - case E32 => "e" + case R64 => "r" + case E32 => "e" case Byte => "" } } enum RegName { - case AX, AL, BX, CX, DX, SI, DI, SP, BP, IP, Reg8, Reg9, Reg10, Reg11, Reg12, Reg13, Reg14, Reg15 + case AX, AL, BX, CX, DX, SI, DI, SP, BP, IP, Reg8, Reg9, Reg10, Reg11, Reg12, Reg13, Reg14, + Reg15 override def toString = this match { case AX => "ax" case AL => "al" @@ -59,17 +60,24 @@ object assemblyIR { } } - //TODO register naming conventions are wrong + // TODO register naming conventions are wrong case class Register(size: RegSize, name: RegName) extends Dest with Src { override def toString = s"${size}${name}" } - case class MemLocation(pointer: Long | Register, opSize: SizeDir = SizeDir.Unspecified) extends Dest with Src { + case class MemLocation(pointer: Long | Register, opSize: SizeDir = SizeDir.Unspecified) + extends Dest + with Src { override def toString = pointer match { case hex: Long => opSize.toString + f"[0x$hex%X]" case reg: Register => opSize.toString + s"[$reg]" } } - case class IndexAddress(base: Register, offset: Int | LabelArg, opSize: SizeDir = SizeDir.Unspecified) extends Dest with Src { + case class IndexAddress( + base: Register, + offset: Int | LabelArg, + opSize: SizeDir = SizeDir.Unspecified + ) extends Dest + with Src { override def toString = s"$opSize[$base + $offset]" } @@ -111,7 +119,7 @@ object assemblyIR { extends Operation(s"j${condition.toString}", op1) case class Set(op1: Dest, condition: Cond = Cond.Always) - extends Operation(s"set${condition.toString}", op1) + extends Operation(s"set${condition.toString}", op1) case class LabelDef(name: String) extends AsmLine { override def toString = s"$name:" @@ -165,14 +173,14 @@ object assemblyIR { } enum SizeDir { - case Byte, Word, Unspecified + case Byte, Word, Unspecified private val ptr = "ptr " override def toString(): String = this match { - case Byte => "byte " + ptr - case Word => "word " + ptr + case Byte => "byte " + ptr + case Word => "word " + ptr case Unspecified => "" } } -} \ No newline at end of file +} diff --git a/src/main/wacc/frontend/typeChecker.scala b/src/main/wacc/frontend/typeChecker.scala index b854272..c3f2ba8 100644 --- a/src/main/wacc/frontend/typeChecker.scala +++ b/src/main/wacc/frontend/typeChecker.scala @@ -180,7 +180,7 @@ object typeChecker { microWacc.Builtin.Read, List( destTy match { - case KnownType.Int => "%d".toMicroWaccCharArray + case KnownType.Int => "%d".toMicroWaccCharArray case KnownType.Char | _ => "%c".toMicroWaccCharArray }, destTyped @@ -218,9 +218,9 @@ object typeChecker { val exprTyped = checkValue(expr, Constraint.Unconstrained) val format = exprTyped.ty match { case KnownType.Bool | KnownType.String => "%s" - case KnownType.Char => "%c" - case KnownType.Int => "%d" - case _ => "%p" + case KnownType.Char => "%c" + case KnownType.Int => "%d" + case _ => "%p" } List( microWacc.Call( From 58d280462ec79935db560b7f85729d4decc1e8c0 Mon Sep 17 00:00:00 2001 From: Guy C Date: Tue, 25 Feb 2025 02:02:57 +0000 Subject: [PATCH 16/54] feat: enhance asmGenerator with additional registers and improve function call generation Co-authored-by: Barf-Vader Co-authored-by: Gleb Koval --- src/main/wacc/backend/asmGenerator.scala | 241 +++++++++++++---------- src/main/wacc/backend/assemblyIR.scala | 2 +- src/main/wacc/frontend/lexer.scala | 2 +- 3 files changed, 136 insertions(+), 109 deletions(-) diff --git a/src/main/wacc/backend/asmGenerator.scala b/src/main/wacc/backend/asmGenerator.scala index 8d5e746..2a4a2ad 100644 --- a/src/main/wacc/backend/asmGenerator.scala +++ b/src/main/wacc/backend/asmGenerator.scala @@ -2,7 +2,6 @@ package wacc import scala.collection.mutable.LinkedHashMap import scala.collection.mutable.ListBuffer -import cats.syntax.align object asmGenerator { import microWacc._ @@ -18,6 +17,10 @@ object asmGenerator { val RIP = Register(RegSize.R64, RegName.IP) val RBP = Register(RegSize.R64, RegName.BP) val RSI = Register(RegSize.R64, RegName.SI) + val RDX = Register(RegSize.R64, RegName.DX) + val RCX = Register(RegSize.R64, RegName.CX) + val R8 = Register(RegSize.R64, RegName.Reg8) + val R9 = Register(RegSize.R64, RegName.Reg9) val _8_BIT_MASK = 0xff @@ -44,10 +47,11 @@ object asmGenerator { alignStack() ++ main.flatMap(generateStmt) ++ List(Move(RAX, ImmediateVal(0))) ++ - funcEpilogue() + funcEpilogue() ++ + generateBuiltInFuncs() val strDirs = strings.toList.zipWithIndex.flatMap { case (str, i) => - List(Directive.Int(str.size), LabelDef(s".L.str$i"), Directive.Asciz(str)) + List(Directive.Int(str.size), LabelDef(s".L.str$i"), Directive.Asciz(str.replace("\"", "\\\""))) } List(Directive.IntelSyntax, Directive.Global("main"), Directive.RoData) ++ @@ -73,7 +77,7 @@ object asmGenerator { wrapFunc( labelGenerator.getLabel(Builtin.Exit), alignStack() ++ - List(Pop(RDI), assemblyIR.Call(CLibFunc.Exit)) + List(assemblyIR.Call(CLibFunc.Exit)) ) ++ wrapFunc( labelGenerator.getLabel(Builtin.Printf), @@ -84,16 +88,28 @@ object asmGenerator { assemblyIR.Call(CLibFunc.Fflush) ) ) ++ - wrapFunc(labelGenerator.getLabel(Builtin.Malloc), List()) ++ - wrapFunc(labelGenerator.getLabel(Builtin.Free), List()) + wrapFunc( + labelGenerator.getLabel(Builtin.Malloc), + alignStack() ++ + List() + ) ++ + wrapFunc(labelGenerator.getLabel(Builtin.Free), List()) ++ + wrapFunc( + labelGenerator.getLabel(Builtin.Read), + alignStack() ++ + List( + Push(RSI), + Load(RSI, MemLocation(RSP)), + assemblyIR.Call(CLibFunc.Scanf), + Pop(RAX) + ) + ) } def generateStmt( stmt: Stmt )(using stack: LinkedHashMap[Ident, Int], strings: ListBuffer[String]): List[AsmLine] = stmt match { - case microWacc.Call(Builtin.Exit, code :: _) => - List() case Assign(lhs, rhs) => var dest: IndexAddress = IndexAddress(RSP, 0) // gets overrwitten @@ -112,15 +128,8 @@ object asmGenerator { // dest = ??? List() }) ++ - (rhs match { - case microWacc.Call(Builtin.ReadInt, _) => - readIntoVar(dest, Builtin.ReadInt) - case microWacc.Call(Builtin.ReadChar, _) => - readIntoVar(dest, Builtin.ReadChar) - case _ => - evalExprOntoStack(rhs) ++ - List(Pop(dest)) - }) + evalExprOntoStack(rhs) ++ + List(Pop(dest)) case If(cond, thenBranch, elseBranch) => { val elseLabel = labelGenerator.getLabel() val endLabel = labelGenerator.getLabel() @@ -151,8 +160,7 @@ object asmGenerator { case microWacc.Return(expr) => evalExprOntoStack(expr) ++ List(Pop(RAX), assemblyIR.Return()) - - case _ => List() + case call: microWacc.Call => generateCall(call) } def evalExprOntoStack(expr: Expr)(using @@ -213,7 +221,7 @@ object asmGenerator { evalExprOntoStack(x) ++ evalExprOntoStack(y) ++ List( - Pop(EAX), + Pop(RAX), Add(MemLocation(RSP, SizeDir.Word), EAX) // TODO OVERFLOWING ) @@ -221,7 +229,7 @@ object asmGenerator { evalExprOntoStack(x) ++ evalExprOntoStack(y) ++ List( - Pop(EAX), + Pop(RAX), Subtract(MemLocation(RSP, SizeDir.Word), EAX) // TODO OVERFLOWING ) @@ -229,28 +237,30 @@ object asmGenerator { evalExprOntoStack(x) ++ evalExprOntoStack(y) ++ List( - Pop(EAX), - Multiply(MemLocation(RSP, SizeDir.Word), EAX) + Pop(RAX), + Multiply(EAX, MemLocation(RSP, SizeDir.Word)), + Add(RSP, ImmediateVal(8)), + Push(RAX) // TODO OVERFLOWING ) case BinaryOperator.Div => evalExprOntoStack(y) ++ evalExprOntoStack(x) ++ List( - Pop(EAX), + Pop(RAX), Divide(MemLocation(RSP, SizeDir.Word)), Add(RSP, ImmediateVal(8)), - Push(EAX) + Push(RAX) // TODO CHECK DIVISOR IS NOT 0 ) case BinaryOperator.Mod => evalExprOntoStack(y) ++ evalExprOntoStack(x) ++ List( - Pop(EAX), + Pop(RAX), Divide(MemLocation(RSP, SizeDir.Word)), Add(RSP, ImmediateVal(8)), - Push(EDX) + Push(RDX) // TODO CHECK DIVISOR IS NOT 0 ) case BinaryOperator.Eq => @@ -269,21 +279,38 @@ object asmGenerator { evalExprOntoStack(x) ++ evalExprOntoStack(y) ++ List( - Pop(EAX), + Pop(RAX), And(MemLocation(RSP, SizeDir.Word), EAX) ) case BinaryOperator.Or => evalExprOntoStack(x) ++ evalExprOntoStack(y) ++ List( - Pop(EAX), + Pop(RAX), Or(MemLocation(RSP, SizeDir.Word), EAX) ) } - case microWacc.Call(target, args) => List() + case call: microWacc.Call => generateCall(call) } } + def generateCall(call: microWacc.Call)(using + stack: LinkedHashMap[Ident, Int], + strings: ListBuffer[String] + ): List[AsmLine] = { + val argRegs = List(RDI, RSI, RDX, RCX, R8, R9) + val microWacc.Call(target, args) = call + argRegs.zip(args).flatMap { (reg, expr) => + evalExprOntoStack(expr) ++ + List(Pop(reg)) + } ++ + args.drop(argRegs.size).flatMap(evalExprOntoStack) ++ + List(assemblyIR.Call(LabelArg(labelGenerator.getLabel(target)))) ++ + (if (args.size > argRegs.size) { + List(Load(RSP, IndexAddress(RSP, (args.size - argRegs.size) * 8))) + } else Nil) + } + // def readIntoVar(dest: IndexAddress, readType: Builtin.ReadInt.type | Builtin.ReadChar.type)(using // stack: LinkedHashMap[Ident, Int], // strings: ListBuffer[String] @@ -316,11 +343,11 @@ object asmGenerator { evalExprOntoStack(x) ++ evalExprOntoStack(y) ++ List( - Pop(EAX), + Pop(RAX), Compare(MemLocation(RSP, SizeDir.Word), EAX), Set(Register(RegSize.Byte, RegName.AL), cond), And(EAX, ImmediateVal(_8_BIT_MASK)), - Push(EAX) + Push(RAX) ) } def accessVar(ident: Ident)(using stack: LinkedHashMap[Ident, Int]): IndexAddress = @@ -352,83 +379,83 @@ object asmGenerator { // def restoreRegs(regList: List[Register]): List[AsmLine] = regList.reverse.map(Pop(_)) // TODO: refactor, really ugly function - def printF(expr: Expr)(using - stack: LinkedHashMap[Ident, Int], - strings: ListBuffer[String] - ): List[AsmLine] = { -// determine the format string - expr.ty match { - case KnownType.String => - strings += PrintFormat.String.toString - case KnownType.Char => - strings += PrintFormat.Char.toString - case KnownType.Int => - strings += PrintFormat.Int.toString - case _ => - strings += PrintFormat.String.toString - } - List( - Load( - RDI, - IndexAddress( - RIP, - LabelArg(s".L.str${strings.size - 1}") - ) - ) - ) - ++ - // determine the actual value to print - (if (expr.ty == KnownType.Bool) { - expr match { - case BoolLiter(true) => { - strings += "true" - } - case _ => { - strings += "false" - } - } - List( - Load( - RDI, - IndexAddress( - RIP, - LabelArg(s".L.str${strings.size - 1}") - ) - ) - ) +// def printF(expr: Expr)(using +// stack: LinkedHashMap[Ident, Int], +// strings: ListBuffer[String] +// ): List[AsmLine] = { +// // determine the format string +// expr.ty match { +// case KnownType.String => +// strings += PrintFormat.String.toString +// case KnownType.Char => +// strings += PrintFormat.Char.toString +// case KnownType.Int => +// strings += PrintFormat.Int.toString +// case _ => +// strings += PrintFormat.String.toString +// } +// List( +// Load( +// RDI, +// IndexAddress( +// RIP, +// LabelArg(s".L.str${strings.size - 1}") +// ) +// ) +// ) +// ++ +// // determine the actual value to print +// (if (expr.ty == KnownType.Bool) { +// expr match { +// case BoolLiter(true) => { +// strings += "true" +// } +// case _ => { +// strings += "false" +// } +// } +// List( +// Load( +// RDI, +// IndexAddress( +// RIP, +// LabelArg(s".L.str${strings.size - 1}") +// ) +// ) +// ) - } else { - evalExprOntoStack(expr) ++ - List(Pop(RSI)) - }) - // print the value - ++ - List( - assemblyIR.Call(CLibFunc.PrintF), - Move(RDI, ImmediateVal(0)), - assemblyIR.Call(CLibFunc.Fflush) - ) - } +// } else { +// evalExprOntoStack(expr) ++ +// List(Pop(RSI)) +// }) +// // print the value +// ++ +// List( +// assemblyIR.Call(CLibFunc.PrintF), +// Move(RDI, ImmediateVal(0)), +// assemblyIR.Call(CLibFunc.Fflush) +// ) +// } // prints a new line - def printLn()(using - stack: LinkedHashMap[Ident, Int], - strings: ListBuffer[String] - ): List[AsmLine] = { - strings += "" - Load( - RDI, - IndexAddress( - RIP, - LabelArg(s".L.str${strings.size - 1}") - ) - ) - :: - List( - assemblyIR.Call(CLibFunc.Puts), - Move(RDI, ImmediateVal(0)), - assemblyIR.Call(CLibFunc.Fflush) - ) + // def printLn()(using + // stack: LinkedHashMap[Ident, Int], + // strings: ListBuffer[String] + // ): List[AsmLine] = { + // strings += "" + // Load( + // RDI, + // IndexAddress( + // RIP, + // LabelArg(s".L.str${strings.size - 1}") + // ) + // ) + // :: + // List( + // assemblyIR.Call(CLibFunc.Puts), + // Move(RDI, ImmediateVal(0)), + // assemblyIR.Call(CLibFunc.Fflush) + // ) - } + // } } diff --git a/src/main/wacc/backend/assemblyIR.scala b/src/main/wacc/backend/assemblyIR.scala index 14e4f4f..0921fd8 100644 --- a/src/main/wacc/backend/assemblyIR.scala +++ b/src/main/wacc/backend/assemblyIR.scala @@ -179,7 +179,7 @@ object assemblyIR { override def toString(): String = this match { case Byte => "byte " + ptr - case Word => "word " + ptr + case Word => "dword " + ptr // TODO check word/doubleword/quadword case Unspecified => "" } } diff --git a/src/main/wacc/frontend/lexer.scala b/src/main/wacc/frontend/lexer.scala index 2efe517..e0b0a44 100644 --- a/src/main/wacc/frontend/lexer.scala +++ b/src/main/wacc/frontend/lexer.scala @@ -39,7 +39,7 @@ val errConfig = new ErrorConfig { ) } object lexer { - + /** Language description for the WACC lexer */ private val desc = LexicalDesc.plain.copy( From 8ed94e4df3e3a8e8bdcd6bfcb57ceb6ac6231bad Mon Sep 17 00:00:00 2001 From: Alex Ling Date: Tue, 25 Feb 2025 03:17:05 +0000 Subject: [PATCH 17/54] fix: initial exprs on stack --- src/main/wacc/backend/asmGenerator.scala | 15 +++++++++------ 1 file changed, 9 insertions(+), 6 deletions(-) diff --git a/src/main/wacc/backend/asmGenerator.scala b/src/main/wacc/backend/asmGenerator.scala index 2a4a2ad..0e8eb55 100644 --- a/src/main/wacc/backend/asmGenerator.scala +++ b/src/main/wacc/backend/asmGenerator.scala @@ -111,8 +111,8 @@ object asmGenerator { )(using stack: LinkedHashMap[Ident, Int], strings: ListBuffer[String]): List[AsmLine] = stmt match { case Assign(lhs, rhs) => - var dest: IndexAddress = - IndexAddress(RSP, 0) // gets overrwitten + var dest: () => IndexAddress = + () => IndexAddress(RSP, 0) // gets overrwitten (lhs match { case ident: Ident => if (!stack.contains(ident)) { @@ -129,7 +129,9 @@ object asmGenerator { List() }) ++ evalExprOntoStack(rhs) ++ - List(Pop(dest)) + List(Pop(RAX), + Move(dest(), RAX), + ) case If(cond, thenBranch, elseBranch) => { val elseLabel = labelGenerator.getLabel() val endLabel = labelGenerator.getLabel() @@ -173,7 +175,7 @@ object asmGenerator { case CharLiter(v) => List(Push(ImmediateVal(v.toInt))) case ident: Ident => - List(Push(accessVar(ident))) + List(Push(accessVar(ident)())) case ArrayLiter(elems) => expr.ty match { case KnownType.String => @@ -347,11 +349,12 @@ object asmGenerator { Compare(MemLocation(RSP, SizeDir.Word), EAX), Set(Register(RegSize.Byte, RegName.AL), cond), And(EAX, ImmediateVal(_8_BIT_MASK)), + Load(RSP, IndexAddress(RSP, 8)), Push(RAX) ) } - def accessVar(ident: Ident)(using stack: LinkedHashMap[Ident, Int]): IndexAddress = - IndexAddress(RSP, (stack.size - stack(ident)) * 8) + def accessVar(ident: Ident)(using stack: LinkedHashMap[Ident, Int]): () => IndexAddress = + () => IndexAddress(RSP, (stack.size - stack(ident)) * 8) def alignStack()(using stack: LinkedHashMap[Ident, Int]): List[AsmLine] = { List( From 3f76a2c5bf29f7234cf7c9f794ff692bc1c2c5c5 Mon Sep 17 00:00:00 2001 From: Alex Ling Date: Tue, 25 Feb 2025 04:44:08 +0000 Subject: [PATCH 18/54] refactor: extract stack into seperate class --- src/main/wacc/backend/asmGenerator.scala | 214 +++++++++++++---------- 1 file changed, 121 insertions(+), 93 deletions(-) diff --git a/src/main/wacc/backend/asmGenerator.scala b/src/main/wacc/backend/asmGenerator.scala index 0e8eb55..deeaf8e 100644 --- a/src/main/wacc/backend/asmGenerator.scala +++ b/src/main/wacc/backend/asmGenerator.scala @@ -10,7 +10,6 @@ object asmGenerator { val RAX = Register(RegSize.R64, RegName.AX) val EAX = Register(RegSize.E32, RegName.AX) - val RSP = Register(RegSize.R64, RegName.SP) val ESP = Register(RegSize.E32, RegName.SP) val EDX = Register(RegSize.E32, RegName.DX) val RDI = Register(RegSize.R64, RegName.DI) @@ -37,14 +36,14 @@ object asmGenerator { } def generateAsm(microProg: Program): List[AsmLine] = { - given stack: LinkedHashMap[Ident, Int] = LinkedHashMap[Ident, Int]() + given stack: Stack = Stack() given strings: ListBuffer[String] = ListBuffer[String]() val Program(funcs, main) = microProg val progAsm = LabelDef("main") :: funcPrologue() ++ - alignStack() ++ + List(stack.align()) ++ main.flatMap(generateStmt) ++ List(Move(RAX, ImmediateVal(0))) ++ funcEpilogue() ++ @@ -61,7 +60,7 @@ object asmGenerator { } def wrapFunc(labelName: String, funcBody: List[AsmLine])(using - stack: LinkedHashMap[Ident, Int], + stack: Stack, strings: ListBuffer[String] ): List[AsmLine] = { LabelDef(labelName) :: @@ -71,74 +70,71 @@ object asmGenerator { } def generateBuiltInFuncs()(using - stack: LinkedHashMap[Ident, Int], + stack: Stack, strings: ListBuffer[String] ): List[AsmLine] = { wrapFunc( labelGenerator.getLabel(Builtin.Exit), - alignStack() ++ - List(assemblyIR.Call(CLibFunc.Exit)) + List(stack.align(), assemblyIR.Call(CLibFunc.Exit)) ) ++ wrapFunc( labelGenerator.getLabel(Builtin.Printf), - alignStack() ++ - List( - assemblyIR.Call(CLibFunc.PrintF), - Move(RDI, ImmediateVal(0)), - assemblyIR.Call(CLibFunc.Fflush) - ) + List( + stack.align(), + assemblyIR.Call(CLibFunc.PrintF), + Move(RDI, ImmediateVal(0)), + assemblyIR.Call(CLibFunc.Fflush) + ) ) ++ wrapFunc( labelGenerator.getLabel(Builtin.Malloc), - alignStack() ++ - List() + List( + stack.align(), + ) ) ++ wrapFunc(labelGenerator.getLabel(Builtin.Free), List()) ++ wrapFunc( labelGenerator.getLabel(Builtin.Read), - alignStack() ++ - List( - Push(RSI), - Load(RSI, MemLocation(RSP)), - assemblyIR.Call(CLibFunc.Scanf), - Pop(RAX) - ) + List( + stack.align(), + stack.push(RSI), + Load(RSI, stack.head), + assemblyIR.Call(CLibFunc.Scanf), + stack.pop(RAX) + ) ) } def generateStmt( stmt: Stmt - )(using stack: LinkedHashMap[Ident, Int], strings: ListBuffer[String]): List[AsmLine] = + )(using stack: Stack, strings: ListBuffer[String]): List[AsmLine] = stmt match { case Assign(lhs, rhs) => var dest: () => IndexAddress = - () => IndexAddress(RSP, 0) // gets overrwitten + () => IndexAddress(RAX, 0) // gets overrwitten (lhs match { case ident: Ident => + dest = stack.accessVar(ident) if (!stack.contains(ident)) { - stack += (ident -> (stack.size + 1)) - dest = accessVar(ident) - List(Subtract(RSP, ImmediateVal(8))) - } else { - dest = accessVar(ident) - List() - } + List(stack.reserve(ident)) + } else Nil // TODO lhs = arrayElem case _ => // dest = ??? List() }) ++ evalExprOntoStack(rhs) ++ - List(Pop(RAX), - Move(dest(), RAX), + List( + stack.pop(RAX), + Move(dest(), RAX), ) case If(cond, thenBranch, elseBranch) => { val elseLabel = labelGenerator.getLabel() val endLabel = labelGenerator.getLabel() evalExprOntoStack(cond) ++ List( - Compare(MemLocation(RSP, SizeDir.Word), ImmediateVal(0)), - Add(RSP, ImmediateVal(8)), + Compare(stack.head(SizeDir.Word), ImmediateVal(0)), + stack.drop(), Jump(LabelArg(elseLabel), Cond.Equal) ) ++ thenBranch.flatMap(generateStmt) ++ @@ -152,8 +148,8 @@ object asmGenerator { List(LabelDef(startLabel)) ++ evalExprOntoStack(cond) ++ List( - Compare(MemLocation(RSP, SizeDir.Word), ImmediateVal(0)), - Add(RSP, ImmediateVal(8)), + Compare(stack.head(SizeDir.Word), ImmediateVal(0)), + stack.drop(), Jump(LabelArg(endLabel), Cond.Equal) ) ++ body.flatMap(generateStmt) ++ @@ -161,21 +157,21 @@ object asmGenerator { } case microWacc.Return(expr) => evalExprOntoStack(expr) ++ - List(Pop(RAX), assemblyIR.Return()) + List(stack.pop(RAX), assemblyIR.Return()) case call: microWacc.Call => generateCall(call) } def evalExprOntoStack(expr: Expr)(using - stack: LinkedHashMap[Ident, Int], + stack: Stack, strings: ListBuffer[String] ): List[AsmLine] = { expr match { case IntLiter(v) => - List(Push(ImmediateVal(v))) + List(stack.push(ImmediateVal(v))) case CharLiter(v) => - List(Push(ImmediateVal(v.toInt))) + List(stack.push(ImmediateVal(v.toInt))) case ident: Ident => - List(Push(accessVar(ident)())) + List(stack.push(stack.accessVar(ident)())) case ArrayLiter(elems) => expr.ty match { case KnownType.String => @@ -191,13 +187,13 @@ object asmGenerator { LabelArg(s".L.str${strings.size - 1}") ) ), - Push(RAX) + stack.push(RAX) ) // TODO other array types case _ => List() } - case BoolLiter(v) => List(Push(ImmediateVal(if (v) 1 else 0))) - case NullLiter() => List(Push(ImmediateVal(0))) + case BoolLiter(v) => List(stack.push(ImmediateVal(if (v) 1 else 0))) + case NullLiter() => List(stack.push(ImmediateVal(0))) case ArrayElem(value, indices) => List() case UnaryOp(x, op) => op match { @@ -208,12 +204,12 @@ object asmGenerator { case UnaryOperator.Len => List() case UnaryOperator.Negate => List( - Negate(MemLocation(RSP, SizeDir.Word)) + Negate(stack.head(SizeDir.Word)) ) case UnaryOperator.Not => evalExprOntoStack(x) ++ List( - Xor(MemLocation(RSP, SizeDir.Word), ImmediateVal(1)) + Xor(stack.head(SizeDir.Word), ImmediateVal(1)) ) } @@ -223,46 +219,46 @@ object asmGenerator { evalExprOntoStack(x) ++ evalExprOntoStack(y) ++ List( - Pop(RAX), - Add(MemLocation(RSP, SizeDir.Word), EAX) + stack.pop(RAX), + Add(stack.head(SizeDir.Word), EAX) // TODO OVERFLOWING ) case BinaryOperator.Sub => evalExprOntoStack(x) ++ evalExprOntoStack(y) ++ List( - Pop(RAX), - Subtract(MemLocation(RSP, SizeDir.Word), EAX) + stack.pop(RAX), + Subtract(stack.head(SizeDir.Word), EAX) // TODO OVERFLOWING ) case BinaryOperator.Mul => evalExprOntoStack(x) ++ evalExprOntoStack(y) ++ List( - Pop(RAX), - Multiply(EAX, MemLocation(RSP, SizeDir.Word)), - Add(RSP, ImmediateVal(8)), - Push(RAX) + stack.pop(RAX), + Multiply(EAX, stack.head(SizeDir.Word)), + stack.drop(), + stack.push(RAX) // TODO OVERFLOWING ) case BinaryOperator.Div => evalExprOntoStack(y) ++ evalExprOntoStack(x) ++ List( - Pop(RAX), - Divide(MemLocation(RSP, SizeDir.Word)), - Add(RSP, ImmediateVal(8)), - Push(RAX) + stack.pop(RAX), + Divide(stack.head(SizeDir.Word)), + stack.drop(), + stack.push(RAX) // TODO CHECK DIVISOR IS NOT 0 ) case BinaryOperator.Mod => evalExprOntoStack(y) ++ evalExprOntoStack(x) ++ List( - Pop(RAX), - Divide(MemLocation(RSP, SizeDir.Word)), - Add(RSP, ImmediateVal(8)), - Push(RDX) + stack.pop(RAX), + Divide(stack.head(SizeDir.Word)), + stack.drop(), + stack.push(RDX) // TODO CHECK DIVISOR IS NOT 0 ) case BinaryOperator.Eq => @@ -281,15 +277,15 @@ object asmGenerator { evalExprOntoStack(x) ++ evalExprOntoStack(y) ++ List( - Pop(RAX), - And(MemLocation(RSP, SizeDir.Word), EAX) + stack.pop(RAX), + And(stack.head(SizeDir.Word), EAX) ) case BinaryOperator.Or => evalExprOntoStack(x) ++ evalExprOntoStack(y) ++ List( - Pop(RAX), - Or(MemLocation(RSP, SizeDir.Word), EAX) + stack.pop(RAX), + Or(stack.head(SizeDir.Word), EAX) ) } case call: microWacc.Call => generateCall(call) @@ -297,24 +293,24 @@ object asmGenerator { } def generateCall(call: microWacc.Call)(using - stack: LinkedHashMap[Ident, Int], + stack: Stack, strings: ListBuffer[String] ): List[AsmLine] = { val argRegs = List(RDI, RSI, RDX, RCX, R8, R9) val microWacc.Call(target, args) = call argRegs.zip(args).flatMap { (reg, expr) => evalExprOntoStack(expr) ++ - List(Pop(reg)) + List(stack.pop(reg)) } ++ args.drop(argRegs.size).flatMap(evalExprOntoStack) ++ List(assemblyIR.Call(LabelArg(labelGenerator.getLabel(target)))) ++ (if (args.size > argRegs.size) { - List(Load(RSP, IndexAddress(RSP, (args.size - argRegs.size) * 8))) + List(stack.reserve(args.size - argRegs.size)) } else Nil) } // def readIntoVar(dest: IndexAddress, readType: Builtin.ReadInt.type | Builtin.ReadChar.type)(using - // stack: LinkedHashMap[Ident, Int], + // stack: Stack, // strings: ListBuffer[String] // ): List[AsmLine] = { // readType match { @@ -339,41 +335,33 @@ object asmGenerator { // } def generateComparison(x: Expr, y: Expr, cond: Cond)(using - stack: LinkedHashMap[Ident, Int], + stack: Stack, strings: ListBuffer[String] ): List[AsmLine] = { evalExprOntoStack(x) ++ evalExprOntoStack(y) ++ List( - Pop(RAX), - Compare(MemLocation(RSP, SizeDir.Word), EAX), + stack.pop(RAX), + Compare(stack.head(SizeDir.Word), EAX), Set(Register(RegSize.Byte, RegName.AL), cond), - And(EAX, ImmediateVal(_8_BIT_MASK)), - Load(RSP, IndexAddress(RSP, 8)), - Push(RAX) + And(RAX, ImmediateVal(_8_BIT_MASK)), + stack.drop(), + stack.push(RAX) ) } - def accessVar(ident: Ident)(using stack: LinkedHashMap[Ident, Int]): () => IndexAddress = - () => IndexAddress(RSP, (stack.size - stack(ident)) * 8) - - def alignStack()(using stack: LinkedHashMap[Ident, Int]): List[AsmLine] = { - List( - And(RSP, ImmediateVal(-16)) - ) - } // Missing a sub instruction but dont think we need it - def funcPrologue(): List[AsmLine] = { + def funcPrologue()(using stack: Stack): List[AsmLine] = { List( - Push(RBP), - Move(RBP, RSP) + stack.push(RBP), + Move(RBP, Register(RegSize.R64, RegName.SP)) ) } - def funcEpilogue(): List[AsmLine] = { + def funcEpilogue()(using stack: Stack): List[AsmLine] = { List( - Move(RSP, RBP), - Pop(RBP), + Move(Register(RegSize.R64, RegName.SP), RBP), + stack.pop(RBP), assemblyIR.Return() ) } @@ -383,7 +371,7 @@ object asmGenerator { // TODO: refactor, really ugly function // def printF(expr: Expr)(using -// stack: LinkedHashMap[Ident, Int], +// stack: Stack, // strings: ListBuffer[String] // ): List[AsmLine] = { // // determine the format string @@ -442,7 +430,7 @@ object asmGenerator { // prints a new line // def printLn()(using - // stack: LinkedHashMap[Ident, Int], + // stack: Stack, // strings: ListBuffer[String] // ): List[AsmLine] = { // strings += "" @@ -461,4 +449,44 @@ object asmGenerator { // ) // } + + + class Stack { + private val stack = LinkedHashMap[Expr | Int, Int]() + private val RSP = Register(RegSize.R64, RegName.SP) + + def next: Int = stack.size + 1 + def push(expr: Expr, src: Src): AsmLine = { + stack += expr -> next + Push(src) + } + def push(src: Src): AsmLine = { + stack += stack.size -> next + Push(src) + } + def pop(dest: Src): AsmLine = { + stack.remove(stack.last._1) + Pop(dest) + } + def reserve(ident: Ident): AsmLine = { + stack += ident -> next + Subtract(RSP, ImmediateVal(8)) + } + def reserve(n: Int = 1): AsmLine = { + (1 to n).foreach(_ => stack += stack.size -> next) + Subtract(RSP, ImmediateVal(n*8)) + } + def drop(n : Int = 1): AsmLine = { + (1 to n).foreach(_ => stack.remove(stack.last._1)) + Add(RSP, ImmediateVal(n*8)) + } + def accessVar(ident: Ident): () => IndexAddress = () => { + IndexAddress(RSP, (stack.size - stack(ident)) * 8) + } + def head: MemLocation = MemLocation(RSP) + def head(size: SizeDir): MemLocation = MemLocation(RSP, size) + def contains(ident: Ident): Boolean = stack.contains(ident) + // TODO: Might want to actually properly handle this with the LinkedHashMap too + def align(): AsmLine = And(RSP, ImmediateVal(-16)) + } } From f628d16d3d26b8f057740faaf55b4e60ca9db9c5 Mon Sep 17 00:00:00 2001 From: Gleb Koval Date: Tue, 25 Feb 2025 16:27:47 +0000 Subject: [PATCH 19/54] fix: always push a value onto stack on expr evaluation --- src/main/wacc/backend/asmGenerator.scala | 164 +++++------------------ src/test/wacc/examples.scala | 40 +++--- 2 files changed, 51 insertions(+), 153 deletions(-) diff --git a/src/main/wacc/backend/asmGenerator.scala b/src/main/wacc/backend/asmGenerator.scala index deeaf8e..27eb1ec 100644 --- a/src/main/wacc/backend/asmGenerator.scala +++ b/src/main/wacc/backend/asmGenerator.scala @@ -50,7 +50,11 @@ object asmGenerator { generateBuiltInFuncs() val strDirs = strings.toList.zipWithIndex.flatMap { case (str, i) => - List(Directive.Int(str.size), LabelDef(s".L.str$i"), Directive.Asciz(str.replace("\"", "\\\""))) + List( + Directive.Int(str.size), + LabelDef(s".L.str$i"), + Directive.Asciz(str.replace("\"", "\\\"")) + ) } List(Directive.IntelSyntax, Directive.Global("main"), Directive.RoData) ++ @@ -89,7 +93,7 @@ object asmGenerator { wrapFunc( labelGenerator.getLabel(Builtin.Malloc), List( - stack.align(), + stack.align() ) ) ++ wrapFunc(labelGenerator.getLabel(Builtin.Free), List()) ++ @@ -126,7 +130,7 @@ object asmGenerator { evalExprOntoStack(rhs) ++ List( stack.pop(RAX), - Move(dest(), RAX), + Move(dest(), RAX) ) case If(cond, thenBranch, elseBranch) => { val elseLabel = labelGenerator.getLabel() @@ -165,7 +169,7 @@ object asmGenerator { stack: Stack, strings: ListBuffer[String] ): List[AsmLine] = { - expr match { + val out = expr match { case IntLiter(v) => List(stack.push(ImmediateVal(v))) case CharLiter(v) => @@ -196,23 +200,24 @@ object asmGenerator { case NullLiter() => List(stack.push(ImmediateVal(0))) case ArrayElem(value, indices) => List() case UnaryOp(x, op) => - op match { - // TODO: chr and ord are TYPE CASTS. They do not change the internal value, - // but will need bound checking e.t.c. - case UnaryOperator.Chr => List() - case UnaryOperator.Ord => List() - case UnaryOperator.Len => List() - case UnaryOperator.Negate => - List( - Negate(stack.head(SizeDir.Word)) - ) - case UnaryOperator.Not => - evalExprOntoStack(x) ++ + evalExprOntoStack(x) ++ + (op match { + // TODO: chr and ord are TYPE CASTS. They do not change the internal value, + // but will need bound checking e.t.c. + case UnaryOperator.Chr => List() + case UnaryOperator.Ord => List() + case UnaryOperator.Len => List() + case UnaryOperator.Negate => List( - Xor(stack.head(SizeDir.Word), ImmediateVal(1)) + Negate(stack.head(SizeDir.Word)) ) + case UnaryOperator.Not => + evalExprOntoStack(x) ++ + List( + Xor(stack.head(SizeDir.Word), ImmediateVal(1)) + ) - } + }) case BinaryOp(x, y, op) => op match { case BinaryOperator.Add => @@ -288,8 +293,11 @@ object asmGenerator { Or(stack.head(SizeDir.Word), EAX) ) } - case call: microWacc.Call => generateCall(call) + case call: microWacc.Call => + generateCall(call) ++ + List(stack.push(RAX)) } + if out.isEmpty then List(stack.push(ImmediateVal(0))) else out } def generateCall(call: microWacc.Call)(using @@ -305,35 +313,10 @@ object asmGenerator { args.drop(argRegs.size).flatMap(evalExprOntoStack) ++ List(assemblyIR.Call(LabelArg(labelGenerator.getLabel(target)))) ++ (if (args.size > argRegs.size) { - List(stack.reserve(args.size - argRegs.size)) + List(stack.drop(args.size - argRegs.size)) } else Nil) } - // def readIntoVar(dest: IndexAddress, readType: Builtin.ReadInt.type | Builtin.ReadChar.type)(using - // stack: Stack, - // strings: ListBuffer[String] - // ): List[AsmLine] = { - // readType match { - // case Builtin.ReadInt => - // strings += PrintFormat.Int.toString - // case Builtin.ReadChar => - // strings += PrintFormat.Char.toString - // } - // List( - // Load( - // RDI, - // IndexAddress( - // RIP, - // LabelArg(s".L.str${strings.size - 1}") - // ) - // ), - // Load(RSI, dest) - // ) ++ - // // alignStack() ++ - // List(assemblyIR.Call(CLibFunc.Scanf)) - - // } - def generateComparison(x: Expr, y: Expr, cond: Cond)(using stack: Stack, strings: ListBuffer[String] @@ -366,91 +349,6 @@ object asmGenerator { ) } - // def saveRegs(regList: List[Register]): List[AsmLine] = regList.map(Push(_)) - // def restoreRegs(regList: List[Register]): List[AsmLine] = regList.reverse.map(Pop(_)) - -// TODO: refactor, really ugly function -// def printF(expr: Expr)(using -// stack: Stack, -// strings: ListBuffer[String] -// ): List[AsmLine] = { -// // determine the format string -// expr.ty match { -// case KnownType.String => -// strings += PrintFormat.String.toString -// case KnownType.Char => -// strings += PrintFormat.Char.toString -// case KnownType.Int => -// strings += PrintFormat.Int.toString -// case _ => -// strings += PrintFormat.String.toString -// } -// List( -// Load( -// RDI, -// IndexAddress( -// RIP, -// LabelArg(s".L.str${strings.size - 1}") -// ) -// ) -// ) -// ++ -// // determine the actual value to print -// (if (expr.ty == KnownType.Bool) { -// expr match { -// case BoolLiter(true) => { -// strings += "true" -// } -// case _ => { -// strings += "false" -// } -// } -// List( -// Load( -// RDI, -// IndexAddress( -// RIP, -// LabelArg(s".L.str${strings.size - 1}") -// ) -// ) -// ) - -// } else { -// evalExprOntoStack(expr) ++ -// List(Pop(RSI)) -// }) -// // print the value -// ++ -// List( -// assemblyIR.Call(CLibFunc.PrintF), -// Move(RDI, ImmediateVal(0)), -// assemblyIR.Call(CLibFunc.Fflush) -// ) -// } - -// prints a new line - // def printLn()(using - // stack: Stack, - // strings: ListBuffer[String] - // ): List[AsmLine] = { - // strings += "" - // Load( - // RDI, - // IndexAddress( - // RIP, - // LabelArg(s".L.str${strings.size - 1}") - // ) - // ) - // :: - // List( - // assemblyIR.Call(CLibFunc.Puts), - // Move(RDI, ImmediateVal(0)), - // assemblyIR.Call(CLibFunc.Fflush) - // ) - - // } - - class Stack { private val stack = LinkedHashMap[Expr | Int, Int]() private val RSP = Register(RegSize.R64, RegName.SP) @@ -474,11 +372,11 @@ object asmGenerator { } def reserve(n: Int = 1): AsmLine = { (1 to n).foreach(_ => stack += stack.size -> next) - Subtract(RSP, ImmediateVal(n*8)) + Subtract(RSP, ImmediateVal(n * 8)) } - def drop(n : Int = 1): AsmLine = { + def drop(n: Int = 1): AsmLine = { (1 to n).foreach(_ => stack.remove(stack.last._1)) - Add(RSP, ImmediateVal(n*8)) + Add(RSP, ImmediateVal(n * 8)) } def accessVar(ident: Ident): () => IndexAddress = () => { IndexAddress(RSP, (stack.size - stack(ident)) * 8) diff --git a/src/test/wacc/examples.scala b/src/test/wacc/examples.scala index e7397b6..1fae3d7 100644 --- a/src/test/wacc/examples.scala +++ b/src/test/wacc/examples.scala @@ -27,13 +27,13 @@ class ParallelExamplesSpec extends AnyFlatSpec with BeforeAndAfterAll with Paral forEvery(files) { (filename, expectedResult) => val baseFilename = filename.stripSuffix(".wacc") given stdout: PrintStream = PrintStream(File(baseFilename + ".out")) - val result = compile(filename) s"$filename" should "be compiled with correct result" in { + val result = compile(filename) assert(expectedResult.contains(result)) } - if (result == 0) it should "run with correct result" in { + if (expectedResult == List(0)) it should "run with correct result" in { if (fileIsDisallowedBackend(filename)) pending // Retrieve contents to get input and expected output + exit code @@ -85,24 +85,24 @@ class ParallelExamplesSpec extends AnyFlatSpec with BeforeAndAfterAll with Paral Seq( // format: off // disable formatting to avoid binPack - // "^.*wacc-examples/valid/advanced.*$", - // "^.*wacc-examples/valid/array.*$", - // "^.*wacc-examples/valid/basic/exit.*$", - // "^.*wacc-examples/valid/basic/skip.*$", - // "^.*wacc-examples/valid/expressions.*$", - // "^.*wacc-examples/valid/function/nested_functions.*$", - // "^.*wacc-examples/valid/function/simple_functions.*$", - // "^.*wacc-examples/valid/if.*$", - // "^.*wacc-examples/valid/IO/print.*$", - // "^.*wacc-examples/valid/IO/read.*$", - // "^.*wacc-examples/valid/IO/IOLoop.wacc.*$", - // "^.*wacc-examples/valid/IO/IOSequence.wacc.*$", - // "^.*wacc-examples/valid/pairs.*$", - // "^.*wacc-examples/valid/runtimeErr.*$", - // "^.*wacc-examples/valid/scope.*$", - // "^.*wacc-examples/valid/sequence.*$", - // "^.*wacc-examples/valid/variables.*$", - // "^.*wacc-examples/valid/while.*$", + "^.*wacc-examples/valid/advanced.*$", + "^.*wacc-examples/valid/array.*$", + "^.*wacc-examples/valid/basic/exit.*$", + "^.*wacc-examples/valid/basic/skip.*$", + "^.*wacc-examples/valid/expressions.*$", + "^.*wacc-examples/valid/function/nested_functions.*$", + "^.*wacc-examples/valid/function/simple_functions.*$", + "^.*wacc-examples/valid/if.*$", + "^.*wacc-examples/valid/IO/print.*$", + "^.*wacc-examples/valid/IO/read.*$", + "^.*wacc-examples/valid/IO/IOLoop.wacc.*$", + "^.*wacc-examples/valid/IO/IOSequence.wacc.*$", + "^.*wacc-examples/valid/pairs.*$", + "^.*wacc-examples/valid/runtimeErr.*$", + "^.*wacc-examples/valid/scope.*$", + "^.*wacc-examples/valid/sequence.*$", + "^.*wacc-examples/valid/variables.*$", + "^.*wacc-examples/valid/while.*$", // format: on ).find(filename.matches).isDefined } From 8558733719173143265ed3e06d831d966038e36d Mon Sep 17 00:00:00 2001 From: Gleb Koval Date: Tue, 25 Feb 2025 16:33:17 +0000 Subject: [PATCH 20/54] style: scala format lexer --- src/main/wacc/frontend/lexer.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/main/wacc/frontend/lexer.scala b/src/main/wacc/frontend/lexer.scala index e0b0a44..2efe517 100644 --- a/src/main/wacc/frontend/lexer.scala +++ b/src/main/wacc/frontend/lexer.scala @@ -39,7 +39,7 @@ val errConfig = new ErrorConfig { ) } object lexer { - + /** Language description for the WACC lexer */ private val desc = LexicalDesc.plain.copy( From 5f8b87221ccf34afdfb3b257db13c12594cf44e6 Mon Sep 17 00:00:00 2001 From: Gleb Koval Date: Tue, 25 Feb 2025 17:07:21 +0000 Subject: [PATCH 21/54] fix: escape characters within assembly --- src/main/wacc/backend/asmGenerator.scala | 9 ++++++++- src/main/wacc/frontend/lexer.scala | 23 ++++++++++++++--------- 2 files changed, 22 insertions(+), 10 deletions(-) diff --git a/src/main/wacc/backend/asmGenerator.scala b/src/main/wacc/backend/asmGenerator.scala index 27eb1ec..3a73ce9 100644 --- a/src/main/wacc/backend/asmGenerator.scala +++ b/src/main/wacc/backend/asmGenerator.scala @@ -7,6 +7,7 @@ object asmGenerator { import microWacc._ import assemblyIR._ import wacc.types._ + import lexer.escapedChars val RAX = Register(RegSize.R64, RegName.AX) val EAX = Register(RegSize.E32, RegName.AX) @@ -53,7 +54,7 @@ object asmGenerator { List( Directive.Int(str.size), LabelDef(s".L.str$i"), - Directive.Asciz(str.replace("\"", "\\\"")) + Directive.Asciz(str.escaped) ) } @@ -387,4 +388,10 @@ object asmGenerator { // TODO: Might want to actually properly handle this with the LinkedHashMap too def align(): AsmLine = And(RSP, ImmediateVal(-16)) } + + private val escapedCharsMapping = escapedChars.map { case (k, v) => v -> s"\\$k" } + extension (s: String) { + private def escaped: String = + s.flatMap(c => escapedCharsMapping.getOrElse(c, c.toString)) + } } diff --git a/src/main/wacc/frontend/lexer.scala b/src/main/wacc/frontend/lexer.scala index 2efe517..4cb51a8 100644 --- a/src/main/wacc/frontend/lexer.scala +++ b/src/main/wacc/frontend/lexer.scala @@ -39,6 +39,17 @@ val errConfig = new ErrorConfig { ) } object lexer { + val escapedChars: Map[String, Int] = Map( + "0" -> '\u0000', + "b" -> '\b', + "t" -> '\t', + "n" -> '\n', + "f" -> '\f', + "r" -> '\r', + "\\" -> '\\', + "'" -> '\'', + "\"" -> '\"' + ) /** Language description for the WACC lexer */ @@ -63,15 +74,9 @@ object lexer { textDesc = TextDesc.plain.copy( graphicCharacter = Basic(c => c >= ' ' && c != '\\' && c != '\'' && c != '"'), escapeSequences = EscapeDesc.plain.copy( - literals = Set('\\', '"', '\''), - mapping = Map( - "0" -> '\u0000', - "b" -> '\b', - "t" -> '\t', - "n" -> '\n', - "f" -> '\f', - "r" -> '\r' - ) + literals = + escapedChars.filter { (s, chr) => chr.toChar.toString == s }.map(_._2.toChar).toSet, + mapping = escapedChars.filter { (s, chr) => chr.toChar.toString != s } ) ), numericDesc = NumericDesc.plain.copy( From efe9f91303e369d9dcf6474668c54649a7c7d4e0 Mon Sep 17 00:00:00 2001 From: Gleb Koval Date: Tue, 25 Feb 2025 17:10:56 +0000 Subject: [PATCH 22/54] refactor: use non-singleton labelgenerator (instead use class) --- src/main/wacc/backend/asmGenerator.scala | 43 ++++++++++++++---------- 1 file changed, 26 insertions(+), 17 deletions(-) diff --git a/src/main/wacc/backend/asmGenerator.scala b/src/main/wacc/backend/asmGenerator.scala index 3a73ce9..e2ef7c9 100644 --- a/src/main/wacc/backend/asmGenerator.scala +++ b/src/main/wacc/backend/asmGenerator.scala @@ -24,21 +24,10 @@ object asmGenerator { val _8_BIT_MASK = 0xff - object labelGenerator { - var labelVal = -1 - def getLabel(): String = { - labelVal += 1 - s".L$labelVal" - } - def getLabel(target: CallTarget): String = target match { - case Ident(v, _) => s"wacc_$v" - case Builtin(name) => s"_$name" - } - } - def generateAsm(microProg: Program): List[AsmLine] = { given stack: Stack = Stack() given strings: ListBuffer[String] = ListBuffer[String]() + given labelGenerator: LabelGenerator = LabelGenerator() val Program(funcs, main) = microProg val progAsm = @@ -76,7 +65,8 @@ object asmGenerator { def generateBuiltInFuncs()(using stack: Stack, - strings: ListBuffer[String] + strings: ListBuffer[String], + labelGenerator: LabelGenerator ): List[AsmLine] = { wrapFunc( labelGenerator.getLabel(Builtin.Exit), @@ -112,7 +102,11 @@ object asmGenerator { def generateStmt( stmt: Stmt - )(using stack: Stack, strings: ListBuffer[String]): List[AsmLine] = + )(using + stack: Stack, + strings: ListBuffer[String], + labelGenerator: LabelGenerator + ): List[AsmLine] = stmt match { case Assign(lhs, rhs) => var dest: () => IndexAddress = @@ -168,7 +162,8 @@ object asmGenerator { def evalExprOntoStack(expr: Expr)(using stack: Stack, - strings: ListBuffer[String] + strings: ListBuffer[String], + labelGenerator: LabelGenerator ): List[AsmLine] = { val out = expr match { case IntLiter(v) => @@ -303,7 +298,8 @@ object asmGenerator { def generateCall(call: microWacc.Call)(using stack: Stack, - strings: ListBuffer[String] + strings: ListBuffer[String], + labelGenerator: LabelGenerator ): List[AsmLine] = { val argRegs = List(RDI, RSI, RDX, RCX, R8, R9) val microWacc.Call(target, args) = call @@ -320,7 +316,8 @@ object asmGenerator { def generateComparison(x: Expr, y: Expr, cond: Cond)(using stack: Stack, - strings: ListBuffer[String] + strings: ListBuffer[String], + labelGenerator: LabelGenerator ): List[AsmLine] = { evalExprOntoStack(x) ++ evalExprOntoStack(y) ++ @@ -350,6 +347,18 @@ object asmGenerator { ) } + class LabelGenerator { + var labelVal = -1 + def getLabel(): String = { + labelVal += 1 + s".L$labelVal" + } + def getLabel(target: CallTarget): String = target match { + case Ident(v, _) => s"wacc_$v" + case Builtin(name) => s"_$name" + } + } + class Stack { private val stack = LinkedHashMap[Expr | Int, Int]() private val RSP = Register(RegSize.R64, RegName.SP) From 4f3596b48ac1b20883217bded63018d6594df8fa Mon Sep 17 00:00:00 2001 From: Gleb Koval Date: Tue, 25 Feb 2025 17:35:04 +0000 Subject: [PATCH 23/54] ci: fix check commits --- .gitlab-ci.yml | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/.gitlab-ci.yml b/.gitlab-ci.yml index 98a6044..f7d5454 100644 --- a/.gitlab-ci.yml +++ b/.gitlab-ci.yml @@ -30,9 +30,9 @@ check_commits: before_script: - apk add git - npm install -g @commitlint/cli @commitlint/config-conventional - - git pull origin master + - git checkout origin/master script: - - npx commitlint --from origin/master --to HEAD --verbose + - npx commitlint --from origin/master --to ${CI_COMMIT_SHA} --verbose compile_jvm: stage: compile @@ -48,10 +48,10 @@ test_jvm: image: gumjoe/wacc-ci-scala:x86 stage: test # Use our own runner (not cloud VM or shared) to ensure we have multiple cores. - tags: [ large ] + tags: [large] # This is expensive, so do use `dependencies` instead of `needs` to # ensure all previous stages pass. - dependencies: [ compile_jvm ] + dependencies: [compile_jvm] before_script: - git clone https://$EXAMPLES_AUTH@gitlab.doc.ic.ac.uk/lab2425_spring/wacc-examples.git script: From 87a239f37c86a7d1355e274950972bb42079d441 Mon Sep 17 00:00:00 2001 From: Alex Ling Date: Tue, 25 Feb 2025 18:20:50 +0000 Subject: [PATCH 24/54] fix: alignment issue with stack in read --- src/main/wacc/backend/asmGenerator.scala | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/main/wacc/backend/asmGenerator.scala b/src/main/wacc/backend/asmGenerator.scala index e2ef7c9..0d02653 100644 --- a/src/main/wacc/backend/asmGenerator.scala +++ b/src/main/wacc/backend/asmGenerator.scala @@ -92,10 +92,12 @@ object asmGenerator { labelGenerator.getLabel(Builtin.Read), List( stack.align(), + stack.reserve(), stack.push(RSI), Load(RSI, stack.head), assemblyIR.Call(CLibFunc.Scanf), - stack.pop(RAX) + stack.pop(RAX), + stack.drop() ) ) } From 7fd92b4212f024874cb9b3f73147145a1e887090 Mon Sep 17 00:00:00 2001 From: Alex Ling Date: Tue, 25 Feb 2025 18:25:34 +0000 Subject: [PATCH 25/54] refactor: passed exit and read tests --- src/test/wacc/examples.scala | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/test/wacc/examples.scala b/src/test/wacc/examples.scala index 1fae3d7..dc3f817 100644 --- a/src/test/wacc/examples.scala +++ b/src/test/wacc/examples.scala @@ -87,14 +87,14 @@ class ParallelExamplesSpec extends AnyFlatSpec with BeforeAndAfterAll with Paral // disable formatting to avoid binPack "^.*wacc-examples/valid/advanced.*$", "^.*wacc-examples/valid/array.*$", - "^.*wacc-examples/valid/basic/exit.*$", + // "^.*wacc-examples/valid/basic/exit.*$", "^.*wacc-examples/valid/basic/skip.*$", "^.*wacc-examples/valid/expressions.*$", "^.*wacc-examples/valid/function/nested_functions.*$", "^.*wacc-examples/valid/function/simple_functions.*$", "^.*wacc-examples/valid/if.*$", "^.*wacc-examples/valid/IO/print.*$", - "^.*wacc-examples/valid/IO/read.*$", + // "^.*wacc-examples/valid/IO/read.*$", "^.*wacc-examples/valid/IO/IOLoop.wacc.*$", "^.*wacc-examples/valid/IO/IOSequence.wacc.*$", "^.*wacc-examples/valid/pairs.*$", From 7953790f4d540b4734108aabe4cbd35cc09acba5 Mon Sep 17 00:00:00 2001 From: Jonny Date: Tue, 25 Feb 2025 18:44:11 +0000 Subject: [PATCH 26/54] feat: used Chains instead of Lists --- src/main/wacc/backend/asmGenerator.scala | 190 ++++++++++++----------- 1 file changed, 98 insertions(+), 92 deletions(-) diff --git a/src/main/wacc/backend/asmGenerator.scala b/src/main/wacc/backend/asmGenerator.scala index 0d02653..a951051 100644 --- a/src/main/wacc/backend/asmGenerator.scala +++ b/src/main/wacc/backend/asmGenerator.scala @@ -2,6 +2,8 @@ package wacc import scala.collection.mutable.LinkedHashMap import scala.collection.mutable.ListBuffer +import cats.data.Chain +import cats.syntax.foldable._ object asmGenerator { import microWacc._ @@ -31,33 +33,38 @@ object asmGenerator { val Program(funcs, main) = microProg val progAsm = - LabelDef("main") :: + Chain.one(LabelDef("main")) ++ funcPrologue() ++ - List(stack.align()) ++ - main.flatMap(generateStmt) ++ - List(Move(RAX, ImmediateVal(0))) ++ + Chain(stack.align()) ++ + main.foldLeft(Chain.empty[AsmLine])(_ ++ generateStmt(_)) ++ + Chain.one(Move(RAX, ImmediateVal(0))) ++ funcEpilogue() ++ generateBuiltInFuncs() - val strDirs = strings.toList.zipWithIndex.flatMap { case (str, i) => - List( - Directive.Int(str.size), - LabelDef(s".L.str$i"), - Directive.Asciz(str.escaped) - ) + val strDirs = strings.toList.zipWithIndex.foldLeft(Chain.empty[AsmLine]) { + case (acc, (str, i)) => + acc ++ Chain( + Directive.Int(str.size), + LabelDef(s".L.str$i"), + Directive.Asciz(str.escaped) + ) } - List(Directive.IntelSyntax, Directive.Global("main"), Directive.RoData) ++ - strDirs ++ - List(Directive.Text) ++ - progAsm + val finalChain = Chain( + Directive.IntelSyntax, + Directive.Global("main"), + Directive.RoData + ) ++ strDirs ++ Chain.one(Directive.Text) ++ progAsm + + finalChain.toList + } - def wrapFunc(labelName: String, funcBody: List[AsmLine])(using + def wrapFunc(labelName: String, funcBody: Chain[AsmLine])(using stack: Stack, strings: ListBuffer[String] - ): List[AsmLine] = { - LabelDef(labelName) :: + ): Chain[AsmLine] = { + Chain.one(LabelDef(labelName)) ++ funcPrologue() ++ funcBody ++ funcEpilogue() @@ -65,16 +72,15 @@ object asmGenerator { def generateBuiltInFuncs()(using stack: Stack, - strings: ListBuffer[String], - labelGenerator: LabelGenerator - ): List[AsmLine] = { + strings: ListBuffer[String] + ): Chain[AsmLine] = { wrapFunc( labelGenerator.getLabel(Builtin.Exit), - List(stack.align(), assemblyIR.Call(CLibFunc.Exit)) + Chain(stack.align(), assemblyIR.Call(CLibFunc.Exit)) ) ++ wrapFunc( labelGenerator.getLabel(Builtin.Printf), - List( + Chain( stack.align(), assemblyIR.Call(CLibFunc.PrintF), Move(RDI, ImmediateVal(0)), @@ -83,14 +89,14 @@ object asmGenerator { ) ++ wrapFunc( labelGenerator.getLabel(Builtin.Malloc), - List( + Chain.one( stack.align() ) ) ++ - wrapFunc(labelGenerator.getLabel(Builtin.Free), List()) ++ + wrapFunc(labelGenerator.getLabel(Builtin.Free), Chain.empty) ++ wrapFunc( labelGenerator.getLabel(Builtin.Read), - List( + Chain( stack.align(), stack.reserve(), stack.push(RSI), @@ -104,11 +110,7 @@ object asmGenerator { def generateStmt( stmt: Stmt - )(using - stack: Stack, - strings: ListBuffer[String], - labelGenerator: LabelGenerator - ): List[AsmLine] = + )(using stack: Stack, strings: ListBuffer[String]): Chain[AsmLine] = stmt match { case Assign(lhs, rhs) => var dest: () => IndexAddress = @@ -117,15 +119,15 @@ object asmGenerator { case ident: Ident => dest = stack.accessVar(ident) if (!stack.contains(ident)) { - List(stack.reserve(ident)) - } else Nil + Chain.one(stack.reserve(ident)) + } else Chain.empty // TODO lhs = arrayElem case _ => // dest = ??? - List() + Chain.empty }) ++ evalExprOntoStack(rhs) ++ - List( + Chain( stack.pop(RAX), Move(dest(), RAX) ) @@ -133,47 +135,48 @@ object asmGenerator { val elseLabel = labelGenerator.getLabel() val endLabel = labelGenerator.getLabel() evalExprOntoStack(cond) ++ - List( - Compare(stack.head(SizeDir.Word), ImmediateVal(0)), - stack.drop(), - Jump(LabelArg(elseLabel), Cond.Equal) + Chain.fromSeq( + List( + Compare(stack.head(SizeDir.Word), ImmediateVal(0)), + stack.drop(), + Jump(LabelArg(elseLabel), Cond.Equal) + ) ) ++ - thenBranch.flatMap(generateStmt) ++ - List(Jump(LabelArg(endLabel)), LabelDef(elseLabel)) ++ - elseBranch.flatMap(generateStmt) ++ - List(LabelDef(endLabel)) + Chain.fromSeq(thenBranch).flatMap(generateStmt) ++ + Chain.fromSeq(List(Jump(LabelArg(endLabel)), LabelDef(elseLabel))) ++ + Chain.fromSeq(elseBranch).flatMap(generateStmt) ++ + Chain.one(LabelDef(endLabel)) } case While(cond, body) => { val startLabel = labelGenerator.getLabel() val endLabel = labelGenerator.getLabel() - List(LabelDef(startLabel)) ++ + Chain.one(LabelDef(startLabel)) ++ evalExprOntoStack(cond) ++ - List( + Chain( Compare(stack.head(SizeDir.Word), ImmediateVal(0)), stack.drop(), Jump(LabelArg(endLabel), Cond.Equal) ) ++ - body.flatMap(generateStmt) ++ - List(Jump(LabelArg(startLabel)), LabelDef(endLabel)) + Chain.fromSeq(body).flatMap(generateStmt) ++ + Chain(Jump(LabelArg(startLabel)), LabelDef(endLabel)) } case microWacc.Return(expr) => evalExprOntoStack(expr) ++ - List(stack.pop(RAX), assemblyIR.Return()) + Chain(stack.pop(RAX), assemblyIR.Return()) case call: microWacc.Call => generateCall(call) } def evalExprOntoStack(expr: Expr)(using stack: Stack, - strings: ListBuffer[String], - labelGenerator: LabelGenerator - ): List[AsmLine] = { + strings: ListBuffer[String] + ): Chain[AsmLine] = { val out = expr match { case IntLiter(v) => - List(stack.push(ImmediateVal(v))) + Chain.one(stack.push(ImmediateVal(v))) case CharLiter(v) => - List(stack.push(ImmediateVal(v.toInt))) + Chain.one(stack.push(ImmediateVal(v.toInt))) case ident: Ident => - List(stack.push(stack.accessVar(ident)())) + Chain.one(stack.push(stack.accessVar(ident)())) case ArrayLiter(elems) => expr.ty match { case KnownType.String => @@ -181,7 +184,7 @@ object asmGenerator { case CharLiter(v) => v case _ => "" }.mkString - List( + Chain( Load( RAX, IndexAddress( @@ -192,26 +195,24 @@ object asmGenerator { stack.push(RAX) ) // TODO other array types - case _ => List() + case _ => Chain.empty } - case BoolLiter(v) => List(stack.push(ImmediateVal(if (v) 1 else 0))) - case NullLiter() => List(stack.push(ImmediateVal(0))) - case ArrayElem(value, indices) => List() + case BoolLiter(v) => Chain.one(stack.push(ImmediateVal(if (v) 1 else 0))) + case NullLiter() => Chain.one(stack.push(ImmediateVal(0))) + case ArrayElem(value, indices) => Chain.empty case UnaryOp(x, op) => evalExprOntoStack(x) ++ (op match { // TODO: chr and ord are TYPE CASTS. They do not change the internal value, // but will need bound checking e.t.c. - case UnaryOperator.Chr => List() - case UnaryOperator.Ord => List() - case UnaryOperator.Len => List() + case UnaryOperator.Chr | UnaryOperator.Ord | UnaryOperator.Len => Chain.empty case UnaryOperator.Negate => - List( + Chain.one( Negate(stack.head(SizeDir.Word)) ) case UnaryOperator.Not => evalExprOntoStack(x) ++ - List( + Chain.one( Xor(stack.head(SizeDir.Word), ImmediateVal(1)) ) @@ -221,7 +222,7 @@ object asmGenerator { case BinaryOperator.Add => evalExprOntoStack(x) ++ evalExprOntoStack(y) ++ - List( + Chain( stack.pop(RAX), Add(stack.head(SizeDir.Word), EAX) // TODO OVERFLOWING @@ -229,7 +230,7 @@ object asmGenerator { case BinaryOperator.Sub => evalExprOntoStack(x) ++ evalExprOntoStack(y) ++ - List( + Chain( stack.pop(RAX), Subtract(stack.head(SizeDir.Word), EAX) // TODO OVERFLOWING @@ -237,7 +238,7 @@ object asmGenerator { case BinaryOperator.Mul => evalExprOntoStack(x) ++ evalExprOntoStack(y) ++ - List( + Chain( stack.pop(RAX), Multiply(EAX, stack.head(SizeDir.Word)), stack.drop(), @@ -247,7 +248,7 @@ object asmGenerator { case BinaryOperator.Div => evalExprOntoStack(y) ++ evalExprOntoStack(x) ++ - List( + Chain( stack.pop(RAX), Divide(stack.head(SizeDir.Word)), stack.drop(), @@ -257,7 +258,7 @@ object asmGenerator { case BinaryOperator.Mod => evalExprOntoStack(y) ++ evalExprOntoStack(x) ++ - List( + Chain( stack.pop(RAX), Divide(stack.head(SizeDir.Word)), stack.drop(), @@ -279,51 +280,56 @@ object asmGenerator { case BinaryOperator.And => evalExprOntoStack(x) ++ evalExprOntoStack(y) ++ - List( + Chain( stack.pop(RAX), And(stack.head(SizeDir.Word), EAX) ) case BinaryOperator.Or => evalExprOntoStack(x) ++ evalExprOntoStack(y) ++ - List( + Chain( stack.pop(RAX), Or(stack.head(SizeDir.Word), EAX) ) } case call: microWacc.Call => generateCall(call) ++ - List(stack.push(RAX)) + Chain.one(stack.push(RAX)) } - if out.isEmpty then List(stack.push(ImmediateVal(0))) else out + if out.isEmpty then Chain.one(stack.push(ImmediateVal(0))) else out } def generateCall(call: microWacc.Call)(using stack: Stack, - strings: ListBuffer[String], - labelGenerator: LabelGenerator - ): List[AsmLine] = { + strings: ListBuffer[String] + ): Chain[AsmLine] = { val argRegs = List(RDI, RSI, RDX, RCX, R8, R9) val microWacc.Call(target, args) = call - argRegs.zip(args).flatMap { (reg, expr) => - evalExprOntoStack(expr) ++ - List(stack.pop(reg)) - } ++ - args.drop(argRegs.size).flatMap(evalExprOntoStack) ++ - List(assemblyIR.Call(LabelArg(labelGenerator.getLabel(target)))) ++ - (if (args.size > argRegs.size) { - List(stack.drop(args.size - argRegs.size)) - } else Nil) + + val regMoves = argRegs + .zip(args) + .map { (reg, expr) => + evalExprOntoStack(expr) ++ + Chain.one(stack.pop(reg)) + } + .combineAll + + val stackPushes = args.drop(argRegs.size).map(evalExprOntoStack).combineAll + + regMoves ++ + stackPushes ++ + Chain.one(assemblyIR.Call(LabelArg(labelGenerator.getLabel(target)))) ++ + (if (args.size > argRegs.size) Chain.one(stack.drop(args.size - argRegs.size)) + else Chain.empty) } def generateComparison(x: Expr, y: Expr, cond: Cond)(using stack: Stack, - strings: ListBuffer[String], - labelGenerator: LabelGenerator - ): List[AsmLine] = { + strings: ListBuffer[String] + ): Chain[AsmLine] = { evalExprOntoStack(x) ++ evalExprOntoStack(y) ++ - List( + Chain( stack.pop(RAX), Compare(stack.head(SizeDir.Word), EAX), Set(Register(RegSize.Byte, RegName.AL), cond), @@ -334,15 +340,15 @@ object asmGenerator { } // Missing a sub instruction but dont think we need it - def funcPrologue()(using stack: Stack): List[AsmLine] = { - List( + def funcPrologue()(using stack: Stack): Chain[AsmLine] = { + Chain( stack.push(RBP), Move(RBP, Register(RegSize.R64, RegName.SP)) ) } - def funcEpilogue()(using stack: Stack): List[AsmLine] = { - List( + def funcEpilogue()(using stack: Stack): Chain[AsmLine] = { + Chain( Move(Register(RegSize.R64, RegName.SP), RBP), stack.pop(RBP), assemblyIR.Return() From edbc03ee25fcecd6feaddba9aeabc9aea39a515c Mon Sep 17 00:00:00 2001 From: Jonny Date: Tue, 25 Feb 2025 19:39:55 +0000 Subject: [PATCH 27/54] feat: used local mutable Chains. Also implemented new LabelGenerator --- src/main/wacc/backend/asmGenerator.scala | 447 +++++++++++------------ 1 file changed, 215 insertions(+), 232 deletions(-) diff --git a/src/main/wacc/backend/asmGenerator.scala b/src/main/wacc/backend/asmGenerator.scala index a951051..7f22e20 100644 --- a/src/main/wacc/backend/asmGenerator.scala +++ b/src/main/wacc/backend/asmGenerator.scala @@ -4,6 +4,7 @@ import scala.collection.mutable.LinkedHashMap import scala.collection.mutable.ListBuffer import cats.data.Chain import cats.syntax.foldable._ +import parsley.token.errors.Label object asmGenerator { import microWacc._ @@ -26,10 +27,27 @@ object asmGenerator { val _8_BIT_MASK = 0xff + extension (chain: Chain[AsmLine]) + def +=(line: AsmLine): Chain[AsmLine] = chain.append(line) + + class LabelGenerator { + var labelVal = -1 + def getLabel(): String = { + labelVal += 1 + s".L$labelVal" + } + def getLabel(target: CallTarget): String = target match { + case Ident(v, _) => s"wacc_$v" + case Builtin(name) => s"_$name" + } + } + + def generateAsm(microProg: Program): List[AsmLine] = { given stack: Stack = Stack() given strings: ListBuffer[String] = ListBuffer[String]() given labelGenerator: LabelGenerator = LabelGenerator() + val Program(funcs, main) = microProg val progAsm = @@ -64,295 +82,260 @@ object asmGenerator { stack: Stack, strings: ListBuffer[String] ): Chain[AsmLine] = { - Chain.one(LabelDef(labelName)) ++ - funcPrologue() ++ - funcBody ++ - funcEpilogue() + var chain = Chain.empty[AsmLine] + + chain += LabelDef(labelName) + chain ++= funcPrologue() + chain ++= funcBody + chain ++= funcEpilogue() + + chain } def generateBuiltInFuncs()(using stack: Stack, - strings: ListBuffer[String] + strings: ListBuffer[String], + labelGenerator: LabelGenerator ): Chain[AsmLine] = { - wrapFunc( + var chain = Chain.empty[AsmLine] + + chain ++= wrapFunc( labelGenerator.getLabel(Builtin.Exit), Chain(stack.align(), assemblyIR.Call(CLibFunc.Exit)) - ) ++ - wrapFunc( - labelGenerator.getLabel(Builtin.Printf), - Chain( - stack.align(), - assemblyIR.Call(CLibFunc.PrintF), - Move(RDI, ImmediateVal(0)), - assemblyIR.Call(CLibFunc.Fflush) - ) - ) ++ - wrapFunc( - labelGenerator.getLabel(Builtin.Malloc), - Chain.one( - stack.align() - ) - ) ++ - wrapFunc(labelGenerator.getLabel(Builtin.Free), Chain.empty) ++ - wrapFunc( - labelGenerator.getLabel(Builtin.Read), - Chain( - stack.align(), - stack.reserve(), - stack.push(RSI), - Load(RSI, stack.head), - assemblyIR.Call(CLibFunc.Scanf), - stack.pop(RAX), - stack.drop() - ) + ) + + chain ++= wrapFunc( + labelGenerator.getLabel(Builtin.Printf), + Chain( + stack.align(), + assemblyIR.Call(CLibFunc.PrintF), + Move(RDI, ImmediateVal(0)), + assemblyIR.Call(CLibFunc.Fflush) ) + ) + + chain ++= wrapFunc( + labelGenerator.getLabel(Builtin.Malloc), + Chain.one(stack.align()) + ) + + chain ++= wrapFunc(labelGenerator.getLabel(Builtin.Free), Chain.empty) + + chain ++= wrapFunc( + labelGenerator.getLabel(Builtin.Read), + Chain( + stack.align(), + stack.reserve(), + stack.push(RSI), + Load(RSI, stack.head), + assemblyIR.Call(CLibFunc.Scanf), + stack.pop(RAX), + stack.drop() + ) + ) + + chain } - def generateStmt( - stmt: Stmt - )(using stack: Stack, strings: ListBuffer[String]): Chain[AsmLine] = + def generateStmt(stmt: Stmt)(using stack: Stack, strings: ListBuffer[String], labelGenerator: LabelGenerator): Chain[AsmLine] = { + var chain = Chain.empty[AsmLine] + stmt match { case Assign(lhs, rhs) => - var dest: () => IndexAddress = - () => IndexAddress(RAX, 0) // gets overrwitten - (lhs match { + var dest: () => IndexAddress = () => IndexAddress(RAX, 0) // overwritten below + + lhs match { case ident: Ident => dest = stack.accessVar(ident) - if (!stack.contains(ident)) { - Chain.one(stack.reserve(ident)) - } else Chain.empty + if (!stack.contains(ident)) chain += stack.reserve(ident) // TODO lhs = arrayElem case _ => - // dest = ??? - Chain.empty - }) ++ - evalExprOntoStack(rhs) ++ - Chain( - stack.pop(RAX), - Move(dest(), RAX) - ) - case If(cond, thenBranch, elseBranch) => { + } + + chain ++= evalExprOntoStack(rhs) + chain += stack.pop(RAX) + chain += Move(dest(), RAX) + + case If(cond, thenBranch, elseBranch) => val elseLabel = labelGenerator.getLabel() val endLabel = labelGenerator.getLabel() - evalExprOntoStack(cond) ++ - Chain.fromSeq( - List( - Compare(stack.head(SizeDir.Word), ImmediateVal(0)), - stack.drop(), - Jump(LabelArg(elseLabel), Cond.Equal) - ) - ) ++ - Chain.fromSeq(thenBranch).flatMap(generateStmt) ++ - Chain.fromSeq(List(Jump(LabelArg(endLabel)), LabelDef(elseLabel))) ++ - Chain.fromSeq(elseBranch).flatMap(generateStmt) ++ - Chain.one(LabelDef(endLabel)) - } - case While(cond, body) => { + + chain ++= evalExprOntoStack(cond) + chain += Compare(stack.head(SizeDir.Word), ImmediateVal(0)) + chain += stack.drop() + chain += Jump(LabelArg(elseLabel), Cond.Equal) + + chain ++= Chain.fromSeq(thenBranch).flatMap(generateStmt) + chain += Jump(LabelArg(endLabel)) + chain += LabelDef(elseLabel) + + chain ++= Chain.fromSeq(elseBranch).flatMap(generateStmt) + chain += LabelDef(endLabel) + + case While(cond, body) => val startLabel = labelGenerator.getLabel() val endLabel = labelGenerator.getLabel() - Chain.one(LabelDef(startLabel)) ++ - evalExprOntoStack(cond) ++ - Chain( - Compare(stack.head(SizeDir.Word), ImmediateVal(0)), - stack.drop(), - Jump(LabelArg(endLabel), Cond.Equal) - ) ++ - Chain.fromSeq(body).flatMap(generateStmt) ++ - Chain(Jump(LabelArg(startLabel)), LabelDef(endLabel)) - } + + chain += LabelDef(startLabel) + chain ++= evalExprOntoStack(cond) + chain += Compare(stack.head(SizeDir.Word), ImmediateVal(0)) + chain += stack.drop() + chain += Jump(LabelArg(endLabel), Cond.Equal) + + chain ++= Chain.fromSeq(body).flatMap(generateStmt) + chain += Jump(LabelArg(startLabel)) + chain += LabelDef(endLabel) + case microWacc.Return(expr) => - evalExprOntoStack(expr) ++ - Chain(stack.pop(RAX), assemblyIR.Return()) - case call: microWacc.Call => generateCall(call) + chain ++= evalExprOntoStack(expr) + chain += stack.pop(RAX) + chain += assemblyIR.Return() + + case call: microWacc.Call => + chain ++= generateCall(call) } + chain + } + def evalExprOntoStack(expr: Expr)(using stack: Stack, - strings: ListBuffer[String] + strings: ListBuffer[String], + labelGenerator: LabelGenerator ): Chain[AsmLine] = { - val out = expr match { - case IntLiter(v) => - Chain.one(stack.push(ImmediateVal(v))) - case CharLiter(v) => - Chain.one(stack.push(ImmediateVal(v.toInt))) - case ident: Ident => - Chain.one(stack.push(stack.accessVar(ident)())) + var chain = Chain.empty[AsmLine] + + expr match { + case IntLiter(v) => chain += stack.push(ImmediateVal(v)) + case CharLiter(v) => chain += stack.push(ImmediateVal(v.toInt)) + case ident: Ident => chain += stack.push(stack.accessVar(ident)()) + case ArrayLiter(elems) => expr.ty match { case KnownType.String => - strings += elems.map { - case CharLiter(v) => v - case _ => "" - }.mkString - Chain( - Load( - RAX, - IndexAddress( - RIP, - LabelArg(s".L.str${strings.size - 1}") - ) - ), - stack.push(RAX) - ) - // TODO other array types - case _ => Chain.empty + strings += elems.collect { case CharLiter(v) => v }.mkString + chain += Load(RAX, IndexAddress(RIP, LabelArg(s".L.str${strings.size - 1}"))) + chain += stack.push(RAX) + case _ => // Other array types TODO } - case BoolLiter(v) => Chain.one(stack.push(ImmediateVal(if (v) 1 else 0))) - case NullLiter() => Chain.one(stack.push(ImmediateVal(0))) - case ArrayElem(value, indices) => Chain.empty - case UnaryOp(x, op) => - evalExprOntoStack(x) ++ - (op match { - // TODO: chr and ord are TYPE CASTS. They do not change the internal value, - // but will need bound checking e.t.c. - case UnaryOperator.Chr | UnaryOperator.Ord | UnaryOperator.Len => Chain.empty - case UnaryOperator.Negate => - Chain.one( - Negate(stack.head(SizeDir.Word)) - ) - case UnaryOperator.Not => - evalExprOntoStack(x) ++ - Chain.one( - Xor(stack.head(SizeDir.Word), ImmediateVal(1)) - ) - }) - case BinaryOp(x, y, op) => + case BoolLiter(v) => chain += stack.push(ImmediateVal(if (v) 1 else 0)) + case NullLiter() => chain += stack.push(ImmediateVal(0)) + case ArrayElem(_, _) => // TODO: Implement handling + + case UnaryOp(x, op) => + chain ++= evalExprOntoStack(x) op match { - case BinaryOperator.Add => - evalExprOntoStack(x) ++ - evalExprOntoStack(y) ++ - Chain( - stack.pop(RAX), - Add(stack.head(SizeDir.Word), EAX) - // TODO OVERFLOWING - ) - case BinaryOperator.Sub => - evalExprOntoStack(x) ++ - evalExprOntoStack(y) ++ - Chain( - stack.pop(RAX), - Subtract(stack.head(SizeDir.Word), EAX) - // TODO OVERFLOWING - ) - case BinaryOperator.Mul => - evalExprOntoStack(x) ++ - evalExprOntoStack(y) ++ - Chain( - stack.pop(RAX), - Multiply(EAX, stack.head(SizeDir.Word)), - stack.drop(), - stack.push(RAX) - // TODO OVERFLOWING - ) - case BinaryOperator.Div => - evalExprOntoStack(y) ++ - evalExprOntoStack(x) ++ - Chain( - stack.pop(RAX), - Divide(stack.head(SizeDir.Word)), - stack.drop(), - stack.push(RAX) - // TODO CHECK DIVISOR IS NOT 0 - ) - case BinaryOperator.Mod => - evalExprOntoStack(y) ++ - evalExprOntoStack(x) ++ - Chain( - stack.pop(RAX), - Divide(stack.head(SizeDir.Word)), - stack.drop(), - stack.push(RDX) - // TODO CHECK DIVISOR IS NOT 0 - ) - case BinaryOperator.Eq => - generateComparison(x, y, Cond.Equal) - case BinaryOperator.Neq => - generateComparison(x, y, Cond.NotEqual) - case BinaryOperator.Greater => - generateComparison(x, y, Cond.Greater) - case BinaryOperator.GreaterEq => - generateComparison(x, y, Cond.GreaterEqual) - case BinaryOperator.Less => - generateComparison(x, y, Cond.Less) - case BinaryOperator.LessEq => - generateComparison(x, y, Cond.LessEqual) - case BinaryOperator.And => - evalExprOntoStack(x) ++ - evalExprOntoStack(y) ++ - Chain( - stack.pop(RAX), - And(stack.head(SizeDir.Word), EAX) - ) - case BinaryOperator.Or => - evalExprOntoStack(x) ++ - evalExprOntoStack(y) ++ - Chain( - stack.pop(RAX), - Or(stack.head(SizeDir.Word), EAX) - ) + case UnaryOperator.Chr | UnaryOperator.Ord | UnaryOperator.Len => // No op needed + case UnaryOperator.Negate => chain += Negate(stack.head(SizeDir.Word)) + case UnaryOperator.Not => + chain += Xor(stack.head(SizeDir.Word), ImmediateVal(1)) } + + case BinaryOp(x, y, op) => + chain ++= evalExprOntoStack(x) + chain ++= evalExprOntoStack(y) + + chain += stack.pop(RAX) + + op match { + case BinaryOperator.Add => chain += Add(stack.head(SizeDir.Word), EAX) + case BinaryOperator.Sub => chain += Subtract(stack.head(SizeDir.Word), EAX) + case BinaryOperator.Mul => + chain += Multiply(EAX, stack.head(SizeDir.Word)) + chain += stack.drop() + chain += stack.push(RAX) + + case BinaryOperator.Div => + chain += Divide(stack.head(SizeDir.Word)) + chain += stack.drop() + chain += stack.push(RAX) + + case BinaryOperator.Mod => + chain += Divide(stack.head(SizeDir.Word)) + chain += stack.drop() + chain += stack.push(RDX) + + case BinaryOperator.Eq => chain ++= generateComparison(x, y, Cond.Equal) + case BinaryOperator.Neq => chain ++= generateComparison(x, y, Cond.NotEqual) + case BinaryOperator.Greater => chain ++= generateComparison(x, y, Cond.Greater) + case BinaryOperator.GreaterEq => chain ++= generateComparison(x, y, Cond.GreaterEqual) + case BinaryOperator.Less => chain ++= generateComparison(x, y, Cond.Less) + case BinaryOperator.LessEq => chain ++= generateComparison(x, y, Cond.LessEqual) + case BinaryOperator.And => chain += And(stack.head(SizeDir.Word), EAX) + case BinaryOperator.Or => chain += Or(stack.head(SizeDir.Word), EAX) + } + case call: microWacc.Call => - generateCall(call) ++ - Chain.one(stack.push(RAX)) + chain ++= generateCall(call) + chain += stack.push(RAX) } - if out.isEmpty then Chain.one(stack.push(ImmediateVal(0))) else out + + if chain.isEmpty then chain += stack.push(ImmediateVal(0)) + chain } def generateCall(call: microWacc.Call)(using stack: Stack, - strings: ListBuffer[String] + strings: ListBuffer[String], + labelGenerator: LabelGenerator ): Chain[AsmLine] = { + var chain = Chain.empty[AsmLine] val argRegs = List(RDI, RSI, RDX, RCX, R8, R9) val microWacc.Call(target, args) = call - val regMoves = argRegs - .zip(args) - .map { (reg, expr) => - evalExprOntoStack(expr) ++ - Chain.one(stack.pop(reg)) - } - .combineAll + argRegs.zip(args).foreach { (reg, expr) => + chain ++= evalExprOntoStack(expr) + chain += stack.pop(reg) + } - val stackPushes = args.drop(argRegs.size).map(evalExprOntoStack).combineAll + args.drop(argRegs.size).foreach { expr => + chain ++= evalExprOntoStack(expr) + } - regMoves ++ - stackPushes ++ - Chain.one(assemblyIR.Call(LabelArg(labelGenerator.getLabel(target)))) ++ - (if (args.size > argRegs.size) Chain.one(stack.drop(args.size - argRegs.size)) - else Chain.empty) + chain += assemblyIR.Call(LabelArg(labelGenerator.getLabel(target))) + + if (args.size > argRegs.size) { + chain += stack.drop(args.size - argRegs.size) + } + + chain } def generateComparison(x: Expr, y: Expr, cond: Cond)(using stack: Stack, - strings: ListBuffer[String] + strings: ListBuffer[String], + labelGenerator: LabelGenerator ): Chain[AsmLine] = { - evalExprOntoStack(x) ++ - evalExprOntoStack(y) ++ - Chain( - stack.pop(RAX), - Compare(stack.head(SizeDir.Word), EAX), - Set(Register(RegSize.Byte, RegName.AL), cond), - And(RAX, ImmediateVal(_8_BIT_MASK)), - stack.drop(), - stack.push(RAX) - ) + + var chain = Chain.empty[AsmLine] + + chain ++= evalExprOntoStack(x) + chain ++= evalExprOntoStack(y) + chain += stack.pop(RAX) + chain += Compare(stack.head(SizeDir.Word), EAX) + chain += Set(Register(RegSize.Byte, RegName.AL), cond) + chain += And(RAX, ImmediateVal(_8_BIT_MASK)) + chain += stack.drop() + chain += stack.push(RAX) + + chain } // Missing a sub instruction but dont think we need it def funcPrologue()(using stack: Stack): Chain[AsmLine] = { - Chain( - stack.push(RBP), - Move(RBP, Register(RegSize.R64, RegName.SP)) - ) + val chain = Chain.empty[AsmLine] + chain += stack.push(RBP) + chain += Move(RBP, Register(RegSize.R64, RegName.SP)) + chain } def funcEpilogue()(using stack: Stack): Chain[AsmLine] = { - Chain( - Move(Register(RegSize.R64, RegName.SP), RBP), - stack.pop(RBP), - assemblyIR.Return() - ) + val chain = Chain.empty[AsmLine] + chain += Move(Register(RegSize.R64, RegName.SP), RBP) + chain += stack.pop(RBP) + chain += assemblyIR.Return() + chain } class LabelGenerator { From bd0eb76bec0fe745f474eec20aeae2727e618719 Mon Sep 17 00:00:00 2001 From: Jonny Date: Tue, 25 Feb 2025 19:43:43 +0000 Subject: [PATCH 28/54] fix: alignment issue with stack in read --- src/main/wacc/backend/asmGenerator.scala | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/main/wacc/backend/asmGenerator.scala b/src/main/wacc/backend/asmGenerator.scala index 7f22e20..c962c71 100644 --- a/src/main/wacc/backend/asmGenerator.scala +++ b/src/main/wacc/backend/asmGenerator.scala @@ -3,8 +3,8 @@ package wacc import scala.collection.mutable.LinkedHashMap import scala.collection.mutable.ListBuffer import cats.data.Chain -import cats.syntax.foldable._ -import parsley.token.errors.Label +// import cats.syntax.foldable._ +// import parsley.token.errors.Label object asmGenerator { import microWacc._ From ebc65af981223e9634075cf9daa4c101e7f6e8b0 Mon Sep 17 00:00:00 2001 From: Jonny Date: Tue, 25 Feb 2025 19:53:31 +0000 Subject: [PATCH 29/54] feat: extension method concatAll defined on Chain implemented --- src/main/wacc/backend/asmGenerator.scala | 62 +++++++++++++----------- 1 file changed, 33 insertions(+), 29 deletions(-) diff --git a/src/main/wacc/backend/asmGenerator.scala b/src/main/wacc/backend/asmGenerator.scala index c962c71..bf4e404 100644 --- a/src/main/wacc/backend/asmGenerator.scala +++ b/src/main/wacc/backend/asmGenerator.scala @@ -3,7 +3,7 @@ package wacc import scala.collection.mutable.LinkedHashMap import scala.collection.mutable.ListBuffer import cats.data.Chain -// import cats.syntax.foldable._ +import cats.syntax.foldable._ // import parsley.token.errors.Label object asmGenerator { @@ -30,6 +30,9 @@ object asmGenerator { extension (chain: Chain[AsmLine]) def +=(line: AsmLine): Chain[AsmLine] = chain.append(line) + def concatAll(chains: Chain[AsmLine]*): Chain[AsmLine] = + chains.foldLeft(chain)(_ ++ _) + class LabelGenerator { var labelVal = -1 def getLabel(): String = { @@ -42,40 +45,38 @@ object asmGenerator { } } - def generateAsm(microProg: Program): List[AsmLine] = { given stack: Stack = Stack() given strings: ListBuffer[String] = ListBuffer[String]() given labelGenerator: LabelGenerator = LabelGenerator() - val Program(funcs, main) = microProg - val progAsm = - Chain.one(LabelDef("main")) ++ - funcPrologue() ++ - Chain(stack.align()) ++ - main.foldLeft(Chain.empty[AsmLine])(_ ++ generateStmt(_)) ++ - Chain.one(Move(RAX, ImmediateVal(0))) ++ - funcEpilogue() ++ - generateBuiltInFuncs() - - val strDirs = strings.toList.zipWithIndex.foldLeft(Chain.empty[AsmLine]) { - case (acc, (str, i)) => - acc ++ Chain( - Directive.Int(str.size), - LabelDef(s".L.str$i"), - Directive.Asciz(str.escaped) - ) + val strDirs = strings.toList.zipWithIndex.foldMap { case (str, i) => + Chain( + Directive.Int(str.size), + LabelDef(s".L.str$i"), + Directive.Asciz(str.escaped) + ) } - val finalChain = Chain( + val progAsm = Chain(LabelDef("main")).concatAll( + funcPrologue(), + Chain.one(stack.align()), + main.foldMap(generateStmt(_)), + Chain.one(Move(RAX, ImmediateVal(0))), + funcEpilogue(), + generateBuiltInFuncs() + ) + + Chain( Directive.IntelSyntax, Directive.Global("main"), Directive.RoData - ) ++ strDirs ++ Chain.one(Directive.Text) ++ progAsm - - finalChain.toList - + ).concatAll( + strDirs, + Chain.one(Directive.Text), + progAsm + ).toList } def wrapFunc(labelName: String, funcBody: Chain[AsmLine])(using @@ -137,7 +138,11 @@ object asmGenerator { chain } - def generateStmt(stmt: Stmt)(using stack: Stack, strings: ListBuffer[String], labelGenerator: LabelGenerator): Chain[AsmLine] = { + def generateStmt(stmt: Stmt)(using + stack: Stack, + strings: ListBuffer[String], + labelGenerator: LabelGenerator + ): Chain[AsmLine] = { var chain = Chain.empty[AsmLine] stmt match { @@ -214,21 +219,20 @@ object asmGenerator { expr.ty match { case KnownType.String => strings += elems.collect { case CharLiter(v) => v }.mkString - chain += Load(RAX, IndexAddress(RIP, LabelArg(s".L.str${strings.size - 1}"))) - chain += stack.push(RAX) + chain += Load(RAX, IndexAddress(RIP, LabelArg(s".L.str${strings.size - 1}"))) + chain += stack.push(RAX) case _ => // Other array types TODO } case BoolLiter(v) => chain += stack.push(ImmediateVal(if (v) 1 else 0)) case NullLiter() => chain += stack.push(ImmediateVal(0)) case ArrayElem(_, _) => // TODO: Implement handling - case UnaryOp(x, op) => chain ++= evalExprOntoStack(x) op match { case UnaryOperator.Chr | UnaryOperator.Ord | UnaryOperator.Len => // No op needed case UnaryOperator.Negate => chain += Negate(stack.head(SizeDir.Word)) - case UnaryOperator.Not => + case UnaryOperator.Not => chain += Xor(stack.head(SizeDir.Word), ImmediateVal(1)) } From 11c483439c4a954471bda72f48e3f4465a42709c Mon Sep 17 00:00:00 2001 From: Gleb Koval Date: Tue, 25 Feb 2025 21:05:21 +0000 Subject: [PATCH 30/54] fix: generate strDirs after prog, change `+=` to `+` --- src/main/wacc/backend/asmGenerator.scala | 34 ++++++++---------------- 1 file changed, 11 insertions(+), 23 deletions(-) diff --git a/src/main/wacc/backend/asmGenerator.scala b/src/main/wacc/backend/asmGenerator.scala index bf4e404..6d10b24 100644 --- a/src/main/wacc/backend/asmGenerator.scala +++ b/src/main/wacc/backend/asmGenerator.scala @@ -28,7 +28,7 @@ object asmGenerator { val _8_BIT_MASK = 0xff extension (chain: Chain[AsmLine]) - def +=(line: AsmLine): Chain[AsmLine] = chain.append(line) + def +(line: AsmLine): Chain[AsmLine] = chain.append(line) def concatAll(chains: Chain[AsmLine]*): Chain[AsmLine] = chains.foldLeft(chain)(_ ++ _) @@ -51,14 +51,6 @@ object asmGenerator { given labelGenerator: LabelGenerator = LabelGenerator() val Program(funcs, main) = microProg - val strDirs = strings.toList.zipWithIndex.foldMap { case (str, i) => - Chain( - Directive.Int(str.size), - LabelDef(s".L.str$i"), - Directive.Asciz(str.escaped) - ) - } - val progAsm = Chain(LabelDef("main")).concatAll( funcPrologue(), Chain.one(stack.align()), @@ -68,6 +60,14 @@ object asmGenerator { generateBuiltInFuncs() ) + val strDirs = strings.toList.zipWithIndex.foldMap { case (str, i) => + Chain( + Directive.Int(str.size), + LabelDef(s".L.str$i"), + Directive.Asciz(str.escaped) + ) + } + Chain( Directive.IntelSyntax, Directive.Global("main"), @@ -328,32 +328,20 @@ object asmGenerator { // Missing a sub instruction but dont think we need it def funcPrologue()(using stack: Stack): Chain[AsmLine] = { - val chain = Chain.empty[AsmLine] + var chain = Chain.empty[AsmLine] chain += stack.push(RBP) chain += Move(RBP, Register(RegSize.R64, RegName.SP)) chain } def funcEpilogue()(using stack: Stack): Chain[AsmLine] = { - val chain = Chain.empty[AsmLine] + var chain = Chain.empty[AsmLine] chain += Move(Register(RegSize.R64, RegName.SP), RBP) chain += stack.pop(RBP) chain += assemblyIR.Return() chain } - class LabelGenerator { - var labelVal = -1 - def getLabel(): String = { - labelVal += 1 - s".L$labelVal" - } - def getLabel(target: CallTarget): String = target match { - case Ident(v, _) => s"wacc_$v" - case Builtin(name) => s"_$name" - } - } - class Stack { private val stack = LinkedHashMap[Expr | Int, Int]() private val RSP = Register(RegSize.R64, RegName.SP) From c9723f9359c149dd8c0b24ec6a9dc04e433807ad Mon Sep 17 00:00:00 2001 From: Guy C Date: Tue, 25 Feb 2025 21:37:18 +0000 Subject: [PATCH 31/54] feat: implements sign extension operation for division --- src/main/wacc/backend/asmGenerator.scala | 12 +++++++++++- src/main/wacc/backend/assemblyIR.scala | 1 + src/test/wacc/examples.scala | 4 ++-- 3 files changed, 14 insertions(+), 3 deletions(-) diff --git a/src/main/wacc/backend/asmGenerator.scala b/src/main/wacc/backend/asmGenerator.scala index 6d10b24..d8dd5b0 100644 --- a/src/main/wacc/backend/asmGenerator.scala +++ b/src/main/wacc/backend/asmGenerator.scala @@ -237,8 +237,8 @@ object asmGenerator { } case BinaryOp(x, y, op) => - chain ++= evalExprOntoStack(x) chain ++= evalExprOntoStack(y) + chain ++= evalExprOntoStack(x) chain += stack.pop(RAX) @@ -251,11 +251,21 @@ object asmGenerator { chain += stack.push(RAX) case BinaryOperator.Div => + // chain += stack.pop(RDX) + // chain += stack.pop(RAX) + // chain += stack.push(RDX) + // chain += stack.push(RAX) + chain += CDQ() chain += Divide(stack.head(SizeDir.Word)) chain += stack.drop() chain += stack.push(RAX) case BinaryOperator.Mod => + // chain += stack.pop(RDX) + // chain += stack.pop(RAX) + // chain += stack.push(RDX) + // chain += stack.push(RAX) + chain += CDQ() chain += Divide(stack.head(SizeDir.Word)) chain += stack.drop() chain += stack.push(RDX) diff --git a/src/main/wacc/backend/assemblyIR.scala b/src/main/wacc/backend/assemblyIR.scala index 0921fd8..5a59fd1 100644 --- a/src/main/wacc/backend/assemblyIR.scala +++ b/src/main/wacc/backend/assemblyIR.scala @@ -112,6 +112,7 @@ object assemblyIR { case class Move(op1: Dest, op2: Src) extends Operation("mov", op1, op2) case class Load(op1: Register, op2: MemLocation | IndexAddress) extends Operation("lea ", op1, op2) + case class CDQ() extends Operation("cdq") case class Return() extends Operation("ret") diff --git a/src/test/wacc/examples.scala b/src/test/wacc/examples.scala index dc3f817..c3189bb 100644 --- a/src/test/wacc/examples.scala +++ b/src/test/wacc/examples.scala @@ -88,7 +88,7 @@ class ParallelExamplesSpec extends AnyFlatSpec with BeforeAndAfterAll with Paral "^.*wacc-examples/valid/advanced.*$", "^.*wacc-examples/valid/array.*$", // "^.*wacc-examples/valid/basic/exit.*$", - "^.*wacc-examples/valid/basic/skip.*$", + // "^.*wacc-examples/valid/basic/skip.*$", "^.*wacc-examples/valid/expressions.*$", "^.*wacc-examples/valid/function/nested_functions.*$", "^.*wacc-examples/valid/function/simple_functions.*$", @@ -100,7 +100,7 @@ class ParallelExamplesSpec extends AnyFlatSpec with BeforeAndAfterAll with Paral "^.*wacc-examples/valid/pairs.*$", "^.*wacc-examples/valid/runtimeErr.*$", "^.*wacc-examples/valid/scope.*$", - "^.*wacc-examples/valid/sequence.*$", + // "^.*wacc-examples/valid/sequence.*$", "^.*wacc-examples/valid/variables.*$", "^.*wacc-examples/valid/while.*$", // format: on From 64b015e4942a15bae2265b41978a57cf4f675b87 Mon Sep 17 00:00:00 2001 From: Gleb Koval Date: Tue, 25 Feb 2025 21:49:23 +0000 Subject: [PATCH 32/54] ci: use JS commitlint configuration --- .commitlintrc.js | 4 ++++ .commitlintrc.yml | 1 - 2 files changed, 4 insertions(+), 1 deletion(-) create mode 100644 .commitlintrc.js delete mode 100644 .commitlintrc.yml diff --git a/.commitlintrc.js b/.commitlintrc.js new file mode 100644 index 0000000..a295b1d --- /dev/null +++ b/.commitlintrc.js @@ -0,0 +1,4 @@ +export default { + extends: ['@commitlint/config-conventional'], + ignores: [commit => commit.startsWith("Local Mutable Chains\n")] +} diff --git a/.commitlintrc.yml b/.commitlintrc.yml deleted file mode 100644 index 175ef04..0000000 --- a/.commitlintrc.yml +++ /dev/null @@ -1 +0,0 @@ -extends: "@commitlint/config-conventional" From da0ef9ec24563eb914b907430c827bb1591aa6da Mon Sep 17 00:00:00 2001 From: Gleb Koval Date: Tue, 25 Feb 2025 22:03:53 +0000 Subject: [PATCH 33/54] ci: checkout commitlint back to current commit --- .gitlab-ci.yml | 1 + 1 file changed, 1 insertion(+) diff --git a/.gitlab-ci.yml b/.gitlab-ci.yml index f7d5454..e61541c 100644 --- a/.gitlab-ci.yml +++ b/.gitlab-ci.yml @@ -32,6 +32,7 @@ check_commits: - npm install -g @commitlint/cli @commitlint/config-conventional - git checkout origin/master script: + - git checkout ${CI_COMMIT_SHA} - npx commitlint --from origin/master --to ${CI_COMMIT_SHA} --verbose compile_jvm: From f76b7a9dc21aca1080a994c1b5fd2aa138aab10d Mon Sep 17 00:00:00 2001 From: Jonny Date: Tue, 25 Feb 2025 22:46:48 +0000 Subject: [PATCH 34/54] refactor: replace explicit loops and flatMap with foldMap --- src/main/wacc/backend/asmGenerator.scala | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/src/main/wacc/backend/asmGenerator.scala b/src/main/wacc/backend/asmGenerator.scala index d8dd5b0..da34592 100644 --- a/src/main/wacc/backend/asmGenerator.scala +++ b/src/main/wacc/backend/asmGenerator.scala @@ -170,11 +170,11 @@ object asmGenerator { chain += stack.drop() chain += Jump(LabelArg(elseLabel), Cond.Equal) - chain ++= Chain.fromSeq(thenBranch).flatMap(generateStmt) + chain ++= thenBranch.foldMap(generateStmt) chain += Jump(LabelArg(endLabel)) chain += LabelDef(elseLabel) - chain ++= Chain.fromSeq(elseBranch).flatMap(generateStmt) + chain ++= elseBranch.foldMap(generateStmt) chain += LabelDef(endLabel) case While(cond, body) => @@ -187,7 +187,7 @@ object asmGenerator { chain += stack.drop() chain += Jump(LabelArg(endLabel), Cond.Equal) - chain ++= Chain.fromSeq(body).flatMap(generateStmt) + chain ++= body.foldMap(generateStmt) chain += Jump(LabelArg(startLabel)) chain += LabelDef(endLabel) @@ -298,13 +298,13 @@ object asmGenerator { val argRegs = List(RDI, RSI, RDX, RCX, R8, R9) val microWacc.Call(target, args) = call - argRegs.zip(args).foreach { (reg, expr) => + argRegs.zip(args).foldMap { (reg, expr) => chain ++= evalExprOntoStack(expr) chain += stack.pop(reg) } - args.drop(argRegs.size).foreach { expr => - chain ++= evalExprOntoStack(expr) + args.drop(argRegs.size).foldMap { + chain ++= evalExprOntoStack(_) } chain += assemblyIR.Call(LabelArg(labelGenerator.getLabel(target))) From 9ca50540e6b9d77f4e3737c575e101fc9b6b3604 Mon Sep 17 00:00:00 2001 From: Guy C Date: Wed, 26 Feb 2025 01:10:14 +0000 Subject: [PATCH 35/54] fix: fixed implementation of if statement code gen --- src/main/wacc/backend/asmGenerator.scala | 16 ++++------------ src/test/wacc/examples.scala | 2 +- 2 files changed, 5 insertions(+), 13 deletions(-) diff --git a/src/main/wacc/backend/asmGenerator.scala b/src/main/wacc/backend/asmGenerator.scala index da34592..f3df8a5 100644 --- a/src/main/wacc/backend/asmGenerator.scala +++ b/src/main/wacc/backend/asmGenerator.scala @@ -166,8 +166,8 @@ object asmGenerator { val endLabel = labelGenerator.getLabel() chain ++= evalExprOntoStack(cond) - chain += Compare(stack.head(SizeDir.Word), ImmediateVal(0)) - chain += stack.drop() + chain += stack.pop(RAX) + chain += Compare(RAX, ImmediateVal(0)) chain += Jump(LabelArg(elseLabel), Cond.Equal) chain ++= thenBranch.foldMap(generateStmt) @@ -183,8 +183,8 @@ object asmGenerator { chain += LabelDef(startLabel) chain ++= evalExprOntoStack(cond) - chain += Compare(stack.head(SizeDir.Word), ImmediateVal(0)) - chain += stack.drop() + chain += stack.pop(RAX) + chain += Compare(RAX, ImmediateVal(0)) chain += Jump(LabelArg(endLabel), Cond.Equal) chain ++= body.foldMap(generateStmt) @@ -251,20 +251,12 @@ object asmGenerator { chain += stack.push(RAX) case BinaryOperator.Div => - // chain += stack.pop(RDX) - // chain += stack.pop(RAX) - // chain += stack.push(RDX) - // chain += stack.push(RAX) chain += CDQ() chain += Divide(stack.head(SizeDir.Word)) chain += stack.drop() chain += stack.push(RAX) case BinaryOperator.Mod => - // chain += stack.pop(RDX) - // chain += stack.pop(RAX) - // chain += stack.push(RDX) - // chain += stack.push(RAX) chain += CDQ() chain += Divide(stack.head(SizeDir.Word)) chain += stack.drop() diff --git a/src/test/wacc/examples.scala b/src/test/wacc/examples.scala index c3189bb..e8a0e28 100644 --- a/src/test/wacc/examples.scala +++ b/src/test/wacc/examples.scala @@ -92,7 +92,7 @@ class ParallelExamplesSpec extends AnyFlatSpec with BeforeAndAfterAll with Paral "^.*wacc-examples/valid/expressions.*$", "^.*wacc-examples/valid/function/nested_functions.*$", "^.*wacc-examples/valid/function/simple_functions.*$", - "^.*wacc-examples/valid/if.*$", + // "^.*wacc-examples/valid/if.*$", "^.*wacc-examples/valid/IO/print.*$", // "^.*wacc-examples/valid/IO/read.*$", "^.*wacc-examples/valid/IO/IOLoop.wacc.*$", From f15530149edadd4999a6549f57692631076a91ea Mon Sep 17 00:00:00 2001 From: Guy C Date: Wed, 26 Feb 2025 01:43:12 +0000 Subject: [PATCH 36/54] fix: fix sub instruction code gen --- src/main/wacc/backend/asmGenerator.scala | 5 ++++- src/test/wacc/examples.scala | 2 +- 2 files changed, 5 insertions(+), 2 deletions(-) diff --git a/src/main/wacc/backend/asmGenerator.scala b/src/main/wacc/backend/asmGenerator.scala index f3df8a5..a9b5fba 100644 --- a/src/main/wacc/backend/asmGenerator.scala +++ b/src/main/wacc/backend/asmGenerator.scala @@ -244,7 +244,10 @@ object asmGenerator { op match { case BinaryOperator.Add => chain += Add(stack.head(SizeDir.Word), EAX) - case BinaryOperator.Sub => chain += Subtract(stack.head(SizeDir.Word), EAX) + case BinaryOperator.Sub => + chain += Subtract(EAX, stack.head(SizeDir.Word)) + chain += stack.drop() + chain += stack.push(RAX) case BinaryOperator.Mul => chain += Multiply(EAX, stack.head(SizeDir.Word)) chain += stack.drop() diff --git a/src/test/wacc/examples.scala b/src/test/wacc/examples.scala index e8a0e28..5887def 100644 --- a/src/test/wacc/examples.scala +++ b/src/test/wacc/examples.scala @@ -96,7 +96,7 @@ class ParallelExamplesSpec extends AnyFlatSpec with BeforeAndAfterAll with Paral "^.*wacc-examples/valid/IO/print.*$", // "^.*wacc-examples/valid/IO/read.*$", "^.*wacc-examples/valid/IO/IOLoop.wacc.*$", - "^.*wacc-examples/valid/IO/IOSequence.wacc.*$", + // "^.*wacc-examples/valid/IO/IOSequence.wacc.*$", "^.*wacc-examples/valid/pairs.*$", "^.*wacc-examples/valid/runtimeErr.*$", "^.*wacc-examples/valid/scope.*$", From fc2c58002eb12e1f4c325498fc1c6cb56c53290c Mon Sep 17 00:00:00 2001 From: Guy C Date: Wed, 26 Feb 2025 02:00:28 +0000 Subject: [PATCH 37/54] test: include variables tests in suite --- src/test/wacc/examples.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/test/wacc/examples.scala b/src/test/wacc/examples.scala index 5887def..88cdad5 100644 --- a/src/test/wacc/examples.scala +++ b/src/test/wacc/examples.scala @@ -101,7 +101,7 @@ class ParallelExamplesSpec extends AnyFlatSpec with BeforeAndAfterAll with Paral "^.*wacc-examples/valid/runtimeErr.*$", "^.*wacc-examples/valid/scope.*$", // "^.*wacc-examples/valid/sequence.*$", - "^.*wacc-examples/valid/variables.*$", + // "^.*wacc-examples/valid/variables.*$", "^.*wacc-examples/valid/while.*$", // format: on ).find(filename.matches).isDefined From 39f88f6f8a88542d02985fbd43804da5563713e3 Mon Sep 17 00:00:00 2001 From: Gleb Koval Date: Wed, 26 Feb 2025 03:31:11 +0000 Subject: [PATCH 38/54] test: disable parallel test execution to avoid race conditions --- src/test/wacc/examples.scala | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/test/wacc/examples.scala b/src/test/wacc/examples.scala index 88cdad5..2ef9100 100644 --- a/src/test/wacc/examples.scala +++ b/src/test/wacc/examples.scala @@ -1,6 +1,6 @@ package wacc -import org.scalatest.{ParallelTestExecution, BeforeAndAfterAll} +import org.scalatest.BeforeAndAfterAll import org.scalatest.flatspec.AnyFlatSpec import org.scalatest.Inspectors.forEvery import java.io.File @@ -8,7 +8,7 @@ import sys.process._ import java.io.PrintStream import scala.io.Source -class ParallelExamplesSpec extends AnyFlatSpec with BeforeAndAfterAll with ParallelTestExecution { +class ParallelExamplesSpec extends AnyFlatSpec with BeforeAndAfterAll { val files = allWaccFiles("wacc-examples/valid").map { p => (p.toString, List(0)) From 07c67dbef6cde2cadcdeaf2a27a70acefb34f288 Mon Sep 17 00:00:00 2001 From: Guy C Date: Wed, 26 Feb 2025 07:13:12 +0000 Subject: [PATCH 39/54] feat: add zero division error handling in asmGenerator --- src/main/wacc/backend/asmGenerator.scala | 33 +++++++++++++++++++++++- 1 file changed, 32 insertions(+), 1 deletion(-) diff --git a/src/main/wacc/backend/asmGenerator.scala b/src/main/wacc/backend/asmGenerator.scala index a9b5fba..7212e56 100644 --- a/src/main/wacc/backend/asmGenerator.scala +++ b/src/main/wacc/backend/asmGenerator.scala @@ -12,6 +12,24 @@ object asmGenerator { import wacc.types._ import lexer.escapedChars + abstract case class Error() { + def strLabel: String + def errStr: String + def errLabel: String + + def stringDef: Chain[AsmLine] = Chain( + Directive.Int(errStr.size), + LabelDef(strLabel), + Directive.Asciz(errStr) + ) + } + object zeroDivError extends Error { + // TODO: is this bad? Can we make an error case class/some other structure? + def strLabel = ".L._errDivZero_str0" + def errStr = "fatal error: division or modulo by zero" + def errLabel = ".L._errDivZero" + } + val RAX = Register(RegSize.R64, RegName.AX) val EAX = Register(RegSize.E32, RegName.AX) val ESP = Register(RegSize.E32, RegName.SP) @@ -66,7 +84,7 @@ object asmGenerator { LabelDef(s".L.str$i"), Directive.Asciz(str.escaped) ) - } + } ++ zeroDivError.stringDef Chain( Directive.IntelSyntax, @@ -135,6 +153,17 @@ object asmGenerator { ) ) + chain ++= Chain( + // TODO can this be done with a call to generateStmt? + // Consider other error cases -> look to generalise + LabelDef(zeroDivError.errLabel), + stack.align(), + Load(RDI, IndexAddress(RIP, LabelArg(zeroDivError.strLabel))), + assemblyIR.Call(CLibFunc.PrintF), + Move(RDI, ImmediateVal(-1)), + assemblyIR.Call(CLibFunc.Exit) + ) + chain } @@ -254,6 +283,8 @@ object asmGenerator { chain += stack.push(RAX) case BinaryOperator.Div => + chain += Compare(stack.head(SizeDir.Word), ImmediateVal(0)) + chain += Jump(LabelArg(zeroDivError.errLabel), Cond.Equal) chain += CDQ() chain += Divide(stack.head(SizeDir.Word)) chain += stack.drop() From 62df2c2244f90cddba8c41e34f2beae40509713f Mon Sep 17 00:00:00 2001 From: Alex Ling Date: Wed, 26 Feb 2025 16:58:01 +0000 Subject: [PATCH 40/54] fix: added dword in sizedir --- src/main/wacc/backend/asmGenerator.scala | 22 +++++++++++----------- src/main/wacc/backend/assemblyIR.scala | 5 +++-- 2 files changed, 14 insertions(+), 13 deletions(-) diff --git a/src/main/wacc/backend/asmGenerator.scala b/src/main/wacc/backend/asmGenerator.scala index 7212e56..4f3c4c7 100644 --- a/src/main/wacc/backend/asmGenerator.scala +++ b/src/main/wacc/backend/asmGenerator.scala @@ -260,9 +260,9 @@ object asmGenerator { chain ++= evalExprOntoStack(x) op match { case UnaryOperator.Chr | UnaryOperator.Ord | UnaryOperator.Len => // No op needed - case UnaryOperator.Negate => chain += Negate(stack.head(SizeDir.Word)) + case UnaryOperator.Negate => chain += Negate(stack.head(SizeDir.DWord)) case UnaryOperator.Not => - chain += Xor(stack.head(SizeDir.Word), ImmediateVal(1)) + chain += Xor(stack.head(SizeDir.DWord), ImmediateVal(1)) } case BinaryOp(x, y, op) => @@ -272,27 +272,27 @@ object asmGenerator { chain += stack.pop(RAX) op match { - case BinaryOperator.Add => chain += Add(stack.head(SizeDir.Word), EAX) + case BinaryOperator.Add => chain += Add(stack.head(SizeDir.DWord), EAX) case BinaryOperator.Sub => - chain += Subtract(EAX, stack.head(SizeDir.Word)) + chain += Subtract(EAX, stack.head(SizeDir.DWord)) chain += stack.drop() chain += stack.push(RAX) case BinaryOperator.Mul => - chain += Multiply(EAX, stack.head(SizeDir.Word)) + chain += Multiply(EAX, stack.head(SizeDir.DWord)) chain += stack.drop() chain += stack.push(RAX) case BinaryOperator.Div => - chain += Compare(stack.head(SizeDir.Word), ImmediateVal(0)) + chain += Compare(stack.head(SizeDir.DWord), ImmediateVal(0)) chain += Jump(LabelArg(zeroDivError.errLabel), Cond.Equal) chain += CDQ() - chain += Divide(stack.head(SizeDir.Word)) + chain += Divide(stack.head(SizeDir.DWord)) chain += stack.drop() chain += stack.push(RAX) case BinaryOperator.Mod => chain += CDQ() - chain += Divide(stack.head(SizeDir.Word)) + chain += Divide(stack.head(SizeDir.DWord)) chain += stack.drop() chain += stack.push(RDX) @@ -302,8 +302,8 @@ object asmGenerator { case BinaryOperator.GreaterEq => chain ++= generateComparison(x, y, Cond.GreaterEqual) case BinaryOperator.Less => chain ++= generateComparison(x, y, Cond.Less) case BinaryOperator.LessEq => chain ++= generateComparison(x, y, Cond.LessEqual) - case BinaryOperator.And => chain += And(stack.head(SizeDir.Word), EAX) - case BinaryOperator.Or => chain += Or(stack.head(SizeDir.Word), EAX) + case BinaryOperator.And => chain += And(stack.head(SizeDir.DWord), EAX) + case BinaryOperator.Or => chain += Or(stack.head(SizeDir.DWord), EAX) } case call: microWacc.Call => @@ -353,7 +353,7 @@ object asmGenerator { chain ++= evalExprOntoStack(x) chain ++= evalExprOntoStack(y) chain += stack.pop(RAX) - chain += Compare(stack.head(SizeDir.Word), EAX) + chain += Compare(stack.head(SizeDir.DWord), EAX) chain += Set(Register(RegSize.Byte, RegName.AL), cond) chain += And(RAX, ImmediateVal(_8_BIT_MASK)) chain += stack.drop() diff --git a/src/main/wacc/backend/assemblyIR.scala b/src/main/wacc/backend/assemblyIR.scala index 5a59fd1..5038be8 100644 --- a/src/main/wacc/backend/assemblyIR.scala +++ b/src/main/wacc/backend/assemblyIR.scala @@ -174,13 +174,14 @@ object assemblyIR { } enum SizeDir { - case Byte, Word, Unspecified + case Byte, Word, DWord, Unspecified private val ptr = "ptr " override def toString(): String = this match { case Byte => "byte " + ptr - case Word => "dword " + ptr // TODO check word/doubleword/quadword + case Word => "word " + ptr // TODO check word/doubleword/quadword + case DWord => "dword " + ptr case Unspecified => "" } } From 85190ce17467c45fea802aa5ace052ff1638359e Mon Sep 17 00:00:00 2001 From: Gleb Koval Date: Wed, 26 Feb 2025 18:08:30 +0000 Subject: [PATCH 41/54] refactor: make microWacc.ArrayElem recursive rather than flat --- src/main/wacc/frontend/microWacc.scala | 6 +----- src/main/wacc/frontend/typeChecker.scala | 13 +++++++++---- 2 files changed, 10 insertions(+), 9 deletions(-) diff --git a/src/main/wacc/frontend/microWacc.scala b/src/main/wacc/frontend/microWacc.scala index c558b6d..099fcc3 100644 --- a/src/main/wacc/frontend/microWacc.scala +++ b/src/main/wacc/frontend/microWacc.scala @@ -1,7 +1,5 @@ package wacc -import cats.data.NonEmptyList - object microWacc { import wacc.types._ @@ -19,9 +17,7 @@ object microWacc { extends Expr(identTy) with CallTarget(identTy) with LValue - case class ArrayElem(value: LValue, indices: NonEmptyList[Expr])(ty: SemType) - extends Expr(ty) - with LValue + case class ArrayElem(value: LValue, index: Expr)(ty: SemType) extends Expr(ty) with LValue // Operators case class UnaryOp(x: Expr, op: UnaryOperator)(ty: SemType) extends Expr(ty) diff --git a/src/main/wacc/frontend/typeChecker.scala b/src/main/wacc/frontend/typeChecker.scala index c3f2ba8..ca95342 100644 --- a/src/main/wacc/frontend/typeChecker.scala +++ b/src/main/wacc/frontend/typeChecker.scala @@ -422,10 +422,15 @@ object typeChecker { } (next, idxTyped) } - microWacc.ArrayElem( + val firstArrayElem = microWacc.ArrayElem( microWacc.Ident(id.v, id.uid)(arrayTy), - indicesTyped + indicesTyped.head )(elemTy.satisfies(constraint, value.pos)) + val arrayElem = indicesTyped.tail.foldLeft(firstArrayElem) { (acc, idx) => + microWacc.ArrayElem(acc, idx)(KnownType.Array(acc.ty)) + } + // Need to type-check the final arrayElem with the constraint + microWacc.ArrayElem(arrayElem.value, arrayElem.index)(elemTy.satisfies(constraint, value.pos)) case ast.Fst(elem) => val elemTyped = checkLValue( elem, @@ -433,7 +438,7 @@ object typeChecker { ) microWacc.ArrayElem( elemTyped, - NonEmptyList.of(microWacc.IntLiter(0)) + microWacc.IntLiter(0) )(elemTyped.ty match { case KnownType.Pair(left, _) => left.satisfies(constraint, elem.pos) @@ -446,7 +451,7 @@ object typeChecker { ) microWacc.ArrayElem( elemTyped, - NonEmptyList.of(microWacc.IntLiter(1)) + microWacc.IntLiter(1) )(elemTyped.ty match { case KnownType.Pair(_, right) => right.satisfies(constraint, elem.pos) From c748a34e4ce4514403d659dcf801850ad1dd5bde Mon Sep 17 00:00:00 2001 From: Gleb Koval Date: Wed, 26 Feb 2025 18:31:15 +0000 Subject: [PATCH 42/54] feat: user functions and calls --- src/main/wacc/backend/asmGenerator.scala | 27 +++++++++++++++++++----- src/main/wacc/frontend/typeChecker.scala | 2 +- 2 files changed, 23 insertions(+), 6 deletions(-) diff --git a/src/main/wacc/backend/asmGenerator.scala b/src/main/wacc/backend/asmGenerator.scala index 4f3c4c7..fe30af7 100644 --- a/src/main/wacc/backend/asmGenerator.scala +++ b/src/main/wacc/backend/asmGenerator.scala @@ -42,6 +42,7 @@ object asmGenerator { val RCX = Register(RegSize.R64, RegName.CX) val R8 = Register(RegSize.R64, RegName.Reg8) val R9 = Register(RegSize.R64, RegName.Reg9) + val argRegs = List(RDI, RSI, RDX, RCX, R8, R9) val _8_BIT_MASK = 0xff @@ -75,7 +76,8 @@ object asmGenerator { main.foldMap(generateStmt(_)), Chain.one(Move(RAX, ImmediateVal(0))), funcEpilogue(), - generateBuiltInFuncs() + generateBuiltInFuncs(), + funcs.foldMap(generateUserFunc(_)) ) val strDirs = strings.toList.zipWithIndex.foldMap { case (str, i) => @@ -111,6 +113,22 @@ object asmGenerator { chain } + def generateUserFunc(func: FuncDecl)(using + strings: ListBuffer[String], + labelGenerator: LabelGenerator + ): Chain[AsmLine] = { + given stack: Stack = Stack() + // Setup the stack with param 7 and up + func.params.drop(argRegs.size).foreach(stack.reserve(_)) + var chain = Chain.empty[AsmLine] + // Push the rest of params onto the stack for simplicity + argRegs.zip(func.params).foreach { (reg, param) => + chain += stack.push(param, reg) + } + chain ++= func.body.foldMap(generateStmt(_)) + wrapFunc(labelGenerator.getLabel(func.name), chain) + } + def generateBuiltInFuncs()(using stack: Stack, strings: ListBuffer[String], @@ -223,7 +241,7 @@ object asmGenerator { case microWacc.Return(expr) => chain ++= evalExprOntoStack(expr) chain += stack.pop(RAX) - chain += assemblyIR.Return() + chain ++= funcEpilogue() case call: microWacc.Call => chain ++= generateCall(call) @@ -321,7 +339,6 @@ object asmGenerator { labelGenerator: LabelGenerator ): Chain[AsmLine] = { var chain = Chain.empty[AsmLine] - val argRegs = List(RDI, RSI, RDX, RCX, R8, R9) val microWacc.Call(target, args) = call argRegs.zip(args).foldMap { (reg, expr) => @@ -373,7 +390,7 @@ object asmGenerator { def funcEpilogue()(using stack: Stack): Chain[AsmLine] = { var chain = Chain.empty[AsmLine] chain += Move(Register(RegSize.R64, RegName.SP), RBP) - chain += stack.pop(RBP) + chain += Pop(RBP) chain += assemblyIR.Return() chain } @@ -382,7 +399,7 @@ object asmGenerator { private val stack = LinkedHashMap[Expr | Int, Int]() private val RSP = Register(RegSize.R64, RegName.SP) - def next: Int = stack.size + 1 + private def next: Int = stack.size + 1 def push(expr: Expr, src: Src): AsmLine = { stack += expr -> next Push(src) diff --git a/src/main/wacc/frontend/typeChecker.scala b/src/main/wacc/frontend/typeChecker.scala index ca95342..8c11550 100644 --- a/src/main/wacc/frontend/typeChecker.scala +++ b/src/main/wacc/frontend/typeChecker.scala @@ -110,7 +110,7 @@ object typeChecker { microWacc.FuncDecl( microWacc.Ident(name.v, name.uid)(retType), params.zip(paramTypes).map { case (ast.Param(_, ident), ty) => - microWacc.Ident(ident.v, name.uid)(ty) + microWacc.Ident(ident.v, ident.uid)(ty) }, stmts.toList .flatMap( From 16de964f74dca97eba4593d3805f39c519a72eb0 Mon Sep 17 00:00:00 2001 From: Gleb Koval Date: Wed, 26 Feb 2025 18:45:03 +0000 Subject: [PATCH 43/54] refactor: do not append epilogue to user functions since they all return anyway --- src/main/wacc/backend/asmGenerator.scala | 26 +++++++++++------------- 1 file changed, 12 insertions(+), 14 deletions(-) diff --git a/src/main/wacc/backend/asmGenerator.scala b/src/main/wacc/backend/asmGenerator.scala index fe30af7..6fbbd82 100644 --- a/src/main/wacc/backend/asmGenerator.scala +++ b/src/main/wacc/backend/asmGenerator.scala @@ -99,17 +99,13 @@ object asmGenerator { ).toList } - def wrapFunc(labelName: String, funcBody: Chain[AsmLine])(using - stack: Stack, - strings: ListBuffer[String] + private def wrapBuiltinFunc(labelName: String, funcBody: Chain[AsmLine])(using + stack: Stack ): Chain[AsmLine] = { - var chain = Chain.empty[AsmLine] - - chain += LabelDef(labelName) + var chain = Chain.one[AsmLine](LabelDef(labelName)) chain ++= funcPrologue() chain ++= funcBody chain ++= funcEpilogue() - chain } @@ -120,13 +116,15 @@ object asmGenerator { given stack: Stack = Stack() // Setup the stack with param 7 and up func.params.drop(argRegs.size).foreach(stack.reserve(_)) - var chain = Chain.empty[AsmLine] + var chain = Chain.one[AsmLine](LabelDef(labelGenerator.getLabel(func.name))) + chain ++= funcPrologue() // Push the rest of params onto the stack for simplicity argRegs.zip(func.params).foreach { (reg, param) => chain += stack.push(param, reg) } chain ++= func.body.foldMap(generateStmt(_)) - wrapFunc(labelGenerator.getLabel(func.name), chain) + // No need for epilogue here since all user functions must return explicitly + chain } def generateBuiltInFuncs()(using @@ -136,12 +134,12 @@ object asmGenerator { ): Chain[AsmLine] = { var chain = Chain.empty[AsmLine] - chain ++= wrapFunc( + chain ++= wrapBuiltinFunc( labelGenerator.getLabel(Builtin.Exit), Chain(stack.align(), assemblyIR.Call(CLibFunc.Exit)) ) - chain ++= wrapFunc( + chain ++= wrapBuiltinFunc( labelGenerator.getLabel(Builtin.Printf), Chain( stack.align(), @@ -151,14 +149,14 @@ object asmGenerator { ) ) - chain ++= wrapFunc( + chain ++= wrapBuiltinFunc( labelGenerator.getLabel(Builtin.Malloc), Chain.one(stack.align()) ) - chain ++= wrapFunc(labelGenerator.getLabel(Builtin.Free), Chain.empty) + chain ++= wrapBuiltinFunc(labelGenerator.getLabel(Builtin.Free), Chain.empty) - chain ++= wrapFunc( + chain ++= wrapBuiltinFunc( labelGenerator.getLabel(Builtin.Read), Chain( stack.align(), From 4fb399a5e1106e4605740666f1c6561c18e56bfd Mon Sep 17 00:00:00 2001 From: Gleb Koval Date: Wed, 26 Feb 2025 19:14:12 +0000 Subject: [PATCH 44/54] feat: generate microWacc for printing booleans --- src/main/wacc/frontend/typeChecker.scala | 37 ++++++++++++++++-------- src/test/wacc/examples.scala | 2 +- 2 files changed, 26 insertions(+), 13 deletions(-) diff --git a/src/main/wacc/frontend/typeChecker.scala b/src/main/wacc/frontend/typeChecker.scala index 8c11550..f571e11 100644 --- a/src/main/wacc/frontend/typeChecker.scala +++ b/src/main/wacc/frontend/typeChecker.scala @@ -216,21 +216,34 @@ object typeChecker { case ast.Print(expr, newline) => // This constraint should never fail, the scope-checker should have caught it already val exprTyped = checkValue(expr, Constraint.Unconstrained) - val format = exprTyped.ty match { - case KnownType.Bool | KnownType.String => "%s" - case KnownType.Char => "%c" - case KnownType.Int => "%d" - case _ => "%p" + val exprFormat = exprTyped.ty match { + case KnownType.Bool | KnownType.String => "%s" + case KnownType.Char => "%c" + case KnownType.Int => "%d" + case KnownType.Pair(_, _) | KnownType.Array(_) | ? => "%p" } - List( - microWacc.Call( - microWacc.Builtin.Printf, - List( - s"$format${if newline then "\n" else ""}".toMicroWaccCharArray, - exprTyped + val printfCall = { (value: microWacc.Expr) => + List( + microWacc.Call( + microWacc.Builtin.Printf, + List( + s"$exprFormat${if newline then "\n" else ""}".toMicroWaccCharArray, + value + ) ) ) - ) + } + exprTyped.ty match { + case KnownType.Bool => + List( + microWacc.If( + exprTyped, + printfCall("true".toMicroWaccCharArray), + printfCall("false".toMicroWaccCharArray) + ) + ) + case _ => printfCall(exprTyped) + } case ast.If(cond, thenStmt, elseStmt) => List( microWacc.If( diff --git a/src/test/wacc/examples.scala b/src/test/wacc/examples.scala index 2ef9100..0da3659 100644 --- a/src/test/wacc/examples.scala +++ b/src/test/wacc/examples.scala @@ -89,7 +89,7 @@ class ParallelExamplesSpec extends AnyFlatSpec with BeforeAndAfterAll { "^.*wacc-examples/valid/array.*$", // "^.*wacc-examples/valid/basic/exit.*$", // "^.*wacc-examples/valid/basic/skip.*$", - "^.*wacc-examples/valid/expressions.*$", + // "^.*wacc-examples/valid/expressions.*$", "^.*wacc-examples/valid/function/nested_functions.*$", "^.*wacc-examples/valid/function/simple_functions.*$", // "^.*wacc-examples/valid/if.*$", From 631f9ddca555ec8a4c6bf34848e167dc469c60c9 Mon Sep 17 00:00:00 2001 From: Jonny Date: Wed, 26 Feb 2025 19:49:10 +0000 Subject: [PATCH 45/54] feat: (maybe) tail call optimisation --- src/main/wacc/backend/asmGenerator.scala | 28 ++++++++++++++++-------- 1 file changed, 19 insertions(+), 9 deletions(-) diff --git a/src/main/wacc/backend/asmGenerator.scala b/src/main/wacc/backend/asmGenerator.scala index 6fbbd82..1b8a39b 100644 --- a/src/main/wacc/backend/asmGenerator.scala +++ b/src/main/wacc/backend/asmGenerator.scala @@ -236,13 +236,18 @@ object asmGenerator { chain += Jump(LabelArg(startLabel)) chain += LabelDef(endLabel) - case microWacc.Return(expr) => - chain ++= evalExprOntoStack(expr) - chain += stack.pop(RAX) - chain ++= funcEpilogue() - case call: microWacc.Call => - chain ++= generateCall(call) + chain ++= generateCall(call, isTail = false) + + case microWacc.Return(expr) => + expr match { + case call: microWacc.Call => + chain ++= generateCall(call, isTail = true) // tco + case _ => + chain ++= evalExprOntoStack(expr) + chain += stack.pop(RAX) + chain ++= funcEpilogue() + } } chain @@ -323,7 +328,7 @@ object asmGenerator { } case call: microWacc.Call => - chain ++= generateCall(call) + chain ++= generateCall(call, isTail = false) chain += stack.push(RAX) } @@ -331,7 +336,7 @@ object asmGenerator { chain } - def generateCall(call: microWacc.Call)(using + def generateCall(call: microWacc.Call, isTail: Boolean)(using stack: Stack, strings: ListBuffer[String], labelGenerator: LabelGenerator @@ -348,7 +353,12 @@ object asmGenerator { chain ++= evalExprOntoStack(_) } - chain += assemblyIR.Call(LabelArg(labelGenerator.getLabel(target))) + // Tail Call Optimisation (TCO) + if (isTail) { + chain += Jump(LabelArg(labelGenerator.getLabel(target))) // tail call + } else { + chain += assemblyIR.Call(LabelArg(labelGenerator.getLabel(target))) // regular call + } if (args.size > argRegs.size) { chain += stack.drop(args.size - argRegs.size) From 2cf18a47a8f392b5aa2d3fb87fa921d925a1663e Mon Sep 17 00:00:00 2001 From: Gleb Koval Date: Wed, 26 Feb 2025 20:00:42 +0000 Subject: [PATCH 46/54] fix: only push one item to stack on comparisons --- src/main/wacc/backend/asmGenerator.scala | 20 ++++++++------------ src/main/wacc/backend/assemblyIR.scala | 5 +++++ src/test/wacc/examples.scala | 2 +- 3 files changed, 14 insertions(+), 13 deletions(-) diff --git a/src/main/wacc/backend/asmGenerator.scala b/src/main/wacc/backend/asmGenerator.scala index 1b8a39b..0d894bc 100644 --- a/src/main/wacc/backend/asmGenerator.scala +++ b/src/main/wacc/backend/asmGenerator.scala @@ -317,12 +317,12 @@ object asmGenerator { chain += stack.drop() chain += stack.push(RDX) - case BinaryOperator.Eq => chain ++= generateComparison(x, y, Cond.Equal) - case BinaryOperator.Neq => chain ++= generateComparison(x, y, Cond.NotEqual) - case BinaryOperator.Greater => chain ++= generateComparison(x, y, Cond.Greater) - case BinaryOperator.GreaterEq => chain ++= generateComparison(x, y, Cond.GreaterEqual) - case BinaryOperator.Less => chain ++= generateComparison(x, y, Cond.Less) - case BinaryOperator.LessEq => chain ++= generateComparison(x, y, Cond.LessEqual) + case BinaryOperator.Eq => chain ++= generateComparison(Cond.Equal) + case BinaryOperator.Neq => chain ++= generateComparison(Cond.NotEqual) + case BinaryOperator.Greater => chain ++= generateComparison(Cond.Greater) + case BinaryOperator.GreaterEq => chain ++= generateComparison(Cond.GreaterEqual) + case BinaryOperator.Less => chain ++= generateComparison(Cond.Less) + case BinaryOperator.LessEq => chain ++= generateComparison(Cond.LessEqual) case BinaryOperator.And => chain += And(stack.head(SizeDir.DWord), EAX) case BinaryOperator.Or => chain += Or(stack.head(SizeDir.DWord), EAX) } @@ -367,18 +367,14 @@ object asmGenerator { chain } - def generateComparison(x: Expr, y: Expr, cond: Cond)(using + def generateComparison(cond: Cond)(using stack: Stack, strings: ListBuffer[String], labelGenerator: LabelGenerator ): Chain[AsmLine] = { - var chain = Chain.empty[AsmLine] - chain ++= evalExprOntoStack(x) - chain ++= evalExprOntoStack(y) - chain += stack.pop(RAX) - chain += Compare(stack.head(SizeDir.DWord), EAX) + chain += Compare(EAX, stack.head(SizeDir.DWord)) chain += Set(Register(RegSize.Byte, RegName.AL), cond) chain += And(RAX, ImmediateVal(_8_BIT_MASK)) chain += stack.drop() diff --git a/src/main/wacc/backend/assemblyIR.scala b/src/main/wacc/backend/assemblyIR.scala index 5038be8..1ff8906 100644 --- a/src/main/wacc/backend/assemblyIR.scala +++ b/src/main/wacc/backend/assemblyIR.scala @@ -126,6 +126,11 @@ object assemblyIR { override def toString = s"$name:" } + case class Comment(comment: String) extends AsmLine { + override def toString = + comment.split("\n").map(line => s"# ${line}").mkString("\n") + } + enum Cond { case Equal, NotEqual, diff --git a/src/test/wacc/examples.scala b/src/test/wacc/examples.scala index 0da3659..7f8538d 100644 --- a/src/test/wacc/examples.scala +++ b/src/test/wacc/examples.scala @@ -102,7 +102,7 @@ class ParallelExamplesSpec extends AnyFlatSpec with BeforeAndAfterAll { "^.*wacc-examples/valid/scope.*$", // "^.*wacc-examples/valid/sequence.*$", // "^.*wacc-examples/valid/variables.*$", - "^.*wacc-examples/valid/while.*$", + // "^.*wacc-examples/valid/while.*$", // format: on ).find(filename.matches).isDefined } From 09df7af2ab58a164f1fab5c6168342f38fa637ee Mon Sep 17 00:00:00 2001 From: Gleb Koval Date: Wed, 26 Feb 2025 20:25:27 +0000 Subject: [PATCH 47/54] fix: reset scope after all branching --- src/main/wacc/backend/asmGenerator.scala | 34 +++++++++++++++++++----- src/test/wacc/examples.scala | 3 ++- 2 files changed, 30 insertions(+), 7 deletions(-) diff --git a/src/main/wacc/backend/asmGenerator.scala b/src/main/wacc/backend/asmGenerator.scala index 0d894bc..0e0643e 100644 --- a/src/main/wacc/backend/asmGenerator.scala +++ b/src/main/wacc/backend/asmGenerator.scala @@ -4,7 +4,6 @@ import scala.collection.mutable.LinkedHashMap import scala.collection.mutable.ListBuffer import cats.data.Chain import cats.syntax.foldable._ -// import parsley.token.errors.Label object asmGenerator { import microWacc._ @@ -183,13 +182,33 @@ object asmGenerator { chain } + /** Wraps a chain in a stack reset. + * + * This is useful for ensuring that the stack size at the death of scope is the same as the stack + * size at the start of the scope. See branching (If / While) + * + * @param genChain + * Function that generates the scope AsmLines + * @param stack + * The stack to reset + * @return + * The generated scope AsmLines + */ + private def generateScope(genChain: () => Chain[AsmLine])(using + stack: Stack + ): Chain[AsmLine] = { + val stackSizeStart = stack.size + var chain = genChain() + chain += stack.drop(stack.size - stackSizeStart) + chain + } + def generateStmt(stmt: Stmt)(using stack: Stack, strings: ListBuffer[String], labelGenerator: LabelGenerator ): Chain[AsmLine] = { var chain = Chain.empty[AsmLine] - stmt match { case Assign(lhs, rhs) => var dest: () => IndexAddress = () => IndexAddress(RAX, 0) // overwritten below @@ -215,11 +234,11 @@ object asmGenerator { chain += Compare(RAX, ImmediateVal(0)) chain += Jump(LabelArg(elseLabel), Cond.Equal) - chain ++= thenBranch.foldMap(generateStmt) + chain ++= generateScope(() => thenBranch.foldMap(generateStmt)) chain += Jump(LabelArg(endLabel)) chain += LabelDef(elseLabel) - chain ++= elseBranch.foldMap(generateStmt) + chain ++= generateScope(() => elseBranch.foldMap(generateStmt)) chain += LabelDef(endLabel) case While(cond, body) => @@ -232,7 +251,7 @@ object asmGenerator { chain += Compare(RAX, ImmediateVal(0)) chain += Jump(LabelArg(endLabel), Cond.Equal) - chain ++= body.foldMap(generateStmt) + chain ++= generateScope(() => body.foldMap(generateStmt)) chain += Jump(LabelArg(startLabel)) chain += LabelDef(endLabel) @@ -259,7 +278,7 @@ object asmGenerator { labelGenerator: LabelGenerator ): Chain[AsmLine] = { var chain = Chain.empty[AsmLine] - + val stackSizeStart = stack.size expr match { case IntLiter(v) => chain += stack.push(ImmediateVal(v)) case CharLiter(v) => chain += stack.push(ImmediateVal(v.toInt)) @@ -333,6 +352,8 @@ object asmGenerator { } if chain.isEmpty then chain += stack.push(ImmediateVal(0)) + + assert(stack.size == stackSizeStart + 1) chain } @@ -404,6 +425,7 @@ object asmGenerator { private val RSP = Register(RegSize.R64, RegName.SP) private def next: Int = stack.size + 1 + def size: Int = stack.size def push(expr: Expr, src: Src): AsmLine = { stack += expr -> next Push(src) diff --git a/src/test/wacc/examples.scala b/src/test/wacc/examples.scala index 7f8538d..8ac0aa4 100644 --- a/src/test/wacc/examples.scala +++ b/src/test/wacc/examples.scala @@ -99,7 +99,8 @@ class ParallelExamplesSpec extends AnyFlatSpec with BeforeAndAfterAll { // "^.*wacc-examples/valid/IO/IOSequence.wacc.*$", "^.*wacc-examples/valid/pairs.*$", "^.*wacc-examples/valid/runtimeErr.*$", - "^.*wacc-examples/valid/scope.*$", + // "^.*wacc-examples/valid/scope.*$", + "^.*wacc-examples/valid/scope/printAllTypes.wacc$", // while we still don't have arrays implemented // "^.*wacc-examples/valid/sequence.*$", // "^.*wacc-examples/valid/variables.*$", // "^.*wacc-examples/valid/while.*$", From bdee6ba756f8756625d69cbc13a85ba68dd3ac0c Mon Sep 17 00:00:00 2001 From: Jonny Date: Wed, 26 Feb 2025 21:39:23 +0000 Subject: [PATCH 48/54] feat: zero registers via xor --- src/main/wacc/backend/asmGenerator.scala | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/src/main/wacc/backend/asmGenerator.scala b/src/main/wacc/backend/asmGenerator.scala index 0e0643e..54fc7db 100644 --- a/src/main/wacc/backend/asmGenerator.scala +++ b/src/main/wacc/backend/asmGenerator.scala @@ -73,7 +73,7 @@ object asmGenerator { funcPrologue(), Chain.one(stack.align()), main.foldMap(generateStmt(_)), - Chain.one(Move(RAX, ImmediateVal(0))), + Chain.one(Xor(RAX, RAX)), funcEpilogue(), generateBuiltInFuncs(), funcs.foldMap(generateUserFunc(_)) @@ -143,7 +143,7 @@ object asmGenerator { Chain( stack.align(), assemblyIR.Call(CLibFunc.PrintF), - Move(RDI, ImmediateVal(0)), + Xor(RDI, RDI), assemblyIR.Call(CLibFunc.Fflush) ) ) @@ -293,7 +293,10 @@ object asmGenerator { case _ => // Other array types TODO } - case BoolLiter(v) => chain += stack.push(ImmediateVal(if (v) 1 else 0)) + case BoolLiter(true) => chain += stack.push(ImmediateVal(1)) + case BoolLiter(false) => + chain += Xor(RAX, RAX) + chain += stack.push(RAX) case NullLiter() => chain += stack.push(ImmediateVal(0)) case ArrayElem(_, _) => // TODO: Implement handling case UnaryOp(x, op) => From 808a59f58ac211e987d920d76814a47a78d18548 Mon Sep 17 00:00:00 2001 From: Alex Ling Date: Wed, 26 Feb 2025 21:12:50 +0000 Subject: [PATCH 49/54] feat: almost implemented arrays --- src/main/wacc/backend/asmGenerator.scala | 69 +++++++++++++++++++----- src/main/wacc/backend/assemblyIR.scala | 19 +++++-- src/main/wacc/frontend/microWacc.scala | 1 + src/main/wacc/frontend/typeChecker.scala | 1 + src/test/wacc/examples.scala | 4 +- 5 files changed, 75 insertions(+), 19 deletions(-) diff --git a/src/main/wacc/backend/asmGenerator.scala b/src/main/wacc/backend/asmGenerator.scala index 54fc7db..688e474 100644 --- a/src/main/wacc/backend/asmGenerator.scala +++ b/src/main/wacc/backend/asmGenerator.scala @@ -148,9 +148,22 @@ object asmGenerator { ) ) + chain ++= wrapBuiltinFunc( + labelGenerator.getLabel(Builtin.PrintCharArray), + Chain( + stack.align(), + Load(RDX, IndexAddress(RSI, 8)), + Move(RSI, MemLocation(RSI)), + assemblyIR.Call(CLibFunc.PrintF), + Xor(RDI, RDI), + assemblyIR.Call(CLibFunc.Fflush) + ) + ) + chain ++= wrapBuiltinFunc( labelGenerator.getLabel(Builtin.Malloc), - Chain.one(stack.align()) + Chain(stack.align(), assemblyIR.Call(CLibFunc.Malloc)) + // Out of memory check is optional ) chain ++= wrapBuiltinFunc(labelGenerator.getLabel(Builtin.Free), Chain.empty) @@ -211,19 +224,25 @@ object asmGenerator { var chain = Chain.empty[AsmLine] stmt match { case Assign(lhs, rhs) => - var dest: () => IndexAddress = () => IndexAddress(RAX, 0) // overwritten below lhs match { case ident: Ident => - dest = stack.accessVar(ident) + val dest = stack.accessVar(ident) if (!stack.contains(ident)) chain += stack.reserve(ident) - // TODO lhs = arrayElem - case _ => - } - chain ++= evalExprOntoStack(rhs) - chain += stack.pop(RAX) - chain += Move(dest(), RAX) + chain ++= evalExprOntoStack(rhs) + chain += stack.pop(RDX) + chain += Move(dest(), RDX) + case ArrayElem(x, i) => + chain ++= evalExprOntoStack(x) + chain ++= evalExprOntoStack(i) + chain ++= evalExprOntoStack(rhs) + chain += stack.pop(RAX) + chain += stack.pop(RCX) + chain += stack.pop(RDX) + + chain += Move(IndexAddress(RDX, 8, RCX, 8), RAX) + } case If(cond, thenBranch, elseBranch) => val elseLabel = labelGenerator.getLabel() @@ -290,19 +309,43 @@ object asmGenerator { strings += elems.collect { case CharLiter(v) => v }.mkString chain += Load(RAX, IndexAddress(RIP, LabelArg(s".L.str${strings.size - 1}"))) chain += stack.push(RAX) - case _ => // Other array types TODO + case _ => + chain ++= generateCall( + microWacc.Call(Builtin.Malloc, List(IntLiter((elems.size + 1) * 8))), + isTail = false + ) + chain += stack.push(RAX) + // Store the length of the array at the start + chain += Move(MemLocation(RAX, SizeDir.DWord), ImmediateVal(elems.size)) + elems.zipWithIndex.foldMap { (elem, i) => + chain ++= evalExprOntoStack(elem) + chain += stack.pop(RCX) + chain += stack.pop(RAX) + chain += Move(IndexAddress(RAX, 8 * (i + 1)), RCX) + chain += stack.push(RAX) + } } case BoolLiter(true) => chain += stack.push(ImmediateVal(1)) case BoolLiter(false) => chain += Xor(RAX, RAX) chain += stack.push(RAX) - case NullLiter() => chain += stack.push(ImmediateVal(0)) - case ArrayElem(_, _) => // TODO: Implement handling + case NullLiter() => chain += stack.push(ImmediateVal(0)) + case ArrayElem(x, i) => + chain ++= evalExprOntoStack(x) + chain ++= evalExprOntoStack(i) + chain += stack.pop(RCX) + chain += stack.pop(RAX) + // + 1 because we store the length of the array at the start + chain += stack.push(IndexAddress(RAX, 8, RCX, 8)) case UnaryOp(x, op) => chain ++= evalExprOntoStack(x) op match { - case UnaryOperator.Chr | UnaryOperator.Ord | UnaryOperator.Len => // No op needed + case UnaryOperator.Chr | UnaryOperator.Ord => // No op needed + case UnaryOperator.Len => + // Access the elem + chain += stack.pop(RAX) + chain += Push(MemLocation(RAX)) case UnaryOperator.Negate => chain += Negate(stack.head(SizeDir.DWord)) case UnaryOperator.Not => chain += Xor(stack.head(SizeDir.DWord), ImmediateVal(1)) diff --git a/src/main/wacc/backend/assemblyIR.scala b/src/main/wacc/backend/assemblyIR.scala index 1ff8906..f8bbf38 100644 --- a/src/main/wacc/backend/assemblyIR.scala +++ b/src/main/wacc/backend/assemblyIR.scala @@ -48,7 +48,9 @@ object assemblyIR { case Scanf, Fflush, Exit, - PrintF + PrintF, + Malloc, + Free private val plt = "@plt" @@ -57,6 +59,8 @@ object assemblyIR { case Fflush => "fflush" + plt case Exit => "exit" + plt case PrintF => "printf" + plt + case Malloc => "malloc" + plt + case Free => "free" + plt } } @@ -72,13 +76,20 @@ object assemblyIR { case reg: Register => opSize.toString + s"[$reg]" } } + + // TODO to string is wacky case class IndexAddress( base: Register, offset: Int | LabelArg, - opSize: SizeDir = SizeDir.Unspecified + indexReg: Register = Register(RegSize.R64, RegName.AX), + scale: Int = 0 ) extends Dest with Src { - override def toString = s"$opSize[$base + $offset]" + override def toString = if (scale != 0) { + s"[$base + $indexReg * $scale + $offset]" + } else { + s"[$base + $offset]" + } } case class ImmediateVal(value: Int) extends Src { @@ -185,7 +196,7 @@ object assemblyIR { override def toString(): String = this match { case Byte => "byte " + ptr - case Word => "word " + ptr // TODO check word/doubleword/quadword + case Word => "word " + ptr case DWord => "dword " + ptr case Unspecified => "" } diff --git a/src/main/wacc/frontend/microWacc.scala b/src/main/wacc/frontend/microWacc.scala index 099fcc3..e2c1bdc 100644 --- a/src/main/wacc/frontend/microWacc.scala +++ b/src/main/wacc/frontend/microWacc.scala @@ -74,6 +74,7 @@ object microWacc { object Exit extends Builtin("exit")(?) object Free extends Builtin("free")(?) object Malloc extends Builtin("malloc")(?) + object PrintCharArray extends Builtin("printCharArray")(?) } case class Assign(lhs: LValue, rhs: Expr) extends Stmt diff --git a/src/main/wacc/frontend/typeChecker.scala b/src/main/wacc/frontend/typeChecker.scala index f571e11..2c430e5 100644 --- a/src/main/wacc/frontend/typeChecker.scala +++ b/src/main/wacc/frontend/typeChecker.scala @@ -218,6 +218,7 @@ object typeChecker { val exprTyped = checkValue(expr, Constraint.Unconstrained) val exprFormat = exprTyped.ty match { case KnownType.Bool | KnownType.String => "%s" + case KnownType.Array(KnownType.Char) => "%.*s" case KnownType.Char => "%c" case KnownType.Int => "%d" case KnownType.Pair(_, _) | KnownType.Array(_) | ? => "%p" diff --git a/src/test/wacc/examples.scala b/src/test/wacc/examples.scala index 8ac0aa4..87def2a 100644 --- a/src/test/wacc/examples.scala +++ b/src/test/wacc/examples.scala @@ -73,7 +73,7 @@ class ParallelExamplesSpec extends AnyFlatSpec with BeforeAndAfterAll { ) assert(process.exitValue == expectedExit) - assert(stdout.toString == expectedOutput) + assert(stdout.toString.replaceAll("0x[0-9a-f]+", "#addrs#") == expectedOutput) } } @@ -86,7 +86,7 @@ class ParallelExamplesSpec extends AnyFlatSpec with BeforeAndAfterAll { // format: off // disable formatting to avoid binPack "^.*wacc-examples/valid/advanced.*$", - "^.*wacc-examples/valid/array.*$", + // "^.*wacc-examples/valid/array.*$", // "^.*wacc-examples/valid/basic/exit.*$", // "^.*wacc-examples/valid/basic/skip.*$", // "^.*wacc-examples/valid/expressions.*$", From 58df1d7bb968b47d6ea942c5cabe89e1f1e3f69f Mon Sep 17 00:00:00 2001 From: Gleb Koval Date: Thu, 27 Feb 2025 01:49:22 +0000 Subject: [PATCH 50/54] refactor: extract Stack, proper register naming and sizes --- src/main/wacc/backend/Stack.scala | 84 ++++++++++++++ src/main/wacc/backend/assemblyIR.scala | 122 ++++++++++++--------- src/main/wacc/backend/sizeExtensions.scala | 29 +++++ 3 files changed, 181 insertions(+), 54 deletions(-) create mode 100644 src/main/wacc/backend/Stack.scala create mode 100644 src/main/wacc/backend/sizeExtensions.scala diff --git a/src/main/wacc/backend/Stack.scala b/src/main/wacc/backend/Stack.scala new file mode 100644 index 0000000..72aa5ef --- /dev/null +++ b/src/main/wacc/backend/Stack.scala @@ -0,0 +1,84 @@ +package wacc + +import scala.collection.mutable.LinkedHashMap + +class Stack { + import assemblyIR._ + import sizeExtensions.size + import microWacc as mw + + private val RSP = Register(Size.Q64, RegName.SP) + private class StackValue(val size: Size, val offset: Int) { + def bottom: Int = offset + size.toInt + } + private val stack = LinkedHashMap[mw.Expr | Int, StackValue]() + + /** The stack's size in bytes. */ + def size: Int = if stack.isEmpty then 0 else stack.last._2.bottom + + /** Push an expression onto the stack. */ + def push(expr: mw.Expr, src: Register): AsmLine = { + stack += expr -> StackValue(src.size, size) + Push(src) + } + + /** Push an arbitrary register onto the stack. */ + def push(src: Register): AsmLine = { + stack += stack.size -> StackValue(src.size, size) + Push(src) + } + + /** Reserve space for a variable on the stack. */ + def reserve(ident: mw.Ident): AsmLine = { + stack += ident -> StackValue(ident.ty.size, size) + Subtract(RSP, ImmediateVal(ident.ty.size.toInt)) + } + + /** Reserve space for values on the stack. + * + * @param sizes + * The sizes of the values to reserve space for. + */ + def reserve(sizes: List[Size]): AsmLine = { + val totalSize = sizes + .map(itemSize => + stack += stack.size -> StackValue(itemSize, size) + itemSize.toInt + ) + .sum + Subtract(RSP, ImmediateVal(totalSize)) + } + + /** Pop a value from the stack into a register. Sizes MUST match. */ + def pop(dest: Register): AsmLine = { + if (dest.size != stack.last._2.size) { + throw new IllegalArgumentException( + s"Cannot pop ${stack.last._2.size} bytes into $dest (${dest.size} bytes) register" + ) + } + stack.remove(stack.last._1) + Pop(dest) + } + + /** Drop the top n values from the stack. */ + def drop(n: Int = 1): AsmLine = { + val totalSize = (1 to n) + .map(_ => + val itemSize = stack.last._2.size.toInt + stack.remove(stack.last._1) + itemSize + ) + .sum + Add(RSP, ImmediateVal(totalSize)) + } + + /** Get a lazy IndexAddress for a variable in the stack. */ + def accessVar(ident: mw.Ident): () => IndexAddress = () => { + IndexAddress(RSP, stack.size - stack(ident).bottom) + } + def contains(ident: mw.Ident): Boolean = stack.contains(ident) + def head: MemLocation = MemLocation(RSP) + def head(offset: Size): MemLocation = MemLocation(RSP, Some(offset)) + // TODO: Might want to actually properly handle this with the LinkedHashMap too + def align(): AsmLine = And(RSP, ImmediateVal(-16)) +} diff --git a/src/main/wacc/backend/assemblyIR.scala b/src/main/wacc/backend/assemblyIR.scala index f8bbf38..2946dcb 100644 --- a/src/main/wacc/backend/assemblyIR.scala +++ b/src/main/wacc/backend/assemblyIR.scala @@ -6,40 +6,73 @@ object assemblyIR { sealed trait Operand sealed trait Src extends Operand // mem location, register and imm value sealed trait Dest extends Operand // mem location and register - enum RegSize { - case R64 - case E32 - case Byte - override def toString = this match { - case R64 => "r" - case E32 => "e" - case Byte => "" + enum Size { + case Q64, D32, W16, B8 + + def toInt: Int = this match { + case Q64 => 8 + case D32 => 4 + case W16 => 2 + case B8 => 1 + } + + private val ptr = "ptr " + + override def toString(): String = this match { + case Q64 => "qword " + ptr + case D32 => "dword " + ptr + case W16 => "word " + ptr + case B8 => "byte " + ptr } } enum RegName { - case AX, AL, BX, CX, DX, SI, DI, SP, BP, IP, Reg8, Reg9, Reg10, Reg11, Reg12, Reg13, Reg14, - Reg15 - override def toString = this match { - case AX => "ax" - case AL => "al" - case BX => "bx" - case CX => "cx" - case DX => "dx" - case SI => "si" - case DI => "di" - case SP => "sp" - case BP => "bp" - case IP => "ip" - case Reg8 => "8" - case Reg9 => "9" - case Reg10 => "10" - case Reg11 => "11" - case Reg12 => "12" - case Reg13 => "13" - case Reg14 => "14" - case Reg15 => "15" + case AX, BX, CX, DX, SI, DI, SP, BP, IP, R8, R9, R10, R11, R12, R13, R14, R15 + } + + case class Register(size: Size, name: RegName) extends Dest with Src { + import RegName._ + + if (size == Size.B8 && name == RegName.IP) { + throw new IllegalArgumentException("Cannot have 8 bit register for IP") + } + override def toString = name match { + case AX => tradToString("ax", "al") + case BX => tradToString("bx", "bl") + case CX => tradToString("cx", "cl") + case DX => tradToString("dx", "dl") + case SI => tradToString("si", "sil") + case DI => tradToString("di", "dil") + case SP => tradToString("sp", "spl") + case BP => tradToString("bp", "bpl") + case IP => tradToString("ip", "#INVALID") + case R8 => newToString(8) + case R9 => newToString(9) + case R10 => newToString(10) + case R11 => newToString(11) + case R12 => newToString(12) + case R13 => newToString(13) + case R14 => newToString(14) + case R15 => newToString(15) + } + + private def tradToString(base: String, byteName: String): String = + size match { + case Size.Q64 => "r" + base + case Size.D32 => "e" + base + case Size.W16 => base + case Size.B8 => byteName + } + + private def newToString(base: Int): String = { + val b = base.toString + "r" + (size match { + case Size.Q64 => b + case Size.D32 => b + "d" + case Size.W16 => b + "w" + case Size.B8 => b + "b" + }) } } @@ -64,24 +97,18 @@ object assemblyIR { } } - // TODO register naming conventions are wrong - case class Register(size: RegSize, name: RegName) extends Dest with Src { - override def toString = s"${size}${name}" - } - case class MemLocation(pointer: Long | Register, opSize: SizeDir = SizeDir.Unspecified) - extends Dest - with Src { - override def toString = pointer match { - case hex: Long => opSize.toString + f"[0x$hex%X]" - case reg: Register => opSize.toString + s"[$reg]" - } + case class MemLocation(pointer: Register, opSize: Option[Size] = None) extends Dest with Src { + def this(pointer: Register, opSize: Size) = this(pointer, Some(opSize)) + + override def toString = + opSize.getOrElse("").toString + s"[$pointer]" } // TODO to string is wacky case class IndexAddress( base: Register, offset: Int | LabelArg, - indexReg: Register = Register(RegSize.R64, RegName.AX), + indexReg: Register = Register(Size.Q64, RegName.AX), scale: Int = 0 ) extends Dest with Src { @@ -188,17 +215,4 @@ object assemblyIR { case String => "%s" } } - - enum SizeDir { - case Byte, Word, DWord, Unspecified - - private val ptr = "ptr " - - override def toString(): String = this match { - case Byte => "byte " + ptr - case Word => "word " + ptr - case DWord => "dword " + ptr - case Unspecified => "" - } - } } diff --git a/src/main/wacc/backend/sizeExtensions.scala b/src/main/wacc/backend/sizeExtensions.scala new file mode 100644 index 0000000..59d3930 --- /dev/null +++ b/src/main/wacc/backend/sizeExtensions.scala @@ -0,0 +1,29 @@ +package wacc + +object sizeExtensions { + import microWacc._ + import types._ + import assemblyIR.Size + + extension (expr: Expr) { + + /** Calculate the size (bytes) of the heap required for the expression. */ + def heapSize: Int = (expr, expr.ty) match { + case (ArrayLiter(elems), KnownType.Array(KnownType.Char)) => + KnownType.Int.size.toInt + elems.size.toInt * KnownType.Char.size.toInt + case (ArrayLiter(elems), _) => + KnownType.Int.size.toInt + elems.map(_.ty.size.toInt).sum + case _ => expr.ty.size.toInt + } + } + + extension (ty: SemType) { + + /** Calculate the size (bytes) of a type in a register. */ + def size: Size = ty match { + case KnownType.Int => Size.D32 + case KnownType.Bool | KnownType.Char => Size.B8 + case KnownType.String | KnownType.Array(_) | KnownType.Pair(_, _) | ? => Size.Q64 + } + } +} From 887b982331339498965aa9214983cbf8cd44f4de Mon Sep 17 00:00:00 2001 From: Gleb Koval Date: Thu, 27 Feb 2025 14:48:24 +0000 Subject: [PATCH 51/54] fix: variable-sized values, heap-allocated arrays (and printCharArray) --- src/main/wacc/Main.scala | 3 +- src/main/wacc/backend/Stack.scala | 80 +++--- src/main/wacc/backend/asmGenerator.scala | 285 ++++++++------------- src/main/wacc/backend/assemblyIR.scala | 6 +- src/main/wacc/backend/sizeExtensions.scala | 10 +- src/main/wacc/backend/writer.scala | 5 +- src/main/wacc/frontend/typeChecker.scala | 12 +- src/test/wacc/instructionSpec.scala | 22 +- 8 files changed, 185 insertions(+), 238 deletions(-) diff --git a/src/main/wacc/Main.scala b/src/main/wacc/Main.scala index 52c40aa..020cbcd 100644 --- a/src/main/wacc/Main.scala +++ b/src/main/wacc/Main.scala @@ -1,6 +1,7 @@ package wacc import scala.collection.mutable +import cats.data.Chain import parsley.{Failure, Success} import scopt.OParser import java.io.File @@ -63,7 +64,7 @@ def frontend( } val s = "enter an integer to echo" -def backend(typedProg: microWacc.Program): List[asm.AsmLine] = +def backend(typedProg: microWacc.Program): Chain[asm.AsmLine] = asmGenerator.generateAsm(typedProg) def compile(filename: String, outFile: Option[File] = None)(using diff --git a/src/main/wacc/backend/Stack.scala b/src/main/wacc/backend/Stack.scala index 72aa5ef..0949ecb 100644 --- a/src/main/wacc/backend/Stack.scala +++ b/src/main/wacc/backend/Stack.scala @@ -1,37 +1,48 @@ package wacc import scala.collection.mutable.LinkedHashMap +import cats.data.Chain class Stack { import assemblyIR._ + import assemblyIR.Size._ import sizeExtensions.size import microWacc as mw - private val RSP = Register(Size.Q64, RegName.SP) + private val RSP = Register(Q64, RegName.SP) private class StackValue(val size: Size, val offset: Int) { - def bottom: Int = offset + size.toInt + def bottom: Int = offset + elemBytes } private val stack = LinkedHashMap[mw.Expr | Int, StackValue]() + private val elemBytes: Int = Q64.toInt + private def sizeBytes: Int = stack.size * elemBytes + /** The stack's size in bytes. */ - def size: Int = if stack.isEmpty then 0 else stack.last._2.bottom + def size: Int = stack.size /** Push an expression onto the stack. */ def push(expr: mw.Expr, src: Register): AsmLine = { - stack += expr -> StackValue(src.size, size) + stack += expr -> StackValue(src.size, sizeBytes) Push(src) } - /** Push an arbitrary register onto the stack. */ - def push(src: Register): AsmLine = { - stack += stack.size -> StackValue(src.size, size) - Push(src) + /** Push a value onto the stack. */ + def push(itemSize: Size, addr: Src): AsmLine = { + stack += stack.size -> StackValue(itemSize, sizeBytes) + Push(addr) } /** Reserve space for a variable on the stack. */ def reserve(ident: mw.Ident): AsmLine = { - stack += ident -> StackValue(ident.ty.size, size) - Subtract(RSP, ImmediateVal(ident.ty.size.toInt)) + stack += ident -> StackValue(ident.ty.size, sizeBytes) + Subtract(RSP, ImmediateVal(elemBytes)) + } + + /** Reserve space for a register on the stack. */ + def reserve(src: Register): AsmLine = { + stack += stack.size -> StackValue(src.size, sizeBytes) + Subtract(RSP, ImmediateVal(src.size.toInt)) } /** Reserve space for values on the stack. @@ -40,45 +51,40 @@ class Stack { * The sizes of the values to reserve space for. */ def reserve(sizes: List[Size]): AsmLine = { - val totalSize = sizes - .map(itemSize => - stack += stack.size -> StackValue(itemSize, size) - itemSize.toInt - ) - .sum - Subtract(RSP, ImmediateVal(totalSize)) + sizes.foreach { itemSize => + stack += stack.size -> StackValue(itemSize, sizeBytes) + } + Subtract(RSP, ImmediateVal(elemBytes * sizes.size)) } /** Pop a value from the stack into a register. Sizes MUST match. */ def pop(dest: Register): AsmLine = { - if (dest.size != stack.last._2.size) { - throw new IllegalArgumentException( - s"Cannot pop ${stack.last._2.size} bytes into $dest (${dest.size} bytes) register" - ) - } stack.remove(stack.last._1) Pop(dest) } /** Drop the top n values from the stack. */ def drop(n: Int = 1): AsmLine = { - val totalSize = (1 to n) - .map(_ => - val itemSize = stack.last._2.size.toInt - stack.remove(stack.last._1) - itemSize - ) - .sum - Add(RSP, ImmediateVal(totalSize)) + (1 to n).foreach { _ => + stack.remove(stack.last._1) + } + Add(RSP, ImmediateVal(n * elemBytes)) } - /** Get a lazy IndexAddress for a variable in the stack. */ - def accessVar(ident: mw.Ident): () => IndexAddress = () => { - IndexAddress(RSP, stack.size - stack(ident).bottom) + /** Generate AsmLines within a scope, which is reset after the block. */ + def withScope(block: () => Chain[AsmLine]): Chain[AsmLine] = { + val resetToSize = stack.size + var lines = block() + lines :+= drop(stack.size - resetToSize) + lines } + + /** Get an IndexAddress for a variable in the stack. */ + def accessVar(ident: mw.Ident): IndexAddress = + IndexAddress(RSP, sizeBytes - stack(ident).bottom) + def contains(ident: mw.Ident): Boolean = stack.contains(ident) - def head: MemLocation = MemLocation(RSP) - def head(offset: Size): MemLocation = MemLocation(RSP, Some(offset)) - // TODO: Might want to actually properly handle this with the LinkedHashMap too - def align(): AsmLine = And(RSP, ImmediateVal(-16)) + def head: MemLocation = MemLocation(RSP, stack.last._2.size) + + override def toString(): String = stack.toString } diff --git a/src/main/wacc/backend/asmGenerator.scala b/src/main/wacc/backend/asmGenerator.scala index 688e474..cd53b30 100644 --- a/src/main/wacc/backend/asmGenerator.scala +++ b/src/main/wacc/backend/asmGenerator.scala @@ -1,6 +1,5 @@ package wacc -import scala.collection.mutable.LinkedHashMap import scala.collection.mutable.ListBuffer import cats.data.Chain import cats.syntax.foldable._ @@ -8,7 +7,10 @@ import cats.syntax.foldable._ object asmGenerator { import microWacc._ import assemblyIR._ - import wacc.types._ + import assemblyIR.Size._ + import assemblyIR.RegName._ + import types._ + import sizeExtensions._ import lexer.escapedChars abstract case class Error() { @@ -29,26 +31,22 @@ object asmGenerator { def errLabel = ".L._errDivZero" } - val RAX = Register(RegSize.R64, RegName.AX) - val EAX = Register(RegSize.E32, RegName.AX) - val ESP = Register(RegSize.E32, RegName.SP) - val EDX = Register(RegSize.E32, RegName.DX) - val RDI = Register(RegSize.R64, RegName.DI) - val RIP = Register(RegSize.R64, RegName.IP) - val RBP = Register(RegSize.R64, RegName.BP) - val RSI = Register(RegSize.R64, RegName.SI) - val RDX = Register(RegSize.R64, RegName.DX) - val RCX = Register(RegSize.R64, RegName.CX) - val R8 = Register(RegSize.R64, RegName.Reg8) - val R9 = Register(RegSize.R64, RegName.Reg9) - val argRegs = List(RDI, RSI, RDX, RCX, R8, R9) + private val RAX = Register(Q64, AX) + private val EAX = Register(D32, AX) + private val RDI = Register(Q64, DI) + private val RIP = Register(Q64, IP) + private val RBP = Register(Q64, BP) + private val RSI = Register(Q64, SI) + private val RDX = Register(Q64, DX) + private val RCX = Register(Q64, CX) + private val argRegs = List(DI, SI, DX, CX, R8, R9) - val _8_BIT_MASK = 0xff + private val _8_BIT_MASK = 0xff - extension (chain: Chain[AsmLine]) - def +(line: AsmLine): Chain[AsmLine] = chain.append(line) + extension [T](chain: Chain[T]) + def +(item: T): Chain[T] = chain.append(item) - def concatAll(chains: Chain[AsmLine]*): Chain[AsmLine] = + def concatAll(chains: Chain[T]*): Chain[T] = chains.foldLeft(chain)(_ ++ _) class LabelGenerator { @@ -63,7 +61,7 @@ object asmGenerator { } } - def generateAsm(microProg: Program): List[AsmLine] = { + def generateAsm(microProg: Program): Chain[AsmLine] = { given stack: Stack = Stack() given strings: ListBuffer[String] = ListBuffer[String]() given labelGenerator: LabelGenerator = LabelGenerator() @@ -71,7 +69,6 @@ object asmGenerator { val progAsm = Chain(LabelDef("main")).concatAll( funcPrologue(), - Chain.one(stack.align()), main.foldMap(generateStmt(_)), Chain.one(Xor(RAX, RAX)), funcEpilogue(), @@ -95,12 +92,10 @@ object asmGenerator { strDirs, Chain.one(Directive.Text), progAsm - ).toList + ) } - private def wrapBuiltinFunc(labelName: String, funcBody: Chain[AsmLine])(using - stack: Stack - ): Chain[AsmLine] = { + private def wrapBuiltinFunc(labelName: String, funcBody: Chain[AsmLine]): Chain[AsmLine] = { var chain = Chain.one[AsmLine](LabelDef(labelName)) chain ++= funcPrologue() chain ++= funcBody @@ -108,7 +103,7 @@ object asmGenerator { chain } - def generateUserFunc(func: FuncDecl)(using + private def generateUserFunc(func: FuncDecl)(using strings: ListBuffer[String], labelGenerator: LabelGenerator ): Chain[AsmLine] = { @@ -119,29 +114,27 @@ object asmGenerator { chain ++= funcPrologue() // Push the rest of params onto the stack for simplicity argRegs.zip(func.params).foreach { (reg, param) => - chain += stack.push(param, reg) + chain += stack.push(param, Register(Q64, reg)) } chain ++= func.body.foldMap(generateStmt(_)) // No need for epilogue here since all user functions must return explicitly chain } - def generateBuiltInFuncs()(using - stack: Stack, - strings: ListBuffer[String], + private def generateBuiltInFuncs()(using labelGenerator: LabelGenerator ): Chain[AsmLine] = { var chain = Chain.empty[AsmLine] chain ++= wrapBuiltinFunc( labelGenerator.getLabel(Builtin.Exit), - Chain(stack.align(), assemblyIR.Call(CLibFunc.Exit)) + Chain(stackAlign, assemblyIR.Call(CLibFunc.Exit)) ) chain ++= wrapBuiltinFunc( labelGenerator.getLabel(Builtin.Printf), Chain( - stack.align(), + stackAlign, assemblyIR.Call(CLibFunc.PrintF), Xor(RDI, RDI), assemblyIR.Call(CLibFunc.Fflush) @@ -151,9 +144,9 @@ object asmGenerator { chain ++= wrapBuiltinFunc( labelGenerator.getLabel(Builtin.PrintCharArray), Chain( - stack.align(), - Load(RDX, IndexAddress(RSI, 8)), - Move(RSI, MemLocation(RSI)), + stackAlign, + Load(RDX, IndexAddress(RSI, KnownType.Int.size.toInt)), + Move(Register(D32, SI), MemLocation(RSI, D32)), assemblyIR.Call(CLibFunc.PrintF), Xor(RDI, RDI), assemblyIR.Call(CLibFunc.Fflush) @@ -162,7 +155,7 @@ object asmGenerator { chain ++= wrapBuiltinFunc( labelGenerator.getLabel(Builtin.Malloc), - Chain(stack.align(), assemblyIR.Call(CLibFunc.Malloc)) + Chain(stackAlign, assemblyIR.Call(CLibFunc.Malloc)) // Out of memory check is optional ) @@ -171,13 +164,12 @@ object asmGenerator { chain ++= wrapBuiltinFunc( labelGenerator.getLabel(Builtin.Read), Chain( - stack.align(), - stack.reserve(), - stack.push(RSI), - Load(RSI, stack.head), + stackAlign, + Subtract(Register(Q64, SP), ImmediateVal(8)), + Push(RSI), + Load(RSI, MemLocation(Register(Q64, SP), Q64)), assemblyIR.Call(CLibFunc.Scanf), - stack.pop(RAX), - stack.drop() + Pop(RAX) ) ) @@ -185,7 +177,7 @@ object asmGenerator { // TODO can this be done with a call to generateStmt? // Consider other error cases -> look to generalise LabelDef(zeroDivError.errLabel), - stack.align(), + stackAlign, Load(RDI, IndexAddress(RIP, LabelArg(zeroDivError.strLabel))), assemblyIR.Call(CLibFunc.PrintF), Move(RDI, ImmediateVal(-1)), @@ -195,53 +187,33 @@ object asmGenerator { chain } - /** Wraps a chain in a stack reset. - * - * This is useful for ensuring that the stack size at the death of scope is the same as the stack - * size at the start of the scope. See branching (If / While) - * - * @param genChain - * Function that generates the scope AsmLines - * @param stack - * The stack to reset - * @return - * The generated scope AsmLines - */ - private def generateScope(genChain: () => Chain[AsmLine])(using - stack: Stack - ): Chain[AsmLine] = { - val stackSizeStart = stack.size - var chain = genChain() - chain += stack.drop(stack.size - stackSizeStart) - chain - } - - def generateStmt(stmt: Stmt)(using + private def generateStmt(stmt: Stmt)(using stack: Stack, strings: ListBuffer[String], labelGenerator: LabelGenerator ): Chain[AsmLine] = { var chain = Chain.empty[AsmLine] + chain += Comment(stmt.toString) stmt match { case Assign(lhs, rhs) => - lhs match { case ident: Ident => - val dest = stack.accessVar(ident) if (!stack.contains(ident)) chain += stack.reserve(ident) - chain ++= evalExprOntoStack(rhs) - chain += stack.pop(RDX) - chain += Move(dest(), RDX) + chain += stack.pop(RAX) + chain += Move(stack.accessVar(ident), RAX) case ArrayElem(x, i) => - chain ++= evalExprOntoStack(x) - chain ++= evalExprOntoStack(i) chain ++= evalExprOntoStack(rhs) + chain ++= evalExprOntoStack(i) + chain ++= evalExprOntoStack(x) chain += stack.pop(RAX) chain += stack.pop(RCX) chain += stack.pop(RDX) - chain += Move(IndexAddress(RDX, 8, RCX, 8), RAX) + chain += Move( + IndexAddress(RAX, KnownType.Int.size.toInt, RCX, x.ty.elemSize.toInt), + Register(x.ty.elemSize, DX) + ) } case If(cond, thenBranch, elseBranch) => @@ -253,11 +225,11 @@ object asmGenerator { chain += Compare(RAX, ImmediateVal(0)) chain += Jump(LabelArg(elseLabel), Cond.Equal) - chain ++= generateScope(() => thenBranch.foldMap(generateStmt)) + chain ++= stack.withScope(() => thenBranch.foldMap(generateStmt)) chain += Jump(LabelArg(endLabel)) chain += LabelDef(elseLabel) - chain ++= generateScope(() => elseBranch.foldMap(generateStmt)) + chain ++= stack.withScope(() => elseBranch.foldMap(generateStmt)) chain += LabelDef(endLabel) case While(cond, body) => @@ -270,7 +242,7 @@ object asmGenerator { chain += Compare(RAX, ImmediateVal(0)) chain += Jump(LabelArg(endLabel), Cond.Equal) - chain ++= generateScope(() => body.foldMap(generateStmt)) + chain ++= stack.withScope(() => body.foldMap(generateStmt)) chain += Jump(LabelArg(startLabel)) chain += LabelDef(endLabel) @@ -291,7 +263,7 @@ object asmGenerator { chain } - def evalExprOntoStack(expr: Expr)(using + private def evalExprOntoStack(expr: Expr)(using stack: Stack, strings: ListBuffer[String], labelGenerator: LabelGenerator @@ -299,111 +271,117 @@ object asmGenerator { var chain = Chain.empty[AsmLine] val stackSizeStart = stack.size expr match { - case IntLiter(v) => chain += stack.push(ImmediateVal(v)) - case CharLiter(v) => chain += stack.push(ImmediateVal(v.toInt)) - case ident: Ident => chain += stack.push(stack.accessVar(ident)()) + case IntLiter(v) => chain += stack.push(KnownType.Int.size, ImmediateVal(v)) + case CharLiter(v) => chain += stack.push(KnownType.Char.size, ImmediateVal(v.toInt)) + case ident: Ident => chain += stack.push(ident.ty.size, stack.accessVar(ident)) - case ArrayLiter(elems) => + case array @ ArrayLiter(elems) => expr.ty match { case KnownType.String => strings += elems.collect { case CharLiter(v) => v }.mkString chain += Load(RAX, IndexAddress(RIP, LabelArg(s".L.str${strings.size - 1}"))) - chain += stack.push(RAX) - case _ => + chain += stack.push(Q64, RAX) + case ty => chain ++= generateCall( - microWacc.Call(Builtin.Malloc, List(IntLiter((elems.size + 1) * 8))), + microWacc.Call(Builtin.Malloc, List(IntLiter(array.heapSize))), isTail = false ) - chain += stack.push(RAX) + chain += stack.push(Q64, RAX) // Store the length of the array at the start - chain += Move(MemLocation(RAX, SizeDir.DWord), ImmediateVal(elems.size)) + chain += Move(MemLocation(RAX, D32), ImmediateVal(elems.size)) elems.zipWithIndex.foldMap { (elem, i) => chain ++= evalExprOntoStack(elem) chain += stack.pop(RCX) chain += stack.pop(RAX) - chain += Move(IndexAddress(RAX, 8 * (i + 1)), RCX) - chain += stack.push(RAX) + chain += Move(IndexAddress(RAX, 4 + i * ty.elemSize.toInt), Register(ty.elemSize, CX)) + chain += stack.push(Q64, RAX) } } - case BoolLiter(true) => chain += stack.push(ImmediateVal(1)) + case BoolLiter(true) => + chain += stack.push(KnownType.Bool.size, ImmediateVal(1)) case BoolLiter(false) => chain += Xor(RAX, RAX) - chain += stack.push(RAX) - case NullLiter() => chain += stack.push(ImmediateVal(0)) + chain += stack.push(KnownType.Bool.size, RAX) + case NullLiter() => + chain += stack.push(KnownType.Pair(?, ?).size, ImmediateVal(0)) case ArrayElem(x, i) => chain ++= evalExprOntoStack(x) chain ++= evalExprOntoStack(i) chain += stack.pop(RCX) chain += stack.pop(RAX) - // + 1 because we store the length of the array at the start - chain += stack.push(IndexAddress(RAX, 8, RCX, 8)) + // + Int because we store the length of the array at the start + chain += Move( + Register(x.ty.elemSize, AX), + IndexAddress(RAX, KnownType.Int.size.toInt, RCX, x.ty.elemSize.toInt) + ) + chain += stack.push(x.ty.elemSize, RAX) case UnaryOp(x, op) => chain ++= evalExprOntoStack(x) op match { case UnaryOperator.Chr | UnaryOperator.Ord => // No op needed - case UnaryOperator.Len => - // Access the elem + case UnaryOperator.Len => chain += stack.pop(RAX) - chain += Push(MemLocation(RAX)) - case UnaryOperator.Negate => chain += Negate(stack.head(SizeDir.DWord)) + chain += Move(EAX, MemLocation(RAX, D32)) + chain += stack.push(D32, RAX) + case UnaryOperator.Negate => + chain += Negate(stack.head) case UnaryOperator.Not => - chain += Xor(stack.head(SizeDir.DWord), ImmediateVal(1)) + chain += Xor(stack.head, ImmediateVal(1)) } case BinaryOp(x, y, op) => + val destX = Register(x.ty.size, AX) chain ++= evalExprOntoStack(y) chain ++= evalExprOntoStack(x) - chain += stack.pop(RAX) op match { - case BinaryOperator.Add => chain += Add(stack.head(SizeDir.DWord), EAX) + case BinaryOperator.Add => + chain += Add(stack.head, destX) case BinaryOperator.Sub => - chain += Subtract(EAX, stack.head(SizeDir.DWord)) + chain += Subtract(destX, stack.head) chain += stack.drop() - chain += stack.push(RAX) + chain += stack.push(destX.size, RAX) case BinaryOperator.Mul => - chain += Multiply(EAX, stack.head(SizeDir.DWord)) + chain += Multiply(destX, stack.head) chain += stack.drop() - chain += stack.push(RAX) + chain += stack.push(destX.size, RAX) case BinaryOperator.Div => - chain += Compare(stack.head(SizeDir.DWord), ImmediateVal(0)) + chain += Compare(stack.head, ImmediateVal(0)) chain += Jump(LabelArg(zeroDivError.errLabel), Cond.Equal) chain += CDQ() - chain += Divide(stack.head(SizeDir.DWord)) + chain += Divide(stack.head) chain += stack.drop() - chain += stack.push(RAX) + chain += stack.push(destX.size, RAX) case BinaryOperator.Mod => chain += CDQ() - chain += Divide(stack.head(SizeDir.DWord)) + chain += Divide(stack.head) chain += stack.drop() - chain += stack.push(RDX) + chain += stack.push(destX.size, RDX) - case BinaryOperator.Eq => chain ++= generateComparison(Cond.Equal) - case BinaryOperator.Neq => chain ++= generateComparison(Cond.NotEqual) - case BinaryOperator.Greater => chain ++= generateComparison(Cond.Greater) - case BinaryOperator.GreaterEq => chain ++= generateComparison(Cond.GreaterEqual) - case BinaryOperator.Less => chain ++= generateComparison(Cond.Less) - case BinaryOperator.LessEq => chain ++= generateComparison(Cond.LessEqual) - case BinaryOperator.And => chain += And(stack.head(SizeDir.DWord), EAX) - case BinaryOperator.Or => chain += Or(stack.head(SizeDir.DWord), EAX) + case BinaryOperator.Eq => chain ++= generateComparison(destX, Cond.Equal) + case BinaryOperator.Neq => chain ++= generateComparison(destX, Cond.NotEqual) + case BinaryOperator.Greater => chain ++= generateComparison(destX, Cond.Greater) + case BinaryOperator.GreaterEq => chain ++= generateComparison(destX, Cond.GreaterEqual) + case BinaryOperator.Less => chain ++= generateComparison(destX, Cond.Less) + case BinaryOperator.LessEq => chain ++= generateComparison(destX, Cond.LessEqual) + case BinaryOperator.And => chain += And(stack.head, destX) + case BinaryOperator.Or => chain += Or(stack.head, destX) } case call: microWacc.Call => chain ++= generateCall(call, isTail = false) - chain += stack.push(RAX) + chain += stack.push(call.ty.size, RAX) } - if chain.isEmpty then chain += stack.push(ImmediateVal(0)) - assert(stack.size == stackSizeStart + 1) chain } - def generateCall(call: microWacc.Call, isTail: Boolean)(using + private def generateCall(call: microWacc.Call, isTail: Boolean)(using stack: Stack, strings: ListBuffer[String], labelGenerator: LabelGenerator @@ -413,7 +391,7 @@ object asmGenerator { argRegs.zip(args).foldMap { (reg, expr) => chain ++= evalExprOntoStack(expr) - chain += stack.pop(reg) + chain += stack.pop(Register(Q64, reg)) } args.drop(argRegs.size).foldMap { @@ -434,77 +412,36 @@ object asmGenerator { chain } - def generateComparison(cond: Cond)(using - stack: Stack, - strings: ListBuffer[String], - labelGenerator: LabelGenerator + private def generateComparison(destX: Register, cond: Cond)(using + stack: Stack ): Chain[AsmLine] = { var chain = Chain.empty[AsmLine] - chain += Compare(EAX, stack.head(SizeDir.DWord)) - chain += Set(Register(RegSize.Byte, RegName.AL), cond) + chain += Compare(destX, stack.head) + chain += Set(Register(B8, AX), cond) chain += And(RAX, ImmediateVal(_8_BIT_MASK)) chain += stack.drop() - chain += stack.push(RAX) + chain += stack.push(B8, RAX) chain } - // Missing a sub instruction but dont think we need it - def funcPrologue()(using stack: Stack): Chain[AsmLine] = { + private def funcPrologue(): Chain[AsmLine] = { var chain = Chain.empty[AsmLine] - chain += stack.push(RBP) - chain += Move(RBP, Register(RegSize.R64, RegName.SP)) + chain += Push(RBP) + chain += Move(RBP, Register(Q64, SP)) chain } - def funcEpilogue()(using stack: Stack): Chain[AsmLine] = { + private def funcEpilogue(): Chain[AsmLine] = { var chain = Chain.empty[AsmLine] - chain += Move(Register(RegSize.R64, RegName.SP), RBP) + chain += Move(Register(Q64, SP), RBP) chain += Pop(RBP) chain += assemblyIR.Return() chain } - class Stack { - private val stack = LinkedHashMap[Expr | Int, Int]() - private val RSP = Register(RegSize.R64, RegName.SP) - - private def next: Int = stack.size + 1 - def size: Int = stack.size - def push(expr: Expr, src: Src): AsmLine = { - stack += expr -> next - Push(src) - } - def push(src: Src): AsmLine = { - stack += stack.size -> next - Push(src) - } - def pop(dest: Src): AsmLine = { - stack.remove(stack.last._1) - Pop(dest) - } - def reserve(ident: Ident): AsmLine = { - stack += ident -> next - Subtract(RSP, ImmediateVal(8)) - } - def reserve(n: Int = 1): AsmLine = { - (1 to n).foreach(_ => stack += stack.size -> next) - Subtract(RSP, ImmediateVal(n * 8)) - } - def drop(n: Int = 1): AsmLine = { - (1 to n).foreach(_ => stack.remove(stack.last._1)) - Add(RSP, ImmediateVal(n * 8)) - } - def accessVar(ident: Ident): () => IndexAddress = () => { - IndexAddress(RSP, (stack.size - stack(ident)) * 8) - } - def head: MemLocation = MemLocation(RSP) - def head(size: SizeDir): MemLocation = MemLocation(RSP, size) - def contains(ident: Ident): Boolean = stack.contains(ident) - // TODO: Might want to actually properly handle this with the LinkedHashMap too - def align(): AsmLine = And(RSP, ImmediateVal(-16)) - } + private def stackAlign: AsmLine = And(Register(Q64, SP), ImmediateVal(-16)) private val escapedCharsMapping = escapedChars.map { case (k, v) => v -> s"\\$k" } extension (s: String) { diff --git a/src/main/wacc/backend/assemblyIR.scala b/src/main/wacc/backend/assemblyIR.scala index 2946dcb..fbf51f5 100644 --- a/src/main/wacc/backend/assemblyIR.scala +++ b/src/main/wacc/backend/assemblyIR.scala @@ -97,11 +97,9 @@ object assemblyIR { } } - case class MemLocation(pointer: Register, opSize: Option[Size] = None) extends Dest with Src { - def this(pointer: Register, opSize: Size) = this(pointer, Some(opSize)) - + case class MemLocation(pointer: Register, opSize: Size) extends Dest with Src { override def toString = - opSize.getOrElse("").toString + s"[$pointer]" + opSize.toString + s"[$pointer]" } // TODO to string is wacky diff --git a/src/main/wacc/backend/sizeExtensions.scala b/src/main/wacc/backend/sizeExtensions.scala index 59d3930..798e290 100644 --- a/src/main/wacc/backend/sizeExtensions.scala +++ b/src/main/wacc/backend/sizeExtensions.scala @@ -11,8 +11,8 @@ object sizeExtensions { def heapSize: Int = (expr, expr.ty) match { case (ArrayLiter(elems), KnownType.Array(KnownType.Char)) => KnownType.Int.size.toInt + elems.size.toInt * KnownType.Char.size.toInt - case (ArrayLiter(elems), _) => - KnownType.Int.size.toInt + elems.map(_.ty.size.toInt).sum + case (ArrayLiter(elems), ty) => + KnownType.Int.size.toInt + elems.size * ty.elemSize.toInt case _ => expr.ty.size.toInt } } @@ -25,5 +25,11 @@ object sizeExtensions { case KnownType.Bool | KnownType.Char => Size.B8 case KnownType.String | KnownType.Array(_) | KnownType.Pair(_, _) | ? => Size.Q64 } + + def elemSize: Size = ty match { + case KnownType.Array(elem) => elem.size + case KnownType.Pair(_, _) => Size.Q64 + case _ => ty.size + } } } diff --git a/src/main/wacc/backend/writer.scala b/src/main/wacc/backend/writer.scala index b798af3..3c8dcfd 100644 --- a/src/main/wacc/backend/writer.scala +++ b/src/main/wacc/backend/writer.scala @@ -1,11 +1,12 @@ package wacc import java.io.PrintStream +import cats.data.Chain object writer { import assemblyIR._ - def writeTo(asmList: List[AsmLine], printStream: PrintStream): Unit = { - asmList.foreach(printStream.println) + def writeTo(asmList: Chain[AsmLine], printStream: PrintStream): Unit = { + asmList.iterator.foreach(printStream.println) } } diff --git a/src/main/wacc/frontend/typeChecker.scala b/src/main/wacc/frontend/typeChecker.scala index 2c430e5..002876d 100644 --- a/src/main/wacc/frontend/typeChecker.scala +++ b/src/main/wacc/frontend/typeChecker.scala @@ -223,10 +223,10 @@ object typeChecker { case KnownType.Int => "%d" case KnownType.Pair(_, _) | KnownType.Array(_) | ? => "%p" } - val printfCall = { (value: microWacc.Expr) => + val printfCall = { (func: microWacc.Builtin, value: microWacc.Expr) => List( microWacc.Call( - microWacc.Builtin.Printf, + func, List( s"$exprFormat${if newline then "\n" else ""}".toMicroWaccCharArray, value @@ -239,11 +239,13 @@ object typeChecker { List( microWacc.If( exprTyped, - printfCall("true".toMicroWaccCharArray), - printfCall("false".toMicroWaccCharArray) + printfCall(microWacc.Builtin.Printf, "true".toMicroWaccCharArray), + printfCall(microWacc.Builtin.Printf, "false".toMicroWaccCharArray) ) ) - case _ => printfCall(exprTyped) + case KnownType.Array(KnownType.Char) => + printfCall(microWacc.Builtin.PrintCharArray, exprTyped) + case _ => printfCall(microWacc.Builtin.Printf, exprTyped) } case ast.If(cond, thenStmt, elseStmt) => List( diff --git a/src/test/wacc/instructionSpec.scala b/src/test/wacc/instructionSpec.scala index b7452a0..feef0d4 100644 --- a/src/test/wacc/instructionSpec.scala +++ b/src/test/wacc/instructionSpec.scala @@ -1,42 +1,38 @@ import org.scalatest.funsuite.AnyFunSuite import wacc.assemblyIR._ +import wacc.assemblyIR.Size._ +import wacc.assemblyIR.RegName._ class instructionSpec extends AnyFunSuite { - val named64BitRegister = Register(RegSize.R64, RegName.AX) + val named64BitRegister = Register(Q64, AX) test("named 64-bit register toString") { assert(named64BitRegister.toString == "rax") } - val named32BitRegister = Register(RegSize.E32, RegName.AX) + val named32BitRegister = Register(D32, AX) test("named 32-bit register toString") { assert(named32BitRegister.toString == "eax") } - val scratch64BitRegister = Register(RegSize.R64, RegName.Reg8) + val scratch64BitRegister = Register(Q64, R8) test("scratch 64-bit register toString") { assert(scratch64BitRegister.toString == "r8") } - val scratch32BitRegister = Register(RegSize.E32, RegName.Reg8) + val scratch32BitRegister = Register(D32, R8) test("scratch 32-bit register toString") { - assert(scratch32BitRegister.toString == "e8") + assert(scratch32BitRegister.toString == "r8d") } - val memLocationWithHex = MemLocation(0x12345678) - - test("mem location with hex toString") { - assert(memLocationWithHex.toString == "[0x12345678]") - } - - val memLocationWithRegister = MemLocation(named64BitRegister) + val memLocationWithRegister = MemLocation(named64BitRegister, Q64) test("mem location with register toString") { - assert(memLocationWithRegister.toString == "[rax]") + assert(memLocationWithRegister.toString == "qword ptr [rax]") } val immediateVal = ImmediateVal(123) From 507cb7dd9bb0066cbb471b78f88cdf177065025e Mon Sep 17 00:00:00 2001 From: Gleb Koval Date: Thu, 27 Feb 2025 15:46:01 +0000 Subject: [PATCH 52/54] fix: zero-out sub-32 bit expressions --- src/main/wacc/backend/asmGenerator.scala | 9 ++++++--- src/test/wacc/examples.scala | 5 ++--- 2 files changed, 8 insertions(+), 6 deletions(-) diff --git a/src/main/wacc/backend/asmGenerator.scala b/src/main/wacc/backend/asmGenerator.scala index cd53b30..8dc67f6 100644 --- a/src/main/wacc/backend/asmGenerator.scala +++ b/src/main/wacc/backend/asmGenerator.scala @@ -41,8 +41,6 @@ object asmGenerator { private val RCX = Register(Q64, CX) private val argRegs = List(DI, SI, DX, CX, R8, R9) - private val _8_BIT_MASK = 0xff - extension [T](chain: Chain[T]) def +(item: T): Chain[T] = chain.append(item) @@ -378,6 +376,7 @@ object asmGenerator { } assert(stack.size == stackSizeStart + 1) + chain ++= zeroRest(MemLocation(stack.head.pointer, Q64), expr.ty.size) chain } @@ -419,7 +418,7 @@ object asmGenerator { chain += Compare(destX, stack.head) chain += Set(Register(B8, AX), cond) - chain += And(RAX, ImmediateVal(_8_BIT_MASK)) + chain ++= zeroRest(RAX, B8) chain += stack.drop() chain += stack.push(B8, RAX) @@ -442,6 +441,10 @@ object asmGenerator { } private def stackAlign: AsmLine = And(Register(Q64, SP), ImmediateVal(-16)) + private def zeroRest(dest: Dest, size: Size): Chain[AsmLine] = size match { + case Q64 | D32 => Chain.empty + case _ => Chain.one(And(dest, ImmediateVal((1 << (size.toInt * 8)) - 1))) + } private val escapedCharsMapping = escapedChars.map { case (k, v) => v -> s"\\$k" } extension (s: String) { diff --git a/src/test/wacc/examples.scala b/src/test/wacc/examples.scala index 87def2a..988a6d0 100644 --- a/src/test/wacc/examples.scala +++ b/src/test/wacc/examples.scala @@ -93,14 +93,13 @@ class ParallelExamplesSpec extends AnyFlatSpec with BeforeAndAfterAll { "^.*wacc-examples/valid/function/nested_functions.*$", "^.*wacc-examples/valid/function/simple_functions.*$", // "^.*wacc-examples/valid/if.*$", - "^.*wacc-examples/valid/IO/print.*$", + // "^.*wacc-examples/valid/IO/print.*$", // "^.*wacc-examples/valid/IO/read.*$", "^.*wacc-examples/valid/IO/IOLoop.wacc.*$", // "^.*wacc-examples/valid/IO/IOSequence.wacc.*$", - "^.*wacc-examples/valid/pairs.*$", + // "^.*wacc-examples/valid/pairs.*$", "^.*wacc-examples/valid/runtimeErr.*$", // "^.*wacc-examples/valid/scope.*$", - "^.*wacc-examples/valid/scope/printAllTypes.wacc$", // while we still don't have arrays implemented // "^.*wacc-examples/valid/sequence.*$", // "^.*wacc-examples/valid/variables.*$", // "^.*wacc-examples/valid/while.*$", From c472c7a62cb39a7e46b8fd880405e4731d9ec60e Mon Sep 17 00:00:00 2001 From: Gleb Koval Date: Thu, 27 Feb 2025 16:02:39 +0000 Subject: [PATCH 53/54] fix: reserve return pointer and RBP on stack for user func bodies --- src/main/wacc/backend/Stack.scala | 2 +- src/main/wacc/backend/asmGenerator.scala | 10 +++++++--- src/test/wacc/examples.scala | 4 ++-- 3 files changed, 10 insertions(+), 6 deletions(-) diff --git a/src/main/wacc/backend/Stack.scala b/src/main/wacc/backend/Stack.scala index 0949ecb..94b329a 100644 --- a/src/main/wacc/backend/Stack.scala +++ b/src/main/wacc/backend/Stack.scala @@ -50,7 +50,7 @@ class Stack { * @param sizes * The sizes of the values to reserve space for. */ - def reserve(sizes: List[Size]): AsmLine = { + def reserve(sizes: Size*): AsmLine = { sizes.foreach { itemSize => stack += stack.size -> StackValue(itemSize, sizeBytes) } diff --git a/src/main/wacc/backend/asmGenerator.scala b/src/main/wacc/backend/asmGenerator.scala index 8dc67f6..60e0b47 100644 --- a/src/main/wacc/backend/asmGenerator.scala +++ b/src/main/wacc/backend/asmGenerator.scala @@ -93,7 +93,9 @@ object asmGenerator { ) } - private def wrapBuiltinFunc(labelName: String, funcBody: Chain[AsmLine]): Chain[AsmLine] = { + private def wrapBuiltinFunc(labelName: String, funcBody: Chain[AsmLine])(using + stack: Stack + ): Chain[AsmLine] = { var chain = Chain.one[AsmLine](LabelDef(labelName)) chain ++= funcPrologue() chain ++= funcBody @@ -108,6 +110,7 @@ object asmGenerator { given stack: Stack = Stack() // Setup the stack with param 7 and up func.params.drop(argRegs.size).foreach(stack.reserve(_)) + stack.reserve(Q64) // Reserve return pointer slot var chain = Chain.one[AsmLine](LabelDef(labelGenerator.getLabel(func.name))) chain ++= funcPrologue() // Push the rest of params onto the stack for simplicity @@ -120,6 +123,7 @@ object asmGenerator { } private def generateBuiltInFuncs()(using + stack: Stack, labelGenerator: LabelGenerator ): Chain[AsmLine] = { var chain = Chain.empty[AsmLine] @@ -425,9 +429,9 @@ object asmGenerator { chain } - private def funcPrologue(): Chain[AsmLine] = { + private def funcPrologue()(using stack: Stack): Chain[AsmLine] = { var chain = Chain.empty[AsmLine] - chain += Push(RBP) + chain += stack.push(Q64, RBP) chain += Move(RBP, Register(Q64, SP)) chain } diff --git a/src/test/wacc/examples.scala b/src/test/wacc/examples.scala index 988a6d0..01e4c30 100644 --- a/src/test/wacc/examples.scala +++ b/src/test/wacc/examples.scala @@ -90,8 +90,8 @@ class ParallelExamplesSpec extends AnyFlatSpec with BeforeAndAfterAll { // "^.*wacc-examples/valid/basic/exit.*$", // "^.*wacc-examples/valid/basic/skip.*$", // "^.*wacc-examples/valid/expressions.*$", - "^.*wacc-examples/valid/function/nested_functions.*$", - "^.*wacc-examples/valid/function/simple_functions.*$", + // "^.*wacc-examples/valid/function/nested_functions.*$", + // "^.*wacc-examples/valid/function/simple_functions.*$", // "^.*wacc-examples/valid/if.*$", // "^.*wacc-examples/valid/IO/print.*$", // "^.*wacc-examples/valid/IO/read.*$", From c0f2473db1710d7f93e57b4a38a99f330fc7e4f4 Mon Sep 17 00:00:00 2001 From: Gleb Koval Date: Thu, 27 Feb 2025 18:25:12 +0000 Subject: [PATCH 54/54] test: fix input handling for IOLoop example --- src/test/wacc/examples.scala | 23 +++++++++++++++++++++-- 1 file changed, 21 insertions(+), 2 deletions(-) diff --git a/src/test/wacc/examples.scala b/src/test/wacc/examples.scala index 01e4c30..7c895ab 100644 --- a/src/test/wacc/examples.scala +++ b/src/test/wacc/examples.scala @@ -39,7 +39,26 @@ class ParallelExamplesSpec extends AnyFlatSpec with BeforeAndAfterAll { // Retrieve contents to get input and expected output + exit code val contents = scala.io.Source.fromFile(File(filename)).getLines.toList val inputLine = - contents.find(_.matches("^# ?[Ii]nput:.*$")).map(_.split(":").last.strip).getOrElse("") + contents + .find(_.matches("^# ?[Ii]nput:.*$")) + .map(line => + ("" :: line.split(":").last.strip.split(" ").toList) + .sliding(2) + .flatMap { arr => + if ( + // First entry has no space in front + arr(0) == "" || + // int followed by non-digit, space can be removed + arr(0).toIntOption.nonEmpty && !arr(1)(0).isDigit || + // non-int followed by int, space can be removed + !arr(0).last.isDigit && arr(1).toIntOption.nonEmpty + ) + then List(arr(1)) + else List(" ", arr(1)) + } + .mkString + ) + .getOrElse("") val outputLineIdx = contents.indexWhere(_.matches("^# ?[Oo]utput:.*$")) val expectedOutput = if (outputLineIdx == -1) "" @@ -95,7 +114,7 @@ class ParallelExamplesSpec extends AnyFlatSpec with BeforeAndAfterAll { // "^.*wacc-examples/valid/if.*$", // "^.*wacc-examples/valid/IO/print.*$", // "^.*wacc-examples/valid/IO/read.*$", - "^.*wacc-examples/valid/IO/IOLoop.wacc.*$", + // "^.*wacc-examples/valid/IO/IOLoop.wacc.*$", // "^.*wacc-examples/valid/IO/IOSequence.wacc.*$", // "^.*wacc-examples/valid/pairs.*$", "^.*wacc-examples/valid/runtimeErr.*$",