refactor: use LabelGenerator for RuntimeErrors

This commit is contained in:
2025-02-28 14:07:00 +00:00
parent 967a6fe58b
commit fb5799dbfd
4 changed files with 91 additions and 86 deletions

View File

@@ -0,0 +1,46 @@
package wacc
import scala.collection.mutable
import cats.data.Chain
private class LabelGenerator {
import assemblyIR._
import microWacc.{CallTarget, Ident, Builtin}
import asmGenerator.escaped
private val strings = mutable.HashMap[String, String]()
private var labelVal = -1
/** Get an arbitrary label. */
def getLabel(): String = {
labelVal += 1
s".L$labelVal"
}
/** Get a named label for a function. */
def getLabel(target: CallTarget): String = target match {
case Ident(v, _) => s"wacc_$v"
case Builtin(name) => s"_$name"
}
/** Get a named label for an error. */
def getLabel(target: RuntimeError): String =
s".L.${target.name}"
/** Get an arbitrary label for a string. */
def getLabel(str: String): String =
strings.getOrElseUpdate(str, s".L.str${strings.size}")
/** Get a named label for a string. */
def getLabel(src: String, name: String): String =
strings.getOrElseUpdate(src, s".L.$name.str${strings.size}")
/** Generate the assembly labels for constants that were labelled using the LabelGenerator. */
def generateConstants: Chain[AsmLine] =
strings.foldLeft(Chain.empty) { case (acc, (str, label)) =>
acc ++ Chain(
LabelDef(label),
Directive.Asciz(str.escaped)
)
}
}

View File

