fix: add opSize back in to stack

This commit is contained in:
2025-03-01 02:07:45 +00:00
parent 3b723392a7
commit ae52fa653c
5 changed files with 293 additions and 2 deletions

View File

@@ -23,7 +23,7 @@ class Stack {
/** Push an expression onto the stack. */ /** Push an expression onto the stack. */
def push(expr: mw.Expr, src: Register): AsmLine = { def push(expr: mw.Expr, src: Register): AsmLine = {
stack += expr -> StackValue(src.size, sizeBytes) stack += expr -> StackValue(expr.ty.size, sizeBytes)
Push(src) Push(src)
} }
@@ -81,7 +81,7 @@ class Stack {
/** Get an MemLocation for a variable in the stack. */ /** Get an MemLocation for a variable in the stack. */
def accessVar(ident: mw.Ident): MemLocation = def accessVar(ident: mw.Ident): MemLocation =
MemLocation(RSP, sizeBytes - stack(ident).bottom) MemLocation(RSP, sizeBytes - stack(ident).bottom, opSize = Some(stack(ident).size))
def contains(ident: mw.Ident): Boolean = stack.contains(ident) def contains(ident: mw.Ident): Boolean = stack.contains(ident)
def head: MemLocation = MemLocation(RSP, opSize = Some(stack.last._2.size)) def head: MemLocation = MemLocation(RSP, opSize = Some(stack.last._2.size))

View File

@@ -0,0 +1,66 @@
package wacc
import org.scalatest.flatspec.AnyFlatSpec
import org.scalatest.Inspectors.forEvery
import cats.data.Chain
class ExtensionsSpec extends AnyFlatSpec {
import asmGenerator.concatAll
import asmGenerator.escaped
behavior of "concatAll"
it should "handle int chains" in {
val chain = Chain(1, 2, 3).concatAll(
Chain(4, 5, 6),
Chain.empty,
Chain.one(-1)
)
assert(chain == Chain(1, 2, 3, 4, 5, 6, -1))
}
it should "handle AsmLine chains" in {
object lines {
import assemblyIR._
import assemblyIR.commonRegisters._
val main = LabelDef("main")
val pop = Pop(RAX)
val add = Add(RAX, ImmediateVal(1))
val push = Push(RAX)
val ret = Return()
}
val chain = Chain(lines.main).concatAll(
Chain.empty,
Chain.one(lines.pop),
Chain(lines.add, lines.push),
Chain.one(lines.ret)
)
assert(chain == Chain(lines.main, lines.pop, lines.add, lines.push, lines.ret))
}
behavior of "escaped"
val escapedStrings = Map(
"hello" -> "hello",
"world" -> "world",
"hello\nworld" -> "hello\\nworld",
"hello\tworld" -> "hello\\tworld",
"hello\\world" -> "hello\\\\world",
"hello\"world" -> "hello\\\"world",
"hello'world" -> "hello\\'world",
"hello\\nworld" -> "hello\\\\nworld",
"hello\\tworld" -> "hello\\\\tworld",
"hello\\\\world" -> "hello\\\\\\\\world",
"hello\\\"world" -> "hello\\\\\\\"world",
"hello\\'world" -> "hello\\\\\\'world",
"hello\\n\\t\\'world" -> "hello\\\\n\\\\t\\\\\\'world",
"hello\u0000world" -> "hello\\0world",
"hello\bworld" -> "hello\\bworld",
"hello\fworld" -> "hello\\fworld"
)
forEvery(escapedStrings) { (input, expected) =>
it should s"return $expected" in {
assert(input.escaped == expected)
}
}
}

View File

@@ -0,0 +1,85 @@
package wacc
import org.scalatest.flatspec.AnyFlatSpec
class LabelGeneratorSpec extends AnyFlatSpec {
import microWacc._
import assemblyIR.{LabelDef, LabelArg, Directive}
import types.?
"getLabel" should "return unique labels" in {
val labelGenerator = new LabelGenerator
val labels = (1 to 10).map(_ => labelGenerator.getLabel())
assert(labels.distinct.length == labels.length)
}
"getLabelDef" should "return unique labels" in {
assert(
LabelDef("test") == LabelDef("test") &&
LabelDef("test").hashCode == LabelDef("test").hashCode,
"Sanity check: LabelDef should be case-classes"
)
val labelGenerator = new LabelGenerator
val labels = (List(
Builtin.Exit,
Builtin.Printf,
Ident("exit", 0)(?),
Ident("test", 0)(?)
) ++ RuntimeError.all.toList).map(labelGenerator.getLabelDef(_))
assert(labels.distinct.length == labels.length)
}
"getLabelArg" should "return unique labels" in {
assert(
LabelArg("test") == LabelArg("test") &&
LabelArg("test").hashCode == LabelArg("test").hashCode,
"Sanity check: LabelArg should be case-classes"
)
val labelGenerator = new LabelGenerator
val labels = (List(
Builtin.Exit,
Builtin.Printf,
Ident("exit", 0)(?),
Ident("test", 0)(?),
"test",
"test",
"test3"
) ++ RuntimeError.all.toList).map {
case s: String => labelGenerator.getLabelArg(s)
case t: (CallTarget | RuntimeError) => labelGenerator.getLabelArg(t)
}
assert(labels.distinct.length == labels.distinct.length)
}
it should "return consistent labels to getLabelDef" in {
val labelGenerator = new LabelGenerator
val targets = (List(
Builtin.Exit,
Builtin.Printf,
Ident("exit", 0)(?),
Ident("test", 0)(?)
) ++ RuntimeError.all.toList)
val labelDefs = targets.map(labelGenerator.getLabelDef(_).toString.dropRight(1)).toSet
val labelArgs = targets.map(labelGenerator.getLabelArg(_).toString).toSet
assert(labelDefs == labelArgs)
}
"generateConstants" should "generate de-duplicated labels for strings" in {
val labelGenerator = new LabelGenerator
val strings = List("hello", "world", "hello\u0000world", "hello", "Hello")
val distincts = strings.distinct.length
val labels = strings.map(labelGenerator.getLabelArg(_).toString).toSet
val asmLines = labelGenerator.generateConstants
assert(
asmLines.collect { case LabelDef(name) =>
name
}.length == distincts
)
assert(
asmLines.collect { case Directive.Asciz(str) => str }.length == distincts
)
assert(asmLines.collect { case LabelDef(name) => name }.toList.toSet == labels)
}
}

View File

@@ -0,0 +1,140 @@
package wacc
import org.scalatest.flatspec.AnyFlatSpec
import cats.data.Chain
class StackSpec extends AnyFlatSpec {
import microWacc._
import assemblyIR._
import assemblyIR.Size._
import assemblyIR.commonRegisters._
import types.{KnownType, ?}
import sizeExtensions.size
private val RSP = Register(Q64, RegName.SP)
"size" should "be 0 initially" in {
val stack = new Stack
assert(stack.size == 0)
}
"push" should "add an expression to the stack" in {
val stack = new Stack
val expr = Ident("x", 0)(?)
val result = stack.push(expr, RAX)
assert(stack.size == 1)
assert(result == Push(RAX))
}
it should "add 2 expressions to the stack" in {
val stack = new Stack
val expr1 = Ident("x", 0)(?)
val expr2 = Ident("x", 1)(?)
val result1 = stack.push(expr1, RAX)
val result2 = stack.push(expr2, RCX)
assert(stack.size == 2)
assert(result1 == Push(RAX))
assert(result2 == Push(RCX))
}
it should "add a value to the stack" in {
val stack = new Stack
val result = stack.push(D32, RAX)
assert(stack.size == 1)
assert(result == Push(RAX))
}
"reserve" should "reserve space for an identifier" in {
val stack = new Stack
val ident = Ident("x", 0)(KnownType.Int)
val result = stack.reserve(ident)
assert(stack.size == 1)
assert(result == Subtract(RSP, ImmediateVal(Q64.toInt)))
}
it should "reserve space for a register" in {
val stack = new Stack
val result = stack.reserve(RAX)
assert(stack.size == 1)
assert(result == Subtract(RSP, ImmediateVal(Q64.toInt)))
}
it should "reserve space for multiple values" in {
val stack = new Stack
val result = stack.reserve(D32, Q64, B8)
assert(stack.size == 3)
assert(result == Subtract(RSP, ImmediateVal(Q64.toInt * 3)))
}
"pop" should "remove the last value from the stack" in {
val stack = new Stack
stack.push(D32, RAX)
val result = stack.pop(RAX)
assert(stack.size == 0)
assert(result == Pop(RAX))
}
"drop" should "remove the last 2 value from the stack" in {
val stack = new Stack
stack.push(D32, RAX)
stack.push(Q64, RAX)
stack.push(B8, RAX)
val result = stack.drop(2)
assert(stack.size == 1)
assert(result == Add(RSP, ImmediateVal(Q64.toInt * 2)))
}
"withScope" should "reset stack after block" in {
val stack = new Stack
stack.push(D32, RAX)
stack.push(Q64, RCX)
stack.push(B8, RDX)
val result = stack.withScope(() =>
Chain(
stack.push(Q64, RSI),
stack.push(B8, RDI),
stack.push(B8, RBP)
)
)
assert(stack.size == 3)
assert(
result == Chain(
Push(RSI),
Push(RDI),
Push(RBP),
Add(RSP, ImmediateVal(Q64.toInt * 3))
)
)
}
"accessVar" should "return the correctly-sized memory location for the identifier" in {
val stack = new Stack
val id = Ident("x", 0)(KnownType.Int)
stack.push(Q64, RAX)
stack.push(id, RCX)
stack.push(B8, RDX)
stack.push(D32, RSI)
val result = stack.accessVar(Ident("x", 0)(KnownType.Int))
assert(result == MemLocation(RSP, Q64.toInt * 2, opSize = Some(KnownType.Int.size)))
}
"contains" should "return true if the stack contains the identifier" in {
val stack = new Stack
val id = Ident("x", 0)(KnownType.Int)
stack.push(D32, RAX)
stack.push(id, RCX)
stack.push(B8, RDX)
assert(stack.contains(id))
assert(!stack.contains(Ident("x", 1)(KnownType.Int)))
}
"head" should "return the correct memory location for the last element" in {
val stack = new Stack
val id = Ident("x", 0)(KnownType.Int)
stack.push(D32, RAX)
stack.push(id, RCX)
stack.push(B8, RDX)
val result = stack.head
assert(result == MemLocation(RSP, opSize = Some(B8)))
}
}