refactor: return args and defs from labelGenerator, instead of strings

This commit is contained in:
2025-02-28 14:23:09 +00:00
parent fb5799dbfd
commit 7627ec14d2
3 changed files with 51 additions and 46 deletions

View File

@@ -17,23 +17,27 @@ private class LabelGenerator {
s".L$labelVal" s".L$labelVal"
} }
/** Get a named label for a function. */ private def getLabel(target: CallTarget | RuntimeError): String = target match {
def getLabel(target: CallTarget): String = target match { case Ident(v, _) => s"wacc_$v"
case Ident(v, _) => s"wacc_$v" case Builtin(name) => s"_$name"
case Builtin(name) => s"_$name" case err: RuntimeError => s".L.${err.name}"
} }
/** Get a named label for an error. */ /** Get a named label def for a function or error. */
def getLabel(target: RuntimeError): String = def getLabelDef(target: CallTarget | RuntimeError): LabelDef =
s".L.${target.name}" LabelDef(getLabel(target))
/** Get a named label for a function or error. */
def getLabelArg(target: CallTarget | RuntimeError): LabelArg =
LabelArg(getLabel(target))
/** Get an arbitrary label for a string. */ /** Get an arbitrary label for a string. */
def getLabel(str: String): String = def getLabelArg(str: String): LabelArg =
strings.getOrElseUpdate(str, s".L.str${strings.size}") LabelArg(strings.getOrElseUpdate(str, s".L.str${strings.size}"))
/** Get a named label for a string. */ /** Get a named label for a string. */
def getLabel(src: String, name: String): String = def getLabelArg(src: String, name: String): LabelArg =
strings.getOrElseUpdate(src, s".L.$name.str${strings.size}") LabelArg(strings.getOrElseUpdate(src, s".L.$name.str${strings.size}"))
/** Generate the assembly labels for constants that were labelled using the LabelGenerator. */ /** Generate the assembly labels for constants that were labelled using the LabelGenerator. */
def generateConstants: Chain[AsmLine] = def generateConstants: Chain[AsmLine] =

View File

