refactor: merge labelGenerator refactors

Merge request lab2425_spring/WACC_37!35

Co-authored-by: Gleb Koval <gleb@koval.net>
This commit is contained in:
Teixeira, Jonny
2025-02-28 14:28:32 +00:00
4 changed files with 111 additions and 101 deletions

View File

@@ -0,0 +1,50 @@
package wacc
import scala.collection.mutable
import cats.data.Chain
private class LabelGenerator {
import assemblyIR._
import microWacc.{CallTarget, Ident, Builtin}
import asmGenerator.escaped
private val strings = mutable.HashMap[String, String]()
private var labelVal = -1
/** Get an arbitrary label. */
def getLabel(): String = {
labelVal += 1
s".L$labelVal"
}
private def getLabel(target: CallTarget | RuntimeError): String = target match {
case Ident(v, _) => s"wacc_$v"
case Builtin(name) => s"_$name"
case err: RuntimeError => s".L.${err.name}"
}
/** Get a named label def for a function or error. */
def getLabelDef(target: CallTarget | RuntimeError): LabelDef =
LabelDef(getLabel(target))
/** Get a named label for a function or error. */
def getLabelArg(target: CallTarget | RuntimeError): LabelArg =
LabelArg(getLabel(target))
/** Get an arbitrary label for a string. */
def getLabelArg(str: String): LabelArg =
LabelArg(strings.getOrElseUpdate(str, s".L.str${strings.size}"))
/** Get a named label for a string. */
def getLabelArg(src: String, name: String): LabelArg =
LabelArg(strings.getOrElseUpdate(src, s".L.$name.str${strings.size}"))
/** Generate the assembly labels for constants that were labelled using the LabelGenerator. */
def generateConstants: Chain[AsmLine] =
strings.foldLeft(Chain.empty) { case (acc, (str, label)) =>
acc ++ Chain(
LabelDef(label),
Directive.Asciz(str.escaped)
)
}
}

View File

