refactor: redesigned runtime errors with added functionality

This commit is contained in:
Jonny
2025-02-27 21:20:17 +00:00
parent 9e6970de62
commit 6f5fcd4d85
3 changed files with 115 additions and 86 deletions

View File

@@ -0,0 +1,104 @@
package wacc
import cats.data.Chain
import wacc.assemblyIR._
sealed trait RuntimeError {
def strLabel: String
def errStr: String
def errLabel: String
def stringDef: Chain[AsmLine] = Chain(
Directive.Int(errStr.length),
LabelDef(strLabel),
Directive.Asciz(errStr)
)
def generateHandler: Chain[AsmLine]
}
object RuntimeError {
// TODO: Refactor to mitigate imports and redeclared vals perhaps
import wacc.asmGenerator.stackAlign
import assemblyIR.Size._
import assemblyIR.RegName._
// private val RAX = Register(Q64, AX)
// private val EAX = Register(D32, AX)
private val RDI = Register(Q64, DI)
private val RIP = Register(Q64, IP)
// private val RBP = Register(Q64, BP)
private val RSI = Register(Q64, SI)
// private val RDX = Register(Q64, DX)
// private val RCX = Register(Q64, CX)
case object ZeroDivError extends RuntimeError {
val strLabel = ".L._errDivZero_str0"
val errStr = "fatal error: division or modulo by zero"
val errLabel = ".L._errDivZero"
def generateHandler: Chain[AsmLine] = Chain(
LabelDef(ZeroDivError.errLabel),
stackAlign,
Load(RDI, IndexAddress(RIP, LabelArg(ZeroDivError.strLabel))),
assemblyIR.Call(CLibFunc.PrintF),
Move(RDI, ImmediateVal(-1)),
assemblyIR.Call(CLibFunc.Exit)
)
}
case object BadChrError extends RuntimeError {
val strLabel = ".L._errBadChr_str0"
val errStr = "fatal error: int %d is not an ASCII character 0-127"
val errLabel = ".L._errBadChr"
def generateHandler: Chain[AsmLine] = Chain(
LabelDef(BadChrError.errLabel),
Pop(RSI),
stackAlign,
Load(RDI, IndexAddress(RIP, LabelArg(BadChrError.strLabel))),
assemblyIR.Call(CLibFunc.PrintF),
Move(RDI, ImmediateVal(255)),
assemblyIR.Call(CLibFunc.Exit)
)
}
case object NullPtrError extends RuntimeError {
val strLabel = ".L._errNullPtr_str0"
val errStr = "fatal error: null pair dereferenced or freed"
val errLabel = ".L._errNullPtr"
def generateHandler: Chain[AsmLine] = Chain(
LabelDef(NullPtrError.errLabel),
stackAlign,
Load(RDI, IndexAddress(RIP, LabelArg(NullPtrError.strLabel))),
assemblyIR.Call(CLibFunc.PrintF),
Move(RDI, ImmediateVal(255)),
assemblyIR.Call(CLibFunc.Exit)
)
}
case object OverflowError extends RuntimeError {
val strLabel = ".L._errOverflow_str0"
val errStr = "fatal error: integer overflow or underflow occurred"
val errLabel = ".L._errOverflow"
def generateHandler: Chain[AsmLine] = Chain(
LabelDef(OverflowError.errLabel),
stackAlign,
Load(RDI, IndexAddress(RIP, LabelArg(OverflowError.strLabel))),
assemblyIR.Call(CLibFunc.PrintF),
Move(RDI, ImmediateVal(255)),
assemblyIR.Call(CLibFunc.Exit)
)
}
val all: Chain[RuntimeError] = Chain(ZeroDivError, BadChrError, NullPtrError, OverflowError)
}

View File

