diff --git a/src/main/wacc/Main.scala b/src/main/wacc/Main.scala index 020cbcd..fc9fb45 100644 --- a/src/main/wacc/Main.scala +++ b/src/main/wacc/Main.scala @@ -82,9 +82,11 @@ def compile(filename: String, outFile: Option[File] = None)(using def main(args: Array[String]): Unit = OParser.parse(cliParser, args, CliConfig()) match { case Some(config) => - compile( - config.file.getAbsolutePath, - outFile = Some(File(".", config.file.getName.stripSuffix(".wacc") + ".s")) + System.exit( + compile( + config.file.getAbsolutePath, + outFile = Some(File(".", config.file.getName.stripSuffix(".wacc") + ".s")) + ) ) case None => } diff --git a/src/main/wacc/backend/RuntimeError.scala b/src/main/wacc/backend/RuntimeError.scala new file mode 100644 index 0000000..9085b63 --- /dev/null +++ b/src/main/wacc/backend/RuntimeError.scala @@ -0,0 +1,122 @@ +package wacc + +import cats.data.Chain +import wacc.assemblyIR._ + +sealed trait RuntimeError { + def strLabel: String + def errStr: String + def errLabel: String + + def stringDef: Chain[AsmLine] = Chain( + Directive.Int(errStr.length), + LabelDef(strLabel), + Directive.Asciz(errStr) + ) + + def generateHandler: Chain[AsmLine] + +} + +object RuntimeError { + + // TODO: Refactor to mitigate imports and redeclared vals perhaps + + import wacc.asmGenerator.stackAlign + import assemblyIR.Size._ + import assemblyIR.RegName._ + + // private val RAX = Register(Q64, AX) + // private val EAX = Register(D32, AX) + private val RDI = Register(Q64, DI) + private val RIP = Register(Q64, IP) + // private val RBP = Register(Q64, BP) + private val RSI = Register(Q64, SI) + // private val RDX = Register(Q64, DX) + // private val RCX = Register(Q64, CX) + + case object ZeroDivError extends RuntimeError { + val strLabel = ".L._errDivZero_str0" + val errStr = "fatal error: division or modulo by zero" + val errLabel = ".L._errDivZero" + + def generateHandler: Chain[AsmLine] = Chain( + LabelDef(ZeroDivError.errLabel), + stackAlign, + Load(RDI, IndexAddress(RIP, LabelArg(ZeroDivError.strLabel))), + assemblyIR.Call(CLibFunc.PrintF), + Move(RDI, ImmediateVal(-1)), + assemblyIR.Call(CLibFunc.Exit) + ) + + } + + case object BadChrError extends RuntimeError { + val strLabel = ".L._errBadChr_str0" + val errStr = "fatal error: int %d is not an ASCII character 0-127" + val errLabel = ".L._errBadChr" + + def generateHandler: Chain[AsmLine] = Chain( + LabelDef(BadChrError.errLabel), + Pop(RSI), + stackAlign, + Load(RDI, IndexAddress(RIP, LabelArg(BadChrError.strLabel))), + assemblyIR.Call(CLibFunc.PrintF), + Move(RDI, ImmediateVal(255)), + assemblyIR.Call(CLibFunc.Exit) + ) + + } + + case object NullPtrError extends RuntimeError { + val strLabel = ".L._errNullPtr_str0" + val errStr = "fatal error: null pair dereferenced or freed" + val errLabel = ".L._errNullPtr" + + def generateHandler: Chain[AsmLine] = Chain( + LabelDef(NullPtrError.errLabel), + stackAlign, + Load(RDI, IndexAddress(RIP, LabelArg(NullPtrError.strLabel))), + assemblyIR.Call(CLibFunc.PrintF), + Move(RDI, ImmediateVal(255)), + assemblyIR.Call(CLibFunc.Exit) + ) + + } + + case object OverflowError extends RuntimeError { + val strLabel = ".L._errOverflow_str0" + val errStr = "fatal error: integer overflow or underflow occurred" + val errLabel = ".L._errOverflow" + + def generateHandler: Chain[AsmLine] = Chain( + LabelDef(OverflowError.errLabel), + stackAlign, + Load(RDI, IndexAddress(RIP, LabelArg(OverflowError.strLabel))), + assemblyIR.Call(CLibFunc.PrintF), + Move(RDI, ImmediateVal(255)), + assemblyIR.Call(CLibFunc.Exit) + ) + + } + + case object OutOfBoundsError extends RuntimeError { + + val strLabel = ".L._errOutOfBounds_str0" + val errStr = "fatal error: array index %d out of bounds" + val errLabel = ".L._errOutOfBounds" + + def generateHandler: Chain[AsmLine] = Chain( + LabelDef(OutOfBoundsError.errLabel), + Move(RSI, Register(Q64, CX)), + stackAlign, + Load(RDI, IndexAddress(RIP, LabelArg(OutOfBoundsError.strLabel))), + assemblyIR.Call(CLibFunc.PrintF), + Move(RDI, ImmediateVal(255)), + assemblyIR.Call(CLibFunc.Exit) + ) + } + + val all: Chain[RuntimeError] = + Chain(ZeroDivError, BadChrError, NullPtrError, OverflowError, OutOfBoundsError) +} diff --git a/src/main/wacc/backend/asmGenerator.scala b/src/main/wacc/backend/asmGenerator.scala index 60e0b47..71a99b5 100644 --- a/src/main/wacc/backend/asmGenerator.scala +++ b/src/main/wacc/backend/asmGenerator.scala @@ -3,6 +3,7 @@ package wacc import scala.collection.mutable.ListBuffer import cats.data.Chain import cats.syntax.foldable._ +import wacc.RuntimeError._ object asmGenerator { import microWacc._ @@ -13,24 +14,6 @@ object asmGenerator { import sizeExtensions._ import lexer.escapedChars - abstract case class Error() { - def strLabel: String - def errStr: String - def errLabel: String - - def stringDef: Chain[AsmLine] = Chain( - Directive.Int(errStr.size), - LabelDef(strLabel), - Directive.Asciz(errStr) - ) - } - object zeroDivError extends Error { - // TODO: is this bad? Can we make an error case class/some other structure? - def strLabel = ".L._errDivZero_str0" - def errStr = "fatal error: division or modulo by zero" - def errLabel = ".L._errDivZero" - } - private val RAX = Register(Q64, AX) private val EAX = Register(D32, AX) private val RDI = Register(Q64, DI) @@ -39,6 +22,7 @@ object asmGenerator { private val RSI = Register(Q64, SI) private val RDX = Register(Q64, DX) private val RCX = Register(Q64, CX) + private val ECX = Register(D32, CX) private val argRegs = List(DI, SI, DX, CX, R8, R9) extension [T](chain: Chain[T]) @@ -80,7 +64,7 @@ object asmGenerator { LabelDef(s".L.str$i"), Directive.Asciz(str.escaped) ) - } ++ zeroDivError.stringDef + } ++ RuntimeError.all.foldMap(_.stringDef) Chain( Directive.IntelSyntax, @@ -161,7 +145,16 @@ object asmGenerator { // Out of memory check is optional ) - chain ++= wrapBuiltinFunc(labelGenerator.getLabel(Builtin.Free), Chain.empty) + chain ++= wrapBuiltinFunc( + labelGenerator.getLabel(Builtin.Free), + Chain( + stackAlign, + Move(RDI, RAX), + Compare(RDI, ImmediateVal(0)), + Jump(LabelArg(NullPtrError.errLabel), Cond.Equal), + assemblyIR.Call(CLibFunc.Free) + ) + ) chain ++= wrapBuiltinFunc( labelGenerator.getLabel(Builtin.Read), @@ -175,16 +168,7 @@ object asmGenerator { ) ) - chain ++= Chain( - // TODO can this be done with a call to generateStmt? - // Consider other error cases -> look to generalise - LabelDef(zeroDivError.errLabel), - stackAlign, - Load(RDI, IndexAddress(RIP, LabelArg(zeroDivError.strLabel))), - assemblyIR.Call(CLibFunc.PrintF), - Move(RDI, ImmediateVal(-1)), - assemblyIR.Call(CLibFunc.Exit) - ) + chain ++= RuntimeError.all.foldMap(_.generateHandler) chain } @@ -207,9 +191,17 @@ object asmGenerator { case ArrayElem(x, i) => chain ++= evalExprOntoStack(rhs) chain ++= evalExprOntoStack(i) + chain += stack.pop(RCX) + chain += Compare(ECX, ImmediateVal(0)) + chain += Jump(LabelArg(OutOfBoundsError.errLabel), Cond.Less) + chain += stack.push(Q64, RCX) chain ++= evalExprOntoStack(x) chain += stack.pop(RAX) chain += stack.pop(RCX) + chain += Compare(EAX, ImmediateVal(0)) + chain += Jump(LabelArg(NullPtrError.errLabel), Cond.Equal) + chain += Compare(MemLocation(RAX, D32), ECX) + chain += Jump(LabelArg(OutOfBoundsError.errLabel), Cond.LessEqual) chain += stack.pop(RDX) chain += Move( @@ -311,7 +303,13 @@ object asmGenerator { chain ++= evalExprOntoStack(x) chain ++= evalExprOntoStack(i) chain += stack.pop(RCX) + chain += Compare(RCX, ImmediateVal(0)) + chain += Jump(LabelArg(OutOfBoundsError.errLabel), Cond.Less) chain += stack.pop(RAX) + chain += Compare(EAX, ImmediateVal(0)) + chain += Jump(LabelArg(NullPtrError.errLabel), Cond.Equal) + chain += Compare(MemLocation(RAX, D32), ECX) + chain += Jump(LabelArg(OutOfBoundsError.errLabel), Cond.LessEqual) // + Int because we store the length of the array at the start chain += Move( Register(x.ty.elemSize, AX), @@ -321,13 +319,22 @@ object asmGenerator { case UnaryOp(x, op) => chain ++= evalExprOntoStack(x) op match { - case UnaryOperator.Chr | UnaryOperator.Ord => // No op needed + case UnaryOperator.Chr => + chain += Move(EAX, stack.head) + chain += And(EAX, ImmediateVal(-128)) + chain += Compare(EAX, ImmediateVal(0)) + chain += Jump(LabelArg(BadChrError.errLabel), Cond.NotEqual) + case UnaryOperator.Ord => // No op needed case UnaryOperator.Len => chain += stack.pop(RAX) chain += Move(EAX, MemLocation(RAX, D32)) chain += stack.push(D32, RAX) case UnaryOperator.Negate => - chain += Negate(stack.head) + chain += Xor(EAX, EAX) + chain += Subtract(EAX, stack.head) + chain += Jump(LabelArg(OverflowError.errLabel), Cond.Overflow) + chain += stack.drop() + chain += stack.push(Q64, RAX) case UnaryOperator.Not => chain += Xor(stack.head, ImmediateVal(1)) } @@ -341,24 +348,29 @@ object asmGenerator { op match { case BinaryOperator.Add => chain += Add(stack.head, destX) + chain += Jump(LabelArg(OverflowError.errLabel), Cond.Overflow) case BinaryOperator.Sub => chain += Subtract(destX, stack.head) + chain += Jump(LabelArg(OverflowError.errLabel), Cond.Overflow) chain += stack.drop() chain += stack.push(destX.size, RAX) case BinaryOperator.Mul => chain += Multiply(destX, stack.head) + chain += Jump(LabelArg(OverflowError.errLabel), Cond.Overflow) chain += stack.drop() chain += stack.push(destX.size, RAX) case BinaryOperator.Div => chain += Compare(stack.head, ImmediateVal(0)) - chain += Jump(LabelArg(zeroDivError.errLabel), Cond.Equal) + chain += Jump(LabelArg(ZeroDivError.errLabel), Cond.Equal) chain += CDQ() chain += Divide(stack.head) chain += stack.drop() chain += stack.push(destX.size, RAX) case BinaryOperator.Mod => + chain += Compare(stack.head, ImmediateVal(0)) + chain += Jump(LabelArg(ZeroDivError.errLabel), Cond.Equal) chain += CDQ() chain += Divide(stack.head) chain += stack.drop() @@ -444,7 +456,7 @@ object asmGenerator { chain } - private def stackAlign: AsmLine = And(Register(Q64, SP), ImmediateVal(-16)) + def stackAlign: AsmLine = And(Register(Q64, SP), ImmediateVal(-16)) private def zeroRest(dest: Dest, size: Size): Chain[AsmLine] = size match { case Q64 | D32 => Chain.empty case _ => Chain.one(And(dest, ImmediateVal((1 << (size.toInt * 8)) - 1))) diff --git a/src/main/wacc/frontend/typeChecker.scala b/src/main/wacc/frontend/typeChecker.scala index 002876d..a628b69 100644 --- a/src/main/wacc/frontend/typeChecker.scala +++ b/src/main/wacc/frontend/typeChecker.scala @@ -180,8 +180,8 @@ object typeChecker { microWacc.Builtin.Read, List( destTy match { - case KnownType.Int => "%d".toMicroWaccCharArray - case KnownType.Char | _ => "%c".toMicroWaccCharArray + case KnownType.Int => " %d".toMicroWaccCharArray + case KnownType.Char | _ => " %c".toMicroWaccCharArray }, destTyped ) diff --git a/src/test/wacc/examples.scala b/src/test/wacc/examples.scala index 7c895ab..6114afd 100644 --- a/src/test/wacc/examples.scala +++ b/src/test/wacc/examples.scala @@ -41,23 +41,7 @@ class ParallelExamplesSpec extends AnyFlatSpec with BeforeAndAfterAll { val inputLine = contents .find(_.matches("^# ?[Ii]nput:.*$")) - .map(line => - ("" :: line.split(":").last.strip.split(" ").toList) - .sliding(2) - .flatMap { arr => - if ( - // First entry has no space in front - arr(0) == "" || - // int followed by non-digit, space can be removed - arr(0).toIntOption.nonEmpty && !arr(1)(0).isDigit || - // non-int followed by int, space can be removed - !arr(0).last.isDigit && arr(1).toIntOption.nonEmpty - ) - then List(arr(1)) - else List(" ", arr(1)) - } - .mkString - ) + .map(_.split(":").last.strip + "\n") .getOrElse("") val outputLineIdx = contents.indexWhere(_.matches("^# ?[Oo]utput:.*$")) val expectedOutput = @@ -92,7 +76,13 @@ class ParallelExamplesSpec extends AnyFlatSpec with BeforeAndAfterAll { ) assert(process.exitValue == expectedExit) - assert(stdout.toString.replaceAll("0x[0-9a-f]+", "#addrs#") == expectedOutput) + assert( + stdout.toString + .replaceAll("0x[0-9a-f]+", "#addrs#") + .replaceAll("fatal error:.*", "#runtime_error#\u0000") + .takeWhile(_ != '\u0000') + == expectedOutput + ) } } @@ -117,7 +107,7 @@ class ParallelExamplesSpec extends AnyFlatSpec with BeforeAndAfterAll { // "^.*wacc-examples/valid/IO/IOLoop.wacc.*$", // "^.*wacc-examples/valid/IO/IOSequence.wacc.*$", // "^.*wacc-examples/valid/pairs.*$", - "^.*wacc-examples/valid/runtimeErr.*$", + //"^.*wacc-examples/valid/runtimeErr.*$", // "^.*wacc-examples/valid/scope.*$", // "^.*wacc-examples/valid/sequence.*$", // "^.*wacc-examples/valid/variables.*$",