diff --git a/.gitignore b/.gitignore index 03801cc..dab1691 100644 --- a/.gitignore +++ b/.gitignore @@ -4,4 +4,3 @@ .vscode/ wacc-examples/ .idea/ - diff --git a/compile b/compile index 5505e42..444f591 100755 --- a/compile +++ b/compile @@ -4,6 +4,6 @@ # but do *not* change its name. # feel free to adjust to suit the specific internal flags of your compiler -./wacc-compiler "$@" +./wacc-compiler --output . "$@" exit $? diff --git a/extension/examples/invalid/semantics/badWacc.wacc b/extension/examples/invalid/semantics/badWacc.wacc new file mode 100644 index 0000000..a334d57 --- /dev/null +++ b/extension/examples/invalid/semantics/badWacc.wacc @@ -0,0 +1,10 @@ +begin + int main() is + int a = 5 ; + string b = "Hello" ; + return a + b + end + + int result = call main() ; + exit result +end diff --git a/extension/examples/invalid/semantics/imports/importBadFile.wacc b/extension/examples/invalid/semantics/imports/importBadFile.wacc new file mode 100644 index 0000000..b116ccd --- /dev/null +++ b/extension/examples/invalid/semantics/imports/importBadFile.wacc @@ -0,0 +1,6 @@ +import "./doesNotExist.wacc" (main) + +begin + int result = call main() ; + exit result +end diff --git a/extension/examples/invalid/semantics/imports/importBadFunc.wacc b/extension/examples/invalid/semantics/imports/importBadFunc.wacc new file mode 100644 index 0000000..bf3c9a0 --- /dev/null +++ b/extension/examples/invalid/semantics/imports/importBadFunc.wacc @@ -0,0 +1,6 @@ +import "../../../valid/sum.wacc" (mult) + +begin + int result = call mult(3, 2) ; + exit result +end diff --git a/extension/examples/invalid/semantics/imports/importBadSem.wacc b/extension/examples/invalid/semantics/imports/importBadSem.wacc new file mode 100644 index 0000000..d20e3a6 --- /dev/null +++ b/extension/examples/invalid/semantics/imports/importBadSem.wacc @@ -0,0 +1,10 @@ +import "../badWacc.wacc" (main) + +begin + int sum(int a, int b) is + return a + b + end + + int result = call main() ; + exit result +end diff --git a/extension/examples/invalid/semantics/imports/importBadSem2.wacc b/extension/examples/invalid/semantics/imports/importBadSem2.wacc new file mode 100644 index 0000000..4bd330e --- /dev/null +++ b/extension/examples/invalid/semantics/imports/importBadSem2.wacc @@ -0,0 +1,6 @@ +import "./importBadSem.wacc" (sum) + +begin + int result = call sum(1, 2) ; + exit result +end diff --git a/extension/examples/invalid/semantics/imports/inderect.wacc b/extension/examples/invalid/semantics/imports/inderect.wacc new file mode 100644 index 0000000..120f9ba --- /dev/null +++ b/extension/examples/invalid/semantics/imports/inderect.wacc @@ -0,0 +1,6 @@ +import "../../../valid/imports/basic.wacc" (sum) + +begin + int result = call sum(3, 2) ; + exit result +end diff --git a/extension/examples/invalid/syntax/badWacc.wacc b/extension/examples/invalid/syntax/badWacc.wacc new file mode 100644 index 0000000..a375309 --- /dev/null +++ b/extension/examples/invalid/syntax/badWacc.wacc @@ -0,0 +1,6 @@ +int main() is + println "Hello World!" ; + return 0 +end + +skip diff --git a/extension/examples/invalid/syntax/imports/emptyImport.wacc b/extension/examples/invalid/syntax/imports/emptyImport.wacc new file mode 100644 index 0000000..ec9dbd0 --- /dev/null +++ b/extension/examples/invalid/syntax/imports/emptyImport.wacc @@ -0,0 +1,8 @@ +import "../../../valid/sum.wacc" sum, main + +begin + int result1 = call sum(5, 10) ; + int result2 = call main() ; + println result1 ; + println result2 +end diff --git a/extension/examples/invalid/syntax/imports/emptyImport2.wacc b/extension/examples/invalid/syntax/imports/emptyImport2.wacc new file mode 100644 index 0000000..99d38b9 --- /dev/null +++ b/extension/examples/invalid/syntax/imports/emptyImport2.wacc @@ -0,0 +1,5 @@ +import "../../../valid/sum.wacc" () + +begin + exit 0 +end diff --git a/extension/examples/invalid/syntax/imports/importBadSyntax.wacc b/extension/examples/invalid/syntax/imports/importBadSyntax.wacc new file mode 100644 index 0000000..d20e3a6 --- /dev/null +++ b/extension/examples/invalid/syntax/imports/importBadSyntax.wacc @@ -0,0 +1,10 @@ +import "../badWacc.wacc" (main) + +begin + int sum(int a, int b) is + return a + b + end + + int result = call main() ; + exit result +end diff --git a/extension/examples/invalid/syntax/imports/importBadSyntax2.wacc b/extension/examples/invalid/syntax/imports/importBadSyntax2.wacc new file mode 100644 index 0000000..0e0e0e1 --- /dev/null +++ b/extension/examples/invalid/syntax/imports/importBadSyntax2.wacc @@ -0,0 +1,6 @@ +import "./importBadSyntax.wacc" (sum) + +begin + int result = call sum(1, 2) ; + exit result +end diff --git a/extension/examples/invalid/syntax/imports/importNoParens.wacc b/extension/examples/invalid/syntax/imports/importNoParens.wacc new file mode 100644 index 0000000..8b13789 --- /dev/null +++ b/extension/examples/invalid/syntax/imports/importNoParens.wacc @@ -0,0 +1 @@ + diff --git a/extension/examples/invalid/syntax/imports/importSemis.wacc b/extension/examples/invalid/syntax/imports/importSemis.wacc new file mode 100644 index 0000000..f127844 --- /dev/null +++ b/extension/examples/invalid/syntax/imports/importSemis.wacc @@ -0,0 +1,9 @@ +import "../../../valid/sum.wacc" (sum) ; +import "../../../valid/sum.wacc" (main) ; + +begin + int result1 = call sum(5, 10) ; + int result2 = call main() ; + println result1 ; + println result2 +end diff --git a/extension/examples/invalid/syntax/imports/importStar.wacc b/extension/examples/invalid/syntax/imports/importStar.wacc new file mode 100644 index 0000000..e027caa --- /dev/null +++ b/extension/examples/invalid/syntax/imports/importStar.wacc @@ -0,0 +1,5 @@ +import "../../../valid/sum.wacc" * + +begin + exit 0 +end diff --git a/extension/examples/invalid/syntax/imports/importStar2.wacc b/extension/examples/invalid/syntax/imports/importStar2.wacc new file mode 100644 index 0000000..bae08ef --- /dev/null +++ b/extension/examples/invalid/syntax/imports/importStar2.wacc @@ -0,0 +1,5 @@ +import "../../../valid/sum.wacc" (*) + +begin + exit 0 +end diff --git a/extension/examples/valid/.gitignore b/extension/examples/valid/.gitignore new file mode 100644 index 0000000..ed87167 --- /dev/null +++ b/extension/examples/valid/.gitignore @@ -0,0 +1,7 @@ +* + +!imports/ +imports/* + +!.gitignore +!*.wacc diff --git a/extension/examples/valid/imports/alias.wacc b/extension/examples/valid/imports/alias.wacc new file mode 100644 index 0000000..91496a1 --- /dev/null +++ b/extension/examples/valid/imports/alias.wacc @@ -0,0 +1,22 @@ +# import main from ../sum.wacc and ./basic.wacc + +# Output: +# 15 +# 0 +# -33 +# + +# Exit: +# 0 + +# Program: + +import "../sum.wacc" (main as sumMain) +import "./basic.wacc" (main) + +begin + int result1 = call sumMain() ; + int result2 = call main() ; + println result1 ; + println result2 +end diff --git a/extension/examples/valid/imports/basic.wacc b/extension/examples/valid/imports/basic.wacc new file mode 100644 index 0000000..d34a34a --- /dev/null +++ b/extension/examples/valid/imports/basic.wacc @@ -0,0 +1,21 @@ +# import sum from ../sum.wacc + +# Output: +# -33 +# + +# Exit: +# 0 + +# Program: + +import "../sum.wacc" (sum) + +begin + int main() is + int result = call sum(-10, -23) ; + return result + end + int result = call main() ; + println result +end diff --git a/extension/examples/valid/imports/manyMains.wacc b/extension/examples/valid/imports/manyMains.wacc new file mode 100644 index 0000000..fc3bc7c --- /dev/null +++ b/extension/examples/valid/imports/manyMains.wacc @@ -0,0 +1,33 @@ +# import all the mains + +# Output: +# 15 +# -33 +# 0 +# -33 +# 0 +# + +# Exit: +# 99 + +# Program: + +import "../sum.wacc" (main as sumMain) +import "./basic.wacc" (main as basicMain) +import "./multiFunc.wacc" (main as multiFuncMain) + +begin + int main() is + int result1 = call sumMain() ; + int result2 = call basicMain() ; + int result3 = call multiFuncMain() ; + println result1 ; + println result2 ; + println result3 ; + return 99 + end + + int result = call main() ; + exit result +end diff --git a/extension/examples/valid/imports/multiFunc.wacc b/extension/examples/valid/imports/multiFunc.wacc new file mode 100644 index 0000000..22d6e4d --- /dev/null +++ b/extension/examples/valid/imports/multiFunc.wacc @@ -0,0 +1,27 @@ +# import sum, main from ../sum.wacc + +# Output: +# 15 +# -33 +# 0 +# 0 +# + +# Exit: +# 0 + +# Program: + +import "../sum.wacc" (sum, main as sumMain) + +begin + int main() is + int result = call sum(-10, -23) ; + println result ; + return 0 + end + int result1 = call sumMain() ; + int result2 = call main() ; + println result1 ; + println result2 +end diff --git a/extension/examples/valid/sum.wacc b/extension/examples/valid/sum.wacc new file mode 100644 index 0000000..dc62e24 --- /dev/null +++ b/extension/examples/valid/sum.wacc @@ -0,0 +1,27 @@ +# simple sum program + +# Output: +# 15 +# + +# Exit: +# 0 + +# Program: + +begin + int sum(int a, int b) is + return a + b + end + + int main() is + int a = 5 ; + int b = 10 ; + int result = call sum(a, b) ; + println result ; + return 0 + end + + int result = call main() ; + exit result +end diff --git a/project.scala b/project.scala index 4deaa08..8edf035 100644 --- a/project.scala +++ b/project.scala @@ -5,23 +5,18 @@ //> using dep com.github.j-mie6::parsley::5.0.0-M10 //> using dep com.github.j-mie6::parsley-cats::1.5.0 //> using dep com.lihaoyi::os-lib::0.11.4 -//> using dep com.github.scopt::scopt::4.1.0 +//> using dep org.typelevel::cats-core::2.13.0 +//> using dep org.typelevel::cats-effect::3.5.7 +//> using dep com.monovore::decline::2.5.0 +//> using dep com.monovore::decline-effect::2.5.0 +//> using dep org.typelevel::log4cats-slf4j::2.7.0 +//> using dep org.slf4j:slf4j-simple:2.0.17 //> using test.dep org.scalatest::scalatest::3.2.19 +//> using dep org.typelevel::cats-effect-testing-scalatest::1.6.0 -// these are all sensible defaults to catch annoying issues +// sensible defaults for warnings and compiler checks //> using options -deprecation -unchecked -feature //> using options -Wimplausible-patterns -Wunused:all //> using options -Yexplicit-nulls -Wsafe-init -Xkind-projector:underscores -// these will help ensure you have access to the latest parsley releases -// even before they land on maven proper, or snapshot versions, if necessary. -// just in case they cause problems, however, keep them turned off unless you -// specifically need them. -// using repositories sonatype-s01:releases -// using repositories sonatype-s01:snapshots - -// these are flags used by Scala native: if you aren't using scala-native, then they do nothing -// lto-thin has decent linking times, and release-fast does not too much optimisation. -// using nativeLto thin -// using nativeGc commix -// using nativeMode release-fast +// repositories for pre-release versions if needed diff --git a/src/main/wacc/Main.scala b/src/main/wacc/Main.scala index fc9fb45..e78d4bd 100644 --- a/src/main/wacc/Main.scala +++ b/src/main/wacc/Main.scala @@ -1,92 +1,166 @@ package wacc import scala.collection.mutable -import cats.data.Chain +import cats.data.{Chain, NonEmptyList} import parsley.{Failure, Success} -import scopt.OParser -import java.io.File -import java.io.PrintStream + +import java.nio.file.{Files, Path} +import cats.syntax.all._ + +import cats.effect.IO +import cats.effect.ExitCode + +import com.monovore.decline._ +import com.monovore.decline.effect._ + +import org.typelevel.log4cats.slf4j.Slf4jLogger +import org.typelevel.log4cats.Logger import assemblyIR as asm +import cats.data.ValidatedNel +import java.io.File -case class CliConfig( - file: File = new File(".") -) +/* +TODO: + 1) IO correctness + 2) Errors can be handled more gracefully - currently, parallelised compilation is not fail fast as far as I am aware + 3) splitting the file up and nicer refactoring + 4) logging could be removed + 5) general cleanup and comments (things like replacing home/ with ~ , and names of parameters and args, descriptions etc) + */ -val cliBuilder = OParser.builder[CliConfig] -val cliParser = { - import cliBuilder._ - OParser.sequence( - programName("wacc-compiler"), - help('h', "help") - .text("Prints this help message"), - arg[File]("") - .text("Input WACC source file") - .required() - .action((f, c) => c.copy(file = f)) - .validate(f => - if (!f.exists) failure("File does not exist") - else if (!f.isFile) failure("File must be a regular file") - else if (!f.getName.endsWith(".wacc")) - failure("File must have .wacc extension") - else success - ) - ) +private val SUCCESS = ExitCode.Success.code +private val ERROR = ExitCode.Error.code + +given logger: Logger[IO] = Slf4jLogger.getLogger[IO] + +val logOpt: Opts[Boolean] = + Opts.flag("log", "Enable logging for additional compilation details", short = "l").orFalse + +def validateFile(path: Path): ValidatedNel[String, Path] = { + (for { + // TODO: redundant 2nd parameter :( + _ <- Either.cond(Files.exists(path), (), s"File '${path}' does not exist") + _ <- Either.cond(Files.isRegularFile(path), (), s"File '${path}' must be a regular file") + _ <- Either.cond(path.toString.endsWith(".wacc"), (), "File must have .wacc extension") + } yield path).toValidatedNel } +val filesOpt: Opts[NonEmptyList[Path]] = + Opts.arguments[Path]("files").mapValidated { + _.traverse(validateFile) + } + +val outputOpt: Opts[Option[Path]] = + Opts + .option[Path]("output", metavar = "path", help = "Output directory for compiled files.") + .validate("Must have permissions to create & access the output path") { path => + try { + Files.createDirectories(path) + true + } catch { + case e: java.nio.file.AccessDeniedException => + false + } + } + .validate("Output path must be a directory") { path => + Files.isDirectory(path) + } + .orNone + def frontend( - contents: String -)(using stdout: PrintStream): Either[microWacc.Program, Int] = { + contents: String, + file: File +): IO[Either[NonEmptyList[Error], microWacc.Program]] = parser.parse(contents) match { - case Success(prog) => + case Failure(msg) => IO.pure(Left(NonEmptyList.one(Error.SyntaxError(file, msg)))) + case Success(fn) => + val partialProg = fn(file) given errors: mutable.Builder[Error, List[Error]] = List.newBuilder - val (names, funcs) = renamer.rename(prog) - given ctx: typeChecker.TypeCheckerCtx = typeChecker.TypeCheckerCtx(names, funcs, errors) - val typedProg = typeChecker.check(prog) - if (errors.result.nonEmpty) { - given errorContent: String = contents - Right( - errors.result - .map { error => - printError(error) - error match { - case _: Error.InternalError => 201 - case _ => 200 - } - } - .max() - ) - } else Left(typedProg) - case Failure(msg) => - stdout.println(msg) - Right(100) - } -} -val s = "enter an integer to echo" + for { + (prog, renameErrors) <- renamer.rename(partialProg) + _ = errors.addAll(renameErrors.toList) + typedProg = typeChecker.check(prog, errors) + + res = NonEmptyList.fromList(errors.result) match { + case Some(errors) => Left(errors) + case None => Right(typedProg) + } + } yield res + } + def backend(typedProg: microWacc.Program): Chain[asm.AsmLine] = asmGenerator.generateAsm(typedProg) -def compile(filename: String, outFile: Option[File] = None)(using - stdout: PrintStream = Console.out -): Int = - frontend(os.read(os.Path(filename))) match { - case Left(typedProg) => - val asmFile = outFile.getOrElse(File(filename.stripSuffix(".wacc") + ".s")) - val asm = backend(typedProg) - writer.writeTo(asm, PrintStream(asmFile)) - 0 - case Right(exitCode) => exitCode - } +def compile( + filePath: Path, + outputDir: Option[Path], + log: Boolean +): IO[Int] = { + val logAction: String => IO[Unit] = + if (log) logger.info(_) + else (_ => IO.unit) -def main(args: Array[String]): Unit = - OParser.parse(cliParser, args, CliConfig()) match { - case Some(config) => - System.exit( - compile( - config.file.getAbsolutePath, - outFile = Some(File(".", config.file.getName.stripSuffix(".wacc") + ".s")) - ) - ) - case None => - } + def readSourceFile: IO[String] = + IO.blocking(os.read(os.Path(filePath))) + + // TODO: path, file , the names are confusing (when Path is the type but we are working with files) + def writeOutputFile(typedProg: microWacc.Program, outputPath: Path): IO[Unit] = + writer.writeTo(backend(typedProg), outputPath) *> + logger.info(s"Success: ${outputPath.toAbsolutePath}") + + def processProgram(contents: String, file: File, outDir: Path): IO[Int] = + for { + frontendResult <- frontend(contents, file) + res <- frontendResult match { + case Left(errors) => + val code = errors.map(err => err.exitCode).toList.min + val errorMsg = errors.map(formatError).toIterable.mkString("\n") + for { + _ <- logAction(s"Compilation failed for $filePath\nExit code: $code") + _ <- IO.blocking( + // Explicit println since we want this to always show without logger thread info e.t.c. + println(s"Compilation failed for ${file.getCanonicalPath}:\n$errorMsg") + ) + } yield code + + case Right(typedProg) => + val outputFile = outDir.resolve(filePath.getFileName.toString.stripSuffix(".wacc") + ".s") + writeOutputFile(typedProg, outputFile).as(SUCCESS) + } + } yield res + + for { + contents <- readSourceFile + _ <- logAction(s"Compiling file: ${filePath.toAbsolutePath}") + exitCode <- processProgram(contents, filePath.toFile, outputDir.getOrElse(filePath.getParent)) + } yield exitCode +} + +def compileCommandParallel( + files: NonEmptyList[Path], + log: Boolean, + outDir: Option[Path] +): IO[ExitCode] = + files + .parTraverse { file => compile(file.toAbsolutePath, outDir, log) } + .map { exitCodes => + exitCodes.filter(_ != 0) match { + case Nil => ExitCode.Success + case errorCodes => ExitCode(errorCodes.min) + } + } + +object Main + extends CommandIOApp( + name = "wacc", + header = "The ultimate WACC compiler", + version = "1.0" + ) { + def main: Opts[IO[ExitCode]] = + (filesOpt, logOpt, outputOpt).mapN { (files, log, outDir) => + compileCommandParallel(files, log, outDir) + } + +} diff --git a/src/main/wacc/backend/LabelGenerator.scala b/src/main/wacc/backend/LabelGenerator.scala index 3b5169b..fd0006f 100644 --- a/src/main/wacc/backend/LabelGenerator.scala +++ b/src/main/wacc/backend/LabelGenerator.scala @@ -18,7 +18,7 @@ private class LabelGenerator { } private def getLabel(target: CallTarget | RuntimeError): String = target match { - case Ident(v, _) => s"wacc_$v" + case Ident(v, guid) => s"wacc_${v}_$guid" case Builtin(name) => s"_$name" case err: RuntimeError => s".L.${err.name}" } diff --git a/src/main/wacc/backend/writer.scala b/src/main/wacc/backend/writer.scala index 3c8dcfd..a339f55 100644 --- a/src/main/wacc/backend/writer.scala +++ b/src/main/wacc/backend/writer.scala @@ -1,12 +1,42 @@ package wacc -import java.io.PrintStream +import cats.effect.Resource +import java.nio.charset.StandardCharsets +import java.io.BufferedWriter +import java.io.FileWriter import cats.data.Chain +import cats.effect.IO + +import org.typelevel.log4cats.Logger +import java.nio.file.Path object writer { import assemblyIR._ - def writeTo(asmList: Chain[AsmLine], printStream: PrintStream): Unit = { - asmList.iterator.foreach(printStream.println) - } + // TODO: Judging from documentation it seems as though IO.blocking is the correct choice + // But needs checking + + /** Creates a resource safe BufferedWriter */ + private def bufferedWriter(outputPath: Path): Resource[IO, BufferedWriter] = + Resource.make { + IO.blocking(new BufferedWriter(new FileWriter(outputPath.toFile, StandardCharsets.UTF_8))) + } { writer => + IO.blocking(writer.close()) + .handleErrorWith(_ => IO.unit) // TODO: ensures writer is closed even if an error occurs + } + + /** Write line safely into a BufferedWriter */ + private def writeLines(writer: BufferedWriter, lines: Chain[AsmLine]): IO[Unit] = + IO.blocking { + lines.iterator.foreach { line => + writer.write(line.toString) + writer.newLine() + } + } + + /** Main function to write assembly to a file */ + def writeTo(asmList: Chain[AsmLine], outputPath: Path)(using logger: Logger[IO]): IO[Unit] = + bufferedWriter(outputPath).use { + writeLines(_, asmList) + } } diff --git a/src/main/wacc/frontend/Error.scala b/src/main/wacc/frontend/Error.scala index 9c02a60..188e91c 100644 --- a/src/main/wacc/frontend/Error.scala +++ b/src/main/wacc/frontend/Error.scala @@ -2,7 +2,10 @@ package wacc import wacc.ast.Position import wacc.types._ -import java.io.PrintStream +import java.io.File + +private val SYNTAX_ERROR = 100 +private val SEMANTIC_ERROR = 200 /** Error types for semantic errors */ @@ -15,6 +18,15 @@ enum Error { case SemanticError(pos: Position, msg: String) case TypeMismatch(pos: Position, expected: SemType, got: SemType, msg: String) case InternalError(pos: Position, msg: String) + + case SyntaxError(file: File, msg: String) +} + +extension (e: Error) { + def exitCode: Int = e match { + case Error.SyntaxError(_, _) => SYNTAX_ERROR + case _ => SEMANTIC_ERROR + } } /** Function to handle printing the details of a given semantic error @@ -24,71 +36,91 @@ enum Error { * @param errorContent * Contents of the file to generate code snippets */ -def printError(error: Error)(using errorContent: String, stdout: PrintStream): Unit = { - stdout.println("Semantic error:") - error match { - case Error.DuplicateDeclaration(ident) => - printPosition(ident.pos) - stdout.println(s"Duplicate declaration of identifier ${ident.v}") - highlight(ident.pos, ident.v.length) - case Error.UndeclaredVariable(ident) => - printPosition(ident.pos) - stdout.println(s"Undeclared variable ${ident.v}") - highlight(ident.pos, ident.v.length) - case Error.UndefinedFunction(ident) => - printPosition(ident.pos) - stdout.println(s"Undefined function ${ident.v}") - highlight(ident.pos, ident.v.length) - case Error.FunctionParamsMismatch(id, expected, got, funcType) => - printPosition(id.pos) - stdout.println(s"Function expects $expected parameters, got $got") - stdout.println( - s"(function ${id.v} has type (${funcType.params.mkString(", ")}) -> ${funcType.returnType})" - ) - highlight(id.pos, 1) - case Error.TypeMismatch(pos, expected, got, msg) => - printPosition(pos) - stdout.println(s"Type mismatch: $msg\nExpected: $expected\nGot: $got") - highlight(pos, 1) - case Error.SemanticError(pos, msg) => - printPosition(pos) - stdout.println(msg) - highlight(pos, 1) - case wacc.Error.InternalError(pos, msg) => - printPosition(pos) - stdout.println(s"Internal error: $msg") - highlight(pos, 1) +def formatError(error: Error): String = { + val sb = new StringBuilder() + + /** Format the file of an error + * + * @param file + * File of the error + */ + def formatFile(file: File): Unit = { + sb.append(s"File: ${file.getCanonicalPath}\n") } -} - -/** Function to highlight a section of code for an error message - * - * @param pos - * Position of the error - * @param size - * Size(in chars) of section to highlight - * @param errorContent - * Contents of the file to generate code snippets - */ -def highlight(pos: Position, size: Int)(using errorContent: String, stdout: PrintStream): Unit = { - val lines = errorContent.split("\n") - - val preLine = if (pos.line > 1) lines(pos.line - 2) else "" - val midLine = lines(pos.line - 1) - val postLine = if (pos.line < lines.size) lines(pos.line) else "" - val linePointer = " " * (pos.column + 2) + ("^" * (size)) + "\n" - - stdout.println( - s" >$preLine\n >$midLine\n$linePointer >$postLine" - ) -} - -/** Function to print the position of an error - * - * @param pos - * Position of the error - */ -def printPosition(pos: Position)(using stdout: PrintStream): Unit = { - stdout.println(s"(line ${pos.line}, column ${pos.column}):") + /** Function to format the position of an error + * + * @param pos + * Position of the error + */ + def formatPosition(pos: Position): Unit = { + formatFile(pos.file) + sb.append(s"(line ${pos.line}, column ${pos.column}):\n") + } + + /** Function to highlight a section of code for an error message + * + * @param pos + * Position of the error + * @param size + * Size(in chars) of section to highlight + */ + def formatHighlight(pos: Position, size: Int): Unit = { + val lines = os.read(os.Path(pos.file.getCanonicalPath)).split("\n") + val preLine = if (pos.line > 1) lines(pos.line - 2) else "" + val midLine = lines(pos.line - 1) + val postLine = if (pos.line < lines.size) lines(pos.line) else "" + val linePointer = " " * (pos.column + 2) + ("^" * (size)) + "\n" + + sb.append( + s" >$preLine\n >$midLine\n$linePointer >$postLine\netscape" + ) + } + + error match { + case Error.SyntaxError(_, _) => + sb.append("Syntax error:\n") + case _ => + sb.append("Semantic error:\n") + } + + error match { + case Error.DuplicateDeclaration(ident) => + formatPosition(ident.pos) + sb.append(s"Duplicate declaration of identifier ${ident.v}\n") + formatHighlight(ident.pos, ident.v.length) + case Error.UndeclaredVariable(ident) => + formatPosition(ident.pos) + sb.append(s"Undeclared variable ${ident.v}\n") + formatHighlight(ident.pos, ident.v.length) + case Error.UndefinedFunction(ident) => + formatPosition(ident.pos) + sb.append(s"Undefined function ${ident.v}\n") + formatHighlight(ident.pos, ident.v.length) + case Error.FunctionParamsMismatch(id, expected, got, funcType) => + formatPosition(id.pos) + sb.append(s"Function expects $expected parameters, got $got\n") + sb.append( + s"(function ${id.v} has type (${funcType.params.mkString(", ")}) -> ${funcType.returnType})\n" + ) + formatHighlight(id.pos, 1) + case Error.TypeMismatch(pos, expected, got, msg) => + formatPosition(pos) + sb.append(s"Type mismatch: $msg\nExpected: $expected\nGot: $got\n") + formatHighlight(pos, 1) + case Error.SemanticError(pos, msg) => + formatPosition(pos) + sb.append(msg + "\n") + formatHighlight(pos, 1) + case wacc.Error.InternalError(pos, msg) => + formatPosition(pos) + sb.append(s"Internal error: $msg\n") + formatHighlight(pos, 1) + case Error.SyntaxError(file, msg) => + formatFile(file) + sb.append(msg + "\n") + sb.append("\n") + } + + sb.toString() } diff --git a/src/main/wacc/frontend/ast.scala b/src/main/wacc/frontend/ast.scala index 9b14b13..e39f931 100644 --- a/src/main/wacc/frontend/ast.scala +++ b/src/main/wacc/frontend/ast.scala @@ -1,5 +1,6 @@ package wacc +import java.io.File import parsley.Parsley import parsley.generic.ErrorBridge import parsley.ap._ @@ -22,26 +23,42 @@ object ast { /* ============================ ATOMIC EXPRESSIONS ============================ */ case class IntLiter(v: Int)(val pos: Position) extends Expr6 - object IntLiter extends ParserBridgePos1[Int, IntLiter] + object IntLiter extends ParserBridgePos1Atom[Int, IntLiter] case class BoolLiter(v: Boolean)(val pos: Position) extends Expr6 - object BoolLiter extends ParserBridgePos1[Boolean, BoolLiter] + object BoolLiter extends ParserBridgePos1Atom[Boolean, BoolLiter] case class CharLiter(v: Char)(val pos: Position) extends Expr6 - object CharLiter extends ParserBridgePos1[Char, CharLiter] + object CharLiter extends ParserBridgePos1Atom[Char, CharLiter] case class StrLiter(v: String)(val pos: Position) extends Expr6 - object StrLiter extends ParserBridgePos1[String, StrLiter] + object StrLiter extends ParserBridgePos1Atom[String, StrLiter] case class PairLiter()(val pos: Position) extends Expr6 object PairLiter extends ParserBridgePos0[PairLiter] - case class Ident(v: String, var uid: Int = -1)(val pos: Position) extends Expr6 with LValue - object Ident extends ParserBridgePos1[String, Ident] { + case class Ident(var v: String, var guid: Int = -1, var ty: types.RenamerType = types.?)( + val pos: Position + ) extends Expr6 + with LValue + object Ident extends ParserBridgePos1Atom[String, Ident] { def apply(v: String)(pos: Position): Ident = new Ident(v)(pos) } case class ArrayElem(name: Ident, indices: NonEmptyList[Expr])(val pos: Position) extends Expr6 with LValue - object ArrayElem extends ParserBridgePos1[NonEmptyList[Expr], Ident => ArrayElem] { - def apply(a: NonEmptyList[Expr])(pos: Position): Ident => ArrayElem = - name => ArrayElem(name, a)(pos) + object ArrayElem extends ParserBridgePos2Chain[NonEmptyList[Expr], Ident, ArrayElem] { + def apply(indices: NonEmptyList[Expr], name: Ident)(pos: Position): ArrayElem = + new ArrayElem(name, indices)(pos) } + // object ArrayElem extends ParserBridgePos1[NonEmptyList[Expr], (File => Ident) => ArrayElem] { + // def apply(a: NonEmptyList[Expr])(pos: Position): (File => Ident) => ArrayElem = + // name => ArrayElem(name(pos.file), a)(pos) + // } + // object ArrayElem extends ParserSingletonBridgePos[(File => NonEmptyList[Expr]) => (File => Ident) => File => ArrayElem] { + // // def apply(indices: NonEmptyList[Expr]): (File => Ident) => File => ArrayElem = + // // name => file => new ArrayElem(name(file), ) + // def apply(indices: Parsley[File => NonEmptyList[Expr]]): Parsley[(File => Ident) => File => ArrayElem] = + // // error(ap1(pos.map(con),)) + + // override final def con(pos: (Int, Int)): (File => NonEmptyList[Expr]) => => C = + // (a, b) => file => this.apply(a(file), b(file))(Position(pos._1, pos._2, file)) + // } case class Parens(expr: Expr)(val pos: Position) extends Expr6 object Parens extends ParserBridgePos1[Expr, Parens] @@ -119,8 +136,9 @@ object ast { case class ArrayType(elemType: Type, dimensions: Int)(val pos: Position) extends Type with PairElemType - object ArrayType extends ParserBridgePos1[Int, Type => ArrayType] { - def apply(a: Int)(pos: Position): Type => ArrayType = elemType => ArrayType(elemType, a)(pos) + object ArrayType extends ParserBridgePos2Chain[Int, Type, ArrayType] { + def apply(dimensions: Int, elemType: Type)(pos: Position): ArrayType = + ArrayType(elemType, dimensions)(pos) } case class PairType(fst: PairElemType, snd: PairElemType)(val pos: Position) extends Type object PairType extends ParserBridgePos2[PairElemType, PairElemType, PairType] @@ -131,6 +149,18 @@ object ast { /* ============================ PROGRAM STRUCTURE ============================ */ + case class ImportedFunc(sourceName: Ident, importName: Ident)(val pos: Position) + object ImportedFunc extends ParserBridgePos2[Ident, Option[Ident], ImportedFunc] { + def apply(a: Ident, b: Option[Ident])(pos: Position): ImportedFunc = + new ImportedFunc(a, b.getOrElse(a))(pos) + } + + case class Import(source: StrLiter, funcs: NonEmptyList[ImportedFunc])(val pos: Position) + object Import extends ParserBridgePos2[StrLiter, NonEmptyList[ImportedFunc], Import] + + case class PartialProgram(imports: List[Import], self: Program)(val pos: Position) + object PartialProgram extends ParserBridgePos2[List[Import], Program, PartialProgram] + case class Program(funcs: List[FuncDecl], main: NonEmptyList[Stmt])(val pos: Position) object Program extends ParserBridgePos2[List[FuncDecl], NonEmptyList[Stmt], Program] @@ -143,15 +173,15 @@ object ast { body: NonEmptyList[Stmt] )(val pos: Position) object FuncDecl - extends ParserBridgePos2[ - List[Param], - NonEmptyList[Stmt], - ((Type, Ident)) => FuncDecl + extends ParserBridgePos2Chain[ + (List[Param], NonEmptyList[Stmt]), + ((Type, Ident)), + FuncDecl ] { - def apply(params: List[Param], body: NonEmptyList[Stmt])( + def apply(paramsBody: (List[Param], NonEmptyList[Stmt]), retTyName: (Type, Ident))( pos: Position - ): ((Type, Ident)) => FuncDecl = - (returnType, name) => FuncDecl(returnType, name, params, body)(pos) + ): FuncDecl = + new FuncDecl(retTyName._1, retTyName._2, paramsBody._1, paramsBody._2)(pos) } case class Param(paramType: Type, name: Ident)(val pos: Position) @@ -159,7 +189,9 @@ object ast { /* ============================ STATEMENTS ============================ */ - sealed trait Stmt + sealed trait Stmt { + val pos: Position + } case class Skip()(val pos: Position) extends Stmt object Skip extends ParserBridgePos0[Skip] case class VarDecl(varType: Type, name: Ident, value: RValue)(val pos: Position) extends Stmt @@ -207,7 +239,7 @@ object ast { /* ============================ PARSER BRIDGES ============================ */ - case class Position(line: Int, column: Int) + case class Position(line: Int, column: Int, file: File) trait ParserSingletonBridgePos[+A] extends ErrorBridge { protected def con(pos: (Int, Int)): A @@ -215,38 +247,63 @@ object ast { final def <#(op: Parsley[?]): Parsley[A] = this from op } - trait ParserBridgePos0[+A] extends ParserSingletonBridgePos[A] { + trait ParserBridgePos0[+A] extends ParserSingletonBridgePos[File => A] { def apply()(pos: Position): A - override final def con(pos: (Int, Int)): A = - apply()(Position(pos._1, pos._2)) + override final def con(pos: (Int, Int)): File => A = + file => apply()(Position(pos._1, pos._2, file)) } - trait ParserBridgePos1[-A, +B] extends ParserSingletonBridgePos[A => B] { + trait ParserBridgePos1Atom[-A, +B] extends ParserSingletonBridgePos[A => File => B] { def apply(a: A)(pos: Position): B - def apply(a: Parsley[A]): Parsley[B] = error(ap1(pos.map(con), a)) + def apply(a: Parsley[A]): Parsley[File => B] = error(ap1(pos.map(con), a)) - override final def con(pos: (Int, Int)): A => B = - this.apply(_)(Position(pos._1, pos._2)) + override final def con(pos: (Int, Int)): A => File => B = + a => file => this.apply(a)(Position(pos._1, pos._2, file)) } - trait ParserBridgePos2[-A, -B, +C] extends ParserSingletonBridgePos[(A, B) => C] { + trait ParserBridgePos1[-A, +B] extends ParserSingletonBridgePos[(File => A) => File => B] { + def apply(a: A)(pos: Position): B + def apply(a: Parsley[File => A]): Parsley[File => B] = error(ap1(pos.map(con), a)) + + override final def con(pos: (Int, Int)): (File => A) => File => B = + a => file => this.apply(a(file))(Position(pos._1, pos._2, file)) + } + + trait ParserBridgePos2Chain[-A, -B, +C] + extends ParserSingletonBridgePos[(File => A) => (File => B) => File => C] { def apply(a: A, b: B)(pos: Position): C - def apply(a: Parsley[A], b: => Parsley[B]): Parsley[C] = error( + def apply(a: Parsley[File => A]): Parsley[(File => B) => File => C] = error( + ap1(pos.map(con), a) + ) + + override final def con(pos: (Int, Int)): (File => A) => (File => B) => File => C = + a => b => file => this.apply(a(file), b(file))(Position(pos._1, pos._2, file)) + } + + trait ParserBridgePos2[-A, -B, +C] + extends ParserSingletonBridgePos[(File => A, File => B) => File => C] { + def apply(a: A, b: B)(pos: Position): C + def apply(a: Parsley[File => A], b: => Parsley[File => B]): Parsley[File => C] = error( ap2(pos.map(con), a, b) ) - override final def con(pos: (Int, Int)): (A, B) => C = - apply(_, _)(Position(pos._1, pos._2)) + override final def con(pos: (Int, Int)): (File => A, File => B) => File => C = + (a, b) => file => this.apply(a(file), b(file))(Position(pos._1, pos._2, file)) } - trait ParserBridgePos3[-A, -B, -C, +D] extends ParserSingletonBridgePos[(A, B, C) => D] { + trait ParserBridgePos3[-A, -B, -C, +D] + extends ParserSingletonBridgePos[(File => A, File => B, File => C) => File => D] { def apply(a: A, b: B, c: C)(pos: Position): D - def apply(a: Parsley[A], b: => Parsley[B], c: => Parsley[C]): Parsley[D] = error( + def apply( + a: Parsley[File => A], + b: => Parsley[File => B], + c: => Parsley[File => C] + ): Parsley[File => D] = error( ap3(pos.map(con), a, b, c) ) - override final def con(pos: (Int, Int)): (A, B, C) => D = - apply(_, _, _)(Position(pos._1, pos._2)) + override final def con(pos: (Int, Int)): (File => A, File => B, File => C) => File => D = + (a, b, c) => file => apply(a(file), b(file), c(file))(Position(pos._1, pos._2, file)) } } diff --git a/src/main/wacc/frontend/parser.scala b/src/main/wacc/frontend/parser.scala index e798284..ce9283c 100644 --- a/src/main/wacc/frontend/parser.scala +++ b/src/main/wacc/frontend/parser.scala @@ -1,18 +1,22 @@ package wacc +import java.io.File import parsley.Result import parsley.Parsley import parsley.Parsley.{atomic, many, notFollowedBy, pure, unit} -import parsley.combinator.{countSome, sepBy} +import parsley.combinator.{countSome, sepBy, option} import parsley.expr.{precedence, SOps, InfixL, InfixN, InfixR, Prefix, Atoms} import parsley.errors.combinator._ import parsley.errors.patterns.VerifiedErrors import parsley.syntax.zipped._ -import parsley.cats.combinator.{some} +import parsley.cats.combinator.{some, sepBy1} +import cats.syntax.all._ import cats.data.NonEmptyList import parsley.errors.DefaultErrorBuilder import parsley.errors.ErrorBuilder import parsley.errors.tokenextractors.LexToken +import parsley.expr.GOps +import cats.Functor object parser { import lexer.implicits.implicitSymbol @@ -52,13 +56,24 @@ object parser { implicit val builder: ErrorBuilder[String] = new DefaultErrorBuilder with LexToken { def tokens = errTokens } - def parse(input: String): Result[String, Program] = parser.parse(input) - private val parser = lexer.fully(``) + def parse(input: String): Result[String, File => PartialProgram] = parser.parse(input) + private val parser = lexer.fully(``) + + private type FParsley[A] = Parsley[File => A] + + private def fParsley[A](p: Parsley[A]): FParsley[A] = + p map { a => file => a } + + private def fPair[A, B](p: Parsley[(File => A, File => B)]): FParsley[(A, B)] = + p map { case (a, b) => file => (a(file), b(file)) } + + private def fMap[A, F[_]: Functor](p: Parsley[F[File => A]]): FParsley[F[A]] = + p map { funcs => file => funcs.map(_(file)) } // Expressions - private lazy val ``: Parsley[Expr] = precedence { - SOps(InfixR)(Or from "||") +: - SOps(InfixR)(And from "&&") +: + private lazy val ``: FParsley[Expr] = precedence { + GOps(InfixR)(Or from "||") +: + GOps(InfixR)(And from "&&") +: SOps(InfixN)(Eq from "==", Neq from "!=") +: SOps(InfixN)( Less from "<", @@ -83,32 +98,33 @@ object parser { } // Atoms - private lazy val ``: Atoms[Expr6] = Atoms( + private lazy val ``: Atoms[File => Expr6] = Atoms( IntLiter(integer).label("integer literal"), BoolLiter(("true" as true) | ("false" as false)).label("boolean literal"), CharLiter(charLit).label("character literal"), - StrLiter(stringLit).label("string literal"), + ``.label("string literal"), PairLiter from "null", ``, Parens("(" ~> `` <~ ")") ) - private val `` = + private lazy val `` = StrLiter(stringLit) + private lazy val `` = Ident(ident) | some("*" | "&").verifiedExplain("pointer operators are not allowed") private lazy val `` = (`` <~ ("(".verifiedExplain( "functions can only be called using 'call' keyword" ) | unit)) <**> (`` identity) - private val `` = ArrayElem(some("[" ~> `` <~ "]")) + private lazy val `` = ArrayElem(fMap(some("[" ~> `` <~ "]"))) // Types - private lazy val ``: Parsley[Type] = + private lazy val ``: FParsley[Type] = (`` | (`` ~> ``)) <**> (`` identity) private val `` = (IntType from "int") | (BoolType from "bool") | (CharType from "char") | (StringType from "string") private lazy val `` = - ArrayType(countSome("[" ~> "]")) + ArrayType(fParsley(countSome("[" ~> "]"))) private val `` = "pair" - private val ``: Parsley[PairType] = PairType( + private val ``: FParsley[PairType] = PairType( "(" ~> `` <~ ",", `` <~ ")" ) @@ -116,7 +132,7 @@ object parser { (`` <**> (`` identity)) | ((UntypedPairType from ``) <**> ((`` <**> ``) - .map(arr => (_: UntypedPairType) => arr) identity)) + .map(arr => (_: File => UntypedPairType) => arr) identity)) /* Statements Atomic is used in two places here: @@ -127,13 +143,30 @@ object parser { invalid syntax check, this only happens at most once per program so this is not a major concern. */ + private lazy val `` = PartialProgram( + fMap(many(``)), + `` + ) + private lazy val `` = Import( + "import" ~> ``, + "(" ~> fMap(sepBy1(``, ",")) <~ ")" + ) + private lazy val `` = ``.label("import file name") + private lazy val `` = ImportedFunc( + ``.label("imported function name"), + fMap(option("as" ~> ``)).label("imported function alias") + ) private lazy val `` = Program( "begin" ~> ( - many( - atomic( - ``.label("function declaration") <~> `` <~ "(" - ) <**> `` - ).label("function declaration") | + fMap( + many( + fPair( + atomic( + ``.label("function declaration") <~> `` <~ "(" + ) + ) <**> `` + ).label("function declaration") + ) | atomic(`` <~ "(").verifiedExplain("function declaration is missing return type") ), ``.label( @@ -142,17 +175,23 @@ object parser { ) private lazy val `` = FuncDecl( - sepBy(``, ",") <~ ")" <~ "is", - ``.guardAgainst { - case stmts if !stmts.isReturning => Seq("all functions must end in a returning statement") - } <~ "end" + fPair( + (fMap(sepBy(``, ",")) <~ ")" <~ "is") <~> + (``.guardAgainst { + // TODO: passing in an arbitrary file works but is ugly + case stmts if !(stmts(File("."))).isReturning => + Seq("all functions must end in a returning statement") + } <~ "end") + ) ) private lazy val `` = Param(``, ``) - private lazy val ``: Parsley[NonEmptyList[Stmt]] = - ( - ``.label("main program body"), - (many(";" ~> ``.label("statement after ';'"))) Nil - ).zipped(NonEmptyList.apply) + private lazy val ``: FParsley[NonEmptyList[Stmt]] = + fMap( + ( + ``.label("main program body"), + (many(";" ~> ``.label("statement after ';'"))) Nil + ).zipped(NonEmptyList.apply) + ) private lazy val `` = (Skip from "skip") @@ -160,8 +199,8 @@ object parser { | Free("free" ~> ``.labelAndExplain(LabelType.Expr)) | Return("return" ~> ``.labelAndExplain(LabelType.Expr)) | Exit("exit" ~> ``.labelAndExplain(LabelType.Expr)) - | Print("print" ~> ``.labelAndExplain(LabelType.Expr), pure(false)) - | Print("println" ~> ``.labelAndExplain(LabelType.Expr), pure(true)) + | Print("print" ~> ``.labelAndExplain(LabelType.Expr), fParsley(pure(false))) + | Print("println" ~> ``.labelAndExplain(LabelType.Expr), fParsley(pure(true))) | If( "if" ~> ``.labelWithType(LabelType.Expr) <~ "then", `` <~ "else", @@ -185,9 +224,9 @@ object parser { ("call" ~> ``).verifiedExplain( "function calls' results must be assigned to a variable" ) - private lazy val ``: Parsley[LValue] = + private lazy val ``: FParsley[LValue] = `` | `` - private lazy val ``: Parsley[RValue] = + private lazy val ``: FParsley[RValue] = `` | NewPair( "newpair" ~> "(" ~> `` <~ ",", @@ -196,13 +235,13 @@ object parser { `` | Call( "call" ~> `` <~ "(", - sepBy(``, ",") <~ ")" + fMap(sepBy(``, ",")) <~ ")" ) | ``.labelWithType(LabelType.Expr) private lazy val `` = Fst("fst" ~> ``.label("valid pair")) | Snd("snd" ~> ``.label("valid pair")) private lazy val `` = ArrayLiter( - "[" ~> sepBy(``, ",") <~ "]" + "[" ~> fMap(sepBy(``, ",")) <~ "]" ) extension (stmts: NonEmptyList[Stmt]) { diff --git a/src/main/wacc/frontend/renamer.scala b/src/main/wacc/frontend/renamer.scala index b281283..4893d42 100644 --- a/src/main/wacc/frontend/renamer.scala +++ b/src/main/wacc/frontend/renamer.scala @@ -1,6 +1,15 @@ package wacc +import java.io.File import scala.collection.mutable +import cats.effect.IO +import cats.syntax.all._ +import cats.implicits._ +import cats.data.Chain +import cats.data.NonEmptyList +import parsley.{Failure, Success} + +private val MAIN = "$main" object renamer { import ast._ @@ -11,116 +20,271 @@ object renamer { case Var } + private case class ScopeKey(path: String, name: String, identType: IdentType) + private case class ScopeValue(id: Ident, public: Boolean) + private class Scope( - val current: mutable.Map[(String, IdentType), Ident], - val parent: Map[(String, IdentType), Ident] + private val current: mutable.Map[ScopeKey, ScopeValue], + private val parent: Map[ScopeKey, ScopeValue], + guidStart: Int = 0, + val guidInc: Int = 1 ) { + private var guid = guidStart + private var immutable = false + + private def nextGuid(): Int = { + val id = guid + guid += guidInc + id + } + + private def verifyMutable(): Unit = { + if (immutable) throw new IllegalStateException("Cannot modify an immutable scope") + } /** Create a new scope with the current scope as its parent. + * + * To be used for single-threaded applications. * * @return * A new scope with an empty current scope, and this scope flattened into the parent scope. */ - def subscope: Scope = - Scope(mutable.Map.empty, Map.empty.withDefault(current.withDefault(parent))) + def withSubscope[T](f: Scope => T): T = { + val subscope = + Scope(mutable.Map.empty, Map.empty.withDefault(current.withDefault(parent)), guid, guidInc) + immutable = true + val result = f(subscope) + guid = subscope.guid // Sync GUID + immutable = false + result + } + + /** Create new scopes with the current scope as its parent and GUID numbering adjusted + * correctly. + * + * This will permanently mark the current scope as immutable, for thread safety. + * + * To be used for multi-threaded applications. + * + * @return + * New scopes with an empty current scope, and this scope flattened into the parent scope. + */ + def subscopes(n: Int): Seq[Scope] = { + verifyMutable() + immutable = true + (0 until n).map { i => + Scope( + mutable.Map.empty, + Map.empty.withDefault(current.withDefault(parent)), + guid + i * guidInc, + guidInc * n + ) + } + } /** Attempt to add a new identifier to the current scope. If the identifier already exists in * the current scope, add an error to the error list. * - * @param ty - * The semantic type of the variable identifier, or function identifier type. * @param name * The name of the identifier. - * @param globalNames - * The global map of identifiers to semantic types - the identifier will be added to this - * map. - * @param globalNumbering - * The global map of identifier names to the number of times they have been declared - will - * used to rename this identifier, and will be incremented. - * @param errors - * The list of errors to append to. + * @return + * An error, if one occurred. */ - def add(ty: SemType | FuncType, name: Ident)(using - globalNames: mutable.Map[Ident, SemType], - globalFuncs: mutable.Map[Ident, FuncType], - globalNumbering: mutable.Map[String, Int], - errors: mutable.Builder[Error, List[Error]] - ) = { - val identType = ty match { + def add(name: Ident, public: Boolean = false): Chain[Error] = { + verifyMutable() + val path = name.pos.file.getCanonicalPath + val identType = name.ty match { case _: SemType => IdentType.Var case _: FuncType => IdentType.Func } - current.get((name.v, identType)) match { - case Some(Ident(_, uid)) => - errors += Error.DuplicateDeclaration(name) - name.uid = uid + val key = ScopeKey(path, name.v, identType) + current.get(key) match { + case Some(ScopeValue(Ident(_, id, _), _)) => + name.guid = id + Chain.one(Error.DuplicateDeclaration(name)) case None => - val uid = globalNumbering.getOrElse(name.v, 0) - name.uid = uid - current((name.v, identType)) = name - - ty match { - case semType: SemType => - globalNames(name) = semType - case funcType: FuncType => - globalFuncs(name) = funcType - } - globalNumbering(name.v) = uid + 1 + name.guid = nextGuid() + current(key) = ScopeValue(name, public) + Chain.empty } } - private def get(name: String, identType: IdentType): Option[Ident] = + /** Attempt to add a new identifier as an alias to another to the existing scope. + * + * @param alias + * The (new) alias identifier. + * @param orig + * The (existing) original identifier. + * + * @return + * An error, if one occurred. + */ + def addAlias(alias: Ident, orig: ScopeValue, public: Boolean = false): Chain[Error] = { + verifyMutable() + val path = alias.pos.file.getCanonicalPath + val identType = alias.ty match { + case _: SemType => IdentType.Var + case _: FuncType => IdentType.Func + } + val key = ScopeKey(path, alias.v, identType) + current.get(key) match { + case Some(ScopeValue(Ident(_, id, _), _)) => + alias.guid = id + Chain.one(Error.DuplicateDeclaration(alias)) + case None => + alias.guid = nextGuid() + current(key) = ScopeValue(orig.id, public) + Chain.empty + } + } + + def get(path: String, name: String, identType: IdentType): Option[ScopeValue] = // Unfortunately map defaults only work with `.apply()`, which throws an error when the key is not found. // Neither is there a way to check whether a default exists, so we have to use a try-catch. try { - Some(current.withDefault(parent)((name, identType))) + Some(current.withDefault(parent)(ScopeKey(path, name, identType))) } catch { case _: NoSuchElementException => None } - def getVar(name: String): Option[Ident] = get(name, IdentType.Var) - def getFunc(name: String): Option[Ident] = get(name, IdentType.Func) + def getVar(name: Ident): Option[Ident] = + get(name.pos.file.getCanonicalPath, name.v, IdentType.Var).map(_.id) + def getFunc(name: Ident): Option[Ident] = + get(name.pos.file.getCanonicalPath, name.v, IdentType.Func).map(_.id) } - /** Check scoping of all variables and functions in the program. Also generate semantic types for - * all identifiers. - * - * @param prog - * AST of the program - * @param errors - * List of errors to append to - * @return - * Map of all (renamed) identifies to their semantic types - */ - def rename(prog: Program)(using - errors: mutable.Builder[Error, List[Error]] - ): (Map[Ident, SemType], Map[Ident, FuncType]) = { - given globalNames: mutable.Map[Ident, SemType] = mutable.Map.empty - given globalFuncs: mutable.Map[Ident, FuncType] = mutable.Map.empty - given globalNumbering: mutable.Map[String, Int] = mutable.Map.empty - val scope = Scope(mutable.Map.empty, Map.empty) + private def prepareGlobalScope( + partialProg: PartialProgram + )(using scope: Scope): IO[(FuncDecl, Chain[FuncDecl], Chain[Error])] = { + def readImportFile(file: File): IO[String] = + IO.blocking(os.read(os.Path(file.getCanonicalPath))) + + def prepareImport(contents: String, file: File)(using + scope: Scope + ): IO[(Chain[FuncDecl], Chain[Error])] = { + parser.parse(contents) match { + case Failure(msg) => + IO.pure(Chain.empty, Chain.one(Error.SyntaxError(file, msg))) + case Success(fn) => + val partialProg = fn(file) + for { + (main, chunks, errors) <- prepareGlobalScope(partialProg) + } yield (main +: chunks, errors) + } + } + + def addImportsToScope(importFile: File, funcs: NonEmptyList[ImportedFunc])(using + scope: Scope + ): Chain[Error] = + funcs.foldMap { case ImportedFunc(srcName, aliasName) => + scope.get(importFile.getCanonicalPath, srcName.v, IdentType.Func) match { + case Some(src) if src.public => + aliasName.ty = src.id.ty + scope.addAlias(aliasName, src) + case _ => + Chain.one(Error.UndefinedFunction(srcName)) + } + } + + val PartialProgram(imports, prog) = partialProg + + // First prepare this file's functions... val Program(funcs, main) = prog - funcs - // First add all function declarations to the scope - .map { case FuncDecl(retType, name, params, body) => + val (funcChunks, funcErrors) = funcs.foldLeft((Chain.empty[FuncDecl], Chain.empty[Error])) { + case ((chunks, errors), func @ FuncDecl(retType, name, params, body)) => val paramTypes = params.map { param => val paramType = SemType(param.paramType) + param.name.ty = paramType paramType } - scope.add(FuncType(SemType(retType), paramTypes), name) - (params zip paramTypes, body) + name.ty = FuncType(SemType(retType), paramTypes) + (chunks :+ func, errors ++ scope.add(name, public = true)) + } + // ...and main body. + val mainBodyIdent = Ident(MAIN, ty = FuncType(?, Nil))(prog.pos) + val mainBodyErrors = scope.add(mainBodyIdent, public = false) + val mainBodyChunk = FuncDecl(IntType()(prog.pos), mainBodyIdent, Nil, main)(prog.pos) + + // Now handle imports + val file = prog.pos.file + val preparedImports = imports.foldLeftM[IO, (Chain[FuncDecl], Chain[Error])]( + (Chain.empty[FuncDecl], Chain.empty[Error]) + ) { case ((chunks, errors), Import(name, funcs)) => + val importFile = File(file.getParent, name.v) + if (!importFile.exists()) { + IO.pure( + ( + chunks, + errors :+ Error.SemanticError( + name.pos, + s"File not found: ${importFile.getCanonicalPath}" + ) + ) + ) + } else if (!importFile.canRead()) { + IO.pure( + ( + chunks, + errors :+ Error.SemanticError( + name.pos, + s"File not readable: ${importFile.getCanonicalPath}" + ) + ) + ) + } else if (importFile.getCanonicalPath == file.getCanonicalPath) { + IO.pure( + ( + chunks, + errors :+ Error.SemanticError( + name.pos, + s"Cannot import self: ${importFile.getCanonicalPath}" + ) + ) + ) + } else if (scope.get(importFile.getCanonicalPath, MAIN, IdentType.Func).isDefined) { + IO.pure(chunks, errors ++ addImportsToScope(importFile, funcs)) + } else { + for { + contents <- readImportFile(importFile) + (importChunks, importErrors) <- prepareImport(contents, importFile) + importAliasErrors = addImportsToScope(importFile, funcs) + } yield (chunks ++ importChunks, errors ++ importErrors) } - // Only then rename the function bodies - // (functions can call one-another regardless of order of declaration) - .foreach { case (params, body) => - val functionScope = scope.subscope - params.foreach { case (param, paramType) => - functionScope.add(paramType, param.name) - } - body.toList.foreach(rename(functionScope.subscope)) // body can shadow function params - } - main.toList.foreach(rename(scope)) - (globalNames.toMap, globalFuncs.toMap) + } + + for { + (importChunks, importErrors) <- preparedImports + allChunks = importChunks ++ funcChunks + allErrors = importErrors ++ funcErrors ++ mainBodyErrors + } yield (mainBodyChunk, allChunks, allErrors) + } + + /** Check scoping of all variables and flatten a program. Also generates semantic types and parses + * any imported files. + * + * @param partialProg + * AST of the program + * @return + * (flattenedProg, errors) + */ + private def renameFunction(funcScopePair: (FuncDecl, Scope)): IO[Chain[Error]] = { + val (FuncDecl(_, _, params, body), subscope) = funcScopePair + val paramErrors = params.foldMap(param => subscope.add(param.name)) + IO(subscope.withSubscope { s => body.foldMap(rename(s)) }) + .map(bodyErrors => paramErrors ++ bodyErrors) + } + + def rename(partialProg: PartialProgram): IO[(Program, Chain[Error])] = { + given scope: Scope = Scope(mutable.Map.empty, Map.empty) + + for { + (main, chunks, globalErrors) <- prepareGlobalScope(partialProg) + toRename = (main +: chunks).toList + allErrors <- toRename + .zip(scope.subscopes(toRename.size)) + .parFoldMapA(renameFunction) + // .map(x => x.combineAll) + } yield (Program(chunks.toList, main.body)(main.pos), globalErrors ++ allErrors) } /** Check scoping of all identifies in a given AST node. @@ -129,91 +293,90 @@ object renamer { * The current scope and flattened parent scope. * @param node * The AST node. - * @param globalNames - * The global map of identifiers to semantic types - renamed identifiers will be added to this - * map. - * @param globalNumbering - * The global map of identifier names to the number of times they have been declared - used and - * updated during identifier renaming. - * @param errors */ - private def rename(scope: Scope)( - node: Ident | Stmt | LValue | RValue | Expr - )(using - globalNames: mutable.Map[Ident, SemType], - globalFuncs: mutable.Map[Ident, FuncType], - globalNumbering: mutable.Map[String, Int], - errors: mutable.Builder[Error, List[Error]] - ): Unit = node match { - // These cases are more interesting because the involve making subscopes - // or modifying the current scope. - case VarDecl(synType, name, value) => { - // Order matters here. Variable isn't declared until after the value is evaluated. - rename(scope)(value) - // Attempt to add the new variable to the current scope. - scope.add(SemType(synType), name) - } - case If(cond, thenStmt, elseStmt) => { - rename(scope)(cond) - // then and else both have their own scopes - thenStmt.toList.foreach(rename(scope.subscope)) - elseStmt.toList.foreach(rename(scope.subscope)) - } - case While(cond, body) => { - rename(scope)(cond) - // while bodies have their own scopes - body.toList.foreach(rename(scope.subscope)) - } - // begin-end blocks have their own scopes - case Block(body) => body.toList.foreach(rename(scope.subscope)) + private def rename(scope: Scope)(node: Ident | Stmt | LValue | RValue | Expr): Chain[Error] = + node match { + // These cases are more interes/globting because the involve making subscopes + // or modifying the current scope. + case VarDecl(synType, name, value) => { + // Order matters here. Variable isn't declared until after the value is evaluated. + val errors = rename(scope)(value) + // Attempt to add the new variable to the current scope. + name.ty = SemType(synType) + errors ++ scope.add(name) + } + case If(cond, thenStmt, elseStmt) => { + val condErrors = rename(scope)(cond) + // then and else both have their own scopes + val thenErrors = scope.withSubscope(s => thenStmt.foldMap(rename(s))) + val elseErrors = scope.withSubscope(s => elseStmt.foldMap(rename(s))) + condErrors ++ thenErrors ++ elseErrors + } + case While(cond, body) => { + val condErrors = rename(scope)(cond) + // while bodies have their own scopes + val bodyErrors = scope.withSubscope(s => body.foldMap(rename(s))) + condErrors ++ bodyErrors + } + // begin-end blocks have their own scopes + case Block(body) => scope.withSubscope(s => body.foldMap(rename(s))) - // These cases are simpler, mostly just recursive calls to rename() - case Assign(lhs, value) => { - // Variables may be reassigned with their value in the rhs, so order doesn't matter here. - rename(scope)(lhs) - rename(scope)(value) - } - case Read(lhs) => rename(scope)(lhs) - case Free(expr) => rename(scope)(expr) - case Return(expr) => rename(scope)(expr) - case Exit(expr) => rename(scope)(expr) - case Print(expr, _) => rename(scope)(expr) - case NewPair(fst, snd) => { - rename(scope)(fst) - rename(scope)(snd) - } - case Call(name, args) => { - scope.getFunc(name.v) match { - case Some(Ident(_, uid)) => name.uid = uid - case None => - errors += Error.UndefinedFunction(name) - scope.add(FuncType(?, args.map(_ => ?)), name) + // These cases are simpler, mostly just recursive calls to rename() + case Assign(lhs, value) => { + // Variables may be reassigned with their value in the rhs, so order doesn't matter here. + rename(scope)(lhs) ++ rename(scope)(value) } - args.foreach(rename(scope)) - } - case Fst(elem) => rename(scope)(elem) - case Snd(elem) => rename(scope)(elem) - case ArrayLiter(elems) => elems.foreach(rename(scope)) - case ArrayElem(name, indices) => { - rename(scope)(name) - indices.toList.foreach(rename(scope)) - } - case Parens(expr) => rename(scope)(expr) - case op: UnaryOp => rename(scope)(op.x) - case op: BinaryOp => { - rename(scope)(op.x) - rename(scope)(op.y) - } - // Default to variables. Only `call` uses IdentType.Func. - case id: Ident => { - scope.getVar(id.v) match { - case Some(Ident(_, uid)) => id.uid = uid - case None => - errors += Error.UndeclaredVariable(id) - scope.add(?, id) + case Read(lhs) => rename(scope)(lhs) + case Free(expr) => rename(scope)(expr) + case Return(expr) => rename(scope)(expr) + case Exit(expr) => rename(scope)(expr) + case Print(expr, _) => rename(scope)(expr) + case NewPair(fst, snd) => { + rename(scope)(fst) ++ rename(scope)(snd) } + case Call(name, args) => { + val nameErrors = scope.getFunc(name) match { + case Some(Ident(realName, guid, ty)) => + name.v = realName + name.ty = ty + name.guid = guid + Chain.empty + case None => + name.ty = FuncType(?, args.map(_ => ?)) + scope.add(name) + Chain.one(Error.UndefinedFunction(name)) + } + val argsErrors = args.foldMap(rename(scope)) + nameErrors ++ argsErrors + } + case Fst(elem) => rename(scope)(elem) + case Snd(elem) => rename(scope)(elem) + case ArrayLiter(elems) => elems.foldMap(rename(scope)) + case ArrayElem(name, indices) => { + val nameErrors = rename(scope)(name) + val indicesErrors = indices.foldMap(rename(scope)) + nameErrors ++ indicesErrors + } + case Parens(expr) => rename(scope)(expr) + case op: UnaryOp => rename(scope)(op.x) + case op: BinaryOp => { + rename(scope)(op.x) ++ rename(scope)(op.y) + } + // Default to variables. Only `call` uses IdentType.Func. + case id: Ident => { + scope.getVar(id) match { + case Some(Ident(_, guid, ty)) => + id.ty = ty + id.guid = guid + Chain.empty + case None => + id.ty = ? + scope.add(id) + Chain.one(Error.UndeclaredVariable(id)) + } + } + // These literals cannot contain identifies, exit immediately. + case IntLiter(_) | BoolLiter(_) | CharLiter(_) | StrLiter(_) | PairLiter() | Skip() => + Chain.empty } - // These literals cannot contain identifies, exit immediately. - case IntLiter(_) | BoolLiter(_) | CharLiter(_) | StrLiter(_) | PairLiter() | Skip() => () - } } diff --git a/src/main/wacc/frontend/typeChecker.scala b/src/main/wacc/frontend/typeChecker.scala index a628b69..6f5804b 100644 --- a/src/main/wacc/frontend/typeChecker.scala +++ b/src/main/wacc/frontend/typeChecker.scala @@ -8,13 +8,8 @@ object typeChecker { import wacc.types._ case class TypeCheckerCtx( - globalNames: Map[ast.Ident, SemType], - globalFuncs: Map[ast.Ident, FuncType], errors: mutable.Builder[Error, List[Error]] ) { - def typeOf(ident: ast.Ident): SemType = globalNames(ident) - def funcType(ident: ast.Ident): FuncType = globalFuncs(ident) - def error(err: Error): SemType = errors += err ? @@ -99,18 +94,17 @@ object typeChecker { * The type checker context which includes the global names and functions, and an errors * builder. */ - def check(prog: ast.Program)(using - ctx: TypeCheckerCtx - ): microWacc.Program = + def check(prog: ast.Program, errors: mutable.Builder[Error, List[Error]]): microWacc.Program = + given ctx: TypeCheckerCtx = TypeCheckerCtx(errors) microWacc.Program( // Ignore function syntax types for return value and params, since those have been converted // to SemTypes by the renamer. prog.funcs.map { case ast.FuncDecl(_, name, params, stmts) => - val FuncType(retType, paramTypes) = ctx.funcType(name) + val FuncType(retType, paramTypes) = name.ty.asInstanceOf[FuncType] microWacc.FuncDecl( - microWacc.Ident(name.v, name.uid)(retType), + microWacc.Ident(name.v, name.guid)(retType), params.zip(paramTypes).map { case (ast.Param(_, ident), ty) => - microWacc.Ident(ident.v, ident.uid)(ty) + microWacc.Ident(ident.v, ident.guid)(ty) }, stmts.toList .flatMap( @@ -134,15 +128,20 @@ object typeChecker { ): List[microWacc.Stmt] = stmt match { // Ignore the type of the variable, since it has been converted to a SemType by the renamer. case ast.VarDecl(_, name, value) => - val expectedTy = ctx.typeOf(name) + val expectedTy = name.ty val typedValue = checkValue( value, Constraint.Is( - expectedTy, + expectedTy.asInstanceOf[SemType], s"variable ${name.v} must be assigned a value of type $expectedTy" ) ) - List(microWacc.Assign(microWacc.Ident(name.v, name.uid)(expectedTy), typedValue)) + List( + microWacc.Assign( + microWacc.Ident(name.v, name.guid)(expectedTy.asInstanceOf[SemType]), + typedValue + ) + ) case ast.Assign(lhs, rhs) => val lhsTyped = checkLValue(lhs, Constraint.Unconstrained) val rhsTyped = @@ -315,7 +314,7 @@ object typeChecker { KnownType.Pair(fstTyped.ty, sndTyped.ty).satisfies(constraint, l.pos) ) case ast.Call(id, args) => - val funcTy @ FuncType(retTy, paramTys) = ctx.funcType(id) + val funcTy @ FuncType(retTy, paramTys) = id.ty.asInstanceOf[FuncType] if (args.length != paramTys.length) { ctx.error(Error.FunctionParamsMismatch(id, paramTys.length, args.length, funcTy)) } @@ -324,7 +323,7 @@ object typeChecker { val argsTyped = args.zip(paramTys).map { case (arg, paramTy) => checkValue(arg, Constraint.Is(paramTy, s"argument type mismatch in function ${id.v}")) } - microWacc.Call(microWacc.Ident(id.v, id.uid)(retTy.satisfies(constraint, id.pos)), argsTyped) + microWacc.Call(microWacc.Ident(id.v, id.guid)(retTy.satisfies(constraint, id.pos)), argsTyped) // Unary operators case ast.Negate(x) => @@ -416,30 +415,32 @@ object typeChecker { private def checkLValue(value: ast.LValue, constraint: Constraint)(using ctx: TypeCheckerCtx ): microWacc.LValue = value match { - case id @ ast.Ident(name, uid) => - microWacc.Ident(name, uid)(ctx.typeOf(id).satisfies(constraint, id.pos)) + case id @ ast.Ident(name, guid, ty) => + microWacc.Ident(name, guid)(ty.asInstanceOf[SemType].satisfies(constraint, id.pos)) case ast.ArrayElem(id, indices) => - val arrayTy = ctx.typeOf(id) - val (elemTy, indicesTyped) = indices.mapAccumulate(arrayTy) { (acc, elem) => - val idxTyped = checkValue(elem, Constraint.Is(KnownType.Int, "array index must be an int")) - val next = acc match { - case KnownType.Array(innerTy) => innerTy - case ? => ? // we can keep indexing an unknown type - case nonArrayTy => - ctx.error( - Error.TypeMismatch( - elem.pos, - KnownType.Array(?), - acc, - "cannot index into a non-array" + val arrayTy = id.ty.asInstanceOf[SemType] + val (elemTy, indicesTyped) = indices.mapAccumulate(arrayTy.asInstanceOf[SemType]) { + (acc, elem) => + val idxTyped = + checkValue(elem, Constraint.Is(KnownType.Int, "array index must be an int")) + val next = acc match { + case KnownType.Array(innerTy) => innerTy + case ? => ? // we can keep indexing an unknown type + case nonArrayTy => + ctx.error( + Error.TypeMismatch( + elem.pos, + KnownType.Array(?), + acc, + "cannot index into a non-array" + ) ) - ) - ? - } - (next, idxTyped) + ? + } + (next, idxTyped) } val firstArrayElem = microWacc.ArrayElem( - microWacc.Ident(id.v, id.uid)(arrayTy), + microWacc.Ident(id.v, id.guid)(arrayTy), indicesTyped.head )(elemTy.satisfies(constraint, value.pos)) val arrayElem = indicesTyped.tail.foldLeft(firstArrayElem) { (acc, idx) => diff --git a/src/main/wacc/frontend/types.scala b/src/main/wacc/frontend/types.scala index 549d8a1..5251396 100644 --- a/src/main/wacc/frontend/types.scala +++ b/src/main/wacc/frontend/types.scala @@ -3,7 +3,9 @@ package wacc object types { import ast._ - sealed trait SemType { + sealed trait RenamerType + + sealed trait SemType extends RenamerType { override def toString(): String = this match { case KnownType.Int => "int" case KnownType.Bool => "bool" @@ -41,5 +43,5 @@ object types { } } - case class FuncType(returnType: SemType, params: List[SemType]) + case class FuncType(returnType: SemType, params: List[SemType]) extends RenamerType } diff --git a/src/test/wacc/examples.scala b/src/test/wacc/examples.scala index 6114afd..a0f4564 100644 --- a/src/test/wacc/examples.scala +++ b/src/test/wacc/examples.scala @@ -1,14 +1,19 @@ package wacc import org.scalatest.BeforeAndAfterAll -import org.scalatest.flatspec.AnyFlatSpec import org.scalatest.Inspectors.forEvery +import org.scalatest.matchers.should.Matchers._ +import org.scalatest.freespec.AsyncFreeSpec +import cats.effect.testing.scalatest.AsyncIOSpec import java.io.File +import java.nio.file.Path import sys.process._ -import java.io.PrintStream import scala.io.Source +import cats.effect.IO +import wacc.{compile as compileWacc} + +class ParallelExamplesSpec extends AsyncFreeSpec with AsyncIOSpec with BeforeAndAfterAll { -class ParallelExamplesSpec extends AnyFlatSpec with BeforeAndAfterAll { val files = allWaccFiles("wacc-examples/valid").map { p => (p.toString, List(0)) @@ -21,97 +26,119 @@ class ParallelExamplesSpec extends AnyFlatSpec with BeforeAndAfterAll { } ++ allWaccFiles("wacc-examples/invalid/whack").map { p => (p.toString, List(100, 200)) + } ++ + allWaccFiles("extension/examples/valid").map { p => + (p.toString, List(0)) + } ++ + allWaccFiles("extension/examples/invalid/syntax").map { p => + (p.toString, List(100)) + } ++ + allWaccFiles("extension/examples/invalid/semantics").map { p => + (p.toString, List(200)) } - // tests go here forEvery(files) { (filename, expectedResult) => val baseFilename = filename.stripSuffix(".wacc") - given stdout: PrintStream = PrintStream(File(baseFilename + ".out")) - s"$filename" should "be compiled with correct result" in { - val result = compile(filename) - assert(expectedResult.contains(result)) - } - - if (expectedResult == List(0)) it should "run with correct result" in { - if (fileIsDisallowedBackend(filename)) pending - - // Retrieve contents to get input and expected output + exit code - val contents = scala.io.Source.fromFile(File(filename)).getLines.toList - val inputLine = - contents - .find(_.matches("^# ?[Ii]nput:.*$")) - .map(_.split(":").last.strip + "\n") - .getOrElse("") - val outputLineIdx = contents.indexWhere(_.matches("^# ?[Oo]utput:.*$")) - val expectedOutput = - if (outputLineIdx == -1) "" + s"$filename" - { + "should be compiled with correct result" in { + if (fileIsPendingFrontend(filename)) + IO.pure(pending) else - contents - .drop(outputLineIdx + 1) - .takeWhile(_.startsWith("#")) - .map(_.stripPrefix("#").stripLeading) - .mkString("\n") + compileWacc(Path.of(filename), outputDir = None, log = false).map { result => + expectedResult should contain(result) + } + } - val exitLineIdx = contents.indexWhere(_.matches("^# ?[Ee]xit:.*$")) - val expectedExit = - if (exitLineIdx == -1) 0 - else contents(exitLineIdx + 1).stripPrefix("#").strip.toInt + if (expectedResult == List(0)) { + "should run with correct result" in { + if (fileIsDisallowedBackend(filename)) + IO.pure(succeed) + else if (fileIsPendingBackend(filename)) + IO.pure(pending) + else + for { + contents <- IO(Source.fromFile(File(filename)).getLines.toList) + inputLine = extractInput(contents) + expectedOutput = extractOutput(contents) + expectedExit = extractExit(contents) - // Assembly and link using gcc - val asmFilename = baseFilename + ".s" - val execFilename = baseFilename - val gccResult = s"gcc -o $execFilename -z noexecstack $asmFilename".! - assert(gccResult == 0) + asmFilename = baseFilename + ".s" + execFilename = baseFilename + gccResult <- IO(s"gcc -o $execFilename -z noexecstack $asmFilename".!) - // Run the executable with the provided input - val stdout = new StringBuilder - val process = s"timeout 5s $execFilename" run ProcessIO( - in = w => { - w.write(inputLine.getBytes) - w.close() - }, - out = Source.fromInputStream(_).addString(stdout), - err = _ => () - ) + _ = assert(gccResult == 0) - assert(process.exitValue == expectedExit) - assert( - stdout.toString - .replaceAll("0x[0-9a-f]+", "#addrs#") - .replaceAll("fatal error:.*", "#runtime_error#\u0000") - .takeWhile(_ != '\u0000') - == expectedOutput - ) + stdout <- IO.pure(new StringBuilder) + process <- IO { + s"timeout 5s $execFilename" run ProcessIO( + in = w => { + w.write(inputLine.getBytes) + w.close() + }, + out = Source.fromInputStream(_).addString(stdout), + err = _ => () + ) + } + + exitCode <- IO.pure(process.exitValue) + + } yield { + exitCode shouldBe expectedExit + normalizeOutput(stdout.toString) shouldBe expectedOutput + } + } + } } } def allWaccFiles(dir: String): IndexedSeq[os.Path] = val d = java.io.File(dir) - os.walk(os.Path(d.getAbsolutePath)).filter { _.ext == "wacc" } + os.walk(os.Path(d.getAbsolutePath)).filter(_.ext == "wacc") - def fileIsDisallowedBackend(filename: String): Boolean = - Seq( - // format: off - // disable formatting to avoid binPack - "^.*wacc-examples/valid/advanced.*$", - // "^.*wacc-examples/valid/array.*$", - // "^.*wacc-examples/valid/basic/exit.*$", - // "^.*wacc-examples/valid/basic/skip.*$", - // "^.*wacc-examples/valid/expressions.*$", - // "^.*wacc-examples/valid/function/nested_functions.*$", - // "^.*wacc-examples/valid/function/simple_functions.*$", - // "^.*wacc-examples/valid/if.*$", - // "^.*wacc-examples/valid/IO/print.*$", - // "^.*wacc-examples/valid/IO/read.*$", - // "^.*wacc-examples/valid/IO/IOLoop.wacc.*$", - // "^.*wacc-examples/valid/IO/IOSequence.wacc.*$", - // "^.*wacc-examples/valid/pairs.*$", - //"^.*wacc-examples/valid/runtimeErr.*$", - // "^.*wacc-examples/valid/scope.*$", - // "^.*wacc-examples/valid/sequence.*$", - // "^.*wacc-examples/valid/variables.*$", - // "^.*wacc-examples/valid/while.*$", - // format: on - ).find(filename.matches).isDefined + private def fileIsDisallowedBackend(filename: String): Boolean = + filename.matches("^.*wacc-examples/valid/advanced.*$") + + private def fileIsPendingFrontend(filename: String): Boolean = + List( + // "^.*extension/examples/invalid/syntax/imports/importBadSyntax.*$", + // "^.*extension/examples/invalid/semantics/imports.*$", + // "^.*extension/examples/valid/imports.*$" + ).exists(filename.matches) + + private def fileIsPendingBackend(filename: String): Boolean = + List( + // "^.*extension/examples/invalid/syntax/imports.*$", + // "^.*extension/examples/invalid/semantics/imports.*$", + // "^.*extension/examples/valid/imports.*$" + ).exists(filename.matches) + + private def extractInput(contents: List[String]): String = + contents + .find(_.matches("^# ?[Ii]nput:.*$")) + .map(_.split(":").last.strip + "\n") + .getOrElse("") + + private def extractOutput(contents: List[String]): String = { + val outputLineIdx = contents.indexWhere(_.matches("^# ?[Oo]utput:.*$")) + if (outputLineIdx == -1) "" + else + contents + .drop(outputLineIdx + 1) + .takeWhile(_.startsWith("#")) + .map(_.stripPrefix("#").stripLeading) + .mkString("\n") + } + + private def extractExit(contents: List[String]): Int = { + val exitLineIdx = contents.indexWhere(_.matches("^# ?[Ee]xit:.*$")) + if (exitLineIdx == -1) 0 + else contents(exitLineIdx + 1).stripPrefix("#").strip.toInt + } + + private def normalizeOutput(output: String): String = + output + .replaceAll("0x[0-9a-f]+", "#addrs#") + .replaceAll("fatal error:.*", "#runtime_error#\u0000") + .takeWhile(_ != '\u0000') }