@@ -1,35 +0,0 @@
package wacc
import cats.data.Chain
import wacc.assemblyIR._
case class RuntimeError(strLabel: String, errStr: String, errLabel: String) {
def stringDef: Chain[AsmLine] = Chain(
Directive.Int(errStr.size),
LabelDef(strLabel),
Directive.Asciz(errStr)
)
}
object RuntimeErrors {
val zeroDivError = RuntimeError(
".L._errDivZero_str0",
"fatal error: division or modulo by zero",
".L._errDivZero"
)
val badChrError = RuntimeError(
".L._errBadChr_str0",
"fatal error: int %d is not ascii character 0-127",
".L._errBadChr"
)
val nullPtrError = RuntimeError(
".L._errNullPtr_str0",
"fatal error: null pair dereferenced or freed",
".L._errNullPtr"
)
val overflowError = RuntimeError(
".L._errOverflow_str0",
"fatal error: integer overflow or underflow occurred",
".L._errOverflow"
)
}

View File

@@ -3,7 +3,7 @@ package wacc
import scala.collection.mutable.ListBuffer import scala.collection.mutable.ListBuffer
import cats.data.Chain import cats.data.Chain
import cats.syntax.foldable._ import cats.syntax.foldable._
import wacc.RuntimeErrors._ import wacc.RuntimeError._
object asmGenerator { object asmGenerator {
import microWacc._ import microWacc._
@@ -63,10 +63,7 @@ object asmGenerator {
LabelDef(s".L.str$i"), LabelDef(s".L.str$i"),
Directive.Asciz(str.escaped) Directive.Asciz(str.escaped)
) )
} ++ zeroDivError.stringDef } ++ RuntimeError.all.foldMap(_.stringDef)
++ badChrError.stringDef
++ nullPtrError.stringDef // TODO COLLATE TO ONE LIST INSTANCE
++ overflowError.stringDef
Chain( Chain(
Directive.IntelSyntax, Directive.IntelSyntax,
@@ -153,7 +150,7 @@ object asmGenerator {
stackAlign, stackAlign,
Move(RDI, RAX), Move(RDI, RAX),
Compare(RDI, ImmediateVal(0)), Compare(RDI, ImmediateVal(0)),
Jump(LabelArg(nullPtrError.errLabel), Cond.Equal), Jump(LabelArg(NullPtrError.errLabel), Cond.Equal),
assemblyIR.Call(CLibFunc.Free) assemblyIR.Call(CLibFunc.Free)
) )
) )
@@ -170,44 +167,7 @@ object asmGenerator {
) )
) )
chain ++= Chain( chain ++= RuntimeError.all.foldMap(_.generateHandler)
// TODO can this be done with a call to generateStmt?
// Consider other error cases -> look to generalise
LabelDef(zeroDivError.errLabel),
stackAlign,
Load(RDI, IndexAddress(RIP, LabelArg(zeroDivError.strLabel))),
assemblyIR.Call(CLibFunc.PrintF),
Move(RDI, ImmediateVal(-1)),
assemblyIR.Call(CLibFunc.Exit)
)
chain ++= Chain(
LabelDef(badChrError.errLabel),
Pop(RSI),
stackAlign,
Load(RDI, IndexAddress(RIP, LabelArg(badChrError.strLabel))),
assemblyIR.Call(CLibFunc.PrintF),
Move(RDI, ImmediateVal(255)),
assemblyIR.Call(CLibFunc.Exit)
)
chain ++= Chain(
LabelDef(nullPtrError.errLabel),
stackAlign,
Load(RDI, IndexAddress(RIP, LabelArg(nullPtrError.strLabel))),
assemblyIR.Call(CLibFunc.PrintF),
Move(RDI, ImmediateVal(255)),
assemblyIR.Call(CLibFunc.Exit)
)
chain ++= Chain(
LabelDef(overflowError.errLabel),
stackAlign,
Load(RDI, IndexAddress(RIP, LabelArg(overflowError.strLabel))),
assemblyIR.Call(CLibFunc.PrintF),
Move(RDI, ImmediateVal(255)),
assemblyIR.Call(CLibFunc.Exit)
)
chain chain
} }
@@ -348,7 +308,7 @@ object asmGenerator {
chain += Move(EAX, stack.head) chain += Move(EAX, stack.head)
chain += And(EAX, ImmediateVal(-128)) chain += And(EAX, ImmediateVal(-128))
chain += Compare(EAX, ImmediateVal(0)) chain += Compare(EAX, ImmediateVal(0))
chain += Jump(LabelArg(badChrError.errLabel), Cond.NotEqual) chain += Jump(LabelArg(BadChrError.errLabel), Cond.NotEqual)
case UnaryOperator.Ord => // No op needed case UnaryOperator.Ord => // No op needed
case UnaryOperator.Len => case UnaryOperator.Len =>
chain += stack.pop(RAX) chain += stack.pop(RAX)
@@ -357,7 +317,7 @@ object asmGenerator {
case UnaryOperator.Negate => case UnaryOperator.Negate =>
chain += Xor(EAX, EAX) chain += Xor(EAX, EAX)
chain += Subtract(EAX, stack.head) chain += Subtract(EAX, stack.head)
chain += Jump(LabelArg(overflowError.errLabel), Cond.Overflow) chain += Jump(LabelArg(OverflowError.errLabel), Cond.Overflow)
chain += stack.drop() chain += stack.drop()
chain += stack.push(Q64, RAX) chain += stack.push(Q64, RAX)
case UnaryOperator.Not => case UnaryOperator.Not =>
@@ -373,20 +333,20 @@ object asmGenerator {
op match { op match {
case BinaryOperator.Add => case BinaryOperator.Add =>
chain += Add(stack.head, destX) chain += Add(stack.head, destX)
chain += Jump(LabelArg(overflowError.errLabel), Cond.Overflow) chain += Jump(LabelArg(OverflowError.errLabel), Cond.Overflow)
case BinaryOperator.Sub => case BinaryOperator.Sub =>
chain += Subtract(destX, stack.head) chain += Subtract(destX, stack.head)
chain += stack.drop() chain += stack.drop()
chain += stack.push(destX.size, RAX) chain += stack.push(destX.size, RAX)
case BinaryOperator.Mul => case BinaryOperator.Mul =>
chain += Multiply(destX, stack.head) chain += Multiply(destX, stack.head)
chain += Jump(LabelArg(overflowError.errLabel), Cond.Overflow) chain += Jump(LabelArg(OverflowError.errLabel), Cond.Overflow)
chain += stack.drop() chain += stack.drop()
chain += stack.push(destX.size, RAX) chain += stack.push(destX.size, RAX)
case BinaryOperator.Div => case BinaryOperator.Div =>
chain += Compare(stack.head, ImmediateVal(0)) chain += Compare(stack.head, ImmediateVal(0))
chain += Jump(LabelArg(zeroDivError.errLabel), Cond.Equal) chain += Jump(LabelArg(ZeroDivError.errLabel), Cond.Equal)
chain += CDQ() chain += CDQ()
chain += Divide(stack.head) chain += Divide(stack.head)
chain += stack.drop() chain += stack.drop()
@@ -394,7 +354,7 @@ object asmGenerator {
case BinaryOperator.Mod => case BinaryOperator.Mod =>
chain += Compare(stack.head, ImmediateVal(0)) chain += Compare(stack.head, ImmediateVal(0))
chain += Jump(LabelArg(zeroDivError.errLabel), Cond.Equal) chain += Jump(LabelArg(ZeroDivError.errLabel), Cond.Equal)
chain += CDQ() chain += CDQ()
chain += Divide(stack.head) chain += Divide(stack.head)
chain += stack.drop() chain += stack.drop()
@@ -480,7 +440,7 @@ object asmGenerator {
chain chain
} }
private def stackAlign: AsmLine = And(Register(Q64, SP), ImmediateVal(-16)) def stackAlign: AsmLine = And(Register(Q64, SP), ImmediateVal(-16))
private def zeroRest(dest: Dest, size: Size): Chain[AsmLine] = size match { private def zeroRest(dest: Dest, size: Size): Chain[AsmLine] = size match {
case Q64 | D32 => Chain.empty case Q64 | D32 => Chain.empty
case _ => Chain.one(And(dest, ImmediateVal((1 << (size.toInt * 8)) - 1))) case _ => Chain.one(And(dest, ImmediateVal((1 << (size.toInt * 8)) - 1)))