feat: incomplete initial implementation of assemblyIR

This commit is contained in:
Barf-Vader
2025-02-13 17:42:50 +00:00
parent 41ed06f91c
commit 32622cdd7e
8 changed files with 68 additions and 0 deletions

View File

@@ -0,0 +1,93 @@
package wacc
import wacc.ast.Position
import wacc.types._
/** Error types for semantic errors
*/
enum Error {
case DuplicateDeclaration(ident: ast.Ident)
case UndeclaredVariable(ident: ast.Ident)
case UndefinedFunction(ident: ast.Ident)
case FunctionParamsMismatch(ident: ast.Ident, expected: Int, got: Int, funcType: FuncType)
case SemanticError(pos: Position, msg: String)
case TypeMismatch(pos: Position, expected: SemType, got: SemType, msg: String)
case InternalError(pos: Position, msg: String)
}
/** Function to handle printing the details of a given semantic error
*
* @param error
* Error object
* @param errorContent
* Contents of the file to generate code snippets
*/
def printError(error: Error)(using errorContent: String): Unit = {
println("Semantic error:")
error match {
case Error.DuplicateDeclaration(ident) =>
printPosition(ident.pos)
println(s"Duplicate declaration of identifier ${ident.v}")
highlight(ident.pos, ident.v.length)
case Error.UndeclaredVariable(ident) =>
printPosition(ident.pos)
println(s"Undeclared variable ${ident.v}")
highlight(ident.pos, ident.v.length)
case Error.UndefinedFunction(ident) =>
printPosition(ident.pos)
println(s"Undefined function ${ident.v}")
highlight(ident.pos, ident.v.length)
case Error.FunctionParamsMismatch(id, expected, got, funcType) =>
printPosition(id.pos)
println(s"Function expects $expected parameters, got $got")
println(
s"(function ${id.v} has type (${funcType.params.mkString(", ")}) -> ${funcType.returnType})"
)
highlight(id.pos, 1)
case Error.TypeMismatch(pos, expected, got, msg) =>
printPosition(pos)
println(s"Type mismatch: $msg\nExpected: $expected\nGot: $got")
highlight(pos, 1)
case Error.SemanticError(pos, msg) =>
printPosition(pos)
println(msg)
highlight(pos, 1)
case wacc.Error.InternalError(pos, msg) =>
printPosition(pos)
println(s"Internal error: $msg")
highlight(pos, 1)
}
}
/** Function to highlight a section of code for an error message
*
* @param pos
* Position of the error
* @param size
* Size(in chars) of section to highlight
* @param errorContent
* Contents of the file to generate code snippets
*/
def highlight(pos: Position, size: Int)(using errorContent: String): Unit = {
val lines = errorContent.split("\n")
val preLine = if (pos.line > 1) lines(pos.line - 2) else ""
val midLine = lines(pos.line - 1)
val postLine = if (pos.line < lines.size) lines(pos.line) else ""
val linePointer = " " * (pos.column + 2) + ("^" * (size)) + "\n"
println(
s" >$preLine\n >$midLine\n$linePointer >$postLine"
)
}
/** Function to print the position of an error
*
* @param pos
* Position of the error
*/
def printPosition(pos: Position): Unit = {
println(s"(line ${pos.line}, column ${pos.column}):")
}

View File

