feat: microwacc type checker implementation

Merge request lab2425_spring/WACC_37!22
This commit is contained in:
Gleb Koval 2025-02-18 17:29:10 +00:00
commit bb090ad431
3 changed files with 297 additions and 148 deletions

View File

@ -19,7 +19,7 @@ object microWacc {
extends Expr(identTy) extends Expr(identTy)
with CallTarget(identTy) with CallTarget(identTy)
with LValue with LValue
case class ArrayElem(name: Ident, indices: NonEmptyList[Expr])(ty: SemType) case class ArrayElem(value: LValue, indices: NonEmptyList[Expr])(ty: SemType)
extends Expr(ty) extends Expr(ty)
with LValue with LValue
@ -48,6 +48,23 @@ object microWacc {
case And case And
case Or case Or
} }
object BinaryOperator {
def fromAst(op: ast.BinaryOp): BinaryOperator = op match {
case _: ast.Add => Add
case _: ast.Sub => Sub
case _: ast.Mul => Mul
case _: ast.Div => Div
case _: ast.Mod => Mod
case _: ast.Greater => Greater
case _: ast.GreaterEq => GreaterEq
case _: ast.Less => Less
case _: ast.LessEq => LessEq
case _: ast.Eq => Eq
case _: ast.Neq => Neq
case _: ast.And => And
case _: ast.Or => Or
}
}
// Statements // Statements
sealed trait Stmt sealed trait Stmt

View File

