diff --git a/src/main/wacc/backend/asmGenerator.scala b/src/main/wacc/backend/asmGenerator.scala index 4f3c4c7..fe30af7 100644 --- a/src/main/wacc/backend/asmGenerator.scala +++ b/src/main/wacc/backend/asmGenerator.scala @@ -42,6 +42,7 @@ object asmGenerator { val RCX = Register(RegSize.R64, RegName.CX) val R8 = Register(RegSize.R64, RegName.Reg8) val R9 = Register(RegSize.R64, RegName.Reg9) + val argRegs = List(RDI, RSI, RDX, RCX, R8, R9) val _8_BIT_MASK = 0xff @@ -75,7 +76,8 @@ object asmGenerator { main.foldMap(generateStmt(_)), Chain.one(Move(RAX, ImmediateVal(0))), funcEpilogue(), - generateBuiltInFuncs() + generateBuiltInFuncs(), + funcs.foldMap(generateUserFunc(_)) ) val strDirs = strings.toList.zipWithIndex.foldMap { case (str, i) => @@ -111,6 +113,22 @@ object asmGenerator { chain } + def generateUserFunc(func: FuncDecl)(using + strings: ListBuffer[String], + labelGenerator: LabelGenerator + ): Chain[AsmLine] = { + given stack: Stack = Stack() + // Setup the stack with param 7 and up + func.params.drop(argRegs.size).foreach(stack.reserve(_)) + var chain = Chain.empty[AsmLine] + // Push the rest of params onto the stack for simplicity + argRegs.zip(func.params).foreach { (reg, param) => + chain += stack.push(param, reg) + } + chain ++= func.body.foldMap(generateStmt(_)) + wrapFunc(labelGenerator.getLabel(func.name), chain) + } + def generateBuiltInFuncs()(using stack: Stack, strings: ListBuffer[String], @@ -223,7 +241,7 @@ object asmGenerator { case microWacc.Return(expr) => chain ++= evalExprOntoStack(expr) chain += stack.pop(RAX) - chain += assemblyIR.Return() + chain ++= funcEpilogue() case call: microWacc.Call => chain ++= generateCall(call) @@ -321,7 +339,6 @@ object asmGenerator { labelGenerator: LabelGenerator ): Chain[AsmLine] = { var chain = Chain.empty[AsmLine] - val argRegs = List(RDI, RSI, RDX, RCX, R8, R9) val microWacc.Call(target, args) = call argRegs.zip(args).foldMap { (reg, expr) => @@ -373,7 +390,7 @@ object asmGenerator { def funcEpilogue()(using stack: Stack): Chain[AsmLine] = { var chain = Chain.empty[AsmLine] chain += Move(Register(RegSize.R64, RegName.SP), RBP) - chain += stack.pop(RBP) + chain += Pop(RBP) chain += assemblyIR.Return() chain } @@ -382,7 +399,7 @@ object asmGenerator { private val stack = LinkedHashMap[Expr | Int, Int]() private val RSP = Register(RegSize.R64, RegName.SP) - def next: Int = stack.size + 1 + private def next: Int = stack.size + 1 def push(expr: Expr, src: Src): AsmLine = { stack += expr -> next Push(src) diff --git a/src/main/wacc/frontend/typeChecker.scala b/src/main/wacc/frontend/typeChecker.scala index ca95342..8c11550 100644 --- a/src/main/wacc/frontend/typeChecker.scala +++ b/src/main/wacc/frontend/typeChecker.scala @@ -110,7 +110,7 @@ object typeChecker { microWacc.FuncDecl( microWacc.Ident(name.v, name.uid)(retType), params.zip(paramTypes).map { case (ast.Param(_, ident), ty) => - microWacc.Ident(ident.v, name.uid)(ty) + microWacc.Ident(ident.v, ident.uid)(ty) }, stmts.toList .flatMap(