feat: used local mutable Chains. Also implemented new LabelGenerator

This commit is contained in:
Jonny 2025-02-25 19:39:55 +00:00 committed by Gleb Koval
parent 7953790f4d
commit edbc03ee25
Signed by: cyclane
GPG Key ID: 15E168A8B332382C

View File

@ -4,6 +4,7 @@ import scala.collection.mutable.LinkedHashMap
import scala.collection.mutable.ListBuffer import scala.collection.mutable.ListBuffer
import cats.data.Chain import cats.data.Chain
import cats.syntax.foldable._ import cats.syntax.foldable._
import parsley.token.errors.Label
object asmGenerator { object asmGenerator {
import microWacc._ import microWacc._
@ -26,10 +27,27 @@ object asmGenerator {
val _8_BIT_MASK = 0xff val _8_BIT_MASK = 0xff
extension (chain: Chain[AsmLine])
def +=(line: AsmLine): Chain[AsmLine] = chain.append(line)
class LabelGenerator {
var labelVal = -1
def getLabel(): String = {
labelVal += 1
s".L$labelVal"
}
def getLabel(target: CallTarget): String = target match {
case Ident(v, _) => s"wacc_$v"
case Builtin(name) => s"_$name"
}
}
def generateAsm(microProg: Program): List[AsmLine] = { def generateAsm(microProg: Program): List[AsmLine] = {
given stack: Stack = Stack() given stack: Stack = Stack()
given strings: ListBuffer[String] = ListBuffer[String]() given strings: ListBuffer[String] = ListBuffer[String]()
given labelGenerator: LabelGenerator = LabelGenerator() given labelGenerator: LabelGenerator = LabelGenerator()
val Program(funcs, main) = microProg val Program(funcs, main) = microProg
val progAsm = val progAsm =
@ -64,21 +82,29 @@ object asmGenerator {
stack: Stack, stack: Stack,
strings: ListBuffer[String] strings: ListBuffer[String]
): Chain[AsmLine] = { ): Chain[AsmLine] = {
Chain.one(LabelDef(labelName)) ++ var chain = Chain.empty[AsmLine]
funcPrologue() ++
funcBody ++ chain += LabelDef(labelName)
funcEpilogue() chain ++= funcPrologue()
chain ++= funcBody
chain ++= funcEpilogue()
chain
} }
def generateBuiltInFuncs()(using def generateBuiltInFuncs()(using
stack: Stack, stack: Stack,
strings: ListBuffer[String] strings: ListBuffer[String],
labelGenerator: LabelGenerator
): Chain[AsmLine] = { ): Chain[AsmLine] = {
wrapFunc( var chain = Chain.empty[AsmLine]
chain ++= wrapFunc(
labelGenerator.getLabel(Builtin.Exit), labelGenerator.getLabel(Builtin.Exit),
Chain(stack.align(), assemblyIR.Call(CLibFunc.Exit)) Chain(stack.align(), assemblyIR.Call(CLibFunc.Exit))
) ++ )
wrapFunc(
chain ++= wrapFunc(
labelGenerator.getLabel(Builtin.Printf), labelGenerator.getLabel(Builtin.Printf),
Chain( Chain(
stack.align(), stack.align(),
@ -86,15 +112,16 @@ object asmGenerator {
Move(RDI, ImmediateVal(0)), Move(RDI, ImmediateVal(0)),
assemblyIR.Call(CLibFunc.Fflush) assemblyIR.Call(CLibFunc.Fflush)
) )
) ++
wrapFunc(
labelGenerator.getLabel(Builtin.Malloc),
Chain.one(
stack.align()
) )
) ++
wrapFunc(labelGenerator.getLabel(Builtin.Free), Chain.empty) ++ chain ++= wrapFunc(
wrapFunc( labelGenerator.getLabel(Builtin.Malloc),
Chain.one(stack.align())
)
chain ++= wrapFunc(labelGenerator.getLabel(Builtin.Free), Chain.empty)
chain ++= wrapFunc(
labelGenerator.getLabel(Builtin.Read), labelGenerator.getLabel(Builtin.Read),
Chain( Chain(
stack.align(), stack.align(),
@ -106,253 +133,209 @@ object asmGenerator {
stack.drop() stack.drop()
) )
) )
chain
} }
def generateStmt( def generateStmt(stmt: Stmt)(using stack: Stack, strings: ListBuffer[String], labelGenerator: LabelGenerator): Chain[AsmLine] = {
stmt: Stmt var chain = Chain.empty[AsmLine]
)(using stack: Stack, strings: ListBuffer[String]): Chain[AsmLine] =
stmt match { stmt match {
case Assign(lhs, rhs) => case Assign(lhs, rhs) =>
var dest: () => IndexAddress = var dest: () => IndexAddress = () => IndexAddress(RAX, 0) // overwritten below
() => IndexAddress(RAX, 0) // gets overrwitten
(lhs match { lhs match {
case ident: Ident => case ident: Ident =>
dest = stack.accessVar(ident) dest = stack.accessVar(ident)
if (!stack.contains(ident)) { if (!stack.contains(ident)) chain += stack.reserve(ident)
Chain.one(stack.reserve(ident))
} else Chain.empty
// TODO lhs = arrayElem // TODO lhs = arrayElem
case _ => case _ =>
// dest = ??? }
Chain.empty
}) ++ chain ++= evalExprOntoStack(rhs)
evalExprOntoStack(rhs) ++ chain += stack.pop(RAX)
Chain( chain += Move(dest(), RAX)
stack.pop(RAX),
Move(dest(), RAX) case If(cond, thenBranch, elseBranch) =>
)
case If(cond, thenBranch, elseBranch) => {
val elseLabel = labelGenerator.getLabel() val elseLabel = labelGenerator.getLabel()
val endLabel = labelGenerator.getLabel() val endLabel = labelGenerator.getLabel()
evalExprOntoStack(cond) ++
Chain.fromSeq( chain ++= evalExprOntoStack(cond)
List( chain += Compare(stack.head(SizeDir.Word), ImmediateVal(0))
Compare(stack.head(SizeDir.Word), ImmediateVal(0)), chain += stack.drop()
stack.drop(), chain += Jump(LabelArg(elseLabel), Cond.Equal)
Jump(LabelArg(elseLabel), Cond.Equal)
) chain ++= Chain.fromSeq(thenBranch).flatMap(generateStmt)
) ++ chain += Jump(LabelArg(endLabel))
Chain.fromSeq(thenBranch).flatMap(generateStmt) ++ chain += LabelDef(elseLabel)
Chain.fromSeq(List(Jump(LabelArg(endLabel)), LabelDef(elseLabel))) ++
Chain.fromSeq(elseBranch).flatMap(generateStmt) ++ chain ++= Chain.fromSeq(elseBranch).flatMap(generateStmt)
Chain.one(LabelDef(endLabel)) chain += LabelDef(endLabel)
}
case While(cond, body) => { case While(cond, body) =>
val startLabel = labelGenerator.getLabel() val startLabel = labelGenerator.getLabel()
val endLabel = labelGenerator.getLabel() val endLabel = labelGenerator.getLabel()
Chain.one(LabelDef(startLabel)) ++
evalExprOntoStack(cond) ++ chain += LabelDef(startLabel)
Chain( chain ++= evalExprOntoStack(cond)
Compare(stack.head(SizeDir.Word), ImmediateVal(0)), chain += Compare(stack.head(SizeDir.Word), ImmediateVal(0))
stack.drop(), chain += stack.drop()
Jump(LabelArg(endLabel), Cond.Equal) chain += Jump(LabelArg(endLabel), Cond.Equal)
) ++
Chain.fromSeq(body).flatMap(generateStmt) ++ chain ++= Chain.fromSeq(body).flatMap(generateStmt)
Chain(Jump(LabelArg(startLabel)), LabelDef(endLabel)) chain += Jump(LabelArg(startLabel))
} chain += LabelDef(endLabel)
case microWacc.Return(expr) => case microWacc.Return(expr) =>
evalExprOntoStack(expr) ++ chain ++= evalExprOntoStack(expr)
Chain(stack.pop(RAX), assemblyIR.Return()) chain += stack.pop(RAX)
case call: microWacc.Call => generateCall(call) chain += assemblyIR.Return()
case call: microWacc.Call =>
chain ++= generateCall(call)
}
chain
} }
def evalExprOntoStack(expr: Expr)(using def evalExprOntoStack(expr: Expr)(using
stack: Stack, stack: Stack,
strings: ListBuffer[String] strings: ListBuffer[String],
labelGenerator: LabelGenerator
): Chain[AsmLine] = { ): Chain[AsmLine] = {
val out = expr match { var chain = Chain.empty[AsmLine]
case IntLiter(v) =>
Chain.one(stack.push(ImmediateVal(v))) expr match {
case CharLiter(v) => case IntLiter(v) => chain += stack.push(ImmediateVal(v))
Chain.one(stack.push(ImmediateVal(v.toInt))) case CharLiter(v) => chain += stack.push(ImmediateVal(v.toInt))
case ident: Ident => case ident: Ident => chain += stack.push(stack.accessVar(ident)())
Chain.one(stack.push(stack.accessVar(ident)()))
case ArrayLiter(elems) => case ArrayLiter(elems) =>
expr.ty match { expr.ty match {
case KnownType.String => case KnownType.String =>
strings += elems.map { strings += elems.collect { case CharLiter(v) => v }.mkString
case CharLiter(v) => v chain += Load(RAX, IndexAddress(RIP, LabelArg(s".L.str${strings.size - 1}")))
case _ => "" chain += stack.push(RAX)
}.mkString case _ => // Other array types TODO
Chain(
Load(
RAX,
IndexAddress(
RIP,
LabelArg(s".L.str${strings.size - 1}")
)
),
stack.push(RAX)
)
// TODO other array types
case _ => Chain.empty
} }
case BoolLiter(v) => Chain.one(stack.push(ImmediateVal(if (v) 1 else 0)))
case NullLiter() => Chain.one(stack.push(ImmediateVal(0)))
case ArrayElem(value, indices) => Chain.empty
case UnaryOp(x, op) =>
evalExprOntoStack(x) ++
(op match {
// TODO: chr and ord are TYPE CASTS. They do not change the internal value,
// but will need bound checking e.t.c.
case UnaryOperator.Chr | UnaryOperator.Ord | UnaryOperator.Len => Chain.empty
case UnaryOperator.Negate =>
Chain.one(
Negate(stack.head(SizeDir.Word))
)
case UnaryOperator.Not =>
evalExprOntoStack(x) ++
Chain.one(
Xor(stack.head(SizeDir.Word), ImmediateVal(1))
)
}) case BoolLiter(v) => chain += stack.push(ImmediateVal(if (v) 1 else 0))
case BinaryOp(x, y, op) => case NullLiter() => chain += stack.push(ImmediateVal(0))
case ArrayElem(_, _) => // TODO: Implement handling
case UnaryOp(x, op) =>
chain ++= evalExprOntoStack(x)
op match { op match {
case BinaryOperator.Add => case UnaryOperator.Chr | UnaryOperator.Ord | UnaryOperator.Len => // No op needed
evalExprOntoStack(x) ++ case UnaryOperator.Negate => chain += Negate(stack.head(SizeDir.Word))
evalExprOntoStack(y) ++ case UnaryOperator.Not =>
Chain( chain += Xor(stack.head(SizeDir.Word), ImmediateVal(1))
stack.pop(RAX), }
Add(stack.head(SizeDir.Word), EAX)
// TODO OVERFLOWING case BinaryOp(x, y, op) =>
) chain ++= evalExprOntoStack(x)
case BinaryOperator.Sub => chain ++= evalExprOntoStack(y)
evalExprOntoStack(x) ++
evalExprOntoStack(y) ++ chain += stack.pop(RAX)
Chain(
stack.pop(RAX), op match {
Subtract(stack.head(SizeDir.Word), EAX) case BinaryOperator.Add => chain += Add(stack.head(SizeDir.Word), EAX)
// TODO OVERFLOWING case BinaryOperator.Sub => chain += Subtract(stack.head(SizeDir.Word), EAX)
)
case BinaryOperator.Mul => case BinaryOperator.Mul =>
evalExprOntoStack(x) ++ chain += Multiply(EAX, stack.head(SizeDir.Word))
evalExprOntoStack(y) ++ chain += stack.drop()
Chain( chain += stack.push(RAX)
stack.pop(RAX),
Multiply(EAX, stack.head(SizeDir.Word)),
stack.drop(),
stack.push(RAX)
// TODO OVERFLOWING
)
case BinaryOperator.Div => case BinaryOperator.Div =>
evalExprOntoStack(y) ++ chain += Divide(stack.head(SizeDir.Word))
evalExprOntoStack(x) ++ chain += stack.drop()
Chain( chain += stack.push(RAX)
stack.pop(RAX),
Divide(stack.head(SizeDir.Word)),
stack.drop(),
stack.push(RAX)
// TODO CHECK DIVISOR IS NOT 0
)
case BinaryOperator.Mod => case BinaryOperator.Mod =>
evalExprOntoStack(y) ++ chain += Divide(stack.head(SizeDir.Word))
evalExprOntoStack(x) ++ chain += stack.drop()
Chain( chain += stack.push(RDX)
stack.pop(RAX),
Divide(stack.head(SizeDir.Word)), case BinaryOperator.Eq => chain ++= generateComparison(x, y, Cond.Equal)
stack.drop(), case BinaryOperator.Neq => chain ++= generateComparison(x, y, Cond.NotEqual)
stack.push(RDX) case BinaryOperator.Greater => chain ++= generateComparison(x, y, Cond.Greater)
// TODO CHECK DIVISOR IS NOT 0 case BinaryOperator.GreaterEq => chain ++= generateComparison(x, y, Cond.GreaterEqual)
) case BinaryOperator.Less => chain ++= generateComparison(x, y, Cond.Less)
case BinaryOperator.Eq => case BinaryOperator.LessEq => chain ++= generateComparison(x, y, Cond.LessEqual)
generateComparison(x, y, Cond.Equal) case BinaryOperator.And => chain += And(stack.head(SizeDir.Word), EAX)
case BinaryOperator.Neq => case BinaryOperator.Or => chain += Or(stack.head(SizeDir.Word), EAX)
generateComparison(x, y, Cond.NotEqual)
case BinaryOperator.Greater =>
generateComparison(x, y, Cond.Greater)
case BinaryOperator.GreaterEq =>
generateComparison(x, y, Cond.GreaterEqual)
case BinaryOperator.Less =>
generateComparison(x, y, Cond.Less)
case BinaryOperator.LessEq =>
generateComparison(x, y, Cond.LessEqual)
case BinaryOperator.And =>
evalExprOntoStack(x) ++
evalExprOntoStack(y) ++
Chain(
stack.pop(RAX),
And(stack.head(SizeDir.Word), EAX)
)
case BinaryOperator.Or =>
evalExprOntoStack(x) ++
evalExprOntoStack(y) ++
Chain(
stack.pop(RAX),
Or(stack.head(SizeDir.Word), EAX)
)
} }
case call: microWacc.Call => case call: microWacc.Call =>
generateCall(call) ++ chain ++= generateCall(call)
Chain.one(stack.push(RAX)) chain += stack.push(RAX)
} }
if out.isEmpty then Chain.one(stack.push(ImmediateVal(0))) else out
if chain.isEmpty then chain += stack.push(ImmediateVal(0))
chain
} }
def generateCall(call: microWacc.Call)(using def generateCall(call: microWacc.Call)(using
stack: Stack, stack: Stack,
strings: ListBuffer[String] strings: ListBuffer[String],
labelGenerator: LabelGenerator
): Chain[AsmLine] = { ): Chain[AsmLine] = {
var chain = Chain.empty[AsmLine]
val argRegs = List(RDI, RSI, RDX, RCX, R8, R9) val argRegs = List(RDI, RSI, RDX, RCX, R8, R9)
val microWacc.Call(target, args) = call val microWacc.Call(target, args) = call
val regMoves = argRegs argRegs.zip(args).foreach { (reg, expr) =>
.zip(args) chain ++= evalExprOntoStack(expr)
.map { (reg, expr) => chain += stack.pop(reg)
evalExprOntoStack(expr) ++
Chain.one(stack.pop(reg))
} }
.combineAll
val stackPushes = args.drop(argRegs.size).map(evalExprOntoStack).combineAll args.drop(argRegs.size).foreach { expr =>
chain ++= evalExprOntoStack(expr)
}
regMoves ++ chain += assemblyIR.Call(LabelArg(labelGenerator.getLabel(target)))
stackPushes ++
Chain.one(assemblyIR.Call(LabelArg(labelGenerator.getLabel(target)))) ++ if (args.size > argRegs.size) {
(if (args.size > argRegs.size) Chain.one(stack.drop(args.size - argRegs.size)) chain += stack.drop(args.size - argRegs.size)
else Chain.empty) }
chain
} }
def generateComparison(x: Expr, y: Expr, cond: Cond)(using def generateComparison(x: Expr, y: Expr, cond: Cond)(using
stack: Stack, stack: Stack,
strings: ListBuffer[String] strings: ListBuffer[String],
labelGenerator: LabelGenerator
): Chain[AsmLine] = { ): Chain[AsmLine] = {
evalExprOntoStack(x) ++
evalExprOntoStack(y) ++ var chain = Chain.empty[AsmLine]
Chain(
stack.pop(RAX), chain ++= evalExprOntoStack(x)
Compare(stack.head(SizeDir.Word), EAX), chain ++= evalExprOntoStack(y)
Set(Register(RegSize.Byte, RegName.AL), cond), chain += stack.pop(RAX)
And(RAX, ImmediateVal(_8_BIT_MASK)), chain += Compare(stack.head(SizeDir.Word), EAX)
stack.drop(), chain += Set(Register(RegSize.Byte, RegName.AL), cond)
stack.push(RAX) chain += And(RAX, ImmediateVal(_8_BIT_MASK))
) chain += stack.drop()
chain += stack.push(RAX)
chain
} }
// Missing a sub instruction but dont think we need it // Missing a sub instruction but dont think we need it
def funcPrologue()(using stack: Stack): Chain[AsmLine] = { def funcPrologue()(using stack: Stack): Chain[AsmLine] = {
Chain( val chain = Chain.empty[AsmLine]
stack.push(RBP), chain += stack.push(RBP)
Move(RBP, Register(RegSize.R64, RegName.SP)) chain += Move(RBP, Register(RegSize.R64, RegName.SP))
) chain
} }
def funcEpilogue()(using stack: Stack): Chain[AsmLine] = { def funcEpilogue()(using stack: Stack): Chain[AsmLine] = {
Chain( val chain = Chain.empty[AsmLine]
Move(Register(RegSize.R64, RegName.SP), RBP), chain += Move(Register(RegSize.R64, RegName.SP), RBP)
stack.pop(RBP), chain += stack.pop(RBP)
assemblyIR.Return() chain += assemblyIR.Return()
) chain
} }
class LabelGenerator { class LabelGenerator {