diff --git a/src/main/wacc/backend/LabelGenerator.scala b/src/main/wacc/backend/LabelGenerator.scala index e618e0b..3b5169b 100644 --- a/src/main/wacc/backend/LabelGenerator.scala +++ b/src/main/wacc/backend/LabelGenerator.scala @@ -17,23 +17,27 @@ private class LabelGenerator { s".L$labelVal" } - /** Get a named label for a function. */ - def getLabel(target: CallTarget): String = target match { - case Ident(v, _) => s"wacc_$v" - case Builtin(name) => s"_$name" + private def getLabel(target: CallTarget | RuntimeError): String = target match { + case Ident(v, _) => s"wacc_$v" + case Builtin(name) => s"_$name" + case err: RuntimeError => s".L.${err.name}" } - /** Get a named label for an error. */ - def getLabel(target: RuntimeError): String = - s".L.${target.name}" + /** Get a named label def for a function or error. */ + def getLabelDef(target: CallTarget | RuntimeError): LabelDef = + LabelDef(getLabel(target)) + + /** Get a named label for a function or error. */ + def getLabelArg(target: CallTarget | RuntimeError): LabelArg = + LabelArg(getLabel(target)) /** Get an arbitrary label for a string. */ - def getLabel(str: String): String = - strings.getOrElseUpdate(str, s".L.str${strings.size}") + def getLabelArg(str: String): LabelArg = + LabelArg(strings.getOrElseUpdate(str, s".L.str${strings.size}")) /** Get a named label for a string. */ - def getLabel(src: String, name: String): String = - strings.getOrElseUpdate(src, s".L.$name.str${strings.size}") + def getLabelArg(src: String, name: String): LabelArg = + LabelArg(strings.getOrElseUpdate(src, s".L.$name.str${strings.size}")) /** Generate the assembly labels for constants that were labelled using the LabelGenerator. */ def generateConstants: Chain[AsmLine] = diff --git a/src/main/wacc/backend/RuntimeError.scala b/src/main/wacc/backend/RuntimeError.scala index 61f1428..5252f9c 100644 --- a/src/main/wacc/backend/RuntimeError.scala +++ b/src/main/wacc/backend/RuntimeError.scala @@ -7,13 +7,13 @@ sealed trait RuntimeError { val name: String protected val errStr: String - protected def getErrLabel(using labelGenerator: LabelGenerator): String = - labelGenerator.getLabel(errStr, name = name) + protected def getErrLabel(using labelGenerator: LabelGenerator): LabelArg = + labelGenerator.getLabelArg(errStr, name = name) protected def generateHandler(using labelGenerator: LabelGenerator): Chain[AsmLine] def generate(using labelGenerator: LabelGenerator): Chain[AsmLine] = - LabelDef(labelGenerator.getLabel(this)) +: generateHandler + labelGenerator.getLabelDef(this) +: generateHandler } object RuntimeError { @@ -39,7 +39,7 @@ object RuntimeError { protected def generateHandler(using labelGenerator: LabelGenerator): Chain[AsmLine] = Chain( stackAlign, - Load(RDI, IndexAddress(RIP, LabelArg(getErrLabel))), + Load(RDI, IndexAddress(RIP, getErrLabel)), assemblyIR.Call(CLibFunc.PrintF), Move(RDI, ImmediateVal(-1)), assemblyIR.Call(CLibFunc.Exit) @@ -54,7 +54,7 @@ object RuntimeError { protected def generateHandler(using labelGenerator: LabelGenerator): Chain[AsmLine] = Chain( Pop(RSI), stackAlign, - Load(RDI, IndexAddress(RIP, LabelArg(getErrLabel))), + Load(RDI, IndexAddress(RIP, getErrLabel)), assemblyIR.Call(CLibFunc.PrintF), Move(RDI, ImmediateVal(255)), assemblyIR.Call(CLibFunc.Exit) @@ -68,7 +68,7 @@ object RuntimeError { protected def generateHandler(using labelGenerator: LabelGenerator): Chain[AsmLine] = Chain( stackAlign, - Load(RDI, IndexAddress(RIP, LabelArg(getErrLabel))), + Load(RDI, IndexAddress(RIP, getErrLabel)), assemblyIR.Call(CLibFunc.PrintF), Move(RDI, ImmediateVal(255)), assemblyIR.Call(CLibFunc.Exit) @@ -82,7 +82,7 @@ object RuntimeError { protected def generateHandler(using labelGenerator: LabelGenerator): Chain[AsmLine] = Chain( stackAlign, - Load(RDI, IndexAddress(RIP, LabelArg(getErrLabel))), + Load(RDI, IndexAddress(RIP, getErrLabel)), assemblyIR.Call(CLibFunc.PrintF), Move(RDI, ImmediateVal(255)), assemblyIR.Call(CLibFunc.Exit) @@ -97,7 +97,7 @@ object RuntimeError { protected def generateHandler(using labelGenerator: LabelGenerator): Chain[AsmLine] = Chain( Move(RSI, Register(Q64, CX)), stackAlign, - Load(RDI, IndexAddress(RIP, LabelArg(getErrLabel))), + Load(RDI, IndexAddress(RIP, getErrLabel)), assemblyIR.Call(CLibFunc.PrintF), Move(RDI, ImmediateVal(255)), assemblyIR.Call(CLibFunc.Exit) diff --git a/src/main/wacc/backend/asmGenerator.scala b/src/main/wacc/backend/asmGenerator.scala index 68bf5b1..c78c3d1 100644 --- a/src/main/wacc/backend/asmGenerator.scala +++ b/src/main/wacc/backend/asmGenerator.scala @@ -56,10 +56,11 @@ object asmGenerator { ) } - private def wrapBuiltinFunc(labelName: String, funcBody: Chain[AsmLine])(using - stack: Stack + private def wrapBuiltinFunc(builtin: Builtin, funcBody: Chain[AsmLine])(using + stack: Stack, + labelGenerator: LabelGenerator ): Chain[AsmLine] = { - var chain = Chain.one[AsmLine](LabelDef(labelName)) + var chain = Chain.one[AsmLine](labelGenerator.getLabelDef(builtin)) chain ++= funcPrologue() chain ++= funcBody chain ++= funcEpilogue() @@ -73,7 +74,7 @@ object asmGenerator { // Setup the stack with param 7 and up func.params.drop(argRegs.size).foreach(stack.reserve(_)) stack.reserve(Q64) // Reserve return pointer slot - var chain = Chain.one[AsmLine](LabelDef(labelGenerator.getLabel(func.name))) + var chain = Chain.one[AsmLine](labelGenerator.getLabelDef(func.name)) chain ++= funcPrologue() // Push the rest of params onto the stack for simplicity argRegs.zip(func.params).foreach { (reg, param) => @@ -91,12 +92,12 @@ object asmGenerator { var chain = Chain.empty[AsmLine] chain ++= wrapBuiltinFunc( - labelGenerator.getLabel(Builtin.Exit), + Builtin.Exit, Chain(stackAlign, assemblyIR.Call(CLibFunc.Exit)) ) chain ++= wrapBuiltinFunc( - labelGenerator.getLabel(Builtin.Printf), + Builtin.Printf, Chain( stackAlign, assemblyIR.Call(CLibFunc.PrintF), @@ -106,7 +107,7 @@ object asmGenerator { ) chain ++= wrapBuiltinFunc( - labelGenerator.getLabel(Builtin.PrintCharArray), + Builtin.PrintCharArray, Chain( stackAlign, Load(RDX, IndexAddress(RSI, KnownType.Int.size.toInt)), @@ -118,24 +119,24 @@ object asmGenerator { ) chain ++= wrapBuiltinFunc( - labelGenerator.getLabel(Builtin.Malloc), + Builtin.Malloc, Chain(stackAlign, assemblyIR.Call(CLibFunc.Malloc)) // Out of memory check is optional ) chain ++= wrapBuiltinFunc( - labelGenerator.getLabel(Builtin.Free), + Builtin.Free, Chain( stackAlign, Move(RDI, RAX), Compare(RDI, ImmediateVal(0)), - Jump(LabelArg(labelGenerator.getLabel(NullPtrError)), Cond.Equal), + Jump(labelGenerator.getLabelArg(NullPtrError), Cond.Equal), assemblyIR.Call(CLibFunc.Free) ) ) chain ++= wrapBuiltinFunc( - labelGenerator.getLabel(Builtin.Read), + Builtin.Read, Chain( stackAlign, Subtract(Register(Q64, SP), ImmediateVal(8)), @@ -167,15 +168,15 @@ object asmGenerator { chain ++= evalExprOntoStack(i) chain += stack.pop(RCX) chain += Compare(ECX, ImmediateVal(0)) - chain += Jump(LabelArg(labelGenerator.getLabel(OutOfBoundsError)), Cond.Less) + chain += Jump(labelGenerator.getLabelArg(OutOfBoundsError), Cond.Less) chain += stack.push(Q64, RCX) chain ++= evalExprOntoStack(x) chain += stack.pop(RAX) chain += stack.pop(RCX) chain += Compare(EAX, ImmediateVal(0)) - chain += Jump(LabelArg(labelGenerator.getLabel(NullPtrError)), Cond.Equal) + chain += Jump(labelGenerator.getLabelArg(NullPtrError), Cond.Equal) chain += Compare(MemLocation(RAX, D32), ECX) - chain += Jump(LabelArg(labelGenerator.getLabel(OutOfBoundsError)), Cond.LessEqual) + chain += Jump(labelGenerator.getLabelArg(OutOfBoundsError), Cond.LessEqual) chain += stack.pop(RDX) chain += Move( @@ -246,7 +247,7 @@ object asmGenerator { expr.ty match { case KnownType.String => val str = elems.collect { case CharLiter(v) => v }.mkString - chain += Load(RAX, IndexAddress(RIP, LabelArg(labelGenerator.getLabel(str)))) + chain += Load(RAX, IndexAddress(RIP, labelGenerator.getLabelArg(str))) chain += stack.push(Q64, RAX) case ty => chain ++= generateCall( @@ -277,12 +278,12 @@ object asmGenerator { chain ++= evalExprOntoStack(i) chain += stack.pop(RCX) chain += Compare(RCX, ImmediateVal(0)) - chain += Jump(LabelArg(labelGenerator.getLabel(OutOfBoundsError)), Cond.Less) + chain += Jump(labelGenerator.getLabelArg(OutOfBoundsError), Cond.Less) chain += stack.pop(RAX) chain += Compare(EAX, ImmediateVal(0)) - chain += Jump(LabelArg(labelGenerator.getLabel(NullPtrError)), Cond.Equal) + chain += Jump(labelGenerator.getLabelArg(NullPtrError), Cond.Equal) chain += Compare(MemLocation(RAX, D32), ECX) - chain += Jump(LabelArg(labelGenerator.getLabel(OutOfBoundsError)), Cond.LessEqual) + chain += Jump(labelGenerator.getLabelArg(OutOfBoundsError), Cond.LessEqual) // + Int because we store the length of the array at the start chain += Move( Register(x.ty.elemSize, AX), @@ -296,7 +297,7 @@ object asmGenerator { chain += Move(EAX, stack.head) chain += And(EAX, ImmediateVal(-128)) chain += Compare(EAX, ImmediateVal(0)) - chain += Jump(LabelArg(labelGenerator.getLabel(BadChrError)), Cond.NotEqual) + chain += Jump(labelGenerator.getLabelArg(BadChrError), Cond.NotEqual) case UnaryOperator.Ord => // No op needed case UnaryOperator.Len => chain += stack.pop(RAX) @@ -305,7 +306,7 @@ object asmGenerator { case UnaryOperator.Negate => chain += Xor(EAX, EAX) chain += Subtract(EAX, stack.head) - chain += Jump(LabelArg(labelGenerator.getLabel(OverflowError)), Cond.Overflow) + chain += Jump(labelGenerator.getLabelArg(OverflowError), Cond.Overflow) chain += stack.drop() chain += stack.push(Q64, RAX) case UnaryOperator.Not => @@ -321,21 +322,21 @@ object asmGenerator { op match { case BinaryOperator.Add => chain += Add(stack.head, destX) - chain += Jump(LabelArg(labelGenerator.getLabel(OverflowError)), Cond.Overflow) + chain += Jump(labelGenerator.getLabelArg(OverflowError), Cond.Overflow) case BinaryOperator.Sub => chain += Subtract(destX, stack.head) - chain += Jump(LabelArg(labelGenerator.getLabel(OverflowError)), Cond.Overflow) + chain += Jump(labelGenerator.getLabelArg(OverflowError), Cond.Overflow) chain += stack.drop() chain += stack.push(destX.size, RAX) case BinaryOperator.Mul => chain += Multiply(destX, stack.head) - chain += Jump(LabelArg(labelGenerator.getLabel(OverflowError)), Cond.Overflow) + chain += Jump(labelGenerator.getLabelArg(OverflowError), Cond.Overflow) chain += stack.drop() chain += stack.push(destX.size, RAX) case BinaryOperator.Div => chain += Compare(stack.head, ImmediateVal(0)) - chain += Jump(LabelArg(labelGenerator.getLabel(ZeroDivError)), Cond.Equal) + chain += Jump(labelGenerator.getLabelArg(ZeroDivError), Cond.Equal) chain += CDQ() chain += Divide(stack.head) chain += stack.drop() @@ -343,7 +344,7 @@ object asmGenerator { case BinaryOperator.Mod => chain += Compare(stack.head, ImmediateVal(0)) - chain += Jump(LabelArg(labelGenerator.getLabel(ZeroDivError)), Cond.Equal) + chain += Jump(labelGenerator.getLabelArg(ZeroDivError), Cond.Equal) chain += CDQ() chain += Divide(stack.head) chain += stack.drop() @@ -393,9 +394,9 @@ object asmGenerator { // Tail Call Optimisation (TCO) if (isTail) { - chain += Jump(LabelArg(labelGenerator.getLabel(target))) // tail call + chain += Jump(labelGenerator.getLabelArg(target)) // tail call } else { - chain += assemblyIR.Call(LabelArg(labelGenerator.getLabel(target))) // regular call + chain += assemblyIR.Call(labelGenerator.getLabelArg(target)) // regular call } if (args.size > argRegs.size) {