From 691d989b9223bbf334ae0b1d0b1afff777c82203 Mon Sep 17 00:00:00 2001 From: Gleb Koval Date: Thu, 27 Feb 2025 01:49:22 +0000 Subject: [PATCH] refactor: extract Stack, proper register naming and sizes --- src/main/wacc/backend/Stack.scala | 84 ++++++++++++++ src/main/wacc/backend/assemblyIR.scala | 122 ++++++++++++--------- src/main/wacc/backend/sizeExtensions.scala | 29 +++++ 3 files changed, 181 insertions(+), 54 deletions(-) create mode 100644 src/main/wacc/backend/Stack.scala create mode 100644 src/main/wacc/backend/sizeExtensions.scala diff --git a/src/main/wacc/backend/Stack.scala b/src/main/wacc/backend/Stack.scala new file mode 100644 index 0000000..72aa5ef --- /dev/null +++ b/src/main/wacc/backend/Stack.scala @@ -0,0 +1,84 @@ +package wacc + +import scala.collection.mutable.LinkedHashMap + +class Stack { + import assemblyIR._ + import sizeExtensions.size + import microWacc as mw + + private val RSP = Register(Size.Q64, RegName.SP) + private class StackValue(val size: Size, val offset: Int) { + def bottom: Int = offset + size.toInt + } + private val stack = LinkedHashMap[mw.Expr | Int, StackValue]() + + /** The stack's size in bytes. */ + def size: Int = if stack.isEmpty then 0 else stack.last._2.bottom + + /** Push an expression onto the stack. */ + def push(expr: mw.Expr, src: Register): AsmLine = { + stack += expr -> StackValue(src.size, size) + Push(src) + } + + /** Push an arbitrary register onto the stack. */ + def push(src: Register): AsmLine = { + stack += stack.size -> StackValue(src.size, size) + Push(src) + } + + /** Reserve space for a variable on the stack. */ + def reserve(ident: mw.Ident): AsmLine = { + stack += ident -> StackValue(ident.ty.size, size) + Subtract(RSP, ImmediateVal(ident.ty.size.toInt)) + } + + /** Reserve space for values on the stack. + * + * @param sizes + * The sizes of the values to reserve space for. + */ + def reserve(sizes: List[Size]): AsmLine = { + val totalSize = sizes + .map(itemSize => + stack += stack.size -> StackValue(itemSize, size) + itemSize.toInt + ) + .sum + Subtract(RSP, ImmediateVal(totalSize)) + } + + /** Pop a value from the stack into a register. Sizes MUST match. */ + def pop(dest: Register): AsmLine = { + if (dest.size != stack.last._2.size) { + throw new IllegalArgumentException( + s"Cannot pop ${stack.last._2.size} bytes into $dest (${dest.size} bytes) register" + ) + } + stack.remove(stack.last._1) + Pop(dest) + } + + /** Drop the top n values from the stack. */ + def drop(n: Int = 1): AsmLine = { + val totalSize = (1 to n) + .map(_ => + val itemSize = stack.last._2.size.toInt + stack.remove(stack.last._1) + itemSize + ) + .sum + Add(RSP, ImmediateVal(totalSize)) + } + + /** Get a lazy IndexAddress for a variable in the stack. */ + def accessVar(ident: mw.Ident): () => IndexAddress = () => { + IndexAddress(RSP, stack.size - stack(ident).bottom) + } + def contains(ident: mw.Ident): Boolean = stack.contains(ident) + def head: MemLocation = MemLocation(RSP) + def head(offset: Size): MemLocation = MemLocation(RSP, Some(offset)) + // TODO: Might want to actually properly handle this with the LinkedHashMap too + def align(): AsmLine = And(RSP, ImmediateVal(-16)) +} diff --git a/src/main/wacc/backend/assemblyIR.scala b/src/main/wacc/backend/assemblyIR.scala index f8bbf38..2946dcb 100644 --- a/src/main/wacc/backend/assemblyIR.scala +++ b/src/main/wacc/backend/assemblyIR.scala @@ -6,40 +6,73 @@ object assemblyIR { sealed trait Operand sealed trait Src extends Operand // mem location, register and imm value sealed trait Dest extends Operand // mem location and register - enum RegSize { - case R64 - case E32 - case Byte - override def toString = this match { - case R64 => "r" - case E32 => "e" - case Byte => "" + enum Size { + case Q64, D32, W16, B8 + + def toInt: Int = this match { + case Q64 => 8 + case D32 => 4 + case W16 => 2 + case B8 => 1 + } + + private val ptr = "ptr " + + override def toString(): String = this match { + case Q64 => "qword " + ptr + case D32 => "dword " + ptr + case W16 => "word " + ptr + case B8 => "byte " + ptr } } enum RegName { - case AX, AL, BX, CX, DX, SI, DI, SP, BP, IP, Reg8, Reg9, Reg10, Reg11, Reg12, Reg13, Reg14, - Reg15 - override def toString = this match { - case AX => "ax" - case AL => "al" - case BX => "bx" - case CX => "cx" - case DX => "dx" - case SI => "si" - case DI => "di" - case SP => "sp" - case BP => "bp" - case IP => "ip" - case Reg8 => "8" - case Reg9 => "9" - case Reg10 => "10" - case Reg11 => "11" - case Reg12 => "12" - case Reg13 => "13" - case Reg14 => "14" - case Reg15 => "15" + case AX, BX, CX, DX, SI, DI, SP, BP, IP, R8, R9, R10, R11, R12, R13, R14, R15 + } + + case class Register(size: Size, name: RegName) extends Dest with Src { + import RegName._ + + if (size == Size.B8 && name == RegName.IP) { + throw new IllegalArgumentException("Cannot have 8 bit register for IP") + } + override def toString = name match { + case AX => tradToString("ax", "al") + case BX => tradToString("bx", "bl") + case CX => tradToString("cx", "cl") + case DX => tradToString("dx", "dl") + case SI => tradToString("si", "sil") + case DI => tradToString("di", "dil") + case SP => tradToString("sp", "spl") + case BP => tradToString("bp", "bpl") + case IP => tradToString("ip", "#INVALID") + case R8 => newToString(8) + case R9 => newToString(9) + case R10 => newToString(10) + case R11 => newToString(11) + case R12 => newToString(12) + case R13 => newToString(13) + case R14 => newToString(14) + case R15 => newToString(15) + } + + private def tradToString(base: String, byteName: String): String = + size match { + case Size.Q64 => "r" + base + case Size.D32 => "e" + base + case Size.W16 => base + case Size.B8 => byteName + } + + private def newToString(base: Int): String = { + val b = base.toString + "r" + (size match { + case Size.Q64 => b + case Size.D32 => b + "d" + case Size.W16 => b + "w" + case Size.B8 => b + "b" + }) } } @@ -64,24 +97,18 @@ object assemblyIR { } } - // TODO register naming conventions are wrong - case class Register(size: RegSize, name: RegName) extends Dest with Src { - override def toString = s"${size}${name}" - } - case class MemLocation(pointer: Long | Register, opSize: SizeDir = SizeDir.Unspecified) - extends Dest - with Src { - override def toString = pointer match { - case hex: Long => opSize.toString + f"[0x$hex%X]" - case reg: Register => opSize.toString + s"[$reg]" - } + case class MemLocation(pointer: Register, opSize: Option[Size] = None) extends Dest with Src { + def this(pointer: Register, opSize: Size) = this(pointer, Some(opSize)) + + override def toString = + opSize.getOrElse("").toString + s"[$pointer]" } // TODO to string is wacky case class IndexAddress( base: Register, offset: Int | LabelArg, - indexReg: Register = Register(RegSize.R64, RegName.AX), + indexReg: Register = Register(Size.Q64, RegName.AX), scale: Int = 0 ) extends Dest with Src { @@ -188,17 +215,4 @@ object assemblyIR { case String => "%s" } } - - enum SizeDir { - case Byte, Word, DWord, Unspecified - - private val ptr = "ptr " - - override def toString(): String = this match { - case Byte => "byte " + ptr - case Word => "word " + ptr - case DWord => "dword " + ptr - case Unspecified => "" - } - } } diff --git a/src/main/wacc/backend/sizeExtensions.scala b/src/main/wacc/backend/sizeExtensions.scala new file mode 100644 index 0000000..59d3930 --- /dev/null +++ b/src/main/wacc/backend/sizeExtensions.scala @@ -0,0 +1,29 @@ +package wacc + +object sizeExtensions { + import microWacc._ + import types._ + import assemblyIR.Size + + extension (expr: Expr) { + + /** Calculate the size (bytes) of the heap required for the expression. */ + def heapSize: Int = (expr, expr.ty) match { + case (ArrayLiter(elems), KnownType.Array(KnownType.Char)) => + KnownType.Int.size.toInt + elems.size.toInt * KnownType.Char.size.toInt + case (ArrayLiter(elems), _) => + KnownType.Int.size.toInt + elems.map(_.ty.size.toInt).sum + case _ => expr.ty.size.toInt + } + } + + extension (ty: SemType) { + + /** Calculate the size (bytes) of a type in a register. */ + def size: Size = ty match { + case KnownType.Int => Size.D32 + case KnownType.Bool | KnownType.Char => Size.B8 + case KnownType.String | KnownType.Array(_) | KnownType.Pair(_, _) | ? => Size.Q64 + } + } +}