@@ -0,0 +1,253 @@
package wacc
import parsley.Parsley
import parsley.generic.ErrorBridge
import parsley.ap._
import parsley.position._
import parsley.syntax.zipped._
import cats.data.NonEmptyList
object ast {
/* ============================ EXPRESSIONS ============================ */
sealed trait Expr extends RValue {
val pos: Position
}
sealed trait Expr1 extends Expr
sealed trait Expr2 extends Expr1
sealed trait Expr3 extends Expr2
sealed trait Expr4 extends Expr3
sealed trait Expr5 extends Expr4
sealed trait Expr6 extends Expr5
/* ============================ ATOMIC EXPRESSIONS ============================ */
case class IntLiter(v: Int)(val pos: Position) extends Expr6
object IntLiter extends ParserBridgePos1[Int, IntLiter]
case class BoolLiter(v: Boolean)(val pos: Position) extends Expr6
object BoolLiter extends ParserBridgePos1[Boolean, BoolLiter]
case class CharLiter(v: Char)(val pos: Position) extends Expr6
object CharLiter extends ParserBridgePos1[Char, CharLiter]
case class StrLiter(v: String)(val pos: Position) extends Expr6
object StrLiter extends ParserBridgePos1[String, StrLiter]
case class PairLiter()(val pos: Position) extends Expr6
object PairLiter extends ParserBridgePos0[PairLiter]
case class Ident(v: String, var uid: Int = -1)(val pos: Position) extends Expr6 with LValue
object Ident extends ParserBridgePos1[String, Ident] {
def apply(v: String)(pos: Position): Ident = new Ident(v)(pos)
}
case class ArrayElem(name: Ident, indices: NonEmptyList[Expr])(val pos: Position)
extends Expr6
with LValue
object ArrayElem extends ParserBridgePos1[NonEmptyList[Expr], Ident => ArrayElem] {
def apply(a: NonEmptyList[Expr])(pos: Position): Ident => ArrayElem =
name => ArrayElem(name, a)(pos)
}
case class Parens(expr: Expr)(val pos: Position) extends Expr6
object Parens extends ParserBridgePos1[Expr, Parens]
/* ============================ UNARY OPERATORS ============================ */
sealed trait UnaryOp extends Expr {
val x: Expr
}
case class Negate(x: Expr6)(val pos: Position) extends Expr6 with UnaryOp
object Negate extends ParserBridgePos1[Expr6, Negate]
case class Not(x: Expr6)(val pos: Position) extends Expr6 with UnaryOp
object Not extends ParserBridgePos1[Expr6, Not]
case class Len(x: Expr6)(val pos: Position) extends Expr6 with UnaryOp
object Len extends ParserBridgePos1[Expr6, Len]
case class Ord(x: Expr6)(val pos: Position) extends Expr6 with UnaryOp
object Ord extends ParserBridgePos1[Expr6, Ord]
case class Chr(x: Expr6)(val pos: Position) extends Expr6 with UnaryOp
object Chr extends ParserBridgePos1[Expr6, Chr]
/* ============================ BINARY OPERATORS ============================ */
sealed trait BinaryOp(val name: String) extends Expr {
val x: Expr
val y: Expr
}
case class Add(x: Expr4, y: Expr5)(val pos: Position) extends Expr4 with BinaryOp("addition")
object Add extends ParserBridgePos2[Expr4, Expr5, Add]
case class Sub(x: Expr4, y: Expr5)(val pos: Position) extends Expr4 with BinaryOp("subtraction")
object Sub extends ParserBridgePos2[Expr4, Expr5, Sub]
case class Mul(x: Expr5, y: Expr6)(val pos: Position)
extends Expr5
with BinaryOp("multiplication")
object Mul extends ParserBridgePos2[Expr5, Expr6, Mul]
case class Div(x: Expr5, y: Expr6)(val pos: Position) extends Expr5 with BinaryOp("division")
object Div extends ParserBridgePos2[Expr5, Expr6, Div]
case class Mod(x: Expr5, y: Expr6)(val pos: Position) extends Expr5 with BinaryOp("modulus")
object Mod extends ParserBridgePos2[Expr5, Expr6, Mod]
case class Greater(x: Expr4, y: Expr4)(val pos: Position)
extends Expr3
with BinaryOp("strictly greater than")
object Greater extends ParserBridgePos2[Expr4, Expr4, Greater]
case class GreaterEq(x: Expr4, y: Expr4)(val pos: Position)
extends Expr3
with BinaryOp("greater than or equal to")
object GreaterEq extends ParserBridgePos2[Expr4, Expr4, GreaterEq]
case class Less(x: Expr4, y: Expr4)(val pos: Position)
extends Expr3
with BinaryOp("strictly less than")
object Less extends ParserBridgePos2[Expr4, Expr4, Less]
case class LessEq(x: Expr4, y: Expr4)(val pos: Position)
extends Expr3
with BinaryOp("less than or equal to")
object LessEq extends ParserBridgePos2[Expr4, Expr4, LessEq]
case class Eq(x: Expr3, y: Expr3)(val pos: Position) extends Expr2 with BinaryOp("equality")
object Eq extends ParserBridgePos2[Expr3, Expr3, Eq]
case class Neq(x: Expr3, y: Expr3)(val pos: Position) extends Expr2 with BinaryOp("inequality")
object Neq extends ParserBridgePos2[Expr3, Expr3, Neq]
case class And(x: Expr2, y: Expr1)(val pos: Position) extends Expr1 with BinaryOp("logical and")
object And extends ParserBridgePos2[Expr2, Expr1, And]
case class Or(x: Expr1, y: Expr)(val pos: Position) extends Expr with BinaryOp("logical or")
object Or extends ParserBridgePos2[Expr1, Expr, Or]
/* ============================ TYPES ============================ */
sealed trait Type
sealed trait BaseType extends Type with PairElemType
case class IntType()(val pos: Position) extends BaseType
object IntType extends ParserBridgePos0[IntType]
case class BoolType()(val pos: Position) extends BaseType
object BoolType extends ParserBridgePos0[BoolType]
case class CharType()(val pos: Position) extends BaseType
object CharType extends ParserBridgePos0[CharType]
case class StringType()(val pos: Position) extends BaseType
object StringType extends ParserBridgePos0[StringType]
case class ArrayType(elemType: Type, dimensions: Int)(val pos: Position)
extends Type
with PairElemType
object ArrayType extends ParserBridgePos1[Int, Type => ArrayType] {
def apply(a: Int)(pos: Position): Type => ArrayType = elemType => ArrayType(elemType, a)(pos)
}
case class PairType(fst: PairElemType, snd: PairElemType)(val pos: Position) extends Type
object PairType extends ParserBridgePos2[PairElemType, PairElemType, PairType]
sealed trait PairElemType
case class UntypedPairType()(val pos: Position) extends PairElemType
object UntypedPairType extends ParserBridgePos0[UntypedPairType]
/* ============================ PROGRAM STRUCTURE ============================ */
case class Program(funcs: List[FuncDecl], main: NonEmptyList[Stmt])(val pos: Position)
object Program extends ParserBridgePos2[List[FuncDecl], NonEmptyList[Stmt], Program]
/* ============================ FUNCTION STRUCTURE ============================ */
case class FuncDecl(
returnType: Type,
name: Ident,
params: List[Param],
body: NonEmptyList[Stmt]
)(val pos: Position)
object FuncDecl
extends ParserBridgePos2[
List[Param],
NonEmptyList[Stmt],
((Type, Ident)) => FuncDecl
] {
def apply(params: List[Param], body: NonEmptyList[Stmt])(
pos: Position
): ((Type, Ident)) => FuncDecl =
(returnType, name) => FuncDecl(returnType, name, params, body)(pos)
}
case class Param(paramType: Type, name: Ident)(val pos: Position)
object Param extends ParserBridgePos2[Type, Ident, Param]
/* ============================ STATEMENTS ============================ */
sealed trait Stmt
case class Skip()(val pos: Position) extends Stmt
object Skip extends ParserBridgePos0[Skip]
case class VarDecl(varType: Type, name: Ident, value: RValue)(val pos: Position) extends Stmt
object VarDecl extends ParserBridgePos3[Type, Ident, RValue, VarDecl]
case class Assign(lhs: LValue, value: RValue)(val pos: Position) extends Stmt
object Assign extends ParserBridgePos2[LValue, RValue, Assign]
case class Read(lhs: LValue)(val pos: Position) extends Stmt
object Read extends ParserBridgePos1[LValue, Read]
case class Free(expr: Expr)(val pos: Position) extends Stmt
object Free extends ParserBridgePos1[Expr, Free]
case class Return(expr: Expr)(val pos: Position) extends Stmt
object Return extends ParserBridgePos1[Expr, Return]
case class Exit(expr: Expr)(val pos: Position) extends Stmt
object Exit extends ParserBridgePos1[Expr, Exit]
case class Print(expr: Expr, newline: Boolean)(val pos: Position) extends Stmt
object Print extends ParserBridgePos2[Expr, Boolean, Print]
case class If(cond: Expr, thenStmt: NonEmptyList[Stmt], elseStmt: NonEmptyList[Stmt])(
val pos: Position
) extends Stmt
object If extends ParserBridgePos3[Expr, NonEmptyList[Stmt], NonEmptyList[Stmt], If]
case class While(cond: Expr, body: NonEmptyList[Stmt])(val pos: Position) extends Stmt
object While extends ParserBridgePos2[Expr, NonEmptyList[Stmt], While]
case class Block(stmt: NonEmptyList[Stmt])(val pos: Position) extends Stmt
object Block extends ParserBridgePos1[NonEmptyList[Stmt], Block]
/* ============================ LVALUES & RVALUES ============================ */
sealed trait LValue {
val pos: Position
}
sealed trait RValue
case class ArrayLiter(elems: List[Expr])(val pos: Position) extends RValue
object ArrayLiter extends ParserBridgePos1[List[Expr], ArrayLiter]
case class NewPair(fst: Expr, snd: Expr)(val pos: Position) extends RValue
object NewPair extends ParserBridgePos2[Expr, Expr, NewPair]
case class Call(name: Ident, args: List[Expr])(val pos: Position) extends RValue
object Call extends ParserBridgePos2[Ident, List[Expr], Call]
sealed trait PairElem extends LValue with RValue
case class Fst(elem: LValue)(val pos: Position) extends PairElem
object Fst extends ParserBridgePos1[LValue, Fst]
case class Snd(elem: LValue)(val pos: Position) extends PairElem
object Snd extends ParserBridgePos1[LValue, Snd]
/* ============================ PARSER BRIDGES ============================ */
case class Position(line: Int, column: Int)
trait ParserSingletonBridgePos[+A] extends ErrorBridge {
protected def con(pos: (Int, Int)): A
infix def from(op: Parsley[?]): Parsley[A] = error(pos.map(con) <~ op)
final def <#(op: Parsley[?]): Parsley[A] = this from op
}
trait ParserBridgePos0[+A] extends ParserSingletonBridgePos[A] {
def apply()(pos: Position): A
override final def con(pos: (Int, Int)): A =
apply()(Position(pos._1, pos._2))
}
trait ParserBridgePos1[-A, +B] extends ParserSingletonBridgePos[A => B] {
def apply(a: A)(pos: Position): B
def apply(a: Parsley[A]): Parsley[B] = error(ap1(pos.map(con), a))
override final def con(pos: (Int, Int)): A => B =
this.apply(_)(Position(pos._1, pos._2))
}
trait ParserBridgePos2[-A, -B, +C] extends ParserSingletonBridgePos[(A, B) => C] {
def apply(a: A, b: B)(pos: Position): C
def apply(a: Parsley[A], b: => Parsley[B]): Parsley[C] = error(
ap2(pos.map(con), a, b)
)
override final def con(pos: (Int, Int)): (A, B) => C =
apply(_, _)(Position(pos._1, pos._2))
}
trait ParserBridgePos3[-A, -B, -C, +D] extends ParserSingletonBridgePos[(A, B, C) => D] {
def apply(a: A, b: B, c: C)(pos: Position): D
def apply(a: Parsley[A], b: => Parsley[B], c: => Parsley[C]): Parsley[D] = error(
ap3(pos.map(con), a, b, c)
)
override final def con(pos: (Int, Int)): (A, B, C) => D =
apply(_, _, _)(Position(pos._1, pos._2))
}
}

