diff --git a/src/main/wacc/backend/asmGenerator.scala b/src/main/wacc/backend/asmGenerator.scala index c900756..381bc00 100644 --- a/src/main/wacc/backend/asmGenerator.scala +++ b/src/main/wacc/backend/asmGenerator.scala @@ -1,6 +1,6 @@ package wacc -import scala.collection.mutable.ListBuffer +import scala.collection.mutable import cats.data.Chain import cats.syntax.foldable._ import wacc.RuntimeError._ @@ -31,8 +31,9 @@ object asmGenerator { def concatAll(chains: Chain[T]*): Chain[T] = chains.foldLeft(chain)(_ ++ _) - class LabelGenerator { - var labelVal = -1 + private class LabelGenerator { + private val strings = mutable.HashMap[String, String]() + private var labelVal = -1 def getLabel(): String = { labelVal += 1 s".L$labelVal" @@ -41,11 +42,21 @@ object asmGenerator { case Ident(v, _) => s"wacc_$v" case Builtin(name) => s"_$name" } + def getLabel(str: String): String = + strings.getOrElseUpdate(str, s".L.str${strings.size}") + + def generateConstants: Chain[AsmLine] = + strings.foldLeft(Chain.empty) { case (acc, (str, label)) => + acc ++ Chain( + Directive.Int(str.size), + LabelDef(label), + Directive.Asciz(str.escaped) + ) + } } def generateAsm(microProg: Program): Chain[AsmLine] = { given stack: Stack = Stack() - given strings: ListBuffer[String] = ListBuffer[String]() given labelGenerator: LabelGenerator = LabelGenerator() val Program(funcs, main) = microProg @@ -58,20 +69,13 @@ object asmGenerator { funcs.foldMap(generateUserFunc(_)) ) - val strDirs = strings.toList.zipWithIndex.foldMap { case (str, i) => - Chain( - Directive.Int(str.size), - LabelDef(s".L.str$i"), - Directive.Asciz(str.escaped) - ) - } ++ RuntimeError.all.foldMap(_.stringDef) - Chain( Directive.IntelSyntax, Directive.Global("main"), Directive.RoData ).concatAll( - strDirs, + labelGenerator.generateConstants, + RuntimeError.all.foldMap(_.stringDef), Chain.one(Directive.Text), progAsm ) @@ -88,7 +92,6 @@ object asmGenerator { } private def generateUserFunc(func: FuncDecl)(using - strings: ListBuffer[String], labelGenerator: LabelGenerator ): Chain[AsmLine] = { given stack: Stack = Stack() @@ -175,7 +178,6 @@ object asmGenerator { private def generateStmt(stmt: Stmt)(using stack: Stack, - strings: ListBuffer[String], labelGenerator: LabelGenerator ): Chain[AsmLine] = { var chain = Chain.empty[AsmLine] @@ -259,7 +261,6 @@ object asmGenerator { private def evalExprOntoStack(expr: Expr)(using stack: Stack, - strings: ListBuffer[String], labelGenerator: LabelGenerator ): Chain[AsmLine] = { var chain = Chain.empty[AsmLine] @@ -272,8 +273,8 @@ object asmGenerator { case array @ ArrayLiter(elems) => expr.ty match { case KnownType.String => - strings += elems.collect { case CharLiter(v) => v }.mkString - chain += Load(RAX, IndexAddress(RIP, LabelArg(s".L.str${strings.size - 1}"))) + val str = elems.collect { case CharLiter(v) => v }.mkString + chain += Load(RAX, IndexAddress(RIP, LabelArg(labelGenerator.getLabel(str)))) chain += stack.push(Q64, RAX) case ty => chain ++= generateCall( @@ -398,7 +399,6 @@ object asmGenerator { private def generateCall(call: microWacc.Call, isTail: Boolean)(using stack: Stack, - strings: ListBuffer[String], labelGenerator: LabelGenerator ): Chain[AsmLine] = { var chain = Chain.empty[AsmLine]