diff --git a/src/main/wacc/ast.scala b/src/main/wacc/ast.scala index c6e743e..1fe26b7 100644 --- a/src/main/wacc/ast.scala +++ b/src/main/wacc/ast.scala @@ -1,7 +1,9 @@ package wacc +import parsley.errors.combinator._ import parsley.generic._ import cats.data.NonEmptyList +import parsley.Parsley object ast { // Expressions @@ -96,7 +98,18 @@ object ast { params: List[Param], body: NonEmptyList[Stmt] ) - object FuncDecl extends ParserBridge4[Type, Ident, List[Param], NonEmptyList[Stmt], FuncDecl] + object FuncDecl extends ParserBridge4[Type, Ident, List[Param], NonEmptyList[Stmt], FuncDecl] { + override def apply( + x1: Parsley[Type], + x2: => Parsley[Ident], + x3: => Parsley[List[Param]], + x4: => Parsley[NonEmptyList[Stmt]] + ): Parsley[FuncDecl] = + super.apply(x1, x2, x3, x4).guardAgainst { + case FuncDecl(_, _, _, body) if !body.isReturning => + Seq("Function must return on all paths") + } + } case class Param(paramType: Type, name: Ident) object Param extends ParserBridge2[Type, Ident, Param] @@ -140,4 +153,14 @@ object ast { object Fst extends ParserBridge1[LValue, Fst] case class Snd(elem: LValue) extends PairElem object Snd extends ParserBridge1[LValue, Snd] + + extension (stmts: NonEmptyList[Stmt]) { + def isReturning: Boolean = stmts.last match { + case Return(_) | Exit(_) => true + case If(_, thenStmt, elseStmt) => thenStmt.isReturning && elseStmt.isReturning + case While(_, body) => body.isReturning + case Block(body) => body.isReturning + case _ => false + } + } }