fix: add opSize back in to stack #38
@@ -23,7 +23,7 @@ class Stack {
|
||||
|
||||
/** Push an expression onto the stack. */
|
||||
def push(expr: mw.Expr, src: Register): AsmLine = {
|
||||
stack += expr -> StackValue(src.size, sizeBytes)
|
||||
stack += expr -> StackValue(expr.ty.size, sizeBytes)
|
||||
Push(src)
|
||||
}
|
||||
|
||||
@@ -81,7 +81,7 @@ class Stack {
|
||||
|
||||
/** Get an MemLocation for a variable in the stack. */
|
||||
def accessVar(ident: mw.Ident): MemLocation =
|
||||
MemLocation(RSP, sizeBytes - stack(ident).bottom)
|
||||
MemLocation(RSP, sizeBytes - stack(ident).bottom, opSize = Some(stack(ident).size))
|
||||
|
||||
def contains(ident: mw.Ident): Boolean = stack.contains(ident)
|
||||
def head: MemLocation = MemLocation(RSP, opSize = Some(stack.last._2.size))
|
||||
|
||||
66
src/test/wacc/backend/extensionsSpec.scala
Normal file
66
src/test/wacc/backend/extensionsSpec.scala
Normal file
@@ -0,0 +1,66 @@
|
||||
package wacc
|
||||
|
||||
import org.scalatest.flatspec.AnyFlatSpec
|
||||
import org.scalatest.Inspectors.forEvery
|
||||
import cats.data.Chain
|
||||
|
||||
class ExtensionsSpec extends AnyFlatSpec {
|
||||
import asmGenerator.concatAll
|
||||
import asmGenerator.escaped
|
||||
|
||||
behavior of "concatAll"
|
||||
|
||||
it should "handle int chains" in {
|
||||
val chain = Chain(1, 2, 3).concatAll(
|
||||
Chain(4, 5, 6),
|
||||
Chain.empty,
|
||||
Chain.one(-1)
|
||||
)
|
||||
assert(chain == Chain(1, 2, 3, 4, 5, 6, -1))
|
||||
}
|
||||
|
||||
it should "handle AsmLine chains" in {
|
||||
object lines {
|
||||
import assemblyIR._
|
||||
import assemblyIR.commonRegisters._
|
||||
val main = LabelDef("main")
|
||||
val pop = Pop(RAX)
|
||||
val add = Add(RAX, ImmediateVal(1))
|
||||
val push = Push(RAX)
|
||||
val ret = Return()
|
||||
}
|
||||
val chain = Chain(lines.main).concatAll(
|
||||
Chain.empty,
|
||||
Chain.one(lines.pop),
|
||||
Chain(lines.add, lines.push),
|
||||
Chain.one(lines.ret)
|
||||
)
|
||||
assert(chain == Chain(lines.main, lines.pop, lines.add, lines.push, lines.ret))
|
||||
}
|
||||
|
||||
behavior of "escaped"
|
||||
|
||||
val escapedStrings = Map(
|
||||
"hello" -> "hello",
|
||||
"world" -> "world",
|
||||
"hello\nworld" -> "hello\\nworld",
|
||||
"hello\tworld" -> "hello\\tworld",
|
||||
"hello\\world" -> "hello\\\\world",
|
||||
"hello\"world" -> "hello\\\"world",
|
||||
"hello'world" -> "hello\\'world",
|
||||
"hello\\nworld" -> "hello\\\\nworld",
|
||||
"hello\\tworld" -> "hello\\\\tworld",
|
||||
"hello\\\\world" -> "hello\\\\\\\\world",
|
||||
"hello\\\"world" -> "hello\\\\\\\"world",
|
||||
"hello\\'world" -> "hello\\\\\\'world",
|
||||
"hello\\n\\t\\'world" -> "hello\\\\n\\\\t\\\\\\'world",
|
||||
"hello\u0000world" -> "hello\\0world",
|
||||
"hello\bworld" -> "hello\\bworld",
|
||||
"hello\fworld" -> "hello\\fworld"
|
||||
)
|
||||
forEvery(escapedStrings) { (input, expected) =>
|
||||
it should s"return $expected" in {
|
||||
assert(input.escaped == expected)
|
||||
}
|
||||
}
|
||||
}
|
||||
85
src/test/wacc/backend/labelGeneratorSpec.scala
Normal file
85
src/test/wacc/backend/labelGeneratorSpec.scala
Normal file
@@ -0,0 +1,85 @@
|
||||
package wacc
|
||||
|
||||
import org.scalatest.flatspec.AnyFlatSpec
|
||||
|
||||
class LabelGeneratorSpec extends AnyFlatSpec {
|
||||
import microWacc._
|
||||
import assemblyIR.{LabelDef, LabelArg, Directive}
|
||||
import types.?
|
||||
|
||||
"getLabel" should "return unique labels" in {
|
||||
val labelGenerator = new LabelGenerator
|
||||
val labels = (1 to 10).map(_ => labelGenerator.getLabel())
|
||||
assert(labels.distinct.length == labels.length)
|
||||
}
|
||||
|
||||
"getLabelDef" should "return unique labels" in {
|
||||
assert(
|
||||
LabelDef("test") == LabelDef("test") &&
|
||||
LabelDef("test").hashCode == LabelDef("test").hashCode,
|
||||
"Sanity check: LabelDef should be case-classes"
|
||||
)
|
||||
|
||||
val labelGenerator = new LabelGenerator
|
||||
val labels = (List(
|
||||
Builtin.Exit,
|
||||
Builtin.Printf,
|
||||
Ident("exit", 0)(?),
|
||||
Ident("test", 0)(?)
|
||||
) ++ RuntimeError.all.toList).map(labelGenerator.getLabelDef(_))
|
||||
assert(labels.distinct.length == labels.length)
|
||||
}
|
||||
|
||||
"getLabelArg" should "return unique labels" in {
|
||||
assert(
|
||||
LabelArg("test") == LabelArg("test") &&
|
||||
LabelArg("test").hashCode == LabelArg("test").hashCode,
|
||||
"Sanity check: LabelArg should be case-classes"
|
||||
)
|
||||
|
||||
val labelGenerator = new LabelGenerator
|
||||
val labels = (List(
|
||||
Builtin.Exit,
|
||||
Builtin.Printf,
|
||||
Ident("exit", 0)(?),
|
||||
Ident("test", 0)(?),
|
||||
"test",
|
||||
"test",
|
||||
"test3"
|
||||
) ++ RuntimeError.all.toList).map {
|
||||
case s: String => labelGenerator.getLabelArg(s)
|
||||
case t: (CallTarget | RuntimeError) => labelGenerator.getLabelArg(t)
|
||||
}
|
||||
assert(labels.distinct.length == labels.distinct.length)
|
||||
}
|
||||
|
||||
it should "return consistent labels to getLabelDef" in {
|
||||
val labelGenerator = new LabelGenerator
|
||||
val targets = (List(
|
||||
Builtin.Exit,
|
||||
Builtin.Printf,
|
||||
Ident("exit", 0)(?),
|
||||
Ident("test", 0)(?)
|
||||
) ++ RuntimeError.all.toList)
|
||||
val labelDefs = targets.map(labelGenerator.getLabelDef(_).toString.dropRight(1)).toSet
|
||||
val labelArgs = targets.map(labelGenerator.getLabelArg(_).toString).toSet
|
||||
assert(labelDefs == labelArgs)
|
||||
}
|
||||
|
||||
"generateConstants" should "generate de-duplicated labels for strings" in {
|
||||
val labelGenerator = new LabelGenerator
|
||||
val strings = List("hello", "world", "hello\u0000world", "hello", "Hello")
|
||||
val distincts = strings.distinct.length
|
||||
val labels = strings.map(labelGenerator.getLabelArg(_).toString).toSet
|
||||
val asmLines = labelGenerator.generateConstants
|
||||
assert(
|
||||
asmLines.collect { case LabelDef(name) =>
|
||||
name
|
||||
}.length == distincts
|
||||
)
|
||||
assert(
|
||||
asmLines.collect { case Directive.Asciz(str) => str }.length == distincts
|
||||
)
|
||||
assert(asmLines.collect { case LabelDef(name) => name }.toList.toSet == labels)
|
||||
}
|
||||
}
|
||||
140
src/test/wacc/backend/stackSpec.scala
Normal file
140
src/test/wacc/backend/stackSpec.scala
Normal file
@@ -0,0 +1,140 @@
|
||||
package wacc
|
||||
|
||||
import org.scalatest.flatspec.AnyFlatSpec
|
||||
import cats.data.Chain
|
||||
|
||||
class StackSpec extends AnyFlatSpec {
|
||||
import microWacc._
|
||||
import assemblyIR._
|
||||
import assemblyIR.Size._
|
||||
import assemblyIR.commonRegisters._
|
||||
import types.{KnownType, ?}
|
||||
import sizeExtensions.size
|
||||
|
||||
private val RSP = Register(Q64, RegName.SP)
|
||||
|
||||
"size" should "be 0 initially" in {
|
||||
val stack = new Stack
|
||||
assert(stack.size == 0)
|
||||
}
|
||||
|
||||
"push" should "add an expression to the stack" in {
|
||||
val stack = new Stack
|
||||
val expr = Ident("x", 0)(?)
|
||||
val result = stack.push(expr, RAX)
|
||||
assert(stack.size == 1)
|
||||
assert(result == Push(RAX))
|
||||
}
|
||||
|
||||
it should "add 2 expressions to the stack" in {
|
||||
val stack = new Stack
|
||||
val expr1 = Ident("x", 0)(?)
|
||||
val expr2 = Ident("x", 1)(?)
|
||||
val result1 = stack.push(expr1, RAX)
|
||||
val result2 = stack.push(expr2, RCX)
|
||||
assert(stack.size == 2)
|
||||
assert(result1 == Push(RAX))
|
||||
assert(result2 == Push(RCX))
|
||||
}
|
||||
|
||||
it should "add a value to the stack" in {
|
||||
val stack = new Stack
|
||||
val result = stack.push(D32, RAX)
|
||||
assert(stack.size == 1)
|
||||
assert(result == Push(RAX))
|
||||
}
|
||||
|
||||
"reserve" should "reserve space for an identifier" in {
|
||||
val stack = new Stack
|
||||
val ident = Ident("x", 0)(KnownType.Int)
|
||||
val result = stack.reserve(ident)
|
||||
assert(stack.size == 1)
|
||||
assert(result == Subtract(RSP, ImmediateVal(Q64.toInt)))
|
||||
}
|
||||
|
||||
it should "reserve space for a register" in {
|
||||
val stack = new Stack
|
||||
val result = stack.reserve(RAX)
|
||||
assert(stack.size == 1)
|
||||
assert(result == Subtract(RSP, ImmediateVal(Q64.toInt)))
|
||||
}
|
||||
|
||||
it should "reserve space for multiple values" in {
|
||||
val stack = new Stack
|
||||
val result = stack.reserve(D32, Q64, B8)
|
||||
assert(stack.size == 3)
|
||||
assert(result == Subtract(RSP, ImmediateVal(Q64.toInt * 3)))
|
||||
}
|
||||
|
||||
"pop" should "remove the last value from the stack" in {
|
||||
val stack = new Stack
|
||||
stack.push(D32, RAX)
|
||||
val result = stack.pop(RAX)
|
||||
assert(stack.size == 0)
|
||||
assert(result == Pop(RAX))
|
||||
}
|
||||
|
||||
"drop" should "remove the last 2 value from the stack" in {
|
||||
val stack = new Stack
|
||||
stack.push(D32, RAX)
|
||||
stack.push(Q64, RAX)
|
||||
stack.push(B8, RAX)
|
||||
val result = stack.drop(2)
|
||||
assert(stack.size == 1)
|
||||
assert(result == Add(RSP, ImmediateVal(Q64.toInt * 2)))
|
||||
}
|
||||
|
||||
"withScope" should "reset stack after block" in {
|
||||
val stack = new Stack
|
||||
stack.push(D32, RAX)
|
||||
stack.push(Q64, RCX)
|
||||
stack.push(B8, RDX)
|
||||
val result = stack.withScope(() =>
|
||||
Chain(
|
||||
stack.push(Q64, RSI),
|
||||
stack.push(B8, RDI),
|
||||
stack.push(B8, RBP)
|
||||
)
|
||||
)
|
||||
assert(stack.size == 3)
|
||||
assert(
|
||||
result == Chain(
|
||||
Push(RSI),
|
||||
Push(RDI),
|
||||
Push(RBP),
|
||||
Add(RSP, ImmediateVal(Q64.toInt * 3))
|
||||
)
|
||||
)
|
||||
}
|
||||
|
||||
"accessVar" should "return the correctly-sized memory location for the identifier" in {
|
||||
val stack = new Stack
|
||||
val id = Ident("x", 0)(KnownType.Int)
|
||||
stack.push(Q64, RAX)
|
||||
stack.push(id, RCX)
|
||||
stack.push(B8, RDX)
|
||||
stack.push(D32, RSI)
|
||||
val result = stack.accessVar(Ident("x", 0)(KnownType.Int))
|
||||
assert(result == MemLocation(RSP, Q64.toInt * 2, opSize = Some(KnownType.Int.size)))
|
||||
}
|
||||
|
||||
"contains" should "return true if the stack contains the identifier" in {
|
||||
val stack = new Stack
|
||||
val id = Ident("x", 0)(KnownType.Int)
|
||||
stack.push(D32, RAX)
|
||||
stack.push(id, RCX)
|
||||
stack.push(B8, RDX)
|
||||
assert(stack.contains(id))
|
||||
assert(!stack.contains(Ident("x", 1)(KnownType.Int)))
|
||||
}
|
||||
|
||||
"head" should "return the correct memory location for the last element" in {
|
||||
val stack = new Stack
|
||||
val id = Ident("x", 0)(KnownType.Int)
|
||||
stack.push(D32, RAX)
|
||||
stack.push(id, RCX)
|
||||
stack.push(B8, RDX)
|
||||
val result = stack.head
|
||||
assert(result == MemLocation(RSP, opSize = Some(B8)))
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user