feat: implements binary operators in asmGenerator

Co-authored-by: Gleb Koval <gleb@koval.net>
Co-authored-by: Barf-Vader <Barf-Vader@users.noreply.github.com>
This commit is contained in:
Guy C
2025-02-25 00:00:12 +00:00
parent 668d7338ae
commit 1488281223
4 changed files with 285 additions and 109 deletions

View File

@@ -2,19 +2,37 @@ package wacc
import scala.collection.mutable.LinkedHashMap
import scala.collection.mutable.ListBuffer
import cats.syntax.align
object labelGenerator {
var labelVal = -1
def getLabel(): String = {
labelVal += 1
s".L$labelVal"
}
}
object asmGenerator {
import microWacc._
import assemblyIR._
import wacc.types._
val RAX = Register(RegSize.R64, RegName.AX)
val EAX = Register(RegSize.E32, RegName.AX)
val RSP = Register(RegSize.R64, RegName.SP)
val ESP = Register(RegSize.E32, RegName.SP)
val EDX = Register(RegSize.E32, RegName.DX)
val RDI = Register(RegSize.R64, RegName.DI)
val RIP = Register(RegSize.R64, RegName.IP)
val RBP = Register(RegSize.R64, RegName.BP)
val RSI = Register(RegSize.R64, RegName.SI)
val _8_BIT_MASK = 0xFF
object labelGenerator {
var labelVal = -1
def getLabel(): String = {
labelVal += 1
s".L$labelVal"
}
def getLabel(target: CallTarget): String = target match{
case Ident(v,_) => s"wacc_$v"
case Builtin(name) => s"_$name"
}
}
def generateAsm(microProg: Program): List[AsmLine] = {
given stack: LinkedHashMap[Ident, Int] = LinkedHashMap[Ident, Int]()
given strings: ListBuffer[String] = ListBuffer[String]()
@@ -25,10 +43,8 @@ object asmGenerator {
funcPrologue() ++
alignStack() ++
main.flatMap(generateStmt) ++
List(Move(Register(RegSize.R64, RegName.AX), ImmediateVal(0))) ++
funcEpilogue() ++
List(assemblyIR.Return()) ++
generateFuncs()
List(Move(RAX, ImmediateVal(0))) ++
funcEpilogue()
val strDirs = strings.toList.zipWithIndex.flatMap { case (str, i) =>
List(Directive.Int(str.size), LabelDef(s".L.str$i"), Directive.Asciz(str))
@@ -40,12 +56,38 @@ object asmGenerator {
progAsm
}
//TODO
def generateFuncs()(using
def wrapFunc(labelName: String, funcBody: List[AsmLine])(using
stack: LinkedHashMap[Ident, Int],
strings: ListBuffer[String]
): List[AsmLine] = {
LabelDef(labelName) ::
funcPrologue() ++
funcBody ++
funcEpilogue()
}
def generateBuiltInFuncs()(using
stack: LinkedHashMap[Ident, Int],
strings: ListBuffer[String]
): List[AsmLine] = {
wrapFunc(labelGenerator.getLabel(Builtin.Exit),
alignStack() ++
List(Pop(RDI),
assemblyIR.Call(CLibFunc.Exit))
) ++
wrapFunc(labelGenerator.getLabel(Builtin.Printf),
alignStack() ++
List(assemblyIR.Call(CLibFunc.PrintF),
Move(RDI, ImmediateVal(0)),
assemblyIR.Call(CLibFunc.Fflush)
)
)++
wrapFunc(labelGenerator.getLabel(Builtin.Malloc),
List()
)++
wrapFunc(labelGenerator.getLabel(Builtin.Free),
List()
)
}
def generateStmt(
@@ -53,28 +95,16 @@ object asmGenerator {
)(using stack: LinkedHashMap[Ident, Int], strings: ListBuffer[String]): List[AsmLine] =
stmt match {
case microWacc.Call(Builtin.Exit, code :: _) =>
// alignStack() ++
evalExprIntoReg(code, Register(RegSize.R64, RegName.DI)) ++
List(assemblyIR.Call(CLibFunc.Exit))
case microWacc.Call(Builtin.Println, expr :: _) =>
// alignStack() ++
printF(expr) ++
printLn()
case microWacc.Call(Builtin.Print, expr :: _) =>
// alignStack() ++
printF(expr)
List()
case Assign(lhs, rhs) =>
var dest: IndexAddress =
IndexAddress(Register(RegSize.R64, RegName.SP), 0) // gets overrwitten
IndexAddress(RSP, 0) // gets overrwitten
(lhs match {
case ident: Ident =>
if (!stack.contains(ident)) {
stack += (ident -> (stack.size + 1))
dest = accessVar(ident)
List(Subtract(Register(RegSize.R64, RegName.SP), ImmediateVal(16)))
List(Subtract(RSP, ImmediateVal(8)))
} else {
dest = accessVar(ident)
List()
@@ -90,15 +120,16 @@ object asmGenerator {
case microWacc.Call(Builtin.ReadChar, _) =>
readIntoVar(dest, Builtin.ReadChar)
case _ =>
evalExprIntoReg(rhs, Register(RegSize.R64, RegName.AX)) ++
List(Move(dest, Register(RegSize.R64, RegName.AX)))
evalExprOntoStack(rhs) ++
List(Pop(dest))
})
case If(cond, thenBranch, elseBranch) => {
val elseLabel = labelGenerator.getLabel()
val endLabel = labelGenerator.getLabel()
evalExprIntoReg(cond, Register(RegSize.R64, RegName.AX)) ++
evalExprOntoStack(cond) ++
List(
Compare(Register(RegSize.R64, RegName.AX), ImmediateVal(0)),
Compare(MemLocation(RSP, SizeDir.Word), ImmediateVal(0)),
Add(RSP, ImmediateVal(8)),
Jump(LabelArg(elseLabel), Cond.Equal)
) ++
thenBranch.flatMap(generateStmt) ++
@@ -110,30 +141,33 @@ object asmGenerator {
val startLabel = labelGenerator.getLabel()
val endLabel = labelGenerator.getLabel()
List(LabelDef(startLabel)) ++
evalExprIntoReg(cond, Register(RegSize.R64, RegName.AX)) ++
evalExprOntoStack(cond) ++
List(
Compare(Register(RegSize.R64, RegName.AX), ImmediateVal(0)),
Compare(MemLocation(RSP, SizeDir.Word), ImmediateVal(0)),
Add(RSP, ImmediateVal(8)),
Jump(LabelArg(endLabel), Cond.Equal)
) ++
body.flatMap(generateStmt) ++
List(Jump(LabelArg(startLabel)), LabelDef(endLabel))
}
case microWacc.Return(expr) =>
evalExprIntoReg(expr, Register(RegSize.R64, RegName.AX))
evalExprOntoStack(expr) ++
List(Pop(RAX), assemblyIR.Return())
case _ => List()
}
def evalExprIntoReg(expr: Expr, dest: Register)(using
def evalExprOntoStack(expr: Expr)(using
stack: LinkedHashMap[Ident, Int],
strings: ListBuffer[String]
): List[AsmLine] = {
expr match {
case IntLiter(v) =>
List(Move(dest, ImmediateVal(v)))
List(Push(ImmediateVal(v)))
case CharLiter(v) =>
List(Move(dest, ImmediateVal(v.toInt)))
List(Push(ImmediateVal(v.toInt)))
case ident: Ident =>
List(Move(dest, accessVar(ident)))
List(Push(accessVar(ident)))
case ArrayLiter(elems) =>
expr.ty match {
case KnownType.String =>
@@ -143,70 +177,173 @@ object asmGenerator {
}.mkString
List(
Load(
dest,
RAX,
IndexAddress(
Register(RegSize.R64, RegName.IP),
RIP,
LabelArg(s".L.str${strings.size - 1}")
)
)
),
Push(RAX)
)
// TODO other array types
case _ => List()
}
// TODO other expr types
case BoolLiter(v) => List(Move(dest, ImmediateVal(if (v) 1 else 0)))
case _ => List()
case BoolLiter(v) => List(Push(ImmediateVal(if (v) 1 else 0)))
case NullLiter() => List(Push(ImmediateVal(0)))
case ArrayElem(value, indices) => List()
case UnaryOp(x, op) => op match {
// TODO: chr and ord are TYPE CASTS. They do not change the internal value,
// but will need bound checking e.t.c.
case UnaryOperator.Chr => List()
case UnaryOperator.Ord => List()
case UnaryOperator.Len => List()
case UnaryOperator.Negate => List(
Negate(MemLocation(RSP,SizeDir.Word))
)
case UnaryOperator.Not =>
evalExprOntoStack(x) ++
List(
Xor(MemLocation(RSP, SizeDir.Word), ImmediateVal(1))
)
}
case BinaryOp(x, y, op) => op match {
case BinaryOperator.Add =>
evalExprOntoStack(x) ++
evalExprOntoStack(y) ++
List(
Pop(EAX),
Add(MemLocation(RSP, SizeDir.Word), EAX)
// TODO OVERFLOWING
)
case BinaryOperator.Sub =>
evalExprOntoStack(x) ++
evalExprOntoStack(y) ++
List(
Pop(EAX),
Subtract(MemLocation(RSP, SizeDir.Word), EAX)
// TODO OVERFLOWING
)
case BinaryOperator.Mul =>
evalExprOntoStack(x) ++
evalExprOntoStack(y) ++
List(
Pop(EAX),
Multiply(MemLocation(RSP, SizeDir.Word), EAX)
// TODO OVERFLOWING
)
case BinaryOperator.Div =>
evalExprOntoStack(y) ++
evalExprOntoStack(x) ++
List(
Pop(EAX),
Divide(MemLocation(RSP, SizeDir.Word)),
Add(RSP, ImmediateVal(8)),
Push(EAX)
// TODO CHECK DIVISOR IS NOT 0
)
case BinaryOperator.Mod =>
evalExprOntoStack(y) ++
evalExprOntoStack(x) ++
List(
Pop(EAX),
Divide(MemLocation(RSP, SizeDir.Word)),
Add(RSP, ImmediateVal(8)),
Push(EDX)
// TODO CHECK DIVISOR IS NOT 0
)
case BinaryOperator.Eq =>
generateComparison(x, y, Cond.Equal)
case BinaryOperator.Neq =>
generateComparison(x, y, Cond.NotEqual)
case BinaryOperator.Greater =>
generateComparison(x, y, Cond.Greater)
case BinaryOperator.GreaterEq =>
generateComparison(x, y, Cond.GreaterEqual)
case BinaryOperator.Less =>
generateComparison(x, y, Cond.Less)
case BinaryOperator.LessEq =>
generateComparison(x, y, Cond.LessEqual)
case BinaryOperator.And =>
evalExprOntoStack(x) ++
evalExprOntoStack(y) ++
List(
Pop(EAX),
And(MemLocation(RSP, SizeDir.Word), EAX),
)
case BinaryOperator.Or =>
evalExprOntoStack(x) ++
evalExprOntoStack(y) ++
List(
Pop(EAX),
Or(MemLocation(RSP, SizeDir.Word), EAX),
)
}
case microWacc.Call(target, args) => List()
}
}
// TODO make sure EOF doenst override the value in the stack
// probably need labels implemented for conditional jumps
def readIntoVar(dest: IndexAddress, readType: Builtin.ReadInt.type | Builtin.ReadChar.type)(using
// def readIntoVar(dest: IndexAddress, readType: Builtin.ReadInt.type | Builtin.ReadChar.type)(using
// stack: LinkedHashMap[Ident, Int],
// strings: ListBuffer[String]
// ): List[AsmLine] = {
// readType match {
// case Builtin.ReadInt =>
// strings += PrintFormat.Int.toString
// case Builtin.ReadChar =>
// strings += PrintFormat.Char.toString
// }
// List(
// Load(
// RDI,
// IndexAddress(
// RIP,
// LabelArg(s".L.str${strings.size - 1}")
// )
// ),
// Load(RSI, dest)
// ) ++
// // alignStack() ++
// List(assemblyIR.Call(CLibFunc.Scanf))
// }
def generateComparison(x : Expr, y: Expr, cond: Cond)(using
stack: LinkedHashMap[Ident, Int],
strings: ListBuffer[String]
): List[AsmLine] = {
readType match {
case Builtin.ReadInt =>
strings += PrintFormat.Int.toString
case Builtin.ReadChar =>
strings += PrintFormat.Char.toString
}
List(
Load(
Register(RegSize.R64, RegName.DI),
IndexAddress(
Register(RegSize.R64, RegName.IP),
LabelArg(s".L.str${strings.size - 1}")
)
),
Load(Register(RegSize.R64, RegName.SI), dest)
) ++
// alignStack() ++
List(assemblyIR.Call(CLibFunc.Scanf))
evalExprOntoStack(x) ++
evalExprOntoStack(y) ++
List(
Pop(EAX),
Compare(MemLocation(RSP, SizeDir.Word), EAX),
Set(Register(RegSize.Byte, RegName.AL), cond),
And(EAX, ImmediateVal(_8_BIT_MASK)),
Push(EAX)
)
}
def accessVar(ident: Ident)(using stack: LinkedHashMap[Ident, Int]): IndexAddress =
IndexAddress(Register(RegSize.R64, RegName.SP), (stack.size - stack(ident)) * 16)
IndexAddress(RSP, (stack.size - stack(ident)) * 8)
def alignStack()(using stack: LinkedHashMap[Ident, Int]): List[AsmLine] = {
List(
And(Register(RegSize.R64, RegName.SP), ImmediateVal(-16))
And(RSP, ImmediateVal(-16))
)
}
// Missing a sub instruction but dont think we need it
def funcPrologue(): List[AsmLine] = {
List(
Push(Register(RegSize.R64, RegName.BP)),
Move(Register(RegSize.R64, RegName.BP), Register(RegSize.R64, RegName.SP))
Push(RBP),
Move(RBP, RSP)
)
}
def funcEpilogue(): List[AsmLine] = {
List(
Move(Register(RegSize.R64, RegName.SP), Register(RegSize.R64, RegName.BP)),
Pop(Register(RegSize.R64, RegName.BP))
Move(RSP, RBP),
Pop(RBP),
assemblyIR.Return()
)
}
@@ -231,9 +368,9 @@ object asmGenerator {
}
List(
Load(
Register(RegSize.R64, RegName.DI),
RDI,
IndexAddress(
Register(RegSize.R64, RegName.IP),
RIP,
LabelArg(s".L.str${strings.size - 1}")
)
)
@@ -251,22 +388,23 @@ object asmGenerator {
}
List(
Load(
Register(RegSize.R64, RegName.DI),
RDI,
IndexAddress(
Register(RegSize.R64, RegName.IP),
RIP,
LabelArg(s".L.str${strings.size - 1}")
)
)
)
} else {
evalExprIntoReg(expr, Register(RegSize.R64, RegName.SI))
evalExprOntoStack(expr) ++
List(Pop(RSI))
})
// print the value
++
List(
assemblyIR.Call(CLibFunc.PrintF),
Move(Register(RegSize.R64, RegName.DI), ImmediateVal(0)),
Move(RDI, ImmediateVal(0)),
assemblyIR.Call(CLibFunc.Fflush)
)
}
@@ -278,16 +416,16 @@ object asmGenerator {
): List[AsmLine] = {
strings += ""
Load(
Register(RegSize.R64, RegName.DI),
RDI,
IndexAddress(
Register(RegSize.R64, RegName.IP),
RIP,
LabelArg(s".L.str${strings.size - 1}")
)
)
::
List(
assemblyIR.Call(CLibFunc.Puts),
Move(Register(RegSize.R64, RegName.DI), ImmediateVal(0)),
Move(RDI, ImmediateVal(0)),
assemblyIR.Call(CLibFunc.Fflush)
)