feat: x86 code generation implementation without runtime checking

Merge request lab2425_spring/WACC_37!29

Co-authored-by: Alex Ling <al4423@ic.ac.uk>
Co-authored-by: Jonny <j.sinteix@gmail.com>
Co-authored-by: Guy C <gc1523@ic.ac.uk>
Co-authored-by: Barf-Vader <47476490+Barf-Vader@users.noreply.github.com>
Co-authored-by: Ling, Alex <alex.ling23@imperial.ac.uk>
This commit is contained in:
2025-02-27 18:54:54 +00:00
14 changed files with 870 additions and 267 deletions

4
.commitlintrc.js Normal file
View File

@@ -0,0 +1,4 @@
export default {
extends: ['@commitlint/config-conventional'],
ignores: [commit => commit.startsWith("Local Mutable Chains\n")]
}

View File

@@ -1 +0,0 @@
extends: "@commitlint/config-conventional"

View File

@@ -30,9 +30,10 @@ check_commits:
before_script:
- apk add git
- npm install -g @commitlint/cli @commitlint/config-conventional
- git pull origin master
- git checkout origin/master
script:
- npx commitlint --from origin/master --to HEAD --verbose
- git checkout ${CI_COMMIT_SHA}
- npx commitlint --from origin/master --to ${CI_COMMIT_SHA} --verbose
compile_jvm:
stage: compile
@@ -48,10 +49,10 @@ test_jvm:
image: gumjoe/wacc-ci-scala:x86
stage: test
# Use our own runner (not cloud VM or shared) to ensure we have multiple cores.
tags: [ large ]
tags: [large]
# This is expensive, so do use `dependencies` instead of `needs` to
# ensure all previous stages pass.
dependencies: [ compile_jvm ]
dependencies: [compile_jvm]
before_script:
- git clone https://$EXAMPLES_AUTH@gitlab.doc.ic.ac.uk/lab2425_spring/wacc-examples.git
script:

View File

