feat: implement all runtime errors
Merge request lab2425_spring/WACC_37!32 Co-authored-by: Guy C <gc1523@ic.ac.uk> Co-authored-by: Jonny <j.sinteix@gmail.com> Co-authored-by: Connolly, Guy <guy.connolly23@imperial.ac.uk> Co-authored-by: Gleb Koval <gleb@koval.net>
This commit is contained in:
@@ -82,9 +82,11 @@ def compile(filename: String, outFile: Option[File] = None)(using
|
||||
def main(args: Array[String]): Unit =
|
||||
OParser.parse(cliParser, args, CliConfig()) match {
|
||||
case Some(config) =>
|
||||
System.exit(
|
||||
compile(
|
||||
config.file.getAbsolutePath,
|
||||
outFile = Some(File(".", config.file.getName.stripSuffix(".wacc") + ".s"))
|
||||
)
|
||||
)
|
||||
case None =>
|
||||
}
|
||||
|
||||
122
src/main/wacc/backend/RuntimeError.scala
Normal file
122
src/main/wacc/backend/RuntimeError.scala
Normal file
@@ -0,0 +1,122 @@
|
||||
package wacc
|
||||
|
||||
import cats.data.Chain
|
||||
import wacc.assemblyIR._
|
||||
|
||||
sealed trait RuntimeError {
|
||||
def strLabel: String
|
||||
def errStr: String
|
||||
def errLabel: String
|
||||
|
||||
def stringDef: Chain[AsmLine] = Chain(
|
||||
Directive.Int(errStr.length),
|
||||
LabelDef(strLabel),
|
||||
Directive.Asciz(errStr)
|
||||
)
|
||||
|
||||
def generateHandler: Chain[AsmLine]
|
||||
|
||||
}
|
||||
|
||||
object RuntimeError {
|
||||
|
||||
// TODO: Refactor to mitigate imports and redeclared vals perhaps
|
||||
|
||||
import wacc.asmGenerator.stackAlign
|
||||
import assemblyIR.Size._
|
||||
import assemblyIR.RegName._
|
||||
|
||||
// private val RAX = Register(Q64, AX)
|
||||
// private val EAX = Register(D32, AX)
|
||||
private val RDI = Register(Q64, DI)
|
||||
private val RIP = Register(Q64, IP)
|
||||
// private val RBP = Register(Q64, BP)
|
||||
private val RSI = Register(Q64, SI)
|
||||
// private val RDX = Register(Q64, DX)
|
||||
// private val RCX = Register(Q64, CX)
|
||||
|
||||
case object ZeroDivError extends RuntimeError {
|
||||
val strLabel = ".L._errDivZero_str0"
|
||||
val errStr = "fatal error: division or modulo by zero"
|
||||
val errLabel = ".L._errDivZero"
|
||||
|
||||
def generateHandler: Chain[AsmLine] = Chain(
|
||||
LabelDef(ZeroDivError.errLabel),
|
||||
stackAlign,
|
||||
Load(RDI, IndexAddress(RIP, LabelArg(ZeroDivError.strLabel))),
|
||||
assemblyIR.Call(CLibFunc.PrintF),
|
||||
Move(RDI, ImmediateVal(-1)),
|
||||
assemblyIR.Call(CLibFunc.Exit)
|
||||
)
|
||||
|
||||
}
|
||||
|
||||
case object BadChrError extends RuntimeError {
|
||||
val strLabel = ".L._errBadChr_str0"
|
||||
val errStr = "fatal error: int %d is not an ASCII character 0-127"
|
||||
val errLabel = ".L._errBadChr"
|
||||
|
||||
def generateHandler: Chain[AsmLine] = Chain(
|
||||
LabelDef(BadChrError.errLabel),
|
||||
Pop(RSI),
|
||||
stackAlign,
|
||||
Load(RDI, IndexAddress(RIP, LabelArg(BadChrError.strLabel))),
|
||||
assemblyIR.Call(CLibFunc.PrintF),
|
||||
Move(RDI, ImmediateVal(255)),
|
||||
assemblyIR.Call(CLibFunc.Exit)
|
||||
)
|
||||
|
||||
}
|
||||
|
||||
case object NullPtrError extends RuntimeError {
|
||||
val strLabel = ".L._errNullPtr_str0"
|
||||
val errStr = "fatal error: null pair dereferenced or freed"
|
||||
val errLabel = ".L._errNullPtr"
|
||||
|
||||
def generateHandler: Chain[AsmLine] = Chain(
|
||||
LabelDef(NullPtrError.errLabel),
|
||||
stackAlign,
|
||||
Load(RDI, IndexAddress(RIP, LabelArg(NullPtrError.strLabel))),
|
||||
assemblyIR.Call(CLibFunc.PrintF),
|
||||
Move(RDI, ImmediateVal(255)),
|
||||
assemblyIR.Call(CLibFunc.Exit)
|
||||
)
|
||||
|
||||
}
|
||||
|
||||
case object OverflowError extends RuntimeError {
|
||||
val strLabel = ".L._errOverflow_str0"
|
||||
val errStr = "fatal error: integer overflow or underflow occurred"
|
||||
val errLabel = ".L._errOverflow"
|
||||
|
||||
def generateHandler: Chain[AsmLine] = Chain(
|
||||
LabelDef(OverflowError.errLabel),
|
||||
stackAlign,
|
||||
Load(RDI, IndexAddress(RIP, LabelArg(OverflowError.strLabel))),
|
||||
assemblyIR.Call(CLibFunc.PrintF),
|
||||
Move(RDI, ImmediateVal(255)),
|
||||
assemblyIR.Call(CLibFunc.Exit)
|
||||
)
|
||||
|
||||
}
|
||||
|
||||
case object OutOfBoundsError extends RuntimeError {
|
||||
|
||||
val strLabel = ".L._errOutOfBounds_str0"
|
||||
val errStr = "fatal error: array index %d out of bounds"
|
||||
val errLabel = ".L._errOutOfBounds"
|
||||
|
||||
def generateHandler: Chain[AsmLine] = Chain(
|
||||
LabelDef(OutOfBoundsError.errLabel),
|
||||
Move(RSI, Register(Q64, CX)),
|
||||
stackAlign,
|
||||
Load(RDI, IndexAddress(RIP, LabelArg(OutOfBoundsError.strLabel))),
|
||||
assemblyIR.Call(CLibFunc.PrintF),
|
||||
Move(RDI, ImmediateVal(255)),
|
||||
assemblyIR.Call(CLibFunc.Exit)
|
||||
)
|
||||
}
|
||||
|
||||
val all: Chain[RuntimeError] =
|
||||
Chain(ZeroDivError, BadChrError, NullPtrError, OverflowError, OutOfBoundsError)
|
||||
}
|
||||
@@ -3,6 +3,7 @@ package wacc
|
||||
import scala.collection.mutable.ListBuffer
|
||||
import cats.data.Chain
|
||||
import cats.syntax.foldable._
|
||||
import wacc.RuntimeError._
|
||||
|
||||
object asmGenerator {
|
||||
import microWacc._
|
||||
@@ -13,24 +14,6 @@ object asmGenerator {
|
||||
import sizeExtensions._
|
||||
import lexer.escapedChars
|
||||
|
||||
abstract case class Error() {
|
||||
def strLabel: String
|
||||
def errStr: String
|
||||
def errLabel: String
|
||||
|
||||
def stringDef: Chain[AsmLine] = Chain(
|
||||
Directive.Int(errStr.size),
|
||||
LabelDef(strLabel),
|
||||
Directive.Asciz(errStr)
|
||||
)
|
||||
}
|
||||
object zeroDivError extends Error {
|
||||
// TODO: is this bad? Can we make an error case class/some other structure?
|
||||
def strLabel = ".L._errDivZero_str0"
|
||||
def errStr = "fatal error: division or modulo by zero"
|
||||
def errLabel = ".L._errDivZero"
|
||||
}
|
||||
|
||||
private val RAX = Register(Q64, AX)
|
||||
private val EAX = Register(D32, AX)
|
||||
private val RDI = Register(Q64, DI)
|
||||
@@ -39,6 +22,7 @@ object asmGenerator {
|
||||
private val RSI = Register(Q64, SI)
|
||||
private val RDX = Register(Q64, DX)
|
||||
private val RCX = Register(Q64, CX)
|
||||
private val ECX = Register(D32, CX)
|
||||
private val argRegs = List(DI, SI, DX, CX, R8, R9)
|
||||
|
||||
extension [T](chain: Chain[T])
|
||||
@@ -80,7 +64,7 @@ object asmGenerator {
|
||||
LabelDef(s".L.str$i"),
|
||||
Directive.Asciz(str.escaped)
|
||||
)
|
||||
} ++ zeroDivError.stringDef
|
||||
} ++ RuntimeError.all.foldMap(_.stringDef)
|
||||
|
||||
Chain(
|
||||
Directive.IntelSyntax,
|
||||
@@ -161,7 +145,16 @@ object asmGenerator {
|
||||
// Out of memory check is optional
|
||||
)
|
||||
|
||||
chain ++= wrapBuiltinFunc(labelGenerator.getLabel(Builtin.Free), Chain.empty)
|
||||
chain ++= wrapBuiltinFunc(
|
||||
labelGenerator.getLabel(Builtin.Free),
|
||||
Chain(
|
||||
stackAlign,
|
||||
Move(RDI, RAX),
|
||||
Compare(RDI, ImmediateVal(0)),
|
||||
Jump(LabelArg(NullPtrError.errLabel), Cond.Equal),
|
||||
assemblyIR.Call(CLibFunc.Free)
|
||||
)
|
||||
)
|
||||
|
||||
chain ++= wrapBuiltinFunc(
|
||||
labelGenerator.getLabel(Builtin.Read),
|
||||
@@ -175,16 +168,7 @@ object asmGenerator {
|
||||
)
|
||||
)
|
||||
|
||||
chain ++= Chain(
|
||||
// TODO can this be done with a call to generateStmt?
|
||||
// Consider other error cases -> look to generalise
|
||||
LabelDef(zeroDivError.errLabel),
|
||||
stackAlign,
|
||||
Load(RDI, IndexAddress(RIP, LabelArg(zeroDivError.strLabel))),
|
||||
assemblyIR.Call(CLibFunc.PrintF),
|
||||
Move(RDI, ImmediateVal(-1)),
|
||||
assemblyIR.Call(CLibFunc.Exit)
|
||||
)
|
||||
chain ++= RuntimeError.all.foldMap(_.generateHandler)
|
||||
|
||||
chain
|
||||
}
|
||||
@@ -207,9 +191,17 @@ object asmGenerator {
|
||||
case ArrayElem(x, i) =>
|
||||
chain ++= evalExprOntoStack(rhs)
|
||||
chain ++= evalExprOntoStack(i)
|
||||
chain += stack.pop(RCX)
|
||||
chain += Compare(ECX, ImmediateVal(0))
|
||||
chain += Jump(LabelArg(OutOfBoundsError.errLabel), Cond.Less)
|
||||
chain += stack.push(Q64, RCX)
|
||||
chain ++= evalExprOntoStack(x)
|
||||
chain += stack.pop(RAX)
|
||||
chain += stack.pop(RCX)
|
||||
chain += Compare(EAX, ImmediateVal(0))
|
||||
chain += Jump(LabelArg(NullPtrError.errLabel), Cond.Equal)
|
||||
chain += Compare(MemLocation(RAX, D32), ECX)
|
||||
chain += Jump(LabelArg(OutOfBoundsError.errLabel), Cond.LessEqual)
|
||||
chain += stack.pop(RDX)
|
||||
|
||||
chain += Move(
|
||||
@@ -311,7 +303,13 @@ object asmGenerator {
|
||||
chain ++= evalExprOntoStack(x)
|
||||
chain ++= evalExprOntoStack(i)
|
||||
chain += stack.pop(RCX)
|
||||
chain += Compare(RCX, ImmediateVal(0))
|
||||
chain += Jump(LabelArg(OutOfBoundsError.errLabel), Cond.Less)
|
||||
chain += stack.pop(RAX)
|
||||
chain += Compare(EAX, ImmediateVal(0))
|
||||
chain += Jump(LabelArg(NullPtrError.errLabel), Cond.Equal)
|
||||
chain += Compare(MemLocation(RAX, D32), ECX)
|
||||
chain += Jump(LabelArg(OutOfBoundsError.errLabel), Cond.LessEqual)
|
||||
// + Int because we store the length of the array at the start
|
||||
chain += Move(
|
||||
Register(x.ty.elemSize, AX),
|
||||
@@ -321,13 +319,22 @@ object asmGenerator {
|
||||
case UnaryOp(x, op) =>
|
||||
chain ++= evalExprOntoStack(x)
|
||||
op match {
|
||||
case UnaryOperator.Chr | UnaryOperator.Ord => // No op needed
|
||||
case UnaryOperator.Chr =>
|
||||
chain += Move(EAX, stack.head)
|
||||
chain += And(EAX, ImmediateVal(-128))
|
||||
chain += Compare(EAX, ImmediateVal(0))
|
||||
chain += Jump(LabelArg(BadChrError.errLabel), Cond.NotEqual)
|
||||
case UnaryOperator.Ord => // No op needed
|
||||
case UnaryOperator.Len =>
|
||||
chain += stack.pop(RAX)
|
||||
chain += Move(EAX, MemLocation(RAX, D32))
|
||||
chain += stack.push(D32, RAX)
|
||||
case UnaryOperator.Negate =>
|
||||
chain += Negate(stack.head)
|
||||
chain += Xor(EAX, EAX)
|
||||
chain += Subtract(EAX, stack.head)
|
||||
chain += Jump(LabelArg(OverflowError.errLabel), Cond.Overflow)
|
||||
chain += stack.drop()
|
||||
chain += stack.push(Q64, RAX)
|
||||
case UnaryOperator.Not =>
|
||||
chain += Xor(stack.head, ImmediateVal(1))
|
||||
}
|
||||
@@ -341,24 +348,29 @@ object asmGenerator {
|
||||
op match {
|
||||
case BinaryOperator.Add =>
|
||||
chain += Add(stack.head, destX)
|
||||
chain += Jump(LabelArg(OverflowError.errLabel), Cond.Overflow)
|
||||
case BinaryOperator.Sub =>
|
||||
chain += Subtract(destX, stack.head)
|
||||
chain += Jump(LabelArg(OverflowError.errLabel), Cond.Overflow)
|
||||
chain += stack.drop()
|
||||
chain += stack.push(destX.size, RAX)
|
||||
case BinaryOperator.Mul =>
|
||||
chain += Multiply(destX, stack.head)
|
||||
chain += Jump(LabelArg(OverflowError.errLabel), Cond.Overflow)
|
||||
chain += stack.drop()
|
||||
chain += stack.push(destX.size, RAX)
|
||||
|
||||
case BinaryOperator.Div =>
|
||||
chain += Compare(stack.head, ImmediateVal(0))
|
||||
chain += Jump(LabelArg(zeroDivError.errLabel), Cond.Equal)
|
||||
chain += Jump(LabelArg(ZeroDivError.errLabel), Cond.Equal)
|
||||
chain += CDQ()
|
||||
chain += Divide(stack.head)
|
||||
chain += stack.drop()
|
||||
chain += stack.push(destX.size, RAX)
|
||||
|
||||
case BinaryOperator.Mod =>
|
||||
chain += Compare(stack.head, ImmediateVal(0))
|
||||
chain += Jump(LabelArg(ZeroDivError.errLabel), Cond.Equal)
|
||||
chain += CDQ()
|
||||
chain += Divide(stack.head)
|
||||
chain += stack.drop()
|
||||
@@ -444,7 +456,7 @@ object asmGenerator {
|
||||
chain
|
||||
}
|
||||
|
||||
private def stackAlign: AsmLine = And(Register(Q64, SP), ImmediateVal(-16))
|
||||
def stackAlign: AsmLine = And(Register(Q64, SP), ImmediateVal(-16))
|
||||
private def zeroRest(dest: Dest, size: Size): Chain[AsmLine] = size match {
|
||||
case Q64 | D32 => Chain.empty
|
||||
case _ => Chain.one(And(dest, ImmediateVal((1 << (size.toInt * 8)) - 1)))
|
||||
|
||||
@@ -180,8 +180,8 @@ object typeChecker {
|
||||
microWacc.Builtin.Read,
|
||||
List(
|
||||
destTy match {
|
||||
case KnownType.Int => "%d".toMicroWaccCharArray
|
||||
case KnownType.Char | _ => "%c".toMicroWaccCharArray
|
||||
case KnownType.Int => " %d".toMicroWaccCharArray
|
||||
case KnownType.Char | _ => " %c".toMicroWaccCharArray
|
||||
},
|
||||
destTyped
|
||||
)
|
||||
|
||||
@@ -41,23 +41,7 @@ class ParallelExamplesSpec extends AnyFlatSpec with BeforeAndAfterAll {
|
||||
val inputLine =
|
||||
contents
|
||||
.find(_.matches("^# ?[Ii]nput:.*$"))
|
||||
.map(line =>
|
||||
("" :: line.split(":").last.strip.split(" ").toList)
|
||||
.sliding(2)
|
||||
.flatMap { arr =>
|
||||
if (
|
||||
// First entry has no space in front
|
||||
arr(0) == "" ||
|
||||
// int followed by non-digit, space can be removed
|
||||
arr(0).toIntOption.nonEmpty && !arr(1)(0).isDigit ||
|
||||
// non-int followed by int, space can be removed
|
||||
!arr(0).last.isDigit && arr(1).toIntOption.nonEmpty
|
||||
)
|
||||
then List(arr(1))
|
||||
else List(" ", arr(1))
|
||||
}
|
||||
.mkString
|
||||
)
|
||||
.map(_.split(":").last.strip + "\n")
|
||||
.getOrElse("")
|
||||
val outputLineIdx = contents.indexWhere(_.matches("^# ?[Oo]utput:.*$"))
|
||||
val expectedOutput =
|
||||
@@ -92,7 +76,13 @@ class ParallelExamplesSpec extends AnyFlatSpec with BeforeAndAfterAll {
|
||||
)
|
||||
|
||||
assert(process.exitValue == expectedExit)
|
||||
assert(stdout.toString.replaceAll("0x[0-9a-f]+", "#addrs#") == expectedOutput)
|
||||
assert(
|
||||
stdout.toString
|
||||
.replaceAll("0x[0-9a-f]+", "#addrs#")
|
||||
.replaceAll("fatal error:.*", "#runtime_error#\u0000")
|
||||
.takeWhile(_ != '\u0000')
|
||||
== expectedOutput
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -117,7 +107,7 @@ class ParallelExamplesSpec extends AnyFlatSpec with BeforeAndAfterAll {
|
||||
// "^.*wacc-examples/valid/IO/IOLoop.wacc.*$",
|
||||
// "^.*wacc-examples/valid/IO/IOSequence.wacc.*$",
|
||||
// "^.*wacc-examples/valid/pairs.*$",
|
||||
"^.*wacc-examples/valid/runtimeErr.*$",
|
||||
//"^.*wacc-examples/valid/runtimeErr.*$",
|
||||
// "^.*wacc-examples/valid/scope.*$",
|
||||
// "^.*wacc-examples/valid/sequence.*$",
|
||||
// "^.*wacc-examples/valid/variables.*$",
|
||||
|
||||
Reference in New Issue
Block a user