66 Commits

Author SHA1 Message Date
fda4e17327 feat: fs2 instead of parTraverse
Merge request lab2425_spring/WACC_37!46
2025-03-14 18:35:59 +00:00
bb1b6a3b23 feat: add wacc-compiler back in 2025-03-14 18:32:25 +00:00
07afc2d59f fix: remove wacc-compiler 2025-03-14 18:31:42 +00:00
8dd23f9e5c fix: add parTraverse semantics back in 2025-03-14 18:31:29 +00:00
8b6e959d11 feat: parTraverse -> fs2 2025-03-14 18:31:29 +00:00
Connolly, Guy
df7a287801 Ext fixes
Merge request lab2425_spring/WACC_37!45

Co-authored-by: Guy C <gc1523@ic.ac.uk>
2025-03-14 18:05:00 +00:00
Guy C
fa399e7721 chore: update wacc-syntax extension package 2025-03-14 17:52:29 +00:00
Guy C
cf495e9d7f chore: update .gitignore to include all temporary WACC files 2025-03-14 17:48:02 +00:00
Guy C
4d8064dc61 chore: add temporary file pattern to .gitignore 2025-03-14 17:45:38 +00:00
Guy C
fde34e88b2 fix: extension fixes for publishing 2025-03-14 17:36:51 +00:00
28ee7a2a32 feat: debug symbols for line file/line location and functions in main file
Merge request lab2425_spring/WACC_37!44
2025-03-14 17:12:13 +00:00
68435207fe fix: include correct main position and don't re-create label 2025-03-14 15:51:34 +00:00
8f7c902ed5 feat: implement .loc, .file and .func debug directives 2025-03-14 15:40:09 +00:00
07f02e61d7 feat: pass stmt position information to microwacc 2025-03-14 14:02:42 +00:00
Connolly, Guy
af514b3363 intelliwacc ide
Merge request lab2425_spring/WACC_37!43

Co-authored-by: Guy C <gc1523@ic.ac.uk>
2025-03-14 13:22:52 +00:00
Guy C
447f29ce4c docs: update README.md 2025-03-14 13:19:14 +00:00
0368daef00 feat: parallel type checking
Merge request lab2425_spring/WACC_37!42

Co-authored-by: Jonny <j.sinteix@gmail.com>
2025-03-14 06:11:53 +00:00
084081de7e style: scala format 2025-03-14 05:40:21 +00:00
46f526c680 feat: success logging by default 2025-03-14 05:39:42 +00:00
53d47fda63 feat: initial parallel type-checker implementation 2025-03-14 04:09:34 +00:00
Guy C
6ad1a9059d refactor: package cleanup and formatting 2025-03-14 02:05:45 +00:00
Guy C
5778b3145d feat: include exe and updated filepath 2025-03-14 01:46:25 +00:00
Guy C
051ef02011 feat: update error generation to consider file paths 2025-03-14 01:28:59 +00:00
Jonny
42515abf2a refactor: remove pattern match in for comprehension 2025-03-14 00:00:43 +00:00
Connolly, Guy
d44eb24086 feat: add option flag, greedy compilation of multiple files, and refactor to...
Merge request lab2425_spring/WACC_37!41

Co-authored-by: Gleb Koval <gleb@koval.net>
Co-authored-by: Jonny <j.sinteix@gmail.com>
2025-03-13 23:28:07 +00:00
191c5df824 feat: imports and parallelised renamer
Merge request lab2425_spring/WACC_37!40

Co-authored-by: Jonny <j.sinteix@gmail.com>
2025-03-13 23:10:38 +00:00
Jonny
68211fd877 feat: parallelised the renamer 2025-03-13 23:00:28 +00:00
a3895dca2c style: scala format 2025-03-13 22:26:56 +00:00
6e592e7d9b feat: functional single-threaded imports 2025-03-13 22:24:41 +00:00
ee54a1201c fix: return proper AST from renamer 2025-03-13 20:47:56 +00:00
c73b073f23 feat: initial attempt at renamer parallelisation 2025-03-13 20:45:57 +00:00
8d8df3357d refactor: use getCanonicalPath instead of toRealPath 2025-03-13 18:39:11 +00:00
00df2dc546 feat: filenames in errors 2025-03-13 15:03:26 +00:00
67e85688b2 refactor: fMap to replace fOption, fList and fNonEmptyList 2025-03-13 14:03:53 +00:00
0497dd34a0 fix: use GOps to avoid scala error 2025-03-13 13:37:17 +00:00
6904aa37e4 style: scala format 2025-03-13 13:31:47 +00:00
5141a2369f fix: convert parser to use FParsley 2025-03-13 13:26:35 +00:00
3fff9d3825 feat: file parser bridges 2025-03-13 13:26:19 +00:00
f11fb9f881 test: integration tests for imports 2025-03-13 09:43:29 +00:00
e881b736f8 feat: imports parser 2025-03-13 08:18:44 +00:00
905a5e5b61 feat: parallelised compilation
Merge request lab2425_spring/WACC_37!39

Co-authored-by: Jonny <j.sinteix@gmail.com>
2025-03-13 01:08:58 +00:00
0d8be53ae4 fix: set output to . for labts compiler 2025-03-12 23:34:39 +00:00
Guy C
36ddd025b2 feat: include the compiler exe and use working relative filepath 2025-03-11 16:39:09 +00:00
Guy C
bad6e47e46 refactor: update error handling and diagnostics in IntelliWACC extension 2025-03-10 18:23:40 +00:00
96ba81e24a refactor: consistent error handling in Main.scala 2025-03-09 23:37:04 +00:00
Guy C
54d6e7143b fix: add 'is' keyword to WACC syntax highlighting 2025-03-03 12:38:25 +00:00
Guy C
c2259334c1 feat: setup for intelliwacc ide with syntax highlights 2025-03-03 12:26:53 +00:00
Jonny
94ee489faf feat: greedy cli argument implemented, parallel compilation now by default, but no fail fast behaviour 2025-03-03 02:58:04 +00:00
Jonny
f24aecffa3 fix: remove implicit val causing conflicts with parsing cli arguments 2025-03-03 02:10:18 +00:00
f896cbb0dd fix: add opSize back in to stack
Merge request lab2425_spring/WACC_37!38
2025-03-02 14:10:18 +00:00
Jonny
19e7ce4c11 fix: fix output flag not reading path passed in 2025-03-02 06:20:19 +00:00
Jonny
473189342b refactor: remove commented out code in main.scala 2025-03-02 03:49:21 +00:00
Jonny
f66f1ab3ac refactor: compile function split up into smaller functions 2025-03-02 03:48:37 +00:00
Jonny
abb43b560d refactor: improve resource safety and structure of writer 2025-03-02 03:26:28 +00:00
Jonny
9a5ccea1f6 style: fix formatting 2025-03-02 03:14:58 +00:00
Jonny
85a82aabb4 feat: add option flag, greedy compilation of multiple files, and refactor to use paths instead of files 2025-03-02 03:12:53 +00:00
1b6d81dfca style: scala format 2025-03-01 02:15:01 +00:00
ae52fa653c fix: add opSize back in to stack 2025-03-01 02:07:45 +00:00
Jonny
01b38b1445 fix: fix incorrect semantic error logging by refactoring error.scala from frontend 2025-03-01 01:34:05 +00:00
Jonny
667fbf4949 feat: introduction of logger to eliminate printstreams 2025-03-01 01:19:50 +00:00
Jonny
d214723f35 feat: parallelise compilation of multiple files given to cli 2025-02-28 19:36:22 +00:00
Jonny
d56be9249a refactor: introduce decline to integrate command-line parsing with cats-effect 2025-02-28 18:00:18 +00:00
Jonny
1a72decf55 feat: remove unsaferunsync and integrate io in tests instead 2025-02-28 16:24:53 +00:00
Jonny
e54e5ce151 refactor: style fixes and fold combinator used instead of explicit pattern match 2025-02-28 15:50:53 +00:00
Jonny
cf1028454d fix: fix frontend tests failing due to expecting error codes instead of runtime exceptions 2025-02-28 15:20:32 +00:00
Jonny
345c652a57 feat: introduce cats-effect and io 2025-02-28 15:18:24 +00:00
56 changed files with 2090 additions and 755 deletions

2
.gitignore vendored
View File

@@ -4,4 +4,4 @@
.vscode/
wacc-examples/
.idea/
**/.temp_wacc_file.*

View File

@@ -4,6 +4,6 @@
# but do *not* change its name.
# feel free to adjust to suit the specific internal flags of your compiler
./wacc-compiler "$@"
./wacc-compiler --output . "$@"
exit $?

View File

@@ -0,0 +1,10 @@
begin
int main() is
int a = 5 ;
string b = "Hello" ;
return a + b
end
int result = call main() ;
exit result
end

View File

@@ -0,0 +1,6 @@
import "./doesNotExist.wacc" (main)
begin
int result = call main() ;
exit result
end

View File

@@ -0,0 +1,6 @@
import "../../../valid/sum.wacc" (mult)
begin
int result = call mult(3, 2) ;
exit result
end

View File

@@ -0,0 +1,10 @@
import "../badWacc.wacc" (main)
begin
int sum(int a, int b) is
return a + b
end
int result = call main() ;
exit result
end

View File

@@ -0,0 +1,6 @@
import "./importBadSem.wacc" (sum)
begin
int result = call sum(1, 2) ;
exit result
end

View File

@@ -0,0 +1,6 @@
import "../../../valid/imports/basic.wacc" (sum)
begin
int result = call sum(3, 2) ;
exit result
end

View File

@@ -0,0 +1,6 @@
int main() is
println "Hello World!" ;
return 0
end
skip

View File

@@ -0,0 +1,8 @@
import "../../../valid/sum.wacc" sum, main
begin
int result1 = call sum(5, 10) ;
int result2 = call main() ;
println result1 ;
println result2
end

View File

@@ -0,0 +1,5 @@
import "../../../valid/sum.wacc" ()
begin
exit 0
end

View File

@@ -0,0 +1,10 @@
import "../badWacc.wacc" (main)
begin
int sum(int a, int b) is
return a + b
end
int result = call main() ;
exit result
end

View File

@@ -0,0 +1,6 @@
import "./importBadSyntax.wacc" (sum)
begin
int result = call sum(1, 2) ;
exit result
end

View File

@@ -0,0 +1 @@

View File

@@ -0,0 +1,9 @@
import "../../../valid/sum.wacc" (sum) ;
import "../../../valid/sum.wacc" (main) ;
begin
int result1 = call sum(5, 10) ;
int result2 = call main() ;
println result1 ;
println result2
end

View File

@@ -0,0 +1,5 @@
import "../../../valid/sum.wacc" *
begin
exit 0
end

View File

@@ -0,0 +1,5 @@
import "../../../valid/sum.wacc" (*)
begin
exit 0
end

7
extension/examples/valid/.gitignore vendored Normal file
View File

