diff --git a/src/main/wacc/Error.scala b/src/main/wacc/Error.scala new file mode 100644 index 0000000..6370925 --- /dev/null +++ b/src/main/wacc/Error.scala @@ -0,0 +1,8 @@ +package wacc + +enum Error { + case DuplicateDeclaration(ident: ast.Ident) + case UndefinedIdentifier(ident: ast.Ident, identType: renamer.IdentType) + case FunctionParamsMismatch(expected: Int, got: Int) + case TypeMismatch(expected: types.SemType, got: types.SemType) +} diff --git a/src/main/wacc/Main.scala b/src/main/wacc/Main.scala index d4b070a..a271c3c 100644 --- a/src/main/wacc/Main.scala +++ b/src/main/wacc/Main.scala @@ -1,5 +1,6 @@ package wacc +import scala.collection.mutable import parsley.{Failure, Success} import scopt.OParser import java.io.File @@ -32,8 +33,15 @@ val cliParser = { def compile(contents: String): Int = { parser.parse(contents) match { case Success(ast) => - // TODO: Do semantics things - 0 + given errors: mutable.Builder[Error, List[Error]] = List.newBuilder + renamer.rename(ast) + // given ctx: types.TypeCheckerCtx[List[Error]] = + // types.TypeCheckerCtx(names, errors) + // types.check(ast) + if (errors.result.nonEmpty) { + errors.result.foreach(println) + 200 + } else 0 case Failure(msg) => println(msg) 100 diff --git a/src/main/wacc/ast.scala b/src/main/wacc/ast.scala index 0076006..123adb5 100644 --- a/src/main/wacc/ast.scala +++ b/src/main/wacc/ast.scala @@ -27,9 +27,11 @@ object ast { case class StrLiter(v: String)(pos: Position) extends Expr6 object StrLiter extends ParserBridgePos1[String, StrLiter] case class PairLiter()(pos: Position) extends Expr6 - object PairLiter extends Expr6 with ParserBridgePos0[PairLiter] - case class Ident(v: String)(pos: Position) extends Expr6 with LValue - object Ident extends ParserBridgePos1[String, Ident] + object PairLiter extends ParserBridgePos0[PairLiter] + case class Ident(v: String, var uid: Int = -1)(pos: Position) extends Expr6 with LValue + object Ident extends ParserBridgePos1[String, Ident] { + def apply(v: String)(pos: Position): Ident = new Ident(v)(pos) + } case class ArrayElem(name: Ident, indices: NonEmptyList[Expr])(pos: Position) extends Expr6 with LValue diff --git a/src/main/wacc/renamer.scala b/src/main/wacc/renamer.scala new file mode 100644 index 0000000..6b78dc1 --- /dev/null +++ b/src/main/wacc/renamer.scala @@ -0,0 +1,215 @@ +package wacc + +import scala.collection.mutable + +object renamer { + import ast._ + import types._ + + enum IdentType { + case Func + case Var + } + + private case class Scope( + current: mutable.Map[(String, IdentType), Ident], + parent: Map[(String, IdentType), Ident] + ) { + + /** Create a new scope with the current scope as its parent. + * + * @return + * A new scope with an empty current scope, and this scope flattened into the parent scope. + */ + def subscope: Scope = + Scope(mutable.Map.empty, Map.empty.withDefault(current.withDefault(parent))) + + /** Attempt to add a new identifier to the current scope. If the identifier already exists in + * the current scope, add an error to the error list. + * + * @param semType + * The semantic type of the identifier. + * @param name + * The name of the identifier. + * @param identType + * The identifier type (function or variable). + * @param globalNames + * The global map of identifiers to semantic types - the identifier will be added to this + * map. + * @param globalNumbering + * The global map of identifier names to the number of times they have been declared - will + * used to rename this identifier, and will be incremented. + * @param errors + * The list of errors to append to. + */ + def add(semType: SemType, name: Ident, identType: IdentType)(using + globalNames: mutable.Map[Ident, SemType], + globalNumbering: mutable.Map[String, Int], + errors: mutable.Builder[Error, List[Error]] + ) = { + if (current.contains((name.v, identType))) { + errors += Error.DuplicateDeclaration(name) + } else { + val uid = globalNumbering.getOrElse(name.v, 0) + name.uid = uid + current((name.v, identType)) = name + + globalNames(name) = semType + globalNumbering(name.v) = uid + 1 + } + } + } + + /** Check scoping of all variables and functions in the program. Also generate semantic types for + * all identifiers. + * + * @param prog + * AST of the program + * @param errors + * List of errors to append to + * @return + * Map of all (renamed) identifies to their semantic types + */ + def rename(prog: Program)(using + errors: mutable.Builder[Error, List[Error]] + ): Map[Ident, SemType] = + given globalNames: mutable.Map[Ident, SemType] = mutable.Map.empty + given globalNumbering: mutable.Map[String, Int] = mutable.Map.empty + val scope = Scope(mutable.Map.empty, Map.empty) + val Program(funcs, main) = prog + funcs + // First add all function declarations to the scope + .map { case FuncDecl(retType, name, params, body) => + val paramTypes = params.map { param => + val paramType = SemType(param.paramType) + paramType + } + scope.add(KnownType.Func(SemType(retType), paramTypes), name, IdentType.Func) + (params zip paramTypes, body) + } + // Only then rename the function bodies + // (functions can call one-another regardless of order of declaration) + .foreach { case (params, body) => + val functionScope = scope.subscope + params.foreach { case (param, paramType) => + functionScope.add(paramType, param.name, IdentType.Var) + } + body.toList.foreach(rename(functionScope.subscope)) // body can shadow function params + } + main.toList.foreach(rename(scope)) + globalNames.toMap + + /** Check scoping of all identifies in a given AST node. + * + * @param scope + * The current scope and flattened parent scope. + * @param node + * The AST node. + * @param globalNames + * The global map of identifiers to semantic types - renamed identifiers will be added to this + * map. + * @param globalNumbering + * The global map of identifier names to the number of times they have been declared - used and + * updated during identifier renaming. + * @param errors + */ + private def rename(scope: Scope)( + node: Ident | Stmt | LValue | RValue | Expr + )(using + globalNames: mutable.Map[Ident, SemType], + globalNumbering: mutable.Map[String, Int], + errors: mutable.Builder[Error, List[Error]] + ): Unit = node match { + // These cases are more interesting because the involve making subscopes + // or modifying the current scope. + case VarDecl(synType, name, value) => { + // Order matters here. Variable isn't declared until after the value is evaluated. + rename(scope)(value) + // Attempt to add the new variable to the current scope. + scope.add(SemType(synType), name, IdentType.Var) + } + case If(cond, thenStmt, elseStmt) => { + rename(scope)(cond) + // then and else both have their own scopes + thenStmt.toList.foreach(rename(scope.subscope)) + elseStmt.toList.foreach(rename(scope.subscope)) + } + case While(cond, body) => { + rename(scope)(cond) + // while bodies have their own scopes + body.toList.foreach(rename(scope.subscope)) + } + // begin-end blocks have their own scopes + case Block(body) => body.toList.foreach(rename(scope.subscope)) + + // These cases are simpler, mostly just recursive calls to rename() + case Assign(lhs, value) => { + // Variables may be reassigned with their value in the rhs, so order doesn't matter here. + rename(scope)(lhs) + rename(scope)(value) + } + case Read(lhs) => rename(scope)(lhs) + case Free(expr) => rename(scope)(expr) + case Return(expr) => rename(scope)(expr) + case Exit(expr) => rename(scope)(expr) + case Print(expr, _) => rename(scope)(expr) + case NewPair(fst, snd) => { + rename(scope)(fst) + rename(scope)(snd) + } + case Call(name, args) => { + renameIdent(scope, name, IdentType.Func) + args.foreach(rename(scope)) + } + case Fst(elem) => rename(scope)(elem) + case Snd(elem) => rename(scope)(elem) + case ArrayLiter(elems) => elems.foreach(rename(scope)) + case ArrayElem(name, indices) => { + rename(scope)(name) + indices.toList.foreach(rename(scope)) + } + case Parens(expr) => rename(scope)(expr) + case op: UnaryOp => rename(scope)(op.x) + case op: BinaryOp => { + rename(scope)(op.x) + rename(scope)(op.y) + } + // Default to variables. Only `call` uses IdentType.Func. + case id: Ident => renameIdent(scope, id, IdentType.Var) + // These literals cannot contain identifies, exit immediately. + case IntLiter(_) | BoolLiter(_) | CharLiter(_) | StrLiter(_) | PairLiter() | Skip() => () + } + + /** Lookup an identifier in the current scope and rename it. If the identifier is not found, add + * an error to the error list and add it to the current scope with an unknown type. + * + * @param scope + * The current scope and flattened parent scope. + * @param ident + * The identifier to rename. + * @param identType + * The type of the identifier (function or variable). + * @param globalNames + * Used to add not-found identifiers to scope. + * @param globalNumbering + * Used to add not-found identifiers to scope. + * @param errors + * The list of errors to append to. + */ + private def renameIdent(scope: Scope, ident: Ident, identType: IdentType)(using + globalNames: mutable.Map[Ident, SemType], + globalNumbering: mutable.Map[String, Int], + errors: mutable.Builder[Error, List[Error]] + ): Unit = { + // Unfortunately map defaults only work with `.apply()`, which throws an error when the key is not found. + // Neither is there a way to check whether a default exists, so we have to use a try-catch. + try { + val Ident(_, uid) = scope.current.withDefault(scope.parent)((ident.v, identType)) + ident.uid = uid + } catch { + case _: NoSuchElementException => + errors += Error.UndefinedIdentifier(ident, identType) + scope.add(?, ident, identType) + } + } +} diff --git a/src/main/wacc/types.scala b/src/main/wacc/types.scala new file mode 100644 index 0000000..e2a7988 --- /dev/null +++ b/src/main/wacc/types.scala @@ -0,0 +1,31 @@ +package wacc + +object types { + import ast._ + + sealed trait SemType + case object ? extends SemType + enum KnownType extends SemType { + case Int + case Bool + case Char + case String + case Array(elem: SemType) + case Pair(left: SemType, right: SemType) + case Func(ret: SemType, params: List[SemType]) + } + + object SemType { + def apply(synType: Type | PairElemType): KnownType = synType match { + case IntType() => KnownType.Int + case BoolType() => KnownType.Bool + case CharType() => KnownType.Char + case StringType() => KnownType.String + // For semantic types it is easier to work with recursion rather than a fixed size + case ArrayType(elemType, dimension) => + (0 until dimension).foldLeft(SemType(elemType))((acc, _) => KnownType.Array(acc)) + case PairType(fst, snd) => KnownType.Pair(SemType(fst), SemType(snd)) + case UntypedPairType() => KnownType.Pair(?, ?) + } + } +}