diff --git a/src/main/wacc/backend/LabelGenerator.scala b/src/main/wacc/backend/LabelGenerator.scala new file mode 100644 index 0000000..3b5169b --- /dev/null +++ b/src/main/wacc/backend/LabelGenerator.scala @@ -0,0 +1,50 @@ +package wacc + +import scala.collection.mutable +import cats.data.Chain + +private class LabelGenerator { + import assemblyIR._ + import microWacc.{CallTarget, Ident, Builtin} + import asmGenerator.escaped + + private val strings = mutable.HashMap[String, String]() + private var labelVal = -1 + + /** Get an arbitrary label. */ + def getLabel(): String = { + labelVal += 1 + s".L$labelVal" + } + + private def getLabel(target: CallTarget | RuntimeError): String = target match { + case Ident(v, _) => s"wacc_$v" + case Builtin(name) => s"_$name" + case err: RuntimeError => s".L.${err.name}" + } + + /** Get a named label def for a function or error. */ + def getLabelDef(target: CallTarget | RuntimeError): LabelDef = + LabelDef(getLabel(target)) + + /** Get a named label for a function or error. */ + def getLabelArg(target: CallTarget | RuntimeError): LabelArg = + LabelArg(getLabel(target)) + + /** Get an arbitrary label for a string. */ + def getLabelArg(str: String): LabelArg = + LabelArg(strings.getOrElseUpdate(str, s".L.str${strings.size}")) + + /** Get a named label for a string. */ + def getLabelArg(src: String, name: String): LabelArg = + LabelArg(strings.getOrElseUpdate(src, s".L.$name.str${strings.size}")) + + /** Generate the assembly labels for constants that were labelled using the LabelGenerator. */ + def generateConstants: Chain[AsmLine] = + strings.foldLeft(Chain.empty) { case (acc, (str, label)) => + acc ++ Chain( + LabelDef(label), + Directive.Asciz(str.escaped) + ) + } +} diff --git a/src/main/wacc/backend/RuntimeError.scala b/src/main/wacc/backend/RuntimeError.scala index 4e4cd07..8e12831 100644 --- a/src/main/wacc/backend/RuntimeError.scala +++ b/src/main/wacc/backend/RuntimeError.scala @@ -4,18 +4,16 @@ import cats.data.Chain import wacc.assemblyIR._ sealed trait RuntimeError { - def strLabel: String - def errStr: String - def errLabel: String + val name: String + protected val errStr: String - def stringDef: Chain[AsmLine] = Chain( - Directive.Int(errStr.length), - LabelDef(strLabel), - Directive.Asciz(errStr) - ) + protected def getErrLabel(using labelGenerator: LabelGenerator): LabelArg = + labelGenerator.getLabelArg(errStr, name = name) - def generateHandler: Chain[AsmLine] + protected def generateHandler(using labelGenerator: LabelGenerator): Chain[AsmLine] + def generate(using labelGenerator: LabelGenerator): Chain[AsmLine] = + labelGenerator.getLabelDef(this) +: generateHandler } object RuntimeError { @@ -25,14 +23,12 @@ object RuntimeError { private val ERROR_CODE = 255 case object ZeroDivError extends RuntimeError { - val strLabel = ".L._errDivZero_str0" - val errStr = "fatal error: division or modulo by zero" - val errLabel = ".L._errDivZero" + val name = "errDivZero" + protected val errStr = "fatal error: division or modulo by zero" - def generateHandler: Chain[AsmLine] = Chain( - LabelDef(ZeroDivError.errLabel), + protected def generateHandler(using labelGenerator: LabelGenerator): Chain[AsmLine] = Chain( stackAlign, - Load(RDI, IndexAddress(RIP, LabelArg(ZeroDivError.strLabel))), + Load(RDI, IndexAddress(RIP, getErrLabel)), assemblyIR.Call(CLibFunc.PrintF), Move(RDI, ImmediateVal(-1)), assemblyIR.Call(CLibFunc.Exit) @@ -41,15 +37,13 @@ object RuntimeError { } case object BadChrError extends RuntimeError { - val strLabel = ".L._errBadChr_str0" - val errStr = "fatal error: int %d is not an ASCII character 0-127" - val errLabel = ".L._errBadChr" + val name = "errBadChr" + protected val errStr = "fatal error: int %d is not an ASCII character 0-127" - def generateHandler: Chain[AsmLine] = Chain( - LabelDef(BadChrError.errLabel), + protected def generateHandler(using labelGenerator: LabelGenerator): Chain[AsmLine] = Chain( Pop(RSI), stackAlign, - Load(RDI, IndexAddress(RIP, LabelArg(BadChrError.strLabel))), + Load(RDI, IndexAddress(RIP, getErrLabel)), assemblyIR.Call(CLibFunc.PrintF), Move(RDI, ImmediateVal(ERROR_CODE)), assemblyIR.Call(CLibFunc.Exit) @@ -58,14 +52,12 @@ object RuntimeError { } case object NullPtrError extends RuntimeError { - val strLabel = ".L._errNullPtr_str0" - val errStr = "fatal error: null pair dereferenced or freed" - val errLabel = ".L._errNullPtr" + val name = "errNullPtr" + protected val errStr = "fatal error: null pair dereferenced or freed" - def generateHandler: Chain[AsmLine] = Chain( - LabelDef(NullPtrError.errLabel), + protected def generateHandler(using labelGenerator: LabelGenerator): Chain[AsmLine] = Chain( stackAlign, - Load(RDI, IndexAddress(RIP, LabelArg(NullPtrError.strLabel))), + Load(RDI, IndexAddress(RIP, getErrLabel)), assemblyIR.Call(CLibFunc.PrintF), Move(RDI, ImmediateVal(ERROR_CODE)), assemblyIR.Call(CLibFunc.Exit) @@ -74,14 +66,12 @@ object RuntimeError { } case object OverflowError extends RuntimeError { - val strLabel = ".L._errOverflow_str0" - val errStr = "fatal error: integer overflow or underflow occurred" - val errLabel = ".L._errOverflow" + val name = "errOverflow" + protected val errStr = "fatal error: integer overflow or underflow occurred" - def generateHandler: Chain[AsmLine] = Chain( - LabelDef(OverflowError.errLabel), + protected def generateHandler(using labelGenerator: LabelGenerator): Chain[AsmLine] = Chain( stackAlign, - Load(RDI, IndexAddress(RIP, LabelArg(OverflowError.strLabel))), + Load(RDI, IndexAddress(RIP, getErrLabel)), assemblyIR.Call(CLibFunc.PrintF), Move(RDI, ImmediateVal(ERROR_CODE)), assemblyIR.Call(CLibFunc.Exit) @@ -90,16 +80,13 @@ object RuntimeError { } case object OutOfBoundsError extends RuntimeError { + val name = "errOutOfBounds" + protected val errStr = "fatal error: array index %d out of bounds" - val strLabel = ".L._errOutOfBounds_str0" - val errStr = "fatal error: array index %d out of bounds" - val errLabel = ".L._errOutOfBounds" - - def generateHandler: Chain[AsmLine] = Chain( - LabelDef(OutOfBoundsError.errLabel), + protected def generateHandler(using labelGenerator: LabelGenerator): Chain[AsmLine] = Chain( Move(RSI, RCX), stackAlign, - Load(RDI, IndexAddress(RIP, LabelArg(OutOfBoundsError.strLabel))), + Load(RDI, IndexAddress(RIP, getErrLabel)), assemblyIR.Call(CLibFunc.PrintF), Move(RDI, ImmediateVal(ERROR_CODE)), assemblyIR.Call(CLibFunc.Exit) @@ -107,14 +94,12 @@ object RuntimeError { } case object OutOfMemoryError extends RuntimeError { - val strLabel = ".L._errOutOfMemory_str0" - val errStr = "fatal error: out of memory" - val errLabel = ".L._errOutOfMemory" + val name = "errOutOfMemory" + protected val errStr = "fatal error: out of memory" - def generateHandler: Chain[AsmLine] = Chain( - LabelDef(OutOfMemoryError.errLabel), + def generateHandler(using labelGenerator: LabelGenerator): Chain[AsmLine] = Chain( stackAlign, - Load(RDI, IndexAddress(RIP, LabelArg(OutOfMemoryError.strLabel))), + Load(RDI, IndexAddress(RIP, getErrLabel)), assemblyIR.Call(CLibFunc.PrintF), Move(RDI, ImmediateVal(ERROR_CODE)), assemblyIR.Call(CLibFunc.Exit) diff --git a/src/main/wacc/backend/asmGenerator.scala b/src/main/wacc/backend/asmGenerator.scala index 1ae9dc4..938efa1 100644 --- a/src/main/wacc/backend/asmGenerator.scala +++ b/src/main/wacc/backend/asmGenerator.scala @@ -1,6 +1,5 @@ package wacc -import scala.collection.mutable.ListBuffer import cats.data.Chain import cats.syntax.foldable._ import wacc.RuntimeError._ @@ -33,21 +32,8 @@ object asmGenerator { def concatAll(chains: Chain[T]*): Chain[T] = chains.foldLeft(chain)(_ ++ _) - class LabelGenerator { - var labelVal = -1 - def getLabel(): String = { - labelVal += 1 - s".L$labelVal" - } - def getLabel(target: CallTarget): String = target match { - case Ident(v, _) => s"wacc_$v" - case Builtin(name) => s"_$name" - } - } - def generateAsm(microProg: Program): Chain[AsmLine] = { given stack: Stack = Stack() - given strings: ListBuffer[String] = ListBuffer[String]() given labelGenerator: LabelGenerator = LabelGenerator() val Program(funcs, main) = microProg @@ -57,32 +43,26 @@ object asmGenerator { Chain.one(Xor(RAX, RAX)), funcEpilogue(), generateBuiltInFuncs(), + RuntimeError.all.foldMap(_.generate), funcs.foldMap(generateUserFunc(_)) ) - val strDirs = strings.toList.zipWithIndex.foldMap { case (str, i) => - Chain( - Directive.Int(str.size), - LabelDef(s".L.str$i"), - Directive.Asciz(str.escaped) - ) - } ++ RuntimeError.all.foldMap(_.stringDef) - Chain( Directive.IntelSyntax, Directive.Global("main"), Directive.RoData ).concatAll( - strDirs, + labelGenerator.generateConstants, Chain.one(Directive.Text), progAsm ) } - private def wrapBuiltinFunc(labelName: String, funcBody: Chain[AsmLine])(using - stack: Stack + private def wrapBuiltinFunc(builtin: Builtin, funcBody: Chain[AsmLine])(using + stack: Stack, + labelGenerator: LabelGenerator ): Chain[AsmLine] = { - var asm = Chain.one[AsmLine](LabelDef(labelName)) + var asm = Chain.one[AsmLine](labelGenerator.getLabelDef(builtin)) asm ++= funcPrologue() asm ++= funcBody asm ++= funcEpilogue() @@ -90,14 +70,13 @@ object asmGenerator { } private def generateUserFunc(func: FuncDecl)(using - strings: ListBuffer[String], labelGenerator: LabelGenerator ): Chain[AsmLine] = { given stack: Stack = Stack() // Setup the stack with param 7 and up func.params.drop(argRegs.size).foreach(stack.reserve(_)) stack.reserve(Q64) // Reserve return pointer slot - var asm = Chain.one[AsmLine](LabelDef(labelGenerator.getLabel(func.name))) + var asm = Chain.one[AsmLine](labelGenerator.getLabelDef(func.name)) asm ++= funcPrologue() // Push the rest of params onto the stack for simplicity argRegs.zip(func.params).foreach { (reg, param) => @@ -115,12 +94,12 @@ object asmGenerator { var asm = Chain.empty[AsmLine] asm ++= wrapBuiltinFunc( - labelGenerator.getLabel(Builtin.Exit), + Builtin.Exit, Chain(stackAlign, assemblyIR.Call(CLibFunc.Exit)) ) asm ++= wrapBuiltinFunc( - labelGenerator.getLabel(Builtin.Printf), + Builtin.Printf, Chain( stackAlign, assemblyIR.Call(CLibFunc.PrintF), @@ -130,7 +109,7 @@ object asmGenerator { ) asm ++= wrapBuiltinFunc( - labelGenerator.getLabel(Builtin.PrintCharArray), + Builtin.PrintCharArray, Chain( stackAlign, Load(RDX, IndexAddress(RSI, KnownType.Int.size.toInt)), @@ -142,29 +121,29 @@ object asmGenerator { ) asm ++= wrapBuiltinFunc( - labelGenerator.getLabel(Builtin.Malloc), + Builtin.Malloc, Chain( stackAlign, assemblyIR.Call(CLibFunc.Malloc), // Out of memory check Compare(RAX, ImmediateVal(0)), - Jump(LabelArg(OutOfMemoryError.errLabel), Cond.Equal) + Jump(labelGenerator.getLabelArg(OutOfMemoryError), Cond.Equal) ) ) asm ++= wrapBuiltinFunc( - labelGenerator.getLabel(Builtin.Free), + Builtin.Free, Chain( stackAlign, Move(RDI, RAX), Compare(RDI, ImmediateVal(0)), - Jump(LabelArg(NullPtrError.errLabel), Cond.Equal), + Jump(labelGenerator.getLabelArg(NullPtrError), Cond.Equal), assemblyIR.Call(CLibFunc.Free) ) ) asm ++= wrapBuiltinFunc( - labelGenerator.getLabel(Builtin.Read), + Builtin.Read, Chain( stackAlign, Subtract(Register(Q64, SP), ImmediateVal(8)), @@ -175,18 +154,14 @@ object asmGenerator { ) ) - asm ++= RuntimeError.all.foldMap(_.generateHandler) - asm } private def generateStmt(stmt: Stmt)(using stack: Stack, - strings: ListBuffer[String], labelGenerator: LabelGenerator ): Chain[AsmLine] = { var asm = Chain.empty[AsmLine] - asm += Comment(stmt.toString) stmt match { case Assign(lhs, rhs) => lhs match { @@ -200,15 +175,15 @@ object asmGenerator { asm ++= evalExprOntoStack(i) asm += stack.pop(RCX) asm += Compare(ECX, ImmediateVal(0)) - asm += Jump(LabelArg(OutOfBoundsError.errLabel), Cond.Less) + asm += Jump(labelGenerator.getLabelArg(OutOfBoundsError), Cond.Less) asm += stack.push(Q64, RCX) asm ++= evalExprOntoStack(x) asm += stack.pop(RAX) asm += stack.pop(RCX) asm += Compare(EAX, ImmediateVal(0)) - asm += Jump(LabelArg(NullPtrError.errLabel), Cond.Equal) + asm += Jump(labelGenerator.getLabelArg(NullPtrError), Cond.Equal) asm += Compare(MemLocation(RAX, D32), ECX) - asm += Jump(LabelArg(OutOfBoundsError.errLabel), Cond.LessEqual) + asm += Jump(labelGenerator.getLabelArg(OutOfBoundsError), Cond.LessEqual) asm += stack.pop(RDX) asm += Move( @@ -266,7 +241,6 @@ object asmGenerator { private def evalExprOntoStack(expr: Expr)(using stack: Stack, - strings: ListBuffer[String], labelGenerator: LabelGenerator ): Chain[AsmLine] = { var asm = Chain.empty[AsmLine] @@ -279,8 +253,8 @@ object asmGenerator { case array @ ArrayLiter(elems) => expr.ty match { case KnownType.String => - strings += elems.collect { case CharLiter(v) => v }.mkString - asm += Load(RAX, IndexAddress(RIP, LabelArg(s".L.str${strings.size - 1}"))) + val str = elems.collect { case CharLiter(v) => v }.mkString + asm += Load(RAX, IndexAddress(RIP, labelGenerator.getLabelArg(str))) asm += stack.push(Q64, RAX) case ty => asm ++= generateCall( @@ -311,12 +285,12 @@ object asmGenerator { asm ++= evalExprOntoStack(i) asm += stack.pop(RCX) asm += Compare(RCX, ImmediateVal(0)) - asm += Jump(LabelArg(OutOfBoundsError.errLabel), Cond.Less) + asm += Jump(labelGenerator.getLabelArg(OutOfBoundsError), Cond.Less) asm += stack.pop(RAX) asm += Compare(EAX, ImmediateVal(0)) - asm += Jump(LabelArg(NullPtrError.errLabel), Cond.Equal) + asm += Jump(labelGenerator.getLabelArg(NullPtrError), Cond.Equal) asm += Compare(MemLocation(RAX, D32), ECX) - asm += Jump(LabelArg(OutOfBoundsError.errLabel), Cond.LessEqual) + asm += Jump(labelGenerator.getLabelArg(OutOfBoundsError), Cond.LessEqual) // + Int because we store the length of the array at the start asm += Move( Register(x.ty.elemSize, AX), @@ -330,7 +304,7 @@ object asmGenerator { asm += Move(EAX, stack.head) asm += And(EAX, ImmediateVal(~_7_BIT_MASK)) asm += Compare(EAX, ImmediateVal(0)) - asm += Jump(LabelArg(BadChrError.errLabel), Cond.NotEqual) + asm += Jump(labelGenerator.getLabelArg(BadChrError), Cond.NotEqual) case UnaryOperator.Ord => // No op needed case UnaryOperator.Len => asm += stack.pop(RAX) @@ -339,7 +313,7 @@ object asmGenerator { case UnaryOperator.Negate => asm += Xor(EAX, EAX) asm += Subtract(EAX, stack.head) - asm += Jump(LabelArg(OverflowError.errLabel), Cond.Overflow) + asm += Jump(labelGenerator.getLabelArg(OverflowError), Cond.Overflow) asm += stack.drop() asm += stack.push(Q64, RAX) case UnaryOperator.Not => @@ -355,21 +329,21 @@ object asmGenerator { op match { case BinaryOperator.Add => asm += Add(stack.head, destX) - asm += Jump(LabelArg(OverflowError.errLabel), Cond.Overflow) + asm += Jump(labelGenerator.getLabelArg(OverflowError), Cond.Overflow) case BinaryOperator.Sub => asm += Subtract(destX, stack.head) - asm += Jump(LabelArg(OverflowError.errLabel), Cond.Overflow) + asm += Jump(labelGenerator.getLabelArg(OverflowError), Cond.Overflow) asm += stack.drop() asm += stack.push(destX.size, RAX) case BinaryOperator.Mul => asm += Multiply(destX, stack.head) - asm += Jump(LabelArg(OverflowError.errLabel), Cond.Overflow) + asm += Jump(labelGenerator.getLabelArg(OverflowError), Cond.Overflow) asm += stack.drop() asm += stack.push(destX.size, RAX) case BinaryOperator.Div => asm += Compare(stack.head, ImmediateVal(0)) - asm += Jump(LabelArg(ZeroDivError.errLabel), Cond.Equal) + asm += Jump(labelGenerator.getLabelArg(ZeroDivError), Cond.Equal) asm += CDQ() asm += Divide(stack.head) asm += stack.drop() @@ -377,7 +351,7 @@ object asmGenerator { case BinaryOperator.Mod => asm += Compare(stack.head, ImmediateVal(0)) - asm += Jump(LabelArg(ZeroDivError.errLabel), Cond.Equal) + asm += Jump(labelGenerator.getLabelArg(ZeroDivError), Cond.Equal) asm += CDQ() asm += Divide(stack.head) asm += stack.drop() @@ -405,7 +379,6 @@ object asmGenerator { private def generateCall(call: microWacc.Call, isTail: Boolean)(using stack: Stack, - strings: ListBuffer[String], labelGenerator: LabelGenerator ): Chain[AsmLine] = { var asm = Chain.empty[AsmLine] @@ -428,9 +401,9 @@ object asmGenerator { // Tail Call Optimisation (TCO) if (isTail) { - asm += Jump(LabelArg(labelGenerator.getLabel(target))) // tail call + asm += Jump(labelGenerator.getLabelArg(target)) // tail call } else { - asm += assemblyIR.Call(LabelArg(labelGenerator.getLabel(target))) // regular call + asm += assemblyIR.Call(labelGenerator.getLabelArg(target)) // regular call } if (args.size > argRegs.size) { @@ -477,7 +450,7 @@ object asmGenerator { private val escapedCharsMapping = escapedChars.map { case (k, v) => v -> s"\\$k" } extension (s: String) { - private def escaped: String = + def escaped: String = s.flatMap(c => escapedCharsMapping.getOrElse(c, c.toString)) } } diff --git a/src/main/wacc/backend/assemblyIR.scala b/src/main/wacc/backend/assemblyIR.scala index b96325e..9ba65ea 100644 --- a/src/main/wacc/backend/assemblyIR.scala +++ b/src/main/wacc/backend/assemblyIR.scala @@ -199,8 +199,8 @@ object assemblyIR { case Global(name) => s".globl $name" case Text => ".text" case RoData => ".section .rodata" - case Int(value) => s".int $value" - case Asciz(string) => s".asciz \"$string\"" + case Int(value) => s"\t.int $value" + case Asciz(string) => s"\t.asciz \"$string\"" } }