fix: reset scope after all branching

This commit is contained in:
Gleb Koval 2025-02-26 20:25:27 +00:00
parent 2cf18a47a8
commit 09df7af2ab
Signed by: cyclane
GPG Key ID: 15E168A8B332382C
2 changed files with 30 additions and 7 deletions

View File

@ -4,7 +4,6 @@ import scala.collection.mutable.LinkedHashMap
import scala.collection.mutable.ListBuffer
import cats.data.Chain
import cats.syntax.foldable._
// import parsley.token.errors.Label
object asmGenerator {
import microWacc._
@ -183,13 +182,33 @@ object asmGenerator {
chain
}
/** Wraps a chain in a stack reset.
*
* This is useful for ensuring that the stack size at the death of scope is the same as the stack
* size at the start of the scope. See branching (If / While)
*
* @param genChain
* Function that generates the scope AsmLines
* @param stack
* The stack to reset
* @return
* The generated scope AsmLines
*/
private def generateScope(genChain: () => Chain[AsmLine])(using
stack: Stack
): Chain[AsmLine] = {
val stackSizeStart = stack.size
var chain = genChain()
chain += stack.drop(stack.size - stackSizeStart)
chain
}
def generateStmt(stmt: Stmt)(using
stack: Stack,
strings: ListBuffer[String],
labelGenerator: LabelGenerator
): Chain[AsmLine] = {
var chain = Chain.empty[AsmLine]
stmt match {
case Assign(lhs, rhs) =>
var dest: () => IndexAddress = () => IndexAddress(RAX, 0) // overwritten below
@ -215,11 +234,11 @@ object asmGenerator {
chain += Compare(RAX, ImmediateVal(0))
chain += Jump(LabelArg(elseLabel), Cond.Equal)
chain ++= thenBranch.foldMap(generateStmt)
chain ++= generateScope(() => thenBranch.foldMap(generateStmt))
chain += Jump(LabelArg(endLabel))
chain += LabelDef(elseLabel)
chain ++= elseBranch.foldMap(generateStmt)
chain ++= generateScope(() => elseBranch.foldMap(generateStmt))
chain += LabelDef(endLabel)
case While(cond, body) =>
@ -232,7 +251,7 @@ object asmGenerator {
chain += Compare(RAX, ImmediateVal(0))
chain += Jump(LabelArg(endLabel), Cond.Equal)
chain ++= body.foldMap(generateStmt)
chain ++= generateScope(() => body.foldMap(generateStmt))
chain += Jump(LabelArg(startLabel))
chain += LabelDef(endLabel)
@ -259,7 +278,7 @@ object asmGenerator {
labelGenerator: LabelGenerator
): Chain[AsmLine] = {
var chain = Chain.empty[AsmLine]
val stackSizeStart = stack.size
expr match {
case IntLiter(v) => chain += stack.push(ImmediateVal(v))
case CharLiter(v) => chain += stack.push(ImmediateVal(v.toInt))
@ -333,6 +352,8 @@ object asmGenerator {
}
if chain.isEmpty then chain += stack.push(ImmediateVal(0))
assert(stack.size == stackSizeStart + 1)
chain
}
@ -404,6 +425,7 @@ object asmGenerator {
private val RSP = Register(RegSize.R64, RegName.SP)
private def next: Int = stack.size + 1
def size: Int = stack.size
def push(expr: Expr, src: Src): AsmLine = {
stack += expr -> next
Push(src)

View File

@ -99,7 +99,8 @@ class ParallelExamplesSpec extends AnyFlatSpec with BeforeAndAfterAll {
// "^.*wacc-examples/valid/IO/IOSequence.wacc.*$",
"^.*wacc-examples/valid/pairs.*$",
"^.*wacc-examples/valid/runtimeErr.*$",
"^.*wacc-examples/valid/scope.*$",
// "^.*wacc-examples/valid/scope.*$",
"^.*wacc-examples/valid/scope/printAllTypes.wacc$", // while we still don't have arrays implemented
// "^.*wacc-examples/valid/sequence.*$",
// "^.*wacc-examples/valid/variables.*$",
// "^.*wacc-examples/valid/while.*$",