feat: (maybe) tail call optimisation

This commit is contained in:
Jonny 2025-02-26 19:49:10 +00:00
parent 4fb399a5e1
commit 631f9ddca5

View File

@ -236,13 +236,18 @@ object asmGenerator {
chain += Jump(LabelArg(startLabel))
chain += LabelDef(endLabel)
case microWacc.Return(expr) =>
chain ++= evalExprOntoStack(expr)
chain += stack.pop(RAX)
chain ++= funcEpilogue()
case call: microWacc.Call =>
chain ++= generateCall(call)
chain ++= generateCall(call, isTail = false)
case microWacc.Return(expr) =>
expr match {
case call: microWacc.Call =>
chain ++= generateCall(call, isTail = true) // tco
case _ =>
chain ++= evalExprOntoStack(expr)
chain += stack.pop(RAX)
chain ++= funcEpilogue()
}
}
chain
@ -323,7 +328,7 @@ object asmGenerator {
}
case call: microWacc.Call =>
chain ++= generateCall(call)
chain ++= generateCall(call, isTail = false)
chain += stack.push(RAX)
}
@ -331,7 +336,7 @@ object asmGenerator {
chain
}
def generateCall(call: microWacc.Call)(using
def generateCall(call: microWacc.Call, isTail: Boolean)(using
stack: Stack,
strings: ListBuffer[String],
labelGenerator: LabelGenerator
@ -348,7 +353,12 @@ object asmGenerator {
chain ++= evalExprOntoStack(_)
}
chain += assemblyIR.Call(LabelArg(labelGenerator.getLabel(target)))
// Tail Call Optimisation (TCO)
if (isTail) {
chain += Jump(LabelArg(labelGenerator.getLabel(target))) // tail call
} else {
chain += assemblyIR.Call(LabelArg(labelGenerator.getLabel(target))) // regular call
}
if (args.size > argRegs.size) {
chain += stack.drop(args.size - argRegs.size)