refactor: merge MemLocation with IndexedAddress
Merge request lab2425_spring/WACC_37!37
This commit is contained in:
@@ -28,7 +28,7 @@ object RuntimeError {
|
||||
|
||||
protected def generateHandler(using labelGenerator: LabelGenerator): Chain[AsmLine] = Chain(
|
||||
stackAlign,
|
||||
Load(RDI, IndexAddress(RIP, getErrLabel)),
|
||||
Load(RDI, MemLocation(RIP, getErrLabel)),
|
||||
assemblyIR.Call(CLibFunc.PrintF),
|
||||
Move(RDI, ImmediateVal(-1)),
|
||||
assemblyIR.Call(CLibFunc.Exit)
|
||||
@@ -43,7 +43,7 @@ object RuntimeError {
|
||||
protected def generateHandler(using labelGenerator: LabelGenerator): Chain[AsmLine] = Chain(
|
||||
Pop(RSI),
|
||||
stackAlign,
|
||||
Load(RDI, IndexAddress(RIP, getErrLabel)),
|
||||
Load(RDI, MemLocation(RIP, getErrLabel)),
|
||||
assemblyIR.Call(CLibFunc.PrintF),
|
||||
Move(RDI, ImmediateVal(ERROR_CODE)),
|
||||
assemblyIR.Call(CLibFunc.Exit)
|
||||
@@ -57,7 +57,7 @@ object RuntimeError {
|
||||
|
||||
protected def generateHandler(using labelGenerator: LabelGenerator): Chain[AsmLine] = Chain(
|
||||
stackAlign,
|
||||
Load(RDI, IndexAddress(RIP, getErrLabel)),
|
||||
Load(RDI, MemLocation(RIP, getErrLabel)),
|
||||
assemblyIR.Call(CLibFunc.PrintF),
|
||||
Move(RDI, ImmediateVal(ERROR_CODE)),
|
||||
assemblyIR.Call(CLibFunc.Exit)
|
||||
@@ -71,7 +71,7 @@ object RuntimeError {
|
||||
|
||||
protected def generateHandler(using labelGenerator: LabelGenerator): Chain[AsmLine] = Chain(
|
||||
stackAlign,
|
||||
Load(RDI, IndexAddress(RIP, getErrLabel)),
|
||||
Load(RDI, MemLocation(RIP, getErrLabel)),
|
||||
assemblyIR.Call(CLibFunc.PrintF),
|
||||
Move(RDI, ImmediateVal(ERROR_CODE)),
|
||||
assemblyIR.Call(CLibFunc.Exit)
|
||||
@@ -86,7 +86,7 @@ object RuntimeError {
|
||||
protected def generateHandler(using labelGenerator: LabelGenerator): Chain[AsmLine] = Chain(
|
||||
Move(RSI, RCX),
|
||||
stackAlign,
|
||||
Load(RDI, IndexAddress(RIP, getErrLabel)),
|
||||
Load(RDI, MemLocation(RIP, getErrLabel)),
|
||||
assemblyIR.Call(CLibFunc.PrintF),
|
||||
Move(RDI, ImmediateVal(ERROR_CODE)),
|
||||
assemblyIR.Call(CLibFunc.Exit)
|
||||
@@ -99,7 +99,7 @@ object RuntimeError {
|
||||
|
||||
def generateHandler(using labelGenerator: LabelGenerator): Chain[AsmLine] = Chain(
|
||||
stackAlign,
|
||||
Load(RDI, IndexAddress(RIP, getErrLabel)),
|
||||
Load(RDI, MemLocation(RIP, getErrLabel)),
|
||||
assemblyIR.Call(CLibFunc.PrintF),
|
||||
Move(RDI, ImmediateVal(ERROR_CODE)),
|
||||
assemblyIR.Call(CLibFunc.Exit)
|
||||
|
||||
@@ -79,12 +79,12 @@ class Stack {
|
||||
lines
|
||||
}
|
||||
|
||||
/** Get an IndexAddress for a variable in the stack. */
|
||||
def accessVar(ident: mw.Ident): IndexAddress =
|
||||
IndexAddress(RSP, sizeBytes - stack(ident).bottom)
|
||||
/** Get an MemLocation for a variable in the stack. */
|
||||
def accessVar(ident: mw.Ident): MemLocation =
|
||||
MemLocation(RSP, sizeBytes - stack(ident).bottom)
|
||||
|
||||
def contains(ident: mw.Ident): Boolean = stack.contains(ident)
|
||||
def head: MemLocation = MemLocation(RSP, stack.last._2.size)
|
||||
def head: MemLocation = MemLocation(RSP, opSize = Some(stack.last._2.size))
|
||||
|
||||
override def toString(): String = stack.toString
|
||||
}
|
||||
|
||||
@@ -8,7 +8,6 @@ object asmGenerator {
|
||||
import microWacc._
|
||||
import assemblyIR._
|
||||
import assemblyIR.commonRegisters._
|
||||
import assemblyIR.Size._
|
||||
import assemblyIR.RegName._
|
||||
import types._
|
||||
import sizeExtensions._
|
||||
@@ -75,12 +74,12 @@ object asmGenerator {
|
||||
given stack: Stack = Stack()
|
||||
// Setup the stack with param 7 and up
|
||||
func.params.drop(argRegs.size).foreach(stack.reserve(_))
|
||||
stack.reserve(Q64) // Reserve return pointer slot
|
||||
stack.reserve(Size.Q64) // Reserve return pointer slot
|
||||
var asm = Chain.one[AsmLine](labelGenerator.getLabelDef(func.name))
|
||||
asm ++= funcPrologue()
|
||||
// Push the rest of params onto the stack for simplicity
|
||||
argRegs.zip(func.params).foreach { (reg, param) =>
|
||||
asm += stack.push(param, Register(Q64, reg))
|
||||
asm += stack.push(param, Register(Size.Q64, reg))
|
||||
}
|
||||
asm ++= func.body.foldMap(generateStmt(_))
|
||||
// No need for epilogue here since all user functions must return explicitly
|
||||
@@ -112,8 +111,8 @@ object asmGenerator {
|
||||
Builtin.PrintCharArray,
|
||||
Chain(
|
||||
stackAlign,
|
||||
Load(RDX, IndexAddress(RSI, KnownType.Int.size.toInt)),
|
||||
Move(Register(D32, SI), MemLocation(RSI, D32)),
|
||||
Load(RDX, MemLocation(RSI, KnownType.Int.size.toInt)),
|
||||
Move(Register(KnownType.Int.size, SI), MemLocation(RSI, opSize = Some(KnownType.Int.size))),
|
||||
assemblyIR.Call(CLibFunc.PrintF),
|
||||
Xor(RDI, RDI),
|
||||
assemblyIR.Call(CLibFunc.Fflush)
|
||||
@@ -146,9 +145,9 @@ object asmGenerator {
|
||||
Builtin.Read,
|
||||
Chain(
|
||||
stackAlign,
|
||||
Subtract(Register(Q64, SP), ImmediateVal(8)),
|
||||
Subtract(Register(Size.Q64, SP), ImmediateVal(8)),
|
||||
Push(RSI),
|
||||
Load(RSI, MemLocation(Register(Q64, SP), Q64)),
|
||||
Load(RSI, MemLocation(Register(Size.Q64, SP), opSize = Some(Size.Q64))),
|
||||
assemblyIR.Call(CLibFunc.Scanf),
|
||||
Pop(RAX)
|
||||
)
|
||||
@@ -162,6 +161,7 @@ object asmGenerator {
|
||||
labelGenerator: LabelGenerator
|
||||
): Chain[AsmLine] = {
|
||||
var asm = Chain.empty[AsmLine]
|
||||
asm += Comment(stmt.toString)
|
||||
stmt match {
|
||||
case Assign(lhs, rhs) =>
|
||||
lhs match {
|
||||
@@ -169,25 +169,25 @@ object asmGenerator {
|
||||
if (!stack.contains(ident)) asm += stack.reserve(ident)
|
||||
asm ++= evalExprOntoStack(rhs)
|
||||
asm += stack.pop(RAX)
|
||||
asm += Move(stack.accessVar(ident), RAX)
|
||||
asm += Move(stack.accessVar(ident).copy(opSize = Some(Size.Q64)), RAX)
|
||||
case ArrayElem(x, i) =>
|
||||
asm ++= evalExprOntoStack(rhs)
|
||||
asm ++= evalExprOntoStack(i)
|
||||
asm += stack.pop(RCX)
|
||||
asm += Compare(ECX, ImmediateVal(0))
|
||||
asm += Jump(labelGenerator.getLabelArg(OutOfBoundsError), Cond.Less)
|
||||
asm += stack.push(Q64, RCX)
|
||||
asm += stack.push(KnownType.Int.size, RCX)
|
||||
asm ++= evalExprOntoStack(x)
|
||||
asm += stack.pop(RAX)
|
||||
asm += stack.pop(RCX)
|
||||
asm += Compare(EAX, ImmediateVal(0))
|
||||
asm += Compare(RAX, ImmediateVal(0))
|
||||
asm += Jump(labelGenerator.getLabelArg(NullPtrError), Cond.Equal)
|
||||
asm += Compare(MemLocation(RAX, D32), ECX)
|
||||
asm += Compare(MemLocation(RAX, opSize = Some(KnownType.Int.size)), ECX)
|
||||
asm += Jump(labelGenerator.getLabelArg(OutOfBoundsError), Cond.LessEqual)
|
||||
asm += stack.pop(RDX)
|
||||
|
||||
asm += Move(
|
||||
IndexAddress(RAX, KnownType.Int.size.toInt, RCX, x.ty.elemSize.toInt),
|
||||
MemLocation(RAX, KnownType.Int.size.toInt, (RCX, x.ty.elemSize.toInt)),
|
||||
Register(x.ty.elemSize, DX)
|
||||
)
|
||||
}
|
||||
@@ -248,28 +248,38 @@ object asmGenerator {
|
||||
expr match {
|
||||
case IntLiter(v) => asm += stack.push(KnownType.Int.size, ImmediateVal(v))
|
||||
case CharLiter(v) => asm += stack.push(KnownType.Char.size, ImmediateVal(v.toInt))
|
||||
case ident: Ident => asm += stack.push(ident.ty.size, stack.accessVar(ident))
|
||||
case ident: Ident =>
|
||||
val location = stack.accessVar(ident)
|
||||
// items in stack are guaranteed to be in Q64 slots,
|
||||
// so we are safe to wipe the opSize from the memory location
|
||||
asm += stack.push(ident.ty.size, location.copy(opSize = None))
|
||||
|
||||
case array @ ArrayLiter(elems) =>
|
||||
expr.ty match {
|
||||
case KnownType.String =>
|
||||
val str = elems.collect { case CharLiter(v) => v }.mkString
|
||||
asm += Load(RAX, IndexAddress(RIP, labelGenerator.getLabelArg(str)))
|
||||
asm += stack.push(Q64, RAX)
|
||||
asm += Load(RAX, MemLocation(RIP, labelGenerator.getLabelArg(str)))
|
||||
asm += stack.push(KnownType.String.size, RAX)
|
||||
case ty =>
|
||||
asm ++= generateCall(
|
||||
microWacc.Call(Builtin.Malloc, List(IntLiter(array.heapSize))),
|
||||
isTail = false
|
||||
)
|
||||
asm += stack.push(Q64, RAX)
|
||||
asm += stack.push(KnownType.Array(?).size, RAX)
|
||||
// Store the length of the array at the start
|
||||
asm += Move(MemLocation(RAX, D32), ImmediateVal(elems.size))
|
||||
asm += Move(
|
||||
MemLocation(RAX, opSize = Some(KnownType.Int.size)),
|
||||
ImmediateVal(elems.size)
|
||||
)
|
||||
elems.zipWithIndex.foldMap { (elem, i) =>
|
||||
asm ++= evalExprOntoStack(elem)
|
||||
asm += stack.pop(RCX)
|
||||
asm += stack.pop(RAX)
|
||||
asm += Move(IndexAddress(RAX, 4 + i * ty.elemSize.toInt), Register(ty.elemSize, CX))
|
||||
asm += stack.push(Q64, RAX)
|
||||
asm += Move(
|
||||
MemLocation(RAX, KnownType.Int.size.toInt + i * ty.elemSize.toInt),
|
||||
Register(ty.elemSize, CX)
|
||||
)
|
||||
asm += stack.push(KnownType.Array(?).size, RAX)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -287,14 +297,14 @@ object asmGenerator {
|
||||
asm += Compare(RCX, ImmediateVal(0))
|
||||
asm += Jump(labelGenerator.getLabelArg(OutOfBoundsError), Cond.Less)
|
||||
asm += stack.pop(RAX)
|
||||
asm += Compare(EAX, ImmediateVal(0))
|
||||
asm += Compare(RAX, ImmediateVal(0))
|
||||
asm += Jump(labelGenerator.getLabelArg(NullPtrError), Cond.Equal)
|
||||
asm += Compare(MemLocation(RAX, D32), ECX)
|
||||
asm += Compare(MemLocation(RAX, opSize = Some(KnownType.Int.size)), ECX)
|
||||
asm += Jump(labelGenerator.getLabelArg(OutOfBoundsError), Cond.LessEqual)
|
||||
// + Int because we store the length of the array at the start
|
||||
asm += Move(
|
||||
Register(x.ty.elemSize, AX),
|
||||
IndexAddress(RAX, KnownType.Int.size.toInt, RCX, x.ty.elemSize.toInt)
|
||||
MemLocation(RAX, KnownType.Int.size.toInt, (RCX, x.ty.elemSize.toInt))
|
||||
)
|
||||
asm += stack.push(x.ty.elemSize, RAX)
|
||||
case UnaryOp(x, op) =>
|
||||
@@ -308,14 +318,14 @@ object asmGenerator {
|
||||
case UnaryOperator.Ord => // No op needed
|
||||
case UnaryOperator.Len =>
|
||||
asm += stack.pop(RAX)
|
||||
asm += Move(EAX, MemLocation(RAX, D32))
|
||||
asm += stack.push(D32, RAX)
|
||||
asm += Move(EAX, MemLocation(RAX, opSize = Some(KnownType.Int.size)))
|
||||
asm += stack.push(KnownType.Int.size, RAX)
|
||||
case UnaryOperator.Negate =>
|
||||
asm += Xor(EAX, EAX)
|
||||
asm += Subtract(EAX, stack.head)
|
||||
asm += Jump(labelGenerator.getLabelArg(OverflowError), Cond.Overflow)
|
||||
asm += stack.drop()
|
||||
asm += stack.push(Q64, RAX)
|
||||
asm += stack.push(KnownType.Int.size, RAX)
|
||||
case UnaryOperator.Not =>
|
||||
asm += Xor(stack.head, ImmediateVal(1))
|
||||
}
|
||||
@@ -376,7 +386,7 @@ object asmGenerator {
|
||||
stack.size == stackSizeStart + 1,
|
||||
"Sanity check: ONLY the evaluated expression should have been pushed onto the stack"
|
||||
)
|
||||
asm ++= zeroRest(MemLocation(stack.head.pointer, Q64), expr.ty.size)
|
||||
asm ++= zeroRest(stack.head.copy(opSize = Some(Size.Q64)), expr.ty.size)
|
||||
asm
|
||||
}
|
||||
|
||||
@@ -397,7 +407,7 @@ object asmGenerator {
|
||||
// And set the appropriate registers
|
||||
.reverse
|
||||
.foreach { reg =>
|
||||
asm += stack.pop(Register(Q64, reg))
|
||||
asm += stack.pop(Register(Size.Q64, reg))
|
||||
}
|
||||
|
||||
// Evaluate arguments 7 and up and push them onto the stack
|
||||
@@ -426,32 +436,32 @@ object asmGenerator {
|
||||
var asm = Chain.empty[AsmLine]
|
||||
|
||||
asm += Compare(destX, stack.head)
|
||||
asm += Set(Register(B8, AX), cond)
|
||||
asm ++= zeroRest(RAX, B8)
|
||||
asm += Set(Register(Size.B8, AX), cond)
|
||||
asm ++= zeroRest(RAX, Size.B8)
|
||||
asm += stack.drop()
|
||||
asm += stack.push(B8, RAX)
|
||||
asm += stack.push(Size.B8, RAX)
|
||||
|
||||
asm
|
||||
}
|
||||
|
||||
private def funcPrologue()(using stack: Stack): Chain[AsmLine] = {
|
||||
var asm = Chain.empty[AsmLine]
|
||||
asm += stack.push(Q64, RBP)
|
||||
asm += Move(RBP, Register(Q64, SP))
|
||||
asm += stack.push(Size.Q64, RBP)
|
||||
asm += Move(RBP, Register(Size.Q64, SP))
|
||||
asm
|
||||
}
|
||||
|
||||
private def funcEpilogue(): Chain[AsmLine] = {
|
||||
var asm = Chain.empty[AsmLine]
|
||||
asm += Move(Register(Q64, SP), RBP)
|
||||
asm += Move(Register(Size.Q64, SP), RBP)
|
||||
asm += Pop(RBP)
|
||||
asm += assemblyIR.Return()
|
||||
asm
|
||||
}
|
||||
|
||||
def stackAlign: AsmLine = And(Register(Q64, SP), ImmediateVal(-16))
|
||||
def stackAlign: AsmLine = And(Register(Size.Q64, SP), ImmediateVal(-16))
|
||||
private def zeroRest(dest: Dest, size: Size): Chain[AsmLine] = size match {
|
||||
case Q64 | D32 => Chain.empty
|
||||
case Size.Q64 | Size.D32 => Chain.empty
|
||||
case _ => Chain.one(And(dest, ImmediateVal((1 << (size.toInt * 8)) - 1)))
|
||||
}
|
||||
|
||||
|
||||
@@ -97,22 +97,33 @@ object assemblyIR {
|
||||
}
|
||||
}
|
||||
|
||||
case class MemLocation(pointer: Register, opSize: Size) extends Dest with Src {
|
||||
override def toString =
|
||||
opSize.toString + s"[$pointer]"
|
||||
}
|
||||
|
||||
case class IndexAddress(
|
||||
case class MemLocation(
|
||||
base: Register,
|
||||
offset: Int | LabelArg,
|
||||
indexReg: Register = Register(Size.Q64, RegName.AX),
|
||||
scale: Int = 0
|
||||
offset: Int | LabelArg = 0,
|
||||
// scale 0 will make register irrelevant, no other reason as to why it's RAX
|
||||
scaledIndex: (Register, Int) = (Register(Size.Q64, RegName.AX), 0),
|
||||
opSize: Option[Size] = None
|
||||
) extends Dest
|
||||
with Src {
|
||||
override def toString = if (scale != 0) {
|
||||
s"[$base + $indexReg * $scale + $offset]"
|
||||
} else {
|
||||
s"[$base + $offset]"
|
||||
def copy(
|
||||
base: Register = this.base,
|
||||
offset: Int | LabelArg = this.offset,
|
||||
scaledIndex: (Register, Int) = this.scaledIndex,
|
||||
opSize: Option[Size] = this.opSize
|
||||
): MemLocation = MemLocation(base, offset, scaledIndex, opSize)
|
||||
|
||||
override def toString(): String = {
|
||||
val opSizeStr = opSize.map(_.toString).getOrElse("")
|
||||
val baseStr = base.toString
|
||||
val offsetStr = offset match {
|
||||
case 0 => ""
|
||||
case off => s" + $off"
|
||||
}
|
||||
val scaledIndexStr = scaledIndex match {
|
||||
case (reg, scale) if scale != 0 => s" + $reg * $scale"
|
||||
case _ => ""
|
||||
}
|
||||
s"$opSizeStr[$baseStr$scaledIndexStr$offsetStr]"
|
||||
}
|
||||
}
|
||||
|
||||
@@ -145,8 +156,7 @@ object assemblyIR {
|
||||
case class Pop(op1: Src) extends Operation("pop", op1)
|
||||
// move operations
|
||||
case class Move(op1: Dest, op2: Src) extends Operation("mov", op1, op2)
|
||||
case class Load(op1: Register, op2: MemLocation | IndexAddress)
|
||||
extends Operation("lea ", op1, op2)
|
||||
case class Load(op1: Register, op2: MemLocation) extends Operation("lea ", op1, op2)
|
||||
|
||||
// function call operations
|
||||
case class Call(op1: CLibFunc | LabelArg) extends Operation("call", op1)
|
||||
|
||||
@@ -9,8 +9,6 @@ object sizeExtensions {
|
||||
|
||||
/** Calculate the size (bytes) of the heap required for the expression. */
|
||||
def heapSize: Int = (expr, expr.ty) match {
|
||||
case (ArrayLiter(elems), KnownType.Array(KnownType.Char)) =>
|
||||
KnownType.Int.size.toInt + elems.size.toInt * KnownType.Char.size.toInt
|
||||
case (ArrayLiter(elems), ty) =>
|
||||
KnownType.Int.size.toInt + elems.size * ty.elemSize.toInt
|
||||
case _ => expr.ty.size.toInt
|
||||
|
||||
@@ -29,12 +29,19 @@ class instructionSpec extends AnyFunSuite {
|
||||
assert(scratch32BitRegister.toString == "r8d")
|
||||
}
|
||||
|
||||
val memLocationWithRegister = MemLocation(named64BitRegister, Q64)
|
||||
val memLocationWithRegister = MemLocation(named64BitRegister, opSize = Some(Q64))
|
||||
|
||||
test("mem location with register toString") {
|
||||
assert(memLocationWithRegister.toString == "qword ptr [rax]")
|
||||
}
|
||||
|
||||
val memLocationFull =
|
||||
MemLocation(named64BitRegister, 32, (scratch64BitRegister, 10), Some(B8))
|
||||
|
||||
test("mem location with all fields toString") {
|
||||
assert(memLocationFull.toString == "byte ptr [rax + r8 * 10 + 32]")
|
||||
}
|
||||
|
||||
val immediateVal = ImmediateVal(123)
|
||||
|
||||
test("immediate value toString") {
|
||||
|
||||
Reference in New Issue
Block a user