fix: always push a value onto stack on expr evaluation

This commit is contained in:
Gleb Koval 2025-02-25 16:27:47 +00:00
parent 3f76a2c5bf
commit f628d16d3d
Signed by: cyclane
GPG Key ID: 15E168A8B332382C
2 changed files with 51 additions and 153 deletions

View File

@ -50,7 +50,11 @@ object asmGenerator {
generateBuiltInFuncs() generateBuiltInFuncs()
val strDirs = strings.toList.zipWithIndex.flatMap { case (str, i) => val strDirs = strings.toList.zipWithIndex.flatMap { case (str, i) =>
List(Directive.Int(str.size), LabelDef(s".L.str$i"), Directive.Asciz(str.replace("\"", "\\\""))) List(
Directive.Int(str.size),
LabelDef(s".L.str$i"),
Directive.Asciz(str.replace("\"", "\\\""))
)
} }
List(Directive.IntelSyntax, Directive.Global("main"), Directive.RoData) ++ List(Directive.IntelSyntax, Directive.Global("main"), Directive.RoData) ++
@ -89,7 +93,7 @@ object asmGenerator {
wrapFunc( wrapFunc(
labelGenerator.getLabel(Builtin.Malloc), labelGenerator.getLabel(Builtin.Malloc),
List( List(
stack.align(), stack.align()
) )
) ++ ) ++
wrapFunc(labelGenerator.getLabel(Builtin.Free), List()) ++ wrapFunc(labelGenerator.getLabel(Builtin.Free), List()) ++
@ -126,7 +130,7 @@ object asmGenerator {
evalExprOntoStack(rhs) ++ evalExprOntoStack(rhs) ++
List( List(
stack.pop(RAX), stack.pop(RAX),
Move(dest(), RAX), Move(dest(), RAX)
) )
case If(cond, thenBranch, elseBranch) => { case If(cond, thenBranch, elseBranch) => {
val elseLabel = labelGenerator.getLabel() val elseLabel = labelGenerator.getLabel()
@ -165,7 +169,7 @@ object asmGenerator {
stack: Stack, stack: Stack,
strings: ListBuffer[String] strings: ListBuffer[String]
): List[AsmLine] = { ): List[AsmLine] = {
expr match { val out = expr match {
case IntLiter(v) => case IntLiter(v) =>
List(stack.push(ImmediateVal(v))) List(stack.push(ImmediateVal(v)))
case CharLiter(v) => case CharLiter(v) =>
@ -196,7 +200,8 @@ object asmGenerator {
case NullLiter() => List(stack.push(ImmediateVal(0))) case NullLiter() => List(stack.push(ImmediateVal(0)))
case ArrayElem(value, indices) => List() case ArrayElem(value, indices) => List()
case UnaryOp(x, op) => case UnaryOp(x, op) =>
op match { evalExprOntoStack(x) ++
(op match {
// TODO: chr and ord are TYPE CASTS. They do not change the internal value, // TODO: chr and ord are TYPE CASTS. They do not change the internal value,
// but will need bound checking e.t.c. // but will need bound checking e.t.c.
case UnaryOperator.Chr => List() case UnaryOperator.Chr => List()
@ -212,7 +217,7 @@ object asmGenerator {
Xor(stack.head(SizeDir.Word), ImmediateVal(1)) Xor(stack.head(SizeDir.Word), ImmediateVal(1))
) )
} })
case BinaryOp(x, y, op) => case BinaryOp(x, y, op) =>
op match { op match {
case BinaryOperator.Add => case BinaryOperator.Add =>
@ -288,8 +293,11 @@ object asmGenerator {
Or(stack.head(SizeDir.Word), EAX) Or(stack.head(SizeDir.Word), EAX)
) )
} }
case call: microWacc.Call => generateCall(call) case call: microWacc.Call =>
generateCall(call) ++
List(stack.push(RAX))
} }
if out.isEmpty then List(stack.push(ImmediateVal(0))) else out
} }
def generateCall(call: microWacc.Call)(using def generateCall(call: microWacc.Call)(using
@ -305,35 +313,10 @@ object asmGenerator {
args.drop(argRegs.size).flatMap(evalExprOntoStack) ++ args.drop(argRegs.size).flatMap(evalExprOntoStack) ++
List(assemblyIR.Call(LabelArg(labelGenerator.getLabel(target)))) ++ List(assemblyIR.Call(LabelArg(labelGenerator.getLabel(target)))) ++
(if (args.size > argRegs.size) { (if (args.size > argRegs.size) {
List(stack.reserve(args.size - argRegs.size)) List(stack.drop(args.size - argRegs.size))
} else Nil) } else Nil)
} }
// def readIntoVar(dest: IndexAddress, readType: Builtin.ReadInt.type | Builtin.ReadChar.type)(using
// stack: Stack,
// strings: ListBuffer[String]
// ): List[AsmLine] = {
// readType match {
// case Builtin.ReadInt =>
// strings += PrintFormat.Int.toString
// case Builtin.ReadChar =>
// strings += PrintFormat.Char.toString
// }
// List(
// Load(
// RDI,
// IndexAddress(
// RIP,
// LabelArg(s".L.str${strings.size - 1}")
// )
// ),
// Load(RSI, dest)
// ) ++
// // alignStack() ++
// List(assemblyIR.Call(CLibFunc.Scanf))
// }
def generateComparison(x: Expr, y: Expr, cond: Cond)(using def generateComparison(x: Expr, y: Expr, cond: Cond)(using
stack: Stack, stack: Stack,
strings: ListBuffer[String] strings: ListBuffer[String]
@ -366,91 +349,6 @@ object asmGenerator {
) )
} }
// def saveRegs(regList: List[Register]): List[AsmLine] = regList.map(Push(_))
// def restoreRegs(regList: List[Register]): List[AsmLine] = regList.reverse.map(Pop(_))
// TODO: refactor, really ugly function
// def printF(expr: Expr)(using
// stack: Stack,
// strings: ListBuffer[String]
// ): List[AsmLine] = {
// // determine the format string
// expr.ty match {
// case KnownType.String =>
// strings += PrintFormat.String.toString
// case KnownType.Char =>
// strings += PrintFormat.Char.toString
// case KnownType.Int =>
// strings += PrintFormat.Int.toString
// case _ =>
// strings += PrintFormat.String.toString
// }
// List(
// Load(
// RDI,
// IndexAddress(
// RIP,
// LabelArg(s".L.str${strings.size - 1}")
// )
// )
// )
// ++
// // determine the actual value to print
// (if (expr.ty == KnownType.Bool) {
// expr match {
// case BoolLiter(true) => {
// strings += "true"
// }
// case _ => {
// strings += "false"
// }
// }
// List(
// Load(
// RDI,
// IndexAddress(
// RIP,
// LabelArg(s".L.str${strings.size - 1}")
// )
// )
// )
// } else {
// evalExprOntoStack(expr) ++
// List(Pop(RSI))
// })
// // print the value
// ++
// List(
// assemblyIR.Call(CLibFunc.PrintF),
// Move(RDI, ImmediateVal(0)),
// assemblyIR.Call(CLibFunc.Fflush)
// )
// }
// prints a new line
// def printLn()(using
// stack: Stack,
// strings: ListBuffer[String]
// ): List[AsmLine] = {
// strings += ""
// Load(
// RDI,
// IndexAddress(
// RIP,
// LabelArg(s".L.str${strings.size - 1}")
// )
// )
// ::
// List(
// assemblyIR.Call(CLibFunc.Puts),
// Move(RDI, ImmediateVal(0)),
// assemblyIR.Call(CLibFunc.Fflush)
// )
// }
class Stack { class Stack {
private val stack = LinkedHashMap[Expr | Int, Int]() private val stack = LinkedHashMap[Expr | Int, Int]()
private val RSP = Register(RegSize.R64, RegName.SP) private val RSP = Register(RegSize.R64, RegName.SP)
@ -474,11 +372,11 @@ object asmGenerator {
} }
def reserve(n: Int = 1): AsmLine = { def reserve(n: Int = 1): AsmLine = {
(1 to n).foreach(_ => stack += stack.size -> next) (1 to n).foreach(_ => stack += stack.size -> next)
Subtract(RSP, ImmediateVal(n*8)) Subtract(RSP, ImmediateVal(n * 8))
} }
def drop(n : Int = 1): AsmLine = { def drop(n: Int = 1): AsmLine = {
(1 to n).foreach(_ => stack.remove(stack.last._1)) (1 to n).foreach(_ => stack.remove(stack.last._1))
Add(RSP, ImmediateVal(n*8)) Add(RSP, ImmediateVal(n * 8))
} }
def accessVar(ident: Ident): () => IndexAddress = () => { def accessVar(ident: Ident): () => IndexAddress = () => {
IndexAddress(RSP, (stack.size - stack(ident)) * 8) IndexAddress(RSP, (stack.size - stack(ident)) * 8)

View File

@ -27,13 +27,13 @@ class ParallelExamplesSpec extends AnyFlatSpec with BeforeAndAfterAll with Paral
forEvery(files) { (filename, expectedResult) => forEvery(files) { (filename, expectedResult) =>
val baseFilename = filename.stripSuffix(".wacc") val baseFilename = filename.stripSuffix(".wacc")
given stdout: PrintStream = PrintStream(File(baseFilename + ".out")) given stdout: PrintStream = PrintStream(File(baseFilename + ".out"))
val result = compile(filename)
s"$filename" should "be compiled with correct result" in { s"$filename" should "be compiled with correct result" in {
val result = compile(filename)
assert(expectedResult.contains(result)) assert(expectedResult.contains(result))
} }
if (result == 0) it should "run with correct result" in { if (expectedResult == List(0)) it should "run with correct result" in {
if (fileIsDisallowedBackend(filename)) pending if (fileIsDisallowedBackend(filename)) pending
// Retrieve contents to get input and expected output + exit code // Retrieve contents to get input and expected output + exit code
@ -85,24 +85,24 @@ class ParallelExamplesSpec extends AnyFlatSpec with BeforeAndAfterAll with Paral
Seq( Seq(
// format: off // format: off
// disable formatting to avoid binPack // disable formatting to avoid binPack
// "^.*wacc-examples/valid/advanced.*$", "^.*wacc-examples/valid/advanced.*$",
// "^.*wacc-examples/valid/array.*$", "^.*wacc-examples/valid/array.*$",
// "^.*wacc-examples/valid/basic/exit.*$", "^.*wacc-examples/valid/basic/exit.*$",
// "^.*wacc-examples/valid/basic/skip.*$", "^.*wacc-examples/valid/basic/skip.*$",
// "^.*wacc-examples/valid/expressions.*$", "^.*wacc-examples/valid/expressions.*$",
// "^.*wacc-examples/valid/function/nested_functions.*$", "^.*wacc-examples/valid/function/nested_functions.*$",
// "^.*wacc-examples/valid/function/simple_functions.*$", "^.*wacc-examples/valid/function/simple_functions.*$",
// "^.*wacc-examples/valid/if.*$", "^.*wacc-examples/valid/if.*$",
// "^.*wacc-examples/valid/IO/print.*$", "^.*wacc-examples/valid/IO/print.*$",
// "^.*wacc-examples/valid/IO/read.*$", "^.*wacc-examples/valid/IO/read.*$",
// "^.*wacc-examples/valid/IO/IOLoop.wacc.*$", "^.*wacc-examples/valid/IO/IOLoop.wacc.*$",
// "^.*wacc-examples/valid/IO/IOSequence.wacc.*$", "^.*wacc-examples/valid/IO/IOSequence.wacc.*$",
// "^.*wacc-examples/valid/pairs.*$", "^.*wacc-examples/valid/pairs.*$",
// "^.*wacc-examples/valid/runtimeErr.*$", "^.*wacc-examples/valid/runtimeErr.*$",
// "^.*wacc-examples/valid/scope.*$", "^.*wacc-examples/valid/scope.*$",
// "^.*wacc-examples/valid/sequence.*$", "^.*wacc-examples/valid/sequence.*$",
// "^.*wacc-examples/valid/variables.*$", "^.*wacc-examples/valid/variables.*$",
// "^.*wacc-examples/valid/while.*$", "^.*wacc-examples/valid/while.*$",
// format: on // format: on
).find(filename.matches).isDefined ).find(filename.matches).isDefined
} }