@@ -4,18 +4,16 @@ import cats.data.Chain
import wacc.assemblyIR._
sealed trait RuntimeError {
def strLabel: String
def errStr: String
def errLabel: String
val name: String
protected val errStr: String
def stringDef: Chain[AsmLine] = Chain(
Directive.Int(errStr.length),
LabelDef(strLabel),
Directive.Asciz(errStr)
)
protected def getErrLabel(using labelGenerator: LabelGenerator): LabelArg =
labelGenerator.getLabelArg(errStr, name = name)
def generateHandler: Chain[AsmLine]
protected def generateHandler(using labelGenerator: LabelGenerator): Chain[AsmLine]
def generate(using labelGenerator: LabelGenerator): Chain[AsmLine] =
labelGenerator.getLabelDef(this) +: generateHandler
}
object RuntimeError {
@@ -36,14 +34,12 @@ object RuntimeError {
// private val RCX = Register(Q64, CX)
case object ZeroDivError extends RuntimeError {
val strLabel = ".L._errDivZero_str0"
val errStr = "fatal error: division or modulo by zero"
val errLabel = ".L._errDivZero"
val name = "errDivZero"
protected val errStr = "fatal error: division or modulo by zero"
def generateHandler: Chain[AsmLine] = Chain(
LabelDef(ZeroDivError.errLabel),
protected def generateHandler(using labelGenerator: LabelGenerator): Chain[AsmLine] = Chain(
stackAlign,
Load(RDI, IndexAddress(RIP, LabelArg(ZeroDivError.strLabel))),
Load(RDI, IndexAddress(RIP, getErrLabel)),
assemblyIR.Call(CLibFunc.PrintF),
Move(RDI, ImmediateVal(-1)),
assemblyIR.Call(CLibFunc.Exit)
@@ -52,15 +48,13 @@ object RuntimeError {
}
case object BadChrError extends RuntimeError {
val strLabel = ".L._errBadChr_str0"
val errStr = "fatal error: int %d is not an ASCII character 0-127"
val errLabel = ".L._errBadChr"
val name = "errBadChr"
protected val errStr = "fatal error: int %d is not an ASCII character 0-127"
def generateHandler: Chain[AsmLine] = Chain(
LabelDef(BadChrError.errLabel),
protected def generateHandler(using labelGenerator: LabelGenerator): Chain[AsmLine] = Chain(
Pop(RSI),
stackAlign,
Load(RDI, IndexAddress(RIP, LabelArg(BadChrError.strLabel))),
Load(RDI, IndexAddress(RIP, getErrLabel)),
assemblyIR.Call(CLibFunc.PrintF),
Move(RDI, ImmediateVal(255)),
assemblyIR.Call(CLibFunc.Exit)
@@ -69,14 +63,12 @@ object RuntimeError {
}
case object NullPtrError extends RuntimeError {
val strLabel = ".L._errNullPtr_str0"
val errStr = "fatal error: null pair dereferenced or freed"
val errLabel = ".L._errNullPtr"
val name = "errNullPtr"
protected val errStr = "fatal error: null pair dereferenced or freed"
def generateHandler: Chain[AsmLine] = Chain(
LabelDef(NullPtrError.errLabel),
protected def generateHandler(using labelGenerator: LabelGenerator): Chain[AsmLine] = Chain(
stackAlign,
Load(RDI, IndexAddress(RIP, LabelArg(NullPtrError.strLabel))),
Load(RDI, IndexAddress(RIP, getErrLabel)),
assemblyIR.Call(CLibFunc.PrintF),
Move(RDI, ImmediateVal(255)),
assemblyIR.Call(CLibFunc.Exit)
@@ -85,14 +77,12 @@ object RuntimeError {
}
case object OverflowError extends RuntimeError {
val strLabel = ".L._errOverflow_str0"
val errStr = "fatal error: integer overflow or underflow occurred"
val errLabel = ".L._errOverflow"
val name = "errOverflow"
protected val errStr = "fatal error: integer overflow or underflow occurred"
def generateHandler: Chain[AsmLine] = Chain(
LabelDef(OverflowError.errLabel),
protected def generateHandler(using labelGenerator: LabelGenerator): Chain[AsmLine] = Chain(
stackAlign,
Load(RDI, IndexAddress(RIP, LabelArg(OverflowError.strLabel))),
Load(RDI, IndexAddress(RIP, getErrLabel)),
assemblyIR.Call(CLibFunc.PrintF),
Move(RDI, ImmediateVal(255)),
assemblyIR.Call(CLibFunc.Exit)
@@ -101,16 +91,13 @@ object RuntimeError {
}
case object OutOfBoundsError extends RuntimeError {
val name = "errOutOfBounds"
protected val errStr = "fatal error: array index %d out of bounds"
val strLabel = ".L._errOutOfBounds_str0"
val errStr = "fatal error: array index %d out of bounds"
val errLabel = ".L._errOutOfBounds"
def generateHandler: Chain[AsmLine] = Chain(
LabelDef(OutOfBoundsError.errLabel),
protected def generateHandler(using labelGenerator: LabelGenerator): Chain[AsmLine] = Chain(
Move(RSI, Register(Q64, CX)),
stackAlign,
Load(RDI, IndexAddress(RIP, LabelArg(OutOfBoundsError.strLabel))),
Load(RDI, IndexAddress(RIP, getErrLabel)),
assemblyIR.Call(CLibFunc.PrintF),
Move(RDI, ImmediateVal(255)),
assemblyIR.Call(CLibFunc.Exit)

View File

@@ -1,6 +1,5 @@
package wacc
import scala.collection.mutable.ListBuffer
import cats.data.Chain
import cats.syntax.foldable._
import wacc.RuntimeError._
@@ -31,21 +30,8 @@ object asmGenerator {
def concatAll(chains: Chain[T]*): Chain[T] =
chains.foldLeft(chain)(_ ++ _)
class LabelGenerator {
var labelVal = -1
def getLabel(): String = {
labelVal += 1
s".L$labelVal"
}
def getLabel(target: CallTarget): String = target match {
case Ident(v, _) => s"wacc_$v"
case Builtin(name) => s"_$name"
}
}
def generateAsm(microProg: Program): Chain[AsmLine] = {
given stack: Stack = Stack()
given strings: ListBuffer[String] = ListBuffer[String]()
given labelGenerator: LabelGenerator = LabelGenerator()
val Program(funcs, main) = microProg
@@ -55,32 +41,26 @@ object asmGenerator {
Chain.one(Xor(RAX, RAX)),
funcEpilogue(),
generateBuiltInFuncs(),
RuntimeError.all.foldMap(_.generate),
funcs.foldMap(generateUserFunc(_))
)
val strDirs = strings.toList.zipWithIndex.foldMap { case (str, i) =>
Chain(
Directive.Int(str.size),
LabelDef(s".L.str$i"),
Directive.Asciz(str.escaped)
)
} ++ RuntimeError.all.foldMap(_.stringDef)
Chain(
Directive.IntelSyntax,
Directive.Global("main"),
Directive.RoData
).concatAll(
strDirs,
labelGenerator.generateConstants,
Chain.one(Directive.Text),
progAsm
)
}
private def wrapBuiltinFunc(labelName: String, funcBody: Chain[AsmLine])(using
stack: Stack
private def wrapBuiltinFunc(builtin: Builtin, funcBody: Chain[AsmLine])(using
stack: Stack,
labelGenerator: LabelGenerator
): Chain[AsmLine] = {
var chain = Chain.one[AsmLine](LabelDef(labelName))
var chain = Chain.one[AsmLine](labelGenerator.getLabelDef(builtin))
chain ++= funcPrologue()
chain ++= funcBody
chain ++= funcEpilogue()
@@ -88,14 +68,13 @@ object asmGenerator {
}
private def generateUserFunc(func: FuncDecl)(using
strings: ListBuffer[String],
labelGenerator: LabelGenerator
): Chain[AsmLine] = {
given stack: Stack = Stack()
// Setup the stack with param 7 and up
func.params.drop(argRegs.size).foreach(stack.reserve(_))
stack.reserve(Q64) // Reserve return pointer slot
var chain = Chain.one[AsmLine](LabelDef(labelGenerator.getLabel(func.name)))
var chain = Chain.one[AsmLine](labelGenerator.getLabelDef(func.name))
chain ++= funcPrologue()
// Push the rest of params onto the stack for simplicity
argRegs.zip(func.params).foreach { (reg, param) =>
@@ -113,12 +92,12 @@ object asmGenerator {
var chain = Chain.empty[AsmLine]
chain ++= wrapBuiltinFunc(
labelGenerator.getLabel(Builtin.Exit),
Builtin.Exit,
Chain(stackAlign, assemblyIR.Call(CLibFunc.Exit))
)
chain ++= wrapBuiltinFunc(
labelGenerator.getLabel(Builtin.Printf),
Builtin.Printf,
Chain(
stackAlign,
assemblyIR.Call(CLibFunc.PrintF),
@@ -128,7 +107,7 @@ object asmGenerator {
)
chain ++= wrapBuiltinFunc(
labelGenerator.getLabel(Builtin.PrintCharArray),
Builtin.PrintCharArray,
Chain(
stackAlign,
Load(RDX, IndexAddress(RSI, KnownType.Int.size.toInt)),
@@ -140,24 +119,24 @@ object asmGenerator {
)
chain ++= wrapBuiltinFunc(
labelGenerator.getLabel(Builtin.Malloc),
Builtin.Malloc,
Chain(stackAlign, assemblyIR.Call(CLibFunc.Malloc))
// Out of memory check is optional
)
chain ++= wrapBuiltinFunc(
labelGenerator.getLabel(Builtin.Free),
Builtin.Free,
Chain(
stackAlign,
Move(RDI, RAX),
Compare(RDI, ImmediateVal(0)),
Jump(LabelArg(NullPtrError.errLabel), Cond.Equal),
Jump(labelGenerator.getLabelArg(NullPtrError), Cond.Equal),
assemblyIR.Call(CLibFunc.Free)
)
)
chain ++= wrapBuiltinFunc(
labelGenerator.getLabel(Builtin.Read),
Builtin.Read,
Chain(
stackAlign,
Subtract(Register(Q64, SP), ImmediateVal(8)),
@@ -168,18 +147,14 @@ object asmGenerator {
)
)
chain ++= RuntimeError.all.foldMap(_.generateHandler)
chain
}
private def generateStmt(stmt: Stmt)(using
stack: Stack,
strings: ListBuffer[String],
labelGenerator: LabelGenerator
): Chain[AsmLine] = {
var chain = Chain.empty[AsmLine]
chain += Comment(stmt.toString)
stmt match {
case Assign(lhs, rhs) =>
lhs match {
@@ -193,15 +168,15 @@ object asmGenerator {
chain ++= evalExprOntoStack(i)
chain += stack.pop(RCX)
chain += Compare(ECX, ImmediateVal(0))
chain += Jump(LabelArg(OutOfBoundsError.errLabel), Cond.Less)
chain += Jump(labelGenerator.getLabelArg(OutOfBoundsError), Cond.Less)
chain += stack.push(Q64, RCX)
chain ++= evalExprOntoStack(x)
chain += stack.pop(RAX)
chain += stack.pop(RCX)
chain += Compare(EAX, ImmediateVal(0))
chain += Jump(LabelArg(NullPtrError.errLabel), Cond.Equal)
chain += Jump(labelGenerator.getLabelArg(NullPtrError), Cond.Equal)
chain += Compare(MemLocation(RAX, D32), ECX)
chain += Jump(LabelArg(OutOfBoundsError.errLabel), Cond.LessEqual)
chain += Jump(labelGenerator.getLabelArg(OutOfBoundsError), Cond.LessEqual)
chain += stack.pop(RDX)
chain += Move(
@@ -259,7 +234,6 @@ object asmGenerator {
private def evalExprOntoStack(expr: Expr)(using
stack: Stack,
strings: ListBuffer[String],
labelGenerator: LabelGenerator
): Chain[AsmLine] = {
var chain = Chain.empty[AsmLine]
@@ -272,8 +246,8 @@ object asmGenerator {
case array @ ArrayLiter(elems) =>
expr.ty match {
case KnownType.String =>
strings += elems.collect { case CharLiter(v) => v }.mkString
chain += Load(RAX, IndexAddress(RIP, LabelArg(s".L.str${strings.size - 1}")))
val str = elems.collect { case CharLiter(v) => v }.mkString
chain += Load(RAX, IndexAddress(RIP, labelGenerator.getLabelArg(str)))
chain += stack.push(Q64, RAX)
case ty =>
chain ++= generateCall(
@@ -304,12 +278,12 @@ object asmGenerator {
chain ++= evalExprOntoStack(i)
chain += stack.pop(RCX)
chain += Compare(RCX, ImmediateVal(0))
chain += Jump(LabelArg(OutOfBoundsError.errLabel), Cond.Less)
chain += Jump(labelGenerator.getLabelArg(OutOfBoundsError), Cond.Less)
chain += stack.pop(RAX)
chain += Compare(EAX, ImmediateVal(0))
chain += Jump(LabelArg(NullPtrError.errLabel), Cond.Equal)
chain += Jump(labelGenerator.getLabelArg(NullPtrError), Cond.Equal)
chain += Compare(MemLocation(RAX, D32), ECX)
chain += Jump(LabelArg(OutOfBoundsError.errLabel), Cond.LessEqual)
chain += Jump(labelGenerator.getLabelArg(OutOfBoundsError), Cond.LessEqual)
// + Int because we store the length of the array at the start
chain += Move(
Register(x.ty.elemSize, AX),
@@ -323,7 +297,7 @@ object asmGenerator {
chain += Move(EAX, stack.head)
chain += And(EAX, ImmediateVal(-128))
chain += Compare(EAX, ImmediateVal(0))
chain += Jump(LabelArg(BadChrError.errLabel), Cond.NotEqual)
chain += Jump(labelGenerator.getLabelArg(BadChrError), Cond.NotEqual)
case UnaryOperator.Ord => // No op needed
case UnaryOperator.Len =>
chain += stack.pop(RAX)
@@ -332,7 +306,7 @@ object asmGenerator {
case UnaryOperator.Negate =>
chain += Xor(EAX, EAX)
chain += Subtract(EAX, stack.head)
chain += Jump(LabelArg(OverflowError.errLabel), Cond.Overflow)
chain += Jump(labelGenerator.getLabelArg(OverflowError), Cond.Overflow)
chain += stack.drop()
chain += stack.push(Q64, RAX)
case UnaryOperator.Not =>
@@ -348,21 +322,21 @@ object asmGenerator {
op match {
case BinaryOperator.Add =>
chain += Add(stack.head, destX)
chain += Jump(LabelArg(OverflowError.errLabel), Cond.Overflow)
chain += Jump(labelGenerator.getLabelArg(OverflowError), Cond.Overflow)
case BinaryOperator.Sub =>
chain += Subtract(destX, stack.head)
chain += Jump(LabelArg(OverflowError.errLabel), Cond.Overflow)
chain += Jump(labelGenerator.getLabelArg(OverflowError), Cond.Overflow)
chain += stack.drop()
chain += stack.push(destX.size, RAX)
case BinaryOperator.Mul =>
chain += Multiply(destX, stack.head)
chain += Jump(LabelArg(OverflowError.errLabel), Cond.Overflow)
chain += Jump(labelGenerator.getLabelArg(OverflowError), Cond.Overflow)
chain += stack.drop()
chain += stack.push(destX.size, RAX)
case BinaryOperator.Div =>
chain += Compare(stack.head, ImmediateVal(0))
chain += Jump(LabelArg(ZeroDivError.errLabel), Cond.Equal)
chain += Jump(labelGenerator.getLabelArg(ZeroDivError), Cond.Equal)
chain += CDQ()
chain += Divide(stack.head)
chain += stack.drop()
@@ -370,7 +344,7 @@ object asmGenerator {
case BinaryOperator.Mod =>
chain += Compare(stack.head, ImmediateVal(0))
chain += Jump(LabelArg(ZeroDivError.errLabel), Cond.Equal)
chain += Jump(labelGenerator.getLabelArg(ZeroDivError), Cond.Equal)
chain += CDQ()
chain += Divide(stack.head)
chain += stack.drop()
@@ -398,7 +372,6 @@ object asmGenerator {
private def generateCall(call: microWacc.Call, isTail: Boolean)(using
stack: Stack,
strings: ListBuffer[String],
labelGenerator: LabelGenerator
): Chain[AsmLine] = {
var chain = Chain.empty[AsmLine]
@@ -421,9 +394,9 @@ object asmGenerator {
// Tail Call Optimisation (TCO)
if (isTail) {
chain += Jump(LabelArg(labelGenerator.getLabel(target))) // tail call
chain += Jump(labelGenerator.getLabelArg(target)) // tail call
} else {
chain += assemblyIR.Call(LabelArg(labelGenerator.getLabel(target))) // regular call
chain += assemblyIR.Call(labelGenerator.getLabelArg(target)) // regular call
}
if (args.size > argRegs.size) {
@@ -470,7 +443,7 @@ object asmGenerator {
private val escapedCharsMapping = escapedChars.map { case (k, v) => v -> s"\\$k" }
extension (s: String) {
private def escaped: String =
def escaped: String =
s.flatMap(c => escapedCharsMapping.getOrElse(c, c.toString))
}
}

View File

@@ -199,8 +199,8 @@ object assemblyIR {
case Global(name) => s".globl $name"
case Text => ".text"
case RoData => ".section .rodata"
case Int(value) => s".int $value"
case Asciz(string) => s".asciz \"$string\""
case Int(value) => s"\t.int $value"
case Asciz(string) => s"\t.asciz \"$string\""
}
}