From 30cf42ee3a23d74aa17d2ab06a3dc200c7f0fc8a Mon Sep 17 00:00:00 2001 From: Gleb Koval Date: Wed, 5 Feb 2025 05:12:32 +0000 Subject: [PATCH] fix: separate variable and function in scope --- src/main/wacc/Error.scala | 2 +- src/main/wacc/renamer.scala | 48 +++++++++++++++++++++++-------------- src/main/wacc/types.scala | 2 -- 3 files changed, 31 insertions(+), 21 deletions(-) diff --git a/src/main/wacc/Error.scala b/src/main/wacc/Error.scala index a9f0490..6370925 100644 --- a/src/main/wacc/Error.scala +++ b/src/main/wacc/Error.scala @@ -2,7 +2,7 @@ package wacc enum Error { case DuplicateDeclaration(ident: ast.Ident) - case UndefinedIdentifier(ident: ast.Ident) + case UndefinedIdentifier(ident: ast.Ident, identType: renamer.IdentType) case FunctionParamsMismatch(expected: Int, got: Int) case TypeMismatch(expected: types.SemType, got: types.SemType) } diff --git a/src/main/wacc/renamer.scala b/src/main/wacc/renamer.scala index 023b317..cdfe19d 100644 --- a/src/main/wacc/renamer.scala +++ b/src/main/wacc/renamer.scala @@ -6,24 +6,29 @@ object renamer { import ast._ import types._ + enum IdentType { + case Func + case Var + } + private case class Scope( - current: mutable.Map[String, Ident], - parent: Map[String, Ident] + current: mutable.Map[(String, IdentType), Ident], + parent: Map[(String, IdentType), Ident] ) { def subscope: Scope = Scope(mutable.Map.empty, Map.empty.withDefault(current.withDefault(parent))) - def add(semType: SemType, name: Ident)(using + def add(semType: SemType, name: Ident, identType: IdentType)(using globalNames: mutable.Map[Ident, SemType], globalNumbering: mutable.Map[String, Int], errors: mutable.Builder[Error, List[Error]] ) = { - if (current.contains(name.v)) { + if (current.contains((name.v, identType))) { errors += Error.DuplicateDeclaration(name) } else { val uid = globalNumbering.getOrElse(name.v, 0) name.uid = uid - current(name.v) = name + current((name.v, identType)) = name globalNames(name) = semType globalNumbering(name.v) = uid + 1 @@ -54,16 +59,16 @@ object renamer { val functionScope = scope.subscope val paramTypes = params.map { param => val paramType = SemType(param.paramType) - functionScope.add(paramType, param.name) + functionScope.add(paramType, param.name, IdentType.Var) paramType } - scope.add(KnownType.Func(SemType(retType), paramTypes), name) + scope.add(KnownType.Func(SemType(retType), paramTypes), name, IdentType.Func) body.toList.foreach(rename(functionScope)) } case VarDecl(synType, name, value) => { // Order matters here. Variable isn't declared until after the value is evaluated. rename(scope)(value) - scope.add(SemType(synType), name) + scope.add(SemType(synType), name, IdentType.Var) } case Assign(lhs, value) => { rename(scope)(lhs) @@ -89,7 +94,7 @@ object renamer { rename(scope)(snd) } case Call(name, args) => { - rename(scope)(name) + renameIdent(scope, name, IdentType.Func) args.foreach(rename(scope)) } case Fst(elem) => rename(scope)(elem) @@ -105,15 +110,22 @@ object renamer { rename(scope)(op.x) rename(scope)(op.y) } - case id: Ident => { - scope.current.withDefault(scope.parent).get(id.v) match { - case Some(Ident(_, uid)) => id.uid = uid - case None => { - errors += Error.UndefinedIdentifier(id) - scope.add(?, id) - } - } - } + // Default to variables. Only `call` uses IdentType.Func. + case id: Ident => renameIdent(scope, id, IdentType.Var) case IntLiter(_) | BoolLiter(_) | CharLiter(_) | StrLiter(_) | PairLiter | Skip => () } + + private def renameIdent(scope: Scope, ident: Ident, identType: IdentType)(using + globalNames: mutable.Map[Ident, SemType], + globalNumbering: mutable.Map[String, Int], + errors: mutable.Builder[Error, List[Error]] + ): Unit = { + scope.current.withDefault(scope.parent).get((ident.v, identType)) match { + case Some(Ident(_, uid)) => ident.uid = uid + case None => { + errors += Error.UndefinedIdentifier(ident, identType) + scope.add(?, ident, identType) + } + } + } } diff --git a/src/main/wacc/types.scala b/src/main/wacc/types.scala index 2ce5f27..388416f 100644 --- a/src/main/wacc/types.scala +++ b/src/main/wacc/types.scala @@ -1,7 +1,5 @@ package wacc -import scala.collection.mutable - object types { import ast._