diff --git a/src/main/wacc/backend/RuntimeError.scala b/src/main/wacc/backend/RuntimeError.scala new file mode 100644 index 0000000..e179494 --- /dev/null +++ b/src/main/wacc/backend/RuntimeError.scala @@ -0,0 +1,104 @@ +package wacc + +import cats.data.Chain +import wacc.assemblyIR._ + +sealed trait RuntimeError { + def strLabel: String + def errStr: String + def errLabel: String + + def stringDef: Chain[AsmLine] = Chain( + Directive.Int(errStr.length), + LabelDef(strLabel), + Directive.Asciz(errStr) + ) + + def generateHandler: Chain[AsmLine] + +} + +object RuntimeError { + + // TODO: Refactor to mitigate imports and redeclared vals perhaps + + import wacc.asmGenerator.stackAlign + import assemblyIR.Size._ + import assemblyIR.RegName._ + + // private val RAX = Register(Q64, AX) + // private val EAX = Register(D32, AX) + private val RDI = Register(Q64, DI) + private val RIP = Register(Q64, IP) + // private val RBP = Register(Q64, BP) + private val RSI = Register(Q64, SI) + // private val RDX = Register(Q64, DX) + // private val RCX = Register(Q64, CX) + + case object ZeroDivError extends RuntimeError { + val strLabel = ".L._errDivZero_str0" + val errStr = "fatal error: division or modulo by zero" + val errLabel = ".L._errDivZero" + + def generateHandler: Chain[AsmLine] = Chain( + LabelDef(ZeroDivError.errLabel), + stackAlign, + Load(RDI, IndexAddress(RIP, LabelArg(ZeroDivError.strLabel))), + assemblyIR.Call(CLibFunc.PrintF), + Move(RDI, ImmediateVal(-1)), + assemblyIR.Call(CLibFunc.Exit) + ) + + } + + case object BadChrError extends RuntimeError { + val strLabel = ".L._errBadChr_str0" + val errStr = "fatal error: int %d is not an ASCII character 0-127" + val errLabel = ".L._errBadChr" + + def generateHandler: Chain[AsmLine] = Chain( + LabelDef(BadChrError.errLabel), + Pop(RSI), + stackAlign, + Load(RDI, IndexAddress(RIP, LabelArg(BadChrError.strLabel))), + assemblyIR.Call(CLibFunc.PrintF), + Move(RDI, ImmediateVal(255)), + assemblyIR.Call(CLibFunc.Exit) + ) + + } + + case object NullPtrError extends RuntimeError { + val strLabel = ".L._errNullPtr_str0" + val errStr = "fatal error: null pair dereferenced or freed" + val errLabel = ".L._errNullPtr" + + def generateHandler: Chain[AsmLine] = Chain( + LabelDef(NullPtrError.errLabel), + stackAlign, + Load(RDI, IndexAddress(RIP, LabelArg(NullPtrError.strLabel))), + assemblyIR.Call(CLibFunc.PrintF), + Move(RDI, ImmediateVal(255)), + assemblyIR.Call(CLibFunc.Exit) + ) + + } + + case object OverflowError extends RuntimeError { + val strLabel = ".L._errOverflow_str0" + val errStr = "fatal error: integer overflow or underflow occurred" + val errLabel = ".L._errOverflow" + + def generateHandler: Chain[AsmLine] = Chain( + LabelDef(OverflowError.errLabel), + stackAlign, + Load(RDI, IndexAddress(RIP, LabelArg(OverflowError.strLabel))), + assemblyIR.Call(CLibFunc.PrintF), + Move(RDI, ImmediateVal(255)), + assemblyIR.Call(CLibFunc.Exit) + ) + + } + + val all: Chain[RuntimeError] = Chain(ZeroDivError, BadChrError, NullPtrError, OverflowError) +} diff --git a/src/main/wacc/backend/RuntimeErrors.scala b/src/main/wacc/backend/RuntimeErrors.scala deleted file mode 100644 index 06062ad..0000000 --- a/src/main/wacc/backend/RuntimeErrors.scala +++ /dev/null @@ -1,35 +0,0 @@ -package wacc - -import cats.data.Chain -import wacc.assemblyIR._ - -case class RuntimeError(strLabel: String, errStr: String, errLabel: String) { - def stringDef: Chain[AsmLine] = Chain( - Directive.Int(errStr.size), - LabelDef(strLabel), - Directive.Asciz(errStr) - ) -} - -object RuntimeErrors { - val zeroDivError = RuntimeError( - ".L._errDivZero_str0", - "fatal error: division or modulo by zero", - ".L._errDivZero" - ) - val badChrError = RuntimeError( - ".L._errBadChr_str0", - "fatal error: int %d is not ascii character 0-127", - ".L._errBadChr" - ) - val nullPtrError = RuntimeError( - ".L._errNullPtr_str0", - "fatal error: null pair dereferenced or freed", - ".L._errNullPtr" - ) - val overflowError = RuntimeError( - ".L._errOverflow_str0", - "fatal error: integer overflow or underflow occurred", - ".L._errOverflow" - ) -} diff --git a/src/main/wacc/backend/asmGenerator.scala b/src/main/wacc/backend/asmGenerator.scala index 411d02e..5b3db3e 100644 --- a/src/main/wacc/backend/asmGenerator.scala +++ b/src/main/wacc/backend/asmGenerator.scala @@ -3,7 +3,7 @@ package wacc import scala.collection.mutable.ListBuffer import cats.data.Chain import cats.syntax.foldable._ -import wacc.RuntimeErrors._ +import wacc.RuntimeError._ object asmGenerator { import microWacc._ @@ -63,10 +63,7 @@ object asmGenerator { LabelDef(s".L.str$i"), Directive.Asciz(str.escaped) ) - } ++ zeroDivError.stringDef - ++ badChrError.stringDef - ++ nullPtrError.stringDef // TODO COLLATE TO ONE LIST INSTANCE - ++ overflowError.stringDef + } ++ RuntimeError.all.foldMap(_.stringDef) Chain( Directive.IntelSyntax, @@ -153,7 +150,7 @@ object asmGenerator { stackAlign, Move(RDI, RAX), Compare(RDI, ImmediateVal(0)), - Jump(LabelArg(nullPtrError.errLabel), Cond.Equal), + Jump(LabelArg(NullPtrError.errLabel), Cond.Equal), assemblyIR.Call(CLibFunc.Free) ) ) @@ -170,44 +167,7 @@ object asmGenerator { ) ) - chain ++= Chain( - // TODO can this be done with a call to generateStmt? - // Consider other error cases -> look to generalise - LabelDef(zeroDivError.errLabel), - stackAlign, - Load(RDI, IndexAddress(RIP, LabelArg(zeroDivError.strLabel))), - assemblyIR.Call(CLibFunc.PrintF), - Move(RDI, ImmediateVal(-1)), - assemblyIR.Call(CLibFunc.Exit) - ) - - chain ++= Chain( - LabelDef(badChrError.errLabel), - Pop(RSI), - stackAlign, - Load(RDI, IndexAddress(RIP, LabelArg(badChrError.strLabel))), - assemblyIR.Call(CLibFunc.PrintF), - Move(RDI, ImmediateVal(255)), - assemblyIR.Call(CLibFunc.Exit) - ) - - chain ++= Chain( - LabelDef(nullPtrError.errLabel), - stackAlign, - Load(RDI, IndexAddress(RIP, LabelArg(nullPtrError.strLabel))), - assemblyIR.Call(CLibFunc.PrintF), - Move(RDI, ImmediateVal(255)), - assemblyIR.Call(CLibFunc.Exit) - ) - - chain ++= Chain( - LabelDef(overflowError.errLabel), - stackAlign, - Load(RDI, IndexAddress(RIP, LabelArg(overflowError.strLabel))), - assemblyIR.Call(CLibFunc.PrintF), - Move(RDI, ImmediateVal(255)), - assemblyIR.Call(CLibFunc.Exit) - ) + chain ++= RuntimeError.all.foldMap(_.generateHandler) chain } @@ -348,7 +308,7 @@ object asmGenerator { chain += Move(EAX, stack.head) chain += And(EAX, ImmediateVal(-128)) chain += Compare(EAX, ImmediateVal(0)) - chain += Jump(LabelArg(badChrError.errLabel), Cond.NotEqual) + chain += Jump(LabelArg(BadChrError.errLabel), Cond.NotEqual) case UnaryOperator.Ord => // No op needed case UnaryOperator.Len => chain += stack.pop(RAX) @@ -357,7 +317,7 @@ object asmGenerator { case UnaryOperator.Negate => chain += Xor(EAX, EAX) chain += Subtract(EAX, stack.head) - chain += Jump(LabelArg(overflowError.errLabel), Cond.Overflow) + chain += Jump(LabelArg(OverflowError.errLabel), Cond.Overflow) chain += stack.drop() chain += stack.push(Q64, RAX) case UnaryOperator.Not => @@ -373,20 +333,20 @@ object asmGenerator { op match { case BinaryOperator.Add => chain += Add(stack.head, destX) - chain += Jump(LabelArg(overflowError.errLabel), Cond.Overflow) + chain += Jump(LabelArg(OverflowError.errLabel), Cond.Overflow) case BinaryOperator.Sub => chain += Subtract(destX, stack.head) chain += stack.drop() chain += stack.push(destX.size, RAX) case BinaryOperator.Mul => chain += Multiply(destX, stack.head) - chain += Jump(LabelArg(overflowError.errLabel), Cond.Overflow) + chain += Jump(LabelArg(OverflowError.errLabel), Cond.Overflow) chain += stack.drop() chain += stack.push(destX.size, RAX) case BinaryOperator.Div => chain += Compare(stack.head, ImmediateVal(0)) - chain += Jump(LabelArg(zeroDivError.errLabel), Cond.Equal) + chain += Jump(LabelArg(ZeroDivError.errLabel), Cond.Equal) chain += CDQ() chain += Divide(stack.head) chain += stack.drop() @@ -394,7 +354,7 @@ object asmGenerator { case BinaryOperator.Mod => chain += Compare(stack.head, ImmediateVal(0)) - chain += Jump(LabelArg(zeroDivError.errLabel), Cond.Equal) + chain += Jump(LabelArg(ZeroDivError.errLabel), Cond.Equal) chain += CDQ() chain += Divide(stack.head) chain += stack.drop() @@ -480,7 +440,7 @@ object asmGenerator { chain } - private def stackAlign: AsmLine = And(Register(Q64, SP), ImmediateVal(-16)) + def stackAlign: AsmLine = And(Register(Q64, SP), ImmediateVal(-16)) private def zeroRest(dest: Dest, size: Size): Chain[AsmLine] = size match { case Q64 | D32 => Chain.empty case _ => Chain.one(And(dest, ImmediateVal((1 << (size.toInt * 8)) - 1)))