refactor: return args and defs from labelGenerator, instead of strings

This commit is contained in:
2025-02-28 14:23:09 +00:00
parent fb5799dbfd
commit 7627ec14d2
3 changed files with 51 additions and 46 deletions

View File

@@ -56,10 +56,11 @@ object asmGenerator {
)
}
private def wrapBuiltinFunc(labelName: String, funcBody: Chain[AsmLine])(using
stack: Stack
private def wrapBuiltinFunc(builtin: Builtin, funcBody: Chain[AsmLine])(using
stack: Stack,
labelGenerator: LabelGenerator
): Chain[AsmLine] = {
var chain = Chain.one[AsmLine](LabelDef(labelName))
var chain = Chain.one[AsmLine](labelGenerator.getLabelDef(builtin))
chain ++= funcPrologue()
chain ++= funcBody
chain ++= funcEpilogue()
@@ -73,7 +74,7 @@ object asmGenerator {
// Setup the stack with param 7 and up
func.params.drop(argRegs.size).foreach(stack.reserve(_))
stack.reserve(Q64) // Reserve return pointer slot
var chain = Chain.one[AsmLine](LabelDef(labelGenerator.getLabel(func.name)))
var chain = Chain.one[AsmLine](labelGenerator.getLabelDef(func.name))
chain ++= funcPrologue()
// Push the rest of params onto the stack for simplicity
argRegs.zip(func.params).foreach { (reg, param) =>
@@ -91,12 +92,12 @@ object asmGenerator {
var chain = Chain.empty[AsmLine]
chain ++= wrapBuiltinFunc(
labelGenerator.getLabel(Builtin.Exit),
Builtin.Exit,
Chain(stackAlign, assemblyIR.Call(CLibFunc.Exit))
)
chain ++= wrapBuiltinFunc(
labelGenerator.getLabel(Builtin.Printf),
Builtin.Printf,
Chain(
stackAlign,
assemblyIR.Call(CLibFunc.PrintF),
@@ -106,7 +107,7 @@ object asmGenerator {
)
chain ++= wrapBuiltinFunc(
labelGenerator.getLabel(Builtin.PrintCharArray),
Builtin.PrintCharArray,
Chain(
stackAlign,
Load(RDX, IndexAddress(RSI, KnownType.Int.size.toInt)),
@@ -118,24 +119,24 @@ object asmGenerator {
)
chain ++= wrapBuiltinFunc(
labelGenerator.getLabel(Builtin.Malloc),
Builtin.Malloc,
Chain(stackAlign, assemblyIR.Call(CLibFunc.Malloc))
// Out of memory check is optional
)
chain ++= wrapBuiltinFunc(
labelGenerator.getLabel(Builtin.Free),
Builtin.Free,
Chain(
stackAlign,
Move(RDI, RAX),
Compare(RDI, ImmediateVal(0)),
Jump(LabelArg(labelGenerator.getLabel(NullPtrError)), Cond.Equal),
Jump(labelGenerator.getLabelArg(NullPtrError), Cond.Equal),
assemblyIR.Call(CLibFunc.Free)
)
)
chain ++= wrapBuiltinFunc(
labelGenerator.getLabel(Builtin.Read),
Builtin.Read,
Chain(
stackAlign,
Subtract(Register(Q64, SP), ImmediateVal(8)),
@@ -167,15 +168,15 @@ object asmGenerator {
chain ++= evalExprOntoStack(i)
chain += stack.pop(RCX)
chain += Compare(ECX, ImmediateVal(0))
chain += Jump(LabelArg(labelGenerator.getLabel(OutOfBoundsError)), Cond.Less)
chain += Jump(labelGenerator.getLabelArg(OutOfBoundsError), Cond.Less)
chain += stack.push(Q64, RCX)
chain ++= evalExprOntoStack(x)
chain += stack.pop(RAX)
chain += stack.pop(RCX)
chain += Compare(EAX, ImmediateVal(0))
chain += Jump(LabelArg(labelGenerator.getLabel(NullPtrError)), Cond.Equal)
chain += Jump(labelGenerator.getLabelArg(NullPtrError), Cond.Equal)
chain += Compare(MemLocation(RAX, D32), ECX)
chain += Jump(LabelArg(labelGenerator.getLabel(OutOfBoundsError)), Cond.LessEqual)
chain += Jump(labelGenerator.getLabelArg(OutOfBoundsError), Cond.LessEqual)
chain += stack.pop(RDX)
chain += Move(
@@ -246,7 +247,7 @@ object asmGenerator {
expr.ty match {
case KnownType.String =>
val str = elems.collect { case CharLiter(v) => v }.mkString
chain += Load(RAX, IndexAddress(RIP, LabelArg(labelGenerator.getLabel(str))))
chain += Load(RAX, IndexAddress(RIP, labelGenerator.getLabelArg(str)))
chain += stack.push(Q64, RAX)
case ty =>
chain ++= generateCall(
@@ -277,12 +278,12 @@ object asmGenerator {
chain ++= evalExprOntoStack(i)
chain += stack.pop(RCX)
chain += Compare(RCX, ImmediateVal(0))
chain += Jump(LabelArg(labelGenerator.getLabel(OutOfBoundsError)), Cond.Less)
chain += Jump(labelGenerator.getLabelArg(OutOfBoundsError), Cond.Less)
chain += stack.pop(RAX)
chain += Compare(EAX, ImmediateVal(0))
chain += Jump(LabelArg(labelGenerator.getLabel(NullPtrError)), Cond.Equal)
chain += Jump(labelGenerator.getLabelArg(NullPtrError), Cond.Equal)
chain += Compare(MemLocation(RAX, D32), ECX)
chain += Jump(LabelArg(labelGenerator.getLabel(OutOfBoundsError)), Cond.LessEqual)
chain += Jump(labelGenerator.getLabelArg(OutOfBoundsError), Cond.LessEqual)
// + Int because we store the length of the array at the start
chain += Move(
Register(x.ty.elemSize, AX),
@@ -296,7 +297,7 @@ object asmGenerator {
chain += Move(EAX, stack.head)
chain += And(EAX, ImmediateVal(-128))
chain += Compare(EAX, ImmediateVal(0))
chain += Jump(LabelArg(labelGenerator.getLabel(BadChrError)), Cond.NotEqual)
chain += Jump(labelGenerator.getLabelArg(BadChrError), Cond.NotEqual)
case UnaryOperator.Ord => // No op needed
case UnaryOperator.Len =>
chain += stack.pop(RAX)
@@ -305,7 +306,7 @@ object asmGenerator {
case UnaryOperator.Negate =>
chain += Xor(EAX, EAX)
chain += Subtract(EAX, stack.head)
chain += Jump(LabelArg(labelGenerator.getLabel(OverflowError)), Cond.Overflow)
chain += Jump(labelGenerator.getLabelArg(OverflowError), Cond.Overflow)
chain += stack.drop()
chain += stack.push(Q64, RAX)
case UnaryOperator.Not =>
@@ -321,21 +322,21 @@ object asmGenerator {
op match {
case BinaryOperator.Add =>
chain += Add(stack.head, destX)
chain += Jump(LabelArg(labelGenerator.getLabel(OverflowError)), Cond.Overflow)
chain += Jump(labelGenerator.getLabelArg(OverflowError), Cond.Overflow)
case BinaryOperator.Sub =>
chain += Subtract(destX, stack.head)
chain += Jump(LabelArg(labelGenerator.getLabel(OverflowError)), Cond.Overflow)
chain += Jump(labelGenerator.getLabelArg(OverflowError), Cond.Overflow)
chain += stack.drop()
chain += stack.push(destX.size, RAX)
case BinaryOperator.Mul =>
chain += Multiply(destX, stack.head)
chain += Jump(LabelArg(labelGenerator.getLabel(OverflowError)), Cond.Overflow)
chain += Jump(labelGenerator.getLabelArg(OverflowError), Cond.Overflow)
chain += stack.drop()
chain += stack.push(destX.size, RAX)
case BinaryOperator.Div =>
chain += Compare(stack.head, ImmediateVal(0))
chain += Jump(LabelArg(labelGenerator.getLabel(ZeroDivError)), Cond.Equal)
chain += Jump(labelGenerator.getLabelArg(ZeroDivError), Cond.Equal)
chain += CDQ()
chain += Divide(stack.head)
chain += stack.drop()
@@ -343,7 +344,7 @@ object asmGenerator {
case BinaryOperator.Mod =>
chain += Compare(stack.head, ImmediateVal(0))
chain += Jump(LabelArg(labelGenerator.getLabel(ZeroDivError)), Cond.Equal)
chain += Jump(labelGenerator.getLabelArg(ZeroDivError), Cond.Equal)
chain += CDQ()
chain += Divide(stack.head)
chain += stack.drop()
@@ -393,9 +394,9 @@ object asmGenerator {
// Tail Call Optimisation (TCO)
if (isTail) {
chain += Jump(LabelArg(labelGenerator.getLabel(target))) // tail call
chain += Jump(labelGenerator.getLabelArg(target)) // tail call
} else {
chain += assemblyIR.Call(LabelArg(labelGenerator.getLabel(target))) // regular call
chain += assemblyIR.Call(labelGenerator.getLabelArg(target)) // regular call
}
if (args.size > argRegs.size) {