From ae52fa653ce7aaa4a684a3c48487ecafcdd7d4c3 Mon Sep 17 00:00:00 2001 From: Gleb Koval Date: Sat, 1 Mar 2025 02:07:45 +0000 Subject: [PATCH] fix: add opSize back in to stack --- src/main/wacc/backend/Stack.scala | 4 +- src/test/wacc/backend/extensionsSpec.scala | 66 +++++++++ .../wacc/{ => backend}/instructionSpec.scala | 0 .../wacc/backend/labelGeneratorSpec.scala | 85 +++++++++++ src/test/wacc/backend/stackSpec.scala | 140 ++++++++++++++++++ 5 files changed, 293 insertions(+), 2 deletions(-) create mode 100644 src/test/wacc/backend/extensionsSpec.scala rename src/test/wacc/{ => backend}/instructionSpec.scala (100%) create mode 100644 src/test/wacc/backend/labelGeneratorSpec.scala create mode 100644 src/test/wacc/backend/stackSpec.scala diff --git a/src/main/wacc/backend/Stack.scala b/src/main/wacc/backend/Stack.scala index 06b52c7..7633b59 100644 --- a/src/main/wacc/backend/Stack.scala +++ b/src/main/wacc/backend/Stack.scala @@ -23,7 +23,7 @@ class Stack { /** Push an expression onto the stack. */ def push(expr: mw.Expr, src: Register): AsmLine = { - stack += expr -> StackValue(src.size, sizeBytes) + stack += expr -> StackValue(expr.ty.size, sizeBytes) Push(src) } @@ -81,7 +81,7 @@ class Stack { /** Get an MemLocation for a variable in the stack. */ def accessVar(ident: mw.Ident): MemLocation = - MemLocation(RSP, sizeBytes - stack(ident).bottom) + MemLocation(RSP, sizeBytes - stack(ident).bottom, opSize = Some(stack(ident).size)) def contains(ident: mw.Ident): Boolean = stack.contains(ident) def head: MemLocation = MemLocation(RSP, opSize = Some(stack.last._2.size)) diff --git a/src/test/wacc/backend/extensionsSpec.scala b/src/test/wacc/backend/extensionsSpec.scala new file mode 100644 index 0000000..b0499b6 --- /dev/null +++ b/src/test/wacc/backend/extensionsSpec.scala @@ -0,0 +1,66 @@ +package wacc + +import org.scalatest.flatspec.AnyFlatSpec +import org.scalatest.Inspectors.forEvery +import cats.data.Chain + +class ExtensionsSpec extends AnyFlatSpec { + import asmGenerator.concatAll + import asmGenerator.escaped + + behavior of "concatAll" + + it should "handle int chains" in { + val chain = Chain(1, 2, 3).concatAll( + Chain(4, 5, 6), + Chain.empty, + Chain.one(-1) + ) + assert(chain == Chain(1, 2, 3, 4, 5, 6, -1)) + } + + it should "handle AsmLine chains" in { + object lines { + import assemblyIR._ + import assemblyIR.commonRegisters._ + val main = LabelDef("main") + val pop = Pop(RAX) + val add = Add(RAX, ImmediateVal(1)) + val push = Push(RAX) + val ret = Return() + } + val chain = Chain(lines.main).concatAll( + Chain.empty, + Chain.one(lines.pop), + Chain(lines.add, lines.push), + Chain.one(lines.ret) + ) + assert(chain == Chain(lines.main, lines.pop, lines.add, lines.push, lines.ret)) + } + + behavior of "escaped" + + val escapedStrings = Map( + "hello" -> "hello", + "world" -> "world", + "hello\nworld" -> "hello\\nworld", + "hello\tworld" -> "hello\\tworld", + "hello\\world" -> "hello\\\\world", + "hello\"world" -> "hello\\\"world", + "hello'world" -> "hello\\'world", + "hello\\nworld" -> "hello\\\\nworld", + "hello\\tworld" -> "hello\\\\tworld", + "hello\\\\world" -> "hello\\\\\\\\world", + "hello\\\"world" -> "hello\\\\\\\"world", + "hello\\'world" -> "hello\\\\\\'world", + "hello\\n\\t\\'world" -> "hello\\\\n\\\\t\\\\\\'world", + "hello\u0000world" -> "hello\\0world", + "hello\bworld" -> "hello\\bworld", + "hello\fworld" -> "hello\\fworld" + ) + forEvery(escapedStrings) { (input, expected) => + it should s"return $expected" in { + assert(input.escaped == expected) + } + } +} diff --git a/src/test/wacc/instructionSpec.scala b/src/test/wacc/backend/instructionSpec.scala similarity index 100% rename from src/test/wacc/instructionSpec.scala rename to src/test/wacc/backend/instructionSpec.scala diff --git a/src/test/wacc/backend/labelGeneratorSpec.scala b/src/test/wacc/backend/labelGeneratorSpec.scala new file mode 100644 index 0000000..fc44c12 --- /dev/null +++ b/src/test/wacc/backend/labelGeneratorSpec.scala @@ -0,0 +1,85 @@ +package wacc + +import org.scalatest.flatspec.AnyFlatSpec + +class LabelGeneratorSpec extends AnyFlatSpec { + import microWacc._ + import assemblyIR.{LabelDef, LabelArg, Directive} + import types.? + + "getLabel" should "return unique labels" in { + val labelGenerator = new LabelGenerator + val labels = (1 to 10).map(_ => labelGenerator.getLabel()) + assert(labels.distinct.length == labels.length) + } + + "getLabelDef" should "return unique labels" in { + assert( + LabelDef("test") == LabelDef("test") && + LabelDef("test").hashCode == LabelDef("test").hashCode, + "Sanity check: LabelDef should be case-classes" + ) + + val labelGenerator = new LabelGenerator + val labels = (List( + Builtin.Exit, + Builtin.Printf, + Ident("exit", 0)(?), + Ident("test", 0)(?) + ) ++ RuntimeError.all.toList).map(labelGenerator.getLabelDef(_)) + assert(labels.distinct.length == labels.length) + } + + "getLabelArg" should "return unique labels" in { + assert( + LabelArg("test") == LabelArg("test") && + LabelArg("test").hashCode == LabelArg("test").hashCode, + "Sanity check: LabelArg should be case-classes" + ) + + val labelGenerator = new LabelGenerator + val labels = (List( + Builtin.Exit, + Builtin.Printf, + Ident("exit", 0)(?), + Ident("test", 0)(?), + "test", + "test", + "test3" + ) ++ RuntimeError.all.toList).map { + case s: String => labelGenerator.getLabelArg(s) + case t: (CallTarget | RuntimeError) => labelGenerator.getLabelArg(t) + } + assert(labels.distinct.length == labels.distinct.length) + } + + it should "return consistent labels to getLabelDef" in { + val labelGenerator = new LabelGenerator + val targets = (List( + Builtin.Exit, + Builtin.Printf, + Ident("exit", 0)(?), + Ident("test", 0)(?) + ) ++ RuntimeError.all.toList) + val labelDefs = targets.map(labelGenerator.getLabelDef(_).toString.dropRight(1)).toSet + val labelArgs = targets.map(labelGenerator.getLabelArg(_).toString).toSet + assert(labelDefs == labelArgs) + } + + "generateConstants" should "generate de-duplicated labels for strings" in { + val labelGenerator = new LabelGenerator + val strings = List("hello", "world", "hello\u0000world", "hello", "Hello") + val distincts = strings.distinct.length + val labels = strings.map(labelGenerator.getLabelArg(_).toString).toSet + val asmLines = labelGenerator.generateConstants + assert( + asmLines.collect { case LabelDef(name) => + name + }.length == distincts + ) + assert( + asmLines.collect { case Directive.Asciz(str) => str }.length == distincts + ) + assert(asmLines.collect { case LabelDef(name) => name }.toList.toSet == labels) + } +} diff --git a/src/test/wacc/backend/stackSpec.scala b/src/test/wacc/backend/stackSpec.scala new file mode 100644 index 0000000..7dc290a --- /dev/null +++ b/src/test/wacc/backend/stackSpec.scala @@ -0,0 +1,140 @@ +package wacc + +import org.scalatest.flatspec.AnyFlatSpec +import cats.data.Chain + +class StackSpec extends AnyFlatSpec { + import microWacc._ + import assemblyIR._ + import assemblyIR.Size._ + import assemblyIR.commonRegisters._ + import types.{KnownType, ?} + import sizeExtensions.size + + private val RSP = Register(Q64, RegName.SP) + + "size" should "be 0 initially" in { + val stack = new Stack + assert(stack.size == 0) + } + + "push" should "add an expression to the stack" in { + val stack = new Stack + val expr = Ident("x", 0)(?) + val result = stack.push(expr, RAX) + assert(stack.size == 1) + assert(result == Push(RAX)) + } + + it should "add 2 expressions to the stack" in { + val stack = new Stack + val expr1 = Ident("x", 0)(?) + val expr2 = Ident("x", 1)(?) + val result1 = stack.push(expr1, RAX) + val result2 = stack.push(expr2, RCX) + assert(stack.size == 2) + assert(result1 == Push(RAX)) + assert(result2 == Push(RCX)) + } + + it should "add a value to the stack" in { + val stack = new Stack + val result = stack.push(D32, RAX) + assert(stack.size == 1) + assert(result == Push(RAX)) + } + + "reserve" should "reserve space for an identifier" in { + val stack = new Stack + val ident = Ident("x", 0)(KnownType.Int) + val result = stack.reserve(ident) + assert(stack.size == 1) + assert(result == Subtract(RSP, ImmediateVal(Q64.toInt))) + } + + it should "reserve space for a register" in { + val stack = new Stack + val result = stack.reserve(RAX) + assert(stack.size == 1) + assert(result == Subtract(RSP, ImmediateVal(Q64.toInt))) + } + + it should "reserve space for multiple values" in { + val stack = new Stack + val result = stack.reserve(D32, Q64, B8) + assert(stack.size == 3) + assert(result == Subtract(RSP, ImmediateVal(Q64.toInt * 3))) + } + + "pop" should "remove the last value from the stack" in { + val stack = new Stack + stack.push(D32, RAX) + val result = stack.pop(RAX) + assert(stack.size == 0) + assert(result == Pop(RAX)) + } + + "drop" should "remove the last 2 value from the stack" in { + val stack = new Stack + stack.push(D32, RAX) + stack.push(Q64, RAX) + stack.push(B8, RAX) + val result = stack.drop(2) + assert(stack.size == 1) + assert(result == Add(RSP, ImmediateVal(Q64.toInt * 2))) + } + + "withScope" should "reset stack after block" in { + val stack = new Stack + stack.push(D32, RAX) + stack.push(Q64, RCX) + stack.push(B8, RDX) + val result = stack.withScope(() => + Chain( + stack.push(Q64, RSI), + stack.push(B8, RDI), + stack.push(B8, RBP) + ) + ) + assert(stack.size == 3) + assert( + result == Chain( + Push(RSI), + Push(RDI), + Push(RBP), + Add(RSP, ImmediateVal(Q64.toInt * 3)) + ) + ) + } + + "accessVar" should "return the correctly-sized memory location for the identifier" in { + val stack = new Stack + val id = Ident("x", 0)(KnownType.Int) + stack.push(Q64, RAX) + stack.push(id, RCX) + stack.push(B8, RDX) + stack.push(D32, RSI) + val result = stack.accessVar(Ident("x", 0)(KnownType.Int)) + assert(result == MemLocation(RSP, Q64.toInt * 2, opSize = Some(KnownType.Int.size))) + } + + "contains" should "return true if the stack contains the identifier" in { + val stack = new Stack + val id = Ident("x", 0)(KnownType.Int) + stack.push(D32, RAX) + stack.push(id, RCX) + stack.push(B8, RDX) + assert(stack.contains(id)) + assert(!stack.contains(Ident("x", 1)(KnownType.Int))) + } + + "head" should "return the correct memory location for the last element" in { + val stack = new Stack + val id = Ident("x", 0)(KnownType.Int) + stack.push(D32, RAX) + stack.push(id, RCX) + stack.push(B8, RDX) + val result = stack.head + assert(result == MemLocation(RSP, opSize = Some(B8))) + } +}