From 631f9ddca555ec8a4c6bf34848e167dc469c60c9 Mon Sep 17 00:00:00 2001 From: Jonny Date: Wed, 26 Feb 2025 19:49:10 +0000 Subject: [PATCH] feat: (maybe) tail call optimisation --- src/main/wacc/backend/asmGenerator.scala | 28 ++++++++++++++++-------- 1 file changed, 19 insertions(+), 9 deletions(-) diff --git a/src/main/wacc/backend/asmGenerator.scala b/src/main/wacc/backend/asmGenerator.scala index 6fbbd82..1b8a39b 100644 --- a/src/main/wacc/backend/asmGenerator.scala +++ b/src/main/wacc/backend/asmGenerator.scala @@ -236,13 +236,18 @@ object asmGenerator { chain += Jump(LabelArg(startLabel)) chain += LabelDef(endLabel) - case microWacc.Return(expr) => - chain ++= evalExprOntoStack(expr) - chain += stack.pop(RAX) - chain ++= funcEpilogue() - case call: microWacc.Call => - chain ++= generateCall(call) + chain ++= generateCall(call, isTail = false) + + case microWacc.Return(expr) => + expr match { + case call: microWacc.Call => + chain ++= generateCall(call, isTail = true) // tco + case _ => + chain ++= evalExprOntoStack(expr) + chain += stack.pop(RAX) + chain ++= funcEpilogue() + } } chain @@ -323,7 +328,7 @@ object asmGenerator { } case call: microWacc.Call => - chain ++= generateCall(call) + chain ++= generateCall(call, isTail = false) chain += stack.push(RAX) } @@ -331,7 +336,7 @@ object asmGenerator { chain } - def generateCall(call: microWacc.Call)(using + def generateCall(call: microWacc.Call, isTail: Boolean)(using stack: Stack, strings: ListBuffer[String], labelGenerator: LabelGenerator @@ -348,7 +353,12 @@ object asmGenerator { chain ++= evalExprOntoStack(_) } - chain += assemblyIR.Call(LabelArg(labelGenerator.getLabel(target))) + // Tail Call Optimisation (TCO) + if (isTail) { + chain += Jump(LabelArg(labelGenerator.getLabel(target))) // tail call + } else { + chain += assemblyIR.Call(LabelArg(labelGenerator.getLabel(target))) // regular call + } if (args.size > argRegs.size) { chain += stack.drop(args.size - argRegs.size)