diff --git a/src/main/wacc/Main.scala b/src/main/wacc/Main.scala index 445d9c1..8221440 100644 --- a/src/main/wacc/Main.scala +++ b/src/main/wacc/Main.scala @@ -4,6 +4,10 @@ import scala.collection.mutable import parsley.{Failure, Success} import scopt.OParser import java.io.File +import java.io.PrintStream + +import assemblyIR as asm +import wacc.microWacc.IntLiter case class CliConfig( file: File = new File(".") @@ -30,7 +34,9 @@ val cliParser = { ) } -def compile(contents: String): Int = { +def frontend( + contents: String +)(using stdout: PrintStream): Either[microWacc.Program, Int] = { parser.parse(contents) match { case Success(prog) => given errors: mutable.Builder[Error, List[Error]] = List.newBuilder @@ -39,28 +45,194 @@ def compile(contents: String): Int = { val typedProg = typeChecker.check(prog) if (errors.result.nonEmpty) { given errorContent: String = contents - errors.result - .map { error => - printError(error) - error match { - case _: Error.InternalError => 201 - case _ => 200 + Right( + errors.result + .map { error => + printError(error) + error match { + case _: Error.InternalError => 201 + case _ => 200 + } } - } - .max() - } else { - println(typedProg) - 0 - } + .max() + ) + } else Left(typedProg) case Failure(msg) => - println(msg) - 100 + stdout.println(msg) + Right(100) } } +val s = "enter an integer to echo" +def backend(typedProg: microWacc.Program): List[asm.AsmLine] | String = + typedProg match { + case microWacc.Program( + Nil, + microWacc.Call(microWacc.Builtin.Exit, microWacc.IntLiter(v) :: Nil) :: Nil + ) => + s""".intel_syntax noprefix +.globl main +main: + mov edi, ${v} + call exit@plt +""" + case microWacc.Program( + Nil, + microWacc.Assign(microWacc.Ident("x", _), microWacc.IntLiter(1)) :: + microWacc.Call(microWacc.Builtin.Println, _) :: + microWacc.Assign( + microWacc.Ident("x", _), + microWacc.Call(microWacc.Builtin.ReadInt, Nil) + ) :: + microWacc.Call(microWacc.Builtin.Println, microWacc.Ident("x", _) :: Nil) :: Nil + ) => + """.intel_syntax noprefix +.globl main +.section .rodata +# length of .L.str0 + .int 24 +.L.str0: + .asciz "enter an integer to echo" +.text +main: + push rbp + # push {rbx, r12} + sub rsp, 16 + mov qword ptr [rsp], rbx + mov qword ptr [rsp + 8], r12 + mov rbp, rsp + mov r12d, 1 + lea rdi, [rip + .L.str0] + # statement primitives do not return results (but will clobber r0/rax) + call _prints + call _println + # load the current value in the destination of the read so it supports defaults + mov edi, r12d + call _readi + mov r12d, eax + mov edi, eax + # statement primitives do not return results (but will clobber r0/rax) + call _printi + call _println + mov rax, 0 + # pop/peek {rbx, r12} + mov rbx, qword ptr [rsp] + mov r12, qword ptr [rsp + 8] + add rsp, 16 + pop rbp + ret + +.section .rodata +# length of .L._printi_str0 + .int 2 +.L._printi_str0: + .asciz "%d" +.text +_printi: + push rbp + mov rbp, rsp + # external calls must be stack-aligned to 16 bytes, accomplished by masking with fffffffffffffff0 + and rsp, -16 + mov esi, edi + lea rdi, [rip + .L._printi_str0] + # on x86, al represents the number of SIMD registers used as variadic arguments + mov al, 0 + call printf@plt + mov rdi, 0 + call fflush@plt + mov rsp, rbp + pop rbp + ret + +.section .rodata +# length of .L._prints_str0 + .int 4 +.L._prints_str0: + .asciz "%.*s" +.text +_prints: + push rbp + mov rbp, rsp + # external calls must be stack-aligned to 16 bytes, accomplished by masking with fffffffffffffff0 + and rsp, -16 + mov rdx, rdi + mov esi, dword ptr [rdi - 4] + lea rdi, [rip + .L._prints_str0] + # on x86, al represents the number of SIMD registers used as variadic arguments + mov al, 0 + call printf@plt + mov rdi, 0 + call fflush@plt + mov rsp, rbp + pop rbp + ret + +.section .rodata +# length of .L._println_str0 + .int 0 +.L._println_str0: + .asciz "" +.text +_println: + push rbp + mov rbp, rsp + # external calls must be stack-aligned to 16 bytes, accomplished by masking with fffffffffffffff0 + and rsp, -16 + lea rdi, [rip + .L._println_str0] + call puts@plt + mov rdi, 0 + call fflush@plt + mov rsp, rbp + pop rbp + ret + +.section .rodata +# length of .L._readi_str0 + .int 2 +.L._readi_str0: + .asciz "%d" +.text +_readi: + push rbp + mov rbp, rsp + # external calls must be stack-aligned to 16 bytes, accomplished by masking with fffffffffffffff0 + and rsp, -16 + # RDI contains the "original" value of the destination of the read + # allocate space on the stack to store the read: preserve alignment! + # the passed default argument should be stored in case of EOF + sub rsp, 16 + mov dword ptr [rsp], edi + lea rsi, qword ptr [rsp] + lea rdi, [rip + .L._readi_str0] + # on x86, al represents the number of SIMD registers used as variadic arguments + mov al, 0 + call scanf@plt + mov eax, dword ptr [rsp] + add rsp, 16 + mov rsp, rbp + pop rbp + ret + """ + case _ => List() + } + +def compile(filename: String)(using stdout: PrintStream = Console.out): Int = + frontend(os.read(os.Path(filename))) match { + case Left(typedProg) => + backend(typedProg) match { + case s: String => + os.write.over(os.Path(filename.stripSuffix(".wacc") + ".s"), s) + case ops: List[asm.AsmLine] => { + val outFile = File(filename.stripSuffix(".wacc") + ".s") + writer.writeTo(ops, PrintStream(outFile)) + } + } + 0 + case Right(exitCode) => exitCode + } + def main(args: Array[String]): Unit = OParser.parse(cliParser, args, CliConfig()) match { - case Some(config) => - System.exit(compile(os.read(os.Path(config.file.getAbsolutePath)))) - case None => + case Some(config) => compile(config.file.getAbsolutePath) + case None => } diff --git a/src/main/wacc/frontend/Error.scala b/src/main/wacc/frontend/Error.scala index 0f3c01d..9c02a60 100644 --- a/src/main/wacc/frontend/Error.scala +++ b/src/main/wacc/frontend/Error.scala @@ -2,6 +2,7 @@ package wacc import wacc.ast.Position import wacc.types._ +import java.io.PrintStream /** Error types for semantic errors */ @@ -23,39 +24,39 @@ enum Error { * @param errorContent * Contents of the file to generate code snippets */ -def printError(error: Error)(using errorContent: String): Unit = { - println("Semantic error:") +def printError(error: Error)(using errorContent: String, stdout: PrintStream): Unit = { + stdout.println("Semantic error:") error match { case Error.DuplicateDeclaration(ident) => printPosition(ident.pos) - println(s"Duplicate declaration of identifier ${ident.v}") + stdout.println(s"Duplicate declaration of identifier ${ident.v}") highlight(ident.pos, ident.v.length) case Error.UndeclaredVariable(ident) => printPosition(ident.pos) - println(s"Undeclared variable ${ident.v}") + stdout.println(s"Undeclared variable ${ident.v}") highlight(ident.pos, ident.v.length) case Error.UndefinedFunction(ident) => printPosition(ident.pos) - println(s"Undefined function ${ident.v}") + stdout.println(s"Undefined function ${ident.v}") highlight(ident.pos, ident.v.length) case Error.FunctionParamsMismatch(id, expected, got, funcType) => printPosition(id.pos) - println(s"Function expects $expected parameters, got $got") - println( + stdout.println(s"Function expects $expected parameters, got $got") + stdout.println( s"(function ${id.v} has type (${funcType.params.mkString(", ")}) -> ${funcType.returnType})" ) highlight(id.pos, 1) case Error.TypeMismatch(pos, expected, got, msg) => printPosition(pos) - println(s"Type mismatch: $msg\nExpected: $expected\nGot: $got") + stdout.println(s"Type mismatch: $msg\nExpected: $expected\nGot: $got") highlight(pos, 1) case Error.SemanticError(pos, msg) => printPosition(pos) - println(msg) + stdout.println(msg) highlight(pos, 1) case wacc.Error.InternalError(pos, msg) => printPosition(pos) - println(s"Internal error: $msg") + stdout.println(s"Internal error: $msg") highlight(pos, 1) } @@ -70,7 +71,7 @@ def printError(error: Error)(using errorContent: String): Unit = { * @param errorContent * Contents of the file to generate code snippets */ -def highlight(pos: Position, size: Int)(using errorContent: String): Unit = { +def highlight(pos: Position, size: Int)(using errorContent: String, stdout: PrintStream): Unit = { val lines = errorContent.split("\n") val preLine = if (pos.line > 1) lines(pos.line - 2) else "" @@ -78,7 +79,7 @@ def highlight(pos: Position, size: Int)(using errorContent: String): Unit = { val postLine = if (pos.line < lines.size) lines(pos.line) else "" val linePointer = " " * (pos.column + 2) + ("^" * (size)) + "\n" - println( + stdout.println( s" >$preLine\n >$midLine\n$linePointer >$postLine" ) } @@ -88,6 +89,6 @@ def highlight(pos: Position, size: Int)(using errorContent: String): Unit = { * @param pos * Position of the error */ -def printPosition(pos: Position): Unit = { - println(s"(line ${pos.line}, column ${pos.column}):") +def printPosition(pos: Position)(using stdout: PrintStream): Unit = { + stdout.println(s"(line ${pos.line}, column ${pos.column}):") }