package wacc import scala.collection.mutable.LinkedHashMap import scala.collection.mutable.ListBuffer object asmGenerator { import microWacc._ import assemblyIR._ import wacc.types._ import lexer.escapedChars val RAX = Register(RegSize.R64, RegName.AX) val EAX = Register(RegSize.E32, RegName.AX) val ESP = Register(RegSize.E32, RegName.SP) val EDX = Register(RegSize.E32, RegName.DX) val RDI = Register(RegSize.R64, RegName.DI) val RIP = Register(RegSize.R64, RegName.IP) val RBP = Register(RegSize.R64, RegName.BP) val RSI = Register(RegSize.R64, RegName.SI) val RDX = Register(RegSize.R64, RegName.DX) val RCX = Register(RegSize.R64, RegName.CX) val R8 = Register(RegSize.R64, RegName.Reg8) val R9 = Register(RegSize.R64, RegName.Reg9) val _8_BIT_MASK = 0xff def generateAsm(microProg: Program): List[AsmLine] = { given stack: Stack = Stack() given strings: ListBuffer[String] = ListBuffer[String]() given labelGenerator: LabelGenerator = LabelGenerator() val Program(funcs, main) = microProg val progAsm = LabelDef("main") :: funcPrologue() ++ List(stack.align()) ++ main.flatMap(generateStmt) ++ List(Move(RAX, ImmediateVal(0))) ++ funcEpilogue() ++ generateBuiltInFuncs() val strDirs = strings.toList.zipWithIndex.flatMap { case (str, i) => List( Directive.Int(str.size), LabelDef(s".L.str$i"), Directive.Asciz(str.escaped) ) } List(Directive.IntelSyntax, Directive.Global("main"), Directive.RoData) ++ strDirs ++ List(Directive.Text) ++ progAsm } def wrapFunc(labelName: String, funcBody: List[AsmLine])(using stack: Stack, strings: ListBuffer[String] ): List[AsmLine] = { LabelDef(labelName) :: funcPrologue() ++ funcBody ++ funcEpilogue() } def generateBuiltInFuncs()(using stack: Stack, strings: ListBuffer[String], labelGenerator: LabelGenerator ): List[AsmLine] = { wrapFunc( labelGenerator.getLabel(Builtin.Exit), List(stack.align(), assemblyIR.Call(CLibFunc.Exit)) ) ++ wrapFunc( labelGenerator.getLabel(Builtin.Printf), List( stack.align(), assemblyIR.Call(CLibFunc.PrintF), Move(RDI, ImmediateVal(0)), assemblyIR.Call(CLibFunc.Fflush) ) ) ++ wrapFunc( labelGenerator.getLabel(Builtin.Malloc), List( stack.align() ) ) ++ wrapFunc(labelGenerator.getLabel(Builtin.Free), List()) ++ wrapFunc( labelGenerator.getLabel(Builtin.Read), List( stack.align(), stack.reserve(), stack.push(RSI), Load(RSI, stack.head), assemblyIR.Call(CLibFunc.Scanf), stack.pop(RAX), stack.drop() ) ) } def generateStmt( stmt: Stmt )(using stack: Stack, strings: ListBuffer[String], labelGenerator: LabelGenerator ): List[AsmLine] = stmt match { case Assign(lhs, rhs) => var dest: () => IndexAddress = () => IndexAddress(RAX, 0) // gets overrwitten (lhs match { case ident: Ident => dest = stack.accessVar(ident) if (!stack.contains(ident)) { List(stack.reserve(ident)) } else Nil // TODO lhs = arrayElem case _ => // dest = ??? List() }) ++ evalExprOntoStack(rhs) ++ List( stack.pop(RAX), Move(dest(), RAX) ) case If(cond, thenBranch, elseBranch) => { val elseLabel = labelGenerator.getLabel() val endLabel = labelGenerator.getLabel() evalExprOntoStack(cond) ++ List( Compare(stack.head(SizeDir.Word), ImmediateVal(0)), stack.drop(), Jump(LabelArg(elseLabel), Cond.Equal) ) ++ thenBranch.flatMap(generateStmt) ++ List(Jump(LabelArg(endLabel)), LabelDef(elseLabel)) ++ elseBranch.flatMap(generateStmt) ++ List(LabelDef(endLabel)) } case While(cond, body) => { val startLabel = labelGenerator.getLabel() val endLabel = labelGenerator.getLabel() List(LabelDef(startLabel)) ++ evalExprOntoStack(cond) ++ List( Compare(stack.head(SizeDir.Word), ImmediateVal(0)), stack.drop(), Jump(LabelArg(endLabel), Cond.Equal) ) ++ body.flatMap(generateStmt) ++ List(Jump(LabelArg(startLabel)), LabelDef(endLabel)) } case microWacc.Return(expr) => evalExprOntoStack(expr) ++ List(stack.pop(RAX), assemblyIR.Return()) case call: microWacc.Call => generateCall(call) } def evalExprOntoStack(expr: Expr)(using stack: Stack, strings: ListBuffer[String], labelGenerator: LabelGenerator ): List[AsmLine] = { val out = expr match { case IntLiter(v) => List(stack.push(ImmediateVal(v))) case CharLiter(v) => List(stack.push(ImmediateVal(v.toInt))) case ident: Ident => List(stack.push(stack.accessVar(ident)())) case ArrayLiter(elems) => expr.ty match { case KnownType.String => strings += elems.map { case CharLiter(v) => v case _ => "" }.mkString List( Load( RAX, IndexAddress( RIP, LabelArg(s".L.str${strings.size - 1}") ) ), stack.push(RAX) ) // TODO other array types case _ => List() } case BoolLiter(v) => List(stack.push(ImmediateVal(if (v) 1 else 0))) case NullLiter() => List(stack.push(ImmediateVal(0))) case ArrayElem(value, indices) => List() case UnaryOp(x, op) => evalExprOntoStack(x) ++ (op match { // TODO: chr and ord are TYPE CASTS. They do not change the internal value, // but will need bound checking e.t.c. case UnaryOperator.Chr => List() case UnaryOperator.Ord => List() case UnaryOperator.Len => List() case UnaryOperator.Negate => List( Negate(stack.head(SizeDir.Word)) ) case UnaryOperator.Not => evalExprOntoStack(x) ++ List( Xor(stack.head(SizeDir.Word), ImmediateVal(1)) ) }) case BinaryOp(x, y, op) => op match { case BinaryOperator.Add => evalExprOntoStack(x) ++ evalExprOntoStack(y) ++ List( stack.pop(RAX), Add(stack.head(SizeDir.Word), EAX) // TODO OVERFLOWING ) case BinaryOperator.Sub => evalExprOntoStack(x) ++ evalExprOntoStack(y) ++ List( stack.pop(RAX), Subtract(stack.head(SizeDir.Word), EAX) // TODO OVERFLOWING ) case BinaryOperator.Mul => evalExprOntoStack(x) ++ evalExprOntoStack(y) ++ List( stack.pop(RAX), Multiply(EAX, stack.head(SizeDir.Word)), stack.drop(), stack.push(RAX) // TODO OVERFLOWING ) case BinaryOperator.Div => evalExprOntoStack(y) ++ evalExprOntoStack(x) ++ List( stack.pop(RAX), Divide(stack.head(SizeDir.Word)), stack.drop(), stack.push(RAX) // TODO CHECK DIVISOR IS NOT 0 ) case BinaryOperator.Mod => evalExprOntoStack(y) ++ evalExprOntoStack(x) ++ List( stack.pop(RAX), Divide(stack.head(SizeDir.Word)), stack.drop(), stack.push(RDX) // TODO CHECK DIVISOR IS NOT 0 ) case BinaryOperator.Eq => generateComparison(x, y, Cond.Equal) case BinaryOperator.Neq => generateComparison(x, y, Cond.NotEqual) case BinaryOperator.Greater => generateComparison(x, y, Cond.Greater) case BinaryOperator.GreaterEq => generateComparison(x, y, Cond.GreaterEqual) case BinaryOperator.Less => generateComparison(x, y, Cond.Less) case BinaryOperator.LessEq => generateComparison(x, y, Cond.LessEqual) case BinaryOperator.And => evalExprOntoStack(x) ++ evalExprOntoStack(y) ++ List( stack.pop(RAX), And(stack.head(SizeDir.Word), EAX) ) case BinaryOperator.Or => evalExprOntoStack(x) ++ evalExprOntoStack(y) ++ List( stack.pop(RAX), Or(stack.head(SizeDir.Word), EAX) ) } case call: microWacc.Call => generateCall(call) ++ List(stack.push(RAX)) } if out.isEmpty then List(stack.push(ImmediateVal(0))) else out } def generateCall(call: microWacc.Call)(using stack: Stack, strings: ListBuffer[String], labelGenerator: LabelGenerator ): List[AsmLine] = { val argRegs = List(RDI, RSI, RDX, RCX, R8, R9) val microWacc.Call(target, args) = call argRegs.zip(args).flatMap { (reg, expr) => evalExprOntoStack(expr) ++ List(stack.pop(reg)) } ++ args.drop(argRegs.size).flatMap(evalExprOntoStack) ++ List(assemblyIR.Call(LabelArg(labelGenerator.getLabel(target)))) ++ (if (args.size > argRegs.size) { List(stack.drop(args.size - argRegs.size)) } else Nil) } def generateComparison(x: Expr, y: Expr, cond: Cond)(using stack: Stack, strings: ListBuffer[String], labelGenerator: LabelGenerator ): List[AsmLine] = { evalExprOntoStack(x) ++ evalExprOntoStack(y) ++ List( stack.pop(RAX), Compare(stack.head(SizeDir.Word), EAX), Set(Register(RegSize.Byte, RegName.AL), cond), And(RAX, ImmediateVal(_8_BIT_MASK)), stack.drop(), stack.push(RAX) ) } // Missing a sub instruction but dont think we need it def funcPrologue()(using stack: Stack): List[AsmLine] = { List( stack.push(RBP), Move(RBP, Register(RegSize.R64, RegName.SP)) ) } def funcEpilogue()(using stack: Stack): List[AsmLine] = { List( Move(Register(RegSize.R64, RegName.SP), RBP), stack.pop(RBP), assemblyIR.Return() ) } class LabelGenerator { var labelVal = -1 def getLabel(): String = { labelVal += 1 s".L$labelVal" } def getLabel(target: CallTarget): String = target match { case Ident(v, _) => s"wacc_$v" case Builtin(name) => s"_$name" } } class Stack { private val stack = LinkedHashMap[Expr | Int, Int]() private val RSP = Register(RegSize.R64, RegName.SP) def next: Int = stack.size + 1 def push(expr: Expr, src: Src): AsmLine = { stack += expr -> next Push(src) } def push(src: Src): AsmLine = { stack += stack.size -> next Push(src) } def pop(dest: Src): AsmLine = { stack.remove(stack.last._1) Pop(dest) } def reserve(ident: Ident): AsmLine = { stack += ident -> next Subtract(RSP, ImmediateVal(8)) } def reserve(n: Int = 1): AsmLine = { (1 to n).foreach(_ => stack += stack.size -> next) Subtract(RSP, ImmediateVal(n * 8)) } def drop(n: Int = 1): AsmLine = { (1 to n).foreach(_ => stack.remove(stack.last._1)) Add(RSP, ImmediateVal(n * 8)) } def accessVar(ident: Ident): () => IndexAddress = () => { IndexAddress(RSP, (stack.size - stack(ident)) * 8) } def head: MemLocation = MemLocation(RSP) def head(size: SizeDir): MemLocation = MemLocation(RSP, size) def contains(ident: Ident): Boolean = stack.contains(ident) // TODO: Might want to actually properly handle this with the LinkedHashMap too def align(): AsmLine = And(RSP, ImmediateVal(-16)) } private val escapedCharsMapping = escapedChars.map { case (k, v) => v -> s"\\$k" } extension (s: String) { private def escaped: String = s.flatMap(c => escapedCharsMapping.getOrElse(c, c.toString)) } }