refactor: use LabelGenerator for RuntimeErrors
This commit is contained in:
@@ -1,6 +1,5 @@
|
||||
package wacc
|
||||
|
||||
import scala.collection.mutable
|
||||
import cats.data.Chain
|
||||
import cats.syntax.foldable._
|
||||
import wacc.RuntimeError._
|
||||
@@ -31,30 +30,6 @@ object asmGenerator {
|
||||
def concatAll(chains: Chain[T]*): Chain[T] =
|
||||
chains.foldLeft(chain)(_ ++ _)
|
||||
|
||||
private class LabelGenerator {
|
||||
private val strings = mutable.HashMap[String, String]()
|
||||
private var labelVal = -1
|
||||
def getLabel(): String = {
|
||||
labelVal += 1
|
||||
s".L$labelVal"
|
||||
}
|
||||
def getLabel(target: CallTarget): String = target match {
|
||||
case Ident(v, _) => s"wacc_$v"
|
||||
case Builtin(name) => s"_$name"
|
||||
}
|
||||
def getLabel(str: String): String =
|
||||
strings.getOrElseUpdate(str, s".L.str${strings.size}")
|
||||
|
||||
def generateConstants: Chain[AsmLine] =
|
||||
strings.foldLeft(Chain.empty) { case (acc, (str, label)) =>
|
||||
acc ++ Chain(
|
||||
Directive.Int(str.size),
|
||||
LabelDef(label),
|
||||
Directive.Asciz(str.escaped)
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
def generateAsm(microProg: Program): Chain[AsmLine] = {
|
||||
given stack: Stack = Stack()
|
||||
given labelGenerator: LabelGenerator = LabelGenerator()
|
||||
@@ -66,6 +41,7 @@ object asmGenerator {
|
||||
Chain.one(Xor(RAX, RAX)),
|
||||
funcEpilogue(),
|
||||
generateBuiltInFuncs(),
|
||||
RuntimeError.all.foldMap(_.generate),
|
||||
funcs.foldMap(generateUserFunc(_))
|
||||
)
|
||||
|
||||
@@ -75,7 +51,6 @@ object asmGenerator {
|
||||
Directive.RoData
|
||||
).concatAll(
|
||||
labelGenerator.generateConstants,
|
||||
RuntimeError.all.foldMap(_.stringDef),
|
||||
Chain.one(Directive.Text),
|
||||
progAsm
|
||||
)
|
||||
@@ -154,7 +129,7 @@ object asmGenerator {
|
||||
stackAlign,
|
||||
Move(RDI, RAX),
|
||||
Compare(RDI, ImmediateVal(0)),
|
||||
Jump(LabelArg(NullPtrError.errLabel), Cond.Equal),
|
||||
Jump(LabelArg(labelGenerator.getLabel(NullPtrError)), Cond.Equal),
|
||||
assemblyIR.Call(CLibFunc.Free)
|
||||
)
|
||||
)
|
||||
@@ -171,8 +146,6 @@ object asmGenerator {
|
||||
)
|
||||
)
|
||||
|
||||
chain ++= RuntimeError.all.foldMap(_.generateHandler)
|
||||
|
||||
chain
|
||||
}
|
||||
|
||||
@@ -181,7 +154,6 @@ object asmGenerator {
|
||||
labelGenerator: LabelGenerator
|
||||
): Chain[AsmLine] = {
|
||||
var chain = Chain.empty[AsmLine]
|
||||
chain += Comment(stmt.toString)
|
||||
stmt match {
|
||||
case Assign(lhs, rhs) =>
|
||||
lhs match {
|
||||
@@ -195,15 +167,15 @@ object asmGenerator {
|
||||
chain ++= evalExprOntoStack(i)
|
||||
chain += stack.pop(RCX)
|
||||
chain += Compare(ECX, ImmediateVal(0))
|
||||
chain += Jump(LabelArg(OutOfBoundsError.errLabel), Cond.Less)
|
||||
chain += Jump(LabelArg(labelGenerator.getLabel(OutOfBoundsError)), Cond.Less)
|
||||
chain += stack.push(Q64, RCX)
|
||||
chain ++= evalExprOntoStack(x)
|
||||
chain += stack.pop(RAX)
|
||||
chain += stack.pop(RCX)
|
||||
chain += Compare(EAX, ImmediateVal(0))
|
||||
chain += Jump(LabelArg(NullPtrError.errLabel), Cond.Equal)
|
||||
chain += Jump(LabelArg(labelGenerator.getLabel(NullPtrError)), Cond.Equal)
|
||||
chain += Compare(MemLocation(RAX, D32), ECX)
|
||||
chain += Jump(LabelArg(OutOfBoundsError.errLabel), Cond.LessEqual)
|
||||
chain += Jump(LabelArg(labelGenerator.getLabel(OutOfBoundsError)), Cond.LessEqual)
|
||||
chain += stack.pop(RDX)
|
||||
|
||||
chain += Move(
|
||||
@@ -305,12 +277,12 @@ object asmGenerator {
|
||||
chain ++= evalExprOntoStack(i)
|
||||
chain += stack.pop(RCX)
|
||||
chain += Compare(RCX, ImmediateVal(0))
|
||||
chain += Jump(LabelArg(OutOfBoundsError.errLabel), Cond.Less)
|
||||
chain += Jump(LabelArg(labelGenerator.getLabel(OutOfBoundsError)), Cond.Less)
|
||||
chain += stack.pop(RAX)
|
||||
chain += Compare(EAX, ImmediateVal(0))
|
||||
chain += Jump(LabelArg(NullPtrError.errLabel), Cond.Equal)
|
||||
chain += Jump(LabelArg(labelGenerator.getLabel(NullPtrError)), Cond.Equal)
|
||||
chain += Compare(MemLocation(RAX, D32), ECX)
|
||||
chain += Jump(LabelArg(OutOfBoundsError.errLabel), Cond.LessEqual)
|
||||
chain += Jump(LabelArg(labelGenerator.getLabel(OutOfBoundsError)), Cond.LessEqual)
|
||||
// + Int because we store the length of the array at the start
|
||||
chain += Move(
|
||||
Register(x.ty.elemSize, AX),
|
||||
@@ -324,7 +296,7 @@ object asmGenerator {
|
||||
chain += Move(EAX, stack.head)
|
||||
chain += And(EAX, ImmediateVal(-128))
|
||||
chain += Compare(EAX, ImmediateVal(0))
|
||||
chain += Jump(LabelArg(BadChrError.errLabel), Cond.NotEqual)
|
||||
chain += Jump(LabelArg(labelGenerator.getLabel(BadChrError)), Cond.NotEqual)
|
||||
case UnaryOperator.Ord => // No op needed
|
||||
case UnaryOperator.Len =>
|
||||
chain += stack.pop(RAX)
|
||||
@@ -333,7 +305,7 @@ object asmGenerator {
|
||||
case UnaryOperator.Negate =>
|
||||
chain += Xor(EAX, EAX)
|
||||
chain += Subtract(EAX, stack.head)
|
||||
chain += Jump(LabelArg(OverflowError.errLabel), Cond.Overflow)
|
||||
chain += Jump(LabelArg(labelGenerator.getLabel(OverflowError)), Cond.Overflow)
|
||||
chain += stack.drop()
|
||||
chain += stack.push(Q64, RAX)
|
||||
case UnaryOperator.Not =>
|
||||
@@ -349,21 +321,21 @@ object asmGenerator {
|
||||
op match {
|
||||
case BinaryOperator.Add =>
|
||||
chain += Add(stack.head, destX)
|
||||
chain += Jump(LabelArg(OverflowError.errLabel), Cond.Overflow)
|
||||
chain += Jump(LabelArg(labelGenerator.getLabel(OverflowError)), Cond.Overflow)
|
||||
case BinaryOperator.Sub =>
|
||||
chain += Subtract(destX, stack.head)
|
||||
chain += Jump(LabelArg(OverflowError.errLabel), Cond.Overflow)
|
||||
chain += Jump(LabelArg(labelGenerator.getLabel(OverflowError)), Cond.Overflow)
|
||||
chain += stack.drop()
|
||||
chain += stack.push(destX.size, RAX)
|
||||
case BinaryOperator.Mul =>
|
||||
chain += Multiply(destX, stack.head)
|
||||
chain += Jump(LabelArg(OverflowError.errLabel), Cond.Overflow)
|
||||
chain += Jump(LabelArg(labelGenerator.getLabel(OverflowError)), Cond.Overflow)
|
||||
chain += stack.drop()
|
||||
chain += stack.push(destX.size, RAX)
|
||||
|
||||
case BinaryOperator.Div =>
|
||||
chain += Compare(stack.head, ImmediateVal(0))
|
||||
chain += Jump(LabelArg(ZeroDivError.errLabel), Cond.Equal)
|
||||
chain += Jump(LabelArg(labelGenerator.getLabel(ZeroDivError)), Cond.Equal)
|
||||
chain += CDQ()
|
||||
chain += Divide(stack.head)
|
||||
chain += stack.drop()
|
||||
@@ -371,7 +343,7 @@ object asmGenerator {
|
||||
|
||||
case BinaryOperator.Mod =>
|
||||
chain += Compare(stack.head, ImmediateVal(0))
|
||||
chain += Jump(LabelArg(ZeroDivError.errLabel), Cond.Equal)
|
||||
chain += Jump(LabelArg(labelGenerator.getLabel(ZeroDivError)), Cond.Equal)
|
||||
chain += CDQ()
|
||||
chain += Divide(stack.head)
|
||||
chain += stack.drop()
|
||||
@@ -470,7 +442,7 @@ object asmGenerator {
|
||||
|
||||
private val escapedCharsMapping = escapedChars.map { case (k, v) => v -> s"\\$k" }
|
||||
extension (s: String) {
|
||||
private def escaped: String =
|
||||
def escaped: String =
|
||||
s.flatMap(c => escapedCharsMapping.getOrElse(c, c.toString))
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user