refactor: comment non-trivial code in typeChecker.scala

Merge request lab2425_spring/WACC_37!20
This commit is contained in:
Gleb Koval 2025-02-07 18:45:50 +00:00
commit 2a4c2bc993

View File

@ -21,17 +21,32 @@ object typeChecker {
? ?
} }
enum Constraint { private enum Constraint {
case Unconstrained case Unconstrained
// Allows weakening in one direction
case Is(ty: SemType, msg: String) case Is(ty: SemType, msg: String)
// Allows weakening in both directions, useful for array literals
case IsSymmetricCompatible(ty: SemType, msg: String) case IsSymmetricCompatible(ty: SemType, msg: String)
case IsUnweakanable(ty: SemType, msg: String) // Does not allow weakening
case IsUnweakenable(ty: SemType, msg: String)
case IsEither(ty1: SemType, ty2: SemType, msg: String) case IsEither(ty1: SemType, ty2: SemType, msg: String)
case Never(msg: String) case Never(msg: String)
} }
extension (ty: SemType) { extension (ty: SemType) {
def satisfies(constraint: Constraint, pos: Position)(using ctx: TypeCheckerCtx): SemType =
/** Check if a type satisfies a constraint.
*
* @param constraint
* Constraint to satisfy.
* @param pos
* Position to pass to the error, if constraint was not satisfied.
* @return
* The type if the constraint was satisfied, or ? if it was not.
*/
private def satisfies(constraint: Constraint, pos: Position)(using
ctx: TypeCheckerCtx
): SemType =
(ty, constraint) match { (ty, constraint) match {
case (KnownType.Array(KnownType.Char), Constraint.Is(KnownType.String, _)) => case (KnownType.Array(KnownType.Char), Constraint.Is(KnownType.String, _)) =>
KnownType.String KnownType.String
@ -42,7 +57,8 @@ object typeChecker {
KnownType.String KnownType.String
case (ty, Constraint.IsSymmetricCompatible(ty2, msg)) => case (ty, Constraint.IsSymmetricCompatible(ty2, msg)) =>
ty.satisfies(Constraint.Is(ty2, msg), pos) ty.satisfies(Constraint.Is(ty2, msg), pos)
case (ty, Constraint.Is(ty2, msg)) => ty.satisfies(Constraint.IsUnweakanable(ty2, msg), pos) // Change to IsUnweakenable to disallow recursive weakening
case (ty, Constraint.Is(ty2, msg)) => ty.satisfies(Constraint.IsUnweakenable(ty2, msg), pos)
case (ty, Constraint.Unconstrained) => ty case (ty, Constraint.Unconstrained) => ty
case (ty, Constraint.Never(msg)) => case (ty, Constraint.Never(msg)) =>
ctx.error(Error.SemanticError(pos, msg)) ctx.error(Error.SemanticError(pos, msg))
@ -50,13 +66,20 @@ object typeChecker {
(ty moreSpecific ty1).orElse(ty moreSpecific ty2).getOrElse { (ty moreSpecific ty1).orElse(ty moreSpecific ty2).getOrElse {
ctx.error(Error.TypeMismatch(pos, ty1, ty, msg)) ctx.error(Error.TypeMismatch(pos, ty1, ty, msg))
} }
case (ty, Constraint.IsUnweakanable(ty2, msg)) => case (ty, Constraint.IsUnweakenable(ty2, msg)) =>
(ty moreSpecific ty2).getOrElse { (ty moreSpecific ty2).getOrElse {
ctx.error(Error.TypeMismatch(pos, ty2, ty, msg)) ctx.error(Error.TypeMismatch(pos, ty2, ty, msg))
} }
} }
infix def moreSpecific(ty2: SemType): Option[SemType] = /** Tries to merge two types, returning the more specific one if possible.
*
* @param ty2
* The other type to merge with.
* @return
* The more specific type if it could be determined, or None if the types are incompatible.
*/
private infix def moreSpecific(ty2: SemType): Option[SemType] =
(ty, ty2) match { (ty, ty2) match {
case (ty, ?) => Some(ty) case (ty, ?) => Some(ty)
case (?, ty) => Some(ty) case (?, ty) => Some(ty)
@ -69,9 +92,19 @@ object typeChecker {
} }
} }
/** Type-check a WACC program.
*
* @param prog
* The AST of the program to type-check.
* @param ctx
* The type checker context which includes the global names and functions, and an errors
* builder.
*/
def check(prog: Program)(using def check(prog: Program)(using
ctx: TypeCheckerCtx ctx: TypeCheckerCtx
): Unit = { ): Unit = {
// Ignore function syntax types for return value and params, since those have been converted
// to SemTypes by the renamer.
prog.funcs.foreach { case FuncDecl(_, name, _, stmts) => prog.funcs.foreach { case FuncDecl(_, name, _, stmts) =>
val FuncType(retType, _) = ctx.funcType(name) val FuncType(retType, _) = ctx.funcType(name)
stmts.toList.foreach( stmts.toList.foreach(
@ -81,9 +114,17 @@ object typeChecker {
prog.main.toList.foreach(checkStmt(_, Constraint.Never("main function must not return"))) prog.main.toList.foreach(checkStmt(_, Constraint.Never("main function must not return")))
} }
/** Type-check an AST statement node.
*
* @param stmt
* The statement to type-check.
* @param returnConstraint
* The constraint that any `return <expr>` statements must satisfy.
*/
private def checkStmt(stmt: Stmt, returnConstraint: Constraint)(using private def checkStmt(stmt: Stmt, returnConstraint: Constraint)(using
ctx: TypeCheckerCtx ctx: TypeCheckerCtx
): Unit = stmt match { ): Unit = stmt match {
// Ignore the type of the variable, since it has been converted to a SemType by the renamer.
case VarDecl(_, name, value) => case VarDecl(_, name, value) =>
val expectedTy = ctx.typeOf(name) val expectedTy = ctx.typeOf(name)
checkValue( checkValue(
@ -146,6 +187,16 @@ object typeChecker {
case Skip() => () case Skip() => ()
} }
/** Type-check an AST LValue, RValue or Expr node. This function does all 3 since these traits
* overlap in the AST.
*
* @param value
* The value to type-check.
* @param constraint
* The type constraint that the value must satisfy.
* @return
* The most specific type of the value if it could be determined, or ? if it could not.
*/
private def checkValue(value: LValue | RValue | Expr, constraint: Constraint)(using private def checkValue(value: LValue | RValue | Expr, constraint: Constraint)(using
ctx: TypeCheckerCtx ctx: TypeCheckerCtx
): SemType = value match { ): SemType = value match {
@ -162,7 +213,7 @@ object typeChecker {
checkValue(elem, Constraint.Is(KnownType.Int, "array index must be an int")) checkValue(elem, Constraint.Is(KnownType.Int, "array index must be an int"))
acc match { acc match {
case KnownType.Array(innerTy) => Some(innerTy) case KnownType.Array(innerTy) => Some(innerTy)
case ? => Some(?) case ? => Some(?) // we can keep indexing an unknown type
case nonArrayTy => case nonArrayTy =>
ctx.error( ctx.error(
Error.TypeMismatch(elem.pos, KnownType.Array(?), acc, "cannot index into a non-array") Error.TypeMismatch(elem.pos, KnownType.Array(?), acc, "cannot index into a non-array")
@ -174,6 +225,7 @@ object typeChecker {
case Parens(expr) => checkValue(expr, constraint) case Parens(expr) => checkValue(expr, constraint)
case l @ ArrayLiter(elems) => case l @ ArrayLiter(elems) =>
KnownType KnownType
// Start with an unknown param type, make it more specific while checking the elements.
.Array(elems.foldLeft[SemType](?) { case (acc, elem) => .Array(elems.foldLeft[SemType](?) { case (acc, elem) =>
checkValue( checkValue(
elem, elem,
@ -193,6 +245,8 @@ object typeChecker {
if (args.length != paramTys.length) { if (args.length != paramTys.length) {
ctx.error(Error.FunctionParamsMismatch(id, paramTys.length, args.length, funcTy)) ctx.error(Error.FunctionParamsMismatch(id, paramTys.length, args.length, funcTy))
} }
// Even if the number of arguments is wrong, we still check the types of the arguments
// in the best way we can (by taking a zip).
args.zip(paramTys).foreach { case (arg, paramTy) => args.zip(paramTys).foreach { case (arg, paramTy) =>
checkValue(arg, Constraint.Is(paramTy, s"argument type mismatch in function ${id.v}")) checkValue(arg, Constraint.Is(paramTy, s"argument type mismatch in function ${id.v}"))
} }
@ -205,7 +259,7 @@ object typeChecker {
case what @ KnownType.Pair(left, _) => case what @ KnownType.Pair(left, _) =>
left.satisfies(constraint, elem.pos) left.satisfies(constraint, elem.pos)
case _ => ctx.error(Error.InternalError(elem.pos, "fst must be applied to a pair")) case _ => ctx.error(Error.InternalError(elem.pos, "fst must be applied to a pair"))
} // satisfies constraint }
case Snd(elem) => case Snd(elem) =>
checkValue( checkValue(
elem, elem,
@ -251,6 +305,7 @@ object typeChecker {
KnownType.Char, KnownType.Char,
s"${op.name} operator must be applied to an int or char" s"${op.name} operator must be applied to an int or char"
) )
// If x type-check failed, we still want to check y is an Int or Char (rather than ?)
val yConstraint = checkValue(op.x, xConstraint) match { val yConstraint = checkValue(op.x, xConstraint) match {
case ? => xConstraint case ? => xConstraint
case xTy => case xTy =>