@@ -4,18 +4,16 @@ import cats.data.Chain
import wacc.assemblyIR._ import wacc.assemblyIR._
sealed trait RuntimeError { sealed trait RuntimeError {
def strLabel: String val name: String
def errStr: String protected val errStr: String
def errLabel: String
def stringDef: Chain[AsmLine] = Chain( protected def getErrLabel(using labelGenerator: LabelGenerator): String =
Directive.Int(errStr.length), labelGenerator.getLabel(errStr, name = name)
LabelDef(strLabel),
Directive.Asciz(errStr)
)
def generateHandler: Chain[AsmLine] protected def generateHandler(using labelGenerator: LabelGenerator): Chain[AsmLine]
def generate(using labelGenerator: LabelGenerator): Chain[AsmLine] =
LabelDef(labelGenerator.getLabel(this)) +: generateHandler
} }
object RuntimeError { object RuntimeError {
@@ -36,14 +34,12 @@ object RuntimeError {
// private val RCX = Register(Q64, CX) // private val RCX = Register(Q64, CX)
case object ZeroDivError extends RuntimeError { case object ZeroDivError extends RuntimeError {
val strLabel = ".L._errDivZero_str0" val name = "errDivZero"
val errStr = "fatal error: division or modulo by zero" protected val errStr = "fatal error: division or modulo by zero"
val errLabel = ".L._errDivZero"
def generateHandler: Chain[AsmLine] = Chain( protected def generateHandler(using labelGenerator: LabelGenerator): Chain[AsmLine] = Chain(
LabelDef(ZeroDivError.errLabel),
stackAlign, stackAlign,
Load(RDI, IndexAddress(RIP, LabelArg(ZeroDivError.strLabel))), Load(RDI, IndexAddress(RIP, LabelArg(getErrLabel))),
assemblyIR.Call(CLibFunc.PrintF), assemblyIR.Call(CLibFunc.PrintF),
Move(RDI, ImmediateVal(-1)), Move(RDI, ImmediateVal(-1)),
assemblyIR.Call(CLibFunc.Exit) assemblyIR.Call(CLibFunc.Exit)
@@ -52,15 +48,13 @@ object RuntimeError {
} }
case object BadChrError extends RuntimeError { case object BadChrError extends RuntimeError {
val strLabel = ".L._errBadChr_str0" val name = "errBadChr"
val errStr = "fatal error: int %d is not an ASCII character 0-127" protected val errStr = "fatal error: int %d is not an ASCII character 0-127"
val errLabel = ".L._errBadChr"
def generateHandler: Chain[AsmLine] = Chain( protected def generateHandler(using labelGenerator: LabelGenerator): Chain[AsmLine] = Chain(
LabelDef(BadChrError.errLabel),
Pop(RSI), Pop(RSI),
stackAlign, stackAlign,
Load(RDI, IndexAddress(RIP, LabelArg(BadChrError.strLabel))), Load(RDI, IndexAddress(RIP, LabelArg(getErrLabel))),
assemblyIR.Call(CLibFunc.PrintF), assemblyIR.Call(CLibFunc.PrintF),
Move(RDI, ImmediateVal(255)), Move(RDI, ImmediateVal(255)),
assemblyIR.Call(CLibFunc.Exit) assemblyIR.Call(CLibFunc.Exit)
@@ -69,14 +63,12 @@ object RuntimeError {
} }
case object NullPtrError extends RuntimeError { case object NullPtrError extends RuntimeError {
val strLabel = ".L._errNullPtr_str0" val name = "errNullPtr"
val errStr = "fatal error: null pair dereferenced or freed" protected val errStr = "fatal error: null pair dereferenced or freed"
val errLabel = ".L._errNullPtr"
def generateHandler: Chain[AsmLine] = Chain( protected def generateHandler(using labelGenerator: LabelGenerator): Chain[AsmLine] = Chain(
LabelDef(NullPtrError.errLabel),
stackAlign, stackAlign,
Load(RDI, IndexAddress(RIP, LabelArg(NullPtrError.strLabel))), Load(RDI, IndexAddress(RIP, LabelArg(getErrLabel))),
assemblyIR.Call(CLibFunc.PrintF), assemblyIR.Call(CLibFunc.PrintF),
Move(RDI, ImmediateVal(255)), Move(RDI, ImmediateVal(255)),
assemblyIR.Call(CLibFunc.Exit) assemblyIR.Call(CLibFunc.Exit)
@@ -85,14 +77,12 @@ object RuntimeError {
} }
case object OverflowError extends RuntimeError { case object OverflowError extends RuntimeError {
val strLabel = ".L._errOverflow_str0" val name = "errOverflow"
val errStr = "fatal error: integer overflow or underflow occurred" protected val errStr = "fatal error: integer overflow or underflow occurred"
val errLabel = ".L._errOverflow"
def generateHandler: Chain[AsmLine] = Chain( protected def generateHandler(using labelGenerator: LabelGenerator): Chain[AsmLine] = Chain(
LabelDef(OverflowError.errLabel),
stackAlign, stackAlign,
Load(RDI, IndexAddress(RIP, LabelArg(OverflowError.strLabel))), Load(RDI, IndexAddress(RIP, LabelArg(getErrLabel))),
assemblyIR.Call(CLibFunc.PrintF), assemblyIR.Call(CLibFunc.PrintF),
Move(RDI, ImmediateVal(255)), Move(RDI, ImmediateVal(255)),
assemblyIR.Call(CLibFunc.Exit) assemblyIR.Call(CLibFunc.Exit)
@@ -101,16 +91,13 @@ object RuntimeError {
} }
case object OutOfBoundsError extends RuntimeError { case object OutOfBoundsError extends RuntimeError {
val name = "errOutOfBounds"
protected val errStr = "fatal error: array index %d out of bounds"
val strLabel = ".L._errOutOfBounds_str0" protected def generateHandler(using labelGenerator: LabelGenerator): Chain[AsmLine] = Chain(
val errStr = "fatal error: array index %d out of bounds"
val errLabel = ".L._errOutOfBounds"
def generateHandler: Chain[AsmLine] = Chain(
LabelDef(OutOfBoundsError.errLabel),
Move(RSI, Register(Q64, CX)), Move(RSI, Register(Q64, CX)),
stackAlign, stackAlign,
Load(RDI, IndexAddress(RIP, LabelArg(OutOfBoundsError.strLabel))), Load(RDI, IndexAddress(RIP, LabelArg(getErrLabel))),
assemblyIR.Call(CLibFunc.PrintF), assemblyIR.Call(CLibFunc.PrintF),
Move(RDI, ImmediateVal(255)), Move(RDI, ImmediateVal(255)),
assemblyIR.Call(CLibFunc.Exit) assemblyIR.Call(CLibFunc.Exit)

View File

@@ -1,6 +1,5 @@
package wacc package wacc
import scala.collection.mutable
import cats.data.Chain import cats.data.Chain
import cats.syntax.foldable._ import cats.syntax.foldable._
import wacc.RuntimeError._ import wacc.RuntimeError._
@@ -31,30 +30,6 @@ object asmGenerator {
def concatAll(chains: Chain[T]*): Chain[T] = def concatAll(chains: Chain[T]*): Chain[T] =
chains.foldLeft(chain)(_ ++ _) chains.foldLeft(chain)(_ ++ _)
private class LabelGenerator {
private val strings = mutable.HashMap[String, String]()
private var labelVal = -1
def getLabel(): String = {
labelVal += 1
s".L$labelVal"
}
def getLabel(target: CallTarget): String = target match {
case Ident(v, _) => s"wacc_$v"
case Builtin(name) => s"_$name"
}
def getLabel(str: String): String =
strings.getOrElseUpdate(str, s".L.str${strings.size}")
def generateConstants: Chain[AsmLine] =
strings.foldLeft(Chain.empty) { case (acc, (str, label)) =>
acc ++ Chain(
Directive.Int(str.size),
LabelDef(label),
Directive.Asciz(str.escaped)
)
}
}
def generateAsm(microProg: Program): Chain[AsmLine] = { def generateAsm(microProg: Program): Chain[AsmLine] = {
given stack: Stack = Stack() given stack: Stack = Stack()
given labelGenerator: LabelGenerator = LabelGenerator() given labelGenerator: LabelGenerator = LabelGenerator()
@@ -66,6 +41,7 @@ object asmGenerator {
Chain.one(Xor(RAX, RAX)), Chain.one(Xor(RAX, RAX)),
funcEpilogue(), funcEpilogue(),
generateBuiltInFuncs(), generateBuiltInFuncs(),
RuntimeError.all.foldMap(_.generate),
funcs.foldMap(generateUserFunc(_)) funcs.foldMap(generateUserFunc(_))
) )
@@ -75,7 +51,6 @@ object asmGenerator {
Directive.RoData Directive.RoData
).concatAll( ).concatAll(
labelGenerator.generateConstants, labelGenerator.generateConstants,
RuntimeError.all.foldMap(_.stringDef),
Chain.one(Directive.Text), Chain.one(Directive.Text),
progAsm progAsm
) )
@@ -154,7 +129,7 @@ object asmGenerator {
stackAlign, stackAlign,
Move(RDI, RAX), Move(RDI, RAX),
Compare(RDI, ImmediateVal(0)), Compare(RDI, ImmediateVal(0)),
Jump(LabelArg(NullPtrError.errLabel), Cond.Equal), Jump(LabelArg(labelGenerator.getLabel(NullPtrError)), Cond.Equal),
assemblyIR.Call(CLibFunc.Free) assemblyIR.Call(CLibFunc.Free)
) )
) )
@@ -171,8 +146,6 @@ object asmGenerator {
) )
) )
chain ++= RuntimeError.all.foldMap(_.generateHandler)
chain chain
} }
@@ -181,7 +154,6 @@ object asmGenerator {
labelGenerator: LabelGenerator labelGenerator: LabelGenerator
): Chain[AsmLine] = { ): Chain[AsmLine] = {
var chain = Chain.empty[AsmLine] var chain = Chain.empty[AsmLine]
chain += Comment(stmt.toString)
stmt match { stmt match {
case Assign(lhs, rhs) => case Assign(lhs, rhs) =>
lhs match { lhs match {
@@ -195,15 +167,15 @@ object asmGenerator {
chain ++= evalExprOntoStack(i) chain ++= evalExprOntoStack(i)
chain += stack.pop(RCX) chain += stack.pop(RCX)
chain += Compare(ECX, ImmediateVal(0)) chain += Compare(ECX, ImmediateVal(0))
chain += Jump(LabelArg(OutOfBoundsError.errLabel), Cond.Less) chain += Jump(LabelArg(labelGenerator.getLabel(OutOfBoundsError)), Cond.Less)
chain += stack.push(Q64, RCX) chain += stack.push(Q64, RCX)
chain ++= evalExprOntoStack(x) chain ++= evalExprOntoStack(x)
chain += stack.pop(RAX) chain += stack.pop(RAX)
chain += stack.pop(RCX) chain += stack.pop(RCX)
chain += Compare(EAX, ImmediateVal(0)) chain += Compare(EAX, ImmediateVal(0))
chain += Jump(LabelArg(NullPtrError.errLabel), Cond.Equal) chain += Jump(LabelArg(labelGenerator.getLabel(NullPtrError)), Cond.Equal)
chain += Compare(MemLocation(RAX, D32), ECX) chain += Compare(MemLocation(RAX, D32), ECX)
chain += Jump(LabelArg(OutOfBoundsError.errLabel), Cond.LessEqual) chain += Jump(LabelArg(labelGenerator.getLabel(OutOfBoundsError)), Cond.LessEqual)
chain += stack.pop(RDX) chain += stack.pop(RDX)
chain += Move( chain += Move(
@@ -305,12 +277,12 @@ object asmGenerator {
chain ++= evalExprOntoStack(i) chain ++= evalExprOntoStack(i)
chain += stack.pop(RCX) chain += stack.pop(RCX)
chain += Compare(RCX, ImmediateVal(0)) chain += Compare(RCX, ImmediateVal(0))
chain += Jump(LabelArg(OutOfBoundsError.errLabel), Cond.Less) chain += Jump(LabelArg(labelGenerator.getLabel(OutOfBoundsError)), Cond.Less)
chain += stack.pop(RAX) chain += stack.pop(RAX)
chain += Compare(EAX, ImmediateVal(0)) chain += Compare(EAX, ImmediateVal(0))
chain += Jump(LabelArg(NullPtrError.errLabel), Cond.Equal) chain += Jump(LabelArg(labelGenerator.getLabel(NullPtrError)), Cond.Equal)
chain += Compare(MemLocation(RAX, D32), ECX) chain += Compare(MemLocation(RAX, D32), ECX)
chain += Jump(LabelArg(OutOfBoundsError.errLabel), Cond.LessEqual) chain += Jump(LabelArg(labelGenerator.getLabel(OutOfBoundsError)), Cond.LessEqual)
// + Int because we store the length of the array at the start // + Int because we store the length of the array at the start
chain += Move( chain += Move(
Register(x.ty.elemSize, AX), Register(x.ty.elemSize, AX),
@@ -324,7 +296,7 @@ object asmGenerator {
chain += Move(EAX, stack.head) chain += Move(EAX, stack.head)
chain += And(EAX, ImmediateVal(-128)) chain += And(EAX, ImmediateVal(-128))
chain += Compare(EAX, ImmediateVal(0)) chain += Compare(EAX, ImmediateVal(0))
chain += Jump(LabelArg(BadChrError.errLabel), Cond.NotEqual) chain += Jump(LabelArg(labelGenerator.getLabel(BadChrError)), Cond.NotEqual)
case UnaryOperator.Ord => // No op needed case UnaryOperator.Ord => // No op needed
case UnaryOperator.Len => case UnaryOperator.Len =>
chain += stack.pop(RAX) chain += stack.pop(RAX)
@@ -333,7 +305,7 @@ object asmGenerator {
case UnaryOperator.Negate => case UnaryOperator.Negate =>
chain += Xor(EAX, EAX) chain += Xor(EAX, EAX)
chain += Subtract(EAX, stack.head) chain += Subtract(EAX, stack.head)
chain += Jump(LabelArg(OverflowError.errLabel), Cond.Overflow) chain += Jump(LabelArg(labelGenerator.getLabel(OverflowError)), Cond.Overflow)
chain += stack.drop() chain += stack.drop()
chain += stack.push(Q64, RAX) chain += stack.push(Q64, RAX)
case UnaryOperator.Not => case UnaryOperator.Not =>
@@ -349,21 +321,21 @@ object asmGenerator {
op match { op match {
case BinaryOperator.Add => case BinaryOperator.Add =>
chain += Add(stack.head, destX) chain += Add(stack.head, destX)
chain += Jump(LabelArg(OverflowError.errLabel), Cond.Overflow) chain += Jump(LabelArg(labelGenerator.getLabel(OverflowError)), Cond.Overflow)
case BinaryOperator.Sub => case BinaryOperator.Sub =>
chain += Subtract(destX, stack.head) chain += Subtract(destX, stack.head)
chain += Jump(LabelArg(OverflowError.errLabel), Cond.Overflow) chain += Jump(LabelArg(labelGenerator.getLabel(OverflowError)), Cond.Overflow)
chain += stack.drop() chain += stack.drop()
chain += stack.push(destX.size, RAX) chain += stack.push(destX.size, RAX)
case BinaryOperator.Mul => case BinaryOperator.Mul =>
chain += Multiply(destX, stack.head) chain += Multiply(destX, stack.head)
chain += Jump(LabelArg(OverflowError.errLabel), Cond.Overflow) chain += Jump(LabelArg(labelGenerator.getLabel(OverflowError)), Cond.Overflow)
chain += stack.drop() chain += stack.drop()
chain += stack.push(destX.size, RAX) chain += stack.push(destX.size, RAX)
case BinaryOperator.Div => case BinaryOperator.Div =>
chain += Compare(stack.head, ImmediateVal(0)) chain += Compare(stack.head, ImmediateVal(0))
chain += Jump(LabelArg(ZeroDivError.errLabel), Cond.Equal) chain += Jump(LabelArg(labelGenerator.getLabel(ZeroDivError)), Cond.Equal)
chain += CDQ() chain += CDQ()
chain += Divide(stack.head) chain += Divide(stack.head)
chain += stack.drop() chain += stack.drop()
@@ -371,7 +343,7 @@ object asmGenerator {
case BinaryOperator.Mod => case BinaryOperator.Mod =>
chain += Compare(stack.head, ImmediateVal(0)) chain += Compare(stack.head, ImmediateVal(0))
chain += Jump(LabelArg(ZeroDivError.errLabel), Cond.Equal) chain += Jump(LabelArg(labelGenerator.getLabel(ZeroDivError)), Cond.Equal)
chain += CDQ() chain += CDQ()
chain += Divide(stack.head) chain += Divide(stack.head)
chain += stack.drop() chain += stack.drop()
@@ -470,7 +442,7 @@ object asmGenerator {
private val escapedCharsMapping = escapedChars.map { case (k, v) => v -> s"\\$k" } private val escapedCharsMapping = escapedChars.map { case (k, v) => v -> s"\\$k" }
extension (s: String) { extension (s: String) {
private def escaped: String = def escaped: String =
s.flatMap(c => escapedCharsMapping.getOrElse(c, c.toString)) s.flatMap(c => escapedCharsMapping.getOrElse(c, c.toString))
} }
} }

View File

@@ -199,8 +199,8 @@ object assemblyIR {
case Global(name) => s".globl $name" case Global(name) => s".globl $name"
case Text => ".text" case Text => ".text"
case RoData => ".section .rodata" case RoData => ".section .rodata"
case Int(value) => s".int $value" case Int(value) => s"\t.int $value"
case Asciz(string) => s".asciz \"$string\"" case Asciz(string) => s"\t.asciz \"$string\""
} }
} }