feat: initial parallel type-checker implementation

This commit is contained in:
2025-03-14 04:09:34 +00:00
parent 42515abf2a
commit 53d47fda63
5 changed files with 322 additions and 283 deletions

View File

@@ -1,6 +1,5 @@
package wacc
import scala.collection.mutable
import cats.data.{Chain, NonEmptyList}
import parsley.{Failure, Success}
@@ -19,6 +18,7 @@ import org.typelevel.log4cats.Logger
import assemblyIR as asm
import cats.data.ValidatedNel
import java.io.File
import cats.data.NonEmptySeq
/*
TODO:
@@ -71,18 +71,15 @@ val outputOpt: Opts[Option[Path]] =
def frontend(
contents: String,
file: File
): IO[Either[NonEmptyList[Error], microWacc.Program]] =
): IO[Either[NonEmptySeq[Error], microWacc.Program]] =
parser.parse(contents) match {
case Failure(msg) => IO.pure(Left(NonEmptyList.one(Error.SyntaxError(file, msg))))
case Failure(msg) => IO.pure(Left(NonEmptySeq.one(Error.SyntaxError(file, msg))))
case Success(fn) =>
val partialProg = fn(file)
given errors: mutable.Builder[Error, List[Error]] = List.newBuilder
for {
(prog, renameErrors) <- renamer.rename(partialProg)
_ = errors.addAll(renameErrors.toList)
typedProg = typeChecker.check(prog, errors)
res = errors.result.toNel.toLeft(typedProg)
(typedProg, errors) <- semantics.check(partialProg)
res = NonEmptySeq.fromSeq(errors.iterator.toSeq).map(Left(_)).getOrElse(Right(typedProg))
} yield res
}
@@ -103,12 +100,19 @@ def compile(
// TODO: path, file , the names are confusing (when Path is the type but we are working with files)
def writeOutputFile(typedProg: microWacc.Program, outputPath: Path): IO[Unit] =
writer.writeTo(backend(typedProg), outputPath) *>
logger.info(s"Success: ${outputPath.toAbsolutePath}")
val backendStart = System.nanoTime()
val asmLines = backend(typedProg)
val backendEnd = System.nanoTime()
writer.writeTo(asmLines, outputPath) *>
logAction(s"Backend time (${filePath.toRealPath()}): ${(backendEnd - backendStart).toFloat / 1e6} ms") *>
logAction(s"Success: ${outputPath.toAbsolutePath}")
def processProgram(contents: String, file: File, outDir: Path): IO[Int] =
val frontendStart = System.nanoTime()
for {
frontendResult <- frontend(contents, file)
frontendEnd = System.nanoTime()
_ <- logAction(s"Frontend time (${filePath.toRealPath()}): ${(frontendEnd - frontendStart).toFloat / 1e6} ms")
res <- frontendResult match {
case Left(errors) =>
val code = errors.map(err => err.exitCode).toList.min

View File

@@ -1,5 +1,7 @@
package wacc
import cats.data.Chain
object microWacc {
import wacc.types._
@@ -78,12 +80,12 @@ object microWacc {
}
case class Assign(lhs: LValue, rhs: Expr) extends Stmt
case class If(cond: Expr, thenBranch: List[Stmt], elseBranch: List[Stmt]) extends Stmt
case class While(cond: Expr, body: List[Stmt]) extends Stmt
case class If(cond: Expr, thenBranch: Chain[Stmt], elseBranch: Chain[Stmt]) extends Stmt
case class While(cond: Expr, body: Chain[Stmt]) extends Stmt
case class Call(target: CallTarget, args: List[Expr]) extends Stmt with Expr(target.retTy)
case class Return(expr: Expr) extends Stmt
// Program
case class FuncDecl(name: Ident, params: List[Ident], body: List[Stmt])
case class Program(funcs: List[FuncDecl], stmts: List[Stmt])
case class FuncDecl(name: Ident, params: List[Ident], body: Chain[Stmt])
case class Program(funcs: Chain[FuncDecl], stmts: Chain[Stmt])
}

View File

@@ -3,27 +3,27 @@ package wacc
import java.io.File
import scala.collection.mutable
import cats.effect.IO
import cats.syntax.all._
import cats.implicits._
import cats.data.Chain
import cats.data.NonEmptyList
import parsley.{Failure, Success}
private val MAIN = "$main"
object renamer {
import ast._
import types._
private enum IdentType {
val MAIN = "$main"
enum IdentType {
case Func
case Var
}
private case class ScopeKey(path: String, name: String, identType: IdentType)
private case class ScopeValue(id: Ident, public: Boolean)
case class ScopeKey(path: String, name: String, identType: IdentType)
case class ScopeValue(id: Ident, public: Boolean)
private class Scope(
class Scope(
private val current: mutable.Map[ScopeKey, ScopeValue],
private val parent: Map[ScopeKey, ScopeValue],
guidStart: Int = 0,
@@ -153,7 +153,7 @@ object renamer {
get(name.pos.file.getCanonicalPath, name.v, IdentType.Func).map(_.id)
}
private def prepareGlobalScope(
def prepareGlobalScope(
partialProg: PartialProgram
)(using scope: Scope): IO[(FuncDecl, Chain[FuncDecl], Chain[Error])] = {
def readImportFile(file: File): IO[String] =
@@ -267,25 +267,13 @@ object renamer {
* @return
* (flattenedProg, errors)
*/
private def renameFunction(funcScopePair: (FuncDecl, Scope)): IO[Chain[Error]] = {
def renameFunction(funcScopePair: (FuncDecl, Scope)): IO[Chain[Error]] = {
val (FuncDecl(_, _, params, body), subscope) = funcScopePair
val paramErrors = params.foldMap(param => subscope.add(param.name))
IO(subscope.withSubscope { s => body.foldMap(rename(s)) })
.map(bodyErrors => paramErrors ++ bodyErrors)
}
def rename(partialProg: PartialProgram): IO[(Program, Chain[Error])] = {
given scope: Scope = Scope(mutable.Map.empty, Map.empty)
for {
(main, chunks, globalErrors) <- prepareGlobalScope(partialProg)
toRename = (main +: chunks).toList
allErrors <- toRename
.zip(scope.subscopes(toRename.size))
.parFoldMapA(renameFunction)
} yield (Program(chunks.toList, main.body)(main.pos), globalErrors ++ allErrors)
}
/** Check scoping of all identifies in a given AST node.
*
* @param scope

View File

@@ -0,0 +1,42 @@
package wacc
import scala.collection.mutable
import cats.implicits._
import cats.data.Chain
import cats.effect.IO
object semantics {
import renamer.{Scope, prepareGlobalScope, renameFunction}
import typeChecker.checkFuncDecl
private def checkFunc(
funcDecl: ast.FuncDecl,
scope: Scope
): IO[Chain[(microWacc.FuncDecl, Chain[Error])]] = {
for {
renamerErrors <- renameFunction(funcDecl, scope)
(microWaccFunc, typeErrors) = checkFuncDecl(funcDecl)
} yield Chain.one(microWaccFunc, renamerErrors ++ typeErrors)
}
def check(partialProg: ast.PartialProgram): IO[(microWacc.Program, Chain[Error])] = {
given scope: Scope = Scope(mutable.Map.empty, Map.empty)
for {
(main, chunks, globalErrors) <- prepareGlobalScope(partialProg)
toRename = (main +: chunks).toList
res <- toRename
.zip(scope.subscopes(toRename.size))
.parFoldMapA(checkFunc)
(typedChunks, errors) = res.foldLeft((Chain.empty[microWacc.FuncDecl], Chain.empty[Error])) {
case ((acc, err), (funcDecl, errors)) =>
(acc :+ funcDecl, err ++ errors)
}
(typedMain, funcs) = typedChunks.uncons match {
case Some((head, tail)) => (head.body, tail)
case None => (Chain.empty, Chain.empty)
}
} yield (microWacc.Program(funcs, typedMain), globalErrors ++ errors)
}
}

View File

@@ -1,20 +1,12 @@
package wacc
import cats.syntax.all._
import scala.collection.mutable
import cats.data.NonEmptyList
import cats.data.Chain
object typeChecker {
import wacc.types._
case class TypeCheckerCtx(
errors: mutable.Builder[Error, List[Error]]
) {
def error(err: Error): SemType =
errors += err
?
}
private enum Constraint {
case Unconstrained
// Allows weakening in one direction
@@ -38,31 +30,29 @@ object typeChecker {
* @return
* The type if the constraint was satisfied, or ? if it was not.
*/
private def satisfies(constraint: Constraint, pos: ast.Position)(using
ctx: TypeCheckerCtx
): SemType =
private def satisfies(constraint: Constraint, pos: ast.Position): (SemType, Chain[Error]) =
(ty, constraint) match {
case (KnownType.Array(KnownType.Char), Constraint.Is(KnownType.String, _)) =>
KnownType.String
(KnownType.String, Chain.empty)
case (
KnownType.String,
Constraint.IsSymmetricCompatible(KnownType.Array(KnownType.Char), _)
) =>
KnownType.String
(KnownType.String, Chain.empty)
case (ty, Constraint.IsSymmetricCompatible(ty2, msg)) =>
ty.satisfies(Constraint.Is(ty2, msg), pos)
// Change to IsUnweakenable to disallow recursive weakening
case (ty, Constraint.Is(ty2, msg)) => ty.satisfies(Constraint.IsUnweakenable(ty2, msg), pos)
case (ty, Constraint.Unconstrained) => ty
case (ty, Constraint.Unconstrained) => (ty, Chain.empty)
case (ty, Constraint.Never(msg)) =>
ctx.error(Error.SemanticError(pos, msg))
(?, Chain.one(Error.SemanticError(pos, msg)))
case (ty, Constraint.IsEither(ty1, ty2, msg)) =>
(ty moreSpecific ty1).orElse(ty moreSpecific ty2).getOrElse {
ctx.error(Error.TypeMismatch(pos, ty1, ty, msg))
(ty moreSpecific ty1).orElse(ty moreSpecific ty2).map((_, Chain.empty)).getOrElse {
(?, Chain.one(Error.TypeMismatch(pos, ty1, ty, msg)))
}
case (ty, Constraint.IsUnweakenable(ty2, msg)) =>
(ty moreSpecific ty2).getOrElse {
ctx.error(Error.TypeMismatch(pos, ty2, ty, msg))
(ty moreSpecific ty2).map((_, Chain.empty)).getOrElse {
(?, Chain.one(Error.TypeMismatch(pos, ty2, ty, msg)))
}
}
@@ -86,35 +76,29 @@ object typeChecker {
}
}
/** Type-check a WACC program.
/** Type-check a function declaration.
*
* @param prog
* The AST of the program to type-check.
* @param ctx
* The type checker context which includes the global names and functions, and an errors
* builder.
* @param func
* The AST of the function to type-check.
*/
def check(prog: ast.Program, errors: mutable.Builder[Error, List[Error]]): microWacc.Program =
given ctx: TypeCheckerCtx = TypeCheckerCtx(errors)
microWacc.Program(
// Ignore function syntax types for return value and params, since those have been converted
// to SemTypes by the renamer.
prog.funcs.map { case ast.FuncDecl(_, name, params, stmts) =>
def checkFuncDecl(func: ast.FuncDecl): (microWacc.FuncDecl, Chain[Error]) = {
val ast.FuncDecl(_, name, params, stmts) = func
val FuncType(retType, paramTypes) = name.ty.asInstanceOf[FuncType]
val returnConstraint =
if func.name.v == renamer.MAIN then Constraint.Never("main body must not return")
else Constraint.Is(retType, s"function ${name.v} must return $retType")
val (body, bodyErrors) = stmts.foldMap(checkStmt(_, returnConstraint))
(
microWacc.FuncDecl(
microWacc.Ident(name.v, name.guid)(retType),
params.zip(paramTypes).map { case (ast.Param(_, ident), ty) =>
microWacc.Ident(ident.v, ident.guid)(ty)
},
stmts.toList
.flatMap(
checkStmt(_, Constraint.Is(retType, s"function ${name.v} must return $retType"))
)
)
},
prog.main.toList
.flatMap(checkStmt(_, Constraint.Never("main function must not return")))
body
),
bodyErrors
)
}
/** Type-check an AST statement node.
*
@@ -123,45 +107,51 @@ object typeChecker {
* @param returnConstraint
* The constraint that any `return <expr>` statements must satisfy.
*/
private def checkStmt(stmt: ast.Stmt, returnConstraint: Constraint)(using
ctx: TypeCheckerCtx
): List[microWacc.Stmt] = stmt match {
private def checkStmt(
stmt: ast.Stmt,
returnConstraint: Constraint
): (Chain[microWacc.Stmt], Chain[Error]) = stmt match {
// Ignore the type of the variable, since it has been converted to a SemType by the renamer.
case ast.VarDecl(_, name, value) =>
val expectedTy = name.ty
val typedValue = checkValue(
val (typedValue, valueErrors) = checkValue(
value,
Constraint.Is(
expectedTy.asInstanceOf[SemType],
s"variable ${name.v} must be assigned a value of type $expectedTy"
)
)
List(
(
Chain.one(
microWacc.Assign(
microWacc.Ident(name.v, name.guid)(expectedTy.asInstanceOf[SemType]),
typedValue
)
),
valueErrors
)
case ast.Assign(lhs, rhs) =>
val lhsTyped = checkLValue(lhs, Constraint.Unconstrained)
val rhsTyped =
val (lhsTyped, lhsErrors) = checkLValue(lhs, Constraint.Unconstrained)
val (rhsTyped, rhsErrors) =
checkValue(rhs, Constraint.Is(lhsTyped.ty, s"assignment must have type ${lhsTyped.ty}"))
(lhsTyped.ty, rhsTyped.ty) match {
val unknownError = (lhsTyped.ty, rhsTyped.ty) match {
case (?, ?) =>
ctx.error(
Chain.one(
Error.SemanticError(lhs.pos, "assignment with both sides of unknown type is illegal")
)
case _ => ()
case _ => Chain.empty
}
List(microWacc.Assign(lhsTyped, rhsTyped))
(Chain.one(microWacc.Assign(lhsTyped, rhsTyped)), lhsErrors ++ rhsErrors ++ unknownError)
case ast.Read(dest) =>
val destTyped = checkLValue(dest, Constraint.Unconstrained)
val destTy = destTyped.ty match {
val (destTyped, destErrors) = checkLValue(dest, Constraint.Unconstrained)
val (destTy, destTyErrors) = destTyped.ty match {
case ? =>
ctx.error(
(
?,
Chain.one(
Error.SemanticError(dest.pos, "cannot read into a destination with an unknown type")
)
?
)
case destTy =>
destTy.satisfies(
Constraint.IsEither(
@@ -172,7 +162,8 @@ object typeChecker {
dest.pos
)
}
List(
(
Chain.one(
microWacc.Assign(
destTyped,
microWacc.Call(
@@ -186,13 +177,11 @@ object typeChecker {
)
)
)
),
destErrors ++ destTyErrors
)
case ast.Free(lhs) =>
List(
microWacc.Call(
microWacc.Builtin.Free,
List(
checkValue(
val (lhsTyped, lhsErrors) = checkValue(
lhs,
Constraint.IsEither(
KnownType.Array(?),
@@ -200,21 +189,17 @@ object typeChecker {
"free must be applied to an array or pair"
)
)
)
)
)
(Chain.one(microWacc.Call(microWacc.Builtin.Free, List(lhsTyped))), lhsErrors)
case ast.Return(expr) =>
List(microWacc.Return(checkValue(expr, returnConstraint)))
val (exprTyped, exprErrors) = checkValue(expr, returnConstraint)
(Chain.one(microWacc.Return(exprTyped)), exprErrors)
case ast.Exit(expr) =>
List(
microWacc.Call(
microWacc.Builtin.Exit,
List(checkValue(expr, Constraint.Is(KnownType.Int, "exit value must be int")))
)
)
val (exprTyped, exprErrors) =
checkValue(expr, Constraint.Is(KnownType.Int, "exit value must be int"))
(Chain.one(microWacc.Call(microWacc.Builtin.Exit, List(exprTyped))), exprErrors)
case ast.Print(expr, newline) =>
// This constraint should never fail, the scope-checker should have caught it already
val exprTyped = checkValue(expr, Constraint.Unconstrained)
val (exprTyped, exprErrors) = checkValue(expr, Constraint.Unconstrained)
val exprFormat = exprTyped.ty match {
case KnownType.Bool | KnownType.String => "%s"
case KnownType.Array(KnownType.Char) => "%.*s"
@@ -223,7 +208,7 @@ object typeChecker {
case KnownType.Pair(_, _) | KnownType.Array(_) | ? => "%p"
}
val printfCall = { (func: microWacc.Builtin, value: microWacc.Expr) =>
List(
Chain.one(
microWacc.Call(
func,
List(
@@ -233,9 +218,10 @@ object typeChecker {
)
)
}
(
exprTyped.ty match {
case KnownType.Bool =>
List(
Chain.one(
microWacc.If(
exprTyped,
printfCall(microWacc.Builtin.Printf, "true".toMicroWaccCharArray),
@@ -245,24 +231,25 @@ object typeChecker {
case KnownType.Array(KnownType.Char) =>
printfCall(microWacc.Builtin.PrintCharArray, exprTyped)
case _ => printfCall(microWacc.Builtin.Printf, exprTyped)
}
case ast.If(cond, thenStmt, elseStmt) =>
List(
microWacc.If(
checkValue(cond, Constraint.Is(KnownType.Bool, "if condition must be a bool")),
thenStmt.toList.flatMap(checkStmt(_, returnConstraint)),
elseStmt.toList.flatMap(checkStmt(_, returnConstraint))
},
exprErrors
)
case ast.If(cond, thenStmt, elseStmt) =>
val (condTyped, condErrors) =
checkValue(cond, Constraint.Is(KnownType.Bool, "if condition must be a bool"))
val (thenStmtTyped, thenErrors) = thenStmt.foldMap(checkStmt(_, returnConstraint))
val (elseStmtTyped, elseErrors) = elseStmt.foldMap(checkStmt(_, returnConstraint))
(
Chain.one(microWacc.If(condTyped, thenStmtTyped, elseStmtTyped)),
condErrors ++ thenErrors ++ elseErrors
)
case ast.While(cond, body) =>
List(
microWacc.While(
checkValue(cond, Constraint.Is(KnownType.Bool, "while condition must be a bool")),
body.toList.flatMap(checkStmt(_, returnConstraint))
)
)
case ast.Block(body) => body.toList.flatMap(checkStmt(_, returnConstraint))
case skip @ ast.Skip() => List.empty
val (condTyped, condErrors) =
checkValue(cond, Constraint.Is(KnownType.Bool, "while condition must be a bool"))
val (bodyTyped, bodyErrors) = body.foldMap(checkStmt(_, returnConstraint))
(Chain.one(microWacc.While(condTyped, bodyTyped)), condErrors ++ bodyErrors)
case ast.Block(body) => body.foldMap(checkStmt(_, returnConstraint))
case skip @ ast.Skip() => (Chain.empty, Chain.empty)
}
/** Type-check an AST LValue, RValue or Expr node. This function does all 3 since these traits
@@ -275,127 +262,142 @@ object typeChecker {
* @return
* The most specific type of the value if it could be determined, or ? if it could not.
*/
private def checkValue(value: ast.LValue | ast.RValue | ast.Expr, constraint: Constraint)(using
ctx: TypeCheckerCtx
): microWacc.Expr = value match {
private def checkValue(
value: ast.LValue | ast.RValue | ast.Expr,
constraint: Constraint
): (microWacc.Expr, Chain[Error]) = value match {
case l @ ast.IntLiter(v) =>
KnownType.Int.satisfies(constraint, l.pos)
microWacc.IntLiter(v)
val (_, errors) = KnownType.Int.satisfies(constraint, l.pos)
(microWacc.IntLiter(v), errors)
case l @ ast.BoolLiter(v) =>
KnownType.Bool.satisfies(constraint, l.pos)
microWacc.BoolLiter(v)
val (_, errors) = KnownType.Bool.satisfies(constraint, l.pos)
(microWacc.BoolLiter(v), errors)
case l @ ast.CharLiter(v) =>
KnownType.Char.satisfies(constraint, l.pos)
microWacc.CharLiter(v)
val (_, errors) = KnownType.Char.satisfies(constraint, l.pos)
(microWacc.CharLiter(v), errors)
case l @ ast.StrLiter(v) =>
KnownType.String.satisfies(constraint, l.pos)
v.toMicroWaccCharArray
val (_, errors) = KnownType.String.satisfies(constraint, l.pos)
(v.toMicroWaccCharArray, errors)
case l @ ast.PairLiter() =>
microWacc.NullLiter()(KnownType.Pair(?, ?).satisfies(constraint, l.pos))
val (ty, errors) = KnownType.Pair(?, ?).satisfies(constraint, l.pos)
(microWacc.NullLiter()(ty), errors)
case ast.Parens(expr) => checkValue(expr, constraint)
case l @ ast.ArrayLiter(elems) =>
val (elemTy, elemsTyped) = elems.mapAccumulate[SemType, microWacc.Expr](?) {
case (acc, elem) =>
val elemTyped = checkValue(
val ((elemTy, elemsErrors), elemsTyped) =
elems.mapAccumulate[(SemType, Chain[Error]), microWacc.Expr]((?, Chain.empty)) {
case ((acc, errors), elem) =>
val (elemTyped, elemErrors) = checkValue(
elem,
Constraint.IsSymmetricCompatible(acc, s"array elements must have the same type")
)
(elemTyped.ty, elemTyped)
((elemTyped.ty, errors ++ elemErrors), elemTyped)
}
val arrayTy = KnownType
val (arrayTy, arrayErrors) = KnownType
// Start with an unknown param type, make it more specific while checking the elements.
.Array(elemTy)
.satisfies(constraint, l.pos)
microWacc.ArrayLiter(elemsTyped)(arrayTy)
(microWacc.ArrayLiter(elemsTyped)(arrayTy), elemsErrors ++ arrayErrors)
case l @ ast.NewPair(fst, snd) =>
val fstTyped = checkValue(fst, Constraint.Unconstrained)
val sndTyped = checkValue(snd, Constraint.Unconstrained)
microWacc.ArrayLiter(List(fstTyped, sndTyped))(
val (fstTyped, fstErrors) = checkValue(fst, Constraint.Unconstrained)
val (sndTyped, sndErrors) = checkValue(snd, Constraint.Unconstrained)
val (pairTy, pairErrors) =
KnownType.Pair(fstTyped.ty, sndTyped.ty).satisfies(constraint, l.pos)
)
(microWacc.ArrayLiter(List(fstTyped, sndTyped))(pairTy), fstErrors ++ sndErrors ++ pairErrors)
case ast.Call(id, args) =>
val funcTy @ FuncType(retTy, paramTys) = id.ty.asInstanceOf[FuncType]
if (args.length != paramTys.length) {
ctx.error(Error.FunctionParamsMismatch(id, paramTys.length, args.length, funcTy))
}
val lenError =
if (args.length == paramTys.length) then Chain.empty
else Chain.one(Error.FunctionParamsMismatch(id, paramTys.length, args.length, funcTy))
// Even if the number of arguments is wrong, we still check the types of the arguments
// in the best way we can (by taking a zip).
val argsTyped = args.zip(paramTys).map { case (arg, paramTy) =>
val (argsErrors, argsTyped) =
args.zip(paramTys).mapAccumulate(Chain.empty[Error]) { case (errors, (arg, paramTy)) =>
val (argTyped, argErrors) =
checkValue(arg, Constraint.Is(paramTy, s"argument type mismatch in function ${id.v}"))
(errors ++ argErrors, argTyped)
}
microWacc.Call(microWacc.Ident(id.v, id.guid)(retTy.satisfies(constraint, id.pos)), argsTyped)
val (retTyChecked, retErrors) = retTy.satisfies(constraint, id.pos)
(
microWacc.Call(microWacc.Ident(id.v, id.guid)(retTyChecked), argsTyped),
lenError ++ argsErrors ++ retErrors
)
// Unary operators
case ast.Negate(x) =>
microWacc.UnaryOp(
checkValue(x, Constraint.Is(KnownType.Int, "negation must be applied to an int")),
microWacc.UnaryOperator.Negate
)(KnownType.Int.satisfies(constraint, x.pos))
val (argTyped, argErrors) =
checkValue(x, Constraint.Is(KnownType.Int, "negation must be applied to an int"))
val (retTy, retErrors) = KnownType.Int.satisfies(constraint, x.pos)
(microWacc.UnaryOp(argTyped, microWacc.UnaryOperator.Negate)(retTy), argErrors ++ retErrors)
case ast.Not(x) =>
microWacc.UnaryOp(
checkValue(x, Constraint.Is(KnownType.Bool, "logical not must be applied to a bool")),
microWacc.UnaryOperator.Not
)(KnownType.Bool.satisfies(constraint, x.pos))
val (argTyped, argErrors) =
checkValue(x, Constraint.Is(KnownType.Bool, "logical not must be applied to a bool"))
val (retTy, retErrors) = KnownType.Bool.satisfies(constraint, x.pos)
(microWacc.UnaryOp(argTyped, microWacc.UnaryOperator.Not)(retTy), argErrors ++ retErrors)
case ast.Len(x) =>
microWacc.UnaryOp(
checkValue(x, Constraint.Is(KnownType.Array(?), "len must be applied to an array")),
microWacc.UnaryOperator.Len
)(KnownType.Int.satisfies(constraint, x.pos))
val (argTyped, argErrors) =
checkValue(x, Constraint.Is(KnownType.Array(?), "len must be applied to an array"))
val (retTy, retErrors) = KnownType.Int.satisfies(constraint, x.pos)
(microWacc.UnaryOp(argTyped, microWacc.UnaryOperator.Len)(retTy), argErrors ++ retErrors)
case ast.Ord(x) =>
microWacc.UnaryOp(
checkValue(x, Constraint.Is(KnownType.Char, "ord must be applied to a char")),
microWacc.UnaryOperator.Ord
)(KnownType.Int.satisfies(constraint, x.pos))
val (argTyped, argErrors) =
checkValue(x, Constraint.Is(KnownType.Char, "ord must be applied to a char"))
val (retTy, retErrors) = KnownType.Int.satisfies(constraint, x.pos)
(microWacc.UnaryOp(argTyped, microWacc.UnaryOperator.Ord)(retTy), argErrors ++ retErrors)
case ast.Chr(x) =>
microWacc.UnaryOp(
checkValue(x, Constraint.Is(KnownType.Int, "chr must be applied to an int")),
microWacc.UnaryOperator.Chr
)(KnownType.Char.satisfies(constraint, x.pos))
val (argTyped, argErrors) =
checkValue(x, Constraint.Is(KnownType.Int, "chr must be applied to an int"))
val (retTy, retErrors) = KnownType.Char.satisfies(constraint, x.pos)
(microWacc.UnaryOp(argTyped, microWacc.UnaryOperator.Chr)(retTy), argErrors ++ retErrors)
// Binary operators
case op: (ast.Add | ast.Sub | ast.Mul | ast.Div | ast.Mod) =>
val operand = Constraint.Is(KnownType.Int, s"${op.name} operator must be applied to an int")
microWacc.BinaryOp(
checkValue(op.x, operand),
checkValue(op.y, operand),
microWacc.BinaryOperator.fromAst(op)
)(KnownType.Int.satisfies(constraint, op.pos))
val (xTyped, xErrors) = checkValue(op.x, operand)
val (yTyped, yErrors) = checkValue(op.y, operand)
val (retTy, retErrors) = KnownType.Int.satisfies(constraint, op.pos)
(
microWacc.BinaryOp(xTyped, yTyped, microWacc.BinaryOperator.fromAst(op))(retTy),
xErrors ++ yErrors ++ retErrors
)
case op: (ast.Eq | ast.Neq) =>
val xTyped = checkValue(op.x, Constraint.Unconstrained)
microWacc.BinaryOp(
xTyped,
checkValue(
val (xTyped, xErrors) = checkValue(op.x, Constraint.Unconstrained)
val (yTyped, yErrors) = checkValue(
op.y,
Constraint
.Is(xTyped.ty, s"${op.name} operator must be applied to values of the same type")
),
microWacc.BinaryOperator.fromAst(op)
)(KnownType.Bool.satisfies(constraint, op.pos))
Constraint.Is(xTyped.ty, s"${op.name} operator must be applied to values of the same type")
)
val (retTy, retErrors) = KnownType.Bool.satisfies(constraint, op.pos)
(
microWacc.BinaryOp(xTyped, yTyped, microWacc.BinaryOperator.fromAst(op))(retTy),
xErrors ++ yErrors ++ retErrors
)
case op: (ast.Less | ast.LessEq | ast.Greater | ast.GreaterEq) =>
val xConstraint = Constraint.IsEither(
KnownType.Int,
KnownType.Char,
s"${op.name} operator must be applied to an int or char"
)
val xTyped = checkValue(op.x, xConstraint)
val (xTyped, xErrors) = checkValue(op.x, xConstraint)
// If x type-check failed, we still want to check y is an Int or Char (rather than ?)
val yConstraint = xTyped.ty match {
case ? => xConstraint
case xTy =>
Constraint.Is(xTy, s"${op.name} operator must be applied to values of the same type")
}
microWacc.BinaryOp(
xTyped,
checkValue(op.y, yConstraint),
microWacc.BinaryOperator.fromAst(op)
)(KnownType.Bool.satisfies(constraint, op.pos))
val (yTyped, yErrors) = checkValue(op.y, yConstraint)
val (retTy, retErrors) = KnownType.Bool.satisfies(constraint, op.pos)
(
microWacc.BinaryOp(xTyped, yTyped, microWacc.BinaryOperator.fromAst(op))(retTy),
xErrors ++ yErrors ++ retErrors
)
case op: (ast.And | ast.Or) =>
val operand = Constraint.Is(KnownType.Bool, s"${op.name} operator must be applied to a bool")
microWacc.BinaryOp(
checkValue(op.x, operand),
checkValue(op.y, operand),
microWacc.BinaryOperator.fromAst(op)
)(KnownType.Bool.satisfies(constraint, op.pos))
val (xTyped, xErrors) = checkValue(op.x, operand)
val (yTyped, yErrors) = checkValue(op.y, operand)
val (retTy, retErrors) = KnownType.Bool.satisfies(constraint, op.pos)
(
microWacc.BinaryOp(xTyped, yTyped, microWacc.BinaryOperator.fromAst(op))(retTy),
xErrors ++ yErrors ++ retErrors
)
case lvalue: ast.LValue => checkLValue(lvalue, constraint)
}
@@ -412,22 +414,27 @@ object typeChecker {
* @return
* The most specific type of the value if it could be determined, or ? if it could not.
*/
private def checkLValue(value: ast.LValue, constraint: Constraint)(using
ctx: TypeCheckerCtx
): microWacc.LValue = value match {
private def checkLValue(
value: ast.LValue,
constraint: Constraint
): (microWacc.LValue, Chain[Error]) = value match {
case id @ ast.Ident(name, guid, ty) =>
microWacc.Ident(name, guid)(ty.asInstanceOf[SemType].satisfies(constraint, id.pos))
val (idTy, idErrors) = ty.asInstanceOf[SemType].satisfies(constraint, id.pos)
(microWacc.Ident(name, guid)(idTy), idErrors)
case ast.ArrayElem(id, indices) =>
val arrayTy = id.ty.asInstanceOf[SemType]
val (elemTy, indicesTyped) = indices.mapAccumulate(arrayTy.asInstanceOf[SemType]) {
(acc, elem) =>
val idxTyped =
val ((elemTy, elemErrors), indicesTyped) =
indices.mapAccumulate((arrayTy.asInstanceOf[SemType], Chain.empty[Error])) {
case ((acc, errors), elem) =>
val (idxTyped, idxErrors) =
checkValue(elem, Constraint.Is(KnownType.Int, "array index must be an int"))
val next = acc match {
case KnownType.Array(innerTy) => innerTy
case ? => ? // we can keep indexing an unknown type
val (next, nextError) = acc match {
case KnownType.Array(innerTy) => (innerTy, Chain.empty)
case ? => (?, Chain.empty) // we can keep indexing an unknown type
case nonArrayTy =>
ctx.error(
(
?,
Chain.one(
Error.TypeMismatch(
elem.pos,
KnownType.Array(?),
@@ -435,45 +442,41 @@ object typeChecker {
"cannot index into a non-array"
)
)
?
)
}
(next, idxTyped)
((next, errors ++ idxErrors ++ nextError), idxTyped)
}
val (retTy, retErrors) = elemTy.satisfies(constraint, value.pos)
val firstArrayElem = microWacc.ArrayElem(
microWacc.Ident(id.v, id.guid)(arrayTy),
indicesTyped.head
)(elemTy.satisfies(constraint, value.pos))
)(retTy)
val arrayElem = indicesTyped.tail.foldLeft(firstArrayElem) { (acc, idx) =>
microWacc.ArrayElem(acc, idx)(KnownType.Array(acc.ty))
}
// Need to type-check the final arrayElem with the constraint
microWacc.ArrayElem(arrayElem.value, arrayElem.index)(elemTy.satisfies(constraint, value.pos))
// TODO: What
(microWacc.ArrayElem(arrayElem.value, arrayElem.index)(retTy), elemErrors ++ retErrors)
case ast.Fst(elem) =>
val elemTyped = checkLValue(
val (elemTyped, elemErrors) = checkLValue(
elem,
Constraint.Is(KnownType.Pair(?, ?), "fst must be applied to a pair")
)
microWacc.ArrayElem(
elemTyped,
microWacc.IntLiter(0)
)(elemTyped.ty match {
case KnownType.Pair(left, _) =>
left.satisfies(constraint, elem.pos)
case _ => ctx.error(Error.InternalError(elem.pos, "fst must be applied to a pair"))
})
val (retTy, retErrors) = elemTyped.ty match {
case KnownType.Pair(left, _) => left.satisfies(constraint, elem.pos)
case _ => (?, Chain.one(Error.InternalError(elem.pos, "fst must be applied to a pair")))
}
(microWacc.ArrayElem(elemTyped, microWacc.IntLiter(0))(retTy), elemErrors ++ retErrors)
case ast.Snd(elem) =>
val elemTyped = checkLValue(
val (elemTyped, elemErrors) = checkLValue(
elem,
Constraint.Is(KnownType.Pair(?, ?), "snd must be applied to a pair")
)
microWacc.ArrayElem(
elemTyped,
microWacc.IntLiter(1)
)(elemTyped.ty match {
case KnownType.Pair(_, right) =>
right.satisfies(constraint, elem.pos)
case _ => ctx.error(Error.InternalError(elem.pos, "snd must be applied to a pair"))
})
val (retTy, retErrors) = elemTyped.ty match {
case KnownType.Pair(_, right) => right.satisfies(constraint, elem.pos)
case _ => (?, Chain.one(Error.InternalError(elem.pos, "snd must be applied to a pair")))
}
(microWacc.ArrayElem(elemTyped, microWacc.IntLiter(1))(retTy), elemErrors ++ retErrors)
}
extension (s: String) {