refactor: extract Stack, proper register naming and sizes

This commit is contained in:
2025-02-27 01:49:22 +00:00
parent 808a59f58a
commit 58df1d7bb9
3 changed files with 181 additions and 54 deletions

View File

@@ -0,0 +1,84 @@
package wacc
import scala.collection.mutable.LinkedHashMap
class Stack {
import assemblyIR._
import sizeExtensions.size
import microWacc as mw
private val RSP = Register(Size.Q64, RegName.SP)
private class StackValue(val size: Size, val offset: Int) {
def bottom: Int = offset + size.toInt
}
private val stack = LinkedHashMap[mw.Expr | Int, StackValue]()
/** The stack's size in bytes. */
def size: Int = if stack.isEmpty then 0 else stack.last._2.bottom
/** Push an expression onto the stack. */
def push(expr: mw.Expr, src: Register): AsmLine = {
stack += expr -> StackValue(src.size, size)
Push(src)
}
/** Push an arbitrary register onto the stack. */
def push(src: Register): AsmLine = {
stack += stack.size -> StackValue(src.size, size)
Push(src)
}
/** Reserve space for a variable on the stack. */
def reserve(ident: mw.Ident): AsmLine = {
stack += ident -> StackValue(ident.ty.size, size)
Subtract(RSP, ImmediateVal(ident.ty.size.toInt))
}
/** Reserve space for values on the stack.
*
* @param sizes
* The sizes of the values to reserve space for.
*/
def reserve(sizes: List[Size]): AsmLine = {
val totalSize = sizes
.map(itemSize =>
stack += stack.size -> StackValue(itemSize, size)
itemSize.toInt
)
.sum
Subtract(RSP, ImmediateVal(totalSize))
}
/** Pop a value from the stack into a register. Sizes MUST match. */
def pop(dest: Register): AsmLine = {
if (dest.size != stack.last._2.size) {
throw new IllegalArgumentException(
s"Cannot pop ${stack.last._2.size} bytes into $dest (${dest.size} bytes) register"
)
}
stack.remove(stack.last._1)
Pop(dest)
}
/** Drop the top n values from the stack. */
def drop(n: Int = 1): AsmLine = {
val totalSize = (1 to n)
.map(_ =>
val itemSize = stack.last._2.size.toInt
stack.remove(stack.last._1)
itemSize
)
.sum
Add(RSP, ImmediateVal(totalSize))
}
/** Get a lazy IndexAddress for a variable in the stack. */
def accessVar(ident: mw.Ident): () => IndexAddress = () => {
IndexAddress(RSP, stack.size - stack(ident).bottom)
}
def contains(ident: mw.Ident): Boolean = stack.contains(ident)
def head: MemLocation = MemLocation(RSP)
def head(offset: Size): MemLocation = MemLocation(RSP, Some(offset))
// TODO: Might want to actually properly handle this with the LinkedHashMap too
def align(): AsmLine = And(RSP, ImmediateVal(-16))
}