View File

@@ -0,0 +1,105 @@
package wacc
import parsley.Parsley
import parsley.character
import parsley.token.{Basic, Lexer}
import parsley.token.descriptions.*
import parsley.token.errors._
/** ErrorConfig for producing more informative error messages
*/
val errConfig = new ErrorConfig {
override def labelSymbol = Map(
"!=" -> Label("binary operator"),
"%" -> Label("binary operator"),
"&&" -> Label("binary operator"),
"*" -> Label("binary operator"),
"/" -> Label("binary operator"),
"<" -> Label("binary operator"),
"<=" -> Label("binary operator"),
"==" -> Label("binary operator"),
">" -> Label("binary operator"),
">=" -> Label("binary operator"),
"||" -> Label("binary operator"),
"!" -> Label("unary operator"),
"len" -> Label("unary operator"),
"ord" -> Label("unary operator"),
"chr" -> Label("unary operator"),
"bool" -> Label("valid type"),
"char" -> Label("valid type"),
"int" -> Label("valid type"),
"pair" -> Label("valid type"),
"string" -> Label("valid type"),
"fst" -> Label("pair extraction"),
"snd" -> Label("pair extraction"),
"false" -> Label("boolean literal"),
"true" -> Label("boolean literal"),
"=" -> Label("assignment"),
"[" -> Label("array index")
)
}
object lexer {
/** Language description for the WACC lexer
*/
private val desc = LexicalDesc.plain.copy(
nameDesc = NameDesc.plain.copy(
identifierStart = Basic(c => c.isLetter || c == '_'),
identifierLetter = Basic(c => c.isLetterOrDigit || c == '_')
),
symbolDesc = SymbolDesc.plain.copy(
hardKeywords = Set(
"begin", "end", "is", "skip", "if", "then", "else", "fi", "while", "do", "done", "read",
"free", "return", "exit", "print", "println", "true", "false", "int", "bool", "char",
"string", "pair", "newpair", "fst", "snd", "call", "chr", "ord", "len", "null"
),
hardOperators = Set(
"+", "-", "*", "/", "%", ">", "<", ">=", "<=", "==", "!=", "&&", "||", "!"
)
),
spaceDesc = SpaceDesc.plain.copy(
lineCommentStart = "#"
),
textDesc = TextDesc.plain.copy(
graphicCharacter = Basic(c => c >= ' ' && c != '\\' && c != '\'' && c != '"'),
escapeSequences = EscapeDesc.plain.copy(
literals = Set('\\', '"', '\''),
mapping = Map(
"0" -> '\u0000',
"b" -> '\b',
"t" -> '\t',
"n" -> '\n',
"f" -> '\f',
"r" -> '\r'
)
)
),
numericDesc = NumericDesc.plain.copy(
decimalExponentDesc = ExponentDesc.NoExponents
)
)
/** Token definitions for the WACC lexer
*/
private val lexer = Lexer(desc, errConfig)
val ident = lexer.lexeme.names.identifier
val integer = lexer.lexeme.integer.decimal32[Int]
val negateCheck = lexer.nonlexeme.symbol("-") ~> character.digit
val charLit = lexer.lexeme.character.ascii
val stringLit = lexer.lexeme.string.ascii
val implicits = lexer.lexeme.symbol.implicits
/** Tokens for producing lexer-backed error messages
*/
val errTokens = Seq(
lexer.nonlexeme.names.identifier.map(v => s"identifier $v"),
lexer.nonlexeme.integer.decimal32[Int].map(n => s"integer $n"),
lexer.nonlexeme.character.ascii.map(c => s"character literal \'$c\'"),
lexer.nonlexeme.string.ascii.map(s => s"string literal \"$s\""),
lexer.nonlexeme.symbol("[").as("array literal"),
character.whitespace.map(_ => "")
) ++ desc.symbolDesc.hardKeywords.map { k =>
lexer.nonlexeme.symbol(k).as(s"keyword $k")
}
def fully[A](p: Parsley[A]): Parsley[A] = lexer.fully(p)
}

