feat: parallelised compilation
Merge request lab2425_spring/WACC_37!39 Co-authored-by: Jonny <j.sinteix@gmail.com>
This commit is contained in:
2
compile
2
compile
@@ -4,6 +4,6 @@
|
||||
# but do *not* change its name.
|
||||
|
||||
# feel free to adjust to suit the specific internal flags of your compiler
|
||||
./wacc-compiler "$@"
|
||||
./wacc-compiler --output . "$@"
|
||||
|
||||
exit $?
|
||||
|
||||
@@ -5,23 +5,18 @@
|
||||
//> using dep com.github.j-mie6::parsley::5.0.0-M10
|
||||
//> using dep com.github.j-mie6::parsley-cats::1.5.0
|
||||
//> using dep com.lihaoyi::os-lib::0.11.4
|
||||
//> using dep com.github.scopt::scopt::4.1.0
|
||||
//> using dep org.typelevel::cats-core::2.13.0
|
||||
//> using dep org.typelevel::cats-effect::3.5.7
|
||||
//> using dep com.monovore::decline::2.5.0
|
||||
//> using dep com.monovore::decline-effect::2.5.0
|
||||
//> using dep org.typelevel::log4cats-slf4j::2.7.0
|
||||
//> using dep org.slf4j:slf4j-simple:2.0.17
|
||||
//> using test.dep org.scalatest::scalatest::3.2.19
|
||||
//> using dep org.typelevel::cats-effect-testing-scalatest::1.6.0
|
||||
|
||||
// these are all sensible defaults to catch annoying issues
|
||||
// sensible defaults for warnings and compiler checks
|
||||
//> using options -deprecation -unchecked -feature
|
||||
//> using options -Wimplausible-patterns -Wunused:all
|
||||
//> using options -Yexplicit-nulls -Wsafe-init -Xkind-projector:underscores
|
||||
|
||||
// these will help ensure you have access to the latest parsley releases
|
||||
// even before they land on maven proper, or snapshot versions, if necessary.
|
||||
// just in case they cause problems, however, keep them turned off unless you
|
||||
// specifically need them.
|
||||
// using repositories sonatype-s01:releases
|
||||
// using repositories sonatype-s01:snapshots
|
||||
|
||||
// these are flags used by Scala native: if you aren't using scala-native, then they do nothing
|
||||
// lto-thin has decent linking times, and release-fast does not too much optimisation.
|
||||
// using nativeLto thin
|
||||
// using nativeGc commix
|
||||
// using nativeMode release-fast
|
||||
// repositories for pre-release versions if needed
|
||||
|
||||
@@ -1,92 +1,159 @@
|
||||
package wacc
|
||||
|
||||
import scala.collection.mutable
|
||||
import cats.data.Chain
|
||||
import cats.data.{Chain, NonEmptyList}
|
||||
import parsley.{Failure, Success}
|
||||
import scopt.OParser
|
||||
import java.io.File
|
||||
import java.io.PrintStream
|
||||
|
||||
import java.nio.file.{Files, Path}
|
||||
import cats.syntax.all._
|
||||
|
||||
import cats.effect.IO
|
||||
import cats.effect.ExitCode
|
||||
|
||||
import com.monovore.decline._
|
||||
import com.monovore.decline.effect._
|
||||
|
||||
import org.typelevel.log4cats.slf4j.Slf4jLogger
|
||||
import org.typelevel.log4cats.Logger
|
||||
|
||||
import assemblyIR as asm
|
||||
import cats.data.ValidatedNel
|
||||
|
||||
case class CliConfig(
|
||||
file: File = new File(".")
|
||||
)
|
||||
/*
|
||||
TODO:
|
||||
1) IO correctness
|
||||
2) Errors can be handled more gracefully - currently, parallelised compilation is not fail fast as far as I am aware
|
||||
3) splitting the file up and nicer refactoring
|
||||
4) logging could be removed
|
||||
5) general cleanup and comments (things like replacing home/<user> with ~ , and names of parameters and args, descriptions etc)
|
||||
*/
|
||||
|
||||
val cliBuilder = OParser.builder[CliConfig]
|
||||
val cliParser = {
|
||||
import cliBuilder._
|
||||
OParser.sequence(
|
||||
programName("wacc-compiler"),
|
||||
help('h', "help")
|
||||
.text("Prints this help message"),
|
||||
arg[File]("<file>")
|
||||
.text("Input WACC source file")
|
||||
.required()
|
||||
.action((f, c) => c.copy(file = f))
|
||||
.validate(f =>
|
||||
if (!f.exists) failure("File does not exist")
|
||||
else if (!f.isFile) failure("File must be a regular file")
|
||||
else if (!f.getName.endsWith(".wacc"))
|
||||
failure("File must have .wacc extension")
|
||||
else success
|
||||
)
|
||||
)
|
||||
private val SUCCESS = ExitCode.Success.code
|
||||
private val ERROR = ExitCode.Error.code
|
||||
|
||||
given logger: Logger[IO] = Slf4jLogger.getLogger[IO]
|
||||
|
||||
val logOpt: Opts[Boolean] =
|
||||
Opts.flag("log", "Enable logging for additional compilation details", short = "l").orFalse
|
||||
|
||||
def validateFile(path: Path): ValidatedNel[String, Path] = {
|
||||
(for {
|
||||
// TODO: redundant 2nd parameter :(
|
||||
_ <- Either.cond(Files.exists(path), (), s"File '${path}' does not exist")
|
||||
_ <- Either.cond(Files.isRegularFile(path), (), s"File '${path}' must be a regular file")
|
||||
_ <- Either.cond(path.toString.endsWith(".wacc"), (), "File must have .wacc extension")
|
||||
} yield path).toValidatedNel
|
||||
}
|
||||
|
||||
val filesOpt: Opts[NonEmptyList[Path]] =
|
||||
Opts.arguments[Path]("files").mapValidated {
|
||||
_.traverse(validateFile)
|
||||
}
|
||||
|
||||
val outputOpt: Opts[Option[Path]] =
|
||||
Opts
|
||||
.option[Path]("output", metavar = "path", help = "Output directory for compiled files.")
|
||||
.validate("Must have permissions to create & access the output path") { path =>
|
||||
try {
|
||||
Files.createDirectories(path)
|
||||
true
|
||||
} catch {
|
||||
case e: java.nio.file.AccessDeniedException =>
|
||||
false
|
||||
}
|
||||
}
|
||||
.validate("Output path must be a directory") { path =>
|
||||
Files.isDirectory(path)
|
||||
}
|
||||
.orNone
|
||||
|
||||
def frontend(
|
||||
contents: String
|
||||
)(using stdout: PrintStream): Either[microWacc.Program, Int] = {
|
||||
): Either[NonEmptyList[Error], microWacc.Program] =
|
||||
parser.parse(contents) match {
|
||||
case Failure(msg) => Left(NonEmptyList.one(Error.SyntaxError(msg)))
|
||||
case Success(prog) =>
|
||||
given errors: mutable.Builder[Error, List[Error]] = List.newBuilder
|
||||
|
||||
val (names, funcs) = renamer.rename(prog)
|
||||
given ctx: typeChecker.TypeCheckerCtx = typeChecker.TypeCheckerCtx(names, funcs, errors)
|
||||
val typedProg = typeChecker.check(prog)
|
||||
if (errors.result.nonEmpty) {
|
||||
given errorContent: String = contents
|
||||
Right(
|
||||
errors.result
|
||||
.map { error =>
|
||||
printError(error)
|
||||
error match {
|
||||
case _: Error.InternalError => 201
|
||||
case _ => 200
|
||||
}
|
||||
}
|
||||
.max()
|
||||
)
|
||||
} else Left(typedProg)
|
||||
case Failure(msg) =>
|
||||
stdout.println(msg)
|
||||
Right(100)
|
||||
}
|
||||
}
|
||||
|
||||
val s = "enter an integer to echo"
|
||||
NonEmptyList.fromList(errors.result) match {
|
||||
case Some(errors) => Left(errors)
|
||||
case None => Right(typedProg)
|
||||
}
|
||||
}
|
||||
|
||||
def backend(typedProg: microWacc.Program): Chain[asm.AsmLine] =
|
||||
asmGenerator.generateAsm(typedProg)
|
||||
|
||||
def compile(filename: String, outFile: Option[File] = None)(using
|
||||
stdout: PrintStream = Console.out
|
||||
): Int =
|
||||
frontend(os.read(os.Path(filename))) match {
|
||||
case Left(typedProg) =>
|
||||
val asmFile = outFile.getOrElse(File(filename.stripSuffix(".wacc") + ".s"))
|
||||
val asm = backend(typedProg)
|
||||
writer.writeTo(asm, PrintStream(asmFile))
|
||||
0
|
||||
case Right(exitCode) => exitCode
|
||||
}
|
||||
def compile(
|
||||
filePath: Path,
|
||||
outputDir: Option[Path],
|
||||
log: Boolean
|
||||
): IO[Int] = {
|
||||
val logAction: String => IO[Unit] =
|
||||
if (log) logger.info(_)
|
||||
else (_ => IO.unit)
|
||||
|
||||
def main(args: Array[String]): Unit =
|
||||
OParser.parse(cliParser, args, CliConfig()) match {
|
||||
case Some(config) =>
|
||||
System.exit(
|
||||
compile(
|
||||
config.file.getAbsolutePath,
|
||||
outFile = Some(File(".", config.file.getName.stripSuffix(".wacc") + ".s"))
|
||||
)
|
||||
)
|
||||
case None =>
|
||||
}
|
||||
def readSourceFile: IO[String] =
|
||||
IO.blocking(os.read(os.Path(filePath)))
|
||||
|
||||
// TODO: path, file , the names are confusing (when Path is the type but we are working with files)
|
||||
def writeOutputFile(typedProg: microWacc.Program, outputPath: Path): IO[Unit] =
|
||||
writer.writeTo(backend(typedProg), outputPath) *>
|
||||
logger.info(s"Success: ${outputPath.toAbsolutePath}")
|
||||
|
||||
def processProgram(contents: String, outDir: Path): IO[Int] =
|
||||
frontend(contents) match {
|
||||
case Left(errors) =>
|
||||
val code = errors.map(err => err.exitCode).toList.min
|
||||
given errorContent: String = contents
|
||||
val errorMsg = errors.map(formatError).toIterable.mkString("\n")
|
||||
for {
|
||||
_ <- logAction(s"Compilation failed for $filePath\nExit code: $code")
|
||||
_ <- IO.blocking(
|
||||
// Explicit println since we want this to always show without logger thread info e.t.c.
|
||||
println(s"Compilation failed for ${filePath.toAbsolutePath}:\n$errorMsg")
|
||||
)
|
||||
} yield code
|
||||
|
||||
case Right(typedProg) =>
|
||||
val outputFile = outDir.resolve(filePath.getFileName.toString.stripSuffix(".wacc") + ".s")
|
||||
writeOutputFile(typedProg, outputFile).as(SUCCESS)
|
||||
}
|
||||
|
||||
for {
|
||||
contents <- readSourceFile
|
||||
_ <- logAction(s"Compiling file: ${filePath.toAbsolutePath}")
|
||||
exitCode <- processProgram(contents, outputDir.getOrElse(filePath.getParent))
|
||||
} yield exitCode
|
||||
}
|
||||
|
||||
def compileCommandParallel(
|
||||
files: NonEmptyList[Path],
|
||||
log: Boolean,
|
||||
outDir: Option[Path]
|
||||
): IO[ExitCode] =
|
||||
files
|
||||
.parTraverse { file => compile(file.toAbsolutePath, outDir, log) }
|
||||
.map { exitCodes =>
|
||||
exitCodes.filter(_ != 0) match {
|
||||
case Nil => ExitCode.Success
|
||||
case errorCodes => ExitCode(errorCodes.min)
|
||||
}
|
||||
}
|
||||
|
||||
object Main
|
||||
extends CommandIOApp(
|
||||
name = "wacc",
|
||||
header = "The ultimate WACC compiler",
|
||||
version = "1.0"
|
||||
) {
|
||||
def main: Opts[IO[ExitCode]] =
|
||||
(filesOpt, logOpt, outputOpt).mapN { (files, log, outDir) =>
|
||||
compileCommandParallel(files, log, outDir)
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
@@ -1,12 +1,42 @@
|
||||
package wacc
|
||||
|
||||
import java.io.PrintStream
|
||||
import cats.effect.Resource
|
||||
import java.nio.charset.StandardCharsets
|
||||
import java.io.BufferedWriter
|
||||
import java.io.FileWriter
|
||||
import cats.data.Chain
|
||||
import cats.effect.IO
|
||||
|
||||
import org.typelevel.log4cats.Logger
|
||||
import java.nio.file.Path
|
||||
|
||||
object writer {
|
||||
import assemblyIR._
|
||||
|
||||
def writeTo(asmList: Chain[AsmLine], printStream: PrintStream): Unit = {
|
||||
asmList.iterator.foreach(printStream.println)
|
||||
}
|
||||
// TODO: Judging from documentation it seems as though IO.blocking is the correct choice
|
||||
// But needs checking
|
||||
|
||||
/** Creates a resource safe BufferedWriter */
|
||||
private def bufferedWriter(outputPath: Path): Resource[IO, BufferedWriter] =
|
||||
Resource.make {
|
||||
IO.blocking(new BufferedWriter(new FileWriter(outputPath.toFile, StandardCharsets.UTF_8)))
|
||||
} { writer =>
|
||||
IO.blocking(writer.close())
|
||||
.handleErrorWith(_ => IO.unit) // TODO: ensures writer is closed even if an error occurs
|
||||
}
|
||||
|
||||
/** Write line safely into a BufferedWriter */
|
||||
private def writeLines(writer: BufferedWriter, lines: Chain[AsmLine]): IO[Unit] =
|
||||
IO.blocking {
|
||||
lines.iterator.foreach { line =>
|
||||
writer.write(line.toString)
|
||||
writer.newLine()
|
||||
}
|
||||
}
|
||||
|
||||
/** Main function to write assembly to a file */
|
||||
def writeTo(asmList: Chain[AsmLine], outputPath: Path)(using logger: Logger[IO]): IO[Unit] =
|
||||
bufferedWriter(outputPath).use {
|
||||
writeLines(_, asmList)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -2,7 +2,9 @@ package wacc
|
||||
|
||||
import wacc.ast.Position
|
||||
import wacc.types._
|
||||
import java.io.PrintStream
|
||||
|
||||
private val SYNTAX_ERROR = 100
|
||||
private val SEMANTIC_ERROR = 200
|
||||
|
||||
/** Error types for semantic errors
|
||||
*/
|
||||
@@ -15,6 +17,15 @@ enum Error {
|
||||
case SemanticError(pos: Position, msg: String)
|
||||
case TypeMismatch(pos: Position, expected: SemType, got: SemType, msg: String)
|
||||
case InternalError(pos: Position, msg: String)
|
||||
|
||||
case SyntaxError(msg: String)
|
||||
}
|
||||
|
||||
extension (e: Error) {
|
||||
def exitCode: Int = e match {
|
||||
case Error.SyntaxError(_) => SYNTAX_ERROR
|
||||
case _ => SEMANTIC_ERROR
|
||||
}
|
||||
}
|
||||
|
||||
/** Function to handle printing the details of a given semantic error
|
||||
@@ -24,71 +35,81 @@ enum Error {
|
||||
* @param errorContent
|
||||
* Contents of the file to generate code snippets
|
||||
*/
|
||||
def printError(error: Error)(using errorContent: String, stdout: PrintStream): Unit = {
|
||||
stdout.println("Semantic error:")
|
||||
error match {
|
||||
case Error.DuplicateDeclaration(ident) =>
|
||||
printPosition(ident.pos)
|
||||
stdout.println(s"Duplicate declaration of identifier ${ident.v}")
|
||||
highlight(ident.pos, ident.v.length)
|
||||
case Error.UndeclaredVariable(ident) =>
|
||||
printPosition(ident.pos)
|
||||
stdout.println(s"Undeclared variable ${ident.v}")
|
||||
highlight(ident.pos, ident.v.length)
|
||||
case Error.UndefinedFunction(ident) =>
|
||||
printPosition(ident.pos)
|
||||
stdout.println(s"Undefined function ${ident.v}")
|
||||
highlight(ident.pos, ident.v.length)
|
||||
case Error.FunctionParamsMismatch(id, expected, got, funcType) =>
|
||||
printPosition(id.pos)
|
||||
stdout.println(s"Function expects $expected parameters, got $got")
|
||||
stdout.println(
|
||||
s"(function ${id.v} has type (${funcType.params.mkString(", ")}) -> ${funcType.returnType})"
|
||||
)
|
||||
highlight(id.pos, 1)
|
||||
case Error.TypeMismatch(pos, expected, got, msg) =>
|
||||
printPosition(pos)
|
||||
stdout.println(s"Type mismatch: $msg\nExpected: $expected\nGot: $got")
|
||||
highlight(pos, 1)
|
||||
case Error.SemanticError(pos, msg) =>
|
||||
printPosition(pos)
|
||||
stdout.println(msg)
|
||||
highlight(pos, 1)
|
||||
case wacc.Error.InternalError(pos, msg) =>
|
||||
printPosition(pos)
|
||||
stdout.println(s"Internal error: $msg")
|
||||
highlight(pos, 1)
|
||||
def formatError(error: Error)(using errorContent: String): String = {
|
||||
val sb = new StringBuilder()
|
||||
|
||||
/** Function to format the position of an error
|
||||
*
|
||||
* @param pos
|
||||
* Position of the error
|
||||
*/
|
||||
def formatPosition(pos: Position): Unit = {
|
||||
sb.append(s"(line ${pos.line}, column ${pos.column}):\n")
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
/** Function to highlight a section of code for an error message
|
||||
*
|
||||
* @param pos
|
||||
* Position of the error
|
||||
* @param size
|
||||
* Size(in chars) of section to highlight
|
||||
* @param errorContent
|
||||
* Contents of the file to generate code snippets
|
||||
*/
|
||||
def highlight(pos: Position, size: Int)(using errorContent: String, stdout: PrintStream): Unit = {
|
||||
val lines = errorContent.split("\n")
|
||||
|
||||
val preLine = if (pos.line > 1) lines(pos.line - 2) else ""
|
||||
val midLine = lines(pos.line - 1)
|
||||
val postLine = if (pos.line < lines.size) lines(pos.line) else ""
|
||||
val linePointer = " " * (pos.column + 2) + ("^" * (size)) + "\n"
|
||||
|
||||
stdout.println(
|
||||
s" >$preLine\n >$midLine\n$linePointer >$postLine"
|
||||
)
|
||||
}
|
||||
|
||||
/** Function to print the position of an error
|
||||
*
|
||||
* @param pos
|
||||
* Position of the error
|
||||
*/
|
||||
def printPosition(pos: Position)(using stdout: PrintStream): Unit = {
|
||||
stdout.println(s"(line ${pos.line}, column ${pos.column}):")
|
||||
/** Function to highlight a section of code for an error message
|
||||
*
|
||||
* @param pos
|
||||
* Position of the error
|
||||
* @param size
|
||||
* Size(in chars) of section to highlight
|
||||
*/
|
||||
def formatHighlight(pos: Position, size: Int): Unit = {
|
||||
val lines = errorContent.split("\n")
|
||||
val preLine = if (pos.line > 1) lines(pos.line - 2) else ""
|
||||
val midLine = lines(pos.line - 1)
|
||||
val postLine = if (pos.line < lines.size) lines(pos.line) else ""
|
||||
val linePointer = " " * (pos.column + 2) + ("^" * (size)) + "\n"
|
||||
|
||||
sb.append(
|
||||
s" >$preLine\n >$midLine\n$linePointer >$postLine\netscape"
|
||||
)
|
||||
}
|
||||
|
||||
error match {
|
||||
case Error.SyntaxError(_) =>
|
||||
sb.append("Syntax error:\n")
|
||||
case _ =>
|
||||
sb.append("Semantic error:\n")
|
||||
}
|
||||
|
||||
error match {
|
||||
case Error.DuplicateDeclaration(ident) =>
|
||||
formatPosition(ident.pos)
|
||||
sb.append(s"Duplicate declaration of identifier ${ident.v}")
|
||||
formatHighlight(ident.pos, ident.v.length)
|
||||
case Error.UndeclaredVariable(ident) =>
|
||||
formatPosition(ident.pos)
|
||||
sb.append(s"Undeclared variable ${ident.v}")
|
||||
formatHighlight(ident.pos, ident.v.length)
|
||||
case Error.UndefinedFunction(ident) =>
|
||||
formatPosition(ident.pos)
|
||||
sb.append(s"Undefined function ${ident.v}")
|
||||
formatHighlight(ident.pos, ident.v.length)
|
||||
case Error.FunctionParamsMismatch(id, expected, got, funcType) =>
|
||||
formatPosition(id.pos)
|
||||
sb.append(s"Function expects $expected parameters, got $got")
|
||||
sb.append(
|
||||
s"(function ${id.v} has type (${funcType.params.mkString(", ")}) -> ${funcType.returnType})"
|
||||
)
|
||||
formatHighlight(id.pos, 1)
|
||||
case Error.TypeMismatch(pos, expected, got, msg) =>
|
||||
formatPosition(pos)
|
||||
sb.append(s"Type mismatch: $msg\nExpected: $expected\nGot: $got")
|
||||
formatHighlight(pos, 1)
|
||||
case Error.SemanticError(pos, msg) =>
|
||||
formatPosition(pos)
|
||||
sb.append(msg)
|
||||
formatHighlight(pos, 1)
|
||||
case wacc.Error.InternalError(pos, msg) =>
|
||||
formatPosition(pos)
|
||||
sb.append(s"Internal error: $msg")
|
||||
formatHighlight(pos, 1)
|
||||
case Error.SyntaxError(msg) =>
|
||||
sb.append(msg)
|
||||
sb.append("\n")
|
||||
}
|
||||
|
||||
sb.toString()
|
||||
|
||||
}
|
||||
|
||||
@@ -4,7 +4,6 @@ import parsley.Parsley
|
||||
import parsley.generic.ErrorBridge
|
||||
import parsley.ap._
|
||||
import parsley.position._
|
||||
import parsley.syntax.zipped._
|
||||
import cats.data.NonEmptyList
|
||||
|
||||
object ast {
|
||||
|
||||
@@ -1,14 +1,19 @@
|
||||
package wacc
|
||||
|
||||
import org.scalatest.BeforeAndAfterAll
|
||||
import org.scalatest.flatspec.AnyFlatSpec
|
||||
import org.scalatest.Inspectors.forEvery
|
||||
import org.scalatest.matchers.should.Matchers._
|
||||
import org.scalatest.freespec.AsyncFreeSpec
|
||||
import cats.effect.testing.scalatest.AsyncIOSpec
|
||||
import java.io.File
|
||||
import java.nio.file.Path
|
||||
import sys.process._
|
||||
import java.io.PrintStream
|
||||
import scala.io.Source
|
||||
import cats.effect.IO
|
||||
import wacc.{compile as compileWacc}
|
||||
|
||||
class ParallelExamplesSpec extends AsyncFreeSpec with AsyncIOSpec with BeforeAndAfterAll {
|
||||
|
||||
class ParallelExamplesSpec extends AnyFlatSpec with BeforeAndAfterAll {
|
||||
val files =
|
||||
allWaccFiles("wacc-examples/valid").map { p =>
|
||||
(p.toString, List(0))
|
||||
@@ -23,95 +28,95 @@ class ParallelExamplesSpec extends AnyFlatSpec with BeforeAndAfterAll {
|
||||
(p.toString, List(100, 200))
|
||||
}
|
||||
|
||||
// tests go here
|
||||
forEvery(files) { (filename, expectedResult) =>
|
||||
val baseFilename = filename.stripSuffix(".wacc")
|
||||
given stdout: PrintStream = PrintStream(File(baseFilename + ".out"))
|
||||
|
||||
s"$filename" should "be compiled with correct result" in {
|
||||
val result = compile(filename)
|
||||
assert(expectedResult.contains(result))
|
||||
}
|
||||
s"$filename" - {
|
||||
"should be compiled with correct result" in {
|
||||
compileWacc(Path.of(filename), outputDir = None, log = false).map { result =>
|
||||
expectedResult should contain(result)
|
||||
}
|
||||
}
|
||||
|
||||
if (expectedResult == List(0)) it should "run with correct result" in {
|
||||
if (fileIsDisallowedBackend(filename)) pending
|
||||
if (expectedResult == List(0)) {
|
||||
"should run with correct result" in {
|
||||
if (fileIsDisallowedBackend(filename))
|
||||
IO.pure(
|
||||
succeed
|
||||
) // TODO: remove when advanced tests removed. not sure how to "pending" this otherwise
|
||||
else {
|
||||
for {
|
||||
contents <- IO(Source.fromFile(File(filename)).getLines.toList)
|
||||
inputLine = extractInput(contents)
|
||||
expectedOutput = extractOutput(contents)
|
||||
expectedExit = extractExit(contents)
|
||||
|
||||
// Retrieve contents to get input and expected output + exit code
|
||||
val contents = scala.io.Source.fromFile(File(filename)).getLines.toList
|
||||
val inputLine =
|
||||
contents
|
||||
.find(_.matches("^# ?[Ii]nput:.*$"))
|
||||
.map(_.split(":").last.strip + "\n")
|
||||
.getOrElse("")
|
||||
val outputLineIdx = contents.indexWhere(_.matches("^# ?[Oo]utput:.*$"))
|
||||
val expectedOutput =
|
||||
if (outputLineIdx == -1) ""
|
||||
else
|
||||
contents
|
||||
.drop(outputLineIdx + 1)
|
||||
.takeWhile(_.startsWith("#"))
|
||||
.map(_.stripPrefix("#").stripLeading)
|
||||
.mkString("\n")
|
||||
asmFilename = baseFilename + ".s"
|
||||
execFilename = baseFilename
|
||||
gccResult <- IO(s"gcc -o $execFilename -z noexecstack $asmFilename".!)
|
||||
|
||||
val exitLineIdx = contents.indexWhere(_.matches("^# ?[Ee]xit:.*$"))
|
||||
val expectedExit =
|
||||
if (exitLineIdx == -1) 0
|
||||
else contents(exitLineIdx + 1).stripPrefix("#").strip.toInt
|
||||
_ = assert(gccResult == 0)
|
||||
|
||||
// Assembly and link using gcc
|
||||
val asmFilename = baseFilename + ".s"
|
||||
val execFilename = baseFilename
|
||||
val gccResult = s"gcc -o $execFilename -z noexecstack $asmFilename".!
|
||||
assert(gccResult == 0)
|
||||
stdout <- IO.pure(new StringBuilder)
|
||||
process <- IO {
|
||||
s"timeout 5s $execFilename" run ProcessIO(
|
||||
in = w => {
|
||||
w.write(inputLine.getBytes)
|
||||
w.close()
|
||||
},
|
||||
out = Source.fromInputStream(_).addString(stdout),
|
||||
err = _ => ()
|
||||
)
|
||||
}
|
||||
|
||||
// Run the executable with the provided input
|
||||
val stdout = new StringBuilder
|
||||
val process = s"timeout 5s $execFilename" run ProcessIO(
|
||||
in = w => {
|
||||
w.write(inputLine.getBytes)
|
||||
w.close()
|
||||
},
|
||||
out = Source.fromInputStream(_).addString(stdout),
|
||||
err = _ => ()
|
||||
)
|
||||
exitCode <- IO.pure(process.exitValue)
|
||||
|
||||
assert(process.exitValue == expectedExit)
|
||||
assert(
|
||||
stdout.toString
|
||||
.replaceAll("0x[0-9a-f]+", "#addrs#")
|
||||
.replaceAll("fatal error:.*", "#runtime_error#\u0000")
|
||||
.takeWhile(_ != '\u0000')
|
||||
== expectedOutput
|
||||
)
|
||||
} yield {
|
||||
exitCode shouldBe expectedExit
|
||||
normalizeOutput(stdout.toString) shouldBe expectedOutput
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
def allWaccFiles(dir: String): IndexedSeq[os.Path] =
|
||||
val d = java.io.File(dir)
|
||||
os.walk(os.Path(d.getAbsolutePath)).filter { _.ext == "wacc" }
|
||||
os.walk(os.Path(d.getAbsolutePath)).filter(_.ext == "wacc")
|
||||
|
||||
// TODO: eventually remove this I think
|
||||
def fileIsDisallowedBackend(filename: String): Boolean =
|
||||
Seq(
|
||||
// format: off
|
||||
// disable formatting to avoid binPack
|
||||
"^.*wacc-examples/valid/advanced.*$",
|
||||
// "^.*wacc-examples/valid/array.*$",
|
||||
// "^.*wacc-examples/valid/basic/exit.*$",
|
||||
// "^.*wacc-examples/valid/basic/skip.*$",
|
||||
// "^.*wacc-examples/valid/expressions.*$",
|
||||
// "^.*wacc-examples/valid/function/nested_functions.*$",
|
||||
// "^.*wacc-examples/valid/function/simple_functions.*$",
|
||||
// "^.*wacc-examples/valid/if.*$",
|
||||
// "^.*wacc-examples/valid/IO/print.*$",
|
||||
// "^.*wacc-examples/valid/IO/read.*$",
|
||||
// "^.*wacc-examples/valid/IO/IOLoop.wacc.*$",
|
||||
// "^.*wacc-examples/valid/IO/IOSequence.wacc.*$",
|
||||
// "^.*wacc-examples/valid/pairs.*$",
|
||||
//"^.*wacc-examples/valid/runtimeErr.*$",
|
||||
// "^.*wacc-examples/valid/scope.*$",
|
||||
// "^.*wacc-examples/valid/sequence.*$",
|
||||
// "^.*wacc-examples/valid/variables.*$",
|
||||
// "^.*wacc-examples/valid/while.*$",
|
||||
// format: on
|
||||
).find(filename.matches).isDefined
|
||||
"^.*wacc-examples/valid/advanced.*$"
|
||||
).exists(filename.matches)
|
||||
|
||||
private def extractInput(contents: List[String]): String =
|
||||
contents
|
||||
.find(_.matches("^# ?[Ii]nput:.*$"))
|
||||
.map(_.split(":").last.strip + "\n")
|
||||
.getOrElse("")
|
||||
|
||||
private def extractOutput(contents: List[String]): String = {
|
||||
val outputLineIdx = contents.indexWhere(_.matches("^# ?[Oo]utput:.*$"))
|
||||
if (outputLineIdx == -1) ""
|
||||
else
|
||||
contents
|
||||
.drop(outputLineIdx + 1)
|
||||
.takeWhile(_.startsWith("#"))
|
||||
.map(_.stripPrefix("#").stripLeading)
|
||||
.mkString("\n")
|
||||
}
|
||||
|
||||
private def extractExit(contents: List[String]): Int = {
|
||||
val exitLineIdx = contents.indexWhere(_.matches("^# ?[Ee]xit:.*$"))
|
||||
if (exitLineIdx == -1) 0
|
||||
else contents(exitLineIdx + 1).stripPrefix("#").strip.toInt
|
||||
}
|
||||
|
||||
private def normalizeOutput(output: String): String =
|
||||
output
|
||||
.replaceAll("0x[0-9a-f]+", "#addrs#")
|
||||
.replaceAll("fatal error:.*", "#runtime_error#\u0000")
|
||||
.takeWhile(_ != '\u0000')
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user