diff --git a/src/main/wacc/Error.scala b/src/main/wacc/Error.scala index 5aa0ad7..89e03d9 100644 --- a/src/main/wacc/Error.scala +++ b/src/main/wacc/Error.scala @@ -5,7 +5,9 @@ import wacc.types._ enum Error { case DuplicateDeclaration(ident: ast.Ident) - case UndefinedIdentifier(ident: ast.Ident, identType: renamer.IdentType) + case UndeclaredVariable(ident: ast.Ident) + case UndefinedFunction(ident: ast.Ident) + case FunctionParamsMismatch(pos: Position, expected: Int, got: Int) case SemanticError(pos: Position, msg: String) case TypeMismatch(pos: Position, expected: SemType, got: SemType, msg: String) @@ -19,9 +21,13 @@ def printError(error: Error)(using errorContent: String): Unit = { printPosition(ident.pos) println(s"Duplicate declaration of identifier ${ident.v}") highlight(ident.pos, ident.v.length) - case Error.UndefinedIdentifier(ident, identType) => + case Error.UndeclaredVariable(ident) => printPosition(ident.pos) - println(s"Undefined ${identType.toString.toLowerCase()} ${ident.v}") + println(s"Undeclared variable ${ident.v}") + highlight(ident.pos, ident.v.length) + case Error.UndefinedFunction(ident) => + printPosition(ident.pos) + println(s"Undefined function ${ident.v}") highlight(ident.pos, ident.v.length) case Error.FunctionParamsMismatch(pos, expected, got) => printPosition(pos) diff --git a/src/main/wacc/Main.scala b/src/main/wacc/Main.scala index 93bc158..f8db02a 100644 --- a/src/main/wacc/Main.scala +++ b/src/main/wacc/Main.scala @@ -34,8 +34,8 @@ def compile(contents: String): Int = { parser.parse(contents) match { case Success(prog) => given errors: mutable.Builder[Error, List[Error]] = List.newBuilder - val names = renamer.rename(prog) - given ctx: typeChecker.TypeCheckerCtx = typeChecker.TypeCheckerCtx(names, errors) + val (names, funcs) = renamer.rename(prog) + given ctx: typeChecker.TypeCheckerCtx = typeChecker.TypeCheckerCtx(names, funcs, errors) typeChecker.check(prog) if (errors.result.nonEmpty) { given errorContent: String = contents diff --git a/src/main/wacc/renamer.scala b/src/main/wacc/renamer.scala index 6b78dc1..b281283 100644 --- a/src/main/wacc/renamer.scala +++ b/src/main/wacc/renamer.scala @@ -6,14 +6,14 @@ object renamer { import ast._ import types._ - enum IdentType { + private enum IdentType { case Func case Var } - private case class Scope( - current: mutable.Map[(String, IdentType), Ident], - parent: Map[(String, IdentType), Ident] + private class Scope( + val current: mutable.Map[(String, IdentType), Ident], + val parent: Map[(String, IdentType), Ident] ) { /** Create a new scope with the current scope as its parent. @@ -27,12 +27,10 @@ object renamer { /** Attempt to add a new identifier to the current scope. If the identifier already exists in * the current scope, add an error to the error list. * - * @param semType - * The semantic type of the identifier. + * @param ty + * The semantic type of the variable identifier, or function identifier type. * @param name * The name of the identifier. - * @param identType - * The identifier type (function or variable). * @param globalNames * The global map of identifiers to semantic types - the identifier will be added to this * map. @@ -42,22 +40,46 @@ object renamer { * @param errors * The list of errors to append to. */ - def add(semType: SemType, name: Ident, identType: IdentType)(using + def add(ty: SemType | FuncType, name: Ident)(using globalNames: mutable.Map[Ident, SemType], + globalFuncs: mutable.Map[Ident, FuncType], globalNumbering: mutable.Map[String, Int], errors: mutable.Builder[Error, List[Error]] ) = { - if (current.contains((name.v, identType))) { - errors += Error.DuplicateDeclaration(name) - } else { - val uid = globalNumbering.getOrElse(name.v, 0) - name.uid = uid - current((name.v, identType)) = name + val identType = ty match { + case _: SemType => IdentType.Var + case _: FuncType => IdentType.Func + } + current.get((name.v, identType)) match { + case Some(Ident(_, uid)) => + errors += Error.DuplicateDeclaration(name) + name.uid = uid + case None => + val uid = globalNumbering.getOrElse(name.v, 0) + name.uid = uid + current((name.v, identType)) = name - globalNames(name) = semType - globalNumbering(name.v) = uid + 1 + ty match { + case semType: SemType => + globalNames(name) = semType + case funcType: FuncType => + globalFuncs(name) = funcType + } + globalNumbering(name.v) = uid + 1 } } + + private def get(name: String, identType: IdentType): Option[Ident] = + // Unfortunately map defaults only work with `.apply()`, which throws an error when the key is not found. + // Neither is there a way to check whether a default exists, so we have to use a try-catch. + try { + Some(current.withDefault(parent)((name, identType))) + } catch { + case _: NoSuchElementException => None + } + + def getVar(name: String): Option[Ident] = get(name, IdentType.Var) + def getFunc(name: String): Option[Ident] = get(name, IdentType.Func) } /** Check scoping of all variables and functions in the program. Also generate semantic types for @@ -72,8 +94,9 @@ object renamer { */ def rename(prog: Program)(using errors: mutable.Builder[Error, List[Error]] - ): Map[Ident, SemType] = + ): (Map[Ident, SemType], Map[Ident, FuncType]) = { given globalNames: mutable.Map[Ident, SemType] = mutable.Map.empty + given globalFuncs: mutable.Map[Ident, FuncType] = mutable.Map.empty given globalNumbering: mutable.Map[String, Int] = mutable.Map.empty val scope = Scope(mutable.Map.empty, Map.empty) val Program(funcs, main) = prog @@ -84,7 +107,7 @@ object renamer { val paramType = SemType(param.paramType) paramType } - scope.add(KnownType.Func(SemType(retType), paramTypes), name, IdentType.Func) + scope.add(FuncType(SemType(retType), paramTypes), name) (params zip paramTypes, body) } // Only then rename the function bodies @@ -92,12 +115,13 @@ object renamer { .foreach { case (params, body) => val functionScope = scope.subscope params.foreach { case (param, paramType) => - functionScope.add(paramType, param.name, IdentType.Var) + functionScope.add(paramType, param.name) } body.toList.foreach(rename(functionScope.subscope)) // body can shadow function params } main.toList.foreach(rename(scope)) - globalNames.toMap + (globalNames.toMap, globalFuncs.toMap) + } /** Check scoping of all identifies in a given AST node. * @@ -117,6 +141,7 @@ object renamer { node: Ident | Stmt | LValue | RValue | Expr )(using globalNames: mutable.Map[Ident, SemType], + globalFuncs: mutable.Map[Ident, FuncType], globalNumbering: mutable.Map[String, Int], errors: mutable.Builder[Error, List[Error]] ): Unit = node match { @@ -126,7 +151,7 @@ object renamer { // Order matters here. Variable isn't declared until after the value is evaluated. rename(scope)(value) // Attempt to add the new variable to the current scope. - scope.add(SemType(synType), name, IdentType.Var) + scope.add(SemType(synType), name) } case If(cond, thenStmt, elseStmt) => { rename(scope)(cond) @@ -158,7 +183,12 @@ object renamer { rename(scope)(snd) } case Call(name, args) => { - renameIdent(scope, name, IdentType.Func) + scope.getFunc(name.v) match { + case Some(Ident(_, uid)) => name.uid = uid + case None => + errors += Error.UndefinedFunction(name) + scope.add(FuncType(?, args.map(_ => ?)), name) + } args.foreach(rename(scope)) } case Fst(elem) => rename(scope)(elem) @@ -175,41 +205,15 @@ object renamer { rename(scope)(op.y) } // Default to variables. Only `call` uses IdentType.Func. - case id: Ident => renameIdent(scope, id, IdentType.Var) + case id: Ident => { + scope.getVar(id.v) match { + case Some(Ident(_, uid)) => id.uid = uid + case None => + errors += Error.UndeclaredVariable(id) + scope.add(?, id) + } + } // These literals cannot contain identifies, exit immediately. case IntLiter(_) | BoolLiter(_) | CharLiter(_) | StrLiter(_) | PairLiter() | Skip() => () } - - /** Lookup an identifier in the current scope and rename it. If the identifier is not found, add - * an error to the error list and add it to the current scope with an unknown type. - * - * @param scope - * The current scope and flattened parent scope. - * @param ident - * The identifier to rename. - * @param identType - * The type of the identifier (function or variable). - * @param globalNames - * Used to add not-found identifiers to scope. - * @param globalNumbering - * Used to add not-found identifiers to scope. - * @param errors - * The list of errors to append to. - */ - private def renameIdent(scope: Scope, ident: Ident, identType: IdentType)(using - globalNames: mutable.Map[Ident, SemType], - globalNumbering: mutable.Map[String, Int], - errors: mutable.Builder[Error, List[Error]] - ): Unit = { - // Unfortunately map defaults only work with `.apply()`, which throws an error when the key is not found. - // Neither is there a way to check whether a default exists, so we have to use a try-catch. - try { - val Ident(_, uid) = scope.current.withDefault(scope.parent)((ident.v, identType)) - ident.uid = uid - } catch { - case _: NoSuchElementException => - errors += Error.UndefinedIdentifier(ident, identType) - scope.add(?, ident, identType) - } - } } diff --git a/src/main/wacc/typeChecker.scala b/src/main/wacc/typeChecker.scala index 0cfe9d7..38b34e3 100644 --- a/src/main/wacc/typeChecker.scala +++ b/src/main/wacc/typeChecker.scala @@ -9,9 +9,12 @@ object typeChecker { case class TypeCheckerCtx( globalNames: Map[Ident, SemType], + globalFuncs: Map[Ident, FuncType], errors: mutable.Builder[Error, List[Error]] ) { - def typeOf(ident: Ident): SemType = globalNames.withDefault { case Ident(_, -1) => ? }(ident) + def typeOf(ident: Ident): SemType = globalNames(ident) + + def funcType(ident: Ident): FuncType = globalFuncs(ident) def error(err: Error): SemType = errors += err @@ -23,7 +26,6 @@ object typeChecker { case Is(ty: SemType, msg: String) case IsSymmetricCompatible(ty: SemType, msg: String) case IsUnweakanable(ty: SemType, msg: String) - case IsVar(msg: String) case IsEither(ty1: SemType, ty2: SemType, msg: String) case Never(msg: String) } @@ -42,9 +44,6 @@ object typeChecker { ty.satisfies(Constraint.Is(ty2, msg), pos) case (ty, Constraint.Is(ty2, msg)) => ty.satisfies(Constraint.IsUnweakanable(ty2, msg), pos) case (ty, Constraint.Unconstrained) => ty - case (KnownType.Func(_, _), Constraint.IsVar(msg)) => - ctx.error(Error.SemanticError(pos, msg)) - case (ty, Constraint.IsVar(msg)) => ty case (ty, Constraint.Never(msg)) => ctx.error(Error.SemanticError(pos, msg)) case (ty, Constraint.IsEither(ty1, ty2, msg)) => @@ -74,13 +73,10 @@ object typeChecker { ctx: TypeCheckerCtx ): Unit = { prog.funcs.foreach { case FuncDecl(_, name, _, stmts) => - ctx.typeOf(name) match { - case KnownType.Func(retType, _) => - stmts.toList.foreach( - checkStmt(_, Constraint.Is(retType, s"function ${name.v} must return $retType")) - ) - case _ => ctx.error(Error.InternalError(name.pos, "function declaration with non-function")) - } + val FuncType(retType, _) = ctx.funcType(name) + stmts.toList.foreach( + checkStmt(_, Constraint.Is(retType, s"function ${name.v} must return $retType")) + ) } prog.main.toList.foreach(checkStmt(_, Constraint.Never("main function must not return"))) } @@ -138,7 +134,7 @@ object typeChecker { checkValue(expr, Constraint.Is(KnownType.Int, "exit value must be int")) case Print(expr, _) => // This constraint should never fail, the scope-checker should have caught it already - checkValue(expr, Constraint.IsVar("print value must be a variable")) + checkValue(expr, Constraint.Unconstrained) case If(cond, thenStmt, elseStmt) => checkValue(cond, Constraint.Is(KnownType.Bool, "if condition must be a bool")) thenStmt.toList.foreach(checkStmt(_, returnConstraint)) @@ -192,20 +188,14 @@ object typeChecker { ) .satisfies(constraint, l.pos) case Call(id, args) => - val funcTy = ctx.typeOf(id) - funcTy match { - case KnownType.Func(retTy, paramTys) => - if (args.length != paramTys.length) { - ctx.error(Error.FunctionParamsMismatch(id.pos, paramTys.length, args.length)) - } - args.zip(paramTys).foreach { case (arg, paramTy) => - checkValue(arg, Constraint.Is(paramTy, s"argument type mismatch in function ${id.v}")) - } - retTy.satisfies(constraint, id.pos) - // Should never happen, the scope-checker should have caught this already - // ctx error had it not - case _ => ctx.error(Error.InternalError(id.pos, "function call to non-function")) + val FuncType(retTy, paramTys) = ctx.funcType(id) + if (args.length != paramTys.length) { + ctx.error(Error.FunctionParamsMismatch(id.pos, paramTys.length, args.length)) } + args.zip(paramTys).foreach { case (arg, paramTy) => + checkValue(arg, Constraint.Is(paramTy, s"argument type mismatch in function ${id.v}")) + } + retTy.satisfies(constraint, id.pos) case Fst(elem) => checkValue( elem, @@ -213,7 +203,6 @@ object typeChecker { ) match { case what @ KnownType.Pair(left, _) => left.satisfies(constraint, elem.pos) - case ? => ?.satisfies(constraint, elem.pos) case _ => ctx.error(Error.InternalError(elem.pos, "fst must be applied to a pair")) } // satisfies constraint case Snd(elem) => @@ -222,7 +211,6 @@ object typeChecker { Constraint.Is(KnownType.Pair(?, ?), "snd must be applied to a pair") ) match { case KnownType.Pair(_, right) => right.satisfies(constraint, elem.pos) - case ? => ?.satisfies(constraint, elem.pos) case _ => ctx.error(Error.InternalError(elem.pos, "snd must be applied to a pair")) } diff --git a/src/main/wacc/types.scala b/src/main/wacc/types.scala index ebe8517..41d4124 100644 --- a/src/main/wacc/types.scala +++ b/src/main/wacc/types.scala @@ -11,7 +11,6 @@ object types { case KnownType.String => "string" case KnownType.Array(elem) => s"$elem[]" case KnownType.Pair(left, right) => s"pair($left, $right)" - case KnownType.Func(ret, params) => s"function returning $ret with params $params" case ? => "?" } } @@ -24,7 +23,6 @@ object types { case String case Array(elem: SemType) case Pair(left: SemType, right: SemType) - case Func(ret: SemType, params: List[SemType]) } object SemType { @@ -40,4 +38,6 @@ object types { case UntypedPairType() => KnownType.Pair(?, ?) } } + + case class FuncType(returnType: SemType, params: List[SemType]) }