refactor: extract Stack, proper register naming and sizes
This commit is contained in:
parent
52ed404a73
commit
691d989b92
84
src/main/wacc/backend/Stack.scala
Normal file
84
src/main/wacc/backend/Stack.scala
Normal file
@ -0,0 +1,84 @@
|
||||
package wacc
|
||||
|
||||
import scala.collection.mutable.LinkedHashMap
|
||||
|
||||
class Stack {
|
||||
import assemblyIR._
|
||||
import sizeExtensions.size
|
||||
import microWacc as mw
|
||||
|
||||
private val RSP = Register(Size.Q64, RegName.SP)
|
||||
private class StackValue(val size: Size, val offset: Int) {
|
||||
def bottom: Int = offset + size.toInt
|
||||
}
|
||||
private val stack = LinkedHashMap[mw.Expr | Int, StackValue]()
|
||||
|
||||
/** The stack's size in bytes. */
|
||||
def size: Int = if stack.isEmpty then 0 else stack.last._2.bottom
|
||||
|
||||
/** Push an expression onto the stack. */
|
||||
def push(expr: mw.Expr, src: Register): AsmLine = {
|
||||
stack += expr -> StackValue(src.size, size)
|
||||
Push(src)
|
||||
}
|
||||
|
||||
/** Push an arbitrary register onto the stack. */
|
||||
def push(src: Register): AsmLine = {
|
||||
stack += stack.size -> StackValue(src.size, size)
|
||||
Push(src)
|
||||
}
|
||||
|
||||
/** Reserve space for a variable on the stack. */
|
||||
def reserve(ident: mw.Ident): AsmLine = {
|
||||
stack += ident -> StackValue(ident.ty.size, size)
|
||||
Subtract(RSP, ImmediateVal(ident.ty.size.toInt))
|
||||
}
|
||||
|
||||
/** Reserve space for values on the stack.
|
||||
*
|
||||
* @param sizes
|
||||
* The sizes of the values to reserve space for.
|
||||
*/
|
||||
def reserve(sizes: List[Size]): AsmLine = {
|
||||
val totalSize = sizes
|
||||
.map(itemSize =>
|
||||
stack += stack.size -> StackValue(itemSize, size)
|
||||
itemSize.toInt
|
||||
)
|
||||
.sum
|
||||
Subtract(RSP, ImmediateVal(totalSize))
|
||||
}
|
||||
|
||||
/** Pop a value from the stack into a register. Sizes MUST match. */
|
||||
def pop(dest: Register): AsmLine = {
|
||||
if (dest.size != stack.last._2.size) {
|
||||
throw new IllegalArgumentException(
|
||||
s"Cannot pop ${stack.last._2.size} bytes into $dest (${dest.size} bytes) register"
|
||||
)
|
||||
}
|
||||
stack.remove(stack.last._1)
|
||||
Pop(dest)
|
||||
}
|
||||
|
||||
/** Drop the top n values from the stack. */
|
||||
def drop(n: Int = 1): AsmLine = {
|
||||
val totalSize = (1 to n)
|
||||
.map(_ =>
|
||||
val itemSize = stack.last._2.size.toInt
|
||||
stack.remove(stack.last._1)
|
||||
itemSize
|
||||
)
|
||||
.sum
|
||||
Add(RSP, ImmediateVal(totalSize))
|
||||
}
|
||||
|
||||
/** Get a lazy IndexAddress for a variable in the stack. */
|
||||
def accessVar(ident: mw.Ident): () => IndexAddress = () => {
|
||||
IndexAddress(RSP, stack.size - stack(ident).bottom)
|
||||
}
|
||||
def contains(ident: mw.Ident): Boolean = stack.contains(ident)
|
||||
def head: MemLocation = MemLocation(RSP)
|
||||
def head(offset: Size): MemLocation = MemLocation(RSP, Some(offset))
|
||||
// TODO: Might want to actually properly handle this with the LinkedHashMap too
|
||||
def align(): AsmLine = And(RSP, ImmediateVal(-16))
|
||||
}
|
@ -6,40 +6,73 @@ object assemblyIR {
|
||||
sealed trait Operand
|
||||
sealed trait Src extends Operand // mem location, register and imm value
|
||||
sealed trait Dest extends Operand // mem location and register
|
||||
enum RegSize {
|
||||
case R64
|
||||
case E32
|
||||
case Byte
|
||||
|
||||
override def toString = this match {
|
||||
case R64 => "r"
|
||||
case E32 => "e"
|
||||
case Byte => ""
|
||||
enum Size {
|
||||
case Q64, D32, W16, B8
|
||||
|
||||
def toInt: Int = this match {
|
||||
case Q64 => 8
|
||||
case D32 => 4
|
||||
case W16 => 2
|
||||
case B8 => 1
|
||||
}
|
||||
|
||||
private val ptr = "ptr "
|
||||
|
||||
override def toString(): String = this match {
|
||||
case Q64 => "qword " + ptr
|
||||
case D32 => "dword " + ptr
|
||||
case W16 => "word " + ptr
|
||||
case B8 => "byte " + ptr
|
||||
}
|
||||
}
|
||||
|
||||
enum RegName {
|
||||
case AX, AL, BX, CX, DX, SI, DI, SP, BP, IP, Reg8, Reg9, Reg10, Reg11, Reg12, Reg13, Reg14,
|
||||
Reg15
|
||||
override def toString = this match {
|
||||
case AX => "ax"
|
||||
case AL => "al"
|
||||
case BX => "bx"
|
||||
case CX => "cx"
|
||||
case DX => "dx"
|
||||
case SI => "si"
|
||||
case DI => "di"
|
||||
case SP => "sp"
|
||||
case BP => "bp"
|
||||
case IP => "ip"
|
||||
case Reg8 => "8"
|
||||
case Reg9 => "9"
|
||||
case Reg10 => "10"
|
||||
case Reg11 => "11"
|
||||
case Reg12 => "12"
|
||||
case Reg13 => "13"
|
||||
case Reg14 => "14"
|
||||
case Reg15 => "15"
|
||||
case AX, BX, CX, DX, SI, DI, SP, BP, IP, R8, R9, R10, R11, R12, R13, R14, R15
|
||||
}
|
||||
|
||||
case class Register(size: Size, name: RegName) extends Dest with Src {
|
||||
import RegName._
|
||||
|
||||
if (size == Size.B8 && name == RegName.IP) {
|
||||
throw new IllegalArgumentException("Cannot have 8 bit register for IP")
|
||||
}
|
||||
override def toString = name match {
|
||||
case AX => tradToString("ax", "al")
|
||||
case BX => tradToString("bx", "bl")
|
||||
case CX => tradToString("cx", "cl")
|
||||
case DX => tradToString("dx", "dl")
|
||||
case SI => tradToString("si", "sil")
|
||||
case DI => tradToString("di", "dil")
|
||||
case SP => tradToString("sp", "spl")
|
||||
case BP => tradToString("bp", "bpl")
|
||||
case IP => tradToString("ip", "#INVALID")
|
||||
case R8 => newToString(8)
|
||||
case R9 => newToString(9)
|
||||
case R10 => newToString(10)
|
||||
case R11 => newToString(11)
|
||||
case R12 => newToString(12)
|
||||
case R13 => newToString(13)
|
||||
case R14 => newToString(14)
|
||||
case R15 => newToString(15)
|
||||
}
|
||||
|
||||
private def tradToString(base: String, byteName: String): String =
|
||||
size match {
|
||||
case Size.Q64 => "r" + base
|
||||
case Size.D32 => "e" + base
|
||||
case Size.W16 => base
|
||||
case Size.B8 => byteName
|
||||
}
|
||||
|
||||
private def newToString(base: Int): String = {
|
||||
val b = base.toString
|
||||
"r" + (size match {
|
||||
case Size.Q64 => b
|
||||
case Size.D32 => b + "d"
|
||||
case Size.W16 => b + "w"
|
||||
case Size.B8 => b + "b"
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@ -64,24 +97,18 @@ object assemblyIR {
|
||||
}
|
||||
}
|
||||
|
||||
// TODO register naming conventions are wrong
|
||||
case class Register(size: RegSize, name: RegName) extends Dest with Src {
|
||||
override def toString = s"${size}${name}"
|
||||
}
|
||||
case class MemLocation(pointer: Long | Register, opSize: SizeDir = SizeDir.Unspecified)
|
||||
extends Dest
|
||||
with Src {
|
||||
override def toString = pointer match {
|
||||
case hex: Long => opSize.toString + f"[0x$hex%X]"
|
||||
case reg: Register => opSize.toString + s"[$reg]"
|
||||
}
|
||||
case class MemLocation(pointer: Register, opSize: Option[Size] = None) extends Dest with Src {
|
||||
def this(pointer: Register, opSize: Size) = this(pointer, Some(opSize))
|
||||
|
||||
override def toString =
|
||||
opSize.getOrElse("").toString + s"[$pointer]"
|
||||
}
|
||||
|
||||
// TODO to string is wacky
|
||||
case class IndexAddress(
|
||||
base: Register,
|
||||
offset: Int | LabelArg,
|
||||
indexReg: Register = Register(RegSize.R64, RegName.AX),
|
||||
indexReg: Register = Register(Size.Q64, RegName.AX),
|
||||
scale: Int = 0
|
||||
) extends Dest
|
||||
with Src {
|
||||
@ -188,17 +215,4 @@ object assemblyIR {
|
||||
case String => "%s"
|
||||
}
|
||||
}
|
||||
|
||||
enum SizeDir {
|
||||
case Byte, Word, DWord, Unspecified
|
||||
|
||||
private val ptr = "ptr "
|
||||
|
||||
override def toString(): String = this match {
|
||||
case Byte => "byte " + ptr
|
||||
case Word => "word " + ptr
|
||||
case DWord => "dword " + ptr
|
||||
case Unspecified => ""
|
||||
}
|
||||
}
|
||||
}
|
||||
|
29
src/main/wacc/backend/sizeExtensions.scala
Normal file
29
src/main/wacc/backend/sizeExtensions.scala
Normal file
@ -0,0 +1,29 @@
|
||||
package wacc
|
||||
|
||||
object sizeExtensions {
|
||||
import microWacc._
|
||||
import types._
|
||||
import assemblyIR.Size
|
||||
|
||||
extension (expr: Expr) {
|
||||
|
||||
/** Calculate the size (bytes) of the heap required for the expression. */
|
||||
def heapSize: Int = (expr, expr.ty) match {
|
||||
case (ArrayLiter(elems), KnownType.Array(KnownType.Char)) =>
|
||||
KnownType.Int.size.toInt + elems.size.toInt * KnownType.Char.size.toInt
|
||||
case (ArrayLiter(elems), _) =>
|
||||
KnownType.Int.size.toInt + elems.map(_.ty.size.toInt).sum
|
||||
case _ => expr.ty.size.toInt
|
||||
}
|
||||
}
|
||||
|
||||
extension (ty: SemType) {
|
||||
|
||||
/** Calculate the size (bytes) of a type in a register. */
|
||||
def size: Size = ty match {
|
||||
case KnownType.Int => Size.D32
|
||||
case KnownType.Bool | KnownType.Char => Size.B8
|
||||
case KnownType.String | KnownType.Array(_) | KnownType.Pair(_, _) | ? => Size.Q64
|
||||
}
|
||||
}
|
||||
}
|
Loading…
x
Reference in New Issue
Block a user