From e23ef8da4810e803bf057a688eee19dbbcaa46a7 Mon Sep 17 00:00:00 2001 From: Gleb Koval Date: Thu, 13 Feb 2025 17:19:15 +0000 Subject: [PATCH 1/8] feat: initial microWacc definition --- src/main/wacc/Frontend/microWacc.scala | 67 ++++++++++++++++++++++++++ 1 file changed, 67 insertions(+) create mode 100644 src/main/wacc/Frontend/microWacc.scala diff --git a/src/main/wacc/Frontend/microWacc.scala b/src/main/wacc/Frontend/microWacc.scala new file mode 100644 index 0000000..1faf3a8 --- /dev/null +++ b/src/main/wacc/Frontend/microWacc.scala @@ -0,0 +1,67 @@ +package wacc + +import cats.data.NonEmptyList + +object microWacc { + import wacc.types._ + + sealed trait CallTarget + sealed trait Expr(val ty: SemType) + sealed trait LValue + + // Atomic expressions + case class IntLiter(v: Int) extends Expr(KnownType.Int) + case class BoolLiter(v: Boolean) extends Expr(KnownType.Bool) + case class CharLiter(v: Char) extends Expr(KnownType.Char) + case class ArrayLiter(elems: List[Expr])(ty: SemType) extends Expr(ty) + case class Ident(name: String)(identTy: SemType) extends Expr(identTy) with CallTarget with LValue + case class ArrayElem(name: Ident, indices: NonEmptyList[Expr])(ty: SemType) + extends Expr(ty) + with LValue + + // Operators + case class UnaryOp(x: Expr, op: UnaryOperator)(ty: SemType) extends Expr(ty) + enum UnaryOperator { + case Negate + case Not + case Len + case Ord + case Chr + } + case class BinaryOp(x: Expr, y: Expr, op: BinaryOperator)(ty: SemType) extends Expr(ty) + enum BinaryOperator { + case Add + case Sub + case Mul + case Div + case Mod + case Greater + case GreaterEq + case Less + case LessEq + case Eq + case Neq + case And + case Or + } + + // Statements + sealed trait Stmt + + enum Builtin extends CallTarget { + case Read + case Free + case Exit + case Print + } + + case class Assign(lhs: LValue, rhs: Expr) extends Stmt + case class If(cond: Expr, thenBranch: List[Stmt], elseBranch: List[Stmt]) extends Stmt + case class While(cond: Expr, body: List[Stmt]) extends Stmt + case class Call(target: CallTarget, args: List[Expr]) extends Stmt + case class Return(expr: Expr) extends Stmt + + // Program + case class FuncDecl(name: Ident, params: List[Ident], body: List[Stmt]) + case class Program(funcs: List[FuncDecl], stmts: List[Stmt]) +} From d6aa83a2eab3c9f9d370baf689ac97f34b7fa693 Mon Sep 17 00:00:00 2001 From: Gleb Koval Date: Thu, 13 Feb 2025 23:54:46 +0000 Subject: [PATCH 2/8] fix: add support for return types in micro wacc calls --- src/main/wacc/Frontend/microWacc.scala | 22 +++++++++++++--------- 1 file changed, 13 insertions(+), 9 deletions(-) diff --git a/src/main/wacc/Frontend/microWacc.scala b/src/main/wacc/Frontend/microWacc.scala index 1faf3a8..88e43f2 100644 --- a/src/main/wacc/Frontend/microWacc.scala +++ b/src/main/wacc/Frontend/microWacc.scala @@ -5,16 +5,19 @@ import cats.data.NonEmptyList object microWacc { import wacc.types._ - sealed trait CallTarget + sealed trait CallTarget(val retTy: SemType) sealed trait Expr(val ty: SemType) - sealed trait LValue + sealed trait LValue extends Expr // Atomic expressions case class IntLiter(v: Int) extends Expr(KnownType.Int) case class BoolLiter(v: Boolean) extends Expr(KnownType.Bool) case class CharLiter(v: Char) extends Expr(KnownType.Char) case class ArrayLiter(elems: List[Expr])(ty: SemType) extends Expr(ty) - case class Ident(name: String)(identTy: SemType) extends Expr(identTy) with CallTarget with LValue + case class Ident(name: String)(identTy: SemType) + extends Expr(identTy) + with CallTarget(identTy) + with LValue case class ArrayElem(name: Ident, indices: NonEmptyList[Expr])(ty: SemType) extends Expr(ty) with LValue @@ -48,17 +51,18 @@ object microWacc { // Statements sealed trait Stmt - enum Builtin extends CallTarget { - case Read - case Free - case Exit - case Print + object Builtin { + case object ReadInt extends CallTarget(KnownType.Int) + case object ReadChar extends CallTarget(KnownType.Char) + case object Print extends CallTarget(?) + case object Exit extends CallTarget(?) + case object Free extends CallTarget(?) } case class Assign(lhs: LValue, rhs: Expr) extends Stmt case class If(cond: Expr, thenBranch: List[Stmt], elseBranch: List[Stmt]) extends Stmt case class While(cond: Expr, body: List[Stmt]) extends Stmt - case class Call(target: CallTarget, args: List[Expr]) extends Stmt + case class Call(target: CallTarget, args: List[Expr]) extends Stmt with Expr(target.retTy) case class Return(expr: Expr) extends Stmt // Program From 03999e00ef31c4c6081ba0db03d25bc5c1e85ccc Mon Sep 17 00:00:00 2001 From: Gleb Koval Date: Fri, 14 Feb 2025 00:07:43 +0000 Subject: [PATCH 3/8] fix: add support for println in micro wacc --- src/main/wacc/Frontend/microWacc.scala | 1 + 1 file changed, 1 insertion(+) diff --git a/src/main/wacc/Frontend/microWacc.scala b/src/main/wacc/Frontend/microWacc.scala index 88e43f2..a19567e 100644 --- a/src/main/wacc/Frontend/microWacc.scala +++ b/src/main/wacc/Frontend/microWacc.scala @@ -55,6 +55,7 @@ object microWacc { case object ReadInt extends CallTarget(KnownType.Int) case object ReadChar extends CallTarget(KnownType.Char) case object Print extends CallTarget(?) + case object Println extends CallTarget(?) case object Exit extends CallTarget(?) case object Free extends CallTarget(?) } From 6a6aadbbeb0528f13bd84f71e9e5d2f1d0f11994 Mon Sep 17 00:00:00 2001 From: Gleb Koval Date: Fri, 14 Feb 2025 00:21:10 +0000 Subject: [PATCH 4/8] fix: add nulliter to micro wacc --- src/main/wacc/Frontend/microWacc.scala | 1 + 1 file changed, 1 insertion(+) diff --git a/src/main/wacc/Frontend/microWacc.scala b/src/main/wacc/Frontend/microWacc.scala index a19567e..d0b6f9d 100644 --- a/src/main/wacc/Frontend/microWacc.scala +++ b/src/main/wacc/Frontend/microWacc.scala @@ -14,6 +14,7 @@ object microWacc { case class BoolLiter(v: Boolean) extends Expr(KnownType.Bool) case class CharLiter(v: Char) extends Expr(KnownType.Char) case class ArrayLiter(elems: List[Expr])(ty: SemType) extends Expr(ty) + case class NullLiter()(ty: SemType) extends Expr(ty) case class Ident(name: String)(identTy: SemType) extends Expr(identTy) with CallTarget(identTy) From bc25f914ad83c3d1c5584cef6f28f16019f0ab03 Mon Sep 17 00:00:00 2001 From: Gleb Koval Date: Fri, 14 Feb 2025 00:35:48 +0000 Subject: [PATCH 5/8] fix: add uid to microWacc Ident --- src/main/wacc/Frontend/microWacc.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/main/wacc/Frontend/microWacc.scala b/src/main/wacc/Frontend/microWacc.scala index d0b6f9d..04b9035 100644 --- a/src/main/wacc/Frontend/microWacc.scala +++ b/src/main/wacc/Frontend/microWacc.scala @@ -15,7 +15,7 @@ object microWacc { case class CharLiter(v: Char) extends Expr(KnownType.Char) case class ArrayLiter(elems: List[Expr])(ty: SemType) extends Expr(ty) case class NullLiter()(ty: SemType) extends Expr(ty) - case class Ident(name: String)(identTy: SemType) + case class Ident(name: String, uid: Int)(identTy: SemType) extends Expr(identTy) with CallTarget(identTy) with LValue From 756b42dd724f184901d688bd9c7a16e993f63943 Mon Sep 17 00:00:00 2001 From: Gleb Koval Date: Thu, 13 Feb 2025 23:12:41 +0000 Subject: [PATCH 6/8] refactor: remove implicit ast from type checker --- src/main/wacc/Frontend/typeChecker.scala | 86 ++++++++++++------------ 1 file changed, 42 insertions(+), 44 deletions(-) diff --git a/src/main/wacc/Frontend/typeChecker.scala b/src/main/wacc/Frontend/typeChecker.scala index 62fad3b..21d49ae 100644 --- a/src/main/wacc/Frontend/typeChecker.scala +++ b/src/main/wacc/Frontend/typeChecker.scala @@ -4,17 +4,15 @@ import cats.syntax.all._ import scala.collection.mutable object typeChecker { - import wacc.ast._ import wacc.types._ case class TypeCheckerCtx( - globalNames: Map[Ident, SemType], - globalFuncs: Map[Ident, FuncType], + globalNames: Map[ast.Ident, SemType], + globalFuncs: Map[ast.Ident, FuncType], errors: mutable.Builder[Error, List[Error]] ) { - def typeOf(ident: Ident): SemType = globalNames(ident) - - def funcType(ident: Ident): FuncType = globalFuncs(ident) + def typeOf(ident: ast.Ident): SemType = globalNames(ident) + def funcType(ident: ast.Ident): FuncType = globalFuncs(ident) def error(err: Error): SemType = errors += err @@ -44,7 +42,7 @@ object typeChecker { * @return * The type if the constraint was satisfied, or ? if it was not. */ - private def satisfies(constraint: Constraint, pos: Position)(using + private def satisfies(constraint: Constraint, pos: ast.Position)(using ctx: TypeCheckerCtx ): SemType = (ty, constraint) match { @@ -100,12 +98,12 @@ object typeChecker { * The type checker context which includes the global names and functions, and an errors * builder. */ - def check(prog: Program)(using + def check(prog: ast.Program)(using ctx: TypeCheckerCtx ): Unit = { // Ignore function syntax types for return value and params, since those have been converted // to SemTypes by the renamer. - prog.funcs.foreach { case FuncDecl(_, name, _, stmts) => + prog.funcs.foreach { case ast.FuncDecl(_, name, _, stmts) => val FuncType(retType, _) = ctx.funcType(name) stmts.toList.foreach( checkStmt(_, Constraint.Is(retType, s"function ${name.v} must return $retType")) @@ -121,11 +119,11 @@ object typeChecker { * @param returnConstraint * The constraint that any `return ` statements must satisfy. */ - private def checkStmt(stmt: Stmt, returnConstraint: Constraint)(using + private def checkStmt(stmt: ast.Stmt, returnConstraint: Constraint)(using ctx: TypeCheckerCtx ): Unit = stmt match { // Ignore the type of the variable, since it has been converted to a SemType by the renamer. - case VarDecl(_, name, value) => + case ast.VarDecl(_, name, value) => val expectedTy = ctx.typeOf(name) checkValue( value, @@ -134,7 +132,7 @@ object typeChecker { s"variable ${name.v} must be assigned a value of type $expectedTy" ) ) - case Assign(lhs, rhs) => + case ast.Assign(lhs, rhs) => val lhsTy = checkValue(lhs, Constraint.Unconstrained) (lhsTy, checkValue(rhs, Constraint.Is(lhsTy, s"assignment must have type $lhsTy"))) match { case (?, ?) => @@ -143,7 +141,7 @@ object typeChecker { ) case _ => () } - case Read(dest) => + case ast.Read(dest) => checkValue(dest, Constraint.Unconstrained) match { case ? => ctx.error( @@ -159,7 +157,7 @@ object typeChecker { dest.pos ) } - case Free(lhs) => + case ast.Free(lhs) => checkValue( lhs, Constraint.IsEither( @@ -168,23 +166,23 @@ object typeChecker { "free must be applied to an array or pair" ) ) - case Return(expr) => + case ast.Return(expr) => checkValue(expr, returnConstraint) - case Exit(expr) => + case ast.Exit(expr) => checkValue(expr, Constraint.Is(KnownType.Int, "exit value must be int")) - case Print(expr, _) => + case ast.Print(expr, _) => // This constraint should never fail, the scope-checker should have caught it already checkValue(expr, Constraint.Unconstrained) - case If(cond, thenStmt, elseStmt) => + case ast.If(cond, thenStmt, elseStmt) => checkValue(cond, Constraint.Is(KnownType.Bool, "if condition must be a bool")) thenStmt.toList.foreach(checkStmt(_, returnConstraint)) elseStmt.toList.foreach(checkStmt(_, returnConstraint)) - case While(cond, body) => + case ast.While(cond, body) => checkValue(cond, Constraint.Is(KnownType.Bool, "while condition must be a bool")) body.toList.foreach(checkStmt(_, returnConstraint)) - case Block(body) => + case ast.Block(body) => body.toList.foreach(checkStmt(_, returnConstraint)) - case Skip() => () + case ast.Skip() => () } /** Type-check an AST LValue, RValue or Expr node. This function does all 3 since these traits @@ -197,17 +195,17 @@ object typeChecker { * @return * The most specific type of the value if it could be determined, or ? if it could not. */ - private def checkValue(value: LValue | RValue | Expr, constraint: Constraint)(using + private def checkValue(value: ast.LValue | ast.RValue | ast.Expr, constraint: Constraint)(using ctx: TypeCheckerCtx ): SemType = value match { - case l @ IntLiter(_) => KnownType.Int.satisfies(constraint, l.pos) - case l @ BoolLiter(_) => KnownType.Bool.satisfies(constraint, l.pos) - case l @ CharLiter(_) => KnownType.Char.satisfies(constraint, l.pos) - case l @ StrLiter(_) => KnownType.String.satisfies(constraint, l.pos) - case l @ PairLiter() => KnownType.Pair(?, ?).satisfies(constraint, l.pos) - case id: Ident => + case l @ ast.IntLiter(_) => KnownType.Int.satisfies(constraint, l.pos) + case l @ ast.BoolLiter(_) => KnownType.Bool.satisfies(constraint, l.pos) + case l @ ast.CharLiter(_) => KnownType.Char.satisfies(constraint, l.pos) + case l @ ast.StrLiter(_) => KnownType.String.satisfies(constraint, l.pos) + case l @ ast.PairLiter() => KnownType.Pair(?, ?).satisfies(constraint, l.pos) + case id: ast.Ident => ctx.typeOf(id).satisfies(constraint, id.pos) - case ArrayElem(id, indices) => + case ast.ArrayElem(id, indices) => val arrayTy = ctx.typeOf(id) val elemTy = indices.foldLeftM(arrayTy) { (acc, elem) => checkValue(elem, Constraint.Is(KnownType.Int, "array index must be an int")) @@ -222,8 +220,8 @@ object typeChecker { } } elemTy.getOrElse(?).satisfies(constraint, id.pos) - case Parens(expr) => checkValue(expr, constraint) - case l @ ArrayLiter(elems) => + case ast.Parens(expr) => checkValue(expr, constraint) + case l @ ast.ArrayLiter(elems) => KnownType // Start with an unknown param type, make it more specific while checking the elements. .Array(elems.foldLeft[SemType](?) { case (acc, elem) => @@ -233,14 +231,14 @@ object typeChecker { ) }) .satisfies(constraint, l.pos) - case l @ NewPair(fst, snd) => + case l @ ast.NewPair(fst, snd) => KnownType .Pair( checkValue(fst, Constraint.Unconstrained), checkValue(snd, Constraint.Unconstrained) ) .satisfies(constraint, l.pos) - case Call(id, args) => + case ast.Call(id, args) => val funcTy @ FuncType(retTy, paramTys) = ctx.funcType(id) if (args.length != paramTys.length) { ctx.error(Error.FunctionParamsMismatch(id, paramTys.length, args.length, funcTy)) @@ -251,7 +249,7 @@ object typeChecker { checkValue(arg, Constraint.Is(paramTy, s"argument type mismatch in function ${id.v}")) } retTy.satisfies(constraint, id.pos) - case Fst(elem) => + case ast.Fst(elem) => checkValue( elem, Constraint.Is(KnownType.Pair(?, ?), "fst must be applied to a pair") @@ -260,7 +258,7 @@ object typeChecker { left.satisfies(constraint, elem.pos) case _ => ctx.error(Error.InternalError(elem.pos, "fst must be applied to a pair")) } - case Snd(elem) => + case ast.Snd(elem) => checkValue( elem, Constraint.Is(KnownType.Pair(?, ?), "snd must be applied to a pair") @@ -270,36 +268,36 @@ object typeChecker { } // Unary operators - case Negate(x) => + case ast.Negate(x) => checkValue(x, Constraint.Is(KnownType.Int, "negation must be applied to an int")) KnownType.Int.satisfies(constraint, x.pos) - case Not(x) => + case ast.Not(x) => checkValue(x, Constraint.Is(KnownType.Bool, "logical not must be applied to a bool")) KnownType.Bool.satisfies(constraint, x.pos) - case Len(x) => + case ast.Len(x) => checkValue(x, Constraint.Is(KnownType.Array(?), "len must be applied to an array")) KnownType.Int.satisfies(constraint, x.pos) - case Ord(x) => + case ast.Ord(x) => checkValue(x, Constraint.Is(KnownType.Char, "ord must be applied to a char")) KnownType.Int.satisfies(constraint, x.pos) - case Chr(x) => + case ast.Chr(x) => checkValue(x, Constraint.Is(KnownType.Int, "chr must be applied to an int")) KnownType.Char.satisfies(constraint, x.pos) // Binary operators - case op: (Add | Sub | Mul | Div | Mod) => + case op: (ast.Add | ast.Sub | ast.Mul | ast.Div | ast.Mod) => val operand = Constraint.Is(KnownType.Int, s"${op.name} operator must be applied to an int") checkValue(op.x, operand) checkValue(op.y, operand) KnownType.Int.satisfies(constraint, op.pos) - case op: (Eq | Neq) => + case op: (ast.Eq | ast.Neq) => val xTy = checkValue(op.x, Constraint.Unconstrained) checkValue( op.y, Constraint.Is(xTy, s"${op.name} operator must be applied to values of the same type") ) KnownType.Bool.satisfies(constraint, op.pos) - case op: (Less | LessEq | Greater | GreaterEq) => + case op: (ast.Less | ast.LessEq | ast.Greater | ast.GreaterEq) => val xConstraint = Constraint.IsEither( KnownType.Int, KnownType.Char, @@ -313,7 +311,7 @@ object typeChecker { } checkValue(op.y, yConstraint) KnownType.Bool.satisfies(constraint, op.pos) - case op: (And | Or) => + case op: (ast.And | ast.Or) => val operand = Constraint.Is(KnownType.Bool, s"${op.name} operator must be applied to a bool") checkValue(op.x, operand) checkValue(op.y, operand) From b7e442b269cef07dea5a20ea2b175e8df1f3a4df Mon Sep 17 00:00:00 2001 From: Gleb Koval Date: Thu, 13 Feb 2025 23:39:07 +0000 Subject: [PATCH 7/8] refactor: introduce exit-code guard against InternalError --- src/main/wacc/Main.scala | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) diff --git a/src/main/wacc/Main.scala b/src/main/wacc/Main.scala index f8db02a..5e95424 100644 --- a/src/main/wacc/Main.scala +++ b/src/main/wacc/Main.scala @@ -39,8 +39,15 @@ def compile(contents: String): Int = { typeChecker.check(prog) if (errors.result.nonEmpty) { given errorContent: String = contents - errors.result.foreach(printError) - 200 + errors.result + .map { error => + printError(error) + error match { + case _: Error.InternalError => 201 + case _ => 200 + } + } + .max() } else 0 case Failure(msg) => println(msg) From 27cc25cc0d5c3b9254ea1f9a7096629b62be4416 Mon Sep 17 00:00:00 2001 From: Gleb Koval Date: Mon, 17 Feb 2025 15:26:32 +0000 Subject: [PATCH 8/8] feat: type-checker returns micro wacc --- src/main/wacc/Frontend/microWacc.scala | 19 +- src/main/wacc/Frontend/typeChecker.scala | 348 +++++++++++++++-------- src/main/wacc/Main.scala | 7 +- 3 files changed, 259 insertions(+), 115 deletions(-) diff --git a/src/main/wacc/Frontend/microWacc.scala b/src/main/wacc/Frontend/microWacc.scala index 04b9035..b9f6635 100644 --- a/src/main/wacc/Frontend/microWacc.scala +++ b/src/main/wacc/Frontend/microWacc.scala @@ -19,7 +19,7 @@ object microWacc { extends Expr(identTy) with CallTarget(identTy) with LValue - case class ArrayElem(name: Ident, indices: NonEmptyList[Expr])(ty: SemType) + case class ArrayElem(value: LValue, indices: NonEmptyList[Expr])(ty: SemType) extends Expr(ty) with LValue @@ -48,6 +48,23 @@ object microWacc { case And case Or } + object BinaryOperator { + def fromAst(op: ast.BinaryOp): BinaryOperator = op match { + case _: ast.Add => Add + case _: ast.Sub => Sub + case _: ast.Mul => Mul + case _: ast.Div => Div + case _: ast.Mod => Mod + case _: ast.Greater => Greater + case _: ast.GreaterEq => GreaterEq + case _: ast.Less => Less + case _: ast.LessEq => LessEq + case _: ast.Eq => Eq + case _: ast.Neq => Neq + case _: ast.And => And + case _: ast.Or => Or + } + } // Statements sealed trait Stmt diff --git a/src/main/wacc/Frontend/typeChecker.scala b/src/main/wacc/Frontend/typeChecker.scala index 21d49ae..e3960cb 100644 --- a/src/main/wacc/Frontend/typeChecker.scala +++ b/src/main/wacc/Frontend/typeChecker.scala @@ -2,6 +2,7 @@ package wacc import cats.syntax.all._ import scala.collection.mutable +import cats.data.NonEmptyList object typeChecker { import wacc.types._ @@ -100,17 +101,26 @@ object typeChecker { */ def check(prog: ast.Program)(using ctx: TypeCheckerCtx - ): Unit = { - // Ignore function syntax types for return value and params, since those have been converted - // to SemTypes by the renamer. - prog.funcs.foreach { case ast.FuncDecl(_, name, _, stmts) => - val FuncType(retType, _) = ctx.funcType(name) - stmts.toList.foreach( - checkStmt(_, Constraint.Is(retType, s"function ${name.v} must return $retType")) - ) - } - prog.main.toList.foreach(checkStmt(_, Constraint.Never("main function must not return"))) - } + ): microWacc.Program = + microWacc.Program( + // Ignore function syntax types for return value and params, since those have been converted + // to SemTypes by the renamer. + prog.funcs.map { case ast.FuncDecl(_, name, params, stmts) => + val FuncType(retType, paramTypes) = ctx.funcType(name) + microWacc.FuncDecl( + microWacc.Ident(name.v, name.uid)(retType), + params.zip(paramTypes).map { case (ast.Param(_, ident), ty) => + microWacc.Ident(ident.v, name.uid)(ty) + }, + stmts.toList + .flatMap( + checkStmt(_, Constraint.Is(retType, s"function ${name.v} must return $retType")) + ) + ) + }, + prog.main.toList + .flatMap(checkStmt(_, Constraint.Never("main function must not return"))) + ) /** Type-check an AST statement node. * @@ -121,32 +131,38 @@ object typeChecker { */ private def checkStmt(stmt: ast.Stmt, returnConstraint: Constraint)(using ctx: TypeCheckerCtx - ): Unit = stmt match { + ): List[microWacc.Stmt] = stmt match { // Ignore the type of the variable, since it has been converted to a SemType by the renamer. case ast.VarDecl(_, name, value) => val expectedTy = ctx.typeOf(name) - checkValue( + val typedValue = checkValue( value, Constraint.Is( expectedTy, s"variable ${name.v} must be assigned a value of type $expectedTy" ) ) + List(microWacc.Assign(microWacc.Ident(name.v, name.uid)(expectedTy), typedValue)) case ast.Assign(lhs, rhs) => - val lhsTy = checkValue(lhs, Constraint.Unconstrained) - (lhsTy, checkValue(rhs, Constraint.Is(lhsTy, s"assignment must have type $lhsTy"))) match { + val lhsTyped = checkLValue(lhs, Constraint.Unconstrained) + val rhsTyped = + checkValue(rhs, Constraint.Is(lhsTyped.ty, s"assignment must have type ${lhsTyped.ty}")) + (lhsTyped.ty, rhsTyped.ty) match { case (?, ?) => ctx.error( Error.SemanticError(lhs.pos, "assignment with both sides of unknown type is illegal") ) case _ => () } + List(microWacc.Assign(lhsTyped, rhsTyped)) case ast.Read(dest) => - checkValue(dest, Constraint.Unconstrained) match { + val destTyped = checkLValue(dest, Constraint.Unconstrained) + val destTy = destTyped.ty match { case ? => ctx.error( Error.SemanticError(dest.pos, "cannot read into a destination with an unknown type") ) + ? case destTy => destTy.satisfies( Constraint.IsEither( @@ -157,32 +173,69 @@ object typeChecker { dest.pos ) } + List( + microWacc.Assign( + destTyped, + microWacc.Call( + destTy match { + case KnownType.Int => microWacc.Builtin.ReadInt + case KnownType.Char => microWacc.Builtin.ReadChar + case _ => microWacc.Builtin.ReadInt // we'll stop due to error anyway + }, + Nil + ) + ) + ) case ast.Free(lhs) => - checkValue( - lhs, - Constraint.IsEither( - KnownType.Array(?), - KnownType.Pair(?, ?), - "free must be applied to an array or pair" + List( + microWacc.Call( + microWacc.Builtin.Free, + List( + checkValue( + lhs, + Constraint.IsEither( + KnownType.Array(?), + KnownType.Pair(?, ?), + "free must be applied to an array or pair" + ) + ) + ) ) ) case ast.Return(expr) => - checkValue(expr, returnConstraint) + List(microWacc.Return(checkValue(expr, returnConstraint))) case ast.Exit(expr) => - checkValue(expr, Constraint.Is(KnownType.Int, "exit value must be int")) - case ast.Print(expr, _) => + List( + microWacc.Call( + microWacc.Builtin.Exit, + List(checkValue(expr, Constraint.Is(KnownType.Int, "exit value must be int"))) + ) + ) + case ast.Print(expr, newline) => // This constraint should never fail, the scope-checker should have caught it already - checkValue(expr, Constraint.Unconstrained) + List( + microWacc.Call( + if newline then microWacc.Builtin.Println else microWacc.Builtin.Print, + List(checkValue(expr, Constraint.Unconstrained)) + ) + ) case ast.If(cond, thenStmt, elseStmt) => - checkValue(cond, Constraint.Is(KnownType.Bool, "if condition must be a bool")) - thenStmt.toList.foreach(checkStmt(_, returnConstraint)) - elseStmt.toList.foreach(checkStmt(_, returnConstraint)) + List( + microWacc.If( + checkValue(cond, Constraint.Is(KnownType.Bool, "if condition must be a bool")), + thenStmt.toList.flatMap(checkStmt(_, returnConstraint)), + elseStmt.toList.flatMap(checkStmt(_, returnConstraint)) + ) + ) case ast.While(cond, body) => - checkValue(cond, Constraint.Is(KnownType.Bool, "while condition must be a bool")) - body.toList.foreach(checkStmt(_, returnConstraint)) - case ast.Block(body) => - body.toList.foreach(checkStmt(_, returnConstraint)) - case ast.Skip() => () + List( + microWacc.While( + checkValue(cond, Constraint.Is(KnownType.Bool, "while condition must be a bool")), + body.toList.flatMap(checkStmt(_, returnConstraint)) + ) + ) + case ast.Block(body) => body.toList.flatMap(checkStmt(_, returnConstraint)) + case skip @ ast.Skip() => List.empty } /** Type-check an AST LValue, RValue or Expr node. This function does all 3 since these traits @@ -197,47 +250,42 @@ object typeChecker { */ private def checkValue(value: ast.LValue | ast.RValue | ast.Expr, constraint: Constraint)(using ctx: TypeCheckerCtx - ): SemType = value match { - case l @ ast.IntLiter(_) => KnownType.Int.satisfies(constraint, l.pos) - case l @ ast.BoolLiter(_) => KnownType.Bool.satisfies(constraint, l.pos) - case l @ ast.CharLiter(_) => KnownType.Char.satisfies(constraint, l.pos) - case l @ ast.StrLiter(_) => KnownType.String.satisfies(constraint, l.pos) - case l @ ast.PairLiter() => KnownType.Pair(?, ?).satisfies(constraint, l.pos) - case id: ast.Ident => - ctx.typeOf(id).satisfies(constraint, id.pos) - case ast.ArrayElem(id, indices) => - val arrayTy = ctx.typeOf(id) - val elemTy = indices.foldLeftM(arrayTy) { (acc, elem) => - checkValue(elem, Constraint.Is(KnownType.Int, "array index must be an int")) - acc match { - case KnownType.Array(innerTy) => Some(innerTy) - case ? => Some(?) // we can keep indexing an unknown type - case nonArrayTy => - ctx.error( - Error.TypeMismatch(elem.pos, KnownType.Array(?), acc, "cannot index into a non-array") - ) - None - } - } - elemTy.getOrElse(?).satisfies(constraint, id.pos) + ): microWacc.Expr = value match { + case l @ ast.IntLiter(v) => + KnownType.Int.satisfies(constraint, l.pos) + microWacc.IntLiter(v) + case l @ ast.BoolLiter(v) => + KnownType.Bool.satisfies(constraint, l.pos) + microWacc.BoolLiter(v) + case l @ ast.CharLiter(v) => + KnownType.Char.satisfies(constraint, l.pos) + microWacc.CharLiter(v) + case l @ ast.StrLiter(v) => + KnownType.String.satisfies(constraint, l.pos) + microWacc.ArrayLiter(v.map(microWacc.CharLiter(_)).toList)(KnownType.String) + case l @ ast.PairLiter() => + microWacc.NullLiter()(KnownType.Pair(?, ?).satisfies(constraint, l.pos)) case ast.Parens(expr) => checkValue(expr, constraint) case l @ ast.ArrayLiter(elems) => - KnownType - // Start with an unknown param type, make it more specific while checking the elements. - .Array(elems.foldLeft[SemType](?) { case (acc, elem) => - checkValue( + val (elemTy, elemsTyped) = elems.mapAccumulate[SemType, microWacc.Expr](?) { + case (acc, elem) => + val elemTyped = checkValue( elem, Constraint.IsSymmetricCompatible(acc, s"array elements must have the same type") ) - }) + (elemTyped.ty, elemTyped) + } + val arrayTy = KnownType + // Start with an unknown param type, make it more specific while checking the elements. + .Array(elemTy) .satisfies(constraint, l.pos) + microWacc.ArrayLiter(elemsTyped)(arrayTy) case l @ ast.NewPair(fst, snd) => - KnownType - .Pair( - checkValue(fst, Constraint.Unconstrained), - checkValue(snd, Constraint.Unconstrained) - ) - .satisfies(constraint, l.pos) + val fstTyped = checkValue(fst, Constraint.Unconstrained) + val sndTyped = checkValue(snd, Constraint.Unconstrained) + microWacc.ArrayLiter(List(fstTyped, sndTyped))( + KnownType.Pair(fstTyped.ty, sndTyped.ty).satisfies(constraint, l.pos) + ) case ast.Call(id, args) => val funcTy @ FuncType(retTy, paramTys) = ctx.funcType(id) if (args.length != paramTys.length) { @@ -245,76 +293,152 @@ object typeChecker { } // Even if the number of arguments is wrong, we still check the types of the arguments // in the best way we can (by taking a zip). - args.zip(paramTys).foreach { case (arg, paramTy) => + val argsTyped = args.zip(paramTys).map { case (arg, paramTy) => checkValue(arg, Constraint.Is(paramTy, s"argument type mismatch in function ${id.v}")) } - retTy.satisfies(constraint, id.pos) - case ast.Fst(elem) => - checkValue( - elem, - Constraint.Is(KnownType.Pair(?, ?), "fst must be applied to a pair") - ) match { - case what @ KnownType.Pair(left, _) => - left.satisfies(constraint, elem.pos) - case _ => ctx.error(Error.InternalError(elem.pos, "fst must be applied to a pair")) - } - case ast.Snd(elem) => - checkValue( - elem, - Constraint.Is(KnownType.Pair(?, ?), "snd must be applied to a pair") - ) match { - case KnownType.Pair(_, right) => right.satisfies(constraint, elem.pos) - case _ => ctx.error(Error.InternalError(elem.pos, "snd must be applied to a pair")) - } + microWacc.Call(microWacc.Ident(id.v, id.uid)(retTy.satisfies(constraint, id.pos)), argsTyped) // Unary operators case ast.Negate(x) => - checkValue(x, Constraint.Is(KnownType.Int, "negation must be applied to an int")) - KnownType.Int.satisfies(constraint, x.pos) + microWacc.UnaryOp( + checkValue(x, Constraint.Is(KnownType.Int, "negation must be applied to an int")), + microWacc.UnaryOperator.Negate + )(KnownType.Int.satisfies(constraint, x.pos)) case ast.Not(x) => - checkValue(x, Constraint.Is(KnownType.Bool, "logical not must be applied to a bool")) - KnownType.Bool.satisfies(constraint, x.pos) + microWacc.UnaryOp( + checkValue(x, Constraint.Is(KnownType.Bool, "logical not must be applied to a bool")), + microWacc.UnaryOperator.Not + )(KnownType.Bool.satisfies(constraint, x.pos)) case ast.Len(x) => - checkValue(x, Constraint.Is(KnownType.Array(?), "len must be applied to an array")) - KnownType.Int.satisfies(constraint, x.pos) + microWacc.UnaryOp( + checkValue(x, Constraint.Is(KnownType.Array(?), "len must be applied to an array")), + microWacc.UnaryOperator.Len + )(KnownType.Int.satisfies(constraint, x.pos)) case ast.Ord(x) => - checkValue(x, Constraint.Is(KnownType.Char, "ord must be applied to a char")) - KnownType.Int.satisfies(constraint, x.pos) + microWacc.UnaryOp( + checkValue(x, Constraint.Is(KnownType.Char, "ord must be applied to a char")), + microWacc.UnaryOperator.Ord + )(KnownType.Int.satisfies(constraint, x.pos)) case ast.Chr(x) => - checkValue(x, Constraint.Is(KnownType.Int, "chr must be applied to an int")) - KnownType.Char.satisfies(constraint, x.pos) + microWacc.UnaryOp( + checkValue(x, Constraint.Is(KnownType.Int, "chr must be applied to an int")), + microWacc.UnaryOperator.Chr + )(KnownType.Char.satisfies(constraint, x.pos)) // Binary operators case op: (ast.Add | ast.Sub | ast.Mul | ast.Div | ast.Mod) => val operand = Constraint.Is(KnownType.Int, s"${op.name} operator must be applied to an int") - checkValue(op.x, operand) - checkValue(op.y, operand) - KnownType.Int.satisfies(constraint, op.pos) + microWacc.BinaryOp( + checkValue(op.x, operand), + checkValue(op.y, operand), + microWacc.BinaryOperator.fromAst(op) + )(KnownType.Int.satisfies(constraint, op.pos)) case op: (ast.Eq | ast.Neq) => - val xTy = checkValue(op.x, Constraint.Unconstrained) - checkValue( - op.y, - Constraint.Is(xTy, s"${op.name} operator must be applied to values of the same type") - ) - KnownType.Bool.satisfies(constraint, op.pos) + val xTyped = checkValue(op.x, Constraint.Unconstrained) + microWacc.BinaryOp( + xTyped, + checkValue( + op.y, + Constraint + .Is(xTyped.ty, s"${op.name} operator must be applied to values of the same type") + ), + microWacc.BinaryOperator.fromAst(op) + )(KnownType.Bool.satisfies(constraint, op.pos)) case op: (ast.Less | ast.LessEq | ast.Greater | ast.GreaterEq) => val xConstraint = Constraint.IsEither( KnownType.Int, KnownType.Char, s"${op.name} operator must be applied to an int or char" ) + val xTyped = checkValue(op.x, xConstraint) // If x type-check failed, we still want to check y is an Int or Char (rather than ?) - val yConstraint = checkValue(op.x, xConstraint) match { + val yConstraint = xTyped.ty match { case ? => xConstraint case xTy => Constraint.Is(xTy, s"${op.name} operator must be applied to values of the same type") } - checkValue(op.y, yConstraint) - KnownType.Bool.satisfies(constraint, op.pos) + microWacc.BinaryOp( + xTyped, + checkValue(op.y, yConstraint), + microWacc.BinaryOperator.fromAst(op) + )(KnownType.Bool.satisfies(constraint, op.pos)) case op: (ast.And | ast.Or) => val operand = Constraint.Is(KnownType.Bool, s"${op.name} operator must be applied to a bool") - checkValue(op.x, operand) - checkValue(op.y, operand) - KnownType.Bool.satisfies(constraint, op.pos) + microWacc.BinaryOp( + checkValue(op.x, operand), + checkValue(op.y, operand), + microWacc.BinaryOperator.fromAst(op) + )(KnownType.Bool.satisfies(constraint, op.pos)) + + case lvalue: ast.LValue => checkLValue(lvalue, constraint) + } + + /** Type-check an AST LValue node. Separate because microWacc keeps LValues + * + * @param value + * The value to type-check. + * @param constraint + * The type constraint that the value must satisfy. + * @param ctx + * The type checker context which includes the global names and functions, and an errors + * builder. + * @return + * The most specific type of the value if it could be determined, or ? if it could not. + */ + private def checkLValue(value: ast.LValue, constraint: Constraint)(using + ctx: TypeCheckerCtx + ): microWacc.LValue = value match { + case id @ ast.Ident(name, uid) => + microWacc.Ident(name, uid)(ctx.typeOf(id).satisfies(constraint, id.pos)) + case ast.ArrayElem(id, indices) => + val arrayTy = ctx.typeOf(id) + val (elemTy, indicesTyped) = indices.mapAccumulate(arrayTy) { (acc, elem) => + val idxTyped = checkValue(elem, Constraint.Is(KnownType.Int, "array index must be an int")) + val next = acc match { + case KnownType.Array(innerTy) => innerTy + case ? => ? // we can keep indexing an unknown type + case nonArrayTy => + ctx.error( + Error.TypeMismatch( + elem.pos, + KnownType.Array(?), + acc, + "cannot index into a non-array" + ) + ) + ? + } + (next, idxTyped) + } + microWacc.ArrayElem( + microWacc.Ident(id.v, id.uid)(arrayTy), + indicesTyped + )(elemTy.satisfies(constraint, value.pos)) + case ast.Fst(elem) => + val elemTyped = checkLValue( + elem, + Constraint.Is(KnownType.Pair(?, ?), "fst must be applied to a pair") + ) + microWacc.ArrayElem( + elemTyped, + NonEmptyList.of(microWacc.IntLiter(0)) + )(elemTyped.ty match { + case KnownType.Pair(left, _) => + left.satisfies(constraint, elem.pos) + case _ => ctx.error(Error.InternalError(elem.pos, "fst must be applied to a pair")) + }) + case ast.Snd(elem) => + val elemTyped = checkLValue( + elem, + Constraint.Is(KnownType.Pair(?, ?), "snd must be applied to a pair") + ) + microWacc.ArrayElem( + elemTyped, + NonEmptyList.of(microWacc.IntLiter(1)) + )(elemTyped.ty match { + case KnownType.Pair(_, right) => + right.satisfies(constraint, elem.pos) + case _ => ctx.error(Error.InternalError(elem.pos, "snd must be applied to a pair")) + }) } } diff --git a/src/main/wacc/Main.scala b/src/main/wacc/Main.scala index 5e95424..445d9c1 100644 --- a/src/main/wacc/Main.scala +++ b/src/main/wacc/Main.scala @@ -36,7 +36,7 @@ def compile(contents: String): Int = { given errors: mutable.Builder[Error, List[Error]] = List.newBuilder val (names, funcs) = renamer.rename(prog) given ctx: typeChecker.TypeCheckerCtx = typeChecker.TypeCheckerCtx(names, funcs, errors) - typeChecker.check(prog) + val typedProg = typeChecker.check(prog) if (errors.result.nonEmpty) { given errorContent: String = contents errors.result @@ -48,7 +48,10 @@ def compile(contents: String): Int = { } } .max() - } else 0 + } else { + println(typedProg) + 0 + } case Failure(msg) => println(msg) 100