refactor: comment non-trivial code in typeChecker.scala
Merge request lab2425_spring/WACC_37!20
This commit is contained in:
commit
2a4c2bc993
@ -21,17 +21,32 @@ object typeChecker {
|
|||||||
?
|
?
|
||||||
}
|
}
|
||||||
|
|
||||||
enum Constraint {
|
private enum Constraint {
|
||||||
case Unconstrained
|
case Unconstrained
|
||||||
|
// Allows weakening in one direction
|
||||||
case Is(ty: SemType, msg: String)
|
case Is(ty: SemType, msg: String)
|
||||||
|
// Allows weakening in both directions, useful for array literals
|
||||||
case IsSymmetricCompatible(ty: SemType, msg: String)
|
case IsSymmetricCompatible(ty: SemType, msg: String)
|
||||||
case IsUnweakanable(ty: SemType, msg: String)
|
// Does not allow weakening
|
||||||
|
case IsUnweakenable(ty: SemType, msg: String)
|
||||||
case IsEither(ty1: SemType, ty2: SemType, msg: String)
|
case IsEither(ty1: SemType, ty2: SemType, msg: String)
|
||||||
case Never(msg: String)
|
case Never(msg: String)
|
||||||
}
|
}
|
||||||
|
|
||||||
extension (ty: SemType) {
|
extension (ty: SemType) {
|
||||||
def satisfies(constraint: Constraint, pos: Position)(using ctx: TypeCheckerCtx): SemType =
|
|
||||||
|
/** Check if a type satisfies a constraint.
|
||||||
|
*
|
||||||
|
* @param constraint
|
||||||
|
* Constraint to satisfy.
|
||||||
|
* @param pos
|
||||||
|
* Position to pass to the error, if constraint was not satisfied.
|
||||||
|
* @return
|
||||||
|
* The type if the constraint was satisfied, or ? if it was not.
|
||||||
|
*/
|
||||||
|
private def satisfies(constraint: Constraint, pos: Position)(using
|
||||||
|
ctx: TypeCheckerCtx
|
||||||
|
): SemType =
|
||||||
(ty, constraint) match {
|
(ty, constraint) match {
|
||||||
case (KnownType.Array(KnownType.Char), Constraint.Is(KnownType.String, _)) =>
|
case (KnownType.Array(KnownType.Char), Constraint.Is(KnownType.String, _)) =>
|
||||||
KnownType.String
|
KnownType.String
|
||||||
@ -42,7 +57,8 @@ object typeChecker {
|
|||||||
KnownType.String
|
KnownType.String
|
||||||
case (ty, Constraint.IsSymmetricCompatible(ty2, msg)) =>
|
case (ty, Constraint.IsSymmetricCompatible(ty2, msg)) =>
|
||||||
ty.satisfies(Constraint.Is(ty2, msg), pos)
|
ty.satisfies(Constraint.Is(ty2, msg), pos)
|
||||||
case (ty, Constraint.Is(ty2, msg)) => ty.satisfies(Constraint.IsUnweakanable(ty2, msg), pos)
|
// Change to IsUnweakenable to disallow recursive weakening
|
||||||
|
case (ty, Constraint.Is(ty2, msg)) => ty.satisfies(Constraint.IsUnweakenable(ty2, msg), pos)
|
||||||
case (ty, Constraint.Unconstrained) => ty
|
case (ty, Constraint.Unconstrained) => ty
|
||||||
case (ty, Constraint.Never(msg)) =>
|
case (ty, Constraint.Never(msg)) =>
|
||||||
ctx.error(Error.SemanticError(pos, msg))
|
ctx.error(Error.SemanticError(pos, msg))
|
||||||
@ -50,13 +66,20 @@ object typeChecker {
|
|||||||
(ty moreSpecific ty1).orElse(ty moreSpecific ty2).getOrElse {
|
(ty moreSpecific ty1).orElse(ty moreSpecific ty2).getOrElse {
|
||||||
ctx.error(Error.TypeMismatch(pos, ty1, ty, msg))
|
ctx.error(Error.TypeMismatch(pos, ty1, ty, msg))
|
||||||
}
|
}
|
||||||
case (ty, Constraint.IsUnweakanable(ty2, msg)) =>
|
case (ty, Constraint.IsUnweakenable(ty2, msg)) =>
|
||||||
(ty moreSpecific ty2).getOrElse {
|
(ty moreSpecific ty2).getOrElse {
|
||||||
ctx.error(Error.TypeMismatch(pos, ty2, ty, msg))
|
ctx.error(Error.TypeMismatch(pos, ty2, ty, msg))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
infix def moreSpecific(ty2: SemType): Option[SemType] =
|
/** Tries to merge two types, returning the more specific one if possible.
|
||||||
|
*
|
||||||
|
* @param ty2
|
||||||
|
* The other type to merge with.
|
||||||
|
* @return
|
||||||
|
* The more specific type if it could be determined, or None if the types are incompatible.
|
||||||
|
*/
|
||||||
|
private infix def moreSpecific(ty2: SemType): Option[SemType] =
|
||||||
(ty, ty2) match {
|
(ty, ty2) match {
|
||||||
case (ty, ?) => Some(ty)
|
case (ty, ?) => Some(ty)
|
||||||
case (?, ty) => Some(ty)
|
case (?, ty) => Some(ty)
|
||||||
@ -69,9 +92,19 @@ object typeChecker {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/** Type-check a WACC program.
|
||||||
|
*
|
||||||
|
* @param prog
|
||||||
|
* The AST of the program to type-check.
|
||||||
|
* @param ctx
|
||||||
|
* The type checker context which includes the global names and functions, and an errors
|
||||||
|
* builder.
|
||||||
|
*/
|
||||||
def check(prog: Program)(using
|
def check(prog: Program)(using
|
||||||
ctx: TypeCheckerCtx
|
ctx: TypeCheckerCtx
|
||||||
): Unit = {
|
): Unit = {
|
||||||
|
// Ignore function syntax types for return value and params, since those have been converted
|
||||||
|
// to SemTypes by the renamer.
|
||||||
prog.funcs.foreach { case FuncDecl(_, name, _, stmts) =>
|
prog.funcs.foreach { case FuncDecl(_, name, _, stmts) =>
|
||||||
val FuncType(retType, _) = ctx.funcType(name)
|
val FuncType(retType, _) = ctx.funcType(name)
|
||||||
stmts.toList.foreach(
|
stmts.toList.foreach(
|
||||||
@ -81,9 +114,17 @@ object typeChecker {
|
|||||||
prog.main.toList.foreach(checkStmt(_, Constraint.Never("main function must not return")))
|
prog.main.toList.foreach(checkStmt(_, Constraint.Never("main function must not return")))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/** Type-check an AST statement node.
|
||||||
|
*
|
||||||
|
* @param stmt
|
||||||
|
* The statement to type-check.
|
||||||
|
* @param returnConstraint
|
||||||
|
* The constraint that any `return <expr>` statements must satisfy.
|
||||||
|
*/
|
||||||
private def checkStmt(stmt: Stmt, returnConstraint: Constraint)(using
|
private def checkStmt(stmt: Stmt, returnConstraint: Constraint)(using
|
||||||
ctx: TypeCheckerCtx
|
ctx: TypeCheckerCtx
|
||||||
): Unit = stmt match {
|
): Unit = stmt match {
|
||||||
|
// Ignore the type of the variable, since it has been converted to a SemType by the renamer.
|
||||||
case VarDecl(_, name, value) =>
|
case VarDecl(_, name, value) =>
|
||||||
val expectedTy = ctx.typeOf(name)
|
val expectedTy = ctx.typeOf(name)
|
||||||
checkValue(
|
checkValue(
|
||||||
@ -146,6 +187,16 @@ object typeChecker {
|
|||||||
case Skip() => ()
|
case Skip() => ()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/** Type-check an AST LValue, RValue or Expr node. This function does all 3 since these traits
|
||||||
|
* overlap in the AST.
|
||||||
|
*
|
||||||
|
* @param value
|
||||||
|
* The value to type-check.
|
||||||
|
* @param constraint
|
||||||
|
* The type constraint that the value must satisfy.
|
||||||
|
* @return
|
||||||
|
* The most specific type of the value if it could be determined, or ? if it could not.
|
||||||
|
*/
|
||||||
private def checkValue(value: LValue | RValue | Expr, constraint: Constraint)(using
|
private def checkValue(value: LValue | RValue | Expr, constraint: Constraint)(using
|
||||||
ctx: TypeCheckerCtx
|
ctx: TypeCheckerCtx
|
||||||
): SemType = value match {
|
): SemType = value match {
|
||||||
@ -162,7 +213,7 @@ object typeChecker {
|
|||||||
checkValue(elem, Constraint.Is(KnownType.Int, "array index must be an int"))
|
checkValue(elem, Constraint.Is(KnownType.Int, "array index must be an int"))
|
||||||
acc match {
|
acc match {
|
||||||
case KnownType.Array(innerTy) => Some(innerTy)
|
case KnownType.Array(innerTy) => Some(innerTy)
|
||||||
case ? => Some(?)
|
case ? => Some(?) // we can keep indexing an unknown type
|
||||||
case nonArrayTy =>
|
case nonArrayTy =>
|
||||||
ctx.error(
|
ctx.error(
|
||||||
Error.TypeMismatch(elem.pos, KnownType.Array(?), acc, "cannot index into a non-array")
|
Error.TypeMismatch(elem.pos, KnownType.Array(?), acc, "cannot index into a non-array")
|
||||||
@ -174,6 +225,7 @@ object typeChecker {
|
|||||||
case Parens(expr) => checkValue(expr, constraint)
|
case Parens(expr) => checkValue(expr, constraint)
|
||||||
case l @ ArrayLiter(elems) =>
|
case l @ ArrayLiter(elems) =>
|
||||||
KnownType
|
KnownType
|
||||||
|
// Start with an unknown param type, make it more specific while checking the elements.
|
||||||
.Array(elems.foldLeft[SemType](?) { case (acc, elem) =>
|
.Array(elems.foldLeft[SemType](?) { case (acc, elem) =>
|
||||||
checkValue(
|
checkValue(
|
||||||
elem,
|
elem,
|
||||||
@ -193,6 +245,8 @@ object typeChecker {
|
|||||||
if (args.length != paramTys.length) {
|
if (args.length != paramTys.length) {
|
||||||
ctx.error(Error.FunctionParamsMismatch(id, paramTys.length, args.length, funcTy))
|
ctx.error(Error.FunctionParamsMismatch(id, paramTys.length, args.length, funcTy))
|
||||||
}
|
}
|
||||||
|
// Even if the number of arguments is wrong, we still check the types of the arguments
|
||||||
|
// in the best way we can (by taking a zip).
|
||||||
args.zip(paramTys).foreach { case (arg, paramTy) =>
|
args.zip(paramTys).foreach { case (arg, paramTy) =>
|
||||||
checkValue(arg, Constraint.Is(paramTy, s"argument type mismatch in function ${id.v}"))
|
checkValue(arg, Constraint.Is(paramTy, s"argument type mismatch in function ${id.v}"))
|
||||||
}
|
}
|
||||||
@ -205,7 +259,7 @@ object typeChecker {
|
|||||||
case what @ KnownType.Pair(left, _) =>
|
case what @ KnownType.Pair(left, _) =>
|
||||||
left.satisfies(constraint, elem.pos)
|
left.satisfies(constraint, elem.pos)
|
||||||
case _ => ctx.error(Error.InternalError(elem.pos, "fst must be applied to a pair"))
|
case _ => ctx.error(Error.InternalError(elem.pos, "fst must be applied to a pair"))
|
||||||
} // satisfies constraint
|
}
|
||||||
case Snd(elem) =>
|
case Snd(elem) =>
|
||||||
checkValue(
|
checkValue(
|
||||||
elem,
|
elem,
|
||||||
@ -251,6 +305,7 @@ object typeChecker {
|
|||||||
KnownType.Char,
|
KnownType.Char,
|
||||||
s"${op.name} operator must be applied to an int or char"
|
s"${op.name} operator must be applied to an int or char"
|
||||||
)
|
)
|
||||||
|
// If x type-check failed, we still want to check y is an Int or Char (rather than ?)
|
||||||
val yConstraint = checkValue(op.x, xConstraint) match {
|
val yConstraint = checkValue(op.x, xConstraint) match {
|
||||||
case ? => xConstraint
|
case ? => xConstraint
|
||||||
case xTy =>
|
case xTy =>
|
||||||
|
Loading…
x
Reference in New Issue
Block a user