feat: enhance asmGenerator with additional registers and improve function call generation
Co-authored-by: Barf-Vader <Barf-Vader@users.noreply.github.com> Co-authored-by: Gleb Koval
This commit is contained in:
parent
f30cf42c4b
commit
58d280462e
@ -2,7 +2,6 @@ package wacc
|
|||||||
|
|
||||||
import scala.collection.mutable.LinkedHashMap
|
import scala.collection.mutable.LinkedHashMap
|
||||||
import scala.collection.mutable.ListBuffer
|
import scala.collection.mutable.ListBuffer
|
||||||
import cats.syntax.align
|
|
||||||
|
|
||||||
object asmGenerator {
|
object asmGenerator {
|
||||||
import microWacc._
|
import microWacc._
|
||||||
@ -18,6 +17,10 @@ object asmGenerator {
|
|||||||
val RIP = Register(RegSize.R64, RegName.IP)
|
val RIP = Register(RegSize.R64, RegName.IP)
|
||||||
val RBP = Register(RegSize.R64, RegName.BP)
|
val RBP = Register(RegSize.R64, RegName.BP)
|
||||||
val RSI = Register(RegSize.R64, RegName.SI)
|
val RSI = Register(RegSize.R64, RegName.SI)
|
||||||
|
val RDX = Register(RegSize.R64, RegName.DX)
|
||||||
|
val RCX = Register(RegSize.R64, RegName.CX)
|
||||||
|
val R8 = Register(RegSize.R64, RegName.Reg8)
|
||||||
|
val R9 = Register(RegSize.R64, RegName.Reg9)
|
||||||
|
|
||||||
val _8_BIT_MASK = 0xff
|
val _8_BIT_MASK = 0xff
|
||||||
|
|
||||||
@ -44,10 +47,11 @@ object asmGenerator {
|
|||||||
alignStack() ++
|
alignStack() ++
|
||||||
main.flatMap(generateStmt) ++
|
main.flatMap(generateStmt) ++
|
||||||
List(Move(RAX, ImmediateVal(0))) ++
|
List(Move(RAX, ImmediateVal(0))) ++
|
||||||
funcEpilogue()
|
funcEpilogue() ++
|
||||||
|
generateBuiltInFuncs()
|
||||||
|
|
||||||
val strDirs = strings.toList.zipWithIndex.flatMap { case (str, i) =>
|
val strDirs = strings.toList.zipWithIndex.flatMap { case (str, i) =>
|
||||||
List(Directive.Int(str.size), LabelDef(s".L.str$i"), Directive.Asciz(str))
|
List(Directive.Int(str.size), LabelDef(s".L.str$i"), Directive.Asciz(str.replace("\"", "\\\"")))
|
||||||
}
|
}
|
||||||
|
|
||||||
List(Directive.IntelSyntax, Directive.Global("main"), Directive.RoData) ++
|
List(Directive.IntelSyntax, Directive.Global("main"), Directive.RoData) ++
|
||||||
@ -73,7 +77,7 @@ object asmGenerator {
|
|||||||
wrapFunc(
|
wrapFunc(
|
||||||
labelGenerator.getLabel(Builtin.Exit),
|
labelGenerator.getLabel(Builtin.Exit),
|
||||||
alignStack() ++
|
alignStack() ++
|
||||||
List(Pop(RDI), assemblyIR.Call(CLibFunc.Exit))
|
List(assemblyIR.Call(CLibFunc.Exit))
|
||||||
) ++
|
) ++
|
||||||
wrapFunc(
|
wrapFunc(
|
||||||
labelGenerator.getLabel(Builtin.Printf),
|
labelGenerator.getLabel(Builtin.Printf),
|
||||||
@ -84,16 +88,28 @@ object asmGenerator {
|
|||||||
assemblyIR.Call(CLibFunc.Fflush)
|
assemblyIR.Call(CLibFunc.Fflush)
|
||||||
)
|
)
|
||||||
) ++
|
) ++
|
||||||
wrapFunc(labelGenerator.getLabel(Builtin.Malloc), List()) ++
|
wrapFunc(
|
||||||
wrapFunc(labelGenerator.getLabel(Builtin.Free), List())
|
labelGenerator.getLabel(Builtin.Malloc),
|
||||||
|
alignStack() ++
|
||||||
|
List()
|
||||||
|
) ++
|
||||||
|
wrapFunc(labelGenerator.getLabel(Builtin.Free), List()) ++
|
||||||
|
wrapFunc(
|
||||||
|
labelGenerator.getLabel(Builtin.Read),
|
||||||
|
alignStack() ++
|
||||||
|
List(
|
||||||
|
Push(RSI),
|
||||||
|
Load(RSI, MemLocation(RSP)),
|
||||||
|
assemblyIR.Call(CLibFunc.Scanf),
|
||||||
|
Pop(RAX)
|
||||||
|
)
|
||||||
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
def generateStmt(
|
def generateStmt(
|
||||||
stmt: Stmt
|
stmt: Stmt
|
||||||
)(using stack: LinkedHashMap[Ident, Int], strings: ListBuffer[String]): List[AsmLine] =
|
)(using stack: LinkedHashMap[Ident, Int], strings: ListBuffer[String]): List[AsmLine] =
|
||||||
stmt match {
|
stmt match {
|
||||||
case microWacc.Call(Builtin.Exit, code :: _) =>
|
|
||||||
List()
|
|
||||||
case Assign(lhs, rhs) =>
|
case Assign(lhs, rhs) =>
|
||||||
var dest: IndexAddress =
|
var dest: IndexAddress =
|
||||||
IndexAddress(RSP, 0) // gets overrwitten
|
IndexAddress(RSP, 0) // gets overrwitten
|
||||||
@ -112,15 +128,8 @@ object asmGenerator {
|
|||||||
// dest = ???
|
// dest = ???
|
||||||
List()
|
List()
|
||||||
}) ++
|
}) ++
|
||||||
(rhs match {
|
|
||||||
case microWacc.Call(Builtin.ReadInt, _) =>
|
|
||||||
readIntoVar(dest, Builtin.ReadInt)
|
|
||||||
case microWacc.Call(Builtin.ReadChar, _) =>
|
|
||||||
readIntoVar(dest, Builtin.ReadChar)
|
|
||||||
case _ =>
|
|
||||||
evalExprOntoStack(rhs) ++
|
evalExprOntoStack(rhs) ++
|
||||||
List(Pop(dest))
|
List(Pop(dest))
|
||||||
})
|
|
||||||
case If(cond, thenBranch, elseBranch) => {
|
case If(cond, thenBranch, elseBranch) => {
|
||||||
val elseLabel = labelGenerator.getLabel()
|
val elseLabel = labelGenerator.getLabel()
|
||||||
val endLabel = labelGenerator.getLabel()
|
val endLabel = labelGenerator.getLabel()
|
||||||
@ -151,8 +160,7 @@ object asmGenerator {
|
|||||||
case microWacc.Return(expr) =>
|
case microWacc.Return(expr) =>
|
||||||
evalExprOntoStack(expr) ++
|
evalExprOntoStack(expr) ++
|
||||||
List(Pop(RAX), assemblyIR.Return())
|
List(Pop(RAX), assemblyIR.Return())
|
||||||
|
case call: microWacc.Call => generateCall(call)
|
||||||
case _ => List()
|
|
||||||
}
|
}
|
||||||
|
|
||||||
def evalExprOntoStack(expr: Expr)(using
|
def evalExprOntoStack(expr: Expr)(using
|
||||||
@ -213,7 +221,7 @@ object asmGenerator {
|
|||||||
evalExprOntoStack(x) ++
|
evalExprOntoStack(x) ++
|
||||||
evalExprOntoStack(y) ++
|
evalExprOntoStack(y) ++
|
||||||
List(
|
List(
|
||||||
Pop(EAX),
|
Pop(RAX),
|
||||||
Add(MemLocation(RSP, SizeDir.Word), EAX)
|
Add(MemLocation(RSP, SizeDir.Word), EAX)
|
||||||
// TODO OVERFLOWING
|
// TODO OVERFLOWING
|
||||||
)
|
)
|
||||||
@ -221,7 +229,7 @@ object asmGenerator {
|
|||||||
evalExprOntoStack(x) ++
|
evalExprOntoStack(x) ++
|
||||||
evalExprOntoStack(y) ++
|
evalExprOntoStack(y) ++
|
||||||
List(
|
List(
|
||||||
Pop(EAX),
|
Pop(RAX),
|
||||||
Subtract(MemLocation(RSP, SizeDir.Word), EAX)
|
Subtract(MemLocation(RSP, SizeDir.Word), EAX)
|
||||||
// TODO OVERFLOWING
|
// TODO OVERFLOWING
|
||||||
)
|
)
|
||||||
@ -229,28 +237,30 @@ object asmGenerator {
|
|||||||
evalExprOntoStack(x) ++
|
evalExprOntoStack(x) ++
|
||||||
evalExprOntoStack(y) ++
|
evalExprOntoStack(y) ++
|
||||||
List(
|
List(
|
||||||
Pop(EAX),
|
Pop(RAX),
|
||||||
Multiply(MemLocation(RSP, SizeDir.Word), EAX)
|
Multiply(EAX, MemLocation(RSP, SizeDir.Word)),
|
||||||
|
Add(RSP, ImmediateVal(8)),
|
||||||
|
Push(RAX)
|
||||||
// TODO OVERFLOWING
|
// TODO OVERFLOWING
|
||||||
)
|
)
|
||||||
case BinaryOperator.Div =>
|
case BinaryOperator.Div =>
|
||||||
evalExprOntoStack(y) ++
|
evalExprOntoStack(y) ++
|
||||||
evalExprOntoStack(x) ++
|
evalExprOntoStack(x) ++
|
||||||
List(
|
List(
|
||||||
Pop(EAX),
|
Pop(RAX),
|
||||||
Divide(MemLocation(RSP, SizeDir.Word)),
|
Divide(MemLocation(RSP, SizeDir.Word)),
|
||||||
Add(RSP, ImmediateVal(8)),
|
Add(RSP, ImmediateVal(8)),
|
||||||
Push(EAX)
|
Push(RAX)
|
||||||
// TODO CHECK DIVISOR IS NOT 0
|
// TODO CHECK DIVISOR IS NOT 0
|
||||||
)
|
)
|
||||||
case BinaryOperator.Mod =>
|
case BinaryOperator.Mod =>
|
||||||
evalExprOntoStack(y) ++
|
evalExprOntoStack(y) ++
|
||||||
evalExprOntoStack(x) ++
|
evalExprOntoStack(x) ++
|
||||||
List(
|
List(
|
||||||
Pop(EAX),
|
Pop(RAX),
|
||||||
Divide(MemLocation(RSP, SizeDir.Word)),
|
Divide(MemLocation(RSP, SizeDir.Word)),
|
||||||
Add(RSP, ImmediateVal(8)),
|
Add(RSP, ImmediateVal(8)),
|
||||||
Push(EDX)
|
Push(RDX)
|
||||||
// TODO CHECK DIVISOR IS NOT 0
|
// TODO CHECK DIVISOR IS NOT 0
|
||||||
)
|
)
|
||||||
case BinaryOperator.Eq =>
|
case BinaryOperator.Eq =>
|
||||||
@ -269,21 +279,38 @@ object asmGenerator {
|
|||||||
evalExprOntoStack(x) ++
|
evalExprOntoStack(x) ++
|
||||||
evalExprOntoStack(y) ++
|
evalExprOntoStack(y) ++
|
||||||
List(
|
List(
|
||||||
Pop(EAX),
|
Pop(RAX),
|
||||||
And(MemLocation(RSP, SizeDir.Word), EAX)
|
And(MemLocation(RSP, SizeDir.Word), EAX)
|
||||||
)
|
)
|
||||||
case BinaryOperator.Or =>
|
case BinaryOperator.Or =>
|
||||||
evalExprOntoStack(x) ++
|
evalExprOntoStack(x) ++
|
||||||
evalExprOntoStack(y) ++
|
evalExprOntoStack(y) ++
|
||||||
List(
|
List(
|
||||||
Pop(EAX),
|
Pop(RAX),
|
||||||
Or(MemLocation(RSP, SizeDir.Word), EAX)
|
Or(MemLocation(RSP, SizeDir.Word), EAX)
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
case microWacc.Call(target, args) => List()
|
case call: microWacc.Call => generateCall(call)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
def generateCall(call: microWacc.Call)(using
|
||||||
|
stack: LinkedHashMap[Ident, Int],
|
||||||
|
strings: ListBuffer[String]
|
||||||
|
): List[AsmLine] = {
|
||||||
|
val argRegs = List(RDI, RSI, RDX, RCX, R8, R9)
|
||||||
|
val microWacc.Call(target, args) = call
|
||||||
|
argRegs.zip(args).flatMap { (reg, expr) =>
|
||||||
|
evalExprOntoStack(expr) ++
|
||||||
|
List(Pop(reg))
|
||||||
|
} ++
|
||||||
|
args.drop(argRegs.size).flatMap(evalExprOntoStack) ++
|
||||||
|
List(assemblyIR.Call(LabelArg(labelGenerator.getLabel(target)))) ++
|
||||||
|
(if (args.size > argRegs.size) {
|
||||||
|
List(Load(RSP, IndexAddress(RSP, (args.size - argRegs.size) * 8)))
|
||||||
|
} else Nil)
|
||||||
|
}
|
||||||
|
|
||||||
// def readIntoVar(dest: IndexAddress, readType: Builtin.ReadInt.type | Builtin.ReadChar.type)(using
|
// def readIntoVar(dest: IndexAddress, readType: Builtin.ReadInt.type | Builtin.ReadChar.type)(using
|
||||||
// stack: LinkedHashMap[Ident, Int],
|
// stack: LinkedHashMap[Ident, Int],
|
||||||
// strings: ListBuffer[String]
|
// strings: ListBuffer[String]
|
||||||
@ -316,11 +343,11 @@ object asmGenerator {
|
|||||||
evalExprOntoStack(x) ++
|
evalExprOntoStack(x) ++
|
||||||
evalExprOntoStack(y) ++
|
evalExprOntoStack(y) ++
|
||||||
List(
|
List(
|
||||||
Pop(EAX),
|
Pop(RAX),
|
||||||
Compare(MemLocation(RSP, SizeDir.Word), EAX),
|
Compare(MemLocation(RSP, SizeDir.Word), EAX),
|
||||||
Set(Register(RegSize.Byte, RegName.AL), cond),
|
Set(Register(RegSize.Byte, RegName.AL), cond),
|
||||||
And(EAX, ImmediateVal(_8_BIT_MASK)),
|
And(EAX, ImmediateVal(_8_BIT_MASK)),
|
||||||
Push(EAX)
|
Push(RAX)
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
def accessVar(ident: Ident)(using stack: LinkedHashMap[Ident, Int]): IndexAddress =
|
def accessVar(ident: Ident)(using stack: LinkedHashMap[Ident, Int]): IndexAddress =
|
||||||
@ -352,83 +379,83 @@ object asmGenerator {
|
|||||||
// def restoreRegs(regList: List[Register]): List[AsmLine] = regList.reverse.map(Pop(_))
|
// def restoreRegs(regList: List[Register]): List[AsmLine] = regList.reverse.map(Pop(_))
|
||||||
|
|
||||||
// TODO: refactor, really ugly function
|
// TODO: refactor, really ugly function
|
||||||
def printF(expr: Expr)(using
|
// def printF(expr: Expr)(using
|
||||||
stack: LinkedHashMap[Ident, Int],
|
// stack: LinkedHashMap[Ident, Int],
|
||||||
strings: ListBuffer[String]
|
// strings: ListBuffer[String]
|
||||||
): List[AsmLine] = {
|
// ): List[AsmLine] = {
|
||||||
// determine the format string
|
// // determine the format string
|
||||||
expr.ty match {
|
// expr.ty match {
|
||||||
case KnownType.String =>
|
// case KnownType.String =>
|
||||||
strings += PrintFormat.String.toString
|
// strings += PrintFormat.String.toString
|
||||||
case KnownType.Char =>
|
// case KnownType.Char =>
|
||||||
strings += PrintFormat.Char.toString
|
// strings += PrintFormat.Char.toString
|
||||||
case KnownType.Int =>
|
// case KnownType.Int =>
|
||||||
strings += PrintFormat.Int.toString
|
// strings += PrintFormat.Int.toString
|
||||||
case _ =>
|
// case _ =>
|
||||||
strings += PrintFormat.String.toString
|
// strings += PrintFormat.String.toString
|
||||||
}
|
// }
|
||||||
List(
|
// List(
|
||||||
Load(
|
// Load(
|
||||||
RDI,
|
// RDI,
|
||||||
IndexAddress(
|
// IndexAddress(
|
||||||
RIP,
|
// RIP,
|
||||||
LabelArg(s".L.str${strings.size - 1}")
|
// LabelArg(s".L.str${strings.size - 1}")
|
||||||
)
|
// )
|
||||||
)
|
// )
|
||||||
)
|
// )
|
||||||
++
|
// ++
|
||||||
// determine the actual value to print
|
// // determine the actual value to print
|
||||||
(if (expr.ty == KnownType.Bool) {
|
// (if (expr.ty == KnownType.Bool) {
|
||||||
expr match {
|
// expr match {
|
||||||
case BoolLiter(true) => {
|
// case BoolLiter(true) => {
|
||||||
strings += "true"
|
// strings += "true"
|
||||||
}
|
// }
|
||||||
case _ => {
|
// case _ => {
|
||||||
strings += "false"
|
// strings += "false"
|
||||||
}
|
// }
|
||||||
}
|
// }
|
||||||
List(
|
// List(
|
||||||
Load(
|
// Load(
|
||||||
RDI,
|
// RDI,
|
||||||
IndexAddress(
|
// IndexAddress(
|
||||||
RIP,
|
// RIP,
|
||||||
LabelArg(s".L.str${strings.size - 1}")
|
// LabelArg(s".L.str${strings.size - 1}")
|
||||||
)
|
// )
|
||||||
)
|
// )
|
||||||
)
|
// )
|
||||||
|
|
||||||
} else {
|
// } else {
|
||||||
evalExprOntoStack(expr) ++
|
// evalExprOntoStack(expr) ++
|
||||||
List(Pop(RSI))
|
// List(Pop(RSI))
|
||||||
})
|
// })
|
||||||
// print the value
|
// // print the value
|
||||||
++
|
// ++
|
||||||
List(
|
// List(
|
||||||
assemblyIR.Call(CLibFunc.PrintF),
|
// assemblyIR.Call(CLibFunc.PrintF),
|
||||||
Move(RDI, ImmediateVal(0)),
|
// Move(RDI, ImmediateVal(0)),
|
||||||
assemblyIR.Call(CLibFunc.Fflush)
|
// assemblyIR.Call(CLibFunc.Fflush)
|
||||||
)
|
// )
|
||||||
}
|
// }
|
||||||
|
|
||||||
// prints a new line
|
// prints a new line
|
||||||
def printLn()(using
|
// def printLn()(using
|
||||||
stack: LinkedHashMap[Ident, Int],
|
// stack: LinkedHashMap[Ident, Int],
|
||||||
strings: ListBuffer[String]
|
// strings: ListBuffer[String]
|
||||||
): List[AsmLine] = {
|
// ): List[AsmLine] = {
|
||||||
strings += ""
|
// strings += ""
|
||||||
Load(
|
// Load(
|
||||||
RDI,
|
// RDI,
|
||||||
IndexAddress(
|
// IndexAddress(
|
||||||
RIP,
|
// RIP,
|
||||||
LabelArg(s".L.str${strings.size - 1}")
|
// LabelArg(s".L.str${strings.size - 1}")
|
||||||
)
|
// )
|
||||||
)
|
// )
|
||||||
::
|
// ::
|
||||||
List(
|
// List(
|
||||||
assemblyIR.Call(CLibFunc.Puts),
|
// assemblyIR.Call(CLibFunc.Puts),
|
||||||
Move(RDI, ImmediateVal(0)),
|
// Move(RDI, ImmediateVal(0)),
|
||||||
assemblyIR.Call(CLibFunc.Fflush)
|
// assemblyIR.Call(CLibFunc.Fflush)
|
||||||
)
|
// )
|
||||||
|
|
||||||
}
|
// }
|
||||||
}
|
}
|
||||||
|
@ -179,7 +179,7 @@ object assemblyIR {
|
|||||||
|
|
||||||
override def toString(): String = this match {
|
override def toString(): String = this match {
|
||||||
case Byte => "byte " + ptr
|
case Byte => "byte " + ptr
|
||||||
case Word => "word " + ptr
|
case Word => "dword " + ptr // TODO check word/doubleword/quadword
|
||||||
case Unspecified => ""
|
case Unspecified => ""
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
Loading…
x
Reference in New Issue
Block a user