diff --git a/src/main/wacc/Error.scala b/src/main/wacc/Error.scala new file mode 100644 index 0000000..6370925 --- /dev/null +++ b/src/main/wacc/Error.scala @@ -0,0 +1,8 @@ +package wacc + +enum Error { + case DuplicateDeclaration(ident: ast.Ident) + case UndefinedIdentifier(ident: ast.Ident, identType: renamer.IdentType) + case FunctionParamsMismatch(expected: Int, got: Int) + case TypeMismatch(expected: types.SemType, got: types.SemType) +} diff --git a/src/main/wacc/Main.scala b/src/main/wacc/Main.scala index d4b070a..a271c3c 100644 --- a/src/main/wacc/Main.scala +++ b/src/main/wacc/Main.scala @@ -1,5 +1,6 @@ package wacc +import scala.collection.mutable import parsley.{Failure, Success} import scopt.OParser import java.io.File @@ -32,8 +33,15 @@ val cliParser = { def compile(contents: String): Int = { parser.parse(contents) match { case Success(ast) => - // TODO: Do semantics things - 0 + given errors: mutable.Builder[Error, List[Error]] = List.newBuilder + renamer.rename(ast) + // given ctx: types.TypeCheckerCtx[List[Error]] = + // types.TypeCheckerCtx(names, errors) + // types.check(ast) + if (errors.result.nonEmpty) { + errors.result.foreach(println) + 200 + } else 0 case Failure(msg) => println(msg) 100 diff --git a/src/main/wacc/ast.scala b/src/main/wacc/ast.scala index c6e743e..123adb5 100644 --- a/src/main/wacc/ast.scala +++ b/src/main/wacc/ast.scala @@ -1,6 +1,10 @@ package wacc -import parsley.generic._ +import parsley.Parsley +import parsley.generic.ErrorBridge +import parsley.ap._ +import parsley.position._ +import parsley.syntax.zipped._ import cats.data.NonEmptyList object ast { @@ -14,80 +18,104 @@ object ast { sealed trait Expr6 extends Expr5 // Atoms - case class IntLiter(v: Int) extends Expr6 - object IntLiter extends ParserBridge1[Int, IntLiter] - case class BoolLiter(v: Boolean) extends Expr6 - object BoolLiter extends ParserBridge1[Boolean, BoolLiter] - case class CharLiter(v: Char) extends Expr6 - object CharLiter extends ParserBridge1[Char, CharLiter] - case class StrLiter(v: String) extends Expr6 - object StrLiter extends ParserBridge1[String, StrLiter] - case object PairLiter extends Expr6 with ParserBridge0[PairLiter.type] - case class Ident(v: String) extends Expr6 with LValue - object Ident extends ParserBridge1[String, Ident] - case class ArrayElem(name: Ident, indices: NonEmptyList[Expr]) extends Expr6 with LValue - object ArrayElem extends ParserBridge2[Ident, NonEmptyList[Expr], ArrayElem] - case class Parens(expr: Expr) extends Expr6 - object Parens extends ParserBridge1[Expr, Parens] + case class IntLiter(v: Int)(pos: Position) extends Expr6 + object IntLiter extends ParserBridgePos1[Int, IntLiter] + case class BoolLiter(v: Boolean)(pos: Position) extends Expr6 + object BoolLiter extends ParserBridgePos1[Boolean, BoolLiter] + case class CharLiter(v: Char)(pos: Position) extends Expr6 + object CharLiter extends ParserBridgePos1[Char, CharLiter] + case class StrLiter(v: String)(pos: Position) extends Expr6 + object StrLiter extends ParserBridgePos1[String, StrLiter] + case class PairLiter()(pos: Position) extends Expr6 + object PairLiter extends ParserBridgePos0[PairLiter] + case class Ident(v: String, var uid: Int = -1)(pos: Position) extends Expr6 with LValue + object Ident extends ParserBridgePos1[String, Ident] { + def apply(v: String)(pos: Position): Ident = new Ident(v)(pos) + } + case class ArrayElem(name: Ident, indices: NonEmptyList[Expr])(pos: Position) + extends Expr6 + with LValue + object ArrayElem extends ParserBridgePos1[NonEmptyList[Expr], Ident => ArrayElem] { + def apply(a: NonEmptyList[Expr])(pos: Position): Ident => ArrayElem = + name => ArrayElem(name, a)(pos) + } + case class Parens(expr: Expr)(pos: Position) extends Expr6 + object Parens extends ParserBridgePos1[Expr, Parens] // Unary operators - case class Negate(x: Expr6) extends Expr6 - object Negate extends ParserBridge1[Expr6, Negate] - case class Not(x: Expr6) extends Expr6 - object Not extends ParserBridge1[Expr6, Not] - case class Len(x: Expr6) extends Expr6 - object Len extends ParserBridge1[Expr6, Len] - case class Ord(x: Expr6) extends Expr6 - object Ord extends ParserBridge1[Expr6, Ord] - case class Chr(x: Expr6) extends Expr6 - object Chr extends ParserBridge1[Expr6, Chr] + sealed trait UnaryOp extends Expr { + val x: Expr + } + case class Negate(x: Expr6)(pos: Position) extends Expr6 with UnaryOp + object Negate extends ParserBridgePos1[Expr6, Negate] + case class Not(x: Expr6)(pos: Position) extends Expr6 with UnaryOp + object Not extends ParserBridgePos1[Expr6, Not] + case class Len(x: Expr6)(pos: Position) extends Expr6 with UnaryOp + object Len extends ParserBridgePos1[Expr6, Len] + case class Ord(x: Expr6)(pos: Position) extends Expr6 with UnaryOp + object Ord extends ParserBridgePos1[Expr6, Ord] + case class Chr(x: Expr6)(pos: Position) extends Expr6 with UnaryOp + object Chr extends ParserBridgePos1[Expr6, Chr] // Binary operators - case class Add(x: Expr4, y: Expr5) extends Expr4 - object Add extends ParserBridge2[Expr4, Expr5, Add] - case class Sub(x: Expr4, y: Expr5) extends Expr4 - object Sub extends ParserBridge2[Expr4, Expr5, Sub] - case class Mul(x: Expr5, y: Expr6) extends Expr5 - object Mul extends ParserBridge2[Expr5, Expr6, Mul] - case class Div(x: Expr5, y: Expr6) extends Expr5 - object Div extends ParserBridge2[Expr5, Expr6, Div] - case class Mod(x: Expr5, y: Expr6) extends Expr5 - object Mod extends ParserBridge2[Expr5, Expr6, Mod] - case class Greater(x: Expr4, y: Expr4) extends Expr3 - object Greater extends ParserBridge2[Expr4, Expr4, Greater] - case class GreaterEq(x: Expr4, y: Expr4) extends Expr3 - object GreaterEq extends ParserBridge2[Expr4, Expr4, GreaterEq] - case class Less(x: Expr4, y: Expr4) extends Expr3 - object Less extends ParserBridge2[Expr4, Expr4, Less] - case class LessEq(x: Expr4, y: Expr4) extends Expr3 - object LessEq extends ParserBridge2[Expr4, Expr4, LessEq] - case class Eq(x: Expr3, y: Expr3) extends Expr2 - object Eq extends ParserBridge2[Expr3, Expr3, Eq] - case class Neq(x: Expr3, y: Expr3) extends Expr2 - object Neq extends ParserBridge2[Expr3, Expr3, Neq] - case class And(x: Expr2, y: Expr1) extends Expr1 - object And extends ParserBridge2[Expr2, Expr1, And] - case class Or(x: Expr1, y: Expr) extends Expr - object Or extends ParserBridge2[Expr1, Expr, Or] + sealed trait BinaryOp extends Expr { + val x: Expr + val y: Expr + } + case class Add(x: Expr4, y: Expr5)(pos: Position) extends Expr4 with BinaryOp + object Add extends ParserBridgePos2[Expr4, Expr5, Add] + case class Sub(x: Expr4, y: Expr5)(pos: Position) extends Expr4 with BinaryOp + object Sub extends ParserBridgePos2[Expr4, Expr5, Sub] + case class Mul(x: Expr5, y: Expr6)(pos: Position) extends Expr5 with BinaryOp + object Mul extends ParserBridgePos2[Expr5, Expr6, Mul] + case class Div(x: Expr5, y: Expr6)(pos: Position) extends Expr5 with BinaryOp + object Div extends ParserBridgePos2[Expr5, Expr6, Div] + case class Mod(x: Expr5, y: Expr6)(pos: Position) extends Expr5 with BinaryOp + object Mod extends ParserBridgePos2[Expr5, Expr6, Mod] + case class Greater(x: Expr4, y: Expr4)(pos: Position) extends Expr3 with BinaryOp + object Greater extends ParserBridgePos2[Expr4, Expr4, Greater] + case class GreaterEq(x: Expr4, y: Expr4)(pos: Position) extends Expr3 with BinaryOp + object GreaterEq extends ParserBridgePos2[Expr4, Expr4, GreaterEq] + case class Less(x: Expr4, y: Expr4)(pos: Position) extends Expr3 with BinaryOp + object Less extends ParserBridgePos2[Expr4, Expr4, Less] + case class LessEq(x: Expr4, y: Expr4)(pos: Position) extends Expr3 with BinaryOp + object LessEq extends ParserBridgePos2[Expr4, Expr4, LessEq] + case class Eq(x: Expr3, y: Expr3)(pos: Position) extends Expr2 with BinaryOp + object Eq extends ParserBridgePos2[Expr3, Expr3, Eq] + case class Neq(x: Expr3, y: Expr3)(pos: Position) extends Expr2 with BinaryOp + object Neq extends ParserBridgePos2[Expr3, Expr3, Neq] + case class And(x: Expr2, y: Expr1)(pos: Position) extends Expr1 with BinaryOp + object And extends ParserBridgePos2[Expr2, Expr1, And] + case class Or(x: Expr1, y: Expr)(pos: Position) extends Expr with BinaryOp + object Or extends ParserBridgePos2[Expr1, Expr, Or] // Types sealed trait Type sealed trait BaseType extends Type with PairElemType - case object IntType extends BaseType with ParserBridge0[IntType.type] - case object BoolType extends BaseType with ParserBridge0[BoolType.type] - case object CharType extends BaseType with ParserBridge0[CharType.type] - case object StringType extends BaseType with ParserBridge0[StringType.type] - case class ArrayType(elemType: Type, dimensions: Int) extends Type with PairElemType - object ArrayType extends ParserBridge2[Type, Int, ArrayType] - case class PairType(fst: PairElemType, snd: PairElemType) extends Type - object PairType extends ParserBridge2[PairElemType, PairElemType, PairType] + case class IntType()(pos: Position) extends BaseType + object IntType extends ParserBridgePos0[IntType] + case class BoolType()(pos: Position) extends BaseType + object BoolType extends ParserBridgePos0[BoolType] + case class CharType()(pos: Position) extends BaseType + object CharType extends ParserBridgePos0[CharType] + case class StringType()(pos: Position) extends BaseType + object StringType extends ParserBridgePos0[StringType] + case class ArrayType(elemType: Type, dimensions: Int)(pos: Position) + extends Type + with PairElemType + object ArrayType extends ParserBridgePos1[Int, Type => ArrayType] { + def apply(a: Int)(pos: Position): Type => ArrayType = elemType => ArrayType(elemType, a)(pos) + } + case class PairType(fst: PairElemType, snd: PairElemType)(pos: Position) extends Type + object PairType extends ParserBridgePos2[PairElemType, PairElemType, PairType] sealed trait PairElemType - case object UntypedPairType extends PairElemType with ParserBridge0[UntypedPairType.type] + case class UntypedPairType()(pos: Position) extends PairElemType + object UntypedPairType extends ParserBridgePos0[UntypedPairType] // waccadoodledo - case class Program(funcs: List[FuncDecl], main: NonEmptyList[Stmt]) - object Program extends ParserBridge2[List[FuncDecl], NonEmptyList[Stmt], Program] + case class Program(funcs: List[FuncDecl], main: NonEmptyList[Stmt])(pos: Position) + object Program extends ParserBridgePos2[List[FuncDecl], NonEmptyList[Stmt], Program] // Function Definitions case class FuncDecl( @@ -95,49 +123,106 @@ object ast { name: Ident, params: List[Param], body: NonEmptyList[Stmt] - ) - object FuncDecl extends ParserBridge4[Type, Ident, List[Param], NonEmptyList[Stmt], FuncDecl] + )(pos: Position) + object FuncDecl + extends ParserBridgePos2[ + List[Param], + NonEmptyList[Stmt], + ((Type, Ident)) => FuncDecl + ] { + def apply(params: List[Param], body: NonEmptyList[Stmt])( + pos: Position + ): ((Type, Ident)) => FuncDecl = + (returnType, name) => FuncDecl(returnType, name, params, body)(pos) + } - case class Param(paramType: Type, name: Ident) - object Param extends ParserBridge2[Type, Ident, Param] + case class Param(paramType: Type, name: Ident)(pos: Position) + object Param extends ParserBridgePos2[Type, Ident, Param] // Statements sealed trait Stmt - case object Skip extends Stmt with ParserBridge0[Skip.type] - case class VarDecl(varType: Type, name: Ident, value: RValue) extends Stmt - object VarDecl extends ParserBridge3[Type, Ident, RValue, VarDecl] - case class Assign(lhs: LValue, value: RValue) extends Stmt - object Assign extends ParserBridge2[LValue, RValue, Assign] - case class Read(lhs: LValue) extends Stmt - object Read extends ParserBridge1[LValue, Read] - case class Free(expr: Expr) extends Stmt - object Free extends ParserBridge1[Expr, Free] - case class Return(expr: Expr) extends Stmt - object Return extends ParserBridge1[Expr, Return] - case class Exit(expr: Expr) extends Stmt - object Exit extends ParserBridge1[Expr, Exit] - case class Print(expr: Expr, newline: Boolean) extends Stmt - object Print extends ParserBridge2[Expr, Boolean, Print] - case class If(cond: Expr, thenStmt: NonEmptyList[Stmt], elseStmt: NonEmptyList[Stmt]) extends Stmt - object If extends ParserBridge3[Expr, NonEmptyList[Stmt], NonEmptyList[Stmt], If] - case class While(cond: Expr, body: NonEmptyList[Stmt]) extends Stmt - object While extends ParserBridge2[Expr, NonEmptyList[Stmt], While] - case class Block(stmt: NonEmptyList[Stmt]) extends Stmt - object Block extends ParserBridge1[NonEmptyList[Stmt], Block] + case class Skip()(pos: Position) extends Stmt + object Skip extends ParserBridgePos0[Skip] + case class VarDecl(varType: Type, name: Ident, value: RValue)(pos: Position) extends Stmt + object VarDecl extends ParserBridgePos3[Type, Ident, RValue, VarDecl] + case class Assign(lhs: LValue, value: RValue)(pos: Position) extends Stmt + object Assign extends ParserBridgePos2[LValue, RValue, Assign] + case class Read(lhs: LValue)(pos: Position) extends Stmt + object Read extends ParserBridgePos1[LValue, Read] + case class Free(expr: Expr)(pos: Position) extends Stmt + object Free extends ParserBridgePos1[Expr, Free] + case class Return(expr: Expr)(pos: Position) extends Stmt + object Return extends ParserBridgePos1[Expr, Return] + case class Exit(expr: Expr)(pos: Position) extends Stmt + object Exit extends ParserBridgePos1[Expr, Exit] + case class Print(expr: Expr, newline: Boolean)(pos: Position) extends Stmt + object Print extends ParserBridgePos2[Expr, Boolean, Print] + case class If(cond: Expr, thenStmt: NonEmptyList[Stmt], elseStmt: NonEmptyList[Stmt])( + pos: Position + ) extends Stmt + object If extends ParserBridgePos3[Expr, NonEmptyList[Stmt], NonEmptyList[Stmt], If] + case class While(cond: Expr, body: NonEmptyList[Stmt])(pos: Position) extends Stmt + object While extends ParserBridgePos2[Expr, NonEmptyList[Stmt], While] + case class Block(stmt: NonEmptyList[Stmt])(pos: Position) extends Stmt + object Block extends ParserBridgePos1[NonEmptyList[Stmt], Block] sealed trait LValue sealed trait RValue - case class ArrayLiter(elems: List[Expr]) extends RValue - object ArrayLiter extends ParserBridge1[List[Expr], ArrayLiter] - case class NewPair(fst: Expr, snd: Expr) extends RValue - object NewPair extends ParserBridge2[Expr, Expr, NewPair] - case class Call(name: Ident, args: List[Expr]) extends RValue - object Call extends ParserBridge2[Ident, List[Expr], Call] + case class ArrayLiter(elems: List[Expr])(pos: Position) extends RValue + object ArrayLiter extends ParserBridgePos1[List[Expr], ArrayLiter] + case class NewPair(fst: Expr, snd: Expr)(pos: Position) extends RValue + object NewPair extends ParserBridgePos2[Expr, Expr, NewPair] + case class Call(name: Ident, args: List[Expr])(pos: Position) extends RValue + object Call extends ParserBridgePos2[Ident, List[Expr], Call] sealed trait PairElem extends LValue with RValue - case class Fst(elem: LValue) extends PairElem - object Fst extends ParserBridge1[LValue, Fst] - case class Snd(elem: LValue) extends PairElem - object Snd extends ParserBridge1[LValue, Snd] + case class Fst(elem: LValue)(pos: Position) extends PairElem + object Fst extends ParserBridgePos1[LValue, Fst] + case class Snd(elem: LValue)(pos: Position) extends PairElem + object Snd extends ParserBridgePos1[LValue, Snd] + + // Parser bridges + case class Position(line: Int, column: Int, offset: Int) + + trait ParserSingletonBridgePos[+A] extends ErrorBridge { + protected def con(pos: (Int, Int), offset: Int): A + infix def from(op: Parsley[?]): Parsley[A] = error((pos, offset).zipped(con) <~ op) + final def <#(op: Parsley[?]): Parsley[A] = this from op + } + + trait ParserBridgePos0[+A] extends ParserSingletonBridgePos[A] { + def apply()(pos: Position): A + + override final def con(pos: (Int, Int), offset: Int): A = + apply()(Position(pos._1, pos._2, offset)) + } + + trait ParserBridgePos1[-A, +B] extends ParserSingletonBridgePos[A => B] { + def apply(a: A)(pos: Position): B + def apply(a: Parsley[A]): Parsley[B] = error(ap1((pos, offset).zipped(con), a)) + + override final def con(pos: (Int, Int), offset: Int): A => B = + this.apply(_)(Position(pos._1, pos._2, offset)) + } + + trait ParserBridgePos2[-A, -B, +C] extends ParserSingletonBridgePos[(A, B) => C] { + def apply(a: A, b: B)(pos: Position): C + def apply(a: Parsley[A], b: => Parsley[B]): Parsley[C] = error( + ap2((pos, offset).zipped(con), a, b) + ) + + override final def con(pos: (Int, Int), offset: Int): (A, B) => C = + apply(_, _)(Position(pos._1, pos._2, offset)) + } + + trait ParserBridgePos3[-A, -B, -C, +D] extends ParserSingletonBridgePos[(A, B, C) => D] { + def apply(a: A, b: B, c: C)(pos: Position): D + def apply(a: Parsley[A], b: => Parsley[B], c: => Parsley[C]): Parsley[D] = error( + ap3((pos, offset).zipped(con), a, b, c) + ) + + override final def con(pos: (Int, Int), offset: Int): (A, B, C) => D = + apply(_, _, _)(Position(pos._1, pos._2, offset)) + } } diff --git a/src/main/wacc/lexer.scala b/src/main/wacc/lexer.scala index 094ac12..4a810c0 100644 --- a/src/main/wacc/lexer.scala +++ b/src/main/wacc/lexer.scala @@ -4,7 +4,36 @@ import parsley.Parsley import parsley.character import parsley.token.{Basic, Lexer} import parsley.token.descriptions.* +import parsley.token.errors._ +val errConfig = new ErrorConfig { + override def labelSymbol = Map( + "!=" -> Label("binary operator"), + "%" -> Label("binary operator"), + "&&" -> Label("binary operator"), + "*" -> Label("binary operator"), + "/" -> Label("binary operator"), + "<" -> Label("binary operator"), + "<=" -> Label("binary operator"), + "==" -> Label("binary operator"), + ">" -> Label("binary operator"), + ">=" -> Label("binary operator"), + "||" -> Label("binary operator"), + "!" -> Label("unary operator"), + "len" -> Label("unary operator"), + "ord" -> Label("unary operator"), + "chr" -> Label("unary operator"), + "bool" -> Label("valid type"), + "char" -> Label("valid type"), + "int" -> Label("valid type"), + "pair" -> Label("valid type"), + "string" -> Label("valid type"), + "fst" -> Label("pair extraction"), + "snd" -> Label("pair extraction"), + "false" -> Label("boolean literal"), + "true" -> Label("boolean literal") + ) +} object lexer { private val desc = LexicalDesc.plain.copy( nameDesc = NameDesc.plain.copy( @@ -43,7 +72,7 @@ object lexer { ) ) - private val lexer = Lexer(desc) + private val lexer = Lexer(desc, errConfig) val ident = lexer.lexeme.names.identifier val integer = lexer.lexeme.integer.decimal32[Int] val negateCheck = lexer.nonlexeme.symbol("-") ~> character.digit @@ -51,5 +80,15 @@ object lexer { val stringLit = lexer.lexeme.string.ascii val implicits = lexer.lexeme.symbol.implicits + val errTokens = Seq( + lexer.nonlexeme.names.identifier.map(v => s"identifier $v"), + lexer.nonlexeme.integer.decimal32[Int].map(n => s"integer $n"), + lexer.nonlexeme.character.ascii.map(c => s"character literal $c"), + lexer.nonlexeme.string.ascii.map(s => s"string literal $s"), + character.whitespace.map(_ => "") + ) ++ desc.symbolDesc.hardKeywords.map { k => + lexer.nonlexeme.symbol(k).as(s"keyword $k") + } + def fully[A](p: Parsley[A]): Parsley[A] = lexer.fully(p) } diff --git a/src/main/wacc/parser.scala b/src/main/wacc/parser.scala index 84ee093..5751732 100644 --- a/src/main/wacc/parser.scala +++ b/src/main/wacc/parser.scala @@ -6,14 +6,51 @@ import parsley.Parsley.{atomic, many, notFollowedBy, pure} import parsley.combinator.{countSome, sepBy} import parsley.expr.{precedence, SOps, InfixL, InfixN, InfixR, Prefix, Atoms} import parsley.errors.combinator._ -import parsley.cats.combinator.{sepBy1, some} +import parsley.syntax.zipped._ +import parsley.cats.combinator.{some} import cats.data.NonEmptyList +import parsley.errors.DefaultErrorBuilder +import parsley.errors.ErrorBuilder +import parsley.errors.tokenextractors.LexToken object parser { import lexer.implicits.implicitSymbol - import lexer.{ident, integer, charLit, stringLit, negateCheck} + import lexer.{ident, integer, charLit, stringLit, negateCheck, errTokens} import ast._ + // error extensions + extension [A](p: Parsley[A]) { + // combines label and explain together into one function call + def labelAndExplain(label: String, explanation: String): Parsley[A] = { + p.label(label).explain(explanation) + } + def labelAndExplain(t: LabelType): Parsley[A] = { + t match { + case LabelType.Expr => + labelWithType(t).explain( + "a valid expression can start with: null, literals, identifiers, unary operators, or parentheses. " + + "Expressions can also contain array indexing and binary operators. " + + "Pair extraction is not allowed in expressions, only in assignments." + ) + case _ => labelWithType(t) + } + } + + def labelWithType(t: LabelType): Parsley[A] = { + t match { + case LabelType.Expr => p.label("valid expression") + case LabelType.Pair => p.label("valid pair") + } + } + } + + enum LabelType: + case Expr + case Pair + + implicit val builder: ErrorBuilder[String] = new DefaultErrorBuilder with LexToken { + def tokens = errTokens + } def parse(input: String): Result[String, Program] = parser.parse(input) private val parser = lexer.fully(``) @@ -28,11 +65,14 @@ object parser { Greater from ">", GreaterEq from ">=" ) +: - SOps(InfixL)(Add from "+", Sub from "-") +: + SOps(InfixL)( + (Add from "+").label("binary operator"), + (Sub from "-").label("binary operator") + ) +: SOps(InfixL)(Mul from "*", Div from "/", Mod from "%") +: SOps(Prefix)( Not from "!", - Negate from (notFollowedBy(negateCheck) ~> "-"), + (Negate from (notFollowedBy(negateCheck) ~> "-")).hide, Len from "len", Ord from "ord", Chr from "chr" @@ -42,10 +82,10 @@ object parser { // Atoms private lazy val ``: Atoms[Expr6] = Atoms( - IntLiter(integer), - BoolLiter(("true" as true) | ("false" as false)), - CharLiter(charLit), - StrLiter(stringLit), + IntLiter(integer).label("integer literal"), + BoolLiter(("true" as true) | ("false" as false)).label("boolean literal"), + CharLiter(charLit).label("character literal"), + StrLiter(stringLit).label("string literal"), PairLiter from "null", ``, Parens("(" ~> `` <~ ")") @@ -53,10 +93,7 @@ object parser { private val `` = Ident(ident) private lazy val `` = `` <**> (`` identity) - private val `` = - some("[" ~> `` <~ "]") map { indices => - ArrayElem((_: Ident), indices) - } + private val `` = ArrayElem(some("[" ~> `` <~ "]")) // Types private lazy val ``: Parsley[Type] = @@ -64,7 +101,7 @@ object parser { private val `` = (IntType from "int") | (BoolType from "bool") | (CharType from "char") | (StringType from "string") private lazy val `` = - countSome("[" ~> "]") map { cnt => ArrayType((_: Type), cnt) } + ArrayType(countSome("[" ~> "]")) private val `` = "pair" private val ``: Parsley[PairType] = PairType( "(" ~> `` <~ ",", @@ -72,38 +109,48 @@ object parser { ) private lazy val `` = (`` <**> (`` identity)) | - `` ~> ((`` <**> ``) UntypedPairType) + ((UntypedPairType from ``) <**> + ((`` <**> ``) + .map(arr => (_: UntypedPairType) => arr) identity)) // Statements private lazy val `` = Program( - "begin" ~> many(atomic(`` <~> `` <~ "(") <**> ``), - `` <~ "end" + "begin" ~> many( + atomic(``.label("function declaration") <~> `` <~ "(") <**> `` + ).label("function declaration"), + ``.label("main program body") <~ "end" ) private lazy val `` = - (sepBy(``, ",") <~ ")" <~ "is" <~> ``.guardAgainst { - case stmts if !stmts.isReturning => Seq("All functions must end in a returning statement") - } <~ "end") map { (params, stmt) => - (FuncDecl((_: Type), (_: Ident), params, stmt)).tupled - } + FuncDecl( + sepBy(``, ",") <~ ")" <~ "is", + ``.guardAgainst { + case stmts if !stmts.isReturning => Seq("All functions must end in a returning statement") + } <~ "end" + ) private lazy val `` = Param(``, ``) private lazy val ``: Parsley[NonEmptyList[Stmt]] = - sepBy1(``, ";") + ( + ``.label("main program body"), + (many(";" ~> ``.label("statement after ';'"))) Nil + ).zipped(NonEmptyList.apply) + private lazy val `` = (Skip from "skip") | Read("read" ~> ``) - | Free("free" ~> ``) - | Return("return" ~> ``) - | Exit("exit" ~> ``) - | Print("print" ~> ``, pure(false)) - | Print("println" ~> ``, pure(true)) + | Free("free" ~> ``.labelAndExplain(LabelType.Expr)) + | Return("return" ~> ``.labelAndExplain(LabelType.Expr)) + | Exit("exit" ~> ``.labelAndExplain(LabelType.Expr)) + | Print("print" ~> ``.labelAndExplain(LabelType.Expr), pure(false)) + | Print("println" ~> ``.labelAndExplain(LabelType.Expr), pure(true)) | If( - "if" ~> `` <~ "then", + "if" ~> ``.labelWithType(LabelType.Expr) <~ "then", `` <~ "else", `` <~ "fi" ) - | While("while" ~> `` <~ "do", `` <~ "done") + | While("while" ~> ``.labelWithType(LabelType.Expr) <~ "do", `` <~ "done") | Block("begin" ~> `` <~ "end") - | VarDecl(``, `` <~ "=", ``) + | VarDecl(``, `` <~ "=", ``.label("valid initial value for variable")) + // TODO: Can we inline the name of the variable in the message | Assign(`` <~ "=", ``) private lazy val ``: Parsley[LValue] = `` | `` @@ -117,9 +164,10 @@ object parser { Call( "call" ~> `` <~ "(", sepBy(``, ",") <~ ")" - ) | `` + ) | ``.labelWithType(LabelType.Expr) private lazy val `` = - Fst("fst" ~> ``) | Snd("snd" ~> ``) + Fst("fst" ~> ``.label("valid pair")) + | Snd("snd" ~> ``.label("valid pair")) private lazy val `` = ArrayLiter( "[" ~> sepBy(``, ",") <~ "]" ) diff --git a/src/main/wacc/renamer.scala b/src/main/wacc/renamer.scala new file mode 100644 index 0000000..6b78dc1 --- /dev/null +++ b/src/main/wacc/renamer.scala @@ -0,0 +1,215 @@ +package wacc + +import scala.collection.mutable + +object renamer { + import ast._ + import types._ + + enum IdentType { + case Func + case Var + } + + private case class Scope( + current: mutable.Map[(String, IdentType), Ident], + parent: Map[(String, IdentType), Ident] + ) { + + /** Create a new scope with the current scope as its parent. + * + * @return + * A new scope with an empty current scope, and this scope flattened into the parent scope. + */ + def subscope: Scope = + Scope(mutable.Map.empty, Map.empty.withDefault(current.withDefault(parent))) + + /** Attempt to add a new identifier to the current scope. If the identifier already exists in + * the current scope, add an error to the error list. + * + * @param semType + * The semantic type of the identifier. + * @param name + * The name of the identifier. + * @param identType + * The identifier type (function or variable). + * @param globalNames + * The global map of identifiers to semantic types - the identifier will be added to this + * map. + * @param globalNumbering + * The global map of identifier names to the number of times they have been declared - will + * used to rename this identifier, and will be incremented. + * @param errors + * The list of errors to append to. + */ + def add(semType: SemType, name: Ident, identType: IdentType)(using + globalNames: mutable.Map[Ident, SemType], + globalNumbering: mutable.Map[String, Int], + errors: mutable.Builder[Error, List[Error]] + ) = { + if (current.contains((name.v, identType))) { + errors += Error.DuplicateDeclaration(name) + } else { + val uid = globalNumbering.getOrElse(name.v, 0) + name.uid = uid + current((name.v, identType)) = name + + globalNames(name) = semType + globalNumbering(name.v) = uid + 1 + } + } + } + + /** Check scoping of all variables and functions in the program. Also generate semantic types for + * all identifiers. + * + * @param prog + * AST of the program + * @param errors + * List of errors to append to + * @return + * Map of all (renamed) identifies to their semantic types + */ + def rename(prog: Program)(using + errors: mutable.Builder[Error, List[Error]] + ): Map[Ident, SemType] = + given globalNames: mutable.Map[Ident, SemType] = mutable.Map.empty + given globalNumbering: mutable.Map[String, Int] = mutable.Map.empty + val scope = Scope(mutable.Map.empty, Map.empty) + val Program(funcs, main) = prog + funcs + // First add all function declarations to the scope + .map { case FuncDecl(retType, name, params, body) => + val paramTypes = params.map { param => + val paramType = SemType(param.paramType) + paramType + } + scope.add(KnownType.Func(SemType(retType), paramTypes), name, IdentType.Func) + (params zip paramTypes, body) + } + // Only then rename the function bodies + // (functions can call one-another regardless of order of declaration) + .foreach { case (params, body) => + val functionScope = scope.subscope + params.foreach { case (param, paramType) => + functionScope.add(paramType, param.name, IdentType.Var) + } + body.toList.foreach(rename(functionScope.subscope)) // body can shadow function params + } + main.toList.foreach(rename(scope)) + globalNames.toMap + + /** Check scoping of all identifies in a given AST node. + * + * @param scope + * The current scope and flattened parent scope. + * @param node + * The AST node. + * @param globalNames + * The global map of identifiers to semantic types - renamed identifiers will be added to this + * map. + * @param globalNumbering + * The global map of identifier names to the number of times they have been declared - used and + * updated during identifier renaming. + * @param errors + */ + private def rename(scope: Scope)( + node: Ident | Stmt | LValue | RValue | Expr + )(using + globalNames: mutable.Map[Ident, SemType], + globalNumbering: mutable.Map[String, Int], + errors: mutable.Builder[Error, List[Error]] + ): Unit = node match { + // These cases are more interesting because the involve making subscopes + // or modifying the current scope. + case VarDecl(synType, name, value) => { + // Order matters here. Variable isn't declared until after the value is evaluated. + rename(scope)(value) + // Attempt to add the new variable to the current scope. + scope.add(SemType(synType), name, IdentType.Var) + } + case If(cond, thenStmt, elseStmt) => { + rename(scope)(cond) + // then and else both have their own scopes + thenStmt.toList.foreach(rename(scope.subscope)) + elseStmt.toList.foreach(rename(scope.subscope)) + } + case While(cond, body) => { + rename(scope)(cond) + // while bodies have their own scopes + body.toList.foreach(rename(scope.subscope)) + } + // begin-end blocks have their own scopes + case Block(body) => body.toList.foreach(rename(scope.subscope)) + + // These cases are simpler, mostly just recursive calls to rename() + case Assign(lhs, value) => { + // Variables may be reassigned with their value in the rhs, so order doesn't matter here. + rename(scope)(lhs) + rename(scope)(value) + } + case Read(lhs) => rename(scope)(lhs) + case Free(expr) => rename(scope)(expr) + case Return(expr) => rename(scope)(expr) + case Exit(expr) => rename(scope)(expr) + case Print(expr, _) => rename(scope)(expr) + case NewPair(fst, snd) => { + rename(scope)(fst) + rename(scope)(snd) + } + case Call(name, args) => { + renameIdent(scope, name, IdentType.Func) + args.foreach(rename(scope)) + } + case Fst(elem) => rename(scope)(elem) + case Snd(elem) => rename(scope)(elem) + case ArrayLiter(elems) => elems.foreach(rename(scope)) + case ArrayElem(name, indices) => { + rename(scope)(name) + indices.toList.foreach(rename(scope)) + } + case Parens(expr) => rename(scope)(expr) + case op: UnaryOp => rename(scope)(op.x) + case op: BinaryOp => { + rename(scope)(op.x) + rename(scope)(op.y) + } + // Default to variables. Only `call` uses IdentType.Func. + case id: Ident => renameIdent(scope, id, IdentType.Var) + // These literals cannot contain identifies, exit immediately. + case IntLiter(_) | BoolLiter(_) | CharLiter(_) | StrLiter(_) | PairLiter() | Skip() => () + } + + /** Lookup an identifier in the current scope and rename it. If the identifier is not found, add + * an error to the error list and add it to the current scope with an unknown type. + * + * @param scope + * The current scope and flattened parent scope. + * @param ident + * The identifier to rename. + * @param identType + * The type of the identifier (function or variable). + * @param globalNames + * Used to add not-found identifiers to scope. + * @param globalNumbering + * Used to add not-found identifiers to scope. + * @param errors + * The list of errors to append to. + */ + private def renameIdent(scope: Scope, ident: Ident, identType: IdentType)(using + globalNames: mutable.Map[Ident, SemType], + globalNumbering: mutable.Map[String, Int], + errors: mutable.Builder[Error, List[Error]] + ): Unit = { + // Unfortunately map defaults only work with `.apply()`, which throws an error when the key is not found. + // Neither is there a way to check whether a default exists, so we have to use a try-catch. + try { + val Ident(_, uid) = scope.current.withDefault(scope.parent)((ident.v, identType)) + ident.uid = uid + } catch { + case _: NoSuchElementException => + errors += Error.UndefinedIdentifier(ident, identType) + scope.add(?, ident, identType) + } + } +} diff --git a/src/main/wacc/types.scala b/src/main/wacc/types.scala new file mode 100644 index 0000000..e2a7988 --- /dev/null +++ b/src/main/wacc/types.scala @@ -0,0 +1,31 @@ +package wacc + +object types { + import ast._ + + sealed trait SemType + case object ? extends SemType + enum KnownType extends SemType { + case Int + case Bool + case Char + case String + case Array(elem: SemType) + case Pair(left: SemType, right: SemType) + case Func(ret: SemType, params: List[SemType]) + } + + object SemType { + def apply(synType: Type | PairElemType): KnownType = synType match { + case IntType() => KnownType.Int + case BoolType() => KnownType.Bool + case CharType() => KnownType.Char + case StringType() => KnownType.String + // For semantic types it is easier to work with recursion rather than a fixed size + case ArrayType(elemType, dimension) => + (0 until dimension).foldLeft(SemType(elemType))((acc, _) => KnownType.Array(acc)) + case PairType(fst, snd) => KnownType.Pair(SemType(fst), SemType(snd)) + case UntypedPairType() => KnownType.Pair(?, ?) + } + } +}