diff --git a/src/main/wacc/Error.scala b/src/main/wacc/Error.scala index 08b6e3b..cc019d3 100644 --- a/src/main/wacc/Error.scala +++ b/src/main/wacc/Error.scala @@ -1,14 +1,13 @@ package wacc -import wacc.ast.Expr +import wacc.ast.Position +import wacc.types._ enum Error { case DuplicateDeclaration(ident: ast.Ident) case UndefinedIdentifier(ident: ast.Ident, identType: renamer.IdentType) - case FunctionParamsMismatch(expected: Int, got: Int) - case TypeMismatch(expected: types.SemType, got: types.SemType) - case InvalidArrayAccess(ty: types.SemType) - case InvalidPairAccess(ty: types.SemType) - case ReturnTypeMismatch(expected: types.SemType, got: types.SemType) - case NonBooleanCondition(expr: Expr) + case FunctionParamsMismatch(expected: Int, got: Int) // TODO not fine + + case TypeMismatch(pos: Position, expected: SemType, got: SemType, msg: String) + case InternalError(pos: Position, msg: String) } diff --git a/src/main/wacc/ast.scala b/src/main/wacc/ast.scala index 123adb5..8ea99d7 100644 --- a/src/main/wacc/ast.scala +++ b/src/main/wacc/ast.scala @@ -9,7 +9,9 @@ import cats.data.NonEmptyList object ast { // Expressions - sealed trait Expr extends RValue + sealed trait Expr extends RValue { + val pos: Position + } sealed trait Expr1 extends Expr sealed trait Expr2 extends Expr1 sealed trait Expr3 extends Expr2 @@ -18,43 +20,43 @@ object ast { sealed trait Expr6 extends Expr5 // Atoms - case class IntLiter(v: Int)(pos: Position) extends Expr6 + case class IntLiter(v: Int)(val pos: Position) extends Expr6 object IntLiter extends ParserBridgePos1[Int, IntLiter] - case class BoolLiter(v: Boolean)(pos: Position) extends Expr6 + case class BoolLiter(v: Boolean)(val pos: Position) extends Expr6 object BoolLiter extends ParserBridgePos1[Boolean, BoolLiter] - case class CharLiter(v: Char)(pos: Position) extends Expr6 + case class CharLiter(v: Char)(val pos: Position) extends Expr6 object CharLiter extends ParserBridgePos1[Char, CharLiter] - case class StrLiter(v: String)(pos: Position) extends Expr6 + case class StrLiter(v: String)(val pos: Position) extends Expr6 object StrLiter extends ParserBridgePos1[String, StrLiter] - case class PairLiter()(pos: Position) extends Expr6 + case class PairLiter()(val pos: Position) extends Expr6 object PairLiter extends ParserBridgePos0[PairLiter] - case class Ident(v: String, var uid: Int = -1)(pos: Position) extends Expr6 with LValue + case class Ident(v: String, var uid: Int = -1)(val pos: Position) extends Expr6 with LValue object Ident extends ParserBridgePos1[String, Ident] { def apply(v: String)(pos: Position): Ident = new Ident(v)(pos) } - case class ArrayElem(name: Ident, indices: NonEmptyList[Expr])(pos: Position) + case class ArrayElem(name: Ident, indices: NonEmptyList[Expr])(val pos: Position) extends Expr6 with LValue object ArrayElem extends ParserBridgePos1[NonEmptyList[Expr], Ident => ArrayElem] { def apply(a: NonEmptyList[Expr])(pos: Position): Ident => ArrayElem = name => ArrayElem(name, a)(pos) } - case class Parens(expr: Expr)(pos: Position) extends Expr6 + case class Parens(expr: Expr)(val pos: Position) extends Expr6 object Parens extends ParserBridgePos1[Expr, Parens] // Unary operators sealed trait UnaryOp extends Expr { val x: Expr } - case class Negate(x: Expr6)(pos: Position) extends Expr6 with UnaryOp + case class Negate(x: Expr6)(val pos: Position) extends Expr6 with UnaryOp object Negate extends ParserBridgePos1[Expr6, Negate] - case class Not(x: Expr6)(pos: Position) extends Expr6 with UnaryOp + case class Not(x: Expr6)(val pos: Position) extends Expr6 with UnaryOp object Not extends ParserBridgePos1[Expr6, Not] - case class Len(x: Expr6)(pos: Position) extends Expr6 with UnaryOp + case class Len(x: Expr6)(val pos: Position) extends Expr6 with UnaryOp object Len extends ParserBridgePos1[Expr6, Len] - case class Ord(x: Expr6)(pos: Position) extends Expr6 with UnaryOp + case class Ord(x: Expr6)(val pos: Position) extends Expr6 with UnaryOp object Ord extends ParserBridgePos1[Expr6, Ord] - case class Chr(x: Expr6)(pos: Position) extends Expr6 with UnaryOp + case class Chr(x: Expr6)(val pos: Position) extends Expr6 with UnaryOp object Chr extends ParserBridgePos1[Expr6, Chr] // Binary operators @@ -62,59 +64,59 @@ object ast { val x: Expr val y: Expr } - case class Add(x: Expr4, y: Expr5)(pos: Position) extends Expr4 with BinaryOp + case class Add(x: Expr4, y: Expr5)(val pos: Position) extends Expr4 with BinaryOp object Add extends ParserBridgePos2[Expr4, Expr5, Add] - case class Sub(x: Expr4, y: Expr5)(pos: Position) extends Expr4 with BinaryOp + case class Sub(x: Expr4, y: Expr5)(val pos: Position) extends Expr4 with BinaryOp object Sub extends ParserBridgePos2[Expr4, Expr5, Sub] - case class Mul(x: Expr5, y: Expr6)(pos: Position) extends Expr5 with BinaryOp + case class Mul(x: Expr5, y: Expr6)(val pos: Position) extends Expr5 with BinaryOp object Mul extends ParserBridgePos2[Expr5, Expr6, Mul] - case class Div(x: Expr5, y: Expr6)(pos: Position) extends Expr5 with BinaryOp + case class Div(x: Expr5, y: Expr6)(val pos: Position) extends Expr5 with BinaryOp object Div extends ParserBridgePos2[Expr5, Expr6, Div] - case class Mod(x: Expr5, y: Expr6)(pos: Position) extends Expr5 with BinaryOp + case class Mod(x: Expr5, y: Expr6)(val pos: Position) extends Expr5 with BinaryOp object Mod extends ParserBridgePos2[Expr5, Expr6, Mod] - case class Greater(x: Expr4, y: Expr4)(pos: Position) extends Expr3 with BinaryOp + case class Greater(x: Expr4, y: Expr4)(val pos: Position) extends Expr3 with BinaryOp object Greater extends ParserBridgePos2[Expr4, Expr4, Greater] - case class GreaterEq(x: Expr4, y: Expr4)(pos: Position) extends Expr3 with BinaryOp + case class GreaterEq(x: Expr4, y: Expr4)(val pos: Position) extends Expr3 with BinaryOp object GreaterEq extends ParserBridgePos2[Expr4, Expr4, GreaterEq] - case class Less(x: Expr4, y: Expr4)(pos: Position) extends Expr3 with BinaryOp + case class Less(x: Expr4, y: Expr4)(val pos: Position) extends Expr3 with BinaryOp object Less extends ParserBridgePos2[Expr4, Expr4, Less] - case class LessEq(x: Expr4, y: Expr4)(pos: Position) extends Expr3 with BinaryOp + case class LessEq(x: Expr4, y: Expr4)(val pos: Position) extends Expr3 with BinaryOp object LessEq extends ParserBridgePos2[Expr4, Expr4, LessEq] - case class Eq(x: Expr3, y: Expr3)(pos: Position) extends Expr2 with BinaryOp + case class Eq(x: Expr3, y: Expr3)(val pos: Position) extends Expr2 with BinaryOp object Eq extends ParserBridgePos2[Expr3, Expr3, Eq] - case class Neq(x: Expr3, y: Expr3)(pos: Position) extends Expr2 with BinaryOp + case class Neq(x: Expr3, y: Expr3)(val pos: Position) extends Expr2 with BinaryOp object Neq extends ParserBridgePos2[Expr3, Expr3, Neq] - case class And(x: Expr2, y: Expr1)(pos: Position) extends Expr1 with BinaryOp + case class And(x: Expr2, y: Expr1)(val pos: Position) extends Expr1 with BinaryOp object And extends ParserBridgePos2[Expr2, Expr1, And] - case class Or(x: Expr1, y: Expr)(pos: Position) extends Expr with BinaryOp + case class Or(x: Expr1, y: Expr)(val pos: Position) extends Expr with BinaryOp object Or extends ParserBridgePos2[Expr1, Expr, Or] // Types sealed trait Type sealed trait BaseType extends Type with PairElemType - case class IntType()(pos: Position) extends BaseType + case class IntType()(val pos: Position) extends BaseType object IntType extends ParserBridgePos0[IntType] - case class BoolType()(pos: Position) extends BaseType + case class BoolType()(val pos: Position) extends BaseType object BoolType extends ParserBridgePos0[BoolType] - case class CharType()(pos: Position) extends BaseType + case class CharType()(val pos: Position) extends BaseType object CharType extends ParserBridgePos0[CharType] - case class StringType()(pos: Position) extends BaseType + case class StringType()(val pos: Position) extends BaseType object StringType extends ParserBridgePos0[StringType] - case class ArrayType(elemType: Type, dimensions: Int)(pos: Position) + case class ArrayType(elemType: Type, dimensions: Int)(val pos: Position) extends Type with PairElemType object ArrayType extends ParserBridgePos1[Int, Type => ArrayType] { def apply(a: Int)(pos: Position): Type => ArrayType = elemType => ArrayType(elemType, a)(pos) } - case class PairType(fst: PairElemType, snd: PairElemType)(pos: Position) extends Type + case class PairType(fst: PairElemType, snd: PairElemType)(val pos: Position) extends Type object PairType extends ParserBridgePos2[PairElemType, PairElemType, PairType] sealed trait PairElemType - case class UntypedPairType()(pos: Position) extends PairElemType + case class UntypedPairType()(val pos: Position) extends PairElemType object UntypedPairType extends ParserBridgePos0[UntypedPairType] // waccadoodledo - case class Program(funcs: List[FuncDecl], main: NonEmptyList[Stmt])(pos: Position) + case class Program(funcs: List[FuncDecl], main: NonEmptyList[Stmt])(val pos: Position) object Program extends ParserBridgePos2[List[FuncDecl], NonEmptyList[Stmt], Program] // Function Definitions @@ -123,7 +125,7 @@ object ast { name: Ident, params: List[Param], body: NonEmptyList[Stmt] - )(pos: Position) + )(val pos: Position) object FuncDecl extends ParserBridgePos2[ List[Param], @@ -136,50 +138,52 @@ object ast { (returnType, name) => FuncDecl(returnType, name, params, body)(pos) } - case class Param(paramType: Type, name: Ident)(pos: Position) + case class Param(paramType: Type, name: Ident)(val pos: Position) object Param extends ParserBridgePos2[Type, Ident, Param] // Statements sealed trait Stmt - case class Skip()(pos: Position) extends Stmt + case class Skip()(val pos: Position) extends Stmt object Skip extends ParserBridgePos0[Skip] - case class VarDecl(varType: Type, name: Ident, value: RValue)(pos: Position) extends Stmt + case class VarDecl(varType: Type, name: Ident, value: RValue)(val pos: Position) extends Stmt object VarDecl extends ParserBridgePos3[Type, Ident, RValue, VarDecl] - case class Assign(lhs: LValue, value: RValue)(pos: Position) extends Stmt + case class Assign(lhs: LValue, value: RValue)(val pos: Position) extends Stmt object Assign extends ParserBridgePos2[LValue, RValue, Assign] - case class Read(lhs: LValue)(pos: Position) extends Stmt + case class Read(lhs: LValue)(val pos: Position) extends Stmt object Read extends ParserBridgePos1[LValue, Read] - case class Free(expr: Expr)(pos: Position) extends Stmt + case class Free(expr: Expr)(val pos: Position) extends Stmt object Free extends ParserBridgePos1[Expr, Free] - case class Return(expr: Expr)(pos: Position) extends Stmt + case class Return(expr: Expr)(val pos: Position) extends Stmt object Return extends ParserBridgePos1[Expr, Return] - case class Exit(expr: Expr)(pos: Position) extends Stmt + case class Exit(expr: Expr)(val pos: Position) extends Stmt object Exit extends ParserBridgePos1[Expr, Exit] - case class Print(expr: Expr, newline: Boolean)(pos: Position) extends Stmt + case class Print(expr: Expr, newline: Boolean)(val pos: Position) extends Stmt object Print extends ParserBridgePos2[Expr, Boolean, Print] case class If(cond: Expr, thenStmt: NonEmptyList[Stmt], elseStmt: NonEmptyList[Stmt])( - pos: Position + val pos: Position ) extends Stmt object If extends ParserBridgePos3[Expr, NonEmptyList[Stmt], NonEmptyList[Stmt], If] - case class While(cond: Expr, body: NonEmptyList[Stmt])(pos: Position) extends Stmt + case class While(cond: Expr, body: NonEmptyList[Stmt])(val pos: Position) extends Stmt object While extends ParserBridgePos2[Expr, NonEmptyList[Stmt], While] - case class Block(stmt: NonEmptyList[Stmt])(pos: Position) extends Stmt + case class Block(stmt: NonEmptyList[Stmt])(val pos: Position) extends Stmt object Block extends ParserBridgePos1[NonEmptyList[Stmt], Block] - sealed trait LValue + sealed trait LValue { + val pos: Position + } sealed trait RValue - case class ArrayLiter(elems: List[Expr])(pos: Position) extends RValue + case class ArrayLiter(elems: List[Expr])(val pos: Position) extends RValue object ArrayLiter extends ParserBridgePos1[List[Expr], ArrayLiter] - case class NewPair(fst: Expr, snd: Expr)(pos: Position) extends RValue + case class NewPair(fst: Expr, snd: Expr)(val pos: Position) extends RValue object NewPair extends ParserBridgePos2[Expr, Expr, NewPair] - case class Call(name: Ident, args: List[Expr])(pos: Position) extends RValue + case class Call(name: Ident, args: List[Expr])(val pos: Position) extends RValue object Call extends ParserBridgePos2[Ident, List[Expr], Call] sealed trait PairElem extends LValue with RValue - case class Fst(elem: LValue)(pos: Position) extends PairElem + case class Fst(elem: LValue)(val pos: Position) extends PairElem object Fst extends ParserBridgePos1[LValue, Fst] - case class Snd(elem: LValue)(pos: Position) extends PairElem + case class Snd(elem: LValue)(val pos: Position) extends PairElem object Snd extends ParserBridgePos1[LValue, Snd] // Parser bridges diff --git a/src/main/wacc/typeChecker.scala b/src/main/wacc/typeChecker.scala index f08087b..8ad65a6 100644 --- a/src/main/wacc/typeChecker.scala +++ b/src/main/wacc/typeChecker.scala @@ -1,73 +1,214 @@ package wacc -import cats.data.{Validated, ValidatedNel} -import cats.implicits.* -import wacc.ast.* -import wacc.types.* -import wacc.Error.* -import wacc.renamer.IdentType - import scala.collection.mutable -case class TypeCheckerCtx(globalNames: Map[Ident, SemType]) - object typeChecker { + import wacc.ast._ + import wacc.types._ - def checkExpr(expr: Expr)(using ctx: TypeCheckerCtx): ValidatedNel[Error, KnownType] = expr match - case IntLiter(_) => KnownType.Int.validNel - case BoolLiter(_) => KnownType.Bool.validNel - case CharLiter(_) => KnownType.Char.validNel - case StrLiter(_) => KnownType.String.validNel - case id @ Ident(_, _) => - ctx.globalNames - .get(id) - .toValidNel(Error.UndefinedIdentifier(id, IdentType.Var)) - .andThen { - case k: KnownType => Validated.validNel(k) - case _ => - Validated.invalidNel( - Error.TypeMismatch(KnownType.Int, ?) - ) // insert some shenanigans here - } - case Add(lhs, rhs) => checkBinaryOp(lhs, rhs, KnownType.Int) - case Sub(lhs, rhs) => checkBinaryOp(lhs, rhs, KnownType.Int) - case Mul(lhs, rhs) => checkBinaryOp(lhs, rhs, KnownType.Int) - case Div(lhs, rhs) => checkBinaryOp(lhs, rhs, KnownType.Int) - case Mod(lhs, rhs) => checkBinaryOp(lhs, rhs, KnownType.Int) - case Eq(lhs, rhs) => checkBinaryOp(lhs, rhs, KnownType.Bool, allowWeakening = true) - case Neq(lhs, rhs) => checkBinaryOp(lhs, rhs, KnownType.Bool, allowWeakening = true) - case And(lhs, rhs) => checkBinaryOp(lhs, rhs, KnownType.Bool) - case Or(lhs, rhs) => checkBinaryOp(lhs, rhs, KnownType.Bool) - case _ => Error.TypeMismatch(KnownType.Int, KnownType.Bool).invalidNel + case class TypeCheckerCtx( + globalNames: Map[Ident, SemType], + errors: mutable.Builder[Error, List[Error]] + ) { + def typeOf(ident: Ident): SemType = globalNames.withDefault { case Ident(_, -1) => ? }(ident) - private def checkBinaryOp( - lhs: Expr, - rhs: Expr, - expected: KnownType, - allowWeakening: Boolean = false - )(using ctx: TypeCheckerCtx): ValidatedNel[Error, KnownType] = - (checkExpr(lhs), checkExpr(rhs)).mapN { (lt, rt) => - if (lt == expected && rt == expected) expected - else if (allowWeakening && isCompatible(lt, rt)) KnownType.Bool - else return Error.TypeMismatch(expected, rt).invalidNel + def error(err: Error): SemType = + errors += err + ? + } + + enum Constraint { + case Unconstrained + case Is(ty: SemType, msg: String) + case IsSymmetricCompatible(ty: SemType, msg: String) + case IsUnweakanable(ty: SemType, msg: String) + case IsVar(msg: String) + case IsEither(ty1: SemType, ty2: SemType, msg: String) + case Never(msg: String) + } + + extension (ty: SemType) + infix def satisfies(constraint: Constraint): SemType = (ty, constraint) match { + case (KnownType.Array(KnownType.Char), Constraint.Is(KnownType.String, _)) => + KnownType.String + case ( + KnownType.String, + Constraint.IsSymmetricCompatible(KnownType.Array(KnownType.Char), _) + ) => + KnownType.String + case (ty, Constraint.IsSymmetricCompatible(ty2, msg)) => ty satisfies Constraint.Is(ty2, msg) + case (ty, Constraint.Is(ty2, msg)) => ty satisfies Constraint.IsUnweakanable(ty2, msg) } - def isCompatible(t1: SemType, t2: SemType): Boolean = (t1, t2) match - case (KnownType.String, KnownType.Array(KnownType.Char)) => true // char[] can weaken to string - case (KnownType.Array(KnownType.Char), KnownType.String) => false // string cannot weaken back - case _ => t1 == t2 + def check(prog: Program)(using + ctx: TypeCheckerCtx + ): Unit = { + prog.funcs.foreach { case FuncDecl(_, name, _, stmts) => + val retType = ctx.typeOf(name) + stmts.toList.foreach( + checkStmt(_, Constraint.Is(retType, s"function ${name.v} must return $retType")) + ) + } + prog.main.toList.foreach(checkStmt(_, Constraint.Never("main function must not return"))) + } - def checkProgram(prog: Program): ValidatedNel[Error, Unit] = + private def checkStmt(stmt: Stmt, returnConstraint: Constraint)(using + ctx: TypeCheckerCtx + ): Unit = stmt match { + case VarDecl(_, name, value) => + val expectedTy = ctx.typeOf(name) + checkValue( + value, + Constraint.Is( + expectedTy, + s"variable ${name.v} must be assigned a value of type $expectedTy" + ) + ) + case Assign(lhs, rhs) => + val lhsTy = checkValue(lhs, Constraint.Unconstrained) + checkValue(rhs, Constraint.Is(lhsTy, s"assignment must have type $lhsTy")) + case Read(lhs) => + checkValue( + lhs, + Constraint.IsEither(KnownType.Int, KnownType.Char, "read must be int or char") + ) + case Free(lhs) => + checkValue( + lhs, + Constraint.IsEither( + KnownType.Array(?), + KnownType.Pair(?, ?), + "free must be an array or pair" + ) + ) + case Return(expr) => + checkValue(expr, returnConstraint) + case Exit(expr) => + checkValue(expr, Constraint.Is(KnownType.Int, "exit value must be int")) + case Print(expr, _) => + // This constraint should never fail, the scope-checker should have caught it already + checkValue(expr, Constraint.IsVar("print value must be a variable")) + case If(cond, thenStmt, elseStmt) => + checkValue(cond, Constraint.Is(KnownType.Bool, "if condition must be a bool")) + thenStmt.toList.foreach(checkStmt(_, returnConstraint)) + elseStmt.toList.foreach(checkStmt(_, returnConstraint)) + case While(cond, body) => + checkValue(cond, Constraint.Is(KnownType.Bool, "while condition must be a bool")) + body.toList.foreach(checkStmt(_, returnConstraint)) + case Block(body) => + body.toList.foreach(checkStmt(_, returnConstraint)) + case Skip() => () + } - given mutable.Builder[Error, List[Error]] = List.newBuilder + private def checkValue(value: LValue | RValue | Expr, constraint: Constraint)(using + ctx: TypeCheckerCtx + ): SemType = value match { + case IntLiter(_) => KnownType.Int satisfies constraint + case BoolLiter(_) => KnownType.Bool satisfies constraint + case CharLiter(_) => KnownType.Char satisfies constraint + case StrLiter(_) => KnownType.String satisfies constraint + case PairLiter() => KnownType.Pair(?, ?) satisfies constraint + case id: Ident => + ctx.typeOf(id) satisfies constraint + case ArrayElem(id, indices) => + val arrayTy = ctx.typeOf(id) + val elemTy = indices.toList.foldRight(arrayTy) { (elem, acc) => + checkValue(elem, Constraint.Is(KnownType.Int, "array index must be an int")) + acc match { + case KnownType.Array(innerTy) => innerTy + case _ => + ctx.error( + Error.TypeMismatch(elem.pos, KnownType.Array(?), acc, "cannot index into a non-array") + ) + } + } + elemTy satisfies constraint + case Parens(expr) => checkValue(expr, constraint) + case ArrayLiter(elems) => + KnownType.Array(elems.foldRight[SemType](?) { case (elem, acc) => + checkValue( + elem, + Constraint.IsSymmetricCompatible(acc, "array elements must have the same type") + ) + }) satisfies constraint + case NewPair(fst, snd) => + KnownType.Pair( + checkValue(fst, Constraint.Unconstrained), + checkValue(snd, Constraint.Unconstrained) + ) satisfies constraint + case Call(id, args) => + val funcTy = ctx.typeOf(id) + funcTy match { + case KnownType.Func(retTy, paramTys) => // TODO do we check argument lengths match + args.zip(paramTys).foreach { case (arg, paramTy) => + checkValue(arg, Constraint.Is(paramTy, s"argument type mismatch in function ${id.v}")) + } + retTy satisfies constraint + // Should never happen, the scope-checker should have caught this already + // ctx error had it not + case _ => ctx.error(Error.InternalError(id.pos, "function call to non-function")) + } + case Fst(elem) => + checkValue( + elem, + Constraint.Is(KnownType.Pair(?, ?), "fst must be applied to a pair") + ) match { + case KnownType.Pair(left, _) => left satisfies constraint + case ? => ? satisfies constraint + case _ => ctx.error(Error.InternalError(elem.pos, "fst must be applied to a pair")) + } // satisfies constraint + case Snd(elem) => + checkValue( + elem, + Constraint.Is(KnownType.Pair(?, ?), "snd must be applied to a pair") + ) match { + case KnownType.Pair(_, right) => right satisfies constraint + case ? => ? satisfies constraint + case _ => ctx.error(Error.InternalError(elem.pos, "snd must be applied to a pair")) + } - val globalNames = renamer.rename(prog) - - given ctx: TypeCheckerCtx = TypeCheckerCtx(globalNames) - - // TODO not implemented - val funcCheck = prog.funcs.parTraverse(checkFuncDecl) - val mainCheck = prog.main.toList.parTraverse(checkStmt) - (funcCheck, mainCheck).mapN((_, _) => ()) + // Unary operators + case Negate(x) => + checkValue(x, Constraint.Is(KnownType.Int, "negation must be applied to an int")) + KnownType.Int satisfies constraint + case Not(x) => + checkValue(x, Constraint.Is(KnownType.Bool, "logical not must be applied to a bool")) + KnownType.Bool satisfies constraint + case Len(x) => + checkValue(x, Constraint.Is(KnownType.Array(?), "len must be applied to an array")) + KnownType.Int satisfies constraint + case Ord(x) => + checkValue(x, Constraint.Is(KnownType.Char, "ord must be applied to a char")) + KnownType.Int satisfies constraint + case Chr(x) => + checkValue(x, Constraint.Is(KnownType.Int, "chr must be applied to an int")) + KnownType.Char satisfies constraint + // Binary operators + case op: (Add | Sub | Mul | Div | Mod) => + val operand = Constraint.Is(KnownType.Int, "binary operator must be applied to an int") + checkValue(op.x, operand) + checkValue(op.y, operand) + KnownType.Int satisfies constraint + case op: (Eq | Neq) => + val xTy = checkValue(op.x, Constraint.Unconstrained) + checkValue(op.y, Constraint.Is(xTy, "equality must be applied to values of the same type")) + KnownType.Bool satisfies constraint + case op: (Less | LessEq | Greater | GreaterEq) => + val xTy = checkValue( + op.x, + Constraint.IsEither( + KnownType.Int, + KnownType.Char, + "comparison must be applied to an int or char" + ) + ) + checkValue(op.y, Constraint.Is(xTy, "comparison must be applied to values of the same type")) + KnownType.Bool satisfies constraint + case op: (And | Or) => + val operand = Constraint.Is(KnownType.Bool, "logical operator must be applied to a bool") + checkValue(op.x, operand) + checkValue(op.y, operand) + KnownType.Bool satisfies constraint + } } diff --git a/src/main/wacc/types.scala b/src/main/wacc/types.scala index e2a7988..ebe8517 100644 --- a/src/main/wacc/types.scala +++ b/src/main/wacc/types.scala @@ -3,7 +3,19 @@ package wacc object types { import ast._ - sealed trait SemType + sealed trait SemType { + override def toString(): String = this match { + case KnownType.Int => "int" + case KnownType.Bool => "bool" + case KnownType.Char => "char" + case KnownType.String => "string" + case KnownType.Array(elem) => s"$elem[]" + case KnownType.Pair(left, right) => s"pair($left, $right)" + case KnownType.Func(ret, params) => s"function returning $ret with params $params" + case ? => "?" + } + } + case object ? extends SemType enum KnownType extends SemType { case Int