feat: type checker without satisfies implemented

Co-Authored-By: jt2622
This commit is contained in:
Gleb Koval 2025-02-06 20:26:15 +00:00
parent 88ec08a023
commit 6548d895d5
Signed by: cyclane
GPG Key ID: 15E168A8B332382C
4 changed files with 277 additions and 121 deletions

View File

@ -1,14 +1,13 @@
package wacc package wacc
import wacc.ast.Expr import wacc.ast.Position
import wacc.types._
enum Error { enum Error {
case DuplicateDeclaration(ident: ast.Ident) case DuplicateDeclaration(ident: ast.Ident)
case UndefinedIdentifier(ident: ast.Ident, identType: renamer.IdentType) case UndefinedIdentifier(ident: ast.Ident, identType: renamer.IdentType)
case FunctionParamsMismatch(expected: Int, got: Int) case FunctionParamsMismatch(expected: Int, got: Int) // TODO not fine
case TypeMismatch(expected: types.SemType, got: types.SemType)
case InvalidArrayAccess(ty: types.SemType) case TypeMismatch(pos: Position, expected: SemType, got: SemType, msg: String)
case InvalidPairAccess(ty: types.SemType) case InternalError(pos: Position, msg: String)
case ReturnTypeMismatch(expected: types.SemType, got: types.SemType)
case NonBooleanCondition(expr: Expr)
} }

View File