@@ -7,13 +7,13 @@ sealed trait RuntimeError {
val name: String val name: String
protected val errStr: String protected val errStr: String
protected def getErrLabel(using labelGenerator: LabelGenerator): String = protected def getErrLabel(using labelGenerator: LabelGenerator): LabelArg =
labelGenerator.getLabel(errStr, name = name) labelGenerator.getLabelArg(errStr, name = name)
protected def generateHandler(using labelGenerator: LabelGenerator): Chain[AsmLine] protected def generateHandler(using labelGenerator: LabelGenerator): Chain[AsmLine]
def generate(using labelGenerator: LabelGenerator): Chain[AsmLine] = def generate(using labelGenerator: LabelGenerator): Chain[AsmLine] =
LabelDef(labelGenerator.getLabel(this)) +: generateHandler labelGenerator.getLabelDef(this) +: generateHandler
} }
object RuntimeError { object RuntimeError {
@@ -39,7 +39,7 @@ object RuntimeError {
protected def generateHandler(using labelGenerator: LabelGenerator): Chain[AsmLine] = Chain( protected def generateHandler(using labelGenerator: LabelGenerator): Chain[AsmLine] = Chain(
stackAlign, stackAlign,
Load(RDI, IndexAddress(RIP, LabelArg(getErrLabel))), Load(RDI, IndexAddress(RIP, getErrLabel)),
assemblyIR.Call(CLibFunc.PrintF), assemblyIR.Call(CLibFunc.PrintF),
Move(RDI, ImmediateVal(-1)), Move(RDI, ImmediateVal(-1)),
assemblyIR.Call(CLibFunc.Exit) assemblyIR.Call(CLibFunc.Exit)
@@ -54,7 +54,7 @@ object RuntimeError {
protected def generateHandler(using labelGenerator: LabelGenerator): Chain[AsmLine] = Chain( protected def generateHandler(using labelGenerator: LabelGenerator): Chain[AsmLine] = Chain(
Pop(RSI), Pop(RSI),
stackAlign, stackAlign,
Load(RDI, IndexAddress(RIP, LabelArg(getErrLabel))), Load(RDI, IndexAddress(RIP, getErrLabel)),
assemblyIR.Call(CLibFunc.PrintF), assemblyIR.Call(CLibFunc.PrintF),
Move(RDI, ImmediateVal(255)), Move(RDI, ImmediateVal(255)),
assemblyIR.Call(CLibFunc.Exit) assemblyIR.Call(CLibFunc.Exit)
@@ -68,7 +68,7 @@ object RuntimeError {
protected def generateHandler(using labelGenerator: LabelGenerator): Chain[AsmLine] = Chain( protected def generateHandler(using labelGenerator: LabelGenerator): Chain[AsmLine] = Chain(
stackAlign, stackAlign,
Load(RDI, IndexAddress(RIP, LabelArg(getErrLabel))), Load(RDI, IndexAddress(RIP, getErrLabel)),
assemblyIR.Call(CLibFunc.PrintF), assemblyIR.Call(CLibFunc.PrintF),
Move(RDI, ImmediateVal(255)), Move(RDI, ImmediateVal(255)),
assemblyIR.Call(CLibFunc.Exit) assemblyIR.Call(CLibFunc.Exit)
@@ -82,7 +82,7 @@ object RuntimeError {
protected def generateHandler(using labelGenerator: LabelGenerator): Chain[AsmLine] = Chain( protected def generateHandler(using labelGenerator: LabelGenerator): Chain[AsmLine] = Chain(
stackAlign, stackAlign,
Load(RDI, IndexAddress(RIP, LabelArg(getErrLabel))), Load(RDI, IndexAddress(RIP, getErrLabel)),
assemblyIR.Call(CLibFunc.PrintF), assemblyIR.Call(CLibFunc.PrintF),
Move(RDI, ImmediateVal(255)), Move(RDI, ImmediateVal(255)),
assemblyIR.Call(CLibFunc.Exit) assemblyIR.Call(CLibFunc.Exit)
@@ -97,7 +97,7 @@ object RuntimeError {
protected def generateHandler(using labelGenerator: LabelGenerator): Chain[AsmLine] = Chain( protected def generateHandler(using labelGenerator: LabelGenerator): Chain[AsmLine] = Chain(
Move(RSI, Register(Q64, CX)), Move(RSI, Register(Q64, CX)),
stackAlign, stackAlign,
Load(RDI, IndexAddress(RIP, LabelArg(getErrLabel))), Load(RDI, IndexAddress(RIP, getErrLabel)),
assemblyIR.Call(CLibFunc.PrintF), assemblyIR.Call(CLibFunc.PrintF),
Move(RDI, ImmediateVal(255)), Move(RDI, ImmediateVal(255)),
assemblyIR.Call(CLibFunc.Exit) assemblyIR.Call(CLibFunc.Exit)

View File

@@ -56,10 +56,11 @@ object asmGenerator {
) )
} }
private def wrapBuiltinFunc(labelName: String, funcBody: Chain[AsmLine])(using private def wrapBuiltinFunc(builtin: Builtin, funcBody: Chain[AsmLine])(using
stack: Stack stack: Stack,
labelGenerator: LabelGenerator
): Chain[AsmLine] = { ): Chain[AsmLine] = {
var chain = Chain.one[AsmLine](LabelDef(labelName)) var chain = Chain.one[AsmLine](labelGenerator.getLabelDef(builtin))
chain ++= funcPrologue() chain ++= funcPrologue()
chain ++= funcBody chain ++= funcBody
chain ++= funcEpilogue() chain ++= funcEpilogue()
@@ -73,7 +74,7 @@ object asmGenerator {
// Setup the stack with param 7 and up // Setup the stack with param 7 and up
func.params.drop(argRegs.size).foreach(stack.reserve(_)) func.params.drop(argRegs.size).foreach(stack.reserve(_))
stack.reserve(Q64) // Reserve return pointer slot stack.reserve(Q64) // Reserve return pointer slot
var chain = Chain.one[AsmLine](LabelDef(labelGenerator.getLabel(func.name))) var chain = Chain.one[AsmLine](labelGenerator.getLabelDef(func.name))
chain ++= funcPrologue() chain ++= funcPrologue()
// Push the rest of params onto the stack for simplicity // Push the rest of params onto the stack for simplicity
argRegs.zip(func.params).foreach { (reg, param) => argRegs.zip(func.params).foreach { (reg, param) =>
@@ -91,12 +92,12 @@ object asmGenerator {
var chain = Chain.empty[AsmLine] var chain = Chain.empty[AsmLine]
chain ++= wrapBuiltinFunc( chain ++= wrapBuiltinFunc(
labelGenerator.getLabel(Builtin.Exit), Builtin.Exit,
Chain(stackAlign, assemblyIR.Call(CLibFunc.Exit)) Chain(stackAlign, assemblyIR.Call(CLibFunc.Exit))
) )
chain ++= wrapBuiltinFunc( chain ++= wrapBuiltinFunc(
labelGenerator.getLabel(Builtin.Printf), Builtin.Printf,
Chain( Chain(
stackAlign, stackAlign,
assemblyIR.Call(CLibFunc.PrintF), assemblyIR.Call(CLibFunc.PrintF),
@@ -106,7 +107,7 @@ object asmGenerator {
) )
chain ++= wrapBuiltinFunc( chain ++= wrapBuiltinFunc(
labelGenerator.getLabel(Builtin.PrintCharArray), Builtin.PrintCharArray,
Chain( Chain(
stackAlign, stackAlign,
Load(RDX, IndexAddress(RSI, KnownType.Int.size.toInt)), Load(RDX, IndexAddress(RSI, KnownType.Int.size.toInt)),
@@ -118,24 +119,24 @@ object asmGenerator {
) )
chain ++= wrapBuiltinFunc( chain ++= wrapBuiltinFunc(
labelGenerator.getLabel(Builtin.Malloc), Builtin.Malloc,
Chain(stackAlign, assemblyIR.Call(CLibFunc.Malloc)) Chain(stackAlign, assemblyIR.Call(CLibFunc.Malloc))
// Out of memory check is optional // Out of memory check is optional
) )
chain ++= wrapBuiltinFunc( chain ++= wrapBuiltinFunc(
labelGenerator.getLabel(Builtin.Free), Builtin.Free,
Chain( Chain(
stackAlign, stackAlign,
Move(RDI, RAX), Move(RDI, RAX),
Compare(RDI, ImmediateVal(0)), Compare(RDI, ImmediateVal(0)),
Jump(LabelArg(labelGenerator.getLabel(NullPtrError)), Cond.Equal), Jump(labelGenerator.getLabelArg(NullPtrError), Cond.Equal),
assemblyIR.Call(CLibFunc.Free) assemblyIR.Call(CLibFunc.Free)
) )
) )
chain ++= wrapBuiltinFunc( chain ++= wrapBuiltinFunc(
labelGenerator.getLabel(Builtin.Read), Builtin.Read,
Chain( Chain(
stackAlign, stackAlign,
Subtract(Register(Q64, SP), ImmediateVal(8)), Subtract(Register(Q64, SP), ImmediateVal(8)),
@@ -167,15 +168,15 @@ object asmGenerator {
chain ++= evalExprOntoStack(i) chain ++= evalExprOntoStack(i)
chain += stack.pop(RCX) chain += stack.pop(RCX)
chain += Compare(ECX, ImmediateVal(0)) chain += Compare(ECX, ImmediateVal(0))
chain += Jump(LabelArg(labelGenerator.getLabel(OutOfBoundsError)), Cond.Less) chain += Jump(labelGenerator.getLabelArg(OutOfBoundsError), Cond.Less)
chain += stack.push(Q64, RCX) chain += stack.push(Q64, RCX)
chain ++= evalExprOntoStack(x) chain ++= evalExprOntoStack(x)
chain += stack.pop(RAX) chain += stack.pop(RAX)
chain += stack.pop(RCX) chain += stack.pop(RCX)
chain += Compare(EAX, ImmediateVal(0)) chain += Compare(EAX, ImmediateVal(0))
chain += Jump(LabelArg(labelGenerator.getLabel(NullPtrError)), Cond.Equal) chain += Jump(labelGenerator.getLabelArg(NullPtrError), Cond.Equal)
chain += Compare(MemLocation(RAX, D32), ECX) chain += Compare(MemLocation(RAX, D32), ECX)
chain += Jump(LabelArg(labelGenerator.getLabel(OutOfBoundsError)), Cond.LessEqual) chain += Jump(labelGenerator.getLabelArg(OutOfBoundsError), Cond.LessEqual)
chain += stack.pop(RDX) chain += stack.pop(RDX)
chain += Move( chain += Move(
@@ -246,7 +247,7 @@ object asmGenerator {
expr.ty match { expr.ty match {
case KnownType.String => case KnownType.String =>
val str = elems.collect { case CharLiter(v) => v }.mkString val str = elems.collect { case CharLiter(v) => v }.mkString
chain += Load(RAX, IndexAddress(RIP, LabelArg(labelGenerator.getLabel(str)))) chain += Load(RAX, IndexAddress(RIP, labelGenerator.getLabelArg(str)))
chain += stack.push(Q64, RAX) chain += stack.push(Q64, RAX)
case ty => case ty =>
chain ++= generateCall( chain ++= generateCall(
@@ -277,12 +278,12 @@ object asmGenerator {
chain ++= evalExprOntoStack(i) chain ++= evalExprOntoStack(i)
chain += stack.pop(RCX) chain += stack.pop(RCX)
chain += Compare(RCX, ImmediateVal(0)) chain += Compare(RCX, ImmediateVal(0))
chain += Jump(LabelArg(labelGenerator.getLabel(OutOfBoundsError)), Cond.Less) chain += Jump(labelGenerator.getLabelArg(OutOfBoundsError), Cond.Less)
chain += stack.pop(RAX) chain += stack.pop(RAX)
chain += Compare(EAX, ImmediateVal(0)) chain += Compare(EAX, ImmediateVal(0))
chain += Jump(LabelArg(labelGenerator.getLabel(NullPtrError)), Cond.Equal) chain += Jump(labelGenerator.getLabelArg(NullPtrError), Cond.Equal)
chain += Compare(MemLocation(RAX, D32), ECX) chain += Compare(MemLocation(RAX, D32), ECX)
chain += Jump(LabelArg(labelGenerator.getLabel(OutOfBoundsError)), Cond.LessEqual) chain += Jump(labelGenerator.getLabelArg(OutOfBoundsError), Cond.LessEqual)
// + Int because we store the length of the array at the start // + Int because we store the length of the array at the start
chain += Move( chain += Move(
Register(x.ty.elemSize, AX), Register(x.ty.elemSize, AX),
@@ -296,7 +297,7 @@ object asmGenerator {
chain += Move(EAX, stack.head) chain += Move(EAX, stack.head)
chain += And(EAX, ImmediateVal(-128)) chain += And(EAX, ImmediateVal(-128))
chain += Compare(EAX, ImmediateVal(0)) chain += Compare(EAX, ImmediateVal(0))
chain += Jump(LabelArg(labelGenerator.getLabel(BadChrError)), Cond.NotEqual) chain += Jump(labelGenerator.getLabelArg(BadChrError), Cond.NotEqual)
case UnaryOperator.Ord => // No op needed case UnaryOperator.Ord => // No op needed
case UnaryOperator.Len => case UnaryOperator.Len =>
chain += stack.pop(RAX) chain += stack.pop(RAX)
@@ -305,7 +306,7 @@ object asmGenerator {
case UnaryOperator.Negate => case UnaryOperator.Negate =>
chain += Xor(EAX, EAX) chain += Xor(EAX, EAX)
chain += Subtract(EAX, stack.head) chain += Subtract(EAX, stack.head)
chain += Jump(LabelArg(labelGenerator.getLabel(OverflowError)), Cond.Overflow) chain += Jump(labelGenerator.getLabelArg(OverflowError), Cond.Overflow)
chain += stack.drop() chain += stack.drop()
chain += stack.push(Q64, RAX) chain += stack.push(Q64, RAX)
case UnaryOperator.Not => case UnaryOperator.Not =>
@@ -321,21 +322,21 @@ object asmGenerator {
op match { op match {
case BinaryOperator.Add => case BinaryOperator.Add =>
chain += Add(stack.head, destX) chain += Add(stack.head, destX)
chain += Jump(LabelArg(labelGenerator.getLabel(OverflowError)), Cond.Overflow) chain += Jump(labelGenerator.getLabelArg(OverflowError), Cond.Overflow)
case BinaryOperator.Sub => case BinaryOperator.Sub =>
chain += Subtract(destX, stack.head) chain += Subtract(destX, stack.head)
chain += Jump(LabelArg(labelGenerator.getLabel(OverflowError)), Cond.Overflow) chain += Jump(labelGenerator.getLabelArg(OverflowError), Cond.Overflow)
chain += stack.drop() chain += stack.drop()
chain += stack.push(destX.size, RAX) chain += stack.push(destX.size, RAX)
case BinaryOperator.Mul => case BinaryOperator.Mul =>
chain += Multiply(destX, stack.head) chain += Multiply(destX, stack.head)
chain += Jump(LabelArg(labelGenerator.getLabel(OverflowError)), Cond.Overflow) chain += Jump(labelGenerator.getLabelArg(OverflowError), Cond.Overflow)
chain += stack.drop() chain += stack.drop()
chain += stack.push(destX.size, RAX) chain += stack.push(destX.size, RAX)
case BinaryOperator.Div => case BinaryOperator.Div =>
chain += Compare(stack.head, ImmediateVal(0)) chain += Compare(stack.head, ImmediateVal(0))
chain += Jump(LabelArg(labelGenerator.getLabel(ZeroDivError)), Cond.Equal) chain += Jump(labelGenerator.getLabelArg(ZeroDivError), Cond.Equal)
chain += CDQ() chain += CDQ()
chain += Divide(stack.head) chain += Divide(stack.head)
chain += stack.drop() chain += stack.drop()
@@ -343,7 +344,7 @@ object asmGenerator {
case BinaryOperator.Mod => case BinaryOperator.Mod =>
chain += Compare(stack.head, ImmediateVal(0)) chain += Compare(stack.head, ImmediateVal(0))
chain += Jump(LabelArg(labelGenerator.getLabel(ZeroDivError)), Cond.Equal) chain += Jump(labelGenerator.getLabelArg(ZeroDivError), Cond.Equal)
chain += CDQ() chain += CDQ()
chain += Divide(stack.head) chain += Divide(stack.head)
chain += stack.drop() chain += stack.drop()
@@ -393,9 +394,9 @@ object asmGenerator {
// Tail Call Optimisation (TCO) // Tail Call Optimisation (TCO)
if (isTail) { if (isTail) {
chain += Jump(LabelArg(labelGenerator.getLabel(target))) // tail call chain += Jump(labelGenerator.getLabelArg(target)) // tail call
} else { } else {
chain += assemblyIR.Call(LabelArg(labelGenerator.getLabel(target))) // regular call chain += assemblyIR.Call(labelGenerator.getLabelArg(target)) // regular call
} }
if (args.size > argRegs.size) { if (args.size > argRegs.size) {