Local Mutable Chains

Merge request lab2425_spring/WACC_37!27

Co-authored-by: Jonny <j.sinteix@gmail.com>
This commit is contained in:
Gleb Koval 2025-02-25 21:13:46 +00:00
commit 70e023f27a

View File

@ -2,6 +2,9 @@ package wacc
import scala.collection.mutable.LinkedHashMap
import scala.collection.mutable.ListBuffer
import cats.data.Chain
import cats.syntax.foldable._
// import parsley.token.errors.Label
object asmGenerator {
import microWacc._
@ -24,330 +27,11 @@ object asmGenerator {
val _8_BIT_MASK = 0xff
def generateAsm(microProg: Program): List[AsmLine] = {
given stack: Stack = Stack()
given strings: ListBuffer[String] = ListBuffer[String]()
given labelGenerator: LabelGenerator = LabelGenerator()
val Program(funcs, main) = microProg
extension (chain: Chain[AsmLine])
def +(line: AsmLine): Chain[AsmLine] = chain.append(line)
val progAsm =
LabelDef("main") ::
funcPrologue() ++
List(stack.align()) ++
main.flatMap(generateStmt) ++
List(Move(RAX, ImmediateVal(0))) ++
funcEpilogue() ++
generateBuiltInFuncs()
val strDirs = strings.toList.zipWithIndex.flatMap { case (str, i) =>
List(
Directive.Int(str.size),
LabelDef(s".L.str$i"),
Directive.Asciz(str.escaped)
)
}
List(Directive.IntelSyntax, Directive.Global("main"), Directive.RoData) ++
strDirs ++
List(Directive.Text) ++
progAsm
}
def wrapFunc(labelName: String, funcBody: List[AsmLine])(using
stack: Stack,
strings: ListBuffer[String]
): List[AsmLine] = {
LabelDef(labelName) ::
funcPrologue() ++
funcBody ++
funcEpilogue()
}
def generateBuiltInFuncs()(using
stack: Stack,
strings: ListBuffer[String],
labelGenerator: LabelGenerator
): List[AsmLine] = {
wrapFunc(
labelGenerator.getLabel(Builtin.Exit),
List(stack.align(), assemblyIR.Call(CLibFunc.Exit))
) ++
wrapFunc(
labelGenerator.getLabel(Builtin.Printf),
List(
stack.align(),
assemblyIR.Call(CLibFunc.PrintF),
Move(RDI, ImmediateVal(0)),
assemblyIR.Call(CLibFunc.Fflush)
)
) ++
wrapFunc(
labelGenerator.getLabel(Builtin.Malloc),
List(
stack.align()
)
) ++
wrapFunc(labelGenerator.getLabel(Builtin.Free), List()) ++
wrapFunc(
labelGenerator.getLabel(Builtin.Read),
List(
stack.align(),
stack.reserve(),
stack.push(RSI),
Load(RSI, stack.head),
assemblyIR.Call(CLibFunc.Scanf),
stack.pop(RAX),
stack.drop()
)
)
}
def generateStmt(
stmt: Stmt
)(using
stack: Stack,
strings: ListBuffer[String],
labelGenerator: LabelGenerator
): List[AsmLine] =
stmt match {
case Assign(lhs, rhs) =>
var dest: () => IndexAddress =
() => IndexAddress(RAX, 0) // gets overrwitten
(lhs match {
case ident: Ident =>
dest = stack.accessVar(ident)
if (!stack.contains(ident)) {
List(stack.reserve(ident))
} else Nil
// TODO lhs = arrayElem
case _ =>
// dest = ???
List()
}) ++
evalExprOntoStack(rhs) ++
List(
stack.pop(RAX),
Move(dest(), RAX)
)
case If(cond, thenBranch, elseBranch) => {
val elseLabel = labelGenerator.getLabel()
val endLabel = labelGenerator.getLabel()
evalExprOntoStack(cond) ++
List(
Compare(stack.head(SizeDir.Word), ImmediateVal(0)),
stack.drop(),
Jump(LabelArg(elseLabel), Cond.Equal)
) ++
thenBranch.flatMap(generateStmt) ++
List(Jump(LabelArg(endLabel)), LabelDef(elseLabel)) ++
elseBranch.flatMap(generateStmt) ++
List(LabelDef(endLabel))
}
case While(cond, body) => {
val startLabel = labelGenerator.getLabel()
val endLabel = labelGenerator.getLabel()
List(LabelDef(startLabel)) ++
evalExprOntoStack(cond) ++
List(
Compare(stack.head(SizeDir.Word), ImmediateVal(0)),
stack.drop(),
Jump(LabelArg(endLabel), Cond.Equal)
) ++
body.flatMap(generateStmt) ++
List(Jump(LabelArg(startLabel)), LabelDef(endLabel))
}
case microWacc.Return(expr) =>
evalExprOntoStack(expr) ++
List(stack.pop(RAX), assemblyIR.Return())
case call: microWacc.Call => generateCall(call)
}
def evalExprOntoStack(expr: Expr)(using
stack: Stack,
strings: ListBuffer[String],
labelGenerator: LabelGenerator
): List[AsmLine] = {
val out = expr match {
case IntLiter(v) =>
List(stack.push(ImmediateVal(v)))
case CharLiter(v) =>
List(stack.push(ImmediateVal(v.toInt)))
case ident: Ident =>
List(stack.push(stack.accessVar(ident)()))
case ArrayLiter(elems) =>
expr.ty match {
case KnownType.String =>
strings += elems.map {
case CharLiter(v) => v
case _ => ""
}.mkString
List(
Load(
RAX,
IndexAddress(
RIP,
LabelArg(s".L.str${strings.size - 1}")
)
),
stack.push(RAX)
)
// TODO other array types
case _ => List()
}
case BoolLiter(v) => List(stack.push(ImmediateVal(if (v) 1 else 0)))
case NullLiter() => List(stack.push(ImmediateVal(0)))
case ArrayElem(value, indices) => List()
case UnaryOp(x, op) =>
evalExprOntoStack(x) ++
(op match {
// TODO: chr and ord are TYPE CASTS. They do not change the internal value,
// but will need bound checking e.t.c.
case UnaryOperator.Chr => List()
case UnaryOperator.Ord => List()
case UnaryOperator.Len => List()
case UnaryOperator.Negate =>
List(
Negate(stack.head(SizeDir.Word))
)
case UnaryOperator.Not =>
evalExprOntoStack(x) ++
List(
Xor(stack.head(SizeDir.Word), ImmediateVal(1))
)
})
case BinaryOp(x, y, op) =>
op match {
case BinaryOperator.Add =>
evalExprOntoStack(x) ++
evalExprOntoStack(y) ++
List(
stack.pop(RAX),
Add(stack.head(SizeDir.Word), EAX)
// TODO OVERFLOWING
)
case BinaryOperator.Sub =>
evalExprOntoStack(x) ++
evalExprOntoStack(y) ++
List(
stack.pop(RAX),
Subtract(stack.head(SizeDir.Word), EAX)
// TODO OVERFLOWING
)
case BinaryOperator.Mul =>
evalExprOntoStack(x) ++
evalExprOntoStack(y) ++
List(
stack.pop(RAX),
Multiply(EAX, stack.head(SizeDir.Word)),
stack.drop(),
stack.push(RAX)
// TODO OVERFLOWING
)
case BinaryOperator.Div =>
evalExprOntoStack(y) ++
evalExprOntoStack(x) ++
List(
stack.pop(RAX),
Divide(stack.head(SizeDir.Word)),
stack.drop(),
stack.push(RAX)
// TODO CHECK DIVISOR IS NOT 0
)
case BinaryOperator.Mod =>
evalExprOntoStack(y) ++
evalExprOntoStack(x) ++
List(
stack.pop(RAX),
Divide(stack.head(SizeDir.Word)),
stack.drop(),
stack.push(RDX)
// TODO CHECK DIVISOR IS NOT 0
)
case BinaryOperator.Eq =>
generateComparison(x, y, Cond.Equal)
case BinaryOperator.Neq =>
generateComparison(x, y, Cond.NotEqual)
case BinaryOperator.Greater =>
generateComparison(x, y, Cond.Greater)
case BinaryOperator.GreaterEq =>
generateComparison(x, y, Cond.GreaterEqual)
case BinaryOperator.Less =>
generateComparison(x, y, Cond.Less)
case BinaryOperator.LessEq =>
generateComparison(x, y, Cond.LessEqual)
case BinaryOperator.And =>
evalExprOntoStack(x) ++
evalExprOntoStack(y) ++
List(
stack.pop(RAX),
And(stack.head(SizeDir.Word), EAX)
)
case BinaryOperator.Or =>
evalExprOntoStack(x) ++
evalExprOntoStack(y) ++
List(
stack.pop(RAX),
Or(stack.head(SizeDir.Word), EAX)
)
}
case call: microWacc.Call =>
generateCall(call) ++
List(stack.push(RAX))
}
if out.isEmpty then List(stack.push(ImmediateVal(0))) else out
}
def generateCall(call: microWacc.Call)(using
stack: Stack,
strings: ListBuffer[String],
labelGenerator: LabelGenerator
): List[AsmLine] = {
val argRegs = List(RDI, RSI, RDX, RCX, R8, R9)
val microWacc.Call(target, args) = call
argRegs.zip(args).flatMap { (reg, expr) =>
evalExprOntoStack(expr) ++
List(stack.pop(reg))
} ++
args.drop(argRegs.size).flatMap(evalExprOntoStack) ++
List(assemblyIR.Call(LabelArg(labelGenerator.getLabel(target)))) ++
(if (args.size > argRegs.size) {
List(stack.drop(args.size - argRegs.size))
} else Nil)
}
def generateComparison(x: Expr, y: Expr, cond: Cond)(using
stack: Stack,
strings: ListBuffer[String],
labelGenerator: LabelGenerator
): List[AsmLine] = {
evalExprOntoStack(x) ++
evalExprOntoStack(y) ++
List(
stack.pop(RAX),
Compare(stack.head(SizeDir.Word), EAX),
Set(Register(RegSize.Byte, RegName.AL), cond),
And(RAX, ImmediateVal(_8_BIT_MASK)),
stack.drop(),
stack.push(RAX)
)
}
// Missing a sub instruction but dont think we need it
def funcPrologue()(using stack: Stack): List[AsmLine] = {
List(
stack.push(RBP),
Move(RBP, Register(RegSize.R64, RegName.SP))
)
}
def funcEpilogue()(using stack: Stack): List[AsmLine] = {
List(
Move(Register(RegSize.R64, RegName.SP), RBP),
stack.pop(RBP),
assemblyIR.Return()
)
}
def concatAll(chains: Chain[AsmLine]*): Chain[AsmLine] =
chains.foldLeft(chain)(_ ++ _)
class LabelGenerator {
var labelVal = -1
@ -361,6 +45,303 @@ object asmGenerator {
}
}
def generateAsm(microProg: Program): List[AsmLine] = {
given stack: Stack = Stack()
given strings: ListBuffer[String] = ListBuffer[String]()
given labelGenerator: LabelGenerator = LabelGenerator()
val Program(funcs, main) = microProg
val progAsm = Chain(LabelDef("main")).concatAll(
funcPrologue(),
Chain.one(stack.align()),
main.foldMap(generateStmt(_)),
Chain.one(Move(RAX, ImmediateVal(0))),
funcEpilogue(),
generateBuiltInFuncs()
)
val strDirs = strings.toList.zipWithIndex.foldMap { case (str, i) =>
Chain(
Directive.Int(str.size),
LabelDef(s".L.str$i"),
Directive.Asciz(str.escaped)
)
}
Chain(
Directive.IntelSyntax,
Directive.Global("main"),
Directive.RoData
).concatAll(
strDirs,
Chain.one(Directive.Text),
progAsm
).toList
}
def wrapFunc(labelName: String, funcBody: Chain[AsmLine])(using
stack: Stack,
strings: ListBuffer[String]
): Chain[AsmLine] = {
var chain = Chain.empty[AsmLine]
chain += LabelDef(labelName)
chain ++= funcPrologue()
chain ++= funcBody
chain ++= funcEpilogue()
chain
}
def generateBuiltInFuncs()(using
stack: Stack,
strings: ListBuffer[String],
labelGenerator: LabelGenerator
): Chain[AsmLine] = {
var chain = Chain.empty[AsmLine]
chain ++= wrapFunc(
labelGenerator.getLabel(Builtin.Exit),
Chain(stack.align(), assemblyIR.Call(CLibFunc.Exit))
)
chain ++= wrapFunc(
labelGenerator.getLabel(Builtin.Printf),
Chain(
stack.align(),
assemblyIR.Call(CLibFunc.PrintF),
Move(RDI, ImmediateVal(0)),
assemblyIR.Call(CLibFunc.Fflush)
)
)
chain ++= wrapFunc(
labelGenerator.getLabel(Builtin.Malloc),
Chain.one(stack.align())
)
chain ++= wrapFunc(labelGenerator.getLabel(Builtin.Free), Chain.empty)
chain ++= wrapFunc(
labelGenerator.getLabel(Builtin.Read),
Chain(
stack.align(),
stack.reserve(),
stack.push(RSI),
Load(RSI, stack.head),
assemblyIR.Call(CLibFunc.Scanf),
stack.pop(RAX),
stack.drop()
)
)
chain
}
def generateStmt(stmt: Stmt)(using
stack: Stack,
strings: ListBuffer[String],
labelGenerator: LabelGenerator
): Chain[AsmLine] = {
var chain = Chain.empty[AsmLine]
stmt match {
case Assign(lhs, rhs) =>
var dest: () => IndexAddress = () => IndexAddress(RAX, 0) // overwritten below
lhs match {
case ident: Ident =>
dest = stack.accessVar(ident)
if (!stack.contains(ident)) chain += stack.reserve(ident)
// TODO lhs = arrayElem
case _ =>
}
chain ++= evalExprOntoStack(rhs)
chain += stack.pop(RAX)
chain += Move(dest(), RAX)
case If(cond, thenBranch, elseBranch) =>
val elseLabel = labelGenerator.getLabel()
val endLabel = labelGenerator.getLabel()
chain ++= evalExprOntoStack(cond)
chain += Compare(stack.head(SizeDir.Word), ImmediateVal(0))
chain += stack.drop()
chain += Jump(LabelArg(elseLabel), Cond.Equal)
chain ++= Chain.fromSeq(thenBranch).flatMap(generateStmt)
chain += Jump(LabelArg(endLabel))
chain += LabelDef(elseLabel)
chain ++= Chain.fromSeq(elseBranch).flatMap(generateStmt)
chain += LabelDef(endLabel)
case While(cond, body) =>
val startLabel = labelGenerator.getLabel()
val endLabel = labelGenerator.getLabel()
chain += LabelDef(startLabel)
chain ++= evalExprOntoStack(cond)
chain += Compare(stack.head(SizeDir.Word), ImmediateVal(0))
chain += stack.drop()
chain += Jump(LabelArg(endLabel), Cond.Equal)
chain ++= Chain.fromSeq(body).flatMap(generateStmt)
chain += Jump(LabelArg(startLabel))
chain += LabelDef(endLabel)
case microWacc.Return(expr) =>
chain ++= evalExprOntoStack(expr)
chain += stack.pop(RAX)
chain += assemblyIR.Return()
case call: microWacc.Call =>
chain ++= generateCall(call)
}
chain
}
def evalExprOntoStack(expr: Expr)(using
stack: Stack,
strings: ListBuffer[String],
labelGenerator: LabelGenerator
): Chain[AsmLine] = {
var chain = Chain.empty[AsmLine]
expr match {
case IntLiter(v) => chain += stack.push(ImmediateVal(v))
case CharLiter(v) => chain += stack.push(ImmediateVal(v.toInt))
case ident: Ident => chain += stack.push(stack.accessVar(ident)())
case ArrayLiter(elems) =>
expr.ty match {
case KnownType.String =>
strings += elems.collect { case CharLiter(v) => v }.mkString
chain += Load(RAX, IndexAddress(RIP, LabelArg(s".L.str${strings.size - 1}")))
chain += stack.push(RAX)
case _ => // Other array types TODO
}
case BoolLiter(v) => chain += stack.push(ImmediateVal(if (v) 1 else 0))
case NullLiter() => chain += stack.push(ImmediateVal(0))
case ArrayElem(_, _) => // TODO: Implement handling
case UnaryOp(x, op) =>
chain ++= evalExprOntoStack(x)
op match {
case UnaryOperator.Chr | UnaryOperator.Ord | UnaryOperator.Len => // No op needed
case UnaryOperator.Negate => chain += Negate(stack.head(SizeDir.Word))
case UnaryOperator.Not =>
chain += Xor(stack.head(SizeDir.Word), ImmediateVal(1))
}
case BinaryOp(x, y, op) =>
chain ++= evalExprOntoStack(x)
chain ++= evalExprOntoStack(y)
chain += stack.pop(RAX)
op match {
case BinaryOperator.Add => chain += Add(stack.head(SizeDir.Word), EAX)
case BinaryOperator.Sub => chain += Subtract(stack.head(SizeDir.Word), EAX)
case BinaryOperator.Mul =>
chain += Multiply(EAX, stack.head(SizeDir.Word))
chain += stack.drop()
chain += stack.push(RAX)
case BinaryOperator.Div =>
chain += Divide(stack.head(SizeDir.Word))
chain += stack.drop()
chain += stack.push(RAX)
case BinaryOperator.Mod =>
chain += Divide(stack.head(SizeDir.Word))
chain += stack.drop()
chain += stack.push(RDX)
case BinaryOperator.Eq => chain ++= generateComparison(x, y, Cond.Equal)
case BinaryOperator.Neq => chain ++= generateComparison(x, y, Cond.NotEqual)
case BinaryOperator.Greater => chain ++= generateComparison(x, y, Cond.Greater)
case BinaryOperator.GreaterEq => chain ++= generateComparison(x, y, Cond.GreaterEqual)
case BinaryOperator.Less => chain ++= generateComparison(x, y, Cond.Less)
case BinaryOperator.LessEq => chain ++= generateComparison(x, y, Cond.LessEqual)
case BinaryOperator.And => chain += And(stack.head(SizeDir.Word), EAX)
case BinaryOperator.Or => chain += Or(stack.head(SizeDir.Word), EAX)
}
case call: microWacc.Call =>
chain ++= generateCall(call)
chain += stack.push(RAX)
}
if chain.isEmpty then chain += stack.push(ImmediateVal(0))
chain
}
def generateCall(call: microWacc.Call)(using
stack: Stack,
strings: ListBuffer[String],
labelGenerator: LabelGenerator
): Chain[AsmLine] = {
var chain = Chain.empty[AsmLine]
val argRegs = List(RDI, RSI, RDX, RCX, R8, R9)
val microWacc.Call(target, args) = call
argRegs.zip(args).foreach { (reg, expr) =>
chain ++= evalExprOntoStack(expr)
chain += stack.pop(reg)
}
args.drop(argRegs.size).foreach { expr =>
chain ++= evalExprOntoStack(expr)
}
chain += assemblyIR.Call(LabelArg(labelGenerator.getLabel(target)))
if (args.size > argRegs.size) {
chain += stack.drop(args.size - argRegs.size)
}
chain
}
def generateComparison(x: Expr, y: Expr, cond: Cond)(using
stack: Stack,
strings: ListBuffer[String],
labelGenerator: LabelGenerator
): Chain[AsmLine] = {
var chain = Chain.empty[AsmLine]
chain ++= evalExprOntoStack(x)
chain ++= evalExprOntoStack(y)
chain += stack.pop(RAX)
chain += Compare(stack.head(SizeDir.Word), EAX)
chain += Set(Register(RegSize.Byte, RegName.AL), cond)
chain += And(RAX, ImmediateVal(_8_BIT_MASK))
chain += stack.drop()
chain += stack.push(RAX)
chain
}
// Missing a sub instruction but dont think we need it
def funcPrologue()(using stack: Stack): Chain[AsmLine] = {
var chain = Chain.empty[AsmLine]
chain += stack.push(RBP)
chain += Move(RBP, Register(RegSize.R64, RegName.SP))
chain
}
def funcEpilogue()(using stack: Stack): Chain[AsmLine] = {
var chain = Chain.empty[AsmLine]
chain += Move(Register(RegSize.R64, RegName.SP), RBP)
chain += stack.pop(RBP)
chain += assemblyIR.Return()
chain
}
class Stack {
private val stack = LinkedHashMap[Expr | Int, Int]()
private val RSP = Register(RegSize.R64, RegName.SP)