Compare commits
201 Commits
WACC_Front
...
stdlib
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
421a7a3b32 | ||
|
|
ab895940bf | ||
|
9ba0e14eb5
|
|||
|
ab0c91463a
|
|||
| fda4e17327 | |||
|
bb1b6a3b23
|
|||
|
07afc2d59f
|
|||
|
8dd23f9e5c
|
|||
|
8b6e959d11
|
|||
|
|
df7a287801 | ||
|
|
fa399e7721 | ||
|
|
cf495e9d7f | ||
|
|
4d8064dc61 | ||
|
|
fde34e88b2 | ||
| 28ee7a2a32 | |||
|
68435207fe
|
|||
|
8f7c902ed5
|
|||
|
07f02e61d7
|
|||
|
|
af514b3363 | ||
|
|
447f29ce4c | ||
| 0368daef00 | |||
|
084081de7e
|
|||
|
46f526c680
|
|||
|
53d47fda63
|
|||
|
|
6ad1a9059d | ||
|
|
5778b3145d | ||
|
|
051ef02011 | ||
|
|
42515abf2a | ||
|
|
d44eb24086 | ||
| 191c5df824 | |||
|
|
68211fd877 | ||
|
a3895dca2c
|
|||
|
6e592e7d9b
|
|||
|
ee54a1201c
|
|||
|
c73b073f23
|
|||
|
8d8df3357d
|
|||
|
00df2dc546
|
|||
|
67e85688b2
|
|||
|
0497dd34a0
|
|||
|
6904aa37e4
|
|||
|
5141a2369f
|
|||
|
3fff9d3825
|
|||
|
f11fb9f881
|
|||
|
e881b736f8
|
|||
| 905a5e5b61 | |||
|
0d8be53ae4
|
|||
|
|
36ddd025b2 | ||
|
|
bad6e47e46 | ||
|
96ba81e24a
|
|||
|
|
54d6e7143b | ||
|
|
c2259334c1 | ||
|
|
94ee489faf | ||
|
|
f24aecffa3 | ||
| f896cbb0dd | |||
|
|
19e7ce4c11 | ||
|
|
473189342b | ||
|
|
f66f1ab3ac | ||
|
|
abb43b560d | ||
|
|
9a5ccea1f6 | ||
|
|
85a82aabb4 | ||
|
1b6d81dfca
|
|||
|
ae52fa653c
|
|||
|
|
01b38b1445 | ||
|
|
667fbf4949 | ||
|
|
d214723f35 | ||
|
3b723392a7
|
|||
| 9a0d3e38a4 | |||
|
37812fb5a7
|
|||
|
578a28a222
|
|||
|
|
d56be9249a | ||
|
0eaf2186b6
|
|||
|
5ae65d3190
|
|||
|
68903f5b69
|
|||
|
e1d90eabf9
|
|||
|
1a39950a7b
|
|||
|
|
1a72decf55 | ||
|
|
e54e5ce151 | ||
| 61643a49eb | |||
|
82997a5a38
|
|||
|
|
cf1028454d | ||
|
|
345c652a57 | ||
|
720d9320e2
|
|||
|
c3f2ce8b19
|
|||
|
|
cf72c5250d | ||
|
7627ec14d2
|
|||
|
|
d0a71c1888 | ||
|
fb5799dbfd
|
|||
|
|
621849dfa4 | ||
|
967a6fe58b
|
|||
|
|
302099ab76 | ||
|
|
30f4309fda | ||
|
|
b733d233b0 | ||
|
|
f2a1eaf24c | ||
| 8b3e9b8380 | |||
|
41f76e50e0
|
|||
|
|
88f89ce761 | ||
|
|
cdf32d93c3 | ||
|
|
edcac2782b | ||
|
|
1b2df507ba | ||
|
|
3a2af6f95d | ||
|
|
4727d6c399 | ||
|
|
9639221a0a | ||
|
|
617f6759d3 | ||
|
|
6f5fcd4d85 | ||
|
c31dd9de25
|
|||
|
edce236158
|
|||
|
|
9e6970de62 | ||
|
cb4f899b8c
|
|||
|
a20f28977b
|
|||
|
9a1728fb3f
|
|||
|
|
ea262e9a56 | ||
|
|
332c00b15b | ||
| 9b9f0a80cb | |||
| ada53e518b | |||
|
c0f2473db1
|
|||
|
c472c7a62c
|
|||
|
507cb7dd9b
|
|||
|
887b982331
|
|||
|
58df1d7bb9
|
|||
|
|
808a59f58a
|
||
|
|
bdee6ba756 | ||
|
09df7af2ab
|
|||
|
2cf18a47a8
|
|||
|
|
631f9ddca5 | ||
|
4fb399a5e1
|
|||
|
16de964f74
|
|||
|
c748a34e4c
|
|||
|
85190ce174
|
|||
|
|
62df2c2244 | ||
|
|
fb91dc89ee | ||
|
|
07c67dbef6 | ||
|
39f88f6f8a
|
|||
|
|
fc2c58002e | ||
|
|
f15530149e | ||
|
|
9ca50540e6 | ||
|
|
f76b7a9dc2 | ||
|
da0ef9ec24
|
|||
|
64b015e494
|
|||
|
|
c9723f9359 | ||
| 70e023f27a | |||
|
11c483439c
|
|||
|
|
ebc65af981
|
||
|
|
bd0eb76bec
|
||
|
|
edbc03ee25
|
||
|
|
7953790f4d
|
||
|
|
7fd92b4212 | ||
|
|
87a239f37c | ||
|
4f3596b48a
|
|||
|
efe9f91303
|
|||
|
5f8b87221c
|
|||
|
8558733719
|
|||
|
f628d16d3d
|
|||
|
|
3f76a2c5bf | ||
|
|
8ed94e4df3 | ||
|
|
58d280462e | ||
|
|
f30cf42c4b | ||
|
|
1488281223 | ||
|
|
668d7338ae | ||
|
|
9d78caf6d9 | ||
|
|
909114bdce | ||
|
|
2bed722a4f | ||
|
|
dc61b1e390 | ||
|
c59c28ecbd
|
|||
|
|
82230a5f66 | ||
|
|
1255a2e74c | ||
|
|
24dddcadab | ||
|
7f2870e340
|
|||
|
|
1ce36dd8da | ||
|
|
de181d161d | ||
|
|
ee4109e9cd | ||
|
|
02e741c52e | ||
|
0391b9deba
|
|||
|
ab28f0950a
|
|||
| 1c6ea16b6e | |||
|
b2da8c2408
|
|||
|
42ff9c9e79
|
|||
|
39c695b1bb
|
|||
|
87691902be
|
|||
|
eb7387b49c
|
|||
|
|
67f7e64b95 | ||
| 2a234f6db8 | |||
|
|
a8420ae45c | ||
|
|
b692982d73 | ||
|
|
006e85d0f8 | ||
|
|
43682b902b | ||
|
|
d49c267d50 | ||
| bb090ad431 | |||
| 8991024d5d | |||
|
|
2c281066a8 | ||
|
27cc25cc0d
|
|||
|
|
7525e523bb | ||
|
|
4e58e41a2a | ||
|
b7e442b269
|
|||
|
756b42dd72
|
|||
|
bc25f914ad
|
|||
|
6a6aadbbeb
|
|||
|
03999e00ef
|
|||
|
d6aa83a2ea
|
|||
|
e23ef8da48
|
|||
|
|
32622cdd7e | ||
|
|
41ed06f91c |
4
.commitlintrc.js
Normal file
4
.commitlintrc.js
Normal file
@@ -0,0 +1,4 @@
|
||||
export default {
|
||||
extends: ['@commitlint/config-conventional'],
|
||||
ignores: [commit => commit.startsWith("Local Mutable Chains\n")]
|
||||
}
|
||||
@@ -1 +0,0 @@
|
||||
extends: "@commitlint/config-conventional"
|
||||
1
.gitignore
vendored
1
.gitignore
vendored
@@ -3,3 +3,4 @@
|
||||
.scala-build/
|
||||
.vscode/
|
||||
wacc-examples/
|
||||
.idea/
|
||||
@@ -30,9 +30,10 @@ check_commits:
|
||||
before_script:
|
||||
- apk add git
|
||||
- npm install -g @commitlint/cli @commitlint/config-conventional
|
||||
- git pull origin master
|
||||
- git checkout origin/master
|
||||
script:
|
||||
- npx commitlint --from origin/master --to HEAD --verbose
|
||||
- git checkout ${CI_COMMIT_SHA}
|
||||
- npx commitlint --from origin/master --to ${CI_COMMIT_SHA} --verbose
|
||||
|
||||
compile_jvm:
|
||||
stage: compile
|
||||
@@ -45,12 +46,13 @@ compile_jvm:
|
||||
- .scala-build/
|
||||
|
||||
test_jvm:
|
||||
image: gumjoe/wacc-ci-scala:x86
|
||||
stage: test
|
||||
# Use our own runner (not cloud VM or shared) to ensure we have multiple cores.
|
||||
tags: [ large ]
|
||||
tags: [large]
|
||||
# This is expensive, so do use `dependencies` instead of `needs` to
|
||||
# ensure all previous stages pass.
|
||||
dependencies: [ compile_jvm ]
|
||||
dependencies: [compile_jvm]
|
||||
before_script:
|
||||
- git clone https://$EXAMPLES_AUTH@gitlab.doc.ic.ac.uk/lab2425_spring/wacc-examples.git
|
||||
script:
|
||||
|
||||
2
compile
2
compile
@@ -4,6 +4,6 @@
|
||||
# but do *not* change its name.
|
||||
|
||||
# feel free to adjust to suit the specific internal flags of your compiler
|
||||
./wacc-compiler "$@"
|
||||
./wacc-compiler --output . "$@"
|
||||
|
||||
exit $?
|
||||
|
||||
10
extension/examples/invalid/semantics/badWacc.wacc
Normal file
10
extension/examples/invalid/semantics/badWacc.wacc
Normal file
@@ -0,0 +1,10 @@
|
||||
begin
|
||||
int main() is
|
||||
int a = 5 ;
|
||||
string b = "Hello" ;
|
||||
return a + b
|
||||
end
|
||||
|
||||
int result = call main() ;
|
||||
exit result
|
||||
end
|
||||
@@ -0,0 +1,6 @@
|
||||
import "./doesNotExist.wacc" (main)
|
||||
|
||||
begin
|
||||
int result = call main() ;
|
||||
exit result
|
||||
end
|
||||
@@ -0,0 +1,6 @@
|
||||
import "../../../valid/sum.wacc" (mult)
|
||||
|
||||
begin
|
||||
int result = call mult(3, 2) ;
|
||||
exit result
|
||||
end
|
||||
@@ -0,0 +1,10 @@
|
||||
import "../badWacc.wacc" (main)
|
||||
|
||||
begin
|
||||
int sum(int a, int b) is
|
||||
return a + b
|
||||
end
|
||||
|
||||
int result = call main() ;
|
||||
exit result
|
||||
end
|
||||
@@ -0,0 +1,6 @@
|
||||
import "./importBadSem.wacc" (sum)
|
||||
|
||||
begin
|
||||
int result = call sum(1, 2) ;
|
||||
exit result
|
||||
end
|
||||
@@ -0,0 +1,6 @@
|
||||
import "../../../valid/imports/basic.wacc" (sum)
|
||||
|
||||
begin
|
||||
int result = call sum(3, 2) ;
|
||||
exit result
|
||||
end
|
||||
6
extension/examples/invalid/syntax/badWacc.wacc
Normal file
6
extension/examples/invalid/syntax/badWacc.wacc
Normal file
@@ -0,0 +1,6 @@
|
||||
int main() is
|
||||
println "Hello World!" ;
|
||||
return 0
|
||||
end
|
||||
|
||||
skip
|
||||
@@ -0,0 +1,8 @@
|
||||
import "../../../valid/sum.wacc" sum, main
|
||||
|
||||
begin
|
||||
int result1 = call sum(5, 10) ;
|
||||
int result2 = call main() ;
|
||||
println result1 ;
|
||||
println result2
|
||||
end
|
||||
@@ -0,0 +1,5 @@
|
||||
import "../../../valid/sum.wacc" ()
|
||||
|
||||
begin
|
||||
exit 0
|
||||
end
|
||||
@@ -0,0 +1,10 @@
|
||||
import "../badWacc.wacc" (main)
|
||||
|
||||
begin
|
||||
int sum(int a, int b) is
|
||||
return a + b
|
||||
end
|
||||
|
||||
int result = call main() ;
|
||||
exit result
|
||||
end
|
||||
@@ -0,0 +1,6 @@
|
||||
import "./importBadSyntax.wacc" (sum)
|
||||
|
||||
begin
|
||||
int result = call sum(1, 2) ;
|
||||
exit result
|
||||
end
|
||||
@@ -0,0 +1 @@
|
||||
|
||||
@@ -0,0 +1,9 @@
|
||||
import "../../../valid/sum.wacc" (sum) ;
|
||||
import "../../../valid/sum.wacc" (main) ;
|
||||
|
||||
begin
|
||||
int result1 = call sum(5, 10) ;
|
||||
int result2 = call main() ;
|
||||
println result1 ;
|
||||
println result2
|
||||
end
|
||||
@@ -0,0 +1,5 @@
|
||||
import "../../../valid/sum.wacc" *
|
||||
|
||||
begin
|
||||
exit 0
|
||||
end
|
||||
@@ -0,0 +1,5 @@
|
||||
import "../../../valid/sum.wacc" (*)
|
||||
|
||||
begin
|
||||
exit 0
|
||||
end
|
||||
7
extension/examples/valid/.gitignore
vendored
Normal file
7
extension/examples/valid/.gitignore
vendored
Normal file
@@ -0,0 +1,7 @@
|
||||
*
|
||||
|
||||
!imports/
|
||||
imports/*
|
||||
|
||||
!.gitignore
|
||||
!*.wacc
|
||||
22
extension/examples/valid/imports/alias.wacc
Normal file
22
extension/examples/valid/imports/alias.wacc
Normal file
@@ -0,0 +1,22 @@
|
||||
# import main from ../sum.wacc and ./basic.wacc
|
||||
|
||||
# Output:
|
||||
# 15
|
||||
# 0
|
||||
# -33
|
||||
#
|
||||
|
||||
# Exit:
|
||||
# 0
|
||||
|
||||
# Program:
|
||||
|
||||
import "../sum.wacc" (main as sumMain)
|
||||
import "./basic.wacc" (main)
|
||||
|
||||
begin
|
||||
int result1 = call sumMain() ;
|
||||
int result2 = call main() ;
|
||||
println result1 ;
|
||||
println result2
|
||||
end
|
||||
21
extension/examples/valid/imports/basic.wacc
Normal file
21
extension/examples/valid/imports/basic.wacc
Normal file
@@ -0,0 +1,21 @@
|
||||
# import sum from ../sum.wacc
|
||||
|
||||
# Output:
|
||||
# -33
|
||||
#
|
||||
|
||||
# Exit:
|
||||
# 0
|
||||
|
||||
# Program:
|
||||
|
||||
import "../sum.wacc" (sum)
|
||||
|
||||
begin
|
||||
int main() is
|
||||
int result = call sum(-10, -23) ;
|
||||
return result
|
||||
end
|
||||
int result = call main() ;
|
||||
println result
|
||||
end
|
||||
33
extension/examples/valid/imports/manyMains.wacc
Normal file
33
extension/examples/valid/imports/manyMains.wacc
Normal file
@@ -0,0 +1,33 @@
|
||||
# import all the mains
|
||||
|
||||
# Output:
|
||||
# 15
|
||||
# -33
|
||||
# 0
|
||||
# -33
|
||||
# 0
|
||||
#
|
||||
|
||||
# Exit:
|
||||
# 99
|
||||
|
||||
# Program:
|
||||
|
||||
import "../sum.wacc" (main as sumMain)
|
||||
import "./basic.wacc" (main as basicMain)
|
||||
import "./multiFunc.wacc" (main as multiFuncMain)
|
||||
|
||||
begin
|
||||
int main() is
|
||||
int result1 = call sumMain() ;
|
||||
int result2 = call basicMain() ;
|
||||
int result3 = call multiFuncMain() ;
|
||||
println result1 ;
|
||||
println result2 ;
|
||||
println result3 ;
|
||||
return 99
|
||||
end
|
||||
|
||||
int result = call main() ;
|
||||
exit result
|
||||
end
|
||||
27
extension/examples/valid/imports/multiFunc.wacc
Normal file
27
extension/examples/valid/imports/multiFunc.wacc
Normal file
@@ -0,0 +1,27 @@
|
||||
# import sum, main from ../sum.wacc
|
||||
|
||||
# Output:
|
||||
# 15
|
||||
# -33
|
||||
# 0
|
||||
# 0
|
||||
#
|
||||
|
||||
# Exit:
|
||||
# 0
|
||||
|
||||
# Program:
|
||||
|
||||
import "../sum.wacc" (sum, main as sumMain)
|
||||
|
||||
begin
|
||||
int main() is
|
||||
int result = call sum(-10, -23) ;
|
||||
println result ;
|
||||
return 0
|
||||
end
|
||||
int result1 = call sumMain() ;
|
||||
int result2 = call main() ;
|
||||
println result1 ;
|
||||
println result2
|
||||
end
|
||||
27
extension/examples/valid/sum.wacc
Normal file
27
extension/examples/valid/sum.wacc
Normal file
@@ -0,0 +1,27 @@
|
||||
# simple sum program
|
||||
|
||||
# Output:
|
||||
# 15
|
||||
#
|
||||
|
||||
# Exit:
|
||||
# 0
|
||||
|
||||
# Program:
|
||||
|
||||
begin
|
||||
int sum(int a, int b) is
|
||||
return a + b
|
||||
end
|
||||
|
||||
int main() is
|
||||
int a = 5 ;
|
||||
int b = 10 ;
|
||||
int result = call sum(a, b) ;
|
||||
println result ;
|
||||
return 0
|
||||
end
|
||||
|
||||
int result = call main() ;
|
||||
exit result
|
||||
end
|
||||
@@ -3,25 +3,22 @@
|
||||
|
||||
// dependencies
|
||||
//> using dep com.github.j-mie6::parsley::5.0.0-M10
|
||||
//> using dep com.github.j-mie6::parsley-cats::1.3.0
|
||||
//> using dep com.lihaoyi::os-lib::0.11.3
|
||||
//> using dep com.github.scopt::scopt::4.1.0
|
||||
//> using dep com.github.j-mie6::parsley-cats::1.5.0
|
||||
//> using dep com.lihaoyi::os-lib::0.11.4
|
||||
//> using dep org.typelevel::cats-core::2.13.0
|
||||
//> using dep org.typelevel::cats-effect::3.5.7
|
||||
//> using dep com.monovore::decline::2.5.0
|
||||
//> using dep com.monovore::decline-effect::2.5.0
|
||||
//> using dep org.typelevel::log4cats-slf4j::2.7.0
|
||||
//> using dep org.slf4j:slf4j-simple:2.0.17
|
||||
//> using test.dep org.scalatest::scalatest::3.2.19
|
||||
//> using dep org.typelevel::cats-effect-testing-scalatest::1.6.0
|
||||
//> using dep "co.fs2::fs2-core:3.11.0"
|
||||
//> using dep co.fs2::fs2-io:3.11.0
|
||||
|
||||
// these are all sensible defaults to catch annoying issues
|
||||
// sensible defaults for warnings and compiler checks
|
||||
//> using options -deprecation -unchecked -feature
|
||||
//> using options -Wimplausible-patterns -Wunused:all
|
||||
//> using options -Yexplicit-nulls -Wsafe-init -Xkind-projector:underscores
|
||||
|
||||
// these will help ensure you have access to the latest parsley releases
|
||||
// even before they land on maven proper, or snapshot versions, if necessary.
|
||||
// just in case they cause problems, however, keep them turned off unless you
|
||||
// specifically need them.
|
||||
// using repositories sonatype-s01:releases
|
||||
// using repositories sonatype-s01:snapshots
|
||||
|
||||
// these are flags used by Scala native: if you aren't using scala-native, then they do nothing
|
||||
// lto-thin has decent linking times, and release-fast does not too much optimisation.
|
||||
// using nativeLto thin
|
||||
// using nativeGc commix
|
||||
// using nativeMode release-fast
|
||||
// repositories for pre-release versions if needed
|
||||
|
||||
@@ -1,93 +0,0 @@
|
||||
package wacc
|
||||
|
||||
import wacc.ast.Position
|
||||
import wacc.types._
|
||||
|
||||
/** Error types for semantic errors
|
||||
*/
|
||||
enum Error {
|
||||
case DuplicateDeclaration(ident: ast.Ident)
|
||||
case UndeclaredVariable(ident: ast.Ident)
|
||||
case UndefinedFunction(ident: ast.Ident)
|
||||
|
||||
case FunctionParamsMismatch(ident: ast.Ident, expected: Int, got: Int, funcType: FuncType)
|
||||
case SemanticError(pos: Position, msg: String)
|
||||
case TypeMismatch(pos: Position, expected: SemType, got: SemType, msg: String)
|
||||
case InternalError(pos: Position, msg: String)
|
||||
}
|
||||
|
||||
/** Function to handle printing the details of a given semantic error
|
||||
*
|
||||
* @param error
|
||||
* Error object
|
||||
* @param errorContent
|
||||
* Contents of the file to generate code snippets
|
||||
*/
|
||||
def printError(error: Error)(using errorContent: String): Unit = {
|
||||
println("Semantic error:")
|
||||
error match {
|
||||
case Error.DuplicateDeclaration(ident) =>
|
||||
printPosition(ident.pos)
|
||||
println(s"Duplicate declaration of identifier ${ident.v}")
|
||||
highlight(ident.pos, ident.v.length)
|
||||
case Error.UndeclaredVariable(ident) =>
|
||||
printPosition(ident.pos)
|
||||
println(s"Undeclared variable ${ident.v}")
|
||||
highlight(ident.pos, ident.v.length)
|
||||
case Error.UndefinedFunction(ident) =>
|
||||
printPosition(ident.pos)
|
||||
println(s"Undefined function ${ident.v}")
|
||||
highlight(ident.pos, ident.v.length)
|
||||
case Error.FunctionParamsMismatch(id, expected, got, funcType) =>
|
||||
printPosition(id.pos)
|
||||
println(s"Function expects $expected parameters, got $got")
|
||||
println(
|
||||
s"(function ${id.v} has type (${funcType.params.mkString(", ")}) -> ${funcType.returnType})"
|
||||
)
|
||||
highlight(id.pos, 1)
|
||||
case Error.TypeMismatch(pos, expected, got, msg) =>
|
||||
printPosition(pos)
|
||||
println(s"Type mismatch: $msg\nExpected: $expected\nGot: $got")
|
||||
highlight(pos, 1)
|
||||
case Error.SemanticError(pos, msg) =>
|
||||
printPosition(pos)
|
||||
println(msg)
|
||||
highlight(pos, 1)
|
||||
case wacc.Error.InternalError(pos, msg) =>
|
||||
printPosition(pos)
|
||||
println(s"Internal error: $msg")
|
||||
highlight(pos, 1)
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
/** Function to highlight a section of code for an error message
|
||||
*
|
||||
* @param pos
|
||||
* Position of the error
|
||||
* @param size
|
||||
* Size(in chars) of section to highlight
|
||||
* @param errorContent
|
||||
* Contents of the file to generate code snippets
|
||||
*/
|
||||
def highlight(pos: Position, size: Int)(using errorContent: String): Unit = {
|
||||
val lines = errorContent.split("\n")
|
||||
|
||||
val preLine = if (pos.line > 1) lines(pos.line - 2) else ""
|
||||
val midLine = lines(pos.line - 1)
|
||||
val postLine = if (pos.line < lines.size) lines(pos.line) else ""
|
||||
val linePointer = " " * (pos.column + 2) + ("^" * (size)) + "\n"
|
||||
|
||||
println(
|
||||
s" >$preLine\n >$midLine\n$linePointer >$postLine"
|
||||
)
|
||||
}
|
||||
|
||||
/** Function to print the position of an error
|
||||
*
|
||||
* @param pos
|
||||
* Position of the error
|
||||
*/
|
||||
def printPosition(pos: Position): Unit = {
|
||||
println(s"(line ${pos.line}, column ${pos.column}):")
|
||||
}
|
||||
@@ -1,56 +1,177 @@
|
||||
package wacc
|
||||
|
||||
import scala.collection.mutable
|
||||
import cats.data.{Chain, NonEmptyList}
|
||||
import parsley.{Failure, Success}
|
||||
import scopt.OParser
|
||||
|
||||
import java.nio.file.{Files, Path}
|
||||
import cats.syntax.all._
|
||||
|
||||
import cats.effect.IO
|
||||
import cats.effect.ExitCode
|
||||
|
||||
import com.monovore.decline._
|
||||
import com.monovore.decline.effect._
|
||||
|
||||
import org.typelevel.log4cats.slf4j.Slf4jLogger
|
||||
import org.typelevel.log4cats.Logger
|
||||
|
||||
import fs2.Stream
|
||||
|
||||
import assemblyIR as asm
|
||||
import cats.data.ValidatedNel
|
||||
import java.io.File
|
||||
import cats.data.NonEmptySeq
|
||||
|
||||
case class CliConfig(
|
||||
file: File = new File(".")
|
||||
)
|
||||
/*
|
||||
TODO:
|
||||
1) IO correctness
|
||||
2) Errors can be handled more gracefully - currently, parallelised compilation is not fail fast as far as I am aware
|
||||
3) splitting the file up and nicer refactoring
|
||||
4) logging could be removed
|
||||
5) general cleanup and comments (things like replacing home/<user> with ~ , and names of parameters and args, descriptions etc)
|
||||
*/
|
||||
|
||||
val cliBuilder = OParser.builder[CliConfig]
|
||||
val cliParser = {
|
||||
import cliBuilder._
|
||||
OParser.sequence(
|
||||
programName("wacc-compiler"),
|
||||
help('h', "help")
|
||||
.text("Prints this help message"),
|
||||
arg[File]("<file>")
|
||||
.text("Input WACC source file")
|
||||
.required()
|
||||
.action((f, c) => c.copy(file = f))
|
||||
.validate(f =>
|
||||
if (!f.exists) failure("File does not exist")
|
||||
else if (!f.isFile) failure("File must be a regular file")
|
||||
else if (!f.getName.endsWith(".wacc"))
|
||||
failure("File must have .wacc extension")
|
||||
else success
|
||||
)
|
||||
)
|
||||
private val SUCCESS = ExitCode.Success.code
|
||||
private val ERROR = ExitCode.Error.code
|
||||
|
||||
given logger: Logger[IO] = Slf4jLogger.getLogger[IO]
|
||||
|
||||
val logOpt: Opts[Boolean] =
|
||||
Opts.flag("log", "Enable logging for additional compilation details", short = "l").orFalse
|
||||
|
||||
def validateFile(path: Path): ValidatedNel[String, Path] = {
|
||||
(for {
|
||||
// TODO: redundant 2nd parameter :(
|
||||
_ <- Either.cond(Files.exists(path), (), s"File '${path}' does not exist")
|
||||
_ <- Either.cond(Files.isRegularFile(path), (), s"File '${path}' must be a regular file")
|
||||
_ <- Either.cond(path.toString.endsWith(".wacc"), (), "File must have .wacc extension")
|
||||
} yield path).toValidatedNel
|
||||
}
|
||||
|
||||
def compile(contents: String): Int = {
|
||||
val filesOpt: Opts[NonEmptyList[Path]] =
|
||||
Opts.arguments[Path]("files").mapValidated {
|
||||
_.traverse(validateFile)
|
||||
}
|
||||
|
||||
val outputOpt: Opts[Option[Path]] =
|
||||
Opts
|
||||
.option[Path]("output", metavar = "path", help = "Output directory for compiled files.")
|
||||
.validate("Must have permissions to create & access the output path") { path =>
|
||||
try {
|
||||
Files.createDirectories(path)
|
||||
true
|
||||
} catch {
|
||||
case e: java.nio.file.AccessDeniedException =>
|
||||
false
|
||||
}
|
||||
}
|
||||
.validate("Output path must be a directory") { path =>
|
||||
Files.isDirectory(path)
|
||||
}
|
||||
.orNone
|
||||
|
||||
def frontend(
|
||||
contents: String,
|
||||
file: File
|
||||
): IO[Either[NonEmptySeq[Error], microWacc.Program]] =
|
||||
parser.parse(contents) match {
|
||||
case Success(prog) =>
|
||||
given errors: mutable.Builder[Error, List[Error]] = List.newBuilder
|
||||
val (names, funcs) = renamer.rename(prog)
|
||||
given ctx: typeChecker.TypeCheckerCtx = typeChecker.TypeCheckerCtx(names, funcs, errors)
|
||||
typeChecker.check(prog)
|
||||
if (errors.result.nonEmpty) {
|
||||
given errorContent: String = contents
|
||||
errors.result.foreach(printError)
|
||||
200
|
||||
} else 0
|
||||
case Failure(msg) =>
|
||||
println(msg)
|
||||
100
|
||||
case Failure(msg) => IO.pure(Left(NonEmptySeq.one(Error.SyntaxError(file, msg))))
|
||||
case Success(fn) =>
|
||||
val partialProg = fn(file)
|
||||
|
||||
for {
|
||||
(typedProg, errors) <- semantics.check(partialProg)
|
||||
res = NonEmptySeq.fromSeq(errors.iterator.toSeq).map(Left(_)).getOrElse(Right(typedProg))
|
||||
} yield res
|
||||
}
|
||||
|
||||
def backend(typedProg: microWacc.Program): Chain[asm.AsmLine] =
|
||||
asmGenerator.generateAsm(typedProg)
|
||||
|
||||
def compile(
|
||||
filePath: Path,
|
||||
outputDir: Option[Path],
|
||||
log: Boolean
|
||||
): IO[Int] = {
|
||||
val logAction: String => IO[Unit] =
|
||||
if (log) logger.info(_)
|
||||
else (_ => IO.unit)
|
||||
|
||||
def readSourceFile: IO[String] =
|
||||
IO.blocking(os.read(os.Path(filePath)))
|
||||
|
||||
// TODO: path, file , the names are confusing (when Path is the type but we are working with files)
|
||||
def writeOutputFile(typedProg: microWacc.Program, outputPath: Path): IO[Unit] =
|
||||
val backendStart = System.nanoTime()
|
||||
val asmLines = backend(typedProg)
|
||||
val backendEnd = System.nanoTime()
|
||||
writer.writeTo(asmLines, outputPath) *>
|
||||
logAction(
|
||||
s"Backend time (${filePath.toRealPath()}): ${(backendEnd - backendStart).toFloat / 1e6} ms"
|
||||
) *>
|
||||
IO.blocking(println(s"Success: ${outputPath.toRealPath()}"))
|
||||
|
||||
def processProgram(contents: String, file: File, outDir: Path): IO[Int] =
|
||||
val frontendStart = System.nanoTime()
|
||||
for {
|
||||
frontendResult <- frontend(contents, file)
|
||||
frontendEnd = System.nanoTime()
|
||||
_ <- logAction(
|
||||
s"Frontend time (${filePath.toRealPath()}): ${(frontendEnd - frontendStart).toFloat / 1e6} ms"
|
||||
)
|
||||
res <- frontendResult match {
|
||||
case Left(errors) =>
|
||||
val code = errors.map(err => err.exitCode).toList.min
|
||||
val errorMsg = errors.map(formatError).toIterable.mkString("\n")
|
||||
for {
|
||||
_ <- logAction(s"Compilation failed for $filePath\nExit code: $code")
|
||||
_ <- IO.blocking(
|
||||
// Explicit println since we want this to always show without logger thread info e.t.c.
|
||||
println(s"Compilation failed for ${file.getCanonicalPath}:\n$errorMsg")
|
||||
)
|
||||
} yield code
|
||||
|
||||
case Right(typedProg) =>
|
||||
val outputFile = outDir.resolve(filePath.getFileName.toString.stripSuffix(".wacc") + ".s")
|
||||
writeOutputFile(typedProg, outputFile).as(SUCCESS)
|
||||
}
|
||||
} yield res
|
||||
|
||||
for {
|
||||
contents <- readSourceFile
|
||||
_ <- logAction(s"Compiling file: ${filePath.toAbsolutePath}")
|
||||
exitCode <- processProgram(contents, filePath.toFile, outputDir.getOrElse(filePath.getParent))
|
||||
} yield exitCode
|
||||
}
|
||||
|
||||
def main(args: Array[String]): Unit =
|
||||
OParser.parse(cliParser, args, CliConfig()) match {
|
||||
case Some(config) =>
|
||||
System.exit(compile(os.read(os.Path(config.file.getAbsolutePath))))
|
||||
case None =>
|
||||
}
|
||||
def compileCommandParallel(
|
||||
files: NonEmptyList[Path],
|
||||
log: Boolean,
|
||||
outDir: Option[Path]
|
||||
): IO[ExitCode] =
|
||||
Stream
|
||||
.emits(files.toList)
|
||||
.parEvalMapUnordered(Runtime.getRuntime.availableProcessors()) { file =>
|
||||
compile(file.toAbsolutePath, outDir, log)
|
||||
}
|
||||
.compile
|
||||
.toList
|
||||
.map { exitCodes =>
|
||||
exitCodes.filter(_ != 0) match {
|
||||
case Nil => ExitCode.Success
|
||||
case errorCodes => ExitCode(errorCodes.min)
|
||||
}
|
||||
}
|
||||
|
||||
object Main
|
||||
extends CommandIOApp(
|
||||
name = "wacc",
|
||||
header = "The ultimate WACC compiler",
|
||||
version = "1.0"
|
||||
) {
|
||||
def main: Opts[IO[ExitCode]] =
|
||||
(filesOpt, logOpt, outputOpt).mapN { (files, log, outDir) =>
|
||||
compileCommandParallel(files, log, outDir)
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
78
src/main/wacc/backend/LabelGenerator.scala
Normal file
78
src/main/wacc/backend/LabelGenerator.scala
Normal file
@@ -0,0 +1,78 @@
|
||||
package wacc
|
||||
|
||||
import scala.collection.mutable
|
||||
import cats.data.Chain
|
||||
import wacc.ast.Position
|
||||
|
||||
private class LabelGenerator {
|
||||
import assemblyIR._
|
||||
import microWacc.{CallTarget, Ident, Builtin}
|
||||
import asmGenerator.escaped
|
||||
|
||||
private val strings = mutable.HashMap[String, String]()
|
||||
private val files = mutable.HashMap[String, Int]()
|
||||
private var labelVal = -1
|
||||
private var permittedFuncFile: Option[String] = None
|
||||
|
||||
/** Get an arbitrary label. */
|
||||
def getLabel(): String = {
|
||||
labelVal += 1
|
||||
s".L$labelVal"
|
||||
}
|
||||
|
||||
private def getLabel(target: CallTarget | RuntimeError): String = target match {
|
||||
case Ident(v, guid) => s"wacc_${v}_$guid"
|
||||
case Builtin(name) => s"_$name"
|
||||
case err: RuntimeError => s".L.${err.name}"
|
||||
}
|
||||
|
||||
/** Get a named label def for a function or error. */
|
||||
def getLabelDef(target: CallTarget | RuntimeError): LabelDef =
|
||||
LabelDef(getLabel(target))
|
||||
|
||||
/** Get a named label for a function or error. */
|
||||
def getLabelArg(target: CallTarget | RuntimeError): LabelArg =
|
||||
LabelArg(getLabel(target))
|
||||
|
||||
/** Get an arbitrary label for a string. */
|
||||
def getLabelArg(str: String): LabelArg =
|
||||
LabelArg(strings.getOrElseUpdate(str, s".L.str${strings.size}"))
|
||||
|
||||
/** Get a named label for a string. */
|
||||
def getLabelArg(src: String, name: String): LabelArg =
|
||||
LabelArg(strings.getOrElseUpdate(src, s".L.$name.str${strings.size}"))
|
||||
|
||||
/** Get a debug directive for a file. */
|
||||
def getDebugFile(file: java.io.File): Int =
|
||||
files.getOrElseUpdate(file.getCanonicalPath, files.size)
|
||||
|
||||
/** Get a debug directive for a function. */
|
||||
def getDebugFunc(pos: Position, name: String, label: LabelDef): Chain[AsmLine] = {
|
||||
permittedFuncFile match {
|
||||
case Some(f) if f != pos.file.getCanonicalPath => Chain.empty
|
||||
case _ =>
|
||||
val customLabel = if name == "main" then Chain.empty else Chain(LabelDef(name))
|
||||
permittedFuncFile = Some(pos.file.getCanonicalPath)
|
||||
customLabel ++ Chain(
|
||||
Directive.Location(getDebugFile(pos.file), pos.line, None),
|
||||
Directive.Type(label, SymbolType.Function),
|
||||
Directive.Func(name, label)
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
/** Generate the assembly labels for constants that were labelled using the LabelGenerator. */
|
||||
def generateConstants: Chain[AsmLine] =
|
||||
strings.foldLeft(Chain.empty) { case (acc, (str, label)) =>
|
||||
acc ++ Chain(
|
||||
LabelDef(label),
|
||||
Directive.Asciz(str.escaped)
|
||||
)
|
||||
}
|
||||
|
||||
/** Generates debug directives that were created using the LabelGenerator. */
|
||||
def generateDebug: Chain[AsmLine] =
|
||||
files.foldLeft(Chain.empty) { case (acc, (file, no)) =>
|
||||
acc :+ Directive.File(no, file)
|
||||
}
|
||||
}
|
||||
112
src/main/wacc/backend/RuntimeError.scala
Normal file
112
src/main/wacc/backend/RuntimeError.scala
Normal file
@@ -0,0 +1,112 @@
|
||||
package wacc
|
||||
|
||||
import cats.data.Chain
|
||||
import wacc.assemblyIR._
|
||||
|
||||
sealed trait RuntimeError {
|
||||
val name: String
|
||||
protected val errStr: String
|
||||
|
||||
protected def getErrLabel(using labelGenerator: LabelGenerator): LabelArg =
|
||||
labelGenerator.getLabelArg(errStr, name = name)
|
||||
|
||||
protected def generateHandler(using labelGenerator: LabelGenerator): Chain[AsmLine]
|
||||
|
||||
def generate(using labelGenerator: LabelGenerator): Chain[AsmLine] =
|
||||
labelGenerator.getLabelDef(this) +: generateHandler
|
||||
}
|
||||
|
||||
object RuntimeError {
|
||||
import wacc.asmGenerator.stackAlign
|
||||
import assemblyIR.commonRegisters._
|
||||
|
||||
private val ERROR_CODE = 255
|
||||
|
||||
case object ZeroDivError extends RuntimeError {
|
||||
val name = "errDivZero"
|
||||
protected val errStr = "fatal error: division or modulo by zero"
|
||||
|
||||
protected def generateHandler(using labelGenerator: LabelGenerator): Chain[AsmLine] = Chain(
|
||||
stackAlign,
|
||||
Load(RDI, MemLocation(RIP, getErrLabel)),
|
||||
assemblyIR.Call(CLibFunc.PrintF),
|
||||
Move(RDI, ImmediateVal(-1)),
|
||||
assemblyIR.Call(CLibFunc.Exit)
|
||||
)
|
||||
|
||||
}
|
||||
|
||||
case object BadChrError extends RuntimeError {
|
||||
val name = "errBadChr"
|
||||
protected val errStr = "fatal error: int %d is not an ASCII character 0-127"
|
||||
|
||||
protected def generateHandler(using labelGenerator: LabelGenerator): Chain[AsmLine] = Chain(
|
||||
Pop(RSI),
|
||||
stackAlign,
|
||||
Load(RDI, MemLocation(RIP, getErrLabel)),
|
||||
assemblyIR.Call(CLibFunc.PrintF),
|
||||
Move(RDI, ImmediateVal(ERROR_CODE)),
|
||||
assemblyIR.Call(CLibFunc.Exit)
|
||||
)
|
||||
|
||||
}
|
||||
|
||||
case object NullPtrError extends RuntimeError {
|
||||
val name = "errNullPtr"
|
||||
protected val errStr = "fatal error: null pair dereferenced or freed"
|
||||
|
||||
protected def generateHandler(using labelGenerator: LabelGenerator): Chain[AsmLine] = Chain(
|
||||
stackAlign,
|
||||
Load(RDI, MemLocation(RIP, getErrLabel)),
|
||||
assemblyIR.Call(CLibFunc.PrintF),
|
||||
Move(RDI, ImmediateVal(ERROR_CODE)),
|
||||
assemblyIR.Call(CLibFunc.Exit)
|
||||
)
|
||||
|
||||
}
|
||||
|
||||
case object OverflowError extends RuntimeError {
|
||||
val name = "errOverflow"
|
||||
protected val errStr = "fatal error: integer overflow or underflow occurred"
|
||||
|
||||
protected def generateHandler(using labelGenerator: LabelGenerator): Chain[AsmLine] = Chain(
|
||||
stackAlign,
|
||||
Load(RDI, MemLocation(RIP, getErrLabel)),
|
||||
assemblyIR.Call(CLibFunc.PrintF),
|
||||
Move(RDI, ImmediateVal(ERROR_CODE)),
|
||||
assemblyIR.Call(CLibFunc.Exit)
|
||||
)
|
||||
|
||||
}
|
||||
|
||||
case object OutOfBoundsError extends RuntimeError {
|
||||
val name = "errOutOfBounds"
|
||||
protected val errStr = "fatal error: array index %d out of bounds"
|
||||
|
||||
protected def generateHandler(using labelGenerator: LabelGenerator): Chain[AsmLine] = Chain(
|
||||
Move(RSI, RCX),
|
||||
stackAlign,
|
||||
Load(RDI, MemLocation(RIP, getErrLabel)),
|
||||
assemblyIR.Call(CLibFunc.PrintF),
|
||||
Move(RDI, ImmediateVal(ERROR_CODE)),
|
||||
assemblyIR.Call(CLibFunc.Exit)
|
||||
)
|
||||
}
|
||||
|
||||
case object OutOfMemoryError extends RuntimeError {
|
||||
val name = "errOutOfMemory"
|
||||
protected val errStr = "fatal error: out of memory"
|
||||
|
||||
def generateHandler(using labelGenerator: LabelGenerator): Chain[AsmLine] = Chain(
|
||||
stackAlign,
|
||||
Load(RDI, MemLocation(RIP, getErrLabel)),
|
||||
assemblyIR.Call(CLibFunc.PrintF),
|
||||
Move(RDI, ImmediateVal(ERROR_CODE)),
|
||||
assemblyIR.Call(CLibFunc.Exit)
|
||||
)
|
||||
}
|
||||
|
||||
val all: Chain[RuntimeError] =
|
||||
Chain(ZeroDivError, BadChrError, NullPtrError, OverflowError, OutOfBoundsError,
|
||||
OutOfMemoryError)
|
||||
}
|
||||
90
src/main/wacc/backend/Stack.scala
Normal file
90
src/main/wacc/backend/Stack.scala
Normal file
@@ -0,0 +1,90 @@
|
||||
package wacc
|
||||
|
||||
import scala.collection.mutable.LinkedHashMap
|
||||
import cats.data.Chain
|
||||
|
||||
class Stack {
|
||||
import assemblyIR._
|
||||
import assemblyIR.Size._
|
||||
import sizeExtensions.size
|
||||
import microWacc as mw
|
||||
|
||||
private val RSP = Register(Q64, RegName.SP)
|
||||
private class StackValue(val size: Size, val offset: Int) {
|
||||
def bottom: Int = offset + elemBytes
|
||||
}
|
||||
private val stack = LinkedHashMap[mw.Expr | Int, StackValue]()
|
||||
|
||||
private val elemBytes: Int = Q64.toInt
|
||||
private def sizeBytes: Int = stack.size * elemBytes
|
||||
|
||||
/** The stack's size in bytes. */
|
||||
def size: Int = stack.size
|
||||
|
||||
/** Push an expression onto the stack. */
|
||||
def push(expr: mw.Expr, src: Register): AsmLine = {
|
||||
stack += expr -> StackValue(expr.ty.size, sizeBytes)
|
||||
Push(src)
|
||||
}
|
||||
|
||||
/** Push a value onto the stack. */
|
||||
def push(itemSize: Size, addr: Src): AsmLine = {
|
||||
stack += stack.size -> StackValue(itemSize, sizeBytes)
|
||||
Push(addr)
|
||||
}
|
||||
|
||||
/** Reserve space for a variable on the stack. */
|
||||
def reserve(ident: mw.Ident): AsmLine = {
|
||||
stack += ident -> StackValue(ident.ty.size, sizeBytes)
|
||||
Subtract(RSP, ImmediateVal(elemBytes))
|
||||
}
|
||||
|
||||
/** Reserve space for a register on the stack. */
|
||||
def reserve(src: Register): AsmLine = {
|
||||
stack += stack.size -> StackValue(src.size, sizeBytes)
|
||||
Subtract(RSP, ImmediateVal(src.size.toInt))
|
||||
}
|
||||
|
||||
/** Reserve space for values on the stack.
|
||||
*
|
||||
* @param sizes
|
||||
* The sizes of the values to reserve space for.
|
||||
*/
|
||||
def reserve(sizes: Size*): AsmLine = {
|
||||
sizes.foreach { itemSize =>
|
||||
stack += stack.size -> StackValue(itemSize, sizeBytes)
|
||||
}
|
||||
Subtract(RSP, ImmediateVal(elemBytes * sizes.size))
|
||||
}
|
||||
|
||||
/** Pop a value from the stack into a register. Sizes MUST match. */
|
||||
def pop(dest: Register): AsmLine = {
|
||||
stack.remove(stack.last._1)
|
||||
Pop(dest)
|
||||
}
|
||||
|
||||
/** Drop the top n values from the stack. */
|
||||
def drop(n: Int = 1): AsmLine = {
|
||||
(1 to n).foreach { _ =>
|
||||
stack.remove(stack.last._1)
|
||||
}
|
||||
Add(RSP, ImmediateVal(n * elemBytes))
|
||||
}
|
||||
|
||||
/** Generate AsmLines within a scope, which is reset after the block. */
|
||||
def withScope(block: () => Chain[AsmLine]): Chain[AsmLine] = {
|
||||
val resetToSize = stack.size
|
||||
var lines = block()
|
||||
lines :+= drop(stack.size - resetToSize)
|
||||
lines
|
||||
}
|
||||
|
||||
/** Get an MemLocation for a variable in the stack. */
|
||||
def accessVar(ident: mw.Ident): MemLocation =
|
||||
MemLocation(RSP, sizeBytes - stack(ident).bottom, opSize = Some(stack(ident).size))
|
||||
|
||||
def contains(ident: mw.Ident): Boolean = stack.contains(ident)
|
||||
def head: MemLocation = MemLocation(RSP, opSize = Some(stack.last._2.size))
|
||||
|
||||
override def toString(): String = stack.toString
|
||||
}
|
||||
483
src/main/wacc/backend/asmGenerator.scala
Normal file
483
src/main/wacc/backend/asmGenerator.scala
Normal file
@@ -0,0 +1,483 @@
|
||||
package wacc
|
||||
|
||||
import cats.data.Chain
|
||||
import cats.syntax.foldable._
|
||||
import wacc.RuntimeError._
|
||||
|
||||
object asmGenerator {
|
||||
import microWacc._
|
||||
import assemblyIR._
|
||||
import assemblyIR.commonRegisters._
|
||||
import assemblyIR.RegName._
|
||||
import types._
|
||||
import sizeExtensions._
|
||||
import lexer.escapedChars
|
||||
|
||||
private val argRegs = List(DI, SI, DX, CX, R8, R9)
|
||||
|
||||
private val _7_BIT_MASK = 0x7f
|
||||
|
||||
extension [T](chain: Chain[T])
|
||||
def +(item: T): Chain[T] = chain.append(item)
|
||||
|
||||
/** Concatenates multiple `Chain[T]` instances into a single `Chain[T]`, appending them to the
|
||||
* current `Chain`.
|
||||
*
|
||||
* @param chains
|
||||
* A variable number of `Chain[T]` instances to concatenate.
|
||||
* @return
|
||||
* A new `Chain[T]` containing all elements from `chain` concatenated with `chains`.
|
||||
*/
|
||||
def concatAll(chains: Chain[T]*): Chain[T] =
|
||||
chains.foldLeft(chain)(_ ++ _)
|
||||
|
||||
def generateAsm(microProg: Program): Chain[AsmLine] = {
|
||||
given stack: Stack = Stack()
|
||||
given labelGenerator: LabelGenerator = LabelGenerator()
|
||||
val Program(funcs, main) = microProg
|
||||
|
||||
val mainLabel = LabelDef("main")
|
||||
val mainAsm = labelGenerator.getDebugFunc(microProg.pos, "main", mainLabel) + mainLabel
|
||||
val progAsm = mainAsm.concatAll(
|
||||
funcPrologue(),
|
||||
main.foldMap(generateStmt(_)),
|
||||
Chain.one(Xor(RAX, RAX)),
|
||||
funcEpilogue(),
|
||||
Chain(Directive.Size(mainLabel, SizeExpr.Relative(mainLabel)), Directive.EndFunc),
|
||||
generateBuiltInFuncs(),
|
||||
RuntimeError.all.foldMap(_.generate),
|
||||
funcs.foldMap(generateUserFunc(_))
|
||||
)
|
||||
|
||||
Chain(
|
||||
Directive.IntelSyntax,
|
||||
Directive.Global("main"),
|
||||
Directive.RoData
|
||||
).concatAll(
|
||||
labelGenerator.generateDebug,
|
||||
labelGenerator.generateConstants,
|
||||
Chain.one(Directive.Text),
|
||||
progAsm
|
||||
)
|
||||
}
|
||||
|
||||
private def wrapBuiltinFunc(builtin: Builtin, funcBody: Chain[AsmLine])(using
|
||||
stack: Stack,
|
||||
labelGenerator: LabelGenerator
|
||||
): Chain[AsmLine] = {
|
||||
var asm = Chain.one[AsmLine](labelGenerator.getLabelDef(builtin))
|
||||
asm ++= funcPrologue()
|
||||
asm ++= funcBody
|
||||
asm ++= funcEpilogue()
|
||||
asm
|
||||
}
|
||||
|
||||
private def generateUserFunc(func: FuncDecl)(using
|
||||
labelGenerator: LabelGenerator
|
||||
): Chain[AsmLine] = {
|
||||
given stack: Stack = Stack()
|
||||
// Setup the stack with param 7 and up
|
||||
func.params.drop(argRegs.size).foreach(stack.reserve(_))
|
||||
stack.reserve(Size.Q64) // Reserve return pointer slot
|
||||
val funcLabel = labelGenerator.getLabelDef(func.name)
|
||||
var asm = labelGenerator.getDebugFunc(func.pos, func.name.name, funcLabel)
|
||||
val debugFunc = asm.size > 0
|
||||
asm += funcLabel
|
||||
asm ++= funcPrologue()
|
||||
// Push the rest of params onto the stack for simplicity
|
||||
argRegs.zip(func.params).foreach { (reg, param) =>
|
||||
asm += stack.push(param, Register(Size.Q64, reg))
|
||||
}
|
||||
asm ++= func.body.foldMap(generateStmt(_))
|
||||
// No need for epilogue here since all user functions must return explicitly
|
||||
if (debugFunc) {
|
||||
asm += Directive.Size(funcLabel, SizeExpr.Relative(funcLabel))
|
||||
asm += Directive.EndFunc
|
||||
}
|
||||
asm
|
||||
}
|
||||
|
||||
private def generateBuiltInFuncs()(using
|
||||
stack: Stack,
|
||||
labelGenerator: LabelGenerator
|
||||
): Chain[AsmLine] = {
|
||||
var asm = Chain.empty[AsmLine]
|
||||
|
||||
asm ++= wrapBuiltinFunc(
|
||||
Builtin.Exit,
|
||||
Chain(stackAlign, assemblyIR.Call(CLibFunc.Exit))
|
||||
)
|
||||
|
||||
asm ++= wrapBuiltinFunc(
|
||||
Builtin.Printf,
|
||||
Chain(
|
||||
stackAlign,
|
||||
assemblyIR.Call(CLibFunc.PrintF),
|
||||
Xor(RDI, RDI),
|
||||
assemblyIR.Call(CLibFunc.Fflush)
|
||||
)
|
||||
)
|
||||
|
||||
asm ++= wrapBuiltinFunc(
|
||||
Builtin.PrintCharArray,
|
||||
Chain(
|
||||
stackAlign,
|
||||
Load(RDX, MemLocation(RSI, KnownType.Int.size.toInt)),
|
||||
Move(Register(KnownType.Int.size, SI), MemLocation(RSI, opSize = Some(KnownType.Int.size))),
|
||||
assemblyIR.Call(CLibFunc.PrintF),
|
||||
Xor(RDI, RDI),
|
||||
assemblyIR.Call(CLibFunc.Fflush)
|
||||
)
|
||||
)
|
||||
|
||||
asm ++= wrapBuiltinFunc(
|
||||
Builtin.Malloc,
|
||||
Chain(
|
||||
stackAlign,
|
||||
assemblyIR.Call(CLibFunc.Malloc),
|
||||
// Out of memory check
|
||||
Compare(RAX, ImmediateVal(0)),
|
||||
Jump(labelGenerator.getLabelArg(OutOfMemoryError), Cond.Equal)
|
||||
)
|
||||
)
|
||||
|
||||
asm ++= wrapBuiltinFunc(
|
||||
Builtin.Free,
|
||||
Chain(
|
||||
stackAlign,
|
||||
Compare(RDI, ImmediateVal(0)),
|
||||
Jump(labelGenerator.getLabelArg(NullPtrError), Cond.Equal),
|
||||
assemblyIR.Call(CLibFunc.Free)
|
||||
)
|
||||
)
|
||||
|
||||
asm ++= wrapBuiltinFunc(
|
||||
Builtin.Read,
|
||||
Chain(
|
||||
stackAlign,
|
||||
Subtract(Register(Size.Q64, SP), ImmediateVal(8)),
|
||||
Push(RSI),
|
||||
Load(RSI, MemLocation(Register(Size.Q64, SP), opSize = Some(Size.Q64))),
|
||||
assemblyIR.Call(CLibFunc.Scanf),
|
||||
Pop(RAX)
|
||||
)
|
||||
)
|
||||
|
||||
asm
|
||||
}
|
||||
|
||||
private def generateStmt(stmt: Stmt)(using
|
||||
stack: Stack,
|
||||
labelGenerator: LabelGenerator
|
||||
): Chain[AsmLine] = {
|
||||
val fileNo = labelGenerator.getDebugFile(stmt.pos.file)
|
||||
var asm = Chain.one[AsmLine](Directive.Location(fileNo, stmt.pos.line, None))
|
||||
stmt match {
|
||||
case Assign(lhs, rhs) =>
|
||||
lhs match {
|
||||
case ident: Ident =>
|
||||
if (!stack.contains(ident)) asm += stack.reserve(ident)
|
||||
asm ++= evalExprOntoStack(rhs)
|
||||
asm += stack.pop(RAX)
|
||||
asm += Move(stack.accessVar(ident).copy(opSize = Some(Size.Q64)), RAX)
|
||||
case ArrayElem(x, i) =>
|
||||
asm ++= evalExprOntoStack(rhs)
|
||||
asm ++= evalExprOntoStack(i)
|
||||
asm += stack.pop(RCX)
|
||||
asm += Compare(ECX, ImmediateVal(0))
|
||||
asm += Jump(labelGenerator.getLabelArg(OutOfBoundsError), Cond.Less)
|
||||
asm += stack.push(KnownType.Int.size, RCX)
|
||||
asm ++= evalExprOntoStack(x)
|
||||
asm += stack.pop(RAX)
|
||||
asm += stack.pop(RCX)
|
||||
asm += Compare(RAX, ImmediateVal(0))
|
||||
asm += Jump(labelGenerator.getLabelArg(NullPtrError), Cond.Equal)
|
||||
asm += Compare(MemLocation(RAX, opSize = Some(KnownType.Int.size)), ECX)
|
||||
asm += Jump(labelGenerator.getLabelArg(OutOfBoundsError), Cond.LessEqual)
|
||||
asm += stack.pop(RDX)
|
||||
|
||||
asm += Move(
|
||||
MemLocation(RAX, KnownType.Int.size.toInt, (RCX, x.ty.elemSize.toInt)),
|
||||
Register(x.ty.elemSize, DX)
|
||||
)
|
||||
}
|
||||
|
||||
case If(cond, thenBranch, elseBranch) =>
|
||||
val elseLabel = labelGenerator.getLabel()
|
||||
val endLabel = labelGenerator.getLabel()
|
||||
|
||||
asm ++= evalExprOntoStack(cond)
|
||||
asm += stack.pop(RAX)
|
||||
asm += Compare(RAX, ImmediateVal(0))
|
||||
asm += Jump(LabelArg(elseLabel), Cond.Equal)
|
||||
|
||||
asm ++= stack.withScope(() => thenBranch.foldMap(generateStmt))
|
||||
asm += Jump(LabelArg(endLabel))
|
||||
asm += LabelDef(elseLabel)
|
||||
|
||||
asm ++= stack.withScope(() => elseBranch.foldMap(generateStmt))
|
||||
asm += LabelDef(endLabel)
|
||||
|
||||
case While(cond, body) =>
|
||||
val startLabel = labelGenerator.getLabel()
|
||||
val endLabel = labelGenerator.getLabel()
|
||||
|
||||
asm += LabelDef(startLabel)
|
||||
asm ++= evalExprOntoStack(cond)
|
||||
asm += stack.pop(RAX)
|
||||
asm += Compare(RAX, ImmediateVal(0))
|
||||
asm += Jump(LabelArg(endLabel), Cond.Equal)
|
||||
|
||||
asm ++= stack.withScope(() => body.foldMap(generateStmt))
|
||||
asm += Jump(LabelArg(startLabel))
|
||||
asm += LabelDef(endLabel)
|
||||
|
||||
case call: microWacc.Call =>
|
||||
asm ++= generateCall(call, isTail = false)
|
||||
|
||||
case microWacc.Return(expr) =>
|
||||
expr match {
|
||||
case call: microWacc.Call =>
|
||||
asm ++= generateCall(call, isTail = true) // tco
|
||||
case _ =>
|
||||
asm ++= evalExprOntoStack(expr)
|
||||
asm += stack.pop(RAX)
|
||||
asm ++= funcEpilogue()
|
||||
}
|
||||
}
|
||||
|
||||
asm
|
||||
}
|
||||
|
||||
private def evalExprOntoStack(expr: Expr)(using
|
||||
stack: Stack,
|
||||
labelGenerator: LabelGenerator
|
||||
): Chain[AsmLine] = {
|
||||
var asm = Chain.empty[AsmLine]
|
||||
val stackSizeStart = stack.size
|
||||
expr match {
|
||||
case IntLiter(v) => asm += stack.push(KnownType.Int.size, ImmediateVal(v))
|
||||
case CharLiter(v) => asm += stack.push(KnownType.Char.size, ImmediateVal(v.toInt))
|
||||
case ident: Ident =>
|
||||
val location = stack.accessVar(ident)
|
||||
// items in stack are guaranteed to be in Q64 slots,
|
||||
// so we are safe to wipe the opSize from the memory location
|
||||
asm += stack.push(ident.ty.size, location.copy(opSize = None))
|
||||
|
||||
case array @ ArrayLiter(elems) =>
|
||||
expr.ty match {
|
||||
case KnownType.String =>
|
||||
val str = elems.collect { case CharLiter(v) => v }.mkString
|
||||
asm += Load(RAX, MemLocation(RIP, labelGenerator.getLabelArg(str)))
|
||||
asm += stack.push(KnownType.String.size, RAX)
|
||||
case ty =>
|
||||
asm ++= generateCall(
|
||||
microWacc.Call(Builtin.Malloc, List(IntLiter(array.heapSize)))(array.pos),
|
||||
isTail = false
|
||||
)
|
||||
asm += stack.push(KnownType.Array(?).size, RAX)
|
||||
// Store the length of the array at the start
|
||||
asm += Move(
|
||||
MemLocation(RAX, opSize = Some(KnownType.Int.size)),
|
||||
ImmediateVal(elems.size)
|
||||
)
|
||||
elems.zipWithIndex.foldMap { (elem, i) =>
|
||||
asm ++= evalExprOntoStack(elem)
|
||||
asm += stack.pop(RCX)
|
||||
asm += stack.pop(RAX)
|
||||
asm += Move(
|
||||
MemLocation(RAX, KnownType.Int.size.toInt + i * ty.elemSize.toInt),
|
||||
Register(ty.elemSize, CX)
|
||||
)
|
||||
asm += stack.push(KnownType.Array(?).size, RAX)
|
||||
}
|
||||
}
|
||||
|
||||
case BoolLiter(true) =>
|
||||
asm += stack.push(KnownType.Bool.size, ImmediateVal(1))
|
||||
case BoolLiter(false) =>
|
||||
asm += Xor(RAX, RAX)
|
||||
asm += stack.push(KnownType.Bool.size, RAX)
|
||||
case NullLiter() =>
|
||||
asm += stack.push(KnownType.Pair(?, ?).size, ImmediateVal(0))
|
||||
case ArrayElem(x, i) =>
|
||||
asm ++= evalExprOntoStack(x)
|
||||
asm ++= evalExprOntoStack(i)
|
||||
asm += stack.pop(RCX)
|
||||
asm += Compare(RCX, ImmediateVal(0))
|
||||
asm += Jump(labelGenerator.getLabelArg(OutOfBoundsError), Cond.Less)
|
||||
asm += stack.pop(RAX)
|
||||
asm += Compare(RAX, ImmediateVal(0))
|
||||
asm += Jump(labelGenerator.getLabelArg(NullPtrError), Cond.Equal)
|
||||
asm += Compare(MemLocation(RAX, opSize = Some(KnownType.Int.size)), ECX)
|
||||
asm += Jump(labelGenerator.getLabelArg(OutOfBoundsError), Cond.LessEqual)
|
||||
// + Int because we store the length of the array at the start
|
||||
asm += Move(
|
||||
Register(x.ty.elemSize, AX),
|
||||
MemLocation(RAX, KnownType.Int.size.toInt, (RCX, x.ty.elemSize.toInt))
|
||||
)
|
||||
asm += stack.push(x.ty.elemSize, RAX)
|
||||
case UnaryOp(x, op) =>
|
||||
asm ++= evalExprOntoStack(x)
|
||||
op match {
|
||||
case UnaryOperator.Chr =>
|
||||
asm += Move(EAX, stack.head)
|
||||
asm += And(EAX, ImmediateVal(~_7_BIT_MASK))
|
||||
asm += Compare(EAX, ImmediateVal(0))
|
||||
asm += Jump(labelGenerator.getLabelArg(BadChrError), Cond.NotEqual)
|
||||
case UnaryOperator.Ord => // No op needed
|
||||
case UnaryOperator.Len =>
|
||||
asm += stack.pop(RAX)
|
||||
asm += Move(EAX, MemLocation(RAX, opSize = Some(KnownType.Int.size)))
|
||||
asm += stack.push(KnownType.Int.size, RAX)
|
||||
case UnaryOperator.Negate =>
|
||||
asm += Xor(EAX, EAX)
|
||||
asm += Subtract(EAX, stack.head)
|
||||
asm += Jump(labelGenerator.getLabelArg(OverflowError), Cond.Overflow)
|
||||
asm += stack.drop()
|
||||
asm += stack.push(KnownType.Int.size, RAX)
|
||||
case UnaryOperator.Not =>
|
||||
asm += Xor(stack.head, ImmediateVal(1))
|
||||
}
|
||||
|
||||
case BinaryOp(x, y, op) =>
|
||||
val destX = Register(x.ty.size, AX)
|
||||
asm ++= evalExprOntoStack(y)
|
||||
asm ++= evalExprOntoStack(x)
|
||||
asm += stack.pop(RAX)
|
||||
|
||||
op match {
|
||||
case BinaryOperator.Add =>
|
||||
asm += Add(stack.head, destX)
|
||||
asm += Jump(labelGenerator.getLabelArg(OverflowError), Cond.Overflow)
|
||||
case BinaryOperator.Sub =>
|
||||
asm += Subtract(destX, stack.head)
|
||||
asm += Jump(labelGenerator.getLabelArg(OverflowError), Cond.Overflow)
|
||||
asm += stack.drop()
|
||||
asm += stack.push(destX.size, RAX)
|
||||
case BinaryOperator.Mul =>
|
||||
asm += Multiply(destX, stack.head)
|
||||
asm += Jump(labelGenerator.getLabelArg(OverflowError), Cond.Overflow)
|
||||
asm += stack.drop()
|
||||
asm += stack.push(destX.size, RAX)
|
||||
|
||||
case BinaryOperator.Div =>
|
||||
asm += Compare(stack.head, ImmediateVal(0))
|
||||
asm += Jump(labelGenerator.getLabelArg(ZeroDivError), Cond.Equal)
|
||||
asm += CDQ()
|
||||
asm += Divide(stack.head)
|
||||
asm += stack.drop()
|
||||
asm += stack.push(destX.size, RAX)
|
||||
|
||||
case BinaryOperator.Mod =>
|
||||
asm += Compare(stack.head, ImmediateVal(0))
|
||||
asm += Jump(labelGenerator.getLabelArg(ZeroDivError), Cond.Equal)
|
||||
asm += CDQ()
|
||||
asm += Divide(stack.head)
|
||||
asm += stack.drop()
|
||||
asm += stack.push(destX.size, RDX)
|
||||
|
||||
case BinaryOperator.Eq => asm ++= generateComparison(destX, Cond.Equal)
|
||||
case BinaryOperator.Neq => asm ++= generateComparison(destX, Cond.NotEqual)
|
||||
case BinaryOperator.Greater => asm ++= generateComparison(destX, Cond.Greater)
|
||||
case BinaryOperator.GreaterEq => asm ++= generateComparison(destX, Cond.GreaterEqual)
|
||||
case BinaryOperator.Less => asm ++= generateComparison(destX, Cond.Less)
|
||||
case BinaryOperator.LessEq => asm ++= generateComparison(destX, Cond.LessEqual)
|
||||
case BinaryOperator.And => asm += And(stack.head, destX)
|
||||
case BinaryOperator.Or => asm += Or(stack.head, destX)
|
||||
}
|
||||
|
||||
case call: microWacc.Call =>
|
||||
asm ++= generateCall(call, isTail = false)
|
||||
asm += stack.push(call.ty.size, RAX)
|
||||
}
|
||||
|
||||
assert(
|
||||
stack.size == stackSizeStart + 1,
|
||||
"Sanity check: ONLY the evaluated expression should have been pushed onto the stack"
|
||||
)
|
||||
asm ++= zeroRest(stack.head.copy(opSize = Some(Size.Q64)), expr.ty.size)
|
||||
asm
|
||||
}
|
||||
|
||||
private def generateCall(call: microWacc.Call, isTail: Boolean)(using
|
||||
stack: Stack,
|
||||
labelGenerator: LabelGenerator
|
||||
): Chain[AsmLine] = {
|
||||
var asm = Chain.empty[AsmLine]
|
||||
val microWacc.Call(target, args) = call
|
||||
|
||||
// Evaluate arguments 0-6
|
||||
argRegs
|
||||
.zip(args)
|
||||
.map { (reg, expr) =>
|
||||
asm ++= evalExprOntoStack(expr)
|
||||
reg
|
||||
}
|
||||
// And set the appropriate registers
|
||||
.reverse
|
||||
.foreach { reg =>
|
||||
asm += stack.pop(Register(Size.Q64, reg))
|
||||
}
|
||||
|
||||
// Evaluate arguments 7 and up and push them onto the stack
|
||||
args.drop(argRegs.size).foldMap {
|
||||
asm ++= evalExprOntoStack(_)
|
||||
}
|
||||
|
||||
// Tail Call Optimisation (TCO)
|
||||
if (isTail) {
|
||||
asm += Jump(labelGenerator.getLabelArg(target)) // tail call
|
||||
} else {
|
||||
asm += assemblyIR.Call(labelGenerator.getLabelArg(target)) // regular call
|
||||
}
|
||||
|
||||
// Remove arguments 7 and up from the stack
|
||||
if (args.size > argRegs.size) {
|
||||
asm += stack.drop(args.size - argRegs.size)
|
||||
}
|
||||
|
||||
asm
|
||||
}
|
||||
|
||||
private def generateComparison(destX: Register, cond: Cond)(using
|
||||
stack: Stack
|
||||
): Chain[AsmLine] = {
|
||||
var asm = Chain.empty[AsmLine]
|
||||
|
||||
asm += Compare(destX, stack.head)
|
||||
asm += Set(Register(Size.B8, AX), cond)
|
||||
asm ++= zeroRest(RAX, Size.B8)
|
||||
asm += stack.drop()
|
||||
asm += stack.push(Size.B8, RAX)
|
||||
|
||||
asm
|
||||
}
|
||||
|
||||
private def funcPrologue()(using stack: Stack): Chain[AsmLine] = {
|
||||
var asm = Chain.empty[AsmLine]
|
||||
asm += stack.push(Size.Q64, RBP)
|
||||
asm += Move(RBP, Register(Size.Q64, SP))
|
||||
asm
|
||||
}
|
||||
|
||||
private def funcEpilogue(): Chain[AsmLine] = {
|
||||
var asm = Chain.empty[AsmLine]
|
||||
asm += Move(Register(Size.Q64, SP), RBP)
|
||||
asm += Pop(RBP)
|
||||
asm += assemblyIR.Return()
|
||||
asm
|
||||
}
|
||||
|
||||
def stackAlign: AsmLine = And(Register(Size.Q64, SP), ImmediateVal(-16))
|
||||
private def zeroRest(dest: Dest, size: Size): Chain[AsmLine] = size match {
|
||||
case Size.Q64 | Size.D32 => Chain.empty
|
||||
case _ => Chain.one(And(dest, ImmediateVal((1 << (size.toInt * 8)) - 1)))
|
||||
}
|
||||
|
||||
private val escapedCharsMapping = escapedChars.map { case (k, v) => v -> s"\\$k" }
|
||||
extension (s: String) {
|
||||
def escaped: String =
|
||||
s.flatMap(c => escapedCharsMapping.getOrElse(c, c.toString))
|
||||
}
|
||||
}
|
||||
272
src/main/wacc/backend/assemblyIR.scala
Normal file
272
src/main/wacc/backend/assemblyIR.scala
Normal file
@@ -0,0 +1,272 @@
|
||||
package wacc
|
||||
|
||||
object assemblyIR {
|
||||
|
||||
sealed trait AsmLine
|
||||
sealed trait Operand
|
||||
sealed trait Src extends Operand // mem location, register and imm value
|
||||
sealed trait Dest extends Operand // mem location and register
|
||||
|
||||
enum Size {
|
||||
case Q64, D32, W16, B8
|
||||
|
||||
def toInt: Int = this match {
|
||||
case Q64 => 8
|
||||
case D32 => 4
|
||||
case W16 => 2
|
||||
case B8 => 1
|
||||
}
|
||||
|
||||
private val ptr = "ptr "
|
||||
|
||||
override def toString(): String = this match {
|
||||
case Q64 => "qword " + ptr
|
||||
case D32 => "dword " + ptr
|
||||
case W16 => "word " + ptr
|
||||
case B8 => "byte " + ptr
|
||||
}
|
||||
}
|
||||
|
||||
enum RegName {
|
||||
case AX, BX, CX, DX, SI, DI, SP, BP, IP, R8, R9, R10, R11, R12, R13, R14, R15
|
||||
}
|
||||
|
||||
case class Register(size: Size, name: RegName) extends Dest with Src {
|
||||
import RegName._
|
||||
|
||||
if (size == Size.B8 && name == RegName.IP) {
|
||||
throw new IllegalArgumentException("Cannot have 8 bit register for IP")
|
||||
}
|
||||
override def toString = name match {
|
||||
case AX => tradToString("ax", "al")
|
||||
case BX => tradToString("bx", "bl")
|
||||
case CX => tradToString("cx", "cl")
|
||||
case DX => tradToString("dx", "dl")
|
||||
case SI => tradToString("si", "sil")
|
||||
case DI => tradToString("di", "dil")
|
||||
case SP => tradToString("sp", "spl")
|
||||
case BP => tradToString("bp", "bpl")
|
||||
case IP => tradToString("ip", "#INVALID")
|
||||
case R8 => newToString(8)
|
||||
case R9 => newToString(9)
|
||||
case R10 => newToString(10)
|
||||
case R11 => newToString(11)
|
||||
case R12 => newToString(12)
|
||||
case R13 => newToString(13)
|
||||
case R14 => newToString(14)
|
||||
case R15 => newToString(15)
|
||||
}
|
||||
|
||||
private def tradToString(base: String, byteName: String): String =
|
||||
size match {
|
||||
case Size.Q64 => "r" + base
|
||||
case Size.D32 => "e" + base
|
||||
case Size.W16 => base
|
||||
case Size.B8 => byteName
|
||||
}
|
||||
|
||||
private def newToString(base: Int): String = {
|
||||
val b = base.toString
|
||||
"r" + (size match {
|
||||
case Size.Q64 => b
|
||||
case Size.D32 => b + "d"
|
||||
case Size.W16 => b + "w"
|
||||
case Size.B8 => b + "b"
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// arguments
|
||||
enum CLibFunc extends Operand {
|
||||
case Scanf,
|
||||
Fflush,
|
||||
Exit,
|
||||
PrintF,
|
||||
Malloc,
|
||||
Free
|
||||
|
||||
private val plt = "@plt"
|
||||
|
||||
override def toString = this match {
|
||||
case Scanf => "scanf" + plt
|
||||
case Fflush => "fflush" + plt
|
||||
case Exit => "exit" + plt
|
||||
case PrintF => "printf" + plt
|
||||
case Malloc => "malloc" + plt
|
||||
case Free => "free" + plt
|
||||
}
|
||||
}
|
||||
|
||||
case class MemLocation(
|
||||
base: Register,
|
||||
offset: Int | LabelArg = 0,
|
||||
// scale 0 will make register irrelevant, no other reason as to why it's RAX
|
||||
scaledIndex: (Register, Int) = (Register(Size.Q64, RegName.AX), 0),
|
||||
opSize: Option[Size] = None
|
||||
) extends Dest
|
||||
with Src {
|
||||
def copy(
|
||||
base: Register = this.base,
|
||||
offset: Int | LabelArg = this.offset,
|
||||
scaledIndex: (Register, Int) = this.scaledIndex,
|
||||
opSize: Option[Size] = this.opSize
|
||||
): MemLocation = MemLocation(base, offset, scaledIndex, opSize)
|
||||
|
||||
override def toString(): String = {
|
||||
val opSizeStr = opSize.map(_.toString).getOrElse("")
|
||||
val baseStr = base.toString
|
||||
val offsetStr = offset match {
|
||||
case 0 => ""
|
||||
case off => s" + $off"
|
||||
}
|
||||
val scaledIndexStr = scaledIndex match {
|
||||
case (reg, scale) if scale != 0 => s" + $reg * $scale"
|
||||
case _ => ""
|
||||
}
|
||||
s"$opSizeStr[$baseStr$scaledIndexStr$offsetStr]"
|
||||
}
|
||||
}
|
||||
|
||||
case class ImmediateVal(value: Int) extends Src {
|
||||
override def toString = value.toString
|
||||
}
|
||||
|
||||
case class LabelArg(name: String) extends Operand {
|
||||
override def toString = name
|
||||
}
|
||||
|
||||
abstract class Operation(ins: String, ops: Operand*) extends AsmLine {
|
||||
override def toString: String = s"\t$ins ${ops.mkString(", ")}"
|
||||
}
|
||||
|
||||
// arithmetic operations
|
||||
case class Add(op1: Dest, op2: Src) extends Operation("add", op1, op2)
|
||||
case class Subtract(op1: Dest, op2: Src) extends Operation("sub", op1, op2)
|
||||
case class Multiply(ops: Operand*) extends Operation("imul", ops*)
|
||||
case class Divide(op1: Src) extends Operation("idiv", op1)
|
||||
case class Negate(op: Dest) extends Operation("neg", op)
|
||||
// bitwise operations
|
||||
case class And(op1: Dest, op2: Src) extends Operation("and", op1, op2)
|
||||
case class Or(op1: Dest, op2: Src) extends Operation("or", op1, op2)
|
||||
case class Xor(op1: Dest, op2: Src) extends Operation("xor", op1, op2)
|
||||
case class Compare(op1: Dest, op2: Src) extends Operation("cmp", op1, op2)
|
||||
case class CDQ() extends Operation("cdq")
|
||||
// stack operations
|
||||
case class Push(op1: Src) extends Operation("push", op1)
|
||||
case class Pop(op1: Src) extends Operation("pop", op1)
|
||||
// move operations
|
||||
case class Move(op1: Dest, op2: Src) extends Operation("mov", op1, op2)
|
||||
case class Load(op1: Register, op2: MemLocation) extends Operation("lea ", op1, op2)
|
||||
|
||||
// function call operations
|
||||
case class Call(op1: CLibFunc | LabelArg) extends Operation("call", op1)
|
||||
case class Return() extends Operation("ret")
|
||||
|
||||
// conditional operations
|
||||
case class Jump(op1: LabelArg, condition: Cond = Cond.Always)
|
||||
extends Operation(s"j${condition.toString}", op1)
|
||||
case class Set(op1: Dest, condition: Cond = Cond.Always)
|
||||
extends Operation(s"set${condition.toString}", op1)
|
||||
|
||||
case class LabelDef(name: String) extends AsmLine {
|
||||
override def toString = s"$name:"
|
||||
}
|
||||
|
||||
case class Comment(comment: String) extends AsmLine {
|
||||
override def toString =
|
||||
comment.split("\n").map(line => s"# ${line}").mkString("\n")
|
||||
}
|
||||
|
||||
enum Cond {
|
||||
case Equal,
|
||||
NotEqual,
|
||||
Greater,
|
||||
GreaterEqual,
|
||||
Less,
|
||||
LessEqual,
|
||||
Overflow,
|
||||
Always
|
||||
override def toString(): String = this match {
|
||||
case Equal => "e"
|
||||
case NotEqual => "ne"
|
||||
case Greater => "g"
|
||||
case GreaterEqual => "ge"
|
||||
case Less => "l"
|
||||
case LessEqual => "le"
|
||||
case Overflow => "o"
|
||||
case Always => "mp"
|
||||
}
|
||||
}
|
||||
|
||||
enum Directive extends AsmLine {
|
||||
case IntelSyntax, RoData, Text, EndFunc
|
||||
case Global(name: String)
|
||||
case Int(value: scala.Int)
|
||||
case Asciz(string: String)
|
||||
case File(no: scala.Int, file: String)
|
||||
case Location(fileNo: scala.Int, lineNo: scala.Int, colNo: Option[scala.Int])
|
||||
case Func(name: String, label: LabelDef)
|
||||
case Type(label: LabelDef, symbolType: SymbolType)
|
||||
case Size(label: LabelDef, expr: SizeExpr)
|
||||
|
||||
override def toString(): String = this match {
|
||||
case IntelSyntax => ".intel_syntax noprefix"
|
||||
case Global(name) => s".globl $name"
|
||||
case Text => ".text"
|
||||
case RoData => ".section .rodata"
|
||||
case Int(value) => s"\t.int $value"
|
||||
case Asciz(string) => s"\t.asciz \"$string\""
|
||||
case File(no, file) => s".file $no \"${file}\""
|
||||
case Location(fileNo, lineNo, colNo) =>
|
||||
s"\t.loc $fileNo $lineNo" + colNo.map(c => s" $c").getOrElse("")
|
||||
case Func(name, label) =>
|
||||
s".func $name, ${label.name}"
|
||||
case EndFunc => ".endfunc"
|
||||
case Type(label, symbolType) =>
|
||||
s".type ${label.name}, @${symbolType.toString}"
|
||||
case Directive.Size(label, expr) =>
|
||||
s".size ${label.name}, ${expr.toString}"
|
||||
}
|
||||
}
|
||||
|
||||
enum SymbolType {
|
||||
case Function
|
||||
|
||||
override def toString(): String = this match {
|
||||
case Function => "function"
|
||||
}
|
||||
}
|
||||
|
||||
enum SizeExpr {
|
||||
case Relative(label: LabelDef)
|
||||
|
||||
override def toString(): String = this match {
|
||||
case Relative(label) => s".-${label.name}"
|
||||
}
|
||||
}
|
||||
|
||||
enum PrintFormat {
|
||||
case Int, Char, String
|
||||
|
||||
override def toString(): String = this match {
|
||||
case Int => "%d"
|
||||
case Char => "%c"
|
||||
case String => "%s"
|
||||
}
|
||||
}
|
||||
|
||||
object commonRegisters {
|
||||
import Size._
|
||||
import RegName._
|
||||
|
||||
val RAX = Register(Q64, AX)
|
||||
val EAX = Register(D32, AX)
|
||||
val RDI = Register(Q64, DI)
|
||||
val RIP = Register(Q64, IP)
|
||||
val RBP = Register(Q64, BP)
|
||||
val RSI = Register(Q64, SI)
|
||||
val RDX = Register(Q64, DX)
|
||||
val RCX = Register(Q64, CX)
|
||||
val ECX = Register(D32, CX)
|
||||
}
|
||||
}
|
||||
33
src/main/wacc/backend/sizeExtensions.scala
Normal file
33
src/main/wacc/backend/sizeExtensions.scala
Normal file
@@ -0,0 +1,33 @@
|
||||
package wacc
|
||||
|
||||
object sizeExtensions {
|
||||
import microWacc._
|
||||
import types._
|
||||
import assemblyIR.Size
|
||||
|
||||
extension (expr: Expr) {
|
||||
|
||||
/** Calculate the size (bytes) of the heap required for the expression. */
|
||||
def heapSize: Int = (expr, expr.ty) match {
|
||||
case (ArrayLiter(elems), ty) =>
|
||||
KnownType.Int.size.toInt + elems.size * ty.elemSize.toInt
|
||||
case _ => expr.ty.size.toInt
|
||||
}
|
||||
}
|
||||
|
||||
extension (ty: SemType) {
|
||||
|
||||
/** Calculate the size (bytes) of a type in a register. */
|
||||
def size: Size = ty match {
|
||||
case KnownType.Int => Size.D32
|
||||
case KnownType.Bool | KnownType.Char => Size.B8
|
||||
case KnownType.String | KnownType.Array(_) | KnownType.Pair(_, _) | ? => Size.Q64
|
||||
}
|
||||
|
||||
def elemSize: Size = ty match {
|
||||
case KnownType.Array(elem) => elem.size
|
||||
case KnownType.Pair(_, _) => Size.Q64
|
||||
case _ => ty.size
|
||||
}
|
||||
}
|
||||
}
|
||||
42
src/main/wacc/backend/writer.scala
Normal file
42
src/main/wacc/backend/writer.scala
Normal file
@@ -0,0 +1,42 @@
|
||||
package wacc
|
||||
|
||||
import cats.effect.Resource
|
||||
import java.nio.charset.StandardCharsets
|
||||
import java.io.BufferedWriter
|
||||
import java.io.FileWriter
|
||||
import cats.data.Chain
|
||||
import cats.effect.IO
|
||||
|
||||
import org.typelevel.log4cats.Logger
|
||||
import java.nio.file.Path
|
||||
|
||||
object writer {
|
||||
import assemblyIR._
|
||||
|
||||
// TODO: Judging from documentation it seems as though IO.blocking is the correct choice
|
||||
// But needs checking
|
||||
|
||||
/** Creates a resource safe BufferedWriter */
|
||||
private def bufferedWriter(outputPath: Path): Resource[IO, BufferedWriter] =
|
||||
Resource.make {
|
||||
IO.blocking(new BufferedWriter(new FileWriter(outputPath.toFile, StandardCharsets.UTF_8)))
|
||||
} { writer =>
|
||||
IO.blocking(writer.close())
|
||||
.handleErrorWith(_ => IO.unit) // TODO: ensures writer is closed even if an error occurs
|
||||
}
|
||||
|
||||
/** Write line safely into a BufferedWriter */
|
||||
private def writeLines(writer: BufferedWriter, lines: Chain[AsmLine]): IO[Unit] =
|
||||
IO.blocking {
|
||||
lines.iterator.foreach { line =>
|
||||
writer.write(line.toString)
|
||||
writer.newLine()
|
||||
}
|
||||
}
|
||||
|
||||
/** Main function to write assembly to a file */
|
||||
def writeTo(asmList: Chain[AsmLine], outputPath: Path)(using logger: Logger[IO]): IO[Unit] =
|
||||
bufferedWriter(outputPath).use {
|
||||
writeLines(_, asmList)
|
||||
}
|
||||
}
|
||||
126
src/main/wacc/frontend/Error.scala
Normal file
126
src/main/wacc/frontend/Error.scala
Normal file
@@ -0,0 +1,126 @@
|
||||
package wacc
|
||||
|
||||
import wacc.ast.Position
|
||||
import wacc.types._
|
||||
import java.io.File
|
||||
|
||||
private val SYNTAX_ERROR = 100
|
||||
private val SEMANTIC_ERROR = 200
|
||||
|
||||
/** Error types for semantic errors
|
||||
*/
|
||||
enum Error {
|
||||
case DuplicateDeclaration(ident: ast.Ident)
|
||||
case UndeclaredVariable(ident: ast.Ident)
|
||||
case UndefinedFunction(ident: ast.Ident)
|
||||
|
||||
case FunctionParamsMismatch(ident: ast.Ident, expected: Int, got: Int, funcType: FuncType)
|
||||
case SemanticError(pos: Position, msg: String)
|
||||
case TypeMismatch(pos: Position, expected: SemType, got: SemType, msg: String)
|
||||
case InternalError(pos: Position, msg: String)
|
||||
|
||||
case SyntaxError(file: File, msg: String)
|
||||
}
|
||||
|
||||
extension (e: Error) {
|
||||
def exitCode: Int = e match {
|
||||
case Error.SyntaxError(_, _) => SYNTAX_ERROR
|
||||
case _ => SEMANTIC_ERROR
|
||||
}
|
||||
}
|
||||
|
||||
/** Function to handle printing the details of a given semantic error
|
||||
*
|
||||
* @param error
|
||||
* Error object
|
||||
* @param errorContent
|
||||
* Contents of the file to generate code snippets
|
||||
*/
|
||||
def formatError(error: Error): String = {
|
||||
val sb = new StringBuilder()
|
||||
|
||||
/** Format the file of an error
|
||||
*
|
||||
* @param file
|
||||
* File of the error
|
||||
*/
|
||||
def formatFile(file: File): Unit = {
|
||||
sb.append(s"File: ${file.getCanonicalPath}\n")
|
||||
}
|
||||
|
||||
/** Function to format the position of an error
|
||||
*
|
||||
* @param pos
|
||||
* Position of the error
|
||||
*/
|
||||
def formatPosition(pos: Position): Unit = {
|
||||
formatFile(pos.file)
|
||||
sb.append(s"(line ${pos.line}, column ${pos.column}):\n")
|
||||
}
|
||||
|
||||
/** Function to highlight a section of code for an error message
|
||||
*
|
||||
* @param pos
|
||||
* Position of the error
|
||||
* @param size
|
||||
* Size(in chars) of section to highlight
|
||||
*/
|
||||
def formatHighlight(pos: Position, size: Int): Unit = {
|
||||
val lines = os.read(os.Path(pos.file.getCanonicalPath)).split("\n")
|
||||
val preLine = if (pos.line > 1) lines(pos.line - 2) else ""
|
||||
val midLine = lines(pos.line - 1)
|
||||
val postLine = if (pos.line < lines.size) lines(pos.line) else ""
|
||||
val linePointer = " " * (pos.column + 2) + ("^" * (size)) + "\n"
|
||||
|
||||
sb.append(
|
||||
s" >$preLine\n >$midLine\n$linePointer >$postLine\netscape"
|
||||
)
|
||||
}
|
||||
|
||||
error match {
|
||||
case Error.SyntaxError(_, _) =>
|
||||
sb.append("Syntax error:\n")
|
||||
case _ =>
|
||||
sb.append("Semantic error:\n")
|
||||
}
|
||||
|
||||
error match {
|
||||
case Error.DuplicateDeclaration(ident) =>
|
||||
formatPosition(ident.pos)
|
||||
sb.append(s"Duplicate declaration of identifier ${ident.v}\n")
|
||||
formatHighlight(ident.pos, ident.v.length)
|
||||
case Error.UndeclaredVariable(ident) =>
|
||||
formatPosition(ident.pos)
|
||||
sb.append(s"Undeclared variable ${ident.v}\n")
|
||||
formatHighlight(ident.pos, ident.v.length)
|
||||
case Error.UndefinedFunction(ident) =>
|
||||
formatPosition(ident.pos)
|
||||
sb.append(s"Undefined function ${ident.v}\n")
|
||||
formatHighlight(ident.pos, ident.v.length)
|
||||
case Error.FunctionParamsMismatch(id, expected, got, funcType) =>
|
||||
formatPosition(id.pos)
|
||||
sb.append(s"Function expects $expected parameters, got $got\n")
|
||||
sb.append(
|
||||
s"(function ${id.v} has type (${funcType.params.mkString(", ")}) -> ${funcType.returnType})\n"
|
||||
)
|
||||
formatHighlight(id.pos, 1)
|
||||
case Error.TypeMismatch(pos, expected, got, msg) =>
|
||||
formatPosition(pos)
|
||||
sb.append(s"Type mismatch: $msg\nExpected: $expected\nGot: $got\n")
|
||||
formatHighlight(pos, 1)
|
||||
case Error.SemanticError(pos, msg) =>
|
||||
formatPosition(pos)
|
||||
sb.append(msg + "\n")
|
||||
formatHighlight(pos, 1)
|
||||
case wacc.Error.InternalError(pos, msg) =>
|
||||
formatPosition(pos)
|
||||
sb.append(s"Internal error: $msg\n")
|
||||
formatHighlight(pos, 1)
|
||||
case Error.SyntaxError(file, msg) =>
|
||||
formatFile(file)
|
||||
sb.append(msg + "\n")
|
||||
sb.append("\n")
|
||||
}
|
||||
|
||||
sb.toString()
|
||||
}
|
||||
@@ -1,10 +1,10 @@
|
||||
package wacc
|
||||
|
||||
import java.io.File
|
||||
import parsley.Parsley
|
||||
import parsley.generic.ErrorBridge
|
||||
import parsley.ap._
|
||||
import parsley.position._
|
||||
import parsley.syntax.zipped._
|
||||
import cats.data.NonEmptyList
|
||||
|
||||
object ast {
|
||||
@@ -23,26 +23,42 @@ object ast {
|
||||
/* ============================ ATOMIC EXPRESSIONS ============================ */
|
||||
|
||||
case class IntLiter(v: Int)(val pos: Position) extends Expr6
|
||||
object IntLiter extends ParserBridgePos1[Int, IntLiter]
|
||||
object IntLiter extends ParserBridgePos1Atom[Int, IntLiter]
|
||||
case class BoolLiter(v: Boolean)(val pos: Position) extends Expr6
|
||||
object BoolLiter extends ParserBridgePos1[Boolean, BoolLiter]
|
||||
object BoolLiter extends ParserBridgePos1Atom[Boolean, BoolLiter]
|
||||
case class CharLiter(v: Char)(val pos: Position) extends Expr6
|
||||
object CharLiter extends ParserBridgePos1[Char, CharLiter]
|
||||
object CharLiter extends ParserBridgePos1Atom[Char, CharLiter]
|
||||
case class StrLiter(v: String)(val pos: Position) extends Expr6
|
||||
object StrLiter extends ParserBridgePos1[String, StrLiter]
|
||||
object StrLiter extends ParserBridgePos1Atom[String, StrLiter]
|
||||
case class PairLiter()(val pos: Position) extends Expr6
|
||||
object PairLiter extends ParserBridgePos0[PairLiter]
|
||||
case class Ident(v: String, var uid: Int = -1)(val pos: Position) extends Expr6 with LValue
|
||||
object Ident extends ParserBridgePos1[String, Ident] {
|
||||
case class Ident(var v: String, var guid: Int = -1, var ty: types.RenamerType = types.?)(
|
||||
val pos: Position
|
||||
) extends Expr6
|
||||
with LValue
|
||||
object Ident extends ParserBridgePos1Atom[String, Ident] {
|
||||
def apply(v: String)(pos: Position): Ident = new Ident(v)(pos)
|
||||
}
|
||||
case class ArrayElem(name: Ident, indices: NonEmptyList[Expr])(val pos: Position)
|
||||
extends Expr6
|
||||
with LValue
|
||||
object ArrayElem extends ParserBridgePos1[NonEmptyList[Expr], Ident => ArrayElem] {
|
||||
def apply(a: NonEmptyList[Expr])(pos: Position): Ident => ArrayElem =
|
||||
name => ArrayElem(name, a)(pos)
|
||||
object ArrayElem extends ParserBridgePos2Chain[NonEmptyList[Expr], Ident, ArrayElem] {
|
||||
def apply(indices: NonEmptyList[Expr], name: Ident)(pos: Position): ArrayElem =
|
||||
new ArrayElem(name, indices)(pos)
|
||||
}
|
||||
// object ArrayElem extends ParserBridgePos1[NonEmptyList[Expr], (File => Ident) => ArrayElem] {
|
||||
// def apply(a: NonEmptyList[Expr])(pos: Position): (File => Ident) => ArrayElem =
|
||||
// name => ArrayElem(name(pos.file), a)(pos)
|
||||
// }
|
||||
// object ArrayElem extends ParserSingletonBridgePos[(File => NonEmptyList[Expr]) => (File => Ident) => File => ArrayElem] {
|
||||
// // def apply(indices: NonEmptyList[Expr]): (File => Ident) => File => ArrayElem =
|
||||
// // name => file => new ArrayElem(name(file), )
|
||||
// def apply(indices: Parsley[File => NonEmptyList[Expr]]): Parsley[(File => Ident) => File => ArrayElem] =
|
||||
// // error(ap1(pos.map(con),))
|
||||
|
||||
// override final def con(pos: (Int, Int)): (File => NonEmptyList[Expr]) => => C =
|
||||
// (a, b) => file => this.apply(a(file), b(file))(Position(pos._1, pos._2, file))
|
||||
// }
|
||||
case class Parens(expr: Expr)(val pos: Position) extends Expr6
|
||||
object Parens extends ParserBridgePos1[Expr, Parens]
|
||||
|
||||
@@ -120,8 +136,9 @@ object ast {
|
||||
case class ArrayType(elemType: Type, dimensions: Int)(val pos: Position)
|
||||
extends Type
|
||||
with PairElemType
|
||||
object ArrayType extends ParserBridgePos1[Int, Type => ArrayType] {
|
||||
def apply(a: Int)(pos: Position): Type => ArrayType = elemType => ArrayType(elemType, a)(pos)
|
||||
object ArrayType extends ParserBridgePos2Chain[Int, Type, ArrayType] {
|
||||
def apply(dimensions: Int, elemType: Type)(pos: Position): ArrayType =
|
||||
ArrayType(elemType, dimensions)(pos)
|
||||
}
|
||||
case class PairType(fst: PairElemType, snd: PairElemType)(val pos: Position) extends Type
|
||||
object PairType extends ParserBridgePos2[PairElemType, PairElemType, PairType]
|
||||
@@ -132,6 +149,18 @@ object ast {
|
||||
|
||||
/* ============================ PROGRAM STRUCTURE ============================ */
|
||||
|
||||
case class ImportedFunc(sourceName: Ident, importName: Ident)(val pos: Position)
|
||||
object ImportedFunc extends ParserBridgePos2[Ident, Option[Ident], ImportedFunc] {
|
||||
def apply(a: Ident, b: Option[Ident])(pos: Position): ImportedFunc =
|
||||
new ImportedFunc(a, b.getOrElse(a))(pos)
|
||||
}
|
||||
|
||||
case class Import(source: StrLiter, funcs: NonEmptyList[ImportedFunc])(val pos: Position)
|
||||
object Import extends ParserBridgePos2[StrLiter, NonEmptyList[ImportedFunc], Import]
|
||||
|
||||
case class PartialProgram(imports: List[Import], self: Program)(val pos: Position)
|
||||
object PartialProgram extends ParserBridgePos2[List[Import], Program, PartialProgram]
|
||||
|
||||
case class Program(funcs: List[FuncDecl], main: NonEmptyList[Stmt])(val pos: Position)
|
||||
object Program extends ParserBridgePos2[List[FuncDecl], NonEmptyList[Stmt], Program]
|
||||
|
||||
@@ -144,15 +173,15 @@ object ast {
|
||||
body: NonEmptyList[Stmt]
|
||||
)(val pos: Position)
|
||||
object FuncDecl
|
||||
extends ParserBridgePos2[
|
||||
List[Param],
|
||||
NonEmptyList[Stmt],
|
||||
((Type, Ident)) => FuncDecl
|
||||
extends ParserBridgePos2Chain[
|
||||
(List[Param], NonEmptyList[Stmt]),
|
||||
((Type, Ident)),
|
||||
FuncDecl
|
||||
] {
|
||||
def apply(params: List[Param], body: NonEmptyList[Stmt])(
|
||||
def apply(paramsBody: (List[Param], NonEmptyList[Stmt]), retTyName: (Type, Ident))(
|
||||
pos: Position
|
||||
): ((Type, Ident)) => FuncDecl =
|
||||
(returnType, name) => FuncDecl(returnType, name, params, body)(pos)
|
||||
): FuncDecl =
|
||||
new FuncDecl(retTyName._1, retTyName._2, paramsBody._1, paramsBody._2)(pos)
|
||||
}
|
||||
|
||||
case class Param(paramType: Type, name: Ident)(val pos: Position)
|
||||
@@ -160,7 +189,9 @@ object ast {
|
||||
|
||||
/* ============================ STATEMENTS ============================ */
|
||||
|
||||
sealed trait Stmt
|
||||
sealed trait Stmt {
|
||||
val pos: Position
|
||||
}
|
||||
case class Skip()(val pos: Position) extends Stmt
|
||||
object Skip extends ParserBridgePos0[Skip]
|
||||
case class VarDecl(varType: Type, name: Ident, value: RValue)(val pos: Position) extends Stmt
|
||||
@@ -192,7 +223,9 @@ object ast {
|
||||
val pos: Position
|
||||
}
|
||||
|
||||
sealed trait RValue
|
||||
sealed trait RValue {
|
||||
val pos: Position
|
||||
}
|
||||
case class ArrayLiter(elems: List[Expr])(val pos: Position) extends RValue
|
||||
object ArrayLiter extends ParserBridgePos1[List[Expr], ArrayLiter]
|
||||
case class NewPair(fst: Expr, snd: Expr)(val pos: Position) extends RValue
|
||||
@@ -208,7 +241,7 @@ object ast {
|
||||
|
||||
/* ============================ PARSER BRIDGES ============================ */
|
||||
|
||||
case class Position(line: Int, column: Int)
|
||||
case class Position(line: Int, column: Int, file: File)
|
||||
|
||||
trait ParserSingletonBridgePos[+A] extends ErrorBridge {
|
||||
protected def con(pos: (Int, Int)): A
|
||||
@@ -216,38 +249,63 @@ object ast {
|
||||
final def <#(op: Parsley[?]): Parsley[A] = this from op
|
||||
}
|
||||
|
||||
trait ParserBridgePos0[+A] extends ParserSingletonBridgePos[A] {
|
||||
trait ParserBridgePos0[+A] extends ParserSingletonBridgePos[File => A] {
|
||||
def apply()(pos: Position): A
|
||||
|
||||
override final def con(pos: (Int, Int)): A =
|
||||
apply()(Position(pos._1, pos._2))
|
||||
override final def con(pos: (Int, Int)): File => A =
|
||||
file => apply()(Position(pos._1, pos._2, file))
|
||||
}
|
||||
|
||||
trait ParserBridgePos1[-A, +B] extends ParserSingletonBridgePos[A => B] {
|
||||
trait ParserBridgePos1Atom[-A, +B] extends ParserSingletonBridgePos[A => File => B] {
|
||||
def apply(a: A)(pos: Position): B
|
||||
def apply(a: Parsley[A]): Parsley[B] = error(ap1(pos.map(con), a))
|
||||
def apply(a: Parsley[A]): Parsley[File => B] = error(ap1(pos.map(con), a))
|
||||
|
||||
override final def con(pos: (Int, Int)): A => B =
|
||||
this.apply(_)(Position(pos._1, pos._2))
|
||||
override final def con(pos: (Int, Int)): A => File => B =
|
||||
a => file => this.apply(a)(Position(pos._1, pos._2, file))
|
||||
}
|
||||
|
||||
trait ParserBridgePos2[-A, -B, +C] extends ParserSingletonBridgePos[(A, B) => C] {
|
||||
trait ParserBridgePos1[-A, +B] extends ParserSingletonBridgePos[(File => A) => File => B] {
|
||||
def apply(a: A)(pos: Position): B
|
||||
def apply(a: Parsley[File => A]): Parsley[File => B] = error(ap1(pos.map(con), a))
|
||||
|
||||
override final def con(pos: (Int, Int)): (File => A) => File => B =
|
||||
a => file => this.apply(a(file))(Position(pos._1, pos._2, file))
|
||||
}
|
||||
|
||||
trait ParserBridgePos2Chain[-A, -B, +C]
|
||||
extends ParserSingletonBridgePos[(File => A) => (File => B) => File => C] {
|
||||
def apply(a: A, b: B)(pos: Position): C
|
||||
def apply(a: Parsley[A], b: => Parsley[B]): Parsley[C] = error(
|
||||
def apply(a: Parsley[File => A]): Parsley[(File => B) => File => C] = error(
|
||||
ap1(pos.map(con), a)
|
||||
)
|
||||
|
||||
override final def con(pos: (Int, Int)): (File => A) => (File => B) => File => C =
|
||||
a => b => file => this.apply(a(file), b(file))(Position(pos._1, pos._2, file))
|
||||
}
|
||||
|
||||
trait ParserBridgePos2[-A, -B, +C]
|
||||
extends ParserSingletonBridgePos[(File => A, File => B) => File => C] {
|
||||
def apply(a: A, b: B)(pos: Position): C
|
||||
def apply(a: Parsley[File => A], b: => Parsley[File => B]): Parsley[File => C] = error(
|
||||
ap2(pos.map(con), a, b)
|
||||
)
|
||||
|
||||
override final def con(pos: (Int, Int)): (A, B) => C =
|
||||
apply(_, _)(Position(pos._1, pos._2))
|
||||
override final def con(pos: (Int, Int)): (File => A, File => B) => File => C =
|
||||
(a, b) => file => this.apply(a(file), b(file))(Position(pos._1, pos._2, file))
|
||||
}
|
||||
|
||||
trait ParserBridgePos3[-A, -B, -C, +D] extends ParserSingletonBridgePos[(A, B, C) => D] {
|
||||
trait ParserBridgePos3[-A, -B, -C, +D]
|
||||
extends ParserSingletonBridgePos[(File => A, File => B, File => C) => File => D] {
|
||||
def apply(a: A, b: B, c: C)(pos: Position): D
|
||||
def apply(a: Parsley[A], b: => Parsley[B], c: => Parsley[C]): Parsley[D] = error(
|
||||
def apply(
|
||||
a: Parsley[File => A],
|
||||
b: => Parsley[File => B],
|
||||
c: => Parsley[File => C]
|
||||
): Parsley[File => D] = error(
|
||||
ap3(pos.map(con), a, b, c)
|
||||
)
|
||||
|
||||
override final def con(pos: (Int, Int)): (A, B, C) => D =
|
||||
apply(_, _, _)(Position(pos._1, pos._2))
|
||||
override final def con(pos: (Int, Int)): (File => A, File => B, File => C) => File => D =
|
||||
(a, b, c) => file => apply(a(file), b(file), c(file))(Position(pos._1, pos._2, file))
|
||||
}
|
||||
}
|
||||
@@ -39,6 +39,17 @@ val errConfig = new ErrorConfig {
|
||||
)
|
||||
}
|
||||
object lexer {
|
||||
val escapedChars: Map[String, Int] = Map(
|
||||
"0" -> '\u0000',
|
||||
"b" -> '\b',
|
||||
"t" -> '\t',
|
||||
"n" -> '\n',
|
||||
"f" -> '\f',
|
||||
"r" -> '\r',
|
||||
"\\" -> '\\',
|
||||
"'" -> '\'',
|
||||
"\"" -> '\"'
|
||||
)
|
||||
|
||||
/** Language description for the WACC lexer
|
||||
*/
|
||||
@@ -63,15 +74,9 @@ object lexer {
|
||||
textDesc = TextDesc.plain.copy(
|
||||
graphicCharacter = Basic(c => c >= ' ' && c != '\\' && c != '\'' && c != '"'),
|
||||
escapeSequences = EscapeDesc.plain.copy(
|
||||
literals = Set('\\', '"', '\''),
|
||||
mapping = Map(
|
||||
"0" -> '\u0000',
|
||||
"b" -> '\b',
|
||||
"t" -> '\t',
|
||||
"n" -> '\n',
|
||||
"f" -> '\f',
|
||||
"r" -> '\r'
|
||||
)
|
||||
literals =
|
||||
escapedChars.filter { (s, chr) => chr.toChar.toString == s }.map(_._2.toChar).toSet,
|
||||
mapping = escapedChars.filter { (s, chr) => chr.toChar.toString != s }
|
||||
)
|
||||
),
|
||||
numericDesc = NumericDesc.plain.copy(
|
||||
97
src/main/wacc/frontend/microWacc.scala
Normal file
97
src/main/wacc/frontend/microWacc.scala
Normal file
@@ -0,0 +1,97 @@
|
||||
package wacc
|
||||
|
||||
import cats.data.Chain
|
||||
|
||||
object microWacc {
|
||||
import wacc.ast.Position
|
||||
import wacc.types._
|
||||
|
||||
sealed trait CallTarget(val retTy: SemType)
|
||||
sealed trait Expr(val ty: SemType)
|
||||
sealed trait LValue extends Expr
|
||||
|
||||
// Atomic expressions
|
||||
case class IntLiter(v: Int) extends Expr(KnownType.Int)
|
||||
case class BoolLiter(v: Boolean) extends Expr(KnownType.Bool)
|
||||
case class CharLiter(v: Char) extends Expr(KnownType.Char)
|
||||
case class ArrayLiter(elems: List[Expr])(ty: SemType, val pos: Position) extends Expr(ty)
|
||||
case class NullLiter()(ty: SemType) extends Expr(ty)
|
||||
case class Ident(name: String, uid: Int)(identTy: SemType)
|
||||
extends Expr(identTy)
|
||||
with CallTarget(identTy)
|
||||
with LValue
|
||||
case class ArrayElem(value: LValue, index: Expr)(ty: SemType) extends Expr(ty) with LValue
|
||||
|
||||
// Operators
|
||||
case class UnaryOp(x: Expr, op: UnaryOperator)(ty: SemType) extends Expr(ty)
|
||||
enum UnaryOperator {
|
||||
case Negate
|
||||
case Not
|
||||
case Len
|
||||
case Ord
|
||||
case Chr
|
||||
}
|
||||
case class BinaryOp(x: Expr, y: Expr, op: BinaryOperator)(ty: SemType) extends Expr(ty)
|
||||
enum BinaryOperator {
|
||||
case Add
|
||||
case Sub
|
||||
case Mul
|
||||
case Div
|
||||
case Mod
|
||||
case Greater
|
||||
case GreaterEq
|
||||
case Less
|
||||
case LessEq
|
||||
case Eq
|
||||
case Neq
|
||||
case And
|
||||
case Or
|
||||
}
|
||||
object BinaryOperator {
|
||||
def fromAst(op: ast.BinaryOp): BinaryOperator = op match {
|
||||
case _: ast.Add => Add
|
||||
case _: ast.Sub => Sub
|
||||
case _: ast.Mul => Mul
|
||||
case _: ast.Div => Div
|
||||
case _: ast.Mod => Mod
|
||||
case _: ast.Greater => Greater
|
||||
case _: ast.GreaterEq => GreaterEq
|
||||
case _: ast.Less => Less
|
||||
case _: ast.LessEq => LessEq
|
||||
case _: ast.Eq => Eq
|
||||
case _: ast.Neq => Neq
|
||||
case _: ast.And => And
|
||||
case _: ast.Or => Or
|
||||
}
|
||||
}
|
||||
|
||||
// Statements
|
||||
sealed trait Stmt {
|
||||
val pos: Position
|
||||
}
|
||||
|
||||
case class Builtin(val name: String)(retTy: SemType) extends CallTarget(retTy) {
|
||||
override def toString(): String = name
|
||||
}
|
||||
object Builtin {
|
||||
object Read extends Builtin("read")(?)
|
||||
object Printf extends Builtin("printf")(?)
|
||||
object Exit extends Builtin("exit")(?)
|
||||
object Free extends Builtin("free")(?)
|
||||
object Malloc extends Builtin("malloc")(?)
|
||||
object PrintCharArray extends Builtin("printCharArray")(?)
|
||||
}
|
||||
|
||||
case class Assign(lhs: LValue, rhs: Expr)(val pos: Position) extends Stmt
|
||||
case class If(cond: Expr, thenBranch: Chain[Stmt], elseBranch: Chain[Stmt])(val pos: Position)
|
||||
extends Stmt
|
||||
case class While(cond: Expr, body: Chain[Stmt])(val pos: Position) extends Stmt
|
||||
case class Call(target: CallTarget, args: List[Expr])(val pos: Position)
|
||||
extends Stmt
|
||||
with Expr(target.retTy)
|
||||
case class Return(expr: Expr)(val pos: Position) extends Stmt
|
||||
|
||||
// Program
|
||||
case class FuncDecl(name: Ident, params: List[Ident], body: Chain[Stmt])(val pos: Position)
|
||||
case class Program(funcs: Chain[FuncDecl], stmts: Chain[Stmt])(val pos: Position)
|
||||
}
|
||||
@@ -1,18 +1,22 @@
|
||||
package wacc
|
||||
|
||||
import java.io.File
|
||||
import parsley.Result
|
||||
import parsley.Parsley
|
||||
import parsley.Parsley.{atomic, many, notFollowedBy, pure, unit}
|
||||
import parsley.combinator.{countSome, sepBy}
|
||||
import parsley.combinator.{countSome, sepBy, option}
|
||||
import parsley.expr.{precedence, SOps, InfixL, InfixN, InfixR, Prefix, Atoms}
|
||||
import parsley.errors.combinator._
|
||||
import parsley.errors.patterns.VerifiedErrors
|
||||
import parsley.syntax.zipped._
|
||||
import parsley.cats.combinator.{some}
|
||||
import parsley.cats.combinator.{some, sepBy1}
|
||||
import cats.syntax.all._
|
||||
import cats.data.NonEmptyList
|
||||
import parsley.errors.DefaultErrorBuilder
|
||||
import parsley.errors.ErrorBuilder
|
||||
import parsley.errors.tokenextractors.LexToken
|
||||
import parsley.expr.GOps
|
||||
import cats.Functor
|
||||
|
||||
object parser {
|
||||
import lexer.implicits.implicitSymbol
|
||||
@@ -52,13 +56,24 @@ object parser {
|
||||
implicit val builder: ErrorBuilder[String] = new DefaultErrorBuilder with LexToken {
|
||||
def tokens = errTokens
|
||||
}
|
||||
def parse(input: String): Result[String, Program] = parser.parse(input)
|
||||
private val parser = lexer.fully(`<program>`)
|
||||
def parse(input: String): Result[String, File => PartialProgram] = parser.parse(input)
|
||||
private val parser = lexer.fully(`<partial-program>`)
|
||||
|
||||
private type FParsley[A] = Parsley[File => A]
|
||||
|
||||
private def fParsley[A](p: Parsley[A]): FParsley[A] =
|
||||
p map { a => file => a }
|
||||
|
||||
private def fPair[A, B](p: Parsley[(File => A, File => B)]): FParsley[(A, B)] =
|
||||
p map { case (a, b) => file => (a(file), b(file)) }
|
||||
|
||||
private def fMap[A, F[_]: Functor](p: Parsley[F[File => A]]): FParsley[F[A]] =
|
||||
p map { funcs => file => funcs.map(_(file)) }
|
||||
|
||||
// Expressions
|
||||
private lazy val `<expr>`: Parsley[Expr] = precedence {
|
||||
SOps(InfixR)(Or from "||") +:
|
||||
SOps(InfixR)(And from "&&") +:
|
||||
private lazy val `<expr>`: FParsley[Expr] = precedence {
|
||||
GOps(InfixR)(Or from "||") +:
|
||||
GOps(InfixR)(And from "&&") +:
|
||||
SOps(InfixN)(Eq from "==", Neq from "!=") +:
|
||||
SOps(InfixN)(
|
||||
Less from "<",
|
||||
@@ -83,32 +98,33 @@ object parser {
|
||||
}
|
||||
|
||||
// Atoms
|
||||
private lazy val `<atom>`: Atoms[Expr6] = Atoms(
|
||||
private lazy val `<atom>`: Atoms[File => Expr6] = Atoms(
|
||||
IntLiter(integer).label("integer literal"),
|
||||
BoolLiter(("true" as true) | ("false" as false)).label("boolean literal"),
|
||||
CharLiter(charLit).label("character literal"),
|
||||
StrLiter(stringLit).label("string literal"),
|
||||
`<str-liter>`.label("string literal"),
|
||||
PairLiter from "null",
|
||||
`<ident-or-array-elem>`,
|
||||
Parens("(" ~> `<expr>` <~ ")")
|
||||
)
|
||||
private val `<ident>` =
|
||||
private lazy val `<str-liter>` = StrLiter(stringLit)
|
||||
private lazy val `<ident>` =
|
||||
Ident(ident) | some("*" | "&").verifiedExplain("pointer operators are not allowed")
|
||||
private lazy val `<ident-or-array-elem>` =
|
||||
(`<ident>` <~ ("(".verifiedExplain(
|
||||
"functions can only be called using 'call' keyword"
|
||||
) | unit)) <**> (`<array-indices>` </> identity)
|
||||
private val `<array-indices>` = ArrayElem(some("[" ~> `<expr>` <~ "]"))
|
||||
private lazy val `<array-indices>` = ArrayElem(fMap(some("[" ~> `<expr>` <~ "]")))
|
||||
|
||||
// Types
|
||||
private lazy val `<type>`: Parsley[Type] =
|
||||
private lazy val `<type>`: FParsley[Type] =
|
||||
(`<base-type>` | (`<pair-type>` ~> `<pair-elems-type>`)) <**> (`<array-type>` </> identity)
|
||||
private val `<base-type>` =
|
||||
(IntType from "int") | (BoolType from "bool") | (CharType from "char") | (StringType from "string")
|
||||
private lazy val `<array-type>` =
|
||||
ArrayType(countSome("[" ~> "]"))
|
||||
ArrayType(fParsley(countSome("[" ~> "]")))
|
||||
private val `<pair-type>` = "pair"
|
||||
private val `<pair-elems-type>`: Parsley[PairType] = PairType(
|
||||
private val `<pair-elems-type>`: FParsley[PairType] = PairType(
|
||||
"(" ~> `<pair-elem-type>` <~ ",",
|
||||
`<pair-elem-type>` <~ ")"
|
||||
)
|
||||
@@ -116,7 +132,7 @@ object parser {
|
||||
(`<base-type>` <**> (`<array-type>` </> identity)) |
|
||||
((UntypedPairType from `<pair-type>`) <**>
|
||||
((`<pair-elems-type>` <**> `<array-type>`)
|
||||
.map(arr => (_: UntypedPairType) => arr) </> identity))
|
||||
.map(arr => (_: File => UntypedPairType) => arr) </> identity))
|
||||
|
||||
/* Statements
|
||||
Atomic is used in two places here:
|
||||
@@ -127,13 +143,30 @@ object parser {
|
||||
invalid syntax check, this only happens at most once per program so this is not a major
|
||||
concern.
|
||||
*/
|
||||
private lazy val `<partial-program>` = PartialProgram(
|
||||
fMap(many(`<import>`)),
|
||||
`<program>`
|
||||
)
|
||||
private lazy val `<import>` = Import(
|
||||
"import" ~> `<import-filename>`,
|
||||
"(" ~> fMap(sepBy1(`<imported-func>`, ",")) <~ ")"
|
||||
)
|
||||
private lazy val `<import-filename>` = `<str-liter>`.label("import file name")
|
||||
private lazy val `<imported-func>` = ImportedFunc(
|
||||
`<ident>`.label("imported function name"),
|
||||
fMap(option("as" ~> `<ident>`)).label("imported function alias")
|
||||
)
|
||||
private lazy val `<program>` = Program(
|
||||
"begin" ~> (
|
||||
many(
|
||||
atomic(
|
||||
`<type>`.label("function declaration") <~> `<ident>` <~ "("
|
||||
) <**> `<partial-func-decl>`
|
||||
).label("function declaration") |
|
||||
fMap(
|
||||
many(
|
||||
fPair(
|
||||
atomic(
|
||||
`<type>`.label("function declaration") <~> `<ident>` <~ "("
|
||||
)
|
||||
) <**> `<partial-func-decl>`
|
||||
).label("function declaration")
|
||||
) |
|
||||
atomic(`<ident>` <~ "(").verifiedExplain("function declaration is missing return type")
|
||||
),
|
||||
`<stmt>`.label(
|
||||
@@ -142,17 +175,23 @@ object parser {
|
||||
)
|
||||
private lazy val `<partial-func-decl>` =
|
||||
FuncDecl(
|
||||
sepBy(`<param>`, ",") <~ ")" <~ "is",
|
||||
`<stmt>`.guardAgainst {
|
||||
case stmts if !stmts.isReturning => Seq("all functions must end in a returning statement")
|
||||
} <~ "end"
|
||||
fPair(
|
||||
(fMap(sepBy(`<param>`, ",")) <~ ")" <~ "is") <~>
|
||||
(`<stmt>`.guardAgainst {
|
||||
// TODO: passing in an arbitrary file works but is ugly
|
||||
case stmts if !(stmts(File("."))).isReturning =>
|
||||
Seq("all functions must end in a returning statement")
|
||||
} <~ "end")
|
||||
)
|
||||
)
|
||||
private lazy val `<param>` = Param(`<type>`, `<ident>`)
|
||||
private lazy val `<stmt>`: Parsley[NonEmptyList[Stmt]] =
|
||||
(
|
||||
`<basic-stmt>`.label("main program body"),
|
||||
(many(";" ~> `<basic-stmt>`.label("statement after ';'"))) </> Nil
|
||||
).zipped(NonEmptyList.apply)
|
||||
private lazy val `<stmt>`: FParsley[NonEmptyList[Stmt]] =
|
||||
fMap(
|
||||
(
|
||||
`<basic-stmt>`.label("main program body"),
|
||||
(many(";" ~> `<basic-stmt>`.label("statement after ';'"))) </> Nil
|
||||
).zipped(NonEmptyList.apply)
|
||||
)
|
||||
|
||||
private lazy val `<basic-stmt>` =
|
||||
(Skip from "skip")
|
||||
@@ -160,8 +199,8 @@ object parser {
|
||||
| Free("free" ~> `<expr>`.labelAndExplain(LabelType.Expr))
|
||||
| Return("return" ~> `<expr>`.labelAndExplain(LabelType.Expr))
|
||||
| Exit("exit" ~> `<expr>`.labelAndExplain(LabelType.Expr))
|
||||
| Print("print" ~> `<expr>`.labelAndExplain(LabelType.Expr), pure(false))
|
||||
| Print("println" ~> `<expr>`.labelAndExplain(LabelType.Expr), pure(true))
|
||||
| Print("print" ~> `<expr>`.labelAndExplain(LabelType.Expr), fParsley(pure(false)))
|
||||
| Print("println" ~> `<expr>`.labelAndExplain(LabelType.Expr), fParsley(pure(true)))
|
||||
| If(
|
||||
"if" ~> `<expr>`.labelWithType(LabelType.Expr) <~ "then",
|
||||
`<stmt>` <~ "else",
|
||||
@@ -185,9 +224,9 @@ object parser {
|
||||
("call" ~> `<ident>`).verifiedExplain(
|
||||
"function calls' results must be assigned to a variable"
|
||||
)
|
||||
private lazy val `<lvalue>`: Parsley[LValue] =
|
||||
private lazy val `<lvalue>`: FParsley[LValue] =
|
||||
`<pair-elem>` | `<ident-or-array-elem>`
|
||||
private lazy val `<rvalue>`: Parsley[RValue] =
|
||||
private lazy val `<rvalue>`: FParsley[RValue] =
|
||||
`<array-liter>` |
|
||||
NewPair(
|
||||
"newpair" ~> "(" ~> `<expr>` <~ ",",
|
||||
@@ -196,13 +235,13 @@ object parser {
|
||||
`<pair-elem>` |
|
||||
Call(
|
||||
"call" ~> `<ident>` <~ "(",
|
||||
sepBy(`<expr>`, ",") <~ ")"
|
||||
fMap(sepBy(`<expr>`, ",")) <~ ")"
|
||||
) | `<expr>`.labelWithType(LabelType.Expr)
|
||||
private lazy val `<pair-elem>` =
|
||||
Fst("fst" ~> `<lvalue>`.label("valid pair"))
|
||||
| Snd("snd" ~> `<lvalue>`.label("valid pair"))
|
||||
private lazy val `<array-liter>` = ArrayLiter(
|
||||
"[" ~> sepBy(`<expr>`, ",") <~ "]"
|
||||
"[" ~> fMap(sepBy(`<expr>`, ",")) <~ "]"
|
||||
)
|
||||
|
||||
extension (stmts: NonEmptyList[Stmt]) {
|
||||
368
src/main/wacc/frontend/renamer.scala
Normal file
368
src/main/wacc/frontend/renamer.scala
Normal file
@@ -0,0 +1,368 @@
|
||||
package wacc
|
||||
|
||||
import java.io.File
|
||||
import scala.collection.mutable
|
||||
import cats.effect.IO
|
||||
import cats.implicits._
|
||||
import cats.data.Chain
|
||||
import cats.data.NonEmptyList
|
||||
import parsley.{Failure, Success}
|
||||
|
||||
object renamer {
|
||||
import ast._
|
||||
import types._
|
||||
|
||||
val MAIN = "$main"
|
||||
|
||||
enum IdentType {
|
||||
case Func
|
||||
case Var
|
||||
}
|
||||
|
||||
case class ScopeKey(path: String, name: String, identType: IdentType)
|
||||
case class ScopeValue(id: Ident, public: Boolean)
|
||||
|
||||
class Scope(
|
||||
private val current: mutable.Map[ScopeKey, ScopeValue],
|
||||
private val parent: Map[ScopeKey, ScopeValue],
|
||||
guidStart: Int = 0,
|
||||
val guidInc: Int = 1
|
||||
) {
|
||||
private var guid = guidStart
|
||||
private var immutable = false
|
||||
|
||||
private def nextGuid(): Int = {
|
||||
val id = guid
|
||||
guid += guidInc
|
||||
id
|
||||
}
|
||||
|
||||
private def verifyMutable(): Unit = {
|
||||
if (immutable) throw new IllegalStateException("Cannot modify an immutable scope")
|
||||
}
|
||||
|
||||
/** Create a new scope with the current scope as its parent.
|
||||
*
|
||||
* To be used for single-threaded applications.
|
||||
*
|
||||
* @return
|
||||
* A new scope with an empty current scope, and this scope flattened into the parent scope.
|
||||
*/
|
||||
def withSubscope[T](f: Scope => T): T = {
|
||||
val subscope =
|
||||
Scope(mutable.Map.empty, Map.empty.withDefault(current.withDefault(parent)), guid, guidInc)
|
||||
immutable = true
|
||||
val result = f(subscope)
|
||||
guid = subscope.guid // Sync GUID
|
||||
immutable = false
|
||||
result
|
||||
}
|
||||
|
||||
/** Create new scopes with the current scope as its parent and GUID numbering adjusted
|
||||
* correctly.
|
||||
*
|
||||
* This will permanently mark the current scope as immutable, for thread safety.
|
||||
*
|
||||
* To be used for multi-threaded applications.
|
||||
*
|
||||
* @return
|
||||
* New scopes with an empty current scope, and this scope flattened into the parent scope.
|
||||
*/
|
||||
def subscopes(n: Int): Seq[Scope] = {
|
||||
verifyMutable()
|
||||
immutable = true
|
||||
(0 until n).map { i =>
|
||||
Scope(
|
||||
mutable.Map.empty,
|
||||
Map.empty.withDefault(current.withDefault(parent)),
|
||||
guid + i * guidInc,
|
||||
guidInc * n
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
/** Attempt to add a new identifier to the current scope. If the identifier already exists in
|
||||
* the current scope, add an error to the error list.
|
||||
*
|
||||
* @param name
|
||||
* The name of the identifier.
|
||||
* @return
|
||||
* An error, if one occurred.
|
||||
*/
|
||||
def add(name: Ident, public: Boolean = false): Chain[Error] = {
|
||||
verifyMutable()
|
||||
val path = name.pos.file.getCanonicalPath
|
||||
val identType = name.ty match {
|
||||
case _: SemType => IdentType.Var
|
||||
case _: FuncType => IdentType.Func
|
||||
}
|
||||
val key = ScopeKey(path, name.v, identType)
|
||||
current.get(key) match {
|
||||
case Some(ScopeValue(Ident(_, id, _), _)) =>
|
||||
name.guid = id
|
||||
Chain.one(Error.DuplicateDeclaration(name))
|
||||
case None =>
|
||||
name.guid = nextGuid()
|
||||
current(key) = ScopeValue(name, public)
|
||||
Chain.empty
|
||||
}
|
||||
}
|
||||
|
||||
/** Attempt to add a new identifier as an alias to another to the existing scope.
|
||||
*
|
||||
* @param alias
|
||||
* The (new) alias identifier.
|
||||
* @param orig
|
||||
* The (existing) original identifier.
|
||||
*
|
||||
* @return
|
||||
* An error, if one occurred.
|
||||
*/
|
||||
def addAlias(alias: Ident, orig: ScopeValue, public: Boolean = false): Chain[Error] = {
|
||||
verifyMutable()
|
||||
val path = alias.pos.file.getCanonicalPath
|
||||
val identType = alias.ty match {
|
||||
case _: SemType => IdentType.Var
|
||||
case _: FuncType => IdentType.Func
|
||||
}
|
||||
val key = ScopeKey(path, alias.v, identType)
|
||||
current.get(key) match {
|
||||
case Some(ScopeValue(Ident(_, id, _), _)) =>
|
||||
alias.guid = id
|
||||
Chain.one(Error.DuplicateDeclaration(alias))
|
||||
case None =>
|
||||
alias.guid = nextGuid()
|
||||
current(key) = ScopeValue(orig.id, public)
|
||||
Chain.empty
|
||||
}
|
||||
}
|
||||
|
||||
def get(path: String, name: String, identType: IdentType): Option[ScopeValue] =
|
||||
// Unfortunately map defaults only work with `.apply()`, which throws an error when the key is not found.
|
||||
// Neither is there a way to check whether a default exists, so we have to use a try-catch.
|
||||
try {
|
||||
Some(current.withDefault(parent)(ScopeKey(path, name, identType)))
|
||||
} catch {
|
||||
case _: NoSuchElementException => None
|
||||
}
|
||||
|
||||
def getVar(name: Ident): Option[Ident] =
|
||||
get(name.pos.file.getCanonicalPath, name.v, IdentType.Var).map(_.id)
|
||||
def getFunc(name: Ident): Option[Ident] =
|
||||
get(name.pos.file.getCanonicalPath, name.v, IdentType.Func).map(_.id)
|
||||
}
|
||||
|
||||
def prepareGlobalScope(
|
||||
partialProg: PartialProgram
|
||||
)(using scope: Scope): IO[(FuncDecl, Chain[FuncDecl], Chain[Error])] = {
|
||||
def readImportFile(file: File): IO[String] =
|
||||
IO.blocking(os.read(os.Path(file.getCanonicalPath)))
|
||||
|
||||
def prepareImport(contents: String, file: File)(using
|
||||
scope: Scope
|
||||
): IO[(Chain[FuncDecl], Chain[Error])] = {
|
||||
parser.parse(contents) match {
|
||||
case Failure(msg) =>
|
||||
IO.pure(Chain.empty, Chain.one(Error.SyntaxError(file, msg)))
|
||||
case Success(fn) =>
|
||||
val partialProg = fn(file)
|
||||
for {
|
||||
(main, chunks, errors) <- prepareGlobalScope(partialProg)
|
||||
} yield (main +: chunks, errors)
|
||||
}
|
||||
}
|
||||
|
||||
def addImportsToScope(importFile: File, funcs: NonEmptyList[ImportedFunc])(using
|
||||
scope: Scope
|
||||
): Chain[Error] =
|
||||
funcs.foldMap { case ImportedFunc(srcName, aliasName) =>
|
||||
scope.get(importFile.getCanonicalPath, srcName.v, IdentType.Func) match {
|
||||
case Some(src) if src.public =>
|
||||
aliasName.ty = src.id.ty
|
||||
scope.addAlias(aliasName, src)
|
||||
case _ =>
|
||||
Chain.one(Error.UndefinedFunction(srcName))
|
||||
}
|
||||
}
|
||||
|
||||
val PartialProgram(imports, prog) = partialProg
|
||||
|
||||
// First prepare this file's functions...
|
||||
val Program(funcs, main) = prog
|
||||
val (funcChunks, funcErrors) = funcs.foldLeft((Chain.empty[FuncDecl], Chain.empty[Error])) {
|
||||
case ((chunks, errors), func @ FuncDecl(retType, name, params, body)) =>
|
||||
val paramTypes = params.map { param =>
|
||||
val paramType = SemType(param.paramType)
|
||||
param.name.ty = paramType
|
||||
paramType
|
||||
}
|
||||
name.ty = FuncType(SemType(retType), paramTypes)
|
||||
(chunks :+ func, errors ++ scope.add(name, public = true))
|
||||
}
|
||||
// ...and main body.
|
||||
val mainBodyIdent = Ident(MAIN, ty = FuncType(?, Nil))(main.head.pos)
|
||||
val mainBodyErrors = scope.add(mainBodyIdent, public = false)
|
||||
val mainBodyChunk = FuncDecl(IntType()(prog.pos), mainBodyIdent, Nil, main)(prog.pos)
|
||||
|
||||
// Now handle imports
|
||||
val file = prog.pos.file
|
||||
val preparedImports = imports.foldLeftM[IO, (Chain[FuncDecl], Chain[Error])](
|
||||
(Chain.empty[FuncDecl], Chain.empty[Error])
|
||||
) { case ((chunks, errors), Import(name, funcs)) =>
|
||||
val importFile = File(file.getParent, name.v)
|
||||
if (!importFile.exists()) {
|
||||
IO.pure(
|
||||
(
|
||||
chunks,
|
||||
errors :+ Error.SemanticError(
|
||||
name.pos,
|
||||
s"File not found: ${importFile.getCanonicalPath}"
|
||||
)
|
||||
)
|
||||
)
|
||||
} else if (!importFile.canRead()) {
|
||||
IO.pure(
|
||||
(
|
||||
chunks,
|
||||
errors :+ Error.SemanticError(
|
||||
name.pos,
|
||||
s"File not readable: ${importFile.getCanonicalPath}"
|
||||
)
|
||||
)
|
||||
)
|
||||
} else if (importFile.getCanonicalPath == file.getCanonicalPath) {
|
||||
IO.pure(
|
||||
(
|
||||
chunks,
|
||||
errors :+ Error.SemanticError(
|
||||
name.pos,
|
||||
s"Cannot import self: ${importFile.getCanonicalPath}"
|
||||
)
|
||||
)
|
||||
)
|
||||
} else if (scope.get(importFile.getCanonicalPath, MAIN, IdentType.Func).isDefined) {
|
||||
IO.pure(chunks, errors ++ addImportsToScope(importFile, funcs))
|
||||
} else {
|
||||
for {
|
||||
contents <- readImportFile(importFile)
|
||||
(importChunks, importErrors) <- prepareImport(contents, importFile)
|
||||
importAliasErrors = addImportsToScope(importFile, funcs)
|
||||
} yield (chunks ++ importChunks, errors ++ importErrors)
|
||||
}
|
||||
}
|
||||
|
||||
for {
|
||||
(importChunks, importErrors) <- preparedImports
|
||||
allChunks = importChunks ++ funcChunks
|
||||
allErrors = importErrors ++ funcErrors ++ mainBodyErrors
|
||||
} yield (mainBodyChunk, allChunks, allErrors)
|
||||
}
|
||||
|
||||
/** Check scoping of all variables and flatten a program. Also generates semantic types and parses
|
||||
* any imported files.
|
||||
*
|
||||
* @param partialProg
|
||||
* AST of the program
|
||||
* @return
|
||||
* (flattenedProg, errors)
|
||||
*/
|
||||
def renameFunction(funcScopePair: (FuncDecl, Scope)): IO[Chain[Error]] = {
|
||||
val (FuncDecl(_, _, params, body), subscope) = funcScopePair
|
||||
val paramErrors = params.foldMap(param => subscope.add(param.name))
|
||||
IO(subscope.withSubscope { s => body.foldMap(rename(s)) })
|
||||
.map(bodyErrors => paramErrors ++ bodyErrors)
|
||||
}
|
||||
|
||||
/** Check scoping of all identifies in a given AST node.
|
||||
*
|
||||
* @param scope
|
||||
* The current scope and flattened parent scope.
|
||||
* @param node
|
||||
* The AST node.
|
||||
*/
|
||||
private def rename(scope: Scope)(node: Ident | Stmt | LValue | RValue | Expr): Chain[Error] =
|
||||
node match {
|
||||
// These cases are more interes/globting because the involve making subscopes
|
||||
// or modifying the current scope.
|
||||
case VarDecl(synType, name, value) => {
|
||||
// Order matters here. Variable isn't declared until after the value is evaluated.
|
||||
val errors = rename(scope)(value)
|
||||
// Attempt to add the new variable to the current scope.
|
||||
name.ty = SemType(synType)
|
||||
errors ++ scope.add(name)
|
||||
}
|
||||
case If(cond, thenStmt, elseStmt) => {
|
||||
val condErrors = rename(scope)(cond)
|
||||
// then and else both have their own scopes
|
||||
val thenErrors = scope.withSubscope(s => thenStmt.foldMap(rename(s)))
|
||||
val elseErrors = scope.withSubscope(s => elseStmt.foldMap(rename(s)))
|
||||
condErrors ++ thenErrors ++ elseErrors
|
||||
}
|
||||
case While(cond, body) => {
|
||||
val condErrors = rename(scope)(cond)
|
||||
// while bodies have their own scopes
|
||||
val bodyErrors = scope.withSubscope(s => body.foldMap(rename(s)))
|
||||
condErrors ++ bodyErrors
|
||||
}
|
||||
// begin-end blocks have their own scopes
|
||||
case Block(body) => scope.withSubscope(s => body.foldMap(rename(s)))
|
||||
|
||||
// These cases are simpler, mostly just recursive calls to rename()
|
||||
case Assign(lhs, value) => {
|
||||
// Variables may be reassigned with their value in the rhs, so order doesn't matter here.
|
||||
rename(scope)(lhs) ++ rename(scope)(value)
|
||||
}
|
||||
case Read(lhs) => rename(scope)(lhs)
|
||||
case Free(expr) => rename(scope)(expr)
|
||||
case Return(expr) => rename(scope)(expr)
|
||||
case Exit(expr) => rename(scope)(expr)
|
||||
case Print(expr, _) => rename(scope)(expr)
|
||||
case NewPair(fst, snd) => {
|
||||
rename(scope)(fst) ++ rename(scope)(snd)
|
||||
}
|
||||
case Call(name, args) => {
|
||||
val nameErrors = scope.getFunc(name) match {
|
||||
case Some(Ident(realName, guid, ty)) =>
|
||||
name.v = realName
|
||||
name.ty = ty
|
||||
name.guid = guid
|
||||
Chain.empty
|
||||
case None =>
|
||||
name.ty = FuncType(?, args.map(_ => ?))
|
||||
scope.add(name)
|
||||
Chain.one(Error.UndefinedFunction(name))
|
||||
}
|
||||
val argsErrors = args.foldMap(rename(scope))
|
||||
nameErrors ++ argsErrors
|
||||
}
|
||||
case Fst(elem) => rename(scope)(elem)
|
||||
case Snd(elem) => rename(scope)(elem)
|
||||
case ArrayLiter(elems) => elems.foldMap(rename(scope))
|
||||
case ArrayElem(name, indices) => {
|
||||
val nameErrors = rename(scope)(name)
|
||||
val indicesErrors = indices.foldMap(rename(scope))
|
||||
nameErrors ++ indicesErrors
|
||||
}
|
||||
case Parens(expr) => rename(scope)(expr)
|
||||
case op: UnaryOp => rename(scope)(op.x)
|
||||
case op: BinaryOp => {
|
||||
rename(scope)(op.x) ++ rename(scope)(op.y)
|
||||
}
|
||||
// Default to variables. Only `call` uses IdentType.Func.
|
||||
case id: Ident => {
|
||||
scope.getVar(id) match {
|
||||
case Some(Ident(_, guid, ty)) =>
|
||||
id.ty = ty
|
||||
id.guid = guid
|
||||
Chain.empty
|
||||
case None =>
|
||||
id.ty = ?
|
||||
scope.add(id)
|
||||
Chain.one(Error.UndeclaredVariable(id))
|
||||
}
|
||||
}
|
||||
// These literals cannot contain identifies, exit immediately.
|
||||
case IntLiter(_) | BoolLiter(_) | CharLiter(_) | StrLiter(_) | PairLiter() | Skip() =>
|
||||
Chain.empty
|
||||
}
|
||||
}
|
||||
42
src/main/wacc/frontend/semantics.scala
Normal file
42
src/main/wacc/frontend/semantics.scala
Normal file
@@ -0,0 +1,42 @@
|
||||
package wacc
|
||||
|
||||
import scala.collection.mutable
|
||||
import cats.implicits._
|
||||
import cats.data.Chain
|
||||
import cats.effect.IO
|
||||
|
||||
object semantics {
|
||||
import renamer.{Scope, prepareGlobalScope, renameFunction}
|
||||
import typeChecker.checkFuncDecl
|
||||
|
||||
private def checkFunc(
|
||||
funcDecl: ast.FuncDecl,
|
||||
scope: Scope
|
||||
): IO[(microWacc.FuncDecl, Chain[Error])] = {
|
||||
for {
|
||||
renamerErrors <- renameFunction(funcDecl, scope)
|
||||
(microWaccFunc, typeErrors) = checkFuncDecl(funcDecl)
|
||||
} yield (microWaccFunc, renamerErrors ++ typeErrors)
|
||||
}
|
||||
|
||||
def check(partialProg: ast.PartialProgram): IO[(microWacc.Program, Chain[Error])] = {
|
||||
given scope: Scope = Scope(mutable.Map.empty, Map.empty)
|
||||
|
||||
for {
|
||||
(main, chunks, globalErrors) <- prepareGlobalScope(partialProg)
|
||||
toRename = (main +: chunks).toList
|
||||
res <- toRename
|
||||
.zip(scope.subscopes(toRename.size))
|
||||
.parTraverse(checkFunc)
|
||||
(typedChunks, errors) = res.foldLeft((Chain.empty[microWacc.FuncDecl], Chain.empty[Error])) {
|
||||
case ((acc, err), (funcDecl, errors)) =>
|
||||
(acc :+ funcDecl, err ++ errors)
|
||||
}
|
||||
(typedMain, funcs) = typedChunks.uncons match {
|
||||
case Some((head, tail)) => (head.body, tail)
|
||||
case None => (Chain.empty, Chain.empty)
|
||||
}
|
||||
} yield (microWacc.Program(funcs, typedMain)(main.pos), globalErrors ++ errors)
|
||||
}
|
||||
|
||||
}
|
||||
492
src/main/wacc/frontend/typeChecker.scala
Normal file
492
src/main/wacc/frontend/typeChecker.scala
Normal file
@@ -0,0 +1,492 @@
|
||||
package wacc
|
||||
|
||||
import cats.syntax.all._
|
||||
import cats.data.NonEmptyList
|
||||
import cats.data.Chain
|
||||
|
||||
object typeChecker {
|
||||
import wacc.types._
|
||||
|
||||
private enum Constraint {
|
||||
case Unconstrained
|
||||
// Allows weakening in one direction
|
||||
case Is(ty: SemType, msg: String)
|
||||
// Allows weakening in both directions, useful for array literals
|
||||
case IsSymmetricCompatible(ty: SemType, msg: String)
|
||||
// Does not allow weakening
|
||||
case IsUnweakenable(ty: SemType, msg: String)
|
||||
case IsEither(ty1: SemType, ty2: SemType, msg: String)
|
||||
case Never(msg: String)
|
||||
}
|
||||
|
||||
extension (ty: SemType) {
|
||||
|
||||
/** Check if a type satisfies a constraint.
|
||||
*
|
||||
* @param constraint
|
||||
* Constraint to satisfy.
|
||||
* @param pos
|
||||
* Position to pass to the error, if constraint was not satisfied.
|
||||
* @return
|
||||
* The type if the constraint was satisfied, or ? if it was not.
|
||||
*/
|
||||
private def satisfies(constraint: Constraint, pos: ast.Position): (SemType, Chain[Error]) =
|
||||
(ty, constraint) match {
|
||||
case (KnownType.Array(KnownType.Char), Constraint.Is(KnownType.String, _)) =>
|
||||
(KnownType.String, Chain.empty)
|
||||
case (
|
||||
KnownType.String,
|
||||
Constraint.IsSymmetricCompatible(KnownType.Array(KnownType.Char), _)
|
||||
) =>
|
||||
(KnownType.String, Chain.empty)
|
||||
case (ty, Constraint.IsSymmetricCompatible(ty2, msg)) =>
|
||||
ty.satisfies(Constraint.Is(ty2, msg), pos)
|
||||
// Change to IsUnweakenable to disallow recursive weakening
|
||||
case (ty, Constraint.Is(ty2, msg)) => ty.satisfies(Constraint.IsUnweakenable(ty2, msg), pos)
|
||||
case (ty, Constraint.Unconstrained) => (ty, Chain.empty)
|
||||
case (ty, Constraint.Never(msg)) =>
|
||||
(?, Chain.one(Error.SemanticError(pos, msg)))
|
||||
case (ty, Constraint.IsEither(ty1, ty2, msg)) =>
|
||||
(ty moreSpecific ty1).orElse(ty moreSpecific ty2).map((_, Chain.empty)).getOrElse {
|
||||
(?, Chain.one(Error.TypeMismatch(pos, ty1, ty, msg)))
|
||||
}
|
||||
case (ty, Constraint.IsUnweakenable(ty2, msg)) =>
|
||||
(ty moreSpecific ty2).map((_, Chain.empty)).getOrElse {
|
||||
(?, Chain.one(Error.TypeMismatch(pos, ty2, ty, msg)))
|
||||
}
|
||||
}
|
||||
|
||||
/** Tries to merge two types, returning the more specific one if possible.
|
||||
*
|
||||
* @param ty2
|
||||
* The other type to merge with.
|
||||
* @return
|
||||
* The more specific type if it could be determined, or None if the types are incompatible.
|
||||
*/
|
||||
private infix def moreSpecific(ty2: SemType): Option[SemType] =
|
||||
(ty, ty2) match {
|
||||
case (ty, ?) => Some(ty)
|
||||
case (?, ty) => Some(ty)
|
||||
case (ty1, ty2) if ty1 == ty2 => Some(ty1)
|
||||
case (KnownType.Array(inn1), KnownType.Array(inn2)) =>
|
||||
(inn1 moreSpecific inn2).map(KnownType.Array(_))
|
||||
case (KnownType.Pair(fst1, snd1), KnownType.Pair(fst2, snd2)) =>
|
||||
(fst1 moreSpecific fst2, snd1 moreSpecific snd2).mapN(KnownType.Pair(_, _))
|
||||
case _ => None
|
||||
}
|
||||
}
|
||||
|
||||
/** Type-check a function declaration.
|
||||
*
|
||||
* @param func
|
||||
* The AST of the function to type-check.
|
||||
*/
|
||||
def checkFuncDecl(func: ast.FuncDecl): (microWacc.FuncDecl, Chain[Error]) = {
|
||||
val ast.FuncDecl(_, name, params, stmts) = func
|
||||
val FuncType(retType, paramTypes) = name.ty.asInstanceOf[FuncType]
|
||||
val returnConstraint =
|
||||
if func.name.v == renamer.MAIN then Constraint.Never("main body must not return")
|
||||
else Constraint.Is(retType, s"function ${name.v} must return $retType")
|
||||
val (body, bodyErrors) = stmts.foldMap(checkStmt(_, returnConstraint))
|
||||
(
|
||||
microWacc.FuncDecl(
|
||||
microWacc.Ident(name.v, name.guid)(retType),
|
||||
params.zip(paramTypes).map { case (ast.Param(_, ident), ty) =>
|
||||
microWacc.Ident(ident.v, ident.guid)(ty)
|
||||
},
|
||||
body
|
||||
)(func.pos),
|
||||
bodyErrors
|
||||
)
|
||||
}
|
||||
|
||||
/** Type-check an AST statement node.
|
||||
*
|
||||
* @param stmt
|
||||
* The statement to type-check.
|
||||
* @param returnConstraint
|
||||
* The constraint that any `return <expr>` statements must satisfy.
|
||||
*/
|
||||
private def checkStmt(
|
||||
stmt: ast.Stmt,
|
||||
returnConstraint: Constraint
|
||||
): (Chain[microWacc.Stmt], Chain[Error]) = stmt match {
|
||||
// Ignore the type of the variable, since it has been converted to a SemType by the renamer.
|
||||
case ast.VarDecl(_, name, value) =>
|
||||
val expectedTy = name.ty
|
||||
val (typedValue, valueErrors) = checkValue(
|
||||
value,
|
||||
Constraint.Is(
|
||||
expectedTy.asInstanceOf[SemType],
|
||||
s"variable ${name.v} must be assigned a value of type $expectedTy"
|
||||
)
|
||||
)
|
||||
(
|
||||
Chain.one(
|
||||
microWacc.Assign(
|
||||
microWacc.Ident(name.v, name.guid)(expectedTy.asInstanceOf[SemType]),
|
||||
typedValue
|
||||
)(stmt.pos)
|
||||
),
|
||||
valueErrors
|
||||
)
|
||||
case ast.Assign(lhs, rhs) =>
|
||||
val (lhsTyped, lhsErrors) = checkLValue(lhs, Constraint.Unconstrained)
|
||||
val (rhsTyped, rhsErrors) =
|
||||
checkValue(rhs, Constraint.Is(lhsTyped.ty, s"assignment must have type ${lhsTyped.ty}"))
|
||||
val unknownError = (lhsTyped.ty, rhsTyped.ty) match {
|
||||
case (?, ?) =>
|
||||
Chain.one(
|
||||
Error.SemanticError(lhs.pos, "assignment with both sides of unknown type is illegal")
|
||||
)
|
||||
case _ => Chain.empty
|
||||
}
|
||||
(
|
||||
Chain.one(microWacc.Assign(lhsTyped, rhsTyped)(stmt.pos)),
|
||||
lhsErrors ++ rhsErrors ++ unknownError
|
||||
)
|
||||
case ast.Read(dest) =>
|
||||
val (destTyped, destErrors) = checkLValue(dest, Constraint.Unconstrained)
|
||||
val (destTy, destTyErrors) = destTyped.ty match {
|
||||
case ? =>
|
||||
(
|
||||
?,
|
||||
Chain.one(
|
||||
Error.SemanticError(dest.pos, "cannot read into a destination with an unknown type")
|
||||
)
|
||||
)
|
||||
case destTy =>
|
||||
destTy.satisfies(
|
||||
Constraint.IsEither(
|
||||
KnownType.Int,
|
||||
KnownType.Char,
|
||||
"read must be applied to an int or char"
|
||||
),
|
||||
dest.pos
|
||||
)
|
||||
}
|
||||
(
|
||||
Chain.one(
|
||||
microWacc.Assign(
|
||||
destTyped,
|
||||
microWacc.Call(
|
||||
microWacc.Builtin.Read,
|
||||
List(
|
||||
destTy match {
|
||||
case KnownType.Int => " %d".toMicroWaccCharArray(stmt.pos)
|
||||
case KnownType.Char | _ => " %c".toMicroWaccCharArray(stmt.pos)
|
||||
},
|
||||
destTyped
|
||||
)
|
||||
)(dest.pos)
|
||||
)(stmt.pos)
|
||||
),
|
||||
destErrors ++ destTyErrors
|
||||
)
|
||||
case ast.Free(lhs) =>
|
||||
val (lhsTyped, lhsErrors) = checkValue(
|
||||
lhs,
|
||||
Constraint.IsEither(
|
||||
KnownType.Array(?),
|
||||
KnownType.Pair(?, ?),
|
||||
"free must be applied to an array or pair"
|
||||
)
|
||||
)
|
||||
(Chain.one(microWacc.Call(microWacc.Builtin.Free, List(lhsTyped))(stmt.pos)), lhsErrors)
|
||||
case ast.Return(expr) =>
|
||||
val (exprTyped, exprErrors) = checkValue(expr, returnConstraint)
|
||||
(Chain.one(microWacc.Return(exprTyped)(stmt.pos)), exprErrors)
|
||||
case ast.Exit(expr) =>
|
||||
val (exprTyped, exprErrors) =
|
||||
checkValue(expr, Constraint.Is(KnownType.Int, "exit value must be int"))
|
||||
(Chain.one(microWacc.Call(microWacc.Builtin.Exit, List(exprTyped))(stmt.pos)), exprErrors)
|
||||
case ast.Print(expr, newline) =>
|
||||
// This constraint should never fail, the scope-checker should have caught it already
|
||||
val (exprTyped, exprErrors) = checkValue(expr, Constraint.Unconstrained)
|
||||
val exprFormat = exprTyped.ty match {
|
||||
case KnownType.Bool | KnownType.String => "%s"
|
||||
case KnownType.Array(KnownType.Char) => "%.*s"
|
||||
case KnownType.Char => "%c"
|
||||
case KnownType.Int => "%d"
|
||||
case KnownType.Pair(_, _) | KnownType.Array(_) | ? => "%p"
|
||||
}
|
||||
val printfCall = { (func: microWacc.Builtin, value: microWacc.Expr) =>
|
||||
Chain.one(
|
||||
microWacc.Call(
|
||||
func,
|
||||
List(
|
||||
s"$exprFormat${if newline then "\n" else ""}".toMicroWaccCharArray(stmt.pos),
|
||||
value
|
||||
)
|
||||
)(stmt.pos)
|
||||
)
|
||||
}
|
||||
(
|
||||
exprTyped.ty match {
|
||||
case KnownType.Bool =>
|
||||
Chain.one(
|
||||
microWacc.If(
|
||||
exprTyped,
|
||||
printfCall(microWacc.Builtin.Printf, "true".toMicroWaccCharArray(stmt.pos)),
|
||||
printfCall(microWacc.Builtin.Printf, "false".toMicroWaccCharArray(stmt.pos))
|
||||
)(stmt.pos)
|
||||
)
|
||||
case KnownType.Array(KnownType.Char) =>
|
||||
printfCall(microWacc.Builtin.PrintCharArray, exprTyped)
|
||||
case _ => printfCall(microWacc.Builtin.Printf, exprTyped)
|
||||
},
|
||||
exprErrors
|
||||
)
|
||||
case ast.If(cond, thenStmt, elseStmt) =>
|
||||
val (condTyped, condErrors) =
|
||||
checkValue(cond, Constraint.Is(KnownType.Bool, "if condition must be a bool"))
|
||||
val (thenStmtTyped, thenErrors) = thenStmt.foldMap(checkStmt(_, returnConstraint))
|
||||
val (elseStmtTyped, elseErrors) = elseStmt.foldMap(checkStmt(_, returnConstraint))
|
||||
(
|
||||
Chain.one(microWacc.If(condTyped, thenStmtTyped, elseStmtTyped)(cond.pos)),
|
||||
condErrors ++ thenErrors ++ elseErrors
|
||||
)
|
||||
case ast.While(cond, body) =>
|
||||
val (condTyped, condErrors) =
|
||||
checkValue(cond, Constraint.Is(KnownType.Bool, "while condition must be a bool"))
|
||||
val (bodyTyped, bodyErrors) = body.foldMap(checkStmt(_, returnConstraint))
|
||||
(Chain.one(microWacc.While(condTyped, bodyTyped)(cond.pos)), condErrors ++ bodyErrors)
|
||||
case ast.Block(body) => body.foldMap(checkStmt(_, returnConstraint))
|
||||
case skip @ ast.Skip() => (Chain.empty, Chain.empty)
|
||||
}
|
||||
|
||||
/** Type-check an AST LValue, RValue or Expr node. This function does all 3 since these traits
|
||||
* overlap in the AST.
|
||||
*
|
||||
* @param value
|
||||
* The value to type-check.
|
||||
* @param constraint
|
||||
* The type constraint that the value must satisfy.
|
||||
* @return
|
||||
* The most specific type of the value if it could be determined, or ? if it could not.
|
||||
*/
|
||||
private def checkValue(
|
||||
value: ast.LValue | ast.RValue | ast.Expr,
|
||||
constraint: Constraint
|
||||
): (microWacc.Expr, Chain[Error]) = value match {
|
||||
case l @ ast.IntLiter(v) =>
|
||||
val (_, errors) = KnownType.Int.satisfies(constraint, l.pos)
|
||||
(microWacc.IntLiter(v), errors)
|
||||
case l @ ast.BoolLiter(v) =>
|
||||
val (_, errors) = KnownType.Bool.satisfies(constraint, l.pos)
|
||||
(microWacc.BoolLiter(v), errors)
|
||||
case l @ ast.CharLiter(v) =>
|
||||
val (_, errors) = KnownType.Char.satisfies(constraint, l.pos)
|
||||
(microWacc.CharLiter(v), errors)
|
||||
case l @ ast.StrLiter(v) =>
|
||||
val (_, errors) = KnownType.String.satisfies(constraint, l.pos)
|
||||
(v.toMicroWaccCharArray(l.pos), errors)
|
||||
case l @ ast.PairLiter() =>
|
||||
val (ty, errors) = KnownType.Pair(?, ?).satisfies(constraint, l.pos)
|
||||
(microWacc.NullLiter()(ty), errors)
|
||||
case ast.Parens(expr) => checkValue(expr, constraint)
|
||||
case l @ ast.ArrayLiter(elems) =>
|
||||
val ((elemTy, elemsErrors), elemsTyped) =
|
||||
elems.mapAccumulate[(SemType, Chain[Error]), microWacc.Expr]((?, Chain.empty)) {
|
||||
case ((acc, errors), elem) =>
|
||||
val (elemTyped, elemErrors) = checkValue(
|
||||
elem,
|
||||
Constraint.IsSymmetricCompatible(acc, s"array elements must have the same type")
|
||||
)
|
||||
((elemTyped.ty, errors ++ elemErrors), elemTyped)
|
||||
}
|
||||
val (arrayTy, arrayErrors) = KnownType
|
||||
// Start with an unknown param type, make it more specific while checking the elements.
|
||||
.Array(elemTy)
|
||||
.satisfies(constraint, l.pos)
|
||||
(microWacc.ArrayLiter(elemsTyped)(arrayTy, l.pos), elemsErrors ++ arrayErrors)
|
||||
case l @ ast.NewPair(fst, snd) =>
|
||||
val (fstTyped, fstErrors) = checkValue(fst, Constraint.Unconstrained)
|
||||
val (sndTyped, sndErrors) = checkValue(snd, Constraint.Unconstrained)
|
||||
val (pairTy, pairErrors) =
|
||||
KnownType.Pair(fstTyped.ty, sndTyped.ty).satisfies(constraint, l.pos)
|
||||
(
|
||||
microWacc.ArrayLiter(List(fstTyped, sndTyped))(pairTy, l.pos),
|
||||
fstErrors ++ sndErrors ++ pairErrors
|
||||
)
|
||||
case ast.Call(id, args) =>
|
||||
val funcTy @ FuncType(retTy, paramTys) = id.ty.asInstanceOf[FuncType]
|
||||
val lenError =
|
||||
if (args.length == paramTys.length) then Chain.empty
|
||||
else Chain.one(Error.FunctionParamsMismatch(id, paramTys.length, args.length, funcTy))
|
||||
// Even if the number of arguments is wrong, we still check the types of the arguments
|
||||
// in the best way we can (by taking a zip).
|
||||
val (argsErrors, argsTyped) =
|
||||
args.zip(paramTys).mapAccumulate(Chain.empty[Error]) { case (errors, (arg, paramTy)) =>
|
||||
val (argTyped, argErrors) =
|
||||
checkValue(arg, Constraint.Is(paramTy, s"argument type mismatch in function ${id.v}"))
|
||||
(errors ++ argErrors, argTyped)
|
||||
}
|
||||
val (retTyChecked, retErrors) = retTy.satisfies(constraint, id.pos)
|
||||
(
|
||||
microWacc.Call(microWacc.Ident(id.v, id.guid)(retTyChecked), argsTyped)(id.pos),
|
||||
lenError ++ argsErrors ++ retErrors
|
||||
)
|
||||
|
||||
// Unary operators
|
||||
case ast.Negate(x) =>
|
||||
val (argTyped, argErrors) =
|
||||
checkValue(x, Constraint.Is(KnownType.Int, "negation must be applied to an int"))
|
||||
val (retTy, retErrors) = KnownType.Int.satisfies(constraint, x.pos)
|
||||
(microWacc.UnaryOp(argTyped, microWacc.UnaryOperator.Negate)(retTy), argErrors ++ retErrors)
|
||||
case ast.Not(x) =>
|
||||
val (argTyped, argErrors) =
|
||||
checkValue(x, Constraint.Is(KnownType.Bool, "logical not must be applied to a bool"))
|
||||
val (retTy, retErrors) = KnownType.Bool.satisfies(constraint, x.pos)
|
||||
(microWacc.UnaryOp(argTyped, microWacc.UnaryOperator.Not)(retTy), argErrors ++ retErrors)
|
||||
case ast.Len(x) =>
|
||||
val (argTyped, argErrors) =
|
||||
checkValue(x, Constraint.Is(KnownType.Array(?), "len must be applied to an array"))
|
||||
val (retTy, retErrors) = KnownType.Int.satisfies(constraint, x.pos)
|
||||
(microWacc.UnaryOp(argTyped, microWacc.UnaryOperator.Len)(retTy), argErrors ++ retErrors)
|
||||
case ast.Ord(x) =>
|
||||
val (argTyped, argErrors) =
|
||||
checkValue(x, Constraint.Is(KnownType.Char, "ord must be applied to a char"))
|
||||
val (retTy, retErrors) = KnownType.Int.satisfies(constraint, x.pos)
|
||||
(microWacc.UnaryOp(argTyped, microWacc.UnaryOperator.Ord)(retTy), argErrors ++ retErrors)
|
||||
case ast.Chr(x) =>
|
||||
val (argTyped, argErrors) =
|
||||
checkValue(x, Constraint.Is(KnownType.Int, "chr must be applied to an int"))
|
||||
val (retTy, retErrors) = KnownType.Char.satisfies(constraint, x.pos)
|
||||
(microWacc.UnaryOp(argTyped, microWacc.UnaryOperator.Chr)(retTy), argErrors ++ retErrors)
|
||||
|
||||
// Binary operators
|
||||
case op: (ast.Add | ast.Sub | ast.Mul | ast.Div | ast.Mod) =>
|
||||
val operand = Constraint.Is(KnownType.Int, s"${op.name} operator must be applied to an int")
|
||||
val (xTyped, xErrors) = checkValue(op.x, operand)
|
||||
val (yTyped, yErrors) = checkValue(op.y, operand)
|
||||
val (retTy, retErrors) = KnownType.Int.satisfies(constraint, op.pos)
|
||||
(
|
||||
microWacc.BinaryOp(xTyped, yTyped, microWacc.BinaryOperator.fromAst(op))(retTy),
|
||||
xErrors ++ yErrors ++ retErrors
|
||||
)
|
||||
case op: (ast.Eq | ast.Neq) =>
|
||||
val (xTyped, xErrors) = checkValue(op.x, Constraint.Unconstrained)
|
||||
val (yTyped, yErrors) = checkValue(
|
||||
op.y,
|
||||
Constraint.Is(xTyped.ty, s"${op.name} operator must be applied to values of the same type")
|
||||
)
|
||||
val (retTy, retErrors) = KnownType.Bool.satisfies(constraint, op.pos)
|
||||
(
|
||||
microWacc.BinaryOp(xTyped, yTyped, microWacc.BinaryOperator.fromAst(op))(retTy),
|
||||
xErrors ++ yErrors ++ retErrors
|
||||
)
|
||||
case op: (ast.Less | ast.LessEq | ast.Greater | ast.GreaterEq) =>
|
||||
val xConstraint = Constraint.IsEither(
|
||||
KnownType.Int,
|
||||
KnownType.Char,
|
||||
s"${op.name} operator must be applied to an int or char"
|
||||
)
|
||||
val (xTyped, xErrors) = checkValue(op.x, xConstraint)
|
||||
// If x type-check failed, we still want to check y is an Int or Char (rather than ?)
|
||||
val yConstraint = xTyped.ty match {
|
||||
case ? => xConstraint
|
||||
case xTy =>
|
||||
Constraint.Is(xTy, s"${op.name} operator must be applied to values of the same type")
|
||||
}
|
||||
val (yTyped, yErrors) = checkValue(op.y, yConstraint)
|
||||
val (retTy, retErrors) = KnownType.Bool.satisfies(constraint, op.pos)
|
||||
(
|
||||
microWacc.BinaryOp(xTyped, yTyped, microWacc.BinaryOperator.fromAst(op))(retTy),
|
||||
xErrors ++ yErrors ++ retErrors
|
||||
)
|
||||
case op: (ast.And | ast.Or) =>
|
||||
val operand = Constraint.Is(KnownType.Bool, s"${op.name} operator must be applied to a bool")
|
||||
val (xTyped, xErrors) = checkValue(op.x, operand)
|
||||
val (yTyped, yErrors) = checkValue(op.y, operand)
|
||||
val (retTy, retErrors) = KnownType.Bool.satisfies(constraint, op.pos)
|
||||
(
|
||||
microWacc.BinaryOp(xTyped, yTyped, microWacc.BinaryOperator.fromAst(op))(retTy),
|
||||
xErrors ++ yErrors ++ retErrors
|
||||
)
|
||||
|
||||
case lvalue: ast.LValue => checkLValue(lvalue, constraint)
|
||||
}
|
||||
|
||||
/** Type-check an AST LValue node. Separate because microWacc keeps LValues
|
||||
*
|
||||
* @param value
|
||||
* The value to type-check.
|
||||
* @param constraint
|
||||
* The type constraint that the value must satisfy.
|
||||
* @param ctx
|
||||
* The type checker context which includes the global names and functions, and an errors
|
||||
* builder.
|
||||
* @return
|
||||
* The most specific type of the value if it could be determined, or ? if it could not.
|
||||
*/
|
||||
private def checkLValue(
|
||||
value: ast.LValue,
|
||||
constraint: Constraint
|
||||
): (microWacc.LValue, Chain[Error]) = value match {
|
||||
case id @ ast.Ident(name, guid, ty) =>
|
||||
val (idTy, idErrors) = ty.asInstanceOf[SemType].satisfies(constraint, id.pos)
|
||||
(microWacc.Ident(name, guid)(idTy), idErrors)
|
||||
case ast.ArrayElem(id, indices) =>
|
||||
val arrayTy = id.ty.asInstanceOf[SemType]
|
||||
val ((elemTy, elemErrors), indicesTyped) =
|
||||
indices.mapAccumulate((arrayTy.asInstanceOf[SemType], Chain.empty[Error])) {
|
||||
case ((acc, errors), elem) =>
|
||||
val (idxTyped, idxErrors) =
|
||||
checkValue(elem, Constraint.Is(KnownType.Int, "array index must be an int"))
|
||||
val (next, nextError) = acc match {
|
||||
case KnownType.Array(innerTy) => (innerTy, Chain.empty)
|
||||
case ? => (?, Chain.empty) // we can keep indexing an unknown type
|
||||
case nonArrayTy =>
|
||||
(
|
||||
?,
|
||||
Chain.one(
|
||||
Error.TypeMismatch(
|
||||
elem.pos,
|
||||
KnownType.Array(?),
|
||||
acc,
|
||||
"cannot index into a non-array"
|
||||
)
|
||||
)
|
||||
)
|
||||
}
|
||||
((next, errors ++ idxErrors ++ nextError), idxTyped)
|
||||
}
|
||||
val (retTy, retErrors) = elemTy.satisfies(constraint, value.pos)
|
||||
val firstArrayElem = microWacc.ArrayElem(
|
||||
microWacc.Ident(id.v, id.guid)(arrayTy),
|
||||
indicesTyped.head
|
||||
)(retTy)
|
||||
val arrayElem = indicesTyped.tail.foldLeft(firstArrayElem) { (acc, idx) =>
|
||||
microWacc.ArrayElem(acc, idx)(KnownType.Array(acc.ty))
|
||||
}
|
||||
// Need to type-check the final arrayElem with the constraint
|
||||
// TODO: What
|
||||
(microWacc.ArrayElem(arrayElem.value, arrayElem.index)(retTy), elemErrors ++ retErrors)
|
||||
case ast.Fst(elem) =>
|
||||
val (elemTyped, elemErrors) = checkLValue(
|
||||
elem,
|
||||
Constraint.Is(KnownType.Pair(?, ?), "fst must be applied to a pair")
|
||||
)
|
||||
val (retTy, retErrors) = elemTyped.ty match {
|
||||
case KnownType.Pair(left, _) => left.satisfies(constraint, elem.pos)
|
||||
case _ => (?, Chain.one(Error.InternalError(elem.pos, "fst must be applied to a pair")))
|
||||
}
|
||||
(microWacc.ArrayElem(elemTyped, microWacc.IntLiter(0))(retTy), elemErrors ++ retErrors)
|
||||
case ast.Snd(elem) =>
|
||||
val (elemTyped, elemErrors) = checkLValue(
|
||||
elem,
|
||||
Constraint.Is(KnownType.Pair(?, ?), "snd must be applied to a pair")
|
||||
)
|
||||
val (retTy, retErrors) = elemTyped.ty match {
|
||||
case KnownType.Pair(_, right) => right.satisfies(constraint, elem.pos)
|
||||
case _ => (?, Chain.one(Error.InternalError(elem.pos, "snd must be applied to a pair")))
|
||||
}
|
||||
(microWacc.ArrayElem(elemTyped, microWacc.IntLiter(1))(retTy), elemErrors ++ retErrors)
|
||||
}
|
||||
|
||||
extension (s: String) {
|
||||
def toMicroWaccCharArray(pos: ast.Position): microWacc.ArrayLiter =
|
||||
microWacc.ArrayLiter(s.map(microWacc.CharLiter(_)).toList)(KnownType.String, pos)
|
||||
}
|
||||
}
|
||||
@@ -3,7 +3,9 @@ package wacc
|
||||
object types {
|
||||
import ast._
|
||||
|
||||
sealed trait SemType {
|
||||
sealed trait RenamerType
|
||||
|
||||
sealed trait SemType extends RenamerType {
|
||||
override def toString(): String = this match {
|
||||
case KnownType.Int => "int"
|
||||
case KnownType.Bool => "bool"
|
||||
@@ -41,5 +43,5 @@ object types {
|
||||
}
|
||||
}
|
||||
|
||||
case class FuncType(returnType: SemType, params: List[SemType])
|
||||
case class FuncType(returnType: SemType, params: List[SemType]) extends RenamerType
|
||||
}
|
||||
@@ -1,219 +0,0 @@
|
||||
package wacc
|
||||
|
||||
import scala.collection.mutable
|
||||
|
||||
object renamer {
|
||||
import ast._
|
||||
import types._
|
||||
|
||||
private enum IdentType {
|
||||
case Func
|
||||
case Var
|
||||
}
|
||||
|
||||
private class Scope(
|
||||
val current: mutable.Map[(String, IdentType), Ident],
|
||||
val parent: Map[(String, IdentType), Ident]
|
||||
) {
|
||||
|
||||
/** Create a new scope with the current scope as its parent.
|
||||
*
|
||||
* @return
|
||||
* A new scope with an empty current scope, and this scope flattened into the parent scope.
|
||||
*/
|
||||
def subscope: Scope =
|
||||
Scope(mutable.Map.empty, Map.empty.withDefault(current.withDefault(parent)))
|
||||
|
||||
/** Attempt to add a new identifier to the current scope. If the identifier already exists in
|
||||
* the current scope, add an error to the error list.
|
||||
*
|
||||
* @param ty
|
||||
* The semantic type of the variable identifier, or function identifier type.
|
||||
* @param name
|
||||
* The name of the identifier.
|
||||
* @param globalNames
|
||||
* The global map of identifiers to semantic types - the identifier will be added to this
|
||||
* map.
|
||||
* @param globalNumbering
|
||||
* The global map of identifier names to the number of times they have been declared - will
|
||||
* used to rename this identifier, and will be incremented.
|
||||
* @param errors
|
||||
* The list of errors to append to.
|
||||
*/
|
||||
def add(ty: SemType | FuncType, name: Ident)(using
|
||||
globalNames: mutable.Map[Ident, SemType],
|
||||
globalFuncs: mutable.Map[Ident, FuncType],
|
||||
globalNumbering: mutable.Map[String, Int],
|
||||
errors: mutable.Builder[Error, List[Error]]
|
||||
) = {
|
||||
val identType = ty match {
|
||||
case _: SemType => IdentType.Var
|
||||
case _: FuncType => IdentType.Func
|
||||
}
|
||||
current.get((name.v, identType)) match {
|
||||
case Some(Ident(_, uid)) =>
|
||||
errors += Error.DuplicateDeclaration(name)
|
||||
name.uid = uid
|
||||
case None =>
|
||||
val uid = globalNumbering.getOrElse(name.v, 0)
|
||||
name.uid = uid
|
||||
current((name.v, identType)) = name
|
||||
|
||||
ty match {
|
||||
case semType: SemType =>
|
||||
globalNames(name) = semType
|
||||
case funcType: FuncType =>
|
||||
globalFuncs(name) = funcType
|
||||
}
|
||||
globalNumbering(name.v) = uid + 1
|
||||
}
|
||||
}
|
||||
|
||||
private def get(name: String, identType: IdentType): Option[Ident] =
|
||||
// Unfortunately map defaults only work with `.apply()`, which throws an error when the key is not found.
|
||||
// Neither is there a way to check whether a default exists, so we have to use a try-catch.
|
||||
try {
|
||||
Some(current.withDefault(parent)((name, identType)))
|
||||
} catch {
|
||||
case _: NoSuchElementException => None
|
||||
}
|
||||
|
||||
def getVar(name: String): Option[Ident] = get(name, IdentType.Var)
|
||||
def getFunc(name: String): Option[Ident] = get(name, IdentType.Func)
|
||||
}
|
||||
|
||||
/** Check scoping of all variables and functions in the program. Also generate semantic types for
|
||||
* all identifiers.
|
||||
*
|
||||
* @param prog
|
||||
* AST of the program
|
||||
* @param errors
|
||||
* List of errors to append to
|
||||
* @return
|
||||
* Map of all (renamed) identifies to their semantic types
|
||||
*/
|
||||
def rename(prog: Program)(using
|
||||
errors: mutable.Builder[Error, List[Error]]
|
||||
): (Map[Ident, SemType], Map[Ident, FuncType]) = {
|
||||
given globalNames: mutable.Map[Ident, SemType] = mutable.Map.empty
|
||||
given globalFuncs: mutable.Map[Ident, FuncType] = mutable.Map.empty
|
||||
given globalNumbering: mutable.Map[String, Int] = mutable.Map.empty
|
||||
val scope = Scope(mutable.Map.empty, Map.empty)
|
||||
val Program(funcs, main) = prog
|
||||
funcs
|
||||
// First add all function declarations to the scope
|
||||
.map { case FuncDecl(retType, name, params, body) =>
|
||||
val paramTypes = params.map { param =>
|
||||
val paramType = SemType(param.paramType)
|
||||
paramType
|
||||
}
|
||||
scope.add(FuncType(SemType(retType), paramTypes), name)
|
||||
(params zip paramTypes, body)
|
||||
}
|
||||
// Only then rename the function bodies
|
||||
// (functions can call one-another regardless of order of declaration)
|
||||
.foreach { case (params, body) =>
|
||||
val functionScope = scope.subscope
|
||||
params.foreach { case (param, paramType) =>
|
||||
functionScope.add(paramType, param.name)
|
||||
}
|
||||
body.toList.foreach(rename(functionScope.subscope)) // body can shadow function params
|
||||
}
|
||||
main.toList.foreach(rename(scope))
|
||||
(globalNames.toMap, globalFuncs.toMap)
|
||||
}
|
||||
|
||||
/** Check scoping of all identifies in a given AST node.
|
||||
*
|
||||
* @param scope
|
||||
* The current scope and flattened parent scope.
|
||||
* @param node
|
||||
* The AST node.
|
||||
* @param globalNames
|
||||
* The global map of identifiers to semantic types - renamed identifiers will be added to this
|
||||
* map.
|
||||
* @param globalNumbering
|
||||
* The global map of identifier names to the number of times they have been declared - used and
|
||||
* updated during identifier renaming.
|
||||
* @param errors
|
||||
*/
|
||||
private def rename(scope: Scope)(
|
||||
node: Ident | Stmt | LValue | RValue | Expr
|
||||
)(using
|
||||
globalNames: mutable.Map[Ident, SemType],
|
||||
globalFuncs: mutable.Map[Ident, FuncType],
|
||||
globalNumbering: mutable.Map[String, Int],
|
||||
errors: mutable.Builder[Error, List[Error]]
|
||||
): Unit = node match {
|
||||
// These cases are more interesting because the involve making subscopes
|
||||
// or modifying the current scope.
|
||||
case VarDecl(synType, name, value) => {
|
||||
// Order matters here. Variable isn't declared until after the value is evaluated.
|
||||
rename(scope)(value)
|
||||
// Attempt to add the new variable to the current scope.
|
||||
scope.add(SemType(synType), name)
|
||||
}
|
||||
case If(cond, thenStmt, elseStmt) => {
|
||||
rename(scope)(cond)
|
||||
// then and else both have their own scopes
|
||||
thenStmt.toList.foreach(rename(scope.subscope))
|
||||
elseStmt.toList.foreach(rename(scope.subscope))
|
||||
}
|
||||
case While(cond, body) => {
|
||||
rename(scope)(cond)
|
||||
// while bodies have their own scopes
|
||||
body.toList.foreach(rename(scope.subscope))
|
||||
}
|
||||
// begin-end blocks have their own scopes
|
||||
case Block(body) => body.toList.foreach(rename(scope.subscope))
|
||||
|
||||
// These cases are simpler, mostly just recursive calls to rename()
|
||||
case Assign(lhs, value) => {
|
||||
// Variables may be reassigned with their value in the rhs, so order doesn't matter here.
|
||||
rename(scope)(lhs)
|
||||
rename(scope)(value)
|
||||
}
|
||||
case Read(lhs) => rename(scope)(lhs)
|
||||
case Free(expr) => rename(scope)(expr)
|
||||
case Return(expr) => rename(scope)(expr)
|
||||
case Exit(expr) => rename(scope)(expr)
|
||||
case Print(expr, _) => rename(scope)(expr)
|
||||
case NewPair(fst, snd) => {
|
||||
rename(scope)(fst)
|
||||
rename(scope)(snd)
|
||||
}
|
||||
case Call(name, args) => {
|
||||
scope.getFunc(name.v) match {
|
||||
case Some(Ident(_, uid)) => name.uid = uid
|
||||
case None =>
|
||||
errors += Error.UndefinedFunction(name)
|
||||
scope.add(FuncType(?, args.map(_ => ?)), name)
|
||||
}
|
||||
args.foreach(rename(scope))
|
||||
}
|
||||
case Fst(elem) => rename(scope)(elem)
|
||||
case Snd(elem) => rename(scope)(elem)
|
||||
case ArrayLiter(elems) => elems.foreach(rename(scope))
|
||||
case ArrayElem(name, indices) => {
|
||||
rename(scope)(name)
|
||||
indices.toList.foreach(rename(scope))
|
||||
}
|
||||
case Parens(expr) => rename(scope)(expr)
|
||||
case op: UnaryOp => rename(scope)(op.x)
|
||||
case op: BinaryOp => {
|
||||
rename(scope)(op.x)
|
||||
rename(scope)(op.y)
|
||||
}
|
||||
// Default to variables. Only `call` uses IdentType.Func.
|
||||
case id: Ident => {
|
||||
scope.getVar(id.v) match {
|
||||
case Some(Ident(_, uid)) => id.uid = uid
|
||||
case None =>
|
||||
errors += Error.UndeclaredVariable(id)
|
||||
scope.add(?, id)
|
||||
}
|
||||
}
|
||||
// These literals cannot contain identifies, exit immediately.
|
||||
case IntLiter(_) | BoolLiter(_) | CharLiter(_) | StrLiter(_) | PairLiter() | Skip() => ()
|
||||
}
|
||||
}
|
||||
@@ -1,322 +0,0 @@
|
||||
package wacc
|
||||
|
||||
import cats.syntax.all._
|
||||
import scala.collection.mutable
|
||||
|
||||
object typeChecker {
|
||||
import wacc.ast._
|
||||
import wacc.types._
|
||||
|
||||
case class TypeCheckerCtx(
|
||||
globalNames: Map[Ident, SemType],
|
||||
globalFuncs: Map[Ident, FuncType],
|
||||
errors: mutable.Builder[Error, List[Error]]
|
||||
) {
|
||||
def typeOf(ident: Ident): SemType = globalNames(ident)
|
||||
|
||||
def funcType(ident: Ident): FuncType = globalFuncs(ident)
|
||||
|
||||
def error(err: Error): SemType =
|
||||
errors += err
|
||||
?
|
||||
}
|
||||
|
||||
private enum Constraint {
|
||||
case Unconstrained
|
||||
// Allows weakening in one direction
|
||||
case Is(ty: SemType, msg: String)
|
||||
// Allows weakening in both directions, useful for array literals
|
||||
case IsSymmetricCompatible(ty: SemType, msg: String)
|
||||
// Does not allow weakening
|
||||
case IsUnweakenable(ty: SemType, msg: String)
|
||||
case IsEither(ty1: SemType, ty2: SemType, msg: String)
|
||||
case Never(msg: String)
|
||||
}
|
||||
|
||||
extension (ty: SemType) {
|
||||
|
||||
/** Check if a type satisfies a constraint.
|
||||
*
|
||||
* @param constraint
|
||||
* Constraint to satisfy.
|
||||
* @param pos
|
||||
* Position to pass to the error, if constraint was not satisfied.
|
||||
* @return
|
||||
* The type if the constraint was satisfied, or ? if it was not.
|
||||
*/
|
||||
private def satisfies(constraint: Constraint, pos: Position)(using
|
||||
ctx: TypeCheckerCtx
|
||||
): SemType =
|
||||
(ty, constraint) match {
|
||||
case (KnownType.Array(KnownType.Char), Constraint.Is(KnownType.String, _)) =>
|
||||
KnownType.String
|
||||
case (
|
||||
KnownType.String,
|
||||
Constraint.IsSymmetricCompatible(KnownType.Array(KnownType.Char), _)
|
||||
) =>
|
||||
KnownType.String
|
||||
case (ty, Constraint.IsSymmetricCompatible(ty2, msg)) =>
|
||||
ty.satisfies(Constraint.Is(ty2, msg), pos)
|
||||
// Change to IsUnweakenable to disallow recursive weakening
|
||||
case (ty, Constraint.Is(ty2, msg)) => ty.satisfies(Constraint.IsUnweakenable(ty2, msg), pos)
|
||||
case (ty, Constraint.Unconstrained) => ty
|
||||
case (ty, Constraint.Never(msg)) =>
|
||||
ctx.error(Error.SemanticError(pos, msg))
|
||||
case (ty, Constraint.IsEither(ty1, ty2, msg)) =>
|
||||
(ty moreSpecific ty1).orElse(ty moreSpecific ty2).getOrElse {
|
||||
ctx.error(Error.TypeMismatch(pos, ty1, ty, msg))
|
||||
}
|
||||
case (ty, Constraint.IsUnweakenable(ty2, msg)) =>
|
||||
(ty moreSpecific ty2).getOrElse {
|
||||
ctx.error(Error.TypeMismatch(pos, ty2, ty, msg))
|
||||
}
|
||||
}
|
||||
|
||||
/** Tries to merge two types, returning the more specific one if possible.
|
||||
*
|
||||
* @param ty2
|
||||
* The other type to merge with.
|
||||
* @return
|
||||
* The more specific type if it could be determined, or None if the types are incompatible.
|
||||
*/
|
||||
private infix def moreSpecific(ty2: SemType): Option[SemType] =
|
||||
(ty, ty2) match {
|
||||
case (ty, ?) => Some(ty)
|
||||
case (?, ty) => Some(ty)
|
||||
case (ty1, ty2) if ty1 == ty2 => Some(ty1)
|
||||
case (KnownType.Array(inn1), KnownType.Array(inn2)) =>
|
||||
(inn1 moreSpecific inn2).map(KnownType.Array(_))
|
||||
case (KnownType.Pair(fst1, snd1), KnownType.Pair(fst2, snd2)) =>
|
||||
(fst1 moreSpecific fst2, snd1 moreSpecific snd2).mapN(KnownType.Pair(_, _))
|
||||
case _ => None
|
||||
}
|
||||
}
|
||||
|
||||
/** Type-check a WACC program.
|
||||
*
|
||||
* @param prog
|
||||
* The AST of the program to type-check.
|
||||
* @param ctx
|
||||
* The type checker context which includes the global names and functions, and an errors
|
||||
* builder.
|
||||
*/
|
||||
def check(prog: Program)(using
|
||||
ctx: TypeCheckerCtx
|
||||
): Unit = {
|
||||
// Ignore function syntax types for return value and params, since those have been converted
|
||||
// to SemTypes by the renamer.
|
||||
prog.funcs.foreach { case FuncDecl(_, name, _, stmts) =>
|
||||
val FuncType(retType, _) = ctx.funcType(name)
|
||||
stmts.toList.foreach(
|
||||
checkStmt(_, Constraint.Is(retType, s"function ${name.v} must return $retType"))
|
||||
)
|
||||
}
|
||||
prog.main.toList.foreach(checkStmt(_, Constraint.Never("main function must not return")))
|
||||
}
|
||||
|
||||
/** Type-check an AST statement node.
|
||||
*
|
||||
* @param stmt
|
||||
* The statement to type-check.
|
||||
* @param returnConstraint
|
||||
* The constraint that any `return <expr>` statements must satisfy.
|
||||
*/
|
||||
private def checkStmt(stmt: Stmt, returnConstraint: Constraint)(using
|
||||
ctx: TypeCheckerCtx
|
||||
): Unit = stmt match {
|
||||
// Ignore the type of the variable, since it has been converted to a SemType by the renamer.
|
||||
case VarDecl(_, name, value) =>
|
||||
val expectedTy = ctx.typeOf(name)
|
||||
checkValue(
|
||||
value,
|
||||
Constraint.Is(
|
||||
expectedTy,
|
||||
s"variable ${name.v} must be assigned a value of type $expectedTy"
|
||||
)
|
||||
)
|
||||
case Assign(lhs, rhs) =>
|
||||
val lhsTy = checkValue(lhs, Constraint.Unconstrained)
|
||||
(lhsTy, checkValue(rhs, Constraint.Is(lhsTy, s"assignment must have type $lhsTy"))) match {
|
||||
case (?, ?) =>
|
||||
ctx.error(
|
||||
Error.SemanticError(lhs.pos, "assignment with both sides of unknown type is illegal")
|
||||
)
|
||||
case _ => ()
|
||||
}
|
||||
case Read(dest) =>
|
||||
checkValue(dest, Constraint.Unconstrained) match {
|
||||
case ? =>
|
||||
ctx.error(
|
||||
Error.SemanticError(dest.pos, "cannot read into a destination with an unknown type")
|
||||
)
|
||||
case destTy =>
|
||||
destTy.satisfies(
|
||||
Constraint.IsEither(
|
||||
KnownType.Int,
|
||||
KnownType.Char,
|
||||
"read must be applied to an int or char"
|
||||
),
|
||||
dest.pos
|
||||
)
|
||||
}
|
||||
case Free(lhs) =>
|
||||
checkValue(
|
||||
lhs,
|
||||
Constraint.IsEither(
|
||||
KnownType.Array(?),
|
||||
KnownType.Pair(?, ?),
|
||||
"free must be applied to an array or pair"
|
||||
)
|
||||
)
|
||||
case Return(expr) =>
|
||||
checkValue(expr, returnConstraint)
|
||||
case Exit(expr) =>
|
||||
checkValue(expr, Constraint.Is(KnownType.Int, "exit value must be int"))
|
||||
case Print(expr, _) =>
|
||||
// This constraint should never fail, the scope-checker should have caught it already
|
||||
checkValue(expr, Constraint.Unconstrained)
|
||||
case If(cond, thenStmt, elseStmt) =>
|
||||
checkValue(cond, Constraint.Is(KnownType.Bool, "if condition must be a bool"))
|
||||
thenStmt.toList.foreach(checkStmt(_, returnConstraint))
|
||||
elseStmt.toList.foreach(checkStmt(_, returnConstraint))
|
||||
case While(cond, body) =>
|
||||
checkValue(cond, Constraint.Is(KnownType.Bool, "while condition must be a bool"))
|
||||
body.toList.foreach(checkStmt(_, returnConstraint))
|
||||
case Block(body) =>
|
||||
body.toList.foreach(checkStmt(_, returnConstraint))
|
||||
case Skip() => ()
|
||||
}
|
||||
|
||||
/** Type-check an AST LValue, RValue or Expr node. This function does all 3 since these traits
|
||||
* overlap in the AST.
|
||||
*
|
||||
* @param value
|
||||
* The value to type-check.
|
||||
* @param constraint
|
||||
* The type constraint that the value must satisfy.
|
||||
* @return
|
||||
* The most specific type of the value if it could be determined, or ? if it could not.
|
||||
*/
|
||||
private def checkValue(value: LValue | RValue | Expr, constraint: Constraint)(using
|
||||
ctx: TypeCheckerCtx
|
||||
): SemType = value match {
|
||||
case l @ IntLiter(_) => KnownType.Int.satisfies(constraint, l.pos)
|
||||
case l @ BoolLiter(_) => KnownType.Bool.satisfies(constraint, l.pos)
|
||||
case l @ CharLiter(_) => KnownType.Char.satisfies(constraint, l.pos)
|
||||
case l @ StrLiter(_) => KnownType.String.satisfies(constraint, l.pos)
|
||||
case l @ PairLiter() => KnownType.Pair(?, ?).satisfies(constraint, l.pos)
|
||||
case id: Ident =>
|
||||
ctx.typeOf(id).satisfies(constraint, id.pos)
|
||||
case ArrayElem(id, indices) =>
|
||||
val arrayTy = ctx.typeOf(id)
|
||||
val elemTy = indices.foldLeftM(arrayTy) { (acc, elem) =>
|
||||
checkValue(elem, Constraint.Is(KnownType.Int, "array index must be an int"))
|
||||
acc match {
|
||||
case KnownType.Array(innerTy) => Some(innerTy)
|
||||
case ? => Some(?) // we can keep indexing an unknown type
|
||||
case nonArrayTy =>
|
||||
ctx.error(
|
||||
Error.TypeMismatch(elem.pos, KnownType.Array(?), acc, "cannot index into a non-array")
|
||||
)
|
||||
None
|
||||
}
|
||||
}
|
||||
elemTy.getOrElse(?).satisfies(constraint, id.pos)
|
||||
case Parens(expr) => checkValue(expr, constraint)
|
||||
case l @ ArrayLiter(elems) =>
|
||||
KnownType
|
||||
// Start with an unknown param type, make it more specific while checking the elements.
|
||||
.Array(elems.foldLeft[SemType](?) { case (acc, elem) =>
|
||||
checkValue(
|
||||
elem,
|
||||
Constraint.IsSymmetricCompatible(acc, s"array elements must have the same type")
|
||||
)
|
||||
})
|
||||
.satisfies(constraint, l.pos)
|
||||
case l @ NewPair(fst, snd) =>
|
||||
KnownType
|
||||
.Pair(
|
||||
checkValue(fst, Constraint.Unconstrained),
|
||||
checkValue(snd, Constraint.Unconstrained)
|
||||
)
|
||||
.satisfies(constraint, l.pos)
|
||||
case Call(id, args) =>
|
||||
val funcTy @ FuncType(retTy, paramTys) = ctx.funcType(id)
|
||||
if (args.length != paramTys.length) {
|
||||
ctx.error(Error.FunctionParamsMismatch(id, paramTys.length, args.length, funcTy))
|
||||
}
|
||||
// Even if the number of arguments is wrong, we still check the types of the arguments
|
||||
// in the best way we can (by taking a zip).
|
||||
args.zip(paramTys).foreach { case (arg, paramTy) =>
|
||||
checkValue(arg, Constraint.Is(paramTy, s"argument type mismatch in function ${id.v}"))
|
||||
}
|
||||
retTy.satisfies(constraint, id.pos)
|
||||
case Fst(elem) =>
|
||||
checkValue(
|
||||
elem,
|
||||
Constraint.Is(KnownType.Pair(?, ?), "fst must be applied to a pair")
|
||||
) match {
|
||||
case what @ KnownType.Pair(left, _) =>
|
||||
left.satisfies(constraint, elem.pos)
|
||||
case _ => ctx.error(Error.InternalError(elem.pos, "fst must be applied to a pair"))
|
||||
}
|
||||
case Snd(elem) =>
|
||||
checkValue(
|
||||
elem,
|
||||
Constraint.Is(KnownType.Pair(?, ?), "snd must be applied to a pair")
|
||||
) match {
|
||||
case KnownType.Pair(_, right) => right.satisfies(constraint, elem.pos)
|
||||
case _ => ctx.error(Error.InternalError(elem.pos, "snd must be applied to a pair"))
|
||||
}
|
||||
|
||||
// Unary operators
|
||||
case Negate(x) =>
|
||||
checkValue(x, Constraint.Is(KnownType.Int, "negation must be applied to an int"))
|
||||
KnownType.Int.satisfies(constraint, x.pos)
|
||||
case Not(x) =>
|
||||
checkValue(x, Constraint.Is(KnownType.Bool, "logical not must be applied to a bool"))
|
||||
KnownType.Bool.satisfies(constraint, x.pos)
|
||||
case Len(x) =>
|
||||
checkValue(x, Constraint.Is(KnownType.Array(?), "len must be applied to an array"))
|
||||
KnownType.Int.satisfies(constraint, x.pos)
|
||||
case Ord(x) =>
|
||||
checkValue(x, Constraint.Is(KnownType.Char, "ord must be applied to a char"))
|
||||
KnownType.Int.satisfies(constraint, x.pos)
|
||||
case Chr(x) =>
|
||||
checkValue(x, Constraint.Is(KnownType.Int, "chr must be applied to an int"))
|
||||
KnownType.Char.satisfies(constraint, x.pos)
|
||||
|
||||
// Binary operators
|
||||
case op: (Add | Sub | Mul | Div | Mod) =>
|
||||
val operand = Constraint.Is(KnownType.Int, s"${op.name} operator must be applied to an int")
|
||||
checkValue(op.x, operand)
|
||||
checkValue(op.y, operand)
|
||||
KnownType.Int.satisfies(constraint, op.pos)
|
||||
case op: (Eq | Neq) =>
|
||||
val xTy = checkValue(op.x, Constraint.Unconstrained)
|
||||
checkValue(
|
||||
op.y,
|
||||
Constraint.Is(xTy, s"${op.name} operator must be applied to values of the same type")
|
||||
)
|
||||
KnownType.Bool.satisfies(constraint, op.pos)
|
||||
case op: (Less | LessEq | Greater | GreaterEq) =>
|
||||
val xConstraint = Constraint.IsEither(
|
||||
KnownType.Int,
|
||||
KnownType.Char,
|
||||
s"${op.name} operator must be applied to an int or char"
|
||||
)
|
||||
// If x type-check failed, we still want to check y is an Int or Char (rather than ?)
|
||||
val yConstraint = checkValue(op.x, xConstraint) match {
|
||||
case ? => xConstraint
|
||||
case xTy =>
|
||||
Constraint.Is(xTy, s"${op.name} operator must be applied to values of the same type")
|
||||
}
|
||||
checkValue(op.y, yConstraint)
|
||||
KnownType.Bool.satisfies(constraint, op.pos)
|
||||
case op: (And | Or) =>
|
||||
val operand = Constraint.Is(KnownType.Bool, s"${op.name} operator must be applied to a bool")
|
||||
checkValue(op.x, operand)
|
||||
checkValue(op.y, operand)
|
||||
KnownType.Bool.satisfies(constraint, op.pos)
|
||||
}
|
||||
}
|
||||
66
src/test/wacc/backend/extensionsSpec.scala
Normal file
66
src/test/wacc/backend/extensionsSpec.scala
Normal file
@@ -0,0 +1,66 @@
|
||||
package wacc
|
||||
|
||||
import org.scalatest.flatspec.AnyFlatSpec
|
||||
import org.scalatest.Inspectors.forEvery
|
||||
import cats.data.Chain
|
||||
|
||||
class ExtensionsSpec extends AnyFlatSpec {
|
||||
import asmGenerator.concatAll
|
||||
import asmGenerator.escaped
|
||||
|
||||
behavior of "concatAll"
|
||||
|
||||
it should "handle int chains" in {
|
||||
val chain = Chain(1, 2, 3).concatAll(
|
||||
Chain(4, 5, 6),
|
||||
Chain.empty,
|
||||
Chain.one(-1)
|
||||
)
|
||||
assert(chain == Chain(1, 2, 3, 4, 5, 6, -1))
|
||||
}
|
||||
|
||||
it should "handle AsmLine chains" in {
|
||||
object lines {
|
||||
import assemblyIR._
|
||||
import assemblyIR.commonRegisters._
|
||||
val main = LabelDef("main")
|
||||
val pop = Pop(RAX)
|
||||
val add = Add(RAX, ImmediateVal(1))
|
||||
val push = Push(RAX)
|
||||
val ret = Return()
|
||||
}
|
||||
val chain = Chain(lines.main).concatAll(
|
||||
Chain.empty,
|
||||
Chain.one(lines.pop),
|
||||
Chain(lines.add, lines.push),
|
||||
Chain.one(lines.ret)
|
||||
)
|
||||
assert(chain == Chain(lines.main, lines.pop, lines.add, lines.push, lines.ret))
|
||||
}
|
||||
|
||||
behavior of "escaped"
|
||||
|
||||
val escapedStrings = Map(
|
||||
"hello" -> "hello",
|
||||
"world" -> "world",
|
||||
"hello\nworld" -> "hello\\nworld",
|
||||
"hello\tworld" -> "hello\\tworld",
|
||||
"hello\\world" -> "hello\\\\world",
|
||||
"hello\"world" -> "hello\\\"world",
|
||||
"hello'world" -> "hello\\'world",
|
||||
"hello\\nworld" -> "hello\\\\nworld",
|
||||
"hello\\tworld" -> "hello\\\\tworld",
|
||||
"hello\\\\world" -> "hello\\\\\\\\world",
|
||||
"hello\\\"world" -> "hello\\\\\\\"world",
|
||||
"hello\\'world" -> "hello\\\\\\'world",
|
||||
"hello\\n\\t\\'world" -> "hello\\\\n\\\\t\\\\\\'world",
|
||||
"hello\u0000world" -> "hello\\0world",
|
||||
"hello\bworld" -> "hello\\bworld",
|
||||
"hello\fworld" -> "hello\\fworld"
|
||||
)
|
||||
forEvery(escapedStrings) { (input, expected) =>
|
||||
it should s"return $expected" in {
|
||||
assert(input.escaped == expected)
|
||||
}
|
||||
}
|
||||
}
|
||||
68
src/test/wacc/backend/instructionSpec.scala
Normal file
68
src/test/wacc/backend/instructionSpec.scala
Normal file
@@ -0,0 +1,68 @@
|
||||
import org.scalatest.funsuite.AnyFunSuite
|
||||
import wacc.assemblyIR._
|
||||
import wacc.assemblyIR.Size._
|
||||
import wacc.assemblyIR.RegName._
|
||||
|
||||
class instructionSpec extends AnyFunSuite {
|
||||
|
||||
val named64BitRegister = Register(Q64, AX)
|
||||
|
||||
test("named 64-bit register toString") {
|
||||
assert(named64BitRegister.toString == "rax")
|
||||
}
|
||||
|
||||
val named32BitRegister = Register(D32, AX)
|
||||
|
||||
test("named 32-bit register toString") {
|
||||
assert(named32BitRegister.toString == "eax")
|
||||
}
|
||||
|
||||
val scratch64BitRegister = Register(Q64, R8)
|
||||
|
||||
test("scratch 64-bit register toString") {
|
||||
assert(scratch64BitRegister.toString == "r8")
|
||||
}
|
||||
|
||||
val scratch32BitRegister = Register(D32, R8)
|
||||
|
||||
test("scratch 32-bit register toString") {
|
||||
assert(scratch32BitRegister.toString == "r8d")
|
||||
}
|
||||
|
||||
val memLocationWithRegister = MemLocation(named64BitRegister, opSize = Some(Q64))
|
||||
|
||||
test("mem location with register toString") {
|
||||
assert(memLocationWithRegister.toString == "qword ptr [rax]")
|
||||
}
|
||||
|
||||
val memLocationFull =
|
||||
MemLocation(named64BitRegister, 32, (scratch64BitRegister, 10), Some(B8))
|
||||
|
||||
test("mem location with all fields toString") {
|
||||
assert(memLocationFull.toString == "byte ptr [rax + r8 * 10 + 32]")
|
||||
}
|
||||
|
||||
val immediateVal = ImmediateVal(123)
|
||||
|
||||
test("immediate value toString") {
|
||||
assert(immediateVal.toString == "123")
|
||||
}
|
||||
|
||||
val addInstruction = Add(named64BitRegister, immediateVal)
|
||||
|
||||
test("x86: add instruction toString") {
|
||||
assert(addInstruction.toString == "\tadd rax, 123")
|
||||
}
|
||||
|
||||
val subInstruction = Subtract(scratch64BitRegister, named64BitRegister)
|
||||
|
||||
test("x86: sub instruction toString") {
|
||||
assert(subInstruction.toString == "\tsub r8, rax")
|
||||
}
|
||||
|
||||
val callInstruction = Call(CLibFunc.Scanf)
|
||||
|
||||
test("x86: call instruction toString") {
|
||||
assert(callInstruction.toString == "\tcall scanf@plt")
|
||||
}
|
||||
}
|
||||
85
src/test/wacc/backend/labelGeneratorSpec.scala
Normal file
85
src/test/wacc/backend/labelGeneratorSpec.scala
Normal file
@@ -0,0 +1,85 @@
|
||||
package wacc
|
||||
|
||||
import org.scalatest.flatspec.AnyFlatSpec
|
||||
|
||||
class LabelGeneratorSpec extends AnyFlatSpec {
|
||||
import microWacc._
|
||||
import assemblyIR.{LabelDef, LabelArg, Directive}
|
||||
import types.?
|
||||
|
||||
"getLabel" should "return unique labels" in {
|
||||
val labelGenerator = new LabelGenerator
|
||||
val labels = (1 to 10).map(_ => labelGenerator.getLabel())
|
||||
assert(labels.distinct.length == labels.length)
|
||||
}
|
||||
|
||||
"getLabelDef" should "return unique labels" in {
|
||||
assert(
|
||||
LabelDef("test") == LabelDef("test") &&
|
||||
LabelDef("test").hashCode == LabelDef("test").hashCode,
|
||||
"Sanity check: LabelDef should be case-classes"
|
||||
)
|
||||
|
||||
val labelGenerator = new LabelGenerator
|
||||
val labels = (List(
|
||||
Builtin.Exit,
|
||||
Builtin.Printf,
|
||||
Ident("exit", 0)(?),
|
||||
Ident("test", 0)(?)
|
||||
) ++ RuntimeError.all.toList).map(labelGenerator.getLabelDef(_))
|
||||
assert(labels.distinct.length == labels.length)
|
||||
}
|
||||
|
||||
"getLabelArg" should "return unique labels" in {
|
||||
assert(
|
||||
LabelArg("test") == LabelArg("test") &&
|
||||
LabelArg("test").hashCode == LabelArg("test").hashCode,
|
||||
"Sanity check: LabelArg should be case-classes"
|
||||
)
|
||||
|
||||
val labelGenerator = new LabelGenerator
|
||||
val labels = (List(
|
||||
Builtin.Exit,
|
||||
Builtin.Printf,
|
||||
Ident("exit", 0)(?),
|
||||
Ident("test", 0)(?),
|
||||
"test",
|
||||
"test",
|
||||
"test3"
|
||||
) ++ RuntimeError.all.toList).map {
|
||||
case s: String => labelGenerator.getLabelArg(s)
|
||||
case t: (CallTarget | RuntimeError) => labelGenerator.getLabelArg(t)
|
||||
}
|
||||
assert(labels.distinct.length == labels.distinct.length)
|
||||
}
|
||||
|
||||
it should "return consistent labels to getLabelDef" in {
|
||||
val labelGenerator = new LabelGenerator
|
||||
val targets = (List(
|
||||
Builtin.Exit,
|
||||
Builtin.Printf,
|
||||
Ident("exit", 0)(?),
|
||||
Ident("test", 0)(?)
|
||||
) ++ RuntimeError.all.toList)
|
||||
val labelDefs = targets.map(labelGenerator.getLabelDef(_).toString.dropRight(1)).toSet
|
||||
val labelArgs = targets.map(labelGenerator.getLabelArg(_).toString).toSet
|
||||
assert(labelDefs == labelArgs)
|
||||
}
|
||||
|
||||
"generateConstants" should "generate de-duplicated labels for strings" in {
|
||||
val labelGenerator = new LabelGenerator
|
||||
val strings = List("hello", "world", "hello\u0000world", "hello", "Hello")
|
||||
val distincts = strings.distinct.length
|
||||
val labels = strings.map(labelGenerator.getLabelArg(_).toString).toSet
|
||||
val asmLines = labelGenerator.generateConstants
|
||||
assert(
|
||||
asmLines.collect { case LabelDef(name) =>
|
||||
name
|
||||
}.length == distincts
|
||||
)
|
||||
assert(
|
||||
asmLines.collect { case Directive.Asciz(str) => str }.length == distincts
|
||||
)
|
||||
assert(asmLines.collect { case LabelDef(name) => name }.toList.toSet == labels)
|
||||
}
|
||||
}
|
||||
140
src/test/wacc/backend/stackSpec.scala
Normal file
140
src/test/wacc/backend/stackSpec.scala
Normal file
@@ -0,0 +1,140 @@
|
||||
package wacc
|
||||
|
||||
import org.scalatest.flatspec.AnyFlatSpec
|
||||
import cats.data.Chain
|
||||
|
||||
class StackSpec extends AnyFlatSpec {
|
||||
import microWacc._
|
||||
import assemblyIR._
|
||||
import assemblyIR.Size._
|
||||
import assemblyIR.commonRegisters._
|
||||
import types.{KnownType, ?}
|
||||
import sizeExtensions.size
|
||||
|
||||
private val RSP = Register(Q64, RegName.SP)
|
||||
|
||||
"size" should "be 0 initially" in {
|
||||
val stack = new Stack
|
||||
assert(stack.size == 0)
|
||||
}
|
||||
|
||||
"push" should "add an expression to the stack" in {
|
||||
val stack = new Stack
|
||||
val expr = Ident("x", 0)(?)
|
||||
val result = stack.push(expr, RAX)
|
||||
assert(stack.size == 1)
|
||||
assert(result == Push(RAX))
|
||||
}
|
||||
|
||||
it should "add 2 expressions to the stack" in {
|
||||
val stack = new Stack
|
||||
val expr1 = Ident("x", 0)(?)
|
||||
val expr2 = Ident("x", 1)(?)
|
||||
val result1 = stack.push(expr1, RAX)
|
||||
val result2 = stack.push(expr2, RCX)
|
||||
assert(stack.size == 2)
|
||||
assert(result1 == Push(RAX))
|
||||
assert(result2 == Push(RCX))
|
||||
}
|
||||
|
||||
it should "add a value to the stack" in {
|
||||
val stack = new Stack
|
||||
val result = stack.push(D32, RAX)
|
||||
assert(stack.size == 1)
|
||||
assert(result == Push(RAX))
|
||||
}
|
||||
|
||||
"reserve" should "reserve space for an identifier" in {
|
||||
val stack = new Stack
|
||||
val ident = Ident("x", 0)(KnownType.Int)
|
||||
val result = stack.reserve(ident)
|
||||
assert(stack.size == 1)
|
||||
assert(result == Subtract(RSP, ImmediateVal(Q64.toInt)))
|
||||
}
|
||||
|
||||
it should "reserve space for a register" in {
|
||||
val stack = new Stack
|
||||
val result = stack.reserve(RAX)
|
||||
assert(stack.size == 1)
|
||||
assert(result == Subtract(RSP, ImmediateVal(Q64.toInt)))
|
||||
}
|
||||
|
||||
it should "reserve space for multiple values" in {
|
||||
val stack = new Stack
|
||||
val result = stack.reserve(D32, Q64, B8)
|
||||
assert(stack.size == 3)
|
||||
assert(result == Subtract(RSP, ImmediateVal(Q64.toInt * 3)))
|
||||
}
|
||||
|
||||
"pop" should "remove the last value from the stack" in {
|
||||
val stack = new Stack
|
||||
stack.push(D32, RAX)
|
||||
val result = stack.pop(RAX)
|
||||
assert(stack.size == 0)
|
||||
assert(result == Pop(RAX))
|
||||
}
|
||||
|
||||
"drop" should "remove the last 2 value from the stack" in {
|
||||
val stack = new Stack
|
||||
stack.push(D32, RAX)
|
||||
stack.push(Q64, RAX)
|
||||
stack.push(B8, RAX)
|
||||
val result = stack.drop(2)
|
||||
assert(stack.size == 1)
|
||||
assert(result == Add(RSP, ImmediateVal(Q64.toInt * 2)))
|
||||
}
|
||||
|
||||
"withScope" should "reset stack after block" in {
|
||||
val stack = new Stack
|
||||
stack.push(D32, RAX)
|
||||
stack.push(Q64, RCX)
|
||||
stack.push(B8, RDX)
|
||||
val result = stack.withScope(() =>
|
||||
Chain(
|
||||
stack.push(Q64, RSI),
|
||||
stack.push(B8, RDI),
|
||||
stack.push(B8, RBP)
|
||||
)
|
||||
)
|
||||
assert(stack.size == 3)
|
||||
assert(
|
||||
result == Chain(
|
||||
Push(RSI),
|
||||
Push(RDI),
|
||||
Push(RBP),
|
||||
Add(RSP, ImmediateVal(Q64.toInt * 3))
|
||||
)
|
||||
)
|
||||
}
|
||||
|
||||
"accessVar" should "return the correctly-sized memory location for the identifier" in {
|
||||
val stack = new Stack
|
||||
val id = Ident("x", 0)(KnownType.Int)
|
||||
stack.push(Q64, RAX)
|
||||
stack.push(id, RCX)
|
||||
stack.push(B8, RDX)
|
||||
stack.push(D32, RSI)
|
||||
val result = stack.accessVar(Ident("x", 0)(KnownType.Int))
|
||||
assert(result == MemLocation(RSP, Q64.toInt * 2, opSize = Some(KnownType.Int.size)))
|
||||
}
|
||||
|
||||
"contains" should "return true if the stack contains the identifier" in {
|
||||
val stack = new Stack
|
||||
val id = Ident("x", 0)(KnownType.Int)
|
||||
stack.push(D32, RAX)
|
||||
stack.push(id, RCX)
|
||||
stack.push(B8, RDX)
|
||||
assert(stack.contains(id))
|
||||
assert(!stack.contains(Ident("x", 1)(KnownType.Int)))
|
||||
}
|
||||
|
||||
"head" should "return the correct memory location for the last element" in {
|
||||
val stack = new Stack
|
||||
val id = Ident("x", 0)(KnownType.Int)
|
||||
stack.push(D32, RAX)
|
||||
stack.push(id, RCX)
|
||||
stack.push(B8, RDX)
|
||||
val result = stack.head
|
||||
assert(result == MemLocation(RSP, opSize = Some(B8)))
|
||||
}
|
||||
}
|
||||
@@ -1,10 +1,19 @@
|
||||
package wacc
|
||||
|
||||
import org.scalatest.{ParallelTestExecution, BeforeAndAfterAll}
|
||||
import org.scalatest.flatspec.AnyFlatSpec
|
||||
import org.scalatest.BeforeAndAfterAll
|
||||
import org.scalatest.Inspectors.forEvery
|
||||
import org.scalatest.matchers.should.Matchers._
|
||||
import org.scalatest.freespec.AsyncFreeSpec
|
||||
import cats.effect.testing.scalatest.AsyncIOSpec
|
||||
import java.io.File
|
||||
import java.nio.file.Path
|
||||
import sys.process._
|
||||
import scala.io.Source
|
||||
import cats.effect.IO
|
||||
import wacc.{compile as compileWacc}
|
||||
|
||||
class ParallelExamplesSpec extends AsyncFreeSpec with AsyncIOSpec with BeforeAndAfterAll {
|
||||
|
||||
class ParallelExamplesSpec extends AnyFlatSpec with BeforeAndAfterAll with ParallelTestExecution {
|
||||
val files =
|
||||
allWaccFiles("wacc-examples/valid").map { p =>
|
||||
(p.toString, List(0))
|
||||
@@ -17,73 +26,119 @@ class ParallelExamplesSpec extends AnyFlatSpec with BeforeAndAfterAll with Paral
|
||||
} ++
|
||||
allWaccFiles("wacc-examples/invalid/whack").map { p =>
|
||||
(p.toString, List(100, 200))
|
||||
} ++
|
||||
allWaccFiles("extension/examples/valid").map { p =>
|
||||
(p.toString, List(0))
|
||||
} ++
|
||||
allWaccFiles("extension/examples/invalid/syntax").map { p =>
|
||||
(p.toString, List(100))
|
||||
} ++
|
||||
allWaccFiles("extension/examples/invalid/semantics").map { p =>
|
||||
(p.toString, List(200))
|
||||
}
|
||||
|
||||
// tests go here
|
||||
forEvery(files.filter { (filename, _) =>
|
||||
!fileIsDissallowed(filename)
|
||||
}) { (filename, expectedResult) =>
|
||||
s"$filename" should "be parsed with correct result" in {
|
||||
val contents = os.read(os.Path(filename))
|
||||
assert(expectedResult.contains(compile(contents)))
|
||||
forEvery(files) { (filename, expectedResult) =>
|
||||
val baseFilename = filename.stripSuffix(".wacc")
|
||||
|
||||
s"$filename" - {
|
||||
"should be compiled with correct result" in {
|
||||
if (fileIsPendingFrontend(filename))
|
||||
IO.pure(pending)
|
||||
else
|
||||
compileWacc(Path.of(filename), outputDir = None, log = false).map { result =>
|
||||
expectedResult should contain(result)
|
||||
}
|
||||
}
|
||||
|
||||
if (expectedResult == List(0)) {
|
||||
"should run with correct result" in {
|
||||
if (fileIsDisallowedBackend(filename))
|
||||
IO.pure(succeed)
|
||||
else if (fileIsPendingBackend(filename))
|
||||
IO.pure(pending)
|
||||
else
|
||||
for {
|
||||
contents <- IO(Source.fromFile(File(filename)).getLines.toList)
|
||||
inputLine = extractInput(contents)
|
||||
expectedOutput = extractOutput(contents)
|
||||
expectedExit = extractExit(contents)
|
||||
|
||||
asmFilename = baseFilename + ".s"
|
||||
execFilename = baseFilename
|
||||
gccResult <- IO(s"gcc -o $execFilename -z noexecstack $asmFilename".!)
|
||||
|
||||
_ = assert(gccResult == 0)
|
||||
|
||||
stdout <- IO.pure(new StringBuilder)
|
||||
process <- IO {
|
||||
s"timeout 5s $execFilename" run ProcessIO(
|
||||
in = w => {
|
||||
w.write(inputLine.getBytes)
|
||||
w.close()
|
||||
},
|
||||
out = Source.fromInputStream(_).addString(stdout),
|
||||
err = _ => ()
|
||||
)
|
||||
}
|
||||
|
||||
exitCode <- IO.pure(process.exitValue)
|
||||
|
||||
} yield {
|
||||
exitCode shouldBe expectedExit
|
||||
normalizeOutput(stdout.toString) shouldBe expectedOutput
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
def allWaccFiles(dir: String): IndexedSeq[os.Path] =
|
||||
val d = java.io.File(dir)
|
||||
os.walk(os.Path(d.getAbsolutePath)).filter { _.ext == "wacc" }
|
||||
os.walk(os.Path(d.getAbsolutePath)).filter(_.ext == "wacc")
|
||||
|
||||
def fileIsDissallowed(filename: String): Boolean =
|
||||
Seq(
|
||||
// format: off
|
||||
// disable formatting to avoid binPack
|
||||
// "wacc-examples/valid/advanced",
|
||||
// "wacc-examples/valid/array",
|
||||
// "wacc-examples/valid/basic/exit",
|
||||
// "wacc-examples/valid/basic/skip",
|
||||
// "wacc-examples/valid/expressions",
|
||||
// "wacc-examples/valid/function/nested_functions",
|
||||
// "wacc-examples/valid/function/simple_functions",
|
||||
// "wacc-examples/valid/if",
|
||||
// "wacc-examples/valid/IO/print",
|
||||
// "wacc-examples/valid/IO/read",
|
||||
// "wacc-examples/valid/IO/IOLoop.wacc",
|
||||
// "wacc-examples/valid/IO/IOSequence.wacc",
|
||||
// "wacc-examples/valid/pairs",
|
||||
// "wacc-examples/valid/runtimeErr",
|
||||
// "wacc-examples/valid/scope",
|
||||
// "wacc-examples/valid/sequence",
|
||||
// "wacc-examples/valid/variables",
|
||||
// "wacc-examples/valid/while",
|
||||
// invalid (syntax)
|
||||
// "wacc-examples/invalid/syntaxErr/array",
|
||||
// "wacc-examples/invalid/syntaxErr/basic",
|
||||
// "wacc-examples/invalid/syntaxErr/expressions",
|
||||
// "wacc-examples/invalid/syntaxErr/function",
|
||||
// "wacc-examples/invalid/syntaxErr/if",
|
||||
// "wacc-examples/invalid/syntaxErr/literals",
|
||||
// "wacc-examples/invalid/syntaxErr/pairs",
|
||||
// "wacc-examples/invalid/syntaxErr/print",
|
||||
// "wacc-examples/invalid/syntaxErr/sequence",
|
||||
// "wacc-examples/invalid/syntaxErr/variables",
|
||||
// "wacc-examples/invalid/syntaxErr/while",
|
||||
// invalid (semantic)
|
||||
// "wacc-examples/invalid/semanticErr/array",
|
||||
// "wacc-examples/invalid/semanticErr/exit",
|
||||
// "wacc-examples/invalid/semanticErr/expressions",
|
||||
// "wacc-examples/invalid/semanticErr/function",
|
||||
// "wacc-examples/invalid/semanticErr/if",
|
||||
// "wacc-examples/invalid/semanticErr/IO",
|
||||
// "wacc-examples/invalid/semanticErr/multiple",
|
||||
// "wacc-examples/invalid/semanticErr/pairs",
|
||||
// "wacc-examples/invalid/semanticErr/print",
|
||||
// "wacc-examples/invalid/semanticErr/read",
|
||||
// "wacc-examples/invalid/semanticErr/scope",
|
||||
// "wacc-examples/invalid/semanticErr/variables",
|
||||
// "wacc-examples/invalid/semanticErr/while",
|
||||
// invalid (whack)
|
||||
// "wacc-examples/invalid/whack"
|
||||
// format: on
|
||||
// format: on
|
||||
).find(filename.contains).isDefined
|
||||
private def fileIsDisallowedBackend(filename: String): Boolean =
|
||||
filename.matches("^.*wacc-examples/valid/advanced.*$")
|
||||
|
||||
private def fileIsPendingFrontend(filename: String): Boolean =
|
||||
List(
|
||||
// "^.*extension/examples/invalid/syntax/imports/importBadSyntax.*$",
|
||||
// "^.*extension/examples/invalid/semantics/imports.*$",
|
||||
// "^.*extension/examples/valid/imports.*$"
|
||||
).exists(filename.matches)
|
||||
|
||||
private def fileIsPendingBackend(filename: String): Boolean =
|
||||
List(
|
||||
// "^.*extension/examples/invalid/syntax/imports.*$",
|
||||
// "^.*extension/examples/invalid/semantics/imports.*$",
|
||||
// "^.*extension/examples/valid/imports.*$"
|
||||
).exists(filename.matches)
|
||||
|
||||
private def extractInput(contents: List[String]): String =
|
||||
contents
|
||||
.find(_.matches("^# ?[Ii]nput:.*$"))
|
||||
.map(_.split(":").last.strip + "\n")
|
||||
.getOrElse("")
|
||||
|
||||
private def extractOutput(contents: List[String]): String = {
|
||||
val outputLineIdx = contents.indexWhere(_.matches("^# ?[Oo]utput:.*$"))
|
||||
if (outputLineIdx == -1) ""
|
||||
else
|
||||
contents
|
||||
.drop(outputLineIdx + 1)
|
||||
.takeWhile(_.startsWith("#"))
|
||||
.map(_.stripPrefix("#").stripLeading)
|
||||
.mkString("\n")
|
||||
}
|
||||
|
||||
private def extractExit(contents: List[String]): Int = {
|
||||
val exitLineIdx = contents.indexWhere(_.matches("^# ?[Ee]xit:.*$"))
|
||||
if (exitLineIdx == -1) 0
|
||||
else contents(exitLineIdx + 1).stripPrefix("#").strip.toInt
|
||||
}
|
||||
|
||||
private def normalizeOutput(output: String): String =
|
||||
output
|
||||
.replaceAll("0x[0-9a-f]+", "#addrs#")
|
||||
.replaceAll("fatal error:.*", "#runtime_error#\u0000")
|
||||
.takeWhile(_ != '\u0000')
|
||||
}
|
||||
|
||||
4
wacc-syntax/.vscodeignore
Normal file
4
wacc-syntax/.vscodeignore
Normal file
@@ -0,0 +1,4 @@
|
||||
.vscode/**
|
||||
.vscode-test/**
|
||||
.gitignore
|
||||
vsc-extension-quickstart.md
|
||||
9
wacc-syntax/CHANGELOG.md
Normal file
9
wacc-syntax/CHANGELOG.md
Normal file
@@ -0,0 +1,9 @@
|
||||
# Change Log
|
||||
|
||||
All notable changes to the "wacc-syntax" extension will be documented in this file.
|
||||
|
||||
Check [Keep a Changelog](http://keepachangelog.com/) for recommendations on how to structure this file.
|
||||
|
||||
## [Unreleased]
|
||||
|
||||
- Initial release
|
||||
7
wacc-syntax/README.md
Normal file
7
wacc-syntax/README.md
Normal file
@@ -0,0 +1,7 @@
|
||||
### INTELLIWACC
|
||||
|
||||
This is the IntelliWACC extension for WACC code development; featuring syntax highlighting, error messages/highlighting and imports.
|
||||
|
||||
This extension was developed as a part of the "WACC Extensions" milestone 2025.
|
||||
|
||||
Authored by Alex L, Gleb K, Guy C and Jonny T
|
||||
BIN
wacc-syntax/README.pdf
Normal file
BIN
wacc-syntax/README.pdf
Normal file
Binary file not shown.
71
wacc-syntax/extension.js
Normal file
71
wacc-syntax/extension.js
Normal file
@@ -0,0 +1,71 @@
|
||||
// Developed using the VSC language extension tutorial
|
||||
// https://code.visualstudio.com/api/language-extensions/overview
|
||||
|
||||
const vscode = require('vscode');
|
||||
const { execSync } = require('child_process');
|
||||
const { parse } = require('path');
|
||||
|
||||
function activate(context) {
|
||||
console.log('IntelliWACC is now active!');
|
||||
|
||||
let diagnosticCollection = vscode.languages.createDiagnosticCollection('wacc');
|
||||
context.subscriptions.push(diagnosticCollection);
|
||||
|
||||
vscode.workspace.onDidSaveTextDocument((document) => {
|
||||
if (document.languageId !== 'wacc') return;
|
||||
|
||||
let diagnostics = [];
|
||||
let errors = generateErrors(document.getText(), document.fileName);
|
||||
errors.forEach(error => {
|
||||
console.log(error);
|
||||
let range = new vscode.Range(error.line - 1 , error.column - 1, error.line - 1, error.column + error.size);
|
||||
let diagnostic = new vscode.Diagnostic(range, error.errorMessage, vscode.DiagnosticSeverity.Error);
|
||||
diagnostics.push(diagnostic);
|
||||
});
|
||||
|
||||
diagnosticCollection.set(document.uri, diagnostics);
|
||||
});
|
||||
}
|
||||
|
||||
function deactivate() {
|
||||
console.log('IntelliWACC is deactivating...');
|
||||
}
|
||||
|
||||
function generateErrors(code, filePath) {
|
||||
try {
|
||||
console.log("generating errors")
|
||||
const fs = require('fs');
|
||||
const tmpFilePath = parse(filePath).dir + '/.temp_wacc_file.wacc';
|
||||
fs.writeFileSync(tmpFilePath, code);
|
||||
|
||||
let output;
|
||||
try {
|
||||
const waccExePath = `${__dirname}/wacc-compiler`;
|
||||
output = execSync(`${waccExePath} ${tmpFilePath}`, { encoding: 'utf8', shell: true, stdio: 'pipe'});
|
||||
} catch (err) {
|
||||
console.log("Error running compiler");
|
||||
output = err.stdout;
|
||||
console.log(output);
|
||||
}
|
||||
let errors = [];
|
||||
errorRegex = /\(line ([\d]+), column ([\d]+)\):\n([^>]+)([^\^]+)([\^]+)\n([^\n]+)([^\(]*)/g
|
||||
while((match = errorRegex.exec(output)) !== null) {
|
||||
console.log(match[5]);
|
||||
errors.push({
|
||||
line: parseInt(match[1], 10),
|
||||
column: parseInt(match[2], 10),
|
||||
errorMessage: match[3].trim(),
|
||||
size: match[5].length - 1
|
||||
});
|
||||
}
|
||||
return errors;
|
||||
} catch (err) {
|
||||
console.error('Error running compiler:', err);
|
||||
return [];
|
||||
}
|
||||
}
|
||||
|
||||
module.exports = {
|
||||
activate,
|
||||
deactivate
|
||||
};
|
||||
BIN
wacc-syntax/icon.png
Normal file
BIN
wacc-syntax/icon.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 571 KiB |
28
wacc-syntax/language-configuration.json
Normal file
28
wacc-syntax/language-configuration.json
Normal file
@@ -0,0 +1,28 @@
|
||||
{
|
||||
"comments": {
|
||||
// symbol used for single line comment. Remove this entry if your language does not support line comments
|
||||
"lineComment": "#",
|
||||
},
|
||||
// symbols used as brackets
|
||||
"brackets": [
|
||||
["{", "}"],
|
||||
["[", "]"],
|
||||
["(", ")"]
|
||||
],
|
||||
// symbols that are auto closed when typing
|
||||
"autoClosingPairs": [
|
||||
["{", "}"],
|
||||
["[", "]"],
|
||||
["(", ")"],
|
||||
["\"", "\""],
|
||||
["'", "'"]
|
||||
],
|
||||
// symbols that can be used to surround a selection
|
||||
"surroundingPairs": [
|
||||
["{", "}"],
|
||||
["[", "]"],
|
||||
["(", ")"],
|
||||
["\"", "\""],
|
||||
["'", "'"]
|
||||
]
|
||||
}
|
||||
15
wacc-syntax/package-lock.json
generated
Normal file
15
wacc-syntax/package-lock.json
generated
Normal file
@@ -0,0 +1,15 @@
|
||||
{
|
||||
"name": "wacc-syntax",
|
||||
"version": "0.0.1",
|
||||
"lockfileVersion": 3,
|
||||
"requires": true,
|
||||
"packages": {
|
||||
"": {
|
||||
"name": "wacc-syntax",
|
||||
"version": "0.0.1",
|
||||
"engines": {
|
||||
"vscode": "^1.97.0"
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
40
wacc-syntax/package.json
Normal file
40
wacc-syntax/package.json
Normal file
@@ -0,0 +1,40 @@
|
||||
{
|
||||
"name": "wacc-syntax",
|
||||
"displayName": "intelliWACC",
|
||||
"description": "WACC language support features",
|
||||
"version": "0.0.1",
|
||||
"publisher": "WACC-37-2025",
|
||||
"icon": "icon.png",
|
||||
"engines": {
|
||||
"vscode": "^1.97.0"
|
||||
},
|
||||
"categories": [
|
||||
"Programming Languages"
|
||||
],
|
||||
"contributes": {
|
||||
"languages": [{
|
||||
"id": "wacc",
|
||||
"aliases": ["WACC", "wacc"],
|
||||
"extensions": [".wacc"],
|
||||
"configuration": "./language-configuration.json"
|
||||
}],
|
||||
"grammars": [{
|
||||
"language": "wacc",
|
||||
"scopeName": "source.wacc",
|
||||
"path": "./syntaxes/wacc.tmLanguage.json"
|
||||
}],
|
||||
"properties": {
|
||||
"files.exclude": {
|
||||
"type": "object",
|
||||
"default": {
|
||||
"**/.temp_wacc_file.*": true
|
||||
},
|
||||
"description": "Configure patterns for excluding files and folders."
|
||||
}
|
||||
}
|
||||
},
|
||||
"scripts": {
|
||||
"build": "vsce package"
|
||||
},
|
||||
"main": "./extension.js"
|
||||
}
|
||||
56
wacc-syntax/syntaxes/wacc.tmLanguage.json
Normal file
56
wacc-syntax/syntaxes/wacc.tmLanguage.json
Normal file
@@ -0,0 +1,56 @@
|
||||
{
|
||||
"$schema": "https://raw.githubusercontent.com/martinring/tmlanguage/master/tmlanguage.json",
|
||||
"name": "WACC",
|
||||
"scopeName": "source.wacc",
|
||||
"fileTypes": [
|
||||
"wacc"
|
||||
],
|
||||
"patterns": [
|
||||
{
|
||||
"match": "\\b(true|false)\\b",
|
||||
"name": "keyword.constant.wacc"
|
||||
},
|
||||
{
|
||||
"match": "\\b(int|bool|char|string|pair|null)\\b",
|
||||
"name": "storage.type.wacc"
|
||||
},
|
||||
{
|
||||
"match": "\".*?\"",
|
||||
"name": "string.quoted.double.mylang"
|
||||
},
|
||||
{
|
||||
"match": "\\b(begin|end)\\b",
|
||||
"name": "keyword.other.unit"
|
||||
},
|
||||
{
|
||||
"match": "\\b(if|then|else|fi|while|do|done|skip|is)\\b",
|
||||
"name": "keyword.control.wacc"
|
||||
},
|
||||
{
|
||||
"match": "\\b(read|free|print|println|newpair|call|fst|snd|ord|chr|len)\\b",
|
||||
"name": "keyword.operator.function.wacc"
|
||||
},
|
||||
{
|
||||
"match": "\\b(return|exit)\\b",
|
||||
"name": "keyword.operator.wacc"
|
||||
},
|
||||
{
|
||||
"match": "'[^']{1}'",
|
||||
"name": "constant.character.wacc"
|
||||
},
|
||||
{
|
||||
"match": "\\b([a-zA-Z_][a-zA-Z0-9_]*)\\s*(?=\\()",
|
||||
"name": "variable.function.wacc"
|
||||
},
|
||||
{
|
||||
"match": "\\b([a-zA-Z_][a-zA-Z0-9_]*)\\b",
|
||||
"name": "variable.other.wacc"
|
||||
},
|
||||
{
|
||||
"match": "#.*$",
|
||||
"name": "comment.line"
|
||||
}
|
||||
|
||||
]
|
||||
|
||||
}
|
||||
29
wacc-syntax/vsc-extension-quickstart.md
Normal file
29
wacc-syntax/vsc-extension-quickstart.md
Normal file
@@ -0,0 +1,29 @@
|
||||
# Welcome to your VS Code Extension
|
||||
|
||||
## What's in the folder
|
||||
|
||||
* This folder contains all of the files necessary for your extension.
|
||||
* `package.json` - this is the manifest file in which you declare your language support and define the location of the grammar file that has been copied into your extension.
|
||||
* `syntaxes/wacc.tmLanguage.json` - this is the Text mate grammar file that is used for tokenization.
|
||||
* `language-configuration.json` - this is the language configuration, defining the tokens that are used for comments and brackets.
|
||||
|
||||
## Get up and running straight away
|
||||
|
||||
* Make sure the language configuration settings in `language-configuration.json` are accurate.
|
||||
* Press `F5` to open a new window with your extension loaded.
|
||||
* Create a new file with a file name suffix matching your language.
|
||||
* Verify that syntax highlighting works and that the language configuration settings are working.
|
||||
|
||||
## Make changes
|
||||
|
||||
* You can relaunch the extension from the debug toolbar after making changes to the files listed above.
|
||||
* You can also reload (`Ctrl+R` or `Cmd+R` on Mac) the VS Code window with your extension to load your changes.
|
||||
|
||||
## Add more language features
|
||||
|
||||
* To add features such as IntelliSense, hovers and validators check out the VS Code extenders documentation at https://code.visualstudio.com/docs
|
||||
|
||||
## Install your extension
|
||||
|
||||
* To start using your extension with Visual Studio Code copy it into the `<user home>/.vscode/extensions` folder and restart Code.
|
||||
* To share your extension with the world, read on https://code.visualstudio.com/docs about publishing an extension.
|
||||
BIN
wacc-syntax/wacc-compiler
Executable file
BIN
wacc-syntax/wacc-compiler
Executable file
Binary file not shown.
1
wacc.target
Normal file
1
wacc.target
Normal file
@@ -0,0 +1 @@
|
||||
x86-64
|
||||
Reference in New Issue
Block a user