diff --git a/compile b/compile index 5505e42..444f591 100755 --- a/compile +++ b/compile @@ -4,6 +4,6 @@ # but do *not* change its name. # feel free to adjust to suit the specific internal flags of your compiler -./wacc-compiler "$@" +./wacc-compiler --output . "$@" exit $? diff --git a/project.scala b/project.scala index 4deaa08..8edf035 100644 --- a/project.scala +++ b/project.scala @@ -5,23 +5,18 @@ //> using dep com.github.j-mie6::parsley::5.0.0-M10 //> using dep com.github.j-mie6::parsley-cats::1.5.0 //> using dep com.lihaoyi::os-lib::0.11.4 -//> using dep com.github.scopt::scopt::4.1.0 +//> using dep org.typelevel::cats-core::2.13.0 +//> using dep org.typelevel::cats-effect::3.5.7 +//> using dep com.monovore::decline::2.5.0 +//> using dep com.monovore::decline-effect::2.5.0 +//> using dep org.typelevel::log4cats-slf4j::2.7.0 +//> using dep org.slf4j:slf4j-simple:2.0.17 //> using test.dep org.scalatest::scalatest::3.2.19 +//> using dep org.typelevel::cats-effect-testing-scalatest::1.6.0 -// these are all sensible defaults to catch annoying issues +// sensible defaults for warnings and compiler checks //> using options -deprecation -unchecked -feature //> using options -Wimplausible-patterns -Wunused:all //> using options -Yexplicit-nulls -Wsafe-init -Xkind-projector:underscores -// these will help ensure you have access to the latest parsley releases -// even before they land on maven proper, or snapshot versions, if necessary. -// just in case they cause problems, however, keep them turned off unless you -// specifically need them. -// using repositories sonatype-s01:releases -// using repositories sonatype-s01:snapshots - -// these are flags used by Scala native: if you aren't using scala-native, then they do nothing -// lto-thin has decent linking times, and release-fast does not too much optimisation. -// using nativeLto thin -// using nativeGc commix -// using nativeMode release-fast +// repositories for pre-release versions if needed diff --git a/src/main/wacc/Main.scala b/src/main/wacc/Main.scala index fc9fb45..d964657 100644 --- a/src/main/wacc/Main.scala +++ b/src/main/wacc/Main.scala @@ -1,92 +1,159 @@ package wacc import scala.collection.mutable -import cats.data.Chain +import cats.data.{Chain, NonEmptyList} import parsley.{Failure, Success} -import scopt.OParser -import java.io.File -import java.io.PrintStream + +import java.nio.file.{Files, Path} +import cats.syntax.all._ + +import cats.effect.IO +import cats.effect.ExitCode + +import com.monovore.decline._ +import com.monovore.decline.effect._ + +import org.typelevel.log4cats.slf4j.Slf4jLogger +import org.typelevel.log4cats.Logger import assemblyIR as asm +import cats.data.ValidatedNel -case class CliConfig( - file: File = new File(".") -) +/* +TODO: + 1) IO correctness + 2) Errors can be handled more gracefully - currently, parallelised compilation is not fail fast as far as I am aware + 3) splitting the file up and nicer refactoring + 4) logging could be removed + 5) general cleanup and comments (things like replacing home/ with ~ , and names of parameters and args, descriptions etc) + */ -val cliBuilder = OParser.builder[CliConfig] -val cliParser = { - import cliBuilder._ - OParser.sequence( - programName("wacc-compiler"), - help('h', "help") - .text("Prints this help message"), - arg[File]("") - .text("Input WACC source file") - .required() - .action((f, c) => c.copy(file = f)) - .validate(f => - if (!f.exists) failure("File does not exist") - else if (!f.isFile) failure("File must be a regular file") - else if (!f.getName.endsWith(".wacc")) - failure("File must have .wacc extension") - else success - ) - ) +private val SUCCESS = ExitCode.Success.code +private val ERROR = ExitCode.Error.code + +given logger: Logger[IO] = Slf4jLogger.getLogger[IO] + +val logOpt: Opts[Boolean] = + Opts.flag("log", "Enable logging for additional compilation details", short = "l").orFalse + +def validateFile(path: Path): ValidatedNel[String, Path] = { + (for { + // TODO: redundant 2nd parameter :( + _ <- Either.cond(Files.exists(path), (), s"File '${path}' does not exist") + _ <- Either.cond(Files.isRegularFile(path), (), s"File '${path}' must be a regular file") + _ <- Either.cond(path.toString.endsWith(".wacc"), (), "File must have .wacc extension") + } yield path).toValidatedNel } +val filesOpt: Opts[NonEmptyList[Path]] = + Opts.arguments[Path]("files").mapValidated { + _.traverse(validateFile) + } + +val outputOpt: Opts[Option[Path]] = + Opts + .option[Path]("output", metavar = "path", help = "Output directory for compiled files.") + .validate("Must have permissions to create & access the output path") { path => + try { + Files.createDirectories(path) + true + } catch { + case e: java.nio.file.AccessDeniedException => + false + } + } + .validate("Output path must be a directory") { path => + Files.isDirectory(path) + } + .orNone + def frontend( contents: String -)(using stdout: PrintStream): Either[microWacc.Program, Int] = { +): Either[NonEmptyList[Error], microWacc.Program] = parser.parse(contents) match { + case Failure(msg) => Left(NonEmptyList.one(Error.SyntaxError(msg))) case Success(prog) => given errors: mutable.Builder[Error, List[Error]] = List.newBuilder + val (names, funcs) = renamer.rename(prog) given ctx: typeChecker.TypeCheckerCtx = typeChecker.TypeCheckerCtx(names, funcs, errors) val typedProg = typeChecker.check(prog) - if (errors.result.nonEmpty) { - given errorContent: String = contents - Right( - errors.result - .map { error => - printError(error) - error match { - case _: Error.InternalError => 201 - case _ => 200 - } - } - .max() - ) - } else Left(typedProg) - case Failure(msg) => - stdout.println(msg) - Right(100) - } -} -val s = "enter an integer to echo" + NonEmptyList.fromList(errors.result) match { + case Some(errors) => Left(errors) + case None => Right(typedProg) + } + } + def backend(typedProg: microWacc.Program): Chain[asm.AsmLine] = asmGenerator.generateAsm(typedProg) -def compile(filename: String, outFile: Option[File] = None)(using - stdout: PrintStream = Console.out -): Int = - frontend(os.read(os.Path(filename))) match { - case Left(typedProg) => - val asmFile = outFile.getOrElse(File(filename.stripSuffix(".wacc") + ".s")) - val asm = backend(typedProg) - writer.writeTo(asm, PrintStream(asmFile)) - 0 - case Right(exitCode) => exitCode - } +def compile( + filePath: Path, + outputDir: Option[Path], + log: Boolean +): IO[Int] = { + val logAction: String => IO[Unit] = + if (log) logger.info(_) + else (_ => IO.unit) -def main(args: Array[String]): Unit = - OParser.parse(cliParser, args, CliConfig()) match { - case Some(config) => - System.exit( - compile( - config.file.getAbsolutePath, - outFile = Some(File(".", config.file.getName.stripSuffix(".wacc") + ".s")) - ) - ) - case None => - } + def readSourceFile: IO[String] = + IO.blocking(os.read(os.Path(filePath))) + + // TODO: path, file , the names are confusing (when Path is the type but we are working with files) + def writeOutputFile(typedProg: microWacc.Program, outputPath: Path): IO[Unit] = + writer.writeTo(backend(typedProg), outputPath) *> + logger.info(s"Success: ${outputPath.toAbsolutePath}") + + def processProgram(contents: String, outDir: Path): IO[Int] = + frontend(contents) match { + case Left(errors) => + val code = errors.map(err => err.exitCode).toList.min + given errorContent: String = contents + val errorMsg = errors.map(formatError).toIterable.mkString("\n") + for { + _ <- logAction(s"Compilation failed for $filePath\nExit code: $code") + _ <- IO.blocking( + // Explicit println since we want this to always show without logger thread info e.t.c. + println(s"Compilation failed for ${filePath.toAbsolutePath}:\n$errorMsg") + ) + } yield code + + case Right(typedProg) => + val outputFile = outDir.resolve(filePath.getFileName.toString.stripSuffix(".wacc") + ".s") + writeOutputFile(typedProg, outputFile).as(SUCCESS) + } + + for { + contents <- readSourceFile + _ <- logAction(s"Compiling file: ${filePath.toAbsolutePath}") + exitCode <- processProgram(contents, outputDir.getOrElse(filePath.getParent)) + } yield exitCode +} + +def compileCommandParallel( + files: NonEmptyList[Path], + log: Boolean, + outDir: Option[Path] +): IO[ExitCode] = + files + .parTraverse { file => compile(file.toAbsolutePath, outDir, log) } + .map { exitCodes => + exitCodes.filter(_ != 0) match { + case Nil => ExitCode.Success + case errorCodes => ExitCode(errorCodes.min) + } + } + +object Main + extends CommandIOApp( + name = "wacc", + header = "The ultimate WACC compiler", + version = "1.0" + ) { + def main: Opts[IO[ExitCode]] = + (filesOpt, logOpt, outputOpt).mapN { (files, log, outDir) => + compileCommandParallel(files, log, outDir) + } + +} diff --git a/src/main/wacc/backend/writer.scala b/src/main/wacc/backend/writer.scala index 3c8dcfd..a339f55 100644 --- a/src/main/wacc/backend/writer.scala +++ b/src/main/wacc/backend/writer.scala @@ -1,12 +1,42 @@ package wacc -import java.io.PrintStream +import cats.effect.Resource +import java.nio.charset.StandardCharsets +import java.io.BufferedWriter +import java.io.FileWriter import cats.data.Chain +import cats.effect.IO + +import org.typelevel.log4cats.Logger +import java.nio.file.Path object writer { import assemblyIR._ - def writeTo(asmList: Chain[AsmLine], printStream: PrintStream): Unit = { - asmList.iterator.foreach(printStream.println) - } + // TODO: Judging from documentation it seems as though IO.blocking is the correct choice + // But needs checking + + /** Creates a resource safe BufferedWriter */ + private def bufferedWriter(outputPath: Path): Resource[IO, BufferedWriter] = + Resource.make { + IO.blocking(new BufferedWriter(new FileWriter(outputPath.toFile, StandardCharsets.UTF_8))) + } { writer => + IO.blocking(writer.close()) + .handleErrorWith(_ => IO.unit) // TODO: ensures writer is closed even if an error occurs + } + + /** Write line safely into a BufferedWriter */ + private def writeLines(writer: BufferedWriter, lines: Chain[AsmLine]): IO[Unit] = + IO.blocking { + lines.iterator.foreach { line => + writer.write(line.toString) + writer.newLine() + } + } + + /** Main function to write assembly to a file */ + def writeTo(asmList: Chain[AsmLine], outputPath: Path)(using logger: Logger[IO]): IO[Unit] = + bufferedWriter(outputPath).use { + writeLines(_, asmList) + } } diff --git a/src/main/wacc/frontend/Error.scala b/src/main/wacc/frontend/Error.scala index 9c02a60..e515494 100644 --- a/src/main/wacc/frontend/Error.scala +++ b/src/main/wacc/frontend/Error.scala @@ -2,7 +2,9 @@ package wacc import wacc.ast.Position import wacc.types._ -import java.io.PrintStream + +private val SYNTAX_ERROR = 100 +private val SEMANTIC_ERROR = 200 /** Error types for semantic errors */ @@ -15,6 +17,15 @@ enum Error { case SemanticError(pos: Position, msg: String) case TypeMismatch(pos: Position, expected: SemType, got: SemType, msg: String) case InternalError(pos: Position, msg: String) + + case SyntaxError(msg: String) +} + +extension (e: Error) { + def exitCode: Int = e match { + case Error.SyntaxError(_) => SYNTAX_ERROR + case _ => SEMANTIC_ERROR + } } /** Function to handle printing the details of a given semantic error @@ -24,71 +35,81 @@ enum Error { * @param errorContent * Contents of the file to generate code snippets */ -def printError(error: Error)(using errorContent: String, stdout: PrintStream): Unit = { - stdout.println("Semantic error:") - error match { - case Error.DuplicateDeclaration(ident) => - printPosition(ident.pos) - stdout.println(s"Duplicate declaration of identifier ${ident.v}") - highlight(ident.pos, ident.v.length) - case Error.UndeclaredVariable(ident) => - printPosition(ident.pos) - stdout.println(s"Undeclared variable ${ident.v}") - highlight(ident.pos, ident.v.length) - case Error.UndefinedFunction(ident) => - printPosition(ident.pos) - stdout.println(s"Undefined function ${ident.v}") - highlight(ident.pos, ident.v.length) - case Error.FunctionParamsMismatch(id, expected, got, funcType) => - printPosition(id.pos) - stdout.println(s"Function expects $expected parameters, got $got") - stdout.println( - s"(function ${id.v} has type (${funcType.params.mkString(", ")}) -> ${funcType.returnType})" - ) - highlight(id.pos, 1) - case Error.TypeMismatch(pos, expected, got, msg) => - printPosition(pos) - stdout.println(s"Type mismatch: $msg\nExpected: $expected\nGot: $got") - highlight(pos, 1) - case Error.SemanticError(pos, msg) => - printPosition(pos) - stdout.println(msg) - highlight(pos, 1) - case wacc.Error.InternalError(pos, msg) => - printPosition(pos) - stdout.println(s"Internal error: $msg") - highlight(pos, 1) +def formatError(error: Error)(using errorContent: String): String = { + val sb = new StringBuilder() + + /** Function to format the position of an error + * + * @param pos + * Position of the error + */ + def formatPosition(pos: Position): Unit = { + sb.append(s"(line ${pos.line}, column ${pos.column}):\n") } -} - -/** Function to highlight a section of code for an error message - * - * @param pos - * Position of the error - * @param size - * Size(in chars) of section to highlight - * @param errorContent - * Contents of the file to generate code snippets - */ -def highlight(pos: Position, size: Int)(using errorContent: String, stdout: PrintStream): Unit = { - val lines = errorContent.split("\n") - - val preLine = if (pos.line > 1) lines(pos.line - 2) else "" - val midLine = lines(pos.line - 1) - val postLine = if (pos.line < lines.size) lines(pos.line) else "" - val linePointer = " " * (pos.column + 2) + ("^" * (size)) + "\n" - - stdout.println( - s" >$preLine\n >$midLine\n$linePointer >$postLine" - ) -} - -/** Function to print the position of an error - * - * @param pos - * Position of the error - */ -def printPosition(pos: Position)(using stdout: PrintStream): Unit = { - stdout.println(s"(line ${pos.line}, column ${pos.column}):") + /** Function to highlight a section of code for an error message + * + * @param pos + * Position of the error + * @param size + * Size(in chars) of section to highlight + */ + def formatHighlight(pos: Position, size: Int): Unit = { + val lines = errorContent.split("\n") + val preLine = if (pos.line > 1) lines(pos.line - 2) else "" + val midLine = lines(pos.line - 1) + val postLine = if (pos.line < lines.size) lines(pos.line) else "" + val linePointer = " " * (pos.column + 2) + ("^" * (size)) + "\n" + + sb.append( + s" >$preLine\n >$midLine\n$linePointer >$postLine\netscape" + ) + } + + error match { + case Error.SyntaxError(_) => + sb.append("Syntax error:\n") + case _ => + sb.append("Semantic error:\n") + } + + error match { + case Error.DuplicateDeclaration(ident) => + formatPosition(ident.pos) + sb.append(s"Duplicate declaration of identifier ${ident.v}") + formatHighlight(ident.pos, ident.v.length) + case Error.UndeclaredVariable(ident) => + formatPosition(ident.pos) + sb.append(s"Undeclared variable ${ident.v}") + formatHighlight(ident.pos, ident.v.length) + case Error.UndefinedFunction(ident) => + formatPosition(ident.pos) + sb.append(s"Undefined function ${ident.v}") + formatHighlight(ident.pos, ident.v.length) + case Error.FunctionParamsMismatch(id, expected, got, funcType) => + formatPosition(id.pos) + sb.append(s"Function expects $expected parameters, got $got") + sb.append( + s"(function ${id.v} has type (${funcType.params.mkString(", ")}) -> ${funcType.returnType})" + ) + formatHighlight(id.pos, 1) + case Error.TypeMismatch(pos, expected, got, msg) => + formatPosition(pos) + sb.append(s"Type mismatch: $msg\nExpected: $expected\nGot: $got") + formatHighlight(pos, 1) + case Error.SemanticError(pos, msg) => + formatPosition(pos) + sb.append(msg) + formatHighlight(pos, 1) + case wacc.Error.InternalError(pos, msg) => + formatPosition(pos) + sb.append(s"Internal error: $msg") + formatHighlight(pos, 1) + case Error.SyntaxError(msg) => + sb.append(msg) + sb.append("\n") + } + + sb.toString() + } diff --git a/src/main/wacc/frontend/ast.scala b/src/main/wacc/frontend/ast.scala index ac0e585..9b14b13 100644 --- a/src/main/wacc/frontend/ast.scala +++ b/src/main/wacc/frontend/ast.scala @@ -4,7 +4,6 @@ import parsley.Parsley import parsley.generic.ErrorBridge import parsley.ap._ import parsley.position._ -import parsley.syntax.zipped._ import cats.data.NonEmptyList object ast { diff --git a/src/test/wacc/examples.scala b/src/test/wacc/examples.scala index 6114afd..11093d6 100644 --- a/src/test/wacc/examples.scala +++ b/src/test/wacc/examples.scala @@ -1,14 +1,19 @@ package wacc import org.scalatest.BeforeAndAfterAll -import org.scalatest.flatspec.AnyFlatSpec import org.scalatest.Inspectors.forEvery +import org.scalatest.matchers.should.Matchers._ +import org.scalatest.freespec.AsyncFreeSpec +import cats.effect.testing.scalatest.AsyncIOSpec import java.io.File +import java.nio.file.Path import sys.process._ -import java.io.PrintStream import scala.io.Source +import cats.effect.IO +import wacc.{compile as compileWacc} + +class ParallelExamplesSpec extends AsyncFreeSpec with AsyncIOSpec with BeforeAndAfterAll { -class ParallelExamplesSpec extends AnyFlatSpec with BeforeAndAfterAll { val files = allWaccFiles("wacc-examples/valid").map { p => (p.toString, List(0)) @@ -23,95 +28,95 @@ class ParallelExamplesSpec extends AnyFlatSpec with BeforeAndAfterAll { (p.toString, List(100, 200)) } - // tests go here forEvery(files) { (filename, expectedResult) => val baseFilename = filename.stripSuffix(".wacc") - given stdout: PrintStream = PrintStream(File(baseFilename + ".out")) - s"$filename" should "be compiled with correct result" in { - val result = compile(filename) - assert(expectedResult.contains(result)) - } + s"$filename" - { + "should be compiled with correct result" in { + compileWacc(Path.of(filename), outputDir = None, log = false).map { result => + expectedResult should contain(result) + } + } - if (expectedResult == List(0)) it should "run with correct result" in { - if (fileIsDisallowedBackend(filename)) pending + if (expectedResult == List(0)) { + "should run with correct result" in { + if (fileIsDisallowedBackend(filename)) + IO.pure( + succeed + ) // TODO: remove when advanced tests removed. not sure how to "pending" this otherwise + else { + for { + contents <- IO(Source.fromFile(File(filename)).getLines.toList) + inputLine = extractInput(contents) + expectedOutput = extractOutput(contents) + expectedExit = extractExit(contents) - // Retrieve contents to get input and expected output + exit code - val contents = scala.io.Source.fromFile(File(filename)).getLines.toList - val inputLine = - contents - .find(_.matches("^# ?[Ii]nput:.*$")) - .map(_.split(":").last.strip + "\n") - .getOrElse("") - val outputLineIdx = contents.indexWhere(_.matches("^# ?[Oo]utput:.*$")) - val expectedOutput = - if (outputLineIdx == -1) "" - else - contents - .drop(outputLineIdx + 1) - .takeWhile(_.startsWith("#")) - .map(_.stripPrefix("#").stripLeading) - .mkString("\n") + asmFilename = baseFilename + ".s" + execFilename = baseFilename + gccResult <- IO(s"gcc -o $execFilename -z noexecstack $asmFilename".!) - val exitLineIdx = contents.indexWhere(_.matches("^# ?[Ee]xit:.*$")) - val expectedExit = - if (exitLineIdx == -1) 0 - else contents(exitLineIdx + 1).stripPrefix("#").strip.toInt + _ = assert(gccResult == 0) - // Assembly and link using gcc - val asmFilename = baseFilename + ".s" - val execFilename = baseFilename - val gccResult = s"gcc -o $execFilename -z noexecstack $asmFilename".! - assert(gccResult == 0) + stdout <- IO.pure(new StringBuilder) + process <- IO { + s"timeout 5s $execFilename" run ProcessIO( + in = w => { + w.write(inputLine.getBytes) + w.close() + }, + out = Source.fromInputStream(_).addString(stdout), + err = _ => () + ) + } - // Run the executable with the provided input - val stdout = new StringBuilder - val process = s"timeout 5s $execFilename" run ProcessIO( - in = w => { - w.write(inputLine.getBytes) - w.close() - }, - out = Source.fromInputStream(_).addString(stdout), - err = _ => () - ) + exitCode <- IO.pure(process.exitValue) - assert(process.exitValue == expectedExit) - assert( - stdout.toString - .replaceAll("0x[0-9a-f]+", "#addrs#") - .replaceAll("fatal error:.*", "#runtime_error#\u0000") - .takeWhile(_ != '\u0000') - == expectedOutput - ) + } yield { + exitCode shouldBe expectedExit + normalizeOutput(stdout.toString) shouldBe expectedOutput + } + } + } + } } } def allWaccFiles(dir: String): IndexedSeq[os.Path] = val d = java.io.File(dir) - os.walk(os.Path(d.getAbsolutePath)).filter { _.ext == "wacc" } + os.walk(os.Path(d.getAbsolutePath)).filter(_.ext == "wacc") + // TODO: eventually remove this I think def fileIsDisallowedBackend(filename: String): Boolean = Seq( - // format: off - // disable formatting to avoid binPack - "^.*wacc-examples/valid/advanced.*$", - // "^.*wacc-examples/valid/array.*$", - // "^.*wacc-examples/valid/basic/exit.*$", - // "^.*wacc-examples/valid/basic/skip.*$", - // "^.*wacc-examples/valid/expressions.*$", - // "^.*wacc-examples/valid/function/nested_functions.*$", - // "^.*wacc-examples/valid/function/simple_functions.*$", - // "^.*wacc-examples/valid/if.*$", - // "^.*wacc-examples/valid/IO/print.*$", - // "^.*wacc-examples/valid/IO/read.*$", - // "^.*wacc-examples/valid/IO/IOLoop.wacc.*$", - // "^.*wacc-examples/valid/IO/IOSequence.wacc.*$", - // "^.*wacc-examples/valid/pairs.*$", - //"^.*wacc-examples/valid/runtimeErr.*$", - // "^.*wacc-examples/valid/scope.*$", - // "^.*wacc-examples/valid/sequence.*$", - // "^.*wacc-examples/valid/variables.*$", - // "^.*wacc-examples/valid/while.*$", - // format: on - ).find(filename.matches).isDefined + "^.*wacc-examples/valid/advanced.*$" + ).exists(filename.matches) + + private def extractInput(contents: List[String]): String = + contents + .find(_.matches("^# ?[Ii]nput:.*$")) + .map(_.split(":").last.strip + "\n") + .getOrElse("") + + private def extractOutput(contents: List[String]): String = { + val outputLineIdx = contents.indexWhere(_.matches("^# ?[Oo]utput:.*$")) + if (outputLineIdx == -1) "" + else + contents + .drop(outputLineIdx + 1) + .takeWhile(_.startsWith("#")) + .map(_.stripPrefix("#").stripLeading) + .mkString("\n") + } + + private def extractExit(contents: List[String]): Int = { + val exitLineIdx = contents.indexWhere(_.matches("^# ?[Ee]xit:.*$")) + if (exitLineIdx == -1) 0 + else contents(exitLineIdx + 1).stripPrefix("#").strip.toInt + } + + private def normalizeOutput(output: String): String = + output + .replaceAll("0x[0-9a-f]+", "#addrs#") + .replaceAll("fatal error:.*", "#runtime_error#\u0000") + .takeWhile(_ != '\u0000') }