Files
WACC_37/src/main/wacc/backend/asmGenerator.scala

296 lines
8.9 KiB
Scala

package wacc
import scala.collection.mutable.LinkedHashMap
import scala.collection.mutable.ListBuffer
object labelGenerator {
var labelVal = -1
def getLabel(): String = {
labelVal += 1
s".L$labelVal"
}
}
object asmGenerator {
import microWacc._
import assemblyIR._
import wacc.types._
def generateAsm(microProg: Program): List[AsmLine] = {
given stack: LinkedHashMap[Ident, Int] = LinkedHashMap[Ident, Int]()
given strings: ListBuffer[String] = ListBuffer[String]()
val Program(funcs, main) = microProg
val progAsm =
LabelDef("main") ::
funcPrologue() ++
alignStack() ++
main.flatMap(generateStmt) ++
funcEpilogue() ++
List(assemblyIR.Return()) ++
generateFuncs()
val strDirs = strings.toList.zipWithIndex.flatMap { case (str, i) =>
List(Directive.Int(str.size), LabelDef(s".L.str$i"), Directive.Asciz(str))
}
List(Directive.IntelSyntax, Directive.Global("main"), Directive.RoData) ++
strDirs ++
List(Directive.Text) ++
progAsm
}
//TODO
def generateFuncs()(using
stack: LinkedHashMap[Ident, Int],
strings: ListBuffer[String]
): List[AsmLine] = {
List()
}
def generateStmt(
stmt: Stmt
)(using stack: LinkedHashMap[Ident, Int], strings: ListBuffer[String]): List[AsmLine] =
stmt match {
case microWacc.Call(Builtin.Exit, code :: _) =>
// alignStack() ++
evalExprIntoReg(code, Register(RegSize.R64, RegName.DI)) ++
List(assemblyIR.Call(CLibFunc.Exit))
case microWacc.Call(Builtin.Println, expr :: _) =>
// alignStack() ++
printF(expr) ++
printLn()
case microWacc.Call(Builtin.Print, expr :: _) =>
// alignStack() ++
printF(expr)
case Assign(lhs, rhs) =>
var dest: IndexAddress =
IndexAddress(Register(RegSize.R64, RegName.SP), 0) // gets overrwitten
(lhs match {
case ident: Ident =>
if (!stack.contains(ident)) {
stack += (ident -> (stack.size + 1))
dest = accessVar(ident)
List(Subtract(Register(RegSize.R64, RegName.SP), ImmediateVal(16)))
} else {
dest = accessVar(ident)
List()
}
// TODO lhs = arrayElem
case _ =>
// dest = ???
List()
}) ++
(rhs match {
case microWacc.Call(Builtin.ReadInt, _) =>
readIntoVar(dest, Builtin.ReadInt)
case microWacc.Call(Builtin.ReadChar, _) =>
readIntoVar(dest, Builtin.ReadChar)
case _ =>
evalExprIntoReg(rhs, Register(RegSize.R64, RegName.AX)) ++
List(Move(dest, Register(RegSize.R64, RegName.AX)))
})
case If(cond, thenBranch, elseBranch) => {
val elseLabel = labelGenerator.getLabel()
val endLabel = labelGenerator.getLabel()
evalExprIntoReg(cond, Register(RegSize.R64, RegName.AX)) ++
List(
Compare(Register(RegSize.R64, RegName.AX), ImmediateVal(0)),
Jump(LabelArg(elseLabel), Cond.Equal)
) ++
thenBranch.flatMap(generateStmt) ++
List(Jump(LabelArg(endLabel)), LabelDef(elseLabel)) ++
elseBranch.flatMap(generateStmt) ++
List(LabelDef(endLabel))
}
case While(cond, body) => {
val startLabel = labelGenerator.getLabel()
val endLabel = labelGenerator.getLabel()
List(LabelDef(startLabel)) ++
evalExprIntoReg(cond, Register(RegSize.R64, RegName.AX)) ++
List(
Compare(Register(RegSize.R64, RegName.AX), ImmediateVal(0)),
Jump(LabelArg(endLabel), Cond.Equal)
) ++
body.flatMap(generateStmt) ++
List(Jump(LabelArg(startLabel)), LabelDef(endLabel))
}
case microWacc.Return(expr) =>
evalExprIntoReg(expr, Register(RegSize.R64, RegName.AX))
case _ => List()
}
def evalExprIntoReg(expr: Expr, dest: Register)(using
stack: LinkedHashMap[Ident, Int],
strings: ListBuffer[String]
): List[AsmLine] = {
expr match {
case IntLiter(v) =>
List(Move(dest, ImmediateVal(v)))
case CharLiter(v) =>
List(Move(dest, ImmediateVal(v.toInt)))
case ident: Ident =>
List(Move(dest, accessVar(ident)))
case ArrayLiter(elems) =>
expr.ty match {
case KnownType.String =>
strings += elems.map {
case CharLiter(v) => v
case _ => ""
}.mkString
List(
Load(
dest,
IndexAddress(
Register(RegSize.R64, RegName.IP),
LabelArg(s".L.str${strings.size - 1}")
)
)
)
// TODO other array types
case _ => List()
}
// TODO other expr types
case BoolLiter(v) => List(Move(dest, ImmediateVal(if (v) 1 else 0)))
case _ => List()
}
}
// TODO make sure EOF doenst override the value in the stack
// probably need labels implemented for conditional jumps
def readIntoVar(dest: IndexAddress, readType: Builtin.ReadInt.type | Builtin.ReadChar.type)(using
stack: LinkedHashMap[Ident, Int],
strings: ListBuffer[String]
): List[AsmLine] = {
readType match {
case Builtin.ReadInt =>
strings += PrintFormat.Int.toString
case Builtin.ReadChar =>
strings += PrintFormat.Char.toString
}
List(
Load(
Register(RegSize.R64, RegName.DI),
IndexAddress(
Register(RegSize.R64, RegName.IP),
LabelArg(s".L.str${strings.size - 1}")
)
),
Load(Register(RegSize.R64, RegName.SI), dest)
) ++
// alignStack() ++
List(assemblyIR.Call(CLibFunc.Scanf))
}
def accessVar(ident: Ident)(using stack: LinkedHashMap[Ident, Int]): IndexAddress =
IndexAddress(Register(RegSize.R64, RegName.SP), (stack.size - stack(ident)) * 16)
def alignStack()(using stack: LinkedHashMap[Ident, Int]): List[AsmLine] = {
List(
And(Register(RegSize.R64, RegName.SP), ImmediateVal(-16))
)
}
// Missing a sub instruction but dont think we need it
def funcPrologue(): List[AsmLine] = {
List(
Push(Register(RegSize.R64, RegName.BP)),
Move(Register(RegSize.R64, RegName.BP), Register(RegSize.R64, RegName.SP))
)
}
def funcEpilogue(): List[AsmLine] = {
List(
Move(Register(RegSize.R64, RegName.AX), ImmediateVal(0)),
Move(Register(RegSize.R64, RegName.SP), Register(RegSize.R64, RegName.BP)),
Pop(Register(RegSize.R64, RegName.BP))
)
}
// def saveRegs(regList: List[Register]): List[AsmLine] = regList.map(Push(_))
// def restoreRegs(regList: List[Register]): List[AsmLine] = regList.reverse.map(Pop(_))
// TODO: refactor, really ugly function
def printF(expr: Expr)(using
stack: LinkedHashMap[Ident, Int],
strings: ListBuffer[String]
): List[AsmLine] = {
// determine the format string
expr.ty match {
case KnownType.String =>
strings += PrintFormat.String.toString
case KnownType.Char =>
strings += PrintFormat.Char.toString
case KnownType.Int =>
strings += PrintFormat.Int.toString
case _ =>
strings += PrintFormat.String.toString
}
List(
Load(
Register(RegSize.R64, RegName.DI),
IndexAddress(
Register(RegSize.R64, RegName.IP),
LabelArg(s".L.str${strings.size - 1}")
)
)
)
++
// determine the actual value to print
(if (expr.ty == KnownType.Bool) {
expr match {
case BoolLiter(true) => {
strings += "true"
}
case _ => {
strings += "false"
}
}
List(
Load(
Register(RegSize.R64, RegName.DI),
IndexAddress(
Register(RegSize.R64, RegName.IP),
LabelArg(s".L.str${strings.size - 1}")
)
)
)
} else {
evalExprIntoReg(expr, Register(RegSize.R64, RegName.SI))
})
// print the value
++
List(
assemblyIR.Call(CLibFunc.PrintF),
Move(Register(RegSize.R64, RegName.DI), ImmediateVal(0)),
assemblyIR.Call(CLibFunc.Fflush)
)
}
// prints a new line
def printLn()(using
stack: LinkedHashMap[Ident, Int],
strings: ListBuffer[String]
): List[AsmLine] = {
strings += ""
Load(
Register(RegSize.R64, RegName.DI),
IndexAddress(
Register(RegSize.R64, RegName.IP),
LabelArg(s".L.str${strings.size - 1}")
)
)
::
List(
assemblyIR.Call(CLibFunc.Puts),
Move(Register(RegSize.R64, RegName.DI), ImmediateVal(0)),
assemblyIR.Call(CLibFunc.Fflush)
)
}
}