refactor: comment non-trivial code in typeChecker.scala

Merge request lab2425_spring/WACC_37!20
This commit is contained in:
Gleb Koval 2025-02-07 18:45:50 +00:00
commit 2a4c2bc993

View File

@ -21,17 +21,32 @@ object typeChecker {
?
}
enum Constraint {
private enum Constraint {
case Unconstrained
// Allows weakening in one direction
case Is(ty: SemType, msg: String)
// Allows weakening in both directions, useful for array literals
case IsSymmetricCompatible(ty: SemType, msg: String)
case IsUnweakanable(ty: SemType, msg: String)
// Does not allow weakening
case IsUnweakenable(ty: SemType, msg: String)
case IsEither(ty1: SemType, ty2: SemType, msg: String)
case Never(msg: String)
}
extension (ty: SemType) {
def satisfies(constraint: Constraint, pos: Position)(using ctx: TypeCheckerCtx): SemType =
/** Check if a type satisfies a constraint.
*
* @param constraint
* Constraint to satisfy.
* @param pos
* Position to pass to the error, if constraint was not satisfied.
* @return
* The type if the constraint was satisfied, or ? if it was not.
*/
private def satisfies(constraint: Constraint, pos: Position)(using
ctx: TypeCheckerCtx
): SemType =
(ty, constraint) match {
case (KnownType.Array(KnownType.Char), Constraint.Is(KnownType.String, _)) =>
KnownType.String
@ -42,7 +57,8 @@ object typeChecker {
KnownType.String
case (ty, Constraint.IsSymmetricCompatible(ty2, msg)) =>
ty.satisfies(Constraint.Is(ty2, msg), pos)
case (ty, Constraint.Is(ty2, msg)) => ty.satisfies(Constraint.IsUnweakanable(ty2, msg), pos)
// Change to IsUnweakenable to disallow recursive weakening
case (ty, Constraint.Is(ty2, msg)) => ty.satisfies(Constraint.IsUnweakenable(ty2, msg), pos)
case (ty, Constraint.Unconstrained) => ty
case (ty, Constraint.Never(msg)) =>
ctx.error(Error.SemanticError(pos, msg))
@ -50,13 +66,20 @@ object typeChecker {
(ty moreSpecific ty1).orElse(ty moreSpecific ty2).getOrElse {
ctx.error(Error.TypeMismatch(pos, ty1, ty, msg))
}
case (ty, Constraint.IsUnweakanable(ty2, msg)) =>
case (ty, Constraint.IsUnweakenable(ty2, msg)) =>
(ty moreSpecific ty2).getOrElse {
ctx.error(Error.TypeMismatch(pos, ty2, ty, msg))
}
}
infix def moreSpecific(ty2: SemType): Option[SemType] =
/** Tries to merge two types, returning the more specific one if possible.
*
* @param ty2
* The other type to merge with.
* @return
* The more specific type if it could be determined, or None if the types are incompatible.
*/
private infix def moreSpecific(ty2: SemType): Option[SemType] =
(ty, ty2) match {
case (ty, ?) => Some(ty)
case (?, ty) => Some(ty)
@ -69,9 +92,19 @@ object typeChecker {
}
}
/** Type-check a WACC program.
*
* @param prog
* The AST of the program to type-check.
* @param ctx
* The type checker context which includes the global names and functions, and an errors
* builder.
*/
def check(prog: Program)(using
ctx: TypeCheckerCtx
): Unit = {
// Ignore function syntax types for return value and params, since those have been converted
// to SemTypes by the renamer.
prog.funcs.foreach { case FuncDecl(_, name, _, stmts) =>
val FuncType(retType, _) = ctx.funcType(name)
stmts.toList.foreach(
@ -81,9 +114,17 @@ object typeChecker {
prog.main.toList.foreach(checkStmt(_, Constraint.Never("main function must not return")))
}
/** Type-check an AST statement node.
*
* @param stmt
* The statement to type-check.
* @param returnConstraint
* The constraint that any `return <expr>` statements must satisfy.
*/
private def checkStmt(stmt: Stmt, returnConstraint: Constraint)(using
ctx: TypeCheckerCtx
): Unit = stmt match {
// Ignore the type of the variable, since it has been converted to a SemType by the renamer.
case VarDecl(_, name, value) =>
val expectedTy = ctx.typeOf(name)
checkValue(
@ -146,6 +187,16 @@ object typeChecker {
case Skip() => ()
}
/** Type-check an AST LValue, RValue or Expr node. This function does all 3 since these traits
* overlap in the AST.
*
* @param value
* The value to type-check.
* @param constraint
* The type constraint that the value must satisfy.
* @return
* The most specific type of the value if it could be determined, or ? if it could not.
*/
private def checkValue(value: LValue | RValue | Expr, constraint: Constraint)(using
ctx: TypeCheckerCtx
): SemType = value match {
@ -162,7 +213,7 @@ object typeChecker {
checkValue(elem, Constraint.Is(KnownType.Int, "array index must be an int"))
acc match {
case KnownType.Array(innerTy) => Some(innerTy)
case ? => Some(?)
case ? => Some(?) // we can keep indexing an unknown type
case nonArrayTy =>
ctx.error(
Error.TypeMismatch(elem.pos, KnownType.Array(?), acc, "cannot index into a non-array")
@ -174,6 +225,7 @@ object typeChecker {
case Parens(expr) => checkValue(expr, constraint)
case l @ ArrayLiter(elems) =>
KnownType
// Start with an unknown param type, make it more specific while checking the elements.
.Array(elems.foldLeft[SemType](?) { case (acc, elem) =>
checkValue(
elem,
@ -193,6 +245,8 @@ object typeChecker {
if (args.length != paramTys.length) {
ctx.error(Error.FunctionParamsMismatch(id, paramTys.length, args.length, funcTy))
}
// Even if the number of arguments is wrong, we still check the types of the arguments
// in the best way we can (by taking a zip).
args.zip(paramTys).foreach { case (arg, paramTy) =>
checkValue(arg, Constraint.Is(paramTy, s"argument type mismatch in function ${id.v}"))
}
@ -205,7 +259,7 @@ object typeChecker {
case what @ KnownType.Pair(left, _) =>
left.satisfies(constraint, elem.pos)
case _ => ctx.error(Error.InternalError(elem.pos, "fst must be applied to a pair"))
} // satisfies constraint
}
case Snd(elem) =>
checkValue(
elem,
@ -251,6 +305,7 @@ object typeChecker {
KnownType.Char,
s"${op.name} operator must be applied to an int or char"
)
// If x type-check failed, we still want to check y is an Int or Char (rather than ?)
val yConstraint = checkValue(op.x, xConstraint) match {
case ? => xConstraint
case xTy =>