@@ -1,13 +1,13 @@
package wacc
import scala.collection.mutable
import cats.data.Chain
import parsley.{Failure, Success}
import scopt.OParser
import java.io.File
import java.io.PrintStream
import assemblyIR as asm
import wacc.microWacc.IntLiter
case class CliConfig(
file: File = new File(".")
@@ -64,157 +64,8 @@ def frontend(
}
val s = "enter an integer to echo"
def backend(typedProg: microWacc.Program): List[asm.AsmLine] | String =
typedProg match {
case microWacc.Program(
Nil,
microWacc.Call(microWacc.Builtin.Exit, microWacc.IntLiter(v) :: Nil) :: Nil
) =>
s""".intel_syntax noprefix
.globl main
main:
mov edi, ${v}
call exit@plt
"""
case microWacc.Program(
Nil,
microWacc.Assign(microWacc.Ident("x", _), microWacc.IntLiter(1)) ::
microWacc.Call(microWacc.Builtin.Println, _) ::
microWacc.Assign(
microWacc.Ident("x", _),
microWacc.Call(microWacc.Builtin.ReadInt, Nil)
) ::
microWacc.Call(microWacc.Builtin.Println, microWacc.Ident("x", _) :: Nil) :: Nil
) =>
""".intel_syntax noprefix
.globl main
.section .rodata
# length of .L.str0
.int 24
.L.str0:
.asciz "enter an integer to echo"
.text
main:
push rbp
# push {rbx, r12}
sub rsp, 16
mov qword ptr [rsp], rbx
mov qword ptr [rsp + 8], r12
mov rbp, rsp
mov r12d, 1
lea rdi, [rip + .L.str0]
# statement primitives do not return results (but will clobber r0/rax)
call _prints
call _println
# load the current value in the destination of the read so it supports defaults
mov edi, r12d
call _readi
mov r12d, eax
mov edi, eax
# statement primitives do not return results (but will clobber r0/rax)
call _printi
call _println
mov rax, 0
# pop/peek {rbx, r12}
mov rbx, qword ptr [rsp]
mov r12, qword ptr [rsp + 8]
add rsp, 16
pop rbp
ret
.section .rodata
# length of .L._printi_str0
.int 2
.L._printi_str0:
.asciz "%d"
.text
_printi:
push rbp
mov rbp, rsp
# external calls must be stack-aligned to 16 bytes, accomplished by masking with fffffffffffffff0
and rsp, -16
mov esi, edi
lea rdi, [rip + .L._printi_str0]
# on x86, al represents the number of SIMD registers used as variadic arguments
mov al, 0
call printf@plt
mov rdi, 0
call fflush@plt
mov rsp, rbp
pop rbp
ret
.section .rodata
# length of .L._prints_str0
.int 4
.L._prints_str0:
.asciz "%.*s"
.text
_prints:
push rbp
mov rbp, rsp
# external calls must be stack-aligned to 16 bytes, accomplished by masking with fffffffffffffff0
and rsp, -16
mov rdx, rdi
mov esi, dword ptr [rdi - 4]
lea rdi, [rip + .L._prints_str0]
# on x86, al represents the number of SIMD registers used as variadic arguments
mov al, 0
call printf@plt
mov rdi, 0
call fflush@plt
mov rsp, rbp
pop rbp
ret
.section .rodata
# length of .L._println_str0
.int 0
.L._println_str0:
.asciz ""
.text
_println:
push rbp
mov rbp, rsp
# external calls must be stack-aligned to 16 bytes, accomplished by masking with fffffffffffffff0
and rsp, -16
lea rdi, [rip + .L._println_str0]
call puts@plt
mov rdi, 0
call fflush@plt
mov rsp, rbp
pop rbp
ret
.section .rodata
# length of .L._readi_str0
.int 2
.L._readi_str0:
.asciz "%d"
.text
_readi:
push rbp
mov rbp, rsp
# external calls must be stack-aligned to 16 bytes, accomplished by masking with fffffffffffffff0
and rsp, -16
# RDI contains the "original" value of the destination of the read
# allocate space on the stack to store the read: preserve alignment!
# the passed default argument should be stored in case of EOF
sub rsp, 16
mov dword ptr [rsp], edi
lea rsi, qword ptr [rsp]
lea rdi, [rip + .L._readi_str0]
# on x86, al represents the number of SIMD registers used as variadic arguments
mov al, 0
call scanf@plt
mov eax, dword ptr [rsp]
add rsp, 16
mov rsp, rbp
pop rbp
ret
"""
case _ => List()
}
def backend(typedProg: microWacc.Program): Chain[asm.AsmLine] =
asmGenerator.generateAsm(typedProg)
def compile(filename: String, outFile: Option[File] = None)(using
stdout: PrintStream = Console.out
@@ -222,13 +73,8 @@ def compile(filename: String, outFile: Option[File] = None)(using
frontend(os.read(os.Path(filename))) match {
case Left(typedProg) =>
val asmFile = outFile.getOrElse(File(filename.stripSuffix(".wacc") + ".s"))
backend(typedProg) match {
case s: String =>
os.write.over(os.Path(asmFile.getAbsolutePath), s)
case ops: List[asm.AsmLine] => {
writer.writeTo(ops, PrintStream(asmFile))
}
}
val asm = backend(typedProg)
writer.writeTo(asm, PrintStream(asmFile))
0
case Right(exitCode) => exitCode
}

View File

@@ -0,0 +1,90 @@
package wacc
import scala.collection.mutable.LinkedHashMap
import cats.data.Chain
class Stack {
import assemblyIR._
import assemblyIR.Size._
import sizeExtensions.size
import microWacc as mw
private val RSP = Register(Q64, RegName.SP)
private class StackValue(val size: Size, val offset: Int) {
def bottom: Int = offset + elemBytes
}
private val stack = LinkedHashMap[mw.Expr | Int, StackValue]()
private val elemBytes: Int = Q64.toInt
private def sizeBytes: Int = stack.size * elemBytes
/** The stack's size in bytes. */
def size: Int = stack.size
/** Push an expression onto the stack. */
def push(expr: mw.Expr, src: Register): AsmLine = {
stack += expr -> StackValue(src.size, sizeBytes)
Push(src)
}
/** Push a value onto the stack. */
def push(itemSize: Size, addr: Src): AsmLine = {
stack += stack.size -> StackValue(itemSize, sizeBytes)
Push(addr)
}
/** Reserve space for a variable on the stack. */
def reserve(ident: mw.Ident): AsmLine = {
stack += ident -> StackValue(ident.ty.size, sizeBytes)
Subtract(RSP, ImmediateVal(elemBytes))
}
/** Reserve space for a register on the stack. */
def reserve(src: Register): AsmLine = {
stack += stack.size -> StackValue(src.size, sizeBytes)
Subtract(RSP, ImmediateVal(src.size.toInt))
}
/** Reserve space for values on the stack.
*
* @param sizes
* The sizes of the values to reserve space for.
*/
def reserve(sizes: Size*): AsmLine = {
sizes.foreach { itemSize =>
stack += stack.size -> StackValue(itemSize, sizeBytes)
}
Subtract(RSP, ImmediateVal(elemBytes * sizes.size))
}
/** Pop a value from the stack into a register. Sizes MUST match. */
def pop(dest: Register): AsmLine = {
stack.remove(stack.last._1)
Pop(dest)
}
/** Drop the top n values from the stack. */
def drop(n: Int = 1): AsmLine = {
(1 to n).foreach { _ =>
stack.remove(stack.last._1)
}
Add(RSP, ImmediateVal(n * elemBytes))
}
/** Generate AsmLines within a scope, which is reset after the block. */
def withScope(block: () => Chain[AsmLine]): Chain[AsmLine] = {
val resetToSize = stack.size
var lines = block()
lines :+= drop(stack.size - resetToSize)
lines
}
/** Get an IndexAddress for a variable in the stack. */
def accessVar(ident: mw.Ident): IndexAddress =
IndexAddress(RSP, sizeBytes - stack(ident).bottom)
def contains(ident: mw.Ident): Boolean = stack.contains(ident)
def head: MemLocation = MemLocation(RSP, stack.last._2.size)
override def toString(): String = stack.toString
}

View File

@@ -0,0 +1,458 @@
package wacc
import scala.collection.mutable.ListBuffer
import cats.data.Chain
import cats.syntax.foldable._
object asmGenerator {
import microWacc._
import assemblyIR._
import assemblyIR.Size._
import assemblyIR.RegName._
import types._
import sizeExtensions._
import lexer.escapedChars
abstract case class Error() {
def strLabel: String
def errStr: String
def errLabel: String
def stringDef: Chain[AsmLine] = Chain(
Directive.Int(errStr.size),
LabelDef(strLabel),
Directive.Asciz(errStr)
)
}
object zeroDivError extends Error {
// TODO: is this bad? Can we make an error case class/some other structure?
def strLabel = ".L._errDivZero_str0"
def errStr = "fatal error: division or modulo by zero"
def errLabel = ".L._errDivZero"
}
private val RAX = Register(Q64, AX)
private val EAX = Register(D32, AX)
private val RDI = Register(Q64, DI)
private val RIP = Register(Q64, IP)
private val RBP = Register(Q64, BP)
private val RSI = Register(Q64, SI)
private val RDX = Register(Q64, DX)
private val RCX = Register(Q64, CX)
private val argRegs = List(DI, SI, DX, CX, R8, R9)
extension [T](chain: Chain[T])
def +(item: T): Chain[T] = chain.append(item)
def concatAll(chains: Chain[T]*): Chain[T] =
chains.foldLeft(chain)(_ ++ _)
class LabelGenerator {
var labelVal = -1
def getLabel(): String = {
labelVal += 1
s".L$labelVal"
}
def getLabel(target: CallTarget): String = target match {
case Ident(v, _) => s"wacc_$v"
case Builtin(name) => s"_$name"
}
}
def generateAsm(microProg: Program): Chain[AsmLine] = {
given stack: Stack = Stack()
given strings: ListBuffer[String] = ListBuffer[String]()
given labelGenerator: LabelGenerator = LabelGenerator()
val Program(funcs, main) = microProg
val progAsm = Chain(LabelDef("main")).concatAll(
funcPrologue(),
main.foldMap(generateStmt(_)),
Chain.one(Xor(RAX, RAX)),
funcEpilogue(),
generateBuiltInFuncs(),
funcs.foldMap(generateUserFunc(_))
)
val strDirs = strings.toList.zipWithIndex.foldMap { case (str, i) =>
Chain(
Directive.Int(str.size),
LabelDef(s".L.str$i"),
Directive.Asciz(str.escaped)
)
} ++ zeroDivError.stringDef
Chain(
Directive.IntelSyntax,
Directive.Global("main"),
Directive.RoData
).concatAll(
strDirs,
Chain.one(Directive.Text),
progAsm
)
}
private def wrapBuiltinFunc(labelName: String, funcBody: Chain[AsmLine])(using
stack: Stack
): Chain[AsmLine] = {
var chain = Chain.one[AsmLine](LabelDef(labelName))
chain ++= funcPrologue()
chain ++= funcBody
chain ++= funcEpilogue()
chain
}
private def generateUserFunc(func: FuncDecl)(using
strings: ListBuffer[String],
labelGenerator: LabelGenerator
): Chain[AsmLine] = {
given stack: Stack = Stack()
// Setup the stack with param 7 and up
func.params.drop(argRegs.size).foreach(stack.reserve(_))
stack.reserve(Q64) // Reserve return pointer slot
var chain = Chain.one[AsmLine](LabelDef(labelGenerator.getLabel(func.name)))
chain ++= funcPrologue()
// Push the rest of params onto the stack for simplicity
argRegs.zip(func.params).foreach { (reg, param) =>
chain += stack.push(param, Register(Q64, reg))
}
chain ++= func.body.foldMap(generateStmt(_))
// No need for epilogue here since all user functions must return explicitly
chain
}
private def generateBuiltInFuncs()(using
stack: Stack,
labelGenerator: LabelGenerator
): Chain[AsmLine] = {
var chain = Chain.empty[AsmLine]
chain ++= wrapBuiltinFunc(
labelGenerator.getLabel(Builtin.Exit),
Chain(stackAlign, assemblyIR.Call(CLibFunc.Exit))
)
chain ++= wrapBuiltinFunc(
labelGenerator.getLabel(Builtin.Printf),
Chain(
stackAlign,
assemblyIR.Call(CLibFunc.PrintF),
Xor(RDI, RDI),
assemblyIR.Call(CLibFunc.Fflush)
)
)
chain ++= wrapBuiltinFunc(
labelGenerator.getLabel(Builtin.PrintCharArray),
Chain(
stackAlign,
Load(RDX, IndexAddress(RSI, KnownType.Int.size.toInt)),
Move(Register(D32, SI), MemLocation(RSI, D32)),
assemblyIR.Call(CLibFunc.PrintF),
Xor(RDI, RDI),
assemblyIR.Call(CLibFunc.Fflush)
)
)
chain ++= wrapBuiltinFunc(
labelGenerator.getLabel(Builtin.Malloc),
Chain(stackAlign, assemblyIR.Call(CLibFunc.Malloc))
// Out of memory check is optional
)
chain ++= wrapBuiltinFunc(labelGenerator.getLabel(Builtin.Free), Chain.empty)
chain ++= wrapBuiltinFunc(
labelGenerator.getLabel(Builtin.Read),
Chain(
stackAlign,
Subtract(Register(Q64, SP), ImmediateVal(8)),
Push(RSI),
Load(RSI, MemLocation(Register(Q64, SP), Q64)),
assemblyIR.Call(CLibFunc.Scanf),
Pop(RAX)
)
)
chain ++= Chain(
// TODO can this be done with a call to generateStmt?
// Consider other error cases -> look to generalise
LabelDef(zeroDivError.errLabel),
stackAlign,
Load(RDI, IndexAddress(RIP, LabelArg(zeroDivError.strLabel))),
assemblyIR.Call(CLibFunc.PrintF),
Move(RDI, ImmediateVal(-1)),
assemblyIR.Call(CLibFunc.Exit)
)
chain
}
private def generateStmt(stmt: Stmt)(using
stack: Stack,
strings: ListBuffer[String],
labelGenerator: LabelGenerator
): Chain[AsmLine] = {
var chain = Chain.empty[AsmLine]
chain += Comment(stmt.toString)
stmt match {
case Assign(lhs, rhs) =>
lhs match {
case ident: Ident =>
if (!stack.contains(ident)) chain += stack.reserve(ident)
chain ++= evalExprOntoStack(rhs)
chain += stack.pop(RAX)
chain += Move(stack.accessVar(ident), RAX)
case ArrayElem(x, i) =>
chain ++= evalExprOntoStack(rhs)
chain ++= evalExprOntoStack(i)
chain ++= evalExprOntoStack(x)
chain += stack.pop(RAX)
chain += stack.pop(RCX)
chain += stack.pop(RDX)
chain += Move(
IndexAddress(RAX, KnownType.Int.size.toInt, RCX, x.ty.elemSize.toInt),
Register(x.ty.elemSize, DX)
)
}
case If(cond, thenBranch, elseBranch) =>
val elseLabel = labelGenerator.getLabel()
val endLabel = labelGenerator.getLabel()
chain ++= evalExprOntoStack(cond)
chain += stack.pop(RAX)
chain += Compare(RAX, ImmediateVal(0))
chain += Jump(LabelArg(elseLabel), Cond.Equal)
chain ++= stack.withScope(() => thenBranch.foldMap(generateStmt))
chain += Jump(LabelArg(endLabel))
chain += LabelDef(elseLabel)
chain ++= stack.withScope(() => elseBranch.foldMap(generateStmt))
chain += LabelDef(endLabel)
case While(cond, body) =>
val startLabel = labelGenerator.getLabel()
val endLabel = labelGenerator.getLabel()
chain += LabelDef(startLabel)
chain ++= evalExprOntoStack(cond)
chain += stack.pop(RAX)
chain += Compare(RAX, ImmediateVal(0))
chain += Jump(LabelArg(endLabel), Cond.Equal)
chain ++= stack.withScope(() => body.foldMap(generateStmt))
chain += Jump(LabelArg(startLabel))
chain += LabelDef(endLabel)
case call: microWacc.Call =>
chain ++= generateCall(call, isTail = false)
case microWacc.Return(expr) =>
expr match {
case call: microWacc.Call =>
chain ++= generateCall(call, isTail = true) // tco
case _ =>
chain ++= evalExprOntoStack(expr)
chain += stack.pop(RAX)
chain ++= funcEpilogue()
}
}
chain
}
private def evalExprOntoStack(expr: Expr)(using
stack: Stack,
strings: ListBuffer[String],
labelGenerator: LabelGenerator
): Chain[AsmLine] = {
var chain = Chain.empty[AsmLine]
val stackSizeStart = stack.size
expr match {
case IntLiter(v) => chain += stack.push(KnownType.Int.size, ImmediateVal(v))
case CharLiter(v) => chain += stack.push(KnownType.Char.size, ImmediateVal(v.toInt))
case ident: Ident => chain += stack.push(ident.ty.size, stack.accessVar(ident))
case array @ ArrayLiter(elems) =>
expr.ty match {
case KnownType.String =>
strings += elems.collect { case CharLiter(v) => v }.mkString
chain += Load(RAX, IndexAddress(RIP, LabelArg(s".L.str${strings.size - 1}")))
chain += stack.push(Q64, RAX)
case ty =>
chain ++= generateCall(
microWacc.Call(Builtin.Malloc, List(IntLiter(array.heapSize))),
isTail = false
)
chain += stack.push(Q64, RAX)
// Store the length of the array at the start
chain += Move(MemLocation(RAX, D32), ImmediateVal(elems.size))
elems.zipWithIndex.foldMap { (elem, i) =>
chain ++= evalExprOntoStack(elem)
chain += stack.pop(RCX)
chain += stack.pop(RAX)
chain += Move(IndexAddress(RAX, 4 + i * ty.elemSize.toInt), Register(ty.elemSize, CX))
chain += stack.push(Q64, RAX)
}
}
case BoolLiter(true) =>
chain += stack.push(KnownType.Bool.size, ImmediateVal(1))
case BoolLiter(false) =>
chain += Xor(RAX, RAX)
chain += stack.push(KnownType.Bool.size, RAX)
case NullLiter() =>
chain += stack.push(KnownType.Pair(?, ?).size, ImmediateVal(0))
case ArrayElem(x, i) =>
chain ++= evalExprOntoStack(x)
chain ++= evalExprOntoStack(i)
chain += stack.pop(RCX)
chain += stack.pop(RAX)
// + Int because we store the length of the array at the start
chain += Move(
Register(x.ty.elemSize, AX),
IndexAddress(RAX, KnownType.Int.size.toInt, RCX, x.ty.elemSize.toInt)
)
chain += stack.push(x.ty.elemSize, RAX)
case UnaryOp(x, op) =>
chain ++= evalExprOntoStack(x)
op match {
case UnaryOperator.Chr | UnaryOperator.Ord => // No op needed
case UnaryOperator.Len =>
chain += stack.pop(RAX)
chain += Move(EAX, MemLocation(RAX, D32))
chain += stack.push(D32, RAX)
case UnaryOperator.Negate =>
chain += Negate(stack.head)
case UnaryOperator.Not =>
chain += Xor(stack.head, ImmediateVal(1))
}
case BinaryOp(x, y, op) =>
val destX = Register(x.ty.size, AX)
chain ++= evalExprOntoStack(y)
chain ++= evalExprOntoStack(x)
chain += stack.pop(RAX)
op match {
case BinaryOperator.Add =>
chain += Add(stack.head, destX)
case BinaryOperator.Sub =>
chain += Subtract(destX, stack.head)
chain += stack.drop()
chain += stack.push(destX.size, RAX)
case BinaryOperator.Mul =>
chain += Multiply(destX, stack.head)
chain += stack.drop()
chain += stack.push(destX.size, RAX)
case BinaryOperator.Div =>
chain += Compare(stack.head, ImmediateVal(0))
chain += Jump(LabelArg(zeroDivError.errLabel), Cond.Equal)
chain += CDQ()
chain += Divide(stack.head)
chain += stack.drop()
chain += stack.push(destX.size, RAX)
case BinaryOperator.Mod =>
chain += CDQ()
chain += Divide(stack.head)
chain += stack.drop()
chain += stack.push(destX.size, RDX)
case BinaryOperator.Eq => chain ++= generateComparison(destX, Cond.Equal)
case BinaryOperator.Neq => chain ++= generateComparison(destX, Cond.NotEqual)
case BinaryOperator.Greater => chain ++= generateComparison(destX, Cond.Greater)
case BinaryOperator.GreaterEq => chain ++= generateComparison(destX, Cond.GreaterEqual)
case BinaryOperator.Less => chain ++= generateComparison(destX, Cond.Less)
case BinaryOperator.LessEq => chain ++= generateComparison(destX, Cond.LessEqual)
case BinaryOperator.And => chain += And(stack.head, destX)
case BinaryOperator.Or => chain += Or(stack.head, destX)
}
case call: microWacc.Call =>
chain ++= generateCall(call, isTail = false)
chain += stack.push(call.ty.size, RAX)
}
assert(stack.size == stackSizeStart + 1)
chain ++= zeroRest(MemLocation(stack.head.pointer, Q64), expr.ty.size)
chain
}
private def generateCall(call: microWacc.Call, isTail: Boolean)(using
stack: Stack,
strings: ListBuffer[String],
labelGenerator: LabelGenerator
): Chain[AsmLine] = {
var chain = Chain.empty[AsmLine]
val microWacc.Call(target, args) = call
argRegs.zip(args).foldMap { (reg, expr) =>
chain ++= evalExprOntoStack(expr)
chain += stack.pop(Register(Q64, reg))
}
args.drop(argRegs.size).foldMap {
chain ++= evalExprOntoStack(_)
}
// Tail Call Optimisation (TCO)
if (isTail) {
chain += Jump(LabelArg(labelGenerator.getLabel(target))) // tail call
} else {
chain += assemblyIR.Call(LabelArg(labelGenerator.getLabel(target))) // regular call
}
if (args.size > argRegs.size) {
chain += stack.drop(args.size - argRegs.size)
}
chain
}
private def generateComparison(destX: Register, cond: Cond)(using
stack: Stack
): Chain[AsmLine] = {
var chain = Chain.empty[AsmLine]
chain += Compare(destX, stack.head)
chain += Set(Register(B8, AX), cond)
chain ++= zeroRest(RAX, B8)
chain += stack.drop()
chain += stack.push(B8, RAX)
chain
}
private def funcPrologue()(using stack: Stack): Chain[AsmLine] = {
var chain = Chain.empty[AsmLine]
chain += stack.push(Q64, RBP)
chain += Move(RBP, Register(Q64, SP))
chain
}
private def funcEpilogue(): Chain[AsmLine] = {
var chain = Chain.empty[AsmLine]
chain += Move(Register(Q64, SP), RBP)
chain += Pop(RBP)
chain += assemblyIR.Return()
chain
}
private def stackAlign: AsmLine = And(Register(Q64, SP), ImmediateVal(-16))
private def zeroRest(dest: Dest, size: Size): Chain[AsmLine] = size match {
case Q64 | D32 => Chain.empty
case _ => Chain.one(And(dest, ImmediateVal((1 << (size.toInt * 8)) - 1)))
}
private val escapedCharsMapping = escapedChars.map { case (k, v) => v -> s"\\$k" }
extension (s: String) {
private def escaped: String =
s.flatMap(c => escapedCharsMapping.getOrElse(c, c.toString))
}
}

View File

@@ -6,47 +6,114 @@ object assemblyIR {
sealed trait Operand
sealed trait Src extends Operand // mem location, register and imm value
sealed trait Dest extends Operand // mem location and register
enum RegSize {
case R64
case E32
override def toString = this match {
case R64 => "r"
case E32 => "e"
enum Size {
case Q64, D32, W16, B8
def toInt: Int = this match {
case Q64 => 8
case D32 => 4
case W16 => 2
case B8 => 1
}
private val ptr = "ptr "
override def toString(): String = this match {
case Q64 => "qword " + ptr
case D32 => "dword " + ptr
case W16 => "word " + ptr
case B8 => "byte " + ptr
}
}
enum RegName {
case AX, BX, CX, DX, SI, DI, SP, BP, IP, R8, R9, R10, R11, R12, R13, R14, R15
}
case class Register(size: Size, name: RegName) extends Dest with Src {
import RegName._
if (size == Size.B8 && name == RegName.IP) {
throw new IllegalArgumentException("Cannot have 8 bit register for IP")
}
override def toString = name match {
case AX => tradToString("ax", "al")
case BX => tradToString("bx", "bl")
case CX => tradToString("cx", "cl")
case DX => tradToString("dx", "dl")
case SI => tradToString("si", "sil")
case DI => tradToString("di", "dil")
case SP => tradToString("sp", "spl")
case BP => tradToString("bp", "bpl")
case IP => tradToString("ip", "#INVALID")
case R8 => newToString(8)
case R9 => newToString(9)
case R10 => newToString(10)
case R11 => newToString(11)
case R12 => newToString(12)
case R13 => newToString(13)
case R14 => newToString(14)
case R15 => newToString(15)
}
private def tradToString(base: String, byteName: String): String =
size match {
case Size.Q64 => "r" + base
case Size.D32 => "e" + base
case Size.W16 => base
case Size.B8 => byteName
}
private def newToString(base: Int): String = {
val b = base.toString
"r" + (size match {
case Size.Q64 => b
case Size.D32 => b + "d"
case Size.W16 => b + "w"
case Size.B8 => b + "b"
})
}
}
// arguments
enum CLibFunc extends Operand {
case Scanf,
Puts,
Fflush,
Exit,
PrintF
PrintF,
Malloc,
Free
private val plt = "@plt"
override def toString = this match {
case Scanf => "scanf" + plt
case Puts => "puts" + plt
case Fflush => "fflush" + plt
case Exit => "exit" + plt
case PrintF => "printf" + plt
case Malloc => "malloc" + plt
case Free => "free" + plt
}
}
enum Register extends Dest with Src {
case Named(name: String, size: RegSize)
case Scratch(num: Int, size: RegSize)
override def toString = this match {
case Named(name, size) => s"${size}${name.toLowerCase()}"
case Scratch(num, size) => s"r${num}${if (size == RegSize.E32) "d" else ""}"
case class MemLocation(pointer: Register, opSize: Size) extends Dest with Src {
override def toString =
opSize.toString + s"[$pointer]"
}
}
case class MemLocation(pointer: Long | Register) extends Dest with Src {
override def toString = pointer match {
case hex: Long => f"[0x$hex%X]"
case reg: Register => s"[$reg]"
// TODO to string is wacky
case class IndexAddress(
base: Register,
offset: Int | LabelArg,
indexReg: Register = Register(Size.Q64, RegName.AX),
scale: Int = 0
) extends Dest
with Src {
override def toString = if (scale != 0) {
s"[$base + $indexReg * $scale + $offset]"
} else {
s"[$base + $offset]"
}
}
@@ -64,30 +131,42 @@ object assemblyIR {
}
case class Add(op1: Dest, op2: Src) extends Operation("add", op1, op2)
case class Subtract(op1: Dest, op2: Src) extends Operation("sub", op1, op2)
case class Multiply(ops: Operand*) extends Operation("mul", ops*)
case class Divide(op1: Src) extends Operation("div", op1)
case class Multiply(ops: Operand*) extends Operation("imul", ops*)
case class Divide(op1: Src) extends Operation("idiv", op1)
case class Negate(op: Dest) extends Operation("neg", op)
case class And(op1: Dest, op2: Src) extends Operation("and", op1, op2)
case class Or(op1: Dest, op2: Src) extends Operation("or", op1, op2)
case class Xor(op1: Dest, op2: Src) extends Operation("xor", op1, op2)
case class Compare(op1: Dest, op2: Src) extends Operation("cmp", op1, op2)
// stack operations
case class Push(op1: Src) extends Operation("push", op1)
case class Pop(op1: Src) extends Operation("pop", op1)
case class Call(op1: CLibFunc) extends Operation("call", op1)
case class Call(op1: CLibFunc | LabelArg) extends Operation("call", op1)
case class Move(op1: Dest, op2: Src) extends Operation("mov", op1, op2)
case class Load(op1: Register, op2: MemLocation) extends Operation("lea ", op1, op2)
case class Load(op1: Register, op2: MemLocation | IndexAddress)
extends Operation("lea ", op1, op2)
case class CDQ() extends Operation("cdq")
case class Return() extends Operation("ret")
case class Jump(op1: LabelArg, condition: Cond = Cond.Always)
extends Operation(s"j${condition.toString}", op1)
case class Set(op1: Dest, condition: Cond = Cond.Always)
extends Operation(s"set${condition.toString}", op1)
case class LabelDef(name: String) extends AsmLine {
override def toString = s"$name:"
}
case class Comment(comment: String) extends AsmLine {
override def toString =
comment.split("\n").map(line => s"# ${line}").mkString("\n")
}
enum Cond {
case Equal,
NotEqual,
@@ -108,4 +187,30 @@ object assemblyIR {
case Always => "mp"
}
}
enum Directive extends AsmLine {
case IntelSyntax, RoData, Text
case Global(name: String)
case Int(value: scala.Int)
case Asciz(string: String)
override def toString(): String = this match {
case IntelSyntax => ".intel_syntax noprefix"
case Global(name) => s".globl $name"
case Text => ".text"
case RoData => ".section .rodata"
case Int(value) => s".int $value"
case Asciz(string) => s".asciz \"$string\""
}
}
enum PrintFormat {
case Int, Char, String
override def toString(): String = this match {
case Int => "%d"
case Char => "%c"
case String => "%s"
}
}
}

View File

@@ -0,0 +1,35 @@
package wacc
object sizeExtensions {
import microWacc._
import types._
import assemblyIR.Size
extension (expr: Expr) {
/** Calculate the size (bytes) of the heap required for the expression. */
def heapSize: Int = (expr, expr.ty) match {
case (ArrayLiter(elems), KnownType.Array(KnownType.Char)) =>
KnownType.Int.size.toInt + elems.size.toInt * KnownType.Char.size.toInt
case (ArrayLiter(elems), ty) =>
KnownType.Int.size.toInt + elems.size * ty.elemSize.toInt
case _ => expr.ty.size.toInt
}
}
extension (ty: SemType) {
/** Calculate the size (bytes) of a type in a register. */
def size: Size = ty match {
case KnownType.Int => Size.D32
case KnownType.Bool | KnownType.Char => Size.B8
case KnownType.String | KnownType.Array(_) | KnownType.Pair(_, _) | ? => Size.Q64
}
def elemSize: Size = ty match {
case KnownType.Array(elem) => elem.size
case KnownType.Pair(_, _) => Size.Q64
case _ => ty.size
}
}
}

View File

@@ -1,11 +1,12 @@
package wacc
import java.io.PrintStream
import cats.data.Chain
object writer {
import assemblyIR._
def writeTo(asmList: List[AsmLine], printStream: PrintStream): Unit = {
asmList.foreach(printStream.println)
def writeTo(asmList: Chain[AsmLine], printStream: PrintStream): Unit = {
asmList.iterator.foreach(printStream.println)
}
}

View File

@@ -39,6 +39,17 @@ val errConfig = new ErrorConfig {
)
}
object lexer {
val escapedChars: Map[String, Int] = Map(
"0" -> '\u0000',
"b" -> '\b',
"t" -> '\t',
"n" -> '\n',
"f" -> '\f',
"r" -> '\r',
"\\" -> '\\',
"'" -> '\'',
"\"" -> '\"'
)
/** Language description for the WACC lexer
*/
@@ -63,15 +74,9 @@ object lexer {
textDesc = TextDesc.plain.copy(
graphicCharacter = Basic(c => c >= ' ' && c != '\\' && c != '\'' && c != '"'),
escapeSequences = EscapeDesc.plain.copy(
literals = Set('\\', '"', '\''),
mapping = Map(
"0" -> '\u0000',
"b" -> '\b',
"t" -> '\t',
"n" -> '\n',
"f" -> '\f',
"r" -> '\r'
)
literals =
escapedChars.filter { (s, chr) => chr.toChar.toString == s }.map(_._2.toChar).toSet,
mapping = escapedChars.filter { (s, chr) => chr.toChar.toString != s }
)
),
numericDesc = NumericDesc.plain.copy(

View File

@@ -1,7 +1,5 @@
package wacc
import cats.data.NonEmptyList
object microWacc {
import wacc.types._
@@ -19,9 +17,7 @@ object microWacc {
extends Expr(identTy)
with CallTarget(identTy)
with LValue
case class ArrayElem(value: LValue, indices: NonEmptyList[Expr])(ty: SemType)
extends Expr(ty)
with LValue
case class ArrayElem(value: LValue, index: Expr)(ty: SemType) extends Expr(ty) with LValue
// Operators
case class UnaryOp(x: Expr, op: UnaryOperator)(ty: SemType) extends Expr(ty)
@@ -69,13 +65,16 @@ object microWacc {
// Statements
sealed trait Stmt
case class Builtin(val name: String)(retTy: SemType) extends CallTarget(retTy) {
override def toString(): String = name
}
object Builtin {
case object ReadInt extends CallTarget(KnownType.Int)
case object ReadChar extends CallTarget(KnownType.Char)
case object Print extends CallTarget(?)
case object Println extends CallTarget(?)
case object Exit extends CallTarget(?)
case object Free extends CallTarget(?)
object Read extends Builtin("read")(?)
object Printf extends Builtin("printf")(?)
object Exit extends Builtin("exit")(?)
object Free extends Builtin("free")(?)
object Malloc extends Builtin("malloc")(?)
object PrintCharArray extends Builtin("printCharArray")(?)
}
case class Assign(lhs: LValue, rhs: Expr) extends Stmt

View File

@@ -110,7 +110,7 @@ object typeChecker {
microWacc.FuncDecl(
microWacc.Ident(name.v, name.uid)(retType),
params.zip(paramTypes).map { case (ast.Param(_, ident), ty) =>
microWacc.Ident(ident.v, name.uid)(ty)
microWacc.Ident(ident.v, ident.uid)(ty)
},
stmts.toList
.flatMap(
@@ -177,12 +177,14 @@ object typeChecker {
microWacc.Assign(
destTyped,
microWacc.Call(
microWacc.Builtin.Read,
List(
destTy match {
case KnownType.Int => microWacc.Builtin.ReadInt
case KnownType.Char => microWacc.Builtin.ReadChar
case _ => microWacc.Builtin.ReadInt // we'll stop due to error anyway
case KnownType.Int => "%d".toMicroWaccCharArray
case KnownType.Char | _ => "%c".toMicroWaccCharArray
},
Nil
destTyped
)
)
)
)
@@ -213,12 +215,38 @@ object typeChecker {
)
case ast.Print(expr, newline) =>
// This constraint should never fail, the scope-checker should have caught it already
val exprTyped = checkValue(expr, Constraint.Unconstrained)
val exprFormat = exprTyped.ty match {
case KnownType.Bool | KnownType.String => "%s"
case KnownType.Array(KnownType.Char) => "%.*s"
case KnownType.Char => "%c"
case KnownType.Int => "%d"
case KnownType.Pair(_, _) | KnownType.Array(_) | ? => "%p"
}
val printfCall = { (func: microWacc.Builtin, value: microWacc.Expr) =>
List(
microWacc.Call(
if newline then microWacc.Builtin.Println else microWacc.Builtin.Print,
List(checkValue(expr, Constraint.Unconstrained))
func,
List(
s"$exprFormat${if newline then "\n" else ""}".toMicroWaccCharArray,
value
)
)
)
}
exprTyped.ty match {
case KnownType.Bool =>
List(
microWacc.If(
exprTyped,
printfCall(microWacc.Builtin.Printf, "true".toMicroWaccCharArray),
printfCall(microWacc.Builtin.Printf, "false".toMicroWaccCharArray)
)
)
case KnownType.Array(KnownType.Char) =>
printfCall(microWacc.Builtin.PrintCharArray, exprTyped)
case _ => printfCall(microWacc.Builtin.Printf, exprTyped)
}
case ast.If(cond, thenStmt, elseStmt) =>
List(
microWacc.If(
@@ -262,7 +290,7 @@ object typeChecker {
microWacc.CharLiter(v)
case l @ ast.StrLiter(v) =>
KnownType.String.satisfies(constraint, l.pos)
microWacc.ArrayLiter(v.map(microWacc.CharLiter(_)).toList)(KnownType.String)
v.toMicroWaccCharArray
case l @ ast.PairLiter() =>
microWacc.NullLiter()(KnownType.Pair(?, ?).satisfies(constraint, l.pos))
case ast.Parens(expr) => checkValue(expr, constraint)
@@ -410,10 +438,15 @@ object typeChecker {
}
(next, idxTyped)
}
microWacc.ArrayElem(
val firstArrayElem = microWacc.ArrayElem(
microWacc.Ident(id.v, id.uid)(arrayTy),
indicesTyped
indicesTyped.head
)(elemTy.satisfies(constraint, value.pos))
val arrayElem = indicesTyped.tail.foldLeft(firstArrayElem) { (acc, idx) =>
microWacc.ArrayElem(acc, idx)(KnownType.Array(acc.ty))
}
// Need to type-check the final arrayElem with the constraint
microWacc.ArrayElem(arrayElem.value, arrayElem.index)(elemTy.satisfies(constraint, value.pos))
case ast.Fst(elem) =>
val elemTyped = checkLValue(
elem,
@@ -421,7 +454,7 @@ object typeChecker {
)
microWacc.ArrayElem(
elemTyped,
NonEmptyList.of(microWacc.IntLiter(0))
microWacc.IntLiter(0)
)(elemTyped.ty match {
case KnownType.Pair(left, _) =>
left.satisfies(constraint, elem.pos)
@@ -434,11 +467,16 @@ object typeChecker {
)
microWacc.ArrayElem(
elemTyped,
NonEmptyList.of(microWacc.IntLiter(1))
microWacc.IntLiter(1)
)(elemTyped.ty match {
case KnownType.Pair(_, right) =>
right.satisfies(constraint, elem.pos)
case _ => ctx.error(Error.InternalError(elem.pos, "snd must be applied to a pair"))
})
}
extension (s: String) {
def toMicroWaccCharArray: microWacc.ArrayLiter =
microWacc.ArrayLiter(s.map(microWacc.CharLiter(_)).toList)(KnownType.String)
}
}

View File

@@ -1,13 +1,14 @@
package wacc
import org.scalatest.{ParallelTestExecution, BeforeAndAfterAll}
import org.scalatest.BeforeAndAfterAll
import org.scalatest.flatspec.AnyFlatSpec
import org.scalatest.Inspectors.forEvery
import java.io.File
import sys.process._
import java.io.PrintStream
import scala.io.Source
class ParallelExamplesSpec extends AnyFlatSpec with BeforeAndAfterAll with ParallelTestExecution {
class ParallelExamplesSpec extends AnyFlatSpec with BeforeAndAfterAll {
val files =
allWaccFiles("wacc-examples/valid").map { p =>
(p.toString, List(0))
@@ -26,19 +27,38 @@ class ParallelExamplesSpec extends AnyFlatSpec with BeforeAndAfterAll with Paral
forEvery(files) { (filename, expectedResult) =>
val baseFilename = filename.stripSuffix(".wacc")
given stdout: PrintStream = PrintStream(File(baseFilename + ".out"))
val result = compile(filename)
s"$filename" should "be compiled with correct result" in {
val result = compile(filename)
assert(expectedResult.contains(result))
}
if (result == 0) it should "run with correct result" in {
if (expectedResult == List(0)) it should "run with correct result" in {
if (fileIsDisallowedBackend(filename)) pending
// Retrieve contents to get input and expected output + exit code
val contents = scala.io.Source.fromFile(File(filename)).getLines.toList
val inputLine =
contents.find(_.matches("^# ?[Ii]nput:.*$")).map(_.split(":").last.strip).getOrElse("")
contents
.find(_.matches("^# ?[Ii]nput:.*$"))
.map(line =>
("" :: line.split(":").last.strip.split(" ").toList)
.sliding(2)
.flatMap { arr =>
if (
// First entry has no space in front
arr(0) == "" ||
// int followed by non-digit, space can be removed
arr(0).toIntOption.nonEmpty && !arr(1)(0).isDigit ||
// non-int followed by int, space can be removed
!arr(0).last.isDigit && arr(1).toIntOption.nonEmpty
)
then List(arr(1))
else List(" ", arr(1))
}
.mkString
)
.getOrElse("")
val outputLineIdx = contents.indexWhere(_.matches("^# ?[Oo]utput:.*$"))
val expectedOutput =
if (outputLineIdx == -1) ""
@@ -62,12 +82,17 @@ class ParallelExamplesSpec extends AnyFlatSpec with BeforeAndAfterAll with Paral
// Run the executable with the provided input
val stdout = new StringBuilder
// val execResult = s"$execFilename".!(ProcessLogger(stdout.append(_)))
val execResult =
s"echo $inputLine" #| s"timeout 5s $execFilename" ! ProcessLogger(stdout.append(_))
val process = s"timeout 5s $execFilename" run ProcessIO(
in = w => {
w.write(inputLine.getBytes)
w.close()
},
out = Source.fromInputStream(_).addString(stdout),
err = _ => ()
)
assert(execResult == expectedExit)
assert(stdout.toString == expectedOutput)
assert(process.exitValue == expectedExit)
assert(stdout.toString.replaceAll("0x[0-9a-f]+", "#addrs#") == expectedOutput)
}
}
@@ -80,22 +105,23 @@ class ParallelExamplesSpec extends AnyFlatSpec with BeforeAndAfterAll with Paral
// format: off
// disable formatting to avoid binPack
"^.*wacc-examples/valid/advanced.*$",
"^.*wacc-examples/valid/array.*$",
"^.*wacc-examples/valid/basic/skip.*$",
"^.*wacc-examples/valid/expressions.*$",
"^.*wacc-examples/valid/function/nested_functions.*$",
"^.*wacc-examples/valid/function/simple_functions.*$",
"^.*wacc-examples/valid/if.*$",
"^.*wacc-examples/valid/IO/print.*$",
"^.*wacc-examples/valid/IO/read(?!echoInt\\.wacc).*$",
"^.*wacc-examples/valid/IO/IOLoop.wacc.*$",
"^.*wacc-examples/valid/IO/IOSequence.wacc.*$",
"^.*wacc-examples/valid/pairs.*$",
// "^.*wacc-examples/valid/array.*$",
// "^.*wacc-examples/valid/basic/exit.*$",
// "^.*wacc-examples/valid/basic/skip.*$",
// "^.*wacc-examples/valid/expressions.*$",
// "^.*wacc-examples/valid/function/nested_functions.*$",
// "^.*wacc-examples/valid/function/simple_functions.*$",
// "^.*wacc-examples/valid/if.*$",
// "^.*wacc-examples/valid/IO/print.*$",
// "^.*wacc-examples/valid/IO/read.*$",
// "^.*wacc-examples/valid/IO/IOLoop.wacc.*$",
// "^.*wacc-examples/valid/IO/IOSequence.wacc.*$",
// "^.*wacc-examples/valid/pairs.*$",
"^.*wacc-examples/valid/runtimeErr.*$",
"^.*wacc-examples/valid/scope.*$",
"^.*wacc-examples/valid/sequence.*$",
"^.*wacc-examples/valid/variables.*$",
"^.*wacc-examples/valid/while.*$",
// "^.*wacc-examples/valid/scope.*$",
// "^.*wacc-examples/valid/sequence.*$",
// "^.*wacc-examples/valid/variables.*$",
// "^.*wacc-examples/valid/while.*$",
// format: on
).find(filename.matches).isDefined
}

View File

@@ -1,42 +1,38 @@
import org.scalatest.funsuite.AnyFunSuite
import wacc.assemblyIR._
import wacc.assemblyIR.Size._
import wacc.assemblyIR.RegName._
class instructionSpec extends AnyFunSuite {
val named64BitRegister = Register.Named("ax", RegSize.R64)
val named64BitRegister = Register(Q64, AX)
test("named 64-bit register toString") {
assert(named64BitRegister.toString == "rax")
}
val named32BitRegister = Register.Named("ax", RegSize.E32)
val named32BitRegister = Register(D32, AX)
test("named 32-bit register toString") {
assert(named32BitRegister.toString == "eax")
}
val scratch64BitRegister = Register.Scratch(1, RegSize.R64)
val scratch64BitRegister = Register(Q64, R8)
test("scratch 64-bit register toString") {
assert(scratch64BitRegister.toString == "r1")
assert(scratch64BitRegister.toString == "r8")
}
val scratch32BitRegister = Register.Scratch(1, RegSize.E32)
val scratch32BitRegister = Register(D32, R8)
test("scratch 32-bit register toString") {
assert(scratch32BitRegister.toString == "r1d")
assert(scratch32BitRegister.toString == "r8d")
}
val memLocationWithHex = MemLocation(0x12345678)
test("mem location with hex toString") {
assert(memLocationWithHex.toString == "[0x12345678]")
}
val memLocationWithRegister = MemLocation(named64BitRegister)
val memLocationWithRegister = MemLocation(named64BitRegister, Q64)
test("mem location with register toString") {
assert(memLocationWithRegister.toString == "[rax]")
assert(memLocationWithRegister.toString == "qword ptr [rax]")
}
val immediateVal = ImmediateVal(123)
@@ -54,7 +50,7 @@ class instructionSpec extends AnyFunSuite {
val subInstruction = Subtract(scratch64BitRegister, named64BitRegister)
test("x86: sub instruction toString") {
assert(subInstruction.toString == "\tsub r1, rax")
assert(subInstruction.toString == "\tsub r8, rax")
}
val callInstruction = Call(CLibFunc.Scanf)