feat: initial parallel type-checker implementation

This commit is contained in:
2025-03-14 04:09:34 +00:00
parent 42515abf2a
commit 53d47fda63
5 changed files with 322 additions and 283 deletions

View File

@@ -1,6 +1,5 @@
package wacc package wacc
import scala.collection.mutable
import cats.data.{Chain, NonEmptyList} import cats.data.{Chain, NonEmptyList}
import parsley.{Failure, Success} import parsley.{Failure, Success}
@@ -19,6 +18,7 @@ import org.typelevel.log4cats.Logger
import assemblyIR as asm import assemblyIR as asm
import cats.data.ValidatedNel import cats.data.ValidatedNel
import java.io.File import java.io.File
import cats.data.NonEmptySeq
/* /*
TODO: TODO:
@@ -71,18 +71,15 @@ val outputOpt: Opts[Option[Path]] =
def frontend( def frontend(
contents: String, contents: String,
file: File file: File
): IO[Either[NonEmptyList[Error], microWacc.Program]] = ): IO[Either[NonEmptySeq[Error], microWacc.Program]] =
parser.parse(contents) match { parser.parse(contents) match {
case Failure(msg) => IO.pure(Left(NonEmptyList.one(Error.SyntaxError(file, msg)))) case Failure(msg) => IO.pure(Left(NonEmptySeq.one(Error.SyntaxError(file, msg))))
case Success(fn) => case Success(fn) =>
val partialProg = fn(file) val partialProg = fn(file)
given errors: mutable.Builder[Error, List[Error]] = List.newBuilder
for { for {
(prog, renameErrors) <- renamer.rename(partialProg) (typedProg, errors) <- semantics.check(partialProg)
_ = errors.addAll(renameErrors.toList) res = NonEmptySeq.fromSeq(errors.iterator.toSeq).map(Left(_)).getOrElse(Right(typedProg))
typedProg = typeChecker.check(prog, errors)
res = errors.result.toNel.toLeft(typedProg)
} yield res } yield res
} }
@@ -103,12 +100,19 @@ def compile(
// TODO: path, file , the names are confusing (when Path is the type but we are working with files) // TODO: path, file , the names are confusing (when Path is the type but we are working with files)
def writeOutputFile(typedProg: microWacc.Program, outputPath: Path): IO[Unit] = def writeOutputFile(typedProg: microWacc.Program, outputPath: Path): IO[Unit] =
writer.writeTo(backend(typedProg), outputPath) *> val backendStart = System.nanoTime()
logger.info(s"Success: ${outputPath.toAbsolutePath}") val asmLines = backend(typedProg)
val backendEnd = System.nanoTime()
writer.writeTo(asmLines, outputPath) *>
logAction(s"Backend time (${filePath.toRealPath()}): ${(backendEnd - backendStart).toFloat / 1e6} ms") *>
logAction(s"Success: ${outputPath.toAbsolutePath}")
def processProgram(contents: String, file: File, outDir: Path): IO[Int] = def processProgram(contents: String, file: File, outDir: Path): IO[Int] =
val frontendStart = System.nanoTime()
for { for {
frontendResult <- frontend(contents, file) frontendResult <- frontend(contents, file)
frontendEnd = System.nanoTime()
_ <- logAction(s"Frontend time (${filePath.toRealPath()}): ${(frontendEnd - frontendStart).toFloat / 1e6} ms")
res <- frontendResult match { res <- frontendResult match {
case Left(errors) => case Left(errors) =>
val code = errors.map(err => err.exitCode).toList.min val code = errors.map(err => err.exitCode).toList.min

View File

@@ -1,5 +1,7 @@
package wacc package wacc
import cats.data.Chain
object microWacc { object microWacc {
import wacc.types._ import wacc.types._
@@ -78,12 +80,12 @@ object microWacc {
} }
case class Assign(lhs: LValue, rhs: Expr) extends Stmt case class Assign(lhs: LValue, rhs: Expr) extends Stmt
case class If(cond: Expr, thenBranch: List[Stmt], elseBranch: List[Stmt]) extends Stmt case class If(cond: Expr, thenBranch: Chain[Stmt], elseBranch: Chain[Stmt]) extends Stmt
case class While(cond: Expr, body: List[Stmt]) extends Stmt case class While(cond: Expr, body: Chain[Stmt]) extends Stmt
case class Call(target: CallTarget, args: List[Expr]) extends Stmt with Expr(target.retTy) case class Call(target: CallTarget, args: List[Expr]) extends Stmt with Expr(target.retTy)
case class Return(expr: Expr) extends Stmt case class Return(expr: Expr) extends Stmt
// Program // Program
case class FuncDecl(name: Ident, params: List[Ident], body: List[Stmt]) case class FuncDecl(name: Ident, params: List[Ident], body: Chain[Stmt])
case class Program(funcs: List[FuncDecl], stmts: List[Stmt]) case class Program(funcs: Chain[FuncDecl], stmts: Chain[Stmt])
} }

View File

@@ -3,27 +3,27 @@ package wacc
import java.io.File import java.io.File
import scala.collection.mutable import scala.collection.mutable
import cats.effect.IO import cats.effect.IO
import cats.syntax.all._
import cats.implicits._ import cats.implicits._
import cats.data.Chain import cats.data.Chain
import cats.data.NonEmptyList import cats.data.NonEmptyList
import parsley.{Failure, Success} import parsley.{Failure, Success}
private val MAIN = "$main"
object renamer { object renamer {
import ast._ import ast._
import types._ import types._
private enum IdentType { val MAIN = "$main"
enum IdentType {
case Func case Func
case Var case Var
} }
private case class ScopeKey(path: String, name: String, identType: IdentType) case class ScopeKey(path: String, name: String, identType: IdentType)
private case class ScopeValue(id: Ident, public: Boolean) case class ScopeValue(id: Ident, public: Boolean)
private class Scope( class Scope(
private val current: mutable.Map[ScopeKey, ScopeValue], private val current: mutable.Map[ScopeKey, ScopeValue],
private val parent: Map[ScopeKey, ScopeValue], private val parent: Map[ScopeKey, ScopeValue],
guidStart: Int = 0, guidStart: Int = 0,
@@ -153,7 +153,7 @@ object renamer {
get(name.pos.file.getCanonicalPath, name.v, IdentType.Func).map(_.id) get(name.pos.file.getCanonicalPath, name.v, IdentType.Func).map(_.id)
} }
private def prepareGlobalScope( def prepareGlobalScope(
partialProg: PartialProgram partialProg: PartialProgram
)(using scope: Scope): IO[(FuncDecl, Chain[FuncDecl], Chain[Error])] = { )(using scope: Scope): IO[(FuncDecl, Chain[FuncDecl], Chain[Error])] = {
def readImportFile(file: File): IO[String] = def readImportFile(file: File): IO[String] =
@@ -267,25 +267,13 @@ object renamer {
* @return * @return
* (flattenedProg, errors) * (flattenedProg, errors)
*/ */
private def renameFunction(funcScopePair: (FuncDecl, Scope)): IO[Chain[Error]] = { def renameFunction(funcScopePair: (FuncDecl, Scope)): IO[Chain[Error]] = {
val (FuncDecl(_, _, params, body), subscope) = funcScopePair val (FuncDecl(_, _, params, body), subscope) = funcScopePair
val paramErrors = params.foldMap(param => subscope.add(param.name)) val paramErrors = params.foldMap(param => subscope.add(param.name))
IO(subscope.withSubscope { s => body.foldMap(rename(s)) }) IO(subscope.withSubscope { s => body.foldMap(rename(s)) })
.map(bodyErrors => paramErrors ++ bodyErrors) .map(bodyErrors => paramErrors ++ bodyErrors)
} }
def rename(partialProg: PartialProgram): IO[(Program, Chain[Error])] = {
given scope: Scope = Scope(mutable.Map.empty, Map.empty)
for {
(main, chunks, globalErrors) <- prepareGlobalScope(partialProg)
toRename = (main +: chunks).toList
allErrors <- toRename
.zip(scope.subscopes(toRename.size))
.parFoldMapA(renameFunction)
} yield (Program(chunks.toList, main.body)(main.pos), globalErrors ++ allErrors)
}
/** Check scoping of all identifies in a given AST node. /** Check scoping of all identifies in a given AST node.
* *
* @param scope * @param scope

View File

@@ -0,0 +1,42 @@
package wacc
import scala.collection.mutable
import cats.implicits._
import cats.data.Chain
import cats.effect.IO
object semantics {
import renamer.{Scope, prepareGlobalScope, renameFunction}
import typeChecker.checkFuncDecl
private def checkFunc(
funcDecl: ast.FuncDecl,
scope: Scope
): IO[Chain[(microWacc.FuncDecl, Chain[Error])]] = {
for {
renamerErrors <- renameFunction(funcDecl, scope)
(microWaccFunc, typeErrors) = checkFuncDecl(funcDecl)
} yield Chain.one(microWaccFunc, renamerErrors ++ typeErrors)
}
def check(partialProg: ast.PartialProgram): IO[(microWacc.Program, Chain[Error])] = {
given scope: Scope = Scope(mutable.Map.empty, Map.empty)
for {
(main, chunks, globalErrors) <- prepareGlobalScope(partialProg)
toRename = (main +: chunks).toList
res <- toRename
.zip(scope.subscopes(toRename.size))
.parFoldMapA(checkFunc)
(typedChunks, errors) = res.foldLeft((Chain.empty[microWacc.FuncDecl], Chain.empty[Error])) {
case ((acc, err), (funcDecl, errors)) =>
(acc :+ funcDecl, err ++ errors)
}
(typedMain, funcs) = typedChunks.uncons match {
case Some((head, tail)) => (head.body, tail)
case None => (Chain.empty, Chain.empty)
}
} yield (microWacc.Program(funcs, typedMain), globalErrors ++ errors)
}
}

View File

@@ -1,20 +1,12 @@
package wacc package wacc
import cats.syntax.all._ import cats.syntax.all._
import scala.collection.mutable
import cats.data.NonEmptyList import cats.data.NonEmptyList
import cats.data.Chain
object typeChecker { object typeChecker {
import wacc.types._ import wacc.types._
case class TypeCheckerCtx(
errors: mutable.Builder[Error, List[Error]]
) {
def error(err: Error): SemType =
errors += err
?
}
private enum Constraint { private enum Constraint {
case Unconstrained case Unconstrained
// Allows weakening in one direction // Allows weakening in one direction
@@ -38,31 +30,29 @@ object typeChecker {
* @return * @return
* The type if the constraint was satisfied, or ? if it was not. * The type if the constraint was satisfied, or ? if it was not.
*/ */
private def satisfies(constraint: Constraint, pos: ast.Position)(using private def satisfies(constraint: Constraint, pos: ast.Position): (SemType, Chain[Error]) =
ctx: TypeCheckerCtx
): SemType =
(ty, constraint) match { (ty, constraint) match {
case (KnownType.Array(KnownType.Char), Constraint.Is(KnownType.String, _)) => case (KnownType.Array(KnownType.Char), Constraint.Is(KnownType.String, _)) =>
KnownType.String (KnownType.String, Chain.empty)
case ( case (
KnownType.String, KnownType.String,
Constraint.IsSymmetricCompatible(KnownType.Array(KnownType.Char), _) Constraint.IsSymmetricCompatible(KnownType.Array(KnownType.Char), _)
) => ) =>
KnownType.String (KnownType.String, Chain.empty)
case (ty, Constraint.IsSymmetricCompatible(ty2, msg)) => case (ty, Constraint.IsSymmetricCompatible(ty2, msg)) =>
ty.satisfies(Constraint.Is(ty2, msg), pos) ty.satisfies(Constraint.Is(ty2, msg), pos)
// Change to IsUnweakenable to disallow recursive weakening // Change to IsUnweakenable to disallow recursive weakening
case (ty, Constraint.Is(ty2, msg)) => ty.satisfies(Constraint.IsUnweakenable(ty2, msg), pos) case (ty, Constraint.Is(ty2, msg)) => ty.satisfies(Constraint.IsUnweakenable(ty2, msg), pos)
case (ty, Constraint.Unconstrained) => ty case (ty, Constraint.Unconstrained) => (ty, Chain.empty)
case (ty, Constraint.Never(msg)) => case (ty, Constraint.Never(msg)) =>
ctx.error(Error.SemanticError(pos, msg)) (?, Chain.one(Error.SemanticError(pos, msg)))
case (ty, Constraint.IsEither(ty1, ty2, msg)) => case (ty, Constraint.IsEither(ty1, ty2, msg)) =>
(ty moreSpecific ty1).orElse(ty moreSpecific ty2).getOrElse { (ty moreSpecific ty1).orElse(ty moreSpecific ty2).map((_, Chain.empty)).getOrElse {
ctx.error(Error.TypeMismatch(pos, ty1, ty, msg)) (?, Chain.one(Error.TypeMismatch(pos, ty1, ty, msg)))
} }
case (ty, Constraint.IsUnweakenable(ty2, msg)) => case (ty, Constraint.IsUnweakenable(ty2, msg)) =>
(ty moreSpecific ty2).getOrElse { (ty moreSpecific ty2).map((_, Chain.empty)).getOrElse {
ctx.error(Error.TypeMismatch(pos, ty2, ty, msg)) (?, Chain.one(Error.TypeMismatch(pos, ty2, ty, msg)))
} }
} }
@@ -86,35 +76,29 @@ object typeChecker {
} }
} }
/** Type-check a WACC program. /** Type-check a function declaration.
* *
* @param prog * @param func
* The AST of the program to type-check. * The AST of the function to type-check.
* @param ctx
* The type checker context which includes the global names and functions, and an errors
* builder.
*/ */
def check(prog: ast.Program, errors: mutable.Builder[Error, List[Error]]): microWacc.Program = def checkFuncDecl(func: ast.FuncDecl): (microWacc.FuncDecl, Chain[Error]) = {
given ctx: TypeCheckerCtx = TypeCheckerCtx(errors) val ast.FuncDecl(_, name, params, stmts) = func
microWacc.Program( val FuncType(retType, paramTypes) = name.ty.asInstanceOf[FuncType]
// Ignore function syntax types for return value and params, since those have been converted val returnConstraint =
// to SemTypes by the renamer. if func.name.v == renamer.MAIN then Constraint.Never("main body must not return")
prog.funcs.map { case ast.FuncDecl(_, name, params, stmts) => else Constraint.Is(retType, s"function ${name.v} must return $retType")
val FuncType(retType, paramTypes) = name.ty.asInstanceOf[FuncType] val (body, bodyErrors) = stmts.foldMap(checkStmt(_, returnConstraint))
microWacc.FuncDecl( (
microWacc.Ident(name.v, name.guid)(retType), microWacc.FuncDecl(
params.zip(paramTypes).map { case (ast.Param(_, ident), ty) => microWacc.Ident(name.v, name.guid)(retType),
microWacc.Ident(ident.v, ident.guid)(ty) params.zip(paramTypes).map { case (ast.Param(_, ident), ty) =>
}, microWacc.Ident(ident.v, ident.guid)(ty)
stmts.toList },
.flatMap( body
checkStmt(_, Constraint.Is(retType, s"function ${name.v} must return $retType")) ),
) bodyErrors
)
},
prog.main.toList
.flatMap(checkStmt(_, Constraint.Never("main function must not return")))
) )
}
/** Type-check an AST statement node. /** Type-check an AST statement node.
* *
@@ -123,45 +107,51 @@ object typeChecker {
* @param returnConstraint * @param returnConstraint
* The constraint that any `return <expr>` statements must satisfy. * The constraint that any `return <expr>` statements must satisfy.
*/ */
private def checkStmt(stmt: ast.Stmt, returnConstraint: Constraint)(using private def checkStmt(
ctx: TypeCheckerCtx stmt: ast.Stmt,
): List[microWacc.Stmt] = stmt match { returnConstraint: Constraint
): (Chain[microWacc.Stmt], Chain[Error]) = stmt match {
// Ignore the type of the variable, since it has been converted to a SemType by the renamer. // Ignore the type of the variable, since it has been converted to a SemType by the renamer.
case ast.VarDecl(_, name, value) => case ast.VarDecl(_, name, value) =>
val expectedTy = name.ty val expectedTy = name.ty
val typedValue = checkValue( val (typedValue, valueErrors) = checkValue(
value, value,
Constraint.Is( Constraint.Is(
expectedTy.asInstanceOf[SemType], expectedTy.asInstanceOf[SemType],
s"variable ${name.v} must be assigned a value of type $expectedTy" s"variable ${name.v} must be assigned a value of type $expectedTy"
) )
) )
List( (
microWacc.Assign( Chain.one(
microWacc.Ident(name.v, name.guid)(expectedTy.asInstanceOf[SemType]), microWacc.Assign(
typedValue microWacc.Ident(name.v, name.guid)(expectedTy.asInstanceOf[SemType]),
) typedValue
)
),
valueErrors
) )
case ast.Assign(lhs, rhs) => case ast.Assign(lhs, rhs) =>
val lhsTyped = checkLValue(lhs, Constraint.Unconstrained) val (lhsTyped, lhsErrors) = checkLValue(lhs, Constraint.Unconstrained)
val rhsTyped = val (rhsTyped, rhsErrors) =
checkValue(rhs, Constraint.Is(lhsTyped.ty, s"assignment must have type ${lhsTyped.ty}")) checkValue(rhs, Constraint.Is(lhsTyped.ty, s"assignment must have type ${lhsTyped.ty}"))
(lhsTyped.ty, rhsTyped.ty) match { val unknownError = (lhsTyped.ty, rhsTyped.ty) match {
case (?, ?) => case (?, ?) =>
ctx.error( Chain.one(
Error.SemanticError(lhs.pos, "assignment with both sides of unknown type is illegal") Error.SemanticError(lhs.pos, "assignment with both sides of unknown type is illegal")
) )
case _ => () case _ => Chain.empty
} }
List(microWacc.Assign(lhsTyped, rhsTyped)) (Chain.one(microWacc.Assign(lhsTyped, rhsTyped)), lhsErrors ++ rhsErrors ++ unknownError)
case ast.Read(dest) => case ast.Read(dest) =>
val destTyped = checkLValue(dest, Constraint.Unconstrained) val (destTyped, destErrors) = checkLValue(dest, Constraint.Unconstrained)
val destTy = destTyped.ty match { val (destTy, destTyErrors) = destTyped.ty match {
case ? => case ? =>
ctx.error( (
Error.SemanticError(dest.pos, "cannot read into a destination with an unknown type") ?,
Chain.one(
Error.SemanticError(dest.pos, "cannot read into a destination with an unknown type")
)
) )
?
case destTy => case destTy =>
destTy.satisfies( destTy.satisfies(
Constraint.IsEither( Constraint.IsEither(
@@ -172,49 +162,44 @@ object typeChecker {
dest.pos dest.pos
) )
} }
List( (
microWacc.Assign( Chain.one(
destTyped, microWacc.Assign(
microWacc.Call( destTyped,
microWacc.Builtin.Read, microWacc.Call(
List( microWacc.Builtin.Read,
destTy match { List(
case KnownType.Int => " %d".toMicroWaccCharArray destTy match {
case KnownType.Char | _ => " %c".toMicroWaccCharArray case KnownType.Int => " %d".toMicroWaccCharArray
}, case KnownType.Char | _ => " %c".toMicroWaccCharArray
destTyped },
) destTyped
)
)
)
case ast.Free(lhs) =>
List(
microWacc.Call(
microWacc.Builtin.Free,
List(
checkValue(
lhs,
Constraint.IsEither(
KnownType.Array(?),
KnownType.Pair(?, ?),
"free must be applied to an array or pair"
) )
) )
) )
),
destErrors ++ destTyErrors
)
case ast.Free(lhs) =>
val (lhsTyped, lhsErrors) = checkValue(
lhs,
Constraint.IsEither(
KnownType.Array(?),
KnownType.Pair(?, ?),
"free must be applied to an array or pair"
) )
) )
(Chain.one(microWacc.Call(microWacc.Builtin.Free, List(lhsTyped))), lhsErrors)
case ast.Return(expr) => case ast.Return(expr) =>
List(microWacc.Return(checkValue(expr, returnConstraint))) val (exprTyped, exprErrors) = checkValue(expr, returnConstraint)
(Chain.one(microWacc.Return(exprTyped)), exprErrors)
case ast.Exit(expr) => case ast.Exit(expr) =>
List( val (exprTyped, exprErrors) =
microWacc.Call( checkValue(expr, Constraint.Is(KnownType.Int, "exit value must be int"))
microWacc.Builtin.Exit, (Chain.one(microWacc.Call(microWacc.Builtin.Exit, List(exprTyped))), exprErrors)
List(checkValue(expr, Constraint.Is(KnownType.Int, "exit value must be int")))
)
)
case ast.Print(expr, newline) => case ast.Print(expr, newline) =>
// This constraint should never fail, the scope-checker should have caught it already // This constraint should never fail, the scope-checker should have caught it already
val exprTyped = checkValue(expr, Constraint.Unconstrained) val (exprTyped, exprErrors) = checkValue(expr, Constraint.Unconstrained)
val exprFormat = exprTyped.ty match { val exprFormat = exprTyped.ty match {
case KnownType.Bool | KnownType.String => "%s" case KnownType.Bool | KnownType.String => "%s"
case KnownType.Array(KnownType.Char) => "%.*s" case KnownType.Array(KnownType.Char) => "%.*s"
@@ -223,7 +208,7 @@ object typeChecker {
case KnownType.Pair(_, _) | KnownType.Array(_) | ? => "%p" case KnownType.Pair(_, _) | KnownType.Array(_) | ? => "%p"
} }
val printfCall = { (func: microWacc.Builtin, value: microWacc.Expr) => val printfCall = { (func: microWacc.Builtin, value: microWacc.Expr) =>
List( Chain.one(
microWacc.Call( microWacc.Call(
func, func,
List( List(
@@ -233,36 +218,38 @@ object typeChecker {
) )
) )
} }
exprTyped.ty match { (
case KnownType.Bool => exprTyped.ty match {
List( case KnownType.Bool =>
microWacc.If( Chain.one(
exprTyped, microWacc.If(
printfCall(microWacc.Builtin.Printf, "true".toMicroWaccCharArray), exprTyped,
printfCall(microWacc.Builtin.Printf, "false".toMicroWaccCharArray) printfCall(microWacc.Builtin.Printf, "true".toMicroWaccCharArray),
printfCall(microWacc.Builtin.Printf, "false".toMicroWaccCharArray)
)
) )
) case KnownType.Array(KnownType.Char) =>
case KnownType.Array(KnownType.Char) => printfCall(microWacc.Builtin.PrintCharArray, exprTyped)
printfCall(microWacc.Builtin.PrintCharArray, exprTyped) case _ => printfCall(microWacc.Builtin.Printf, exprTyped)
case _ => printfCall(microWacc.Builtin.Printf, exprTyped) },
} exprErrors
)
case ast.If(cond, thenStmt, elseStmt) => case ast.If(cond, thenStmt, elseStmt) =>
List( val (condTyped, condErrors) =
microWacc.If( checkValue(cond, Constraint.Is(KnownType.Bool, "if condition must be a bool"))
checkValue(cond, Constraint.Is(KnownType.Bool, "if condition must be a bool")), val (thenStmtTyped, thenErrors) = thenStmt.foldMap(checkStmt(_, returnConstraint))
thenStmt.toList.flatMap(checkStmt(_, returnConstraint)), val (elseStmtTyped, elseErrors) = elseStmt.foldMap(checkStmt(_, returnConstraint))
elseStmt.toList.flatMap(checkStmt(_, returnConstraint)) (
) Chain.one(microWacc.If(condTyped, thenStmtTyped, elseStmtTyped)),
condErrors ++ thenErrors ++ elseErrors
) )
case ast.While(cond, body) => case ast.While(cond, body) =>
List( val (condTyped, condErrors) =
microWacc.While( checkValue(cond, Constraint.Is(KnownType.Bool, "while condition must be a bool"))
checkValue(cond, Constraint.Is(KnownType.Bool, "while condition must be a bool")), val (bodyTyped, bodyErrors) = body.foldMap(checkStmt(_, returnConstraint))
body.toList.flatMap(checkStmt(_, returnConstraint)) (Chain.one(microWacc.While(condTyped, bodyTyped)), condErrors ++ bodyErrors)
) case ast.Block(body) => body.foldMap(checkStmt(_, returnConstraint))
) case skip @ ast.Skip() => (Chain.empty, Chain.empty)
case ast.Block(body) => body.toList.flatMap(checkStmt(_, returnConstraint))
case skip @ ast.Skip() => List.empty
} }
/** Type-check an AST LValue, RValue or Expr node. This function does all 3 since these traits /** Type-check an AST LValue, RValue or Expr node. This function does all 3 since these traits
@@ -275,127 +262,142 @@ object typeChecker {
* @return * @return
* The most specific type of the value if it could be determined, or ? if it could not. * The most specific type of the value if it could be determined, or ? if it could not.
*/ */
private def checkValue(value: ast.LValue | ast.RValue | ast.Expr, constraint: Constraint)(using private def checkValue(
ctx: TypeCheckerCtx value: ast.LValue | ast.RValue | ast.Expr,
): microWacc.Expr = value match { constraint: Constraint
): (microWacc.Expr, Chain[Error]) = value match {
case l @ ast.IntLiter(v) => case l @ ast.IntLiter(v) =>
KnownType.Int.satisfies(constraint, l.pos) val (_, errors) = KnownType.Int.satisfies(constraint, l.pos)
microWacc.IntLiter(v) (microWacc.IntLiter(v), errors)
case l @ ast.BoolLiter(v) => case l @ ast.BoolLiter(v) =>
KnownType.Bool.satisfies(constraint, l.pos) val (_, errors) = KnownType.Bool.satisfies(constraint, l.pos)
microWacc.BoolLiter(v) (microWacc.BoolLiter(v), errors)
case l @ ast.CharLiter(v) => case l @ ast.CharLiter(v) =>
KnownType.Char.satisfies(constraint, l.pos) val (_, errors) = KnownType.Char.satisfies(constraint, l.pos)
microWacc.CharLiter(v) (microWacc.CharLiter(v), errors)
case l @ ast.StrLiter(v) => case l @ ast.StrLiter(v) =>
KnownType.String.satisfies(constraint, l.pos) val (_, errors) = KnownType.String.satisfies(constraint, l.pos)
v.toMicroWaccCharArray (v.toMicroWaccCharArray, errors)
case l @ ast.PairLiter() => case l @ ast.PairLiter() =>
microWacc.NullLiter()(KnownType.Pair(?, ?).satisfies(constraint, l.pos)) val (ty, errors) = KnownType.Pair(?, ?).satisfies(constraint, l.pos)
(microWacc.NullLiter()(ty), errors)
case ast.Parens(expr) => checkValue(expr, constraint) case ast.Parens(expr) => checkValue(expr, constraint)
case l @ ast.ArrayLiter(elems) => case l @ ast.ArrayLiter(elems) =>
val (elemTy, elemsTyped) = elems.mapAccumulate[SemType, microWacc.Expr](?) { val ((elemTy, elemsErrors), elemsTyped) =
case (acc, elem) => elems.mapAccumulate[(SemType, Chain[Error]), microWacc.Expr]((?, Chain.empty)) {
val elemTyped = checkValue( case ((acc, errors), elem) =>
elem, val (elemTyped, elemErrors) = checkValue(
Constraint.IsSymmetricCompatible(acc, s"array elements must have the same type") elem,
) Constraint.IsSymmetricCompatible(acc, s"array elements must have the same type")
(elemTyped.ty, elemTyped) )
} ((elemTyped.ty, errors ++ elemErrors), elemTyped)
val arrayTy = KnownType }
val (arrayTy, arrayErrors) = KnownType
// Start with an unknown param type, make it more specific while checking the elements. // Start with an unknown param type, make it more specific while checking the elements.
.Array(elemTy) .Array(elemTy)
.satisfies(constraint, l.pos) .satisfies(constraint, l.pos)
microWacc.ArrayLiter(elemsTyped)(arrayTy) (microWacc.ArrayLiter(elemsTyped)(arrayTy), elemsErrors ++ arrayErrors)
case l @ ast.NewPair(fst, snd) => case l @ ast.NewPair(fst, snd) =>
val fstTyped = checkValue(fst, Constraint.Unconstrained) val (fstTyped, fstErrors) = checkValue(fst, Constraint.Unconstrained)
val sndTyped = checkValue(snd, Constraint.Unconstrained) val (sndTyped, sndErrors) = checkValue(snd, Constraint.Unconstrained)
microWacc.ArrayLiter(List(fstTyped, sndTyped))( val (pairTy, pairErrors) =
KnownType.Pair(fstTyped.ty, sndTyped.ty).satisfies(constraint, l.pos) KnownType.Pair(fstTyped.ty, sndTyped.ty).satisfies(constraint, l.pos)
) (microWacc.ArrayLiter(List(fstTyped, sndTyped))(pairTy), fstErrors ++ sndErrors ++ pairErrors)
case ast.Call(id, args) => case ast.Call(id, args) =>
val funcTy @ FuncType(retTy, paramTys) = id.ty.asInstanceOf[FuncType] val funcTy @ FuncType(retTy, paramTys) = id.ty.asInstanceOf[FuncType]
if (args.length != paramTys.length) { val lenError =
ctx.error(Error.FunctionParamsMismatch(id, paramTys.length, args.length, funcTy)) if (args.length == paramTys.length) then Chain.empty
} else Chain.one(Error.FunctionParamsMismatch(id, paramTys.length, args.length, funcTy))
// Even if the number of arguments is wrong, we still check the types of the arguments // Even if the number of arguments is wrong, we still check the types of the arguments
// in the best way we can (by taking a zip). // in the best way we can (by taking a zip).
val argsTyped = args.zip(paramTys).map { case (arg, paramTy) => val (argsErrors, argsTyped) =
checkValue(arg, Constraint.Is(paramTy, s"argument type mismatch in function ${id.v}")) args.zip(paramTys).mapAccumulate(Chain.empty[Error]) { case (errors, (arg, paramTy)) =>
} val (argTyped, argErrors) =
microWacc.Call(microWacc.Ident(id.v, id.guid)(retTy.satisfies(constraint, id.pos)), argsTyped) checkValue(arg, Constraint.Is(paramTy, s"argument type mismatch in function ${id.v}"))
(errors ++ argErrors, argTyped)
}
val (retTyChecked, retErrors) = retTy.satisfies(constraint, id.pos)
(
microWacc.Call(microWacc.Ident(id.v, id.guid)(retTyChecked), argsTyped),
lenError ++ argsErrors ++ retErrors
)
// Unary operators // Unary operators
case ast.Negate(x) => case ast.Negate(x) =>
microWacc.UnaryOp( val (argTyped, argErrors) =
checkValue(x, Constraint.Is(KnownType.Int, "negation must be applied to an int")), checkValue(x, Constraint.Is(KnownType.Int, "negation must be applied to an int"))
microWacc.UnaryOperator.Negate val (retTy, retErrors) = KnownType.Int.satisfies(constraint, x.pos)
)(KnownType.Int.satisfies(constraint, x.pos)) (microWacc.UnaryOp(argTyped, microWacc.UnaryOperator.Negate)(retTy), argErrors ++ retErrors)
case ast.Not(x) => case ast.Not(x) =>
microWacc.UnaryOp( val (argTyped, argErrors) =
checkValue(x, Constraint.Is(KnownType.Bool, "logical not must be applied to a bool")), checkValue(x, Constraint.Is(KnownType.Bool, "logical not must be applied to a bool"))
microWacc.UnaryOperator.Not val (retTy, retErrors) = KnownType.Bool.satisfies(constraint, x.pos)
)(KnownType.Bool.satisfies(constraint, x.pos)) (microWacc.UnaryOp(argTyped, microWacc.UnaryOperator.Not)(retTy), argErrors ++ retErrors)
case ast.Len(x) => case ast.Len(x) =>
microWacc.UnaryOp( val (argTyped, argErrors) =
checkValue(x, Constraint.Is(KnownType.Array(?), "len must be applied to an array")), checkValue(x, Constraint.Is(KnownType.Array(?), "len must be applied to an array"))
microWacc.UnaryOperator.Len val (retTy, retErrors) = KnownType.Int.satisfies(constraint, x.pos)
)(KnownType.Int.satisfies(constraint, x.pos)) (microWacc.UnaryOp(argTyped, microWacc.UnaryOperator.Len)(retTy), argErrors ++ retErrors)
case ast.Ord(x) => case ast.Ord(x) =>
microWacc.UnaryOp( val (argTyped, argErrors) =
checkValue(x, Constraint.Is(KnownType.Char, "ord must be applied to a char")), checkValue(x, Constraint.Is(KnownType.Char, "ord must be applied to a char"))
microWacc.UnaryOperator.Ord val (retTy, retErrors) = KnownType.Int.satisfies(constraint, x.pos)
)(KnownType.Int.satisfies(constraint, x.pos)) (microWacc.UnaryOp(argTyped, microWacc.UnaryOperator.Ord)(retTy), argErrors ++ retErrors)
case ast.Chr(x) => case ast.Chr(x) =>
microWacc.UnaryOp( val (argTyped, argErrors) =
checkValue(x, Constraint.Is(KnownType.Int, "chr must be applied to an int")), checkValue(x, Constraint.Is(KnownType.Int, "chr must be applied to an int"))
microWacc.UnaryOperator.Chr val (retTy, retErrors) = KnownType.Char.satisfies(constraint, x.pos)
)(KnownType.Char.satisfies(constraint, x.pos)) (microWacc.UnaryOp(argTyped, microWacc.UnaryOperator.Chr)(retTy), argErrors ++ retErrors)
// Binary operators // Binary operators
case op: (ast.Add | ast.Sub | ast.Mul | ast.Div | ast.Mod) => case op: (ast.Add | ast.Sub | ast.Mul | ast.Div | ast.Mod) =>
val operand = Constraint.Is(KnownType.Int, s"${op.name} operator must be applied to an int") val operand = Constraint.Is(KnownType.Int, s"${op.name} operator must be applied to an int")
microWacc.BinaryOp( val (xTyped, xErrors) = checkValue(op.x, operand)
checkValue(op.x, operand), val (yTyped, yErrors) = checkValue(op.y, operand)
checkValue(op.y, operand), val (retTy, retErrors) = KnownType.Int.satisfies(constraint, op.pos)
microWacc.BinaryOperator.fromAst(op) (
)(KnownType.Int.satisfies(constraint, op.pos)) microWacc.BinaryOp(xTyped, yTyped, microWacc.BinaryOperator.fromAst(op))(retTy),
xErrors ++ yErrors ++ retErrors
)
case op: (ast.Eq | ast.Neq) => case op: (ast.Eq | ast.Neq) =>
val xTyped = checkValue(op.x, Constraint.Unconstrained) val (xTyped, xErrors) = checkValue(op.x, Constraint.Unconstrained)
microWacc.BinaryOp( val (yTyped, yErrors) = checkValue(
xTyped, op.y,
checkValue( Constraint.Is(xTyped.ty, s"${op.name} operator must be applied to values of the same type")
op.y, )
Constraint val (retTy, retErrors) = KnownType.Bool.satisfies(constraint, op.pos)
.Is(xTyped.ty, s"${op.name} operator must be applied to values of the same type") (
), microWacc.BinaryOp(xTyped, yTyped, microWacc.BinaryOperator.fromAst(op))(retTy),
microWacc.BinaryOperator.fromAst(op) xErrors ++ yErrors ++ retErrors
)(KnownType.Bool.satisfies(constraint, op.pos)) )
case op: (ast.Less | ast.LessEq | ast.Greater | ast.GreaterEq) => case op: (ast.Less | ast.LessEq | ast.Greater | ast.GreaterEq) =>
val xConstraint = Constraint.IsEither( val xConstraint = Constraint.IsEither(
KnownType.Int, KnownType.Int,
KnownType.Char, KnownType.Char,
s"${op.name} operator must be applied to an int or char" s"${op.name} operator must be applied to an int or char"
) )
val xTyped = checkValue(op.x, xConstraint) val (xTyped, xErrors) = checkValue(op.x, xConstraint)
// If x type-check failed, we still want to check y is an Int or Char (rather than ?) // If x type-check failed, we still want to check y is an Int or Char (rather than ?)
val yConstraint = xTyped.ty match { val yConstraint = xTyped.ty match {
case ? => xConstraint case ? => xConstraint
case xTy => case xTy =>
Constraint.Is(xTy, s"${op.name} operator must be applied to values of the same type") Constraint.Is(xTy, s"${op.name} operator must be applied to values of the same type")
} }
microWacc.BinaryOp( val (yTyped, yErrors) = checkValue(op.y, yConstraint)
xTyped, val (retTy, retErrors) = KnownType.Bool.satisfies(constraint, op.pos)
checkValue(op.y, yConstraint), (
microWacc.BinaryOperator.fromAst(op) microWacc.BinaryOp(xTyped, yTyped, microWacc.BinaryOperator.fromAst(op))(retTy),
)(KnownType.Bool.satisfies(constraint, op.pos)) xErrors ++ yErrors ++ retErrors
)
case op: (ast.And | ast.Or) => case op: (ast.And | ast.Or) =>
val operand = Constraint.Is(KnownType.Bool, s"${op.name} operator must be applied to a bool") val operand = Constraint.Is(KnownType.Bool, s"${op.name} operator must be applied to a bool")
microWacc.BinaryOp( val (xTyped, xErrors) = checkValue(op.x, operand)
checkValue(op.x, operand), val (yTyped, yErrors) = checkValue(op.y, operand)
checkValue(op.y, operand), val (retTy, retErrors) = KnownType.Bool.satisfies(constraint, op.pos)
microWacc.BinaryOperator.fromAst(op) (
)(KnownType.Bool.satisfies(constraint, op.pos)) microWacc.BinaryOp(xTyped, yTyped, microWacc.BinaryOperator.fromAst(op))(retTy),
xErrors ++ yErrors ++ retErrors
)
case lvalue: ast.LValue => checkLValue(lvalue, constraint) case lvalue: ast.LValue => checkLValue(lvalue, constraint)
} }
@@ -412,68 +414,69 @@ object typeChecker {
* @return * @return
* The most specific type of the value if it could be determined, or ? if it could not. * The most specific type of the value if it could be determined, or ? if it could not.
*/ */
private def checkLValue(value: ast.LValue, constraint: Constraint)(using private def checkLValue(
ctx: TypeCheckerCtx value: ast.LValue,
): microWacc.LValue = value match { constraint: Constraint
): (microWacc.LValue, Chain[Error]) = value match {
case id @ ast.Ident(name, guid, ty) => case id @ ast.Ident(name, guid, ty) =>
microWacc.Ident(name, guid)(ty.asInstanceOf[SemType].satisfies(constraint, id.pos)) val (idTy, idErrors) = ty.asInstanceOf[SemType].satisfies(constraint, id.pos)
(microWacc.Ident(name, guid)(idTy), idErrors)
case ast.ArrayElem(id, indices) => case ast.ArrayElem(id, indices) =>
val arrayTy = id.ty.asInstanceOf[SemType] val arrayTy = id.ty.asInstanceOf[SemType]
val (elemTy, indicesTyped) = indices.mapAccumulate(arrayTy.asInstanceOf[SemType]) { val ((elemTy, elemErrors), indicesTyped) =
(acc, elem) => indices.mapAccumulate((arrayTy.asInstanceOf[SemType], Chain.empty[Error])) {
val idxTyped = case ((acc, errors), elem) =>
checkValue(elem, Constraint.Is(KnownType.Int, "array index must be an int")) val (idxTyped, idxErrors) =
val next = acc match { checkValue(elem, Constraint.Is(KnownType.Int, "array index must be an int"))
case KnownType.Array(innerTy) => innerTy val (next, nextError) = acc match {
case ? => ? // we can keep indexing an unknown type case KnownType.Array(innerTy) => (innerTy, Chain.empty)
case nonArrayTy => case ? => (?, Chain.empty) // we can keep indexing an unknown type
ctx.error( case nonArrayTy =>
Error.TypeMismatch( (
elem.pos, ?,
KnownType.Array(?), Chain.one(
acc, Error.TypeMismatch(
"cannot index into a non-array" elem.pos,
KnownType.Array(?),
acc,
"cannot index into a non-array"
)
)
) )
) }
? ((next, errors ++ idxErrors ++ nextError), idxTyped)
} }
(next, idxTyped) val (retTy, retErrors) = elemTy.satisfies(constraint, value.pos)
}
val firstArrayElem = microWacc.ArrayElem( val firstArrayElem = microWacc.ArrayElem(
microWacc.Ident(id.v, id.guid)(arrayTy), microWacc.Ident(id.v, id.guid)(arrayTy),
indicesTyped.head indicesTyped.head
)(elemTy.satisfies(constraint, value.pos)) )(retTy)
val arrayElem = indicesTyped.tail.foldLeft(firstArrayElem) { (acc, idx) => val arrayElem = indicesTyped.tail.foldLeft(firstArrayElem) { (acc, idx) =>
microWacc.ArrayElem(acc, idx)(KnownType.Array(acc.ty)) microWacc.ArrayElem(acc, idx)(KnownType.Array(acc.ty))
} }
// Need to type-check the final arrayElem with the constraint // Need to type-check the final arrayElem with the constraint
microWacc.ArrayElem(arrayElem.value, arrayElem.index)(elemTy.satisfies(constraint, value.pos)) // TODO: What
(microWacc.ArrayElem(arrayElem.value, arrayElem.index)(retTy), elemErrors ++ retErrors)
case ast.Fst(elem) => case ast.Fst(elem) =>
val elemTyped = checkLValue( val (elemTyped, elemErrors) = checkLValue(
elem, elem,
Constraint.Is(KnownType.Pair(?, ?), "fst must be applied to a pair") Constraint.Is(KnownType.Pair(?, ?), "fst must be applied to a pair")
) )
microWacc.ArrayElem( val (retTy, retErrors) = elemTyped.ty match {
elemTyped, case KnownType.Pair(left, _) => left.satisfies(constraint, elem.pos)
microWacc.IntLiter(0) case _ => (?, Chain.one(Error.InternalError(elem.pos, "fst must be applied to a pair")))
)(elemTyped.ty match { }
case KnownType.Pair(left, _) => (microWacc.ArrayElem(elemTyped, microWacc.IntLiter(0))(retTy), elemErrors ++ retErrors)
left.satisfies(constraint, elem.pos)
case _ => ctx.error(Error.InternalError(elem.pos, "fst must be applied to a pair"))
})
case ast.Snd(elem) => case ast.Snd(elem) =>
val elemTyped = checkLValue( val (elemTyped, elemErrors) = checkLValue(
elem, elem,
Constraint.Is(KnownType.Pair(?, ?), "snd must be applied to a pair") Constraint.Is(KnownType.Pair(?, ?), "snd must be applied to a pair")
) )
microWacc.ArrayElem( val (retTy, retErrors) = elemTyped.ty match {
elemTyped, case KnownType.Pair(_, right) => right.satisfies(constraint, elem.pos)
microWacc.IntLiter(1) case _ => (?, Chain.one(Error.InternalError(elem.pos, "snd must be applied to a pair")))
)(elemTyped.ty match { }
case KnownType.Pair(_, right) => (microWacc.ArrayElem(elemTyped, microWacc.IntLiter(1))(retTy), elemErrors ++ retErrors)
right.satisfies(constraint, elem.pos)
case _ => ctx.error(Error.InternalError(elem.pos, "snd must be applied to a pair"))
})
} }
extension (s: String) { extension (s: String) {