feat: used Chains instead of Lists

This commit is contained in:
Jonny 2025-02-25 18:44:11 +00:00 committed by Gleb Koval
parent 7fd92b4212
commit 7953790f4d
Signed by: cyclane
GPG Key ID: 15E168A8B332382C

View File

@ -2,6 +2,8 @@ package wacc
import scala.collection.mutable.LinkedHashMap import scala.collection.mutable.LinkedHashMap
import scala.collection.mutable.ListBuffer import scala.collection.mutable.ListBuffer
import cats.data.Chain
import cats.syntax.foldable._
object asmGenerator { object asmGenerator {
import microWacc._ import microWacc._
@ -31,33 +33,38 @@ object asmGenerator {
val Program(funcs, main) = microProg val Program(funcs, main) = microProg
val progAsm = val progAsm =
LabelDef("main") :: Chain.one(LabelDef("main")) ++
funcPrologue() ++ funcPrologue() ++
List(stack.align()) ++ Chain(stack.align()) ++
main.flatMap(generateStmt) ++ main.foldLeft(Chain.empty[AsmLine])(_ ++ generateStmt(_)) ++
List(Move(RAX, ImmediateVal(0))) ++ Chain.one(Move(RAX, ImmediateVal(0))) ++
funcEpilogue() ++ funcEpilogue() ++
generateBuiltInFuncs() generateBuiltInFuncs()
val strDirs = strings.toList.zipWithIndex.flatMap { case (str, i) => val strDirs = strings.toList.zipWithIndex.foldLeft(Chain.empty[AsmLine]) {
List( case (acc, (str, i)) =>
acc ++ Chain(
Directive.Int(str.size), Directive.Int(str.size),
LabelDef(s".L.str$i"), LabelDef(s".L.str$i"),
Directive.Asciz(str.escaped) Directive.Asciz(str.escaped)
) )
} }
List(Directive.IntelSyntax, Directive.Global("main"), Directive.RoData) ++ val finalChain = Chain(
strDirs ++ Directive.IntelSyntax,
List(Directive.Text) ++ Directive.Global("main"),
progAsm Directive.RoData
) ++ strDirs ++ Chain.one(Directive.Text) ++ progAsm
finalChain.toList
} }
def wrapFunc(labelName: String, funcBody: List[AsmLine])(using def wrapFunc(labelName: String, funcBody: Chain[AsmLine])(using
stack: Stack, stack: Stack,
strings: ListBuffer[String] strings: ListBuffer[String]
): List[AsmLine] = { ): Chain[AsmLine] = {
LabelDef(labelName) :: Chain.one(LabelDef(labelName)) ++
funcPrologue() ++ funcPrologue() ++
funcBody ++ funcBody ++
funcEpilogue() funcEpilogue()
@ -65,16 +72,15 @@ object asmGenerator {
def generateBuiltInFuncs()(using def generateBuiltInFuncs()(using
stack: Stack, stack: Stack,
strings: ListBuffer[String], strings: ListBuffer[String]
labelGenerator: LabelGenerator ): Chain[AsmLine] = {
): List[AsmLine] = {
wrapFunc( wrapFunc(
labelGenerator.getLabel(Builtin.Exit), labelGenerator.getLabel(Builtin.Exit),
List(stack.align(), assemblyIR.Call(CLibFunc.Exit)) Chain(stack.align(), assemblyIR.Call(CLibFunc.Exit))
) ++ ) ++
wrapFunc( wrapFunc(
labelGenerator.getLabel(Builtin.Printf), labelGenerator.getLabel(Builtin.Printf),
List( Chain(
stack.align(), stack.align(),
assemblyIR.Call(CLibFunc.PrintF), assemblyIR.Call(CLibFunc.PrintF),
Move(RDI, ImmediateVal(0)), Move(RDI, ImmediateVal(0)),
@ -83,14 +89,14 @@ object asmGenerator {
) ++ ) ++
wrapFunc( wrapFunc(
labelGenerator.getLabel(Builtin.Malloc), labelGenerator.getLabel(Builtin.Malloc),
List( Chain.one(
stack.align() stack.align()
) )
) ++ ) ++
wrapFunc(labelGenerator.getLabel(Builtin.Free), List()) ++ wrapFunc(labelGenerator.getLabel(Builtin.Free), Chain.empty) ++
wrapFunc( wrapFunc(
labelGenerator.getLabel(Builtin.Read), labelGenerator.getLabel(Builtin.Read),
List( Chain(
stack.align(), stack.align(),
stack.reserve(), stack.reserve(),
stack.push(RSI), stack.push(RSI),
@ -104,11 +110,7 @@ object asmGenerator {
def generateStmt( def generateStmt(
stmt: Stmt stmt: Stmt
)(using )(using stack: Stack, strings: ListBuffer[String]): Chain[AsmLine] =
stack: Stack,
strings: ListBuffer[String],
labelGenerator: LabelGenerator
): List[AsmLine] =
stmt match { stmt match {
case Assign(lhs, rhs) => case Assign(lhs, rhs) =>
var dest: () => IndexAddress = var dest: () => IndexAddress =
@ -117,15 +119,15 @@ object asmGenerator {
case ident: Ident => case ident: Ident =>
dest = stack.accessVar(ident) dest = stack.accessVar(ident)
if (!stack.contains(ident)) { if (!stack.contains(ident)) {
List(stack.reserve(ident)) Chain.one(stack.reserve(ident))
} else Nil } else Chain.empty
// TODO lhs = arrayElem // TODO lhs = arrayElem
case _ => case _ =>
// dest = ??? // dest = ???
List() Chain.empty
}) ++ }) ++
evalExprOntoStack(rhs) ++ evalExprOntoStack(rhs) ++
List( Chain(
stack.pop(RAX), stack.pop(RAX),
Move(dest(), RAX) Move(dest(), RAX)
) )
@ -133,47 +135,48 @@ object asmGenerator {
val elseLabel = labelGenerator.getLabel() val elseLabel = labelGenerator.getLabel()
val endLabel = labelGenerator.getLabel() val endLabel = labelGenerator.getLabel()
evalExprOntoStack(cond) ++ evalExprOntoStack(cond) ++
Chain.fromSeq(
List( List(
Compare(stack.head(SizeDir.Word), ImmediateVal(0)), Compare(stack.head(SizeDir.Word), ImmediateVal(0)),
stack.drop(), stack.drop(),
Jump(LabelArg(elseLabel), Cond.Equal) Jump(LabelArg(elseLabel), Cond.Equal)
)
) ++ ) ++
thenBranch.flatMap(generateStmt) ++ Chain.fromSeq(thenBranch).flatMap(generateStmt) ++
List(Jump(LabelArg(endLabel)), LabelDef(elseLabel)) ++ Chain.fromSeq(List(Jump(LabelArg(endLabel)), LabelDef(elseLabel))) ++
elseBranch.flatMap(generateStmt) ++ Chain.fromSeq(elseBranch).flatMap(generateStmt) ++
List(LabelDef(endLabel)) Chain.one(LabelDef(endLabel))
} }
case While(cond, body) => { case While(cond, body) => {
val startLabel = labelGenerator.getLabel() val startLabel = labelGenerator.getLabel()
val endLabel = labelGenerator.getLabel() val endLabel = labelGenerator.getLabel()
List(LabelDef(startLabel)) ++ Chain.one(LabelDef(startLabel)) ++
evalExprOntoStack(cond) ++ evalExprOntoStack(cond) ++
List( Chain(
Compare(stack.head(SizeDir.Word), ImmediateVal(0)), Compare(stack.head(SizeDir.Word), ImmediateVal(0)),
stack.drop(), stack.drop(),
Jump(LabelArg(endLabel), Cond.Equal) Jump(LabelArg(endLabel), Cond.Equal)
) ++ ) ++
body.flatMap(generateStmt) ++ Chain.fromSeq(body).flatMap(generateStmt) ++
List(Jump(LabelArg(startLabel)), LabelDef(endLabel)) Chain(Jump(LabelArg(startLabel)), LabelDef(endLabel))
} }
case microWacc.Return(expr) => case microWacc.Return(expr) =>
evalExprOntoStack(expr) ++ evalExprOntoStack(expr) ++
List(stack.pop(RAX), assemblyIR.Return()) Chain(stack.pop(RAX), assemblyIR.Return())
case call: microWacc.Call => generateCall(call) case call: microWacc.Call => generateCall(call)
} }
def evalExprOntoStack(expr: Expr)(using def evalExprOntoStack(expr: Expr)(using
stack: Stack, stack: Stack,
strings: ListBuffer[String], strings: ListBuffer[String]
labelGenerator: LabelGenerator ): Chain[AsmLine] = {
): List[AsmLine] = {
val out = expr match { val out = expr match {
case IntLiter(v) => case IntLiter(v) =>
List(stack.push(ImmediateVal(v))) Chain.one(stack.push(ImmediateVal(v)))
case CharLiter(v) => case CharLiter(v) =>
List(stack.push(ImmediateVal(v.toInt))) Chain.one(stack.push(ImmediateVal(v.toInt)))
case ident: Ident => case ident: Ident =>
List(stack.push(stack.accessVar(ident)())) Chain.one(stack.push(stack.accessVar(ident)()))
case ArrayLiter(elems) => case ArrayLiter(elems) =>
expr.ty match { expr.ty match {
case KnownType.String => case KnownType.String =>
@ -181,7 +184,7 @@ object asmGenerator {
case CharLiter(v) => v case CharLiter(v) => v
case _ => "" case _ => ""
}.mkString }.mkString
List( Chain(
Load( Load(
RAX, RAX,
IndexAddress( IndexAddress(
@ -192,26 +195,24 @@ object asmGenerator {
stack.push(RAX) stack.push(RAX)
) )
// TODO other array types // TODO other array types
case _ => List() case _ => Chain.empty
} }
case BoolLiter(v) => List(stack.push(ImmediateVal(if (v) 1 else 0))) case BoolLiter(v) => Chain.one(stack.push(ImmediateVal(if (v) 1 else 0)))
case NullLiter() => List(stack.push(ImmediateVal(0))) case NullLiter() => Chain.one(stack.push(ImmediateVal(0)))
case ArrayElem(value, indices) => List() case ArrayElem(value, indices) => Chain.empty
case UnaryOp(x, op) => case UnaryOp(x, op) =>
evalExprOntoStack(x) ++ evalExprOntoStack(x) ++
(op match { (op match {
// TODO: chr and ord are TYPE CASTS. They do not change the internal value, // TODO: chr and ord are TYPE CASTS. They do not change the internal value,
// but will need bound checking e.t.c. // but will need bound checking e.t.c.
case UnaryOperator.Chr => List() case UnaryOperator.Chr | UnaryOperator.Ord | UnaryOperator.Len => Chain.empty
case UnaryOperator.Ord => List()
case UnaryOperator.Len => List()
case UnaryOperator.Negate => case UnaryOperator.Negate =>
List( Chain.one(
Negate(stack.head(SizeDir.Word)) Negate(stack.head(SizeDir.Word))
) )
case UnaryOperator.Not => case UnaryOperator.Not =>
evalExprOntoStack(x) ++ evalExprOntoStack(x) ++
List( Chain.one(
Xor(stack.head(SizeDir.Word), ImmediateVal(1)) Xor(stack.head(SizeDir.Word), ImmediateVal(1))
) )
@ -221,7 +222,7 @@ object asmGenerator {
case BinaryOperator.Add => case BinaryOperator.Add =>
evalExprOntoStack(x) ++ evalExprOntoStack(x) ++
evalExprOntoStack(y) ++ evalExprOntoStack(y) ++
List( Chain(
stack.pop(RAX), stack.pop(RAX),
Add(stack.head(SizeDir.Word), EAX) Add(stack.head(SizeDir.Word), EAX)
// TODO OVERFLOWING // TODO OVERFLOWING
@ -229,7 +230,7 @@ object asmGenerator {
case BinaryOperator.Sub => case BinaryOperator.Sub =>
evalExprOntoStack(x) ++ evalExprOntoStack(x) ++
evalExprOntoStack(y) ++ evalExprOntoStack(y) ++
List( Chain(
stack.pop(RAX), stack.pop(RAX),
Subtract(stack.head(SizeDir.Word), EAX) Subtract(stack.head(SizeDir.Word), EAX)
// TODO OVERFLOWING // TODO OVERFLOWING
@ -237,7 +238,7 @@ object asmGenerator {
case BinaryOperator.Mul => case BinaryOperator.Mul =>
evalExprOntoStack(x) ++ evalExprOntoStack(x) ++
evalExprOntoStack(y) ++ evalExprOntoStack(y) ++
List( Chain(
stack.pop(RAX), stack.pop(RAX),
Multiply(EAX, stack.head(SizeDir.Word)), Multiply(EAX, stack.head(SizeDir.Word)),
stack.drop(), stack.drop(),
@ -247,7 +248,7 @@ object asmGenerator {
case BinaryOperator.Div => case BinaryOperator.Div =>
evalExprOntoStack(y) ++ evalExprOntoStack(y) ++
evalExprOntoStack(x) ++ evalExprOntoStack(x) ++
List( Chain(
stack.pop(RAX), stack.pop(RAX),
Divide(stack.head(SizeDir.Word)), Divide(stack.head(SizeDir.Word)),
stack.drop(), stack.drop(),
@ -257,7 +258,7 @@ object asmGenerator {
case BinaryOperator.Mod => case BinaryOperator.Mod =>
evalExprOntoStack(y) ++ evalExprOntoStack(y) ++
evalExprOntoStack(x) ++ evalExprOntoStack(x) ++
List( Chain(
stack.pop(RAX), stack.pop(RAX),
Divide(stack.head(SizeDir.Word)), Divide(stack.head(SizeDir.Word)),
stack.drop(), stack.drop(),
@ -279,51 +280,56 @@ object asmGenerator {
case BinaryOperator.And => case BinaryOperator.And =>
evalExprOntoStack(x) ++ evalExprOntoStack(x) ++
evalExprOntoStack(y) ++ evalExprOntoStack(y) ++
List( Chain(
stack.pop(RAX), stack.pop(RAX),
And(stack.head(SizeDir.Word), EAX) And(stack.head(SizeDir.Word), EAX)
) )
case BinaryOperator.Or => case BinaryOperator.Or =>
evalExprOntoStack(x) ++ evalExprOntoStack(x) ++
evalExprOntoStack(y) ++ evalExprOntoStack(y) ++
List( Chain(
stack.pop(RAX), stack.pop(RAX),
Or(stack.head(SizeDir.Word), EAX) Or(stack.head(SizeDir.Word), EAX)
) )
} }
case call: microWacc.Call => case call: microWacc.Call =>
generateCall(call) ++ generateCall(call) ++
List(stack.push(RAX)) Chain.one(stack.push(RAX))
} }
if out.isEmpty then List(stack.push(ImmediateVal(0))) else out if out.isEmpty then Chain.one(stack.push(ImmediateVal(0))) else out
} }
def generateCall(call: microWacc.Call)(using def generateCall(call: microWacc.Call)(using
stack: Stack, stack: Stack,
strings: ListBuffer[String], strings: ListBuffer[String]
labelGenerator: LabelGenerator ): Chain[AsmLine] = {
): List[AsmLine] = {
val argRegs = List(RDI, RSI, RDX, RCX, R8, R9) val argRegs = List(RDI, RSI, RDX, RCX, R8, R9)
val microWacc.Call(target, args) = call val microWacc.Call(target, args) = call
argRegs.zip(args).flatMap { (reg, expr) =>
val regMoves = argRegs
.zip(args)
.map { (reg, expr) =>
evalExprOntoStack(expr) ++ evalExprOntoStack(expr) ++
List(stack.pop(reg)) Chain.one(stack.pop(reg))
} ++ }
args.drop(argRegs.size).flatMap(evalExprOntoStack) ++ .combineAll
List(assemblyIR.Call(LabelArg(labelGenerator.getLabel(target)))) ++
(if (args.size > argRegs.size) { val stackPushes = args.drop(argRegs.size).map(evalExprOntoStack).combineAll
List(stack.drop(args.size - argRegs.size))
} else Nil) regMoves ++
stackPushes ++
Chain.one(assemblyIR.Call(LabelArg(labelGenerator.getLabel(target)))) ++
(if (args.size > argRegs.size) Chain.one(stack.drop(args.size - argRegs.size))
else Chain.empty)
} }
def generateComparison(x: Expr, y: Expr, cond: Cond)(using def generateComparison(x: Expr, y: Expr, cond: Cond)(using
stack: Stack, stack: Stack,
strings: ListBuffer[String], strings: ListBuffer[String]
labelGenerator: LabelGenerator ): Chain[AsmLine] = {
): List[AsmLine] = {
evalExprOntoStack(x) ++ evalExprOntoStack(x) ++
evalExprOntoStack(y) ++ evalExprOntoStack(y) ++
List( Chain(
stack.pop(RAX), stack.pop(RAX),
Compare(stack.head(SizeDir.Word), EAX), Compare(stack.head(SizeDir.Word), EAX),
Set(Register(RegSize.Byte, RegName.AL), cond), Set(Register(RegSize.Byte, RegName.AL), cond),
@ -334,15 +340,15 @@ object asmGenerator {
} }
// Missing a sub instruction but dont think we need it // Missing a sub instruction but dont think we need it
def funcPrologue()(using stack: Stack): List[AsmLine] = { def funcPrologue()(using stack: Stack): Chain[AsmLine] = {
List( Chain(
stack.push(RBP), stack.push(RBP),
Move(RBP, Register(RegSize.R64, RegName.SP)) Move(RBP, Register(RegSize.R64, RegName.SP))
) )
} }
def funcEpilogue()(using stack: Stack): List[AsmLine] = { def funcEpilogue()(using stack: Stack): Chain[AsmLine] = {
List( Chain(
Move(Register(RegSize.R64, RegName.SP), RBP), Move(Register(RegSize.R64, RegName.SP), RBP),
stack.pop(RBP), stack.pop(RBP),
assemblyIR.Return() assemblyIR.Return()