From 5016fefc14109dc171ca2abcbd4e3fb767b428f4 Mon Sep 17 00:00:00 2001 From: Guy C Date: Tue, 4 Feb 2025 03:02:37 +0000 Subject: [PATCH 01/18] feat: lexer implements ErrorConfig for improved error messages --- src/main/wacc/lexer.scala | 34 +++++++++++++++++++++++++++++++++- 1 file changed, 33 insertions(+), 1 deletion(-) diff --git a/src/main/wacc/lexer.scala b/src/main/wacc/lexer.scala index 094ac12..ceeec45 100644 --- a/src/main/wacc/lexer.scala +++ b/src/main/wacc/lexer.scala @@ -4,7 +4,39 @@ import parsley.Parsley import parsley.character import parsley.token.{Basic, Lexer} import parsley.token.descriptions.* +import parsley.token.errors._ +import parsley.errors._ +val errConfig = new ErrorConfig { + override def labelSymbol = Map( + "!=" -> Label("binary operator"), + "%" -> Label("binary operator"), + "&&" -> Label("binary operator"), + "*" -> Label("binary operator"), + "+" -> Label("binary operator"), + "-" -> Label("binary operator"), + "/" -> Label("binary operator"), + "<" -> Label("binary operator"), + "<=" -> Label("binary operator"), + "==" -> Label("binary operator"), + ">" -> Label("binary operator"), + ">=" -> Label("binary operator"), + "||" -> Label("binary operator"), + "!" -> Label("unary operator"), + "len" -> Label("unary operator"), + "ord" -> Label("unary operator"), + "chr" -> Label("unary operator"), + "bool" -> Label("valid type"), + "char" -> Label("valid type"), + "int" -> Label("valid type"), + "pair" -> Label("valid type"), + "string" -> Label("valid type"), + "fst" -> Label("pair extraction"), + "snd" -> Label("pair extraction"), + "false" -> Label("boolean value"), + "true" -> Label("boolean value") + ) +} object lexer { private val desc = LexicalDesc.plain.copy( nameDesc = NameDesc.plain.copy( @@ -43,7 +75,7 @@ object lexer { ) ) - private val lexer = Lexer(desc) + private val lexer = Lexer(desc, errConfig) val ident = lexer.lexeme.names.identifier val integer = lexer.lexeme.integer.decimal32[Int] val negateCheck = lexer.nonlexeme.symbol("-") ~> character.digit From 3c236543563f0beda3e51bb483c81cf8c5ad0eda Mon Sep 17 00:00:00 2001 From: Guy C Date: Tue, 4 Feb 2025 03:06:56 +0000 Subject: [PATCH 02/18] fix: remove redundant imports --- src/main/wacc/lexer.scala | 1 - 1 file changed, 1 deletion(-) diff --git a/src/main/wacc/lexer.scala b/src/main/wacc/lexer.scala index ceeec45..b1b1604 100644 --- a/src/main/wacc/lexer.scala +++ b/src/main/wacc/lexer.scala @@ -5,7 +5,6 @@ import parsley.character import parsley.token.{Basic, Lexer} import parsley.token.descriptions.* import parsley.token.errors._ -import parsley.errors._ val errConfig = new ErrorConfig { override def labelSymbol = Map( From 4e50ed35ba89b46411d64c9a483a68c54c25077d Mon Sep 17 00:00:00 2001 From: Guy C Date: Tue, 4 Feb 2025 03:32:52 +0000 Subject: [PATCH 03/18] feat: more error messages --- src/main/wacc/ast.scala | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/src/main/wacc/ast.scala b/src/main/wacc/ast.scala index c6e743e..7e84b14 100644 --- a/src/main/wacc/ast.scala +++ b/src/main/wacc/ast.scala @@ -119,9 +119,13 @@ object ast { case class Print(expr: Expr, newline: Boolean) extends Stmt object Print extends ParserBridge2[Expr, Boolean, Print] case class If(cond: Expr, thenStmt: NonEmptyList[Stmt], elseStmt: NonEmptyList[Stmt]) extends Stmt - object If extends ParserBridge3[Expr, NonEmptyList[Stmt], NonEmptyList[Stmt], If] + object If extends ParserBridge3[Expr, NonEmptyList[Stmt], NonEmptyList[Stmt], If] { + override def labels = List("if statement") + } case class While(cond: Expr, body: NonEmptyList[Stmt]) extends Stmt - object While extends ParserBridge2[Expr, NonEmptyList[Stmt], While] + object While extends ParserBridge2[Expr, NonEmptyList[Stmt], While] { + override def labels = List("while statement") + } case class Block(stmt: NonEmptyList[Stmt]) extends Stmt object Block extends ParserBridge1[NonEmptyList[Stmt], Block] From 4602b756284c104f4bc1821c5cf5c00177e2ef31 Mon Sep 17 00:00:00 2001 From: Guy C Date: Tue, 4 Feb 2025 17:03:46 +0000 Subject: [PATCH 04/18] feat: improved error messages --- src/main/wacc/parser.scala | 39 ++++++++++++++++++++++---------------- 1 file changed, 23 insertions(+), 16 deletions(-) diff --git a/src/main/wacc/parser.scala b/src/main/wacc/parser.scala index 84ee093..c174845 100644 --- a/src/main/wacc/parser.scala +++ b/src/main/wacc/parser.scala @@ -6,7 +6,8 @@ import parsley.Parsley.{atomic, many, notFollowedBy, pure} import parsley.combinator.{countSome, sepBy} import parsley.expr.{precedence, SOps, InfixL, InfixN, InfixR, Prefix, Atoms} import parsley.errors.combinator._ -import parsley.cats.combinator.{sepBy1, some} +import parsley.syntax.zipped._ +import parsley.cats.combinator.{some} import cats.data.NonEmptyList object parser { @@ -72,12 +73,12 @@ object parser { ) private lazy val `` = (`` <**> (`` identity)) | - `` ~> ((`` <**> ``) UntypedPairType) - + `` ~> ((`` <**> ``.explain("for a pair to contain a pair type, it must be an array or erased pair")) UntypedPairType) + // TODO: better explanation here? // Statements private lazy val `` = Program( - "begin" ~> many(atomic(`` <~> `` <~ "(") <**> ``), - `` <~ "end" + "begin" ~> many(atomic(``.label("function declaration") <~> `` <~ "(") <**> ``).label("function declaration"), + ``.label("main program body") <~ "end" ) private lazy val `` = (sepBy(``, ",") <~ ")" <~ "is" <~> ``.guardAgainst { @@ -87,23 +88,28 @@ object parser { } private lazy val `` = Param(``, ``) private lazy val ``: Parsley[NonEmptyList[Stmt]] = - sepBy1(``, ";") + ( + ``.label("main program body"), + (many(";" ~> ``.label("statement after ';'"))) Nil + ).zipped(NonEmptyList.apply) + private lazy val `` = (Skip from "skip") | Read("read" ~> ``) - | Free("free" ~> ``) - | Return("return" ~> ``) - | Exit("exit" ~> ``) - | Print("print" ~> ``, pure(false)) - | Print("println" ~> ``, pure(true)) + | Free("free" ~> ``.label("a valid expression")) + | Return("return" ~> ``.label("a valid expression")) + | Exit("exit" ~> ``.label("a valid expression")) + | Print("print" ~> ``.label("a valid expression"), pure(false)) + | Print("println" ~> ``.label("a valid expression"), pure(true)) | If( - "if" ~> `` <~ "then", + "if" ~> ``.label("a valid expression") <~ "then", `` <~ "else", `` <~ "fi" ) - | While("while" ~> `` <~ "do", `` <~ "done") + | While("while" ~> ``.label("a valid expression") <~ "do", `` <~ "done") | Block("begin" ~> `` <~ "end") - | VarDecl(``, `` <~ "=", ``) + | VarDecl(``, `` <~ "=", ``.label("a valid initial value for variable")) + // TODO: Can we inline the name of the variable in the message | Assign(`` <~ "=", ``) private lazy val ``: Parsley[LValue] = `` | `` @@ -117,9 +123,10 @@ object parser { Call( "call" ~> `` <~ "(", sepBy(``, ",") <~ ")" - ) | `` + ) | ``.label("valid expression") private lazy val `` = - Fst("fst" ~> ``) | Snd("snd" ~> ``) + Fst("fst" ~> ``.label("a valid pair")) + | Snd("snd" ~> ``.label("a valid pair")) private lazy val `` = ArrayLiter( "[" ~> sepBy(``, ",") <~ "]" ) From 057d62546461cf5838bc3834ad48ee83b6eee96a Mon Sep 17 00:00:00 2001 From: Guy C Date: Tue, 4 Feb 2025 17:13:56 +0000 Subject: [PATCH 05/18] fix: style fixes --- src/main/wacc/parser.scala | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/src/main/wacc/parser.scala b/src/main/wacc/parser.scala index c174845..5b8b06c 100644 --- a/src/main/wacc/parser.scala +++ b/src/main/wacc/parser.scala @@ -73,11 +73,15 @@ object parser { ) private lazy val `` = (`` <**> (`` identity)) | - `` ~> ((`` <**> ``.explain("for a pair to contain a pair type, it must be an array or erased pair")) UntypedPairType) + `` ~> ((`` <**> ``.explain( + "for a pair to contain a pair type, it must be an array or erased pair" + )) UntypedPairType) // TODO: better explanation here? // Statements private lazy val `` = Program( - "begin" ~> many(atomic(``.label("function declaration") <~> `` <~ "(") <**> ``).label("function declaration"), + "begin" ~> many( + atomic(``.label("function declaration") <~> `` <~ "(") <**> `` + ).label("function declaration"), ``.label("main program body") <~ "end" ) private lazy val `` = @@ -125,8 +129,8 @@ object parser { sepBy(``, ",") <~ ")" ) | ``.label("valid expression") private lazy val `` = - Fst("fst" ~> ``.label("a valid pair")) - | Snd("snd" ~> ``.label("a valid pair")) + Fst("fst" ~> ``.label("a valid pair")) + | Snd("snd" ~> ``.label("a valid pair")) private lazy val `` = ArrayLiter( "[" ~> sepBy(``, ",") <~ "]" ) From ded35dcc6e4d3b9ba2d34edd2b4bd878eb70e900 Mon Sep 17 00:00:00 2001 From: Guy C Date: Thu, 6 Feb 2025 15:36:24 +0000 Subject: [PATCH 06/18] feat: improved error messages for atom types --- src/main/wacc/lexer.scala | 6 ++---- src/main/wacc/parser.scala | 15 +++++++++------ 2 files changed, 11 insertions(+), 10 deletions(-) diff --git a/src/main/wacc/lexer.scala b/src/main/wacc/lexer.scala index b1b1604..6b47b63 100644 --- a/src/main/wacc/lexer.scala +++ b/src/main/wacc/lexer.scala @@ -12,8 +12,6 @@ val errConfig = new ErrorConfig { "%" -> Label("binary operator"), "&&" -> Label("binary operator"), "*" -> Label("binary operator"), - "+" -> Label("binary operator"), - "-" -> Label("binary operator"), "/" -> Label("binary operator"), "<" -> Label("binary operator"), "<=" -> Label("binary operator"), @@ -32,8 +30,8 @@ val errConfig = new ErrorConfig { "string" -> Label("valid type"), "fst" -> Label("pair extraction"), "snd" -> Label("pair extraction"), - "false" -> Label("boolean value"), - "true" -> Label("boolean value") + "false" -> Label("boolean literal"), + "true" -> Label("boolean literal") ) } object lexer { diff --git a/src/main/wacc/parser.scala b/src/main/wacc/parser.scala index 5b8b06c..f4c0293 100644 --- a/src/main/wacc/parser.scala +++ b/src/main/wacc/parser.scala @@ -29,11 +29,14 @@ object parser { Greater from ">", GreaterEq from ">=" ) +: - SOps(InfixL)(Add from "+", Sub from "-") +: + SOps(InfixL)( + (Add from "+").label("binary operator"), + (Sub from "-").label("binary operator") + ) +: SOps(InfixL)(Mul from "*", Div from "/", Mod from "%") +: SOps(Prefix)( Not from "!", - Negate from (notFollowedBy(negateCheck) ~> "-"), + (Negate from (notFollowedBy(negateCheck) ~> "-")).hide, Len from "len", Ord from "ord", Chr from "chr" @@ -43,10 +46,10 @@ object parser { // Atoms private lazy val ``: Atoms[Expr6] = Atoms( - IntLiter(integer), - BoolLiter(("true" as true) | ("false" as false)), - CharLiter(charLit), - StrLiter(stringLit), + IntLiter(integer).label("integer literal"), + BoolLiter(("true" as true) | ("false" as false)).label("boolean literal"), + CharLiter(charLit).label("character literal"), + StrLiter(stringLit).label("string literal"), PairLiter from "null", ``, Parens("(" ~> `` <~ ")") From e787d7168fc2b05a0eb54e9b93a2975742e0a2f0 Mon Sep 17 00:00:00 2001 From: Barf-Vader <47476490+Barf-Vader@users.noreply.github.com> Date: Thu, 6 Feb 2025 16:59:04 +0000 Subject: [PATCH 07/18] refactor: implemented labelAndExplain(), combining the two, and provided explanations for expr Co-authored-by: gc1523 --- src/main/wacc/parser.scala | 57 +++++++++++++++++++++++++++++--------- 1 file changed, 44 insertions(+), 13 deletions(-) diff --git a/src/main/wacc/parser.scala b/src/main/wacc/parser.scala index f4c0293..609af07 100644 --- a/src/main/wacc/parser.scala +++ b/src/main/wacc/parser.scala @@ -10,11 +10,42 @@ import parsley.syntax.zipped._ import parsley.cats.combinator.{some} import cats.data.NonEmptyList + object parser { import lexer.implicits.implicitSymbol import lexer.{ident, integer, charLit, stringLit, negateCheck} import ast._ + //error extensions + extension [A](p: Parsley[A]) { + + //combines label and explain together into one function call + def labelAndExplain(label: String, explanation: String): Parsley[A] = { + p.label(label).explain(explanation) + } + def labelAndExplain(t: LabelType): Parsley[A] = { + t match { + case LabelType.Expr => + labelWithType(t).explain( + "a valid expression can start with: null, literals, identifiers, unary operators, or parentheses. " + + "Expressions can also contain array indexing and binary operators. " + + "Pair extraction is not allowed in expressions, only in assignments.") + case _ => labelWithType(t) + } + } + + def labelWithType(t: LabelType): Parsley[A] = { + t match { + case LabelType.Expr => p.label("valid expression") + case LabelType.Pair => p.label("valid pair") + } + } + } + + enum LabelType: + case Expr + case Pair + def parse(input: String): Result[String, Program] = parser.parse(input) private val parser = lexer.fully(``) @@ -77,7 +108,7 @@ object parser { private lazy val `` = (`` <**> (`` identity)) | `` ~> ((`` <**> ``.explain( - "for a pair to contain a pair type, it must be an array or erased pair" + "non-erased pair types cannot be nested" )) UntypedPairType) // TODO: better explanation here? // Statements @@ -89,7 +120,7 @@ object parser { ) private lazy val `` = (sepBy(``, ",") <~ ")" <~ "is" <~> ``.guardAgainst { - case stmts if !stmts.isReturning => Seq("All functions must end in a returning statement") + case stmts if !stmts.isReturning => Seq("all functions must end in a returning statement") } <~ "end") map { (params, stmt) => (FuncDecl((_: Type), (_: Ident), params, stmt)).tupled } @@ -103,19 +134,19 @@ object parser { private lazy val `` = (Skip from "skip") | Read("read" ~> ``) - | Free("free" ~> ``.label("a valid expression")) - | Return("return" ~> ``.label("a valid expression")) - | Exit("exit" ~> ``.label("a valid expression")) - | Print("print" ~> ``.label("a valid expression"), pure(false)) - | Print("println" ~> ``.label("a valid expression"), pure(true)) + | Free("free" ~> ``.labelAndExplain(LabelType.Expr)) + | Return("return" ~> ``.labelAndExplain(LabelType.Expr)) + | Exit("exit" ~> ``.labelAndExplain(LabelType.Expr)) + | Print("print" ~> ``.labelAndExplain(LabelType.Expr), pure(false)) + | Print("println" ~> ``.labelAndExplain(LabelType.Expr), pure(true)) | If( - "if" ~> ``.label("a valid expression") <~ "then", + "if" ~> ``.labelWithType(LabelType.Expr) <~ "then", `` <~ "else", `` <~ "fi" ) - | While("while" ~> ``.label("a valid expression") <~ "do", `` <~ "done") + | While("while" ~> ``.labelWithType(LabelType.Expr) <~ "do", `` <~ "done") | Block("begin" ~> `` <~ "end") - | VarDecl(``, `` <~ "=", ``.label("a valid initial value for variable")) + | VarDecl(``, `` <~ "=", ``.label("valid initial value for variable")) // TODO: Can we inline the name of the variable in the message | Assign(`` <~ "=", ``) private lazy val ``: Parsley[LValue] = @@ -130,10 +161,10 @@ object parser { Call( "call" ~> `` <~ "(", sepBy(``, ",") <~ ")" - ) | ``.label("valid expression") + ) | ``.labelWithType(LabelType.Expr) private lazy val `` = - Fst("fst" ~> ``.label("a valid pair")) - | Snd("snd" ~> ``.label("a valid pair")) + Fst("fst" ~> ``.label("valid pair")) + | Snd("snd" ~> ``.label("valid pair")) private lazy val `` = ArrayLiter( "[" ~> sepBy(``, ",") <~ "]" ) From 19880321d70077a84db72e11dfa6ecb91534234a Mon Sep 17 00:00:00 2001 From: Barf-Vader <47476490+Barf-Vader@users.noreply.github.com> Date: Thu, 6 Feb 2025 17:39:35 +0000 Subject: [PATCH 08/18] feat: implemented lexer-backed error builder, error messages are now based on predefined tokens Co-authored-by: gc1523 --- src/main/wacc/lexer.scala | 10 ++++++++++ src/main/wacc/parser.scala | 10 +++++++--- 2 files changed, 17 insertions(+), 3 deletions(-) diff --git a/src/main/wacc/lexer.scala b/src/main/wacc/lexer.scala index 6b47b63..4a810c0 100644 --- a/src/main/wacc/lexer.scala +++ b/src/main/wacc/lexer.scala @@ -80,5 +80,15 @@ object lexer { val stringLit = lexer.lexeme.string.ascii val implicits = lexer.lexeme.symbol.implicits + val errTokens = Seq( + lexer.nonlexeme.names.identifier.map(v => s"identifier $v"), + lexer.nonlexeme.integer.decimal32[Int].map(n => s"integer $n"), + lexer.nonlexeme.character.ascii.map(c => s"character literal $c"), + lexer.nonlexeme.string.ascii.map(s => s"string literal $s"), + character.whitespace.map(_ => "") + ) ++ desc.symbolDesc.hardKeywords.map { k => + lexer.nonlexeme.symbol(k).as(s"keyword $k") + } + def fully[A](p: Parsley[A]): Parsley[A] = lexer.fully(p) } diff --git a/src/main/wacc/parser.scala b/src/main/wacc/parser.scala index 609af07..86eced7 100644 --- a/src/main/wacc/parser.scala +++ b/src/main/wacc/parser.scala @@ -9,16 +9,17 @@ import parsley.errors.combinator._ import parsley.syntax.zipped._ import parsley.cats.combinator.{some} import cats.data.NonEmptyList - +import parsley.errors.DefaultErrorBuilder +import parsley.errors.ErrorBuilder +import parsley.errors.tokenextractors.LexToken object parser { import lexer.implicits.implicitSymbol - import lexer.{ident, integer, charLit, stringLit, negateCheck} + import lexer.{ident, integer, charLit, stringLit, negateCheck, errTokens} import ast._ //error extensions extension [A](p: Parsley[A]) { - //combines label and explain together into one function call def labelAndExplain(label: String, explanation: String): Parsley[A] = { p.label(label).explain(explanation) @@ -46,6 +47,9 @@ object parser { case Expr case Pair + implicit val builder: ErrorBuilder[String] = new DefaultErrorBuilder with LexToken { + def tokens = errTokens + } def parse(input: String): Result[String, Program] = parser.parse(input) private val parser = lexer.fully(``) From 8b64f2e352e362afd283208c3394cfdb70b62885 Mon Sep 17 00:00:00 2001 From: Barf-Vader <47476490+Barf-Vader@users.noreply.github.com> Date: Thu, 6 Feb 2025 17:41:58 +0000 Subject: [PATCH 09/18] fix: removed redundant labelling in ast Co-authored-by: gc1523 --- src/main/wacc/ast.scala | 8 ++------ 1 file changed, 2 insertions(+), 6 deletions(-) diff --git a/src/main/wacc/ast.scala b/src/main/wacc/ast.scala index 7e84b14..c6e743e 100644 --- a/src/main/wacc/ast.scala +++ b/src/main/wacc/ast.scala @@ -119,13 +119,9 @@ object ast { case class Print(expr: Expr, newline: Boolean) extends Stmt object Print extends ParserBridge2[Expr, Boolean, Print] case class If(cond: Expr, thenStmt: NonEmptyList[Stmt], elseStmt: NonEmptyList[Stmt]) extends Stmt - object If extends ParserBridge3[Expr, NonEmptyList[Stmt], NonEmptyList[Stmt], If] { - override def labels = List("if statement") - } + object If extends ParserBridge3[Expr, NonEmptyList[Stmt], NonEmptyList[Stmt], If] case class While(cond: Expr, body: NonEmptyList[Stmt]) extends Stmt - object While extends ParserBridge2[Expr, NonEmptyList[Stmt], While] { - override def labels = List("while statement") - } + object While extends ParserBridge2[Expr, NonEmptyList[Stmt], While] case class Block(stmt: NonEmptyList[Stmt]) extends Stmt object Block extends ParserBridge1[NonEmptyList[Stmt], Block] From bd779931b6787da20115f56ca7418bb175a2a783 Mon Sep 17 00:00:00 2001 From: Barf-Vader <47476490+Barf-Vader@users.noreply.github.com> Date: Thu, 6 Feb 2025 17:48:14 +0000 Subject: [PATCH 10/18] refactor: style fixes in parser Co-authored-by: gc1523 --- src/main/wacc/parser.scala | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/src/main/wacc/parser.scala b/src/main/wacc/parser.scala index 86eced7..9cec896 100644 --- a/src/main/wacc/parser.scala +++ b/src/main/wacc/parser.scala @@ -18,9 +18,9 @@ object parser { import lexer.{ident, integer, charLit, stringLit, negateCheck, errTokens} import ast._ - //error extensions + // error extensions extension [A](p: Parsley[A]) { - //combines label and explain together into one function call + // combines label and explain together into one function call def labelAndExplain(label: String, explanation: String): Parsley[A] = { p.label(label).explain(explanation) } @@ -28,9 +28,10 @@ object parser { t match { case LabelType.Expr => labelWithType(t).explain( - "a valid expression can start with: null, literals, identifiers, unary operators, or parentheses. " + - "Expressions can also contain array indexing and binary operators. " + - "Pair extraction is not allowed in expressions, only in assignments.") + "a valid expression can start with: null, literals, identifiers, unary operators, or parentheses. " + + "Expressions can also contain array indexing and binary operators. " + + "Pair extraction is not allowed in expressions, only in assignments." + ) case _ => labelWithType(t) } } From 1486296b40a340e46c0bd3aa1b44030ec7a8d40d Mon Sep 17 00:00:00 2001 From: Gleb Koval Date: Wed, 5 Feb 2025 02:12:16 +0000 Subject: [PATCH 11/18] refactor: add position tracking to AST, UnaryOp, BinaryOp --- src/main/wacc/ast.scala | 282 ++++++++++++++++++++++++++-------------- 1 file changed, 184 insertions(+), 98 deletions(-) diff --git a/src/main/wacc/ast.scala b/src/main/wacc/ast.scala index c6e743e..9afd8d9 100644 --- a/src/main/wacc/ast.scala +++ b/src/main/wacc/ast.scala @@ -1,6 +1,9 @@ package wacc -import parsley.generic._ +import parsley.Parsley +import parsley.generic.ErrorBridge +import parsley.position._ +import parsley.syntax.zipped._ import cats.data.NonEmptyList object ast { @@ -14,80 +17,97 @@ object ast { sealed trait Expr6 extends Expr5 // Atoms - case class IntLiter(v: Int) extends Expr6 - object IntLiter extends ParserBridge1[Int, IntLiter] - case class BoolLiter(v: Boolean) extends Expr6 - object BoolLiter extends ParserBridge1[Boolean, BoolLiter] - case class CharLiter(v: Char) extends Expr6 - object CharLiter extends ParserBridge1[Char, CharLiter] - case class StrLiter(v: String) extends Expr6 - object StrLiter extends ParserBridge1[String, StrLiter] - case object PairLiter extends Expr6 with ParserBridge0[PairLiter.type] - case class Ident(v: String) extends Expr6 with LValue - object Ident extends ParserBridge1[String, Ident] - case class ArrayElem(name: Ident, indices: NonEmptyList[Expr]) extends Expr6 with LValue - object ArrayElem extends ParserBridge2[Ident, NonEmptyList[Expr], ArrayElem] - case class Parens(expr: Expr) extends Expr6 - object Parens extends ParserBridge1[Expr, Parens] + case class IntLiter(v: Int)(pos: Position) extends Expr6 + object IntLiter extends ParserBridgePos1[Int, IntLiter] + case class BoolLiter(v: Boolean)(pos: Position) extends Expr6 + object BoolLiter extends ParserBridgePos1[Boolean, BoolLiter] + case class CharLiter(v: Char)(pos: Position) extends Expr6 + object CharLiter extends ParserBridgePos1[Char, CharLiter] + case class StrLiter(v: String)(pos: Position) extends Expr6 + object StrLiter extends ParserBridgePos1[String, StrLiter] + case class PairLiter()(pos: Position) extends Expr6 + object PairLiter extends Expr6 with ParserBridgePos0[PairLiter] + case class Ident(v: String)(pos: Position) extends Expr6 with LValue + object Ident extends ParserBridgePos1[String, Ident] + case class ArrayElem(name: Ident, indices: NonEmptyList[Expr])(pos: Position) + extends Expr6 + with LValue + object ArrayElem extends ParserBridgePos2[Ident, NonEmptyList[Expr], ArrayElem] + case class Parens(expr: Expr)(pos: Position) extends Expr6 + object Parens extends ParserBridgePos1[Expr, Parens] // Unary operators - case class Negate(x: Expr6) extends Expr6 - object Negate extends ParserBridge1[Expr6, Negate] - case class Not(x: Expr6) extends Expr6 - object Not extends ParserBridge1[Expr6, Not] - case class Len(x: Expr6) extends Expr6 - object Len extends ParserBridge1[Expr6, Len] - case class Ord(x: Expr6) extends Expr6 - object Ord extends ParserBridge1[Expr6, Ord] - case class Chr(x: Expr6) extends Expr6 - object Chr extends ParserBridge1[Expr6, Chr] + sealed trait UnaryOp extends Expr { + val x: Expr + } + case class Negate(x: Expr6)(pos: Position) extends Expr6 with UnaryOp + object Negate extends ParserBridgePos1[Expr6, Negate] + case class Not(x: Expr6)(pos: Position) extends Expr6 with UnaryOp + object Not extends ParserBridgePos1[Expr6, Not] + case class Len(x: Expr6)(pos: Position) extends Expr6 with UnaryOp + object Len extends ParserBridgePos1[Expr6, Len] + case class Ord(x: Expr6)(pos: Position) extends Expr6 with UnaryOp + object Ord extends ParserBridgePos1[Expr6, Ord] + case class Chr(x: Expr6)(pos: Position) extends Expr6 with UnaryOp + object Chr extends ParserBridgePos1[Expr6, Chr] // Binary operators - case class Add(x: Expr4, y: Expr5) extends Expr4 - object Add extends ParserBridge2[Expr4, Expr5, Add] - case class Sub(x: Expr4, y: Expr5) extends Expr4 - object Sub extends ParserBridge2[Expr4, Expr5, Sub] - case class Mul(x: Expr5, y: Expr6) extends Expr5 - object Mul extends ParserBridge2[Expr5, Expr6, Mul] - case class Div(x: Expr5, y: Expr6) extends Expr5 - object Div extends ParserBridge2[Expr5, Expr6, Div] - case class Mod(x: Expr5, y: Expr6) extends Expr5 - object Mod extends ParserBridge2[Expr5, Expr6, Mod] - case class Greater(x: Expr4, y: Expr4) extends Expr3 - object Greater extends ParserBridge2[Expr4, Expr4, Greater] - case class GreaterEq(x: Expr4, y: Expr4) extends Expr3 - object GreaterEq extends ParserBridge2[Expr4, Expr4, GreaterEq] - case class Less(x: Expr4, y: Expr4) extends Expr3 - object Less extends ParserBridge2[Expr4, Expr4, Less] - case class LessEq(x: Expr4, y: Expr4) extends Expr3 - object LessEq extends ParserBridge2[Expr4, Expr4, LessEq] - case class Eq(x: Expr3, y: Expr3) extends Expr2 - object Eq extends ParserBridge2[Expr3, Expr3, Eq] - case class Neq(x: Expr3, y: Expr3) extends Expr2 - object Neq extends ParserBridge2[Expr3, Expr3, Neq] - case class And(x: Expr2, y: Expr1) extends Expr1 - object And extends ParserBridge2[Expr2, Expr1, And] - case class Or(x: Expr1, y: Expr) extends Expr - object Or extends ParserBridge2[Expr1, Expr, Or] + sealed trait BinaryOp extends Expr { + val x: Expr + val y: Expr + } + case class Add(x: Expr4, y: Expr5)(pos: Position) extends Expr4 with BinaryOp + object Add extends ParserBridgePos2[Expr4, Expr5, Add] + case class Sub(x: Expr4, y: Expr5)(pos: Position) extends Expr4 with BinaryOp + object Sub extends ParserBridgePos2[Expr4, Expr5, Sub] + case class Mul(x: Expr5, y: Expr6)(pos: Position) extends Expr5 with BinaryOp + object Mul extends ParserBridgePos2[Expr5, Expr6, Mul] + case class Div(x: Expr5, y: Expr6)(pos: Position) extends Expr5 with BinaryOp + object Div extends ParserBridgePos2[Expr5, Expr6, Div] + case class Mod(x: Expr5, y: Expr6)(pos: Position) extends Expr5 with BinaryOp + object Mod extends ParserBridgePos2[Expr5, Expr6, Mod] + case class Greater(x: Expr4, y: Expr4)(pos: Position) extends Expr3 with BinaryOp + object Greater extends ParserBridgePos2[Expr4, Expr4, Greater] + case class GreaterEq(x: Expr4, y: Expr4)(pos: Position) extends Expr3 with BinaryOp + object GreaterEq extends ParserBridgePos2[Expr4, Expr4, GreaterEq] + case class Less(x: Expr4, y: Expr4)(pos: Position) extends Expr3 with BinaryOp + object Less extends ParserBridgePos2[Expr4, Expr4, Less] + case class LessEq(x: Expr4, y: Expr4)(pos: Position) extends Expr3 with BinaryOp + object LessEq extends ParserBridgePos2[Expr4, Expr4, LessEq] + case class Eq(x: Expr3, y: Expr3)(pos: Position) extends Expr2 with BinaryOp + object Eq extends ParserBridgePos2[Expr3, Expr3, Eq] + case class Neq(x: Expr3, y: Expr3)(pos: Position) extends Expr2 with BinaryOp + object Neq extends ParserBridgePos2[Expr3, Expr3, Neq] + case class And(x: Expr2, y: Expr1)(pos: Position) extends Expr1 with BinaryOp + object And extends ParserBridgePos2[Expr2, Expr1, And] + case class Or(x: Expr1, y: Expr)(pos: Position) extends Expr with BinaryOp + object Or extends ParserBridgePos2[Expr1, Expr, Or] // Types sealed trait Type sealed trait BaseType extends Type with PairElemType - case object IntType extends BaseType with ParserBridge0[IntType.type] - case object BoolType extends BaseType with ParserBridge0[BoolType.type] - case object CharType extends BaseType with ParserBridge0[CharType.type] - case object StringType extends BaseType with ParserBridge0[StringType.type] - case class ArrayType(elemType: Type, dimensions: Int) extends Type with PairElemType - object ArrayType extends ParserBridge2[Type, Int, ArrayType] - case class PairType(fst: PairElemType, snd: PairElemType) extends Type - object PairType extends ParserBridge2[PairElemType, PairElemType, PairType] + case class IntType()(pos: Position) extends BaseType + object IntType extends ParserBridgePos0[IntType] + case class BoolType()(pos: Position) extends BaseType + object BoolType extends ParserBridgePos0[BoolType] + case class CharType()(pos: Position) extends BaseType + object CharType extends ParserBridgePos0[CharType] + case class StringType()(pos: Position) extends BaseType + object StringType extends ParserBridgePos0[StringType] + case class ArrayType(elemType: Type, dimensions: Int)(pos: Position) + extends Type + with PairElemType + object ArrayType extends ParserBridgePos2[Type, Int, ArrayType] + case class PairType(fst: PairElemType, snd: PairElemType)(pos: Position) extends Type + object PairType extends ParserBridgePos2[PairElemType, PairElemType, PairType] sealed trait PairElemType - case object UntypedPairType extends PairElemType with ParserBridge0[UntypedPairType.type] + case class UntypedPairType()(pos: Position) extends PairElemType + object UntypedPairType extends ParserBridgePos0[UntypedPairType] // waccadoodledo - case class Program(funcs: List[FuncDecl], main: NonEmptyList[Stmt]) - object Program extends ParserBridge2[List[FuncDecl], NonEmptyList[Stmt], Program] + case class Program(funcs: List[FuncDecl], main: NonEmptyList[Stmt])(pos: Position) + object Program extends ParserBridgePos2[List[FuncDecl], NonEmptyList[Stmt], Program] // Function Definitions case class FuncDecl( @@ -95,49 +115,115 @@ object ast { name: Ident, params: List[Param], body: NonEmptyList[Stmt] - ) - object FuncDecl extends ParserBridge4[Type, Ident, List[Param], NonEmptyList[Stmt], FuncDecl] + )(pos: Position) + object FuncDecl extends ParserBridgePos4[Type, Ident, List[Param], NonEmptyList[Stmt], FuncDecl] - case class Param(paramType: Type, name: Ident) - object Param extends ParserBridge2[Type, Ident, Param] + case class Param(paramType: Type, name: Ident)(pos: Position) + object Param extends ParserBridgePos2[Type, Ident, Param] // Statements sealed trait Stmt - case object Skip extends Stmt with ParserBridge0[Skip.type] - case class VarDecl(varType: Type, name: Ident, value: RValue) extends Stmt - object VarDecl extends ParserBridge3[Type, Ident, RValue, VarDecl] - case class Assign(lhs: LValue, value: RValue) extends Stmt - object Assign extends ParserBridge2[LValue, RValue, Assign] - case class Read(lhs: LValue) extends Stmt - object Read extends ParserBridge1[LValue, Read] - case class Free(expr: Expr) extends Stmt - object Free extends ParserBridge1[Expr, Free] - case class Return(expr: Expr) extends Stmt - object Return extends ParserBridge1[Expr, Return] - case class Exit(expr: Expr) extends Stmt - object Exit extends ParserBridge1[Expr, Exit] - case class Print(expr: Expr, newline: Boolean) extends Stmt - object Print extends ParserBridge2[Expr, Boolean, Print] - case class If(cond: Expr, thenStmt: NonEmptyList[Stmt], elseStmt: NonEmptyList[Stmt]) extends Stmt - object If extends ParserBridge3[Expr, NonEmptyList[Stmt], NonEmptyList[Stmt], If] - case class While(cond: Expr, body: NonEmptyList[Stmt]) extends Stmt - object While extends ParserBridge2[Expr, NonEmptyList[Stmt], While] - case class Block(stmt: NonEmptyList[Stmt]) extends Stmt - object Block extends ParserBridge1[NonEmptyList[Stmt], Block] + case class Skip()(pos: Position) extends Stmt + object Skip extends ParserBridgePos0[Skip] + case class VarDecl(varType: Type, name: Ident, value: RValue)(pos: Position) extends Stmt + object VarDecl extends ParserBridgePos3[Type, Ident, RValue, VarDecl] + case class Assign(lhs: LValue, value: RValue)(pos: Position) extends Stmt + object Assign extends ParserBridgePos2[LValue, RValue, Assign] + case class Read(lhs: LValue)(pos: Position) extends Stmt + object Read extends ParserBridgePos1[LValue, Read] + case class Free(expr: Expr)(pos: Position) extends Stmt + object Free extends ParserBridgePos1[Expr, Free] + case class Return(expr: Expr)(pos: Position) extends Stmt + object Return extends ParserBridgePos1[Expr, Return] + case class Exit(expr: Expr)(pos: Position) extends Stmt + object Exit extends ParserBridgePos1[Expr, Exit] + case class Print(expr: Expr, newline: Boolean)(pos: Position) extends Stmt + object Print extends ParserBridgePos2[Expr, Boolean, Print] + case class If(cond: Expr, thenStmt: NonEmptyList[Stmt], elseStmt: NonEmptyList[Stmt])( + pos: Position + ) extends Stmt + object If extends ParserBridgePos3[Expr, NonEmptyList[Stmt], NonEmptyList[Stmt], If] + case class While(cond: Expr, body: NonEmptyList[Stmt])(pos: Position) extends Stmt + object While extends ParserBridgePos2[Expr, NonEmptyList[Stmt], While] + case class Block(stmt: NonEmptyList[Stmt])(pos: Position) extends Stmt + object Block extends ParserBridgePos1[NonEmptyList[Stmt], Block] sealed trait LValue sealed trait RValue - case class ArrayLiter(elems: List[Expr]) extends RValue - object ArrayLiter extends ParserBridge1[List[Expr], ArrayLiter] - case class NewPair(fst: Expr, snd: Expr) extends RValue - object NewPair extends ParserBridge2[Expr, Expr, NewPair] - case class Call(name: Ident, args: List[Expr]) extends RValue - object Call extends ParserBridge2[Ident, List[Expr], Call] + case class ArrayLiter(elems: List[Expr])(pos: Position) extends RValue + object ArrayLiter extends ParserBridgePos1[List[Expr], ArrayLiter] + case class NewPair(fst: Expr, snd: Expr)(pos: Position) extends RValue + object NewPair extends ParserBridgePos2[Expr, Expr, NewPair] + case class Call(name: Ident, args: List[Expr])(pos: Position) extends RValue + object Call extends ParserBridgePos2[Ident, List[Expr], Call] sealed trait PairElem extends LValue with RValue - case class Fst(elem: LValue) extends PairElem - object Fst extends ParserBridge1[LValue, Fst] - case class Snd(elem: LValue) extends PairElem - object Snd extends ParserBridge1[LValue, Snd] + case class Fst(elem: LValue)(pos: Position) extends PairElem + object Fst extends ParserBridgePos1[LValue, Fst] + case class Snd(elem: LValue)(pos: Position) extends PairElem + object Snd extends ParserBridgePos1[LValue, Snd] + + // Parser bridges + case class Position(line: Int, column: Int, offset: Int, width: Int) + + private def applyCon[A, B]( + con: ((Int, Int), Int, Int) => A => B + )(ops: Parsley[A]): Parsley[B] = + (pos, offset, withWidth(ops)).zipped.map { (pos, off, res) => + con(pos, off, res._2)(res._1) + } + + trait ParserSingletonBridgePos[+A] extends ErrorBridge { + protected def con(pos: (Int, Int), offset: Int, width: Int): A + def from(op: Parsley[?]): Parsley[A] = error((pos, offset, withWidth(op)).zipped.map { + (pos, off, res) => con(pos, off, res._2) + }) + final def <#(op: Parsley[?]): Parsley[A] = from(op) + } + + trait ParserBridgePos0[+A] extends ParserSingletonBridgePos[A] { + def apply()(pos: Position): A + + override final def con(pos: (Int, Int), offset: Int, width: Int): A = + apply()(Position(pos._1, pos._2, offset, width)) + } + + trait ParserBridgePos1[-A, +B] extends ParserSingletonBridgePos[A => B] { + def apply(a: A)(pos: Position): B + def apply(a: Parsley[A]): Parsley[B] = error((pos, offset, withWidth(a)).zipped.map { + (pos, off, res) => con(pos, off, res._2)(res._1) + }) + + override final def con(pos: (Int, Int), offset: Int, width: Int): A => B = + apply(_)(Position(pos._1, pos._2, offset, width)) + } + + trait ParserBridgePos2[-A, -B, +C] extends ParserSingletonBridgePos[((A, B)) => C] { + def apply(a: A, b: B)(pos: Position): C + def apply(a: Parsley[A], b: Parsley[B]): Parsley[C] = error(applyCon(con)((a, b).zipped)) + + override final def con(pos: (Int, Int), offset: Int, width: Int): ((A, B)) => C = + apply(_, _)(Position(pos._1, pos._2, offset, width)) + } + + trait ParserBridgePos3[-A, -B, -C, +D] extends ParserSingletonBridgePos[((A, B, C)) => D] { + def apply(a: A, b: B, c: C)(pos: Position): D + def apply(a: Parsley[A], b: Parsley[B], c: Parsley[C]): Parsley[D] = error( + applyCon(con)((a, b, c).zipped) + ) + + override final def con(pos: (Int, Int), offset: Int, width: Int): ((A, B, C)) => D = + apply(_, _, _)(Position(pos._1, pos._2, offset, width)) + } + + trait ParserBridgePos4[-A, -B, -C, -D, +E] extends ParserSingletonBridgePos[((A, B, C, D)) => E] { + def apply(a: A, b: B, c: C, d: D)(pos: Position): E + def apply(a: Parsley[A], b: Parsley[B], c: Parsley[C], d: Parsley[D]): Parsley[E] = error( + applyCon(con)((a, b, c, d).zipped) + ) + + override final def con(pos: (Int, Int), offset: Int, width: Int): ((A, B, C, D)) => E = + apply(_, _, _, _)(Position(pos._1, pos._2, offset, width)) + } } From 6d1c0b7a87ded87b5ff3781c4514c500c42f26b5 Mon Sep 17 00:00:00 2001 From: Gleb Koval Date: Wed, 5 Feb 2025 04:47:43 +0000 Subject: [PATCH 12/18] fix: make parser use only parsley parser bridge apply --- src/main/wacc/ast.scala | 76 ++++++++++++++++++++------------------ src/main/wacc/parser.scala | 26 ++++++------- 2 files changed, 52 insertions(+), 50 deletions(-) diff --git a/src/main/wacc/ast.scala b/src/main/wacc/ast.scala index 9afd8d9..e4a9128 100644 --- a/src/main/wacc/ast.scala +++ b/src/main/wacc/ast.scala @@ -2,6 +2,7 @@ package wacc import parsley.Parsley import parsley.generic.ErrorBridge +import parsley.ap._ import parsley.position._ import parsley.syntax.zipped._ import cats.data.NonEmptyList @@ -32,7 +33,10 @@ object ast { case class ArrayElem(name: Ident, indices: NonEmptyList[Expr])(pos: Position) extends Expr6 with LValue - object ArrayElem extends ParserBridgePos2[Ident, NonEmptyList[Expr], ArrayElem] + object ArrayElem extends ParserBridgePos1[NonEmptyList[Expr], Ident => ArrayElem] { + def apply(a: NonEmptyList[Expr])(pos: Position): Ident => ArrayElem = + name => ArrayElem(name, a)(pos) + } case class Parens(expr: Expr)(pos: Position) extends Expr6 object Parens extends ParserBridgePos1[Expr, Parens] @@ -97,7 +101,9 @@ object ast { case class ArrayType(elemType: Type, dimensions: Int)(pos: Position) extends Type with PairElemType - object ArrayType extends ParserBridgePos2[Type, Int, ArrayType] + object ArrayType extends ParserBridgePos1[Int, Type => ArrayType] { + def apply(a: Int)(pos: Position): Type => ArrayType = elemType => ArrayType(elemType, a)(pos) + } case class PairType(fst: PairElemType, snd: PairElemType)(pos: Position) extends Type object PairType extends ParserBridgePos2[PairElemType, PairElemType, PairType] @@ -116,7 +122,17 @@ object ast { params: List[Param], body: NonEmptyList[Stmt] )(pos: Position) - object FuncDecl extends ParserBridgePos4[Type, Ident, List[Param], NonEmptyList[Stmt], FuncDecl] + object FuncDecl + extends ParserBridgePos2[ + List[Param], + NonEmptyList[Stmt], + ((Type, Ident)) => FuncDecl + ] { + def apply(params: List[Param], body: NonEmptyList[Stmt])( + pos: Position + ): ((Type, Ident)) => FuncDecl = + (returnType, name) => FuncDecl(returnType, name, params, body)(pos) + } case class Param(paramType: Type, name: Ident)(pos: Position) object Param extends ParserBridgePos2[Type, Ident, Param] @@ -165,65 +181,53 @@ object ast { object Snd extends ParserBridgePos1[LValue, Snd] // Parser bridges - case class Position(line: Int, column: Int, offset: Int, width: Int) + case class Position(line: Int, column: Int, offset: Int) private def applyCon[A, B]( con: ((Int, Int), Int, Int) => A => B - )(ops: Parsley[A]): Parsley[B] = + )(ops: => Parsley[A]): Parsley[B] = (pos, offset, withWidth(ops)).zipped.map { (pos, off, res) => con(pos, off, res._2)(res._1) } trait ParserSingletonBridgePos[+A] extends ErrorBridge { - protected def con(pos: (Int, Int), offset: Int, width: Int): A - def from(op: Parsley[?]): Parsley[A] = error((pos, offset, withWidth(op)).zipped.map { - (pos, off, res) => con(pos, off, res._2) - }) - final def <#(op: Parsley[?]): Parsley[A] = from(op) + protected def con(pos: (Int, Int), offset: Int): A + infix def from(op: Parsley[?]): Parsley[A] = error((pos, offset).zipped(con) <~ op) + final def <#(op: Parsley[?]): Parsley[A] = this from op } trait ParserBridgePos0[+A] extends ParserSingletonBridgePos[A] { def apply()(pos: Position): A - override final def con(pos: (Int, Int), offset: Int, width: Int): A = - apply()(Position(pos._1, pos._2, offset, width)) + override final def con(pos: (Int, Int), offset: Int): A = + apply()(Position(pos._1, pos._2, offset)) } trait ParserBridgePos1[-A, +B] extends ParserSingletonBridgePos[A => B] { def apply(a: A)(pos: Position): B - def apply(a: Parsley[A]): Parsley[B] = error((pos, offset, withWidth(a)).zipped.map { - (pos, off, res) => con(pos, off, res._2)(res._1) - }) + def apply(a: Parsley[A]): Parsley[B] = error(ap1((pos, offset).zipped(con), a)) - override final def con(pos: (Int, Int), offset: Int, width: Int): A => B = - apply(_)(Position(pos._1, pos._2, offset, width)) + override final def con(pos: (Int, Int), offset: Int): A => B = + this.apply(_)(Position(pos._1, pos._2, offset)) } - trait ParserBridgePos2[-A, -B, +C] extends ParserSingletonBridgePos[((A, B)) => C] { + trait ParserBridgePos2[-A, -B, +C] extends ParserSingletonBridgePos[(A, B) => C] { def apply(a: A, b: B)(pos: Position): C - def apply(a: Parsley[A], b: Parsley[B]): Parsley[C] = error(applyCon(con)((a, b).zipped)) + def apply(a: Parsley[A], b: => Parsley[B]): Parsley[C] = error( + ap2((pos, offset).zipped(con), a, b) + ) - override final def con(pos: (Int, Int), offset: Int, width: Int): ((A, B)) => C = - apply(_, _)(Position(pos._1, pos._2, offset, width)) + override final def con(pos: (Int, Int), offset: Int): (A, B) => C = + apply(_, _)(Position(pos._1, pos._2, offset)) } - trait ParserBridgePos3[-A, -B, -C, +D] extends ParserSingletonBridgePos[((A, B, C)) => D] { + trait ParserBridgePos3[-A, -B, -C, +D] extends ParserSingletonBridgePos[(A, B, C) => D] { def apply(a: A, b: B, c: C)(pos: Position): D - def apply(a: Parsley[A], b: Parsley[B], c: Parsley[C]): Parsley[D] = error( - applyCon(con)((a, b, c).zipped) + def apply(a: Parsley[A], b: => Parsley[B], c: => Parsley[C]): Parsley[D] = error( + ap3((pos, offset).zipped(con), a, b, c) ) - override final def con(pos: (Int, Int), offset: Int, width: Int): ((A, B, C)) => D = - apply(_, _, _)(Position(pos._1, pos._2, offset, width)) - } - - trait ParserBridgePos4[-A, -B, -C, -D, +E] extends ParserSingletonBridgePos[((A, B, C, D)) => E] { - def apply(a: A, b: B, c: C, d: D)(pos: Position): E - def apply(a: Parsley[A], b: Parsley[B], c: Parsley[C], d: Parsley[D]): Parsley[E] = error( - applyCon(con)((a, b, c, d).zipped) - ) - - override final def con(pos: (Int, Int), offset: Int, width: Int): ((A, B, C, D)) => E = - apply(_, _, _, _)(Position(pos._1, pos._2, offset, width)) + override final def con(pos: (Int, Int), offset: Int): (A, B, C) => D = + apply(_, _, _)(Position(pos._1, pos._2, offset)) } } diff --git a/src/main/wacc/parser.scala b/src/main/wacc/parser.scala index 9cec896..5751732 100644 --- a/src/main/wacc/parser.scala +++ b/src/main/wacc/parser.scala @@ -93,10 +93,7 @@ object parser { private val `` = Ident(ident) private lazy val `` = `` <**> (`` identity) - private val `` = - some("[" ~> `` <~ "]") map { indices => - ArrayElem((_: Ident), indices) - } + private val `` = ArrayElem(some("[" ~> `` <~ "]")) // Types private lazy val ``: Parsley[Type] = @@ -104,7 +101,7 @@ object parser { private val `` = (IntType from "int") | (BoolType from "bool") | (CharType from "char") | (StringType from "string") private lazy val `` = - countSome("[" ~> "]") map { cnt => ArrayType((_: Type), cnt) } + ArrayType(countSome("[" ~> "]")) private val `` = "pair" private val ``: Parsley[PairType] = PairType( "(" ~> `` <~ ",", @@ -112,10 +109,10 @@ object parser { ) private lazy val `` = (`` <**> (`` identity)) | - `` ~> ((`` <**> ``.explain( - "non-erased pair types cannot be nested" - )) UntypedPairType) - // TODO: better explanation here? + ((UntypedPairType from ``) <**> + ((`` <**> ``) + .map(arr => (_: UntypedPairType) => arr) identity)) + // Statements private lazy val `` = Program( "begin" ~> many( @@ -124,11 +121,12 @@ object parser { ``.label("main program body") <~ "end" ) private lazy val `` = - (sepBy(``, ",") <~ ")" <~ "is" <~> ``.guardAgainst { - case stmts if !stmts.isReturning => Seq("all functions must end in a returning statement") - } <~ "end") map { (params, stmt) => - (FuncDecl((_: Type), (_: Ident), params, stmt)).tupled - } + FuncDecl( + sepBy(``, ",") <~ ")" <~ "is", + ``.guardAgainst { + case stmts if !stmts.isReturning => Seq("All functions must end in a returning statement") + } <~ "end" + ) private lazy val `` = Param(``, ``) private lazy val ``: Parsley[NonEmptyList[Stmt]] = ( From e9ed197782bec4fe74f575aa821b3ee648052bff Mon Sep 17 00:00:00 2001 From: Gleb Koval Date: Wed, 5 Feb 2025 04:49:05 +0000 Subject: [PATCH 13/18] fix: remove unused applyCon from AST --- src/main/wacc/ast.scala | 7 ------- 1 file changed, 7 deletions(-) diff --git a/src/main/wacc/ast.scala b/src/main/wacc/ast.scala index e4a9128..0076006 100644 --- a/src/main/wacc/ast.scala +++ b/src/main/wacc/ast.scala @@ -183,13 +183,6 @@ object ast { // Parser bridges case class Position(line: Int, column: Int, offset: Int) - private def applyCon[A, B]( - con: ((Int, Int), Int, Int) => A => B - )(ops: => Parsley[A]): Parsley[B] = - (pos, offset, withWidth(ops)).zipped.map { (pos, off, res) => - con(pos, off, res._2)(res._1) - } - trait ParserSingletonBridgePos[+A] extends ErrorBridge { protected def con(pos: (Int, Int), offset: Int): A infix def from(op: Parsley[?]): Parsley[A] = error((pos, offset).zipped(con) <~ op) From 3fbb90322fd727e259ae26f19ecac9461a38495c Mon Sep 17 00:00:00 2001 From: Gleb Koval Date: Tue, 4 Feb 2025 22:26:38 +0000 Subject: [PATCH 14/18] feat: renamer maybe maybe maybe maybe --- src/main/wacc/Error.scala | 8 +++ src/main/wacc/Main.scala | 12 +++- src/main/wacc/ast.scala | 19 +++--- src/main/wacc/renamer.scala | 119 ++++++++++++++++++++++++++++++++++++ src/main/wacc/types.scala | 32 ++++++++++ 5 files changed, 181 insertions(+), 9 deletions(-) create mode 100644 src/main/wacc/Error.scala create mode 100644 src/main/wacc/renamer.scala create mode 100644 src/main/wacc/types.scala diff --git a/src/main/wacc/Error.scala b/src/main/wacc/Error.scala new file mode 100644 index 0000000..a9f0490 --- /dev/null +++ b/src/main/wacc/Error.scala @@ -0,0 +1,8 @@ +package wacc + +enum Error { + case DuplicateDeclaration(ident: ast.Ident) + case UndefinedIdentifier(ident: ast.Ident) + case FunctionParamsMismatch(expected: Int, got: Int) + case TypeMismatch(expected: types.SemType, got: types.SemType) +} diff --git a/src/main/wacc/Main.scala b/src/main/wacc/Main.scala index d4b070a..dcc7e6d 100644 --- a/src/main/wacc/Main.scala +++ b/src/main/wacc/Main.scala @@ -1,5 +1,6 @@ package wacc +import scala.collection.mutable import parsley.{Failure, Success} import scopt.OParser import java.io.File @@ -32,8 +33,15 @@ val cliParser = { def compile(contents: String): Int = { parser.parse(contents) match { case Success(ast) => - // TODO: Do semantics things - 0 + given errors: mutable.Builder[Error, List[Error]] = List.newBuilder + val names = renamer.rename(ast) + // given ctx: types.TypeCheckerCtx[List[Error]] = + // types.TypeCheckerCtx(names, errors) + // types.check(ast) + if (errors.result.nonEmpty) { + errors.result.foreach(println) + 200 + } else 0 case Failure(msg) => println(msg) 100 diff --git a/src/main/wacc/ast.scala b/src/main/wacc/ast.scala index 0076006..2457c8f 100644 --- a/src/main/wacc/ast.scala +++ b/src/main/wacc/ast.scala @@ -28,8 +28,10 @@ object ast { object StrLiter extends ParserBridgePos1[String, StrLiter] case class PairLiter()(pos: Position) extends Expr6 object PairLiter extends Expr6 with ParserBridgePos0[PairLiter] - case class Ident(v: String)(pos: Position) extends Expr6 with LValue - object Ident extends ParserBridgePos1[String, Ident] + case class Ident(v: String, var uid: Int = -1) extends Expr6 with LValue + object Ident extends ParserBridgePos1[String, Ident] { + def apply(x1: String): Ident = new Ident(x1) + } case class ArrayElem(name: Ident, indices: NonEmptyList[Expr])(pos: Position) extends Expr6 with LValue @@ -44,15 +46,18 @@ object ast { sealed trait UnaryOp extends Expr { val x: Expr } - case class Negate(x: Expr6)(pos: Position) extends Expr6 with UnaryOp + sealed trait UnaryOp extends Expr { + val x: Expr + } + case class Negate(x: Expr6)(pos: Position) extends Expr6 with UnaryOp with UnaryOp object Negate extends ParserBridgePos1[Expr6, Negate] - case class Not(x: Expr6)(pos: Position) extends Expr6 with UnaryOp + case class Not(x: Expr6)(pos: Position) extends Expr6 with UnaryOp with UnaryOp object Not extends ParserBridgePos1[Expr6, Not] - case class Len(x: Expr6)(pos: Position) extends Expr6 with UnaryOp + case class Len(x: Expr6)(pos: Position) extends Expr6 with UnaryOp with UnaryOp object Len extends ParserBridgePos1[Expr6, Len] - case class Ord(x: Expr6)(pos: Position) extends Expr6 with UnaryOp + case class Ord(x: Expr6)(pos: Position) extends Expr6 with UnaryOp with UnaryOp object Ord extends ParserBridgePos1[Expr6, Ord] - case class Chr(x: Expr6)(pos: Position) extends Expr6 with UnaryOp + case class Chr(x: Expr6)(pos: Position) extends Expr6 with UnaryOp with UnaryOp object Chr extends ParserBridgePos1[Expr6, Chr] // Binary operators diff --git a/src/main/wacc/renamer.scala b/src/main/wacc/renamer.scala new file mode 100644 index 0000000..023b317 --- /dev/null +++ b/src/main/wacc/renamer.scala @@ -0,0 +1,119 @@ +package wacc + +import scala.collection.mutable + +object renamer { + import ast._ + import types._ + + private case class Scope( + current: mutable.Map[String, Ident], + parent: Map[String, Ident] + ) { + def subscope: Scope = + Scope(mutable.Map.empty, Map.empty.withDefault(current.withDefault(parent))) + + def add(semType: SemType, name: Ident)(using + globalNames: mutable.Map[Ident, SemType], + globalNumbering: mutable.Map[String, Int], + errors: mutable.Builder[Error, List[Error]] + ) = { + if (current.contains(name.v)) { + errors += Error.DuplicateDeclaration(name) + } else { + val uid = globalNumbering.getOrElse(name.v, 0) + name.uid = uid + current(name.v) = name + + globalNames(name) = semType + globalNumbering(name.v) = uid + 1 + } + } + } + + def rename(prog: Program)(using + errors: mutable.Builder[Error, List[Error]] + ): Map[Ident, SemType] = + given globalNames: mutable.Map[Ident, SemType] = mutable.Map.empty + given globalNumbering: mutable.Map[String, Int] = mutable.Map.empty + rename(Scope(mutable.Map.empty, Map.empty))(prog) + globalNames.toMap + + private def rename(scope: Scope)( + node: Program | FuncDecl | Ident | Stmt | LValue | RValue + )(using + globalNames: mutable.Map[Ident, SemType], + globalNumbering: mutable.Map[String, Int], + errors: mutable.Builder[Error, List[Error]] + ): Unit = node match { + case Program(funcs, main) => { + funcs.foreach(rename(scope)) + main.toList.foreach(rename(scope)) + } + case FuncDecl(retType, name, params, body) => { + val functionScope = scope.subscope + val paramTypes = params.map { param => + val paramType = SemType(param.paramType) + functionScope.add(paramType, param.name) + paramType + } + scope.add(KnownType.Func(SemType(retType), paramTypes), name) + body.toList.foreach(rename(functionScope)) + } + case VarDecl(synType, name, value) => { + // Order matters here. Variable isn't declared until after the value is evaluated. + rename(scope)(value) + scope.add(SemType(synType), name) + } + case Assign(lhs, value) => { + rename(scope)(lhs) + rename(scope)(value) + } + case Read(lhs) => rename(scope)(lhs) + case Free(expr) => rename(scope)(expr) + case Return(expr) => rename(scope)(expr) + case Exit(expr) => rename(scope)(expr) + case Print(expr, _) => rename(scope)(expr) + case If(cond, thenStmt, elseStmt) => { + rename(scope)(cond) + thenStmt.toList.foreach(rename(scope.subscope)) + elseStmt.toList.foreach(rename(scope.subscope)) + } + case While(cond, body) => { + rename(scope)(cond) + body.toList.foreach(rename(scope.subscope)) + } + case Block(body) => body.toList.foreach(rename(scope.subscope)) + case NewPair(fst, snd) => { + rename(scope)(fst) + rename(scope)(snd) + } + case Call(name, args) => { + rename(scope)(name) + args.foreach(rename(scope)) + } + case Fst(elem) => rename(scope)(elem) + case Snd(elem) => rename(scope)(elem) + case ArrayLiter(elems) => elems.foreach(rename(scope)) + case ArrayElem(name, indices) => { + rename(scope)(name) + indices.toList.foreach(rename(scope)) + } + case Parens(expr) => rename(scope)(expr) + case op: UnaryOp => rename(scope)(op.x) + case op: BinaryOp => { + rename(scope)(op.x) + rename(scope)(op.y) + } + case id: Ident => { + scope.current.withDefault(scope.parent).get(id.v) match { + case Some(Ident(_, uid)) => id.uid = uid + case None => { + errors += Error.UndefinedIdentifier(id) + scope.add(?, id) + } + } + } + case IntLiter(_) | BoolLiter(_) | CharLiter(_) | StrLiter(_) | PairLiter | Skip => () + } +} diff --git a/src/main/wacc/types.scala b/src/main/wacc/types.scala new file mode 100644 index 0000000..2ce5f27 --- /dev/null +++ b/src/main/wacc/types.scala @@ -0,0 +1,32 @@ +package wacc + +import scala.collection.mutable + +object types { + import ast._ + + sealed trait SemType + case object ? extends SemType + enum KnownType extends SemType { + case Int + case Bool + case Char + case String + case Array(elem: SemType) + case Pair(left: SemType, right: SemType) + case Func(ret: SemType, params: List[SemType]) + } + + object SemType { + def apply(synType: Type | PairElemType): KnownType = synType match { + case IntType => KnownType.Int + case BoolType => KnownType.Bool + case CharType => KnownType.Char + case StringType => KnownType.String + case ArrayType(elemType, dimension) => + (0 until dimension).foldLeft(SemType(elemType))((acc, _) => KnownType.Array(acc)) + case PairType(fst, snd) => KnownType.Pair(SemType(fst), SemType(snd)) + case UntypedPairType => KnownType.Pair(?, ?) + } + } +} From 30cf42ee3a23d74aa17d2ab06a3dc200c7f0fc8a Mon Sep 17 00:00:00 2001 From: Gleb Koval Date: Wed, 5 Feb 2025 05:12:32 +0000 Subject: [PATCH 15/18] fix: separate variable and function in scope --- src/main/wacc/Error.scala | 2 +- src/main/wacc/renamer.scala | 48 +++++++++++++++++++++++-------------- src/main/wacc/types.scala | 2 -- 3 files changed, 31 insertions(+), 21 deletions(-) diff --git a/src/main/wacc/Error.scala b/src/main/wacc/Error.scala index a9f0490..6370925 100644 --- a/src/main/wacc/Error.scala +++ b/src/main/wacc/Error.scala @@ -2,7 +2,7 @@ package wacc enum Error { case DuplicateDeclaration(ident: ast.Ident) - case UndefinedIdentifier(ident: ast.Ident) + case UndefinedIdentifier(ident: ast.Ident, identType: renamer.IdentType) case FunctionParamsMismatch(expected: Int, got: Int) case TypeMismatch(expected: types.SemType, got: types.SemType) } diff --git a/src/main/wacc/renamer.scala b/src/main/wacc/renamer.scala index 023b317..cdfe19d 100644 --- a/src/main/wacc/renamer.scala +++ b/src/main/wacc/renamer.scala @@ -6,24 +6,29 @@ object renamer { import ast._ import types._ + enum IdentType { + case Func + case Var + } + private case class Scope( - current: mutable.Map[String, Ident], - parent: Map[String, Ident] + current: mutable.Map[(String, IdentType), Ident], + parent: Map[(String, IdentType), Ident] ) { def subscope: Scope = Scope(mutable.Map.empty, Map.empty.withDefault(current.withDefault(parent))) - def add(semType: SemType, name: Ident)(using + def add(semType: SemType, name: Ident, identType: IdentType)(using globalNames: mutable.Map[Ident, SemType], globalNumbering: mutable.Map[String, Int], errors: mutable.Builder[Error, List[Error]] ) = { - if (current.contains(name.v)) { + if (current.contains((name.v, identType))) { errors += Error.DuplicateDeclaration(name) } else { val uid = globalNumbering.getOrElse(name.v, 0) name.uid = uid - current(name.v) = name + current((name.v, identType)) = name globalNames(name) = semType globalNumbering(name.v) = uid + 1 @@ -54,16 +59,16 @@ object renamer { val functionScope = scope.subscope val paramTypes = params.map { param => val paramType = SemType(param.paramType) - functionScope.add(paramType, param.name) + functionScope.add(paramType, param.name, IdentType.Var) paramType } - scope.add(KnownType.Func(SemType(retType), paramTypes), name) + scope.add(KnownType.Func(SemType(retType), paramTypes), name, IdentType.Func) body.toList.foreach(rename(functionScope)) } case VarDecl(synType, name, value) => { // Order matters here. Variable isn't declared until after the value is evaluated. rename(scope)(value) - scope.add(SemType(synType), name) + scope.add(SemType(synType), name, IdentType.Var) } case Assign(lhs, value) => { rename(scope)(lhs) @@ -89,7 +94,7 @@ object renamer { rename(scope)(snd) } case Call(name, args) => { - rename(scope)(name) + renameIdent(scope, name, IdentType.Func) args.foreach(rename(scope)) } case Fst(elem) => rename(scope)(elem) @@ -105,15 +110,22 @@ object renamer { rename(scope)(op.x) rename(scope)(op.y) } - case id: Ident => { - scope.current.withDefault(scope.parent).get(id.v) match { - case Some(Ident(_, uid)) => id.uid = uid - case None => { - errors += Error.UndefinedIdentifier(id) - scope.add(?, id) - } - } - } + // Default to variables. Only `call` uses IdentType.Func. + case id: Ident => renameIdent(scope, id, IdentType.Var) case IntLiter(_) | BoolLiter(_) | CharLiter(_) | StrLiter(_) | PairLiter | Skip => () } + + private def renameIdent(scope: Scope, ident: Ident, identType: IdentType)(using + globalNames: mutable.Map[Ident, SemType], + globalNumbering: mutable.Map[String, Int], + errors: mutable.Builder[Error, List[Error]] + ): Unit = { + scope.current.withDefault(scope.parent).get((ident.v, identType)) match { + case Some(Ident(_, uid)) => ident.uid = uid + case None => { + errors += Error.UndefinedIdentifier(ident, identType) + scope.add(?, ident, identType) + } + } + } } diff --git a/src/main/wacc/types.scala b/src/main/wacc/types.scala index 2ce5f27..388416f 100644 --- a/src/main/wacc/types.scala +++ b/src/main/wacc/types.scala @@ -1,7 +1,5 @@ package wacc -import scala.collection.mutable - object types { import ast._ From ae9625b58634e773f6014d7e366b08acb0e58edb Mon Sep 17 00:00:00 2001 From: Gleb Koval Date: Wed, 5 Feb 2025 18:04:04 +0000 Subject: [PATCH 16/18] fix: use apply() instead of get() for Maps --- src/main/wacc/renamer.scala | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/src/main/wacc/renamer.scala b/src/main/wacc/renamer.scala index cdfe19d..e46d9dd 100644 --- a/src/main/wacc/renamer.scala +++ b/src/main/wacc/renamer.scala @@ -120,12 +120,15 @@ object renamer { globalNumbering: mutable.Map[String, Int], errors: mutable.Builder[Error, List[Error]] ): Unit = { - scope.current.withDefault(scope.parent).get((ident.v, identType)) match { - case Some(Ident(_, uid)) => ident.uid = uid - case None => { + // Unfortunately map defaults only work with `.apply()`, which throws an error when the key is not found. + // Neither is there a way to check whether a default exists, so we have to use a try-catch. + try { + val Ident(_, uid) = scope.current.withDefault(scope.parent)((ident.v, identType)) + ident.uid = uid + } catch { + case _: NoSuchElementException => errors += Error.UndefinedIdentifier(ident, identType) scope.add(?, ident, identType) - } } } } From 74f62ea933bf7fc6ae838318151a8b4056a38a8a Mon Sep 17 00:00:00 2001 From: Gleb Koval Date: Wed, 5 Feb 2025 20:41:49 +0000 Subject: [PATCH 17/18] fix: fix merge breaks, add function names to scope before renaming bodies --- src/main/wacc/Main.scala | 2 +- src/main/wacc/ast.scala | 19 ++++++++----------- src/main/wacc/renamer.scala | 38 ++++++++++++++++++++----------------- src/main/wacc/types.scala | 10 +++++----- 4 files changed, 35 insertions(+), 34 deletions(-) diff --git a/src/main/wacc/Main.scala b/src/main/wacc/Main.scala index dcc7e6d..a271c3c 100644 --- a/src/main/wacc/Main.scala +++ b/src/main/wacc/Main.scala @@ -34,7 +34,7 @@ def compile(contents: String): Int = { parser.parse(contents) match { case Success(ast) => given errors: mutable.Builder[Error, List[Error]] = List.newBuilder - val names = renamer.rename(ast) + renamer.rename(ast) // given ctx: types.TypeCheckerCtx[List[Error]] = // types.TypeCheckerCtx(names, errors) // types.check(ast) diff --git a/src/main/wacc/ast.scala b/src/main/wacc/ast.scala index 2457c8f..123adb5 100644 --- a/src/main/wacc/ast.scala +++ b/src/main/wacc/ast.scala @@ -27,10 +27,10 @@ object ast { case class StrLiter(v: String)(pos: Position) extends Expr6 object StrLiter extends ParserBridgePos1[String, StrLiter] case class PairLiter()(pos: Position) extends Expr6 - object PairLiter extends Expr6 with ParserBridgePos0[PairLiter] - case class Ident(v: String, var uid: Int = -1) extends Expr6 with LValue + object PairLiter extends ParserBridgePos0[PairLiter] + case class Ident(v: String, var uid: Int = -1)(pos: Position) extends Expr6 with LValue object Ident extends ParserBridgePos1[String, Ident] { - def apply(x1: String): Ident = new Ident(x1) + def apply(v: String)(pos: Position): Ident = new Ident(v)(pos) } case class ArrayElem(name: Ident, indices: NonEmptyList[Expr])(pos: Position) extends Expr6 @@ -46,18 +46,15 @@ object ast { sealed trait UnaryOp extends Expr { val x: Expr } - sealed trait UnaryOp extends Expr { - val x: Expr - } - case class Negate(x: Expr6)(pos: Position) extends Expr6 with UnaryOp with UnaryOp + case class Negate(x: Expr6)(pos: Position) extends Expr6 with UnaryOp object Negate extends ParserBridgePos1[Expr6, Negate] - case class Not(x: Expr6)(pos: Position) extends Expr6 with UnaryOp with UnaryOp + case class Not(x: Expr6)(pos: Position) extends Expr6 with UnaryOp object Not extends ParserBridgePos1[Expr6, Not] - case class Len(x: Expr6)(pos: Position) extends Expr6 with UnaryOp with UnaryOp + case class Len(x: Expr6)(pos: Position) extends Expr6 with UnaryOp object Len extends ParserBridgePos1[Expr6, Len] - case class Ord(x: Expr6)(pos: Position) extends Expr6 with UnaryOp with UnaryOp + case class Ord(x: Expr6)(pos: Position) extends Expr6 with UnaryOp object Ord extends ParserBridgePos1[Expr6, Ord] - case class Chr(x: Expr6)(pos: Position) extends Expr6 with UnaryOp with UnaryOp + case class Chr(x: Expr6)(pos: Position) extends Expr6 with UnaryOp object Chr extends ParserBridgePos1[Expr6, Chr] // Binary operators diff --git a/src/main/wacc/renamer.scala b/src/main/wacc/renamer.scala index e46d9dd..77e0bd6 100644 --- a/src/main/wacc/renamer.scala +++ b/src/main/wacc/renamer.scala @@ -41,30 +41,34 @@ object renamer { ): Map[Ident, SemType] = given globalNames: mutable.Map[Ident, SemType] = mutable.Map.empty given globalNumbering: mutable.Map[String, Int] = mutable.Map.empty - rename(Scope(mutable.Map.empty, Map.empty))(prog) + val scope = Scope(mutable.Map.empty, Map.empty) + val Program(funcs, main) = prog + funcs + .map { case FuncDecl(retType, name, params, body) => + val paramTypes = params.map { param => + val paramType = SemType(param.paramType) + paramType + } + scope.add(KnownType.Func(SemType(retType), paramTypes), name, IdentType.Func) + (params zip paramTypes, body) + } + .foreach { case (params, body) => + val functionScope = scope.subscope + params.foreach { case (param, paramType) => + functionScope.add(paramType, param.name, IdentType.Var) + } + body.toList.foreach(rename(functionScope.subscope)) // body can shadow function params + } + main.toList.foreach(rename(scope)) globalNames.toMap private def rename(scope: Scope)( - node: Program | FuncDecl | Ident | Stmt | LValue | RValue + node: Ident | Stmt | LValue | RValue )(using globalNames: mutable.Map[Ident, SemType], globalNumbering: mutable.Map[String, Int], errors: mutable.Builder[Error, List[Error]] ): Unit = node match { - case Program(funcs, main) => { - funcs.foreach(rename(scope)) - main.toList.foreach(rename(scope)) - } - case FuncDecl(retType, name, params, body) => { - val functionScope = scope.subscope - val paramTypes = params.map { param => - val paramType = SemType(param.paramType) - functionScope.add(paramType, param.name, IdentType.Var) - paramType - } - scope.add(KnownType.Func(SemType(retType), paramTypes), name, IdentType.Func) - body.toList.foreach(rename(functionScope)) - } case VarDecl(synType, name, value) => { // Order matters here. Variable isn't declared until after the value is evaluated. rename(scope)(value) @@ -112,7 +116,7 @@ object renamer { } // Default to variables. Only `call` uses IdentType.Func. case id: Ident => renameIdent(scope, id, IdentType.Var) - case IntLiter(_) | BoolLiter(_) | CharLiter(_) | StrLiter(_) | PairLiter | Skip => () + case IntLiter(_) | BoolLiter(_) | CharLiter(_) | StrLiter(_) | PairLiter() | Skip() => () } private def renameIdent(scope: Scope, ident: Ident, identType: IdentType)(using diff --git a/src/main/wacc/types.scala b/src/main/wacc/types.scala index 388416f..e62f0db 100644 --- a/src/main/wacc/types.scala +++ b/src/main/wacc/types.scala @@ -17,14 +17,14 @@ object types { object SemType { def apply(synType: Type | PairElemType): KnownType = synType match { - case IntType => KnownType.Int - case BoolType => KnownType.Bool - case CharType => KnownType.Char - case StringType => KnownType.String + case IntType() => KnownType.Int + case BoolType() => KnownType.Bool + case CharType() => KnownType.Char + case StringType() => KnownType.String case ArrayType(elemType, dimension) => (0 until dimension).foldLeft(SemType(elemType))((acc, _) => KnownType.Array(acc)) case PairType(fst, snd) => KnownType.Pair(SemType(fst), SemType(snd)) - case UntypedPairType => KnownType.Pair(?, ?) + case UntypedPairType() => KnownType.Pair(?, ?) } } } From 0e2d1af878f56c5b476f02d9518d49fcdf23c795 Mon Sep 17 00:00:00 2001 From: Gleb Koval Date: Wed, 5 Feb 2025 22:03:26 +0000 Subject: [PATCH 18/18] refactor: add comments to renamer --- src/main/wacc/renamer.scala | 99 ++++++++++++++++++++++++++++++++----- src/main/wacc/types.scala | 1 + 2 files changed, 89 insertions(+), 11 deletions(-) diff --git a/src/main/wacc/renamer.scala b/src/main/wacc/renamer.scala index 77e0bd6..6b78dc1 100644 --- a/src/main/wacc/renamer.scala +++ b/src/main/wacc/renamer.scala @@ -15,9 +15,33 @@ object renamer { current: mutable.Map[(String, IdentType), Ident], parent: Map[(String, IdentType), Ident] ) { + + /** Create a new scope with the current scope as its parent. + * + * @return + * A new scope with an empty current scope, and this scope flattened into the parent scope. + */ def subscope: Scope = Scope(mutable.Map.empty, Map.empty.withDefault(current.withDefault(parent))) + /** Attempt to add a new identifier to the current scope. If the identifier already exists in + * the current scope, add an error to the error list. + * + * @param semType + * The semantic type of the identifier. + * @param name + * The name of the identifier. + * @param identType + * The identifier type (function or variable). + * @param globalNames + * The global map of identifiers to semantic types - the identifier will be added to this + * map. + * @param globalNumbering + * The global map of identifier names to the number of times they have been declared - will + * used to rename this identifier, and will be incremented. + * @param errors + * The list of errors to append to. + */ def add(semType: SemType, name: Ident, identType: IdentType)(using globalNames: mutable.Map[Ident, SemType], globalNumbering: mutable.Map[String, Int], @@ -36,6 +60,16 @@ object renamer { } } + /** Check scoping of all variables and functions in the program. Also generate semantic types for + * all identifiers. + * + * @param prog + * AST of the program + * @param errors + * List of errors to append to + * @return + * Map of all (renamed) identifies to their semantic types + */ def rename(prog: Program)(using errors: mutable.Builder[Error, List[Error]] ): Map[Ident, SemType] = @@ -44,6 +78,7 @@ object renamer { val scope = Scope(mutable.Map.empty, Map.empty) val Program(funcs, main) = prog funcs + // First add all function declarations to the scope .map { case FuncDecl(retType, name, params, body) => val paramTypes = params.map { param => val paramType = SemType(param.paramType) @@ -52,6 +87,8 @@ object renamer { scope.add(KnownType.Func(SemType(retType), paramTypes), name, IdentType.Func) (params zip paramTypes, body) } + // Only then rename the function bodies + // (functions can call one-another regardless of order of declaration) .foreach { case (params, body) => val functionScope = scope.subscope params.foreach { case (param, paramType) => @@ -62,19 +99,52 @@ object renamer { main.toList.foreach(rename(scope)) globalNames.toMap + /** Check scoping of all identifies in a given AST node. + * + * @param scope + * The current scope and flattened parent scope. + * @param node + * The AST node. + * @param globalNames + * The global map of identifiers to semantic types - renamed identifiers will be added to this + * map. + * @param globalNumbering + * The global map of identifier names to the number of times they have been declared - used and + * updated during identifier renaming. + * @param errors + */ private def rename(scope: Scope)( - node: Ident | Stmt | LValue | RValue + node: Ident | Stmt | LValue | RValue | Expr )(using globalNames: mutable.Map[Ident, SemType], globalNumbering: mutable.Map[String, Int], errors: mutable.Builder[Error, List[Error]] ): Unit = node match { + // These cases are more interesting because the involve making subscopes + // or modifying the current scope. case VarDecl(synType, name, value) => { // Order matters here. Variable isn't declared until after the value is evaluated. rename(scope)(value) + // Attempt to add the new variable to the current scope. scope.add(SemType(synType), name, IdentType.Var) } + case If(cond, thenStmt, elseStmt) => { + rename(scope)(cond) + // then and else both have their own scopes + thenStmt.toList.foreach(rename(scope.subscope)) + elseStmt.toList.foreach(rename(scope.subscope)) + } + case While(cond, body) => { + rename(scope)(cond) + // while bodies have their own scopes + body.toList.foreach(rename(scope.subscope)) + } + // begin-end blocks have their own scopes + case Block(body) => body.toList.foreach(rename(scope.subscope)) + + // These cases are simpler, mostly just recursive calls to rename() case Assign(lhs, value) => { + // Variables may be reassigned with their value in the rhs, so order doesn't matter here. rename(scope)(lhs) rename(scope)(value) } @@ -83,16 +153,6 @@ object renamer { case Return(expr) => rename(scope)(expr) case Exit(expr) => rename(scope)(expr) case Print(expr, _) => rename(scope)(expr) - case If(cond, thenStmt, elseStmt) => { - rename(scope)(cond) - thenStmt.toList.foreach(rename(scope.subscope)) - elseStmt.toList.foreach(rename(scope.subscope)) - } - case While(cond, body) => { - rename(scope)(cond) - body.toList.foreach(rename(scope.subscope)) - } - case Block(body) => body.toList.foreach(rename(scope.subscope)) case NewPair(fst, snd) => { rename(scope)(fst) rename(scope)(snd) @@ -116,9 +176,26 @@ object renamer { } // Default to variables. Only `call` uses IdentType.Func. case id: Ident => renameIdent(scope, id, IdentType.Var) + // These literals cannot contain identifies, exit immediately. case IntLiter(_) | BoolLiter(_) | CharLiter(_) | StrLiter(_) | PairLiter() | Skip() => () } + /** Lookup an identifier in the current scope and rename it. If the identifier is not found, add + * an error to the error list and add it to the current scope with an unknown type. + * + * @param scope + * The current scope and flattened parent scope. + * @param ident + * The identifier to rename. + * @param identType + * The type of the identifier (function or variable). + * @param globalNames + * Used to add not-found identifiers to scope. + * @param globalNumbering + * Used to add not-found identifiers to scope. + * @param errors + * The list of errors to append to. + */ private def renameIdent(scope: Scope, ident: Ident, identType: IdentType)(using globalNames: mutable.Map[Ident, SemType], globalNumbering: mutable.Map[String, Int], diff --git a/src/main/wacc/types.scala b/src/main/wacc/types.scala index e62f0db..e2a7988 100644 --- a/src/main/wacc/types.scala +++ b/src/main/wacc/types.scala @@ -21,6 +21,7 @@ object types { case BoolType() => KnownType.Bool case CharType() => KnownType.Char case StringType() => KnownType.String + // For semantic types it is easier to work with recursion rather than a fixed size case ArrayType(elemType, dimension) => (0 until dimension).foldLeft(SemType(elemType))((acc, _) => KnownType.Array(acc)) case PairType(fst, snd) => KnownType.Pair(SemType(fst), SemType(snd))