refactor: merge comments and extracting constants & renaming refactors
Merge request lab2425_spring/WACC_37!36 Co-authored-by: Jonny <j.sinteix@gmail.com> Co-authored-by: Guy C <gc1523@ic.ac.uk>
This commit is contained in:
@@ -17,21 +17,10 @@ sealed trait RuntimeError {
|
|||||||
}
|
}
|
||||||
|
|
||||||
object RuntimeError {
|
object RuntimeError {
|
||||||
|
|
||||||
// TODO: Refactor to mitigate imports and redeclared vals perhaps
|
|
||||||
|
|
||||||
import wacc.asmGenerator.stackAlign
|
import wacc.asmGenerator.stackAlign
|
||||||
import assemblyIR.Size._
|
import assemblyIR.commonRegisters._
|
||||||
import assemblyIR.RegName._
|
|
||||||
|
|
||||||
// private val RAX = Register(Q64, AX)
|
private val ERROR_CODE = 255
|
||||||
// private val EAX = Register(D32, AX)
|
|
||||||
private val RDI = Register(Q64, DI)
|
|
||||||
private val RIP = Register(Q64, IP)
|
|
||||||
// private val RBP = Register(Q64, BP)
|
|
||||||
private val RSI = Register(Q64, SI)
|
|
||||||
// private val RDX = Register(Q64, DX)
|
|
||||||
// private val RCX = Register(Q64, CX)
|
|
||||||
|
|
||||||
case object ZeroDivError extends RuntimeError {
|
case object ZeroDivError extends RuntimeError {
|
||||||
val name = "errDivZero"
|
val name = "errDivZero"
|
||||||
@@ -56,7 +45,7 @@ object RuntimeError {
|
|||||||
stackAlign,
|
stackAlign,
|
||||||
Load(RDI, IndexAddress(RIP, getErrLabel)),
|
Load(RDI, IndexAddress(RIP, getErrLabel)),
|
||||||
assemblyIR.Call(CLibFunc.PrintF),
|
assemblyIR.Call(CLibFunc.PrintF),
|
||||||
Move(RDI, ImmediateVal(255)),
|
Move(RDI, ImmediateVal(ERROR_CODE)),
|
||||||
assemblyIR.Call(CLibFunc.Exit)
|
assemblyIR.Call(CLibFunc.Exit)
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -70,7 +59,7 @@ object RuntimeError {
|
|||||||
stackAlign,
|
stackAlign,
|
||||||
Load(RDI, IndexAddress(RIP, getErrLabel)),
|
Load(RDI, IndexAddress(RIP, getErrLabel)),
|
||||||
assemblyIR.Call(CLibFunc.PrintF),
|
assemblyIR.Call(CLibFunc.PrintF),
|
||||||
Move(RDI, ImmediateVal(255)),
|
Move(RDI, ImmediateVal(ERROR_CODE)),
|
||||||
assemblyIR.Call(CLibFunc.Exit)
|
assemblyIR.Call(CLibFunc.Exit)
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -84,7 +73,7 @@ object RuntimeError {
|
|||||||
stackAlign,
|
stackAlign,
|
||||||
Load(RDI, IndexAddress(RIP, getErrLabel)),
|
Load(RDI, IndexAddress(RIP, getErrLabel)),
|
||||||
assemblyIR.Call(CLibFunc.PrintF),
|
assemblyIR.Call(CLibFunc.PrintF),
|
||||||
Move(RDI, ImmediateVal(255)),
|
Move(RDI, ImmediateVal(ERROR_CODE)),
|
||||||
assemblyIR.Call(CLibFunc.Exit)
|
assemblyIR.Call(CLibFunc.Exit)
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -95,15 +84,29 @@ object RuntimeError {
|
|||||||
protected val errStr = "fatal error: array index %d out of bounds"
|
protected val errStr = "fatal error: array index %d out of bounds"
|
||||||
|
|
||||||
protected def generateHandler(using labelGenerator: LabelGenerator): Chain[AsmLine] = Chain(
|
protected def generateHandler(using labelGenerator: LabelGenerator): Chain[AsmLine] = Chain(
|
||||||
Move(RSI, Register(Q64, CX)),
|
Move(RSI, RCX),
|
||||||
stackAlign,
|
stackAlign,
|
||||||
Load(RDI, IndexAddress(RIP, getErrLabel)),
|
Load(RDI, IndexAddress(RIP, getErrLabel)),
|
||||||
assemblyIR.Call(CLibFunc.PrintF),
|
assemblyIR.Call(CLibFunc.PrintF),
|
||||||
Move(RDI, ImmediateVal(255)),
|
Move(RDI, ImmediateVal(ERROR_CODE)),
|
||||||
|
assemblyIR.Call(CLibFunc.Exit)
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
case object OutOfMemoryError extends RuntimeError {
|
||||||
|
val name = "errOutOfMemory"
|
||||||
|
protected val errStr = "fatal error: out of memory"
|
||||||
|
|
||||||
|
def generateHandler(using labelGenerator: LabelGenerator): Chain[AsmLine] = Chain(
|
||||||
|
stackAlign,
|
||||||
|
Load(RDI, IndexAddress(RIP, getErrLabel)),
|
||||||
|
assemblyIR.Call(CLibFunc.PrintF),
|
||||||
|
Move(RDI, ImmediateVal(ERROR_CODE)),
|
||||||
assemblyIR.Call(CLibFunc.Exit)
|
assemblyIR.Call(CLibFunc.Exit)
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
val all: Chain[RuntimeError] =
|
val all: Chain[RuntimeError] =
|
||||||
Chain(ZeroDivError, BadChrError, NullPtrError, OverflowError, OutOfBoundsError)
|
Chain(ZeroDivError, BadChrError, NullPtrError, OverflowError, OutOfBoundsError,
|
||||||
|
OutOfMemoryError)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -7,26 +7,28 @@ import wacc.RuntimeError._
|
|||||||
object asmGenerator {
|
object asmGenerator {
|
||||||
import microWacc._
|
import microWacc._
|
||||||
import assemblyIR._
|
import assemblyIR._
|
||||||
|
import assemblyIR.commonRegisters._
|
||||||
import assemblyIR.Size._
|
import assemblyIR.Size._
|
||||||
import assemblyIR.RegName._
|
import assemblyIR.RegName._
|
||||||
import types._
|
import types._
|
||||||
import sizeExtensions._
|
import sizeExtensions._
|
||||||
import lexer.escapedChars
|
import lexer.escapedChars
|
||||||
|
|
||||||
private val RAX = Register(Q64, AX)
|
|
||||||
private val EAX = Register(D32, AX)
|
|
||||||
private val RDI = Register(Q64, DI)
|
|
||||||
private val RIP = Register(Q64, IP)
|
|
||||||
private val RBP = Register(Q64, BP)
|
|
||||||
private val RSI = Register(Q64, SI)
|
|
||||||
private val RDX = Register(Q64, DX)
|
|
||||||
private val RCX = Register(Q64, CX)
|
|
||||||
private val ECX = Register(D32, CX)
|
|
||||||
private val argRegs = List(DI, SI, DX, CX, R8, R9)
|
private val argRegs = List(DI, SI, DX, CX, R8, R9)
|
||||||
|
|
||||||
|
private val _7_BIT_MASK = 0x7f
|
||||||
|
|
||||||
extension [T](chain: Chain[T])
|
extension [T](chain: Chain[T])
|
||||||
def +(item: T): Chain[T] = chain.append(item)
|
def +(item: T): Chain[T] = chain.append(item)
|
||||||
|
|
||||||
|
/** Concatenates multiple `Chain[T]` instances into a single `Chain[T]`, appending them to the
|
||||||
|
* current `Chain`.
|
||||||
|
*
|
||||||
|
* @param chains
|
||||||
|
* A variable number of `Chain[T]` instances to concatenate.
|
||||||
|
* @return
|
||||||
|
* A new `Chain[T]` containing all elements from `chain` concatenated with `chains`.
|
||||||
|
*/
|
||||||
def concatAll(chains: Chain[T]*): Chain[T] =
|
def concatAll(chains: Chain[T]*): Chain[T] =
|
||||||
chains.foldLeft(chain)(_ ++ _)
|
chains.foldLeft(chain)(_ ++ _)
|
||||||
|
|
||||||
@@ -60,11 +62,11 @@ object asmGenerator {
|
|||||||
stack: Stack,
|
stack: Stack,
|
||||||
labelGenerator: LabelGenerator
|
labelGenerator: LabelGenerator
|
||||||
): Chain[AsmLine] = {
|
): Chain[AsmLine] = {
|
||||||
var chain = Chain.one[AsmLine](labelGenerator.getLabelDef(builtin))
|
var asm = Chain.one[AsmLine](labelGenerator.getLabelDef(builtin))
|
||||||
chain ++= funcPrologue()
|
asm ++= funcPrologue()
|
||||||
chain ++= funcBody
|
asm ++= funcBody
|
||||||
chain ++= funcEpilogue()
|
asm ++= funcEpilogue()
|
||||||
chain
|
asm
|
||||||
}
|
}
|
||||||
|
|
||||||
private def generateUserFunc(func: FuncDecl)(using
|
private def generateUserFunc(func: FuncDecl)(using
|
||||||
@@ -74,29 +76,29 @@ object asmGenerator {
|
|||||||
// Setup the stack with param 7 and up
|
// Setup the stack with param 7 and up
|
||||||
func.params.drop(argRegs.size).foreach(stack.reserve(_))
|
func.params.drop(argRegs.size).foreach(stack.reserve(_))
|
||||||
stack.reserve(Q64) // Reserve return pointer slot
|
stack.reserve(Q64) // Reserve return pointer slot
|
||||||
var chain = Chain.one[AsmLine](labelGenerator.getLabelDef(func.name))
|
var asm = Chain.one[AsmLine](labelGenerator.getLabelDef(func.name))
|
||||||
chain ++= funcPrologue()
|
asm ++= funcPrologue()
|
||||||
// Push the rest of params onto the stack for simplicity
|
// Push the rest of params onto the stack for simplicity
|
||||||
argRegs.zip(func.params).foreach { (reg, param) =>
|
argRegs.zip(func.params).foreach { (reg, param) =>
|
||||||
chain += stack.push(param, Register(Q64, reg))
|
asm += stack.push(param, Register(Q64, reg))
|
||||||
}
|
}
|
||||||
chain ++= func.body.foldMap(generateStmt(_))
|
asm ++= func.body.foldMap(generateStmt(_))
|
||||||
// No need for epilogue here since all user functions must return explicitly
|
// No need for epilogue here since all user functions must return explicitly
|
||||||
chain
|
asm
|
||||||
}
|
}
|
||||||
|
|
||||||
private def generateBuiltInFuncs()(using
|
private def generateBuiltInFuncs()(using
|
||||||
stack: Stack,
|
stack: Stack,
|
||||||
labelGenerator: LabelGenerator
|
labelGenerator: LabelGenerator
|
||||||
): Chain[AsmLine] = {
|
): Chain[AsmLine] = {
|
||||||
var chain = Chain.empty[AsmLine]
|
var asm = Chain.empty[AsmLine]
|
||||||
|
|
||||||
chain ++= wrapBuiltinFunc(
|
asm ++= wrapBuiltinFunc(
|
||||||
Builtin.Exit,
|
Builtin.Exit,
|
||||||
Chain(stackAlign, assemblyIR.Call(CLibFunc.Exit))
|
Chain(stackAlign, assemblyIR.Call(CLibFunc.Exit))
|
||||||
)
|
)
|
||||||
|
|
||||||
chain ++= wrapBuiltinFunc(
|
asm ++= wrapBuiltinFunc(
|
||||||
Builtin.Printf,
|
Builtin.Printf,
|
||||||
Chain(
|
Chain(
|
||||||
stackAlign,
|
stackAlign,
|
||||||
@@ -106,7 +108,7 @@ object asmGenerator {
|
|||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
chain ++= wrapBuiltinFunc(
|
asm ++= wrapBuiltinFunc(
|
||||||
Builtin.PrintCharArray,
|
Builtin.PrintCharArray,
|
||||||
Chain(
|
Chain(
|
||||||
stackAlign,
|
stackAlign,
|
||||||
@@ -118,13 +120,18 @@ object asmGenerator {
|
|||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
chain ++= wrapBuiltinFunc(
|
asm ++= wrapBuiltinFunc(
|
||||||
Builtin.Malloc,
|
Builtin.Malloc,
|
||||||
Chain(stackAlign, assemblyIR.Call(CLibFunc.Malloc))
|
Chain(
|
||||||
// Out of memory check is optional
|
stackAlign,
|
||||||
|
assemblyIR.Call(CLibFunc.Malloc),
|
||||||
|
// Out of memory check
|
||||||
|
Compare(RAX, ImmediateVal(0)),
|
||||||
|
Jump(labelGenerator.getLabelArg(OutOfMemoryError), Cond.Equal)
|
||||||
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
chain ++= wrapBuiltinFunc(
|
asm ++= wrapBuiltinFunc(
|
||||||
Builtin.Free,
|
Builtin.Free,
|
||||||
Chain(
|
Chain(
|
||||||
stackAlign,
|
stackAlign,
|
||||||
@@ -135,7 +142,7 @@ object asmGenerator {
|
|||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
chain ++= wrapBuiltinFunc(
|
asm ++= wrapBuiltinFunc(
|
||||||
Builtin.Read,
|
Builtin.Read,
|
||||||
Chain(
|
Chain(
|
||||||
stackAlign,
|
stackAlign,
|
||||||
@@ -147,39 +154,39 @@ object asmGenerator {
|
|||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
chain
|
asm
|
||||||
}
|
}
|
||||||
|
|
||||||
private def generateStmt(stmt: Stmt)(using
|
private def generateStmt(stmt: Stmt)(using
|
||||||
stack: Stack,
|
stack: Stack,
|
||||||
labelGenerator: LabelGenerator
|
labelGenerator: LabelGenerator
|
||||||
): Chain[AsmLine] = {
|
): Chain[AsmLine] = {
|
||||||
var chain = Chain.empty[AsmLine]
|
var asm = Chain.empty[AsmLine]
|
||||||
stmt match {
|
stmt match {
|
||||||
case Assign(lhs, rhs) =>
|
case Assign(lhs, rhs) =>
|
||||||
lhs match {
|
lhs match {
|
||||||
case ident: Ident =>
|
case ident: Ident =>
|
||||||
if (!stack.contains(ident)) chain += stack.reserve(ident)
|
if (!stack.contains(ident)) asm += stack.reserve(ident)
|
||||||
chain ++= evalExprOntoStack(rhs)
|
asm ++= evalExprOntoStack(rhs)
|
||||||
chain += stack.pop(RAX)
|
asm += stack.pop(RAX)
|
||||||
chain += Move(stack.accessVar(ident), RAX)
|
asm += Move(stack.accessVar(ident), RAX)
|
||||||
case ArrayElem(x, i) =>
|
case ArrayElem(x, i) =>
|
||||||
chain ++= evalExprOntoStack(rhs)
|
asm ++= evalExprOntoStack(rhs)
|
||||||
chain ++= evalExprOntoStack(i)
|
asm ++= evalExprOntoStack(i)
|
||||||
chain += stack.pop(RCX)
|
asm += stack.pop(RCX)
|
||||||
chain += Compare(ECX, ImmediateVal(0))
|
asm += Compare(ECX, ImmediateVal(0))
|
||||||
chain += Jump(labelGenerator.getLabelArg(OutOfBoundsError), Cond.Less)
|
asm += Jump(labelGenerator.getLabelArg(OutOfBoundsError), Cond.Less)
|
||||||
chain += stack.push(Q64, RCX)
|
asm += stack.push(Q64, RCX)
|
||||||
chain ++= evalExprOntoStack(x)
|
asm ++= evalExprOntoStack(x)
|
||||||
chain += stack.pop(RAX)
|
asm += stack.pop(RAX)
|
||||||
chain += stack.pop(RCX)
|
asm += stack.pop(RCX)
|
||||||
chain += Compare(EAX, ImmediateVal(0))
|
asm += Compare(EAX, ImmediateVal(0))
|
||||||
chain += Jump(labelGenerator.getLabelArg(NullPtrError), Cond.Equal)
|
asm += Jump(labelGenerator.getLabelArg(NullPtrError), Cond.Equal)
|
||||||
chain += Compare(MemLocation(RAX, D32), ECX)
|
asm += Compare(MemLocation(RAX, D32), ECX)
|
||||||
chain += Jump(labelGenerator.getLabelArg(OutOfBoundsError), Cond.LessEqual)
|
asm += Jump(labelGenerator.getLabelArg(OutOfBoundsError), Cond.LessEqual)
|
||||||
chain += stack.pop(RDX)
|
asm += stack.pop(RDX)
|
||||||
|
|
||||||
chain += Move(
|
asm += Move(
|
||||||
IndexAddress(RAX, KnownType.Int.size.toInt, RCX, x.ty.elemSize.toInt),
|
IndexAddress(RAX, KnownType.Int.size.toInt, RCX, x.ty.elemSize.toInt),
|
||||||
Register(x.ty.elemSize, DX)
|
Register(x.ty.elemSize, DX)
|
||||||
)
|
)
|
||||||
@@ -189,250 +196,257 @@ object asmGenerator {
|
|||||||
val elseLabel = labelGenerator.getLabel()
|
val elseLabel = labelGenerator.getLabel()
|
||||||
val endLabel = labelGenerator.getLabel()
|
val endLabel = labelGenerator.getLabel()
|
||||||
|
|
||||||
chain ++= evalExprOntoStack(cond)
|
asm ++= evalExprOntoStack(cond)
|
||||||
chain += stack.pop(RAX)
|
asm += stack.pop(RAX)
|
||||||
chain += Compare(RAX, ImmediateVal(0))
|
asm += Compare(RAX, ImmediateVal(0))
|
||||||
chain += Jump(LabelArg(elseLabel), Cond.Equal)
|
asm += Jump(LabelArg(elseLabel), Cond.Equal)
|
||||||
|
|
||||||
chain ++= stack.withScope(() => thenBranch.foldMap(generateStmt))
|
asm ++= stack.withScope(() => thenBranch.foldMap(generateStmt))
|
||||||
chain += Jump(LabelArg(endLabel))
|
asm += Jump(LabelArg(endLabel))
|
||||||
chain += LabelDef(elseLabel)
|
asm += LabelDef(elseLabel)
|
||||||
|
|
||||||
chain ++= stack.withScope(() => elseBranch.foldMap(generateStmt))
|
asm ++= stack.withScope(() => elseBranch.foldMap(generateStmt))
|
||||||
chain += LabelDef(endLabel)
|
asm += LabelDef(endLabel)
|
||||||
|
|
||||||
case While(cond, body) =>
|
case While(cond, body) =>
|
||||||
val startLabel = labelGenerator.getLabel()
|
val startLabel = labelGenerator.getLabel()
|
||||||
val endLabel = labelGenerator.getLabel()
|
val endLabel = labelGenerator.getLabel()
|
||||||
|
|
||||||
chain += LabelDef(startLabel)
|
asm += LabelDef(startLabel)
|
||||||
chain ++= evalExprOntoStack(cond)
|
asm ++= evalExprOntoStack(cond)
|
||||||
chain += stack.pop(RAX)
|
asm += stack.pop(RAX)
|
||||||
chain += Compare(RAX, ImmediateVal(0))
|
asm += Compare(RAX, ImmediateVal(0))
|
||||||
chain += Jump(LabelArg(endLabel), Cond.Equal)
|
asm += Jump(LabelArg(endLabel), Cond.Equal)
|
||||||
|
|
||||||
chain ++= stack.withScope(() => body.foldMap(generateStmt))
|
asm ++= stack.withScope(() => body.foldMap(generateStmt))
|
||||||
chain += Jump(LabelArg(startLabel))
|
asm += Jump(LabelArg(startLabel))
|
||||||
chain += LabelDef(endLabel)
|
asm += LabelDef(endLabel)
|
||||||
|
|
||||||
case call: microWacc.Call =>
|
case call: microWacc.Call =>
|
||||||
chain ++= generateCall(call, isTail = false)
|
asm ++= generateCall(call, isTail = false)
|
||||||
|
|
||||||
case microWacc.Return(expr) =>
|
case microWacc.Return(expr) =>
|
||||||
expr match {
|
expr match {
|
||||||
case call: microWacc.Call =>
|
case call: microWacc.Call =>
|
||||||
chain ++= generateCall(call, isTail = true) // tco
|
asm ++= generateCall(call, isTail = true) // tco
|
||||||
case _ =>
|
case _ =>
|
||||||
chain ++= evalExprOntoStack(expr)
|
asm ++= evalExprOntoStack(expr)
|
||||||
chain += stack.pop(RAX)
|
asm += stack.pop(RAX)
|
||||||
chain ++= funcEpilogue()
|
asm ++= funcEpilogue()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
chain
|
asm
|
||||||
}
|
}
|
||||||
|
|
||||||
private def evalExprOntoStack(expr: Expr)(using
|
private def evalExprOntoStack(expr: Expr)(using
|
||||||
stack: Stack,
|
stack: Stack,
|
||||||
labelGenerator: LabelGenerator
|
labelGenerator: LabelGenerator
|
||||||
): Chain[AsmLine] = {
|
): Chain[AsmLine] = {
|
||||||
var chain = Chain.empty[AsmLine]
|
var asm = Chain.empty[AsmLine]
|
||||||
val stackSizeStart = stack.size
|
val stackSizeStart = stack.size
|
||||||
expr match {
|
expr match {
|
||||||
case IntLiter(v) => chain += stack.push(KnownType.Int.size, ImmediateVal(v))
|
case IntLiter(v) => asm += stack.push(KnownType.Int.size, ImmediateVal(v))
|
||||||
case CharLiter(v) => chain += stack.push(KnownType.Char.size, ImmediateVal(v.toInt))
|
case CharLiter(v) => asm += stack.push(KnownType.Char.size, ImmediateVal(v.toInt))
|
||||||
case ident: Ident => chain += stack.push(ident.ty.size, stack.accessVar(ident))
|
case ident: Ident => asm += stack.push(ident.ty.size, stack.accessVar(ident))
|
||||||
|
|
||||||
case array @ ArrayLiter(elems) =>
|
case array @ ArrayLiter(elems) =>
|
||||||
expr.ty match {
|
expr.ty match {
|
||||||
case KnownType.String =>
|
case KnownType.String =>
|
||||||
val str = elems.collect { case CharLiter(v) => v }.mkString
|
val str = elems.collect { case CharLiter(v) => v }.mkString
|
||||||
chain += Load(RAX, IndexAddress(RIP, labelGenerator.getLabelArg(str)))
|
asm += Load(RAX, IndexAddress(RIP, labelGenerator.getLabelArg(str)))
|
||||||
chain += stack.push(Q64, RAX)
|
asm += stack.push(Q64, RAX)
|
||||||
case ty =>
|
case ty =>
|
||||||
chain ++= generateCall(
|
asm ++= generateCall(
|
||||||
microWacc.Call(Builtin.Malloc, List(IntLiter(array.heapSize))),
|
microWacc.Call(Builtin.Malloc, List(IntLiter(array.heapSize))),
|
||||||
isTail = false
|
isTail = false
|
||||||
)
|
)
|
||||||
chain += stack.push(Q64, RAX)
|
asm += stack.push(Q64, RAX)
|
||||||
// Store the length of the array at the start
|
// Store the length of the array at the start
|
||||||
chain += Move(MemLocation(RAX, D32), ImmediateVal(elems.size))
|
asm += Move(MemLocation(RAX, D32), ImmediateVal(elems.size))
|
||||||
elems.zipWithIndex.foldMap { (elem, i) =>
|
elems.zipWithIndex.foldMap { (elem, i) =>
|
||||||
chain ++= evalExprOntoStack(elem)
|
asm ++= evalExprOntoStack(elem)
|
||||||
chain += stack.pop(RCX)
|
asm += stack.pop(RCX)
|
||||||
chain += stack.pop(RAX)
|
asm += stack.pop(RAX)
|
||||||
chain += Move(IndexAddress(RAX, 4 + i * ty.elemSize.toInt), Register(ty.elemSize, CX))
|
asm += Move(IndexAddress(RAX, 4 + i * ty.elemSize.toInt), Register(ty.elemSize, CX))
|
||||||
chain += stack.push(Q64, RAX)
|
asm += stack.push(Q64, RAX)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
case BoolLiter(true) =>
|
case BoolLiter(true) =>
|
||||||
chain += stack.push(KnownType.Bool.size, ImmediateVal(1))
|
asm += stack.push(KnownType.Bool.size, ImmediateVal(1))
|
||||||
case BoolLiter(false) =>
|
case BoolLiter(false) =>
|
||||||
chain += Xor(RAX, RAX)
|
asm += Xor(RAX, RAX)
|
||||||
chain += stack.push(KnownType.Bool.size, RAX)
|
asm += stack.push(KnownType.Bool.size, RAX)
|
||||||
case NullLiter() =>
|
case NullLiter() =>
|
||||||
chain += stack.push(KnownType.Pair(?, ?).size, ImmediateVal(0))
|
asm += stack.push(KnownType.Pair(?, ?).size, ImmediateVal(0))
|
||||||
case ArrayElem(x, i) =>
|
case ArrayElem(x, i) =>
|
||||||
chain ++= evalExprOntoStack(x)
|
asm ++= evalExprOntoStack(x)
|
||||||
chain ++= evalExprOntoStack(i)
|
asm ++= evalExprOntoStack(i)
|
||||||
chain += stack.pop(RCX)
|
asm += stack.pop(RCX)
|
||||||
chain += Compare(RCX, ImmediateVal(0))
|
asm += Compare(RCX, ImmediateVal(0))
|
||||||
chain += Jump(labelGenerator.getLabelArg(OutOfBoundsError), Cond.Less)
|
asm += Jump(labelGenerator.getLabelArg(OutOfBoundsError), Cond.Less)
|
||||||
chain += stack.pop(RAX)
|
asm += stack.pop(RAX)
|
||||||
chain += Compare(EAX, ImmediateVal(0))
|
asm += Compare(EAX, ImmediateVal(0))
|
||||||
chain += Jump(labelGenerator.getLabelArg(NullPtrError), Cond.Equal)
|
asm += Jump(labelGenerator.getLabelArg(NullPtrError), Cond.Equal)
|
||||||
chain += Compare(MemLocation(RAX, D32), ECX)
|
asm += Compare(MemLocation(RAX, D32), ECX)
|
||||||
chain += Jump(labelGenerator.getLabelArg(OutOfBoundsError), Cond.LessEqual)
|
asm += Jump(labelGenerator.getLabelArg(OutOfBoundsError), Cond.LessEqual)
|
||||||
// + Int because we store the length of the array at the start
|
// + Int because we store the length of the array at the start
|
||||||
chain += Move(
|
asm += Move(
|
||||||
Register(x.ty.elemSize, AX),
|
Register(x.ty.elemSize, AX),
|
||||||
IndexAddress(RAX, KnownType.Int.size.toInt, RCX, x.ty.elemSize.toInt)
|
IndexAddress(RAX, KnownType.Int.size.toInt, RCX, x.ty.elemSize.toInt)
|
||||||
)
|
)
|
||||||
chain += stack.push(x.ty.elemSize, RAX)
|
asm += stack.push(x.ty.elemSize, RAX)
|
||||||
case UnaryOp(x, op) =>
|
case UnaryOp(x, op) =>
|
||||||
chain ++= evalExprOntoStack(x)
|
asm ++= evalExprOntoStack(x)
|
||||||
op match {
|
op match {
|
||||||
case UnaryOperator.Chr =>
|
case UnaryOperator.Chr =>
|
||||||
chain += Move(EAX, stack.head)
|
asm += Move(EAX, stack.head)
|
||||||
chain += And(EAX, ImmediateVal(-128))
|
asm += And(EAX, ImmediateVal(~_7_BIT_MASK))
|
||||||
chain += Compare(EAX, ImmediateVal(0))
|
asm += Compare(EAX, ImmediateVal(0))
|
||||||
chain += Jump(labelGenerator.getLabelArg(BadChrError), Cond.NotEqual)
|
asm += Jump(labelGenerator.getLabelArg(BadChrError), Cond.NotEqual)
|
||||||
case UnaryOperator.Ord => // No op needed
|
case UnaryOperator.Ord => // No op needed
|
||||||
case UnaryOperator.Len =>
|
case UnaryOperator.Len =>
|
||||||
chain += stack.pop(RAX)
|
asm += stack.pop(RAX)
|
||||||
chain += Move(EAX, MemLocation(RAX, D32))
|
asm += Move(EAX, MemLocation(RAX, D32))
|
||||||
chain += stack.push(D32, RAX)
|
asm += stack.push(D32, RAX)
|
||||||
case UnaryOperator.Negate =>
|
case UnaryOperator.Negate =>
|
||||||
chain += Xor(EAX, EAX)
|
asm += Xor(EAX, EAX)
|
||||||
chain += Subtract(EAX, stack.head)
|
asm += Subtract(EAX, stack.head)
|
||||||
chain += Jump(labelGenerator.getLabelArg(OverflowError), Cond.Overflow)
|
asm += Jump(labelGenerator.getLabelArg(OverflowError), Cond.Overflow)
|
||||||
chain += stack.drop()
|
asm += stack.drop()
|
||||||
chain += stack.push(Q64, RAX)
|
asm += stack.push(Q64, RAX)
|
||||||
case UnaryOperator.Not =>
|
case UnaryOperator.Not =>
|
||||||
chain += Xor(stack.head, ImmediateVal(1))
|
asm += Xor(stack.head, ImmediateVal(1))
|
||||||
}
|
}
|
||||||
|
|
||||||
case BinaryOp(x, y, op) =>
|
case BinaryOp(x, y, op) =>
|
||||||
val destX = Register(x.ty.size, AX)
|
val destX = Register(x.ty.size, AX)
|
||||||
chain ++= evalExprOntoStack(y)
|
asm ++= evalExprOntoStack(y)
|
||||||
chain ++= evalExprOntoStack(x)
|
asm ++= evalExprOntoStack(x)
|
||||||
chain += stack.pop(RAX)
|
asm += stack.pop(RAX)
|
||||||
|
|
||||||
op match {
|
op match {
|
||||||
case BinaryOperator.Add =>
|
case BinaryOperator.Add =>
|
||||||
chain += Add(stack.head, destX)
|
asm += Add(stack.head, destX)
|
||||||
chain += Jump(labelGenerator.getLabelArg(OverflowError), Cond.Overflow)
|
asm += Jump(labelGenerator.getLabelArg(OverflowError), Cond.Overflow)
|
||||||
case BinaryOperator.Sub =>
|
case BinaryOperator.Sub =>
|
||||||
chain += Subtract(destX, stack.head)
|
asm += Subtract(destX, stack.head)
|
||||||
chain += Jump(labelGenerator.getLabelArg(OverflowError), Cond.Overflow)
|
asm += Jump(labelGenerator.getLabelArg(OverflowError), Cond.Overflow)
|
||||||
chain += stack.drop()
|
asm += stack.drop()
|
||||||
chain += stack.push(destX.size, RAX)
|
asm += stack.push(destX.size, RAX)
|
||||||
case BinaryOperator.Mul =>
|
case BinaryOperator.Mul =>
|
||||||
chain += Multiply(destX, stack.head)
|
asm += Multiply(destX, stack.head)
|
||||||
chain += Jump(labelGenerator.getLabelArg(OverflowError), Cond.Overflow)
|
asm += Jump(labelGenerator.getLabelArg(OverflowError), Cond.Overflow)
|
||||||
chain += stack.drop()
|
asm += stack.drop()
|
||||||
chain += stack.push(destX.size, RAX)
|
asm += stack.push(destX.size, RAX)
|
||||||
|
|
||||||
case BinaryOperator.Div =>
|
case BinaryOperator.Div =>
|
||||||
chain += Compare(stack.head, ImmediateVal(0))
|
asm += Compare(stack.head, ImmediateVal(0))
|
||||||
chain += Jump(labelGenerator.getLabelArg(ZeroDivError), Cond.Equal)
|
asm += Jump(labelGenerator.getLabelArg(ZeroDivError), Cond.Equal)
|
||||||
chain += CDQ()
|
asm += CDQ()
|
||||||
chain += Divide(stack.head)
|
asm += Divide(stack.head)
|
||||||
chain += stack.drop()
|
asm += stack.drop()
|
||||||
chain += stack.push(destX.size, RAX)
|
asm += stack.push(destX.size, RAX)
|
||||||
|
|
||||||
case BinaryOperator.Mod =>
|
case BinaryOperator.Mod =>
|
||||||
chain += Compare(stack.head, ImmediateVal(0))
|
asm += Compare(stack.head, ImmediateVal(0))
|
||||||
chain += Jump(labelGenerator.getLabelArg(ZeroDivError), Cond.Equal)
|
asm += Jump(labelGenerator.getLabelArg(ZeroDivError), Cond.Equal)
|
||||||
chain += CDQ()
|
asm += CDQ()
|
||||||
chain += Divide(stack.head)
|
asm += Divide(stack.head)
|
||||||
chain += stack.drop()
|
asm += stack.drop()
|
||||||
chain += stack.push(destX.size, RDX)
|
asm += stack.push(destX.size, RDX)
|
||||||
|
|
||||||
case BinaryOperator.Eq => chain ++= generateComparison(destX, Cond.Equal)
|
case BinaryOperator.Eq => asm ++= generateComparison(destX, Cond.Equal)
|
||||||
case BinaryOperator.Neq => chain ++= generateComparison(destX, Cond.NotEqual)
|
case BinaryOperator.Neq => asm ++= generateComparison(destX, Cond.NotEqual)
|
||||||
case BinaryOperator.Greater => chain ++= generateComparison(destX, Cond.Greater)
|
case BinaryOperator.Greater => asm ++= generateComparison(destX, Cond.Greater)
|
||||||
case BinaryOperator.GreaterEq => chain ++= generateComparison(destX, Cond.GreaterEqual)
|
case BinaryOperator.GreaterEq => asm ++= generateComparison(destX, Cond.GreaterEqual)
|
||||||
case BinaryOperator.Less => chain ++= generateComparison(destX, Cond.Less)
|
case BinaryOperator.Less => asm ++= generateComparison(destX, Cond.Less)
|
||||||
case BinaryOperator.LessEq => chain ++= generateComparison(destX, Cond.LessEqual)
|
case BinaryOperator.LessEq => asm ++= generateComparison(destX, Cond.LessEqual)
|
||||||
case BinaryOperator.And => chain += And(stack.head, destX)
|
case BinaryOperator.And => asm += And(stack.head, destX)
|
||||||
case BinaryOperator.Or => chain += Or(stack.head, destX)
|
case BinaryOperator.Or => asm += Or(stack.head, destX)
|
||||||
}
|
}
|
||||||
|
|
||||||
case call: microWacc.Call =>
|
case call: microWacc.Call =>
|
||||||
chain ++= generateCall(call, isTail = false)
|
asm ++= generateCall(call, isTail = false)
|
||||||
chain += stack.push(call.ty.size, RAX)
|
asm += stack.push(call.ty.size, RAX)
|
||||||
}
|
}
|
||||||
|
|
||||||
assert(stack.size == stackSizeStart + 1)
|
assert(
|
||||||
chain ++= zeroRest(MemLocation(stack.head.pointer, Q64), expr.ty.size)
|
stack.size == stackSizeStart + 1,
|
||||||
chain
|
"Sanity check: ONLY the evaluated expression should have been pushed onto the stack"
|
||||||
|
)
|
||||||
|
asm ++= zeroRest(MemLocation(stack.head.pointer, Q64), expr.ty.size)
|
||||||
|
asm
|
||||||
}
|
}
|
||||||
|
|
||||||
private def generateCall(call: microWacc.Call, isTail: Boolean)(using
|
private def generateCall(call: microWacc.Call, isTail: Boolean)(using
|
||||||
stack: Stack,
|
stack: Stack,
|
||||||
labelGenerator: LabelGenerator
|
labelGenerator: LabelGenerator
|
||||||
): Chain[AsmLine] = {
|
): Chain[AsmLine] = {
|
||||||
var chain = Chain.empty[AsmLine]
|
var asm = Chain.empty[AsmLine]
|
||||||
val microWacc.Call(target, args) = call
|
val microWacc.Call(target, args) = call
|
||||||
|
|
||||||
|
// Evaluate arguments 0-6
|
||||||
argRegs
|
argRegs
|
||||||
.zip(args)
|
.zip(args)
|
||||||
.map { (reg, expr) =>
|
.map { (reg, expr) =>
|
||||||
chain ++= evalExprOntoStack(expr)
|
asm ++= evalExprOntoStack(expr)
|
||||||
reg
|
reg
|
||||||
}
|
}
|
||||||
|
// And set the appropriate registers
|
||||||
.reverse
|
.reverse
|
||||||
.foreach { reg =>
|
.foreach { reg =>
|
||||||
chain += stack.pop(Register(Q64, reg))
|
asm += stack.pop(Register(Q64, reg))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Evaluate arguments 7 and up and push them onto the stack
|
||||||
args.drop(argRegs.size).foldMap {
|
args.drop(argRegs.size).foldMap {
|
||||||
chain ++= evalExprOntoStack(_)
|
asm ++= evalExprOntoStack(_)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Tail Call Optimisation (TCO)
|
// Tail Call Optimisation (TCO)
|
||||||
if (isTail) {
|
if (isTail) {
|
||||||
chain += Jump(labelGenerator.getLabelArg(target)) // tail call
|
asm += Jump(labelGenerator.getLabelArg(target)) // tail call
|
||||||
} else {
|
} else {
|
||||||
chain += assemblyIR.Call(labelGenerator.getLabelArg(target)) // regular call
|
asm += assemblyIR.Call(labelGenerator.getLabelArg(target)) // regular call
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Remove arguments 7 and up from the stack
|
||||||
if (args.size > argRegs.size) {
|
if (args.size > argRegs.size) {
|
||||||
chain += stack.drop(args.size - argRegs.size)
|
asm += stack.drop(args.size - argRegs.size)
|
||||||
}
|
}
|
||||||
|
|
||||||
chain
|
asm
|
||||||
}
|
}
|
||||||
|
|
||||||
private def generateComparison(destX: Register, cond: Cond)(using
|
private def generateComparison(destX: Register, cond: Cond)(using
|
||||||
stack: Stack
|
stack: Stack
|
||||||
): Chain[AsmLine] = {
|
): Chain[AsmLine] = {
|
||||||
var chain = Chain.empty[AsmLine]
|
var asm = Chain.empty[AsmLine]
|
||||||
|
|
||||||
chain += Compare(destX, stack.head)
|
asm += Compare(destX, stack.head)
|
||||||
chain += Set(Register(B8, AX), cond)
|
asm += Set(Register(B8, AX), cond)
|
||||||
chain ++= zeroRest(RAX, B8)
|
asm ++= zeroRest(RAX, B8)
|
||||||
chain += stack.drop()
|
asm += stack.drop()
|
||||||
chain += stack.push(B8, RAX)
|
asm += stack.push(B8, RAX)
|
||||||
|
|
||||||
chain
|
asm
|
||||||
}
|
}
|
||||||
|
|
||||||
private def funcPrologue()(using stack: Stack): Chain[AsmLine] = {
|
private def funcPrologue()(using stack: Stack): Chain[AsmLine] = {
|
||||||
var chain = Chain.empty[AsmLine]
|
var asm = Chain.empty[AsmLine]
|
||||||
chain += stack.push(Q64, RBP)
|
asm += stack.push(Q64, RBP)
|
||||||
chain += Move(RBP, Register(Q64, SP))
|
asm += Move(RBP, Register(Q64, SP))
|
||||||
chain
|
asm
|
||||||
}
|
}
|
||||||
|
|
||||||
private def funcEpilogue(): Chain[AsmLine] = {
|
private def funcEpilogue(): Chain[AsmLine] = {
|
||||||
var chain = Chain.empty[AsmLine]
|
var asm = Chain.empty[AsmLine]
|
||||||
chain += Move(Register(Q64, SP), RBP)
|
asm += Move(Register(Q64, SP), RBP)
|
||||||
chain += Pop(RBP)
|
asm += Pop(RBP)
|
||||||
chain += assemblyIR.Return()
|
asm += assemblyIR.Return()
|
||||||
chain
|
asm
|
||||||
}
|
}
|
||||||
|
|
||||||
def stackAlign: AsmLine = And(Register(Q64, SP), ImmediateVal(-16))
|
def stackAlign: AsmLine = And(Register(Q64, SP), ImmediateVal(-16))
|
||||||
|
|||||||
@@ -102,7 +102,6 @@ object assemblyIR {
|
|||||||
opSize.toString + s"[$pointer]"
|
opSize.toString + s"[$pointer]"
|
||||||
}
|
}
|
||||||
|
|
||||||
// TODO to string is wacky
|
|
||||||
case class IndexAddress(
|
case class IndexAddress(
|
||||||
base: Register,
|
base: Register,
|
||||||
offset: Int | LabelArg,
|
offset: Int | LabelArg,
|
||||||
@@ -125,36 +124,37 @@ object assemblyIR {
|
|||||||
override def toString = name
|
override def toString = name
|
||||||
}
|
}
|
||||||
|
|
||||||
// TODO Check if dest and src are not both memory locations
|
|
||||||
abstract class Operation(ins: String, ops: Operand*) extends AsmLine {
|
abstract class Operation(ins: String, ops: Operand*) extends AsmLine {
|
||||||
override def toString: String = s"\t$ins ${ops.mkString(", ")}"
|
override def toString: String = s"\t$ins ${ops.mkString(", ")}"
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// arithmetic operations
|
||||||
case class Add(op1: Dest, op2: Src) extends Operation("add", op1, op2)
|
case class Add(op1: Dest, op2: Src) extends Operation("add", op1, op2)
|
||||||
case class Subtract(op1: Dest, op2: Src) extends Operation("sub", op1, op2)
|
case class Subtract(op1: Dest, op2: Src) extends Operation("sub", op1, op2)
|
||||||
case class Multiply(ops: Operand*) extends Operation("imul", ops*)
|
case class Multiply(ops: Operand*) extends Operation("imul", ops*)
|
||||||
case class Divide(op1: Src) extends Operation("idiv", op1)
|
case class Divide(op1: Src) extends Operation("idiv", op1)
|
||||||
case class Negate(op: Dest) extends Operation("neg", op)
|
case class Negate(op: Dest) extends Operation("neg", op)
|
||||||
|
// bitwise operations
|
||||||
case class And(op1: Dest, op2: Src) extends Operation("and", op1, op2)
|
case class And(op1: Dest, op2: Src) extends Operation("and", op1, op2)
|
||||||
case class Or(op1: Dest, op2: Src) extends Operation("or", op1, op2)
|
case class Or(op1: Dest, op2: Src) extends Operation("or", op1, op2)
|
||||||
case class Xor(op1: Dest, op2: Src) extends Operation("xor", op1, op2)
|
case class Xor(op1: Dest, op2: Src) extends Operation("xor", op1, op2)
|
||||||
case class Compare(op1: Dest, op2: Src) extends Operation("cmp", op1, op2)
|
case class Compare(op1: Dest, op2: Src) extends Operation("cmp", op1, op2)
|
||||||
|
case class CDQ() extends Operation("cdq")
|
||||||
// stack operations
|
// stack operations
|
||||||
case class Push(op1: Src) extends Operation("push", op1)
|
case class Push(op1: Src) extends Operation("push", op1)
|
||||||
case class Pop(op1: Src) extends Operation("pop", op1)
|
case class Pop(op1: Src) extends Operation("pop", op1)
|
||||||
case class Call(op1: CLibFunc | LabelArg) extends Operation("call", op1)
|
// move operations
|
||||||
|
|
||||||
case class Move(op1: Dest, op2: Src) extends Operation("mov", op1, op2)
|
case class Move(op1: Dest, op2: Src) extends Operation("mov", op1, op2)
|
||||||
case class Load(op1: Register, op2: MemLocation | IndexAddress)
|
case class Load(op1: Register, op2: MemLocation | IndexAddress)
|
||||||
extends Operation("lea ", op1, op2)
|
extends Operation("lea ", op1, op2)
|
||||||
case class CDQ() extends Operation("cdq")
|
|
||||||
|
|
||||||
|
// function call operations
|
||||||
|
case class Call(op1: CLibFunc | LabelArg) extends Operation("call", op1)
|
||||||
case class Return() extends Operation("ret")
|
case class Return() extends Operation("ret")
|
||||||
|
|
||||||
|
// conditional operations
|
||||||
case class Jump(op1: LabelArg, condition: Cond = Cond.Always)
|
case class Jump(op1: LabelArg, condition: Cond = Cond.Always)
|
||||||
extends Operation(s"j${condition.toString}", op1)
|
extends Operation(s"j${condition.toString}", op1)
|
||||||
|
|
||||||
case class Set(op1: Dest, condition: Cond = Cond.Always)
|
case class Set(op1: Dest, condition: Cond = Cond.Always)
|
||||||
extends Operation(s"set${condition.toString}", op1)
|
extends Operation(s"set${condition.toString}", op1)
|
||||||
|
|
||||||
@@ -213,4 +213,19 @@ object assemblyIR {
|
|||||||
case String => "%s"
|
case String => "%s"
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
object commonRegisters {
|
||||||
|
import Size._
|
||||||
|
import RegName._
|
||||||
|
|
||||||
|
val RAX = Register(Q64, AX)
|
||||||
|
val EAX = Register(D32, AX)
|
||||||
|
val RDI = Register(Q64, DI)
|
||||||
|
val RIP = Register(Q64, IP)
|
||||||
|
val RBP = Register(Q64, BP)
|
||||||
|
val RSI = Register(Q64, SI)
|
||||||
|
val RDX = Register(Q64, DX)
|
||||||
|
val RCX = Register(Q64, CX)
|
||||||
|
val ECX = Register(D32, CX)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user