package wacc import scala.collection.mutable.LinkedHashMap import scala.collection.mutable.ListBuffer import cats.data.Chain import cats.syntax.foldable._ // import parsley.token.errors.Label object asmGenerator { import microWacc._ import assemblyIR._ import wacc.types._ import lexer.escapedChars abstract case class Error() { def strLabel: String def errStr: String def errLabel: String def stringDef: Chain[AsmLine] = Chain( Directive.Int(errStr.size), LabelDef(strLabel), Directive.Asciz(errStr) ) } object zeroDivError extends Error { // TODO: is this bad? Can we make an error case class/some other structure? def strLabel = ".L._errDivZero_str0" def errStr = "fatal error: division or modulo by zero" def errLabel = ".L._errDivZero" } val RAX = Register(RegSize.R64, RegName.AX) val EAX = Register(RegSize.E32, RegName.AX) val ESP = Register(RegSize.E32, RegName.SP) val EDX = Register(RegSize.E32, RegName.DX) val RDI = Register(RegSize.R64, RegName.DI) val RIP = Register(RegSize.R64, RegName.IP) val RBP = Register(RegSize.R64, RegName.BP) val RSI = Register(RegSize.R64, RegName.SI) val RDX = Register(RegSize.R64, RegName.DX) val RCX = Register(RegSize.R64, RegName.CX) val R8 = Register(RegSize.R64, RegName.Reg8) val R9 = Register(RegSize.R64, RegName.Reg9) val argRegs = List(RDI, RSI, RDX, RCX, R8, R9) val _8_BIT_MASK = 0xff extension (chain: Chain[AsmLine]) def +(line: AsmLine): Chain[AsmLine] = chain.append(line) def concatAll(chains: Chain[AsmLine]*): Chain[AsmLine] = chains.foldLeft(chain)(_ ++ _) class LabelGenerator { var labelVal = -1 def getLabel(): String = { labelVal += 1 s".L$labelVal" } def getLabel(target: CallTarget): String = target match { case Ident(v, _) => s"wacc_$v" case Builtin(name) => s"_$name" } } def generateAsm(microProg: Program): List[AsmLine] = { given stack: Stack = Stack() given strings: ListBuffer[String] = ListBuffer[String]() given labelGenerator: LabelGenerator = LabelGenerator() val Program(funcs, main) = microProg val progAsm = Chain(LabelDef("main")).concatAll( funcPrologue(), Chain.one(stack.align()), main.foldMap(generateStmt(_)), Chain.one(Move(RAX, ImmediateVal(0))), funcEpilogue(), generateBuiltInFuncs(), funcs.foldMap(generateUserFunc(_)) ) val strDirs = strings.toList.zipWithIndex.foldMap { case (str, i) => Chain( Directive.Int(str.size), LabelDef(s".L.str$i"), Directive.Asciz(str.escaped) ) } ++ zeroDivError.stringDef Chain( Directive.IntelSyntax, Directive.Global("main"), Directive.RoData ).concatAll( strDirs, Chain.one(Directive.Text), progAsm ).toList } private def wrapBuiltinFunc(labelName: String, funcBody: Chain[AsmLine])(using stack: Stack ): Chain[AsmLine] = { var chain = Chain.one[AsmLine](LabelDef(labelName)) chain ++= funcPrologue() chain ++= funcBody chain ++= funcEpilogue() chain } def generateUserFunc(func: FuncDecl)(using strings: ListBuffer[String], labelGenerator: LabelGenerator ): Chain[AsmLine] = { given stack: Stack = Stack() // Setup the stack with param 7 and up func.params.drop(argRegs.size).foreach(stack.reserve(_)) var chain = Chain.one[AsmLine](LabelDef(labelGenerator.getLabel(func.name))) chain ++= funcPrologue() // Push the rest of params onto the stack for simplicity argRegs.zip(func.params).foreach { (reg, param) => chain += stack.push(param, reg) } chain ++= func.body.foldMap(generateStmt(_)) // No need for epilogue here since all user functions must return explicitly chain } def generateBuiltInFuncs()(using stack: Stack, strings: ListBuffer[String], labelGenerator: LabelGenerator ): Chain[AsmLine] = { var chain = Chain.empty[AsmLine] chain ++= wrapBuiltinFunc( labelGenerator.getLabel(Builtin.Exit), Chain(stack.align(), assemblyIR.Call(CLibFunc.Exit)) ) chain ++= wrapBuiltinFunc( labelGenerator.getLabel(Builtin.Printf), Chain( stack.align(), assemblyIR.Call(CLibFunc.PrintF), Move(RDI, ImmediateVal(0)), assemblyIR.Call(CLibFunc.Fflush) ) ) chain ++= wrapBuiltinFunc( labelGenerator.getLabel(Builtin.Malloc), Chain.one(stack.align()) ) chain ++= wrapBuiltinFunc(labelGenerator.getLabel(Builtin.Free), Chain.empty) chain ++= wrapBuiltinFunc( labelGenerator.getLabel(Builtin.Read), Chain( stack.align(), stack.reserve(), stack.push(RSI), Load(RSI, stack.head), assemblyIR.Call(CLibFunc.Scanf), stack.pop(RAX), stack.drop() ) ) chain ++= Chain( // TODO can this be done with a call to generateStmt? // Consider other error cases -> look to generalise LabelDef(zeroDivError.errLabel), stack.align(), Load(RDI, IndexAddress(RIP, LabelArg(zeroDivError.strLabel))), assemblyIR.Call(CLibFunc.PrintF), Move(RDI, ImmediateVal(-1)), assemblyIR.Call(CLibFunc.Exit) ) chain } def generateStmt(stmt: Stmt)(using stack: Stack, strings: ListBuffer[String], labelGenerator: LabelGenerator ): Chain[AsmLine] = { var chain = Chain.empty[AsmLine] stmt match { case Assign(lhs, rhs) => var dest: () => IndexAddress = () => IndexAddress(RAX, 0) // overwritten below lhs match { case ident: Ident => dest = stack.accessVar(ident) if (!stack.contains(ident)) chain += stack.reserve(ident) // TODO lhs = arrayElem case _ => } chain ++= evalExprOntoStack(rhs) chain += stack.pop(RAX) chain += Move(dest(), RAX) case If(cond, thenBranch, elseBranch) => val elseLabel = labelGenerator.getLabel() val endLabel = labelGenerator.getLabel() chain ++= evalExprOntoStack(cond) chain += stack.pop(RAX) chain += Compare(RAX, ImmediateVal(0)) chain += Jump(LabelArg(elseLabel), Cond.Equal) chain ++= thenBranch.foldMap(generateStmt) chain += Jump(LabelArg(endLabel)) chain += LabelDef(elseLabel) chain ++= elseBranch.foldMap(generateStmt) chain += LabelDef(endLabel) case While(cond, body) => val startLabel = labelGenerator.getLabel() val endLabel = labelGenerator.getLabel() chain += LabelDef(startLabel) chain ++= evalExprOntoStack(cond) chain += stack.pop(RAX) chain += Compare(RAX, ImmediateVal(0)) chain += Jump(LabelArg(endLabel), Cond.Equal) chain ++= body.foldMap(generateStmt) chain += Jump(LabelArg(startLabel)) chain += LabelDef(endLabel) case microWacc.Return(expr) => chain ++= evalExprOntoStack(expr) chain += stack.pop(RAX) chain ++= funcEpilogue() case call: microWacc.Call => chain ++= generateCall(call) } chain } def evalExprOntoStack(expr: Expr)(using stack: Stack, strings: ListBuffer[String], labelGenerator: LabelGenerator ): Chain[AsmLine] = { var chain = Chain.empty[AsmLine] expr match { case IntLiter(v) => chain += stack.push(ImmediateVal(v)) case CharLiter(v) => chain += stack.push(ImmediateVal(v.toInt)) case ident: Ident => chain += stack.push(stack.accessVar(ident)()) case ArrayLiter(elems) => expr.ty match { case KnownType.String => strings += elems.collect { case CharLiter(v) => v }.mkString chain += Load(RAX, IndexAddress(RIP, LabelArg(s".L.str${strings.size - 1}"))) chain += stack.push(RAX) case _ => // Other array types TODO } case BoolLiter(v) => chain += stack.push(ImmediateVal(if (v) 1 else 0)) case NullLiter() => chain += stack.push(ImmediateVal(0)) case ArrayElem(_, _) => // TODO: Implement handling case UnaryOp(x, op) => chain ++= evalExprOntoStack(x) op match { case UnaryOperator.Chr | UnaryOperator.Ord | UnaryOperator.Len => // No op needed case UnaryOperator.Negate => chain += Negate(stack.head(SizeDir.DWord)) case UnaryOperator.Not => chain += Xor(stack.head(SizeDir.DWord), ImmediateVal(1)) } case BinaryOp(x, y, op) => chain ++= evalExprOntoStack(y) chain ++= evalExprOntoStack(x) chain += stack.pop(RAX) op match { case BinaryOperator.Add => chain += Add(stack.head(SizeDir.DWord), EAX) case BinaryOperator.Sub => chain += Subtract(EAX, stack.head(SizeDir.DWord)) chain += stack.drop() chain += stack.push(RAX) case BinaryOperator.Mul => chain += Multiply(EAX, stack.head(SizeDir.DWord)) chain += stack.drop() chain += stack.push(RAX) case BinaryOperator.Div => chain += Compare(stack.head(SizeDir.DWord), ImmediateVal(0)) chain += Jump(LabelArg(zeroDivError.errLabel), Cond.Equal) chain += CDQ() chain += Divide(stack.head(SizeDir.DWord)) chain += stack.drop() chain += stack.push(RAX) case BinaryOperator.Mod => chain += CDQ() chain += Divide(stack.head(SizeDir.DWord)) chain += stack.drop() chain += stack.push(RDX) case BinaryOperator.Eq => chain ++= generateComparison(x, y, Cond.Equal) case BinaryOperator.Neq => chain ++= generateComparison(x, y, Cond.NotEqual) case BinaryOperator.Greater => chain ++= generateComparison(x, y, Cond.Greater) case BinaryOperator.GreaterEq => chain ++= generateComparison(x, y, Cond.GreaterEqual) case BinaryOperator.Less => chain ++= generateComparison(x, y, Cond.Less) case BinaryOperator.LessEq => chain ++= generateComparison(x, y, Cond.LessEqual) case BinaryOperator.And => chain += And(stack.head(SizeDir.DWord), EAX) case BinaryOperator.Or => chain += Or(stack.head(SizeDir.DWord), EAX) } case call: microWacc.Call => chain ++= generateCall(call) chain += stack.push(RAX) } if chain.isEmpty then chain += stack.push(ImmediateVal(0)) chain } def generateCall(call: microWacc.Call)(using stack: Stack, strings: ListBuffer[String], labelGenerator: LabelGenerator ): Chain[AsmLine] = { var chain = Chain.empty[AsmLine] val microWacc.Call(target, args) = call argRegs.zip(args).foldMap { (reg, expr) => chain ++= evalExprOntoStack(expr) chain += stack.pop(reg) } args.drop(argRegs.size).foldMap { chain ++= evalExprOntoStack(_) } chain += assemblyIR.Call(LabelArg(labelGenerator.getLabel(target))) if (args.size > argRegs.size) { chain += stack.drop(args.size - argRegs.size) } chain } def generateComparison(x: Expr, y: Expr, cond: Cond)(using stack: Stack, strings: ListBuffer[String], labelGenerator: LabelGenerator ): Chain[AsmLine] = { var chain = Chain.empty[AsmLine] chain ++= evalExprOntoStack(x) chain ++= evalExprOntoStack(y) chain += stack.pop(RAX) chain += Compare(stack.head(SizeDir.DWord), EAX) chain += Set(Register(RegSize.Byte, RegName.AL), cond) chain += And(RAX, ImmediateVal(_8_BIT_MASK)) chain += stack.drop() chain += stack.push(RAX) chain } // Missing a sub instruction but dont think we need it def funcPrologue()(using stack: Stack): Chain[AsmLine] = { var chain = Chain.empty[AsmLine] chain += stack.push(RBP) chain += Move(RBP, Register(RegSize.R64, RegName.SP)) chain } def funcEpilogue()(using stack: Stack): Chain[AsmLine] = { var chain = Chain.empty[AsmLine] chain += Move(Register(RegSize.R64, RegName.SP), RBP) chain += Pop(RBP) chain += assemblyIR.Return() chain } class Stack { private val stack = LinkedHashMap[Expr | Int, Int]() private val RSP = Register(RegSize.R64, RegName.SP) private def next: Int = stack.size + 1 def push(expr: Expr, src: Src): AsmLine = { stack += expr -> next Push(src) } def push(src: Src): AsmLine = { stack += stack.size -> next Push(src) } def pop(dest: Src): AsmLine = { stack.remove(stack.last._1) Pop(dest) } def reserve(ident: Ident): AsmLine = { stack += ident -> next Subtract(RSP, ImmediateVal(8)) } def reserve(n: Int = 1): AsmLine = { (1 to n).foreach(_ => stack += stack.size -> next) Subtract(RSP, ImmediateVal(n * 8)) } def drop(n: Int = 1): AsmLine = { (1 to n).foreach(_ => stack.remove(stack.last._1)) Add(RSP, ImmediateVal(n * 8)) } def accessVar(ident: Ident): () => IndexAddress = () => { IndexAddress(RSP, (stack.size - stack(ident)) * 8) } def head: MemLocation = MemLocation(RSP) def head(size: SizeDir): MemLocation = MemLocation(RSP, size) def contains(ident: Ident): Boolean = stack.contains(ident) // TODO: Might want to actually properly handle this with the LinkedHashMap too def align(): AsmLine = And(RSP, ImmediateVal(-16)) } private val escapedCharsMapping = escapedChars.map { case (k, v) => v -> s"\\$k" } extension (s: String) { private def escaped: String = s.flatMap(c => escapedCharsMapping.getOrElse(c, c.toString)) } }