feat: parallelised compilation

Merge request lab2425_spring/WACC_37!39

Co-authored-by: Jonny <j.sinteix@gmail.com>
This commit is contained in:
2025-03-13 01:08:58 +00:00
7 changed files with 349 additions and 232 deletions

View File

@@ -1,92 +1,159 @@
package wacc
import scala.collection.mutable
import cats.data.Chain
import cats.data.{Chain, NonEmptyList}
import parsley.{Failure, Success}
import scopt.OParser
import java.io.File
import java.io.PrintStream
import java.nio.file.{Files, Path}
import cats.syntax.all._
import cats.effect.IO
import cats.effect.ExitCode
import com.monovore.decline._
import com.monovore.decline.effect._
import org.typelevel.log4cats.slf4j.Slf4jLogger
import org.typelevel.log4cats.Logger
import assemblyIR as asm
import cats.data.ValidatedNel
case class CliConfig(
file: File = new File(".")
)
/*
TODO:
1) IO correctness
2) Errors can be handled more gracefully - currently, parallelised compilation is not fail fast as far as I am aware
3) splitting the file up and nicer refactoring
4) logging could be removed
5) general cleanup and comments (things like replacing home/<user> with ~ , and names of parameters and args, descriptions etc)
*/
val cliBuilder = OParser.builder[CliConfig]
val cliParser = {
import cliBuilder._
OParser.sequence(
programName("wacc-compiler"),
help('h', "help")
.text("Prints this help message"),
arg[File]("<file>")
.text("Input WACC source file")
.required()
.action((f, c) => c.copy(file = f))
.validate(f =>
if (!f.exists) failure("File does not exist")
else if (!f.isFile) failure("File must be a regular file")
else if (!f.getName.endsWith(".wacc"))
failure("File must have .wacc extension")
else success
)
)
private val SUCCESS = ExitCode.Success.code
private val ERROR = ExitCode.Error.code
given logger: Logger[IO] = Slf4jLogger.getLogger[IO]
val logOpt: Opts[Boolean] =
Opts.flag("log", "Enable logging for additional compilation details", short = "l").orFalse
def validateFile(path: Path): ValidatedNel[String, Path] = {
(for {
// TODO: redundant 2nd parameter :(
_ <- Either.cond(Files.exists(path), (), s"File '${path}' does not exist")
_ <- Either.cond(Files.isRegularFile(path), (), s"File '${path}' must be a regular file")
_ <- Either.cond(path.toString.endsWith(".wacc"), (), "File must have .wacc extension")
} yield path).toValidatedNel
}
val filesOpt: Opts[NonEmptyList[Path]] =
Opts.arguments[Path]("files").mapValidated {
_.traverse(validateFile)
}
val outputOpt: Opts[Option[Path]] =
Opts
.option[Path]("output", metavar = "path", help = "Output directory for compiled files.")
.validate("Must have permissions to create & access the output path") { path =>
try {
Files.createDirectories(path)
true
} catch {
case e: java.nio.file.AccessDeniedException =>
false
}
}
.validate("Output path must be a directory") { path =>
Files.isDirectory(path)
}
.orNone
def frontend(
contents: String
)(using stdout: PrintStream): Either[microWacc.Program, Int] = {
): Either[NonEmptyList[Error], microWacc.Program] =
parser.parse(contents) match {
case Failure(msg) => Left(NonEmptyList.one(Error.SyntaxError(msg)))
case Success(prog) =>
given errors: mutable.Builder[Error, List[Error]] = List.newBuilder
val (names, funcs) = renamer.rename(prog)
given ctx: typeChecker.TypeCheckerCtx = typeChecker.TypeCheckerCtx(names, funcs, errors)
val typedProg = typeChecker.check(prog)
if (errors.result.nonEmpty) {
given errorContent: String = contents
Right(
errors.result
.map { error =>
printError(error)
error match {
case _: Error.InternalError => 201
case _ => 200
}
}
.max()
)
} else Left(typedProg)
case Failure(msg) =>
stdout.println(msg)
Right(100)
}
}
val s = "enter an integer to echo"
NonEmptyList.fromList(errors.result) match {
case Some(errors) => Left(errors)
case None => Right(typedProg)
}
}
def backend(typedProg: microWacc.Program): Chain[asm.AsmLine] =
asmGenerator.generateAsm(typedProg)
def compile(filename: String, outFile: Option[File] = None)(using
stdout: PrintStream = Console.out
): Int =
frontend(os.read(os.Path(filename))) match {
case Left(typedProg) =>
val asmFile = outFile.getOrElse(File(filename.stripSuffix(".wacc") + ".s"))
val asm = backend(typedProg)
writer.writeTo(asm, PrintStream(asmFile))
0
case Right(exitCode) => exitCode
}
def compile(
filePath: Path,
outputDir: Option[Path],
log: Boolean
): IO[Int] = {
val logAction: String => IO[Unit] =
if (log) logger.info(_)
else (_ => IO.unit)
def main(args: Array[String]): Unit =
OParser.parse(cliParser, args, CliConfig()) match {
case Some(config) =>
System.exit(
compile(
config.file.getAbsolutePath,
outFile = Some(File(".", config.file.getName.stripSuffix(".wacc") + ".s"))
)
)
case None =>
}
def readSourceFile: IO[String] =
IO.blocking(os.read(os.Path(filePath)))
// TODO: path, file , the names are confusing (when Path is the type but we are working with files)
def writeOutputFile(typedProg: microWacc.Program, outputPath: Path): IO[Unit] =
writer.writeTo(backend(typedProg), outputPath) *>
logger.info(s"Success: ${outputPath.toAbsolutePath}")
def processProgram(contents: String, outDir: Path): IO[Int] =
frontend(contents) match {
case Left(errors) =>
val code = errors.map(err => err.exitCode).toList.min
given errorContent: String = contents
val errorMsg = errors.map(formatError).toIterable.mkString("\n")
for {
_ <- logAction(s"Compilation failed for $filePath\nExit code: $code")
_ <- IO.blocking(
// Explicit println since we want this to always show without logger thread info e.t.c.
println(s"Compilation failed for ${filePath.toAbsolutePath}:\n$errorMsg")
)
} yield code
case Right(typedProg) =>
val outputFile = outDir.resolve(filePath.getFileName.toString.stripSuffix(".wacc") + ".s")
writeOutputFile(typedProg, outputFile).as(SUCCESS)
}
for {
contents <- readSourceFile
_ <- logAction(s"Compiling file: ${filePath.toAbsolutePath}")
exitCode <- processProgram(contents, outputDir.getOrElse(filePath.getParent))
} yield exitCode
}
def compileCommandParallel(
files: NonEmptyList[Path],
log: Boolean,
outDir: Option[Path]
): IO[ExitCode] =
files
.parTraverse { file => compile(file.toAbsolutePath, outDir, log) }
.map { exitCodes =>
exitCodes.filter(_ != 0) match {
case Nil => ExitCode.Success
case errorCodes => ExitCode(errorCodes.min)
}
}
object Main
extends CommandIOApp(
name = "wacc",
header = "The ultimate WACC compiler",
version = "1.0"
) {
def main: Opts[IO[ExitCode]] =
(filesOpt, logOpt, outputOpt).mapN { (files, log, outDir) =>
compileCommandParallel(files, log, outDir)
}
}

