diff --git a/Cargo.lock b/Cargo.lock index 9a0ebed2f7ee50e8cebbe0ff4d82c99f68f00c6d..2dc1fecb35270afe9fc1a62f0c90eece5b283b08 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2,20 +2,11 @@ # It is not intended for manual editing. version = 3 -[[package]] -name = "aho-corasick" -version = "1.1.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ea5d730647d4fadd988536d06fecce94b7b4f2a7efdae548f1cf4b63205518ab" -dependencies = [ - "memchr", -] - [[package]] name = "anstream" -version = "0.5.0" +version = "0.6.11" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b1f58811cfac344940f1a400b6e6231ce35171f614f26439e80f8c1465c5cc0c" +checksum = "6e2e1ebcb11de5c03c67de28a7df593d32191b44939c482e97702baaaa6ab6a5" dependencies = [ "anstyle", "anstyle-parse", @@ -27,33 +18,33 @@ dependencies = [ [[package]] name = "anstyle" -version = "1.0.2" +version = "1.0.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "15c4c2c83f81532e5845a733998b6971faca23490340a418e9b72a3ec9de12ea" +checksum = "8901269c6307e8d93993578286ac0edf7f195079ffff5ebdeea6a59ffb7e36bc" [[package]] name = "anstyle-parse" -version = "0.2.1" +version = "0.2.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "938874ff5980b03a87c5524b3ae5b59cf99b1d6bc836848df7bc5ada9643c333" +checksum = "c75ac65da39e5fe5ab759307499ddad880d724eed2f6ce5b5e8a26f4f387928c" dependencies = [ "utf8parse", ] [[package]] name = "anstyle-query" -version = "1.0.0" +version = "1.0.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5ca11d4be1bab0c8bc8734a9aa7bf4ee8316d462a08c6ac5052f888fef5b494b" +checksum = "e28923312444cdd728e4738b3f9c9cac739500909bb3d3c94b43551b16517648" dependencies = [ "windows-sys", ] [[package]] name = "anstyle-wincon" -version = "2.1.0" +version = "3.0.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "58f54d10c6dfa51283a066ceab3ec1ab78d13fae00aa49243a45e4571fb79dfd" +checksum = "1cd54b81ec8d6180e24654d0b371ad22fc3dd083b6ff8ba325b72e00c87660a7" dependencies = [ "anstyle", "windows-sys", @@ -77,15 +68,6 @@ dependencies = [ "wyz", ] -[[package]] -name = "cc" -version = "1.0.83" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f1174fb0b6ec23863f8b971027804a42614e347eafb0a95bf0b12cdae21fc4d0" -dependencies = [ - "libc", -] - [[package]] name = "cfg-if" version = "1.0.0" @@ -94,9 +76,9 @@ checksum = "baf1de4339761588bc0619e3cbc0120ee582ebb74b53b4efbf79117bd2da40fd" [[package]] name = "clap" -version = "4.4.2" +version = "4.4.18" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6a13b88d2c62ff462f88e4a121f17a82c1af05693a2f192b5c38d14de73c19f6" +checksum = "1e578d6ec4194633722ccf9544794b71b1385c3c027efe0c55db226fc880865c" dependencies = [ "clap_builder", "clap_derive", @@ -104,9 +86,9 @@ dependencies = [ [[package]] name = "clap_builder" -version = "4.4.2" +version = "4.4.18" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2bb9faaa7c2ef94b2743a21f5a29e6f0010dff4caa69ac8e9d6cf8b6fa74da08" +checksum = "4df4df40ec50c46000231c914968278b1eb05098cf8f1b3a518a95030e71d1c7" dependencies = [ "anstream", "anstyle", @@ -116,9 +98,9 @@ dependencies = [ [[package]] name = "clap_derive" -version = "4.4.2" +version = "4.4.7" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0862016ff20d69b84ef8247369fabf5c008a7417002411897d40ee1f4532b873" +checksum = "cf9804afaaf59a91e75b022a30fb7229a7901f60c755489cc61c9b423b836442" dependencies = [ "heck", "proc-macro2", @@ -128,9 +110,9 @@ dependencies = [ [[package]] name = "clap_lex" -version = "0.5.1" +version = "0.6.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "cd7cc57abe963c6d3b9d8be5b06ba7c8957a930305ca90304f24ef040aa6f961" +checksum = "702fc72eb24e5a1e48ce58027a675bc24edd52096d5397d4aea7c6dd9eca0bd1" [[package]] name = "colorchoice" @@ -138,12 +120,6 @@ version = "1.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "acbf1af155f9b9ef647e42cdc158db4b64a1b61f743629225fde6f3e0be2a7c7" -[[package]] -name = "either" -version = "1.9.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a26ae43d7bcc3b814de94796a5e736d4029efb0ee900c12e2d54c993ad1a1e07" - [[package]] name = "ena" version = "0.14.2" @@ -161,9 +137,9 @@ checksum = "e6d5a32815ae3f33302d95fdcb2ce17862f8c65363dcfd29360480ba1001fc9c" [[package]] name = "getrandom" -version = "0.2.10" +version = "0.2.12" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "be4136b2a15dd319360be1c07d9933517ccf0be8f16bf62a3bee4f0d618df427" +checksum = "190092ea657667030ac6a35e305e62fc4dd69fd98ac98631e5d3a2b1575a12b5" dependencies = [ "cfg-if", "libc", @@ -177,89 +153,56 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "95505c38b4572b2d910cecb0281560f54b440a19336cbbcb27bf6ce6adc6f5a8" [[package]] -name = "hercules_codegen" +name = "hercules_cg" version = "0.1.0" dependencies = [ "bitvec", "ena", "hercules_ir", - "inkwell", ] [[package]] -name = "hercules_ir" -version = "0.1.0" -dependencies = [ - "bitvec", - "nom", - "ordered-float", -] - -[[package]] -name = "hercules_opt" +name = "hercules_cpu" version = "0.1.0" dependencies = [ + "clap", + "hercules_cg", "hercules_ir", + "hercules_opt", + "rand", ] [[package]] -name = "hercules_tools" +name = "hercules_dot" version = "0.1.0" dependencies = [ "clap", - "hercules_codegen", "hercules_ir", "hercules_opt", "rand", ] [[package]] -name = "inkwell" -version = "0.2.0" -source = "git+https://github.com/TheDan64/inkwell?branch=master#7a09ad8a5f3b1fc416f95b5e1c97d33df0ab3f06" +name = "hercules_ir" +version = "0.1.0" dependencies = [ - "either", - "inkwell_internals", - "libc", - "llvm-sys", - "once_cell", - "thiserror", + "bitvec", + "nom", + "ordered-float", ] [[package]] -name = "inkwell_internals" -version = "0.8.0" -source = "git+https://github.com/TheDan64/inkwell?branch=master#7a09ad8a5f3b1fc416f95b5e1c97d33df0ab3f06" +name = "hercules_opt" +version = "0.1.0" dependencies = [ - "proc-macro2", - "quote", - "syn", + "hercules_ir", ] -[[package]] -name = "lazy_static" -version = "1.4.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e2abad23fbc42b3700f2f279844dc832adb2b2eb069b2df918f455c4e18cc646" - [[package]] name = "libc" -version = "0.2.149" +version = "0.2.153" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a08173bc88b7955d1b3145aa561539096c421ac8debde8cbc3612ec635fee29b" - -[[package]] -name = "llvm-sys" -version = "160.1.3" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0bf51981ac0622b10fe4790763e3de1f3d68a0ee4222e03accaaab6731bd508d" -dependencies = [ - "cc", - "lazy_static", - "libc", - "regex", - "semver", -] +checksum = "9c198f91728a82281a64e1f4f9eeb25d82cb32a5de251c6bd1b5154d63a8e7bd" [[package]] name = "log" @@ -269,9 +212,9 @@ checksum = "b5e6163cb8c49088c2c36f57875e58ccd8c87c7427f7fbd50ea6710b2f3f2e8f" [[package]] name = "memchr" -version = "2.6.3" +version = "2.7.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8f232d6ef707e1956a43342693d2a31e72989554d58299d7a88738cc95b0d35c" +checksum = "523dc4f511e55ab87b694dc30d0f820d60906ef06413f93d4d7a1385599cc149" [[package]] name = "minimal-lexical" @@ -291,24 +234,18 @@ dependencies = [ [[package]] name = "num-traits" -version = "0.2.16" +version = "0.2.17" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f30b0abd723be7e2ffca1272140fac1a2f084c77ec3e123c192b66af1ee9e6c2" +checksum = "39e3200413f237f41ab11ad6d161bc7239c84dcb631773ccd7de3dfe4b5c267c" dependencies = [ "autocfg", ] -[[package]] -name = "once_cell" -version = "1.18.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "dd8b5dd2ae5ed71462c540258bedcb51965123ad7e7ccf4b9a8cafaa4a63576d" - [[package]] name = "ordered-float" -version = "3.9.1" +version = "4.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2a54938017eacd63036332b4ae5c8a49fc8c0c1d6d629893057e4f13609edd06" +checksum = "a76df7075c7d4d01fdcb46c912dd17fba5b60c78ea480b475f2b6ab6f666584e" dependencies = [ "num-traits", ] @@ -321,18 +258,18 @@ checksum = "5b40af805b3121feab8a3c29f04d8ad262fa8e0561883e7653e024ae4479e6de" [[package]] name = "proc-macro2" -version = "1.0.66" +version = "1.0.78" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "18fb31db3f9bddb2ea821cde30a9f70117e3f119938b5ee630b7403aa6e2ead9" +checksum = "e2422ad645d89c99f8f3e6b88a9fdeca7fabeac836b1002371c4367c8f984aae" dependencies = [ "unicode-ident", ] [[package]] name = "quote" -version = "1.0.33" +version = "1.0.35" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5267fca4496028628a95160fc423a33e8b2e6af8a5302579e322e4b520293cae" +checksum = "291ec9ab5efd934aaf503a6466c5d5251535d108ee747472c3977cc5acc868ef" dependencies = [ "proc-macro2", ] @@ -373,41 +310,6 @@ dependencies = [ "getrandom", ] -[[package]] -name = "regex" -version = "1.9.5" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "697061221ea1b4a94a624f67d0ae2bfe4e22b8a17b6a192afb11046542cc8c47" -dependencies = [ - "aho-corasick", - "memchr", - "regex-automata", - "regex-syntax", -] - -[[package]] -name = "regex-automata" -version = "0.3.8" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c2f401f4955220693b56f8ec66ee9c78abffd8d1c4f23dc41a23839eb88f0795" -dependencies = [ - "aho-corasick", - "memchr", - "regex-syntax", -] - -[[package]] -name = "regex-syntax" -version = "0.7.5" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "dbb5fb1acd8a1a18b3dd5be62d25485eb770e05afb408a9627d14d451bae12da" - -[[package]] -name = "semver" -version = "1.0.19" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ad977052201c6de01a8ef2aa3378c4bd23217a056337d1d6da40468d267a4fb0" - [[package]] name = "strsim" version = "0.10.0" @@ -416,9 +318,9 @@ checksum = "73473c0e59e6d5812c5dfe2a064a6444949f089e20eec9a2e5506596494e4623" [[package]] name = "syn" -version = "2.0.31" +version = "2.0.48" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "718fa2415bcb8d8bd775917a1bf12a7931b6dfa890753378538118181e0cb398" +checksum = "0f3531638e407dfc0814761abb7c00a5b54992b849452a0646b7f65c9f770f3f" dependencies = [ "proc-macro2", "quote", @@ -431,31 +333,11 @@ version = "1.0.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "55937e1799185b12863d447f42597ed69d9928686b8d88a1df17376a097d8369" -[[package]] -name = "thiserror" -version = "1.0.50" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f9a7210f5c9a7156bb50aa36aed4c95afb51df0df00713949448cf9e97d382d2" -dependencies = [ - "thiserror-impl", -] - -[[package]] -name = "thiserror-impl" -version = "1.0.50" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "266b2e40bc00e5a6c09c3584011e08b06f123c00362c92b975ba9843aaaa14b8" -dependencies = [ - "proc-macro2", - "quote", - "syn", -] - [[package]] name = "unicode-ident" -version = "1.0.11" +version = "1.0.12" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "301abaae475aa91687eb82514b328ab47a211a533026cb25fc3e519b86adfc3c" +checksum = "3354b9ac3fae1ff6755cb6db53683adb661634f67557942dea4facebec0fee4b" [[package]] name = "utf8parse" @@ -471,18 +353,18 @@ checksum = "9c8d87e72b64a3b4db28d11ce29237c246188f4f51057d65a7eab63b7987e423" [[package]] name = "windows-sys" -version = "0.48.0" +version = "0.52.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "677d2418bec65e3338edb076e806bc1ec15693c5d0104683f2efe857f61056a9" +checksum = "282be5f36a8ce781fad8c8ae18fa3f9beff57ec1b52cb3de0789201425d9a33d" dependencies = [ "windows-targets", ] [[package]] name = "windows-targets" -version = "0.48.5" +version = "0.52.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9a2fa6e2155d7247be68c096456083145c183cbbbc2764150dda45a87197940c" +checksum = "8a18201040b24831fbb9e4eb208f8892e1f50a37feb53cc7ff887feb8f50e7cd" dependencies = [ "windows_aarch64_gnullvm", "windows_aarch64_msvc", @@ -495,45 +377,45 @@ dependencies = [ [[package]] name = "windows_aarch64_gnullvm" -version = "0.48.5" +version = "0.52.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2b38e32f0abccf9987a4e3079dfb67dcd799fb61361e53e2882c3cbaf0d905d8" +checksum = "cb7764e35d4db8a7921e09562a0304bf2f93e0a51bfccee0bd0bb0b666b015ea" [[package]] name = "windows_aarch64_msvc" -version = "0.48.5" +version = "0.52.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "dc35310971f3b2dbbf3f0690a219f40e2d9afcf64f9ab7cc1be722937c26b4bc" +checksum = "bbaa0368d4f1d2aaefc55b6fcfee13f41544ddf36801e793edbbfd7d7df075ef" [[package]] name = "windows_i686_gnu" -version = "0.48.5" +version = "0.52.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a75915e7def60c94dcef72200b9a8e58e5091744960da64ec734a6c6e9b3743e" +checksum = "a28637cb1fa3560a16915793afb20081aba2c92ee8af57b4d5f28e4b3e7df313" [[package]] name = "windows_i686_msvc" -version = "0.48.5" +version = "0.52.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8f55c233f70c4b27f66c523580f78f1004e8b5a8b659e05a4eb49d4166cca406" +checksum = "ffe5e8e31046ce6230cc7215707b816e339ff4d4d67c65dffa206fd0f7aa7b9a" [[package]] name = "windows_x86_64_gnu" -version = "0.48.5" +version = "0.52.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "53d40abd2583d23e4718fddf1ebec84dbff8381c07cae67ff7768bbf19c6718e" +checksum = "3d6fa32db2bc4a2f5abeacf2b69f7992cd09dca97498da74a151a3132c26befd" [[package]] name = "windows_x86_64_gnullvm" -version = "0.48.5" +version = "0.52.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0b7b52767868a23d5bab768e390dc5f5c55825b6d30b86c844ff2dc7414044cc" +checksum = "1a657e1e9d3f514745a572a6846d3c7aa7dbe1658c056ed9c3344c4109a6949e" [[package]] name = "windows_x86_64_msvc" -version = "0.48.5" +version = "0.52.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ed94fce61571a4006852b7389a063ab983c02eb1bb37b47f8272ce92d06d9538" +checksum = "dff9641d1cd4be8d1a070daf9e3773c5f67e78b4d9d42263020c057706765c04" [[package]] name = "wyz" diff --git a/Cargo.toml b/Cargo.toml index fc931a3cd3057a8ae8f64a939d51ba14b465caee..d6a8629e19126222a539896445cab51e9792716c 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,8 +1,9 @@ [workspace] members = [ - "hercules_codegen", + "hercules_cg", "hercules_ir", "hercules_opt", - "hercules_tools" + "hercules_tools/hercules_dot", + "hercules_tools/hercules_cpu" ] diff --git a/IR.md b/IR.md index 30b7ea89b0dbcbf4baf27fcea5b07bfb37c0eac8..6e08159abad301d146066abf94ebf5d7dd7f3e42 100644 --- a/IR.md +++ b/IR.md @@ -20,7 +20,7 @@ Region nodes are the mechanism for merging multiple branches inside Hercules IR. ## If -The branch mechanism in Hercules IR. An if node takes two inputs - a control predecessor, and a condition. The control predecessor must have control type, and the condition must have boolean type. The output type is a product of two control types, which are the same as the control input's type. Every if node must be followed directly by two [read\_prod](#readprod) nodes, each of which reads differing elements of the if node's output product. This is the mechanism by which the output edges from the if node (and also the [match](#match) node) are labelled, even though nodes only explicitly store their input edges. +The branch mechanism in Hercules IR. An if node takes two inputs - a control predecessor, and a condition. The control predecessor must have control type, and the condition must have boolean type. The output type is the same control type as the control input. Every if node must be followed directly by two [read](#read) nodes, each of which represents the opposite destinations of the branch. This is the mechanism by which the output edges from the if node (and also the [match](#match) node) are labelled, even though nodes only explicitly store their input edges. ## Fork @@ -70,30 +70,14 @@ The binary node represents a basic binary operation. A binary node takes two inp The call node passes its inputs to a function, and outputs the result of the function call. A call node takes some number of data inputs. A call node also stores a reference to the function it calls. The number and types of the data inputs must match the referenced function. A call node also stores references to dynamic constants it uses as inputs to the function. The number of dynamic constants references must match the number of dynamic constant inputs of the referenced function. The output type of a call node is the return type of the referenced function. A call node notably does not take as input or output a control type. This is because all operations in Hercules IR are pure, including arbitrary function calls. Thus, the only things affecting a function call are the data inputs, and (conceptually) the function may be called an arbitrary amount of times. -## ReadProd +## Read -The read\_prod node reads an element from an product typed value. A read\_prod node takes one data input. A read\_prod node also stores the index into the product it reads. The type of the data input must be a product type. The index must be a valid index into the product type. The output type of a read\_prod node is the type of the index-th element in the product (0-indexed). +The read node reads an element from a collection consisting of product, summation, and array types. It uses a series of indices, that index successive levels in a collection's type tree. There are 4 kinds of indices - fields (for products), variants (for summations), positions (for arrays), and controls (for succeeding [if](#if) and [match](#match) nodes). A read node takes one collect input, and potentially position inputs (there are node inputs only for positions). The output type of a read node is the indexed type in the collection's type tree. -## WriteProd +## Write -The write\_prod node modifies an input product with an input datum, and outputs the new product. A write\_prod node takes two inputs - one product input, and one data input. A write\_prod node also stores the index into the product it writes. The type of the product input must be a product type. The type of the data input must be the same as the index-th element in the product (0-indexed). The output type of a write\_prod node is the same as the product input type. - -## ReadArray - -The read\_array node reads an element from an array typed value. A read\_array node takes two inputs - one array input, and one index input. The type of the array input must be an array type. The type of the index input must be an integer type. The output type of a read\_array node is the element type of the array input's array type. At runtime, if an out-of-bounds array access occurs, the conductor will eventually notify the host. - -## WriteArray - -The write\_array node modifies an input array with an input datum. A write\_array node takes three inputs - one array input, one data input, and one index input. The type of the array input must be an array type. The type of the data input must be the same as the element type of the array input's array type. The type of the index input must be an integer type. The output type of a write\_array node is the same as the array input's array type. At runtime, if an out-of-bounds array access occurs, the conductor will eventually notify the host. +The write node writes an element into a collection consisting of product, summation, and array types. It uses a series of indices, that index successive levels in a collection's type tree. There are 4 kinds of indices - fields (for products), variants (for summations), and positions (for arrays). A write node takes one collect input, a data input, and potentially position inputs (there are node inputs only for positions). The output type of a write node is the same as the collection input. The indexed type in the collection must be the same type as the data input. ## Match -The match node branches based on the variant of a sum typed value. A match node takes two inputs - a control predecessor, and a sum input. The control predecessor must have control type, and the sum input must have a sum type. The output type is a product of N control types, where N is the number of possible variants in the sum input's sum type. The control types in the product are the same as the control input's type. Every match node must be followed directly by N [read\_prod](#readprod) nodes, each of which reads differing elements of the match node's output product. This is the mechanism by which the output edges from the match node (and also the [if](#if) node) are labelled, even though nodes only explicitly store their input edges. - -## BuildSum - -The build\_sum node creates a sum typed value from a datum. A build\_sum node takes one input - a data input. A build\_sum node additionally stores the sum type it builds, as well as which variant of the aforementioned sum type it builds. The stored variant must be a valid variant inside the stored sum type. The type of the data input must match the type of the variant of the sum type. The output type of a build\_sum node is the aforementioned sum type. - -## ExtractSum - -The extract\_sum node extracts the concrete value inside a sum value, given a particular variant to extract. An extract\_sum node takes one input - a data input. The data input must have a sum type. An extract\_sum node also stored the variant it extracts. The stored variant must be a valid variant of the data input's sum type. The output type of an extract\_sum node is the type of the specified variant of the data input's sum type. At runtime, if the input sum value holds the stored variant, the output of an extract\_sum node is the value inside that variant in the sum value. If the input sum value holds a different variant, the output of an extract\_sum node is defined as the bit-pattern of all zeros for the output type of the extract\_sum node. +The match node branches based on the variant of a sum typed value. A match node takes two inputs - a control predecessor, and a sum input. The control predecessor must have control type, and the sum input must have a sum type. The output type is a product of N control types, where N is the number of possible variants in the sum input's sum type. The control types in the product are the same as the control input's type. Every match node must be followed directly by N [read](#read) nodes, each of which reads differing elements of the match node's output product. This is the mechanism by which the output edges from the match node (and also the [if](#if) node) are labelled, even though nodes only explicitly store their input edges. diff --git a/hercules_codegen/Cargo.toml b/hercules_cg/Cargo.toml similarity index 52% rename from hercules_codegen/Cargo.toml rename to hercules_cg/Cargo.toml index dfc4f7d19adf3e403e2bef5878b73fe02e1a0cb8..8a4cf940b06a12f7131769107dc4b1b8cf8ab1c5 100644 --- a/hercules_codegen/Cargo.toml +++ b/hercules_cg/Cargo.toml @@ -1,10 +1,9 @@ [package] -name = "hercules_codegen" +name = "hercules_cg" version = "0.1.0" authors = ["Russel Arbore <rarbore2@illinois.edu>"] [dependencies] bitvec = "*" ena = "*" -inkwell = { git = "https://github.com/TheDan64/inkwell", branch = "master", features = ["llvm16-0-prefer-dynamic"] } hercules_ir = { path = "../hercules_ir" } diff --git a/hercules_cg/src/antideps.rs b/hercules_cg/src/antideps.rs new file mode 100644 index 0000000000000000000000000000000000000000..2ce6db83bef9687eda8e1b01cd16824e92562545 --- /dev/null +++ b/hercules_cg/src/antideps.rs @@ -0,0 +1,78 @@ +extern crate hercules_ir; + +use self::hercules_ir::def_use::*; +use self::hercules_ir::ir::*; + +/* + * Top level function to assemble anti-dependence edges. Returns a list of pairs + * of nodes. The first item in the pair is the read node, and the second item is + * the write node. + */ +pub fn antideps<I: Iterator<Item = NodeID>>( + function: &Function, + def_use: &ImmutableDefUseMap, + nodes: I, +) -> Vec<(NodeID, NodeID)> { + // Anti-dependence edges are between a write node and a read node, where + // each node uses the same array value. The read must be scheduled before + // the write to avoid incorrect compilation. + let mut antideps = vec![]; + + for id in nodes { + // Collect the reads and writes to / from this collection. + let users = def_use.get_users(id); + let reads = users.iter().filter(|user| { + if let Node::Read { + collect, + indices: _, + } = function.nodes[user.idx()] + { + collect == id + } else { + false + } + }); + let mut writes = users.iter().filter(|user| { + if let Node::Write { + collect, + data: _, + indices: _, + } = function.nodes[user.idx()] + { + collect == id + } else { + false + } + }); + + // If there are any writes, compute the anti dependence edges. + if let Some(write) = writes.next() { + for read in reads { + antideps.push((*read, *write)); + } + } + + // TODO: Multiple write uses should clone the collection for N - 1 of the writes. + assert!(writes.next() == None, "Can't form anti-dependencies when there are two independent writes depending on a single collection value."); + } + + antideps +} + +/* + * Sometimes, we are only interested in anti-dependence edges involving arrays. + */ +pub fn array_antideps( + function: &Function, + def_use: &ImmutableDefUseMap, + types: &Vec<Type>, + typing: &Vec<TypeID>, +) -> Vec<(NodeID, NodeID)> { + antideps( + function, + def_use, + (0..function.nodes.len()) + .map(NodeID::new) + .filter(|id| types[typing[id.idx()].idx()].is_array()), + ) +} diff --git a/hercules_cg/src/cpu_beta.rs b/hercules_cg/src/cpu_beta.rs new file mode 100644 index 0000000000000000000000000000000000000000..bb73e1d348b31ccbe1b3e71fd1852364da3d08ce --- /dev/null +++ b/hercules_cg/src/cpu_beta.rs @@ -0,0 +1,755 @@ +extern crate bitvec; +extern crate hercules_ir; + +use std::collections::HashMap; +use std::collections::VecDeque; +use std::fmt::Write; +use std::iter::zip; + +use self::bitvec::prelude::*; + +use self::hercules_ir::def_use::*; +use self::hercules_ir::ir::*; + +/* + * This CPU backend is a rewrite of the original CPU alpha backend. It accounts + * for many changes in the IR designed to make lowering, optimization, and code + * generation easier. + */ + +#[derive(Debug)] +struct LLVMBlock { + header: String, + phis: String, + data: String, + terminator: String, +} + +/* + * Top level function to generate code for a module. Emits LLVM IR text. + */ +pub fn cpu_beta_codegen<W: Write>( + module: &hercules_ir::ir::Module, + typing: &hercules_ir::typecheck::ModuleTyping, + reverse_postorders: &Vec<Vec<NodeID>>, + def_uses: &Vec<ImmutableDefUseMap>, + bbs: &Vec<Vec<NodeID>>, + antideps: &Vec<Vec<(NodeID, NodeID)>>, + fork_join_maps: &Vec<HashMap<NodeID, NodeID>>, + fork_join_nests: &Vec<HashMap<NodeID, Vec<NodeID>>>, + w: &mut W, +) -> std::fmt::Result { + let hercules_ir::ir::Module { + functions, + types, + constants, + dynamic_constants, + } = module; + + // Step 1: render types into LLVM IR. This requires translating from our + // interning structures to LLVM types. We can't just blow through the types + // vector, since a type may reference a type ID ahead of it in the vector. + // Instead, iterate types in a bottom up order with respect to the type + // intern DAGs. + let mut llvm_types = vec!["".to_string(); types.len()]; + for id in module.types_bottom_up() { + match &types[id.idx()] { + Type::Control(_) => { + // Later, we create virtual registers corresponding to fork + // nodes of type i64, so we need the "type" of the fork node + // to be i64. + llvm_types[id.idx()] = "i64".to_string(); + } + Type::Boolean => { + llvm_types[id.idx()] = "i1".to_string(); + } + Type::Integer8 | Type::UnsignedInteger8 => { + llvm_types[id.idx()] = "i8".to_string(); + } + Type::Integer16 | Type::UnsignedInteger16 => { + llvm_types[id.idx()] = "i16".to_string(); + } + Type::Integer32 | Type::UnsignedInteger32 => { + llvm_types[id.idx()] = "i32".to_string(); + } + Type::Integer64 | Type::UnsignedInteger64 => { + llvm_types[id.idx()] = "i64".to_string(); + } + Type::Float32 => { + llvm_types[id.idx()] = "float".to_string(); + } + Type::Float64 => { + llvm_types[id.idx()] = "double".to_string(); + } + // Because we traverse in bottom-up order, we can assume that the + // LLVM types for children types are already computed. + Type::Product(fields) => { + let mut iter = fields.iter(); + if let Some(first) = iter.next() { + llvm_types[id.idx()] = + iter.fold("{".to_string() + &llvm_types[first.idx()], |s, f| { + s + ", " + &llvm_types[f.idx()] + }) + "}"; + } else { + llvm_types[id.idx()] = "{}".to_string(); + } + } + Type::Array(_, _) => { + // Array types becomes pointers. The element type and dynamic + // constant bounds characterize the access code we generate + // later, not the type itself. + llvm_types[id.idx()] = "ptr".to_string(); + } + Type::Summation(_) => todo!(), + } + } + + // Step 2: render constants into LLVM IR. This is done in a very similar + // manner as types. + let mut llvm_constants = vec!["".to_string(); types.len()]; + for id in module.constants_bottom_up() { + match &constants[id.idx()] { + Constant::Boolean(val) => { + llvm_constants[id.idx()] = if *val { + "true".to_string() + } else { + "false".to_string() + }; + } + Constant::Integer8(val) => llvm_constants[id.idx()] = format!("{}", val), + Constant::Integer16(val) => llvm_constants[id.idx()] = format!("{}", val), + Constant::Integer32(val) => llvm_constants[id.idx()] = format!("{}", val), + Constant::Integer64(val) => llvm_constants[id.idx()] = format!("{}", val), + Constant::UnsignedInteger8(val) => llvm_constants[id.idx()] = format!("{}", val), + Constant::UnsignedInteger16(val) => llvm_constants[id.idx()] = format!("{}", val), + Constant::UnsignedInteger32(val) => llvm_constants[id.idx()] = format!("{}", val), + Constant::UnsignedInteger64(val) => llvm_constants[id.idx()] = format!("{}", val), + Constant::Float32(val) => { + llvm_constants[id.idx()] = if val.fract() == 0.0 { + format!("{}.0", val) + } else { + format!("{}", val) + } + } + Constant::Float64(val) => { + llvm_constants[id.idx()] = if val.fract() == 0.0 { + format!("{}.0", val) + } else { + format!("{}", val) + } + } + Constant::Product(_, fields) => { + let mut iter = fields.iter(); + if let Some(first) = iter.next() { + llvm_constants[id.idx()] = + iter.fold("{".to_string() + &llvm_constants[first.idx()], |s, f| { + s + ", " + &llvm_constants[f.idx()] + }) + "}"; + } else { + llvm_constants[id.idx()] = "{}".to_string(); + } + } + Constant::Array(_, _) => llvm_constants[id.idx()] = format!("%arr.{}", id.idx()), + Constant::Summation(_, _, _) => todo!(), + } + } + + // Step 3: render dynamic constants into LLVM IR. + let mut llvm_dynamic_constants = vec!["".to_string(); dynamic_constants.len()]; + for id in (0..dynamic_constants.len()).map(DynamicConstantID::new) { + match &dynamic_constants[id.idx()] { + DynamicConstant::Constant(val) => llvm_dynamic_constants[id.idx()] = format!("{}", val), + DynamicConstant::Parameter(num) => { + llvm_dynamic_constants[id.idx()] = format!("%dc{}", num) + } + } + } + + // Step 4: do codegen for each function. + for function_idx in 0..functions.len() { + let function = &functions[function_idx]; + let typing = &typing[function_idx]; + let reverse_postorder = &reverse_postorders[function_idx]; + let def_use = &def_uses[function_idx]; + let bb = &bbs[function_idx]; + let antideps = &antideps[function_idx]; + let fork_join_map = &fork_join_maps[function_idx]; + let fork_join_nest = &fork_join_nests[function_idx]; + + // Step 4.1: emit function signature. + let llvm_ret_type = &llvm_types[function.return_type.idx()]; + let mut llvm_params = function + .param_types + .iter() + .enumerate() + .map(|(idx, id)| format!("{} %p{}", &llvm_types[id.idx()], idx)) + .chain((0..function.num_dynamic_constants).map(|idx| format!("i64 %dc{}", idx))) + .chain( + (0..constants.len()) + .filter(|idx| constants[*idx].is_array()) + .map(|idx| format!("%arr.{}", idx)), + ); + write!(w, "define {} @{}(", llvm_ret_type, function.name)?; + if let Some(first) = llvm_params.next() { + write!(w, "{}", first)?; + for p in llvm_params { + write!(w, ", {}", p)?; + } + } + write!(w, ") {{\n")?; + + // Step 4.2: emit basic blocks. A node represents a basic block if its + // entry in the basic blocks vector points to itself. Each basic block + // is created as four strings: the block header, the block's phis, the + // block's data computations, and the block's terminator instruction. + let mut llvm_bbs = HashMap::new(); + for id in (0..function.nodes.len()).map(NodeID::new) { + if bb[id.idx()] == id { + llvm_bbs.insert( + id, + LLVMBlock { + header: format!("bb_{}:\n", id.idx()), + phis: "".to_string(), + data: "".to_string(), + terminator: "".to_string(), + }, + ); + } + } + + // Step 4.3: emit nodes. Nodes are emitted into basic blocks separately + // as nodes are not necessarily emitted in order. Assemble worklist of + // nodes, starting as reverse post order of nodes. For non-phi and non- + // reduce nodes, only emit once all data uses are emitted. In addition, + // consider additional anti-dependence edges from read to write nodes. + let mut visited = bitvec![u8, Lsb0; 0; function.nodes.len()]; + let mut worklist = VecDeque::from(reverse_postorder.clone()); + while let Some(id) = worklist.pop_front() { + if !(function.nodes[id.idx()].is_phi() || function.nodes[id.idx()].is_reduce()) + && !get_uses(&function.nodes[id.idx()]) + .as_ref() + .into_iter() + .chain( + antideps.iter().filter_map( + |(read, write)| if id == *write { Some(read) } else { None }, + ), + ) + .all(|x| function.nodes[x.idx()].is_control() || visited[x.idx()]) + { + // Skip emitting node if it's not a phi or reducee node and if + // its data uses are not emitted yet. + worklist.push_back(id); + } else { + // Once all of the data dependencies for this node are emitted, + // this node can be emitted. + emit_llvm_for_node( + id, + function, + typing, + types, + dynamic_constants, + bb, + def_use, + fork_join_map, + fork_join_nest, + &mut llvm_bbs, + &llvm_types, + &llvm_constants, + &llvm_dynamic_constants, + )?; + visited.set(id.idx(), true); + } + } + + // Step 4.4: put basic blocks in order. + for node in reverse_postorder { + if bb[node.idx()] == *node { + write!( + w, + "{}{}{}{}", + llvm_bbs[node].header, + llvm_bbs[node].phis, + llvm_bbs[node].data, + llvm_bbs[node].terminator + )?; + } + } + + // Step 4.5: close function. + write!(w, "}}\n")?; + } + + Ok(()) +} + +/* + * Emit LLVM implementing a single node. + */ +fn emit_llvm_for_node( + id: NodeID, + function: &Function, + typing: &Vec<TypeID>, + types: &Vec<Type>, + dynamic_constants: &Vec<DynamicConstant>, + bb: &Vec<NodeID>, + def_use: &ImmutableDefUseMap, + fork_join_map: &HashMap<NodeID, NodeID>, + fork_join_nest: &HashMap<NodeID, Vec<NodeID>>, + llvm_bbs: &mut HashMap<NodeID, LLVMBlock>, + llvm_types: &Vec<String>, + llvm_constants: &Vec<String>, + llvm_dynamic_constants: &Vec<String>, +) -> std::fmt::Result { + // Helper to get the virtual register corresponding to a node. Overload to + // also emit constants, dynamic constants, parameters, and thread IDs. + // Override writes to use previous non-store pointer, since emitting a store + // doesn't create a new pointer virtual register we should use. If the user + // (current node being emitted) is outside the fork-join nest of a reduce + // node, use the reduct input to the reduce node instead - this is needed to + // not get a value from the loop one iteration too early. + let virtual_register = |mut vid: NodeID| { + while let Node::Reduce { + control, + init: _, + reduct, + } = &function.nodes[vid.idx()] + { + // Figure out the fork corresponding to the associated join. + let fork_id = if let Node::Join { control } = function.nodes[control.idx()] { + if let Type::Control(factors) = &types[typing[control.idx()].idx()] { + *factors.last().unwrap() + } else { + panic!() + } + } else { + panic!() + }; + + if !fork_join_nest[&bb[id.idx()]].contains(&fork_id) { + vid = *reduct; + } else { + break; + } + } + + while let Node::Write { + collect, + indices: _, + data: _, + } = &function.nodes[vid.idx()] + { + vid = *collect; + } + + match function.nodes[vid.idx()] { + Node::Constant { id } => llvm_constants[id.idx()].clone(), + Node::DynamicConstant { id } => llvm_dynamic_constants[id.idx()].clone(), + Node::Parameter { index } => format!("%p{}", index), + Node::ThreadID { control } => format!("%v{}", control.idx()), + _ => format!("%v{}", vid.idx()), + } + }; + let type_of = |id: NodeID| format!("{}", llvm_types[typing[id.idx()].idx()]); + let normal_value = |id: NodeID| format!("{} {}", type_of(id), virtual_register(id)); + + // Helper to emit code to index into an aggregate, and return a pointer to + // the indexed element. This only works when the first aggregate is an + // array. + let mut generate_index_code = |collect: NodeID, indices: &[Index]| { + let extents = types[typing[collect.idx()].idx()].try_extents().unwrap(); + let position = indices[0].try_position().unwrap(); + for (idx, (extent, index_id)) in zip(extents, position).enumerate() { + if idx == 0 { + llvm_bbs.get_mut(&bb[id.idx()]).unwrap().data += &format!( + " %index.{}.acc.add.{} = add {}, {}\n", + id.idx(), + idx, + normal_value(*index_id), + 0, + ); + } else { + llvm_bbs.get_mut(&bb[id.idx()]).unwrap().data += &format!( + " %index.{}.acc.mul.{} = mul i64 {}, %index.{}.acc.add.{}\n", + id.idx(), + idx, + llvm_dynamic_constants[extent.idx()], + id.idx(), + idx - 1, + ); + llvm_bbs.get_mut(&bb[id.idx()]).unwrap().data += &format!( + " %index.{}.acc.add.{} = add {}, %index.{}.acc.mul.{}\n", + id.idx(), + idx, + normal_value(*index_id), + id.idx(), + idx, + ); + } + } + llvm_bbs.get_mut(&bb[id.idx()]).unwrap().data += &format!( + " %index.{}.ptr = getelementptr {}, {}, i64 %index.{}.acc.add.{}", + id.idx(), + llvm_types[types[typing[collect.idx()].idx()] + .try_element_type() + .unwrap() + .idx()], + normal_value(collect), + id.idx(), + extents.len() - 1 + ); + for index in &indices[1..] { + llvm_bbs.get_mut(&bb[id.idx()]).unwrap().data += + &format!(", i32 {}", index.try_field().unwrap()); + } + llvm_bbs.get_mut(&bb[id.idx()]).unwrap().data += "\n"; + }; + + // Emit code depending on the type of the node. + match function.nodes[id.idx()] { + Node::Start | Node::Region { preds: _ } => { + let successor = def_use + .get_users(id) + .iter() + .filter(|id| function.nodes[id.idx()].is_strictly_control()) + .next() + .unwrap(); + llvm_bbs.get_mut(&bb[id.idx()]).unwrap().terminator = + format!(" br label %bb_{}\n", successor.idx()); + } + Node::If { control: _, cond } => { + let successors = def_use.get_users(id); + let rev = if let Node::Read { + collect: _, + indices, + } = &function.nodes[successors[0].idx()] + { + indices[0] != Index::Control(0) + } else { + panic!() + }; + llvm_bbs.get_mut(&bb[id.idx()]).unwrap().terminator = format!( + " br {}, label %bb_{}, label %bb_{}\n", + normal_value(cond), + successors[(!rev) as usize].idx(), + successors[rev as usize].idx() + ); + } + Node::Fork { control, factor: _ } => { + // Calculate the join and successor. + let join = fork_join_map[&id]; + let successor = def_use + .get_users(id) + .iter() + .filter(|id| function.nodes[id.idx()].is_strictly_control()) + .next() + .unwrap(); + + // Need to create phi node for the loop index. + llvm_bbs.get_mut(&bb[id.idx()]).unwrap().phis += &format!( + " {} = phi i64 [ 0, %bb_{} ], [ %fork.{}.inc, %bb_{} ]\n", + virtual_register(id), + control.idx(), + id.idx(), + join.idx(), + ); + llvm_bbs.get_mut(&bb[id.idx()]).unwrap().data += &format!( + " %fork.{}.inc = add i64 1, {}\n", + id.idx(), + virtual_register(id), + ); + llvm_bbs.get_mut(&bb[id.idx()]).unwrap().terminator = + format!(" br label %bb_{}\n", successor.idx()); + } + Node::Join { control } => { + // Get the fork, its factor, and the successor to this join. + let fork_id = if let Type::Control(factors) = &types[typing[control.idx()].idx()] { + *factors.last().unwrap() + } else { + panic!() + }; + let factor = if let Node::Fork { control: _, factor } = &function.nodes[fork_id.idx()] { + *factor + } else { + panic!() + }; + let successor = def_use + .get_users(id) + .iter() + .filter(|id| function.nodes[id.idx()].is_strictly_control()) + .next() + .unwrap(); + + // Form the bottom of the loop. We need to branch between the + // successor and the fork. + llvm_bbs.get_mut(&bb[id.idx()]).unwrap().data += &format!( + " %join.{}.cond = icmp ult i64 %fork.{}.inc, {}\n", + id.idx(), + fork_id.idx(), + llvm_dynamic_constants[factor.idx()], + ); + llvm_bbs.get_mut(&bb[id.idx()]).unwrap().terminator = format!( + " br i1 %join.{}.cond, label %bb_{}, label %bb_{}\n", + id.idx(), + fork_id.idx(), + successor.idx(), + ); + } + Node::Phi { + control: _, + ref data, + } => { + let pred_ids = get_uses(&function.nodes[bb[id.idx()].idx()]); + let mut iter = zip(data.iter(), pred_ids.as_ref().iter()); + let (first_data, first_control) = iter.next().unwrap(); + let mut phi = format!( + " {} = phi {} [ {}, %bb_{} ]", + virtual_register(id), + type_of(id), + virtual_register(*first_data), + bb[first_control.idx()].idx() + ); + for (data, control) in iter { + phi += &format!( + ", [ {}, %bb_{} ]", + virtual_register(*data), + bb[control.idx()].idx() + ); + } + phi += "\n"; + llvm_bbs.get_mut(&bb[id.idx()]).unwrap().phis += φ + } + // No code needs to get emitted for thread ID nodes - the loop index is + // emitted in the fork. + Node::ThreadID { control: _ } => {} + Node::Reduce { + control, + init, + reduct, + } => { + // Figure out the fork corresponding to the associated join. + let fork_id = if let Node::Join { control } = function.nodes[control.idx()] { + if let Type::Control(factors) = &types[typing[control.idx()].idx()] { + *factors.last().unwrap() + } else { + panic!() + } + } else { + panic!() + }; + + // Figure out the fork's predecessor. + let pred = if let Node::Fork { control, factor: _ } = function.nodes[fork_id.idx()] { + control + } else { + panic!() + }; + + // Create the phi node for the reduction. + llvm_bbs.get_mut(&bb[fork_id.idx()]).unwrap().phis += &format!( + " {} = phi {} [ {}, %bb_{} ], [ {}, %bb_{} ]\n", + virtual_register(id), + type_of(id), + virtual_register(init), + pred.idx(), + virtual_register(reduct), + control.idx(), + ); + } + Node::Return { control: _, data } => { + llvm_bbs.get_mut(&bb[id.idx()]).unwrap().terminator = + format!(" ret {}\n", normal_value(data)); + } + // No code needs to get emitted for parameters, constants, or dynamic + // constants - these are just specific virtual registers or constant + // values. + Node::Parameter { index: _ } => {} + Node::Constant { id: _ } => {} + Node::DynamicConstant { id: _ } => {} + Node::Unary { input, op } => match op { + UnaryOperator::Not => todo!(), + UnaryOperator::Neg => { + if types[typing[input.idx()].idx()].is_float() { + llvm_bbs.get_mut(&bb[id.idx()]).unwrap().data += &format!( + " {} = fneg {}\n", + virtual_register(id), + normal_value(input) + ); + } else { + todo!() + } + } + }, + Node::Binary { left, right, op } => { + let opcode = match op { + BinaryOperator::Add => { + if types[typing[left.idx()].idx()].is_float() { + "fadd" + } else { + "add" + } + } + BinaryOperator::Sub => { + if types[typing[left.idx()].idx()].is_float() { + "fsub" + } else { + "sub" + } + } + BinaryOperator::Mul => { + if types[typing[left.idx()].idx()].is_float() { + "fmul" + } else { + "mul" + } + } + BinaryOperator::Div => { + if types[typing[left.idx()].idx()].is_float() { + "fdiv" + } else if types[typing[left.idx()].idx()].is_unsigned() { + "udiv" + } else { + "sdiv" + } + } + BinaryOperator::Rem => { + if types[typing[left.idx()].idx()].is_float() { + "frem" + } else if types[typing[left.idx()].idx()].is_unsigned() { + "urem" + } else { + "srem" + } + } + BinaryOperator::LT => { + if types[typing[left.idx()].idx()].is_float() { + "fcmp olt" + } else if types[typing[left.idx()].idx()].is_unsigned() { + "icmp ult" + } else { + "icmp slt" + } + } + BinaryOperator::LTE => { + if types[typing[left.idx()].idx()].is_float() { + "fcmp ole" + } else if types[typing[left.idx()].idx()].is_unsigned() { + "icmp ule" + } else { + "icmp sle" + } + } + BinaryOperator::GT => { + if types[typing[left.idx()].idx()].is_float() { + "fcmp ogt" + } else if types[typing[left.idx()].idx()].is_unsigned() { + "icmp ugt" + } else { + "icmp sgt" + } + } + BinaryOperator::GTE => { + if types[typing[left.idx()].idx()].is_float() { + "fcmp oge" + } else if types[typing[left.idx()].idx()].is_unsigned() { + "icmp uge" + } else { + "icmp sge" + } + } + BinaryOperator::EQ => { + if types[typing[left.idx()].idx()].is_float() { + "fcmp oeq" + } else { + "icmp eq" + } + } + BinaryOperator::NE => { + if types[typing[left.idx()].idx()].is_float() { + "fcmp one" + } else { + "icmp ne" + } + } + BinaryOperator::Or => "or", + BinaryOperator::And => "and", + BinaryOperator::Xor => "xor", + BinaryOperator::LSh => "lsh", + BinaryOperator::RSh => { + if types[typing[left.idx()].idx()].is_unsigned() { + "lshr" + } else { + "ashr" + } + } + }; + llvm_bbs.get_mut(&bb[id.idx()]).unwrap().data += &format!( + " {} = {} {}, {}\n", + virtual_register(id), + opcode, + normal_value(left), + virtual_register(right), + ); + } + Node::Read { + collect, + ref indices, + } => { + // Read nodes may be projection successors of if or match nodes. + if function.nodes[collect.idx()].is_strictly_control() { + let successor = def_use.get_users(id)[0]; + llvm_bbs.get_mut(&bb[id.idx()]).unwrap().terminator = + format!(" br label %bb_{}\n", successor.idx()); + } else { + generate_index_code(collect, indices); + if types[typing[collect.idx()].idx()].is_array() { + llvm_bbs.get_mut(&bb[id.idx()]).unwrap().data += &format!( + " {} = load {}, ptr %index.{}.ptr\n", + virtual_register(id), + llvm_types[typing[id.idx()].idx()], + id.idx() + ); + } else { + llvm_bbs.get_mut(&bb[id.idx()]).unwrap().data += &format!( + " {} = extractvalue {}", + virtual_register(id), + normal_value(collect) + ); + for index in indices.iter() { + llvm_bbs.get_mut(&bb[id.idx()]).unwrap().data += + &format!(", {}", index.try_field().unwrap()); + } + llvm_bbs.get_mut(&bb[id.idx()]).unwrap().data += "\n"; + } + } + } + Node::Write { + collect, + ref indices, + data, + } => { + generate_index_code(collect, indices); + if types[typing[collect.idx()].idx()].is_array() { + llvm_bbs.get_mut(&bb[id.idx()]).unwrap().data += &format!( + " store {}, ptr %index.{}.ptr\n", + normal_value(data), + id.idx() + ); + } else { + llvm_bbs.get_mut(&bb[id.idx()]).unwrap().data += &format!( + " {} = insertvalue {}, {}", + virtual_register(id), + normal_value(collect), + normal_value(data) + ); + for index in indices.iter() { + llvm_bbs.get_mut(&bb[id.idx()]).unwrap().data += + &format!(", {}", index.try_field().unwrap()); + } + llvm_bbs.get_mut(&bb[id.idx()]).unwrap().data += "\n"; + } + } + _ => todo!(), + } + + Ok(()) +} diff --git a/hercules_codegen/src/gcm.rs b/hercules_cg/src/gcm.rs similarity index 98% rename from hercules_codegen/src/gcm.rs rename to hercules_cg/src/gcm.rs index a27f15b08bff875c51dfaafa6e6640ca20dfb820..baaeee330c41d1098d4c1546ddb60732d1ee6dbc 100644 --- a/hercules_codegen/src/gcm.rs +++ b/hercules_cg/src/gcm.rs @@ -103,7 +103,7 @@ pub fn compute_fork_join_nesting( // the corresponding join node. (0..function.nodes.len()) .map(NodeID::new) - .filter(|id| function.is_control(*id)) + .filter(|id| function.nodes[id.idx()].is_control()) .map(|id| { ( id, diff --git a/hercules_cg/src/lib.rs b/hercules_cg/src/lib.rs new file mode 100644 index 0000000000000000000000000000000000000000..dbe405317d8bf07a17448ff9fc257c1df6f3c279 --- /dev/null +++ b/hercules_cg/src/lib.rs @@ -0,0 +1,7 @@ +pub mod antideps; +pub mod cpu_beta; +pub mod gcm; + +pub use crate::antideps::*; +pub use crate::cpu_beta::*; +pub use crate::gcm::*; diff --git a/hercules_codegen/src/antideps.rs b/hercules_codegen/src/antideps.rs deleted file mode 100644 index ff748ce8b93d2294f6aa55dda5bc4281f17a8fa5..0000000000000000000000000000000000000000 --- a/hercules_codegen/src/antideps.rs +++ /dev/null @@ -1,52 +0,0 @@ -extern crate hercules_ir; - -use self::hercules_ir::def_use::*; -use self::hercules_ir::ir::*; - -/* - * Top level function to assemble anti-dependence edges. Returns a list of pairs - * of nodes. The first item in the pair is the read node, and the second item is - * the write node. - */ -pub fn antideps(function: &Function, def_use: &ImmutableDefUseMap) -> Vec<(NodeID, NodeID)> { - // Array typed values are not directly computed on. Thus, there are actually - // very few nodes that have array inputs or output. As a result, when - // forming anti-dependencies for a single allocation, we only need to - // consider immediate users that are read or write nodes - no proper - // dataflow analysis necessary. - let mut antideps = vec![]; - - for id in (0..function.nodes.len()).map(NodeID::new) { - // We only need to consider array reads and writes. - let users = def_use.get_users(id); - let reads = users.iter().filter(|user| { - if let Node::ReadArray { array, index: _ } = function.nodes[user.idx()] { - array == id - } else { - false - } - }); - let mut writes = users.iter().filter(|user| { - if let Node::WriteArray { - array, - index: _, - data: _, - } = function.nodes[user.idx()] - { - array == id - } else { - false - } - }); - - // If there are any writes, compute the anti dependence edges. - if let Some(write) = writes.next() { - for read in reads { - antideps.push((*read, *write)); - } - } - assert!(writes.next() == None, "Can't form anti-dependencies when there are two independent writes depending on a single array value."); - } - - antideps -} diff --git a/hercules_codegen/src/array_alloc.rs b/hercules_codegen/src/array_alloc.rs deleted file mode 100644 index 3ce9fcc13ed242fc0f2d4503340b68561745fe0d..0000000000000000000000000000000000000000 --- a/hercules_codegen/src/array_alloc.rs +++ /dev/null @@ -1,214 +0,0 @@ -extern crate ena; -extern crate hercules_ir; - -use std::collections::HashMap; - -use self::ena::unify::*; - -use self::hercules_ir::def_use::*; -use self::hercules_ir::ir::*; - -#[derive(Copy, Clone, Debug, Hash, PartialEq, Eq)] -struct UnitKey(u32); - -impl UnifyKey for UnitKey { - type Value = (); - - fn index(&self) -> u32 { - self.0 - } - - fn from_index(u: u32) -> UnitKey { - UnitKey(u) - } - - fn tag() -> &'static str { - "UnitKey" - } -} - -/* - * Top level function to allocate individual arrays for sets of nodes in the IR - * graph. Doesn't attempt to overlap allocations w/ liveness analysis. Returns - * a set of nodes per allocation, which are the nodes that use that allocation, - * along with dimensions specified with dynamic constants. The return value - * consists of a list of array allocations (described by their dynamic constant - * size), and a map from nodes involving array values to the number of the array - * allocation they operate on. - */ -pub fn logical_array_alloc( - function: &Function, - typing: &Vec<TypeID>, - types: &Vec<Type>, - fork_join_map: &HashMap<NodeID, NodeID>, - bbs: &Vec<NodeID>, - fork_join_nests: &HashMap<NodeID, Vec<NodeID>>, -) -> (Vec<Vec<DynamicConstantID>>, HashMap<NodeID, usize>) { - // Step 1: filter nodes that operate on arrays, either on their input or - // their output. - let id_outputs_array = |id: &NodeID| types[typing[id.idx()].idx()].is_array(); - let array_nodes: Vec<_> = (0..function.nodes.len()) - .map(NodeID::new) - .filter(|id| { - id_outputs_array(id) - || get_uses(&function.nodes[id.idx()]) - .as_ref() - .into_iter() - .any(id_outputs_array) - }) - .collect(); - let array_node_numbers: HashMap<_, _> = - std::iter::zip(array_nodes.iter().map(|x| *x), 0..array_nodes.len()).collect(); - - // Step 2: union find the nodes based on use edges. Every node in each set - // should use the same array allocation. The representative node for a set - // will be the node with the smallest ID. - let mut allocs: UnificationTable<InPlace<UnitKey>> = UnificationTable::new(); - let keys: Vec<_> = (0..array_nodes.len()).map(|_| allocs.new_key(())).collect(); - for node in array_nodes.iter() { - for array_use in get_uses(&function.nodes[node.idx()]) - .as_ref() - .into_iter() - .map(|x| *x) - .filter(id_outputs_array) - { - allocs.union( - keys[array_node_numbers[&node]], - keys[array_node_numbers[&array_use]], - ); - } - } - - // Step 3: determine the size of each array allocation. This is the size of - // the array type operated on, possibly in addition to dynamic constant - // factors corresponding to uses inside fork / joins. Each node that can - // operate on array values, and their corresponding affect on the array - // value's size, are listed below. - // - // Phi: only provides the base array dimensions - // Collect: provides the base array dimensions, in addition to dimensions - // corresponding to dominating fork / join nests - // Return: provides no array dimensions - // Parameter: only provides the base array dimensions - // Constant: only provides the base array dimensions - // Call: TODO - // ReadArray: only provides the base array dimensions - // WriteArray: provides the base array dimensions, in addition to dimensions - // corresponding to each fork / join the node is nested in - let mut key_to_value_size: HashMap<UnitKey, Vec<DynamicConstantID>> = HashMap::new(); - for key in keys.iter() { - let value_key = allocs.find(*key); - let id = array_nodes[key.index() as usize]; - - let extents = match function.nodes[id.idx()] { - Node::Phi { - control: _, - data: _, - } - | Node::Parameter { index: _ } - | Node::Constant { id: _ } - | Node::ReadArray { array: _, index: _ } => { - // For nodes that don't write to the array, the required size - // is just the underlying size of the array. - type_extents(typing[id.idx()], types) - } - Node::Collect { - control: _, - data: _, - } - | Node::WriteArray { - array: _, - data: _, - index: _, - } => { - // For nodes that write to the array, the required size depends - // on the surrounding fork / join pairs. - write_dimensionality( - function, - id, - typing, - types, - fork_join_map, - bbs, - fork_join_nests, - ) - } - Node::Return { - control: _, - data: _, - } => { - continue; - } - _ => todo!(), - }; - - // The largest required size is the correct size. It is assumed that all - // sizes calculated above form a total order with respect to suffix - // vectoring. - if let Some(old_extents) = key_to_value_size.get(&value_key) { - assert!( - std::iter::zip(old_extents.iter().rev(), extents.iter().rev()).all(|(a, b)| a == b) - ); - if old_extents.len() < extents.len() { - key_to_value_size.insert(value_key, extents); - } - } else { - key_to_value_size.insert(value_key, extents); - } - } - - // Step 4: collect array allocations as a numbered list. Map from array - // nodes to the array allocation number they use. - let mut logical_allocations: Vec<Vec<DynamicConstantID>> = vec![]; - let mut key_to_number: HashMap<UnitKey, usize> = HashMap::new(); - for (key, array_value) in key_to_value_size { - key_to_number.insert(key, logical_allocations.len()); - logical_allocations.push(array_value); - } - let mut node_to_logical_numbers: HashMap<NodeID, usize> = HashMap::new(); - for node in array_nodes.iter() { - node_to_logical_numbers.insert( - *node, - key_to_number[&allocs.find(keys[array_node_numbers[node]])], - ); - } - - (logical_allocations, node_to_logical_numbers) -} - -/* - * Get the dimensionality of a write. Checks for the dimensions of the array - * type, and adds dimensions corresponding to dominating fork / join nests. - */ -pub fn write_dimensionality( - function: &Function, - write: NodeID, - typing: &Vec<TypeID>, - types: &Vec<Type>, - fork_join_map: &HashMap<NodeID, NodeID>, - bbs: &Vec<NodeID>, - fork_join_nests: &HashMap<NodeID, Vec<NodeID>>, -) -> Vec<DynamicConstantID> { - let mut extents = type_extents(typing[write.idx()], types); - assert!( - extents.len() > 0, - "Can't call write_dimensionality with a node that doesn't output an array." - ); - extents.reverse(); - - for fork in fork_join_nests[&bbs[write.idx()]].iter() { - if let Node::Fork { control: _, factor } = function.nodes[fork.idx()] { - // If this node is a collect, we don't need to add the dimension - // from the corresponding fork. - if function.nodes[write.idx()].is_collect() && fork_join_map[&fork] != bbs[write.idx()] - { - extents.push(factor); - } - } else { - panic!("Fork join nests map contains a non-fork in the value list."); - } - } - - extents.reverse(); - extents -} diff --git a/hercules_codegen/src/cpu_alpha.rs b/hercules_codegen/src/cpu_alpha.rs deleted file mode 100644 index d08ed6d23f8fe082220048e1d7aa965608e53663..0000000000000000000000000000000000000000 --- a/hercules_codegen/src/cpu_alpha.rs +++ /dev/null @@ -1,1163 +0,0 @@ -extern crate bitvec; -extern crate hercules_ir; -extern crate inkwell; - -use std::collections::HashMap; -use std::collections::VecDeque; -use std::convert::TryFrom; -use std::iter::repeat; -use std::iter::zip; - -use self::inkwell::basic_block::*; -use self::inkwell::builder::*; -use self::inkwell::context::*; -use self::inkwell::types::*; -use self::inkwell::values::*; -use self::inkwell::*; - -use self::hercules_ir::def_use::*; -use self::hercules_ir::ir::*; - -/* - * This CPU backend was written to get some Hercules IR running, and to better - * understand how writing backends for Hercules IR will work. This backend is - * not meant to be used in the long term. If you are reading this code in a - * significant amount of time from when this comment was written, you are - * probably already doing something wrong - Russel. - */ - -/* - * Top level function to generate code for a module. Writes the result object - * file to the specified path. - */ -pub fn cpu_alpha_codegen( - module: &hercules_ir::ir::Module, - typing: &hercules_ir::typecheck::ModuleTyping, - reverse_postorders: &Vec<Vec<NodeID>>, - def_uses: &Vec<ImmutableDefUseMap>, - bbs: &Vec<Vec<NodeID>>, - antideps: &Vec<Vec<(NodeID, NodeID)>>, - array_allocations: &Vec<(Vec<Vec<DynamicConstantID>>, HashMap<NodeID, usize>)>, - fork_join_nests: &Vec<HashMap<NodeID, Vec<NodeID>>>, - path: &std::path::Path, -) { - let hercules_ir::ir::Module { - functions, - types, - constants, - dynamic_constants, - } = module; - - // Step 1: initialize LLVM objects. - let llvm_context = Context::create(); - let llvm_module = llvm_context.create_module(""); - let llvm_builder = llvm_context.create_builder(); - - // Step 2: convert the types. This requires translating from our interning - // structures to LLVM's. We can't just blow through the types vector, since - // a type may reference a type ID ahead of it in the vector. Instead, - // iterate types in a bottom up order with respect to the type intern DAGs. - let mut llvm_types = vec![llvm_context.i8_type().as_basic_type_enum(); types.len()]; - for id in module.types_bottom_up() { - match &types[id.idx()] { - Type::Control(_) => {} - Type::Boolean => { - llvm_types[id.idx()] = llvm_context.bool_type().as_basic_type_enum(); - } - Type::Integer8 | Type::UnsignedInteger8 => { - llvm_types[id.idx()] = llvm_context.i8_type().as_basic_type_enum(); - } - Type::Integer16 | Type::UnsignedInteger16 => { - llvm_types[id.idx()] = llvm_context.i16_type().as_basic_type_enum(); - } - Type::Integer32 | Type::UnsignedInteger32 => { - llvm_types[id.idx()] = llvm_context.i32_type().as_basic_type_enum(); - } - Type::Integer64 | Type::UnsignedInteger64 => { - llvm_types[id.idx()] = llvm_context.i64_type().as_basic_type_enum(); - } - Type::Float32 => { - llvm_types[id.idx()] = llvm_context.f32_type().as_basic_type_enum(); - } - Type::Float64 => { - llvm_types[id.idx()] = llvm_context.f64_type().as_basic_type_enum(); - } - // Because we traverse in bottom-up order, we can assume that the - // LLVM types for children types are already computed. - Type::Product(fields) => { - let field_types = fields - .iter() - .map(|id| llvm_types[id.idx()]) - .collect::<Box<[_]>>(); - llvm_types[id.idx()] = llvm_context - .struct_type(&field_types, false) - .as_basic_type_enum(); - } - Type::Array(elem, _) => { - // Array types need to be flattened - an array of an array in - // Hercules IR needs to translate to a single pointer in LLVM. - if let Type::Array(_, _) = types[elem.idx()] { - llvm_types[id.idx()] = llvm_types[elem.idx()]; - } else { - let elem_type = llvm_types[elem.idx()]; - llvm_types[id.idx()] = elem_type - .ptr_type(AddressSpace::default()) - .as_basic_type_enum(); - } - } - Type::Summation(_) => todo!(), - } - } - - // Step 3: convert the constants. This is done in a very similar manner as - // types. - let mut llvm_constants = vec![ - llvm_context - .i8_type() - .const_int(0, false) - .as_basic_value_enum(); - constants.len() - ]; - for id in module.constants_bottom_up() { - match &constants[id.idx()] { - Constant::Boolean(val) => { - llvm_constants[id.idx()] = llvm_context - .bool_type() - .const_int(*val as u64, false) - .as_basic_value_enum(); - } - Constant::Integer8(val) => { - llvm_constants[id.idx()] = llvm_context - .i8_type() - .const_int(*val as u64, true) - .as_basic_value_enum(); - } - Constant::Integer16(val) => { - llvm_constants[id.idx()] = llvm_context - .i16_type() - .const_int(*val as u64, true) - .as_basic_value_enum(); - } - Constant::Integer32(val) => { - llvm_constants[id.idx()] = llvm_context - .i32_type() - .const_int(*val as u64, true) - .as_basic_value_enum(); - } - Constant::Integer64(val) => { - llvm_constants[id.idx()] = llvm_context - .i64_type() - .const_int(*val as u64, true) - .as_basic_value_enum(); - } - Constant::UnsignedInteger8(val) => { - llvm_constants[id.idx()] = llvm_context - .i8_type() - .const_int(*val as u64, false) - .as_basic_value_enum(); - } - Constant::UnsignedInteger16(val) => { - llvm_constants[id.idx()] = llvm_context - .i16_type() - .const_int(*val as u64, false) - .as_basic_value_enum(); - } - Constant::UnsignedInteger32(val) => { - llvm_constants[id.idx()] = llvm_context - .i32_type() - .const_int(*val as u64, false) - .as_basic_value_enum(); - } - Constant::UnsignedInteger64(val) => { - llvm_constants[id.idx()] = llvm_context - .i64_type() - .const_int(*val, false) - .as_basic_value_enum(); - } - Constant::Float32(val) => { - llvm_constants[id.idx()] = llvm_context - .f32_type() - .const_float(val.into_inner() as f64) - .as_basic_value_enum(); - } - Constant::Float64(val) => { - llvm_constants[id.idx()] = llvm_context - .f64_type() - .const_float(val.into_inner()) - .as_basic_value_enum(); - } - // Because we traverse in bottom-up order, we can assume that the - // LLVM constants for children constants are already computed. - Constant::Product(_, fields) => { - let field_constants = fields - .iter() - .map(|id| llvm_constants[id.idx()]) - .collect::<Box<[_]>>(); - llvm_constants[id.idx()] = llvm_context - .const_struct(&field_constants, false) - .as_basic_value_enum(); - } - Constant::Array(_, _) => todo!(), - Constant::Summation(_, _, _) => todo!(), - } - } - - // Step 4: do codegen for each function. - for function_idx in 0..functions.len() { - let function = &functions[function_idx]; - let typing = &typing[function_idx]; - let reverse_postorder = &reverse_postorders[function_idx]; - let def_use = &def_uses[function_idx]; - let bb = &bbs[function_idx]; - let antideps = &antideps[function_idx]; - let fork_join_nest = &fork_join_nests[function_idx]; - let array_allocations = &array_allocations[function_idx]; - - // Step 4.1: create LLVM function object. - let llvm_ret_type = llvm_types[function.return_type.idx()]; - let llvm_param_types = function - .param_types - .iter() - .map(|id| llvm_types[id.idx()].into()) - .chain( - repeat(BasicMetadataTypeEnum::try_from(llvm_context.i64_type()).unwrap()) - .take(function.num_dynamic_constants as usize), - ) - .chain( - repeat( - BasicMetadataTypeEnum::try_from( - llvm_context.i8_type().ptr_type(AddressSpace::default()), - ) - .unwrap(), - ) - .take(array_allocations.0.len() as usize), - ) - .collect::<Box<[_]>>(); - let llvm_fn_type = llvm_ret_type.fn_type(&llvm_param_types, false); - let llvm_fn = llvm_module.add_function(&function.name, llvm_fn_type, None); - - // Step 4.2: create LLVM basic blocks. A node needs a corresponding - // basic block if its entry in the basic blocks vector points to iself. - let mut llvm_bbs = HashMap::new(); - for id in (0..function.nodes.len()).map(NodeID::new) { - if bb[id.idx()] == id { - llvm_bbs.insert( - id, - llvm_context.append_basic_block(llvm_fn, &format!("bb_{}", id.idx())), - ); - } - } - - // Step 4.3: emit LLVM for each node. Assemble worklist of nodes, - // starting as reverse post order of nodes. For non-phi nodes, only emit - // once all data uses are emitted. In addition, consider additional anti - // dependence edges from read to write nodes. - let mut values = HashMap::new(); - let mut phi_values = HashMap::new(); - let mut branch_instructions = HashMap::new(); - let mut worklist = VecDeque::from(reverse_postorder.clone()); - while let Some(id) = worklist.pop_front() { - if !function.nodes[id.idx()].is_phi() - && !get_uses(&function.nodes[id.idx()]) - .as_ref() - .into_iter() - .chain( - antideps.iter().filter_map( - |(read, write)| if id == *write { Some(read) } else { None }, - ), - ) - .all(|x| function.is_control(*x) || values.contains_key(x)) - { - // Skip emitting node if it's not a phi node and if its data - // uses are not emitted yet. - worklist.push_back(id); - } else { - // Once all of the data dependencies for this node are emitted, - // this node can be emitted. - emit_llvm_for_node( - id, - &mut values, - &mut phi_values, - &mut branch_instructions, - function, - typing, - types, - dynamic_constants, - bb, - def_use, - fork_join_nest, - array_allocations, - &llvm_context, - &llvm_builder, - llvm_fn, - &llvm_bbs, - &llvm_types, - &llvm_constants, - ); - } - } - - // Step 4.4: patch phi nodes with incoming data values. - for id in (0..function.nodes.len()).map(NodeID::new) { - let node = &function.nodes[id.idx()]; - if node.is_phi() { - // Region node is the only strictly control use of the phi. - let uses = get_uses(node); - let region = uses - .as_ref() - .iter() - .filter(|id| function.nodes[id.idx()].is_strictly_control()) - .next() - .unwrap(); - - // Need to create intermediate vector - Inkwell expects a list - // of dynamic references to basic values. Those references must - // reference concrete basic values, which we need to create. - // Thus, we need to store them in this intermediate vector. - let data_uses: Vec<_> = uses - .as_ref() - .iter() - .filter(|id| !function.nodes[id.idx()].is_strictly_control()) - .map(|id| BasicValueEnum::try_from(values[id]).unwrap()) - .collect(); - let data_uses = data_uses - .iter() - .map(|ref_value| ref_value as &dyn BasicValue); - - // The basic blocks are the uses of the region node. - let region_uses = get_uses(&function.nodes[region.idx()]); - let pred_bbs = region_uses.as_ref().iter().map(|x| llvm_bbs[&bb[x.idx()]]); - - // The order of the data uses of the phi corresponds with the - // order of the control uses of the region. - let incoming_values: Vec<_> = zip(data_uses, pred_bbs).collect(); - phi_values[&id].add_incoming(&incoming_values[..]); - } - } - } - - // Step 5: write out module to given file path. - llvm_module.write_bitcode_to_path(path); -} - -/* - * Emit LLVM implementing a single node. - */ -fn emit_llvm_for_node<'ctx>( - id: NodeID, - values: &mut HashMap<NodeID, AnyValueEnum<'ctx>>, - phi_values: &mut HashMap<NodeID, PhiValue<'ctx>>, - branch_instructions: &mut HashMap<BasicBlock<'ctx>, InstructionValue<'ctx>>, - function: &Function, - typing: &Vec<TypeID>, - types: &Vec<Type>, - dynamic_constants: &Vec<DynamicConstant>, - bb: &Vec<NodeID>, - def_use: &ImmutableDefUseMap, - fork_join_nest: &HashMap<NodeID, Vec<NodeID>>, - array_allocations: &(Vec<Vec<DynamicConstantID>>, HashMap<NodeID, usize>), - llvm_context: &'ctx Context, - llvm_builder: &'ctx Builder, - llvm_fn: FunctionValue<'ctx>, - llvm_bbs: &HashMap<NodeID, BasicBlock<'ctx>>, - llvm_types: &Vec<BasicTypeEnum<'ctx>>, - llvm_constants: &Vec<BasicValueEnum<'ctx>>, -) { - // Helper to emit code for dynamic constants. - let emit_dynamic_constant = - |dyn_cons_id: DynamicConstantID| match dynamic_constants[dyn_cons_id.idx()] { - DynamicConstant::Constant(val) => llvm_context - .i64_type() - .const_int(val as u64, false) - .as_any_value_enum(), - DynamicConstant::Parameter(num) => llvm_fn - .get_nth_param((num + function.param_types.len()) as u32) - .unwrap() - .as_any_value_enum(), - }; - - // Helper to get array allocation. - let get_array_alloc = |idx: usize| { - llvm_fn - .get_nth_param( - (idx + function.param_types.len() + function.num_dynamic_constants as usize) as u32, - ) - .unwrap() - .as_any_value_enum() - }; - - // Helper to position at the beginning of a basic block. - let position_at_beginning = |bb: BasicBlock<'ctx>| { - if let Some(first_inst) = bb.get_first_instruction() { - llvm_builder.position_before(&first_inst); - } else { - llvm_builder.position_at_end(bb); - } - }; - - let llvm_bb = llvm_bbs[&bb[id.idx()]]; - if let Some(iv) = branch_instructions.get(&llvm_bb) { - llvm_builder.position_before(iv); - } else { - llvm_builder.position_at_end(llvm_bb); - } - match function.nodes[id.idx()] { - Node::Start | Node::Region { preds: _ } => { - let successor = def_use - .get_users(id) - .iter() - .filter(|id| function.nodes[id.idx()].is_strictly_control()) - .next() - .unwrap(); - branch_instructions.insert( - llvm_bb, - llvm_builder - .build_unconditional_branch(llvm_bbs[successor]) - .unwrap(), - ); - } - Node::If { control: _, cond } => { - let successors = def_use.get_users(id); - if function.nodes[successors[0].idx()] == (Node::ReadProd { prod: id, index: 0 }) { - branch_instructions.insert( - llvm_bb, - llvm_builder - .build_conditional_branch( - values[&cond].into_int_value(), - llvm_bbs[&bb[successors[1].idx()]], - llvm_bbs[&bb[successors[0].idx()]], - ) - .unwrap(), - ); - } else { - branch_instructions.insert( - llvm_bb, - llvm_builder - .build_conditional_branch( - values[&cond].into_int_value(), - llvm_bbs[&bb[successors[0].idx()]], - llvm_bbs[&bb[successors[1].idx()]], - ) - .unwrap(), - ); - } - } - Node::Fork { - control: _, - factor: _, - } => { - // Need to create phi node for the loop index. - let phi_value = llvm_builder.build_phi(llvm_context.i64_type(), "").unwrap(); - phi_values.insert(id, phi_value); - - let successor = def_use - .get_users(id) - .iter() - .filter(|id| function.nodes[id.idx()].is_strictly_control()) - .next() - .unwrap(); - branch_instructions.insert( - llvm_bb, - llvm_builder - .build_unconditional_branch(llvm_bbs[successor]) - .unwrap(), - ); - } - Node::Join { control } => { - // Form the bottom of the loop. We need to branch between the - // successor and the fork. - let fork_id = if let Type::Control(factors) = &types[typing[control.idx()].idx()] { - *factors.last().unwrap() - } else { - panic!() - }; - let phi_value = phi_values[&fork_id]; - let (fork_predecessor, factor) = - if let Node::Fork { control, factor } = &function.nodes[fork_id.idx()] { - (*control, *factor) - } else { - panic!() - }; - - // Create a serial loop with a simple index. - let bound = emit_dynamic_constant(factor); - let new_index = llvm_builder - .build_int_add( - phi_value.as_any_value_enum().into_int_value(), - llvm_context.i64_type().const_int(1, false), - "", - ) - .unwrap(); - phi_value.add_incoming(&[ - ( - &llvm_context.i64_type().const_int(0, false), - llvm_bbs[&bb[fork_predecessor.idx()]], - ), - (&new_index, llvm_bbs[&bb[id.idx()]]), - ]); - - // Create branch forming the serial loop. - let condition = llvm_builder - .build_int_compare(IntPredicate::ULT, new_index, bound.into_int_value(), "") - .unwrap(); - let successor = def_use - .get_users(id) - .iter() - .filter(|id| function.nodes[id.idx()].is_strictly_control()) - .next() - .unwrap(); - branch_instructions.insert( - llvm_bb, - llvm_builder - .build_conditional_branch( - condition, - llvm_bbs[&bb[fork_id.idx()]], - llvm_bbs[&bb[successor.idx()]], - ) - .unwrap(), - ); - } - Node::Phi { - control: _, - data: _, - } => { - // For some reason, Inkwell doesn't convert phi values to/from the - // AnyValueEnum type properly, so store phi values in another map. - position_at_beginning(llvm_bb); - let phi_value = llvm_builder - .build_phi(llvm_types[typing[id.idx()].idx()], "") - .unwrap(); - phi_values.insert(id, phi_value); - values.insert(id, phi_value.as_any_value_enum()); - } - Node::ThreadID { control } => { - let phi_value = phi_values[&control]; - values.insert(id, phi_value.as_any_value_enum()); - } - Node::Collect { control, data } => { - // Write into destination array only in inner-most collect. Outer - // collects become no-ops. - let elem_type = typing[data.idx()]; - if !types[elem_type.idx()].is_array() { - // Get all the thread IDs of the nested forks. These are the phi - // values corresponding to each fork. - let thread_ids = fork_join_nest[&control] - .iter() - .map(|fork| phi_values[&fork]); - let alloc_num = array_allocations.1[&id]; - let extents = &array_allocations.0[alloc_num]; - let mut write_index = llvm_context.i64_type().const_int(0, false); - let mut multiplier = llvm_context.i64_type().const_int(1, false); - for (thread_id, extent) in zip(thread_ids, extents.iter().rev()) { - // Add contribution of this index dimension to flat write - // index. - write_index = llvm_builder - .build_int_add( - write_index, - llvm_builder - .build_int_mul( - multiplier, - thread_id.as_any_value_enum().into_int_value(), - "", - ) - .unwrap(), - "", - ) - .unwrap(); - - // Keep running multiplication of extents seen so far. - multiplier = llvm_builder - .build_int_mul( - multiplier, - emit_dynamic_constant(*extent).into_int_value(), - "", - ) - .unwrap(); - } - - // Emit the write. - let array = get_array_alloc(alloc_num); - let ptr_type = llvm_types[typing[data.idx()].idx()]; - let gep_ptr = unsafe { - llvm_builder - .build_gep(ptr_type, array.into_pointer_value(), &[write_index], "") - .unwrap() - }; - llvm_builder - .build_store(gep_ptr, BasicValueEnum::try_from(values[&data]).unwrap()) - .unwrap(); - - values.insert(id, array); - } else { - values.insert(id, values[&data]); - } - } - Node::Return { control: _, data } => { - llvm_builder - .build_return(Some(&BasicValueEnum::try_from(values[&data]).unwrap())) - .unwrap(); - } - Node::Parameter { index } => { - values.insert( - id, - llvm_fn - .get_nth_param(index as u32) - .unwrap() - .as_any_value_enum(), - ); - } - Node::Constant { id: cons_id } => { - values.insert(id, llvm_constants[cons_id.idx()].into()); - } - Node::DynamicConstant { id: dyn_cons_id } => { - values.insert(id, emit_dynamic_constant(dyn_cons_id)); - } - Node::Unary { input, op } => { - let input = values[&input]; - match op { - UnaryOperator::Not => { - values.insert( - id, - llvm_builder - .build_not(input.into_int_value(), "") - .unwrap() - .as_any_value_enum(), - ); - } - UnaryOperator::Neg => { - if input.get_type().is_float_type() { - values.insert( - id, - llvm_builder - .build_float_neg(input.into_float_value(), "") - .unwrap() - .as_any_value_enum(), - ); - } else { - values.insert( - id, - llvm_builder - .build_int_neg(input.into_int_value(), "") - .unwrap() - .as_any_value_enum(), - ); - } - } - } - } - Node::Binary { left, right, op } => { - let left = values[&left]; - let right = values[&right]; - match op { - BinaryOperator::Add => { - if left.get_type().is_float_type() { - values.insert( - id, - llvm_builder - .build_float_add( - left.into_float_value(), - right.into_float_value(), - "", - ) - .unwrap() - .as_any_value_enum(), - ); - } else { - values.insert( - id, - llvm_builder - .build_int_add(left.into_int_value(), right.into_int_value(), "") - .unwrap() - .as_any_value_enum(), - ); - } - } - BinaryOperator::Sub => { - if left.get_type().is_float_type() { - values.insert( - id, - llvm_builder - .build_float_sub( - left.into_float_value(), - right.into_float_value(), - "", - ) - .unwrap() - .as_any_value_enum(), - ); - } else { - values.insert( - id, - llvm_builder - .build_int_sub(left.into_int_value(), right.into_int_value(), "") - .unwrap() - .as_any_value_enum(), - ); - } - } - BinaryOperator::Mul => { - if left.get_type().is_float_type() { - values.insert( - id, - llvm_builder - .build_float_mul( - left.into_float_value(), - right.into_float_value(), - "", - ) - .unwrap() - .as_any_value_enum(), - ); - } else { - values.insert( - id, - llvm_builder - .build_int_mul(left.into_int_value(), right.into_int_value(), "") - .unwrap() - .as_any_value_enum(), - ); - } - } - BinaryOperator::Div => { - if left.get_type().is_float_type() { - values.insert( - id, - llvm_builder - .build_float_div( - left.into_float_value(), - right.into_float_value(), - "", - ) - .unwrap() - .as_any_value_enum(), - ); - } else if types[typing[id.idx()].idx()].is_unsigned() { - values.insert( - id, - llvm_builder - .build_int_unsigned_div( - left.into_int_value(), - right.into_int_value(), - "", - ) - .unwrap() - .as_any_value_enum(), - ); - } else { - values.insert( - id, - llvm_builder - .build_int_signed_div( - left.into_int_value(), - right.into_int_value(), - "", - ) - .unwrap() - .as_any_value_enum(), - ); - } - } - BinaryOperator::Rem => { - if left.get_type().is_float_type() { - values.insert( - id, - llvm_builder - .build_float_rem( - left.into_float_value(), - right.into_float_value(), - "", - ) - .unwrap() - .as_any_value_enum(), - ); - } else if types[typing[id.idx()].idx()].is_unsigned() { - values.insert( - id, - llvm_builder - .build_int_unsigned_rem( - left.into_int_value(), - right.into_int_value(), - "", - ) - .unwrap() - .as_any_value_enum(), - ); - } else { - values.insert( - id, - llvm_builder - .build_int_signed_rem( - left.into_int_value(), - right.into_int_value(), - "", - ) - .unwrap() - .as_any_value_enum(), - ); - } - } - BinaryOperator::LT => { - if left.get_type().is_float_type() { - values.insert( - id, - llvm_builder - .build_float_compare( - FloatPredicate::OLT, - left.into_float_value(), - right.into_float_value(), - "", - ) - .unwrap() - .as_any_value_enum(), - ); - } else if types[typing[id.idx()].idx()].is_unsigned() { - values.insert( - id, - llvm_builder - .build_int_compare( - IntPredicate::ULT, - left.into_int_value(), - right.into_int_value(), - "", - ) - .unwrap() - .as_any_value_enum(), - ); - } else { - values.insert( - id, - llvm_builder - .build_int_compare( - IntPredicate::SLT, - left.into_int_value(), - right.into_int_value(), - "", - ) - .unwrap() - .as_any_value_enum(), - ); - } - } - BinaryOperator::LTE => { - if left.get_type().is_float_type() { - values.insert( - id, - llvm_builder - .build_float_compare( - FloatPredicate::OLE, - left.into_float_value(), - right.into_float_value(), - "", - ) - .unwrap() - .as_any_value_enum(), - ); - } else if types[typing[id.idx()].idx()].is_unsigned() { - values.insert( - id, - llvm_builder - .build_int_compare( - IntPredicate::ULE, - left.into_int_value(), - right.into_int_value(), - "", - ) - .unwrap() - .as_any_value_enum(), - ); - } else { - values.insert( - id, - llvm_builder - .build_int_compare( - IntPredicate::SLE, - left.into_int_value(), - right.into_int_value(), - "", - ) - .unwrap() - .as_any_value_enum(), - ); - } - } - BinaryOperator::GT => { - if left.get_type().is_float_type() { - values.insert( - id, - llvm_builder - .build_float_compare( - FloatPredicate::OGT, - left.into_float_value(), - right.into_float_value(), - "", - ) - .unwrap() - .as_any_value_enum(), - ); - } else if types[typing[id.idx()].idx()].is_unsigned() { - values.insert( - id, - llvm_builder - .build_int_compare( - IntPredicate::UGT, - left.into_int_value(), - right.into_int_value(), - "", - ) - .unwrap() - .as_any_value_enum(), - ); - } else { - values.insert( - id, - llvm_builder - .build_int_compare( - IntPredicate::SGT, - left.into_int_value(), - right.into_int_value(), - "", - ) - .unwrap() - .as_any_value_enum(), - ); - } - } - BinaryOperator::GTE => { - if left.get_type().is_float_type() { - values.insert( - id, - llvm_builder - .build_float_compare( - FloatPredicate::OGE, - left.into_float_value(), - right.into_float_value(), - "", - ) - .unwrap() - .as_any_value_enum(), - ); - } else if types[typing[id.idx()].idx()].is_unsigned() { - values.insert( - id, - llvm_builder - .build_int_compare( - IntPredicate::UGE, - left.into_int_value(), - right.into_int_value(), - "", - ) - .unwrap() - .as_any_value_enum(), - ); - } else { - values.insert( - id, - llvm_builder - .build_int_compare( - IntPredicate::SGE, - left.into_int_value(), - right.into_int_value(), - "", - ) - .unwrap() - .as_any_value_enum(), - ); - } - } - BinaryOperator::EQ => { - if left.get_type().is_float_type() { - values.insert( - id, - llvm_builder - .build_float_compare( - FloatPredicate::OEQ, - left.into_float_value(), - right.into_float_value(), - "", - ) - .unwrap() - .as_any_value_enum(), - ); - } else { - values.insert( - id, - llvm_builder - .build_int_compare( - IntPredicate::EQ, - left.into_int_value(), - right.into_int_value(), - "", - ) - .unwrap() - .as_any_value_enum(), - ); - } - } - BinaryOperator::NE => { - if left.get_type().is_float_type() { - values.insert( - id, - llvm_builder - .build_float_compare( - FloatPredicate::ONE, - left.into_float_value(), - right.into_float_value(), - "", - ) - .unwrap() - .as_any_value_enum(), - ); - } else { - values.insert( - id, - llvm_builder - .build_int_compare( - IntPredicate::NE, - left.into_int_value(), - right.into_int_value(), - "", - ) - .unwrap() - .as_any_value_enum(), - ); - } - } - BinaryOperator::Or => { - values.insert( - id, - llvm_builder - .build_or(left.into_int_value(), right.into_int_value(), "") - .unwrap() - .as_any_value_enum(), - ); - } - BinaryOperator::And => { - values.insert( - id, - llvm_builder - .build_and(left.into_int_value(), right.into_int_value(), "") - .unwrap() - .as_any_value_enum(), - ); - } - BinaryOperator::Xor => { - values.insert( - id, - llvm_builder - .build_xor(left.into_int_value(), right.into_int_value(), "") - .unwrap() - .as_any_value_enum(), - ); - } - BinaryOperator::LSh => { - values.insert( - id, - llvm_builder - .build_left_shift(left.into_int_value(), right.into_int_value(), "") - .unwrap() - .as_any_value_enum(), - ); - } - BinaryOperator::RSh => { - values.insert( - id, - llvm_builder - .build_right_shift( - left.into_int_value(), - right.into_int_value(), - !types[typing[id.idx()].idx()].is_unsigned(), - "", - ) - .unwrap() - .as_any_value_enum(), - ); - } - } - } - Node::ReadProd { prod, index } => { - // ReadProd nodes are special in that they may be projection nodes. - if function.nodes[prod.idx()].is_strictly_control() { - let successor = def_use.get_users(id)[0]; - branch_instructions.insert( - llvm_bb, - llvm_builder - .build_unconditional_branch(llvm_bbs[&successor]) - .unwrap(), - ); - } else { - values.insert( - id, - llvm_builder - .build_extract_value(values[&prod].into_struct_value(), index as u32, "") - .unwrap() - .as_any_value_enum(), - ); - } - } - Node::WriteProd { prod, index, data } => { - values.insert( - id, - llvm_builder - .build_insert_value( - values[&prod].into_struct_value(), - BasicValueEnum::try_from(values[&data]).unwrap(), - index as u32, - "", - ) - .unwrap() - .as_any_value_enum(), - ); - } - Node::ReadArray { array, index } => { - let elem_type = element_type(typing[id.idx()], types); - let llvm_elem_type = llvm_types[elem_type.idx()]; - - // If this is the last level of the array type, then do a load. - // Otherwise, the output is a pointer to the sub-array. - if types[typing[id.idx()].idx()].is_array() { - let mut index = values[&index].into_int_value(); - for dc in type_extents(typing[id.idx()], types) { - let dc = emit_dynamic_constant(dc); - index = llvm_builder - .build_int_mul(index, dc.into_int_value(), "") - .unwrap(); - } - values.insert(id, unsafe { - llvm_builder - .build_gep( - llvm_elem_type, - values[&array].into_pointer_value(), - &[index], - "", - ) - .unwrap() - .as_any_value_enum() - }); - } else { - let gep_ptr = unsafe { - llvm_builder - .build_gep( - llvm_elem_type, - values[&array].into_pointer_value(), - &[values[&index].into_int_value()], - "", - ) - .unwrap() - }; - values.insert( - id, - llvm_builder - .build_load(llvm_elem_type, gep_ptr, "") - .unwrap() - .as_any_value_enum(), - ); - } - } - Node::WriteArray { - array: _, - index: _, - data: _, - } => todo!(), - _ => todo!(), - } -} diff --git a/hercules_codegen/src/lib.rs b/hercules_codegen/src/lib.rs deleted file mode 100644 index 7a1d6b6fda36262b4e4400c6b4f8f6730a017695..0000000000000000000000000000000000000000 --- a/hercules_codegen/src/lib.rs +++ /dev/null @@ -1,9 +0,0 @@ -pub mod antideps; -pub mod array_alloc; -pub mod cpu_alpha; -pub mod gcm; - -pub use crate::antideps::*; -pub use crate::array_alloc::*; -pub use crate::cpu_alpha::*; -pub use crate::gcm::*; diff --git a/hercules_ir/src/build.rs b/hercules_ir/src/build.rs index e1e6aa2e2421be8036ade92cf1214c71c669253b..3c37ecb86d59e87f08a8b2a8f5caaa80152ba8f3 100644 --- a/hercules_ir/src/build.rs +++ b/hercules_ir/src/build.rs @@ -262,8 +262,8 @@ impl<'a> Builder<'a> { self.create_type_sum(Box::new([a, b, c, d, e, f, g, h])) } - pub fn create_type_array(&mut self, elem: TypeID, dc: DynamicConstantID) -> TypeID { - self.intern_type(Type::Array(elem, dc)) + pub fn create_type_array(&mut self, elem: TypeID, extents: Box<[DynamicConstantID]>) -> TypeID { + self.intern_type(Type::Array(elem, extents)) } pub fn create_constant_bool(&mut self, val: bool) -> ConstantID { @@ -355,14 +355,18 @@ impl<'a> Builder<'a> { &mut self, elem_ty: TypeID, cons: Box<[ConstantID]>, + extents: Box<[u32]>, ) -> BuilderResult<ConstantID> { for con in cons.iter() { if self.constant_types[con.idx()] != elem_ty { Err("Constant provided to create_constant_array has a different type than the provided element type.")? } } - let dc = self.create_dynamic_constant_constant(cons.len()); - let ty = self.create_type_array(elem_ty, dc); + let extents = extents + .iter() + .map(|extent| self.create_dynamic_constant_constant(*extent as usize)) + .collect(); + let ty = self.create_type_array(elem_ty, extents); Ok(self.intern_constant(Constant::Array(ty, cons), ty)) } @@ -374,6 +378,22 @@ impl<'a> Builder<'a> { self.intern_dynamic_constant(DynamicConstant::Parameter(val)) } + pub fn create_field_index(&self, idx: usize) -> Index { + Index::Field(idx) + } + + pub fn create_variant_index(&self, idx: usize) -> Index { + Index::Variant(idx) + } + + pub fn create_position_index(&self, idx: Box<[NodeID]>) -> Index { + Index::Position(idx) + } + + pub fn create_control_index(&self, idx: usize) -> Index { + Index::Control(idx) + } + pub fn create_function( &mut self, name: &'a str, @@ -446,8 +466,12 @@ impl NodeBuilder { self.node = Node::ThreadID { control }; } - pub fn build_collect(&mut self, control: NodeID, data: NodeID) { - self.node = Node::Collect { control, data }; + pub fn build_collect(&mut self, control: NodeID, init: NodeID, reduct: NodeID) { + self.node = Node::Reduce { + control, + init, + reduct, + }; } pub fn build_return(&mut self, control: NodeID, data: NodeID) { @@ -487,35 +511,15 @@ impl NodeBuilder { }; } - pub fn build_readprod(&mut self, prod: NodeID, index: usize) { - self.node = Node::ReadProd { prod, index }; + pub fn build_read(&mut self, collect: NodeID, indices: Box<[Index]>) { + self.node = Node::Read { collect, indices }; } - pub fn build_writeprod(&mut self, prod: NodeID, data: NodeID, index: usize) { - self.node = Node::WriteProd { prod, data, index }; - } - - pub fn build_readarray(&mut self, array: NodeID, index: NodeID) { - self.node = Node::ReadArray { array, index }; - } - - pub fn build_writearray(&mut self, array: NodeID, data: NodeID, index: NodeID) { - self.node = Node::WriteArray { array, data, index }; - } - - pub fn build_match(&mut self, control: NodeID, sum: NodeID) { - self.node = Node::Match { control, sum }; - } - - pub fn build_buildsum(&mut self, data: NodeID, sum_ty: TypeID, variant: usize) { - self.node = Node::BuildSum { + pub fn build_write(&mut self, collect: NodeID, data: NodeID, indices: Box<[Index]>) { + self.node = Node::Write { + collect, data, - sum_ty, - variant, + indices, }; } - - pub fn build_extractsum(&mut self, data: NodeID, variant: usize) { - self.node = Node::ExtractSum { data, variant }; - } } diff --git a/hercules_ir/src/dataflow.rs b/hercules_ir/src/dataflow.rs index c3f3e9f900116e79f10803f3225fa5c5565f0bbb..5ad46a4c0c453a913c0cee5c8205644e14643b5a 100644 --- a/hercules_ir/src/dataflow.rs +++ b/hercules_ir/src/dataflow.rs @@ -316,9 +316,9 @@ impl<'a> Iterator for NodeSetIterator<'a> { /* * Flow function for collecting all of a node's uses of "control outputs". What - * this flow function does is collect all immediate phi, thread ID, and collect + * this flow function does is collect all immediate phi, thread ID, and reduce * nodes that every other node depends on through data nodes. Flow is ended at - * a control node, or at a phi, thread ID, or collect node. + * a control, phi, thread ID, or reduce node. */ pub fn control_output_flow( inputs: &[&UnionNodeSet], @@ -332,12 +332,12 @@ pub fn control_output_flow( let node = &function.nodes[node_id.idx()]; // Step 2: clear all bits, if applicable. - if node.is_strictly_control() || node.is_thread_id() || node.is_collect() || node.is_phi() { + if node.is_strictly_control() || node.is_thread_id() || node.is_reduce() || node.is_phi() { out = UnionNodeSet::Empty; } // Step 3: set bit for current node, if applicable. - if node.is_thread_id() || node.is_collect() || node.is_phi() { + if node.is_thread_id() || node.is_reduce() || node.is_phi() { let mut singular = bitvec![u8, Lsb0; 0; function.nodes.len()]; singular.set(node_id.idx(), true); out = UnionNodeSet::meet(&out, &UnionNodeSet::Bits(singular)); @@ -350,7 +350,7 @@ pub fn control_output_flow( * Flow function for collecting all of a data node's immediate uses / users of * control nodes. Useful for code generation. Since this is for immediate uses / * users of control nodes, control node uses / users do not propagate through - * control nodes, or through control output nodes (phis, thread IDs, collects). + * control nodes, or through control output nodes (phis, thread IDs, reduces). */ pub fn immediate_control_flow( inputs: &[&UnionNodeSet], @@ -362,7 +362,11 @@ pub fn immediate_control_flow( // Step 1: replace node if this is a phi, thread ID, or collect. if let Node::Phi { control, data: _ } | Node::ThreadID { control } - | Node::Collect { control, data: _ } = &function.nodes[node_id.idx()] + | Node::Reduce { + control, + init: _, + reduct: _, + } = &function.nodes[node_id.idx()] { node_id = *control; } else { @@ -373,15 +377,8 @@ pub fn immediate_control_flow( } let node = &function.nodes[node_id.idx()]; - // Step 2: figure out if this node is a control node. - let control = if let Node::ReadProd { prod, index: _ } = node { - function.nodes[prod.idx()].is_strictly_control() - } else { - node.is_strictly_control() - }; - - // Step 3: clear all bits and set bit for current node, if applicable. - if control { + // Step 2: clear all bits and set bit for current node, if applicable. + if node.is_control() { let mut singular = bitvec![u8, Lsb0; 0; function.nodes.len()]; singular.set(node_id.idx(), true); out = UnionNodeSet::Bits(singular); diff --git a/hercules_ir/src/def_use.rs b/hercules_ir/src/def_use.rs index fbd75201ef298442ec8baa37ce71805f129246e3..a8fd5c17f6799887364f3cf3ef62b36e6e1953e1 100644 --- a/hercules_ir/src/def_use.rs +++ b/hercules_ir/src/def_use.rs @@ -83,7 +83,7 @@ pub enum NodeUses<'a> { // Phi nodes are special, and store both a NodeID locally *and* many in a // boxed slice. Since these NodeIDs are not stored contiguously, we have to // construct a new contiguous slice by copying. Sigh. - Phi(Box<[NodeID]>), + Owned(Box<[NodeID]>), } /* @@ -107,7 +107,7 @@ impl<'a> AsRef<[NodeID]> for NodeUses<'a> { NodeUses::Two(x) => x, NodeUses::Three(x) => x, NodeUses::Variable(x) => x, - NodeUses::Phi(x) => x, + NodeUses::Owned(x) => x, } } } @@ -132,15 +132,20 @@ pub fn get_uses<'a>(node: &'a Node) -> NodeUses<'a> { Node::Start => NodeUses::Zero, Node::Region { preds } => NodeUses::Variable(preds), Node::If { control, cond } => NodeUses::Two([*control, *cond]), + Node::Match { control, sum } => NodeUses::Two([*control, *sum]), Node::Fork { control, factor: _ } => NodeUses::One([*control]), Node::Join { control } => NodeUses::One([*control]), Node::Phi { control, data } => { let mut uses: Vec<NodeID> = Vec::from(&data[..]); uses.push(*control); - NodeUses::Phi(uses.into_boxed_slice()) + NodeUses::Owned(uses.into_boxed_slice()) } Node::ThreadID { control } => NodeUses::One([*control]), - Node::Collect { control, data } => NodeUses::Two([*control, *data]), + Node::Reduce { + control, + init, + reduct, + } => NodeUses::Three([*control, *init, *reduct]), Node::Return { control, data } => NodeUses::Two([*control, *data]), Node::Parameter { index: _ } => NodeUses::One([NodeID::new(0)]), Node::Constant { id: _ } => NodeUses::One([NodeID::new(0)]), @@ -152,21 +157,43 @@ pub fn get_uses<'a>(node: &'a Node) -> NodeUses<'a> { dynamic_constants: _, args, } => NodeUses::Variable(args), - Node::ReadProd { prod, index: _ } => NodeUses::One([*prod]), - Node::WriteProd { - prod, - data, - index: _, - } => NodeUses::Two([*prod, *data]), - Node::ReadArray { array, index } => NodeUses::Two([*array, *index]), - Node::WriteArray { array, data, index } => NodeUses::Three([*array, *data, *index]), - Node::Match { control, sum } => NodeUses::Two([*control, *sum]), - Node::BuildSum { + Node::Read { collect, indices } => { + let mut uses = vec![]; + for index in indices.iter() { + if let Index::Position(pos) = index { + uses.append(&mut pos.clone().into_vec()); + } + } + if uses.len() > 0 { + uses.reverse(); + uses.push(*collect); + uses.reverse(); + NodeUses::Owned(uses.into_boxed_slice()) + } else { + NodeUses::One([*collect]) + } + } + Node::Write { + collect, + indices, data, - sum_ty: _, - variant: _, - } => NodeUses::One([*data]), - Node::ExtractSum { data, variant: _ } => NodeUses::One([*data]), + } => { + let mut uses = vec![]; + for index in indices.iter() { + if let Index::Position(pos) = index { + uses.append(&mut pos.clone().into_vec()); + } + } + if uses.len() > 0 { + uses.reverse(); + uses.push(*data); + uses.push(*collect); + uses.reverse(); + NodeUses::Owned(uses.into_boxed_slice()) + } else { + NodeUses::Two([*collect, *data]) + } + } } } @@ -182,13 +209,18 @@ pub fn get_uses_mut<'a>(node: &'a mut Node) -> NodeUsesMut<'a> { Node::Start => NodeUsesMut::Zero, Node::Region { preds } => NodeUsesMut::Variable(preds.iter_mut().collect()), Node::If { control, cond } => NodeUsesMut::Two([control, cond]), + Node::Match { control, sum } => NodeUsesMut::Two([control, sum]), Node::Fork { control, factor: _ } => NodeUsesMut::One([control]), Node::Join { control } => NodeUsesMut::One([control]), Node::Phi { control, data } => { NodeUsesMut::Variable(std::iter::once(control).chain(data.iter_mut()).collect()) } Node::ThreadID { control } => NodeUsesMut::One([control]), - Node::Collect { control, data } => NodeUsesMut::Two([control, data]), + Node::Reduce { + control, + init, + reduct, + } => NodeUsesMut::Three([control, init, reduct]), Node::Return { control, data } => NodeUsesMut::Two([control, data]), Node::Parameter { index: _ } => NodeUsesMut::Zero, Node::Constant { id: _ } => NodeUsesMut::Zero, @@ -200,20 +232,42 @@ pub fn get_uses_mut<'a>(node: &'a mut Node) -> NodeUsesMut<'a> { dynamic_constants: _, args, } => NodeUsesMut::Variable(args.iter_mut().collect()), - Node::ReadProd { prod, index: _ } => NodeUsesMut::One([prod]), - Node::WriteProd { - prod, - data, - index: _, - } => NodeUsesMut::Two([prod, data]), - Node::ReadArray { array, index } => NodeUsesMut::Two([array, index]), - Node::WriteArray { array, data, index } => NodeUsesMut::Three([array, data, index]), - Node::Match { control, sum } => NodeUsesMut::Two([control, sum]), - Node::BuildSum { + Node::Read { collect, indices } => { + let mut uses = vec![]; + for index in indices.iter_mut() { + if let Index::Position(pos) = index { + for d in pos.iter_mut() { + uses.push(d); + } + } + } + if uses.len() > 0 { + uses.push(collect); + NodeUsesMut::Variable(uses.into_boxed_slice()) + } else { + NodeUsesMut::One([collect]) + } + } + Node::Write { + collect, + indices, data, - sum_ty: _, - variant: _, - } => NodeUsesMut::One([data]), - Node::ExtractSum { data, variant: _ } => NodeUsesMut::One([data]), + } => { + let mut uses = vec![]; + for index in indices.iter_mut() { + if let Index::Position(pos) = index { + for d in pos.iter_mut() { + uses.push(d); + } + } + } + if uses.len() > 0 { + uses.push(collect); + uses.push(data); + NodeUsesMut::Variable(uses.into_boxed_slice()) + } else { + NodeUsesMut::Two([collect, data]) + } + } } } diff --git a/hercules_ir/src/ir.rs b/hercules_ir/src/ir.rs index af1555d6d85153cd08d1c39de5ce93ead1c356c7..ecc7eecfeb20c96969c50ebe8ac6c436c5f47a1a 100644 --- a/hercules_ir/src/ir.rs +++ b/hercules_ir/src/ir.rs @@ -46,13 +46,11 @@ pub struct Function { * control type. Hercules IR is based off of the sea-of-nodes IR, the main * feature of which being a merged control and data flow graph. Thus, control * is a type of value, just like any other type. However, the type system is - * very restrictive over what can be done with control values. A novel addition - * in Hercules IR is that a control type is parameterized by a list of thread + * very restrictive over what can be done with control values. An addition in + * Hercules IR is that a control type is parameterized by a list of thread * spawning factors. This is the mechanism in Hercules IR for representing * parallelism. Summation types are an IR equivalent of Rust's enum types. - * These are lowered into tagged unions during scheduling. Array types are one- - * dimensional. Multi-dimensional arrays are represented by nesting array types. - * An array extent is represented with a dynamic constant. + * These are lowered into tagged unions during scheduling. */ #[derive(Debug, Clone, PartialEq, Eq, Hash)] pub enum Type { @@ -70,7 +68,7 @@ pub enum Type { Float64, Product(Box<[TypeID]>), Summation(Box<[TypeID]>), - Array(TypeID, DynamicConstantID), + Array(TypeID, Box<[DynamicConstantID]>), } /* @@ -115,6 +113,23 @@ pub enum DynamicConstant { Parameter(usize), } +/* + * Hercules has a single node for reading, Read, and a single node for writing, + * Write, that are both used for modifying product, sum, and array structures. + * However, each of these types are indexed differently. Thus, these two nodes + * operate on an index list, composing indices at different levels in a type + * tree. Each type that can be indexed has a unique variant in the index enum. + * Read nodes are overloaded to select between control successors of if and + * match nodes. + */ +#[derive(Debug, Clone, PartialEq, Eq, Hash)] +pub enum Index { + Field(usize), + Variant(usize), + Position(Box<[NodeID]>), + Control(usize), +} + /* * Hercules IR is a combination of a possibly cylic control flow graph, and * many possibly cyclic data flow graphs. Each node represents some operation on @@ -136,6 +151,10 @@ pub enum Node { control: NodeID, cond: NodeID, }, + Match { + control: NodeID, + sum: NodeID, + }, Fork { control: NodeID, factor: DynamicConstantID, @@ -150,9 +169,10 @@ pub enum Node { ThreadID { control: NodeID, }, - Collect { + Reduce { control: NodeID, - data: NodeID, + init: NodeID, + reduct: NodeID, }, Return { control: NodeID, @@ -181,36 +201,14 @@ pub enum Node { dynamic_constants: Box<[DynamicConstantID]>, args: Box<[NodeID]>, }, - ReadProd { - prod: NodeID, - index: usize, - }, - WriteProd { - prod: NodeID, - data: NodeID, - index: usize, - }, - ReadArray { - array: NodeID, - index: NodeID, - }, - WriteArray { - array: NodeID, - data: NodeID, - index: NodeID, - }, - Match { - control: NodeID, - sum: NodeID, - }, - BuildSum { - data: NodeID, - sum_ty: TypeID, - variant: usize, + Read { + collect: NodeID, + indices: Box<[Index]>, }, - ExtractSum { + Write { + collect: NodeID, data: NodeID, - variant: usize, + indices: Box<[Index]>, }, } @@ -322,11 +320,13 @@ impl Module { } write!(w, ")") } - Type::Array(elem, length) => { + Type::Array(elem, extents) => { write!(w, "Array(")?; self.write_type(*elem, w)?; - write!(w, ", ")?; - self.write_dynamic_constant(*length, w)?; + for extent in extents.iter() { + write!(w, ", ")?; + self.write_dynamic_constant(*extent, w)?; + } write!(w, ")") } }?; @@ -586,31 +586,6 @@ impl Function { std::mem::swap(&mut new_nodes, &mut self.nodes); } - - /* - * Checking if a node is control requires surrounding context, so this is a - * member of Function, not Node. - */ - pub fn is_control(&self, id: NodeID) -> bool { - if self.nodes[id.idx()].is_strictly_control() { - return true; - } - - if let Node::ReadProd { prod, index: _ } = self.nodes[id.idx()] { - return match self.nodes[prod.idx()] { - // ReadProd nodes are control nodes if their predecessor is a - // legal control node. - Node::Match { control: _, sum: _ } - | Node::If { - control: _, - cond: _, - } => true, - _ => false, - }; - } - - false - } } impl Type { @@ -673,15 +648,22 @@ impl Type { false } } -} -pub fn type_extents(mut ty: TypeID, types: &Vec<Type>) -> Vec<DynamicConstantID> { - let mut extents = vec![]; - while let Type::Array(elem, dc) = types[ty.idx()] { - extents.push(dc); - ty = elem; + pub fn try_element_type(&self) -> Option<TypeID> { + if let Type::Array(elem, _) = self { + Some(*elem) + } else { + None + } + } + + pub fn try_extents(&self) -> Option<&[DynamicConstantID]> { + if let Type::Array(_, extents) = self { + Some(extents) + } else { + None + } } - extents } pub fn element_type(mut ty: TypeID, types: &Vec<Type>) -> TypeID { @@ -692,6 +674,14 @@ pub fn element_type(mut ty: TypeID, types: &Vec<Type>) -> TypeID { } impl Constant { + pub fn is_array(&self) -> bool { + if let Constant::Array(_, _) = self { + true + } else { + false + } + } + /* * Useful for GVN. */ @@ -743,6 +733,35 @@ macro_rules! define_pattern_predicate { }; } +impl Index { + define_pattern_predicate!(is_control, Index::Control(_)); + + pub fn try_field(&self) -> Option<usize> { + if let Index::Field(field) = self { + Some(*field) + } else { + None + } + } + + pub fn try_position(&self) -> Option<&[NodeID]> { + if let Index::Position(indices) = self { + Some(&indices) + } else { + None + } + } + + pub fn lower_case_name(&self) -> &'static str { + match self { + Index::Field(_) => "field", + Index::Variant(_) => "variant", + Index::Position(_) => "position", + Index::Control(_) => "control", + } + } +} + impl Node { define_pattern_predicate!(is_start, Node::Start); define_pattern_predicate!(is_region, Node::Region { preds: _ }); @@ -770,10 +789,11 @@ impl Node { ); define_pattern_predicate!(is_thread_id, Node::ThreadID { control: _ }); define_pattern_predicate!( - is_collect, - Node::Collect { + is_reduce, + Node::Reduce { control: _, - data: _, + init: _, + reduct: _, } ); define_pattern_predicate!( @@ -783,28 +803,25 @@ impl Node { data: _, } ); - define_pattern_predicate!(is_read_prod, Node::ReadProd { prod: _, index: _ }); define_pattern_predicate!( - is_write_prod, - Node::WriteProd { - prod: _, - index: _, - data: _ + is_read, + Node::Read { + collect: _, + indices: _ } ); - define_pattern_predicate!(is_read_array, Node::ReadArray { array: _, index: _ }); define_pattern_predicate!( - is_write_array, - Node::WriteArray { - array: _, - index: _, + is_write, + Node::Write { + collect: _, + indices: _, data: _ } ); define_pattern_predicate!(is_match, Node::Match { control: _, sum: _ }); /* - * ReadProd nodes can be considered control when following an if or match + * Read nodes can be considered control when following an if or match * node. However, it is sometimes useful to exclude such nodes when * considering control nodes. */ @@ -826,6 +843,7 @@ impl Node { control: _, cond: _, } => "If", + Node::Match { control: _, sum: _ } => "Match", Node::Fork { control: _, factor: _, @@ -836,10 +854,11 @@ impl Node { data: _, } => "Phi", Node::ThreadID { control: _ } => "ThreadID", - Node::Collect { + Node::Reduce { control: _, - data: _, - } => "Collect", + init: _, + reduct: _, + } => "Reduce", Node::Return { control: _, data: _, @@ -858,28 +877,15 @@ impl Node { dynamic_constants: _, args: _, } => "Unary", - Node::ReadProd { prod: _, index: _ } => "ReadProd", - Node::WriteProd { - prod: _, - data: _, - index: _, - } => "WriteProd", - Node::ReadArray { array: _, index: _ } => "ReadArray", - Node::WriteArray { - array: _, - data: _, - index: _, - } => "WriteArray", - Node::Match { control: _, sum: _ } => "Match", - Node::BuildSum { - data: _, - sum_ty: _, - variant: _, - } => "BuildSum", - Node::ExtractSum { + Node::Read { + collect: _, + indices: _, + } => "Read", + Node::Write { + collect: _, data: _, - variant: _, - } => "ExtractSum", + indices: _, + } => "Write", } } @@ -891,6 +897,7 @@ impl Node { control: _, cond: _, } => "if", + Node::Match { control: _, sum: _ } => "match", Node::Fork { control: _, factor: _, @@ -901,10 +908,11 @@ impl Node { data: _, } => "phi", Node::ThreadID { control: _ } => "thread_id", - Node::Collect { + Node::Reduce { control: _, - data: _, - } => "collect", + init: _, + reduct: _, + } => "reduce", Node::Return { control: _, data: _, @@ -923,29 +931,32 @@ impl Node { dynamic_constants: _, args: _, } => "call", - Node::ReadProd { prod: _, index: _ } => "read_prod", - Node::WriteProd { - prod: _, - data: _, - index: _, - } => "write_prod ", - Node::ReadArray { array: _, index: _ } => "read_array", - Node::WriteArray { - array: _, - data: _, - index: _, - } => "write_array", - Node::Match { control: _, sum: _ } => "match", - Node::BuildSum { + Node::Read { + collect: _, + indices: _, + } => "read", + Node::Write { + collect: _, data: _, - sum_ty: _, - variant: _, - } => "build_sum", - Node::ExtractSum { - data: _, - variant: _, - } => "extract_sum", + indices: _, + } => "write", + } + } + + pub fn is_control(&self) -> bool { + if self.is_strictly_control() { + return true; + } + + if let Node::Read { + collect: _, + indices, + } = self + { + return indices.len() == 1 && indices[0].is_control(); } + + false } } diff --git a/hercules_ir/src/parse.rs b/hercules_ir/src/parse.rs index 0900d49ff5438725ed0eaf0af9366373521c58f8..3354cbb284808f376db491b222e496358300a760 100644 --- a/hercules_ir/src/parse.rs +++ b/hercules_ir/src/parse.rs @@ -115,7 +115,7 @@ fn parse_module<'a>(ir_text: &'a str, context: Context<'a>) -> nom::IResult<&'a )?; let mut context = context.into_inner(); - // functions, as returned by parsing, is in parse order, which may differ + // Functions, as returned by parsing, is in parse order, which may differ // from the order dictated by FunctionIDs in the function name intern map. let mut fixed_functions = vec![ Function { @@ -279,11 +279,12 @@ fn parse_node<'a>( let (ir_text, node) = match node_kind { "region" => parse_region(ir_text, context)?, "if" => parse_if(ir_text, context)?, + "match" => parse_match(ir_text, context)?, "fork" => parse_fork(ir_text, context)?, "join" => parse_join(ir_text, context)?, "phi" => parse_phi(ir_text, context)?, "thread_id" => parse_thread_id(ir_text, context)?, - "collect" => parse_collect(ir_text, context)?, + "reduce" => parse_reduce(ir_text, context)?, "return" => parse_return(ir_text, context)?, "constant" => parse_constant_node(ir_text, context)?, "dynamic_constant" => parse_dynamic_constant_node(ir_text, context)?, @@ -308,13 +309,8 @@ fn parse_node<'a>( "lsh" => parse_binary(ir_text, context, BinaryOperator::LSh)?, "rsh" => parse_binary(ir_text, context, BinaryOperator::RSh)?, "call" => parse_call(ir_text, context)?, - "read_prod" => parse_read_prod(ir_text, context)?, - "write_prod" => parse_write_prod(ir_text, context)?, - "read_array" => parse_read_array(ir_text, context)?, - "write_array" => parse_write_array(ir_text, context)?, - "match" => parse_match(ir_text, context)?, - "build_sum" => parse_build_sum(ir_text, context)?, - "extract_sum" => parse_extract_sum(ir_text, context)?, + "read" => parse_read(ir_text, context)?, + "write" => parse_write(ir_text, context)?, _ => Err(nom::Err::Error(nom::error::Error { input: ir_text, code: nom::error::ErrorKind::IsNot, @@ -412,14 +408,23 @@ fn parse_thread_id<'a>( Ok((ir_text, Node::ThreadID { control })) } -fn parse_collect<'a>( +fn parse_reduce<'a>( ir_text: &'a str, context: &RefCell<Context<'a>>, ) -> nom::IResult<&'a str, Node> { - let (ir_text, (control, data)) = parse_tuple2(parse_identifier, parse_identifier)(ir_text)?; + let (ir_text, (control, init, reduct)) = + parse_tuple3(parse_identifier, parse_identifier, parse_identifier)(ir_text)?; let control = context.borrow_mut().get_node_id(control); - let data = context.borrow_mut().get_node_id(data); - Ok((ir_text, Node::Collect { control, data })) + let init = context.borrow_mut().get_node_id(init); + let reduct = context.borrow_mut().get_node_id(reduct); + Ok(( + ir_text, + Node::Reduce { + control, + init, + reduct, + }, + )) } fn parse_return<'a>( @@ -532,48 +537,147 @@ fn parse_call<'a>(ir_text: &'a str, context: &RefCell<Context<'a>>) -> nom::IRes )) } -fn parse_read_prod<'a>( +fn parse_index<'a>( ir_text: &'a str, context: &RefCell<Context<'a>>, -) -> nom::IResult<&'a str, Node> { - let (ir_text, (prod, index)) = - parse_tuple2(parse_identifier, |x| parse_prim::<usize>(x, "1234567890"))(ir_text)?; - let prod = context.borrow_mut().get_node_id(prod); - Ok((ir_text, Node::ReadProd { prod, index })) -} - -fn parse_write_prod<'a>( - ir_text: &'a str, - context: &RefCell<Context<'a>>, -) -> nom::IResult<&'a str, Node> { - let (ir_text, (prod, data, index)) = parse_tuple3(parse_identifier, parse_identifier, |x| { - parse_prim::<usize>(x, "1234567890") - })(ir_text)?; - let prod = context.borrow_mut().get_node_id(prod); - let data = context.borrow_mut().get_node_id(data); - Ok((ir_text, Node::WriteProd { prod, data, index })) +) -> nom::IResult<&'a str, Index> { + let (ir_text, idx) = nom::branch::alt(( + nom::combinator::map( + nom::sequence::tuple(( + nom::character::complete::multispace0, + nom::bytes::complete::tag("field"), + nom::character::complete::multispace0, + nom::character::complete::char('('), + nom::character::complete::multispace0, + |x| parse_prim::<usize>(x, "1234567890"), + nom::character::complete::multispace0, + nom::character::complete::char(')'), + nom::character::complete::multispace0, + )), + |(_, _, _, _, _, x, _, _, _)| Index::Field(x), + ), + nom::combinator::map( + nom::sequence::tuple(( + nom::character::complete::multispace0, + nom::bytes::complete::tag("variant"), + nom::character::complete::multispace0, + nom::character::complete::char('('), + nom::character::complete::multispace0, + |x| parse_prim::<usize>(x, "1234567890"), + nom::character::complete::multispace0, + nom::character::complete::char(')'), + nom::character::complete::multispace0, + )), + |(_, _, _, _, _, x, _, _, _)| Index::Variant(x), + ), + nom::combinator::map( + nom::sequence::tuple(( + nom::character::complete::multispace0, + nom::bytes::complete::tag("position"), + nom::character::complete::multispace0, + nom::character::complete::char('('), + nom::character::complete::multispace0, + nom::multi::separated_list1( + nom::sequence::tuple(( + nom::character::complete::multispace0, + nom::character::complete::char(','), + nom::character::complete::multispace0, + )), + parse_identifier, + ), + nom::character::complete::multispace0, + nom::character::complete::char(')'), + nom::character::complete::multispace0, + )), + |(_, _, _, _, _, x, _, _, _)| { + Index::Position( + x.into_iter() + .map(|x| context.borrow_mut().get_node_id(x)) + .collect(), + ) + }, + ), + nom::combinator::map( + nom::sequence::tuple(( + nom::character::complete::multispace0, + nom::bytes::complete::tag("control"), + nom::character::complete::multispace0, + nom::character::complete::char('('), + nom::character::complete::multispace0, + |x| parse_prim::<usize>(x, "1234567890"), + nom::character::complete::multispace0, + nom::character::complete::char(')'), + nom::character::complete::multispace0, + )), + |(_, _, _, _, _, x, _, _, _)| Index::Control(x), + ), + ))(ir_text)?; + Ok((ir_text, idx)) } -fn parse_read_array<'a>( - ir_text: &'a str, - context: &RefCell<Context<'a>>, -) -> nom::IResult<&'a str, Node> { - let (ir_text, (array, index)) = parse_tuple2(parse_identifier, parse_identifier)(ir_text)?; - let array = context.borrow_mut().get_node_id(array); - let index = context.borrow_mut().get_node_id(index); - Ok((ir_text, Node::ReadArray { array, index })) +fn parse_read<'a>(ir_text: &'a str, context: &RefCell<Context<'a>>) -> nom::IResult<&'a str, Node> { + let ir_text = nom::character::complete::multispace0(ir_text)?.0; + let ir_text = nom::character::complete::char('(')(ir_text)?.0; + let ir_text = nom::character::complete::multispace0(ir_text)?.0; + let (ir_text, collect) = parse_identifier(ir_text)?; + let ir_text = nom::character::complete::multispace0(ir_text)?.0; + let ir_text = nom::character::complete::char(',')(ir_text)?.0; + let ir_text = nom::character::complete::multispace0(ir_text)?.0; + let (ir_text, indices) = nom::multi::separated_list1( + nom::sequence::tuple(( + nom::character::complete::multispace0, + nom::character::complete::char(','), + nom::character::complete::multispace0, + )), + |x| parse_index(x, context), + )(ir_text)?; + let ir_text = nom::character::complete::multispace0(ir_text)?.0; + let ir_text = nom::character::complete::char(')')(ir_text)?.0; + let collect = context.borrow_mut().get_node_id(collect); + Ok(( + ir_text, + Node::Read { + collect, + indices: indices.into_boxed_slice(), + }, + )) } -fn parse_write_array<'a>( +fn parse_write<'a>( ir_text: &'a str, context: &RefCell<Context<'a>>, ) -> nom::IResult<&'a str, Node> { - let (ir_text, (array, data, index)) = - parse_tuple3(parse_identifier, parse_identifier, parse_identifier)(ir_text)?; - let array = context.borrow_mut().get_node_id(array); + let ir_text = nom::character::complete::multispace0(ir_text)?.0; + let ir_text = nom::character::complete::char('(')(ir_text)?.0; + let ir_text = nom::character::complete::multispace0(ir_text)?.0; + let (ir_text, collect) = parse_identifier(ir_text)?; + let ir_text = nom::character::complete::multispace0(ir_text)?.0; + let ir_text = nom::character::complete::char(',')(ir_text)?.0; + let ir_text = nom::character::complete::multispace0(ir_text)?.0; + let (ir_text, data) = parse_identifier(ir_text)?; + let ir_text = nom::character::complete::multispace0(ir_text)?.0; + let ir_text = nom::character::complete::char(',')(ir_text)?.0; + let ir_text = nom::character::complete::multispace0(ir_text)?.0; + let (ir_text, indices) = nom::multi::separated_list1( + nom::sequence::tuple(( + nom::character::complete::multispace0, + nom::character::complete::char(','), + nom::character::complete::multispace0, + )), + |x| parse_index(x, context), + )(ir_text)?; + let ir_text = nom::character::complete::multispace0(ir_text)?.0; + let ir_text = nom::character::complete::char(')')(ir_text)?.0; + let collect = context.borrow_mut().get_node_id(collect); let data = context.borrow_mut().get_node_id(data); - let index = context.borrow_mut().get_node_id(index); - Ok((ir_text, Node::WriteArray { array, data, index })) + Ok(( + ir_text, + Node::Write { + collect, + data, + indices: indices.into_boxed_slice(), + }, + )) } fn parse_type_id<'a>( @@ -596,36 +700,6 @@ fn parse_match<'a>( Ok((ir_text, Node::Match { control, sum })) } -fn parse_build_sum<'a>( - ir_text: &'a str, - context: &RefCell<Context<'a>>, -) -> nom::IResult<&'a str, Node> { - let (ir_text, (data, sum_ty, variant)) = parse_tuple3( - parse_identifier, - |x| parse_type_id(x, context), - |x| parse_prim::<usize>(x, "1234567890"), - )(ir_text)?; - let data = context.borrow_mut().get_node_id(data); - Ok(( - ir_text, - Node::BuildSum { - data, - sum_ty, - variant, - }, - )) -} - -fn parse_extract_sum<'a>( - ir_text: &'a str, - context: &RefCell<Context<'a>>, -) -> nom::IResult<&'a str, Node> { - let (ir_text, (data, variant)) = - parse_tuple2(parse_identifier, |x| parse_prim::<usize>(x, "1234567890"))(ir_text)?; - let data = context.borrow_mut().get_node_id(data); - Ok((ir_text, Node::ExtractSum { data, variant })) -} - fn parse_type<'a>(ir_text: &'a str, context: &RefCell<Context<'a>>) -> nom::IResult<&'a str, Type> { // Parser combinators are very convenient, if a bit hard to read. let ir_text = nom::character::complete::multispace0(ir_text)?.0; @@ -716,8 +790,8 @@ fn parse_type<'a>(ir_text: &'a str, context: &RefCell<Context<'a>>) -> nom::IRes )), |(_, _, _, _, ids, _, _)| Type::Summation(ids.into_boxed_slice()), ), - // Array types are just a pair between an element type and a dynamic - // constant representing its extent. + // Array types are just a list of an element type and at least one + // dynamic constant representing its extent. nom::combinator::map( nom::sequence::tuple(( nom::bytes::complete::tag("array"), @@ -728,11 +802,20 @@ fn parse_type<'a>(ir_text: &'a str, context: &RefCell<Context<'a>>) -> nom::IRes nom::character::complete::multispace0, nom::character::complete::char(','), nom::character::complete::multispace0, - |x| parse_dynamic_constant_id(x, context), + nom::multi::separated_list1( + nom::sequence::tuple(( + nom::character::complete::multispace0, + nom::character::complete::char(','), + nom::character::complete::multispace0, + )), + |x| parse_dynamic_constant_id(x, context), + ), nom::character::complete::multispace0, nom::character::complete::char(')'), )), - |(_, _, _, _, ty_id, _, _, _, dc_id, _, _)| Type::Array(ty_id, dc_id), + |(_, _, _, _, ty_id, _, _, _, dc_ids, _, _)| { + Type::Array(ty_id, dc_ids.into_boxed_slice()) + }, ), ))(ir_text)?; Ok((ir_text, ty)) diff --git a/hercules_ir/src/subgraph.rs b/hercules_ir/src/subgraph.rs index c4829b60c488a9b5c8032c90229eb3851c0eb298..c0b7aa4bb504ba40be2b4cde754d70058eefc711 100644 --- a/hercules_ir/src/subgraph.rs +++ b/hercules_ir/src/subgraph.rs @@ -235,5 +235,7 @@ where * Get the control subgraph of a function. */ pub fn control_subgraph(function: &Function, def_use: &ImmutableDefUseMap) -> Subgraph { - subgraph(function, def_use, |node| function.is_control(node)) + subgraph(function, def_use, |node| { + function.nodes[node.idx()].is_control() + }) } diff --git a/hercules_ir/src/typecheck.rs b/hercules_ir/src/typecheck.rs index ad53283d4e5e0b67b697b1e895426d940217df88..1a183411f59581d71bf7ba682d03d563b9f11855 100644 --- a/hercules_ir/src/typecheck.rs +++ b/hercules_ir/src/typecheck.rs @@ -257,9 +257,6 @@ fn typeflow( return Error(String::from( "If node's control input cannot have non-control type.", )); - } else { - let out_ty = Type::Product(Box::new([*id, *id])); - return Concrete(get_type_id(out_ty, types, reverse_type_map)); } } @@ -392,45 +389,45 @@ fn typeflow( reverse_type_map, )) } - Node::Collect { control, data: _ } => { - if inputs.len() != 2 { - return Error(String::from("Collect node must have exactly two inputs.")); + Node::Reduce { + control: _, + init: _, + reduct: _, + } => { + if inputs.len() != 3 { + return Error(String::from("Reduce node must have exactly two inputs.")); } - if let (Concrete(control_id), Concrete(data_id)) = (inputs[0], inputs[1]) { + if let (Concrete(control_id), Concrete(init_id), Concrete(reduct_id)) = + (inputs[0], inputs[1], inputs[2]) + { // Check control input is control. if let Type::Control(_) = types[control_id.idx()] { } else { return Error(String::from( - "Collect node's control input must have control type.", + "Reduce node's control input must have control type.", )); } - // Check data input isn't control. - if let Type::Control(_) = types[data_id.idx()] { + // Check init input isn't control. + if let Type::Control(_) = types[init_id.idx()] { return Error(String::from( - "Collect node's data input must not have control type.", + "Reduce node's initialization input must not have control type.", )); } - // Unfortunately, the type of the control input doesn't contain - // the thread replication factor this collect node is operating - // with. We use the join replication factor map side data - // structure to store the replication factor each join reduces - // over to make this easier. - if let Some(factor) = join_factor_map.get(control) { - let array_out_id = - get_type_id(Type::Array(*data_id, *factor), types, reverse_type_map); - Concrete(array_out_id) - } else { - // If the join factor map doesn't contain the control - // input, stay optimistic. - Unconstrained + // Check reduct input isn't control. + if let Type::Control(_) = types[reduct_id.idx()] { + return Error(String::from( + "Reduce node's reduction input must not have control type.", + )); } + + TypeSemilattice::meet(inputs[1], inputs[2]) } else if inputs[0].is_error() { inputs[0].clone() } else { - inputs[1].clone() + TypeSemilattice::meet(inputs[1], inputs[2]) } } Node::Return { @@ -561,13 +558,21 @@ fn typeflow( } // Array typechecking also consists of validating the number of constant elements. Constant::Array(id, ref elems) => { - if let Type::Array(_, dc_id) = types[id.idx()] { - if dynamic_constants[dc_id.idx()] == DynamicConstant::Constant(elems.len()) - { - Concrete(id) - } else { - Error(String::from("Array constant must have the correct number of constant elements as specified by its type.")) + if let Type::Array(_, dc_ids) = &types[id.idx()] { + let mut total_num_elems = 1; + for dc_id in dc_ids.iter() { + total_num_elems *= if let DynamicConstant::Constant(extent) = + dynamic_constants[dc_id.idx()] + { + extent + } else { + return Error(String::from("Array constant type must reference only constant valued dynamic constants.")); + }; + } + if total_num_elems != 1 && total_num_elems != elems.len() { + return Error(String::from("Array constant must have a compatible amount of elements as the extent of the array.")); } + Concrete(id) } else { Error(String::from( "Array constant must store an explicit array type.", @@ -745,140 +750,100 @@ fn typeflow( Concrete(callee.return_type) } - Node::ReadProd { prod: _, index } => { - if inputs.len() != 1 { - return Error(String::from("ReadProd node must have exactly one input.")); - } - - // If the input type isn't concrete, just propagate input type. - if let Concrete(id) = inputs[0] { - if let Type::Product(elem_tys) = &types[id.idx()] { - if *index >= elem_tys.len() { - // ReadProd's index being out of range is a type error. - return Error(String::from("ReadProd node's index must be within range of input product type's element list.")); - } else { - return Concrete(elem_tys[*index]); - } - } else { - return Error(String::from( - "ReadProd node's input type must be a product type.", - )); - } - } - - inputs[0].clone() - } - Node::WriteProd { - prod: _, - data: _, - index, + Node::Read { + collect: _, + indices, } => { - if inputs.len() != 2 { - return Error(String::from("WriteProd node must have exactly two inputs.")); - } - - // If the input type isn't concrete, just propagate input type. - if let Concrete(id) = inputs[0] { - if let Type::Product(elem_tys) = &types[id.idx()] { - if *index >= elem_tys.len() { - // ReadProd's index being out of range is a type error. - return Error(String::from("WriteProd node's index must be within range of input product type's element list.")); - } else if let Concrete(data_id) = inputs[1] { - if elem_tys[*index] != *data_id { - return Error(format!("WriteProd node's data input doesn't match the type of the element at index {} inside the product type.", index)); - } else if let Type::Control(_) = &types[data_id.idx()] { + if indices.len() == 0 { + return Error(String::from("Read node must have at least one index.")); + } + + // Traverse the collect input's type tree downwards. + if let Concrete(mut collect_id) = inputs[0] { + for index in indices.iter() { + match (&types[collect_id.idx()], index) { + (Type::Product(fields), Index::Field(field)) => { + if *field >= fields.len() { + return Error(String::from("Read node's field index must be in the range of the product type being indexed.")); + } + collect_id = fields[*field]; + } + (Type::Summation(variants), Index::Variant(variant)) => { + if *variant >= variants.len() { + return Error(String::from("Read node's variant index must be in the range of the variant type being indexed.")); + } + collect_id = variants[*variant]; + } + (Type::Array(elem_ty_id, dim_sizes), Index::Position(indices)) => { + if dim_sizes.len() != indices.len() { + return Error(String::from("Read node's position index must have the same number of dimensions as the array type being indexed.")); + } + collect_id = *elem_ty_id; + } + (Type::Control(_), Index::Control(_)) => {} + _ => { return Error(String::from( - "WriteProd node's data input cannot have a control type.", + "Read node has mismatched input type and indices.", )); } - } else if inputs[1].is_error() { - // If an input lattice value is an error, we must - // propagate it. - return inputs[1].clone(); } - return Concrete(elem_tys[*index]); - } else { - return Error(String::from( - "WriteProd node's input type must be a product type.", - )); } - } - inputs[0].clone() - } - Node::ReadArray { array: _, index: _ } => { - if inputs.len() != 2 { - return Error(String::from("ReadArray node must have exactly two inputs.")); - } - - // Check that index has unsigned type. - if let Concrete(id) = inputs[1] { - if !types[id.idx()].is_unsigned() { - return Error(String::from( - "ReadArray node's index input must have unsigned type.", - )); - } - } else if inputs[1].is_error() { - return inputs[1].clone(); - } - - // If array input is concrete, we can get type of ReadArray node. - if let Concrete(id) = inputs[0] { - if let Type::Array(elem_id, _) = types[id.idx()] { - return Concrete(elem_id); - } else { - return Error(String::from( - "ReadArray node's array input must have array type.", - )); - } + // If successfully traversed, the leaf type is the result. + return Concrete(collect_id); } inputs[0].clone() } - Node::WriteArray { - array: _, + Node::Write { + collect: _, data: _, - index: _, + indices, } => { - if inputs.len() != 3 { - return Error(String::from("WriteArray node must have exactly 3 inputs.")); - } - - // Check that index has unsigned type. - if let Concrete(id) = inputs[2] { - if !types[id.idx()].is_unsigned() { - return Error(String::from( - "WriteArray node's index input must have unsigned type.", - )); - } - } else if inputs[2].is_error() { - return inputs[2].clone(); - } - - // Check that array and data types match. - if let Concrete(array_id) = inputs[0] { - if let Type::Array(elem_id, _) = types[array_id.idx()] { - if let Concrete(data_id) = inputs[1] { - if elem_id != *data_id { - return Error(String::from("WriteArray node's array and data inputs must have compatible types (type of data input must be the same as the array input's element type).")); - } else if let Type::Control(_) = &types[data_id.idx()] { + if indices.len() == 0 { + return Error(String::from("Write node must have at least one index.")); + } + + // Traverse the collect input's type tree downwards. + if let (Concrete(mut collect_id), Concrete(data_id)) = (inputs[0], inputs[1]) { + for index in indices.iter() { + match (&types[collect_id.idx()], index) { + (Type::Product(fields), Index::Field(field)) => { + if *field >= fields.len() { + return Error(String::from("Write node's field index must be in the range of the product type being indexed.")); + } + collect_id = fields[*field]; + } + (Type::Summation(variants), Index::Variant(variant)) => { + if *variant >= variants.len() { + return Error(String::from("Write node's variant index must be in the range of the variant type being indexed.")); + } + collect_id = variants[*variant]; + } + (Type::Array(elem_ty_id, dim_sizes), Index::Position(indices)) => { + if dim_sizes.len() != indices.len() { + return Error(String::from("Write node's position index must have the same number of dimensions as the array type being indexed.")); + } + collect_id = *elem_ty_id; + } + (Type::Control(_), Index::Control(_)) => {} + _ => { return Error(String::from( - "WriteArray node's data input cannot have a control type.", + "Write node has mismatched input type and indices.", )); } } - } else { + } + + // The leaf type being indexed must be what's being written. + if *data_id != collect_id { return Error(String::from( - "WriteArray node's array input must have array type.", + "Write node has mismatched data type and indexed type.", )); } } - // If an input type is an error, we must propagate it. - if inputs[1].is_error() { - return inputs[1].clone(); - } - + // No matter what, the type is the type of the collect input. inputs[0].clone() } Node::Match { control: _, sum: _ } => { @@ -889,15 +854,13 @@ fn typeflow( // Check sum and control inputs simultaneously, since both need to // be concrete to determine a concrete type for a match node. if let (Concrete(control_id), Concrete(sum_id)) = (inputs[0], inputs[1]) { - if let Type::Summation(variants) = &types[sum_id.idx()] { + if let Type::Summation(_) = &types[sum_id.idx()] { if !types[control_id.idx()].is_control() { return Error(String::from( "Match node's control input cannot have non-control type.", )); } else { - let out_ty = - Type::Product(vec![*control_id; variants.len()].into_boxed_slice()); - return Concrete(get_type_id(out_ty, types, reverse_type_map)); + return inputs[0].clone(); } } else { return Error(String::from( @@ -913,66 +876,6 @@ fn typeflow( TypeSemilattice::Error(msg) => TypeSemilattice::Error(msg), } } - Node::BuildSum { - data: _, - sum_ty, - variant, - } => { - if inputs.len() != 1 { - return Error(String::from("BuildSum node must have exactly one input.")); - } - - if let Concrete(id) = inputs[0] { - if let Type::Control(_) = &types[id.idx()] { - return Error(String::from( - "BuildSum node's data input cannot have a control type.", - )); - } - - // BuildSum node stores its own result type. - if let Type::Summation(variants) = &types[sum_ty.idx()] { - // Must reference an existing variant. - if *variant >= variants.len() { - return Error(String::from("BuildSum node's variant number must be in range of valid variant numbers for referenced sum type.")); - } - - // The variant type has to be the same as the type of data. - if *id == variants[*variant] { - return Error(String::from( - "BuildSum node's input type must match the referenced variant type.", - )); - } - - return Concrete(*sum_ty); - } else { - return Error(String::from("BuildSum node must reference a sum type.")); - } - } - - inputs[0].clone() - } - Node::ExtractSum { data: _, variant } => { - if inputs.len() != 1 { - return Error(String::from("ExtractSum node must have exactly one input.")); - } - - if let Concrete(id) = inputs[0] { - if let Type::Summation(variants) = &types[id.idx()] { - // Must reference an existing variant. - if *variant >= variants.len() { - return Error(String::from("BuildSum node's variant number must be in range of valid variant numbers for referenced sum type.")); - } - - return Concrete(variants[*variant]); - } else { - return Error(String::from( - "ExtractSum node's input cannot have non-sum type.", - )); - } - } - - inputs[0].clone() - } } } diff --git a/hercules_ir/src/verify.rs b/hercules_ir/src/verify.rs index 00513d52db117232cd0aa12220fd8c61766dab27..1454a64b98fb50190f4107ed4efe2e6b9a4d0298 100644 --- a/hercules_ir/src/verify.rs +++ b/hercules_ir/src/verify.rs @@ -188,7 +188,7 @@ fn verify_structure( found_control = true; } } else { - Err("All fork of a start node must be control or ThreadID nodes.")?; + Err("All users of a fork node must be control or ThreadID nodes.")?; } } } @@ -198,14 +198,15 @@ fn verify_structure( } } // A join node must have exactly one control user. Additionally, - // it may have many collect users. + // it may have many reduce users. Node::Join { control: _ } => { let mut found_control = false; for user in users { match function.nodes[user.idx()] { - Node::Collect { + Node::Reduce { control: _, - data: _, + init: _, + reduct: _, } => {} _ => { if function.nodes[user.idx()].is_strictly_control() { @@ -215,7 +216,7 @@ fn verify_structure( found_control = true; } } else { - Err("All join of a start node must be control or Collect nodes.")?; + Err("All uses of a join node must be control or Reduce nodes.")?; } } } @@ -224,7 +225,7 @@ fn verify_structure( Err("A join node must have exactly one control user.")?; } } - // Each if node must have exactly two ReadProd users, which + // Each if node must have exactly two Read users, which // reference differing elements of the node's output product. Node::If { control: _, @@ -234,23 +235,28 @@ fn verify_structure( Err(format!("If node must have 2 users, not {}.", users.len()))?; } if let ( - Node::ReadProd { - prod: _, - index: index1, + Node::Read { + collect: _, + indices: indices1, }, - Node::ReadProd { - prod: _, - index: index2, + Node::Read { + collect: _, + indices: indices2, }, ) = ( &function.nodes[users[0].idx()], &function.nodes[users[1].idx()], ) { - if !((*index1 == 0 && *index2 == 1) || (*index1 == 1 && *index2 == 0)) { - Err("If node's user ReadProd nodes must reference different elements of output product.")?; + if indices1.len() != 1 + || indices2.len() != 1 + || !((indices1[0] == Index::Control(0) && indices2[0] == Index::Control(1)) + || (indices1[0] == Index::Control(1) + && indices2[0] == Index::Control(0))) + { + Err("If node's user Read nodes must reference different elements of output product.")?; } } else { - Err("If node's users must both be ReadProd nodes.")?; + Err("If node's users must both be Read nodes.")?; } } // Phi nodes must depend on a region node. @@ -272,7 +278,11 @@ fn verify_structure( } } // Collect nodes must depend on a join node. - Node::Collect { control, data: _ } => { + Node::Reduce { + control, + init: _, + reduct: _, + } => { if let Node::Join { control: _ } = function.nodes[control.idx()] { } else { Err("Collect node's control input must be a join node.")?; @@ -291,7 +301,7 @@ fn verify_structure( } } // Match nodes are similar to if nodes, but have a variable number - // of ReadProd users, corresponding to the sum type being matched. + // of Read users, corresponding to the sum type being matched. Node::Match { control: _, sum } => { let sum_ty = &types[typing[sum.idx()].idx()]; if let Type::Summation(tys) = sum_ty { @@ -305,13 +315,25 @@ fn verify_structure( } let mut users_covered = bitvec![u8, Lsb0; 0; users.len()]; for user in users { - if let Node::ReadProd { prod: _, index } = function.nodes[user.idx()] { - assert!(index < users.len(), "ReadProd child of match node reads from bad index, but ran after typecheck succeeded."); + if let Node::Read { + collect: _, + ref indices, + } = function.nodes[user.idx()] + { + if indices.len() != 1 { + Err("Match node's user Read nodes must have a single index.")?; + } + let index = if let Index::Control(index) = indices[0] { + index + } else { + Err("Match node's user Read node must use a control index.")? + }; + assert!(index < users.len(), "Read child of match node reads from bad index, but ran after typecheck succeeded."); users_covered.set(index, true); } } if users_covered.count_ones() != users.len() { - Err(format!("Match node's user ReadProd nodes must reference all {} elements of match node's output product, but they only reference {} of them.", users.len(), users_covered.count_ones()))?; + Err(format!("Match node's user Read nodes must reference all {} elements of match node's output product, but they only reference {} of them.", users.len(), users_covered.count_ones()))?; } } else { panic!("Type of match node's sum input is not a summation type, but ran after typecheck succeeded."); @@ -358,7 +380,7 @@ fn verify_dominance_relationships( let mut to_check = vec![]; for idx in 0..function.nodes.len() { // If this node is a phi node, we need to handle adding dominance checks - // completely differently. + // completely differently. Reduce nodes need to be handled similarly. if let Node::Phi { control, data } = &function.nodes[idx] { // Get the control predecessors of a region. let region_preds = if let Node::Region { preds } = &function.nodes[control.idx()] { @@ -377,19 +399,21 @@ fn verify_dominance_relationships( control_output_dependencies[data_pred.idx()].clone(), )); } + } else if let Node::Reduce { + control: _, + init: _, + reduct: _, + } = &function.nodes[idx] + { + // TODO: Properly check dominance relations of reduce nodes. } else { // Having a control output dependency only matters if this node is a // control node, or if this node is a control output of a control node. // If this node is a control output, then we want to consider the - // control node itself. We exclude the case of a phi node here, since - // phi nodes can explicitly have non-dominating inputs. We handle phis - // separately above. + // control node itself. We exclude the case of phi and reduce nodes + // here, since these nodes can explicitly have non-dominating inputs. let this_id = if let Node::ThreadID { control: dominated_control, - } - | Node::Collect { - control: dominated_control, - data: _, } = function.nodes[idx] { dominated_control @@ -424,9 +448,14 @@ fn verify_dominance_relationships( for pred_idx in 0..function.nodes.len() { if dependencies.is_set(NodeID::new(pred_idx)) { match function.nodes[pred_idx] { - // Verify that uses of phis / collect nodes are dominated + // Verify that uses of phis / reduce nodes are dominated // by the corresponding region / join nodes, respectively. - Node::Phi { control, data: _ } | Node::Collect { control, data: _ } => { + Node::Phi { control, data: _ } + | Node::Reduce { + control, + init: _, + reduct: _, + } => { if dom.contains(this_id) && !dom.does_dom(control, this_id) { Err(format!( "{} node (ID {}) doesn't dominate its use (ID {}).", diff --git a/hercules_opt/src/ccp.rs b/hercules_opt/src/ccp.rs index 2993bf8762a479609cc8ff161f592bec8d2e97c5..c061a877a258d3b558e9c35b246a2eb3a72b18b8 100644 --- a/hercules_opt/src/ccp.rs +++ b/hercules_opt/src/ccp.rs @@ -271,14 +271,14 @@ pub fn ccp( // The reachable users iterator will contain one user if we need to // remove this branch node. if let None = reachable_users.next() { - // The user is a ReadProd node, which in turn has one user. + // The user is a Read node, which in turn has one user. assert!( def_use.get_users(*the_reachable_user).len() == 1, - "Control ReadProd node doesn't have exactly one user." + "Control Read node doesn't have exactly one user." ); let target = def_use.get_users(*the_reachable_user)[0]; - // For each use in the target of the reachable ReadProd, turn it + // For each use in the target of the reachable Read, turn it // into a use of the node proceeding this branch node. for u in get_uses_mut(&mut function.nodes[target.idx()]).as_mut() { if **u == *the_reachable_user { @@ -287,7 +287,7 @@ pub fn ccp( } // Remove this branch node, since it is malformed. Also remove - // all successor ReadProd nodes. + // all successor Read nodes. function.nodes[branch_id.idx()] = Node::Start; for user in users { function.nodes[user.idx()] = Node::Start; @@ -379,9 +379,10 @@ fn ccp_flow_function( Node::Region { preds } => preds.iter().fold(CCPLattice::top(), |val, id| { CCPLattice::meet(&val, &inputs[id.idx()]) }), - // If node has only one output, so doesn't directly handle crossover of - // reachability and constant propagation. ReadProd handles that. + // If node has only one output, if doesn't directly handle crossover of + // reachability and constant propagation. Read handles that. Node::If { control, cond: _ } => inputs[control.idx()].clone(), + Node::Match { control, sum: _ } => inputs[control.idx()].clone(), Node::Fork { control, factor: _ } => inputs[control.idx()].clone(), Node::Join { control } => inputs[control.idx()].clone(), // Phi nodes must look at the reachability of the inputs to its @@ -415,10 +416,11 @@ fn ccp_flow_function( // for the corresponding fork is one. Node::ThreadID { control } => inputs[control.idx()].clone(), // TODO: At least for now, collect nodes always produce unknown values. - // It may be worthwile to add interpretation of constants for collect - // nodes, but it would involve plumbing dynamic constant and fork join - // pairing information here, and I don't feel like doing that. - Node::Collect { control, data: _ } => inputs[control.idx()].clone(), + Node::Reduce { + control, + init: _, + reduct: _, + } => inputs[control.idx()].clone(), Node::Return { control, data } => CCPLattice { reachability: inputs[control.idx()].reachability.clone(), constant: inputs[data.idx()].constant.clone(), @@ -645,22 +647,22 @@ fn ccp_flow_function( }), constant: ConstantLattice::bottom(), }, - // ReadProd handles reachability when following an if or match. - Node::ReadProd { prod, index } => match &function.nodes[prod.idx()] { + // Read handles reachability when following an if or match. + Node::Read { collect, indices } => match &function.nodes[collect.idx()] { Node::If { control: _, cond } => { let cond_constant = &inputs[cond.idx()].constant; - let if_reachability = &inputs[prod.idx()].reachability; - let if_constant = &inputs[prod.idx()].constant; + let if_reachability = &inputs[collect.idx()].reachability; + let if_constant = &inputs[collect.idx()].constant; let new_reachability = if cond_constant.is_top() { ReachabilityLattice::top() } else if let ConstantLattice::Constant(cons) = cond_constant { if let Constant::Boolean(val) = cons { - if *val && *index == 0 { + if *val && indices[0] == Index::Control(0) { // If condition is true and this is the false // branch, then unreachable. ReachabilityLattice::top() - } else if !val && *index == 1 { + } else if !val && indices[0] == Index::Control(1) { // If condition is true and this is the true branch, // then unreachable. ReachabilityLattice::top() @@ -668,7 +670,7 @@ fn ccp_flow_function( if_reachability.clone() } } else { - panic!("Attempted to interpret ReadProd node, where corresponding if node has a non-boolean constant input. Did typechecking succeed?") + panic!("Attempted to interpret Read node, where corresponding if node has a non-boolean constant input. Did typechecking succeed?") } } else { if_reachability.clone() @@ -681,14 +683,14 @@ fn ccp_flow_function( } Node::Match { control: _, sum } => { let sum_constant = &inputs[sum.idx()].constant; - let if_reachability = &inputs[prod.idx()].reachability; - let if_constant = &inputs[prod.idx()].constant; + let if_reachability = &inputs[collect.idx()].reachability; + let if_constant = &inputs[collect.idx()].constant; let new_reachability = if sum_constant.is_top() { ReachabilityLattice::top() } else if let ConstantLattice::Constant(cons) = sum_constant { if let Constant::Summation(_, variant, _) = cons { - if *variant as usize != *index { + if Index::Control(*variant as usize) != indices[0] { // If match variant is not the same as this branch, // then unreachable. ReachabilityLattice::top() @@ -696,7 +698,7 @@ fn ccp_flow_function( if_reachability.clone() } } else { - panic!("Attempted to interpret ReadProd node, where corresponding match node has a non-summation constant input. Did typechecking succeed?") + panic!("Attempted to interpret Read node, where corresponding match node has a non-summation constant input. Did typechecking succeed?") } } else { if_reachability.clone() @@ -707,98 +709,35 @@ fn ccp_flow_function( constant: if_constant.clone(), } } - _ => { - let CCPLattice { - ref reachability, - ref constant, - } = inputs[prod.idx()]; - - let new_constant = if let ConstantLattice::Constant(cons) = constant { - let new_cons = if let Constant::Product(_, fields) = cons { - // Index into product constant to get result constant. - old_constants[fields[*index].idx()].clone() - } else { - panic!("Attempted to interpret ReadProd on non-product constant. Did typechecking succeed?") - }; - ConstantLattice::Constant(new_cons) - } else { - constant.clone() - }; - - CCPLattice { - reachability: reachability.clone(), - constant: new_constant, - } - } + _ => CCPLattice { + reachability: inputs[collect.idx()].reachability.clone(), + constant: ConstantLattice::bottom(), + }, }, - // WriteProd is uninterpreted for now. - Node::WriteProd { - prod, + // Write is uninterpreted for now. + Node::Write { + collect, data, - index: _, - } => CCPLattice { - reachability: ReachabilityLattice::meet( - &inputs[prod.idx()].reachability, + indices, + } => { + let mut reachability = ReachabilityLattice::meet( + &inputs[collect.idx()].reachability, &inputs[data.idx()].reachability, - ), - constant: ConstantLattice::bottom(), - }, - Node::ReadArray { array, index } => { - let CCPLattice { - reachability: ref array_reachability, - constant: ref array_constant, - } = inputs[array.idx()]; - let CCPLattice { - reachability: ref index_reachability, - constant: ref index_constant, - } = inputs[index.idx()]; - - let new_constant = if let ( - ConstantLattice::Constant(array_cons), - ConstantLattice::Constant(index_cons), - ) = (array_constant, index_constant) - { - let new_cons = match (array_cons, index_cons) { - (Constant::Array(_, elems), Constant::UnsignedInteger8(idx)) => { - elems[*idx as usize] + ); + for index in indices.iter() { + if let Index::Position(positions) = index { + for position in positions.iter() { + reachability = ReachabilityLattice::meet( + &reachability, + &inputs[position.idx()].reachability, + ); } - (Constant::Array(_, elems), Constant::UnsignedInteger16(idx)) => { - elems[*idx as usize] - } - (Constant::Array(_, elems), Constant::UnsignedInteger32(idx)) => { - elems[*idx as usize] - } - (Constant::Array(_, elems), Constant::UnsignedInteger64(idx)) => { - elems[*idx as usize] - } - _ => panic!("Unsupported inputs to ReadArray node. Did typechecking succeed?"), - }; - ConstantLattice::Constant(old_constants[new_cons.idx()].clone()) - } else if (array_constant.is_top() && !index_constant.is_bottom()) - || (!array_constant.is_bottom() && index_constant.is_top()) - { - ConstantLattice::top() - } else { - ConstantLattice::meet(array_constant, index_constant) - }; - + } + } CCPLattice { - reachability: ReachabilityLattice::meet(array_reachability, index_reachability), - constant: new_constant, + reachability, + constant: ConstantLattice::bottom(), } } - // WriteArray is uninterpreted for now. - Node::WriteArray { array, data, index } => CCPLattice { - reachability: ReachabilityLattice::meet( - &ReachabilityLattice::meet( - &inputs[array.idx()].reachability, - &inputs[data.idx()].reachability, - ), - &inputs[index.idx()].reachability, - ), - constant: ConstantLattice::bottom(), - }, - Node::Match { control, sum: _ } => inputs[control.idx()].clone(), - _ => CCPLattice::bottom(), } } diff --git a/hercules_tools/Cargo.toml b/hercules_tools/Cargo.toml deleted file mode 100644 index 412f1022638722bd15023dbb289642634d654733..0000000000000000000000000000000000000000 --- a/hercules_tools/Cargo.toml +++ /dev/null @@ -1,19 +0,0 @@ -[package] -name = "hercules_tools" -version = "0.1.0" -authors = ["Russel Arbore <rarbore2@illinois.edu>"] - -[[bin]] -name = "hercules_dot" -path = "src/hercules_dot/main.rs" - -[[bin]] -name = "hercules_cpu" -path = "src/hercules_cpu/main.rs" - -[dependencies] -clap = { version = "*", features = ["derive"] } -hercules_ir = { path = "../hercules_ir" } -hercules_opt = { path = "../hercules_opt" } -hercules_codegen = { path = "../hercules_codegen" } -rand = "*" diff --git a/hercules_tools/hercules_cpu/Cargo.toml b/hercules_tools/hercules_cpu/Cargo.toml new file mode 100644 index 0000000000000000000000000000000000000000..58519ce50d47a85a668aee99452aa85398489052 --- /dev/null +++ b/hercules_tools/hercules_cpu/Cargo.toml @@ -0,0 +1,11 @@ +[package] +name = "hercules_cpu" +version = "0.1.0" +authors = ["Russel Arbore <rarbore2@illinois.edu>"] + +[dependencies] +clap = { version = "*", features = ["derive"] } +hercules_ir = { path = "../../hercules_ir" } +hercules_opt = { path = "../../hercules_opt" } +hercules_cg = { path = "../../hercules_cg" } +rand = "*" diff --git a/hercules_tools/src/hercules_cpu/main.rs b/hercules_tools/hercules_cpu/src/main.rs similarity index 80% rename from hercules_tools/src/hercules_cpu/main.rs rename to hercules_tools/hercules_cpu/src/main.rs index 9ea218afa729c60372a78baa4d02d15385c897a8..c6b077f96589fdc0cb9e0692923f3af3cc7ebac9 100644 --- a/hercules_tools/src/hercules_cpu/main.rs +++ b/hercules_tools/hercules_cpu/src/main.rs @@ -57,7 +57,14 @@ fn main() { .functions .iter() .enumerate() - .map(|(idx, function)| hercules_codegen::antideps::antideps(function, &def_uses[idx])) + .map(|(idx, function)| { + hercules_cg::antideps::array_antideps( + function, + &def_uses[idx], + &module.types, + &typing[idx], + ) + }) .collect(); let bbs: Vec<_> = module @@ -65,7 +72,7 @@ fn main() { .iter() .enumerate() .map(|(idx, function)| { - hercules_codegen::gcm::gcm( + hercules_cg::gcm::gcm( function, &def_uses[idx], &reverse_postorders[idx], @@ -82,39 +89,23 @@ fn main() { .iter() .enumerate() .map(|(idx, function)| { - hercules_codegen::gcm::compute_fork_join_nesting( - function, - &doms[idx], - &fork_join_maps[idx], - ) + hercules_cg::gcm::compute_fork_join_nesting(function, &doms[idx], &fork_join_maps[idx]) }) .collect(); - let array_allocs: Vec<_> = module - .functions - .iter() - .enumerate() - .map(|(idx, function)| { - hercules_codegen::array_alloc::logical_array_alloc( - function, - &typing[idx], - &module.types, - &fork_join_maps[idx], - &bbs[idx], - &fork_join_nests[idx], - ) - }) - .collect(); - - hercules_codegen::cpu_alpha::cpu_alpha_codegen( + let mut file = File::create("test.ll").unwrap(); + let mut contents = String::new(); + hercules_cg::cpu_beta::cpu_beta_codegen( &module, &typing, &reverse_postorders, &def_uses, &bbs, &antideps, - &array_allocs, + &fork_join_maps, &fork_join_nests, - &std::path::Path::new("test.bc"), - ); + &mut contents, + ) + .unwrap(); + file.write_all(contents.as_bytes()).unwrap(); } diff --git a/hercules_tools/hercules_dot/Cargo.toml b/hercules_tools/hercules_dot/Cargo.toml new file mode 100644 index 0000000000000000000000000000000000000000..f2b42c0ea2bf18fdca6dc00b33af17302c19f145 --- /dev/null +++ b/hercules_tools/hercules_dot/Cargo.toml @@ -0,0 +1,10 @@ +[package] +name = "hercules_dot" +version = "0.1.0" +authors = ["Russel Arbore <rarbore2@illinois.edu>"] + +[dependencies] +clap = { version = "*", features = ["derive"] } +hercules_ir = { path = "../../hercules_ir" } +hercules_opt = { path = "../../hercules_opt" } +rand = "*" diff --git a/hercules_tools/src/hercules_dot/dot.rs b/hercules_tools/hercules_dot/src/dot.rs similarity index 89% rename from hercules_tools/src/hercules_dot/dot.rs rename to hercules_tools/hercules_dot/src/dot.rs index 5ccd6363bafcb59b734923ca873d60f0b70ca1b7..70f368313e05f5a50b2182077e519cc7ddaf382b 100644 --- a/hercules_tools/src/hercules_dot/dot.rs +++ b/hercules_tools/hercules_dot/src/dot.rs @@ -68,7 +68,9 @@ pub fn write_dot<W: Write>( // the source in the reverse postorder. let is_back_edge = reverse_postorder_node_numbers[node_id.idx()] < reverse_postorder_node_numbers[u.idx()] - && (node.is_phi() || (function.is_control(node_id) && function.is_control(*u))); + && (node.is_phi() + || (function.nodes[node_id.idx()].is_control() + && function.nodes[u.idx()].is_control())); write_edge( node_id, function_id, @@ -185,18 +187,31 @@ fn write_node<W: Write>( module.write_dynamic_constant(*dc_id, &mut suffix)?; } } - Node::ReadProd { prod: _, index } => write!(&mut suffix, "{}", index)?, - Node::WriteProd { - prod: _, - data: _, - index, - } => write!(&mut suffix, "{}", index)?, - Node::BuildSum { + Node::Read { + collect: _, + indices, + } => { + let mut iter = indices.iter(); + if let Some(first) = iter.next() { + write!(&mut suffix, "{}", first.lower_case_name())?; + for next in iter { + write!(&mut suffix, ", {}", next.lower_case_name())?; + } + } + } + Node::Write { + collect: _, data: _, - sum_ty: _, - variant, - } => write!(&mut suffix, "{}", variant)?, - Node::ExtractSum { data: _, variant } => write!(&mut suffix, "{}", variant)?, + indices, + } => { + let mut iter = indices.iter(); + if let Some(first) = iter.next() { + write!(&mut suffix, "{}", first.lower_case_name())?; + for next in iter { + write!(&mut suffix, ", {}", next.lower_case_name())?; + } + } + } _ => {} }; diff --git a/hercules_tools/src/hercules_dot/main.rs b/hercules_tools/hercules_dot/src/main.rs similarity index 100% rename from hercules_tools/src/hercules_dot/main.rs rename to hercules_tools/hercules_dot/src/main.rs diff --git a/samples/fork_join.hir b/samples/fork_join.hir index e8371c62d4f9be3607d0be4120232aa198b5d760..99e95829578f5893ce507e9643a97727fb85ccc2 100644 --- a/samples/fork_join.hir +++ b/samples/fork_join.hir @@ -1,9 +1,8 @@ -fn fork_join<2>() -> array(array(u64, #1), #0) - af_ctrl = fork(start, #0) - bf_ctrl = fork(af_ctrl, #1) - bj_ctrl = join(bf_ctrl) - aj_ctrl = join(bj_ctrl) - x = thread_id(bf_ctrl) - bdata = collect(bj_ctrl, x) - adata = collect(aj_ctrl, bdata) - r = return(aj_ctrl, adata) +fn fork_join<1>() -> u64 + f_ctrl = fork(start, #0) + j_ctrl = join(f_ctrl) + zero = constant(u64, 0) + x = thread_id(f_ctrl) + data = reduce(j_ctrl, zero, sum) + sum = add(data, x) + r = return(j_ctrl, data) diff --git a/samples/matmul.hir b/samples/matmul.hir index af13ce95489cd5675f61c9023d78f0866a9d4e68..65e0cb1045189277b3ac04bb430125a0a7a14cce 100644 --- a/samples/matmul.hir +++ b/samples/matmul.hir @@ -1,28 +1,57 @@ -fn matmul<3>(a: array(array(f32, #1), #0), b: array(array(f32, #2), #1)) -> array(array(f32, #2), #0) +fn matmul<3>(a: array(f32, #0, #1), b: array(f32, #1, #2), c: array(f32, #0, #2)) -> array(f32, #0, #2) i_ctrl = fork(start, #0) i_idx = thread_id(i_ctrl) - k_ctrl = fork(i_ctrl, #2) + j_ctrl = fork(i_ctrl, #2) + j_idx = thread_id(j_ctrl) + k_ctrl = fork(j_ctrl, #1) k_idx = thread_id(k_ctrl) - zero_idx = constant(u64, 0) - one_idx = constant(u64, 1) - zero_val = constant(f32, 0) - loop = region(k_ctrl, if_true) - j = phi(loop, zero_idx, j_inc) - sum = phi(loop, zero_val, sum_inc) - j_inc = add(j, one_idx) - fval1 = read_array(a, i_idx) - fval2 = read_array(b, j) - val1 = read_array(fval1, j) - val2 = read_array(fval2, k_idx) - mul = mul(val1, val2) - sum_inc = add(sum, mul) - j_size = dynamic_constant(#1) - less = lt(j_inc, j_size) - if = if(loop, less) - if_false = read_prod(if, 0) - if_true = read_prod(if, 1) - k_join_ctrl = join(if_false) - k_join_data = collect(k_join_ctrl, sum_inc) - i_join_ctrl = join(k_join_ctrl) - i_join_data = collect(i_join_ctrl, k_join_data) - r = return(i_join_ctrl, i_join_data) + k_join_ctrl = join(k_ctrl) + j_join_ctrl = join(k_join_ctrl) + i_join_ctrl = join(j_join_ctrl) + r = return(i_join_ctrl, update_i_c) + zero = constant(f32, 0) + a_val = read(a, position(i_idx, k_idx)) + b_val = read(b, position(k_idx, j_idx)) + mul = mul(a_val, b_val) + add = add(mul, dot) + dot = reduce(k_join_ctrl, zero, add) + updated_c = write(update_j_c, dot, position(i_idx, j_idx)) + update_j_c = reduce(j_join_ctrl, update_i_c, updated_c) + update_i_c = reduce(i_join_ctrl, c, update_j_c) + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/samples/strset.hir b/samples/strset.hir index b83ff39e288139e633e68d2670cdd79b0cf2b64c..2e3ba49225c314a5803e3e678dad2596253e6064 100644 --- a/samples/strset.hir +++ b/samples/strset.hir @@ -7,11 +7,11 @@ fn strset<1>(str: array(u8, #0), byte: u8) -> array(u8, #0) str_inc = phi(loop, str, write) idx_inc = add(idx, one) in_bounds = lt(idx_inc, bound) - read = read_array(str_inc, idx) - write = write_array(str_inc, byte, idx) + read = read(str_inc, position(idx)) + write = write(str_inc, byte, position(idx)) continue = ne(read, byte) if_cond = and(continue, in_bounds) if = if(loop, if_cond) - if_false = read_prod(if, 0) - if_true = read_prod(if, 1) - r = return(if_false, str_inc) \ No newline at end of file + if_false = read(if, control(0)) + if_true = read(if, control(1)) + r = return(if_false, str_inc)