From 4ffe85be91b0a2b6d3338575ce1dac0475b607c3 Mon Sep 17 00:00:00 2001 From: Gleb Koval Date: Thu, 27 Feb 2025 14:48:24 +0000 Subject: [PATCH] fix: variable-sized values, heap-allocated arrays (and printCharArray) --- src/main/wacc/Main.scala | 3 +- src/main/wacc/backend/Stack.scala | 80 +++--- src/main/wacc/backend/asmGenerator.scala | 285 ++++++++------------- src/main/wacc/backend/assemblyIR.scala | 6 +- src/main/wacc/backend/sizeExtensions.scala | 10 +- src/main/wacc/backend/writer.scala | 5 +- src/main/wacc/frontend/typeChecker.scala | 12 +- src/test/wacc/instructionSpec.scala | 22 +- 8 files changed, 185 insertions(+), 238 deletions(-) diff --git a/src/main/wacc/Main.scala b/src/main/wacc/Main.scala index 52c40aa..020cbcd 100644 --- a/src/main/wacc/Main.scala +++ b/src/main/wacc/Main.scala @@ -1,6 +1,7 @@ package wacc import scala.collection.mutable +import cats.data.Chain import parsley.{Failure, Success} import scopt.OParser import java.io.File @@ -63,7 +64,7 @@ def frontend( } val s = "enter an integer to echo" -def backend(typedProg: microWacc.Program): List[asm.AsmLine] = +def backend(typedProg: microWacc.Program): Chain[asm.AsmLine] = asmGenerator.generateAsm(typedProg) def compile(filename: String, outFile: Option[File] = None)(using diff --git a/src/main/wacc/backend/Stack.scala b/src/main/wacc/backend/Stack.scala index 72aa5ef..0949ecb 100644 --- a/src/main/wacc/backend/Stack.scala +++ b/src/main/wacc/backend/Stack.scala @@ -1,37 +1,48 @@ package wacc import scala.collection.mutable.LinkedHashMap +import cats.data.Chain class Stack { import assemblyIR._ + import assemblyIR.Size._ import sizeExtensions.size import microWacc as mw - private val RSP = Register(Size.Q64, RegName.SP) + private val RSP = Register(Q64, RegName.SP) private class StackValue(val size: Size, val offset: Int) { - def bottom: Int = offset + size.toInt + def bottom: Int = offset + elemBytes } private val stack = LinkedHashMap[mw.Expr | Int, StackValue]() + private val elemBytes: Int = Q64.toInt + private def sizeBytes: Int = stack.size * elemBytes + /** The stack's size in bytes. */ - def size: Int = if stack.isEmpty then 0 else stack.last._2.bottom + def size: Int = stack.size /** Push an expression onto the stack. */ def push(expr: mw.Expr, src: Register): AsmLine = { - stack += expr -> StackValue(src.size, size) + stack += expr -> StackValue(src.size, sizeBytes) Push(src) } - /** Push an arbitrary register onto the stack. */ - def push(src: Register): AsmLine = { - stack += stack.size -> StackValue(src.size, size) - Push(src) + /** Push a value onto the stack. */ + def push(itemSize: Size, addr: Src): AsmLine = { + stack += stack.size -> StackValue(itemSize, sizeBytes) + Push(addr) } /** Reserve space for a variable on the stack. */ def reserve(ident: mw.Ident): AsmLine = { - stack += ident -> StackValue(ident.ty.size, size) - Subtract(RSP, ImmediateVal(ident.ty.size.toInt)) + stack += ident -> StackValue(ident.ty.size, sizeBytes) + Subtract(RSP, ImmediateVal(elemBytes)) + } + + /** Reserve space for a register on the stack. */ + def reserve(src: Register): AsmLine = { + stack += stack.size -> StackValue(src.size, sizeBytes) + Subtract(RSP, ImmediateVal(src.size.toInt)) } /** Reserve space for values on the stack. @@ -40,45 +51,40 @@ class Stack { * The sizes of the values to reserve space for. */ def reserve(sizes: List[Size]): AsmLine = { - val totalSize = sizes - .map(itemSize => - stack += stack.size -> StackValue(itemSize, size) - itemSize.toInt - ) - .sum - Subtract(RSP, ImmediateVal(totalSize)) + sizes.foreach { itemSize => + stack += stack.size -> StackValue(itemSize, sizeBytes) + } + Subtract(RSP, ImmediateVal(elemBytes * sizes.size)) } /** Pop a value from the stack into a register. Sizes MUST match. */ def pop(dest: Register): AsmLine = { - if (dest.size != stack.last._2.size) { - throw new IllegalArgumentException( - s"Cannot pop ${stack.last._2.size} bytes into $dest (${dest.size} bytes) register" - ) - } stack.remove(stack.last._1) Pop(dest) } /** Drop the top n values from the stack. */ def drop(n: Int = 1): AsmLine = { - val totalSize = (1 to n) - .map(_ => - val itemSize = stack.last._2.size.toInt - stack.remove(stack.last._1) - itemSize - ) - .sum - Add(RSP, ImmediateVal(totalSize)) + (1 to n).foreach { _ => + stack.remove(stack.last._1) + } + Add(RSP, ImmediateVal(n * elemBytes)) } - /** Get a lazy IndexAddress for a variable in the stack. */ - def accessVar(ident: mw.Ident): () => IndexAddress = () => { - IndexAddress(RSP, stack.size - stack(ident).bottom) + /** Generate AsmLines within a scope, which is reset after the block. */ + def withScope(block: () => Chain[AsmLine]): Chain[AsmLine] = { + val resetToSize = stack.size + var lines = block() + lines :+= drop(stack.size - resetToSize) + lines } + + /** Get an IndexAddress for a variable in the stack. */ + def accessVar(ident: mw.Ident): IndexAddress = + IndexAddress(RSP, sizeBytes - stack(ident).bottom) + def contains(ident: mw.Ident): Boolean = stack.contains(ident) - def head: MemLocation = MemLocation(RSP) - def head(offset: Size): MemLocation = MemLocation(RSP, Some(offset)) - // TODO: Might want to actually properly handle this with the LinkedHashMap too - def align(): AsmLine = And(RSP, ImmediateVal(-16)) + def head: MemLocation = MemLocation(RSP, stack.last._2.size) + + override def toString(): String = stack.toString } diff --git a/src/main/wacc/backend/asmGenerator.scala b/src/main/wacc/backend/asmGenerator.scala index 688e474..cd53b30 100644 --- a/src/main/wacc/backend/asmGenerator.scala +++ b/src/main/wacc/backend/asmGenerator.scala @@ -1,6 +1,5 @@ package wacc -import scala.collection.mutable.LinkedHashMap import scala.collection.mutable.ListBuffer import cats.data.Chain import cats.syntax.foldable._ @@ -8,7 +7,10 @@ import cats.syntax.foldable._ object asmGenerator { import microWacc._ import assemblyIR._ - import wacc.types._ + import assemblyIR.Size._ + import assemblyIR.RegName._ + import types._ + import sizeExtensions._ import lexer.escapedChars abstract case class Error() { @@ -29,26 +31,22 @@ object asmGenerator { def errLabel = ".L._errDivZero" } - val RAX = Register(RegSize.R64, RegName.AX) - val EAX = Register(RegSize.E32, RegName.AX) - val ESP = Register(RegSize.E32, RegName.SP) - val EDX = Register(RegSize.E32, RegName.DX) - val RDI = Register(RegSize.R64, RegName.DI) - val RIP = Register(RegSize.R64, RegName.IP) - val RBP = Register(RegSize.R64, RegName.BP) - val RSI = Register(RegSize.R64, RegName.SI) - val RDX = Register(RegSize.R64, RegName.DX) - val RCX = Register(RegSize.R64, RegName.CX) - val R8 = Register(RegSize.R64, RegName.Reg8) - val R9 = Register(RegSize.R64, RegName.Reg9) - val argRegs = List(RDI, RSI, RDX, RCX, R8, R9) + private val RAX = Register(Q64, AX) + private val EAX = Register(D32, AX) + private val RDI = Register(Q64, DI) + private val RIP = Register(Q64, IP) + private val RBP = Register(Q64, BP) + private val RSI = Register(Q64, SI) + private val RDX = Register(Q64, DX) + private val RCX = Register(Q64, CX) + private val argRegs = List(DI, SI, DX, CX, R8, R9) - val _8_BIT_MASK = 0xff + private val _8_BIT_MASK = 0xff - extension (chain: Chain[AsmLine]) - def +(line: AsmLine): Chain[AsmLine] = chain.append(line) + extension [T](chain: Chain[T]) + def +(item: T): Chain[T] = chain.append(item) - def concatAll(chains: Chain[AsmLine]*): Chain[AsmLine] = + def concatAll(chains: Chain[T]*): Chain[T] = chains.foldLeft(chain)(_ ++ _) class LabelGenerator { @@ -63,7 +61,7 @@ object asmGenerator { } } - def generateAsm(microProg: Program): List[AsmLine] = { + def generateAsm(microProg: Program): Chain[AsmLine] = { given stack: Stack = Stack() given strings: ListBuffer[String] = ListBuffer[String]() given labelGenerator: LabelGenerator = LabelGenerator() @@ -71,7 +69,6 @@ object asmGenerator { val progAsm = Chain(LabelDef("main")).concatAll( funcPrologue(), - Chain.one(stack.align()), main.foldMap(generateStmt(_)), Chain.one(Xor(RAX, RAX)), funcEpilogue(), @@ -95,12 +92,10 @@ object asmGenerator { strDirs, Chain.one(Directive.Text), progAsm - ).toList + ) } - private def wrapBuiltinFunc(labelName: String, funcBody: Chain[AsmLine])(using - stack: Stack - ): Chain[AsmLine] = { + private def wrapBuiltinFunc(labelName: String, funcBody: Chain[AsmLine]): Chain[AsmLine] = { var chain = Chain.one[AsmLine](LabelDef(labelName)) chain ++= funcPrologue() chain ++= funcBody @@ -108,7 +103,7 @@ object asmGenerator { chain } - def generateUserFunc(func: FuncDecl)(using + private def generateUserFunc(func: FuncDecl)(using strings: ListBuffer[String], labelGenerator: LabelGenerator ): Chain[AsmLine] = { @@ -119,29 +114,27 @@ object asmGenerator { chain ++= funcPrologue() // Push the rest of params onto the stack for simplicity argRegs.zip(func.params).foreach { (reg, param) => - chain += stack.push(param, reg) + chain += stack.push(param, Register(Q64, reg)) } chain ++= func.body.foldMap(generateStmt(_)) // No need for epilogue here since all user functions must return explicitly chain } - def generateBuiltInFuncs()(using - stack: Stack, - strings: ListBuffer[String], + private def generateBuiltInFuncs()(using labelGenerator: LabelGenerator ): Chain[AsmLine] = { var chain = Chain.empty[AsmLine] chain ++= wrapBuiltinFunc( labelGenerator.getLabel(Builtin.Exit), - Chain(stack.align(), assemblyIR.Call(CLibFunc.Exit)) + Chain(stackAlign, assemblyIR.Call(CLibFunc.Exit)) ) chain ++= wrapBuiltinFunc( labelGenerator.getLabel(Builtin.Printf), Chain( - stack.align(), + stackAlign, assemblyIR.Call(CLibFunc.PrintF), Xor(RDI, RDI), assemblyIR.Call(CLibFunc.Fflush) @@ -151,9 +144,9 @@ object asmGenerator { chain ++= wrapBuiltinFunc( labelGenerator.getLabel(Builtin.PrintCharArray), Chain( - stack.align(), - Load(RDX, IndexAddress(RSI, 8)), - Move(RSI, MemLocation(RSI)), + stackAlign, + Load(RDX, IndexAddress(RSI, KnownType.Int.size.toInt)), + Move(Register(D32, SI), MemLocation(RSI, D32)), assemblyIR.Call(CLibFunc.PrintF), Xor(RDI, RDI), assemblyIR.Call(CLibFunc.Fflush) @@ -162,7 +155,7 @@ object asmGenerator { chain ++= wrapBuiltinFunc( labelGenerator.getLabel(Builtin.Malloc), - Chain(stack.align(), assemblyIR.Call(CLibFunc.Malloc)) + Chain(stackAlign, assemblyIR.Call(CLibFunc.Malloc)) // Out of memory check is optional ) @@ -171,13 +164,12 @@ object asmGenerator { chain ++= wrapBuiltinFunc( labelGenerator.getLabel(Builtin.Read), Chain( - stack.align(), - stack.reserve(), - stack.push(RSI), - Load(RSI, stack.head), + stackAlign, + Subtract(Register(Q64, SP), ImmediateVal(8)), + Push(RSI), + Load(RSI, MemLocation(Register(Q64, SP), Q64)), assemblyIR.Call(CLibFunc.Scanf), - stack.pop(RAX), - stack.drop() + Pop(RAX) ) ) @@ -185,7 +177,7 @@ object asmGenerator { // TODO can this be done with a call to generateStmt? // Consider other error cases -> look to generalise LabelDef(zeroDivError.errLabel), - stack.align(), + stackAlign, Load(RDI, IndexAddress(RIP, LabelArg(zeroDivError.strLabel))), assemblyIR.Call(CLibFunc.PrintF), Move(RDI, ImmediateVal(-1)), @@ -195,53 +187,33 @@ object asmGenerator { chain } - /** Wraps a chain in a stack reset. - * - * This is useful for ensuring that the stack size at the death of scope is the same as the stack - * size at the start of the scope. See branching (If / While) - * - * @param genChain - * Function that generates the scope AsmLines - * @param stack - * The stack to reset - * @return - * The generated scope AsmLines - */ - private def generateScope(genChain: () => Chain[AsmLine])(using - stack: Stack - ): Chain[AsmLine] = { - val stackSizeStart = stack.size - var chain = genChain() - chain += stack.drop(stack.size - stackSizeStart) - chain - } - - def generateStmt(stmt: Stmt)(using + private def generateStmt(stmt: Stmt)(using stack: Stack, strings: ListBuffer[String], labelGenerator: LabelGenerator ): Chain[AsmLine] = { var chain = Chain.empty[AsmLine] + chain += Comment(stmt.toString) stmt match { case Assign(lhs, rhs) => - lhs match { case ident: Ident => - val dest = stack.accessVar(ident) if (!stack.contains(ident)) chain += stack.reserve(ident) - chain ++= evalExprOntoStack(rhs) - chain += stack.pop(RDX) - chain += Move(dest(), RDX) + chain += stack.pop(RAX) + chain += Move(stack.accessVar(ident), RAX) case ArrayElem(x, i) => - chain ++= evalExprOntoStack(x) - chain ++= evalExprOntoStack(i) chain ++= evalExprOntoStack(rhs) + chain ++= evalExprOntoStack(i) + chain ++= evalExprOntoStack(x) chain += stack.pop(RAX) chain += stack.pop(RCX) chain += stack.pop(RDX) - chain += Move(IndexAddress(RDX, 8, RCX, 8), RAX) + chain += Move( + IndexAddress(RAX, KnownType.Int.size.toInt, RCX, x.ty.elemSize.toInt), + Register(x.ty.elemSize, DX) + ) } case If(cond, thenBranch, elseBranch) => @@ -253,11 +225,11 @@ object asmGenerator { chain += Compare(RAX, ImmediateVal(0)) chain += Jump(LabelArg(elseLabel), Cond.Equal) - chain ++= generateScope(() => thenBranch.foldMap(generateStmt)) + chain ++= stack.withScope(() => thenBranch.foldMap(generateStmt)) chain += Jump(LabelArg(endLabel)) chain += LabelDef(elseLabel) - chain ++= generateScope(() => elseBranch.foldMap(generateStmt)) + chain ++= stack.withScope(() => elseBranch.foldMap(generateStmt)) chain += LabelDef(endLabel) case While(cond, body) => @@ -270,7 +242,7 @@ object asmGenerator { chain += Compare(RAX, ImmediateVal(0)) chain += Jump(LabelArg(endLabel), Cond.Equal) - chain ++= generateScope(() => body.foldMap(generateStmt)) + chain ++= stack.withScope(() => body.foldMap(generateStmt)) chain += Jump(LabelArg(startLabel)) chain += LabelDef(endLabel) @@ -291,7 +263,7 @@ object asmGenerator { chain } - def evalExprOntoStack(expr: Expr)(using + private def evalExprOntoStack(expr: Expr)(using stack: Stack, strings: ListBuffer[String], labelGenerator: LabelGenerator @@ -299,111 +271,117 @@ object asmGenerator { var chain = Chain.empty[AsmLine] val stackSizeStart = stack.size expr match { - case IntLiter(v) => chain += stack.push(ImmediateVal(v)) - case CharLiter(v) => chain += stack.push(ImmediateVal(v.toInt)) - case ident: Ident => chain += stack.push(stack.accessVar(ident)()) + case IntLiter(v) => chain += stack.push(KnownType.Int.size, ImmediateVal(v)) + case CharLiter(v) => chain += stack.push(KnownType.Char.size, ImmediateVal(v.toInt)) + case ident: Ident => chain += stack.push(ident.ty.size, stack.accessVar(ident)) - case ArrayLiter(elems) => + case array @ ArrayLiter(elems) => expr.ty match { case KnownType.String => strings += elems.collect { case CharLiter(v) => v }.mkString chain += Load(RAX, IndexAddress(RIP, LabelArg(s".L.str${strings.size - 1}"))) - chain += stack.push(RAX) - case _ => + chain += stack.push(Q64, RAX) + case ty => chain ++= generateCall( - microWacc.Call(Builtin.Malloc, List(IntLiter((elems.size + 1) * 8))), + microWacc.Call(Builtin.Malloc, List(IntLiter(array.heapSize))), isTail = false ) - chain += stack.push(RAX) + chain += stack.push(Q64, RAX) // Store the length of the array at the start - chain += Move(MemLocation(RAX, SizeDir.DWord), ImmediateVal(elems.size)) + chain += Move(MemLocation(RAX, D32), ImmediateVal(elems.size)) elems.zipWithIndex.foldMap { (elem, i) => chain ++= evalExprOntoStack(elem) chain += stack.pop(RCX) chain += stack.pop(RAX) - chain += Move(IndexAddress(RAX, 8 * (i + 1)), RCX) - chain += stack.push(RAX) + chain += Move(IndexAddress(RAX, 4 + i * ty.elemSize.toInt), Register(ty.elemSize, CX)) + chain += stack.push(Q64, RAX) } } - case BoolLiter(true) => chain += stack.push(ImmediateVal(1)) + case BoolLiter(true) => + chain += stack.push(KnownType.Bool.size, ImmediateVal(1)) case BoolLiter(false) => chain += Xor(RAX, RAX) - chain += stack.push(RAX) - case NullLiter() => chain += stack.push(ImmediateVal(0)) + chain += stack.push(KnownType.Bool.size, RAX) + case NullLiter() => + chain += stack.push(KnownType.Pair(?, ?).size, ImmediateVal(0)) case ArrayElem(x, i) => chain ++= evalExprOntoStack(x) chain ++= evalExprOntoStack(i) chain += stack.pop(RCX) chain += stack.pop(RAX) - // + 1 because we store the length of the array at the start - chain += stack.push(IndexAddress(RAX, 8, RCX, 8)) + // + Int because we store the length of the array at the start + chain += Move( + Register(x.ty.elemSize, AX), + IndexAddress(RAX, KnownType.Int.size.toInt, RCX, x.ty.elemSize.toInt) + ) + chain += stack.push(x.ty.elemSize, RAX) case UnaryOp(x, op) => chain ++= evalExprOntoStack(x) op match { case UnaryOperator.Chr | UnaryOperator.Ord => // No op needed - case UnaryOperator.Len => - // Access the elem + case UnaryOperator.Len => chain += stack.pop(RAX) - chain += Push(MemLocation(RAX)) - case UnaryOperator.Negate => chain += Negate(stack.head(SizeDir.DWord)) + chain += Move(EAX, MemLocation(RAX, D32)) + chain += stack.push(D32, RAX) + case UnaryOperator.Negate => + chain += Negate(stack.head) case UnaryOperator.Not => - chain += Xor(stack.head(SizeDir.DWord), ImmediateVal(1)) + chain += Xor(stack.head, ImmediateVal(1)) } case BinaryOp(x, y, op) => + val destX = Register(x.ty.size, AX) chain ++= evalExprOntoStack(y) chain ++= evalExprOntoStack(x) - chain += stack.pop(RAX) op match { - case BinaryOperator.Add => chain += Add(stack.head(SizeDir.DWord), EAX) + case BinaryOperator.Add => + chain += Add(stack.head, destX) case BinaryOperator.Sub => - chain += Subtract(EAX, stack.head(SizeDir.DWord)) + chain += Subtract(destX, stack.head) chain += stack.drop() - chain += stack.push(RAX) + chain += stack.push(destX.size, RAX) case BinaryOperator.Mul => - chain += Multiply(EAX, stack.head(SizeDir.DWord)) + chain += Multiply(destX, stack.head) chain += stack.drop() - chain += stack.push(RAX) + chain += stack.push(destX.size, RAX) case BinaryOperator.Div => - chain += Compare(stack.head(SizeDir.DWord), ImmediateVal(0)) + chain += Compare(stack.head, ImmediateVal(0)) chain += Jump(LabelArg(zeroDivError.errLabel), Cond.Equal) chain += CDQ() - chain += Divide(stack.head(SizeDir.DWord)) + chain += Divide(stack.head) chain += stack.drop() - chain += stack.push(RAX) + chain += stack.push(destX.size, RAX) case BinaryOperator.Mod => chain += CDQ() - chain += Divide(stack.head(SizeDir.DWord)) + chain += Divide(stack.head) chain += stack.drop() - chain += stack.push(RDX) + chain += stack.push(destX.size, RDX) - case BinaryOperator.Eq => chain ++= generateComparison(Cond.Equal) - case BinaryOperator.Neq => chain ++= generateComparison(Cond.NotEqual) - case BinaryOperator.Greater => chain ++= generateComparison(Cond.Greater) - case BinaryOperator.GreaterEq => chain ++= generateComparison(Cond.GreaterEqual) - case BinaryOperator.Less => chain ++= generateComparison(Cond.Less) - case BinaryOperator.LessEq => chain ++= generateComparison(Cond.LessEqual) - case BinaryOperator.And => chain += And(stack.head(SizeDir.DWord), EAX) - case BinaryOperator.Or => chain += Or(stack.head(SizeDir.DWord), EAX) + case BinaryOperator.Eq => chain ++= generateComparison(destX, Cond.Equal) + case BinaryOperator.Neq => chain ++= generateComparison(destX, Cond.NotEqual) + case BinaryOperator.Greater => chain ++= generateComparison(destX, Cond.Greater) + case BinaryOperator.GreaterEq => chain ++= generateComparison(destX, Cond.GreaterEqual) + case BinaryOperator.Less => chain ++= generateComparison(destX, Cond.Less) + case BinaryOperator.LessEq => chain ++= generateComparison(destX, Cond.LessEqual) + case BinaryOperator.And => chain += And(stack.head, destX) + case BinaryOperator.Or => chain += Or(stack.head, destX) } case call: microWacc.Call => chain ++= generateCall(call, isTail = false) - chain += stack.push(RAX) + chain += stack.push(call.ty.size, RAX) } - if chain.isEmpty then chain += stack.push(ImmediateVal(0)) - assert(stack.size == stackSizeStart + 1) chain } - def generateCall(call: microWacc.Call, isTail: Boolean)(using + private def generateCall(call: microWacc.Call, isTail: Boolean)(using stack: Stack, strings: ListBuffer[String], labelGenerator: LabelGenerator @@ -413,7 +391,7 @@ object asmGenerator { argRegs.zip(args).foldMap { (reg, expr) => chain ++= evalExprOntoStack(expr) - chain += stack.pop(reg) + chain += stack.pop(Register(Q64, reg)) } args.drop(argRegs.size).foldMap { @@ -434,77 +412,36 @@ object asmGenerator { chain } - def generateComparison(cond: Cond)(using - stack: Stack, - strings: ListBuffer[String], - labelGenerator: LabelGenerator + private def generateComparison(destX: Register, cond: Cond)(using + stack: Stack ): Chain[AsmLine] = { var chain = Chain.empty[AsmLine] - chain += Compare(EAX, stack.head(SizeDir.DWord)) - chain += Set(Register(RegSize.Byte, RegName.AL), cond) + chain += Compare(destX, stack.head) + chain += Set(Register(B8, AX), cond) chain += And(RAX, ImmediateVal(_8_BIT_MASK)) chain += stack.drop() - chain += stack.push(RAX) + chain += stack.push(B8, RAX) chain } - // Missing a sub instruction but dont think we need it - def funcPrologue()(using stack: Stack): Chain[AsmLine] = { + private def funcPrologue(): Chain[AsmLine] = { var chain = Chain.empty[AsmLine] - chain += stack.push(RBP) - chain += Move(RBP, Register(RegSize.R64, RegName.SP)) + chain += Push(RBP) + chain += Move(RBP, Register(Q64, SP)) chain } - def funcEpilogue()(using stack: Stack): Chain[AsmLine] = { + private def funcEpilogue(): Chain[AsmLine] = { var chain = Chain.empty[AsmLine] - chain += Move(Register(RegSize.R64, RegName.SP), RBP) + chain += Move(Register(Q64, SP), RBP) chain += Pop(RBP) chain += assemblyIR.Return() chain } - class Stack { - private val stack = LinkedHashMap[Expr | Int, Int]() - private val RSP = Register(RegSize.R64, RegName.SP) - - private def next: Int = stack.size + 1 - def size: Int = stack.size - def push(expr: Expr, src: Src): AsmLine = { - stack += expr -> next - Push(src) - } - def push(src: Src): AsmLine = { - stack += stack.size -> next - Push(src) - } - def pop(dest: Src): AsmLine = { - stack.remove(stack.last._1) - Pop(dest) - } - def reserve(ident: Ident): AsmLine = { - stack += ident -> next - Subtract(RSP, ImmediateVal(8)) - } - def reserve(n: Int = 1): AsmLine = { - (1 to n).foreach(_ => stack += stack.size -> next) - Subtract(RSP, ImmediateVal(n * 8)) - } - def drop(n: Int = 1): AsmLine = { - (1 to n).foreach(_ => stack.remove(stack.last._1)) - Add(RSP, ImmediateVal(n * 8)) - } - def accessVar(ident: Ident): () => IndexAddress = () => { - IndexAddress(RSP, (stack.size - stack(ident)) * 8) - } - def head: MemLocation = MemLocation(RSP) - def head(size: SizeDir): MemLocation = MemLocation(RSP, size) - def contains(ident: Ident): Boolean = stack.contains(ident) - // TODO: Might want to actually properly handle this with the LinkedHashMap too - def align(): AsmLine = And(RSP, ImmediateVal(-16)) - } + private def stackAlign: AsmLine = And(Register(Q64, SP), ImmediateVal(-16)) private val escapedCharsMapping = escapedChars.map { case (k, v) => v -> s"\\$k" } extension (s: String) { diff --git a/src/main/wacc/backend/assemblyIR.scala b/src/main/wacc/backend/assemblyIR.scala index 2946dcb..fbf51f5 100644 --- a/src/main/wacc/backend/assemblyIR.scala +++ b/src/main/wacc/backend/assemblyIR.scala @@ -97,11 +97,9 @@ object assemblyIR { } } - case class MemLocation(pointer: Register, opSize: Option[Size] = None) extends Dest with Src { - def this(pointer: Register, opSize: Size) = this(pointer, Some(opSize)) - + case class MemLocation(pointer: Register, opSize: Size) extends Dest with Src { override def toString = - opSize.getOrElse("").toString + s"[$pointer]" + opSize.toString + s"[$pointer]" } // TODO to string is wacky diff --git a/src/main/wacc/backend/sizeExtensions.scala b/src/main/wacc/backend/sizeExtensions.scala index 59d3930..798e290 100644 --- a/src/main/wacc/backend/sizeExtensions.scala +++ b/src/main/wacc/backend/sizeExtensions.scala @@ -11,8 +11,8 @@ object sizeExtensions { def heapSize: Int = (expr, expr.ty) match { case (ArrayLiter(elems), KnownType.Array(KnownType.Char)) => KnownType.Int.size.toInt + elems.size.toInt * KnownType.Char.size.toInt - case (ArrayLiter(elems), _) => - KnownType.Int.size.toInt + elems.map(_.ty.size.toInt).sum + case (ArrayLiter(elems), ty) => + KnownType.Int.size.toInt + elems.size * ty.elemSize.toInt case _ => expr.ty.size.toInt } } @@ -25,5 +25,11 @@ object sizeExtensions { case KnownType.Bool | KnownType.Char => Size.B8 case KnownType.String | KnownType.Array(_) | KnownType.Pair(_, _) | ? => Size.Q64 } + + def elemSize: Size = ty match { + case KnownType.Array(elem) => elem.size + case KnownType.Pair(_, _) => Size.Q64 + case _ => ty.size + } } } diff --git a/src/main/wacc/backend/writer.scala b/src/main/wacc/backend/writer.scala index b798af3..3c8dcfd 100644 --- a/src/main/wacc/backend/writer.scala +++ b/src/main/wacc/backend/writer.scala @@ -1,11 +1,12 @@ package wacc import java.io.PrintStream +import cats.data.Chain object writer { import assemblyIR._ - def writeTo(asmList: List[AsmLine], printStream: PrintStream): Unit = { - asmList.foreach(printStream.println) + def writeTo(asmList: Chain[AsmLine], printStream: PrintStream): Unit = { + asmList.iterator.foreach(printStream.println) } } diff --git a/src/main/wacc/frontend/typeChecker.scala b/src/main/wacc/frontend/typeChecker.scala index 2c430e5..002876d 100644 --- a/src/main/wacc/frontend/typeChecker.scala +++ b/src/main/wacc/frontend/typeChecker.scala @@ -223,10 +223,10 @@ object typeChecker { case KnownType.Int => "%d" case KnownType.Pair(_, _) | KnownType.Array(_) | ? => "%p" } - val printfCall = { (value: microWacc.Expr) => + val printfCall = { (func: microWacc.Builtin, value: microWacc.Expr) => List( microWacc.Call( - microWacc.Builtin.Printf, + func, List( s"$exprFormat${if newline then "\n" else ""}".toMicroWaccCharArray, value @@ -239,11 +239,13 @@ object typeChecker { List( microWacc.If( exprTyped, - printfCall("true".toMicroWaccCharArray), - printfCall("false".toMicroWaccCharArray) + printfCall(microWacc.Builtin.Printf, "true".toMicroWaccCharArray), + printfCall(microWacc.Builtin.Printf, "false".toMicroWaccCharArray) ) ) - case _ => printfCall(exprTyped) + case KnownType.Array(KnownType.Char) => + printfCall(microWacc.Builtin.PrintCharArray, exprTyped) + case _ => printfCall(microWacc.Builtin.Printf, exprTyped) } case ast.If(cond, thenStmt, elseStmt) => List( diff --git a/src/test/wacc/instructionSpec.scala b/src/test/wacc/instructionSpec.scala index b7452a0..feef0d4 100644 --- a/src/test/wacc/instructionSpec.scala +++ b/src/test/wacc/instructionSpec.scala @@ -1,42 +1,38 @@ import org.scalatest.funsuite.AnyFunSuite import wacc.assemblyIR._ +import wacc.assemblyIR.Size._ +import wacc.assemblyIR.RegName._ class instructionSpec extends AnyFunSuite { - val named64BitRegister = Register(RegSize.R64, RegName.AX) + val named64BitRegister = Register(Q64, AX) test("named 64-bit register toString") { assert(named64BitRegister.toString == "rax") } - val named32BitRegister = Register(RegSize.E32, RegName.AX) + val named32BitRegister = Register(D32, AX) test("named 32-bit register toString") { assert(named32BitRegister.toString == "eax") } - val scratch64BitRegister = Register(RegSize.R64, RegName.Reg8) + val scratch64BitRegister = Register(Q64, R8) test("scratch 64-bit register toString") { assert(scratch64BitRegister.toString == "r8") } - val scratch32BitRegister = Register(RegSize.E32, RegName.Reg8) + val scratch32BitRegister = Register(D32, R8) test("scratch 32-bit register toString") { - assert(scratch32BitRegister.toString == "e8") + assert(scratch32BitRegister.toString == "r8d") } - val memLocationWithHex = MemLocation(0x12345678) - - test("mem location with hex toString") { - assert(memLocationWithHex.toString == "[0x12345678]") - } - - val memLocationWithRegister = MemLocation(named64BitRegister) + val memLocationWithRegister = MemLocation(named64BitRegister, Q64) test("mem location with register toString") { - assert(memLocationWithRegister.toString == "[rax]") + assert(memLocationWithRegister.toString == "qword ptr [rax]") } val immediateVal = ImmediateVal(123)