refactor: remove implicit ast from type checker
This commit is contained in:
		| @@ -4,17 +4,15 @@ import cats.syntax.all._ | |||||||
| import scala.collection.mutable | import scala.collection.mutable | ||||||
|  |  | ||||||
| object typeChecker { | object typeChecker { | ||||||
|   import wacc.ast._ |  | ||||||
|   import wacc.types._ |   import wacc.types._ | ||||||
|  |  | ||||||
|   case class TypeCheckerCtx( |   case class TypeCheckerCtx( | ||||||
|       globalNames: Map[Ident, SemType], |       globalNames: Map[ast.Ident, SemType], | ||||||
|       globalFuncs: Map[Ident, FuncType], |       globalFuncs: Map[ast.Ident, FuncType], | ||||||
|       errors: mutable.Builder[Error, List[Error]] |       errors: mutable.Builder[Error, List[Error]] | ||||||
|   ) { |   ) { | ||||||
|     def typeOf(ident: Ident): SemType = globalNames(ident) |     def typeOf(ident: ast.Ident): SemType = globalNames(ident) | ||||||
|  |     def funcType(ident: ast.Ident): FuncType = globalFuncs(ident) | ||||||
|     def funcType(ident: Ident): FuncType = globalFuncs(ident) |  | ||||||
|  |  | ||||||
|     def error(err: Error): SemType = |     def error(err: Error): SemType = | ||||||
|       errors += err |       errors += err | ||||||
| @@ -44,7 +42,7 @@ object typeChecker { | |||||||
|       * @return |       * @return | ||||||
|       *   The type if the constraint was satisfied, or ? if it was not. |       *   The type if the constraint was satisfied, or ? if it was not. | ||||||
|       */ |       */ | ||||||
|     private def satisfies(constraint: Constraint, pos: Position)(using |     private def satisfies(constraint: Constraint, pos: ast.Position)(using | ||||||
|         ctx: TypeCheckerCtx |         ctx: TypeCheckerCtx | ||||||
|     ): SemType = |     ): SemType = | ||||||
|       (ty, constraint) match { |       (ty, constraint) match { | ||||||
| @@ -100,12 +98,12 @@ object typeChecker { | |||||||
|     *   The type checker context which includes the global names and functions, and an errors |     *   The type checker context which includes the global names and functions, and an errors | ||||||
|     *   builder. |     *   builder. | ||||||
|     */ |     */ | ||||||
|   def check(prog: Program)(using |   def check(prog: ast.Program)(using | ||||||
|       ctx: TypeCheckerCtx |       ctx: TypeCheckerCtx | ||||||
|   ): Unit = { |   ): Unit = { | ||||||
|     // Ignore function syntax types for return value and params, since those have been converted |     // Ignore function syntax types for return value and params, since those have been converted | ||||||
|     // to SemTypes by the renamer. |     // to SemTypes by the renamer. | ||||||
|     prog.funcs.foreach { case FuncDecl(_, name, _, stmts) => |     prog.funcs.foreach { case ast.FuncDecl(_, name, _, stmts) => | ||||||
|       val FuncType(retType, _) = ctx.funcType(name) |       val FuncType(retType, _) = ctx.funcType(name) | ||||||
|       stmts.toList.foreach( |       stmts.toList.foreach( | ||||||
|         checkStmt(_, Constraint.Is(retType, s"function ${name.v} must return $retType")) |         checkStmt(_, Constraint.Is(retType, s"function ${name.v} must return $retType")) | ||||||
| @@ -121,11 +119,11 @@ object typeChecker { | |||||||
|     * @param returnConstraint |     * @param returnConstraint | ||||||
|     *   The constraint that any `return <expr>` statements must satisfy. |     *   The constraint that any `return <expr>` statements must satisfy. | ||||||
|     */ |     */ | ||||||
|   private def checkStmt(stmt: Stmt, returnConstraint: Constraint)(using |   private def checkStmt(stmt: ast.Stmt, returnConstraint: Constraint)(using | ||||||
|       ctx: TypeCheckerCtx |       ctx: TypeCheckerCtx | ||||||
|   ): Unit = stmt match { |   ): Unit = stmt match { | ||||||
|     // Ignore the type of the variable, since it has been converted to a SemType by the renamer. |     // Ignore the type of the variable, since it has been converted to a SemType by the renamer. | ||||||
|     case VarDecl(_, name, value) => |     case ast.VarDecl(_, name, value) => | ||||||
|       val expectedTy = ctx.typeOf(name) |       val expectedTy = ctx.typeOf(name) | ||||||
|       checkValue( |       checkValue( | ||||||
|         value, |         value, | ||||||
| @@ -134,7 +132,7 @@ object typeChecker { | |||||||
|           s"variable ${name.v} must be assigned a value of type $expectedTy" |           s"variable ${name.v} must be assigned a value of type $expectedTy" | ||||||
|         ) |         ) | ||||||
|       ) |       ) | ||||||
|     case Assign(lhs, rhs) => |     case ast.Assign(lhs, rhs) => | ||||||
|       val lhsTy = checkValue(lhs, Constraint.Unconstrained) |       val lhsTy = checkValue(lhs, Constraint.Unconstrained) | ||||||
|       (lhsTy, checkValue(rhs, Constraint.Is(lhsTy, s"assignment must have type $lhsTy"))) match { |       (lhsTy, checkValue(rhs, Constraint.Is(lhsTy, s"assignment must have type $lhsTy"))) match { | ||||||
|         case (?, ?) => |         case (?, ?) => | ||||||
| @@ -143,7 +141,7 @@ object typeChecker { | |||||||
|           ) |           ) | ||||||
|         case _ => () |         case _ => () | ||||||
|       } |       } | ||||||
|     case Read(dest) => |     case ast.Read(dest) => | ||||||
|       checkValue(dest, Constraint.Unconstrained) match { |       checkValue(dest, Constraint.Unconstrained) match { | ||||||
|         case ? => |         case ? => | ||||||
|           ctx.error( |           ctx.error( | ||||||
| @@ -159,7 +157,7 @@ object typeChecker { | |||||||
|             dest.pos |             dest.pos | ||||||
|           ) |           ) | ||||||
|       } |       } | ||||||
|     case Free(lhs) => |     case ast.Free(lhs) => | ||||||
|       checkValue( |       checkValue( | ||||||
|         lhs, |         lhs, | ||||||
|         Constraint.IsEither( |         Constraint.IsEither( | ||||||
| @@ -168,23 +166,23 @@ object typeChecker { | |||||||
|           "free must be applied to an array or pair" |           "free must be applied to an array or pair" | ||||||
|         ) |         ) | ||||||
|       ) |       ) | ||||||
|     case Return(expr) => |     case ast.Return(expr) => | ||||||
|       checkValue(expr, returnConstraint) |       checkValue(expr, returnConstraint) | ||||||
|     case Exit(expr) => |     case ast.Exit(expr) => | ||||||
|       checkValue(expr, Constraint.Is(KnownType.Int, "exit value must be int")) |       checkValue(expr, Constraint.Is(KnownType.Int, "exit value must be int")) | ||||||
|     case Print(expr, _) => |     case ast.Print(expr, _) => | ||||||
|       // This constraint should never fail, the scope-checker should have caught it already |       // This constraint should never fail, the scope-checker should have caught it already | ||||||
|       checkValue(expr, Constraint.Unconstrained) |       checkValue(expr, Constraint.Unconstrained) | ||||||
|     case If(cond, thenStmt, elseStmt) => |     case ast.If(cond, thenStmt, elseStmt) => | ||||||
|       checkValue(cond, Constraint.Is(KnownType.Bool, "if condition must be a bool")) |       checkValue(cond, Constraint.Is(KnownType.Bool, "if condition must be a bool")) | ||||||
|       thenStmt.toList.foreach(checkStmt(_, returnConstraint)) |       thenStmt.toList.foreach(checkStmt(_, returnConstraint)) | ||||||
|       elseStmt.toList.foreach(checkStmt(_, returnConstraint)) |       elseStmt.toList.foreach(checkStmt(_, returnConstraint)) | ||||||
|     case While(cond, body) => |     case ast.While(cond, body) => | ||||||
|       checkValue(cond, Constraint.Is(KnownType.Bool, "while condition must be a bool")) |       checkValue(cond, Constraint.Is(KnownType.Bool, "while condition must be a bool")) | ||||||
|       body.toList.foreach(checkStmt(_, returnConstraint)) |       body.toList.foreach(checkStmt(_, returnConstraint)) | ||||||
|     case Block(body) => |     case ast.Block(body) => | ||||||
|       body.toList.foreach(checkStmt(_, returnConstraint)) |       body.toList.foreach(checkStmt(_, returnConstraint)) | ||||||
|     case Skip() => () |     case ast.Skip() => () | ||||||
|   } |   } | ||||||
|  |  | ||||||
|   /** Type-check an AST LValue, RValue or Expr node. This function does all 3 since these traits |   /** Type-check an AST LValue, RValue or Expr node. This function does all 3 since these traits | ||||||
| @@ -197,17 +195,17 @@ object typeChecker { | |||||||
|     * @return |     * @return | ||||||
|     *   The most specific type of the value if it could be determined, or ? if it could not. |     *   The most specific type of the value if it could be determined, or ? if it could not. | ||||||
|     */ |     */ | ||||||
|   private def checkValue(value: LValue | RValue | Expr, constraint: Constraint)(using |   private def checkValue(value: ast.LValue | ast.RValue | ast.Expr, constraint: Constraint)(using | ||||||
|       ctx: TypeCheckerCtx |       ctx: TypeCheckerCtx | ||||||
|   ): SemType = value match { |   ): SemType = value match { | ||||||
|     case l @ IntLiter(_)  => KnownType.Int.satisfies(constraint, l.pos) |     case l @ ast.IntLiter(_)  => KnownType.Int.satisfies(constraint, l.pos) | ||||||
|     case l @ BoolLiter(_) => KnownType.Bool.satisfies(constraint, l.pos) |     case l @ ast.BoolLiter(_) => KnownType.Bool.satisfies(constraint, l.pos) | ||||||
|     case l @ CharLiter(_) => KnownType.Char.satisfies(constraint, l.pos) |     case l @ ast.CharLiter(_) => KnownType.Char.satisfies(constraint, l.pos) | ||||||
|     case l @ StrLiter(_)  => KnownType.String.satisfies(constraint, l.pos) |     case l @ ast.StrLiter(_)  => KnownType.String.satisfies(constraint, l.pos) | ||||||
|     case l @ PairLiter()  => KnownType.Pair(?, ?).satisfies(constraint, l.pos) |     case l @ ast.PairLiter()  => KnownType.Pair(?, ?).satisfies(constraint, l.pos) | ||||||
|     case id: Ident => |     case id: ast.Ident => | ||||||
|       ctx.typeOf(id).satisfies(constraint, id.pos) |       ctx.typeOf(id).satisfies(constraint, id.pos) | ||||||
|     case ArrayElem(id, indices) => |     case ast.ArrayElem(id, indices) => | ||||||
|       val arrayTy = ctx.typeOf(id) |       val arrayTy = ctx.typeOf(id) | ||||||
|       val elemTy = indices.foldLeftM(arrayTy) { (acc, elem) => |       val elemTy = indices.foldLeftM(arrayTy) { (acc, elem) => | ||||||
|         checkValue(elem, Constraint.Is(KnownType.Int, "array index must be an int")) |         checkValue(elem, Constraint.Is(KnownType.Int, "array index must be an int")) | ||||||
| @@ -222,8 +220,8 @@ object typeChecker { | |||||||
|         } |         } | ||||||
|       } |       } | ||||||
|       elemTy.getOrElse(?).satisfies(constraint, id.pos) |       elemTy.getOrElse(?).satisfies(constraint, id.pos) | ||||||
|     case Parens(expr) => checkValue(expr, constraint) |     case ast.Parens(expr) => checkValue(expr, constraint) | ||||||
|     case l @ ArrayLiter(elems) => |     case l @ ast.ArrayLiter(elems) => | ||||||
|       KnownType |       KnownType | ||||||
|         // Start with an unknown param type, make it more specific while checking the elements. |         // Start with an unknown param type, make it more specific while checking the elements. | ||||||
|         .Array(elems.foldLeft[SemType](?) { case (acc, elem) => |         .Array(elems.foldLeft[SemType](?) { case (acc, elem) => | ||||||
| @@ -233,14 +231,14 @@ object typeChecker { | |||||||
|           ) |           ) | ||||||
|         }) |         }) | ||||||
|         .satisfies(constraint, l.pos) |         .satisfies(constraint, l.pos) | ||||||
|     case l @ NewPair(fst, snd) => |     case l @ ast.NewPair(fst, snd) => | ||||||
|       KnownType |       KnownType | ||||||
|         .Pair( |         .Pair( | ||||||
|           checkValue(fst, Constraint.Unconstrained), |           checkValue(fst, Constraint.Unconstrained), | ||||||
|           checkValue(snd, Constraint.Unconstrained) |           checkValue(snd, Constraint.Unconstrained) | ||||||
|         ) |         ) | ||||||
|         .satisfies(constraint, l.pos) |         .satisfies(constraint, l.pos) | ||||||
|     case Call(id, args) => |     case ast.Call(id, args) => | ||||||
|       val funcTy @ FuncType(retTy, paramTys) = ctx.funcType(id) |       val funcTy @ FuncType(retTy, paramTys) = ctx.funcType(id) | ||||||
|       if (args.length != paramTys.length) { |       if (args.length != paramTys.length) { | ||||||
|         ctx.error(Error.FunctionParamsMismatch(id, paramTys.length, args.length, funcTy)) |         ctx.error(Error.FunctionParamsMismatch(id, paramTys.length, args.length, funcTy)) | ||||||
| @@ -251,7 +249,7 @@ object typeChecker { | |||||||
|         checkValue(arg, Constraint.Is(paramTy, s"argument type mismatch in function ${id.v}")) |         checkValue(arg, Constraint.Is(paramTy, s"argument type mismatch in function ${id.v}")) | ||||||
|       } |       } | ||||||
|       retTy.satisfies(constraint, id.pos) |       retTy.satisfies(constraint, id.pos) | ||||||
|     case Fst(elem) => |     case ast.Fst(elem) => | ||||||
|       checkValue( |       checkValue( | ||||||
|         elem, |         elem, | ||||||
|         Constraint.Is(KnownType.Pair(?, ?), "fst must be applied to a pair") |         Constraint.Is(KnownType.Pair(?, ?), "fst must be applied to a pair") | ||||||
| @@ -260,7 +258,7 @@ object typeChecker { | |||||||
|           left.satisfies(constraint, elem.pos) |           left.satisfies(constraint, elem.pos) | ||||||
|         case _ => ctx.error(Error.InternalError(elem.pos, "fst must be applied to a pair")) |         case _ => ctx.error(Error.InternalError(elem.pos, "fst must be applied to a pair")) | ||||||
|       } |       } | ||||||
|     case Snd(elem) => |     case ast.Snd(elem) => | ||||||
|       checkValue( |       checkValue( | ||||||
|         elem, |         elem, | ||||||
|         Constraint.Is(KnownType.Pair(?, ?), "snd must be applied to a pair") |         Constraint.Is(KnownType.Pair(?, ?), "snd must be applied to a pair") | ||||||
| @@ -270,36 +268,36 @@ object typeChecker { | |||||||
|       } |       } | ||||||
|  |  | ||||||
|     // Unary operators |     // Unary operators | ||||||
|     case Negate(x) => |     case ast.Negate(x) => | ||||||
|       checkValue(x, Constraint.Is(KnownType.Int, "negation must be applied to an int")) |       checkValue(x, Constraint.Is(KnownType.Int, "negation must be applied to an int")) | ||||||
|       KnownType.Int.satisfies(constraint, x.pos) |       KnownType.Int.satisfies(constraint, x.pos) | ||||||
|     case Not(x) => |     case ast.Not(x) => | ||||||
|       checkValue(x, Constraint.Is(KnownType.Bool, "logical not must be applied to a bool")) |       checkValue(x, Constraint.Is(KnownType.Bool, "logical not must be applied to a bool")) | ||||||
|       KnownType.Bool.satisfies(constraint, x.pos) |       KnownType.Bool.satisfies(constraint, x.pos) | ||||||
|     case Len(x) => |     case ast.Len(x) => | ||||||
|       checkValue(x, Constraint.Is(KnownType.Array(?), "len must be applied to an array")) |       checkValue(x, Constraint.Is(KnownType.Array(?), "len must be applied to an array")) | ||||||
|       KnownType.Int.satisfies(constraint, x.pos) |       KnownType.Int.satisfies(constraint, x.pos) | ||||||
|     case Ord(x) => |     case ast.Ord(x) => | ||||||
|       checkValue(x, Constraint.Is(KnownType.Char, "ord must be applied to a char")) |       checkValue(x, Constraint.Is(KnownType.Char, "ord must be applied to a char")) | ||||||
|       KnownType.Int.satisfies(constraint, x.pos) |       KnownType.Int.satisfies(constraint, x.pos) | ||||||
|     case Chr(x) => |     case ast.Chr(x) => | ||||||
|       checkValue(x, Constraint.Is(KnownType.Int, "chr must be applied to an int")) |       checkValue(x, Constraint.Is(KnownType.Int, "chr must be applied to an int")) | ||||||
|       KnownType.Char.satisfies(constraint, x.pos) |       KnownType.Char.satisfies(constraint, x.pos) | ||||||
|  |  | ||||||
|     // Binary operators |     // Binary operators | ||||||
|     case op: (Add | Sub | Mul | Div | Mod) => |     case op: (ast.Add | ast.Sub | ast.Mul | ast.Div | ast.Mod) => | ||||||
|       val operand = Constraint.Is(KnownType.Int, s"${op.name} operator must be applied to an int") |       val operand = Constraint.Is(KnownType.Int, s"${op.name} operator must be applied to an int") | ||||||
|       checkValue(op.x, operand) |       checkValue(op.x, operand) | ||||||
|       checkValue(op.y, operand) |       checkValue(op.y, operand) | ||||||
|       KnownType.Int.satisfies(constraint, op.pos) |       KnownType.Int.satisfies(constraint, op.pos) | ||||||
|     case op: (Eq | Neq) => |     case op: (ast.Eq | ast.Neq) => | ||||||
|       val xTy = checkValue(op.x, Constraint.Unconstrained) |       val xTy = checkValue(op.x, Constraint.Unconstrained) | ||||||
|       checkValue( |       checkValue( | ||||||
|         op.y, |         op.y, | ||||||
|         Constraint.Is(xTy, s"${op.name} operator must be applied to values of the same type") |         Constraint.Is(xTy, s"${op.name} operator must be applied to values of the same type") | ||||||
|       ) |       ) | ||||||
|       KnownType.Bool.satisfies(constraint, op.pos) |       KnownType.Bool.satisfies(constraint, op.pos) | ||||||
|     case op: (Less | LessEq | Greater | GreaterEq) => |     case op: (ast.Less | ast.LessEq | ast.Greater | ast.GreaterEq) => | ||||||
|       val xConstraint = Constraint.IsEither( |       val xConstraint = Constraint.IsEither( | ||||||
|         KnownType.Int, |         KnownType.Int, | ||||||
|         KnownType.Char, |         KnownType.Char, | ||||||
| @@ -313,7 +311,7 @@ object typeChecker { | |||||||
|       } |       } | ||||||
|       checkValue(op.y, yConstraint) |       checkValue(op.y, yConstraint) | ||||||
|       KnownType.Bool.satisfies(constraint, op.pos) |       KnownType.Bool.satisfies(constraint, op.pos) | ||||||
|     case op: (And | Or) => |     case op: (ast.And | ast.Or) => | ||||||
|       val operand = Constraint.Is(KnownType.Bool, s"${op.name} operator must be applied to a bool") |       val operand = Constraint.Is(KnownType.Bool, s"${op.name} operator must be applied to a bool") | ||||||
|       checkValue(op.x, operand) |       checkValue(op.x, operand) | ||||||
|       checkValue(op.y, operand) |       checkValue(op.y, operand) | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user