package wacc import scala.collection.mutable.LinkedHashMap import cats.data.Chain class Stack { import assemblyIR._ import assemblyIR.Size._ import sizeExtensions.size import microWacc as mw private val RSP = Register(Q64, RegName.SP) private class StackValue(val size: Size, val offset: Int) { def bottom: Int = offset + elemBytes } private val stack = LinkedHashMap[mw.Expr | Int, StackValue]() private val elemBytes: Int = Q64.toInt private def sizeBytes: Int = stack.size * elemBytes /** The stack's size in bytes. */ def size: Int = stack.size /** Push an expression onto the stack. */ def push(expr: mw.Expr, src: Register): AsmLine = { stack += expr -> StackValue(src.size, sizeBytes) Push(src) } /** Push a value onto the stack. */ def push(itemSize: Size, addr: Src): AsmLine = { stack += stack.size -> StackValue(itemSize, sizeBytes) Push(addr) } /** Reserve space for a variable on the stack. */ def reserve(ident: mw.Ident): AsmLine = { stack += ident -> StackValue(ident.ty.size, sizeBytes) Subtract(RSP, ImmediateVal(elemBytes)) } /** Reserve space for a register on the stack. */ def reserve(src: Register): AsmLine = { stack += stack.size -> StackValue(src.size, sizeBytes) Subtract(RSP, ImmediateVal(src.size.toInt)) } /** Reserve space for values on the stack. * * @param sizes * The sizes of the values to reserve space for. */ def reserve(sizes: List[Size]): AsmLine = { sizes.foreach { itemSize => stack += stack.size -> StackValue(itemSize, sizeBytes) } Subtract(RSP, ImmediateVal(elemBytes * sizes.size)) } /** Pop a value from the stack into a register. Sizes MUST match. */ def pop(dest: Register): AsmLine = { stack.remove(stack.last._1) Pop(dest) } /** Drop the top n values from the stack. */ def drop(n: Int = 1): AsmLine = { (1 to n).foreach { _ => stack.remove(stack.last._1) } Add(RSP, ImmediateVal(n * elemBytes)) } /** Generate AsmLines within a scope, which is reset after the block. */ def withScope(block: () => Chain[AsmLine]): Chain[AsmLine] = { val resetToSize = stack.size var lines = block() lines :+= drop(stack.size - resetToSize) lines } /** Get an IndexAddress for a variable in the stack. */ def accessVar(ident: mw.Ident): IndexAddress = IndexAddress(RSP, sizeBytes - stack(ident).bottom) def contains(ident: mw.Ident): Boolean = stack.contains(ident) def head: MemLocation = MemLocation(RSP, stack.last._2.size) override def toString(): String = stack.toString }