feat: imports and parallelised renamer #40

Merged
gk1623 merged 14 commits from imports into master 2025-03-13 23:10:38 +00:00
31 changed files with 860 additions and 323 deletions

1
.gitignore vendored
View File

@@ -4,4 +4,3 @@
.vscode/
wacc-examples/
.idea/

View File

@@ -0,0 +1,10 @@
begin
int main() is
int a = 5 ;
string b = "Hello" ;
return a + b
end
int result = call main() ;
exit result
end

View File

@@ -0,0 +1,6 @@
import "./doesNotExist.wacc" (main)
begin
int result = call main() ;
exit result
end

View File

@@ -0,0 +1,6 @@
import "../../../valid/sum.wacc" (mult)
begin
int result = call mult(3, 2) ;
exit result
end

View File

@@ -0,0 +1,10 @@
import "../badWacc.wacc" (main)
begin
int sum(int a, int b) is
return a + b
end
int result = call main() ;
exit result
end

View File

@@ -0,0 +1,6 @@
import "./importBadSem.wacc" (sum)
begin
int result = call sum(1, 2) ;
exit result
end

View File

@@ -0,0 +1,6 @@
import "../../../valid/imports/basic.wacc" (sum)
begin
int result = call sum(3, 2) ;
exit result
end

View File

@@ -0,0 +1,6 @@
int main() is
println "Hello World!" ;
return 0
end
skip

View File

@@ -0,0 +1,8 @@
import "../../../valid/sum.wacc" sum, main
begin
int result1 = call sum(5, 10) ;
int result2 = call main() ;
println result1 ;
println result2
end

View File

@@ -0,0 +1,5 @@
import "../../../valid/sum.wacc" ()
begin
exit 0
end

View File

@@ -0,0 +1,10 @@
import "../badWacc.wacc" (main)
begin
int sum(int a, int b) is
return a + b
end
int result = call main() ;
exit result
end

View File

@@ -0,0 +1,6 @@
import "./importBadSyntax.wacc" (sum)
begin
int result = call sum(1, 2) ;
exit result
end

View File

@@ -0,0 +1 @@

View File

@@ -0,0 +1,9 @@
import "../../../valid/sum.wacc" (sum) ;
import "../../../valid/sum.wacc" (main) ;
begin
int result1 = call sum(5, 10) ;
int result2 = call main() ;
println result1 ;
println result2
end

View File

@@ -0,0 +1,5 @@
import "../../../valid/sum.wacc" *
begin
exit 0
end

View File

@@ -0,0 +1,5 @@
import "../../../valid/sum.wacc" (*)
begin
exit 0
end

7
extension/examples/valid/.gitignore vendored Normal file
View File

