feat: extension method concatAll defined on Chain implemented

This commit is contained in:
Jonny 2025-02-25 19:53:31 +00:00 committed by Gleb Koval
parent bd0eb76bec
commit ebc65af981
Signed by: cyclane
GPG Key ID: 15E168A8B332382C

View File

@ -3,7 +3,7 @@ package wacc
import scala.collection.mutable.LinkedHashMap import scala.collection.mutable.LinkedHashMap
import scala.collection.mutable.ListBuffer import scala.collection.mutable.ListBuffer
import cats.data.Chain import cats.data.Chain
// import cats.syntax.foldable._ import cats.syntax.foldable._
// import parsley.token.errors.Label // import parsley.token.errors.Label
object asmGenerator { object asmGenerator {
@ -30,6 +30,9 @@ object asmGenerator {
extension (chain: Chain[AsmLine]) extension (chain: Chain[AsmLine])
def +=(line: AsmLine): Chain[AsmLine] = chain.append(line) def +=(line: AsmLine): Chain[AsmLine] = chain.append(line)
def concatAll(chains: Chain[AsmLine]*): Chain[AsmLine] =
chains.foldLeft(chain)(_ ++ _)
class LabelGenerator { class LabelGenerator {
var labelVal = -1 var labelVal = -1
def getLabel(): String = { def getLabel(): String = {
@ -42,40 +45,38 @@ object asmGenerator {
} }
} }
def generateAsm(microProg: Program): List[AsmLine] = { def generateAsm(microProg: Program): List[AsmLine] = {
given stack: Stack = Stack() given stack: Stack = Stack()
given strings: ListBuffer[String] = ListBuffer[String]() given strings: ListBuffer[String] = ListBuffer[String]()
given labelGenerator: LabelGenerator = LabelGenerator() given labelGenerator: LabelGenerator = LabelGenerator()
val Program(funcs, main) = microProg val Program(funcs, main) = microProg
val progAsm = val strDirs = strings.toList.zipWithIndex.foldMap { case (str, i) =>
Chain.one(LabelDef("main")) ++ Chain(
funcPrologue() ++
Chain(stack.align()) ++
main.foldLeft(Chain.empty[AsmLine])(_ ++ generateStmt(_)) ++
Chain.one(Move(RAX, ImmediateVal(0))) ++
funcEpilogue() ++
generateBuiltInFuncs()
val strDirs = strings.toList.zipWithIndex.foldLeft(Chain.empty[AsmLine]) {
case (acc, (str, i)) =>
acc ++ Chain(
Directive.Int(str.size), Directive.Int(str.size),
LabelDef(s".L.str$i"), LabelDef(s".L.str$i"),
Directive.Asciz(str.escaped) Directive.Asciz(str.escaped)
) )
} }
val finalChain = Chain( val progAsm = Chain(LabelDef("main")).concatAll(
funcPrologue(),
Chain.one(stack.align()),
main.foldMap(generateStmt(_)),
Chain.one(Move(RAX, ImmediateVal(0))),
funcEpilogue(),
generateBuiltInFuncs()
)
Chain(
Directive.IntelSyntax, Directive.IntelSyntax,
Directive.Global("main"), Directive.Global("main"),
Directive.RoData Directive.RoData
) ++ strDirs ++ Chain.one(Directive.Text) ++ progAsm ).concatAll(
strDirs,
finalChain.toList Chain.one(Directive.Text),
progAsm
).toList
} }
def wrapFunc(labelName: String, funcBody: Chain[AsmLine])(using def wrapFunc(labelName: String, funcBody: Chain[AsmLine])(using
@ -137,7 +138,11 @@ object asmGenerator {
chain chain
} }
def generateStmt(stmt: Stmt)(using stack: Stack, strings: ListBuffer[String], labelGenerator: LabelGenerator): Chain[AsmLine] = { def generateStmt(stmt: Stmt)(using
stack: Stack,
strings: ListBuffer[String],
labelGenerator: LabelGenerator
): Chain[AsmLine] = {
var chain = Chain.empty[AsmLine] var chain = Chain.empty[AsmLine]
stmt match { stmt match {
@ -222,7 +227,6 @@ object asmGenerator {
case BoolLiter(v) => chain += stack.push(ImmediateVal(if (v) 1 else 0)) case BoolLiter(v) => chain += stack.push(ImmediateVal(if (v) 1 else 0))
case NullLiter() => chain += stack.push(ImmediateVal(0)) case NullLiter() => chain += stack.push(ImmediateVal(0))
case ArrayElem(_, _) => // TODO: Implement handling case ArrayElem(_, _) => // TODO: Implement handling
case UnaryOp(x, op) => case UnaryOp(x, op) =>
chain ++= evalExprOntoStack(x) chain ++= evalExprOntoStack(x)
op match { op match {