refactor: make functions non-semantic types
This commit is contained in:
parent
0f87725f62
commit
f143f685c4
@ -5,7 +5,9 @@ import wacc.types._
|
||||
|
||||
enum Error {
|
||||
case DuplicateDeclaration(ident: ast.Ident)
|
||||
case UndefinedIdentifier(ident: ast.Ident, identType: renamer.IdentType)
|
||||
case UndeclaredVariable(ident: ast.Ident)
|
||||
case UndefinedFunction(ident: ast.Ident)
|
||||
|
||||
case FunctionParamsMismatch(pos: Position, expected: Int, got: Int)
|
||||
case SemanticError(pos: Position, msg: String)
|
||||
case TypeMismatch(pos: Position, expected: SemType, got: SemType, msg: String)
|
||||
@ -19,9 +21,13 @@ def printError(error: Error)(using errorContent: String): Unit = {
|
||||
printPosition(ident.pos)
|
||||
println(s"Duplicate declaration of identifier ${ident.v}")
|
||||
highlight(ident.pos, ident.v.length)
|
||||
case Error.UndefinedIdentifier(ident, identType) =>
|
||||
case Error.UndeclaredVariable(ident) =>
|
||||
printPosition(ident.pos)
|
||||
println(s"Undefined ${identType.toString.toLowerCase()} ${ident.v}")
|
||||
println(s"Undeclared variable ${ident.v}")
|
||||
highlight(ident.pos, ident.v.length)
|
||||
case Error.UndefinedFunction(ident) =>
|
||||
printPosition(ident.pos)
|
||||
println(s"Undefined function ${ident.v}")
|
||||
highlight(ident.pos, ident.v.length)
|
||||
case Error.FunctionParamsMismatch(pos, expected, got) =>
|
||||
printPosition(pos)
|
||||
|
@ -34,8 +34,8 @@ def compile(contents: String): Int = {
|
||||
parser.parse(contents) match {
|
||||
case Success(prog) =>
|
||||
given errors: mutable.Builder[Error, List[Error]] = List.newBuilder
|
||||
val names = renamer.rename(prog)
|
||||
given ctx: typeChecker.TypeCheckerCtx = typeChecker.TypeCheckerCtx(names, errors)
|
||||
val (names, funcs) = renamer.rename(prog)
|
||||
given ctx: typeChecker.TypeCheckerCtx = typeChecker.TypeCheckerCtx(names, funcs, errors)
|
||||
typeChecker.check(prog)
|
||||
if (errors.result.nonEmpty) {
|
||||
given errorContent: String = contents
|
||||
|
@ -6,14 +6,14 @@ object renamer {
|
||||
import ast._
|
||||
import types._
|
||||
|
||||
enum IdentType {
|
||||
private enum IdentType {
|
||||
case Func
|
||||
case Var
|
||||
}
|
||||
|
||||
private case class Scope(
|
||||
current: mutable.Map[(String, IdentType), Ident],
|
||||
parent: Map[(String, IdentType), Ident]
|
||||
private class Scope(
|
||||
val current: mutable.Map[(String, IdentType), Ident],
|
||||
val parent: Map[(String, IdentType), Ident]
|
||||
) {
|
||||
|
||||
/** Create a new scope with the current scope as its parent.
|
||||
@ -27,12 +27,10 @@ object renamer {
|
||||
/** Attempt to add a new identifier to the current scope. If the identifier already exists in
|
||||
* the current scope, add an error to the error list.
|
||||
*
|
||||
* @param semType
|
||||
* The semantic type of the identifier.
|
||||
* @param ty
|
||||
* The semantic type of the variable identifier, or function identifier type.
|
||||
* @param name
|
||||
* The name of the identifier.
|
||||
* @param identType
|
||||
* The identifier type (function or variable).
|
||||
* @param globalNames
|
||||
* The global map of identifiers to semantic types - the identifier will be added to this
|
||||
* map.
|
||||
@ -42,22 +40,46 @@ object renamer {
|
||||
* @param errors
|
||||
* The list of errors to append to.
|
||||
*/
|
||||
def add(semType: SemType, name: Ident, identType: IdentType)(using
|
||||
def add(ty: SemType | FuncType, name: Ident)(using
|
||||
globalNames: mutable.Map[Ident, SemType],
|
||||
globalFuncs: mutable.Map[Ident, FuncType],
|
||||
globalNumbering: mutable.Map[String, Int],
|
||||
errors: mutable.Builder[Error, List[Error]]
|
||||
) = {
|
||||
if (current.contains((name.v, identType))) {
|
||||
val identType = ty match {
|
||||
case _: SemType => IdentType.Var
|
||||
case _: FuncType => IdentType.Func
|
||||
}
|
||||
current.get((name.v, identType)) match {
|
||||
case Some(Ident(_, uid)) =>
|
||||
errors += Error.DuplicateDeclaration(name)
|
||||
} else {
|
||||
name.uid = uid
|
||||
case None =>
|
||||
val uid = globalNumbering.getOrElse(name.v, 0)
|
||||
name.uid = uid
|
||||
current((name.v, identType)) = name
|
||||
|
||||
ty match {
|
||||
case semType: SemType =>
|
||||
globalNames(name) = semType
|
||||
case funcType: FuncType =>
|
||||
globalFuncs(name) = funcType
|
||||
}
|
||||
globalNumbering(name.v) = uid + 1
|
||||
}
|
||||
}
|
||||
|
||||
private def get(name: String, identType: IdentType): Option[Ident] =
|
||||
// Unfortunately map defaults only work with `.apply()`, which throws an error when the key is not found.
|
||||
// Neither is there a way to check whether a default exists, so we have to use a try-catch.
|
||||
try {
|
||||
Some(current.withDefault(parent)((name, identType)))
|
||||
} catch {
|
||||
case _: NoSuchElementException => None
|
||||
}
|
||||
|
||||
def getVar(name: String): Option[Ident] = get(name, IdentType.Var)
|
||||
def getFunc(name: String): Option[Ident] = get(name, IdentType.Func)
|
||||
}
|
||||
|
||||
/** Check scoping of all variables and functions in the program. Also generate semantic types for
|
||||
@ -72,8 +94,9 @@ object renamer {
|
||||
*/
|
||||
def rename(prog: Program)(using
|
||||
errors: mutable.Builder[Error, List[Error]]
|
||||
): Map[Ident, SemType] =
|
||||
): (Map[Ident, SemType], Map[Ident, FuncType]) = {
|
||||
given globalNames: mutable.Map[Ident, SemType] = mutable.Map.empty
|
||||
given globalFuncs: mutable.Map[Ident, FuncType] = mutable.Map.empty
|
||||
given globalNumbering: mutable.Map[String, Int] = mutable.Map.empty
|
||||
val scope = Scope(mutable.Map.empty, Map.empty)
|
||||
val Program(funcs, main) = prog
|
||||
@ -84,7 +107,7 @@ object renamer {
|
||||
val paramType = SemType(param.paramType)
|
||||
paramType
|
||||
}
|
||||
scope.add(KnownType.Func(SemType(retType), paramTypes), name, IdentType.Func)
|
||||
scope.add(FuncType(SemType(retType), paramTypes), name)
|
||||
(params zip paramTypes, body)
|
||||
}
|
||||
// Only then rename the function bodies
|
||||
@ -92,12 +115,13 @@ object renamer {
|
||||
.foreach { case (params, body) =>
|
||||
val functionScope = scope.subscope
|
||||
params.foreach { case (param, paramType) =>
|
||||
functionScope.add(paramType, param.name, IdentType.Var)
|
||||
functionScope.add(paramType, param.name)
|
||||
}
|
||||
body.toList.foreach(rename(functionScope.subscope)) // body can shadow function params
|
||||
}
|
||||
main.toList.foreach(rename(scope))
|
||||
globalNames.toMap
|
||||
(globalNames.toMap, globalFuncs.toMap)
|
||||
}
|
||||
|
||||
/** Check scoping of all identifies in a given AST node.
|
||||
*
|
||||
@ -117,6 +141,7 @@ object renamer {
|
||||
node: Ident | Stmt | LValue | RValue | Expr
|
||||
)(using
|
||||
globalNames: mutable.Map[Ident, SemType],
|
||||
globalFuncs: mutable.Map[Ident, FuncType],
|
||||
globalNumbering: mutable.Map[String, Int],
|
||||
errors: mutable.Builder[Error, List[Error]]
|
||||
): Unit = node match {
|
||||
@ -126,7 +151,7 @@ object renamer {
|
||||
// Order matters here. Variable isn't declared until after the value is evaluated.
|
||||
rename(scope)(value)
|
||||
// Attempt to add the new variable to the current scope.
|
||||
scope.add(SemType(synType), name, IdentType.Var)
|
||||
scope.add(SemType(synType), name)
|
||||
}
|
||||
case If(cond, thenStmt, elseStmt) => {
|
||||
rename(scope)(cond)
|
||||
@ -158,7 +183,12 @@ object renamer {
|
||||
rename(scope)(snd)
|
||||
}
|
||||
case Call(name, args) => {
|
||||
renameIdent(scope, name, IdentType.Func)
|
||||
scope.getFunc(name.v) match {
|
||||
case Some(Ident(_, uid)) => name.uid = uid
|
||||
case None =>
|
||||
errors += Error.UndefinedFunction(name)
|
||||
scope.add(FuncType(?, args.map(_ => ?)), name)
|
||||
}
|
||||
args.foreach(rename(scope))
|
||||
}
|
||||
case Fst(elem) => rename(scope)(elem)
|
||||
@ -175,41 +205,15 @@ object renamer {
|
||||
rename(scope)(op.y)
|
||||
}
|
||||
// Default to variables. Only `call` uses IdentType.Func.
|
||||
case id: Ident => renameIdent(scope, id, IdentType.Var)
|
||||
case id: Ident => {
|
||||
scope.getVar(id.v) match {
|
||||
case Some(Ident(_, uid)) => id.uid = uid
|
||||
case None =>
|
||||
errors += Error.UndeclaredVariable(id)
|
||||
scope.add(?, id)
|
||||
}
|
||||
}
|
||||
// These literals cannot contain identifies, exit immediately.
|
||||
case IntLiter(_) | BoolLiter(_) | CharLiter(_) | StrLiter(_) | PairLiter() | Skip() => ()
|
||||
}
|
||||
|
||||
/** Lookup an identifier in the current scope and rename it. If the identifier is not found, add
|
||||
* an error to the error list and add it to the current scope with an unknown type.
|
||||
*
|
||||
* @param scope
|
||||
* The current scope and flattened parent scope.
|
||||
* @param ident
|
||||
* The identifier to rename.
|
||||
* @param identType
|
||||
* The type of the identifier (function or variable).
|
||||
* @param globalNames
|
||||
* Used to add not-found identifiers to scope.
|
||||
* @param globalNumbering
|
||||
* Used to add not-found identifiers to scope.
|
||||
* @param errors
|
||||
* The list of errors to append to.
|
||||
*/
|
||||
private def renameIdent(scope: Scope, ident: Ident, identType: IdentType)(using
|
||||
globalNames: mutable.Map[Ident, SemType],
|
||||
globalNumbering: mutable.Map[String, Int],
|
||||
errors: mutable.Builder[Error, List[Error]]
|
||||
): Unit = {
|
||||
// Unfortunately map defaults only work with `.apply()`, which throws an error when the key is not found.
|
||||
// Neither is there a way to check whether a default exists, so we have to use a try-catch.
|
||||
try {
|
||||
val Ident(_, uid) = scope.current.withDefault(scope.parent)((ident.v, identType))
|
||||
ident.uid = uid
|
||||
} catch {
|
||||
case _: NoSuchElementException =>
|
||||
errors += Error.UndefinedIdentifier(ident, identType)
|
||||
scope.add(?, ident, identType)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -9,9 +9,12 @@ object typeChecker {
|
||||
|
||||
case class TypeCheckerCtx(
|
||||
globalNames: Map[Ident, SemType],
|
||||
globalFuncs: Map[Ident, FuncType],
|
||||
errors: mutable.Builder[Error, List[Error]]
|
||||
) {
|
||||
def typeOf(ident: Ident): SemType = globalNames.withDefault { case Ident(_, -1) => ? }(ident)
|
||||
def typeOf(ident: Ident): SemType = globalNames(ident)
|
||||
|
||||
def funcType(ident: Ident): FuncType = globalFuncs(ident)
|
||||
|
||||
def error(err: Error): SemType =
|
||||
errors += err
|
||||
@ -23,7 +26,6 @@ object typeChecker {
|
||||
case Is(ty: SemType, msg: String)
|
||||
case IsSymmetricCompatible(ty: SemType, msg: String)
|
||||
case IsUnweakanable(ty: SemType, msg: String)
|
||||
case IsVar(msg: String)
|
||||
case IsEither(ty1: SemType, ty2: SemType, msg: String)
|
||||
case Never(msg: String)
|
||||
}
|
||||
@ -42,9 +44,6 @@ object typeChecker {
|
||||
ty.satisfies(Constraint.Is(ty2, msg), pos)
|
||||
case (ty, Constraint.Is(ty2, msg)) => ty.satisfies(Constraint.IsUnweakanable(ty2, msg), pos)
|
||||
case (ty, Constraint.Unconstrained) => ty
|
||||
case (KnownType.Func(_, _), Constraint.IsVar(msg)) =>
|
||||
ctx.error(Error.SemanticError(pos, msg))
|
||||
case (ty, Constraint.IsVar(msg)) => ty
|
||||
case (ty, Constraint.Never(msg)) =>
|
||||
ctx.error(Error.SemanticError(pos, msg))
|
||||
case (ty, Constraint.IsEither(ty1, ty2, msg)) =>
|
||||
@ -74,13 +73,10 @@ object typeChecker {
|
||||
ctx: TypeCheckerCtx
|
||||
): Unit = {
|
||||
prog.funcs.foreach { case FuncDecl(_, name, _, stmts) =>
|
||||
ctx.typeOf(name) match {
|
||||
case KnownType.Func(retType, _) =>
|
||||
val FuncType(retType, _) = ctx.funcType(name)
|
||||
stmts.toList.foreach(
|
||||
checkStmt(_, Constraint.Is(retType, s"function ${name.v} must return $retType"))
|
||||
)
|
||||
case _ => ctx.error(Error.InternalError(name.pos, "function declaration with non-function"))
|
||||
}
|
||||
}
|
||||
prog.main.toList.foreach(checkStmt(_, Constraint.Never("main function must not return")))
|
||||
}
|
||||
@ -138,7 +134,7 @@ object typeChecker {
|
||||
checkValue(expr, Constraint.Is(KnownType.Int, "exit value must be int"))
|
||||
case Print(expr, _) =>
|
||||
// This constraint should never fail, the scope-checker should have caught it already
|
||||
checkValue(expr, Constraint.IsVar("print value must be a variable"))
|
||||
checkValue(expr, Constraint.Unconstrained)
|
||||
case If(cond, thenStmt, elseStmt) =>
|
||||
checkValue(cond, Constraint.Is(KnownType.Bool, "if condition must be a bool"))
|
||||
thenStmt.toList.foreach(checkStmt(_, returnConstraint))
|
||||
@ -192,9 +188,7 @@ object typeChecker {
|
||||
)
|
||||
.satisfies(constraint, l.pos)
|
||||
case Call(id, args) =>
|
||||
val funcTy = ctx.typeOf(id)
|
||||
funcTy match {
|
||||
case KnownType.Func(retTy, paramTys) =>
|
||||
val FuncType(retTy, paramTys) = ctx.funcType(id)
|
||||
if (args.length != paramTys.length) {
|
||||
ctx.error(Error.FunctionParamsMismatch(id.pos, paramTys.length, args.length))
|
||||
}
|
||||
@ -202,10 +196,6 @@ object typeChecker {
|
||||
checkValue(arg, Constraint.Is(paramTy, s"argument type mismatch in function ${id.v}"))
|
||||
}
|
||||
retTy.satisfies(constraint, id.pos)
|
||||
// Should never happen, the scope-checker should have caught this already
|
||||
// ctx error had it not
|
||||
case _ => ctx.error(Error.InternalError(id.pos, "function call to non-function"))
|
||||
}
|
||||
case Fst(elem) =>
|
||||
checkValue(
|
||||
elem,
|
||||
|
@ -11,7 +11,6 @@ object types {
|
||||
case KnownType.String => "string"
|
||||
case KnownType.Array(elem) => s"$elem[]"
|
||||
case KnownType.Pair(left, right) => s"pair($left, $right)"
|
||||
case KnownType.Func(ret, params) => s"function returning $ret with params $params"
|
||||
case ? => "?"
|
||||
}
|
||||
}
|
||||
@ -24,7 +23,6 @@ object types {
|
||||
case String
|
||||
case Array(elem: SemType)
|
||||
case Pair(left: SemType, right: SemType)
|
||||
case Func(ret: SemType, params: List[SemType])
|
||||
}
|
||||
|
||||
object SemType {
|
||||
@ -40,4 +38,6 @@ object types {
|
||||
case UntypedPairType() => KnownType.Pair(?, ?)
|
||||
}
|
||||
}
|
||||
|
||||
case class FuncType(returnType: SemType, params: List[SemType])
|
||||
}
|
||||
|
Loading…
x
Reference in New Issue
Block a user