diff --git a/src/main/wacc/Main.scala b/src/main/wacc/Main.scala index e78d4bd..c29de9d 100644 --- a/src/main/wacc/Main.scala +++ b/src/main/wacc/Main.scala @@ -1,6 +1,5 @@ package wacc -import scala.collection.mutable import cats.data.{Chain, NonEmptyList} import parsley.{Failure, Success} @@ -19,6 +18,7 @@ import org.typelevel.log4cats.Logger import assemblyIR as asm import cats.data.ValidatedNel import java.io.File +import cats.data.NonEmptySeq /* TODO: @@ -71,22 +71,15 @@ val outputOpt: Opts[Option[Path]] = def frontend( contents: String, file: File -): IO[Either[NonEmptyList[Error], microWacc.Program]] = +): IO[Either[NonEmptySeq[Error], microWacc.Program]] = parser.parse(contents) match { - case Failure(msg) => IO.pure(Left(NonEmptyList.one(Error.SyntaxError(file, msg)))) + case Failure(msg) => IO.pure(Left(NonEmptySeq.one(Error.SyntaxError(file, msg)))) case Success(fn) => val partialProg = fn(file) - given errors: mutable.Builder[Error, List[Error]] = List.newBuilder for { - (prog, renameErrors) <- renamer.rename(partialProg) - _ = errors.addAll(renameErrors.toList) - typedProg = typeChecker.check(prog, errors) - - res = NonEmptyList.fromList(errors.result) match { - case Some(errors) => Left(errors) - case None => Right(typedProg) - } + (typedProg, errors) <- semantics.check(partialProg) + res = NonEmptySeq.fromSeq(errors.iterator.toSeq).map(Left(_)).getOrElse(Right(typedProg)) } yield res } @@ -107,12 +100,23 @@ def compile( // TODO: path, file , the names are confusing (when Path is the type but we are working with files) def writeOutputFile(typedProg: microWacc.Program, outputPath: Path): IO[Unit] = - writer.writeTo(backend(typedProg), outputPath) *> - logger.info(s"Success: ${outputPath.toAbsolutePath}") + val backendStart = System.nanoTime() + val asmLines = backend(typedProg) + val backendEnd = System.nanoTime() + writer.writeTo(asmLines, outputPath) *> + logAction( + s"Backend time (${filePath.toRealPath()}): ${(backendEnd - backendStart).toFloat / 1e6} ms" + ) *> + IO.blocking(println(s"Success: ${outputPath.toRealPath()}")) def processProgram(contents: String, file: File, outDir: Path): IO[Int] = + val frontendStart = System.nanoTime() for { frontendResult <- frontend(contents, file) + frontendEnd = System.nanoTime() + _ <- logAction( + s"Frontend time (${filePath.toRealPath()}): ${(frontendEnd - frontendStart).toFloat / 1e6} ms" + ) res <- frontendResult match { case Left(errors) => val code = errors.map(err => err.exitCode).toList.min diff --git a/src/main/wacc/frontend/microWacc.scala b/src/main/wacc/frontend/microWacc.scala index e2c1bdc..36c5d16 100644 --- a/src/main/wacc/frontend/microWacc.scala +++ b/src/main/wacc/frontend/microWacc.scala @@ -1,5 +1,7 @@ package wacc +import cats.data.Chain + object microWacc { import wacc.types._ @@ -78,12 +80,12 @@ object microWacc { } case class Assign(lhs: LValue, rhs: Expr) extends Stmt - case class If(cond: Expr, thenBranch: List[Stmt], elseBranch: List[Stmt]) extends Stmt - case class While(cond: Expr, body: List[Stmt]) extends Stmt + case class If(cond: Expr, thenBranch: Chain[Stmt], elseBranch: Chain[Stmt]) extends Stmt + case class While(cond: Expr, body: Chain[Stmt]) extends Stmt case class Call(target: CallTarget, args: List[Expr]) extends Stmt with Expr(target.retTy) case class Return(expr: Expr) extends Stmt // Program - case class FuncDecl(name: Ident, params: List[Ident], body: List[Stmt]) - case class Program(funcs: List[FuncDecl], stmts: List[Stmt]) + case class FuncDecl(name: Ident, params: List[Ident], body: Chain[Stmt]) + case class Program(funcs: Chain[FuncDecl], stmts: Chain[Stmt]) } diff --git a/src/main/wacc/frontend/renamer.scala b/src/main/wacc/frontend/renamer.scala index 4893d42..b16df04 100644 --- a/src/main/wacc/frontend/renamer.scala +++ b/src/main/wacc/frontend/renamer.scala @@ -3,27 +3,26 @@ package wacc import java.io.File import scala.collection.mutable import cats.effect.IO -import cats.syntax.all._ import cats.implicits._ import cats.data.Chain import cats.data.NonEmptyList import parsley.{Failure, Success} -private val MAIN = "$main" - object renamer { import ast._ import types._ - private enum IdentType { + val MAIN = "$main" + + enum IdentType { case Func case Var } - private case class ScopeKey(path: String, name: String, identType: IdentType) - private case class ScopeValue(id: Ident, public: Boolean) + case class ScopeKey(path: String, name: String, identType: IdentType) + case class ScopeValue(id: Ident, public: Boolean) - private class Scope( + class Scope( private val current: mutable.Map[ScopeKey, ScopeValue], private val parent: Map[ScopeKey, ScopeValue], guidStart: Int = 0, @@ -153,7 +152,7 @@ object renamer { get(name.pos.file.getCanonicalPath, name.v, IdentType.Func).map(_.id) } - private def prepareGlobalScope( + def prepareGlobalScope( partialProg: PartialProgram )(using scope: Scope): IO[(FuncDecl, Chain[FuncDecl], Chain[Error])] = { def readImportFile(file: File): IO[String] = @@ -267,26 +266,13 @@ object renamer { * @return * (flattenedProg, errors) */ - private def renameFunction(funcScopePair: (FuncDecl, Scope)): IO[Chain[Error]] = { + def renameFunction(funcScopePair: (FuncDecl, Scope)): IO[Chain[Error]] = { val (FuncDecl(_, _, params, body), subscope) = funcScopePair val paramErrors = params.foldMap(param => subscope.add(param.name)) IO(subscope.withSubscope { s => body.foldMap(rename(s)) }) .map(bodyErrors => paramErrors ++ bodyErrors) } - def rename(partialProg: PartialProgram): IO[(Program, Chain[Error])] = { - given scope: Scope = Scope(mutable.Map.empty, Map.empty) - - for { - (main, chunks, globalErrors) <- prepareGlobalScope(partialProg) - toRename = (main +: chunks).toList - allErrors <- toRename - .zip(scope.subscopes(toRename.size)) - .parFoldMapA(renameFunction) - // .map(x => x.combineAll) - } yield (Program(chunks.toList, main.body)(main.pos), globalErrors ++ allErrors) - } - /** Check scoping of all identifies in a given AST node. * * @param scope diff --git a/src/main/wacc/frontend/semantics.scala b/src/main/wacc/frontend/semantics.scala new file mode 100644 index 0000000..dcc5b94 --- /dev/null +++ b/src/main/wacc/frontend/semantics.scala @@ -0,0 +1,42 @@ +package wacc + +import scala.collection.mutable +import cats.implicits._ +import cats.data.Chain +import cats.effect.IO + +object semantics { + import renamer.{Scope, prepareGlobalScope, renameFunction} + import typeChecker.checkFuncDecl + + private def checkFunc( + funcDecl: ast.FuncDecl, + scope: Scope + ): IO[(microWacc.FuncDecl, Chain[Error])] = { + for { + renamerErrors <- renameFunction(funcDecl, scope) + (microWaccFunc, typeErrors) = checkFuncDecl(funcDecl) + } yield (microWaccFunc, renamerErrors ++ typeErrors) + } + + def check(partialProg: ast.PartialProgram): IO[(microWacc.Program, Chain[Error])] = { + given scope: Scope = Scope(mutable.Map.empty, Map.empty) + + for { + (main, chunks, globalErrors) <- prepareGlobalScope(partialProg) + toRename = (main +: chunks).toList + res <- toRename + .zip(scope.subscopes(toRename.size)) + .parTraverse(checkFunc) + (typedChunks, errors) = res.foldLeft((Chain.empty[microWacc.FuncDecl], Chain.empty[Error])) { + case ((acc, err), (funcDecl, errors)) => + (acc :+ funcDecl, err ++ errors) + } + (typedMain, funcs) = typedChunks.uncons match { + case Some((head, tail)) => (head.body, tail) + case None => (Chain.empty, Chain.empty) + } + } yield (microWacc.Program(funcs, typedMain), globalErrors ++ errors) + } + +} diff --git a/src/main/wacc/frontend/typeChecker.scala b/src/main/wacc/frontend/typeChecker.scala index 6f5804b..60fd924 100644 --- a/src/main/wacc/frontend/typeChecker.scala +++ b/src/main/wacc/frontend/typeChecker.scala @@ -1,20 +1,12 @@ package wacc import cats.syntax.all._ -import scala.collection.mutable import cats.data.NonEmptyList +import cats.data.Chain object typeChecker { import wacc.types._ - case class TypeCheckerCtx( - errors: mutable.Builder[Error, List[Error]] - ) { - def error(err: Error): SemType = - errors += err - ? - } - private enum Constraint { case Unconstrained // Allows weakening in one direction @@ -38,31 +30,29 @@ object typeChecker { * @return * The type if the constraint was satisfied, or ? if it was not. */ - private def satisfies(constraint: Constraint, pos: ast.Position)(using - ctx: TypeCheckerCtx - ): SemType = + private def satisfies(constraint: Constraint, pos: ast.Position): (SemType, Chain[Error]) = (ty, constraint) match { case (KnownType.Array(KnownType.Char), Constraint.Is(KnownType.String, _)) => - KnownType.String + (KnownType.String, Chain.empty) case ( KnownType.String, Constraint.IsSymmetricCompatible(KnownType.Array(KnownType.Char), _) ) => - KnownType.String + (KnownType.String, Chain.empty) case (ty, Constraint.IsSymmetricCompatible(ty2, msg)) => ty.satisfies(Constraint.Is(ty2, msg), pos) // Change to IsUnweakenable to disallow recursive weakening case (ty, Constraint.Is(ty2, msg)) => ty.satisfies(Constraint.IsUnweakenable(ty2, msg), pos) - case (ty, Constraint.Unconstrained) => ty + case (ty, Constraint.Unconstrained) => (ty, Chain.empty) case (ty, Constraint.Never(msg)) => - ctx.error(Error.SemanticError(pos, msg)) + (?, Chain.one(Error.SemanticError(pos, msg))) case (ty, Constraint.IsEither(ty1, ty2, msg)) => - (ty moreSpecific ty1).orElse(ty moreSpecific ty2).getOrElse { - ctx.error(Error.TypeMismatch(pos, ty1, ty, msg)) + (ty moreSpecific ty1).orElse(ty moreSpecific ty2).map((_, Chain.empty)).getOrElse { + (?, Chain.one(Error.TypeMismatch(pos, ty1, ty, msg))) } case (ty, Constraint.IsUnweakenable(ty2, msg)) => - (ty moreSpecific ty2).getOrElse { - ctx.error(Error.TypeMismatch(pos, ty2, ty, msg)) + (ty moreSpecific ty2).map((_, Chain.empty)).getOrElse { + (?, Chain.one(Error.TypeMismatch(pos, ty2, ty, msg))) } } @@ -86,35 +76,29 @@ object typeChecker { } } - /** Type-check a WACC program. + /** Type-check a function declaration. * - * @param prog - * The AST of the program to type-check. - * @param ctx - * The type checker context which includes the global names and functions, and an errors - * builder. + * @param func + * The AST of the function to type-check. */ - def check(prog: ast.Program, errors: mutable.Builder[Error, List[Error]]): microWacc.Program = - given ctx: TypeCheckerCtx = TypeCheckerCtx(errors) - microWacc.Program( - // Ignore function syntax types for return value and params, since those have been converted - // to SemTypes by the renamer. - prog.funcs.map { case ast.FuncDecl(_, name, params, stmts) => - val FuncType(retType, paramTypes) = name.ty.asInstanceOf[FuncType] - microWacc.FuncDecl( - microWacc.Ident(name.v, name.guid)(retType), - params.zip(paramTypes).map { case (ast.Param(_, ident), ty) => - microWacc.Ident(ident.v, ident.guid)(ty) - }, - stmts.toList - .flatMap( - checkStmt(_, Constraint.Is(retType, s"function ${name.v} must return $retType")) - ) - ) - }, - prog.main.toList - .flatMap(checkStmt(_, Constraint.Never("main function must not return"))) + def checkFuncDecl(func: ast.FuncDecl): (microWacc.FuncDecl, Chain[Error]) = { + val ast.FuncDecl(_, name, params, stmts) = func + val FuncType(retType, paramTypes) = name.ty.asInstanceOf[FuncType] + val returnConstraint = + if func.name.v == renamer.MAIN then Constraint.Never("main body must not return") + else Constraint.Is(retType, s"function ${name.v} must return $retType") + val (body, bodyErrors) = stmts.foldMap(checkStmt(_, returnConstraint)) + ( + microWacc.FuncDecl( + microWacc.Ident(name.v, name.guid)(retType), + params.zip(paramTypes).map { case (ast.Param(_, ident), ty) => + microWacc.Ident(ident.v, ident.guid)(ty) + }, + body + ), + bodyErrors ) + } /** Type-check an AST statement node. * @@ -123,45 +107,51 @@ object typeChecker { * @param returnConstraint * The constraint that any `return ` statements must satisfy. */ - private def checkStmt(stmt: ast.Stmt, returnConstraint: Constraint)(using - ctx: TypeCheckerCtx - ): List[microWacc.Stmt] = stmt match { + private def checkStmt( + stmt: ast.Stmt, + returnConstraint: Constraint + ): (Chain[microWacc.Stmt], Chain[Error]) = stmt match { // Ignore the type of the variable, since it has been converted to a SemType by the renamer. case ast.VarDecl(_, name, value) => val expectedTy = name.ty - val typedValue = checkValue( + val (typedValue, valueErrors) = checkValue( value, Constraint.Is( expectedTy.asInstanceOf[SemType], s"variable ${name.v} must be assigned a value of type $expectedTy" ) ) - List( - microWacc.Assign( - microWacc.Ident(name.v, name.guid)(expectedTy.asInstanceOf[SemType]), - typedValue - ) + ( + Chain.one( + microWacc.Assign( + microWacc.Ident(name.v, name.guid)(expectedTy.asInstanceOf[SemType]), + typedValue + ) + ), + valueErrors ) case ast.Assign(lhs, rhs) => - val lhsTyped = checkLValue(lhs, Constraint.Unconstrained) - val rhsTyped = + val (lhsTyped, lhsErrors) = checkLValue(lhs, Constraint.Unconstrained) + val (rhsTyped, rhsErrors) = checkValue(rhs, Constraint.Is(lhsTyped.ty, s"assignment must have type ${lhsTyped.ty}")) - (lhsTyped.ty, rhsTyped.ty) match { + val unknownError = (lhsTyped.ty, rhsTyped.ty) match { case (?, ?) => - ctx.error( + Chain.one( Error.SemanticError(lhs.pos, "assignment with both sides of unknown type is illegal") ) - case _ => () + case _ => Chain.empty } - List(microWacc.Assign(lhsTyped, rhsTyped)) + (Chain.one(microWacc.Assign(lhsTyped, rhsTyped)), lhsErrors ++ rhsErrors ++ unknownError) case ast.Read(dest) => - val destTyped = checkLValue(dest, Constraint.Unconstrained) - val destTy = destTyped.ty match { + val (destTyped, destErrors) = checkLValue(dest, Constraint.Unconstrained) + val (destTy, destTyErrors) = destTyped.ty match { case ? => - ctx.error( - Error.SemanticError(dest.pos, "cannot read into a destination with an unknown type") + ( + ?, + Chain.one( + Error.SemanticError(dest.pos, "cannot read into a destination with an unknown type") + ) ) - ? case destTy => destTy.satisfies( Constraint.IsEither( @@ -172,49 +162,44 @@ object typeChecker { dest.pos ) } - List( - microWacc.Assign( - destTyped, - microWacc.Call( - microWacc.Builtin.Read, - List( - destTy match { - case KnownType.Int => " %d".toMicroWaccCharArray - case KnownType.Char | _ => " %c".toMicroWaccCharArray - }, - destTyped - ) - ) - ) - ) - case ast.Free(lhs) => - List( - microWacc.Call( - microWacc.Builtin.Free, - List( - checkValue( - lhs, - Constraint.IsEither( - KnownType.Array(?), - KnownType.Pair(?, ?), - "free must be applied to an array or pair" + ( + Chain.one( + microWacc.Assign( + destTyped, + microWacc.Call( + microWacc.Builtin.Read, + List( + destTy match { + case KnownType.Int => " %d".toMicroWaccCharArray + case KnownType.Char | _ => " %c".toMicroWaccCharArray + }, + destTyped ) ) ) + ), + destErrors ++ destTyErrors + ) + case ast.Free(lhs) => + val (lhsTyped, lhsErrors) = checkValue( + lhs, + Constraint.IsEither( + KnownType.Array(?), + KnownType.Pair(?, ?), + "free must be applied to an array or pair" ) ) + (Chain.one(microWacc.Call(microWacc.Builtin.Free, List(lhsTyped))), lhsErrors) case ast.Return(expr) => - List(microWacc.Return(checkValue(expr, returnConstraint))) + val (exprTyped, exprErrors) = checkValue(expr, returnConstraint) + (Chain.one(microWacc.Return(exprTyped)), exprErrors) case ast.Exit(expr) => - List( - microWacc.Call( - microWacc.Builtin.Exit, - List(checkValue(expr, Constraint.Is(KnownType.Int, "exit value must be int"))) - ) - ) + val (exprTyped, exprErrors) = + checkValue(expr, Constraint.Is(KnownType.Int, "exit value must be int")) + (Chain.one(microWacc.Call(microWacc.Builtin.Exit, List(exprTyped))), exprErrors) case ast.Print(expr, newline) => // This constraint should never fail, the scope-checker should have caught it already - val exprTyped = checkValue(expr, Constraint.Unconstrained) + val (exprTyped, exprErrors) = checkValue(expr, Constraint.Unconstrained) val exprFormat = exprTyped.ty match { case KnownType.Bool | KnownType.String => "%s" case KnownType.Array(KnownType.Char) => "%.*s" @@ -223,7 +208,7 @@ object typeChecker { case KnownType.Pair(_, _) | KnownType.Array(_) | ? => "%p" } val printfCall = { (func: microWacc.Builtin, value: microWacc.Expr) => - List( + Chain.one( microWacc.Call( func, List( @@ -233,36 +218,38 @@ object typeChecker { ) ) } - exprTyped.ty match { - case KnownType.Bool => - List( - microWacc.If( - exprTyped, - printfCall(microWacc.Builtin.Printf, "true".toMicroWaccCharArray), - printfCall(microWacc.Builtin.Printf, "false".toMicroWaccCharArray) + ( + exprTyped.ty match { + case KnownType.Bool => + Chain.one( + microWacc.If( + exprTyped, + printfCall(microWacc.Builtin.Printf, "true".toMicroWaccCharArray), + printfCall(microWacc.Builtin.Printf, "false".toMicroWaccCharArray) + ) ) - ) - case KnownType.Array(KnownType.Char) => - printfCall(microWacc.Builtin.PrintCharArray, exprTyped) - case _ => printfCall(microWacc.Builtin.Printf, exprTyped) - } + case KnownType.Array(KnownType.Char) => + printfCall(microWacc.Builtin.PrintCharArray, exprTyped) + case _ => printfCall(microWacc.Builtin.Printf, exprTyped) + }, + exprErrors + ) case ast.If(cond, thenStmt, elseStmt) => - List( - microWacc.If( - checkValue(cond, Constraint.Is(KnownType.Bool, "if condition must be a bool")), - thenStmt.toList.flatMap(checkStmt(_, returnConstraint)), - elseStmt.toList.flatMap(checkStmt(_, returnConstraint)) - ) + val (condTyped, condErrors) = + checkValue(cond, Constraint.Is(KnownType.Bool, "if condition must be a bool")) + val (thenStmtTyped, thenErrors) = thenStmt.foldMap(checkStmt(_, returnConstraint)) + val (elseStmtTyped, elseErrors) = elseStmt.foldMap(checkStmt(_, returnConstraint)) + ( + Chain.one(microWacc.If(condTyped, thenStmtTyped, elseStmtTyped)), + condErrors ++ thenErrors ++ elseErrors ) case ast.While(cond, body) => - List( - microWacc.While( - checkValue(cond, Constraint.Is(KnownType.Bool, "while condition must be a bool")), - body.toList.flatMap(checkStmt(_, returnConstraint)) - ) - ) - case ast.Block(body) => body.toList.flatMap(checkStmt(_, returnConstraint)) - case skip @ ast.Skip() => List.empty + val (condTyped, condErrors) = + checkValue(cond, Constraint.Is(KnownType.Bool, "while condition must be a bool")) + val (bodyTyped, bodyErrors) = body.foldMap(checkStmt(_, returnConstraint)) + (Chain.one(microWacc.While(condTyped, bodyTyped)), condErrors ++ bodyErrors) + case ast.Block(body) => body.foldMap(checkStmt(_, returnConstraint)) + case skip @ ast.Skip() => (Chain.empty, Chain.empty) } /** Type-check an AST LValue, RValue or Expr node. This function does all 3 since these traits @@ -275,127 +262,142 @@ object typeChecker { * @return * The most specific type of the value if it could be determined, or ? if it could not. */ - private def checkValue(value: ast.LValue | ast.RValue | ast.Expr, constraint: Constraint)(using - ctx: TypeCheckerCtx - ): microWacc.Expr = value match { + private def checkValue( + value: ast.LValue | ast.RValue | ast.Expr, + constraint: Constraint + ): (microWacc.Expr, Chain[Error]) = value match { case l @ ast.IntLiter(v) => - KnownType.Int.satisfies(constraint, l.pos) - microWacc.IntLiter(v) + val (_, errors) = KnownType.Int.satisfies(constraint, l.pos) + (microWacc.IntLiter(v), errors) case l @ ast.BoolLiter(v) => - KnownType.Bool.satisfies(constraint, l.pos) - microWacc.BoolLiter(v) + val (_, errors) = KnownType.Bool.satisfies(constraint, l.pos) + (microWacc.BoolLiter(v), errors) case l @ ast.CharLiter(v) => - KnownType.Char.satisfies(constraint, l.pos) - microWacc.CharLiter(v) + val (_, errors) = KnownType.Char.satisfies(constraint, l.pos) + (microWacc.CharLiter(v), errors) case l @ ast.StrLiter(v) => - KnownType.String.satisfies(constraint, l.pos) - v.toMicroWaccCharArray + val (_, errors) = KnownType.String.satisfies(constraint, l.pos) + (v.toMicroWaccCharArray, errors) case l @ ast.PairLiter() => - microWacc.NullLiter()(KnownType.Pair(?, ?).satisfies(constraint, l.pos)) + val (ty, errors) = KnownType.Pair(?, ?).satisfies(constraint, l.pos) + (microWacc.NullLiter()(ty), errors) case ast.Parens(expr) => checkValue(expr, constraint) case l @ ast.ArrayLiter(elems) => - val (elemTy, elemsTyped) = elems.mapAccumulate[SemType, microWacc.Expr](?) { - case (acc, elem) => - val elemTyped = checkValue( - elem, - Constraint.IsSymmetricCompatible(acc, s"array elements must have the same type") - ) - (elemTyped.ty, elemTyped) - } - val arrayTy = KnownType + val ((elemTy, elemsErrors), elemsTyped) = + elems.mapAccumulate[(SemType, Chain[Error]), microWacc.Expr]((?, Chain.empty)) { + case ((acc, errors), elem) => + val (elemTyped, elemErrors) = checkValue( + elem, + Constraint.IsSymmetricCompatible(acc, s"array elements must have the same type") + ) + ((elemTyped.ty, errors ++ elemErrors), elemTyped) + } + val (arrayTy, arrayErrors) = KnownType // Start with an unknown param type, make it more specific while checking the elements. .Array(elemTy) .satisfies(constraint, l.pos) - microWacc.ArrayLiter(elemsTyped)(arrayTy) + (microWacc.ArrayLiter(elemsTyped)(arrayTy), elemsErrors ++ arrayErrors) case l @ ast.NewPair(fst, snd) => - val fstTyped = checkValue(fst, Constraint.Unconstrained) - val sndTyped = checkValue(snd, Constraint.Unconstrained) - microWacc.ArrayLiter(List(fstTyped, sndTyped))( + val (fstTyped, fstErrors) = checkValue(fst, Constraint.Unconstrained) + val (sndTyped, sndErrors) = checkValue(snd, Constraint.Unconstrained) + val (pairTy, pairErrors) = KnownType.Pair(fstTyped.ty, sndTyped.ty).satisfies(constraint, l.pos) - ) + (microWacc.ArrayLiter(List(fstTyped, sndTyped))(pairTy), fstErrors ++ sndErrors ++ pairErrors) case ast.Call(id, args) => val funcTy @ FuncType(retTy, paramTys) = id.ty.asInstanceOf[FuncType] - if (args.length != paramTys.length) { - ctx.error(Error.FunctionParamsMismatch(id, paramTys.length, args.length, funcTy)) - } + val lenError = + if (args.length == paramTys.length) then Chain.empty + else Chain.one(Error.FunctionParamsMismatch(id, paramTys.length, args.length, funcTy)) // Even if the number of arguments is wrong, we still check the types of the arguments // in the best way we can (by taking a zip). - val argsTyped = args.zip(paramTys).map { case (arg, paramTy) => - checkValue(arg, Constraint.Is(paramTy, s"argument type mismatch in function ${id.v}")) - } - microWacc.Call(microWacc.Ident(id.v, id.guid)(retTy.satisfies(constraint, id.pos)), argsTyped) + val (argsErrors, argsTyped) = + args.zip(paramTys).mapAccumulate(Chain.empty[Error]) { case (errors, (arg, paramTy)) => + val (argTyped, argErrors) = + checkValue(arg, Constraint.Is(paramTy, s"argument type mismatch in function ${id.v}")) + (errors ++ argErrors, argTyped) + } + val (retTyChecked, retErrors) = retTy.satisfies(constraint, id.pos) + ( + microWacc.Call(microWacc.Ident(id.v, id.guid)(retTyChecked), argsTyped), + lenError ++ argsErrors ++ retErrors + ) // Unary operators case ast.Negate(x) => - microWacc.UnaryOp( - checkValue(x, Constraint.Is(KnownType.Int, "negation must be applied to an int")), - microWacc.UnaryOperator.Negate - )(KnownType.Int.satisfies(constraint, x.pos)) + val (argTyped, argErrors) = + checkValue(x, Constraint.Is(KnownType.Int, "negation must be applied to an int")) + val (retTy, retErrors) = KnownType.Int.satisfies(constraint, x.pos) + (microWacc.UnaryOp(argTyped, microWacc.UnaryOperator.Negate)(retTy), argErrors ++ retErrors) case ast.Not(x) => - microWacc.UnaryOp( - checkValue(x, Constraint.Is(KnownType.Bool, "logical not must be applied to a bool")), - microWacc.UnaryOperator.Not - )(KnownType.Bool.satisfies(constraint, x.pos)) + val (argTyped, argErrors) = + checkValue(x, Constraint.Is(KnownType.Bool, "logical not must be applied to a bool")) + val (retTy, retErrors) = KnownType.Bool.satisfies(constraint, x.pos) + (microWacc.UnaryOp(argTyped, microWacc.UnaryOperator.Not)(retTy), argErrors ++ retErrors) case ast.Len(x) => - microWacc.UnaryOp( - checkValue(x, Constraint.Is(KnownType.Array(?), "len must be applied to an array")), - microWacc.UnaryOperator.Len - )(KnownType.Int.satisfies(constraint, x.pos)) + val (argTyped, argErrors) = + checkValue(x, Constraint.Is(KnownType.Array(?), "len must be applied to an array")) + val (retTy, retErrors) = KnownType.Int.satisfies(constraint, x.pos) + (microWacc.UnaryOp(argTyped, microWacc.UnaryOperator.Len)(retTy), argErrors ++ retErrors) case ast.Ord(x) => - microWacc.UnaryOp( - checkValue(x, Constraint.Is(KnownType.Char, "ord must be applied to a char")), - microWacc.UnaryOperator.Ord - )(KnownType.Int.satisfies(constraint, x.pos)) + val (argTyped, argErrors) = + checkValue(x, Constraint.Is(KnownType.Char, "ord must be applied to a char")) + val (retTy, retErrors) = KnownType.Int.satisfies(constraint, x.pos) + (microWacc.UnaryOp(argTyped, microWacc.UnaryOperator.Ord)(retTy), argErrors ++ retErrors) case ast.Chr(x) => - microWacc.UnaryOp( - checkValue(x, Constraint.Is(KnownType.Int, "chr must be applied to an int")), - microWacc.UnaryOperator.Chr - )(KnownType.Char.satisfies(constraint, x.pos)) + val (argTyped, argErrors) = + checkValue(x, Constraint.Is(KnownType.Int, "chr must be applied to an int")) + val (retTy, retErrors) = KnownType.Char.satisfies(constraint, x.pos) + (microWacc.UnaryOp(argTyped, microWacc.UnaryOperator.Chr)(retTy), argErrors ++ retErrors) // Binary operators case op: (ast.Add | ast.Sub | ast.Mul | ast.Div | ast.Mod) => val operand = Constraint.Is(KnownType.Int, s"${op.name} operator must be applied to an int") - microWacc.BinaryOp( - checkValue(op.x, operand), - checkValue(op.y, operand), - microWacc.BinaryOperator.fromAst(op) - )(KnownType.Int.satisfies(constraint, op.pos)) + val (xTyped, xErrors) = checkValue(op.x, operand) + val (yTyped, yErrors) = checkValue(op.y, operand) + val (retTy, retErrors) = KnownType.Int.satisfies(constraint, op.pos) + ( + microWacc.BinaryOp(xTyped, yTyped, microWacc.BinaryOperator.fromAst(op))(retTy), + xErrors ++ yErrors ++ retErrors + ) case op: (ast.Eq | ast.Neq) => - val xTyped = checkValue(op.x, Constraint.Unconstrained) - microWacc.BinaryOp( - xTyped, - checkValue( - op.y, - Constraint - .Is(xTyped.ty, s"${op.name} operator must be applied to values of the same type") - ), - microWacc.BinaryOperator.fromAst(op) - )(KnownType.Bool.satisfies(constraint, op.pos)) + val (xTyped, xErrors) = checkValue(op.x, Constraint.Unconstrained) + val (yTyped, yErrors) = checkValue( + op.y, + Constraint.Is(xTyped.ty, s"${op.name} operator must be applied to values of the same type") + ) + val (retTy, retErrors) = KnownType.Bool.satisfies(constraint, op.pos) + ( + microWacc.BinaryOp(xTyped, yTyped, microWacc.BinaryOperator.fromAst(op))(retTy), + xErrors ++ yErrors ++ retErrors + ) case op: (ast.Less | ast.LessEq | ast.Greater | ast.GreaterEq) => val xConstraint = Constraint.IsEither( KnownType.Int, KnownType.Char, s"${op.name} operator must be applied to an int or char" ) - val xTyped = checkValue(op.x, xConstraint) + val (xTyped, xErrors) = checkValue(op.x, xConstraint) // If x type-check failed, we still want to check y is an Int or Char (rather than ?) val yConstraint = xTyped.ty match { case ? => xConstraint case xTy => Constraint.Is(xTy, s"${op.name} operator must be applied to values of the same type") } - microWacc.BinaryOp( - xTyped, - checkValue(op.y, yConstraint), - microWacc.BinaryOperator.fromAst(op) - )(KnownType.Bool.satisfies(constraint, op.pos)) + val (yTyped, yErrors) = checkValue(op.y, yConstraint) + val (retTy, retErrors) = KnownType.Bool.satisfies(constraint, op.pos) + ( + microWacc.BinaryOp(xTyped, yTyped, microWacc.BinaryOperator.fromAst(op))(retTy), + xErrors ++ yErrors ++ retErrors + ) case op: (ast.And | ast.Or) => val operand = Constraint.Is(KnownType.Bool, s"${op.name} operator must be applied to a bool") - microWacc.BinaryOp( - checkValue(op.x, operand), - checkValue(op.y, operand), - microWacc.BinaryOperator.fromAst(op) - )(KnownType.Bool.satisfies(constraint, op.pos)) + val (xTyped, xErrors) = checkValue(op.x, operand) + val (yTyped, yErrors) = checkValue(op.y, operand) + val (retTy, retErrors) = KnownType.Bool.satisfies(constraint, op.pos) + ( + microWacc.BinaryOp(xTyped, yTyped, microWacc.BinaryOperator.fromAst(op))(retTy), + xErrors ++ yErrors ++ retErrors + ) case lvalue: ast.LValue => checkLValue(lvalue, constraint) } @@ -412,68 +414,69 @@ object typeChecker { * @return * The most specific type of the value if it could be determined, or ? if it could not. */ - private def checkLValue(value: ast.LValue, constraint: Constraint)(using - ctx: TypeCheckerCtx - ): microWacc.LValue = value match { + private def checkLValue( + value: ast.LValue, + constraint: Constraint + ): (microWacc.LValue, Chain[Error]) = value match { case id @ ast.Ident(name, guid, ty) => - microWacc.Ident(name, guid)(ty.asInstanceOf[SemType].satisfies(constraint, id.pos)) + val (idTy, idErrors) = ty.asInstanceOf[SemType].satisfies(constraint, id.pos) + (microWacc.Ident(name, guid)(idTy), idErrors) case ast.ArrayElem(id, indices) => val arrayTy = id.ty.asInstanceOf[SemType] - val (elemTy, indicesTyped) = indices.mapAccumulate(arrayTy.asInstanceOf[SemType]) { - (acc, elem) => - val idxTyped = - checkValue(elem, Constraint.Is(KnownType.Int, "array index must be an int")) - val next = acc match { - case KnownType.Array(innerTy) => innerTy - case ? => ? // we can keep indexing an unknown type - case nonArrayTy => - ctx.error( - Error.TypeMismatch( - elem.pos, - KnownType.Array(?), - acc, - "cannot index into a non-array" + val ((elemTy, elemErrors), indicesTyped) = + indices.mapAccumulate((arrayTy.asInstanceOf[SemType], Chain.empty[Error])) { + case ((acc, errors), elem) => + val (idxTyped, idxErrors) = + checkValue(elem, Constraint.Is(KnownType.Int, "array index must be an int")) + val (next, nextError) = acc match { + case KnownType.Array(innerTy) => (innerTy, Chain.empty) + case ? => (?, Chain.empty) // we can keep indexing an unknown type + case nonArrayTy => + ( + ?, + Chain.one( + Error.TypeMismatch( + elem.pos, + KnownType.Array(?), + acc, + "cannot index into a non-array" + ) + ) ) - ) - ? - } - (next, idxTyped) - } + } + ((next, errors ++ idxErrors ++ nextError), idxTyped) + } + val (retTy, retErrors) = elemTy.satisfies(constraint, value.pos) val firstArrayElem = microWacc.ArrayElem( microWacc.Ident(id.v, id.guid)(arrayTy), indicesTyped.head - )(elemTy.satisfies(constraint, value.pos)) + )(retTy) val arrayElem = indicesTyped.tail.foldLeft(firstArrayElem) { (acc, idx) => microWacc.ArrayElem(acc, idx)(KnownType.Array(acc.ty)) } // Need to type-check the final arrayElem with the constraint - microWacc.ArrayElem(arrayElem.value, arrayElem.index)(elemTy.satisfies(constraint, value.pos)) + // TODO: What + (microWacc.ArrayElem(arrayElem.value, arrayElem.index)(retTy), elemErrors ++ retErrors) case ast.Fst(elem) => - val elemTyped = checkLValue( + val (elemTyped, elemErrors) = checkLValue( elem, Constraint.Is(KnownType.Pair(?, ?), "fst must be applied to a pair") ) - microWacc.ArrayElem( - elemTyped, - microWacc.IntLiter(0) - )(elemTyped.ty match { - case KnownType.Pair(left, _) => - left.satisfies(constraint, elem.pos) - case _ => ctx.error(Error.InternalError(elem.pos, "fst must be applied to a pair")) - }) + val (retTy, retErrors) = elemTyped.ty match { + case KnownType.Pair(left, _) => left.satisfies(constraint, elem.pos) + case _ => (?, Chain.one(Error.InternalError(elem.pos, "fst must be applied to a pair"))) + } + (microWacc.ArrayElem(elemTyped, microWacc.IntLiter(0))(retTy), elemErrors ++ retErrors) case ast.Snd(elem) => - val elemTyped = checkLValue( + val (elemTyped, elemErrors) = checkLValue( elem, Constraint.Is(KnownType.Pair(?, ?), "snd must be applied to a pair") ) - microWacc.ArrayElem( - elemTyped, - microWacc.IntLiter(1) - )(elemTyped.ty match { - case KnownType.Pair(_, right) => - right.satisfies(constraint, elem.pos) - case _ => ctx.error(Error.InternalError(elem.pos, "snd must be applied to a pair")) - }) + val (retTy, retErrors) = elemTyped.ty match { + case KnownType.Pair(_, right) => right.satisfies(constraint, elem.pos) + case _ => (?, Chain.one(Error.InternalError(elem.pos, "snd must be applied to a pair"))) + } + (microWacc.ArrayElem(elemTyped, microWacc.IntLiter(1))(retTy), elemErrors ++ retErrors) } extension (s: String) {