feat: pass stmt position information to microwacc

This commit is contained in:
2025-03-14 14:02:15 +00:00
parent af514b3363
commit 07f02e61d7
4 changed files with 45 additions and 31 deletions

View File

@@ -261,7 +261,7 @@ object asmGenerator {
asm += stack.push(KnownType.String.size, RAX) asm += stack.push(KnownType.String.size, RAX)
case ty => case ty =>
asm ++= generateCall( asm ++= generateCall(
microWacc.Call(Builtin.Malloc, List(IntLiter(array.heapSize))), microWacc.Call(Builtin.Malloc, List(IntLiter(array.heapSize)))(array.pos),
isTail = false isTail = false
) )
asm += stack.push(KnownType.Array(?).size, RAX) asm += stack.push(KnownType.Array(?).size, RAX)

View File

@@ -223,7 +223,9 @@ object ast {
val pos: Position val pos: Position
} }
sealed trait RValue sealed trait RValue {
val pos: Position
}
case class ArrayLiter(elems: List[Expr])(val pos: Position) extends RValue case class ArrayLiter(elems: List[Expr])(val pos: Position) extends RValue
object ArrayLiter extends ParserBridgePos1[List[Expr], ArrayLiter] object ArrayLiter extends ParserBridgePos1[List[Expr], ArrayLiter]
case class NewPair(fst: Expr, snd: Expr)(val pos: Position) extends RValue case class NewPair(fst: Expr, snd: Expr)(val pos: Position) extends RValue

View File

