refactor: extract Stack, proper register naming and sizes

This commit is contained in:
Gleb Koval 2025-02-27 01:49:22 +00:00
parent 52ed404a73
commit 691d989b92
Signed by: cyclane
GPG Key ID: 15E168A8B332382C
3 changed files with 181 additions and 54 deletions

View File

@ -0,0 +1,84 @@
package wacc
import scala.collection.mutable.LinkedHashMap
class Stack {
import assemblyIR._
import sizeExtensions.size
import microWacc as mw
private val RSP = Register(Size.Q64, RegName.SP)
private class StackValue(val size: Size, val offset: Int) {
def bottom: Int = offset + size.toInt
}
private val stack = LinkedHashMap[mw.Expr | Int, StackValue]()
/** The stack's size in bytes. */
def size: Int = if stack.isEmpty then 0 else stack.last._2.bottom
/** Push an expression onto the stack. */
def push(expr: mw.Expr, src: Register): AsmLine = {
stack += expr -> StackValue(src.size, size)
Push(src)
}
/** Push an arbitrary register onto the stack. */
def push(src: Register): AsmLine = {
stack += stack.size -> StackValue(src.size, size)
Push(src)
}
/** Reserve space for a variable on the stack. */
def reserve(ident: mw.Ident): AsmLine = {
stack += ident -> StackValue(ident.ty.size, size)
Subtract(RSP, ImmediateVal(ident.ty.size.toInt))
}
/** Reserve space for values on the stack.
*
* @param sizes
* The sizes of the values to reserve space for.
*/
def reserve(sizes: List[Size]): AsmLine = {
val totalSize = sizes
.map(itemSize =>
stack += stack.size -> StackValue(itemSize, size)
itemSize.toInt
)
.sum
Subtract(RSP, ImmediateVal(totalSize))
}
/** Pop a value from the stack into a register. Sizes MUST match. */
def pop(dest: Register): AsmLine = {
if (dest.size != stack.last._2.size) {
throw new IllegalArgumentException(
s"Cannot pop ${stack.last._2.size} bytes into $dest (${dest.size} bytes) register"
)
}
stack.remove(stack.last._1)
Pop(dest)
}
/** Drop the top n values from the stack. */
def drop(n: Int = 1): AsmLine = {
val totalSize = (1 to n)
.map(_ =>
val itemSize = stack.last._2.size.toInt
stack.remove(stack.last._1)
itemSize
)
.sum
Add(RSP, ImmediateVal(totalSize))
}
/** Get a lazy IndexAddress for a variable in the stack. */
def accessVar(ident: mw.Ident): () => IndexAddress = () => {
IndexAddress(RSP, stack.size - stack(ident).bottom)
}
def contains(ident: mw.Ident): Boolean = stack.contains(ident)
def head: MemLocation = MemLocation(RSP)
def head(offset: Size): MemLocation = MemLocation(RSP, Some(offset))
// TODO: Might want to actually properly handle this with the LinkedHashMap too
def align(): AsmLine = And(RSP, ImmediateVal(-16))
}

View File

