441 lines
14 KiB
Scala
441 lines
14 KiB
Scala
package wacc
|
|
|
|
import scala.collection.mutable.LinkedHashMap
|
|
import scala.collection.mutable.ListBuffer
|
|
import cats.data.Chain
|
|
import cats.syntax.foldable._
|
|
// import parsley.token.errors.Label
|
|
|
|
object asmGenerator {
|
|
import microWacc._
|
|
import assemblyIR._
|
|
import wacc.types._
|
|
import lexer.escapedChars
|
|
|
|
abstract case class Error() {
|
|
def strLabel: String
|
|
def errStr: String
|
|
def errLabel: String
|
|
|
|
def stringDef: Chain[AsmLine] = Chain(
|
|
Directive.Int(errStr.size),
|
|
LabelDef(strLabel),
|
|
Directive.Asciz(errStr)
|
|
)
|
|
}
|
|
object zeroDivError extends Error {
|
|
// TODO: is this bad? Can we make an error case class/some other structure?
|
|
def strLabel = ".L._errDivZero_str0"
|
|
def errStr = "fatal error: division or modulo by zero"
|
|
def errLabel = ".L._errDivZero"
|
|
}
|
|
|
|
val RAX = Register(RegSize.R64, RegName.AX)
|
|
val EAX = Register(RegSize.E32, RegName.AX)
|
|
val ESP = Register(RegSize.E32, RegName.SP)
|
|
val EDX = Register(RegSize.E32, RegName.DX)
|
|
val RDI = Register(RegSize.R64, RegName.DI)
|
|
val RIP = Register(RegSize.R64, RegName.IP)
|
|
val RBP = Register(RegSize.R64, RegName.BP)
|
|
val RSI = Register(RegSize.R64, RegName.SI)
|
|
val RDX = Register(RegSize.R64, RegName.DX)
|
|
val RCX = Register(RegSize.R64, RegName.CX)
|
|
val R8 = Register(RegSize.R64, RegName.Reg8)
|
|
val R9 = Register(RegSize.R64, RegName.Reg9)
|
|
val argRegs = List(RDI, RSI, RDX, RCX, R8, R9)
|
|
|
|
val _8_BIT_MASK = 0xff
|
|
|
|
extension (chain: Chain[AsmLine])
|
|
def +(line: AsmLine): Chain[AsmLine] = chain.append(line)
|
|
|
|
def concatAll(chains: Chain[AsmLine]*): Chain[AsmLine] =
|
|
chains.foldLeft(chain)(_ ++ _)
|
|
|
|
class LabelGenerator {
|
|
var labelVal = -1
|
|
def getLabel(): String = {
|
|
labelVal += 1
|
|
s".L$labelVal"
|
|
}
|
|
def getLabel(target: CallTarget): String = target match {
|
|
case Ident(v, _) => s"wacc_$v"
|
|
case Builtin(name) => s"_$name"
|
|
}
|
|
}
|
|
|
|
def generateAsm(microProg: Program): List[AsmLine] = {
|
|
given stack: Stack = Stack()
|
|
given strings: ListBuffer[String] = ListBuffer[String]()
|
|
given labelGenerator: LabelGenerator = LabelGenerator()
|
|
val Program(funcs, main) = microProg
|
|
|
|
val progAsm = Chain(LabelDef("main")).concatAll(
|
|
funcPrologue(),
|
|
Chain.one(stack.align()),
|
|
main.foldMap(generateStmt(_)),
|
|
Chain.one(Move(RAX, ImmediateVal(0))),
|
|
funcEpilogue(),
|
|
generateBuiltInFuncs(),
|
|
funcs.foldMap(generateUserFunc(_))
|
|
)
|
|
|
|
val strDirs = strings.toList.zipWithIndex.foldMap { case (str, i) =>
|
|
Chain(
|
|
Directive.Int(str.size),
|
|
LabelDef(s".L.str$i"),
|
|
Directive.Asciz(str.escaped)
|
|
)
|
|
} ++ zeroDivError.stringDef
|
|
|
|
Chain(
|
|
Directive.IntelSyntax,
|
|
Directive.Global("main"),
|
|
Directive.RoData
|
|
).concatAll(
|
|
strDirs,
|
|
Chain.one(Directive.Text),
|
|
progAsm
|
|
).toList
|
|
}
|
|
|
|
private def wrapBuiltinFunc(labelName: String, funcBody: Chain[AsmLine])(using
|
|
stack: Stack
|
|
): Chain[AsmLine] = {
|
|
var chain = Chain.one[AsmLine](LabelDef(labelName))
|
|
chain ++= funcPrologue()
|
|
chain ++= funcBody
|
|
chain ++= funcEpilogue()
|
|
chain
|
|
}
|
|
|
|
def generateUserFunc(func: FuncDecl)(using
|
|
strings: ListBuffer[String],
|
|
labelGenerator: LabelGenerator
|
|
): Chain[AsmLine] = {
|
|
given stack: Stack = Stack()
|
|
// Setup the stack with param 7 and up
|
|
func.params.drop(argRegs.size).foreach(stack.reserve(_))
|
|
var chain = Chain.one[AsmLine](LabelDef(labelGenerator.getLabel(func.name)))
|
|
chain ++= funcPrologue()
|
|
// Push the rest of params onto the stack for simplicity
|
|
argRegs.zip(func.params).foreach { (reg, param) =>
|
|
chain += stack.push(param, reg)
|
|
}
|
|
chain ++= func.body.foldMap(generateStmt(_))
|
|
// No need for epilogue here since all user functions must return explicitly
|
|
chain
|
|
}
|
|
|
|
def generateBuiltInFuncs()(using
|
|
stack: Stack,
|
|
strings: ListBuffer[String],
|
|
labelGenerator: LabelGenerator
|
|
): Chain[AsmLine] = {
|
|
var chain = Chain.empty[AsmLine]
|
|
|
|
chain ++= wrapBuiltinFunc(
|
|
labelGenerator.getLabel(Builtin.Exit),
|
|
Chain(stack.align(), assemblyIR.Call(CLibFunc.Exit))
|
|
)
|
|
|
|
chain ++= wrapBuiltinFunc(
|
|
labelGenerator.getLabel(Builtin.Printf),
|
|
Chain(
|
|
stack.align(),
|
|
assemblyIR.Call(CLibFunc.PrintF),
|
|
Move(RDI, ImmediateVal(0)),
|
|
assemblyIR.Call(CLibFunc.Fflush)
|
|
)
|
|
)
|
|
|
|
chain ++= wrapBuiltinFunc(
|
|
labelGenerator.getLabel(Builtin.Malloc),
|
|
Chain.one(stack.align())
|
|
)
|
|
|
|
chain ++= wrapBuiltinFunc(labelGenerator.getLabel(Builtin.Free), Chain.empty)
|
|
|
|
chain ++= wrapBuiltinFunc(
|
|
labelGenerator.getLabel(Builtin.Read),
|
|
Chain(
|
|
stack.align(),
|
|
stack.reserve(),
|
|
stack.push(RSI),
|
|
Load(RSI, stack.head),
|
|
assemblyIR.Call(CLibFunc.Scanf),
|
|
stack.pop(RAX),
|
|
stack.drop()
|
|
)
|
|
)
|
|
|
|
chain ++= Chain(
|
|
// TODO can this be done with a call to generateStmt?
|
|
// Consider other error cases -> look to generalise
|
|
LabelDef(zeroDivError.errLabel),
|
|
stack.align(),
|
|
Load(RDI, IndexAddress(RIP, LabelArg(zeroDivError.strLabel))),
|
|
assemblyIR.Call(CLibFunc.PrintF),
|
|
Move(RDI, ImmediateVal(-1)),
|
|
assemblyIR.Call(CLibFunc.Exit)
|
|
)
|
|
|
|
chain
|
|
}
|
|
|
|
def generateStmt(stmt: Stmt)(using
|
|
stack: Stack,
|
|
strings: ListBuffer[String],
|
|
labelGenerator: LabelGenerator
|
|
): Chain[AsmLine] = {
|
|
var chain = Chain.empty[AsmLine]
|
|
|
|
stmt match {
|
|
case Assign(lhs, rhs) =>
|
|
var dest: () => IndexAddress = () => IndexAddress(RAX, 0) // overwritten below
|
|
|
|
lhs match {
|
|
case ident: Ident =>
|
|
dest = stack.accessVar(ident)
|
|
if (!stack.contains(ident)) chain += stack.reserve(ident)
|
|
// TODO lhs = arrayElem
|
|
case _ =>
|
|
}
|
|
|
|
chain ++= evalExprOntoStack(rhs)
|
|
chain += stack.pop(RAX)
|
|
chain += Move(dest(), RAX)
|
|
|
|
case If(cond, thenBranch, elseBranch) =>
|
|
val elseLabel = labelGenerator.getLabel()
|
|
val endLabel = labelGenerator.getLabel()
|
|
|
|
chain ++= evalExprOntoStack(cond)
|
|
chain += stack.pop(RAX)
|
|
chain += Compare(RAX, ImmediateVal(0))
|
|
chain += Jump(LabelArg(elseLabel), Cond.Equal)
|
|
|
|
chain ++= thenBranch.foldMap(generateStmt)
|
|
chain += Jump(LabelArg(endLabel))
|
|
chain += LabelDef(elseLabel)
|
|
|
|
chain ++= elseBranch.foldMap(generateStmt)
|
|
chain += LabelDef(endLabel)
|
|
|
|
case While(cond, body) =>
|
|
val startLabel = labelGenerator.getLabel()
|
|
val endLabel = labelGenerator.getLabel()
|
|
|
|
chain += LabelDef(startLabel)
|
|
chain ++= evalExprOntoStack(cond)
|
|
chain += stack.pop(RAX)
|
|
chain += Compare(RAX, ImmediateVal(0))
|
|
chain += Jump(LabelArg(endLabel), Cond.Equal)
|
|
|
|
chain ++= body.foldMap(generateStmt)
|
|
chain += Jump(LabelArg(startLabel))
|
|
chain += LabelDef(endLabel)
|
|
|
|
case microWacc.Return(expr) =>
|
|
chain ++= evalExprOntoStack(expr)
|
|
chain += stack.pop(RAX)
|
|
chain ++= funcEpilogue()
|
|
|
|
case call: microWacc.Call =>
|
|
chain ++= generateCall(call)
|
|
}
|
|
|
|
chain
|
|
}
|
|
|
|
def evalExprOntoStack(expr: Expr)(using
|
|
stack: Stack,
|
|
strings: ListBuffer[String],
|
|
labelGenerator: LabelGenerator
|
|
): Chain[AsmLine] = {
|
|
var chain = Chain.empty[AsmLine]
|
|
|
|
expr match {
|
|
case IntLiter(v) => chain += stack.push(ImmediateVal(v))
|
|
case CharLiter(v) => chain += stack.push(ImmediateVal(v.toInt))
|
|
case ident: Ident => chain += stack.push(stack.accessVar(ident)())
|
|
|
|
case ArrayLiter(elems) =>
|
|
expr.ty match {
|
|
case KnownType.String =>
|
|
strings += elems.collect { case CharLiter(v) => v }.mkString
|
|
chain += Load(RAX, IndexAddress(RIP, LabelArg(s".L.str${strings.size - 1}")))
|
|
chain += stack.push(RAX)
|
|
case _ => // Other array types TODO
|
|
}
|
|
|
|
case BoolLiter(v) => chain += stack.push(ImmediateVal(if (v) 1 else 0))
|
|
case NullLiter() => chain += stack.push(ImmediateVal(0))
|
|
case ArrayElem(_, _) => // TODO: Implement handling
|
|
case UnaryOp(x, op) =>
|
|
chain ++= evalExprOntoStack(x)
|
|
op match {
|
|
case UnaryOperator.Chr | UnaryOperator.Ord | UnaryOperator.Len => // No op needed
|
|
case UnaryOperator.Negate => chain += Negate(stack.head(SizeDir.DWord))
|
|
case UnaryOperator.Not =>
|
|
chain += Xor(stack.head(SizeDir.DWord), ImmediateVal(1))
|
|
}
|
|
|
|
case BinaryOp(x, y, op) =>
|
|
chain ++= evalExprOntoStack(y)
|
|
chain ++= evalExprOntoStack(x)
|
|
|
|
chain += stack.pop(RAX)
|
|
|
|
op match {
|
|
case BinaryOperator.Add => chain += Add(stack.head(SizeDir.DWord), EAX)
|
|
case BinaryOperator.Sub =>
|
|
chain += Subtract(EAX, stack.head(SizeDir.DWord))
|
|
chain += stack.drop()
|
|
chain += stack.push(RAX)
|
|
case BinaryOperator.Mul =>
|
|
chain += Multiply(EAX, stack.head(SizeDir.DWord))
|
|
chain += stack.drop()
|
|
chain += stack.push(RAX)
|
|
|
|
case BinaryOperator.Div =>
|
|
chain += Compare(stack.head(SizeDir.DWord), ImmediateVal(0))
|
|
chain += Jump(LabelArg(zeroDivError.errLabel), Cond.Equal)
|
|
chain += CDQ()
|
|
chain += Divide(stack.head(SizeDir.DWord))
|
|
chain += stack.drop()
|
|
chain += stack.push(RAX)
|
|
|
|
case BinaryOperator.Mod =>
|
|
chain += CDQ()
|
|
chain += Divide(stack.head(SizeDir.DWord))
|
|
chain += stack.drop()
|
|
chain += stack.push(RDX)
|
|
|
|
case BinaryOperator.Eq => chain ++= generateComparison(x, y, Cond.Equal)
|
|
case BinaryOperator.Neq => chain ++= generateComparison(x, y, Cond.NotEqual)
|
|
case BinaryOperator.Greater => chain ++= generateComparison(x, y, Cond.Greater)
|
|
case BinaryOperator.GreaterEq => chain ++= generateComparison(x, y, Cond.GreaterEqual)
|
|
case BinaryOperator.Less => chain ++= generateComparison(x, y, Cond.Less)
|
|
case BinaryOperator.LessEq => chain ++= generateComparison(x, y, Cond.LessEqual)
|
|
case BinaryOperator.And => chain += And(stack.head(SizeDir.DWord), EAX)
|
|
case BinaryOperator.Or => chain += Or(stack.head(SizeDir.DWord), EAX)
|
|
}
|
|
|
|
case call: microWacc.Call =>
|
|
chain ++= generateCall(call)
|
|
chain += stack.push(RAX)
|
|
}
|
|
|
|
if chain.isEmpty then chain += stack.push(ImmediateVal(0))
|
|
chain
|
|
}
|
|
|
|
def generateCall(call: microWacc.Call)(using
|
|
stack: Stack,
|
|
strings: ListBuffer[String],
|
|
labelGenerator: LabelGenerator
|
|
): Chain[AsmLine] = {
|
|
var chain = Chain.empty[AsmLine]
|
|
val microWacc.Call(target, args) = call
|
|
|
|
argRegs.zip(args).foldMap { (reg, expr) =>
|
|
chain ++= evalExprOntoStack(expr)
|
|
chain += stack.pop(reg)
|
|
}
|
|
|
|
args.drop(argRegs.size).foldMap {
|
|
chain ++= evalExprOntoStack(_)
|
|
}
|
|
|
|
chain += assemblyIR.Call(LabelArg(labelGenerator.getLabel(target)))
|
|
|
|
if (args.size > argRegs.size) {
|
|
chain += stack.drop(args.size - argRegs.size)
|
|
}
|
|
|
|
chain
|
|
}
|
|
|
|
def generateComparison(x: Expr, y: Expr, cond: Cond)(using
|
|
stack: Stack,
|
|
strings: ListBuffer[String],
|
|
labelGenerator: LabelGenerator
|
|
): Chain[AsmLine] = {
|
|
|
|
var chain = Chain.empty[AsmLine]
|
|
|
|
chain ++= evalExprOntoStack(x)
|
|
chain ++= evalExprOntoStack(y)
|
|
chain += stack.pop(RAX)
|
|
chain += Compare(stack.head(SizeDir.DWord), EAX)
|
|
chain += Set(Register(RegSize.Byte, RegName.AL), cond)
|
|
chain += And(RAX, ImmediateVal(_8_BIT_MASK))
|
|
chain += stack.drop()
|
|
chain += stack.push(RAX)
|
|
|
|
chain
|
|
}
|
|
|
|
// Missing a sub instruction but dont think we need it
|
|
def funcPrologue()(using stack: Stack): Chain[AsmLine] = {
|
|
var chain = Chain.empty[AsmLine]
|
|
chain += stack.push(RBP)
|
|
chain += Move(RBP, Register(RegSize.R64, RegName.SP))
|
|
chain
|
|
}
|
|
|
|
def funcEpilogue()(using stack: Stack): Chain[AsmLine] = {
|
|
var chain = Chain.empty[AsmLine]
|
|
chain += Move(Register(RegSize.R64, RegName.SP), RBP)
|
|
chain += Pop(RBP)
|
|
chain += assemblyIR.Return()
|
|
chain
|
|
}
|
|
|
|
class Stack {
|
|
private val stack = LinkedHashMap[Expr | Int, Int]()
|
|
private val RSP = Register(RegSize.R64, RegName.SP)
|
|
|
|
private def next: Int = stack.size + 1
|
|
def push(expr: Expr, src: Src): AsmLine = {
|
|
stack += expr -> next
|
|
Push(src)
|
|
}
|
|
def push(src: Src): AsmLine = {
|
|
stack += stack.size -> next
|
|
Push(src)
|
|
}
|
|
def pop(dest: Src): AsmLine = {
|
|
stack.remove(stack.last._1)
|
|
Pop(dest)
|
|
}
|
|
def reserve(ident: Ident): AsmLine = {
|
|
stack += ident -> next
|
|
Subtract(RSP, ImmediateVal(8))
|
|
}
|
|
def reserve(n: Int = 1): AsmLine = {
|
|
(1 to n).foreach(_ => stack += stack.size -> next)
|
|
Subtract(RSP, ImmediateVal(n * 8))
|
|
}
|
|
def drop(n: Int = 1): AsmLine = {
|
|
(1 to n).foreach(_ => stack.remove(stack.last._1))
|
|
Add(RSP, ImmediateVal(n * 8))
|
|
}
|
|
def accessVar(ident: Ident): () => IndexAddress = () => {
|
|
IndexAddress(RSP, (stack.size - stack(ident)) * 8)
|
|
}
|
|
def head: MemLocation = MemLocation(RSP)
|
|
def head(size: SizeDir): MemLocation = MemLocation(RSP, size)
|
|
def contains(ident: Ident): Boolean = stack.contains(ident)
|
|
// TODO: Might want to actually properly handle this with the LinkedHashMap too
|
|
def align(): AsmLine = And(RSP, ImmediateVal(-16))
|
|
}
|
|
|
|
private val escapedCharsMapping = escapedChars.map { case (k, v) => v -> s"\\$k" }
|
|
extension (s: String) {
|
|
private def escaped: String =
|
|
s.flatMap(c => escapedCharsMapping.getOrElse(c, c.toString))
|
|
}
|
|
}
|