@@ -3,6 +3,7 @@ package wacc
import cats.data.Chain import cats.data.Chain
object microWacc { object microWacc {
import wacc.ast.Position
import wacc.types._ import wacc.types._
sealed trait CallTarget(val retTy: SemType) sealed trait CallTarget(val retTy: SemType)
@@ -13,7 +14,7 @@ object microWacc {
case class IntLiter(v: Int) extends Expr(KnownType.Int) case class IntLiter(v: Int) extends Expr(KnownType.Int)
case class BoolLiter(v: Boolean) extends Expr(KnownType.Bool) case class BoolLiter(v: Boolean) extends Expr(KnownType.Bool)
case class CharLiter(v: Char) extends Expr(KnownType.Char) case class CharLiter(v: Char) extends Expr(KnownType.Char)
case class ArrayLiter(elems: List[Expr])(ty: SemType) extends Expr(ty) case class ArrayLiter(elems: List[Expr])(ty: SemType, val pos: Position) extends Expr(ty)
case class NullLiter()(ty: SemType) extends Expr(ty) case class NullLiter()(ty: SemType) extends Expr(ty)
case class Ident(name: String, uid: Int)(identTy: SemType) case class Ident(name: String, uid: Int)(identTy: SemType)
extends Expr(identTy) extends Expr(identTy)
@@ -65,7 +66,9 @@ object microWacc {
} }
// Statements // Statements
sealed trait Stmt sealed trait Stmt {
val pos: Position
}
case class Builtin(val name: String)(retTy: SemType) extends CallTarget(retTy) { case class Builtin(val name: String)(retTy: SemType) extends CallTarget(retTy) {
override def toString(): String = name override def toString(): String = name
@@ -79,11 +82,14 @@ object microWacc {
object PrintCharArray extends Builtin("printCharArray")(?) object PrintCharArray extends Builtin("printCharArray")(?)
} }
case class Assign(lhs: LValue, rhs: Expr) extends Stmt case class Assign(lhs: LValue, rhs: Expr)(val pos: Position) extends Stmt
case class If(cond: Expr, thenBranch: Chain[Stmt], elseBranch: Chain[Stmt]) extends Stmt case class If(cond: Expr, thenBranch: Chain[Stmt], elseBranch: Chain[Stmt])(val pos: Position)
case class While(cond: Expr, body: Chain[Stmt]) extends Stmt extends Stmt
case class Call(target: CallTarget, args: List[Expr]) extends Stmt with Expr(target.retTy) case class While(cond: Expr, body: Chain[Stmt])(val pos: Position) extends Stmt
case class Return(expr: Expr) extends Stmt case class Call(target: CallTarget, args: List[Expr])(val pos: Position)
extends Stmt
with Expr(target.retTy)
case class Return(expr: Expr)(val pos: Position) extends Stmt
// Program // Program
case class FuncDecl(name: Ident, params: List[Ident], body: Chain[Stmt]) case class FuncDecl(name: Ident, params: List[Ident], body: Chain[Stmt])

View File

@@ -126,7 +126,7 @@ object typeChecker {
microWacc.Assign( microWacc.Assign(
microWacc.Ident(name.v, name.guid)(expectedTy.asInstanceOf[SemType]), microWacc.Ident(name.v, name.guid)(expectedTy.asInstanceOf[SemType]),
typedValue typedValue
) )(stmt.pos)
), ),
valueErrors valueErrors
) )
@@ -141,7 +141,10 @@ object typeChecker {
) )
case _ => Chain.empty case _ => Chain.empty
} }
(Chain.one(microWacc.Assign(lhsTyped, rhsTyped)), lhsErrors ++ rhsErrors ++ unknownError) (
Chain.one(microWacc.Assign(lhsTyped, rhsTyped)(stmt.pos)),
lhsErrors ++ rhsErrors ++ unknownError
)
case ast.Read(dest) => case ast.Read(dest) =>
val (destTyped, destErrors) = checkLValue(dest, Constraint.Unconstrained) val (destTyped, destErrors) = checkLValue(dest, Constraint.Unconstrained)
val (destTy, destTyErrors) = destTyped.ty match { val (destTy, destTyErrors) = destTyped.ty match {
@@ -170,13 +173,13 @@ object typeChecker {
microWacc.Builtin.Read, microWacc.Builtin.Read,
List( List(
destTy match { destTy match {
case KnownType.Int => " %d".toMicroWaccCharArray case KnownType.Int => " %d".toMicroWaccCharArray(stmt.pos)
case KnownType.Char | _ => " %c".toMicroWaccCharArray case KnownType.Char | _ => " %c".toMicroWaccCharArray(stmt.pos)
}, },
destTyped destTyped
) )
) )(dest.pos)
) )(stmt.pos)
), ),
destErrors ++ destTyErrors destErrors ++ destTyErrors
) )
@@ -189,14 +192,14 @@ object typeChecker {
"free must be applied to an array or pair" "free must be applied to an array or pair"
) )
) )
(Chain.one(microWacc.Call(microWacc.Builtin.Free, List(lhsTyped))), lhsErrors) (Chain.one(microWacc.Call(microWacc.Builtin.Free, List(lhsTyped))(stmt.pos)), lhsErrors)
case ast.Return(expr) => case ast.Return(expr) =>
val (exprTyped, exprErrors) = checkValue(expr, returnConstraint) val (exprTyped, exprErrors) = checkValue(expr, returnConstraint)
(Chain.one(microWacc.Return(exprTyped)), exprErrors) (Chain.one(microWacc.Return(exprTyped)(stmt.pos)), exprErrors)
case ast.Exit(expr) => case ast.Exit(expr) =>
val (exprTyped, exprErrors) = val (exprTyped, exprErrors) =
checkValue(expr, Constraint.Is(KnownType.Int, "exit value must be int")) checkValue(expr, Constraint.Is(KnownType.Int, "exit value must be int"))
(Chain.one(microWacc.Call(microWacc.Builtin.Exit, List(exprTyped))), exprErrors) (Chain.one(microWacc.Call(microWacc.Builtin.Exit, List(exprTyped))(stmt.pos)), exprErrors)
case ast.Print(expr, newline) => case ast.Print(expr, newline) =>
// This constraint should never fail, the scope-checker should have caught it already // This constraint should never fail, the scope-checker should have caught it already
val (exprTyped, exprErrors) = checkValue(expr, Constraint.Unconstrained) val (exprTyped, exprErrors) = checkValue(expr, Constraint.Unconstrained)
@@ -212,10 +215,10 @@ object typeChecker {
microWacc.Call( microWacc.Call(
func, func,
List( List(
s"$exprFormat${if newline then "\n" else ""}".toMicroWaccCharArray, s"$exprFormat${if newline then "\n" else ""}".toMicroWaccCharArray(stmt.pos),
value value
) )
) )(stmt.pos)
) )
} }
( (
@@ -224,9 +227,9 @@ object typeChecker {
Chain.one( Chain.one(
microWacc.If( microWacc.If(
exprTyped, exprTyped,
printfCall(microWacc.Builtin.Printf, "true".toMicroWaccCharArray), printfCall(microWacc.Builtin.Printf, "true".toMicroWaccCharArray(stmt.pos)),
printfCall(microWacc.Builtin.Printf, "false".toMicroWaccCharArray) printfCall(microWacc.Builtin.Printf, "false".toMicroWaccCharArray(stmt.pos))
) )(stmt.pos)
) )
case KnownType.Array(KnownType.Char) => case KnownType.Array(KnownType.Char) =>
printfCall(microWacc.Builtin.PrintCharArray, exprTyped) printfCall(microWacc.Builtin.PrintCharArray, exprTyped)
@@ -240,14 +243,14 @@ object typeChecker {
val (thenStmtTyped, thenErrors) = thenStmt.foldMap(checkStmt(_, returnConstraint)) val (thenStmtTyped, thenErrors) = thenStmt.foldMap(checkStmt(_, returnConstraint))
val (elseStmtTyped, elseErrors) = elseStmt.foldMap(checkStmt(_, returnConstraint)) val (elseStmtTyped, elseErrors) = elseStmt.foldMap(checkStmt(_, returnConstraint))
( (
Chain.one(microWacc.If(condTyped, thenStmtTyped, elseStmtTyped)), Chain.one(microWacc.If(condTyped, thenStmtTyped, elseStmtTyped)(cond.pos)),
condErrors ++ thenErrors ++ elseErrors condErrors ++ thenErrors ++ elseErrors
) )
case ast.While(cond, body) => case ast.While(cond, body) =>
val (condTyped, condErrors) = val (condTyped, condErrors) =
checkValue(cond, Constraint.Is(KnownType.Bool, "while condition must be a bool")) checkValue(cond, Constraint.Is(KnownType.Bool, "while condition must be a bool"))
val (bodyTyped, bodyErrors) = body.foldMap(checkStmt(_, returnConstraint)) val (bodyTyped, bodyErrors) = body.foldMap(checkStmt(_, returnConstraint))
(Chain.one(microWacc.While(condTyped, bodyTyped)), condErrors ++ bodyErrors) (Chain.one(microWacc.While(condTyped, bodyTyped)(cond.pos)), condErrors ++ bodyErrors)
case ast.Block(body) => body.foldMap(checkStmt(_, returnConstraint)) case ast.Block(body) => body.foldMap(checkStmt(_, returnConstraint))
case skip @ ast.Skip() => (Chain.empty, Chain.empty) case skip @ ast.Skip() => (Chain.empty, Chain.empty)
} }
@@ -277,7 +280,7 @@ object typeChecker {
(microWacc.CharLiter(v), errors) (microWacc.CharLiter(v), errors)
case l @ ast.StrLiter(v) => case l @ ast.StrLiter(v) =>
val (_, errors) = KnownType.String.satisfies(constraint, l.pos) val (_, errors) = KnownType.String.satisfies(constraint, l.pos)
(v.toMicroWaccCharArray, errors) (v.toMicroWaccCharArray(l.pos), errors)
case l @ ast.PairLiter() => case l @ ast.PairLiter() =>
val (ty, errors) = KnownType.Pair(?, ?).satisfies(constraint, l.pos) val (ty, errors) = KnownType.Pair(?, ?).satisfies(constraint, l.pos)
(microWacc.NullLiter()(ty), errors) (microWacc.NullLiter()(ty), errors)
@@ -296,13 +299,16 @@ object typeChecker {
// Start with an unknown param type, make it more specific while checking the elements. // Start with an unknown param type, make it more specific while checking the elements.
.Array(elemTy) .Array(elemTy)
.satisfies(constraint, l.pos) .satisfies(constraint, l.pos)
(microWacc.ArrayLiter(elemsTyped)(arrayTy), elemsErrors ++ arrayErrors) (microWacc.ArrayLiter(elemsTyped)(arrayTy, l.pos), elemsErrors ++ arrayErrors)
case l @ ast.NewPair(fst, snd) => case l @ ast.NewPair(fst, snd) =>
val (fstTyped, fstErrors) = checkValue(fst, Constraint.Unconstrained) val (fstTyped, fstErrors) = checkValue(fst, Constraint.Unconstrained)
val (sndTyped, sndErrors) = checkValue(snd, Constraint.Unconstrained) val (sndTyped, sndErrors) = checkValue(snd, Constraint.Unconstrained)
val (pairTy, pairErrors) = val (pairTy, pairErrors) =
KnownType.Pair(fstTyped.ty, sndTyped.ty).satisfies(constraint, l.pos) KnownType.Pair(fstTyped.ty, sndTyped.ty).satisfies(constraint, l.pos)
(microWacc.ArrayLiter(List(fstTyped, sndTyped))(pairTy), fstErrors ++ sndErrors ++ pairErrors) (
microWacc.ArrayLiter(List(fstTyped, sndTyped))(pairTy, l.pos),
fstErrors ++ sndErrors ++ pairErrors
)
case ast.Call(id, args) => case ast.Call(id, args) =>
val funcTy @ FuncType(retTy, paramTys) = id.ty.asInstanceOf[FuncType] val funcTy @ FuncType(retTy, paramTys) = id.ty.asInstanceOf[FuncType]
val lenError = val lenError =
@@ -318,7 +324,7 @@ object typeChecker {
} }
val (retTyChecked, retErrors) = retTy.satisfies(constraint, id.pos) val (retTyChecked, retErrors) = retTy.satisfies(constraint, id.pos)
( (
microWacc.Call(microWacc.Ident(id.v, id.guid)(retTyChecked), argsTyped), microWacc.Call(microWacc.Ident(id.v, id.guid)(retTyChecked), argsTyped)(id.pos),
lenError ++ argsErrors ++ retErrors lenError ++ argsErrors ++ retErrors
) )
@@ -480,7 +486,7 @@ object typeChecker {
} }
extension (s: String) { extension (s: String) {
def toMicroWaccCharArray: microWacc.ArrayLiter = def toMicroWaccCharArray(pos: ast.Position): microWacc.ArrayLiter =
microWacc.ArrayLiter(s.map(microWacc.CharLiter(_)).toList)(KnownType.String) microWacc.ArrayLiter(s.map(microWacc.CharLiter(_)).toList)(KnownType.String, pos)
} }
} }