@@ -0,0 +1,7 @@
*
!imports/
imports/*
!.gitignore
!*.wacc

View File

@@ -0,0 +1,22 @@
# import main from ../sum.wacc and ./basic.wacc
# Output:
# 15
# 0
# -33
#
# Exit:
# 0
# Program:
import "../sum.wacc" (main as sumMain)
import "./basic.wacc" (main)
begin
int result1 = call sumMain() ;
int result2 = call main() ;
println result1 ;
println result2
end

View File

@@ -0,0 +1,21 @@
# import sum from ../sum.wacc
# Output:
# -33
#
# Exit:
# 0
# Program:
import "../sum.wacc" (sum)
begin
int main() is
int result = call sum(-10, -23) ;
return result
end
int result = call main() ;
println result
end

View File

@@ -0,0 +1,33 @@
# import all the mains
# Output:
# 15
# -33
# 0
# -33
# 0
#
# Exit:
# 99
# Program:
import "../sum.wacc" (main as sumMain)
import "./basic.wacc" (main as basicMain)
import "./multiFunc.wacc" (main as multiFuncMain)
begin
int main() is
int result1 = call sumMain() ;
int result2 = call basicMain() ;
int result3 = call multiFuncMain() ;
println result1 ;
println result2 ;
println result3 ;
return 99
end
int result = call main() ;
exit result
end

View File

@@ -0,0 +1,27 @@
# import sum, main from ../sum.wacc
# Output:
# 15
# -33
# 0
# 0
#
# Exit:
# 0
# Program:
import "../sum.wacc" (sum, main as sumMain)
begin
int main() is
int result = call sum(-10, -23) ;
println result ;
return 0
end
int result1 = call sumMain() ;
int result2 = call main() ;
println result1 ;
println result2
end

View File

@@ -0,0 +1,27 @@
# simple sum program
# Output:
# 15
#
# Exit:
# 0
# Program:
begin
int sum(int a, int b) is
return a + b
end
int main() is
int a = 5 ;
int b = 10 ;
int result = call sum(a, b) ;
println result ;
return 0
end
int result = call main() ;
exit result
end

View File

@@ -18,6 +18,7 @@ import org.typelevel.log4cats.Logger
import assemblyIR as asm
import cats.data.ValidatedNel
import java.io.File
/*
TODO:
@@ -68,21 +69,25 @@ val outputOpt: Opts[Option[Path]] =
.orNone
def frontend(
contents: String
): Either[NonEmptyList[Error], microWacc.Program] =
contents: String,
file: File
): IO[Either[NonEmptyList[Error], microWacc.Program]] =
parser.parse(contents) match {
case Failure(msg) => Left(NonEmptyList.one(Error.SyntaxError(msg)))
case Success(prog) =>
case Failure(msg) => IO.pure(Left(NonEmptyList.one(Error.SyntaxError(file, msg))))
case Success(fn) =>
val partialProg = fn(file)
given errors: mutable.Builder[Error, List[Error]] = List.newBuilder
val (names, funcs) = renamer.rename(prog)
given ctx: typeChecker.TypeCheckerCtx = typeChecker.TypeCheckerCtx(names, funcs, errors)
val typedProg = typeChecker.check(prog)
for {
(prog, renameErrors) <- renamer.rename(partialProg)
_ = errors.addAll(renameErrors.toList)
typedProg = typeChecker.check(prog, errors)
NonEmptyList.fromList(errors.result) match {
case Some(errors) => Left(errors)
case None => Right(typedProg)
}
res = NonEmptyList.fromList(errors.result) match {
case Some(errors) => Left(errors)
case None => Right(typedProg)
}
} yield res
}
def backend(typedProg: microWacc.Program): Chain[asm.AsmLine] =
@@ -105,29 +110,31 @@ def compile(
writer.writeTo(backend(typedProg), outputPath) *>
logger.info(s"Success: ${outputPath.toAbsolutePath}")
def processProgram(contents: String, outDir: Path): IO[Int] =
frontend(contents) match {
case Left(errors) =>
val code = errors.map(err => err.exitCode).toList.min
given errorContent: String = contents
val errorMsg = errors.map(formatError).toIterable.mkString("\n")
for {
_ <- logAction(s"Compilation failed for $filePath\nExit code: $code")
_ <- IO.blocking(
// Explicit println since we want this to always show without logger thread info e.t.c.
println(s"Compilation failed for ${filePath.toAbsolutePath}:\n$errorMsg")
)
} yield code
def processProgram(contents: String, file: File, outDir: Path): IO[Int] =
for {
frontendResult <- frontend(contents, file)
res <- frontendResult match {
case Left(errors) =>
val code = errors.map(err => err.exitCode).toList.min
val errorMsg = errors.map(formatError).toIterable.mkString("\n")
for {
_ <- logAction(s"Compilation failed for $filePath\nExit code: $code")
_ <- IO.blocking(
// Explicit println since we want this to always show without logger thread info e.t.c.
println(s"Compilation failed for ${file.getCanonicalPath}:\n$errorMsg")
)
} yield code
case Right(typedProg) =>
val outputFile = outDir.resolve(filePath.getFileName.toString.stripSuffix(".wacc") + ".s")
writeOutputFile(typedProg, outputFile).as(SUCCESS)
}
case Right(typedProg) =>
val outputFile = outDir.resolve(filePath.getFileName.toString.stripSuffix(".wacc") + ".s")
writeOutputFile(typedProg, outputFile).as(SUCCESS)
}
} yield res
for {
contents <- readSourceFile
_ <- logAction(s"Compiling file: ${filePath.toAbsolutePath}")
exitCode <- processProgram(contents, outputDir.getOrElse(filePath.getParent))
exitCode <- processProgram(contents, filePath.toFile, outputDir.getOrElse(filePath.getParent))
} yield exitCode
}

View File

@@ -18,7 +18,7 @@ private class LabelGenerator {
}
private def getLabel(target: CallTarget | RuntimeError): String = target match {
case Ident(v, _) => s"wacc_$v"
case Ident(v, guid) => s"wacc_${v}_$guid"
case Builtin(name) => s"_$name"
case err: RuntimeError => s".L.${err.name}"
}

View File

@@ -2,6 +2,7 @@ package wacc
import wacc.ast.Position
import wacc.types._
import java.io.File
private val SYNTAX_ERROR = 100
private val SEMANTIC_ERROR = 200
@@ -18,13 +19,13 @@ enum Error {
case TypeMismatch(pos: Position, expected: SemType, got: SemType, msg: String)
case InternalError(pos: Position, msg: String)
case SyntaxError(msg: String)
case SyntaxError(file: File, msg: String)
}
extension (e: Error) {
def exitCode: Int = e match {
case Error.SyntaxError(_) => SYNTAX_ERROR
case _ => SEMANTIC_ERROR
case Error.SyntaxError(_, _) => SYNTAX_ERROR
case _ => SEMANTIC_ERROR
}
}
@@ -35,15 +36,25 @@ extension (e: Error) {
* @param errorContent
* Contents of the file to generate code snippets
*/
def formatError(error: Error)(using errorContent: String): String = {
def formatError(error: Error): String = {
val sb = new StringBuilder()
/** Format the file of an error
*
* @param file
* File of the error
*/
def formatFile(file: File): Unit = {
sb.append(s"File: ${file.getCanonicalPath}\n")
}
/** Function to format the position of an error
*
* @param pos
* Position of the error
*/
def formatPosition(pos: Position): Unit = {
formatFile(pos.file)
sb.append(s"(line ${pos.line}, column ${pos.column}):\n")
}
@@ -55,7 +66,7 @@ def formatError(error: Error)(using errorContent: String): String = {
* Size(in chars) of section to highlight
*/
def formatHighlight(pos: Position, size: Int): Unit = {
val lines = errorContent.split("\n")
val lines = os.read(os.Path(pos.file.getCanonicalPath)).split("\n")
val preLine = if (pos.line > 1) lines(pos.line - 2) else ""
val midLine = lines(pos.line - 1)
val postLine = if (pos.line < lines.size) lines(pos.line) else ""
@@ -67,7 +78,7 @@ def formatError(error: Error)(using errorContent: String): String = {
}
error match {
case Error.SyntaxError(_) =>
case Error.SyntaxError(_, _) =>
sb.append("Syntax error:\n")
case _ =>
sb.append("Semantic error:\n")
@@ -76,40 +87,40 @@ def formatError(error: Error)(using errorContent: String): String = {
error match {
case Error.DuplicateDeclaration(ident) =>
formatPosition(ident.pos)
sb.append(s"Duplicate declaration of identifier ${ident.v}")
sb.append(s"Duplicate declaration of identifier ${ident.v}\n")
formatHighlight(ident.pos, ident.v.length)
case Error.UndeclaredVariable(ident) =>
formatPosition(ident.pos)
sb.append(s"Undeclared variable ${ident.v}")
sb.append(s"Undeclared variable ${ident.v}\n")
formatHighlight(ident.pos, ident.v.length)
case Error.UndefinedFunction(ident) =>
formatPosition(ident.pos)
sb.append(s"Undefined function ${ident.v}")
sb.append(s"Undefined function ${ident.v}\n")
formatHighlight(ident.pos, ident.v.length)
case Error.FunctionParamsMismatch(id, expected, got, funcType) =>
formatPosition(id.pos)
sb.append(s"Function expects $expected parameters, got $got")
sb.append(s"Function expects $expected parameters, got $got\n")
sb.append(
s"(function ${id.v} has type (${funcType.params.mkString(", ")}) -> ${funcType.returnType})"
s"(function ${id.v} has type (${funcType.params.mkString(", ")}) -> ${funcType.returnType})\n"
)
formatHighlight(id.pos, 1)
case Error.TypeMismatch(pos, expected, got, msg) =>
formatPosition(pos)
sb.append(s"Type mismatch: $msg\nExpected: $expected\nGot: $got")
sb.append(s"Type mismatch: $msg\nExpected: $expected\nGot: $got\n")
formatHighlight(pos, 1)
case Error.SemanticError(pos, msg) =>
formatPosition(pos)
sb.append(msg)
sb.append(msg + "\n")
formatHighlight(pos, 1)
case wacc.Error.InternalError(pos, msg) =>
formatPosition(pos)
sb.append(s"Internal error: $msg")
sb.append(s"Internal error: $msg\n")
formatHighlight(pos, 1)
case Error.SyntaxError(msg) =>
sb.append(msg)
case Error.SyntaxError(file, msg) =>
formatFile(file)
sb.append(msg + "\n")
sb.append("\n")
}
sb.toString()
}

View File

@@ -1,5 +1,6 @@
package wacc
import java.io.File
import parsley.Parsley
import parsley.generic.ErrorBridge
import parsley.ap._
@@ -22,26 +23,42 @@ object ast {
/* ============================ ATOMIC EXPRESSIONS ============================ */
case class IntLiter(v: Int)(val pos: Position) extends Expr6
object IntLiter extends ParserBridgePos1[Int, IntLiter]
object IntLiter extends ParserBridgePos1Atom[Int, IntLiter]
case class BoolLiter(v: Boolean)(val pos: Position) extends Expr6
object BoolLiter extends ParserBridgePos1[Boolean, BoolLiter]
object BoolLiter extends ParserBridgePos1Atom[Boolean, BoolLiter]
case class CharLiter(v: Char)(val pos: Position) extends Expr6
object CharLiter extends ParserBridgePos1[Char, CharLiter]
object CharLiter extends ParserBridgePos1Atom[Char, CharLiter]
case class StrLiter(v: String)(val pos: Position) extends Expr6
object StrLiter extends ParserBridgePos1[String, StrLiter]
object StrLiter extends ParserBridgePos1Atom[String, StrLiter]
case class PairLiter()(val pos: Position) extends Expr6
object PairLiter extends ParserBridgePos0[PairLiter]
case class Ident(v: String, var uid: Int = -1)(val pos: Position) extends Expr6 with LValue
object Ident extends ParserBridgePos1[String, Ident] {
case class Ident(var v: String, var guid: Int = -1, var ty: types.RenamerType = types.?)(
val pos: Position
) extends Expr6
with LValue
object Ident extends ParserBridgePos1Atom[String, Ident] {
def apply(v: String)(pos: Position): Ident = new Ident(v)(pos)
}
case class ArrayElem(name: Ident, indices: NonEmptyList[Expr])(val pos: Position)
extends Expr6
with LValue
object ArrayElem extends ParserBridgePos1[NonEmptyList[Expr], Ident => ArrayElem] {
def apply(a: NonEmptyList[Expr])(pos: Position): Ident => ArrayElem =
name => ArrayElem(name, a)(pos)
object ArrayElem extends ParserBridgePos2Chain[NonEmptyList[Expr], Ident, ArrayElem] {
def apply(indices: NonEmptyList[Expr], name: Ident)(pos: Position): ArrayElem =
new ArrayElem(name, indices)(pos)
}
// object ArrayElem extends ParserBridgePos1[NonEmptyList[Expr], (File => Ident) => ArrayElem] {
// def apply(a: NonEmptyList[Expr])(pos: Position): (File => Ident) => ArrayElem =
// name => ArrayElem(name(pos.file), a)(pos)
// }
// object ArrayElem extends ParserSingletonBridgePos[(File => NonEmptyList[Expr]) => (File => Ident) => File => ArrayElem] {
// // def apply(indices: NonEmptyList[Expr]): (File => Ident) => File => ArrayElem =
// // name => file => new ArrayElem(name(file), )
// def apply(indices: Parsley[File => NonEmptyList[Expr]]): Parsley[(File => Ident) => File => ArrayElem] =
// // error(ap1(pos.map(con),))
// override final def con(pos: (Int, Int)): (File => NonEmptyList[Expr]) => => C =
// (a, b) => file => this.apply(a(file), b(file))(Position(pos._1, pos._2, file))
// }
case class Parens(expr: Expr)(val pos: Position) extends Expr6
object Parens extends ParserBridgePos1[Expr, Parens]
@@ -119,8 +136,9 @@ object ast {
case class ArrayType(elemType: Type, dimensions: Int)(val pos: Position)
extends Type
with PairElemType
object ArrayType extends ParserBridgePos1[Int, Type => ArrayType] {
def apply(a: Int)(pos: Position): Type => ArrayType = elemType => ArrayType(elemType, a)(pos)
object ArrayType extends ParserBridgePos2Chain[Int, Type, ArrayType] {
def apply(dimensions: Int, elemType: Type)(pos: Position): ArrayType =
ArrayType(elemType, dimensions)(pos)
}
case class PairType(fst: PairElemType, snd: PairElemType)(val pos: Position) extends Type
object PairType extends ParserBridgePos2[PairElemType, PairElemType, PairType]
@@ -131,6 +149,18 @@ object ast {
/* ============================ PROGRAM STRUCTURE ============================ */
case class ImportedFunc(sourceName: Ident, importName: Ident)(val pos: Position)
object ImportedFunc extends ParserBridgePos2[Ident, Option[Ident], ImportedFunc] {
def apply(a: Ident, b: Option[Ident])(pos: Position): ImportedFunc =
new ImportedFunc(a, b.getOrElse(a))(pos)
}
case class Import(source: StrLiter, funcs: NonEmptyList[ImportedFunc])(val pos: Position)
object Import extends ParserBridgePos2[StrLiter, NonEmptyList[ImportedFunc], Import]
case class PartialProgram(imports: List[Import], self: Program)(val pos: Position)
object PartialProgram extends ParserBridgePos2[List[Import], Program, PartialProgram]
case class Program(funcs: List[FuncDecl], main: NonEmptyList[Stmt])(val pos: Position)
object Program extends ParserBridgePos2[List[FuncDecl], NonEmptyList[Stmt], Program]
@@ -143,15 +173,15 @@ object ast {
body: NonEmptyList[Stmt]
)(val pos: Position)
object FuncDecl
extends ParserBridgePos2[
List[Param],
NonEmptyList[Stmt],
((Type, Ident)) => FuncDecl
extends ParserBridgePos2Chain[
(List[Param], NonEmptyList[Stmt]),
((Type, Ident)),
FuncDecl
] {
def apply(params: List[Param], body: NonEmptyList[Stmt])(
def apply(paramsBody: (List[Param], NonEmptyList[Stmt]), retTyName: (Type, Ident))(
pos: Position
): ((Type, Ident)) => FuncDecl =
(returnType, name) => FuncDecl(returnType, name, params, body)(pos)
): FuncDecl =
new FuncDecl(retTyName._1, retTyName._2, paramsBody._1, paramsBody._2)(pos)
}
case class Param(paramType: Type, name: Ident)(val pos: Position)
@@ -159,7 +189,9 @@ object ast {
/* ============================ STATEMENTS ============================ */
sealed trait Stmt
sealed trait Stmt {
val pos: Position
}
case class Skip()(val pos: Position) extends Stmt
object Skip extends ParserBridgePos0[Skip]
case class VarDecl(varType: Type, name: Ident, value: RValue)(val pos: Position) extends Stmt
@@ -207,7 +239,7 @@ object ast {
/* ============================ PARSER BRIDGES ============================ */
case class Position(line: Int, column: Int)
case class Position(line: Int, column: Int, file: File)
trait ParserSingletonBridgePos[+A] extends ErrorBridge {
protected def con(pos: (Int, Int)): A
@@ -215,38 +247,63 @@ object ast {
final def <#(op: Parsley[?]): Parsley[A] = this from op
}
trait ParserBridgePos0[+A] extends ParserSingletonBridgePos[A] {
trait ParserBridgePos0[+A] extends ParserSingletonBridgePos[File => A] {
def apply()(pos: Position): A
override final def con(pos: (Int, Int)): A =
apply()(Position(pos._1, pos._2))
override final def con(pos: (Int, Int)): File => A =
file => apply()(Position(pos._1, pos._2, file))
}
trait ParserBridgePos1[-A, +B] extends ParserSingletonBridgePos[A => B] {
trait ParserBridgePos1Atom[-A, +B] extends ParserSingletonBridgePos[A => File => B] {
def apply(a: A)(pos: Position): B
def apply(a: Parsley[A]): Parsley[B] = error(ap1(pos.map(con), a))
def apply(a: Parsley[A]): Parsley[File => B] = error(ap1(pos.map(con), a))
override final def con(pos: (Int, Int)): A => B =
this.apply(_)(Position(pos._1, pos._2))
override final def con(pos: (Int, Int)): A => File => B =
a => file => this.apply(a)(Position(pos._1, pos._2, file))
}
trait ParserBridgePos2[-A, -B, +C] extends ParserSingletonBridgePos[(A, B) => C] {
trait ParserBridgePos1[-A, +B] extends ParserSingletonBridgePos[(File => A) => File => B] {
def apply(a: A)(pos: Position): B
def apply(a: Parsley[File => A]): Parsley[File => B] = error(ap1(pos.map(con), a))
override final def con(pos: (Int, Int)): (File => A) => File => B =
a => file => this.apply(a(file))(Position(pos._1, pos._2, file))
}
trait ParserBridgePos2Chain[-A, -B, +C]
extends ParserSingletonBridgePos[(File => A) => (File => B) => File => C] {
def apply(a: A, b: B)(pos: Position): C
def apply(a: Parsley[A], b: => Parsley[B]): Parsley[C] = error(
def apply(a: Parsley[File => A]): Parsley[(File => B) => File => C] = error(
ap1(pos.map(con), a)
)
override final def con(pos: (Int, Int)): (File => A) => (File => B) => File => C =
a => b => file => this.apply(a(file), b(file))(Position(pos._1, pos._2, file))
}
trait ParserBridgePos2[-A, -B, +C]
extends ParserSingletonBridgePos[(File => A, File => B) => File => C] {
def apply(a: A, b: B)(pos: Position): C
def apply(a: Parsley[File => A], b: => Parsley[File => B]): Parsley[File => C] = error(
ap2(pos.map(con), a, b)
)
override final def con(pos: (Int, Int)): (A, B) => C =
apply(_, _)(Position(pos._1, pos._2))
override final def con(pos: (Int, Int)): (File => A, File => B) => File => C =
(a, b) => file => this.apply(a(file), b(file))(Position(pos._1, pos._2, file))
}
trait ParserBridgePos3[-A, -B, -C, +D] extends ParserSingletonBridgePos[(A, B, C) => D] {
trait ParserBridgePos3[-A, -B, -C, +D]
extends ParserSingletonBridgePos[(File => A, File => B, File => C) => File => D] {
def apply(a: A, b: B, c: C)(pos: Position): D
def apply(a: Parsley[A], b: => Parsley[B], c: => Parsley[C]): Parsley[D] = error(
def apply(
a: Parsley[File => A],
b: => Parsley[File => B],
c: => Parsley[File => C]
): Parsley[File => D] = error(
ap3(pos.map(con), a, b, c)
)
override final def con(pos: (Int, Int)): (A, B, C) => D =
apply(_, _, _)(Position(pos._1, pos._2))
override final def con(pos: (Int, Int)): (File => A, File => B, File => C) => File => D =
(a, b, c) => file => apply(a(file), b(file), c(file))(Position(pos._1, pos._2, file))
}
}

View File

@@ -1,18 +1,22 @@
package wacc
import java.io.File
import parsley.Result
import parsley.Parsley
import parsley.Parsley.{atomic, many, notFollowedBy, pure, unit}
import parsley.combinator.{countSome, sepBy}
import parsley.combinator.{countSome, sepBy, option}
import parsley.expr.{precedence, SOps, InfixL, InfixN, InfixR, Prefix, Atoms}
import parsley.errors.combinator._
import parsley.errors.patterns.VerifiedErrors
import parsley.syntax.zipped._
import parsley.cats.combinator.{some}
import parsley.cats.combinator.{some, sepBy1}
import cats.syntax.all._
import cats.data.NonEmptyList
import parsley.errors.DefaultErrorBuilder
import parsley.errors.ErrorBuilder
import parsley.errors.tokenextractors.LexToken
import parsley.expr.GOps
import cats.Functor
object parser {
import lexer.implicits.implicitSymbol
@@ -52,13 +56,24 @@ object parser {
implicit val builder: ErrorBuilder[String] = new DefaultErrorBuilder with LexToken {
def tokens = errTokens
}
def parse(input: String): Result[String, Program] = parser.parse(input)
private val parser = lexer.fully(`<program>`)
def parse(input: String): Result[String, File => PartialProgram] = parser.parse(input)
private val parser = lexer.fully(`<partial-program>`)
private type FParsley[A] = Parsley[File => A]
private def fParsley[A](p: Parsley[A]): FParsley[A] =
p map { a => file => a }
private def fPair[A, B](p: Parsley[(File => A, File => B)]): FParsley[(A, B)] =
p map { case (a, b) => file => (a(file), b(file)) }
private def fMap[A, F[_]: Functor](p: Parsley[F[File => A]]): FParsley[F[A]] =
p map { funcs => file => funcs.map(_(file)) }
// Expressions
private lazy val `<expr>`: Parsley[Expr] = precedence {
SOps(InfixR)(Or from "||") +:
SOps(InfixR)(And from "&&") +:
private lazy val `<expr>`: FParsley[Expr] = precedence {
GOps(InfixR)(Or from "||") +:
GOps(InfixR)(And from "&&") +:
SOps(InfixN)(Eq from "==", Neq from "!=") +:
SOps(InfixN)(
Less from "<",
@@ -83,32 +98,33 @@ object parser {
}
// Atoms
private lazy val `<atom>`: Atoms[Expr6] = Atoms(
private lazy val `<atom>`: Atoms[File => Expr6] = Atoms(
IntLiter(integer).label("integer literal"),
BoolLiter(("true" as true) | ("false" as false)).label("boolean literal"),
CharLiter(charLit).label("character literal"),
StrLiter(stringLit).label("string literal"),
`<str-liter>`.label("string literal"),
PairLiter from "null",
`<ident-or-array-elem>`,
Parens("(" ~> `<expr>` <~ ")")
)
private val `<ident>` =
private lazy val `<str-liter>` = StrLiter(stringLit)
private lazy val `<ident>` =
Ident(ident) | some("*" | "&").verifiedExplain("pointer operators are not allowed")
private lazy val `<ident-or-array-elem>` =
(`<ident>` <~ ("(".verifiedExplain(
"functions can only be called using 'call' keyword"
) | unit)) <**> (`<array-indices>` </> identity)
private val `<array-indices>` = ArrayElem(some("[" ~> `<expr>` <~ "]"))
private lazy val `<array-indices>` = ArrayElem(fMap(some("[" ~> `<expr>` <~ "]")))
// Types
private lazy val `<type>`: Parsley[Type] =
private lazy val `<type>`: FParsley[Type] =
(`<base-type>` | (`<pair-type>` ~> `<pair-elems-type>`)) <**> (`<array-type>` </> identity)
private val `<base-type>` =
(IntType from "int") | (BoolType from "bool") | (CharType from "char") | (StringType from "string")
private lazy val `<array-type>` =
ArrayType(countSome("[" ~> "]"))
ArrayType(fParsley(countSome("[" ~> "]")))
private val `<pair-type>` = "pair"
private val `<pair-elems-type>`: Parsley[PairType] = PairType(
private val `<pair-elems-type>`: FParsley[PairType] = PairType(
"(" ~> `<pair-elem-type>` <~ ",",
`<pair-elem-type>` <~ ")"
)
@@ -116,7 +132,7 @@ object parser {
(`<base-type>` <**> (`<array-type>` </> identity)) |
((UntypedPairType from `<pair-type>`) <**>
((`<pair-elems-type>` <**> `<array-type>`)
.map(arr => (_: UntypedPairType) => arr) </> identity))
.map(arr => (_: File => UntypedPairType) => arr) </> identity))
/* Statements
Atomic is used in two places here:
@@ -127,13 +143,30 @@ object parser {
invalid syntax check, this only happens at most once per program so this is not a major
concern.
*/
private lazy val `<partial-program>` = PartialProgram(
fMap(many(`<import>`)),
`<program>`
)
private lazy val `<import>` = Import(
"import" ~> `<import-filename>`,
"(" ~> fMap(sepBy1(`<imported-func>`, ",")) <~ ")"
)
private lazy val `<import-filename>` = `<str-liter>`.label("import file name")
private lazy val `<imported-func>` = ImportedFunc(
`<ident>`.label("imported function name"),
fMap(option("as" ~> `<ident>`)).label("imported function alias")
)
private lazy val `<program>` = Program(
"begin" ~> (
many(
atomic(
`<type>`.label("function declaration") <~> `<ident>` <~ "("
) <**> `<partial-func-decl>`
).label("function declaration") |
fMap(
many(
fPair(
atomic(
`<type>`.label("function declaration") <~> `<ident>` <~ "("
)
) <**> `<partial-func-decl>`
).label("function declaration")
) |
atomic(`<ident>` <~ "(").verifiedExplain("function declaration is missing return type")
),
`<stmt>`.label(
@@ -142,17 +175,23 @@ object parser {
)
private lazy val `<partial-func-decl>` =
FuncDecl(
sepBy(`<param>`, ",") <~ ")" <~ "is",
`<stmt>`.guardAgainst {
case stmts if !stmts.isReturning => Seq("all functions must end in a returning statement")
} <~ "end"
fPair(
(fMap(sepBy(`<param>`, ",")) <~ ")" <~ "is") <~>
(`<stmt>`.guardAgainst {
// TODO: passing in an arbitrary file works but is ugly
case stmts if !(stmts(File("."))).isReturning =>
Seq("all functions must end in a returning statement")
} <~ "end")
)
)
private lazy val `<param>` = Param(`<type>`, `<ident>`)
private lazy val `<stmt>`: Parsley[NonEmptyList[Stmt]] =
(
`<basic-stmt>`.label("main program body"),
(many(";" ~> `<basic-stmt>`.label("statement after ';'"))) </> Nil
).zipped(NonEmptyList.apply)
private lazy val `<stmt>`: FParsley[NonEmptyList[Stmt]] =
fMap(
(
`<basic-stmt>`.label("main program body"),
(many(";" ~> `<basic-stmt>`.label("statement after ';'"))) </> Nil
).zipped(NonEmptyList.apply)
)
private lazy val `<basic-stmt>` =
(Skip from "skip")
@@ -160,8 +199,8 @@ object parser {
| Free("free" ~> `<expr>`.labelAndExplain(LabelType.Expr))
| Return("return" ~> `<expr>`.labelAndExplain(LabelType.Expr))
| Exit("exit" ~> `<expr>`.labelAndExplain(LabelType.Expr))
| Print("print" ~> `<expr>`.labelAndExplain(LabelType.Expr), pure(false))
| Print("println" ~> `<expr>`.labelAndExplain(LabelType.Expr), pure(true))
| Print("print" ~> `<expr>`.labelAndExplain(LabelType.Expr), fParsley(pure(false)))
| Print("println" ~> `<expr>`.labelAndExplain(LabelType.Expr), fParsley(pure(true)))
| If(
"if" ~> `<expr>`.labelWithType(LabelType.Expr) <~ "then",
`<stmt>` <~ "else",
@@ -185,9 +224,9 @@ object parser {
("call" ~> `<ident>`).verifiedExplain(
"function calls' results must be assigned to a variable"
)
private lazy val `<lvalue>`: Parsley[LValue] =
private lazy val `<lvalue>`: FParsley[LValue] =
`<pair-elem>` | `<ident-or-array-elem>`
private lazy val `<rvalue>`: Parsley[RValue] =
private lazy val `<rvalue>`: FParsley[RValue] =
`<array-liter>` |
NewPair(
"newpair" ~> "(" ~> `<expr>` <~ ",",
@@ -196,13 +235,13 @@ object parser {
`<pair-elem>` |
Call(
"call" ~> `<ident>` <~ "(",
sepBy(`<expr>`, ",") <~ ")"
fMap(sepBy(`<expr>`, ",")) <~ ")"
) | `<expr>`.labelWithType(LabelType.Expr)
private lazy val `<pair-elem>` =
Fst("fst" ~> `<lvalue>`.label("valid pair"))
| Snd("snd" ~> `<lvalue>`.label("valid pair"))
private lazy val `<array-liter>` = ArrayLiter(
"[" ~> sepBy(`<expr>`, ",") <~ "]"
"[" ~> fMap(sepBy(`<expr>`, ",")) <~ "]"
)
extension (stmts: NonEmptyList[Stmt]) {

View File

@@ -1,6 +1,15 @@
package wacc
import java.io.File
import scala.collection.mutable
import cats.effect.IO
import cats.syntax.all._
import cats.implicits._
import cats.data.Chain
import cats.data.NonEmptyList
import parsley.{Failure, Success}
private val MAIN = "$main"
object renamer {
import ast._
@@ -11,116 +20,271 @@ object renamer {
case Var
}
private case class ScopeKey(path: String, name: String, identType: IdentType)
private case class ScopeValue(id: Ident, public: Boolean)
private class Scope(
val current: mutable.Map[(String, IdentType), Ident],
val parent: Map[(String, IdentType), Ident]
private val current: mutable.Map[ScopeKey, ScopeValue],
private val parent: Map[ScopeKey, ScopeValue],
guidStart: Int = 0,
val guidInc: Int = 1
) {
private var guid = guidStart
private var immutable = false
private def nextGuid(): Int = {
val id = guid
guid += guidInc
id
}
private def verifyMutable(): Unit = {
if (immutable) throw new IllegalStateException("Cannot modify an immutable scope")
}
/** Create a new scope with the current scope as its parent.
*
* To be used for single-threaded applications.
*
* @return
* A new scope with an empty current scope, and this scope flattened into the parent scope.
*/
def subscope: Scope =
Scope(mutable.Map.empty, Map.empty.withDefault(current.withDefault(parent)))
def withSubscope[T](f: Scope => T): T = {
val subscope =
Scope(mutable.Map.empty, Map.empty.withDefault(current.withDefault(parent)), guid, guidInc)
immutable = true
val result = f(subscope)
guid = subscope.guid // Sync GUID
immutable = false
result
}
/** Create new scopes with the current scope as its parent and GUID numbering adjusted
* correctly.
*
* This will permanently mark the current scope as immutable, for thread safety.
*
* To be used for multi-threaded applications.
*
* @return
* New scopes with an empty current scope, and this scope flattened into the parent scope.
*/
def subscopes(n: Int): Seq[Scope] = {
verifyMutable()
immutable = true
(0 until n).map { i =>
Scope(
mutable.Map.empty,
Map.empty.withDefault(current.withDefault(parent)),
guid + i * guidInc,
guidInc * n
)
}
}
/** Attempt to add a new identifier to the current scope. If the identifier already exists in
* the current scope, add an error to the error list.
*
* @param ty
* The semantic type of the variable identifier, or function identifier type.
* @param name
* The name of the identifier.
* @param globalNames
* The global map of identifiers to semantic types - the identifier will be added to this
* map.
* @param globalNumbering
* The global map of identifier names to the number of times they have been declared - will
* used to rename this identifier, and will be incremented.
* @param errors
* The list of errors to append to.
* @return
* An error, if one occurred.
*/
def add(ty: SemType | FuncType, name: Ident)(using
globalNames: mutable.Map[Ident, SemType],
globalFuncs: mutable.Map[Ident, FuncType],
globalNumbering: mutable.Map[String, Int],
errors: mutable.Builder[Error, List[Error]]
) = {
val identType = ty match {
def add(name: Ident, public: Boolean = false): Chain[Error] = {
verifyMutable()
val path = name.pos.file.getCanonicalPath
val identType = name.ty match {
case _: SemType => IdentType.Var
case _: FuncType => IdentType.Func
}
current.get((name.v, identType)) match {
case Some(Ident(_, uid)) =>
errors += Error.DuplicateDeclaration(name)
name.uid = uid
val key = ScopeKey(path, name.v, identType)
current.get(key) match {
case Some(ScopeValue(Ident(_, id, _), _)) =>
name.guid = id
Chain.one(Error.DuplicateDeclaration(name))
case None =>
val uid = globalNumbering.getOrElse(name.v, 0)
name.uid = uid
current((name.v, identType)) = name
ty match {
case semType: SemType =>
globalNames(name) = semType
case funcType: FuncType =>
globalFuncs(name) = funcType
}
globalNumbering(name.v) = uid + 1
name.guid = nextGuid()
current(key) = ScopeValue(name, public)
Chain.empty
}
}
private def get(name: String, identType: IdentType): Option[Ident] =
/** Attempt to add a new identifier as an alias to another to the existing scope.
*
* @param alias
* The (new) alias identifier.
* @param orig
* The (existing) original identifier.
*
* @return
* An error, if one occurred.
*/
def addAlias(alias: Ident, orig: ScopeValue, public: Boolean = false): Chain[Error] = {
verifyMutable()
val path = alias.pos.file.getCanonicalPath
val identType = alias.ty match {
case _: SemType => IdentType.Var
case _: FuncType => IdentType.Func
}
val key = ScopeKey(path, alias.v, identType)
current.get(key) match {
case Some(ScopeValue(Ident(_, id, _), _)) =>
alias.guid = id
Chain.one(Error.DuplicateDeclaration(alias))
case None =>
alias.guid = nextGuid()
current(key) = ScopeValue(orig.id, public)
Chain.empty
}
}
def get(path: String, name: String, identType: IdentType): Option[ScopeValue] =
// Unfortunately map defaults only work with `.apply()`, which throws an error when the key is not found.
// Neither is there a way to check whether a default exists, so we have to use a try-catch.
try {
Some(current.withDefault(parent)((name, identType)))
Some(current.withDefault(parent)(ScopeKey(path, name, identType)))
} catch {
case _: NoSuchElementException => None
}
def getVar(name: String): Option[Ident] = get(name, IdentType.Var)
def getFunc(name: String): Option[Ident] = get(name, IdentType.Func)
def getVar(name: Ident): Option[Ident] =
get(name.pos.file.getCanonicalPath, name.v, IdentType.Var).map(_.id)
def getFunc(name: Ident): Option[Ident] =
get(name.pos.file.getCanonicalPath, name.v, IdentType.Func).map(_.id)
}
/** Check scoping of all variables and functions in the program. Also generate semantic types for
* all identifiers.
*
* @param prog
* AST of the program
* @param errors
* List of errors to append to
* @return
* Map of all (renamed) identifies to their semantic types
*/
def rename(prog: Program)(using
errors: mutable.Builder[Error, List[Error]]
): (Map[Ident, SemType], Map[Ident, FuncType]) = {
given globalNames: mutable.Map[Ident, SemType] = mutable.Map.empty
given globalFuncs: mutable.Map[Ident, FuncType] = mutable.Map.empty
given globalNumbering: mutable.Map[String, Int] = mutable.Map.empty
val scope = Scope(mutable.Map.empty, Map.empty)
private def prepareGlobalScope(
partialProg: PartialProgram
)(using scope: Scope): IO[(FuncDecl, Chain[FuncDecl], Chain[Error])] = {
def readImportFile(file: File): IO[String] =
IO.blocking(os.read(os.Path(file.getCanonicalPath)))
def prepareImport(contents: String, file: File)(using
scope: Scope
): IO[(Chain[FuncDecl], Chain[Error])] = {
parser.parse(contents) match {
case Failure(msg) =>
IO.pure(Chain.empty, Chain.one(Error.SyntaxError(file, msg)))
case Success(fn) =>
val partialProg = fn(file)
for {
(main, chunks, errors) <- prepareGlobalScope(partialProg)
} yield (main +: chunks, errors)
}
}
def addImportsToScope(importFile: File, funcs: NonEmptyList[ImportedFunc])(using
scope: Scope
): Chain[Error] =
funcs.foldMap { case ImportedFunc(srcName, aliasName) =>
scope.get(importFile.getCanonicalPath, srcName.v, IdentType.Func) match {
case Some(src) if src.public =>
aliasName.ty = src.id.ty
scope.addAlias(aliasName, src)
case _ =>
Chain.one(Error.UndefinedFunction(srcName))
}
}
val PartialProgram(imports, prog) = partialProg
// First prepare this file's functions...
val Program(funcs, main) = prog
funcs
// First add all function declarations to the scope
.map { case FuncDecl(retType, name, params, body) =>
val (funcChunks, funcErrors) = funcs.foldLeft((Chain.empty[FuncDecl], Chain.empty[Error])) {
case ((chunks, errors), func @ FuncDecl(retType, name, params, body)) =>
val paramTypes = params.map { param =>
val paramType = SemType(param.paramType)
param.name.ty = paramType
paramType
}
scope.add(FuncType(SemType(retType), paramTypes), name)
(params zip paramTypes, body)
name.ty = FuncType(SemType(retType), paramTypes)
(chunks :+ func, errors ++ scope.add(name, public = true))
}
// ...and main body.
val mainBodyIdent = Ident(MAIN, ty = FuncType(?, Nil))(prog.pos)
val mainBodyErrors = scope.add(mainBodyIdent, public = false)
val mainBodyChunk = FuncDecl(IntType()(prog.pos), mainBodyIdent, Nil, main)(prog.pos)
// Now handle imports
val file = prog.pos.file
val preparedImports = imports.foldLeftM[IO, (Chain[FuncDecl], Chain[Error])](
(Chain.empty[FuncDecl], Chain.empty[Error])
) { case ((chunks, errors), Import(name, funcs)) =>
val importFile = File(file.getParent, name.v)
if (!importFile.exists()) {
IO.pure(
(
chunks,
errors :+ Error.SemanticError(
name.pos,
s"File not found: ${importFile.getCanonicalPath}"
)
)
)
} else if (!importFile.canRead()) {
IO.pure(
(
chunks,
errors :+ Error.SemanticError(
name.pos,
s"File not readable: ${importFile.getCanonicalPath}"
)
)
)
} else if (importFile.getCanonicalPath == file.getCanonicalPath) {
IO.pure(
(
chunks,
errors :+ Error.SemanticError(
name.pos,
s"Cannot import self: ${importFile.getCanonicalPath}"
)
)
)
} else if (scope.get(importFile.getCanonicalPath, MAIN, IdentType.Func).isDefined) {
IO.pure(chunks, errors ++ addImportsToScope(importFile, funcs))
} else {
for {
contents <- readImportFile(importFile)
(importChunks, importErrors) <- prepareImport(contents, importFile)
importAliasErrors = addImportsToScope(importFile, funcs)
} yield (chunks ++ importChunks, errors ++ importErrors)
}
// Only then rename the function bodies
// (functions can call one-another regardless of order of declaration)
.foreach { case (params, body) =>
val functionScope = scope.subscope
params.foreach { case (param, paramType) =>
functionScope.add(paramType, param.name)
}
body.toList.foreach(rename(functionScope.subscope)) // body can shadow function params
}
main.toList.foreach(rename(scope))
(globalNames.toMap, globalFuncs.toMap)
}
for {
(importChunks, importErrors) <- preparedImports
allChunks = importChunks ++ funcChunks
allErrors = importErrors ++ funcErrors ++ mainBodyErrors
} yield (mainBodyChunk, allChunks, allErrors)
}
/** Check scoping of all variables and flatten a program. Also generates semantic types and parses
* any imported files.
*
* @param partialProg
* AST of the program
* @return
* (flattenedProg, errors)
*/
private def renameFunction(funcScopePair: (FuncDecl, Scope)): IO[Chain[Error]] = {
val (FuncDecl(_, _, params, body), subscope) = funcScopePair
val paramErrors = params.foldMap(param => subscope.add(param.name))
IO(subscope.withSubscope { s => body.foldMap(rename(s)) })
.map(bodyErrors => paramErrors ++ bodyErrors)
}
def rename(partialProg: PartialProgram): IO[(Program, Chain[Error])] = {
given scope: Scope = Scope(mutable.Map.empty, Map.empty)
for {
(main, chunks, globalErrors) <- prepareGlobalScope(partialProg)
toRename = (main +: chunks).toList
allErrors <- toRename
.zip(scope.subscopes(toRename.size))
.parFoldMapA(renameFunction)
// .map(x => x.combineAll)
} yield (Program(chunks.toList, main.body)(main.pos), globalErrors ++ allErrors)
}
/** Check scoping of all identifies in a given AST node.
@@ -129,91 +293,90 @@ object renamer {
* The current scope and flattened parent scope.
* @param node
* The AST node.
* @param globalNames
* The global map of identifiers to semantic types - renamed identifiers will be added to this
* map.
* @param globalNumbering
* The global map of identifier names to the number of times they have been declared - used and
* updated during identifier renaming.
* @param errors
*/
private def rename(scope: Scope)(
node: Ident | Stmt | LValue | RValue | Expr
)(using
globalNames: mutable.Map[Ident, SemType],
globalFuncs: mutable.Map[Ident, FuncType],
globalNumbering: mutable.Map[String, Int],
errors: mutable.Builder[Error, List[Error]]
): Unit = node match {
// These cases are more interesting because the involve making subscopes
// or modifying the current scope.
case VarDecl(synType, name, value) => {
// Order matters here. Variable isn't declared until after the value is evaluated.
rename(scope)(value)
// Attempt to add the new variable to the current scope.
scope.add(SemType(synType), name)
}
case If(cond, thenStmt, elseStmt) => {
rename(scope)(cond)
// then and else both have their own scopes
thenStmt.toList.foreach(rename(scope.subscope))
elseStmt.toList.foreach(rename(scope.subscope))
}
case While(cond, body) => {
rename(scope)(cond)
// while bodies have their own scopes
body.toList.foreach(rename(scope.subscope))
}
// begin-end blocks have their own scopes
case Block(body) => body.toList.foreach(rename(scope.subscope))
private def rename(scope: Scope)(node: Ident | Stmt | LValue | RValue | Expr): Chain[Error] =
node match {
// These cases are more interes/globting because the involve making subscopes
// or modifying the current scope.
case VarDecl(synType, name, value) => {
// Order matters here. Variable isn't declared until after the value is evaluated.
val errors = rename(scope)(value)
// Attempt to add the new variable to the current scope.
name.ty = SemType(synType)
errors ++ scope.add(name)
}
case If(cond, thenStmt, elseStmt) => {
val condErrors = rename(scope)(cond)
// then and else both have their own scopes
val thenErrors = scope.withSubscope(s => thenStmt.foldMap(rename(s)))
val elseErrors = scope.withSubscope(s => elseStmt.foldMap(rename(s)))
condErrors ++ thenErrors ++ elseErrors
}
case While(cond, body) => {
val condErrors = rename(scope)(cond)
// while bodies have their own scopes
val bodyErrors = scope.withSubscope(s => body.foldMap(rename(s)))
condErrors ++ bodyErrors
}
// begin-end blocks have their own scopes
case Block(body) => scope.withSubscope(s => body.foldMap(rename(s)))
// These cases are simpler, mostly just recursive calls to rename()
case Assign(lhs, value) => {
// Variables may be reassigned with their value in the rhs, so order doesn't matter here.
rename(scope)(lhs)
rename(scope)(value)
}
case Read(lhs) => rename(scope)(lhs)
case Free(expr) => rename(scope)(expr)
case Return(expr) => rename(scope)(expr)
case Exit(expr) => rename(scope)(expr)
case Print(expr, _) => rename(scope)(expr)
case NewPair(fst, snd) => {
rename(scope)(fst)
rename(scope)(snd)
}
case Call(name, args) => {
scope.getFunc(name.v) match {
case Some(Ident(_, uid)) => name.uid = uid
case None =>
errors += Error.UndefinedFunction(name)
scope.add(FuncType(?, args.map(_ => ?)), name)
// These cases are simpler, mostly just recursive calls to rename()
case Assign(lhs, value) => {
// Variables may be reassigned with their value in the rhs, so order doesn't matter here.
rename(scope)(lhs) ++ rename(scope)(value)
}
args.foreach(rename(scope))
}
case Fst(elem) => rename(scope)(elem)
case Snd(elem) => rename(scope)(elem)
case ArrayLiter(elems) => elems.foreach(rename(scope))
case ArrayElem(name, indices) => {
rename(scope)(name)
indices.toList.foreach(rename(scope))
}
case Parens(expr) => rename(scope)(expr)
case op: UnaryOp => rename(scope)(op.x)
case op: BinaryOp => {
rename(scope)(op.x)
rename(scope)(op.y)
}
// Default to variables. Only `call` uses IdentType.Func.
case id: Ident => {
scope.getVar(id.v) match {
case Some(Ident(_, uid)) => id.uid = uid
case None =>
errors += Error.UndeclaredVariable(id)
scope.add(?, id)
case Read(lhs) => rename(scope)(lhs)
case Free(expr) => rename(scope)(expr)
case Return(expr) => rename(scope)(expr)
case Exit(expr) => rename(scope)(expr)
case Print(expr, _) => rename(scope)(expr)
case NewPair(fst, snd) => {
rename(scope)(fst) ++ rename(scope)(snd)
}
case Call(name, args) => {
val nameErrors = scope.getFunc(name) match {
case Some(Ident(realName, guid, ty)) =>
name.v = realName
name.ty = ty
name.guid = guid
Chain.empty
case None =>
name.ty = FuncType(?, args.map(_ => ?))
scope.add(name)
Chain.one(Error.UndefinedFunction(name))
}
val argsErrors = args.foldMap(rename(scope))
nameErrors ++ argsErrors
}
case Fst(elem) => rename(scope)(elem)
case Snd(elem) => rename(scope)(elem)
case ArrayLiter(elems) => elems.foldMap(rename(scope))
case ArrayElem(name, indices) => {
val nameErrors = rename(scope)(name)
val indicesErrors = indices.foldMap(rename(scope))
nameErrors ++ indicesErrors
}
case Parens(expr) => rename(scope)(expr)
case op: UnaryOp => rename(scope)(op.x)
case op: BinaryOp => {
rename(scope)(op.x) ++ rename(scope)(op.y)
}
// Default to variables. Only `call` uses IdentType.Func.
case id: Ident => {
scope.getVar(id) match {
case Some(Ident(_, guid, ty)) =>
id.ty = ty
id.guid = guid
Chain.empty
case None =>
id.ty = ?
scope.add(id)
Chain.one(Error.UndeclaredVariable(id))
}
}
// These literals cannot contain identifies, exit immediately.
case IntLiter(_) | BoolLiter(_) | CharLiter(_) | StrLiter(_) | PairLiter() | Skip() =>
Chain.empty
}
// These literals cannot contain identifies, exit immediately.
case IntLiter(_) | BoolLiter(_) | CharLiter(_) | StrLiter(_) | PairLiter() | Skip() => ()
}
}

View File

@@ -8,13 +8,8 @@ object typeChecker {
import wacc.types._
case class TypeCheckerCtx(
globalNames: Map[ast.Ident, SemType],
globalFuncs: Map[ast.Ident, FuncType],
errors: mutable.Builder[Error, List[Error]]
) {
def typeOf(ident: ast.Ident): SemType = globalNames(ident)
def funcType(ident: ast.Ident): FuncType = globalFuncs(ident)
def error(err: Error): SemType =
errors += err
?
@@ -99,18 +94,17 @@ object typeChecker {
* The type checker context which includes the global names and functions, and an errors
* builder.
*/
def check(prog: ast.Program)(using
ctx: TypeCheckerCtx
): microWacc.Program =
def check(prog: ast.Program, errors: mutable.Builder[Error, List[Error]]): microWacc.Program =
given ctx: TypeCheckerCtx = TypeCheckerCtx(errors)
microWacc.Program(
// Ignore function syntax types for return value and params, since those have been converted
// to SemTypes by the renamer.
prog.funcs.map { case ast.FuncDecl(_, name, params, stmts) =>
val FuncType(retType, paramTypes) = ctx.funcType(name)
val FuncType(retType, paramTypes) = name.ty.asInstanceOf[FuncType]
microWacc.FuncDecl(
microWacc.Ident(name.v, name.uid)(retType),
microWacc.Ident(name.v, name.guid)(retType),
params.zip(paramTypes).map { case (ast.Param(_, ident), ty) =>
microWacc.Ident(ident.v, ident.uid)(ty)
microWacc.Ident(ident.v, ident.guid)(ty)
},
stmts.toList
.flatMap(
@@ -134,15 +128,20 @@ object typeChecker {
): List[microWacc.Stmt] = stmt match {
// Ignore the type of the variable, since it has been converted to a SemType by the renamer.
case ast.VarDecl(_, name, value) =>
val expectedTy = ctx.typeOf(name)
val expectedTy = name.ty
val typedValue = checkValue(
value,
Constraint.Is(
expectedTy,
expectedTy.asInstanceOf[SemType],
s"variable ${name.v} must be assigned a value of type $expectedTy"
)
)
List(microWacc.Assign(microWacc.Ident(name.v, name.uid)(expectedTy), typedValue))
List(
microWacc.Assign(
microWacc.Ident(name.v, name.guid)(expectedTy.asInstanceOf[SemType]),
typedValue
)
)
case ast.Assign(lhs, rhs) =>
val lhsTyped = checkLValue(lhs, Constraint.Unconstrained)
val rhsTyped =
@@ -315,7 +314,7 @@ object typeChecker {
KnownType.Pair(fstTyped.ty, sndTyped.ty).satisfies(constraint, l.pos)
)
case ast.Call(id, args) =>
val funcTy @ FuncType(retTy, paramTys) = ctx.funcType(id)
val funcTy @ FuncType(retTy, paramTys) = id.ty.asInstanceOf[FuncType]
if (args.length != paramTys.length) {
ctx.error(Error.FunctionParamsMismatch(id, paramTys.length, args.length, funcTy))
}
@@ -324,7 +323,7 @@ object typeChecker {
val argsTyped = args.zip(paramTys).map { case (arg, paramTy) =>
checkValue(arg, Constraint.Is(paramTy, s"argument type mismatch in function ${id.v}"))
}
microWacc.Call(microWacc.Ident(id.v, id.uid)(retTy.satisfies(constraint, id.pos)), argsTyped)
microWacc.Call(microWacc.Ident(id.v, id.guid)(retTy.satisfies(constraint, id.pos)), argsTyped)
// Unary operators
case ast.Negate(x) =>
@@ -416,30 +415,32 @@ object typeChecker {
private def checkLValue(value: ast.LValue, constraint: Constraint)(using
ctx: TypeCheckerCtx
): microWacc.LValue = value match {
case id @ ast.Ident(name, uid) =>
microWacc.Ident(name, uid)(ctx.typeOf(id).satisfies(constraint, id.pos))
case id @ ast.Ident(name, guid, ty) =>
microWacc.Ident(name, guid)(ty.asInstanceOf[SemType].satisfies(constraint, id.pos))
case ast.ArrayElem(id, indices) =>
val arrayTy = ctx.typeOf(id)
val (elemTy, indicesTyped) = indices.mapAccumulate(arrayTy) { (acc, elem) =>
val idxTyped = checkValue(elem, Constraint.Is(KnownType.Int, "array index must be an int"))
val next = acc match {
case KnownType.Array(innerTy) => innerTy
case ? => ? // we can keep indexing an unknown type
case nonArrayTy =>
ctx.error(
Error.TypeMismatch(
elem.pos,
KnownType.Array(?),
acc,
"cannot index into a non-array"
val arrayTy = id.ty.asInstanceOf[SemType]
val (elemTy, indicesTyped) = indices.mapAccumulate(arrayTy.asInstanceOf[SemType]) {
(acc, elem) =>
val idxTyped =
checkValue(elem, Constraint.Is(KnownType.Int, "array index must be an int"))
val next = acc match {
case KnownType.Array(innerTy) => innerTy
case ? => ? // we can keep indexing an unknown type
case nonArrayTy =>
ctx.error(
Error.TypeMismatch(
elem.pos,
KnownType.Array(?),
acc,
"cannot index into a non-array"
)
)
)
?
}
(next, idxTyped)
?
}
(next, idxTyped)
}
val firstArrayElem = microWacc.ArrayElem(
microWacc.Ident(id.v, id.uid)(arrayTy),
microWacc.Ident(id.v, id.guid)(arrayTy),
indicesTyped.head
)(elemTy.satisfies(constraint, value.pos))
val arrayElem = indicesTyped.tail.foldLeft(firstArrayElem) { (acc, idx) =>

View File

@@ -3,7 +3,9 @@ package wacc
object types {
import ast._
sealed trait SemType {
sealed trait RenamerType
sealed trait SemType extends RenamerType {
override def toString(): String = this match {
case KnownType.Int => "int"
case KnownType.Bool => "bool"
@@ -41,5 +43,5 @@ object types {
}
}
case class FuncType(returnType: SemType, params: List[SemType])
case class FuncType(returnType: SemType, params: List[SemType]) extends RenamerType
}

View File

@@ -26,6 +26,15 @@ class ParallelExamplesSpec extends AsyncFreeSpec with AsyncIOSpec with BeforeAnd
} ++
allWaccFiles("wacc-examples/invalid/whack").map { p =>
(p.toString, List(100, 200))
} ++
allWaccFiles("extension/examples/valid").map { p =>
(p.toString, List(0))
} ++
allWaccFiles("extension/examples/invalid/syntax").map { p =>
(p.toString, List(100))
} ++
allWaccFiles("extension/examples/invalid/semantics").map { p =>
(p.toString, List(200))
}
forEvery(files) { (filename, expectedResult) =>
@@ -33,18 +42,21 @@ class ParallelExamplesSpec extends AsyncFreeSpec with AsyncIOSpec with BeforeAnd
s"$filename" - {
"should be compiled with correct result" in {
compileWacc(Path.of(filename), outputDir = None, log = false).map { result =>
expectedResult should contain(result)
}
if (fileIsPendingFrontend(filename))
IO.pure(pending)
else
compileWacc(Path.of(filename), outputDir = None, log = false).map { result =>
expectedResult should contain(result)
}
}
if (expectedResult == List(0)) {
"should run with correct result" in {
if (fileIsDisallowedBackend(filename))
IO.pure(
succeed
) // TODO: remove when advanced tests removed. not sure how to "pending" this otherwise
else {
IO.pure(succeed)
else if (fileIsPendingBackend(filename))
IO.pure(pending)
else
for {
contents <- IO(Source.fromFile(File(filename)).getLines.toList)
inputLine = extractInput(contents)
@@ -75,7 +87,6 @@ class ParallelExamplesSpec extends AsyncFreeSpec with AsyncIOSpec with BeforeAnd
exitCode shouldBe expectedExit
normalizeOutput(stdout.toString) shouldBe expectedOutput
}
}
}
}
}
@@ -85,10 +96,21 @@ class ParallelExamplesSpec extends AsyncFreeSpec with AsyncIOSpec with BeforeAnd
val d = java.io.File(dir)
os.walk(os.Path(d.getAbsolutePath)).filter(_.ext == "wacc")
// TODO: eventually remove this I think
def fileIsDisallowedBackend(filename: String): Boolean =
Seq(
"^.*wacc-examples/valid/advanced.*$"
private def fileIsDisallowedBackend(filename: String): Boolean =
filename.matches("^.*wacc-examples/valid/advanced.*$")
private def fileIsPendingFrontend(filename: String): Boolean =
List(
// "^.*extension/examples/invalid/syntax/imports/importBadSyntax.*$",
// "^.*extension/examples/invalid/semantics/imports.*$",
// "^.*extension/examples/valid/imports.*$"
).exists(filename.matches)
private def fileIsPendingBackend(filename: String): Boolean =
List(
// "^.*extension/examples/invalid/syntax/imports.*$",
// "^.*extension/examples/invalid/semantics/imports.*$",
// "^.*extension/examples/valid/imports.*$"
).exists(filename.matches)
private def extractInput(contents: List[String]): String =