diff --git a/src/main/wacc/backend/LabelGenerator.scala b/src/main/wacc/backend/LabelGenerator.scala new file mode 100644 index 0000000..3b5169b --- /dev/null +++ b/src/main/wacc/backend/LabelGenerator.scala @@ -0,0 +1,50 @@ +package wacc + +import scala.collection.mutable +import cats.data.Chain + +private class LabelGenerator { + import assemblyIR._ + import microWacc.{CallTarget, Ident, Builtin} + import asmGenerator.escaped + + private val strings = mutable.HashMap[String, String]() + private var labelVal = -1 + + /** Get an arbitrary label. */ + def getLabel(): String = { + labelVal += 1 + s".L$labelVal" + } + + private def getLabel(target: CallTarget | RuntimeError): String = target match { + case Ident(v, _) => s"wacc_$v" + case Builtin(name) => s"_$name" + case err: RuntimeError => s".L.${err.name}" + } + + /** Get a named label def for a function or error. */ + def getLabelDef(target: CallTarget | RuntimeError): LabelDef = + LabelDef(getLabel(target)) + + /** Get a named label for a function or error. */ + def getLabelArg(target: CallTarget | RuntimeError): LabelArg = + LabelArg(getLabel(target)) + + /** Get an arbitrary label for a string. */ + def getLabelArg(str: String): LabelArg = + LabelArg(strings.getOrElseUpdate(str, s".L.str${strings.size}")) + + /** Get a named label for a string. */ + def getLabelArg(src: String, name: String): LabelArg = + LabelArg(strings.getOrElseUpdate(src, s".L.$name.str${strings.size}")) + + /** Generate the assembly labels for constants that were labelled using the LabelGenerator. */ + def generateConstants: Chain[AsmLine] = + strings.foldLeft(Chain.empty) { case (acc, (str, label)) => + acc ++ Chain( + LabelDef(label), + Directive.Asciz(str.escaped) + ) + } +} diff --git a/src/main/wacc/backend/RuntimeError.scala b/src/main/wacc/backend/RuntimeError.scala index 9085b63..5252f9c 100644 --- a/src/main/wacc/backend/RuntimeError.scala +++ b/src/main/wacc/backend/RuntimeError.scala @@ -4,18 +4,16 @@ import cats.data.Chain import wacc.assemblyIR._ sealed trait RuntimeError { - def strLabel: String - def errStr: String - def errLabel: String + val name: String + protected val errStr: String - def stringDef: Chain[AsmLine] = Chain( - Directive.Int(errStr.length), - LabelDef(strLabel), - Directive.Asciz(errStr) - ) + protected def getErrLabel(using labelGenerator: LabelGenerator): LabelArg = + labelGenerator.getLabelArg(errStr, name = name) - def generateHandler: Chain[AsmLine] + protected def generateHandler(using labelGenerator: LabelGenerator): Chain[AsmLine] + def generate(using labelGenerator: LabelGenerator): Chain[AsmLine] = + labelGenerator.getLabelDef(this) +: generateHandler } object RuntimeError { @@ -36,14 +34,12 @@ object RuntimeError { // private val RCX = Register(Q64, CX) case object ZeroDivError extends RuntimeError { - val strLabel = ".L._errDivZero_str0" - val errStr = "fatal error: division or modulo by zero" - val errLabel = ".L._errDivZero" + val name = "errDivZero" + protected val errStr = "fatal error: division or modulo by zero" - def generateHandler: Chain[AsmLine] = Chain( - LabelDef(ZeroDivError.errLabel), + protected def generateHandler(using labelGenerator: LabelGenerator): Chain[AsmLine] = Chain( stackAlign, - Load(RDI, IndexAddress(RIP, LabelArg(ZeroDivError.strLabel))), + Load(RDI, IndexAddress(RIP, getErrLabel)), assemblyIR.Call(CLibFunc.PrintF), Move(RDI, ImmediateVal(-1)), assemblyIR.Call(CLibFunc.Exit) @@ -52,15 +48,13 @@ object RuntimeError { } case object BadChrError extends RuntimeError { - val strLabel = ".L._errBadChr_str0" - val errStr = "fatal error: int %d is not an ASCII character 0-127" - val errLabel = ".L._errBadChr" + val name = "errBadChr" + protected val errStr = "fatal error: int %d is not an ASCII character 0-127" - def generateHandler: Chain[AsmLine] = Chain( - LabelDef(BadChrError.errLabel), + protected def generateHandler(using labelGenerator: LabelGenerator): Chain[AsmLine] = Chain( Pop(RSI), stackAlign, - Load(RDI, IndexAddress(RIP, LabelArg(BadChrError.strLabel))), + Load(RDI, IndexAddress(RIP, getErrLabel)), assemblyIR.Call(CLibFunc.PrintF), Move(RDI, ImmediateVal(255)), assemblyIR.Call(CLibFunc.Exit) @@ -69,14 +63,12 @@ object RuntimeError { } case object NullPtrError extends RuntimeError { - val strLabel = ".L._errNullPtr_str0" - val errStr = "fatal error: null pair dereferenced or freed" - val errLabel = ".L._errNullPtr" + val name = "errNullPtr" + protected val errStr = "fatal error: null pair dereferenced or freed" - def generateHandler: Chain[AsmLine] = Chain( - LabelDef(NullPtrError.errLabel), + protected def generateHandler(using labelGenerator: LabelGenerator): Chain[AsmLine] = Chain( stackAlign, - Load(RDI, IndexAddress(RIP, LabelArg(NullPtrError.strLabel))), + Load(RDI, IndexAddress(RIP, getErrLabel)), assemblyIR.Call(CLibFunc.PrintF), Move(RDI, ImmediateVal(255)), assemblyIR.Call(CLibFunc.Exit) @@ -85,14 +77,12 @@ object RuntimeError { } case object OverflowError extends RuntimeError { - val strLabel = ".L._errOverflow_str0" - val errStr = "fatal error: integer overflow or underflow occurred" - val errLabel = ".L._errOverflow" + val name = "errOverflow" + protected val errStr = "fatal error: integer overflow or underflow occurred" - def generateHandler: Chain[AsmLine] = Chain( - LabelDef(OverflowError.errLabel), + protected def generateHandler(using labelGenerator: LabelGenerator): Chain[AsmLine] = Chain( stackAlign, - Load(RDI, IndexAddress(RIP, LabelArg(OverflowError.strLabel))), + Load(RDI, IndexAddress(RIP, getErrLabel)), assemblyIR.Call(CLibFunc.PrintF), Move(RDI, ImmediateVal(255)), assemblyIR.Call(CLibFunc.Exit) @@ -101,16 +91,13 @@ object RuntimeError { } case object OutOfBoundsError extends RuntimeError { + val name = "errOutOfBounds" + protected val errStr = "fatal error: array index %d out of bounds" - val strLabel = ".L._errOutOfBounds_str0" - val errStr = "fatal error: array index %d out of bounds" - val errLabel = ".L._errOutOfBounds" - - def generateHandler: Chain[AsmLine] = Chain( - LabelDef(OutOfBoundsError.errLabel), + protected def generateHandler(using labelGenerator: LabelGenerator): Chain[AsmLine] = Chain( Move(RSI, Register(Q64, CX)), stackAlign, - Load(RDI, IndexAddress(RIP, LabelArg(OutOfBoundsError.strLabel))), + Load(RDI, IndexAddress(RIP, getErrLabel)), assemblyIR.Call(CLibFunc.PrintF), Move(RDI, ImmediateVal(255)), assemblyIR.Call(CLibFunc.Exit) diff --git a/src/main/wacc/backend/asmGenerator.scala b/src/main/wacc/backend/asmGenerator.scala index c900756..c78c3d1 100644 --- a/src/main/wacc/backend/asmGenerator.scala +++ b/src/main/wacc/backend/asmGenerator.scala @@ -1,6 +1,5 @@ package wacc -import scala.collection.mutable.ListBuffer import cats.data.Chain import cats.syntax.foldable._ import wacc.RuntimeError._ @@ -31,21 +30,8 @@ object asmGenerator { def concatAll(chains: Chain[T]*): Chain[T] = chains.foldLeft(chain)(_ ++ _) - class LabelGenerator { - var labelVal = -1 - def getLabel(): String = { - labelVal += 1 - s".L$labelVal" - } - def getLabel(target: CallTarget): String = target match { - case Ident(v, _) => s"wacc_$v" - case Builtin(name) => s"_$name" - } - } - def generateAsm(microProg: Program): Chain[AsmLine] = { given stack: Stack = Stack() - given strings: ListBuffer[String] = ListBuffer[String]() given labelGenerator: LabelGenerator = LabelGenerator() val Program(funcs, main) = microProg @@ -55,32 +41,26 @@ object asmGenerator { Chain.one(Xor(RAX, RAX)), funcEpilogue(), generateBuiltInFuncs(), + RuntimeError.all.foldMap(_.generate), funcs.foldMap(generateUserFunc(_)) ) - val strDirs = strings.toList.zipWithIndex.foldMap { case (str, i) => - Chain( - Directive.Int(str.size), - LabelDef(s".L.str$i"), - Directive.Asciz(str.escaped) - ) - } ++ RuntimeError.all.foldMap(_.stringDef) - Chain( Directive.IntelSyntax, Directive.Global("main"), Directive.RoData ).concatAll( - strDirs, + labelGenerator.generateConstants, Chain.one(Directive.Text), progAsm ) } - private def wrapBuiltinFunc(labelName: String, funcBody: Chain[AsmLine])(using - stack: Stack + private def wrapBuiltinFunc(builtin: Builtin, funcBody: Chain[AsmLine])(using + stack: Stack, + labelGenerator: LabelGenerator ): Chain[AsmLine] = { - var chain = Chain.one[AsmLine](LabelDef(labelName)) + var chain = Chain.one[AsmLine](labelGenerator.getLabelDef(builtin)) chain ++= funcPrologue() chain ++= funcBody chain ++= funcEpilogue() @@ -88,14 +68,13 @@ object asmGenerator { } private def generateUserFunc(func: FuncDecl)(using - strings: ListBuffer[String], labelGenerator: LabelGenerator ): Chain[AsmLine] = { given stack: Stack = Stack() // Setup the stack with param 7 and up func.params.drop(argRegs.size).foreach(stack.reserve(_)) stack.reserve(Q64) // Reserve return pointer slot - var chain = Chain.one[AsmLine](LabelDef(labelGenerator.getLabel(func.name))) + var chain = Chain.one[AsmLine](labelGenerator.getLabelDef(func.name)) chain ++= funcPrologue() // Push the rest of params onto the stack for simplicity argRegs.zip(func.params).foreach { (reg, param) => @@ -113,12 +92,12 @@ object asmGenerator { var chain = Chain.empty[AsmLine] chain ++= wrapBuiltinFunc( - labelGenerator.getLabel(Builtin.Exit), + Builtin.Exit, Chain(stackAlign, assemblyIR.Call(CLibFunc.Exit)) ) chain ++= wrapBuiltinFunc( - labelGenerator.getLabel(Builtin.Printf), + Builtin.Printf, Chain( stackAlign, assemblyIR.Call(CLibFunc.PrintF), @@ -128,7 +107,7 @@ object asmGenerator { ) chain ++= wrapBuiltinFunc( - labelGenerator.getLabel(Builtin.PrintCharArray), + Builtin.PrintCharArray, Chain( stackAlign, Load(RDX, IndexAddress(RSI, KnownType.Int.size.toInt)), @@ -140,24 +119,24 @@ object asmGenerator { ) chain ++= wrapBuiltinFunc( - labelGenerator.getLabel(Builtin.Malloc), + Builtin.Malloc, Chain(stackAlign, assemblyIR.Call(CLibFunc.Malloc)) // Out of memory check is optional ) chain ++= wrapBuiltinFunc( - labelGenerator.getLabel(Builtin.Free), + Builtin.Free, Chain( stackAlign, Move(RDI, RAX), Compare(RDI, ImmediateVal(0)), - Jump(LabelArg(NullPtrError.errLabel), Cond.Equal), + Jump(labelGenerator.getLabelArg(NullPtrError), Cond.Equal), assemblyIR.Call(CLibFunc.Free) ) ) chain ++= wrapBuiltinFunc( - labelGenerator.getLabel(Builtin.Read), + Builtin.Read, Chain( stackAlign, Subtract(Register(Q64, SP), ImmediateVal(8)), @@ -168,18 +147,14 @@ object asmGenerator { ) ) - chain ++= RuntimeError.all.foldMap(_.generateHandler) - chain } private def generateStmt(stmt: Stmt)(using stack: Stack, - strings: ListBuffer[String], labelGenerator: LabelGenerator ): Chain[AsmLine] = { var chain = Chain.empty[AsmLine] - chain += Comment(stmt.toString) stmt match { case Assign(lhs, rhs) => lhs match { @@ -193,15 +168,15 @@ object asmGenerator { chain ++= evalExprOntoStack(i) chain += stack.pop(RCX) chain += Compare(ECX, ImmediateVal(0)) - chain += Jump(LabelArg(OutOfBoundsError.errLabel), Cond.Less) + chain += Jump(labelGenerator.getLabelArg(OutOfBoundsError), Cond.Less) chain += stack.push(Q64, RCX) chain ++= evalExprOntoStack(x) chain += stack.pop(RAX) chain += stack.pop(RCX) chain += Compare(EAX, ImmediateVal(0)) - chain += Jump(LabelArg(NullPtrError.errLabel), Cond.Equal) + chain += Jump(labelGenerator.getLabelArg(NullPtrError), Cond.Equal) chain += Compare(MemLocation(RAX, D32), ECX) - chain += Jump(LabelArg(OutOfBoundsError.errLabel), Cond.LessEqual) + chain += Jump(labelGenerator.getLabelArg(OutOfBoundsError), Cond.LessEqual) chain += stack.pop(RDX) chain += Move( @@ -259,7 +234,6 @@ object asmGenerator { private def evalExprOntoStack(expr: Expr)(using stack: Stack, - strings: ListBuffer[String], labelGenerator: LabelGenerator ): Chain[AsmLine] = { var chain = Chain.empty[AsmLine] @@ -272,8 +246,8 @@ object asmGenerator { case array @ ArrayLiter(elems) => expr.ty match { case KnownType.String => - strings += elems.collect { case CharLiter(v) => v }.mkString - chain += Load(RAX, IndexAddress(RIP, LabelArg(s".L.str${strings.size - 1}"))) + val str = elems.collect { case CharLiter(v) => v }.mkString + chain += Load(RAX, IndexAddress(RIP, labelGenerator.getLabelArg(str))) chain += stack.push(Q64, RAX) case ty => chain ++= generateCall( @@ -304,12 +278,12 @@ object asmGenerator { chain ++= evalExprOntoStack(i) chain += stack.pop(RCX) chain += Compare(RCX, ImmediateVal(0)) - chain += Jump(LabelArg(OutOfBoundsError.errLabel), Cond.Less) + chain += Jump(labelGenerator.getLabelArg(OutOfBoundsError), Cond.Less) chain += stack.pop(RAX) chain += Compare(EAX, ImmediateVal(0)) - chain += Jump(LabelArg(NullPtrError.errLabel), Cond.Equal) + chain += Jump(labelGenerator.getLabelArg(NullPtrError), Cond.Equal) chain += Compare(MemLocation(RAX, D32), ECX) - chain += Jump(LabelArg(OutOfBoundsError.errLabel), Cond.LessEqual) + chain += Jump(labelGenerator.getLabelArg(OutOfBoundsError), Cond.LessEqual) // + Int because we store the length of the array at the start chain += Move( Register(x.ty.elemSize, AX), @@ -323,7 +297,7 @@ object asmGenerator { chain += Move(EAX, stack.head) chain += And(EAX, ImmediateVal(-128)) chain += Compare(EAX, ImmediateVal(0)) - chain += Jump(LabelArg(BadChrError.errLabel), Cond.NotEqual) + chain += Jump(labelGenerator.getLabelArg(BadChrError), Cond.NotEqual) case UnaryOperator.Ord => // No op needed case UnaryOperator.Len => chain += stack.pop(RAX) @@ -332,7 +306,7 @@ object asmGenerator { case UnaryOperator.Negate => chain += Xor(EAX, EAX) chain += Subtract(EAX, stack.head) - chain += Jump(LabelArg(OverflowError.errLabel), Cond.Overflow) + chain += Jump(labelGenerator.getLabelArg(OverflowError), Cond.Overflow) chain += stack.drop() chain += stack.push(Q64, RAX) case UnaryOperator.Not => @@ -348,21 +322,21 @@ object asmGenerator { op match { case BinaryOperator.Add => chain += Add(stack.head, destX) - chain += Jump(LabelArg(OverflowError.errLabel), Cond.Overflow) + chain += Jump(labelGenerator.getLabelArg(OverflowError), Cond.Overflow) case BinaryOperator.Sub => chain += Subtract(destX, stack.head) - chain += Jump(LabelArg(OverflowError.errLabel), Cond.Overflow) + chain += Jump(labelGenerator.getLabelArg(OverflowError), Cond.Overflow) chain += stack.drop() chain += stack.push(destX.size, RAX) case BinaryOperator.Mul => chain += Multiply(destX, stack.head) - chain += Jump(LabelArg(OverflowError.errLabel), Cond.Overflow) + chain += Jump(labelGenerator.getLabelArg(OverflowError), Cond.Overflow) chain += stack.drop() chain += stack.push(destX.size, RAX) case BinaryOperator.Div => chain += Compare(stack.head, ImmediateVal(0)) - chain += Jump(LabelArg(ZeroDivError.errLabel), Cond.Equal) + chain += Jump(labelGenerator.getLabelArg(ZeroDivError), Cond.Equal) chain += CDQ() chain += Divide(stack.head) chain += stack.drop() @@ -370,7 +344,7 @@ object asmGenerator { case BinaryOperator.Mod => chain += Compare(stack.head, ImmediateVal(0)) - chain += Jump(LabelArg(ZeroDivError.errLabel), Cond.Equal) + chain += Jump(labelGenerator.getLabelArg(ZeroDivError), Cond.Equal) chain += CDQ() chain += Divide(stack.head) chain += stack.drop() @@ -398,7 +372,6 @@ object asmGenerator { private def generateCall(call: microWacc.Call, isTail: Boolean)(using stack: Stack, - strings: ListBuffer[String], labelGenerator: LabelGenerator ): Chain[AsmLine] = { var chain = Chain.empty[AsmLine] @@ -421,9 +394,9 @@ object asmGenerator { // Tail Call Optimisation (TCO) if (isTail) { - chain += Jump(LabelArg(labelGenerator.getLabel(target))) // tail call + chain += Jump(labelGenerator.getLabelArg(target)) // tail call } else { - chain += assemblyIR.Call(LabelArg(labelGenerator.getLabel(target))) // regular call + chain += assemblyIR.Call(labelGenerator.getLabelArg(target)) // regular call } if (args.size > argRegs.size) { @@ -470,7 +443,7 @@ object asmGenerator { private val escapedCharsMapping = escapedChars.map { case (k, v) => v -> s"\\$k" } extension (s: String) { - private def escaped: String = + def escaped: String = s.flatMap(c => escapedCharsMapping.getOrElse(c, c.toString)) } } diff --git a/src/main/wacc/backend/assemblyIR.scala b/src/main/wacc/backend/assemblyIR.scala index fbf51f5..d2d374a 100644 --- a/src/main/wacc/backend/assemblyIR.scala +++ b/src/main/wacc/backend/assemblyIR.scala @@ -199,8 +199,8 @@ object assemblyIR { case Global(name) => s".globl $name" case Text => ".text" case RoData => ".section .rodata" - case Int(value) => s".int $value" - case Asciz(string) => s".asciz \"$string\"" + case Int(value) => s"\t.int $value" + case Asciz(string) => s"\t.asciz \"$string\"" } }