feat: implements binary operators in asmGenerator

Co-authored-by: Gleb Koval <gleb@koval.net>
Co-authored-by: Barf-Vader <Barf-Vader@users.noreply.github.com>
This commit is contained in:
Guy C 2025-02-25 00:00:12 +00:00
parent 668d7338ae
commit 1488281223
4 changed files with 285 additions and 109 deletions

View File

@ -2,18 +2,36 @@ package wacc
import scala.collection.mutable.LinkedHashMap import scala.collection.mutable.LinkedHashMap
import scala.collection.mutable.ListBuffer import scala.collection.mutable.ListBuffer
import cats.syntax.align
object labelGenerator { object asmGenerator {
import microWacc._
import assemblyIR._
import wacc.types._
val RAX = Register(RegSize.R64, RegName.AX)
val EAX = Register(RegSize.E32, RegName.AX)
val RSP = Register(RegSize.R64, RegName.SP)
val ESP = Register(RegSize.E32, RegName.SP)
val EDX = Register(RegSize.E32, RegName.DX)
val RDI = Register(RegSize.R64, RegName.DI)
val RIP = Register(RegSize.R64, RegName.IP)
val RBP = Register(RegSize.R64, RegName.BP)
val RSI = Register(RegSize.R64, RegName.SI)
val _8_BIT_MASK = 0xFF
object labelGenerator {
var labelVal = -1 var labelVal = -1
def getLabel(): String = { def getLabel(): String = {
labelVal += 1 labelVal += 1
s".L$labelVal" s".L$labelVal"
} }
} def getLabel(target: CallTarget): String = target match{
object asmGenerator { case Ident(v,_) => s"wacc_$v"
import microWacc._ case Builtin(name) => s"_$name"
import assemblyIR._ }
import wacc.types._ }
def generateAsm(microProg: Program): List[AsmLine] = { def generateAsm(microProg: Program): List[AsmLine] = {
given stack: LinkedHashMap[Ident, Int] = LinkedHashMap[Ident, Int]() given stack: LinkedHashMap[Ident, Int] = LinkedHashMap[Ident, Int]()
@ -25,10 +43,8 @@ object asmGenerator {
funcPrologue() ++ funcPrologue() ++
alignStack() ++ alignStack() ++
main.flatMap(generateStmt) ++ main.flatMap(generateStmt) ++
List(Move(Register(RegSize.R64, RegName.AX), ImmediateVal(0))) ++ List(Move(RAX, ImmediateVal(0))) ++
funcEpilogue() ++ funcEpilogue()
List(assemblyIR.Return()) ++
generateFuncs()
val strDirs = strings.toList.zipWithIndex.flatMap { case (str, i) => val strDirs = strings.toList.zipWithIndex.flatMap { case (str, i) =>
List(Directive.Int(str.size), LabelDef(s".L.str$i"), Directive.Asciz(str)) List(Directive.Int(str.size), LabelDef(s".L.str$i"), Directive.Asciz(str))
@ -40,12 +56,38 @@ object asmGenerator {
progAsm progAsm
} }
//TODO def wrapFunc(labelName: String, funcBody: List[AsmLine])(using
def generateFuncs()(using
stack: LinkedHashMap[Ident, Int], stack: LinkedHashMap[Ident, Int],
strings: ListBuffer[String] strings: ListBuffer[String]
): List[AsmLine] = { ): List[AsmLine] = {
LabelDef(labelName) ::
funcPrologue() ++
funcBody ++
funcEpilogue()
}
def generateBuiltInFuncs()(using
stack: LinkedHashMap[Ident, Int],
strings: ListBuffer[String]
): List[AsmLine] = {
wrapFunc(labelGenerator.getLabel(Builtin.Exit),
alignStack() ++
List(Pop(RDI),
assemblyIR.Call(CLibFunc.Exit))
) ++
wrapFunc(labelGenerator.getLabel(Builtin.Printf),
alignStack() ++
List(assemblyIR.Call(CLibFunc.PrintF),
Move(RDI, ImmediateVal(0)),
assemblyIR.Call(CLibFunc.Fflush)
)
)++
wrapFunc(labelGenerator.getLabel(Builtin.Malloc),
List() List()
)++
wrapFunc(labelGenerator.getLabel(Builtin.Free),
List()
)
} }
def generateStmt( def generateStmt(
@ -53,28 +95,16 @@ object asmGenerator {
)(using stack: LinkedHashMap[Ident, Int], strings: ListBuffer[String]): List[AsmLine] = )(using stack: LinkedHashMap[Ident, Int], strings: ListBuffer[String]): List[AsmLine] =
stmt match { stmt match {
case microWacc.Call(Builtin.Exit, code :: _) => case microWacc.Call(Builtin.Exit, code :: _) =>
// alignStack() ++ List()
evalExprIntoReg(code, Register(RegSize.R64, RegName.DI)) ++
List(assemblyIR.Call(CLibFunc.Exit))
case microWacc.Call(Builtin.Println, expr :: _) =>
// alignStack() ++
printF(expr) ++
printLn()
case microWacc.Call(Builtin.Print, expr :: _) =>
// alignStack() ++
printF(expr)
case Assign(lhs, rhs) => case Assign(lhs, rhs) =>
var dest: IndexAddress = var dest: IndexAddress =
IndexAddress(Register(RegSize.R64, RegName.SP), 0) // gets overrwitten IndexAddress(RSP, 0) // gets overrwitten
(lhs match { (lhs match {
case ident: Ident => case ident: Ident =>
if (!stack.contains(ident)) { if (!stack.contains(ident)) {
stack += (ident -> (stack.size + 1)) stack += (ident -> (stack.size + 1))
dest = accessVar(ident) dest = accessVar(ident)
List(Subtract(Register(RegSize.R64, RegName.SP), ImmediateVal(16))) List(Subtract(RSP, ImmediateVal(8)))
} else { } else {
dest = accessVar(ident) dest = accessVar(ident)
List() List()
@ -90,15 +120,16 @@ object asmGenerator {
case microWacc.Call(Builtin.ReadChar, _) => case microWacc.Call(Builtin.ReadChar, _) =>
readIntoVar(dest, Builtin.ReadChar) readIntoVar(dest, Builtin.ReadChar)
case _ => case _ =>
evalExprIntoReg(rhs, Register(RegSize.R64, RegName.AX)) ++ evalExprOntoStack(rhs) ++
List(Move(dest, Register(RegSize.R64, RegName.AX))) List(Pop(dest))
}) })
case If(cond, thenBranch, elseBranch) => { case If(cond, thenBranch, elseBranch) => {
val elseLabel = labelGenerator.getLabel() val elseLabel = labelGenerator.getLabel()
val endLabel = labelGenerator.getLabel() val endLabel = labelGenerator.getLabel()
evalExprIntoReg(cond, Register(RegSize.R64, RegName.AX)) ++ evalExprOntoStack(cond) ++
List( List(
Compare(Register(RegSize.R64, RegName.AX), ImmediateVal(0)), Compare(MemLocation(RSP, SizeDir.Word), ImmediateVal(0)),
Add(RSP, ImmediateVal(8)),
Jump(LabelArg(elseLabel), Cond.Equal) Jump(LabelArg(elseLabel), Cond.Equal)
) ++ ) ++
thenBranch.flatMap(generateStmt) ++ thenBranch.flatMap(generateStmt) ++
@ -110,30 +141,33 @@ object asmGenerator {
val startLabel = labelGenerator.getLabel() val startLabel = labelGenerator.getLabel()
val endLabel = labelGenerator.getLabel() val endLabel = labelGenerator.getLabel()
List(LabelDef(startLabel)) ++ List(LabelDef(startLabel)) ++
evalExprIntoReg(cond, Register(RegSize.R64, RegName.AX)) ++ evalExprOntoStack(cond) ++
List( List(
Compare(Register(RegSize.R64, RegName.AX), ImmediateVal(0)), Compare(MemLocation(RSP, SizeDir.Word), ImmediateVal(0)),
Add(RSP, ImmediateVal(8)),
Jump(LabelArg(endLabel), Cond.Equal) Jump(LabelArg(endLabel), Cond.Equal)
) ++ ) ++
body.flatMap(generateStmt) ++ body.flatMap(generateStmt) ++
List(Jump(LabelArg(startLabel)), LabelDef(endLabel)) List(Jump(LabelArg(startLabel)), LabelDef(endLabel))
} }
case microWacc.Return(expr) => case microWacc.Return(expr) =>
evalExprIntoReg(expr, Register(RegSize.R64, RegName.AX)) evalExprOntoStack(expr) ++
List(Pop(RAX), assemblyIR.Return())
case _ => List() case _ => List()
} }
def evalExprIntoReg(expr: Expr, dest: Register)(using def evalExprOntoStack(expr: Expr)(using
stack: LinkedHashMap[Ident, Int], stack: LinkedHashMap[Ident, Int],
strings: ListBuffer[String] strings: ListBuffer[String]
): List[AsmLine] = { ): List[AsmLine] = {
expr match { expr match {
case IntLiter(v) => case IntLiter(v) =>
List(Move(dest, ImmediateVal(v))) List(Push(ImmediateVal(v)))
case CharLiter(v) => case CharLiter(v) =>
List(Move(dest, ImmediateVal(v.toInt))) List(Push(ImmediateVal(v.toInt)))
case ident: Ident => case ident: Ident =>
List(Move(dest, accessVar(ident))) List(Push(accessVar(ident)))
case ArrayLiter(elems) => case ArrayLiter(elems) =>
expr.ty match { expr.ty match {
case KnownType.String => case KnownType.String =>
@ -143,70 +177,173 @@ object asmGenerator {
}.mkString }.mkString
List( List(
Load( Load(
dest, RAX,
IndexAddress( IndexAddress(
Register(RegSize.R64, RegName.IP), RIP,
LabelArg(s".L.str${strings.size - 1}") LabelArg(s".L.str${strings.size - 1}")
) )
) ),
Push(RAX)
) )
// TODO other array types // TODO other array types
case _ => List() case _ => List()
} }
// TODO other expr types case BoolLiter(v) => List(Push(ImmediateVal(if (v) 1 else 0)))
case BoolLiter(v) => List(Move(dest, ImmediateVal(if (v) 1 else 0))) case NullLiter() => List(Push(ImmediateVal(0)))
case _ => List() case ArrayElem(value, indices) => List()
case UnaryOp(x, op) => op match {
// TODO: chr and ord are TYPE CASTS. They do not change the internal value,
// but will need bound checking e.t.c.
case UnaryOperator.Chr => List()
case UnaryOperator.Ord => List()
case UnaryOperator.Len => List()
case UnaryOperator.Negate => List(
Negate(MemLocation(RSP,SizeDir.Word))
)
case UnaryOperator.Not =>
evalExprOntoStack(x) ++
List(
Xor(MemLocation(RSP, SizeDir.Word), ImmediateVal(1))
)
}
case BinaryOp(x, y, op) => op match {
case BinaryOperator.Add =>
evalExprOntoStack(x) ++
evalExprOntoStack(y) ++
List(
Pop(EAX),
Add(MemLocation(RSP, SizeDir.Word), EAX)
// TODO OVERFLOWING
)
case BinaryOperator.Sub =>
evalExprOntoStack(x) ++
evalExprOntoStack(y) ++
List(
Pop(EAX),
Subtract(MemLocation(RSP, SizeDir.Word), EAX)
// TODO OVERFLOWING
)
case BinaryOperator.Mul =>
evalExprOntoStack(x) ++
evalExprOntoStack(y) ++
List(
Pop(EAX),
Multiply(MemLocation(RSP, SizeDir.Word), EAX)
// TODO OVERFLOWING
)
case BinaryOperator.Div =>
evalExprOntoStack(y) ++
evalExprOntoStack(x) ++
List(
Pop(EAX),
Divide(MemLocation(RSP, SizeDir.Word)),
Add(RSP, ImmediateVal(8)),
Push(EAX)
// TODO CHECK DIVISOR IS NOT 0
)
case BinaryOperator.Mod =>
evalExprOntoStack(y) ++
evalExprOntoStack(x) ++
List(
Pop(EAX),
Divide(MemLocation(RSP, SizeDir.Word)),
Add(RSP, ImmediateVal(8)),
Push(EDX)
// TODO CHECK DIVISOR IS NOT 0
)
case BinaryOperator.Eq =>
generateComparison(x, y, Cond.Equal)
case BinaryOperator.Neq =>
generateComparison(x, y, Cond.NotEqual)
case BinaryOperator.Greater =>
generateComparison(x, y, Cond.Greater)
case BinaryOperator.GreaterEq =>
generateComparison(x, y, Cond.GreaterEqual)
case BinaryOperator.Less =>
generateComparison(x, y, Cond.Less)
case BinaryOperator.LessEq =>
generateComparison(x, y, Cond.LessEqual)
case BinaryOperator.And =>
evalExprOntoStack(x) ++
evalExprOntoStack(y) ++
List(
Pop(EAX),
And(MemLocation(RSP, SizeDir.Word), EAX),
)
case BinaryOperator.Or =>
evalExprOntoStack(x) ++
evalExprOntoStack(y) ++
List(
Pop(EAX),
Or(MemLocation(RSP, SizeDir.Word), EAX),
)
}
case microWacc.Call(target, args) => List()
} }
} }
// TODO make sure EOF doenst override the value in the stack // def readIntoVar(dest: IndexAddress, readType: Builtin.ReadInt.type | Builtin.ReadChar.type)(using
// probably need labels implemented for conditional jumps // stack: LinkedHashMap[Ident, Int],
def readIntoVar(dest: IndexAddress, readType: Builtin.ReadInt.type | Builtin.ReadChar.type)(using // strings: ListBuffer[String]
// ): List[AsmLine] = {
// readType match {
// case Builtin.ReadInt =>
// strings += PrintFormat.Int.toString
// case Builtin.ReadChar =>
// strings += PrintFormat.Char.toString
// }
// List(
// Load(
// RDI,
// IndexAddress(
// RIP,
// LabelArg(s".L.str${strings.size - 1}")
// )
// ),
// Load(RSI, dest)
// ) ++
// // alignStack() ++
// List(assemblyIR.Call(CLibFunc.Scanf))
// }
def generateComparison(x : Expr, y: Expr, cond: Cond)(using
stack: LinkedHashMap[Ident, Int], stack: LinkedHashMap[Ident, Int],
strings: ListBuffer[String] strings: ListBuffer[String]
): List[AsmLine] = { ): List[AsmLine] = {
readType match { evalExprOntoStack(x) ++
case Builtin.ReadInt => evalExprOntoStack(y) ++
strings += PrintFormat.Int.toString
case Builtin.ReadChar =>
strings += PrintFormat.Char.toString
}
List( List(
Load( Pop(EAX),
Register(RegSize.R64, RegName.DI), Compare(MemLocation(RSP, SizeDir.Word), EAX),
IndexAddress( Set(Register(RegSize.Byte, RegName.AL), cond),
Register(RegSize.R64, RegName.IP), And(EAX, ImmediateVal(_8_BIT_MASK)),
LabelArg(s".L.str${strings.size - 1}") Push(EAX)
) )
),
Load(Register(RegSize.R64, RegName.SI), dest)
) ++
// alignStack() ++
List(assemblyIR.Call(CLibFunc.Scanf))
} }
def accessVar(ident: Ident)(using stack: LinkedHashMap[Ident, Int]): IndexAddress = def accessVar(ident: Ident)(using stack: LinkedHashMap[Ident, Int]): IndexAddress =
IndexAddress(Register(RegSize.R64, RegName.SP), (stack.size - stack(ident)) * 16) IndexAddress(RSP, (stack.size - stack(ident)) * 8)
def alignStack()(using stack: LinkedHashMap[Ident, Int]): List[AsmLine] = { def alignStack()(using stack: LinkedHashMap[Ident, Int]): List[AsmLine] = {
List( List(
And(Register(RegSize.R64, RegName.SP), ImmediateVal(-16)) And(RSP, ImmediateVal(-16))
) )
} }
// Missing a sub instruction but dont think we need it // Missing a sub instruction but dont think we need it
def funcPrologue(): List[AsmLine] = { def funcPrologue(): List[AsmLine] = {
List( List(
Push(Register(RegSize.R64, RegName.BP)), Push(RBP),
Move(Register(RegSize.R64, RegName.BP), Register(RegSize.R64, RegName.SP)) Move(RBP, RSP)
) )
} }
def funcEpilogue(): List[AsmLine] = { def funcEpilogue(): List[AsmLine] = {
List( List(
Move(Register(RegSize.R64, RegName.SP), Register(RegSize.R64, RegName.BP)), Move(RSP, RBP),
Pop(Register(RegSize.R64, RegName.BP)) Pop(RBP),
assemblyIR.Return()
) )
} }
@ -231,9 +368,9 @@ object asmGenerator {
} }
List( List(
Load( Load(
Register(RegSize.R64, RegName.DI), RDI,
IndexAddress( IndexAddress(
Register(RegSize.R64, RegName.IP), RIP,
LabelArg(s".L.str${strings.size - 1}") LabelArg(s".L.str${strings.size - 1}")
) )
) )
@ -251,22 +388,23 @@ object asmGenerator {
} }
List( List(
Load( Load(
Register(RegSize.R64, RegName.DI), RDI,
IndexAddress( IndexAddress(
Register(RegSize.R64, RegName.IP), RIP,
LabelArg(s".L.str${strings.size - 1}") LabelArg(s".L.str${strings.size - 1}")
) )
) )
) )
} else { } else {
evalExprIntoReg(expr, Register(RegSize.R64, RegName.SI)) evalExprOntoStack(expr) ++
List(Pop(RSI))
}) })
// print the value // print the value
++ ++
List( List(
assemblyIR.Call(CLibFunc.PrintF), assemblyIR.Call(CLibFunc.PrintF),
Move(Register(RegSize.R64, RegName.DI), ImmediateVal(0)), Move(RDI, ImmediateVal(0)),
assemblyIR.Call(CLibFunc.Fflush) assemblyIR.Call(CLibFunc.Fflush)
) )
} }
@ -278,16 +416,16 @@ object asmGenerator {
): List[AsmLine] = { ): List[AsmLine] = {
strings += "" strings += ""
Load( Load(
Register(RegSize.R64, RegName.DI), RDI,
IndexAddress( IndexAddress(
Register(RegSize.R64, RegName.IP), RIP,
LabelArg(s".L.str${strings.size - 1}") LabelArg(s".L.str${strings.size - 1}")
) )
) )
:: ::
List( List(
assemblyIR.Call(CLibFunc.Puts), assemblyIR.Call(CLibFunc.Puts),
Move(Register(RegSize.R64, RegName.DI), ImmediateVal(0)), Move(RDI, ImmediateVal(0)),
assemblyIR.Call(CLibFunc.Fflush) assemblyIR.Call(CLibFunc.Fflush)
) )

View File

@ -9,17 +9,20 @@ object assemblyIR {
enum RegSize { enum RegSize {
case R64 case R64
case E32 case E32
case Byte
override def toString = this match { override def toString = this match {
case R64 => "r" case R64 => "r"
case E32 => "e" case E32 => "e"
case Byte => ""
} }
} }
enum RegName { enum RegName {
case AX, BX, CX, DX, SI, DI, SP, BP, IP, Reg8, Reg9, Reg10, Reg11, Reg12, Reg13, Reg14, Reg15 case AX, AL, BX, CX, DX, SI, DI, SP, BP, IP, Reg8, Reg9, Reg10, Reg11, Reg12, Reg13, Reg14, Reg15
override def toString = this match { override def toString = this match {
case AX => "ax" case AX => "ax"
case AL => "al"
case BX => "bx" case BX => "bx"
case CX => "cx" case CX => "cx"
case DX => "dx" case DX => "dx"
@ -42,7 +45,6 @@ object assemblyIR {
// arguments // arguments
enum CLibFunc extends Operand { enum CLibFunc extends Operand {
case Scanf, case Scanf,
Puts,
Fflush, Fflush,
Exit, Exit,
PrintF PrintF
@ -51,24 +53,24 @@ object assemblyIR {
override def toString = this match { override def toString = this match {
case Scanf => "scanf" + plt case Scanf => "scanf" + plt
case Puts => "puts" + plt
case Fflush => "fflush" + plt case Fflush => "fflush" + plt
case Exit => "exit" + plt case Exit => "exit" + plt
case PrintF => "printf" + plt case PrintF => "printf" + plt
} }
} }
//TODO register naming conventions are wrong
case class Register(size: RegSize, name: RegName) extends Dest with Src { case class Register(size: RegSize, name: RegName) extends Dest with Src {
override def toString = s"${size}${name}" override def toString = s"${size}${name}"
} }
case class MemLocation(pointer: Long | Register) extends Dest with Src { case class MemLocation(pointer: Long | Register, opSize: SizeDir = SizeDir.Unspecified) extends Dest with Src {
override def toString = pointer match { override def toString = pointer match {
case hex: Long => f"[0x$hex%X]" case hex: Long => opSize.toString + f"[0x$hex%X]"
case reg: Register => s"[$reg]" case reg: Register => opSize.toString + s"[$reg]"
} }
} }
case class IndexAddress(base: Register, offset: Int | LabelArg) extends Dest with Src { case class IndexAddress(base: Register, offset: Int | LabelArg, opSize: SizeDir = SizeDir.Unspecified) extends Dest with Src {
override def toString = s"[$base + $offset]" override def toString = s"$opSize[$base + $offset]"
} }
case class ImmediateVal(value: Int) extends Src { case class ImmediateVal(value: Int) extends Src {
@ -85,11 +87,13 @@ object assemblyIR {
} }
case class Add(op1: Dest, op2: Src) extends Operation("add", op1, op2) case class Add(op1: Dest, op2: Src) extends Operation("add", op1, op2)
case class Subtract(op1: Dest, op2: Src) extends Operation("sub", op1, op2) case class Subtract(op1: Dest, op2: Src) extends Operation("sub", op1, op2)
case class Multiply(ops: Operand*) extends Operation("mul", ops*) case class Multiply(ops: Operand*) extends Operation("imul", ops*)
case class Divide(op1: Src) extends Operation("div", op1) case class Divide(op1: Src) extends Operation("idiv", op1)
case class Negate(op: Dest) extends Operation("neg", op)
case class And(op1: Dest, op2: Src) extends Operation("and", op1, op2) case class And(op1: Dest, op2: Src) extends Operation("and", op1, op2)
case class Or(op1: Dest, op2: Src) extends Operation("or", op1, op2) case class Or(op1: Dest, op2: Src) extends Operation("or", op1, op2)
case class Xor(op1: Dest, op2: Src) extends Operation("xor", op1, op2)
case class Compare(op1: Dest, op2: Src) extends Operation("cmp", op1, op2) case class Compare(op1: Dest, op2: Src) extends Operation("cmp", op1, op2)
// stack operations // stack operations
@ -106,6 +110,9 @@ object assemblyIR {
case class Jump(op1: LabelArg, condition: Cond = Cond.Always) case class Jump(op1: LabelArg, condition: Cond = Cond.Always)
extends Operation(s"j${condition.toString}", op1) extends Operation(s"j${condition.toString}", op1)
case class Set(op1: Dest, condition: Cond = Cond.Always)
extends Operation(s"set${condition.toString}", op1)
case class LabelDef(name: String) extends AsmLine { case class LabelDef(name: String) extends AsmLine {
override def toString = s"$name:" override def toString = s"$name:"
} }
@ -156,4 +163,16 @@ object assemblyIR {
case String => "%s" case String => "%s"
} }
} }
enum SizeDir {
case Byte, Word, Unspecified
private val ptr = "ptr "
override def toString(): String = this match {
case Byte => "byte " + ptr
case Word => "word " + ptr
case Unspecified => ""
}
}
} }

View File

@ -69,13 +69,15 @@ object microWacc {
// Statements // Statements
sealed trait Stmt sealed trait Stmt
case class Builtin(val name: String)(retTy: SemType) extends CallTarget(retTy) {
override def toString(): String = name
}
object Builtin { object Builtin {
case object ReadInt extends CallTarget(KnownType.Int) object Read extends Builtin("read")(?)
case object ReadChar extends CallTarget(KnownType.Char) object Printf extends Builtin("printf")(?)
case object Print extends CallTarget(?) object Exit extends Builtin("exit")(?)
case object Println extends CallTarget(?) object Free extends Builtin("free")(?)
case object Exit extends CallTarget(?) object Malloc extends Builtin("malloc")(?)
case object Free extends CallTarget(?)
} }
case class Assign(lhs: LValue, rhs: Expr) extends Stmt case class Assign(lhs: LValue, rhs: Expr) extends Stmt

View File

@ -177,12 +177,14 @@ object typeChecker {
microWacc.Assign( microWacc.Assign(
destTyped, destTyped,
microWacc.Call( microWacc.Call(
microWacc.Builtin.Read,
List(
destTy match { destTy match {
case KnownType.Int => microWacc.Builtin.ReadInt case KnownType.Int => "%d".toMicroWaccCharArray
case KnownType.Char => microWacc.Builtin.ReadChar case KnownType.Char | _ => "%c".toMicroWaccCharArray
case _ => microWacc.Builtin.ReadInt // we'll stop due to error anyway
}, },
Nil destTyped
)
) )
) )
) )
@ -213,10 +215,20 @@ object typeChecker {
) )
case ast.Print(expr, newline) => case ast.Print(expr, newline) =>
// This constraint should never fail, the scope-checker should have caught it already // This constraint should never fail, the scope-checker should have caught it already
val exprTyped = checkValue(expr, Constraint.Unconstrained)
val format = exprTyped.ty match {
case KnownType.Bool | KnownType.String => "%s"
case KnownType.Char => "%c"
case KnownType.Int => "%d"
case _ => "%p"
}
List( List(
microWacc.Call( microWacc.Call(
if newline then microWacc.Builtin.Println else microWacc.Builtin.Print, microWacc.Builtin.Printf,
List(checkValue(expr, Constraint.Unconstrained)) List(
s"$format${if newline then "\n" else ""}".toMicroWaccCharArray,
exprTyped
)
) )
) )
case ast.If(cond, thenStmt, elseStmt) => case ast.If(cond, thenStmt, elseStmt) =>
@ -262,7 +274,7 @@ object typeChecker {
microWacc.CharLiter(v) microWacc.CharLiter(v)
case l @ ast.StrLiter(v) => case l @ ast.StrLiter(v) =>
KnownType.String.satisfies(constraint, l.pos) KnownType.String.satisfies(constraint, l.pos)
microWacc.ArrayLiter(v.map(microWacc.CharLiter(_)).toList)(KnownType.String) v.toMicroWaccCharArray
case l @ ast.PairLiter() => case l @ ast.PairLiter() =>
microWacc.NullLiter()(KnownType.Pair(?, ?).satisfies(constraint, l.pos)) microWacc.NullLiter()(KnownType.Pair(?, ?).satisfies(constraint, l.pos))
case ast.Parens(expr) => checkValue(expr, constraint) case ast.Parens(expr) => checkValue(expr, constraint)
@ -441,4 +453,9 @@ object typeChecker {
case _ => ctx.error(Error.InternalError(elem.pos, "snd must be applied to a pair")) case _ => ctx.error(Error.InternalError(elem.pos, "snd must be applied to a pair"))
}) })
} }
extension (s: String) {
def toMicroWaccCharArray: microWacc.ArrayLiter =
microWacc.ArrayLiter(s.map(microWacc.CharLiter(_)).toList)(KnownType.String)
}
} }