From 7953790f4d540b4734108aabe4cbd35cc09acba5 Mon Sep 17 00:00:00 2001 From: Jonny Date: Tue, 25 Feb 2025 18:44:11 +0000 Subject: [PATCH 1/5] feat: used Chains instead of Lists --- src/main/wacc/backend/asmGenerator.scala | 190 ++++++++++++----------- 1 file changed, 98 insertions(+), 92 deletions(-) diff --git a/src/main/wacc/backend/asmGenerator.scala b/src/main/wacc/backend/asmGenerator.scala index 0d02653..a951051 100644 --- a/src/main/wacc/backend/asmGenerator.scala +++ b/src/main/wacc/backend/asmGenerator.scala @@ -2,6 +2,8 @@ package wacc import scala.collection.mutable.LinkedHashMap import scala.collection.mutable.ListBuffer +import cats.data.Chain +import cats.syntax.foldable._ object asmGenerator { import microWacc._ @@ -31,33 +33,38 @@ object asmGenerator { val Program(funcs, main) = microProg val progAsm = - LabelDef("main") :: + Chain.one(LabelDef("main")) ++ funcPrologue() ++ - List(stack.align()) ++ - main.flatMap(generateStmt) ++ - List(Move(RAX, ImmediateVal(0))) ++ + Chain(stack.align()) ++ + main.foldLeft(Chain.empty[AsmLine])(_ ++ generateStmt(_)) ++ + Chain.one(Move(RAX, ImmediateVal(0))) ++ funcEpilogue() ++ generateBuiltInFuncs() - val strDirs = strings.toList.zipWithIndex.flatMap { case (str, i) => - List( - Directive.Int(str.size), - LabelDef(s".L.str$i"), - Directive.Asciz(str.escaped) - ) + val strDirs = strings.toList.zipWithIndex.foldLeft(Chain.empty[AsmLine]) { + case (acc, (str, i)) => + acc ++ Chain( + Directive.Int(str.size), + LabelDef(s".L.str$i"), + Directive.Asciz(str.escaped) + ) } - List(Directive.IntelSyntax, Directive.Global("main"), Directive.RoData) ++ - strDirs ++ - List(Directive.Text) ++ - progAsm + val finalChain = Chain( + Directive.IntelSyntax, + Directive.Global("main"), + Directive.RoData + ) ++ strDirs ++ Chain.one(Directive.Text) ++ progAsm + + finalChain.toList + } - def wrapFunc(labelName: String, funcBody: List[AsmLine])(using + def wrapFunc(labelName: String, funcBody: Chain[AsmLine])(using stack: Stack, strings: ListBuffer[String] - ): List[AsmLine] = { - LabelDef(labelName) :: + ): Chain[AsmLine] = { + Chain.one(LabelDef(labelName)) ++ funcPrologue() ++ funcBody ++ funcEpilogue() @@ -65,16 +72,15 @@ object asmGenerator { def generateBuiltInFuncs()(using stack: Stack, - strings: ListBuffer[String], - labelGenerator: LabelGenerator - ): List[AsmLine] = { + strings: ListBuffer[String] + ): Chain[AsmLine] = { wrapFunc( labelGenerator.getLabel(Builtin.Exit), - List(stack.align(), assemblyIR.Call(CLibFunc.Exit)) + Chain(stack.align(), assemblyIR.Call(CLibFunc.Exit)) ) ++ wrapFunc( labelGenerator.getLabel(Builtin.Printf), - List( + Chain( stack.align(), assemblyIR.Call(CLibFunc.PrintF), Move(RDI, ImmediateVal(0)), @@ -83,14 +89,14 @@ object asmGenerator { ) ++ wrapFunc( labelGenerator.getLabel(Builtin.Malloc), - List( + Chain.one( stack.align() ) ) ++ - wrapFunc(labelGenerator.getLabel(Builtin.Free), List()) ++ + wrapFunc(labelGenerator.getLabel(Builtin.Free), Chain.empty) ++ wrapFunc( labelGenerator.getLabel(Builtin.Read), - List( + Chain( stack.align(), stack.reserve(), stack.push(RSI), @@ -104,11 +110,7 @@ object asmGenerator { def generateStmt( stmt: Stmt - )(using - stack: Stack, - strings: ListBuffer[String], - labelGenerator: LabelGenerator - ): List[AsmLine] = + )(using stack: Stack, strings: ListBuffer[String]): Chain[AsmLine] = stmt match { case Assign(lhs, rhs) => var dest: () => IndexAddress = @@ -117,15 +119,15 @@ object asmGenerator { case ident: Ident => dest = stack.accessVar(ident) if (!stack.contains(ident)) { - List(stack.reserve(ident)) - } else Nil + Chain.one(stack.reserve(ident)) + } else Chain.empty // TODO lhs = arrayElem case _ => // dest = ??? - List() + Chain.empty }) ++ evalExprOntoStack(rhs) ++ - List( + Chain( stack.pop(RAX), Move(dest(), RAX) ) @@ -133,47 +135,48 @@ object asmGenerator { val elseLabel = labelGenerator.getLabel() val endLabel = labelGenerator.getLabel() evalExprOntoStack(cond) ++ - List( - Compare(stack.head(SizeDir.Word), ImmediateVal(0)), - stack.drop(), - Jump(LabelArg(elseLabel), Cond.Equal) + Chain.fromSeq( + List( + Compare(stack.head(SizeDir.Word), ImmediateVal(0)), + stack.drop(), + Jump(LabelArg(elseLabel), Cond.Equal) + ) ) ++ - thenBranch.flatMap(generateStmt) ++ - List(Jump(LabelArg(endLabel)), LabelDef(elseLabel)) ++ - elseBranch.flatMap(generateStmt) ++ - List(LabelDef(endLabel)) + Chain.fromSeq(thenBranch).flatMap(generateStmt) ++ + Chain.fromSeq(List(Jump(LabelArg(endLabel)), LabelDef(elseLabel))) ++ + Chain.fromSeq(elseBranch).flatMap(generateStmt) ++ + Chain.one(LabelDef(endLabel)) } case While(cond, body) => { val startLabel = labelGenerator.getLabel() val endLabel = labelGenerator.getLabel() - List(LabelDef(startLabel)) ++ + Chain.one(LabelDef(startLabel)) ++ evalExprOntoStack(cond) ++ - List( + Chain( Compare(stack.head(SizeDir.Word), ImmediateVal(0)), stack.drop(), Jump(LabelArg(endLabel), Cond.Equal) ) ++ - body.flatMap(generateStmt) ++ - List(Jump(LabelArg(startLabel)), LabelDef(endLabel)) + Chain.fromSeq(body).flatMap(generateStmt) ++ + Chain(Jump(LabelArg(startLabel)), LabelDef(endLabel)) } case microWacc.Return(expr) => evalExprOntoStack(expr) ++ - List(stack.pop(RAX), assemblyIR.Return()) + Chain(stack.pop(RAX), assemblyIR.Return()) case call: microWacc.Call => generateCall(call) } def evalExprOntoStack(expr: Expr)(using stack: Stack, - strings: ListBuffer[String], - labelGenerator: LabelGenerator - ): List[AsmLine] = { + strings: ListBuffer[String] + ): Chain[AsmLine] = { val out = expr match { case IntLiter(v) => - List(stack.push(ImmediateVal(v))) + Chain.one(stack.push(ImmediateVal(v))) case CharLiter(v) => - List(stack.push(ImmediateVal(v.toInt))) + Chain.one(stack.push(ImmediateVal(v.toInt))) case ident: Ident => - List(stack.push(stack.accessVar(ident)())) + Chain.one(stack.push(stack.accessVar(ident)())) case ArrayLiter(elems) => expr.ty match { case KnownType.String => @@ -181,7 +184,7 @@ object asmGenerator { case CharLiter(v) => v case _ => "" }.mkString - List( + Chain( Load( RAX, IndexAddress( @@ -192,26 +195,24 @@ object asmGenerator { stack.push(RAX) ) // TODO other array types - case _ => List() + case _ => Chain.empty } - case BoolLiter(v) => List(stack.push(ImmediateVal(if (v) 1 else 0))) - case NullLiter() => List(stack.push(ImmediateVal(0))) - case ArrayElem(value, indices) => List() + case BoolLiter(v) => Chain.one(stack.push(ImmediateVal(if (v) 1 else 0))) + case NullLiter() => Chain.one(stack.push(ImmediateVal(0))) + case ArrayElem(value, indices) => Chain.empty case UnaryOp(x, op) => evalExprOntoStack(x) ++ (op match { // TODO: chr and ord are TYPE CASTS. They do not change the internal value, // but will need bound checking e.t.c. - case UnaryOperator.Chr => List() - case UnaryOperator.Ord => List() - case UnaryOperator.Len => List() + case UnaryOperator.Chr | UnaryOperator.Ord | UnaryOperator.Len => Chain.empty case UnaryOperator.Negate => - List( + Chain.one( Negate(stack.head(SizeDir.Word)) ) case UnaryOperator.Not => evalExprOntoStack(x) ++ - List( + Chain.one( Xor(stack.head(SizeDir.Word), ImmediateVal(1)) ) @@ -221,7 +222,7 @@ object asmGenerator { case BinaryOperator.Add => evalExprOntoStack(x) ++ evalExprOntoStack(y) ++ - List( + Chain( stack.pop(RAX), Add(stack.head(SizeDir.Word), EAX) // TODO OVERFLOWING @@ -229,7 +230,7 @@ object asmGenerator { case BinaryOperator.Sub => evalExprOntoStack(x) ++ evalExprOntoStack(y) ++ - List( + Chain( stack.pop(RAX), Subtract(stack.head(SizeDir.Word), EAX) // TODO OVERFLOWING @@ -237,7 +238,7 @@ object asmGenerator { case BinaryOperator.Mul => evalExprOntoStack(x) ++ evalExprOntoStack(y) ++ - List( + Chain( stack.pop(RAX), Multiply(EAX, stack.head(SizeDir.Word)), stack.drop(), @@ -247,7 +248,7 @@ object asmGenerator { case BinaryOperator.Div => evalExprOntoStack(y) ++ evalExprOntoStack(x) ++ - List( + Chain( stack.pop(RAX), Divide(stack.head(SizeDir.Word)), stack.drop(), @@ -257,7 +258,7 @@ object asmGenerator { case BinaryOperator.Mod => evalExprOntoStack(y) ++ evalExprOntoStack(x) ++ - List( + Chain( stack.pop(RAX), Divide(stack.head(SizeDir.Word)), stack.drop(), @@ -279,51 +280,56 @@ object asmGenerator { case BinaryOperator.And => evalExprOntoStack(x) ++ evalExprOntoStack(y) ++ - List( + Chain( stack.pop(RAX), And(stack.head(SizeDir.Word), EAX) ) case BinaryOperator.Or => evalExprOntoStack(x) ++ evalExprOntoStack(y) ++ - List( + Chain( stack.pop(RAX), Or(stack.head(SizeDir.Word), EAX) ) } case call: microWacc.Call => generateCall(call) ++ - List(stack.push(RAX)) + Chain.one(stack.push(RAX)) } - if out.isEmpty then List(stack.push(ImmediateVal(0))) else out + if out.isEmpty then Chain.one(stack.push(ImmediateVal(0))) else out } def generateCall(call: microWacc.Call)(using stack: Stack, - strings: ListBuffer[String], - labelGenerator: LabelGenerator - ): List[AsmLine] = { + strings: ListBuffer[String] + ): Chain[AsmLine] = { val argRegs = List(RDI, RSI, RDX, RCX, R8, R9) val microWacc.Call(target, args) = call - argRegs.zip(args).flatMap { (reg, expr) => - evalExprOntoStack(expr) ++ - List(stack.pop(reg)) - } ++ - args.drop(argRegs.size).flatMap(evalExprOntoStack) ++ - List(assemblyIR.Call(LabelArg(labelGenerator.getLabel(target)))) ++ - (if (args.size > argRegs.size) { - List(stack.drop(args.size - argRegs.size)) - } else Nil) + + val regMoves = argRegs + .zip(args) + .map { (reg, expr) => + evalExprOntoStack(expr) ++ + Chain.one(stack.pop(reg)) + } + .combineAll + + val stackPushes = args.drop(argRegs.size).map(evalExprOntoStack).combineAll + + regMoves ++ + stackPushes ++ + Chain.one(assemblyIR.Call(LabelArg(labelGenerator.getLabel(target)))) ++ + (if (args.size > argRegs.size) Chain.one(stack.drop(args.size - argRegs.size)) + else Chain.empty) } def generateComparison(x: Expr, y: Expr, cond: Cond)(using stack: Stack, - strings: ListBuffer[String], - labelGenerator: LabelGenerator - ): List[AsmLine] = { + strings: ListBuffer[String] + ): Chain[AsmLine] = { evalExprOntoStack(x) ++ evalExprOntoStack(y) ++ - List( + Chain( stack.pop(RAX), Compare(stack.head(SizeDir.Word), EAX), Set(Register(RegSize.Byte, RegName.AL), cond), @@ -334,15 +340,15 @@ object asmGenerator { } // Missing a sub instruction but dont think we need it - def funcPrologue()(using stack: Stack): List[AsmLine] = { - List( + def funcPrologue()(using stack: Stack): Chain[AsmLine] = { + Chain( stack.push(RBP), Move(RBP, Register(RegSize.R64, RegName.SP)) ) } - def funcEpilogue()(using stack: Stack): List[AsmLine] = { - List( + def funcEpilogue()(using stack: Stack): Chain[AsmLine] = { + Chain( Move(Register(RegSize.R64, RegName.SP), RBP), stack.pop(RBP), assemblyIR.Return() From edbc03ee25fcecd6feaddba9aeabc9aea39a515c Mon Sep 17 00:00:00 2001 From: Jonny Date: Tue, 25 Feb 2025 19:39:55 +0000 Subject: [PATCH 2/5] feat: used local mutable Chains. Also implemented new LabelGenerator --- src/main/wacc/backend/asmGenerator.scala | 447 +++++++++++------------ 1 file changed, 215 insertions(+), 232 deletions(-) diff --git a/src/main/wacc/backend/asmGenerator.scala b/src/main/wacc/backend/asmGenerator.scala index a951051..7f22e20 100644 --- a/src/main/wacc/backend/asmGenerator.scala +++ b/src/main/wacc/backend/asmGenerator.scala @@ -4,6 +4,7 @@ import scala.collection.mutable.LinkedHashMap import scala.collection.mutable.ListBuffer import cats.data.Chain import cats.syntax.foldable._ +import parsley.token.errors.Label object asmGenerator { import microWacc._ @@ -26,10 +27,27 @@ object asmGenerator { val _8_BIT_MASK = 0xff + extension (chain: Chain[AsmLine]) + def +=(line: AsmLine): Chain[AsmLine] = chain.append(line) + + class LabelGenerator { + var labelVal = -1 + def getLabel(): String = { + labelVal += 1 + s".L$labelVal" + } + def getLabel(target: CallTarget): String = target match { + case Ident(v, _) => s"wacc_$v" + case Builtin(name) => s"_$name" + } + } + + def generateAsm(microProg: Program): List[AsmLine] = { given stack: Stack = Stack() given strings: ListBuffer[String] = ListBuffer[String]() given labelGenerator: LabelGenerator = LabelGenerator() + val Program(funcs, main) = microProg val progAsm = @@ -64,295 +82,260 @@ object asmGenerator { stack: Stack, strings: ListBuffer[String] ): Chain[AsmLine] = { - Chain.one(LabelDef(labelName)) ++ - funcPrologue() ++ - funcBody ++ - funcEpilogue() + var chain = Chain.empty[AsmLine] + + chain += LabelDef(labelName) + chain ++= funcPrologue() + chain ++= funcBody + chain ++= funcEpilogue() + + chain } def generateBuiltInFuncs()(using stack: Stack, - strings: ListBuffer[String] + strings: ListBuffer[String], + labelGenerator: LabelGenerator ): Chain[AsmLine] = { - wrapFunc( + var chain = Chain.empty[AsmLine] + + chain ++= wrapFunc( labelGenerator.getLabel(Builtin.Exit), Chain(stack.align(), assemblyIR.Call(CLibFunc.Exit)) - ) ++ - wrapFunc( - labelGenerator.getLabel(Builtin.Printf), - Chain( - stack.align(), - assemblyIR.Call(CLibFunc.PrintF), - Move(RDI, ImmediateVal(0)), - assemblyIR.Call(CLibFunc.Fflush) - ) - ) ++ - wrapFunc( - labelGenerator.getLabel(Builtin.Malloc), - Chain.one( - stack.align() - ) - ) ++ - wrapFunc(labelGenerator.getLabel(Builtin.Free), Chain.empty) ++ - wrapFunc( - labelGenerator.getLabel(Builtin.Read), - Chain( - stack.align(), - stack.reserve(), - stack.push(RSI), - Load(RSI, stack.head), - assemblyIR.Call(CLibFunc.Scanf), - stack.pop(RAX), - stack.drop() - ) + ) + + chain ++= wrapFunc( + labelGenerator.getLabel(Builtin.Printf), + Chain( + stack.align(), + assemblyIR.Call(CLibFunc.PrintF), + Move(RDI, ImmediateVal(0)), + assemblyIR.Call(CLibFunc.Fflush) ) + ) + + chain ++= wrapFunc( + labelGenerator.getLabel(Builtin.Malloc), + Chain.one(stack.align()) + ) + + chain ++= wrapFunc(labelGenerator.getLabel(Builtin.Free), Chain.empty) + + chain ++= wrapFunc( + labelGenerator.getLabel(Builtin.Read), + Chain( + stack.align(), + stack.reserve(), + stack.push(RSI), + Load(RSI, stack.head), + assemblyIR.Call(CLibFunc.Scanf), + stack.pop(RAX), + stack.drop() + ) + ) + + chain } - def generateStmt( - stmt: Stmt - )(using stack: Stack, strings: ListBuffer[String]): Chain[AsmLine] = + def generateStmt(stmt: Stmt)(using stack: Stack, strings: ListBuffer[String], labelGenerator: LabelGenerator): Chain[AsmLine] = { + var chain = Chain.empty[AsmLine] + stmt match { case Assign(lhs, rhs) => - var dest: () => IndexAddress = - () => IndexAddress(RAX, 0) // gets overrwitten - (lhs match { + var dest: () => IndexAddress = () => IndexAddress(RAX, 0) // overwritten below + + lhs match { case ident: Ident => dest = stack.accessVar(ident) - if (!stack.contains(ident)) { - Chain.one(stack.reserve(ident)) - } else Chain.empty + if (!stack.contains(ident)) chain += stack.reserve(ident) // TODO lhs = arrayElem case _ => - // dest = ??? - Chain.empty - }) ++ - evalExprOntoStack(rhs) ++ - Chain( - stack.pop(RAX), - Move(dest(), RAX) - ) - case If(cond, thenBranch, elseBranch) => { + } + + chain ++= evalExprOntoStack(rhs) + chain += stack.pop(RAX) + chain += Move(dest(), RAX) + + case If(cond, thenBranch, elseBranch) => val elseLabel = labelGenerator.getLabel() val endLabel = labelGenerator.getLabel() - evalExprOntoStack(cond) ++ - Chain.fromSeq( - List( - Compare(stack.head(SizeDir.Word), ImmediateVal(0)), - stack.drop(), - Jump(LabelArg(elseLabel), Cond.Equal) - ) - ) ++ - Chain.fromSeq(thenBranch).flatMap(generateStmt) ++ - Chain.fromSeq(List(Jump(LabelArg(endLabel)), LabelDef(elseLabel))) ++ - Chain.fromSeq(elseBranch).flatMap(generateStmt) ++ - Chain.one(LabelDef(endLabel)) - } - case While(cond, body) => { + + chain ++= evalExprOntoStack(cond) + chain += Compare(stack.head(SizeDir.Word), ImmediateVal(0)) + chain += stack.drop() + chain += Jump(LabelArg(elseLabel), Cond.Equal) + + chain ++= Chain.fromSeq(thenBranch).flatMap(generateStmt) + chain += Jump(LabelArg(endLabel)) + chain += LabelDef(elseLabel) + + chain ++= Chain.fromSeq(elseBranch).flatMap(generateStmt) + chain += LabelDef(endLabel) + + case While(cond, body) => val startLabel = labelGenerator.getLabel() val endLabel = labelGenerator.getLabel() - Chain.one(LabelDef(startLabel)) ++ - evalExprOntoStack(cond) ++ - Chain( - Compare(stack.head(SizeDir.Word), ImmediateVal(0)), - stack.drop(), - Jump(LabelArg(endLabel), Cond.Equal) - ) ++ - Chain.fromSeq(body).flatMap(generateStmt) ++ - Chain(Jump(LabelArg(startLabel)), LabelDef(endLabel)) - } + + chain += LabelDef(startLabel) + chain ++= evalExprOntoStack(cond) + chain += Compare(stack.head(SizeDir.Word), ImmediateVal(0)) + chain += stack.drop() + chain += Jump(LabelArg(endLabel), Cond.Equal) + + chain ++= Chain.fromSeq(body).flatMap(generateStmt) + chain += Jump(LabelArg(startLabel)) + chain += LabelDef(endLabel) + case microWacc.Return(expr) => - evalExprOntoStack(expr) ++ - Chain(stack.pop(RAX), assemblyIR.Return()) - case call: microWacc.Call => generateCall(call) + chain ++= evalExprOntoStack(expr) + chain += stack.pop(RAX) + chain += assemblyIR.Return() + + case call: microWacc.Call => + chain ++= generateCall(call) } + chain + } + def evalExprOntoStack(expr: Expr)(using stack: Stack, - strings: ListBuffer[String] + strings: ListBuffer[String], + labelGenerator: LabelGenerator ): Chain[AsmLine] = { - val out = expr match { - case IntLiter(v) => - Chain.one(stack.push(ImmediateVal(v))) - case CharLiter(v) => - Chain.one(stack.push(ImmediateVal(v.toInt))) - case ident: Ident => - Chain.one(stack.push(stack.accessVar(ident)())) + var chain = Chain.empty[AsmLine] + + expr match { + case IntLiter(v) => chain += stack.push(ImmediateVal(v)) + case CharLiter(v) => chain += stack.push(ImmediateVal(v.toInt)) + case ident: Ident => chain += stack.push(stack.accessVar(ident)()) + case ArrayLiter(elems) => expr.ty match { case KnownType.String => - strings += elems.map { - case CharLiter(v) => v - case _ => "" - }.mkString - Chain( - Load( - RAX, - IndexAddress( - RIP, - LabelArg(s".L.str${strings.size - 1}") - ) - ), - stack.push(RAX) - ) - // TODO other array types - case _ => Chain.empty + strings += elems.collect { case CharLiter(v) => v }.mkString + chain += Load(RAX, IndexAddress(RIP, LabelArg(s".L.str${strings.size - 1}"))) + chain += stack.push(RAX) + case _ => // Other array types TODO } - case BoolLiter(v) => Chain.one(stack.push(ImmediateVal(if (v) 1 else 0))) - case NullLiter() => Chain.one(stack.push(ImmediateVal(0))) - case ArrayElem(value, indices) => Chain.empty - case UnaryOp(x, op) => - evalExprOntoStack(x) ++ - (op match { - // TODO: chr and ord are TYPE CASTS. They do not change the internal value, - // but will need bound checking e.t.c. - case UnaryOperator.Chr | UnaryOperator.Ord | UnaryOperator.Len => Chain.empty - case UnaryOperator.Negate => - Chain.one( - Negate(stack.head(SizeDir.Word)) - ) - case UnaryOperator.Not => - evalExprOntoStack(x) ++ - Chain.one( - Xor(stack.head(SizeDir.Word), ImmediateVal(1)) - ) - }) - case BinaryOp(x, y, op) => + case BoolLiter(v) => chain += stack.push(ImmediateVal(if (v) 1 else 0)) + case NullLiter() => chain += stack.push(ImmediateVal(0)) + case ArrayElem(_, _) => // TODO: Implement handling + + case UnaryOp(x, op) => + chain ++= evalExprOntoStack(x) op match { - case BinaryOperator.Add => - evalExprOntoStack(x) ++ - evalExprOntoStack(y) ++ - Chain( - stack.pop(RAX), - Add(stack.head(SizeDir.Word), EAX) - // TODO OVERFLOWING - ) - case BinaryOperator.Sub => - evalExprOntoStack(x) ++ - evalExprOntoStack(y) ++ - Chain( - stack.pop(RAX), - Subtract(stack.head(SizeDir.Word), EAX) - // TODO OVERFLOWING - ) - case BinaryOperator.Mul => - evalExprOntoStack(x) ++ - evalExprOntoStack(y) ++ - Chain( - stack.pop(RAX), - Multiply(EAX, stack.head(SizeDir.Word)), - stack.drop(), - stack.push(RAX) - // TODO OVERFLOWING - ) - case BinaryOperator.Div => - evalExprOntoStack(y) ++ - evalExprOntoStack(x) ++ - Chain( - stack.pop(RAX), - Divide(stack.head(SizeDir.Word)), - stack.drop(), - stack.push(RAX) - // TODO CHECK DIVISOR IS NOT 0 - ) - case BinaryOperator.Mod => - evalExprOntoStack(y) ++ - evalExprOntoStack(x) ++ - Chain( - stack.pop(RAX), - Divide(stack.head(SizeDir.Word)), - stack.drop(), - stack.push(RDX) - // TODO CHECK DIVISOR IS NOT 0 - ) - case BinaryOperator.Eq => - generateComparison(x, y, Cond.Equal) - case BinaryOperator.Neq => - generateComparison(x, y, Cond.NotEqual) - case BinaryOperator.Greater => - generateComparison(x, y, Cond.Greater) - case BinaryOperator.GreaterEq => - generateComparison(x, y, Cond.GreaterEqual) - case BinaryOperator.Less => - generateComparison(x, y, Cond.Less) - case BinaryOperator.LessEq => - generateComparison(x, y, Cond.LessEqual) - case BinaryOperator.And => - evalExprOntoStack(x) ++ - evalExprOntoStack(y) ++ - Chain( - stack.pop(RAX), - And(stack.head(SizeDir.Word), EAX) - ) - case BinaryOperator.Or => - evalExprOntoStack(x) ++ - evalExprOntoStack(y) ++ - Chain( - stack.pop(RAX), - Or(stack.head(SizeDir.Word), EAX) - ) + case UnaryOperator.Chr | UnaryOperator.Ord | UnaryOperator.Len => // No op needed + case UnaryOperator.Negate => chain += Negate(stack.head(SizeDir.Word)) + case UnaryOperator.Not => + chain += Xor(stack.head(SizeDir.Word), ImmediateVal(1)) } + + case BinaryOp(x, y, op) => + chain ++= evalExprOntoStack(x) + chain ++= evalExprOntoStack(y) + + chain += stack.pop(RAX) + + op match { + case BinaryOperator.Add => chain += Add(stack.head(SizeDir.Word), EAX) + case BinaryOperator.Sub => chain += Subtract(stack.head(SizeDir.Word), EAX) + case BinaryOperator.Mul => + chain += Multiply(EAX, stack.head(SizeDir.Word)) + chain += stack.drop() + chain += stack.push(RAX) + + case BinaryOperator.Div => + chain += Divide(stack.head(SizeDir.Word)) + chain += stack.drop() + chain += stack.push(RAX) + + case BinaryOperator.Mod => + chain += Divide(stack.head(SizeDir.Word)) + chain += stack.drop() + chain += stack.push(RDX) + + case BinaryOperator.Eq => chain ++= generateComparison(x, y, Cond.Equal) + case BinaryOperator.Neq => chain ++= generateComparison(x, y, Cond.NotEqual) + case BinaryOperator.Greater => chain ++= generateComparison(x, y, Cond.Greater) + case BinaryOperator.GreaterEq => chain ++= generateComparison(x, y, Cond.GreaterEqual) + case BinaryOperator.Less => chain ++= generateComparison(x, y, Cond.Less) + case BinaryOperator.LessEq => chain ++= generateComparison(x, y, Cond.LessEqual) + case BinaryOperator.And => chain += And(stack.head(SizeDir.Word), EAX) + case BinaryOperator.Or => chain += Or(stack.head(SizeDir.Word), EAX) + } + case call: microWacc.Call => - generateCall(call) ++ - Chain.one(stack.push(RAX)) + chain ++= generateCall(call) + chain += stack.push(RAX) } - if out.isEmpty then Chain.one(stack.push(ImmediateVal(0))) else out + + if chain.isEmpty then chain += stack.push(ImmediateVal(0)) + chain } def generateCall(call: microWacc.Call)(using stack: Stack, - strings: ListBuffer[String] + strings: ListBuffer[String], + labelGenerator: LabelGenerator ): Chain[AsmLine] = { + var chain = Chain.empty[AsmLine] val argRegs = List(RDI, RSI, RDX, RCX, R8, R9) val microWacc.Call(target, args) = call - val regMoves = argRegs - .zip(args) - .map { (reg, expr) => - evalExprOntoStack(expr) ++ - Chain.one(stack.pop(reg)) - } - .combineAll + argRegs.zip(args).foreach { (reg, expr) => + chain ++= evalExprOntoStack(expr) + chain += stack.pop(reg) + } - val stackPushes = args.drop(argRegs.size).map(evalExprOntoStack).combineAll + args.drop(argRegs.size).foreach { expr => + chain ++= evalExprOntoStack(expr) + } - regMoves ++ - stackPushes ++ - Chain.one(assemblyIR.Call(LabelArg(labelGenerator.getLabel(target)))) ++ - (if (args.size > argRegs.size) Chain.one(stack.drop(args.size - argRegs.size)) - else Chain.empty) + chain += assemblyIR.Call(LabelArg(labelGenerator.getLabel(target))) + + if (args.size > argRegs.size) { + chain += stack.drop(args.size - argRegs.size) + } + + chain } def generateComparison(x: Expr, y: Expr, cond: Cond)(using stack: Stack, - strings: ListBuffer[String] + strings: ListBuffer[String], + labelGenerator: LabelGenerator ): Chain[AsmLine] = { - evalExprOntoStack(x) ++ - evalExprOntoStack(y) ++ - Chain( - stack.pop(RAX), - Compare(stack.head(SizeDir.Word), EAX), - Set(Register(RegSize.Byte, RegName.AL), cond), - And(RAX, ImmediateVal(_8_BIT_MASK)), - stack.drop(), - stack.push(RAX) - ) + + var chain = Chain.empty[AsmLine] + + chain ++= evalExprOntoStack(x) + chain ++= evalExprOntoStack(y) + chain += stack.pop(RAX) + chain += Compare(stack.head(SizeDir.Word), EAX) + chain += Set(Register(RegSize.Byte, RegName.AL), cond) + chain += And(RAX, ImmediateVal(_8_BIT_MASK)) + chain += stack.drop() + chain += stack.push(RAX) + + chain } // Missing a sub instruction but dont think we need it def funcPrologue()(using stack: Stack): Chain[AsmLine] = { - Chain( - stack.push(RBP), - Move(RBP, Register(RegSize.R64, RegName.SP)) - ) + val chain = Chain.empty[AsmLine] + chain += stack.push(RBP) + chain += Move(RBP, Register(RegSize.R64, RegName.SP)) + chain } def funcEpilogue()(using stack: Stack): Chain[AsmLine] = { - Chain( - Move(Register(RegSize.R64, RegName.SP), RBP), - stack.pop(RBP), - assemblyIR.Return() - ) + val chain = Chain.empty[AsmLine] + chain += Move(Register(RegSize.R64, RegName.SP), RBP) + chain += stack.pop(RBP) + chain += assemblyIR.Return() + chain } class LabelGenerator { From bd0eb76bec0fe745f474eec20aeae2727e618719 Mon Sep 17 00:00:00 2001 From: Jonny Date: Tue, 25 Feb 2025 19:43:43 +0000 Subject: [PATCH 3/5] fix: alignment issue with stack in read --- src/main/wacc/backend/asmGenerator.scala | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/main/wacc/backend/asmGenerator.scala b/src/main/wacc/backend/asmGenerator.scala index 7f22e20..c962c71 100644 --- a/src/main/wacc/backend/asmGenerator.scala +++ b/src/main/wacc/backend/asmGenerator.scala @@ -3,8 +3,8 @@ package wacc import scala.collection.mutable.LinkedHashMap import scala.collection.mutable.ListBuffer import cats.data.Chain -import cats.syntax.foldable._ -import parsley.token.errors.Label +// import cats.syntax.foldable._ +// import parsley.token.errors.Label object asmGenerator { import microWacc._ From ebc65af981223e9634075cf9daa4c101e7f6e8b0 Mon Sep 17 00:00:00 2001 From: Jonny Date: Tue, 25 Feb 2025 19:53:31 +0000 Subject: [PATCH 4/5] feat: extension method concatAll defined on Chain implemented --- src/main/wacc/backend/asmGenerator.scala | 62 +++++++++++++----------- 1 file changed, 33 insertions(+), 29 deletions(-) diff --git a/src/main/wacc/backend/asmGenerator.scala b/src/main/wacc/backend/asmGenerator.scala index c962c71..bf4e404 100644 --- a/src/main/wacc/backend/asmGenerator.scala +++ b/src/main/wacc/backend/asmGenerator.scala @@ -3,7 +3,7 @@ package wacc import scala.collection.mutable.LinkedHashMap import scala.collection.mutable.ListBuffer import cats.data.Chain -// import cats.syntax.foldable._ +import cats.syntax.foldable._ // import parsley.token.errors.Label object asmGenerator { @@ -30,6 +30,9 @@ object asmGenerator { extension (chain: Chain[AsmLine]) def +=(line: AsmLine): Chain[AsmLine] = chain.append(line) + def concatAll(chains: Chain[AsmLine]*): Chain[AsmLine] = + chains.foldLeft(chain)(_ ++ _) + class LabelGenerator { var labelVal = -1 def getLabel(): String = { @@ -42,40 +45,38 @@ object asmGenerator { } } - def generateAsm(microProg: Program): List[AsmLine] = { given stack: Stack = Stack() given strings: ListBuffer[String] = ListBuffer[String]() given labelGenerator: LabelGenerator = LabelGenerator() - val Program(funcs, main) = microProg - val progAsm = - Chain.one(LabelDef("main")) ++ - funcPrologue() ++ - Chain(stack.align()) ++ - main.foldLeft(Chain.empty[AsmLine])(_ ++ generateStmt(_)) ++ - Chain.one(Move(RAX, ImmediateVal(0))) ++ - funcEpilogue() ++ - generateBuiltInFuncs() - - val strDirs = strings.toList.zipWithIndex.foldLeft(Chain.empty[AsmLine]) { - case (acc, (str, i)) => - acc ++ Chain( - Directive.Int(str.size), - LabelDef(s".L.str$i"), - Directive.Asciz(str.escaped) - ) + val strDirs = strings.toList.zipWithIndex.foldMap { case (str, i) => + Chain( + Directive.Int(str.size), + LabelDef(s".L.str$i"), + Directive.Asciz(str.escaped) + ) } - val finalChain = Chain( + val progAsm = Chain(LabelDef("main")).concatAll( + funcPrologue(), + Chain.one(stack.align()), + main.foldMap(generateStmt(_)), + Chain.one(Move(RAX, ImmediateVal(0))), + funcEpilogue(), + generateBuiltInFuncs() + ) + + Chain( Directive.IntelSyntax, Directive.Global("main"), Directive.RoData - ) ++ strDirs ++ Chain.one(Directive.Text) ++ progAsm - - finalChain.toList - + ).concatAll( + strDirs, + Chain.one(Directive.Text), + progAsm + ).toList } def wrapFunc(labelName: String, funcBody: Chain[AsmLine])(using @@ -137,7 +138,11 @@ object asmGenerator { chain } - def generateStmt(stmt: Stmt)(using stack: Stack, strings: ListBuffer[String], labelGenerator: LabelGenerator): Chain[AsmLine] = { + def generateStmt(stmt: Stmt)(using + stack: Stack, + strings: ListBuffer[String], + labelGenerator: LabelGenerator + ): Chain[AsmLine] = { var chain = Chain.empty[AsmLine] stmt match { @@ -214,21 +219,20 @@ object asmGenerator { expr.ty match { case KnownType.String => strings += elems.collect { case CharLiter(v) => v }.mkString - chain += Load(RAX, IndexAddress(RIP, LabelArg(s".L.str${strings.size - 1}"))) - chain += stack.push(RAX) + chain += Load(RAX, IndexAddress(RIP, LabelArg(s".L.str${strings.size - 1}"))) + chain += stack.push(RAX) case _ => // Other array types TODO } case BoolLiter(v) => chain += stack.push(ImmediateVal(if (v) 1 else 0)) case NullLiter() => chain += stack.push(ImmediateVal(0)) case ArrayElem(_, _) => // TODO: Implement handling - case UnaryOp(x, op) => chain ++= evalExprOntoStack(x) op match { case UnaryOperator.Chr | UnaryOperator.Ord | UnaryOperator.Len => // No op needed case UnaryOperator.Negate => chain += Negate(stack.head(SizeDir.Word)) - case UnaryOperator.Not => + case UnaryOperator.Not => chain += Xor(stack.head(SizeDir.Word), ImmediateVal(1)) } From 11c483439c4a954471bda72f48e3f4465a42709c Mon Sep 17 00:00:00 2001 From: Gleb Koval Date: Tue, 25 Feb 2025 21:05:21 +0000 Subject: [PATCH 5/5] fix: generate strDirs after prog, change `+=` to `+` --- src/main/wacc/backend/asmGenerator.scala | 34 ++++++++---------------- 1 file changed, 11 insertions(+), 23 deletions(-) diff --git a/src/main/wacc/backend/asmGenerator.scala b/src/main/wacc/backend/asmGenerator.scala index bf4e404..6d10b24 100644 --- a/src/main/wacc/backend/asmGenerator.scala +++ b/src/main/wacc/backend/asmGenerator.scala @@ -28,7 +28,7 @@ object asmGenerator { val _8_BIT_MASK = 0xff extension (chain: Chain[AsmLine]) - def +=(line: AsmLine): Chain[AsmLine] = chain.append(line) + def +(line: AsmLine): Chain[AsmLine] = chain.append(line) def concatAll(chains: Chain[AsmLine]*): Chain[AsmLine] = chains.foldLeft(chain)(_ ++ _) @@ -51,14 +51,6 @@ object asmGenerator { given labelGenerator: LabelGenerator = LabelGenerator() val Program(funcs, main) = microProg - val strDirs = strings.toList.zipWithIndex.foldMap { case (str, i) => - Chain( - Directive.Int(str.size), - LabelDef(s".L.str$i"), - Directive.Asciz(str.escaped) - ) - } - val progAsm = Chain(LabelDef("main")).concatAll( funcPrologue(), Chain.one(stack.align()), @@ -68,6 +60,14 @@ object asmGenerator { generateBuiltInFuncs() ) + val strDirs = strings.toList.zipWithIndex.foldMap { case (str, i) => + Chain( + Directive.Int(str.size), + LabelDef(s".L.str$i"), + Directive.Asciz(str.escaped) + ) + } + Chain( Directive.IntelSyntax, Directive.Global("main"), @@ -328,32 +328,20 @@ object asmGenerator { // Missing a sub instruction but dont think we need it def funcPrologue()(using stack: Stack): Chain[AsmLine] = { - val chain = Chain.empty[AsmLine] + var chain = Chain.empty[AsmLine] chain += stack.push(RBP) chain += Move(RBP, Register(RegSize.R64, RegName.SP)) chain } def funcEpilogue()(using stack: Stack): Chain[AsmLine] = { - val chain = Chain.empty[AsmLine] + var chain = Chain.empty[AsmLine] chain += Move(Register(RegSize.R64, RegName.SP), RBP) chain += stack.pop(RBP) chain += assemblyIR.Return() chain } - class LabelGenerator { - var labelVal = -1 - def getLabel(): String = { - labelVal += 1 - s".L$labelVal" - } - def getLabel(target: CallTarget): String = target match { - case Ident(v, _) => s"wacc_$v" - case Builtin(name) => s"_$name" - } - } - class Stack { private val stack = LinkedHashMap[Expr | Int, Int]() private val RSP = Register(RegSize.R64, RegName.SP)