refactor: extract Stack, proper register naming and sizes

This commit is contained in:
Gleb Koval 2025-02-27 01:49:22 +00:00
parent 52ed404a73
commit 691d989b92
Signed by: cyclane
GPG Key ID: 15E168A8B332382C
3 changed files with 181 additions and 54 deletions

View File

@ -0,0 +1,84 @@
package wacc
import scala.collection.mutable.LinkedHashMap
class Stack {
import assemblyIR._
import sizeExtensions.size
import microWacc as mw
private val RSP = Register(Size.Q64, RegName.SP)
private class StackValue(val size: Size, val offset: Int) {
def bottom: Int = offset + size.toInt
}
private val stack = LinkedHashMap[mw.Expr | Int, StackValue]()
/** The stack's size in bytes. */
def size: Int = if stack.isEmpty then 0 else stack.last._2.bottom
/** Push an expression onto the stack. */
def push(expr: mw.Expr, src: Register): AsmLine = {
stack += expr -> StackValue(src.size, size)
Push(src)
}
/** Push an arbitrary register onto the stack. */
def push(src: Register): AsmLine = {
stack += stack.size -> StackValue(src.size, size)
Push(src)
}
/** Reserve space for a variable on the stack. */
def reserve(ident: mw.Ident): AsmLine = {
stack += ident -> StackValue(ident.ty.size, size)
Subtract(RSP, ImmediateVal(ident.ty.size.toInt))
}
/** Reserve space for values on the stack.
*
* @param sizes
* The sizes of the values to reserve space for.
*/
def reserve(sizes: List[Size]): AsmLine = {
val totalSize = sizes
.map(itemSize =>
stack += stack.size -> StackValue(itemSize, size)
itemSize.toInt
)
.sum
Subtract(RSP, ImmediateVal(totalSize))
}
/** Pop a value from the stack into a register. Sizes MUST match. */
def pop(dest: Register): AsmLine = {
if (dest.size != stack.last._2.size) {
throw new IllegalArgumentException(
s"Cannot pop ${stack.last._2.size} bytes into $dest (${dest.size} bytes) register"
)
}
stack.remove(stack.last._1)
Pop(dest)
}
/** Drop the top n values from the stack. */
def drop(n: Int = 1): AsmLine = {
val totalSize = (1 to n)
.map(_ =>
val itemSize = stack.last._2.size.toInt
stack.remove(stack.last._1)
itemSize
)
.sum
Add(RSP, ImmediateVal(totalSize))
}
/** Get a lazy IndexAddress for a variable in the stack. */
def accessVar(ident: mw.Ident): () => IndexAddress = () => {
IndexAddress(RSP, stack.size - stack(ident).bottom)
}
def contains(ident: mw.Ident): Boolean = stack.contains(ident)
def head: MemLocation = MemLocation(RSP)
def head(offset: Size): MemLocation = MemLocation(RSP, Some(offset))
// TODO: Might want to actually properly handle this with the LinkedHashMap too
def align(): AsmLine = And(RSP, ImmediateVal(-16))
}

View File

