diff --git a/src/main/wacc/Error.scala b/src/main/wacc/Error.scala index 6370925..c627246 100644 --- a/src/main/wacc/Error.scala +++ b/src/main/wacc/Error.scala @@ -1,8 +1,14 @@ package wacc +import wacc.ast.Position +import wacc.types._ + enum Error { case DuplicateDeclaration(ident: ast.Ident) case UndefinedIdentifier(ident: ast.Ident, identType: renamer.IdentType) - case FunctionParamsMismatch(expected: Int, got: Int) - case TypeMismatch(expected: types.SemType, got: types.SemType) + + case FunctionParamsMismatch(pos: Position, expected: Int, got: Int) + case SemanticError(pos: Position, msg: String) + case TypeMismatch(pos: Position, expected: SemType, got: SemType, msg: String) + case InternalError(pos: Position, msg: String) } diff --git a/src/main/wacc/Main.scala b/src/main/wacc/Main.scala index a271c3c..59796ff 100644 --- a/src/main/wacc/Main.scala +++ b/src/main/wacc/Main.scala @@ -32,12 +32,11 @@ val cliParser = { def compile(contents: String): Int = { parser.parse(contents) match { - case Success(ast) => + case Success(prog) => given errors: mutable.Builder[Error, List[Error]] = List.newBuilder - renamer.rename(ast) - // given ctx: types.TypeCheckerCtx[List[Error]] = - // types.TypeCheckerCtx(names, errors) - // types.check(ast) + val names = renamer.rename(prog) + given ctx: typeChecker.TypeCheckerCtx = typeChecker.TypeCheckerCtx(names, errors) + typeChecker.check(prog) if (errors.result.nonEmpty) { errors.result.foreach(println) 200 diff --git a/src/main/wacc/ast.scala b/src/main/wacc/ast.scala index 123adb5..8ea99d7 100644 --- a/src/main/wacc/ast.scala +++ b/src/main/wacc/ast.scala @@ -9,7 +9,9 @@ import cats.data.NonEmptyList object ast { // Expressions - sealed trait Expr extends RValue + sealed trait Expr extends RValue { + val pos: Position + } sealed trait Expr1 extends Expr sealed trait Expr2 extends Expr1 sealed trait Expr3 extends Expr2 @@ -18,43 +20,43 @@ object ast { sealed trait Expr6 extends Expr5 // Atoms - case class IntLiter(v: Int)(pos: Position) extends Expr6 + case class IntLiter(v: Int)(val pos: Position) extends Expr6 object IntLiter extends ParserBridgePos1[Int, IntLiter] - case class BoolLiter(v: Boolean)(pos: Position) extends Expr6 + case class BoolLiter(v: Boolean)(val pos: Position) extends Expr6 object BoolLiter extends ParserBridgePos1[Boolean, BoolLiter] - case class CharLiter(v: Char)(pos: Position) extends Expr6 + case class CharLiter(v: Char)(val pos: Position) extends Expr6 object CharLiter extends ParserBridgePos1[Char, CharLiter] - case class StrLiter(v: String)(pos: Position) extends Expr6 + case class StrLiter(v: String)(val pos: Position) extends Expr6 object StrLiter extends ParserBridgePos1[String, StrLiter] - case class PairLiter()(pos: Position) extends Expr6 + case class PairLiter()(val pos: Position) extends Expr6 object PairLiter extends ParserBridgePos0[PairLiter] - case class Ident(v: String, var uid: Int = -1)(pos: Position) extends Expr6 with LValue + case class Ident(v: String, var uid: Int = -1)(val pos: Position) extends Expr6 with LValue object Ident extends ParserBridgePos1[String, Ident] { def apply(v: String)(pos: Position): Ident = new Ident(v)(pos) } - case class ArrayElem(name: Ident, indices: NonEmptyList[Expr])(pos: Position) + case class ArrayElem(name: Ident, indices: NonEmptyList[Expr])(val pos: Position) extends Expr6 with LValue object ArrayElem extends ParserBridgePos1[NonEmptyList[Expr], Ident => ArrayElem] { def apply(a: NonEmptyList[Expr])(pos: Position): Ident => ArrayElem = name => ArrayElem(name, a)(pos) } - case class Parens(expr: Expr)(pos: Position) extends Expr6 + case class Parens(expr: Expr)(val pos: Position) extends Expr6 object Parens extends ParserBridgePos1[Expr, Parens] // Unary operators sealed trait UnaryOp extends Expr { val x: Expr } - case class Negate(x: Expr6)(pos: Position) extends Expr6 with UnaryOp + case class Negate(x: Expr6)(val pos: Position) extends Expr6 with UnaryOp object Negate extends ParserBridgePos1[Expr6, Negate] - case class Not(x: Expr6)(pos: Position) extends Expr6 with UnaryOp + case class Not(x: Expr6)(val pos: Position) extends Expr6 with UnaryOp object Not extends ParserBridgePos1[Expr6, Not] - case class Len(x: Expr6)(pos: Position) extends Expr6 with UnaryOp + case class Len(x: Expr6)(val pos: Position) extends Expr6 with UnaryOp object Len extends ParserBridgePos1[Expr6, Len] - case class Ord(x: Expr6)(pos: Position) extends Expr6 with UnaryOp + case class Ord(x: Expr6)(val pos: Position) extends Expr6 with UnaryOp object Ord extends ParserBridgePos1[Expr6, Ord] - case class Chr(x: Expr6)(pos: Position) extends Expr6 with UnaryOp + case class Chr(x: Expr6)(val pos: Position) extends Expr6 with UnaryOp object Chr extends ParserBridgePos1[Expr6, Chr] // Binary operators @@ -62,59 +64,59 @@ object ast { val x: Expr val y: Expr } - case class Add(x: Expr4, y: Expr5)(pos: Position) extends Expr4 with BinaryOp + case class Add(x: Expr4, y: Expr5)(val pos: Position) extends Expr4 with BinaryOp object Add extends ParserBridgePos2[Expr4, Expr5, Add] - case class Sub(x: Expr4, y: Expr5)(pos: Position) extends Expr4 with BinaryOp + case class Sub(x: Expr4, y: Expr5)(val pos: Position) extends Expr4 with BinaryOp object Sub extends ParserBridgePos2[Expr4, Expr5, Sub] - case class Mul(x: Expr5, y: Expr6)(pos: Position) extends Expr5 with BinaryOp + case class Mul(x: Expr5, y: Expr6)(val pos: Position) extends Expr5 with BinaryOp object Mul extends ParserBridgePos2[Expr5, Expr6, Mul] - case class Div(x: Expr5, y: Expr6)(pos: Position) extends Expr5 with BinaryOp + case class Div(x: Expr5, y: Expr6)(val pos: Position) extends Expr5 with BinaryOp object Div extends ParserBridgePos2[Expr5, Expr6, Div] - case class Mod(x: Expr5, y: Expr6)(pos: Position) extends Expr5 with BinaryOp + case class Mod(x: Expr5, y: Expr6)(val pos: Position) extends Expr5 with BinaryOp object Mod extends ParserBridgePos2[Expr5, Expr6, Mod] - case class Greater(x: Expr4, y: Expr4)(pos: Position) extends Expr3 with BinaryOp + case class Greater(x: Expr4, y: Expr4)(val pos: Position) extends Expr3 with BinaryOp object Greater extends ParserBridgePos2[Expr4, Expr4, Greater] - case class GreaterEq(x: Expr4, y: Expr4)(pos: Position) extends Expr3 with BinaryOp + case class GreaterEq(x: Expr4, y: Expr4)(val pos: Position) extends Expr3 with BinaryOp object GreaterEq extends ParserBridgePos2[Expr4, Expr4, GreaterEq] - case class Less(x: Expr4, y: Expr4)(pos: Position) extends Expr3 with BinaryOp + case class Less(x: Expr4, y: Expr4)(val pos: Position) extends Expr3 with BinaryOp object Less extends ParserBridgePos2[Expr4, Expr4, Less] - case class LessEq(x: Expr4, y: Expr4)(pos: Position) extends Expr3 with BinaryOp + case class LessEq(x: Expr4, y: Expr4)(val pos: Position) extends Expr3 with BinaryOp object LessEq extends ParserBridgePos2[Expr4, Expr4, LessEq] - case class Eq(x: Expr3, y: Expr3)(pos: Position) extends Expr2 with BinaryOp + case class Eq(x: Expr3, y: Expr3)(val pos: Position) extends Expr2 with BinaryOp object Eq extends ParserBridgePos2[Expr3, Expr3, Eq] - case class Neq(x: Expr3, y: Expr3)(pos: Position) extends Expr2 with BinaryOp + case class Neq(x: Expr3, y: Expr3)(val pos: Position) extends Expr2 with BinaryOp object Neq extends ParserBridgePos2[Expr3, Expr3, Neq] - case class And(x: Expr2, y: Expr1)(pos: Position) extends Expr1 with BinaryOp + case class And(x: Expr2, y: Expr1)(val pos: Position) extends Expr1 with BinaryOp object And extends ParserBridgePos2[Expr2, Expr1, And] - case class Or(x: Expr1, y: Expr)(pos: Position) extends Expr with BinaryOp + case class Or(x: Expr1, y: Expr)(val pos: Position) extends Expr with BinaryOp object Or extends ParserBridgePos2[Expr1, Expr, Or] // Types sealed trait Type sealed trait BaseType extends Type with PairElemType - case class IntType()(pos: Position) extends BaseType + case class IntType()(val pos: Position) extends BaseType object IntType extends ParserBridgePos0[IntType] - case class BoolType()(pos: Position) extends BaseType + case class BoolType()(val pos: Position) extends BaseType object BoolType extends ParserBridgePos0[BoolType] - case class CharType()(pos: Position) extends BaseType + case class CharType()(val pos: Position) extends BaseType object CharType extends ParserBridgePos0[CharType] - case class StringType()(pos: Position) extends BaseType + case class StringType()(val pos: Position) extends BaseType object StringType extends ParserBridgePos0[StringType] - case class ArrayType(elemType: Type, dimensions: Int)(pos: Position) + case class ArrayType(elemType: Type, dimensions: Int)(val pos: Position) extends Type with PairElemType object ArrayType extends ParserBridgePos1[Int, Type => ArrayType] { def apply(a: Int)(pos: Position): Type => ArrayType = elemType => ArrayType(elemType, a)(pos) } - case class PairType(fst: PairElemType, snd: PairElemType)(pos: Position) extends Type + case class PairType(fst: PairElemType, snd: PairElemType)(val pos: Position) extends Type object PairType extends ParserBridgePos2[PairElemType, PairElemType, PairType] sealed trait PairElemType - case class UntypedPairType()(pos: Position) extends PairElemType + case class UntypedPairType()(val pos: Position) extends PairElemType object UntypedPairType extends ParserBridgePos0[UntypedPairType] // waccadoodledo - case class Program(funcs: List[FuncDecl], main: NonEmptyList[Stmt])(pos: Position) + case class Program(funcs: List[FuncDecl], main: NonEmptyList[Stmt])(val pos: Position) object Program extends ParserBridgePos2[List[FuncDecl], NonEmptyList[Stmt], Program] // Function Definitions @@ -123,7 +125,7 @@ object ast { name: Ident, params: List[Param], body: NonEmptyList[Stmt] - )(pos: Position) + )(val pos: Position) object FuncDecl extends ParserBridgePos2[ List[Param], @@ -136,50 +138,52 @@ object ast { (returnType, name) => FuncDecl(returnType, name, params, body)(pos) } - case class Param(paramType: Type, name: Ident)(pos: Position) + case class Param(paramType: Type, name: Ident)(val pos: Position) object Param extends ParserBridgePos2[Type, Ident, Param] // Statements sealed trait Stmt - case class Skip()(pos: Position) extends Stmt + case class Skip()(val pos: Position) extends Stmt object Skip extends ParserBridgePos0[Skip] - case class VarDecl(varType: Type, name: Ident, value: RValue)(pos: Position) extends Stmt + case class VarDecl(varType: Type, name: Ident, value: RValue)(val pos: Position) extends Stmt object VarDecl extends ParserBridgePos3[Type, Ident, RValue, VarDecl] - case class Assign(lhs: LValue, value: RValue)(pos: Position) extends Stmt + case class Assign(lhs: LValue, value: RValue)(val pos: Position) extends Stmt object Assign extends ParserBridgePos2[LValue, RValue, Assign] - case class Read(lhs: LValue)(pos: Position) extends Stmt + case class Read(lhs: LValue)(val pos: Position) extends Stmt object Read extends ParserBridgePos1[LValue, Read] - case class Free(expr: Expr)(pos: Position) extends Stmt + case class Free(expr: Expr)(val pos: Position) extends Stmt object Free extends ParserBridgePos1[Expr, Free] - case class Return(expr: Expr)(pos: Position) extends Stmt + case class Return(expr: Expr)(val pos: Position) extends Stmt object Return extends ParserBridgePos1[Expr, Return] - case class Exit(expr: Expr)(pos: Position) extends Stmt + case class Exit(expr: Expr)(val pos: Position) extends Stmt object Exit extends ParserBridgePos1[Expr, Exit] - case class Print(expr: Expr, newline: Boolean)(pos: Position) extends Stmt + case class Print(expr: Expr, newline: Boolean)(val pos: Position) extends Stmt object Print extends ParserBridgePos2[Expr, Boolean, Print] case class If(cond: Expr, thenStmt: NonEmptyList[Stmt], elseStmt: NonEmptyList[Stmt])( - pos: Position + val pos: Position ) extends Stmt object If extends ParserBridgePos3[Expr, NonEmptyList[Stmt], NonEmptyList[Stmt], If] - case class While(cond: Expr, body: NonEmptyList[Stmt])(pos: Position) extends Stmt + case class While(cond: Expr, body: NonEmptyList[Stmt])(val pos: Position) extends Stmt object While extends ParserBridgePos2[Expr, NonEmptyList[Stmt], While] - case class Block(stmt: NonEmptyList[Stmt])(pos: Position) extends Stmt + case class Block(stmt: NonEmptyList[Stmt])(val pos: Position) extends Stmt object Block extends ParserBridgePos1[NonEmptyList[Stmt], Block] - sealed trait LValue + sealed trait LValue { + val pos: Position + } sealed trait RValue - case class ArrayLiter(elems: List[Expr])(pos: Position) extends RValue + case class ArrayLiter(elems: List[Expr])(val pos: Position) extends RValue object ArrayLiter extends ParserBridgePos1[List[Expr], ArrayLiter] - case class NewPair(fst: Expr, snd: Expr)(pos: Position) extends RValue + case class NewPair(fst: Expr, snd: Expr)(val pos: Position) extends RValue object NewPair extends ParserBridgePos2[Expr, Expr, NewPair] - case class Call(name: Ident, args: List[Expr])(pos: Position) extends RValue + case class Call(name: Ident, args: List[Expr])(val pos: Position) extends RValue object Call extends ParserBridgePos2[Ident, List[Expr], Call] sealed trait PairElem extends LValue with RValue - case class Fst(elem: LValue)(pos: Position) extends PairElem + case class Fst(elem: LValue)(val pos: Position) extends PairElem object Fst extends ParserBridgePos1[LValue, Fst] - case class Snd(elem: LValue)(pos: Position) extends PairElem + case class Snd(elem: LValue)(val pos: Position) extends PairElem object Snd extends ParserBridgePos1[LValue, Snd] // Parser bridges diff --git a/src/main/wacc/typeChecker.scala b/src/main/wacc/typeChecker.scala new file mode 100644 index 0000000..0cfe9d7 --- /dev/null +++ b/src/main/wacc/typeChecker.scala @@ -0,0 +1,273 @@ +package wacc + +import cats.syntax.all._ +import scala.collection.mutable + +object typeChecker { + import wacc.ast._ + import wacc.types._ + + case class TypeCheckerCtx( + globalNames: Map[Ident, SemType], + errors: mutable.Builder[Error, List[Error]] + ) { + def typeOf(ident: Ident): SemType = globalNames.withDefault { case Ident(_, -1) => ? }(ident) + + def error(err: Error): SemType = + errors += err + ? + } + + enum Constraint { + case Unconstrained + case Is(ty: SemType, msg: String) + case IsSymmetricCompatible(ty: SemType, msg: String) + case IsUnweakanable(ty: SemType, msg: String) + case IsVar(msg: String) + case IsEither(ty1: SemType, ty2: SemType, msg: String) + case Never(msg: String) + } + + extension (ty: SemType) { + def satisfies(constraint: Constraint, pos: Position)(using ctx: TypeCheckerCtx): SemType = + (ty, constraint) match { + case (KnownType.Array(KnownType.Char), Constraint.Is(KnownType.String, _)) => + KnownType.String + case ( + KnownType.String, + Constraint.IsSymmetricCompatible(KnownType.Array(KnownType.Char), _) + ) => + KnownType.String + case (ty, Constraint.IsSymmetricCompatible(ty2, msg)) => + ty.satisfies(Constraint.Is(ty2, msg), pos) + case (ty, Constraint.Is(ty2, msg)) => ty.satisfies(Constraint.IsUnweakanable(ty2, msg), pos) + case (ty, Constraint.Unconstrained) => ty + case (KnownType.Func(_, _), Constraint.IsVar(msg)) => + ctx.error(Error.SemanticError(pos, msg)) + case (ty, Constraint.IsVar(msg)) => ty + case (ty, Constraint.Never(msg)) => + ctx.error(Error.SemanticError(pos, msg)) + case (ty, Constraint.IsEither(ty1, ty2, msg)) => + (ty moreSpecific ty1).orElse(ty moreSpecific ty2).getOrElse { + ctx.error(Error.TypeMismatch(pos, ty1, ty, msg)) + } + case (ty, Constraint.IsUnweakanable(ty2, msg)) => + (ty moreSpecific ty2).getOrElse { + ctx.error(Error.TypeMismatch(pos, ty2, ty, msg)) + } + } + + infix def moreSpecific(ty2: SemType): Option[SemType] = + (ty, ty2) match { + case (ty, ?) => Some(ty) + case (?, ty) => Some(ty) + case (ty1, ty2) if ty1 == ty2 => Some(ty1) + case (KnownType.Array(inn1), KnownType.Array(inn2)) => + (inn1 moreSpecific inn2).map(KnownType.Array(_)) + case (KnownType.Pair(fst1, snd1), KnownType.Pair(fst2, snd2)) => + (fst1 moreSpecific fst2, snd1 moreSpecific snd2).mapN(KnownType.Pair(_, _)) + case _ => None + } + } + + def check(prog: Program)(using + ctx: TypeCheckerCtx + ): Unit = { + prog.funcs.foreach { case FuncDecl(_, name, _, stmts) => + ctx.typeOf(name) match { + case KnownType.Func(retType, _) => + stmts.toList.foreach( + checkStmt(_, Constraint.Is(retType, s"function ${name.v} must return $retType")) + ) + case _ => ctx.error(Error.InternalError(name.pos, "function declaration with non-function")) + } + } + prog.main.toList.foreach(checkStmt(_, Constraint.Never("main function must not return"))) + } + + private def checkStmt(stmt: Stmt, returnConstraint: Constraint)(using + ctx: TypeCheckerCtx + ): Unit = stmt match { + case VarDecl(_, name, value) => + val expectedTy = ctx.typeOf(name) + checkValue( + value, + Constraint.Is( + expectedTy, + s"variable ${name.v} must be assigned a value of type $expectedTy" + ) + ) + case Assign(lhs, rhs) => + val lhsTy = checkValue(lhs, Constraint.Unconstrained) + checkValue(rhs, Constraint.Is(lhsTy, s"assignment must have type $lhsTy")) match { + case ? => + ctx.error( + Error.SemanticError(lhs.pos, "assignment with both sides of unknown type is illegal") + ) + case _ => () + } + case Read(lhs) => + val lhsTy = checkValue(lhs, Constraint.Unconstrained) + lhsTy match { + case ? => + ctx.error( + Error.SemanticError(lhs.pos, "cannot read into a destination with an unknown type") + ) + case _ => + lhsTy.satisfies( + Constraint.IsEither( + KnownType.Int, + KnownType.Char, + "read must be applied to an int or char" + ), + lhs.pos + ) + } + case Free(lhs) => + checkValue( + lhs, + Constraint.IsEither( + KnownType.Array(?), + KnownType.Pair(?, ?), + "free must be an array or pair" + ) + ) + case Return(expr) => + checkValue(expr, returnConstraint) + case Exit(expr) => + checkValue(expr, Constraint.Is(KnownType.Int, "exit value must be int")) + case Print(expr, _) => + // This constraint should never fail, the scope-checker should have caught it already + checkValue(expr, Constraint.IsVar("print value must be a variable")) + case If(cond, thenStmt, elseStmt) => + checkValue(cond, Constraint.Is(KnownType.Bool, "if condition must be a bool")) + thenStmt.toList.foreach(checkStmt(_, returnConstraint)) + elseStmt.toList.foreach(checkStmt(_, returnConstraint)) + case While(cond, body) => + checkValue(cond, Constraint.Is(KnownType.Bool, "while condition must be a bool")) + body.toList.foreach(checkStmt(_, returnConstraint)) + case Block(body) => + body.toList.foreach(checkStmt(_, returnConstraint)) + case Skip() => () + } + + private def checkValue(value: LValue | RValue | Expr, constraint: Constraint)(using + ctx: TypeCheckerCtx + ): SemType = value match { + case l @ IntLiter(_) => KnownType.Int.satisfies(constraint, l.pos) + case l @ BoolLiter(_) => KnownType.Bool.satisfies(constraint, l.pos) + case l @ CharLiter(_) => KnownType.Char.satisfies(constraint, l.pos) + case l @ StrLiter(_) => KnownType.String.satisfies(constraint, l.pos) + case l @ PairLiter() => KnownType.Pair(?, ?).satisfies(constraint, l.pos) + case id: Ident => + ctx.typeOf(id).satisfies(constraint, id.pos) + case ArrayElem(id, indices) => + val arrayTy = ctx.typeOf(id) + val elemTy = indices.toList.foldRight(arrayTy) { (elem, acc) => + checkValue(elem, Constraint.Is(KnownType.Int, "array index must be an int")) + acc match { + case KnownType.Array(innerTy) => innerTy + case _ => + ctx.error( + Error.TypeMismatch(elem.pos, KnownType.Array(?), acc, "cannot index into a non-array") + ) + } + } + elemTy.satisfies(constraint, id.pos) + case Parens(expr) => checkValue(expr, constraint) + case l @ ArrayLiter(elems) => + KnownType + .Array(elems.foldRight[SemType](?) { case (elem, acc) => + checkValue( + elem, + Constraint.IsSymmetricCompatible(acc, "array elements must have the same type") + ) + }) + .satisfies(constraint, l.pos) + case l @ NewPair(fst, snd) => + KnownType + .Pair( + checkValue(fst, Constraint.Unconstrained), + checkValue(snd, Constraint.Unconstrained) + ) + .satisfies(constraint, l.pos) + case Call(id, args) => + val funcTy = ctx.typeOf(id) + funcTy match { + case KnownType.Func(retTy, paramTys) => + if (args.length != paramTys.length) { + ctx.error(Error.FunctionParamsMismatch(id.pos, paramTys.length, args.length)) + } + args.zip(paramTys).foreach { case (arg, paramTy) => + checkValue(arg, Constraint.Is(paramTy, s"argument type mismatch in function ${id.v}")) + } + retTy.satisfies(constraint, id.pos) + // Should never happen, the scope-checker should have caught this already + // ctx error had it not + case _ => ctx.error(Error.InternalError(id.pos, "function call to non-function")) + } + case Fst(elem) => + checkValue( + elem, + Constraint.Is(KnownType.Pair(?, ?), "fst must be applied to a pair") + ) match { + case what @ KnownType.Pair(left, _) => + left.satisfies(constraint, elem.pos) + case ? => ?.satisfies(constraint, elem.pos) + case _ => ctx.error(Error.InternalError(elem.pos, "fst must be applied to a pair")) + } // satisfies constraint + case Snd(elem) => + checkValue( + elem, + Constraint.Is(KnownType.Pair(?, ?), "snd must be applied to a pair") + ) match { + case KnownType.Pair(_, right) => right.satisfies(constraint, elem.pos) + case ? => ?.satisfies(constraint, elem.pos) + case _ => ctx.error(Error.InternalError(elem.pos, "snd must be applied to a pair")) + } + + // Unary operators + case Negate(x) => + checkValue(x, Constraint.Is(KnownType.Int, "negation must be applied to an int")) + KnownType.Int.satisfies(constraint, x.pos) + case Not(x) => + checkValue(x, Constraint.Is(KnownType.Bool, "logical not must be applied to a bool")) + KnownType.Bool.satisfies(constraint, x.pos) + case Len(x) => + checkValue(x, Constraint.Is(KnownType.Array(?), "len must be applied to an array")) + KnownType.Int.satisfies(constraint, x.pos) + case Ord(x) => + checkValue(x, Constraint.Is(KnownType.Char, "ord must be applied to a char")) + KnownType.Int.satisfies(constraint, x.pos) + case Chr(x) => + checkValue(x, Constraint.Is(KnownType.Int, "chr must be applied to an int")) + KnownType.Char.satisfies(constraint, x.pos) + + // Binary operators + case op: (Add | Sub | Mul | Div | Mod) => + val operand = Constraint.Is(KnownType.Int, "binary operator must be applied to an int") + checkValue(op.x, operand) + checkValue(op.y, operand) + KnownType.Int.satisfies(constraint, op.pos) + case op: (Eq | Neq) => + val xTy = checkValue(op.x, Constraint.Unconstrained) + checkValue(op.y, Constraint.Is(xTy, "equality must be applied to values of the same type")) + KnownType.Bool.satisfies(constraint, op.pos) + case op: (Less | LessEq | Greater | GreaterEq) => + val xTy = checkValue( + op.x, + Constraint.IsEither( + KnownType.Int, + KnownType.Char, + "comparison must be applied to an int or char" + ) + ) + checkValue(op.y, Constraint.Is(xTy, "comparison must be applied to values of the same type")) + KnownType.Bool.satisfies(constraint, op.pos) + case op: (And | Or) => + val operand = Constraint.Is(KnownType.Bool, "logical operator must be applied to a bool") + checkValue(op.x, operand) + checkValue(op.y, operand) + KnownType.Bool.satisfies(constraint, op.pos) + } +} diff --git a/src/main/wacc/types.scala b/src/main/wacc/types.scala index e2a7988..ebe8517 100644 --- a/src/main/wacc/types.scala +++ b/src/main/wacc/types.scala @@ -3,7 +3,19 @@ package wacc object types { import ast._ - sealed trait SemType + sealed trait SemType { + override def toString(): String = this match { + case KnownType.Int => "int" + case KnownType.Bool => "bool" + case KnownType.Char => "char" + case KnownType.String => "string" + case KnownType.Array(elem) => s"$elem[]" + case KnownType.Pair(left, right) => s"pair($left, $right)" + case KnownType.Func(ret, params) => s"function returning $ret with params $params" + case ? => "?" + } + } + case object ? extends SemType enum KnownType extends SemType { case Int diff --git a/src/test/wacc/examples.scala b/src/test/wacc/examples.scala index 8d2b55e..f62d537 100644 --- a/src/test/wacc/examples.scala +++ b/src/test/wacc/examples.scala @@ -16,7 +16,7 @@ class ParallelExamplesSpec extends AnyFlatSpec with BeforeAndAfterAll with Paral (p.toString, List(200)) } ++ allWaccFiles("wacc-examples/invalid/whack").map { p => - (p.toString, List(0, 100, 200)) + (p.toString, List(100, 200)) } // tests go here @@ -68,21 +68,21 @@ class ParallelExamplesSpec extends AnyFlatSpec with BeforeAndAfterAll with Paral // "wacc-examples/invalid/syntaxErr/variables", // "wacc-examples/invalid/syntaxErr/while", // invalid (semantic) - "wacc-examples/invalid/semanticErr/array", - "wacc-examples/invalid/semanticErr/exit", - "wacc-examples/invalid/semanticErr/expressions", - "wacc-examples/invalid/semanticErr/function", - "wacc-examples/invalid/semanticErr/if", - "wacc-examples/invalid/semanticErr/IO", - "wacc-examples/invalid/semanticErr/multiple", - "wacc-examples/invalid/semanticErr/pairs", - "wacc-examples/invalid/semanticErr/print", - "wacc-examples/invalid/semanticErr/read", - "wacc-examples/invalid/semanticErr/scope", - "wacc-examples/invalid/semanticErr/variables", - "wacc-examples/invalid/semanticErr/while", + // "wacc-examples/invalid/semanticErr/array", + // "wacc-examples/invalid/semanticErr/exit", + // "wacc-examples/invalid/semanticErr/expressions", + // "wacc-examples/invalid/semanticErr/function", + // "wacc-examples/invalid/semanticErr/if", + // "wacc-examples/invalid/semanticErr/IO", + // "wacc-examples/invalid/semanticErr/multiple", + // "wacc-examples/invalid/semanticErr/pairs", + // "wacc-examples/invalid/semanticErr/print", + // "wacc-examples/invalid/semanticErr/read", + // "wacc-examples/invalid/semanticErr/scope", + // "wacc-examples/invalid/semanticErr/variables", + // "wacc-examples/invalid/semanticErr/while", // invalid (whack) - "wacc-examples/invalid/whack" + // "wacc-examples/invalid/whack" // format: on // format: on ).find(filename.contains).isDefined