From 756b42dd724f184901d688bd9c7a16e993f63943 Mon Sep 17 00:00:00 2001 From: Gleb Koval Date: Thu, 13 Feb 2025 23:12:41 +0000 Subject: [PATCH] refactor: remove implicit ast from type checker --- src/main/wacc/Frontend/typeChecker.scala | 86 ++++++++++++------------ 1 file changed, 42 insertions(+), 44 deletions(-) diff --git a/src/main/wacc/Frontend/typeChecker.scala b/src/main/wacc/Frontend/typeChecker.scala index 62fad3b..21d49ae 100644 --- a/src/main/wacc/Frontend/typeChecker.scala +++ b/src/main/wacc/Frontend/typeChecker.scala @@ -4,17 +4,15 @@ import cats.syntax.all._ import scala.collection.mutable object typeChecker { - import wacc.ast._ import wacc.types._ case class TypeCheckerCtx( - globalNames: Map[Ident, SemType], - globalFuncs: Map[Ident, FuncType], + globalNames: Map[ast.Ident, SemType], + globalFuncs: Map[ast.Ident, FuncType], errors: mutable.Builder[Error, List[Error]] ) { - def typeOf(ident: Ident): SemType = globalNames(ident) - - def funcType(ident: Ident): FuncType = globalFuncs(ident) + def typeOf(ident: ast.Ident): SemType = globalNames(ident) + def funcType(ident: ast.Ident): FuncType = globalFuncs(ident) def error(err: Error): SemType = errors += err @@ -44,7 +42,7 @@ object typeChecker { * @return * The type if the constraint was satisfied, or ? if it was not. */ - private def satisfies(constraint: Constraint, pos: Position)(using + private def satisfies(constraint: Constraint, pos: ast.Position)(using ctx: TypeCheckerCtx ): SemType = (ty, constraint) match { @@ -100,12 +98,12 @@ object typeChecker { * The type checker context which includes the global names and functions, and an errors * builder. */ - def check(prog: Program)(using + def check(prog: ast.Program)(using ctx: TypeCheckerCtx ): Unit = { // Ignore function syntax types for return value and params, since those have been converted // to SemTypes by the renamer. - prog.funcs.foreach { case FuncDecl(_, name, _, stmts) => + prog.funcs.foreach { case ast.FuncDecl(_, name, _, stmts) => val FuncType(retType, _) = ctx.funcType(name) stmts.toList.foreach( checkStmt(_, Constraint.Is(retType, s"function ${name.v} must return $retType")) @@ -121,11 +119,11 @@ object typeChecker { * @param returnConstraint * The constraint that any `return ` statements must satisfy. */ - private def checkStmt(stmt: Stmt, returnConstraint: Constraint)(using + private def checkStmt(stmt: ast.Stmt, returnConstraint: Constraint)(using ctx: TypeCheckerCtx ): Unit = stmt match { // Ignore the type of the variable, since it has been converted to a SemType by the renamer. - case VarDecl(_, name, value) => + case ast.VarDecl(_, name, value) => val expectedTy = ctx.typeOf(name) checkValue( value, @@ -134,7 +132,7 @@ object typeChecker { s"variable ${name.v} must be assigned a value of type $expectedTy" ) ) - case Assign(lhs, rhs) => + case ast.Assign(lhs, rhs) => val lhsTy = checkValue(lhs, Constraint.Unconstrained) (lhsTy, checkValue(rhs, Constraint.Is(lhsTy, s"assignment must have type $lhsTy"))) match { case (?, ?) => @@ -143,7 +141,7 @@ object typeChecker { ) case _ => () } - case Read(dest) => + case ast.Read(dest) => checkValue(dest, Constraint.Unconstrained) match { case ? => ctx.error( @@ -159,7 +157,7 @@ object typeChecker { dest.pos ) } - case Free(lhs) => + case ast.Free(lhs) => checkValue( lhs, Constraint.IsEither( @@ -168,23 +166,23 @@ object typeChecker { "free must be applied to an array or pair" ) ) - case Return(expr) => + case ast.Return(expr) => checkValue(expr, returnConstraint) - case Exit(expr) => + case ast.Exit(expr) => checkValue(expr, Constraint.Is(KnownType.Int, "exit value must be int")) - case Print(expr, _) => + case ast.Print(expr, _) => // This constraint should never fail, the scope-checker should have caught it already checkValue(expr, Constraint.Unconstrained) - case If(cond, thenStmt, elseStmt) => + case ast.If(cond, thenStmt, elseStmt) => checkValue(cond, Constraint.Is(KnownType.Bool, "if condition must be a bool")) thenStmt.toList.foreach(checkStmt(_, returnConstraint)) elseStmt.toList.foreach(checkStmt(_, returnConstraint)) - case While(cond, body) => + case ast.While(cond, body) => checkValue(cond, Constraint.Is(KnownType.Bool, "while condition must be a bool")) body.toList.foreach(checkStmt(_, returnConstraint)) - case Block(body) => + case ast.Block(body) => body.toList.foreach(checkStmt(_, returnConstraint)) - case Skip() => () + case ast.Skip() => () } /** Type-check an AST LValue, RValue or Expr node. This function does all 3 since these traits @@ -197,17 +195,17 @@ object typeChecker { * @return * The most specific type of the value if it could be determined, or ? if it could not. */ - private def checkValue(value: LValue | RValue | Expr, constraint: Constraint)(using + private def checkValue(value: ast.LValue | ast.RValue | ast.Expr, constraint: Constraint)(using ctx: TypeCheckerCtx ): SemType = value match { - case l @ IntLiter(_) => KnownType.Int.satisfies(constraint, l.pos) - case l @ BoolLiter(_) => KnownType.Bool.satisfies(constraint, l.pos) - case l @ CharLiter(_) => KnownType.Char.satisfies(constraint, l.pos) - case l @ StrLiter(_) => KnownType.String.satisfies(constraint, l.pos) - case l @ PairLiter() => KnownType.Pair(?, ?).satisfies(constraint, l.pos) - case id: Ident => + case l @ ast.IntLiter(_) => KnownType.Int.satisfies(constraint, l.pos) + case l @ ast.BoolLiter(_) => KnownType.Bool.satisfies(constraint, l.pos) + case l @ ast.CharLiter(_) => KnownType.Char.satisfies(constraint, l.pos) + case l @ ast.StrLiter(_) => KnownType.String.satisfies(constraint, l.pos) + case l @ ast.PairLiter() => KnownType.Pair(?, ?).satisfies(constraint, l.pos) + case id: ast.Ident => ctx.typeOf(id).satisfies(constraint, id.pos) - case ArrayElem(id, indices) => + case ast.ArrayElem(id, indices) => val arrayTy = ctx.typeOf(id) val elemTy = indices.foldLeftM(arrayTy) { (acc, elem) => checkValue(elem, Constraint.Is(KnownType.Int, "array index must be an int")) @@ -222,8 +220,8 @@ object typeChecker { } } elemTy.getOrElse(?).satisfies(constraint, id.pos) - case Parens(expr) => checkValue(expr, constraint) - case l @ ArrayLiter(elems) => + case ast.Parens(expr) => checkValue(expr, constraint) + case l @ ast.ArrayLiter(elems) => KnownType // Start with an unknown param type, make it more specific while checking the elements. .Array(elems.foldLeft[SemType](?) { case (acc, elem) => @@ -233,14 +231,14 @@ object typeChecker { ) }) .satisfies(constraint, l.pos) - case l @ NewPair(fst, snd) => + case l @ ast.NewPair(fst, snd) => KnownType .Pair( checkValue(fst, Constraint.Unconstrained), checkValue(snd, Constraint.Unconstrained) ) .satisfies(constraint, l.pos) - case Call(id, args) => + case ast.Call(id, args) => val funcTy @ FuncType(retTy, paramTys) = ctx.funcType(id) if (args.length != paramTys.length) { ctx.error(Error.FunctionParamsMismatch(id, paramTys.length, args.length, funcTy)) @@ -251,7 +249,7 @@ object typeChecker { checkValue(arg, Constraint.Is(paramTy, s"argument type mismatch in function ${id.v}")) } retTy.satisfies(constraint, id.pos) - case Fst(elem) => + case ast.Fst(elem) => checkValue( elem, Constraint.Is(KnownType.Pair(?, ?), "fst must be applied to a pair") @@ -260,7 +258,7 @@ object typeChecker { left.satisfies(constraint, elem.pos) case _ => ctx.error(Error.InternalError(elem.pos, "fst must be applied to a pair")) } - case Snd(elem) => + case ast.Snd(elem) => checkValue( elem, Constraint.Is(KnownType.Pair(?, ?), "snd must be applied to a pair") @@ -270,36 +268,36 @@ object typeChecker { } // Unary operators - case Negate(x) => + case ast.Negate(x) => checkValue(x, Constraint.Is(KnownType.Int, "negation must be applied to an int")) KnownType.Int.satisfies(constraint, x.pos) - case Not(x) => + case ast.Not(x) => checkValue(x, Constraint.Is(KnownType.Bool, "logical not must be applied to a bool")) KnownType.Bool.satisfies(constraint, x.pos) - case Len(x) => + case ast.Len(x) => checkValue(x, Constraint.Is(KnownType.Array(?), "len must be applied to an array")) KnownType.Int.satisfies(constraint, x.pos) - case Ord(x) => + case ast.Ord(x) => checkValue(x, Constraint.Is(KnownType.Char, "ord must be applied to a char")) KnownType.Int.satisfies(constraint, x.pos) - case Chr(x) => + case ast.Chr(x) => checkValue(x, Constraint.Is(KnownType.Int, "chr must be applied to an int")) KnownType.Char.satisfies(constraint, x.pos) // Binary operators - case op: (Add | Sub | Mul | Div | Mod) => + case op: (ast.Add | ast.Sub | ast.Mul | ast.Div | ast.Mod) => val operand = Constraint.Is(KnownType.Int, s"${op.name} operator must be applied to an int") checkValue(op.x, operand) checkValue(op.y, operand) KnownType.Int.satisfies(constraint, op.pos) - case op: (Eq | Neq) => + case op: (ast.Eq | ast.Neq) => val xTy = checkValue(op.x, Constraint.Unconstrained) checkValue( op.y, Constraint.Is(xTy, s"${op.name} operator must be applied to values of the same type") ) KnownType.Bool.satisfies(constraint, op.pos) - case op: (Less | LessEq | Greater | GreaterEq) => + case op: (ast.Less | ast.LessEq | ast.Greater | ast.GreaterEq) => val xConstraint = Constraint.IsEither( KnownType.Int, KnownType.Char, @@ -313,7 +311,7 @@ object typeChecker { } checkValue(op.y, yConstraint) KnownType.Bool.satisfies(constraint, op.pos) - case op: (And | Or) => + case op: (ast.And | ast.Or) => val operand = Constraint.Is(KnownType.Bool, s"${op.name} operator must be applied to a bool") checkValue(op.x, operand) checkValue(op.y, operand)