diff --git a/src/main/wacc/backend/LabelGenerator.scala b/src/main/wacc/backend/LabelGenerator.scala new file mode 100644 index 0000000..e618e0b --- /dev/null +++ b/src/main/wacc/backend/LabelGenerator.scala @@ -0,0 +1,46 @@ +package wacc + +import scala.collection.mutable +import cats.data.Chain + +private class LabelGenerator { + import assemblyIR._ + import microWacc.{CallTarget, Ident, Builtin} + import asmGenerator.escaped + + private val strings = mutable.HashMap[String, String]() + private var labelVal = -1 + + /** Get an arbitrary label. */ + def getLabel(): String = { + labelVal += 1 + s".L$labelVal" + } + + /** Get a named label for a function. */ + def getLabel(target: CallTarget): String = target match { + case Ident(v, _) => s"wacc_$v" + case Builtin(name) => s"_$name" + } + + /** Get a named label for an error. */ + def getLabel(target: RuntimeError): String = + s".L.${target.name}" + + /** Get an arbitrary label for a string. */ + def getLabel(str: String): String = + strings.getOrElseUpdate(str, s".L.str${strings.size}") + + /** Get a named label for a string. */ + def getLabel(src: String, name: String): String = + strings.getOrElseUpdate(src, s".L.$name.str${strings.size}") + + /** Generate the assembly labels for constants that were labelled using the LabelGenerator. */ + def generateConstants: Chain[AsmLine] = + strings.foldLeft(Chain.empty) { case (acc, (str, label)) => + acc ++ Chain( + LabelDef(label), + Directive.Asciz(str.escaped) + ) + } +} diff --git a/src/main/wacc/backend/RuntimeError.scala b/src/main/wacc/backend/RuntimeError.scala index 9085b63..61f1428 100644 --- a/src/main/wacc/backend/RuntimeError.scala +++ b/src/main/wacc/backend/RuntimeError.scala @@ -4,18 +4,16 @@ import cats.data.Chain import wacc.assemblyIR._ sealed trait RuntimeError { - def strLabel: String - def errStr: String - def errLabel: String + val name: String + protected val errStr: String - def stringDef: Chain[AsmLine] = Chain( - Directive.Int(errStr.length), - LabelDef(strLabel), - Directive.Asciz(errStr) - ) + protected def getErrLabel(using labelGenerator: LabelGenerator): String = + labelGenerator.getLabel(errStr, name = name) - def generateHandler: Chain[AsmLine] + protected def generateHandler(using labelGenerator: LabelGenerator): Chain[AsmLine] + def generate(using labelGenerator: LabelGenerator): Chain[AsmLine] = + LabelDef(labelGenerator.getLabel(this)) +: generateHandler } object RuntimeError { @@ -36,14 +34,12 @@ object RuntimeError { // private val RCX = Register(Q64, CX) case object ZeroDivError extends RuntimeError { - val strLabel = ".L._errDivZero_str0" - val errStr = "fatal error: division or modulo by zero" - val errLabel = ".L._errDivZero" + val name = "errDivZero" + protected val errStr = "fatal error: division or modulo by zero" - def generateHandler: Chain[AsmLine] = Chain( - LabelDef(ZeroDivError.errLabel), + protected def generateHandler(using labelGenerator: LabelGenerator): Chain[AsmLine] = Chain( stackAlign, - Load(RDI, IndexAddress(RIP, LabelArg(ZeroDivError.strLabel))), + Load(RDI, IndexAddress(RIP, LabelArg(getErrLabel))), assemblyIR.Call(CLibFunc.PrintF), Move(RDI, ImmediateVal(-1)), assemblyIR.Call(CLibFunc.Exit) @@ -52,15 +48,13 @@ object RuntimeError { } case object BadChrError extends RuntimeError { - val strLabel = ".L._errBadChr_str0" - val errStr = "fatal error: int %d is not an ASCII character 0-127" - val errLabel = ".L._errBadChr" + val name = "errBadChr" + protected val errStr = "fatal error: int %d is not an ASCII character 0-127" - def generateHandler: Chain[AsmLine] = Chain( - LabelDef(BadChrError.errLabel), + protected def generateHandler(using labelGenerator: LabelGenerator): Chain[AsmLine] = Chain( Pop(RSI), stackAlign, - Load(RDI, IndexAddress(RIP, LabelArg(BadChrError.strLabel))), + Load(RDI, IndexAddress(RIP, LabelArg(getErrLabel))), assemblyIR.Call(CLibFunc.PrintF), Move(RDI, ImmediateVal(255)), assemblyIR.Call(CLibFunc.Exit) @@ -69,14 +63,12 @@ object RuntimeError { } case object NullPtrError extends RuntimeError { - val strLabel = ".L._errNullPtr_str0" - val errStr = "fatal error: null pair dereferenced or freed" - val errLabel = ".L._errNullPtr" + val name = "errNullPtr" + protected val errStr = "fatal error: null pair dereferenced or freed" - def generateHandler: Chain[AsmLine] = Chain( - LabelDef(NullPtrError.errLabel), + protected def generateHandler(using labelGenerator: LabelGenerator): Chain[AsmLine] = Chain( stackAlign, - Load(RDI, IndexAddress(RIP, LabelArg(NullPtrError.strLabel))), + Load(RDI, IndexAddress(RIP, LabelArg(getErrLabel))), assemblyIR.Call(CLibFunc.PrintF), Move(RDI, ImmediateVal(255)), assemblyIR.Call(CLibFunc.Exit) @@ -85,14 +77,12 @@ object RuntimeError { } case object OverflowError extends RuntimeError { - val strLabel = ".L._errOverflow_str0" - val errStr = "fatal error: integer overflow or underflow occurred" - val errLabel = ".L._errOverflow" + val name = "errOverflow" + protected val errStr = "fatal error: integer overflow or underflow occurred" - def generateHandler: Chain[AsmLine] = Chain( - LabelDef(OverflowError.errLabel), + protected def generateHandler(using labelGenerator: LabelGenerator): Chain[AsmLine] = Chain( stackAlign, - Load(RDI, IndexAddress(RIP, LabelArg(OverflowError.strLabel))), + Load(RDI, IndexAddress(RIP, LabelArg(getErrLabel))), assemblyIR.Call(CLibFunc.PrintF), Move(RDI, ImmediateVal(255)), assemblyIR.Call(CLibFunc.Exit) @@ -101,16 +91,13 @@ object RuntimeError { } case object OutOfBoundsError extends RuntimeError { + val name = "errOutOfBounds" + protected val errStr = "fatal error: array index %d out of bounds" - val strLabel = ".L._errOutOfBounds_str0" - val errStr = "fatal error: array index %d out of bounds" - val errLabel = ".L._errOutOfBounds" - - def generateHandler: Chain[AsmLine] = Chain( - LabelDef(OutOfBoundsError.errLabel), + protected def generateHandler(using labelGenerator: LabelGenerator): Chain[AsmLine] = Chain( Move(RSI, Register(Q64, CX)), stackAlign, - Load(RDI, IndexAddress(RIP, LabelArg(OutOfBoundsError.strLabel))), + Load(RDI, IndexAddress(RIP, LabelArg(getErrLabel))), assemblyIR.Call(CLibFunc.PrintF), Move(RDI, ImmediateVal(255)), assemblyIR.Call(CLibFunc.Exit) diff --git a/src/main/wacc/backend/asmGenerator.scala b/src/main/wacc/backend/asmGenerator.scala index 381bc00..68bf5b1 100644 --- a/src/main/wacc/backend/asmGenerator.scala +++ b/src/main/wacc/backend/asmGenerator.scala @@ -1,6 +1,5 @@ package wacc -import scala.collection.mutable import cats.data.Chain import cats.syntax.foldable._ import wacc.RuntimeError._ @@ -31,30 +30,6 @@ object asmGenerator { def concatAll(chains: Chain[T]*): Chain[T] = chains.foldLeft(chain)(_ ++ _) - private class LabelGenerator { - private val strings = mutable.HashMap[String, String]() - private var labelVal = -1 - def getLabel(): String = { - labelVal += 1 - s".L$labelVal" - } - def getLabel(target: CallTarget): String = target match { - case Ident(v, _) => s"wacc_$v" - case Builtin(name) => s"_$name" - } - def getLabel(str: String): String = - strings.getOrElseUpdate(str, s".L.str${strings.size}") - - def generateConstants: Chain[AsmLine] = - strings.foldLeft(Chain.empty) { case (acc, (str, label)) => - acc ++ Chain( - Directive.Int(str.size), - LabelDef(label), - Directive.Asciz(str.escaped) - ) - } - } - def generateAsm(microProg: Program): Chain[AsmLine] = { given stack: Stack = Stack() given labelGenerator: LabelGenerator = LabelGenerator() @@ -66,6 +41,7 @@ object asmGenerator { Chain.one(Xor(RAX, RAX)), funcEpilogue(), generateBuiltInFuncs(), + RuntimeError.all.foldMap(_.generate), funcs.foldMap(generateUserFunc(_)) ) @@ -75,7 +51,6 @@ object asmGenerator { Directive.RoData ).concatAll( labelGenerator.generateConstants, - RuntimeError.all.foldMap(_.stringDef), Chain.one(Directive.Text), progAsm ) @@ -154,7 +129,7 @@ object asmGenerator { stackAlign, Move(RDI, RAX), Compare(RDI, ImmediateVal(0)), - Jump(LabelArg(NullPtrError.errLabel), Cond.Equal), + Jump(LabelArg(labelGenerator.getLabel(NullPtrError)), Cond.Equal), assemblyIR.Call(CLibFunc.Free) ) ) @@ -171,8 +146,6 @@ object asmGenerator { ) ) - chain ++= RuntimeError.all.foldMap(_.generateHandler) - chain } @@ -181,7 +154,6 @@ object asmGenerator { labelGenerator: LabelGenerator ): Chain[AsmLine] = { var chain = Chain.empty[AsmLine] - chain += Comment(stmt.toString) stmt match { case Assign(lhs, rhs) => lhs match { @@ -195,15 +167,15 @@ object asmGenerator { chain ++= evalExprOntoStack(i) chain += stack.pop(RCX) chain += Compare(ECX, ImmediateVal(0)) - chain += Jump(LabelArg(OutOfBoundsError.errLabel), Cond.Less) + chain += Jump(LabelArg(labelGenerator.getLabel(OutOfBoundsError)), Cond.Less) chain += stack.push(Q64, RCX) chain ++= evalExprOntoStack(x) chain += stack.pop(RAX) chain += stack.pop(RCX) chain += Compare(EAX, ImmediateVal(0)) - chain += Jump(LabelArg(NullPtrError.errLabel), Cond.Equal) + chain += Jump(LabelArg(labelGenerator.getLabel(NullPtrError)), Cond.Equal) chain += Compare(MemLocation(RAX, D32), ECX) - chain += Jump(LabelArg(OutOfBoundsError.errLabel), Cond.LessEqual) + chain += Jump(LabelArg(labelGenerator.getLabel(OutOfBoundsError)), Cond.LessEqual) chain += stack.pop(RDX) chain += Move( @@ -305,12 +277,12 @@ object asmGenerator { chain ++= evalExprOntoStack(i) chain += stack.pop(RCX) chain += Compare(RCX, ImmediateVal(0)) - chain += Jump(LabelArg(OutOfBoundsError.errLabel), Cond.Less) + chain += Jump(LabelArg(labelGenerator.getLabel(OutOfBoundsError)), Cond.Less) chain += stack.pop(RAX) chain += Compare(EAX, ImmediateVal(0)) - chain += Jump(LabelArg(NullPtrError.errLabel), Cond.Equal) + chain += Jump(LabelArg(labelGenerator.getLabel(NullPtrError)), Cond.Equal) chain += Compare(MemLocation(RAX, D32), ECX) - chain += Jump(LabelArg(OutOfBoundsError.errLabel), Cond.LessEqual) + chain += Jump(LabelArg(labelGenerator.getLabel(OutOfBoundsError)), Cond.LessEqual) // + Int because we store the length of the array at the start chain += Move( Register(x.ty.elemSize, AX), @@ -324,7 +296,7 @@ object asmGenerator { chain += Move(EAX, stack.head) chain += And(EAX, ImmediateVal(-128)) chain += Compare(EAX, ImmediateVal(0)) - chain += Jump(LabelArg(BadChrError.errLabel), Cond.NotEqual) + chain += Jump(LabelArg(labelGenerator.getLabel(BadChrError)), Cond.NotEqual) case UnaryOperator.Ord => // No op needed case UnaryOperator.Len => chain += stack.pop(RAX) @@ -333,7 +305,7 @@ object asmGenerator { case UnaryOperator.Negate => chain += Xor(EAX, EAX) chain += Subtract(EAX, stack.head) - chain += Jump(LabelArg(OverflowError.errLabel), Cond.Overflow) + chain += Jump(LabelArg(labelGenerator.getLabel(OverflowError)), Cond.Overflow) chain += stack.drop() chain += stack.push(Q64, RAX) case UnaryOperator.Not => @@ -349,21 +321,21 @@ object asmGenerator { op match { case BinaryOperator.Add => chain += Add(stack.head, destX) - chain += Jump(LabelArg(OverflowError.errLabel), Cond.Overflow) + chain += Jump(LabelArg(labelGenerator.getLabel(OverflowError)), Cond.Overflow) case BinaryOperator.Sub => chain += Subtract(destX, stack.head) - chain += Jump(LabelArg(OverflowError.errLabel), Cond.Overflow) + chain += Jump(LabelArg(labelGenerator.getLabel(OverflowError)), Cond.Overflow) chain += stack.drop() chain += stack.push(destX.size, RAX) case BinaryOperator.Mul => chain += Multiply(destX, stack.head) - chain += Jump(LabelArg(OverflowError.errLabel), Cond.Overflow) + chain += Jump(LabelArg(labelGenerator.getLabel(OverflowError)), Cond.Overflow) chain += stack.drop() chain += stack.push(destX.size, RAX) case BinaryOperator.Div => chain += Compare(stack.head, ImmediateVal(0)) - chain += Jump(LabelArg(ZeroDivError.errLabel), Cond.Equal) + chain += Jump(LabelArg(labelGenerator.getLabel(ZeroDivError)), Cond.Equal) chain += CDQ() chain += Divide(stack.head) chain += stack.drop() @@ -371,7 +343,7 @@ object asmGenerator { case BinaryOperator.Mod => chain += Compare(stack.head, ImmediateVal(0)) - chain += Jump(LabelArg(ZeroDivError.errLabel), Cond.Equal) + chain += Jump(LabelArg(labelGenerator.getLabel(ZeroDivError)), Cond.Equal) chain += CDQ() chain += Divide(stack.head) chain += stack.drop() @@ -470,7 +442,7 @@ object asmGenerator { private val escapedCharsMapping = escapedChars.map { case (k, v) => v -> s"\\$k" } extension (s: String) { - private def escaped: String = + def escaped: String = s.flatMap(c => escapedCharsMapping.getOrElse(c, c.toString)) } } diff --git a/src/main/wacc/backend/assemblyIR.scala b/src/main/wacc/backend/assemblyIR.scala index fbf51f5..d2d374a 100644 --- a/src/main/wacc/backend/assemblyIR.scala +++ b/src/main/wacc/backend/assemblyIR.scala @@ -199,8 +199,8 @@ object assemblyIR { case Global(name) => s".globl $name" case Text => ".text" case RoData => ".section .rodata" - case Int(value) => s".int $value" - case Asciz(string) => s".asciz \"$string\"" + case Int(value) => s"\t.int $value" + case Asciz(string) => s"\t.asciz \"$string\"" } }