refactor: return args and defs from labelGenerator, instead of strings
This commit is contained in:
@@ -17,23 +17,27 @@ private class LabelGenerator {
|
|||||||
s".L$labelVal"
|
s".L$labelVal"
|
||||||
}
|
}
|
||||||
|
|
||||||
/** Get a named label for a function. */
|
private def getLabel(target: CallTarget | RuntimeError): String = target match {
|
||||||
def getLabel(target: CallTarget): String = target match {
|
|
||||||
case Ident(v, _) => s"wacc_$v"
|
case Ident(v, _) => s"wacc_$v"
|
||||||
case Builtin(name) => s"_$name"
|
case Builtin(name) => s"_$name"
|
||||||
|
case err: RuntimeError => s".L.${err.name}"
|
||||||
}
|
}
|
||||||
|
|
||||||
/** Get a named label for an error. */
|
/** Get a named label def for a function or error. */
|
||||||
def getLabel(target: RuntimeError): String =
|
def getLabelDef(target: CallTarget | RuntimeError): LabelDef =
|
||||||
s".L.${target.name}"
|
LabelDef(getLabel(target))
|
||||||
|
|
||||||
|
/** Get a named label for a function or error. */
|
||||||
|
def getLabelArg(target: CallTarget | RuntimeError): LabelArg =
|
||||||
|
LabelArg(getLabel(target))
|
||||||
|
|
||||||
/** Get an arbitrary label for a string. */
|
/** Get an arbitrary label for a string. */
|
||||||
def getLabel(str: String): String =
|
def getLabelArg(str: String): LabelArg =
|
||||||
strings.getOrElseUpdate(str, s".L.str${strings.size}")
|
LabelArg(strings.getOrElseUpdate(str, s".L.str${strings.size}"))
|
||||||
|
|
||||||
/** Get a named label for a string. */
|
/** Get a named label for a string. */
|
||||||
def getLabel(src: String, name: String): String =
|
def getLabelArg(src: String, name: String): LabelArg =
|
||||||
strings.getOrElseUpdate(src, s".L.$name.str${strings.size}")
|
LabelArg(strings.getOrElseUpdate(src, s".L.$name.str${strings.size}"))
|
||||||
|
|
||||||
/** Generate the assembly labels for constants that were labelled using the LabelGenerator. */
|
/** Generate the assembly labels for constants that were labelled using the LabelGenerator. */
|
||||||
def generateConstants: Chain[AsmLine] =
|
def generateConstants: Chain[AsmLine] =
|
||||||
|
|||||||
@@ -7,13 +7,13 @@ sealed trait RuntimeError {
|
|||||||
val name: String
|
val name: String
|
||||||
protected val errStr: String
|
protected val errStr: String
|
||||||
|
|
||||||
protected def getErrLabel(using labelGenerator: LabelGenerator): String =
|
protected def getErrLabel(using labelGenerator: LabelGenerator): LabelArg =
|
||||||
labelGenerator.getLabel(errStr, name = name)
|
labelGenerator.getLabelArg(errStr, name = name)
|
||||||
|
|
||||||
protected def generateHandler(using labelGenerator: LabelGenerator): Chain[AsmLine]
|
protected def generateHandler(using labelGenerator: LabelGenerator): Chain[AsmLine]
|
||||||
|
|
||||||
def generate(using labelGenerator: LabelGenerator): Chain[AsmLine] =
|
def generate(using labelGenerator: LabelGenerator): Chain[AsmLine] =
|
||||||
LabelDef(labelGenerator.getLabel(this)) +: generateHandler
|
labelGenerator.getLabelDef(this) +: generateHandler
|
||||||
}
|
}
|
||||||
|
|
||||||
object RuntimeError {
|
object RuntimeError {
|
||||||
@@ -39,7 +39,7 @@ object RuntimeError {
|
|||||||
|
|
||||||
protected def generateHandler(using labelGenerator: LabelGenerator): Chain[AsmLine] = Chain(
|
protected def generateHandler(using labelGenerator: LabelGenerator): Chain[AsmLine] = Chain(
|
||||||
stackAlign,
|
stackAlign,
|
||||||
Load(RDI, IndexAddress(RIP, LabelArg(getErrLabel))),
|
Load(RDI, IndexAddress(RIP, getErrLabel)),
|
||||||
assemblyIR.Call(CLibFunc.PrintF),
|
assemblyIR.Call(CLibFunc.PrintF),
|
||||||
Move(RDI, ImmediateVal(-1)),
|
Move(RDI, ImmediateVal(-1)),
|
||||||
assemblyIR.Call(CLibFunc.Exit)
|
assemblyIR.Call(CLibFunc.Exit)
|
||||||
@@ -54,7 +54,7 @@ object RuntimeError {
|
|||||||
protected def generateHandler(using labelGenerator: LabelGenerator): Chain[AsmLine] = Chain(
|
protected def generateHandler(using labelGenerator: LabelGenerator): Chain[AsmLine] = Chain(
|
||||||
Pop(RSI),
|
Pop(RSI),
|
||||||
stackAlign,
|
stackAlign,
|
||||||
Load(RDI, IndexAddress(RIP, LabelArg(getErrLabel))),
|
Load(RDI, IndexAddress(RIP, getErrLabel)),
|
||||||
assemblyIR.Call(CLibFunc.PrintF),
|
assemblyIR.Call(CLibFunc.PrintF),
|
||||||
Move(RDI, ImmediateVal(255)),
|
Move(RDI, ImmediateVal(255)),
|
||||||
assemblyIR.Call(CLibFunc.Exit)
|
assemblyIR.Call(CLibFunc.Exit)
|
||||||
@@ -68,7 +68,7 @@ object RuntimeError {
|
|||||||
|
|
||||||
protected def generateHandler(using labelGenerator: LabelGenerator): Chain[AsmLine] = Chain(
|
protected def generateHandler(using labelGenerator: LabelGenerator): Chain[AsmLine] = Chain(
|
||||||
stackAlign,
|
stackAlign,
|
||||||
Load(RDI, IndexAddress(RIP, LabelArg(getErrLabel))),
|
Load(RDI, IndexAddress(RIP, getErrLabel)),
|
||||||
assemblyIR.Call(CLibFunc.PrintF),
|
assemblyIR.Call(CLibFunc.PrintF),
|
||||||
Move(RDI, ImmediateVal(255)),
|
Move(RDI, ImmediateVal(255)),
|
||||||
assemblyIR.Call(CLibFunc.Exit)
|
assemblyIR.Call(CLibFunc.Exit)
|
||||||
@@ -82,7 +82,7 @@ object RuntimeError {
|
|||||||
|
|
||||||
protected def generateHandler(using labelGenerator: LabelGenerator): Chain[AsmLine] = Chain(
|
protected def generateHandler(using labelGenerator: LabelGenerator): Chain[AsmLine] = Chain(
|
||||||
stackAlign,
|
stackAlign,
|
||||||
Load(RDI, IndexAddress(RIP, LabelArg(getErrLabel))),
|
Load(RDI, IndexAddress(RIP, getErrLabel)),
|
||||||
assemblyIR.Call(CLibFunc.PrintF),
|
assemblyIR.Call(CLibFunc.PrintF),
|
||||||
Move(RDI, ImmediateVal(255)),
|
Move(RDI, ImmediateVal(255)),
|
||||||
assemblyIR.Call(CLibFunc.Exit)
|
assemblyIR.Call(CLibFunc.Exit)
|
||||||
@@ -97,7 +97,7 @@ object RuntimeError {
|
|||||||
protected def generateHandler(using labelGenerator: LabelGenerator): Chain[AsmLine] = Chain(
|
protected def generateHandler(using labelGenerator: LabelGenerator): Chain[AsmLine] = Chain(
|
||||||
Move(RSI, Register(Q64, CX)),
|
Move(RSI, Register(Q64, CX)),
|
||||||
stackAlign,
|
stackAlign,
|
||||||
Load(RDI, IndexAddress(RIP, LabelArg(getErrLabel))),
|
Load(RDI, IndexAddress(RIP, getErrLabel)),
|
||||||
assemblyIR.Call(CLibFunc.PrintF),
|
assemblyIR.Call(CLibFunc.PrintF),
|
||||||
Move(RDI, ImmediateVal(255)),
|
Move(RDI, ImmediateVal(255)),
|
||||||
assemblyIR.Call(CLibFunc.Exit)
|
assemblyIR.Call(CLibFunc.Exit)
|
||||||
|
|||||||
@@ -56,10 +56,11 @@ object asmGenerator {
|
|||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
private def wrapBuiltinFunc(labelName: String, funcBody: Chain[AsmLine])(using
|
private def wrapBuiltinFunc(builtin: Builtin, funcBody: Chain[AsmLine])(using
|
||||||
stack: Stack
|
stack: Stack,
|
||||||
|
labelGenerator: LabelGenerator
|
||||||
): Chain[AsmLine] = {
|
): Chain[AsmLine] = {
|
||||||
var chain = Chain.one[AsmLine](LabelDef(labelName))
|
var chain = Chain.one[AsmLine](labelGenerator.getLabelDef(builtin))
|
||||||
chain ++= funcPrologue()
|
chain ++= funcPrologue()
|
||||||
chain ++= funcBody
|
chain ++= funcBody
|
||||||
chain ++= funcEpilogue()
|
chain ++= funcEpilogue()
|
||||||
@@ -73,7 +74,7 @@ object asmGenerator {
|
|||||||
// Setup the stack with param 7 and up
|
// Setup the stack with param 7 and up
|
||||||
func.params.drop(argRegs.size).foreach(stack.reserve(_))
|
func.params.drop(argRegs.size).foreach(stack.reserve(_))
|
||||||
stack.reserve(Q64) // Reserve return pointer slot
|
stack.reserve(Q64) // Reserve return pointer slot
|
||||||
var chain = Chain.one[AsmLine](LabelDef(labelGenerator.getLabel(func.name)))
|
var chain = Chain.one[AsmLine](labelGenerator.getLabelDef(func.name))
|
||||||
chain ++= funcPrologue()
|
chain ++= funcPrologue()
|
||||||
// Push the rest of params onto the stack for simplicity
|
// Push the rest of params onto the stack for simplicity
|
||||||
argRegs.zip(func.params).foreach { (reg, param) =>
|
argRegs.zip(func.params).foreach { (reg, param) =>
|
||||||
@@ -91,12 +92,12 @@ object asmGenerator {
|
|||||||
var chain = Chain.empty[AsmLine]
|
var chain = Chain.empty[AsmLine]
|
||||||
|
|
||||||
chain ++= wrapBuiltinFunc(
|
chain ++= wrapBuiltinFunc(
|
||||||
labelGenerator.getLabel(Builtin.Exit),
|
Builtin.Exit,
|
||||||
Chain(stackAlign, assemblyIR.Call(CLibFunc.Exit))
|
Chain(stackAlign, assemblyIR.Call(CLibFunc.Exit))
|
||||||
)
|
)
|
||||||
|
|
||||||
chain ++= wrapBuiltinFunc(
|
chain ++= wrapBuiltinFunc(
|
||||||
labelGenerator.getLabel(Builtin.Printf),
|
Builtin.Printf,
|
||||||
Chain(
|
Chain(
|
||||||
stackAlign,
|
stackAlign,
|
||||||
assemblyIR.Call(CLibFunc.PrintF),
|
assemblyIR.Call(CLibFunc.PrintF),
|
||||||
@@ -106,7 +107,7 @@ object asmGenerator {
|
|||||||
)
|
)
|
||||||
|
|
||||||
chain ++= wrapBuiltinFunc(
|
chain ++= wrapBuiltinFunc(
|
||||||
labelGenerator.getLabel(Builtin.PrintCharArray),
|
Builtin.PrintCharArray,
|
||||||
Chain(
|
Chain(
|
||||||
stackAlign,
|
stackAlign,
|
||||||
Load(RDX, IndexAddress(RSI, KnownType.Int.size.toInt)),
|
Load(RDX, IndexAddress(RSI, KnownType.Int.size.toInt)),
|
||||||
@@ -118,24 +119,24 @@ object asmGenerator {
|
|||||||
)
|
)
|
||||||
|
|
||||||
chain ++= wrapBuiltinFunc(
|
chain ++= wrapBuiltinFunc(
|
||||||
labelGenerator.getLabel(Builtin.Malloc),
|
Builtin.Malloc,
|
||||||
Chain(stackAlign, assemblyIR.Call(CLibFunc.Malloc))
|
Chain(stackAlign, assemblyIR.Call(CLibFunc.Malloc))
|
||||||
// Out of memory check is optional
|
// Out of memory check is optional
|
||||||
)
|
)
|
||||||
|
|
||||||
chain ++= wrapBuiltinFunc(
|
chain ++= wrapBuiltinFunc(
|
||||||
labelGenerator.getLabel(Builtin.Free),
|
Builtin.Free,
|
||||||
Chain(
|
Chain(
|
||||||
stackAlign,
|
stackAlign,
|
||||||
Move(RDI, RAX),
|
Move(RDI, RAX),
|
||||||
Compare(RDI, ImmediateVal(0)),
|
Compare(RDI, ImmediateVal(0)),
|
||||||
Jump(LabelArg(labelGenerator.getLabel(NullPtrError)), Cond.Equal),
|
Jump(labelGenerator.getLabelArg(NullPtrError), Cond.Equal),
|
||||||
assemblyIR.Call(CLibFunc.Free)
|
assemblyIR.Call(CLibFunc.Free)
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
chain ++= wrapBuiltinFunc(
|
chain ++= wrapBuiltinFunc(
|
||||||
labelGenerator.getLabel(Builtin.Read),
|
Builtin.Read,
|
||||||
Chain(
|
Chain(
|
||||||
stackAlign,
|
stackAlign,
|
||||||
Subtract(Register(Q64, SP), ImmediateVal(8)),
|
Subtract(Register(Q64, SP), ImmediateVal(8)),
|
||||||
@@ -167,15 +168,15 @@ object asmGenerator {
|
|||||||
chain ++= evalExprOntoStack(i)
|
chain ++= evalExprOntoStack(i)
|
||||||
chain += stack.pop(RCX)
|
chain += stack.pop(RCX)
|
||||||
chain += Compare(ECX, ImmediateVal(0))
|
chain += Compare(ECX, ImmediateVal(0))
|
||||||
chain += Jump(LabelArg(labelGenerator.getLabel(OutOfBoundsError)), Cond.Less)
|
chain += Jump(labelGenerator.getLabelArg(OutOfBoundsError), Cond.Less)
|
||||||
chain += stack.push(Q64, RCX)
|
chain += stack.push(Q64, RCX)
|
||||||
chain ++= evalExprOntoStack(x)
|
chain ++= evalExprOntoStack(x)
|
||||||
chain += stack.pop(RAX)
|
chain += stack.pop(RAX)
|
||||||
chain += stack.pop(RCX)
|
chain += stack.pop(RCX)
|
||||||
chain += Compare(EAX, ImmediateVal(0))
|
chain += Compare(EAX, ImmediateVal(0))
|
||||||
chain += Jump(LabelArg(labelGenerator.getLabel(NullPtrError)), Cond.Equal)
|
chain += Jump(labelGenerator.getLabelArg(NullPtrError), Cond.Equal)
|
||||||
chain += Compare(MemLocation(RAX, D32), ECX)
|
chain += Compare(MemLocation(RAX, D32), ECX)
|
||||||
chain += Jump(LabelArg(labelGenerator.getLabel(OutOfBoundsError)), Cond.LessEqual)
|
chain += Jump(labelGenerator.getLabelArg(OutOfBoundsError), Cond.LessEqual)
|
||||||
chain += stack.pop(RDX)
|
chain += stack.pop(RDX)
|
||||||
|
|
||||||
chain += Move(
|
chain += Move(
|
||||||
@@ -246,7 +247,7 @@ object asmGenerator {
|
|||||||
expr.ty match {
|
expr.ty match {
|
||||||
case KnownType.String =>
|
case KnownType.String =>
|
||||||
val str = elems.collect { case CharLiter(v) => v }.mkString
|
val str = elems.collect { case CharLiter(v) => v }.mkString
|
||||||
chain += Load(RAX, IndexAddress(RIP, LabelArg(labelGenerator.getLabel(str))))
|
chain += Load(RAX, IndexAddress(RIP, labelGenerator.getLabelArg(str)))
|
||||||
chain += stack.push(Q64, RAX)
|
chain += stack.push(Q64, RAX)
|
||||||
case ty =>
|
case ty =>
|
||||||
chain ++= generateCall(
|
chain ++= generateCall(
|
||||||
@@ -277,12 +278,12 @@ object asmGenerator {
|
|||||||
chain ++= evalExprOntoStack(i)
|
chain ++= evalExprOntoStack(i)
|
||||||
chain += stack.pop(RCX)
|
chain += stack.pop(RCX)
|
||||||
chain += Compare(RCX, ImmediateVal(0))
|
chain += Compare(RCX, ImmediateVal(0))
|
||||||
chain += Jump(LabelArg(labelGenerator.getLabel(OutOfBoundsError)), Cond.Less)
|
chain += Jump(labelGenerator.getLabelArg(OutOfBoundsError), Cond.Less)
|
||||||
chain += stack.pop(RAX)
|
chain += stack.pop(RAX)
|
||||||
chain += Compare(EAX, ImmediateVal(0))
|
chain += Compare(EAX, ImmediateVal(0))
|
||||||
chain += Jump(LabelArg(labelGenerator.getLabel(NullPtrError)), Cond.Equal)
|
chain += Jump(labelGenerator.getLabelArg(NullPtrError), Cond.Equal)
|
||||||
chain += Compare(MemLocation(RAX, D32), ECX)
|
chain += Compare(MemLocation(RAX, D32), ECX)
|
||||||
chain += Jump(LabelArg(labelGenerator.getLabel(OutOfBoundsError)), Cond.LessEqual)
|
chain += Jump(labelGenerator.getLabelArg(OutOfBoundsError), Cond.LessEqual)
|
||||||
// + Int because we store the length of the array at the start
|
// + Int because we store the length of the array at the start
|
||||||
chain += Move(
|
chain += Move(
|
||||||
Register(x.ty.elemSize, AX),
|
Register(x.ty.elemSize, AX),
|
||||||
@@ -296,7 +297,7 @@ object asmGenerator {
|
|||||||
chain += Move(EAX, stack.head)
|
chain += Move(EAX, stack.head)
|
||||||
chain += And(EAX, ImmediateVal(-128))
|
chain += And(EAX, ImmediateVal(-128))
|
||||||
chain += Compare(EAX, ImmediateVal(0))
|
chain += Compare(EAX, ImmediateVal(0))
|
||||||
chain += Jump(LabelArg(labelGenerator.getLabel(BadChrError)), Cond.NotEqual)
|
chain += Jump(labelGenerator.getLabelArg(BadChrError), Cond.NotEqual)
|
||||||
case UnaryOperator.Ord => // No op needed
|
case UnaryOperator.Ord => // No op needed
|
||||||
case UnaryOperator.Len =>
|
case UnaryOperator.Len =>
|
||||||
chain += stack.pop(RAX)
|
chain += stack.pop(RAX)
|
||||||
@@ -305,7 +306,7 @@ object asmGenerator {
|
|||||||
case UnaryOperator.Negate =>
|
case UnaryOperator.Negate =>
|
||||||
chain += Xor(EAX, EAX)
|
chain += Xor(EAX, EAX)
|
||||||
chain += Subtract(EAX, stack.head)
|
chain += Subtract(EAX, stack.head)
|
||||||
chain += Jump(LabelArg(labelGenerator.getLabel(OverflowError)), Cond.Overflow)
|
chain += Jump(labelGenerator.getLabelArg(OverflowError), Cond.Overflow)
|
||||||
chain += stack.drop()
|
chain += stack.drop()
|
||||||
chain += stack.push(Q64, RAX)
|
chain += stack.push(Q64, RAX)
|
||||||
case UnaryOperator.Not =>
|
case UnaryOperator.Not =>
|
||||||
@@ -321,21 +322,21 @@ object asmGenerator {
|
|||||||
op match {
|
op match {
|
||||||
case BinaryOperator.Add =>
|
case BinaryOperator.Add =>
|
||||||
chain += Add(stack.head, destX)
|
chain += Add(stack.head, destX)
|
||||||
chain += Jump(LabelArg(labelGenerator.getLabel(OverflowError)), Cond.Overflow)
|
chain += Jump(labelGenerator.getLabelArg(OverflowError), Cond.Overflow)
|
||||||
case BinaryOperator.Sub =>
|
case BinaryOperator.Sub =>
|
||||||
chain += Subtract(destX, stack.head)
|
chain += Subtract(destX, stack.head)
|
||||||
chain += Jump(LabelArg(labelGenerator.getLabel(OverflowError)), Cond.Overflow)
|
chain += Jump(labelGenerator.getLabelArg(OverflowError), Cond.Overflow)
|
||||||
chain += stack.drop()
|
chain += stack.drop()
|
||||||
chain += stack.push(destX.size, RAX)
|
chain += stack.push(destX.size, RAX)
|
||||||
case BinaryOperator.Mul =>
|
case BinaryOperator.Mul =>
|
||||||
chain += Multiply(destX, stack.head)
|
chain += Multiply(destX, stack.head)
|
||||||
chain += Jump(LabelArg(labelGenerator.getLabel(OverflowError)), Cond.Overflow)
|
chain += Jump(labelGenerator.getLabelArg(OverflowError), Cond.Overflow)
|
||||||
chain += stack.drop()
|
chain += stack.drop()
|
||||||
chain += stack.push(destX.size, RAX)
|
chain += stack.push(destX.size, RAX)
|
||||||
|
|
||||||
case BinaryOperator.Div =>
|
case BinaryOperator.Div =>
|
||||||
chain += Compare(stack.head, ImmediateVal(0))
|
chain += Compare(stack.head, ImmediateVal(0))
|
||||||
chain += Jump(LabelArg(labelGenerator.getLabel(ZeroDivError)), Cond.Equal)
|
chain += Jump(labelGenerator.getLabelArg(ZeroDivError), Cond.Equal)
|
||||||
chain += CDQ()
|
chain += CDQ()
|
||||||
chain += Divide(stack.head)
|
chain += Divide(stack.head)
|
||||||
chain += stack.drop()
|
chain += stack.drop()
|
||||||
@@ -343,7 +344,7 @@ object asmGenerator {
|
|||||||
|
|
||||||
case BinaryOperator.Mod =>
|
case BinaryOperator.Mod =>
|
||||||
chain += Compare(stack.head, ImmediateVal(0))
|
chain += Compare(stack.head, ImmediateVal(0))
|
||||||
chain += Jump(LabelArg(labelGenerator.getLabel(ZeroDivError)), Cond.Equal)
|
chain += Jump(labelGenerator.getLabelArg(ZeroDivError), Cond.Equal)
|
||||||
chain += CDQ()
|
chain += CDQ()
|
||||||
chain += Divide(stack.head)
|
chain += Divide(stack.head)
|
||||||
chain += stack.drop()
|
chain += stack.drop()
|
||||||
@@ -393,9 +394,9 @@ object asmGenerator {
|
|||||||
|
|
||||||
// Tail Call Optimisation (TCO)
|
// Tail Call Optimisation (TCO)
|
||||||
if (isTail) {
|
if (isTail) {
|
||||||
chain += Jump(LabelArg(labelGenerator.getLabel(target))) // tail call
|
chain += Jump(labelGenerator.getLabelArg(target)) // tail call
|
||||||
} else {
|
} else {
|
||||||
chain += assemblyIR.Call(LabelArg(labelGenerator.getLabel(target))) // regular call
|
chain += assemblyIR.Call(labelGenerator.getLabelArg(target)) // regular call
|
||||||
}
|
}
|
||||||
|
|
||||||
if (args.size > argRegs.size) {
|
if (args.size > argRegs.size) {
|
||||||
|
|||||||
Reference in New Issue
Block a user