@ -6,40 +6,73 @@ object assemblyIR {
sealed trait Operand sealed trait Operand
sealed trait Src extends Operand // mem location, register and imm value sealed trait Src extends Operand // mem location, register and imm value
sealed trait Dest extends Operand // mem location and register sealed trait Dest extends Operand // mem location and register
enum RegSize {
case R64
case E32
case Byte
override def toString = this match { enum Size {
case R64 => "r" case Q64, D32, W16, B8
case E32 => "e"
case Byte => "" def toInt: Int = this match {
case Q64 => 8
case D32 => 4
case W16 => 2
case B8 => 1
}
private val ptr = "ptr "
override def toString(): String = this match {
case Q64 => "qword " + ptr
case D32 => "dword " + ptr
case W16 => "word " + ptr
case B8 => "byte " + ptr
} }
} }
enum RegName { enum RegName {
case AX, AL, BX, CX, DX, SI, DI, SP, BP, IP, Reg8, Reg9, Reg10, Reg11, Reg12, Reg13, Reg14, case AX, BX, CX, DX, SI, DI, SP, BP, IP, R8, R9, R10, R11, R12, R13, R14, R15
Reg15 }
override def toString = this match {
case AX => "ax" case class Register(size: Size, name: RegName) extends Dest with Src {
case AL => "al" import RegName._
case BX => "bx"
case CX => "cx" if (size == Size.B8 && name == RegName.IP) {
case DX => "dx" throw new IllegalArgumentException("Cannot have 8 bit register for IP")
case SI => "si" }
case DI => "di" override def toString = name match {
case SP => "sp" case AX => tradToString("ax", "al")
case BP => "bp" case BX => tradToString("bx", "bl")
case IP => "ip" case CX => tradToString("cx", "cl")
case Reg8 => "8" case DX => tradToString("dx", "dl")
case Reg9 => "9" case SI => tradToString("si", "sil")
case Reg10 => "10" case DI => tradToString("di", "dil")
case Reg11 => "11" case SP => tradToString("sp", "spl")
case Reg12 => "12" case BP => tradToString("bp", "bpl")
case Reg13 => "13" case IP => tradToString("ip", "#INVALID")
case Reg14 => "14" case R8 => newToString(8)
case Reg15 => "15" case R9 => newToString(9)
case R10 => newToString(10)
case R11 => newToString(11)
case R12 => newToString(12)
case R13 => newToString(13)
case R14 => newToString(14)
case R15 => newToString(15)
}
private def tradToString(base: String, byteName: String): String =
size match {
case Size.Q64 => "r" + base
case Size.D32 => "e" + base
case Size.W16 => base
case Size.B8 => byteName
}
private def newToString(base: Int): String = {
val b = base.toString
"r" + (size match {
case Size.Q64 => b
case Size.D32 => b + "d"
case Size.W16 => b + "w"
case Size.B8 => b + "b"
})
} }
} }
@ -64,24 +97,18 @@ object assemblyIR {
} }
} }
// TODO register naming conventions are wrong case class MemLocation(pointer: Register, opSize: Option[Size] = None) extends Dest with Src {
case class Register(size: RegSize, name: RegName) extends Dest with Src { def this(pointer: Register, opSize: Size) = this(pointer, Some(opSize))
override def toString = s"${size}${name}"
} override def toString =
case class MemLocation(pointer: Long | Register, opSize: SizeDir = SizeDir.Unspecified) opSize.getOrElse("").toString + s"[$pointer]"
extends Dest
with Src {
override def toString = pointer match {
case hex: Long => opSize.toString + f"[0x$hex%X]"
case reg: Register => opSize.toString + s"[$reg]"
}
} }
// TODO to string is wacky // TODO to string is wacky
case class IndexAddress( case class IndexAddress(
base: Register, base: Register,
offset: Int | LabelArg, offset: Int | LabelArg,
indexReg: Register = Register(RegSize.R64, RegName.AX), indexReg: Register = Register(Size.Q64, RegName.AX),
scale: Int = 0 scale: Int = 0
) extends Dest ) extends Dest
with Src { with Src {
@ -188,17 +215,4 @@ object assemblyIR {
case String => "%s" case String => "%s"
} }
} }
enum SizeDir {
case Byte, Word, DWord, Unspecified
private val ptr = "ptr "
override def toString(): String = this match {
case Byte => "byte " + ptr
case Word => "word " + ptr
case DWord => "dword " + ptr
case Unspecified => ""
}
}
} }

View File

@ -0,0 +1,29 @@
package wacc
object sizeExtensions {
import microWacc._
import types._
import assemblyIR.Size
extension (expr: Expr) {
/** Calculate the size (bytes) of the heap required for the expression. */
def heapSize: Int = (expr, expr.ty) match {
case (ArrayLiter(elems), KnownType.Array(KnownType.Char)) =>
KnownType.Int.size.toInt + elems.size.toInt * KnownType.Char.size.toInt
case (ArrayLiter(elems), _) =>
KnownType.Int.size.toInt + elems.map(_.ty.size.toInt).sum
case _ => expr.ty.size.toInt
}
}
extension (ty: SemType) {
/** Calculate the size (bytes) of a type in a register. */
def size: Size = ty match {
case KnownType.Int => Size.D32
case KnownType.Bool | KnownType.Char => Size.B8
case KnownType.String | KnownType.Array(_) | KnownType.Pair(_, _) | ? => Size.Q64
}
}
}