refactor: use explicit sizes in asmGenerator

This commit is contained in:
2025-02-28 17:21:45 +00:00
parent e1d90eabf9
commit 68903f5b69
2 changed files with 35 additions and 32 deletions

View File

@@ -8,7 +8,6 @@ object asmGenerator {
import microWacc._ import microWacc._
import assemblyIR._ import assemblyIR._
import assemblyIR.commonRegisters._ import assemblyIR.commonRegisters._
import assemblyIR.Size._
import assemblyIR.RegName._ import assemblyIR.RegName._
import types._ import types._
import sizeExtensions._ import sizeExtensions._
@@ -75,12 +74,12 @@ object asmGenerator {
given stack: Stack = Stack() given stack: Stack = Stack()
// Setup the stack with param 7 and up // Setup the stack with param 7 and up
func.params.drop(argRegs.size).foreach(stack.reserve(_)) func.params.drop(argRegs.size).foreach(stack.reserve(_))
stack.reserve(Q64) // Reserve return pointer slot stack.reserve(Size.Q64) // Reserve return pointer slot
var asm = Chain.one[AsmLine](labelGenerator.getLabelDef(func.name)) var asm = Chain.one[AsmLine](labelGenerator.getLabelDef(func.name))
asm ++= funcPrologue() asm ++= funcPrologue()
// Push the rest of params onto the stack for simplicity // Push the rest of params onto the stack for simplicity
argRegs.zip(func.params).foreach { (reg, param) => argRegs.zip(func.params).foreach { (reg, param) =>
asm += stack.push(param, Register(Q64, reg)) asm += stack.push(param, Register(Size.Q64, reg))
} }
asm ++= func.body.foldMap(generateStmt(_)) asm ++= func.body.foldMap(generateStmt(_))
// No need for epilogue here since all user functions must return explicitly // No need for epilogue here since all user functions must return explicitly
@@ -113,7 +112,7 @@ object asmGenerator {
Chain( Chain(
stackAlign, stackAlign,
Load(RDX, MemLocation(RSI, KnownType.Int.size.toInt)), Load(RDX, MemLocation(RSI, KnownType.Int.size.toInt)),
Move(Register(D32, SI), MemLocation(RSI, opSize = Some(D32))), Move(Register(KnownType.Int.size, SI), MemLocation(RSI, opSize = Some(KnownType.Int.size))),
assemblyIR.Call(CLibFunc.PrintF), assemblyIR.Call(CLibFunc.PrintF),
Xor(RDI, RDI), Xor(RDI, RDI),
assemblyIR.Call(CLibFunc.Fflush) assemblyIR.Call(CLibFunc.Fflush)
@@ -146,9 +145,9 @@ object asmGenerator {
Builtin.Read, Builtin.Read,
Chain( Chain(
stackAlign, stackAlign,
Subtract(Register(Q64, SP), ImmediateVal(8)), Subtract(Register(Size.Q64, SP), ImmediateVal(8)),
Push(RSI), Push(RSI),
Load(RSI, MemLocation(Register(Q64, SP), opSize = Some(Q64))), Load(RSI, MemLocation(Register(Size.Q64, SP), opSize = Some(Size.Q64))),
assemblyIR.Call(CLibFunc.Scanf), assemblyIR.Call(CLibFunc.Scanf),
Pop(RAX) Pop(RAX)
) )
@@ -177,13 +176,13 @@ object asmGenerator {
asm += stack.pop(RCX) asm += stack.pop(RCX)
asm += Compare(ECX, ImmediateVal(0)) asm += Compare(ECX, ImmediateVal(0))
asm += Jump(labelGenerator.getLabelArg(OutOfBoundsError), Cond.Less) asm += Jump(labelGenerator.getLabelArg(OutOfBoundsError), Cond.Less)
asm += stack.push(Q64, RCX) asm += stack.push(KnownType.Int.size, RCX)
asm ++= evalExprOntoStack(x) asm ++= evalExprOntoStack(x)
asm += stack.pop(RAX) asm += stack.pop(RAX)
asm += stack.pop(RCX) asm += stack.pop(RCX)
asm += Compare(EAX, ImmediateVal(0)) asm += Compare(RAX, ImmediateVal(0))
asm += Jump(labelGenerator.getLabelArg(NullPtrError), Cond.Equal) asm += Jump(labelGenerator.getLabelArg(NullPtrError), Cond.Equal)
asm += Compare(MemLocation(RAX, opSize = Some(D32)), ECX) asm += Compare(MemLocation(RAX, opSize = Some(KnownType.Int.size)), ECX)
asm += Jump(labelGenerator.getLabelArg(OutOfBoundsError), Cond.LessEqual) asm += Jump(labelGenerator.getLabelArg(OutOfBoundsError), Cond.LessEqual)
asm += stack.pop(RDX) asm += stack.pop(RDX)
@@ -260,21 +259,27 @@ object asmGenerator {
case KnownType.String => case KnownType.String =>
val str = elems.collect { case CharLiter(v) => v }.mkString val str = elems.collect { case CharLiter(v) => v }.mkString
asm += Load(RAX, MemLocation(RIP, labelGenerator.getLabelArg(str))) asm += Load(RAX, MemLocation(RIP, labelGenerator.getLabelArg(str)))
asm += stack.push(Q64, RAX) asm += stack.push(KnownType.String.size, RAX)
case ty => case ty =>
asm ++= generateCall( asm ++= generateCall(
microWacc.Call(Builtin.Malloc, List(IntLiter(array.heapSize))), microWacc.Call(Builtin.Malloc, List(IntLiter(array.heapSize))),
isTail = false isTail = false
) )
asm += stack.push(Q64, RAX) asm += stack.push(KnownType.Array(?).size, RAX)
// Store the length of the array at the start // Store the length of the array at the start
asm += Move(MemLocation(RAX, opSize = Some(D32)), ImmediateVal(elems.size)) asm += Move(
MemLocation(RAX, opSize = Some(KnownType.Int.size)),
ImmediateVal(elems.size)
)
elems.zipWithIndex.foldMap { (elem, i) => elems.zipWithIndex.foldMap { (elem, i) =>
asm ++= evalExprOntoStack(elem) asm ++= evalExprOntoStack(elem)
asm += stack.pop(RCX) asm += stack.pop(RCX)
asm += stack.pop(RAX) asm += stack.pop(RAX)
asm += Move(MemLocation(RAX, 4 + i * ty.elemSize.toInt), Register(ty.elemSize, CX)) asm += Move(
asm += stack.push(Q64, RAX) MemLocation(RAX, KnownType.Int.size.toInt + i * ty.elemSize.toInt),
Register(ty.elemSize, CX)
)
asm += stack.push(KnownType.Array(?).size, RAX)
} }
} }
@@ -292,9 +297,9 @@ object asmGenerator {
asm += Compare(RCX, ImmediateVal(0)) asm += Compare(RCX, ImmediateVal(0))
asm += Jump(labelGenerator.getLabelArg(OutOfBoundsError), Cond.Less) asm += Jump(labelGenerator.getLabelArg(OutOfBoundsError), Cond.Less)
asm += stack.pop(RAX) asm += stack.pop(RAX)
asm += Compare(EAX, ImmediateVal(0)) asm += Compare(RAX, ImmediateVal(0))
asm += Jump(labelGenerator.getLabelArg(NullPtrError), Cond.Equal) asm += Jump(labelGenerator.getLabelArg(NullPtrError), Cond.Equal)
asm += Compare(MemLocation(RAX, opSize = Some(D32)), ECX) asm += Compare(MemLocation(RAX, opSize = Some(KnownType.Int.size)), ECX)
asm += Jump(labelGenerator.getLabelArg(OutOfBoundsError), Cond.LessEqual) asm += Jump(labelGenerator.getLabelArg(OutOfBoundsError), Cond.LessEqual)
// + Int because we store the length of the array at the start // + Int because we store the length of the array at the start
asm += Move( asm += Move(
@@ -313,14 +318,14 @@ object asmGenerator {
case UnaryOperator.Ord => // No op needed case UnaryOperator.Ord => // No op needed
case UnaryOperator.Len => case UnaryOperator.Len =>
asm += stack.pop(RAX) asm += stack.pop(RAX)
asm += Move(EAX, MemLocation(RAX, opSize = Some(D32))) asm += Move(EAX, MemLocation(RAX, opSize = Some(KnownType.Int.size)))
asm += stack.push(D32, RAX) asm += stack.push(KnownType.Int.size, RAX)
case UnaryOperator.Negate => case UnaryOperator.Negate =>
asm += Xor(EAX, EAX) asm += Xor(EAX, EAX)
asm += Subtract(EAX, stack.head) asm += Subtract(EAX, stack.head)
asm += Jump(labelGenerator.getLabelArg(OverflowError), Cond.Overflow) asm += Jump(labelGenerator.getLabelArg(OverflowError), Cond.Overflow)
asm += stack.drop() asm += stack.drop()
asm += stack.push(Q64, RAX) asm += stack.push(KnownType.Int.size, RAX)
case UnaryOperator.Not => case UnaryOperator.Not =>
asm += Xor(stack.head, ImmediateVal(1)) asm += Xor(stack.head, ImmediateVal(1))
} }
@@ -381,7 +386,7 @@ object asmGenerator {
stack.size == stackSizeStart + 1, stack.size == stackSizeStart + 1,
"Sanity check: ONLY the evaluated expression should have been pushed onto the stack" "Sanity check: ONLY the evaluated expression should have been pushed onto the stack"
) )
asm ++= zeroRest(stack.head.copy(opSize = Some(Q64)), expr.ty.size) asm ++= zeroRest(stack.head.copy(opSize = Some(Size.Q64)), expr.ty.size)
asm asm
} }
@@ -402,7 +407,7 @@ object asmGenerator {
// And set the appropriate registers // And set the appropriate registers
.reverse .reverse
.foreach { reg => .foreach { reg =>
asm += stack.pop(Register(Q64, reg)) asm += stack.pop(Register(Size.Q64, reg))
} }
// Evaluate arguments 7 and up and push them onto the stack // Evaluate arguments 7 and up and push them onto the stack
@@ -431,33 +436,33 @@ object asmGenerator {
var asm = Chain.empty[AsmLine] var asm = Chain.empty[AsmLine]
asm += Compare(destX, stack.head) asm += Compare(destX, stack.head)
asm += Set(Register(B8, AX), cond) asm += Set(Register(Size.B8, AX), cond)
asm ++= zeroRest(RAX, B8) asm ++= zeroRest(RAX, Size.B8)
asm += stack.drop() asm += stack.drop()
asm += stack.push(B8, RAX) asm += stack.push(Size.B8, RAX)
asm asm
} }
private def funcPrologue()(using stack: Stack): Chain[AsmLine] = { private def funcPrologue()(using stack: Stack): Chain[AsmLine] = {
var asm = Chain.empty[AsmLine] var asm = Chain.empty[AsmLine]
asm += stack.push(Q64, RBP) asm += stack.push(Size.Q64, RBP)
asm += Move(RBP, Register(Q64, SP)) asm += Move(RBP, Register(Size.Q64, SP))
asm asm
} }
private def funcEpilogue(): Chain[AsmLine] = { private def funcEpilogue(): Chain[AsmLine] = {
var asm = Chain.empty[AsmLine] var asm = Chain.empty[AsmLine]
asm += Move(Register(Q64, SP), RBP) asm += Move(Register(Size.Q64, SP), RBP)
asm += Pop(RBP) asm += Pop(RBP)
asm += assemblyIR.Return() asm += assemblyIR.Return()
asm asm
} }
def stackAlign: AsmLine = And(Register(Q64, SP), ImmediateVal(-16)) def stackAlign: AsmLine = And(Register(Size.Q64, SP), ImmediateVal(-16))
private def zeroRest(dest: Dest, size: Size): Chain[AsmLine] = size match { private def zeroRest(dest: Dest, size: Size): Chain[AsmLine] = size match {
case Q64 | D32 => Chain.empty case Size.Q64 | Size.D32 => Chain.empty
case _ => Chain.one(And(dest, ImmediateVal((1 << (size.toInt * 8)) - 1))) case _ => Chain.one(And(dest, ImmediateVal((1 << (size.toInt * 8)) - 1)))
} }
private val escapedCharsMapping = escapedChars.map { case (k, v) => v -> s"\\$k" } private val escapedCharsMapping = escapedChars.map { case (k, v) => v -> s"\\$k" }

View File

@@ -9,8 +9,6 @@ object sizeExtensions {
/** Calculate the size (bytes) of the heap required for the expression. */ /** Calculate the size (bytes) of the heap required for the expression. */
def heapSize: Int = (expr, expr.ty) match { def heapSize: Int = (expr, expr.ty) match {
case (ArrayLiter(elems), KnownType.Array(KnownType.Char)) =>
KnownType.Int.size.toInt + elems.size.toInt * KnownType.Char.size.toInt
case (ArrayLiter(elems), ty) => case (ArrayLiter(elems), ty) =>
KnownType.Int.size.toInt + elems.size * ty.elemSize.toInt KnownType.Int.size.toInt + elems.size * ty.elemSize.toInt
case _ => expr.ty.size.toInt case _ => expr.ty.size.toInt