feat: implement satisfies function in type checker

Co-Authored-By: jt2622
This commit is contained in:
Gleb Koval 2025-02-06 21:04:27 +00:00
parent 6548d895d5
commit f6e734937f
Signed by: cyclane
GPG Key ID: 15E168A8B332382C
4 changed files with 96 additions and 62 deletions

View File

@ -8,6 +8,7 @@ enum Error {
case UndefinedIdentifier(ident: ast.Ident, identType: renamer.IdentType) case UndefinedIdentifier(ident: ast.Ident, identType: renamer.IdentType)
case FunctionParamsMismatch(expected: Int, got: Int) // TODO not fine case FunctionParamsMismatch(expected: Int, got: Int) // TODO not fine
case SemanticError(pos: Position, msg: String)
case TypeMismatch(pos: Position, expected: SemType, got: SemType, msg: String) case TypeMismatch(pos: Position, expected: SemType, got: SemType, msg: String)
case InternalError(pos: Position, msg: String) case InternalError(pos: Position, msg: String)
} }

View File

@ -32,12 +32,11 @@ val cliParser = {
def compile(contents: String): Int = { def compile(contents: String): Int = {
parser.parse(contents) match { parser.parse(contents) match {
case Success(ast) => case Success(prog) =>
given errors: mutable.Builder[Error, List[Error]] = List.newBuilder given errors: mutable.Builder[Error, List[Error]] = List.newBuilder
renamer.rename(ast) val names = renamer.rename(prog)
// given ctx: types.TypeCheckerCtx[List[Error]] = given ctx: typeChecker.TypeCheckerCtx = typeChecker.TypeCheckerCtx(names, errors)
// types.TypeCheckerCtx(names, errors) typeChecker.check(prog)
// types.check(ast)
if (errors.result.nonEmpty) { if (errors.result.nonEmpty) {
errors.result.foreach(println) errors.result.foreach(println)
200 200

View File

@ -1,5 +1,6 @@
package wacc package wacc
import cats.syntax.all._
import scala.collection.mutable import scala.collection.mutable
object typeChecker { object typeChecker {
@ -27,8 +28,9 @@ object typeChecker {
case Never(msg: String) case Never(msg: String)
} }
extension (ty: SemType) extension (ty: SemType) {
infix def satisfies(constraint: Constraint): SemType = (ty, constraint) match { def satisfies(constraint: Constraint, pos: Position)(using ctx: TypeCheckerCtx): SemType =
(ty, constraint) match {
case (KnownType.Array(KnownType.Char), Constraint.Is(KnownType.String, _)) => case (KnownType.Array(KnownType.Char), Constraint.Is(KnownType.String, _)) =>
KnownType.String KnownType.String
case ( case (
@ -36,8 +38,36 @@ object typeChecker {
Constraint.IsSymmetricCompatible(KnownType.Array(KnownType.Char), _) Constraint.IsSymmetricCompatible(KnownType.Array(KnownType.Char), _)
) => ) =>
KnownType.String KnownType.String
case (ty, Constraint.IsSymmetricCompatible(ty2, msg)) => ty satisfies Constraint.Is(ty2, msg) case (ty, Constraint.IsSymmetricCompatible(ty2, msg)) =>
case (ty, Constraint.Is(ty2, msg)) => ty satisfies Constraint.IsUnweakanable(ty2, msg) ty.satisfies(Constraint.Is(ty2, msg), pos)
case (ty, Constraint.Is(ty2, msg)) => ty.satisfies(Constraint.IsUnweakanable(ty2, msg), pos)
case (ty, Constraint.Unconstrained) => ty
case (KnownType.Func(_, _), Constraint.IsVar(msg)) =>
ctx.error(Error.SemanticError(pos, msg))
case (ty, Constraint.IsVar(msg)) => ty
case (ty, Constraint.Never(msg)) =>
ctx.error(Error.SemanticError(pos, msg))
case (ty, Constraint.IsEither(ty1, ty2, msg)) =>
(ty moreSpecific ty1).orElse(ty moreSpecific ty2).getOrElse {
ctx.error(Error.TypeMismatch(pos, ty1, ty, msg))
}
case (ty, Constraint.IsUnweakanable(ty2, msg)) =>
(ty moreSpecific ty2).getOrElse {
ctx.error(Error.TypeMismatch(pos, ty2, ty, msg))
}
}
infix def moreSpecific(ty2: SemType): Option[SemType] =
(ty, ty2) match {
case (ty, ?) => Some(ty)
case (?, ty) => Some(ty)
case (ty1, ty2) if ty1 == ty2 => Some(ty1)
case (KnownType.Array(inn1), KnownType.Array(inn2)) =>
(inn1 moreSpecific inn2).map(KnownType.Array(_))
case (KnownType.Pair(fst1, snd1), KnownType.Pair(fst2, snd2)) =>
(fst1 moreSpecific fst2, snd1 moreSpecific snd2).mapN(KnownType.Pair(_, _))
case _ => None
}
} }
def check(prog: Program)(using def check(prog: Program)(using
@ -103,13 +133,13 @@ object typeChecker {
private def checkValue(value: LValue | RValue | Expr, constraint: Constraint)(using private def checkValue(value: LValue | RValue | Expr, constraint: Constraint)(using
ctx: TypeCheckerCtx ctx: TypeCheckerCtx
): SemType = value match { ): SemType = value match {
case IntLiter(_) => KnownType.Int satisfies constraint case l @ IntLiter(_) => KnownType.Int.satisfies(constraint, l.pos)
case BoolLiter(_) => KnownType.Bool satisfies constraint case l @ BoolLiter(_) => KnownType.Bool.satisfies(constraint, l.pos)
case CharLiter(_) => KnownType.Char satisfies constraint case l @ CharLiter(_) => KnownType.Char.satisfies(constraint, l.pos)
case StrLiter(_) => KnownType.String satisfies constraint case l @ StrLiter(_) => KnownType.String.satisfies(constraint, l.pos)
case PairLiter() => KnownType.Pair(?, ?) satisfies constraint case l @ PairLiter() => KnownType.Pair(?, ?).satisfies(constraint, l.pos)
case id: Ident => case id: Ident =>
ctx.typeOf(id) satisfies constraint ctx.typeOf(id).satisfies(constraint, id.pos)
case ArrayElem(id, indices) => case ArrayElem(id, indices) =>
val arrayTy = ctx.typeOf(id) val arrayTy = ctx.typeOf(id)
val elemTy = indices.toList.foldRight(arrayTy) { (elem, acc) => val elemTy = indices.toList.foldRight(arrayTy) { (elem, acc) =>
@ -122,20 +152,24 @@ object typeChecker {
) )
} }
} }
elemTy satisfies constraint elemTy.satisfies(constraint, id.pos)
case Parens(expr) => checkValue(expr, constraint) case Parens(expr) => checkValue(expr, constraint)
case ArrayLiter(elems) => case l @ ArrayLiter(elems) =>
KnownType.Array(elems.foldRight[SemType](?) { case (elem, acc) => KnownType
.Array(elems.foldRight[SemType](?) { case (elem, acc) =>
checkValue( checkValue(
elem, elem,
Constraint.IsSymmetricCompatible(acc, "array elements must have the same type") Constraint.IsSymmetricCompatible(acc, "array elements must have the same type")
) )
}) satisfies constraint })
case NewPair(fst, snd) => .satisfies(constraint, l.pos)
KnownType.Pair( case l @ NewPair(fst, snd) =>
KnownType
.Pair(
checkValue(fst, Constraint.Unconstrained), checkValue(fst, Constraint.Unconstrained),
checkValue(snd, Constraint.Unconstrained) checkValue(snd, Constraint.Unconstrained)
) satisfies constraint )
.satisfies(constraint, l.pos)
case Call(id, args) => case Call(id, args) =>
val funcTy = ctx.typeOf(id) val funcTy = ctx.typeOf(id)
funcTy match { funcTy match {
@ -143,7 +177,7 @@ object typeChecker {
args.zip(paramTys).foreach { case (arg, paramTy) => args.zip(paramTys).foreach { case (arg, paramTy) =>
checkValue(arg, Constraint.Is(paramTy, s"argument type mismatch in function ${id.v}")) checkValue(arg, Constraint.Is(paramTy, s"argument type mismatch in function ${id.v}"))
} }
retTy satisfies constraint retTy.satisfies(constraint, id.pos)
// Should never happen, the scope-checker should have caught this already // Should never happen, the scope-checker should have caught this already
// ctx error had it not // ctx error had it not
case _ => ctx.error(Error.InternalError(id.pos, "function call to non-function")) case _ => ctx.error(Error.InternalError(id.pos, "function call to non-function"))
@ -153,8 +187,8 @@ object typeChecker {
elem, elem,
Constraint.Is(KnownType.Pair(?, ?), "fst must be applied to a pair") Constraint.Is(KnownType.Pair(?, ?), "fst must be applied to a pair")
) match { ) match {
case KnownType.Pair(left, _) => left satisfies constraint case KnownType.Pair(left, _) => left.satisfies(constraint, elem.pos)
case ? => ? satisfies constraint case ? => ?.satisfies(constraint, elem.pos)
case _ => ctx.error(Error.InternalError(elem.pos, "fst must be applied to a pair")) case _ => ctx.error(Error.InternalError(elem.pos, "fst must be applied to a pair"))
} // satisfies constraint } // satisfies constraint
case Snd(elem) => case Snd(elem) =>
@ -162,38 +196,38 @@ object typeChecker {
elem, elem,
Constraint.Is(KnownType.Pair(?, ?), "snd must be applied to a pair") Constraint.Is(KnownType.Pair(?, ?), "snd must be applied to a pair")
) match { ) match {
case KnownType.Pair(_, right) => right satisfies constraint case KnownType.Pair(_, right) => right.satisfies(constraint, elem.pos)
case ? => ? satisfies constraint case ? => ?.satisfies(constraint, elem.pos)
case _ => ctx.error(Error.InternalError(elem.pos, "snd must be applied to a pair")) case _ => ctx.error(Error.InternalError(elem.pos, "snd must be applied to a pair"))
} }
// Unary operators // Unary operators
case Negate(x) => case Negate(x) =>
checkValue(x, Constraint.Is(KnownType.Int, "negation must be applied to an int")) checkValue(x, Constraint.Is(KnownType.Int, "negation must be applied to an int"))
KnownType.Int satisfies constraint KnownType.Int.satisfies(constraint, x.pos)
case Not(x) => case Not(x) =>
checkValue(x, Constraint.Is(KnownType.Bool, "logical not must be applied to a bool")) checkValue(x, Constraint.Is(KnownType.Bool, "logical not must be applied to a bool"))
KnownType.Bool satisfies constraint KnownType.Bool.satisfies(constraint, x.pos)
case Len(x) => case Len(x) =>
checkValue(x, Constraint.Is(KnownType.Array(?), "len must be applied to an array")) checkValue(x, Constraint.Is(KnownType.Array(?), "len must be applied to an array"))
KnownType.Int satisfies constraint KnownType.Int.satisfies(constraint, x.pos)
case Ord(x) => case Ord(x) =>
checkValue(x, Constraint.Is(KnownType.Char, "ord must be applied to a char")) checkValue(x, Constraint.Is(KnownType.Char, "ord must be applied to a char"))
KnownType.Int satisfies constraint KnownType.Int.satisfies(constraint, x.pos)
case Chr(x) => case Chr(x) =>
checkValue(x, Constraint.Is(KnownType.Int, "chr must be applied to an int")) checkValue(x, Constraint.Is(KnownType.Int, "chr must be applied to an int"))
KnownType.Char satisfies constraint KnownType.Char.satisfies(constraint, x.pos)
// Binary operators // Binary operators
case op: (Add | Sub | Mul | Div | Mod) => case op: (Add | Sub | Mul | Div | Mod) =>
val operand = Constraint.Is(KnownType.Int, "binary operator must be applied to an int") val operand = Constraint.Is(KnownType.Int, "binary operator must be applied to an int")
checkValue(op.x, operand) checkValue(op.x, operand)
checkValue(op.y, operand) checkValue(op.y, operand)
KnownType.Int satisfies constraint KnownType.Int.satisfies(constraint, op.pos)
case op: (Eq | Neq) => case op: (Eq | Neq) =>
val xTy = checkValue(op.x, Constraint.Unconstrained) val xTy = checkValue(op.x, Constraint.Unconstrained)
checkValue(op.y, Constraint.Is(xTy, "equality must be applied to values of the same type")) checkValue(op.y, Constraint.Is(xTy, "equality must be applied to values of the same type"))
KnownType.Bool satisfies constraint KnownType.Bool.satisfies(constraint, op.pos)
case op: (Less | LessEq | Greater | GreaterEq) => case op: (Less | LessEq | Greater | GreaterEq) =>
val xTy = checkValue( val xTy = checkValue(
op.x, op.x,
@ -204,11 +238,11 @@ object typeChecker {
) )
) )
checkValue(op.y, Constraint.Is(xTy, "comparison must be applied to values of the same type")) checkValue(op.y, Constraint.Is(xTy, "comparison must be applied to values of the same type"))
KnownType.Bool satisfies constraint KnownType.Bool.satisfies(constraint, op.pos)
case op: (And | Or) => case op: (And | Or) =>
val operand = Constraint.Is(KnownType.Bool, "logical operator must be applied to a bool") val operand = Constraint.Is(KnownType.Bool, "logical operator must be applied to a bool")
checkValue(op.x, operand) checkValue(op.x, operand)
checkValue(op.y, operand) checkValue(op.y, operand)
KnownType.Bool satisfies constraint KnownType.Bool.satisfies(constraint, op.pos)
} }
} }

View File

@ -68,19 +68,19 @@ class ParallelExamplesSpec extends AnyFlatSpec with BeforeAndAfterAll with Paral
// "wacc-examples/invalid/syntaxErr/variables", // "wacc-examples/invalid/syntaxErr/variables",
// "wacc-examples/invalid/syntaxErr/while", // "wacc-examples/invalid/syntaxErr/while",
// invalid (semantic) // invalid (semantic)
"wacc-examples/invalid/semanticErr/array", // "wacc-examples/invalid/semanticErr/array",
"wacc-examples/invalid/semanticErr/exit", // "wacc-examples/invalid/semanticErr/exit",
"wacc-examples/invalid/semanticErr/expressions", // "wacc-examples/invalid/semanticErr/expressions",
"wacc-examples/invalid/semanticErr/function", // "wacc-examples/invalid/semanticErr/function",
"wacc-examples/invalid/semanticErr/if", // "wacc-examples/invalid/semanticErr/if",
"wacc-examples/invalid/semanticErr/IO", // "wacc-examples/invalid/semanticErr/IO",
"wacc-examples/invalid/semanticErr/multiple", // "wacc-examples/invalid/semanticErr/multiple",
"wacc-examples/invalid/semanticErr/pairs", // "wacc-examples/invalid/semanticErr/pairs",
"wacc-examples/invalid/semanticErr/print", // "wacc-examples/invalid/semanticErr/print",
"wacc-examples/invalid/semanticErr/read", // "wacc-examples/invalid/semanticErr/read",
"wacc-examples/invalid/semanticErr/scope", // "wacc-examples/invalid/semanticErr/scope",
"wacc-examples/invalid/semanticErr/variables", // "wacc-examples/invalid/semanticErr/variables",
"wacc-examples/invalid/semanticErr/while", // "wacc-examples/invalid/semanticErr/while",
// invalid (whack) // invalid (whack)
"wacc-examples/invalid/whack" "wacc-examples/invalid/whack"
// format: on // format: on