View File

@@ -0,0 +1,224 @@
package wacc
import parsley.Result
import parsley.Parsley
import parsley.Parsley.{atomic, many, notFollowedBy, pure, unit}
import parsley.combinator.{countSome, sepBy}
import parsley.expr.{precedence, SOps, InfixL, InfixN, InfixR, Prefix, Atoms}
import parsley.errors.combinator._
import parsley.errors.patterns.VerifiedErrors
import parsley.syntax.zipped._
import parsley.cats.combinator.{some}
import cats.data.NonEmptyList
import parsley.errors.DefaultErrorBuilder
import parsley.errors.ErrorBuilder
import parsley.errors.tokenextractors.LexToken
object parser {
import lexer.implicits.implicitSymbol
import lexer.{ident, integer, charLit, stringLit, negateCheck, errTokens}
import ast._
// error extensions
extension [A](p: Parsley[A]) {
// combines label and explain together into one function call
def labelAndExplain(label: String, explanation: String): Parsley[A] = {
p.label(label).explain(explanation)
}
def labelAndExplain(t: LabelType): Parsley[A] = {
t match {
case LabelType.Expr =>
labelWithType(t).explain(
"a valid expression can start with: null, literals, identifiers, unary operators, or parentheses. " +
"Expressions can also contain array indexing and binary operators. " +
"Pair extraction is not allowed in expressions, only in assignments."
)
case _ => labelWithType(t)
}
}
def labelWithType(t: LabelType): Parsley[A] = {
t match {
case LabelType.Expr => p.label("valid expression")
case LabelType.Pair => p.label("valid pair")
}
}
}
enum LabelType:
case Expr
case Pair
implicit val builder: ErrorBuilder[String] = new DefaultErrorBuilder with LexToken {
def tokens = errTokens
}
def parse(input: String): Result[String, Program] = parser.parse(input)
private val parser = lexer.fully(`<program>`)
// Expressions
private lazy val `<expr>`: Parsley[Expr] = precedence {
SOps(InfixR)(Or from "||") +:
SOps(InfixR)(And from "&&") +:
SOps(InfixN)(Eq from "==", Neq from "!=") +:
SOps(InfixN)(
Less from "<",
LessEq from "<=",
Greater from ">",
GreaterEq from ">="
) +:
SOps(InfixL)(
(Add from "+").label("binary operator"),
(Sub from "-").label("binary operator")
) +:
SOps(InfixL)(Mul from "*", Div from "/", Mod from "%") +:
SOps(Prefix)(
Not from "!",
// notFollowedBy(negateCheck) ensures that negative numbers are parsed as a single int literal
(Negate from (notFollowedBy(negateCheck) ~> "-")).hide,
Len from "len",
Ord from "ord",
Chr from "chr"
) +:
`<atom>`
}
// Atoms
private lazy val `<atom>`: Atoms[Expr6] = Atoms(
IntLiter(integer).label("integer literal"),
BoolLiter(("true" as true) | ("false" as false)).label("boolean literal"),
CharLiter(charLit).label("character literal"),
StrLiter(stringLit).label("string literal"),
PairLiter from "null",
`<ident-or-array-elem>`,
Parens("(" ~> `<expr>` <~ ")")
)
private val `<ident>` =
Ident(ident) | some("*" | "&").verifiedExplain("pointer operators are not allowed")
private lazy val `<ident-or-array-elem>` =
(`<ident>` <~ ("(".verifiedExplain(
"functions can only be called using 'call' keyword"
) | unit)) <**> (`<array-indices>` </> identity)
private val `<array-indices>` = ArrayElem(some("[" ~> `<expr>` <~ "]"))
// Types
private lazy val `<type>`: Parsley[Type] =
(`<base-type>` | (`<pair-type>` ~> `<pair-elems-type>`)) <**> (`<array-type>` </> identity)
private val `<base-type>` =
(IntType from "int") | (BoolType from "bool") | (CharType from "char") | (StringType from "string")
private lazy val `<array-type>` =
ArrayType(countSome("[" ~> "]"))
private val `<pair-type>` = "pair"
private val `<pair-elems-type>`: Parsley[PairType] = PairType(
"(" ~> `<pair-elem-type>` <~ ",",
`<pair-elem-type>` <~ ")"
)
private lazy val `<pair-elem-type>` =
(`<base-type>` <**> (`<array-type>` </> identity)) |
((UntypedPairType from `<pair-type>`) <**>
((`<pair-elems-type>` <**> `<array-type>`)
.map(arr => (_: UntypedPairType) => arr) </> identity))
/* Statements
Atomic is used in two places here:
1. Atomic for function return type - code may be a variable declaration instead, If we were
to factor out the type, the resulting code would be rather messy. It can only fail once
in the entire program so it creates minimal overhead.
2. Atomic for function missing return type check - there is no easy way around an explicit
invalid syntax check, this only happens at most once per program so this is not a major
concern.
*/
private lazy val `<program>` = Program(
"begin" ~> (
many(
atomic(
`<type>`.label("function declaration") <~> `<ident>` <~ "("
) <**> `<partial-func-decl>`
).label("function declaration") |
atomic(`<ident>` <~ "(").verifiedExplain("function declaration is missing return type")
),
`<stmt>`.label(
"main program body"
) <~ "end"
)
private lazy val `<partial-func-decl>` =
FuncDecl(
sepBy(`<param>`, ",") <~ ")" <~ "is",
`<stmt>`.guardAgainst {
case stmts if !stmts.isReturning => Seq("all functions must end in a returning statement")
} <~ "end"
)
private lazy val `<param>` = Param(`<type>`, `<ident>`)
private lazy val `<stmt>`: Parsley[NonEmptyList[Stmt]] =
(
`<basic-stmt>`.label("main program body"),
(many(";" ~> `<basic-stmt>`.label("statement after ';'"))) </> Nil
).zipped(NonEmptyList.apply)
private lazy val `<basic-stmt>` =
(Skip from "skip")
| Read("read" ~> `<lvalue>`)
| Free("free" ~> `<expr>`.labelAndExplain(LabelType.Expr))
| Return("return" ~> `<expr>`.labelAndExplain(LabelType.Expr))
| Exit("exit" ~> `<expr>`.labelAndExplain(LabelType.Expr))
| Print("print" ~> `<expr>`.labelAndExplain(LabelType.Expr), pure(false))
| Print("println" ~> `<expr>`.labelAndExplain(LabelType.Expr), pure(true))
| If(
"if" ~> `<expr>`.labelWithType(LabelType.Expr) <~ "then",
`<stmt>` <~ "else",
`<stmt>` <~ "fi"
)
| While("while" ~> `<expr>`.labelWithType(LabelType.Expr) <~ "do", `<stmt>` <~ "done")
| Block("begin" ~> `<stmt>` <~ "end")
| VarDecl(
`<type>`,
`<ident>` <~ ("=" | "(".verifiedExplain(
"all function declarations must be above the main program body"
)),
`<rvalue>`.label("valid initial value for variable")
)
| Assign(
`<lvalue>` <~ ("=" | "(".verifiedExplain(
"function calls must use the 'call' keyword and the result must be assigned to a variable"
)),
`<rvalue>`
) |
("call" ~> `<ident>`).verifiedExplain(
"function calls' results must be assigned to a variable"
)
private lazy val `<lvalue>`: Parsley[LValue] =
`<pair-elem>` | `<ident-or-array-elem>`
private lazy val `<rvalue>`: Parsley[RValue] =
`<array-liter>` |
NewPair(
"newpair" ~> "(" ~> `<expr>` <~ ",",
`<expr>` <~ ")"
) |
`<pair-elem>` |
Call(
"call" ~> `<ident>` <~ "(",
sepBy(`<expr>`, ",") <~ ")"
) | `<expr>`.labelWithType(LabelType.Expr)
private lazy val `<pair-elem>` =
Fst("fst" ~> `<lvalue>`.label("valid pair"))
| Snd("snd" ~> `<lvalue>`.label("valid pair"))
private lazy val `<array-liter>` = ArrayLiter(
"[" ~> sepBy(`<expr>`, ",") <~ "]"
)
extension (stmts: NonEmptyList[Stmt]) {
/** Determines whether a function body is guaranteed to return in all cases This is required as
* all functions must end via a "return" or "exit" statement
*
* @return
* true if the statement list ends in a return statement, false otherwise
*/
def isReturning: Boolean = stmts.last match {
case Return(_) | Exit(_) => true
case If(_, thenStmt, elseStmt) => thenStmt.isReturning && elseStmt.isReturning
case While(_, body) => body.isReturning
case Block(body) => body.isReturning
case _ => false
}
}
}

