refactor: extract stack into seperate class

This commit is contained in:
Alex Ling 2025-02-25 04:44:08 +00:00
parent 8ed94e4df3
commit 3f76a2c5bf

View File

@ -10,7 +10,6 @@ object asmGenerator {
val RAX = Register(RegSize.R64, RegName.AX)
val EAX = Register(RegSize.E32, RegName.AX)
val RSP = Register(RegSize.R64, RegName.SP)
val ESP = Register(RegSize.E32, RegName.SP)
val EDX = Register(RegSize.E32, RegName.DX)
val RDI = Register(RegSize.R64, RegName.DI)
@ -37,14 +36,14 @@ object asmGenerator {
}
def generateAsm(microProg: Program): List[AsmLine] = {
given stack: LinkedHashMap[Ident, Int] = LinkedHashMap[Ident, Int]()
given stack: Stack = Stack()
given strings: ListBuffer[String] = ListBuffer[String]()
val Program(funcs, main) = microProg
val progAsm =
LabelDef("main") ::
funcPrologue() ++
alignStack() ++
List(stack.align()) ++
main.flatMap(generateStmt) ++
List(Move(RAX, ImmediateVal(0))) ++
funcEpilogue() ++
@ -61,7 +60,7 @@ object asmGenerator {
}
def wrapFunc(labelName: String, funcBody: List[AsmLine])(using
stack: LinkedHashMap[Ident, Int],
stack: Stack,
strings: ListBuffer[String]
): List[AsmLine] = {
LabelDef(labelName) ::
@ -71,74 +70,71 @@ object asmGenerator {
}
def generateBuiltInFuncs()(using
stack: LinkedHashMap[Ident, Int],
stack: Stack,
strings: ListBuffer[String]
): List[AsmLine] = {
wrapFunc(
labelGenerator.getLabel(Builtin.Exit),
alignStack() ++
List(assemblyIR.Call(CLibFunc.Exit))
List(stack.align(), assemblyIR.Call(CLibFunc.Exit))
) ++
wrapFunc(
labelGenerator.getLabel(Builtin.Printf),
alignStack() ++
List(
assemblyIR.Call(CLibFunc.PrintF),
Move(RDI, ImmediateVal(0)),
assemblyIR.Call(CLibFunc.Fflush)
)
List(
stack.align(),
assemblyIR.Call(CLibFunc.PrintF),
Move(RDI, ImmediateVal(0)),
assemblyIR.Call(CLibFunc.Fflush)
)
) ++
wrapFunc(
labelGenerator.getLabel(Builtin.Malloc),
alignStack() ++
List()
List(
stack.align(),
)
) ++
wrapFunc(labelGenerator.getLabel(Builtin.Free), List()) ++
wrapFunc(
labelGenerator.getLabel(Builtin.Read),
alignStack() ++
List(
Push(RSI),
Load(RSI, MemLocation(RSP)),
assemblyIR.Call(CLibFunc.Scanf),
Pop(RAX)
)
List(
stack.align(),
stack.push(RSI),
Load(RSI, stack.head),
assemblyIR.Call(CLibFunc.Scanf),
stack.pop(RAX)
)
)
}
def generateStmt(
stmt: Stmt
)(using stack: LinkedHashMap[Ident, Int], strings: ListBuffer[String]): List[AsmLine] =
)(using stack: Stack, strings: ListBuffer[String]): List[AsmLine] =
stmt match {
case Assign(lhs, rhs) =>
var dest: () => IndexAddress =
() => IndexAddress(RSP, 0) // gets overrwitten
() => IndexAddress(RAX, 0) // gets overrwitten
(lhs match {
case ident: Ident =>
dest = stack.accessVar(ident)
if (!stack.contains(ident)) {
stack += (ident -> (stack.size + 1))
dest = accessVar(ident)
List(Subtract(RSP, ImmediateVal(8)))
} else {
dest = accessVar(ident)
List()
}
List(stack.reserve(ident))
} else Nil
// TODO lhs = arrayElem
case _ =>
// dest = ???
List()
}) ++
evalExprOntoStack(rhs) ++
List(Pop(RAX),
Move(dest(), RAX),
List(
stack.pop(RAX),
Move(dest(), RAX),
)
case If(cond, thenBranch, elseBranch) => {
val elseLabel = labelGenerator.getLabel()
val endLabel = labelGenerator.getLabel()
evalExprOntoStack(cond) ++
List(
Compare(MemLocation(RSP, SizeDir.Word), ImmediateVal(0)),
Add(RSP, ImmediateVal(8)),
Compare(stack.head(SizeDir.Word), ImmediateVal(0)),
stack.drop(),
Jump(LabelArg(elseLabel), Cond.Equal)
) ++
thenBranch.flatMap(generateStmt) ++
@ -152,8 +148,8 @@ object asmGenerator {
List(LabelDef(startLabel)) ++
evalExprOntoStack(cond) ++
List(
Compare(MemLocation(RSP, SizeDir.Word), ImmediateVal(0)),
Add(RSP, ImmediateVal(8)),
Compare(stack.head(SizeDir.Word), ImmediateVal(0)),
stack.drop(),
Jump(LabelArg(endLabel), Cond.Equal)
) ++
body.flatMap(generateStmt) ++
@ -161,21 +157,21 @@ object asmGenerator {
}
case microWacc.Return(expr) =>
evalExprOntoStack(expr) ++
List(Pop(RAX), assemblyIR.Return())
List(stack.pop(RAX), assemblyIR.Return())
case call: microWacc.Call => generateCall(call)
}
def evalExprOntoStack(expr: Expr)(using
stack: LinkedHashMap[Ident, Int],
stack: Stack,
strings: ListBuffer[String]
): List[AsmLine] = {
expr match {
case IntLiter(v) =>
List(Push(ImmediateVal(v)))
List(stack.push(ImmediateVal(v)))
case CharLiter(v) =>
List(Push(ImmediateVal(v.toInt)))
List(stack.push(ImmediateVal(v.toInt)))
case ident: Ident =>
List(Push(accessVar(ident)()))
List(stack.push(stack.accessVar(ident)()))
case ArrayLiter(elems) =>
expr.ty match {
case KnownType.String =>
@ -191,13 +187,13 @@ object asmGenerator {
LabelArg(s".L.str${strings.size - 1}")
)
),
Push(RAX)
stack.push(RAX)
)
// TODO other array types
case _ => List()
}
case BoolLiter(v) => List(Push(ImmediateVal(if (v) 1 else 0)))
case NullLiter() => List(Push(ImmediateVal(0)))
case BoolLiter(v) => List(stack.push(ImmediateVal(if (v) 1 else 0)))
case NullLiter() => List(stack.push(ImmediateVal(0)))
case ArrayElem(value, indices) => List()
case UnaryOp(x, op) =>
op match {
@ -208,12 +204,12 @@ object asmGenerator {
case UnaryOperator.Len => List()
case UnaryOperator.Negate =>
List(
Negate(MemLocation(RSP, SizeDir.Word))
Negate(stack.head(SizeDir.Word))
)
case UnaryOperator.Not =>
evalExprOntoStack(x) ++
List(
Xor(MemLocation(RSP, SizeDir.Word), ImmediateVal(1))
Xor(stack.head(SizeDir.Word), ImmediateVal(1))
)
}
@ -223,46 +219,46 @@ object asmGenerator {
evalExprOntoStack(x) ++
evalExprOntoStack(y) ++
List(
Pop(RAX),
Add(MemLocation(RSP, SizeDir.Word), EAX)
stack.pop(RAX),
Add(stack.head(SizeDir.Word), EAX)
// TODO OVERFLOWING
)
case BinaryOperator.Sub =>
evalExprOntoStack(x) ++
evalExprOntoStack(y) ++
List(
Pop(RAX),
Subtract(MemLocation(RSP, SizeDir.Word), EAX)
stack.pop(RAX),
Subtract(stack.head(SizeDir.Word), EAX)
// TODO OVERFLOWING
)
case BinaryOperator.Mul =>
evalExprOntoStack(x) ++
evalExprOntoStack(y) ++
List(
Pop(RAX),
Multiply(EAX, MemLocation(RSP, SizeDir.Word)),
Add(RSP, ImmediateVal(8)),
Push(RAX)
stack.pop(RAX),
Multiply(EAX, stack.head(SizeDir.Word)),
stack.drop(),
stack.push(RAX)
// TODO OVERFLOWING
)
case BinaryOperator.Div =>
evalExprOntoStack(y) ++
evalExprOntoStack(x) ++
List(
Pop(RAX),
Divide(MemLocation(RSP, SizeDir.Word)),
Add(RSP, ImmediateVal(8)),
Push(RAX)
stack.pop(RAX),
Divide(stack.head(SizeDir.Word)),
stack.drop(),
stack.push(RAX)
// TODO CHECK DIVISOR IS NOT 0
)
case BinaryOperator.Mod =>
evalExprOntoStack(y) ++
evalExprOntoStack(x) ++
List(
Pop(RAX),
Divide(MemLocation(RSP, SizeDir.Word)),
Add(RSP, ImmediateVal(8)),
Push(RDX)
stack.pop(RAX),
Divide(stack.head(SizeDir.Word)),
stack.drop(),
stack.push(RDX)
// TODO CHECK DIVISOR IS NOT 0
)
case BinaryOperator.Eq =>
@ -281,15 +277,15 @@ object asmGenerator {
evalExprOntoStack(x) ++
evalExprOntoStack(y) ++
List(
Pop(RAX),
And(MemLocation(RSP, SizeDir.Word), EAX)
stack.pop(RAX),
And(stack.head(SizeDir.Word), EAX)
)
case BinaryOperator.Or =>
evalExprOntoStack(x) ++
evalExprOntoStack(y) ++
List(
Pop(RAX),
Or(MemLocation(RSP, SizeDir.Word), EAX)
stack.pop(RAX),
Or(stack.head(SizeDir.Word), EAX)
)
}
case call: microWacc.Call => generateCall(call)
@ -297,24 +293,24 @@ object asmGenerator {
}
def generateCall(call: microWacc.Call)(using
stack: LinkedHashMap[Ident, Int],
stack: Stack,
strings: ListBuffer[String]
): List[AsmLine] = {
val argRegs = List(RDI, RSI, RDX, RCX, R8, R9)
val microWacc.Call(target, args) = call
argRegs.zip(args).flatMap { (reg, expr) =>
evalExprOntoStack(expr) ++
List(Pop(reg))
List(stack.pop(reg))
} ++
args.drop(argRegs.size).flatMap(evalExprOntoStack) ++
List(assemblyIR.Call(LabelArg(labelGenerator.getLabel(target)))) ++
(if (args.size > argRegs.size) {
List(Load(RSP, IndexAddress(RSP, (args.size - argRegs.size) * 8)))
List(stack.reserve(args.size - argRegs.size))
} else Nil)
}
// def readIntoVar(dest: IndexAddress, readType: Builtin.ReadInt.type | Builtin.ReadChar.type)(using
// stack: LinkedHashMap[Ident, Int],
// stack: Stack,
// strings: ListBuffer[String]
// ): List[AsmLine] = {
// readType match {
@ -339,41 +335,33 @@ object asmGenerator {
// }
def generateComparison(x: Expr, y: Expr, cond: Cond)(using
stack: LinkedHashMap[Ident, Int],
stack: Stack,
strings: ListBuffer[String]
): List[AsmLine] = {
evalExprOntoStack(x) ++
evalExprOntoStack(y) ++
List(
Pop(RAX),
Compare(MemLocation(RSP, SizeDir.Word), EAX),
stack.pop(RAX),
Compare(stack.head(SizeDir.Word), EAX),
Set(Register(RegSize.Byte, RegName.AL), cond),
And(EAX, ImmediateVal(_8_BIT_MASK)),
Load(RSP, IndexAddress(RSP, 8)),
Push(RAX)
And(RAX, ImmediateVal(_8_BIT_MASK)),
stack.drop(),
stack.push(RAX)
)
}
def accessVar(ident: Ident)(using stack: LinkedHashMap[Ident, Int]): () => IndexAddress =
() => IndexAddress(RSP, (stack.size - stack(ident)) * 8)
def alignStack()(using stack: LinkedHashMap[Ident, Int]): List[AsmLine] = {
List(
And(RSP, ImmediateVal(-16))
)
}
// Missing a sub instruction but dont think we need it
def funcPrologue(): List[AsmLine] = {
def funcPrologue()(using stack: Stack): List[AsmLine] = {
List(
Push(RBP),
Move(RBP, RSP)
stack.push(RBP),
Move(RBP, Register(RegSize.R64, RegName.SP))
)
}
def funcEpilogue(): List[AsmLine] = {
def funcEpilogue()(using stack: Stack): List[AsmLine] = {
List(
Move(RSP, RBP),
Pop(RBP),
Move(Register(RegSize.R64, RegName.SP), RBP),
stack.pop(RBP),
assemblyIR.Return()
)
}
@ -383,7 +371,7 @@ object asmGenerator {
// TODO: refactor, really ugly function
// def printF(expr: Expr)(using
// stack: LinkedHashMap[Ident, Int],
// stack: Stack,
// strings: ListBuffer[String]
// ): List[AsmLine] = {
// // determine the format string
@ -442,7 +430,7 @@ object asmGenerator {
// prints a new line
// def printLn()(using
// stack: LinkedHashMap[Ident, Int],
// stack: Stack,
// strings: ListBuffer[String]
// ): List[AsmLine] = {
// strings += ""
@ -461,4 +449,44 @@ object asmGenerator {
// )
// }
class Stack {
private val stack = LinkedHashMap[Expr | Int, Int]()
private val RSP = Register(RegSize.R64, RegName.SP)
def next: Int = stack.size + 1
def push(expr: Expr, src: Src): AsmLine = {
stack += expr -> next
Push(src)
}
def push(src: Src): AsmLine = {
stack += stack.size -> next
Push(src)
}
def pop(dest: Src): AsmLine = {
stack.remove(stack.last._1)
Pop(dest)
}
def reserve(ident: Ident): AsmLine = {
stack += ident -> next
Subtract(RSP, ImmediateVal(8))
}
def reserve(n: Int = 1): AsmLine = {
(1 to n).foreach(_ => stack += stack.size -> next)
Subtract(RSP, ImmediateVal(n*8))
}
def drop(n : Int = 1): AsmLine = {
(1 to n).foreach(_ => stack.remove(stack.last._1))
Add(RSP, ImmediateVal(n*8))
}
def accessVar(ident: Ident): () => IndexAddress = () => {
IndexAddress(RSP, (stack.size - stack(ident)) * 8)
}
def head: MemLocation = MemLocation(RSP)
def head(size: SizeDir): MemLocation = MemLocation(RSP, size)
def contains(ident: Ident): Boolean = stack.contains(ident)
// TODO: Might want to actually properly handle this with the LinkedHashMap too
def align(): AsmLine = And(RSP, ImmediateVal(-16))
}
}