From 24dddcadabe6764dcedd1e0e501cf145b48ad8e3 Mon Sep 17 00:00:00 2001
From: Alex Ling <al4423@ic.ac.uk>
Date: Sat, 22 Feb 2025 21:38:12 +0000
Subject: [PATCH] feat: almost complete clib calls

---
 src/main/wacc/backend/asmGenerator.scala | 204 +++++++++++++++++++----
 src/main/wacc/backend/assemblyIR.scala   |  11 +-
 src/test/wacc/examples.scala             |  38 ++---
 3 files changed, 198 insertions(+), 55 deletions(-)

diff --git a/src/main/wacc/backend/asmGenerator.scala b/src/main/wacc/backend/asmGenerator.scala
index ab350f4..7964cf4 100644
--- a/src/main/wacc/backend/asmGenerator.scala
+++ b/src/main/wacc/backend/asmGenerator.scala
@@ -15,12 +15,15 @@ object asmGenerator {
 
     val progAsm =
       LabelDef("main") ::
+        funcPrologue() ++
+        alignStack() ++
         main.flatMap(generateStmt) ++
+        funcEpilogue() ++
         List(assemblyIR.Return()) ++
         generateFuncs()
 
     val strDirs = strings.toList.zipWithIndex.flatMap { case (str, i) =>
-      List(Directive.Int(str.size), LabelDef(s".L.str$i:"), Directive.Asciz(str))
+      List(Directive.Int(str.size), LabelDef(s".L.str$i"), Directive.Asciz(str))
     }
 
     List(Directive.IntelSyntax, Directive.Global("main"), Directive.RoData) ++
@@ -42,31 +45,47 @@ object asmGenerator {
   )(using stack: LinkedHashMap[Ident, Int], strings: ListBuffer[String]): List[AsmLine] =
     stmt match {
       case microWacc.Call(Builtin.Exit, code :: _) =>
-        alignStack() ++
-          evalExprIntoReg(code, Register(RegSize.R64, RegName.DI)) ++
+        // alignStack() ++
+        evalExprIntoReg(code, Register(RegSize.R64, RegName.DI)) ++
           List(assemblyIR.Call(CLibFunc.Exit))
 
       case microWacc.Call(Builtin.Println, expr :: _) =>
-        alignStack() ++
-          evalExprIntoReg(expr, Register(RegSize.R64, RegName.DI)) ++
-          List(
-            assemblyIR.Call(CLibFunc.Puts),
-            Move(Register(RegSize.R64, RegName.DI), ImmediateVal(0)),
-            assemblyIR.Call(CLibFunc.Fflush)
-          ) ++
-          restoreStack()
+        // alignStack() ++
+        printF(expr) ++
+          printLn()
 
-      case microWacc.Call(Builtin.ReadInt, expr :: _) =>
-        List()
+      case microWacc.Call(Builtin.Print, expr :: _) =>
+        // alignStack() ++
+        printF(expr)
 
       case Assign(lhs, rhs) =>
-        lhs match {
+        var dest: IndexAddress =
+          IndexAddress(Register(RegSize.R64, RegName.SP), 0) // gets overrwitten
+        (lhs match {
           case ident: Ident =>
-            stack += (ident -> stack.size)
-            evalExprIntoReg(rhs, Register(RegSize.R64, RegName.AX)) ++
-              List(Push(Register(RegSize.R64, RegName.AX)))
-          case _ => List()
-        }
+            if (!stack.contains(ident)) {
+              stack += (ident -> (stack.size + 1))
+              dest = accessVar(ident)
+              List(Subtract(Register(RegSize.R64, RegName.SP), ImmediateVal(16)))
+            } else {
+              dest = accessVar(ident)
+              List()
+            }
+          // TODO lhs = arrayElem
+          case _ =>
+            // dest = ???
+            List()
+        }) ++
+          (rhs match {
+            case microWacc.Call(Builtin.ReadInt, _) =>
+              readIntoVar(dest, Builtin.ReadInt)
+            case microWacc.Call(Builtin.ReadChar, _) =>
+              readIntoVar(dest, Builtin.ReadChar)
+            case _ =>
+              evalExprIntoReg(rhs, Register(RegSize.R64, RegName.AX)) ++
+                List(Move(dest, Register(RegSize.R64, RegName.AX)))
+          })
+      // TODO other statements
       case _ => List()
     }
 
@@ -74,22 +93,20 @@ object asmGenerator {
       stack: LinkedHashMap[Ident, Int],
       strings: ListBuffer[String]
   ): List[AsmLine] = {
-    var src: Src = ImmediateVal(0) // Placeholder
-    (expr match {
+    expr match {
       case IntLiter(v) =>
-        src = ImmediateVal(v)
-        List()
+        List(Move(dest, ImmediateVal(v)))
+      case CharLiter(v) =>
+        List(Move(dest, ImmediateVal(v.toInt)))
       case ident: Ident =>
-        List(
-          Move(
-            dest,
-            IndexAddress(Register(RegSize.R64, RegName.SP), (stack.size - stack(ident)) * 4)
-          )
-        )
+        List(Move(dest, accessVar(ident)))
       case ArrayLiter(elems) =>
         expr.ty match {
-          case KnownType.Char =>
-            strings += elems.mkString
+          case KnownType.String =>
+            strings += elems.map {
+              case CharLiter(v) => v
+              case _            => ""
+            }.mkString
             List(
               Load(
                 dest,
@@ -99,22 +116,59 @@ object asmGenerator {
                 )
               )
             )
+          // TODO other array types
           case _ => List()
         }
+      // TODO other expr types
       case _ => List()
-    }) ++ List(Move(dest, src))
+    }
   }
 
+  // TODO make sure EOF doenst override the value in the stack
+  // probably need labels implemented for conditional jumps
+  def readIntoVar(dest: IndexAddress, readType: Builtin.ReadInt.type | Builtin.ReadChar.type)(using
+      stack: LinkedHashMap[Ident, Int],
+      strings: ListBuffer[String]
+  ): List[AsmLine] = {
+    readType match {
+      case Builtin.ReadInt =>
+        strings += PrintFormat.Int.toString
+      case Builtin.ReadChar =>
+        strings += PrintFormat.Char.toString
+    }
+    List(
+      Load(
+        Register(RegSize.R64, RegName.DI),
+        IndexAddress(
+          Register(RegSize.R64, RegName.IP),
+          LabelArg(s".L.str${strings.size - 1}")
+        )
+      ),
+      Load(Register(RegSize.R64, RegName.SI), dest)
+    ) ++
+      // alignStack() ++
+      List(assemblyIR.Call(CLibFunc.Scanf))
+
+  }
+
+  def accessVar(ident: Ident)(using stack: LinkedHashMap[Ident, Int]): IndexAddress =
+    IndexAddress(Register(RegSize.R64, RegName.SP), (stack.size - stack(ident)) * 16)
+
   def alignStack()(using stack: LinkedHashMap[Ident, Int]): List[AsmLine] = {
     List(
-      And(Register(RegSize.R64, RegName.SP), ImmediateVal(-16)),
-      // Store stack pointer in rbp as it is callee saved
+      And(Register(RegSize.R64, RegName.SP), ImmediateVal(-16))
+    )
+  }
+
+  // Missing a sub instruction but dont think we need it
+  def funcPrologue(): List[AsmLine] = {
+    List(
       Push(Register(RegSize.R64, RegName.BP)),
       Move(Register(RegSize.R64, RegName.BP), Register(RegSize.R64, RegName.SP))
     )
   }
 
-  def restoreStack()(using stack: LinkedHashMap[Ident, Int]): List[AsmLine] = {
+  def funcEpilogue(): List[AsmLine] = {
     List(
       Move(Register(RegSize.R64, RegName.SP), Register(RegSize.R64, RegName.BP)),
       Pop(Register(RegSize.R64, RegName.BP))
@@ -123,4 +177,84 @@ object asmGenerator {
 
   // def saveRegs(regList: List[Register]): List[AsmLine] = regList.map(Push(_))
   // def restoreRegs(regList: List[Register]): List[AsmLine] = regList.reverse.map(Pop(_))
+
+// TODO: refactor, really ugly function
+  def printF(expr: Expr)(using
+      stack: LinkedHashMap[Ident, Int],
+      strings: ListBuffer[String]
+  ): List[AsmLine] = {
+// determine the format string
+    expr.ty match {
+      case KnownType.String =>
+        strings += PrintFormat.String.toString
+      case KnownType.Char =>
+        strings += PrintFormat.Char.toString
+      case KnownType.Int =>
+        strings += PrintFormat.Int.toString
+      case _ =>
+        strings += PrintFormat.String.toString
+    }
+    List(
+      Load(
+        Register(RegSize.R64, RegName.DI),
+        IndexAddress(
+          Register(RegSize.R64, RegName.IP),
+          LabelArg(s".L.str${strings.size - 1}")
+        )
+      )
+    )
+      ++
+        // determine the actual value to print
+        (if (expr.ty == KnownType.Bool) {
+           expr match {
+             case BoolLiter(true) => {
+               strings += "true"
+             }
+             case _ => {
+               strings += "false"
+             }
+           }
+           List(
+             Load(
+               Register(RegSize.R64, RegName.DI),
+               IndexAddress(
+                 Register(RegSize.R64, RegName.IP),
+                 LabelArg(s".L.str${strings.size - 1}")
+               )
+             )
+           )
+
+         } else {
+           evalExprIntoReg(expr, Register(RegSize.R64, RegName.SI))
+         })
+        // print the value
+        ++
+        List(
+          assemblyIR.Call(CLibFunc.PrintF),
+          Move(Register(RegSize.R64, RegName.DI), ImmediateVal(0)),
+          assemblyIR.Call(CLibFunc.Fflush)
+        )
+  }
+
+// prints a new line
+  def printLn()(using
+      stack: LinkedHashMap[Ident, Int],
+      strings: ListBuffer[String]
+  ): List[AsmLine] = {
+    strings += ""
+    Load(
+      Register(RegSize.R64, RegName.DI),
+      IndexAddress(
+        Register(RegSize.R64, RegName.IP),
+        LabelArg(s".L.str${strings.size - 1}")
+      )
+    )
+      ::
+        List(
+          assemblyIR.Call(CLibFunc.Puts),
+          Move(Register(RegSize.R64, RegName.DI), ImmediateVal(0)),
+          assemblyIR.Call(CLibFunc.Fflush)
+        )
+
+  }
 }
diff --git a/src/main/wacc/backend/assemblyIR.scala b/src/main/wacc/backend/assemblyIR.scala
index 73cdeaf..c48daac 100644
--- a/src/main/wacc/backend/assemblyIR.scala
+++ b/src/main/wacc/backend/assemblyIR.scala
@@ -143,8 +143,17 @@ object assemblyIR {
       case Text          => ".text"
       case RoData        => ".section .rodata"
       case Int(value)    => s".int $value"
-      case Asciz(string) => s".asciz $string"
+      case Asciz(string) => s".asciz \"$string\""
+    }
+  }
 
+  enum PrintFormat {
+    case Int, Char, String
+
+    override def toString(): String = this match {
+      case Int    => "%d"
+      case Char   => "%c"
+      case String => "%s"
     }
   }
 }
diff --git a/src/test/wacc/examples.scala b/src/test/wacc/examples.scala
index 970bcb6..abf6769 100644
--- a/src/test/wacc/examples.scala
+++ b/src/test/wacc/examples.scala
@@ -47,7 +47,7 @@ class ParallelExamplesSpec extends AnyFlatSpec with BeforeAndAfterAll with Paral
             .drop(outputLineIdx + 1)
             .takeWhile(_.startsWith("#"))
             .map(_.stripPrefix("#").stripLeading)
-            .mkString("\n")
+            .mkString("")
 
       val exitLineIdx = contents.indexWhere(_.matches("^# ?[Ee]xit:.*$"))
       val expectedExit =
@@ -79,24 +79,24 @@ class ParallelExamplesSpec extends AnyFlatSpec with BeforeAndAfterAll with Paral
     Seq(
       // format: off
       // disable formatting to avoid binPack
-      "^.*wacc-examples/valid/advanced.*$",
-      "^.*wacc-examples/valid/array.*$",
-      "^.*wacc-examples/valid/basic/exit.*$",
-      "^.*wacc-examples/valid/basic/skip.*$",
-      "^.*wacc-examples/valid/expressions.*$",
-      "^.*wacc-examples/valid/function/nested_functions.*$",
-      "^.*wacc-examples/valid/function/simple_functions.*$",
-      "^.*wacc-examples/valid/if.*$",
-      "^.*wacc-examples/valid/IO/print.*$",
-      "^.*wacc-examples/valid/IO/read.*$",
-      "^.*wacc-examples/valid/IO/IOLoop.wacc.*$",
-      "^.*wacc-examples/valid/IO/IOSequence.wacc.*$",
-      "^.*wacc-examples/valid/pairs.*$",
-      "^.*wacc-examples/valid/runtimeErr.*$",
-      "^.*wacc-examples/valid/scope.*$",
-      "^.*wacc-examples/valid/sequence.*$",
-      "^.*wacc-examples/valid/variables.*$",
-      "^.*wacc-examples/valid/while.*$",
+      // "^.*wacc-examples/valid/advanced.*$",
+      // "^.*wacc-examples/valid/array.*$",
+      // "^.*wacc-examples/valid/basic/exit.*$",
+      // "^.*wacc-examples/valid/basic/skip.*$",
+      // "^.*wacc-examples/valid/expressions.*$",
+      // "^.*wacc-examples/valid/function/nested_functions.*$",
+      // "^.*wacc-examples/valid/function/simple_functions.*$",
+      // "^.*wacc-examples/valid/if.*$",
+      // "^.*wacc-examples/valid/IO/print.*$",
+      // "^.*wacc-examples/valid/IO/read.*$",
+      // "^.*wacc-examples/valid/IO/IOLoop.wacc.*$",
+      // "^.*wacc-examples/valid/IO/IOSequence.wacc.*$",
+      // "^.*wacc-examples/valid/pairs.*$",
+      // "^.*wacc-examples/valid/runtimeErr.*$",
+      // "^.*wacc-examples/valid/scope.*$",
+      // "^.*wacc-examples/valid/sequence.*$",
+      // "^.*wacc-examples/valid/variables.*$",
+      // "^.*wacc-examples/valid/while.*$",
       // format: on
     ).find(filename.matches).isDefined
 }