View File

@@ -0,0 +1,219 @@
package wacc
import scala.collection.mutable
object renamer {
import ast._
import types._
private enum IdentType {
case Func
case Var
}
private class Scope(
val current: mutable.Map[(String, IdentType), Ident],
val parent: Map[(String, IdentType), Ident]
) {
/** Create a new scope with the current scope as its parent.
*
* @return
* A new scope with an empty current scope, and this scope flattened into the parent scope.
*/
def subscope: Scope =
Scope(mutable.Map.empty, Map.empty.withDefault(current.withDefault(parent)))
/** Attempt to add a new identifier to the current scope. If the identifier already exists in
* the current scope, add an error to the error list.
*
* @param ty
* The semantic type of the variable identifier, or function identifier type.
* @param name
* The name of the identifier.
* @param globalNames
* The global map of identifiers to semantic types - the identifier will be added to this
* map.
* @param globalNumbering
* The global map of identifier names to the number of times they have been declared - will
* used to rename this identifier, and will be incremented.
* @param errors
* The list of errors to append to.
*/
def add(ty: SemType | FuncType, name: Ident)(using
globalNames: mutable.Map[Ident, SemType],
globalFuncs: mutable.Map[Ident, FuncType],
globalNumbering: mutable.Map[String, Int],
errors: mutable.Builder[Error, List[Error]]
) = {
val identType = ty match {
case _: SemType => IdentType.Var
case _: FuncType => IdentType.Func
}
current.get((name.v, identType)) match {
case Some(Ident(_, uid)) =>
errors += Error.DuplicateDeclaration(name)
name.uid = uid
case None =>
val uid = globalNumbering.getOrElse(name.v, 0)
name.uid = uid
current((name.v, identType)) = name
ty match {
case semType: SemType =>
globalNames(name) = semType
case funcType: FuncType =>
globalFuncs(name) = funcType
}
globalNumbering(name.v) = uid + 1
}
}
private def get(name: String, identType: IdentType): Option[Ident] =
// Unfortunately map defaults only work with `.apply()`, which throws an error when the key is not found.
// Neither is there a way to check whether a default exists, so we have to use a try-catch.
try {
Some(current.withDefault(parent)((name, identType)))
} catch {
case _: NoSuchElementException => None
}
def getVar(name: String): Option[Ident] = get(name, IdentType.Var)
def getFunc(name: String): Option[Ident] = get(name, IdentType.Func)
}
/** Check scoping of all variables and functions in the program. Also generate semantic types for
* all identifiers.
*
* @param prog
* AST of the program
* @param errors
* List of errors to append to
* @return
* Map of all (renamed) identifies to their semantic types
*/
def rename(prog: Program)(using
errors: mutable.Builder[Error, List[Error]]
): (Map[Ident, SemType], Map[Ident, FuncType]) = {
given globalNames: mutable.Map[Ident, SemType] = mutable.Map.empty
given globalFuncs: mutable.Map[Ident, FuncType] = mutable.Map.empty
given globalNumbering: mutable.Map[String, Int] = mutable.Map.empty
val scope = Scope(mutable.Map.empty, Map.empty)
val Program(funcs, main) = prog
funcs
// First add all function declarations to the scope
.map { case FuncDecl(retType, name, params, body) =>
val paramTypes = params.map { param =>
val paramType = SemType(param.paramType)
paramType
}
scope.add(FuncType(SemType(retType), paramTypes), name)
(params zip paramTypes, body)
}
// Only then rename the function bodies
// (functions can call one-another regardless of order of declaration)
.foreach { case (params, body) =>
val functionScope = scope.subscope
params.foreach { case (param, paramType) =>
functionScope.add(paramType, param.name)
}
body.toList.foreach(rename(functionScope.subscope)) // body can shadow function params
}
main.toList.foreach(rename(scope))
(globalNames.toMap, globalFuncs.toMap)
}
/** Check scoping of all identifies in a given AST node.
*
* @param scope
* The current scope and flattened parent scope.
* @param node
* The AST node.
* @param globalNames
* The global map of identifiers to semantic types - renamed identifiers will be added to this
* map.
* @param globalNumbering
* The global map of identifier names to the number of times they have been declared - used and
* updated during identifier renaming.
* @param errors
*/
private def rename(scope: Scope)(
node: Ident | Stmt | LValue | RValue | Expr
)(using
globalNames: mutable.Map[Ident, SemType],
globalFuncs: mutable.Map[Ident, FuncType],
globalNumbering: mutable.Map[String, Int],
errors: mutable.Builder[Error, List[Error]]
): Unit = node match {
// These cases are more interesting because the involve making subscopes
// or modifying the current scope.
case VarDecl(synType, name, value) => {
// Order matters here. Variable isn't declared until after the value is evaluated.
rename(scope)(value)
// Attempt to add the new variable to the current scope.
scope.add(SemType(synType), name)
}
case If(cond, thenStmt, elseStmt) => {
rename(scope)(cond)
// then and else both have their own scopes
thenStmt.toList.foreach(rename(scope.subscope))
elseStmt.toList.foreach(rename(scope.subscope))
}
case While(cond, body) => {
rename(scope)(cond)
// while bodies have their own scopes
body.toList.foreach(rename(scope.subscope))
}
// begin-end blocks have their own scopes
case Block(body) => body.toList.foreach(rename(scope.subscope))
// These cases are simpler, mostly just recursive calls to rename()
case Assign(lhs, value) => {
// Variables may be reassigned with their value in the rhs, so order doesn't matter here.
rename(scope)(lhs)
rename(scope)(value)
}
case Read(lhs) => rename(scope)(lhs)
case Free(expr) => rename(scope)(expr)
case Return(expr) => rename(scope)(expr)
case Exit(expr) => rename(scope)(expr)
case Print(expr, _) => rename(scope)(expr)
case NewPair(fst, snd) => {
rename(scope)(fst)
rename(scope)(snd)
}
case Call(name, args) => {
scope.getFunc(name.v) match {
case Some(Ident(_, uid)) => name.uid = uid
case None =>
errors += Error.UndefinedFunction(name)
scope.add(FuncType(?, args.map(_ => ?)), name)
}
args.foreach(rename(scope))
}
case Fst(elem) => rename(scope)(elem)
case Snd(elem) => rename(scope)(elem)
case ArrayLiter(elems) => elems.foreach(rename(scope))
case ArrayElem(name, indices) => {
rename(scope)(name)
indices.toList.foreach(rename(scope))
}
case Parens(expr) => rename(scope)(expr)
case op: UnaryOp => rename(scope)(op.x)
case op: BinaryOp => {
rename(scope)(op.x)
rename(scope)(op.y)
}
// Default to variables. Only `call` uses IdentType.Func.
case id: Ident => {
scope.getVar(id.v) match {
case Some(Ident(_, uid)) => id.uid = uid
case None =>
errors += Error.UndeclaredVariable(id)
scope.add(?, id)
}
}
// These literals cannot contain identifies, exit immediately.
case IntLiter(_) | BoolLiter(_) | CharLiter(_) | StrLiter(_) | PairLiter() | Skip() => ()
}
}

