fix: variable-sized values, heap-allocated arrays (and printCharArray)

This commit is contained in:
Gleb Koval 2025-02-27 14:48:24 +00:00
parent 691d989b92
commit 4ffe85be91
Signed by: cyclane
GPG Key ID: 15E168A8B332382C
8 changed files with 185 additions and 238 deletions

View File

@ -1,6 +1,7 @@
package wacc package wacc
import scala.collection.mutable import scala.collection.mutable
import cats.data.Chain
import parsley.{Failure, Success} import parsley.{Failure, Success}
import scopt.OParser import scopt.OParser
import java.io.File import java.io.File
@ -63,7 +64,7 @@ def frontend(
} }
val s = "enter an integer to echo" val s = "enter an integer to echo"
def backend(typedProg: microWacc.Program): List[asm.AsmLine] = def backend(typedProg: microWacc.Program): Chain[asm.AsmLine] =
asmGenerator.generateAsm(typedProg) asmGenerator.generateAsm(typedProg)
def compile(filename: String, outFile: Option[File] = None)(using def compile(filename: String, outFile: Option[File] = None)(using

View File

@ -1,37 +1,48 @@
package wacc package wacc
import scala.collection.mutable.LinkedHashMap import scala.collection.mutable.LinkedHashMap
import cats.data.Chain
class Stack { class Stack {
import assemblyIR._ import assemblyIR._
import assemblyIR.Size._
import sizeExtensions.size import sizeExtensions.size
import microWacc as mw import microWacc as mw
private val RSP = Register(Size.Q64, RegName.SP) private val RSP = Register(Q64, RegName.SP)
private class StackValue(val size: Size, val offset: Int) { private class StackValue(val size: Size, val offset: Int) {
def bottom: Int = offset + size.toInt def bottom: Int = offset + elemBytes
} }
private val stack = LinkedHashMap[mw.Expr | Int, StackValue]() private val stack = LinkedHashMap[mw.Expr | Int, StackValue]()
private val elemBytes: Int = Q64.toInt
private def sizeBytes: Int = stack.size * elemBytes
/** The stack's size in bytes. */ /** The stack's size in bytes. */
def size: Int = if stack.isEmpty then 0 else stack.last._2.bottom def size: Int = stack.size
/** Push an expression onto the stack. */ /** Push an expression onto the stack. */
def push(expr: mw.Expr, src: Register): AsmLine = { def push(expr: mw.Expr, src: Register): AsmLine = {
stack += expr -> StackValue(src.size, size) stack += expr -> StackValue(src.size, sizeBytes)
Push(src) Push(src)
} }
/** Push an arbitrary register onto the stack. */ /** Push a value onto the stack. */
def push(src: Register): AsmLine = { def push(itemSize: Size, addr: Src): AsmLine = {
stack += stack.size -> StackValue(src.size, size) stack += stack.size -> StackValue(itemSize, sizeBytes)
Push(src) Push(addr)
} }
/** Reserve space for a variable on the stack. */ /** Reserve space for a variable on the stack. */
def reserve(ident: mw.Ident): AsmLine = { def reserve(ident: mw.Ident): AsmLine = {
stack += ident -> StackValue(ident.ty.size, size) stack += ident -> StackValue(ident.ty.size, sizeBytes)
Subtract(RSP, ImmediateVal(ident.ty.size.toInt)) Subtract(RSP, ImmediateVal(elemBytes))
}
/** Reserve space for a register on the stack. */
def reserve(src: Register): AsmLine = {
stack += stack.size -> StackValue(src.size, sizeBytes)
Subtract(RSP, ImmediateVal(src.size.toInt))
} }
/** Reserve space for values on the stack. /** Reserve space for values on the stack.
@ -40,45 +51,40 @@ class Stack {
* The sizes of the values to reserve space for. * The sizes of the values to reserve space for.
*/ */
def reserve(sizes: List[Size]): AsmLine = { def reserve(sizes: List[Size]): AsmLine = {
val totalSize = sizes sizes.foreach { itemSize =>
.map(itemSize => stack += stack.size -> StackValue(itemSize, sizeBytes)
stack += stack.size -> StackValue(itemSize, size) }
itemSize.toInt Subtract(RSP, ImmediateVal(elemBytes * sizes.size))
)
.sum
Subtract(RSP, ImmediateVal(totalSize))
} }
/** Pop a value from the stack into a register. Sizes MUST match. */ /** Pop a value from the stack into a register. Sizes MUST match. */
def pop(dest: Register): AsmLine = { def pop(dest: Register): AsmLine = {
if (dest.size != stack.last._2.size) {
throw new IllegalArgumentException(
s"Cannot pop ${stack.last._2.size} bytes into $dest (${dest.size} bytes) register"
)
}
stack.remove(stack.last._1) stack.remove(stack.last._1)
Pop(dest) Pop(dest)
} }
/** Drop the top n values from the stack. */ /** Drop the top n values from the stack. */
def drop(n: Int = 1): AsmLine = { def drop(n: Int = 1): AsmLine = {
val totalSize = (1 to n) (1 to n).foreach { _ =>
.map(_ => stack.remove(stack.last._1)
val itemSize = stack.last._2.size.toInt }
stack.remove(stack.last._1) Add(RSP, ImmediateVal(n * elemBytes))
itemSize
)
.sum
Add(RSP, ImmediateVal(totalSize))
} }
/** Get a lazy IndexAddress for a variable in the stack. */ /** Generate AsmLines within a scope, which is reset after the block. */
def accessVar(ident: mw.Ident): () => IndexAddress = () => { def withScope(block: () => Chain[AsmLine]): Chain[AsmLine] = {
IndexAddress(RSP, stack.size - stack(ident).bottom) val resetToSize = stack.size
var lines = block()
lines :+= drop(stack.size - resetToSize)
lines
} }
/** Get an IndexAddress for a variable in the stack. */
def accessVar(ident: mw.Ident): IndexAddress =
IndexAddress(RSP, sizeBytes - stack(ident).bottom)
def contains(ident: mw.Ident): Boolean = stack.contains(ident) def contains(ident: mw.Ident): Boolean = stack.contains(ident)
def head: MemLocation = MemLocation(RSP) def head: MemLocation = MemLocation(RSP, stack.last._2.size)
def head(offset: Size): MemLocation = MemLocation(RSP, Some(offset))
// TODO: Might want to actually properly handle this with the LinkedHashMap too override def toString(): String = stack.toString
def align(): AsmLine = And(RSP, ImmediateVal(-16))
} }

View File

@ -1,6 +1,5 @@
package wacc package wacc
import scala.collection.mutable.LinkedHashMap
import scala.collection.mutable.ListBuffer import scala.collection.mutable.ListBuffer
import cats.data.Chain import cats.data.Chain
import cats.syntax.foldable._ import cats.syntax.foldable._
@ -8,7 +7,10 @@ import cats.syntax.foldable._
object asmGenerator { object asmGenerator {
import microWacc._ import microWacc._
import assemblyIR._ import assemblyIR._
import wacc.types._ import assemblyIR.Size._
import assemblyIR.RegName._
import types._
import sizeExtensions._
import lexer.escapedChars import lexer.escapedChars
abstract case class Error() { abstract case class Error() {
@ -29,26 +31,22 @@ object asmGenerator {
def errLabel = ".L._errDivZero" def errLabel = ".L._errDivZero"
} }
val RAX = Register(RegSize.R64, RegName.AX) private val RAX = Register(Q64, AX)
val EAX = Register(RegSize.E32, RegName.AX) private val EAX = Register(D32, AX)
val ESP = Register(RegSize.E32, RegName.SP) private val RDI = Register(Q64, DI)
val EDX = Register(RegSize.E32, RegName.DX) private val RIP = Register(Q64, IP)
val RDI = Register(RegSize.R64, RegName.DI) private val RBP = Register(Q64, BP)
val RIP = Register(RegSize.R64, RegName.IP) private val RSI = Register(Q64, SI)
val RBP = Register(RegSize.R64, RegName.BP) private val RDX = Register(Q64, DX)
val RSI = Register(RegSize.R64, RegName.SI) private val RCX = Register(Q64, CX)
val RDX = Register(RegSize.R64, RegName.DX) private val argRegs = List(DI, SI, DX, CX, R8, R9)
val RCX = Register(RegSize.R64, RegName.CX)
val R8 = Register(RegSize.R64, RegName.Reg8)
val R9 = Register(RegSize.R64, RegName.Reg9)
val argRegs = List(RDI, RSI, RDX, RCX, R8, R9)
val _8_BIT_MASK = 0xff private val _8_BIT_MASK = 0xff
extension (chain: Chain[AsmLine]) extension [T](chain: Chain[T])
def +(line: AsmLine): Chain[AsmLine] = chain.append(line) def +(item: T): Chain[T] = chain.append(item)
def concatAll(chains: Chain[AsmLine]*): Chain[AsmLine] = def concatAll(chains: Chain[T]*): Chain[T] =
chains.foldLeft(chain)(_ ++ _) chains.foldLeft(chain)(_ ++ _)
class LabelGenerator { class LabelGenerator {
@ -63,7 +61,7 @@ object asmGenerator {
} }
} }
def generateAsm(microProg: Program): List[AsmLine] = { def generateAsm(microProg: Program): Chain[AsmLine] = {
given stack: Stack = Stack() given stack: Stack = Stack()
given strings: ListBuffer[String] = ListBuffer[String]() given strings: ListBuffer[String] = ListBuffer[String]()
given labelGenerator: LabelGenerator = LabelGenerator() given labelGenerator: LabelGenerator = LabelGenerator()
@ -71,7 +69,6 @@ object asmGenerator {
val progAsm = Chain(LabelDef("main")).concatAll( val progAsm = Chain(LabelDef("main")).concatAll(
funcPrologue(), funcPrologue(),
Chain.one(stack.align()),
main.foldMap(generateStmt(_)), main.foldMap(generateStmt(_)),
Chain.one(Xor(RAX, RAX)), Chain.one(Xor(RAX, RAX)),
funcEpilogue(), funcEpilogue(),
@ -95,12 +92,10 @@ object asmGenerator {
strDirs, strDirs,
Chain.one(Directive.Text), Chain.one(Directive.Text),
progAsm progAsm
).toList )
} }
private def wrapBuiltinFunc(labelName: String, funcBody: Chain[AsmLine])(using private def wrapBuiltinFunc(labelName: String, funcBody: Chain[AsmLine]): Chain[AsmLine] = {
stack: Stack
): Chain[AsmLine] = {
var chain = Chain.one[AsmLine](LabelDef(labelName)) var chain = Chain.one[AsmLine](LabelDef(labelName))
chain ++= funcPrologue() chain ++= funcPrologue()
chain ++= funcBody chain ++= funcBody
@ -108,7 +103,7 @@ object asmGenerator {
chain chain
} }
def generateUserFunc(func: FuncDecl)(using private def generateUserFunc(func: FuncDecl)(using
strings: ListBuffer[String], strings: ListBuffer[String],
labelGenerator: LabelGenerator labelGenerator: LabelGenerator
): Chain[AsmLine] = { ): Chain[AsmLine] = {
@ -119,29 +114,27 @@ object asmGenerator {
chain ++= funcPrologue() chain ++= funcPrologue()
// Push the rest of params onto the stack for simplicity // Push the rest of params onto the stack for simplicity
argRegs.zip(func.params).foreach { (reg, param) => argRegs.zip(func.params).foreach { (reg, param) =>
chain += stack.push(param, reg) chain += stack.push(param, Register(Q64, reg))
} }
chain ++= func.body.foldMap(generateStmt(_)) chain ++= func.body.foldMap(generateStmt(_))
// No need for epilogue here since all user functions must return explicitly // No need for epilogue here since all user functions must return explicitly
chain chain
} }
def generateBuiltInFuncs()(using private def generateBuiltInFuncs()(using
stack: Stack,
strings: ListBuffer[String],
labelGenerator: LabelGenerator labelGenerator: LabelGenerator
): Chain[AsmLine] = { ): Chain[AsmLine] = {
var chain = Chain.empty[AsmLine] var chain = Chain.empty[AsmLine]
chain ++= wrapBuiltinFunc( chain ++= wrapBuiltinFunc(
labelGenerator.getLabel(Builtin.Exit), labelGenerator.getLabel(Builtin.Exit),
Chain(stack.align(), assemblyIR.Call(CLibFunc.Exit)) Chain(stackAlign, assemblyIR.Call(CLibFunc.Exit))
) )
chain ++= wrapBuiltinFunc( chain ++= wrapBuiltinFunc(
labelGenerator.getLabel(Builtin.Printf), labelGenerator.getLabel(Builtin.Printf),
Chain( Chain(
stack.align(), stackAlign,
assemblyIR.Call(CLibFunc.PrintF), assemblyIR.Call(CLibFunc.PrintF),
Xor(RDI, RDI), Xor(RDI, RDI),
assemblyIR.Call(CLibFunc.Fflush) assemblyIR.Call(CLibFunc.Fflush)
@ -151,9 +144,9 @@ object asmGenerator {
chain ++= wrapBuiltinFunc( chain ++= wrapBuiltinFunc(
labelGenerator.getLabel(Builtin.PrintCharArray), labelGenerator.getLabel(Builtin.PrintCharArray),
Chain( Chain(
stack.align(), stackAlign,
Load(RDX, IndexAddress(RSI, 8)), Load(RDX, IndexAddress(RSI, KnownType.Int.size.toInt)),
Move(RSI, MemLocation(RSI)), Move(Register(D32, SI), MemLocation(RSI, D32)),
assemblyIR.Call(CLibFunc.PrintF), assemblyIR.Call(CLibFunc.PrintF),
Xor(RDI, RDI), Xor(RDI, RDI),
assemblyIR.Call(CLibFunc.Fflush) assemblyIR.Call(CLibFunc.Fflush)
@ -162,7 +155,7 @@ object asmGenerator {
chain ++= wrapBuiltinFunc( chain ++= wrapBuiltinFunc(
labelGenerator.getLabel(Builtin.Malloc), labelGenerator.getLabel(Builtin.Malloc),
Chain(stack.align(), assemblyIR.Call(CLibFunc.Malloc)) Chain(stackAlign, assemblyIR.Call(CLibFunc.Malloc))
// Out of memory check is optional // Out of memory check is optional
) )
@ -171,13 +164,12 @@ object asmGenerator {
chain ++= wrapBuiltinFunc( chain ++= wrapBuiltinFunc(
labelGenerator.getLabel(Builtin.Read), labelGenerator.getLabel(Builtin.Read),
Chain( Chain(
stack.align(), stackAlign,
stack.reserve(), Subtract(Register(Q64, SP), ImmediateVal(8)),
stack.push(RSI), Push(RSI),
Load(RSI, stack.head), Load(RSI, MemLocation(Register(Q64, SP), Q64)),
assemblyIR.Call(CLibFunc.Scanf), assemblyIR.Call(CLibFunc.Scanf),
stack.pop(RAX), Pop(RAX)
stack.drop()
) )
) )
@ -185,7 +177,7 @@ object asmGenerator {
// TODO can this be done with a call to generateStmt? // TODO can this be done with a call to generateStmt?
// Consider other error cases -> look to generalise // Consider other error cases -> look to generalise
LabelDef(zeroDivError.errLabel), LabelDef(zeroDivError.errLabel),
stack.align(), stackAlign,
Load(RDI, IndexAddress(RIP, LabelArg(zeroDivError.strLabel))), Load(RDI, IndexAddress(RIP, LabelArg(zeroDivError.strLabel))),
assemblyIR.Call(CLibFunc.PrintF), assemblyIR.Call(CLibFunc.PrintF),
Move(RDI, ImmediateVal(-1)), Move(RDI, ImmediateVal(-1)),
@ -195,53 +187,33 @@ object asmGenerator {
chain chain
} }
/** Wraps a chain in a stack reset. private def generateStmt(stmt: Stmt)(using
*
* This is useful for ensuring that the stack size at the death of scope is the same as the stack
* size at the start of the scope. See branching (If / While)
*
* @param genChain
* Function that generates the scope AsmLines
* @param stack
* The stack to reset
* @return
* The generated scope AsmLines
*/
private def generateScope(genChain: () => Chain[AsmLine])(using
stack: Stack
): Chain[AsmLine] = {
val stackSizeStart = stack.size
var chain = genChain()
chain += stack.drop(stack.size - stackSizeStart)
chain
}
def generateStmt(stmt: Stmt)(using
stack: Stack, stack: Stack,
strings: ListBuffer[String], strings: ListBuffer[String],
labelGenerator: LabelGenerator labelGenerator: LabelGenerator
): Chain[AsmLine] = { ): Chain[AsmLine] = {
var chain = Chain.empty[AsmLine] var chain = Chain.empty[AsmLine]
chain += Comment(stmt.toString)
stmt match { stmt match {
case Assign(lhs, rhs) => case Assign(lhs, rhs) =>
lhs match { lhs match {
case ident: Ident => case ident: Ident =>
val dest = stack.accessVar(ident)
if (!stack.contains(ident)) chain += stack.reserve(ident) if (!stack.contains(ident)) chain += stack.reserve(ident)
chain ++= evalExprOntoStack(rhs) chain ++= evalExprOntoStack(rhs)
chain += stack.pop(RDX) chain += stack.pop(RAX)
chain += Move(dest(), RDX) chain += Move(stack.accessVar(ident), RAX)
case ArrayElem(x, i) => case ArrayElem(x, i) =>
chain ++= evalExprOntoStack(x)
chain ++= evalExprOntoStack(i)
chain ++= evalExprOntoStack(rhs) chain ++= evalExprOntoStack(rhs)
chain ++= evalExprOntoStack(i)
chain ++= evalExprOntoStack(x)
chain += stack.pop(RAX) chain += stack.pop(RAX)
chain += stack.pop(RCX) chain += stack.pop(RCX)
chain += stack.pop(RDX) chain += stack.pop(RDX)
chain += Move(IndexAddress(RDX, 8, RCX, 8), RAX) chain += Move(
IndexAddress(RAX, KnownType.Int.size.toInt, RCX, x.ty.elemSize.toInt),
Register(x.ty.elemSize, DX)
)
} }
case If(cond, thenBranch, elseBranch) => case If(cond, thenBranch, elseBranch) =>
@ -253,11 +225,11 @@ object asmGenerator {
chain += Compare(RAX, ImmediateVal(0)) chain += Compare(RAX, ImmediateVal(0))
chain += Jump(LabelArg(elseLabel), Cond.Equal) chain += Jump(LabelArg(elseLabel), Cond.Equal)
chain ++= generateScope(() => thenBranch.foldMap(generateStmt)) chain ++= stack.withScope(() => thenBranch.foldMap(generateStmt))
chain += Jump(LabelArg(endLabel)) chain += Jump(LabelArg(endLabel))
chain += LabelDef(elseLabel) chain += LabelDef(elseLabel)
chain ++= generateScope(() => elseBranch.foldMap(generateStmt)) chain ++= stack.withScope(() => elseBranch.foldMap(generateStmt))
chain += LabelDef(endLabel) chain += LabelDef(endLabel)
case While(cond, body) => case While(cond, body) =>
@ -270,7 +242,7 @@ object asmGenerator {
chain += Compare(RAX, ImmediateVal(0)) chain += Compare(RAX, ImmediateVal(0))
chain += Jump(LabelArg(endLabel), Cond.Equal) chain += Jump(LabelArg(endLabel), Cond.Equal)
chain ++= generateScope(() => body.foldMap(generateStmt)) chain ++= stack.withScope(() => body.foldMap(generateStmt))
chain += Jump(LabelArg(startLabel)) chain += Jump(LabelArg(startLabel))
chain += LabelDef(endLabel) chain += LabelDef(endLabel)
@ -291,7 +263,7 @@ object asmGenerator {
chain chain
} }
def evalExprOntoStack(expr: Expr)(using private def evalExprOntoStack(expr: Expr)(using
stack: Stack, stack: Stack,
strings: ListBuffer[String], strings: ListBuffer[String],
labelGenerator: LabelGenerator labelGenerator: LabelGenerator
@ -299,111 +271,117 @@ object asmGenerator {
var chain = Chain.empty[AsmLine] var chain = Chain.empty[AsmLine]
val stackSizeStart = stack.size val stackSizeStart = stack.size
expr match { expr match {
case IntLiter(v) => chain += stack.push(ImmediateVal(v)) case IntLiter(v) => chain += stack.push(KnownType.Int.size, ImmediateVal(v))
case CharLiter(v) => chain += stack.push(ImmediateVal(v.toInt)) case CharLiter(v) => chain += stack.push(KnownType.Char.size, ImmediateVal(v.toInt))
case ident: Ident => chain += stack.push(stack.accessVar(ident)()) case ident: Ident => chain += stack.push(ident.ty.size, stack.accessVar(ident))
case ArrayLiter(elems) => case array @ ArrayLiter(elems) =>
expr.ty match { expr.ty match {
case KnownType.String => case KnownType.String =>
strings += elems.collect { case CharLiter(v) => v }.mkString strings += elems.collect { case CharLiter(v) => v }.mkString
chain += Load(RAX, IndexAddress(RIP, LabelArg(s".L.str${strings.size - 1}"))) chain += Load(RAX, IndexAddress(RIP, LabelArg(s".L.str${strings.size - 1}")))
chain += stack.push(RAX) chain += stack.push(Q64, RAX)
case _ => case ty =>
chain ++= generateCall( chain ++= generateCall(
microWacc.Call(Builtin.Malloc, List(IntLiter((elems.size + 1) * 8))), microWacc.Call(Builtin.Malloc, List(IntLiter(array.heapSize))),
isTail = false isTail = false
) )
chain += stack.push(RAX) chain += stack.push(Q64, RAX)
// Store the length of the array at the start // Store the length of the array at the start
chain += Move(MemLocation(RAX, SizeDir.DWord), ImmediateVal(elems.size)) chain += Move(MemLocation(RAX, D32), ImmediateVal(elems.size))
elems.zipWithIndex.foldMap { (elem, i) => elems.zipWithIndex.foldMap { (elem, i) =>
chain ++= evalExprOntoStack(elem) chain ++= evalExprOntoStack(elem)
chain += stack.pop(RCX) chain += stack.pop(RCX)
chain += stack.pop(RAX) chain += stack.pop(RAX)
chain += Move(IndexAddress(RAX, 8 * (i + 1)), RCX) chain += Move(IndexAddress(RAX, 4 + i * ty.elemSize.toInt), Register(ty.elemSize, CX))
chain += stack.push(RAX) chain += stack.push(Q64, RAX)
} }
} }
case BoolLiter(true) => chain += stack.push(ImmediateVal(1)) case BoolLiter(true) =>
chain += stack.push(KnownType.Bool.size, ImmediateVal(1))
case BoolLiter(false) => case BoolLiter(false) =>
chain += Xor(RAX, RAX) chain += Xor(RAX, RAX)
chain += stack.push(RAX) chain += stack.push(KnownType.Bool.size, RAX)
case NullLiter() => chain += stack.push(ImmediateVal(0)) case NullLiter() =>
chain += stack.push(KnownType.Pair(?, ?).size, ImmediateVal(0))
case ArrayElem(x, i) => case ArrayElem(x, i) =>
chain ++= evalExprOntoStack(x) chain ++= evalExprOntoStack(x)
chain ++= evalExprOntoStack(i) chain ++= evalExprOntoStack(i)
chain += stack.pop(RCX) chain += stack.pop(RCX)
chain += stack.pop(RAX) chain += stack.pop(RAX)
// + 1 because we store the length of the array at the start // + Int because we store the length of the array at the start
chain += stack.push(IndexAddress(RAX, 8, RCX, 8)) chain += Move(
Register(x.ty.elemSize, AX),
IndexAddress(RAX, KnownType.Int.size.toInt, RCX, x.ty.elemSize.toInt)
)
chain += stack.push(x.ty.elemSize, RAX)
case UnaryOp(x, op) => case UnaryOp(x, op) =>
chain ++= evalExprOntoStack(x) chain ++= evalExprOntoStack(x)
op match { op match {
case UnaryOperator.Chr | UnaryOperator.Ord => // No op needed case UnaryOperator.Chr | UnaryOperator.Ord => // No op needed
case UnaryOperator.Len => case UnaryOperator.Len =>
// Access the elem
chain += stack.pop(RAX) chain += stack.pop(RAX)
chain += Push(MemLocation(RAX)) chain += Move(EAX, MemLocation(RAX, D32))
case UnaryOperator.Negate => chain += Negate(stack.head(SizeDir.DWord)) chain += stack.push(D32, RAX)
case UnaryOperator.Negate =>
chain += Negate(stack.head)
case UnaryOperator.Not => case UnaryOperator.Not =>
chain += Xor(stack.head(SizeDir.DWord), ImmediateVal(1)) chain += Xor(stack.head, ImmediateVal(1))
} }
case BinaryOp(x, y, op) => case BinaryOp(x, y, op) =>
val destX = Register(x.ty.size, AX)
chain ++= evalExprOntoStack(y) chain ++= evalExprOntoStack(y)
chain ++= evalExprOntoStack(x) chain ++= evalExprOntoStack(x)
chain += stack.pop(RAX) chain += stack.pop(RAX)
op match { op match {
case BinaryOperator.Add => chain += Add(stack.head(SizeDir.DWord), EAX) case BinaryOperator.Add =>
chain += Add(stack.head, destX)
case BinaryOperator.Sub => case BinaryOperator.Sub =>
chain += Subtract(EAX, stack.head(SizeDir.DWord)) chain += Subtract(destX, stack.head)
chain += stack.drop() chain += stack.drop()
chain += stack.push(RAX) chain += stack.push(destX.size, RAX)
case BinaryOperator.Mul => case BinaryOperator.Mul =>
chain += Multiply(EAX, stack.head(SizeDir.DWord)) chain += Multiply(destX, stack.head)
chain += stack.drop() chain += stack.drop()
chain += stack.push(RAX) chain += stack.push(destX.size, RAX)
case BinaryOperator.Div => case BinaryOperator.Div =>
chain += Compare(stack.head(SizeDir.DWord), ImmediateVal(0)) chain += Compare(stack.head, ImmediateVal(0))
chain += Jump(LabelArg(zeroDivError.errLabel), Cond.Equal) chain += Jump(LabelArg(zeroDivError.errLabel), Cond.Equal)
chain += CDQ() chain += CDQ()
chain += Divide(stack.head(SizeDir.DWord)) chain += Divide(stack.head)
chain += stack.drop() chain += stack.drop()
chain += stack.push(RAX) chain += stack.push(destX.size, RAX)
case BinaryOperator.Mod => case BinaryOperator.Mod =>
chain += CDQ() chain += CDQ()
chain += Divide(stack.head(SizeDir.DWord)) chain += Divide(stack.head)
chain += stack.drop() chain += stack.drop()
chain += stack.push(RDX) chain += stack.push(destX.size, RDX)
case BinaryOperator.Eq => chain ++= generateComparison(Cond.Equal) case BinaryOperator.Eq => chain ++= generateComparison(destX, Cond.Equal)
case BinaryOperator.Neq => chain ++= generateComparison(Cond.NotEqual) case BinaryOperator.Neq => chain ++= generateComparison(destX, Cond.NotEqual)
case BinaryOperator.Greater => chain ++= generateComparison(Cond.Greater) case BinaryOperator.Greater => chain ++= generateComparison(destX, Cond.Greater)
case BinaryOperator.GreaterEq => chain ++= generateComparison(Cond.GreaterEqual) case BinaryOperator.GreaterEq => chain ++= generateComparison(destX, Cond.GreaterEqual)
case BinaryOperator.Less => chain ++= generateComparison(Cond.Less) case BinaryOperator.Less => chain ++= generateComparison(destX, Cond.Less)
case BinaryOperator.LessEq => chain ++= generateComparison(Cond.LessEqual) case BinaryOperator.LessEq => chain ++= generateComparison(destX, Cond.LessEqual)
case BinaryOperator.And => chain += And(stack.head(SizeDir.DWord), EAX) case BinaryOperator.And => chain += And(stack.head, destX)
case BinaryOperator.Or => chain += Or(stack.head(SizeDir.DWord), EAX) case BinaryOperator.Or => chain += Or(stack.head, destX)
} }
case call: microWacc.Call => case call: microWacc.Call =>
chain ++= generateCall(call, isTail = false) chain ++= generateCall(call, isTail = false)
chain += stack.push(RAX) chain += stack.push(call.ty.size, RAX)
} }
if chain.isEmpty then chain += stack.push(ImmediateVal(0))
assert(stack.size == stackSizeStart + 1) assert(stack.size == stackSizeStart + 1)
chain chain
} }
def generateCall(call: microWacc.Call, isTail: Boolean)(using private def generateCall(call: microWacc.Call, isTail: Boolean)(using
stack: Stack, stack: Stack,
strings: ListBuffer[String], strings: ListBuffer[String],
labelGenerator: LabelGenerator labelGenerator: LabelGenerator
@ -413,7 +391,7 @@ object asmGenerator {
argRegs.zip(args).foldMap { (reg, expr) => argRegs.zip(args).foldMap { (reg, expr) =>
chain ++= evalExprOntoStack(expr) chain ++= evalExprOntoStack(expr)
chain += stack.pop(reg) chain += stack.pop(Register(Q64, reg))
} }
args.drop(argRegs.size).foldMap { args.drop(argRegs.size).foldMap {
@ -434,77 +412,36 @@ object asmGenerator {
chain chain
} }
def generateComparison(cond: Cond)(using private def generateComparison(destX: Register, cond: Cond)(using
stack: Stack, stack: Stack
strings: ListBuffer[String],
labelGenerator: LabelGenerator
): Chain[AsmLine] = { ): Chain[AsmLine] = {
var chain = Chain.empty[AsmLine] var chain = Chain.empty[AsmLine]
chain += Compare(EAX, stack.head(SizeDir.DWord)) chain += Compare(destX, stack.head)
chain += Set(Register(RegSize.Byte, RegName.AL), cond) chain += Set(Register(B8, AX), cond)
chain += And(RAX, ImmediateVal(_8_BIT_MASK)) chain += And(RAX, ImmediateVal(_8_BIT_MASK))
chain += stack.drop() chain += stack.drop()
chain += stack.push(RAX) chain += stack.push(B8, RAX)
chain chain
} }
// Missing a sub instruction but dont think we need it private def funcPrologue(): Chain[AsmLine] = {
def funcPrologue()(using stack: Stack): Chain[AsmLine] = {
var chain = Chain.empty[AsmLine] var chain = Chain.empty[AsmLine]
chain += stack.push(RBP) chain += Push(RBP)
chain += Move(RBP, Register(RegSize.R64, RegName.SP)) chain += Move(RBP, Register(Q64, SP))
chain chain
} }
def funcEpilogue()(using stack: Stack): Chain[AsmLine] = { private def funcEpilogue(): Chain[AsmLine] = {
var chain = Chain.empty[AsmLine] var chain = Chain.empty[AsmLine]
chain += Move(Register(RegSize.R64, RegName.SP), RBP) chain += Move(Register(Q64, SP), RBP)
chain += Pop(RBP) chain += Pop(RBP)
chain += assemblyIR.Return() chain += assemblyIR.Return()
chain chain
} }
class Stack { private def stackAlign: AsmLine = And(Register(Q64, SP), ImmediateVal(-16))
private val stack = LinkedHashMap[Expr | Int, Int]()
private val RSP = Register(RegSize.R64, RegName.SP)
private def next: Int = stack.size + 1
def size: Int = stack.size
def push(expr: Expr, src: Src): AsmLine = {
stack += expr -> next
Push(src)
}
def push(src: Src): AsmLine = {
stack += stack.size -> next
Push(src)
}
def pop(dest: Src): AsmLine = {
stack.remove(stack.last._1)
Pop(dest)
}
def reserve(ident: Ident): AsmLine = {
stack += ident -> next
Subtract(RSP, ImmediateVal(8))
}
def reserve(n: Int = 1): AsmLine = {
(1 to n).foreach(_ => stack += stack.size -> next)
Subtract(RSP, ImmediateVal(n * 8))
}
def drop(n: Int = 1): AsmLine = {
(1 to n).foreach(_ => stack.remove(stack.last._1))
Add(RSP, ImmediateVal(n * 8))
}
def accessVar(ident: Ident): () => IndexAddress = () => {
IndexAddress(RSP, (stack.size - stack(ident)) * 8)
}
def head: MemLocation = MemLocation(RSP)
def head(size: SizeDir): MemLocation = MemLocation(RSP, size)
def contains(ident: Ident): Boolean = stack.contains(ident)
// TODO: Might want to actually properly handle this with the LinkedHashMap too
def align(): AsmLine = And(RSP, ImmediateVal(-16))
}
private val escapedCharsMapping = escapedChars.map { case (k, v) => v -> s"\\$k" } private val escapedCharsMapping = escapedChars.map { case (k, v) => v -> s"\\$k" }
extension (s: String) { extension (s: String) {

View File

@ -97,11 +97,9 @@ object assemblyIR {
} }
} }
case class MemLocation(pointer: Register, opSize: Option[Size] = None) extends Dest with Src { case class MemLocation(pointer: Register, opSize: Size) extends Dest with Src {
def this(pointer: Register, opSize: Size) = this(pointer, Some(opSize))
override def toString = override def toString =
opSize.getOrElse("").toString + s"[$pointer]" opSize.toString + s"[$pointer]"
} }
// TODO to string is wacky // TODO to string is wacky

View File

@ -11,8 +11,8 @@ object sizeExtensions {
def heapSize: Int = (expr, expr.ty) match { def heapSize: Int = (expr, expr.ty) match {
case (ArrayLiter(elems), KnownType.Array(KnownType.Char)) => case (ArrayLiter(elems), KnownType.Array(KnownType.Char)) =>
KnownType.Int.size.toInt + elems.size.toInt * KnownType.Char.size.toInt KnownType.Int.size.toInt + elems.size.toInt * KnownType.Char.size.toInt
case (ArrayLiter(elems), _) => case (ArrayLiter(elems), ty) =>
KnownType.Int.size.toInt + elems.map(_.ty.size.toInt).sum KnownType.Int.size.toInt + elems.size * ty.elemSize.toInt
case _ => expr.ty.size.toInt case _ => expr.ty.size.toInt
} }
} }
@ -25,5 +25,11 @@ object sizeExtensions {
case KnownType.Bool | KnownType.Char => Size.B8 case KnownType.Bool | KnownType.Char => Size.B8
case KnownType.String | KnownType.Array(_) | KnownType.Pair(_, _) | ? => Size.Q64 case KnownType.String | KnownType.Array(_) | KnownType.Pair(_, _) | ? => Size.Q64
} }
def elemSize: Size = ty match {
case KnownType.Array(elem) => elem.size
case KnownType.Pair(_, _) => Size.Q64
case _ => ty.size
}
} }
} }

View File

@ -1,11 +1,12 @@
package wacc package wacc
import java.io.PrintStream import java.io.PrintStream
import cats.data.Chain
object writer { object writer {
import assemblyIR._ import assemblyIR._
def writeTo(asmList: List[AsmLine], printStream: PrintStream): Unit = { def writeTo(asmList: Chain[AsmLine], printStream: PrintStream): Unit = {
asmList.foreach(printStream.println) asmList.iterator.foreach(printStream.println)
} }
} }

View File

@ -223,10 +223,10 @@ object typeChecker {
case KnownType.Int => "%d" case KnownType.Int => "%d"
case KnownType.Pair(_, _) | KnownType.Array(_) | ? => "%p" case KnownType.Pair(_, _) | KnownType.Array(_) | ? => "%p"
} }
val printfCall = { (value: microWacc.Expr) => val printfCall = { (func: microWacc.Builtin, value: microWacc.Expr) =>
List( List(
microWacc.Call( microWacc.Call(
microWacc.Builtin.Printf, func,
List( List(
s"$exprFormat${if newline then "\n" else ""}".toMicroWaccCharArray, s"$exprFormat${if newline then "\n" else ""}".toMicroWaccCharArray,
value value
@ -239,11 +239,13 @@ object typeChecker {
List( List(
microWacc.If( microWacc.If(
exprTyped, exprTyped,
printfCall("true".toMicroWaccCharArray), printfCall(microWacc.Builtin.Printf, "true".toMicroWaccCharArray),
printfCall("false".toMicroWaccCharArray) printfCall(microWacc.Builtin.Printf, "false".toMicroWaccCharArray)
) )
) )
case _ => printfCall(exprTyped) case KnownType.Array(KnownType.Char) =>
printfCall(microWacc.Builtin.PrintCharArray, exprTyped)
case _ => printfCall(microWacc.Builtin.Printf, exprTyped)
} }
case ast.If(cond, thenStmt, elseStmt) => case ast.If(cond, thenStmt, elseStmt) =>
List( List(

View File

@ -1,42 +1,38 @@
import org.scalatest.funsuite.AnyFunSuite import org.scalatest.funsuite.AnyFunSuite
import wacc.assemblyIR._ import wacc.assemblyIR._
import wacc.assemblyIR.Size._
import wacc.assemblyIR.RegName._
class instructionSpec extends AnyFunSuite { class instructionSpec extends AnyFunSuite {
val named64BitRegister = Register(RegSize.R64, RegName.AX) val named64BitRegister = Register(Q64, AX)
test("named 64-bit register toString") { test("named 64-bit register toString") {
assert(named64BitRegister.toString == "rax") assert(named64BitRegister.toString == "rax")
} }
val named32BitRegister = Register(RegSize.E32, RegName.AX) val named32BitRegister = Register(D32, AX)
test("named 32-bit register toString") { test("named 32-bit register toString") {
assert(named32BitRegister.toString == "eax") assert(named32BitRegister.toString == "eax")
} }
val scratch64BitRegister = Register(RegSize.R64, RegName.Reg8) val scratch64BitRegister = Register(Q64, R8)
test("scratch 64-bit register toString") { test("scratch 64-bit register toString") {
assert(scratch64BitRegister.toString == "r8") assert(scratch64BitRegister.toString == "r8")
} }
val scratch32BitRegister = Register(RegSize.E32, RegName.Reg8) val scratch32BitRegister = Register(D32, R8)
test("scratch 32-bit register toString") { test("scratch 32-bit register toString") {
assert(scratch32BitRegister.toString == "e8") assert(scratch32BitRegister.toString == "r8d")
} }
val memLocationWithHex = MemLocation(0x12345678) val memLocationWithRegister = MemLocation(named64BitRegister, Q64)
test("mem location with hex toString") {
assert(memLocationWithHex.toString == "[0x12345678]")
}
val memLocationWithRegister = MemLocation(named64BitRegister)
test("mem location with register toString") { test("mem location with register toString") {
assert(memLocationWithRegister.toString == "[rax]") assert(memLocationWithRegister.toString == "qword ptr [rax]")
} }
val immediateVal = ImmediateVal(123) val immediateVal = ImmediateVal(123)