Compare commits

...

10 Commits

10 changed files with 442 additions and 262 deletions

View File

@ -1,6 +1,7 @@
package wacc package wacc
import scala.collection.mutable import scala.collection.mutable
import cats.data.Chain
import parsley.{Failure, Success} import parsley.{Failure, Success}
import scopt.OParser import scopt.OParser
import java.io.File import java.io.File
@ -63,7 +64,7 @@ def frontend(
} }
val s = "enter an integer to echo" val s = "enter an integer to echo"
def backend(typedProg: microWacc.Program): List[asm.AsmLine] = def backend(typedProg: microWacc.Program): Chain[asm.AsmLine] =
asmGenerator.generateAsm(typedProg) asmGenerator.generateAsm(typedProg)
def compile(filename: String, outFile: Option[File] = None)(using def compile(filename: String, outFile: Option[File] = None)(using

View File

@ -0,0 +1,90 @@
package wacc
import scala.collection.mutable.LinkedHashMap
import cats.data.Chain
class Stack {
import assemblyIR._
import assemblyIR.Size._
import sizeExtensions.size
import microWacc as mw
private val RSP = Register(Q64, RegName.SP)
private class StackValue(val size: Size, val offset: Int) {
def bottom: Int = offset + elemBytes
}
private val stack = LinkedHashMap[mw.Expr | Int, StackValue]()
private val elemBytes: Int = Q64.toInt
private def sizeBytes: Int = stack.size * elemBytes
/** The stack's size in bytes. */
def size: Int = stack.size
/** Push an expression onto the stack. */
def push(expr: mw.Expr, src: Register): AsmLine = {
stack += expr -> StackValue(src.size, sizeBytes)
Push(src)
}
/** Push a value onto the stack. */
def push(itemSize: Size, addr: Src): AsmLine = {
stack += stack.size -> StackValue(itemSize, sizeBytes)
Push(addr)
}
/** Reserve space for a variable on the stack. */
def reserve(ident: mw.Ident): AsmLine = {
stack += ident -> StackValue(ident.ty.size, sizeBytes)
Subtract(RSP, ImmediateVal(elemBytes))
}
/** Reserve space for a register on the stack. */
def reserve(src: Register): AsmLine = {
stack += stack.size -> StackValue(src.size, sizeBytes)
Subtract(RSP, ImmediateVal(src.size.toInt))
}
/** Reserve space for values on the stack.
*
* @param sizes
* The sizes of the values to reserve space for.
*/
def reserve(sizes: List[Size]): AsmLine = {
sizes.foreach { itemSize =>
stack += stack.size -> StackValue(itemSize, sizeBytes)
}
Subtract(RSP, ImmediateVal(elemBytes * sizes.size))
}
/** Pop a value from the stack into a register. Sizes MUST match. */
def pop(dest: Register): AsmLine = {
stack.remove(stack.last._1)
Pop(dest)
}
/** Drop the top n values from the stack. */
def drop(n: Int = 1): AsmLine = {
(1 to n).foreach { _ =>
stack.remove(stack.last._1)
}
Add(RSP, ImmediateVal(n * elemBytes))
}
/** Generate AsmLines within a scope, which is reset after the block. */
def withScope(block: () => Chain[AsmLine]): Chain[AsmLine] = {
val resetToSize = stack.size
var lines = block()
lines :+= drop(stack.size - resetToSize)
lines
}
/** Get an IndexAddress for a variable in the stack. */
def accessVar(ident: mw.Ident): IndexAddress =
IndexAddress(RSP, sizeBytes - stack(ident).bottom)
def contains(ident: mw.Ident): Boolean = stack.contains(ident)
def head: MemLocation = MemLocation(RSP, stack.last._2.size)
override def toString(): String = stack.toString
}

View File

@ -1,15 +1,16 @@
package wacc package wacc
import scala.collection.mutable.LinkedHashMap
import scala.collection.mutable.ListBuffer import scala.collection.mutable.ListBuffer
import cats.data.Chain import cats.data.Chain
import cats.syntax.foldable._ import cats.syntax.foldable._
// import parsley.token.errors.Label
object asmGenerator { object asmGenerator {
import microWacc._ import microWacc._
import assemblyIR._ import assemblyIR._
import wacc.types._ import assemblyIR.Size._
import assemblyIR.RegName._
import types._
import sizeExtensions._
import lexer.escapedChars import lexer.escapedChars
abstract case class Error() { abstract case class Error() {
@ -30,26 +31,20 @@ object asmGenerator {
def errLabel = ".L._errDivZero" def errLabel = ".L._errDivZero"
} }
val RAX = Register(RegSize.R64, RegName.AX) private val RAX = Register(Q64, AX)
val EAX = Register(RegSize.E32, RegName.AX) private val EAX = Register(D32, AX)
val ESP = Register(RegSize.E32, RegName.SP) private val RDI = Register(Q64, DI)
val EDX = Register(RegSize.E32, RegName.DX) private val RIP = Register(Q64, IP)
val RDI = Register(RegSize.R64, RegName.DI) private val RBP = Register(Q64, BP)
val RIP = Register(RegSize.R64, RegName.IP) private val RSI = Register(Q64, SI)
val RBP = Register(RegSize.R64, RegName.BP) private val RDX = Register(Q64, DX)
val RSI = Register(RegSize.R64, RegName.SI) private val RCX = Register(Q64, CX)
val RDX = Register(RegSize.R64, RegName.DX) private val argRegs = List(DI, SI, DX, CX, R8, R9)
val RCX = Register(RegSize.R64, RegName.CX)
val R8 = Register(RegSize.R64, RegName.Reg8)
val R9 = Register(RegSize.R64, RegName.Reg9)
val argRegs = List(RDI, RSI, RDX, RCX, R8, R9)
val _8_BIT_MASK = 0xff extension [T](chain: Chain[T])
def +(item: T): Chain[T] = chain.append(item)
extension (chain: Chain[AsmLine]) def concatAll(chains: Chain[T]*): Chain[T] =
def +(line: AsmLine): Chain[AsmLine] = chain.append(line)
def concatAll(chains: Chain[AsmLine]*): Chain[AsmLine] =
chains.foldLeft(chain)(_ ++ _) chains.foldLeft(chain)(_ ++ _)
class LabelGenerator { class LabelGenerator {
@ -64,7 +59,7 @@ object asmGenerator {
} }
} }
def generateAsm(microProg: Program): List[AsmLine] = { def generateAsm(microProg: Program): Chain[AsmLine] = {
given stack: Stack = Stack() given stack: Stack = Stack()
given strings: ListBuffer[String] = ListBuffer[String]() given strings: ListBuffer[String] = ListBuffer[String]()
given labelGenerator: LabelGenerator = LabelGenerator() given labelGenerator: LabelGenerator = LabelGenerator()
@ -72,9 +67,8 @@ object asmGenerator {
val progAsm = Chain(LabelDef("main")).concatAll( val progAsm = Chain(LabelDef("main")).concatAll(
funcPrologue(), funcPrologue(),
Chain.one(stack.align()),
main.foldMap(generateStmt(_)), main.foldMap(generateStmt(_)),
Chain.one(Move(RAX, ImmediateVal(0))), Chain.one(Xor(RAX, RAX)),
funcEpilogue(), funcEpilogue(),
generateBuiltInFuncs(), generateBuiltInFuncs(),
funcs.foldMap(generateUserFunc(_)) funcs.foldMap(generateUserFunc(_))
@ -96,78 +90,84 @@ object asmGenerator {
strDirs, strDirs,
Chain.one(Directive.Text), Chain.one(Directive.Text),
progAsm progAsm
).toList )
} }
def wrapFunc(labelName: String, funcBody: Chain[AsmLine])(using private def wrapBuiltinFunc(labelName: String, funcBody: Chain[AsmLine]): Chain[AsmLine] = {
stack: Stack, var chain = Chain.one[AsmLine](LabelDef(labelName))
strings: ListBuffer[String]
): Chain[AsmLine] = {
var chain = Chain.empty[AsmLine]
chain += LabelDef(labelName)
chain ++= funcPrologue() chain ++= funcPrologue()
chain ++= funcBody chain ++= funcBody
chain ++= funcEpilogue() chain ++= funcEpilogue()
chain chain
} }
def generateUserFunc(func: FuncDecl)(using private def generateUserFunc(func: FuncDecl)(using
strings: ListBuffer[String], strings: ListBuffer[String],
labelGenerator: LabelGenerator labelGenerator: LabelGenerator
): Chain[AsmLine] = { ): Chain[AsmLine] = {
given stack: Stack = Stack() given stack: Stack = Stack()
// Setup the stack with param 7 and up // Setup the stack with param 7 and up
func.params.drop(argRegs.size).foreach(stack.reserve(_)) func.params.drop(argRegs.size).foreach(stack.reserve(_))
var chain = Chain.empty[AsmLine] var chain = Chain.one[AsmLine](LabelDef(labelGenerator.getLabel(func.name)))
chain ++= funcPrologue()
// Push the rest of params onto the stack for simplicity // Push the rest of params onto the stack for simplicity
argRegs.zip(func.params).foreach { (reg, param) => argRegs.zip(func.params).foreach { (reg, param) =>
chain += stack.push(param, reg) chain += stack.push(param, Register(Q64, reg))
} }
chain ++= func.body.foldMap(generateStmt(_)) chain ++= func.body.foldMap(generateStmt(_))
wrapFunc(labelGenerator.getLabel(func.name), chain) // No need for epilogue here since all user functions must return explicitly
chain
} }
def generateBuiltInFuncs()(using private def generateBuiltInFuncs()(using
stack: Stack,
strings: ListBuffer[String],
labelGenerator: LabelGenerator labelGenerator: LabelGenerator
): Chain[AsmLine] = { ): Chain[AsmLine] = {
var chain = Chain.empty[AsmLine] var chain = Chain.empty[AsmLine]
chain ++= wrapFunc( chain ++= wrapBuiltinFunc(
labelGenerator.getLabel(Builtin.Exit), labelGenerator.getLabel(Builtin.Exit),
Chain(stack.align(), assemblyIR.Call(CLibFunc.Exit)) Chain(stackAlign, assemblyIR.Call(CLibFunc.Exit))
) )
chain ++= wrapFunc( chain ++= wrapBuiltinFunc(
labelGenerator.getLabel(Builtin.Printf), labelGenerator.getLabel(Builtin.Printf),
Chain( Chain(
stack.align(), stackAlign,
assemblyIR.Call(CLibFunc.PrintF), assemblyIR.Call(CLibFunc.PrintF),
Move(RDI, ImmediateVal(0)), Xor(RDI, RDI),
assemblyIR.Call(CLibFunc.Fflush) assemblyIR.Call(CLibFunc.Fflush)
) )
) )
chain ++= wrapFunc( chain ++= wrapBuiltinFunc(
labelGenerator.getLabel(Builtin.Malloc), labelGenerator.getLabel(Builtin.PrintCharArray),
Chain.one(stack.align()) Chain(
stackAlign,
Load(RDX, IndexAddress(RSI, KnownType.Int.size.toInt)),
Move(Register(D32, SI), MemLocation(RSI, D32)),
assemblyIR.Call(CLibFunc.PrintF),
Xor(RDI, RDI),
assemblyIR.Call(CLibFunc.Fflush)
)
) )
chain ++= wrapFunc(labelGenerator.getLabel(Builtin.Free), Chain.empty) chain ++= wrapBuiltinFunc(
labelGenerator.getLabel(Builtin.Malloc),
Chain(stackAlign, assemblyIR.Call(CLibFunc.Malloc))
// Out of memory check is optional
)
chain ++= wrapFunc( chain ++= wrapBuiltinFunc(labelGenerator.getLabel(Builtin.Free), Chain.empty)
chain ++= wrapBuiltinFunc(
labelGenerator.getLabel(Builtin.Read), labelGenerator.getLabel(Builtin.Read),
Chain( Chain(
stack.align(), stackAlign,
stack.reserve(), Subtract(Register(Q64, SP), ImmediateVal(8)),
stack.push(RSI), Push(RSI),
Load(RSI, stack.head), Load(RSI, MemLocation(Register(Q64, SP), Q64)),
assemblyIR.Call(CLibFunc.Scanf), assemblyIR.Call(CLibFunc.Scanf),
stack.pop(RAX), Pop(RAX)
stack.drop()
) )
) )
@ -175,7 +175,7 @@ object asmGenerator {
// TODO can this be done with a call to generateStmt? // TODO can this be done with a call to generateStmt?
// Consider other error cases -> look to generalise // Consider other error cases -> look to generalise
LabelDef(zeroDivError.errLabel), LabelDef(zeroDivError.errLabel),
stack.align(), stackAlign,
Load(RDI, IndexAddress(RIP, LabelArg(zeroDivError.strLabel))), Load(RDI, IndexAddress(RIP, LabelArg(zeroDivError.strLabel))),
assemblyIR.Call(CLibFunc.PrintF), assemblyIR.Call(CLibFunc.PrintF),
Move(RDI, ImmediateVal(-1)), Move(RDI, ImmediateVal(-1)),
@ -185,28 +185,34 @@ object asmGenerator {
chain chain
} }
def generateStmt(stmt: Stmt)(using private def generateStmt(stmt: Stmt)(using
stack: Stack, stack: Stack,
strings: ListBuffer[String], strings: ListBuffer[String],
labelGenerator: LabelGenerator labelGenerator: LabelGenerator
): Chain[AsmLine] = { ): Chain[AsmLine] = {
var chain = Chain.empty[AsmLine] var chain = Chain.empty[AsmLine]
chain += Comment(stmt.toString)
stmt match { stmt match {
case Assign(lhs, rhs) => case Assign(lhs, rhs) =>
var dest: () => IndexAddress = () => IndexAddress(RAX, 0) // overwritten below
lhs match { lhs match {
case ident: Ident => case ident: Ident =>
dest = stack.accessVar(ident)
if (!stack.contains(ident)) chain += stack.reserve(ident) if (!stack.contains(ident)) chain += stack.reserve(ident)
// TODO lhs = arrayElem chain ++= evalExprOntoStack(rhs)
case _ => chain += stack.pop(RAX)
} chain += Move(stack.accessVar(ident), RAX)
case ArrayElem(x, i) =>
chain ++= evalExprOntoStack(rhs)
chain ++= evalExprOntoStack(i)
chain ++= evalExprOntoStack(x)
chain += stack.pop(RAX)
chain += stack.pop(RCX)
chain += stack.pop(RDX)
chain ++= evalExprOntoStack(rhs) chain += Move(
chain += stack.pop(RAX) IndexAddress(RAX, KnownType.Int.size.toInt, RCX, x.ty.elemSize.toInt),
chain += Move(dest(), RAX) Register(x.ty.elemSize, DX)
)
}
case If(cond, thenBranch, elseBranch) => case If(cond, thenBranch, elseBranch) =>
val elseLabel = labelGenerator.getLabel() val elseLabel = labelGenerator.getLabel()
@ -217,11 +223,11 @@ object asmGenerator {
chain += Compare(RAX, ImmediateVal(0)) chain += Compare(RAX, ImmediateVal(0))
chain += Jump(LabelArg(elseLabel), Cond.Equal) chain += Jump(LabelArg(elseLabel), Cond.Equal)
chain ++= thenBranch.foldMap(generateStmt) chain ++= stack.withScope(() => thenBranch.foldMap(generateStmt))
chain += Jump(LabelArg(endLabel)) chain += Jump(LabelArg(endLabel))
chain += LabelDef(elseLabel) chain += LabelDef(elseLabel)
chain ++= elseBranch.foldMap(generateStmt) chain ++= stack.withScope(() => elseBranch.foldMap(generateStmt))
chain += LabelDef(endLabel) chain += LabelDef(endLabel)
case While(cond, body) => case While(cond, body) =>
@ -234,106 +240,147 @@ object asmGenerator {
chain += Compare(RAX, ImmediateVal(0)) chain += Compare(RAX, ImmediateVal(0))
chain += Jump(LabelArg(endLabel), Cond.Equal) chain += Jump(LabelArg(endLabel), Cond.Equal)
chain ++= body.foldMap(generateStmt) chain ++= stack.withScope(() => body.foldMap(generateStmt))
chain += Jump(LabelArg(startLabel)) chain += Jump(LabelArg(startLabel))
chain += LabelDef(endLabel) chain += LabelDef(endLabel)
case microWacc.Return(expr) =>
chain ++= evalExprOntoStack(expr)
chain += stack.pop(RAX)
chain ++= funcEpilogue()
case call: microWacc.Call => case call: microWacc.Call =>
chain ++= generateCall(call) chain ++= generateCall(call, isTail = false)
case microWacc.Return(expr) =>
expr match {
case call: microWacc.Call =>
chain ++= generateCall(call, isTail = true) // tco
case _ =>
chain ++= evalExprOntoStack(expr)
chain += stack.pop(RAX)
chain ++= funcEpilogue()
}
} }
chain chain
} }
def evalExprOntoStack(expr: Expr)(using private def evalExprOntoStack(expr: Expr)(using
stack: Stack, stack: Stack,
strings: ListBuffer[String], strings: ListBuffer[String],
labelGenerator: LabelGenerator labelGenerator: LabelGenerator
): Chain[AsmLine] = { ): Chain[AsmLine] = {
var chain = Chain.empty[AsmLine] var chain = Chain.empty[AsmLine]
val stackSizeStart = stack.size
expr match { expr match {
case IntLiter(v) => chain += stack.push(ImmediateVal(v)) case IntLiter(v) => chain += stack.push(KnownType.Int.size, ImmediateVal(v))
case CharLiter(v) => chain += stack.push(ImmediateVal(v.toInt)) case CharLiter(v) => chain += stack.push(KnownType.Char.size, ImmediateVal(v.toInt))
case ident: Ident => chain += stack.push(stack.accessVar(ident)()) case ident: Ident => chain += stack.push(ident.ty.size, stack.accessVar(ident))
case ArrayLiter(elems) => case array @ ArrayLiter(elems) =>
expr.ty match { expr.ty match {
case KnownType.String => case KnownType.String =>
strings += elems.collect { case CharLiter(v) => v }.mkString strings += elems.collect { case CharLiter(v) => v }.mkString
chain += Load(RAX, IndexAddress(RIP, LabelArg(s".L.str${strings.size - 1}"))) chain += Load(RAX, IndexAddress(RIP, LabelArg(s".L.str${strings.size - 1}")))
chain += stack.push(RAX) chain += stack.push(Q64, RAX)
case _ => // Other array types TODO case ty =>
chain ++= generateCall(
microWacc.Call(Builtin.Malloc, List(IntLiter(array.heapSize))),
isTail = false
)
chain += stack.push(Q64, RAX)
// Store the length of the array at the start
chain += Move(MemLocation(RAX, D32), ImmediateVal(elems.size))
elems.zipWithIndex.foldMap { (elem, i) =>
chain ++= evalExprOntoStack(elem)
chain += stack.pop(RCX)
chain += stack.pop(RAX)
chain += Move(IndexAddress(RAX, 4 + i * ty.elemSize.toInt), Register(ty.elemSize, CX))
chain += stack.push(Q64, RAX)
}
} }
case BoolLiter(v) => chain += stack.push(ImmediateVal(if (v) 1 else 0)) case BoolLiter(true) =>
case NullLiter() => chain += stack.push(ImmediateVal(0)) chain += stack.push(KnownType.Bool.size, ImmediateVal(1))
case ArrayElem(_, _) => // TODO: Implement handling case BoolLiter(false) =>
chain += Xor(RAX, RAX)
chain += stack.push(KnownType.Bool.size, RAX)
case NullLiter() =>
chain += stack.push(KnownType.Pair(?, ?).size, ImmediateVal(0))
case ArrayElem(x, i) =>
chain ++= evalExprOntoStack(x)
chain ++= evalExprOntoStack(i)
chain += stack.pop(RCX)
chain += stack.pop(RAX)
// + Int because we store the length of the array at the start
chain += Move(
Register(x.ty.elemSize, AX),
IndexAddress(RAX, KnownType.Int.size.toInt, RCX, x.ty.elemSize.toInt)
)
chain += stack.push(x.ty.elemSize, RAX)
case UnaryOp(x, op) => case UnaryOp(x, op) =>
chain ++= evalExprOntoStack(x) chain ++= evalExprOntoStack(x)
op match { op match {
case UnaryOperator.Chr | UnaryOperator.Ord | UnaryOperator.Len => // No op needed case UnaryOperator.Chr | UnaryOperator.Ord => // No op needed
case UnaryOperator.Negate => chain += Negate(stack.head(SizeDir.DWord)) case UnaryOperator.Len =>
chain += stack.pop(RAX)
chain += Move(EAX, MemLocation(RAX, D32))
chain += stack.push(D32, RAX)
case UnaryOperator.Negate =>
chain += Negate(stack.head)
case UnaryOperator.Not => case UnaryOperator.Not =>
chain += Xor(stack.head(SizeDir.DWord), ImmediateVal(1)) chain += Xor(stack.head, ImmediateVal(1))
} }
case BinaryOp(x, y, op) => case BinaryOp(x, y, op) =>
val destX = Register(x.ty.size, AX)
chain ++= evalExprOntoStack(y) chain ++= evalExprOntoStack(y)
chain ++= evalExprOntoStack(x) chain ++= evalExprOntoStack(x)
chain += stack.pop(RAX) chain += stack.pop(RAX)
op match { op match {
case BinaryOperator.Add => chain += Add(stack.head(SizeDir.DWord), EAX) case BinaryOperator.Add =>
chain += Add(stack.head, destX)
case BinaryOperator.Sub => case BinaryOperator.Sub =>
chain += Subtract(EAX, stack.head(SizeDir.DWord)) chain += Subtract(destX, stack.head)
chain += stack.drop() chain += stack.drop()
chain += stack.push(RAX) chain += stack.push(destX.size, RAX)
case BinaryOperator.Mul => case BinaryOperator.Mul =>
chain += Multiply(EAX, stack.head(SizeDir.DWord)) chain += Multiply(destX, stack.head)
chain += stack.drop() chain += stack.drop()
chain += stack.push(RAX) chain += stack.push(destX.size, RAX)
case BinaryOperator.Div => case BinaryOperator.Div =>
chain += Compare(stack.head(SizeDir.DWord), ImmediateVal(0)) chain += Compare(stack.head, ImmediateVal(0))
chain += Jump(LabelArg(zeroDivError.errLabel), Cond.Equal) chain += Jump(LabelArg(zeroDivError.errLabel), Cond.Equal)
chain += CDQ() chain += CDQ()
chain += Divide(stack.head(SizeDir.DWord)) chain += Divide(stack.head)
chain += stack.drop() chain += stack.drop()
chain += stack.push(RAX) chain += stack.push(destX.size, RAX)
case BinaryOperator.Mod => case BinaryOperator.Mod =>
chain += CDQ() chain += CDQ()
chain += Divide(stack.head(SizeDir.DWord)) chain += Divide(stack.head)
chain += stack.drop() chain += stack.drop()
chain += stack.push(RDX) chain += stack.push(destX.size, RDX)
case BinaryOperator.Eq => chain ++= generateComparison(x, y, Cond.Equal) case BinaryOperator.Eq => chain ++= generateComparison(destX, Cond.Equal)
case BinaryOperator.Neq => chain ++= generateComparison(x, y, Cond.NotEqual) case BinaryOperator.Neq => chain ++= generateComparison(destX, Cond.NotEqual)
case BinaryOperator.Greater => chain ++= generateComparison(x, y, Cond.Greater) case BinaryOperator.Greater => chain ++= generateComparison(destX, Cond.Greater)
case BinaryOperator.GreaterEq => chain ++= generateComparison(x, y, Cond.GreaterEqual) case BinaryOperator.GreaterEq => chain ++= generateComparison(destX, Cond.GreaterEqual)
case BinaryOperator.Less => chain ++= generateComparison(x, y, Cond.Less) case BinaryOperator.Less => chain ++= generateComparison(destX, Cond.Less)
case BinaryOperator.LessEq => chain ++= generateComparison(x, y, Cond.LessEqual) case BinaryOperator.LessEq => chain ++= generateComparison(destX, Cond.LessEqual)
case BinaryOperator.And => chain += And(stack.head(SizeDir.DWord), EAX) case BinaryOperator.And => chain += And(stack.head, destX)
case BinaryOperator.Or => chain += Or(stack.head(SizeDir.DWord), EAX) case BinaryOperator.Or => chain += Or(stack.head, destX)
} }
case call: microWacc.Call => case call: microWacc.Call =>
chain ++= generateCall(call) chain ++= generateCall(call, isTail = false)
chain += stack.push(RAX) chain += stack.push(call.ty.size, RAX)
} }
if chain.isEmpty then chain += stack.push(ImmediateVal(0)) assert(stack.size == stackSizeStart + 1)
chain ++= zeroRest(MemLocation(stack.head.pointer, Q64), expr.ty.size)
chain chain
} }
def generateCall(call: microWacc.Call)(using private def generateCall(call: microWacc.Call, isTail: Boolean)(using
stack: Stack, stack: Stack,
strings: ListBuffer[String], strings: ListBuffer[String],
labelGenerator: LabelGenerator labelGenerator: LabelGenerator
@ -343,14 +390,19 @@ object asmGenerator {
argRegs.zip(args).foldMap { (reg, expr) => argRegs.zip(args).foldMap { (reg, expr) =>
chain ++= evalExprOntoStack(expr) chain ++= evalExprOntoStack(expr)
chain += stack.pop(reg) chain += stack.pop(Register(Q64, reg))
} }
args.drop(argRegs.size).foldMap { args.drop(argRegs.size).foldMap {
chain ++= evalExprOntoStack(_) chain ++= evalExprOntoStack(_)
} }
chain += assemblyIR.Call(LabelArg(labelGenerator.getLabel(target))) // Tail Call Optimisation (TCO)
if (isTail) {
chain += Jump(LabelArg(labelGenerator.getLabel(target))) // tail call
} else {
chain += assemblyIR.Call(LabelArg(labelGenerator.getLabel(target))) // regular call
}
if (args.size > argRegs.size) { if (args.size > argRegs.size) {
chain += stack.drop(args.size - argRegs.size) chain += stack.drop(args.size - argRegs.size)
@ -359,79 +411,39 @@ object asmGenerator {
chain chain
} }
def generateComparison(x: Expr, y: Expr, cond: Cond)(using private def generateComparison(destX: Register, cond: Cond)(using
stack: Stack, stack: Stack
strings: ListBuffer[String],
labelGenerator: LabelGenerator
): Chain[AsmLine] = { ): Chain[AsmLine] = {
var chain = Chain.empty[AsmLine] var chain = Chain.empty[AsmLine]
chain ++= evalExprOntoStack(x) chain += Compare(destX, stack.head)
chain ++= evalExprOntoStack(y) chain += Set(Register(B8, AX), cond)
chain += stack.pop(RAX) chain ++= zeroRest(RAX, B8)
chain += Compare(stack.head(SizeDir.DWord), EAX)
chain += Set(Register(RegSize.Byte, RegName.AL), cond)
chain += And(RAX, ImmediateVal(_8_BIT_MASK))
chain += stack.drop() chain += stack.drop()
chain += stack.push(RAX) chain += stack.push(B8, RAX)
chain chain
} }
// Missing a sub instruction but dont think we need it private def funcPrologue(): Chain[AsmLine] = {
def funcPrologue()(using stack: Stack): Chain[AsmLine] = {
var chain = Chain.empty[AsmLine] var chain = Chain.empty[AsmLine]
chain += stack.push(RBP) chain += Push(RBP)
chain += Move(RBP, Register(RegSize.R64, RegName.SP)) chain += Move(RBP, Register(Q64, SP))
chain chain
} }
def funcEpilogue()(using stack: Stack): Chain[AsmLine] = { private def funcEpilogue(): Chain[AsmLine] = {
var chain = Chain.empty[AsmLine] var chain = Chain.empty[AsmLine]
chain += Move(Register(RegSize.R64, RegName.SP), RBP) chain += Move(Register(Q64, SP), RBP)
chain += Pop(RBP) chain += Pop(RBP)
chain += assemblyIR.Return() chain += assemblyIR.Return()
chain chain
} }
class Stack { private def stackAlign: AsmLine = And(Register(Q64, SP), ImmediateVal(-16))
private val stack = LinkedHashMap[Expr | Int, Int]() private def zeroRest(dest: Dest, size: Size): Chain[AsmLine] = size match {
private val RSP = Register(RegSize.R64, RegName.SP) case Q64 | D32 => Chain.empty
case _ => Chain.one(And(dest, ImmediateVal((1 << (size.toInt * 8)) - 1)))
private def next: Int = stack.size + 1
def push(expr: Expr, src: Src): AsmLine = {
stack += expr -> next
Push(src)
}
def push(src: Src): AsmLine = {
stack += stack.size -> next
Push(src)
}
def pop(dest: Src): AsmLine = {
stack.remove(stack.last._1)
Pop(dest)
}
def reserve(ident: Ident): AsmLine = {
stack += ident -> next
Subtract(RSP, ImmediateVal(8))
}
def reserve(n: Int = 1): AsmLine = {
(1 to n).foreach(_ => stack += stack.size -> next)
Subtract(RSP, ImmediateVal(n * 8))
}
def drop(n: Int = 1): AsmLine = {
(1 to n).foreach(_ => stack.remove(stack.last._1))
Add(RSP, ImmediateVal(n * 8))
}
def accessVar(ident: Ident): () => IndexAddress = () => {
IndexAddress(RSP, (stack.size - stack(ident)) * 8)
}
def head: MemLocation = MemLocation(RSP)
def head(size: SizeDir): MemLocation = MemLocation(RSP, size)
def contains(ident: Ident): Boolean = stack.contains(ident)
// TODO: Might want to actually properly handle this with the LinkedHashMap too
def align(): AsmLine = And(RSP, ImmediateVal(-16))
} }
private val escapedCharsMapping = escapedChars.map { case (k, v) => v -> s"\\$k" } private val escapedCharsMapping = escapedChars.map { case (k, v) => v -> s"\\$k" }

View File

@ -6,40 +6,73 @@ object assemblyIR {
sealed trait Operand sealed trait Operand
sealed trait Src extends Operand // mem location, register and imm value sealed trait Src extends Operand // mem location, register and imm value
sealed trait Dest extends Operand // mem location and register sealed trait Dest extends Operand // mem location and register
enum RegSize {
case R64
case E32
case Byte
override def toString = this match { enum Size {
case R64 => "r" case Q64, D32, W16, B8
case E32 => "e"
case Byte => "" def toInt: Int = this match {
case Q64 => 8
case D32 => 4
case W16 => 2
case B8 => 1
}
private val ptr = "ptr "
override def toString(): String = this match {
case Q64 => "qword " + ptr
case D32 => "dword " + ptr
case W16 => "word " + ptr
case B8 => "byte " + ptr
} }
} }
enum RegName { enum RegName {
case AX, AL, BX, CX, DX, SI, DI, SP, BP, IP, Reg8, Reg9, Reg10, Reg11, Reg12, Reg13, Reg14, case AX, BX, CX, DX, SI, DI, SP, BP, IP, R8, R9, R10, R11, R12, R13, R14, R15
Reg15 }
override def toString = this match {
case AX => "ax" case class Register(size: Size, name: RegName) extends Dest with Src {
case AL => "al" import RegName._
case BX => "bx"
case CX => "cx" if (size == Size.B8 && name == RegName.IP) {
case DX => "dx" throw new IllegalArgumentException("Cannot have 8 bit register for IP")
case SI => "si" }
case DI => "di" override def toString = name match {
case SP => "sp" case AX => tradToString("ax", "al")
case BP => "bp" case BX => tradToString("bx", "bl")
case IP => "ip" case CX => tradToString("cx", "cl")
case Reg8 => "8" case DX => tradToString("dx", "dl")
case Reg9 => "9" case SI => tradToString("si", "sil")
case Reg10 => "10" case DI => tradToString("di", "dil")
case Reg11 => "11" case SP => tradToString("sp", "spl")
case Reg12 => "12" case BP => tradToString("bp", "bpl")
case Reg13 => "13" case IP => tradToString("ip", "#INVALID")
case Reg14 => "14" case R8 => newToString(8)
case Reg15 => "15" case R9 => newToString(9)
case R10 => newToString(10)
case R11 => newToString(11)
case R12 => newToString(12)
case R13 => newToString(13)
case R14 => newToString(14)
case R15 => newToString(15)
}
private def tradToString(base: String, byteName: String): String =
size match {
case Size.Q64 => "r" + base
case Size.D32 => "e" + base
case Size.W16 => base
case Size.B8 => byteName
}
private def newToString(base: Int): String = {
val b = base.toString
"r" + (size match {
case Size.Q64 => b
case Size.D32 => b + "d"
case Size.W16 => b + "w"
case Size.B8 => b + "b"
})
} }
} }
@ -48,7 +81,9 @@ object assemblyIR {
case Scanf, case Scanf,
Fflush, Fflush,
Exit, Exit,
PrintF PrintF,
Malloc,
Free
private val plt = "@plt" private val plt = "@plt"
@ -57,28 +92,29 @@ object assemblyIR {
case Fflush => "fflush" + plt case Fflush => "fflush" + plt
case Exit => "exit" + plt case Exit => "exit" + plt
case PrintF => "printf" + plt case PrintF => "printf" + plt
case Malloc => "malloc" + plt
case Free => "free" + plt
} }
} }
// TODO register naming conventions are wrong case class MemLocation(pointer: Register, opSize: Size) extends Dest with Src {
case class Register(size: RegSize, name: RegName) extends Dest with Src { override def toString =
override def toString = s"${size}${name}" opSize.toString + s"[$pointer]"
}
case class MemLocation(pointer: Long | Register, opSize: SizeDir = SizeDir.Unspecified)
extends Dest
with Src {
override def toString = pointer match {
case hex: Long => opSize.toString + f"[0x$hex%X]"
case reg: Register => opSize.toString + s"[$reg]"
}
} }
// TODO to string is wacky
case class IndexAddress( case class IndexAddress(
base: Register, base: Register,
offset: Int | LabelArg, offset: Int | LabelArg,
opSize: SizeDir = SizeDir.Unspecified indexReg: Register = Register(Size.Q64, RegName.AX),
scale: Int = 0
) extends Dest ) extends Dest
with Src { with Src {
override def toString = s"$opSize[$base + $offset]" override def toString = if (scale != 0) {
s"[$base + $indexReg * $scale + $offset]"
} else {
s"[$base + $offset]"
}
} }
case class ImmediateVal(value: Int) extends Src { case class ImmediateVal(value: Int) extends Src {
@ -126,6 +162,11 @@ object assemblyIR {
override def toString = s"$name:" override def toString = s"$name:"
} }
case class Comment(comment: String) extends AsmLine {
override def toString =
comment.split("\n").map(line => s"# ${line}").mkString("\n")
}
enum Cond { enum Cond {
case Equal, case Equal,
NotEqual, NotEqual,
@ -172,17 +213,4 @@ object assemblyIR {
case String => "%s" case String => "%s"
} }
} }
enum SizeDir {
case Byte, Word, DWord, Unspecified
private val ptr = "ptr "
override def toString(): String = this match {
case Byte => "byte " + ptr
case Word => "word " + ptr // TODO check word/doubleword/quadword
case DWord => "dword " + ptr
case Unspecified => ""
}
}
} }

View File

@ -0,0 +1,35 @@
package wacc
object sizeExtensions {
import microWacc._
import types._
import assemblyIR.Size
extension (expr: Expr) {
/** Calculate the size (bytes) of the heap required for the expression. */
def heapSize: Int = (expr, expr.ty) match {
case (ArrayLiter(elems), KnownType.Array(KnownType.Char)) =>
KnownType.Int.size.toInt + elems.size.toInt * KnownType.Char.size.toInt
case (ArrayLiter(elems), ty) =>
KnownType.Int.size.toInt + elems.size * ty.elemSize.toInt
case _ => expr.ty.size.toInt
}
}
extension (ty: SemType) {
/** Calculate the size (bytes) of a type in a register. */
def size: Size = ty match {
case KnownType.Int => Size.D32
case KnownType.Bool | KnownType.Char => Size.B8
case KnownType.String | KnownType.Array(_) | KnownType.Pair(_, _) | ? => Size.Q64
}
def elemSize: Size = ty match {
case KnownType.Array(elem) => elem.size
case KnownType.Pair(_, _) => Size.Q64
case _ => ty.size
}
}
}

View File

@ -1,11 +1,12 @@
package wacc package wacc
import java.io.PrintStream import java.io.PrintStream
import cats.data.Chain
object writer { object writer {
import assemblyIR._ import assemblyIR._
def writeTo(asmList: List[AsmLine], printStream: PrintStream): Unit = { def writeTo(asmList: Chain[AsmLine], printStream: PrintStream): Unit = {
asmList.foreach(printStream.println) asmList.iterator.foreach(printStream.println)
} }
} }

View File

@ -74,6 +74,7 @@ object microWacc {
object Exit extends Builtin("exit")(?) object Exit extends Builtin("exit")(?)
object Free extends Builtin("free")(?) object Free extends Builtin("free")(?)
object Malloc extends Builtin("malloc")(?) object Malloc extends Builtin("malloc")(?)
object PrintCharArray extends Builtin("printCharArray")(?)
} }
case class Assign(lhs: LValue, rhs: Expr) extends Stmt case class Assign(lhs: LValue, rhs: Expr) extends Stmt

View File

@ -216,21 +216,37 @@ object typeChecker {
case ast.Print(expr, newline) => case ast.Print(expr, newline) =>
// This constraint should never fail, the scope-checker should have caught it already // This constraint should never fail, the scope-checker should have caught it already
val exprTyped = checkValue(expr, Constraint.Unconstrained) val exprTyped = checkValue(expr, Constraint.Unconstrained)
val format = exprTyped.ty match { val exprFormat = exprTyped.ty match {
case KnownType.Bool | KnownType.String => "%s" case KnownType.Bool | KnownType.String => "%s"
case KnownType.Char => "%c" case KnownType.Array(KnownType.Char) => "%.*s"
case KnownType.Int => "%d" case KnownType.Char => "%c"
case _ => "%p" case KnownType.Int => "%d"
case KnownType.Pair(_, _) | KnownType.Array(_) | ? => "%p"
} }
List( val printfCall = { (func: microWacc.Builtin, value: microWacc.Expr) =>
microWacc.Call( List(
microWacc.Builtin.Printf, microWacc.Call(
List( func,
s"$format${if newline then "\n" else ""}".toMicroWaccCharArray, List(
exprTyped s"$exprFormat${if newline then "\n" else ""}".toMicroWaccCharArray,
value
)
) )
) )
) }
exprTyped.ty match {
case KnownType.Bool =>
List(
microWacc.If(
exprTyped,
printfCall(microWacc.Builtin.Printf, "true".toMicroWaccCharArray),
printfCall(microWacc.Builtin.Printf, "false".toMicroWaccCharArray)
)
)
case KnownType.Array(KnownType.Char) =>
printfCall(microWacc.Builtin.PrintCharArray, exprTyped)
case _ => printfCall(microWacc.Builtin.Printf, exprTyped)
}
case ast.If(cond, thenStmt, elseStmt) => case ast.If(cond, thenStmt, elseStmt) =>
List( List(
microWacc.If( microWacc.If(

View File

@ -73,7 +73,7 @@ class ParallelExamplesSpec extends AnyFlatSpec with BeforeAndAfterAll {
) )
assert(process.exitValue == expectedExit) assert(process.exitValue == expectedExit)
assert(stdout.toString == expectedOutput) assert(stdout.toString.replaceAll("0x[0-9a-f]+", "#addrs#") == expectedOutput)
} }
} }
@ -86,23 +86,23 @@ class ParallelExamplesSpec extends AnyFlatSpec with BeforeAndAfterAll {
// format: off // format: off
// disable formatting to avoid binPack // disable formatting to avoid binPack
"^.*wacc-examples/valid/advanced.*$", "^.*wacc-examples/valid/advanced.*$",
"^.*wacc-examples/valid/array.*$", // "^.*wacc-examples/valid/array.*$",
// "^.*wacc-examples/valid/basic/exit.*$", // "^.*wacc-examples/valid/basic/exit.*$",
// "^.*wacc-examples/valid/basic/skip.*$", // "^.*wacc-examples/valid/basic/skip.*$",
"^.*wacc-examples/valid/expressions.*$", // "^.*wacc-examples/valid/expressions.*$",
"^.*wacc-examples/valid/function/nested_functions.*$", "^.*wacc-examples/valid/function/nested_functions.*$",
"^.*wacc-examples/valid/function/simple_functions.*$", "^.*wacc-examples/valid/function/simple_functions.*$",
// "^.*wacc-examples/valid/if.*$", // "^.*wacc-examples/valid/if.*$",
"^.*wacc-examples/valid/IO/print.*$", // "^.*wacc-examples/valid/IO/print.*$",
// "^.*wacc-examples/valid/IO/read.*$", // "^.*wacc-examples/valid/IO/read.*$",
"^.*wacc-examples/valid/IO/IOLoop.wacc.*$", "^.*wacc-examples/valid/IO/IOLoop.wacc.*$",
// "^.*wacc-examples/valid/IO/IOSequence.wacc.*$", // "^.*wacc-examples/valid/IO/IOSequence.wacc.*$",
"^.*wacc-examples/valid/pairs.*$", // "^.*wacc-examples/valid/pairs.*$",
"^.*wacc-examples/valid/runtimeErr.*$", "^.*wacc-examples/valid/runtimeErr.*$",
"^.*wacc-examples/valid/scope.*$", // "^.*wacc-examples/valid/scope.*$",
// "^.*wacc-examples/valid/sequence.*$", // "^.*wacc-examples/valid/sequence.*$",
// "^.*wacc-examples/valid/variables.*$", // "^.*wacc-examples/valid/variables.*$",
"^.*wacc-examples/valid/while.*$", // "^.*wacc-examples/valid/while.*$",
// format: on // format: on
).find(filename.matches).isDefined ).find(filename.matches).isDefined
} }

View File

@ -1,42 +1,38 @@
import org.scalatest.funsuite.AnyFunSuite import org.scalatest.funsuite.AnyFunSuite
import wacc.assemblyIR._ import wacc.assemblyIR._
import wacc.assemblyIR.Size._
import wacc.assemblyIR.RegName._
class instructionSpec extends AnyFunSuite { class instructionSpec extends AnyFunSuite {
val named64BitRegister = Register(RegSize.R64, RegName.AX) val named64BitRegister = Register(Q64, AX)
test("named 64-bit register toString") { test("named 64-bit register toString") {
assert(named64BitRegister.toString == "rax") assert(named64BitRegister.toString == "rax")
} }
val named32BitRegister = Register(RegSize.E32, RegName.AX) val named32BitRegister = Register(D32, AX)
test("named 32-bit register toString") { test("named 32-bit register toString") {
assert(named32BitRegister.toString == "eax") assert(named32BitRegister.toString == "eax")
} }
val scratch64BitRegister = Register(RegSize.R64, RegName.Reg8) val scratch64BitRegister = Register(Q64, R8)
test("scratch 64-bit register toString") { test("scratch 64-bit register toString") {
assert(scratch64BitRegister.toString == "r8") assert(scratch64BitRegister.toString == "r8")
} }
val scratch32BitRegister = Register(RegSize.E32, RegName.Reg8) val scratch32BitRegister = Register(D32, R8)
test("scratch 32-bit register toString") { test("scratch 32-bit register toString") {
assert(scratch32BitRegister.toString == "e8") assert(scratch32BitRegister.toString == "r8d")
} }
val memLocationWithHex = MemLocation(0x12345678) val memLocationWithRegister = MemLocation(named64BitRegister, Q64)
test("mem location with hex toString") {
assert(memLocationWithHex.toString == "[0x12345678]")
}
val memLocationWithRegister = MemLocation(named64BitRegister)
test("mem location with register toString") { test("mem location with register toString") {
assert(memLocationWithRegister.toString == "[rax]") assert(memLocationWithRegister.toString == "qword ptr [rax]")
} }
val immediateVal = ImmediateVal(123) val immediateVal = ImmediateVal(123)