View File

@@ -0,0 +1,322 @@
package wacc
import cats.syntax.all._
import scala.collection.mutable
object typeChecker {
import wacc.ast._
import wacc.types._
case class TypeCheckerCtx(
globalNames: Map[Ident, SemType],
globalFuncs: Map[Ident, FuncType],
errors: mutable.Builder[Error, List[Error]]
) {
def typeOf(ident: Ident): SemType = globalNames(ident)
def funcType(ident: Ident): FuncType = globalFuncs(ident)
def error(err: Error): SemType =
errors += err
?
}
private enum Constraint {
case Unconstrained
// Allows weakening in one direction
case Is(ty: SemType, msg: String)
// Allows weakening in both directions, useful for array literals
case IsSymmetricCompatible(ty: SemType, msg: String)
// Does not allow weakening
case IsUnweakenable(ty: SemType, msg: String)
case IsEither(ty1: SemType, ty2: SemType, msg: String)
case Never(msg: String)
}
extension (ty: SemType) {
/** Check if a type satisfies a constraint.
*
* @param constraint
* Constraint to satisfy.
* @param pos
* Position to pass to the error, if constraint was not satisfied.
* @return
* The type if the constraint was satisfied, or ? if it was not.
*/
private def satisfies(constraint: Constraint, pos: Position)(using
ctx: TypeCheckerCtx
): SemType =
(ty, constraint) match {
case (KnownType.Array(KnownType.Char), Constraint.Is(KnownType.String, _)) =>
KnownType.String
case (
KnownType.String,
Constraint.IsSymmetricCompatible(KnownType.Array(KnownType.Char), _)
) =>
KnownType.String
case (ty, Constraint.IsSymmetricCompatible(ty2, msg)) =>
ty.satisfies(Constraint.Is(ty2, msg), pos)
// Change to IsUnweakenable to disallow recursive weakening
case (ty, Constraint.Is(ty2, msg)) => ty.satisfies(Constraint.IsUnweakenable(ty2, msg), pos)
case (ty, Constraint.Unconstrained) => ty
case (ty, Constraint.Never(msg)) =>
ctx.error(Error.SemanticError(pos, msg))
case (ty, Constraint.IsEither(ty1, ty2, msg)) =>
(ty moreSpecific ty1).orElse(ty moreSpecific ty2).getOrElse {
ctx.error(Error.TypeMismatch(pos, ty1, ty, msg))
}
case (ty, Constraint.IsUnweakenable(ty2, msg)) =>
(ty moreSpecific ty2).getOrElse {
ctx.error(Error.TypeMismatch(pos, ty2, ty, msg))
}
}
/** Tries to merge two types, returning the more specific one if possible.
*
* @param ty2
* The other type to merge with.
* @return
* The more specific type if it could be determined, or None if the types are incompatible.
*/
private infix def moreSpecific(ty2: SemType): Option[SemType] =
(ty, ty2) match {
case (ty, ?) => Some(ty)
case (?, ty) => Some(ty)
case (ty1, ty2) if ty1 == ty2 => Some(ty1)
case (KnownType.Array(inn1), KnownType.Array(inn2)) =>
(inn1 moreSpecific inn2).map(KnownType.Array(_))
case (KnownType.Pair(fst1, snd1), KnownType.Pair(fst2, snd2)) =>
(fst1 moreSpecific fst2, snd1 moreSpecific snd2).mapN(KnownType.Pair(_, _))
case _ => None
}
}
/** Type-check a WACC program.
*
* @param prog
* The AST of the program to type-check.
* @param ctx
* The type checker context which includes the global names and functions, and an errors
* builder.
*/
def check(prog: Program)(using
ctx: TypeCheckerCtx
): Unit = {
// Ignore function syntax types for return value and params, since those have been converted
// to SemTypes by the renamer.
prog.funcs.foreach { case FuncDecl(_, name, _, stmts) =>
val FuncType(retType, _) = ctx.funcType(name)
stmts.toList.foreach(
checkStmt(_, Constraint.Is(retType, s"function ${name.v} must return $retType"))
)
}
prog.main.toList.foreach(checkStmt(_, Constraint.Never("main function must not return")))
}
/** Type-check an AST statement node.
*
* @param stmt
* The statement to type-check.
* @param returnConstraint
* The constraint that any `return <expr>` statements must satisfy.
*/
private def checkStmt(stmt: Stmt, returnConstraint: Constraint)(using
ctx: TypeCheckerCtx
): Unit = stmt match {
// Ignore the type of the variable, since it has been converted to a SemType by the renamer.
case VarDecl(_, name, value) =>
val expectedTy = ctx.typeOf(name)
checkValue(
value,
Constraint.Is(
expectedTy,
s"variable ${name.v} must be assigned a value of type $expectedTy"
)
)
case Assign(lhs, rhs) =>
val lhsTy = checkValue(lhs, Constraint.Unconstrained)
(lhsTy, checkValue(rhs, Constraint.Is(lhsTy, s"assignment must have type $lhsTy"))) match {
case (?, ?) =>
ctx.error(
Error.SemanticError(lhs.pos, "assignment with both sides of unknown type is illegal")
)
case _ => ()
}
case Read(dest) =>
checkValue(dest, Constraint.Unconstrained) match {
case ? =>
ctx.error(
Error.SemanticError(dest.pos, "cannot read into a destination with an unknown type")
)
case destTy =>
destTy.satisfies(
Constraint.IsEither(
KnownType.Int,
KnownType.Char,
"read must be applied to an int or char"
),
dest.pos
)
}
case Free(lhs) =>
checkValue(
lhs,
Constraint.IsEither(
KnownType.Array(?),
KnownType.Pair(?, ?),
"free must be applied to an array or pair"
)
)
case Return(expr) =>
checkValue(expr, returnConstraint)
case Exit(expr) =>
checkValue(expr, Constraint.Is(KnownType.Int, "exit value must be int"))
case Print(expr, _) =>
// This constraint should never fail, the scope-checker should have caught it already
checkValue(expr, Constraint.Unconstrained)
case If(cond, thenStmt, elseStmt) =>
checkValue(cond, Constraint.Is(KnownType.Bool, "if condition must be a bool"))
thenStmt.toList.foreach(checkStmt(_, returnConstraint))
elseStmt.toList.foreach(checkStmt(_, returnConstraint))
case While(cond, body) =>
checkValue(cond, Constraint.Is(KnownType.Bool, "while condition must be a bool"))
body.toList.foreach(checkStmt(_, returnConstraint))
case Block(body) =>
body.toList.foreach(checkStmt(_, returnConstraint))
case Skip() => ()
}
/** Type-check an AST LValue, RValue or Expr node. This function does all 3 since these traits
* overlap in the AST.
*
* @param value
* The value to type-check.
* @param constraint
* The type constraint that the value must satisfy.
* @return
* The most specific type of the value if it could be determined, or ? if it could not.
*/
private def checkValue(value: LValue | RValue | Expr, constraint: Constraint)(using
ctx: TypeCheckerCtx
): SemType = value match {
case l @ IntLiter(_) => KnownType.Int.satisfies(constraint, l.pos)
case l @ BoolLiter(_) => KnownType.Bool.satisfies(constraint, l.pos)
case l @ CharLiter(_) => KnownType.Char.satisfies(constraint, l.pos)
case l @ StrLiter(_) => KnownType.String.satisfies(constraint, l.pos)
case l @ PairLiter() => KnownType.Pair(?, ?).satisfies(constraint, l.pos)
case id: Ident =>
ctx.typeOf(id).satisfies(constraint, id.pos)
case ArrayElem(id, indices) =>
val arrayTy = ctx.typeOf(id)
val elemTy = indices.foldLeftM(arrayTy) { (acc, elem) =>
checkValue(elem, Constraint.Is(KnownType.Int, "array index must be an int"))
acc match {
case KnownType.Array(innerTy) => Some(innerTy)
case ? => Some(?) // we can keep indexing an unknown type
case nonArrayTy =>
ctx.error(
Error.TypeMismatch(elem.pos, KnownType.Array(?), acc, "cannot index into a non-array")
)
None
}
}
elemTy.getOrElse(?).satisfies(constraint, id.pos)
case Parens(expr) => checkValue(expr, constraint)
case l @ ArrayLiter(elems) =>
KnownType
// Start with an unknown param type, make it more specific while checking the elements.
.Array(elems.foldLeft[SemType](?) { case (acc, elem) =>
checkValue(
elem,
Constraint.IsSymmetricCompatible(acc, s"array elements must have the same type")
)
})
.satisfies(constraint, l.pos)
case l @ NewPair(fst, snd) =>
KnownType
.Pair(
checkValue(fst, Constraint.Unconstrained),
checkValue(snd, Constraint.Unconstrained)
)
.satisfies(constraint, l.pos)
case Call(id, args) =>
val funcTy @ FuncType(retTy, paramTys) = ctx.funcType(id)
if (args.length != paramTys.length) {
ctx.error(Error.FunctionParamsMismatch(id, paramTys.length, args.length, funcTy))
}
// Even if the number of arguments is wrong, we still check the types of the arguments
// in the best way we can (by taking a zip).
args.zip(paramTys).foreach { case (arg, paramTy) =>
checkValue(arg, Constraint.Is(paramTy, s"argument type mismatch in function ${id.v}"))
}
retTy.satisfies(constraint, id.pos)
case Fst(elem) =>
checkValue(
elem,
Constraint.Is(KnownType.Pair(?, ?), "fst must be applied to a pair")
) match {
case what @ KnownType.Pair(left, _) =>
left.satisfies(constraint, elem.pos)
case _ => ctx.error(Error.InternalError(elem.pos, "fst must be applied to a pair"))
}
case Snd(elem) =>
checkValue(
elem,
Constraint.Is(KnownType.Pair(?, ?), "snd must be applied to a pair")
) match {
case KnownType.Pair(_, right) => right.satisfies(constraint, elem.pos)
case _ => ctx.error(Error.InternalError(elem.pos, "snd must be applied to a pair"))
}
// Unary operators
case Negate(x) =>
checkValue(x, Constraint.Is(KnownType.Int, "negation must be applied to an int"))
KnownType.Int.satisfies(constraint, x.pos)
case Not(x) =>
checkValue(x, Constraint.Is(KnownType.Bool, "logical not must be applied to a bool"))
KnownType.Bool.satisfies(constraint, x.pos)
case Len(x) =>
checkValue(x, Constraint.Is(KnownType.Array(?), "len must be applied to an array"))
KnownType.Int.satisfies(constraint, x.pos)
case Ord(x) =>
checkValue(x, Constraint.Is(KnownType.Char, "ord must be applied to a char"))
KnownType.Int.satisfies(constraint, x.pos)
case Chr(x) =>
checkValue(x, Constraint.Is(KnownType.Int, "chr must be applied to an int"))
KnownType.Char.satisfies(constraint, x.pos)
// Binary operators
case op: (Add | Sub | Mul | Div | Mod) =>
val operand = Constraint.Is(KnownType.Int, s"${op.name} operator must be applied to an int")
checkValue(op.x, operand)
checkValue(op.y, operand)
KnownType.Int.satisfies(constraint, op.pos)
case op: (Eq | Neq) =>
val xTy = checkValue(op.x, Constraint.Unconstrained)
checkValue(
op.y,
Constraint.Is(xTy, s"${op.name} operator must be applied to values of the same type")
)
KnownType.Bool.satisfies(constraint, op.pos)
case op: (Less | LessEq | Greater | GreaterEq) =>
val xConstraint = Constraint.IsEither(
KnownType.Int,
KnownType.Char,
s"${op.name} operator must be applied to an int or char"
)
// If x type-check failed, we still want to check y is an Int or Char (rather than ?)
val yConstraint = checkValue(op.x, xConstraint) match {
case ? => xConstraint
case xTy =>
Constraint.Is(xTy, s"${op.name} operator must be applied to values of the same type")
}
checkValue(op.y, yConstraint)
KnownType.Bool.satisfies(constraint, op.pos)
case op: (And | Or) =>
val operand = Constraint.Is(KnownType.Bool, s"${op.name} operator must be applied to a bool")
checkValue(op.x, operand)
checkValue(op.y, operand)
KnownType.Bool.satisfies(constraint, op.pos)
}
}