@ -2,19 +2,18 @@ package wacc
import cats.syntax.all._ import cats.syntax.all._
import scala.collection.mutable import scala.collection.mutable
import cats.data.NonEmptyList
object typeChecker { object typeChecker {
import wacc.ast._
import wacc.types._ import wacc.types._
case class TypeCheckerCtx( case class TypeCheckerCtx(
globalNames: Map[Ident, SemType], globalNames: Map[ast.Ident, SemType],
globalFuncs: Map[Ident, FuncType], globalFuncs: Map[ast.Ident, FuncType],
errors: mutable.Builder[Error, List[Error]] errors: mutable.Builder[Error, List[Error]]
) { ) {
def typeOf(ident: Ident): SemType = globalNames(ident) def typeOf(ident: ast.Ident): SemType = globalNames(ident)
def funcType(ident: ast.Ident): FuncType = globalFuncs(ident)
def funcType(ident: Ident): FuncType = globalFuncs(ident)
def error(err: Error): SemType = def error(err: Error): SemType =
errors += err errors += err
@ -44,7 +43,7 @@ object typeChecker {
* @return * @return
* The type if the constraint was satisfied, or ? if it was not. * The type if the constraint was satisfied, or ? if it was not.
*/ */
private def satisfies(constraint: Constraint, pos: Position)(using private def satisfies(constraint: Constraint, pos: ast.Position)(using
ctx: TypeCheckerCtx ctx: TypeCheckerCtx
): SemType = ): SemType =
(ty, constraint) match { (ty, constraint) match {
@ -100,19 +99,28 @@ object typeChecker {
* The type checker context which includes the global names and functions, and an errors * The type checker context which includes the global names and functions, and an errors
* builder. * builder.
*/ */
def check(prog: Program)(using def check(prog: ast.Program)(using
ctx: TypeCheckerCtx ctx: TypeCheckerCtx
): Unit = { ): microWacc.Program =
microWacc.Program(
// Ignore function syntax types for return value and params, since those have been converted // Ignore function syntax types for return value and params, since those have been converted
// to SemTypes by the renamer. // to SemTypes by the renamer.
prog.funcs.foreach { case FuncDecl(_, name, _, stmts) => prog.funcs.map { case ast.FuncDecl(_, name, params, stmts) =>
val FuncType(retType, _) = ctx.funcType(name) val FuncType(retType, paramTypes) = ctx.funcType(name)
stmts.toList.foreach( microWacc.FuncDecl(
microWacc.Ident(name.v, name.uid)(retType),
params.zip(paramTypes).map { case (ast.Param(_, ident), ty) =>
microWacc.Ident(ident.v, name.uid)(ty)
},
stmts.toList
.flatMap(
checkStmt(_, Constraint.Is(retType, s"function ${name.v} must return $retType")) checkStmt(_, Constraint.Is(retType, s"function ${name.v} must return $retType"))
) )
} )
prog.main.toList.foreach(checkStmt(_, Constraint.Never("main function must not return"))) },
} prog.main.toList
.flatMap(checkStmt(_, Constraint.Never("main function must not return")))
)
/** Type-check an AST statement node. /** Type-check an AST statement node.
* *
@ -121,34 +129,40 @@ object typeChecker {
* @param returnConstraint * @param returnConstraint
* The constraint that any `return <expr>` statements must satisfy. * The constraint that any `return <expr>` statements must satisfy.
*/ */
private def checkStmt(stmt: Stmt, returnConstraint: Constraint)(using private def checkStmt(stmt: ast.Stmt, returnConstraint: Constraint)(using
ctx: TypeCheckerCtx ctx: TypeCheckerCtx
): Unit = stmt match { ): List[microWacc.Stmt] = stmt match {
// Ignore the type of the variable, since it has been converted to a SemType by the renamer. // Ignore the type of the variable, since it has been converted to a SemType by the renamer.
case VarDecl(_, name, value) => case ast.VarDecl(_, name, value) =>
val expectedTy = ctx.typeOf(name) val expectedTy = ctx.typeOf(name)
checkValue( val typedValue = checkValue(
value, value,
Constraint.Is( Constraint.Is(
expectedTy, expectedTy,
s"variable ${name.v} must be assigned a value of type $expectedTy" s"variable ${name.v} must be assigned a value of type $expectedTy"
) )
) )
case Assign(lhs, rhs) => List(microWacc.Assign(microWacc.Ident(name.v, name.uid)(expectedTy), typedValue))
val lhsTy = checkValue(lhs, Constraint.Unconstrained) case ast.Assign(lhs, rhs) =>
(lhsTy, checkValue(rhs, Constraint.Is(lhsTy, s"assignment must have type $lhsTy"))) match { val lhsTyped = checkLValue(lhs, Constraint.Unconstrained)
val rhsTyped =
checkValue(rhs, Constraint.Is(lhsTyped.ty, s"assignment must have type ${lhsTyped.ty}"))
(lhsTyped.ty, rhsTyped.ty) match {
case (?, ?) => case (?, ?) =>
ctx.error( ctx.error(
Error.SemanticError(lhs.pos, "assignment with both sides of unknown type is illegal") Error.SemanticError(lhs.pos, "assignment with both sides of unknown type is illegal")
) )
case _ => () case _ => ()
} }
case Read(dest) => List(microWacc.Assign(lhsTyped, rhsTyped))
checkValue(dest, Constraint.Unconstrained) match { case ast.Read(dest) =>
val destTyped = checkLValue(dest, Constraint.Unconstrained)
val destTy = destTyped.ty match {
case ? => case ? =>
ctx.error( ctx.error(
Error.SemanticError(dest.pos, "cannot read into a destination with an unknown type") Error.SemanticError(dest.pos, "cannot read into a destination with an unknown type")
) )
?
case destTy => case destTy =>
destTy.satisfies( destTy.satisfies(
Constraint.IsEither( Constraint.IsEither(
@ -159,7 +173,24 @@ object typeChecker {
dest.pos dest.pos
) )
} }
case Free(lhs) => List(
microWacc.Assign(
destTyped,
microWacc.Call(
destTy match {
case KnownType.Int => microWacc.Builtin.ReadInt
case KnownType.Char => microWacc.Builtin.ReadChar
case _ => microWacc.Builtin.ReadInt // we'll stop due to error anyway
},
Nil
)
)
)
case ast.Free(lhs) =>
List(
microWacc.Call(
microWacc.Builtin.Free,
List(
checkValue( checkValue(
lhs, lhs,
Constraint.IsEither( Constraint.IsEither(
@ -168,23 +199,43 @@ object typeChecker {
"free must be applied to an array or pair" "free must be applied to an array or pair"
) )
) )
case Return(expr) => )
checkValue(expr, returnConstraint) )
case Exit(expr) => )
checkValue(expr, Constraint.Is(KnownType.Int, "exit value must be int")) case ast.Return(expr) =>
case Print(expr, _) => List(microWacc.Return(checkValue(expr, returnConstraint)))
case ast.Exit(expr) =>
List(
microWacc.Call(
microWacc.Builtin.Exit,
List(checkValue(expr, Constraint.Is(KnownType.Int, "exit value must be int")))
)
)
case ast.Print(expr, newline) =>
// This constraint should never fail, the scope-checker should have caught it already // This constraint should never fail, the scope-checker should have caught it already
checkValue(expr, Constraint.Unconstrained) List(
case If(cond, thenStmt, elseStmt) => microWacc.Call(
checkValue(cond, Constraint.Is(KnownType.Bool, "if condition must be a bool")) if newline then microWacc.Builtin.Println else microWacc.Builtin.Print,
thenStmt.toList.foreach(checkStmt(_, returnConstraint)) List(checkValue(expr, Constraint.Unconstrained))
elseStmt.toList.foreach(checkStmt(_, returnConstraint)) )
case While(cond, body) => )
checkValue(cond, Constraint.Is(KnownType.Bool, "while condition must be a bool")) case ast.If(cond, thenStmt, elseStmt) =>
body.toList.foreach(checkStmt(_, returnConstraint)) List(
case Block(body) => microWacc.If(
body.toList.foreach(checkStmt(_, returnConstraint)) checkValue(cond, Constraint.Is(KnownType.Bool, "if condition must be a bool")),
case Skip() => () thenStmt.toList.flatMap(checkStmt(_, returnConstraint)),
elseStmt.toList.flatMap(checkStmt(_, returnConstraint))
)
)
case ast.While(cond, body) =>
List(
microWacc.While(
checkValue(cond, Constraint.Is(KnownType.Bool, "while condition must be a bool")),
body.toList.flatMap(checkStmt(_, returnConstraint))
)
)
case ast.Block(body) => body.toList.flatMap(checkStmt(_, returnConstraint))
case skip @ ast.Skip() => List.empty
} }
/** Type-check an AST LValue, RValue or Expr node. This function does all 3 since these traits /** Type-check an AST LValue, RValue or Expr node. This function does all 3 since these traits
@ -197,126 +248,197 @@ object typeChecker {
* @return * @return
* The most specific type of the value if it could be determined, or ? if it could not. * The most specific type of the value if it could be determined, or ? if it could not.
*/ */
private def checkValue(value: LValue | RValue | Expr, constraint: Constraint)(using private def checkValue(value: ast.LValue | ast.RValue | ast.Expr, constraint: Constraint)(using
ctx: TypeCheckerCtx ctx: TypeCheckerCtx
): SemType = value match { ): microWacc.Expr = value match {
case l @ IntLiter(_) => KnownType.Int.satisfies(constraint, l.pos) case l @ ast.IntLiter(v) =>
case l @ BoolLiter(_) => KnownType.Bool.satisfies(constraint, l.pos) KnownType.Int.satisfies(constraint, l.pos)
case l @ CharLiter(_) => KnownType.Char.satisfies(constraint, l.pos) microWacc.IntLiter(v)
case l @ StrLiter(_) => KnownType.String.satisfies(constraint, l.pos) case l @ ast.BoolLiter(v) =>
case l @ PairLiter() => KnownType.Pair(?, ?).satisfies(constraint, l.pos) KnownType.Bool.satisfies(constraint, l.pos)
case id: Ident => microWacc.BoolLiter(v)
ctx.typeOf(id).satisfies(constraint, id.pos) case l @ ast.CharLiter(v) =>
case ArrayElem(id, indices) => KnownType.Char.satisfies(constraint, l.pos)
val arrayTy = ctx.typeOf(id) microWacc.CharLiter(v)
val elemTy = indices.foldLeftM(arrayTy) { (acc, elem) => case l @ ast.StrLiter(v) =>
checkValue(elem, Constraint.Is(KnownType.Int, "array index must be an int")) KnownType.String.satisfies(constraint, l.pos)
acc match { microWacc.ArrayLiter(v.map(microWacc.CharLiter(_)).toList)(KnownType.String)
case KnownType.Array(innerTy) => Some(innerTy) case l @ ast.PairLiter() =>
case ? => Some(?) // we can keep indexing an unknown type microWacc.NullLiter()(KnownType.Pair(?, ?).satisfies(constraint, l.pos))
case nonArrayTy => case ast.Parens(expr) => checkValue(expr, constraint)
ctx.error( case l @ ast.ArrayLiter(elems) =>
Error.TypeMismatch(elem.pos, KnownType.Array(?), acc, "cannot index into a non-array") val (elemTy, elemsTyped) = elems.mapAccumulate[SemType, microWacc.Expr](?) {
) case (acc, elem) =>
None val elemTyped = checkValue(
}
}
elemTy.getOrElse(?).satisfies(constraint, id.pos)
case Parens(expr) => checkValue(expr, constraint)
case l @ ArrayLiter(elems) =>
KnownType
// Start with an unknown param type, make it more specific while checking the elements.
.Array(elems.foldLeft[SemType](?) { case (acc, elem) =>
checkValue(
elem, elem,
Constraint.IsSymmetricCompatible(acc, s"array elements must have the same type") Constraint.IsSymmetricCompatible(acc, s"array elements must have the same type")
) )
}) (elemTyped.ty, elemTyped)
}
val arrayTy = KnownType
// Start with an unknown param type, make it more specific while checking the elements.
.Array(elemTy)
.satisfies(constraint, l.pos) .satisfies(constraint, l.pos)
case l @ NewPair(fst, snd) => microWacc.ArrayLiter(elemsTyped)(arrayTy)
KnownType case l @ ast.NewPair(fst, snd) =>
.Pair( val fstTyped = checkValue(fst, Constraint.Unconstrained)
checkValue(fst, Constraint.Unconstrained), val sndTyped = checkValue(snd, Constraint.Unconstrained)
checkValue(snd, Constraint.Unconstrained) microWacc.ArrayLiter(List(fstTyped, sndTyped))(
KnownType.Pair(fstTyped.ty, sndTyped.ty).satisfies(constraint, l.pos)
) )
.satisfies(constraint, l.pos) case ast.Call(id, args) =>
case Call(id, args) =>
val funcTy @ FuncType(retTy, paramTys) = ctx.funcType(id) val funcTy @ FuncType(retTy, paramTys) = ctx.funcType(id)
if (args.length != paramTys.length) { if (args.length != paramTys.length) {
ctx.error(Error.FunctionParamsMismatch(id, paramTys.length, args.length, funcTy)) ctx.error(Error.FunctionParamsMismatch(id, paramTys.length, args.length, funcTy))
} }
// Even if the number of arguments is wrong, we still check the types of the arguments // Even if the number of arguments is wrong, we still check the types of the arguments
// in the best way we can (by taking a zip). // in the best way we can (by taking a zip).
args.zip(paramTys).foreach { case (arg, paramTy) => val argsTyped = args.zip(paramTys).map { case (arg, paramTy) =>
checkValue(arg, Constraint.Is(paramTy, s"argument type mismatch in function ${id.v}")) checkValue(arg, Constraint.Is(paramTy, s"argument type mismatch in function ${id.v}"))
} }
retTy.satisfies(constraint, id.pos) microWacc.Call(microWacc.Ident(id.v, id.uid)(retTy.satisfies(constraint, id.pos)), argsTyped)
case Fst(elem) =>
checkValue(
elem,
Constraint.Is(KnownType.Pair(?, ?), "fst must be applied to a pair")
) match {
case what @ KnownType.Pair(left, _) =>
left.satisfies(constraint, elem.pos)
case _ => ctx.error(Error.InternalError(elem.pos, "fst must be applied to a pair"))
}
case Snd(elem) =>
checkValue(
elem,
Constraint.Is(KnownType.Pair(?, ?), "snd must be applied to a pair")
) match {
case KnownType.Pair(_, right) => right.satisfies(constraint, elem.pos)
case _ => ctx.error(Error.InternalError(elem.pos, "snd must be applied to a pair"))
}
// Unary operators // Unary operators
case Negate(x) => case ast.Negate(x) =>
checkValue(x, Constraint.Is(KnownType.Int, "negation must be applied to an int")) microWacc.UnaryOp(
KnownType.Int.satisfies(constraint, x.pos) checkValue(x, Constraint.Is(KnownType.Int, "negation must be applied to an int")),
case Not(x) => microWacc.UnaryOperator.Negate
checkValue(x, Constraint.Is(KnownType.Bool, "logical not must be applied to a bool")) )(KnownType.Int.satisfies(constraint, x.pos))
KnownType.Bool.satisfies(constraint, x.pos) case ast.Not(x) =>
case Len(x) => microWacc.UnaryOp(
checkValue(x, Constraint.Is(KnownType.Array(?), "len must be applied to an array")) checkValue(x, Constraint.Is(KnownType.Bool, "logical not must be applied to a bool")),
KnownType.Int.satisfies(constraint, x.pos) microWacc.UnaryOperator.Not
case Ord(x) => )(KnownType.Bool.satisfies(constraint, x.pos))
checkValue(x, Constraint.Is(KnownType.Char, "ord must be applied to a char")) case ast.Len(x) =>
KnownType.Int.satisfies(constraint, x.pos) microWacc.UnaryOp(
case Chr(x) => checkValue(x, Constraint.Is(KnownType.Array(?), "len must be applied to an array")),
checkValue(x, Constraint.Is(KnownType.Int, "chr must be applied to an int")) microWacc.UnaryOperator.Len
KnownType.Char.satisfies(constraint, x.pos) )(KnownType.Int.satisfies(constraint, x.pos))
case ast.Ord(x) =>
microWacc.UnaryOp(
checkValue(x, Constraint.Is(KnownType.Char, "ord must be applied to a char")),
microWacc.UnaryOperator.Ord
)(KnownType.Int.satisfies(constraint, x.pos))
case ast.Chr(x) =>
microWacc.UnaryOp(
checkValue(x, Constraint.Is(KnownType.Int, "chr must be applied to an int")),
microWacc.UnaryOperator.Chr
)(KnownType.Char.satisfies(constraint, x.pos))
// Binary operators // Binary operators
case op: (Add | Sub | Mul | Div | Mod) => case op: (ast.Add | ast.Sub | ast.Mul | ast.Div | ast.Mod) =>
val operand = Constraint.Is(KnownType.Int, s"${op.name} operator must be applied to an int") val operand = Constraint.Is(KnownType.Int, s"${op.name} operator must be applied to an int")
checkValue(op.x, operand) microWacc.BinaryOp(
checkValue(op.y, operand) checkValue(op.x, operand),
KnownType.Int.satisfies(constraint, op.pos) checkValue(op.y, operand),
case op: (Eq | Neq) => microWacc.BinaryOperator.fromAst(op)
val xTy = checkValue(op.x, Constraint.Unconstrained) )(KnownType.Int.satisfies(constraint, op.pos))
case op: (ast.Eq | ast.Neq) =>
val xTyped = checkValue(op.x, Constraint.Unconstrained)
microWacc.BinaryOp(
xTyped,
checkValue( checkValue(
op.y, op.y,
Constraint.Is(xTy, s"${op.name} operator must be applied to values of the same type") Constraint
) .Is(xTyped.ty, s"${op.name} operator must be applied to values of the same type")
KnownType.Bool.satisfies(constraint, op.pos) ),
case op: (Less | LessEq | Greater | GreaterEq) => microWacc.BinaryOperator.fromAst(op)
)(KnownType.Bool.satisfies(constraint, op.pos))
case op: (ast.Less | ast.LessEq | ast.Greater | ast.GreaterEq) =>
val xConstraint = Constraint.IsEither( val xConstraint = Constraint.IsEither(
KnownType.Int, KnownType.Int,
KnownType.Char, KnownType.Char,
s"${op.name} operator must be applied to an int or char" s"${op.name} operator must be applied to an int or char"
) )
val xTyped = checkValue(op.x, xConstraint)
// If x type-check failed, we still want to check y is an Int or Char (rather than ?) // If x type-check failed, we still want to check y is an Int or Char (rather than ?)
val yConstraint = checkValue(op.x, xConstraint) match { val yConstraint = xTyped.ty match {
case ? => xConstraint case ? => xConstraint
case xTy => case xTy =>
Constraint.Is(xTy, s"${op.name} operator must be applied to values of the same type") Constraint.Is(xTy, s"${op.name} operator must be applied to values of the same type")
} }
checkValue(op.y, yConstraint) microWacc.BinaryOp(
KnownType.Bool.satisfies(constraint, op.pos) xTyped,
case op: (And | Or) => checkValue(op.y, yConstraint),
microWacc.BinaryOperator.fromAst(op)
)(KnownType.Bool.satisfies(constraint, op.pos))
case op: (ast.And | ast.Or) =>
val operand = Constraint.Is(KnownType.Bool, s"${op.name} operator must be applied to a bool") val operand = Constraint.Is(KnownType.Bool, s"${op.name} operator must be applied to a bool")
checkValue(op.x, operand) microWacc.BinaryOp(
checkValue(op.y, operand) checkValue(op.x, operand),
KnownType.Bool.satisfies(constraint, op.pos) checkValue(op.y, operand),
microWacc.BinaryOperator.fromAst(op)
)(KnownType.Bool.satisfies(constraint, op.pos))
case lvalue: ast.LValue => checkLValue(lvalue, constraint)
}
/** Type-check an AST LValue node. Separate because microWacc keeps LValues
*
* @param value
* The value to type-check.
* @param constraint
* The type constraint that the value must satisfy.
* @param ctx
* The type checker context which includes the global names and functions, and an errors
* builder.
* @return
* The most specific type of the value if it could be determined, or ? if it could not.
*/
private def checkLValue(value: ast.LValue, constraint: Constraint)(using
ctx: TypeCheckerCtx
): microWacc.LValue = value match {
case id @ ast.Ident(name, uid) =>
microWacc.Ident(name, uid)(ctx.typeOf(id).satisfies(constraint, id.pos))
case ast.ArrayElem(id, indices) =>
val arrayTy = ctx.typeOf(id)
val (elemTy, indicesTyped) = indices.mapAccumulate(arrayTy) { (acc, elem) =>
val idxTyped = checkValue(elem, Constraint.Is(KnownType.Int, "array index must be an int"))
val next = acc match {
case KnownType.Array(innerTy) => innerTy
case ? => ? // we can keep indexing an unknown type
case nonArrayTy =>
ctx.error(
Error.TypeMismatch(
elem.pos,
KnownType.Array(?),
acc,
"cannot index into a non-array"
)
)
?
}
(next, idxTyped)
}
microWacc.ArrayElem(
microWacc.Ident(id.v, id.uid)(arrayTy),
indicesTyped
)(elemTy.satisfies(constraint, value.pos))
case ast.Fst(elem) =>
val elemTyped = checkLValue(
elem,
Constraint.Is(KnownType.Pair(?, ?), "fst must be applied to a pair")
)
microWacc.ArrayElem(
elemTyped,
NonEmptyList.of(microWacc.IntLiter(0))
)(elemTyped.ty match {
case KnownType.Pair(left, _) =>
left.satisfies(constraint, elem.pos)
case _ => ctx.error(Error.InternalError(elem.pos, "fst must be applied to a pair"))
})
case ast.Snd(elem) =>
val elemTyped = checkLValue(
elem,
Constraint.Is(KnownType.Pair(?, ?), "snd must be applied to a pair")
)
microWacc.ArrayElem(
elemTyped,
NonEmptyList.of(microWacc.IntLiter(1))
)(elemTyped.ty match {
case KnownType.Pair(_, right) =>
right.satisfies(constraint, elem.pos)
case _ => ctx.error(Error.InternalError(elem.pos, "snd must be applied to a pair"))
})
} }
} }

View File

@ -36,12 +36,22 @@ def compile(contents: String): Int = {
given errors: mutable.Builder[Error, List[Error]] = List.newBuilder given errors: mutable.Builder[Error, List[Error]] = List.newBuilder
val (names, funcs) = renamer.rename(prog) val (names, funcs) = renamer.rename(prog)
given ctx: typeChecker.TypeCheckerCtx = typeChecker.TypeCheckerCtx(names, funcs, errors) given ctx: typeChecker.TypeCheckerCtx = typeChecker.TypeCheckerCtx(names, funcs, errors)
typeChecker.check(prog) val typedProg = typeChecker.check(prog)
if (errors.result.nonEmpty) { if (errors.result.nonEmpty) {
given errorContent: String = contents given errorContent: String = contents
errors.result.foreach(printError) errors.result
200 .map { error =>
} else 0 printError(error)
error match {
case _: Error.InternalError => 201
case _ => 200
}
}
.max()
} else {
println(typedProg)
0
}
case Failure(msg) => case Failure(msg) =>
println(msg) println(msg)
100 100