feat: incomplete initial implementation of assemblyIR
This commit is contained in:
93
src/main/wacc/frontend/Error.scala
Normal file
93
src/main/wacc/frontend/Error.scala
Normal file
@@ -0,0 +1,93 @@
|
||||
package wacc
|
||||
|
||||
import wacc.ast.Position
|
||||
import wacc.types._
|
||||
|
||||
/** Error types for semantic errors
|
||||
*/
|
||||
enum Error {
|
||||
case DuplicateDeclaration(ident: ast.Ident)
|
||||
case UndeclaredVariable(ident: ast.Ident)
|
||||
case UndefinedFunction(ident: ast.Ident)
|
||||
|
||||
case FunctionParamsMismatch(ident: ast.Ident, expected: Int, got: Int, funcType: FuncType)
|
||||
case SemanticError(pos: Position, msg: String)
|
||||
case TypeMismatch(pos: Position, expected: SemType, got: SemType, msg: String)
|
||||
case InternalError(pos: Position, msg: String)
|
||||
}
|
||||
|
||||
/** Function to handle printing the details of a given semantic error
|
||||
*
|
||||
* @param error
|
||||
* Error object
|
||||
* @param errorContent
|
||||
* Contents of the file to generate code snippets
|
||||
*/
|
||||
def printError(error: Error)(using errorContent: String): Unit = {
|
||||
println("Semantic error:")
|
||||
error match {
|
||||
case Error.DuplicateDeclaration(ident) =>
|
||||
printPosition(ident.pos)
|
||||
println(s"Duplicate declaration of identifier ${ident.v}")
|
||||
highlight(ident.pos, ident.v.length)
|
||||
case Error.UndeclaredVariable(ident) =>
|
||||
printPosition(ident.pos)
|
||||
println(s"Undeclared variable ${ident.v}")
|
||||
highlight(ident.pos, ident.v.length)
|
||||
case Error.UndefinedFunction(ident) =>
|
||||
printPosition(ident.pos)
|
||||
println(s"Undefined function ${ident.v}")
|
||||
highlight(ident.pos, ident.v.length)
|
||||
case Error.FunctionParamsMismatch(id, expected, got, funcType) =>
|
||||
printPosition(id.pos)
|
||||
println(s"Function expects $expected parameters, got $got")
|
||||
println(
|
||||
s"(function ${id.v} has type (${funcType.params.mkString(", ")}) -> ${funcType.returnType})"
|
||||
)
|
||||
highlight(id.pos, 1)
|
||||
case Error.TypeMismatch(pos, expected, got, msg) =>
|
||||
printPosition(pos)
|
||||
println(s"Type mismatch: $msg\nExpected: $expected\nGot: $got")
|
||||
highlight(pos, 1)
|
||||
case Error.SemanticError(pos, msg) =>
|
||||
printPosition(pos)
|
||||
println(msg)
|
||||
highlight(pos, 1)
|
||||
case wacc.Error.InternalError(pos, msg) =>
|
||||
printPosition(pos)
|
||||
println(s"Internal error: $msg")
|
||||
highlight(pos, 1)
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
/** Function to highlight a section of code for an error message
|
||||
*
|
||||
* @param pos
|
||||
* Position of the error
|
||||
* @param size
|
||||
* Size(in chars) of section to highlight
|
||||
* @param errorContent
|
||||
* Contents of the file to generate code snippets
|
||||
*/
|
||||
def highlight(pos: Position, size: Int)(using errorContent: String): Unit = {
|
||||
val lines = errorContent.split("\n")
|
||||
|
||||
val preLine = if (pos.line > 1) lines(pos.line - 2) else ""
|
||||
val midLine = lines(pos.line - 1)
|
||||
val postLine = if (pos.line < lines.size) lines(pos.line) else ""
|
||||
val linePointer = " " * (pos.column + 2) + ("^" * (size)) + "\n"
|
||||
|
||||
println(
|
||||
s" >$preLine\n >$midLine\n$linePointer >$postLine"
|
||||
)
|
||||
}
|
||||
|
||||
/** Function to print the position of an error
|
||||
*
|
||||
* @param pos
|
||||
* Position of the error
|
||||
*/
|
||||
def printPosition(pos: Position): Unit = {
|
||||
println(s"(line ${pos.line}, column ${pos.column}):")
|
||||
}
|
||||
253
src/main/wacc/frontend/ast.scala
Normal file
253
src/main/wacc/frontend/ast.scala
Normal file
@@ -0,0 +1,253 @@
|
||||
package wacc
|
||||
|
||||
import parsley.Parsley
|
||||
import parsley.generic.ErrorBridge
|
||||
import parsley.ap._
|
||||
import parsley.position._
|
||||
import parsley.syntax.zipped._
|
||||
import cats.data.NonEmptyList
|
||||
|
||||
object ast {
|
||||
/* ============================ EXPRESSIONS ============================ */
|
||||
|
||||
sealed trait Expr extends RValue {
|
||||
val pos: Position
|
||||
}
|
||||
sealed trait Expr1 extends Expr
|
||||
sealed trait Expr2 extends Expr1
|
||||
sealed trait Expr3 extends Expr2
|
||||
sealed trait Expr4 extends Expr3
|
||||
sealed trait Expr5 extends Expr4
|
||||
sealed trait Expr6 extends Expr5
|
||||
|
||||
/* ============================ ATOMIC EXPRESSIONS ============================ */
|
||||
|
||||
case class IntLiter(v: Int)(val pos: Position) extends Expr6
|
||||
object IntLiter extends ParserBridgePos1[Int, IntLiter]
|
||||
case class BoolLiter(v: Boolean)(val pos: Position) extends Expr6
|
||||
object BoolLiter extends ParserBridgePos1[Boolean, BoolLiter]
|
||||
case class CharLiter(v: Char)(val pos: Position) extends Expr6
|
||||
object CharLiter extends ParserBridgePos1[Char, CharLiter]
|
||||
case class StrLiter(v: String)(val pos: Position) extends Expr6
|
||||
object StrLiter extends ParserBridgePos1[String, StrLiter]
|
||||
case class PairLiter()(val pos: Position) extends Expr6
|
||||
object PairLiter extends ParserBridgePos0[PairLiter]
|
||||
case class Ident(v: String, var uid: Int = -1)(val pos: Position) extends Expr6 with LValue
|
||||
object Ident extends ParserBridgePos1[String, Ident] {
|
||||
def apply(v: String)(pos: Position): Ident = new Ident(v)(pos)
|
||||
}
|
||||
case class ArrayElem(name: Ident, indices: NonEmptyList[Expr])(val pos: Position)
|
||||
extends Expr6
|
||||
with LValue
|
||||
object ArrayElem extends ParserBridgePos1[NonEmptyList[Expr], Ident => ArrayElem] {
|
||||
def apply(a: NonEmptyList[Expr])(pos: Position): Ident => ArrayElem =
|
||||
name => ArrayElem(name, a)(pos)
|
||||
}
|
||||
case class Parens(expr: Expr)(val pos: Position) extends Expr6
|
||||
object Parens extends ParserBridgePos1[Expr, Parens]
|
||||
|
||||
/* ============================ UNARY OPERATORS ============================ */
|
||||
|
||||
sealed trait UnaryOp extends Expr {
|
||||
val x: Expr
|
||||
}
|
||||
case class Negate(x: Expr6)(val pos: Position) extends Expr6 with UnaryOp
|
||||
object Negate extends ParserBridgePos1[Expr6, Negate]
|
||||
case class Not(x: Expr6)(val pos: Position) extends Expr6 with UnaryOp
|
||||
object Not extends ParserBridgePos1[Expr6, Not]
|
||||
case class Len(x: Expr6)(val pos: Position) extends Expr6 with UnaryOp
|
||||
object Len extends ParserBridgePos1[Expr6, Len]
|
||||
case class Ord(x: Expr6)(val pos: Position) extends Expr6 with UnaryOp
|
||||
object Ord extends ParserBridgePos1[Expr6, Ord]
|
||||
case class Chr(x: Expr6)(val pos: Position) extends Expr6 with UnaryOp
|
||||
object Chr extends ParserBridgePos1[Expr6, Chr]
|
||||
|
||||
/* ============================ BINARY OPERATORS ============================ */
|
||||
|
||||
sealed trait BinaryOp(val name: String) extends Expr {
|
||||
val x: Expr
|
||||
val y: Expr
|
||||
}
|
||||
case class Add(x: Expr4, y: Expr5)(val pos: Position) extends Expr4 with BinaryOp("addition")
|
||||
object Add extends ParserBridgePos2[Expr4, Expr5, Add]
|
||||
case class Sub(x: Expr4, y: Expr5)(val pos: Position) extends Expr4 with BinaryOp("subtraction")
|
||||
object Sub extends ParserBridgePos2[Expr4, Expr5, Sub]
|
||||
case class Mul(x: Expr5, y: Expr6)(val pos: Position)
|
||||
extends Expr5
|
||||
with BinaryOp("multiplication")
|
||||
object Mul extends ParserBridgePos2[Expr5, Expr6, Mul]
|
||||
case class Div(x: Expr5, y: Expr6)(val pos: Position) extends Expr5 with BinaryOp("division")
|
||||
object Div extends ParserBridgePos2[Expr5, Expr6, Div]
|
||||
case class Mod(x: Expr5, y: Expr6)(val pos: Position) extends Expr5 with BinaryOp("modulus")
|
||||
object Mod extends ParserBridgePos2[Expr5, Expr6, Mod]
|
||||
case class Greater(x: Expr4, y: Expr4)(val pos: Position)
|
||||
extends Expr3
|
||||
with BinaryOp("strictly greater than")
|
||||
object Greater extends ParserBridgePos2[Expr4, Expr4, Greater]
|
||||
case class GreaterEq(x: Expr4, y: Expr4)(val pos: Position)
|
||||
extends Expr3
|
||||
with BinaryOp("greater than or equal to")
|
||||
object GreaterEq extends ParserBridgePos2[Expr4, Expr4, GreaterEq]
|
||||
case class Less(x: Expr4, y: Expr4)(val pos: Position)
|
||||
extends Expr3
|
||||
with BinaryOp("strictly less than")
|
||||
object Less extends ParserBridgePos2[Expr4, Expr4, Less]
|
||||
case class LessEq(x: Expr4, y: Expr4)(val pos: Position)
|
||||
extends Expr3
|
||||
with BinaryOp("less than or equal to")
|
||||
object LessEq extends ParserBridgePos2[Expr4, Expr4, LessEq]
|
||||
case class Eq(x: Expr3, y: Expr3)(val pos: Position) extends Expr2 with BinaryOp("equality")
|
||||
object Eq extends ParserBridgePos2[Expr3, Expr3, Eq]
|
||||
case class Neq(x: Expr3, y: Expr3)(val pos: Position) extends Expr2 with BinaryOp("inequality")
|
||||
object Neq extends ParserBridgePos2[Expr3, Expr3, Neq]
|
||||
case class And(x: Expr2, y: Expr1)(val pos: Position) extends Expr1 with BinaryOp("logical and")
|
||||
object And extends ParserBridgePos2[Expr2, Expr1, And]
|
||||
case class Or(x: Expr1, y: Expr)(val pos: Position) extends Expr with BinaryOp("logical or")
|
||||
object Or extends ParserBridgePos2[Expr1, Expr, Or]
|
||||
|
||||
/* ============================ TYPES ============================ */
|
||||
|
||||
sealed trait Type
|
||||
sealed trait BaseType extends Type with PairElemType
|
||||
case class IntType()(val pos: Position) extends BaseType
|
||||
object IntType extends ParserBridgePos0[IntType]
|
||||
case class BoolType()(val pos: Position) extends BaseType
|
||||
object BoolType extends ParserBridgePos0[BoolType]
|
||||
case class CharType()(val pos: Position) extends BaseType
|
||||
object CharType extends ParserBridgePos0[CharType]
|
||||
case class StringType()(val pos: Position) extends BaseType
|
||||
object StringType extends ParserBridgePos0[StringType]
|
||||
case class ArrayType(elemType: Type, dimensions: Int)(val pos: Position)
|
||||
extends Type
|
||||
with PairElemType
|
||||
object ArrayType extends ParserBridgePos1[Int, Type => ArrayType] {
|
||||
def apply(a: Int)(pos: Position): Type => ArrayType = elemType => ArrayType(elemType, a)(pos)
|
||||
}
|
||||
case class PairType(fst: PairElemType, snd: PairElemType)(val pos: Position) extends Type
|
||||
object PairType extends ParserBridgePos2[PairElemType, PairElemType, PairType]
|
||||
|
||||
sealed trait PairElemType
|
||||
case class UntypedPairType()(val pos: Position) extends PairElemType
|
||||
object UntypedPairType extends ParserBridgePos0[UntypedPairType]
|
||||
|
||||
/* ============================ PROGRAM STRUCTURE ============================ */
|
||||
|
||||
case class Program(funcs: List[FuncDecl], main: NonEmptyList[Stmt])(val pos: Position)
|
||||
object Program extends ParserBridgePos2[List[FuncDecl], NonEmptyList[Stmt], Program]
|
||||
|
||||
/* ============================ FUNCTION STRUCTURE ============================ */
|
||||
|
||||
case class FuncDecl(
|
||||
returnType: Type,
|
||||
name: Ident,
|
||||
params: List[Param],
|
||||
body: NonEmptyList[Stmt]
|
||||
)(val pos: Position)
|
||||
object FuncDecl
|
||||
extends ParserBridgePos2[
|
||||
List[Param],
|
||||
NonEmptyList[Stmt],
|
||||
((Type, Ident)) => FuncDecl
|
||||
] {
|
||||
def apply(params: List[Param], body: NonEmptyList[Stmt])(
|
||||
pos: Position
|
||||
): ((Type, Ident)) => FuncDecl =
|
||||
(returnType, name) => FuncDecl(returnType, name, params, body)(pos)
|
||||
}
|
||||
|
||||
case class Param(paramType: Type, name: Ident)(val pos: Position)
|
||||
object Param extends ParserBridgePos2[Type, Ident, Param]
|
||||
|
||||
/* ============================ STATEMENTS ============================ */
|
||||
|
||||
sealed trait Stmt
|
||||
case class Skip()(val pos: Position) extends Stmt
|
||||
object Skip extends ParserBridgePos0[Skip]
|
||||
case class VarDecl(varType: Type, name: Ident, value: RValue)(val pos: Position) extends Stmt
|
||||
object VarDecl extends ParserBridgePos3[Type, Ident, RValue, VarDecl]
|
||||
case class Assign(lhs: LValue, value: RValue)(val pos: Position) extends Stmt
|
||||
object Assign extends ParserBridgePos2[LValue, RValue, Assign]
|
||||
case class Read(lhs: LValue)(val pos: Position) extends Stmt
|
||||
object Read extends ParserBridgePos1[LValue, Read]
|
||||
case class Free(expr: Expr)(val pos: Position) extends Stmt
|
||||
object Free extends ParserBridgePos1[Expr, Free]
|
||||
case class Return(expr: Expr)(val pos: Position) extends Stmt
|
||||
object Return extends ParserBridgePos1[Expr, Return]
|
||||
case class Exit(expr: Expr)(val pos: Position) extends Stmt
|
||||
object Exit extends ParserBridgePos1[Expr, Exit]
|
||||
case class Print(expr: Expr, newline: Boolean)(val pos: Position) extends Stmt
|
||||
object Print extends ParserBridgePos2[Expr, Boolean, Print]
|
||||
case class If(cond: Expr, thenStmt: NonEmptyList[Stmt], elseStmt: NonEmptyList[Stmt])(
|
||||
val pos: Position
|
||||
) extends Stmt
|
||||
object If extends ParserBridgePos3[Expr, NonEmptyList[Stmt], NonEmptyList[Stmt], If]
|
||||
case class While(cond: Expr, body: NonEmptyList[Stmt])(val pos: Position) extends Stmt
|
||||
object While extends ParserBridgePos2[Expr, NonEmptyList[Stmt], While]
|
||||
case class Block(stmt: NonEmptyList[Stmt])(val pos: Position) extends Stmt
|
||||
object Block extends ParserBridgePos1[NonEmptyList[Stmt], Block]
|
||||
|
||||
/* ============================ LVALUES & RVALUES ============================ */
|
||||
|
||||
sealed trait LValue {
|
||||
val pos: Position
|
||||
}
|
||||
|
||||
sealed trait RValue
|
||||
case class ArrayLiter(elems: List[Expr])(val pos: Position) extends RValue
|
||||
object ArrayLiter extends ParserBridgePos1[List[Expr], ArrayLiter]
|
||||
case class NewPair(fst: Expr, snd: Expr)(val pos: Position) extends RValue
|
||||
object NewPair extends ParserBridgePos2[Expr, Expr, NewPair]
|
||||
case class Call(name: Ident, args: List[Expr])(val pos: Position) extends RValue
|
||||
object Call extends ParserBridgePos2[Ident, List[Expr], Call]
|
||||
|
||||
sealed trait PairElem extends LValue with RValue
|
||||
case class Fst(elem: LValue)(val pos: Position) extends PairElem
|
||||
object Fst extends ParserBridgePos1[LValue, Fst]
|
||||
case class Snd(elem: LValue)(val pos: Position) extends PairElem
|
||||
object Snd extends ParserBridgePos1[LValue, Snd]
|
||||
|
||||
/* ============================ PARSER BRIDGES ============================ */
|
||||
|
||||
case class Position(line: Int, column: Int)
|
||||
|
||||
trait ParserSingletonBridgePos[+A] extends ErrorBridge {
|
||||
protected def con(pos: (Int, Int)): A
|
||||
infix def from(op: Parsley[?]): Parsley[A] = error(pos.map(con) <~ op)
|
||||
final def <#(op: Parsley[?]): Parsley[A] = this from op
|
||||
}
|
||||
|
||||
trait ParserBridgePos0[+A] extends ParserSingletonBridgePos[A] {
|
||||
def apply()(pos: Position): A
|
||||
|
||||
override final def con(pos: (Int, Int)): A =
|
||||
apply()(Position(pos._1, pos._2))
|
||||
}
|
||||
|
||||
trait ParserBridgePos1[-A, +B] extends ParserSingletonBridgePos[A => B] {
|
||||
def apply(a: A)(pos: Position): B
|
||||
def apply(a: Parsley[A]): Parsley[B] = error(ap1(pos.map(con), a))
|
||||
|
||||
override final def con(pos: (Int, Int)): A => B =
|
||||
this.apply(_)(Position(pos._1, pos._2))
|
||||
}
|
||||
|
||||
trait ParserBridgePos2[-A, -B, +C] extends ParserSingletonBridgePos[(A, B) => C] {
|
||||
def apply(a: A, b: B)(pos: Position): C
|
||||
def apply(a: Parsley[A], b: => Parsley[B]): Parsley[C] = error(
|
||||
ap2(pos.map(con), a, b)
|
||||
)
|
||||
|
||||
override final def con(pos: (Int, Int)): (A, B) => C =
|
||||
apply(_, _)(Position(pos._1, pos._2))
|
||||
}
|
||||
|
||||
trait ParserBridgePos3[-A, -B, -C, +D] extends ParserSingletonBridgePos[(A, B, C) => D] {
|
||||
def apply(a: A, b: B, c: C)(pos: Position): D
|
||||
def apply(a: Parsley[A], b: => Parsley[B], c: => Parsley[C]): Parsley[D] = error(
|
||||
ap3(pos.map(con), a, b, c)
|
||||
)
|
||||
|
||||
override final def con(pos: (Int, Int)): (A, B, C) => D =
|
||||
apply(_, _, _)(Position(pos._1, pos._2))
|
||||
}
|
||||
}
|
||||
105
src/main/wacc/frontend/lexer.scala
Normal file
105
src/main/wacc/frontend/lexer.scala
Normal file
@@ -0,0 +1,105 @@
|
||||
package wacc
|
||||
|
||||
import parsley.Parsley
|
||||
import parsley.character
|
||||
import parsley.token.{Basic, Lexer}
|
||||
import parsley.token.descriptions.*
|
||||
import parsley.token.errors._
|
||||
|
||||
/** ErrorConfig for producing more informative error messages
|
||||
*/
|
||||
val errConfig = new ErrorConfig {
|
||||
override def labelSymbol = Map(
|
||||
"!=" -> Label("binary operator"),
|
||||
"%" -> Label("binary operator"),
|
||||
"&&" -> Label("binary operator"),
|
||||
"*" -> Label("binary operator"),
|
||||
"/" -> Label("binary operator"),
|
||||
"<" -> Label("binary operator"),
|
||||
"<=" -> Label("binary operator"),
|
||||
"==" -> Label("binary operator"),
|
||||
">" -> Label("binary operator"),
|
||||
">=" -> Label("binary operator"),
|
||||
"||" -> Label("binary operator"),
|
||||
"!" -> Label("unary operator"),
|
||||
"len" -> Label("unary operator"),
|
||||
"ord" -> Label("unary operator"),
|
||||
"chr" -> Label("unary operator"),
|
||||
"bool" -> Label("valid type"),
|
||||
"char" -> Label("valid type"),
|
||||
"int" -> Label("valid type"),
|
||||
"pair" -> Label("valid type"),
|
||||
"string" -> Label("valid type"),
|
||||
"fst" -> Label("pair extraction"),
|
||||
"snd" -> Label("pair extraction"),
|
||||
"false" -> Label("boolean literal"),
|
||||
"true" -> Label("boolean literal"),
|
||||
"=" -> Label("assignment"),
|
||||
"[" -> Label("array index")
|
||||
)
|
||||
}
|
||||
object lexer {
|
||||
|
||||
/** Language description for the WACC lexer
|
||||
*/
|
||||
private val desc = LexicalDesc.plain.copy(
|
||||
nameDesc = NameDesc.plain.copy(
|
||||
identifierStart = Basic(c => c.isLetter || c == '_'),
|
||||
identifierLetter = Basic(c => c.isLetterOrDigit || c == '_')
|
||||
),
|
||||
symbolDesc = SymbolDesc.plain.copy(
|
||||
hardKeywords = Set(
|
||||
"begin", "end", "is", "skip", "if", "then", "else", "fi", "while", "do", "done", "read",
|
||||
"free", "return", "exit", "print", "println", "true", "false", "int", "bool", "char",
|
||||
"string", "pair", "newpair", "fst", "snd", "call", "chr", "ord", "len", "null"
|
||||
),
|
||||
hardOperators = Set(
|
||||
"+", "-", "*", "/", "%", ">", "<", ">=", "<=", "==", "!=", "&&", "||", "!"
|
||||
)
|
||||
),
|
||||
spaceDesc = SpaceDesc.plain.copy(
|
||||
lineCommentStart = "#"
|
||||
),
|
||||
textDesc = TextDesc.plain.copy(
|
||||
graphicCharacter = Basic(c => c >= ' ' && c != '\\' && c != '\'' && c != '"'),
|
||||
escapeSequences = EscapeDesc.plain.copy(
|
||||
literals = Set('\\', '"', '\''),
|
||||
mapping = Map(
|
||||
"0" -> '\u0000',
|
||||
"b" -> '\b',
|
||||
"t" -> '\t',
|
||||
"n" -> '\n',
|
||||
"f" -> '\f',
|
||||
"r" -> '\r'
|
||||
)
|
||||
)
|
||||
),
|
||||
numericDesc = NumericDesc.plain.copy(
|
||||
decimalExponentDesc = ExponentDesc.NoExponents
|
||||
)
|
||||
)
|
||||
|
||||
/** Token definitions for the WACC lexer
|
||||
*/
|
||||
private val lexer = Lexer(desc, errConfig)
|
||||
val ident = lexer.lexeme.names.identifier
|
||||
val integer = lexer.lexeme.integer.decimal32[Int]
|
||||
val negateCheck = lexer.nonlexeme.symbol("-") ~> character.digit
|
||||
val charLit = lexer.lexeme.character.ascii
|
||||
val stringLit = lexer.lexeme.string.ascii
|
||||
val implicits = lexer.lexeme.symbol.implicits
|
||||
|
||||
/** Tokens for producing lexer-backed error messages
|
||||
*/
|
||||
val errTokens = Seq(
|
||||
lexer.nonlexeme.names.identifier.map(v => s"identifier $v"),
|
||||
lexer.nonlexeme.integer.decimal32[Int].map(n => s"integer $n"),
|
||||
lexer.nonlexeme.character.ascii.map(c => s"character literal \'$c\'"),
|
||||
lexer.nonlexeme.string.ascii.map(s => s"string literal \"$s\""),
|
||||
lexer.nonlexeme.symbol("[").as("array literal"),
|
||||
character.whitespace.map(_ => "")
|
||||
) ++ desc.symbolDesc.hardKeywords.map { k =>
|
||||
lexer.nonlexeme.symbol(k).as(s"keyword $k")
|
||||
}
|
||||
def fully[A](p: Parsley[A]): Parsley[A] = lexer.fully(p)
|
||||
}
|
||||
224
src/main/wacc/frontend/parser.scala
Normal file
224
src/main/wacc/frontend/parser.scala
Normal file
@@ -0,0 +1,224 @@
|
||||
package wacc
|
||||
|
||||
import parsley.Result
|
||||
import parsley.Parsley
|
||||
import parsley.Parsley.{atomic, many, notFollowedBy, pure, unit}
|
||||
import parsley.combinator.{countSome, sepBy}
|
||||
import parsley.expr.{precedence, SOps, InfixL, InfixN, InfixR, Prefix, Atoms}
|
||||
import parsley.errors.combinator._
|
||||
import parsley.errors.patterns.VerifiedErrors
|
||||
import parsley.syntax.zipped._
|
||||
import parsley.cats.combinator.{some}
|
||||
import cats.data.NonEmptyList
|
||||
import parsley.errors.DefaultErrorBuilder
|
||||
import parsley.errors.ErrorBuilder
|
||||
import parsley.errors.tokenextractors.LexToken
|
||||
|
||||
object parser {
|
||||
import lexer.implicits.implicitSymbol
|
||||
import lexer.{ident, integer, charLit, stringLit, negateCheck, errTokens}
|
||||
import ast._
|
||||
|
||||
// error extensions
|
||||
extension [A](p: Parsley[A]) {
|
||||
// combines label and explain together into one function call
|
||||
def labelAndExplain(label: String, explanation: String): Parsley[A] = {
|
||||
p.label(label).explain(explanation)
|
||||
}
|
||||
def labelAndExplain(t: LabelType): Parsley[A] = {
|
||||
t match {
|
||||
case LabelType.Expr =>
|
||||
labelWithType(t).explain(
|
||||
"a valid expression can start with: null, literals, identifiers, unary operators, or parentheses. " +
|
||||
"Expressions can also contain array indexing and binary operators. " +
|
||||
"Pair extraction is not allowed in expressions, only in assignments."
|
||||
)
|
||||
case _ => labelWithType(t)
|
||||
}
|
||||
}
|
||||
|
||||
def labelWithType(t: LabelType): Parsley[A] = {
|
||||
t match {
|
||||
case LabelType.Expr => p.label("valid expression")
|
||||
case LabelType.Pair => p.label("valid pair")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
enum LabelType:
|
||||
case Expr
|
||||
case Pair
|
||||
|
||||
implicit val builder: ErrorBuilder[String] = new DefaultErrorBuilder with LexToken {
|
||||
def tokens = errTokens
|
||||
}
|
||||
def parse(input: String): Result[String, Program] = parser.parse(input)
|
||||
private val parser = lexer.fully(`<program>`)
|
||||
|
||||
// Expressions
|
||||
private lazy val `<expr>`: Parsley[Expr] = precedence {
|
||||
SOps(InfixR)(Or from "||") +:
|
||||
SOps(InfixR)(And from "&&") +:
|
||||
SOps(InfixN)(Eq from "==", Neq from "!=") +:
|
||||
SOps(InfixN)(
|
||||
Less from "<",
|
||||
LessEq from "<=",
|
||||
Greater from ">",
|
||||
GreaterEq from ">="
|
||||
) +:
|
||||
SOps(InfixL)(
|
||||
(Add from "+").label("binary operator"),
|
||||
(Sub from "-").label("binary operator")
|
||||
) +:
|
||||
SOps(InfixL)(Mul from "*", Div from "/", Mod from "%") +:
|
||||
SOps(Prefix)(
|
||||
Not from "!",
|
||||
// notFollowedBy(negateCheck) ensures that negative numbers are parsed as a single int literal
|
||||
(Negate from (notFollowedBy(negateCheck) ~> "-")).hide,
|
||||
Len from "len",
|
||||
Ord from "ord",
|
||||
Chr from "chr"
|
||||
) +:
|
||||
`<atom>`
|
||||
}
|
||||
|
||||
// Atoms
|
||||
private lazy val `<atom>`: Atoms[Expr6] = Atoms(
|
||||
IntLiter(integer).label("integer literal"),
|
||||
BoolLiter(("true" as true) | ("false" as false)).label("boolean literal"),
|
||||
CharLiter(charLit).label("character literal"),
|
||||
StrLiter(stringLit).label("string literal"),
|
||||
PairLiter from "null",
|
||||
`<ident-or-array-elem>`,
|
||||
Parens("(" ~> `<expr>` <~ ")")
|
||||
)
|
||||
private val `<ident>` =
|
||||
Ident(ident) | some("*" | "&").verifiedExplain("pointer operators are not allowed")
|
||||
private lazy val `<ident-or-array-elem>` =
|
||||
(`<ident>` <~ ("(".verifiedExplain(
|
||||
"functions can only be called using 'call' keyword"
|
||||
) | unit)) <**> (`<array-indices>` </> identity)
|
||||
private val `<array-indices>` = ArrayElem(some("[" ~> `<expr>` <~ "]"))
|
||||
|
||||
// Types
|
||||
private lazy val `<type>`: Parsley[Type] =
|
||||
(`<base-type>` | (`<pair-type>` ~> `<pair-elems-type>`)) <**> (`<array-type>` </> identity)
|
||||
private val `<base-type>` =
|
||||
(IntType from "int") | (BoolType from "bool") | (CharType from "char") | (StringType from "string")
|
||||
private lazy val `<array-type>` =
|
||||
ArrayType(countSome("[" ~> "]"))
|
||||
private val `<pair-type>` = "pair"
|
||||
private val `<pair-elems-type>`: Parsley[PairType] = PairType(
|
||||
"(" ~> `<pair-elem-type>` <~ ",",
|
||||
`<pair-elem-type>` <~ ")"
|
||||
)
|
||||
private lazy val `<pair-elem-type>` =
|
||||
(`<base-type>` <**> (`<array-type>` </> identity)) |
|
||||
((UntypedPairType from `<pair-type>`) <**>
|
||||
((`<pair-elems-type>` <**> `<array-type>`)
|
||||
.map(arr => (_: UntypedPairType) => arr) </> identity))
|
||||
|
||||
/* Statements
|
||||
Atomic is used in two places here:
|
||||
1. Atomic for function return type - code may be a variable declaration instead, If we were
|
||||
to factor out the type, the resulting code would be rather messy. It can only fail once
|
||||
in the entire program so it creates minimal overhead.
|
||||
2. Atomic for function missing return type check - there is no easy way around an explicit
|
||||
invalid syntax check, this only happens at most once per program so this is not a major
|
||||
concern.
|
||||
*/
|
||||
private lazy val `<program>` = Program(
|
||||
"begin" ~> (
|
||||
many(
|
||||
atomic(
|
||||
`<type>`.label("function declaration") <~> `<ident>` <~ "("
|
||||
) <**> `<partial-func-decl>`
|
||||
).label("function declaration") |
|
||||
atomic(`<ident>` <~ "(").verifiedExplain("function declaration is missing return type")
|
||||
),
|
||||
`<stmt>`.label(
|
||||
"main program body"
|
||||
) <~ "end"
|
||||
)
|
||||
private lazy val `<partial-func-decl>` =
|
||||
FuncDecl(
|
||||
sepBy(`<param>`, ",") <~ ")" <~ "is",
|
||||
`<stmt>`.guardAgainst {
|
||||
case stmts if !stmts.isReturning => Seq("all functions must end in a returning statement")
|
||||
} <~ "end"
|
||||
)
|
||||
private lazy val `<param>` = Param(`<type>`, `<ident>`)
|
||||
private lazy val `<stmt>`: Parsley[NonEmptyList[Stmt]] =
|
||||
(
|
||||
`<basic-stmt>`.label("main program body"),
|
||||
(many(";" ~> `<basic-stmt>`.label("statement after ';'"))) </> Nil
|
||||
).zipped(NonEmptyList.apply)
|
||||
|
||||
private lazy val `<basic-stmt>` =
|
||||
(Skip from "skip")
|
||||
| Read("read" ~> `<lvalue>`)
|
||||
| Free("free" ~> `<expr>`.labelAndExplain(LabelType.Expr))
|
||||
| Return("return" ~> `<expr>`.labelAndExplain(LabelType.Expr))
|
||||
| Exit("exit" ~> `<expr>`.labelAndExplain(LabelType.Expr))
|
||||
| Print("print" ~> `<expr>`.labelAndExplain(LabelType.Expr), pure(false))
|
||||
| Print("println" ~> `<expr>`.labelAndExplain(LabelType.Expr), pure(true))
|
||||
| If(
|
||||
"if" ~> `<expr>`.labelWithType(LabelType.Expr) <~ "then",
|
||||
`<stmt>` <~ "else",
|
||||
`<stmt>` <~ "fi"
|
||||
)
|
||||
| While("while" ~> `<expr>`.labelWithType(LabelType.Expr) <~ "do", `<stmt>` <~ "done")
|
||||
| Block("begin" ~> `<stmt>` <~ "end")
|
||||
| VarDecl(
|
||||
`<type>`,
|
||||
`<ident>` <~ ("=" | "(".verifiedExplain(
|
||||
"all function declarations must be above the main program body"
|
||||
)),
|
||||
`<rvalue>`.label("valid initial value for variable")
|
||||
)
|
||||
| Assign(
|
||||
`<lvalue>` <~ ("=" | "(".verifiedExplain(
|
||||
"function calls must use the 'call' keyword and the result must be assigned to a variable"
|
||||
)),
|
||||
`<rvalue>`
|
||||
) |
|
||||
("call" ~> `<ident>`).verifiedExplain(
|
||||
"function calls' results must be assigned to a variable"
|
||||
)
|
||||
private lazy val `<lvalue>`: Parsley[LValue] =
|
||||
`<pair-elem>` | `<ident-or-array-elem>`
|
||||
private lazy val `<rvalue>`: Parsley[RValue] =
|
||||
`<array-liter>` |
|
||||
NewPair(
|
||||
"newpair" ~> "(" ~> `<expr>` <~ ",",
|
||||
`<expr>` <~ ")"
|
||||
) |
|
||||
`<pair-elem>` |
|
||||
Call(
|
||||
"call" ~> `<ident>` <~ "(",
|
||||
sepBy(`<expr>`, ",") <~ ")"
|
||||
) | `<expr>`.labelWithType(LabelType.Expr)
|
||||
private lazy val `<pair-elem>` =
|
||||
Fst("fst" ~> `<lvalue>`.label("valid pair"))
|
||||
| Snd("snd" ~> `<lvalue>`.label("valid pair"))
|
||||
private lazy val `<array-liter>` = ArrayLiter(
|
||||
"[" ~> sepBy(`<expr>`, ",") <~ "]"
|
||||
)
|
||||
|
||||
extension (stmts: NonEmptyList[Stmt]) {
|
||||
|
||||
/** Determines whether a function body is guaranteed to return in all cases This is required as
|
||||
* all functions must end via a "return" or "exit" statement
|
||||
*
|
||||
* @return
|
||||
* true if the statement list ends in a return statement, false otherwise
|
||||
*/
|
||||
def isReturning: Boolean = stmts.last match {
|
||||
case Return(_) | Exit(_) => true
|
||||
case If(_, thenStmt, elseStmt) => thenStmt.isReturning && elseStmt.isReturning
|
||||
case While(_, body) => body.isReturning
|
||||
case Block(body) => body.isReturning
|
||||
case _ => false
|
||||
}
|
||||
}
|
||||
}
|
||||
219
src/main/wacc/frontend/renamer.scala
Normal file
219
src/main/wacc/frontend/renamer.scala
Normal file
@@ -0,0 +1,219 @@
|
||||
package wacc
|
||||
|
||||
import scala.collection.mutable
|
||||
|
||||
object renamer {
|
||||
import ast._
|
||||
import types._
|
||||
|
||||
private enum IdentType {
|
||||
case Func
|
||||
case Var
|
||||
}
|
||||
|
||||
private class Scope(
|
||||
val current: mutable.Map[(String, IdentType), Ident],
|
||||
val parent: Map[(String, IdentType), Ident]
|
||||
) {
|
||||
|
||||
/** Create a new scope with the current scope as its parent.
|
||||
*
|
||||
* @return
|
||||
* A new scope with an empty current scope, and this scope flattened into the parent scope.
|
||||
*/
|
||||
def subscope: Scope =
|
||||
Scope(mutable.Map.empty, Map.empty.withDefault(current.withDefault(parent)))
|
||||
|
||||
/** Attempt to add a new identifier to the current scope. If the identifier already exists in
|
||||
* the current scope, add an error to the error list.
|
||||
*
|
||||
* @param ty
|
||||
* The semantic type of the variable identifier, or function identifier type.
|
||||
* @param name
|
||||
* The name of the identifier.
|
||||
* @param globalNames
|
||||
* The global map of identifiers to semantic types - the identifier will be added to this
|
||||
* map.
|
||||
* @param globalNumbering
|
||||
* The global map of identifier names to the number of times they have been declared - will
|
||||
* used to rename this identifier, and will be incremented.
|
||||
* @param errors
|
||||
* The list of errors to append to.
|
||||
*/
|
||||
def add(ty: SemType | FuncType, name: Ident)(using
|
||||
globalNames: mutable.Map[Ident, SemType],
|
||||
globalFuncs: mutable.Map[Ident, FuncType],
|
||||
globalNumbering: mutable.Map[String, Int],
|
||||
errors: mutable.Builder[Error, List[Error]]
|
||||
) = {
|
||||
val identType = ty match {
|
||||
case _: SemType => IdentType.Var
|
||||
case _: FuncType => IdentType.Func
|
||||
}
|
||||
current.get((name.v, identType)) match {
|
||||
case Some(Ident(_, uid)) =>
|
||||
errors += Error.DuplicateDeclaration(name)
|
||||
name.uid = uid
|
||||
case None =>
|
||||
val uid = globalNumbering.getOrElse(name.v, 0)
|
||||
name.uid = uid
|
||||
current((name.v, identType)) = name
|
||||
|
||||
ty match {
|
||||
case semType: SemType =>
|
||||
globalNames(name) = semType
|
||||
case funcType: FuncType =>
|
||||
globalFuncs(name) = funcType
|
||||
}
|
||||
globalNumbering(name.v) = uid + 1
|
||||
}
|
||||
}
|
||||
|
||||
private def get(name: String, identType: IdentType): Option[Ident] =
|
||||
// Unfortunately map defaults only work with `.apply()`, which throws an error when the key is not found.
|
||||
// Neither is there a way to check whether a default exists, so we have to use a try-catch.
|
||||
try {
|
||||
Some(current.withDefault(parent)((name, identType)))
|
||||
} catch {
|
||||
case _: NoSuchElementException => None
|
||||
}
|
||||
|
||||
def getVar(name: String): Option[Ident] = get(name, IdentType.Var)
|
||||
def getFunc(name: String): Option[Ident] = get(name, IdentType.Func)
|
||||
}
|
||||
|
||||
/** Check scoping of all variables and functions in the program. Also generate semantic types for
|
||||
* all identifiers.
|
||||
*
|
||||
* @param prog
|
||||
* AST of the program
|
||||
* @param errors
|
||||
* List of errors to append to
|
||||
* @return
|
||||
* Map of all (renamed) identifies to their semantic types
|
||||
*/
|
||||
def rename(prog: Program)(using
|
||||
errors: mutable.Builder[Error, List[Error]]
|
||||
): (Map[Ident, SemType], Map[Ident, FuncType]) = {
|
||||
given globalNames: mutable.Map[Ident, SemType] = mutable.Map.empty
|
||||
given globalFuncs: mutable.Map[Ident, FuncType] = mutable.Map.empty
|
||||
given globalNumbering: mutable.Map[String, Int] = mutable.Map.empty
|
||||
val scope = Scope(mutable.Map.empty, Map.empty)
|
||||
val Program(funcs, main) = prog
|
||||
funcs
|
||||
// First add all function declarations to the scope
|
||||
.map { case FuncDecl(retType, name, params, body) =>
|
||||
val paramTypes = params.map { param =>
|
||||
val paramType = SemType(param.paramType)
|
||||
paramType
|
||||
}
|
||||
scope.add(FuncType(SemType(retType), paramTypes), name)
|
||||
(params zip paramTypes, body)
|
||||
}
|
||||
// Only then rename the function bodies
|
||||
// (functions can call one-another regardless of order of declaration)
|
||||
.foreach { case (params, body) =>
|
||||
val functionScope = scope.subscope
|
||||
params.foreach { case (param, paramType) =>
|
||||
functionScope.add(paramType, param.name)
|
||||
}
|
||||
body.toList.foreach(rename(functionScope.subscope)) // body can shadow function params
|
||||
}
|
||||
main.toList.foreach(rename(scope))
|
||||
(globalNames.toMap, globalFuncs.toMap)
|
||||
}
|
||||
|
||||
/** Check scoping of all identifies in a given AST node.
|
||||
*
|
||||
* @param scope
|
||||
* The current scope and flattened parent scope.
|
||||
* @param node
|
||||
* The AST node.
|
||||
* @param globalNames
|
||||
* The global map of identifiers to semantic types - renamed identifiers will be added to this
|
||||
* map.
|
||||
* @param globalNumbering
|
||||
* The global map of identifier names to the number of times they have been declared - used and
|
||||
* updated during identifier renaming.
|
||||
* @param errors
|
||||
*/
|
||||
private def rename(scope: Scope)(
|
||||
node: Ident | Stmt | LValue | RValue | Expr
|
||||
)(using
|
||||
globalNames: mutable.Map[Ident, SemType],
|
||||
globalFuncs: mutable.Map[Ident, FuncType],
|
||||
globalNumbering: mutable.Map[String, Int],
|
||||
errors: mutable.Builder[Error, List[Error]]
|
||||
): Unit = node match {
|
||||
// These cases are more interesting because the involve making subscopes
|
||||
// or modifying the current scope.
|
||||
case VarDecl(synType, name, value) => {
|
||||
// Order matters here. Variable isn't declared until after the value is evaluated.
|
||||
rename(scope)(value)
|
||||
// Attempt to add the new variable to the current scope.
|
||||
scope.add(SemType(synType), name)
|
||||
}
|
||||
case If(cond, thenStmt, elseStmt) => {
|
||||
rename(scope)(cond)
|
||||
// then and else both have their own scopes
|
||||
thenStmt.toList.foreach(rename(scope.subscope))
|
||||
elseStmt.toList.foreach(rename(scope.subscope))
|
||||
}
|
||||
case While(cond, body) => {
|
||||
rename(scope)(cond)
|
||||
// while bodies have their own scopes
|
||||
body.toList.foreach(rename(scope.subscope))
|
||||
}
|
||||
// begin-end blocks have their own scopes
|
||||
case Block(body) => body.toList.foreach(rename(scope.subscope))
|
||||
|
||||
// These cases are simpler, mostly just recursive calls to rename()
|
||||
case Assign(lhs, value) => {
|
||||
// Variables may be reassigned with their value in the rhs, so order doesn't matter here.
|
||||
rename(scope)(lhs)
|
||||
rename(scope)(value)
|
||||
}
|
||||
case Read(lhs) => rename(scope)(lhs)
|
||||
case Free(expr) => rename(scope)(expr)
|
||||
case Return(expr) => rename(scope)(expr)
|
||||
case Exit(expr) => rename(scope)(expr)
|
||||
case Print(expr, _) => rename(scope)(expr)
|
||||
case NewPair(fst, snd) => {
|
||||
rename(scope)(fst)
|
||||
rename(scope)(snd)
|
||||
}
|
||||
case Call(name, args) => {
|
||||
scope.getFunc(name.v) match {
|
||||
case Some(Ident(_, uid)) => name.uid = uid
|
||||
case None =>
|
||||
errors += Error.UndefinedFunction(name)
|
||||
scope.add(FuncType(?, args.map(_ => ?)), name)
|
||||
}
|
||||
args.foreach(rename(scope))
|
||||
}
|
||||
case Fst(elem) => rename(scope)(elem)
|
||||
case Snd(elem) => rename(scope)(elem)
|
||||
case ArrayLiter(elems) => elems.foreach(rename(scope))
|
||||
case ArrayElem(name, indices) => {
|
||||
rename(scope)(name)
|
||||
indices.toList.foreach(rename(scope))
|
||||
}
|
||||
case Parens(expr) => rename(scope)(expr)
|
||||
case op: UnaryOp => rename(scope)(op.x)
|
||||
case op: BinaryOp => {
|
||||
rename(scope)(op.x)
|
||||
rename(scope)(op.y)
|
||||
}
|
||||
// Default to variables. Only `call` uses IdentType.Func.
|
||||
case id: Ident => {
|
||||
scope.getVar(id.v) match {
|
||||
case Some(Ident(_, uid)) => id.uid = uid
|
||||
case None =>
|
||||
errors += Error.UndeclaredVariable(id)
|
||||
scope.add(?, id)
|
||||
}
|
||||
}
|
||||
// These literals cannot contain identifies, exit immediately.
|
||||
case IntLiter(_) | BoolLiter(_) | CharLiter(_) | StrLiter(_) | PairLiter() | Skip() => ()
|
||||
}
|
||||
}
|
||||
322
src/main/wacc/frontend/typeChecker.scala
Normal file
322
src/main/wacc/frontend/typeChecker.scala
Normal file
@@ -0,0 +1,322 @@
|
||||
package wacc
|
||||
|
||||
import cats.syntax.all._
|
||||
import scala.collection.mutable
|
||||
|
||||
object typeChecker {
|
||||
import wacc.ast._
|
||||
import wacc.types._
|
||||
|
||||
case class TypeCheckerCtx(
|
||||
globalNames: Map[Ident, SemType],
|
||||
globalFuncs: Map[Ident, FuncType],
|
||||
errors: mutable.Builder[Error, List[Error]]
|
||||
) {
|
||||
def typeOf(ident: Ident): SemType = globalNames(ident)
|
||||
|
||||
def funcType(ident: Ident): FuncType = globalFuncs(ident)
|
||||
|
||||
def error(err: Error): SemType =
|
||||
errors += err
|
||||
?
|
||||
}
|
||||
|
||||
private enum Constraint {
|
||||
case Unconstrained
|
||||
// Allows weakening in one direction
|
||||
case Is(ty: SemType, msg: String)
|
||||
// Allows weakening in both directions, useful for array literals
|
||||
case IsSymmetricCompatible(ty: SemType, msg: String)
|
||||
// Does not allow weakening
|
||||
case IsUnweakenable(ty: SemType, msg: String)
|
||||
case IsEither(ty1: SemType, ty2: SemType, msg: String)
|
||||
case Never(msg: String)
|
||||
}
|
||||
|
||||
extension (ty: SemType) {
|
||||
|
||||
/** Check if a type satisfies a constraint.
|
||||
*
|
||||
* @param constraint
|
||||
* Constraint to satisfy.
|
||||
* @param pos
|
||||
* Position to pass to the error, if constraint was not satisfied.
|
||||
* @return
|
||||
* The type if the constraint was satisfied, or ? if it was not.
|
||||
*/
|
||||
private def satisfies(constraint: Constraint, pos: Position)(using
|
||||
ctx: TypeCheckerCtx
|
||||
): SemType =
|
||||
(ty, constraint) match {
|
||||
case (KnownType.Array(KnownType.Char), Constraint.Is(KnownType.String, _)) =>
|
||||
KnownType.String
|
||||
case (
|
||||
KnownType.String,
|
||||
Constraint.IsSymmetricCompatible(KnownType.Array(KnownType.Char), _)
|
||||
) =>
|
||||
KnownType.String
|
||||
case (ty, Constraint.IsSymmetricCompatible(ty2, msg)) =>
|
||||
ty.satisfies(Constraint.Is(ty2, msg), pos)
|
||||
// Change to IsUnweakenable to disallow recursive weakening
|
||||
case (ty, Constraint.Is(ty2, msg)) => ty.satisfies(Constraint.IsUnweakenable(ty2, msg), pos)
|
||||
case (ty, Constraint.Unconstrained) => ty
|
||||
case (ty, Constraint.Never(msg)) =>
|
||||
ctx.error(Error.SemanticError(pos, msg))
|
||||
case (ty, Constraint.IsEither(ty1, ty2, msg)) =>
|
||||
(ty moreSpecific ty1).orElse(ty moreSpecific ty2).getOrElse {
|
||||
ctx.error(Error.TypeMismatch(pos, ty1, ty, msg))
|
||||
}
|
||||
case (ty, Constraint.IsUnweakenable(ty2, msg)) =>
|
||||
(ty moreSpecific ty2).getOrElse {
|
||||
ctx.error(Error.TypeMismatch(pos, ty2, ty, msg))
|
||||
}
|
||||
}
|
||||
|
||||
/** Tries to merge two types, returning the more specific one if possible.
|
||||
*
|
||||
* @param ty2
|
||||
* The other type to merge with.
|
||||
* @return
|
||||
* The more specific type if it could be determined, or None if the types are incompatible.
|
||||
*/
|
||||
private infix def moreSpecific(ty2: SemType): Option[SemType] =
|
||||
(ty, ty2) match {
|
||||
case (ty, ?) => Some(ty)
|
||||
case (?, ty) => Some(ty)
|
||||
case (ty1, ty2) if ty1 == ty2 => Some(ty1)
|
||||
case (KnownType.Array(inn1), KnownType.Array(inn2)) =>
|
||||
(inn1 moreSpecific inn2).map(KnownType.Array(_))
|
||||
case (KnownType.Pair(fst1, snd1), KnownType.Pair(fst2, snd2)) =>
|
||||
(fst1 moreSpecific fst2, snd1 moreSpecific snd2).mapN(KnownType.Pair(_, _))
|
||||
case _ => None
|
||||
}
|
||||
}
|
||||
|
||||
/** Type-check a WACC program.
|
||||
*
|
||||
* @param prog
|
||||
* The AST of the program to type-check.
|
||||
* @param ctx
|
||||
* The type checker context which includes the global names and functions, and an errors
|
||||
* builder.
|
||||
*/
|
||||
def check(prog: Program)(using
|
||||
ctx: TypeCheckerCtx
|
||||
): Unit = {
|
||||
// Ignore function syntax types for return value and params, since those have been converted
|
||||
// to SemTypes by the renamer.
|
||||
prog.funcs.foreach { case FuncDecl(_, name, _, stmts) =>
|
||||
val FuncType(retType, _) = ctx.funcType(name)
|
||||
stmts.toList.foreach(
|
||||
checkStmt(_, Constraint.Is(retType, s"function ${name.v} must return $retType"))
|
||||
)
|
||||
}
|
||||
prog.main.toList.foreach(checkStmt(_, Constraint.Never("main function must not return")))
|
||||
}
|
||||
|
||||
/** Type-check an AST statement node.
|
||||
*
|
||||
* @param stmt
|
||||
* The statement to type-check.
|
||||
* @param returnConstraint
|
||||
* The constraint that any `return <expr>` statements must satisfy.
|
||||
*/
|
||||
private def checkStmt(stmt: Stmt, returnConstraint: Constraint)(using
|
||||
ctx: TypeCheckerCtx
|
||||
): Unit = stmt match {
|
||||
// Ignore the type of the variable, since it has been converted to a SemType by the renamer.
|
||||
case VarDecl(_, name, value) =>
|
||||
val expectedTy = ctx.typeOf(name)
|
||||
checkValue(
|
||||
value,
|
||||
Constraint.Is(
|
||||
expectedTy,
|
||||
s"variable ${name.v} must be assigned a value of type $expectedTy"
|
||||
)
|
||||
)
|
||||
case Assign(lhs, rhs) =>
|
||||
val lhsTy = checkValue(lhs, Constraint.Unconstrained)
|
||||
(lhsTy, checkValue(rhs, Constraint.Is(lhsTy, s"assignment must have type $lhsTy"))) match {
|
||||
case (?, ?) =>
|
||||
ctx.error(
|
||||
Error.SemanticError(lhs.pos, "assignment with both sides of unknown type is illegal")
|
||||
)
|
||||
case _ => ()
|
||||
}
|
||||
case Read(dest) =>
|
||||
checkValue(dest, Constraint.Unconstrained) match {
|
||||
case ? =>
|
||||
ctx.error(
|
||||
Error.SemanticError(dest.pos, "cannot read into a destination with an unknown type")
|
||||
)
|
||||
case destTy =>
|
||||
destTy.satisfies(
|
||||
Constraint.IsEither(
|
||||
KnownType.Int,
|
||||
KnownType.Char,
|
||||
"read must be applied to an int or char"
|
||||
),
|
||||
dest.pos
|
||||
)
|
||||
}
|
||||
case Free(lhs) =>
|
||||
checkValue(
|
||||
lhs,
|
||||
Constraint.IsEither(
|
||||
KnownType.Array(?),
|
||||
KnownType.Pair(?, ?),
|
||||
"free must be applied to an array or pair"
|
||||
)
|
||||
)
|
||||
case Return(expr) =>
|
||||
checkValue(expr, returnConstraint)
|
||||
case Exit(expr) =>
|
||||
checkValue(expr, Constraint.Is(KnownType.Int, "exit value must be int"))
|
||||
case Print(expr, _) =>
|
||||
// This constraint should never fail, the scope-checker should have caught it already
|
||||
checkValue(expr, Constraint.Unconstrained)
|
||||
case If(cond, thenStmt, elseStmt) =>
|
||||
checkValue(cond, Constraint.Is(KnownType.Bool, "if condition must be a bool"))
|
||||
thenStmt.toList.foreach(checkStmt(_, returnConstraint))
|
||||
elseStmt.toList.foreach(checkStmt(_, returnConstraint))
|
||||
case While(cond, body) =>
|
||||
checkValue(cond, Constraint.Is(KnownType.Bool, "while condition must be a bool"))
|
||||
body.toList.foreach(checkStmt(_, returnConstraint))
|
||||
case Block(body) =>
|
||||
body.toList.foreach(checkStmt(_, returnConstraint))
|
||||
case Skip() => ()
|
||||
}
|
||||
|
||||
/** Type-check an AST LValue, RValue or Expr node. This function does all 3 since these traits
|
||||
* overlap in the AST.
|
||||
*
|
||||
* @param value
|
||||
* The value to type-check.
|
||||
* @param constraint
|
||||
* The type constraint that the value must satisfy.
|
||||
* @return
|
||||
* The most specific type of the value if it could be determined, or ? if it could not.
|
||||
*/
|
||||
private def checkValue(value: LValue | RValue | Expr, constraint: Constraint)(using
|
||||
ctx: TypeCheckerCtx
|
||||
): SemType = value match {
|
||||
case l @ IntLiter(_) => KnownType.Int.satisfies(constraint, l.pos)
|
||||
case l @ BoolLiter(_) => KnownType.Bool.satisfies(constraint, l.pos)
|
||||
case l @ CharLiter(_) => KnownType.Char.satisfies(constraint, l.pos)
|
||||
case l @ StrLiter(_) => KnownType.String.satisfies(constraint, l.pos)
|
||||
case l @ PairLiter() => KnownType.Pair(?, ?).satisfies(constraint, l.pos)
|
||||
case id: Ident =>
|
||||
ctx.typeOf(id).satisfies(constraint, id.pos)
|
||||
case ArrayElem(id, indices) =>
|
||||
val arrayTy = ctx.typeOf(id)
|
||||
val elemTy = indices.foldLeftM(arrayTy) { (acc, elem) =>
|
||||
checkValue(elem, Constraint.Is(KnownType.Int, "array index must be an int"))
|
||||
acc match {
|
||||
case KnownType.Array(innerTy) => Some(innerTy)
|
||||
case ? => Some(?) // we can keep indexing an unknown type
|
||||
case nonArrayTy =>
|
||||
ctx.error(
|
||||
Error.TypeMismatch(elem.pos, KnownType.Array(?), acc, "cannot index into a non-array")
|
||||
)
|
||||
None
|
||||
}
|
||||
}
|
||||
elemTy.getOrElse(?).satisfies(constraint, id.pos)
|
||||
case Parens(expr) => checkValue(expr, constraint)
|
||||
case l @ ArrayLiter(elems) =>
|
||||
KnownType
|
||||
// Start with an unknown param type, make it more specific while checking the elements.
|
||||
.Array(elems.foldLeft[SemType](?) { case (acc, elem) =>
|
||||
checkValue(
|
||||
elem,
|
||||
Constraint.IsSymmetricCompatible(acc, s"array elements must have the same type")
|
||||
)
|
||||
})
|
||||
.satisfies(constraint, l.pos)
|
||||
case l @ NewPair(fst, snd) =>
|
||||
KnownType
|
||||
.Pair(
|
||||
checkValue(fst, Constraint.Unconstrained),
|
||||
checkValue(snd, Constraint.Unconstrained)
|
||||
)
|
||||
.satisfies(constraint, l.pos)
|
||||
case Call(id, args) =>
|
||||
val funcTy @ FuncType(retTy, paramTys) = ctx.funcType(id)
|
||||
if (args.length != paramTys.length) {
|
||||
ctx.error(Error.FunctionParamsMismatch(id, paramTys.length, args.length, funcTy))
|
||||
}
|
||||
// Even if the number of arguments is wrong, we still check the types of the arguments
|
||||
// in the best way we can (by taking a zip).
|
||||
args.zip(paramTys).foreach { case (arg, paramTy) =>
|
||||
checkValue(arg, Constraint.Is(paramTy, s"argument type mismatch in function ${id.v}"))
|
||||
}
|
||||
retTy.satisfies(constraint, id.pos)
|
||||
case Fst(elem) =>
|
||||
checkValue(
|
||||
elem,
|
||||
Constraint.Is(KnownType.Pair(?, ?), "fst must be applied to a pair")
|
||||
) match {
|
||||
case what @ KnownType.Pair(left, _) =>
|
||||
left.satisfies(constraint, elem.pos)
|
||||
case _ => ctx.error(Error.InternalError(elem.pos, "fst must be applied to a pair"))
|
||||
}
|
||||
case Snd(elem) =>
|
||||
checkValue(
|
||||
elem,
|
||||
Constraint.Is(KnownType.Pair(?, ?), "snd must be applied to a pair")
|
||||
) match {
|
||||
case KnownType.Pair(_, right) => right.satisfies(constraint, elem.pos)
|
||||
case _ => ctx.error(Error.InternalError(elem.pos, "snd must be applied to a pair"))
|
||||
}
|
||||
|
||||
// Unary operators
|
||||
case Negate(x) =>
|
||||
checkValue(x, Constraint.Is(KnownType.Int, "negation must be applied to an int"))
|
||||
KnownType.Int.satisfies(constraint, x.pos)
|
||||
case Not(x) =>
|
||||
checkValue(x, Constraint.Is(KnownType.Bool, "logical not must be applied to a bool"))
|
||||
KnownType.Bool.satisfies(constraint, x.pos)
|
||||
case Len(x) =>
|
||||
checkValue(x, Constraint.Is(KnownType.Array(?), "len must be applied to an array"))
|
||||
KnownType.Int.satisfies(constraint, x.pos)
|
||||
case Ord(x) =>
|
||||
checkValue(x, Constraint.Is(KnownType.Char, "ord must be applied to a char"))
|
||||
KnownType.Int.satisfies(constraint, x.pos)
|
||||
case Chr(x) =>
|
||||
checkValue(x, Constraint.Is(KnownType.Int, "chr must be applied to an int"))
|
||||
KnownType.Char.satisfies(constraint, x.pos)
|
||||
|
||||
// Binary operators
|
||||
case op: (Add | Sub | Mul | Div | Mod) =>
|
||||
val operand = Constraint.Is(KnownType.Int, s"${op.name} operator must be applied to an int")
|
||||
checkValue(op.x, operand)
|
||||
checkValue(op.y, operand)
|
||||
KnownType.Int.satisfies(constraint, op.pos)
|
||||
case op: (Eq | Neq) =>
|
||||
val xTy = checkValue(op.x, Constraint.Unconstrained)
|
||||
checkValue(
|
||||
op.y,
|
||||
Constraint.Is(xTy, s"${op.name} operator must be applied to values of the same type")
|
||||
)
|
||||
KnownType.Bool.satisfies(constraint, op.pos)
|
||||
case op: (Less | LessEq | Greater | GreaterEq) =>
|
||||
val xConstraint = Constraint.IsEither(
|
||||
KnownType.Int,
|
||||
KnownType.Char,
|
||||
s"${op.name} operator must be applied to an int or char"
|
||||
)
|
||||
// If x type-check failed, we still want to check y is an Int or Char (rather than ?)
|
||||
val yConstraint = checkValue(op.x, xConstraint) match {
|
||||
case ? => xConstraint
|
||||
case xTy =>
|
||||
Constraint.Is(xTy, s"${op.name} operator must be applied to values of the same type")
|
||||
}
|
||||
checkValue(op.y, yConstraint)
|
||||
KnownType.Bool.satisfies(constraint, op.pos)
|
||||
case op: (And | Or) =>
|
||||
val operand = Constraint.Is(KnownType.Bool, s"${op.name} operator must be applied to a bool")
|
||||
checkValue(op.x, operand)
|
||||
checkValue(op.y, operand)
|
||||
KnownType.Bool.satisfies(constraint, op.pos)
|
||||
}
|
||||
}
|
||||
45
src/main/wacc/frontend/types.scala
Normal file
45
src/main/wacc/frontend/types.scala
Normal file
@@ -0,0 +1,45 @@
|
||||
package wacc
|
||||
|
||||
object types {
|
||||
import ast._
|
||||
|
||||
sealed trait SemType {
|
||||
override def toString(): String = this match {
|
||||
case KnownType.Int => "int"
|
||||
case KnownType.Bool => "bool"
|
||||
case KnownType.Char => "char"
|
||||
case KnownType.String => "string"
|
||||
case KnownType.Array(?) => "array"
|
||||
case KnownType.Array(elem) => s"$elem[]"
|
||||
case KnownType.Pair(?, ?) => "pair"
|
||||
case KnownType.Pair(left, right) => s"pair($left, $right)"
|
||||
case ? => "<unknown-type>"
|
||||
}
|
||||
}
|
||||
|
||||
case object ? extends SemType
|
||||
enum KnownType extends SemType {
|
||||
case Int
|
||||
case Bool
|
||||
case Char
|
||||
case String
|
||||
case Array(elem: SemType)
|
||||
case Pair(left: SemType, right: SemType)
|
||||
}
|
||||
|
||||
object SemType {
|
||||
def apply(synType: Type | PairElemType): KnownType = synType match {
|
||||
case IntType() => KnownType.Int
|
||||
case BoolType() => KnownType.Bool
|
||||
case CharType() => KnownType.Char
|
||||
case StringType() => KnownType.String
|
||||
// For semantic types it is easier to work with recursion rather than a fixed size
|
||||
case ArrayType(elemType, dimension) =>
|
||||
(0 until dimension).foldLeft(SemType(elemType))((acc, _) => KnownType.Array(acc))
|
||||
case PairType(fst, snd) => KnownType.Pair(SemType(fst), SemType(snd))
|
||||
case UntypedPairType() => KnownType.Pair(?, ?)
|
||||
}
|
||||
}
|
||||
|
||||
case class FuncType(returnType: SemType, params: List[SemType])
|
||||
}
|
||||
Reference in New Issue
Block a user