From 967a6fe58b1df8fbd0555009bf8d22c685d94935 Mon Sep 17 00:00:00 2001 From: Gleb Koval Date: Fri, 28 Feb 2025 13:14:29 +0000 Subject: [PATCH 1/3] refactor: replace strings ListBuffer with labelGenerator --- src/main/wacc/backend/asmGenerator.scala | 38 ++++++++++++------------ 1 file changed, 19 insertions(+), 19 deletions(-) diff --git a/src/main/wacc/backend/asmGenerator.scala b/src/main/wacc/backend/asmGenerator.scala index c900756..381bc00 100644 --- a/src/main/wacc/backend/asmGenerator.scala +++ b/src/main/wacc/backend/asmGenerator.scala @@ -1,6 +1,6 @@ package wacc -import scala.collection.mutable.ListBuffer +import scala.collection.mutable import cats.data.Chain import cats.syntax.foldable._ import wacc.RuntimeError._ @@ -31,8 +31,9 @@ object asmGenerator { def concatAll(chains: Chain[T]*): Chain[T] = chains.foldLeft(chain)(_ ++ _) - class LabelGenerator { - var labelVal = -1 + private class LabelGenerator { + private val strings = mutable.HashMap[String, String]() + private var labelVal = -1 def getLabel(): String = { labelVal += 1 s".L$labelVal" @@ -41,11 +42,21 @@ object asmGenerator { case Ident(v, _) => s"wacc_$v" case Builtin(name) => s"_$name" } + def getLabel(str: String): String = + strings.getOrElseUpdate(str, s".L.str${strings.size}") + + def generateConstants: Chain[AsmLine] = + strings.foldLeft(Chain.empty) { case (acc, (str, label)) => + acc ++ Chain( + Directive.Int(str.size), + LabelDef(label), + Directive.Asciz(str.escaped) + ) + } } def generateAsm(microProg: Program): Chain[AsmLine] = { given stack: Stack = Stack() - given strings: ListBuffer[String] = ListBuffer[String]() given labelGenerator: LabelGenerator = LabelGenerator() val Program(funcs, main) = microProg @@ -58,20 +69,13 @@ object asmGenerator { funcs.foldMap(generateUserFunc(_)) ) - val strDirs = strings.toList.zipWithIndex.foldMap { case (str, i) => - Chain( - Directive.Int(str.size), - LabelDef(s".L.str$i"), - Directive.Asciz(str.escaped) - ) - } ++ RuntimeError.all.foldMap(_.stringDef) - Chain( Directive.IntelSyntax, Directive.Global("main"), Directive.RoData ).concatAll( - strDirs, + labelGenerator.generateConstants, + RuntimeError.all.foldMap(_.stringDef), Chain.one(Directive.Text), progAsm ) @@ -88,7 +92,6 @@ object asmGenerator { } private def generateUserFunc(func: FuncDecl)(using - strings: ListBuffer[String], labelGenerator: LabelGenerator ): Chain[AsmLine] = { given stack: Stack = Stack() @@ -175,7 +178,6 @@ object asmGenerator { private def generateStmt(stmt: Stmt)(using stack: Stack, - strings: ListBuffer[String], labelGenerator: LabelGenerator ): Chain[AsmLine] = { var chain = Chain.empty[AsmLine] @@ -259,7 +261,6 @@ object asmGenerator { private def evalExprOntoStack(expr: Expr)(using stack: Stack, - strings: ListBuffer[String], labelGenerator: LabelGenerator ): Chain[AsmLine] = { var chain = Chain.empty[AsmLine] @@ -272,8 +273,8 @@ object asmGenerator { case array @ ArrayLiter(elems) => expr.ty match { case KnownType.String => - strings += elems.collect { case CharLiter(v) => v }.mkString - chain += Load(RAX, IndexAddress(RIP, LabelArg(s".L.str${strings.size - 1}"))) + val str = elems.collect { case CharLiter(v) => v }.mkString + chain += Load(RAX, IndexAddress(RIP, LabelArg(labelGenerator.getLabel(str)))) chain += stack.push(Q64, RAX) case ty => chain ++= generateCall( @@ -398,7 +399,6 @@ object asmGenerator { private def generateCall(call: microWacc.Call, isTail: Boolean)(using stack: Stack, - strings: ListBuffer[String], labelGenerator: LabelGenerator ): Chain[AsmLine] = { var chain = Chain.empty[AsmLine] From fb5799dbfdb9d0570b1707444eb2fb0b4de8fac7 Mon Sep 17 00:00:00 2001 From: Gleb Koval Date: Fri, 28 Feb 2025 14:07:00 +0000 Subject: [PATCH 2/3] refactor: use LabelGenerator for RuntimeErrors --- src/main/wacc/backend/LabelGenerator.scala | 46 +++++++++++++++ src/main/wacc/backend/RuntimeError.scala | 67 +++++++++------------- src/main/wacc/backend/asmGenerator.scala | 60 ++++++------------- src/main/wacc/backend/assemblyIR.scala | 4 +- 4 files changed, 91 insertions(+), 86 deletions(-) create mode 100644 src/main/wacc/backend/LabelGenerator.scala diff --git a/src/main/wacc/backend/LabelGenerator.scala b/src/main/wacc/backend/LabelGenerator.scala new file mode 100644 index 0000000..e618e0b --- /dev/null +++ b/src/main/wacc/backend/LabelGenerator.scala @@ -0,0 +1,46 @@ +package wacc + +import scala.collection.mutable +import cats.data.Chain + +private class LabelGenerator { + import assemblyIR._ + import microWacc.{CallTarget, Ident, Builtin} + import asmGenerator.escaped + + private val strings = mutable.HashMap[String, String]() + private var labelVal = -1 + + /** Get an arbitrary label. */ + def getLabel(): String = { + labelVal += 1 + s".L$labelVal" + } + + /** Get a named label for a function. */ + def getLabel(target: CallTarget): String = target match { + case Ident(v, _) => s"wacc_$v" + case Builtin(name) => s"_$name" + } + + /** Get a named label for an error. */ + def getLabel(target: RuntimeError): String = + s".L.${target.name}" + + /** Get an arbitrary label for a string. */ + def getLabel(str: String): String = + strings.getOrElseUpdate(str, s".L.str${strings.size}") + + /** Get a named label for a string. */ + def getLabel(src: String, name: String): String = + strings.getOrElseUpdate(src, s".L.$name.str${strings.size}") + + /** Generate the assembly labels for constants that were labelled using the LabelGenerator. */ + def generateConstants: Chain[AsmLine] = + strings.foldLeft(Chain.empty) { case (acc, (str, label)) => + acc ++ Chain( + LabelDef(label), + Directive.Asciz(str.escaped) + ) + } +} diff --git a/src/main/wacc/backend/RuntimeError.scala b/src/main/wacc/backend/RuntimeError.scala index 9085b63..61f1428 100644 --- a/src/main/wacc/backend/RuntimeError.scala +++ b/src/main/wacc/backend/RuntimeError.scala @@ -4,18 +4,16 @@ import cats.data.Chain import wacc.assemblyIR._ sealed trait RuntimeError { - def strLabel: String - def errStr: String - def errLabel: String + val name: String + protected val errStr: String - def stringDef: Chain[AsmLine] = Chain( - Directive.Int(errStr.length), - LabelDef(strLabel), - Directive.Asciz(errStr) - ) + protected def getErrLabel(using labelGenerator: LabelGenerator): String = + labelGenerator.getLabel(errStr, name = name) - def generateHandler: Chain[AsmLine] + protected def generateHandler(using labelGenerator: LabelGenerator): Chain[AsmLine] + def generate(using labelGenerator: LabelGenerator): Chain[AsmLine] = + LabelDef(labelGenerator.getLabel(this)) +: generateHandler } object RuntimeError { @@ -36,14 +34,12 @@ object RuntimeError { // private val RCX = Register(Q64, CX) case object ZeroDivError extends RuntimeError { - val strLabel = ".L._errDivZero_str0" - val errStr = "fatal error: division or modulo by zero" - val errLabel = ".L._errDivZero" + val name = "errDivZero" + protected val errStr = "fatal error: division or modulo by zero" - def generateHandler: Chain[AsmLine] = Chain( - LabelDef(ZeroDivError.errLabel), + protected def generateHandler(using labelGenerator: LabelGenerator): Chain[AsmLine] = Chain( stackAlign, - Load(RDI, IndexAddress(RIP, LabelArg(ZeroDivError.strLabel))), + Load(RDI, IndexAddress(RIP, LabelArg(getErrLabel))), assemblyIR.Call(CLibFunc.PrintF), Move(RDI, ImmediateVal(-1)), assemblyIR.Call(CLibFunc.Exit) @@ -52,15 +48,13 @@ object RuntimeError { } case object BadChrError extends RuntimeError { - val strLabel = ".L._errBadChr_str0" - val errStr = "fatal error: int %d is not an ASCII character 0-127" - val errLabel = ".L._errBadChr" + val name = "errBadChr" + protected val errStr = "fatal error: int %d is not an ASCII character 0-127" - def generateHandler: Chain[AsmLine] = Chain( - LabelDef(BadChrError.errLabel), + protected def generateHandler(using labelGenerator: LabelGenerator): Chain[AsmLine] = Chain( Pop(RSI), stackAlign, - Load(RDI, IndexAddress(RIP, LabelArg(BadChrError.strLabel))), + Load(RDI, IndexAddress(RIP, LabelArg(getErrLabel))), assemblyIR.Call(CLibFunc.PrintF), Move(RDI, ImmediateVal(255)), assemblyIR.Call(CLibFunc.Exit) @@ -69,14 +63,12 @@ object RuntimeError { } case object NullPtrError extends RuntimeError { - val strLabel = ".L._errNullPtr_str0" - val errStr = "fatal error: null pair dereferenced or freed" - val errLabel = ".L._errNullPtr" + val name = "errNullPtr" + protected val errStr = "fatal error: null pair dereferenced or freed" - def generateHandler: Chain[AsmLine] = Chain( - LabelDef(NullPtrError.errLabel), + protected def generateHandler(using labelGenerator: LabelGenerator): Chain[AsmLine] = Chain( stackAlign, - Load(RDI, IndexAddress(RIP, LabelArg(NullPtrError.strLabel))), + Load(RDI, IndexAddress(RIP, LabelArg(getErrLabel))), assemblyIR.Call(CLibFunc.PrintF), Move(RDI, ImmediateVal(255)), assemblyIR.Call(CLibFunc.Exit) @@ -85,14 +77,12 @@ object RuntimeError { } case object OverflowError extends RuntimeError { - val strLabel = ".L._errOverflow_str0" - val errStr = "fatal error: integer overflow or underflow occurred" - val errLabel = ".L._errOverflow" + val name = "errOverflow" + protected val errStr = "fatal error: integer overflow or underflow occurred" - def generateHandler: Chain[AsmLine] = Chain( - LabelDef(OverflowError.errLabel), + protected def generateHandler(using labelGenerator: LabelGenerator): Chain[AsmLine] = Chain( stackAlign, - Load(RDI, IndexAddress(RIP, LabelArg(OverflowError.strLabel))), + Load(RDI, IndexAddress(RIP, LabelArg(getErrLabel))), assemblyIR.Call(CLibFunc.PrintF), Move(RDI, ImmediateVal(255)), assemblyIR.Call(CLibFunc.Exit) @@ -101,16 +91,13 @@ object RuntimeError { } case object OutOfBoundsError extends RuntimeError { + val name = "errOutOfBounds" + protected val errStr = "fatal error: array index %d out of bounds" - val strLabel = ".L._errOutOfBounds_str0" - val errStr = "fatal error: array index %d out of bounds" - val errLabel = ".L._errOutOfBounds" - - def generateHandler: Chain[AsmLine] = Chain( - LabelDef(OutOfBoundsError.errLabel), + protected def generateHandler(using labelGenerator: LabelGenerator): Chain[AsmLine] = Chain( Move(RSI, Register(Q64, CX)), stackAlign, - Load(RDI, IndexAddress(RIP, LabelArg(OutOfBoundsError.strLabel))), + Load(RDI, IndexAddress(RIP, LabelArg(getErrLabel))), assemblyIR.Call(CLibFunc.PrintF), Move(RDI, ImmediateVal(255)), assemblyIR.Call(CLibFunc.Exit) diff --git a/src/main/wacc/backend/asmGenerator.scala b/src/main/wacc/backend/asmGenerator.scala index 381bc00..68bf5b1 100644 --- a/src/main/wacc/backend/asmGenerator.scala +++ b/src/main/wacc/backend/asmGenerator.scala @@ -1,6 +1,5 @@ package wacc -import scala.collection.mutable import cats.data.Chain import cats.syntax.foldable._ import wacc.RuntimeError._ @@ -31,30 +30,6 @@ object asmGenerator { def concatAll(chains: Chain[T]*): Chain[T] = chains.foldLeft(chain)(_ ++ _) - private class LabelGenerator { - private val strings = mutable.HashMap[String, String]() - private var labelVal = -1 - def getLabel(): String = { - labelVal += 1 - s".L$labelVal" - } - def getLabel(target: CallTarget): String = target match { - case Ident(v, _) => s"wacc_$v" - case Builtin(name) => s"_$name" - } - def getLabel(str: String): String = - strings.getOrElseUpdate(str, s".L.str${strings.size}") - - def generateConstants: Chain[AsmLine] = - strings.foldLeft(Chain.empty) { case (acc, (str, label)) => - acc ++ Chain( - Directive.Int(str.size), - LabelDef(label), - Directive.Asciz(str.escaped) - ) - } - } - def generateAsm(microProg: Program): Chain[AsmLine] = { given stack: Stack = Stack() given labelGenerator: LabelGenerator = LabelGenerator() @@ -66,6 +41,7 @@ object asmGenerator { Chain.one(Xor(RAX, RAX)), funcEpilogue(), generateBuiltInFuncs(), + RuntimeError.all.foldMap(_.generate), funcs.foldMap(generateUserFunc(_)) ) @@ -75,7 +51,6 @@ object asmGenerator { Directive.RoData ).concatAll( labelGenerator.generateConstants, - RuntimeError.all.foldMap(_.stringDef), Chain.one(Directive.Text), progAsm ) @@ -154,7 +129,7 @@ object asmGenerator { stackAlign, Move(RDI, RAX), Compare(RDI, ImmediateVal(0)), - Jump(LabelArg(NullPtrError.errLabel), Cond.Equal), + Jump(LabelArg(labelGenerator.getLabel(NullPtrError)), Cond.Equal), assemblyIR.Call(CLibFunc.Free) ) ) @@ -171,8 +146,6 @@ object asmGenerator { ) ) - chain ++= RuntimeError.all.foldMap(_.generateHandler) - chain } @@ -181,7 +154,6 @@ object asmGenerator { labelGenerator: LabelGenerator ): Chain[AsmLine] = { var chain = Chain.empty[AsmLine] - chain += Comment(stmt.toString) stmt match { case Assign(lhs, rhs) => lhs match { @@ -195,15 +167,15 @@ object asmGenerator { chain ++= evalExprOntoStack(i) chain += stack.pop(RCX) chain += Compare(ECX, ImmediateVal(0)) - chain += Jump(LabelArg(OutOfBoundsError.errLabel), Cond.Less) + chain += Jump(LabelArg(labelGenerator.getLabel(OutOfBoundsError)), Cond.Less) chain += stack.push(Q64, RCX) chain ++= evalExprOntoStack(x) chain += stack.pop(RAX) chain += stack.pop(RCX) chain += Compare(EAX, ImmediateVal(0)) - chain += Jump(LabelArg(NullPtrError.errLabel), Cond.Equal) + chain += Jump(LabelArg(labelGenerator.getLabel(NullPtrError)), Cond.Equal) chain += Compare(MemLocation(RAX, D32), ECX) - chain += Jump(LabelArg(OutOfBoundsError.errLabel), Cond.LessEqual) + chain += Jump(LabelArg(labelGenerator.getLabel(OutOfBoundsError)), Cond.LessEqual) chain += stack.pop(RDX) chain += Move( @@ -305,12 +277,12 @@ object asmGenerator { chain ++= evalExprOntoStack(i) chain += stack.pop(RCX) chain += Compare(RCX, ImmediateVal(0)) - chain += Jump(LabelArg(OutOfBoundsError.errLabel), Cond.Less) + chain += Jump(LabelArg(labelGenerator.getLabel(OutOfBoundsError)), Cond.Less) chain += stack.pop(RAX) chain += Compare(EAX, ImmediateVal(0)) - chain += Jump(LabelArg(NullPtrError.errLabel), Cond.Equal) + chain += Jump(LabelArg(labelGenerator.getLabel(NullPtrError)), Cond.Equal) chain += Compare(MemLocation(RAX, D32), ECX) - chain += Jump(LabelArg(OutOfBoundsError.errLabel), Cond.LessEqual) + chain += Jump(LabelArg(labelGenerator.getLabel(OutOfBoundsError)), Cond.LessEqual) // + Int because we store the length of the array at the start chain += Move( Register(x.ty.elemSize, AX), @@ -324,7 +296,7 @@ object asmGenerator { chain += Move(EAX, stack.head) chain += And(EAX, ImmediateVal(-128)) chain += Compare(EAX, ImmediateVal(0)) - chain += Jump(LabelArg(BadChrError.errLabel), Cond.NotEqual) + chain += Jump(LabelArg(labelGenerator.getLabel(BadChrError)), Cond.NotEqual) case UnaryOperator.Ord => // No op needed case UnaryOperator.Len => chain += stack.pop(RAX) @@ -333,7 +305,7 @@ object asmGenerator { case UnaryOperator.Negate => chain += Xor(EAX, EAX) chain += Subtract(EAX, stack.head) - chain += Jump(LabelArg(OverflowError.errLabel), Cond.Overflow) + chain += Jump(LabelArg(labelGenerator.getLabel(OverflowError)), Cond.Overflow) chain += stack.drop() chain += stack.push(Q64, RAX) case UnaryOperator.Not => @@ -349,21 +321,21 @@ object asmGenerator { op match { case BinaryOperator.Add => chain += Add(stack.head, destX) - chain += Jump(LabelArg(OverflowError.errLabel), Cond.Overflow) + chain += Jump(LabelArg(labelGenerator.getLabel(OverflowError)), Cond.Overflow) case BinaryOperator.Sub => chain += Subtract(destX, stack.head) - chain += Jump(LabelArg(OverflowError.errLabel), Cond.Overflow) + chain += Jump(LabelArg(labelGenerator.getLabel(OverflowError)), Cond.Overflow) chain += stack.drop() chain += stack.push(destX.size, RAX) case BinaryOperator.Mul => chain += Multiply(destX, stack.head) - chain += Jump(LabelArg(OverflowError.errLabel), Cond.Overflow) + chain += Jump(LabelArg(labelGenerator.getLabel(OverflowError)), Cond.Overflow) chain += stack.drop() chain += stack.push(destX.size, RAX) case BinaryOperator.Div => chain += Compare(stack.head, ImmediateVal(0)) - chain += Jump(LabelArg(ZeroDivError.errLabel), Cond.Equal) + chain += Jump(LabelArg(labelGenerator.getLabel(ZeroDivError)), Cond.Equal) chain += CDQ() chain += Divide(stack.head) chain += stack.drop() @@ -371,7 +343,7 @@ object asmGenerator { case BinaryOperator.Mod => chain += Compare(stack.head, ImmediateVal(0)) - chain += Jump(LabelArg(ZeroDivError.errLabel), Cond.Equal) + chain += Jump(LabelArg(labelGenerator.getLabel(ZeroDivError)), Cond.Equal) chain += CDQ() chain += Divide(stack.head) chain += stack.drop() @@ -470,7 +442,7 @@ object asmGenerator { private val escapedCharsMapping = escapedChars.map { case (k, v) => v -> s"\\$k" } extension (s: String) { - private def escaped: String = + def escaped: String = s.flatMap(c => escapedCharsMapping.getOrElse(c, c.toString)) } } diff --git a/src/main/wacc/backend/assemblyIR.scala b/src/main/wacc/backend/assemblyIR.scala index fbf51f5..d2d374a 100644 --- a/src/main/wacc/backend/assemblyIR.scala +++ b/src/main/wacc/backend/assemblyIR.scala @@ -199,8 +199,8 @@ object assemblyIR { case Global(name) => s".globl $name" case Text => ".text" case RoData => ".section .rodata" - case Int(value) => s".int $value" - case Asciz(string) => s".asciz \"$string\"" + case Int(value) => s"\t.int $value" + case Asciz(string) => s"\t.asciz \"$string\"" } } From 7627ec14d264b0cde4c3417ed26c809e00f5d677 Mon Sep 17 00:00:00 2001 From: Gleb Koval Date: Fri, 28 Feb 2025 14:23:09 +0000 Subject: [PATCH 3/3] refactor: return args and defs from labelGenerator, instead of strings --- src/main/wacc/backend/LabelGenerator.scala | 26 +++++----- src/main/wacc/backend/RuntimeError.scala | 16 +++---- src/main/wacc/backend/asmGenerator.scala | 55 +++++++++++----------- 3 files changed, 51 insertions(+), 46 deletions(-) diff --git a/src/main/wacc/backend/LabelGenerator.scala b/src/main/wacc/backend/LabelGenerator.scala index e618e0b..3b5169b 100644 --- a/src/main/wacc/backend/LabelGenerator.scala +++ b/src/main/wacc/backend/LabelGenerator.scala @@ -17,23 +17,27 @@ private class LabelGenerator { s".L$labelVal" } - /** Get a named label for a function. */ - def getLabel(target: CallTarget): String = target match { - case Ident(v, _) => s"wacc_$v" - case Builtin(name) => s"_$name" + private def getLabel(target: CallTarget | RuntimeError): String = target match { + case Ident(v, _) => s"wacc_$v" + case Builtin(name) => s"_$name" + case err: RuntimeError => s".L.${err.name}" } - /** Get a named label for an error. */ - def getLabel(target: RuntimeError): String = - s".L.${target.name}" + /** Get a named label def for a function or error. */ + def getLabelDef(target: CallTarget | RuntimeError): LabelDef = + LabelDef(getLabel(target)) + + /** Get a named label for a function or error. */ + def getLabelArg(target: CallTarget | RuntimeError): LabelArg = + LabelArg(getLabel(target)) /** Get an arbitrary label for a string. */ - def getLabel(str: String): String = - strings.getOrElseUpdate(str, s".L.str${strings.size}") + def getLabelArg(str: String): LabelArg = + LabelArg(strings.getOrElseUpdate(str, s".L.str${strings.size}")) /** Get a named label for a string. */ - def getLabel(src: String, name: String): String = - strings.getOrElseUpdate(src, s".L.$name.str${strings.size}") + def getLabelArg(src: String, name: String): LabelArg = + LabelArg(strings.getOrElseUpdate(src, s".L.$name.str${strings.size}")) /** Generate the assembly labels for constants that were labelled using the LabelGenerator. */ def generateConstants: Chain[AsmLine] = diff --git a/src/main/wacc/backend/RuntimeError.scala b/src/main/wacc/backend/RuntimeError.scala index 61f1428..5252f9c 100644 --- a/src/main/wacc/backend/RuntimeError.scala +++ b/src/main/wacc/backend/RuntimeError.scala @@ -7,13 +7,13 @@ sealed trait RuntimeError { val name: String protected val errStr: String - protected def getErrLabel(using labelGenerator: LabelGenerator): String = - labelGenerator.getLabel(errStr, name = name) + protected def getErrLabel(using labelGenerator: LabelGenerator): LabelArg = + labelGenerator.getLabelArg(errStr, name = name) protected def generateHandler(using labelGenerator: LabelGenerator): Chain[AsmLine] def generate(using labelGenerator: LabelGenerator): Chain[AsmLine] = - LabelDef(labelGenerator.getLabel(this)) +: generateHandler + labelGenerator.getLabelDef(this) +: generateHandler } object RuntimeError { @@ -39,7 +39,7 @@ object RuntimeError { protected def generateHandler(using labelGenerator: LabelGenerator): Chain[AsmLine] = Chain( stackAlign, - Load(RDI, IndexAddress(RIP, LabelArg(getErrLabel))), + Load(RDI, IndexAddress(RIP, getErrLabel)), assemblyIR.Call(CLibFunc.PrintF), Move(RDI, ImmediateVal(-1)), assemblyIR.Call(CLibFunc.Exit) @@ -54,7 +54,7 @@ object RuntimeError { protected def generateHandler(using labelGenerator: LabelGenerator): Chain[AsmLine] = Chain( Pop(RSI), stackAlign, - Load(RDI, IndexAddress(RIP, LabelArg(getErrLabel))), + Load(RDI, IndexAddress(RIP, getErrLabel)), assemblyIR.Call(CLibFunc.PrintF), Move(RDI, ImmediateVal(255)), assemblyIR.Call(CLibFunc.Exit) @@ -68,7 +68,7 @@ object RuntimeError { protected def generateHandler(using labelGenerator: LabelGenerator): Chain[AsmLine] = Chain( stackAlign, - Load(RDI, IndexAddress(RIP, LabelArg(getErrLabel))), + Load(RDI, IndexAddress(RIP, getErrLabel)), assemblyIR.Call(CLibFunc.PrintF), Move(RDI, ImmediateVal(255)), assemblyIR.Call(CLibFunc.Exit) @@ -82,7 +82,7 @@ object RuntimeError { protected def generateHandler(using labelGenerator: LabelGenerator): Chain[AsmLine] = Chain( stackAlign, - Load(RDI, IndexAddress(RIP, LabelArg(getErrLabel))), + Load(RDI, IndexAddress(RIP, getErrLabel)), assemblyIR.Call(CLibFunc.PrintF), Move(RDI, ImmediateVal(255)), assemblyIR.Call(CLibFunc.Exit) @@ -97,7 +97,7 @@ object RuntimeError { protected def generateHandler(using labelGenerator: LabelGenerator): Chain[AsmLine] = Chain( Move(RSI, Register(Q64, CX)), stackAlign, - Load(RDI, IndexAddress(RIP, LabelArg(getErrLabel))), + Load(RDI, IndexAddress(RIP, getErrLabel)), assemblyIR.Call(CLibFunc.PrintF), Move(RDI, ImmediateVal(255)), assemblyIR.Call(CLibFunc.Exit) diff --git a/src/main/wacc/backend/asmGenerator.scala b/src/main/wacc/backend/asmGenerator.scala index 68bf5b1..c78c3d1 100644 --- a/src/main/wacc/backend/asmGenerator.scala +++ b/src/main/wacc/backend/asmGenerator.scala @@ -56,10 +56,11 @@ object asmGenerator { ) } - private def wrapBuiltinFunc(labelName: String, funcBody: Chain[AsmLine])(using - stack: Stack + private def wrapBuiltinFunc(builtin: Builtin, funcBody: Chain[AsmLine])(using + stack: Stack, + labelGenerator: LabelGenerator ): Chain[AsmLine] = { - var chain = Chain.one[AsmLine](LabelDef(labelName)) + var chain = Chain.one[AsmLine](labelGenerator.getLabelDef(builtin)) chain ++= funcPrologue() chain ++= funcBody chain ++= funcEpilogue() @@ -73,7 +74,7 @@ object asmGenerator { // Setup the stack with param 7 and up func.params.drop(argRegs.size).foreach(stack.reserve(_)) stack.reserve(Q64) // Reserve return pointer slot - var chain = Chain.one[AsmLine](LabelDef(labelGenerator.getLabel(func.name))) + var chain = Chain.one[AsmLine](labelGenerator.getLabelDef(func.name)) chain ++= funcPrologue() // Push the rest of params onto the stack for simplicity argRegs.zip(func.params).foreach { (reg, param) => @@ -91,12 +92,12 @@ object asmGenerator { var chain = Chain.empty[AsmLine] chain ++= wrapBuiltinFunc( - labelGenerator.getLabel(Builtin.Exit), + Builtin.Exit, Chain(stackAlign, assemblyIR.Call(CLibFunc.Exit)) ) chain ++= wrapBuiltinFunc( - labelGenerator.getLabel(Builtin.Printf), + Builtin.Printf, Chain( stackAlign, assemblyIR.Call(CLibFunc.PrintF), @@ -106,7 +107,7 @@ object asmGenerator { ) chain ++= wrapBuiltinFunc( - labelGenerator.getLabel(Builtin.PrintCharArray), + Builtin.PrintCharArray, Chain( stackAlign, Load(RDX, IndexAddress(RSI, KnownType.Int.size.toInt)), @@ -118,24 +119,24 @@ object asmGenerator { ) chain ++= wrapBuiltinFunc( - labelGenerator.getLabel(Builtin.Malloc), + Builtin.Malloc, Chain(stackAlign, assemblyIR.Call(CLibFunc.Malloc)) // Out of memory check is optional ) chain ++= wrapBuiltinFunc( - labelGenerator.getLabel(Builtin.Free), + Builtin.Free, Chain( stackAlign, Move(RDI, RAX), Compare(RDI, ImmediateVal(0)), - Jump(LabelArg(labelGenerator.getLabel(NullPtrError)), Cond.Equal), + Jump(labelGenerator.getLabelArg(NullPtrError), Cond.Equal), assemblyIR.Call(CLibFunc.Free) ) ) chain ++= wrapBuiltinFunc( - labelGenerator.getLabel(Builtin.Read), + Builtin.Read, Chain( stackAlign, Subtract(Register(Q64, SP), ImmediateVal(8)), @@ -167,15 +168,15 @@ object asmGenerator { chain ++= evalExprOntoStack(i) chain += stack.pop(RCX) chain += Compare(ECX, ImmediateVal(0)) - chain += Jump(LabelArg(labelGenerator.getLabel(OutOfBoundsError)), Cond.Less) + chain += Jump(labelGenerator.getLabelArg(OutOfBoundsError), Cond.Less) chain += stack.push(Q64, RCX) chain ++= evalExprOntoStack(x) chain += stack.pop(RAX) chain += stack.pop(RCX) chain += Compare(EAX, ImmediateVal(0)) - chain += Jump(LabelArg(labelGenerator.getLabel(NullPtrError)), Cond.Equal) + chain += Jump(labelGenerator.getLabelArg(NullPtrError), Cond.Equal) chain += Compare(MemLocation(RAX, D32), ECX) - chain += Jump(LabelArg(labelGenerator.getLabel(OutOfBoundsError)), Cond.LessEqual) + chain += Jump(labelGenerator.getLabelArg(OutOfBoundsError), Cond.LessEqual) chain += stack.pop(RDX) chain += Move( @@ -246,7 +247,7 @@ object asmGenerator { expr.ty match { case KnownType.String => val str = elems.collect { case CharLiter(v) => v }.mkString - chain += Load(RAX, IndexAddress(RIP, LabelArg(labelGenerator.getLabel(str)))) + chain += Load(RAX, IndexAddress(RIP, labelGenerator.getLabelArg(str))) chain += stack.push(Q64, RAX) case ty => chain ++= generateCall( @@ -277,12 +278,12 @@ object asmGenerator { chain ++= evalExprOntoStack(i) chain += stack.pop(RCX) chain += Compare(RCX, ImmediateVal(0)) - chain += Jump(LabelArg(labelGenerator.getLabel(OutOfBoundsError)), Cond.Less) + chain += Jump(labelGenerator.getLabelArg(OutOfBoundsError), Cond.Less) chain += stack.pop(RAX) chain += Compare(EAX, ImmediateVal(0)) - chain += Jump(LabelArg(labelGenerator.getLabel(NullPtrError)), Cond.Equal) + chain += Jump(labelGenerator.getLabelArg(NullPtrError), Cond.Equal) chain += Compare(MemLocation(RAX, D32), ECX) - chain += Jump(LabelArg(labelGenerator.getLabel(OutOfBoundsError)), Cond.LessEqual) + chain += Jump(labelGenerator.getLabelArg(OutOfBoundsError), Cond.LessEqual) // + Int because we store the length of the array at the start chain += Move( Register(x.ty.elemSize, AX), @@ -296,7 +297,7 @@ object asmGenerator { chain += Move(EAX, stack.head) chain += And(EAX, ImmediateVal(-128)) chain += Compare(EAX, ImmediateVal(0)) - chain += Jump(LabelArg(labelGenerator.getLabel(BadChrError)), Cond.NotEqual) + chain += Jump(labelGenerator.getLabelArg(BadChrError), Cond.NotEqual) case UnaryOperator.Ord => // No op needed case UnaryOperator.Len => chain += stack.pop(RAX) @@ -305,7 +306,7 @@ object asmGenerator { case UnaryOperator.Negate => chain += Xor(EAX, EAX) chain += Subtract(EAX, stack.head) - chain += Jump(LabelArg(labelGenerator.getLabel(OverflowError)), Cond.Overflow) + chain += Jump(labelGenerator.getLabelArg(OverflowError), Cond.Overflow) chain += stack.drop() chain += stack.push(Q64, RAX) case UnaryOperator.Not => @@ -321,21 +322,21 @@ object asmGenerator { op match { case BinaryOperator.Add => chain += Add(stack.head, destX) - chain += Jump(LabelArg(labelGenerator.getLabel(OverflowError)), Cond.Overflow) + chain += Jump(labelGenerator.getLabelArg(OverflowError), Cond.Overflow) case BinaryOperator.Sub => chain += Subtract(destX, stack.head) - chain += Jump(LabelArg(labelGenerator.getLabel(OverflowError)), Cond.Overflow) + chain += Jump(labelGenerator.getLabelArg(OverflowError), Cond.Overflow) chain += stack.drop() chain += stack.push(destX.size, RAX) case BinaryOperator.Mul => chain += Multiply(destX, stack.head) - chain += Jump(LabelArg(labelGenerator.getLabel(OverflowError)), Cond.Overflow) + chain += Jump(labelGenerator.getLabelArg(OverflowError), Cond.Overflow) chain += stack.drop() chain += stack.push(destX.size, RAX) case BinaryOperator.Div => chain += Compare(stack.head, ImmediateVal(0)) - chain += Jump(LabelArg(labelGenerator.getLabel(ZeroDivError)), Cond.Equal) + chain += Jump(labelGenerator.getLabelArg(ZeroDivError), Cond.Equal) chain += CDQ() chain += Divide(stack.head) chain += stack.drop() @@ -343,7 +344,7 @@ object asmGenerator { case BinaryOperator.Mod => chain += Compare(stack.head, ImmediateVal(0)) - chain += Jump(LabelArg(labelGenerator.getLabel(ZeroDivError)), Cond.Equal) + chain += Jump(labelGenerator.getLabelArg(ZeroDivError), Cond.Equal) chain += CDQ() chain += Divide(stack.head) chain += stack.drop() @@ -393,9 +394,9 @@ object asmGenerator { // Tail Call Optimisation (TCO) if (isTail) { - chain += Jump(LabelArg(labelGenerator.getLabel(target))) // tail call + chain += Jump(labelGenerator.getLabelArg(target)) // tail call } else { - chain += assemblyIR.Call(LabelArg(labelGenerator.getLabel(target))) // regular call + chain += assemblyIR.Call(labelGenerator.getLabelArg(target)) // regular call } if (args.size > argRegs.size) {