Compare commits

..

No commits in common. "82a3d6068bce5ad794b9f2a2a5405f470860568c" and "c748a34e4ce4514403d659dcf801850ad1dd5bde" have entirely different histories.

10 changed files with 262 additions and 442 deletions

View File

@ -1,7 +1,6 @@
package wacc package wacc
import scala.collection.mutable import scala.collection.mutable
import cats.data.Chain
import parsley.{Failure, Success} import parsley.{Failure, Success}
import scopt.OParser import scopt.OParser
import java.io.File import java.io.File
@ -64,7 +63,7 @@ def frontend(
} }
val s = "enter an integer to echo" val s = "enter an integer to echo"
def backend(typedProg: microWacc.Program): Chain[asm.AsmLine] = def backend(typedProg: microWacc.Program): List[asm.AsmLine] =
asmGenerator.generateAsm(typedProg) asmGenerator.generateAsm(typedProg)
def compile(filename: String, outFile: Option[File] = None)(using def compile(filename: String, outFile: Option[File] = None)(using

View File

@ -1,90 +0,0 @@
package wacc
import scala.collection.mutable.LinkedHashMap
import cats.data.Chain
class Stack {
import assemblyIR._
import assemblyIR.Size._
import sizeExtensions.size
import microWacc as mw
private val RSP = Register(Q64, RegName.SP)
private class StackValue(val size: Size, val offset: Int) {
def bottom: Int = offset + elemBytes
}
private val stack = LinkedHashMap[mw.Expr | Int, StackValue]()
private val elemBytes: Int = Q64.toInt
private def sizeBytes: Int = stack.size * elemBytes
/** The stack's size in bytes. */
def size: Int = stack.size
/** Push an expression onto the stack. */
def push(expr: mw.Expr, src: Register): AsmLine = {
stack += expr -> StackValue(src.size, sizeBytes)
Push(src)
}
/** Push a value onto the stack. */
def push(itemSize: Size, addr: Src): AsmLine = {
stack += stack.size -> StackValue(itemSize, sizeBytes)
Push(addr)
}
/** Reserve space for a variable on the stack. */
def reserve(ident: mw.Ident): AsmLine = {
stack += ident -> StackValue(ident.ty.size, sizeBytes)
Subtract(RSP, ImmediateVal(elemBytes))
}
/** Reserve space for a register on the stack. */
def reserve(src: Register): AsmLine = {
stack += stack.size -> StackValue(src.size, sizeBytes)
Subtract(RSP, ImmediateVal(src.size.toInt))
}
/** Reserve space for values on the stack.
*
* @param sizes
* The sizes of the values to reserve space for.
*/
def reserve(sizes: List[Size]): AsmLine = {
sizes.foreach { itemSize =>
stack += stack.size -> StackValue(itemSize, sizeBytes)
}
Subtract(RSP, ImmediateVal(elemBytes * sizes.size))
}
/** Pop a value from the stack into a register. Sizes MUST match. */
def pop(dest: Register): AsmLine = {
stack.remove(stack.last._1)
Pop(dest)
}
/** Drop the top n values from the stack. */
def drop(n: Int = 1): AsmLine = {
(1 to n).foreach { _ =>
stack.remove(stack.last._1)
}
Add(RSP, ImmediateVal(n * elemBytes))
}
/** Generate AsmLines within a scope, which is reset after the block. */
def withScope(block: () => Chain[AsmLine]): Chain[AsmLine] = {
val resetToSize = stack.size
var lines = block()
lines :+= drop(stack.size - resetToSize)
lines
}
/** Get an IndexAddress for a variable in the stack. */
def accessVar(ident: mw.Ident): IndexAddress =
IndexAddress(RSP, sizeBytes - stack(ident).bottom)
def contains(ident: mw.Ident): Boolean = stack.contains(ident)
def head: MemLocation = MemLocation(RSP, stack.last._2.size)
override def toString(): String = stack.toString
}

View File

@ -1,16 +1,15 @@
package wacc package wacc
import scala.collection.mutable.LinkedHashMap
import scala.collection.mutable.ListBuffer import scala.collection.mutable.ListBuffer
import cats.data.Chain import cats.data.Chain
import cats.syntax.foldable._ import cats.syntax.foldable._
// import parsley.token.errors.Label
object asmGenerator { object asmGenerator {
import microWacc._ import microWacc._
import assemblyIR._ import assemblyIR._
import assemblyIR.Size._ import wacc.types._
import assemblyIR.RegName._
import types._
import sizeExtensions._
import lexer.escapedChars import lexer.escapedChars
abstract case class Error() { abstract case class Error() {
@ -31,20 +30,26 @@ object asmGenerator {
def errLabel = ".L._errDivZero" def errLabel = ".L._errDivZero"
} }
private val RAX = Register(Q64, AX) val RAX = Register(RegSize.R64, RegName.AX)
private val EAX = Register(D32, AX) val EAX = Register(RegSize.E32, RegName.AX)
private val RDI = Register(Q64, DI) val ESP = Register(RegSize.E32, RegName.SP)
private val RIP = Register(Q64, IP) val EDX = Register(RegSize.E32, RegName.DX)
private val RBP = Register(Q64, BP) val RDI = Register(RegSize.R64, RegName.DI)
private val RSI = Register(Q64, SI) val RIP = Register(RegSize.R64, RegName.IP)
private val RDX = Register(Q64, DX) val RBP = Register(RegSize.R64, RegName.BP)
private val RCX = Register(Q64, CX) val RSI = Register(RegSize.R64, RegName.SI)
private val argRegs = List(DI, SI, DX, CX, R8, R9) val RDX = Register(RegSize.R64, RegName.DX)
val RCX = Register(RegSize.R64, RegName.CX)
val R8 = Register(RegSize.R64, RegName.Reg8)
val R9 = Register(RegSize.R64, RegName.Reg9)
val argRegs = List(RDI, RSI, RDX, RCX, R8, R9)
extension [T](chain: Chain[T]) val _8_BIT_MASK = 0xff
def +(item: T): Chain[T] = chain.append(item)
def concatAll(chains: Chain[T]*): Chain[T] = extension (chain: Chain[AsmLine])
def +(line: AsmLine): Chain[AsmLine] = chain.append(line)
def concatAll(chains: Chain[AsmLine]*): Chain[AsmLine] =
chains.foldLeft(chain)(_ ++ _) chains.foldLeft(chain)(_ ++ _)
class LabelGenerator { class LabelGenerator {
@ -59,7 +64,7 @@ object asmGenerator {
} }
} }
def generateAsm(microProg: Program): Chain[AsmLine] = { def generateAsm(microProg: Program): List[AsmLine] = {
given stack: Stack = Stack() given stack: Stack = Stack()
given strings: ListBuffer[String] = ListBuffer[String]() given strings: ListBuffer[String] = ListBuffer[String]()
given labelGenerator: LabelGenerator = LabelGenerator() given labelGenerator: LabelGenerator = LabelGenerator()
@ -67,8 +72,9 @@ object asmGenerator {
val progAsm = Chain(LabelDef("main")).concatAll( val progAsm = Chain(LabelDef("main")).concatAll(
funcPrologue(), funcPrologue(),
Chain.one(stack.align()),
main.foldMap(generateStmt(_)), main.foldMap(generateStmt(_)),
Chain.one(Xor(RAX, RAX)), Chain.one(Move(RAX, ImmediateVal(0))),
funcEpilogue(), funcEpilogue(),
generateBuiltInFuncs(), generateBuiltInFuncs(),
funcs.foldMap(generateUserFunc(_)) funcs.foldMap(generateUserFunc(_))
@ -90,84 +96,78 @@ object asmGenerator {
strDirs, strDirs,
Chain.one(Directive.Text), Chain.one(Directive.Text),
progAsm progAsm
) ).toList
} }
private def wrapBuiltinFunc(labelName: String, funcBody: Chain[AsmLine]): Chain[AsmLine] = { def wrapFunc(labelName: String, funcBody: Chain[AsmLine])(using
var chain = Chain.one[AsmLine](LabelDef(labelName)) stack: Stack,
strings: ListBuffer[String]
): Chain[AsmLine] = {
var chain = Chain.empty[AsmLine]
chain += LabelDef(labelName)
chain ++= funcPrologue() chain ++= funcPrologue()
chain ++= funcBody chain ++= funcBody
chain ++= funcEpilogue() chain ++= funcEpilogue()
chain chain
} }
private def generateUserFunc(func: FuncDecl)(using def generateUserFunc(func: FuncDecl)(using
strings: ListBuffer[String], strings: ListBuffer[String],
labelGenerator: LabelGenerator labelGenerator: LabelGenerator
): Chain[AsmLine] = { ): Chain[AsmLine] = {
given stack: Stack = Stack() given stack: Stack = Stack()
// Setup the stack with param 7 and up // Setup the stack with param 7 and up
func.params.drop(argRegs.size).foreach(stack.reserve(_)) func.params.drop(argRegs.size).foreach(stack.reserve(_))
var chain = Chain.one[AsmLine](LabelDef(labelGenerator.getLabel(func.name))) var chain = Chain.empty[AsmLine]
chain ++= funcPrologue()
// Push the rest of params onto the stack for simplicity // Push the rest of params onto the stack for simplicity
argRegs.zip(func.params).foreach { (reg, param) => argRegs.zip(func.params).foreach { (reg, param) =>
chain += stack.push(param, Register(Q64, reg)) chain += stack.push(param, reg)
} }
chain ++= func.body.foldMap(generateStmt(_)) chain ++= func.body.foldMap(generateStmt(_))
// No need for epilogue here since all user functions must return explicitly wrapFunc(labelGenerator.getLabel(func.name), chain)
chain
} }
private def generateBuiltInFuncs()(using def generateBuiltInFuncs()(using
stack: Stack,
strings: ListBuffer[String],
labelGenerator: LabelGenerator labelGenerator: LabelGenerator
): Chain[AsmLine] = { ): Chain[AsmLine] = {
var chain = Chain.empty[AsmLine] var chain = Chain.empty[AsmLine]
chain ++= wrapBuiltinFunc( chain ++= wrapFunc(
labelGenerator.getLabel(Builtin.Exit), labelGenerator.getLabel(Builtin.Exit),
Chain(stackAlign, assemblyIR.Call(CLibFunc.Exit)) Chain(stack.align(), assemblyIR.Call(CLibFunc.Exit))
) )
chain ++= wrapBuiltinFunc( chain ++= wrapFunc(
labelGenerator.getLabel(Builtin.Printf), labelGenerator.getLabel(Builtin.Printf),
Chain( Chain(
stackAlign, stack.align(),
assemblyIR.Call(CLibFunc.PrintF), assemblyIR.Call(CLibFunc.PrintF),
Xor(RDI, RDI), Move(RDI, ImmediateVal(0)),
assemblyIR.Call(CLibFunc.Fflush) assemblyIR.Call(CLibFunc.Fflush)
) )
) )
chain ++= wrapBuiltinFunc( chain ++= wrapFunc(
labelGenerator.getLabel(Builtin.PrintCharArray),
Chain(
stackAlign,
Load(RDX, IndexAddress(RSI, KnownType.Int.size.toInt)),
Move(Register(D32, SI), MemLocation(RSI, D32)),
assemblyIR.Call(CLibFunc.PrintF),
Xor(RDI, RDI),
assemblyIR.Call(CLibFunc.Fflush)
)
)
chain ++= wrapBuiltinFunc(
labelGenerator.getLabel(Builtin.Malloc), labelGenerator.getLabel(Builtin.Malloc),
Chain(stackAlign, assemblyIR.Call(CLibFunc.Malloc)) Chain.one(stack.align())
// Out of memory check is optional
) )
chain ++= wrapBuiltinFunc(labelGenerator.getLabel(Builtin.Free), Chain.empty) chain ++= wrapFunc(labelGenerator.getLabel(Builtin.Free), Chain.empty)
chain ++= wrapBuiltinFunc( chain ++= wrapFunc(
labelGenerator.getLabel(Builtin.Read), labelGenerator.getLabel(Builtin.Read),
Chain( Chain(
stackAlign, stack.align(),
Subtract(Register(Q64, SP), ImmediateVal(8)), stack.reserve(),
Push(RSI), stack.push(RSI),
Load(RSI, MemLocation(Register(Q64, SP), Q64)), Load(RSI, stack.head),
assemblyIR.Call(CLibFunc.Scanf), assemblyIR.Call(CLibFunc.Scanf),
Pop(RAX) stack.pop(RAX),
stack.drop()
) )
) )
@ -175,7 +175,7 @@ object asmGenerator {
// TODO can this be done with a call to generateStmt? // TODO can this be done with a call to generateStmt?
// Consider other error cases -> look to generalise // Consider other error cases -> look to generalise
LabelDef(zeroDivError.errLabel), LabelDef(zeroDivError.errLabel),
stackAlign, stack.align(),
Load(RDI, IndexAddress(RIP, LabelArg(zeroDivError.strLabel))), Load(RDI, IndexAddress(RIP, LabelArg(zeroDivError.strLabel))),
assemblyIR.Call(CLibFunc.PrintF), assemblyIR.Call(CLibFunc.PrintF),
Move(RDI, ImmediateVal(-1)), Move(RDI, ImmediateVal(-1)),
@ -185,35 +185,29 @@ object asmGenerator {
chain chain
} }
private def generateStmt(stmt: Stmt)(using def generateStmt(stmt: Stmt)(using
stack: Stack, stack: Stack,
strings: ListBuffer[String], strings: ListBuffer[String],
labelGenerator: LabelGenerator labelGenerator: LabelGenerator
): Chain[AsmLine] = { ): Chain[AsmLine] = {
var chain = Chain.empty[AsmLine] var chain = Chain.empty[AsmLine]
chain += Comment(stmt.toString)
stmt match { stmt match {
case Assign(lhs, rhs) => case Assign(lhs, rhs) =>
var dest: () => IndexAddress = () => IndexAddress(RAX, 0) // overwritten below
lhs match { lhs match {
case ident: Ident => case ident: Ident =>
dest = stack.accessVar(ident)
if (!stack.contains(ident)) chain += stack.reserve(ident) if (!stack.contains(ident)) chain += stack.reserve(ident)
chain ++= evalExprOntoStack(rhs) // TODO lhs = arrayElem
chain += stack.pop(RAX) case _ =>
chain += Move(stack.accessVar(ident), RAX)
case ArrayElem(x, i) =>
chain ++= evalExprOntoStack(rhs)
chain ++= evalExprOntoStack(i)
chain ++= evalExprOntoStack(x)
chain += stack.pop(RAX)
chain += stack.pop(RCX)
chain += stack.pop(RDX)
chain += Move(
IndexAddress(RAX, KnownType.Int.size.toInt, RCX, x.ty.elemSize.toInt),
Register(x.ty.elemSize, DX)
)
} }
chain ++= evalExprOntoStack(rhs)
chain += stack.pop(RAX)
chain += Move(dest(), RAX)
case If(cond, thenBranch, elseBranch) => case If(cond, thenBranch, elseBranch) =>
val elseLabel = labelGenerator.getLabel() val elseLabel = labelGenerator.getLabel()
val endLabel = labelGenerator.getLabel() val endLabel = labelGenerator.getLabel()
@ -223,11 +217,11 @@ object asmGenerator {
chain += Compare(RAX, ImmediateVal(0)) chain += Compare(RAX, ImmediateVal(0))
chain += Jump(LabelArg(elseLabel), Cond.Equal) chain += Jump(LabelArg(elseLabel), Cond.Equal)
chain ++= stack.withScope(() => thenBranch.foldMap(generateStmt)) chain ++= thenBranch.foldMap(generateStmt)
chain += Jump(LabelArg(endLabel)) chain += Jump(LabelArg(endLabel))
chain += LabelDef(elseLabel) chain += LabelDef(elseLabel)
chain ++= stack.withScope(() => elseBranch.foldMap(generateStmt)) chain ++= elseBranch.foldMap(generateStmt)
chain += LabelDef(endLabel) chain += LabelDef(endLabel)
case While(cond, body) => case While(cond, body) =>
@ -240,147 +234,106 @@ object asmGenerator {
chain += Compare(RAX, ImmediateVal(0)) chain += Compare(RAX, ImmediateVal(0))
chain += Jump(LabelArg(endLabel), Cond.Equal) chain += Jump(LabelArg(endLabel), Cond.Equal)
chain ++= stack.withScope(() => body.foldMap(generateStmt)) chain ++= body.foldMap(generateStmt)
chain += Jump(LabelArg(startLabel)) chain += Jump(LabelArg(startLabel))
chain += LabelDef(endLabel) chain += LabelDef(endLabel)
case call: microWacc.Call =>
chain ++= generateCall(call, isTail = false)
case microWacc.Return(expr) => case microWacc.Return(expr) =>
expr match {
case call: microWacc.Call =>
chain ++= generateCall(call, isTail = true) // tco
case _ =>
chain ++= evalExprOntoStack(expr) chain ++= evalExprOntoStack(expr)
chain += stack.pop(RAX) chain += stack.pop(RAX)
chain ++= funcEpilogue() chain ++= funcEpilogue()
}
case call: microWacc.Call =>
chain ++= generateCall(call)
} }
chain chain
} }
private def evalExprOntoStack(expr: Expr)(using def evalExprOntoStack(expr: Expr)(using
stack: Stack, stack: Stack,
strings: ListBuffer[String], strings: ListBuffer[String],
labelGenerator: LabelGenerator labelGenerator: LabelGenerator
): Chain[AsmLine] = { ): Chain[AsmLine] = {
var chain = Chain.empty[AsmLine] var chain = Chain.empty[AsmLine]
val stackSizeStart = stack.size
expr match {
case IntLiter(v) => chain += stack.push(KnownType.Int.size, ImmediateVal(v))
case CharLiter(v) => chain += stack.push(KnownType.Char.size, ImmediateVal(v.toInt))
case ident: Ident => chain += stack.push(ident.ty.size, stack.accessVar(ident))
case array @ ArrayLiter(elems) => expr match {
case IntLiter(v) => chain += stack.push(ImmediateVal(v))
case CharLiter(v) => chain += stack.push(ImmediateVal(v.toInt))
case ident: Ident => chain += stack.push(stack.accessVar(ident)())
case ArrayLiter(elems) =>
expr.ty match { expr.ty match {
case KnownType.String => case KnownType.String =>
strings += elems.collect { case CharLiter(v) => v }.mkString strings += elems.collect { case CharLiter(v) => v }.mkString
chain += Load(RAX, IndexAddress(RIP, LabelArg(s".L.str${strings.size - 1}"))) chain += Load(RAX, IndexAddress(RIP, LabelArg(s".L.str${strings.size - 1}")))
chain += stack.push(Q64, RAX) chain += stack.push(RAX)
case ty => case _ => // Other array types TODO
chain ++= generateCall(
microWacc.Call(Builtin.Malloc, List(IntLiter(array.heapSize))),
isTail = false
)
chain += stack.push(Q64, RAX)
// Store the length of the array at the start
chain += Move(MemLocation(RAX, D32), ImmediateVal(elems.size))
elems.zipWithIndex.foldMap { (elem, i) =>
chain ++= evalExprOntoStack(elem)
chain += stack.pop(RCX)
chain += stack.pop(RAX)
chain += Move(IndexAddress(RAX, 4 + i * ty.elemSize.toInt), Register(ty.elemSize, CX))
chain += stack.push(Q64, RAX)
}
} }
case BoolLiter(true) => case BoolLiter(v) => chain += stack.push(ImmediateVal(if (v) 1 else 0))
chain += stack.push(KnownType.Bool.size, ImmediateVal(1)) case NullLiter() => chain += stack.push(ImmediateVal(0))
case BoolLiter(false) => case ArrayElem(_, _) => // TODO: Implement handling
chain += Xor(RAX, RAX)
chain += stack.push(KnownType.Bool.size, RAX)
case NullLiter() =>
chain += stack.push(KnownType.Pair(?, ?).size, ImmediateVal(0))
case ArrayElem(x, i) =>
chain ++= evalExprOntoStack(x)
chain ++= evalExprOntoStack(i)
chain += stack.pop(RCX)
chain += stack.pop(RAX)
// + Int because we store the length of the array at the start
chain += Move(
Register(x.ty.elemSize, AX),
IndexAddress(RAX, KnownType.Int.size.toInt, RCX, x.ty.elemSize.toInt)
)
chain += stack.push(x.ty.elemSize, RAX)
case UnaryOp(x, op) => case UnaryOp(x, op) =>
chain ++= evalExprOntoStack(x) chain ++= evalExprOntoStack(x)
op match { op match {
case UnaryOperator.Chr | UnaryOperator.Ord => // No op needed case UnaryOperator.Chr | UnaryOperator.Ord | UnaryOperator.Len => // No op needed
case UnaryOperator.Len => case UnaryOperator.Negate => chain += Negate(stack.head(SizeDir.DWord))
chain += stack.pop(RAX)
chain += Move(EAX, MemLocation(RAX, D32))
chain += stack.push(D32, RAX)
case UnaryOperator.Negate =>
chain += Negate(stack.head)
case UnaryOperator.Not => case UnaryOperator.Not =>
chain += Xor(stack.head, ImmediateVal(1)) chain += Xor(stack.head(SizeDir.DWord), ImmediateVal(1))
} }
case BinaryOp(x, y, op) => case BinaryOp(x, y, op) =>
val destX = Register(x.ty.size, AX)
chain ++= evalExprOntoStack(y) chain ++= evalExprOntoStack(y)
chain ++= evalExprOntoStack(x) chain ++= evalExprOntoStack(x)
chain += stack.pop(RAX) chain += stack.pop(RAX)
op match { op match {
case BinaryOperator.Add => case BinaryOperator.Add => chain += Add(stack.head(SizeDir.DWord), EAX)
chain += Add(stack.head, destX)
case BinaryOperator.Sub => case BinaryOperator.Sub =>
chain += Subtract(destX, stack.head) chain += Subtract(EAX, stack.head(SizeDir.DWord))
chain += stack.drop() chain += stack.drop()
chain += stack.push(destX.size, RAX) chain += stack.push(RAX)
case BinaryOperator.Mul => case BinaryOperator.Mul =>
chain += Multiply(destX, stack.head) chain += Multiply(EAX, stack.head(SizeDir.DWord))
chain += stack.drop() chain += stack.drop()
chain += stack.push(destX.size, RAX) chain += stack.push(RAX)
case BinaryOperator.Div => case BinaryOperator.Div =>
chain += Compare(stack.head, ImmediateVal(0)) chain += Compare(stack.head(SizeDir.DWord), ImmediateVal(0))
chain += Jump(LabelArg(zeroDivError.errLabel), Cond.Equal) chain += Jump(LabelArg(zeroDivError.errLabel), Cond.Equal)
chain += CDQ() chain += CDQ()
chain += Divide(stack.head) chain += Divide(stack.head(SizeDir.DWord))
chain += stack.drop() chain += stack.drop()
chain += stack.push(destX.size, RAX) chain += stack.push(RAX)
case BinaryOperator.Mod => case BinaryOperator.Mod =>
chain += CDQ() chain += CDQ()
chain += Divide(stack.head) chain += Divide(stack.head(SizeDir.DWord))
chain += stack.drop() chain += stack.drop()
chain += stack.push(destX.size, RDX) chain += stack.push(RDX)
case BinaryOperator.Eq => chain ++= generateComparison(destX, Cond.Equal) case BinaryOperator.Eq => chain ++= generateComparison(x, y, Cond.Equal)
case BinaryOperator.Neq => chain ++= generateComparison(destX, Cond.NotEqual) case BinaryOperator.Neq => chain ++= generateComparison(x, y, Cond.NotEqual)
case BinaryOperator.Greater => chain ++= generateComparison(destX, Cond.Greater) case BinaryOperator.Greater => chain ++= generateComparison(x, y, Cond.Greater)
case BinaryOperator.GreaterEq => chain ++= generateComparison(destX, Cond.GreaterEqual) case BinaryOperator.GreaterEq => chain ++= generateComparison(x, y, Cond.GreaterEqual)
case BinaryOperator.Less => chain ++= generateComparison(destX, Cond.Less) case BinaryOperator.Less => chain ++= generateComparison(x, y, Cond.Less)
case BinaryOperator.LessEq => chain ++= generateComparison(destX, Cond.LessEqual) case BinaryOperator.LessEq => chain ++= generateComparison(x, y, Cond.LessEqual)
case BinaryOperator.And => chain += And(stack.head, destX) case BinaryOperator.And => chain += And(stack.head(SizeDir.DWord), EAX)
case BinaryOperator.Or => chain += Or(stack.head, destX) case BinaryOperator.Or => chain += Or(stack.head(SizeDir.DWord), EAX)
} }
case call: microWacc.Call => case call: microWacc.Call =>
chain ++= generateCall(call, isTail = false) chain ++= generateCall(call)
chain += stack.push(call.ty.size, RAX) chain += stack.push(RAX)
} }
assert(stack.size == stackSizeStart + 1) if chain.isEmpty then chain += stack.push(ImmediateVal(0))
chain ++= zeroRest(MemLocation(stack.head.pointer, Q64), expr.ty.size)
chain chain
} }
private def generateCall(call: microWacc.Call, isTail: Boolean)(using def generateCall(call: microWacc.Call)(using
stack: Stack, stack: Stack,
strings: ListBuffer[String], strings: ListBuffer[String],
labelGenerator: LabelGenerator labelGenerator: LabelGenerator
@ -390,19 +343,14 @@ object asmGenerator {
argRegs.zip(args).foldMap { (reg, expr) => argRegs.zip(args).foldMap { (reg, expr) =>
chain ++= evalExprOntoStack(expr) chain ++= evalExprOntoStack(expr)
chain += stack.pop(Register(Q64, reg)) chain += stack.pop(reg)
} }
args.drop(argRegs.size).foldMap { args.drop(argRegs.size).foldMap {
chain ++= evalExprOntoStack(_) chain ++= evalExprOntoStack(_)
} }
// Tail Call Optimisation (TCO) chain += assemblyIR.Call(LabelArg(labelGenerator.getLabel(target)))
if (isTail) {
chain += Jump(LabelArg(labelGenerator.getLabel(target))) // tail call
} else {
chain += assemblyIR.Call(LabelArg(labelGenerator.getLabel(target))) // regular call
}
if (args.size > argRegs.size) { if (args.size > argRegs.size) {
chain += stack.drop(args.size - argRegs.size) chain += stack.drop(args.size - argRegs.size)
@ -411,39 +359,79 @@ object asmGenerator {
chain chain
} }
private def generateComparison(destX: Register, cond: Cond)(using def generateComparison(x: Expr, y: Expr, cond: Cond)(using
stack: Stack stack: Stack,
strings: ListBuffer[String],
labelGenerator: LabelGenerator
): Chain[AsmLine] = { ): Chain[AsmLine] = {
var chain = Chain.empty[AsmLine] var chain = Chain.empty[AsmLine]
chain += Compare(destX, stack.head) chain ++= evalExprOntoStack(x)
chain += Set(Register(B8, AX), cond) chain ++= evalExprOntoStack(y)
chain ++= zeroRest(RAX, B8) chain += stack.pop(RAX)
chain += Compare(stack.head(SizeDir.DWord), EAX)
chain += Set(Register(RegSize.Byte, RegName.AL), cond)
chain += And(RAX, ImmediateVal(_8_BIT_MASK))
chain += stack.drop() chain += stack.drop()
chain += stack.push(B8, RAX) chain += stack.push(RAX)
chain chain
} }
private def funcPrologue(): Chain[AsmLine] = { // Missing a sub instruction but dont think we need it
def funcPrologue()(using stack: Stack): Chain[AsmLine] = {
var chain = Chain.empty[AsmLine] var chain = Chain.empty[AsmLine]
chain += Push(RBP) chain += stack.push(RBP)
chain += Move(RBP, Register(Q64, SP)) chain += Move(RBP, Register(RegSize.R64, RegName.SP))
chain chain
} }
private def funcEpilogue(): Chain[AsmLine] = { def funcEpilogue()(using stack: Stack): Chain[AsmLine] = {
var chain = Chain.empty[AsmLine] var chain = Chain.empty[AsmLine]
chain += Move(Register(Q64, SP), RBP) chain += Move(Register(RegSize.R64, RegName.SP), RBP)
chain += Pop(RBP) chain += Pop(RBP)
chain += assemblyIR.Return() chain += assemblyIR.Return()
chain chain
} }
private def stackAlign: AsmLine = And(Register(Q64, SP), ImmediateVal(-16)) class Stack {
private def zeroRest(dest: Dest, size: Size): Chain[AsmLine] = size match { private val stack = LinkedHashMap[Expr | Int, Int]()
case Q64 | D32 => Chain.empty private val RSP = Register(RegSize.R64, RegName.SP)
case _ => Chain.one(And(dest, ImmediateVal((1 << (size.toInt * 8)) - 1)))
private def next: Int = stack.size + 1
def push(expr: Expr, src: Src): AsmLine = {
stack += expr -> next
Push(src)
}
def push(src: Src): AsmLine = {
stack += stack.size -> next
Push(src)
}
def pop(dest: Src): AsmLine = {
stack.remove(stack.last._1)
Pop(dest)
}
def reserve(ident: Ident): AsmLine = {
stack += ident -> next
Subtract(RSP, ImmediateVal(8))
}
def reserve(n: Int = 1): AsmLine = {
(1 to n).foreach(_ => stack += stack.size -> next)
Subtract(RSP, ImmediateVal(n * 8))
}
def drop(n: Int = 1): AsmLine = {
(1 to n).foreach(_ => stack.remove(stack.last._1))
Add(RSP, ImmediateVal(n * 8))
}
def accessVar(ident: Ident): () => IndexAddress = () => {
IndexAddress(RSP, (stack.size - stack(ident)) * 8)
}
def head: MemLocation = MemLocation(RSP)
def head(size: SizeDir): MemLocation = MemLocation(RSP, size)
def contains(ident: Ident): Boolean = stack.contains(ident)
// TODO: Might want to actually properly handle this with the LinkedHashMap too
def align(): AsmLine = And(RSP, ImmediateVal(-16))
} }
private val escapedCharsMapping = escapedChars.map { case (k, v) => v -> s"\\$k" } private val escapedCharsMapping = escapedChars.map { case (k, v) => v -> s"\\$k" }

View File

@ -6,73 +6,40 @@ object assemblyIR {
sealed trait Operand sealed trait Operand
sealed trait Src extends Operand // mem location, register and imm value sealed trait Src extends Operand // mem location, register and imm value
sealed trait Dest extends Operand // mem location and register sealed trait Dest extends Operand // mem location and register
enum RegSize {
case R64
case E32
case Byte
enum Size { override def toString = this match {
case Q64, D32, W16, B8 case R64 => "r"
case E32 => "e"
def toInt: Int = this match { case Byte => ""
case Q64 => 8
case D32 => 4
case W16 => 2
case B8 => 1
}
private val ptr = "ptr "
override def toString(): String = this match {
case Q64 => "qword " + ptr
case D32 => "dword " + ptr
case W16 => "word " + ptr
case B8 => "byte " + ptr
} }
} }
enum RegName { enum RegName {
case AX, BX, CX, DX, SI, DI, SP, BP, IP, R8, R9, R10, R11, R12, R13, R14, R15 case AX, AL, BX, CX, DX, SI, DI, SP, BP, IP, Reg8, Reg9, Reg10, Reg11, Reg12, Reg13, Reg14,
} Reg15
override def toString = this match {
case class Register(size: Size, name: RegName) extends Dest with Src { case AX => "ax"
import RegName._ case AL => "al"
case BX => "bx"
if (size == Size.B8 && name == RegName.IP) { case CX => "cx"
throw new IllegalArgumentException("Cannot have 8 bit register for IP") case DX => "dx"
} case SI => "si"
override def toString = name match { case DI => "di"
case AX => tradToString("ax", "al") case SP => "sp"
case BX => tradToString("bx", "bl") case BP => "bp"
case CX => tradToString("cx", "cl") case IP => "ip"
case DX => tradToString("dx", "dl") case Reg8 => "8"
case SI => tradToString("si", "sil") case Reg9 => "9"
case DI => tradToString("di", "dil") case Reg10 => "10"
case SP => tradToString("sp", "spl") case Reg11 => "11"
case BP => tradToString("bp", "bpl") case Reg12 => "12"
case IP => tradToString("ip", "#INVALID") case Reg13 => "13"
case R8 => newToString(8) case Reg14 => "14"
case R9 => newToString(9) case Reg15 => "15"
case R10 => newToString(10)
case R11 => newToString(11)
case R12 => newToString(12)
case R13 => newToString(13)
case R14 => newToString(14)
case R15 => newToString(15)
}
private def tradToString(base: String, byteName: String): String =
size match {
case Size.Q64 => "r" + base
case Size.D32 => "e" + base
case Size.W16 => base
case Size.B8 => byteName
}
private def newToString(base: Int): String = {
val b = base.toString
"r" + (size match {
case Size.Q64 => b
case Size.D32 => b + "d"
case Size.W16 => b + "w"
case Size.B8 => b + "b"
})
} }
} }
@ -81,9 +48,7 @@ object assemblyIR {
case Scanf, case Scanf,
Fflush, Fflush,
Exit, Exit,
PrintF, PrintF
Malloc,
Free
private val plt = "@plt" private val plt = "@plt"
@ -92,29 +57,28 @@ object assemblyIR {
case Fflush => "fflush" + plt case Fflush => "fflush" + plt
case Exit => "exit" + plt case Exit => "exit" + plt
case PrintF => "printf" + plt case PrintF => "printf" + plt
case Malloc => "malloc" + plt
case Free => "free" + plt
} }
} }
case class MemLocation(pointer: Register, opSize: Size) extends Dest with Src { // TODO register naming conventions are wrong
override def toString = case class Register(size: RegSize, name: RegName) extends Dest with Src {
opSize.toString + s"[$pointer]" override def toString = s"${size}${name}"
}
case class MemLocation(pointer: Long | Register, opSize: SizeDir = SizeDir.Unspecified)
extends Dest
with Src {
override def toString = pointer match {
case hex: Long => opSize.toString + f"[0x$hex%X]"
case reg: Register => opSize.toString + s"[$reg]"
}
} }
// TODO to string is wacky
case class IndexAddress( case class IndexAddress(
base: Register, base: Register,
offset: Int | LabelArg, offset: Int | LabelArg,
indexReg: Register = Register(Size.Q64, RegName.AX), opSize: SizeDir = SizeDir.Unspecified
scale: Int = 0
) extends Dest ) extends Dest
with Src { with Src {
override def toString = if (scale != 0) { override def toString = s"$opSize[$base + $offset]"
s"[$base + $indexReg * $scale + $offset]"
} else {
s"[$base + $offset]"
}
} }
case class ImmediateVal(value: Int) extends Src { case class ImmediateVal(value: Int) extends Src {
@ -162,11 +126,6 @@ object assemblyIR {
override def toString = s"$name:" override def toString = s"$name:"
} }
case class Comment(comment: String) extends AsmLine {
override def toString =
comment.split("\n").map(line => s"# ${line}").mkString("\n")
}
enum Cond { enum Cond {
case Equal, case Equal,
NotEqual, NotEqual,
@ -213,4 +172,17 @@ object assemblyIR {
case String => "%s" case String => "%s"
} }
} }
enum SizeDir {
case Byte, Word, DWord, Unspecified
private val ptr = "ptr "
override def toString(): String = this match {
case Byte => "byte " + ptr
case Word => "word " + ptr // TODO check word/doubleword/quadword
case DWord => "dword " + ptr
case Unspecified => ""
}
}
} }

View File

@ -1,35 +0,0 @@
package wacc
object sizeExtensions {
import microWacc._
import types._
import assemblyIR.Size
extension (expr: Expr) {
/** Calculate the size (bytes) of the heap required for the expression. */
def heapSize: Int = (expr, expr.ty) match {
case (ArrayLiter(elems), KnownType.Array(KnownType.Char)) =>
KnownType.Int.size.toInt + elems.size.toInt * KnownType.Char.size.toInt
case (ArrayLiter(elems), ty) =>
KnownType.Int.size.toInt + elems.size * ty.elemSize.toInt
case _ => expr.ty.size.toInt
}
}
extension (ty: SemType) {
/** Calculate the size (bytes) of a type in a register. */
def size: Size = ty match {
case KnownType.Int => Size.D32
case KnownType.Bool | KnownType.Char => Size.B8
case KnownType.String | KnownType.Array(_) | KnownType.Pair(_, _) | ? => Size.Q64
}
def elemSize: Size = ty match {
case KnownType.Array(elem) => elem.size
case KnownType.Pair(_, _) => Size.Q64
case _ => ty.size
}
}
}

View File

@ -1,12 +1,11 @@
package wacc package wacc
import java.io.PrintStream import java.io.PrintStream
import cats.data.Chain
object writer { object writer {
import assemblyIR._ import assemblyIR._
def writeTo(asmList: Chain[AsmLine], printStream: PrintStream): Unit = { def writeTo(asmList: List[AsmLine], printStream: PrintStream): Unit = {
asmList.iterator.foreach(printStream.println) asmList.foreach(printStream.println)
} }
} }

View File

@ -74,7 +74,6 @@ object microWacc {
object Exit extends Builtin("exit")(?) object Exit extends Builtin("exit")(?)
object Free extends Builtin("free")(?) object Free extends Builtin("free")(?)
object Malloc extends Builtin("malloc")(?) object Malloc extends Builtin("malloc")(?)
object PrintCharArray extends Builtin("printCharArray")(?)
} }
case class Assign(lhs: LValue, rhs: Expr) extends Stmt case class Assign(lhs: LValue, rhs: Expr) extends Stmt

View File

@ -216,37 +216,21 @@ object typeChecker {
case ast.Print(expr, newline) => case ast.Print(expr, newline) =>
// This constraint should never fail, the scope-checker should have caught it already // This constraint should never fail, the scope-checker should have caught it already
val exprTyped = checkValue(expr, Constraint.Unconstrained) val exprTyped = checkValue(expr, Constraint.Unconstrained)
val exprFormat = exprTyped.ty match { val format = exprTyped.ty match {
case KnownType.Bool | KnownType.String => "%s" case KnownType.Bool | KnownType.String => "%s"
case KnownType.Array(KnownType.Char) => "%.*s"
case KnownType.Char => "%c" case KnownType.Char => "%c"
case KnownType.Int => "%d" case KnownType.Int => "%d"
case KnownType.Pair(_, _) | KnownType.Array(_) | ? => "%p" case _ => "%p"
} }
val printfCall = { (func: microWacc.Builtin, value: microWacc.Expr) =>
List( List(
microWacc.Call( microWacc.Call(
func, microWacc.Builtin.Printf,
List( List(
s"$exprFormat${if newline then "\n" else ""}".toMicroWaccCharArray, s"$format${if newline then "\n" else ""}".toMicroWaccCharArray,
value exprTyped
) )
) )
) )
}
exprTyped.ty match {
case KnownType.Bool =>
List(
microWacc.If(
exprTyped,
printfCall(microWacc.Builtin.Printf, "true".toMicroWaccCharArray),
printfCall(microWacc.Builtin.Printf, "false".toMicroWaccCharArray)
)
)
case KnownType.Array(KnownType.Char) =>
printfCall(microWacc.Builtin.PrintCharArray, exprTyped)
case _ => printfCall(microWacc.Builtin.Printf, exprTyped)
}
case ast.If(cond, thenStmt, elseStmt) => case ast.If(cond, thenStmt, elseStmt) =>
List( List(
microWacc.If( microWacc.If(

View File

@ -73,7 +73,7 @@ class ParallelExamplesSpec extends AnyFlatSpec with BeforeAndAfterAll {
) )
assert(process.exitValue == expectedExit) assert(process.exitValue == expectedExit)
assert(stdout.toString.replaceAll("0x[0-9a-f]+", "#addrs#") == expectedOutput) assert(stdout.toString == expectedOutput)
} }
} }
@ -86,23 +86,23 @@ class ParallelExamplesSpec extends AnyFlatSpec with BeforeAndAfterAll {
// format: off // format: off
// disable formatting to avoid binPack // disable formatting to avoid binPack
"^.*wacc-examples/valid/advanced.*$", "^.*wacc-examples/valid/advanced.*$",
// "^.*wacc-examples/valid/array.*$", "^.*wacc-examples/valid/array.*$",
// "^.*wacc-examples/valid/basic/exit.*$", // "^.*wacc-examples/valid/basic/exit.*$",
// "^.*wacc-examples/valid/basic/skip.*$", // "^.*wacc-examples/valid/basic/skip.*$",
// "^.*wacc-examples/valid/expressions.*$", "^.*wacc-examples/valid/expressions.*$",
"^.*wacc-examples/valid/function/nested_functions.*$", "^.*wacc-examples/valid/function/nested_functions.*$",
"^.*wacc-examples/valid/function/simple_functions.*$", "^.*wacc-examples/valid/function/simple_functions.*$",
// "^.*wacc-examples/valid/if.*$", // "^.*wacc-examples/valid/if.*$",
// "^.*wacc-examples/valid/IO/print.*$", "^.*wacc-examples/valid/IO/print.*$",
// "^.*wacc-examples/valid/IO/read.*$", // "^.*wacc-examples/valid/IO/read.*$",
"^.*wacc-examples/valid/IO/IOLoop.wacc.*$", "^.*wacc-examples/valid/IO/IOLoop.wacc.*$",
// "^.*wacc-examples/valid/IO/IOSequence.wacc.*$", // "^.*wacc-examples/valid/IO/IOSequence.wacc.*$",
// "^.*wacc-examples/valid/pairs.*$", "^.*wacc-examples/valid/pairs.*$",
"^.*wacc-examples/valid/runtimeErr.*$", "^.*wacc-examples/valid/runtimeErr.*$",
// "^.*wacc-examples/valid/scope.*$", "^.*wacc-examples/valid/scope.*$",
// "^.*wacc-examples/valid/sequence.*$", // "^.*wacc-examples/valid/sequence.*$",
// "^.*wacc-examples/valid/variables.*$", // "^.*wacc-examples/valid/variables.*$",
// "^.*wacc-examples/valid/while.*$", "^.*wacc-examples/valid/while.*$",
// format: on // format: on
).find(filename.matches).isDefined ).find(filename.matches).isDefined
} }

View File

@ -1,38 +1,42 @@
import org.scalatest.funsuite.AnyFunSuite import org.scalatest.funsuite.AnyFunSuite
import wacc.assemblyIR._ import wacc.assemblyIR._
import wacc.assemblyIR.Size._
import wacc.assemblyIR.RegName._
class instructionSpec extends AnyFunSuite { class instructionSpec extends AnyFunSuite {
val named64BitRegister = Register(Q64, AX) val named64BitRegister = Register(RegSize.R64, RegName.AX)
test("named 64-bit register toString") { test("named 64-bit register toString") {
assert(named64BitRegister.toString == "rax") assert(named64BitRegister.toString == "rax")
} }
val named32BitRegister = Register(D32, AX) val named32BitRegister = Register(RegSize.E32, RegName.AX)
test("named 32-bit register toString") { test("named 32-bit register toString") {
assert(named32BitRegister.toString == "eax") assert(named32BitRegister.toString == "eax")
} }
val scratch64BitRegister = Register(Q64, R8) val scratch64BitRegister = Register(RegSize.R64, RegName.Reg8)
test("scratch 64-bit register toString") { test("scratch 64-bit register toString") {
assert(scratch64BitRegister.toString == "r8") assert(scratch64BitRegister.toString == "r8")
} }
val scratch32BitRegister = Register(D32, R8) val scratch32BitRegister = Register(RegSize.E32, RegName.Reg8)
test("scratch 32-bit register toString") { test("scratch 32-bit register toString") {
assert(scratch32BitRegister.toString == "r8d") assert(scratch32BitRegister.toString == "e8")
} }
val memLocationWithRegister = MemLocation(named64BitRegister, Q64) val memLocationWithHex = MemLocation(0x12345678)
test("mem location with hex toString") {
assert(memLocationWithHex.toString == "[0x12345678]")
}
val memLocationWithRegister = MemLocation(named64BitRegister)
test("mem location with register toString") { test("mem location with register toString") {
assert(memLocationWithRegister.toString == "qword ptr [rax]") assert(memLocationWithRegister.toString == "[rax]")
} }
val immediateVal = ImmediateVal(123) val immediateVal = ImmediateVal(123)