From 148828122363f92a36e6a4110e907e77ccb64a46 Mon Sep 17 00:00:00 2001 From: Guy C Date: Tue, 25 Feb 2025 00:00:12 +0000 Subject: [PATCH] feat: implements binary operators in asmGenerator Co-authored-by: Gleb Koval Co-authored-by: Barf-Vader --- src/main/wacc/backend/asmGenerator.scala | 304 ++++++++++++++++------- src/main/wacc/backend/assemblyIR.scala | 41 ++- src/main/wacc/frontend/microWacc.scala | 14 +- src/main/wacc/frontend/typeChecker.scala | 35 ++- 4 files changed, 285 insertions(+), 109 deletions(-) diff --git a/src/main/wacc/backend/asmGenerator.scala b/src/main/wacc/backend/asmGenerator.scala index 4c1a924..1c3a2a9 100644 --- a/src/main/wacc/backend/asmGenerator.scala +++ b/src/main/wacc/backend/asmGenerator.scala @@ -2,19 +2,37 @@ package wacc import scala.collection.mutable.LinkedHashMap import scala.collection.mutable.ListBuffer +import cats.syntax.align -object labelGenerator { - var labelVal = -1 - def getLabel(): String = { - labelVal += 1 - s".L$labelVal" - } -} object asmGenerator { import microWacc._ import assemblyIR._ import wacc.types._ + val RAX = Register(RegSize.R64, RegName.AX) + val EAX = Register(RegSize.E32, RegName.AX) + val RSP = Register(RegSize.R64, RegName.SP) + val ESP = Register(RegSize.E32, RegName.SP) + val EDX = Register(RegSize.E32, RegName.DX) + val RDI = Register(RegSize.R64, RegName.DI) + val RIP = Register(RegSize.R64, RegName.IP) + val RBP = Register(RegSize.R64, RegName.BP) + val RSI = Register(RegSize.R64, RegName.SI) + + val _8_BIT_MASK = 0xFF + + object labelGenerator { + var labelVal = -1 + def getLabel(): String = { + labelVal += 1 + s".L$labelVal" + } + def getLabel(target: CallTarget): String = target match{ + case Ident(v,_) => s"wacc_$v" + case Builtin(name) => s"_$name" + } + } + def generateAsm(microProg: Program): List[AsmLine] = { given stack: LinkedHashMap[Ident, Int] = LinkedHashMap[Ident, Int]() given strings: ListBuffer[String] = ListBuffer[String]() @@ -25,10 +43,8 @@ object asmGenerator { funcPrologue() ++ alignStack() ++ main.flatMap(generateStmt) ++ - List(Move(Register(RegSize.R64, RegName.AX), ImmediateVal(0))) ++ - funcEpilogue() ++ - List(assemblyIR.Return()) ++ - generateFuncs() + List(Move(RAX, ImmediateVal(0))) ++ + funcEpilogue() val strDirs = strings.toList.zipWithIndex.flatMap { case (str, i) => List(Directive.Int(str.size), LabelDef(s".L.str$i"), Directive.Asciz(str)) @@ -40,12 +56,38 @@ object asmGenerator { progAsm } -//TODO - def generateFuncs()(using + def wrapFunc(labelName: String, funcBody: List[AsmLine])(using stack: LinkedHashMap[Ident, Int], strings: ListBuffer[String] ): List[AsmLine] = { + LabelDef(labelName) :: + funcPrologue() ++ + funcBody ++ + funcEpilogue() + } + + def generateBuiltInFuncs()(using + stack: LinkedHashMap[Ident, Int], + strings: ListBuffer[String] + ): List[AsmLine] = { + wrapFunc(labelGenerator.getLabel(Builtin.Exit), + alignStack() ++ + List(Pop(RDI), + assemblyIR.Call(CLibFunc.Exit)) + ) ++ + wrapFunc(labelGenerator.getLabel(Builtin.Printf), + alignStack() ++ + List(assemblyIR.Call(CLibFunc.PrintF), + Move(RDI, ImmediateVal(0)), + assemblyIR.Call(CLibFunc.Fflush) + ) + )++ + wrapFunc(labelGenerator.getLabel(Builtin.Malloc), List() + )++ + wrapFunc(labelGenerator.getLabel(Builtin.Free), + List() + ) } def generateStmt( @@ -53,28 +95,16 @@ object asmGenerator { )(using stack: LinkedHashMap[Ident, Int], strings: ListBuffer[String]): List[AsmLine] = stmt match { case microWacc.Call(Builtin.Exit, code :: _) => - // alignStack() ++ - evalExprIntoReg(code, Register(RegSize.R64, RegName.DI)) ++ - List(assemblyIR.Call(CLibFunc.Exit)) - - case microWacc.Call(Builtin.Println, expr :: _) => - // alignStack() ++ - printF(expr) ++ - printLn() - - case microWacc.Call(Builtin.Print, expr :: _) => - // alignStack() ++ - printF(expr) - + List() case Assign(lhs, rhs) => var dest: IndexAddress = - IndexAddress(Register(RegSize.R64, RegName.SP), 0) // gets overrwitten + IndexAddress(RSP, 0) // gets overrwitten (lhs match { case ident: Ident => if (!stack.contains(ident)) { stack += (ident -> (stack.size + 1)) dest = accessVar(ident) - List(Subtract(Register(RegSize.R64, RegName.SP), ImmediateVal(16))) + List(Subtract(RSP, ImmediateVal(8))) } else { dest = accessVar(ident) List() @@ -90,15 +120,16 @@ object asmGenerator { case microWacc.Call(Builtin.ReadChar, _) => readIntoVar(dest, Builtin.ReadChar) case _ => - evalExprIntoReg(rhs, Register(RegSize.R64, RegName.AX)) ++ - List(Move(dest, Register(RegSize.R64, RegName.AX))) + evalExprOntoStack(rhs) ++ + List(Pop(dest)) }) case If(cond, thenBranch, elseBranch) => { val elseLabel = labelGenerator.getLabel() val endLabel = labelGenerator.getLabel() - evalExprIntoReg(cond, Register(RegSize.R64, RegName.AX)) ++ + evalExprOntoStack(cond) ++ List( - Compare(Register(RegSize.R64, RegName.AX), ImmediateVal(0)), + Compare(MemLocation(RSP, SizeDir.Word), ImmediateVal(0)), + Add(RSP, ImmediateVal(8)), Jump(LabelArg(elseLabel), Cond.Equal) ) ++ thenBranch.flatMap(generateStmt) ++ @@ -110,30 +141,33 @@ object asmGenerator { val startLabel = labelGenerator.getLabel() val endLabel = labelGenerator.getLabel() List(LabelDef(startLabel)) ++ - evalExprIntoReg(cond, Register(RegSize.R64, RegName.AX)) ++ + evalExprOntoStack(cond) ++ List( - Compare(Register(RegSize.R64, RegName.AX), ImmediateVal(0)), + Compare(MemLocation(RSP, SizeDir.Word), ImmediateVal(0)), + Add(RSP, ImmediateVal(8)), Jump(LabelArg(endLabel), Cond.Equal) ) ++ body.flatMap(generateStmt) ++ List(Jump(LabelArg(startLabel)), LabelDef(endLabel)) } case microWacc.Return(expr) => - evalExprIntoReg(expr, Register(RegSize.R64, RegName.AX)) + evalExprOntoStack(expr) ++ + List(Pop(RAX), assemblyIR.Return()) + case _ => List() } - def evalExprIntoReg(expr: Expr, dest: Register)(using + def evalExprOntoStack(expr: Expr)(using stack: LinkedHashMap[Ident, Int], strings: ListBuffer[String] ): List[AsmLine] = { expr match { case IntLiter(v) => - List(Move(dest, ImmediateVal(v))) + List(Push(ImmediateVal(v))) case CharLiter(v) => - List(Move(dest, ImmediateVal(v.toInt))) + List(Push(ImmediateVal(v.toInt))) case ident: Ident => - List(Move(dest, accessVar(ident))) + List(Push(accessVar(ident))) case ArrayLiter(elems) => expr.ty match { case KnownType.String => @@ -143,70 +177,173 @@ object asmGenerator { }.mkString List( Load( - dest, + RAX, IndexAddress( - Register(RegSize.R64, RegName.IP), + RIP, LabelArg(s".L.str${strings.size - 1}") ) - ) + ), + Push(RAX) ) // TODO other array types case _ => List() } - // TODO other expr types - case BoolLiter(v) => List(Move(dest, ImmediateVal(if (v) 1 else 0))) - case _ => List() + case BoolLiter(v) => List(Push(ImmediateVal(if (v) 1 else 0))) + case NullLiter() => List(Push(ImmediateVal(0))) + case ArrayElem(value, indices) => List() + case UnaryOp(x, op) => op match { + // TODO: chr and ord are TYPE CASTS. They do not change the internal value, + // but will need bound checking e.t.c. + case UnaryOperator.Chr => List() + case UnaryOperator.Ord => List() + case UnaryOperator.Len => List() + case UnaryOperator.Negate => List( + Negate(MemLocation(RSP,SizeDir.Word)) + ) + case UnaryOperator.Not => + evalExprOntoStack(x) ++ + List( + Xor(MemLocation(RSP, SizeDir.Word), ImmediateVal(1)) + ) + + } + case BinaryOp(x, y, op) => op match { + case BinaryOperator.Add => + evalExprOntoStack(x) ++ + evalExprOntoStack(y) ++ + List( + Pop(EAX), + Add(MemLocation(RSP, SizeDir.Word), EAX) + // TODO OVERFLOWING + ) + case BinaryOperator.Sub => + evalExprOntoStack(x) ++ + evalExprOntoStack(y) ++ + List( + Pop(EAX), + Subtract(MemLocation(RSP, SizeDir.Word), EAX) + // TODO OVERFLOWING + ) + case BinaryOperator.Mul => + evalExprOntoStack(x) ++ + evalExprOntoStack(y) ++ + List( + Pop(EAX), + Multiply(MemLocation(RSP, SizeDir.Word), EAX) + // TODO OVERFLOWING + ) + case BinaryOperator.Div => + evalExprOntoStack(y) ++ + evalExprOntoStack(x) ++ + List( + Pop(EAX), + Divide(MemLocation(RSP, SizeDir.Word)), + Add(RSP, ImmediateVal(8)), + Push(EAX) + // TODO CHECK DIVISOR IS NOT 0 + ) + case BinaryOperator.Mod => + evalExprOntoStack(y) ++ + evalExprOntoStack(x) ++ + List( + Pop(EAX), + Divide(MemLocation(RSP, SizeDir.Word)), + Add(RSP, ImmediateVal(8)), + Push(EDX) + // TODO CHECK DIVISOR IS NOT 0 + ) + case BinaryOperator.Eq => + generateComparison(x, y, Cond.Equal) + case BinaryOperator.Neq => + generateComparison(x, y, Cond.NotEqual) + case BinaryOperator.Greater => + generateComparison(x, y, Cond.Greater) + case BinaryOperator.GreaterEq => + generateComparison(x, y, Cond.GreaterEqual) + case BinaryOperator.Less => + generateComparison(x, y, Cond.Less) + case BinaryOperator.LessEq => + generateComparison(x, y, Cond.LessEqual) + case BinaryOperator.And => + evalExprOntoStack(x) ++ + evalExprOntoStack(y) ++ + List( + Pop(EAX), + And(MemLocation(RSP, SizeDir.Word), EAX), + ) + case BinaryOperator.Or => + evalExprOntoStack(x) ++ + evalExprOntoStack(y) ++ + List( + Pop(EAX), + Or(MemLocation(RSP, SizeDir.Word), EAX), + ) + } + case microWacc.Call(target, args) => List() } } - // TODO make sure EOF doenst override the value in the stack - // probably need labels implemented for conditional jumps - def readIntoVar(dest: IndexAddress, readType: Builtin.ReadInt.type | Builtin.ReadChar.type)(using + // def readIntoVar(dest: IndexAddress, readType: Builtin.ReadInt.type | Builtin.ReadChar.type)(using + // stack: LinkedHashMap[Ident, Int], + // strings: ListBuffer[String] + // ): List[AsmLine] = { + // readType match { + // case Builtin.ReadInt => + // strings += PrintFormat.Int.toString + // case Builtin.ReadChar => + // strings += PrintFormat.Char.toString + // } + // List( + // Load( + // RDI, + // IndexAddress( + // RIP, + // LabelArg(s".L.str${strings.size - 1}") + // ) + // ), + // Load(RSI, dest) + // ) ++ + // // alignStack() ++ + // List(assemblyIR.Call(CLibFunc.Scanf)) + + // } + + def generateComparison(x : Expr, y: Expr, cond: Cond)(using stack: LinkedHashMap[Ident, Int], strings: ListBuffer[String] ): List[AsmLine] = { - readType match { - case Builtin.ReadInt => - strings += PrintFormat.Int.toString - case Builtin.ReadChar => - strings += PrintFormat.Char.toString - } - List( - Load( - Register(RegSize.R64, RegName.DI), - IndexAddress( - Register(RegSize.R64, RegName.IP), - LabelArg(s".L.str${strings.size - 1}") - ) - ), - Load(Register(RegSize.R64, RegName.SI), dest) - ) ++ - // alignStack() ++ - List(assemblyIR.Call(CLibFunc.Scanf)) - + evalExprOntoStack(x) ++ + evalExprOntoStack(y) ++ + List( + Pop(EAX), + Compare(MemLocation(RSP, SizeDir.Word), EAX), + Set(Register(RegSize.Byte, RegName.AL), cond), + And(EAX, ImmediateVal(_8_BIT_MASK)), + Push(EAX) + ) } - def accessVar(ident: Ident)(using stack: LinkedHashMap[Ident, Int]): IndexAddress = - IndexAddress(Register(RegSize.R64, RegName.SP), (stack.size - stack(ident)) * 16) + IndexAddress(RSP, (stack.size - stack(ident)) * 8) def alignStack()(using stack: LinkedHashMap[Ident, Int]): List[AsmLine] = { List( - And(Register(RegSize.R64, RegName.SP), ImmediateVal(-16)) + And(RSP, ImmediateVal(-16)) ) } // Missing a sub instruction but dont think we need it def funcPrologue(): List[AsmLine] = { List( - Push(Register(RegSize.R64, RegName.BP)), - Move(Register(RegSize.R64, RegName.BP), Register(RegSize.R64, RegName.SP)) + Push(RBP), + Move(RBP, RSP) ) } def funcEpilogue(): List[AsmLine] = { List( - Move(Register(RegSize.R64, RegName.SP), Register(RegSize.R64, RegName.BP)), - Pop(Register(RegSize.R64, RegName.BP)) + Move(RSP, RBP), + Pop(RBP), + assemblyIR.Return() ) } @@ -231,9 +368,9 @@ object asmGenerator { } List( Load( - Register(RegSize.R64, RegName.DI), + RDI, IndexAddress( - Register(RegSize.R64, RegName.IP), + RIP, LabelArg(s".L.str${strings.size - 1}") ) ) @@ -251,22 +388,23 @@ object asmGenerator { } List( Load( - Register(RegSize.R64, RegName.DI), + RDI, IndexAddress( - Register(RegSize.R64, RegName.IP), + RIP, LabelArg(s".L.str${strings.size - 1}") ) ) ) } else { - evalExprIntoReg(expr, Register(RegSize.R64, RegName.SI)) + evalExprOntoStack(expr) ++ + List(Pop(RSI)) }) // print the value ++ List( assemblyIR.Call(CLibFunc.PrintF), - Move(Register(RegSize.R64, RegName.DI), ImmediateVal(0)), + Move(RDI, ImmediateVal(0)), assemblyIR.Call(CLibFunc.Fflush) ) } @@ -278,16 +416,16 @@ object asmGenerator { ): List[AsmLine] = { strings += "" Load( - Register(RegSize.R64, RegName.DI), + RDI, IndexAddress( - Register(RegSize.R64, RegName.IP), + RIP, LabelArg(s".L.str${strings.size - 1}") ) ) :: List( assemblyIR.Call(CLibFunc.Puts), - Move(Register(RegSize.R64, RegName.DI), ImmediateVal(0)), + Move(RDI, ImmediateVal(0)), assemblyIR.Call(CLibFunc.Fflush) ) diff --git a/src/main/wacc/backend/assemblyIR.scala b/src/main/wacc/backend/assemblyIR.scala index c48daac..22ca36b 100644 --- a/src/main/wacc/backend/assemblyIR.scala +++ b/src/main/wacc/backend/assemblyIR.scala @@ -9,17 +9,20 @@ object assemblyIR { enum RegSize { case R64 case E32 + case Byte override def toString = this match { case R64 => "r" case E32 => "e" + case Byte => "" } } enum RegName { - case AX, BX, CX, DX, SI, DI, SP, BP, IP, Reg8, Reg9, Reg10, Reg11, Reg12, Reg13, Reg14, Reg15 + case AX, AL, BX, CX, DX, SI, DI, SP, BP, IP, Reg8, Reg9, Reg10, Reg11, Reg12, Reg13, Reg14, Reg15 override def toString = this match { case AX => "ax" + case AL => "al" case BX => "bx" case CX => "cx" case DX => "dx" @@ -42,7 +45,6 @@ object assemblyIR { // arguments enum CLibFunc extends Operand { case Scanf, - Puts, Fflush, Exit, PrintF @@ -51,24 +53,24 @@ object assemblyIR { override def toString = this match { case Scanf => "scanf" + plt - case Puts => "puts" + plt case Fflush => "fflush" + plt case Exit => "exit" + plt case PrintF => "printf" + plt } } + //TODO register naming conventions are wrong case class Register(size: RegSize, name: RegName) extends Dest with Src { override def toString = s"${size}${name}" } - case class MemLocation(pointer: Long | Register) extends Dest with Src { + case class MemLocation(pointer: Long | Register, opSize: SizeDir = SizeDir.Unspecified) extends Dest with Src { override def toString = pointer match { - case hex: Long => f"[0x$hex%X]" - case reg: Register => s"[$reg]" + case hex: Long => opSize.toString + f"[0x$hex%X]" + case reg: Register => opSize.toString + s"[$reg]" } } - case class IndexAddress(base: Register, offset: Int | LabelArg) extends Dest with Src { - override def toString = s"[$base + $offset]" + case class IndexAddress(base: Register, offset: Int | LabelArg, opSize: SizeDir = SizeDir.Unspecified) extends Dest with Src { + override def toString = s"$opSize[$base + $offset]" } case class ImmediateVal(value: Int) extends Src { @@ -85,11 +87,13 @@ object assemblyIR { } case class Add(op1: Dest, op2: Src) extends Operation("add", op1, op2) case class Subtract(op1: Dest, op2: Src) extends Operation("sub", op1, op2) - case class Multiply(ops: Operand*) extends Operation("mul", ops*) - case class Divide(op1: Src) extends Operation("div", op1) + case class Multiply(ops: Operand*) extends Operation("imul", ops*) + case class Divide(op1: Src) extends Operation("idiv", op1) + case class Negate(op: Dest) extends Operation("neg", op) case class And(op1: Dest, op2: Src) extends Operation("and", op1, op2) case class Or(op1: Dest, op2: Src) extends Operation("or", op1, op2) + case class Xor(op1: Dest, op2: Src) extends Operation("xor", op1, op2) case class Compare(op1: Dest, op2: Src) extends Operation("cmp", op1, op2) // stack operations @@ -106,6 +110,9 @@ object assemblyIR { case class Jump(op1: LabelArg, condition: Cond = Cond.Always) extends Operation(s"j${condition.toString}", op1) + case class Set(op1: Dest, condition: Cond = Cond.Always) + extends Operation(s"set${condition.toString}", op1) + case class LabelDef(name: String) extends AsmLine { override def toString = s"$name:" } @@ -156,4 +163,16 @@ object assemblyIR { case String => "%s" } } -} + + enum SizeDir { + case Byte, Word, Unspecified + + private val ptr = "ptr " + + override def toString(): String = this match { + case Byte => "byte " + ptr + case Word => "word " + ptr + case Unspecified => "" + } + } +} \ No newline at end of file diff --git a/src/main/wacc/frontend/microWacc.scala b/src/main/wacc/frontend/microWacc.scala index b9f6635..c558b6d 100644 --- a/src/main/wacc/frontend/microWacc.scala +++ b/src/main/wacc/frontend/microWacc.scala @@ -69,13 +69,15 @@ object microWacc { // Statements sealed trait Stmt + case class Builtin(val name: String)(retTy: SemType) extends CallTarget(retTy) { + override def toString(): String = name + } object Builtin { - case object ReadInt extends CallTarget(KnownType.Int) - case object ReadChar extends CallTarget(KnownType.Char) - case object Print extends CallTarget(?) - case object Println extends CallTarget(?) - case object Exit extends CallTarget(?) - case object Free extends CallTarget(?) + object Read extends Builtin("read")(?) + object Printf extends Builtin("printf")(?) + object Exit extends Builtin("exit")(?) + object Free extends Builtin("free")(?) + object Malloc extends Builtin("malloc")(?) } case class Assign(lhs: LValue, rhs: Expr) extends Stmt diff --git a/src/main/wacc/frontend/typeChecker.scala b/src/main/wacc/frontend/typeChecker.scala index e3960cb..b854272 100644 --- a/src/main/wacc/frontend/typeChecker.scala +++ b/src/main/wacc/frontend/typeChecker.scala @@ -177,12 +177,14 @@ object typeChecker { microWacc.Assign( destTyped, microWacc.Call( - destTy match { - case KnownType.Int => microWacc.Builtin.ReadInt - case KnownType.Char => microWacc.Builtin.ReadChar - case _ => microWacc.Builtin.ReadInt // we'll stop due to error anyway - }, - Nil + microWacc.Builtin.Read, + List( + destTy match { + case KnownType.Int => "%d".toMicroWaccCharArray + case KnownType.Char | _ => "%c".toMicroWaccCharArray + }, + destTyped + ) ) ) ) @@ -213,10 +215,20 @@ object typeChecker { ) case ast.Print(expr, newline) => // This constraint should never fail, the scope-checker should have caught it already + val exprTyped = checkValue(expr, Constraint.Unconstrained) + val format = exprTyped.ty match { + case KnownType.Bool | KnownType.String => "%s" + case KnownType.Char => "%c" + case KnownType.Int => "%d" + case _ => "%p" + } List( microWacc.Call( - if newline then microWacc.Builtin.Println else microWacc.Builtin.Print, - List(checkValue(expr, Constraint.Unconstrained)) + microWacc.Builtin.Printf, + List( + s"$format${if newline then "\n" else ""}".toMicroWaccCharArray, + exprTyped + ) ) ) case ast.If(cond, thenStmt, elseStmt) => @@ -262,7 +274,7 @@ object typeChecker { microWacc.CharLiter(v) case l @ ast.StrLiter(v) => KnownType.String.satisfies(constraint, l.pos) - microWacc.ArrayLiter(v.map(microWacc.CharLiter(_)).toList)(KnownType.String) + v.toMicroWaccCharArray case l @ ast.PairLiter() => microWacc.NullLiter()(KnownType.Pair(?, ?).satisfies(constraint, l.pos)) case ast.Parens(expr) => checkValue(expr, constraint) @@ -441,4 +453,9 @@ object typeChecker { case _ => ctx.error(Error.InternalError(elem.pos, "snd must be applied to a pair")) }) } + + extension (s: String) { + def toMicroWaccCharArray: microWacc.ArrayLiter = + microWacc.ArrayLiter(s.map(microWacc.CharLiter(_)).toList)(KnownType.String) + } }