488 lines
16 KiB
Scala
488 lines
16 KiB
Scala
package wacc
|
|
|
|
import cats.data.Chain
|
|
import cats.syntax.foldable._
|
|
import wacc.RuntimeError._
|
|
|
|
object asmGenerator {
|
|
import microWacc._
|
|
import assemblyIR._
|
|
import assemblyIR.commonRegisters._
|
|
import assemblyIR.RegName._
|
|
import types._
|
|
import sizeExtensions._
|
|
import lexer.escapedChars
|
|
|
|
private val argRegs = List(DI, SI, DX, CX, R8, R9)
|
|
|
|
private val _7_BIT_MASK = 0x7f
|
|
|
|
extension [T](chain: Chain[T])
|
|
def +(item: T): Chain[T] = chain.append(item)
|
|
|
|
/** Concatenates multiple `Chain[T]` instances into a single `Chain[T]`, appending them to the
|
|
* current `Chain`.
|
|
*
|
|
* @param chains
|
|
* A variable number of `Chain[T]` instances to concatenate.
|
|
* @return
|
|
* A new `Chain[T]` containing all elements from `chain` concatenated with `chains`.
|
|
*/
|
|
def concatAll(chains: Chain[T]*): Chain[T] =
|
|
chains.foldLeft(chain)(_ ++ _)
|
|
|
|
def generateAsm(microProg: Program): Chain[AsmLine] = {
|
|
given stack: Stack = Stack()
|
|
given labelGenerator: LabelGenerator = LabelGenerator()
|
|
val Program(funcs, main) = microProg
|
|
|
|
val mainLabel = LabelDef("main")
|
|
val mainAsm = main.headOption match {
|
|
case Some(stmt) =>
|
|
labelGenerator.getDebugFunc(stmt.pos, "$main", mainLabel) + mainLabel
|
|
case None => Chain.one(mainLabel)
|
|
}
|
|
val progAsm = mainAsm.concatAll(
|
|
funcPrologue(),
|
|
main.foldMap(generateStmt(_)),
|
|
Chain.one(Xor(RAX, RAX)),
|
|
funcEpilogue(),
|
|
Chain(Directive.Size(mainLabel, SizeExpr.Relative(mainLabel)), Directive.EndFunc),
|
|
generateBuiltInFuncs(),
|
|
RuntimeError.all.foldMap(_.generate),
|
|
funcs.foldMap(generateUserFunc(_))
|
|
)
|
|
|
|
Chain(
|
|
Directive.IntelSyntax,
|
|
Directive.Global("main"),
|
|
Directive.RoData
|
|
).concatAll(
|
|
labelGenerator.generateDebug,
|
|
labelGenerator.generateConstants,
|
|
Chain.one(Directive.Text),
|
|
progAsm
|
|
)
|
|
}
|
|
|
|
private def wrapBuiltinFunc(builtin: Builtin, funcBody: Chain[AsmLine])(using
|
|
stack: Stack,
|
|
labelGenerator: LabelGenerator
|
|
): Chain[AsmLine] = {
|
|
var asm = Chain.one[AsmLine](labelGenerator.getLabelDef(builtin))
|
|
asm ++= funcPrologue()
|
|
asm ++= funcBody
|
|
asm ++= funcEpilogue()
|
|
asm
|
|
}
|
|
|
|
private def generateUserFunc(func: FuncDecl)(using
|
|
labelGenerator: LabelGenerator
|
|
): Chain[AsmLine] = {
|
|
given stack: Stack = Stack()
|
|
// Setup the stack with param 7 and up
|
|
func.params.drop(argRegs.size).foreach(stack.reserve(_))
|
|
stack.reserve(Size.Q64) // Reserve return pointer slot
|
|
val funcLabel = labelGenerator.getLabelDef(func.name)
|
|
var asm = labelGenerator.getDebugFunc(func.pos, func.name.name, funcLabel)
|
|
val debugFunc = asm.size > 0
|
|
asm += funcLabel
|
|
asm ++= funcPrologue()
|
|
// Push the rest of params onto the stack for simplicity
|
|
argRegs.zip(func.params).foreach { (reg, param) =>
|
|
asm += stack.push(param, Register(Size.Q64, reg))
|
|
}
|
|
asm ++= func.body.foldMap(generateStmt(_))
|
|
// No need for epilogue here since all user functions must return explicitly
|
|
if (debugFunc) {
|
|
asm += Directive.Size(funcLabel, SizeExpr.Relative(funcLabel))
|
|
asm += Directive.EndFunc
|
|
}
|
|
asm
|
|
}
|
|
|
|
private def generateBuiltInFuncs()(using
|
|
stack: Stack,
|
|
labelGenerator: LabelGenerator
|
|
): Chain[AsmLine] = {
|
|
var asm = Chain.empty[AsmLine]
|
|
|
|
asm ++= wrapBuiltinFunc(
|
|
Builtin.Exit,
|
|
Chain(stackAlign, assemblyIR.Call(CLibFunc.Exit))
|
|
)
|
|
|
|
asm ++= wrapBuiltinFunc(
|
|
Builtin.Printf,
|
|
Chain(
|
|
stackAlign,
|
|
assemblyIR.Call(CLibFunc.PrintF),
|
|
Xor(RDI, RDI),
|
|
assemblyIR.Call(CLibFunc.Fflush)
|
|
)
|
|
)
|
|
|
|
asm ++= wrapBuiltinFunc(
|
|
Builtin.PrintCharArray,
|
|
Chain(
|
|
stackAlign,
|
|
Load(RDX, MemLocation(RSI, KnownType.Int.size.toInt)),
|
|
Move(Register(KnownType.Int.size, SI), MemLocation(RSI, opSize = Some(KnownType.Int.size))),
|
|
assemblyIR.Call(CLibFunc.PrintF),
|
|
Xor(RDI, RDI),
|
|
assemblyIR.Call(CLibFunc.Fflush)
|
|
)
|
|
)
|
|
|
|
asm ++= wrapBuiltinFunc(
|
|
Builtin.Malloc,
|
|
Chain(
|
|
stackAlign,
|
|
assemblyIR.Call(CLibFunc.Malloc),
|
|
// Out of memory check
|
|
Compare(RAX, ImmediateVal(0)),
|
|
Jump(labelGenerator.getLabelArg(OutOfMemoryError), Cond.Equal)
|
|
)
|
|
)
|
|
|
|
asm ++= wrapBuiltinFunc(
|
|
Builtin.Free,
|
|
Chain(
|
|
stackAlign,
|
|
Compare(RDI, ImmediateVal(0)),
|
|
Jump(labelGenerator.getLabelArg(NullPtrError), Cond.Equal),
|
|
assemblyIR.Call(CLibFunc.Free)
|
|
)
|
|
)
|
|
|
|
asm ++= wrapBuiltinFunc(
|
|
Builtin.Read,
|
|
Chain(
|
|
stackAlign,
|
|
Subtract(Register(Size.Q64, SP), ImmediateVal(8)),
|
|
Push(RSI),
|
|
Load(RSI, MemLocation(Register(Size.Q64, SP), opSize = Some(Size.Q64))),
|
|
assemblyIR.Call(CLibFunc.Scanf),
|
|
Pop(RAX)
|
|
)
|
|
)
|
|
|
|
asm
|
|
}
|
|
|
|
private def generateStmt(stmt: Stmt)(using
|
|
stack: Stack,
|
|
labelGenerator: LabelGenerator
|
|
): Chain[AsmLine] = {
|
|
val fileNo = labelGenerator.getDebugFile(stmt.pos.file)
|
|
var asm = Chain.one[AsmLine](Directive.Location(fileNo, stmt.pos.line, None))
|
|
stmt match {
|
|
case Assign(lhs, rhs) =>
|
|
lhs match {
|
|
case ident: Ident =>
|
|
if (!stack.contains(ident)) asm += stack.reserve(ident)
|
|
asm ++= evalExprOntoStack(rhs)
|
|
asm += stack.pop(RAX)
|
|
asm += Move(stack.accessVar(ident).copy(opSize = Some(Size.Q64)), RAX)
|
|
case ArrayElem(x, i) =>
|
|
asm ++= evalExprOntoStack(rhs)
|
|
asm ++= evalExprOntoStack(i)
|
|
asm += stack.pop(RCX)
|
|
asm += Compare(ECX, ImmediateVal(0))
|
|
asm += Jump(labelGenerator.getLabelArg(OutOfBoundsError), Cond.Less)
|
|
asm += stack.push(KnownType.Int.size, RCX)
|
|
asm ++= evalExprOntoStack(x)
|
|
asm += stack.pop(RAX)
|
|
asm += stack.pop(RCX)
|
|
asm += Compare(RAX, ImmediateVal(0))
|
|
asm += Jump(labelGenerator.getLabelArg(NullPtrError), Cond.Equal)
|
|
asm += Compare(MemLocation(RAX, opSize = Some(KnownType.Int.size)), ECX)
|
|
asm += Jump(labelGenerator.getLabelArg(OutOfBoundsError), Cond.LessEqual)
|
|
asm += stack.pop(RDX)
|
|
|
|
asm += Move(
|
|
MemLocation(RAX, KnownType.Int.size.toInt, (RCX, x.ty.elemSize.toInt)),
|
|
Register(x.ty.elemSize, DX)
|
|
)
|
|
}
|
|
|
|
case If(cond, thenBranch, elseBranch) =>
|
|
val elseLabel = labelGenerator.getLabel()
|
|
val endLabel = labelGenerator.getLabel()
|
|
|
|
asm ++= evalExprOntoStack(cond)
|
|
asm += stack.pop(RAX)
|
|
asm += Compare(RAX, ImmediateVal(0))
|
|
asm += Jump(LabelArg(elseLabel), Cond.Equal)
|
|
|
|
asm ++= stack.withScope(() => thenBranch.foldMap(generateStmt))
|
|
asm += Jump(LabelArg(endLabel))
|
|
asm += LabelDef(elseLabel)
|
|
|
|
asm ++= stack.withScope(() => elseBranch.foldMap(generateStmt))
|
|
asm += LabelDef(endLabel)
|
|
|
|
case While(cond, body) =>
|
|
val startLabel = labelGenerator.getLabel()
|
|
val endLabel = labelGenerator.getLabel()
|
|
|
|
asm += LabelDef(startLabel)
|
|
asm ++= evalExprOntoStack(cond)
|
|
asm += stack.pop(RAX)
|
|
asm += Compare(RAX, ImmediateVal(0))
|
|
asm += Jump(LabelArg(endLabel), Cond.Equal)
|
|
|
|
asm ++= stack.withScope(() => body.foldMap(generateStmt))
|
|
asm += Jump(LabelArg(startLabel))
|
|
asm += LabelDef(endLabel)
|
|
|
|
case call: microWacc.Call =>
|
|
asm ++= generateCall(call, isTail = false)
|
|
|
|
case microWacc.Return(expr) =>
|
|
expr match {
|
|
case call: microWacc.Call =>
|
|
asm ++= generateCall(call, isTail = true) // tco
|
|
case _ =>
|
|
asm ++= evalExprOntoStack(expr)
|
|
asm += stack.pop(RAX)
|
|
asm ++= funcEpilogue()
|
|
}
|
|
}
|
|
|
|
asm
|
|
}
|
|
|
|
private def evalExprOntoStack(expr: Expr)(using
|
|
stack: Stack,
|
|
labelGenerator: LabelGenerator
|
|
): Chain[AsmLine] = {
|
|
var asm = Chain.empty[AsmLine]
|
|
val stackSizeStart = stack.size
|
|
expr match {
|
|
case IntLiter(v) => asm += stack.push(KnownType.Int.size, ImmediateVal(v))
|
|
case CharLiter(v) => asm += stack.push(KnownType.Char.size, ImmediateVal(v.toInt))
|
|
case ident: Ident =>
|
|
val location = stack.accessVar(ident)
|
|
// items in stack are guaranteed to be in Q64 slots,
|
|
// so we are safe to wipe the opSize from the memory location
|
|
asm += stack.push(ident.ty.size, location.copy(opSize = None))
|
|
|
|
case array @ ArrayLiter(elems) =>
|
|
expr.ty match {
|
|
case KnownType.String =>
|
|
val str = elems.collect { case CharLiter(v) => v }.mkString
|
|
asm += Load(RAX, MemLocation(RIP, labelGenerator.getLabelArg(str)))
|
|
asm += stack.push(KnownType.String.size, RAX)
|
|
case ty =>
|
|
asm ++= generateCall(
|
|
microWacc.Call(Builtin.Malloc, List(IntLiter(array.heapSize)))(array.pos),
|
|
isTail = false
|
|
)
|
|
asm += stack.push(KnownType.Array(?).size, RAX)
|
|
// Store the length of the array at the start
|
|
asm += Move(
|
|
MemLocation(RAX, opSize = Some(KnownType.Int.size)),
|
|
ImmediateVal(elems.size)
|
|
)
|
|
elems.zipWithIndex.foldMap { (elem, i) =>
|
|
asm ++= evalExprOntoStack(elem)
|
|
asm += stack.pop(RCX)
|
|
asm += stack.pop(RAX)
|
|
asm += Move(
|
|
MemLocation(RAX, KnownType.Int.size.toInt + i * ty.elemSize.toInt),
|
|
Register(ty.elemSize, CX)
|
|
)
|
|
asm += stack.push(KnownType.Array(?).size, RAX)
|
|
}
|
|
}
|
|
|
|
case BoolLiter(true) =>
|
|
asm += stack.push(KnownType.Bool.size, ImmediateVal(1))
|
|
case BoolLiter(false) =>
|
|
asm += Xor(RAX, RAX)
|
|
asm += stack.push(KnownType.Bool.size, RAX)
|
|
case NullLiter() =>
|
|
asm += stack.push(KnownType.Pair(?, ?).size, ImmediateVal(0))
|
|
case ArrayElem(x, i) =>
|
|
asm ++= evalExprOntoStack(x)
|
|
asm ++= evalExprOntoStack(i)
|
|
asm += stack.pop(RCX)
|
|
asm += Compare(RCX, ImmediateVal(0))
|
|
asm += Jump(labelGenerator.getLabelArg(OutOfBoundsError), Cond.Less)
|
|
asm += stack.pop(RAX)
|
|
asm += Compare(RAX, ImmediateVal(0))
|
|
asm += Jump(labelGenerator.getLabelArg(NullPtrError), Cond.Equal)
|
|
asm += Compare(MemLocation(RAX, opSize = Some(KnownType.Int.size)), ECX)
|
|
asm += Jump(labelGenerator.getLabelArg(OutOfBoundsError), Cond.LessEqual)
|
|
// + Int because we store the length of the array at the start
|
|
asm += Move(
|
|
Register(x.ty.elemSize, AX),
|
|
MemLocation(RAX, KnownType.Int.size.toInt, (RCX, x.ty.elemSize.toInt))
|
|
)
|
|
asm += stack.push(x.ty.elemSize, RAX)
|
|
case UnaryOp(x, op) =>
|
|
asm ++= evalExprOntoStack(x)
|
|
op match {
|
|
case UnaryOperator.Chr =>
|
|
asm += Move(EAX, stack.head)
|
|
asm += And(EAX, ImmediateVal(~_7_BIT_MASK))
|
|
asm += Compare(EAX, ImmediateVal(0))
|
|
asm += Jump(labelGenerator.getLabelArg(BadChrError), Cond.NotEqual)
|
|
case UnaryOperator.Ord => // No op needed
|
|
case UnaryOperator.Len =>
|
|
asm += stack.pop(RAX)
|
|
asm += Move(EAX, MemLocation(RAX, opSize = Some(KnownType.Int.size)))
|
|
asm += stack.push(KnownType.Int.size, RAX)
|
|
case UnaryOperator.Negate =>
|
|
asm += Xor(EAX, EAX)
|
|
asm += Subtract(EAX, stack.head)
|
|
asm += Jump(labelGenerator.getLabelArg(OverflowError), Cond.Overflow)
|
|
asm += stack.drop()
|
|
asm += stack.push(KnownType.Int.size, RAX)
|
|
case UnaryOperator.Not =>
|
|
asm += Xor(stack.head, ImmediateVal(1))
|
|
}
|
|
|
|
case BinaryOp(x, y, op) =>
|
|
val destX = Register(x.ty.size, AX)
|
|
asm ++= evalExprOntoStack(y)
|
|
asm ++= evalExprOntoStack(x)
|
|
asm += stack.pop(RAX)
|
|
|
|
op match {
|
|
case BinaryOperator.Add =>
|
|
asm += Add(stack.head, destX)
|
|
asm += Jump(labelGenerator.getLabelArg(OverflowError), Cond.Overflow)
|
|
case BinaryOperator.Sub =>
|
|
asm += Subtract(destX, stack.head)
|
|
asm += Jump(labelGenerator.getLabelArg(OverflowError), Cond.Overflow)
|
|
asm += stack.drop()
|
|
asm += stack.push(destX.size, RAX)
|
|
case BinaryOperator.Mul =>
|
|
asm += Multiply(destX, stack.head)
|
|
asm += Jump(labelGenerator.getLabelArg(OverflowError), Cond.Overflow)
|
|
asm += stack.drop()
|
|
asm += stack.push(destX.size, RAX)
|
|
|
|
case BinaryOperator.Div =>
|
|
asm += Compare(stack.head, ImmediateVal(0))
|
|
asm += Jump(labelGenerator.getLabelArg(ZeroDivError), Cond.Equal)
|
|
asm += CDQ()
|
|
asm += Divide(stack.head)
|
|
asm += stack.drop()
|
|
asm += stack.push(destX.size, RAX)
|
|
|
|
case BinaryOperator.Mod =>
|
|
asm += Compare(stack.head, ImmediateVal(0))
|
|
asm += Jump(labelGenerator.getLabelArg(ZeroDivError), Cond.Equal)
|
|
asm += CDQ()
|
|
asm += Divide(stack.head)
|
|
asm += stack.drop()
|
|
asm += stack.push(destX.size, RDX)
|
|
|
|
case BinaryOperator.Eq => asm ++= generateComparison(destX, Cond.Equal)
|
|
case BinaryOperator.Neq => asm ++= generateComparison(destX, Cond.NotEqual)
|
|
case BinaryOperator.Greater => asm ++= generateComparison(destX, Cond.Greater)
|
|
case BinaryOperator.GreaterEq => asm ++= generateComparison(destX, Cond.GreaterEqual)
|
|
case BinaryOperator.Less => asm ++= generateComparison(destX, Cond.Less)
|
|
case BinaryOperator.LessEq => asm ++= generateComparison(destX, Cond.LessEqual)
|
|
case BinaryOperator.And => asm += And(stack.head, destX)
|
|
case BinaryOperator.Or => asm += Or(stack.head, destX)
|
|
}
|
|
|
|
case call: microWacc.Call =>
|
|
asm ++= generateCall(call, isTail = false)
|
|
asm += stack.push(call.ty.size, RAX)
|
|
}
|
|
|
|
assert(
|
|
stack.size == stackSizeStart + 1,
|
|
"Sanity check: ONLY the evaluated expression should have been pushed onto the stack"
|
|
)
|
|
asm ++= zeroRest(stack.head.copy(opSize = Some(Size.Q64)), expr.ty.size)
|
|
asm
|
|
}
|
|
|
|
private def generateCall(call: microWacc.Call, isTail: Boolean)(using
|
|
stack: Stack,
|
|
labelGenerator: LabelGenerator
|
|
): Chain[AsmLine] = {
|
|
var asm = Chain.empty[AsmLine]
|
|
val microWacc.Call(target, args) = call
|
|
|
|
// Evaluate arguments 0-6
|
|
argRegs
|
|
.zip(args)
|
|
.map { (reg, expr) =>
|
|
asm ++= evalExprOntoStack(expr)
|
|
reg
|
|
}
|
|
// And set the appropriate registers
|
|
.reverse
|
|
.foreach { reg =>
|
|
asm += stack.pop(Register(Size.Q64, reg))
|
|
}
|
|
|
|
// Evaluate arguments 7 and up and push them onto the stack
|
|
args.drop(argRegs.size).foldMap {
|
|
asm ++= evalExprOntoStack(_)
|
|
}
|
|
|
|
// Tail Call Optimisation (TCO)
|
|
if (isTail) {
|
|
asm += Jump(labelGenerator.getLabelArg(target)) // tail call
|
|
} else {
|
|
asm += assemblyIR.Call(labelGenerator.getLabelArg(target)) // regular call
|
|
}
|
|
|
|
// Remove arguments 7 and up from the stack
|
|
if (args.size > argRegs.size) {
|
|
asm += stack.drop(args.size - argRegs.size)
|
|
}
|
|
|
|
asm
|
|
}
|
|
|
|
private def generateComparison(destX: Register, cond: Cond)(using
|
|
stack: Stack
|
|
): Chain[AsmLine] = {
|
|
var asm = Chain.empty[AsmLine]
|
|
|
|
asm += Compare(destX, stack.head)
|
|
asm += Set(Register(Size.B8, AX), cond)
|
|
asm ++= zeroRest(RAX, Size.B8)
|
|
asm += stack.drop()
|
|
asm += stack.push(Size.B8, RAX)
|
|
|
|
asm
|
|
}
|
|
|
|
private def funcPrologue()(using stack: Stack): Chain[AsmLine] = {
|
|
var asm = Chain.empty[AsmLine]
|
|
asm += stack.push(Size.Q64, RBP)
|
|
asm += Move(RBP, Register(Size.Q64, SP))
|
|
asm
|
|
}
|
|
|
|
private def funcEpilogue(): Chain[AsmLine] = {
|
|
var asm = Chain.empty[AsmLine]
|
|
asm += Move(Register(Size.Q64, SP), RBP)
|
|
asm += Pop(RBP)
|
|
asm += assemblyIR.Return()
|
|
asm
|
|
}
|
|
|
|
def stackAlign: AsmLine = And(Register(Size.Q64, SP), ImmediateVal(-16))
|
|
private def zeroRest(dest: Dest, size: Size): Chain[AsmLine] = size match {
|
|
case Size.Q64 | Size.D32 => Chain.empty
|
|
case _ => Chain.one(And(dest, ImmediateVal((1 << (size.toInt * 8)) - 1)))
|
|
}
|
|
|
|
private val escapedCharsMapping = escapedChars.map { case (k, v) => v -> s"\\$k" }
|
|
extension (s: String) {
|
|
def escaped: String =
|
|
s.flatMap(c => escapedCharsMapping.getOrElse(c, c.toString))
|
|
}
|
|
}
|