From f2a1eaf24cacc5927237d9385f7059d8efd8bab6 Mon Sep 17 00:00:00 2001 From: Guy C Date: Fri, 28 Feb 2025 11:37:30 +0000 Subject: [PATCH 1/8] refactor: reorganize operation classes --- src/main/wacc/backend/assemblyIR.scala | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/src/main/wacc/backend/assemblyIR.scala b/src/main/wacc/backend/assemblyIR.scala index fbf51f5..e8e7e62 100644 --- a/src/main/wacc/backend/assemblyIR.scala +++ b/src/main/wacc/backend/assemblyIR.scala @@ -102,7 +102,6 @@ object assemblyIR { opSize.toString + s"[$pointer]" } - // TODO to string is wacky case class IndexAddress( base: Register, offset: Int | LabelArg, @@ -125,36 +124,37 @@ object assemblyIR { override def toString = name } - // TODO Check if dest and src are not both memory locations abstract class Operation(ins: String, ops: Operand*) extends AsmLine { override def toString: String = s"\t$ins ${ops.mkString(", ")}" } + + // arithmetic operations case class Add(op1: Dest, op2: Src) extends Operation("add", op1, op2) case class Subtract(op1: Dest, op2: Src) extends Operation("sub", op1, op2) case class Multiply(ops: Operand*) extends Operation("imul", ops*) case class Divide(op1: Src) extends Operation("idiv", op1) case class Negate(op: Dest) extends Operation("neg", op) - + // bitwise operations case class And(op1: Dest, op2: Src) extends Operation("and", op1, op2) case class Or(op1: Dest, op2: Src) extends Operation("or", op1, op2) case class Xor(op1: Dest, op2: Src) extends Operation("xor", op1, op2) case class Compare(op1: Dest, op2: Src) extends Operation("cmp", op1, op2) - + case class CDQ() extends Operation("cdq") // stack operations case class Push(op1: Src) extends Operation("push", op1) case class Pop(op1: Src) extends Operation("pop", op1) - case class Call(op1: CLibFunc | LabelArg) extends Operation("call", op1) - + // move operations case class Move(op1: Dest, op2: Src) extends Operation("mov", op1, op2) case class Load(op1: Register, op2: MemLocation | IndexAddress) extends Operation("lea ", op1, op2) - case class CDQ() extends Operation("cdq") + // function call operations + case class Call(op1: CLibFunc | LabelArg) extends Operation("call", op1) case class Return() extends Operation("ret") + // conditional operations case class Jump(op1: LabelArg, condition: Cond = Cond.Always) extends Operation(s"j${condition.toString}", op1) - case class Set(op1: Dest, condition: Cond = Cond.Always) extends Operation(s"set${condition.toString}", op1) From b733d233b03c4f2202f7af1302796f19e5031591 Mon Sep 17 00:00:00 2001 From: Guy C Date: Fri, 28 Feb 2025 11:58:38 +0000 Subject: [PATCH 2/8] feat: implements outofmemoryerror handling --- src/main/wacc/backend/RuntimeError.scala | 18 +++++++++++++++++- src/main/wacc/backend/asmGenerator.scala | 9 +++++++-- 2 files changed, 24 insertions(+), 3 deletions(-) diff --git a/src/main/wacc/backend/RuntimeError.scala b/src/main/wacc/backend/RuntimeError.scala index 9085b63..e6e972c 100644 --- a/src/main/wacc/backend/RuntimeError.scala +++ b/src/main/wacc/backend/RuntimeError.scala @@ -117,6 +117,22 @@ object RuntimeError { ) } + case object OutOfMemoryError extends RuntimeError { + val strLabel = ".L._errOutOfMemory_str0" + val errStr = "fatal error: out of memory" + val errLabel = ".L._errOutOfMemory" + + def generateHandler: Chain[AsmLine] = Chain( + LabelDef(OutOfMemoryError.errLabel), + stackAlign, + Load(RDI, IndexAddress(RIP, LabelArg(OutOfMemoryError.strLabel))), + assemblyIR.Call(CLibFunc.PrintF), + Move(RDI, ImmediateVal(255)), + assemblyIR.Call(CLibFunc.Exit) + ) + } + val all: Chain[RuntimeError] = - Chain(ZeroDivError, BadChrError, NullPtrError, OverflowError, OutOfBoundsError) + Chain(ZeroDivError, BadChrError, NullPtrError, OverflowError, OutOfBoundsError, + OutOfMemoryError) } diff --git a/src/main/wacc/backend/asmGenerator.scala b/src/main/wacc/backend/asmGenerator.scala index c900756..10542f9 100644 --- a/src/main/wacc/backend/asmGenerator.scala +++ b/src/main/wacc/backend/asmGenerator.scala @@ -141,8 +141,13 @@ object asmGenerator { chain ++= wrapBuiltinFunc( labelGenerator.getLabel(Builtin.Malloc), - Chain(stackAlign, assemblyIR.Call(CLibFunc.Malloc)) - // Out of memory check is optional + Chain( + stackAlign, + assemblyIR.Call(CLibFunc.Malloc), + // Out of memory check + Compare(RAX, ImmediateVal(0)), + Jump(LabelArg(OutOfMemoryError.errLabel), Cond.Equal) + ) ) chain ++= wrapBuiltinFunc( From 30f4309fda8652c8582b6f6fa85b308e82aa1d7b Mon Sep 17 00:00:00 2001 From: Guy C Date: Fri, 28 Feb 2025 12:14:55 +0000 Subject: [PATCH 3/8] feat: use errorcode constant in runtime errors --- src/main/wacc/backend/RuntimeError.scala | 16 +++++++++------- 1 file changed, 9 insertions(+), 7 deletions(-) diff --git a/src/main/wacc/backend/RuntimeError.scala b/src/main/wacc/backend/RuntimeError.scala index e6e972c..934077a 100644 --- a/src/main/wacc/backend/RuntimeError.scala +++ b/src/main/wacc/backend/RuntimeError.scala @@ -3,6 +3,8 @@ package wacc import cats.data.Chain import wacc.assemblyIR._ +val ERROR_CODE = 255 + sealed trait RuntimeError { def strLabel: String def errStr: String @@ -33,7 +35,7 @@ object RuntimeError { // private val RBP = Register(Q64, BP) private val RSI = Register(Q64, SI) // private val RDX = Register(Q64, DX) - // private val RCX = Register(Q64, CX) + private val RCX = Register(Q64, CX) case object ZeroDivError extends RuntimeError { val strLabel = ".L._errDivZero_str0" @@ -62,7 +64,7 @@ object RuntimeError { stackAlign, Load(RDI, IndexAddress(RIP, LabelArg(BadChrError.strLabel))), assemblyIR.Call(CLibFunc.PrintF), - Move(RDI, ImmediateVal(255)), + Move(RDI, ImmediateVal(ERROR_CODE)), assemblyIR.Call(CLibFunc.Exit) ) @@ -78,7 +80,7 @@ object RuntimeError { stackAlign, Load(RDI, IndexAddress(RIP, LabelArg(NullPtrError.strLabel))), assemblyIR.Call(CLibFunc.PrintF), - Move(RDI, ImmediateVal(255)), + Move(RDI, ImmediateVal(ERROR_CODE)), assemblyIR.Call(CLibFunc.Exit) ) @@ -94,7 +96,7 @@ object RuntimeError { stackAlign, Load(RDI, IndexAddress(RIP, LabelArg(OverflowError.strLabel))), assemblyIR.Call(CLibFunc.PrintF), - Move(RDI, ImmediateVal(255)), + Move(RDI, ImmediateVal(ERROR_CODE)), assemblyIR.Call(CLibFunc.Exit) ) @@ -108,11 +110,11 @@ object RuntimeError { def generateHandler: Chain[AsmLine] = Chain( LabelDef(OutOfBoundsError.errLabel), - Move(RSI, Register(Q64, CX)), + Move(RSI, RCX), stackAlign, Load(RDI, IndexAddress(RIP, LabelArg(OutOfBoundsError.strLabel))), assemblyIR.Call(CLibFunc.PrintF), - Move(RDI, ImmediateVal(255)), + Move(RDI, ImmediateVal(ERROR_CODE)), assemblyIR.Call(CLibFunc.Exit) ) } @@ -127,7 +129,7 @@ object RuntimeError { stackAlign, Load(RDI, IndexAddress(RIP, LabelArg(OutOfMemoryError.strLabel))), assemblyIR.Call(CLibFunc.PrintF), - Move(RDI, ImmediateVal(255)), + Move(RDI, ImmediateVal(ERROR_CODE)), assemblyIR.Call(CLibFunc.Exit) ) } From 302099ab760c5e1baa6e360c56078dbf3d62ea42 Mon Sep 17 00:00:00 2001 From: Guy C Date: Fri, 28 Feb 2025 12:23:09 +0000 Subject: [PATCH 4/8] refactor: removes magic numbers in asmgenerator --- src/main/wacc/backend/RuntimeError.scala | 7 +------ src/main/wacc/backend/asmGenerator.scala | 4 +++- 2 files changed, 4 insertions(+), 7 deletions(-) diff --git a/src/main/wacc/backend/RuntimeError.scala b/src/main/wacc/backend/RuntimeError.scala index 934077a..8159485 100644 --- a/src/main/wacc/backend/RuntimeError.scala +++ b/src/main/wacc/backend/RuntimeError.scala @@ -3,8 +3,6 @@ package wacc import cats.data.Chain import wacc.assemblyIR._ -val ERROR_CODE = 255 - sealed trait RuntimeError { def strLabel: String def errStr: String @@ -28,14 +26,11 @@ object RuntimeError { import assemblyIR.Size._ import assemblyIR.RegName._ - // private val RAX = Register(Q64, AX) - // private val EAX = Register(D32, AX) private val RDI = Register(Q64, DI) private val RIP = Register(Q64, IP) - // private val RBP = Register(Q64, BP) private val RSI = Register(Q64, SI) - // private val RDX = Register(Q64, DX) private val RCX = Register(Q64, CX) + private val ERROR_CODE = 255 case object ZeroDivError extends RuntimeError { val strLabel = ".L._errDivZero_str0" diff --git a/src/main/wacc/backend/asmGenerator.scala b/src/main/wacc/backend/asmGenerator.scala index 10542f9..2d3fcf0 100644 --- a/src/main/wacc/backend/asmGenerator.scala +++ b/src/main/wacc/backend/asmGenerator.scala @@ -25,6 +25,8 @@ object asmGenerator { private val ECX = Register(D32, CX) private val argRegs = List(DI, SI, DX, CX, R8, R9) + private val _7_BIT_MASK = 0x7f + extension [T](chain: Chain[T]) def +(item: T): Chain[T] = chain.append(item) @@ -326,7 +328,7 @@ object asmGenerator { op match { case UnaryOperator.Chr => chain += Move(EAX, stack.head) - chain += And(EAX, ImmediateVal(-128)) + chain += And(EAX, ImmediateVal(~_7_BIT_MASK)) chain += Compare(EAX, ImmediateVal(0)) chain += Jump(LabelArg(BadChrError.errLabel), Cond.NotEqual) case UnaryOperator.Ord => // No op needed From 621849dfa4f7501c172e2e5078a532642cb41418 Mon Sep 17 00:00:00 2001 From: Jonny Date: Fri, 28 Feb 2025 13:55:02 +0000 Subject: [PATCH 5/8] refactor: rename local builder chain to asm --- src/main/wacc/backend/asmGenerator.scala | 336 +++++++++++------------ 1 file changed, 168 insertions(+), 168 deletions(-) diff --git a/src/main/wacc/backend/asmGenerator.scala b/src/main/wacc/backend/asmGenerator.scala index 2d3fcf0..0d67a0f 100644 --- a/src/main/wacc/backend/asmGenerator.scala +++ b/src/main/wacc/backend/asmGenerator.scala @@ -82,11 +82,11 @@ object asmGenerator { private def wrapBuiltinFunc(labelName: String, funcBody: Chain[AsmLine])(using stack: Stack ): Chain[AsmLine] = { - var chain = Chain.one[AsmLine](LabelDef(labelName)) - chain ++= funcPrologue() - chain ++= funcBody - chain ++= funcEpilogue() - chain + var asm = Chain.one[AsmLine](LabelDef(labelName)) + asm ++= funcPrologue() + asm ++= funcBody + asm ++= funcEpilogue() + asm } private def generateUserFunc(func: FuncDecl)(using @@ -97,29 +97,29 @@ object asmGenerator { // Setup the stack with param 7 and up func.params.drop(argRegs.size).foreach(stack.reserve(_)) stack.reserve(Q64) // Reserve return pointer slot - var chain = Chain.one[AsmLine](LabelDef(labelGenerator.getLabel(func.name))) - chain ++= funcPrologue() + var asm = Chain.one[AsmLine](LabelDef(labelGenerator.getLabel(func.name))) + asm ++= funcPrologue() // Push the rest of params onto the stack for simplicity argRegs.zip(func.params).foreach { (reg, param) => - chain += stack.push(param, Register(Q64, reg)) + asm += stack.push(param, Register(Q64, reg)) } - chain ++= func.body.foldMap(generateStmt(_)) + asm ++= func.body.foldMap(generateStmt(_)) // No need for epilogue here since all user functions must return explicitly - chain + asm } private def generateBuiltInFuncs()(using stack: Stack, labelGenerator: LabelGenerator ): Chain[AsmLine] = { - var chain = Chain.empty[AsmLine] + var asm = Chain.empty[AsmLine] - chain ++= wrapBuiltinFunc( + asm ++= wrapBuiltinFunc( labelGenerator.getLabel(Builtin.Exit), Chain(stackAlign, assemblyIR.Call(CLibFunc.Exit)) ) - chain ++= wrapBuiltinFunc( + asm ++= wrapBuiltinFunc( labelGenerator.getLabel(Builtin.Printf), Chain( stackAlign, @@ -129,7 +129,7 @@ object asmGenerator { ) ) - chain ++= wrapBuiltinFunc( + asm ++= wrapBuiltinFunc( labelGenerator.getLabel(Builtin.PrintCharArray), Chain( stackAlign, @@ -141,7 +141,7 @@ object asmGenerator { ) ) - chain ++= wrapBuiltinFunc( + asm ++= wrapBuiltinFunc( labelGenerator.getLabel(Builtin.Malloc), Chain( stackAlign, @@ -152,7 +152,7 @@ object asmGenerator { ) ) - chain ++= wrapBuiltinFunc( + asm ++= wrapBuiltinFunc( labelGenerator.getLabel(Builtin.Free), Chain( stackAlign, @@ -163,7 +163,7 @@ object asmGenerator { ) ) - chain ++= wrapBuiltinFunc( + asm ++= wrapBuiltinFunc( labelGenerator.getLabel(Builtin.Read), Chain( stackAlign, @@ -175,9 +175,9 @@ object asmGenerator { ) ) - chain ++= RuntimeError.all.foldMap(_.generateHandler) + asm ++= RuntimeError.all.foldMap(_.generateHandler) - chain + asm } private def generateStmt(stmt: Stmt)(using @@ -185,33 +185,33 @@ object asmGenerator { strings: ListBuffer[String], labelGenerator: LabelGenerator ): Chain[AsmLine] = { - var chain = Chain.empty[AsmLine] - chain += Comment(stmt.toString) + var asm = Chain.empty[AsmLine] + asm += Comment(stmt.toString) stmt match { case Assign(lhs, rhs) => lhs match { case ident: Ident => - if (!stack.contains(ident)) chain += stack.reserve(ident) - chain ++= evalExprOntoStack(rhs) - chain += stack.pop(RAX) - chain += Move(stack.accessVar(ident), RAX) + if (!stack.contains(ident)) asm += stack.reserve(ident) + asm ++= evalExprOntoStack(rhs) + asm += stack.pop(RAX) + asm += Move(stack.accessVar(ident), RAX) case ArrayElem(x, i) => - chain ++= evalExprOntoStack(rhs) - chain ++= evalExprOntoStack(i) - chain += stack.pop(RCX) - chain += Compare(ECX, ImmediateVal(0)) - chain += Jump(LabelArg(OutOfBoundsError.errLabel), Cond.Less) - chain += stack.push(Q64, RCX) - chain ++= evalExprOntoStack(x) - chain += stack.pop(RAX) - chain += stack.pop(RCX) - chain += Compare(EAX, ImmediateVal(0)) - chain += Jump(LabelArg(NullPtrError.errLabel), Cond.Equal) - chain += Compare(MemLocation(RAX, D32), ECX) - chain += Jump(LabelArg(OutOfBoundsError.errLabel), Cond.LessEqual) - chain += stack.pop(RDX) + asm ++= evalExprOntoStack(rhs) + asm ++= evalExprOntoStack(i) + asm += stack.pop(RCX) + asm += Compare(ECX, ImmediateVal(0)) + asm += Jump(LabelArg(OutOfBoundsError.errLabel), Cond.Less) + asm += stack.push(Q64, RCX) + asm ++= evalExprOntoStack(x) + asm += stack.pop(RAX) + asm += stack.pop(RCX) + asm += Compare(EAX, ImmediateVal(0)) + asm += Jump(LabelArg(NullPtrError.errLabel), Cond.Equal) + asm += Compare(MemLocation(RAX, D32), ECX) + asm += Jump(LabelArg(OutOfBoundsError.errLabel), Cond.LessEqual) + asm += stack.pop(RDX) - chain += Move( + asm += Move( IndexAddress(RAX, KnownType.Int.size.toInt, RCX, x.ty.elemSize.toInt), Register(x.ty.elemSize, DX) ) @@ -221,47 +221,47 @@ object asmGenerator { val elseLabel = labelGenerator.getLabel() val endLabel = labelGenerator.getLabel() - chain ++= evalExprOntoStack(cond) - chain += stack.pop(RAX) - chain += Compare(RAX, ImmediateVal(0)) - chain += Jump(LabelArg(elseLabel), Cond.Equal) + asm ++= evalExprOntoStack(cond) + asm += stack.pop(RAX) + asm += Compare(RAX, ImmediateVal(0)) + asm += Jump(LabelArg(elseLabel), Cond.Equal) - chain ++= stack.withScope(() => thenBranch.foldMap(generateStmt)) - chain += Jump(LabelArg(endLabel)) - chain += LabelDef(elseLabel) + asm ++= stack.withScope(() => thenBranch.foldMap(generateStmt)) + asm += Jump(LabelArg(endLabel)) + asm += LabelDef(elseLabel) - chain ++= stack.withScope(() => elseBranch.foldMap(generateStmt)) - chain += LabelDef(endLabel) + asm ++= stack.withScope(() => elseBranch.foldMap(generateStmt)) + asm += LabelDef(endLabel) case While(cond, body) => val startLabel = labelGenerator.getLabel() val endLabel = labelGenerator.getLabel() - chain += LabelDef(startLabel) - chain ++= evalExprOntoStack(cond) - chain += stack.pop(RAX) - chain += Compare(RAX, ImmediateVal(0)) - chain += Jump(LabelArg(endLabel), Cond.Equal) + asm += LabelDef(startLabel) + asm ++= evalExprOntoStack(cond) + asm += stack.pop(RAX) + asm += Compare(RAX, ImmediateVal(0)) + asm += Jump(LabelArg(endLabel), Cond.Equal) - chain ++= stack.withScope(() => body.foldMap(generateStmt)) - chain += Jump(LabelArg(startLabel)) - chain += LabelDef(endLabel) + asm ++= stack.withScope(() => body.foldMap(generateStmt)) + asm += Jump(LabelArg(startLabel)) + asm += LabelDef(endLabel) case call: microWacc.Call => - chain ++= generateCall(call, isTail = false) + asm ++= generateCall(call, isTail = false) case microWacc.Return(expr) => expr match { case call: microWacc.Call => - chain ++= generateCall(call, isTail = true) // tco + asm ++= generateCall(call, isTail = true) // tco case _ => - chain ++= evalExprOntoStack(expr) - chain += stack.pop(RAX) - chain ++= funcEpilogue() + asm ++= evalExprOntoStack(expr) + asm += stack.pop(RAX) + asm ++= funcEpilogue() } } - chain + asm } private def evalExprOntoStack(expr: Expr)(using @@ -269,138 +269,138 @@ object asmGenerator { strings: ListBuffer[String], labelGenerator: LabelGenerator ): Chain[AsmLine] = { - var chain = Chain.empty[AsmLine] + var asm = Chain.empty[AsmLine] val stackSizeStart = stack.size expr match { - case IntLiter(v) => chain += stack.push(KnownType.Int.size, ImmediateVal(v)) - case CharLiter(v) => chain += stack.push(KnownType.Char.size, ImmediateVal(v.toInt)) - case ident: Ident => chain += stack.push(ident.ty.size, stack.accessVar(ident)) + case IntLiter(v) => asm += stack.push(KnownType.Int.size, ImmediateVal(v)) + case CharLiter(v) => asm += stack.push(KnownType.Char.size, ImmediateVal(v.toInt)) + case ident: Ident => asm += stack.push(ident.ty.size, stack.accessVar(ident)) case array @ ArrayLiter(elems) => expr.ty match { case KnownType.String => strings += elems.collect { case CharLiter(v) => v }.mkString - chain += Load(RAX, IndexAddress(RIP, LabelArg(s".L.str${strings.size - 1}"))) - chain += stack.push(Q64, RAX) + asm += Load(RAX, IndexAddress(RIP, LabelArg(s".L.str${strings.size - 1}"))) + asm += stack.push(Q64, RAX) case ty => - chain ++= generateCall( + asm ++= generateCall( microWacc.Call(Builtin.Malloc, List(IntLiter(array.heapSize))), isTail = false ) - chain += stack.push(Q64, RAX) + asm += stack.push(Q64, RAX) // Store the length of the array at the start - chain += Move(MemLocation(RAX, D32), ImmediateVal(elems.size)) + asm += Move(MemLocation(RAX, D32), ImmediateVal(elems.size)) elems.zipWithIndex.foldMap { (elem, i) => - chain ++= evalExprOntoStack(elem) - chain += stack.pop(RCX) - chain += stack.pop(RAX) - chain += Move(IndexAddress(RAX, 4 + i * ty.elemSize.toInt), Register(ty.elemSize, CX)) - chain += stack.push(Q64, RAX) + asm ++= evalExprOntoStack(elem) + asm += stack.pop(RCX) + asm += stack.pop(RAX) + asm += Move(IndexAddress(RAX, 4 + i * ty.elemSize.toInt), Register(ty.elemSize, CX)) + asm += stack.push(Q64, RAX) } } case BoolLiter(true) => - chain += stack.push(KnownType.Bool.size, ImmediateVal(1)) + asm += stack.push(KnownType.Bool.size, ImmediateVal(1)) case BoolLiter(false) => - chain += Xor(RAX, RAX) - chain += stack.push(KnownType.Bool.size, RAX) + asm += Xor(RAX, RAX) + asm += stack.push(KnownType.Bool.size, RAX) case NullLiter() => - chain += stack.push(KnownType.Pair(?, ?).size, ImmediateVal(0)) + asm += stack.push(KnownType.Pair(?, ?).size, ImmediateVal(0)) case ArrayElem(x, i) => - chain ++= evalExprOntoStack(x) - chain ++= evalExprOntoStack(i) - chain += stack.pop(RCX) - chain += Compare(RCX, ImmediateVal(0)) - chain += Jump(LabelArg(OutOfBoundsError.errLabel), Cond.Less) - chain += stack.pop(RAX) - chain += Compare(EAX, ImmediateVal(0)) - chain += Jump(LabelArg(NullPtrError.errLabel), Cond.Equal) - chain += Compare(MemLocation(RAX, D32), ECX) - chain += Jump(LabelArg(OutOfBoundsError.errLabel), Cond.LessEqual) + asm ++= evalExprOntoStack(x) + asm ++= evalExprOntoStack(i) + asm += stack.pop(RCX) + asm += Compare(RCX, ImmediateVal(0)) + asm += Jump(LabelArg(OutOfBoundsError.errLabel), Cond.Less) + asm += stack.pop(RAX) + asm += Compare(EAX, ImmediateVal(0)) + asm += Jump(LabelArg(NullPtrError.errLabel), Cond.Equal) + asm += Compare(MemLocation(RAX, D32), ECX) + asm += Jump(LabelArg(OutOfBoundsError.errLabel), Cond.LessEqual) // + Int because we store the length of the array at the start - chain += Move( + asm += Move( Register(x.ty.elemSize, AX), IndexAddress(RAX, KnownType.Int.size.toInt, RCX, x.ty.elemSize.toInt) ) - chain += stack.push(x.ty.elemSize, RAX) + asm += stack.push(x.ty.elemSize, RAX) case UnaryOp(x, op) => - chain ++= evalExprOntoStack(x) + asm ++= evalExprOntoStack(x) op match { case UnaryOperator.Chr => - chain += Move(EAX, stack.head) - chain += And(EAX, ImmediateVal(~_7_BIT_MASK)) - chain += Compare(EAX, ImmediateVal(0)) - chain += Jump(LabelArg(BadChrError.errLabel), Cond.NotEqual) + asm += Move(EAX, stack.head) + asm += And(EAX, ImmediateVal(~_7_BIT_MASK)) + asm += Compare(EAX, ImmediateVal(0)) + asm += Jump(LabelArg(BadChrError.errLabel), Cond.NotEqual) case UnaryOperator.Ord => // No op needed case UnaryOperator.Len => - chain += stack.pop(RAX) - chain += Move(EAX, MemLocation(RAX, D32)) - chain += stack.push(D32, RAX) + asm += stack.pop(RAX) + asm += Move(EAX, MemLocation(RAX, D32)) + asm += stack.push(D32, RAX) case UnaryOperator.Negate => - chain += Xor(EAX, EAX) - chain += Subtract(EAX, stack.head) - chain += Jump(LabelArg(OverflowError.errLabel), Cond.Overflow) - chain += stack.drop() - chain += stack.push(Q64, RAX) + asm += Xor(EAX, EAX) + asm += Subtract(EAX, stack.head) + asm += Jump(LabelArg(OverflowError.errLabel), Cond.Overflow) + asm += stack.drop() + asm += stack.push(Q64, RAX) case UnaryOperator.Not => - chain += Xor(stack.head, ImmediateVal(1)) + asm += Xor(stack.head, ImmediateVal(1)) } case BinaryOp(x, y, op) => val destX = Register(x.ty.size, AX) - chain ++= evalExprOntoStack(y) - chain ++= evalExprOntoStack(x) - chain += stack.pop(RAX) + asm ++= evalExprOntoStack(y) + asm ++= evalExprOntoStack(x) + asm += stack.pop(RAX) op match { case BinaryOperator.Add => - chain += Add(stack.head, destX) - chain += Jump(LabelArg(OverflowError.errLabel), Cond.Overflow) + asm += Add(stack.head, destX) + asm += Jump(LabelArg(OverflowError.errLabel), Cond.Overflow) case BinaryOperator.Sub => - chain += Subtract(destX, stack.head) - chain += Jump(LabelArg(OverflowError.errLabel), Cond.Overflow) - chain += stack.drop() - chain += stack.push(destX.size, RAX) + asm += Subtract(destX, stack.head) + asm += Jump(LabelArg(OverflowError.errLabel), Cond.Overflow) + asm += stack.drop() + asm += stack.push(destX.size, RAX) case BinaryOperator.Mul => - chain += Multiply(destX, stack.head) - chain += Jump(LabelArg(OverflowError.errLabel), Cond.Overflow) - chain += stack.drop() - chain += stack.push(destX.size, RAX) + asm += Multiply(destX, stack.head) + asm += Jump(LabelArg(OverflowError.errLabel), Cond.Overflow) + asm += stack.drop() + asm += stack.push(destX.size, RAX) case BinaryOperator.Div => - chain += Compare(stack.head, ImmediateVal(0)) - chain += Jump(LabelArg(ZeroDivError.errLabel), Cond.Equal) - chain += CDQ() - chain += Divide(stack.head) - chain += stack.drop() - chain += stack.push(destX.size, RAX) + asm += Compare(stack.head, ImmediateVal(0)) + asm += Jump(LabelArg(ZeroDivError.errLabel), Cond.Equal) + asm += CDQ() + asm += Divide(stack.head) + asm += stack.drop() + asm += stack.push(destX.size, RAX) case BinaryOperator.Mod => - chain += Compare(stack.head, ImmediateVal(0)) - chain += Jump(LabelArg(ZeroDivError.errLabel), Cond.Equal) - chain += CDQ() - chain += Divide(stack.head) - chain += stack.drop() - chain += stack.push(destX.size, RDX) + asm += Compare(stack.head, ImmediateVal(0)) + asm += Jump(LabelArg(ZeroDivError.errLabel), Cond.Equal) + asm += CDQ() + asm += Divide(stack.head) + asm += stack.drop() + asm += stack.push(destX.size, RDX) - case BinaryOperator.Eq => chain ++= generateComparison(destX, Cond.Equal) - case BinaryOperator.Neq => chain ++= generateComparison(destX, Cond.NotEqual) - case BinaryOperator.Greater => chain ++= generateComparison(destX, Cond.Greater) - case BinaryOperator.GreaterEq => chain ++= generateComparison(destX, Cond.GreaterEqual) - case BinaryOperator.Less => chain ++= generateComparison(destX, Cond.Less) - case BinaryOperator.LessEq => chain ++= generateComparison(destX, Cond.LessEqual) - case BinaryOperator.And => chain += And(stack.head, destX) - case BinaryOperator.Or => chain += Or(stack.head, destX) + case BinaryOperator.Eq => asm ++= generateComparison(destX, Cond.Equal) + case BinaryOperator.Neq => asm ++= generateComparison(destX, Cond.NotEqual) + case BinaryOperator.Greater => asm ++= generateComparison(destX, Cond.Greater) + case BinaryOperator.GreaterEq => asm ++= generateComparison(destX, Cond.GreaterEqual) + case BinaryOperator.Less => asm ++= generateComparison(destX, Cond.Less) + case BinaryOperator.LessEq => asm ++= generateComparison(destX, Cond.LessEqual) + case BinaryOperator.And => asm += And(stack.head, destX) + case BinaryOperator.Or => asm += Or(stack.head, destX) } case call: microWacc.Call => - chain ++= generateCall(call, isTail = false) - chain += stack.push(call.ty.size, RAX) + asm ++= generateCall(call, isTail = false) + asm += stack.push(call.ty.size, RAX) } assert(stack.size == stackSizeStart + 1) - chain ++= zeroRest(MemLocation(stack.head.pointer, Q64), expr.ty.size) - chain + asm ++= zeroRest(MemLocation(stack.head.pointer, Q64), expr.ty.size) + asm } private def generateCall(call: microWacc.Call, isTail: Boolean)(using @@ -408,65 +408,65 @@ object asmGenerator { strings: ListBuffer[String], labelGenerator: LabelGenerator ): Chain[AsmLine] = { - var chain = Chain.empty[AsmLine] + var asm = Chain.empty[AsmLine] val microWacc.Call(target, args) = call argRegs .zip(args) .map { (reg, expr) => - chain ++= evalExprOntoStack(expr) + asm ++= evalExprOntoStack(expr) reg } .reverse .foreach { reg => - chain += stack.pop(Register(Q64, reg)) + asm += stack.pop(Register(Q64, reg)) } args.drop(argRegs.size).foldMap { - chain ++= evalExprOntoStack(_) + asm ++= evalExprOntoStack(_) } // Tail Call Optimisation (TCO) if (isTail) { - chain += Jump(LabelArg(labelGenerator.getLabel(target))) // tail call + asm += Jump(LabelArg(labelGenerator.getLabel(target))) // tail call } else { - chain += assemblyIR.Call(LabelArg(labelGenerator.getLabel(target))) // regular call + asm += assemblyIR.Call(LabelArg(labelGenerator.getLabel(target))) // regular call } if (args.size > argRegs.size) { - chain += stack.drop(args.size - argRegs.size) + asm += stack.drop(args.size - argRegs.size) } - chain + asm } private def generateComparison(destX: Register, cond: Cond)(using stack: Stack ): Chain[AsmLine] = { - var chain = Chain.empty[AsmLine] + var asm = Chain.empty[AsmLine] - chain += Compare(destX, stack.head) - chain += Set(Register(B8, AX), cond) - chain ++= zeroRest(RAX, B8) - chain += stack.drop() - chain += stack.push(B8, RAX) + asm += Compare(destX, stack.head) + asm += Set(Register(B8, AX), cond) + asm ++= zeroRest(RAX, B8) + asm += stack.drop() + asm += stack.push(B8, RAX) - chain + asm } private def funcPrologue()(using stack: Stack): Chain[AsmLine] = { - var chain = Chain.empty[AsmLine] - chain += stack.push(Q64, RBP) - chain += Move(RBP, Register(Q64, SP)) - chain + var asm = Chain.empty[AsmLine] + asm += stack.push(Q64, RBP) + asm += Move(RBP, Register(Q64, SP)) + asm } private def funcEpilogue(): Chain[AsmLine] = { - var chain = Chain.empty[AsmLine] - chain += Move(Register(Q64, SP), RBP) - chain += Pop(RBP) - chain += assemblyIR.Return() - chain + var asm = Chain.empty[AsmLine] + asm += Move(Register(Q64, SP), RBP) + asm += Pop(RBP) + asm += assemblyIR.Return() + asm } def stackAlign: AsmLine = And(Register(Q64, SP), ImmediateVal(-16)) From d0a71c188851ae12d3ec3d59a91d3956c4ad7db8 Mon Sep 17 00:00:00 2001 From: Jonny Date: Fri, 28 Feb 2025 14:07:50 +0000 Subject: [PATCH 6/8] docs: add doc for concatall chain extension --- src/main/wacc/backend/asmGenerator.scala | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/src/main/wacc/backend/asmGenerator.scala b/src/main/wacc/backend/asmGenerator.scala index 0d67a0f..3c4ae64 100644 --- a/src/main/wacc/backend/asmGenerator.scala +++ b/src/main/wacc/backend/asmGenerator.scala @@ -30,6 +30,14 @@ object asmGenerator { extension [T](chain: Chain[T]) def +(item: T): Chain[T] = chain.append(item) + /** Concatenates multiple `Chain[T]` instances into a single `Chain[T]`, appending them to the + * current `Chain`. + * + * @param chains + * A variable number of `Chain[T]` instances to concatenate. + * @return + * A new `Chain[T]` containing all elements from `chain` concatenated with `chains`. + */ def concatAll(chains: Chain[T]*): Chain[T] = chains.foldLeft(chain)(_ ++ _) From c3f2ce8b197de1bf16e32e01be63f3e51f488261 Mon Sep 17 00:00:00 2001 From: Gleb Koval Date: Fri, 28 Feb 2025 14:47:47 +0000 Subject: [PATCH 7/8] refactor: single definition for common registers --- src/main/wacc/backend/RuntimeError.scala | 10 +--------- src/main/wacc/backend/asmGenerator.scala | 10 +--------- src/main/wacc/backend/assemblyIR.scala | 15 +++++++++++++++ 3 files changed, 17 insertions(+), 18 deletions(-) diff --git a/src/main/wacc/backend/RuntimeError.scala b/src/main/wacc/backend/RuntimeError.scala index 8159485..4e4cd07 100644 --- a/src/main/wacc/backend/RuntimeError.scala +++ b/src/main/wacc/backend/RuntimeError.scala @@ -19,17 +19,9 @@ sealed trait RuntimeError { } object RuntimeError { - - // TODO: Refactor to mitigate imports and redeclared vals perhaps - import wacc.asmGenerator.stackAlign - import assemblyIR.Size._ - import assemblyIR.RegName._ + import assemblyIR.commonRegisters._ - private val RDI = Register(Q64, DI) - private val RIP = Register(Q64, IP) - private val RSI = Register(Q64, SI) - private val RCX = Register(Q64, CX) private val ERROR_CODE = 255 case object ZeroDivError extends RuntimeError { diff --git a/src/main/wacc/backend/asmGenerator.scala b/src/main/wacc/backend/asmGenerator.scala index 3c4ae64..1ae9dc4 100644 --- a/src/main/wacc/backend/asmGenerator.scala +++ b/src/main/wacc/backend/asmGenerator.scala @@ -8,21 +8,13 @@ import wacc.RuntimeError._ object asmGenerator { import microWacc._ import assemblyIR._ + import assemblyIR.commonRegisters._ import assemblyIR.Size._ import assemblyIR.RegName._ import types._ import sizeExtensions._ import lexer.escapedChars - private val RAX = Register(Q64, AX) - private val EAX = Register(D32, AX) - private val RDI = Register(Q64, DI) - private val RIP = Register(Q64, IP) - private val RBP = Register(Q64, BP) - private val RSI = Register(Q64, SI) - private val RDX = Register(Q64, DX) - private val RCX = Register(Q64, CX) - private val ECX = Register(D32, CX) private val argRegs = List(DI, SI, DX, CX, R8, R9) private val _7_BIT_MASK = 0x7f diff --git a/src/main/wacc/backend/assemblyIR.scala b/src/main/wacc/backend/assemblyIR.scala index e8e7e62..b96325e 100644 --- a/src/main/wacc/backend/assemblyIR.scala +++ b/src/main/wacc/backend/assemblyIR.scala @@ -213,4 +213,19 @@ object assemblyIR { case String => "%s" } } + + object commonRegisters { + import Size._ + import RegName._ + + val RAX = Register(Q64, AX) + val EAX = Register(D32, AX) + val RDI = Register(Q64, DI) + val RIP = Register(Q64, IP) + val RBP = Register(Q64, BP) + val RSI = Register(Q64, SI) + val RDX = Register(Q64, DX) + val RCX = Register(Q64, CX) + val ECX = Register(D32, CX) + } } From 82997a5a389d67f09ff4bbe484d4f2937a3979a5 Mon Sep 17 00:00:00 2001 From: Gleb Koval Date: Fri, 28 Feb 2025 15:29:41 +0000 Subject: [PATCH 8/8] docs: clarify evalExprOntoStack sanity check, explanation comments for generateCall --- src/main/wacc/backend/asmGenerator.scala | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/src/main/wacc/backend/asmGenerator.scala b/src/main/wacc/backend/asmGenerator.scala index 938efa1..2a7366d 100644 --- a/src/main/wacc/backend/asmGenerator.scala +++ b/src/main/wacc/backend/asmGenerator.scala @@ -372,7 +372,10 @@ object asmGenerator { asm += stack.push(call.ty.size, RAX) } - assert(stack.size == stackSizeStart + 1) + assert( + stack.size == stackSizeStart + 1, + "Sanity check: ONLY the evaluated expression should have been pushed onto the stack" + ) asm ++= zeroRest(MemLocation(stack.head.pointer, Q64), expr.ty.size) asm } @@ -384,17 +387,20 @@ object asmGenerator { var asm = Chain.empty[AsmLine] val microWacc.Call(target, args) = call + // Evaluate arguments 0-6 argRegs .zip(args) .map { (reg, expr) => asm ++= evalExprOntoStack(expr) reg } + // And set the appropriate registers .reverse .foreach { reg => asm += stack.pop(Register(Q64, reg)) } + // Evaluate arguments 7 and up and push them onto the stack args.drop(argRegs.size).foldMap { asm ++= evalExprOntoStack(_) } @@ -406,6 +412,7 @@ object asmGenerator { asm += assemblyIR.Call(labelGenerator.getLabelArg(target)) // regular call } + // Remove arguments 7 and up from the stack if (args.size > argRegs.size) { asm += stack.drop(args.size - argRegs.size) }