From 09df7af2ab58a164f1fab5c6168342f38fa637ee Mon Sep 17 00:00:00 2001 From: Gleb Koval Date: Wed, 26 Feb 2025 20:25:27 +0000 Subject: [PATCH] fix: reset scope after all branching --- src/main/wacc/backend/asmGenerator.scala | 34 +++++++++++++++++++----- src/test/wacc/examples.scala | 3 ++- 2 files changed, 30 insertions(+), 7 deletions(-) diff --git a/src/main/wacc/backend/asmGenerator.scala b/src/main/wacc/backend/asmGenerator.scala index 0d894bc..0e0643e 100644 --- a/src/main/wacc/backend/asmGenerator.scala +++ b/src/main/wacc/backend/asmGenerator.scala @@ -4,7 +4,6 @@ import scala.collection.mutable.LinkedHashMap import scala.collection.mutable.ListBuffer import cats.data.Chain import cats.syntax.foldable._ -// import parsley.token.errors.Label object asmGenerator { import microWacc._ @@ -183,13 +182,33 @@ object asmGenerator { chain } + /** Wraps a chain in a stack reset. + * + * This is useful for ensuring that the stack size at the death of scope is the same as the stack + * size at the start of the scope. See branching (If / While) + * + * @param genChain + * Function that generates the scope AsmLines + * @param stack + * The stack to reset + * @return + * The generated scope AsmLines + */ + private def generateScope(genChain: () => Chain[AsmLine])(using + stack: Stack + ): Chain[AsmLine] = { + val stackSizeStart = stack.size + var chain = genChain() + chain += stack.drop(stack.size - stackSizeStart) + chain + } + def generateStmt(stmt: Stmt)(using stack: Stack, strings: ListBuffer[String], labelGenerator: LabelGenerator ): Chain[AsmLine] = { var chain = Chain.empty[AsmLine] - stmt match { case Assign(lhs, rhs) => var dest: () => IndexAddress = () => IndexAddress(RAX, 0) // overwritten below @@ -215,11 +234,11 @@ object asmGenerator { chain += Compare(RAX, ImmediateVal(0)) chain += Jump(LabelArg(elseLabel), Cond.Equal) - chain ++= thenBranch.foldMap(generateStmt) + chain ++= generateScope(() => thenBranch.foldMap(generateStmt)) chain += Jump(LabelArg(endLabel)) chain += LabelDef(elseLabel) - chain ++= elseBranch.foldMap(generateStmt) + chain ++= generateScope(() => elseBranch.foldMap(generateStmt)) chain += LabelDef(endLabel) case While(cond, body) => @@ -232,7 +251,7 @@ object asmGenerator { chain += Compare(RAX, ImmediateVal(0)) chain += Jump(LabelArg(endLabel), Cond.Equal) - chain ++= body.foldMap(generateStmt) + chain ++= generateScope(() => body.foldMap(generateStmt)) chain += Jump(LabelArg(startLabel)) chain += LabelDef(endLabel) @@ -259,7 +278,7 @@ object asmGenerator { labelGenerator: LabelGenerator ): Chain[AsmLine] = { var chain = Chain.empty[AsmLine] - + val stackSizeStart = stack.size expr match { case IntLiter(v) => chain += stack.push(ImmediateVal(v)) case CharLiter(v) => chain += stack.push(ImmediateVal(v.toInt)) @@ -333,6 +352,8 @@ object asmGenerator { } if chain.isEmpty then chain += stack.push(ImmediateVal(0)) + + assert(stack.size == stackSizeStart + 1) chain } @@ -404,6 +425,7 @@ object asmGenerator { private val RSP = Register(RegSize.R64, RegName.SP) private def next: Int = stack.size + 1 + def size: Int = stack.size def push(expr: Expr, src: Src): AsmLine = { stack += expr -> next Push(src) diff --git a/src/test/wacc/examples.scala b/src/test/wacc/examples.scala index 7f8538d..8ac0aa4 100644 --- a/src/test/wacc/examples.scala +++ b/src/test/wacc/examples.scala @@ -99,7 +99,8 @@ class ParallelExamplesSpec extends AnyFlatSpec with BeforeAndAfterAll { // "^.*wacc-examples/valid/IO/IOSequence.wacc.*$", "^.*wacc-examples/valid/pairs.*$", "^.*wacc-examples/valid/runtimeErr.*$", - "^.*wacc-examples/valid/scope.*$", + // "^.*wacc-examples/valid/scope.*$", + "^.*wacc-examples/valid/scope/printAllTypes.wacc$", // while we still don't have arrays implemented // "^.*wacc-examples/valid/sequence.*$", // "^.*wacc-examples/valid/variables.*$", // "^.*wacc-examples/valid/while.*$",