Merge branch 'master' into comments-and-refactors
This commit is contained in:
@@ -1,6 +1,5 @@
|
||||
package wacc
|
||||
|
||||
import scala.collection.mutable.ListBuffer
|
||||
import cats.data.Chain
|
||||
import cats.syntax.foldable._
|
||||
import wacc.RuntimeError._
|
||||
@@ -33,21 +32,8 @@ object asmGenerator {
|
||||
def concatAll(chains: Chain[T]*): Chain[T] =
|
||||
chains.foldLeft(chain)(_ ++ _)
|
||||
|
||||
class LabelGenerator {
|
||||
var labelVal = -1
|
||||
def getLabel(): String = {
|
||||
labelVal += 1
|
||||
s".L$labelVal"
|
||||
}
|
||||
def getLabel(target: CallTarget): String = target match {
|
||||
case Ident(v, _) => s"wacc_$v"
|
||||
case Builtin(name) => s"_$name"
|
||||
}
|
||||
}
|
||||
|
||||
def generateAsm(microProg: Program): Chain[AsmLine] = {
|
||||
given stack: Stack = Stack()
|
||||
given strings: ListBuffer[String] = ListBuffer[String]()
|
||||
given labelGenerator: LabelGenerator = LabelGenerator()
|
||||
val Program(funcs, main) = microProg
|
||||
|
||||
@@ -57,32 +43,26 @@ object asmGenerator {
|
||||
Chain.one(Xor(RAX, RAX)),
|
||||
funcEpilogue(),
|
||||
generateBuiltInFuncs(),
|
||||
RuntimeError.all.foldMap(_.generate),
|
||||
funcs.foldMap(generateUserFunc(_))
|
||||
)
|
||||
|
||||
val strDirs = strings.toList.zipWithIndex.foldMap { case (str, i) =>
|
||||
Chain(
|
||||
Directive.Int(str.size),
|
||||
LabelDef(s".L.str$i"),
|
||||
Directive.Asciz(str.escaped)
|
||||
)
|
||||
} ++ RuntimeError.all.foldMap(_.stringDef)
|
||||
|
||||
Chain(
|
||||
Directive.IntelSyntax,
|
||||
Directive.Global("main"),
|
||||
Directive.RoData
|
||||
).concatAll(
|
||||
strDirs,
|
||||
labelGenerator.generateConstants,
|
||||
Chain.one(Directive.Text),
|
||||
progAsm
|
||||
)
|
||||
}
|
||||
|
||||
private def wrapBuiltinFunc(labelName: String, funcBody: Chain[AsmLine])(using
|
||||
stack: Stack
|
||||
private def wrapBuiltinFunc(builtin: Builtin, funcBody: Chain[AsmLine])(using
|
||||
stack: Stack,
|
||||
labelGenerator: LabelGenerator
|
||||
): Chain[AsmLine] = {
|
||||
var asm = Chain.one[AsmLine](LabelDef(labelName))
|
||||
var asm = Chain.one[AsmLine](labelGenerator.getLabelDef(builtin))
|
||||
asm ++= funcPrologue()
|
||||
asm ++= funcBody
|
||||
asm ++= funcEpilogue()
|
||||
@@ -90,14 +70,13 @@ object asmGenerator {
|
||||
}
|
||||
|
||||
private def generateUserFunc(func: FuncDecl)(using
|
||||
strings: ListBuffer[String],
|
||||
labelGenerator: LabelGenerator
|
||||
): Chain[AsmLine] = {
|
||||
given stack: Stack = Stack()
|
||||
// Setup the stack with param 7 and up
|
||||
func.params.drop(argRegs.size).foreach(stack.reserve(_))
|
||||
stack.reserve(Q64) // Reserve return pointer slot
|
||||
var asm = Chain.one[AsmLine](LabelDef(labelGenerator.getLabel(func.name)))
|
||||
var asm = Chain.one[AsmLine](labelGenerator.getLabelDef(func.name))
|
||||
asm ++= funcPrologue()
|
||||
// Push the rest of params onto the stack for simplicity
|
||||
argRegs.zip(func.params).foreach { (reg, param) =>
|
||||
@@ -115,12 +94,12 @@ object asmGenerator {
|
||||
var asm = Chain.empty[AsmLine]
|
||||
|
||||
asm ++= wrapBuiltinFunc(
|
||||
labelGenerator.getLabel(Builtin.Exit),
|
||||
Builtin.Exit,
|
||||
Chain(stackAlign, assemblyIR.Call(CLibFunc.Exit))
|
||||
)
|
||||
|
||||
asm ++= wrapBuiltinFunc(
|
||||
labelGenerator.getLabel(Builtin.Printf),
|
||||
Builtin.Printf,
|
||||
Chain(
|
||||
stackAlign,
|
||||
assemblyIR.Call(CLibFunc.PrintF),
|
||||
@@ -130,7 +109,7 @@ object asmGenerator {
|
||||
)
|
||||
|
||||
asm ++= wrapBuiltinFunc(
|
||||
labelGenerator.getLabel(Builtin.PrintCharArray),
|
||||
Builtin.PrintCharArray,
|
||||
Chain(
|
||||
stackAlign,
|
||||
Load(RDX, IndexAddress(RSI, KnownType.Int.size.toInt)),
|
||||
@@ -142,29 +121,29 @@ object asmGenerator {
|
||||
)
|
||||
|
||||
asm ++= wrapBuiltinFunc(
|
||||
labelGenerator.getLabel(Builtin.Malloc),
|
||||
Builtin.Malloc,
|
||||
Chain(
|
||||
stackAlign,
|
||||
assemblyIR.Call(CLibFunc.Malloc),
|
||||
// Out of memory check
|
||||
Compare(RAX, ImmediateVal(0)),
|
||||
Jump(LabelArg(OutOfMemoryError.errLabel), Cond.Equal)
|
||||
Jump(labelGenerator.getLabelArg(OutOfMemoryError), Cond.Equal)
|
||||
)
|
||||
)
|
||||
|
||||
asm ++= wrapBuiltinFunc(
|
||||
labelGenerator.getLabel(Builtin.Free),
|
||||
Builtin.Free,
|
||||
Chain(
|
||||
stackAlign,
|
||||
Move(RDI, RAX),
|
||||
Compare(RDI, ImmediateVal(0)),
|
||||
Jump(LabelArg(NullPtrError.errLabel), Cond.Equal),
|
||||
Jump(labelGenerator.getLabelArg(NullPtrError), Cond.Equal),
|
||||
assemblyIR.Call(CLibFunc.Free)
|
||||
)
|
||||
)
|
||||
|
||||
asm ++= wrapBuiltinFunc(
|
||||
labelGenerator.getLabel(Builtin.Read),
|
||||
Builtin.Read,
|
||||
Chain(
|
||||
stackAlign,
|
||||
Subtract(Register(Q64, SP), ImmediateVal(8)),
|
||||
@@ -175,18 +154,14 @@ object asmGenerator {
|
||||
)
|
||||
)
|
||||
|
||||
asm ++= RuntimeError.all.foldMap(_.generateHandler)
|
||||
|
||||
asm
|
||||
}
|
||||
|
||||
private def generateStmt(stmt: Stmt)(using
|
||||
stack: Stack,
|
||||
strings: ListBuffer[String],
|
||||
labelGenerator: LabelGenerator
|
||||
): Chain[AsmLine] = {
|
||||
var asm = Chain.empty[AsmLine]
|
||||
asm += Comment(stmt.toString)
|
||||
stmt match {
|
||||
case Assign(lhs, rhs) =>
|
||||
lhs match {
|
||||
@@ -200,15 +175,15 @@ object asmGenerator {
|
||||
asm ++= evalExprOntoStack(i)
|
||||
asm += stack.pop(RCX)
|
||||
asm += Compare(ECX, ImmediateVal(0))
|
||||
asm += Jump(LabelArg(OutOfBoundsError.errLabel), Cond.Less)
|
||||
asm += Jump(labelGenerator.getLabelArg(OutOfBoundsError), Cond.Less)
|
||||
asm += stack.push(Q64, RCX)
|
||||
asm ++= evalExprOntoStack(x)
|
||||
asm += stack.pop(RAX)
|
||||
asm += stack.pop(RCX)
|
||||
asm += Compare(EAX, ImmediateVal(0))
|
||||
asm += Jump(LabelArg(NullPtrError.errLabel), Cond.Equal)
|
||||
asm += Jump(labelGenerator.getLabelArg(NullPtrError), Cond.Equal)
|
||||
asm += Compare(MemLocation(RAX, D32), ECX)
|
||||
asm += Jump(LabelArg(OutOfBoundsError.errLabel), Cond.LessEqual)
|
||||
asm += Jump(labelGenerator.getLabelArg(OutOfBoundsError), Cond.LessEqual)
|
||||
asm += stack.pop(RDX)
|
||||
|
||||
asm += Move(
|
||||
@@ -266,7 +241,6 @@ object asmGenerator {
|
||||
|
||||
private def evalExprOntoStack(expr: Expr)(using
|
||||
stack: Stack,
|
||||
strings: ListBuffer[String],
|
||||
labelGenerator: LabelGenerator
|
||||
): Chain[AsmLine] = {
|
||||
var asm = Chain.empty[AsmLine]
|
||||
@@ -279,8 +253,8 @@ object asmGenerator {
|
||||
case array @ ArrayLiter(elems) =>
|
||||
expr.ty match {
|
||||
case KnownType.String =>
|
||||
strings += elems.collect { case CharLiter(v) => v }.mkString
|
||||
asm += Load(RAX, IndexAddress(RIP, LabelArg(s".L.str${strings.size - 1}")))
|
||||
val str = elems.collect { case CharLiter(v) => v }.mkString
|
||||
asm += Load(RAX, IndexAddress(RIP, labelGenerator.getLabelArg(str)))
|
||||
asm += stack.push(Q64, RAX)
|
||||
case ty =>
|
||||
asm ++= generateCall(
|
||||
@@ -311,12 +285,12 @@ object asmGenerator {
|
||||
asm ++= evalExprOntoStack(i)
|
||||
asm += stack.pop(RCX)
|
||||
asm += Compare(RCX, ImmediateVal(0))
|
||||
asm += Jump(LabelArg(OutOfBoundsError.errLabel), Cond.Less)
|
||||
asm += Jump(labelGenerator.getLabelArg(OutOfBoundsError), Cond.Less)
|
||||
asm += stack.pop(RAX)
|
||||
asm += Compare(EAX, ImmediateVal(0))
|
||||
asm += Jump(LabelArg(NullPtrError.errLabel), Cond.Equal)
|
||||
asm += Jump(labelGenerator.getLabelArg(NullPtrError), Cond.Equal)
|
||||
asm += Compare(MemLocation(RAX, D32), ECX)
|
||||
asm += Jump(LabelArg(OutOfBoundsError.errLabel), Cond.LessEqual)
|
||||
asm += Jump(labelGenerator.getLabelArg(OutOfBoundsError), Cond.LessEqual)
|
||||
// + Int because we store the length of the array at the start
|
||||
asm += Move(
|
||||
Register(x.ty.elemSize, AX),
|
||||
@@ -330,7 +304,7 @@ object asmGenerator {
|
||||
asm += Move(EAX, stack.head)
|
||||
asm += And(EAX, ImmediateVal(~_7_BIT_MASK))
|
||||
asm += Compare(EAX, ImmediateVal(0))
|
||||
asm += Jump(LabelArg(BadChrError.errLabel), Cond.NotEqual)
|
||||
asm += Jump(labelGenerator.getLabelArg(BadChrError), Cond.NotEqual)
|
||||
case UnaryOperator.Ord => // No op needed
|
||||
case UnaryOperator.Len =>
|
||||
asm += stack.pop(RAX)
|
||||
@@ -339,7 +313,7 @@ object asmGenerator {
|
||||
case UnaryOperator.Negate =>
|
||||
asm += Xor(EAX, EAX)
|
||||
asm += Subtract(EAX, stack.head)
|
||||
asm += Jump(LabelArg(OverflowError.errLabel), Cond.Overflow)
|
||||
asm += Jump(labelGenerator.getLabelArg(OverflowError), Cond.Overflow)
|
||||
asm += stack.drop()
|
||||
asm += stack.push(Q64, RAX)
|
||||
case UnaryOperator.Not =>
|
||||
@@ -355,21 +329,21 @@ object asmGenerator {
|
||||
op match {
|
||||
case BinaryOperator.Add =>
|
||||
asm += Add(stack.head, destX)
|
||||
asm += Jump(LabelArg(OverflowError.errLabel), Cond.Overflow)
|
||||
asm += Jump(labelGenerator.getLabelArg(OverflowError), Cond.Overflow)
|
||||
case BinaryOperator.Sub =>
|
||||
asm += Subtract(destX, stack.head)
|
||||
asm += Jump(LabelArg(OverflowError.errLabel), Cond.Overflow)
|
||||
asm += Jump(labelGenerator.getLabelArg(OverflowError), Cond.Overflow)
|
||||
asm += stack.drop()
|
||||
asm += stack.push(destX.size, RAX)
|
||||
case BinaryOperator.Mul =>
|
||||
asm += Multiply(destX, stack.head)
|
||||
asm += Jump(LabelArg(OverflowError.errLabel), Cond.Overflow)
|
||||
asm += Jump(labelGenerator.getLabelArg(OverflowError), Cond.Overflow)
|
||||
asm += stack.drop()
|
||||
asm += stack.push(destX.size, RAX)
|
||||
|
||||
case BinaryOperator.Div =>
|
||||
asm += Compare(stack.head, ImmediateVal(0))
|
||||
asm += Jump(LabelArg(ZeroDivError.errLabel), Cond.Equal)
|
||||
asm += Jump(labelGenerator.getLabelArg(ZeroDivError), Cond.Equal)
|
||||
asm += CDQ()
|
||||
asm += Divide(stack.head)
|
||||
asm += stack.drop()
|
||||
@@ -377,7 +351,7 @@ object asmGenerator {
|
||||
|
||||
case BinaryOperator.Mod =>
|
||||
asm += Compare(stack.head, ImmediateVal(0))
|
||||
asm += Jump(LabelArg(ZeroDivError.errLabel), Cond.Equal)
|
||||
asm += Jump(labelGenerator.getLabelArg(ZeroDivError), Cond.Equal)
|
||||
asm += CDQ()
|
||||
asm += Divide(stack.head)
|
||||
asm += stack.drop()
|
||||
@@ -405,7 +379,6 @@ object asmGenerator {
|
||||
|
||||
private def generateCall(call: microWacc.Call, isTail: Boolean)(using
|
||||
stack: Stack,
|
||||
strings: ListBuffer[String],
|
||||
labelGenerator: LabelGenerator
|
||||
): Chain[AsmLine] = {
|
||||
var asm = Chain.empty[AsmLine]
|
||||
@@ -428,9 +401,9 @@ object asmGenerator {
|
||||
|
||||
// Tail Call Optimisation (TCO)
|
||||
if (isTail) {
|
||||
asm += Jump(LabelArg(labelGenerator.getLabel(target))) // tail call
|
||||
asm += Jump(labelGenerator.getLabelArg(target)) // tail call
|
||||
} else {
|
||||
asm += assemblyIR.Call(LabelArg(labelGenerator.getLabel(target))) // regular call
|
||||
asm += assemblyIR.Call(labelGenerator.getLabelArg(target)) // regular call
|
||||
}
|
||||
|
||||
if (args.size > argRegs.size) {
|
||||
@@ -477,7 +450,7 @@ object asmGenerator {
|
||||
|
||||
private val escapedCharsMapping = escapedChars.map { case (k, v) => v -> s"\\$k" }
|
||||
extension (s: String) {
|
||||
private def escaped: String =
|
||||
def escaped: String =
|
||||
s.flatMap(c => escapedCharsMapping.getOrElse(c, c.toString))
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user