View File

@@ -1,12 +1,42 @@
package wacc
import java.io.PrintStream
import cats.effect.Resource
import java.nio.charset.StandardCharsets
import java.io.BufferedWriter
import java.io.FileWriter
import cats.data.Chain
import cats.effect.IO
import org.typelevel.log4cats.Logger
import java.nio.file.Path
object writer {
import assemblyIR._
def writeTo(asmList: Chain[AsmLine], printStream: PrintStream): Unit = {
asmList.iterator.foreach(printStream.println)
}
// TODO: Judging from documentation it seems as though IO.blocking is the correct choice
// But needs checking
/** Creates a resource safe BufferedWriter */
private def bufferedWriter(outputPath: Path): Resource[IO, BufferedWriter] =
Resource.make {
IO.blocking(new BufferedWriter(new FileWriter(outputPath.toFile, StandardCharsets.UTF_8)))
} { writer =>
IO.blocking(writer.close())
.handleErrorWith(_ => IO.unit) // TODO: ensures writer is closed even if an error occurs
}
/** Write line safely into a BufferedWriter */
private def writeLines(writer: BufferedWriter, lines: Chain[AsmLine]): IO[Unit] =
IO.blocking {
lines.iterator.foreach { line =>
writer.write(line.toString)
writer.newLine()
}
}
/** Main function to write assembly to a file */
def writeTo(asmList: Chain[AsmLine], outputPath: Path)(using logger: Logger[IO]): IO[Unit] =
bufferedWriter(outputPath).use {
writeLines(_, asmList)
}
}

View File