View File

@@ -0,0 +1,45 @@
package wacc
object types {
import ast._
sealed trait SemType {
override def toString(): String = this match {
case KnownType.Int => "int"
case KnownType.Bool => "bool"
case KnownType.Char => "char"
case KnownType.String => "string"
case KnownType.Array(?) => "array"
case KnownType.Array(elem) => s"$elem[]"
case KnownType.Pair(?, ?) => "pair"
case KnownType.Pair(left, right) => s"pair($left, $right)"
case ? => "<unknown-type>"
}
}
case object ? extends SemType
enum KnownType extends SemType {
case Int
case Bool
case Char
case String
case Array(elem: SemType)
case Pair(left: SemType, right: SemType)
}
object SemType {
def apply(synType: Type | PairElemType): KnownType = synType match {
case IntType() => KnownType.Int
case BoolType() => KnownType.Bool
case CharType() => KnownType.Char
case StringType() => KnownType.String
// For semantic types it is easier to work with recursion rather than a fixed size
case ArrayType(elemType, dimension) =>
(0 until dimension).foldLeft(SemType(elemType))((acc, _) => KnownType.Array(acc))
case PairType(fst, snd) => KnownType.Pair(SemType(fst), SemType(snd))
case UntypedPairType() => KnownType.Pair(?, ?)
}
}
case class FuncType(returnType: SemType, params: List[SemType])
}