From 3fbb90322fd727e259ae26f19ecac9461a38495c Mon Sep 17 00:00:00 2001 From: Gleb Koval Date: Tue, 4 Feb 2025 22:26:38 +0000 Subject: [PATCH] feat: renamer maybe maybe maybe maybe --- src/main/wacc/Error.scala | 8 +++ src/main/wacc/Main.scala | 12 +++- src/main/wacc/ast.scala | 19 +++--- src/main/wacc/renamer.scala | 119 ++++++++++++++++++++++++++++++++++++ src/main/wacc/types.scala | 32 ++++++++++ 5 files changed, 181 insertions(+), 9 deletions(-) create mode 100644 src/main/wacc/Error.scala create mode 100644 src/main/wacc/renamer.scala create mode 100644 src/main/wacc/types.scala diff --git a/src/main/wacc/Error.scala b/src/main/wacc/Error.scala new file mode 100644 index 0000000..a9f0490 --- /dev/null +++ b/src/main/wacc/Error.scala @@ -0,0 +1,8 @@ +package wacc + +enum Error { + case DuplicateDeclaration(ident: ast.Ident) + case UndefinedIdentifier(ident: ast.Ident) + case FunctionParamsMismatch(expected: Int, got: Int) + case TypeMismatch(expected: types.SemType, got: types.SemType) +} diff --git a/src/main/wacc/Main.scala b/src/main/wacc/Main.scala index d4b070a..dcc7e6d 100644 --- a/src/main/wacc/Main.scala +++ b/src/main/wacc/Main.scala @@ -1,5 +1,6 @@ package wacc +import scala.collection.mutable import parsley.{Failure, Success} import scopt.OParser import java.io.File @@ -32,8 +33,15 @@ val cliParser = { def compile(contents: String): Int = { parser.parse(contents) match { case Success(ast) => - // TODO: Do semantics things - 0 + given errors: mutable.Builder[Error, List[Error]] = List.newBuilder + val names = renamer.rename(ast) + // given ctx: types.TypeCheckerCtx[List[Error]] = + // types.TypeCheckerCtx(names, errors) + // types.check(ast) + if (errors.result.nonEmpty) { + errors.result.foreach(println) + 200 + } else 0 case Failure(msg) => println(msg) 100 diff --git a/src/main/wacc/ast.scala b/src/main/wacc/ast.scala index 0076006..2457c8f 100644 --- a/src/main/wacc/ast.scala +++ b/src/main/wacc/ast.scala @@ -28,8 +28,10 @@ object ast { object StrLiter extends ParserBridgePos1[String, StrLiter] case class PairLiter()(pos: Position) extends Expr6 object PairLiter extends Expr6 with ParserBridgePos0[PairLiter] - case class Ident(v: String)(pos: Position) extends Expr6 with LValue - object Ident extends ParserBridgePos1[String, Ident] + case class Ident(v: String, var uid: Int = -1) extends Expr6 with LValue + object Ident extends ParserBridgePos1[String, Ident] { + def apply(x1: String): Ident = new Ident(x1) + } case class ArrayElem(name: Ident, indices: NonEmptyList[Expr])(pos: Position) extends Expr6 with LValue @@ -44,15 +46,18 @@ object ast { sealed trait UnaryOp extends Expr { val x: Expr } - case class Negate(x: Expr6)(pos: Position) extends Expr6 with UnaryOp + sealed trait UnaryOp extends Expr { + val x: Expr + } + case class Negate(x: Expr6)(pos: Position) extends Expr6 with UnaryOp with UnaryOp object Negate extends ParserBridgePos1[Expr6, Negate] - case class Not(x: Expr6)(pos: Position) extends Expr6 with UnaryOp + case class Not(x: Expr6)(pos: Position) extends Expr6 with UnaryOp with UnaryOp object Not extends ParserBridgePos1[Expr6, Not] - case class Len(x: Expr6)(pos: Position) extends Expr6 with UnaryOp + case class Len(x: Expr6)(pos: Position) extends Expr6 with UnaryOp with UnaryOp object Len extends ParserBridgePos1[Expr6, Len] - case class Ord(x: Expr6)(pos: Position) extends Expr6 with UnaryOp + case class Ord(x: Expr6)(pos: Position) extends Expr6 with UnaryOp with UnaryOp object Ord extends ParserBridgePos1[Expr6, Ord] - case class Chr(x: Expr6)(pos: Position) extends Expr6 with UnaryOp + case class Chr(x: Expr6)(pos: Position) extends Expr6 with UnaryOp with UnaryOp object Chr extends ParserBridgePos1[Expr6, Chr] // Binary operators diff --git a/src/main/wacc/renamer.scala b/src/main/wacc/renamer.scala new file mode 100644 index 0000000..023b317 --- /dev/null +++ b/src/main/wacc/renamer.scala @@ -0,0 +1,119 @@ +package wacc + +import scala.collection.mutable + +object renamer { + import ast._ + import types._ + + private case class Scope( + current: mutable.Map[String, Ident], + parent: Map[String, Ident] + ) { + def subscope: Scope = + Scope(mutable.Map.empty, Map.empty.withDefault(current.withDefault(parent))) + + def add(semType: SemType, name: Ident)(using + globalNames: mutable.Map[Ident, SemType], + globalNumbering: mutable.Map[String, Int], + errors: mutable.Builder[Error, List[Error]] + ) = { + if (current.contains(name.v)) { + errors += Error.DuplicateDeclaration(name) + } else { + val uid = globalNumbering.getOrElse(name.v, 0) + name.uid = uid + current(name.v) = name + + globalNames(name) = semType + globalNumbering(name.v) = uid + 1 + } + } + } + + def rename(prog: Program)(using + errors: mutable.Builder[Error, List[Error]] + ): Map[Ident, SemType] = + given globalNames: mutable.Map[Ident, SemType] = mutable.Map.empty + given globalNumbering: mutable.Map[String, Int] = mutable.Map.empty + rename(Scope(mutable.Map.empty, Map.empty))(prog) + globalNames.toMap + + private def rename(scope: Scope)( + node: Program | FuncDecl | Ident | Stmt | LValue | RValue + )(using + globalNames: mutable.Map[Ident, SemType], + globalNumbering: mutable.Map[String, Int], + errors: mutable.Builder[Error, List[Error]] + ): Unit = node match { + case Program(funcs, main) => { + funcs.foreach(rename(scope)) + main.toList.foreach(rename(scope)) + } + case FuncDecl(retType, name, params, body) => { + val functionScope = scope.subscope + val paramTypes = params.map { param => + val paramType = SemType(param.paramType) + functionScope.add(paramType, param.name) + paramType + } + scope.add(KnownType.Func(SemType(retType), paramTypes), name) + body.toList.foreach(rename(functionScope)) + } + case VarDecl(synType, name, value) => { + // Order matters here. Variable isn't declared until after the value is evaluated. + rename(scope)(value) + scope.add(SemType(synType), name) + } + case Assign(lhs, value) => { + rename(scope)(lhs) + rename(scope)(value) + } + case Read(lhs) => rename(scope)(lhs) + case Free(expr) => rename(scope)(expr) + case Return(expr) => rename(scope)(expr) + case Exit(expr) => rename(scope)(expr) + case Print(expr, _) => rename(scope)(expr) + case If(cond, thenStmt, elseStmt) => { + rename(scope)(cond) + thenStmt.toList.foreach(rename(scope.subscope)) + elseStmt.toList.foreach(rename(scope.subscope)) + } + case While(cond, body) => { + rename(scope)(cond) + body.toList.foreach(rename(scope.subscope)) + } + case Block(body) => body.toList.foreach(rename(scope.subscope)) + case NewPair(fst, snd) => { + rename(scope)(fst) + rename(scope)(snd) + } + case Call(name, args) => { + rename(scope)(name) + args.foreach(rename(scope)) + } + case Fst(elem) => rename(scope)(elem) + case Snd(elem) => rename(scope)(elem) + case ArrayLiter(elems) => elems.foreach(rename(scope)) + case ArrayElem(name, indices) => { + rename(scope)(name) + indices.toList.foreach(rename(scope)) + } + case Parens(expr) => rename(scope)(expr) + case op: UnaryOp => rename(scope)(op.x) + case op: BinaryOp => { + rename(scope)(op.x) + rename(scope)(op.y) + } + case id: Ident => { + scope.current.withDefault(scope.parent).get(id.v) match { + case Some(Ident(_, uid)) => id.uid = uid + case None => { + errors += Error.UndefinedIdentifier(id) + scope.add(?, id) + } + } + } + case IntLiter(_) | BoolLiter(_) | CharLiter(_) | StrLiter(_) | PairLiter | Skip => () + } +} diff --git a/src/main/wacc/types.scala b/src/main/wacc/types.scala new file mode 100644 index 0000000..2ce5f27 --- /dev/null +++ b/src/main/wacc/types.scala @@ -0,0 +1,32 @@ +package wacc + +import scala.collection.mutable + +object types { + import ast._ + + sealed trait SemType + case object ? extends SemType + enum KnownType extends SemType { + case Int + case Bool + case Char + case String + case Array(elem: SemType) + case Pair(left: SemType, right: SemType) + case Func(ret: SemType, params: List[SemType]) + } + + object SemType { + def apply(synType: Type | PairElemType): KnownType = synType match { + case IntType => KnownType.Int + case BoolType => KnownType.Bool + case CharType => KnownType.Char + case StringType => KnownType.String + case ArrayType(elemType, dimension) => + (0 until dimension).foldLeft(SemType(elemType))((acc, _) => KnownType.Array(acc)) + case PairType(fst, snd) => KnownType.Pair(SemType(fst), SemType(snd)) + case UntypedPairType => KnownType.Pair(?, ?) + } + } +}