@@ -0,0 +1,7 @@
*
!imports/
imports/*
!.gitignore
!*.wacc

View File

@@ -0,0 +1,22 @@
# import main from ../sum.wacc and ./basic.wacc
# Output:
# 15
# 0
# -33
#
# Exit:
# 0
# Program:
import "../sum.wacc" (main as sumMain)
import "./basic.wacc" (main)
begin
int result1 = call sumMain() ;
int result2 = call main() ;
println result1 ;
println result2
end

View File

@@ -0,0 +1,21 @@
# import sum from ../sum.wacc
# Output:
# -33
#
# Exit:
# 0
# Program:
import "../sum.wacc" (sum)
begin
int main() is
int result = call sum(-10, -23) ;
return result
end
int result = call main() ;
println result
end

View File

@@ -0,0 +1,33 @@
# import all the mains
# Output:
# 15
# -33
# 0
# -33
# 0
#
# Exit:
# 99
# Program:
import "../sum.wacc" (main as sumMain)
import "./basic.wacc" (main as basicMain)
import "./multiFunc.wacc" (main as multiFuncMain)
begin
int main() is
int result1 = call sumMain() ;
int result2 = call basicMain() ;
int result3 = call multiFuncMain() ;
println result1 ;
println result2 ;
println result3 ;
return 99
end
int result = call main() ;
exit result
end

View File

@@ -0,0 +1,27 @@
# import sum, main from ../sum.wacc
# Output:
# 15
# -33
# 0
# 0
#
# Exit:
# 0
# Program:
import "../sum.wacc" (sum, main as sumMain)
begin
int main() is
int result = call sum(-10, -23) ;
println result ;
return 0
end
int result1 = call sumMain() ;
int result2 = call main() ;
println result1 ;
println result2
end

View File

@@ -0,0 +1,27 @@
# simple sum program
# Output:
# 15
#
# Exit:
# 0
# Program:
begin
int sum(int a, int b) is
return a + b
end
int main() is
int a = 5 ;
int b = 10 ;
int result = call sum(a, b) ;
println result ;
return 0
end
int result = call main() ;
exit result
end

View File

@@ -5,23 +5,20 @@
//> using dep com.github.j-mie6::parsley::5.0.0-M10
//> using dep com.github.j-mie6::parsley-cats::1.5.0
//> using dep com.lihaoyi::os-lib::0.11.4
//> using dep com.github.scopt::scopt::4.1.0
//> using dep org.typelevel::cats-core::2.13.0
//> using dep org.typelevel::cats-effect::3.5.7
//> using dep com.monovore::decline::2.5.0
//> using dep com.monovore::decline-effect::2.5.0
//> using dep org.typelevel::log4cats-slf4j::2.7.0
//> using dep org.slf4j:slf4j-simple:2.0.17
//> using test.dep org.scalatest::scalatest::3.2.19
//> using dep org.typelevel::cats-effect-testing-scalatest::1.6.0
//> using dep "co.fs2::fs2-core:3.11.0"
//> using dep co.fs2::fs2-io:3.11.0
// these are all sensible defaults to catch annoying issues
// sensible defaults for warnings and compiler checks
//> using options -deprecation -unchecked -feature
//> using options -Wimplausible-patterns -Wunused:all
//> using options -Yexplicit-nulls -Wsafe-init -Xkind-projector:underscores
// these will help ensure you have access to the latest parsley releases
// even before they land on maven proper, or snapshot versions, if necessary.
// just in case they cause problems, however, keep them turned off unless you
// specifically need them.
// using repositories sonatype-s01:releases
// using repositories sonatype-s01:snapshots
// these are flags used by Scala native: if you aren't using scala-native, then they do nothing
// lto-thin has decent linking times, and release-fast does not too much optimisation.
// using nativeLto thin
// using nativeGc commix
// using nativeMode release-fast
// repositories for pre-release versions if needed

View File

@@ -1,92 +1,177 @@
package wacc
import scala.collection.mutable
import cats.data.Chain
import cats.data.{Chain, NonEmptyList}
import parsley.{Failure, Success}
import scopt.OParser
import java.io.File
import java.io.PrintStream
import java.nio.file.{Files, Path}
import cats.syntax.all._
import cats.effect.IO
import cats.effect.ExitCode
import com.monovore.decline._
import com.monovore.decline.effect._
import org.typelevel.log4cats.slf4j.Slf4jLogger
import org.typelevel.log4cats.Logger
import fs2.Stream
import assemblyIR as asm
import cats.data.ValidatedNel
import java.io.File
import cats.data.NonEmptySeq
case class CliConfig(
file: File = new File(".")
)
/*
TODO:
1) IO correctness
2) Errors can be handled more gracefully - currently, parallelised compilation is not fail fast as far as I am aware
3) splitting the file up and nicer refactoring
4) logging could be removed
5) general cleanup and comments (things like replacing home/<user> with ~ , and names of parameters and args, descriptions etc)
*/
val cliBuilder = OParser.builder[CliConfig]
val cliParser = {
import cliBuilder._
OParser.sequence(
programName("wacc-compiler"),
help('h', "help")
.text("Prints this help message"),
arg[File]("<file>")
.text("Input WACC source file")
.required()
.action((f, c) => c.copy(file = f))
.validate(f =>
if (!f.exists) failure("File does not exist")
else if (!f.isFile) failure("File must be a regular file")
else if (!f.getName.endsWith(".wacc"))
failure("File must have .wacc extension")
else success
)
)
private val SUCCESS = ExitCode.Success.code
private val ERROR = ExitCode.Error.code
given logger: Logger[IO] = Slf4jLogger.getLogger[IO]
val logOpt: Opts[Boolean] =
Opts.flag("log", "Enable logging for additional compilation details", short = "l").orFalse
def validateFile(path: Path): ValidatedNel[String, Path] = {
(for {
// TODO: redundant 2nd parameter :(
_ <- Either.cond(Files.exists(path), (), s"File '${path}' does not exist")
_ <- Either.cond(Files.isRegularFile(path), (), s"File '${path}' must be a regular file")
_ <- Either.cond(path.toString.endsWith(".wacc"), (), "File must have .wacc extension")
} yield path).toValidatedNel
}
val filesOpt: Opts[NonEmptyList[Path]] =
Opts.arguments[Path]("files").mapValidated {
_.traverse(validateFile)
}
val outputOpt: Opts[Option[Path]] =
Opts
.option[Path]("output", metavar = "path", help = "Output directory for compiled files.")
.validate("Must have permissions to create & access the output path") { path =>
try {
Files.createDirectories(path)
true
} catch {
case e: java.nio.file.AccessDeniedException =>
false
}
}
.validate("Output path must be a directory") { path =>
Files.isDirectory(path)
}
.orNone
def frontend(
contents: String
)(using stdout: PrintStream): Either[microWacc.Program, Int] = {
contents: String,
file: File
): IO[Either[NonEmptySeq[Error], microWacc.Program]] =
parser.parse(contents) match {
case Success(prog) =>
given errors: mutable.Builder[Error, List[Error]] = List.newBuilder
val (names, funcs) = renamer.rename(prog)
given ctx: typeChecker.TypeCheckerCtx = typeChecker.TypeCheckerCtx(names, funcs, errors)
val typedProg = typeChecker.check(prog)
if (errors.result.nonEmpty) {
given errorContent: String = contents
Right(
errors.result
.map { error =>
printError(error)
error match {
case _: Error.InternalError => 201
case _ => 200
}
}
.max()
)
} else Left(typedProg)
case Failure(msg) =>
stdout.println(msg)
Right(100)
}
case Failure(msg) => IO.pure(Left(NonEmptySeq.one(Error.SyntaxError(file, msg))))
case Success(fn) =>
val partialProg = fn(file)
for {
(typedProg, errors) <- semantics.check(partialProg)
res = NonEmptySeq.fromSeq(errors.iterator.toSeq).map(Left(_)).getOrElse(Right(typedProg))
} yield res
}
val s = "enter an integer to echo"
def backend(typedProg: microWacc.Program): Chain[asm.AsmLine] =
asmGenerator.generateAsm(typedProg)
def compile(filename: String, outFile: Option[File] = None)(using
stdout: PrintStream = Console.out
): Int =
frontend(os.read(os.Path(filename))) match {
case Left(typedProg) =>
val asmFile = outFile.getOrElse(File(filename.stripSuffix(".wacc") + ".s"))
val asm = backend(typedProg)
writer.writeTo(asm, PrintStream(asmFile))
0
case Right(exitCode) => exitCode
def compile(
filePath: Path,
outputDir: Option[Path],
log: Boolean
): IO[Int] = {
val logAction: String => IO[Unit] =
if (log) logger.info(_)
else (_ => IO.unit)
def readSourceFile: IO[String] =
IO.blocking(os.read(os.Path(filePath)))
// TODO: path, file , the names are confusing (when Path is the type but we are working with files)
def writeOutputFile(typedProg: microWacc.Program, outputPath: Path): IO[Unit] =
val backendStart = System.nanoTime()
val asmLines = backend(typedProg)
val backendEnd = System.nanoTime()
writer.writeTo(asmLines, outputPath) *>
logAction(
s"Backend time (${filePath.toRealPath()}): ${(backendEnd - backendStart).toFloat / 1e6} ms"
) *>
IO.blocking(println(s"Success: ${outputPath.toRealPath()}"))
def processProgram(contents: String, file: File, outDir: Path): IO[Int] =
val frontendStart = System.nanoTime()
for {
frontendResult <- frontend(contents, file)
frontendEnd = System.nanoTime()
_ <- logAction(
s"Frontend time (${filePath.toRealPath()}): ${(frontendEnd - frontendStart).toFloat / 1e6} ms"
)
res <- frontendResult match {
case Left(errors) =>
val code = errors.map(err => err.exitCode).toList.min
val errorMsg = errors.map(formatError).toIterable.mkString("\n")
for {
_ <- logAction(s"Compilation failed for $filePath\nExit code: $code")
_ <- IO.blocking(
// Explicit println since we want this to always show without logger thread info e.t.c.
println(s"Compilation failed for ${file.getCanonicalPath}:\n$errorMsg")
)
} yield code
case Right(typedProg) =>
val outputFile = outDir.resolve(filePath.getFileName.toString.stripSuffix(".wacc") + ".s")
writeOutputFile(typedProg, outputFile).as(SUCCESS)
}
} yield res
for {
contents <- readSourceFile
_ <- logAction(s"Compiling file: ${filePath.toAbsolutePath}")
exitCode <- processProgram(contents, filePath.toFile, outputDir.getOrElse(filePath.getParent))
} yield exitCode
}
def compileCommandParallel(
files: NonEmptyList[Path],
log: Boolean,
outDir: Option[Path]
): IO[ExitCode] =
Stream
.emits(files.toList)
.parEvalMapUnordered(Runtime.getRuntime.availableProcessors()) { file =>
compile(file.toAbsolutePath, outDir, log)
}
.compile
.toList
.map { exitCodes =>
exitCodes.filter(_ != 0) match {
case Nil => ExitCode.Success
case errorCodes => ExitCode(errorCodes.min)
}
}
object Main
extends CommandIOApp(
name = "wacc",
header = "The ultimate WACC compiler",
version = "1.0"
) {
def main: Opts[IO[ExitCode]] =
(filesOpt, logOpt, outputOpt).mapN { (files, log, outDir) =>
compileCommandParallel(files, log, outDir)
}
def main(args: Array[String]): Unit =
OParser.parse(cliParser, args, CliConfig()) match {
case Some(config) =>
System.exit(
compile(
config.file.getAbsolutePath,
outFile = Some(File(".", config.file.getName.stripSuffix(".wacc") + ".s"))
)
)
case None =>
}

View File

@@ -2,6 +2,7 @@ package wacc
import scala.collection.mutable
import cats.data.Chain
import wacc.ast.Position
private class LabelGenerator {
import assemblyIR._
@@ -9,7 +10,9 @@ private class LabelGenerator {
import asmGenerator.escaped
private val strings = mutable.HashMap[String, String]()
private val files = mutable.HashMap[String, Int]()
private var labelVal = -1
private var permittedFuncFile: Option[String] = None
/** Get an arbitrary label. */
def getLabel(): String = {
@@ -18,7 +21,7 @@ private class LabelGenerator {
}
private def getLabel(target: CallTarget | RuntimeError): String = target match {
case Ident(v, _) => s"wacc_$v"
case Ident(v, guid) => s"wacc_${v}_$guid"
case Builtin(name) => s"_$name"
case err: RuntimeError => s".L.${err.name}"
}
@@ -39,6 +42,25 @@ private class LabelGenerator {
def getLabelArg(src: String, name: String): LabelArg =
LabelArg(strings.getOrElseUpdate(src, s".L.$name.str${strings.size}"))
/** Get a debug directive for a file. */
def getDebugFile(file: java.io.File): Int =
files.getOrElseUpdate(file.getCanonicalPath, files.size)
/** Get a debug directive for a function. */
def getDebugFunc(pos: Position, name: String, label: LabelDef): Chain[AsmLine] = {
permittedFuncFile match {
case Some(f) if f != pos.file.getCanonicalPath => Chain.empty
case _ =>
val customLabel = if name == "main" then Chain.empty else Chain(LabelDef(name))
permittedFuncFile = Some(pos.file.getCanonicalPath)
customLabel ++ Chain(
Directive.Location(getDebugFile(pos.file), pos.line, None),
Directive.Type(label, SymbolType.Function),
Directive.Func(name, label)
)
}
}
/** Generate the assembly labels for constants that were labelled using the LabelGenerator. */
def generateConstants: Chain[AsmLine] =
strings.foldLeft(Chain.empty) { case (acc, (str, label)) =>
@@ -47,4 +69,10 @@ private class LabelGenerator {
Directive.Asciz(str.escaped)
)
}
/** Generates debug directives that were created using the LabelGenerator. */
def generateDebug: Chain[AsmLine] =
files.foldLeft(Chain.empty) { case (acc, (file, no)) =>
acc :+ Directive.File(no, file)
}
}

View File

@@ -23,7 +23,7 @@ class Stack {
/** Push an expression onto the stack. */
def push(expr: mw.Expr, src: Register): AsmLine = {
stack += expr -> StackValue(src.size, sizeBytes)
stack += expr -> StackValue(expr.ty.size, sizeBytes)
Push(src)
}
@@ -81,7 +81,7 @@ class Stack {
/** Get an MemLocation for a variable in the stack. */
def accessVar(ident: mw.Ident): MemLocation =
MemLocation(RSP, sizeBytes - stack(ident).bottom)
MemLocation(RSP, sizeBytes - stack(ident).bottom, opSize = Some(stack(ident).size))
def contains(ident: mw.Ident): Boolean = stack.contains(ident)
def head: MemLocation = MemLocation(RSP, opSize = Some(stack.last._2.size))

View File

@@ -36,11 +36,14 @@ object asmGenerator {
given labelGenerator: LabelGenerator = LabelGenerator()
val Program(funcs, main) = microProg
val progAsm = Chain(LabelDef("main")).concatAll(
val mainLabel = LabelDef("main")
val mainAsm = labelGenerator.getDebugFunc(microProg.pos, "main", mainLabel) + mainLabel
val progAsm = mainAsm.concatAll(
funcPrologue(),
main.foldMap(generateStmt(_)),
Chain.one(Xor(RAX, RAX)),
funcEpilogue(),
Chain(Directive.Size(mainLabel, SizeExpr.Relative(mainLabel)), Directive.EndFunc),
generateBuiltInFuncs(),
RuntimeError.all.foldMap(_.generate),
funcs.foldMap(generateUserFunc(_))
@@ -51,6 +54,7 @@ object asmGenerator {
Directive.Global("main"),
Directive.RoData
).concatAll(
labelGenerator.generateDebug,
labelGenerator.generateConstants,
Chain.one(Directive.Text),
progAsm
@@ -75,7 +79,10 @@ object asmGenerator {
// Setup the stack with param 7 and up
func.params.drop(argRegs.size).foreach(stack.reserve(_))
stack.reserve(Size.Q64) // Reserve return pointer slot
var asm = Chain.one[AsmLine](labelGenerator.getLabelDef(func.name))
val funcLabel = labelGenerator.getLabelDef(func.name)
var asm = labelGenerator.getDebugFunc(func.pos, func.name.name, funcLabel)
val debugFunc = asm.size > 0
asm += funcLabel
asm ++= funcPrologue()
// Push the rest of params onto the stack for simplicity
argRegs.zip(func.params).foreach { (reg, param) =>
@@ -83,6 +90,10 @@ object asmGenerator {
}
asm ++= func.body.foldMap(generateStmt(_))
// No need for epilogue here since all user functions must return explicitly
if (debugFunc) {
asm += Directive.Size(funcLabel, SizeExpr.Relative(funcLabel))
asm += Directive.EndFunc
}
asm
}
@@ -159,8 +170,8 @@ object asmGenerator {
stack: Stack,
labelGenerator: LabelGenerator
): Chain[AsmLine] = {
var asm = Chain.empty[AsmLine]
asm += Comment(stmt.toString)
val fileNo = labelGenerator.getDebugFile(stmt.pos.file)
var asm = Chain.one[AsmLine](Directive.Location(fileNo, stmt.pos.line, None))
stmt match {
case Assign(lhs, rhs) =>
lhs match {
@@ -261,7 +272,7 @@ object asmGenerator {
asm += stack.push(KnownType.String.size, RAX)
case ty =>
asm ++= generateCall(
microWacc.Call(Builtin.Malloc, List(IntLiter(array.heapSize))),
microWacc.Call(Builtin.Malloc, List(IntLiter(array.heapSize)))(array.pos),
isTail = false
)
asm += stack.push(KnownType.Array(?).size, RAX)

View File

@@ -199,10 +199,15 @@ object assemblyIR {
}
enum Directive extends AsmLine {
case IntelSyntax, RoData, Text
case IntelSyntax, RoData, Text, EndFunc
case Global(name: String)
case Int(value: scala.Int)
case Asciz(string: String)
case File(no: scala.Int, file: String)
case Location(fileNo: scala.Int, lineNo: scala.Int, colNo: Option[scala.Int])
case Func(name: String, label: LabelDef)
case Type(label: LabelDef, symbolType: SymbolType)
case Size(label: LabelDef, expr: SizeExpr)
override def toString(): String = this match {
case IntelSyntax => ".intel_syntax noprefix"
@@ -211,6 +216,32 @@ object assemblyIR {
case RoData => ".section .rodata"
case Int(value) => s"\t.int $value"
case Asciz(string) => s"\t.asciz \"$string\""
case File(no, file) => s".file $no \"${file}\""
case Location(fileNo, lineNo, colNo) =>
s"\t.loc $fileNo $lineNo" + colNo.map(c => s" $c").getOrElse("")
case Func(name, label) =>
s".func $name, ${label.name}"
case EndFunc => ".endfunc"
case Type(label, symbolType) =>
s".type ${label.name}, @${symbolType.toString}"
case Directive.Size(label, expr) =>
s".size ${label.name}, ${expr.toString}"
}
}
enum SymbolType {
case Function
override def toString(): String = this match {
case Function => "function"
}
}
enum SizeExpr {
case Relative(label: LabelDef)
override def toString(): String = this match {
case Relative(label) => s".-${label.name}"
}
}

View File

@@ -1,12 +1,42 @@
package wacc
import java.io.PrintStream
import cats.effect.Resource
import java.nio.charset.StandardCharsets
import java.io.BufferedWriter
import java.io.FileWriter
import cats.data.Chain
import cats.effect.IO
import org.typelevel.log4cats.Logger
import java.nio.file.Path
object writer {
import assemblyIR._
def writeTo(asmList: Chain[AsmLine], printStream: PrintStream): Unit = {
asmList.iterator.foreach(printStream.println)
// TODO: Judging from documentation it seems as though IO.blocking is the correct choice
// But needs checking
/** Creates a resource safe BufferedWriter */
private def bufferedWriter(outputPath: Path): Resource[IO, BufferedWriter] =
Resource.make {
IO.blocking(new BufferedWriter(new FileWriter(outputPath.toFile, StandardCharsets.UTF_8)))
} { writer =>
IO.blocking(writer.close())
.handleErrorWith(_ => IO.unit) // TODO: ensures writer is closed even if an error occurs
}
/** Write line safely into a BufferedWriter */
private def writeLines(writer: BufferedWriter, lines: Chain[AsmLine]): IO[Unit] =
IO.blocking {
lines.iterator.foreach { line =>
writer.write(line.toString)
writer.newLine()
}
}
/** Main function to write assembly to a file */
def writeTo(asmList: Chain[AsmLine], outputPath: Path)(using logger: Logger[IO]): IO[Unit] =
bufferedWriter(outputPath).use {
writeLines(_, asmList)
}
}

View File

@@ -2,7 +2,10 @@ package wacc
import wacc.ast.Position
import wacc.types._
import java.io.PrintStream
import java.io.File
private val SYNTAX_ERROR = 100
private val SEMANTIC_ERROR = 200
/** Error types for semantic errors
*/
@@ -15,6 +18,15 @@ enum Error {
case SemanticError(pos: Position, msg: String)
case TypeMismatch(pos: Position, expected: SemType, got: SemType, msg: String)
case InternalError(pos: Position, msg: String)
case SyntaxError(file: File, msg: String)
}
extension (e: Error) {
def exitCode: Int = e match {
case Error.SyntaxError(_, _) => SYNTAX_ERROR
case _ => SEMANTIC_ERROR
}
}
/** Function to handle printing the details of a given semantic error
@@ -24,42 +36,26 @@ enum Error {
* @param errorContent
* Contents of the file to generate code snippets
*/
def printError(error: Error)(using errorContent: String, stdout: PrintStream): Unit = {
stdout.println("Semantic error:")
error match {
case Error.DuplicateDeclaration(ident) =>
printPosition(ident.pos)
stdout.println(s"Duplicate declaration of identifier ${ident.v}")
highlight(ident.pos, ident.v.length)
case Error.UndeclaredVariable(ident) =>
printPosition(ident.pos)
stdout.println(s"Undeclared variable ${ident.v}")
highlight(ident.pos, ident.v.length)
case Error.UndefinedFunction(ident) =>
printPosition(ident.pos)
stdout.println(s"Undefined function ${ident.v}")
highlight(ident.pos, ident.v.length)
case Error.FunctionParamsMismatch(id, expected, got, funcType) =>
printPosition(id.pos)
stdout.println(s"Function expects $expected parameters, got $got")
stdout.println(
s"(function ${id.v} has type (${funcType.params.mkString(", ")}) -> ${funcType.returnType})"
)
highlight(id.pos, 1)
case Error.TypeMismatch(pos, expected, got, msg) =>
printPosition(pos)
stdout.println(s"Type mismatch: $msg\nExpected: $expected\nGot: $got")
highlight(pos, 1)
case Error.SemanticError(pos, msg) =>
printPosition(pos)
stdout.println(msg)
highlight(pos, 1)
case wacc.Error.InternalError(pos, msg) =>
printPosition(pos)
stdout.println(s"Internal error: $msg")
highlight(pos, 1)
def formatError(error: Error): String = {
val sb = new StringBuilder()
/** Format the file of an error
*
* @param file
* File of the error
*/
def formatFile(file: File): Unit = {
sb.append(s"File: ${file.getCanonicalPath}\n")
}
/** Function to format the position of an error
*
* @param pos
* Position of the error
*/
def formatPosition(pos: Position): Unit = {
formatFile(pos.file)
sb.append(s"(line ${pos.line}, column ${pos.column}):\n")
}
/** Function to highlight a section of code for an error message
@@ -68,27 +64,63 @@ def printError(error: Error)(using errorContent: String, stdout: PrintStream): U
* Position of the error
* @param size
* Size(in chars) of section to highlight
* @param errorContent
* Contents of the file to generate code snippets
*/
def highlight(pos: Position, size: Int)(using errorContent: String, stdout: PrintStream): Unit = {
val lines = errorContent.split("\n")
def formatHighlight(pos: Position, size: Int): Unit = {
val lines = os.read(os.Path(pos.file.getCanonicalPath)).split("\n")
val preLine = if (pos.line > 1) lines(pos.line - 2) else ""
val midLine = lines(pos.line - 1)
val postLine = if (pos.line < lines.size) lines(pos.line) else ""
val linePointer = " " * (pos.column + 2) + ("^" * (size)) + "\n"
stdout.println(
s" >$preLine\n >$midLine\n$linePointer >$postLine"
sb.append(
s" >$preLine\n >$midLine\n$linePointer >$postLine\netscape"
)
}
/** Function to print the position of an error
*
* @param pos
* Position of the error
*/
def printPosition(pos: Position)(using stdout: PrintStream): Unit = {
stdout.println(s"(line ${pos.line}, column ${pos.column}):")
error match {
case Error.SyntaxError(_, _) =>
sb.append("Syntax error:\n")
case _ =>
sb.append("Semantic error:\n")
}
error match {
case Error.DuplicateDeclaration(ident) =>
formatPosition(ident.pos)
sb.append(s"Duplicate declaration of identifier ${ident.v}\n")
formatHighlight(ident.pos, ident.v.length)
case Error.UndeclaredVariable(ident) =>
formatPosition(ident.pos)
sb.append(s"Undeclared variable ${ident.v}\n")
formatHighlight(ident.pos, ident.v.length)
case Error.UndefinedFunction(ident) =>
formatPosition(ident.pos)
sb.append(s"Undefined function ${ident.v}\n")
formatHighlight(ident.pos, ident.v.length)
case Error.FunctionParamsMismatch(id, expected, got, funcType) =>
formatPosition(id.pos)
sb.append(s"Function expects $expected parameters, got $got\n")
sb.append(
s"(function ${id.v} has type (${funcType.params.mkString(", ")}) -> ${funcType.returnType})\n"
)
formatHighlight(id.pos, 1)
case Error.TypeMismatch(pos, expected, got, msg) =>
formatPosition(pos)
sb.append(s"Type mismatch: $msg\nExpected: $expected\nGot: $got\n")
formatHighlight(pos, 1)
case Error.SemanticError(pos, msg) =>
formatPosition(pos)
sb.append(msg + "\n")
formatHighlight(pos, 1)
case wacc.Error.InternalError(pos, msg) =>
formatPosition(pos)
sb.append(s"Internal error: $msg\n")
formatHighlight(pos, 1)
case Error.SyntaxError(file, msg) =>
formatFile(file)
sb.append(msg + "\n")
sb.append("\n")
}
sb.toString()
}

View File

@@ -1,10 +1,10 @@
package wacc
import java.io.File
import parsley.Parsley
import parsley.generic.ErrorBridge
import parsley.ap._
import parsley.position._
import parsley.syntax.zipped._
import cats.data.NonEmptyList
object ast {
@@ -23,26 +23,42 @@ object ast {
/* ============================ ATOMIC EXPRESSIONS ============================ */
case class IntLiter(v: Int)(val pos: Position) extends Expr6
object IntLiter extends ParserBridgePos1[Int, IntLiter]
object IntLiter extends ParserBridgePos1Atom[Int, IntLiter]
case class BoolLiter(v: Boolean)(val pos: Position) extends Expr6
object BoolLiter extends ParserBridgePos1[Boolean, BoolLiter]
object BoolLiter extends ParserBridgePos1Atom[Boolean, BoolLiter]
case class CharLiter(v: Char)(val pos: Position) extends Expr6
object CharLiter extends ParserBridgePos1[Char, CharLiter]
object CharLiter extends ParserBridgePos1Atom[Char, CharLiter]
case class StrLiter(v: String)(val pos: Position) extends Expr6
object StrLiter extends ParserBridgePos1[String, StrLiter]
object StrLiter extends ParserBridgePos1Atom[String, StrLiter]
case class PairLiter()(val pos: Position) extends Expr6
object PairLiter extends ParserBridgePos0[PairLiter]
case class Ident(v: String, var uid: Int = -1)(val pos: Position) extends Expr6 with LValue
object Ident extends ParserBridgePos1[String, Ident] {
case class Ident(var v: String, var guid: Int = -1, var ty: types.RenamerType = types.?)(
val pos: Position
) extends Expr6
with LValue
object Ident extends ParserBridgePos1Atom[String, Ident] {
def apply(v: String)(pos: Position): Ident = new Ident(v)(pos)
}
case class ArrayElem(name: Ident, indices: NonEmptyList[Expr])(val pos: Position)
extends Expr6
with LValue
object ArrayElem extends ParserBridgePos1[NonEmptyList[Expr], Ident => ArrayElem] {
def apply(a: NonEmptyList[Expr])(pos: Position): Ident => ArrayElem =
name => ArrayElem(name, a)(pos)
object ArrayElem extends ParserBridgePos2Chain[NonEmptyList[Expr], Ident, ArrayElem] {
def apply(indices: NonEmptyList[Expr], name: Ident)(pos: Position): ArrayElem =
new ArrayElem(name, indices)(pos)
}
// object ArrayElem extends ParserBridgePos1[NonEmptyList[Expr], (File => Ident) => ArrayElem] {
// def apply(a: NonEmptyList[Expr])(pos: Position): (File => Ident) => ArrayElem =
// name => ArrayElem(name(pos.file), a)(pos)
// }
// object ArrayElem extends ParserSingletonBridgePos[(File => NonEmptyList[Expr]) => (File => Ident) => File => ArrayElem] {
// // def apply(indices: NonEmptyList[Expr]): (File => Ident) => File => ArrayElem =
// // name => file => new ArrayElem(name(file), )
// def apply(indices: Parsley[File => NonEmptyList[Expr]]): Parsley[(File => Ident) => File => ArrayElem] =
// // error(ap1(pos.map(con),))
// override final def con(pos: (Int, Int)): (File => NonEmptyList[Expr]) => => C =
// (a, b) => file => this.apply(a(file), b(file))(Position(pos._1, pos._2, file))
// }
case class Parens(expr: Expr)(val pos: Position) extends Expr6
object Parens extends ParserBridgePos1[Expr, Parens]
@@ -120,8 +136,9 @@ object ast {
case class ArrayType(elemType: Type, dimensions: Int)(val pos: Position)
extends Type
with PairElemType
object ArrayType extends ParserBridgePos1[Int, Type => ArrayType] {
def apply(a: Int)(pos: Position): Type => ArrayType = elemType => ArrayType(elemType, a)(pos)
object ArrayType extends ParserBridgePos2Chain[Int, Type, ArrayType] {
def apply(dimensions: Int, elemType: Type)(pos: Position): ArrayType =
ArrayType(elemType, dimensions)(pos)
}
case class PairType(fst: PairElemType, snd: PairElemType)(val pos: Position) extends Type
object PairType extends ParserBridgePos2[PairElemType, PairElemType, PairType]
@@ -132,6 +149,18 @@ object ast {
/* ============================ PROGRAM STRUCTURE ============================ */
case class ImportedFunc(sourceName: Ident, importName: Ident)(val pos: Position)
object ImportedFunc extends ParserBridgePos2[Ident, Option[Ident], ImportedFunc] {
def apply(a: Ident, b: Option[Ident])(pos: Position): ImportedFunc =
new ImportedFunc(a, b.getOrElse(a))(pos)
}
case class Import(source: StrLiter, funcs: NonEmptyList[ImportedFunc])(val pos: Position)
object Import extends ParserBridgePos2[StrLiter, NonEmptyList[ImportedFunc], Import]
case class PartialProgram(imports: List[Import], self: Program)(val pos: Position)
object PartialProgram extends ParserBridgePos2[List[Import], Program, PartialProgram]
case class Program(funcs: List[FuncDecl], main: NonEmptyList[Stmt])(val pos: Position)
object Program extends ParserBridgePos2[List[FuncDecl], NonEmptyList[Stmt], Program]
@@ -144,15 +173,15 @@ object ast {
body: NonEmptyList[Stmt]
)(val pos: Position)
object FuncDecl
extends ParserBridgePos2[
List[Param],
NonEmptyList[Stmt],
((Type, Ident)) => FuncDecl
extends ParserBridgePos2Chain[
(List[Param], NonEmptyList[Stmt]),
((Type, Ident)),
FuncDecl
] {
def apply(params: List[Param], body: NonEmptyList[Stmt])(
def apply(paramsBody: (List[Param], NonEmptyList[Stmt]), retTyName: (Type, Ident))(
pos: Position
): ((Type, Ident)) => FuncDecl =
(returnType, name) => FuncDecl(returnType, name, params, body)(pos)
): FuncDecl =
new FuncDecl(retTyName._1, retTyName._2, paramsBody._1, paramsBody._2)(pos)
}
case class Param(paramType: Type, name: Ident)(val pos: Position)
@@ -160,7 +189,9 @@ object ast {
/* ============================ STATEMENTS ============================ */
sealed trait Stmt
sealed trait Stmt {
val pos: Position
}
case class Skip()(val pos: Position) extends Stmt
object Skip extends ParserBridgePos0[Skip]
case class VarDecl(varType: Type, name: Ident, value: RValue)(val pos: Position) extends Stmt
@@ -192,7 +223,9 @@ object ast {
val pos: Position
}
sealed trait RValue
sealed trait RValue {
val pos: Position
}
case class ArrayLiter(elems: List[Expr])(val pos: Position) extends RValue
object ArrayLiter extends ParserBridgePos1[List[Expr], ArrayLiter]
case class NewPair(fst: Expr, snd: Expr)(val pos: Position) extends RValue
@@ -208,7 +241,7 @@ object ast {
/* ============================ PARSER BRIDGES ============================ */
case class Position(line: Int, column: Int)
case class Position(line: Int, column: Int, file: File)
trait ParserSingletonBridgePos[+A] extends ErrorBridge {
protected def con(pos: (Int, Int)): A
@@ -216,38 +249,63 @@ object ast {
final def <#(op: Parsley[?]): Parsley[A] = this from op
}
trait ParserBridgePos0[+A] extends ParserSingletonBridgePos[A] {
trait ParserBridgePos0[+A] extends ParserSingletonBridgePos[File => A] {
def apply()(pos: Position): A
override final def con(pos: (Int, Int)): A =
apply()(Position(pos._1, pos._2))
override final def con(pos: (Int, Int)): File => A =
file => apply()(Position(pos._1, pos._2, file))
}
trait ParserBridgePos1[-A, +B] extends ParserSingletonBridgePos[A => B] {
trait ParserBridgePos1Atom[-A, +B] extends ParserSingletonBridgePos[A => File => B] {
def apply(a: A)(pos: Position): B
def apply(a: Parsley[A]): Parsley[B] = error(ap1(pos.map(con), a))
def apply(a: Parsley[A]): Parsley[File => B] = error(ap1(pos.map(con), a))
override final def con(pos: (Int, Int)): A => B =
this.apply(_)(Position(pos._1, pos._2))
override final def con(pos: (Int, Int)): A => File => B =
a => file => this.apply(a)(Position(pos._1, pos._2, file))
}
trait ParserBridgePos2[-A, -B, +C] extends ParserSingletonBridgePos[(A, B) => C] {
trait ParserBridgePos1[-A, +B] extends ParserSingletonBridgePos[(File => A) => File => B] {
def apply(a: A)(pos: Position): B
def apply(a: Parsley[File => A]): Parsley[File => B] = error(ap1(pos.map(con), a))
override final def con(pos: (Int, Int)): (File => A) => File => B =
a => file => this.apply(a(file))(Position(pos._1, pos._2, file))
}
trait ParserBridgePos2Chain[-A, -B, +C]
extends ParserSingletonBridgePos[(File => A) => (File => B) => File => C] {
def apply(a: A, b: B)(pos: Position): C
def apply(a: Parsley[A], b: => Parsley[B]): Parsley[C] = error(
def apply(a: Parsley[File => A]): Parsley[(File => B) => File => C] = error(
ap1(pos.map(con), a)
)
override final def con(pos: (Int, Int)): (File => A) => (File => B) => File => C =
a => b => file => this.apply(a(file), b(file))(Position(pos._1, pos._2, file))
}
trait ParserBridgePos2[-A, -B, +C]
extends ParserSingletonBridgePos[(File => A, File => B) => File => C] {
def apply(a: A, b: B)(pos: Position): C
def apply(a: Parsley[File => A], b: => Parsley[File => B]): Parsley[File => C] = error(
ap2(pos.map(con), a, b)
)
override final def con(pos: (Int, Int)): (A, B) => C =
apply(_, _)(Position(pos._1, pos._2))
override final def con(pos: (Int, Int)): (File => A, File => B) => File => C =
(a, b) => file => this.apply(a(file), b(file))(Position(pos._1, pos._2, file))
}
trait ParserBridgePos3[-A, -B, -C, +D] extends ParserSingletonBridgePos[(A, B, C) => D] {
trait ParserBridgePos3[-A, -B, -C, +D]
extends ParserSingletonBridgePos[(File => A, File => B, File => C) => File => D] {
def apply(a: A, b: B, c: C)(pos: Position): D
def apply(a: Parsley[A], b: => Parsley[B], c: => Parsley[C]): Parsley[D] = error(
def apply(
a: Parsley[File => A],
b: => Parsley[File => B],
c: => Parsley[File => C]
): Parsley[File => D] = error(
ap3(pos.map(con), a, b, c)
)
override final def con(pos: (Int, Int)): (A, B, C) => D =
apply(_, _, _)(Position(pos._1, pos._2))
override final def con(pos: (Int, Int)): (File => A, File => B, File => C) => File => D =
(a, b, c) => file => apply(a(file), b(file), c(file))(Position(pos._1, pos._2, file))
}
}

View File

@@ -1,6 +1,9 @@
package wacc
import cats.data.Chain
object microWacc {
import wacc.ast.Position
import wacc.types._
sealed trait CallTarget(val retTy: SemType)
@@ -11,7 +14,7 @@ object microWacc {
case class IntLiter(v: Int) extends Expr(KnownType.Int)
case class BoolLiter(v: Boolean) extends Expr(KnownType.Bool)
case class CharLiter(v: Char) extends Expr(KnownType.Char)
case class ArrayLiter(elems: List[Expr])(ty: SemType) extends Expr(ty)
case class ArrayLiter(elems: List[Expr])(ty: SemType, val pos: Position) extends Expr(ty)
case class NullLiter()(ty: SemType) extends Expr(ty)
case class Ident(name: String, uid: Int)(identTy: SemType)
extends Expr(identTy)
@@ -63,7 +66,9 @@ object microWacc {
}
// Statements
sealed trait Stmt
sealed trait Stmt {
val pos: Position
}
case class Builtin(val name: String)(retTy: SemType) extends CallTarget(retTy) {
override def toString(): String = name
@@ -77,13 +82,16 @@ object microWacc {
object PrintCharArray extends Builtin("printCharArray")(?)
}
case class Assign(lhs: LValue, rhs: Expr) extends Stmt
case class If(cond: Expr, thenBranch: List[Stmt], elseBranch: List[Stmt]) extends Stmt
case class While(cond: Expr, body: List[Stmt]) extends Stmt
case class Call(target: CallTarget, args: List[Expr]) extends Stmt with Expr(target.retTy)
case class Return(expr: Expr) extends Stmt
case class Assign(lhs: LValue, rhs: Expr)(val pos: Position) extends Stmt
case class If(cond: Expr, thenBranch: Chain[Stmt], elseBranch: Chain[Stmt])(val pos: Position)
extends Stmt
case class While(cond: Expr, body: Chain[Stmt])(val pos: Position) extends Stmt
case class Call(target: CallTarget, args: List[Expr])(val pos: Position)
extends Stmt
with Expr(target.retTy)
case class Return(expr: Expr)(val pos: Position) extends Stmt
// Program
case class FuncDecl(name: Ident, params: List[Ident], body: List[Stmt])
case class Program(funcs: List[FuncDecl], stmts: List[Stmt])
case class FuncDecl(name: Ident, params: List[Ident], body: Chain[Stmt])(val pos: Position)
case class Program(funcs: Chain[FuncDecl], stmts: Chain[Stmt])(val pos: Position)
}

View File

@@ -1,18 +1,22 @@
package wacc
import java.io.File
import parsley.Result
import parsley.Parsley
import parsley.Parsley.{atomic, many, notFollowedBy, pure, unit}
import parsley.combinator.{countSome, sepBy}
import parsley.combinator.{countSome, sepBy, option}
import parsley.expr.{precedence, SOps, InfixL, InfixN, InfixR, Prefix, Atoms}
import parsley.errors.combinator._
import parsley.errors.patterns.VerifiedErrors
import parsley.syntax.zipped._
import parsley.cats.combinator.{some}
import parsley.cats.combinator.{some, sepBy1}
import cats.syntax.all._
import cats.data.NonEmptyList
import parsley.errors.DefaultErrorBuilder
import parsley.errors.ErrorBuilder
import parsley.errors.tokenextractors.LexToken
import parsley.expr.GOps
import cats.Functor
object parser {
import lexer.implicits.implicitSymbol
@@ -52,13 +56,24 @@ object parser {
implicit val builder: ErrorBuilder[String] = new DefaultErrorBuilder with LexToken {
def tokens = errTokens
}
def parse(input: String): Result[String, Program] = parser.parse(input)
private val parser = lexer.fully(`<program>`)
def parse(input: String): Result[String, File => PartialProgram] = parser.parse(input)
private val parser = lexer.fully(`<partial-program>`)
private type FParsley[A] = Parsley[File => A]
private def fParsley[A](p: Parsley[A]): FParsley[A] =
p map { a => file => a }
private def fPair[A, B](p: Parsley[(File => A, File => B)]): FParsley[(A, B)] =
p map { case (a, b) => file => (a(file), b(file)) }
private def fMap[A, F[_]: Functor](p: Parsley[F[File => A]]): FParsley[F[A]] =
p map { funcs => file => funcs.map(_(file)) }
// Expressions
private lazy val `<expr>`: Parsley[Expr] = precedence {
SOps(InfixR)(Or from "||") +:
SOps(InfixR)(And from "&&") +:
private lazy val `<expr>`: FParsley[Expr] = precedence {
GOps(InfixR)(Or from "||") +:
GOps(InfixR)(And from "&&") +:
SOps(InfixN)(Eq from "==", Neq from "!=") +:
SOps(InfixN)(
Less from "<",
@@ -83,32 +98,33 @@ object parser {
}
// Atoms
private lazy val `<atom>`: Atoms[Expr6] = Atoms(
private lazy val `<atom>`: Atoms[File => Expr6] = Atoms(
IntLiter(integer).label("integer literal"),
BoolLiter(("true" as true) | ("false" as false)).label("boolean literal"),
CharLiter(charLit).label("character literal"),
StrLiter(stringLit).label("string literal"),
`<str-liter>`.label("string literal"),
PairLiter from "null",
`<ident-or-array-elem>`,
Parens("(" ~> `<expr>` <~ ")")
)
private val `<ident>` =
private lazy val `<str-liter>` = StrLiter(stringLit)
private lazy val `<ident>` =
Ident(ident) | some("*" | "&").verifiedExplain("pointer operators are not allowed")
private lazy val `<ident-or-array-elem>` =
(`<ident>` <~ ("(".verifiedExplain(
"functions can only be called using 'call' keyword"
) | unit)) <**> (`<array-indices>` </> identity)
private val `<array-indices>` = ArrayElem(some("[" ~> `<expr>` <~ "]"))
private lazy val `<array-indices>` = ArrayElem(fMap(some("[" ~> `<expr>` <~ "]")))
// Types
private lazy val `<type>`: Parsley[Type] =
private lazy val `<type>`: FParsley[Type] =
(`<base-type>` | (`<pair-type>` ~> `<pair-elems-type>`)) <**> (`<array-type>` </> identity)
private val `<base-type>` =
(IntType from "int") | (BoolType from "bool") | (CharType from "char") | (StringType from "string")
private lazy val `<array-type>` =
ArrayType(countSome("[" ~> "]"))
ArrayType(fParsley(countSome("[" ~> "]")))
private val `<pair-type>` = "pair"
private val `<pair-elems-type>`: Parsley[PairType] = PairType(
private val `<pair-elems-type>`: FParsley[PairType] = PairType(
"(" ~> `<pair-elem-type>` <~ ",",
`<pair-elem-type>` <~ ")"
)
@@ -116,7 +132,7 @@ object parser {
(`<base-type>` <**> (`<array-type>` </> identity)) |
((UntypedPairType from `<pair-type>`) <**>
((`<pair-elems-type>` <**> `<array-type>`)
.map(arr => (_: UntypedPairType) => arr) </> identity))
.map(arr => (_: File => UntypedPairType) => arr) </> identity))
/* Statements
Atomic is used in two places here:
@@ -127,13 +143,30 @@ object parser {
invalid syntax check, this only happens at most once per program so this is not a major
concern.
*/
private lazy val `<partial-program>` = PartialProgram(
fMap(many(`<import>`)),
`<program>`
)
private lazy val `<import>` = Import(
"import" ~> `<import-filename>`,
"(" ~> fMap(sepBy1(`<imported-func>`, ",")) <~ ")"
)
private lazy val `<import-filename>` = `<str-liter>`.label("import file name")
private lazy val `<imported-func>` = ImportedFunc(
`<ident>`.label("imported function name"),
fMap(option("as" ~> `<ident>`)).label("imported function alias")
)
private lazy val `<program>` = Program(
"begin" ~> (
fMap(
many(
fPair(
atomic(
`<type>`.label("function declaration") <~> `<ident>` <~ "("
)
) <**> `<partial-func-decl>`
).label("function declaration") |
).label("function declaration")
) |
atomic(`<ident>` <~ "(").verifiedExplain("function declaration is missing return type")
),
`<stmt>`.label(
@@ -142,17 +175,23 @@ object parser {
)
private lazy val `<partial-func-decl>` =
FuncDecl(
sepBy(`<param>`, ",") <~ ")" <~ "is",
`<stmt>`.guardAgainst {
case stmts if !stmts.isReturning => Seq("all functions must end in a returning statement")
} <~ "end"
fPair(
(fMap(sepBy(`<param>`, ",")) <~ ")" <~ "is") <~>
(`<stmt>`.guardAgainst {
// TODO: passing in an arbitrary file works but is ugly
case stmts if !(stmts(File("."))).isReturning =>
Seq("all functions must end in a returning statement")
} <~ "end")
)
)
private lazy val `<param>` = Param(`<type>`, `<ident>`)
private lazy val `<stmt>`: Parsley[NonEmptyList[Stmt]] =
private lazy val `<stmt>`: FParsley[NonEmptyList[Stmt]] =
fMap(
(
`<basic-stmt>`.label("main program body"),
(many(";" ~> `<basic-stmt>`.label("statement after ';'"))) </> Nil
).zipped(NonEmptyList.apply)
)
private lazy val `<basic-stmt>` =
(Skip from "skip")
@@ -160,8 +199,8 @@ object parser {
| Free("free" ~> `<expr>`.labelAndExplain(LabelType.Expr))
| Return("return" ~> `<expr>`.labelAndExplain(LabelType.Expr))
| Exit("exit" ~> `<expr>`.labelAndExplain(LabelType.Expr))
| Print("print" ~> `<expr>`.labelAndExplain(LabelType.Expr), pure(false))
| Print("println" ~> `<expr>`.labelAndExplain(LabelType.Expr), pure(true))
| Print("print" ~> `<expr>`.labelAndExplain(LabelType.Expr), fParsley(pure(false)))
| Print("println" ~> `<expr>`.labelAndExplain(LabelType.Expr), fParsley(pure(true)))
| If(
"if" ~> `<expr>`.labelWithType(LabelType.Expr) <~ "then",
`<stmt>` <~ "else",
@@ -185,9 +224,9 @@ object parser {
("call" ~> `<ident>`).verifiedExplain(
"function calls' results must be assigned to a variable"
)
private lazy val `<lvalue>`: Parsley[LValue] =
private lazy val `<lvalue>`: FParsley[LValue] =
`<pair-elem>` | `<ident-or-array-elem>`
private lazy val `<rvalue>`: Parsley[RValue] =
private lazy val `<rvalue>`: FParsley[RValue] =
`<array-liter>` |
NewPair(
"newpair" ~> "(" ~> `<expr>` <~ ",",
@@ -196,13 +235,13 @@ object parser {
`<pair-elem>` |
Call(
"call" ~> `<ident>` <~ "(",
sepBy(`<expr>`, ",") <~ ")"
fMap(sepBy(`<expr>`, ",")) <~ ")"
) | `<expr>`.labelWithType(LabelType.Expr)
private lazy val `<pair-elem>` =
Fst("fst" ~> `<lvalue>`.label("valid pair"))
| Snd("snd" ~> `<lvalue>`.label("valid pair"))
private lazy val `<array-liter>` = ArrayLiter(
"[" ~> sepBy(`<expr>`, ",") <~ "]"
"[" ~> fMap(sepBy(`<expr>`, ",")) <~ "]"
)
extension (stmts: NonEmptyList[Stmt]) {

View File

@@ -1,126 +1,276 @@
package wacc
import java.io.File
import scala.collection.mutable
import cats.effect.IO
import cats.implicits._
import cats.data.Chain
import cats.data.NonEmptyList
import parsley.{Failure, Success}
object renamer {
import ast._
import types._
private enum IdentType {
val MAIN = "$main"
enum IdentType {
case Func
case Var
}
private class Scope(
val current: mutable.Map[(String, IdentType), Ident],
val parent: Map[(String, IdentType), Ident]
case class ScopeKey(path: String, name: String, identType: IdentType)
case class ScopeValue(id: Ident, public: Boolean)
class Scope(
private val current: mutable.Map[ScopeKey, ScopeValue],
private val parent: Map[ScopeKey, ScopeValue],
guidStart: Int = 0,
val guidInc: Int = 1
) {
private var guid = guidStart
private var immutable = false
private def nextGuid(): Int = {
val id = guid
guid += guidInc
id
}
private def verifyMutable(): Unit = {
if (immutable) throw new IllegalStateException("Cannot modify an immutable scope")
}
/** Create a new scope with the current scope as its parent.
*
* To be used for single-threaded applications.
*
* @return
* A new scope with an empty current scope, and this scope flattened into the parent scope.
*/
def subscope: Scope =
Scope(mutable.Map.empty, Map.empty.withDefault(current.withDefault(parent)))
def withSubscope[T](f: Scope => T): T = {
val subscope =
Scope(mutable.Map.empty, Map.empty.withDefault(current.withDefault(parent)), guid, guidInc)
immutable = true
val result = f(subscope)
guid = subscope.guid // Sync GUID
immutable = false
result
}
/** Create new scopes with the current scope as its parent and GUID numbering adjusted
* correctly.
*
* This will permanently mark the current scope as immutable, for thread safety.
*
* To be used for multi-threaded applications.
*
* @return
* New scopes with an empty current scope, and this scope flattened into the parent scope.
*/
def subscopes(n: Int): Seq[Scope] = {
verifyMutable()
immutable = true
(0 until n).map { i =>
Scope(
mutable.Map.empty,
Map.empty.withDefault(current.withDefault(parent)),
guid + i * guidInc,
guidInc * n
)
}
}
/** Attempt to add a new identifier to the current scope. If the identifier already exists in
* the current scope, add an error to the error list.
*
* @param ty
* The semantic type of the variable identifier, or function identifier type.
* @param name
* The name of the identifier.
* @param globalNames
* The global map of identifiers to semantic types - the identifier will be added to this
* map.
* @param globalNumbering
* The global map of identifier names to the number of times they have been declared - will
* used to rename this identifier, and will be incremented.
* @param errors
* The list of errors to append to.
* @return
* An error, if one occurred.
*/
def add(ty: SemType | FuncType, name: Ident)(using
globalNames: mutable.Map[Ident, SemType],
globalFuncs: mutable.Map[Ident, FuncType],
globalNumbering: mutable.Map[String, Int],
errors: mutable.Builder[Error, List[Error]]
) = {
val identType = ty match {
def add(name: Ident, public: Boolean = false): Chain[Error] = {
verifyMutable()
val path = name.pos.file.getCanonicalPath
val identType = name.ty match {
case _: SemType => IdentType.Var
case _: FuncType => IdentType.Func
}
current.get((name.v, identType)) match {
case Some(Ident(_, uid)) =>
errors += Error.DuplicateDeclaration(name)
name.uid = uid
val key = ScopeKey(path, name.v, identType)
current.get(key) match {
case Some(ScopeValue(Ident(_, id, _), _)) =>
name.guid = id
Chain.one(Error.DuplicateDeclaration(name))
case None =>
val uid = globalNumbering.getOrElse(name.v, 0)
name.uid = uid
current((name.v, identType)) = name
ty match {
case semType: SemType =>
globalNames(name) = semType
case funcType: FuncType =>
globalFuncs(name) = funcType
}
globalNumbering(name.v) = uid + 1
name.guid = nextGuid()
current(key) = ScopeValue(name, public)
Chain.empty
}
}
private def get(name: String, identType: IdentType): Option[Ident] =
/** Attempt to add a new identifier as an alias to another to the existing scope.
*
* @param alias
* The (new) alias identifier.
* @param orig
* The (existing) original identifier.
*
* @return
* An error, if one occurred.
*/
def addAlias(alias: Ident, orig: ScopeValue, public: Boolean = false): Chain[Error] = {
verifyMutable()
val path = alias.pos.file.getCanonicalPath
val identType = alias.ty match {
case _: SemType => IdentType.Var
case _: FuncType => IdentType.Func
}
val key = ScopeKey(path, alias.v, identType)
current.get(key) match {
case Some(ScopeValue(Ident(_, id, _), _)) =>
alias.guid = id
Chain.one(Error.DuplicateDeclaration(alias))
case None =>
alias.guid = nextGuid()
current(key) = ScopeValue(orig.id, public)
Chain.empty
}
}
def get(path: String, name: String, identType: IdentType): Option[ScopeValue] =
// Unfortunately map defaults only work with `.apply()`, which throws an error when the key is not found.
// Neither is there a way to check whether a default exists, so we have to use a try-catch.
try {
Some(current.withDefault(parent)((name, identType)))
Some(current.withDefault(parent)(ScopeKey(path, name, identType)))
} catch {
case _: NoSuchElementException => None
}
def getVar(name: String): Option[Ident] = get(name, IdentType.Var)
def getFunc(name: String): Option[Ident] = get(name, IdentType.Func)
def getVar(name: Ident): Option[Ident] =
get(name.pos.file.getCanonicalPath, name.v, IdentType.Var).map(_.id)
def getFunc(name: Ident): Option[Ident] =
get(name.pos.file.getCanonicalPath, name.v, IdentType.Func).map(_.id)
}
/** Check scoping of all variables and functions in the program. Also generate semantic types for
* all identifiers.
*
* @param prog
* AST of the program
* @param errors
* List of errors to append to
* @return
* Map of all (renamed) identifies to their semantic types
*/
def rename(prog: Program)(using
errors: mutable.Builder[Error, List[Error]]
): (Map[Ident, SemType], Map[Ident, FuncType]) = {
given globalNames: mutable.Map[Ident, SemType] = mutable.Map.empty
given globalFuncs: mutable.Map[Ident, FuncType] = mutable.Map.empty
given globalNumbering: mutable.Map[String, Int] = mutable.Map.empty
val scope = Scope(mutable.Map.empty, Map.empty)
def prepareGlobalScope(
partialProg: PartialProgram
)(using scope: Scope): IO[(FuncDecl, Chain[FuncDecl], Chain[Error])] = {
def readImportFile(file: File): IO[String] =
IO.blocking(os.read(os.Path(file.getCanonicalPath)))
def prepareImport(contents: String, file: File)(using
scope: Scope
): IO[(Chain[FuncDecl], Chain[Error])] = {
parser.parse(contents) match {
case Failure(msg) =>
IO.pure(Chain.empty, Chain.one(Error.SyntaxError(file, msg)))
case Success(fn) =>
val partialProg = fn(file)
for {
(main, chunks, errors) <- prepareGlobalScope(partialProg)
} yield (main +: chunks, errors)
}
}
def addImportsToScope(importFile: File, funcs: NonEmptyList[ImportedFunc])(using
scope: Scope
): Chain[Error] =
funcs.foldMap { case ImportedFunc(srcName, aliasName) =>
scope.get(importFile.getCanonicalPath, srcName.v, IdentType.Func) match {
case Some(src) if src.public =>
aliasName.ty = src.id.ty
scope.addAlias(aliasName, src)
case _ =>
Chain.one(Error.UndefinedFunction(srcName))
}
}
val PartialProgram(imports, prog) = partialProg
// First prepare this file's functions...
val Program(funcs, main) = prog
funcs
// First add all function declarations to the scope
.map { case FuncDecl(retType, name, params, body) =>
val (funcChunks, funcErrors) = funcs.foldLeft((Chain.empty[FuncDecl], Chain.empty[Error])) {
case ((chunks, errors), func @ FuncDecl(retType, name, params, body)) =>
val paramTypes = params.map { param =>
val paramType = SemType(param.paramType)
param.name.ty = paramType
paramType
}
scope.add(FuncType(SemType(retType), paramTypes), name)
(params zip paramTypes, body)
name.ty = FuncType(SemType(retType), paramTypes)
(chunks :+ func, errors ++ scope.add(name, public = true))
}
// Only then rename the function bodies
// (functions can call one-another regardless of order of declaration)
.foreach { case (params, body) =>
val functionScope = scope.subscope
params.foreach { case (param, paramType) =>
functionScope.add(paramType, param.name)
// ...and main body.
val mainBodyIdent = Ident(MAIN, ty = FuncType(?, Nil))(main.head.pos)
val mainBodyErrors = scope.add(mainBodyIdent, public = false)
val mainBodyChunk = FuncDecl(IntType()(prog.pos), mainBodyIdent, Nil, main)(prog.pos)
// Now handle imports
val file = prog.pos.file
val preparedImports = imports.foldLeftM[IO, (Chain[FuncDecl], Chain[Error])](
(Chain.empty[FuncDecl], Chain.empty[Error])
) { case ((chunks, errors), Import(name, funcs)) =>
val importFile = File(file.getParent, name.v)
if (!importFile.exists()) {
IO.pure(
(
chunks,
errors :+ Error.SemanticError(
name.pos,
s"File not found: ${importFile.getCanonicalPath}"
)
)
)
} else if (!importFile.canRead()) {
IO.pure(
(
chunks,
errors :+ Error.SemanticError(
name.pos,
s"File not readable: ${importFile.getCanonicalPath}"
)
)
)
} else if (importFile.getCanonicalPath == file.getCanonicalPath) {
IO.pure(
(
chunks,
errors :+ Error.SemanticError(
name.pos,
s"Cannot import self: ${importFile.getCanonicalPath}"
)
)
)
} else if (scope.get(importFile.getCanonicalPath, MAIN, IdentType.Func).isDefined) {
IO.pure(chunks, errors ++ addImportsToScope(importFile, funcs))
} else {
for {
contents <- readImportFile(importFile)
(importChunks, importErrors) <- prepareImport(contents, importFile)
importAliasErrors = addImportsToScope(importFile, funcs)
} yield (chunks ++ importChunks, errors ++ importErrors)
}
body.toList.foreach(rename(functionScope.subscope)) // body can shadow function params
}
main.toList.foreach(rename(scope))
(globalNames.toMap, globalFuncs.toMap)
for {
(importChunks, importErrors) <- preparedImports
allChunks = importChunks ++ funcChunks
allErrors = importErrors ++ funcErrors ++ mainBodyErrors
} yield (mainBodyChunk, allChunks, allErrors)
}
/** Check scoping of all variables and flatten a program. Also generates semantic types and parses
* any imported files.
*
* @param partialProg
* AST of the program
* @return
* (flattenedProg, errors)
*/
def renameFunction(funcScopePair: (FuncDecl, Scope)): IO[Chain[Error]] = {
val (FuncDecl(_, _, params, body), subscope) = funcScopePair
val paramErrors = params.foldMap(param => subscope.add(param.name))
IO(subscope.withSubscope { s => body.foldMap(rename(s)) })
.map(bodyErrors => paramErrors ++ bodyErrors)
}
/** Check scoping of all identifies in a given AST node.
@@ -129,49 +279,38 @@ object renamer {
* The current scope and flattened parent scope.
* @param node
* The AST node.
* @param globalNames
* The global map of identifiers to semantic types - renamed identifiers will be added to this
* map.
* @param globalNumbering
* The global map of identifier names to the number of times they have been declared - used and
* updated during identifier renaming.
* @param errors
*/
private def rename(scope: Scope)(
node: Ident | Stmt | LValue | RValue | Expr
)(using
globalNames: mutable.Map[Ident, SemType],
globalFuncs: mutable.Map[Ident, FuncType],
globalNumbering: mutable.Map[String, Int],
errors: mutable.Builder[Error, List[Error]]
): Unit = node match {
// These cases are more interesting because the involve making subscopes
private def rename(scope: Scope)(node: Ident | Stmt | LValue | RValue | Expr): Chain[Error] =
node match {
// These cases are more interes/globting because the involve making subscopes
// or modifying the current scope.
case VarDecl(synType, name, value) => {
// Order matters here. Variable isn't declared until after the value is evaluated.
rename(scope)(value)
val errors = rename(scope)(value)
// Attempt to add the new variable to the current scope.
scope.add(SemType(synType), name)
name.ty = SemType(synType)
errors ++ scope.add(name)
}
case If(cond, thenStmt, elseStmt) => {
rename(scope)(cond)
val condErrors = rename(scope)(cond)
// then and else both have their own scopes
thenStmt.toList.foreach(rename(scope.subscope))
elseStmt.toList.foreach(rename(scope.subscope))
val thenErrors = scope.withSubscope(s => thenStmt.foldMap(rename(s)))
val elseErrors = scope.withSubscope(s => elseStmt.foldMap(rename(s)))
condErrors ++ thenErrors ++ elseErrors
}
case While(cond, body) => {
rename(scope)(cond)
val condErrors = rename(scope)(cond)
// while bodies have their own scopes
body.toList.foreach(rename(scope.subscope))
val bodyErrors = scope.withSubscope(s => body.foldMap(rename(s)))
condErrors ++ bodyErrors
}
// begin-end blocks have their own scopes
case Block(body) => body.toList.foreach(rename(scope.subscope))
case Block(body) => scope.withSubscope(s => body.foldMap(rename(s)))
// These cases are simpler, mostly just recursive calls to rename()
case Assign(lhs, value) => {
// Variables may be reassigned with their value in the rhs, so order doesn't matter here.
rename(scope)(lhs)
rename(scope)(value)
rename(scope)(lhs) ++ rename(scope)(value)
}
case Read(lhs) => rename(scope)(lhs)
case Free(expr) => rename(scope)(expr)
@@ -179,41 +318,51 @@ object renamer {
case Exit(expr) => rename(scope)(expr)
case Print(expr, _) => rename(scope)(expr)
case NewPair(fst, snd) => {
rename(scope)(fst)
rename(scope)(snd)
rename(scope)(fst) ++ rename(scope)(snd)
}
case Call(name, args) => {
scope.getFunc(name.v) match {
case Some(Ident(_, uid)) => name.uid = uid
val nameErrors = scope.getFunc(name) match {
case Some(Ident(realName, guid, ty)) =>
name.v = realName
name.ty = ty
name.guid = guid
Chain.empty
case None =>
errors += Error.UndefinedFunction(name)
scope.add(FuncType(?, args.map(_ => ?)), name)
name.ty = FuncType(?, args.map(_ => ?))
scope.add(name)
Chain.one(Error.UndefinedFunction(name))
}
args.foreach(rename(scope))
val argsErrors = args.foldMap(rename(scope))
nameErrors ++ argsErrors
}
case Fst(elem) => rename(scope)(elem)
case Snd(elem) => rename(scope)(elem)
case ArrayLiter(elems) => elems.foreach(rename(scope))
case ArrayLiter(elems) => elems.foldMap(rename(scope))
case ArrayElem(name, indices) => {
rename(scope)(name)
indices.toList.foreach(rename(scope))
val nameErrors = rename(scope)(name)
val indicesErrors = indices.foldMap(rename(scope))
nameErrors ++ indicesErrors
}
case Parens(expr) => rename(scope)(expr)
case op: UnaryOp => rename(scope)(op.x)
case op: BinaryOp => {
rename(scope)(op.x)
rename(scope)(op.y)
rename(scope)(op.x) ++ rename(scope)(op.y)
}
// Default to variables. Only `call` uses IdentType.Func.
case id: Ident => {
scope.getVar(id.v) match {
case Some(Ident(_, uid)) => id.uid = uid
scope.getVar(id) match {
case Some(Ident(_, guid, ty)) =>
id.ty = ty
id.guid = guid
Chain.empty
case None =>
errors += Error.UndeclaredVariable(id)
scope.add(?, id)
id.ty = ?
scope.add(id)
Chain.one(Error.UndeclaredVariable(id))
}
}
// These literals cannot contain identifies, exit immediately.
case IntLiter(_) | BoolLiter(_) | CharLiter(_) | StrLiter(_) | PairLiter() | Skip() => ()
case IntLiter(_) | BoolLiter(_) | CharLiter(_) | StrLiter(_) | PairLiter() | Skip() =>
Chain.empty
}
}

View File

@@ -0,0 +1,42 @@
package wacc
import scala.collection.mutable
import cats.implicits._
import cats.data.Chain
import cats.effect.IO
object semantics {
import renamer.{Scope, prepareGlobalScope, renameFunction}
import typeChecker.checkFuncDecl
private def checkFunc(
funcDecl: ast.FuncDecl,
scope: Scope
): IO[(microWacc.FuncDecl, Chain[Error])] = {
for {
renamerErrors <- renameFunction(funcDecl, scope)
(microWaccFunc, typeErrors) = checkFuncDecl(funcDecl)
} yield (microWaccFunc, renamerErrors ++ typeErrors)
}
def check(partialProg: ast.PartialProgram): IO[(microWacc.Program, Chain[Error])] = {
given scope: Scope = Scope(mutable.Map.empty, Map.empty)
for {
(main, chunks, globalErrors) <- prepareGlobalScope(partialProg)
toRename = (main +: chunks).toList
res <- toRename
.zip(scope.subscopes(toRename.size))
.parTraverse(checkFunc)
(typedChunks, errors) = res.foldLeft((Chain.empty[microWacc.FuncDecl], Chain.empty[Error])) {
case ((acc, err), (funcDecl, errors)) =>
(acc :+ funcDecl, err ++ errors)
}
(typedMain, funcs) = typedChunks.uncons match {
case Some((head, tail)) => (head.body, tail)
case None => (Chain.empty, Chain.empty)
}
} yield (microWacc.Program(funcs, typedMain)(main.pos), globalErrors ++ errors)
}
}

View File

@@ -1,25 +1,12 @@
package wacc
import cats.syntax.all._
import scala.collection.mutable
import cats.data.NonEmptyList
import cats.data.Chain
object typeChecker {
import wacc.types._
case class TypeCheckerCtx(
globalNames: Map[ast.Ident, SemType],
globalFuncs: Map[ast.Ident, FuncType],
errors: mutable.Builder[Error, List[Error]]
) {
def typeOf(ident: ast.Ident): SemType = globalNames(ident)
def funcType(ident: ast.Ident): FuncType = globalFuncs(ident)
def error(err: Error): SemType =
errors += err
?
}
private enum Constraint {
case Unconstrained
// Allows weakening in one direction
@@ -43,31 +30,29 @@ object typeChecker {
* @return
* The type if the constraint was satisfied, or ? if it was not.
*/
private def satisfies(constraint: Constraint, pos: ast.Position)(using
ctx: TypeCheckerCtx
): SemType =
private def satisfies(constraint: Constraint, pos: ast.Position): (SemType, Chain[Error]) =
(ty, constraint) match {
case (KnownType.Array(KnownType.Char), Constraint.Is(KnownType.String, _)) =>
KnownType.String
(KnownType.String, Chain.empty)
case (
KnownType.String,
Constraint.IsSymmetricCompatible(KnownType.Array(KnownType.Char), _)
) =>
KnownType.String
(KnownType.String, Chain.empty)
case (ty, Constraint.IsSymmetricCompatible(ty2, msg)) =>
ty.satisfies(Constraint.Is(ty2, msg), pos)
// Change to IsUnweakenable to disallow recursive weakening
case (ty, Constraint.Is(ty2, msg)) => ty.satisfies(Constraint.IsUnweakenable(ty2, msg), pos)
case (ty, Constraint.Unconstrained) => ty
case (ty, Constraint.Unconstrained) => (ty, Chain.empty)
case (ty, Constraint.Never(msg)) =>
ctx.error(Error.SemanticError(pos, msg))
(?, Chain.one(Error.SemanticError(pos, msg)))
case (ty, Constraint.IsEither(ty1, ty2, msg)) =>
(ty moreSpecific ty1).orElse(ty moreSpecific ty2).getOrElse {
ctx.error(Error.TypeMismatch(pos, ty1, ty, msg))
(ty moreSpecific ty1).orElse(ty moreSpecific ty2).map((_, Chain.empty)).getOrElse {
(?, Chain.one(Error.TypeMismatch(pos, ty1, ty, msg)))
}
case (ty, Constraint.IsUnweakenable(ty2, msg)) =>
(ty moreSpecific ty2).getOrElse {
ctx.error(Error.TypeMismatch(pos, ty2, ty, msg))
(ty moreSpecific ty2).map((_, Chain.empty)).getOrElse {
(?, Chain.one(Error.TypeMismatch(pos, ty2, ty, msg)))
}
}
@@ -91,36 +76,29 @@ object typeChecker {
}
}
/** Type-check a WACC program.
/** Type-check a function declaration.
*
* @param prog
* The AST of the program to type-check.
* @param ctx
* The type checker context which includes the global names and functions, and an errors
* builder.
* @param func
* The AST of the function to type-check.
*/
def check(prog: ast.Program)(using
ctx: TypeCheckerCtx
): microWacc.Program =
microWacc.Program(
// Ignore function syntax types for return value and params, since those have been converted
// to SemTypes by the renamer.
prog.funcs.map { case ast.FuncDecl(_, name, params, stmts) =>
val FuncType(retType, paramTypes) = ctx.funcType(name)
def checkFuncDecl(func: ast.FuncDecl): (microWacc.FuncDecl, Chain[Error]) = {
val ast.FuncDecl(_, name, params, stmts) = func
val FuncType(retType, paramTypes) = name.ty.asInstanceOf[FuncType]
val returnConstraint =
if func.name.v == renamer.MAIN then Constraint.Never("main body must not return")
else Constraint.Is(retType, s"function ${name.v} must return $retType")
val (body, bodyErrors) = stmts.foldMap(checkStmt(_, returnConstraint))
(
microWacc.FuncDecl(
microWacc.Ident(name.v, name.uid)(retType),
microWacc.Ident(name.v, name.guid)(retType),
params.zip(paramTypes).map { case (ast.Param(_, ident), ty) =>
microWacc.Ident(ident.v, ident.uid)(ty)
microWacc.Ident(ident.v, ident.guid)(ty)
},
stmts.toList
.flatMap(
checkStmt(_, Constraint.Is(retType, s"function ${name.v} must return $retType"))
)
)
},
prog.main.toList
.flatMap(checkStmt(_, Constraint.Never("main function must not return")))
body
)(func.pos),
bodyErrors
)
}
/** Type-check an AST statement node.
*
@@ -129,40 +107,54 @@ object typeChecker {
* @param returnConstraint
* The constraint that any `return <expr>` statements must satisfy.
*/
private def checkStmt(stmt: ast.Stmt, returnConstraint: Constraint)(using
ctx: TypeCheckerCtx
): List[microWacc.Stmt] = stmt match {
private def checkStmt(
stmt: ast.Stmt,
returnConstraint: Constraint
): (Chain[microWacc.Stmt], Chain[Error]) = stmt match {
// Ignore the type of the variable, since it has been converted to a SemType by the renamer.
case ast.VarDecl(_, name, value) =>
val expectedTy = ctx.typeOf(name)
val typedValue = checkValue(
val expectedTy = name.ty
val (typedValue, valueErrors) = checkValue(
value,
Constraint.Is(
expectedTy,
expectedTy.asInstanceOf[SemType],
s"variable ${name.v} must be assigned a value of type $expectedTy"
)
)
List(microWacc.Assign(microWacc.Ident(name.v, name.uid)(expectedTy), typedValue))
(
Chain.one(
microWacc.Assign(
microWacc.Ident(name.v, name.guid)(expectedTy.asInstanceOf[SemType]),
typedValue
)(stmt.pos)
),
valueErrors
)
case ast.Assign(lhs, rhs) =>
val lhsTyped = checkLValue(lhs, Constraint.Unconstrained)
val rhsTyped =
val (lhsTyped, lhsErrors) = checkLValue(lhs, Constraint.Unconstrained)
val (rhsTyped, rhsErrors) =
checkValue(rhs, Constraint.Is(lhsTyped.ty, s"assignment must have type ${lhsTyped.ty}"))
(lhsTyped.ty, rhsTyped.ty) match {
val unknownError = (lhsTyped.ty, rhsTyped.ty) match {
case (?, ?) =>
ctx.error(
Chain.one(
Error.SemanticError(lhs.pos, "assignment with both sides of unknown type is illegal")
)
case _ => ()
case _ => Chain.empty
}
List(microWacc.Assign(lhsTyped, rhsTyped))
(
Chain.one(microWacc.Assign(lhsTyped, rhsTyped)(stmt.pos)),
lhsErrors ++ rhsErrors ++ unknownError
)
case ast.Read(dest) =>
val destTyped = checkLValue(dest, Constraint.Unconstrained)
val destTy = destTyped.ty match {
val (destTyped, destErrors) = checkLValue(dest, Constraint.Unconstrained)
val (destTy, destTyErrors) = destTyped.ty match {
case ? =>
ctx.error(
(
?,
Chain.one(
Error.SemanticError(dest.pos, "cannot read into a destination with an unknown type")
)
?
)
case destTy =>
destTy.satisfies(
Constraint.IsEither(
@@ -173,27 +165,26 @@ object typeChecker {
dest.pos
)
}
List(
(
Chain.one(
microWacc.Assign(
destTyped,
microWacc.Call(
microWacc.Builtin.Read,
List(
destTy match {
case KnownType.Int => " %d".toMicroWaccCharArray
case KnownType.Char | _ => " %c".toMicroWaccCharArray
case KnownType.Int => " %d".toMicroWaccCharArray(stmt.pos)
case KnownType.Char | _ => " %c".toMicroWaccCharArray(stmt.pos)
},
destTyped
)
)
)
)(dest.pos)
)(stmt.pos)
),
destErrors ++ destTyErrors
)
case ast.Free(lhs) =>
List(
microWacc.Call(
microWacc.Builtin.Free,
List(
checkValue(
val (lhsTyped, lhsErrors) = checkValue(
lhs,
Constraint.IsEither(
KnownType.Array(?),
@@ -201,21 +192,17 @@ object typeChecker {
"free must be applied to an array or pair"
)
)
)
)
)
(Chain.one(microWacc.Call(microWacc.Builtin.Free, List(lhsTyped))(stmt.pos)), lhsErrors)
case ast.Return(expr) =>
List(microWacc.Return(checkValue(expr, returnConstraint)))
val (exprTyped, exprErrors) = checkValue(expr, returnConstraint)
(Chain.one(microWacc.Return(exprTyped)(stmt.pos)), exprErrors)
case ast.Exit(expr) =>
List(
microWacc.Call(
microWacc.Builtin.Exit,
List(checkValue(expr, Constraint.Is(KnownType.Int, "exit value must be int")))
)
)
val (exprTyped, exprErrors) =
checkValue(expr, Constraint.Is(KnownType.Int, "exit value must be int"))
(Chain.one(microWacc.Call(microWacc.Builtin.Exit, List(exprTyped))(stmt.pos)), exprErrors)
case ast.Print(expr, newline) =>
// This constraint should never fail, the scope-checker should have caught it already
val exprTyped = checkValue(expr, Constraint.Unconstrained)
val (exprTyped, exprErrors) = checkValue(expr, Constraint.Unconstrained)
val exprFormat = exprTyped.ty match {
case KnownType.Bool | KnownType.String => "%s"
case KnownType.Array(KnownType.Char) => "%.*s"
@@ -224,46 +211,48 @@ object typeChecker {
case KnownType.Pair(_, _) | KnownType.Array(_) | ? => "%p"
}
val printfCall = { (func: microWacc.Builtin, value: microWacc.Expr) =>
List(
Chain.one(
microWacc.Call(
func,
List(
s"$exprFormat${if newline then "\n" else ""}".toMicroWaccCharArray,
s"$exprFormat${if newline then "\n" else ""}".toMicroWaccCharArray(stmt.pos),
value
)
)
)(stmt.pos)
)
}
(
exprTyped.ty match {
case KnownType.Bool =>
List(
Chain.one(
microWacc.If(
exprTyped,
printfCall(microWacc.Builtin.Printf, "true".toMicroWaccCharArray),
printfCall(microWacc.Builtin.Printf, "false".toMicroWaccCharArray)
)
printfCall(microWacc.Builtin.Printf, "true".toMicroWaccCharArray(stmt.pos)),
printfCall(microWacc.Builtin.Printf, "false".toMicroWaccCharArray(stmt.pos))
)(stmt.pos)
)
case KnownType.Array(KnownType.Char) =>
printfCall(microWacc.Builtin.PrintCharArray, exprTyped)
case _ => printfCall(microWacc.Builtin.Printf, exprTyped)
}
case ast.If(cond, thenStmt, elseStmt) =>
List(
microWacc.If(
checkValue(cond, Constraint.Is(KnownType.Bool, "if condition must be a bool")),
thenStmt.toList.flatMap(checkStmt(_, returnConstraint)),
elseStmt.toList.flatMap(checkStmt(_, returnConstraint))
},
exprErrors
)
case ast.If(cond, thenStmt, elseStmt) =>
val (condTyped, condErrors) =
checkValue(cond, Constraint.Is(KnownType.Bool, "if condition must be a bool"))
val (thenStmtTyped, thenErrors) = thenStmt.foldMap(checkStmt(_, returnConstraint))
val (elseStmtTyped, elseErrors) = elseStmt.foldMap(checkStmt(_, returnConstraint))
(
Chain.one(microWacc.If(condTyped, thenStmtTyped, elseStmtTyped)(cond.pos)),
condErrors ++ thenErrors ++ elseErrors
)
case ast.While(cond, body) =>
List(
microWacc.While(
checkValue(cond, Constraint.Is(KnownType.Bool, "while condition must be a bool")),
body.toList.flatMap(checkStmt(_, returnConstraint))
)
)
case ast.Block(body) => body.toList.flatMap(checkStmt(_, returnConstraint))
case skip @ ast.Skip() => List.empty
val (condTyped, condErrors) =
checkValue(cond, Constraint.Is(KnownType.Bool, "while condition must be a bool"))
val (bodyTyped, bodyErrors) = body.foldMap(checkStmt(_, returnConstraint))
(Chain.one(microWacc.While(condTyped, bodyTyped)(cond.pos)), condErrors ++ bodyErrors)
case ast.Block(body) => body.foldMap(checkStmt(_, returnConstraint))
case skip @ ast.Skip() => (Chain.empty, Chain.empty)
}
/** Type-check an AST LValue, RValue or Expr node. This function does all 3 since these traits
@@ -276,127 +265,145 @@ object typeChecker {
* @return
* The most specific type of the value if it could be determined, or ? if it could not.
*/
private def checkValue(value: ast.LValue | ast.RValue | ast.Expr, constraint: Constraint)(using
ctx: TypeCheckerCtx
): microWacc.Expr = value match {
private def checkValue(
value: ast.LValue | ast.RValue | ast.Expr,
constraint: Constraint
): (microWacc.Expr, Chain[Error]) = value match {
case l @ ast.IntLiter(v) =>
KnownType.Int.satisfies(constraint, l.pos)
microWacc.IntLiter(v)
val (_, errors) = KnownType.Int.satisfies(constraint, l.pos)
(microWacc.IntLiter(v), errors)
case l @ ast.BoolLiter(v) =>
KnownType.Bool.satisfies(constraint, l.pos)
microWacc.BoolLiter(v)
val (_, errors) = KnownType.Bool.satisfies(constraint, l.pos)
(microWacc.BoolLiter(v), errors)
case l @ ast.CharLiter(v) =>
KnownType.Char.satisfies(constraint, l.pos)
microWacc.CharLiter(v)
val (_, errors) = KnownType.Char.satisfies(constraint, l.pos)
(microWacc.CharLiter(v), errors)
case l @ ast.StrLiter(v) =>
KnownType.String.satisfies(constraint, l.pos)
v.toMicroWaccCharArray
val (_, errors) = KnownType.String.satisfies(constraint, l.pos)
(v.toMicroWaccCharArray(l.pos), errors)
case l @ ast.PairLiter() =>
microWacc.NullLiter()(KnownType.Pair(?, ?).satisfies(constraint, l.pos))
val (ty, errors) = KnownType.Pair(?, ?).satisfies(constraint, l.pos)
(microWacc.NullLiter()(ty), errors)
case ast.Parens(expr) => checkValue(expr, constraint)
case l @ ast.ArrayLiter(elems) =>
val (elemTy, elemsTyped) = elems.mapAccumulate[SemType, microWacc.Expr](?) {
case (acc, elem) =>
val elemTyped = checkValue(
val ((elemTy, elemsErrors), elemsTyped) =
elems.mapAccumulate[(SemType, Chain[Error]), microWacc.Expr]((?, Chain.empty)) {
case ((acc, errors), elem) =>
val (elemTyped, elemErrors) = checkValue(
elem,
Constraint.IsSymmetricCompatible(acc, s"array elements must have the same type")
)
(elemTyped.ty, elemTyped)
((elemTyped.ty, errors ++ elemErrors), elemTyped)
}
val arrayTy = KnownType
val (arrayTy, arrayErrors) = KnownType
// Start with an unknown param type, make it more specific while checking the elements.
.Array(elemTy)
.satisfies(constraint, l.pos)
microWacc.ArrayLiter(elemsTyped)(arrayTy)
(microWacc.ArrayLiter(elemsTyped)(arrayTy, l.pos), elemsErrors ++ arrayErrors)
case l @ ast.NewPair(fst, snd) =>
val fstTyped = checkValue(fst, Constraint.Unconstrained)
val sndTyped = checkValue(snd, Constraint.Unconstrained)
microWacc.ArrayLiter(List(fstTyped, sndTyped))(
val (fstTyped, fstErrors) = checkValue(fst, Constraint.Unconstrained)
val (sndTyped, sndErrors) = checkValue(snd, Constraint.Unconstrained)
val (pairTy, pairErrors) =
KnownType.Pair(fstTyped.ty, sndTyped.ty).satisfies(constraint, l.pos)
(
microWacc.ArrayLiter(List(fstTyped, sndTyped))(pairTy, l.pos),
fstErrors ++ sndErrors ++ pairErrors
)
case ast.Call(id, args) =>
val funcTy @ FuncType(retTy, paramTys) = ctx.funcType(id)
if (args.length != paramTys.length) {
ctx.error(Error.FunctionParamsMismatch(id, paramTys.length, args.length, funcTy))
}
val funcTy @ FuncType(retTy, paramTys) = id.ty.asInstanceOf[FuncType]
val lenError =
if (args.length == paramTys.length) then Chain.empty
else Chain.one(Error.FunctionParamsMismatch(id, paramTys.length, args.length, funcTy))
// Even if the number of arguments is wrong, we still check the types of the arguments
// in the best way we can (by taking a zip).
val argsTyped = args.zip(paramTys).map { case (arg, paramTy) =>
val (argsErrors, argsTyped) =
args.zip(paramTys).mapAccumulate(Chain.empty[Error]) { case (errors, (arg, paramTy)) =>
val (argTyped, argErrors) =
checkValue(arg, Constraint.Is(paramTy, s"argument type mismatch in function ${id.v}"))
(errors ++ argErrors, argTyped)
}
microWacc.Call(microWacc.Ident(id.v, id.uid)(retTy.satisfies(constraint, id.pos)), argsTyped)
val (retTyChecked, retErrors) = retTy.satisfies(constraint, id.pos)
(
microWacc.Call(microWacc.Ident(id.v, id.guid)(retTyChecked), argsTyped)(id.pos),
lenError ++ argsErrors ++ retErrors
)
// Unary operators
case ast.Negate(x) =>
microWacc.UnaryOp(
checkValue(x, Constraint.Is(KnownType.Int, "negation must be applied to an int")),
microWacc.UnaryOperator.Negate
)(KnownType.Int.satisfies(constraint, x.pos))
val (argTyped, argErrors) =
checkValue(x, Constraint.Is(KnownType.Int, "negation must be applied to an int"))
val (retTy, retErrors) = KnownType.Int.satisfies(constraint, x.pos)
(microWacc.UnaryOp(argTyped, microWacc.UnaryOperator.Negate)(retTy), argErrors ++ retErrors)
case ast.Not(x) =>
microWacc.UnaryOp(
checkValue(x, Constraint.Is(KnownType.Bool, "logical not must be applied to a bool")),
microWacc.UnaryOperator.Not
)(KnownType.Bool.satisfies(constraint, x.pos))
val (argTyped, argErrors) =
checkValue(x, Constraint.Is(KnownType.Bool, "logical not must be applied to a bool"))
val (retTy, retErrors) = KnownType.Bool.satisfies(constraint, x.pos)
(microWacc.UnaryOp(argTyped, microWacc.UnaryOperator.Not)(retTy), argErrors ++ retErrors)
case ast.Len(x) =>
microWacc.UnaryOp(
checkValue(x, Constraint.Is(KnownType.Array(?), "len must be applied to an array")),
microWacc.UnaryOperator.Len
)(KnownType.Int.satisfies(constraint, x.pos))
val (argTyped, argErrors) =
checkValue(x, Constraint.Is(KnownType.Array(?), "len must be applied to an array"))
val (retTy, retErrors) = KnownType.Int.satisfies(constraint, x.pos)
(microWacc.UnaryOp(argTyped, microWacc.UnaryOperator.Len)(retTy), argErrors ++ retErrors)
case ast.Ord(x) =>
microWacc.UnaryOp(
checkValue(x, Constraint.Is(KnownType.Char, "ord must be applied to a char")),
microWacc.UnaryOperator.Ord
)(KnownType.Int.satisfies(constraint, x.pos))
val (argTyped, argErrors) =
checkValue(x, Constraint.Is(KnownType.Char, "ord must be applied to a char"))
val (retTy, retErrors) = KnownType.Int.satisfies(constraint, x.pos)
(microWacc.UnaryOp(argTyped, microWacc.UnaryOperator.Ord)(retTy), argErrors ++ retErrors)
case ast.Chr(x) =>
microWacc.UnaryOp(
checkValue(x, Constraint.Is(KnownType.Int, "chr must be applied to an int")),
microWacc.UnaryOperator.Chr
)(KnownType.Char.satisfies(constraint, x.pos))
val (argTyped, argErrors) =
checkValue(x, Constraint.Is(KnownType.Int, "chr must be applied to an int"))
val (retTy, retErrors) = KnownType.Char.satisfies(constraint, x.pos)
(microWacc.UnaryOp(argTyped, microWacc.UnaryOperator.Chr)(retTy), argErrors ++ retErrors)
// Binary operators
case op: (ast.Add | ast.Sub | ast.Mul | ast.Div | ast.Mod) =>
val operand = Constraint.Is(KnownType.Int, s"${op.name} operator must be applied to an int")
microWacc.BinaryOp(
checkValue(op.x, operand),
checkValue(op.y, operand),
microWacc.BinaryOperator.fromAst(op)
)(KnownType.Int.satisfies(constraint, op.pos))
val (xTyped, xErrors) = checkValue(op.x, operand)
val (yTyped, yErrors) = checkValue(op.y, operand)
val (retTy, retErrors) = KnownType.Int.satisfies(constraint, op.pos)
(
microWacc.BinaryOp(xTyped, yTyped, microWacc.BinaryOperator.fromAst(op))(retTy),
xErrors ++ yErrors ++ retErrors
)
case op: (ast.Eq | ast.Neq) =>
val xTyped = checkValue(op.x, Constraint.Unconstrained)
microWacc.BinaryOp(
xTyped,
checkValue(
val (xTyped, xErrors) = checkValue(op.x, Constraint.Unconstrained)
val (yTyped, yErrors) = checkValue(
op.y,
Constraint
.Is(xTyped.ty, s"${op.name} operator must be applied to values of the same type")
),
microWacc.BinaryOperator.fromAst(op)
)(KnownType.Bool.satisfies(constraint, op.pos))
Constraint.Is(xTyped.ty, s"${op.name} operator must be applied to values of the same type")
)
val (retTy, retErrors) = KnownType.Bool.satisfies(constraint, op.pos)
(
microWacc.BinaryOp(xTyped, yTyped, microWacc.BinaryOperator.fromAst(op))(retTy),
xErrors ++ yErrors ++ retErrors
)
case op: (ast.Less | ast.LessEq | ast.Greater | ast.GreaterEq) =>
val xConstraint = Constraint.IsEither(
KnownType.Int,
KnownType.Char,
s"${op.name} operator must be applied to an int or char"
)
val xTyped = checkValue(op.x, xConstraint)
val (xTyped, xErrors) = checkValue(op.x, xConstraint)
// If x type-check failed, we still want to check y is an Int or Char (rather than ?)
val yConstraint = xTyped.ty match {
case ? => xConstraint
case xTy =>
Constraint.Is(xTy, s"${op.name} operator must be applied to values of the same type")
}
microWacc.BinaryOp(
xTyped,
checkValue(op.y, yConstraint),
microWacc.BinaryOperator.fromAst(op)
)(KnownType.Bool.satisfies(constraint, op.pos))
val (yTyped, yErrors) = checkValue(op.y, yConstraint)
val (retTy, retErrors) = KnownType.Bool.satisfies(constraint, op.pos)
(
microWacc.BinaryOp(xTyped, yTyped, microWacc.BinaryOperator.fromAst(op))(retTy),
xErrors ++ yErrors ++ retErrors
)
case op: (ast.And | ast.Or) =>
val operand = Constraint.Is(KnownType.Bool, s"${op.name} operator must be applied to a bool")
microWacc.BinaryOp(
checkValue(op.x, operand),
checkValue(op.y, operand),
microWacc.BinaryOperator.fromAst(op)
)(KnownType.Bool.satisfies(constraint, op.pos))
val (xTyped, xErrors) = checkValue(op.x, operand)
val (yTyped, yErrors) = checkValue(op.y, operand)
val (retTy, retErrors) = KnownType.Bool.satisfies(constraint, op.pos)
(
microWacc.BinaryOp(xTyped, yTyped, microWacc.BinaryOperator.fromAst(op))(retTy),
xErrors ++ yErrors ++ retErrors
)
case lvalue: ast.LValue => checkLValue(lvalue, constraint)
}
@@ -413,20 +420,27 @@ object typeChecker {
* @return
* The most specific type of the value if it could be determined, or ? if it could not.
*/
private def checkLValue(value: ast.LValue, constraint: Constraint)(using
ctx: TypeCheckerCtx
): microWacc.LValue = value match {
case id @ ast.Ident(name, uid) =>
microWacc.Ident(name, uid)(ctx.typeOf(id).satisfies(constraint, id.pos))
private def checkLValue(
value: ast.LValue,
constraint: Constraint
): (microWacc.LValue, Chain[Error]) = value match {
case id @ ast.Ident(name, guid, ty) =>
val (idTy, idErrors) = ty.asInstanceOf[SemType].satisfies(constraint, id.pos)
(microWacc.Ident(name, guid)(idTy), idErrors)
case ast.ArrayElem(id, indices) =>
val arrayTy = ctx.typeOf(id)
val (elemTy, indicesTyped) = indices.mapAccumulate(arrayTy) { (acc, elem) =>
val idxTyped = checkValue(elem, Constraint.Is(KnownType.Int, "array index must be an int"))
val next = acc match {
case KnownType.Array(innerTy) => innerTy
case ? => ? // we can keep indexing an unknown type
val arrayTy = id.ty.asInstanceOf[SemType]
val ((elemTy, elemErrors), indicesTyped) =
indices.mapAccumulate((arrayTy.asInstanceOf[SemType], Chain.empty[Error])) {
case ((acc, errors), elem) =>
val (idxTyped, idxErrors) =
checkValue(elem, Constraint.Is(KnownType.Int, "array index must be an int"))
val (next, nextError) = acc match {
case KnownType.Array(innerTy) => (innerTy, Chain.empty)
case ? => (?, Chain.empty) // we can keep indexing an unknown type
case nonArrayTy =>
ctx.error(
(
?,
Chain.one(
Error.TypeMismatch(
elem.pos,
KnownType.Array(?),
@@ -434,49 +448,45 @@ object typeChecker {
"cannot index into a non-array"
)
)
?
)
}
(next, idxTyped)
((next, errors ++ idxErrors ++ nextError), idxTyped)
}
val (retTy, retErrors) = elemTy.satisfies(constraint, value.pos)
val firstArrayElem = microWacc.ArrayElem(
microWacc.Ident(id.v, id.uid)(arrayTy),
microWacc.Ident(id.v, id.guid)(arrayTy),
indicesTyped.head
)(elemTy.satisfies(constraint, value.pos))
)(retTy)
val arrayElem = indicesTyped.tail.foldLeft(firstArrayElem) { (acc, idx) =>
microWacc.ArrayElem(acc, idx)(KnownType.Array(acc.ty))
}
// Need to type-check the final arrayElem with the constraint
microWacc.ArrayElem(arrayElem.value, arrayElem.index)(elemTy.satisfies(constraint, value.pos))
// TODO: What
(microWacc.ArrayElem(arrayElem.value, arrayElem.index)(retTy), elemErrors ++ retErrors)
case ast.Fst(elem) =>
val elemTyped = checkLValue(
val (elemTyped, elemErrors) = checkLValue(
elem,
Constraint.Is(KnownType.Pair(?, ?), "fst must be applied to a pair")
)
microWacc.ArrayElem(
elemTyped,
microWacc.IntLiter(0)
)(elemTyped.ty match {
case KnownType.Pair(left, _) =>
left.satisfies(constraint, elem.pos)
case _ => ctx.error(Error.InternalError(elem.pos, "fst must be applied to a pair"))
})
val (retTy, retErrors) = elemTyped.ty match {
case KnownType.Pair(left, _) => left.satisfies(constraint, elem.pos)
case _ => (?, Chain.one(Error.InternalError(elem.pos, "fst must be applied to a pair")))
}
(microWacc.ArrayElem(elemTyped, microWacc.IntLiter(0))(retTy), elemErrors ++ retErrors)
case ast.Snd(elem) =>
val elemTyped = checkLValue(
val (elemTyped, elemErrors) = checkLValue(
elem,
Constraint.Is(KnownType.Pair(?, ?), "snd must be applied to a pair")
)
microWacc.ArrayElem(
elemTyped,
microWacc.IntLiter(1)
)(elemTyped.ty match {
case KnownType.Pair(_, right) =>
right.satisfies(constraint, elem.pos)
case _ => ctx.error(Error.InternalError(elem.pos, "snd must be applied to a pair"))
})
val (retTy, retErrors) = elemTyped.ty match {
case KnownType.Pair(_, right) => right.satisfies(constraint, elem.pos)
case _ => (?, Chain.one(Error.InternalError(elem.pos, "snd must be applied to a pair")))
}
(microWacc.ArrayElem(elemTyped, microWacc.IntLiter(1))(retTy), elemErrors ++ retErrors)
}
extension (s: String) {
def toMicroWaccCharArray: microWacc.ArrayLiter =
microWacc.ArrayLiter(s.map(microWacc.CharLiter(_)).toList)(KnownType.String)
def toMicroWaccCharArray(pos: ast.Position): microWacc.ArrayLiter =
microWacc.ArrayLiter(s.map(microWacc.CharLiter(_)).toList)(KnownType.String, pos)
}
}

View File

@@ -3,7 +3,9 @@ package wacc
object types {
import ast._
sealed trait SemType {
sealed trait RenamerType
sealed trait SemType extends RenamerType {
override def toString(): String = this match {
case KnownType.Int => "int"
case KnownType.Bool => "bool"
@@ -41,5 +43,5 @@ object types {
}
}
case class FuncType(returnType: SemType, params: List[SemType])
case class FuncType(returnType: SemType, params: List[SemType]) extends RenamerType
}

View File

@@ -0,0 +1,66 @@
package wacc
import org.scalatest.flatspec.AnyFlatSpec
import org.scalatest.Inspectors.forEvery
import cats.data.Chain
class ExtensionsSpec extends AnyFlatSpec {
import asmGenerator.concatAll
import asmGenerator.escaped
behavior of "concatAll"
it should "handle int chains" in {
val chain = Chain(1, 2, 3).concatAll(
Chain(4, 5, 6),
Chain.empty,
Chain.one(-1)
)
assert(chain == Chain(1, 2, 3, 4, 5, 6, -1))
}
it should "handle AsmLine chains" in {
object lines {
import assemblyIR._
import assemblyIR.commonRegisters._
val main = LabelDef("main")
val pop = Pop(RAX)
val add = Add(RAX, ImmediateVal(1))
val push = Push(RAX)
val ret = Return()
}
val chain = Chain(lines.main).concatAll(
Chain.empty,
Chain.one(lines.pop),
Chain(lines.add, lines.push),
Chain.one(lines.ret)
)
assert(chain == Chain(lines.main, lines.pop, lines.add, lines.push, lines.ret))
}
behavior of "escaped"
val escapedStrings = Map(
"hello" -> "hello",
"world" -> "world",
"hello\nworld" -> "hello\\nworld",
"hello\tworld" -> "hello\\tworld",
"hello\\world" -> "hello\\\\world",
"hello\"world" -> "hello\\\"world",
"hello'world" -> "hello\\'world",
"hello\\nworld" -> "hello\\\\nworld",
"hello\\tworld" -> "hello\\\\tworld",
"hello\\\\world" -> "hello\\\\\\\\world",
"hello\\\"world" -> "hello\\\\\\\"world",
"hello\\'world" -> "hello\\\\\\'world",
"hello\\n\\t\\'world" -> "hello\\\\n\\\\t\\\\\\'world",
"hello\u0000world" -> "hello\\0world",
"hello\bworld" -> "hello\\bworld",
"hello\fworld" -> "hello\\fworld"
)
forEvery(escapedStrings) { (input, expected) =>
it should s"return $expected" in {
assert(input.escaped == expected)
}
}
}

View File

@@ -0,0 +1,85 @@
package wacc
import org.scalatest.flatspec.AnyFlatSpec
class LabelGeneratorSpec extends AnyFlatSpec {
import microWacc._
import assemblyIR.{LabelDef, LabelArg, Directive}
import types.?
"getLabel" should "return unique labels" in {
val labelGenerator = new LabelGenerator
val labels = (1 to 10).map(_ => labelGenerator.getLabel())
assert(labels.distinct.length == labels.length)
}
"getLabelDef" should "return unique labels" in {
assert(
LabelDef("test") == LabelDef("test") &&
LabelDef("test").hashCode == LabelDef("test").hashCode,
"Sanity check: LabelDef should be case-classes"
)
val labelGenerator = new LabelGenerator
val labels = (List(
Builtin.Exit,
Builtin.Printf,
Ident("exit", 0)(?),
Ident("test", 0)(?)
) ++ RuntimeError.all.toList).map(labelGenerator.getLabelDef(_))
assert(labels.distinct.length == labels.length)
}
"getLabelArg" should "return unique labels" in {
assert(
LabelArg("test") == LabelArg("test") &&
LabelArg("test").hashCode == LabelArg("test").hashCode,
"Sanity check: LabelArg should be case-classes"
)
val labelGenerator = new LabelGenerator
val labels = (List(
Builtin.Exit,
Builtin.Printf,
Ident("exit", 0)(?),
Ident("test", 0)(?),
"test",
"test",
"test3"
) ++ RuntimeError.all.toList).map {
case s: String => labelGenerator.getLabelArg(s)
case t: (CallTarget | RuntimeError) => labelGenerator.getLabelArg(t)
}
assert(labels.distinct.length == labels.distinct.length)
}
it should "return consistent labels to getLabelDef" in {
val labelGenerator = new LabelGenerator
val targets = (List(
Builtin.Exit,
Builtin.Printf,
Ident("exit", 0)(?),
Ident("test", 0)(?)
) ++ RuntimeError.all.toList)
val labelDefs = targets.map(labelGenerator.getLabelDef(_).toString.dropRight(1)).toSet
val labelArgs = targets.map(labelGenerator.getLabelArg(_).toString).toSet
assert(labelDefs == labelArgs)
}
"generateConstants" should "generate de-duplicated labels for strings" in {
val labelGenerator = new LabelGenerator
val strings = List("hello", "world", "hello\u0000world", "hello", "Hello")
val distincts = strings.distinct.length
val labels = strings.map(labelGenerator.getLabelArg(_).toString).toSet
val asmLines = labelGenerator.generateConstants
assert(
asmLines.collect { case LabelDef(name) =>
name
}.length == distincts
)
assert(
asmLines.collect { case Directive.Asciz(str) => str }.length == distincts
)
assert(asmLines.collect { case LabelDef(name) => name }.toList.toSet == labels)
}
}

View File

@@ -0,0 +1,140 @@
package wacc
import org.scalatest.flatspec.AnyFlatSpec
import cats.data.Chain
class StackSpec extends AnyFlatSpec {
import microWacc._
import assemblyIR._
import assemblyIR.Size._
import assemblyIR.commonRegisters._
import types.{KnownType, ?}
import sizeExtensions.size
private val RSP = Register(Q64, RegName.SP)
"size" should "be 0 initially" in {
val stack = new Stack
assert(stack.size == 0)
}
"push" should "add an expression to the stack" in {
val stack = new Stack
val expr = Ident("x", 0)(?)
val result = stack.push(expr, RAX)
assert(stack.size == 1)
assert(result == Push(RAX))
}
it should "add 2 expressions to the stack" in {
val stack = new Stack
val expr1 = Ident("x", 0)(?)
val expr2 = Ident("x", 1)(?)
val result1 = stack.push(expr1, RAX)
val result2 = stack.push(expr2, RCX)
assert(stack.size == 2)
assert(result1 == Push(RAX))
assert(result2 == Push(RCX))
}
it should "add a value to the stack" in {
val stack = new Stack
val result = stack.push(D32, RAX)
assert(stack.size == 1)
assert(result == Push(RAX))
}
"reserve" should "reserve space for an identifier" in {
val stack = new Stack
val ident = Ident("x", 0)(KnownType.Int)
val result = stack.reserve(ident)
assert(stack.size == 1)
assert(result == Subtract(RSP, ImmediateVal(Q64.toInt)))
}
it should "reserve space for a register" in {
val stack = new Stack
val result = stack.reserve(RAX)
assert(stack.size == 1)
assert(result == Subtract(RSP, ImmediateVal(Q64.toInt)))
}
it should "reserve space for multiple values" in {
val stack = new Stack
val result = stack.reserve(D32, Q64, B8)
assert(stack.size == 3)
assert(result == Subtract(RSP, ImmediateVal(Q64.toInt * 3)))
}
"pop" should "remove the last value from the stack" in {
val stack = new Stack
stack.push(D32, RAX)
val result = stack.pop(RAX)
assert(stack.size == 0)
assert(result == Pop(RAX))
}
"drop" should "remove the last 2 value from the stack" in {
val stack = new Stack
stack.push(D32, RAX)
stack.push(Q64, RAX)
stack.push(B8, RAX)
val result = stack.drop(2)
assert(stack.size == 1)
assert(result == Add(RSP, ImmediateVal(Q64.toInt * 2)))
}
"withScope" should "reset stack after block" in {
val stack = new Stack
stack.push(D32, RAX)
stack.push(Q64, RCX)
stack.push(B8, RDX)
val result = stack.withScope(() =>
Chain(
stack.push(Q64, RSI),
stack.push(B8, RDI),
stack.push(B8, RBP)
)
)
assert(stack.size == 3)
assert(
result == Chain(
Push(RSI),
Push(RDI),
Push(RBP),
Add(RSP, ImmediateVal(Q64.toInt * 3))
)
)
}
"accessVar" should "return the correctly-sized memory location for the identifier" in {
val stack = new Stack
val id = Ident("x", 0)(KnownType.Int)
stack.push(Q64, RAX)
stack.push(id, RCX)
stack.push(B8, RDX)
stack.push(D32, RSI)
val result = stack.accessVar(Ident("x", 0)(KnownType.Int))
assert(result == MemLocation(RSP, Q64.toInt * 2, opSize = Some(KnownType.Int.size)))
}
"contains" should "return true if the stack contains the identifier" in {
val stack = new Stack
val id = Ident("x", 0)(KnownType.Int)
stack.push(D32, RAX)
stack.push(id, RCX)
stack.push(B8, RDX)
assert(stack.contains(id))
assert(!stack.contains(Ident("x", 1)(KnownType.Int)))
}
"head" should "return the correct memory location for the last element" in {
val stack = new Stack
val id = Ident("x", 0)(KnownType.Int)
stack.push(D32, RAX)
stack.push(id, RCX)
stack.push(B8, RDX)
val result = stack.head
assert(result == MemLocation(RSP, opSize = Some(B8)))
}
}

View File

@@ -1,14 +1,19 @@
package wacc
import org.scalatest.BeforeAndAfterAll
import org.scalatest.flatspec.AnyFlatSpec
import org.scalatest.Inspectors.forEvery
import org.scalatest.matchers.should.Matchers._
import org.scalatest.freespec.AsyncFreeSpec
import cats.effect.testing.scalatest.AsyncIOSpec
import java.io.File
import java.nio.file.Path
import sys.process._
import java.io.PrintStream
import scala.io.Source
import cats.effect.IO
import wacc.{compile as compileWacc}
class ParallelExamplesSpec extends AsyncFreeSpec with AsyncIOSpec with BeforeAndAfterAll {
class ParallelExamplesSpec extends AnyFlatSpec with BeforeAndAfterAll {
val files =
allWaccFiles("wacc-examples/valid").map { p =>
(p.toString, List(0))
@@ -21,52 +26,52 @@ class ParallelExamplesSpec extends AnyFlatSpec with BeforeAndAfterAll {
} ++
allWaccFiles("wacc-examples/invalid/whack").map { p =>
(p.toString, List(100, 200))
} ++
allWaccFiles("extension/examples/valid").map { p =>
(p.toString, List(0))
} ++
allWaccFiles("extension/examples/invalid/syntax").map { p =>
(p.toString, List(100))
} ++
allWaccFiles("extension/examples/invalid/semantics").map { p =>
(p.toString, List(200))
}
// tests go here
forEvery(files) { (filename, expectedResult) =>
val baseFilename = filename.stripSuffix(".wacc")
given stdout: PrintStream = PrintStream(File(baseFilename + ".out"))
s"$filename" should "be compiled with correct result" in {
val result = compile(filename)
assert(expectedResult.contains(result))
s"$filename" - {
"should be compiled with correct result" in {
if (fileIsPendingFrontend(filename))
IO.pure(pending)
else
compileWacc(Path.of(filename), outputDir = None, log = false).map { result =>
expectedResult should contain(result)
}
}
if (expectedResult == List(0)) it should "run with correct result" in {
if (fileIsDisallowedBackend(filename)) pending
// Retrieve contents to get input and expected output + exit code
val contents = scala.io.Source.fromFile(File(filename)).getLines.toList
val inputLine =
contents
.find(_.matches("^# ?[Ii]nput:.*$"))
.map(_.split(":").last.strip + "\n")
.getOrElse("")
val outputLineIdx = contents.indexWhere(_.matches("^# ?[Oo]utput:.*$"))
val expectedOutput =
if (outputLineIdx == -1) ""
if (expectedResult == List(0)) {
"should run with correct result" in {
if (fileIsDisallowedBackend(filename))
IO.pure(succeed)
else if (fileIsPendingBackend(filename))
IO.pure(pending)
else
contents
.drop(outputLineIdx + 1)
.takeWhile(_.startsWith("#"))
.map(_.stripPrefix("#").stripLeading)
.mkString("\n")
for {
contents <- IO(Source.fromFile(File(filename)).getLines.toList)
inputLine = extractInput(contents)
expectedOutput = extractOutput(contents)
expectedExit = extractExit(contents)
val exitLineIdx = contents.indexWhere(_.matches("^# ?[Ee]xit:.*$"))
val expectedExit =
if (exitLineIdx == -1) 0
else contents(exitLineIdx + 1).stripPrefix("#").strip.toInt
asmFilename = baseFilename + ".s"
execFilename = baseFilename
gccResult <- IO(s"gcc -o $execFilename -z noexecstack $asmFilename".!)
// Assembly and link using gcc
val asmFilename = baseFilename + ".s"
val execFilename = baseFilename
val gccResult = s"gcc -o $execFilename -z noexecstack $asmFilename".!
assert(gccResult == 0)
_ = assert(gccResult == 0)
// Run the executable with the provided input
val stdout = new StringBuilder
val process = s"timeout 5s $execFilename" run ProcessIO(
stdout <- IO.pure(new StringBuilder)
process <- IO {
s"timeout 5s $execFilename" run ProcessIO(
in = w => {
w.write(inputLine.getBytes)
w.close()
@@ -74,44 +79,66 @@ class ParallelExamplesSpec extends AnyFlatSpec with BeforeAndAfterAll {
out = Source.fromInputStream(_).addString(stdout),
err = _ => ()
)
}
assert(process.exitValue == expectedExit)
assert(
stdout.toString
.replaceAll("0x[0-9a-f]+", "#addrs#")
.replaceAll("fatal error:.*", "#runtime_error#\u0000")
.takeWhile(_ != '\u0000')
== expectedOutput
)
exitCode <- IO.pure(process.exitValue)
} yield {
exitCode shouldBe expectedExit
normalizeOutput(stdout.toString) shouldBe expectedOutput
}
}
}
}
}
def allWaccFiles(dir: String): IndexedSeq[os.Path] =
val d = java.io.File(dir)
os.walk(os.Path(d.getAbsolutePath)).filter { _.ext == "wacc" }
os.walk(os.Path(d.getAbsolutePath)).filter(_.ext == "wacc")
def fileIsDisallowedBackend(filename: String): Boolean =
Seq(
// format: off
// disable formatting to avoid binPack
"^.*wacc-examples/valid/advanced.*$",
// "^.*wacc-examples/valid/array.*$",
// "^.*wacc-examples/valid/basic/exit.*$",
// "^.*wacc-examples/valid/basic/skip.*$",
// "^.*wacc-examples/valid/expressions.*$",
// "^.*wacc-examples/valid/function/nested_functions.*$",
// "^.*wacc-examples/valid/function/simple_functions.*$",
// "^.*wacc-examples/valid/if.*$",
// "^.*wacc-examples/valid/IO/print.*$",
// "^.*wacc-examples/valid/IO/read.*$",
// "^.*wacc-examples/valid/IO/IOLoop.wacc.*$",
// "^.*wacc-examples/valid/IO/IOSequence.wacc.*$",
// "^.*wacc-examples/valid/pairs.*$",
//"^.*wacc-examples/valid/runtimeErr.*$",
// "^.*wacc-examples/valid/scope.*$",
// "^.*wacc-examples/valid/sequence.*$",
// "^.*wacc-examples/valid/variables.*$",
// "^.*wacc-examples/valid/while.*$",
// format: on
).find(filename.matches).isDefined
private def fileIsDisallowedBackend(filename: String): Boolean =
filename.matches("^.*wacc-examples/valid/advanced.*$")
private def fileIsPendingFrontend(filename: String): Boolean =
List(
// "^.*extension/examples/invalid/syntax/imports/importBadSyntax.*$",
// "^.*extension/examples/invalid/semantics/imports.*$",
// "^.*extension/examples/valid/imports.*$"
).exists(filename.matches)
private def fileIsPendingBackend(filename: String): Boolean =
List(
// "^.*extension/examples/invalid/syntax/imports.*$",
// "^.*extension/examples/invalid/semantics/imports.*$",
// "^.*extension/examples/valid/imports.*$"
).exists(filename.matches)
private def extractInput(contents: List[String]): String =
contents
.find(_.matches("^# ?[Ii]nput:.*$"))
.map(_.split(":").last.strip + "\n")
.getOrElse("")
private def extractOutput(contents: List[String]): String = {
val outputLineIdx = contents.indexWhere(_.matches("^# ?[Oo]utput:.*$"))
if (outputLineIdx == -1) ""
else
contents
.drop(outputLineIdx + 1)
.takeWhile(_.startsWith("#"))
.map(_.stripPrefix("#").stripLeading)
.mkString("\n")
}
private def extractExit(contents: List[String]): Int = {
val exitLineIdx = contents.indexWhere(_.matches("^# ?[Ee]xit:.*$"))
if (exitLineIdx == -1) 0
else contents(exitLineIdx + 1).stripPrefix("#").strip.toInt
}
private def normalizeOutput(output: String): String =
output
.replaceAll("0x[0-9a-f]+", "#addrs#")
.replaceAll("fatal error:.*", "#runtime_error#\u0000")
.takeWhile(_ != '\u0000')
}

BIN
wacc-compiler Executable file

Binary file not shown.

View File

@@ -0,0 +1,4 @@
.vscode/**
.vscode-test/**
.gitignore
vsc-extension-quickstart.md

9
wacc-syntax/CHANGELOG.md Normal file
View File

@@ -0,0 +1,9 @@
# Change Log
All notable changes to the "wacc-syntax" extension will be documented in this file.
Check [Keep a Changelog](http://keepachangelog.com/) for recommendations on how to structure this file.
## [Unreleased]
- Initial release

7
wacc-syntax/README.md Normal file
View File

@@ -0,0 +1,7 @@
### INTELLIWACC
This is the IntelliWACC extension for WACC code development; featuring syntax highlighting, error messages/highlighting and imports.
This extension was developed as a part of the "WACC Extensions" milestone 2025.
Authored by Alex L, Gleb K, Guy C and Jonny T

BIN
wacc-syntax/README.pdf Normal file

Binary file not shown.

71
wacc-syntax/extension.js Normal file
View File

@@ -0,0 +1,71 @@
// Developed using the VSC language extension tutorial
// https://code.visualstudio.com/api/language-extensions/overview
const vscode = require('vscode');
const { execSync } = require('child_process');
const { parse } = require('path');
function activate(context) {
console.log('IntelliWACC is now active!');
let diagnosticCollection = vscode.languages.createDiagnosticCollection('wacc');
context.subscriptions.push(diagnosticCollection);
vscode.workspace.onDidSaveTextDocument((document) => {
if (document.languageId !== 'wacc') return;
let diagnostics = [];
let errors = generateErrors(document.getText(), document.fileName);
errors.forEach(error => {
console.log(error);
let range = new vscode.Range(error.line - 1 , error.column - 1, error.line - 1, error.column + error.size);
let diagnostic = new vscode.Diagnostic(range, error.errorMessage, vscode.DiagnosticSeverity.Error);
diagnostics.push(diagnostic);
});
diagnosticCollection.set(document.uri, diagnostics);
});
}
function deactivate() {
console.log('IntelliWACC is deactivating...');
}
function generateErrors(code, filePath) {
try {
console.log("generating errors")
const fs = require('fs');
const tmpFilePath = parse(filePath).dir + '/.temp_wacc_file.wacc';
fs.writeFileSync(tmpFilePath, code);
let output;
try {
const waccExePath = `${__dirname}/wacc-compiler`;
output = execSync(`${waccExePath} ${tmpFilePath}`, { encoding: 'utf8', shell: true, stdio: 'pipe'});
} catch (err) {
console.log("Error running compiler");
output = err.stdout;
console.log(output);
}
let errors = [];
errorRegex = /\(line ([\d]+), column ([\d]+)\):\n([^>]+)([^\^]+)([\^]+)\n([^\n]+)([^\(]*)/g
while((match = errorRegex.exec(output)) !== null) {
console.log(match[5]);
errors.push({
line: parseInt(match[1], 10),
column: parseInt(match[2], 10),
errorMessage: match[3].trim(),
size: match[5].length - 1
});
}
return errors;
} catch (err) {
console.error('Error running compiler:', err);
return [];
}
}
module.exports = {
activate,
deactivate
};

BIN
wacc-syntax/icon.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 571 KiB

View File

@@ -0,0 +1,28 @@
{
"comments": {
// symbol used for single line comment. Remove this entry if your language does not support line comments
"lineComment": "#",
},
// symbols used as brackets
"brackets": [
["{", "}"],
["[", "]"],
["(", ")"]
],
// symbols that are auto closed when typing
"autoClosingPairs": [
["{", "}"],
["[", "]"],
["(", ")"],
["\"", "\""],
["'", "'"]
],
// symbols that can be used to surround a selection
"surroundingPairs": [
["{", "}"],
["[", "]"],
["(", ")"],
["\"", "\""],
["'", "'"]
]
}

15
wacc-syntax/package-lock.json generated Normal file
View File

@@ -0,0 +1,15 @@
{
"name": "wacc-syntax",
"version": "0.0.1",
"lockfileVersion": 3,
"requires": true,
"packages": {
"": {
"name": "wacc-syntax",
"version": "0.0.1",
"engines": {
"vscode": "^1.97.0"
}
}
}
}

40
wacc-syntax/package.json Normal file
View File

@@ -0,0 +1,40 @@
{
"name": "wacc-syntax",
"displayName": "intelliWACC",
"description": "WACC language support features",
"version": "0.0.1",
"publisher": "WACC-37-2025",
"icon": "icon.png",
"engines": {
"vscode": "^1.97.0"
},
"categories": [
"Programming Languages"
],
"contributes": {
"languages": [{
"id": "wacc",
"aliases": ["WACC", "wacc"],
"extensions": [".wacc"],
"configuration": "./language-configuration.json"
}],
"grammars": [{
"language": "wacc",
"scopeName": "source.wacc",
"path": "./syntaxes/wacc.tmLanguage.json"
}],
"properties": {
"files.exclude": {
"type": "object",
"default": {
"**/.temp_wacc_file.*": true
},
"description": "Configure patterns for excluding files and folders."
}
}
},
"scripts": {
"build": "vsce package"
},
"main": "./extension.js"
}

View File

@@ -0,0 +1,56 @@
{
"$schema": "https://raw.githubusercontent.com/martinring/tmlanguage/master/tmlanguage.json",
"name": "WACC",
"scopeName": "source.wacc",
"fileTypes": [
"wacc"
],
"patterns": [
{
"match": "\\b(true|false)\\b",
"name": "keyword.constant.wacc"
},
{
"match": "\\b(int|bool|char|string|pair|null)\\b",
"name": "storage.type.wacc"
},
{
"match": "\".*?\"",
"name": "string.quoted.double.mylang"
},
{
"match": "\\b(begin|end)\\b",
"name": "keyword.other.unit"
},
{
"match": "\\b(if|then|else|fi|while|do|done|skip|is)\\b",
"name": "keyword.control.wacc"
},
{
"match": "\\b(read|free|print|println|newpair|call|fst|snd|ord|chr|len)\\b",
"name": "keyword.operator.function.wacc"
},
{
"match": "\\b(return|exit)\\b",
"name": "keyword.operator.wacc"
},
{
"match": "'[^']{1}'",
"name": "constant.character.wacc"
},
{
"match": "\\b([a-zA-Z_][a-zA-Z0-9_]*)\\s*(?=\\()",
"name": "variable.function.wacc"
},
{
"match": "\\b([a-zA-Z_][a-zA-Z0-9_]*)\\b",
"name": "variable.other.wacc"
},
{
"match": "#.*$",
"name": "comment.line"
}
]
}

View File

@@ -0,0 +1,29 @@
# Welcome to your VS Code Extension
## What's in the folder
* This folder contains all of the files necessary for your extension.
* `package.json` - this is the manifest file in which you declare your language support and define the location of the grammar file that has been copied into your extension.
* `syntaxes/wacc.tmLanguage.json` - this is the Text mate grammar file that is used for tokenization.
* `language-configuration.json` - this is the language configuration, defining the tokens that are used for comments and brackets.
## Get up and running straight away
* Make sure the language configuration settings in `language-configuration.json` are accurate.
* Press `F5` to open a new window with your extension loaded.
* Create a new file with a file name suffix matching your language.
* Verify that syntax highlighting works and that the language configuration settings are working.
## Make changes
* You can relaunch the extension from the debug toolbar after making changes to the files listed above.
* You can also reload (`Ctrl+R` or `Cmd+R` on Mac) the VS Code window with your extension to load your changes.
## Add more language features
* To add features such as IntelliSense, hovers and validators check out the VS Code extenders documentation at https://code.visualstudio.com/docs
## Install your extension
* To start using your extension with Visual Studio Code copy it into the `<user home>/.vscode/extensions` folder and restart Code.
* To share your extension with the world, read on https://code.visualstudio.com/docs about publishing an extension.

Binary file not shown.