refactor: remove implicit ast from type checker

This commit is contained in:
Gleb Koval 2025-02-13 23:12:41 +00:00
parent bc25f914ad
commit 756b42dd72
Signed by: cyclane
GPG Key ID: 15E168A8B332382C

View File

@ -4,17 +4,15 @@ import cats.syntax.all._
import scala.collection.mutable import scala.collection.mutable
object typeChecker { object typeChecker {
import wacc.ast._
import wacc.types._ import wacc.types._
case class TypeCheckerCtx( case class TypeCheckerCtx(
globalNames: Map[Ident, SemType], globalNames: Map[ast.Ident, SemType],
globalFuncs: Map[Ident, FuncType], globalFuncs: Map[ast.Ident, FuncType],
errors: mutable.Builder[Error, List[Error]] errors: mutable.Builder[Error, List[Error]]
) { ) {
def typeOf(ident: Ident): SemType = globalNames(ident) def typeOf(ident: ast.Ident): SemType = globalNames(ident)
def funcType(ident: ast.Ident): FuncType = globalFuncs(ident)
def funcType(ident: Ident): FuncType = globalFuncs(ident)
def error(err: Error): SemType = def error(err: Error): SemType =
errors += err errors += err
@ -44,7 +42,7 @@ object typeChecker {
* @return * @return
* The type if the constraint was satisfied, or ? if it was not. * The type if the constraint was satisfied, or ? if it was not.
*/ */
private def satisfies(constraint: Constraint, pos: Position)(using private def satisfies(constraint: Constraint, pos: ast.Position)(using
ctx: TypeCheckerCtx ctx: TypeCheckerCtx
): SemType = ): SemType =
(ty, constraint) match { (ty, constraint) match {
@ -100,12 +98,12 @@ object typeChecker {
* The type checker context which includes the global names and functions, and an errors * The type checker context which includes the global names and functions, and an errors
* builder. * builder.
*/ */
def check(prog: Program)(using def check(prog: ast.Program)(using
ctx: TypeCheckerCtx ctx: TypeCheckerCtx
): Unit = { ): Unit = {
// Ignore function syntax types for return value and params, since those have been converted // Ignore function syntax types for return value and params, since those have been converted
// to SemTypes by the renamer. // to SemTypes by the renamer.
prog.funcs.foreach { case FuncDecl(_, name, _, stmts) => prog.funcs.foreach { case ast.FuncDecl(_, name, _, stmts) =>
val FuncType(retType, _) = ctx.funcType(name) val FuncType(retType, _) = ctx.funcType(name)
stmts.toList.foreach( stmts.toList.foreach(
checkStmt(_, Constraint.Is(retType, s"function ${name.v} must return $retType")) checkStmt(_, Constraint.Is(retType, s"function ${name.v} must return $retType"))
@ -121,11 +119,11 @@ object typeChecker {
* @param returnConstraint * @param returnConstraint
* The constraint that any `return <expr>` statements must satisfy. * The constraint that any `return <expr>` statements must satisfy.
*/ */
private def checkStmt(stmt: Stmt, returnConstraint: Constraint)(using private def checkStmt(stmt: ast.Stmt, returnConstraint: Constraint)(using
ctx: TypeCheckerCtx ctx: TypeCheckerCtx
): Unit = stmt match { ): Unit = stmt match {
// Ignore the type of the variable, since it has been converted to a SemType by the renamer. // Ignore the type of the variable, since it has been converted to a SemType by the renamer.
case VarDecl(_, name, value) => case ast.VarDecl(_, name, value) =>
val expectedTy = ctx.typeOf(name) val expectedTy = ctx.typeOf(name)
checkValue( checkValue(
value, value,
@ -134,7 +132,7 @@ object typeChecker {
s"variable ${name.v} must be assigned a value of type $expectedTy" s"variable ${name.v} must be assigned a value of type $expectedTy"
) )
) )
case Assign(lhs, rhs) => case ast.Assign(lhs, rhs) =>
val lhsTy = checkValue(lhs, Constraint.Unconstrained) val lhsTy = checkValue(lhs, Constraint.Unconstrained)
(lhsTy, checkValue(rhs, Constraint.Is(lhsTy, s"assignment must have type $lhsTy"))) match { (lhsTy, checkValue(rhs, Constraint.Is(lhsTy, s"assignment must have type $lhsTy"))) match {
case (?, ?) => case (?, ?) =>
@ -143,7 +141,7 @@ object typeChecker {
) )
case _ => () case _ => ()
} }
case Read(dest) => case ast.Read(dest) =>
checkValue(dest, Constraint.Unconstrained) match { checkValue(dest, Constraint.Unconstrained) match {
case ? => case ? =>
ctx.error( ctx.error(
@ -159,7 +157,7 @@ object typeChecker {
dest.pos dest.pos
) )
} }
case Free(lhs) => case ast.Free(lhs) =>
checkValue( checkValue(
lhs, lhs,
Constraint.IsEither( Constraint.IsEither(
@ -168,23 +166,23 @@ object typeChecker {
"free must be applied to an array or pair" "free must be applied to an array or pair"
) )
) )
case Return(expr) => case ast.Return(expr) =>
checkValue(expr, returnConstraint) checkValue(expr, returnConstraint)
case Exit(expr) => case ast.Exit(expr) =>
checkValue(expr, Constraint.Is(KnownType.Int, "exit value must be int")) checkValue(expr, Constraint.Is(KnownType.Int, "exit value must be int"))
case Print(expr, _) => case ast.Print(expr, _) =>
// This constraint should never fail, the scope-checker should have caught it already // This constraint should never fail, the scope-checker should have caught it already
checkValue(expr, Constraint.Unconstrained) checkValue(expr, Constraint.Unconstrained)
case If(cond, thenStmt, elseStmt) => case ast.If(cond, thenStmt, elseStmt) =>
checkValue(cond, Constraint.Is(KnownType.Bool, "if condition must be a bool")) checkValue(cond, Constraint.Is(KnownType.Bool, "if condition must be a bool"))
thenStmt.toList.foreach(checkStmt(_, returnConstraint)) thenStmt.toList.foreach(checkStmt(_, returnConstraint))
elseStmt.toList.foreach(checkStmt(_, returnConstraint)) elseStmt.toList.foreach(checkStmt(_, returnConstraint))
case While(cond, body) => case ast.While(cond, body) =>
checkValue(cond, Constraint.Is(KnownType.Bool, "while condition must be a bool")) checkValue(cond, Constraint.Is(KnownType.Bool, "while condition must be a bool"))
body.toList.foreach(checkStmt(_, returnConstraint)) body.toList.foreach(checkStmt(_, returnConstraint))
case Block(body) => case ast.Block(body) =>
body.toList.foreach(checkStmt(_, returnConstraint)) body.toList.foreach(checkStmt(_, returnConstraint))
case Skip() => () case ast.Skip() => ()
} }
/** Type-check an AST LValue, RValue or Expr node. This function does all 3 since these traits /** Type-check an AST LValue, RValue or Expr node. This function does all 3 since these traits
@ -197,17 +195,17 @@ object typeChecker {
* @return * @return
* The most specific type of the value if it could be determined, or ? if it could not. * The most specific type of the value if it could be determined, or ? if it could not.
*/ */
private def checkValue(value: LValue | RValue | Expr, constraint: Constraint)(using private def checkValue(value: ast.LValue | ast.RValue | ast.Expr, constraint: Constraint)(using
ctx: TypeCheckerCtx ctx: TypeCheckerCtx
): SemType = value match { ): SemType = value match {
case l @ IntLiter(_) => KnownType.Int.satisfies(constraint, l.pos) case l @ ast.IntLiter(_) => KnownType.Int.satisfies(constraint, l.pos)
case l @ BoolLiter(_) => KnownType.Bool.satisfies(constraint, l.pos) case l @ ast.BoolLiter(_) => KnownType.Bool.satisfies(constraint, l.pos)
case l @ CharLiter(_) => KnownType.Char.satisfies(constraint, l.pos) case l @ ast.CharLiter(_) => KnownType.Char.satisfies(constraint, l.pos)
case l @ StrLiter(_) => KnownType.String.satisfies(constraint, l.pos) case l @ ast.StrLiter(_) => KnownType.String.satisfies(constraint, l.pos)
case l @ PairLiter() => KnownType.Pair(?, ?).satisfies(constraint, l.pos) case l @ ast.PairLiter() => KnownType.Pair(?, ?).satisfies(constraint, l.pos)
case id: Ident => case id: ast.Ident =>
ctx.typeOf(id).satisfies(constraint, id.pos) ctx.typeOf(id).satisfies(constraint, id.pos)
case ArrayElem(id, indices) => case ast.ArrayElem(id, indices) =>
val arrayTy = ctx.typeOf(id) val arrayTy = ctx.typeOf(id)
val elemTy = indices.foldLeftM(arrayTy) { (acc, elem) => val elemTy = indices.foldLeftM(arrayTy) { (acc, elem) =>
checkValue(elem, Constraint.Is(KnownType.Int, "array index must be an int")) checkValue(elem, Constraint.Is(KnownType.Int, "array index must be an int"))
@ -222,8 +220,8 @@ object typeChecker {
} }
} }
elemTy.getOrElse(?).satisfies(constraint, id.pos) elemTy.getOrElse(?).satisfies(constraint, id.pos)
case Parens(expr) => checkValue(expr, constraint) case ast.Parens(expr) => checkValue(expr, constraint)
case l @ ArrayLiter(elems) => case l @ ast.ArrayLiter(elems) =>
KnownType KnownType
// Start with an unknown param type, make it more specific while checking the elements. // Start with an unknown param type, make it more specific while checking the elements.
.Array(elems.foldLeft[SemType](?) { case (acc, elem) => .Array(elems.foldLeft[SemType](?) { case (acc, elem) =>
@ -233,14 +231,14 @@ object typeChecker {
) )
}) })
.satisfies(constraint, l.pos) .satisfies(constraint, l.pos)
case l @ NewPair(fst, snd) => case l @ ast.NewPair(fst, snd) =>
KnownType KnownType
.Pair( .Pair(
checkValue(fst, Constraint.Unconstrained), checkValue(fst, Constraint.Unconstrained),
checkValue(snd, Constraint.Unconstrained) checkValue(snd, Constraint.Unconstrained)
) )
.satisfies(constraint, l.pos) .satisfies(constraint, l.pos)
case Call(id, args) => case ast.Call(id, args) =>
val funcTy @ FuncType(retTy, paramTys) = ctx.funcType(id) val funcTy @ FuncType(retTy, paramTys) = ctx.funcType(id)
if (args.length != paramTys.length) { if (args.length != paramTys.length) {
ctx.error(Error.FunctionParamsMismatch(id, paramTys.length, args.length, funcTy)) ctx.error(Error.FunctionParamsMismatch(id, paramTys.length, args.length, funcTy))
@ -251,7 +249,7 @@ object typeChecker {
checkValue(arg, Constraint.Is(paramTy, s"argument type mismatch in function ${id.v}")) checkValue(arg, Constraint.Is(paramTy, s"argument type mismatch in function ${id.v}"))
} }
retTy.satisfies(constraint, id.pos) retTy.satisfies(constraint, id.pos)
case Fst(elem) => case ast.Fst(elem) =>
checkValue( checkValue(
elem, elem,
Constraint.Is(KnownType.Pair(?, ?), "fst must be applied to a pair") Constraint.Is(KnownType.Pair(?, ?), "fst must be applied to a pair")
@ -260,7 +258,7 @@ object typeChecker {
left.satisfies(constraint, elem.pos) left.satisfies(constraint, elem.pos)
case _ => ctx.error(Error.InternalError(elem.pos, "fst must be applied to a pair")) case _ => ctx.error(Error.InternalError(elem.pos, "fst must be applied to a pair"))
} }
case Snd(elem) => case ast.Snd(elem) =>
checkValue( checkValue(
elem, elem,
Constraint.Is(KnownType.Pair(?, ?), "snd must be applied to a pair") Constraint.Is(KnownType.Pair(?, ?), "snd must be applied to a pair")
@ -270,36 +268,36 @@ object typeChecker {
} }
// Unary operators // Unary operators
case Negate(x) => case ast.Negate(x) =>
checkValue(x, Constraint.Is(KnownType.Int, "negation must be applied to an int")) checkValue(x, Constraint.Is(KnownType.Int, "negation must be applied to an int"))
KnownType.Int.satisfies(constraint, x.pos) KnownType.Int.satisfies(constraint, x.pos)
case Not(x) => case ast.Not(x) =>
checkValue(x, Constraint.Is(KnownType.Bool, "logical not must be applied to a bool")) checkValue(x, Constraint.Is(KnownType.Bool, "logical not must be applied to a bool"))
KnownType.Bool.satisfies(constraint, x.pos) KnownType.Bool.satisfies(constraint, x.pos)
case Len(x) => case ast.Len(x) =>
checkValue(x, Constraint.Is(KnownType.Array(?), "len must be applied to an array")) checkValue(x, Constraint.Is(KnownType.Array(?), "len must be applied to an array"))
KnownType.Int.satisfies(constraint, x.pos) KnownType.Int.satisfies(constraint, x.pos)
case Ord(x) => case ast.Ord(x) =>
checkValue(x, Constraint.Is(KnownType.Char, "ord must be applied to a char")) checkValue(x, Constraint.Is(KnownType.Char, "ord must be applied to a char"))
KnownType.Int.satisfies(constraint, x.pos) KnownType.Int.satisfies(constraint, x.pos)
case Chr(x) => case ast.Chr(x) =>
checkValue(x, Constraint.Is(KnownType.Int, "chr must be applied to an int")) checkValue(x, Constraint.Is(KnownType.Int, "chr must be applied to an int"))
KnownType.Char.satisfies(constraint, x.pos) KnownType.Char.satisfies(constraint, x.pos)
// Binary operators // Binary operators
case op: (Add | Sub | Mul | Div | Mod) => case op: (ast.Add | ast.Sub | ast.Mul | ast.Div | ast.Mod) =>
val operand = Constraint.Is(KnownType.Int, s"${op.name} operator must be applied to an int") val operand = Constraint.Is(KnownType.Int, s"${op.name} operator must be applied to an int")
checkValue(op.x, operand) checkValue(op.x, operand)
checkValue(op.y, operand) checkValue(op.y, operand)
KnownType.Int.satisfies(constraint, op.pos) KnownType.Int.satisfies(constraint, op.pos)
case op: (Eq | Neq) => case op: (ast.Eq | ast.Neq) =>
val xTy = checkValue(op.x, Constraint.Unconstrained) val xTy = checkValue(op.x, Constraint.Unconstrained)
checkValue( checkValue(
op.y, op.y,
Constraint.Is(xTy, s"${op.name} operator must be applied to values of the same type") Constraint.Is(xTy, s"${op.name} operator must be applied to values of the same type")
) )
KnownType.Bool.satisfies(constraint, op.pos) KnownType.Bool.satisfies(constraint, op.pos)
case op: (Less | LessEq | Greater | GreaterEq) => case op: (ast.Less | ast.LessEq | ast.Greater | ast.GreaterEq) =>
val xConstraint = Constraint.IsEither( val xConstraint = Constraint.IsEither(
KnownType.Int, KnownType.Int,
KnownType.Char, KnownType.Char,
@ -313,7 +311,7 @@ object typeChecker {
} }
checkValue(op.y, yConstraint) checkValue(op.y, yConstraint)
KnownType.Bool.satisfies(constraint, op.pos) KnownType.Bool.satisfies(constraint, op.pos)
case op: (And | Or) => case op: (ast.And | ast.Or) =>
val operand = Constraint.Is(KnownType.Bool, s"${op.name} operator must be applied to a bool") val operand = Constraint.Is(KnownType.Bool, s"${op.name} operator must be applied to a bool")
checkValue(op.x, operand) checkValue(op.x, operand)
checkValue(op.y, operand) checkValue(op.y, operand)