refactor: replace strings ListBuffer with labelGenerator

This commit is contained in:
2025-02-28 13:14:29 +00:00
parent 8b3e9b8380
commit 967a6fe58b

View File

@@ -1,6 +1,6 @@
package wacc package wacc
import scala.collection.mutable.ListBuffer import scala.collection.mutable
import cats.data.Chain import cats.data.Chain
import cats.syntax.foldable._ import cats.syntax.foldable._
import wacc.RuntimeError._ import wacc.RuntimeError._
@@ -31,8 +31,9 @@ object asmGenerator {
def concatAll(chains: Chain[T]*): Chain[T] = def concatAll(chains: Chain[T]*): Chain[T] =
chains.foldLeft(chain)(_ ++ _) chains.foldLeft(chain)(_ ++ _)
class LabelGenerator { private class LabelGenerator {
var labelVal = -1 private val strings = mutable.HashMap[String, String]()
private var labelVal = -1
def getLabel(): String = { def getLabel(): String = {
labelVal += 1 labelVal += 1
s".L$labelVal" s".L$labelVal"
@@ -41,11 +42,21 @@ object asmGenerator {
case Ident(v, _) => s"wacc_$v" case Ident(v, _) => s"wacc_$v"
case Builtin(name) => s"_$name" case Builtin(name) => s"_$name"
} }
def getLabel(str: String): String =
strings.getOrElseUpdate(str, s".L.str${strings.size}")
def generateConstants: Chain[AsmLine] =
strings.foldLeft(Chain.empty) { case (acc, (str, label)) =>
acc ++ Chain(
Directive.Int(str.size),
LabelDef(label),
Directive.Asciz(str.escaped)
)
}
} }
def generateAsm(microProg: Program): Chain[AsmLine] = { def generateAsm(microProg: Program): Chain[AsmLine] = {
given stack: Stack = Stack() given stack: Stack = Stack()
given strings: ListBuffer[String] = ListBuffer[String]()
given labelGenerator: LabelGenerator = LabelGenerator() given labelGenerator: LabelGenerator = LabelGenerator()
val Program(funcs, main) = microProg val Program(funcs, main) = microProg
@@ -58,20 +69,13 @@ object asmGenerator {
funcs.foldMap(generateUserFunc(_)) funcs.foldMap(generateUserFunc(_))
) )
val strDirs = strings.toList.zipWithIndex.foldMap { case (str, i) =>
Chain(
Directive.Int(str.size),
LabelDef(s".L.str$i"),
Directive.Asciz(str.escaped)
)
} ++ RuntimeError.all.foldMap(_.stringDef)
Chain( Chain(
Directive.IntelSyntax, Directive.IntelSyntax,
Directive.Global("main"), Directive.Global("main"),
Directive.RoData Directive.RoData
).concatAll( ).concatAll(
strDirs, labelGenerator.generateConstants,
RuntimeError.all.foldMap(_.stringDef),
Chain.one(Directive.Text), Chain.one(Directive.Text),
progAsm progAsm
) )
@@ -88,7 +92,6 @@ object asmGenerator {
} }
private def generateUserFunc(func: FuncDecl)(using private def generateUserFunc(func: FuncDecl)(using
strings: ListBuffer[String],
labelGenerator: LabelGenerator labelGenerator: LabelGenerator
): Chain[AsmLine] = { ): Chain[AsmLine] = {
given stack: Stack = Stack() given stack: Stack = Stack()
@@ -175,7 +178,6 @@ object asmGenerator {
private def generateStmt(stmt: Stmt)(using private def generateStmt(stmt: Stmt)(using
stack: Stack, stack: Stack,
strings: ListBuffer[String],
labelGenerator: LabelGenerator labelGenerator: LabelGenerator
): Chain[AsmLine] = { ): Chain[AsmLine] = {
var chain = Chain.empty[AsmLine] var chain = Chain.empty[AsmLine]
@@ -259,7 +261,6 @@ object asmGenerator {
private def evalExprOntoStack(expr: Expr)(using private def evalExprOntoStack(expr: Expr)(using
stack: Stack, stack: Stack,
strings: ListBuffer[String],
labelGenerator: LabelGenerator labelGenerator: LabelGenerator
): Chain[AsmLine] = { ): Chain[AsmLine] = {
var chain = Chain.empty[AsmLine] var chain = Chain.empty[AsmLine]
@@ -272,8 +273,8 @@ object asmGenerator {
case array @ ArrayLiter(elems) => case array @ ArrayLiter(elems) =>
expr.ty match { expr.ty match {
case KnownType.String => case KnownType.String =>
strings += elems.collect { case CharLiter(v) => v }.mkString val str = elems.collect { case CharLiter(v) => v }.mkString
chain += Load(RAX, IndexAddress(RIP, LabelArg(s".L.str${strings.size - 1}"))) chain += Load(RAX, IndexAddress(RIP, LabelArg(labelGenerator.getLabel(str))))
chain += stack.push(Q64, RAX) chain += stack.push(Q64, RAX)
case ty => case ty =>
chain ++= generateCall( chain ++= generateCall(
@@ -398,7 +399,6 @@ object asmGenerator {
private def generateCall(call: microWacc.Call, isTail: Boolean)(using private def generateCall(call: microWacc.Call, isTail: Boolean)(using
stack: Stack, stack: Stack,
strings: ListBuffer[String],
labelGenerator: LabelGenerator labelGenerator: LabelGenerator
): Chain[AsmLine] = { ): Chain[AsmLine] = {
var chain = Chain.empty[AsmLine] var chain = Chain.empty[AsmLine]