From 6d1c0b7a87ded87b5ff3781c4514c500c42f26b5 Mon Sep 17 00:00:00 2001
From: Gleb Koval <gleb@koval.net>
Date: Wed, 5 Feb 2025 04:47:43 +0000
Subject: [PATCH] fix: make parser use only parsley parser bridge apply

---
 src/main/wacc/ast.scala    | 76 ++++++++++++++++++++------------------
 src/main/wacc/parser.scala | 26 ++++++-------
 2 files changed, 52 insertions(+), 50 deletions(-)

diff --git a/src/main/wacc/ast.scala b/src/main/wacc/ast.scala
index 9afd8d9..e4a9128 100644
--- a/src/main/wacc/ast.scala
+++ b/src/main/wacc/ast.scala
@@ -2,6 +2,7 @@ package wacc
 
 import parsley.Parsley
 import parsley.generic.ErrorBridge
+import parsley.ap._
 import parsley.position._
 import parsley.syntax.zipped._
 import cats.data.NonEmptyList
@@ -32,7 +33,10 @@ object ast {
   case class ArrayElem(name: Ident, indices: NonEmptyList[Expr])(pos: Position)
       extends Expr6
       with LValue
-  object ArrayElem extends ParserBridgePos2[Ident, NonEmptyList[Expr], ArrayElem]
+  object ArrayElem extends ParserBridgePos1[NonEmptyList[Expr], Ident => ArrayElem] {
+    def apply(a: NonEmptyList[Expr])(pos: Position): Ident => ArrayElem =
+      name => ArrayElem(name, a)(pos)
+  }
   case class Parens(expr: Expr)(pos: Position) extends Expr6
   object Parens extends ParserBridgePos1[Expr, Parens]
 
@@ -97,7 +101,9 @@ object ast {
   case class ArrayType(elemType: Type, dimensions: Int)(pos: Position)
       extends Type
       with PairElemType
-  object ArrayType extends ParserBridgePos2[Type, Int, ArrayType]
+  object ArrayType extends ParserBridgePos1[Int, Type => ArrayType] {
+    def apply(a: Int)(pos: Position): Type => ArrayType = elemType => ArrayType(elemType, a)(pos)
+  }
   case class PairType(fst: PairElemType, snd: PairElemType)(pos: Position) extends Type
   object PairType extends ParserBridgePos2[PairElemType, PairElemType, PairType]
 
@@ -116,7 +122,17 @@ object ast {
       params: List[Param],
       body: NonEmptyList[Stmt]
   )(pos: Position)
-  object FuncDecl extends ParserBridgePos4[Type, Ident, List[Param], NonEmptyList[Stmt], FuncDecl]
+  object FuncDecl
+      extends ParserBridgePos2[
+        List[Param],
+        NonEmptyList[Stmt],
+        ((Type, Ident)) => FuncDecl
+      ] {
+    def apply(params: List[Param], body: NonEmptyList[Stmt])(
+        pos: Position
+    ): ((Type, Ident)) => FuncDecl =
+      (returnType, name) => FuncDecl(returnType, name, params, body)(pos)
+  }
 
   case class Param(paramType: Type, name: Ident)(pos: Position)
   object Param extends ParserBridgePos2[Type, Ident, Param]
@@ -165,65 +181,53 @@ object ast {
   object Snd extends ParserBridgePos1[LValue, Snd]
 
   // Parser bridges
-  case class Position(line: Int, column: Int, offset: Int, width: Int)
+  case class Position(line: Int, column: Int, offset: Int)
 
   private def applyCon[A, B](
       con: ((Int, Int), Int, Int) => A => B
-  )(ops: Parsley[A]): Parsley[B] =
+  )(ops: => Parsley[A]): Parsley[B] =
     (pos, offset, withWidth(ops)).zipped.map { (pos, off, res) =>
       con(pos, off, res._2)(res._1)
     }
 
   trait ParserSingletonBridgePos[+A] extends ErrorBridge {
-    protected def con(pos: (Int, Int), offset: Int, width: Int): A
-    def from(op: Parsley[?]): Parsley[A] = error((pos, offset, withWidth(op)).zipped.map {
-      (pos, off, res) => con(pos, off, res._2)
-    })
-    final def <#(op: Parsley[?]): Parsley[A] = from(op)
+    protected def con(pos: (Int, Int), offset: Int): A
+    infix def from(op: Parsley[?]): Parsley[A] = error((pos, offset).zipped(con) <~ op)
+    final def <#(op: Parsley[?]): Parsley[A] = this from op
   }
 
   trait ParserBridgePos0[+A] extends ParserSingletonBridgePos[A] {
     def apply()(pos: Position): A
 
-    override final def con(pos: (Int, Int), offset: Int, width: Int): A =
-      apply()(Position(pos._1, pos._2, offset, width))
+    override final def con(pos: (Int, Int), offset: Int): A =
+      apply()(Position(pos._1, pos._2, offset))
   }
 
   trait ParserBridgePos1[-A, +B] extends ParserSingletonBridgePos[A => B] {
     def apply(a: A)(pos: Position): B
-    def apply(a: Parsley[A]): Parsley[B] = error((pos, offset, withWidth(a)).zipped.map {
-      (pos, off, res) => con(pos, off, res._2)(res._1)
-    })
+    def apply(a: Parsley[A]): Parsley[B] = error(ap1((pos, offset).zipped(con), a))
 
-    override final def con(pos: (Int, Int), offset: Int, width: Int): A => B =
-      apply(_)(Position(pos._1, pos._2, offset, width))
+    override final def con(pos: (Int, Int), offset: Int): A => B =
+      this.apply(_)(Position(pos._1, pos._2, offset))
   }
 
-  trait ParserBridgePos2[-A, -B, +C] extends ParserSingletonBridgePos[((A, B)) => C] {
+  trait ParserBridgePos2[-A, -B, +C] extends ParserSingletonBridgePos[(A, B) => C] {
     def apply(a: A, b: B)(pos: Position): C
-    def apply(a: Parsley[A], b: Parsley[B]): Parsley[C] = error(applyCon(con)((a, b).zipped))
+    def apply(a: Parsley[A], b: => Parsley[B]): Parsley[C] = error(
+      ap2((pos, offset).zipped(con), a, b)
+    )
 
-    override final def con(pos: (Int, Int), offset: Int, width: Int): ((A, B)) => C =
-      apply(_, _)(Position(pos._1, pos._2, offset, width))
+    override final def con(pos: (Int, Int), offset: Int): (A, B) => C =
+      apply(_, _)(Position(pos._1, pos._2, offset))
   }
 
-  trait ParserBridgePos3[-A, -B, -C, +D] extends ParserSingletonBridgePos[((A, B, C)) => D] {
+  trait ParserBridgePos3[-A, -B, -C, +D] extends ParserSingletonBridgePos[(A, B, C) => D] {
     def apply(a: A, b: B, c: C)(pos: Position): D
-    def apply(a: Parsley[A], b: Parsley[B], c: Parsley[C]): Parsley[D] = error(
-      applyCon(con)((a, b, c).zipped)
+    def apply(a: Parsley[A], b: => Parsley[B], c: => Parsley[C]): Parsley[D] = error(
+      ap3((pos, offset).zipped(con), a, b, c)
     )
 
-    override final def con(pos: (Int, Int), offset: Int, width: Int): ((A, B, C)) => D =
-      apply(_, _, _)(Position(pos._1, pos._2, offset, width))
-  }
-
-  trait ParserBridgePos4[-A, -B, -C, -D, +E] extends ParserSingletonBridgePos[((A, B, C, D)) => E] {
-    def apply(a: A, b: B, c: C, d: D)(pos: Position): E
-    def apply(a: Parsley[A], b: Parsley[B], c: Parsley[C], d: Parsley[D]): Parsley[E] = error(
-      applyCon(con)((a, b, c, d).zipped)
-    )
-
-    override final def con(pos: (Int, Int), offset: Int, width: Int): ((A, B, C, D)) => E =
-      apply(_, _, _, _)(Position(pos._1, pos._2, offset, width))
+    override final def con(pos: (Int, Int), offset: Int): (A, B, C) => D =
+      apply(_, _, _)(Position(pos._1, pos._2, offset))
   }
 }
diff --git a/src/main/wacc/parser.scala b/src/main/wacc/parser.scala
index 9cec896..5751732 100644
--- a/src/main/wacc/parser.scala
+++ b/src/main/wacc/parser.scala
@@ -93,10 +93,7 @@ object parser {
   private val `<ident>` = Ident(ident)
   private lazy val `<ident-or-array-elem>` =
     `<ident>` <**> (`<array-indices>` </> identity)
-  private val `<array-indices>` =
-    some("[" ~> `<expr>` <~ "]") map { indices =>
-      ArrayElem((_: Ident), indices)
-    }
+  private val `<array-indices>` = ArrayElem(some("[" ~> `<expr>` <~ "]"))
 
   // Types
   private lazy val `<type>`: Parsley[Type] =
@@ -104,7 +101,7 @@ object parser {
   private val `<base-type>` =
     (IntType from "int") | (BoolType from "bool") | (CharType from "char") | (StringType from "string")
   private lazy val `<array-type>` =
-    countSome("[" ~> "]") map { cnt => ArrayType((_: Type), cnt) }
+    ArrayType(countSome("[" ~> "]"))
   private val `<pair-type>` = "pair"
   private val `<pair-elems-type>`: Parsley[PairType] = PairType(
     "(" ~> `<pair-elem-type>` <~ ",",
@@ -112,10 +109,10 @@ object parser {
   )
   private lazy val `<pair-elem-type>` =
     (`<base-type>` <**> (`<array-type>` </> identity)) |
-      `<pair-type>` ~> ((`<pair-elems-type>` <**> `<array-type>`.explain(
-        "non-erased pair types cannot be nested"
-      )) </> UntypedPairType)
-  // TODO: better explanation here?
+      ((UntypedPairType from `<pair-type>`) <**>
+        ((`<pair-elems-type>` <**> `<array-type>`)
+          .map(arr => (_: UntypedPairType) => arr) </> identity))
+
   // Statements
   private lazy val `<program>` = Program(
     "begin" ~> many(
@@ -124,11 +121,12 @@ object parser {
     `<stmt>`.label("main program body") <~ "end"
   )
   private lazy val `<partial-func-decl>` =
-    (sepBy(`<param>`, ",") <~ ")" <~ "is" <~> `<stmt>`.guardAgainst {
-      case stmts if !stmts.isReturning => Seq("all functions must end in a returning statement")
-    } <~ "end") map { (params, stmt) =>
-      (FuncDecl((_: Type), (_: Ident), params, stmt)).tupled
-    }
+    FuncDecl(
+      sepBy(`<param>`, ",") <~ ")" <~ "is",
+      `<stmt>`.guardAgainst {
+        case stmts if !stmts.isReturning => Seq("All functions must end in a returning statement")
+      } <~ "end"
+    )
   private lazy val `<param>` = Param(`<type>`, `<ident>`)
   private lazy val `<stmt>`: Parsley[NonEmptyList[Stmt]] =
     (