feat: enhance asmGenerator with additional registers and improve function call generation

Co-authored-by: Barf-Vader <Barf-Vader@users.noreply.github.com>
Co-authored-by: Gleb Koval
This commit is contained in:
Guy C 2025-02-25 02:02:57 +00:00
parent f30cf42c4b
commit 58d280462e
3 changed files with 136 additions and 109 deletions

View File

@ -2,7 +2,6 @@ package wacc
import scala.collection.mutable.LinkedHashMap import scala.collection.mutable.LinkedHashMap
import scala.collection.mutable.ListBuffer import scala.collection.mutable.ListBuffer
import cats.syntax.align
object asmGenerator { object asmGenerator {
import microWacc._ import microWacc._
@ -18,6 +17,10 @@ object asmGenerator {
val RIP = Register(RegSize.R64, RegName.IP) val RIP = Register(RegSize.R64, RegName.IP)
val RBP = Register(RegSize.R64, RegName.BP) val RBP = Register(RegSize.R64, RegName.BP)
val RSI = Register(RegSize.R64, RegName.SI) val RSI = Register(RegSize.R64, RegName.SI)
val RDX = Register(RegSize.R64, RegName.DX)
val RCX = Register(RegSize.R64, RegName.CX)
val R8 = Register(RegSize.R64, RegName.Reg8)
val R9 = Register(RegSize.R64, RegName.Reg9)
val _8_BIT_MASK = 0xff val _8_BIT_MASK = 0xff
@ -44,10 +47,11 @@ object asmGenerator {
alignStack() ++ alignStack() ++
main.flatMap(generateStmt) ++ main.flatMap(generateStmt) ++
List(Move(RAX, ImmediateVal(0))) ++ List(Move(RAX, ImmediateVal(0))) ++
funcEpilogue() funcEpilogue() ++
generateBuiltInFuncs()
val strDirs = strings.toList.zipWithIndex.flatMap { case (str, i) => val strDirs = strings.toList.zipWithIndex.flatMap { case (str, i) =>
List(Directive.Int(str.size), LabelDef(s".L.str$i"), Directive.Asciz(str)) List(Directive.Int(str.size), LabelDef(s".L.str$i"), Directive.Asciz(str.replace("\"", "\\\"")))
} }
List(Directive.IntelSyntax, Directive.Global("main"), Directive.RoData) ++ List(Directive.IntelSyntax, Directive.Global("main"), Directive.RoData) ++
@ -73,7 +77,7 @@ object asmGenerator {
wrapFunc( wrapFunc(
labelGenerator.getLabel(Builtin.Exit), labelGenerator.getLabel(Builtin.Exit),
alignStack() ++ alignStack() ++
List(Pop(RDI), assemblyIR.Call(CLibFunc.Exit)) List(assemblyIR.Call(CLibFunc.Exit))
) ++ ) ++
wrapFunc( wrapFunc(
labelGenerator.getLabel(Builtin.Printf), labelGenerator.getLabel(Builtin.Printf),
@ -84,16 +88,28 @@ object asmGenerator {
assemblyIR.Call(CLibFunc.Fflush) assemblyIR.Call(CLibFunc.Fflush)
) )
) ++ ) ++
wrapFunc(labelGenerator.getLabel(Builtin.Malloc), List()) ++ wrapFunc(
wrapFunc(labelGenerator.getLabel(Builtin.Free), List()) labelGenerator.getLabel(Builtin.Malloc),
alignStack() ++
List()
) ++
wrapFunc(labelGenerator.getLabel(Builtin.Free), List()) ++
wrapFunc(
labelGenerator.getLabel(Builtin.Read),
alignStack() ++
List(
Push(RSI),
Load(RSI, MemLocation(RSP)),
assemblyIR.Call(CLibFunc.Scanf),
Pop(RAX)
)
)
} }
def generateStmt( def generateStmt(
stmt: Stmt stmt: Stmt
)(using stack: LinkedHashMap[Ident, Int], strings: ListBuffer[String]): List[AsmLine] = )(using stack: LinkedHashMap[Ident, Int], strings: ListBuffer[String]): List[AsmLine] =
stmt match { stmt match {
case microWacc.Call(Builtin.Exit, code :: _) =>
List()
case Assign(lhs, rhs) => case Assign(lhs, rhs) =>
var dest: IndexAddress = var dest: IndexAddress =
IndexAddress(RSP, 0) // gets overrwitten IndexAddress(RSP, 0) // gets overrwitten
@ -112,15 +128,8 @@ object asmGenerator {
// dest = ??? // dest = ???
List() List()
}) ++ }) ++
(rhs match {
case microWacc.Call(Builtin.ReadInt, _) =>
readIntoVar(dest, Builtin.ReadInt)
case microWacc.Call(Builtin.ReadChar, _) =>
readIntoVar(dest, Builtin.ReadChar)
case _ =>
evalExprOntoStack(rhs) ++ evalExprOntoStack(rhs) ++
List(Pop(dest)) List(Pop(dest))
})
case If(cond, thenBranch, elseBranch) => { case If(cond, thenBranch, elseBranch) => {
val elseLabel = labelGenerator.getLabel() val elseLabel = labelGenerator.getLabel()
val endLabel = labelGenerator.getLabel() val endLabel = labelGenerator.getLabel()
@ -151,8 +160,7 @@ object asmGenerator {
case microWacc.Return(expr) => case microWacc.Return(expr) =>
evalExprOntoStack(expr) ++ evalExprOntoStack(expr) ++
List(Pop(RAX), assemblyIR.Return()) List(Pop(RAX), assemblyIR.Return())
case call: microWacc.Call => generateCall(call)
case _ => List()
} }
def evalExprOntoStack(expr: Expr)(using def evalExprOntoStack(expr: Expr)(using
@ -213,7 +221,7 @@ object asmGenerator {
evalExprOntoStack(x) ++ evalExprOntoStack(x) ++
evalExprOntoStack(y) ++ evalExprOntoStack(y) ++
List( List(
Pop(EAX), Pop(RAX),
Add(MemLocation(RSP, SizeDir.Word), EAX) Add(MemLocation(RSP, SizeDir.Word), EAX)
// TODO OVERFLOWING // TODO OVERFLOWING
) )
@ -221,7 +229,7 @@ object asmGenerator {
evalExprOntoStack(x) ++ evalExprOntoStack(x) ++
evalExprOntoStack(y) ++ evalExprOntoStack(y) ++
List( List(
Pop(EAX), Pop(RAX),
Subtract(MemLocation(RSP, SizeDir.Word), EAX) Subtract(MemLocation(RSP, SizeDir.Word), EAX)
// TODO OVERFLOWING // TODO OVERFLOWING
) )
@ -229,28 +237,30 @@ object asmGenerator {
evalExprOntoStack(x) ++ evalExprOntoStack(x) ++
evalExprOntoStack(y) ++ evalExprOntoStack(y) ++
List( List(
Pop(EAX), Pop(RAX),
Multiply(MemLocation(RSP, SizeDir.Word), EAX) Multiply(EAX, MemLocation(RSP, SizeDir.Word)),
Add(RSP, ImmediateVal(8)),
Push(RAX)
// TODO OVERFLOWING // TODO OVERFLOWING
) )
case BinaryOperator.Div => case BinaryOperator.Div =>
evalExprOntoStack(y) ++ evalExprOntoStack(y) ++
evalExprOntoStack(x) ++ evalExprOntoStack(x) ++
List( List(
Pop(EAX), Pop(RAX),
Divide(MemLocation(RSP, SizeDir.Word)), Divide(MemLocation(RSP, SizeDir.Word)),
Add(RSP, ImmediateVal(8)), Add(RSP, ImmediateVal(8)),
Push(EAX) Push(RAX)
// TODO CHECK DIVISOR IS NOT 0 // TODO CHECK DIVISOR IS NOT 0
) )
case BinaryOperator.Mod => case BinaryOperator.Mod =>
evalExprOntoStack(y) ++ evalExprOntoStack(y) ++
evalExprOntoStack(x) ++ evalExprOntoStack(x) ++
List( List(
Pop(EAX), Pop(RAX),
Divide(MemLocation(RSP, SizeDir.Word)), Divide(MemLocation(RSP, SizeDir.Word)),
Add(RSP, ImmediateVal(8)), Add(RSP, ImmediateVal(8)),
Push(EDX) Push(RDX)
// TODO CHECK DIVISOR IS NOT 0 // TODO CHECK DIVISOR IS NOT 0
) )
case BinaryOperator.Eq => case BinaryOperator.Eq =>
@ -269,21 +279,38 @@ object asmGenerator {
evalExprOntoStack(x) ++ evalExprOntoStack(x) ++
evalExprOntoStack(y) ++ evalExprOntoStack(y) ++
List( List(
Pop(EAX), Pop(RAX),
And(MemLocation(RSP, SizeDir.Word), EAX) And(MemLocation(RSP, SizeDir.Word), EAX)
) )
case BinaryOperator.Or => case BinaryOperator.Or =>
evalExprOntoStack(x) ++ evalExprOntoStack(x) ++
evalExprOntoStack(y) ++ evalExprOntoStack(y) ++
List( List(
Pop(EAX), Pop(RAX),
Or(MemLocation(RSP, SizeDir.Word), EAX) Or(MemLocation(RSP, SizeDir.Word), EAX)
) )
} }
case microWacc.Call(target, args) => List() case call: microWacc.Call => generateCall(call)
} }
} }
def generateCall(call: microWacc.Call)(using
stack: LinkedHashMap[Ident, Int],
strings: ListBuffer[String]
): List[AsmLine] = {
val argRegs = List(RDI, RSI, RDX, RCX, R8, R9)
val microWacc.Call(target, args) = call
argRegs.zip(args).flatMap { (reg, expr) =>
evalExprOntoStack(expr) ++
List(Pop(reg))
} ++
args.drop(argRegs.size).flatMap(evalExprOntoStack) ++
List(assemblyIR.Call(LabelArg(labelGenerator.getLabel(target)))) ++
(if (args.size > argRegs.size) {
List(Load(RSP, IndexAddress(RSP, (args.size - argRegs.size) * 8)))
} else Nil)
}
// def readIntoVar(dest: IndexAddress, readType: Builtin.ReadInt.type | Builtin.ReadChar.type)(using // def readIntoVar(dest: IndexAddress, readType: Builtin.ReadInt.type | Builtin.ReadChar.type)(using
// stack: LinkedHashMap[Ident, Int], // stack: LinkedHashMap[Ident, Int],
// strings: ListBuffer[String] // strings: ListBuffer[String]
@ -316,11 +343,11 @@ object asmGenerator {
evalExprOntoStack(x) ++ evalExprOntoStack(x) ++
evalExprOntoStack(y) ++ evalExprOntoStack(y) ++
List( List(
Pop(EAX), Pop(RAX),
Compare(MemLocation(RSP, SizeDir.Word), EAX), Compare(MemLocation(RSP, SizeDir.Word), EAX),
Set(Register(RegSize.Byte, RegName.AL), cond), Set(Register(RegSize.Byte, RegName.AL), cond),
And(EAX, ImmediateVal(_8_BIT_MASK)), And(EAX, ImmediateVal(_8_BIT_MASK)),
Push(EAX) Push(RAX)
) )
} }
def accessVar(ident: Ident)(using stack: LinkedHashMap[Ident, Int]): IndexAddress = def accessVar(ident: Ident)(using stack: LinkedHashMap[Ident, Int]): IndexAddress =
@ -352,83 +379,83 @@ object asmGenerator {
// def restoreRegs(regList: List[Register]): List[AsmLine] = regList.reverse.map(Pop(_)) // def restoreRegs(regList: List[Register]): List[AsmLine] = regList.reverse.map(Pop(_))
// TODO: refactor, really ugly function // TODO: refactor, really ugly function
def printF(expr: Expr)(using // def printF(expr: Expr)(using
stack: LinkedHashMap[Ident, Int], // stack: LinkedHashMap[Ident, Int],
strings: ListBuffer[String] // strings: ListBuffer[String]
): List[AsmLine] = { // ): List[AsmLine] = {
// determine the format string // // determine the format string
expr.ty match { // expr.ty match {
case KnownType.String => // case KnownType.String =>
strings += PrintFormat.String.toString // strings += PrintFormat.String.toString
case KnownType.Char => // case KnownType.Char =>
strings += PrintFormat.Char.toString // strings += PrintFormat.Char.toString
case KnownType.Int => // case KnownType.Int =>
strings += PrintFormat.Int.toString // strings += PrintFormat.Int.toString
case _ => // case _ =>
strings += PrintFormat.String.toString // strings += PrintFormat.String.toString
} // }
List( // List(
Load( // Load(
RDI, // RDI,
IndexAddress( // IndexAddress(
RIP, // RIP,
LabelArg(s".L.str${strings.size - 1}") // LabelArg(s".L.str${strings.size - 1}")
) // )
) // )
) // )
++ // ++
// determine the actual value to print // // determine the actual value to print
(if (expr.ty == KnownType.Bool) { // (if (expr.ty == KnownType.Bool) {
expr match { // expr match {
case BoolLiter(true) => { // case BoolLiter(true) => {
strings += "true" // strings += "true"
} // }
case _ => { // case _ => {
strings += "false" // strings += "false"
} // }
} // }
List( // List(
Load( // Load(
RDI, // RDI,
IndexAddress( // IndexAddress(
RIP, // RIP,
LabelArg(s".L.str${strings.size - 1}") // LabelArg(s".L.str${strings.size - 1}")
) // )
) // )
) // )
} else { // } else {
evalExprOntoStack(expr) ++ // evalExprOntoStack(expr) ++
List(Pop(RSI)) // List(Pop(RSI))
}) // })
// print the value // // print the value
++ // ++
List( // List(
assemblyIR.Call(CLibFunc.PrintF), // assemblyIR.Call(CLibFunc.PrintF),
Move(RDI, ImmediateVal(0)), // Move(RDI, ImmediateVal(0)),
assemblyIR.Call(CLibFunc.Fflush) // assemblyIR.Call(CLibFunc.Fflush)
) // )
} // }
// prints a new line // prints a new line
def printLn()(using // def printLn()(using
stack: LinkedHashMap[Ident, Int], // stack: LinkedHashMap[Ident, Int],
strings: ListBuffer[String] // strings: ListBuffer[String]
): List[AsmLine] = { // ): List[AsmLine] = {
strings += "" // strings += ""
Load( // Load(
RDI, // RDI,
IndexAddress( // IndexAddress(
RIP, // RIP,
LabelArg(s".L.str${strings.size - 1}") // LabelArg(s".L.str${strings.size - 1}")
) // )
) // )
:: // ::
List( // List(
assemblyIR.Call(CLibFunc.Puts), // assemblyIR.Call(CLibFunc.Puts),
Move(RDI, ImmediateVal(0)), // Move(RDI, ImmediateVal(0)),
assemblyIR.Call(CLibFunc.Fflush) // assemblyIR.Call(CLibFunc.Fflush)
) // )
} // }
} }

View File

@ -179,7 +179,7 @@ object assemblyIR {
override def toString(): String = this match { override def toString(): String = this match {
case Byte => "byte " + ptr case Byte => "byte " + ptr
case Word => "word " + ptr case Word => "dword " + ptr // TODO check word/doubleword/quadword
case Unspecified => "" case Unspecified => ""
} }
} }