feat: implements binary operators in asmGenerator

Co-authored-by: Gleb Koval <gleb@koval.net>
Co-authored-by: Barf-Vader <Barf-Vader@users.noreply.github.com>
This commit is contained in:
Guy C 2025-02-25 00:00:12 +00:00
parent 668d7338ae
commit 1488281223
4 changed files with 285 additions and 109 deletions

View File

@ -2,19 +2,37 @@ package wacc
import scala.collection.mutable.LinkedHashMap
import scala.collection.mutable.ListBuffer
import cats.syntax.align
object labelGenerator {
var labelVal = -1
def getLabel(): String = {
labelVal += 1
s".L$labelVal"
}
}
object asmGenerator {
import microWacc._
import assemblyIR._
import wacc.types._
val RAX = Register(RegSize.R64, RegName.AX)
val EAX = Register(RegSize.E32, RegName.AX)
val RSP = Register(RegSize.R64, RegName.SP)
val ESP = Register(RegSize.E32, RegName.SP)
val EDX = Register(RegSize.E32, RegName.DX)
val RDI = Register(RegSize.R64, RegName.DI)
val RIP = Register(RegSize.R64, RegName.IP)
val RBP = Register(RegSize.R64, RegName.BP)
val RSI = Register(RegSize.R64, RegName.SI)
val _8_BIT_MASK = 0xFF
object labelGenerator {
var labelVal = -1
def getLabel(): String = {
labelVal += 1
s".L$labelVal"
}
def getLabel(target: CallTarget): String = target match{
case Ident(v,_) => s"wacc_$v"
case Builtin(name) => s"_$name"
}
}
def generateAsm(microProg: Program): List[AsmLine] = {
given stack: LinkedHashMap[Ident, Int] = LinkedHashMap[Ident, Int]()
given strings: ListBuffer[String] = ListBuffer[String]()
@ -25,10 +43,8 @@ object asmGenerator {
funcPrologue() ++
alignStack() ++
main.flatMap(generateStmt) ++
List(Move(Register(RegSize.R64, RegName.AX), ImmediateVal(0))) ++
funcEpilogue() ++
List(assemblyIR.Return()) ++
generateFuncs()
List(Move(RAX, ImmediateVal(0))) ++
funcEpilogue()
val strDirs = strings.toList.zipWithIndex.flatMap { case (str, i) =>
List(Directive.Int(str.size), LabelDef(s".L.str$i"), Directive.Asciz(str))
@ -40,12 +56,38 @@ object asmGenerator {
progAsm
}
//TODO
def generateFuncs()(using
def wrapFunc(labelName: String, funcBody: List[AsmLine])(using
stack: LinkedHashMap[Ident, Int],
strings: ListBuffer[String]
): List[AsmLine] = {
LabelDef(labelName) ::
funcPrologue() ++
funcBody ++
funcEpilogue()
}
def generateBuiltInFuncs()(using
stack: LinkedHashMap[Ident, Int],
strings: ListBuffer[String]
): List[AsmLine] = {
wrapFunc(labelGenerator.getLabel(Builtin.Exit),
alignStack() ++
List(Pop(RDI),
assemblyIR.Call(CLibFunc.Exit))
) ++
wrapFunc(labelGenerator.getLabel(Builtin.Printf),
alignStack() ++
List(assemblyIR.Call(CLibFunc.PrintF),
Move(RDI, ImmediateVal(0)),
assemblyIR.Call(CLibFunc.Fflush)
)
)++
wrapFunc(labelGenerator.getLabel(Builtin.Malloc),
List()
)++
wrapFunc(labelGenerator.getLabel(Builtin.Free),
List()
)
}
def generateStmt(
@ -53,28 +95,16 @@ object asmGenerator {
)(using stack: LinkedHashMap[Ident, Int], strings: ListBuffer[String]): List[AsmLine] =
stmt match {
case microWacc.Call(Builtin.Exit, code :: _) =>
// alignStack() ++
evalExprIntoReg(code, Register(RegSize.R64, RegName.DI)) ++
List(assemblyIR.Call(CLibFunc.Exit))
case microWacc.Call(Builtin.Println, expr :: _) =>
// alignStack() ++
printF(expr) ++
printLn()
case microWacc.Call(Builtin.Print, expr :: _) =>
// alignStack() ++
printF(expr)
List()
case Assign(lhs, rhs) =>
var dest: IndexAddress =
IndexAddress(Register(RegSize.R64, RegName.SP), 0) // gets overrwitten
IndexAddress(RSP, 0) // gets overrwitten
(lhs match {
case ident: Ident =>
if (!stack.contains(ident)) {
stack += (ident -> (stack.size + 1))
dest = accessVar(ident)
List(Subtract(Register(RegSize.R64, RegName.SP), ImmediateVal(16)))
List(Subtract(RSP, ImmediateVal(8)))
} else {
dest = accessVar(ident)
List()
@ -90,15 +120,16 @@ object asmGenerator {
case microWacc.Call(Builtin.ReadChar, _) =>
readIntoVar(dest, Builtin.ReadChar)
case _ =>
evalExprIntoReg(rhs, Register(RegSize.R64, RegName.AX)) ++
List(Move(dest, Register(RegSize.R64, RegName.AX)))
evalExprOntoStack(rhs) ++
List(Pop(dest))
})
case If(cond, thenBranch, elseBranch) => {
val elseLabel = labelGenerator.getLabel()
val endLabel = labelGenerator.getLabel()
evalExprIntoReg(cond, Register(RegSize.R64, RegName.AX)) ++
evalExprOntoStack(cond) ++
List(
Compare(Register(RegSize.R64, RegName.AX), ImmediateVal(0)),
Compare(MemLocation(RSP, SizeDir.Word), ImmediateVal(0)),
Add(RSP, ImmediateVal(8)),
Jump(LabelArg(elseLabel), Cond.Equal)
) ++
thenBranch.flatMap(generateStmt) ++
@ -110,30 +141,33 @@ object asmGenerator {
val startLabel = labelGenerator.getLabel()
val endLabel = labelGenerator.getLabel()
List(LabelDef(startLabel)) ++
evalExprIntoReg(cond, Register(RegSize.R64, RegName.AX)) ++
evalExprOntoStack(cond) ++
List(
Compare(Register(RegSize.R64, RegName.AX), ImmediateVal(0)),
Compare(MemLocation(RSP, SizeDir.Word), ImmediateVal(0)),
Add(RSP, ImmediateVal(8)),
Jump(LabelArg(endLabel), Cond.Equal)
) ++
body.flatMap(generateStmt) ++
List(Jump(LabelArg(startLabel)), LabelDef(endLabel))
}
case microWacc.Return(expr) =>
evalExprIntoReg(expr, Register(RegSize.R64, RegName.AX))
evalExprOntoStack(expr) ++
List(Pop(RAX), assemblyIR.Return())
case _ => List()
}
def evalExprIntoReg(expr: Expr, dest: Register)(using
def evalExprOntoStack(expr: Expr)(using
stack: LinkedHashMap[Ident, Int],
strings: ListBuffer[String]
): List[AsmLine] = {
expr match {
case IntLiter(v) =>
List(Move(dest, ImmediateVal(v)))
List(Push(ImmediateVal(v)))
case CharLiter(v) =>
List(Move(dest, ImmediateVal(v.toInt)))
List(Push(ImmediateVal(v.toInt)))
case ident: Ident =>
List(Move(dest, accessVar(ident)))
List(Push(accessVar(ident)))
case ArrayLiter(elems) =>
expr.ty match {
case KnownType.String =>
@ -143,70 +177,173 @@ object asmGenerator {
}.mkString
List(
Load(
dest,
RAX,
IndexAddress(
Register(RegSize.R64, RegName.IP),
RIP,
LabelArg(s".L.str${strings.size - 1}")
)
)
),
Push(RAX)
)
// TODO other array types
case _ => List()
}
// TODO other expr types
case BoolLiter(v) => List(Move(dest, ImmediateVal(if (v) 1 else 0)))
case _ => List()
case BoolLiter(v) => List(Push(ImmediateVal(if (v) 1 else 0)))
case NullLiter() => List(Push(ImmediateVal(0)))
case ArrayElem(value, indices) => List()
case UnaryOp(x, op) => op match {
// TODO: chr and ord are TYPE CASTS. They do not change the internal value,
// but will need bound checking e.t.c.
case UnaryOperator.Chr => List()
case UnaryOperator.Ord => List()
case UnaryOperator.Len => List()
case UnaryOperator.Negate => List(
Negate(MemLocation(RSP,SizeDir.Word))
)
case UnaryOperator.Not =>
evalExprOntoStack(x) ++
List(
Xor(MemLocation(RSP, SizeDir.Word), ImmediateVal(1))
)
}
case BinaryOp(x, y, op) => op match {
case BinaryOperator.Add =>
evalExprOntoStack(x) ++
evalExprOntoStack(y) ++
List(
Pop(EAX),
Add(MemLocation(RSP, SizeDir.Word), EAX)
// TODO OVERFLOWING
)
case BinaryOperator.Sub =>
evalExprOntoStack(x) ++
evalExprOntoStack(y) ++
List(
Pop(EAX),
Subtract(MemLocation(RSP, SizeDir.Word), EAX)
// TODO OVERFLOWING
)
case BinaryOperator.Mul =>
evalExprOntoStack(x) ++
evalExprOntoStack(y) ++
List(
Pop(EAX),
Multiply(MemLocation(RSP, SizeDir.Word), EAX)
// TODO OVERFLOWING
)
case BinaryOperator.Div =>
evalExprOntoStack(y) ++
evalExprOntoStack(x) ++
List(
Pop(EAX),
Divide(MemLocation(RSP, SizeDir.Word)),
Add(RSP, ImmediateVal(8)),
Push(EAX)
// TODO CHECK DIVISOR IS NOT 0
)
case BinaryOperator.Mod =>
evalExprOntoStack(y) ++
evalExprOntoStack(x) ++
List(
Pop(EAX),
Divide(MemLocation(RSP, SizeDir.Word)),
Add(RSP, ImmediateVal(8)),
Push(EDX)
// TODO CHECK DIVISOR IS NOT 0
)
case BinaryOperator.Eq =>
generateComparison(x, y, Cond.Equal)
case BinaryOperator.Neq =>
generateComparison(x, y, Cond.NotEqual)
case BinaryOperator.Greater =>
generateComparison(x, y, Cond.Greater)
case BinaryOperator.GreaterEq =>
generateComparison(x, y, Cond.GreaterEqual)
case BinaryOperator.Less =>
generateComparison(x, y, Cond.Less)
case BinaryOperator.LessEq =>
generateComparison(x, y, Cond.LessEqual)
case BinaryOperator.And =>
evalExprOntoStack(x) ++
evalExprOntoStack(y) ++
List(
Pop(EAX),
And(MemLocation(RSP, SizeDir.Word), EAX),
)
case BinaryOperator.Or =>
evalExprOntoStack(x) ++
evalExprOntoStack(y) ++
List(
Pop(EAX),
Or(MemLocation(RSP, SizeDir.Word), EAX),
)
}
case microWacc.Call(target, args) => List()
}
}
// TODO make sure EOF doenst override the value in the stack
// probably need labels implemented for conditional jumps
def readIntoVar(dest: IndexAddress, readType: Builtin.ReadInt.type | Builtin.ReadChar.type)(using
// def readIntoVar(dest: IndexAddress, readType: Builtin.ReadInt.type | Builtin.ReadChar.type)(using
// stack: LinkedHashMap[Ident, Int],
// strings: ListBuffer[String]
// ): List[AsmLine] = {
// readType match {
// case Builtin.ReadInt =>
// strings += PrintFormat.Int.toString
// case Builtin.ReadChar =>
// strings += PrintFormat.Char.toString
// }
// List(
// Load(
// RDI,
// IndexAddress(
// RIP,
// LabelArg(s".L.str${strings.size - 1}")
// )
// ),
// Load(RSI, dest)
// ) ++
// // alignStack() ++
// List(assemblyIR.Call(CLibFunc.Scanf))
// }
def generateComparison(x : Expr, y: Expr, cond: Cond)(using
stack: LinkedHashMap[Ident, Int],
strings: ListBuffer[String]
): List[AsmLine] = {
readType match {
case Builtin.ReadInt =>
strings += PrintFormat.Int.toString
case Builtin.ReadChar =>
strings += PrintFormat.Char.toString
}
List(
Load(
Register(RegSize.R64, RegName.DI),
IndexAddress(
Register(RegSize.R64, RegName.IP),
LabelArg(s".L.str${strings.size - 1}")
)
),
Load(Register(RegSize.R64, RegName.SI), dest)
) ++
// alignStack() ++
List(assemblyIR.Call(CLibFunc.Scanf))
evalExprOntoStack(x) ++
evalExprOntoStack(y) ++
List(
Pop(EAX),
Compare(MemLocation(RSP, SizeDir.Word), EAX),
Set(Register(RegSize.Byte, RegName.AL), cond),
And(EAX, ImmediateVal(_8_BIT_MASK)),
Push(EAX)
)
}
def accessVar(ident: Ident)(using stack: LinkedHashMap[Ident, Int]): IndexAddress =
IndexAddress(Register(RegSize.R64, RegName.SP), (stack.size - stack(ident)) * 16)
IndexAddress(RSP, (stack.size - stack(ident)) * 8)
def alignStack()(using stack: LinkedHashMap[Ident, Int]): List[AsmLine] = {
List(
And(Register(RegSize.R64, RegName.SP), ImmediateVal(-16))
And(RSP, ImmediateVal(-16))
)
}
// Missing a sub instruction but dont think we need it
def funcPrologue(): List[AsmLine] = {
List(
Push(Register(RegSize.R64, RegName.BP)),
Move(Register(RegSize.R64, RegName.BP), Register(RegSize.R64, RegName.SP))
Push(RBP),
Move(RBP, RSP)
)
}
def funcEpilogue(): List[AsmLine] = {
List(
Move(Register(RegSize.R64, RegName.SP), Register(RegSize.R64, RegName.BP)),
Pop(Register(RegSize.R64, RegName.BP))
Move(RSP, RBP),
Pop(RBP),
assemblyIR.Return()
)
}
@ -231,9 +368,9 @@ object asmGenerator {
}
List(
Load(
Register(RegSize.R64, RegName.DI),
RDI,
IndexAddress(
Register(RegSize.R64, RegName.IP),
RIP,
LabelArg(s".L.str${strings.size - 1}")
)
)
@ -251,22 +388,23 @@ object asmGenerator {
}
List(
Load(
Register(RegSize.R64, RegName.DI),
RDI,
IndexAddress(
Register(RegSize.R64, RegName.IP),
RIP,
LabelArg(s".L.str${strings.size - 1}")
)
)
)
} else {
evalExprIntoReg(expr, Register(RegSize.R64, RegName.SI))
evalExprOntoStack(expr) ++
List(Pop(RSI))
})
// print the value
++
List(
assemblyIR.Call(CLibFunc.PrintF),
Move(Register(RegSize.R64, RegName.DI), ImmediateVal(0)),
Move(RDI, ImmediateVal(0)),
assemblyIR.Call(CLibFunc.Fflush)
)
}
@ -278,16 +416,16 @@ object asmGenerator {
): List[AsmLine] = {
strings += ""
Load(
Register(RegSize.R64, RegName.DI),
RDI,
IndexAddress(
Register(RegSize.R64, RegName.IP),
RIP,
LabelArg(s".L.str${strings.size - 1}")
)
)
::
List(
assemblyIR.Call(CLibFunc.Puts),
Move(Register(RegSize.R64, RegName.DI), ImmediateVal(0)),
Move(RDI, ImmediateVal(0)),
assemblyIR.Call(CLibFunc.Fflush)
)

View File

@ -9,17 +9,20 @@ object assemblyIR {
enum RegSize {
case R64
case E32
case Byte
override def toString = this match {
case R64 => "r"
case E32 => "e"
case Byte => ""
}
}
enum RegName {
case AX, BX, CX, DX, SI, DI, SP, BP, IP, Reg8, Reg9, Reg10, Reg11, Reg12, Reg13, Reg14, Reg15
case AX, AL, BX, CX, DX, SI, DI, SP, BP, IP, Reg8, Reg9, Reg10, Reg11, Reg12, Reg13, Reg14, Reg15
override def toString = this match {
case AX => "ax"
case AL => "al"
case BX => "bx"
case CX => "cx"
case DX => "dx"
@ -42,7 +45,6 @@ object assemblyIR {
// arguments
enum CLibFunc extends Operand {
case Scanf,
Puts,
Fflush,
Exit,
PrintF
@ -51,24 +53,24 @@ object assemblyIR {
override def toString = this match {
case Scanf => "scanf" + plt
case Puts => "puts" + plt
case Fflush => "fflush" + plt
case Exit => "exit" + plt
case PrintF => "printf" + plt
}
}
//TODO register naming conventions are wrong
case class Register(size: RegSize, name: RegName) extends Dest with Src {
override def toString = s"${size}${name}"
}
case class MemLocation(pointer: Long | Register) extends Dest with Src {
case class MemLocation(pointer: Long | Register, opSize: SizeDir = SizeDir.Unspecified) extends Dest with Src {
override def toString = pointer match {
case hex: Long => f"[0x$hex%X]"
case reg: Register => s"[$reg]"
case hex: Long => opSize.toString + f"[0x$hex%X]"
case reg: Register => opSize.toString + s"[$reg]"
}
}
case class IndexAddress(base: Register, offset: Int | LabelArg) extends Dest with Src {
override def toString = s"[$base + $offset]"
case class IndexAddress(base: Register, offset: Int | LabelArg, opSize: SizeDir = SizeDir.Unspecified) extends Dest with Src {
override def toString = s"$opSize[$base + $offset]"
}
case class ImmediateVal(value: Int) extends Src {
@ -85,11 +87,13 @@ object assemblyIR {
}
case class Add(op1: Dest, op2: Src) extends Operation("add", op1, op2)
case class Subtract(op1: Dest, op2: Src) extends Operation("sub", op1, op2)
case class Multiply(ops: Operand*) extends Operation("mul", ops*)
case class Divide(op1: Src) extends Operation("div", op1)
case class Multiply(ops: Operand*) extends Operation("imul", ops*)
case class Divide(op1: Src) extends Operation("idiv", op1)
case class Negate(op: Dest) extends Operation("neg", op)
case class And(op1: Dest, op2: Src) extends Operation("and", op1, op2)
case class Or(op1: Dest, op2: Src) extends Operation("or", op1, op2)
case class Xor(op1: Dest, op2: Src) extends Operation("xor", op1, op2)
case class Compare(op1: Dest, op2: Src) extends Operation("cmp", op1, op2)
// stack operations
@ -106,6 +110,9 @@ object assemblyIR {
case class Jump(op1: LabelArg, condition: Cond = Cond.Always)
extends Operation(s"j${condition.toString}", op1)
case class Set(op1: Dest, condition: Cond = Cond.Always)
extends Operation(s"set${condition.toString}", op1)
case class LabelDef(name: String) extends AsmLine {
override def toString = s"$name:"
}
@ -156,4 +163,16 @@ object assemblyIR {
case String => "%s"
}
}
}
enum SizeDir {
case Byte, Word, Unspecified
private val ptr = "ptr "
override def toString(): String = this match {
case Byte => "byte " + ptr
case Word => "word " + ptr
case Unspecified => ""
}
}
}

View File

@ -69,13 +69,15 @@ object microWacc {
// Statements
sealed trait Stmt
case class Builtin(val name: String)(retTy: SemType) extends CallTarget(retTy) {
override def toString(): String = name
}
object Builtin {
case object ReadInt extends CallTarget(KnownType.Int)
case object ReadChar extends CallTarget(KnownType.Char)
case object Print extends CallTarget(?)
case object Println extends CallTarget(?)
case object Exit extends CallTarget(?)
case object Free extends CallTarget(?)
object Read extends Builtin("read")(?)
object Printf extends Builtin("printf")(?)
object Exit extends Builtin("exit")(?)
object Free extends Builtin("free")(?)
object Malloc extends Builtin("malloc")(?)
}
case class Assign(lhs: LValue, rhs: Expr) extends Stmt

View File

@ -177,12 +177,14 @@ object typeChecker {
microWacc.Assign(
destTyped,
microWacc.Call(
destTy match {
case KnownType.Int => microWacc.Builtin.ReadInt
case KnownType.Char => microWacc.Builtin.ReadChar
case _ => microWacc.Builtin.ReadInt // we'll stop due to error anyway
},
Nil
microWacc.Builtin.Read,
List(
destTy match {
case KnownType.Int => "%d".toMicroWaccCharArray
case KnownType.Char | _ => "%c".toMicroWaccCharArray
},
destTyped
)
)
)
)
@ -213,10 +215,20 @@ object typeChecker {
)
case ast.Print(expr, newline) =>
// This constraint should never fail, the scope-checker should have caught it already
val exprTyped = checkValue(expr, Constraint.Unconstrained)
val format = exprTyped.ty match {
case KnownType.Bool | KnownType.String => "%s"
case KnownType.Char => "%c"
case KnownType.Int => "%d"
case _ => "%p"
}
List(
microWacc.Call(
if newline then microWacc.Builtin.Println else microWacc.Builtin.Print,
List(checkValue(expr, Constraint.Unconstrained))
microWacc.Builtin.Printf,
List(
s"$format${if newline then "\n" else ""}".toMicroWaccCharArray,
exprTyped
)
)
)
case ast.If(cond, thenStmt, elseStmt) =>
@ -262,7 +274,7 @@ object typeChecker {
microWacc.CharLiter(v)
case l @ ast.StrLiter(v) =>
KnownType.String.satisfies(constraint, l.pos)
microWacc.ArrayLiter(v.map(microWacc.CharLiter(_)).toList)(KnownType.String)
v.toMicroWaccCharArray
case l @ ast.PairLiter() =>
microWacc.NullLiter()(KnownType.Pair(?, ?).satisfies(constraint, l.pos))
case ast.Parens(expr) => checkValue(expr, constraint)
@ -441,4 +453,9 @@ object typeChecker {
case _ => ctx.error(Error.InternalError(elem.pos, "snd must be applied to a pair"))
})
}
extension (s: String) {
def toMicroWaccCharArray: microWacc.ArrayLiter =
microWacc.ArrayLiter(s.map(microWacc.CharLiter(_)).toList)(KnownType.String)
}
}