@ -6,40 +6,73 @@ object assemblyIR {
sealed trait Operand
sealed trait Src extends Operand // mem location, register and imm value
sealed trait Dest extends Operand // mem location and register
enum RegSize {
case R64
case E32
case Byte
override def toString = this match {
case R64 => "r"
case E32 => "e"
case Byte => ""
enum Size {
case Q64, D32, W16, B8
def toInt: Int = this match {
case Q64 => 8
case D32 => 4
case W16 => 2
case B8 => 1
}
private val ptr = "ptr "
override def toString(): String = this match {
case Q64 => "qword " + ptr
case D32 => "dword " + ptr
case W16 => "word " + ptr
case B8 => "byte " + ptr
}
}
enum RegName {
case AX, AL, BX, CX, DX, SI, DI, SP, BP, IP, Reg8, Reg9, Reg10, Reg11, Reg12, Reg13, Reg14,
Reg15
override def toString = this match {
case AX => "ax"
case AL => "al"
case BX => "bx"
case CX => "cx"
case DX => "dx"
case SI => "si"
case DI => "di"
case SP => "sp"
case BP => "bp"
case IP => "ip"
case Reg8 => "8"
case Reg9 => "9"
case Reg10 => "10"
case Reg11 => "11"
case Reg12 => "12"
case Reg13 => "13"
case Reg14 => "14"
case Reg15 => "15"
case AX, BX, CX, DX, SI, DI, SP, BP, IP, R8, R9, R10, R11, R12, R13, R14, R15
}
case class Register(size: Size, name: RegName) extends Dest with Src {
import RegName._
if (size == Size.B8 && name == RegName.IP) {
throw new IllegalArgumentException("Cannot have 8 bit register for IP")
}
override def toString = name match {
case AX => tradToString("ax", "al")
case BX => tradToString("bx", "bl")
case CX => tradToString("cx", "cl")
case DX => tradToString("dx", "dl")
case SI => tradToString("si", "sil")
case DI => tradToString("di", "dil")
case SP => tradToString("sp", "spl")
case BP => tradToString("bp", "bpl")
case IP => tradToString("ip", "#INVALID")
case R8 => newToString(8)
case R9 => newToString(9)
case R10 => newToString(10)
case R11 => newToString(11)
case R12 => newToString(12)
case R13 => newToString(13)
case R14 => newToString(14)
case R15 => newToString(15)
}
private def tradToString(base: String, byteName: String): String =
size match {
case Size.Q64 => "r" + base
case Size.D32 => "e" + base
case Size.W16 => base
case Size.B8 => byteName
}
private def newToString(base: Int): String = {
val b = base.toString
"r" + (size match {
case Size.Q64 => b
case Size.D32 => b + "d"
case Size.W16 => b + "w"
case Size.B8 => b + "b"
})
}
}
@ -64,24 +97,18 @@ object assemblyIR {
}
}
// TODO register naming conventions are wrong
case class Register(size: RegSize, name: RegName) extends Dest with Src {
override def toString = s"${size}${name}"
}
case class MemLocation(pointer: Long | Register, opSize: SizeDir = SizeDir.Unspecified)
extends Dest
with Src {
override def toString = pointer match {
case hex: Long => opSize.toString + f"[0x$hex%X]"
case reg: Register => opSize.toString + s"[$reg]"
}
case class MemLocation(pointer: Register, opSize: Option[Size] = None) extends Dest with Src {
def this(pointer: Register, opSize: Size) = this(pointer, Some(opSize))
override def toString =
opSize.getOrElse("").toString + s"[$pointer]"
}
// TODO to string is wacky
case class IndexAddress(
base: Register,
offset: Int | LabelArg,
indexReg: Register = Register(RegSize.R64, RegName.AX),
indexReg: Register = Register(Size.Q64, RegName.AX),
scale: Int = 0
) extends Dest
with Src {
@ -188,17 +215,4 @@ object assemblyIR {
case String => "%s"
}
}
enum SizeDir {
case Byte, Word, DWord, Unspecified
private val ptr = "ptr "
override def toString(): String = this match {
case Byte => "byte " + ptr
case Word => "word " + ptr
case DWord => "dword " + ptr
case Unspecified => ""
}
}
}

View File

@ -0,0 +1,29 @@
package wacc
object sizeExtensions {
import microWacc._
import types._
import assemblyIR.Size
extension (expr: Expr) {
/** Calculate the size (bytes) of the heap required for the expression. */
def heapSize: Int = (expr, expr.ty) match {
case (ArrayLiter(elems), KnownType.Array(KnownType.Char)) =>
KnownType.Int.size.toInt + elems.size.toInt * KnownType.Char.size.toInt
case (ArrayLiter(elems), _) =>
KnownType.Int.size.toInt + elems.map(_.ty.size.toInt).sum
case _ => expr.ty.size.toInt
}
}
extension (ty: SemType) {
/** Calculate the size (bytes) of a type in a register. */
def size: Size = ty match {
case KnownType.Int => Size.D32
case KnownType.Bool | KnownType.Char => Size.B8
case KnownType.String | KnownType.Array(_) | KnownType.Pair(_, _) | ? => Size.Q64
}
}
}