diff --git a/src/main/wacc/Frontend/microWacc.scala b/src/main/wacc/Frontend/microWacc.scala index 04b9035..b9f6635 100644 --- a/src/main/wacc/Frontend/microWacc.scala +++ b/src/main/wacc/Frontend/microWacc.scala @@ -19,7 +19,7 @@ object microWacc { extends Expr(identTy) with CallTarget(identTy) with LValue - case class ArrayElem(name: Ident, indices: NonEmptyList[Expr])(ty: SemType) + case class ArrayElem(value: LValue, indices: NonEmptyList[Expr])(ty: SemType) extends Expr(ty) with LValue @@ -48,6 +48,23 @@ object microWacc { case And case Or } + object BinaryOperator { + def fromAst(op: ast.BinaryOp): BinaryOperator = op match { + case _: ast.Add => Add + case _: ast.Sub => Sub + case _: ast.Mul => Mul + case _: ast.Div => Div + case _: ast.Mod => Mod + case _: ast.Greater => Greater + case _: ast.GreaterEq => GreaterEq + case _: ast.Less => Less + case _: ast.LessEq => LessEq + case _: ast.Eq => Eq + case _: ast.Neq => Neq + case _: ast.And => And + case _: ast.Or => Or + } + } // Statements sealed trait Stmt diff --git a/src/main/wacc/Frontend/typeChecker.scala b/src/main/wacc/Frontend/typeChecker.scala index 21d49ae..e3960cb 100644 --- a/src/main/wacc/Frontend/typeChecker.scala +++ b/src/main/wacc/Frontend/typeChecker.scala @@ -2,6 +2,7 @@ package wacc import cats.syntax.all._ import scala.collection.mutable +import cats.data.NonEmptyList object typeChecker { import wacc.types._ @@ -100,17 +101,26 @@ object typeChecker { */ def check(prog: ast.Program)(using ctx: TypeCheckerCtx - ): Unit = { - // Ignore function syntax types for return value and params, since those have been converted - // to SemTypes by the renamer. - prog.funcs.foreach { case ast.FuncDecl(_, name, _, stmts) => - val FuncType(retType, _) = ctx.funcType(name) - stmts.toList.foreach( - checkStmt(_, Constraint.Is(retType, s"function ${name.v} must return $retType")) - ) - } - prog.main.toList.foreach(checkStmt(_, Constraint.Never("main function must not return"))) - } + ): microWacc.Program = + microWacc.Program( + // Ignore function syntax types for return value and params, since those have been converted + // to SemTypes by the renamer. + prog.funcs.map { case ast.FuncDecl(_, name, params, stmts) => + val FuncType(retType, paramTypes) = ctx.funcType(name) + microWacc.FuncDecl( + microWacc.Ident(name.v, name.uid)(retType), + params.zip(paramTypes).map { case (ast.Param(_, ident), ty) => + microWacc.Ident(ident.v, name.uid)(ty) + }, + stmts.toList + .flatMap( + checkStmt(_, Constraint.Is(retType, s"function ${name.v} must return $retType")) + ) + ) + }, + prog.main.toList + .flatMap(checkStmt(_, Constraint.Never("main function must not return"))) + ) /** Type-check an AST statement node. * @@ -121,32 +131,38 @@ object typeChecker { */ private def checkStmt(stmt: ast.Stmt, returnConstraint: Constraint)(using ctx: TypeCheckerCtx - ): Unit = stmt match { + ): List[microWacc.Stmt] = stmt match { // Ignore the type of the variable, since it has been converted to a SemType by the renamer. case ast.VarDecl(_, name, value) => val expectedTy = ctx.typeOf(name) - checkValue( + val typedValue = checkValue( value, Constraint.Is( expectedTy, s"variable ${name.v} must be assigned a value of type $expectedTy" ) ) + List(microWacc.Assign(microWacc.Ident(name.v, name.uid)(expectedTy), typedValue)) case ast.Assign(lhs, rhs) => - val lhsTy = checkValue(lhs, Constraint.Unconstrained) - (lhsTy, checkValue(rhs, Constraint.Is(lhsTy, s"assignment must have type $lhsTy"))) match { + val lhsTyped = checkLValue(lhs, Constraint.Unconstrained) + val rhsTyped = + checkValue(rhs, Constraint.Is(lhsTyped.ty, s"assignment must have type ${lhsTyped.ty}")) + (lhsTyped.ty, rhsTyped.ty) match { case (?, ?) => ctx.error( Error.SemanticError(lhs.pos, "assignment with both sides of unknown type is illegal") ) case _ => () } + List(microWacc.Assign(lhsTyped, rhsTyped)) case ast.Read(dest) => - checkValue(dest, Constraint.Unconstrained) match { + val destTyped = checkLValue(dest, Constraint.Unconstrained) + val destTy = destTyped.ty match { case ? => ctx.error( Error.SemanticError(dest.pos, "cannot read into a destination with an unknown type") ) + ? case destTy => destTy.satisfies( Constraint.IsEither( @@ -157,32 +173,69 @@ object typeChecker { dest.pos ) } + List( + microWacc.Assign( + destTyped, + microWacc.Call( + destTy match { + case KnownType.Int => microWacc.Builtin.ReadInt + case KnownType.Char => microWacc.Builtin.ReadChar + case _ => microWacc.Builtin.ReadInt // we'll stop due to error anyway + }, + Nil + ) + ) + ) case ast.Free(lhs) => - checkValue( - lhs, - Constraint.IsEither( - KnownType.Array(?), - KnownType.Pair(?, ?), - "free must be applied to an array or pair" + List( + microWacc.Call( + microWacc.Builtin.Free, + List( + checkValue( + lhs, + Constraint.IsEither( + KnownType.Array(?), + KnownType.Pair(?, ?), + "free must be applied to an array or pair" + ) + ) + ) ) ) case ast.Return(expr) => - checkValue(expr, returnConstraint) + List(microWacc.Return(checkValue(expr, returnConstraint))) case ast.Exit(expr) => - checkValue(expr, Constraint.Is(KnownType.Int, "exit value must be int")) - case ast.Print(expr, _) => + List( + microWacc.Call( + microWacc.Builtin.Exit, + List(checkValue(expr, Constraint.Is(KnownType.Int, "exit value must be int"))) + ) + ) + case ast.Print(expr, newline) => // This constraint should never fail, the scope-checker should have caught it already - checkValue(expr, Constraint.Unconstrained) + List( + microWacc.Call( + if newline then microWacc.Builtin.Println else microWacc.Builtin.Print, + List(checkValue(expr, Constraint.Unconstrained)) + ) + ) case ast.If(cond, thenStmt, elseStmt) => - checkValue(cond, Constraint.Is(KnownType.Bool, "if condition must be a bool")) - thenStmt.toList.foreach(checkStmt(_, returnConstraint)) - elseStmt.toList.foreach(checkStmt(_, returnConstraint)) + List( + microWacc.If( + checkValue(cond, Constraint.Is(KnownType.Bool, "if condition must be a bool")), + thenStmt.toList.flatMap(checkStmt(_, returnConstraint)), + elseStmt.toList.flatMap(checkStmt(_, returnConstraint)) + ) + ) case ast.While(cond, body) => - checkValue(cond, Constraint.Is(KnownType.Bool, "while condition must be a bool")) - body.toList.foreach(checkStmt(_, returnConstraint)) - case ast.Block(body) => - body.toList.foreach(checkStmt(_, returnConstraint)) - case ast.Skip() => () + List( + microWacc.While( + checkValue(cond, Constraint.Is(KnownType.Bool, "while condition must be a bool")), + body.toList.flatMap(checkStmt(_, returnConstraint)) + ) + ) + case ast.Block(body) => body.toList.flatMap(checkStmt(_, returnConstraint)) + case skip @ ast.Skip() => List.empty } /** Type-check an AST LValue, RValue or Expr node. This function does all 3 since these traits @@ -197,47 +250,42 @@ object typeChecker { */ private def checkValue(value: ast.LValue | ast.RValue | ast.Expr, constraint: Constraint)(using ctx: TypeCheckerCtx - ): SemType = value match { - case l @ ast.IntLiter(_) => KnownType.Int.satisfies(constraint, l.pos) - case l @ ast.BoolLiter(_) => KnownType.Bool.satisfies(constraint, l.pos) - case l @ ast.CharLiter(_) => KnownType.Char.satisfies(constraint, l.pos) - case l @ ast.StrLiter(_) => KnownType.String.satisfies(constraint, l.pos) - case l @ ast.PairLiter() => KnownType.Pair(?, ?).satisfies(constraint, l.pos) - case id: ast.Ident => - ctx.typeOf(id).satisfies(constraint, id.pos) - case ast.ArrayElem(id, indices) => - val arrayTy = ctx.typeOf(id) - val elemTy = indices.foldLeftM(arrayTy) { (acc, elem) => - checkValue(elem, Constraint.Is(KnownType.Int, "array index must be an int")) - acc match { - case KnownType.Array(innerTy) => Some(innerTy) - case ? => Some(?) // we can keep indexing an unknown type - case nonArrayTy => - ctx.error( - Error.TypeMismatch(elem.pos, KnownType.Array(?), acc, "cannot index into a non-array") - ) - None - } - } - elemTy.getOrElse(?).satisfies(constraint, id.pos) + ): microWacc.Expr = value match { + case l @ ast.IntLiter(v) => + KnownType.Int.satisfies(constraint, l.pos) + microWacc.IntLiter(v) + case l @ ast.BoolLiter(v) => + KnownType.Bool.satisfies(constraint, l.pos) + microWacc.BoolLiter(v) + case l @ ast.CharLiter(v) => + KnownType.Char.satisfies(constraint, l.pos) + microWacc.CharLiter(v) + case l @ ast.StrLiter(v) => + KnownType.String.satisfies(constraint, l.pos) + microWacc.ArrayLiter(v.map(microWacc.CharLiter(_)).toList)(KnownType.String) + case l @ ast.PairLiter() => + microWacc.NullLiter()(KnownType.Pair(?, ?).satisfies(constraint, l.pos)) case ast.Parens(expr) => checkValue(expr, constraint) case l @ ast.ArrayLiter(elems) => - KnownType - // Start with an unknown param type, make it more specific while checking the elements. - .Array(elems.foldLeft[SemType](?) { case (acc, elem) => - checkValue( + val (elemTy, elemsTyped) = elems.mapAccumulate[SemType, microWacc.Expr](?) { + case (acc, elem) => + val elemTyped = checkValue( elem, Constraint.IsSymmetricCompatible(acc, s"array elements must have the same type") ) - }) + (elemTyped.ty, elemTyped) + } + val arrayTy = KnownType + // Start with an unknown param type, make it more specific while checking the elements. + .Array(elemTy) .satisfies(constraint, l.pos) + microWacc.ArrayLiter(elemsTyped)(arrayTy) case l @ ast.NewPair(fst, snd) => - KnownType - .Pair( - checkValue(fst, Constraint.Unconstrained), - checkValue(snd, Constraint.Unconstrained) - ) - .satisfies(constraint, l.pos) + val fstTyped = checkValue(fst, Constraint.Unconstrained) + val sndTyped = checkValue(snd, Constraint.Unconstrained) + microWacc.ArrayLiter(List(fstTyped, sndTyped))( + KnownType.Pair(fstTyped.ty, sndTyped.ty).satisfies(constraint, l.pos) + ) case ast.Call(id, args) => val funcTy @ FuncType(retTy, paramTys) = ctx.funcType(id) if (args.length != paramTys.length) { @@ -245,76 +293,152 @@ object typeChecker { } // Even if the number of arguments is wrong, we still check the types of the arguments // in the best way we can (by taking a zip). - args.zip(paramTys).foreach { case (arg, paramTy) => + val argsTyped = args.zip(paramTys).map { case (arg, paramTy) => checkValue(arg, Constraint.Is(paramTy, s"argument type mismatch in function ${id.v}")) } - retTy.satisfies(constraint, id.pos) - case ast.Fst(elem) => - checkValue( - elem, - Constraint.Is(KnownType.Pair(?, ?), "fst must be applied to a pair") - ) match { - case what @ KnownType.Pair(left, _) => - left.satisfies(constraint, elem.pos) - case _ => ctx.error(Error.InternalError(elem.pos, "fst must be applied to a pair")) - } - case ast.Snd(elem) => - checkValue( - elem, - Constraint.Is(KnownType.Pair(?, ?), "snd must be applied to a pair") - ) match { - case KnownType.Pair(_, right) => right.satisfies(constraint, elem.pos) - case _ => ctx.error(Error.InternalError(elem.pos, "snd must be applied to a pair")) - } + microWacc.Call(microWacc.Ident(id.v, id.uid)(retTy.satisfies(constraint, id.pos)), argsTyped) // Unary operators case ast.Negate(x) => - checkValue(x, Constraint.Is(KnownType.Int, "negation must be applied to an int")) - KnownType.Int.satisfies(constraint, x.pos) + microWacc.UnaryOp( + checkValue(x, Constraint.Is(KnownType.Int, "negation must be applied to an int")), + microWacc.UnaryOperator.Negate + )(KnownType.Int.satisfies(constraint, x.pos)) case ast.Not(x) => - checkValue(x, Constraint.Is(KnownType.Bool, "logical not must be applied to a bool")) - KnownType.Bool.satisfies(constraint, x.pos) + microWacc.UnaryOp( + checkValue(x, Constraint.Is(KnownType.Bool, "logical not must be applied to a bool")), + microWacc.UnaryOperator.Not + )(KnownType.Bool.satisfies(constraint, x.pos)) case ast.Len(x) => - checkValue(x, Constraint.Is(KnownType.Array(?), "len must be applied to an array")) - KnownType.Int.satisfies(constraint, x.pos) + microWacc.UnaryOp( + checkValue(x, Constraint.Is(KnownType.Array(?), "len must be applied to an array")), + microWacc.UnaryOperator.Len + )(KnownType.Int.satisfies(constraint, x.pos)) case ast.Ord(x) => - checkValue(x, Constraint.Is(KnownType.Char, "ord must be applied to a char")) - KnownType.Int.satisfies(constraint, x.pos) + microWacc.UnaryOp( + checkValue(x, Constraint.Is(KnownType.Char, "ord must be applied to a char")), + microWacc.UnaryOperator.Ord + )(KnownType.Int.satisfies(constraint, x.pos)) case ast.Chr(x) => - checkValue(x, Constraint.Is(KnownType.Int, "chr must be applied to an int")) - KnownType.Char.satisfies(constraint, x.pos) + microWacc.UnaryOp( + checkValue(x, Constraint.Is(KnownType.Int, "chr must be applied to an int")), + microWacc.UnaryOperator.Chr + )(KnownType.Char.satisfies(constraint, x.pos)) // Binary operators case op: (ast.Add | ast.Sub | ast.Mul | ast.Div | ast.Mod) => val operand = Constraint.Is(KnownType.Int, s"${op.name} operator must be applied to an int") - checkValue(op.x, operand) - checkValue(op.y, operand) - KnownType.Int.satisfies(constraint, op.pos) + microWacc.BinaryOp( + checkValue(op.x, operand), + checkValue(op.y, operand), + microWacc.BinaryOperator.fromAst(op) + )(KnownType.Int.satisfies(constraint, op.pos)) case op: (ast.Eq | ast.Neq) => - val xTy = checkValue(op.x, Constraint.Unconstrained) - checkValue( - op.y, - Constraint.Is(xTy, s"${op.name} operator must be applied to values of the same type") - ) - KnownType.Bool.satisfies(constraint, op.pos) + val xTyped = checkValue(op.x, Constraint.Unconstrained) + microWacc.BinaryOp( + xTyped, + checkValue( + op.y, + Constraint + .Is(xTyped.ty, s"${op.name} operator must be applied to values of the same type") + ), + microWacc.BinaryOperator.fromAst(op) + )(KnownType.Bool.satisfies(constraint, op.pos)) case op: (ast.Less | ast.LessEq | ast.Greater | ast.GreaterEq) => val xConstraint = Constraint.IsEither( KnownType.Int, KnownType.Char, s"${op.name} operator must be applied to an int or char" ) + val xTyped = checkValue(op.x, xConstraint) // If x type-check failed, we still want to check y is an Int or Char (rather than ?) - val yConstraint = checkValue(op.x, xConstraint) match { + val yConstraint = xTyped.ty match { case ? => xConstraint case xTy => Constraint.Is(xTy, s"${op.name} operator must be applied to values of the same type") } - checkValue(op.y, yConstraint) - KnownType.Bool.satisfies(constraint, op.pos) + microWacc.BinaryOp( + xTyped, + checkValue(op.y, yConstraint), + microWacc.BinaryOperator.fromAst(op) + )(KnownType.Bool.satisfies(constraint, op.pos)) case op: (ast.And | ast.Or) => val operand = Constraint.Is(KnownType.Bool, s"${op.name} operator must be applied to a bool") - checkValue(op.x, operand) - checkValue(op.y, operand) - KnownType.Bool.satisfies(constraint, op.pos) + microWacc.BinaryOp( + checkValue(op.x, operand), + checkValue(op.y, operand), + microWacc.BinaryOperator.fromAst(op) + )(KnownType.Bool.satisfies(constraint, op.pos)) + + case lvalue: ast.LValue => checkLValue(lvalue, constraint) + } + + /** Type-check an AST LValue node. Separate because microWacc keeps LValues + * + * @param value + * The value to type-check. + * @param constraint + * The type constraint that the value must satisfy. + * @param ctx + * The type checker context which includes the global names and functions, and an errors + * builder. + * @return + * The most specific type of the value if it could be determined, or ? if it could not. + */ + private def checkLValue(value: ast.LValue, constraint: Constraint)(using + ctx: TypeCheckerCtx + ): microWacc.LValue = value match { + case id @ ast.Ident(name, uid) => + microWacc.Ident(name, uid)(ctx.typeOf(id).satisfies(constraint, id.pos)) + case ast.ArrayElem(id, indices) => + val arrayTy = ctx.typeOf(id) + val (elemTy, indicesTyped) = indices.mapAccumulate(arrayTy) { (acc, elem) => + val idxTyped = checkValue(elem, Constraint.Is(KnownType.Int, "array index must be an int")) + val next = acc match { + case KnownType.Array(innerTy) => innerTy + case ? => ? // we can keep indexing an unknown type + case nonArrayTy => + ctx.error( + Error.TypeMismatch( + elem.pos, + KnownType.Array(?), + acc, + "cannot index into a non-array" + ) + ) + ? + } + (next, idxTyped) + } + microWacc.ArrayElem( + microWacc.Ident(id.v, id.uid)(arrayTy), + indicesTyped + )(elemTy.satisfies(constraint, value.pos)) + case ast.Fst(elem) => + val elemTyped = checkLValue( + elem, + Constraint.Is(KnownType.Pair(?, ?), "fst must be applied to a pair") + ) + microWacc.ArrayElem( + elemTyped, + NonEmptyList.of(microWacc.IntLiter(0)) + )(elemTyped.ty match { + case KnownType.Pair(left, _) => + left.satisfies(constraint, elem.pos) + case _ => ctx.error(Error.InternalError(elem.pos, "fst must be applied to a pair")) + }) + case ast.Snd(elem) => + val elemTyped = checkLValue( + elem, + Constraint.Is(KnownType.Pair(?, ?), "snd must be applied to a pair") + ) + microWacc.ArrayElem( + elemTyped, + NonEmptyList.of(microWacc.IntLiter(1)) + )(elemTyped.ty match { + case KnownType.Pair(_, right) => + right.satisfies(constraint, elem.pos) + case _ => ctx.error(Error.InternalError(elem.pos, "snd must be applied to a pair")) + }) } } diff --git a/src/main/wacc/Main.scala b/src/main/wacc/Main.scala index 5e95424..445d9c1 100644 --- a/src/main/wacc/Main.scala +++ b/src/main/wacc/Main.scala @@ -36,7 +36,7 @@ def compile(contents: String): Int = { given errors: mutable.Builder[Error, List[Error]] = List.newBuilder val (names, funcs) = renamer.rename(prog) given ctx: typeChecker.TypeCheckerCtx = typeChecker.TypeCheckerCtx(names, funcs, errors) - typeChecker.check(prog) + val typedProg = typeChecker.check(prog) if (errors.result.nonEmpty) { given errorContent: String = contents errors.result @@ -48,7 +48,10 @@ def compile(contents: String): Int = { } } .max() - } else 0 + } else { + println(typedProg) + 0 + } case Failure(msg) => println(msg) 100