Compare commits

...

10 Commits

10 changed files with 442 additions and 262 deletions

View File

@ -1,6 +1,7 @@
package wacc
import scala.collection.mutable
import cats.data.Chain
import parsley.{Failure, Success}
import scopt.OParser
import java.io.File
@ -63,7 +64,7 @@ def frontend(
}
val s = "enter an integer to echo"
def backend(typedProg: microWacc.Program): List[asm.AsmLine] =
def backend(typedProg: microWacc.Program): Chain[asm.AsmLine] =
asmGenerator.generateAsm(typedProg)
def compile(filename: String, outFile: Option[File] = None)(using

View File

@ -0,0 +1,90 @@
package wacc
import scala.collection.mutable.LinkedHashMap
import cats.data.Chain
class Stack {
import assemblyIR._
import assemblyIR.Size._
import sizeExtensions.size
import microWacc as mw
private val RSP = Register(Q64, RegName.SP)
private class StackValue(val size: Size, val offset: Int) {
def bottom: Int = offset + elemBytes
}
private val stack = LinkedHashMap[mw.Expr | Int, StackValue]()
private val elemBytes: Int = Q64.toInt
private def sizeBytes: Int = stack.size * elemBytes
/** The stack's size in bytes. */
def size: Int = stack.size
/** Push an expression onto the stack. */
def push(expr: mw.Expr, src: Register): AsmLine = {
stack += expr -> StackValue(src.size, sizeBytes)
Push(src)
}
/** Push a value onto the stack. */
def push(itemSize: Size, addr: Src): AsmLine = {
stack += stack.size -> StackValue(itemSize, sizeBytes)
Push(addr)
}
/** Reserve space for a variable on the stack. */
def reserve(ident: mw.Ident): AsmLine = {
stack += ident -> StackValue(ident.ty.size, sizeBytes)
Subtract(RSP, ImmediateVal(elemBytes))
}
/** Reserve space for a register on the stack. */
def reserve(src: Register): AsmLine = {
stack += stack.size -> StackValue(src.size, sizeBytes)
Subtract(RSP, ImmediateVal(src.size.toInt))
}
/** Reserve space for values on the stack.
*
* @param sizes
* The sizes of the values to reserve space for.
*/
def reserve(sizes: List[Size]): AsmLine = {
sizes.foreach { itemSize =>
stack += stack.size -> StackValue(itemSize, sizeBytes)
}
Subtract(RSP, ImmediateVal(elemBytes * sizes.size))
}
/** Pop a value from the stack into a register. Sizes MUST match. */
def pop(dest: Register): AsmLine = {
stack.remove(stack.last._1)
Pop(dest)
}
/** Drop the top n values from the stack. */
def drop(n: Int = 1): AsmLine = {
(1 to n).foreach { _ =>
stack.remove(stack.last._1)
}
Add(RSP, ImmediateVal(n * elemBytes))
}
/** Generate AsmLines within a scope, which is reset after the block. */
def withScope(block: () => Chain[AsmLine]): Chain[AsmLine] = {
val resetToSize = stack.size
var lines = block()
lines :+= drop(stack.size - resetToSize)
lines
}
/** Get an IndexAddress for a variable in the stack. */
def accessVar(ident: mw.Ident): IndexAddress =
IndexAddress(RSP, sizeBytes - stack(ident).bottom)
def contains(ident: mw.Ident): Boolean = stack.contains(ident)
def head: MemLocation = MemLocation(RSP, stack.last._2.size)
override def toString(): String = stack.toString
}

View File

@ -1,15 +1,16 @@
package wacc
import scala.collection.mutable.LinkedHashMap
import scala.collection.mutable.ListBuffer
import cats.data.Chain
import cats.syntax.foldable._
// import parsley.token.errors.Label
object asmGenerator {
import microWacc._
import assemblyIR._
import wacc.types._
import assemblyIR.Size._
import assemblyIR.RegName._
import types._
import sizeExtensions._
import lexer.escapedChars
abstract case class Error() {
@ -30,26 +31,20 @@ object asmGenerator {
def errLabel = ".L._errDivZero"
}
val RAX = Register(RegSize.R64, RegName.AX)
val EAX = Register(RegSize.E32, RegName.AX)
val ESP = Register(RegSize.E32, RegName.SP)
val EDX = Register(RegSize.E32, RegName.DX)
val RDI = Register(RegSize.R64, RegName.DI)
val RIP = Register(RegSize.R64, RegName.IP)
val RBP = Register(RegSize.R64, RegName.BP)
val RSI = Register(RegSize.R64, RegName.SI)
val RDX = Register(RegSize.R64, RegName.DX)
val RCX = Register(RegSize.R64, RegName.CX)
val R8 = Register(RegSize.R64, RegName.Reg8)
val R9 = Register(RegSize.R64, RegName.Reg9)
val argRegs = List(RDI, RSI, RDX, RCX, R8, R9)
private val RAX = Register(Q64, AX)
private val EAX = Register(D32, AX)
private val RDI = Register(Q64, DI)
private val RIP = Register(Q64, IP)
private val RBP = Register(Q64, BP)
private val RSI = Register(Q64, SI)
private val RDX = Register(Q64, DX)
private val RCX = Register(Q64, CX)
private val argRegs = List(DI, SI, DX, CX, R8, R9)
val _8_BIT_MASK = 0xff
extension [T](chain: Chain[T])
def +(item: T): Chain[T] = chain.append(item)
extension (chain: Chain[AsmLine])
def +(line: AsmLine): Chain[AsmLine] = chain.append(line)
def concatAll(chains: Chain[AsmLine]*): Chain[AsmLine] =
def concatAll(chains: Chain[T]*): Chain[T] =
chains.foldLeft(chain)(_ ++ _)
class LabelGenerator {
@ -64,7 +59,7 @@ object asmGenerator {
}
}
def generateAsm(microProg: Program): List[AsmLine] = {
def generateAsm(microProg: Program): Chain[AsmLine] = {
given stack: Stack = Stack()
given strings: ListBuffer[String] = ListBuffer[String]()
given labelGenerator: LabelGenerator = LabelGenerator()
@ -72,9 +67,8 @@ object asmGenerator {
val progAsm = Chain(LabelDef("main")).concatAll(
funcPrologue(),
Chain.one(stack.align()),
main.foldMap(generateStmt(_)),
Chain.one(Move(RAX, ImmediateVal(0))),
Chain.one(Xor(RAX, RAX)),
funcEpilogue(),
generateBuiltInFuncs(),
funcs.foldMap(generateUserFunc(_))
@ -96,78 +90,84 @@ object asmGenerator {
strDirs,
Chain.one(Directive.Text),
progAsm
).toList
)
}
def wrapFunc(labelName: String, funcBody: Chain[AsmLine])(using
stack: Stack,
strings: ListBuffer[String]
): Chain[AsmLine] = {
var chain = Chain.empty[AsmLine]
chain += LabelDef(labelName)
private def wrapBuiltinFunc(labelName: String, funcBody: Chain[AsmLine]): Chain[AsmLine] = {
var chain = Chain.one[AsmLine](LabelDef(labelName))
chain ++= funcPrologue()
chain ++= funcBody
chain ++= funcEpilogue()
chain
}
def generateUserFunc(func: FuncDecl)(using
private def generateUserFunc(func: FuncDecl)(using
strings: ListBuffer[String],
labelGenerator: LabelGenerator
): Chain[AsmLine] = {
given stack: Stack = Stack()
// Setup the stack with param 7 and up
func.params.drop(argRegs.size).foreach(stack.reserve(_))
var chain = Chain.empty[AsmLine]
var chain = Chain.one[AsmLine](LabelDef(labelGenerator.getLabel(func.name)))
chain ++= funcPrologue()
// Push the rest of params onto the stack for simplicity
argRegs.zip(func.params).foreach { (reg, param) =>
chain += stack.push(param, reg)
chain += stack.push(param, Register(Q64, reg))
}
chain ++= func.body.foldMap(generateStmt(_))
wrapFunc(labelGenerator.getLabel(func.name), chain)
// No need for epilogue here since all user functions must return explicitly
chain
}
def generateBuiltInFuncs()(using
stack: Stack,
strings: ListBuffer[String],
private def generateBuiltInFuncs()(using
labelGenerator: LabelGenerator
): Chain[AsmLine] = {
var chain = Chain.empty[AsmLine]
chain ++= wrapFunc(
chain ++= wrapBuiltinFunc(
labelGenerator.getLabel(Builtin.Exit),
Chain(stack.align(), assemblyIR.Call(CLibFunc.Exit))
Chain(stackAlign, assemblyIR.Call(CLibFunc.Exit))
)
chain ++= wrapFunc(
chain ++= wrapBuiltinFunc(
labelGenerator.getLabel(Builtin.Printf),
Chain(
stack.align(),
stackAlign,
assemblyIR.Call(CLibFunc.PrintF),
Move(RDI, ImmediateVal(0)),
Xor(RDI, RDI),
assemblyIR.Call(CLibFunc.Fflush)
)
)
chain ++= wrapFunc(
labelGenerator.getLabel(Builtin.Malloc),
Chain.one(stack.align())
chain ++= wrapBuiltinFunc(
labelGenerator.getLabel(Builtin.PrintCharArray),
Chain(
stackAlign,
Load(RDX, IndexAddress(RSI, KnownType.Int.size.toInt)),
Move(Register(D32, SI), MemLocation(RSI, D32)),
assemblyIR.Call(CLibFunc.PrintF),
Xor(RDI, RDI),
assemblyIR.Call(CLibFunc.Fflush)
)
)
chain ++= wrapFunc(labelGenerator.getLabel(Builtin.Free), Chain.empty)
chain ++= wrapBuiltinFunc(
labelGenerator.getLabel(Builtin.Malloc),
Chain(stackAlign, assemblyIR.Call(CLibFunc.Malloc))
// Out of memory check is optional
)
chain ++= wrapFunc(
chain ++= wrapBuiltinFunc(labelGenerator.getLabel(Builtin.Free), Chain.empty)
chain ++= wrapBuiltinFunc(
labelGenerator.getLabel(Builtin.Read),
Chain(
stack.align(),
stack.reserve(),
stack.push(RSI),
Load(RSI, stack.head),
stackAlign,
Subtract(Register(Q64, SP), ImmediateVal(8)),
Push(RSI),
Load(RSI, MemLocation(Register(Q64, SP), Q64)),
assemblyIR.Call(CLibFunc.Scanf),
stack.pop(RAX),
stack.drop()
Pop(RAX)
)
)
@ -175,7 +175,7 @@ object asmGenerator {
// TODO can this be done with a call to generateStmt?
// Consider other error cases -> look to generalise
LabelDef(zeroDivError.errLabel),
stack.align(),
stackAlign,
Load(RDI, IndexAddress(RIP, LabelArg(zeroDivError.strLabel))),
assemblyIR.Call(CLibFunc.PrintF),
Move(RDI, ImmediateVal(-1)),
@ -185,28 +185,34 @@ object asmGenerator {
chain
}
def generateStmt(stmt: Stmt)(using
private def generateStmt(stmt: Stmt)(using
stack: Stack,
strings: ListBuffer[String],
labelGenerator: LabelGenerator
): Chain[AsmLine] = {
var chain = Chain.empty[AsmLine]
chain += Comment(stmt.toString)
stmt match {
case Assign(lhs, rhs) =>
var dest: () => IndexAddress = () => IndexAddress(RAX, 0) // overwritten below
lhs match {
case ident: Ident =>
dest = stack.accessVar(ident)
if (!stack.contains(ident)) chain += stack.reserve(ident)
// TODO lhs = arrayElem
case _ =>
}
chain ++= evalExprOntoStack(rhs)
chain += stack.pop(RAX)
chain += Move(stack.accessVar(ident), RAX)
case ArrayElem(x, i) =>
chain ++= evalExprOntoStack(rhs)
chain ++= evalExprOntoStack(i)
chain ++= evalExprOntoStack(x)
chain += stack.pop(RAX)
chain += stack.pop(RCX)
chain += stack.pop(RDX)
chain ++= evalExprOntoStack(rhs)
chain += stack.pop(RAX)
chain += Move(dest(), RAX)
chain += Move(
IndexAddress(RAX, KnownType.Int.size.toInt, RCX, x.ty.elemSize.toInt),
Register(x.ty.elemSize, DX)
)
}
case If(cond, thenBranch, elseBranch) =>
val elseLabel = labelGenerator.getLabel()
@ -217,11 +223,11 @@ object asmGenerator {
chain += Compare(RAX, ImmediateVal(0))
chain += Jump(LabelArg(elseLabel), Cond.Equal)
chain ++= thenBranch.foldMap(generateStmt)
chain ++= stack.withScope(() => thenBranch.foldMap(generateStmt))
chain += Jump(LabelArg(endLabel))
chain += LabelDef(elseLabel)
chain ++= elseBranch.foldMap(generateStmt)
chain ++= stack.withScope(() => elseBranch.foldMap(generateStmt))
chain += LabelDef(endLabel)
case While(cond, body) =>
@ -234,106 +240,147 @@ object asmGenerator {
chain += Compare(RAX, ImmediateVal(0))
chain += Jump(LabelArg(endLabel), Cond.Equal)
chain ++= body.foldMap(generateStmt)
chain ++= stack.withScope(() => body.foldMap(generateStmt))
chain += Jump(LabelArg(startLabel))
chain += LabelDef(endLabel)
case microWacc.Return(expr) =>
chain ++= evalExprOntoStack(expr)
chain += stack.pop(RAX)
chain ++= funcEpilogue()
case call: microWacc.Call =>
chain ++= generateCall(call)
chain ++= generateCall(call, isTail = false)
case microWacc.Return(expr) =>
expr match {
case call: microWacc.Call =>
chain ++= generateCall(call, isTail = true) // tco
case _ =>
chain ++= evalExprOntoStack(expr)
chain += stack.pop(RAX)
chain ++= funcEpilogue()
}
}
chain
}
def evalExprOntoStack(expr: Expr)(using
private def evalExprOntoStack(expr: Expr)(using
stack: Stack,
strings: ListBuffer[String],
labelGenerator: LabelGenerator
): Chain[AsmLine] = {
var chain = Chain.empty[AsmLine]
val stackSizeStart = stack.size
expr match {
case IntLiter(v) => chain += stack.push(ImmediateVal(v))
case CharLiter(v) => chain += stack.push(ImmediateVal(v.toInt))
case ident: Ident => chain += stack.push(stack.accessVar(ident)())
case IntLiter(v) => chain += stack.push(KnownType.Int.size, ImmediateVal(v))
case CharLiter(v) => chain += stack.push(KnownType.Char.size, ImmediateVal(v.toInt))
case ident: Ident => chain += stack.push(ident.ty.size, stack.accessVar(ident))
case ArrayLiter(elems) =>
case array @ ArrayLiter(elems) =>
expr.ty match {
case KnownType.String =>
strings += elems.collect { case CharLiter(v) => v }.mkString
chain += Load(RAX, IndexAddress(RIP, LabelArg(s".L.str${strings.size - 1}")))
chain += stack.push(RAX)
case _ => // Other array types TODO
chain += stack.push(Q64, RAX)
case ty =>
chain ++= generateCall(
microWacc.Call(Builtin.Malloc, List(IntLiter(array.heapSize))),
isTail = false
)
chain += stack.push(Q64, RAX)
// Store the length of the array at the start
chain += Move(MemLocation(RAX, D32), ImmediateVal(elems.size))
elems.zipWithIndex.foldMap { (elem, i) =>
chain ++= evalExprOntoStack(elem)
chain += stack.pop(RCX)
chain += stack.pop(RAX)
chain += Move(IndexAddress(RAX, 4 + i * ty.elemSize.toInt), Register(ty.elemSize, CX))
chain += stack.push(Q64, RAX)
}
}
case BoolLiter(v) => chain += stack.push(ImmediateVal(if (v) 1 else 0))
case NullLiter() => chain += stack.push(ImmediateVal(0))
case ArrayElem(_, _) => // TODO: Implement handling
case BoolLiter(true) =>
chain += stack.push(KnownType.Bool.size, ImmediateVal(1))
case BoolLiter(false) =>
chain += Xor(RAX, RAX)
chain += stack.push(KnownType.Bool.size, RAX)
case NullLiter() =>
chain += stack.push(KnownType.Pair(?, ?).size, ImmediateVal(0))
case ArrayElem(x, i) =>
chain ++= evalExprOntoStack(x)
chain ++= evalExprOntoStack(i)
chain += stack.pop(RCX)
chain += stack.pop(RAX)
// + Int because we store the length of the array at the start
chain += Move(
Register(x.ty.elemSize, AX),
IndexAddress(RAX, KnownType.Int.size.toInt, RCX, x.ty.elemSize.toInt)
)
chain += stack.push(x.ty.elemSize, RAX)
case UnaryOp(x, op) =>
chain ++= evalExprOntoStack(x)
op match {
case UnaryOperator.Chr | UnaryOperator.Ord | UnaryOperator.Len => // No op needed
case UnaryOperator.Negate => chain += Negate(stack.head(SizeDir.DWord))
case UnaryOperator.Chr | UnaryOperator.Ord => // No op needed
case UnaryOperator.Len =>
chain += stack.pop(RAX)
chain += Move(EAX, MemLocation(RAX, D32))
chain += stack.push(D32, RAX)
case UnaryOperator.Negate =>
chain += Negate(stack.head)
case UnaryOperator.Not =>
chain += Xor(stack.head(SizeDir.DWord), ImmediateVal(1))
chain += Xor(stack.head, ImmediateVal(1))
}
case BinaryOp(x, y, op) =>
val destX = Register(x.ty.size, AX)
chain ++= evalExprOntoStack(y)
chain ++= evalExprOntoStack(x)
chain += stack.pop(RAX)
op match {
case BinaryOperator.Add => chain += Add(stack.head(SizeDir.DWord), EAX)
case BinaryOperator.Add =>
chain += Add(stack.head, destX)
case BinaryOperator.Sub =>
chain += Subtract(EAX, stack.head(SizeDir.DWord))
chain += Subtract(destX, stack.head)
chain += stack.drop()
chain += stack.push(RAX)
chain += stack.push(destX.size, RAX)
case BinaryOperator.Mul =>
chain += Multiply(EAX, stack.head(SizeDir.DWord))
chain += Multiply(destX, stack.head)
chain += stack.drop()
chain += stack.push(RAX)
chain += stack.push(destX.size, RAX)
case BinaryOperator.Div =>
chain += Compare(stack.head(SizeDir.DWord), ImmediateVal(0))
chain += Compare(stack.head, ImmediateVal(0))
chain += Jump(LabelArg(zeroDivError.errLabel), Cond.Equal)
chain += CDQ()
chain += Divide(stack.head(SizeDir.DWord))
chain += Divide(stack.head)
chain += stack.drop()
chain += stack.push(RAX)
chain += stack.push(destX.size, RAX)
case BinaryOperator.Mod =>
chain += CDQ()
chain += Divide(stack.head(SizeDir.DWord))
chain += Divide(stack.head)
chain += stack.drop()
chain += stack.push(RDX)
chain += stack.push(destX.size, RDX)
case BinaryOperator.Eq => chain ++= generateComparison(x, y, Cond.Equal)
case BinaryOperator.Neq => chain ++= generateComparison(x, y, Cond.NotEqual)
case BinaryOperator.Greater => chain ++= generateComparison(x, y, Cond.Greater)
case BinaryOperator.GreaterEq => chain ++= generateComparison(x, y, Cond.GreaterEqual)
case BinaryOperator.Less => chain ++= generateComparison(x, y, Cond.Less)
case BinaryOperator.LessEq => chain ++= generateComparison(x, y, Cond.LessEqual)
case BinaryOperator.And => chain += And(stack.head(SizeDir.DWord), EAX)
case BinaryOperator.Or => chain += Or(stack.head(SizeDir.DWord), EAX)
case BinaryOperator.Eq => chain ++= generateComparison(destX, Cond.Equal)
case BinaryOperator.Neq => chain ++= generateComparison(destX, Cond.NotEqual)
case BinaryOperator.Greater => chain ++= generateComparison(destX, Cond.Greater)
case BinaryOperator.GreaterEq => chain ++= generateComparison(destX, Cond.GreaterEqual)
case BinaryOperator.Less => chain ++= generateComparison(destX, Cond.Less)
case BinaryOperator.LessEq => chain ++= generateComparison(destX, Cond.LessEqual)
case BinaryOperator.And => chain += And(stack.head, destX)
case BinaryOperator.Or => chain += Or(stack.head, destX)
}
case call: microWacc.Call =>
chain ++= generateCall(call)
chain += stack.push(RAX)
chain ++= generateCall(call, isTail = false)
chain += stack.push(call.ty.size, RAX)
}
if chain.isEmpty then chain += stack.push(ImmediateVal(0))
assert(stack.size == stackSizeStart + 1)
chain ++= zeroRest(MemLocation(stack.head.pointer, Q64), expr.ty.size)
chain
}
def generateCall(call: microWacc.Call)(using
private def generateCall(call: microWacc.Call, isTail: Boolean)(using
stack: Stack,
strings: ListBuffer[String],
labelGenerator: LabelGenerator
@ -343,14 +390,19 @@ object asmGenerator {
argRegs.zip(args).foldMap { (reg, expr) =>
chain ++= evalExprOntoStack(expr)
chain += stack.pop(reg)
chain += stack.pop(Register(Q64, reg))
}
args.drop(argRegs.size).foldMap {
chain ++= evalExprOntoStack(_)
}
chain += assemblyIR.Call(LabelArg(labelGenerator.getLabel(target)))
// Tail Call Optimisation (TCO)
if (isTail) {
chain += Jump(LabelArg(labelGenerator.getLabel(target))) // tail call
} else {
chain += assemblyIR.Call(LabelArg(labelGenerator.getLabel(target))) // regular call
}
if (args.size > argRegs.size) {
chain += stack.drop(args.size - argRegs.size)
@ -359,79 +411,39 @@ object asmGenerator {
chain
}
def generateComparison(x: Expr, y: Expr, cond: Cond)(using
stack: Stack,
strings: ListBuffer[String],
labelGenerator: LabelGenerator
private def generateComparison(destX: Register, cond: Cond)(using
stack: Stack
): Chain[AsmLine] = {
var chain = Chain.empty[AsmLine]
chain ++= evalExprOntoStack(x)
chain ++= evalExprOntoStack(y)
chain += stack.pop(RAX)
chain += Compare(stack.head(SizeDir.DWord), EAX)
chain += Set(Register(RegSize.Byte, RegName.AL), cond)
chain += And(RAX, ImmediateVal(_8_BIT_MASK))
chain += Compare(destX, stack.head)
chain += Set(Register(B8, AX), cond)
chain ++= zeroRest(RAX, B8)
chain += stack.drop()
chain += stack.push(RAX)
chain += stack.push(B8, RAX)
chain
}
// Missing a sub instruction but dont think we need it
def funcPrologue()(using stack: Stack): Chain[AsmLine] = {
private def funcPrologue(): Chain[AsmLine] = {
var chain = Chain.empty[AsmLine]
chain += stack.push(RBP)
chain += Move(RBP, Register(RegSize.R64, RegName.SP))
chain += Push(RBP)
chain += Move(RBP, Register(Q64, SP))
chain
}
def funcEpilogue()(using stack: Stack): Chain[AsmLine] = {
private def funcEpilogue(): Chain[AsmLine] = {
var chain = Chain.empty[AsmLine]
chain += Move(Register(RegSize.R64, RegName.SP), RBP)
chain += Move(Register(Q64, SP), RBP)
chain += Pop(RBP)
chain += assemblyIR.Return()
chain
}
class Stack {
private val stack = LinkedHashMap[Expr | Int, Int]()
private val RSP = Register(RegSize.R64, RegName.SP)
private def next: Int = stack.size + 1
def push(expr: Expr, src: Src): AsmLine = {
stack += expr -> next
Push(src)
}
def push(src: Src): AsmLine = {
stack += stack.size -> next
Push(src)
}
def pop(dest: Src): AsmLine = {
stack.remove(stack.last._1)
Pop(dest)
}
def reserve(ident: Ident): AsmLine = {
stack += ident -> next
Subtract(RSP, ImmediateVal(8))
}
def reserve(n: Int = 1): AsmLine = {
(1 to n).foreach(_ => stack += stack.size -> next)
Subtract(RSP, ImmediateVal(n * 8))
}
def drop(n: Int = 1): AsmLine = {
(1 to n).foreach(_ => stack.remove(stack.last._1))
Add(RSP, ImmediateVal(n * 8))
}
def accessVar(ident: Ident): () => IndexAddress = () => {
IndexAddress(RSP, (stack.size - stack(ident)) * 8)
}
def head: MemLocation = MemLocation(RSP)
def head(size: SizeDir): MemLocation = MemLocation(RSP, size)
def contains(ident: Ident): Boolean = stack.contains(ident)
// TODO: Might want to actually properly handle this with the LinkedHashMap too
def align(): AsmLine = And(RSP, ImmediateVal(-16))
private def stackAlign: AsmLine = And(Register(Q64, SP), ImmediateVal(-16))
private def zeroRest(dest: Dest, size: Size): Chain[AsmLine] = size match {
case Q64 | D32 => Chain.empty
case _ => Chain.one(And(dest, ImmediateVal((1 << (size.toInt * 8)) - 1)))
}
private val escapedCharsMapping = escapedChars.map { case (k, v) => v -> s"\\$k" }

View File

@ -6,40 +6,73 @@ object assemblyIR {
sealed trait Operand
sealed trait Src extends Operand // mem location, register and imm value
sealed trait Dest extends Operand // mem location and register
enum RegSize {
case R64
case E32
case Byte
override def toString = this match {
case R64 => "r"
case E32 => "e"
case Byte => ""
enum Size {
case Q64, D32, W16, B8
def toInt: Int = this match {
case Q64 => 8
case D32 => 4
case W16 => 2
case B8 => 1
}
private val ptr = "ptr "
override def toString(): String = this match {
case Q64 => "qword " + ptr
case D32 => "dword " + ptr
case W16 => "word " + ptr
case B8 => "byte " + ptr
}
}
enum RegName {
case AX, AL, BX, CX, DX, SI, DI, SP, BP, IP, Reg8, Reg9, Reg10, Reg11, Reg12, Reg13, Reg14,
Reg15
override def toString = this match {
case AX => "ax"
case AL => "al"
case BX => "bx"
case CX => "cx"
case DX => "dx"
case SI => "si"
case DI => "di"
case SP => "sp"
case BP => "bp"
case IP => "ip"
case Reg8 => "8"
case Reg9 => "9"
case Reg10 => "10"
case Reg11 => "11"
case Reg12 => "12"
case Reg13 => "13"
case Reg14 => "14"
case Reg15 => "15"
case AX, BX, CX, DX, SI, DI, SP, BP, IP, R8, R9, R10, R11, R12, R13, R14, R15
}
case class Register(size: Size, name: RegName) extends Dest with Src {
import RegName._
if (size == Size.B8 && name == RegName.IP) {
throw new IllegalArgumentException("Cannot have 8 bit register for IP")
}
override def toString = name match {
case AX => tradToString("ax", "al")
case BX => tradToString("bx", "bl")
case CX => tradToString("cx", "cl")
case DX => tradToString("dx", "dl")
case SI => tradToString("si", "sil")
case DI => tradToString("di", "dil")
case SP => tradToString("sp", "spl")
case BP => tradToString("bp", "bpl")
case IP => tradToString("ip", "#INVALID")
case R8 => newToString(8)
case R9 => newToString(9)
case R10 => newToString(10)
case R11 => newToString(11)
case R12 => newToString(12)
case R13 => newToString(13)
case R14 => newToString(14)
case R15 => newToString(15)
}
private def tradToString(base: String, byteName: String): String =
size match {
case Size.Q64 => "r" + base
case Size.D32 => "e" + base
case Size.W16 => base
case Size.B8 => byteName
}
private def newToString(base: Int): String = {
val b = base.toString
"r" + (size match {
case Size.Q64 => b
case Size.D32 => b + "d"
case Size.W16 => b + "w"
case Size.B8 => b + "b"
})
}
}
@ -48,7 +81,9 @@ object assemblyIR {
case Scanf,
Fflush,
Exit,
PrintF
PrintF,
Malloc,
Free
private val plt = "@plt"
@ -57,28 +92,29 @@ object assemblyIR {
case Fflush => "fflush" + plt
case Exit => "exit" + plt
case PrintF => "printf" + plt
case Malloc => "malloc" + plt
case Free => "free" + plt
}
}
// TODO register naming conventions are wrong
case class Register(size: RegSize, name: RegName) extends Dest with Src {
override def toString = s"${size}${name}"
}
case class MemLocation(pointer: Long | Register, opSize: SizeDir = SizeDir.Unspecified)
extends Dest
with Src {
override def toString = pointer match {
case hex: Long => opSize.toString + f"[0x$hex%X]"
case reg: Register => opSize.toString + s"[$reg]"
}
case class MemLocation(pointer: Register, opSize: Size) extends Dest with Src {
override def toString =
opSize.toString + s"[$pointer]"
}
// TODO to string is wacky
case class IndexAddress(
base: Register,
offset: Int | LabelArg,
opSize: SizeDir = SizeDir.Unspecified
indexReg: Register = Register(Size.Q64, RegName.AX),
scale: Int = 0
) extends Dest
with Src {
override def toString = s"$opSize[$base + $offset]"
override def toString = if (scale != 0) {
s"[$base + $indexReg * $scale + $offset]"
} else {
s"[$base + $offset]"
}
}
case class ImmediateVal(value: Int) extends Src {
@ -126,6 +162,11 @@ object assemblyIR {
override def toString = s"$name:"
}
case class Comment(comment: String) extends AsmLine {
override def toString =
comment.split("\n").map(line => s"# ${line}").mkString("\n")
}
enum Cond {
case Equal,
NotEqual,
@ -172,17 +213,4 @@ object assemblyIR {
case String => "%s"
}
}
enum SizeDir {
case Byte, Word, DWord, Unspecified
private val ptr = "ptr "
override def toString(): String = this match {
case Byte => "byte " + ptr
case Word => "word " + ptr // TODO check word/doubleword/quadword
case DWord => "dword " + ptr
case Unspecified => ""
}
}
}

View File

@ -0,0 +1,35 @@
package wacc
object sizeExtensions {
import microWacc._
import types._
import assemblyIR.Size
extension (expr: Expr) {
/** Calculate the size (bytes) of the heap required for the expression. */
def heapSize: Int = (expr, expr.ty) match {
case (ArrayLiter(elems), KnownType.Array(KnownType.Char)) =>
KnownType.Int.size.toInt + elems.size.toInt * KnownType.Char.size.toInt
case (ArrayLiter(elems), ty) =>
KnownType.Int.size.toInt + elems.size * ty.elemSize.toInt
case _ => expr.ty.size.toInt
}
}
extension (ty: SemType) {
/** Calculate the size (bytes) of a type in a register. */
def size: Size = ty match {
case KnownType.Int => Size.D32
case KnownType.Bool | KnownType.Char => Size.B8
case KnownType.String | KnownType.Array(_) | KnownType.Pair(_, _) | ? => Size.Q64
}
def elemSize: Size = ty match {
case KnownType.Array(elem) => elem.size
case KnownType.Pair(_, _) => Size.Q64
case _ => ty.size
}
}
}

View File

@ -1,11 +1,12 @@
package wacc
import java.io.PrintStream
import cats.data.Chain
object writer {
import assemblyIR._
def writeTo(asmList: List[AsmLine], printStream: PrintStream): Unit = {
asmList.foreach(printStream.println)
def writeTo(asmList: Chain[AsmLine], printStream: PrintStream): Unit = {
asmList.iterator.foreach(printStream.println)
}
}

View File

@ -74,6 +74,7 @@ object microWacc {
object Exit extends Builtin("exit")(?)
object Free extends Builtin("free")(?)
object Malloc extends Builtin("malloc")(?)
object PrintCharArray extends Builtin("printCharArray")(?)
}
case class Assign(lhs: LValue, rhs: Expr) extends Stmt

View File

@ -216,21 +216,37 @@ object typeChecker {
case ast.Print(expr, newline) =>
// This constraint should never fail, the scope-checker should have caught it already
val exprTyped = checkValue(expr, Constraint.Unconstrained)
val format = exprTyped.ty match {
case KnownType.Bool | KnownType.String => "%s"
case KnownType.Char => "%c"
case KnownType.Int => "%d"
case _ => "%p"
val exprFormat = exprTyped.ty match {
case KnownType.Bool | KnownType.String => "%s"
case KnownType.Array(KnownType.Char) => "%.*s"
case KnownType.Char => "%c"
case KnownType.Int => "%d"
case KnownType.Pair(_, _) | KnownType.Array(_) | ? => "%p"
}
List(
microWacc.Call(
microWacc.Builtin.Printf,
List(
s"$format${if newline then "\n" else ""}".toMicroWaccCharArray,
exprTyped
val printfCall = { (func: microWacc.Builtin, value: microWacc.Expr) =>
List(
microWacc.Call(
func,
List(
s"$exprFormat${if newline then "\n" else ""}".toMicroWaccCharArray,
value
)
)
)
)
}
exprTyped.ty match {
case KnownType.Bool =>
List(
microWacc.If(
exprTyped,
printfCall(microWacc.Builtin.Printf, "true".toMicroWaccCharArray),
printfCall(microWacc.Builtin.Printf, "false".toMicroWaccCharArray)
)
)
case KnownType.Array(KnownType.Char) =>
printfCall(microWacc.Builtin.PrintCharArray, exprTyped)
case _ => printfCall(microWacc.Builtin.Printf, exprTyped)
}
case ast.If(cond, thenStmt, elseStmt) =>
List(
microWacc.If(

View File

@ -73,7 +73,7 @@ class ParallelExamplesSpec extends AnyFlatSpec with BeforeAndAfterAll {
)
assert(process.exitValue == expectedExit)
assert(stdout.toString == expectedOutput)
assert(stdout.toString.replaceAll("0x[0-9a-f]+", "#addrs#") == expectedOutput)
}
}
@ -86,23 +86,23 @@ class ParallelExamplesSpec extends AnyFlatSpec with BeforeAndAfterAll {
// format: off
// disable formatting to avoid binPack
"^.*wacc-examples/valid/advanced.*$",
"^.*wacc-examples/valid/array.*$",
// "^.*wacc-examples/valid/array.*$",
// "^.*wacc-examples/valid/basic/exit.*$",
// "^.*wacc-examples/valid/basic/skip.*$",
"^.*wacc-examples/valid/expressions.*$",
// "^.*wacc-examples/valid/expressions.*$",
"^.*wacc-examples/valid/function/nested_functions.*$",
"^.*wacc-examples/valid/function/simple_functions.*$",
// "^.*wacc-examples/valid/if.*$",
"^.*wacc-examples/valid/IO/print.*$",
// "^.*wacc-examples/valid/IO/print.*$",
// "^.*wacc-examples/valid/IO/read.*$",
"^.*wacc-examples/valid/IO/IOLoop.wacc.*$",
// "^.*wacc-examples/valid/IO/IOSequence.wacc.*$",
"^.*wacc-examples/valid/pairs.*$",
// "^.*wacc-examples/valid/pairs.*$",
"^.*wacc-examples/valid/runtimeErr.*$",
"^.*wacc-examples/valid/scope.*$",
// "^.*wacc-examples/valid/scope.*$",
// "^.*wacc-examples/valid/sequence.*$",
// "^.*wacc-examples/valid/variables.*$",
"^.*wacc-examples/valid/while.*$",
// "^.*wacc-examples/valid/while.*$",
// format: on
).find(filename.matches).isDefined
}

View File

@ -1,42 +1,38 @@
import org.scalatest.funsuite.AnyFunSuite
import wacc.assemblyIR._
import wacc.assemblyIR.Size._
import wacc.assemblyIR.RegName._
class instructionSpec extends AnyFunSuite {
val named64BitRegister = Register(RegSize.R64, RegName.AX)
val named64BitRegister = Register(Q64, AX)
test("named 64-bit register toString") {
assert(named64BitRegister.toString == "rax")
}
val named32BitRegister = Register(RegSize.E32, RegName.AX)
val named32BitRegister = Register(D32, AX)
test("named 32-bit register toString") {
assert(named32BitRegister.toString == "eax")
}
val scratch64BitRegister = Register(RegSize.R64, RegName.Reg8)
val scratch64BitRegister = Register(Q64, R8)
test("scratch 64-bit register toString") {
assert(scratch64BitRegister.toString == "r8")
}
val scratch32BitRegister = Register(RegSize.E32, RegName.Reg8)
val scratch32BitRegister = Register(D32, R8)
test("scratch 32-bit register toString") {
assert(scratch32BitRegister.toString == "e8")
assert(scratch32BitRegister.toString == "r8d")
}
val memLocationWithHex = MemLocation(0x12345678)
test("mem location with hex toString") {
assert(memLocationWithHex.toString == "[0x12345678]")
}
val memLocationWithRegister = MemLocation(named64BitRegister)
val memLocationWithRegister = MemLocation(named64BitRegister, Q64)
test("mem location with register toString") {
assert(memLocationWithRegister.toString == "[rax]")
assert(memLocationWithRegister.toString == "qword ptr [rax]")
}
val immediateVal = ImmediateVal(123)