From 9da2744cb928683937e739c427971af2825c781d Mon Sep 17 00:00:00 2001 From: Gleb Koval Date: Fri, 7 Feb 2025 18:41:16 +0000 Subject: [PATCH 1/2] refactor: comment non-trivial code in typeChecker.scala --- src/main/wacc/typeChecker.scala | 65 ++++++++++++++++++++++++++++++--- 1 file changed, 60 insertions(+), 5 deletions(-) diff --git a/src/main/wacc/typeChecker.scala b/src/main/wacc/typeChecker.scala index c766905..6168480 100644 --- a/src/main/wacc/typeChecker.scala +++ b/src/main/wacc/typeChecker.scala @@ -21,17 +21,32 @@ object typeChecker { ? } - enum Constraint { + private enum Constraint { case Unconstrained + // Allows weakening in one direction case Is(ty: SemType, msg: String) + // Allows weakening in both directions, useful for array literals case IsSymmetricCompatible(ty: SemType, msg: String) + // Does not allow weakening case IsUnweakanable(ty: SemType, msg: String) case IsEither(ty1: SemType, ty2: SemType, msg: String) case Never(msg: String) } extension (ty: SemType) { - def satisfies(constraint: Constraint, pos: Position)(using ctx: TypeCheckerCtx): SemType = + + /** Check if a type satisfies a constraint. + * + * @param constraint + * Constraint to satisfy. + * @param pos + * Position to pass to the error, if constraint was not satisfied. + * @return + * The type if the constraint was satisfied, or ? if it was not. + */ + private def satisfies(constraint: Constraint, pos: Position)(using + ctx: TypeCheckerCtx + ): SemType = (ty, constraint) match { case (KnownType.Array(KnownType.Char), Constraint.Is(KnownType.String, _)) => KnownType.String @@ -42,6 +57,7 @@ object typeChecker { KnownType.String case (ty, Constraint.IsSymmetricCompatible(ty2, msg)) => ty.satisfies(Constraint.Is(ty2, msg), pos) + // Change to IsUnweakanable to disallow recursive weakening case (ty, Constraint.Is(ty2, msg)) => ty.satisfies(Constraint.IsUnweakanable(ty2, msg), pos) case (ty, Constraint.Unconstrained) => ty case (ty, Constraint.Never(msg)) => @@ -56,7 +72,14 @@ object typeChecker { } } - infix def moreSpecific(ty2: SemType): Option[SemType] = + /** Tries to merge two types, returning the more specific one if possible. + * + * @param ty2 + * The other type to merge with. + * @return + * The more specific type if it could be determined, or None if the types are incompatible. + */ + private infix def moreSpecific(ty2: SemType): Option[SemType] = (ty, ty2) match { case (ty, ?) => Some(ty) case (?, ty) => Some(ty) @@ -69,9 +92,19 @@ object typeChecker { } } + /** Type-check a WACC program. + * + * @param prog + * The AST of the program to type-check. + * @param ctx + * The type checker context which includes the global names and functions, and an errors + * builder. + */ def check(prog: Program)(using ctx: TypeCheckerCtx ): Unit = { + // Ignore function syntax types for return value and params, since those have been converted + // to SemTypes by the renamer. prog.funcs.foreach { case FuncDecl(_, name, _, stmts) => val FuncType(retType, _) = ctx.funcType(name) stmts.toList.foreach( @@ -81,9 +114,17 @@ object typeChecker { prog.main.toList.foreach(checkStmt(_, Constraint.Never("main function must not return"))) } + /** Type-check an AST statement node. + * + * @param stmt + * The statement to type-check. + * @param returnConstraint + * The constraint that any `return ` statements must satisfy. + */ private def checkStmt(stmt: Stmt, returnConstraint: Constraint)(using ctx: TypeCheckerCtx ): Unit = stmt match { + // Ignore the type of the variable, since it has been converted to a SemType by the renamer. case VarDecl(_, name, value) => val expectedTy = ctx.typeOf(name) checkValue( @@ -146,6 +187,16 @@ object typeChecker { case Skip() => () } + /** Type-check an AST LValue, RValue or Expr node. This function does all 3 since these traits + * overlap in the AST. + * + * @param value + * The value to type-check. + * @param constraint + * The type constraint that the value must satisfy. + * @return + * The most specific type of the value if it could be determined, or ? if it could not. + */ private def checkValue(value: LValue | RValue | Expr, constraint: Constraint)(using ctx: TypeCheckerCtx ): SemType = value match { @@ -162,7 +213,7 @@ object typeChecker { checkValue(elem, Constraint.Is(KnownType.Int, "array index must be an int")) acc match { case KnownType.Array(innerTy) => Some(innerTy) - case ? => Some(?) + case ? => Some(?) // we can keep indexing an unknown type case nonArrayTy => ctx.error( Error.TypeMismatch(elem.pos, KnownType.Array(?), acc, "cannot index into a non-array") @@ -174,6 +225,7 @@ object typeChecker { case Parens(expr) => checkValue(expr, constraint) case l @ ArrayLiter(elems) => KnownType + // Start with an unknown param type, make it more specific while checking the elements. .Array(elems.foldLeft[SemType](?) { case (acc, elem) => checkValue( elem, @@ -193,6 +245,8 @@ object typeChecker { if (args.length != paramTys.length) { ctx.error(Error.FunctionParamsMismatch(id, paramTys.length, args.length, funcTy)) } + // Even if the number of arguments is wrong, we still check the types of the arguments + // in the best way we can (by taking a zip). args.zip(paramTys).foreach { case (arg, paramTy) => checkValue(arg, Constraint.Is(paramTy, s"argument type mismatch in function ${id.v}")) } @@ -205,7 +259,7 @@ object typeChecker { case what @ KnownType.Pair(left, _) => left.satisfies(constraint, elem.pos) case _ => ctx.error(Error.InternalError(elem.pos, "fst must be applied to a pair")) - } // satisfies constraint + } case Snd(elem) => checkValue( elem, @@ -251,6 +305,7 @@ object typeChecker { KnownType.Char, s"${op.name} operator must be applied to an int or char" ) + // If x type-check failed, we still want to check y is an Int or Char (rather than ?) val yConstraint = checkValue(op.x, xConstraint) match { case ? => xConstraint case xTy => From 84932edb693574d0f1f2481ca3059043d4ca9b11 Mon Sep 17 00:00:00 2001 From: Gleb Koval Date: Fri, 7 Feb 2025 18:43:30 +0000 Subject: [PATCH 2/2] refactor: change typo in Unweakenable --- src/main/wacc/typeChecker.scala | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/main/wacc/typeChecker.scala b/src/main/wacc/typeChecker.scala index 6168480..62fad3b 100644 --- a/src/main/wacc/typeChecker.scala +++ b/src/main/wacc/typeChecker.scala @@ -28,7 +28,7 @@ object typeChecker { // Allows weakening in both directions, useful for array literals case IsSymmetricCompatible(ty: SemType, msg: String) // Does not allow weakening - case IsUnweakanable(ty: SemType, msg: String) + case IsUnweakenable(ty: SemType, msg: String) case IsEither(ty1: SemType, ty2: SemType, msg: String) case Never(msg: String) } @@ -57,8 +57,8 @@ object typeChecker { KnownType.String case (ty, Constraint.IsSymmetricCompatible(ty2, msg)) => ty.satisfies(Constraint.Is(ty2, msg), pos) - // Change to IsUnweakanable to disallow recursive weakening - case (ty, Constraint.Is(ty2, msg)) => ty.satisfies(Constraint.IsUnweakanable(ty2, msg), pos) + // Change to IsUnweakenable to disallow recursive weakening + case (ty, Constraint.Is(ty2, msg)) => ty.satisfies(Constraint.IsUnweakenable(ty2, msg), pos) case (ty, Constraint.Unconstrained) => ty case (ty, Constraint.Never(msg)) => ctx.error(Error.SemanticError(pos, msg)) @@ -66,7 +66,7 @@ object typeChecker { (ty moreSpecific ty1).orElse(ty moreSpecific ty2).getOrElse { ctx.error(Error.TypeMismatch(pos, ty1, ty, msg)) } - case (ty, Constraint.IsUnweakanable(ty2, msg)) => + case (ty, Constraint.IsUnweakenable(ty2, msg)) => (ty moreSpecific ty2).getOrElse { ctx.error(Error.TypeMismatch(pos, ty2, ty, msg)) }