diff --git a/src/main/wacc/Error.scala b/src/main/wacc/Error.scala index 6370925..08b6e3b 100644 --- a/src/main/wacc/Error.scala +++ b/src/main/wacc/Error.scala @@ -1,8 +1,14 @@ package wacc +import wacc.ast.Expr + enum Error { case DuplicateDeclaration(ident: ast.Ident) case UndefinedIdentifier(ident: ast.Ident, identType: renamer.IdentType) case FunctionParamsMismatch(expected: Int, got: Int) case TypeMismatch(expected: types.SemType, got: types.SemType) + case InvalidArrayAccess(ty: types.SemType) + case InvalidPairAccess(ty: types.SemType) + case ReturnTypeMismatch(expected: types.SemType, got: types.SemType) + case NonBooleanCondition(expr: Expr) } diff --git a/src/main/wacc/typeChecker.scala b/src/main/wacc/typeChecker.scala new file mode 100644 index 0000000..f08087b --- /dev/null +++ b/src/main/wacc/typeChecker.scala @@ -0,0 +1,73 @@ +package wacc + +import cats.data.{Validated, ValidatedNel} +import cats.implicits.* +import wacc.ast.* +import wacc.types.* +import wacc.Error.* +import wacc.renamer.IdentType + +import scala.collection.mutable + +case class TypeCheckerCtx(globalNames: Map[Ident, SemType]) + +object typeChecker { + + def checkExpr(expr: Expr)(using ctx: TypeCheckerCtx): ValidatedNel[Error, KnownType] = expr match + case IntLiter(_) => KnownType.Int.validNel + case BoolLiter(_) => KnownType.Bool.validNel + case CharLiter(_) => KnownType.Char.validNel + case StrLiter(_) => KnownType.String.validNel + case id @ Ident(_, _) => + ctx.globalNames + .get(id) + .toValidNel(Error.UndefinedIdentifier(id, IdentType.Var)) + .andThen { + case k: KnownType => Validated.validNel(k) + case _ => + Validated.invalidNel( + Error.TypeMismatch(KnownType.Int, ?) + ) // insert some shenanigans here + } + case Add(lhs, rhs) => checkBinaryOp(lhs, rhs, KnownType.Int) + case Sub(lhs, rhs) => checkBinaryOp(lhs, rhs, KnownType.Int) + case Mul(lhs, rhs) => checkBinaryOp(lhs, rhs, KnownType.Int) + case Div(lhs, rhs) => checkBinaryOp(lhs, rhs, KnownType.Int) + case Mod(lhs, rhs) => checkBinaryOp(lhs, rhs, KnownType.Int) + case Eq(lhs, rhs) => checkBinaryOp(lhs, rhs, KnownType.Bool, allowWeakening = true) + case Neq(lhs, rhs) => checkBinaryOp(lhs, rhs, KnownType.Bool, allowWeakening = true) + case And(lhs, rhs) => checkBinaryOp(lhs, rhs, KnownType.Bool) + case Or(lhs, rhs) => checkBinaryOp(lhs, rhs, KnownType.Bool) + case _ => Error.TypeMismatch(KnownType.Int, KnownType.Bool).invalidNel + + private def checkBinaryOp( + lhs: Expr, + rhs: Expr, + expected: KnownType, + allowWeakening: Boolean = false + )(using ctx: TypeCheckerCtx): ValidatedNel[Error, KnownType] = + (checkExpr(lhs), checkExpr(rhs)).mapN { (lt, rt) => + if (lt == expected && rt == expected) expected + else if (allowWeakening && isCompatible(lt, rt)) KnownType.Bool + else return Error.TypeMismatch(expected, rt).invalidNel + } + + def isCompatible(t1: SemType, t2: SemType): Boolean = (t1, t2) match + case (KnownType.String, KnownType.Array(KnownType.Char)) => true // char[] can weaken to string + case (KnownType.Array(KnownType.Char), KnownType.String) => false // string cannot weaken back + case _ => t1 == t2 + + def checkProgram(prog: Program): ValidatedNel[Error, Unit] = + + given mutable.Builder[Error, List[Error]] = List.newBuilder + + val globalNames = renamer.rename(prog) + + given ctx: TypeCheckerCtx = TypeCheckerCtx(globalNames) + + // TODO not implemented + val funcCheck = prog.funcs.parTraverse(checkFuncDecl) + val mainCheck = prog.main.toList.parTraverse(checkStmt) + (funcCheck, mainCheck).mapN((_, _) => ()) + +}