refactor: merge MemLocation with IndexedAddress #37

Merged
gk1623 merged 7 commits from single-memlocation into master 2025-02-28 18:44:49 +00:00
5 changed files with 62 additions and 40 deletions
Showing only changes of commit 1a39950a7b - Show all commits

View File

@@ -28,7 +28,7 @@ object RuntimeError {
protected def generateHandler(using labelGenerator: LabelGenerator): Chain[AsmLine] = Chain( protected def generateHandler(using labelGenerator: LabelGenerator): Chain[AsmLine] = Chain(
stackAlign, stackAlign,
Load(RDI, IndexAddress(RIP, getErrLabel)), Load(RDI, MemLocation(RIP, getErrLabel)),
assemblyIR.Call(CLibFunc.PrintF), assemblyIR.Call(CLibFunc.PrintF),
Move(RDI, ImmediateVal(-1)), Move(RDI, ImmediateVal(-1)),
assemblyIR.Call(CLibFunc.Exit) assemblyIR.Call(CLibFunc.Exit)
@@ -43,7 +43,7 @@ object RuntimeError {
protected def generateHandler(using labelGenerator: LabelGenerator): Chain[AsmLine] = Chain( protected def generateHandler(using labelGenerator: LabelGenerator): Chain[AsmLine] = Chain(
Pop(RSI), Pop(RSI),
stackAlign, stackAlign,
Load(RDI, IndexAddress(RIP, getErrLabel)), Load(RDI, MemLocation(RIP, getErrLabel)),
assemblyIR.Call(CLibFunc.PrintF), assemblyIR.Call(CLibFunc.PrintF),
Move(RDI, ImmediateVal(ERROR_CODE)), Move(RDI, ImmediateVal(ERROR_CODE)),
assemblyIR.Call(CLibFunc.Exit) assemblyIR.Call(CLibFunc.Exit)
@@ -57,7 +57,7 @@ object RuntimeError {
protected def generateHandler(using labelGenerator: LabelGenerator): Chain[AsmLine] = Chain( protected def generateHandler(using labelGenerator: LabelGenerator): Chain[AsmLine] = Chain(
stackAlign, stackAlign,
Load(RDI, IndexAddress(RIP, getErrLabel)), Load(RDI, MemLocation(RIP, getErrLabel)),
assemblyIR.Call(CLibFunc.PrintF), assemblyIR.Call(CLibFunc.PrintF),
Move(RDI, ImmediateVal(ERROR_CODE)), Move(RDI, ImmediateVal(ERROR_CODE)),
assemblyIR.Call(CLibFunc.Exit) assemblyIR.Call(CLibFunc.Exit)
@@ -71,7 +71,7 @@ object RuntimeError {
protected def generateHandler(using labelGenerator: LabelGenerator): Chain[AsmLine] = Chain( protected def generateHandler(using labelGenerator: LabelGenerator): Chain[AsmLine] = Chain(
stackAlign, stackAlign,
Load(RDI, IndexAddress(RIP, getErrLabel)), Load(RDI, MemLocation(RIP, getErrLabel)),
assemblyIR.Call(CLibFunc.PrintF), assemblyIR.Call(CLibFunc.PrintF),
Move(RDI, ImmediateVal(ERROR_CODE)), Move(RDI, ImmediateVal(ERROR_CODE)),
assemblyIR.Call(CLibFunc.Exit) assemblyIR.Call(CLibFunc.Exit)
@@ -86,7 +86,7 @@ object RuntimeError {
protected def generateHandler(using labelGenerator: LabelGenerator): Chain[AsmLine] = Chain( protected def generateHandler(using labelGenerator: LabelGenerator): Chain[AsmLine] = Chain(
Move(RSI, RCX), Move(RSI, RCX),
stackAlign, stackAlign,
Load(RDI, IndexAddress(RIP, getErrLabel)), Load(RDI, MemLocation(RIP, getErrLabel)),
assemblyIR.Call(CLibFunc.PrintF), assemblyIR.Call(CLibFunc.PrintF),
Move(RDI, ImmediateVal(ERROR_CODE)), Move(RDI, ImmediateVal(ERROR_CODE)),
assemblyIR.Call(CLibFunc.Exit) assemblyIR.Call(CLibFunc.Exit)
@@ -99,7 +99,7 @@ object RuntimeError {
def generateHandler(using labelGenerator: LabelGenerator): Chain[AsmLine] = Chain( def generateHandler(using labelGenerator: LabelGenerator): Chain[AsmLine] = Chain(
stackAlign, stackAlign,
Load(RDI, IndexAddress(RIP, getErrLabel)), Load(RDI, MemLocation(RIP, getErrLabel)),
assemblyIR.Call(CLibFunc.PrintF), assemblyIR.Call(CLibFunc.PrintF),
Move(RDI, ImmediateVal(ERROR_CODE)), Move(RDI, ImmediateVal(ERROR_CODE)),
assemblyIR.Call(CLibFunc.Exit) assemblyIR.Call(CLibFunc.Exit)

View File

@@ -79,12 +79,12 @@ class Stack {
lines lines
} }
/** Get an IndexAddress for a variable in the stack. */ /** Get an MemLocation for a variable in the stack. */
def accessVar(ident: mw.Ident): IndexAddress = def accessVar(ident: mw.Ident): MemLocation =
IndexAddress(RSP, sizeBytes - stack(ident).bottom) MemLocation(RSP, sizeBytes - stack(ident).bottom, opSize = Some(ident.ty.size))
def contains(ident: mw.Ident): Boolean = stack.contains(ident) def contains(ident: mw.Ident): Boolean = stack.contains(ident)
def head: MemLocation = MemLocation(RSP, stack.last._2.size) def head: MemLocation = MemLocation(RSP, opSize = Some(stack.last._2.size))
override def toString(): String = stack.toString override def toString(): String = stack.toString
} }

View File

@@ -112,8 +112,8 @@ object asmGenerator {
Builtin.PrintCharArray, Builtin.PrintCharArray,
Chain( Chain(
stackAlign, stackAlign,
Load(RDX, IndexAddress(RSI, KnownType.Int.size.toInt)), Load(RDX, MemLocation(RSI, KnownType.Int.size.toInt)),
Move(Register(D32, SI), MemLocation(RSI, D32)), Move(Register(D32, SI), MemLocation(RSI, opSize = Some(D32))),
assemblyIR.Call(CLibFunc.PrintF), assemblyIR.Call(CLibFunc.PrintF),
Xor(RDI, RDI), Xor(RDI, RDI),
assemblyIR.Call(CLibFunc.Fflush) assemblyIR.Call(CLibFunc.Fflush)
@@ -148,7 +148,7 @@ object asmGenerator {
stackAlign, stackAlign,
Subtract(Register(Q64, SP), ImmediateVal(8)), Subtract(Register(Q64, SP), ImmediateVal(8)),
Push(RSI), Push(RSI),
Load(RSI, MemLocation(Register(Q64, SP), Q64)), Load(RSI, MemLocation(Register(Q64, SP), opSize = Some(Q64))),
assemblyIR.Call(CLibFunc.Scanf), assemblyIR.Call(CLibFunc.Scanf),
Pop(RAX) Pop(RAX)
) )
@@ -167,9 +167,10 @@ object asmGenerator {
lhs match { lhs match {
case ident: Ident => case ident: Ident =>
if (!stack.contains(ident)) asm += stack.reserve(ident) if (!stack.contains(ident)) asm += stack.reserve(ident)
val dest = Register(ident.ty.size, AX)
asm ++= evalExprOntoStack(rhs) asm ++= evalExprOntoStack(rhs)
asm += stack.pop(RAX) asm += stack.pop(RAX)
asm += Move(stack.accessVar(ident), RAX) asm += Move(stack.accessVar(ident), dest)
case ArrayElem(x, i) => case ArrayElem(x, i) =>
asm ++= evalExprOntoStack(rhs) asm ++= evalExprOntoStack(rhs)
asm ++= evalExprOntoStack(i) asm ++= evalExprOntoStack(i)
@@ -182,12 +183,12 @@ object asmGenerator {
asm += stack.pop(RCX) asm += stack.pop(RCX)
asm += Compare(EAX, ImmediateVal(0)) asm += Compare(EAX, ImmediateVal(0))
asm += Jump(labelGenerator.getLabelArg(NullPtrError), Cond.Equal) asm += Jump(labelGenerator.getLabelArg(NullPtrError), Cond.Equal)
asm += Compare(MemLocation(RAX, D32), ECX) asm += Compare(MemLocation(RAX, opSize = Some(D32)), ECX)
asm += Jump(labelGenerator.getLabelArg(OutOfBoundsError), Cond.LessEqual) asm += Jump(labelGenerator.getLabelArg(OutOfBoundsError), Cond.LessEqual)
asm += stack.pop(RDX) asm += stack.pop(RDX)
asm += Move( asm += Move(
IndexAddress(RAX, KnownType.Int.size.toInt, RCX, x.ty.elemSize.toInt), MemLocation(RAX, KnownType.Int.size.toInt, (RCX, x.ty.elemSize.toInt)),
Register(x.ty.elemSize, DX) Register(x.ty.elemSize, DX)
) )
} }
@@ -248,13 +249,17 @@ object asmGenerator {
expr match { expr match {
case IntLiter(v) => asm += stack.push(KnownType.Int.size, ImmediateVal(v)) case IntLiter(v) => asm += stack.push(KnownType.Int.size, ImmediateVal(v))
case CharLiter(v) => asm += stack.push(KnownType.Char.size, ImmediateVal(v.toInt)) case CharLiter(v) => asm += stack.push(KnownType.Char.size, ImmediateVal(v.toInt))
case ident: Ident => asm += stack.push(ident.ty.size, stack.accessVar(ident)) case ident: Ident =>
val location = stack.accessVar(ident)
// items in stack are guaranteed to be in Q64 slots,
// so we are safe to wipe the opSize from the memory location
asm += stack.push(ident.ty.size, location.copy(opSize = None))
case array @ ArrayLiter(elems) => case array @ ArrayLiter(elems) =>
expr.ty match { expr.ty match {
case KnownType.String => case KnownType.String =>
val str = elems.collect { case CharLiter(v) => v }.mkString val str = elems.collect { case CharLiter(v) => v }.mkString
asm += Load(RAX, IndexAddress(RIP, labelGenerator.getLabelArg(str))) asm += Load(RAX, MemLocation(RIP, labelGenerator.getLabelArg(str)))
asm += stack.push(Q64, RAX) asm += stack.push(Q64, RAX)
case ty => case ty =>
asm ++= generateCall( asm ++= generateCall(
@@ -263,12 +268,12 @@ object asmGenerator {
) )
asm += stack.push(Q64, RAX) asm += stack.push(Q64, RAX)
// Store the length of the array at the start // Store the length of the array at the start
asm += Move(MemLocation(RAX, D32), ImmediateVal(elems.size)) asm += Move(MemLocation(RAX, opSize = Some(D32)), ImmediateVal(elems.size))
elems.zipWithIndex.foldMap { (elem, i) => elems.zipWithIndex.foldMap { (elem, i) =>
asm ++= evalExprOntoStack(elem) asm ++= evalExprOntoStack(elem)
asm += stack.pop(RCX) asm += stack.pop(RCX)
asm += stack.pop(RAX) asm += stack.pop(RAX)
asm += Move(IndexAddress(RAX, 4 + i * ty.elemSize.toInt), Register(ty.elemSize, CX)) asm += Move(MemLocation(RAX, 4 + i * ty.elemSize.toInt), Register(ty.elemSize, CX))
asm += stack.push(Q64, RAX) asm += stack.push(Q64, RAX)
} }
} }
@@ -289,12 +294,12 @@ object asmGenerator {
asm += stack.pop(RAX) asm += stack.pop(RAX)
asm += Compare(EAX, ImmediateVal(0)) asm += Compare(EAX, ImmediateVal(0))
asm += Jump(labelGenerator.getLabelArg(NullPtrError), Cond.Equal) asm += Jump(labelGenerator.getLabelArg(NullPtrError), Cond.Equal)
asm += Compare(MemLocation(RAX, D32), ECX) asm += Compare(MemLocation(RAX, opSize = Some(D32)), ECX)
asm += Jump(labelGenerator.getLabelArg(OutOfBoundsError), Cond.LessEqual) asm += Jump(labelGenerator.getLabelArg(OutOfBoundsError), Cond.LessEqual)
// + Int because we store the length of the array at the start // + Int because we store the length of the array at the start
asm += Move( asm += Move(
Register(x.ty.elemSize, AX), Register(x.ty.elemSize, AX),
IndexAddress(RAX, KnownType.Int.size.toInt, RCX, x.ty.elemSize.toInt) MemLocation(RAX, KnownType.Int.size.toInt, (RCX, x.ty.elemSize.toInt))
) )
asm += stack.push(x.ty.elemSize, RAX) asm += stack.push(x.ty.elemSize, RAX)
case UnaryOp(x, op) => case UnaryOp(x, op) =>
@@ -308,7 +313,7 @@ object asmGenerator {
case UnaryOperator.Ord => // No op needed case UnaryOperator.Ord => // No op needed
case UnaryOperator.Len => case UnaryOperator.Len =>
asm += stack.pop(RAX) asm += stack.pop(RAX)
asm += Move(EAX, MemLocation(RAX, D32)) asm += Move(EAX, MemLocation(RAX, opSize = Some(D32)))
asm += stack.push(D32, RAX) asm += stack.push(D32, RAX)
case UnaryOperator.Negate => case UnaryOperator.Negate =>
asm += Xor(EAX, EAX) asm += Xor(EAX, EAX)
@@ -376,7 +381,7 @@ object asmGenerator {
stack.size == stackSizeStart + 1, stack.size == stackSizeStart + 1,
"Sanity check: ONLY the evaluated expression should have been pushed onto the stack" "Sanity check: ONLY the evaluated expression should have been pushed onto the stack"
) )
asm ++= zeroRest(MemLocation(stack.head.pointer, Q64), expr.ty.size) asm ++= zeroRest(MemLocation(base = stack.head.base, opSize = Some(Q64)), expr.ty.size)
asm asm
} }

View File

@@ -97,22 +97,33 @@ object assemblyIR {
} }
} }
case class MemLocation(pointer: Register, opSize: Size) extends Dest with Src { case class MemLocation(
override def toString =
opSize.toString + s"[$pointer]"
}
case class IndexAddress(
base: Register, base: Register,
offset: Int | LabelArg, offset: Int | LabelArg = 0,
indexReg: Register = Register(Size.Q64, RegName.AX), // scale 0 will make register irrelevant, no other reason as to why it's RAX
scale: Int = 0 scaledIndex: (Register, Int) = (Register(Size.Q64, RegName.AX), 0),
opSize: Option[Size] = None
) extends Dest ) extends Dest
with Src { with Src {
override def toString = if (scale != 0) { def copy(
s"[$base + $indexReg * $scale + $offset]" base: Register = this.base,
} else { offset: Int | LabelArg = this.offset,
s"[$base + $offset]" scaledIndex: (Register, Int) = this.scaledIndex,
opSize: Option[Size] = this.opSize
): MemLocation = MemLocation(base, offset, scaledIndex, opSize)
override def toString(): String = {
val opSizeStr = opSize.map(_.toString).getOrElse("")
val baseStr = base.toString
val offsetStr = offset match {
case 0 => ""
case off => s" + $off"
}
val scaledIndexStr = scaledIndex match {
case (reg, scale) if scale != 0 => s" + $reg * $scale"
case _ => ""
}
s"$opSizeStr[$baseStr$scaledIndexStr$offsetStr]"
} }
} }
@@ -145,8 +156,7 @@ object assemblyIR {
case class Pop(op1: Src) extends Operation("pop", op1) case class Pop(op1: Src) extends Operation("pop", op1)
// move operations // move operations
case class Move(op1: Dest, op2: Src) extends Operation("mov", op1, op2) case class Move(op1: Dest, op2: Src) extends Operation("mov", op1, op2)
case class Load(op1: Register, op2: MemLocation | IndexAddress) case class Load(op1: Register, op2: MemLocation) extends Operation("lea ", op1, op2)
extends Operation("lea ", op1, op2)
// function call operations // function call operations
case class Call(op1: CLibFunc | LabelArg) extends Operation("call", op1) case class Call(op1: CLibFunc | LabelArg) extends Operation("call", op1)

View File

@@ -29,12 +29,19 @@ class instructionSpec extends AnyFunSuite {
assert(scratch32BitRegister.toString == "r8d") assert(scratch32BitRegister.toString == "r8d")
} }
val memLocationWithRegister = MemLocation(named64BitRegister, Q64) val memLocationWithRegister = MemLocation(named64BitRegister, opSize = Some(Q64))
test("mem location with register toString") { test("mem location with register toString") {
assert(memLocationWithRegister.toString == "qword ptr [rax]") assert(memLocationWithRegister.toString == "qword ptr [rax]")
} }
val memLocationFull =
MemLocation(named64BitRegister, 32, (scratch64BitRegister, 10), Some(B8))
test("mem location with all fields toString") {
assert(memLocationFull.toString == "byte ptr [rax + r8 * 10 + 32]")
}
val immediateVal = ImmediateVal(123) val immediateVal = ImmediateVal(123)
test("immediate value toString") { test("immediate value toString") {