diff --git a/src/main/wacc/Main.scala b/src/main/wacc/Main.scala index 52c40aa..020cbcd 100644 --- a/src/main/wacc/Main.scala +++ b/src/main/wacc/Main.scala @@ -1,6 +1,7 @@ package wacc import scala.collection.mutable +import cats.data.Chain import parsley.{Failure, Success} import scopt.OParser import java.io.File @@ -63,7 +64,7 @@ def frontend( } val s = "enter an integer to echo" -def backend(typedProg: microWacc.Program): List[asm.AsmLine] = +def backend(typedProg: microWacc.Program): Chain[asm.AsmLine] = asmGenerator.generateAsm(typedProg) def compile(filename: String, outFile: Option[File] = None)(using diff --git a/src/main/wacc/backend/Stack.scala b/src/main/wacc/backend/Stack.scala new file mode 100644 index 0000000..94b329a --- /dev/null +++ b/src/main/wacc/backend/Stack.scala @@ -0,0 +1,90 @@ +package wacc + +import scala.collection.mutable.LinkedHashMap +import cats.data.Chain + +class Stack { + import assemblyIR._ + import assemblyIR.Size._ + import sizeExtensions.size + import microWacc as mw + + private val RSP = Register(Q64, RegName.SP) + private class StackValue(val size: Size, val offset: Int) { + def bottom: Int = offset + elemBytes + } + private val stack = LinkedHashMap[mw.Expr | Int, StackValue]() + + private val elemBytes: Int = Q64.toInt + private def sizeBytes: Int = stack.size * elemBytes + + /** The stack's size in bytes. */ + def size: Int = stack.size + + /** Push an expression onto the stack. */ + def push(expr: mw.Expr, src: Register): AsmLine = { + stack += expr -> StackValue(src.size, sizeBytes) + Push(src) + } + + /** Push a value onto the stack. */ + def push(itemSize: Size, addr: Src): AsmLine = { + stack += stack.size -> StackValue(itemSize, sizeBytes) + Push(addr) + } + + /** Reserve space for a variable on the stack. */ + def reserve(ident: mw.Ident): AsmLine = { + stack += ident -> StackValue(ident.ty.size, sizeBytes) + Subtract(RSP, ImmediateVal(elemBytes)) + } + + /** Reserve space for a register on the stack. */ + def reserve(src: Register): AsmLine = { + stack += stack.size -> StackValue(src.size, sizeBytes) + Subtract(RSP, ImmediateVal(src.size.toInt)) + } + + /** Reserve space for values on the stack. + * + * @param sizes + * The sizes of the values to reserve space for. + */ + def reserve(sizes: Size*): AsmLine = { + sizes.foreach { itemSize => + stack += stack.size -> StackValue(itemSize, sizeBytes) + } + Subtract(RSP, ImmediateVal(elemBytes * sizes.size)) + } + + /** Pop a value from the stack into a register. Sizes MUST match. */ + def pop(dest: Register): AsmLine = { + stack.remove(stack.last._1) + Pop(dest) + } + + /** Drop the top n values from the stack. */ + def drop(n: Int = 1): AsmLine = { + (1 to n).foreach { _ => + stack.remove(stack.last._1) + } + Add(RSP, ImmediateVal(n * elemBytes)) + } + + /** Generate AsmLines within a scope, which is reset after the block. */ + def withScope(block: () => Chain[AsmLine]): Chain[AsmLine] = { + val resetToSize = stack.size + var lines = block() + lines :+= drop(stack.size - resetToSize) + lines + } + + /** Get an IndexAddress for a variable in the stack. */ + def accessVar(ident: mw.Ident): IndexAddress = + IndexAddress(RSP, sizeBytes - stack(ident).bottom) + + def contains(ident: mw.Ident): Boolean = stack.contains(ident) + def head: MemLocation = MemLocation(RSP, stack.last._2.size) + + override def toString(): String = stack.toString +} diff --git a/src/main/wacc/backend/asmGenerator.scala b/src/main/wacc/backend/asmGenerator.scala index 54fc7db..60e0b47 100644 --- a/src/main/wacc/backend/asmGenerator.scala +++ b/src/main/wacc/backend/asmGenerator.scala @@ -1,6 +1,5 @@ package wacc -import scala.collection.mutable.LinkedHashMap import scala.collection.mutable.ListBuffer import cats.data.Chain import cats.syntax.foldable._ @@ -8,7 +7,10 @@ import cats.syntax.foldable._ object asmGenerator { import microWacc._ import assemblyIR._ - import wacc.types._ + import assemblyIR.Size._ + import assemblyIR.RegName._ + import types._ + import sizeExtensions._ import lexer.escapedChars abstract case class Error() { @@ -29,26 +31,20 @@ object asmGenerator { def errLabel = ".L._errDivZero" } - val RAX = Register(RegSize.R64, RegName.AX) - val EAX = Register(RegSize.E32, RegName.AX) - val ESP = Register(RegSize.E32, RegName.SP) - val EDX = Register(RegSize.E32, RegName.DX) - val RDI = Register(RegSize.R64, RegName.DI) - val RIP = Register(RegSize.R64, RegName.IP) - val RBP = Register(RegSize.R64, RegName.BP) - val RSI = Register(RegSize.R64, RegName.SI) - val RDX = Register(RegSize.R64, RegName.DX) - val RCX = Register(RegSize.R64, RegName.CX) - val R8 = Register(RegSize.R64, RegName.Reg8) - val R9 = Register(RegSize.R64, RegName.Reg9) - val argRegs = List(RDI, RSI, RDX, RCX, R8, R9) + private val RAX = Register(Q64, AX) + private val EAX = Register(D32, AX) + private val RDI = Register(Q64, DI) + private val RIP = Register(Q64, IP) + private val RBP = Register(Q64, BP) + private val RSI = Register(Q64, SI) + private val RDX = Register(Q64, DX) + private val RCX = Register(Q64, CX) + private val argRegs = List(DI, SI, DX, CX, R8, R9) - val _8_BIT_MASK = 0xff + extension [T](chain: Chain[T]) + def +(item: T): Chain[T] = chain.append(item) - extension (chain: Chain[AsmLine]) - def +(line: AsmLine): Chain[AsmLine] = chain.append(line) - - def concatAll(chains: Chain[AsmLine]*): Chain[AsmLine] = + def concatAll(chains: Chain[T]*): Chain[T] = chains.foldLeft(chain)(_ ++ _) class LabelGenerator { @@ -63,7 +59,7 @@ object asmGenerator { } } - def generateAsm(microProg: Program): List[AsmLine] = { + def generateAsm(microProg: Program): Chain[AsmLine] = { given stack: Stack = Stack() given strings: ListBuffer[String] = ListBuffer[String]() given labelGenerator: LabelGenerator = LabelGenerator() @@ -71,7 +67,6 @@ object asmGenerator { val progAsm = Chain(LabelDef("main")).concatAll( funcPrologue(), - Chain.one(stack.align()), main.foldMap(generateStmt(_)), Chain.one(Xor(RAX, RAX)), funcEpilogue(), @@ -95,7 +90,7 @@ object asmGenerator { strDirs, Chain.one(Directive.Text), progAsm - ).toList + ) } private def wrapBuiltinFunc(labelName: String, funcBody: Chain[AsmLine])(using @@ -108,40 +103,52 @@ object asmGenerator { chain } - def generateUserFunc(func: FuncDecl)(using + private def generateUserFunc(func: FuncDecl)(using strings: ListBuffer[String], labelGenerator: LabelGenerator ): Chain[AsmLine] = { given stack: Stack = Stack() // Setup the stack with param 7 and up func.params.drop(argRegs.size).foreach(stack.reserve(_)) + stack.reserve(Q64) // Reserve return pointer slot var chain = Chain.one[AsmLine](LabelDef(labelGenerator.getLabel(func.name))) chain ++= funcPrologue() // Push the rest of params onto the stack for simplicity argRegs.zip(func.params).foreach { (reg, param) => - chain += stack.push(param, reg) + chain += stack.push(param, Register(Q64, reg)) } chain ++= func.body.foldMap(generateStmt(_)) // No need for epilogue here since all user functions must return explicitly chain } - def generateBuiltInFuncs()(using + private def generateBuiltInFuncs()(using stack: Stack, - strings: ListBuffer[String], labelGenerator: LabelGenerator ): Chain[AsmLine] = { var chain = Chain.empty[AsmLine] chain ++= wrapBuiltinFunc( labelGenerator.getLabel(Builtin.Exit), - Chain(stack.align(), assemblyIR.Call(CLibFunc.Exit)) + Chain(stackAlign, assemblyIR.Call(CLibFunc.Exit)) ) chain ++= wrapBuiltinFunc( labelGenerator.getLabel(Builtin.Printf), Chain( - stack.align(), + stackAlign, + assemblyIR.Call(CLibFunc.PrintF), + Xor(RDI, RDI), + assemblyIR.Call(CLibFunc.Fflush) + ) + ) + + chain ++= wrapBuiltinFunc( + labelGenerator.getLabel(Builtin.PrintCharArray), + Chain( + stackAlign, + Load(RDX, IndexAddress(RSI, KnownType.Int.size.toInt)), + Move(Register(D32, SI), MemLocation(RSI, D32)), assemblyIR.Call(CLibFunc.PrintF), Xor(RDI, RDI), assemblyIR.Call(CLibFunc.Fflush) @@ -150,7 +157,8 @@ object asmGenerator { chain ++= wrapBuiltinFunc( labelGenerator.getLabel(Builtin.Malloc), - Chain.one(stack.align()) + Chain(stackAlign, assemblyIR.Call(CLibFunc.Malloc)) + // Out of memory check is optional ) chain ++= wrapBuiltinFunc(labelGenerator.getLabel(Builtin.Free), Chain.empty) @@ -158,13 +166,12 @@ object asmGenerator { chain ++= wrapBuiltinFunc( labelGenerator.getLabel(Builtin.Read), Chain( - stack.align(), - stack.reserve(), - stack.push(RSI), - Load(RSI, stack.head), + stackAlign, + Subtract(Register(Q64, SP), ImmediateVal(8)), + Push(RSI), + Load(RSI, MemLocation(Register(Q64, SP), Q64)), assemblyIR.Call(CLibFunc.Scanf), - stack.pop(RAX), - stack.drop() + Pop(RAX) ) ) @@ -172,7 +179,7 @@ object asmGenerator { // TODO can this be done with a call to generateStmt? // Consider other error cases -> look to generalise LabelDef(zeroDivError.errLabel), - stack.align(), + stackAlign, Load(RDI, IndexAddress(RIP, LabelArg(zeroDivError.strLabel))), assemblyIR.Call(CLibFunc.PrintF), Move(RDI, ImmediateVal(-1)), @@ -182,48 +189,34 @@ object asmGenerator { chain } - /** Wraps a chain in a stack reset. - * - * This is useful for ensuring that the stack size at the death of scope is the same as the stack - * size at the start of the scope. See branching (If / While) - * - * @param genChain - * Function that generates the scope AsmLines - * @param stack - * The stack to reset - * @return - * The generated scope AsmLines - */ - private def generateScope(genChain: () => Chain[AsmLine])(using - stack: Stack - ): Chain[AsmLine] = { - val stackSizeStart = stack.size - var chain = genChain() - chain += stack.drop(stack.size - stackSizeStart) - chain - } - - def generateStmt(stmt: Stmt)(using + private def generateStmt(stmt: Stmt)(using stack: Stack, strings: ListBuffer[String], labelGenerator: LabelGenerator ): Chain[AsmLine] = { var chain = Chain.empty[AsmLine] + chain += Comment(stmt.toString) stmt match { case Assign(lhs, rhs) => - var dest: () => IndexAddress = () => IndexAddress(RAX, 0) // overwritten below - lhs match { case ident: Ident => - dest = stack.accessVar(ident) if (!stack.contains(ident)) chain += stack.reserve(ident) - // TODO lhs = arrayElem - case _ => - } + chain ++= evalExprOntoStack(rhs) + chain += stack.pop(RAX) + chain += Move(stack.accessVar(ident), RAX) + case ArrayElem(x, i) => + chain ++= evalExprOntoStack(rhs) + chain ++= evalExprOntoStack(i) + chain ++= evalExprOntoStack(x) + chain += stack.pop(RAX) + chain += stack.pop(RCX) + chain += stack.pop(RDX) - chain ++= evalExprOntoStack(rhs) - chain += stack.pop(RAX) - chain += Move(dest(), RAX) + chain += Move( + IndexAddress(RAX, KnownType.Int.size.toInt, RCX, x.ty.elemSize.toInt), + Register(x.ty.elemSize, DX) + ) + } case If(cond, thenBranch, elseBranch) => val elseLabel = labelGenerator.getLabel() @@ -234,11 +227,11 @@ object asmGenerator { chain += Compare(RAX, ImmediateVal(0)) chain += Jump(LabelArg(elseLabel), Cond.Equal) - chain ++= generateScope(() => thenBranch.foldMap(generateStmt)) + chain ++= stack.withScope(() => thenBranch.foldMap(generateStmt)) chain += Jump(LabelArg(endLabel)) chain += LabelDef(elseLabel) - chain ++= generateScope(() => elseBranch.foldMap(generateStmt)) + chain ++= stack.withScope(() => elseBranch.foldMap(generateStmt)) chain += LabelDef(endLabel) case While(cond, body) => @@ -251,7 +244,7 @@ object asmGenerator { chain += Compare(RAX, ImmediateVal(0)) chain += Jump(LabelArg(endLabel), Cond.Equal) - chain ++= generateScope(() => body.foldMap(generateStmt)) + chain ++= stack.withScope(() => body.foldMap(generateStmt)) chain += Jump(LabelArg(startLabel)) chain += LabelDef(endLabel) @@ -272,7 +265,7 @@ object asmGenerator { chain } - def evalExprOntoStack(expr: Expr)(using + private def evalExprOntoStack(expr: Expr)(using stack: Stack, strings: ListBuffer[String], labelGenerator: LabelGenerator @@ -280,87 +273,118 @@ object asmGenerator { var chain = Chain.empty[AsmLine] val stackSizeStart = stack.size expr match { - case IntLiter(v) => chain += stack.push(ImmediateVal(v)) - case CharLiter(v) => chain += stack.push(ImmediateVal(v.toInt)) - case ident: Ident => chain += stack.push(stack.accessVar(ident)()) + case IntLiter(v) => chain += stack.push(KnownType.Int.size, ImmediateVal(v)) + case CharLiter(v) => chain += stack.push(KnownType.Char.size, ImmediateVal(v.toInt)) + case ident: Ident => chain += stack.push(ident.ty.size, stack.accessVar(ident)) - case ArrayLiter(elems) => + case array @ ArrayLiter(elems) => expr.ty match { case KnownType.String => strings += elems.collect { case CharLiter(v) => v }.mkString chain += Load(RAX, IndexAddress(RIP, LabelArg(s".L.str${strings.size - 1}"))) - chain += stack.push(RAX) - case _ => // Other array types TODO + chain += stack.push(Q64, RAX) + case ty => + chain ++= generateCall( + microWacc.Call(Builtin.Malloc, List(IntLiter(array.heapSize))), + isTail = false + ) + chain += stack.push(Q64, RAX) + // Store the length of the array at the start + chain += Move(MemLocation(RAX, D32), ImmediateVal(elems.size)) + elems.zipWithIndex.foldMap { (elem, i) => + chain ++= evalExprOntoStack(elem) + chain += stack.pop(RCX) + chain += stack.pop(RAX) + chain += Move(IndexAddress(RAX, 4 + i * ty.elemSize.toInt), Register(ty.elemSize, CX)) + chain += stack.push(Q64, RAX) + } } - case BoolLiter(true) => chain += stack.push(ImmediateVal(1)) + case BoolLiter(true) => + chain += stack.push(KnownType.Bool.size, ImmediateVal(1)) case BoolLiter(false) => chain += Xor(RAX, RAX) - chain += stack.push(RAX) - case NullLiter() => chain += stack.push(ImmediateVal(0)) - case ArrayElem(_, _) => // TODO: Implement handling + chain += stack.push(KnownType.Bool.size, RAX) + case NullLiter() => + chain += stack.push(KnownType.Pair(?, ?).size, ImmediateVal(0)) + case ArrayElem(x, i) => + chain ++= evalExprOntoStack(x) + chain ++= evalExprOntoStack(i) + chain += stack.pop(RCX) + chain += stack.pop(RAX) + // + Int because we store the length of the array at the start + chain += Move( + Register(x.ty.elemSize, AX), + IndexAddress(RAX, KnownType.Int.size.toInt, RCX, x.ty.elemSize.toInt) + ) + chain += stack.push(x.ty.elemSize, RAX) case UnaryOp(x, op) => chain ++= evalExprOntoStack(x) op match { - case UnaryOperator.Chr | UnaryOperator.Ord | UnaryOperator.Len => // No op needed - case UnaryOperator.Negate => chain += Negate(stack.head(SizeDir.DWord)) + case UnaryOperator.Chr | UnaryOperator.Ord => // No op needed + case UnaryOperator.Len => + chain += stack.pop(RAX) + chain += Move(EAX, MemLocation(RAX, D32)) + chain += stack.push(D32, RAX) + case UnaryOperator.Negate => + chain += Negate(stack.head) case UnaryOperator.Not => - chain += Xor(stack.head(SizeDir.DWord), ImmediateVal(1)) + chain += Xor(stack.head, ImmediateVal(1)) } case BinaryOp(x, y, op) => + val destX = Register(x.ty.size, AX) chain ++= evalExprOntoStack(y) chain ++= evalExprOntoStack(x) - chain += stack.pop(RAX) op match { - case BinaryOperator.Add => chain += Add(stack.head(SizeDir.DWord), EAX) + case BinaryOperator.Add => + chain += Add(stack.head, destX) case BinaryOperator.Sub => - chain += Subtract(EAX, stack.head(SizeDir.DWord)) + chain += Subtract(destX, stack.head) chain += stack.drop() - chain += stack.push(RAX) + chain += stack.push(destX.size, RAX) case BinaryOperator.Mul => - chain += Multiply(EAX, stack.head(SizeDir.DWord)) + chain += Multiply(destX, stack.head) chain += stack.drop() - chain += stack.push(RAX) + chain += stack.push(destX.size, RAX) case BinaryOperator.Div => - chain += Compare(stack.head(SizeDir.DWord), ImmediateVal(0)) + chain += Compare(stack.head, ImmediateVal(0)) chain += Jump(LabelArg(zeroDivError.errLabel), Cond.Equal) chain += CDQ() - chain += Divide(stack.head(SizeDir.DWord)) + chain += Divide(stack.head) chain += stack.drop() - chain += stack.push(RAX) + chain += stack.push(destX.size, RAX) case BinaryOperator.Mod => chain += CDQ() - chain += Divide(stack.head(SizeDir.DWord)) + chain += Divide(stack.head) chain += stack.drop() - chain += stack.push(RDX) + chain += stack.push(destX.size, RDX) - case BinaryOperator.Eq => chain ++= generateComparison(Cond.Equal) - case BinaryOperator.Neq => chain ++= generateComparison(Cond.NotEqual) - case BinaryOperator.Greater => chain ++= generateComparison(Cond.Greater) - case BinaryOperator.GreaterEq => chain ++= generateComparison(Cond.GreaterEqual) - case BinaryOperator.Less => chain ++= generateComparison(Cond.Less) - case BinaryOperator.LessEq => chain ++= generateComparison(Cond.LessEqual) - case BinaryOperator.And => chain += And(stack.head(SizeDir.DWord), EAX) - case BinaryOperator.Or => chain += Or(stack.head(SizeDir.DWord), EAX) + case BinaryOperator.Eq => chain ++= generateComparison(destX, Cond.Equal) + case BinaryOperator.Neq => chain ++= generateComparison(destX, Cond.NotEqual) + case BinaryOperator.Greater => chain ++= generateComparison(destX, Cond.Greater) + case BinaryOperator.GreaterEq => chain ++= generateComparison(destX, Cond.GreaterEqual) + case BinaryOperator.Less => chain ++= generateComparison(destX, Cond.Less) + case BinaryOperator.LessEq => chain ++= generateComparison(destX, Cond.LessEqual) + case BinaryOperator.And => chain += And(stack.head, destX) + case BinaryOperator.Or => chain += Or(stack.head, destX) } case call: microWacc.Call => chain ++= generateCall(call, isTail = false) - chain += stack.push(RAX) + chain += stack.push(call.ty.size, RAX) } - if chain.isEmpty then chain += stack.push(ImmediateVal(0)) - assert(stack.size == stackSizeStart + 1) + chain ++= zeroRest(MemLocation(stack.head.pointer, Q64), expr.ty.size) chain } - def generateCall(call: microWacc.Call, isTail: Boolean)(using + private def generateCall(call: microWacc.Call, isTail: Boolean)(using stack: Stack, strings: ListBuffer[String], labelGenerator: LabelGenerator @@ -370,7 +394,7 @@ object asmGenerator { argRegs.zip(args).foldMap { (reg, expr) => chain ++= evalExprOntoStack(expr) - chain += stack.pop(reg) + chain += stack.pop(Register(Q64, reg)) } args.drop(argRegs.size).foldMap { @@ -391,76 +415,39 @@ object asmGenerator { chain } - def generateComparison(cond: Cond)(using - stack: Stack, - strings: ListBuffer[String], - labelGenerator: LabelGenerator + private def generateComparison(destX: Register, cond: Cond)(using + stack: Stack ): Chain[AsmLine] = { var chain = Chain.empty[AsmLine] - chain += Compare(EAX, stack.head(SizeDir.DWord)) - chain += Set(Register(RegSize.Byte, RegName.AL), cond) - chain += And(RAX, ImmediateVal(_8_BIT_MASK)) + chain += Compare(destX, stack.head) + chain += Set(Register(B8, AX), cond) + chain ++= zeroRest(RAX, B8) chain += stack.drop() - chain += stack.push(RAX) + chain += stack.push(B8, RAX) chain } - // Missing a sub instruction but dont think we need it - def funcPrologue()(using stack: Stack): Chain[AsmLine] = { + private def funcPrologue()(using stack: Stack): Chain[AsmLine] = { var chain = Chain.empty[AsmLine] - chain += stack.push(RBP) - chain += Move(RBP, Register(RegSize.R64, RegName.SP)) + chain += stack.push(Q64, RBP) + chain += Move(RBP, Register(Q64, SP)) chain } - def funcEpilogue()(using stack: Stack): Chain[AsmLine] = { + private def funcEpilogue(): Chain[AsmLine] = { var chain = Chain.empty[AsmLine] - chain += Move(Register(RegSize.R64, RegName.SP), RBP) + chain += Move(Register(Q64, SP), RBP) chain += Pop(RBP) chain += assemblyIR.Return() chain } - class Stack { - private val stack = LinkedHashMap[Expr | Int, Int]() - private val RSP = Register(RegSize.R64, RegName.SP) - - private def next: Int = stack.size + 1 - def size: Int = stack.size - def push(expr: Expr, src: Src): AsmLine = { - stack += expr -> next - Push(src) - } - def push(src: Src): AsmLine = { - stack += stack.size -> next - Push(src) - } - def pop(dest: Src): AsmLine = { - stack.remove(stack.last._1) - Pop(dest) - } - def reserve(ident: Ident): AsmLine = { - stack += ident -> next - Subtract(RSP, ImmediateVal(8)) - } - def reserve(n: Int = 1): AsmLine = { - (1 to n).foreach(_ => stack += stack.size -> next) - Subtract(RSP, ImmediateVal(n * 8)) - } - def drop(n: Int = 1): AsmLine = { - (1 to n).foreach(_ => stack.remove(stack.last._1)) - Add(RSP, ImmediateVal(n * 8)) - } - def accessVar(ident: Ident): () => IndexAddress = () => { - IndexAddress(RSP, (stack.size - stack(ident)) * 8) - } - def head: MemLocation = MemLocation(RSP) - def head(size: SizeDir): MemLocation = MemLocation(RSP, size) - def contains(ident: Ident): Boolean = stack.contains(ident) - // TODO: Might want to actually properly handle this with the LinkedHashMap too - def align(): AsmLine = And(RSP, ImmediateVal(-16)) + private def stackAlign: AsmLine = And(Register(Q64, SP), ImmediateVal(-16)) + private def zeroRest(dest: Dest, size: Size): Chain[AsmLine] = size match { + case Q64 | D32 => Chain.empty + case _ => Chain.one(And(dest, ImmediateVal((1 << (size.toInt * 8)) - 1))) } private val escapedCharsMapping = escapedChars.map { case (k, v) => v -> s"\\$k" } diff --git a/src/main/wacc/backend/assemblyIR.scala b/src/main/wacc/backend/assemblyIR.scala index 1ff8906..fbf51f5 100644 --- a/src/main/wacc/backend/assemblyIR.scala +++ b/src/main/wacc/backend/assemblyIR.scala @@ -6,40 +6,73 @@ object assemblyIR { sealed trait Operand sealed trait Src extends Operand // mem location, register and imm value sealed trait Dest extends Operand // mem location and register - enum RegSize { - case R64 - case E32 - case Byte - override def toString = this match { - case R64 => "r" - case E32 => "e" - case Byte => "" + enum Size { + case Q64, D32, W16, B8 + + def toInt: Int = this match { + case Q64 => 8 + case D32 => 4 + case W16 => 2 + case B8 => 1 + } + + private val ptr = "ptr " + + override def toString(): String = this match { + case Q64 => "qword " + ptr + case D32 => "dword " + ptr + case W16 => "word " + ptr + case B8 => "byte " + ptr } } enum RegName { - case AX, AL, BX, CX, DX, SI, DI, SP, BP, IP, Reg8, Reg9, Reg10, Reg11, Reg12, Reg13, Reg14, - Reg15 - override def toString = this match { - case AX => "ax" - case AL => "al" - case BX => "bx" - case CX => "cx" - case DX => "dx" - case SI => "si" - case DI => "di" - case SP => "sp" - case BP => "bp" - case IP => "ip" - case Reg8 => "8" - case Reg9 => "9" - case Reg10 => "10" - case Reg11 => "11" - case Reg12 => "12" - case Reg13 => "13" - case Reg14 => "14" - case Reg15 => "15" + case AX, BX, CX, DX, SI, DI, SP, BP, IP, R8, R9, R10, R11, R12, R13, R14, R15 + } + + case class Register(size: Size, name: RegName) extends Dest with Src { + import RegName._ + + if (size == Size.B8 && name == RegName.IP) { + throw new IllegalArgumentException("Cannot have 8 bit register for IP") + } + override def toString = name match { + case AX => tradToString("ax", "al") + case BX => tradToString("bx", "bl") + case CX => tradToString("cx", "cl") + case DX => tradToString("dx", "dl") + case SI => tradToString("si", "sil") + case DI => tradToString("di", "dil") + case SP => tradToString("sp", "spl") + case BP => tradToString("bp", "bpl") + case IP => tradToString("ip", "#INVALID") + case R8 => newToString(8) + case R9 => newToString(9) + case R10 => newToString(10) + case R11 => newToString(11) + case R12 => newToString(12) + case R13 => newToString(13) + case R14 => newToString(14) + case R15 => newToString(15) + } + + private def tradToString(base: String, byteName: String): String = + size match { + case Size.Q64 => "r" + base + case Size.D32 => "e" + base + case Size.W16 => base + case Size.B8 => byteName + } + + private def newToString(base: Int): String = { + val b = base.toString + "r" + (size match { + case Size.Q64 => b + case Size.D32 => b + "d" + case Size.W16 => b + "w" + case Size.B8 => b + "b" + }) } } @@ -48,7 +81,9 @@ object assemblyIR { case Scanf, Fflush, Exit, - PrintF + PrintF, + Malloc, + Free private val plt = "@plt" @@ -57,28 +92,29 @@ object assemblyIR { case Fflush => "fflush" + plt case Exit => "exit" + plt case PrintF => "printf" + plt + case Malloc => "malloc" + plt + case Free => "free" + plt } } - // TODO register naming conventions are wrong - case class Register(size: RegSize, name: RegName) extends Dest with Src { - override def toString = s"${size}${name}" - } - case class MemLocation(pointer: Long | Register, opSize: SizeDir = SizeDir.Unspecified) - extends Dest - with Src { - override def toString = pointer match { - case hex: Long => opSize.toString + f"[0x$hex%X]" - case reg: Register => opSize.toString + s"[$reg]" - } + case class MemLocation(pointer: Register, opSize: Size) extends Dest with Src { + override def toString = + opSize.toString + s"[$pointer]" } + + // TODO to string is wacky case class IndexAddress( base: Register, offset: Int | LabelArg, - opSize: SizeDir = SizeDir.Unspecified + indexReg: Register = Register(Size.Q64, RegName.AX), + scale: Int = 0 ) extends Dest with Src { - override def toString = s"$opSize[$base + $offset]" + override def toString = if (scale != 0) { + s"[$base + $indexReg * $scale + $offset]" + } else { + s"[$base + $offset]" + } } case class ImmediateVal(value: Int) extends Src { @@ -177,17 +213,4 @@ object assemblyIR { case String => "%s" } } - - enum SizeDir { - case Byte, Word, DWord, Unspecified - - private val ptr = "ptr " - - override def toString(): String = this match { - case Byte => "byte " + ptr - case Word => "word " + ptr // TODO check word/doubleword/quadword - case DWord => "dword " + ptr - case Unspecified => "" - } - } } diff --git a/src/main/wacc/backend/sizeExtensions.scala b/src/main/wacc/backend/sizeExtensions.scala new file mode 100644 index 0000000..798e290 --- /dev/null +++ b/src/main/wacc/backend/sizeExtensions.scala @@ -0,0 +1,35 @@ +package wacc + +object sizeExtensions { + import microWacc._ + import types._ + import assemblyIR.Size + + extension (expr: Expr) { + + /** Calculate the size (bytes) of the heap required for the expression. */ + def heapSize: Int = (expr, expr.ty) match { + case (ArrayLiter(elems), KnownType.Array(KnownType.Char)) => + KnownType.Int.size.toInt + elems.size.toInt * KnownType.Char.size.toInt + case (ArrayLiter(elems), ty) => + KnownType.Int.size.toInt + elems.size * ty.elemSize.toInt + case _ => expr.ty.size.toInt + } + } + + extension (ty: SemType) { + + /** Calculate the size (bytes) of a type in a register. */ + def size: Size = ty match { + case KnownType.Int => Size.D32 + case KnownType.Bool | KnownType.Char => Size.B8 + case KnownType.String | KnownType.Array(_) | KnownType.Pair(_, _) | ? => Size.Q64 + } + + def elemSize: Size = ty match { + case KnownType.Array(elem) => elem.size + case KnownType.Pair(_, _) => Size.Q64 + case _ => ty.size + } + } +} diff --git a/src/main/wacc/backend/writer.scala b/src/main/wacc/backend/writer.scala index b798af3..3c8dcfd 100644 --- a/src/main/wacc/backend/writer.scala +++ b/src/main/wacc/backend/writer.scala @@ -1,11 +1,12 @@ package wacc import java.io.PrintStream +import cats.data.Chain object writer { import assemblyIR._ - def writeTo(asmList: List[AsmLine], printStream: PrintStream): Unit = { - asmList.foreach(printStream.println) + def writeTo(asmList: Chain[AsmLine], printStream: PrintStream): Unit = { + asmList.iterator.foreach(printStream.println) } } diff --git a/src/main/wacc/frontend/microWacc.scala b/src/main/wacc/frontend/microWacc.scala index 099fcc3..e2c1bdc 100644 --- a/src/main/wacc/frontend/microWacc.scala +++ b/src/main/wacc/frontend/microWacc.scala @@ -74,6 +74,7 @@ object microWacc { object Exit extends Builtin("exit")(?) object Free extends Builtin("free")(?) object Malloc extends Builtin("malloc")(?) + object PrintCharArray extends Builtin("printCharArray")(?) } case class Assign(lhs: LValue, rhs: Expr) extends Stmt diff --git a/src/main/wacc/frontend/typeChecker.scala b/src/main/wacc/frontend/typeChecker.scala index f571e11..002876d 100644 --- a/src/main/wacc/frontend/typeChecker.scala +++ b/src/main/wacc/frontend/typeChecker.scala @@ -218,14 +218,15 @@ object typeChecker { val exprTyped = checkValue(expr, Constraint.Unconstrained) val exprFormat = exprTyped.ty match { case KnownType.Bool | KnownType.String => "%s" + case KnownType.Array(KnownType.Char) => "%.*s" case KnownType.Char => "%c" case KnownType.Int => "%d" case KnownType.Pair(_, _) | KnownType.Array(_) | ? => "%p" } - val printfCall = { (value: microWacc.Expr) => + val printfCall = { (func: microWacc.Builtin, value: microWacc.Expr) => List( microWacc.Call( - microWacc.Builtin.Printf, + func, List( s"$exprFormat${if newline then "\n" else ""}".toMicroWaccCharArray, value @@ -238,11 +239,13 @@ object typeChecker { List( microWacc.If( exprTyped, - printfCall("true".toMicroWaccCharArray), - printfCall("false".toMicroWaccCharArray) + printfCall(microWacc.Builtin.Printf, "true".toMicroWaccCharArray), + printfCall(microWacc.Builtin.Printf, "false".toMicroWaccCharArray) ) ) - case _ => printfCall(exprTyped) + case KnownType.Array(KnownType.Char) => + printfCall(microWacc.Builtin.PrintCharArray, exprTyped) + case _ => printfCall(microWacc.Builtin.Printf, exprTyped) } case ast.If(cond, thenStmt, elseStmt) => List( diff --git a/src/test/wacc/examples.scala b/src/test/wacc/examples.scala index 8ac0aa4..7c895ab 100644 --- a/src/test/wacc/examples.scala +++ b/src/test/wacc/examples.scala @@ -39,7 +39,26 @@ class ParallelExamplesSpec extends AnyFlatSpec with BeforeAndAfterAll { // Retrieve contents to get input and expected output + exit code val contents = scala.io.Source.fromFile(File(filename)).getLines.toList val inputLine = - contents.find(_.matches("^# ?[Ii]nput:.*$")).map(_.split(":").last.strip).getOrElse("") + contents + .find(_.matches("^# ?[Ii]nput:.*$")) + .map(line => + ("" :: line.split(":").last.strip.split(" ").toList) + .sliding(2) + .flatMap { arr => + if ( + // First entry has no space in front + arr(0) == "" || + // int followed by non-digit, space can be removed + arr(0).toIntOption.nonEmpty && !arr(1)(0).isDigit || + // non-int followed by int, space can be removed + !arr(0).last.isDigit && arr(1).toIntOption.nonEmpty + ) + then List(arr(1)) + else List(" ", arr(1)) + } + .mkString + ) + .getOrElse("") val outputLineIdx = contents.indexWhere(_.matches("^# ?[Oo]utput:.*$")) val expectedOutput = if (outputLineIdx == -1) "" @@ -73,7 +92,7 @@ class ParallelExamplesSpec extends AnyFlatSpec with BeforeAndAfterAll { ) assert(process.exitValue == expectedExit) - assert(stdout.toString == expectedOutput) + assert(stdout.toString.replaceAll("0x[0-9a-f]+", "#addrs#") == expectedOutput) } } @@ -86,21 +105,20 @@ class ParallelExamplesSpec extends AnyFlatSpec with BeforeAndAfterAll { // format: off // disable formatting to avoid binPack "^.*wacc-examples/valid/advanced.*$", - "^.*wacc-examples/valid/array.*$", + // "^.*wacc-examples/valid/array.*$", // "^.*wacc-examples/valid/basic/exit.*$", // "^.*wacc-examples/valid/basic/skip.*$", // "^.*wacc-examples/valid/expressions.*$", - "^.*wacc-examples/valid/function/nested_functions.*$", - "^.*wacc-examples/valid/function/simple_functions.*$", + // "^.*wacc-examples/valid/function/nested_functions.*$", + // "^.*wacc-examples/valid/function/simple_functions.*$", // "^.*wacc-examples/valid/if.*$", - "^.*wacc-examples/valid/IO/print.*$", + // "^.*wacc-examples/valid/IO/print.*$", // "^.*wacc-examples/valid/IO/read.*$", - "^.*wacc-examples/valid/IO/IOLoop.wacc.*$", + // "^.*wacc-examples/valid/IO/IOLoop.wacc.*$", // "^.*wacc-examples/valid/IO/IOSequence.wacc.*$", - "^.*wacc-examples/valid/pairs.*$", + // "^.*wacc-examples/valid/pairs.*$", "^.*wacc-examples/valid/runtimeErr.*$", // "^.*wacc-examples/valid/scope.*$", - "^.*wacc-examples/valid/scope/printAllTypes.wacc$", // while we still don't have arrays implemented // "^.*wacc-examples/valid/sequence.*$", // "^.*wacc-examples/valid/variables.*$", // "^.*wacc-examples/valid/while.*$", diff --git a/src/test/wacc/instructionSpec.scala b/src/test/wacc/instructionSpec.scala index b7452a0..feef0d4 100644 --- a/src/test/wacc/instructionSpec.scala +++ b/src/test/wacc/instructionSpec.scala @@ -1,42 +1,38 @@ import org.scalatest.funsuite.AnyFunSuite import wacc.assemblyIR._ +import wacc.assemblyIR.Size._ +import wacc.assemblyIR.RegName._ class instructionSpec extends AnyFunSuite { - val named64BitRegister = Register(RegSize.R64, RegName.AX) + val named64BitRegister = Register(Q64, AX) test("named 64-bit register toString") { assert(named64BitRegister.toString == "rax") } - val named32BitRegister = Register(RegSize.E32, RegName.AX) + val named32BitRegister = Register(D32, AX) test("named 32-bit register toString") { assert(named32BitRegister.toString == "eax") } - val scratch64BitRegister = Register(RegSize.R64, RegName.Reg8) + val scratch64BitRegister = Register(Q64, R8) test("scratch 64-bit register toString") { assert(scratch64BitRegister.toString == "r8") } - val scratch32BitRegister = Register(RegSize.E32, RegName.Reg8) + val scratch32BitRegister = Register(D32, R8) test("scratch 32-bit register toString") { - assert(scratch32BitRegister.toString == "e8") + assert(scratch32BitRegister.toString == "r8d") } - val memLocationWithHex = MemLocation(0x12345678) - - test("mem location with hex toString") { - assert(memLocationWithHex.toString == "[0x12345678]") - } - - val memLocationWithRegister = MemLocation(named64BitRegister) + val memLocationWithRegister = MemLocation(named64BitRegister, Q64) test("mem location with register toString") { - assert(memLocationWithRegister.toString == "[rax]") + assert(memLocationWithRegister.toString == "qword ptr [rax]") } val immediateVal = ImmediateVal(123)