diff --git a/.gitignore b/.gitignore index 03801cc..dab1691 100644 --- a/.gitignore +++ b/.gitignore @@ -4,4 +4,3 @@ .vscode/ wacc-examples/ .idea/ - diff --git a/extension/examples/invalid/semantics/badWacc.wacc b/extension/examples/invalid/semantics/badWacc.wacc new file mode 100644 index 0000000..a334d57 --- /dev/null +++ b/extension/examples/invalid/semantics/badWacc.wacc @@ -0,0 +1,10 @@ +begin + int main() is + int a = 5 ; + string b = "Hello" ; + return a + b + end + + int result = call main() ; + exit result +end diff --git a/extension/examples/invalid/semantics/imports/importBadFile.wacc b/extension/examples/invalid/semantics/imports/importBadFile.wacc new file mode 100644 index 0000000..b116ccd --- /dev/null +++ b/extension/examples/invalid/semantics/imports/importBadFile.wacc @@ -0,0 +1,6 @@ +import "./doesNotExist.wacc" (main) + +begin + int result = call main() ; + exit result +end diff --git a/extension/examples/invalid/semantics/imports/importBadFunc.wacc b/extension/examples/invalid/semantics/imports/importBadFunc.wacc new file mode 100644 index 0000000..bf3c9a0 --- /dev/null +++ b/extension/examples/invalid/semantics/imports/importBadFunc.wacc @@ -0,0 +1,6 @@ +import "../../../valid/sum.wacc" (mult) + +begin + int result = call mult(3, 2) ; + exit result +end diff --git a/extension/examples/invalid/semantics/imports/importBadSem.wacc b/extension/examples/invalid/semantics/imports/importBadSem.wacc new file mode 100644 index 0000000..d20e3a6 --- /dev/null +++ b/extension/examples/invalid/semantics/imports/importBadSem.wacc @@ -0,0 +1,10 @@ +import "../badWacc.wacc" (main) + +begin + int sum(int a, int b) is + return a + b + end + + int result = call main() ; + exit result +end diff --git a/extension/examples/invalid/semantics/imports/importBadSem2.wacc b/extension/examples/invalid/semantics/imports/importBadSem2.wacc new file mode 100644 index 0000000..4bd330e --- /dev/null +++ b/extension/examples/invalid/semantics/imports/importBadSem2.wacc @@ -0,0 +1,6 @@ +import "./importBadSem.wacc" (sum) + +begin + int result = call sum(1, 2) ; + exit result +end diff --git a/extension/examples/invalid/semantics/imports/inderect.wacc b/extension/examples/invalid/semantics/imports/inderect.wacc new file mode 100644 index 0000000..120f9ba --- /dev/null +++ b/extension/examples/invalid/semantics/imports/inderect.wacc @@ -0,0 +1,6 @@ +import "../../../valid/imports/basic.wacc" (sum) + +begin + int result = call sum(3, 2) ; + exit result +end diff --git a/extension/examples/invalid/syntax/badWacc.wacc b/extension/examples/invalid/syntax/badWacc.wacc new file mode 100644 index 0000000..a375309 --- /dev/null +++ b/extension/examples/invalid/syntax/badWacc.wacc @@ -0,0 +1,6 @@ +int main() is + println "Hello World!" ; + return 0 +end + +skip diff --git a/extension/examples/invalid/syntax/imports/emptyImport.wacc b/extension/examples/invalid/syntax/imports/emptyImport.wacc new file mode 100644 index 0000000..ec9dbd0 --- /dev/null +++ b/extension/examples/invalid/syntax/imports/emptyImport.wacc @@ -0,0 +1,8 @@ +import "../../../valid/sum.wacc" sum, main + +begin + int result1 = call sum(5, 10) ; + int result2 = call main() ; + println result1 ; + println result2 +end diff --git a/extension/examples/invalid/syntax/imports/emptyImport2.wacc b/extension/examples/invalid/syntax/imports/emptyImport2.wacc new file mode 100644 index 0000000..99d38b9 --- /dev/null +++ b/extension/examples/invalid/syntax/imports/emptyImport2.wacc @@ -0,0 +1,5 @@ +import "../../../valid/sum.wacc" () + +begin + exit 0 +end diff --git a/extension/examples/invalid/syntax/imports/importBadSyntax.wacc b/extension/examples/invalid/syntax/imports/importBadSyntax.wacc new file mode 100644 index 0000000..d20e3a6 --- /dev/null +++ b/extension/examples/invalid/syntax/imports/importBadSyntax.wacc @@ -0,0 +1,10 @@ +import "../badWacc.wacc" (main) + +begin + int sum(int a, int b) is + return a + b + end + + int result = call main() ; + exit result +end diff --git a/extension/examples/invalid/syntax/imports/importBadSyntax2.wacc b/extension/examples/invalid/syntax/imports/importBadSyntax2.wacc new file mode 100644 index 0000000..0e0e0e1 --- /dev/null +++ b/extension/examples/invalid/syntax/imports/importBadSyntax2.wacc @@ -0,0 +1,6 @@ +import "./importBadSyntax.wacc" (sum) + +begin + int result = call sum(1, 2) ; + exit result +end diff --git a/extension/examples/invalid/syntax/imports/importNoParens.wacc b/extension/examples/invalid/syntax/imports/importNoParens.wacc new file mode 100644 index 0000000..8b13789 --- /dev/null +++ b/extension/examples/invalid/syntax/imports/importNoParens.wacc @@ -0,0 +1 @@ + diff --git a/extension/examples/invalid/syntax/imports/importSemis.wacc b/extension/examples/invalid/syntax/imports/importSemis.wacc new file mode 100644 index 0000000..f127844 --- /dev/null +++ b/extension/examples/invalid/syntax/imports/importSemis.wacc @@ -0,0 +1,9 @@ +import "../../../valid/sum.wacc" (sum) ; +import "../../../valid/sum.wacc" (main) ; + +begin + int result1 = call sum(5, 10) ; + int result2 = call main() ; + println result1 ; + println result2 +end diff --git a/extension/examples/invalid/syntax/imports/importStar.wacc b/extension/examples/invalid/syntax/imports/importStar.wacc new file mode 100644 index 0000000..e027caa --- /dev/null +++ b/extension/examples/invalid/syntax/imports/importStar.wacc @@ -0,0 +1,5 @@ +import "../../../valid/sum.wacc" * + +begin + exit 0 +end diff --git a/extension/examples/invalid/syntax/imports/importStar2.wacc b/extension/examples/invalid/syntax/imports/importStar2.wacc new file mode 100644 index 0000000..bae08ef --- /dev/null +++ b/extension/examples/invalid/syntax/imports/importStar2.wacc @@ -0,0 +1,5 @@ +import "../../../valid/sum.wacc" (*) + +begin + exit 0 +end diff --git a/extension/examples/valid/.gitignore b/extension/examples/valid/.gitignore new file mode 100644 index 0000000..ed87167 --- /dev/null +++ b/extension/examples/valid/.gitignore @@ -0,0 +1,7 @@ +* + +!imports/ +imports/* + +!.gitignore +!*.wacc diff --git a/extension/examples/valid/imports/alias.wacc b/extension/examples/valid/imports/alias.wacc new file mode 100644 index 0000000..91496a1 --- /dev/null +++ b/extension/examples/valid/imports/alias.wacc @@ -0,0 +1,22 @@ +# import main from ../sum.wacc and ./basic.wacc + +# Output: +# 15 +# 0 +# -33 +# + +# Exit: +# 0 + +# Program: + +import "../sum.wacc" (main as sumMain) +import "./basic.wacc" (main) + +begin + int result1 = call sumMain() ; + int result2 = call main() ; + println result1 ; + println result2 +end diff --git a/extension/examples/valid/imports/basic.wacc b/extension/examples/valid/imports/basic.wacc new file mode 100644 index 0000000..d34a34a --- /dev/null +++ b/extension/examples/valid/imports/basic.wacc @@ -0,0 +1,21 @@ +# import sum from ../sum.wacc + +# Output: +# -33 +# + +# Exit: +# 0 + +# Program: + +import "../sum.wacc" (sum) + +begin + int main() is + int result = call sum(-10, -23) ; + return result + end + int result = call main() ; + println result +end diff --git a/extension/examples/valid/imports/manyMains.wacc b/extension/examples/valid/imports/manyMains.wacc new file mode 100644 index 0000000..fc3bc7c --- /dev/null +++ b/extension/examples/valid/imports/manyMains.wacc @@ -0,0 +1,33 @@ +# import all the mains + +# Output: +# 15 +# -33 +# 0 +# -33 +# 0 +# + +# Exit: +# 99 + +# Program: + +import "../sum.wacc" (main as sumMain) +import "./basic.wacc" (main as basicMain) +import "./multiFunc.wacc" (main as multiFuncMain) + +begin + int main() is + int result1 = call sumMain() ; + int result2 = call basicMain() ; + int result3 = call multiFuncMain() ; + println result1 ; + println result2 ; + println result3 ; + return 99 + end + + int result = call main() ; + exit result +end diff --git a/extension/examples/valid/imports/multiFunc.wacc b/extension/examples/valid/imports/multiFunc.wacc new file mode 100644 index 0000000..22d6e4d --- /dev/null +++ b/extension/examples/valid/imports/multiFunc.wacc @@ -0,0 +1,27 @@ +# import sum, main from ../sum.wacc + +# Output: +# 15 +# -33 +# 0 +# 0 +# + +# Exit: +# 0 + +# Program: + +import "../sum.wacc" (sum, main as sumMain) + +begin + int main() is + int result = call sum(-10, -23) ; + println result ; + return 0 + end + int result1 = call sumMain() ; + int result2 = call main() ; + println result1 ; + println result2 +end diff --git a/extension/examples/valid/sum.wacc b/extension/examples/valid/sum.wacc new file mode 100644 index 0000000..dc62e24 --- /dev/null +++ b/extension/examples/valid/sum.wacc @@ -0,0 +1,27 @@ +# simple sum program + +# Output: +# 15 +# + +# Exit: +# 0 + +# Program: + +begin + int sum(int a, int b) is + return a + b + end + + int main() is + int a = 5 ; + int b = 10 ; + int result = call sum(a, b) ; + println result ; + return 0 + end + + int result = call main() ; + exit result +end diff --git a/src/main/wacc/Main.scala b/src/main/wacc/Main.scala index d964657..e78d4bd 100644 --- a/src/main/wacc/Main.scala +++ b/src/main/wacc/Main.scala @@ -18,6 +18,7 @@ import org.typelevel.log4cats.Logger import assemblyIR as asm import cats.data.ValidatedNel +import java.io.File /* TODO: @@ -68,21 +69,25 @@ val outputOpt: Opts[Option[Path]] = .orNone def frontend( - contents: String -): Either[NonEmptyList[Error], microWacc.Program] = + contents: String, + file: File +): IO[Either[NonEmptyList[Error], microWacc.Program]] = parser.parse(contents) match { - case Failure(msg) => Left(NonEmptyList.one(Error.SyntaxError(msg))) - case Success(prog) => + case Failure(msg) => IO.pure(Left(NonEmptyList.one(Error.SyntaxError(file, msg)))) + case Success(fn) => + val partialProg = fn(file) given errors: mutable.Builder[Error, List[Error]] = List.newBuilder - val (names, funcs) = renamer.rename(prog) - given ctx: typeChecker.TypeCheckerCtx = typeChecker.TypeCheckerCtx(names, funcs, errors) - val typedProg = typeChecker.check(prog) + for { + (prog, renameErrors) <- renamer.rename(partialProg) + _ = errors.addAll(renameErrors.toList) + typedProg = typeChecker.check(prog, errors) - NonEmptyList.fromList(errors.result) match { - case Some(errors) => Left(errors) - case None => Right(typedProg) - } + res = NonEmptyList.fromList(errors.result) match { + case Some(errors) => Left(errors) + case None => Right(typedProg) + } + } yield res } def backend(typedProg: microWacc.Program): Chain[asm.AsmLine] = @@ -105,29 +110,31 @@ def compile( writer.writeTo(backend(typedProg), outputPath) *> logger.info(s"Success: ${outputPath.toAbsolutePath}") - def processProgram(contents: String, outDir: Path): IO[Int] = - frontend(contents) match { - case Left(errors) => - val code = errors.map(err => err.exitCode).toList.min - given errorContent: String = contents - val errorMsg = errors.map(formatError).toIterable.mkString("\n") - for { - _ <- logAction(s"Compilation failed for $filePath\nExit code: $code") - _ <- IO.blocking( - // Explicit println since we want this to always show without logger thread info e.t.c. - println(s"Compilation failed for ${filePath.toAbsolutePath}:\n$errorMsg") - ) - } yield code + def processProgram(contents: String, file: File, outDir: Path): IO[Int] = + for { + frontendResult <- frontend(contents, file) + res <- frontendResult match { + case Left(errors) => + val code = errors.map(err => err.exitCode).toList.min + val errorMsg = errors.map(formatError).toIterable.mkString("\n") + for { + _ <- logAction(s"Compilation failed for $filePath\nExit code: $code") + _ <- IO.blocking( + // Explicit println since we want this to always show without logger thread info e.t.c. + println(s"Compilation failed for ${file.getCanonicalPath}:\n$errorMsg") + ) + } yield code - case Right(typedProg) => - val outputFile = outDir.resolve(filePath.getFileName.toString.stripSuffix(".wacc") + ".s") - writeOutputFile(typedProg, outputFile).as(SUCCESS) - } + case Right(typedProg) => + val outputFile = outDir.resolve(filePath.getFileName.toString.stripSuffix(".wacc") + ".s") + writeOutputFile(typedProg, outputFile).as(SUCCESS) + } + } yield res for { contents <- readSourceFile _ <- logAction(s"Compiling file: ${filePath.toAbsolutePath}") - exitCode <- processProgram(contents, outputDir.getOrElse(filePath.getParent)) + exitCode <- processProgram(contents, filePath.toFile, outputDir.getOrElse(filePath.getParent)) } yield exitCode } diff --git a/src/main/wacc/backend/LabelGenerator.scala b/src/main/wacc/backend/LabelGenerator.scala index 3b5169b..fd0006f 100644 --- a/src/main/wacc/backend/LabelGenerator.scala +++ b/src/main/wacc/backend/LabelGenerator.scala @@ -18,7 +18,7 @@ private class LabelGenerator { } private def getLabel(target: CallTarget | RuntimeError): String = target match { - case Ident(v, _) => s"wacc_$v" + case Ident(v, guid) => s"wacc_${v}_$guid" case Builtin(name) => s"_$name" case err: RuntimeError => s".L.${err.name}" } diff --git a/src/main/wacc/frontend/Error.scala b/src/main/wacc/frontend/Error.scala index e515494..188e91c 100644 --- a/src/main/wacc/frontend/Error.scala +++ b/src/main/wacc/frontend/Error.scala @@ -2,6 +2,7 @@ package wacc import wacc.ast.Position import wacc.types._ +import java.io.File private val SYNTAX_ERROR = 100 private val SEMANTIC_ERROR = 200 @@ -18,13 +19,13 @@ enum Error { case TypeMismatch(pos: Position, expected: SemType, got: SemType, msg: String) case InternalError(pos: Position, msg: String) - case SyntaxError(msg: String) + case SyntaxError(file: File, msg: String) } extension (e: Error) { def exitCode: Int = e match { - case Error.SyntaxError(_) => SYNTAX_ERROR - case _ => SEMANTIC_ERROR + case Error.SyntaxError(_, _) => SYNTAX_ERROR + case _ => SEMANTIC_ERROR } } @@ -35,15 +36,25 @@ extension (e: Error) { * @param errorContent * Contents of the file to generate code snippets */ -def formatError(error: Error)(using errorContent: String): String = { +def formatError(error: Error): String = { val sb = new StringBuilder() + /** Format the file of an error + * + * @param file + * File of the error + */ + def formatFile(file: File): Unit = { + sb.append(s"File: ${file.getCanonicalPath}\n") + } + /** Function to format the position of an error * * @param pos * Position of the error */ def formatPosition(pos: Position): Unit = { + formatFile(pos.file) sb.append(s"(line ${pos.line}, column ${pos.column}):\n") } @@ -55,7 +66,7 @@ def formatError(error: Error)(using errorContent: String): String = { * Size(in chars) of section to highlight */ def formatHighlight(pos: Position, size: Int): Unit = { - val lines = errorContent.split("\n") + val lines = os.read(os.Path(pos.file.getCanonicalPath)).split("\n") val preLine = if (pos.line > 1) lines(pos.line - 2) else "" val midLine = lines(pos.line - 1) val postLine = if (pos.line < lines.size) lines(pos.line) else "" @@ -67,7 +78,7 @@ def formatError(error: Error)(using errorContent: String): String = { } error match { - case Error.SyntaxError(_) => + case Error.SyntaxError(_, _) => sb.append("Syntax error:\n") case _ => sb.append("Semantic error:\n") @@ -76,40 +87,40 @@ def formatError(error: Error)(using errorContent: String): String = { error match { case Error.DuplicateDeclaration(ident) => formatPosition(ident.pos) - sb.append(s"Duplicate declaration of identifier ${ident.v}") + sb.append(s"Duplicate declaration of identifier ${ident.v}\n") formatHighlight(ident.pos, ident.v.length) case Error.UndeclaredVariable(ident) => formatPosition(ident.pos) - sb.append(s"Undeclared variable ${ident.v}") + sb.append(s"Undeclared variable ${ident.v}\n") formatHighlight(ident.pos, ident.v.length) case Error.UndefinedFunction(ident) => formatPosition(ident.pos) - sb.append(s"Undefined function ${ident.v}") + sb.append(s"Undefined function ${ident.v}\n") formatHighlight(ident.pos, ident.v.length) case Error.FunctionParamsMismatch(id, expected, got, funcType) => formatPosition(id.pos) - sb.append(s"Function expects $expected parameters, got $got") + sb.append(s"Function expects $expected parameters, got $got\n") sb.append( - s"(function ${id.v} has type (${funcType.params.mkString(", ")}) -> ${funcType.returnType})" + s"(function ${id.v} has type (${funcType.params.mkString(", ")}) -> ${funcType.returnType})\n" ) formatHighlight(id.pos, 1) case Error.TypeMismatch(pos, expected, got, msg) => formatPosition(pos) - sb.append(s"Type mismatch: $msg\nExpected: $expected\nGot: $got") + sb.append(s"Type mismatch: $msg\nExpected: $expected\nGot: $got\n") formatHighlight(pos, 1) case Error.SemanticError(pos, msg) => formatPosition(pos) - sb.append(msg) + sb.append(msg + "\n") formatHighlight(pos, 1) case wacc.Error.InternalError(pos, msg) => formatPosition(pos) - sb.append(s"Internal error: $msg") + sb.append(s"Internal error: $msg\n") formatHighlight(pos, 1) - case Error.SyntaxError(msg) => - sb.append(msg) + case Error.SyntaxError(file, msg) => + formatFile(file) + sb.append(msg + "\n") sb.append("\n") } sb.toString() - } diff --git a/src/main/wacc/frontend/ast.scala b/src/main/wacc/frontend/ast.scala index 9b14b13..e39f931 100644 --- a/src/main/wacc/frontend/ast.scala +++ b/src/main/wacc/frontend/ast.scala @@ -1,5 +1,6 @@ package wacc +import java.io.File import parsley.Parsley import parsley.generic.ErrorBridge import parsley.ap._ @@ -22,26 +23,42 @@ object ast { /* ============================ ATOMIC EXPRESSIONS ============================ */ case class IntLiter(v: Int)(val pos: Position) extends Expr6 - object IntLiter extends ParserBridgePos1[Int, IntLiter] + object IntLiter extends ParserBridgePos1Atom[Int, IntLiter] case class BoolLiter(v: Boolean)(val pos: Position) extends Expr6 - object BoolLiter extends ParserBridgePos1[Boolean, BoolLiter] + object BoolLiter extends ParserBridgePos1Atom[Boolean, BoolLiter] case class CharLiter(v: Char)(val pos: Position) extends Expr6 - object CharLiter extends ParserBridgePos1[Char, CharLiter] + object CharLiter extends ParserBridgePos1Atom[Char, CharLiter] case class StrLiter(v: String)(val pos: Position) extends Expr6 - object StrLiter extends ParserBridgePos1[String, StrLiter] + object StrLiter extends ParserBridgePos1Atom[String, StrLiter] case class PairLiter()(val pos: Position) extends Expr6 object PairLiter extends ParserBridgePos0[PairLiter] - case class Ident(v: String, var uid: Int = -1)(val pos: Position) extends Expr6 with LValue - object Ident extends ParserBridgePos1[String, Ident] { + case class Ident(var v: String, var guid: Int = -1, var ty: types.RenamerType = types.?)( + val pos: Position + ) extends Expr6 + with LValue + object Ident extends ParserBridgePos1Atom[String, Ident] { def apply(v: String)(pos: Position): Ident = new Ident(v)(pos) } case class ArrayElem(name: Ident, indices: NonEmptyList[Expr])(val pos: Position) extends Expr6 with LValue - object ArrayElem extends ParserBridgePos1[NonEmptyList[Expr], Ident => ArrayElem] { - def apply(a: NonEmptyList[Expr])(pos: Position): Ident => ArrayElem = - name => ArrayElem(name, a)(pos) + object ArrayElem extends ParserBridgePos2Chain[NonEmptyList[Expr], Ident, ArrayElem] { + def apply(indices: NonEmptyList[Expr], name: Ident)(pos: Position): ArrayElem = + new ArrayElem(name, indices)(pos) } + // object ArrayElem extends ParserBridgePos1[NonEmptyList[Expr], (File => Ident) => ArrayElem] { + // def apply(a: NonEmptyList[Expr])(pos: Position): (File => Ident) => ArrayElem = + // name => ArrayElem(name(pos.file), a)(pos) + // } + // object ArrayElem extends ParserSingletonBridgePos[(File => NonEmptyList[Expr]) => (File => Ident) => File => ArrayElem] { + // // def apply(indices: NonEmptyList[Expr]): (File => Ident) => File => ArrayElem = + // // name => file => new ArrayElem(name(file), ) + // def apply(indices: Parsley[File => NonEmptyList[Expr]]): Parsley[(File => Ident) => File => ArrayElem] = + // // error(ap1(pos.map(con),)) + + // override final def con(pos: (Int, Int)): (File => NonEmptyList[Expr]) => => C = + // (a, b) => file => this.apply(a(file), b(file))(Position(pos._1, pos._2, file)) + // } case class Parens(expr: Expr)(val pos: Position) extends Expr6 object Parens extends ParserBridgePos1[Expr, Parens] @@ -119,8 +136,9 @@ object ast { case class ArrayType(elemType: Type, dimensions: Int)(val pos: Position) extends Type with PairElemType - object ArrayType extends ParserBridgePos1[Int, Type => ArrayType] { - def apply(a: Int)(pos: Position): Type => ArrayType = elemType => ArrayType(elemType, a)(pos) + object ArrayType extends ParserBridgePos2Chain[Int, Type, ArrayType] { + def apply(dimensions: Int, elemType: Type)(pos: Position): ArrayType = + ArrayType(elemType, dimensions)(pos) } case class PairType(fst: PairElemType, snd: PairElemType)(val pos: Position) extends Type object PairType extends ParserBridgePos2[PairElemType, PairElemType, PairType] @@ -131,6 +149,18 @@ object ast { /* ============================ PROGRAM STRUCTURE ============================ */ + case class ImportedFunc(sourceName: Ident, importName: Ident)(val pos: Position) + object ImportedFunc extends ParserBridgePos2[Ident, Option[Ident], ImportedFunc] { + def apply(a: Ident, b: Option[Ident])(pos: Position): ImportedFunc = + new ImportedFunc(a, b.getOrElse(a))(pos) + } + + case class Import(source: StrLiter, funcs: NonEmptyList[ImportedFunc])(val pos: Position) + object Import extends ParserBridgePos2[StrLiter, NonEmptyList[ImportedFunc], Import] + + case class PartialProgram(imports: List[Import], self: Program)(val pos: Position) + object PartialProgram extends ParserBridgePos2[List[Import], Program, PartialProgram] + case class Program(funcs: List[FuncDecl], main: NonEmptyList[Stmt])(val pos: Position) object Program extends ParserBridgePos2[List[FuncDecl], NonEmptyList[Stmt], Program] @@ -143,15 +173,15 @@ object ast { body: NonEmptyList[Stmt] )(val pos: Position) object FuncDecl - extends ParserBridgePos2[ - List[Param], - NonEmptyList[Stmt], - ((Type, Ident)) => FuncDecl + extends ParserBridgePos2Chain[ + (List[Param], NonEmptyList[Stmt]), + ((Type, Ident)), + FuncDecl ] { - def apply(params: List[Param], body: NonEmptyList[Stmt])( + def apply(paramsBody: (List[Param], NonEmptyList[Stmt]), retTyName: (Type, Ident))( pos: Position - ): ((Type, Ident)) => FuncDecl = - (returnType, name) => FuncDecl(returnType, name, params, body)(pos) + ): FuncDecl = + new FuncDecl(retTyName._1, retTyName._2, paramsBody._1, paramsBody._2)(pos) } case class Param(paramType: Type, name: Ident)(val pos: Position) @@ -159,7 +189,9 @@ object ast { /* ============================ STATEMENTS ============================ */ - sealed trait Stmt + sealed trait Stmt { + val pos: Position + } case class Skip()(val pos: Position) extends Stmt object Skip extends ParserBridgePos0[Skip] case class VarDecl(varType: Type, name: Ident, value: RValue)(val pos: Position) extends Stmt @@ -207,7 +239,7 @@ object ast { /* ============================ PARSER BRIDGES ============================ */ - case class Position(line: Int, column: Int) + case class Position(line: Int, column: Int, file: File) trait ParserSingletonBridgePos[+A] extends ErrorBridge { protected def con(pos: (Int, Int)): A @@ -215,38 +247,63 @@ object ast { final def <#(op: Parsley[?]): Parsley[A] = this from op } - trait ParserBridgePos0[+A] extends ParserSingletonBridgePos[A] { + trait ParserBridgePos0[+A] extends ParserSingletonBridgePos[File => A] { def apply()(pos: Position): A - override final def con(pos: (Int, Int)): A = - apply()(Position(pos._1, pos._2)) + override final def con(pos: (Int, Int)): File => A = + file => apply()(Position(pos._1, pos._2, file)) } - trait ParserBridgePos1[-A, +B] extends ParserSingletonBridgePos[A => B] { + trait ParserBridgePos1Atom[-A, +B] extends ParserSingletonBridgePos[A => File => B] { def apply(a: A)(pos: Position): B - def apply(a: Parsley[A]): Parsley[B] = error(ap1(pos.map(con), a)) + def apply(a: Parsley[A]): Parsley[File => B] = error(ap1(pos.map(con), a)) - override final def con(pos: (Int, Int)): A => B = - this.apply(_)(Position(pos._1, pos._2)) + override final def con(pos: (Int, Int)): A => File => B = + a => file => this.apply(a)(Position(pos._1, pos._2, file)) } - trait ParserBridgePos2[-A, -B, +C] extends ParserSingletonBridgePos[(A, B) => C] { + trait ParserBridgePos1[-A, +B] extends ParserSingletonBridgePos[(File => A) => File => B] { + def apply(a: A)(pos: Position): B + def apply(a: Parsley[File => A]): Parsley[File => B] = error(ap1(pos.map(con), a)) + + override final def con(pos: (Int, Int)): (File => A) => File => B = + a => file => this.apply(a(file))(Position(pos._1, pos._2, file)) + } + + trait ParserBridgePos2Chain[-A, -B, +C] + extends ParserSingletonBridgePos[(File => A) => (File => B) => File => C] { def apply(a: A, b: B)(pos: Position): C - def apply(a: Parsley[A], b: => Parsley[B]): Parsley[C] = error( + def apply(a: Parsley[File => A]): Parsley[(File => B) => File => C] = error( + ap1(pos.map(con), a) + ) + + override final def con(pos: (Int, Int)): (File => A) => (File => B) => File => C = + a => b => file => this.apply(a(file), b(file))(Position(pos._1, pos._2, file)) + } + + trait ParserBridgePos2[-A, -B, +C] + extends ParserSingletonBridgePos[(File => A, File => B) => File => C] { + def apply(a: A, b: B)(pos: Position): C + def apply(a: Parsley[File => A], b: => Parsley[File => B]): Parsley[File => C] = error( ap2(pos.map(con), a, b) ) - override final def con(pos: (Int, Int)): (A, B) => C = - apply(_, _)(Position(pos._1, pos._2)) + override final def con(pos: (Int, Int)): (File => A, File => B) => File => C = + (a, b) => file => this.apply(a(file), b(file))(Position(pos._1, pos._2, file)) } - trait ParserBridgePos3[-A, -B, -C, +D] extends ParserSingletonBridgePos[(A, B, C) => D] { + trait ParserBridgePos3[-A, -B, -C, +D] + extends ParserSingletonBridgePos[(File => A, File => B, File => C) => File => D] { def apply(a: A, b: B, c: C)(pos: Position): D - def apply(a: Parsley[A], b: => Parsley[B], c: => Parsley[C]): Parsley[D] = error( + def apply( + a: Parsley[File => A], + b: => Parsley[File => B], + c: => Parsley[File => C] + ): Parsley[File => D] = error( ap3(pos.map(con), a, b, c) ) - override final def con(pos: (Int, Int)): (A, B, C) => D = - apply(_, _, _)(Position(pos._1, pos._2)) + override final def con(pos: (Int, Int)): (File => A, File => B, File => C) => File => D = + (a, b, c) => file => apply(a(file), b(file), c(file))(Position(pos._1, pos._2, file)) } } diff --git a/src/main/wacc/frontend/parser.scala b/src/main/wacc/frontend/parser.scala index e798284..ce9283c 100644 --- a/src/main/wacc/frontend/parser.scala +++ b/src/main/wacc/frontend/parser.scala @@ -1,18 +1,22 @@ package wacc +import java.io.File import parsley.Result import parsley.Parsley import parsley.Parsley.{atomic, many, notFollowedBy, pure, unit} -import parsley.combinator.{countSome, sepBy} +import parsley.combinator.{countSome, sepBy, option} import parsley.expr.{precedence, SOps, InfixL, InfixN, InfixR, Prefix, Atoms} import parsley.errors.combinator._ import parsley.errors.patterns.VerifiedErrors import parsley.syntax.zipped._ -import parsley.cats.combinator.{some} +import parsley.cats.combinator.{some, sepBy1} +import cats.syntax.all._ import cats.data.NonEmptyList import parsley.errors.DefaultErrorBuilder import parsley.errors.ErrorBuilder import parsley.errors.tokenextractors.LexToken +import parsley.expr.GOps +import cats.Functor object parser { import lexer.implicits.implicitSymbol @@ -52,13 +56,24 @@ object parser { implicit val builder: ErrorBuilder[String] = new DefaultErrorBuilder with LexToken { def tokens = errTokens } - def parse(input: String): Result[String, Program] = parser.parse(input) - private val parser = lexer.fully(``) + def parse(input: String): Result[String, File => PartialProgram] = parser.parse(input) + private val parser = lexer.fully(``) + + private type FParsley[A] = Parsley[File => A] + + private def fParsley[A](p: Parsley[A]): FParsley[A] = + p map { a => file => a } + + private def fPair[A, B](p: Parsley[(File => A, File => B)]): FParsley[(A, B)] = + p map { case (a, b) => file => (a(file), b(file)) } + + private def fMap[A, F[_]: Functor](p: Parsley[F[File => A]]): FParsley[F[A]] = + p map { funcs => file => funcs.map(_(file)) } // Expressions - private lazy val ``: Parsley[Expr] = precedence { - SOps(InfixR)(Or from "||") +: - SOps(InfixR)(And from "&&") +: + private lazy val ``: FParsley[Expr] = precedence { + GOps(InfixR)(Or from "||") +: + GOps(InfixR)(And from "&&") +: SOps(InfixN)(Eq from "==", Neq from "!=") +: SOps(InfixN)( Less from "<", @@ -83,32 +98,33 @@ object parser { } // Atoms - private lazy val ``: Atoms[Expr6] = Atoms( + private lazy val ``: Atoms[File => Expr6] = Atoms( IntLiter(integer).label("integer literal"), BoolLiter(("true" as true) | ("false" as false)).label("boolean literal"), CharLiter(charLit).label("character literal"), - StrLiter(stringLit).label("string literal"), + ``.label("string literal"), PairLiter from "null", ``, Parens("(" ~> `` <~ ")") ) - private val `` = + private lazy val `` = StrLiter(stringLit) + private lazy val `` = Ident(ident) | some("*" | "&").verifiedExplain("pointer operators are not allowed") private lazy val `` = (`` <~ ("(".verifiedExplain( "functions can only be called using 'call' keyword" ) | unit)) <**> (`` identity) - private val `` = ArrayElem(some("[" ~> `` <~ "]")) + private lazy val `` = ArrayElem(fMap(some("[" ~> `` <~ "]"))) // Types - private lazy val ``: Parsley[Type] = + private lazy val ``: FParsley[Type] = (`` | (`` ~> ``)) <**> (`` identity) private val `` = (IntType from "int") | (BoolType from "bool") | (CharType from "char") | (StringType from "string") private lazy val `` = - ArrayType(countSome("[" ~> "]")) + ArrayType(fParsley(countSome("[" ~> "]"))) private val `` = "pair" - private val ``: Parsley[PairType] = PairType( + private val ``: FParsley[PairType] = PairType( "(" ~> `` <~ ",", `` <~ ")" ) @@ -116,7 +132,7 @@ object parser { (`` <**> (`` identity)) | ((UntypedPairType from ``) <**> ((`` <**> ``) - .map(arr => (_: UntypedPairType) => arr) identity)) + .map(arr => (_: File => UntypedPairType) => arr) identity)) /* Statements Atomic is used in two places here: @@ -127,13 +143,30 @@ object parser { invalid syntax check, this only happens at most once per program so this is not a major concern. */ + private lazy val `` = PartialProgram( + fMap(many(``)), + `` + ) + private lazy val `` = Import( + "import" ~> ``, + "(" ~> fMap(sepBy1(``, ",")) <~ ")" + ) + private lazy val `` = ``.label("import file name") + private lazy val `` = ImportedFunc( + ``.label("imported function name"), + fMap(option("as" ~> ``)).label("imported function alias") + ) private lazy val `` = Program( "begin" ~> ( - many( - atomic( - ``.label("function declaration") <~> `` <~ "(" - ) <**> `` - ).label("function declaration") | + fMap( + many( + fPair( + atomic( + ``.label("function declaration") <~> `` <~ "(" + ) + ) <**> `` + ).label("function declaration") + ) | atomic(`` <~ "(").verifiedExplain("function declaration is missing return type") ), ``.label( @@ -142,17 +175,23 @@ object parser { ) private lazy val `` = FuncDecl( - sepBy(``, ",") <~ ")" <~ "is", - ``.guardAgainst { - case stmts if !stmts.isReturning => Seq("all functions must end in a returning statement") - } <~ "end" + fPair( + (fMap(sepBy(``, ",")) <~ ")" <~ "is") <~> + (``.guardAgainst { + // TODO: passing in an arbitrary file works but is ugly + case stmts if !(stmts(File("."))).isReturning => + Seq("all functions must end in a returning statement") + } <~ "end") + ) ) private lazy val `` = Param(``, ``) - private lazy val ``: Parsley[NonEmptyList[Stmt]] = - ( - ``.label("main program body"), - (many(";" ~> ``.label("statement after ';'"))) Nil - ).zipped(NonEmptyList.apply) + private lazy val ``: FParsley[NonEmptyList[Stmt]] = + fMap( + ( + ``.label("main program body"), + (many(";" ~> ``.label("statement after ';'"))) Nil + ).zipped(NonEmptyList.apply) + ) private lazy val `` = (Skip from "skip") @@ -160,8 +199,8 @@ object parser { | Free("free" ~> ``.labelAndExplain(LabelType.Expr)) | Return("return" ~> ``.labelAndExplain(LabelType.Expr)) | Exit("exit" ~> ``.labelAndExplain(LabelType.Expr)) - | Print("print" ~> ``.labelAndExplain(LabelType.Expr), pure(false)) - | Print("println" ~> ``.labelAndExplain(LabelType.Expr), pure(true)) + | Print("print" ~> ``.labelAndExplain(LabelType.Expr), fParsley(pure(false))) + | Print("println" ~> ``.labelAndExplain(LabelType.Expr), fParsley(pure(true))) | If( "if" ~> ``.labelWithType(LabelType.Expr) <~ "then", `` <~ "else", @@ -185,9 +224,9 @@ object parser { ("call" ~> ``).verifiedExplain( "function calls' results must be assigned to a variable" ) - private lazy val ``: Parsley[LValue] = + private lazy val ``: FParsley[LValue] = `` | `` - private lazy val ``: Parsley[RValue] = + private lazy val ``: FParsley[RValue] = `` | NewPair( "newpair" ~> "(" ~> `` <~ ",", @@ -196,13 +235,13 @@ object parser { `` | Call( "call" ~> `` <~ "(", - sepBy(``, ",") <~ ")" + fMap(sepBy(``, ",")) <~ ")" ) | ``.labelWithType(LabelType.Expr) private lazy val `` = Fst("fst" ~> ``.label("valid pair")) | Snd("snd" ~> ``.label("valid pair")) private lazy val `` = ArrayLiter( - "[" ~> sepBy(``, ",") <~ "]" + "[" ~> fMap(sepBy(``, ",")) <~ "]" ) extension (stmts: NonEmptyList[Stmt]) { diff --git a/src/main/wacc/frontend/renamer.scala b/src/main/wacc/frontend/renamer.scala index b281283..4893d42 100644 --- a/src/main/wacc/frontend/renamer.scala +++ b/src/main/wacc/frontend/renamer.scala @@ -1,6 +1,15 @@ package wacc +import java.io.File import scala.collection.mutable +import cats.effect.IO +import cats.syntax.all._ +import cats.implicits._ +import cats.data.Chain +import cats.data.NonEmptyList +import parsley.{Failure, Success} + +private val MAIN = "$main" object renamer { import ast._ @@ -11,116 +20,271 @@ object renamer { case Var } + private case class ScopeKey(path: String, name: String, identType: IdentType) + private case class ScopeValue(id: Ident, public: Boolean) + private class Scope( - val current: mutable.Map[(String, IdentType), Ident], - val parent: Map[(String, IdentType), Ident] + private val current: mutable.Map[ScopeKey, ScopeValue], + private val parent: Map[ScopeKey, ScopeValue], + guidStart: Int = 0, + val guidInc: Int = 1 ) { + private var guid = guidStart + private var immutable = false + + private def nextGuid(): Int = { + val id = guid + guid += guidInc + id + } + + private def verifyMutable(): Unit = { + if (immutable) throw new IllegalStateException("Cannot modify an immutable scope") + } /** Create a new scope with the current scope as its parent. + * + * To be used for single-threaded applications. * * @return * A new scope with an empty current scope, and this scope flattened into the parent scope. */ - def subscope: Scope = - Scope(mutable.Map.empty, Map.empty.withDefault(current.withDefault(parent))) + def withSubscope[T](f: Scope => T): T = { + val subscope = + Scope(mutable.Map.empty, Map.empty.withDefault(current.withDefault(parent)), guid, guidInc) + immutable = true + val result = f(subscope) + guid = subscope.guid // Sync GUID + immutable = false + result + } + + /** Create new scopes with the current scope as its parent and GUID numbering adjusted + * correctly. + * + * This will permanently mark the current scope as immutable, for thread safety. + * + * To be used for multi-threaded applications. + * + * @return + * New scopes with an empty current scope, and this scope flattened into the parent scope. + */ + def subscopes(n: Int): Seq[Scope] = { + verifyMutable() + immutable = true + (0 until n).map { i => + Scope( + mutable.Map.empty, + Map.empty.withDefault(current.withDefault(parent)), + guid + i * guidInc, + guidInc * n + ) + } + } /** Attempt to add a new identifier to the current scope. If the identifier already exists in * the current scope, add an error to the error list. * - * @param ty - * The semantic type of the variable identifier, or function identifier type. * @param name * The name of the identifier. - * @param globalNames - * The global map of identifiers to semantic types - the identifier will be added to this - * map. - * @param globalNumbering - * The global map of identifier names to the number of times they have been declared - will - * used to rename this identifier, and will be incremented. - * @param errors - * The list of errors to append to. + * @return + * An error, if one occurred. */ - def add(ty: SemType | FuncType, name: Ident)(using - globalNames: mutable.Map[Ident, SemType], - globalFuncs: mutable.Map[Ident, FuncType], - globalNumbering: mutable.Map[String, Int], - errors: mutable.Builder[Error, List[Error]] - ) = { - val identType = ty match { + def add(name: Ident, public: Boolean = false): Chain[Error] = { + verifyMutable() + val path = name.pos.file.getCanonicalPath + val identType = name.ty match { case _: SemType => IdentType.Var case _: FuncType => IdentType.Func } - current.get((name.v, identType)) match { - case Some(Ident(_, uid)) => - errors += Error.DuplicateDeclaration(name) - name.uid = uid + val key = ScopeKey(path, name.v, identType) + current.get(key) match { + case Some(ScopeValue(Ident(_, id, _), _)) => + name.guid = id + Chain.one(Error.DuplicateDeclaration(name)) case None => - val uid = globalNumbering.getOrElse(name.v, 0) - name.uid = uid - current((name.v, identType)) = name - - ty match { - case semType: SemType => - globalNames(name) = semType - case funcType: FuncType => - globalFuncs(name) = funcType - } - globalNumbering(name.v) = uid + 1 + name.guid = nextGuid() + current(key) = ScopeValue(name, public) + Chain.empty } } - private def get(name: String, identType: IdentType): Option[Ident] = + /** Attempt to add a new identifier as an alias to another to the existing scope. + * + * @param alias + * The (new) alias identifier. + * @param orig + * The (existing) original identifier. + * + * @return + * An error, if one occurred. + */ + def addAlias(alias: Ident, orig: ScopeValue, public: Boolean = false): Chain[Error] = { + verifyMutable() + val path = alias.pos.file.getCanonicalPath + val identType = alias.ty match { + case _: SemType => IdentType.Var + case _: FuncType => IdentType.Func + } + val key = ScopeKey(path, alias.v, identType) + current.get(key) match { + case Some(ScopeValue(Ident(_, id, _), _)) => + alias.guid = id + Chain.one(Error.DuplicateDeclaration(alias)) + case None => + alias.guid = nextGuid() + current(key) = ScopeValue(orig.id, public) + Chain.empty + } + } + + def get(path: String, name: String, identType: IdentType): Option[ScopeValue] = // Unfortunately map defaults only work with `.apply()`, which throws an error when the key is not found. // Neither is there a way to check whether a default exists, so we have to use a try-catch. try { - Some(current.withDefault(parent)((name, identType))) + Some(current.withDefault(parent)(ScopeKey(path, name, identType))) } catch { case _: NoSuchElementException => None } - def getVar(name: String): Option[Ident] = get(name, IdentType.Var) - def getFunc(name: String): Option[Ident] = get(name, IdentType.Func) + def getVar(name: Ident): Option[Ident] = + get(name.pos.file.getCanonicalPath, name.v, IdentType.Var).map(_.id) + def getFunc(name: Ident): Option[Ident] = + get(name.pos.file.getCanonicalPath, name.v, IdentType.Func).map(_.id) } - /** Check scoping of all variables and functions in the program. Also generate semantic types for - * all identifiers. - * - * @param prog - * AST of the program - * @param errors - * List of errors to append to - * @return - * Map of all (renamed) identifies to their semantic types - */ - def rename(prog: Program)(using - errors: mutable.Builder[Error, List[Error]] - ): (Map[Ident, SemType], Map[Ident, FuncType]) = { - given globalNames: mutable.Map[Ident, SemType] = mutable.Map.empty - given globalFuncs: mutable.Map[Ident, FuncType] = mutable.Map.empty - given globalNumbering: mutable.Map[String, Int] = mutable.Map.empty - val scope = Scope(mutable.Map.empty, Map.empty) + private def prepareGlobalScope( + partialProg: PartialProgram + )(using scope: Scope): IO[(FuncDecl, Chain[FuncDecl], Chain[Error])] = { + def readImportFile(file: File): IO[String] = + IO.blocking(os.read(os.Path(file.getCanonicalPath))) + + def prepareImport(contents: String, file: File)(using + scope: Scope + ): IO[(Chain[FuncDecl], Chain[Error])] = { + parser.parse(contents) match { + case Failure(msg) => + IO.pure(Chain.empty, Chain.one(Error.SyntaxError(file, msg))) + case Success(fn) => + val partialProg = fn(file) + for { + (main, chunks, errors) <- prepareGlobalScope(partialProg) + } yield (main +: chunks, errors) + } + } + + def addImportsToScope(importFile: File, funcs: NonEmptyList[ImportedFunc])(using + scope: Scope + ): Chain[Error] = + funcs.foldMap { case ImportedFunc(srcName, aliasName) => + scope.get(importFile.getCanonicalPath, srcName.v, IdentType.Func) match { + case Some(src) if src.public => + aliasName.ty = src.id.ty + scope.addAlias(aliasName, src) + case _ => + Chain.one(Error.UndefinedFunction(srcName)) + } + } + + val PartialProgram(imports, prog) = partialProg + + // First prepare this file's functions... val Program(funcs, main) = prog - funcs - // First add all function declarations to the scope - .map { case FuncDecl(retType, name, params, body) => + val (funcChunks, funcErrors) = funcs.foldLeft((Chain.empty[FuncDecl], Chain.empty[Error])) { + case ((chunks, errors), func @ FuncDecl(retType, name, params, body)) => val paramTypes = params.map { param => val paramType = SemType(param.paramType) + param.name.ty = paramType paramType } - scope.add(FuncType(SemType(retType), paramTypes), name) - (params zip paramTypes, body) + name.ty = FuncType(SemType(retType), paramTypes) + (chunks :+ func, errors ++ scope.add(name, public = true)) + } + // ...and main body. + val mainBodyIdent = Ident(MAIN, ty = FuncType(?, Nil))(prog.pos) + val mainBodyErrors = scope.add(mainBodyIdent, public = false) + val mainBodyChunk = FuncDecl(IntType()(prog.pos), mainBodyIdent, Nil, main)(prog.pos) + + // Now handle imports + val file = prog.pos.file + val preparedImports = imports.foldLeftM[IO, (Chain[FuncDecl], Chain[Error])]( + (Chain.empty[FuncDecl], Chain.empty[Error]) + ) { case ((chunks, errors), Import(name, funcs)) => + val importFile = File(file.getParent, name.v) + if (!importFile.exists()) { + IO.pure( + ( + chunks, + errors :+ Error.SemanticError( + name.pos, + s"File not found: ${importFile.getCanonicalPath}" + ) + ) + ) + } else if (!importFile.canRead()) { + IO.pure( + ( + chunks, + errors :+ Error.SemanticError( + name.pos, + s"File not readable: ${importFile.getCanonicalPath}" + ) + ) + ) + } else if (importFile.getCanonicalPath == file.getCanonicalPath) { + IO.pure( + ( + chunks, + errors :+ Error.SemanticError( + name.pos, + s"Cannot import self: ${importFile.getCanonicalPath}" + ) + ) + ) + } else if (scope.get(importFile.getCanonicalPath, MAIN, IdentType.Func).isDefined) { + IO.pure(chunks, errors ++ addImportsToScope(importFile, funcs)) + } else { + for { + contents <- readImportFile(importFile) + (importChunks, importErrors) <- prepareImport(contents, importFile) + importAliasErrors = addImportsToScope(importFile, funcs) + } yield (chunks ++ importChunks, errors ++ importErrors) } - // Only then rename the function bodies - // (functions can call one-another regardless of order of declaration) - .foreach { case (params, body) => - val functionScope = scope.subscope - params.foreach { case (param, paramType) => - functionScope.add(paramType, param.name) - } - body.toList.foreach(rename(functionScope.subscope)) // body can shadow function params - } - main.toList.foreach(rename(scope)) - (globalNames.toMap, globalFuncs.toMap) + } + + for { + (importChunks, importErrors) <- preparedImports + allChunks = importChunks ++ funcChunks + allErrors = importErrors ++ funcErrors ++ mainBodyErrors + } yield (mainBodyChunk, allChunks, allErrors) + } + + /** Check scoping of all variables and flatten a program. Also generates semantic types and parses + * any imported files. + * + * @param partialProg + * AST of the program + * @return + * (flattenedProg, errors) + */ + private def renameFunction(funcScopePair: (FuncDecl, Scope)): IO[Chain[Error]] = { + val (FuncDecl(_, _, params, body), subscope) = funcScopePair + val paramErrors = params.foldMap(param => subscope.add(param.name)) + IO(subscope.withSubscope { s => body.foldMap(rename(s)) }) + .map(bodyErrors => paramErrors ++ bodyErrors) + } + + def rename(partialProg: PartialProgram): IO[(Program, Chain[Error])] = { + given scope: Scope = Scope(mutable.Map.empty, Map.empty) + + for { + (main, chunks, globalErrors) <- prepareGlobalScope(partialProg) + toRename = (main +: chunks).toList + allErrors <- toRename + .zip(scope.subscopes(toRename.size)) + .parFoldMapA(renameFunction) + // .map(x => x.combineAll) + } yield (Program(chunks.toList, main.body)(main.pos), globalErrors ++ allErrors) } /** Check scoping of all identifies in a given AST node. @@ -129,91 +293,90 @@ object renamer { * The current scope and flattened parent scope. * @param node * The AST node. - * @param globalNames - * The global map of identifiers to semantic types - renamed identifiers will be added to this - * map. - * @param globalNumbering - * The global map of identifier names to the number of times they have been declared - used and - * updated during identifier renaming. - * @param errors */ - private def rename(scope: Scope)( - node: Ident | Stmt | LValue | RValue | Expr - )(using - globalNames: mutable.Map[Ident, SemType], - globalFuncs: mutable.Map[Ident, FuncType], - globalNumbering: mutable.Map[String, Int], - errors: mutable.Builder[Error, List[Error]] - ): Unit = node match { - // These cases are more interesting because the involve making subscopes - // or modifying the current scope. - case VarDecl(synType, name, value) => { - // Order matters here. Variable isn't declared until after the value is evaluated. - rename(scope)(value) - // Attempt to add the new variable to the current scope. - scope.add(SemType(synType), name) - } - case If(cond, thenStmt, elseStmt) => { - rename(scope)(cond) - // then and else both have their own scopes - thenStmt.toList.foreach(rename(scope.subscope)) - elseStmt.toList.foreach(rename(scope.subscope)) - } - case While(cond, body) => { - rename(scope)(cond) - // while bodies have their own scopes - body.toList.foreach(rename(scope.subscope)) - } - // begin-end blocks have their own scopes - case Block(body) => body.toList.foreach(rename(scope.subscope)) + private def rename(scope: Scope)(node: Ident | Stmt | LValue | RValue | Expr): Chain[Error] = + node match { + // These cases are more interes/globting because the involve making subscopes + // or modifying the current scope. + case VarDecl(synType, name, value) => { + // Order matters here. Variable isn't declared until after the value is evaluated. + val errors = rename(scope)(value) + // Attempt to add the new variable to the current scope. + name.ty = SemType(synType) + errors ++ scope.add(name) + } + case If(cond, thenStmt, elseStmt) => { + val condErrors = rename(scope)(cond) + // then and else both have their own scopes + val thenErrors = scope.withSubscope(s => thenStmt.foldMap(rename(s))) + val elseErrors = scope.withSubscope(s => elseStmt.foldMap(rename(s))) + condErrors ++ thenErrors ++ elseErrors + } + case While(cond, body) => { + val condErrors = rename(scope)(cond) + // while bodies have their own scopes + val bodyErrors = scope.withSubscope(s => body.foldMap(rename(s))) + condErrors ++ bodyErrors + } + // begin-end blocks have their own scopes + case Block(body) => scope.withSubscope(s => body.foldMap(rename(s))) - // These cases are simpler, mostly just recursive calls to rename() - case Assign(lhs, value) => { - // Variables may be reassigned with their value in the rhs, so order doesn't matter here. - rename(scope)(lhs) - rename(scope)(value) - } - case Read(lhs) => rename(scope)(lhs) - case Free(expr) => rename(scope)(expr) - case Return(expr) => rename(scope)(expr) - case Exit(expr) => rename(scope)(expr) - case Print(expr, _) => rename(scope)(expr) - case NewPair(fst, snd) => { - rename(scope)(fst) - rename(scope)(snd) - } - case Call(name, args) => { - scope.getFunc(name.v) match { - case Some(Ident(_, uid)) => name.uid = uid - case None => - errors += Error.UndefinedFunction(name) - scope.add(FuncType(?, args.map(_ => ?)), name) + // These cases are simpler, mostly just recursive calls to rename() + case Assign(lhs, value) => { + // Variables may be reassigned with their value in the rhs, so order doesn't matter here. + rename(scope)(lhs) ++ rename(scope)(value) } - args.foreach(rename(scope)) - } - case Fst(elem) => rename(scope)(elem) - case Snd(elem) => rename(scope)(elem) - case ArrayLiter(elems) => elems.foreach(rename(scope)) - case ArrayElem(name, indices) => { - rename(scope)(name) - indices.toList.foreach(rename(scope)) - } - case Parens(expr) => rename(scope)(expr) - case op: UnaryOp => rename(scope)(op.x) - case op: BinaryOp => { - rename(scope)(op.x) - rename(scope)(op.y) - } - // Default to variables. Only `call` uses IdentType.Func. - case id: Ident => { - scope.getVar(id.v) match { - case Some(Ident(_, uid)) => id.uid = uid - case None => - errors += Error.UndeclaredVariable(id) - scope.add(?, id) + case Read(lhs) => rename(scope)(lhs) + case Free(expr) => rename(scope)(expr) + case Return(expr) => rename(scope)(expr) + case Exit(expr) => rename(scope)(expr) + case Print(expr, _) => rename(scope)(expr) + case NewPair(fst, snd) => { + rename(scope)(fst) ++ rename(scope)(snd) } + case Call(name, args) => { + val nameErrors = scope.getFunc(name) match { + case Some(Ident(realName, guid, ty)) => + name.v = realName + name.ty = ty + name.guid = guid + Chain.empty + case None => + name.ty = FuncType(?, args.map(_ => ?)) + scope.add(name) + Chain.one(Error.UndefinedFunction(name)) + } + val argsErrors = args.foldMap(rename(scope)) + nameErrors ++ argsErrors + } + case Fst(elem) => rename(scope)(elem) + case Snd(elem) => rename(scope)(elem) + case ArrayLiter(elems) => elems.foldMap(rename(scope)) + case ArrayElem(name, indices) => { + val nameErrors = rename(scope)(name) + val indicesErrors = indices.foldMap(rename(scope)) + nameErrors ++ indicesErrors + } + case Parens(expr) => rename(scope)(expr) + case op: UnaryOp => rename(scope)(op.x) + case op: BinaryOp => { + rename(scope)(op.x) ++ rename(scope)(op.y) + } + // Default to variables. Only `call` uses IdentType.Func. + case id: Ident => { + scope.getVar(id) match { + case Some(Ident(_, guid, ty)) => + id.ty = ty + id.guid = guid + Chain.empty + case None => + id.ty = ? + scope.add(id) + Chain.one(Error.UndeclaredVariable(id)) + } + } + // These literals cannot contain identifies, exit immediately. + case IntLiter(_) | BoolLiter(_) | CharLiter(_) | StrLiter(_) | PairLiter() | Skip() => + Chain.empty } - // These literals cannot contain identifies, exit immediately. - case IntLiter(_) | BoolLiter(_) | CharLiter(_) | StrLiter(_) | PairLiter() | Skip() => () - } } diff --git a/src/main/wacc/frontend/typeChecker.scala b/src/main/wacc/frontend/typeChecker.scala index a628b69..6f5804b 100644 --- a/src/main/wacc/frontend/typeChecker.scala +++ b/src/main/wacc/frontend/typeChecker.scala @@ -8,13 +8,8 @@ object typeChecker { import wacc.types._ case class TypeCheckerCtx( - globalNames: Map[ast.Ident, SemType], - globalFuncs: Map[ast.Ident, FuncType], errors: mutable.Builder[Error, List[Error]] ) { - def typeOf(ident: ast.Ident): SemType = globalNames(ident) - def funcType(ident: ast.Ident): FuncType = globalFuncs(ident) - def error(err: Error): SemType = errors += err ? @@ -99,18 +94,17 @@ object typeChecker { * The type checker context which includes the global names and functions, and an errors * builder. */ - def check(prog: ast.Program)(using - ctx: TypeCheckerCtx - ): microWacc.Program = + def check(prog: ast.Program, errors: mutable.Builder[Error, List[Error]]): microWacc.Program = + given ctx: TypeCheckerCtx = TypeCheckerCtx(errors) microWacc.Program( // Ignore function syntax types for return value and params, since those have been converted // to SemTypes by the renamer. prog.funcs.map { case ast.FuncDecl(_, name, params, stmts) => - val FuncType(retType, paramTypes) = ctx.funcType(name) + val FuncType(retType, paramTypes) = name.ty.asInstanceOf[FuncType] microWacc.FuncDecl( - microWacc.Ident(name.v, name.uid)(retType), + microWacc.Ident(name.v, name.guid)(retType), params.zip(paramTypes).map { case (ast.Param(_, ident), ty) => - microWacc.Ident(ident.v, ident.uid)(ty) + microWacc.Ident(ident.v, ident.guid)(ty) }, stmts.toList .flatMap( @@ -134,15 +128,20 @@ object typeChecker { ): List[microWacc.Stmt] = stmt match { // Ignore the type of the variable, since it has been converted to a SemType by the renamer. case ast.VarDecl(_, name, value) => - val expectedTy = ctx.typeOf(name) + val expectedTy = name.ty val typedValue = checkValue( value, Constraint.Is( - expectedTy, + expectedTy.asInstanceOf[SemType], s"variable ${name.v} must be assigned a value of type $expectedTy" ) ) - List(microWacc.Assign(microWacc.Ident(name.v, name.uid)(expectedTy), typedValue)) + List( + microWacc.Assign( + microWacc.Ident(name.v, name.guid)(expectedTy.asInstanceOf[SemType]), + typedValue + ) + ) case ast.Assign(lhs, rhs) => val lhsTyped = checkLValue(lhs, Constraint.Unconstrained) val rhsTyped = @@ -315,7 +314,7 @@ object typeChecker { KnownType.Pair(fstTyped.ty, sndTyped.ty).satisfies(constraint, l.pos) ) case ast.Call(id, args) => - val funcTy @ FuncType(retTy, paramTys) = ctx.funcType(id) + val funcTy @ FuncType(retTy, paramTys) = id.ty.asInstanceOf[FuncType] if (args.length != paramTys.length) { ctx.error(Error.FunctionParamsMismatch(id, paramTys.length, args.length, funcTy)) } @@ -324,7 +323,7 @@ object typeChecker { val argsTyped = args.zip(paramTys).map { case (arg, paramTy) => checkValue(arg, Constraint.Is(paramTy, s"argument type mismatch in function ${id.v}")) } - microWacc.Call(microWacc.Ident(id.v, id.uid)(retTy.satisfies(constraint, id.pos)), argsTyped) + microWacc.Call(microWacc.Ident(id.v, id.guid)(retTy.satisfies(constraint, id.pos)), argsTyped) // Unary operators case ast.Negate(x) => @@ -416,30 +415,32 @@ object typeChecker { private def checkLValue(value: ast.LValue, constraint: Constraint)(using ctx: TypeCheckerCtx ): microWacc.LValue = value match { - case id @ ast.Ident(name, uid) => - microWacc.Ident(name, uid)(ctx.typeOf(id).satisfies(constraint, id.pos)) + case id @ ast.Ident(name, guid, ty) => + microWacc.Ident(name, guid)(ty.asInstanceOf[SemType].satisfies(constraint, id.pos)) case ast.ArrayElem(id, indices) => - val arrayTy = ctx.typeOf(id) - val (elemTy, indicesTyped) = indices.mapAccumulate(arrayTy) { (acc, elem) => - val idxTyped = checkValue(elem, Constraint.Is(KnownType.Int, "array index must be an int")) - val next = acc match { - case KnownType.Array(innerTy) => innerTy - case ? => ? // we can keep indexing an unknown type - case nonArrayTy => - ctx.error( - Error.TypeMismatch( - elem.pos, - KnownType.Array(?), - acc, - "cannot index into a non-array" + val arrayTy = id.ty.asInstanceOf[SemType] + val (elemTy, indicesTyped) = indices.mapAccumulate(arrayTy.asInstanceOf[SemType]) { + (acc, elem) => + val idxTyped = + checkValue(elem, Constraint.Is(KnownType.Int, "array index must be an int")) + val next = acc match { + case KnownType.Array(innerTy) => innerTy + case ? => ? // we can keep indexing an unknown type + case nonArrayTy => + ctx.error( + Error.TypeMismatch( + elem.pos, + KnownType.Array(?), + acc, + "cannot index into a non-array" + ) ) - ) - ? - } - (next, idxTyped) + ? + } + (next, idxTyped) } val firstArrayElem = microWacc.ArrayElem( - microWacc.Ident(id.v, id.uid)(arrayTy), + microWacc.Ident(id.v, id.guid)(arrayTy), indicesTyped.head )(elemTy.satisfies(constraint, value.pos)) val arrayElem = indicesTyped.tail.foldLeft(firstArrayElem) { (acc, idx) => diff --git a/src/main/wacc/frontend/types.scala b/src/main/wacc/frontend/types.scala index 549d8a1..5251396 100644 --- a/src/main/wacc/frontend/types.scala +++ b/src/main/wacc/frontend/types.scala @@ -3,7 +3,9 @@ package wacc object types { import ast._ - sealed trait SemType { + sealed trait RenamerType + + sealed trait SemType extends RenamerType { override def toString(): String = this match { case KnownType.Int => "int" case KnownType.Bool => "bool" @@ -41,5 +43,5 @@ object types { } } - case class FuncType(returnType: SemType, params: List[SemType]) + case class FuncType(returnType: SemType, params: List[SemType]) extends RenamerType } diff --git a/src/test/wacc/examples.scala b/src/test/wacc/examples.scala index 11093d6..a0f4564 100644 --- a/src/test/wacc/examples.scala +++ b/src/test/wacc/examples.scala @@ -26,6 +26,15 @@ class ParallelExamplesSpec extends AsyncFreeSpec with AsyncIOSpec with BeforeAnd } ++ allWaccFiles("wacc-examples/invalid/whack").map { p => (p.toString, List(100, 200)) + } ++ + allWaccFiles("extension/examples/valid").map { p => + (p.toString, List(0)) + } ++ + allWaccFiles("extension/examples/invalid/syntax").map { p => + (p.toString, List(100)) + } ++ + allWaccFiles("extension/examples/invalid/semantics").map { p => + (p.toString, List(200)) } forEvery(files) { (filename, expectedResult) => @@ -33,18 +42,21 @@ class ParallelExamplesSpec extends AsyncFreeSpec with AsyncIOSpec with BeforeAnd s"$filename" - { "should be compiled with correct result" in { - compileWacc(Path.of(filename), outputDir = None, log = false).map { result => - expectedResult should contain(result) - } + if (fileIsPendingFrontend(filename)) + IO.pure(pending) + else + compileWacc(Path.of(filename), outputDir = None, log = false).map { result => + expectedResult should contain(result) + } } if (expectedResult == List(0)) { "should run with correct result" in { if (fileIsDisallowedBackend(filename)) - IO.pure( - succeed - ) // TODO: remove when advanced tests removed. not sure how to "pending" this otherwise - else { + IO.pure(succeed) + else if (fileIsPendingBackend(filename)) + IO.pure(pending) + else for { contents <- IO(Source.fromFile(File(filename)).getLines.toList) inputLine = extractInput(contents) @@ -75,7 +87,6 @@ class ParallelExamplesSpec extends AsyncFreeSpec with AsyncIOSpec with BeforeAnd exitCode shouldBe expectedExit normalizeOutput(stdout.toString) shouldBe expectedOutput } - } } } } @@ -85,10 +96,21 @@ class ParallelExamplesSpec extends AsyncFreeSpec with AsyncIOSpec with BeforeAnd val d = java.io.File(dir) os.walk(os.Path(d.getAbsolutePath)).filter(_.ext == "wacc") - // TODO: eventually remove this I think - def fileIsDisallowedBackend(filename: String): Boolean = - Seq( - "^.*wacc-examples/valid/advanced.*$" + private def fileIsDisallowedBackend(filename: String): Boolean = + filename.matches("^.*wacc-examples/valid/advanced.*$") + + private def fileIsPendingFrontend(filename: String): Boolean = + List( + // "^.*extension/examples/invalid/syntax/imports/importBadSyntax.*$", + // "^.*extension/examples/invalid/semantics/imports.*$", + // "^.*extension/examples/valid/imports.*$" + ).exists(filename.matches) + + private def fileIsPendingBackend(filename: String): Boolean = + List( + // "^.*extension/examples/invalid/syntax/imports.*$", + // "^.*extension/examples/invalid/semantics/imports.*$", + // "^.*extension/examples/valid/imports.*$" ).exists(filename.matches) private def extractInput(contents: List[String]): String =