@@ -2,7 +2,9 @@ package wacc
import wacc.ast.Position
import wacc.types._
import java.io.PrintStream
private val SYNTAX_ERROR = 100
private val SEMANTIC_ERROR = 200
/** Error types for semantic errors
*/
@@ -15,6 +17,15 @@ enum Error {
case SemanticError(pos: Position, msg: String)
case TypeMismatch(pos: Position, expected: SemType, got: SemType, msg: String)
case InternalError(pos: Position, msg: String)
case SyntaxError(msg: String)
}
extension (e: Error) {
def exitCode: Int = e match {
case Error.SyntaxError(_) => SYNTAX_ERROR
case _ => SEMANTIC_ERROR
}
}
/** Function to handle printing the details of a given semantic error
@@ -24,71 +35,81 @@ enum Error {
* @param errorContent
* Contents of the file to generate code snippets
*/
def printError(error: Error)(using errorContent: String, stdout: PrintStream): Unit = {
stdout.println("Semantic error:")
error match {
case Error.DuplicateDeclaration(ident) =>
printPosition(ident.pos)
stdout.println(s"Duplicate declaration of identifier ${ident.v}")
highlight(ident.pos, ident.v.length)
case Error.UndeclaredVariable(ident) =>
printPosition(ident.pos)
stdout.println(s"Undeclared variable ${ident.v}")
highlight(ident.pos, ident.v.length)
case Error.UndefinedFunction(ident) =>
printPosition(ident.pos)
stdout.println(s"Undefined function ${ident.v}")
highlight(ident.pos, ident.v.length)
case Error.FunctionParamsMismatch(id, expected, got, funcType) =>
printPosition(id.pos)
stdout.println(s"Function expects $expected parameters, got $got")
stdout.println(
s"(function ${id.v} has type (${funcType.params.mkString(", ")}) -> ${funcType.returnType})"
)
highlight(id.pos, 1)
case Error.TypeMismatch(pos, expected, got, msg) =>
printPosition(pos)
stdout.println(s"Type mismatch: $msg\nExpected: $expected\nGot: $got")
highlight(pos, 1)
case Error.SemanticError(pos, msg) =>
printPosition(pos)
stdout.println(msg)
highlight(pos, 1)
case wacc.Error.InternalError(pos, msg) =>
printPosition(pos)
stdout.println(s"Internal error: $msg")
highlight(pos, 1)
def formatError(error: Error)(using errorContent: String): String = {
val sb = new StringBuilder()
/** Function to format the position of an error
*
* @param pos
* Position of the error
*/
def formatPosition(pos: Position): Unit = {
sb.append(s"(line ${pos.line}, column ${pos.column}):\n")
}
}
/** Function to highlight a section of code for an error message
*
* @param pos
* Position of the error
* @param size
* Size(in chars) of section to highlight
* @param errorContent
* Contents of the file to generate code snippets
*/
def highlight(pos: Position, size: Int)(using errorContent: String, stdout: PrintStream): Unit = {
val lines = errorContent.split("\n")
val preLine = if (pos.line > 1) lines(pos.line - 2) else ""
val midLine = lines(pos.line - 1)
val postLine = if (pos.line < lines.size) lines(pos.line) else ""
val linePointer = " " * (pos.column + 2) + ("^" * (size)) + "\n"
stdout.println(
s" >$preLine\n >$midLine\n$linePointer >$postLine"
)
}
/** Function to print the position of an error
*
* @param pos
* Position of the error
*/
def printPosition(pos: Position)(using stdout: PrintStream): Unit = {
stdout.println(s"(line ${pos.line}, column ${pos.column}):")
/** Function to highlight a section of code for an error message
*
* @param pos
* Position of the error
* @param size
* Size(in chars) of section to highlight
*/
def formatHighlight(pos: Position, size: Int): Unit = {
val lines = errorContent.split("\n")
val preLine = if (pos.line > 1) lines(pos.line - 2) else ""
val midLine = lines(pos.line - 1)
val postLine = if (pos.line < lines.size) lines(pos.line) else ""
val linePointer = " " * (pos.column + 2) + ("^" * (size)) + "\n"
sb.append(
s" >$preLine\n >$midLine\n$linePointer >$postLine\netscape"
)
}
error match {
case Error.SyntaxError(_) =>
sb.append("Syntax error:\n")
case _ =>
sb.append("Semantic error:\n")
}
error match {
case Error.DuplicateDeclaration(ident) =>
formatPosition(ident.pos)
sb.append(s"Duplicate declaration of identifier ${ident.v}")
formatHighlight(ident.pos, ident.v.length)
case Error.UndeclaredVariable(ident) =>
formatPosition(ident.pos)
sb.append(s"Undeclared variable ${ident.v}")
formatHighlight(ident.pos, ident.v.length)
case Error.UndefinedFunction(ident) =>
formatPosition(ident.pos)
sb.append(s"Undefined function ${ident.v}")
formatHighlight(ident.pos, ident.v.length)
case Error.FunctionParamsMismatch(id, expected, got, funcType) =>
formatPosition(id.pos)
sb.append(s"Function expects $expected parameters, got $got")
sb.append(
s"(function ${id.v} has type (${funcType.params.mkString(", ")}) -> ${funcType.returnType})"
)
formatHighlight(id.pos, 1)
case Error.TypeMismatch(pos, expected, got, msg) =>
formatPosition(pos)
sb.append(s"Type mismatch: $msg\nExpected: $expected\nGot: $got")
formatHighlight(pos, 1)
case Error.SemanticError(pos, msg) =>
formatPosition(pos)
sb.append(msg)
formatHighlight(pos, 1)
case wacc.Error.InternalError(pos, msg) =>
formatPosition(pos)
sb.append(s"Internal error: $msg")
formatHighlight(pos, 1)
case Error.SyntaxError(msg) =>
sb.append(msg)
sb.append("\n")
}
sb.toString()
}

View File

@@ -4,7 +4,6 @@ import parsley.Parsley
import parsley.generic.ErrorBridge
import parsley.ap._
import parsley.position._
import parsley.syntax.zipped._
import cats.data.NonEmptyList
object ast {