diff --git a/.gitignore b/.gitignore index 291a3dd611addcc307a74a444814a32d914982e9..959fc7f691c367d0cab841748d21965fea8b7698 100644 --- a/.gitignore +++ b/.gitignore @@ -5,3 +5,5 @@ *.ll *.c *.o + +.*.swp diff --git a/Cargo.lock b/Cargo.lock index b5a72ec6d4576aedb78b75f5cf099017e9111c1a..1ec60bdb9f1cca198bcbec7292e6f0ccf51e179d 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2,11 +2,20 @@ # It is not intended for manual editing. version = 3 +[[package]] +name = "aho-corasick" +version = "1.1.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b2969dcb958b36655471fc61f7e416fa76033bdd4bfed0678d8fee1e2d07a1f0" +dependencies = [ + "memchr", +] + [[package]] name = "anstream" -version = "0.6.11" +version = "0.6.13" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6e2e1ebcb11de5c03c67de28a7df593d32191b44939c482e97702baaaa6ab6a5" +checksum = "d96bd03f33fe50a863e394ee9718a706f988b9079b20c3784fb726e7678b62fb" dependencies = [ "anstyle", "anstyle-parse", @@ -50,12 +59,33 @@ dependencies = [ "windows-sys", ] +[[package]] +name = "anyhow" +version = "1.0.80" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5ad32ce52e4161730f7098c077cd2ed6229b5804ccf99e5366be1ab72a98b4e1" + [[package]] name = "autocfg" version = "1.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d468802bab17cbc0cc575e9b053f41e72aa36bfa6b7f55e3529ffa43161b97fa" +[[package]] +name = "bincode" +version = "1.3.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b1f45e9417d87227c7a56d22e471c6206462cba514c7590c09aff4cf6d1ddcad" +dependencies = [ + "serde", +] + +[[package]] +name = "bitflags" +version = "1.3.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bef38d45163c2f1dde094a7dfd33ccf595c92905c8f8f4fdc18d06fb1037718a" + [[package]] name = "bitvec" version = "1.0.1" @@ -68,17 +98,37 @@ dependencies = [ "wyz", ] +[[package]] +name = "cactus" +version = "1.0.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "acbc26382d871df4b7442e3df10a9402bf3cf5e55cbd66f12be38861425f0564" + [[package]] name = "cfg-if" version = "1.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "baf1de4339761588bc0619e3cbc0120ee582ebb74b53b4efbf79117bd2da40fd" +[[package]] +name = "cfgrammar" +version = "0.13.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "829b900c3abfb519f93fe713765625385117372cc83b9c68c22d3d8807d89440" +dependencies = [ + "indexmap", + "lazy_static", + "num-traits", + "regex", + "serde", + "vob", +] + [[package]] name = "clap" -version = "4.4.18" +version = "4.5.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1e578d6ec4194633722ccf9544794b71b1385c3c027efe0c55db226fc880865c" +checksum = "b230ab84b0ffdf890d5a10abdbc8b83ae1c4918275daea1ab8801f71536b2651" dependencies = [ "clap_builder", "clap_derive", @@ -86,9 +136,9 @@ dependencies = [ [[package]] name = "clap_builder" -version = "4.4.18" +version = "4.5.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4df4df40ec50c46000231c914968278b1eb05098cf8f1b3a518a95030e71d1c7" +checksum = "ae129e2e766ae0ec03484e609954119f123cc1fe650337e155d03b022f24f7b4" dependencies = [ "anstream", "anstyle", @@ -98,9 +148,9 @@ dependencies = [ [[package]] name = "clap_derive" -version = "4.4.7" +version = "4.5.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "cf9804afaaf59a91e75b022a30fb7229a7901f60c755489cc61c9b423b836442" +checksum = "307bc0538d5f0f83b8248db3087aa92fe504e4691294d0c96c0eabc33f47ba47" dependencies = [ "heck", "proc-macro2", @@ -110,9 +160,9 @@ dependencies = [ [[package]] name = "clap_lex" -version = "0.6.0" +version = "0.7.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "702fc72eb24e5a1e48ce58027a675bc24edd52096d5397d4aea7c6dd9eca0bd1" +checksum = "98cc8fbded0c607b7ba9dd60cd98df59af97e84d24e49c8557331cfc26d301ce" [[package]] name = "colorchoice" @@ -120,6 +170,15 @@ version = "1.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "acbf1af155f9b9ef647e42cdc158db4b64a1b61f743629225fde6f3e0be2a7c7" +[[package]] +name = "deranged" +version = "0.3.11" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b42b6fa04a440b495c8b04d0e71b707c585f83cb9cb28cf8cd0d976c315e31b4" +dependencies = [ + "powerfmt", +] + [[package]] name = "ena" version = "0.14.2" @@ -129,12 +188,45 @@ dependencies = [ "log", ] +[[package]] +name = "equivalent" +version = "1.0.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5443807d6dff69373d433ab9ef5378ad8df50ca6298caf15de6e52e24aaf54d5" + +[[package]] +name = "filetime" +version = "0.2.23" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1ee447700ac8aa0b2f2bd7bc4462ad686ba06baa6727ac149a2d6277f0d240fd" +dependencies = [ + "cfg-if", + "libc", + "redox_syscall", + "windows-sys", +] + +[[package]] +name = "fnv" +version = "1.0.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3f9eec918d3f24069decb9af1554cad7c880e2da24a9afd88aca000531ab82c1" + [[package]] name = "funty" version = "2.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e6d5a32815ae3f33302d95fdcb2ce17862f8c65363dcfd29360480ba1001fc9c" +[[package]] +name = "getopts" +version = "0.2.21" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "14dbbfd5c71d70241ecf9e6f13737f7b5ce823821063188d7e46c41d371eebd5" +dependencies = [ + "unicode-width", +] + [[package]] name = "getrandom" version = "0.2.12" @@ -146,6 +238,12 @@ dependencies = [ "wasi", ] +[[package]] +name = "hashbrown" +version = "0.14.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "290f1a1d9242c78d09ce40a5e87e7554ee637af1351968159f4952f028f75604" + [[package]] name = "heck" version = "0.4.1" @@ -218,6 +316,41 @@ dependencies = [ "libc", ] +[[package]] +name = "indexmap" +version = "2.2.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7b0b929d511467233429c45a44ac1dcaa21ba0f5ba11e4879e6ed28ddb4f9df4" +dependencies = [ + "equivalent", + "hashbrown", +] + +[[package]] +name = "itoa" +version = "1.0.10" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b1a46d1a171d865aa5f83f92695765caa047a9b4cbae2cbf37dbd613a793fd4c" + +[[package]] +name = "juno_frontend" +version = "0.1.0" +dependencies = [ + "cfgrammar", + "clap", + "hercules_ir", + "hercules_opt", + "lrlex", + "lrpar", + "ordered-float", +] + +[[package]] +name = "lazy_static" +version = "1.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e2abad23fbc42b3700f2f279844dc832adb2b2eb069b2df918f455c4e18cc646" + [[package]] name = "libc" version = "0.2.153" @@ -226,9 +359,63 @@ checksum = "9c198f91728a82281a64e1f4f9eeb25d82cb32a5de251c6bd1b5154d63a8e7bd" [[package]] name = "log" -version = "0.4.20" +version = "0.4.21" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "90ed8c1e510134f979dbc4f070f87d4313098b704861a105fe34231c70a3901c" + +[[package]] +name = "lrlex" +version = "0.13.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1bd189d9a727faaf7ec012bb2939f5c843aee894d02715e97c2882c2d2e96093" +dependencies = [ + "cfgrammar", + "getopts", + "lazy_static", + "lrpar", + "num-traits", + "quote", + "regex", + "regex-syntax 0.7.5", + "serde", + "vergen", +] + +[[package]] +name = "lrpar" +version = "0.13.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b5e6163cb8c49088c2c36f57875e58ccd8c87c7427f7fbd50ea6710b2f3f2e8f" +checksum = "9ca23a8fd40ca43869fa8e6cf5e4d87c3fce88d21a109f7e07126ebd8369b494" +dependencies = [ + "bincode", + "cactus", + "cfgrammar", + "filetime", + "indexmap", + "lazy_static", + "lrtable", + "num-traits", + "packedvec", + "regex", + "serde", + "static_assertions", + "vergen", + "vob", +] + +[[package]] +name = "lrtable" +version = "0.13.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ee337773c4a85c7a270fbc4ca500ece4abb288cd9c3ac60f642441b96b9f53a7" +dependencies = [ + "cfgrammar", + "fnv", + "num-traits", + "serde", + "sparsevec", + "vob", +] [[package]] name = "memchr" @@ -252,15 +439,30 @@ dependencies = [ "minimal-lexical", ] +[[package]] +name = "num-conv" +version = "0.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "51d515d32fb182ee37cda2ccdcb92950d6a3c2893aa280e540671c2cd0f3b1d9" + [[package]] name = "num-traits" -version = "0.2.17" +version = "0.2.18" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "39e3200413f237f41ab11ad6d161bc7239c84dcb631773ccd7de3dfe4b5c267c" +checksum = "da0df0e5185db44f69b44f26786fe401b6c293d1907744beaa7fa62b2e5a517a" dependencies = [ "autocfg", ] +[[package]] +name = "num_threads" +version = "0.1.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5c7398b9c8b70908f6371f47ed36737907c87c52af34c268fed0bf0ceb92ead9" +dependencies = [ + "libc", +] + [[package]] name = "ordered-float" version = "4.2.0" @@ -270,6 +472,22 @@ dependencies = [ "num-traits", ] +[[package]] +name = "packedvec" +version = "1.2.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bde3c690ec20e4a2b4fb46f0289a451181eb50011a1e2acc8d85e2fde9062a45" +dependencies = [ + "num-traits", + "serde", +] + +[[package]] +name = "powerfmt" +version = "0.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "439ee305def115ba05938db6eb1644ff94165c5ab5e9420d1c1bcedbba909391" + [[package]] name = "ppv-lite86" version = "0.2.17" @@ -330,17 +548,120 @@ dependencies = [ "getrandom", ] +[[package]] +name = "redox_syscall" +version = "0.4.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4722d768eff46b75989dd134e5c353f0d6296e5aaa3132e776cbdb56be7731aa" +dependencies = [ + "bitflags", +] + +[[package]] +name = "regex" +version = "1.10.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b62dbe01f0b06f9d8dc7d49e05a0785f153b00b2c227856282f671e0318c9b15" +dependencies = [ + "aho-corasick", + "memchr", + "regex-automata", + "regex-syntax 0.8.2", +] + +[[package]] +name = "regex-automata" +version = "0.4.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "86b83b8b9847f9bf95ef68afb0b8e6cdb80f498442f5179a29fad448fcc1eaea" +dependencies = [ + "aho-corasick", + "memchr", + "regex-syntax 0.8.2", +] + +[[package]] +name = "regex-syntax" +version = "0.7.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "dbb5fb1acd8a1a18b3dd5be62d25485eb770e05afb408a9627d14d451bae12da" + +[[package]] +name = "regex-syntax" +version = "0.8.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c08c74e62047bb2de4ff487b251e4a92e24f48745648451635cec7d591162d9f" + +[[package]] +name = "rustc_version" +version = "0.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bfa0f585226d2e68097d4f95d113b15b83a82e819ab25717ec0590d9584ef366" +dependencies = [ + "semver", +] + +[[package]] +name = "rustversion" +version = "1.0.14" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7ffc183a10b4478d04cbbbfc96d0873219d962dd5accaff2ffbd4ceb7df837f4" + +[[package]] +name = "semver" +version = "1.0.22" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "92d43fe69e652f3df9bdc2b85b2854a0825b86e4fb76bc44d945137d053639ca" + +[[package]] +name = "serde" +version = "1.0.197" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3fb1c873e1b9b056a4dc4c0c198b24c3ffa059243875552b2bd0933b1aee4ce2" +dependencies = [ + "serde_derive", +] + +[[package]] +name = "serde_derive" +version = "1.0.197" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7eb0b34b42edc17f6b7cac84a52a1c5f0e1bb2227e997ca9011ea3dd34e8610b" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + +[[package]] +name = "sparsevec" +version = "0.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "35df5d2e580b29f3f7ec5b4ed49b0ab3acf7f3624122b3e823cafb9630f293b8" +dependencies = [ + "num-traits", + "packedvec", + "serde", + "vob", +] + +[[package]] +name = "static_assertions" +version = "1.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a2eb9349b6444b326872e140eb1cf5e7c522154d69e7a0ffb0fb81c06b37543f" + [[package]] name = "strsim" -version = "0.10.0" +version = "0.11.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "73473c0e59e6d5812c5dfe2a064a6444949f089e20eec9a2e5506596494e4623" +checksum = "5ee073c9e4cd00e28217186dbe12796d692868f432bf2e97ee73bed0c56dfa01" [[package]] name = "syn" -version = "2.0.48" +version = "2.0.52" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0f3531638e407dfc0814761abb7c00a5b54992b849452a0646b7f65c9f770f3f" +checksum = "b699d15b36d1f02c3e7c69f8ffef53de37aefae075d8488d4ba1a7788d574a07" dependencies = [ "proc-macro2", "quote", @@ -359,18 +680,79 @@ version = "1.0.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "55937e1799185b12863d447f42597ed69d9928686b8d88a1df17376a097d8369" +[[package]] +name = "time" +version = "0.3.34" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c8248b6521bb14bc45b4067159b9b6ad792e2d6d754d6c41fb50e29fefe38749" +dependencies = [ + "deranged", + "itoa", + "libc", + "num-conv", + "num_threads", + "powerfmt", + "serde", + "time-core", + "time-macros", +] + +[[package]] +name = "time-core" +version = "0.1.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ef927ca75afb808a4d64dd374f00a2adf8d0fcff8e7b184af886c3c87ec4a3f3" + +[[package]] +name = "time-macros" +version = "0.2.17" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7ba3a3ef41e6672a2f0f001392bb5dcd3ff0a9992d618ca761a11c3121547774" +dependencies = [ + "num-conv", + "time-core", +] + [[package]] name = "unicode-ident" version = "1.0.12" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "3354b9ac3fae1ff6755cb6db53683adb661634f67557942dea4facebec0fee4b" +[[package]] +name = "unicode-width" +version = "0.1.11" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e51733f11c9c4f72aa0c160008246859e340b00807569a0da0e7a1079b27ba85" + [[package]] name = "utf8parse" version = "0.2.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "711b9620af191e0cdc7468a8d14e709c3dcdb115b36f838e601583af800a370a" +[[package]] +name = "vergen" +version = "8.3.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e27d6bdd219887a9eadd19e1c34f32e47fa332301184935c6d9bca26f3cca525" +dependencies = [ + "anyhow", + "rustversion", + "time", +] + +[[package]] +name = "vob" +version = "3.0.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c058f4c41e71a043c67744cb76dcc1ae63ece328c1732a72489ccccc2dec23e6" +dependencies = [ + "num-traits", + "rustc_version", + "serde", +] + [[package]] name = "wasi" version = "0.11.0+wasi-snapshot-preview1" @@ -388,9 +770,9 @@ dependencies = [ [[package]] name = "windows-targets" -version = "0.52.0" +version = "0.52.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8a18201040b24831fbb9e4eb208f8892e1f50a37feb53cc7ff887feb8f50e7cd" +checksum = "7dd37b7e5ab9018759f893a1952c9420d060016fc19a472b4bb20d1bdd694d1b" dependencies = [ "windows_aarch64_gnullvm", "windows_aarch64_msvc", @@ -403,45 +785,45 @@ dependencies = [ [[package]] name = "windows_aarch64_gnullvm" -version = "0.52.0" +version = "0.52.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "cb7764e35d4db8a7921e09562a0304bf2f93e0a51bfccee0bd0bb0b666b015ea" +checksum = "bcf46cf4c365c6f2d1cc93ce535f2c8b244591df96ceee75d8e83deb70a9cac9" [[package]] name = "windows_aarch64_msvc" -version = "0.52.0" +version = "0.52.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "bbaa0368d4f1d2aaefc55b6fcfee13f41544ddf36801e793edbbfd7d7df075ef" +checksum = "da9f259dd3bcf6990b55bffd094c4f7235817ba4ceebde8e6d11cd0c5633b675" [[package]] name = "windows_i686_gnu" -version = "0.52.0" +version = "0.52.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a28637cb1fa3560a16915793afb20081aba2c92ee8af57b4d5f28e4b3e7df313" +checksum = "b474d8268f99e0995f25b9f095bc7434632601028cf86590aea5c8a5cb7801d3" [[package]] name = "windows_i686_msvc" -version = "0.52.0" +version = "0.52.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ffe5e8e31046ce6230cc7215707b816e339ff4d4d67c65dffa206fd0f7aa7b9a" +checksum = "1515e9a29e5bed743cb4415a9ecf5dfca648ce85ee42e15873c3cd8610ff8e02" [[package]] name = "windows_x86_64_gnu" -version = "0.52.0" +version = "0.52.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3d6fa32db2bc4a2f5abeacf2b69f7992cd09dca97498da74a151a3132c26befd" +checksum = "5eee091590e89cc02ad514ffe3ead9eb6b660aedca2183455434b93546371a03" [[package]] name = "windows_x86_64_gnullvm" -version = "0.52.0" +version = "0.52.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1a657e1e9d3f514745a572a6846d3c7aa7dbe1658c056ed9c3344c4109a6949e" +checksum = "77ca79f2451b49fa9e2af39f0747fe999fcda4f5e241b2898624dca97a1f2177" [[package]] name = "windows_x86_64_msvc" -version = "0.52.0" +version = "0.52.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "dff9641d1cd4be8d1a070daf9e3773c5f67e78b4d9d42263020c057706765c04" +checksum = "32b752e52a2da0ddfbdbcc6fceadfeede4c939ed16d13e648833a61dfb611ed8" [[package]] name = "wyz" diff --git a/Cargo.toml b/Cargo.toml index 8320ec1520c225375638586a3b7c12e5ab4501ce..d625efb3dbef4e42001328b5800c34a7d8e7154f 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,5 +1,5 @@ [workspace] - +resolver = "2" members = [ "hercules_cg", "hercules_ir", @@ -9,5 +9,7 @@ members = [ "hercules_tools/hercules_dot", "hercules_tools/hercules_cpu", + "juno_frontend", + "hercules_samples/matmul" ] diff --git a/DESIGN.md b/DESIGN.md index c34c3023305a1625e524d65f5e5369a12c23d1be..841a05f91c372fbb5e61425be143edd3bd849c1e 100644 --- a/DESIGN.md +++ b/DESIGN.md @@ -12,7 +12,8 @@ The following sections contain information on how Hercules is designed to meet t ## Front-end Language Design -TODO: @aaronjc4 +The front-end language Juno is a relatively standard imperative programming language based on Rust syntax, though using Mutual Value Semantics and making certain restrictions useful for compiling of to heterogeneous systems. +More information about Juno can be found in [juno_frontend/README.md](juno_frontend/README.md). ## Scheduling Language Design diff --git a/hercules_ir/Cargo.toml b/hercules_ir/Cargo.toml index b68e6b24d65322e9a92981c45ce1711315dee54e..39fbebe550ea1f26d3f0cd28f443123a90c570a3 100644 --- a/hercules_ir/Cargo.toml +++ b/hercules_ir/Cargo.toml @@ -1,7 +1,7 @@ [package] name = "hercules_ir" version = "0.1.0" -authors = ["Russel Arbore <rarbore2@illinois.edu>"] +authors = ["Russel Arbore <rarbore2@illinois.edu>, Aaron Councilman <aaronjc4@illinois.edu>"] [dependencies] rand = "*" diff --git a/hercules_ir/src/build.rs b/hercules_ir/src/build.rs index 4345f7125f7da274d7f9f19d1d0709661dde6a29..6fe2edf37aae217a50302b53fa61ef60b2ba3b0e 100644 --- a/hercules_ir/src/build.rs +++ b/hercules_ir/src/build.rs @@ -370,6 +370,10 @@ impl<'a> Builder<'a> { Ok(self.intern_constant(Constant::Array(ty, cons), ty)) } + pub fn create_constant_zero(&mut self, typ : TypeID) -> ConstantID { + self.intern_constant(Constant::Zero(typ), typ) + } + pub fn create_dynamic_constant_constant(&mut self, val: usize) -> DynamicConstantID { self.intern_dynamic_constant(DynamicConstant::Constant(val)) } @@ -396,7 +400,7 @@ impl<'a> Builder<'a> { pub fn create_function( &mut self, - name: &'a str, + name: &str, param_types: Vec<TypeID>, return_type: TypeID, num_dynamic_constants: u32, diff --git a/hercules_ir/src/ir.rs b/hercules_ir/src/ir.rs index 466451791dde6ae771db61da484d010c7ed771db..0688c13749ec312e731a134bb3862c6b945adbb9 100644 --- a/hercules_ir/src/ir.rs +++ b/hercules_ir/src/ir.rs @@ -1168,3 +1168,244 @@ define_id_type!(NodeID); define_id_type!(TypeID); define_id_type!(ConstantID); define_id_type!(DynamicConstantID); + +/* + * Sometimes, it's useful to debug print out a module. This code prints out a + * module in (approximately) the same textual format as is parsed in parse.rs. + */ +use std::fmt::Display; +use std::fmt::Formatter; + +impl Display for Module { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + for func in self.functions.iter() { + func.ir_fmt(f, self)?; + write!(f, "\n")?; + } + Ok(()) + } +} + +/* + * When printing out objects in a module, we may need to refer back (upwards) to + * other objects in the module. Display doesn't let us do that, so we make our + * own trait. + */ +trait IRDisplay { + fn ir_fmt(&self, f: &mut Formatter<'_>, module: &Module) -> std::fmt::Result; +} + +impl IRDisplay for Function { + fn ir_fmt(&self, f: &mut Formatter<'_>, module: &Module) -> std::fmt::Result { + write!(f, "fn {}<{}>(", self.name, self.num_dynamic_constants)?; + + for (idx, typ) in self.param_types.iter().enumerate() { + write!(f, "arg_{} : ", idx)?; + module.write_type(*typ, f)?; + if idx + 1 < self.param_types.len() { + write!(f, ", ")?; + } + } + + write!(f, ") -> ")?; + module.write_type(self.return_type, f)?; + + write!(f, "\n")?; + + for (idx, node) in self.nodes.iter().enumerate() { + write!(f, "\tvar_{} = ", idx)?; + node.ir_fmt(f, module)?; + write!(f, "\n")?; + } + + Ok(()) + } +} + +impl IRDisplay for Node { + fn ir_fmt(&self, f: &mut Formatter<'_>, module: &Module) -> std::fmt::Result { + match self { + Node::Start => { + write!(f, "start") + } + Node::Region { preds } => { + write!(f, "region(")?; + for (idx, pred) in preds.iter().enumerate() { + write!(f, "var_{}", pred.0)?; + if idx + 1 < preds.len() { + write!(f, ", ")?; + } + } + write!(f, ")") + } + Node::If { control, cond } => { + write!(f, "if(var_{}, var_{})", control.0, cond.0) + } + Node::Match { control, sum } => { + write!(f, "match(var_{}, var_{})", control.0, sum.0) + } + Node::Fork { control, factor } => { + write!(f, "fork(var_{}, ", control.0)?; + module.write_dynamic_constant(*factor, f)?; + write!(f, ")") + } + Node::Join { control } => { + write!(f, "join(var_{})", control.0) + } + Node::Phi { control, data } => { + write!(f, "phi(var_{}", control.0)?; + for val in data.iter() { + write!(f, ", var_{}", val.0)?; + } + write!(f, ")") + } + Node::ThreadID { control } => { + write!(f, "thread_id(var_{})", control.0) + } + Node::Reduce { + control, + init, + reduct, + } => { + write!( + f, + "reduce(var_{}, var_{}, var_{})", + control.0, init.0, reduct.0 + ) + } + Node::Return { control, data } => { + write!(f, "return(var_{}, var_{})", control.0, data.0) + } + Node::Parameter { index } => { + write!(f, "arg_{}", index) + } + Node::Constant { id } => { + write!(f, "constant(")?; + module.constants[id.idx()].ir_fmt(f, module)?; + write!(f, ")") + } + Node::DynamicConstant { id } => { + write!(f, "dynamic_constant(")?; + module.write_dynamic_constant(*id, f)?; + write!(f, ")") + } + Node::Unary { input, op } => { + write!(f, "{}(var_{})", op.lower_case_name(), input.0) + } + Node::Binary { left, right, op } => { + write!( + f, + "{}(var_{}, var_{})", + op.lower_case_name(), + left.0, + right.0 + ) + } + Node::Call { + function, + dynamic_constants, + args, + } => { + write!(f, "call<")?; + for (idx, dyn_const) in dynamic_constants.iter().enumerate() { + module.write_dynamic_constant(*dyn_const, f)?; + if idx + 1 < dynamic_constants.len() { + write!(f, ", ")?; + } + } + write!(f, ">({}", module.functions[function.0 as usize].name)?; + for arg in args.iter() { + write!(f, ", var_{}", arg.0)?; + } + write!(f, ")") + } + Node::Read { collect, indices } => { + write!(f, "read(var_{}", collect.0)?; + for idx in indices.iter() { + write!(f, ", ")?; + idx.ir_fmt(f, module)?; + } + write!(f, ")") + } + Node::Write { + collect, + data, + indices, + } => { + write!(f, "write(var_{}, var_{}", collect.0, data.0)?; + for idx in indices.iter() { + write!(f, ", ")?; + idx.ir_fmt(f, module)?; + } + write!(f, ")") + } + Node::Ternary { .. } => todo!(), + } + } +} + +impl IRDisplay for Index { + fn ir_fmt(&self, f: &mut Formatter<'_>, _module: &Module) -> std::fmt::Result { + match self { + Index::Field(idx) => write!(f, "field({})", idx), + Index::Variant(idx) => write!(f, "variant({})", idx), + Index::Control(idx) => write!(f, "control({})", idx), + Index::Position(indices) => { + write!(f, "position(")?; + for (i, idx) in indices.iter().enumerate() { + write!(f, "var_{}", idx.0)?; + if i + 1 < indices.len() { + write!(f, ", ")?; + } + } + write!(f, ")") + } + } + } +} + +impl IRDisplay for Constant { + fn ir_fmt(&self, f: &mut Formatter<'_>, module: &Module) -> std::fmt::Result { + match self { + Constant::Boolean(v) => write!(f, "{} : bool", v), + Constant::Integer8(v) => write!(f, "{} : i8", v), + Constant::Integer16(v) => write!(f, "{} : i16", v), + Constant::Integer32(v) => write!(f, "{} : i32", v), + Constant::Integer64(v) => write!(f, "{} : i64", v), + Constant::UnsignedInteger8(v) => write!(f, "{} : u8", v), + Constant::UnsignedInteger16(v) => write!(f, "{} : u16", v), + Constant::UnsignedInteger32(v) => write!(f, "{} : u32", v), + Constant::UnsignedInteger64(v) => write!(f, "{} : u64", v), + Constant::Float32(v) => write!(f, "{} : f32", v), + Constant::Float64(v) => write!(f, "{} : f64", v), + Constant::Product(t, cnsts) => { + write!(f, "(")?; + for i in 0..cnsts.len() { + module.constants[cnsts[i].idx()].ir_fmt(f, module)?; + write!(f, ", ")?; + } + write!(f, ") :")?; + module.write_type(*t, f) + } + Constant::Summation(t, tag, cnst) => { + write!(f, "{}(", tag)?; + module.constants[cnst.idx()].ir_fmt(f, module)?; + write!(f, ") : ")?; + module.write_type(*t, f) + } + Constant::Array(t, cnsts) => { + write!(f, "{{")?; + for i in 0..cnsts.len() { + module.constants[cnsts[i].idx()].ir_fmt(f, module)?; + write!(f, ", ")?; + } + write!(f, "}} : ")?; + module.write_type(*t, f) + } + Constant::Zero(t) => { + write!(f, "zero : ")?; + module.write_type(*t, f) + } + } + } +} diff --git a/hercules_opt/Cargo.toml b/hercules_opt/Cargo.toml index 7743a73ff61bad93e473f5a169cf97c95edd71e7..5cf76b09aa87384f0418b00c777b8edde16833ed 100644 --- a/hercules_opt/Cargo.toml +++ b/hercules_opt/Cargo.toml @@ -1,7 +1,7 @@ [package] name = "hercules_opt" version = "0.1.0" -authors = ["Russel Arbore <rarbore2@illinois.edu>"] +authors = ["Russel Arbore <rarbore2@illinois.edu>, Aaron Councilman <aaronjc4@illinois.edu>"] [dependencies] ordered-float = "*" diff --git a/hercules_opt/src/forkify.rs b/hercules_opt/src/forkify.rs index 0e7efb0ed952c241244871c5194b37b6698303ba..32e812a99eb394a78abe987225c46dd74ad78dd9 100644 --- a/hercules_opt/src/forkify.rs +++ b/hercules_opt/src/forkify.rs @@ -17,10 +17,13 @@ pub fn forkify( def_use: &ImmutableDefUseMap, loops: &LoopTree, ) { - // Ignore loops that are already fork-joins. + // Ignore loops that are already fork-joins. TODO: re-calculate def_use per + // loop, since it's technically invalidated after each individual loop + // modification. let natural_loops = loops .bottom_up_loops() .into_iter() + .rev() .filter(|(k, _)| function.nodes[k.idx()].is_region()); // Detect loops that have a simple loop induction variable. TODO: proper diff --git a/hercules_opt/src/lib.rs b/hercules_opt/src/lib.rs index a82e7247e9b979ef2b6790cfa06779332b0d9872..27a9cfed113151a177312269e8824b3d3a483696 100644 --- a/hercules_opt/src/lib.rs +++ b/hercules_opt/src/lib.rs @@ -4,6 +4,7 @@ pub mod ccp; pub mod dce; pub mod forkify; pub mod gvn; +pub mod phi_elim; pub mod pass; pub mod pred; @@ -11,5 +12,6 @@ pub use crate::ccp::*; pub use crate::dce::*; pub use crate::forkify::*; pub use crate::gvn::*; +pub use crate::phi_elim::*; pub use crate::pass::*; pub use crate::pred::*; diff --git a/hercules_opt/src/pass.rs b/hercules_opt/src/pass.rs index 31a681bfbd660b353630dda7201e0f0f3b53f856..394de3c0a6080e67dfceb7d30021ef00dfc3960e 100644 --- a/hercules_opt/src/pass.rs +++ b/hercules_opt/src/pass.rs @@ -28,6 +28,7 @@ pub enum Pass { CCP, GVN, Forkify, + PhiElim, Predication, Verify, Xdot(bool), @@ -319,6 +320,11 @@ impl PassManager { ) } } + Pass::PhiElim => { + for function in self.module.functions.iter_mut() { + phi_elim(function); + } + } Pass::Predication => { self.make_def_uses(); self.make_reverse_postorders(); diff --git a/hercules_opt/src/phi_elim.rs b/hercules_opt/src/phi_elim.rs new file mode 100644 index 0000000000000000000000000000000000000000..21a1a2e5fea40eec40a03ab352841059c6118781 --- /dev/null +++ b/hercules_opt/src/phi_elim.rs @@ -0,0 +1,82 @@ +extern crate hercules_ir; + +use std::collections::HashMap; + +use self::hercules_ir::get_uses_mut; +use self::hercules_ir::ir::*; + +/* + * This is a Hercules IR transformation that: + * - Eliminates phi nodes where all inputs are the same (here this means the + * same node in IR, GVN and related are separate). + * - Eliminate regions with only a single predecessor. + * + * The first of these optimizations is inspired by the description of the SSA + * construction algorithm in Braun et al., CC 2013 used by the front-end. This + * optimization performs the phi removal suggested by that paper, as in our + * construction algorithm performing it during IR code generation is difficult. + */ + +/* + * Top level function to run phi elimination, as described above. Deletes nodes + * by setting nodes to gravestones. Works with a function already containing + * gravestones. + */ +pub fn phi_elim(function: &mut Function) { + // Keep a map of nodes that we need to replace, and what we need to replace + // them with. + let mut replace_nodes: HashMap<usize, NodeID> = HashMap::new(); + + // Iterate over the nodes of the function until convergence. In this loop, + // we look for phis and regions that can be eliminated, mark them as + // gravestones, and add them to the replacement map. For all other nodes, we + // see if any of their arguments are in the replacement map - if so, + // eliminate them. + let mut changed = true; + while changed { + changed = false; + + for (idx, node) in function.nodes.iter_mut().enumerate() { + // Replace any nodes that this node uses that are in the replacement + // map. + for u in get_uses_mut(node).as_mut() { + let old_id = u.idx(); + if let Some(replacement) = replace_nodes.get(&old_id) { + **u = *replacement; + changed = true; + } + } + + // Then, check if this node can be removed. + if let Node::Phi { control: _, data } = node { + // For a phi, we can remove it if all of its data inputs are the + // same node or self-cycles. + let mut unique = Some(data[0]); + for i in 1..data.len() { + // Ignore self-loops. + if data[i].idx() != idx && Some(data[i]) != unique { + if unique.unwrap().idx() == idx { + unique = Some(data[i]); + } else { + unique = None; + break; + } + } + } + if let Some(value) = unique { + changed = true; + replace_nodes.insert(idx, value); + // Delete this node. + *node = Node::Start; + } + } else if let Node::Region { preds } = node { + if preds.len() == 1 { + changed = true; + replace_nodes.insert(idx, preds[0]); + // Delete this node. + *node = Node::Start; + } + } + } + } +} diff --git a/juno_frontend/.gitignore b/juno_frontend/.gitignore new file mode 100644 index 0000000000000000000000000000000000000000..ef5f7e557e275015cb02004a1f69e6b98ae8f2f0 --- /dev/null +++ b/juno_frontend/.gitignore @@ -0,0 +1,4 @@ +*.aux +*.log +*.out +*.pdf diff --git a/juno_frontend/Cargo.toml b/juno_frontend/Cargo.toml new file mode 100644 index 0000000000000000000000000000000000000000..2332cfea777b3fdf4d64a4a45c925ce771c022f6 --- /dev/null +++ b/juno_frontend/Cargo.toml @@ -0,0 +1,23 @@ +[package] +name = "juno_frontend" +version = "0.1.0" +authors = ["Aaron Councilman <aaronjc4@illinois.edu>"] +edition = "2021" + +[[bin]] +name = "juno" +path = "src/main.rs" + +[build-dependencies] +cfgrammar = "0.13" +lrlex = "0.13" +lrpar = "0.13" + +[dependencies] +clap = { version = "*", features = ["derive"] } +cfgrammar = "0.13" +lrlex = "0.13" +lrpar = "0.13" +ordered-float = "*" +hercules_ir = { path = "../hercules_ir" } +hercules_opt = { path = "../hercules_opt" } diff --git a/juno_frontend/README.md b/juno_frontend/README.md new file mode 100644 index 0000000000000000000000000000000000000000..8974cdea00a9bb8c04722981c328872be0cbc5ae --- /dev/null +++ b/juno_frontend/README.md @@ -0,0 +1,201 @@ +# Juno Frontend +## Syntax +Throughout the description of syntax, we use the convention that any text contained within a pair of angle-brackets with no space (such as <this>) denotes a placeholder. +In particular, we use the following placeholders +- name: an identifier, which must begin with a letter and the other characters may be letters, numbers, or underscores +- qualified name: a sequence of identifiers separated by `::` +- top levels: a number of top-level constructs; where the top-level constructs are those defined below as Modules, Constants, Types, and Functions +- type: a type name +- expr: an expression + +Additionally, text contained within a pair of square brackets without space (such as [this]) denotes an optional part of the syntax + +### Modules +Modules can be imported using +``` +use <qualified-name>; +``` +In this context, the qualified name may also end with `::*` to indicate that all names in some package should be imported into the current scope. + +To define a new module within an existing file, use +``` +[pub] mod <name> { <top-levels> } +``` +where the optional pub indicates whether the module is exported when this file is used by another program. + +### Constants +Constant values can be declared as +``` +[pub] const <name> [: <type>] = <expr>; +``` + +Constant values must be able to be completely evaluated at compile time and therefore do not allow function calls, but all other expression forms, with the exception of those involving arrays, are supported. + +### Types +Type declarations have the form +``` +[pub] type <name> [< <type-variables> >] = <type-def> +``` +where `<type-variables>` are a comma separated list of names and kinds expressed as `<name> : <kind>` where the kinds are `type`, `usize`, `number`, and `integer`. +The kind `type` indicates that the associated variable is some type; this is the standard type variables that are present in many languages. +The kind `usize` indicates that the variable is a dynamic constant, and can therefore be used as a dimension for arrays as well as being used in expressions. +The kinds `number` and `integer` indicate that the associated variable is a type but must be a numeric type or integer type, respectively; this allows us to express computation such as matrix multiplication which can be applied to any numeric type while still ensuring that this parametric function will be well-typed for all types it can be applied to (in contrast to C++ where type errors can arise when a function is instantiated). +In addition, the kind can be omitted after a variable to instead create a list of variables followed by a kind, such as `x, y, z : usize`, in which case all preceding variables after the last kind are given the specified kind (here `usize`). + +Type definitions have three forms, first they can just be another type, hence defining a type alias. +They can also be a struct or union definition having forms +``` +[pub] struct { <fields> } +[pub] union { <fields> } +``` +where in this context pub indicates for a struct whether all fields are public and for a union whether the constructors and their arguments are public. +The fields in these definitions are lists of fields of the following form +``` +[pub] <name> [: <type>] ; +``` +If a field does not have a specified type, it is taken (implicitly) to be unit. +In a struct the pub on individual fields allows control of exactly which fields are exposed publically, and is a semantic error on unions. + +Finally, types have the following forms +``` +i8 +i16 +i32 +i64 +u8 +u16 +u32 +u64 +usize +f32 +f64 +bool +void +( <types> ) +<qualified-name> [::< <type-expressions> >] +<type>[ <type-expressions> ] +``` +Note that in this type system the void type and the unit type are considered to be the same type, and can be denoted as `()` as well. +The parentheses enclosed form (assuming the number of types inside is not one) where the types are comma separated, is used to define product types. + +Lastly, type expressions, are used in contexts such as type arguments and array sizes and can be either types or dynamic constant expressions (though only the latter are semantically valid in array dimensions) and have the following forms +``` +i8 +i16 +i32 +i64 +u8 +u16 +u32 +u64 +usize +f32 +f64 +bool +void +( <types-expressions> ) +<qualified-name> [::< <type-expressions> >] +<type-expression>[ <type-expressions> ] +<integer-literal> +- <type-expression> +<type-expression> + <type-expression> +<type-expression> - <type-expression> +<type-expression> * <type-expression> +<type-expression> / <type-expression> +``` + +### Functions +Function declarations have the form +``` +[pub] fn <name> [< <type-variables> >] ( <arguments> ) [ -> <type> ] { <body> } +``` +where the body is a statement and the arguments are a comma separated list of "argument binders" which have the form +``` +[inout] <pattern> [: <type>] +``` +where patterns are defined in the statements section. + +Note that an argument that is marked as inout must be provided with a variable as its parameter when the function is called and after the call the value of that variable is updated to the value that argument had within the function when it returned. + +### Statements +``` +let <pattern> [: <type>] [= <expr>]; +const <pattern> [: <type>] [= <expr>]; +<lexpr> <assign-op> <expr>; +if <non-struct-expr> { <body> } [else <if-stmt>] +match <non-struct-expr> { <cases> } +for <pattern> [: <type>] = <non-struct-expr> to <non-struct-expr> [by <signed-int-literal>] { <body> } +while <non-struct-expr> { <body> } +return [<expr>]; +break; +continue; +{ <body> } +<qualified-name>[::< <type-expressions> >]( <parameters> ); +``` +The assignment-operators `<assign-op>` are `=`, `+=`, `-=`, `*=`, `/=`, `%=`, `&=`, `|=`, `^=`, `&&=`, `||=` `<<=`, `>>=` and have the standard meanings (for instance as in C). +An `<if-stmt>` means either another if-then-else construct or a curly-bracket inclosed body. +Note that, like in Rust, the condition of an if-then-else, match, for, while is a non-struct expression, meaning just that it cannot contain a struct declaration that is not contained in parentheses. +A signed integer literal is an integer literal that is allowed to be prefixed by a single either `+` or `-` sign (but a sign is not needed either). + +Left-hand side expressions (`<lexpr>`) have the following forms +``` +<name> +<lexpr> . <name> +<lexpr> <dot-number> +<lexpr> [ <expressions> ] +``` +where `<dot-number>` denotes a dot (.) followed immediately (with no space) by an integer; for instance .0; these are used to index into product types. + + +The cases of a match statement have the following form +``` +<patterns> => <body> +``` +where <patterns> are a bar `|` separated list of patterns; note that a bar is also permitted at the start, so a collection of patterns can have either the form `a | b | c` or `| a | b | c`. + +Finally, a pattern has the following forms +``` +_ +<int-literal> +<qualified-name> +( <patterns-comma> ) +<qualified-name> { <named-patterns> } +<qualified-name> ( <patterns-comma> ) +``` +where `_` is the wildcard pattern and `<patterns-comma>` denotes a comma-separated list of patterns. +Named patterns are also comma separated but have the form `<name> : <pattern>` and are used for pattern matching on the fields of a struct. +Note that the final form is used to pattern match on unions. + +### Expressions +Finally, expressions have the following forms +``` +<int-literal> +<float-literal> +true +false +<unary-op> <expression> +<expression> <binary-op> <expression> +<expression> as <type> +<lexpr> +( <expressions> ) +<qualified-name> [::< <type-expressions> >] { <id-expressions> } +if <expression> then <expression> else <expression> +<qualified-name> [::< <type-expressions> >] ( <parameters> ) +``` +The unary operators are negation `-`, bitwise-not `~`, and boolean not `!`. +The binary operators are `+`, `-`, `*`, `/`, `%`, `&` `|`, `^`, `&&`, `||`, `<<`, `>>`, `<`, `<=`, `>`, `>=`, `==`, `!=` which have their standard interpretations, precedences, and associativies. + +An `<id-expression>` has form `<name> = <expression>` and lists are comma separated. + +Parameters are comma separated and have the form `[&] <expression>` where the ampersand & is necessary for inout arguments. + + +## Unimplemented Features +The following features are incomplete +- Implementation of match and non-variable patterns in all other bindings +- Type inference +- Dynamic constant expressions (i.e. expressions of dynamic constants which are themselves dynamic constants) +- Module definitions, imports, and qualified names +- Partial indexing into an array +- Struct values where some arguments are left as their default values +- Supporting arbitrary l-expressions for inout arguments, we currently only support variables diff --git a/juno_frontend/build.rs b/juno_frontend/build.rs new file mode 100644 index 0000000000000000000000000000000000000000..eac7103d82ee77c42bd7cd146606e027400aee07 --- /dev/null +++ b/juno_frontend/build.rs @@ -0,0 +1,15 @@ +use cfgrammar::yacc::YaccKind; +use lrlex::CTLexerBuilder; + +fn main() { + CTLexerBuilder::new() + .lrpar_config(|ctp| { + ctp.yacckind(YaccKind::Grmtools) + .grammar_in_src_dir("lang.y") + .unwrap() + }) + .lexer_in_src_dir("lang.l") + .unwrap() + .build() + .unwrap(); +} diff --git a/juno_frontend/examples/cava.jn b/juno_frontend/examples/cava.jn new file mode 100644 index 0000000000000000000000000000000000000000..ea2b65e1842afe439808c750d0125895c30ddc16 --- /dev/null +++ b/juno_frontend/examples/cava.jn @@ -0,0 +1,218 @@ +fn max<a : number>(x : a, y : a) -> a { + return if x >= y then x else y; +} + +fn min<a : number>(x : a, y : a) -> a { + return if x < y then x else y; +} + +fn medianMatrix<a : number, rows, cols : usize>(m : a[rows, cols]) -> a { + const n : usize = rows * cols; + + let tmp : a[rows * cols]; + for r = 0 to rows { + for c = 0 to cols { + tmp[r * cols + c] = m[r, c]; + } + } + + for i = 0 to n - 1 { + for j = 0 to n - i - 1 { + if tmp[j] > tmp[j+1] { + let t : a = tmp[j]; + tmp[j] = tmp[j+1]; + tmp[j+1] = t; + } + } + } + + return tmp[n / 2]; +} + +const CHAN : u64 = 3; +const CTRL_PTS : u64 = 3702; + +fn scale<row : usize, col : usize>(input : u8[CHAN, row, col]) -> f32[CHAN, row, col] { + let res : f32[CHAN, row, col]; + + for chan = 0 to CHAN { + for r = 0 to row { + for c = 0 to col { + res[chan, r, c] = input[chan, r, c] as f32 * 1.0 / 255; + } + } + } + + return res; +} + +fn descale<row : usize, col : usize>(input : f32[CHAN, row, col]) -> u8[CHAN, row, col] { + let res : u8[CHAN, row, col]; + + for chan = 0 to CHAN { + for r = 0 to row { + for c = 0 to col { + res[chan, r, c] = min(max(input[chan, r, c] * 255, 0), 255) as u8; + } + } + } + + return res; +} + +fn demosaic<row : usize, col : usize>(input : f32[CHAN, row, col]) -> f32[CHAN, row, col] { + let res : f32[CHAN, row, col]; + + for r = 1 to row-1 { + for c = 1 to col-1 { + if r % 2 == 0 && col % 2 == 0 { + let R1 : f32 = input[0, r, c-1]; + let R2 : f32 = input[0, r, c+1]; + let B1 : f32 = input[2, r-1, c]; + let B2 : f32 = input[2, r+1, c]; + res[0, r, c] = (R1 + R2) / 2; + res[1, r, c] = input[1, r, c] * 2; + res[2, r, c] = (B1 + B2) / 2; + } else if row % 2 == 0 && col % 2 == 1 { + let G1 = input[1, r-1, c]; + let G2 = input[1, r+1, c]; + let G3 = input[1, r, c-1]; + let G4 = input[1, r, c+1]; + let B1 = input[2, r-1, c-1]; + let B2 = input[2, r-1, c+1]; + let B3 = input[2, r+1, c-1]; + let B4 = input[2, r+1, c+1]; + res[0, r, c] = input[0, r, c]; + res[1, r, c] = (G1 + G2 + G3 + G4) / 2; + res[2, r, c] = (B1 + B2 + B3 + B4) / 4; + } else if row % 2 == 1 && col % 2 == 0 { + let R1 = input[0, r-1, c-1]; + let R2 = input[0, r+1, c-1]; + let R3 = input[0, r-1, c+1]; + let R4 = input[0, r+1, c+1]; + let G1 = input[1, r-1, c]; + let G2 = input[1, r+1, c]; + let G3 = input[1, r, c-1]; + let G4 = input[1, r, c+1]; + res[0, r, c] = (R1 + R2 + R3 + R4) / 4; + res[1, r, c] = (G1 + G2 + G3 + G4) / 2; + res[2, r, c] = input[2, r, c]; + } else { + let R1 = input[0, r-1, c]; + let R2 = input[0, r+1, c]; + let B1 = input[2, r, c-1]; + let B2 = input[2, r, c+1]; + res[0, r, c] = (R1 + R2) / 2; + res[1, r, c] = input[1, r, c] * 2; + res[2, r, c] = (B1 + B2) / 2; + } + } + } + + return res; +} + +fn denoise<row : usize, col : usize>(input : f32[CHAN, row, col]) -> f32[CHAN, row, col] { + let res : f32[CHAN, row, col]; + + for chan = 0 to CHAN { + for r = 0 to row { + for c = 0 to col { + if r >= 1 && r < row - 1 && c >= 1 && c < col - 1 { + let filter : f32[3][3]; // same as [3, 3] + for i = -1 to 2 by 1 { + for j = -1 to 2 by 1 { + filter[i+1, j+1] = input[chan, r + i, c + j]; + } + } + result[chan, r, c] = medianMatrix(filter); + } else { + res[chan, r, c] = input[chan, r, c]; + } + } + } + } + + return res; +} + +fn transform<row : usize, col : usize> + (input : f32[CHAN, row, col], tstw_trans : f32[CHAN, CHAN]) + -> f32[CHAN, row, col] { + let result : f32[CHAN, row, col]; + + for chan = 0 to CHAN { + for r = 0 to row { + for c = 0 to col { + result[chan, r, c] = max(input[0, r, c] * tstw_trans[0, chan] + + input[1, r, c] * tstw_trans[1, chan] + + input[2, r, c] * tstw_trans[2, chan], 0); + } + } + } + + return result; +} + +fn gamut<row : usize, col : usize> + (input : f32[CHAN, row, col], ctrl_pts : f32[CTRL_PTS, CHAN], + weights : f32[CTRL_PTS, CHAN], coefs : f32[4, CHAN]) -> f32[CHAN, row, col] { + let result : f32[CHAN, row, col]; + let l2_dist : f32[CTRL_PTS]; + + for r = 0 to row { + for c = 0 to col { + for cp = 0 to CTRL_PTS { + let v1 : f32 = input[0, r, c] - ctrl_pts[cp, 0]; + let v2 : f32 = input[1, r, c] - ctrl_pts[cp, 1]; + let v3 : f32 = input[2, r, c] - ctrl_pts[cp, 2]; + let v : f32 = v1 * v1 + v2 * v2 + v3 * v3; + l2_dist[cp] = math::sqrt(v); // NOTE: Somehow some libraries will likely have to be external + } + + for chan = 0 to CHAN { + let chan_val : f32 = 0.0; + for cp = 0 to CTRL_PTS { + chan_val += l2_dist[cp] * weights[cp, chan]; + } + + chan_val += coefs[0, chan] + coefs[1, chan] * input[0, r, c] + + coefs[2, chan] * input[1, r, c] + + coefs[3, chan] * input[2, r, c]; + result[chan, row, col] = max(chan_val, 0); + } + } + } + + return result; +} + +fn tone_map<row : usize, col:usize> + (input : f32[CHAN, row, col], tone_map : f32[255, CHAN]) -> f32[CHAN, row, col] { + let result : f32[CHAN, row, col]; + + for chan = 0 to CHAN { + for r = 0 to row { + for c = 0 to col { + let x = (input[chan, r, c] * 255) as u8; + result[chan, r, c] = tone_map[x, chan]; + } + } + } + + return result; +} + +fn cava<r, c : usize>(input : u8[CHAN, r, c], TsTw : f32[CHAN, CHAN], + ctrl_pts : f32[CTRL_PTS, CHAN], + weights : f32[CTRl_PTHS, CHAN], + coefs : f32[4, CHAN], tone_map : f32[255, CHAN]) + -> u8[CHAN, row, col] { + let scaled = scale::<r, c>(input); + let demosc = demosaic::<r, c>(scaled); + let denosd = denoise::<r, c>(demosc); + let transf = transform::<r, c>(denosd, TsTw); + let gamutd = gamut::<r, c>(transf, ctrl_pts, weights, coefs); + let tonemd = tone_map::<r, c>(gamutd, tone_map); + return descale::<r, c>(tonemd); +} diff --git a/juno_frontend/examples/matadd.jn b/juno_frontend/examples/matadd.jn new file mode 100644 index 0000000000000000000000000000000000000000..58217f7cdc4353762dec072ff74e3527a7b54976 --- /dev/null +++ b/juno_frontend/examples/matadd.jn @@ -0,0 +1,11 @@ +fn mat_add<n : usize, m : usize>(a : f32[n, m], b : f32[n, m]) -> f32[n, m] { + let res : f32[n, m]; + + for i = 0 to n { + for j = 0 to m { + res[i, j] = a[i, j] + b[i, j]; + } + } + + return res; +} diff --git a/juno_frontend/examples/matmul.jn b/juno_frontend/examples/matmul.jn new file mode 100644 index 0000000000000000000000000000000000000000..5c7b28a236167904f5c91a954da1b7cf1ae14163 --- /dev/null +++ b/juno_frontend/examples/matmul.jn @@ -0,0 +1,13 @@ +fn matmul<n : usize, m : usize, l : usize>(a : f32[n, m], b : f32[m, l]) -> f32[n, l] { + let res : f32[n, l]; + + for i = 0 to n { + for j = 0 to l { + for k = 0 to m { + res[i, j] += a[i, k] * b[k, j]; + } + } + } + + return res; +} diff --git a/juno_frontend/examples/poly_matmul.jn b/juno_frontend/examples/poly_matmul.jn new file mode 100644 index 0000000000000000000000000000000000000000..baff2620fe77852dbf49a306713d702da93d5e19 --- /dev/null +++ b/juno_frontend/examples/poly_matmul.jn @@ -0,0 +1,21 @@ +fn matmul<t : number, n, m, l : usize>(a : t[n, m], b : t[m, l]) -> t[n, l] { + let res : t[n, l]; + + for i = 0 to n { + for j = 0 to l { + for k = 0 to m { + res[i, j] += a[i, k] * b[k, j]; + } + } + } + + return res; +} + +fn matmul_i32<n, m, l : usize>(a : i32[n, m], b : i32[m, l]) -> i32[n, l] { + return matmul::<i32, n, m, l>(a, b); +} + +fn matmul_f64<n, m, l : usize>(a : f64[n, m], b : f64[m, l]) -> f64[n, l] { + return matmul::<f64, n, m, l>(a, b); +} diff --git a/juno_frontend/examples/vecadd.jn b/juno_frontend/examples/vecadd.jn new file mode 100644 index 0000000000000000000000000000000000000000..115ab399341ba31d27522a7a17bf28c7b1962a95 --- /dev/null +++ b/juno_frontend/examples/vecadd.jn @@ -0,0 +1,9 @@ +fn vec_add<n : usize>(a : f32[n], b : f32[n]) -> f32[n] { + let res : f32[n]; + + for i = 0 to n { + res[i] = a[i] + b[i]; + } + + return res; +} diff --git a/juno_frontend/src/codegen.rs b/juno_frontend/src/codegen.rs new file mode 100644 index 0000000000000000000000000000000000000000..c48c188e9ad96e177cff1d85e7ff6e285a9fba79 --- /dev/null +++ b/juno_frontend/src/codegen.rs @@ -0,0 +1,622 @@ +extern crate hercules_ir; + +use std::collections::{HashMap, VecDeque}; + +use self::hercules_ir::ir; +use self::hercules_ir::ir::*; +use self::hercules_ir::build::*; + +use crate::ssa::SSA; +use crate::semant; +use crate::semant::{Prg, Function, Stmt, Expr, Literal, UnaryOp, BinaryOp}; +use crate::types::{TypeSolver, TypeSolverInst, Primitive, Either}; + +// Loop info is a stack of the loop levels, recording the latch and exit block of each +type LoopInfo = Vec<(NodeID, NodeID)>; + +pub fn codegen_program(prg : Prg) -> Module { + CodeGenerator::build(prg) +} + +struct CodeGenerator<'a> { + builder : Builder<'a>, + types : &'a TypeSolver, + funcs : &'a Vec<Function>, + uid : usize, + // The function map tracks a map from function index and set of type variables to its function + // id in the builder + functions : HashMap<(usize, Vec<TypeID>), FunctionID>, + // The worklist tracks a list of functions to codegen, tracking the function's id, its + // type-solving instantiation (account for the type parameters), the function id, and the entry + // block id + worklist : VecDeque<(usize, TypeSolverInst<'a>, FunctionID, NodeID)>, +} + +impl CodeGenerator<'_> { + fn build((types, funcs) : Prg) -> Module { + // Identify the functions (by index) which have no type arguments, these are the ones we + // ask for code to be generated for + let func_idx + = funcs.iter().enumerate() + .filter_map(|(i, f)| + if f.num_type_args == 0 { Some(i) } else { None }); + + let mut codegen = CodeGenerator { builder : Builder::create(), + types : &types, + funcs : &funcs, + uid : 0, + functions : HashMap::new(), + worklist : VecDeque::new(), }; + + // Add the identifed functions to the list to code-gen + func_idx.for_each(|i| { let _ = codegen.get_function(i, vec![]); }); + + codegen.finish() + } + + fn finish(mut self) -> Module { + while !self.worklist.is_empty() { + let (idx, mut type_inst, func, entry) = self.worklist.pop_front().unwrap(); + self.codegen_function(&self.funcs[idx], &mut type_inst, func, entry); + } + + self.builder.finish() + } + + fn get_function(&mut self, func_idx : usize, ty_args : Vec<TypeID>) -> FunctionID { + let func_info = (func_idx, ty_args); + match self.functions.get(&func_info) { + Some(func_id) => *func_id, + None => { + let ty_args = func_info.1; + + let func = &self.funcs[func_idx]; + let mut solver_inst = self.types.create_instance(ty_args.clone()); + + // TODO: Ideally we would write out the type arguments, but now that they're + // lowered to TypeID we can't do that as far as I can tell + let name = format!("{}_{}", func.name, self.uid); + self.uid += 1; + + let mut param_types = vec![]; + for (_, ty) in func.arguments.iter() { + param_types.push(solver_inst.lower_type(&mut self.builder, *ty)); + } + + let return_type = solver_inst.lower_type(&mut self.builder, func.return_type); + + let (func_id, entry) + = self.builder.create_function( + &name, param_types, return_type, + func.num_dyn_consts as u32).unwrap(); + + self.functions.insert((func_idx, ty_args), func_id); + self.worklist.push_back((func_idx, solver_inst, func_id, entry)); + func_id + }, + } + } + + fn codegen_function(&mut self, func : &Function, types : &mut TypeSolverInst, + func_id : FunctionID, entry : NodeID) { + // Setup the SSA construction data structure + let mut ssa = SSA::new(func_id, entry); + + // Create nodes for the arguments + for (idx, (var, _)) in func.arguments.iter().enumerate() { + let mut node_builder = self.builder.allocate_node(func_id); + ssa.write_variable(*var, entry, node_builder.id()); + node_builder.build_parameter(idx); + let _ = self.builder.add_node(node_builder); + } + + // Generate code for the body + let None = self.codegen_stmt(&func.body, types, &mut ssa, + func_id, entry, &mut vec![]) + else { panic!("Generated code for a function missing a return") }; + } + + fn codegen_stmt(&mut self, stmt : &Stmt, types : &mut TypeSolverInst, + ssa : &mut SSA, func_id : FunctionID, cur_block : NodeID, + loops : &mut LoopInfo) -> Option<NodeID> { + match stmt { + Stmt::AssignStmt { var, val } => { + let (val, block) = self.codegen_expr(val, types, ssa, func_id, cur_block); + ssa.write_variable(*var, block, val); + Some(block) + }, + Stmt::IfStmt { cond, thn, els } => { + let (val_cond, block_cond) + = self.codegen_expr(cond, types, ssa, func_id, cur_block); + let (mut if_node, block_then, block_else) + = ssa.create_cond(&mut self.builder, block_cond); + + let then_end = self.codegen_stmt(thn, types, ssa, + func_id, block_then, loops); + let else_end = + match els { + None => Some(block_else), + Some(els_stmt) => + self.codegen_stmt(els_stmt, types, ssa, + func_id, block_else, loops), + }; + + if_node.build_if(block_cond, val_cond); + let _ = self.builder.add_node(if_node); + + match (then_end, else_end) { + (None, els) => els, + (thn, None) => thn, + (Some(then_term), Some(else_term)) => { + let block_join = ssa.create_block(&mut self.builder); + ssa.add_pred(block_join, then_term); + ssa.add_pred(block_join, else_term); + ssa.seal_block(block_join, &mut self.builder); + Some(block_join) + }, + } + }, + Stmt::LoopStmt { cond, update, body } => { + // We generate guarded loops, so the first step is to create + // a conditional branch, branching on the condition + let (val_guard, block_guard) + = self.codegen_expr(cond, types, ssa, func_id, cur_block); + let (mut if_node, true_guard, false_proj) + = ssa.create_cond(&mut self.builder, block_guard); + if_node.build_if(block_guard, val_guard); + let _ = self.builder.add_node(if_node); + + // We then create a region for the exit (since there may be breaks) + let block_exit = ssa.create_block(&mut self.builder); + ssa.add_pred(block_exit, false_proj); + + // Now, create a block for the loop's latch, we don't (currently) know any of its + // predecessors + let block_latch = ssa.create_block(&mut self.builder); + + // Code-gen any update into the latch and then code-gen the condition + let block_updated = + match update { + None => block_latch, + Some(stmt) => + self.codegen_stmt(stmt, types, ssa, func_id, block_latch, loops) + .expect("Loop update should return control"), + }; + let (val_cond, block_cond) + = self.codegen_expr(cond, types, ssa, func_id, block_updated); + + let (mut if_node, true_proj, false_proj) + = ssa.create_cond(&mut self.builder, block_cond); + if_node.build_if(block_cond, val_cond); + let _ = self.builder.add_node(if_node); + + // Add the false projection from the latch as a predecessor of the exit + ssa.add_pred(block_exit, false_proj); + + // Create a block for the loop header, and add the true branches from the guard and + // latch as its only predecessors + let body_block = ssa.create_block(&mut self.builder); + ssa.add_pred(body_block, true_guard); + ssa.add_pred(body_block, true_proj); + ssa.seal_block(body_block, &mut self.builder); + + // Generate code for the body + loops.push((block_latch, block_exit)); + let body_res = self.codegen_stmt(body, types, ssa, func_id, body_block, loops); + loops.pop(); + + // If the body of the loop can reach some block, we add that block as a predecessor + // of the latch + match body_res { + None => {}, + Some(block) => { + ssa.add_pred(block_latch, block); + }, + } + + // Seal remaining open blocks + ssa.seal_block(block_exit, &mut self.builder); + ssa.seal_block(block_latch, &mut self.builder); + + // It is always assumed a loop may be skipped and so control can reach after the + // loop + Some(block_exit) + }, + Stmt::ReturnStmt { expr } => { + let (val_ret, block_ret) + = self.codegen_expr(expr, types, ssa, func_id, cur_block); + let mut return_node = self.builder.allocate_node(func_id); + return_node.build_return(block_ret, val_ret); + let _ = self.builder.add_node(return_node); + None + }, + Stmt::BreakStmt {} => { + let last_loop = loops.len() - 1; + let (_latch, exit) = loops[last_loop]; + ssa.add_pred(exit, cur_block); // The block that contains this break now leads to + // the exit + None + }, + Stmt::ContinueStmt {} => { + let last_loop = loops.len() - 1; + let (latch, _exit) = loops[last_loop]; + ssa.add_pred(latch, cur_block); // The block that contains this continue now leads + // to the latch + None + }, + Stmt::BlockStmt { body } => { + let mut block = Some(cur_block); + for stmt in body.iter() { + block = self.codegen_stmt(stmt, types, ssa, func_id, + block.unwrap(), loops); + } + block + }, + Stmt::ExprStmt { expr } => { + let (_val, block) + = self.codegen_expr(expr, types, ssa, func_id, cur_block); + Some(block) + }, + } + } + + // The codegen_expr function returns a pair of node IDs, the first is the node whose value is + // the given expression and the second is the node of a control node at which the value is + // available + fn codegen_expr(&mut self, expr : &Expr, types : &mut TypeSolverInst, + ssa : &mut SSA, func_id : FunctionID, cur_block : NodeID) + -> (NodeID, NodeID) { + match expr { + Expr::Variable { var, .. } => { + (ssa.read_variable(*var, cur_block, &mut self.builder), + cur_block) + }, + Expr::DynConst { idx, .. } => { + let mut node = self.builder.allocate_node(func_id); + let node_id = node.id(); + let dyn_const = self.builder.create_dynamic_constant_parameter(*idx); + node.build_dynamicconstant(dyn_const); + let _ = self.builder.add_node(node); + (node_id, cur_block) + }, + Expr::Read { index, val, .. } => { + let (collection, block) + = self.codegen_expr(val, types, ssa, func_id, cur_block); + let (indices, end_block) + = self.codegen_indices(index, types, ssa, func_id, block); + + let mut node = self.builder.allocate_node(func_id); + let node_id = node.id(); + node.build_read(collection, indices.into()); + let _ = self.builder.add_node(node); + (node_id, end_block) + }, + Expr::Write { index, val, rep, .. } => { + let (collection, block) + = self.codegen_expr(val, types, ssa, func_id, cur_block); + let (indices, idx_block) + = self.codegen_indices(index, types, ssa, func_id, block); + let (replace, end_block) + = self.codegen_expr(rep, types, ssa, func_id, idx_block); + + let mut node = self.builder.allocate_node(func_id); + let node_id = node.id(); + node.build_write(collection, replace, indices.into()); + let _ = self.builder.add_node(node); + (node_id, end_block) + }, + Expr::Tuple { vals, typ } => { + let mut block = cur_block; + let mut values = vec![]; + for expr in vals { + let (val_expr, block_expr) + = self.codegen_expr(expr, types, ssa, func_id, block); + block = block_expr; + values.push(val_expr); + } + + let tuple_type = types.lower_type(&mut self.builder, *typ); + (self.build_tuple(values, tuple_type, func_id), block) + }, + Expr::Union { tag, val, typ } => { + let (value, block) + = self.codegen_expr(val, types, ssa, func_id, cur_block); + + let union_type = types.lower_type(&mut self.builder, *typ); + (self.build_union(*tag, value, union_type, func_id), block) + }, + Expr::Constant { val, .. } => { + let const_id = self.build_constant(val, types); + + let mut val = self.builder.allocate_node(func_id); + let val_node = val.id(); + val.build_constant(const_id); + let _ = self.builder.add_node(val); + + (val_node, cur_block) + }, + Expr::Zero { typ } => { + let type_id = types.lower_type(&mut self.builder, *typ); + let zero_const = self.builder.create_constant_zero(type_id); + let mut zero = self.builder.allocate_node(func_id); + let zero_val = zero.id(); + zero.build_constant(zero_const); + let _ = self.builder.add_node(zero); + + (zero_val, cur_block) + }, + Expr::UnaryExp { op, expr, .. } => { + let (val, block) + = self.codegen_expr(expr, types, ssa, func_id, cur_block); + + let mut expr = self.builder.allocate_node(func_id); + let expr_id = expr.id(); + expr.build_unary(val, + match op { + UnaryOp::Negation => UnaryOperator::Neg, + UnaryOp::BitwiseNot => UnaryOperator::Not, + }); + let _ = self.builder.add_node(expr); + + (expr_id, block) + }, + Expr::BinaryExp { op, lhs, rhs, .. } => { + let (val_lhs, block_lhs) + = self.codegen_expr(lhs, types, ssa, func_id, cur_block); + let (val_rhs, block_rhs) + = self.codegen_expr(rhs, types, ssa, func_id, block_lhs); + + let mut expr = self.builder.allocate_node(func_id); + let expr_id = expr.id(); + expr.build_binary(val_lhs, val_rhs, + match op { + BinaryOp::Add => BinaryOperator::Add, + BinaryOp::Sub => BinaryOperator::Sub, + BinaryOp::Mul => BinaryOperator::Mul, + BinaryOp::Div => BinaryOperator::Div, + BinaryOp::Mod => BinaryOperator::Rem, + BinaryOp::BitAnd => BinaryOperator::And, + BinaryOp::BitOr => BinaryOperator::Or, + BinaryOp::Xor => BinaryOperator::Xor, + BinaryOp::Lt => BinaryOperator::LT, + BinaryOp::Le => BinaryOperator::LTE, + BinaryOp::Gt => BinaryOperator::GT, + BinaryOp::Ge => BinaryOperator::GTE, + BinaryOp::Eq => BinaryOperator::EQ, + BinaryOp::Neq => BinaryOperator::NE, + BinaryOp::LShift => BinaryOperator::LSh, + BinaryOp::RShift => BinaryOperator::RSh, + }); + let _ = self.builder.add_node(expr); + + (expr_id, block_rhs) + }, + Expr::CastExpr { expr, typ } => { + let type_id = types.lower_type(&mut self.builder, *typ); + let (val, block) + = self.codegen_expr(expr, types, ssa, func_id, cur_block); + + let mut expr = self.builder.allocate_node(func_id); + let expr_id = expr.id(); + expr.build_unary(val, UnaryOperator::Cast(type_id)); + let _ = self.builder.add_node(expr); + + (expr_id, block) + }, + Expr::CondExpr { cond, thn, els, .. } => { + // Code-gen the condition + let (val_cond, block_cond) + = self.codegen_expr(cond, types, ssa, func_id, cur_block); + + // Create the if + let (mut if_builder, then_block, else_block) + = ssa.create_cond(&mut self.builder, block_cond); + if_builder.build_if(block_cond, val_cond); + let _ = self.builder.add_node(if_builder); + + // Code-gen the branches + let (then_val, block_then) + = self.codegen_expr(thn, types, ssa, func_id, then_block); + let (else_val, block_else) + = self.codegen_expr(els, types, ssa, func_id, else_block); + + // Create the join in the control-flow + let join = ssa.create_block(&mut self.builder); + ssa.add_pred(join, block_then); + ssa.add_pred(join, block_else); + ssa.seal_block(join, &mut self.builder); + + // Create a phi that joins the two branches + let mut phi = self.builder.allocate_node(func_id); + let phi_id = phi.id(); + phi.build_phi(join, vec![then_val, else_val].into()); + let _ = self.builder.add_node(phi); + + (phi_id, join) + }, + Expr::CallExpr { func, ty_args, dyn_consts, args, .. } => { + // We start by lowering the type arguments to TypeIDs + let mut type_params = vec![]; + for typ in ty_args { + type_params.push(types.lower_type(&mut self.builder, *typ)); + } + + // With the type arguments, we can now lookup the function + let call_func = self.get_function(*func, type_params); + + // We then build the dynamic constants + let dynamic_constants + = TypeSolverInst::build_dyn_consts(&mut self.builder, dyn_consts); + + // Code gen for each argument in order + // For inouts, this becomes an ssa.read_variable + // We also record the variables which are our inouts + let mut block = cur_block; + let mut arg_vals = vec![]; + let mut inouts = vec![]; + for arg in args { + match arg { + Either::Left(exp) => { + let (val, new_block) + = self.codegen_expr(exp, types, ssa, func_id, block); + block = new_block; + arg_vals.push(val); + }, + Either::Right(var) => { + inouts.push(*var); + arg_vals.push(ssa.read_variable(*var, block, &mut self.builder)); + }, + } + } + + // Create the call expression + let mut call = self.builder.allocate_node(func_id); + let call_id = call.id(); + call.build_call(call_func, dynamic_constants.into(), arg_vals.into()); + let _ = self.builder.add_node(call); + + // Read each of the "inout values" and perform the SSA update + let inouts_index = self.builder.create_field_index(1); + for (idx, var) in inouts.into_iter().enumerate() { + let index = self.builder.create_field_index(idx); + let mut read = self.builder.allocate_node(func_id); + let read_id = read.id(); + read.build_read(call_id, vec![inouts_index.clone(), index].into()); + let _ = self.builder.add_node(read); + + ssa.write_variable(var, block, read_id); + } + + // Read the "actual return" value and return it + let value_index = self.builder.create_field_index(0); + let mut read = self.builder.allocate_node(func_id); + let read_id = read.id(); + read.build_read(call_id, vec![value_index].into()); + let _ = self.builder.add_node(read); + + (read_id, block) + }, + } + } + + // Convert a list of Index from the semantic analysis into a list of indices for the builder. + // Note that this takes and returns a block since expressions may involve control flow + fn codegen_indices(&mut self, index : &Vec<semant::Index>, types : &mut TypeSolverInst, + ssa : &mut SSA, func_id : FunctionID, cur_block : NodeID) + -> (Vec<ir::Index>, NodeID) { + + let mut block = cur_block; + let mut built_index = vec![]; + for idx in index { + match idx { + semant::Index::Field(idx) => { + built_index.push(self.builder.create_field_index(*idx)); + }, + semant::Index::Variant(idx) => { + built_index.push(self.builder.create_variant_index(*idx)); + }, + semant::Index::Array(exps) => { + let mut expr_vals = vec![]; + for exp in exps { + let (val, new_block) = + self.codegen_expr(exp, types, ssa, func_id, block); + block = new_block; + expr_vals.push(val); + } + built_index.push(self.builder.create_position_index(expr_vals.into())); + }, + } + } + + (built_index, block) + } + + fn build_tuple(&mut self, exprs : Vec<NodeID>, typ : TypeID, func_id : FunctionID) -> NodeID { + let zero_const = self.builder.create_constant_zero(typ); + + let mut zero = self.builder.allocate_node(func_id); + let zero_val = zero.id(); + zero.build_constant(zero_const); + let _ = self.builder.add_node(zero); + + let mut val = zero_val; + for (idx, exp) in exprs.into_iter().enumerate() { + let mut write = self.builder.allocate_node(func_id); + let write_id = write.id(); + let index = self.builder.create_field_index(idx); + + write.build_write(val, exp, vec![index].into()); + let _ = self.builder.add_node(write); + val = write_id; + } + + val + } + + fn build_union(&mut self, tag : usize, val : NodeID, typ : TypeID, + func_id : FunctionID) -> NodeID { + let zero_const = self.builder.create_constant_zero(typ); + + let mut zero = self.builder.allocate_node(func_id); + let zero_val = zero.id(); + zero.build_constant(zero_const); + let _ = self.builder.add_node(zero); + + let mut write = self.builder.allocate_node(func_id); + let write_id = write.id(); + let index = self.builder.create_variant_index(tag); + + write.build_write(zero_val, val, vec![index].into()); + let _ = self.builder.add_node(write); + + write_id + } + + fn build_constant<'a>(&mut self, (lit, typ) : &semant::Constant, + types : &mut TypeSolverInst<'a>) -> ConstantID { + match lit { + Literal::Unit => { + self.builder.create_constant_prod(vec![].into()) + }, + Literal::Bool(val) => { + self.builder.create_constant_bool(*val) + }, + Literal::Integer(val) => { + let p = types.as_numeric_type(&mut self.builder, *typ); + match p { + Primitive::I8 => self.builder.create_constant_i8(*val as i8), + Primitive::I16 => self.builder.create_constant_i16(*val as i16), + Primitive::I32 => self.builder.create_constant_i32(*val as i32), + Primitive::I64 => self.builder.create_constant_i64(*val as i64), + Primitive::U8 => self.builder.create_constant_u8(*val as u8), + Primitive::U16 => self.builder.create_constant_u16(*val as u16), + Primitive::U32 => self.builder.create_constant_u32(*val as u32), + Primitive::U64 => self.builder.create_constant_u64(*val as u64), + Primitive::F32 => self.builder.create_constant_f32(*val as f32), + Primitive::F64 => self.builder.create_constant_f64(*val as f64), + _ => panic!("Internal error in build_constant for integer"), + } + }, + Literal::Float(val) => { + let p = types.as_numeric_type(&mut self.builder, *typ); + match p { + Primitive::F32 => self.builder.create_constant_f32(*val as f32), + Primitive::F64 => self.builder.create_constant_f64(*val as f64), + _ => panic!("Internal error in build_constant for float"), + } + }, + Literal::Tuple(vals) => { + let mut constants = vec![]; + for val in vals { + constants.push(self.build_constant(val, types)); + } + self.builder.create_constant_prod(constants.into()) + }, + Literal::Sum(tag, val) => { + let constant = self.build_constant(val, types); + let type_id = types.lower_type(&mut self.builder, *typ); + self.builder.create_constant_sum(type_id, *tag as u32, constant).unwrap() + }, + } + } +} diff --git a/juno_frontend/src/env.rs b/juno_frontend/src/env.rs new file mode 100644 index 0000000000000000000000000000000000000000..3985a69a108c3c04f27f92b65b2c083c94006398 --- /dev/null +++ b/juno_frontend/src/env.rs @@ -0,0 +1,65 @@ +use std::collections::HashMap; +use std::collections::HashSet; +use std::hash::Hash; + +pub struct Env<K, V> { + table : HashMap<K, Vec<V>>, + scope : Vec<HashSet<K>>, + count : usize, +} + +impl<K : Eq + Hash + Copy, V> Env<K, V> { + pub fn new() -> Env<K, V> { + Env { table : HashMap::new(), scope : vec![], count : 0 } + } + + pub fn lookup(&self, k : &K) -> Option<&V> { + match self.table.get(k) { + None => None, + Some(l) => Some(&l[l.len()-1]), + } + } + + pub fn insert(&mut self, k : K, v : V) { + if self.scope[self.scope.len()-1].contains(&k) { + match self.table.get_mut(&k) { + None => panic!("Internal Failure: Environment Insert"), + Some(r) => { + let last = r.len() - 1; + r[last] = v; + }, + } + } else { + let last = self.scope.len() - 1; + match self.table.get_mut(&k) { + None => { self.table.insert(k, vec![v]); }, + Some(r) => { r.push(v); }, + } + self.scope[last].insert(k); + } + } + + pub fn open_scope(&mut self) { + self.scope.push(HashSet::new()); + } + + pub fn close_scope(&mut self) { + match self.scope.pop() { + None => assert!(false, "Internal Failure: Environment no scope to close"), + Some(to_remove) => { + for k in to_remove { + match self.table.get_mut(&k) { + None => { assert!(false, "Internal Failure: Environment Close Scope"); }, + Some(r) => { r.pop(); }, + } + } + }, + } + } + + pub fn uniq(&mut self) -> usize { + let n = self.count; + self.count += 1; + n + } +} diff --git a/juno_frontend/src/lang.l b/juno_frontend/src/lang.l new file mode 100644 index 0000000000000000000000000000000000000000..6d96385ebc3aa0c38bd1439514c919192cc95b20 --- /dev/null +++ b/juno_frontend/src/lang.l @@ -0,0 +1,130 @@ +%x comment +%% +/\* <+comment>; + +<comment>/\* <+comment>; +<comment>\*+/ <-comment>; +<comment>\*+ ; +<comment>[\n\r] ; +<comment>. ; + +//[^\n\r]* ; + +[\t ]+ ; +[\n\r] ; + +#\[[^\]]*\] "FUNC_ATTR" + +as "as" +break "break" +by "by" +const "const" +continue "continue" +else "else" +false "false" +fn "fn" +for "for" +if "if" +inout "inout" +integer "integer" +let "let" +match "match" +mod "mod" +number "number" +pub "pub" +return "return" +struct "struct" +then "then" +to "to" +true "true" +type "type" +union "union" +use "use" +while "while" + +bool "bool" +i8 "i8" +i16 "i16" +i32 "i32" +i64 "i64" +u8 "u8" +u16 "u16" +u32 "u32" +u64 "u64" +usize "usize" +f32 "f32" +f64 "f64" +void "void" + +\+= "+=" +\+ "+" + +&&= "&&=" +&& "&&" +&= "&=" +& "&" + +\|\|= "||=" +\|\| "||" +\|= "|=" +\| "|" + +/= "/=" +/ "/" + +%= "%=" +% "%" + +\<<= "<<=" +\<< "<<" +\<= "<=" +\< "<" + +>>= ">>=" +>> ">>" +>= ">=" +> ">" + +\*= "*=" +\* "*" + +-= "-=" +-> "->" +- "-" + +\^= "^=" +\^ "^" + +=> "=>" +== "==" += "=" + +!= "!=" +! "!" + +:: "::" +: ":" + +, "," +\. "." +; ";" +~ "~" +_ "_" + +\( "(" +\) ")" +\{ "{" +\} "}" +\[ "[" +\] "]" + +\.[0-9]+ "DOT_NUM" +[a-zA-Z][a-zA-Z0-9_]* "ID" +[0-9]+ "INT" +0x[0-9a-fA-F]+ "HEX_INT" +0b[0-1]+ "BIN_INT" +0o[0-7]+ "OCT_INT" +[0-9]+\.[0-9]*(|e[0-9]+) "FLOAT" + +. "UNMATCHED" +. "UNARY" diff --git a/juno_frontend/src/lang.y b/juno_frontend/src/lang.y new file mode 100644 index 0000000000000000000000000000000000000000..2f63f998eb62e305163a3ad24977c9efbcd5afdc --- /dev/null +++ b/juno_frontend/src/lang.y @@ -0,0 +1,791 @@ +%start Program + +%token UNARY +%avoid_insert "FUNC_ATTR" "DOT_NUM" "ID" "INT" "HEX_INT" "BIN_INT" "OCT_INT" "FLOAT" +%expect-unused Unmatched 'UNMATCHED' 'UNARY' + +%nonassoc ')' +%nonassoc 'else' +%left '||' +%left '&&' +%left '|' +%left '^' +%left '&' +%nonassoc '==' '!=' +%nonassoc '<' '<=' '>' '>=' +%left '<<' '>>' +%left '+' '-' +%left '*' '/' '%' +%left 'as' 'size' +%right '~' '!' 'UNARY' +%left '.' 'DOT_NUM' '[' ']' + +%% +Program -> Result<Prg, ()> + : { Ok(vec![]) } + | Program Top { flatten($1, $2) } + ; + +Top -> Result<Top, ()> + : Import { $1 } + | TypeDecl { $1 } + | ConstDecl { $1 } + | FuncDecl { $1 } + | Module { $1 } + ; + +PubOption -> Result<bool, ()> + : { Ok(false) } + | 'pub' { Ok(true) } + ; + +Module -> Result<Top, ()> + : PubOption 'mod' 'ID' '{' Program '}' + { Ok(Top::ModDecl{ span : $span, public : $1?, name : span_of_tok($3)?, body : $5? }) }; + +PackageName -> Result<PackageName, ()> + : 'ID' { Ok(vec![span_of_tok($1)?]) } + | PackageName '::' 'ID' { flatten($1, span_of_tok($3)) } + ; + +PackageId -> Result<ImportName, ()> + : PackageName { Ok(($1?, None)) } + | PackageName '::' '*' { Ok(($1?, Some(span_of_tok($3)?))) } + ; + +Import -> Result<Top, ()> + : 'use' PackageId ';' { Ok(Top::Import{ span : $span, name : $2? }) }; + +TypeVars -> Result<Vec<TypeVar>, ()> + : { Ok(vec![]) } + | '<' TypeVarsI '>' { Ok($2?.into_iter().flatten().collect()) } + ; +TypeVarsI -> Result<VecDeque<Vec<TypeVar>>, ()> + : { Ok(VecDeque::new()) } + | TypeVar { Ok(VecDeque::from([$1?])) } + | TypeVar ',' TypeVarsI { cons_deque($1, $3) } + ; +TypeVar -> Result<Vec<TypeVar>, ()> + : IdList ':' Kind { let kind = $3?; + Ok($1?.into_iter() + .map(|n| TypeVar { span : $span, name : n, kind : kind }) + .collect()) } + ; +IdList -> Result<Vec<Id>, ()> + : 'ID' { Ok(vec![span_of_tok($1)?]) } + | IdList ',' 'ID' { flatten($1, span_of_tok($3)) } + ; + +Kind -> Result<Kind, ()> + : 'type' { Ok(Kind::Type) } + | 'usize' { Ok(Kind::USize) } + | 'number' { Ok(Kind::Number) } + | 'integer' { Ok(Kind::Integer) } + ; + +TypeDecl -> Result<Top, ()> + : PubOption 'type' 'ID' TypeVars '=' TypeDef ';' + { Ok(Top::TypeDecl{ span : $span, public : $1?, name : span_of_tok($3)?, ty_vars : $4?, + body : $6? } )}; + +TypeDef -> Result<TyDef, ()> + : Type + { Ok(TyDef::TypeAlias{ span : $span, body : $1? }) } + | PubOption 'struct' '{' ObjFields '}' + { Ok(TyDef::Struct{ span : $span, public : $1?, fields : $4? }) } + | PubOption 'union' '{' ObjFields '}' + { Ok(TyDef::Union { span : $span, public : $1?, fields : $4? }) } + ; + +ObjFields -> Result<Vec<ObjField>, ()> + : { Ok(vec![]) } + | ObjFields ObjField { flatten($1, $2) } + ; +ObjField -> Result<ObjField, ()> + : PubOption 'ID' ';' + { Ok(ObjField{ span : $span, public : $1?, name : span_of_tok($2)?, typ : None }) } + | PubOption 'ID' ':' Type ';' + { Ok(ObjField{ span : $span, public : $1?, name : span_of_tok($2)?, typ : Some($4?) }) } + ; + +Type -> Result<Type, ()> + : PrimType + { Ok(Type::PrimType{ span : $span, typ : $1? }) } + | '(' Types ')' + { Ok(Type::TupleType{ span : $span, tys : $2? }) } + | PackageName + { Ok(Type::NamedType{ span : $span, name : $1?, args : vec![] }) } + | PackageName '::' '<' TypeExprs '>' + { Ok(Type::NamedType{ span : $span, name : $1?, args : $4? }) } + | Type '[' TypeExprs ']' + { Ok(Type::ArrayType{ span : $span, elem : Box::new($1?), dims : $3? }) } + ; +Types -> Result<Vec<Type>, ()> + : { Ok(vec![]) } + | Type { Ok(vec![$1?]) } + | TypesS ',' Type { flatten($1, $3) } + ; +TypesS -> Result<Vec<Type>, ()> + : Type { Ok(vec![$1?]) } + | TypesS ',' Type { flatten($1, $3) } + ; +PrimType -> Result<Primitive, ()> + : 'bool' { Ok(Primitive::Bool) } + | 'i8' { Ok(Primitive::I8) } + | 'u8' { Ok(Primitive::U8) } + | 'i16' { Ok(Primitive::I16) } + | 'u16' { Ok(Primitive::U16) } + | 'i32' { Ok(Primitive::I32) } + | 'u32' { Ok(Primitive::U32) } + | 'i64' { Ok(Primitive::I64) } + | 'u64' { Ok(Primitive::U64) } + | 'usize' { Ok(Primitive::USize) } + | 'f32' { Ok(Primitive::F32) } + | 'f64' { Ok(Primitive::F64) } + | 'void' { Ok(Primitive::Void) } + ; + +ConstDecl -> Result<Top, ()> + : PubOption 'const' 'ID' '=' Expr ';' + { Ok(Top::ConstDecl{ span : $span, public : $1?, name : span_of_tok($3)?, ty : None, + body : $5? }) } + | PubOption 'const' 'ID' ':' Type '=' Expr ';' + { Ok(Top::ConstDecl{ span : $span, public : $1?, name : span_of_tok($3)?, ty : Some($5?), + body : $7? }) } + ; + +FuncDecl -> Result<Top, ()> + : PubOption 'fn' 'ID' TypeVars '(' Arguments ')' Stmts + { Ok(Top::FuncDecl{ span : $span, public : $1?, attr : None, name : span_of_tok($3)?, + ty_vars : $4?, args : $6?, ty : None, body : $8? }) } + | 'FUNC_ATTR' PubOption 'fn' 'ID' TypeVars '(' Arguments ')' Stmts + { Ok(Top::FuncDecl{ span : $span, public : $2?, attr : Some(span_of_tok($1)?), + name : span_of_tok($4)?, ty_vars : $5?, args : $7?, ty : None, + body : $9? }) } + | PubOption 'fn' 'ID' TypeVars '(' Arguments ')' '->' Type Stmts + { Ok(Top::FuncDecl{ span : $span, public : $1?, attr : None, name : span_of_tok($3)?, + ty_vars : $4?, args : $6?, ty : Some($9?), body : $10? }) } + | 'FUNC_ATTR' PubOption 'fn' 'ID' TypeVars '(' Arguments ')' '->' Type Stmts + { Ok(Top::FuncDecl{ span : $span, public : $2?, attr : Some(span_of_tok($1)?), + name : span_of_tok($4)?, ty_vars : $5?, args : $7?, ty : Some($10?), + body : $11? }) } + ; +Arguments -> Result<Vec<(Option<Span>, VarBind)>, ()> + : { Ok(vec![]) } + | ArgBind { Ok(vec![$1?]) } + | ArgumentsS ',' ArgBind { flatten($1, $3) } + ; +ArgumentsS -> Result<Vec<(Option<Span>, VarBind)>, ()> + : ArgBind { Ok(vec![$1?]) } + | ArgumentsS ',' ArgBind { flatten($1, $3) } + ; +ArgBind -> Result<(Option<Span>, VarBind), ()> + : 'inout' VarBind { Ok((Some(span_of_tok($1)?), $2?)) } + | VarBind { Ok((None, $1?)) } + ; + +VarBind -> Result<VarBind, ()> + : Pattern { Ok(VarBind{ span : $span, pattern : $1?, typ : None }) } + | Pattern ':' Type { Ok(VarBind{ span : $span, pattern : $1?, typ : Some($3?) }) } + ; + +Pattern -> Result<Pattern, ()> + : '_' { Ok(Pattern::Wildcard { span : $span }) } + | IntLit { let (span, base) = $1?; + Ok(Pattern::IntLit { span : span, base : base }) } + | PackageName { Ok(Pattern::Variable { span : $span, name : $1? }) } + | '(' PatternsComma ')' { Ok(Pattern::TuplePattern { span : $span, pats : $2? }) } + | PackageName '{' NamePatterns '}' + { Ok(Pattern::StructPattern { span : $span, name : $1?, pats : $3? }) } + | PackageName '(' PatternsComma ')' + { Ok(Pattern::UnionPattern { span : $span, name : $1?, pats : $3? }) } + ; +PatternsComma -> Result<Vec<Pattern>, ()> + : { Ok(vec![]) } + | Pattern { Ok(vec![$1?]) } + | PatternsCommaS ',' Pattern { flatten($1, $3) } + ; +PatternsCommaS -> Result<Vec<Pattern>, ()> + : Pattern { Ok(vec![$1?]) } + | PatternsCommaS ',' Pattern { flatten($1, $3) } + ; +NamePatterns -> Result<Vec<(Id, Pattern)>, ()> + : 'ID' ':' Pattern { Ok(vec![(span_of_tok($1)?, $3?)]) } + | NamePatternsS ',' 'ID' ':' Pattern { flatten($1, res_pair(span_of_tok($3), $5)) } + ; +NamePatternsS -> Result<Vec<(Id, Pattern)>, ()> + : 'ID' ':' Pattern { Ok(vec![(span_of_tok($1)?, $3?)]) } + | NamePatternsS ',' 'ID' ':' Pattern { flatten($1, res_pair(span_of_tok($3), $5)) } + ; + +Stmt -> Result<Stmt, ()> + : 'let' VarBind ';' + { Ok(Stmt::LetStmt{ span : $span, var : $2?, init : None }) } + | 'let' VarBind '=' Expr ';' + { Ok(Stmt::LetStmt{ span : $span, var : $2?, init : Some($4?) }) } + | 'const' VarBind ';' + { Ok(Stmt::ConstStmt{ span : $span, var : $2?, init : None }) } + | 'const' VarBind '=' Expr ';' + { Ok(Stmt::ConstStmt{ span : $span, var : $2?, init : Some($4?) }) } + | LExpr '=' Expr ';' + { Ok(Stmt::AssignStmt{ span : $span, lhs : $1?, assign : AssignOp::None, + assign_span : span_of_tok($2)?, rhs : $3? }) } + | LExpr '+=' Expr ';' + { Ok(Stmt::AssignStmt{ span : $span, lhs : $1?, assign : AssignOp::Add, + assign_span : span_of_tok($2)?, rhs : $3? }) } + | LExpr '-=' Expr ';' + { Ok(Stmt::AssignStmt{ span : $span, lhs : $1?, assign : AssignOp::Sub, + assign_span : span_of_tok($2)?, rhs : $3? }) } + | LExpr '*=' Expr ';' + { Ok(Stmt::AssignStmt{ span : $span, lhs : $1?, assign : AssignOp::Mul, + assign_span : span_of_tok($2)?, rhs : $3? }) } + | LExpr '/=' Expr ';' + { Ok(Stmt::AssignStmt{ span : $span, lhs : $1?, assign : AssignOp::Div, + assign_span : span_of_tok($2)?, rhs : $3? }) } + | LExpr '%=' Expr ';' + { Ok(Stmt::AssignStmt{ span : $span, lhs : $1?, assign : AssignOp::Mod, + assign_span : span_of_tok($2)?, rhs : $3? }) } + | LExpr '&=' Expr ';' + { Ok(Stmt::AssignStmt{ span : $span, lhs : $1?, assign : AssignOp::BitAnd, + assign_span : span_of_tok($2)?, rhs : $3? }) } + | LExpr '|=' Expr ';' + { Ok(Stmt::AssignStmt{ span : $span, lhs : $1?, assign : AssignOp::BitOr, + assign_span : span_of_tok($2)?, rhs : $3? }) } + | LExpr '^=' Expr ';' + { Ok(Stmt::AssignStmt{ span : $span, lhs : $1?, assign : AssignOp::Xor, + assign_span : span_of_tok($2)?, rhs : $3? }) } + | LExpr '&&=' Expr ';' + { Ok(Stmt::AssignStmt{ span : $span, lhs : $1?, assign : AssignOp::LogAnd, + assign_span : span_of_tok($2)?, rhs : $3? }) } + | LExpr '||=' Expr ';' + { Ok(Stmt::AssignStmt{ span : $span, lhs : $1?, assign : AssignOp::LogOr, + assign_span : span_of_tok($2)?, rhs : $3? }) } + | LExpr '<<=' Expr ';' + { Ok(Stmt::AssignStmt{ span : $span, lhs : $1?, assign : AssignOp::LShift, + assign_span : span_of_tok($2)?, rhs : $3? }) } + | LExpr '>>=' Expr ';' + { Ok(Stmt::AssignStmt{ span : $span, lhs : $1?, assign : AssignOp::RShift, + assign_span : span_of_tok($2)?, rhs : $3? }) } + | IfStmt + { $1 } + | 'match' NonStructExpr Cases + { Ok(Stmt::MatchStmt{ span : $span, expr : $2?, body : $3? }) } + | 'for' VarBind '=' NonStructExpr 'to' NonStructExpr Stmts + { Ok(Stmt::ForStmt{ span : $span, var : $2?, init : $4?, bound : $6?, step : None, + body : Box::new($7?) }) } + | 'for' VarBind '=' NonStructExpr 'to' NonStructExpr 'by' SignedIntLit Stmts + { Ok(Stmt::ForStmt{ span : $span, var : $2?, init : $4?, bound : $6?, step : Some($8?), + body : Box::new($9?) }) } + | 'while' NonStructExpr Stmts + { Ok(Stmt::WhileStmt{ span : $span, cond : $2?, body : Box::new($3?) }) } + | 'return' ';' + { Ok(Stmt::ReturnStmt{ span : $span, expr : None }) } + | 'return' Expr ';' + { Ok(Stmt::ReturnStmt{ span : $span, expr : Some($2?)}) } + | 'break' ';' + { Ok(Stmt::BreakStmt{ span : $span }) } + | 'continue' ';' + { Ok(Stmt::ContinueStmt{ span : $span }) } + | Stmts + { $1 } + | PackageName '(' Params ')' ';' + { Ok(Stmt::CallStmt{ span : $span, name : $1?, ty_args : vec![], args : $3? }) } + | PackageName '::' '<' TypeExprs '>' '(' Params ')' ';' + { Ok(Stmt::CallStmt{ span : $span, name : $1?, ty_args : $4?, args : $7? }) } + ; +Stmts -> Result<Stmt, ()> + : '{' StmtList '}' { Ok(Stmt::BlockStmt{ span : $span, body : $2? }) }; +StmtList -> Result<Vec<Stmt>, ()> + : { Ok(vec![]) } + | StmtList Stmt { flatten($1, $2) } + ; + +IfStmt -> Result<Stmt, ()> + : 'if' NonStructExpr Stmts + { Ok(Stmt::IfStmt{ span : $span, cond : $2?, thn : Box::new($3?), els : None }) } + | 'if' NonStructExpr Stmts 'else' IfStmt + { Ok(Stmt::IfStmt{ span : $span, cond : $2?, thn : Box::new($3?), + els : Some(Box::new($5?)) }) } + | 'if' NonStructExpr Stmts 'else' Stmts + { Ok(Stmt::IfStmt{ span : $span, cond : $2?, thn : Box::new($3?), + els : Some(Box::new($5?)) }) } + ; + +Cases -> Result<Vec<Case>, ()> + : '{' CaseList '}' { $2 } + ; +CaseList -> Result<Vec<Case>, ()> + : Case { Ok(vec![$1?]) } + | CaseList Case { flatten($1, $2) } + ; +Case -> Result<Case, ()> + : Patterns '=>' Stmt { Ok(Case{ span : $span, pat : $1?, body : $3? }) }; +Patterns -> Result<Vec<Pattern>, ()> + : Pattern { Ok(vec![$1?]) } + | '|' Pattern { Ok(vec![$2?]) } + | Patterns '|' Pattern { flatten($1, $3) } + ; + +LExpr -> Result<LExpr, ()> + : 'ID' { Ok(LExpr::VariableLExpr{ span : span_of_tok($1)? }) } + | LExpr '.' 'ID' { Ok(LExpr::FieldLExpr{ span : $span, lhs : Box::new($1?), + rhs : span_of_tok($3)? }) } + | LExpr 'DOT_NUM' { Ok(LExpr::NumFieldLExpr { span : $span, lhs : Box::new($1?), + rhs : span_of_tok($2)? }) } + | LExpr '[' Exprs ']' { Ok(LExpr::IndexLExpr{ span : $span, lhs : Box::new($1?), index : $3? }) } + ; + +SignedIntLit -> Result<(bool, Span, IntBase), ()> + : '+' IntLit { Ok((false, $2?.0, $2?.1)) } + | IntLit { Ok((false, $1?.0, $1?.1)) } + | '-' IntLit { Ok((true, $2?.0, $2?.1)) } + ; + +IntLit -> Result<(Span, IntBase), ()> + : 'INT' { Ok(($span, IntBase::Decimal)) } + | 'HEX_INT' { Ok(($span, IntBase::Hexadecimal)) } + | 'BIN_INT' { Ok(($span, IntBase::Binary)) } + | 'OCT_INT' { Ok(($span, IntBase::Octal)) } + ; + +Exprs -> Result<Vec<Expr>, ()> + : { Ok(vec![]) } + | Expr { Ok(vec![$1?]) } + | ExprsS ',' Expr { flatten($1, $3) } + ; +ExprsS -> Result<Vec<Expr>, ()> + : Expr { Ok(vec![$1?]) } + | ExprsS ',' Expr { flatten($1, $3) } + ; + +Expr -> Result<Expr, ()> + : PackageName + { Ok(Expr::Variable{ span : $span, name : $1? }) } + | Expr '.' 'ID' + { Ok(Expr::Field{ span : $span, lhs : Box::new($1?), rhs : span_of_tok($3)? }) } + | Expr 'DOT_NUM' + { Ok(Expr::NumField{ span : $span, lhs : Box::new($1?), rhs : span_of_tok($2)? }) } + | Expr '[' Exprs ']' + { Ok(Expr::ArrIndex{ span : $span, lhs : Box::new($1?), index : $3? }) } + | '(' Exprs ')' + { Ok(Expr::Tuple{ span : $span, exprs : $2? }) } + | PackageName '{' IdExprs '}' + { Ok(Expr::Struct{ span : $span, name : $1?, ty_args : vec![], exprs : $3? }) } + | PackageName '::' '<' TypeExprs '>' '{' IdExprs '}' + { Ok(Expr::Struct{ span : $span, name : $1?, ty_args : $4?, exprs : $7? }) } + | 'true' + { Ok(Expr::BoolLit{ span : $span, value : true }) } + | 'false' + { Ok(Expr::BoolLit{ span : $span, value : false }) } + | IntLit + { let (span, base) = $1?; + Ok(Expr::IntLit{ span : span, base : base }) } + | 'FLOAT' + { Ok(Expr::FloatLit{ span : $span }) } + | '-' Expr %prec 'UNARY' + { Ok(Expr::UnaryExpr{ span : $span, op : UnaryOp::Negation, expr : Box::new($2?)}) } + | '~' Expr + { Ok(Expr::UnaryExpr{ span : $span, op : UnaryOp::BitwiseNot, expr : Box::new($2?)}) } + | '!' Expr + { Ok(Expr::UnaryExpr{ span : $span, op : UnaryOp::LogicalNot, expr : Box::new($2?)}) } + | Expr 'as' Type + { Ok(Expr::CastExpr{ span : $span, expr : Box::new($1?), typ : $3?}) } + | Expr '+' Expr + { Ok(Expr::BinaryExpr{ span: $span, op: BinaryOp::Add, lhs: Box::new($1?), rhs: Box::new($3?)}) } + | Expr '-' Expr + { Ok(Expr::BinaryExpr{ span: $span, op: BinaryOp::Sub, lhs: Box::new($1?), rhs: Box::new($3?)}) } + | Expr '*' Expr + { Ok(Expr::BinaryExpr{ span: $span, op: BinaryOp::Mul, lhs: Box::new($1?), rhs: Box::new($3?)}) } + | Expr '/' Expr + { Ok(Expr::BinaryExpr{ span: $span, op: BinaryOp::Div, lhs: Box::new($1?), rhs: Box::new($3?)}) } + | Expr '%' Expr + { Ok(Expr::BinaryExpr{ span: $span, op: BinaryOp::Mod, lhs: Box::new($1?), rhs: Box::new($3?)}) } + | Expr '&' Expr + { Ok(Expr::BinaryExpr{ span: $span, op: BinaryOp::BitAnd, lhs: Box::new($1?), rhs: Box::new($3?)}) } + | Expr '&&' Expr + { Ok(Expr::BinaryExpr{ span: $span, op: BinaryOp::LogAnd, lhs: Box::new($1?), rhs: Box::new($3?)}) } + | Expr '|' Expr + { Ok(Expr::BinaryExpr{ span: $span, op: BinaryOp::BitOr, lhs: Box::new($1?), rhs: Box::new($3?)}) } + | Expr '||' Expr + { Ok(Expr::BinaryExpr{ span: $span, op: BinaryOp::LogOr, lhs: Box::new($1?), rhs: Box::new($3?)}) } + | Expr '^' Expr + { Ok(Expr::BinaryExpr{ span: $span, op: BinaryOp::Xor, lhs: Box::new($1?), rhs: Box::new($3?)}) } + | Expr '<' Expr + { Ok(Expr::BinaryExpr{ span: $span, op: BinaryOp::Lt, lhs: Box::new($1?), rhs: Box::new($3?)}) } + | Expr '<=' Expr + { Ok(Expr::BinaryExpr{ span: $span, op: BinaryOp::Le, lhs: Box::new($1?), rhs: Box::new($3?)}) } + | Expr '>' Expr + { Ok(Expr::BinaryExpr{ span: $span, op: BinaryOp::Gt, lhs: Box::new($1?), rhs: Box::new($3?)}) } + | Expr '>=' Expr + { Ok(Expr::BinaryExpr{ span: $span, op: BinaryOp::Ge, lhs: Box::new($1?), rhs: Box::new($3?)}) } + | Expr '==' Expr + { Ok(Expr::BinaryExpr{ span: $span, op: BinaryOp::Eq, lhs: Box::new($1?), rhs: Box::new($3?)}) } + | Expr '!=' Expr + { Ok(Expr::BinaryExpr{ span: $span, op: BinaryOp::Neq, lhs: Box::new($1?), rhs: Box::new($3?)}) } + | Expr '<<' Expr + { Ok(Expr::BinaryExpr{ span: $span, op: BinaryOp::LShift, lhs: Box::new($1?), rhs: Box::new($3?)}) } + | Expr '>>' Expr + { Ok(Expr::BinaryExpr{ span: $span, op: BinaryOp::RShift, lhs: Box::new($1?), rhs: Box::new($3?)}) } + | 'if' Expr 'then' Expr 'else' Expr + { Ok(Expr::CondExpr{ span: $span, cond : Box::new($2?), thn : Box::new($4?), els : Box::new($6?) })} + | PackageName '(' Params ')' + { Ok(Expr::CallExpr{ span : $span, name : $1?, ty_args : vec![], args: $3? }) } + | PackageName '::' '<' TypeExprs '>' '(' Params ')' + { Ok(Expr::CallExpr{ span : $span, name : $1?, ty_args : $4?, args: $7? }) } + ; +IdExprs -> Result<Vec<(Id, Expr)>, ()> + : 'ID' '=' Expr { Ok(vec![(span_of_tok($1)?, $3?)]) } + | IdExprsS ',' 'ID' '=' Expr { flatten($1, res_pair(span_of_tok($3), $5)) } + ; +IdExprsS -> Result<Vec<(Id, Expr)>, ()> + : 'ID' '=' Expr { Ok(vec![(span_of_tok($1)?, $3?)]) } + | IdExprsS ',' 'ID' '=' Expr { flatten($1, res_pair(span_of_tok($3), $5)) } + ; +Params -> Result<Vec<(bool, Expr)>, ()> + : { Ok(vec![]) } + | Expr { Ok(vec![(false, $1?)]) } + | '&' Expr { Ok(vec![(true, $2?)]) } + | ParamsS ',' Expr { flatten($1, Ok((false, $3?))) } + | ParamsS ',' '&' Expr { flatten($1, Ok((true, $4?))) } + ; +ParamsS -> Result<Vec<(bool, Expr)>, ()> + : Expr { Ok(vec![(false, $1?)]) } + | '&' Expr { Ok(vec![(true, $2?)]) } + | ParamsS ',' Expr { flatten($1, Ok((false, $3?))) } + | ParamsS ',' '&' Expr { flatten($1, Ok((true, $4?))) } + ; + +NonStructExpr -> Result<Expr, ()> + : PackageName + { Ok(Expr::Variable{ span : $span, name : $1? }) } + | NonStructExpr '.' 'ID' + { Ok(Expr::Field{ span : $span, lhs : Box::new($1?), rhs : span_of_tok($3)? }) } + | NonStructExpr 'DOT_NUM' + { Ok(Expr::NumField{ span : $span, lhs : Box::new($1?), rhs : span_of_tok($2)? }) } + | NonStructExpr '[' Exprs ']' + { Ok(Expr::ArrIndex{ span : $span, lhs : Box::new($1?), index : $3? }) } + | '(' Exprs ')' + { Ok(Expr::Tuple{ span : $span, exprs : $2? }) } + | 'true' + { Ok(Expr::BoolLit{ span : $span, value : true }) } + | 'false' + { Ok(Expr::BoolLit{ span : $span, value : false }) } + | IntLit + { let (span, base) = $1?; + Ok(Expr::IntLit{ span : span, base : base }) } + | 'FLOAT' + { Ok(Expr::FloatLit{ span : $span }) } + | '-' NonStructExpr %prec 'UNARY' + { Ok(Expr::UnaryExpr{ span : $span, op : UnaryOp::Negation, expr : Box::new($2?)}) } + | '~' NonStructExpr + { Ok(Expr::UnaryExpr{ span : $span, op : UnaryOp::BitwiseNot, expr : Box::new($2?)}) } + | '!' NonStructExpr + { Ok(Expr::UnaryExpr{ span : $span, op : UnaryOp::LogicalNot, expr : Box::new($2?)}) } + | NonStructExpr 'as' Type + { Ok(Expr::CastExpr{ span : $span, expr : Box::new($1?), typ : $3?}) } + | NonStructExpr '+' NonStructExpr + { Ok(Expr::BinaryExpr{ span: $span, op: BinaryOp::Add, lhs: Box::new($1?), rhs: Box::new($3?)}) } + | NonStructExpr '-' NonStructExpr + { Ok(Expr::BinaryExpr{ span: $span, op: BinaryOp::Sub, lhs: Box::new($1?), rhs: Box::new($3?)}) } + | NonStructExpr '*' NonStructExpr + { Ok(Expr::BinaryExpr{ span: $span, op: BinaryOp::Mul, lhs: Box::new($1?), rhs: Box::new($3?)}) } + | NonStructExpr '/' NonStructExpr + { Ok(Expr::BinaryExpr{ span: $span, op: BinaryOp::Div, lhs: Box::new($1?), rhs: Box::new($3?)}) } + | NonStructExpr '%' NonStructExpr + { Ok(Expr::BinaryExpr{ span: $span, op: BinaryOp::Mod, lhs: Box::new($1?), rhs: Box::new($3?)}) } + | NonStructExpr '&' NonStructExpr + { Ok(Expr::BinaryExpr{ span: $span, op: BinaryOp::BitAnd, lhs: Box::new($1?), rhs: Box::new($3?)}) } + | NonStructExpr '&&' NonStructExpr + { Ok(Expr::BinaryExpr{ span: $span, op: BinaryOp::LogAnd, lhs: Box::new($1?), rhs: Box::new($3?)}) } + | NonStructExpr '|' NonStructExpr + { Ok(Expr::BinaryExpr{ span: $span, op: BinaryOp::BitOr, lhs: Box::new($1?), rhs: Box::new($3?)}) } + | NonStructExpr '||' NonStructExpr + { Ok(Expr::BinaryExpr{ span: $span, op: BinaryOp::LogOr, lhs: Box::new($1?), rhs: Box::new($3?)}) } + | NonStructExpr '^' NonStructExpr + { Ok(Expr::BinaryExpr{ span: $span, op: BinaryOp::Xor, lhs: Box::new($1?), rhs: Box::new($3?)}) } + | NonStructExpr '<' NonStructExpr + { Ok(Expr::BinaryExpr{ span: $span, op: BinaryOp::Lt, lhs: Box::new($1?), rhs: Box::new($3?)}) } + | NonStructExpr '<=' NonStructExpr + { Ok(Expr::BinaryExpr{ span: $span, op: BinaryOp::Le, lhs: Box::new($1?), rhs: Box::new($3?)}) } + | NonStructExpr '>' NonStructExpr + { Ok(Expr::BinaryExpr{ span: $span, op: BinaryOp::Gt, lhs: Box::new($1?), rhs: Box::new($3?)}) } + | NonStructExpr '>=' NonStructExpr + { Ok(Expr::BinaryExpr{ span: $span, op: BinaryOp::Ge, lhs: Box::new($1?), rhs: Box::new($3?)}) } + | NonStructExpr '==' NonStructExpr + { Ok(Expr::BinaryExpr{ span: $span, op: BinaryOp::Eq, lhs: Box::new($1?), rhs: Box::new($3?)}) } + | NonStructExpr '!=' NonStructExpr + { Ok(Expr::BinaryExpr{ span: $span, op: BinaryOp::Neq, lhs: Box::new($1?), rhs: Box::new($3?)}) } + | NonStructExpr '<<' NonStructExpr + { Ok(Expr::BinaryExpr{ span: $span, op: BinaryOp::LShift, lhs: Box::new($1?), rhs: Box::new($3?)}) } + | NonStructExpr '>>' NonStructExpr + { Ok(Expr::BinaryExpr{ span: $span, op: BinaryOp::RShift, lhs: Box::new($1?), rhs: Box::new($3?)}) } + | 'if' NonStructExpr 'then' NonStructExpr 'else' NonStructExpr + { Ok(Expr::CondExpr{ span: $span, cond : Box::new($2?), thn : Box::new($4?), els : Box::new($6?) })} + | PackageName '(' Params ')' + { Ok(Expr::CallExpr{ span : $span, name : $1?, ty_args : vec![], args: $3? }) } + | PackageName '::' '<' TypeExprs '>' '(' Params ')' + { Ok(Expr::CallExpr{ span : $span, name : $1?, ty_args : $4?, args: $7? }) } + ; + +TypeExprs -> Result<Vec<TypeExpr>, ()> + : { Ok(vec![]) } + | TypeExpr { Ok(vec![$1?]) } + | TypeExprsS ',' TypeExpr { flatten($1, $3) } + ; +TypeExprsS -> Result<Vec<TypeExpr>, ()> + : TypeExpr { Ok(vec![$1?]) } + | TypeExprsS ',' TypeExpr { flatten($1, $3) } + ; +TypeExpr -> Result<TypeExpr, ()> + : PrimType + { Ok(TypeExpr::PrimType{ span : $span, typ : $1? }) } + | '(' TypeExprs ')' + { Ok(TypeExpr::TupleType{ span : $span, tys : $2? }) } + | PackageName + { Ok(TypeExpr::NamedTypeExpr{ span : $span, name : $1?, args : vec![]}) } + | PackageName '::' '<' TypeExprs '>' + { Ok(TypeExpr::NamedTypeExpr{ span : $span, name : $1?, args : $4? }) } + | TypeExpr '[' TypeExprs ']' + { Ok(TypeExpr::ArrayTypeExpr{ span : $span, elem : Box::new($1?), dims : $3? }) } + | IntLit + { let (span, base) = $1?; + Ok(TypeExpr::IntLiteral{ span : span, base : base }) } + | '-' TypeExpr %prec UNARY + { Ok(TypeExpr::Negative{ span : $span, expr : Box::new($2?)}) } + | TypeExpr '+' TypeExpr + { Ok(TypeExpr::Add{ span : $span, lhs : Box::new($1?), rhs : Box::new($3?) }) } + | TypeExpr '-' TypeExpr + { Ok(TypeExpr::Sub{ span : $span, lhs : Box::new($1?), rhs : Box::new($3?) }) } + | TypeExpr '*' TypeExpr + { Ok(TypeExpr::Mul{ span : $span, lhs : Box::new($1?), rhs : Box::new($3?) }) } + | TypeExpr '/' TypeExpr + { Ok(TypeExpr::Div{ span : $span, lhs : Box::new($1?), rhs : Box::new($3?) }) } + ; + +Unmatched -> (): 'UNMATCHED' { }; +%% + +use cfgrammar::Span; +use lrlex::DefaultLexeme; +use std::collections::VecDeque; + +fn flatten<T>(lhs: Result<Vec<T>, ()>, rhs: Result<T, ()>) -> Result<Vec<T>, ()> { + let mut flt = lhs?; + flt.push(rhs?); + Ok(flt) +} + +fn cons_deque<T>(lhs : Result<T, ()>, rhs : Result<VecDeque<T>, ()>) -> Result<VecDeque<T>, ()> { + let mut lst = rhs?; + lst.push_front(lhs?); + Ok(lst) +} + +fn span_of_tok(t : Result<DefaultLexeme, DefaultLexeme>) -> Result<Span, ()> { + t.map_err(|_| ()).map(|l| l.span()) +} + +fn res_pair<A, B>(x : Result<A, ()>, y : Result<B, ()>) -> Result<(A, B), ()> { + Ok((x?, y?)) +} + +pub type Prg = Vec<Top>; +pub type Id = Span; +pub type PackageName = Vec<Span>; +pub type ImportName = (PackageName, Option<Span>); // option is the wildcard * + +#[derive(Debug, Copy, Clone)] +pub enum Kind { Type, USize, Number, Integer } +#[derive(Debug, Copy, Clone, PartialEq, Eq)] +pub enum Primitive { Bool, I8, U8, I16, U16, I32, U32, I64, U64, USize, F32, F64, Void } +#[derive(Debug, Copy, Clone, PartialEq, Eq)] +pub enum AssignOp { None, Add, Sub, Mul, Div, Mod, BitAnd, BitOr, Xor, LogAnd, LogOr, + LShift, RShift } +#[derive(Debug, Copy, Clone)] +pub enum IntBase { Binary, Octal, Decimal, Hexadecimal } +#[derive(Debug, Copy, Clone)] +pub enum UnaryOp { Negation, BitwiseNot, LogicalNot } +#[derive(Debug, Copy, Clone)] +pub enum BinaryOp { Add, Sub, Mul, Div, Mod, BitAnd, LogAnd, BitOr, LogOr, Xor, + Lt, Le, Gt, Ge, Eq, Neq, LShift, RShift } + +#[derive(Debug)] +pub struct ObjField { pub span : Span, pub public : bool, pub name : Id, pub typ : Option<Type> } +#[derive(Debug)] +pub struct TypeVar { pub span : Span, pub name : Id, pub kind : Kind } +#[derive(Debug)] +pub struct VarBind { pub span : Span, pub pattern : Pattern, pub typ : Option<Type> } +#[derive(Debug)] +pub struct Case { pub span : Span, pub pat : Vec<Pattern>, pub body : Stmt } + +#[derive(Debug)] +pub enum Top { + Import { span : Span, name : ImportName }, + TypeDecl { span : Span, public : bool, name : Id, ty_vars : Vec<TypeVar>, body : TyDef }, + ConstDecl { span : Span, public : bool, name : Id, ty : Option<Type>, body : Expr }, + FuncDecl { span : Span, public : bool, attr : Option<Span>, name : Id, ty_vars : Vec<TypeVar>, + args : Vec<(Option<Span>, VarBind)>, // option is for inout + ty : Option<Type>, body : Stmt }, + ModDecl { span : Span, public : bool, name : Id, body : Vec<Top> }, +} + +#[derive(Debug)] +pub enum TyDef { + TypeAlias { span : Span, body : Type }, + Struct { span : Span, public : bool, fields : Vec<ObjField> }, + Union { span : Span, public : bool, fields : Vec<ObjField> }, +} + +#[derive(Debug)] +pub enum Type { + PrimType { span : Span, typ : Primitive }, + TupleType { span : Span, tys : Vec<Type> }, + NamedType { span : Span, name : PackageName, args : Vec<TypeExpr> }, + ArrayType { span : Span, elem : Box<Type>, dims : Vec<TypeExpr> }, +} + +#[derive(Debug)] +pub enum Stmt { + LetStmt { span : Span, var : VarBind, init : Option<Expr> }, + ConstStmt { span : Span, var : VarBind, init : Option<Expr> }, + AssignStmt { span : Span, lhs : LExpr, assign : AssignOp, assign_span : Span, rhs : Expr }, + IfStmt { span : Span, cond : Expr, thn : Box<Stmt>, els : Option<Box<Stmt>> }, + MatchStmt { span : Span, expr : Expr, body : Vec<Case> }, + // The step records: negative, number, base + ForStmt { span : Span, var : VarBind, init : Expr, bound : Expr, step : Option<(bool, Span, IntBase)>, + body : Box<Stmt> }, + WhileStmt { span : Span, cond : Expr, body : Box<Stmt> }, + ReturnStmt { span : Span, expr : Option<Expr> }, + BreakStmt { span : Span }, + ContinueStmt { span : Span }, + BlockStmt { span : Span, body : Vec<Stmt> }, + CallStmt { span : Span, name : PackageName, ty_args : Vec<TypeExpr>, + args : Vec<(bool, Expr)> }, // bool indicates & (for inouts) +} + +#[derive(Debug)] +pub enum Pattern { + Wildcard { span : Span }, + IntLit { span : Span, base : IntBase }, + Variable { span : Span, name : PackageName }, + TuplePattern { span : Span, pats : Vec<Pattern> }, + StructPattern { span : Span, name : PackageName, pats : Vec<(Id, Pattern)> }, + UnionPattern { span : Span, name : PackageName, pats : Vec<Pattern> }, +} + +#[derive(Debug)] +pub enum LExpr { + VariableLExpr { span : Span }, + FieldLExpr { span : Span, lhs : Box<LExpr>, rhs : Id }, + NumFieldLExpr { span : Span, lhs : Box<LExpr>, rhs : Span }, + IndexLExpr { span : Span, lhs : Box<LExpr>, index : Vec<Expr> }, +} + +#[derive(Debug)] +pub enum Expr { + Variable { span : Span, name : PackageName }, + Field { span : Span, lhs : Box<Expr>, rhs : Id }, + NumField { span : Span, lhs : Box<Expr>, rhs : Span }, + ArrIndex { span : Span, lhs : Box<Expr>, index : Vec<Expr> }, + Tuple { span : Span, exprs : Vec<Expr> }, + Struct { span : Span, name : PackageName, ty_args : Vec<TypeExpr>, + exprs : Vec<(Id, Expr)> }, + BoolLit { span : Span, value : bool }, + IntLit { span : Span, base : IntBase }, + FloatLit { span : Span }, + UnaryExpr { span : Span, op : UnaryOp, expr : Box<Expr> }, + BinaryExpr { span : Span, op : BinaryOp, lhs : Box<Expr>, rhs : Box<Expr> }, + CastExpr { span : Span, expr : Box<Expr>, typ : Type }, + CondExpr { span : Span, cond : Box<Expr>, thn : Box<Expr>, els : Box<Expr> }, + CallExpr { span : Span, name : PackageName, ty_args : Vec<TypeExpr>, + args : Vec<(bool, Expr)> }, // bool indicates & (for inouts) +} + +#[derive(Debug)] +pub enum TypeExpr { + PrimType { span : Span, typ : Primitive }, + TupleType { span : Span, tys : Vec<TypeExpr> }, + NamedTypeExpr { span : Span, name : PackageName, args : Vec<TypeExpr> }, + ArrayTypeExpr { span : Span, elem : Box<TypeExpr>, dims : Vec<TypeExpr> }, + IntLiteral { span : Span, base : IntBase }, + Negative { span : Span, expr : Box<TypeExpr> }, + Add { span : Span, lhs : Box<TypeExpr>, rhs : Box<TypeExpr> }, + Sub { span : Span, lhs : Box<TypeExpr>, rhs : Box<TypeExpr> }, + Mul { span : Span, lhs : Box<TypeExpr>, rhs : Box<TypeExpr> }, + Div { span : Span, lhs : Box<TypeExpr>, rhs : Box<TypeExpr> }, +} + +pub trait Spans { + fn span(&self) -> Span; +} + +impl Spans for Expr { + fn span(&self) -> Span { + match self { + Expr::Variable { span, .. } + | Expr::Field { span, .. } + | Expr::NumField { span, .. } + | Expr::ArrIndex { span, .. } + | Expr::Tuple { span, .. } + | Expr::Struct { span, .. } + | Expr::BoolLit { span, .. } + | Expr::IntLit { span, .. } + | Expr::FloatLit { span } + | Expr::UnaryExpr { span, .. } + | Expr::BinaryExpr { span, .. } + | Expr::CastExpr { span, .. } + | Expr::CondExpr { span, .. } + | Expr::CallExpr { span, .. } + => *span + } + } +} + +impl Spans for Stmt { + fn span(&self) -> Span { + match self { + Stmt::LetStmt { span, .. } + | Stmt::ConstStmt { span, .. } + | Stmt::AssignStmt { span, .. } + | Stmt::IfStmt { span, .. } + | Stmt::MatchStmt { span, .. } + | Stmt::ForStmt { span, .. } + | Stmt::WhileStmt { span, .. } + | Stmt::ReturnStmt { span, .. } + | Stmt::BreakStmt { span, .. } + | Stmt::ContinueStmt { span, .. } + | Stmt::BlockStmt { span, .. } + | Stmt::CallStmt { span, .. } + => *span + } + } +} + +impl Spans for TypeExpr { + fn span(&self) -> Span { + match self { + TypeExpr::PrimType { span, .. } + | TypeExpr::TupleType { span, .. } + | TypeExpr::NamedTypeExpr { span, .. } + | TypeExpr::ArrayTypeExpr { span, .. } + | TypeExpr::IntLiteral { span, .. } + | TypeExpr::Negative { span, .. } + | TypeExpr::Add { span, .. } + | TypeExpr::Sub { span, .. } + | TypeExpr::Mul { span, .. } + | TypeExpr::Div { span, .. } + => *span + } + } +} + +impl IntBase { + pub fn base(&self) -> u32 { + match self { + IntBase::Binary => 2, + IntBase::Octal => 8, + IntBase::Decimal => 10, + IntBase::Hexadecimal => 16, + } + } +} diff --git a/juno_frontend/src/main.rs b/juno_frontend/src/main.rs new file mode 100644 index 0000000000000000000000000000000000000000..f8e267eb3d2f2eb9bc8e7555cd313af0cc4c857d --- /dev/null +++ b/juno_frontend/src/main.rs @@ -0,0 +1,48 @@ +extern crate clap; + +use clap::Parser; + +mod codegen; +mod env; +mod semant; +mod ssa; +mod types; + +use codegen::*; + +extern crate hercules_ir; + +#[derive(Parser)] +#[command(author, version, about, long_about = None)] +struct Cli { + src_file : String, +} + +fn main() { + let args = Cli::parse(); + let prg = semant::parse_and_analyze(args.src_file); + match prg { + Ok(prg) => { + let module = codegen_program(prg); + + let mut pm = hercules_opt::pass::PassManager::new(module); + pm.add_pass(hercules_opt::pass::Pass::Verify); + pm.add_pass(hercules_opt::pass::Pass::PhiElim); + pm.add_pass(hercules_opt::pass::Pass::Verify); + pm.add_pass(hercules_opt::pass::Pass::CCP); + pm.add_pass(hercules_opt::pass::Pass::DCE); + pm.add_pass(hercules_opt::pass::Pass::GVN); + pm.add_pass(hercules_opt::pass::Pass::DCE); + pm.add_pass(hercules_opt::pass::Pass::Xdot(true)); + pm.add_pass(hercules_opt::pass::Pass::Forkify); + pm.add_pass(hercules_opt::pass::Pass::DCE); + pm.add_pass(hercules_opt::pass::Pass::Xdot(true)); + let _ = pm.run_passes(); + }, + Err(errs) => { + for err in errs{ + eprintln!("{}", err); + } + }, + } +} diff --git a/juno_frontend/src/semant.rs b/juno_frontend/src/semant.rs new file mode 100644 index 0000000000000000000000000000000000000000..324ef5b9f9519a61ae73b2e782c3b82f813c15eb --- /dev/null +++ b/juno_frontend/src/semant.rs @@ -0,0 +1,4010 @@ +extern crate hercules_ir; + +use std::collections::{HashMap, LinkedList}; +use std::fs::File; +use std::io::Read; +use std::fmt; + +use lrlex::{lrlex_mod, DefaultLexerTypes}; +use lrpar::{lrpar_mod, NonStreamingLexer}; +use cfgrammar::Span; + +use ordered_float::OrderedFloat; + +lrlex_mod!("lang.l"); +lrpar_mod!("lang.y"); + +use lang_y::*; +use crate::env::Env; +use crate::types; +use crate::types::{DynamicConstant, Either, Type, TypeSolver}; + +// Definitions and data structures for semantic analysis + +// Entities in the environment +enum Entity { + // A variable has a variable number to distinguish shadowing + Variable { variable : usize, typ : Type, is_const : bool }, + Type { type_args : Vec<lang_y::Kind>, value : Type }, + DynConst { value : usize }, // dynamic constant number + Constant { value : Constant }, + // For functions we track an index, its type parameters, its argument types and if they are + // inout, and its return type + Function { index : usize, type_args : Vec<lang_y::Kind>, + args : Vec<(types::Type, bool)>, return_type : types::Type }, +} + +// Constant values +#[derive(Clone, Debug)] +pub enum Literal { + Unit, Bool(bool), Integer(u64), Float(f64), + Tuple(Vec<Constant>), + Sum(usize, Box<Constant>), // The tag and value +} +pub type Constant = (Literal, Type); + +impl PartialEq for Literal { + fn eq(&self, other : &Self) -> bool { + match (self, other) { + (Literal::Unit, Literal::Unit) => true, + (Literal::Bool(b), Literal::Bool(c)) => b == c, + (Literal::Integer(i), Literal::Integer(j)) => i == j, + (Literal::Float(i), Literal::Float(j)) => + OrderedFloat(*i) == OrderedFloat(*j), + (Literal::Tuple(fs), Literal::Tuple(gs)) => fs == gs, + (Literal::Sum(i, v), Literal::Sum(j, u)) => i == j && *v == *u, + _ => false, + } + } +} + +impl Eq for Literal {} + +// Map strings to unique identifiers and counts uids +struct StringTable { + count : usize, + string_to_index : HashMap<String, usize>, + index_to_string : HashMap<usize, String>, +} +impl StringTable { + fn new() -> StringTable { + StringTable { count : 0, + string_to_index : HashMap::new(), + index_to_string : HashMap::new(), } + } + + // Produce the UID for a string + fn lookup_string(&mut self, s : String) -> usize { + match self.string_to_index.get(&s) { + Some(n) => *n, + None => { + let n = self.count; + self.count += 1; + self.string_to_index.insert(s.clone(), n); + self.index_to_string.insert(n, s); + n + }, + } + } + + // Identify the string corresponding to a UID + fn lookup_id(&self, n : usize) -> Option<String> { + self.index_to_string.get(&n).cloned() + } +} + +// Convert spans into uids in the String Table +fn intern_id(n : &Span, lex : &dyn NonStreamingLexer<DefaultLexerTypes<u32>>, + stringtab : &mut StringTable) -> usize { + stringtab.lookup_string(lex.span_str(*n).to_string()) +} + +fn intern_package_name( + n : &PackageName, lex : &dyn NonStreamingLexer<DefaultLexerTypes<u32>>, + stringtab : &mut StringTable) -> Vec<usize> { + + let mut res = vec![]; + for s in n { + res.push(intern_id(s, lex, stringtab)); + } + res +} + +// A location in the program, used in error messages +pub struct Location { + start_line : usize, start_column : usize, + end_line : usize, end_column : usize, +} + +impl Location { + fn fake() -> Location { + Location { start_line : 0, start_column : 0, + end_line : 0, end_column : 0 } + } +} + +// Conversion from span to internal locations +fn span_to_loc(span : Span, lexer : &dyn NonStreamingLexer<DefaultLexerTypes<u32>>) + -> Location { + let ((start_line, start_column), (end_line, end_column)) = lexer.line_col(span); + Location { start_line, start_column, end_line, end_column } +} + +// Printing locations +impl fmt::Display for Location { + fn fmt(&self, f : &mut fmt::Formatter) -> fmt::Result { + write!(f, "{}, {} -- {}, {}", + self.start_line, self.start_column, + self.end_line, self.end_column) + } +} + +// Error Messages +pub enum ErrorMessage { + NotImplemented(Location, String), + IOError(String), + SyntaxError(String), + SemanticError(Location, String), // Other errors, with a location and description + // Undefined variable at location, variable name + UndefinedVariable(Location, String), + // Kind error at location, expected, actual) + KindError(Location, String, String), + // Type error at location, expected type, actual type) + TypeError(Location, String, String), +} + +// Printing for error messages +impl fmt::Display for ErrorMessage { + fn fmt(&self, f : &mut fmt::Formatter) -> fmt::Result { + match self { + ErrorMessage::NotImplemented(loc, msg) => { + write!(f, "Error ({}). Feature not implemented : {}", loc, msg) + }, + ErrorMessage::IOError(msg) => { + write!(f, "Error: {}", msg) + }, + ErrorMessage::SyntaxError(msg) => { + write!(f, "Syntax Error : {}", msg) + }, + ErrorMessage::SemanticError(loc, msg) => { + write!(f, "Error ({}). {}", loc, msg) + }, + ErrorMessage::UndefinedVariable(loc, name) => { + write!(f, "Error ({}). Undefined variable '{}'", loc, name) + }, + ErrorMessage::KindError(loc, expected, actual) => { + write!(f, "Error ({}). Expected {} but found {}", loc, expected, actual) + }, + ErrorMessage::TypeError(loc, expected, actual) => { + write!(f, "Error ({}). Type error, expected {} but found {}", + loc, expected, actual) + }, + } + } +} + +pub type ErrorMessages = LinkedList<ErrorMessage>; + +// Constructors and combiners for error messages +fn singleton_error(err : ErrorMessage) -> ErrorMessages { + LinkedList::from([err]) +} + +fn append_errors2<A, B>(x : Result<A, ErrorMessages>, y : Result<B, ErrorMessages>) + -> Result<(A, B), ErrorMessages> { + match (x, y) { + (Err(mut err_x), Err(mut err_y)) => { + err_x.append(&mut err_y); + Err(err_x) + }, + (Err(err_x), _) => Err(err_x), + (_, Err(err_y)) => Err(err_y), + (Ok(x), Ok(y)) => Ok((x, y)), + } +} + +fn append_errors3<A, B, C>(x : Result<A, ErrorMessages>, y : Result<B, ErrorMessages>, + z : Result<C, ErrorMessages>) -> Result<(A, B, C), ErrorMessages> { + let xy = append_errors2(x, y); + let xyz = append_errors2(xy, z); + + match xyz { + Err(errs) => Err(errs), + Ok(((x, y), z)) => Ok((x, y, z)), + } +} + +// Normalized AST forms after semantic analysis +// These include type information at all expression nodes, and remove names and locations +pub type Prg = (TypeSolver, Vec<Function>); + +// The function stores information for code-generation. The type information therefore is not the +// type information that is needed for type checking code that uses this function. +// In particular, the return type accounts for the type of inout arguments +pub struct Function { + pub name : String, + pub num_dyn_consts : usize, + pub num_type_args : usize, + pub arguments : Vec<(usize, Type)>, + pub return_type : Type, + pub body : Stmt, +} + +// Normalized statements differ in a number of ways from the form from the parser: +// 1. Let/Const are replaced by assignments (since we've already typed checked) +// 2. Assignment operators of the form lhs X= rhs are replaced by assignments of the form lhs = lhs +// X rhs, and assignments now always occur to variables, so modification of fields/values in an +// array are replaced by appropriate expressions which write to those portions of the value +// 3. Match statements are eliminated except for those on union types, and only one match per level +// 4. Both for and while loops are combined into a single loop form +// 5. Call statements are transformed to call expressions, with a new statement form for +// expressions +// Additional notes +// - Returns in this AST include the inout values +pub enum Stmt { + AssignStmt { var : usize, val : Expr }, + IfStmt { cond : Expr, thn : Box<Stmt>, els : Option<Box<Stmt>> }, + // TODO: Not implemented + //MatchStmt { expr : Expr, cases : Vec<usize>, body : Vec<Stmt> }, + LoopStmt { cond : Expr, update : Option<Box<Stmt>>, body : Box<Stmt> }, + ReturnStmt { expr : Expr }, + BreakStmt {}, + ContinueStmt {}, + BlockStmt { body : Vec<Stmt> }, + ExprStmt { expr : Expr }, +} + +// Normalized expressions differ in a number of ways: +// 1. All expressions store their type +// 2. Field, index, and array access expressions are all replaced by a read node (like the IR) and +// we add a write expression that is used to simplify the assignment operation +// 3. Structs are eliminated and replaced by tuples +// 4. Unions are now tagged by a number rather than name and are also now separated from function +// calls +// 5. The unary and binary operations no longer contain boolean operations, instead those are +// expressed using the conditional expression +// 6. Functions are now identified by number, and arguments to functions are now represented as +// either an expression or a variable number for the inout arguments +// TODO: Technically inout arguments could be any l-expression +// 7. There's an additional Zero which is used to construct the default of a type +#[derive(Clone, Debug)] +pub enum Expr { + Variable { var : usize, typ : Type }, + DynConst { idx : usize, typ : Type }, + Read { index : Vec<Index>, val : Box<Expr>, typ : Type }, + Write { index : Vec<Index>, val : Box<Expr>, rep : Box<Expr>, typ : Type }, + Tuple { vals : Vec<Expr>, typ : Type }, + Union { tag : usize, val : Box<Expr>, typ : Type }, + Constant { val : Constant, typ : Type }, + Zero { typ : Type }, + UnaryExp { op : UnaryOp, expr : Box<Expr>, typ : Type }, + BinaryExp { op : BinaryOp, lhs : Box<Expr>, rhs : Box<Expr>, typ : Type }, + CastExpr { expr : Box<Expr>, typ : Type }, + CondExpr { cond : Box<Expr>, thn : Box<Expr>, els : Box<Expr>, typ : Type }, + CallExpr { func : usize, ty_args : Vec<Type>, dyn_consts : Vec<DynamicConstant>, + args : Vec<Either<Expr, usize>>, typ : Type }, +} + +#[derive(Clone, Debug)] +pub enum Index { Field(usize), Variant(usize), Array(Vec<Expr>) } + +#[derive(Clone, Debug)] +pub enum UnaryOp { Negation, BitwiseNot } +#[derive(Clone, Debug)] +pub enum BinaryOp { Add, Sub, Mul, Div, Mod, + BitAnd, BitOr, Xor, + Lt, Le, Gt, Ge, Eq, Neq, + LShift, RShift } + +fn convert_assign_op(op : lang_y::AssignOp) -> BinaryOp { + match op { + AssignOp::None => panic!("Do not call convert_assign_op on AssignOp::None"), + AssignOp::Add => BinaryOp::Add, + AssignOp::Sub => BinaryOp::Sub, + AssignOp::Mul => BinaryOp::Mul, + AssignOp::Div => BinaryOp::Div, + AssignOp::Mod => BinaryOp::Mod, + AssignOp::BitAnd => BinaryOp::BitAnd, + AssignOp::BitOr => BinaryOp::BitOr, + AssignOp::Xor => BinaryOp::Xor, + AssignOp::LShift => BinaryOp::LShift, + AssignOp::RShift => BinaryOp::RShift, + AssignOp::LogAnd => panic!("Do not call convert_assign_op on AssignOp::LogAnd"), + AssignOp::LogOr => panic!("Do not call convert_assign_op on AssignOp::LogOr"), + } +} + +fn convert_binary_op(op : lang_y::BinaryOp) -> BinaryOp { + match op { + lang_y::BinaryOp::Add => BinaryOp::Add, + lang_y::BinaryOp::Sub => BinaryOp::Sub, + lang_y::BinaryOp::Mul => BinaryOp::Mul, + lang_y::BinaryOp::Div => BinaryOp::Div, + lang_y::BinaryOp::Mod => BinaryOp::Mod, + lang_y::BinaryOp::BitAnd => BinaryOp::BitAnd, + lang_y::BinaryOp::BitOr => BinaryOp::BitOr, + lang_y::BinaryOp::Xor => BinaryOp::Xor, + lang_y::BinaryOp::Lt => BinaryOp::Lt, + lang_y::BinaryOp::Le => BinaryOp::Le, + lang_y::BinaryOp::Gt => BinaryOp::Gt, + lang_y::BinaryOp::Ge => BinaryOp::Ge, + lang_y::BinaryOp::Eq => BinaryOp::Eq, + lang_y::BinaryOp::Neq => BinaryOp::Neq, + lang_y::BinaryOp::LShift => BinaryOp::LShift, + lang_y::BinaryOp::RShift => BinaryOp::RShift, + lang_y::BinaryOp::LogAnd => panic!("Do not call convert_binary_op on BinaryOp::LogAnd"), + lang_y::BinaryOp::LogOr => panic!("Do not call convert_binary_op on BinaryOp::LogOr"), + } +} + +// Be able to access the type of an expression easily +impl Expr { + pub fn get_type(&self) -> Type { + match self { + Expr::Variable { var : _, typ } | Expr::DynConst { idx : _, typ } + | Expr::Read { index : _, val : _, typ } + | Expr::Write { index : _, val : _, rep : _, typ } + | Expr::Tuple { vals : _, typ } | Expr::Union { tag : _, val : _, typ } + | Expr::Constant { val : _, typ } | Expr::UnaryExp { op : _, expr : _, typ } + | Expr::BinaryExp { op : _, lhs : _, rhs : _, typ } + | Expr::CastExpr { expr : _, typ } + | Expr::CondExpr { cond : _, thn : _, els : _, typ } + | Expr::CallExpr { func : _, ty_args : _, dyn_consts : _, args : _, typ } + | Expr::Zero { typ } + => *typ + } + } +} + +// Helper function to unparse types +fn unparse_type(types : &TypeSolver, typ : Type, stringtab : &StringTable) -> String { + types.to_string(typ, &|n| stringtab.lookup_id(n).unwrap()) +} + +// Start of parsing and semantic analysis + +// Loads the contents of the given file name, parses, and performs semantic analysis +pub fn parse_and_analyze(src_file : String) -> Result<Prg, ErrorMessages> { + if let Ok(mut file) = File::open(src_file) { + let mut contents = String::new(); + if let Ok(_) = file.read_to_string(&mut contents) { + let lexerdef = lang_l::lexerdef(); + let lexer = lexerdef.lexer(&contents); + let (res, errs) = lang_y::parse(&lexer); + + if errs.is_empty() { + match res { + None => Err(singleton_error( + ErrorMessage::SyntaxError("Parser did not return".to_string()))), + Some(Err(())) => Err(singleton_error( + ErrorMessage::SyntaxError("Unspecified parse error".to_string()))), + Some(Ok(r)) => analyze_program(r, &lexer), + } + } else { + Err(errs.iter() + .map(|e| ErrorMessage::SyntaxError( + e.pp(&lexer, &lang_y::token_epp))) + .collect()) + } + } else { + Err(singleton_error(ErrorMessage::IOError("Unable to read input file".to_string()))) + } + } else { + Err(singleton_error(ErrorMessage::IOError("Unable to open input file".to_string()))) + } +} + +fn analyze_program( + prg : lang_y::Prg, lexer : &dyn NonStreamingLexer<DefaultLexerTypes<u32>>) + -> Result<Prg, ErrorMessages> { + + let mut stringtab = StringTable::new(); + let mut env : Env<usize, Entity> = Env::new(); + let mut types = TypeSolver::new(); + + let mut res = vec![]; + + env.open_scope(); + + for top in prg { + match top { + lang_y::Top::Import { span, name: _ } => { + Err(singleton_error( + ErrorMessage::NotImplemented( + span_to_loc(span, lexer), + "imports".to_string())))? + }, + lang_y::Top::TypeDecl { span : _, public: _, name, ty_vars, body } => { + // TODO: Handle public + env.open_scope(); // Create a new scope for the body (for type variables) + + // Add the type variables to the environment + let mut num_type = 0; + let mut num_dyn_const = 0; + // Track the kinds of the variables + let mut kinds = vec![]; + + for TypeVar { span : _, name, kind } in ty_vars { + let nm = intern_id(&name, lexer, &mut stringtab); + kinds.push(kind); + + match kind { + Kind::Type => { + let typ = types.new_type_var(nm, num_type, false, false); + env.insert(nm, Entity::Type { type_args : vec![], value : typ }); + num_type += 1; + }, + Kind::USize => { + env.insert(nm, Entity::DynConst { value : num_dyn_const }); + num_dyn_const += 1; + }, + Kind::Number => { + let typ = types.new_type_var(nm, num_type, true, false); + env.insert(nm, Entity::Type { type_args : vec![], value : typ }); + num_type += 1; + }, + Kind::Integer => { + let typ = types.new_type_var(nm, num_type, true, true); + env.insert(nm, Entity::Type { type_args : vec![], value : typ }); + num_type += 1; + }, + } + } + + let nm = intern_id(&name, lexer, &mut stringtab); + let typ = process_type_def(body, nm, lexer, &mut stringtab, + &mut env, &mut types)?; + + // Insert information into the global scope + env.close_scope(); + env.insert(nm, Entity::Type { type_args : kinds, value : typ }); + }, + lang_y::Top::ConstDecl { span : _, public: _, name, ty: _, body } => { + // TODO: Handle public + let nm = intern_id(&name, lexer, &mut stringtab); + let val = process_expr_as_constant(body, lexer, &mut stringtab, + &mut env, &mut types)?; + env.insert(nm, Entity::Constant { value : val }); + }, + lang_y::Top::FuncDecl { span, public: _, attr: _, name, ty_vars, args, ty, body } => { + // TODO: Handle public, attributes + env.open_scope(); // Open a new scope immediately to put type variables in + + // Process the type variables and add them into the environment + let mut num_dyn_const = 0; + let mut num_type_var = 0; + let mut type_kinds = vec![]; + for TypeVar { span : _, name, kind } in ty_vars { + type_kinds.push(kind); + let nm = intern_id(&name, lexer, &mut stringtab); + match kind { + Kind::USize => { + let num = num_dyn_const; + num_dyn_const += 1; + + env.insert(nm, Entity::DynConst { value : num }); + }, + Kind::Type => { + let typ = types.new_type_var(nm, num_type_var, false, false); + env.insert(nm, Entity::Type { type_args : vec![], value : typ }); + num_type_var += 1; + }, + Kind::Number => { + let typ = types.new_type_var(nm, num_type_var, true, false); + env.insert(nm, Entity::Type { type_args : vec![], value : typ }); + num_type_var += 1; + }, + Kind::Integer => { + let typ = types.new_type_var(nm, num_type_var, true, true); + env.insert(nm, Entity::Type { type_args : vec![], value : typ }); + num_type_var += 1; + }, + } + } + + // Process arguments + let mut arg_types : Vec<(usize, Type, bool)> = vec![]; // list of name, type, and + // whether is inout + let mut inout_args = vec![]; // list of indices into args + + // A collection of errors we encounter processing the arguments + let mut errors = LinkedList::new(); + + for (inout, VarBind { span, pattern, typ }) in args { + if !typ.is_some() { + errors.push_back( + ErrorMessage::NotImplemented( + span_to_loc(span, lexer), + "argument type inference".to_string())); + continue; + } + + match pattern { + Pattern::Variable { span, name } => { + if name.len() != 1 { + errors.push_back( + ErrorMessage::SemanticError( + span_to_loc(span, lexer), + "Bound variables must be local names, without a package separator".to_string())); + continue; + } + + let nm = intern_package_name(&name, lexer, &mut stringtab)[0]; + match process_type(typ.expect("FROM ABOVE"), lexer, + &mut stringtab, &env, &mut types) { + Ok(ty) => { + if inout.is_some() { + inout_args.push(arg_types.len()); + } + arg_types.push((nm, ty, inout.is_some())); + }, + Err(mut errs) => { errors.append(&mut errs); }, + } + }, + _ => { + errors.push_back( + ErrorMessage::NotImplemented( + span_to_loc(span, lexer), + "patterns in arguments".to_string())); + }, + } + } + + let return_type = + match ty { + None => { + errors.push_back( + ErrorMessage::NotImplemented( + span_to_loc(span, lexer), + "function return type inference".to_string())); + types.new_primitive(types::Primitive::Unit) + }, + Some(ty) => { + match process_type(ty, lexer, &mut stringtab, &env, + &mut types) { + Ok(ty) => ty, + Err(mut errs) => { + errors.append(&mut errs); + types.new_primitive(types::Primitive::Unit) + }, + } + }, + }; + + if !errors.is_empty() { + Err(errors)? + } + + // Compute the proper type accounting for the inouts (which become returns) + let mut inout_types = vec![]; + for arg_idx in &inout_args { + inout_types.push(arg_types[*arg_idx].1.clone()); + } + + let inout_tuple = types.new_tuple(inout_types.clone()); + let pure_return_type + = types.new_tuple(vec![return_type, inout_tuple]); + + // Add the arguments to the environment and assign each a unique variable number + // Also track the variable numbers of the inout arguments for generating returns + let mut arg_variables = vec![]; + let mut inout_variables = vec![]; + for (nm, ty, is_inout) in arg_types.iter() { + let variable = env.uniq(); + env.insert(*nm, + Entity::Variable { + variable : variable, + typ : *ty, + is_const : false }); + arg_variables.push(variable); + + if *is_inout { inout_variables.push(variable); } + } + + // Finally, we have a properly built environment and we can + // start processing the body + let (mut body, end_reachable) + = process_stmt(body, lexer, &mut stringtab, &mut env, &mut types, + false, return_type, &inout_variables, &inout_types)?; + + if end_reachable { + // The end of a function being reachable (i.e. there is some possible path + // where there is no return statement) is an error unless the return type is + // void + if types.is_void(return_type) { + // Insert return at the end + body = Stmt::BlockStmt { + body : vec![ + body, + generate_return(Expr::Tuple { + vals : vec![], + typ : types.new_primitive(types::Primitive::Unit) + }, + &inout_variables, &inout_types, + &mut types)] }; + } else { + Err(singleton_error( + ErrorMessage::SemanticError( + span_to_loc(span, lexer), + "May reach end of control without return".to_string())))? + } + } + + env.close_scope(); + + // Add the function to the global environment + let nm = intern_id(&name, lexer, &mut stringtab); + env.insert(nm, Entity::Function { + index : res.len(), + type_args : type_kinds, + args : arg_types.iter() + .map(|(_, ty, is)| (*ty, *is)) + .collect::<Vec<_>>(), + return_type : return_type }); + + // Add the function definition to the list of functions + res.push(Function { + name : lexer.span_str(name).to_string(), + num_dyn_consts : num_dyn_const, + num_type_args : num_type_var, + arguments : arg_types.iter().zip(arg_variables.iter()) + .map(|(v, n)| (*n, v.1)).collect::<Vec<_>>(), + return_type : pure_return_type, + body : body }); + }, + lang_y::Top::ModDecl { span, public: _, name: _, body: _ } => { + Err(singleton_error( + ErrorMessage::NotImplemented( + span_to_loc(span, lexer), + "modules".to_string())))? + }, + } + } + + Ok((types, res)) +} + +fn process_type_def(def : lang_y::TyDef, name : usize, + lexer : &dyn NonStreamingLexer<DefaultLexerTypes<u32>>, + stringtab : &mut StringTable, env : &mut Env<usize, Entity>, + types : &mut TypeSolver) -> Result<Type, ErrorMessages> { + + match def { + lang_y::TyDef::TypeAlias { span: _, body } => { + process_type(body, lexer, stringtab, env, types) + }, + lang_y::TyDef::Struct { span: _, public: _, fields } => { + // TODO: handle public correctly (and field public) + + let mut field_list = vec![]; + let mut field_map = HashMap::new(); + let mut errors = LinkedList::new(); + + for ObjField { span, public: _, name, typ } in fields { + let nm = intern_id(&name, lexer, stringtab); + match typ { + None => { + errors.push_back( + ErrorMessage::SemanticError( + span_to_loc(span, lexer), + "struct fields must have a type".to_string())); + }, + Some(ty) => { + match process_type(ty, lexer, stringtab, env, types) { + Ok(typ) => { + let idx = field_list.len(); + field_list.push(typ); + field_map.insert(nm, idx); + }, + Err(mut errs) => errors.append(&mut errs), + } + }, + } + } + + if !errors.is_empty() { + Err(errors) + } else { + Ok(types.new_struct(name, env.uniq(), field_list, field_map)) + } + }, + lang_y::TyDef::Union { span : _, public: _, fields } => { + // TODO: handle public correctly + let mut constr_list = vec![]; + let mut constr_map = HashMap::new(); + let mut errors = LinkedList::new(); + + for ObjField { span, public, name, typ } in fields { + if public { + errors.push_back( + ErrorMessage::SemanticError( + span_to_loc(span, lexer), + "union constructors cannot be marked public, all constructors share the visibility of the union".to_string())); + } else { + let nm = intern_id(&name, lexer, stringtab); + match typ { + None => { + let idx = constr_list.len(); + constr_list.push(types.new_primitive(types::Primitive::Unit)); + constr_map.insert(nm, idx); + }, + Some(ty) => { + match process_type(ty, lexer, stringtab, env, types) { + Ok(typ) => { + let idx = constr_list.len(); + constr_list.push(typ); + constr_map.insert(nm, idx); + }, + Err(mut errs) => errors.append(&mut errs), + } + }, + } + } + } + + if !errors.is_empty() { Err(errors) } + else { + Ok(types.new_union(name, env.uniq(), constr_list, constr_map)) + } + }, + } +} + +fn process_type(typ : lang_y::Type, lexer : &dyn NonStreamingLexer<DefaultLexerTypes<u32>>, + stringtab : &mut StringTable, env : &Env<usize, Entity>, + types : &mut TypeSolver) -> Result<Type, ErrorMessages> { + + match typ { + lang_y::Type::PrimType { span: _, typ } => { + Ok(types.new_primitive(convert_primitive(typ))) + }, + lang_y::Type::TupleType { span: _, tys } => { + let mut fields = vec![]; + let mut errors = LinkedList::new(); + + for ty in tys { + match process_type(ty, lexer, stringtab, env, types) { + Ok(t) => fields.push(t), + Err(mut errs) => errors.append(&mut errs), + } + } + + if !errors.is_empty() { + Err(errors) + } else { + if fields.len() == 1 { + Ok(fields.pop().expect("Length")) + } else { + Ok(types.new_tuple(fields)) + } + } + }, + lang_y::Type::NamedType { span, name, args } => { + if name.len() != 1 { + Err(singleton_error( + ErrorMessage::NotImplemented( + span_to_loc(span, lexer), + "packages".to_string()))) + } else { + let id = intern_package_name(&name, lexer, stringtab); + let nm = id[0]; + match env.lookup(&nm) { + Some(Entity::Type { type_args, value }) => { + if args.len() != type_args.len() { + Err(singleton_error( + ErrorMessage::SemanticError( + span_to_loc(span, lexer), + format!("Expected {} type arguments, provided {}", + type_args.len(), args.len()))))? + } + + // Process the type arguments, ensuring they match the given kinds + let mut type_vars = vec![]; + let mut dynamic_constants = vec![]; + let mut errors = LinkedList::new(); + + for (arg, kind) in args.into_iter().zip(type_args.iter()) { + let arg_span = arg.span(); + match kind { + lang_y::Kind::USize => { + match process_type_expr_as_expr( + arg, lexer, stringtab, env, types) { + Err(mut errs) => errors.append(&mut errs), + Ok(val) => dynamic_constants.push(val), + } + }, + lang_y::Kind::Type => { + match process_type_expr_as_type( + arg, lexer, stringtab, env, types) { + Err(mut errs) => errors.append(&mut errs), + Ok(typ) => type_vars.push(typ), + } + }, + lang_y::Kind::Number => { + match process_type_expr_as_type( + arg, lexer, stringtab, env, types) { + Err(mut errs) => errors.append(&mut errs), + Ok(typ) => { + if types.is_number(typ) { + type_vars.push(typ); + } else { + errors.push_back( + ErrorMessage::KindError( + span_to_loc(arg_span, lexer), + "number".to_string(), + unparse_type(types, typ, stringtab))); + } + }, + } + }, + lang_y::Kind::Integer => { + match process_type_expr_as_type( + arg, lexer, stringtab, env, types) { + Err(mut errs) => errors.append(&mut errs), + Ok(typ) => { + if types.is_integer(typ) { + type_vars.push(typ); + } else { + errors.push_back( + ErrorMessage::KindError( + span_to_loc(arg_span, lexer), + "integer".to_string(), + unparse_type(types, typ, stringtab))); + } + }, + } + }, + } + } + + if !errors.is_empty() { Err(errors)? } + + if type_vars.len() == 0 && dynamic_constants.len() == 0 { + Ok(*value) + } else { + Ok(types.instantiate(*value, &type_vars, &dynamic_constants)) + } + }, + Some(_) => + Err(singleton_error( + ErrorMessage::KindError( + span_to_loc(span, lexer), + "type".to_string(), + "value".to_string()))), + None => + Err(singleton_error( + ErrorMessage::UndefinedVariable( + span_to_loc(span, lexer), + stringtab.lookup_id(nm).unwrap()))), + } + } + }, + lang_y::Type::ArrayType { span: _, elem, dims } => { + let mut dimensions = vec![]; + let mut errors = LinkedList::new(); + + let element = process_type(*elem, lexer, stringtab, env, types); + + for dim in dims { + match process_type_expr_as_expr(dim, lexer, stringtab, env, types) { + Err(mut errs) => errors.append(&mut errs), + Ok(ex) => dimensions.push(ex), + } + } + + match element { + Err(mut errs) => { + errs.append(&mut errors); + Err(errs) + }, + Ok(element_type) => { + if !errors.is_empty() { + Err(errors) + } else { + if types.is_array(element_type) { + let elem_type = types.get_element_type(element_type).unwrap(); + let mut inner_dims = types.get_dimensions(element_type).unwrap(); + + dimensions.append(&mut inner_dims); + Ok(types.new_array(elem_type, dimensions)) + } else { + Ok(types.new_array(element_type, dimensions)) + } + } + } + } + }, + } +} + +fn process_type_expr_as_expr(exp : lang_y::TypeExpr, + lexer : &dyn NonStreamingLexer<DefaultLexerTypes<u32>>, + stringtab : &mut StringTable, + env : &Env<usize, Entity>, types : &mut TypeSolver) + -> Result<DynamicConstant, ErrorMessages> { + + match exp { + lang_y::TypeExpr::PrimType { span, .. } + | lang_y::TypeExpr::TupleType { span, .. } + | lang_y::TypeExpr::ArrayTypeExpr { span, .. } => + Err(singleton_error( + ErrorMessage::KindError( + span_to_loc(span, lexer), + "dynamic constant expression".to_string(), + "type".to_string()))), + + lang_y::TypeExpr::NamedTypeExpr { span, name, args } => { + if name.len() != 1 { + Err(singleton_error( + ErrorMessage::NotImplemented( + span_to_loc(span, lexer), + "packages".to_string()))) + } else { + let id = intern_package_name(&name, lexer, stringtab); + let nm = id[0]; + match env.lookup(&nm) { + Some(Entity::DynConst { value }) => { + if args.len() > 0 { + Err(singleton_error( + ErrorMessage::SemanticError( + span_to_loc(span, lexer), + format!("No type arguments exists on dynamic constants")))) + } else { + Ok(DynamicConstant::DynConst(nm, *value)) + } + }, + Some(Entity::Constant { value : (val, typ) }) => { + match val { + Literal::Integer(val) => + Ok(DynamicConstant::Constant(*val as usize)), + _ => + Err(singleton_error( + ErrorMessage::TypeError( + span_to_loc(span, lexer), + "usize".to_string(), + unparse_type(types, *typ, stringtab)))), + } + }, + Some(Entity::Variable { .. }) => + Err(singleton_error( + ErrorMessage::KindError( + span_to_loc(span, lexer), + "dynamic constant expression".to_string(), + "runtime variable".to_string()))), + Some(Entity::Type { .. }) => + Err(singleton_error( + ErrorMessage::KindError( + span_to_loc(span, lexer), + "dynamic constant expression".to_string(), + "type".to_string()))), + Some(Entity::Function { .. }) => + Err(singleton_error( + ErrorMessage::KindError( + span_to_loc(span, lexer), + "dynamic constant expression".to_string(), + "function".to_string()))), + None => + Err(singleton_error( + ErrorMessage::UndefinedVariable( + span_to_loc(span, lexer), + stringtab.lookup_id(nm).unwrap()))), + } + } + }, + lang_y::TypeExpr::IntLiteral { span, base } => { + let res = usize::from_str_radix(lexer.span_str(span), base.base()); + assert!(res.is_ok(), "Internal Error: Int literal is not an integer"); + Ok(DynamicConstant::Constant(res.unwrap())) + }, + + lang_y::TypeExpr::Negative { span, .. } + | lang_y::TypeExpr::Add { span, .. } + | lang_y::TypeExpr::Sub { span, .. } + | lang_y::TypeExpr::Mul { span, .. } + | lang_y::TypeExpr::Div { span, .. } => + Err(singleton_error( + ErrorMessage::NotImplemented( + span_to_loc(span, lexer), + "expressions of dynamic constants".to_string()))), + } +} + +fn process_type_expr_as_type(exp : lang_y::TypeExpr, + lexer : &dyn NonStreamingLexer<DefaultLexerTypes<u32>>, + stringtab : &mut StringTable, + env : &Env<usize, Entity>, types : &mut TypeSolver) + -> Result<Type, ErrorMessages> { + + match exp { + lang_y::TypeExpr::IntLiteral { span, .. } + | lang_y::TypeExpr::Negative { span, .. } + | lang_y::TypeExpr::Add { span, .. } + | lang_y::TypeExpr::Sub { span, .. } + | lang_y::TypeExpr::Mul { span, .. } + | lang_y::TypeExpr::Div { span, .. } => { + Err(singleton_error( + ErrorMessage::KindError( + span_to_loc(span, lexer), + "type".to_string(), "expression".to_string()))) + }, + lang_y::TypeExpr::PrimType { span : _, typ } => { + Ok(types.new_primitive(convert_primitive(typ))) + }, + lang_y::TypeExpr::TupleType { span : _, tys } => { + let mut fields = vec![]; + let mut errors = LinkedList::new(); + + for ty in tys { + match process_type_expr_as_type(ty, lexer, stringtab, env, types) { + Ok(t) => fields.push(t), + Err(mut errs) => errors.append(&mut errs), + } + } + + if !errors.is_empty() { + Err(errors) + } else { + if fields.len() == 1 { + Ok(fields.pop().expect("Length")) + } else { + Ok(types.new_tuple(fields)) + } + } + }, + lang_y::TypeExpr::ArrayTypeExpr { span : _, elem, dims } => { + let mut dimensions = vec![]; + let mut errors = LinkedList::new(); + + let element = process_type_expr_as_type(*elem, lexer, stringtab, env, types); + + for dim in dims { + match process_type_expr_as_expr(dim, lexer, stringtab, env, types) { + Err(mut errs) => errors.append(&mut errs), + Ok(ex) => dimensions.push(ex), + } + } + + match element { + Err(mut errs) => { + errs.append(&mut errors); + Err(errs) + }, + Ok(element_type) => { + if !errors.is_empty() { + Err(errors) + } else { + if types.is_array(element_type) { + let elem_type = types.get_element_type(element_type).unwrap(); + let mut inner_dims = types.get_dimensions(element_type).unwrap(); + + dimensions.append(&mut inner_dims); + Ok(types.new_array(elem_type, dimensions)) + } else { + Ok(types.new_array(element_type, dimensions)) + } + } + } + } + }, + lang_y::TypeExpr::NamedTypeExpr { span, name, args } => { + if name.len() != 1 { + Err(singleton_error( + ErrorMessage::NotImplemented( + span_to_loc(span, lexer), + "packages".to_string()))) + } else { + let id = intern_package_name(&name, lexer, stringtab); + let nm = id[0]; + match env.lookup(&nm) { + Some(Entity::Type { type_args, value }) => { + if args.len() != type_args.len() { + Err(singleton_error( + ErrorMessage::SemanticError( + span_to_loc(span, lexer), + format!("Expected {} type arguments, provided {}", + type_args.len(), args.len()))))? + } + + // Process the type arguments, ensuring they match the given kinds + let mut type_vars = vec![]; + let mut dynamic_constants = vec![]; + let mut errors = LinkedList::new(); + + for (arg, kind) in args.into_iter().zip(type_args.iter()) { + let arg_span = arg.span(); + match kind { + lang_y::Kind::USize => { + match process_type_expr_as_expr( + arg, lexer, stringtab, env, types) { + Err(mut errs) => errors.append(&mut errs), + Ok(val) => dynamic_constants.push(val), + } + }, + lang_y::Kind::Type => { + match process_type_expr_as_type( + arg, lexer, stringtab, env, types) { + Err(mut errs) => errors.append(&mut errs), + Ok(typ) => type_vars.push(typ), + } + }, + lang_y::Kind::Number => { + match process_type_expr_as_type( + arg, lexer, stringtab, env, types) { + Err(mut errs) => errors.append(&mut errs), + Ok(typ) => { + if types.is_number(typ) { + type_vars.push(typ); + } else { + errors.push_back( + ErrorMessage::KindError( + span_to_loc(arg_span, lexer), + "number".to_string(), + unparse_type(types, typ, stringtab))); + } + }, + } + }, + lang_y::Kind::Integer => { + match process_type_expr_as_type( + arg, lexer, stringtab, env, types) { + Err(mut errs) => errors.append(&mut errs), + Ok(typ) => { + if types.is_integer(typ) { + type_vars.push(typ); + } else { + errors.push_back( + ErrorMessage::KindError( + span_to_loc(arg_span, lexer), + "integer".to_string(), + unparse_type(types, typ, stringtab))); + } + }, + } + }, + } + } + + if !errors.is_empty() { Err(errors)? } + + if type_vars.len() == 0 && dynamic_constants.len() == 0 { + Ok(*value) + } else { + Ok(types.instantiate(*value, &type_vars, &dynamic_constants)) + } + }, + Some(_) => + Err(singleton_error( + ErrorMessage::KindError( + span_to_loc(span, lexer), + "type".to_string(), + "value".to_string()))), + None => + Err(singleton_error( + ErrorMessage::UndefinedVariable( + span_to_loc(span, lexer), + stringtab.lookup_id(nm).unwrap()))), + } + } + }, + } +} + +// Normalizes the given statement, and returns the normalized statement plus whether a statement +// after the analyzed one is reachable or not +fn process_stmt(stmt : lang_y::Stmt, lexer : &dyn NonStreamingLexer<DefaultLexerTypes<u32>>, + stringtab : &mut StringTable, env : &mut Env<usize, Entity>, + types : &mut TypeSolver, in_loop : bool, return_type : Type, + inout_vars : &Vec<usize>, inout_types : &Vec<Type>) + -> Result<(Stmt, bool), ErrorMessages> { + + match stmt { + lang_y::Stmt::LetStmt { span: _, var : VarBind { span : v_span, pattern, typ }, init } => { + match pattern { + Pattern::Variable { span, name } => { + if typ.is_none() { + return Err(singleton_error( + ErrorMessage::NotImplemented( + span_to_loc(v_span, lexer), + "variable type inference".to_string()))); + } + + if name.len() != 1 { + Err(singleton_error( + ErrorMessage::SemanticError( + span_to_loc(span, lexer), + "Bound variables must be local names, without a package separator".to_string())))? + } + + let nm = intern_package_name(&name, lexer, stringtab)[0]; + let ty = process_type(typ.expect("FROM ABOVE"), lexer, stringtab, env, types)?; + + let var = env.uniq(); + + let (val, exp_loc) = + match init { + Some(exp) => { + let loc = span_to_loc(exp.span(), lexer); + (process_expr(exp, lexer, stringtab, env, types)?, loc) + }, + None => { + (Expr::Zero { typ : ty }, Location::fake()) + }, + }; + let typ = val.get_type(); + + env.insert(nm, + Entity::Variable { variable : var, typ : ty, is_const : false }); + + if !types.equal(ty, typ) { + Err(singleton_error( + ErrorMessage::TypeError( + exp_loc, + unparse_type(types, ty, stringtab), + unparse_type(types, typ, stringtab)))) + } else { + Ok((Stmt::AssignStmt { var : var, val : val }, true)) + } + }, + _ => { + Err(singleton_error( + ErrorMessage::NotImplemented( + span_to_loc(v_span, lexer), + "non-variable bindings".to_string()))) + }, + } + }, + lang_y::Stmt::ConstStmt { span: _, var : VarBind { span : v_span, pattern, typ }, init } => { + match pattern { + Pattern::Variable { span, name } => { + if typ.is_none() { + return Err(singleton_error( + ErrorMessage::NotImplemented( + span_to_loc(v_span, lexer), + "variable type inference".to_string()))); + } + + if name.len() != 1 { + Err(singleton_error( + ErrorMessage::SemanticError( + span_to_loc(span, lexer), + "Bound variables must be local names, without a package separator".to_string())))? + } + + let nm = intern_package_name(&name, lexer, stringtab)[0]; + let ty = process_type(typ.expect("FROM ABOVE"), lexer, stringtab, env, types)?; + + let var = env.uniq(); + + let (val, exp_loc) = + match init { + Some(exp) => { + let loc = span_to_loc(exp.span(), lexer); + (process_expr(exp, lexer, stringtab, env, types)?, loc) + }, + None => { + (Expr::Zero { typ : ty }, Location::fake()) + }, + }; + let typ = val.get_type(); + + env.insert(nm, + Entity::Variable { variable : var, typ : ty, is_const : true }); + + if !types.equal(ty, typ) { + Err(singleton_error( + ErrorMessage::TypeError( + exp_loc, + unparse_type(types, ty, stringtab), + unparse_type(types, typ, stringtab)))) + } else { + Ok((Stmt::AssignStmt { var : var, val : val }, true)) + } + }, + _ => { + Err(singleton_error( + ErrorMessage::NotImplemented( + span_to_loc(v_span, lexer), + "non-variable bindings".to_string()))) + }, + } + }, + lang_y::Stmt::AssignStmt { span: _, lhs, assign, assign_span, rhs } => { + let lhs_res = process_lexpr(lhs, lexer, stringtab, env, types); + let rhs_res = process_expr(rhs, lexer, stringtab, env, types); + let (((var, var_typ), (exp_typ, index)), val) + = append_errors2(lhs_res, rhs_res)?; + let typ = val.get_type(); + + // Perform the appropriate type checking + match assign { + AssignOp::None => { + if !types.equal(exp_typ, typ) { + Err(singleton_error( + ErrorMessage::TypeError( + span_to_loc(assign_span, lexer), + unparse_type(types, exp_typ, stringtab), + unparse_type(types, typ, stringtab))))? + } + }, + AssignOp::Add | AssignOp::Sub | AssignOp::Mul | AssignOp::Div => { + if !types.equal(exp_typ, typ) { + Err(singleton_error( + ErrorMessage::TypeError( + span_to_loc(assign_span, lexer), + unparse_type(types, exp_typ, stringtab), + unparse_type(types, typ, stringtab))))? + } + if !types.is_number(exp_typ) { + Err(singleton_error( + ErrorMessage::TypeError( + span_to_loc(assign_span, lexer), + "number".to_string(), + unparse_type(types, exp_typ, stringtab))))? + } + }, + AssignOp::Mod | AssignOp::BitAnd | AssignOp::BitOr | AssignOp::Xor + | AssignOp::LShift | AssignOp::RShift => { + if !types.equal(exp_typ, typ) { + Err(singleton_error( + ErrorMessage::TypeError( + span_to_loc(assign_span, lexer), + unparse_type(types, exp_typ, stringtab), + unparse_type(types, typ, stringtab))))? + } + if !types.is_integer(exp_typ) { + Err(singleton_error( + ErrorMessage::TypeError( + span_to_loc(assign_span, lexer), + "integer".to_string(), + unparse_type(types, exp_typ, stringtab))))? + } + }, + AssignOp::LogAnd | AssignOp::LogOr => { + if !types.equal(exp_typ, typ) { + Err(singleton_error( + ErrorMessage::TypeError( + span_to_loc(assign_span, lexer), + unparse_type(types, exp_typ, stringtab), + unparse_type(types, typ, stringtab))))? + } + if !types.is_bool(exp_typ) { + Err(singleton_error( + ErrorMessage::TypeError( + span_to_loc(assign_span, lexer), + "bool".to_string(), + unparse_type(types, exp_typ, stringtab))))? + } + }, + } + + let empty_index = index.is_empty(); + let rhs_var = Expr::Variable { var : var, typ : var_typ }; + + let rhs_val = + if empty_index { + rhs_var + } else { + Expr::Read { + index : index.clone(), + val : Box::new(rhs_var), + typ : exp_typ } + }; + + // Construct the right-hand side for the normalized expression; for x= operations this + // will construct the read and the operation; the write is left for after this since it + // is common to all cases + let result_rhs = + match assign { + AssignOp::None => { + val + }, + AssignOp::Add | AssignOp::Sub | AssignOp::Mul | AssignOp::Div + | AssignOp::Mod | AssignOp::BitAnd | AssignOp::BitOr + | AssignOp::Xor | AssignOp::LShift | AssignOp::RShift => { + Expr::BinaryExp { + op : convert_assign_op(assign), + lhs : Box::new(rhs_val), + rhs : Box::new(val), + typ : typ } + }, + // For x &&= y we convert to if x then y else false + AssignOp::LogAnd => { + Expr::CondExpr { + cond : Box::new(rhs_val), + thn : Box::new(val), + // We know that the expected type is bool, so just use it to avoid + // creating additional new types + els : Box::new(Expr::Constant { + val : (Literal::Bool(false), exp_typ), + typ : exp_typ }), + typ : typ } + }, + // For x ||= y we convert to if x then true else y + AssignOp::LogOr => { + Expr::CondExpr { + cond : Box::new(rhs_val), + thn : Box::new(Expr::Constant { + val : (Literal::Bool(true), exp_typ), + typ : exp_typ }), + els : Box::new(val), + typ : typ } + }, + }; + + let write_exp = + if empty_index { + result_rhs + } else { + Expr::Write { + index : index, + val : Box::new(Expr::Variable { var : var, typ : var_typ }), + rep : Box::new(result_rhs), + typ : var_typ } + }; + + Ok((Stmt::AssignStmt { + var : var, + val : write_exp }, true)) + }, + lang_y::Stmt::IfStmt { span: _, cond, thn, els } => { + let cond_span = cond.span(); + let cond_res = process_expr(cond, lexer, stringtab, env, types); + + env.open_scope(); + let thn_res = process_stmt(*thn, lexer, stringtab, env, types, + in_loop, return_type, inout_vars, inout_types); + env.close_scope(); + + env.open_scope(); + let els_res = + match els { None => Ok((None, true)), + Some(stmt) => + process_stmt(*stmt, lexer, stringtab, env, types, + in_loop, return_type, inout_vars, inout_types) + .map(|(s, b)| (Some(s), b)), }; + env.close_scope(); + + let (cond_exp, (thn_body, thn_fall), (els_body, els_fall)) + = append_errors3(cond_res, thn_res, els_res)?; + let cond_typ = cond_exp.get_type(); + + if !types.is_bool(cond_typ) { + Err(singleton_error( + ErrorMessage::TypeError( + span_to_loc(cond_span, lexer), + "bool".to_string(), + unparse_type(types, cond_typ, stringtab))))? + } + + Ok((Stmt::IfStmt { + cond : cond_exp, + thn : Box::new(thn_body), + els : els_body.map(|s| Box::new(s)) }, + thn_fall || els_fall)) + }, + lang_y::Stmt::MatchStmt { span, expr: _, body: _ } => { + Err(singleton_error( + ErrorMessage::NotImplemented( + span_to_loc(span, lexer), + "match statements".to_string()))) + }, + lang_y::Stmt::ForStmt { span: _, var : VarBind { span : v_span, pattern, typ }, + init, bound, step, body } => { + let (var, var_name, var_type) = + match pattern { + Pattern::Variable { span, name } => { + if name.len() != 1 { + return Err(singleton_error( + ErrorMessage::SemanticError( + span_to_loc(span, lexer), + "Bound variables must be local names, without a package separator".to_string()))); + } + + let nm = intern_package_name(&name, lexer, stringtab)[0]; + let var_type = + match typ { + None => types.new_primitive(types::Primitive::U64), + Some(t) => { + let ty = process_type(t, lexer, stringtab, env, types)?; + if !types.is_integer(ty) { + return Err(singleton_error( + ErrorMessage::SemanticError( + span_to_loc(v_span, lexer), + "For loop variables must be integers".to_string()))); + } + ty + }, + }; + + let var = env.uniq(); + (var, nm, var_type) + }, + _ => { + return Err(singleton_error( + ErrorMessage::NotImplemented( + span_to_loc(v_span, lexer), + "patterns in for loop arguments".to_string()))); + }, + }; + + // Evaluate the initial value, bound, and step + let init_span = init.span(); + let bound_span = bound.span(); + + let init_res = process_expr(init, lexer, stringtab, env, types); + let bound_res = process_expr(bound, lexer, stringtab, env, types); + + // The step is tracked as a pair of the step's amount (always positive) and whether the + // step should be positive or negative + let (step_val, step_pos) = + match step { + None => { + (1, true) + }, + Some((negative, span, base)) => { + let val = u64::from_str_radix(lexer.span_str(span), base.base()); + assert!(val.is_ok(), "Internal Error: Int literal is not an integer"); + let num = val.unwrap(); + if num == 0 { + return Err(singleton_error( + ErrorMessage::SemanticError( + span_to_loc(span, lexer), + "For loop step cannot be 0".to_string()))); + } + + (num, !negative) + }, + }; + + let (init_val, bound_val) + = append_errors2(init_res, bound_res)?; + let init_typ = init_val.get_type(); + let bound_typ = bound_val.get_type(); + + // Verify that the types of the initial value and bound are correct + let mut type_errors = LinkedList::new(); + if !types.equal(var_type, init_typ) { + type_errors.push_back( + ErrorMessage::TypeError( + span_to_loc(init_span, lexer), + unparse_type(types, var_type, stringtab), + unparse_type(types, init_typ, stringtab))); + } + if !types.equal(var_type, bound_typ) { + type_errors.push_back( + ErrorMessage::TypeError( + span_to_loc(bound_span, lexer), + unparse_type(types, var_type, stringtab), + unparse_type(types, bound_typ, stringtab))); + } + if !type_errors.is_empty() { + Err(type_errors)? + } + + // Create the scope for the body + env.open_scope(); + env.insert(var_name, Entity::Variable { + variable : var, typ : var_type, is_const : true }); + + // Process the body + let (body, _) + = process_stmt(*body, lexer, stringtab, env, types, true, + return_type, inout_vars, inout_types)?; + + env.close_scope(); + + // We bind the initial value of the loop counter + let init_eval = Stmt::AssignStmt { var : var, val : init_val }; + + // We create a new variable for the loop bound and we're going to bind the bound to + // that value before the loop so that it is only evaluated once + let bound_var = env.uniq(); + let bound_eval = Stmt::AssignStmt { var : bound_var, val : bound_val }; + + // The condition of the loop is var < bound, unless the step is negative in which case + // it is var > bound + let condition = + Expr::BinaryExp { + op : if step_pos { BinaryOp::Lt } else { BinaryOp::Gt }, + lhs : Box::new(Expr::Variable { var : var, typ : var_type }), + rhs : Box::new(Expr::Variable { var : bound_var, typ : bound_typ }), + typ : types.new_primitive(types::Primitive::Bool) }; + + // The update of the loop is var = var + step, unless the step is negative in which + // case it is var = var - step + let update = + Stmt::AssignStmt { + var : var, + val : Expr::BinaryExp { + op : if step_pos { BinaryOp::Add } else { BinaryOp::Sub }, + lhs : Box::new(Expr::Variable { var : var, typ : var_type }), + rhs : Box::new(Expr::Constant { + val : (Literal::Integer(step_val), var_type), + typ : var_type }), + typ : var_type }}; + + // Finally, the entire loop is constructed as: + // Evaluate initial value + // Evaluate bound value + // Loop + // Note that the statement after a loop is always assumed to be reachable + Ok((Stmt::BlockStmt { + body : vec![ + init_eval, + bound_eval, + Stmt::LoopStmt { + cond : condition, + update : Some(Box::new(update)), + body : Box::new(body) + } + ] + }, true)) + }, + lang_y::Stmt::WhileStmt { span: _, cond, body } => { + let cond_span = cond.span(); + let cond_res = process_expr(cond, lexer, stringtab, env, types); + + env.open_scope(); + let body_res = process_stmt(*body, lexer, stringtab, env, types, + true, return_type, inout_vars, inout_types); + env.close_scope(); + + let (cond_val, (body_stmt, _)) + = append_errors2(cond_res, body_res)?; + let cond_typ = cond_val.get_type(); + + if !types.is_bool(cond_typ) { + Err(singleton_error( + ErrorMessage::TypeError( + span_to_loc(cond_span, lexer), + "bool".to_string(), + unparse_type(types, cond_typ, stringtab))))? + } + + // Again, the statement after a loop is always considered reachable + Ok((Stmt::LoopStmt { + cond : cond_val, + update : None, + body : Box::new(body_stmt) }, true)) + }, + lang_y::Stmt::ReturnStmt { span, expr } => { + let return_val = + if expr.is_none() && types.is_void(return_type) { + Expr::Constant { + val : (Literal::Unit, return_type), + typ : return_type } + } else if expr.is_none() { + Err(singleton_error( + ErrorMessage::SemanticError( + span_to_loc(span, lexer), + format!("Expected return of type {} found no return value", + unparse_type(types, return_type, stringtab)))))? + } else { + let val = process_expr(expr.unwrap(), lexer, stringtab, env, + types)?; + let typ = val.get_type(); + if !types.equal(return_type, typ) { + Err(singleton_error( + ErrorMessage::TypeError( + span_to_loc(span, lexer), + unparse_type(types, return_type, stringtab), + unparse_type(types, typ, stringtab))))? + } + val + }; + + // We return a tuple of the return value and of the inout variables + // Statements after a return are never reachable + Ok((generate_return(return_val, inout_vars, inout_types, types), + false)) + }, + lang_y::Stmt::BreakStmt { span } => { + if !in_loop { + Err(singleton_error( + ErrorMessage::SemanticError( + span_to_loc(span, lexer), + "Break not contained within loop".to_string())))? + } + + // Code after a break is unreachable + Ok((Stmt::BreakStmt {}, false)) + }, + lang_y::Stmt::ContinueStmt { span } => { + if !in_loop { + Err(singleton_error( + ErrorMessage::SemanticError( + span_to_loc(span, lexer), + "Continue not contained within loop".to_string())))? + } + + // Code after a continue is unreachable + Ok((Stmt::ContinueStmt {}, false)) + }, + lang_y::Stmt::BlockStmt { span: _, body } => { + // Blocks create a new scope for variables declared in them + env.open_scope(); + + let mut reachable = true; + let mut errors = LinkedList::new(); + let mut res = vec![]; + + for stmt in body { + if !reachable { + Err(singleton_error( + ErrorMessage::SemanticError( + span_to_loc(stmt.span(), lexer), + "Unreachable statement".to_string())))? + } + + match process_stmt(stmt, lexer, stringtab, env, types, in_loop, + return_type, inout_vars, inout_types) { + Err(mut errs) => { errors.append(&mut errs); }, + Ok((stmt, post_reachable)) => { + res.push(stmt); + reachable = post_reachable; + }, + } + } + + env.close_scope(); + + if !errors.is_empty() { + Err(errors) + } else { + Ok((Stmt::BlockStmt { body : res }, reachable)) + } + }, + lang_y::Stmt::CallStmt { span, name, ty_args, args } => { + // Call statements are lowered to call expressions which is made a statment using the + // ExprStmt constructor + // Code after a call is always reachable + Ok((Stmt::ExprStmt { + expr : process_expr( + lang_y::Expr::CallExpr { span, name, ty_args, args }, + lexer, stringtab, env, types)? }, + true)) + }, + } +} + +// Process an l-expression to produce the variable that's modified and its type along with the type +// of the piece being modified and a list of the index operations needed to access the accessed +// piece +// This should only be used for the left-hand side of an assignment since it will return an error +// if the variable that is accessed is marked as constant +fn process_lexpr(expr : lang_y::LExpr, + lexer : &dyn NonStreamingLexer<DefaultLexerTypes<u32>>, + stringtab : &mut StringTable, env : &mut Env<usize, Entity>, + types : &mut TypeSolver) + -> Result<((usize, Type), (Type, Vec<Index>)), ErrorMessages> { + match expr { + lang_y::LExpr::VariableLExpr { span } => { + let nm = intern_id(&span, lexer, stringtab); + match env.lookup(&nm) { + Some(Entity::Variable { variable, typ, is_const }) => { + if *is_const { + Err(singleton_error( + ErrorMessage::SemanticError( + span_to_loc(span, lexer), + format!("Variable {} is const, cannot assign to it", + lexer.span_str(span))))) + } else { + Ok(((*variable, *typ), (*typ, vec![]))) + } + }, + Some(Entity::DynConst { .. }) => { + Err(singleton_error( + ErrorMessage::SemanticError( + span_to_loc(span, lexer), + format!("{} is a dynamic constant, cannot assign to it", + lexer.span_str(span))))) + }, + Some(Entity::Constant { .. }) => { + Err(singleton_error( + ErrorMessage::SemanticError( + span_to_loc(span, lexer), + format!("{} is a constant, cannot assign to it", + lexer.span_str(span))))) + }, + Some(Entity::Function { .. }) => { + Err(singleton_error( + ErrorMessage::SemanticError( + span_to_loc(span, lexer), + format!("{} is a function, cannot assign to it", + lexer.span_str(span))))) + }, + Some(Entity::Type { .. }) => { + Err(singleton_error( + ErrorMessage::SemanticError( + span_to_loc(span, lexer), + format!("{} is a type, cannot assign to it", + lexer.span_str(span))))) + }, + None => { + Err(singleton_error( + ErrorMessage::UndefinedVariable( + span_to_loc(span, lexer), + lexer.span_str(span).to_string()))) + }, + } + }, + lang_y::LExpr::FieldLExpr { span, lhs, rhs } => { + let ((var, var_typ), (idx_typ, mut idx)) + = process_lexpr(*lhs, lexer, stringtab, env, types)?; + let field_nm = intern_id(&rhs, lexer, stringtab); + + match types.get_field(idx_typ, field_nm) { + None => Err(singleton_error( + ErrorMessage::SemanticError( + span_to_loc(span, lexer), + format!("Type {} does not possess field {}", + unparse_type(types, idx_typ, stringtab), + stringtab.lookup_id(field_nm).unwrap())))), + Some((field_idx, field_type)) => { + idx.push(Index::Field(field_idx)); + Ok(((var, var_typ), (field_type, idx))) + }, + } + }, + lang_y::LExpr::NumFieldLExpr { span, lhs, rhs } => { + let ((var, var_typ), (idx_typ, mut idx)) + = process_lexpr(*lhs, lexer, stringtab, env, types)?; + + // Identify the field number; to do this we remove the first character of the string of + // the right-hand side since the ".###" is lexed as a single token + let num = lexer.span_str(rhs)[1..].parse::<usize>() + .expect("From lexical analysis"); + + match types.get_index(idx_typ, num) { + None => Err(singleton_error( + ErrorMessage::SemanticError( + span_to_loc(span, lexer), + format!("Type {} does not possess index {}", + unparse_type(types, idx_typ, stringtab), + num)))), + Some(field_type) => { + idx.push(Index::Field(num)); + Ok(((var, var_typ), (field_type, idx))) + }, + } + }, + lang_y::LExpr::IndexLExpr { span, lhs, index } => { + let ((var, var_typ), (idx_typ, mut idx)) + = process_lexpr(*lhs, lexer, stringtab, env, types)?; + + let mut indices = vec![]; + let mut errors = LinkedList::new(); + for idx in index { + let idx_span = idx.span(); + match process_expr(idx, lexer, stringtab, env, types) { + Err(mut errs) => errors.append(&mut errs), + Ok(exp) => { + let typ = exp.get_type(); + if !types.is_u64(typ) { + errors.push_back( + ErrorMessage::TypeError( + span_to_loc(idx_span, lexer), + "usize".to_string(), + unparse_type(types, typ, stringtab))); + } else { + indices.push(exp); + } + }, + } + } + + if !errors.is_empty() { + Err(errors)? + } + + if !types.is_array(idx_typ) { + Err(singleton_error( + ErrorMessage::SemanticError( + span_to_loc(span, lexer), + format!("Array index does not apply to type {}", + unparse_type(types, idx_typ, stringtab)))))? + } + + let num_dims = types.get_num_dimensions(idx_typ).unwrap(); + if indices.len() < num_dims { + Err(singleton_error( + ErrorMessage::NotImplemented( + span_to_loc(span, lexer), + format!("fewer array indices than dimensions, array has {} dimensions but using {} indices", + num_dims, indices.len())))) + } else if indices.len() > num_dims { + Err(singleton_error( + ErrorMessage::SemanticError( + span_to_loc(span, lexer), + format!("Too many array indices, array has {} dimensions but using {} indices", + num_dims, indices.len())))) + } else { + idx.push(Index::Array(indices)); + Ok(((var, var_typ), + (types.get_element_type(idx_typ).unwrap(), idx))) + } + }, + } +} + +fn process_expr_as_constant(expr : lang_y::Expr, + lexer : &dyn NonStreamingLexer<DefaultLexerTypes<u32>>, + stringtab : &mut StringTable, env : &mut Env<usize, Entity>, + types : &mut TypeSolver) + -> Result<Constant, ErrorMessages> { + + match expr { + lang_y::Expr::Variable { span, name } => { + if name.len() != 1 { + Err(singleton_error( + ErrorMessage::NotImplemented( + span_to_loc(span, lexer), + "packages".to_string())))? + } + let nm = intern_package_name(&name, lexer, stringtab)[0]; + + match env.lookup(&nm) { + Some(Entity::Variable { .. }) => { + panic!("Constant should not be evaluated in an environment with variables") + }, + Some(Entity::DynConst { .. }) => { + panic!("Constant should not be evaluated in an environment with dynamic constants") + }, + Some(Entity::Constant { value }) => { + Ok(value.clone()) + }, + Some(Entity::Function { .. }) => + Err(singleton_error( + ErrorMessage::SemanticError( + span_to_loc(span, lexer), + format!("{} is a function, expected a value", + stringtab.lookup_id(nm).unwrap())))), + Some(Entity::Type { .. }) => + Err(singleton_error( + ErrorMessage::SemanticError( + span_to_loc(span, lexer), + format!("{} is a type, expected a value", + stringtab.lookup_id(nm).unwrap())))), + None => + Err(singleton_error( + ErrorMessage::UndefinedVariable( + span_to_loc(span, lexer), + stringtab.lookup_id(nm).unwrap()))) + } + }, + lang_y::Expr::Field { span, lhs, rhs } => { + let field_name = intern_id(&rhs, lexer, stringtab); + let (lit, typ) = process_expr_as_constant(*lhs, lexer, stringtab, env, types)?; + + match types.get_field(typ, field_name) { + None => Err(singleton_error( + ErrorMessage::SemanticError( + span_to_loc(span, lexer), + format!("Type {} does not possess field {}", + unparse_type(types, typ, stringtab), + stringtab.lookup_id(field_name).unwrap())))), + Some((field_idx, _)) => { + let Literal::Tuple(fields) = lit else { panic!("Wrong constant constructor") }; + Ok(fields[field_idx].clone()) + }, + } + }, + lang_y::Expr::NumField { span, lhs, rhs } => { + let (lit, typ) = process_expr_as_constant(*lhs, lexer, stringtab, env, types)?; + + let num = lexer.span_str(rhs)[1..].parse::<usize>() + .expect("From lexical analysis"); + + match types.get_index(typ, num) { + None => Err(singleton_error( + ErrorMessage::SemanticError( + span_to_loc(span, lexer), + format!("Type {} does not possess index {}", + unparse_type(types, typ, stringtab), + num)))), + Some(_) => { + let Literal::Tuple(fields) = lit else { panic!("Wrong constant constructor") }; + Ok(fields[num].clone()) + }, + } + }, + lang_y::Expr::ArrIndex { span, .. } => { + Err(singleton_error( + ErrorMessage::SemanticError( + span_to_loc(span, lexer), + format!("Arrays are not allowed in constants")))) + }, + lang_y::Expr::Tuple { span : _, mut exprs } => { + if exprs.len() == 1 { + return process_expr_as_constant(exprs.pop().unwrap(), lexer, stringtab, env, types); + } + if exprs.len() == 0 { + return Ok((Literal::Unit, types.new_primitive(types::Primitive::Unit))); + } + + let mut vals = vec![]; + let mut typs = vec![]; + let mut errors = LinkedList::new(); + + for exp in exprs { + match process_expr_as_constant(exp, lexer, stringtab, env, types) { + Err(mut errs) => errors.append(&mut errs), + Ok((lit, typ)) => { + typs.push(typ); + vals.push((lit, typ)); + }, + } + } + + if !errors.is_empty() { + Err(errors) + } else { + Ok((Literal::Tuple(vals), types.new_tuple(typs))) + } + }, + lang_y::Expr::Struct { span, name, ty_args, exprs } => { + if name.len() != 1 { + Err(singleton_error( + ErrorMessage::NotImplemented( + span_to_loc(span, lexer), + "packages".to_string())))? + } + + let struct_nm = intern_package_name(&name, lexer, stringtab)[0]; + match env.lookup(&struct_nm) { + Some(Entity::Variable { .. }) => { + Err(singleton_error( + ErrorMessage::KindError( + span_to_loc(span, lexer), + "struct name".to_string(), + "variable".to_string()))) + }, + Some(Entity::DynConst { .. }) => { + Err(singleton_error( + ErrorMessage::KindError( + span_to_loc(span, lexer), + "struct name".to_string(), + "dynamic constant".to_string()))) + }, + Some(Entity::Constant { .. }) => { + Err(singleton_error( + ErrorMessage::KindError( + span_to_loc(span, lexer), + "struct name".to_string(), + "constant".to_string()))) + }, + Some(Entity::Function { .. }) => { + Err(singleton_error( + ErrorMessage::KindError( + span_to_loc(span, lexer), + "struct name".to_string(), + "function".to_string()))) + }, + None => { + Err(singleton_error( + ErrorMessage::UndefinedVariable( + span_to_loc(span, lexer), + stringtab.lookup_id(struct_nm).unwrap()))) + }, + Some(Entity::Type { type_args : kinds, value : typ }) => { + if !types.is_struct(*typ) { + Err(singleton_error( + ErrorMessage::KindError( + span_to_loc(span, lexer), + "struct name".to_string(), + "non-struct type".to_string())))? + } + if kinds.len() != ty_args.len() { + Err(singleton_error( + ErrorMessage::SemanticError( + span_to_loc(span, lexer), + format!("Expected {} type arguments, provided {}", + kinds.len(), ty_args.len()))))? + } + + // Verify that the type arguments we are provided are correct and collect the + // type variable and dynamic constant substitutions + let mut type_vars = vec![]; + let mut dyn_consts = vec![]; + let mut errors = LinkedList::new(); + + for (arg, kind) in ty_args.into_iter().zip(kinds.iter()) { + let arg_span = arg.span(); + match kind { + lang_y::Kind::USize => { + match process_type_expr_as_expr( + arg, lexer, stringtab, env, types) { + Err(mut errs) => errors.append(&mut errs), + Ok(val) => dyn_consts.push(val), + } + }, + lang_y::Kind::Type => { + match process_type_expr_as_type( + arg, lexer, stringtab, env, types) { + Err(mut errs) => errors.append(&mut errs), + Ok(typ) => type_vars.push(typ), + } + }, + lang_y::Kind::Number => { + match process_type_expr_as_type( + arg, lexer, stringtab, env, types) { + Err(mut errs) => errors.append(&mut errs), + Ok(typ) => { + if types.is_number(typ) { + type_vars.push(typ); + } else { + errors.push_back( + ErrorMessage::KindError( + span_to_loc(arg_span, lexer), + "number".to_string(), + unparse_type(types, typ, stringtab))); + } + }, + } + }, + lang_y::Kind::Integer => { + match process_type_expr_as_type( + arg, lexer, stringtab, env, types) { + Err(mut errs) => errors.append(&mut errs), + Ok(typ) => { + if types.is_integer(typ) { + type_vars.push(typ); + } else { + errors.push_back( + ErrorMessage::KindError( + span_to_loc(arg_span, lexer), + "integer".to_string(), + unparse_type(types, typ, stringtab))); + } + }, + } + }, + } + } + + if !errors.is_empty() { return Err(errors); } + + let struct_type = + if type_vars.len() == 0 && dyn_consts.len() == 0 { + *typ + } else { + types.instantiate(*typ, &type_vars, &dyn_consts) + }; + + // Check each field and construct the appropriate tuple + // Note that fields that are omitted will be initialized with their type's + // default value + let num_fields = types.get_num_struct_fields(struct_type).unwrap(); + + // Values for the fields, in order + let mut values : Vec<Option<Constant>> = vec![None; num_fields]; + + for (field_name, expr) in exprs { + let field_nm = intern_id(&field_name, lexer, stringtab); + let expr_span = expr.span(); + + match types.get_field(struct_type, field_nm) { + None => { + errors.push_back( + ErrorMessage::SemanticError( + span_to_loc(field_name, lexer), + format!("Struct {} does not have field {}", + unparse_type(types, struct_type, stringtab), + stringtab.lookup_id(field_nm).unwrap()))); + }, + Some((idx, field_typ)) => { + if values[idx].is_some() { + errors.push_back( + ErrorMessage::SemanticError( + span_to_loc(field_name, lexer), + format!("Field {} defined multiple times", + stringtab.lookup_id(field_nm).unwrap()))); + } else { + match process_expr_as_constant(expr, lexer, stringtab, env, types) { + Err(mut errs) => errors.append(&mut errs), + Ok((lit, typ)) => { + if !types.equal(field_typ, typ) { + // Set the value at this index even though there's + // an error so that we also report if the field is + // defined multiple times + values[idx] + = Some((Literal::Unit, + types.new_primitive(types::Primitive::Unit))); + errors.push_back( + ErrorMessage::TypeError( + span_to_loc(expr_span, lexer), + unparse_type(types, field_typ, stringtab), + unparse_type(types, typ, stringtab))); + } else { + values[idx] = Some((lit, typ)); + } + }, + } + } + }, + } + } + + if !errors.is_empty() { return Err(errors); } + + if values.iter().any(|n| n.is_none()) { + Err(singleton_error( + ErrorMessage::NotImplemented( + span_to_loc(span, lexer), + "constant struct with missing fields".to_string())))? + } + + // Construct the list of field values, filling in zero values as needed + let filled_fields + = values.into_iter().map(|t| t.unwrap()).collect::<Vec<_>>(); + + Ok((Literal::Tuple(filled_fields), struct_type)) + }, + } + }, + lang_y::Expr::BoolLit { span : _, value } => { + let bool_typ = types.new_primitive(types::Primitive::Bool); + Ok((Literal::Bool(value), bool_typ)) + }, + lang_y::Expr::IntLit { span, base } => { + let res = u64::from_str_radix(lexer.span_str(span), base.base()); + assert!(res.is_ok(), "Internal Error: Int literal is not an integer"); + + let num_typ = types.new_number(); + Ok((Literal::Integer(res.unwrap()), num_typ)) + }, + lang_y::Expr::FloatLit { span } => { + let res = lexer.span_str(span).parse::<f64>(); + assert!(res.is_ok(), "Internal Error: Float literal is not a float"); + + let float_typ = types.new_float(); + Ok((Literal::Float(res.unwrap()), float_typ)) + }, + lang_y::Expr::UnaryExpr { span, op, expr } => { + let (expr_lit, expr_typ) + = process_expr_as_constant(*expr, lexer, stringtab, env, types)?; + + match op { + lang_y::UnaryOp::Negation => { + if !types.is_number(expr_typ) { + Err(singleton_error( + ErrorMessage::TypeError( + span_to_loc(span, lexer), + "number".to_string(), + unparse_type(types, expr_typ, stringtab)))) + } else { + Ok((match expr_lit { + Literal::Integer(i) => Literal::Integer(- (i as i64) as u64), + Literal::Float(f) => Literal::Float(- f), + _ => panic!("Incorrect literal constructor"), + }, expr_typ)) + } + }, + lang_y::UnaryOp::BitwiseNot => { + if !types.is_integer(expr_typ) { + Err(singleton_error( + ErrorMessage::TypeError( + span_to_loc(span, lexer), + "integer".to_string(), + unparse_type(types, expr_typ, stringtab)))) + } else { + let Literal::Integer(i) = expr_lit + else { panic!("Incorrect literal constructor"); }; + Ok((Literal::Integer(! i), expr_typ)) + } + }, + lang_y::UnaryOp::LogicalNot => { + if !types.is_bool(expr_typ) { + Err(singleton_error( + ErrorMessage::TypeError( + span_to_loc(span, lexer), + "bool".to_string(), + unparse_type(types, expr_typ, stringtab)))) + } else { + let Literal::Bool(b) = expr_lit + else { panic!("Incorrect literal constructor"); }; + Ok((Literal::Bool(! b), expr_typ)) + } + }, + } + }, + lang_y::Expr::BinaryExpr { span : _, op, lhs, rhs } => { + let lhs_span = lhs.span(); + let rhs_span = rhs.span(); + + let lhs_res = process_expr_as_constant(*lhs, lexer, stringtab, env, types); + let rhs_res = process_expr_as_constant(*rhs, lexer, stringtab, env, types); + + let ((lhs_lit, lhs_typ), (rhs_lit, rhs_typ)) + = append_errors2(lhs_res, rhs_res)?; + + // First, type-check + match op { + // Equality and inequality work on any types + lang_y::BinaryOp::Eq | lang_y::BinaryOp::Neq => { + if !types.equal(lhs_typ, rhs_typ) { + return Err(singleton_error( + ErrorMessage::TypeError( + span_to_loc(rhs_span, lexer), + unparse_type(types, lhs_typ, stringtab), + unparse_type(types, rhs_typ, stringtab)))); + } + }, + // These work on any numbers + lang_y::BinaryOp::Add | lang_y::BinaryOp::Sub | lang_y::BinaryOp::Mul + | lang_y::BinaryOp::Div | lang_y::BinaryOp::Lt | lang_y::BinaryOp::Le + | lang_y::BinaryOp::Gt | lang_y::BinaryOp::Ge => { + let mut errors = LinkedList::new(); + let lhs_number = types.is_number(lhs_typ); + let rhs_number = types.is_number(rhs_typ); + let equal = types.equal(lhs_typ, rhs_typ); + + if lhs_number && !equal { + errors.push_back( + ErrorMessage::TypeError( + span_to_loc(rhs_span, lexer), + unparse_type(types, lhs_typ, stringtab), + unparse_type(types, rhs_typ, stringtab))); + } else if rhs_number && !equal { + errors.push_back( + ErrorMessage::TypeError( + span_to_loc(lhs_span, lexer), + unparse_type(types, rhs_typ, stringtab), + unparse_type(types, lhs_typ, stringtab))); + } else { + // The types are equal or both are not numbers + if !lhs_number { + errors.push_back( + ErrorMessage::TypeError( + span_to_loc(lhs_span, lexer), + "number".to_string(), + unparse_type(types, lhs_typ, stringtab))); + } + if !rhs_number { + errors.push_back( + ErrorMessage::TypeError( + span_to_loc(rhs_span, lexer), + "number".to_string(), + unparse_type(types, rhs_typ, stringtab))); + } + } + + if !errors.is_empty() { return Err(errors); } + }, + lang_y::BinaryOp::Mod | lang_y::BinaryOp::BitAnd | lang_y::BinaryOp::BitOr + | lang_y::BinaryOp::Xor | lang_y::BinaryOp::LShift | lang_y::BinaryOp::RShift + => { + let mut errors = LinkedList::new(); + let lhs_integer = types.is_integer(lhs_typ); + let rhs_integer = types.is_integer(rhs_typ); + let equal = types.equal(lhs_typ, rhs_typ); + + if lhs_integer && !equal { + errors.push_back( + ErrorMessage::TypeError( + span_to_loc(rhs_span, lexer), + unparse_type(types, lhs_typ, stringtab), + unparse_type(types, rhs_typ, stringtab))); + } else if rhs_integer && !equal { + errors.push_back( + ErrorMessage::TypeError( + span_to_loc(lhs_span, lexer), + unparse_type(types, rhs_typ, stringtab), + unparse_type(types, lhs_typ, stringtab))); + } else { + // The types are equal or both are not integers + if !lhs_integer { + errors.push_back( + ErrorMessage::TypeError( + span_to_loc(lhs_span, lexer), + "integer".to_string(), + unparse_type(types, lhs_typ, stringtab))); + } + if !rhs_integer { + errors.push_back( + ErrorMessage::TypeError( + span_to_loc(rhs_span, lexer), + "integer".to_string(), + unparse_type(types, rhs_typ, stringtab))); + } + } + + if !errors.is_empty() { return Err(errors); } + }, + lang_y::BinaryOp::LogAnd | lang_y::BinaryOp::LogOr => { + let mut errors = LinkedList::new(); + let lhs_bool = types.is_bool(lhs_typ); + let rhs_bool = types.is_bool(rhs_typ); + let equal = types.equal(lhs_typ, rhs_typ); + + if lhs_bool && !equal { + errors.push_back( + ErrorMessage::TypeError( + span_to_loc(rhs_span, lexer), + unparse_type(types, lhs_typ, stringtab), + unparse_type(types, rhs_typ, stringtab))); + } else if rhs_bool && !equal { + errors.push_back( + ErrorMessage::TypeError( + span_to_loc(lhs_span, lexer), + unparse_type(types, rhs_typ, stringtab), + unparse_type(types, lhs_typ, stringtab))); + } else { + // The types are equal or both are not bools + if !lhs_bool { + errors.push_back( + ErrorMessage::TypeError( + span_to_loc(lhs_span, lexer), + "bool".to_string(), + unparse_type(types, lhs_typ, stringtab))); + } + if !rhs_bool { + errors.push_back( + ErrorMessage::TypeError( + span_to_loc(rhs_span, lexer), + "bool".to_string(), + unparse_type(types, rhs_typ, stringtab))); + } + } + + if !errors.is_empty() { return Err(errors); } + }, + }; + + match op { + lang_y::BinaryOp::Add => { + match (lhs_lit, rhs_lit) { + (Literal::Integer(i), Literal::Integer(j)) + => Ok((Literal::Integer(i + j), lhs_typ)), + (Literal::Integer(i), Literal::Float(j)) + => Ok((Literal::Float((i as f64) + j), lhs_typ)), + (Literal::Float(i), Literal::Integer(j)) + => Ok((Literal::Float(i + (j as f64)), lhs_typ)), + (Literal::Float(i), Literal::Float(j)) + => Ok((Literal::Float(i + j), lhs_typ)), + _ => panic!("Incorrect literal constructor"), + } + }, + lang_y::BinaryOp::Sub => { + match (lhs_lit, rhs_lit) { + (Literal::Integer(i), Literal::Integer(j)) + => Ok((Literal::Integer((i as i64 - j as i64) as u64), lhs_typ)), + (Literal::Integer(i), Literal::Float(j)) + => Ok((Literal::Float((i as f64) - j), lhs_typ)), + (Literal::Float(i), Literal::Integer(j)) + => Ok((Literal::Float(i - (j as f64)), lhs_typ)), + (Literal::Float(i), Literal::Float(j)) + => Ok((Literal::Float(i - j), lhs_typ)), + _ => panic!("Incorrect literal constructor"), + } + }, + lang_y::BinaryOp::Mul => { + match (lhs_lit, rhs_lit) { + (Literal::Integer(i), Literal::Integer(j)) + => Ok((Literal::Integer(i * j), lhs_typ)), + (Literal::Integer(i), Literal::Float(j)) + => Ok((Literal::Float((i as f64) * j), lhs_typ)), + (Literal::Float(i), Literal::Integer(j)) + => Ok((Literal::Float(i * (j as f64)), lhs_typ)), + (Literal::Float(i), Literal::Float(j)) + => Ok((Literal::Float(i * j), lhs_typ)), + _ => panic!("Incorrect literal constructor"), + } + }, + lang_y::BinaryOp::Div => { + match (lhs_lit, rhs_lit) { + (Literal::Integer(i), Literal::Integer(j)) + => Ok((Literal::Integer(i / j), lhs_typ)), + (Literal::Integer(i), Literal::Float(j)) + => Ok((Literal::Float((i as f64) / j), lhs_typ)), + (Literal::Float(i), Literal::Integer(j)) + => Ok((Literal::Float(i / (j as f64)), lhs_typ)), + (Literal::Float(i), Literal::Float(j)) + => Ok((Literal::Float(i / j), lhs_typ)), + _ => panic!("Incorrect literal constructor"), + } + }, + lang_y::BinaryOp::Mod => { + match (lhs_lit, rhs_lit) { + (Literal::Integer(i), Literal::Integer(j)) + => Ok((Literal::Integer(i % j), lhs_typ)), + _ => panic!("Incorrect literal constructor"), + } + }, + lang_y::BinaryOp::BitAnd => { + match (lhs_lit, rhs_lit) { + (Literal::Integer(i), Literal::Integer(j)) + => Ok((Literal::Integer(i & j), lhs_typ)), + _ => panic!("Incorrect literal constructor"), + } + }, + lang_y::BinaryOp::BitOr => { + match (lhs_lit, rhs_lit) { + (Literal::Integer(i), Literal::Integer(j)) + => Ok((Literal::Integer(i | j), lhs_typ)), + _ => panic!("Incorrect literal constructor"), + } + }, + lang_y::BinaryOp::Xor => { + match (lhs_lit, rhs_lit) { + (Literal::Integer(i), Literal::Integer(j)) + => Ok((Literal::Integer(i ^ j), lhs_typ)), + _ => panic!("Incorrect literal constructor"), + } + }, + lang_y::BinaryOp::Lt => { + let bool_typ = types.new_primitive(types::Primitive::Bool); + match (lhs_lit, rhs_lit) { + (Literal::Integer(i), Literal::Integer(j)) + => Ok((Literal::Bool(i < j), bool_typ)), + (Literal::Integer(i), Literal::Float(j)) + => Ok((Literal::Bool(((i as f64)) < j), bool_typ)), + (Literal::Float(i), Literal::Integer(j)) + => Ok((Literal::Bool(i < (j as f64)), bool_typ)), + (Literal::Float(i), Literal::Float(j)) + => Ok((Literal::Bool(i < j), bool_typ)), + _ => panic!("Incorrect literal constructor"), + } + }, + lang_y::BinaryOp::Le => { + let bool_typ = types.new_primitive(types::Primitive::Bool); + match (lhs_lit, rhs_lit) { + (Literal::Integer(i), Literal::Integer(j)) + => Ok((Literal::Bool(i <= j), bool_typ)), + (Literal::Integer(i), Literal::Float(j)) + => Ok((Literal::Bool((i as f64) <= j), bool_typ)), + (Literal::Float(i), Literal::Integer(j)) + => Ok((Literal::Bool(i <= (j as f64)), bool_typ)), + (Literal::Float(i), Literal::Float(j)) + => Ok((Literal::Bool(i <= j), bool_typ)), + _ => panic!("Incorrect literal constructor"), + } + }, + lang_y::BinaryOp::Gt => { + let bool_typ = types.new_primitive(types::Primitive::Bool); + match (lhs_lit, rhs_lit) { + (Literal::Integer(i), Literal::Integer(j)) + => Ok((Literal::Bool(i > j), bool_typ)), + (Literal::Integer(i), Literal::Float(j)) + => Ok((Literal::Bool((i as f64) > j), bool_typ)), + (Literal::Float(i), Literal::Integer(j)) + => Ok((Literal::Bool(i > (j as f64)), bool_typ)), + (Literal::Float(i), Literal::Float(j)) + => Ok((Literal::Bool(i > j), bool_typ)), + _ => panic!("Incorrect literal constructor"), + } + }, + lang_y::BinaryOp::Ge => { + let bool_typ = types.new_primitive(types::Primitive::Bool); + match (lhs_lit, rhs_lit) { + (Literal::Integer(i), Literal::Integer(j)) + => Ok((Literal::Bool(i >= j), bool_typ)), + (Literal::Integer(i), Literal::Float(j)) + => Ok((Literal::Bool((i as f64) >= j), bool_typ)), + (Literal::Float(i), Literal::Integer(j)) + => Ok((Literal::Bool(i >= (j as f64)), bool_typ)), + (Literal::Float(i), Literal::Float(j)) + => Ok((Literal::Bool(i >= j), bool_typ)), + _ => panic!("Incorrect literal constructor"), + } + }, + lang_y::BinaryOp::Eq => { + let bool_typ = types.new_primitive(types::Primitive::Bool); + match (lhs_lit, rhs_lit) { + (Literal::Integer(i), Literal::Integer(j)) + => Ok((Literal::Bool(i == j), bool_typ)), + (Literal::Integer(i), Literal::Float(j)) + => Ok((Literal::Bool((i as f64) == j), bool_typ)), + (Literal::Float(i), Literal::Integer(j)) + => Ok((Literal::Bool(i == (j as f64)), bool_typ)), + (Literal::Float(i), Literal::Float(j)) + => Ok((Literal::Bool(i == j), bool_typ)), + (lhs_lit, rhs_lit) + => Ok((Literal::Bool(lhs_lit == rhs_lit), bool_typ)), + } + }, + lang_y::BinaryOp::Neq => { + let bool_typ = types.new_primitive(types::Primitive::Bool); + match (lhs_lit, rhs_lit) { + (Literal::Integer(i), Literal::Integer(j)) + => Ok((Literal::Bool(i != j), bool_typ)), + (Literal::Integer(i), Literal::Float(j)) + => Ok((Literal::Bool((i as f64) != j), bool_typ)), + (Literal::Float(i), Literal::Integer(j)) + => Ok((Literal::Bool(i != (j as f64)), bool_typ)), + (Literal::Float(i), Literal::Float(j)) + => Ok((Literal::Bool(i != j), bool_typ)), + (lhs_lit, rhs_lit) + => Ok((Literal::Bool(lhs_lit != rhs_lit), bool_typ)), + } + }, + lang_y::BinaryOp::LShift => { + match (lhs_lit, rhs_lit) { + (Literal::Integer(i), Literal::Integer(j)) + => Ok((Literal::Integer(i << j), lhs_typ)), + _ => panic!("Incorrect literal constructor"), + } + }, + lang_y::BinaryOp::RShift => { + match (lhs_lit, rhs_lit) { + (Literal::Integer(i), Literal::Integer(j)) + => Ok((Literal::Integer(i >> j), lhs_typ)), + _ => panic!("Incorrect literal constructor"), + } + }, + lang_y::BinaryOp::LogAnd => { + match (lhs_lit, rhs_lit) { + (Literal::Bool(i), Literal::Bool(j)) + => Ok((Literal::Bool(i && j), lhs_typ)), + _ => panic!("Incorrect literal constructor"), + } + }, + lang_y::BinaryOp::LogOr => { + match (lhs_lit, rhs_lit) { + (Literal::Bool(i), Literal::Bool(j)) + => Ok((Literal::Bool(i || j), lhs_typ)), + _ => panic!("Incorrect literal constructor"), + } + }, + } + }, + lang_y::Expr::CastExpr { span, expr, typ } => { + // Cast between numeric types + let expr_res = process_expr_as_constant(*expr, lexer, stringtab, env, types); + let type_res = process_type(typ, lexer, stringtab, env, types); + + let ((expr_lit, expr_typ), to_typ) = append_errors2(expr_res, type_res)?; + + if !types.is_number(expr_typ) || !types.is_number(to_typ) { + Err(singleton_error( + ErrorMessage::SemanticError( + span_to_loc(span, lexer), + format!("Can only cast between numeric types, cannot cast {} to {}", + unparse_type(types, expr_typ, stringtab), + unparse_type(types, to_typ, stringtab))))) + } else { + if types.is_integer(to_typ) { + Ok((match expr_lit { + Literal::Integer(i) => Literal::Integer(i), + Literal::Float(f) => Literal::Integer(f as u64), + _ => panic!("Incorrect literal constructor"), + }, to_typ)) + } else { + Ok((match expr_lit { + Literal::Integer(i) => Literal::Float(i as f64), + Literal::Float(f) => Literal::Float(f), + _ => panic!("Incorrect literal constructor"), + }, to_typ)) + } + } + }, + lang_y::Expr::CondExpr { span, cond, thn, els } => { + let cond_span = cond.span(); + + let cond_res = process_expr_as_constant(*cond, lexer, stringtab, env, types); + let thn_res = process_expr_as_constant(*thn, lexer, stringtab, env, types); + let els_res = process_expr_as_constant(*els, lexer, stringtab, env, types); + + let ((cond_lit, cond_typ), (thn_lit, thn_typ), (els_lit, els_typ)) + = append_errors3(cond_res, thn_res, els_res)?; + + let mut errors = LinkedList::new(); + + if !types.is_bool(cond_typ) { + errors.push_back( + ErrorMessage::TypeError( + span_to_loc(cond_span, lexer), + "bool".to_string(), + unparse_type(types, cond_typ, stringtab))); + } + if !types.equal(thn_typ, els_typ) { + errors.push_back( + ErrorMessage::SemanticError( + span_to_loc(span, lexer), + format!("Types of conditional branches do not match, have {} and {}", + unparse_type(types, thn_typ, stringtab), + unparse_type(types, els_typ, stringtab)))); + } + + if !errors.is_empty() { + Err(errors) + } else { + let Literal::Bool(condition) = cond_lit else { panic!("Incorrect literal constructor"); }; + if condition { Ok((thn_lit, thn_typ)) } else { Ok((els_lit, els_typ)) } + } + }, + lang_y::Expr::CallExpr { span, name, ty_args, args } => { + // While calls cannot be evaluated as constants, enum values can be, so we need to + // distinguish whether this is actually a call or the construction of some enum value + if name.len() > 2 { + Err(singleton_error( + ErrorMessage::NotImplemented( + span_to_loc(span, lexer), + "packages".to_string())))? + } + + let nm = intern_package_name(&name, lexer, stringtab); + + match env.lookup(&nm[0]) { + Some(Entity::Variable { .. }) | Some(Entity::DynConst { .. }) + | Some(Entity::Constant { .. }) | Some(Entity::Function { .. }) + | None if name.len() != 1 => { + Err(singleton_error( + ErrorMessage::NotImplemented( + span_to_loc(span, lexer), + "packages".to_string()))) + }, + None => + Err(singleton_error( + ErrorMessage::UndefinedVariable( + span_to_loc(name[0], lexer), + stringtab.lookup_id(nm[0]).unwrap()))), + Some(Entity::Variable { .. }) => + Err(singleton_error( + ErrorMessage::SemanticError( + span_to_loc(name[0], lexer), + format!("{} is a variable, expected a function or union constructor", + stringtab.lookup_id(nm[0]).unwrap())))), + Some(Entity::DynConst { .. }) => + Err(singleton_error( + ErrorMessage::SemanticError( + span_to_loc(name[0], lexer), + format!("{} is a dynamic constant, expected a function or union constructor", + stringtab.lookup_id(nm[0]).unwrap())))), + Some(Entity::Constant { .. }) => + Err(singleton_error( + ErrorMessage::SemanticError( + span_to_loc(name[0], lexer), + format!("{} is a constant, expected a function or union constructor", + stringtab.lookup_id(nm[0]).unwrap())))), + Some(Entity::Function { .. }) => + Err(singleton_error( + ErrorMessage::SemanticError( + span_to_loc(name[0], lexer), + format!("Function calls cannot be evaluated as a constant")))), + Some(Entity::Type { type_args : kinds, value : typ }) => { + if !types.is_union(*typ) { + if name.len() != 1 { + Err(singleton_error( + ErrorMessage::NotImplemented( + span_to_loc(span, lexer), + "packages".to_string())))? + } else { + Err(singleton_error( + ErrorMessage::SemanticError( + span_to_loc(name[0], lexer), + format!("{} is a type, expected a function or union constructor", + stringtab.lookup_id(nm[0]).unwrap()))))? + } + } + if name.len() != 2 { + Err(singleton_error( + ErrorMessage::SemanticError( + span_to_loc(name[0], lexer), + format!("Expected constructor name"))))? + } + + if types.get_constructor_info(*typ, nm[1]).is_none() { + Err(singleton_error( + ErrorMessage::SemanticError( + span_to_loc(name[1], lexer), + format!("{} is not a constructor of type {}", + stringtab.lookup_id(nm[1]).unwrap(), + unparse_type(types, *typ, stringtab)))))? + } + + // Now, we know that we are constructing some union, we need to verify that + // the type arguments are appropriate + if kinds.len() != ty_args.len() { + Err(singleton_error( + ErrorMessage::SemanticError( + span_to_loc(span, lexer), + format!("Expected {} type arguments, provided {}", + kinds.len(), ty_args.len()))))? + } + + let mut type_vars = vec![]; + let mut dyn_consts = vec![]; + let mut errors = LinkedList::new(); + + for (arg, kind) in ty_args.into_iter().zip(kinds.iter()) { + let arg_span = arg.span(); + match kind { + lang_y::Kind::USize => { + match process_type_expr_as_expr( + arg, lexer, stringtab, env, types) { + Err(mut errs) => errors.append(&mut errs), + Ok(val) => dyn_consts.push(val), + } + }, + lang_y::Kind::Type => { + match process_type_expr_as_type( + arg, lexer, stringtab, env, types) { + Err(mut errs) => errors.append(&mut errs), + Ok(typ) => type_vars.push(typ), + } + }, + lang_y::Kind::Number => { + match process_type_expr_as_type( + arg, lexer, stringtab, env, types) { + Err(mut errs) => errors.append(&mut errs), + Ok(typ) => { + if types.is_number(typ) { + type_vars.push(typ); + } else { + errors.push_back( + ErrorMessage::KindError( + span_to_loc(arg_span, lexer), + "number".to_string(), + unparse_type(types, typ, stringtab))); + } + }, + } + }, + lang_y::Kind::Integer => { + match process_type_expr_as_type( + arg, lexer, stringtab, env, types) { + Err(mut errs) => errors.append(&mut errs), + Ok(typ) => { + if types.is_integer(typ) { + type_vars.push(typ); + } else { + errors.push_back( + ErrorMessage::KindError( + span_to_loc(arg_span, lexer), + "integer".to_string(), + unparse_type(types, typ, stringtab))); + } + }, + } + }, + } + } + + if !errors.is_empty() { return Err(errors); } + + let union_type = + if type_vars.len() == 0 && dyn_consts.len() == 0 { + *typ + } else { + types.instantiate(*typ, &type_vars, &dyn_consts) + }; + let Some((constr_idx, constr_typ)) + = types.get_constructor_info(union_type, nm[1]) + else { panic!("From above"); }; + + // Now, process the arguments to ensure they has the type needed by this + // constructor + // To do this, since unions take a single argument, we process the arguments as + // a single tuple, reporting an error if inout is used anywhere + for (is_inout, arg) in args.iter() { + if *is_inout { + Err(singleton_error( + ErrorMessage::SemanticError( + span_to_loc(arg.span(), lexer), + format!("Union constructors cannot be marked inout"))))? + } + } + + let (body_lit, body_typ) + = process_expr_as_constant( + lang_y::Expr::Tuple { + span : span, + exprs : args.into_iter().map(|(_, a)| a).collect::<Vec<_>>() }, + lexer, stringtab, env, types)?; + + if !types.equal(constr_typ, body_typ) { + Err(singleton_error( + ErrorMessage::TypeError( + span_to_loc(span, lexer), + unparse_type(types, constr_typ, stringtab), + unparse_type(types, body_typ, stringtab)))) + } else { + Ok((Literal::Sum(constr_idx, Box::new((body_lit, body_typ))), + body_typ)) + } + }, + } + }, + } +} + +fn process_expr(expr : lang_y::Expr, lexer : &dyn NonStreamingLexer<DefaultLexerTypes<u32>>, + stringtab : &mut StringTable, env : &mut Env<usize, Entity>, + types : &mut TypeSolver) + -> Result<Expr, ErrorMessages> { + + match expr { + lang_y::Expr::Variable { span, name } => { + if name.len() != 1 { + Err(singleton_error( + ErrorMessage::NotImplemented( + span_to_loc(span, lexer), + "packages".to_string())))? + } + let nm = intern_package_name(&name, lexer, stringtab)[0]; + + match env.lookup(&nm) { + Some(Entity::Variable { variable, typ, .. }) => { + Ok(Expr::Variable { var : *variable, typ : *typ }) + }, + Some(Entity::DynConst { value }) => { + let typ = types.new_primitive(types::Primitive::U64); + Ok(Expr::DynConst { idx : *value, typ : typ }) + }, + Some(Entity::Constant { value : (lit, typ) }) => { + Ok(Expr::Constant { val : (lit.clone(), *typ), typ : *typ }) + }, + Some(Entity::Function { .. }) => + Err(singleton_error( + ErrorMessage::SemanticError( + span_to_loc(span, lexer), + format!("{} is a function, expected a value", + stringtab.lookup_id(nm).unwrap())))), + Some(Entity::Type { .. }) => + Err(singleton_error( + ErrorMessage::SemanticError( + span_to_loc(span, lexer), + format!("{} is a type, expected a value", + stringtab.lookup_id(nm).unwrap())))), + None => + Err(singleton_error( + ErrorMessage::UndefinedVariable( + span_to_loc(span, lexer), + stringtab.lookup_id(nm).unwrap()))) + } + }, + lang_y::Expr::Field { span, lhs, rhs } => { + let field_name = intern_id(&rhs, lexer, stringtab); + let exp = process_expr(*lhs, lexer, stringtab, env, types)?; + let exp_typ = exp.get_type(); + + match types.get_field(exp_typ, field_name) { + None => Err(singleton_error( + ErrorMessage::SemanticError( + span_to_loc(span, lexer), + format!("Type {} does not possess field {}", + unparse_type(types, exp_typ, stringtab), + stringtab.lookup_id(field_name).unwrap())))), + Some((field_idx, field_type)) => + Ok(Expr::Read { + index : vec![Index::Field(field_idx)], + val : Box::new(exp), + typ : field_type }), + } + }, + lang_y::Expr::NumField { span, lhs, rhs } => { + let exp = process_expr(*lhs, lexer, stringtab, env, types)?; + let exp_typ = exp.get_type(); + + let num = lexer.span_str(rhs)[1..].parse::<usize>() + .expect("From lexical analysis"); + + match types.get_index(exp_typ, num) { + None => Err(singleton_error( + ErrorMessage::SemanticError( + span_to_loc(span, lexer), + format!("Type {} does not possess index {}", + unparse_type(types, exp_typ, stringtab), + num)))), + Some(field_type) => + Ok(Expr::Read { + index : vec![Index::Field(num)], + val : Box::new(exp), + typ : field_type }), + } + }, + lang_y::Expr::ArrIndex { span, lhs, index } => { + let exp = process_expr(*lhs, lexer, stringtab, env, types)?; + let exp_typ = exp.get_type(); + + let mut indices = vec![]; + let mut errors = LinkedList::new(); + for idx in index { + let idx_span = idx.span(); + match process_expr(idx, lexer, stringtab, env, types) { + Err(mut errs) => errors.append(&mut errs), + Ok(exp) => { + let typ = exp.get_type(); + if !types.is_u64(typ) { + errors.push_back( + ErrorMessage::TypeError( + span_to_loc(idx_span, lexer), + "usize".to_string(), + unparse_type(types, typ, stringtab))); + } else { + indices.push(exp); + } + }, + } + } + + if !errors.is_empty() { + Err(errors)? + } + + if !types.is_array(exp_typ) { + Err(singleton_error( + ErrorMessage::SemanticError( + span_to_loc(span, lexer), + format!("Array index does not apply to type {}", + unparse_type(types, exp_typ, stringtab)))))? + } + + let num_dims = types.get_num_dimensions(exp_typ).unwrap(); + if indices.len() < num_dims { + Err(singleton_error( + ErrorMessage::NotImplemented( + span_to_loc(span, lexer), + format!("fewer array indices than dimensions, array has {} dimensions but using {} indices", + num_dims, indices.len())))) + } else if indices.len() > num_dims { + Err(singleton_error( + ErrorMessage::SemanticError( + span_to_loc(span, lexer), + format!("Too many array indices, array has {} dimensions but using {} indices", + num_dims, indices.len())))) + } else { + Ok(Expr::Read { + index : vec![Index::Array(indices)], + val : Box::new(exp), + typ : types.get_element_type(exp_typ).unwrap() }) + } + }, + lang_y::Expr::Tuple { span : _, mut exprs } => { + if exprs.len() == 1 { + return process_expr(exprs.pop().unwrap(), lexer, stringtab, env, types); + } + if exprs.len() == 0 { + let unit_type = types.new_primitive(types::Primitive::Unit); + return Ok(Expr::Constant { + val : (Literal::Unit, unit_type), + typ : unit_type }); + } + + let mut vals = vec![]; + let mut typs = vec![]; + let mut errors = LinkedList::new(); + + for exp in exprs { + match process_expr(exp, lexer, stringtab, env, types) { + Err(mut errs) => errors.append(&mut errs), + Ok(val) => { + typs.push(val.get_type()); + vals.push(val); + }, + } + } + + if !errors.is_empty() { + Err(errors) + } else { + Ok(Expr::Tuple { + vals : vals, + typ : types.new_tuple(typs) }) + } + }, + lang_y::Expr::Struct { span, name, ty_args, exprs } => { + if name.len() != 1 { + Err(singleton_error( + ErrorMessage::NotImplemented( + span_to_loc(span, lexer), + "packages".to_string())))? + } + + let struct_nm = intern_package_name(&name, lexer, stringtab)[0]; + match env.lookup(&struct_nm) { + Some(Entity::Variable { .. }) => { + Err(singleton_error( + ErrorMessage::KindError( + span_to_loc(span, lexer), + "struct name".to_string(), + "variable".to_string()))) + }, + Some(Entity::DynConst { .. }) => { + Err(singleton_error( + ErrorMessage::KindError( + span_to_loc(span, lexer), + "struct name".to_string(), + "dynamic constant".to_string()))) + }, + Some(Entity::Constant { .. }) => { + Err(singleton_error( + ErrorMessage::KindError( + span_to_loc(span, lexer), + "struct name".to_string(), + "constant".to_string()))) + }, + Some(Entity::Function { .. }) => { + Err(singleton_error( + ErrorMessage::KindError( + span_to_loc(span, lexer), + "struct name".to_string(), + "function".to_string()))) + }, + None => { + Err(singleton_error( + ErrorMessage::UndefinedVariable( + span_to_loc(span, lexer), + stringtab.lookup_id(struct_nm).unwrap()))) + }, + Some(Entity::Type { type_args : kinds, value : typ }) => { + if !types.is_struct(*typ) { + Err(singleton_error( + ErrorMessage::KindError( + span_to_loc(span, lexer), + "struct name".to_string(), + "non-struct type".to_string())))? + } + if kinds.len() != ty_args.len() { + Err(singleton_error( + ErrorMessage::SemanticError( + span_to_loc(span, lexer), + format!("Expected {} type arguments, provided {}", + kinds.len(), ty_args.len()))))? + } + + // Verify that the type arguments we are provided are correct and collect the + // type variable and dynamic constant substitutions + let mut type_vars = vec![]; + let mut dyn_consts = vec![]; + let mut errors = LinkedList::new(); + + for (arg, kind) in ty_args.into_iter().zip(kinds.iter()) { + let arg_span = arg.span(); + match kind { + lang_y::Kind::USize => { + match process_type_expr_as_expr( + arg, lexer, stringtab, env, types) { + Err(mut errs) => errors.append(&mut errs), + Ok(val) => dyn_consts.push(val), + } + }, + lang_y::Kind::Type => { + match process_type_expr_as_type( + arg, lexer, stringtab, env, types) { + Err(mut errs) => errors.append(&mut errs), + Ok(typ) => type_vars.push(typ), + } + }, + lang_y::Kind::Number => { + match process_type_expr_as_type( + arg, lexer, stringtab, env, types) { + Err(mut errs) => errors.append(&mut errs), + Ok(typ) => { + if types.is_number(typ) { + type_vars.push(typ); + } else { + errors.push_back( + ErrorMessage::KindError( + span_to_loc(arg_span, lexer), + "number".to_string(), + unparse_type(types, typ, stringtab))); + } + }, + } + }, + lang_y::Kind::Integer => { + match process_type_expr_as_type( + arg, lexer, stringtab, env, types) { + Err(mut errs) => errors.append(&mut errs), + Ok(typ) => { + if types.is_integer(typ) { + type_vars.push(typ); + } else { + errors.push_back( + ErrorMessage::KindError( + span_to_loc(arg_span, lexer), + "integer".to_string(), + unparse_type(types, typ, stringtab))); + } + }, + } + }, + } + } + + if !errors.is_empty() { return Err(errors); } + + let struct_type = + if type_vars.len() == 0 && dyn_consts.len() == 0 { + *typ + } else { + types.instantiate(*typ, &type_vars, &dyn_consts) + }; + + // Check each field and construct the appropriate tuple + // Note that fields that are omitted will be initialized with their type's + // default value + let num_fields = types.get_num_struct_fields(struct_type).unwrap(); + + // Values for the fields, in order + let mut values : Vec<Option<Expr>> = vec![None; num_fields]; + + for (field_name, expr) in exprs { + let field_nm = intern_id(&field_name, lexer, stringtab); + let expr_span = expr.span(); + + match types.get_field(struct_type, field_nm) { + None => { + errors.push_back( + ErrorMessage::SemanticError( + span_to_loc(field_name, lexer), + format!("Struct {} does not have field {}", + unparse_type(types, struct_type, stringtab), + stringtab.lookup_id(field_nm).unwrap()))); + }, + Some((idx, field_typ)) => { + if values[idx].is_some() { + errors.push_back( + ErrorMessage::SemanticError( + span_to_loc(field_name, lexer), + format!("Field {} defined multiple times", + stringtab.lookup_id(field_nm).unwrap()))); + } else { + match process_expr(expr, lexer, stringtab, env, types) { + Err(mut errs) => errors.append(&mut errs), + Ok(val) => { + let val_typ = val.get_type(); + if !types.equal(field_typ, val_typ) { + // Set the value at this index even though there's + // an error so that we also report if the field is + // defined multiple times + values[idx] = Some(Expr::Zero { typ : field_typ }); + errors.push_back( + ErrorMessage::TypeError( + span_to_loc(expr_span, lexer), + unparse_type(types, field_typ, stringtab), + unparse_type(types, val_typ, stringtab))); + } else { + values[idx] = Some(val); + } + }, + } + } + }, + } + } + + if !errors.is_empty() { return Err(errors); } + + // Construct the list of field values, filling in zero values as needed + let filled_fields + = values.into_iter().enumerate() + .map(|(i, t)| t.unwrap_or( + Expr::Zero { + typ : types.get_struct_field_type(struct_type, i).unwrap() })) + .collect::<Vec<_>>(); + + Ok(Expr::Tuple { vals : filled_fields, + typ : struct_type }) + }, + } + }, + lang_y::Expr::BoolLit { span : _, value } => { + let bool_typ = types.new_primitive(types::Primitive::Bool); + Ok(Expr::Constant { + val : (Literal::Bool(value), bool_typ), + typ : bool_typ }) + }, + lang_y::Expr::IntLit { span, base } => { + let res = u64::from_str_radix(lexer.span_str(span), base.base()); + assert!(res.is_ok(), "Internal Error: Int literal is not an integer"); + + let num_typ = types.new_number(); + Ok(Expr::Constant { + val : (Literal::Integer(res.unwrap()), num_typ), + typ : num_typ }) + }, + lang_y::Expr::FloatLit { span } => { + let res = lexer.span_str(span).parse::<f64>(); + assert!(res.is_ok(), "Internal Error: Float literal is not a float"); + + let float_typ = types.new_float(); + Ok(Expr::Constant { + val : (Literal::Float(res.unwrap()), float_typ), + typ : float_typ }) + }, + lang_y::Expr::UnaryExpr { span, op, expr } => { + let expr_val = process_expr(*expr, lexer, stringtab, env, types)?; + let expr_typ = expr_val.get_type(); + + match op { + lang_y::UnaryOp::Negation => { + if !types.is_number(expr_typ) { + Err(singleton_error( + ErrorMessage::TypeError( + span_to_loc(span, lexer), + "number".to_string(), + unparse_type(types, expr_typ, stringtab)))) + } else { + Ok(Expr::UnaryExp { + op : UnaryOp::Negation, + expr : Box::new(expr_val), + typ : expr_typ }) + } + }, + lang_y::UnaryOp::BitwiseNot => { + if !types.is_integer(expr_typ) { + Err(singleton_error( + ErrorMessage::TypeError( + span_to_loc(span, lexer), + "integer".to_string(), + unparse_type(types, expr_typ, stringtab)))) + } else { + Ok(Expr::UnaryExp { + op : UnaryOp::BitwiseNot, + expr : Box::new(expr_val), + typ : expr_typ }) + } + }, + lang_y::UnaryOp::LogicalNot => { + if !types.is_bool(expr_typ) { + Err(singleton_error( + ErrorMessage::TypeError( + span_to_loc(span, lexer), + "bool".to_string(), + unparse_type(types, expr_typ, stringtab)))) + } else { + // ! x is translated into if x then false else true + let val_true = + Expr::Constant { + val : (Literal::Bool(true), expr_typ), + typ : expr_typ }; + let val_false = + Expr::Constant { + val : (Literal::Bool(false), expr_typ), + typ : expr_typ }; + Ok(Expr::CondExpr { + cond : Box::new(expr_val), + thn : Box::new(val_false), + els : Box::new(val_true), + typ : expr_typ }) + } + }, + } + }, + lang_y::Expr::BinaryExpr { span : _, op, lhs, rhs } => { + let lhs_span = lhs.span(); + let rhs_span = rhs.span(); + + let lhs_res = process_expr(*lhs, lexer, stringtab, env, types); + let rhs_res = process_expr(*rhs, lexer, stringtab, env, types); + + let (lhs_val, rhs_val) = append_errors2(lhs_res, rhs_res)?; + let lhs_typ = lhs_val.get_type(); + let rhs_typ = rhs_val.get_type(); + + // First, type-check + match op { + // Equality and inequality work on any types + lang_y::BinaryOp::Eq | lang_y::BinaryOp::Neq => { + if !types.equal(lhs_typ, rhs_typ) { + return Err(singleton_error( + ErrorMessage::TypeError( + span_to_loc(rhs_span, lexer), + unparse_type(types, lhs_typ, stringtab), + unparse_type(types, rhs_typ, stringtab)))); + } + }, + // These work on any numbers + lang_y::BinaryOp::Add | lang_y::BinaryOp::Sub | lang_y::BinaryOp::Mul + | lang_y::BinaryOp::Div | lang_y::BinaryOp::Lt | lang_y::BinaryOp::Le + | lang_y::BinaryOp::Gt | lang_y::BinaryOp::Ge => { + let mut errors = LinkedList::new(); + let lhs_number = types.is_number(lhs_typ); + let rhs_number = types.is_number(rhs_typ); + let equal = types.equal(lhs_typ, rhs_typ); + + if lhs_number && !equal { + errors.push_back( + ErrorMessage::TypeError( + span_to_loc(rhs_span, lexer), + unparse_type(types, lhs_typ, stringtab), + unparse_type(types, rhs_typ, stringtab))); + } else if rhs_number && !equal { + errors.push_back( + ErrorMessage::TypeError( + span_to_loc(lhs_span, lexer), + unparse_type(types, rhs_typ, stringtab), + unparse_type(types, lhs_typ, stringtab))); + } else { + // The types are equal or both are not numbers + if !lhs_number { + errors.push_back( + ErrorMessage::TypeError( + span_to_loc(lhs_span, lexer), + "number".to_string(), + unparse_type(types, lhs_typ, stringtab))); + } + if !rhs_number { + errors.push_back( + ErrorMessage::TypeError( + span_to_loc(rhs_span, lexer), + "number".to_string(), + unparse_type(types, rhs_typ, stringtab))); + } + } + + if !errors.is_empty() { return Err(errors); } + }, + // These work on integer inputs + lang_y::BinaryOp::Mod | lang_y::BinaryOp::BitAnd | lang_y::BinaryOp::BitOr + | lang_y::BinaryOp::Xor | lang_y::BinaryOp::LShift | lang_y::BinaryOp::RShift + => { + let mut errors = LinkedList::new(); + let lhs_integer = types.is_integer(lhs_typ); + let rhs_integer = types.is_integer(rhs_typ); + let equal = types.equal(lhs_typ, rhs_typ); + + if lhs_integer && !equal { + errors.push_back( + ErrorMessage::TypeError( + span_to_loc(rhs_span, lexer), + unparse_type(types, lhs_typ, stringtab), + unparse_type(types, rhs_typ, stringtab))); + } else if rhs_integer && !equal { + errors.push_back( + ErrorMessage::TypeError( + span_to_loc(lhs_span, lexer), + unparse_type(types, rhs_typ, stringtab), + unparse_type(types, lhs_typ, stringtab))); + } else { + // The types are equal or both are not integers + if !lhs_integer { + errors.push_back( + ErrorMessage::TypeError( + span_to_loc(lhs_span, lexer), + "integer".to_string(), + unparse_type(types, lhs_typ, stringtab))); + } + if !rhs_integer { + errors.push_back( + ErrorMessage::TypeError( + span_to_loc(rhs_span, lexer), + "integer".to_string(), + unparse_type(types, rhs_typ, stringtab))); + } + } + + if !errors.is_empty() { return Err(errors); } + }, + // These work on boolean inputs + lang_y::BinaryOp::LogAnd | lang_y::BinaryOp::LogOr => { + let mut errors = LinkedList::new(); + let lhs_bool = types.is_bool(lhs_typ); + let rhs_bool = types.is_bool(rhs_typ); + let equal = types.equal(lhs_typ, rhs_typ); + + if lhs_bool && !equal { + errors.push_back( + ErrorMessage::TypeError( + span_to_loc(rhs_span, lexer), + unparse_type(types, lhs_typ, stringtab), + unparse_type(types, rhs_typ, stringtab))); + } else if rhs_bool && !equal { + errors.push_back( + ErrorMessage::TypeError( + span_to_loc(lhs_span, lexer), + unparse_type(types, rhs_typ, stringtab), + unparse_type(types, lhs_typ, stringtab))); + } else { + // The types are equal or both are not bools + if !lhs_bool { + errors.push_back( + ErrorMessage::TypeError( + span_to_loc(lhs_span, lexer), + "bool".to_string(), + unparse_type(types, lhs_typ, stringtab))); + } + if !rhs_bool { + errors.push_back( + ErrorMessage::TypeError( + span_to_loc(rhs_span, lexer), + "bool".to_string(), + unparse_type(types, rhs_typ, stringtab))); + } + } + + if !errors.is_empty() { return Err(errors); } + }, + }; + + match op { + // The binary operations are compiled into conditional expressions: + // x && y = if x then y else false + // x || y = if x then true else y + lang_y::BinaryOp::LogAnd => { + let false_val = + Expr::Constant { + val : (Literal::Bool(false), lhs_typ), + typ : lhs_typ }; + Ok(Expr::CondExpr { + cond : Box::new(lhs_val), + thn : Box::new(rhs_val), + els : Box::new(false_val), + typ : lhs_typ }) + }, + lang_y::BinaryOp::LogOr => { + let true_val = + Expr::Constant { + val : (Literal::Bool(true), lhs_typ), + typ : lhs_typ }; + Ok(Expr::CondExpr { + cond : Box::new(lhs_val), + thn : Box::new(true_val), + els : Box::new(rhs_val), + typ : lhs_typ }) + }, + // For comparison operators, the resulting type is a boolean, while for all other + // operations the result is the same as the two operands + lang_y::BinaryOp::Lt | lang_y::BinaryOp::Le + | lang_y::BinaryOp::Gt | lang_y::BinaryOp::Ge + | lang_y::BinaryOp::Eq | lang_y::BinaryOp::Neq => { + Ok(Expr::BinaryExp { + op : convert_binary_op(op), + lhs : Box::new(lhs_val), + rhs : Box::new(rhs_val), + typ : types.new_primitive(types::Primitive::Bool) }) + }, + _ => { + Ok(Expr::BinaryExp { + op : convert_binary_op(op), + lhs : Box::new(lhs_val), + rhs : Box::new(rhs_val), + typ : lhs_typ }) + }, + } + }, + lang_y::Expr::CastExpr { span, expr, typ } => { + // For the moment at least, casting is only supported between numeric types, and all + // numeric types can be cast to each other + let expr_res = process_expr(*expr, lexer, stringtab, env, types); + let type_res = process_type(typ, lexer, stringtab, env, types); + + let (expr_val, to_typ) = append_errors2(expr_res, type_res)?; + let expr_typ = expr_val.get_type(); + + if !types.is_number(expr_typ) || !types.is_number(to_typ) { + Err(singleton_error( + ErrorMessage::SemanticError( + span_to_loc(span, lexer), + format!("Can only cast between numeric types, cannot cast {} to {}", + unparse_type(types, expr_typ, stringtab), + unparse_type(types, to_typ, stringtab))))) + } else { + Ok(Expr::CastExpr { expr : Box::new(expr_val), + typ : to_typ }) + } + }, + lang_y::Expr::CondExpr { span, cond, thn, els } => { + let cond_span = cond.span(); + + let cond_res = process_expr(*cond, lexer, stringtab, env, types); + let thn_res = process_expr(*thn, lexer, stringtab, env, types); + let els_res = process_expr(*els, lexer, stringtab, env, types); + + let (cond_val, thn_val, els_val) + = append_errors3(cond_res, thn_res, els_res)?; + + let cond_typ = cond_val.get_type(); + let thn_typ = thn_val.get_type(); + let els_typ = els_val.get_type(); + + let mut errors = LinkedList::new(); + + if !types.is_bool(cond_typ) { + errors.push_back( + ErrorMessage::TypeError( + span_to_loc(cond_span, lexer), + "bool".to_string(), + unparse_type(types, cond_typ, stringtab))); + } + if !types.equal(thn_typ, els_typ) { + errors.push_back( + ErrorMessage::SemanticError( + span_to_loc(span, lexer), + format!("Types of conditional branches do not match, have {} and {}", + unparse_type(types, thn_typ, stringtab), + unparse_type(types, els_typ, stringtab)))); + } + + if !errors.is_empty() { + Err(errors) + } else { + Ok(Expr::CondExpr { + cond : Box::new(cond_val), + thn : Box::new(thn_val), + els : Box::new(els_val), + typ : thn_typ }) + } + }, + lang_y::Expr::CallExpr { span, name, ty_args, args } => { + // In the AST from the parser we have no way to distinguish between function calls and + // union construction. We have to identify which case we're in here. We do this by + // identifying whether the name (looking for the moment at just the first part of the + // name) and determining whether it's a type or a function. Obviously we then report + // errors if there are additional parts of the name + if name.len() > 2 { + Err(singleton_error( + ErrorMessage::NotImplemented( + span_to_loc(span, lexer), + "packages".to_string())))? + } + + let nm = intern_package_name(&name, lexer, stringtab); + + match env.lookup(&nm[0]) { + Some(Entity::Variable { .. }) | Some(Entity::DynConst { .. }) + | Some(Entity::Constant { .. }) | Some(Entity::Function { .. }) + | None if name.len() != 1 => { + Err(singleton_error( + ErrorMessage::NotImplemented( + span_to_loc(span, lexer), + "packages".to_string()))) + }, + None => + Err(singleton_error( + ErrorMessage::UndefinedVariable( + span_to_loc(name[0], lexer), + stringtab.lookup_id(nm[0]).unwrap()))), + Some(Entity::Variable { .. }) => + Err(singleton_error( + ErrorMessage::SemanticError( + span_to_loc(name[0], lexer), + format!("{} is a variable, expected a function or union constructor", + stringtab.lookup_id(nm[0]).unwrap())))), + Some(Entity::DynConst { .. }) => + Err(singleton_error( + ErrorMessage::SemanticError( + span_to_loc(name[0], lexer), + format!("{} is a dynamic constant, expected a function or union constructor", + stringtab.lookup_id(nm[0]).unwrap())))), + Some(Entity::Constant { .. }) => + Err(singleton_error( + ErrorMessage::SemanticError( + span_to_loc(name[0], lexer), + format!("{} is a constant, expected a function or union constructor", + stringtab.lookup_id(nm[0]).unwrap())))), + Some(Entity::Type { type_args : kinds, value : typ }) => { + if !types.is_union(*typ) { + if name.len() != 1 { + Err(singleton_error( + ErrorMessage::NotImplemented( + span_to_loc(span, lexer), + "packages".to_string())))? + } else { + Err(singleton_error( + ErrorMessage::SemanticError( + span_to_loc(name[0], lexer), + format!("{} is a type, expected a function or union constructor", + stringtab.lookup_id(nm[0]).unwrap()))))? + } + } + if name.len() != 2 { + Err(singleton_error( + ErrorMessage::SemanticError( + span_to_loc(name[0], lexer), + format!("Expected constructor name"))))? + } + + if types.get_constructor_info(*typ, nm[1]).is_none() { + Err(singleton_error( + ErrorMessage::SemanticError( + span_to_loc(name[1], lexer), + format!("{} is not a constructor of type {}", + stringtab.lookup_id(nm[1]).unwrap(), + unparse_type(types, *typ, stringtab)))))? + } + + // Now, we know that we are constructing some union, we need to verify that + // the type arguments are appropriate + if kinds.len() != ty_args.len() { + Err(singleton_error( + ErrorMessage::SemanticError( + span_to_loc(span, lexer), + format!("Expected {} type arguments, provided {}", + kinds.len(), ty_args.len()))))? + } + + let mut type_vars = vec![]; + let mut dyn_consts = vec![]; + let mut errors = LinkedList::new(); + + for (arg, kind) in ty_args.into_iter().zip(kinds.iter()) { + let arg_span = arg.span(); + match kind { + lang_y::Kind::USize => { + match process_type_expr_as_expr( + arg, lexer, stringtab, env, types) { + Err(mut errs) => errors.append(&mut errs), + Ok(val) => dyn_consts.push(val), + } + }, + lang_y::Kind::Type => { + match process_type_expr_as_type( + arg, lexer, stringtab, env, types) { + Err(mut errs) => errors.append(&mut errs), + Ok(typ) => type_vars.push(typ), + } + }, + lang_y::Kind::Number => { + match process_type_expr_as_type( + arg, lexer, stringtab, env, types) { + Err(mut errs) => errors.append(&mut errs), + Ok(typ) => { + if types.is_number(typ) { + type_vars.push(typ); + } else { + errors.push_back( + ErrorMessage::KindError( + span_to_loc(arg_span, lexer), + "number".to_string(), + unparse_type(types, typ, stringtab))); + } + }, + } + }, + lang_y::Kind::Integer => { + match process_type_expr_as_type( + arg, lexer, stringtab, env, types) { + Err(mut errs) => errors.append(&mut errs), + Ok(typ) => { + if types.is_integer(typ) { + type_vars.push(typ); + } else { + errors.push_back( + ErrorMessage::KindError( + span_to_loc(arg_span, lexer), + "integer".to_string(), + unparse_type(types, typ, stringtab))); + } + }, + } + }, + } + } + + if !errors.is_empty() { return Err(errors); } + + let union_type = + if type_vars.len() == 0 && dyn_consts.len() == 0 { + *typ + } else { + types.instantiate(*typ, &type_vars, &dyn_consts) + }; + let Some((constr_idx, constr_typ)) + = types.get_constructor_info(union_type, nm[1]) + else { panic!("From above"); }; + + // Now, process the arguments to ensure they has the type needed by this + // constructor + // To do this, since unions take a single argument, we process the arguments as + // a single tuple, reporting an error if inout is used anywhere + for (is_inout, arg) in args.iter() { + if *is_inout { + Err(singleton_error( + ErrorMessage::SemanticError( + span_to_loc(arg.span(), lexer), + format!("Union constructors cannot be marked inout"))))? + } + } + + let body = process_expr( + lang_y::Expr::Tuple { + span : span, + exprs : args.into_iter().map(|(_, a)| a).collect::<Vec<_>>() }, + lexer, stringtab, env, types)?; + let body_typ = body.get_type(); + + if !types.equal(constr_typ, body_typ) { + Err(singleton_error( + ErrorMessage::TypeError( + span_to_loc(span, lexer), + unparse_type(types, constr_typ, stringtab), + unparse_type(types, body_typ, stringtab)))) + } else { + Ok(Expr::Union { + tag : constr_idx, + val : Box::new(body), + typ : union_type }) + } + }, + Some(Entity::Function { index : function, type_args : kinds, + args : func_args, return_type }) => { + let func = *function; + + // Verify that the type arguments are appropriate + if kinds.len() != ty_args.len() { + Err(singleton_error( + ErrorMessage::SemanticError( + span_to_loc(span, lexer), + format!("Expected {} type arguments, provided {}", + kinds.len(), ty_args.len()))))? + } + + let mut type_vars = vec![]; + let mut dyn_consts = vec![]; + let mut errors = LinkedList::new(); + + for (arg, kind) in ty_args.into_iter().zip(kinds.iter()) { + let arg_span = arg.span(); + match kind { + lang_y::Kind::USize => { + match process_type_expr_as_expr( + arg, lexer, stringtab, env, types) { + Err(mut errs) => errors.append(&mut errs), + Ok(val) => dyn_consts.push(val), + } + }, + lang_y::Kind::Type => { + match process_type_expr_as_type( + arg, lexer, stringtab, env, types) { + Err(mut errs) => errors.append(&mut errs), + Ok(typ) => type_vars.push(typ), + } + }, + lang_y::Kind::Number => { + match process_type_expr_as_type( + arg, lexer, stringtab, env, types) { + Err(mut errs) => errors.append(&mut errs), + Ok(typ) => { + if types.is_number(typ) { + type_vars.push(typ); + } else { + errors.push_back( + ErrorMessage::KindError( + span_to_loc(arg_span, lexer), + "number".to_string(), + unparse_type(types, typ, stringtab))); + } + }, + } + }, + lang_y::Kind::Integer => { + match process_type_expr_as_type( + arg, lexer, stringtab, env, types) { + Err(mut errs) => errors.append(&mut errs), + Ok(typ) => { + if types.is_integer(typ) { + type_vars.push(typ); + } else { + errors.push_back( + ErrorMessage::KindError( + span_to_loc(arg_span, lexer), + "integer".to_string(), + unparse_type(types, typ, stringtab))); + } + }, + } + }, + } + } + + if !errors.is_empty() { return Err(errors); } + + let arg_types = + if type_vars.len() == 0 && dyn_consts.len() == 0 { + func_args.clone() + } else { + let mut tys = vec![]; + for (t, inout) in func_args { + tys.push(( + types.instantiate(*t, &type_vars, &dyn_consts), + *inout)); + } + tys + }; + let return_typ = + types.instantiate(*return_type, &type_vars, &dyn_consts); + + // Now, process the arguments to ensure they has the type needed by this + // constructor + let mut arg_vals : Vec<Either<Expr, usize>> = vec![]; + let mut errors = LinkedList::new(); + + for ((is_inout, arg), (arg_typ, expect_inout)) + in args.into_iter().zip(arg_types.into_iter()) { + + let arg_span = arg.span(); + + if is_inout && !expect_inout { + errors.push_back( + ErrorMessage::SemanticError( + span_to_loc(arg_span, lexer), + format!("Argument should be inout"))); + } else if !is_inout && expect_inout { + errors.push_back( + ErrorMessage::SemanticError( + span_to_loc(arg_span, lexer), + format!("Argument should not be inout"))); + } else if is_inout { + // If the argument is an inout then it needs to just be a variable + match process_expr(arg, lexer, stringtab, env, types) { + Err(mut errs) => errors.append(&mut errs), + Ok(Expr::Variable { var, typ }) => { + if !types.equal(arg_typ, typ) { + errors.push_back( + ErrorMessage::TypeError( + span_to_loc(arg_span, lexer), + unparse_type(types, arg_typ, stringtab), + unparse_type(types, typ, stringtab))); + } else { + arg_vals.push(Either::Right(var)); + } + }, + Ok(_) => { + errors.push_back( + ErrorMessage::SemanticError( + span_to_loc(arg_span, lexer), + format!("An inout argument must just be a variable"))); + }, + } + } else { + match process_expr(arg, lexer, stringtab, env, types) { + Err(mut errs) => errors.append(&mut errs), + Ok(exp) => { + if !types.equal(arg_typ, exp.get_type()) { + errors.push_back( + ErrorMessage::TypeError( + span_to_loc(arg_span, lexer), + unparse_type(types, arg_typ, stringtab), + unparse_type(types, exp.get_type(), stringtab))); + } else { + arg_vals.push(Either::Left(exp)); + } + }, + } + } + } + + if !errors.is_empty() { + Err(errors) + } else { + Ok(Expr::CallExpr { + func : func, + ty_args : type_vars, + dyn_consts : dyn_consts, + args : arg_vals, + typ : return_typ }) + } + }, + } + }, + } +} + +fn generate_return(expr : Expr, vars : &Vec<usize>, var_types : &Vec<Type>, + types : &mut TypeSolver) -> Stmt { + let var_exprs = vars.iter().zip(var_types.iter()) + .map(|(var, typ)| Expr::Variable { var : *var, typ : *typ }) + .collect::<Vec<_>>(); + + let inout_type = types.new_tuple(var_types.clone()); + let inout_vals = Expr::Tuple { vals : var_exprs, typ : inout_type }; + + let expr_type = expr.get_type(); + + let val = Expr::Tuple { vals : vec![expr, inout_vals], + typ : types.new_tuple(vec![expr_type, inout_type]) }; + + Stmt::ReturnStmt { expr : val } +} + +fn convert_primitive(prim : lang_y::Primitive) -> types::Primitive { + match prim { + lang_y::Primitive::Bool => types::Primitive::Bool, + lang_y::Primitive::I8 => types::Primitive::I8, + lang_y::Primitive::U8 => types::Primitive::U8, + lang_y::Primitive::I16 => types::Primitive::I16, + lang_y::Primitive::U16 => types::Primitive::U16, + lang_y::Primitive::I32 => types::Primitive::I32, + lang_y::Primitive::U32 => types::Primitive::U32, + lang_y::Primitive::I64 => types::Primitive::I64, + lang_y::Primitive::U64 => types::Primitive::U64, + lang_y::Primitive::USize=> types::Primitive::U64, + lang_y::Primitive::F32 => types::Primitive::F32, + lang_y::Primitive::F64 => types::Primitive::F64, + lang_y::Primitive::Void => types::Primitive::Unit, + } +} diff --git a/juno_frontend/src/ssa.rs b/juno_frontend/src/ssa.rs new file mode 100644 index 0000000000000000000000000000000000000000..0492fe4dde4c378c9b6eff55bb863995b23d31ff --- /dev/null +++ b/juno_frontend/src/ssa.rs @@ -0,0 +1,171 @@ +/* The data structure and algorithm described in + * Braun, M., Buchwald, S., Hack, S., Leißa, R., Mallon, C., Zwinkau, A. (2013). Simple and + * Efficient Construction of Static Single Assignment Form. In: Jhala, R., De Bosschere, K. (eds) + * Compiler Construction. CC 2013. Lecture Notes in Computer Science, vol 7791. Springer, Berlin, + * Heidelberg. https://doi.org/10.1007/978-3-642-37051-9_6 + */ +extern crate hercules_ir; + +use std::collections::{HashMap, HashSet}; + +use self::hercules_ir::ir::*; +use self::hercules_ir::build::*; + +pub struct SSA { + // Map from variable (usize) to build (NodeID) to definition (NodeID) + current_def : HashMap<usize, HashMap<NodeID, NodeID>>, + sealed_blocks : HashSet<NodeID>, + incomplete_phis : HashMap<NodeID, HashMap<usize, NodeBuilder>>, + + function : FunctionID, + block_preds : HashMap<NodeID, Vec<NodeID>>, + unsealed_blocks : HashMap<NodeID, NodeBuilder>, +} + +impl SSA { + pub fn new(func : FunctionID, entry : NodeID) -> SSA { + SSA { current_def : HashMap::new(), + sealed_blocks : HashSet::from([entry]), + incomplete_phis : HashMap::new(), + function : func, + block_preds : HashMap::from([(entry, vec![])]), + unsealed_blocks : HashMap::new() } + } + + pub fn create_cond<'a>(&mut self, builder : &mut Builder<'a>, + pred : NodeID) -> (NodeBuilder, NodeID, NodeID) { + let if_builder = builder.allocate_node(self.function); + let mut left_builder = builder.allocate_node(self.function); + let mut right_builder = builder.allocate_node(self.function); + + let left_proj = left_builder.id(); + let right_proj = right_builder.id(); + + // True branch + let proj_left = builder.create_control_index(1); + left_builder.build_read(if_builder.id(), vec![proj_left].into()); + + // False branch + let proj_right = builder.create_control_index(0); + right_builder.build_read(if_builder.id(), vec![proj_right].into()); + + let _ = builder.add_node(left_builder); + let _ = builder.add_node(right_builder); + + self.sealed_blocks.insert(if_builder.id()); + self.block_preds.insert(if_builder.id(), vec![pred]); + + self.sealed_blocks.insert(left_proj); + self.block_preds.insert(left_proj, vec![if_builder.id()]); + + self.sealed_blocks.insert(right_proj); + self.block_preds.insert(right_proj, vec![if_builder.id()]); + + (if_builder, left_proj, right_proj) + } + + pub fn create_block<'a>(&mut self, builder : &mut Builder<'a>) -> NodeID { + let node_builder = builder.allocate_node(self.function); + let block = node_builder.id(); + self.unsealed_blocks.insert(block, node_builder); + self.block_preds.insert(block, vec![]); + self.incomplete_phis.insert(block, HashMap::new()); + block + } + + // Add "pred" as a predecessor of "block" + pub fn add_pred(&mut self, block : NodeID, pred : NodeID) { + assert!(self.unsealed_blocks.contains_key(&block), + "Block must be unsealed to add predecessors"); + self.block_preds.get_mut(&block) + .expect("Block was created") + .push(pred); + } + + pub fn seal_block<'a>(&mut self, block : NodeID, builder : &mut Builder<'a>) { + let mut block_builder = self.unsealed_blocks.remove(&block) + .expect("A block must be unsealed to seal it"); + + let preds = self.block_preds.get(&block) + .expect("A block must be created to seal it") + .clone(); + let mut phis = + match self.incomplete_phis.remove(&block) { + None => HashMap::new(), + Some(phis) => phis, + }; + + for (variable, phi) in phis.drain() { + self.add_phi_operands(variable, block, phi, builder); + } + + self.sealed_blocks.insert(block); + block_builder.build_region(preds.into()); + let _ = builder.add_node(block_builder); + } + + pub fn write_variable(&mut self, variable : usize, block : NodeID, value : NodeID) { + match self.current_def.get_mut(&variable) { + Some(m) => { + m.insert(block, value); + }, + None => { + self.current_def.insert(variable, HashMap::from([(block, value)])); + }, + } + } + + pub fn read_variable<'a>(&mut self, variable : usize, block : NodeID, + builder : &mut Builder<'a>) -> NodeID { + match self.current_def.get(&variable) { + Some(var) => { + match var.get(&block) { + Some(val) => *val, + None => self.read_variable_recursive(variable, block, builder), + } + }, + None => { + panic!("ERROR: Variable in read_variable never written") + }, + } + } + + fn read_variable_recursive<'a>(&mut self, variable : usize, block : NodeID, + builder : &mut Builder<'a>) -> NodeID { + let val = + if !self.sealed_blocks.contains(&block) { + let node = builder.allocate_node(self.function); + let node_id = node.id(); + self.incomplete_phis.get_mut(&block) + .expect("Unsealed block has been added") + .insert(variable, node); + node_id + } else if self.block_preds.get(&block) + .expect("Sealed block has preds").len() == 1 { + self.read_variable(variable, + self.block_preds.get(&block) + .expect("Sealed block has preds")[0], + builder) + } else { + let node = builder.allocate_node(self.function); + let node_id = node.id(); + self.write_variable(variable, block, node_id); + self.add_phi_operands(variable, block, node, builder); + node_id + }; + + self.write_variable(variable, block, val); + val + } + + fn add_phi_operands<'a>(&mut self, variable : usize, block : NodeID, + mut phi : NodeBuilder, builder : &mut Builder<'a>) { + let mut vals = vec![]; + let preds = self.block_preds.get(&block).expect("Block exists").clone(); + for pred in preds { + vals.push(self.read_variable(variable, pred, builder)); + } + phi.build_phi(block, vals.into()); + let _ = builder.add_node(phi); + } +} diff --git a/juno_frontend/src/types.rs b/juno_frontend/src/types.rs new file mode 100644 index 0000000000000000000000000000000000000000..03702f360d7a3cbeb9f9aa4a178b8b74c163f221 --- /dev/null +++ b/juno_frontend/src/types.rs @@ -0,0 +1,763 @@ +use std::collections::{HashMap, HashSet, VecDeque}; + +use crate::hercules_ir::ir::*; +use crate::hercules_ir::build::*; + +#[derive(Copy, Clone, PartialEq, Eq, Debug)] +pub enum Either<A, B> { + Left(A), + Right(B) +} + +#[derive(Copy, Clone, Eq, Debug)] +pub enum DynamicConstant { + Constant(usize), // constant value + DynConst(usize, usize) // name and dynamic constant number +} + +impl PartialEq for DynamicConstant { + fn eq(&self, other : &Self) -> bool { + match (self, other) { + (DynamicConstant::Constant(x), DynamicConstant::Constant(y)) => x == y, + (DynamicConstant::DynConst(_, n), DynamicConstant::DynConst(_, m)) => n == m, + (_, _) => false, + } + } +} + +impl DynamicConstant { + fn subst(&self, dyn_consts : &Vec<DynamicConstant>) -> DynamicConstant { + match self { + DynamicConstant::Constant(val) => DynamicConstant::Constant(*val), + DynamicConstant::DynConst(_, idx) => dyn_consts[*idx], + } + } + + fn to_string(&self, stringtab : &dyn Fn(usize) -> String) -> String { + match self { + DynamicConstant::Constant(val) => val.to_string(), + DynamicConstant::DynConst(nm, _) => stringtab(*nm), + } + } +} + +#[derive(Copy, Clone, PartialEq, Eq, Debug)] +pub enum Primitive { Bool, U8, I8, U16, I16, U32, I32, U64, I64, F32, F64, Unit } + +impl Primitive { + fn is_number(&self) -> bool { + match &self { + Primitive::U8 | Primitive::I8 | Primitive::U16 + | Primitive::I16 | Primitive::U32 | Primitive::I32 + | Primitive::U64 | Primitive::I64 | Primitive::F32 + | Primitive::F64 => true, + _ => false, + } + } + + fn is_integer(&self) -> bool { + match &self { + Primitive::U8 | Primitive::I8 | Primitive::U16 + | Primitive::I16 | Primitive::U32 | Primitive::I32 + | Primitive::U64 | Primitive::I64 => true, + _ => false, + } + } + + fn is_float(&self) -> bool { + match &self { + Primitive::F32 | Primitive::F64 => true, + _ => false, + } + } + + fn to_string(&self) -> String { + match self { + Primitive::Bool => "bool".to_string(), + Primitive::I8 => "i8".to_string(), + Primitive::U8 => "u8".to_string(), + Primitive::I16 => "i16".to_string(), + Primitive::U16 => "u16".to_string(), + Primitive::I32 => "i32".to_string(), + Primitive::U32 => "u32".to_string(), + Primitive::I64 => "i64".to_string(), + Primitive::U64 => "u64".to_string(), + Primitive::F32 => "f32".to_string(), + Primitive::F64 => "f64".to_string(), + Primitive::Unit => "()".to_string(), + } + } +} + +#[derive(Copy, Clone, PartialEq, Eq, Debug)] +pub struct Type { val : usize } + +// Type forms, which include both concrete types, as well as unsolved types that may have some +// constraints. Note that constrained types are just primitives (particularly the numeric types) +// because the type system ensures that we always know information about arrays, structs, unions, +// and tuples +#[derive(Clone, Debug)] +enum TypeForm { + Primitive(Primitive), + Tuple(Vec<Type>), + Array(Type, Vec<DynamicConstant>), + + // This type is the same type as another type + OtherType(Type), + + // For type variables, we record its name, its index (in the list of type variables in this + // context), and anything we know about it (is it a number, is it an integer) + TypeVar { name : usize, index : usize, is_num : bool, is_int : bool }, + + // For structs and unions we record the name (via its interned representation), a UID, and the + // types of its fields/constructors in a set order and a map from field/constructor names to + // its index in the list + Struct { name : usize, id : usize, fields : Vec<Type>, names : HashMap<usize, usize> }, + Union { name : usize, id : usize, constr : Vec<Type>, names : HashMap<usize, usize> }, + + // Constrained types + AnyNumber, AnyInteger, AnyFloat, +} + +#[derive(Debug)] +pub struct TypeSolver { + types : Vec<TypeForm>, +} + +#[derive(Debug)] +pub struct TypeSolverInst<'a> { + solver : &'a TypeSolver, + // A collection of current values for type variables, and variables that we've solved for in + // that context + type_vars : Vec<TypeID>, + solved : Vec<Option<TypeID>>, +} + +impl TypeSolver { + pub fn new() -> TypeSolver { + TypeSolver { types : vec![] } + } + + pub fn new_number(&mut self) -> Type { + self.create_type(TypeForm::AnyNumber) + } + + pub fn new_integer(&mut self) -> Type { + self.create_type(TypeForm::AnyInteger) + } + + pub fn new_float(&mut self) -> Type { + self.create_type(TypeForm::AnyFloat) + } + + pub fn new_primitive(&mut self, p : Primitive) -> Type { + self.create_type(TypeForm::Primitive(p)) + } + + pub fn new_tuple(&mut self, fields : Vec<Type>) -> Type { + self.create_type(TypeForm::Tuple(fields)) + } + + pub fn new_array(&mut self, element : Type, dims : Vec<DynamicConstant>) -> Type { + self.create_type(TypeForm::Array(element, dims)) + } + + pub fn new_type_var(&mut self, name : usize, index : usize, is_num : bool, + is_int : bool) -> Type { + self.create_type(TypeForm::TypeVar { name, index, is_num, is_int }) + } + + pub fn new_struct(&mut self, name : usize, id : usize, fields : Vec<Type>, + names : HashMap<usize, usize>) -> Type { + self.create_type(TypeForm::Struct { name, id, fields, names }) + } + + pub fn new_union(&mut self, name : usize, id : usize, constr : Vec<Type>, + names : HashMap<usize, usize>) -> Type { + self.create_type(TypeForm::Union { name, id, constr, names }) + } + + fn create_type(&mut self, typ : TypeForm) -> Type { + let idx = self.types.len(); + self.types.push(typ); + Type { val : idx } + } + + pub fn create_instance(&self, type_vars : Vec<TypeID>) -> TypeSolverInst { + let num_vars = self.types.len(); + TypeSolverInst { solver : self, + type_vars : type_vars, + solved : vec![None; num_vars] } + } + + pub fn is_u64(&mut self, Type { val } : Type) -> bool { + match &self.types[val] { + TypeForm::Primitive(Primitive::U64) => true, + TypeForm::OtherType(t) => self.is_u64(*t), + TypeForm::AnyNumber | TypeForm::AnyInteger => { + self.types[val] = TypeForm::Primitive(Primitive::U64); + true + }, + _ => false, + } + } + + pub fn is_bool(&mut self, Type { val } : Type) -> bool { + match &self.types[val] { + TypeForm::Primitive(Primitive::Bool) => true, + TypeForm::OtherType(t) => self.is_bool(*t), + _ => false, + } + } + + pub fn is_void(&mut self, Type { val } : Type) -> bool { + match &self.types[val] { + TypeForm::Primitive(Primitive::Unit) => true, + TypeForm::OtherType(t) => self.is_void(*t), + _ => false, + } + } + + pub fn is_number(&mut self, Type { val } : Type) -> bool { + match &self.types[val] { + TypeForm::Primitive(p) => p.is_number(), + TypeForm::OtherType(t) => self.is_number(*t), + TypeForm::AnyNumber | TypeForm::AnyInteger | TypeForm::AnyFloat => true, + TypeForm::TypeVar { name : _, index : _, is_num, .. } => *is_num, + _ => false, + } + } + + pub fn is_integer(&mut self, Type { val } : Type) -> bool { + match &self.types[val] { + TypeForm::Primitive(p) => p.is_integer(), + TypeForm::OtherType(t) => self.is_integer(*t), + TypeForm::TypeVar { name : _, index : _, is_num : _, is_int } => *is_int, + TypeForm::AnyInteger => true, + TypeForm::AnyNumber => { + self.types[val] = TypeForm::AnyInteger; + true + }, + _ => false, + } + } + + pub fn is_float(&mut self, Type { val } : Type) -> bool { + match &self.types[val] { + TypeForm::Primitive(p) => p.is_float(), + TypeForm::OtherType(t) => self.is_float(*t), + TypeForm::AnyFloat => true, + TypeForm::AnyNumber => { + self.types[val] = TypeForm::AnyFloat; + true + }, + _ => false, + } + } + + pub fn is_tuple(&self, Type { val } : Type) -> bool { + match &self.types[val] { + TypeForm::Tuple(_) => true, + TypeForm::OtherType(t) => self.is_tuple(*t), + _ => false, + } + } + + pub fn get_num_fields(&self, Type { val } : Type) -> Option<usize> { + match &self.types[val] { + TypeForm::Tuple(fields) => { Some(fields.len()) }, + TypeForm::OtherType(t) => self.get_num_fields(*t), + _ => None, + } + } + + fn get_fields(&self, Type { val } : Type) -> Vec<Type> { + match &self.types[val] { + TypeForm::Tuple(fields) => { fields.clone() }, + TypeForm::OtherType(t) => self.get_fields(*t), + _ => panic!("Internal function get_fields used on non-tuple"), + } + } + + pub fn get_index(&self, Type { val } : Type, idx : usize) -> Option<Type> { + match &self.types[val] { + TypeForm::Tuple(fields) => fields.get(idx).copied(), + TypeForm::OtherType(t) => self.get_index(*t, idx), + _ => None, + } + } + + pub fn is_struct(&self, Type { val } : Type) -> bool { + match &self.types[val] { + TypeForm::Struct { .. } => true, + TypeForm::OtherType(t) => self.is_struct(*t), + _ => false, + } + } + + // Return the number of fields a struct has + pub fn get_num_struct_fields(&self, Type { val } : Type) -> Option<usize> { + match &self.types[val] { + TypeForm::Struct { name : _, id : _, fields, .. } => Some(fields.len()), + TypeForm::OtherType(t) => self.get_num_struct_fields(*t), + _ => None, + } + } + + // Returns the position and type of a field in a type (if it exists) + pub fn get_field(&self, Type { val } : Type, name : usize) -> Option<(usize, Type)> { + match &self.types[val] { + TypeForm::Struct { name : _, id : _, fields, names } => { + names.get(&name).map(|idx| (*idx, fields[*idx])) + }, + TypeForm::OtherType(t) => self.get_field(*t, name), + _ => None, + } + } + + // Returns the type of the field at a certain index in a struct + pub fn get_struct_field_type(&self, Type { val } : Type, idx : usize) -> Option<Type> { + match &self.types[val] { + TypeForm::Struct { name : _, id : _, fields, .. } => + fields.get(idx).copied(), + TypeForm::OtherType(t) => self.get_struct_field_type(*t, idx), + _ => None, + } + } + + pub fn get_field_names(&self, Type { val } : Type) -> Option<Vec<usize>> { + match &self.types[val] { + TypeForm::Struct { name : _, id : _, fields : _, names } => { + Some(names.keys().map(|i| *i).collect::<Vec<_>>()) + }, + TypeForm::OtherType(t) => self.get_field_names(*t), + _ => None, + } + } + + pub fn is_array(&self, Type { val } : Type) -> bool { + match &self.types[val] { + TypeForm::Array(_, _) => true, + TypeForm::OtherType(t) => self.is_array(*t), + _ => false, + } + } + + pub fn get_element_type(&self, Type { val } : Type) -> Option<Type> { + match &self.types[val] { + TypeForm::Array(elem, _) => Some(*elem), + TypeForm::OtherType(t) => self.get_element_type(*t), + _ => None, + } + } + + pub fn get_dimensions(&self, Type { val } : Type) -> Option<Vec<DynamicConstant>> { + match &self.types[val] { + TypeForm::Array(_, dims) => Some(dims.to_vec()), + TypeForm::OtherType(t) => self.get_dimensions(*t), + _ => None, + } + } + + pub fn get_num_dimensions(&self, Type { val } : Type) -> Option<usize> { + match &self.types[val] { + TypeForm::Array(_, dims) => Some(dims.len()), + TypeForm::OtherType(t) => self.get_num_dimensions(*t), + _ => None, + } + } + + pub fn is_union(&self, Type { val } : Type) -> bool { + match &self.types[val] { + TypeForm::Union { .. } => true, + TypeForm::OtherType(t) => self.is_union(*t), + _ => false, + } + } + + pub fn get_constructor_list(&self, Type { val } : Type) -> Option<Vec<usize>> { + match &self.types[val] { + TypeForm::Union { name : _, id : _, constr : _, names } => { + Some(names.keys().map(|i| *i).collect::<Vec<_>>()) + }, + TypeForm::OtherType(t) => self.get_constructor_list(*t), + _ => None, + } + } + + pub fn get_constructor_info(&self, Type { val } : Type, name : usize) + -> Option<(usize, Type)> { + match &self.types[val] { + TypeForm::Union { name : _, id : _, constr, names} => { + names.get(&name).map(|idx| (*idx, constr[*idx])) + }, + TypeForm::OtherType(t) => self.get_constructor_info(*t, name), + _ => None, + } + } + + fn is_type_var_num(&self, num : usize, Type { val } : Type) -> bool { + match &self.types[val] { + TypeForm::TypeVar { name : _, index, .. } => *index == num, + TypeForm::OtherType(t) => self.is_type_var_num(num, *t), + _ => false, + } + } + + pub fn equal(&mut self, Type { val : ty1 } : Type, Type { val : ty2 } : Type) -> bool { + if let TypeForm::OtherType(ty) = self.types[ty1] { + return self.equal(ty, Type { val : ty2 }); + } + if let TypeForm::OtherType(ty) = self.types[ty2] { + return self.equal(Type { val : ty1 }, ty); + } + + match (self.types[ty1].clone(), self.types[ty2].clone()) { + (TypeForm::Primitive(p1), TypeForm::Primitive(p2)) => p1 == p2, + + (TypeForm::Primitive(p), TypeForm::AnyNumber) if p.is_number() => { + self.types[ty2] = TypeForm::OtherType(Type { val : ty1 }); + true + }, + (TypeForm::Primitive(p), TypeForm::AnyInteger) if p.is_integer() => { + self.types[ty2] = TypeForm::OtherType(Type { val : ty1 }); + true + }, + (TypeForm::Primitive(p), TypeForm::AnyFloat) if p.is_float() => { + self.types[ty2] = TypeForm::OtherType(Type { val : ty1 }); + true + }, + + (TypeForm::AnyNumber, TypeForm::Primitive(p)) if p.is_number() => { + self.types[ty1] = TypeForm::OtherType(Type { val : ty2 }); + true + }, + (TypeForm::AnyInteger, TypeForm::Primitive(p)) if p.is_number() => { + self.types[ty1] = TypeForm::OtherType(Type { val : ty2 }); + true + }, + (TypeForm::AnyFloat, TypeForm::Primitive(p)) if p.is_float() => { + self.types[ty1] = TypeForm::OtherType(Type { val : ty2 }); + true + }, + + (TypeForm::Tuple(f1), TypeForm::Tuple(f2)) if f1.len() == f2.len() => { + for (t1, t2) in f1.iter().zip(f2.iter()) { + if !self.equal(*t1, *t2) { return false; } + } + true + }, + + (TypeForm::Array(t1, dm1), TypeForm::Array(t2, dm2)) => + self.equal(t1, t2) && dm1 == dm2, + + (TypeForm::TypeVar { name : _, index : idx1, .. }, + TypeForm::TypeVar { name : _, index : idx2, .. }) => idx1 == idx2, + + (TypeForm::Struct { name : _, id : id1, fields : fs1, .. }, + TypeForm::Struct { name : _, id : id2, fields : fs2, .. }) + | (TypeForm::Union {name : _, id : id1, constr : fs1, .. }, + TypeForm::Union {name : _, id : id2, constr : fs2, .. }) + if id1 == id2 && fs1.len() == fs2.len() => { + for (t1, t2) in fs1.iter().zip(fs2.iter()) { + if !self.equal(*t1, *t2) { return false; } + } + true + }, + + (TypeForm::AnyNumber | TypeForm::AnyInteger | TypeForm::AnyFloat, + TypeForm::AnyNumber) + | (TypeForm::AnyInteger, TypeForm::AnyInteger) + | (TypeForm::AnyFloat, TypeForm::AnyFloat) => { + self.types[ty2] = TypeForm::OtherType(Type { val : ty1 }); + true + }, + (TypeForm::AnyNumber, TypeForm::AnyInteger) + | (TypeForm::AnyNumber, TypeForm::AnyFloat) => { + self.types[ty1] = TypeForm::OtherType(Type { val : ty2 }); + true + }, + + _ => false, + } + } + + pub fn to_string(&self, Type { val } : Type, stringtab : &dyn Fn(usize) -> String) + -> String { + match &self.types[val] { + TypeForm::Primitive(p) => p.to_string(), + TypeForm::Tuple(fields) => { + "(" .to_string() + + &fields.iter().map(|t| self.to_string(*t, stringtab)).collect::<Vec<_>>().join(", ") + + ")" + }, + TypeForm::Array(elem, dims) => { + self.to_string(*elem, stringtab) + + "[" + + &dims.iter().map(|d| d.to_string(stringtab)).collect::<Vec<_>>().join(", ") + + "]" + }, + TypeForm::OtherType(typ) => { + self.to_string(*typ, stringtab) + }, + TypeForm::TypeVar { name, .. } | TypeForm::Struct { name, .. } + | TypeForm::Union { name, .. } => { + stringtab(*name) + }, + TypeForm::AnyNumber => "number".to_string(), + TypeForm::AnyInteger => "integer".to_string(), + TypeForm::AnyFloat => "float".to_string(), + } + } + + // Instantiate a type using the provided list of type variables and dynamic constants + // This is useful for instantiating the return type of a function and parametric types + pub fn instantiate(&mut self, Type { val } : Type, type_vars : &Vec<Type>, + dynamic_constants : &Vec<DynamicConstant>) -> Type { + match self.types[val].clone() { + TypeForm::Primitive(_) => Type { val }, + TypeForm::AnyNumber | TypeForm::AnyInteger | TypeForm::AnyFloat => { + self.create_type(self.types[val].clone()) + }, + TypeForm::OtherType(t) => + self.instantiate(t, type_vars, dynamic_constants), + TypeForm::Tuple(fields) => { + let mut types = vec![]; + let mut changed = false; + for typ in fields { + let inst = self.instantiate(typ, type_vars, dynamic_constants); + changed = changed || typ.val != inst.val; + types.push(inst); + } + if changed { self.create_type(TypeForm::Tuple(types)) } + else { Type { val } } + }, + TypeForm::Array(elem, dims) => { + let elem_typ = self.instantiate(elem, type_vars, dynamic_constants); + let subst_dims = + dims.iter().map(|c| c.subst(dynamic_constants)).collect::<Vec<_>>(); + + self.create_type(TypeForm::Array(elem_typ, subst_dims)) + }, + TypeForm::TypeVar { name : _, index, is_num, is_int } => { + let typ = type_vars[index]; + assert!(!is_num || self.is_number(typ)); + assert!(!is_int || self.is_integer(typ)); + typ + }, + TypeForm::Struct { name, id, fields, names } => { + let mut new_fields = vec![]; + let mut changed = false; + for typ in fields { + let inst = self.instantiate(typ, type_vars, dynamic_constants); + changed = changed || typ.val != inst.val; + new_fields.push(inst); + } + + if changed { self.create_type(TypeForm::Struct { + name : name, id : id, + fields : new_fields, + names : names.clone() }) } + else { Type { val } } + }, + TypeForm::Union { name, id, constr, names } => { + let mut new_constr = vec![]; + let mut changed = false; + for typ in constr { + let inst = self.instantiate(typ, type_vars, dynamic_constants); + changed = changed || typ.val != inst.val; + new_constr.push(inst); + } + + if changed { self.create_type(TypeForm::Union { + name : name, id : id, + constr : new_constr, + names : names.clone() }) } + else { Type { val } } + }, + } + } +} + +impl TypeSolverInst<'_> { + pub fn lower_type(&mut self, builder : &mut Builder, Type { val } : Type) -> TypeID { + if self.solved[val].is_some() { + return self.solved[val].unwrap(); + } + + let mut worklist = VecDeque::from([val]); + let mut depends : HashMap<usize, HashSet<usize>> = HashMap::new(); + + while !worklist.is_empty() { + let typ = worklist.pop_front().unwrap(); + + // If this type is already solved, just continue. + // Since we don't depend on something unless its unsolved we only need to drain the set + // of dependences once + if self.solved[typ].is_some() { continue; } + + let solution : Either<TypeID, usize> = + match &self.solver.types[typ] { + TypeForm::Primitive(p) => Either::Left(Self::build_primitive(builder, *p)), + TypeForm::Tuple(fields) => { + let mut needs = None; + let mut i_fields = vec![]; + + for Type { val } in fields { + match &self.solved[*val] { + Some(ty) => i_fields.push(*ty), + None => { needs = Some(*val); break; }, + } + } + + if let Some(t) = needs { + Either::Right(t) + } else { + Either::Left(Self::build_product(builder, i_fields)) + } + }, + TypeForm::Array(Type { val }, dims) => { + match &self.solved[*val] { + Some(ty) => + Either::Left(Self::build_array(builder, *ty, dims)), + None => Either::Right(*val), + } + }, + TypeForm::OtherType(Type { val }) => { + match &self.solved[*val] { + Some(ty) => Either::Left(*ty), + None => Either::Right(*val), + } + }, + TypeForm::TypeVar { name : _, index, .. } => { + Either::Left(self.type_vars[*index]) + }, + TypeForm::Struct { name : _, id : _, fields, .. } => { + let mut needs = None; + let mut i_fields = vec![]; + + for Type { val } in fields { + match &self.solved[*val] { + Some(ty) => i_fields.push(*ty), + None => { needs = Some(*val); break; }, + } + } + + if let Some(t) = needs { + Either::Right(t) + } else { + Either::Left(Self::build_product(builder, i_fields)) + } + }, + TypeForm::Union { name : _, id : _, constr, .. } => { + let mut needs = None; + let mut i_constr = vec![]; + + for Type { val } in constr { + match &self.solved[*val] { + Some(ty) => i_constr.push(*ty), + None => { needs = Some(*val); break; }, + } + } + + if let Some(t) = needs { + Either::Right(t) + } else { + Either::Left(Self::build_union(builder, i_constr)) + } + }, + TypeForm::AnyNumber | TypeForm::AnyInteger => + Either::Left(Self::build_primitive(builder, Primitive::I64)), + TypeForm::AnyFloat => + Either::Left(Self::build_primitive(builder, Primitive::F64)), + }; + + match solution { + Either::Left(solution) => { + self.solved[typ] = Some(solution); + match depends.get_mut(&typ) { + None => {}, + Some(set) => { + for idx in set.drain() { + worklist.push_back(idx); + } + }, + } + }, + Either::Right(needs) => { + match depends.get_mut(&needs) { + None => { + depends.insert(needs, HashSet::from([typ])); + }, + Some(set) => { + set.insert(typ); + }, + } + worklist.push_back(needs); + }, + } + } + + self.solved[val].expect("Failure to solve type constraints") + } + + pub fn as_numeric_type(&mut self, builder : &mut Builder, ty : Type) -> Primitive { + let type_id = self.lower_type(builder, ty); + if type_id == builder.create_type_i8() { Primitive::I8 } + else if type_id == builder.create_type_i16() { Primitive::I16 } + else if type_id == builder.create_type_i32() { Primitive::I32 } + else if type_id == builder.create_type_i64() { Primitive::I64 } + else if type_id == builder.create_type_u8() { Primitive::U8 } + else if type_id == builder.create_type_u16() { Primitive::U16 } + else if type_id == builder.create_type_u32() { Primitive::U32 } + else if type_id == builder.create_type_u64() { Primitive::U64 } + else if type_id == builder.create_type_f32() { Primitive::F32 } + else if type_id == builder.create_type_f64() { Primitive::F64 } + else { panic!("as_numeric_type() called on non-numeric type") } + } + + fn build_primitive(builder : &mut Builder, p : Primitive) -> TypeID { + match p { + Primitive::Bool => builder.create_type_bool(), + Primitive::I8 => builder.create_type_i8(), + Primitive::I16 => builder.create_type_i16(), + Primitive::I32 => builder.create_type_i32(), + Primitive::I64 => builder.create_type_i64(), + Primitive::U8 => builder.create_type_u8(), + Primitive::U16 => builder.create_type_u16(), + Primitive::U32 => builder.create_type_u32(), + Primitive::U64 => builder.create_type_u64(), + Primitive::F32 => builder.create_type_f32(), + Primitive::F64 => builder.create_type_f64(), + Primitive::Unit => builder.create_type_prod(vec![].into()), + } + } + + fn build_product(builder : &mut Builder, tys : Vec<TypeID>) -> TypeID { + builder.create_type_prod(tys.into()) + } + + fn build_union(builder : &mut Builder, tys : Vec<TypeID>) -> TypeID { + builder.create_type_sum(tys.into()) + } + + fn build_array(builder : &mut Builder, elem : TypeID, dims : &Vec<DynamicConstant>) -> TypeID { + let extents = Self::build_dyn_consts(builder, dims); + builder.create_type_array(elem, extents.into()) + } + + pub fn build_dyn_consts(builder : &mut Builder, vals : &Vec<DynamicConstant>) -> Vec<DynamicConstantID> { + let mut res = vec![]; + for val in vals { + res.push(Self::build_dyn_const(builder, val)); + } + res + } + + pub fn build_dyn_const(builder : &mut Builder, val : &DynamicConstant) -> DynamicConstantID { + match val { + DynamicConstant::Constant(val) => builder.create_dynamic_constant_constant(*val), + DynamicConstant::DynConst(_, num) => builder.create_dynamic_constant_parameter(*num), + } + } +}