package wacc import cats.data.Chain import cats.syntax.foldable._ import wacc.RuntimeError._ object asmGenerator { import microWacc._ import assemblyIR._ import assemblyIR.commonRegisters._ import assemblyIR.RegName._ import types._ import sizeExtensions._ import lexer.escapedChars private val argRegs = List(DI, SI, DX, CX, R8, R9) private val _7_BIT_MASK = 0x7f extension [T](chain: Chain[T]) def +(item: T): Chain[T] = chain.append(item) /** Concatenates multiple `Chain[T]` instances into a single `Chain[T]`, appending them to the * current `Chain`. * * @param chains * A variable number of `Chain[T]` instances to concatenate. * @return * A new `Chain[T]` containing all elements from `chain` concatenated with `chains`. */ def concatAll(chains: Chain[T]*): Chain[T] = chains.foldLeft(chain)(_ ++ _) def generateAsm(microProg: Program): Chain[AsmLine] = { given stack: Stack = Stack() given labelGenerator: LabelGenerator = LabelGenerator() val Program(funcs, main) = microProg val mainLabel = LabelDef("main") val mainAsm = main.headOption match { case Some(stmt) => labelGenerator.getDebugFunc(stmt.pos, "$main", mainLabel) + mainLabel case None => Chain.one(mainLabel) } val progAsm = mainAsm.concatAll( funcPrologue(), main.foldMap(generateStmt(_)), Chain.one(Xor(RAX, RAX)), funcEpilogue(), Chain(Directive.Size(mainLabel, SizeExpr.Relative(mainLabel)), Directive.EndFunc), generateBuiltInFuncs(), RuntimeError.all.foldMap(_.generate), funcs.foldMap(generateUserFunc(_)) ) Chain( Directive.IntelSyntax, Directive.Global("main"), Directive.RoData ).concatAll( labelGenerator.generateDebug, labelGenerator.generateConstants, Chain.one(Directive.Text), progAsm ) } private def wrapBuiltinFunc(builtin: Builtin, funcBody: Chain[AsmLine])(using stack: Stack, labelGenerator: LabelGenerator ): Chain[AsmLine] = { var asm = Chain.one[AsmLine](labelGenerator.getLabelDef(builtin)) asm ++= funcPrologue() asm ++= funcBody asm ++= funcEpilogue() asm } private def generateUserFunc(func: FuncDecl)(using labelGenerator: LabelGenerator ): Chain[AsmLine] = { given stack: Stack = Stack() // Setup the stack with param 7 and up func.params.drop(argRegs.size).foreach(stack.reserve(_)) stack.reserve(Size.Q64) // Reserve return pointer slot val funcLabel = labelGenerator.getLabelDef(func.name) var asm = labelGenerator.getDebugFunc(func.pos, func.name.name, funcLabel) val debugFunc = asm.size > 0 asm += funcLabel asm ++= funcPrologue() // Push the rest of params onto the stack for simplicity argRegs.zip(func.params).foreach { (reg, param) => asm += stack.push(param, Register(Size.Q64, reg)) } asm ++= func.body.foldMap(generateStmt(_)) // No need for epilogue here since all user functions must return explicitly if (debugFunc) { asm += Directive.Size(funcLabel, SizeExpr.Relative(funcLabel)) asm += Directive.EndFunc } asm } private def generateBuiltInFuncs()(using stack: Stack, labelGenerator: LabelGenerator ): Chain[AsmLine] = { var asm = Chain.empty[AsmLine] asm ++= wrapBuiltinFunc( Builtin.Exit, Chain(stackAlign, assemblyIR.Call(CLibFunc.Exit)) ) asm ++= wrapBuiltinFunc( Builtin.Printf, Chain( stackAlign, assemblyIR.Call(CLibFunc.PrintF), Xor(RDI, RDI), assemblyIR.Call(CLibFunc.Fflush) ) ) asm ++= wrapBuiltinFunc( Builtin.PrintCharArray, Chain( stackAlign, Load(RDX, MemLocation(RSI, KnownType.Int.size.toInt)), Move(Register(KnownType.Int.size, SI), MemLocation(RSI, opSize = Some(KnownType.Int.size))), assemblyIR.Call(CLibFunc.PrintF), Xor(RDI, RDI), assemblyIR.Call(CLibFunc.Fflush) ) ) asm ++= wrapBuiltinFunc( Builtin.Malloc, Chain( stackAlign, assemblyIR.Call(CLibFunc.Malloc), // Out of memory check Compare(RAX, ImmediateVal(0)), Jump(labelGenerator.getLabelArg(OutOfMemoryError), Cond.Equal) ) ) asm ++= wrapBuiltinFunc( Builtin.Free, Chain( stackAlign, Compare(RDI, ImmediateVal(0)), Jump(labelGenerator.getLabelArg(NullPtrError), Cond.Equal), assemblyIR.Call(CLibFunc.Free) ) ) asm ++= wrapBuiltinFunc( Builtin.Read, Chain( stackAlign, Subtract(Register(Size.Q64, SP), ImmediateVal(8)), Push(RSI), Load(RSI, MemLocation(Register(Size.Q64, SP), opSize = Some(Size.Q64))), assemblyIR.Call(CLibFunc.Scanf), Pop(RAX) ) ) asm } private def generateStmt(stmt: Stmt)(using stack: Stack, labelGenerator: LabelGenerator ): Chain[AsmLine] = { val fileNo = labelGenerator.getDebugFile(stmt.pos.file) var asm = Chain.one[AsmLine](Directive.Location(fileNo, stmt.pos.line, None)) stmt match { case Assign(lhs, rhs) => lhs match { case ident: Ident => if (!stack.contains(ident)) asm += stack.reserve(ident) asm ++= evalExprOntoStack(rhs) asm += stack.pop(RAX) asm += Move(stack.accessVar(ident).copy(opSize = Some(Size.Q64)), RAX) case ArrayElem(x, i) => asm ++= evalExprOntoStack(rhs) asm ++= evalExprOntoStack(i) asm += stack.pop(RCX) asm += Compare(ECX, ImmediateVal(0)) asm += Jump(labelGenerator.getLabelArg(OutOfBoundsError), Cond.Less) asm += stack.push(KnownType.Int.size, RCX) asm ++= evalExprOntoStack(x) asm += stack.pop(RAX) asm += stack.pop(RCX) asm += Compare(RAX, ImmediateVal(0)) asm += Jump(labelGenerator.getLabelArg(NullPtrError), Cond.Equal) asm += Compare(MemLocation(RAX, opSize = Some(KnownType.Int.size)), ECX) asm += Jump(labelGenerator.getLabelArg(OutOfBoundsError), Cond.LessEqual) asm += stack.pop(RDX) asm += Move( MemLocation(RAX, KnownType.Int.size.toInt, (RCX, x.ty.elemSize.toInt)), Register(x.ty.elemSize, DX) ) } case If(cond, thenBranch, elseBranch) => val elseLabel = labelGenerator.getLabel() val endLabel = labelGenerator.getLabel() asm ++= evalExprOntoStack(cond) asm += stack.pop(RAX) asm += Compare(RAX, ImmediateVal(0)) asm += Jump(LabelArg(elseLabel), Cond.Equal) asm ++= stack.withScope(() => thenBranch.foldMap(generateStmt)) asm += Jump(LabelArg(endLabel)) asm += LabelDef(elseLabel) asm ++= stack.withScope(() => elseBranch.foldMap(generateStmt)) asm += LabelDef(endLabel) case While(cond, body) => val startLabel = labelGenerator.getLabel() val endLabel = labelGenerator.getLabel() asm += LabelDef(startLabel) asm ++= evalExprOntoStack(cond) asm += stack.pop(RAX) asm += Compare(RAX, ImmediateVal(0)) asm += Jump(LabelArg(endLabel), Cond.Equal) asm ++= stack.withScope(() => body.foldMap(generateStmt)) asm += Jump(LabelArg(startLabel)) asm += LabelDef(endLabel) case call: microWacc.Call => asm ++= generateCall(call, isTail = false) case microWacc.Return(expr) => expr match { case call: microWacc.Call => asm ++= generateCall(call, isTail = true) // tco case _ => asm ++= evalExprOntoStack(expr) asm += stack.pop(RAX) asm ++= funcEpilogue() } } asm } private def evalExprOntoStack(expr: Expr)(using stack: Stack, labelGenerator: LabelGenerator ): Chain[AsmLine] = { var asm = Chain.empty[AsmLine] val stackSizeStart = stack.size expr match { case IntLiter(v) => asm += stack.push(KnownType.Int.size, ImmediateVal(v)) case CharLiter(v) => asm += stack.push(KnownType.Char.size, ImmediateVal(v.toInt)) case ident: Ident => val location = stack.accessVar(ident) // items in stack are guaranteed to be in Q64 slots, // so we are safe to wipe the opSize from the memory location asm += stack.push(ident.ty.size, location.copy(opSize = None)) case array @ ArrayLiter(elems) => expr.ty match { case KnownType.String => val str = elems.collect { case CharLiter(v) => v }.mkString asm += Load(RAX, MemLocation(RIP, labelGenerator.getLabelArg(str))) asm += stack.push(KnownType.String.size, RAX) case ty => asm ++= generateCall( microWacc.Call(Builtin.Malloc, List(IntLiter(array.heapSize)))(array.pos), isTail = false ) asm += stack.push(KnownType.Array(?).size, RAX) // Store the length of the array at the start asm += Move( MemLocation(RAX, opSize = Some(KnownType.Int.size)), ImmediateVal(elems.size) ) elems.zipWithIndex.foldMap { (elem, i) => asm ++= evalExprOntoStack(elem) asm += stack.pop(RCX) asm += stack.pop(RAX) asm += Move( MemLocation(RAX, KnownType.Int.size.toInt + i * ty.elemSize.toInt), Register(ty.elemSize, CX) ) asm += stack.push(KnownType.Array(?).size, RAX) } } case BoolLiter(true) => asm += stack.push(KnownType.Bool.size, ImmediateVal(1)) case BoolLiter(false) => asm += Xor(RAX, RAX) asm += stack.push(KnownType.Bool.size, RAX) case NullLiter() => asm += stack.push(KnownType.Pair(?, ?).size, ImmediateVal(0)) case ArrayElem(x, i) => asm ++= evalExprOntoStack(x) asm ++= evalExprOntoStack(i) asm += stack.pop(RCX) asm += Compare(RCX, ImmediateVal(0)) asm += Jump(labelGenerator.getLabelArg(OutOfBoundsError), Cond.Less) asm += stack.pop(RAX) asm += Compare(RAX, ImmediateVal(0)) asm += Jump(labelGenerator.getLabelArg(NullPtrError), Cond.Equal) asm += Compare(MemLocation(RAX, opSize = Some(KnownType.Int.size)), ECX) asm += Jump(labelGenerator.getLabelArg(OutOfBoundsError), Cond.LessEqual) // + Int because we store the length of the array at the start asm += Move( Register(x.ty.elemSize, AX), MemLocation(RAX, KnownType.Int.size.toInt, (RCX, x.ty.elemSize.toInt)) ) asm += stack.push(x.ty.elemSize, RAX) case UnaryOp(x, op) => asm ++= evalExprOntoStack(x) op match { case UnaryOperator.Chr => asm += Move(EAX, stack.head) asm += And(EAX, ImmediateVal(~_7_BIT_MASK)) asm += Compare(EAX, ImmediateVal(0)) asm += Jump(labelGenerator.getLabelArg(BadChrError), Cond.NotEqual) case UnaryOperator.Ord => // No op needed case UnaryOperator.Len => asm += stack.pop(RAX) asm += Move(EAX, MemLocation(RAX, opSize = Some(KnownType.Int.size))) asm += stack.push(KnownType.Int.size, RAX) case UnaryOperator.Negate => asm += Xor(EAX, EAX) asm += Subtract(EAX, stack.head) asm += Jump(labelGenerator.getLabelArg(OverflowError), Cond.Overflow) asm += stack.drop() asm += stack.push(KnownType.Int.size, RAX) case UnaryOperator.Not => asm += Xor(stack.head, ImmediateVal(1)) } case BinaryOp(x, y, op) => val destX = Register(x.ty.size, AX) asm ++= evalExprOntoStack(y) asm ++= evalExprOntoStack(x) asm += stack.pop(RAX) op match { case BinaryOperator.Add => asm += Add(stack.head, destX) asm += Jump(labelGenerator.getLabelArg(OverflowError), Cond.Overflow) case BinaryOperator.Sub => asm += Subtract(destX, stack.head) asm += Jump(labelGenerator.getLabelArg(OverflowError), Cond.Overflow) asm += stack.drop() asm += stack.push(destX.size, RAX) case BinaryOperator.Mul => asm += Multiply(destX, stack.head) asm += Jump(labelGenerator.getLabelArg(OverflowError), Cond.Overflow) asm += stack.drop() asm += stack.push(destX.size, RAX) case BinaryOperator.Div => asm += Compare(stack.head, ImmediateVal(0)) asm += Jump(labelGenerator.getLabelArg(ZeroDivError), Cond.Equal) asm += CDQ() asm += Divide(stack.head) asm += stack.drop() asm += stack.push(destX.size, RAX) case BinaryOperator.Mod => asm += Compare(stack.head, ImmediateVal(0)) asm += Jump(labelGenerator.getLabelArg(ZeroDivError), Cond.Equal) asm += CDQ() asm += Divide(stack.head) asm += stack.drop() asm += stack.push(destX.size, RDX) case BinaryOperator.Eq => asm ++= generateComparison(destX, Cond.Equal) case BinaryOperator.Neq => asm ++= generateComparison(destX, Cond.NotEqual) case BinaryOperator.Greater => asm ++= generateComparison(destX, Cond.Greater) case BinaryOperator.GreaterEq => asm ++= generateComparison(destX, Cond.GreaterEqual) case BinaryOperator.Less => asm ++= generateComparison(destX, Cond.Less) case BinaryOperator.LessEq => asm ++= generateComparison(destX, Cond.LessEqual) case BinaryOperator.And => asm += And(stack.head, destX) case BinaryOperator.Or => asm += Or(stack.head, destX) } case call: microWacc.Call => asm ++= generateCall(call, isTail = false) asm += stack.push(call.ty.size, RAX) } assert( stack.size == stackSizeStart + 1, "Sanity check: ONLY the evaluated expression should have been pushed onto the stack" ) asm ++= zeroRest(stack.head.copy(opSize = Some(Size.Q64)), expr.ty.size) asm } private def generateCall(call: microWacc.Call, isTail: Boolean)(using stack: Stack, labelGenerator: LabelGenerator ): Chain[AsmLine] = { var asm = Chain.empty[AsmLine] val microWacc.Call(target, args) = call // Evaluate arguments 0-6 argRegs .zip(args) .map { (reg, expr) => asm ++= evalExprOntoStack(expr) reg } // And set the appropriate registers .reverse .foreach { reg => asm += stack.pop(Register(Size.Q64, reg)) } // Evaluate arguments 7 and up and push them onto the stack args.drop(argRegs.size).foldMap { asm ++= evalExprOntoStack(_) } // Tail Call Optimisation (TCO) if (isTail) { asm += Jump(labelGenerator.getLabelArg(target)) // tail call } else { asm += assemblyIR.Call(labelGenerator.getLabelArg(target)) // regular call } // Remove arguments 7 and up from the stack if (args.size > argRegs.size) { asm += stack.drop(args.size - argRegs.size) } asm } private def generateComparison(destX: Register, cond: Cond)(using stack: Stack ): Chain[AsmLine] = { var asm = Chain.empty[AsmLine] asm += Compare(destX, stack.head) asm += Set(Register(Size.B8, AX), cond) asm ++= zeroRest(RAX, Size.B8) asm += stack.drop() asm += stack.push(Size.B8, RAX) asm } private def funcPrologue()(using stack: Stack): Chain[AsmLine] = { var asm = Chain.empty[AsmLine] asm += stack.push(Size.Q64, RBP) asm += Move(RBP, Register(Size.Q64, SP)) asm } private def funcEpilogue(): Chain[AsmLine] = { var asm = Chain.empty[AsmLine] asm += Move(Register(Size.Q64, SP), RBP) asm += Pop(RBP) asm += assemblyIR.Return() asm } def stackAlign: AsmLine = And(Register(Size.Q64, SP), ImmediateVal(-16)) private def zeroRest(dest: Dest, size: Size): Chain[AsmLine] = size match { case Size.Q64 | Size.D32 => Chain.empty case _ => Chain.one(And(dest, ImmediateVal((1 << (size.toInt * 8)) - 1))) } private val escapedCharsMapping = escapedChars.map { case (k, v) => v -> s"\\$k" } extension (s: String) { def escaped: String = s.flatMap(c => escapedCharsMapping.getOrElse(c, c.toString)) } }