diff --git a/.commitlintrc.js b/.commitlintrc.js new file mode 100644 index 0000000..a295b1d --- /dev/null +++ b/.commitlintrc.js @@ -0,0 +1,4 @@ +export default { + extends: ['@commitlint/config-conventional'], + ignores: [commit => commit.startsWith("Local Mutable Chains\n")] +} diff --git a/.commitlintrc.yml b/.commitlintrc.yml deleted file mode 100644 index 175ef04..0000000 --- a/.commitlintrc.yml +++ /dev/null @@ -1 +0,0 @@ -extends: "@commitlint/config-conventional" diff --git a/.gitlab-ci.yml b/.gitlab-ci.yml index 98a6044..e61541c 100644 --- a/.gitlab-ci.yml +++ b/.gitlab-ci.yml @@ -30,9 +30,10 @@ check_commits: before_script: - apk add git - npm install -g @commitlint/cli @commitlint/config-conventional - - git pull origin master + - git checkout origin/master script: - - npx commitlint --from origin/master --to HEAD --verbose + - git checkout ${CI_COMMIT_SHA} + - npx commitlint --from origin/master --to ${CI_COMMIT_SHA} --verbose compile_jvm: stage: compile @@ -48,10 +49,10 @@ test_jvm: image: gumjoe/wacc-ci-scala:x86 stage: test # Use our own runner (not cloud VM or shared) to ensure we have multiple cores. - tags: [ large ] + tags: [large] # This is expensive, so do use `dependencies` instead of `needs` to # ensure all previous stages pass. - dependencies: [ compile_jvm ] + dependencies: [compile_jvm] before_script: - git clone https://$EXAMPLES_AUTH@gitlab.doc.ic.ac.uk/lab2425_spring/wacc-examples.git script: diff --git a/src/main/wacc/Main.scala b/src/main/wacc/Main.scala index e8e7b7b..020cbcd 100644 --- a/src/main/wacc/Main.scala +++ b/src/main/wacc/Main.scala @@ -1,13 +1,13 @@ package wacc import scala.collection.mutable +import cats.data.Chain import parsley.{Failure, Success} import scopt.OParser import java.io.File import java.io.PrintStream import assemblyIR as asm -import wacc.microWacc.IntLiter case class CliConfig( file: File = new File(".") @@ -64,157 +64,8 @@ def frontend( } val s = "enter an integer to echo" -def backend(typedProg: microWacc.Program): List[asm.AsmLine] | String = - typedProg match { - case microWacc.Program( - Nil, - microWacc.Call(microWacc.Builtin.Exit, microWacc.IntLiter(v) :: Nil) :: Nil - ) => - s""".intel_syntax noprefix -.globl main -main: - mov edi, ${v} - call exit@plt -""" - case microWacc.Program( - Nil, - microWacc.Assign(microWacc.Ident("x", _), microWacc.IntLiter(1)) :: - microWacc.Call(microWacc.Builtin.Println, _) :: - microWacc.Assign( - microWacc.Ident("x", _), - microWacc.Call(microWacc.Builtin.ReadInt, Nil) - ) :: - microWacc.Call(microWacc.Builtin.Println, microWacc.Ident("x", _) :: Nil) :: Nil - ) => - """.intel_syntax noprefix -.globl main -.section .rodata -# length of .L.str0 - .int 24 -.L.str0: - .asciz "enter an integer to echo" -.text -main: - push rbp - # push {rbx, r12} - sub rsp, 16 - mov qword ptr [rsp], rbx - mov qword ptr [rsp + 8], r12 - mov rbp, rsp - mov r12d, 1 - lea rdi, [rip + .L.str0] - # statement primitives do not return results (but will clobber r0/rax) - call _prints - call _println - # load the current value in the destination of the read so it supports defaults - mov edi, r12d - call _readi - mov r12d, eax - mov edi, eax - # statement primitives do not return results (but will clobber r0/rax) - call _printi - call _println - mov rax, 0 - # pop/peek {rbx, r12} - mov rbx, qword ptr [rsp] - mov r12, qword ptr [rsp + 8] - add rsp, 16 - pop rbp - ret - -.section .rodata -# length of .L._printi_str0 - .int 2 -.L._printi_str0: - .asciz "%d" -.text -_printi: - push rbp - mov rbp, rsp - # external calls must be stack-aligned to 16 bytes, accomplished by masking with fffffffffffffff0 - and rsp, -16 - mov esi, edi - lea rdi, [rip + .L._printi_str0] - # on x86, al represents the number of SIMD registers used as variadic arguments - mov al, 0 - call printf@plt - mov rdi, 0 - call fflush@plt - mov rsp, rbp - pop rbp - ret - -.section .rodata -# length of .L._prints_str0 - .int 4 -.L._prints_str0: - .asciz "%.*s" -.text -_prints: - push rbp - mov rbp, rsp - # external calls must be stack-aligned to 16 bytes, accomplished by masking with fffffffffffffff0 - and rsp, -16 - mov rdx, rdi - mov esi, dword ptr [rdi - 4] - lea rdi, [rip + .L._prints_str0] - # on x86, al represents the number of SIMD registers used as variadic arguments - mov al, 0 - call printf@plt - mov rdi, 0 - call fflush@plt - mov rsp, rbp - pop rbp - ret - -.section .rodata -# length of .L._println_str0 - .int 0 -.L._println_str0: - .asciz "" -.text -_println: - push rbp - mov rbp, rsp - # external calls must be stack-aligned to 16 bytes, accomplished by masking with fffffffffffffff0 - and rsp, -16 - lea rdi, [rip + .L._println_str0] - call puts@plt - mov rdi, 0 - call fflush@plt - mov rsp, rbp - pop rbp - ret - -.section .rodata -# length of .L._readi_str0 - .int 2 -.L._readi_str0: - .asciz "%d" -.text -_readi: - push rbp - mov rbp, rsp - # external calls must be stack-aligned to 16 bytes, accomplished by masking with fffffffffffffff0 - and rsp, -16 - # RDI contains the "original" value of the destination of the read - # allocate space on the stack to store the read: preserve alignment! - # the passed default argument should be stored in case of EOF - sub rsp, 16 - mov dword ptr [rsp], edi - lea rsi, qword ptr [rsp] - lea rdi, [rip + .L._readi_str0] - # on x86, al represents the number of SIMD registers used as variadic arguments - mov al, 0 - call scanf@plt - mov eax, dword ptr [rsp] - add rsp, 16 - mov rsp, rbp - pop rbp - ret - """ - case _ => List() - } +def backend(typedProg: microWacc.Program): Chain[asm.AsmLine] = + asmGenerator.generateAsm(typedProg) def compile(filename: String, outFile: Option[File] = None)(using stdout: PrintStream = Console.out @@ -222,13 +73,8 @@ def compile(filename: String, outFile: Option[File] = None)(using frontend(os.read(os.Path(filename))) match { case Left(typedProg) => val asmFile = outFile.getOrElse(File(filename.stripSuffix(".wacc") + ".s")) - backend(typedProg) match { - case s: String => - os.write.over(os.Path(asmFile.getAbsolutePath), s) - case ops: List[asm.AsmLine] => { - writer.writeTo(ops, PrintStream(asmFile)) - } - } + val asm = backend(typedProg) + writer.writeTo(asm, PrintStream(asmFile)) 0 case Right(exitCode) => exitCode } diff --git a/src/main/wacc/backend/Stack.scala b/src/main/wacc/backend/Stack.scala new file mode 100644 index 0000000..94b329a --- /dev/null +++ b/src/main/wacc/backend/Stack.scala @@ -0,0 +1,90 @@ +package wacc + +import scala.collection.mutable.LinkedHashMap +import cats.data.Chain + +class Stack { + import assemblyIR._ + import assemblyIR.Size._ + import sizeExtensions.size + import microWacc as mw + + private val RSP = Register(Q64, RegName.SP) + private class StackValue(val size: Size, val offset: Int) { + def bottom: Int = offset + elemBytes + } + private val stack = LinkedHashMap[mw.Expr | Int, StackValue]() + + private val elemBytes: Int = Q64.toInt + private def sizeBytes: Int = stack.size * elemBytes + + /** The stack's size in bytes. */ + def size: Int = stack.size + + /** Push an expression onto the stack. */ + def push(expr: mw.Expr, src: Register): AsmLine = { + stack += expr -> StackValue(src.size, sizeBytes) + Push(src) + } + + /** Push a value onto the stack. */ + def push(itemSize: Size, addr: Src): AsmLine = { + stack += stack.size -> StackValue(itemSize, sizeBytes) + Push(addr) + } + + /** Reserve space for a variable on the stack. */ + def reserve(ident: mw.Ident): AsmLine = { + stack += ident -> StackValue(ident.ty.size, sizeBytes) + Subtract(RSP, ImmediateVal(elemBytes)) + } + + /** Reserve space for a register on the stack. */ + def reserve(src: Register): AsmLine = { + stack += stack.size -> StackValue(src.size, sizeBytes) + Subtract(RSP, ImmediateVal(src.size.toInt)) + } + + /** Reserve space for values on the stack. + * + * @param sizes + * The sizes of the values to reserve space for. + */ + def reserve(sizes: Size*): AsmLine = { + sizes.foreach { itemSize => + stack += stack.size -> StackValue(itemSize, sizeBytes) + } + Subtract(RSP, ImmediateVal(elemBytes * sizes.size)) + } + + /** Pop a value from the stack into a register. Sizes MUST match. */ + def pop(dest: Register): AsmLine = { + stack.remove(stack.last._1) + Pop(dest) + } + + /** Drop the top n values from the stack. */ + def drop(n: Int = 1): AsmLine = { + (1 to n).foreach { _ => + stack.remove(stack.last._1) + } + Add(RSP, ImmediateVal(n * elemBytes)) + } + + /** Generate AsmLines within a scope, which is reset after the block. */ + def withScope(block: () => Chain[AsmLine]): Chain[AsmLine] = { + val resetToSize = stack.size + var lines = block() + lines :+= drop(stack.size - resetToSize) + lines + } + + /** Get an IndexAddress for a variable in the stack. */ + def accessVar(ident: mw.Ident): IndexAddress = + IndexAddress(RSP, sizeBytes - stack(ident).bottom) + + def contains(ident: mw.Ident): Boolean = stack.contains(ident) + def head: MemLocation = MemLocation(RSP, stack.last._2.size) + + override def toString(): String = stack.toString +} diff --git a/src/main/wacc/backend/asmGenerator.scala b/src/main/wacc/backend/asmGenerator.scala new file mode 100644 index 0000000..60e0b47 --- /dev/null +++ b/src/main/wacc/backend/asmGenerator.scala @@ -0,0 +1,458 @@ +package wacc + +import scala.collection.mutable.ListBuffer +import cats.data.Chain +import cats.syntax.foldable._ + +object asmGenerator { + import microWacc._ + import assemblyIR._ + import assemblyIR.Size._ + import assemblyIR.RegName._ + import types._ + import sizeExtensions._ + import lexer.escapedChars + + abstract case class Error() { + def strLabel: String + def errStr: String + def errLabel: String + + def stringDef: Chain[AsmLine] = Chain( + Directive.Int(errStr.size), + LabelDef(strLabel), + Directive.Asciz(errStr) + ) + } + object zeroDivError extends Error { + // TODO: is this bad? Can we make an error case class/some other structure? + def strLabel = ".L._errDivZero_str0" + def errStr = "fatal error: division or modulo by zero" + def errLabel = ".L._errDivZero" + } + + private val RAX = Register(Q64, AX) + private val EAX = Register(D32, AX) + private val RDI = Register(Q64, DI) + private val RIP = Register(Q64, IP) + private val RBP = Register(Q64, BP) + private val RSI = Register(Q64, SI) + private val RDX = Register(Q64, DX) + private val RCX = Register(Q64, CX) + private val argRegs = List(DI, SI, DX, CX, R8, R9) + + extension [T](chain: Chain[T]) + def +(item: T): Chain[T] = chain.append(item) + + def concatAll(chains: Chain[T]*): Chain[T] = + chains.foldLeft(chain)(_ ++ _) + + class LabelGenerator { + var labelVal = -1 + def getLabel(): String = { + labelVal += 1 + s".L$labelVal" + } + def getLabel(target: CallTarget): String = target match { + case Ident(v, _) => s"wacc_$v" + case Builtin(name) => s"_$name" + } + } + + def generateAsm(microProg: Program): Chain[AsmLine] = { + given stack: Stack = Stack() + given strings: ListBuffer[String] = ListBuffer[String]() + given labelGenerator: LabelGenerator = LabelGenerator() + val Program(funcs, main) = microProg + + val progAsm = Chain(LabelDef("main")).concatAll( + funcPrologue(), + main.foldMap(generateStmt(_)), + Chain.one(Xor(RAX, RAX)), + funcEpilogue(), + generateBuiltInFuncs(), + funcs.foldMap(generateUserFunc(_)) + ) + + val strDirs = strings.toList.zipWithIndex.foldMap { case (str, i) => + Chain( + Directive.Int(str.size), + LabelDef(s".L.str$i"), + Directive.Asciz(str.escaped) + ) + } ++ zeroDivError.stringDef + + Chain( + Directive.IntelSyntax, + Directive.Global("main"), + Directive.RoData + ).concatAll( + strDirs, + Chain.one(Directive.Text), + progAsm + ) + } + + private def wrapBuiltinFunc(labelName: String, funcBody: Chain[AsmLine])(using + stack: Stack + ): Chain[AsmLine] = { + var chain = Chain.one[AsmLine](LabelDef(labelName)) + chain ++= funcPrologue() + chain ++= funcBody + chain ++= funcEpilogue() + chain + } + + private def generateUserFunc(func: FuncDecl)(using + strings: ListBuffer[String], + labelGenerator: LabelGenerator + ): Chain[AsmLine] = { + given stack: Stack = Stack() + // Setup the stack with param 7 and up + func.params.drop(argRegs.size).foreach(stack.reserve(_)) + stack.reserve(Q64) // Reserve return pointer slot + var chain = Chain.one[AsmLine](LabelDef(labelGenerator.getLabel(func.name))) + chain ++= funcPrologue() + // Push the rest of params onto the stack for simplicity + argRegs.zip(func.params).foreach { (reg, param) => + chain += stack.push(param, Register(Q64, reg)) + } + chain ++= func.body.foldMap(generateStmt(_)) + // No need for epilogue here since all user functions must return explicitly + chain + } + + private def generateBuiltInFuncs()(using + stack: Stack, + labelGenerator: LabelGenerator + ): Chain[AsmLine] = { + var chain = Chain.empty[AsmLine] + + chain ++= wrapBuiltinFunc( + labelGenerator.getLabel(Builtin.Exit), + Chain(stackAlign, assemblyIR.Call(CLibFunc.Exit)) + ) + + chain ++= wrapBuiltinFunc( + labelGenerator.getLabel(Builtin.Printf), + Chain( + stackAlign, + assemblyIR.Call(CLibFunc.PrintF), + Xor(RDI, RDI), + assemblyIR.Call(CLibFunc.Fflush) + ) + ) + + chain ++= wrapBuiltinFunc( + labelGenerator.getLabel(Builtin.PrintCharArray), + Chain( + stackAlign, + Load(RDX, IndexAddress(RSI, KnownType.Int.size.toInt)), + Move(Register(D32, SI), MemLocation(RSI, D32)), + assemblyIR.Call(CLibFunc.PrintF), + Xor(RDI, RDI), + assemblyIR.Call(CLibFunc.Fflush) + ) + ) + + chain ++= wrapBuiltinFunc( + labelGenerator.getLabel(Builtin.Malloc), + Chain(stackAlign, assemblyIR.Call(CLibFunc.Malloc)) + // Out of memory check is optional + ) + + chain ++= wrapBuiltinFunc(labelGenerator.getLabel(Builtin.Free), Chain.empty) + + chain ++= wrapBuiltinFunc( + labelGenerator.getLabel(Builtin.Read), + Chain( + stackAlign, + Subtract(Register(Q64, SP), ImmediateVal(8)), + Push(RSI), + Load(RSI, MemLocation(Register(Q64, SP), Q64)), + assemblyIR.Call(CLibFunc.Scanf), + Pop(RAX) + ) + ) + + chain ++= Chain( + // TODO can this be done with a call to generateStmt? + // Consider other error cases -> look to generalise + LabelDef(zeroDivError.errLabel), + stackAlign, + Load(RDI, IndexAddress(RIP, LabelArg(zeroDivError.strLabel))), + assemblyIR.Call(CLibFunc.PrintF), + Move(RDI, ImmediateVal(-1)), + assemblyIR.Call(CLibFunc.Exit) + ) + + chain + } + + private def generateStmt(stmt: Stmt)(using + stack: Stack, + strings: ListBuffer[String], + labelGenerator: LabelGenerator + ): Chain[AsmLine] = { + var chain = Chain.empty[AsmLine] + chain += Comment(stmt.toString) + stmt match { + case Assign(lhs, rhs) => + lhs match { + case ident: Ident => + if (!stack.contains(ident)) chain += stack.reserve(ident) + chain ++= evalExprOntoStack(rhs) + chain += stack.pop(RAX) + chain += Move(stack.accessVar(ident), RAX) + case ArrayElem(x, i) => + chain ++= evalExprOntoStack(rhs) + chain ++= evalExprOntoStack(i) + chain ++= evalExprOntoStack(x) + chain += stack.pop(RAX) + chain += stack.pop(RCX) + chain += stack.pop(RDX) + + chain += Move( + IndexAddress(RAX, KnownType.Int.size.toInt, RCX, x.ty.elemSize.toInt), + Register(x.ty.elemSize, DX) + ) + } + + case If(cond, thenBranch, elseBranch) => + val elseLabel = labelGenerator.getLabel() + val endLabel = labelGenerator.getLabel() + + chain ++= evalExprOntoStack(cond) + chain += stack.pop(RAX) + chain += Compare(RAX, ImmediateVal(0)) + chain += Jump(LabelArg(elseLabel), Cond.Equal) + + chain ++= stack.withScope(() => thenBranch.foldMap(generateStmt)) + chain += Jump(LabelArg(endLabel)) + chain += LabelDef(elseLabel) + + chain ++= stack.withScope(() => elseBranch.foldMap(generateStmt)) + chain += LabelDef(endLabel) + + case While(cond, body) => + val startLabel = labelGenerator.getLabel() + val endLabel = labelGenerator.getLabel() + + chain += LabelDef(startLabel) + chain ++= evalExprOntoStack(cond) + chain += stack.pop(RAX) + chain += Compare(RAX, ImmediateVal(0)) + chain += Jump(LabelArg(endLabel), Cond.Equal) + + chain ++= stack.withScope(() => body.foldMap(generateStmt)) + chain += Jump(LabelArg(startLabel)) + chain += LabelDef(endLabel) + + case call: microWacc.Call => + chain ++= generateCall(call, isTail = false) + + case microWacc.Return(expr) => + expr match { + case call: microWacc.Call => + chain ++= generateCall(call, isTail = true) // tco + case _ => + chain ++= evalExprOntoStack(expr) + chain += stack.pop(RAX) + chain ++= funcEpilogue() + } + } + + chain + } + + private def evalExprOntoStack(expr: Expr)(using + stack: Stack, + strings: ListBuffer[String], + labelGenerator: LabelGenerator + ): Chain[AsmLine] = { + var chain = Chain.empty[AsmLine] + val stackSizeStart = stack.size + expr match { + case IntLiter(v) => chain += stack.push(KnownType.Int.size, ImmediateVal(v)) + case CharLiter(v) => chain += stack.push(KnownType.Char.size, ImmediateVal(v.toInt)) + case ident: Ident => chain += stack.push(ident.ty.size, stack.accessVar(ident)) + + case array @ ArrayLiter(elems) => + expr.ty match { + case KnownType.String => + strings += elems.collect { case CharLiter(v) => v }.mkString + chain += Load(RAX, IndexAddress(RIP, LabelArg(s".L.str${strings.size - 1}"))) + chain += stack.push(Q64, RAX) + case ty => + chain ++= generateCall( + microWacc.Call(Builtin.Malloc, List(IntLiter(array.heapSize))), + isTail = false + ) + chain += stack.push(Q64, RAX) + // Store the length of the array at the start + chain += Move(MemLocation(RAX, D32), ImmediateVal(elems.size)) + elems.zipWithIndex.foldMap { (elem, i) => + chain ++= evalExprOntoStack(elem) + chain += stack.pop(RCX) + chain += stack.pop(RAX) + chain += Move(IndexAddress(RAX, 4 + i * ty.elemSize.toInt), Register(ty.elemSize, CX)) + chain += stack.push(Q64, RAX) + } + } + + case BoolLiter(true) => + chain += stack.push(KnownType.Bool.size, ImmediateVal(1)) + case BoolLiter(false) => + chain += Xor(RAX, RAX) + chain += stack.push(KnownType.Bool.size, RAX) + case NullLiter() => + chain += stack.push(KnownType.Pair(?, ?).size, ImmediateVal(0)) + case ArrayElem(x, i) => + chain ++= evalExprOntoStack(x) + chain ++= evalExprOntoStack(i) + chain += stack.pop(RCX) + chain += stack.pop(RAX) + // + Int because we store the length of the array at the start + chain += Move( + Register(x.ty.elemSize, AX), + IndexAddress(RAX, KnownType.Int.size.toInt, RCX, x.ty.elemSize.toInt) + ) + chain += stack.push(x.ty.elemSize, RAX) + case UnaryOp(x, op) => + chain ++= evalExprOntoStack(x) + op match { + case UnaryOperator.Chr | UnaryOperator.Ord => // No op needed + case UnaryOperator.Len => + chain += stack.pop(RAX) + chain += Move(EAX, MemLocation(RAX, D32)) + chain += stack.push(D32, RAX) + case UnaryOperator.Negate => + chain += Negate(stack.head) + case UnaryOperator.Not => + chain += Xor(stack.head, ImmediateVal(1)) + } + + case BinaryOp(x, y, op) => + val destX = Register(x.ty.size, AX) + chain ++= evalExprOntoStack(y) + chain ++= evalExprOntoStack(x) + chain += stack.pop(RAX) + + op match { + case BinaryOperator.Add => + chain += Add(stack.head, destX) + case BinaryOperator.Sub => + chain += Subtract(destX, stack.head) + chain += stack.drop() + chain += stack.push(destX.size, RAX) + case BinaryOperator.Mul => + chain += Multiply(destX, stack.head) + chain += stack.drop() + chain += stack.push(destX.size, RAX) + + case BinaryOperator.Div => + chain += Compare(stack.head, ImmediateVal(0)) + chain += Jump(LabelArg(zeroDivError.errLabel), Cond.Equal) + chain += CDQ() + chain += Divide(stack.head) + chain += stack.drop() + chain += stack.push(destX.size, RAX) + + case BinaryOperator.Mod => + chain += CDQ() + chain += Divide(stack.head) + chain += stack.drop() + chain += stack.push(destX.size, RDX) + + case BinaryOperator.Eq => chain ++= generateComparison(destX, Cond.Equal) + case BinaryOperator.Neq => chain ++= generateComparison(destX, Cond.NotEqual) + case BinaryOperator.Greater => chain ++= generateComparison(destX, Cond.Greater) + case BinaryOperator.GreaterEq => chain ++= generateComparison(destX, Cond.GreaterEqual) + case BinaryOperator.Less => chain ++= generateComparison(destX, Cond.Less) + case BinaryOperator.LessEq => chain ++= generateComparison(destX, Cond.LessEqual) + case BinaryOperator.And => chain += And(stack.head, destX) + case BinaryOperator.Or => chain += Or(stack.head, destX) + } + + case call: microWacc.Call => + chain ++= generateCall(call, isTail = false) + chain += stack.push(call.ty.size, RAX) + } + + assert(stack.size == stackSizeStart + 1) + chain ++= zeroRest(MemLocation(stack.head.pointer, Q64), expr.ty.size) + chain + } + + private def generateCall(call: microWacc.Call, isTail: Boolean)(using + stack: Stack, + strings: ListBuffer[String], + labelGenerator: LabelGenerator + ): Chain[AsmLine] = { + var chain = Chain.empty[AsmLine] + val microWacc.Call(target, args) = call + + argRegs.zip(args).foldMap { (reg, expr) => + chain ++= evalExprOntoStack(expr) + chain += stack.pop(Register(Q64, reg)) + } + + args.drop(argRegs.size).foldMap { + chain ++= evalExprOntoStack(_) + } + + // Tail Call Optimisation (TCO) + if (isTail) { + chain += Jump(LabelArg(labelGenerator.getLabel(target))) // tail call + } else { + chain += assemblyIR.Call(LabelArg(labelGenerator.getLabel(target))) // regular call + } + + if (args.size > argRegs.size) { + chain += stack.drop(args.size - argRegs.size) + } + + chain + } + + private def generateComparison(destX: Register, cond: Cond)(using + stack: Stack + ): Chain[AsmLine] = { + var chain = Chain.empty[AsmLine] + + chain += Compare(destX, stack.head) + chain += Set(Register(B8, AX), cond) + chain ++= zeroRest(RAX, B8) + chain += stack.drop() + chain += stack.push(B8, RAX) + + chain + } + + private def funcPrologue()(using stack: Stack): Chain[AsmLine] = { + var chain = Chain.empty[AsmLine] + chain += stack.push(Q64, RBP) + chain += Move(RBP, Register(Q64, SP)) + chain + } + + private def funcEpilogue(): Chain[AsmLine] = { + var chain = Chain.empty[AsmLine] + chain += Move(Register(Q64, SP), RBP) + chain += Pop(RBP) + chain += assemblyIR.Return() + chain + } + + private def stackAlign: AsmLine = And(Register(Q64, SP), ImmediateVal(-16)) + private def zeroRest(dest: Dest, size: Size): Chain[AsmLine] = size match { + case Q64 | D32 => Chain.empty + case _ => Chain.one(And(dest, ImmediateVal((1 << (size.toInt * 8)) - 1))) + } + + private val escapedCharsMapping = escapedChars.map { case (k, v) => v -> s"\\$k" } + extension (s: String) { + private def escaped: String = + s.flatMap(c => escapedCharsMapping.getOrElse(c, c.toString)) + } +} diff --git a/src/main/wacc/backend/assemblyIR.scala b/src/main/wacc/backend/assemblyIR.scala index 323742b..fbf51f5 100644 --- a/src/main/wacc/backend/assemblyIR.scala +++ b/src/main/wacc/backend/assemblyIR.scala @@ -6,47 +6,114 @@ object assemblyIR { sealed trait Operand sealed trait Src extends Operand // mem location, register and imm value sealed trait Dest extends Operand // mem location and register - enum RegSize { - case R64 - case E32 - override def toString = this match { - case R64 => "r" - case E32 => "e" + enum Size { + case Q64, D32, W16, B8 + + def toInt: Int = this match { + case Q64 => 8 + case D32 => 4 + case W16 => 2 + case B8 => 1 + } + + private val ptr = "ptr " + + override def toString(): String = this match { + case Q64 => "qword " + ptr + case D32 => "dword " + ptr + case W16 => "word " + ptr + case B8 => "byte " + ptr + } + } + + enum RegName { + case AX, BX, CX, DX, SI, DI, SP, BP, IP, R8, R9, R10, R11, R12, R13, R14, R15 + } + + case class Register(size: Size, name: RegName) extends Dest with Src { + import RegName._ + + if (size == Size.B8 && name == RegName.IP) { + throw new IllegalArgumentException("Cannot have 8 bit register for IP") + } + override def toString = name match { + case AX => tradToString("ax", "al") + case BX => tradToString("bx", "bl") + case CX => tradToString("cx", "cl") + case DX => tradToString("dx", "dl") + case SI => tradToString("si", "sil") + case DI => tradToString("di", "dil") + case SP => tradToString("sp", "spl") + case BP => tradToString("bp", "bpl") + case IP => tradToString("ip", "#INVALID") + case R8 => newToString(8) + case R9 => newToString(9) + case R10 => newToString(10) + case R11 => newToString(11) + case R12 => newToString(12) + case R13 => newToString(13) + case R14 => newToString(14) + case R15 => newToString(15) + } + + private def tradToString(base: String, byteName: String): String = + size match { + case Size.Q64 => "r" + base + case Size.D32 => "e" + base + case Size.W16 => base + case Size.B8 => byteName + } + + private def newToString(base: Int): String = { + val b = base.toString + "r" + (size match { + case Size.Q64 => b + case Size.D32 => b + "d" + case Size.W16 => b + "w" + case Size.B8 => b + "b" + }) } } // arguments enum CLibFunc extends Operand { case Scanf, - Puts, Fflush, Exit, - PrintF + PrintF, + Malloc, + Free private val plt = "@plt" override def toString = this match { case Scanf => "scanf" + plt - case Puts => "puts" + plt case Fflush => "fflush" + plt case Exit => "exit" + plt case PrintF => "printf" + plt + case Malloc => "malloc" + plt + case Free => "free" + plt } } - enum Register extends Dest with Src { - case Named(name: String, size: RegSize) - case Scratch(num: Int, size: RegSize) - override def toString = this match { - case Named(name, size) => s"${size}${name.toLowerCase()}" - case Scratch(num, size) => s"r${num}${if (size == RegSize.E32) "d" else ""}" - } + case class MemLocation(pointer: Register, opSize: Size) extends Dest with Src { + override def toString = + opSize.toString + s"[$pointer]" } - case class MemLocation(pointer: Long | Register) extends Dest with Src { - override def toString = pointer match { - case hex: Long => f"[0x$hex%X]" - case reg: Register => s"[$reg]" + + // TODO to string is wacky + case class IndexAddress( + base: Register, + offset: Int | LabelArg, + indexReg: Register = Register(Size.Q64, RegName.AX), + scale: Int = 0 + ) extends Dest + with Src { + override def toString = if (scale != 0) { + s"[$base + $indexReg * $scale + $offset]" + } else { + s"[$base + $offset]" } } @@ -64,30 +131,42 @@ object assemblyIR { } case class Add(op1: Dest, op2: Src) extends Operation("add", op1, op2) case class Subtract(op1: Dest, op2: Src) extends Operation("sub", op1, op2) - case class Multiply(ops: Operand*) extends Operation("mul", ops*) - case class Divide(op1: Src) extends Operation("div", op1) + case class Multiply(ops: Operand*) extends Operation("imul", ops*) + case class Divide(op1: Src) extends Operation("idiv", op1) + case class Negate(op: Dest) extends Operation("neg", op) case class And(op1: Dest, op2: Src) extends Operation("and", op1, op2) case class Or(op1: Dest, op2: Src) extends Operation("or", op1, op2) + case class Xor(op1: Dest, op2: Src) extends Operation("xor", op1, op2) case class Compare(op1: Dest, op2: Src) extends Operation("cmp", op1, op2) // stack operations case class Push(op1: Src) extends Operation("push", op1) case class Pop(op1: Src) extends Operation("pop", op1) - case class Call(op1: CLibFunc) extends Operation("call", op1) + case class Call(op1: CLibFunc | LabelArg) extends Operation("call", op1) case class Move(op1: Dest, op2: Src) extends Operation("mov", op1, op2) - case class Load(op1: Register, op2: MemLocation) extends Operation("lea ", op1, op2) + case class Load(op1: Register, op2: MemLocation | IndexAddress) + extends Operation("lea ", op1, op2) + case class CDQ() extends Operation("cdq") case class Return() extends Operation("ret") case class Jump(op1: LabelArg, condition: Cond = Cond.Always) extends Operation(s"j${condition.toString}", op1) + case class Set(op1: Dest, condition: Cond = Cond.Always) + extends Operation(s"set${condition.toString}", op1) + case class LabelDef(name: String) extends AsmLine { override def toString = s"$name:" } + case class Comment(comment: String) extends AsmLine { + override def toString = + comment.split("\n").map(line => s"# ${line}").mkString("\n") + } + enum Cond { case Equal, NotEqual, @@ -108,4 +187,30 @@ object assemblyIR { case Always => "mp" } } + + enum Directive extends AsmLine { + case IntelSyntax, RoData, Text + case Global(name: String) + case Int(value: scala.Int) + case Asciz(string: String) + + override def toString(): String = this match { + case IntelSyntax => ".intel_syntax noprefix" + case Global(name) => s".globl $name" + case Text => ".text" + case RoData => ".section .rodata" + case Int(value) => s".int $value" + case Asciz(string) => s".asciz \"$string\"" + } + } + + enum PrintFormat { + case Int, Char, String + + override def toString(): String = this match { + case Int => "%d" + case Char => "%c" + case String => "%s" + } + } } diff --git a/src/main/wacc/backend/sizeExtensions.scala b/src/main/wacc/backend/sizeExtensions.scala new file mode 100644 index 0000000..798e290 --- /dev/null +++ b/src/main/wacc/backend/sizeExtensions.scala @@ -0,0 +1,35 @@ +package wacc + +object sizeExtensions { + import microWacc._ + import types._ + import assemblyIR.Size + + extension (expr: Expr) { + + /** Calculate the size (bytes) of the heap required for the expression. */ + def heapSize: Int = (expr, expr.ty) match { + case (ArrayLiter(elems), KnownType.Array(KnownType.Char)) => + KnownType.Int.size.toInt + elems.size.toInt * KnownType.Char.size.toInt + case (ArrayLiter(elems), ty) => + KnownType.Int.size.toInt + elems.size * ty.elemSize.toInt + case _ => expr.ty.size.toInt + } + } + + extension (ty: SemType) { + + /** Calculate the size (bytes) of a type in a register. */ + def size: Size = ty match { + case KnownType.Int => Size.D32 + case KnownType.Bool | KnownType.Char => Size.B8 + case KnownType.String | KnownType.Array(_) | KnownType.Pair(_, _) | ? => Size.Q64 + } + + def elemSize: Size = ty match { + case KnownType.Array(elem) => elem.size + case KnownType.Pair(_, _) => Size.Q64 + case _ => ty.size + } + } +} diff --git a/src/main/wacc/backend/writer.scala b/src/main/wacc/backend/writer.scala index b798af3..3c8dcfd 100644 --- a/src/main/wacc/backend/writer.scala +++ b/src/main/wacc/backend/writer.scala @@ -1,11 +1,12 @@ package wacc import java.io.PrintStream +import cats.data.Chain object writer { import assemblyIR._ - def writeTo(asmList: List[AsmLine], printStream: PrintStream): Unit = { - asmList.foreach(printStream.println) + def writeTo(asmList: Chain[AsmLine], printStream: PrintStream): Unit = { + asmList.iterator.foreach(printStream.println) } } diff --git a/src/main/wacc/frontend/lexer.scala b/src/main/wacc/frontend/lexer.scala index 2efe517..4cb51a8 100644 --- a/src/main/wacc/frontend/lexer.scala +++ b/src/main/wacc/frontend/lexer.scala @@ -39,6 +39,17 @@ val errConfig = new ErrorConfig { ) } object lexer { + val escapedChars: Map[String, Int] = Map( + "0" -> '\u0000', + "b" -> '\b', + "t" -> '\t', + "n" -> '\n', + "f" -> '\f', + "r" -> '\r', + "\\" -> '\\', + "'" -> '\'', + "\"" -> '\"' + ) /** Language description for the WACC lexer */ @@ -63,15 +74,9 @@ object lexer { textDesc = TextDesc.plain.copy( graphicCharacter = Basic(c => c >= ' ' && c != '\\' && c != '\'' && c != '"'), escapeSequences = EscapeDesc.plain.copy( - literals = Set('\\', '"', '\''), - mapping = Map( - "0" -> '\u0000', - "b" -> '\b', - "t" -> '\t', - "n" -> '\n', - "f" -> '\f', - "r" -> '\r' - ) + literals = + escapedChars.filter { (s, chr) => chr.toChar.toString == s }.map(_._2.toChar).toSet, + mapping = escapedChars.filter { (s, chr) => chr.toChar.toString != s } ) ), numericDesc = NumericDesc.plain.copy( diff --git a/src/main/wacc/frontend/microWacc.scala b/src/main/wacc/frontend/microWacc.scala index b9f6635..e2c1bdc 100644 --- a/src/main/wacc/frontend/microWacc.scala +++ b/src/main/wacc/frontend/microWacc.scala @@ -1,7 +1,5 @@ package wacc -import cats.data.NonEmptyList - object microWacc { import wacc.types._ @@ -19,9 +17,7 @@ object microWacc { extends Expr(identTy) with CallTarget(identTy) with LValue - case class ArrayElem(value: LValue, indices: NonEmptyList[Expr])(ty: SemType) - extends Expr(ty) - with LValue + case class ArrayElem(value: LValue, index: Expr)(ty: SemType) extends Expr(ty) with LValue // Operators case class UnaryOp(x: Expr, op: UnaryOperator)(ty: SemType) extends Expr(ty) @@ -69,13 +65,16 @@ object microWacc { // Statements sealed trait Stmt + case class Builtin(val name: String)(retTy: SemType) extends CallTarget(retTy) { + override def toString(): String = name + } object Builtin { - case object ReadInt extends CallTarget(KnownType.Int) - case object ReadChar extends CallTarget(KnownType.Char) - case object Print extends CallTarget(?) - case object Println extends CallTarget(?) - case object Exit extends CallTarget(?) - case object Free extends CallTarget(?) + object Read extends Builtin("read")(?) + object Printf extends Builtin("printf")(?) + object Exit extends Builtin("exit")(?) + object Free extends Builtin("free")(?) + object Malloc extends Builtin("malloc")(?) + object PrintCharArray extends Builtin("printCharArray")(?) } case class Assign(lhs: LValue, rhs: Expr) extends Stmt diff --git a/src/main/wacc/frontend/typeChecker.scala b/src/main/wacc/frontend/typeChecker.scala index e3960cb..002876d 100644 --- a/src/main/wacc/frontend/typeChecker.scala +++ b/src/main/wacc/frontend/typeChecker.scala @@ -110,7 +110,7 @@ object typeChecker { microWacc.FuncDecl( microWacc.Ident(name.v, name.uid)(retType), params.zip(paramTypes).map { case (ast.Param(_, ident), ty) => - microWacc.Ident(ident.v, name.uid)(ty) + microWacc.Ident(ident.v, ident.uid)(ty) }, stmts.toList .flatMap( @@ -177,12 +177,14 @@ object typeChecker { microWacc.Assign( destTyped, microWacc.Call( - destTy match { - case KnownType.Int => microWacc.Builtin.ReadInt - case KnownType.Char => microWacc.Builtin.ReadChar - case _ => microWacc.Builtin.ReadInt // we'll stop due to error anyway - }, - Nil + microWacc.Builtin.Read, + List( + destTy match { + case KnownType.Int => "%d".toMicroWaccCharArray + case KnownType.Char | _ => "%c".toMicroWaccCharArray + }, + destTyped + ) ) ) ) @@ -213,12 +215,38 @@ object typeChecker { ) case ast.Print(expr, newline) => // This constraint should never fail, the scope-checker should have caught it already - List( - microWacc.Call( - if newline then microWacc.Builtin.Println else microWacc.Builtin.Print, - List(checkValue(expr, Constraint.Unconstrained)) + val exprTyped = checkValue(expr, Constraint.Unconstrained) + val exprFormat = exprTyped.ty match { + case KnownType.Bool | KnownType.String => "%s" + case KnownType.Array(KnownType.Char) => "%.*s" + case KnownType.Char => "%c" + case KnownType.Int => "%d" + case KnownType.Pair(_, _) | KnownType.Array(_) | ? => "%p" + } + val printfCall = { (func: microWacc.Builtin, value: microWacc.Expr) => + List( + microWacc.Call( + func, + List( + s"$exprFormat${if newline then "\n" else ""}".toMicroWaccCharArray, + value + ) + ) ) - ) + } + exprTyped.ty match { + case KnownType.Bool => + List( + microWacc.If( + exprTyped, + printfCall(microWacc.Builtin.Printf, "true".toMicroWaccCharArray), + printfCall(microWacc.Builtin.Printf, "false".toMicroWaccCharArray) + ) + ) + case KnownType.Array(KnownType.Char) => + printfCall(microWacc.Builtin.PrintCharArray, exprTyped) + case _ => printfCall(microWacc.Builtin.Printf, exprTyped) + } case ast.If(cond, thenStmt, elseStmt) => List( microWacc.If( @@ -262,7 +290,7 @@ object typeChecker { microWacc.CharLiter(v) case l @ ast.StrLiter(v) => KnownType.String.satisfies(constraint, l.pos) - microWacc.ArrayLiter(v.map(microWacc.CharLiter(_)).toList)(KnownType.String) + v.toMicroWaccCharArray case l @ ast.PairLiter() => microWacc.NullLiter()(KnownType.Pair(?, ?).satisfies(constraint, l.pos)) case ast.Parens(expr) => checkValue(expr, constraint) @@ -410,10 +438,15 @@ object typeChecker { } (next, idxTyped) } - microWacc.ArrayElem( + val firstArrayElem = microWacc.ArrayElem( microWacc.Ident(id.v, id.uid)(arrayTy), - indicesTyped + indicesTyped.head )(elemTy.satisfies(constraint, value.pos)) + val arrayElem = indicesTyped.tail.foldLeft(firstArrayElem) { (acc, idx) => + microWacc.ArrayElem(acc, idx)(KnownType.Array(acc.ty)) + } + // Need to type-check the final arrayElem with the constraint + microWacc.ArrayElem(arrayElem.value, arrayElem.index)(elemTy.satisfies(constraint, value.pos)) case ast.Fst(elem) => val elemTyped = checkLValue( elem, @@ -421,7 +454,7 @@ object typeChecker { ) microWacc.ArrayElem( elemTyped, - NonEmptyList.of(microWacc.IntLiter(0)) + microWacc.IntLiter(0) )(elemTyped.ty match { case KnownType.Pair(left, _) => left.satisfies(constraint, elem.pos) @@ -434,11 +467,16 @@ object typeChecker { ) microWacc.ArrayElem( elemTyped, - NonEmptyList.of(microWacc.IntLiter(1)) + microWacc.IntLiter(1) )(elemTyped.ty match { case KnownType.Pair(_, right) => right.satisfies(constraint, elem.pos) case _ => ctx.error(Error.InternalError(elem.pos, "snd must be applied to a pair")) }) } + + extension (s: String) { + def toMicroWaccCharArray: microWacc.ArrayLiter = + microWacc.ArrayLiter(s.map(microWacc.CharLiter(_)).toList)(KnownType.String) + } } diff --git a/src/test/wacc/examples.scala b/src/test/wacc/examples.scala index abff693..7c895ab 100644 --- a/src/test/wacc/examples.scala +++ b/src/test/wacc/examples.scala @@ -1,13 +1,14 @@ package wacc -import org.scalatest.{ParallelTestExecution, BeforeAndAfterAll} +import org.scalatest.BeforeAndAfterAll import org.scalatest.flatspec.AnyFlatSpec import org.scalatest.Inspectors.forEvery import java.io.File import sys.process._ import java.io.PrintStream +import scala.io.Source -class ParallelExamplesSpec extends AnyFlatSpec with BeforeAndAfterAll with ParallelTestExecution { +class ParallelExamplesSpec extends AnyFlatSpec with BeforeAndAfterAll { val files = allWaccFiles("wacc-examples/valid").map { p => (p.toString, List(0)) @@ -26,19 +27,38 @@ class ParallelExamplesSpec extends AnyFlatSpec with BeforeAndAfterAll with Paral forEvery(files) { (filename, expectedResult) => val baseFilename = filename.stripSuffix(".wacc") given stdout: PrintStream = PrintStream(File(baseFilename + ".out")) - val result = compile(filename) s"$filename" should "be compiled with correct result" in { + val result = compile(filename) assert(expectedResult.contains(result)) } - if (result == 0) it should "run with correct result" in { + if (expectedResult == List(0)) it should "run with correct result" in { if (fileIsDisallowedBackend(filename)) pending // Retrieve contents to get input and expected output + exit code val contents = scala.io.Source.fromFile(File(filename)).getLines.toList val inputLine = - contents.find(_.matches("^# ?[Ii]nput:.*$")).map(_.split(":").last.strip).getOrElse("") + contents + .find(_.matches("^# ?[Ii]nput:.*$")) + .map(line => + ("" :: line.split(":").last.strip.split(" ").toList) + .sliding(2) + .flatMap { arr => + if ( + // First entry has no space in front + arr(0) == "" || + // int followed by non-digit, space can be removed + arr(0).toIntOption.nonEmpty && !arr(1)(0).isDigit || + // non-int followed by int, space can be removed + !arr(0).last.isDigit && arr(1).toIntOption.nonEmpty + ) + then List(arr(1)) + else List(" ", arr(1)) + } + .mkString + ) + .getOrElse("") val outputLineIdx = contents.indexWhere(_.matches("^# ?[Oo]utput:.*$")) val expectedOutput = if (outputLineIdx == -1) "" @@ -62,12 +82,17 @@ class ParallelExamplesSpec extends AnyFlatSpec with BeforeAndAfterAll with Paral // Run the executable with the provided input val stdout = new StringBuilder - // val execResult = s"$execFilename".!(ProcessLogger(stdout.append(_))) - val execResult = - s"echo $inputLine" #| s"timeout 5s $execFilename" ! ProcessLogger(stdout.append(_)) + val process = s"timeout 5s $execFilename" run ProcessIO( + in = w => { + w.write(inputLine.getBytes) + w.close() + }, + out = Source.fromInputStream(_).addString(stdout), + err = _ => () + ) - assert(execResult == expectedExit) - assert(stdout.toString == expectedOutput) + assert(process.exitValue == expectedExit) + assert(stdout.toString.replaceAll("0x[0-9a-f]+", "#addrs#") == expectedOutput) } } @@ -80,22 +105,23 @@ class ParallelExamplesSpec extends AnyFlatSpec with BeforeAndAfterAll with Paral // format: off // disable formatting to avoid binPack "^.*wacc-examples/valid/advanced.*$", - "^.*wacc-examples/valid/array.*$", - "^.*wacc-examples/valid/basic/skip.*$", - "^.*wacc-examples/valid/expressions.*$", - "^.*wacc-examples/valid/function/nested_functions.*$", - "^.*wacc-examples/valid/function/simple_functions.*$", - "^.*wacc-examples/valid/if.*$", - "^.*wacc-examples/valid/IO/print.*$", - "^.*wacc-examples/valid/IO/read(?!echoInt\\.wacc).*$", - "^.*wacc-examples/valid/IO/IOLoop.wacc.*$", - "^.*wacc-examples/valid/IO/IOSequence.wacc.*$", - "^.*wacc-examples/valid/pairs.*$", + // "^.*wacc-examples/valid/array.*$", + // "^.*wacc-examples/valid/basic/exit.*$", + // "^.*wacc-examples/valid/basic/skip.*$", + // "^.*wacc-examples/valid/expressions.*$", + // "^.*wacc-examples/valid/function/nested_functions.*$", + // "^.*wacc-examples/valid/function/simple_functions.*$", + // "^.*wacc-examples/valid/if.*$", + // "^.*wacc-examples/valid/IO/print.*$", + // "^.*wacc-examples/valid/IO/read.*$", + // "^.*wacc-examples/valid/IO/IOLoop.wacc.*$", + // "^.*wacc-examples/valid/IO/IOSequence.wacc.*$", + // "^.*wacc-examples/valid/pairs.*$", "^.*wacc-examples/valid/runtimeErr.*$", - "^.*wacc-examples/valid/scope.*$", - "^.*wacc-examples/valid/sequence.*$", - "^.*wacc-examples/valid/variables.*$", - "^.*wacc-examples/valid/while.*$", + // "^.*wacc-examples/valid/scope.*$", + // "^.*wacc-examples/valid/sequence.*$", + // "^.*wacc-examples/valid/variables.*$", + // "^.*wacc-examples/valid/while.*$", // format: on ).find(filename.matches).isDefined } diff --git a/src/test/wacc/instructionSpec.scala b/src/test/wacc/instructionSpec.scala index 6d427a5..feef0d4 100644 --- a/src/test/wacc/instructionSpec.scala +++ b/src/test/wacc/instructionSpec.scala @@ -1,42 +1,38 @@ import org.scalatest.funsuite.AnyFunSuite import wacc.assemblyIR._ +import wacc.assemblyIR.Size._ +import wacc.assemblyIR.RegName._ class instructionSpec extends AnyFunSuite { - val named64BitRegister = Register.Named("ax", RegSize.R64) + val named64BitRegister = Register(Q64, AX) test("named 64-bit register toString") { assert(named64BitRegister.toString == "rax") } - val named32BitRegister = Register.Named("ax", RegSize.E32) + val named32BitRegister = Register(D32, AX) test("named 32-bit register toString") { assert(named32BitRegister.toString == "eax") } - val scratch64BitRegister = Register.Scratch(1, RegSize.R64) + val scratch64BitRegister = Register(Q64, R8) test("scratch 64-bit register toString") { - assert(scratch64BitRegister.toString == "r1") + assert(scratch64BitRegister.toString == "r8") } - val scratch32BitRegister = Register.Scratch(1, RegSize.E32) + val scratch32BitRegister = Register(D32, R8) test("scratch 32-bit register toString") { - assert(scratch32BitRegister.toString == "r1d") + assert(scratch32BitRegister.toString == "r8d") } - val memLocationWithHex = MemLocation(0x12345678) - - test("mem location with hex toString") { - assert(memLocationWithHex.toString == "[0x12345678]") - } - - val memLocationWithRegister = MemLocation(named64BitRegister) + val memLocationWithRegister = MemLocation(named64BitRegister, Q64) test("mem location with register toString") { - assert(memLocationWithRegister.toString == "[rax]") + assert(memLocationWithRegister.toString == "qword ptr [rax]") } val immediateVal = ImmediateVal(123) @@ -54,7 +50,7 @@ class instructionSpec extends AnyFunSuite { val subInstruction = Subtract(scratch64BitRegister, named64BitRegister) test("x86: sub instruction toString") { - assert(subInstruction.toString == "\tsub r1, rax") + assert(subInstruction.toString == "\tsub r8, rax") } val callInstruction = Call(CLibFunc.Scanf)