@ -9,7 +9,9 @@ import cats.data.NonEmptyList
object ast { object ast {
// Expressions // Expressions
sealed trait Expr extends RValue sealed trait Expr extends RValue {
val pos: Position
}
sealed trait Expr1 extends Expr sealed trait Expr1 extends Expr
sealed trait Expr2 extends Expr1 sealed trait Expr2 extends Expr1
sealed trait Expr3 extends Expr2 sealed trait Expr3 extends Expr2
@ -18,43 +20,43 @@ object ast {
sealed trait Expr6 extends Expr5 sealed trait Expr6 extends Expr5
// Atoms // Atoms
case class IntLiter(v: Int)(pos: Position) extends Expr6 case class IntLiter(v: Int)(val pos: Position) extends Expr6
object IntLiter extends ParserBridgePos1[Int, IntLiter] object IntLiter extends ParserBridgePos1[Int, IntLiter]
case class BoolLiter(v: Boolean)(pos: Position) extends Expr6 case class BoolLiter(v: Boolean)(val pos: Position) extends Expr6
object BoolLiter extends ParserBridgePos1[Boolean, BoolLiter] object BoolLiter extends ParserBridgePos1[Boolean, BoolLiter]
case class CharLiter(v: Char)(pos: Position) extends Expr6 case class CharLiter(v: Char)(val pos: Position) extends Expr6
object CharLiter extends ParserBridgePos1[Char, CharLiter] object CharLiter extends ParserBridgePos1[Char, CharLiter]
case class StrLiter(v: String)(pos: Position) extends Expr6 case class StrLiter(v: String)(val pos: Position) extends Expr6
object StrLiter extends ParserBridgePos1[String, StrLiter] object StrLiter extends ParserBridgePos1[String, StrLiter]
case class PairLiter()(pos: Position) extends Expr6 case class PairLiter()(val pos: Position) extends Expr6
object PairLiter extends ParserBridgePos0[PairLiter] object PairLiter extends ParserBridgePos0[PairLiter]
case class Ident(v: String, var uid: Int = -1)(pos: Position) extends Expr6 with LValue case class Ident(v: String, var uid: Int = -1)(val pos: Position) extends Expr6 with LValue
object Ident extends ParserBridgePos1[String, Ident] { object Ident extends ParserBridgePos1[String, Ident] {
def apply(v: String)(pos: Position): Ident = new Ident(v)(pos) def apply(v: String)(pos: Position): Ident = new Ident(v)(pos)
} }
case class ArrayElem(name: Ident, indices: NonEmptyList[Expr])(pos: Position) case class ArrayElem(name: Ident, indices: NonEmptyList[Expr])(val pos: Position)
extends Expr6 extends Expr6
with LValue with LValue
object ArrayElem extends ParserBridgePos1[NonEmptyList[Expr], Ident => ArrayElem] { object ArrayElem extends ParserBridgePos1[NonEmptyList[Expr], Ident => ArrayElem] {
def apply(a: NonEmptyList[Expr])(pos: Position): Ident => ArrayElem = def apply(a: NonEmptyList[Expr])(pos: Position): Ident => ArrayElem =
name => ArrayElem(name, a)(pos) name => ArrayElem(name, a)(pos)
} }
case class Parens(expr: Expr)(pos: Position) extends Expr6 case class Parens(expr: Expr)(val pos: Position) extends Expr6
object Parens extends ParserBridgePos1[Expr, Parens] object Parens extends ParserBridgePos1[Expr, Parens]
// Unary operators // Unary operators
sealed trait UnaryOp extends Expr { sealed trait UnaryOp extends Expr {
val x: Expr val x: Expr
} }
case class Negate(x: Expr6)(pos: Position) extends Expr6 with UnaryOp case class Negate(x: Expr6)(val pos: Position) extends Expr6 with UnaryOp
object Negate extends ParserBridgePos1[Expr6, Negate] object Negate extends ParserBridgePos1[Expr6, Negate]
case class Not(x: Expr6)(pos: Position) extends Expr6 with UnaryOp case class Not(x: Expr6)(val pos: Position) extends Expr6 with UnaryOp
object Not extends ParserBridgePos1[Expr6, Not] object Not extends ParserBridgePos1[Expr6, Not]
case class Len(x: Expr6)(pos: Position) extends Expr6 with UnaryOp case class Len(x: Expr6)(val pos: Position) extends Expr6 with UnaryOp
object Len extends ParserBridgePos1[Expr6, Len] object Len extends ParserBridgePos1[Expr6, Len]
case class Ord(x: Expr6)(pos: Position) extends Expr6 with UnaryOp case class Ord(x: Expr6)(val pos: Position) extends Expr6 with UnaryOp
object Ord extends ParserBridgePos1[Expr6, Ord] object Ord extends ParserBridgePos1[Expr6, Ord]
case class Chr(x: Expr6)(pos: Position) extends Expr6 with UnaryOp case class Chr(x: Expr6)(val pos: Position) extends Expr6 with UnaryOp
object Chr extends ParserBridgePos1[Expr6, Chr] object Chr extends ParserBridgePos1[Expr6, Chr]
// Binary operators // Binary operators
@ -62,59 +64,59 @@ object ast {
val x: Expr val x: Expr
val y: Expr val y: Expr
} }
case class Add(x: Expr4, y: Expr5)(pos: Position) extends Expr4 with BinaryOp case class Add(x: Expr4, y: Expr5)(val pos: Position) extends Expr4 with BinaryOp
object Add extends ParserBridgePos2[Expr4, Expr5, Add] object Add extends ParserBridgePos2[Expr4, Expr5, Add]
case class Sub(x: Expr4, y: Expr5)(pos: Position) extends Expr4 with BinaryOp case class Sub(x: Expr4, y: Expr5)(val pos: Position) extends Expr4 with BinaryOp
object Sub extends ParserBridgePos2[Expr4, Expr5, Sub] object Sub extends ParserBridgePos2[Expr4, Expr5, Sub]
case class Mul(x: Expr5, y: Expr6)(pos: Position) extends Expr5 with BinaryOp case class Mul(x: Expr5, y: Expr6)(val pos: Position) extends Expr5 with BinaryOp
object Mul extends ParserBridgePos2[Expr5, Expr6, Mul] object Mul extends ParserBridgePos2[Expr5, Expr6, Mul]
case class Div(x: Expr5, y: Expr6)(pos: Position) extends Expr5 with BinaryOp case class Div(x: Expr5, y: Expr6)(val pos: Position) extends Expr5 with BinaryOp
object Div extends ParserBridgePos2[Expr5, Expr6, Div] object Div extends ParserBridgePos2[Expr5, Expr6, Div]
case class Mod(x: Expr5, y: Expr6)(pos: Position) extends Expr5 with BinaryOp case class Mod(x: Expr5, y: Expr6)(val pos: Position) extends Expr5 with BinaryOp
object Mod extends ParserBridgePos2[Expr5, Expr6, Mod] object Mod extends ParserBridgePos2[Expr5, Expr6, Mod]
case class Greater(x: Expr4, y: Expr4)(pos: Position) extends Expr3 with BinaryOp case class Greater(x: Expr4, y: Expr4)(val pos: Position) extends Expr3 with BinaryOp
object Greater extends ParserBridgePos2[Expr4, Expr4, Greater] object Greater extends ParserBridgePos2[Expr4, Expr4, Greater]
case class GreaterEq(x: Expr4, y: Expr4)(pos: Position) extends Expr3 with BinaryOp case class GreaterEq(x: Expr4, y: Expr4)(val pos: Position) extends Expr3 with BinaryOp
object GreaterEq extends ParserBridgePos2[Expr4, Expr4, GreaterEq] object GreaterEq extends ParserBridgePos2[Expr4, Expr4, GreaterEq]
case class Less(x: Expr4, y: Expr4)(pos: Position) extends Expr3 with BinaryOp case class Less(x: Expr4, y: Expr4)(val pos: Position) extends Expr3 with BinaryOp
object Less extends ParserBridgePos2[Expr4, Expr4, Less] object Less extends ParserBridgePos2[Expr4, Expr4, Less]
case class LessEq(x: Expr4, y: Expr4)(pos: Position) extends Expr3 with BinaryOp case class LessEq(x: Expr4, y: Expr4)(val pos: Position) extends Expr3 with BinaryOp
object LessEq extends ParserBridgePos2[Expr4, Expr4, LessEq] object LessEq extends ParserBridgePos2[Expr4, Expr4, LessEq]
case class Eq(x: Expr3, y: Expr3)(pos: Position) extends Expr2 with BinaryOp case class Eq(x: Expr3, y: Expr3)(val pos: Position) extends Expr2 with BinaryOp
object Eq extends ParserBridgePos2[Expr3, Expr3, Eq] object Eq extends ParserBridgePos2[Expr3, Expr3, Eq]
case class Neq(x: Expr3, y: Expr3)(pos: Position) extends Expr2 with BinaryOp case class Neq(x: Expr3, y: Expr3)(val pos: Position) extends Expr2 with BinaryOp
object Neq extends ParserBridgePos2[Expr3, Expr3, Neq] object Neq extends ParserBridgePos2[Expr3, Expr3, Neq]
case class And(x: Expr2, y: Expr1)(pos: Position) extends Expr1 with BinaryOp case class And(x: Expr2, y: Expr1)(val pos: Position) extends Expr1 with BinaryOp
object And extends ParserBridgePos2[Expr2, Expr1, And] object And extends ParserBridgePos2[Expr2, Expr1, And]
case class Or(x: Expr1, y: Expr)(pos: Position) extends Expr with BinaryOp case class Or(x: Expr1, y: Expr)(val pos: Position) extends Expr with BinaryOp
object Or extends ParserBridgePos2[Expr1, Expr, Or] object Or extends ParserBridgePos2[Expr1, Expr, Or]
// Types // Types
sealed trait Type sealed trait Type
sealed trait BaseType extends Type with PairElemType sealed trait BaseType extends Type with PairElemType
case class IntType()(pos: Position) extends BaseType case class IntType()(val pos: Position) extends BaseType
object IntType extends ParserBridgePos0[IntType] object IntType extends ParserBridgePos0[IntType]
case class BoolType()(pos: Position) extends BaseType case class BoolType()(val pos: Position) extends BaseType
object BoolType extends ParserBridgePos0[BoolType] object BoolType extends ParserBridgePos0[BoolType]
case class CharType()(pos: Position) extends BaseType case class CharType()(val pos: Position) extends BaseType
object CharType extends ParserBridgePos0[CharType] object CharType extends ParserBridgePos0[CharType]
case class StringType()(pos: Position) extends BaseType case class StringType()(val pos: Position) extends BaseType
object StringType extends ParserBridgePos0[StringType] object StringType extends ParserBridgePos0[StringType]
case class ArrayType(elemType: Type, dimensions: Int)(pos: Position) case class ArrayType(elemType: Type, dimensions: Int)(val pos: Position)
extends Type extends Type
with PairElemType with PairElemType
object ArrayType extends ParserBridgePos1[Int, Type => ArrayType] { object ArrayType extends ParserBridgePos1[Int, Type => ArrayType] {
def apply(a: Int)(pos: Position): Type => ArrayType = elemType => ArrayType(elemType, a)(pos) def apply(a: Int)(pos: Position): Type => ArrayType = elemType => ArrayType(elemType, a)(pos)
} }
case class PairType(fst: PairElemType, snd: PairElemType)(pos: Position) extends Type case class PairType(fst: PairElemType, snd: PairElemType)(val pos: Position) extends Type
object PairType extends ParserBridgePos2[PairElemType, PairElemType, PairType] object PairType extends ParserBridgePos2[PairElemType, PairElemType, PairType]
sealed trait PairElemType sealed trait PairElemType
case class UntypedPairType()(pos: Position) extends PairElemType case class UntypedPairType()(val pos: Position) extends PairElemType
object UntypedPairType extends ParserBridgePos0[UntypedPairType] object UntypedPairType extends ParserBridgePos0[UntypedPairType]
// waccadoodledo // waccadoodledo
case class Program(funcs: List[FuncDecl], main: NonEmptyList[Stmt])(pos: Position) case class Program(funcs: List[FuncDecl], main: NonEmptyList[Stmt])(val pos: Position)
object Program extends ParserBridgePos2[List[FuncDecl], NonEmptyList[Stmt], Program] object Program extends ParserBridgePos2[List[FuncDecl], NonEmptyList[Stmt], Program]
// Function Definitions // Function Definitions
@ -123,7 +125,7 @@ object ast {
name: Ident, name: Ident,
params: List[Param], params: List[Param],
body: NonEmptyList[Stmt] body: NonEmptyList[Stmt]
)(pos: Position) )(val pos: Position)
object FuncDecl object FuncDecl
extends ParserBridgePos2[ extends ParserBridgePos2[
List[Param], List[Param],
@ -136,50 +138,52 @@ object ast {
(returnType, name) => FuncDecl(returnType, name, params, body)(pos) (returnType, name) => FuncDecl(returnType, name, params, body)(pos)
} }
case class Param(paramType: Type, name: Ident)(pos: Position) case class Param(paramType: Type, name: Ident)(val pos: Position)
object Param extends ParserBridgePos2[Type, Ident, Param] object Param extends ParserBridgePos2[Type, Ident, Param]
// Statements // Statements
sealed trait Stmt sealed trait Stmt
case class Skip()(pos: Position) extends Stmt case class Skip()(val pos: Position) extends Stmt
object Skip extends ParserBridgePos0[Skip] object Skip extends ParserBridgePos0[Skip]
case class VarDecl(varType: Type, name: Ident, value: RValue)(pos: Position) extends Stmt case class VarDecl(varType: Type, name: Ident, value: RValue)(val pos: Position) extends Stmt
object VarDecl extends ParserBridgePos3[Type, Ident, RValue, VarDecl] object VarDecl extends ParserBridgePos3[Type, Ident, RValue, VarDecl]
case class Assign(lhs: LValue, value: RValue)(pos: Position) extends Stmt case class Assign(lhs: LValue, value: RValue)(val pos: Position) extends Stmt
object Assign extends ParserBridgePos2[LValue, RValue, Assign] object Assign extends ParserBridgePos2[LValue, RValue, Assign]
case class Read(lhs: LValue)(pos: Position) extends Stmt case class Read(lhs: LValue)(val pos: Position) extends Stmt
object Read extends ParserBridgePos1[LValue, Read] object Read extends ParserBridgePos1[LValue, Read]
case class Free(expr: Expr)(pos: Position) extends Stmt case class Free(expr: Expr)(val pos: Position) extends Stmt
object Free extends ParserBridgePos1[Expr, Free] object Free extends ParserBridgePos1[Expr, Free]
case class Return(expr: Expr)(pos: Position) extends Stmt case class Return(expr: Expr)(val pos: Position) extends Stmt
object Return extends ParserBridgePos1[Expr, Return] object Return extends ParserBridgePos1[Expr, Return]
case class Exit(expr: Expr)(pos: Position) extends Stmt case class Exit(expr: Expr)(val pos: Position) extends Stmt
object Exit extends ParserBridgePos1[Expr, Exit] object Exit extends ParserBridgePos1[Expr, Exit]
case class Print(expr: Expr, newline: Boolean)(pos: Position) extends Stmt case class Print(expr: Expr, newline: Boolean)(val pos: Position) extends Stmt
object Print extends ParserBridgePos2[Expr, Boolean, Print] object Print extends ParserBridgePos2[Expr, Boolean, Print]
case class If(cond: Expr, thenStmt: NonEmptyList[Stmt], elseStmt: NonEmptyList[Stmt])( case class If(cond: Expr, thenStmt: NonEmptyList[Stmt], elseStmt: NonEmptyList[Stmt])(
pos: Position val pos: Position
) extends Stmt ) extends Stmt
object If extends ParserBridgePos3[Expr, NonEmptyList[Stmt], NonEmptyList[Stmt], If] object If extends ParserBridgePos3[Expr, NonEmptyList[Stmt], NonEmptyList[Stmt], If]
case class While(cond: Expr, body: NonEmptyList[Stmt])(pos: Position) extends Stmt case class While(cond: Expr, body: NonEmptyList[Stmt])(val pos: Position) extends Stmt
object While extends ParserBridgePos2[Expr, NonEmptyList[Stmt], While] object While extends ParserBridgePos2[Expr, NonEmptyList[Stmt], While]
case class Block(stmt: NonEmptyList[Stmt])(pos: Position) extends Stmt case class Block(stmt: NonEmptyList[Stmt])(val pos: Position) extends Stmt
object Block extends ParserBridgePos1[NonEmptyList[Stmt], Block] object Block extends ParserBridgePos1[NonEmptyList[Stmt], Block]
sealed trait LValue sealed trait LValue {
val pos: Position
}
sealed trait RValue sealed trait RValue
case class ArrayLiter(elems: List[Expr])(pos: Position) extends RValue case class ArrayLiter(elems: List[Expr])(val pos: Position) extends RValue
object ArrayLiter extends ParserBridgePos1[List[Expr], ArrayLiter] object ArrayLiter extends ParserBridgePos1[List[Expr], ArrayLiter]
case class NewPair(fst: Expr, snd: Expr)(pos: Position) extends RValue case class NewPair(fst: Expr, snd: Expr)(val pos: Position) extends RValue
object NewPair extends ParserBridgePos2[Expr, Expr, NewPair] object NewPair extends ParserBridgePos2[Expr, Expr, NewPair]
case class Call(name: Ident, args: List[Expr])(pos: Position) extends RValue case class Call(name: Ident, args: List[Expr])(val pos: Position) extends RValue
object Call extends ParserBridgePos2[Ident, List[Expr], Call] object Call extends ParserBridgePos2[Ident, List[Expr], Call]
sealed trait PairElem extends LValue with RValue sealed trait PairElem extends LValue with RValue
case class Fst(elem: LValue)(pos: Position) extends PairElem case class Fst(elem: LValue)(val pos: Position) extends PairElem
object Fst extends ParserBridgePos1[LValue, Fst] object Fst extends ParserBridgePos1[LValue, Fst]
case class Snd(elem: LValue)(pos: Position) extends PairElem case class Snd(elem: LValue)(val pos: Position) extends PairElem
object Snd extends ParserBridgePos1[LValue, Snd] object Snd extends ParserBridgePos1[LValue, Snd]
// Parser bridges // Parser bridges

View File

@ -1,73 +1,214 @@
package wacc package wacc
import cats.data.{Validated, ValidatedNel}
import cats.implicits.*
import wacc.ast.*
import wacc.types.*
import wacc.Error.*
import wacc.renamer.IdentType
import scala.collection.mutable import scala.collection.mutable
case class TypeCheckerCtx(globalNames: Map[Ident, SemType])
object typeChecker { object typeChecker {
import wacc.ast._
import wacc.types._
def checkExpr(expr: Expr)(using ctx: TypeCheckerCtx): ValidatedNel[Error, KnownType] = expr match case class TypeCheckerCtx(
case IntLiter(_) => KnownType.Int.validNel globalNames: Map[Ident, SemType],
case BoolLiter(_) => KnownType.Bool.validNel errors: mutable.Builder[Error, List[Error]]
case CharLiter(_) => KnownType.Char.validNel ) {
case StrLiter(_) => KnownType.String.validNel def typeOf(ident: Ident): SemType = globalNames.withDefault { case Ident(_, -1) => ? }(ident)
case id @ Ident(_, _) =>
ctx.globalNames
.get(id)
.toValidNel(Error.UndefinedIdentifier(id, IdentType.Var))
.andThen {
case k: KnownType => Validated.validNel(k)
case _ =>
Validated.invalidNel(
Error.TypeMismatch(KnownType.Int, ?)
) // insert some shenanigans here
}
case Add(lhs, rhs) => checkBinaryOp(lhs, rhs, KnownType.Int)
case Sub(lhs, rhs) => checkBinaryOp(lhs, rhs, KnownType.Int)
case Mul(lhs, rhs) => checkBinaryOp(lhs, rhs, KnownType.Int)
case Div(lhs, rhs) => checkBinaryOp(lhs, rhs, KnownType.Int)
case Mod(lhs, rhs) => checkBinaryOp(lhs, rhs, KnownType.Int)
case Eq(lhs, rhs) => checkBinaryOp(lhs, rhs, KnownType.Bool, allowWeakening = true)
case Neq(lhs, rhs) => checkBinaryOp(lhs, rhs, KnownType.Bool, allowWeakening = true)
case And(lhs, rhs) => checkBinaryOp(lhs, rhs, KnownType.Bool)
case Or(lhs, rhs) => checkBinaryOp(lhs, rhs, KnownType.Bool)
case _ => Error.TypeMismatch(KnownType.Int, KnownType.Bool).invalidNel
private def checkBinaryOp( def error(err: Error): SemType =
lhs: Expr, errors += err
rhs: Expr, ?
expected: KnownType, }
allowWeakening: Boolean = false
)(using ctx: TypeCheckerCtx): ValidatedNel[Error, KnownType] = enum Constraint {
(checkExpr(lhs), checkExpr(rhs)).mapN { (lt, rt) => case Unconstrained
if (lt == expected && rt == expected) expected case Is(ty: SemType, msg: String)
else if (allowWeakening && isCompatible(lt, rt)) KnownType.Bool case IsSymmetricCompatible(ty: SemType, msg: String)
else return Error.TypeMismatch(expected, rt).invalidNel case IsUnweakanable(ty: SemType, msg: String)
case IsVar(msg: String)
case IsEither(ty1: SemType, ty2: SemType, msg: String)
case Never(msg: String)
}
extension (ty: SemType)
infix def satisfies(constraint: Constraint): SemType = (ty, constraint) match {
case (KnownType.Array(KnownType.Char), Constraint.Is(KnownType.String, _)) =>
KnownType.String
case (
KnownType.String,
Constraint.IsSymmetricCompatible(KnownType.Array(KnownType.Char), _)
) =>
KnownType.String
case (ty, Constraint.IsSymmetricCompatible(ty2, msg)) => ty satisfies Constraint.Is(ty2, msg)
case (ty, Constraint.Is(ty2, msg)) => ty satisfies Constraint.IsUnweakanable(ty2, msg)
} }
def isCompatible(t1: SemType, t2: SemType): Boolean = (t1, t2) match def check(prog: Program)(using
case (KnownType.String, KnownType.Array(KnownType.Char)) => true // char[] can weaken to string ctx: TypeCheckerCtx
case (KnownType.Array(KnownType.Char), KnownType.String) => false // string cannot weaken back ): Unit = {
case _ => t1 == t2 prog.funcs.foreach { case FuncDecl(_, name, _, stmts) =>
val retType = ctx.typeOf(name)
stmts.toList.foreach(
checkStmt(_, Constraint.Is(retType, s"function ${name.v} must return $retType"))
)
}
prog.main.toList.foreach(checkStmt(_, Constraint.Never("main function must not return")))
}
def checkProgram(prog: Program): ValidatedNel[Error, Unit] = private def checkStmt(stmt: Stmt, returnConstraint: Constraint)(using
ctx: TypeCheckerCtx
): Unit = stmt match {
case VarDecl(_, name, value) =>
val expectedTy = ctx.typeOf(name)
checkValue(
value,
Constraint.Is(
expectedTy,
s"variable ${name.v} must be assigned a value of type $expectedTy"
)
)
case Assign(lhs, rhs) =>
val lhsTy = checkValue(lhs, Constraint.Unconstrained)
checkValue(rhs, Constraint.Is(lhsTy, s"assignment must have type $lhsTy"))
case Read(lhs) =>
checkValue(
lhs,
Constraint.IsEither(KnownType.Int, KnownType.Char, "read must be int or char")
)
case Free(lhs) =>
checkValue(
lhs,
Constraint.IsEither(
KnownType.Array(?),
KnownType.Pair(?, ?),
"free must be an array or pair"
)
)
case Return(expr) =>
checkValue(expr, returnConstraint)
case Exit(expr) =>
checkValue(expr, Constraint.Is(KnownType.Int, "exit value must be int"))
case Print(expr, _) =>
// This constraint should never fail, the scope-checker should have caught it already
checkValue(expr, Constraint.IsVar("print value must be a variable"))
case If(cond, thenStmt, elseStmt) =>
checkValue(cond, Constraint.Is(KnownType.Bool, "if condition must be a bool"))
thenStmt.toList.foreach(checkStmt(_, returnConstraint))
elseStmt.toList.foreach(checkStmt(_, returnConstraint))
case While(cond, body) =>
checkValue(cond, Constraint.Is(KnownType.Bool, "while condition must be a bool"))
body.toList.foreach(checkStmt(_, returnConstraint))
case Block(body) =>
body.toList.foreach(checkStmt(_, returnConstraint))
case Skip() => ()
}
given mutable.Builder[Error, List[Error]] = List.newBuilder private def checkValue(value: LValue | RValue | Expr, constraint: Constraint)(using
ctx: TypeCheckerCtx
): SemType = value match {
case IntLiter(_) => KnownType.Int satisfies constraint
case BoolLiter(_) => KnownType.Bool satisfies constraint
case CharLiter(_) => KnownType.Char satisfies constraint
case StrLiter(_) => KnownType.String satisfies constraint
case PairLiter() => KnownType.Pair(?, ?) satisfies constraint
case id: Ident =>
ctx.typeOf(id) satisfies constraint
case ArrayElem(id, indices) =>
val arrayTy = ctx.typeOf(id)
val elemTy = indices.toList.foldRight(arrayTy) { (elem, acc) =>
checkValue(elem, Constraint.Is(KnownType.Int, "array index must be an int"))
acc match {
case KnownType.Array(innerTy) => innerTy
case _ =>
ctx.error(
Error.TypeMismatch(elem.pos, KnownType.Array(?), acc, "cannot index into a non-array")
)
}
}
elemTy satisfies constraint
case Parens(expr) => checkValue(expr, constraint)
case ArrayLiter(elems) =>
KnownType.Array(elems.foldRight[SemType](?) { case (elem, acc) =>
checkValue(
elem,
Constraint.IsSymmetricCompatible(acc, "array elements must have the same type")
)
}) satisfies constraint
case NewPair(fst, snd) =>
KnownType.Pair(
checkValue(fst, Constraint.Unconstrained),
checkValue(snd, Constraint.Unconstrained)
) satisfies constraint
case Call(id, args) =>
val funcTy = ctx.typeOf(id)
funcTy match {
case KnownType.Func(retTy, paramTys) => // TODO do we check argument lengths match
args.zip(paramTys).foreach { case (arg, paramTy) =>
checkValue(arg, Constraint.Is(paramTy, s"argument type mismatch in function ${id.v}"))
}
retTy satisfies constraint
// Should never happen, the scope-checker should have caught this already
// ctx error had it not
case _ => ctx.error(Error.InternalError(id.pos, "function call to non-function"))
}
case Fst(elem) =>
checkValue(
elem,
Constraint.Is(KnownType.Pair(?, ?), "fst must be applied to a pair")
) match {
case KnownType.Pair(left, _) => left satisfies constraint
case ? => ? satisfies constraint
case _ => ctx.error(Error.InternalError(elem.pos, "fst must be applied to a pair"))
} // satisfies constraint
case Snd(elem) =>
checkValue(
elem,
Constraint.Is(KnownType.Pair(?, ?), "snd must be applied to a pair")
) match {
case KnownType.Pair(_, right) => right satisfies constraint
case ? => ? satisfies constraint
case _ => ctx.error(Error.InternalError(elem.pos, "snd must be applied to a pair"))
}
val globalNames = renamer.rename(prog) // Unary operators
case Negate(x) =>
given ctx: TypeCheckerCtx = TypeCheckerCtx(globalNames) checkValue(x, Constraint.Is(KnownType.Int, "negation must be applied to an int"))
KnownType.Int satisfies constraint
// TODO not implemented case Not(x) =>
val funcCheck = prog.funcs.parTraverse(checkFuncDecl) checkValue(x, Constraint.Is(KnownType.Bool, "logical not must be applied to a bool"))
val mainCheck = prog.main.toList.parTraverse(checkStmt) KnownType.Bool satisfies constraint
(funcCheck, mainCheck).mapN((_, _) => ()) case Len(x) =>
checkValue(x, Constraint.Is(KnownType.Array(?), "len must be applied to an array"))
KnownType.Int satisfies constraint
case Ord(x) =>
checkValue(x, Constraint.Is(KnownType.Char, "ord must be applied to a char"))
KnownType.Int satisfies constraint
case Chr(x) =>
checkValue(x, Constraint.Is(KnownType.Int, "chr must be applied to an int"))
KnownType.Char satisfies constraint
// Binary operators
case op: (Add | Sub | Mul | Div | Mod) =>
val operand = Constraint.Is(KnownType.Int, "binary operator must be applied to an int")
checkValue(op.x, operand)
checkValue(op.y, operand)
KnownType.Int satisfies constraint
case op: (Eq | Neq) =>
val xTy = checkValue(op.x, Constraint.Unconstrained)
checkValue(op.y, Constraint.Is(xTy, "equality must be applied to values of the same type"))
KnownType.Bool satisfies constraint
case op: (Less | LessEq | Greater | GreaterEq) =>
val xTy = checkValue(
op.x,
Constraint.IsEither(
KnownType.Int,
KnownType.Char,
"comparison must be applied to an int or char"
)
)
checkValue(op.y, Constraint.Is(xTy, "comparison must be applied to values of the same type"))
KnownType.Bool satisfies constraint
case op: (And | Or) =>
val operand = Constraint.Is(KnownType.Bool, "logical operator must be applied to a bool")
checkValue(op.x, operand)
checkValue(op.y, operand)
KnownType.Bool satisfies constraint
}
} }

View File

@ -3,7 +3,19 @@ package wacc
object types { object types {
import ast._ import ast._
sealed trait SemType sealed trait SemType {
override def toString(): String = this match {
case KnownType.Int => "int"
case KnownType.Bool => "bool"
case KnownType.Char => "char"
case KnownType.String => "string"
case KnownType.Array(elem) => s"$elem[]"
case KnownType.Pair(left, right) => s"pair($left, $right)"
case KnownType.Func(ret, params) => s"function returning $ret with params $params"
case ? => "?"
}
}
case object ? extends SemType case object ? extends SemType
enum KnownType extends SemType { enum KnownType extends SemType {
case Int case Int