diff --git a/Cargo.lock b/Cargo.lock index c3138b1b025e0997f47d3b04d8f129909936c9dd..cc0a005b796e843ed013a824feb2c2f5e3c9976d 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -66,6 +66,119 @@ version = "1.0.86" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "b3d1d046238990b9cf5bcde22a3fb3584ee5cf65fb2765f454ed428c7a0063da" +[[package]] +name = "async-channel" +version = "1.9.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "81953c529336010edd6d8e358f886d9581267795c61b19475b71314bffa46d35" +dependencies = [ + "concurrent-queue", + "event-listener 2.5.3", + "futures-core", +] + +[[package]] +name = "async-channel" +version = "2.3.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "89b47800b0be77592da0afd425cc03468052844aff33b84e33cc696f64e77b6a" +dependencies = [ + "concurrent-queue", + "event-listener-strategy", + "futures-core", + "pin-project-lite", +] + +[[package]] +name = "async-executor" +version = "1.13.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "30ca9a001c1e8ba5149f91a74362376cc6bc5b919d92d988668657bd570bdcec" +dependencies = [ + "async-task", + "concurrent-queue", + "fastrand", + "futures-lite", + "slab", +] + +[[package]] +name = "async-global-executor" +version = "2.4.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "05b1b633a2115cd122d73b955eadd9916c18c8f510ec9cd1686404c60ad1c29c" +dependencies = [ + "async-channel 2.3.1", + "async-executor", + "async-io", + "async-lock", + "blocking", + "futures-lite", + "once_cell", +] + +[[package]] +name = "async-io" +version = "2.3.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0d6baa8f0178795da0e71bc42c9e5d13261aac7ee549853162e66a241ba17964" +dependencies = [ + "async-lock", + "cfg-if", + "concurrent-queue", + "futures-io", + "futures-lite", + "parking", + "polling", + "rustix", + "slab", + "tracing", + "windows-sys", +] + +[[package]] +name = "async-lock" +version = "3.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ff6e472cdea888a4bd64f342f09b3f50e1886d32afe8df3d663c01140b811b18" +dependencies = [ + "event-listener 5.3.1", + "event-listener-strategy", + "pin-project-lite", +] + +[[package]] +name = "async-std" +version = "1.13.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c634475f29802fde2b8f0b505b1bd00dfe4df7d4a000f0b36f7671197d5c3615" +dependencies = [ + "async-channel 1.9.0", + "async-global-executor", + "async-io", + "async-lock", + "crossbeam-utils", + "futures-channel", + "futures-core", + "futures-io", + "futures-lite", + "gloo-timers", + "kv-log-macro", + "log", + "memchr", + "once_cell", + "pin-project-lite", + "pin-utils", + "slab", + "wasm-bindgen-futures", +] + +[[package]] +name = "async-task" +version = "4.7.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8b75356056920673b02621b35afd0f7dda9306d03c79a30f5c56c44cf256e3de" + [[package]] name = "atomic-polyfill" version = "1.0.3" @@ -75,6 +188,12 @@ dependencies = [ "critical-section", ] +[[package]] +name = "atomic-waker" +version = "1.1.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1505bd5d3d116872e7271a6d4e16d81d0c8570876c8de68093a09ac269d8aac0" + [[package]] name = "autocfg" version = "1.3.0" @@ -123,6 +242,25 @@ dependencies = [ "wyz", ] +[[package]] +name = "blocking" +version = "1.6.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "703f41c54fc768e63e091340b424302bb1c29ef4aa0c7f10fe849dfb114d29ea" +dependencies = [ + "async-channel 2.3.1", + "async-task", + "futures-io", + "futures-lite", + "piper", +] + +[[package]] +name = "bumpalo" +version = "3.16.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "79296716171880943b8470b5f8d03aa55eb2e645a4874bdbb28adb49162e012c" + [[package]] name = "byteorder" version = "1.5.0" @@ -207,12 +345,27 @@ version = "1.0.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "0b6a852b24ab71dffc585bcb46eaf7959d175cb865a7152e35b348d1b2960422" +[[package]] +name = "concurrent-queue" +version = "2.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4ca0197aee26d1ae37445ee532fefce43251d24cc7c166799f4d46817f1d3973" +dependencies = [ + "crossbeam-utils", +] + [[package]] name = "critical-section" version = "1.1.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "7059fff8937831a9ae6f0fe4d658ffabf58f2ca96aa9dec1c889f936f705f216" +[[package]] +name = "crossbeam-utils" +version = "0.8.20" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "22ec99545bb0ed0ea7bb9b8e1e9122ea386ff8a48c0922e43f36d45ab09e0e80" + [[package]] name = "deranged" version = "0.3.11" @@ -260,6 +413,49 @@ version = "1.0.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "5443807d6dff69373d433ab9ef5378ad8df50ca6298caf15de6e52e24aaf54d5" +[[package]] +name = "errno" +version = "0.3.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "534c5cf6194dfab3db3242765c03bbe257cf92f22b38f6bc0c58d59108a820ba" +dependencies = [ + "libc", + "windows-sys", +] + +[[package]] +name = "event-listener" +version = "2.5.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0206175f82b8d6bf6652ff7d71a1e27fd2e4efde587fd368662814d6ec1d9ce0" + +[[package]] +name = "event-listener" +version = "5.3.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6032be9bd27023a771701cc49f9f053c751055f71efb2e0ae5c15809093675ba" +dependencies = [ + "concurrent-queue", + "parking", + "pin-project-lite", +] + +[[package]] +name = "event-listener-strategy" +version = "0.5.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0f214dc438f977e6d4e3500aaa277f5ad94ca83fbbd9b1a15713ce2344ccc5a1" +dependencies = [ + "event-listener 5.3.1", + "pin-project-lite", +] + +[[package]] +name = "fastrand" +version = "2.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e8c02a5121d4ea3eb16a80748c74f5549a5665e4c21333c6098f283870fbdea6" + [[package]] name = "filetime" version = "0.2.23" @@ -284,6 +480,40 @@ version = "2.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e6d5a32815ae3f33302d95fdcb2ce17862f8c65363dcfd29360480ba1001fc9c" +[[package]] +name = "futures-channel" +version = "0.3.30" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "eac8f7d7865dcb88bd4373ab671c8cf4508703796caa2b1985a9ca867b3fcb78" +dependencies = [ + "futures-core", +] + +[[package]] +name = "futures-core" +version = "0.3.30" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "dfc6580bb841c5a68e9ef15c77ccc837b40a7504914d52e47b8b0e9bbda25a1d" + +[[package]] +name = "futures-io" +version = "0.3.30" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a44623e20b9681a318efdd71c299b6b222ed6f231972bfe2f224ebad6311f0c1" + +[[package]] +name = "futures-lite" +version = "2.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "52527eb5074e35e9339c6b4e8d12600c7128b68fb25dcb9fa9dec18f7c25f3a5" +dependencies = [ + "fastrand", + "futures-core", + "futures-io", + "parking", + "pin-project-lite", +] + [[package]] name = "getopts" version = "0.2.21" @@ -304,6 +534,18 @@ dependencies = [ "wasi", ] +[[package]] +name = "gloo-timers" +version = "0.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bbb143cf96099802033e0d4f4963b19fd2e0b728bcf076cd9cf7f6634f092994" +dependencies = [ + "futures-channel", + "futures-core", + "js-sys", + "wasm-bindgen", +] + [[package]] name = "hash32" version = "0.2.1" @@ -340,23 +582,33 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "2304e00983f87ffb38b55b444b5e3b60a884b5d30c0fca7d82fe33449bbe55ea" [[package]] -name = "hercules_driver" +name = "hercules_cg" version = "0.1.0" dependencies = [ - "clap", + "bitvec", "hercules_ir", - "hercules_opt", - "ron", + "ordered-float", + "serde", ] [[package]] -name = "hercules_hbin_dump" +name = "hercules_dot" +version = "0.1.0" +dependencies = [ + "async-std", + "clap", + "hercules_rt", + "rand", +] + +[[package]] +name = "hercules_driver" version = "0.1.0" dependencies = [ "clap", "hercules_ir", - "postcard", - "serde", + "hercules_opt", + "ron", ] [[package]] @@ -389,6 +641,7 @@ name = "hercules_opt" version = "0.1.0" dependencies = [ "bitvec", + "hercules_cg", "hercules_ir", "ordered-float", "postcard", @@ -411,7 +664,7 @@ name = "hercules_rt_proc" version = "0.1.0" dependencies = [ "anyhow", - "hercules_ir", + "hercules_cg", "postcard", "serde", ] @@ -430,6 +683,12 @@ dependencies = [ "rand", ] +[[package]] +name = "hermit-abi" +version = "0.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fbf6a919d6cf397374f7dfeeea91d974c7c0a7221d0d0f4f20d859d329e53fcc" + [[package]] name = "indexmap" version = "2.2.6" @@ -461,6 +720,15 @@ version = "1.0.11" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "49f1f14873335454500d59611f1cf4a4b0f786f9ac11f4312a78e4cf2566695b" +[[package]] +name = "js-sys" +version = "0.3.70" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1868808506b929d7b0cfa8f75951347aa71bb21144b7791bae35d9bccfcfe37a" +dependencies = [ + "wasm-bindgen", +] + [[package]] name = "juno_frontend" version = "0.1.0" @@ -477,6 +745,15 @@ dependencies = [ "phf", ] +[[package]] +name = "kv-log-macro" +version = "1.0.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0de8b303297635ad57c9f5059fd9cee7a47f8e8daa09df0fcd07dd39fb22977f" +dependencies = [ + "log", +] + [[package]] name = "lazy_static" version = "1.4.0" @@ -485,9 +762,15 @@ checksum = "e2abad23fbc42b3700f2f279844dc832adb2b2eb069b2df918f455c4e18cc646" [[package]] name = "libc" -version = "0.2.155" +version = "0.2.158" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "97b3888a4aecf77e811145cadf6eef5901f4782c53886191b2f693f24761847c" +checksum = "d8adc4bb1803a324070e64a98ae98f38934d91957a99cfb3a43dcbc01bc56439" + +[[package]] +name = "linux-raw-sys" +version = "0.4.14" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "78b3ae25bc7c8c38cec158d1f2757ee79e9b3740fbc7ccf0e59e4b08d793fa89" [[package]] name = "lock_api" @@ -499,6 +782,15 @@ dependencies = [ "scopeguard", ] +[[package]] +name = "log" +version = "0.4.22" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a7a70ba024b9dc04c27ea2f0c0548feb474ec5c54bba33a7f72f873a39d07b24" +dependencies = [ + "value-bag", +] + [[package]] name = "lrlex" version = "0.13.6" @@ -629,6 +921,12 @@ dependencies = [ "libc", ] +[[package]] +name = "once_cell" +version = "1.19.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3fdb12b2476b595f9358c5161aa467c2438859caa136dec86c26fdd2efe17b92" + [[package]] name = "ordered-float" version = "4.2.0" @@ -648,6 +946,12 @@ dependencies = [ "serde", ] +[[package]] +name = "parking" +version = "2.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f38d5652c16fde515bb1ecef450ab0f6a219d619a7274976324d5e377f7dceba" + [[package]] name = "phf" version = "0.11.2" @@ -690,6 +994,44 @@ dependencies = [ "siphasher", ] +[[package]] +name = "pin-project-lite" +version = "0.2.14" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bda66fc9667c18cb2758a2ac84d1167245054bcf85d5d1aaa6923f45801bdd02" + +[[package]] +name = "pin-utils" +version = "0.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8b870d8c151b6f2fb93e84a13146138f05d02ed11c7e7c54f8826aaaf7c9f184" + +[[package]] +name = "piper" +version = "0.2.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "96c8c490f422ef9a4efd2cb5b42b76c8613d7e7dfc1caf667b8a3350a5acc066" +dependencies = [ + "atomic-waker", + "fastrand", + "futures-io", +] + +[[package]] +name = "polling" +version = "3.7.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a3ed00ed3fbf728b5816498ecd316d1716eecaced9c0c8d2c5a6740ca214985b" +dependencies = [ + "cfg-if", + "concurrent-queue", + "hermit-abi", + "pin-project-lite", + "rustix", + "tracing", + "windows-sys", +] + [[package]] name = "postcard" version = "1.0.8" @@ -827,6 +1169,19 @@ dependencies = [ "semver", ] +[[package]] +name = "rustix" +version = "0.38.37" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8acb788b847c24f28525660c4d7758620a7210875711f79e7f663cc152726811" +dependencies = [ + "bitflags 2.5.0", + "errno", + "libc", + "linux-raw-sys", + "windows-sys", +] + [[package]] name = "rustversion" version = "1.0.17" @@ -871,6 +1226,15 @@ version = "0.3.11" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "38b58827f4464d87d377d175e90bf58eb00fd8716ff0a62f80356b5e61555d0d" +[[package]] +name = "slab" +version = "0.4.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8f92a496fb766b417c996b9c5e57daf2f7ad3b0bebe1ccfca4856390e3d3bb67" +dependencies = [ + "autocfg", +] + [[package]] name = "sparsevec" version = "0.2.0" @@ -966,6 +1330,22 @@ dependencies = [ "time-core", ] +[[package]] +name = "tracing" +version = "0.1.40" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c3523ab5a71916ccf420eebdf5521fcef02141234bbc0b8a49f2fdc4544364ef" +dependencies = [ + "pin-project-lite", + "tracing-core", +] + +[[package]] +name = "tracing-core" +version = "0.1.32" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c06d3da6113f116aaee68e4d601191614c9053067f9ab7f6edbcb161237daa54" + [[package]] name = "unicode-ident" version = "1.0.12" @@ -984,6 +1364,12 @@ version = "0.2.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "711b9620af191e0cdc7468a8d14e709c3dcdb115b36f838e601583af800a370a" +[[package]] +name = "value-bag" +version = "1.9.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5a84c137d37ab0142f0f2ddfe332651fdbf252e7b7dbb4e67b6c1f1b2e925101" + [[package]] name = "vergen" version = "8.3.1" @@ -1012,6 +1398,83 @@ version = "0.11.0+wasi-snapshot-preview1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "9c8d87e72b64a3b4db28d11ce29237c246188f4f51057d65a7eab63b7987e423" +[[package]] +name = "wasm-bindgen" +version = "0.2.93" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a82edfc16a6c469f5f44dc7b571814045d60404b55a0ee849f9bcfa2e63dd9b5" +dependencies = [ + "cfg-if", + "once_cell", + "wasm-bindgen-macro", +] + +[[package]] +name = "wasm-bindgen-backend" +version = "0.2.93" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9de396da306523044d3302746f1208fa71d7532227f15e347e2d93e4145dd77b" +dependencies = [ + "bumpalo", + "log", + "once_cell", + "proc-macro2", + "quote", + "syn", + "wasm-bindgen-shared", +] + +[[package]] +name = "wasm-bindgen-futures" +version = "0.4.43" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "61e9300f63a621e96ed275155c108eb6f843b6a26d053f122ab69724559dc8ed" +dependencies = [ + "cfg-if", + "js-sys", + "wasm-bindgen", + "web-sys", +] + +[[package]] +name = "wasm-bindgen-macro" +version = "0.2.93" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "585c4c91a46b072c92e908d99cb1dcdf95c5218eeb6f3bf1efa991ee7a68cccf" +dependencies = [ + "quote", + "wasm-bindgen-macro-support", +] + +[[package]] +name = "wasm-bindgen-macro-support" +version = "0.2.93" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "afc340c74d9005395cf9dd098506f7f44e38f2b4a21c6aaacf9a105ea5e1e836" +dependencies = [ + "proc-macro2", + "quote", + "syn", + "wasm-bindgen-backend", + "wasm-bindgen-shared", +] + +[[package]] +name = "wasm-bindgen-shared" +version = "0.2.93" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c62a0a307cb4a311d3a07867860911ca130c3494e8c2719593806c08bc5d0484" + +[[package]] +name = "web-sys" +version = "0.3.70" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "26fdeaafd9bd129f65e7c031593c24d62186301e0c72c8978fa1678be7d532c0" +dependencies = [ + "js-sys", + "wasm-bindgen", +] + [[package]] name = "windows-sys" version = "0.52.0" diff --git a/Cargo.toml b/Cargo.toml index 14184efac2c7dce70c9edca005ad5cc3dde70610..67d0cea41d242081592b33c7b7ae4b2347db83d1 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,6 +1,7 @@ [workspace] resolver = "2" members = [ + "hercules_cg", "hercules_ir", "hercules_opt", "hercules_rt", @@ -10,10 +11,11 @@ members = [ "hercules_test/hercules_tests", "hercules_tools/hercules_driver", - "hercules_tools/hercules_hbin_dump", + #"hercules_tools/hercules_hbin_dump", "juno_frontend", + "hercules_samples/dot", #"hercules_samples/matmul", #"hercules_samples/task_parallel" ] diff --git a/hercules_cg/Cargo.toml b/hercules_cg/Cargo.toml new file mode 100644 index 0000000000000000000000000000000000000000..8658a63733c3b64c15f5ab4e88d28d8e2441b4a0 --- /dev/null +++ b/hercules_cg/Cargo.toml @@ -0,0 +1,11 @@ +[package] +name = "hercules_cg" +version = "0.1.0" +authors = ["Russel Arbore <rarbore2@illinois.edu>"] + +[dependencies] +ordered-float = "*" +bitvec = "*" +serde = { version = "*", features = ["derive"] } +hercules_ir = { path = "../hercules_ir" } + diff --git a/hercules_cg/src/cpu.rs b/hercules_cg/src/cpu.rs new file mode 100644 index 0000000000000000000000000000000000000000..4f93707fc892d48183e95fb80f3ded8836239cba --- /dev/null +++ b/hercules_cg/src/cpu.rs @@ -0,0 +1,658 @@ +extern crate bitvec; + +use std::collections::{HashMap, VecDeque}; +use std::fmt::{Error, Write}; + +use self::bitvec::prelude::*; + +use crate::*; + +/* + * List of big TODOs that aren't urgent: + * + * 1. Return `undef` when a PartitionExit data output isn't dominated by that + * datum's definition. PartitionExit always returns the complete set of data + * outputs that could ever be needed from a partition - this is because we + * don't want to introduce sum types into partition signatures or do funky + * pointer tricks. So, we could run into the following situation: + * + * Block 1 + * / \ + * Block 2 Block 3 + * | | + * Define v1 Define v2 + * / \ + * PartitionExit(v1,v2) PartitionExit(v1, v2) + * + * Let's assume that the PartitionExits branch to different partitions where + * one of them needs v1 and the other needs v2. Notice that both + * PartitionExits need to return both v1 and v2, since their signatures must + * be identical, even though for both, one of v1 or v2 doesn't dominate the + * PartitionExit. What *should* happen here is that each PartitionExit gets + * lowered to an LLVM `ret`, where the non-dominating output is set to + * `undef` This works since in the original, un-partitioned, Hercules IR, + * defs must dominate uses, so we won't run into a situation where a returned + * `undef` value is actually read. What happens currently is that the + * generated LLVM will `ret` `%v1` and `%v2`, which LLVM won't compile (since + * the code wouldn't be in SSA form). This should get fixed when we start + * compiling more complicated codes. + * + * 2. We're a bit, "loose", with how certain basic blocks involving fork-joins + * are handled. In particular, the following things need to be handled + * properly, but aren't yet: + * - Phis in a block with a reduce block predecessor need to generate LLVM + * phis that actually depend on the top parallel block of that fork-join, + * not the reduce block. This is because we hijack the top block to be the + * loop header, rather than use the reduce block to be the loop header. + * - The above also applies in the case of phis generated from thread IDs and + * reduction variables inside the top parallel block of a fork-join. This + * case occurs when there is a parallel-reduce section inside another + * parallel-reduce section. + * - The above also applies in the case of a parallel-reduce section + * immediately following another parallel-reduce section (a reduce block + * jumps to the top parallel block of another parallel-reduce section). + * + * 3. Handle >= 3D fork-joins and array accesses. This isn't conceptually + * difficult, but generating the LLVM code to implement these is annoying. + * + * 4. Handle ABI properly when taking in / returning structs taking moret han 16 + * bytes. When a passed / returned struct takes more than 16 bytes, it needs + * to be passed around via pointers. This is one of many platform specific C + * ABI rules we need to handle to be properly called from Rust (that 16 byte + * rule is actually x86-64 specific). I'm honestly not sure how to handle + * this well. We avoid running into the manifestation of this problem by for + * some samples by removing unneeded parameters / return values from + * partitions at the schedule IR level, which we should do anyway, but this + * isn't a complete solution. + */ + +/* + * The top level function to compile a schedule IR function into LLVM IR, for + * execution on the CPU. We generate LLVM IR textually, since there are no good + * LLVM bindings for Rust, and we are *not* writing any C++. + */ +pub fn cpu_compile<W: Write>( + function: &SFunction, + manifest: &PartitionManifest, + w: &mut W, +) -> Result<(), Error> { + // Calculate basic analyses over schedule IR. + let dep_graph = sched_dependence_graph(function); + let svalue_types = sched_svalue_types(function); + let parallel_reduce_infos = sched_parallel_reduce_sections(function); + + // Generate a dummy uninitialized global - this is needed so that there'll + // be a non-empty .bss section in the ELF object file. + write!( + w, + "@dummy_{} = dso_local global i8 0, align 1\n", + manifest.name + )?; + + // Emit the partition function signature. + write!(w, "define ")?; + if function.return_types.len() == 1 { + emit_type(&function.return_types[0], w)?; + } else { + // Functions with multiple return values return said values in a struct. + emit_type(&SType::Product(function.return_types.clone().into()), w)?; + } + write!(w, " @{}(", manifest.name)?; + (0..function.param_types.len()) + .map(|param_idx| Some(SValue::VirtualRegister(param_idx))) + .intersperse(None) + .map(|token| -> Result<(), Error> { + match token { + Some(param_svalue) => { + emit_svalue(¶m_svalue, true, &svalue_types, w)?; + } + None => write!(w, ", ")?, + } + Ok(()) + }) + .collect::<Result<(), Error>>()?; + write!(w, ") {{\n",)?; + + // Emit the function body. Emit each block, one at a time. + for (block_idx, block) in function.blocks.iter().enumerate() { + // Emit the header for the block. + write!(w, "bb_{}:\n", block_idx)?; + + // For "tops" of sequential fork-joins, we hijack to the top block to be + // the loop header for the fork-join loop. + if let Some(fork_join_id) = block.kind.try_parallel() + && parallel_reduce_infos[&fork_join_id] + .top_parallel_block + .idx() + == block_idx + { + emit_fork_join_seq_header( + ¶llel_reduce_infos[&fork_join_id], + &svalue_types, + block_idx, + w, + )?; + } + + // For each basic block, emit instructions in that block. Emit using a + // worklist over the dependency graph. + let mut emitted = bitvec![u8, Lsb0; 0; block.insts.len()]; + let mut worklist = VecDeque::from((0..block.insts.len()).collect::<Vec<_>>()); + while let Some(inst_idx) = worklist.pop_front() { + let inst_id = InstID::new(block_idx, inst_idx); + let dependencies = &dep_graph[&inst_id]; + let all_uses_emitted = dependencies + .into_iter() + // Check that all used instructions in this block... + .filter(|inst_id| inst_id.idx_0() == block_idx) + // were already emitted. + .all(|inst_id| emitted[inst_id.idx_1()]); + // Phis don't need to wait for all of their uses to be emitted. + if block.insts[inst_idx].is_phi() || all_uses_emitted { + emit_inst( + block.virt_regs[inst_id.idx_1()].0, + &block.insts[inst_idx], + block + .kind + .try_fork_join_id() + .map(|fork_join_id| ¶llel_reduce_infos[&fork_join_id]), + &svalue_types, + w, + )?; + emitted.set(inst_id.idx_1(), true); + } else { + worklist.push_back(inst_idx); + } + } + } + + write!(w, "}}\n\n",)?; + Ok(()) +} + +fn emit_type<W: Write>(stype: &SType, w: &mut W) -> Result<(), Error> { + match stype { + SType::Boolean => write!(w, "i1")?, + SType::Integer8 | SType::UnsignedInteger8 => write!(w, "i8")?, + SType::Integer16 | SType::UnsignedInteger16 => write!(w, "i16")?, + SType::Integer32 | SType::UnsignedInteger32 => write!(w, "i32")?, + SType::Integer64 | SType::UnsignedInteger64 => write!(w, "i64")?, + SType::Float32 => write!(w, "float")?, + SType::Float64 => write!(w, "double")?, + SType::Product(fields) => { + write!(w, "{{")?; + fields + .into_iter() + .map(Some) + .intersperse(None) + .map(|token| -> Result<(), Error> { + match token { + Some(field_ty) => emit_type(field_ty, w)?, + None => write!(w, ", ")?, + } + Ok(()) + }) + .collect::<Result<(), Error>>()?; + write!(w, "}}")?; + } + SType::ArrayRef(_) => write!(w, "ptr")?, + } + + Ok(()) +} + +fn emit_constant<W: Write>(sconstant: &SConstant, w: &mut W) -> Result<(), Error> { + match sconstant { + SConstant::Boolean(val) => write!(w, "{}", val)?, + SConstant::Integer8(val) => write!(w, "{}", val)?, + SConstant::Integer16(val) => write!(w, "{}", val)?, + SConstant::Integer32(val) => write!(w, "{}", val)?, + SConstant::Integer64(val) => write!(w, "{}", val)?, + SConstant::UnsignedInteger8(val) => write!(w, "{}", val)?, + SConstant::UnsignedInteger16(val) => write!(w, "{}", val)?, + SConstant::UnsignedInteger32(val) => write!(w, "{}", val)?, + SConstant::UnsignedInteger64(val) => write!(w, "{}", val)?, + SConstant::Float32(val) => { + if val.fract() == 0.0 { + write!(w, "{}.0", val)? + } else { + write!(w, "{}", val)? + } + } + SConstant::Float64(val) => { + if val.fract() == 0.0 { + write!(w, "{}.0", val)? + } else { + write!(w, "{}", val)? + } + } + SConstant::Product(fields) => { + write!(w, "{{")?; + fields + .into_iter() + .map(Some) + .intersperse(None) + .map(|token| -> Result<(), Error> { + match token { + Some(field_cons) => { + emit_type(&field_cons.get_type(), w)?; + write!(w, " ")?; + emit_constant(field_cons, w)?; + } + None => write!(w, ", ")?, + } + Ok(()) + }) + .collect::<Result<(), Error>>()?; + write!(w, "}}")?; + } + } + + Ok(()) +} + +fn emit_svalue<W: Write>( + svalue: &SValue, + add_ty: bool, + types: &HashMap<SValue, SType>, + w: &mut W, +) -> Result<(), Error> { + if add_ty { + emit_type(&types[svalue], w)?; + write!(w, " ")?; + } + match svalue { + SValue::VirtualRegister(virt_reg) => write!(w, "%v{}", virt_reg)?, + SValue::Constant(cons) => emit_constant(cons, w)?, + } + Ok(()) +} + +fn emit_inst<W: Write>( + virt_reg: usize, + inst: &SInst, + parallel_reduce_info: Option<&ParallelReduceInfo>, + types: &HashMap<SValue, SType>, + w: &mut W, +) -> Result<(), Error> { + // Helper to emit the initial assignment to the destination virtual + // register, when applicable. + let self_svalue = SValue::VirtualRegister(virt_reg); + let emit_assign = |w: &mut W| -> Result<(), Error> { write!(w, "%v{} = ", virt_reg) }; + + write!(w, " ")?; + match inst { + SInst::Phi { inputs } => { + emit_assign(w)?; + write!(w, "phi ")?; + emit_type(&types[&self_svalue], w)?; + write!(w, " ")?; + inputs + .into_iter() + .map(Some) + .intersperse(None) + .map(|token| match token { + Some((pred_block_id, svalue)) => { + write!(w, "[ ")?; + emit_svalue(svalue, false, types, w)?; + write!(w, ", %bb_{} ]", pred_block_id.idx())?; + Ok(()) + } + None => write!(w, ", "), + }) + .collect::<Result<(), Error>>()?; + } + SInst::ThreadID { dimension } => { + let block = parallel_reduce_info.unwrap().top_parallel_block; + emit_assign(w)?; + write!(w, "add i64 0, %thread_id_{}_{}", block.idx(), dimension)?; + } + SInst::ReductionVariable { number } => { + write!(w, "; Already emitted reduction variable #{number}.")?; + } + SInst::Jump { + target, + parallel_entry: _, + reduce_exit, + } => { + if reduce_exit.is_some() { + write!( + w, + "br label %bb_{}", + parallel_reduce_info.unwrap().top_parallel_block.idx() + )?; + } else { + write!(w, "br label %bb_{}", target.idx())?; + } + } + SInst::Branch { + cond, + false_target, + true_target, + } => { + write!(w, "br ")?; + emit_svalue(cond, true, types, w)?; + write!( + w, + ", label %bb_{}, label %bb_{}", + true_target.idx(), + false_target.idx() + )?; + } + SInst::PartitionExit { data_outputs } => { + if data_outputs.len() == 0 { + write!(w, "ret {{}} zeroinitializer")?; + } else if data_outputs.len() == 1 { + write!(w, "ret ")?; + emit_svalue(&data_outputs[0], true, types, w)?; + } else { + let ret_ty = SType::Product( + data_outputs + .iter() + .map(|svalue| types[svalue].clone()) + .collect(), + ); + write!(w, "%v{}_0 = insertvalue ", virt_reg)?; + emit_type(&ret_ty, w)?; + write!(w, " undef, ")?; + emit_svalue(&data_outputs[0], true, types, w)?; + write!(w, ", 0\n")?; + for idx in 1..data_outputs.len() { + write!(w, " %v{}_{} = insertvalue ", virt_reg, idx)?; + emit_type(&ret_ty, w)?; + write!(w, " %v{}_{}, ", virt_reg, idx - 1)?; + emit_svalue(&data_outputs[idx], true, types, w)?; + write!(w, ", {}\n", idx)?; + } + write!(w, " ret ")?; + emit_type(&ret_ty, w)?; + write!(w, " %v{}_{}", virt_reg, data_outputs.len() - 1)?; + } + } + SInst::Return { value } => { + write!(w, "ret ")?; + emit_svalue(value, true, types, w)?; + } + SInst::Unary { input, op } => { + emit_assign(w)?; + match op { + SUnaryOperator::Not => { + write!(w, "xor ")?; + emit_svalue(input, true, types, w)?; + write!(w, ", -1")?; + } + SUnaryOperator::Neg => { + if types[input].is_float() { + write!(w, "fneg ")?; + emit_svalue(input, true, types, w)?; + } else { + write!(w, "mul ")?; + emit_svalue(input, true, types, w)?; + write!(w, ", -1")?; + } + } + SUnaryOperator::Cast(_) => todo!(), + } + } + SInst::Binary { left, right, op } => { + emit_assign(w)?; + let op = op.get_llvm_op(&types[left]); + write!(w, "{} ", op)?; + emit_svalue(left, true, types, w)?; + write!(w, ", ")?; + emit_svalue(right, false, types, w)?; + } + SInst::Ternary { + first, + second, + third, + op, + } => { + emit_assign(w)?; + match op { + STernaryOperator::Select => { + write!(w, "select ")?; + emit_svalue(first, true, types, w)?; + write!(w, ", ")?; + emit_svalue(second, true, types, w)?; + write!(w, ", ")?; + emit_svalue(third, true, types, w)?; + } + } + } + SInst::ArrayLoad { + array, + position, + bounds, + } => { + emit_linear_index_calc(virt_reg, position, bounds, types, w)?; + write!(w, "%load_ptr_{} = getelementptr ", virt_reg)?; + emit_type(&types[&self_svalue], w)?; + write!(w, ", ")?; + emit_svalue(array, true, types, w)?; + write!(w, ", i64 %calc_linear_idx_{}\n ", virt_reg)?; + emit_assign(w)?; + write!(w, "load ")?; + emit_type(&types[&self_svalue], w)?; + write!(w, ", ptr %load_ptr_{}", virt_reg)?; + } + SInst::ArrayStore { + array, + value, + position, + bounds, + } => { + emit_linear_index_calc(virt_reg, position, bounds, types, w)?; + write!(w, "%store_ptr_{} = getelementptr ", virt_reg)?; + emit_type(&types[&self_svalue], w)?; + write!(w, ", ")?; + emit_svalue(array, true, types, w)?; + write!(w, ", i64 %calc_linear_idx_{}\n ", virt_reg)?; + write!(w, "store ")?; + emit_svalue(value, true, types, w)?; + write!(w, ", ptr %store_ptr_{}", virt_reg)?; + } + _ => {} + } + write!(w, "\n")?; + + Ok(()) +} + +/* + * Emit the loop header implementing a sequential fork-join. + */ +fn emit_fork_join_seq_header<W: Write>( + info: &ParallelReduceInfo, + types: &HashMap<SValue, SType>, + block_idx: usize, + w: &mut W, +) -> Result<(), Error> { + // Emit the phis for the linear loop index variable and the reduction + // variables. + // TODO: handle these cases: + // 1. A parallel-reduce section is nested inside another parallel-reduce + // section. If the predecessor block is itself a top parallel block, the + // predecessor needs to be the _body block, not the original block. + // 2. A parallel-reduce section is immediately followed by another parallel- + // reduce section. If the predecessor block is itself a reduce block, the + // predecessor needs to be the top parallel block of the previous + // parallel-reduce section, not its reduce block. + let entry_pred = info.predecessor; + let loop_pred = info.reduce_block; + write!( + w, + " %linear_{} = phi i64 [ 0, %bb_{} ], [ %linear_{}_inc, %bb_{} ]\n", + block_idx, + entry_pred.idx(), + block_idx, + loop_pred.idx(), + )?; + for (var_num, virt_reg) in info.reduction_variables.iter() { + write!(w, " %v{} = phi ", virt_reg)?; + emit_type(&types[&SValue::VirtualRegister(*virt_reg)], w)?; + write!(w, " [ ")?; + emit_svalue(&info.reduce_inits[*var_num], false, types, w)?; + write!(w, ", %bb_{} ], [ ", entry_pred.idx())?; + emit_svalue(&info.reduce_reducts[*var_num], false, types, w)?; + write!(w, ", %bb_{} ]\n", loop_pred.idx())?; + } + + // Calculate the loop bounds. + if info.thread_counts.len() == 1 { + write!(w, " %bound_{} = add i64 0, ", block_idx)?; + emit_svalue(&info.thread_counts[0], false, types, w)?; + write!(w, "\n")?; + } else if info.thread_counts.len() == 2 { + write!(w, " %bound_{} = mul ", block_idx)?; + emit_svalue(&info.thread_counts[0], true, types, w)?; + write!(w, ", ")?; + emit_svalue(&info.thread_counts[1], false, types, w)?; + write!(w, "\n")?; + } else { + todo!("TODO: Handle the 3 or more dimensional fork-join case.") + } + + // Emit the branch. + write!( + w, + " %cond_{} = icmp ult i64 %linear_{}, %bound_{}\n", + block_idx, block_idx, block_idx + )?; + write!( + w, + " br i1 %cond_{}, label %bb_{}_body, label %bb_{}\n", + block_idx, + block_idx, + info.successor.idx() + )?; + + // Start the body of the loop. + write!(w, "bb_{}_body:\n", block_idx)?; + + // Calculate the multi-dimensional thread indices. + if info.thread_counts.len() == 1 { + write!( + w, + " %thread_id_{}_0 = add i64 0, %linear_{}\n", + block_idx, block_idx + )?; + } else if info.thread_counts.len() == 2 { + write!( + w, + " %thread_id_{}_0 = udiv i64 %linear_{}, ", + block_idx, block_idx + )?; + emit_svalue(&info.thread_counts[1], false, types, w)?; + write!(w, "\n")?; + write!( + w, + " %thread_id_{}_1 = urem i64 %linear_{}, ", + block_idx, block_idx + )?; + emit_svalue(&info.thread_counts[1], false, types, w)?; + write!(w, "\n")?; + } else { + todo!("TODO: Handle the 3 or more dimensional fork-join case.") + } + + // Increment the linear index. + write!( + w, + " %linear_{}_inc = add i64 %linear_{}, 1\n", + block_idx, block_idx + )?; + + Ok(()) +} + +/* + * Implement the index math to convert a multi-dimensional position to a linear + * position inside an array. + */ +fn emit_linear_index_calc<W: Write>( + virt_reg: usize, + position: &[SValue], + bounds: &[SValue], + types: &HashMap<SValue, SType>, + w: &mut W, +) -> Result<(), Error> { + assert_eq!(position.len(), bounds.len()); + + if position.len() == 1 { + write!(w, "%calc_linear_idx_{} = add i64 0, ", virt_reg)?; + emit_svalue(&position[0], false, types, w)?; + write!(w, "\n ")?; + } else if position.len() == 2 { + write!(w, "%calc_linear_idx_{}_0 = mul ", virt_reg)?; + emit_svalue(&position[0], true, types, w)?; + write!(w, ", ")?; + emit_svalue(&bounds[1], false, types, w)?; + write!( + w, + "\n %calc_linear_idx_{} = add i64 %calc_linear_idx_{}_0, ", + virt_reg, virt_reg + )?; + emit_svalue(&position[1], false, types, w)?; + write!(w, "\n ")?; + } else { + todo!("TODO: Handle the 3 or more dimensional array case.") + } + + Ok(()) +} + +impl SBinaryOperator { + fn get_llvm_op(&self, left_ty: &SType) -> &'static str { + enum OpTy { + Float, + Unsigned, + Signed, + } + + let op_ty = if left_ty.is_float() { + OpTy::Float + } else if left_ty.is_unsigned() { + OpTy::Unsigned + } else { + OpTy::Signed + }; + + match (self, op_ty) { + (SBinaryOperator::Add, OpTy::Float) => "fadd", + (SBinaryOperator::Add, _) => "add", + (SBinaryOperator::Sub, OpTy::Float) => "fsub", + (SBinaryOperator::Sub, _) => "sub", + (SBinaryOperator::Mul, OpTy::Float) => "fmul", + (SBinaryOperator::Mul, _) => "mul", + (SBinaryOperator::Div, OpTy::Float) => "fdiv", + (SBinaryOperator::Div, OpTy::Unsigned) => "udiv", + (SBinaryOperator::Div, OpTy::Signed) => "sdiv", + (SBinaryOperator::Rem, OpTy::Float) => "frem", + (SBinaryOperator::Rem, OpTy::Unsigned) => "urem", + (SBinaryOperator::Rem, OpTy::Signed) => "srem", + (SBinaryOperator::LT, OpTy::Float) => "fcmp olt", + (SBinaryOperator::LT, OpTy::Unsigned) => "icmp ult", + (SBinaryOperator::LT, OpTy::Signed) => "icmp slt", + (SBinaryOperator::LTE, OpTy::Float) => "fcmp ole", + (SBinaryOperator::LTE, OpTy::Unsigned) => "icmp ule", + (SBinaryOperator::LTE, OpTy::Signed) => "icmp sle", + (SBinaryOperator::GT, OpTy::Float) => "fcmp ogt", + (SBinaryOperator::GT, OpTy::Unsigned) => "icmp ugt", + (SBinaryOperator::GT, OpTy::Signed) => "icmp sgt", + (SBinaryOperator::GTE, OpTy::Float) => "fcmp oge", + (SBinaryOperator::GTE, OpTy::Unsigned) => "icmp uge", + (SBinaryOperator::GTE, OpTy::Signed) => "icmp sge", + (SBinaryOperator::EQ, OpTy::Float) => "fcmp oeq", + (SBinaryOperator::EQ, _) => "icmp eq", + (SBinaryOperator::NE, OpTy::Float) => "fcmp one", + (SBinaryOperator::NE, _) => "icmp ne", + (SBinaryOperator::Or, _) => "or", + (SBinaryOperator::And, _) => "and", + (SBinaryOperator::Xor, _) => "xor", + (SBinaryOperator::LSh, _) => "lsh", + (SBinaryOperator::RSh, OpTy::Unsigned) => "lshr", + (SBinaryOperator::RSh, _) => "ashr", + } + } +} diff --git a/hercules_cg/src/lib.rs b/hercules_cg/src/lib.rs new file mode 100644 index 0000000000000000000000000000000000000000..b9a67ebab31efd77c12bfddaee5665443df5deb0 --- /dev/null +++ b/hercules_cg/src/lib.rs @@ -0,0 +1,13 @@ +#![feature(let_chains, iter_intersperse)] + +pub mod cpu; +pub mod manifest; +pub mod sched_gen; +pub mod sched_ir; +pub mod sched_schedule; + +pub use crate::cpu::*; +pub use crate::manifest::*; +pub use crate::sched_gen::*; +pub use crate::sched_ir::*; +pub use crate::sched_schedule::*; diff --git a/hercules_cg/src/manifest.rs b/hercules_cg/src/manifest.rs new file mode 100644 index 0000000000000000000000000000000000000000..1634b87c131b1b9e4dc1217bf89e8f9118418925 --- /dev/null +++ b/hercules_cg/src/manifest.rs @@ -0,0 +1,123 @@ +extern crate serde; + +extern crate hercules_ir; + +use std::iter::once; + +use self::serde::Deserialize; +use self::serde::Serialize; + +use self::hercules_ir::*; + +use crate::*; + +/* + * A manifest stores metadata about a Hercules function. This metadata is used + * by the runtime to actually call a Hercules function. + */ +#[derive(Debug, Clone, Hash, Serialize, Deserialize)] +pub struct Manifest { + // The signature of each Hercules function is represented in terms of + // STypes, since this is the lowest level type representation that Hercules + // constructs before reaching target-specific backends. + pub param_types: Vec<(SType, ParameterKind)>, + pub return_type: SType, + + // The dynamic constants (potentially) used in this Hercules function. + pub dynamic_constants: Vec<DynamicConstant>, + // The dimensions for array constants defined and used in this Hercules + // function. + pub array_constants: Vec<Box<[DynamicConstantID]>>, + + // The partitions that make up this Hercules function. + pub partitions: Vec<PartitionManifest>, +} + +#[derive(Debug, Clone, Hash, Serialize, Deserialize)] +pub struct PartitionManifest { + // Each partition has one corresponding SFunction. + pub name: SFunctionName, + // Record the type and kind of each parameter. + pub parameters: Vec<(SType, ParameterKind)>, + // Record the type and kind of each return value. + pub returns: Vec<(SType, ReturnKind)>, + // Record the list of possible successors from this partition. + pub successors: Vec<PartitionID>, +} + +#[derive(Debug, Clone, Hash, Serialize, Deserialize, PartialEq, Eq)] +pub enum ParameterKind { + // A parameter corresponding to a parameter of the Hercules function. + HerculesParameter(usize), + // A parameter corresponding to some data defined in some other partition. + DataInput(NodeID), + // A parameter corresponding to a dynamic constant input to the Hercules + // function. + DynamicConstant(usize), + // A parameter corresponding to an array constant used in the partition. + ArrayConstant(ArrayID), +} + +#[derive(Debug, Clone, Hash, Serialize, Deserialize)] +pub enum ReturnKind { + // A return value corresponding to the return value of the Hercules + // function. + HerculesReturn, + // A return value corresponding to some data used in some other partition. + DataOutput(NodeID), + // An integer specifying which partition should be executed next, if this + // partition has multiple successors. + NextPartition, +} + +impl Manifest { + pub fn all_visible_types(&self) -> impl Iterator<Item = SType> + '_ { + self.param_types + // Include the Hercules function parameter types. + .iter() + .map(|(ty, _)| ty.clone()) + // Include the Hercules function return type. + .chain(once(self.return_type.clone())) + // Include the partition parameter types. + .chain( + self.partitions + .iter() + .map(|partition| partition.parameters.iter().map(|(ty, _)| ty.clone())) + .flatten(), + ) + // Include the partition return types. + .chain( + self.partitions + .iter() + .map(|partition| partition.returns.iter().map(|(ty, _)| ty.clone())) + .flatten(), + ) + // Include the product types formed by the partition return types, + // since multiple return values are returned inside a struct. + .chain(self.partitions.iter().map(|partition| { + SType::Product(partition.returns.iter().map(|(ty, _)| ty.clone()).collect()) + })) + } +} + +impl PartitionManifest { + pub fn data_inputs(&self) -> impl Iterator<Item = (NodeID, &SType)> + '_ { + self.parameters.iter().filter_map(|(stype, param_kind)| { + if let ParameterKind::DataInput(id) = param_kind { + Some((*id, stype)) + } else { + None + } + }) + } + + pub fn data_outputs(&self) -> impl Iterator<Item = (NodeID, &SType)> + '_ { + self.returns.iter().filter_map(|(stype, return_kind)| { + if let ReturnKind::DataOutput(id) = return_kind { + Some((*id, stype)) + } else { + None + } + }) + } +} diff --git a/hercules_cg/src/sched_gen.rs b/hercules_cg/src/sched_gen.rs new file mode 100644 index 0000000000000000000000000000000000000000..c969680108a281e1c2c36dd9905604c5b83be032 --- /dev/null +++ b/hercules_cg/src/sched_gen.rs @@ -0,0 +1,1329 @@ +extern crate bitvec; + +extern crate hercules_ir; + +use std::cell::Cell; +use std::collections::{HashMap, VecDeque}; +use std::iter::zip; +use std::mem::{swap, take}; + +use self::bitvec::prelude::*; + +use self::hercules_ir::*; + +use crate::*; + +pub fn sched_compile( + module: &Module, + def_uses: &Vec<ImmutableDefUseMap>, + reverse_postorders: &Vec<Vec<NodeID>>, + typing: &ModuleTyping, + control_subgraphs: &Vec<Subgraph>, + fork_join_maps: &Vec<HashMap<NodeID, NodeID>>, + fork_join_nests: &Vec<HashMap<NodeID, Vec<NodeID>>>, + antideps: &Vec<Vec<(NodeID, NodeID)>>, + bbs: &Vec<Vec<NodeID>>, + plans: &Vec<Plan>, +) -> SModule { + verify_types_well_formed_for_sched_ir(&module.types); + let stypes = convert_to_sched_ir_types(&module.types); + verify_constants_well_formed_for_sched_ir(&module.constants); + let sconstants = convert_to_sched_ir_constants(&module.constants); + + let mut functions = HashMap::new(); + let mut manifests = HashMap::new(); + for idx in 0..module.functions.len() { + let (sfunctions, manifest) = FunctionContext::new( + &module.functions[idx], + &module.types, + &module.constants, + &module.dynamic_constants, + &def_uses[idx], + &reverse_postorders[idx], + &typing[idx], + &control_subgraphs[idx], + &fork_join_maps[idx], + &fork_join_nests[idx], + &antideps[idx], + &bbs[idx], + &plans[idx], + &stypes, + &sconstants, + ) + .compile_function(); + + functions.extend(sfunctions.into_iter()); + manifests.insert(module.functions[idx].name.clone(), manifest); + } + + SModule { + functions, + manifests, + } +} + +/* + * Checks for the following conditions: + * 1. Array types are only ever root types. + * 2. Summation types do not exist - TODO: properly support summation types. + */ +fn verify_types_well_formed_for_sched_ir(types: &Vec<Type>) { + for ty in types.iter() { + match ty { + Type::Product(fields) => { + let any_array = fields + .iter() + .any(|field_ty| types[field_ty.idx()].is_array()); + assert!(!any_array, "PANIC: Found non-root array type."); + } + Type::Summation(_) => panic!("PANIC: Can't lower summations to schedule IR yet."), + Type::Array(elem, _) => assert!( + !types[elem.idx()].is_array(), + "PANIC: Found non-root array type." + ), + _ => {} + } + } +} + +fn convert_to_sched_ir_types(types: &Vec<Type>) -> Vec<SType> { + let mut stypes = vec![SType::Boolean; types.len()]; + + for id in types_bottom_up(types) { + stypes[id.idx()] = match &types[id.idx()] { + Type::Control => SType::Boolean, + Type::Boolean => SType::Boolean, + Type::Integer8 => SType::Integer8, + Type::Integer16 => SType::Integer16, + Type::Integer32 => SType::Integer32, + Type::Integer64 => SType::Integer64, + Type::UnsignedInteger8 => SType::UnsignedInteger8, + Type::UnsignedInteger16 => SType::UnsignedInteger16, + Type::UnsignedInteger32 => SType::UnsignedInteger32, + Type::UnsignedInteger64 => SType::UnsignedInteger64, + Type::Float32 => SType::Float32, + Type::Float64 => SType::Float64, + Type::Product(fields) => { + SType::Product(fields.iter().map(|id| stypes[id.idx()].clone()).collect()) + } + Type::Summation(_) => todo!(), + Type::Array(elem_ty, _) => SType::ArrayRef(Box::new(stypes[elem_ty.idx()].clone())), + }; + } + + stypes +} + +/* + * Checks for the following conditions: + * 1. Array constants are only ever root constants. + * 2. Summation constants do not exist - TODO: properly support summation + * constants. + */ +fn verify_constants_well_formed_for_sched_ir(constants: &Vec<Constant>) { + for ty in constants.iter() { + match ty { + Constant::Product(_, fields) => { + let any_array = fields + .iter() + .any(|field_ty| constants[field_ty.idx()].is_array()); + assert!(!any_array, "PANIC: Found non-root array constant."); + } + Constant::Summation(_, _, _) => { + panic!("PANIC: Can't lower summations to schedule IR yet.") + } + // We don't need to check array constants explicitly, since they + // explicitly store their type, and nothing else - if an array + // constant were invalid by condition #1, then its corresponding + // type would be invalid, which would get caught by + // `verify_types_well_formed_for_sched_ir`. + _ => {} + } + } +} + +fn convert_to_sched_ir_constants(constants: &Vec<Constant>) -> Vec<SConstant> { + let mut sconstants = vec![SConstant::Boolean(false); constants.len()]; + + for id in constants_bottom_up(constants) { + sconstants[id.idx()] = match &constants[id.idx()] { + Constant::Boolean(val) => SConstant::Boolean(*val), + Constant::Integer8(val) => SConstant::Integer8(*val), + Constant::Integer16(val) => SConstant::Integer16(*val), + Constant::Integer32(val) => SConstant::Integer32(*val), + Constant::Integer64(val) => SConstant::Integer64(*val), + Constant::UnsignedInteger8(val) => SConstant::UnsignedInteger8(*val), + Constant::UnsignedInteger16(val) => SConstant::UnsignedInteger16(*val), + Constant::UnsignedInteger32(val) => SConstant::UnsignedInteger32(*val), + Constant::UnsignedInteger64(val) => SConstant::UnsignedInteger64(*val), + Constant::Float32(val) => SConstant::Float32(*val), + Constant::Float64(val) => SConstant::Float64(*val), + Constant::Product(_, fields) => SConstant::Product( + fields + .iter() + .map(|id| sconstants[id.idx()].clone()) + .collect(), + ), + Constant::Summation(_, _, _) => todo!(), + // Array constants are never generated inline schedule IR. + Constant::Array(_) => SConstant::Boolean(false), + }; + } + + sconstants +} + +/* + * Converts one Hercules function to N schedule IR functions, where N is the + * number of partitions in the Hercules function. + */ +struct FunctionContext<'a> { + function: &'a Function, + types: &'a Vec<Type>, + constants: &'a Vec<Constant>, + dynamic_constants: &'a Vec<DynamicConstant>, + def_use: &'a ImmutableDefUseMap, + reverse_postorder: &'a Vec<NodeID>, + typing: &'a Vec<TypeID>, + control_subgraph: &'a Subgraph, + fork_join_map: &'a HashMap<NodeID, NodeID>, + fork_join_nest: &'a HashMap<NodeID, Vec<NodeID>>, + antideps: &'a Vec<(NodeID, NodeID)>, + bbs: &'a Vec<NodeID>, + plan: &'a Plan, + stypes: &'a Vec<SType>, + sconstants: &'a Vec<SConstant>, + + top_nodes: Vec<NodeID>, + partition_graph: Subgraph, + inverted_partition_map: Vec<Vec<NodeID>>, + data_inputs: Vec<Vec<NodeID>>, + data_outputs: Vec<Vec<NodeID>>, + + num_virtual_registers: Vec<Cell<usize>>, +} + +impl<'a> FunctionContext<'a> { + fn new( + function: &'a Function, + types: &'a Vec<Type>, + constants: &'a Vec<Constant>, + dynamic_constants: &'a Vec<DynamicConstant>, + def_use: &'a ImmutableDefUseMap, + reverse_postorder: &'a Vec<NodeID>, + typing: &'a Vec<TypeID>, + control_subgraph: &'a Subgraph, + fork_join_map: &'a HashMap<NodeID, NodeID>, + fork_join_nest: &'a HashMap<NodeID, Vec<NodeID>>, + antideps: &'a Vec<(NodeID, NodeID)>, + bbs: &'a Vec<NodeID>, + plan: &'a Plan, + stypes: &'a Vec<SType>, + sconstants: &'a Vec<SConstant>, + ) -> Self { + let inverted_partition_map = plan.invert_partition_map(); + let top_nodes = plan.compute_top_nodes(function, control_subgraph, &inverted_partition_map); + let partition_graph = partition_graph(function, def_use, plan); + let data_inputs = plan.compute_data_inputs(function); + let data_outputs = plan.compute_data_outputs(function, def_use); + + let num_virtual_registers = vec![Cell::new(0); plan.num_partitions]; + + FunctionContext { + function, + types, + constants, + dynamic_constants, + def_use, + reverse_postorder, + typing, + control_subgraph, + fork_join_map, + fork_join_nest, + antideps, + bbs, + plan, + stypes, + sconstants, + + top_nodes, + partition_graph, + inverted_partition_map, + data_inputs, + data_outputs, + + num_virtual_registers, + } + } + + /* + * Top level function to compile a Hercules IR function into simple IR + * functions. + */ + fn compile_function(&self) -> (HashMap<SFunctionName, SFunction>, Manifest) { + let (manifest, array_node_to_array_id) = self.compute_manifest(); + + manifest + .partitions + .iter() + .enumerate() + .for_each(|(idx, partition_manifest)| { + self.num_virtual_registers[idx].set(partition_manifest.parameters.len()) + }); + + let partition_functions = (0..self.plan.num_partitions) + .map(|partition_idx| { + ( + self.get_sfunction_name(partition_idx), + self.compile_partition(partition_idx, &manifest, &array_node_to_array_id), + ) + }) + .collect(); + + (partition_functions, manifest) + } + + /* + * Compute the manifest for a Hercules function. This includes all of the + * partition signature information. + */ + fn compute_manifest(&self) -> (Manifest, HashMap<NodeID, ArrayID>) { + // The manifest needs to contain metadata for allocating arrays. + let dynamic_constants = self.dynamic_constants.clone(); + let array_constants = self + .function + .nodes + .iter() + .filter_map(|node| { + if let Some(cons) = node.try_constant() + && let Some(ty) = self.constants[cons.idx()].try_array_type() + { + let extents = self.types[ty.idx()] + .try_extents() + .expect("PANIC: Type of array constant is not an array type."); + Some(extents.into_iter().map(|id| *id).collect()) + } else { + None + } + }) + .collect(); + + // Assign each array constant a unique ID for noting which ones to pass + // to what partition functions. + let array_node_to_array_id = (0..self.function.nodes.len()) + .filter(|node_idx| { + if let Some(cons) = self.function.nodes[*node_idx].try_constant() { + self.constants[cons.idx()].is_array() + } else { + false + } + }) + .enumerate() + .map(|(idx, node_idx)| (NodeID::new(node_idx), ArrayID::new(idx))) + .collect::<HashMap<NodeID, ArrayID>>(); + + let partitions = (0..self.plan.num_partitions) + .map(|partition_idx| { + let partition = &self.inverted_partition_map[partition_idx]; + let name = self.get_sfunction_name(partition_idx); + let mut parameters = vec![]; + let mut returns = vec![]; + + // Compute the signature of each partitions' schedule IR + // function, which has the following structure: + // 1. If the partition is the entry partition, the first + // parameters are the parameters to the Hercules function. If + // not, then the first parameters are all of the data inputs + // to the partition. Note that parameter nodes are always in + // the partition of the start node (the entry partition), so + // function parameters used in other partitions are treated + // as an inter-partition data dependence. + if partition_idx == 0 { + parameters.extend(self.function.param_types.iter().enumerate().map( + |(param_idx, ty_id)| { + ( + self.stypes[ty_id.idx()].clone(), + ParameterKind::HerculesParameter(param_idx), + ) + }, + )); + } else { + parameters.extend(self.data_inputs[partition_idx].iter().map(|node_id| { + ( + self.stypes[self.typing[node_id.idx()].idx()].clone(), + ParameterKind::DataInput(*node_id), + ) + })) + } + + // 2. The second set of parameters are references to zero-ed + // memories for implementing array constants. Implicit array + // cloning is, for now, forbidden. Array constants are + // rematerialized into each partition that uses the constant, + // so look over all of the uses of all the nodes in the + // partition, not all of the nodes in the partition. + parameters.extend( + partition + .iter() + .map(|node_id| { + get_uses(&self.function.nodes[node_id.idx()]) + .as_ref() + .iter() + .filter_map(|use_id| { + if let Some(array_id) = array_node_to_array_id.get(use_id) { + Some(( + self.stypes[self.typing[use_id.idx()].idx()].clone(), + ParameterKind::ArrayConstant(*array_id), + )) + } else { + None + } + }) + .collect::<Vec<_>>() + }) + .flatten(), + ); + + // 3. The third set of parameters are the dynamic constants + // passed to the overall function. + parameters.extend((0..self.function.num_dynamic_constants).map(|idx| { + ( + SType::UnsignedInteger64, + ParameterKind::DynamicConstant(idx as usize), + ) + })); + + // Note that many partitions will be given unused parameters + // (mainly dynamic constants). These will be removed during the + // small amount of optimization done on simple IR. + + // Simple IR functions may return multiple values (this is to + // avoid needing to pack / un-pack product types). The return + // value of an exit partition is the return value of the + // Hercules function. The return values of non-exit partitions + // are the data outputs of the partition, possibly plus an + // integer specifying what partition should be executed next, if + // there are multiple successor partitions. A valid partitioning + // will only contain partitions with either a branch to another + // partition xor a return node. + let successors = self + .partition_graph + .succs(NodeID::new(partition_idx)) + .map(|node_id| PartitionID::new(node_id.idx())) + .collect::<Vec<PartitionID>>(); + if partition + .iter() + .any(|node_id| self.function.nodes[node_id.idx()].is_return()) + { + assert_eq!(successors.len(), 0); + returns.push(( + self.stypes[self.function.return_type.idx()].clone(), + ReturnKind::HerculesReturn, + )); + } else { + assert!(successors.len() > 0); + returns.extend(self.data_outputs[partition_idx].iter().map(|node_id| { + ( + self.stypes[self.typing[node_id.idx()].idx()].clone(), + ReturnKind::DataOutput(*node_id), + ) + })); + if successors.len() > 1 { + returns.push((SType::Integer8, ReturnKind::NextPartition)); + } + } + + PartitionManifest { + name, + parameters, + returns, + successors, + } + }) + .collect(); + + // The parameters for the overall Hercules function is computed in a + // similar fashion as for the individual partition functions. + let mut param_types = vec![]; + param_types.extend(self.function.param_types.iter().enumerate().map( + |(param_idx, ty_id)| { + ( + self.stypes[ty_id.idx()].clone(), + ParameterKind::HerculesParameter(param_idx), + ) + }, + )); + param_types.extend(array_node_to_array_id.iter().map(|(node_id, array_id)| { + ( + self.stypes[self.typing[node_id.idx()].idx()].clone(), + ParameterKind::ArrayConstant(*array_id), + ) + })); + param_types.extend((0..self.function.num_dynamic_constants).map(|idx| { + ( + SType::UnsignedInteger64, + ParameterKind::DynamicConstant(idx as usize), + ) + })); + + // The return type is just the schedule IR type corresponding to the + // Hercules function's return type. + let return_type = self.stypes[self.function.return_type.idx()].clone(); + + let manifest = Manifest { + param_types, + return_type, + dynamic_constants, + array_constants, + partitions, + }; + (manifest, array_node_to_array_id) + } + + /* + * Compile a partition into an SFunction. + */ + fn compile_partition( + &self, + partition_idx: usize, + manifest: &Manifest, + array_node_to_array_id: &HashMap<NodeID, ArrayID>, + ) -> SFunction { + let partition = &self.inverted_partition_map[partition_idx]; + let mut blocks = vec![]; + + // First, create basic blocks inside the SFunction corresponding to the + // control nodes in the partition. If this isn't the entry partition + // (partition #0), add an entry block, since the first basic block in a + // partition may have a predecessor inside the partition. + let mut control_id_to_block_id = HashMap::new(); + let mut fork_node_id_to_fork_join_id = HashMap::new(); + if partition_idx != 0 { + // Create an explicit entry block, if one is already created via the + // Start node. + blocks.push(SBlock::default()); + } + for node in partition { + if self.function.nodes[node.idx()].is_control() { + control_id_to_block_id.insert(*node, BlockID::new(blocks.len())); + let mut block = SBlock::default(); + if let Some(imm_fork) = self.fork_join_nest[node].get(0) { + let new_id = ForkJoinID::new(fork_node_id_to_fork_join_id.len()); + let fork_join_id = *fork_node_id_to_fork_join_id + .entry(*imm_fork) + .or_insert(new_id); + block.kind = if self.function.nodes[node.idx()].is_join() { + SBlockKind::Reduce(fork_join_id) + } else { + SBlockKind::Parallel(fork_join_id) + }; + } + blocks.push(block); + } + } + + // Second, assign every data node a SValue. This map incorporates info + // from the manifest to make using SFunction parameters easy. + let mut data_id_to_svalue = manifest.partitions[partition_idx] + .parameters + .iter() + .enumerate() + .filter_map(|(idx, (_, kind))| match kind { + // Assign SValues to nodes defined outside the partition and + // passed in via SFunction parameters. + ParameterKind::DataInput(node_id) => Some((*node_id, SValue::VirtualRegister(idx))), + _ => None, + }) + .chain( + // Assign SValues for nodes inside the partition. + partition + .iter() + .filter(|node_id| !self.function.nodes[node_id.idx()].is_control()) + .filter_map(|data_id| { + let value = match self.function.nodes[data_id.idx()] { + // Phis in a block with no predecessors inside the + // current partition don't get lowered to phis in + // schedule IR - they get lowered to partition + // parameters. Phis with some predecessors in the + // same partition and some in a different partition + // get lowered to a combination of a SFunction + // parameter and a phi instruction, and uses of the + // phi node should become uses of the phi + // instruction. + Node::Phi { control, data: _ } + if self.control_subgraph.preds(control).all(|pred| { + self.plan.partitions[pred.idx()] + != self.plan.partitions[control.idx()] + }) => + { + // If the phi just gets lowered to a parameter, + // it got added above when adding the virtual + // registers for the SFunction parameters. + return None; + } + // Figure out which virtual constant in the + // signature of the current SFunction corresponds to + // a particular Hercules parameter. + Node::Parameter { index } => SValue::VirtualRegister( + manifest.partitions[partition_idx] + .parameters + .iter() + .position(|(_, kind)| { + *kind == ParameterKind::HerculesParameter(index) + }) + .unwrap(), + ), + // Wait to assign SValues to constants. We assign + // SValues to constants in user partitions, not in + // the partition the constant node happens to be in. + Node::Constant { id: _ } => { + return None; + } + // Dynamic constant nodes get generated upfront, + // since they may or may not need a virtual register + // freshly allocated for them. The math necessary + // for them gets put in the block corresponding to + // the control node the DynamicConstant node was + // scheduled to. + Node::DynamicConstant { id } => { + let block_id = control_id_to_block_id[&self.bbs[data_id.idx()]]; + self.compile_dynamic_constant( + id, + &mut blocks[block_id.idx()], + partition_idx, + manifest, + ) + } + // Wait to assign SValues to array writes. + Node::Write { + collect: _, + data: _, + indices: _, + } if self.types[self.typing[data_id.idx()].idx()].is_array() => { + return None + } + _ => SValue::VirtualRegister(self.make_virt_reg(partition_idx)), + }; + Some((*data_id, value)) + }), + ) + .chain( + // Assign SValues for constants used by nodes in the partition. + partition + .iter() + .map(|node_id| { + get_uses(&self.function.nodes[node_id.idx()]) + .as_ref() + .iter() + .filter_map(|use_id| { + if let Node::Constant { id } = self.function.nodes[use_id.idx()] { + // Array constants map to the parameter the + // array memory is passed in through - all + // other constants are represented inline in + // an SValue. + let svalue = if let Some(array_id) = + array_node_to_array_id.get(use_id) + { + SValue::VirtualRegister( + manifest.partitions[partition_idx] + .parameters + .iter() + .position(|(_, kind)| { + *kind == ParameterKind::ArrayConstant(*array_id) + }) + .unwrap(), + ) + } else { + SValue::Constant(self.sconstants[id.idx()].clone()) + }; + Some((*use_id, svalue)) + } else { + None + } + }) + .collect::<Vec<_>>() + }) + .flatten(), + ) + .collect::<HashMap<_, _>>(); + + // Next, assign all the array write nodes. Array write nodes are + // recursively assigned the SValue of their `collect` input. + let mut worklist = partition + .iter() + .filter(|id| { + self.function.nodes[id.idx()].is_write() && !data_id_to_svalue.contains_key(id) + }) + .map(|id| *id) + .collect::<VecDeque<_>>(); + while let Some(id) = worklist.pop_front() { + let pred = match self.function.nodes[id.idx()] { + Node::Write { + data: _, + indices: _, + collect, + } => collect, + _ => panic!("PANIC: Filtered out write nodes, but found a different node kind."), + }; + if let Some(svalue) = data_id_to_svalue.get(&pred) { + data_id_to_svalue.insert(id, svalue.clone()); + } else { + worklist.push_front(id); + } + } + + // Third, generate code for every node in the partition. Iterates + // through a worklist of nodes in the partition. For non-phi and non- + // reduce nodes, only emit once all data uses are emitted. In addition, + // consider additional anti-dependence edges from read to write nodes. + // Def-use and anti-dependence edges are the only ordering we guarantee + // in schedule IR basic blocks, and it's up to device-specific backends + // to perform instruction scheduling. + let mut visited = bitvec![u8, Lsb0; 0; self.function.nodes.len()]; + let mut worklist = partition.iter().map(|id| *id).collect::<VecDeque<_>>(); + while let Some(id) = worklist.pop_front() { + if self.function.nodes[id.idx()].is_phi() + || self.function.nodes[id.idx()].is_reduce() + || get_uses(&self.function.nodes[id.idx()]) + .as_ref() + .into_iter() + // If this node isn't a phi or reduce, we need to check that + // all uses, as well as all reads we anti-depend with, have + // been emitted. + .chain(self.antideps.iter().filter_map(|(read, write)| { + if id == *write { + Some(read) + } else { + None + } + })) + // Only data dependencies within this partition need to have + // already been visited. + .all(|use_id| { + self.plan.partitions[use_id.idx()] != PartitionID::new(partition_idx) + || self.function.nodes[use_id.idx()].is_control() + || visited[use_id.idx()] + }) + { + // Once all of the data dependencies for this node are emitted, + // this node can be emitted. + self.compile_node( + id, + &control_id_to_block_id, + &data_id_to_svalue, + &mut blocks, + partition_idx, + manifest, + ); + visited.set(id.idx(), true); + } else { + // Skip emitting node if it's not a phi or reduce node and if + // its data uses are not emitted yet. + worklist.push_back(id); + } + } + + // Fourth, add the jump from the explicit entry block to the top node's + // block in the partition. + if partition_idx != 0 { + // Explicitly jump to the block corresponding to the top of the + // partition. That block may be a parallel block, but it's not a + // reduce block. + let top_node = self.top_nodes[partition_idx]; + let top_block = control_id_to_block_id[&top_node]; + let parallel_entry = if self.function.nodes[top_node.idx()].is_fork() { + Some(self.compile_parallel_entry( + top_node, + &data_id_to_svalue, + &mut blocks[0], + partition_idx, + manifest, + )) + } else { + None + }; + blocks[0].insts.push(SInst::Jump { + target: top_block, + parallel_entry, + reduce_exit: None, + }); + blocks[0] + .virt_regs + .push((self.make_virt_reg(partition_idx), SType::Boolean)); + } + + SFunction { + blocks, + param_types: manifest.partitions[partition_idx] + .parameters + .iter() + .map(|(sty, _)| sty.clone()) + .collect(), + return_types: manifest.partitions[partition_idx] + .returns + .iter() + .map(|(sty, _)| sty.clone()) + .collect(), + } + } + + fn compile_node( + &self, + id: NodeID, + control_id_to_block_id: &HashMap<NodeID, BlockID>, + data_id_to_svalue: &HashMap<NodeID, SValue>, + blocks: &mut Vec<SBlock>, + partition_idx: usize, + manifest: &Manifest, + ) { + let bb = self.bbs[id.idx()]; + let block_id = Cell::new(control_id_to_block_id[&bb]); + let mut block = take(&mut blocks[block_id.get().idx()]); + + // Uses of reduce nodes inside their corresponding reduce block need to + // refer to the reduction variable instruction, not the output of the + // reduce block. + let get_svalue = |id: NodeID| data_id_to_svalue[&id].clone(); + let self_virt_reg = || get_svalue(id).try_virt_reg().unwrap(); + + // Helper function to lower a jump to a particular control node. + let lower_jmp = |dst: NodeID, block: &mut SBlock| { + if let Some(block_id) = control_id_to_block_id.get(&dst) { + // The successor block is in this partition. Add extra info to + // the jump if we're jumping into a parallel section or out of a + // reduce section. Note that both of those may be true at once. + let parallel_entry = if self.function.nodes[dst.idx()].is_fork() { + Some(self.compile_parallel_entry( + dst, + data_id_to_svalue, + block, + partition_idx, + manifest, + )) + } else { + None + }; + let reduce_exit = if self.function.nodes[id.idx()].is_join() { + Some(self.compile_reduce_exit(dst, data_id_to_svalue)) + } else { + None + }; + block.insts.push(SInst::Jump { + target: *block_id, + parallel_entry, + reduce_exit, + }); + } else { + assert_ne!( + self.plan.partitions[id.idx()], + self.plan.partitions[dst.idx()] + ); + + // The successor block is in a different partition. + let next_partition = self.plan.partitions[dst.idx()]; + let data_outputs = manifest.partitions[partition_idx] + .returns + .iter() + .map(|(_, kind)| match kind { + ReturnKind::DataOutput(id) => get_svalue(*id).clone(), + ReturnKind::HerculesReturn => panic!("PANIC: Partition can't contain a HerculesReturn kind of return value when it jumps to another partition."), + ReturnKind::NextPartition => SValue::Constant(SConstant::Integer8(next_partition.idx() as i8)), + }) + .collect(); + + block.insts.push(SInst::PartitionExit { data_outputs }); + } + block + .virt_regs + .push((self.make_virt_reg(partition_idx), SType::Boolean)); + }; + + // Helper function to generate the dynamic constant math to compute the + // bounds of an array type of a node. + let lower_extents = |id: NodeID, block: &mut SBlock| { + self.types[self.typing[id.idx()].idx()] + .try_extents() + .unwrap() + .iter() + .map(|dc| self.compile_dynamic_constant(*dc, block, partition_idx, manifest)) + .collect() + }; + + // Emit schedule IR instructions corresponding to this Hercules IR node. + match self.function.nodes[id.idx()] { + // Forks are super simple to lower here. Since what's sequential / + // parallel / reducing is encoded in basic block kinds, and entry / + // exits are handled in `lower_jmp`, we just need to add a jump like + // any other control block with one successor. + Node::Start + | Node::Region { preds: _ } + | Node::Projection { + control: _, + selection: _, + } + | Node::Fork { + control: _, + factors: _, + } => { + let mut succs = self.control_subgraph.succs(id); + assert_eq!(succs.len(), 1); + let succ = succs.next().unwrap(); + lower_jmp(succ, &mut block); + } + Node::Join { control: _ } => { + let mut succs = self.control_subgraph.succs(id); + assert_eq!(succs.len(), 1); + let succ = succs.next().unwrap(); + if self.plan.partitions[id.idx()] != self.plan.partitions[succ.idx()] { + // If the successor is in another partition, we need to add + // a sequential block to hold the PartitionExit. Add a jump, + // with reduce exit metadata, to the reduce block. + let exit_block_id = BlockID::new(blocks.len()); + let reduce_exit = self.compile_reduce_exit(id, data_id_to_svalue); + block.insts.push(SInst::Jump { + target: exit_block_id, + parallel_entry: None, + reduce_exit: Some(reduce_exit), + }); + block + .virt_regs + .push((self.make_virt_reg(partition_idx), SType::Boolean)); + // The exit block contains just a PartitionExit instruction. + let mut exit_block = SBlock::default(); + // `lower_jmp` depends on `block_id`, so temporarily update. + let old_block_id = block_id.get(); + block_id.set(exit_block_id); + lower_jmp(succ, &mut exit_block); + block_id.set(old_block_id); + blocks.push(exit_block); + } else { + // Otherwise, lower the jump as normal. + lower_jmp(succ, &mut block); + } + } + Node::If { control: _, cond } => { + let mut succs = self.control_subgraph.succs(id); + let mut proj1 = succs.next().unwrap(); + let mut proj2 = succs.next().unwrap(); + assert_eq!(succs.next(), None); + if self.function.nodes[proj1.idx()].try_proj().unwrap().1 == 1 { + swap(&mut proj1, &mut proj2); + } + block.insts.push(SInst::Branch { + cond: get_svalue(cond).clone(), + false_target: control_id_to_block_id[&self.bbs[proj1.idx()]], + true_target: control_id_to_block_id[&self.bbs[proj2.idx()]], + }); + block + .virt_regs + .push((self.make_virt_reg(partition_idx), SType::Boolean)); + } + Node::Return { control: _, data } => { + block.insts.push(SInst::Return { + value: get_svalue(data).clone(), + }); + block + .virt_regs + .push((self.make_virt_reg(partition_idx), SType::Boolean)); + } + + Node::Phi { control, ref data } => { + let control_uses = get_uses(&self.function.nodes[control.idx()]); + let mut found_in_partition_predecessor = false; + let mut found_out_of_partition_predecessor = false; + let inputs = zip(control_uses.as_ref().iter(), data.iter()) + .filter_map(|(control_use, data_id)| { + if let Some(block_id) = control_id_to_block_id.get(control_use) { + // If any of the predecessors are in this partition, + // we actually generate a phi instruction. + // Otherwise, we just need to refer to the parameter + // of the SFunction corresponding to this phi. + found_in_partition_predecessor = true; + Some((*block_id, get_svalue(*data_id).clone())) + } else { + // Don't add multiple inputs for block #0. + if found_out_of_partition_predecessor { + return None; + } + // This predecessor for the phi gets passed in via a + // parameter set up for this phi. + found_out_of_partition_predecessor = true; + let param_idx = manifest.partitions[partition_idx] + .parameters + .iter() + .position(|(_, kind)| *kind == ParameterKind::DataInput(id)) + .unwrap(); + Some((BlockID::new(0), SValue::VirtualRegister(param_idx))) + } + }) + .collect(); + + // If there's at least one predecessor inside this partition, we + // need to generate an actual phi instruction. + if found_in_partition_predecessor { + block.insts.push(SInst::Phi { inputs }); + block.virt_regs.push(( + self_virt_reg(), + self.stypes[self.typing[id.idx()].idx()].clone(), + )); + } + } + + Node::ThreadID { + control: _, + dimension, + } => { + block.insts.push(SInst::ThreadID { dimension }); + block + .virt_regs + .push((self_virt_reg(), SType::UnsignedInteger64)); + } + Node::Reduce { + control, + init: _, + reduct: _, + } => { + // Determine the reduction variable number based on the users of + // the join node. + let number = self + .def_use + .get_users(control) + .iter() + .filter(|user| self.function.nodes[user.idx()].is_reduce()) + .position(|user| *user == id) + .unwrap(); + block.insts.push(SInst::ReductionVariable { number }); + block.virt_regs.push(( + self_virt_reg(), + self.stypes[self.typing[id.idx()].idx()].clone(), + )); + } + + Node::Unary { input, op } => { + block.insts.push(SInst::Unary { + input: get_svalue(input).clone(), + op: convert_unary_op(op, &self.stypes), + }); + block.virt_regs.push(( + self_virt_reg(), + self.stypes[self.typing[id.idx()].idx()].clone(), + )); + } + Node::Binary { left, right, op } => { + block.insts.push(SInst::Binary { + left: get_svalue(left).clone(), + right: get_svalue(right).clone(), + op: convert_binary_op(op), + }); + block.virt_regs.push(( + self_virt_reg(), + self.stypes[self.typing[id.idx()].idx()].clone(), + )); + } + Node::Ternary { + first, + second, + third, + op, + } => { + block.insts.push(SInst::Ternary { + first: get_svalue(first).clone(), + second: get_svalue(second).clone(), + third: get_svalue(third).clone(), + op: convert_ternary_op(op), + }); + block.virt_regs.push(( + self_virt_reg(), + self.stypes[self.typing[id.idx()].idx()].clone(), + )); + } + + Node::Read { + collect, + ref indices, + } => { + let mut collect_svalue = get_svalue(collect); + let mut prod_indices = &indices[..]; + + // We currently only support top-level arrays and products. The + // array and product portions become separate instructions. + // Since arrays are always root types, handle them first. + if let Some(position) = indices[0].try_position() { + // If there's both an array load and a product extract, we + // need to allocate an intermediary virtual register. + let dst_virt_reg = if indices.len() > 1 { + self.make_virt_reg(partition_idx) + } else { + self_virt_reg() + }; + + let position = position.iter().map(|id| get_svalue(*id)).collect(); + // Array loads need the dynamic constant bounds for indexing + // math. + let bounds = lower_extents(collect, &mut block); + let load_ty = if let SType::ArrayRef(elem_ty) = + self.stypes[self.typing[collect.idx()].idx()].clone() + { + *elem_ty + } else { + panic!("PANIC: Type of collection isn't an array when an ArrayLoad use is generated.") + }; + block.insts.push(SInst::ArrayLoad { + array: collect_svalue, + position, + bounds, + }); + block.virt_regs.push((dst_virt_reg, load_ty)); + + // The product extract needs to extract from the product + // loaded from the array. + collect_svalue = SValue::VirtualRegister(dst_virt_reg); + prod_indices = &indices[1..]; + } + + // Handle the product indices. + if prod_indices.len() > 0 { + let indices = prod_indices + .iter() + .map(|index| index.try_field().unwrap()) + .collect(); + block.insts.push(SInst::ProductExtract { + product: collect_svalue, + indices, + }); + block.virt_regs.push(( + self_virt_reg(), + self.stypes[self.typing[id.idx()].idx()].clone(), + )); + } + } + Node::Write { + collect, + data, + ref indices, + } => { + // We currently only support top-level arrays and products. + // There are three cases that we handle separately: + // 1. Writing to an array. This just lowers to an ArrayStore. + // 2. Writing to a product inside an array. This lowers to an + // ArrayLoad to get the initial product value, a + // ProductInsert to update the product value, and an + // ArrayStore to write the new product value into the array. + // 3. Writing to a product. This just lowers to a ProductInsert. + + if let Some(position) = indices[0].try_position() + && indices.len() == 1 + { + // Handle case #1. + let position = position.iter().map(|id| get_svalue(*id)).collect(); + // Array stores need the dynamic constant bounds for + // indexing math. + let bounds = lower_extents(collect, &mut block); + block.insts.push(SInst::ArrayStore { + array: get_svalue(collect), + value: get_svalue(data), + position, + bounds, + }); + // Array stores don't produce a meaningful virtual register. + block + .virt_regs + .push((self.make_virt_reg(partition_idx), SType::Boolean)); + } else if let Some(position) = indices[0].try_position() { + // Handle case #2. + let position = position + .iter() + .map(|id| get_svalue(*id)) + .collect::<Box<[_]>>(); + let bounds = lower_extents(collect, &mut block); + + // Load the product. + let load_virt_reg = self.make_virt_reg(partition_idx); + let load_ty = if let SType::ArrayRef(elem_ty) = + self.stypes[self.typing[collect.idx()].idx()].clone() + { + *elem_ty + } else { + panic!("PANIC: Type of collection isn't an array when an ArrayLoad use is generated.") + }; + block.insts.push(SInst::ArrayLoad { + array: get_svalue(collect), + position: position.clone(), + bounds: bounds.clone(), + }); + block.virt_regs.push((load_virt_reg, load_ty.clone())); + + // Update the product. + let update_virt_reg = self.make_virt_reg(partition_idx); + let indices = indices[1..] + .iter() + .map(|index| index.try_field().unwrap()) + .collect(); + block.insts.push(SInst::ProductInsert { + product: SValue::VirtualRegister(load_virt_reg), + data: get_svalue(data), + indices, + }); + block.virt_regs.push((update_virt_reg, load_ty)); + + // Store the product. + block.insts.push(SInst::ArrayStore { + array: get_svalue(collect), + value: SValue::VirtualRegister(update_virt_reg), + position, + bounds, + }); + block + .virt_regs + .push((self.make_virt_reg(partition_idx), SType::Boolean)); + } else { + // Handle case #3. + let indices = indices + .iter() + .map(|index| index.try_field().unwrap()) + .collect(); + block.insts.push(SInst::ProductInsert { + product: get_svalue(collect), + data: get_svalue(data), + indices, + }); + // Product insertions do produce a virtual register, since + // they create a new product value. + block.virt_regs.push(( + self_virt_reg(), + self.stypes[self.typing[id.idx()].idx()].clone(), + )); + } + } + + // There are a few nodes for which no code needs to get emitted. + _ => {} + } + + blocks[block_id.get().idx()] = block; + } + + /* + * Compiles a reference to a dynamic constant into math to compute that + * dynamic constant. We need a mutable reference to some basic block, since + * we may need to generate math inline to compute the dynamic constant. + * TODO: actually implement dynamic constant math - only then will the above + * be true. + */ + fn compile_dynamic_constant( + &self, + dc: DynamicConstantID, + _block: &mut SBlock, + partition_idx: usize, + manifest: &Manifest, + ) -> SValue { + match self.dynamic_constants[dc.idx()] { + DynamicConstant::Constant(cons) => { + SValue::Constant(SConstant::UnsignedInteger64(cons as u64)) + } + DynamicConstant::Parameter(idx) => SValue::VirtualRegister( + manifest.partitions[partition_idx] + .parameters + .iter() + .position(|(_, kind)| *kind == ParameterKind::DynamicConstant(idx)) + .unwrap(), + ), + } + } + + /* + * Makes a parallel entry for a jump to a fork. + */ + fn compile_parallel_entry( + &self, + fork: NodeID, + data_id_to_svalue: &HashMap<NodeID, SValue>, + block: &mut SBlock, + partition_idx: usize, + manifest: &Manifest, + ) -> ParallelEntry { + let (_, factors) = self.function.nodes[fork.idx()].try_fork().unwrap(); + let thread_counts = factors + .iter() + .map(|dc_id| self.compile_dynamic_constant(*dc_id, block, partition_idx, manifest)) + .collect(); + let reduce_inits = self + .def_use + .get_users(self.fork_join_map[&fork]) + .iter() + .filter_map(|user| self.function.nodes[user.idx()].try_reduce()) + .map(|(_, init, _)| data_id_to_svalue[&init].clone()) + .collect(); + ParallelEntry { + thread_counts, + reduce_inits, + } + } + + /* + * Makes a reduce exit for a jump from a join. + */ + fn compile_reduce_exit( + &self, + join: NodeID, + data_id_to_svalue: &HashMap<NodeID, SValue>, + ) -> ReduceExit { + let reduce_reducts = self + .def_use + .get_users(join) + .iter() + .filter(|user| self.function.nodes[user.idx()].is_reduce()) + .map(|reduce| { + // The SValues that get passed to the reduce exit are the + // `reduct` input to the reduce node. + data_id_to_svalue[&get_uses(&self.function.nodes[reduce.idx()]).as_ref()[2]].clone() + }) + .collect(); + ReduceExit { reduce_reducts } + } + + fn make_virt_reg(&self, partition_idx: usize) -> usize { + let virt_reg = self.num_virtual_registers[partition_idx].get(); + self.num_virtual_registers[partition_idx].set(virt_reg + 1); + virt_reg + } + + fn get_sfunction_name(&self, partition_idx: usize) -> SFunctionName { + format!("{}_{}", self.function.name, partition_idx) + } +} + +fn convert_unary_op(op: UnaryOperator, simple_ir_types: &[SType]) -> SUnaryOperator { + match op { + UnaryOperator::Not => SUnaryOperator::Not, + UnaryOperator::Neg => SUnaryOperator::Neg, + UnaryOperator::Cast(ty) => SUnaryOperator::Cast(simple_ir_types[ty.idx()].clone()), + } +} + +fn convert_binary_op(op: BinaryOperator) -> SBinaryOperator { + match op { + BinaryOperator::Add => SBinaryOperator::Add, + BinaryOperator::Sub => SBinaryOperator::Sub, + BinaryOperator::Mul => SBinaryOperator::Mul, + BinaryOperator::Div => SBinaryOperator::Div, + BinaryOperator::Rem => SBinaryOperator::Rem, + BinaryOperator::LT => SBinaryOperator::LT, + BinaryOperator::LTE => SBinaryOperator::LTE, + BinaryOperator::GT => SBinaryOperator::GT, + BinaryOperator::GTE => SBinaryOperator::GTE, + BinaryOperator::EQ => SBinaryOperator::EQ, + BinaryOperator::NE => SBinaryOperator::NE, + BinaryOperator::Or => SBinaryOperator::Or, + BinaryOperator::And => SBinaryOperator::And, + BinaryOperator::Xor => SBinaryOperator::Xor, + BinaryOperator::LSh => SBinaryOperator::LSh, + BinaryOperator::RSh => SBinaryOperator::RSh, + } +} + +fn convert_ternary_op(op: TernaryOperator) -> STernaryOperator { + match op { + TernaryOperator::Select => STernaryOperator::Select, + } +} diff --git a/hercules_cg/src/sched_ir.rs b/hercules_cg/src/sched_ir.rs new file mode 100644 index 0000000000000000000000000000000000000000..56a2bb51a8070545c9db247c569c5818be3077f3 --- /dev/null +++ b/hercules_cg/src/sched_ir.rs @@ -0,0 +1,486 @@ +extern crate ordered_float; +extern crate serde; + +extern crate hercules_ir; + +use std::collections::HashMap; + +use self::serde::Deserialize; +use self::serde::Serialize; + +use self::hercules_ir::*; + +use crate::*; + +/* + * A schedule IR module is a list of functions and a description of each + * Hercules function in terms of schedule IR functions (called the manifest). + */ +#[derive(Debug, Default, Clone)] +pub struct SModule { + // Refer to schedule IR functions by their name. + pub functions: HashMap<SFunctionName, SFunction>, + // Each Hercules function maps to a manifest. + pub manifests: HashMap<String, Manifest>, +} + +/* + * A schedule IR function consists of a CFG of basic blocks, each containing + * instructions. Instructions can produce virtual register outputs, and SSA form + * is guaranteed. SFunctions can have multiple parameters and return values - + * many values may cross partition boundaries at once. + * + * Since SFunctions represent partitions, many SFunctions don't "return". + * Instead, conceptually they "jump" to the next partition to run. SFunctions + * that jump to another partition contain the "PartitionExit" instruction, while + * SFunctions that return from the Hercules function contain the "Return" + * instruction. An SFunction must contain either PartitionExits xor Returns. + * + * There are two special kinds of basic blocks for representing fork-joins: + * parallel blocks and reduce blocks. Each parallel / reduce block is associated + * with a unique ID per fork-join. A parallel block can contain a "ThreadID" + * instruction, which gets the Nth thread ID. A reduce block can contain a + * "Reduce" instruction, which gets the last value of the Mth reduction + * variable. When jumping to a parallel block, a u64 must be provided per fork + * dimension, specifying how many threads should spawn, and an initial value per + * reduction variable must be provided. When jumping out of a reduce block, a + * "new" value for each reduction variable must be provided. + */ +#[derive(Debug, Default, Clone)] +pub struct SFunction { + pub blocks: Vec<SBlock>, + pub param_types: Vec<SType>, + pub return_types: Vec<SType>, +} + +impl SFunction { + pub fn get_inst(&self, id: InstID) -> &SInst { + &self.blocks[id.idx_0()].insts[id.idx_1()] + } + + pub fn get_inst_mut(&mut self, id: InstID) -> &mut SInst { + &mut self.blocks[id.idx_0()].insts[id.idx_1()] + } +} + +/* + * Use a very simple representation for blocks, since modification is not a + * priority. Unlike many IRs (say LLVM), the instructions in schedule IR blocks + * aren't necessarily ordered, as different backends may have different + * scheduling considerations. This means that, for example, each SBlock must + * contain exactly one terminating instruction, but the position of that + * instruction may not be at the end of the block. All that's required is that + * defs precede uses, and that loads and stores to array references are ordered. + */ +#[derive(Debug, Default, Clone)] +pub struct SBlock { + pub insts: Vec<SInst>, + // The virtual registers created by each instruction. Technically, this will + // assign instructions like ArrayStores and Regions a virtual register, + // which doesn't make sense. These virtual registers are just ignored. Each + // virtual register has a certain schedule IR type. The type of virtual + // registers produced by certain instructions, like Jump or ArrayStore, is + // set to SType::Boolean, but it's not meaningful. + pub virt_regs: Vec<(usize, SType)>, + pub kind: SBlockKind, +} + +impl SBlock { + pub fn successors(&self) -> BlockSuccessors { + self.insts + .iter() + .map(|inst| inst.block_successors()) + .filter(|successors| *successors != BlockSuccessors::Zero) + .next() + .unwrap_or(BlockSuccessors::Zero) + } +} + +#[derive(Debug, Default, Clone, PartialEq, Eq)] +pub enum SBlockKind { + #[default] + Sequential, + Parallel(ForkJoinID), + Reduce(ForkJoinID), +} + +impl SBlockKind { + pub fn try_parallel(&self) -> Option<ForkJoinID> { + if let SBlockKind::Parallel(id) = self { + Some(*id) + } else { + None + } + } + + pub fn try_reduce(&self) -> Option<ForkJoinID> { + if let SBlockKind::Reduce(id) = self { + Some(*id) + } else { + None + } + } + + pub fn try_fork_join_id(&self) -> Option<ForkJoinID> { + match self { + SBlockKind::Sequential => None, + SBlockKind::Parallel(id) | SBlockKind::Reduce(id) => Some(*id), + } + } +} + +/* + * Unlike Hercules IR, we can represent a reference to an array (so that we + * don't need to use an array value in this IR). This is fine, since we're not + * doing much analysis / optimization at this stage, and most platforms we want + * to target have a similar model for working with arrays anyway. We still need + * value product types, since the layout of these types may be platform + * dependent. + */ +#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)] +pub enum SType { + Boolean, + Integer8, + Integer16, + Integer32, + Integer64, + UnsignedInteger8, + UnsignedInteger16, + UnsignedInteger32, + UnsignedInteger64, + Float32, + Float64, + // Don't intern STypes to make developing the code generator easier. + Product(Box<[SType]>), + // Array types don't include their bounds, since dynamic constants are not + // an IR-level concept in schedule IR. + ArrayRef(Box<SType>), +} + +impl SType { + pub fn is_float(&self) -> bool { + match self { + SType::Float32 | SType::Float64 => true, + _ => false, + } + } + + pub fn is_unsigned(&self) -> bool { + match self { + SType::UnsignedInteger8 + | SType::UnsignedInteger16 + | SType::UnsignedInteger32 + | SType::UnsignedInteger64 => true, + _ => false, + } + } + + pub fn try_product(&self) -> Option<&[SType]> { + if let SType::Product(fields) = self { + Some(fields) + } else { + None + } + } +} + +/* + * Represents constants, except for array constants. + */ +#[derive(Debug, Clone, PartialEq, Eq, Hash)] +pub enum SConstant { + Boolean(bool), + Integer8(i8), + Integer16(i16), + Integer32(i32), + Integer64(i64), + UnsignedInteger8(u8), + UnsignedInteger16(u16), + UnsignedInteger32(u32), + UnsignedInteger64(u64), + Float32(ordered_float::OrderedFloat<f32>), + Float64(ordered_float::OrderedFloat<f64>), + // Don't intern SConstants to make developing the code generator easier. + Product(Box<[SConstant]>), +} + +impl SConstant { + pub fn get_type(&self) -> SType { + match self { + SConstant::Boolean(_) => SType::Boolean, + SConstant::Integer8(_) => SType::Integer8, + SConstant::Integer16(_) => SType::Integer16, + SConstant::Integer32(_) => SType::Integer32, + SConstant::Integer64(_) => SType::Integer64, + SConstant::UnsignedInteger8(_) => SType::UnsignedInteger8, + SConstant::UnsignedInteger16(_) => SType::UnsignedInteger16, + SConstant::UnsignedInteger32(_) => SType::UnsignedInteger32, + SConstant::UnsignedInteger64(_) => SType::UnsignedInteger64, + SConstant::Float32(_) => SType::Float32, + SConstant::Float64(_) => SType::Float64, + SConstant::Product(fields) => { + SType::Product(fields.into_iter().map(|field| field.get_type()).collect()) + } + } + } +} + +#[derive(Debug, Clone, PartialEq, Eq, Hash)] +pub enum SValue { + Constant(SConstant), + VirtualRegister(usize), +} + +impl SValue { + pub fn try_virt_reg(&self) -> Option<usize> { + if let SValue::VirtualRegister(vr) = self { + Some(*vr) + } else { + None + } + } + + pub fn try_constant(&self) -> Option<&SConstant> { + if let SValue::Constant(cons) = self { + Some(cons) + } else { + None + } + } +} + +/* + * Typical instructions of a CFG + SSA IR, plus some instructions for + * representing particular Hercules IR quirks. + */ +#[derive(Debug, Clone)] +pub enum SInst { + Phi { + inputs: Box<[(BlockID, SValue)]>, + }, + ThreadID { + dimension: usize, + }, + ReductionVariable { + number: usize, + }, + Jump { + target: BlockID, + parallel_entry: Option<ParallelEntry>, + reduce_exit: Option<ReduceExit>, + }, + Branch { + cond: SValue, + false_target: BlockID, + true_target: BlockID, + }, + PartitionExit { + data_outputs: Box<[SValue]>, + }, + Return { + value: SValue, + }, + Unary { + input: SValue, + op: SUnaryOperator, + }, + Binary { + left: SValue, + right: SValue, + op: SBinaryOperator, + }, + Ternary { + first: SValue, + second: SValue, + third: SValue, + op: STernaryOperator, + }, + ProductExtract { + product: SValue, + indices: Box<[usize]>, + }, + ProductInsert { + product: SValue, + data: SValue, + indices: Box<[usize]>, + }, + ArrayLoad { + array: SValue, + position: Box<[SValue]>, + bounds: Box<[SValue]>, + }, + ArrayStore { + array: SValue, + value: SValue, + position: Box<[SValue]>, + bounds: Box<[SValue]>, + }, +} + +impl SInst { + pub fn is_phi(&self) -> bool { + if let SInst::Phi { inputs: _ } = self { + true + } else { + false + } + } + + pub fn is_jump(&self) -> bool { + if let SInst::Jump { + target: _, + parallel_entry: _, + reduce_exit: _, + } = self + { + true + } else { + false + } + } + + pub fn is_partition_exit(&self) -> bool { + if let SInst::PartitionExit { data_outputs: _ } = self { + true + } else { + false + } + } + + pub fn is_return(&self) -> bool { + if let SInst::Return { value: _ } = self { + true + } else { + false + } + } + + pub fn is_terminator(&self) -> bool { + self.is_jump() || self.is_partition_exit() || self.is_return() + } + + pub fn try_thread_id(&self) -> Option<usize> { + if let SInst::ThreadID { dimension } = self { + Some(*dimension) + } else { + None + } + } + + pub fn try_reduction_variable(&self) -> Option<usize> { + if let SInst::ReductionVariable { number } = self { + Some(*number) + } else { + None + } + } + + pub fn try_jump(&self) -> Option<(BlockID, Option<&ParallelEntry>, Option<&ReduceExit>)> { + if let SInst::Jump { + target, + parallel_entry, + reduce_exit, + } = self + { + Some((*target, parallel_entry.as_ref(), reduce_exit.as_ref())) + } else { + None + } + } + + pub fn block_successors(&self) -> BlockSuccessors { + match self { + SInst::Jump { + target, + parallel_entry: _, + reduce_exit: _, + } => BlockSuccessors::One([*target]), + SInst::Branch { + cond: _, + false_target, + true_target, + } => BlockSuccessors::Two([*false_target, *true_target]), + _ => BlockSuccessors::Zero, + } + } +} + +#[derive(Debug, PartialEq, Eq)] +pub enum BlockSuccessors { + Zero, + One([BlockID; 1]), + Two([BlockID; 2]), +} + +impl AsRef<[BlockID]> for BlockSuccessors { + fn as_ref(&self) -> &[BlockID] { + match self { + BlockSuccessors::Zero => &[], + BlockSuccessors::One(x) => x, + BlockSuccessors::Two(x) => x, + } + } +} + +/* + * On entering a parallel section, we need to specify how many threads to spawn + * and what the initial values of the reduction variables are. + */ +#[derive(Debug, Clone)] +pub struct ParallelEntry { + pub thread_counts: Box<[SValue]>, + pub reduce_inits: Box<[SValue]>, +} + +/* + * On exiting a reduce section, we need to specify which instructions in the + * reduce block correspond to what reduction variables. This also specifies + * which values defined inside the reduce block can be used outside the block. + */ +#[derive(Debug, Clone)] +pub struct ReduceExit { + pub reduce_reducts: Box<[SValue]>, +} + +/* + * The operator types are mostly the same. + */ +#[derive(Debug, Clone, PartialEq, Eq, Hash)] +pub enum SUnaryOperator { + Not, + Neg, + Cast(SType), +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] +pub enum SBinaryOperator { + Add, + Sub, + Mul, + Div, + Rem, + LT, + LTE, + GT, + GTE, + EQ, + NE, + Or, + And, + Xor, + LSh, + RSh, +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] +pub enum STernaryOperator { + Select, +} + +pub type SFunctionName = String; + +define_id_type!(ArrayID); +define_id_type!(BlockID); +define_id_type!(ForkJoinID); + +define_dual_id_type!(InstID); diff --git a/hercules_cg/src/sched_schedule.rs b/hercules_cg/src/sched_schedule.rs new file mode 100644 index 0000000000000000000000000000000000000000..ac4583df1f8abc493649fbb73cc150c20191ea75 --- /dev/null +++ b/hercules_cg/src/sched_schedule.rs @@ -0,0 +1,352 @@ +use std::collections::{HashMap, HashSet, VecDeque}; +use std::iter::{empty, once, zip}; + +use crate::*; + +/* + * Iterate over the uses of a instruction. + */ +pub fn sched_get_uses(inst: &SInst) -> Box<dyn Iterator<Item = &SValue> + '_> { + match inst { + SInst::Phi { inputs } => Box::new(inputs.iter().map(|(_, svalue)| svalue)), + SInst::ThreadID { dimension: _ } => Box::new(empty()), + SInst::ReductionVariable { number: _ } => Box::new(empty()), + SInst::Jump { + target: _, + parallel_entry, + reduce_exit, + } => { + let first = parallel_entry.as_ref().map(|parallel_entry| { + parallel_entry + .thread_counts + .iter() + .chain(parallel_entry.reduce_inits.iter()) + }); + let second = reduce_exit + .as_ref() + .map(|reduce_exit| reduce_exit.reduce_reducts.iter()); + match (first, second) { + (Some(first), Some(second)) => Box::new(first.chain(second)), + (Some(first), None) => Box::new(first), + (None, Some(second)) => Box::new(second), + (None, None) => Box::new(empty()), + } + } + SInst::Branch { + cond, + false_target: _, + true_target: _, + } => Box::new(once(cond)), + SInst::PartitionExit { data_outputs } => Box::new(data_outputs.iter().map(|svalue| svalue)), + SInst::Return { value } => Box::new(once(value)), + SInst::Unary { input, op: _ } => Box::new(once(input)), + SInst::Binary { left, right, op: _ } => Box::new(once(left).chain(once(right))), + SInst::Ternary { + first, + second, + third, + op: _, + } => Box::new(once(first).chain(once(second)).chain(once(third))), + SInst::ProductExtract { + product, + indices: _, + } => Box::new(once(product)), + SInst::ProductInsert { + product, + data, + indices: _, + } => Box::new(once(product).chain(once(data))), + SInst::ArrayLoad { + array, + position, + bounds, + } => Box::new(once(array).chain(position.iter()).chain(bounds.iter())), + SInst::ArrayStore { + array, + value, + position, + bounds, + } => Box::new( + once(array) + .chain(once(value)) + .chain(position.iter()) + .chain(bounds.iter()), + ), + } +} + +/* + * Build a dependency graph of instructions in an SFunction. + */ +pub fn sched_dependence_graph(function: &SFunction) -> HashMap<InstID, Vec<InstID>> { + let mut dep_graph = HashMap::new(); + + // First, map each virtual register to the instruction ID producing it. + let mut virt_reg_to_inst_id = HashMap::new(); + for block_idx in 0..function.blocks.len() { + let block = &function.blocks[block_idx]; + for inst_idx in 0..block.insts.len() { + let virt_reg = block.virt_regs[inst_idx].0; + let inst_id = InstID::new(block_idx, inst_idx); + virt_reg_to_inst_id.insert(virt_reg, inst_id); + dep_graph.insert(inst_id, vec![]); + } + } + + // Process the dependencies in each block. This includes inter-block + // dependencies for normal def-use edges. + for block_idx in 0..function.blocks.len() { + let block = &function.blocks[block_idx]; + + // Add normal dependencies. + for inst_idx in 0..block.insts.len() { + let inst_id = InstID::new(block_idx, inst_idx); + let inst = &block.insts[inst_idx]; + for use_sval in sched_get_uses(inst) { + if let SValue::VirtualRegister(virt_reg) = use_sval { + // Uses of parameters don't correspond to any instruction we + // need to depend on. + if let Some(use_id) = virt_reg_to_inst_id.get(virt_reg) { + let deps = dep_graph.get_mut(&inst_id).unwrap(); + if !deps.contains(use_id) { + deps.push(*use_id); + } + } + } + } + } + + // Phis should appear at the top of linear basic blocks. + + // Add dependencies between the phis. + let mut last_phi = None; + for inst_idx in 0..block.insts.len() { + let inst_id = InstID::new(block_idx, inst_idx); + let inst = &block.insts[inst_idx]; + if inst.is_phi() { + if let Some(last_phi) = last_phi { + let deps = dep_graph.get_mut(&inst_id).unwrap(); + if !deps.contains(&last_phi) { + deps.push(last_phi); + } + } + last_phi = Some(inst_id); + } + } + + // If there is at least one phi, add a dependency between the "last" phi + // and every non-phi instruction with no dependencies yet. + if let Some(last_phi) = last_phi { + for inst_idx in 0..block.insts.len() { + let inst_id = InstID::new(block_idx, inst_idx); + let inst = &block.insts[inst_idx]; + if !inst.is_phi() { + let deps = dep_graph.get_mut(&inst_id).unwrap(); + if deps.is_empty() { + deps.push(last_phi); + } + } + } + } + + // Terminator instructions appear at the bottom of linear basic blocks. + + // Find every non-terminator instruction with no users. + let mut no_user_insts = (0..block.insts.len()) + .filter(|inst_idx| !block.insts[*inst_idx].is_terminator()) + .map(|inst_idx| InstID::new(block_idx, inst_idx)) + .collect::<HashSet<_>>(); + for inst_idx in 0..block.insts.len() { + let inst_id = InstID::new(block_idx, inst_idx); + for dep in dep_graph[&inst_id].iter() { + no_user_insts.remove(dep); + } + } + + // Add a dependency between each instruction with no users (previously) + // and each terminator instruction. + for inst_idx in 0..block.insts.len() { + let inst_id = InstID::new(block_idx, inst_idx); + let inst = &block.insts[inst_idx]; + if inst.is_terminator() { + let deps = dep_graph.get_mut(&inst_id).unwrap(); + for no_user_inst in no_user_insts.iter() { + if !deps.contains(no_user_inst) { + deps.push(*no_user_inst); + } + } + } + } + } + + dep_graph +} + +/* + * Assemble a map from SValue to SType. + */ +pub fn sched_svalue_types(function: &SFunction) -> HashMap<SValue, SType> { + let mut result = HashMap::new(); + + // Add types of parameters. + for (param_idx, param_ty) in function.param_types.iter().enumerate() { + result.insert(SValue::VirtualRegister(param_idx), param_ty.clone()); + } + + // Add types of instructions and constants. + for block in function.blocks.iter() { + for (inst, (virt_reg, sty)) in zip(block.insts.iter(), block.virt_regs.iter()) { + // Add the type of the output of the instruction. + result.insert(SValue::VirtualRegister(*virt_reg), sty.clone()); + + // Find constants inthe uses of instructions. + for u in sched_get_uses(inst) { + if let SValue::Constant(cons) = u { + result.insert(u.clone(), cons.get_type()); + } + } + } + } + + result +} + +/* + * Analysis information for one fork-join. + */ +pub struct ParallelReduceInfo { + // The block that jumps into the parallel section. + pub predecessor: BlockID, + // The block that is jumped into after the reduce section. + pub successor: BlockID, + + // The first parallel block in the parallel section. + pub top_parallel_block: BlockID, + // The parallel block that jumps to the reduce section. + pub bottom_parallel_block: BlockID, + // The single block in the reduce section. + pub reduce_block: BlockID, + + // The thread count SValues used for this fork-join. + pub thread_counts: Box<[SValue]>, + // The initial SValues for the reduction variables. + pub reduce_inits: Box<[SValue]>, + // The reduct SValues for the reduction variables. + pub reduce_reducts: Box<[SValue]>, + + // Map from thread ID dimension to virtual registers of corresponding thread + // ID instructions. + pub thread_ids: HashMap<usize, Vec<usize>>, + // Map from reduction variable number to virtual register of the + // corresponding reduction variable instruction. + pub reduction_variables: HashMap<usize, usize>, +} + +/* + * Analyze parallel-reduce sections to make lowering them easier. Returns a map + * from ForkJoinID to information about that parallel-reduce section. + */ +pub fn sched_parallel_reduce_sections( + function: &SFunction, +) -> HashMap<ForkJoinID, ParallelReduceInfo> { + let mut result = HashMap::new(); + + for (block_idx, block) in function.blocks.iter().enumerate() { + // Start by identifying a jump into a parallel section. + for inst in block.insts.iter() { + if let SInst::Jump { + target, + parallel_entry, + reduce_exit: _, + } = inst + && let Some(parallel_entry) = parallel_entry + { + let predecessor = BlockID::new(block_idx); + let ParallelEntry { + thread_counts, + reduce_inits, + } = parallel_entry.clone(); + + // The jump target is the top of the parallel section. Get the + // fork-join ID from that block. + let top_parallel_block = *target; + let fork_join_id = function.blocks[top_parallel_block.idx()] + .kind + .try_parallel() + .unwrap(); + + // Traverse the blocks until finding a jump to a reduce block. + let mut queue = VecDeque::from(vec![top_parallel_block]); + let mut visited = HashSet::new(); + visited.insert(top_parallel_block); + let mut bfs_dest = None; + while let Some(bfs) = queue.pop_front() { + for succ in function.blocks[bfs.idx()].successors().as_ref() { + if function.blocks[succ.idx()].kind.try_reduce().is_some() { + bfs_dest = Some((bfs, *succ)); + } else if !visited.contains(succ) { + queue.push_back(*succ); + visited.insert(*succ); + } + } + } + let (bottom_parallel_block, reduce_block) = bfs_dest.unwrap(); + + // Find the jump out of the reduce block. + let (successor, _, reduce_exit) = function.blocks[reduce_block.idx()] + .insts + .iter() + .filter_map(|inst| inst.try_jump()) + .next() + .unwrap(); + let reduce_reducts = reduce_exit.unwrap().reduce_reducts.clone(); + + // Find the thread ID instructions. + let mut thread_ids = (0..thread_counts.len()) + .map(|dim| (dim, vec![])) + .collect::<HashMap<usize, Vec<usize>>>(); + for parallel_block in visited { + for (inst, (virt_reg, _)) in zip( + function.blocks[parallel_block.idx()].insts.iter(), + function.blocks[parallel_block.idx()].virt_regs.iter(), + ) { + if let Some(dim) = inst.try_thread_id() { + thread_ids.get_mut(&dim).unwrap().push(*virt_reg); + } + } + } + + // Find the reduction variable instructions. + let reduction_variables = zip( + function.blocks[reduce_block.idx()].insts.iter(), + function.blocks[reduce_block.idx()].virt_regs.iter(), + ) + .filter_map(|(inst, (virt_reg, _))| { + inst.try_reduction_variable() + .map(|number| (number, *virt_reg)) + }) + .collect(); + + // Assemble all of the info and add it to the map. + let info = ParallelReduceInfo { + predecessor, + successor, + + top_parallel_block, + bottom_parallel_block, + reduce_block, + + thread_counts, + reduce_inits, + reduce_reducts, + + thread_ids, + reduction_variables, + }; + result.insert(fork_join_id, info); + } + } + } + + result +} diff --git a/hercules_ir/src/build.rs b/hercules_ir/src/build.rs index 1d6c47546b8c3f3c49c4f238fa15d9f283af621f..6a73da890a5bdb4fe0dfd04dec98b1da506763cb 100644 --- a/hercules_ir/src/build.rs +++ b/hercules_ir/src/build.rs @@ -360,9 +360,9 @@ impl<'a> Builder<'a> { self.intern_constant(Constant::Array(ty), ty) } - pub fn create_constant_zero(&mut self, typ : TypeID) -> ConstantID { + pub fn create_constant_zero(&mut self, typ: TypeID) -> ConstantID { match &self.module.types[typ.idx()] { - Type::Control(_) => panic!("Cannot create constant for control types"), + Type::Control => panic!("Cannot create constant for control types"), Type::Boolean => self.create_constant_bool(false), Type::Integer8 => self.create_constant_i8(0), Type::Integer16 => self.create_constant_i16(0), @@ -380,13 +380,13 @@ impl<'a> Builder<'a> { cs.push(self.create_constant_zero(t)); } self.create_constant_prod(cs.into()) - }, + } Type::Summation(cs) => { assert!(cs.len() >= 1, "Cannot create zero for empty summation"); let c = self.create_constant_zero(cs[0]); self.create_constant_sum(typ, 0, c) .expect("Exists and well typed by construction") - }, + } Type::Array(t, dims) => self.create_constant_array(*t, dims.clone()), } } @@ -411,8 +411,6 @@ impl<'a> Builder<'a> { Index::Position(idx) } - - pub fn create_function( &mut self, name: &str, @@ -469,8 +467,8 @@ impl NodeBuilder { self.node = Node::If { control, cond }; } - pub fn build_fork(&mut self, control: NodeID, factor: DynamicConstantID) { - self.node = Node::Fork { control, factor }; + pub fn build_fork(&mut self, control: NodeID, factors: Box<[DynamicConstantID]>) { + self.node = Node::Fork { control, factors }; } pub fn build_join(&mut self, control: NodeID) { @@ -481,8 +479,8 @@ impl NodeBuilder { self.node = Node::Phi { control, data }; } - pub fn build_threadid(&mut self, control: NodeID) { - self.node = Node::ThreadID { control }; + pub fn build_thread_id(&mut self, control: NodeID, dimension: usize) { + self.node = Node::ThreadID { control, dimension }; } pub fn build_collect(&mut self, control: NodeID, init: NodeID, reduct: NodeID) { @@ -549,15 +547,8 @@ impl NodeBuilder { }; } - pub fn build_intrinsic( - &mut self, - intrinsic : Intrinsic, - args: Box<[NodeID]>, - ) { - self.node = Node::IntrinsicCall { - intrinsic, - args, - }; + pub fn build_intrinsic(&mut self, intrinsic: Intrinsic, args: Box<[NodeID]>) { + self.node = Node::IntrinsicCall { intrinsic, args }; } pub fn build_read(&mut self, collect: NodeID, indices: Box<[Index]>) { diff --git a/hercules_ir/src/dataflow.rs b/hercules_ir/src/dataflow.rs index bde6be4ad8789b14c612e6e1f6341483faefff29..6df19d1408e06cc9e18c2f3f52e4eb5f5f618315 100644 --- a/hercules_ir/src/dataflow.rs +++ b/hercules_ir/src/dataflow.rs @@ -361,7 +361,10 @@ pub fn immediate_control_flow( // Step 1: replace node if this is a phi, thread ID, or collect. if let Node::Phi { control, data: _ } - | Node::ThreadID { control } + | Node::ThreadID { + control, + dimension: _, + } | Node::Reduce { control, init: _, diff --git a/hercules_ir/src/def_use.rs b/hercules_ir/src/def_use.rs index a9c6f6e96853c7ed04fdc8ef88d610411e2137c7..e237ef55a6c9296d2c2fd8e7e6bb5a41d98e393b 100644 --- a/hercules_ir/src/def_use.rs +++ b/hercules_ir/src/def_use.rs @@ -143,14 +143,20 @@ pub fn get_uses<'a>(node: &'a Node) -> NodeUses<'a> { Node::Region { preds } => NodeUses::Variable(preds), Node::If { control, cond } => NodeUses::Two([*control, *cond]), Node::Match { control, sum } => NodeUses::Two([*control, *sum]), - Node::Fork { control, factor: _ } => NodeUses::One([*control]), + Node::Fork { + control, + factors: _, + } => NodeUses::One([*control]), Node::Join { control } => NodeUses::One([*control]), Node::Phi { control, data } => { let mut uses: Vec<NodeID> = Vec::from(&data[..]); uses.push(*control); NodeUses::Owned(uses.into_boxed_slice()) } - Node::ThreadID { control } => NodeUses::One([*control]), + Node::ThreadID { + control, + dimension: _, + } => NodeUses::One([*control]), Node::Reduce { control, init, @@ -173,10 +179,7 @@ pub fn get_uses<'a>(node: &'a Node) -> NodeUses<'a> { dynamic_constants: _, args, } => NodeUses::Variable(args), - Node::IntrinsicCall { - intrinsic: _, - args, - } => NodeUses::Variable(args), + Node::IntrinsicCall { intrinsic: _, args } => NodeUses::Variable(args), Node::Read { collect, indices } => { let mut uses = vec![]; for index in indices.iter() { @@ -214,10 +217,10 @@ pub fn get_uses<'a>(node: &'a Node) -> NodeUses<'a> { NodeUses::Two([*collect, *data]) } } - Node::Projection { control, selection: _ } => { - NodeUses::One([*control]) - - } + Node::Projection { + control, + selection: _, + } => NodeUses::One([*control]), } } @@ -234,12 +237,18 @@ pub fn get_uses_mut<'a>(node: &'a mut Node) -> NodeUsesMut<'a> { Node::Region { preds } => NodeUsesMut::Variable(preds.iter_mut().collect()), Node::If { control, cond } => NodeUsesMut::Two([control, cond]), Node::Match { control, sum } => NodeUsesMut::Two([control, sum]), - Node::Fork { control, factor: _ } => NodeUsesMut::One([control]), + Node::Fork { + control, + factors: _, + } => NodeUsesMut::One([control]), Node::Join { control } => NodeUsesMut::One([control]), Node::Phi { control, data } => { NodeUsesMut::Variable(std::iter::once(control).chain(data.iter_mut()).collect()) } - Node::ThreadID { control } => NodeUsesMut::One([control]), + Node::ThreadID { + control, + dimension: _, + } => NodeUsesMut::One([control]), Node::Reduce { control, init, @@ -262,10 +271,9 @@ pub fn get_uses_mut<'a>(node: &'a mut Node) -> NodeUsesMut<'a> { dynamic_constants: _, args, } => NodeUsesMut::Variable(args.iter_mut().collect()), - Node::IntrinsicCall { - intrinsic: _, - args, - } => NodeUsesMut::Variable(args.iter_mut().collect()), + Node::IntrinsicCall { intrinsic: _, args } => { + NodeUsesMut::Variable(args.iter_mut().collect()) + } Node::Read { collect, indices } => { let mut uses = vec![]; for index in indices.iter_mut() { @@ -303,8 +311,9 @@ pub fn get_uses_mut<'a>(node: &'a mut Node) -> NodeUsesMut<'a> { NodeUsesMut::Two([collect, data]) } } - Node::Projection { control, selection } => { - NodeUsesMut::One([control]) - }, + Node::Projection { + control, + selection: _, + } => NodeUsesMut::One([control]), } } diff --git a/hercules_ir/src/dot.rs b/hercules_ir/src/dot.rs index 92603712241bb7f653fda49ef559ce4188745c33..92bfdaaff77cf067055b0d8fcf4d1a1d1ea1d4e8 100644 --- a/hercules_ir/src/dot.rs +++ b/hercules_ir/src/dot.rs @@ -22,7 +22,6 @@ pub fn xdot_module( fork_join_maps: Option<&Vec<HashMap<NodeID, NodeID>>>, bbs: Option<&Vec<Vec<NodeID>>>, plans: Option<&Vec<Plan>>, - fork_join_placements: Option<&Vec<Vec<ForkJoinPlacement>>>, ) { let mut tmp_path = temp_dir(); let mut rng = rand::thread_rng(); @@ -37,7 +36,6 @@ pub fn xdot_module( fork_join_maps, bbs, plans, - fork_join_placements, &mut contents, ) .expect("PANIC: Unable to generate output file contents."); @@ -46,7 +44,7 @@ pub fn xdot_module( Command::new("xdot") .args([tmp_path]) .output() - .expect("PANIC: Couldn't execute xdot."); + .expect("PANIC: Couldn't execute xdot. Is xdot installed?"); } /* @@ -60,7 +58,6 @@ pub fn write_dot<W: Write>( fork_join_maps: Option<&Vec<HashMap<NodeID, NodeID>>>, bbs: Option<&Vec<Vec<NodeID>>>, plans: Option<&Vec<Plan>>, - fork_join_placements: Option<&Vec<Vec<ForkJoinPlacement>>>, w: &mut W, ) -> std::fmt::Result { write_digraph_header(w)?; @@ -218,44 +215,6 @@ pub fn write_dot<W: Write>( } } - // Step 5: draw fork-join placement edges in purple. - if let Some(fork_join_placements) = fork_join_placements { - let fork_join_map = &fork_join_maps.unwrap()[function_id.idx()]; - let fork_join_placement = &fork_join_placements[function_id.idx()]; - for node_idx in 0..fork_join_placement.len() { - let node_id = NodeID::new(node_idx); - match fork_join_placement[node_id.idx()] { - ForkJoinPlacement::Sequential => {} - ForkJoinPlacement::Fork(fork_id) => { - write_edge( - node_id, - function_id, - fork_id, - function_id, - true, - "purple, constraint=false", - "dotted", - &module, - w, - )?; - } - ForkJoinPlacement::Reduce(fork_id) => { - write_edge( - node_id, - function_id, - fork_join_map[&fork_id], - function_id, - true, - "purple, constraint=false", - "dotted", - &module, - w, - )?; - } - } - } - } - write_graph_footer(w)?; } @@ -327,7 +286,26 @@ fn write_node<W: Write>( // node label. let mut suffix = String::new(); match node { - Node::Fork { control: _, factor } => module.write_dynamic_constant(*factor, &mut suffix)?, + Node::Fork { + control: _, + factors, + } => { + for token in factors + .into_iter() + .map(|factor| Some(factor)) + .intersperse(None) + { + if let Some(factor) = token { + module.write_dynamic_constant(*factor, &mut suffix)?; + } else { + write!(&mut suffix, ", ")?; + } + } + } + Node::ThreadID { + control: _, + dimension, + } => write!(&mut suffix, "{}", dimension)?, Node::Parameter { index } => write!(&mut suffix, "#{}", index)?, Node::Constant { id } => module.write_constant(*id, &mut suffix)?, Node::DynamicConstant { id } => module.write_dynamic_constant(*id, &mut suffix)?, @@ -342,10 +320,7 @@ fn write_node<W: Write>( module.write_dynamic_constant(*dc_id, &mut suffix)?; } } - Node::IntrinsicCall { - intrinsic, - args: _, - } => { + Node::IntrinsicCall { intrinsic, args: _ } => { write!(&mut suffix, "{}", intrinsic.lower_case_name())? } Node::Read { diff --git a/hercules_ir/src/gcm.rs b/hercules_ir/src/gcm.rs index 9da269885c84cf802eba561eb0aad9334ed21103..716dae5f143fe591fa0c51b5b8b389312c6a4b9a 100644 --- a/hercules_ir/src/gcm.rs +++ b/hercules_ir/src/gcm.rs @@ -163,175 +163,3 @@ pub fn compute_fork_join_nesting( }) .collect() } - -/* - * Find all the reduce-cycles in a function. - */ -pub fn compute_reduce_cycles(function: &Function) -> HashMap<NodeID, Vec<NodeID>> { - let mut result = HashMap::new(); - let mut dfs_visited = bitvec![u8, Lsb0; 0; function.nodes.len()]; - - for id in (0..function.nodes.len()).map(NodeID::new) { - if let Node::Reduce { - control: _, - init: _, - reduct, - } = &function.nodes[id.idx()] - { - // DFS to find data cycle "rooted" at reduce. - dfs_visited.fill(false); - dfs_visited.set(id.idx(), true); - // The stack starts with the reduce node itself and the `reduct` use - // of the reduce node. - let mut dfs_stack = vec![(id, 0), (*reduct, 0)]; - 'dfs: while let Some((node_id, use_idx)) = dfs_stack.pop() { - if node_id == id { - // If we returned to the reduce node, then there is no - // cycle. This will be signified by any empty vector in the - // return map. - break; - } - - dfs_visited.set(node_id.idx(), true); - - // If there are further uses... - let uses = get_uses(&function.nodes[node_id.idx()]); - if use_idx < uses.as_ref().len() { - // Push ourselves back on to the stack. - dfs_stack.push((node_id, use_idx + 1)); - - // Check if the use is a data node. - let use_id = uses.as_ref()[use_idx]; - if !function.nodes[use_id.idx()].is_control() { - // If so, check if the next use was already visited. - if !dfs_visited[use_id.idx()] { - // If not, add the use to the stack. - dfs_stack.push((use_id, 0)); - } else if dfs_stack.iter().any(|(id, _)| *id == use_id) { - // If so, and the use is already in the stack, we've - // found a cycle - if the already visited node we - // found isn't the reduce, then there's a cycle not - // involving the reduce, which isn't valid. - assert_eq!( - id, use_id, - "PANIC: Found cycle not containing expected reduce node." - ); - break 'dfs; - } - } - } - } - - result.insert(id, dfs_stack.into_iter().map(|(id, _)| id).collect()); - } - } - - result -} - -pub fn invert_reduce_cycles( - function: &Function, - reduce_cycles: &HashMap<NodeID, Vec<NodeID>>, - join_fork_map: &HashMap<NodeID, NodeID>, - fork_join_nest: &HashMap<NodeID, Vec<NodeID>>, -) -> Vec<Option<NodeID>> { - let mut result: Vec<Option<NodeID>> = vec![None; function.nodes.len()]; - - for (reduce, in_cycle) in reduce_cycles { - for node in in_cycle { - if let Some(old_reduce) = result[node.idx()] { - // A node may be in multiple reduce cycles when there are nested - // fork-joins. In such cases, we pick the more "deeply nested" - // reduce cycle. - let old_join_id = function.nodes[old_reduce.idx()].try_reduce().unwrap().0; - let old_fork_id = join_fork_map[&old_join_id]; - let new_join_id = function.nodes[reduce.idx()].try_reduce().unwrap().0; - let new_fork_id = join_fork_map[&new_join_id]; - - let old_above_new = fork_join_nest[&new_fork_id].contains(&old_fork_id); - let new_above_old = fork_join_nest[&old_fork_id].contains(&new_fork_id); - assert!(old_above_new ^ new_above_old, "PANIC: A node can only be in reduce cycles that are hierarchically related and from different fork-joins."); - if old_above_new { - result[node.idx()] = Some(*reduce); - } - } else { - result[node.idx()] = Some(*reduce); - } - } - } - - result -} - -/* - * Description of a node's placement amongst fork-joins, generated per-node by - * `compute_fork_join_placements`. - */ -#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)] -pub enum ForkJoinPlacement { - // The node is not "in" any fork-joins. - Sequential, - // The node is in the "fork" section of the fork/join, marked by fork ID. - Fork(NodeID), - // The node is in the "reduce" section of the fork/join, marked by fork ID. - Reduce(NodeID), -} - -/* - * Find which fork/join each data node is "inside" of, based off of basic block - * scheduling (global code motion information). A data node is either not a part - * of any fork/join, is a part of the "fork" section of a fork/join, or is a - * part of the "reduce" section of a fork/join. The following conditions are - * applied in order to determine which category a data node is in: - * 1. If a data node is contained in a cycle containing a reduce node OR is - * scheduled to the basic block of the join node of the fork/join, the data - * node is in the "reduce" section of that fork/join (more specifically, the - * most deeply nested such reduce node). Otherwise... - * 2. If a data node is scheduled to a control node inside a fork/join, the data - * node is in the "fork" section of that fork/join (more specifically, the - * most deeply nested such fork/join). Otherwise... - * 3. If a data node is not in a "reduce" or "fork" section of any fork/join, it - * is a "sequential" node. - */ -pub fn compute_fork_join_placement( - function: &Function, - fork_join_map: &HashMap<NodeID, NodeID>, - fork_join_nest: &HashMap<NodeID, Vec<NodeID>>, - bbs: &Vec<NodeID>, -) -> Vec<ForkJoinPlacement> { - let mut result = vec![ForkJoinPlacement::Sequential; function.nodes.len()]; - let join_fork_map = fork_join_map - .into_iter() - .map(|(fork, join)| (*join, *fork)) - .collect::<HashMap<_, _>>(); - let reduce_cycles = compute_reduce_cycles(function); - let inverted_reduce_cycles = - invert_reduce_cycles(function, &reduce_cycles, &join_fork_map, fork_join_nest); - - for id in (0..function.nodes.len()).map(NodeID::new) { - // Check condition #1. - if let Some(reduce_id) = &inverted_reduce_cycles[id.idx()] { - let join_id = function.nodes[reduce_id.idx()].try_reduce().unwrap().0; - let fork_id = join_fork_map[&join_id]; - result[id.idx()] = ForkJoinPlacement::Reduce(fork_id); - continue; - } - - if let Some(fork_id) = join_fork_map.get(&bbs[id.idx()]) { - result[id.idx()] = ForkJoinPlacement::Reduce(*fork_id); - continue; - } - - // Check condition #2. - let forks = &fork_join_nest[&bbs[id.idx()]]; - if let Some(fork_id) = forks.get(0) { - result[id.idx()] = ForkJoinPlacement::Fork(*fork_id); - continue; - } - - // Default to condition #3. - result[id.idx()] = ForkJoinPlacement::Sequential; - } - - result -} diff --git a/hercules_ir/src/ir.rs b/hercules_ir/src/ir.rs index 7a126672390697d32c8a30945deb6a5f1a6e943d..1ea2950b26d3614b9fd26d84fd6e02e10064c867 100644 --- a/hercules_ir/src/ir.rs +++ b/hercules_ir/src/ir.rs @@ -49,15 +49,12 @@ pub struct Function { * control type. Hercules IR is based off of the sea-of-nodes IR, the main * feature of which being a merged control and data flow graph. Thus, control * is a type of value, just like any other type. However, the type system is - * very restrictive over what can be done with control values. An addition in - * Hercules IR is that a control type is parameterized by a list of thread - * spawning factors. This is the mechanism in Hercules IR for representing - * parallelism. Summation types are an IR equivalent of Rust's enum types. - * These are lowered into tagged unions during scheduling. + * very restrictive over what can be done with control values. Summation types + * are an IR equivalent of Rust's enum types. */ #[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)] pub enum Type { - Control(Box<[NodeID]>), + Control, Boolean, Integer8, Integer16, @@ -159,7 +156,7 @@ pub enum Node { }, Fork { control: NodeID, - factor: DynamicConstantID, + factors: Box<[DynamicConstantID]>, }, Join { control: NodeID, @@ -170,6 +167,7 @@ pub enum Node { }, ThreadID { control: NodeID, + dimension: usize, }, Reduce { control: NodeID, @@ -225,7 +223,7 @@ pub enum Node { Projection { control: NodeID, selection: usize, - } + }, } #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] @@ -302,7 +300,7 @@ impl Module { */ pub fn write_type<W: Write>(&self, ty_id: TypeID, w: &mut W) -> std::fmt::Result { match &self.types[ty_id.idx()] { - Type::Control(_) => write!(w, "Control"), + Type::Control => write!(w, "Control"), Type::Boolean => write!(w, "Boolean"), Type::Integer8 => write!(w, "Integer8"), Type::Integer16 => write!(w, "Integer16"), @@ -622,7 +620,7 @@ impl<T: Clone> GraveUpdatable for Vec<T> { impl Type { pub fn is_control(&self) -> bool { - if let Type::Control(_) = self { + if let Type::Control = self { true } else { false @@ -883,7 +881,7 @@ impl Node { is_fork, Node::Fork { control: _, - factor: _, + factors: _, } ); define_pattern_predicate!(is_join, Node::Join { control: _ }); @@ -894,7 +892,13 @@ impl Node { data: _, } ); - define_pattern_predicate!(is_thread_id, Node::ThreadID { control: _ }); + define_pattern_predicate!( + is_thread_id, + Node::ThreadID { + control: _, + dimension: _ + } + ); define_pattern_predicate!( is_reduce, Node::Reduce { @@ -932,7 +936,7 @@ impl Node { define_pattern_predicate!( is_projection, Node::Projection { - control: _, + control: _, selection: _ } ); @@ -953,6 +957,14 @@ impl Node { } } + pub fn try_proj(&self) -> Option<(NodeID, usize)> { + if let Node::Projection { control, selection } = self { + Some((*control, *selection)) + } else { + None + } + } + pub fn try_phi(&self) -> Option<(NodeID, &[NodeID])> { if let Node::Phi { control, data } = self { Some((*control, data)) @@ -969,17 +981,17 @@ impl Node { } } - pub fn try_fork(&self) -> Option<(NodeID, DynamicConstantID)> { - if let Node::Fork { control, factor } = self { - Some((*control, *factor)) + pub fn try_fork(&self) -> Option<(NodeID, &[DynamicConstantID])> { + if let Node::Fork { control, factors } = self { + Some((*control, factors)) } else { None } } - pub fn try_thread_id(&self) -> Option<NodeID> { - if let Node::ThreadID { control } = self { - Some(*control) + pub fn try_thread_id(&self) -> Option<(NodeID, usize)> { + if let Node::ThreadID { control, dimension } = self { + Some((*control, *dimension)) } else { None } @@ -1063,14 +1075,17 @@ impl Node { Node::Match { control: _, sum: _ } => "Match", Node::Fork { control: _, - factor: _, + factors: _, } => "Fork", Node::Join { control: _ } => "Join", Node::Phi { control: _, data: _, } => "Phi", - Node::ThreadID { control: _ } => "ThreadID", + Node::ThreadID { + control: _, + dimension: _, + } => "ThreadID", Node::Reduce { control: _, init: _, @@ -1113,9 +1128,9 @@ impl Node { data: _, indices: _, } => "Write", - Node::Projection { - control: _, - selection: _ + Node::Projection { + control: _, + selection: _, } => "Projection", } } @@ -1131,14 +1146,17 @@ impl Node { Node::Match { control: _, sum: _ } => "match", Node::Fork { control: _, - factor: _, + factors: _, } => "fork", Node::Join { control: _ } => "join", Node::Phi { control: _, data: _, } => "phi", - Node::ThreadID { control: _ } => "thread_id", + Node::ThreadID { + control: _, + dimension: _, + } => "thread_id", Node::Reduce { control: _, init: _, @@ -1181,9 +1199,9 @@ impl Node { data: _, indices: _, } => "write", - Node::Projection { - control: _, - selection: _ + Node::Projection { + control: _, + selection: _, } => "projection", } } @@ -1279,107 +1297,107 @@ impl TernaryOperator { impl Intrinsic { pub fn parse<'a>(name: &'a str) -> Option<Self> { match name { - "abs" => Some(Intrinsic::Abs), - "acos" => Some(Intrinsic::ACos), - "acosh" => Some(Intrinsic::ACosh), - "asin" => Some(Intrinsic::ASin), - "asinh" => Some(Intrinsic::ASinh), - "atan" => Some(Intrinsic::ATan), - "atan2" => Some(Intrinsic::ATan2), - "atanh" => Some(Intrinsic::ATanh), - "cbrt" => Some(Intrinsic::Cbrt), - "ceil" => Some(Intrinsic::Ceil), - "cos" => Some(Intrinsic::Cos), - "cosh" => Some(Intrinsic::Cosh), - "exp" => Some(Intrinsic::Exp), - "exp2" => Some(Intrinsic::Exp2), + "abs" => Some(Intrinsic::Abs), + "acos" => Some(Intrinsic::ACos), + "acosh" => Some(Intrinsic::ACosh), + "asin" => Some(Intrinsic::ASin), + "asinh" => Some(Intrinsic::ASinh), + "atan" => Some(Intrinsic::ATan), + "atan2" => Some(Intrinsic::ATan2), + "atanh" => Some(Intrinsic::ATanh), + "cbrt" => Some(Intrinsic::Cbrt), + "ceil" => Some(Intrinsic::Ceil), + "cos" => Some(Intrinsic::Cos), + "cosh" => Some(Intrinsic::Cosh), + "exp" => Some(Intrinsic::Exp), + "exp2" => Some(Intrinsic::Exp2), "exp_m1" => Some(Intrinsic::ExpM1), - "floor" => Some(Intrinsic::Floor), - "ln" => Some(Intrinsic::Ln), - "ln_1p" => Some(Intrinsic::Ln1P), - "log" => Some(Intrinsic::Log), - "log10" => Some(Intrinsic::Log10), - "log2" => Some(Intrinsic::Log2), - "pow" => Some(Intrinsic::Pow), - "powf" => Some(Intrinsic::Powf), - "powi" => Some(Intrinsic::Powi), - "round" => Some(Intrinsic::Round), - "sin" => Some(Intrinsic::Sin), - "sinh" => Some(Intrinsic::Sinh), - "sqrt" => Some(Intrinsic::Sqrt), - "tan" => Some(Intrinsic::Tan), - "tanh" => Some(Intrinsic::Tanh), + "floor" => Some(Intrinsic::Floor), + "ln" => Some(Intrinsic::Ln), + "ln_1p" => Some(Intrinsic::Ln1P), + "log" => Some(Intrinsic::Log), + "log10" => Some(Intrinsic::Log10), + "log2" => Some(Intrinsic::Log2), + "pow" => Some(Intrinsic::Pow), + "powf" => Some(Intrinsic::Powf), + "powi" => Some(Intrinsic::Powi), + "round" => Some(Intrinsic::Round), + "sin" => Some(Intrinsic::Sin), + "sinh" => Some(Intrinsic::Sinh), + "sqrt" => Some(Intrinsic::Sqrt), + "tan" => Some(Intrinsic::Tan), + "tanh" => Some(Intrinsic::Tanh), _ => None, } } pub fn upper_case_name(&self) -> &'static str { match self { - Intrinsic::Abs => "Abs", - Intrinsic::ACos => "Acos", + Intrinsic::Abs => "Abs", + Intrinsic::ACos => "Acos", Intrinsic::ACosh => "Acosh", - Intrinsic::ASin => "Asin", + Intrinsic::ASin => "Asin", Intrinsic::ASinh => "Asinh", - Intrinsic::ATan => "Atan", + Intrinsic::ATan => "Atan", Intrinsic::ATan2 => "Atan2", Intrinsic::ATanh => "Atanh", - Intrinsic::Cbrt => "Cbrt", - Intrinsic::Ceil => "Ceil", - Intrinsic::Cos => "Cos", - Intrinsic::Cosh => "Cosh", - Intrinsic::Exp => "Exp", - Intrinsic::Exp2 => "Exp2", + Intrinsic::Cbrt => "Cbrt", + Intrinsic::Ceil => "Ceil", + Intrinsic::Cos => "Cos", + Intrinsic::Cosh => "Cosh", + Intrinsic::Exp => "Exp", + Intrinsic::Exp2 => "Exp2", Intrinsic::ExpM1 => "Exp_m1", Intrinsic::Floor => "Floor", - Intrinsic::Ln => "Ln", - Intrinsic::Ln1P => "Ln_1p", - Intrinsic::Log => "Log", + Intrinsic::Ln => "Ln", + Intrinsic::Ln1P => "Ln_1p", + Intrinsic::Log => "Log", Intrinsic::Log10 => "Log10", - Intrinsic::Log2 => "Log2", - Intrinsic::Pow => "Pow", - Intrinsic::Powf => "Powf", - Intrinsic::Powi => "Powi", + Intrinsic::Log2 => "Log2", + Intrinsic::Pow => "Pow", + Intrinsic::Powf => "Powf", + Intrinsic::Powi => "Powi", Intrinsic::Round => "Round", - Intrinsic::Sin => "Sin", - Intrinsic::Sinh => "Sinh", - Intrinsic::Sqrt => "Sqrt", - Intrinsic::Tan => "Tan", - Intrinsic::Tanh => "Tanh", + Intrinsic::Sin => "Sin", + Intrinsic::Sinh => "Sinh", + Intrinsic::Sqrt => "Sqrt", + Intrinsic::Tan => "Tan", + Intrinsic::Tanh => "Tanh", } } pub fn lower_case_name(&self) -> &'static str { match self { - Intrinsic::Abs => "abs", - Intrinsic::ACos => "acos", + Intrinsic::Abs => "abs", + Intrinsic::ACos => "acos", Intrinsic::ACosh => "acosh", - Intrinsic::ASin => "asin", + Intrinsic::ASin => "asin", Intrinsic::ASinh => "asinh", - Intrinsic::ATan => "atan", + Intrinsic::ATan => "atan", Intrinsic::ATan2 => "atan2", Intrinsic::ATanh => "atanh", - Intrinsic::Cbrt => "cbrt", - Intrinsic::Ceil => "ceil", - Intrinsic::Cos => "cos", - Intrinsic::Cosh => "cosh", - Intrinsic::Exp => "exp", - Intrinsic::Exp2 => "exp2", + Intrinsic::Cbrt => "cbrt", + Intrinsic::Ceil => "ceil", + Intrinsic::Cos => "cos", + Intrinsic::Cosh => "cosh", + Intrinsic::Exp => "exp", + Intrinsic::Exp2 => "exp2", Intrinsic::ExpM1 => "exp_m1", Intrinsic::Floor => "floor", - Intrinsic::Ln => "ln", - Intrinsic::Ln1P => "ln_1p", - Intrinsic::Log => "log", + Intrinsic::Ln => "ln", + Intrinsic::Ln1P => "ln_1p", + Intrinsic::Log => "log", Intrinsic::Log10 => "log10", - Intrinsic::Log2 => "log2", - Intrinsic::Pow => "pow", - Intrinsic::Powf => "powf", - Intrinsic::Powi => "powi", + Intrinsic::Log2 => "log2", + Intrinsic::Pow => "pow", + Intrinsic::Powf => "powf", + Intrinsic::Powi => "powi", Intrinsic::Round => "round", - Intrinsic::Sin => "sin", - Intrinsic::Sinh => "sinh", - Intrinsic::Sqrt => "sqrt", - Intrinsic::Tan => "tan", - Intrinsic::Tanh => "tanh", + Intrinsic::Sin => "sin", + Intrinsic::Sinh => "sinh", + Intrinsic::Sqrt => "sqrt", + Intrinsic::Tan => "tan", + Intrinsic::Tanh => "tanh", } } } @@ -1440,6 +1458,14 @@ macro_rules! define_dual_id_type { pub fn new(x: usize, y: usize) -> Self { $x(x as u32, y as u32) } + + pub fn idx_0(&self) -> usize { + self.0 as usize + } + + pub fn idx_1(&self) -> usize { + self.1 as usize + } } }; } diff --git a/hercules_ir/src/lib.rs b/hercules_ir/src/lib.rs index c2fe35b704a08bc9bb4a32ffe778a082c96eb4f6..8dd42d6c13e2d6d230cda758bad5a9aaafe81b71 100644 --- a/hercules_ir/src/lib.rs +++ b/hercules_ir/src/lib.rs @@ -1,4 +1,10 @@ -#![feature(coroutines, coroutine_trait, let_chains, stmt_expr_attributes)] +#![feature( + coroutines, + coroutine_trait, + let_chains, + stmt_expr_attributes, + iter_intersperse +)] pub mod antideps; pub mod build; @@ -9,7 +15,6 @@ pub mod dot; pub mod gcm; pub mod ir; pub mod loops; -pub mod manifest; pub mod parse; pub mod schedule; pub mod subgraph; @@ -25,7 +30,6 @@ pub use crate::dot::*; pub use crate::gcm::*; pub use crate::ir::*; pub use crate::loops::*; -pub use crate::manifest::*; pub use crate::parse::*; pub use crate::schedule::*; pub use crate::subgraph::*; diff --git a/hercules_ir/src/manifest.rs b/hercules_ir/src/manifest.rs deleted file mode 100644 index f2bbd03ab884529d89fe320ec33fc6508ca5d394..0000000000000000000000000000000000000000 --- a/hercules_ir/src/manifest.rs +++ /dev/null @@ -1,232 +0,0 @@ -extern crate serde; - -use self::serde::Deserialize; -use self::serde::Serialize; - -use crate::*; - -/* - * Every .hbin file contains a manifest which describes the Hercules functions - * contained in the module. This information is used by the runtime to execute - * the functions properly, the chief concern being how to stitch together the - * execution of each partition. - */ -#[derive(Debug, Serialize, Deserialize, Hash)] -pub struct ModuleManifest { - // A module contains a manifest per individual function. - pub functions: Vec<FunctionManifest>, - // All of the types used in the module. - pub types: Vec<Type>, - // All of the sizes and alignments for types in the module. - pub type_sizes_aligns: Vec<(Option<usize>, usize)>, - // All of the dynamic constants used in the module. - pub dynamic_constants: Vec<DynamicConstant>, - // Store the types of constants. - pub constant_types: Vec<(ConstantID, TypeID)>, - // The only constants that aren't baked into the generated code are array - // constants. These are explicitly stored in and loaded from the manifest. - // We only store the array bytes if the array is non-zero. If the array is - // all zeros, we optionally store the size of the array. We may not be able - // to store the array size at compile time if it's all zeros, as its size - // may be a dynamic constant with compile-time unknown size. - pub array_constants: Vec<ConstantBytes>, - // Store map from array number, in array_constants, to constant ID. - pub array_cons_ids: Vec<ConstantID>, -} - -/* - * embed_constant calculates the byte representation of a constant. For zero - * constants, we avoid storing the actual zero bytes, and optionally store the - * size - zero constant arrays may have dynamic constant dimensions unknown at - * compile time. The second usize in each variant is the alignment the bytes - * must be at to be used properly. - */ -#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)] -pub enum ConstantBytes { - NonZero(Vec<u8>, usize), - Zero(Option<usize>, usize), -} - -#[derive(Debug, Serialize, Deserialize, Hash)] -pub struct FunctionManifest { - pub name: String, - // Types of the function parameters. - pub param_types: Vec<TypeID>, - // Type of the return value. - pub return_type: TypeID, - // Types of all of the nodes in this function. Used for figuring out the - // type of partition data inputs and outputs. - pub typing: Vec<TypeID>, - // IDs of constants that are actually used in this function. - pub used_constants: Vec<ConstantID>, - // Number of dynamic constant parameters that need to provided. - pub num_dynamic_constant_parameters: u32, - // Manifests for constituent partitions. - pub partitions: Vec<PartitionManifest>, - // When using dynamic constants, certain constraints are generated. For - // example, if one uses division in dynamic constant math, the resulting - // dynamic constant must be an integer, so the numerator dynamic constant - // must be divisible by the denominator dynamic constant. These are stored - // per function, since different functions have different contraints on - // their dynamic constant parameters. - pub dynamic_constant_rules: Vec<DynamicConstantRule>, - // The partition containing the start node might not be partition 0. - pub top_partition: PartitionID, - // Keep track of which node values are returned from the overall function. - pub returned_values: Vec<NodeID>, -} - -/* - * Rules for validity of provided dynamic constants - */ -#[derive(Debug, Serialize, Deserialize, Hash, PartialEq, Eq)] -pub enum DynamicConstantRule { - // Generated from subtraction. - LessThan(DynamicConstantID, DynamicConstantID), - // Generated from division. - Divides(DynamicConstantID, DynamicConstantID), -} - -#[derive(Debug, Serialize, Deserialize, Hash, PartialEq, Eq)] -pub enum PartitionInput { - // Data input from another partition within this function. - DataInput(NodeID), - // An argument from the function parameters. Integer is the parameter index. - FunctionArgument(u32), - // An array constant used in this function. Integer is the array constant - // number. - ArrayConstant(u32), - // A dynamic constant parameter of this function. Integer is the dynamic - // constant parameter number. - DynamicConstant(u32), -} - -#[derive(Debug, Serialize, Deserialize, Hash, PartialEq, Eq)] -pub enum PartitionOutput { - // Data output used by another partition within this function, or to be - // returned from this function. - DataOutput(NodeID), - // Value indicating control flow that the runtime should take. - ControlIndicator, -} - -#[derive(Debug, Serialize, Deserialize, Default, Hash)] -pub struct PartitionManifest { - // Top node for this partition. - pub top_node: NodeID, - pub inputs: Vec<PartitionInput>, - pub outputs: Vec<PartitionOutput>, - - // Store the partitions that come after this partition. The first element of - // the pair is the control returning node in the current partition that - // corresponds to a particular successor, and the second element of the pair - // is the partition ID of the corresponding successor partition. - pub successor_partitions: Vec<(NodeID, PartitionID)>, -} - -impl ConstantBytes { - pub fn align(&self) -> usize { - match self { - ConstantBytes::NonZero(_, align) => *align, - ConstantBytes::Zero(_, align) => *align, - } - } -} - -/* - * Get the data inputs of partitions in a function. - */ -pub fn function_partition_inputs(manifest: &FunctionManifest) -> Vec<NodeID> { - let mut all_data_inputs = manifest - .partitions - .iter() - .map(|manifest| { - manifest.inputs.iter().filter_map(|input| { - if let PartitionInput::DataInput(idx) = input { - Some(*idx) - } else { - None - } - }) - }) - .flatten() - .collect::<Vec<NodeID>>(); - all_data_inputs.sort(); - all_data_inputs.dedup(); - all_data_inputs -} - -/* - * Get the data outputs of partitions in a function. - */ -pub fn function_partition_outputs(manifest: &FunctionManifest) -> Vec<NodeID> { - let mut all_data_outputs = manifest - .partitions - .iter() - .map(|manifest| { - manifest.outputs.iter().filter_map(|output| { - if let PartitionOutput::DataOutput(idx) = output { - Some(*idx) - } else { - None - } - }) - }) - .flatten() - .collect::<Vec<NodeID>>(); - all_data_outputs.sort(); - all_data_outputs.dedup(); - all_data_outputs -} - -/* - * Verify a manifest is valid. - */ -pub fn verify_manifest(manifest: &ModuleManifest) -> Result<(), String> { - manifest - .functions - .iter() - .map(|manifest| { - // First, check that all partitiion data inputs get used, and that - // all partition data outputs are produced. - let mut all_data_inputs = function_partition_inputs(manifest); - let all_data_outputs = function_partition_outputs(manifest); - - // Data outputs include values returned from the overall function. - // These are obviously not used by another partition, but that's - // fine. We add these artificially as "inputs" to a partition so - // that this check still passes. We don't just remove the returned - // value from the data outputs because it might also be the - // legitimate output of another partition, and doing that would - // cause a false negative. - all_data_inputs.extend(&manifest.returned_values); - all_data_inputs.sort(); - all_data_inputs.dedup(); - if all_data_inputs != all_data_outputs { - return Err(format!("PANIC: Partitions in manifest contain inconsistent data inputs and data outputs of partitions.\nHere are the data input IDs:\n{:?}\nHere are the data output IDs:\n{:?}\n", all_data_inputs, all_data_outputs)); - } - - // Next, check that partitions contain multiple successors if and - // only if a control indicator is present in their outputs. - // TODO: this will fail to verify a partition with a return node and - // a single successor partition, which technically is allowed. In - // this scenario, a control indicator is required (to tell if we - // should return or advance the control token), but there is only - // one successor partition. Deal with this later. - manifest.partitions.iter().map(|manifest| { - let multiple_successors = manifest.successor_partitions.len() > 1; - let control_indicator = manifest.outputs.contains(&PartitionOutput::ControlIndicator); - if !multiple_successors && control_indicator { - Err("PANIC: Partition in manifest doesn't have multiple successor partitions, but has a control indicator output.".to_string()) - } else if multiple_successors && !control_indicator { - Err("PANIC: Partition in manifest has multiple successor partitions, but not a control indicator output.".to_string()) - } else { - Ok(()) - } - }).collect::<Result<Vec<()>, String>>()?; - - Ok(()) - }) - .collect::<Result<Vec<()>, String>>() - .map(|_| ()) -} diff --git a/hercules_ir/src/parse.rs b/hercules_ir/src/parse.rs index d53ffdb288aca78540463c0fb7e46709e8d3fe6c..5b9ff6940ea1127f307f2ced0add35666a093f12 100644 --- a/hercules_ir/src/parse.rs +++ b/hercules_ir/src/parse.rs @@ -138,7 +138,7 @@ fn parse_module<'a>(ir_text: &'a str, context: Context<'a>) -> nom::IResult<&'a // Assemble flat lists of interned goodies, now that we've figured out // everyones' IDs. - let mut types = vec![Type::Control(Box::new([])); context.interned_types.len()]; + let mut types = vec![Type::Control; context.interned_types.len()]; for (ty, id) in context.interned_types { types[id.idx()] = ty; } @@ -373,13 +373,21 @@ fn parse_if<'a>(ir_text: &'a str, context: &RefCell<Context<'a>>) -> nom::IResul } fn parse_fork<'a>(ir_text: &'a str, context: &RefCell<Context<'a>>) -> nom::IResult<&'a str, Node> { - let (ir_text, (control, factor)) = - parse_tuple2(parse_identifier, |x| parse_dynamic_constant_id(x, context))(ir_text)?; + let (ir_text, (control, factors)) = parse_tuple2( + parse_identifier, + nom::multi::separated_list1( + nom::sequence::tuple(( + nom::character::complete::multispace0, + nom::character::complete::char(','), + nom::character::complete::multispace0, + )), + |x| parse_dynamic_constant_id(x, context), + ), + )(ir_text)?; let control = context.borrow_mut().get_node_id(control); - // Because parse_dynamic_constant_id returned a DynamicConstantID directly, - // we don't need to manually convert it here. - Ok((ir_text, Node::Fork { control, factor })) + let factors = factors.into(); + Ok((ir_text, Node::Fork { control, factors })) } fn parse_join<'a>(ir_text: &'a str, context: &RefCell<Context<'a>>) -> nom::IResult<&'a str, Node> { @@ -417,9 +425,10 @@ fn parse_thread_id<'a>( ir_text: &'a str, context: &RefCell<Context<'a>>, ) -> nom::IResult<&'a str, Node> { - let (ir_text, (control,)) = parse_tuple1(parse_identifier)(ir_text)?; + let (ir_text, (control, dimension)) = + parse_tuple2(parse_identifier, |x| parse_prim::<usize>(x, "1234567890"))(ir_text)?; let control = context.borrow_mut().get_node_id(control); - Ok((ir_text, Node::ThreadID { control })) + Ok((ir_text, Node::ThreadID { control, dimension })) } fn parse_reduce<'a>( @@ -572,7 +581,10 @@ fn parse_call<'a>(ir_text: &'a str, context: &RefCell<Context<'a>>) -> nom::IRes )) } -fn parse_intrinsic<'a>(ir_text: &'a str, context: &RefCell<Context<'a>>) -> nom::IResult<&'a str, Node> { +fn parse_intrinsic<'a>( + ir_text: &'a str, + context: &RefCell<Context<'a>>, +) -> nom::IResult<&'a str, Node> { // Intrinsic nodes take an intrinsic name as an argument and a variable // number of normal arguments. let ir_text = nom::character::complete::multispace0(ir_text)?.0; @@ -593,11 +605,10 @@ fn parse_intrinsic<'a>(ir_text: &'a str, context: &RefCell<Context<'a>>) -> nom: .collect(); let ir_text = nom::character::complete::multispace0(ir_text)?.0; let ir_text = nom::character::complete::char(')')(ir_text)?.0; - let intrinsic = Intrinsic::parse(intrinsic) - .ok_or(nom::Err::Error(nom::error::Error { - input: ir_text, - code: nom::error::ErrorKind::IsNot, - }))?; + let intrinsic = Intrinsic::parse(intrinsic).ok_or(nom::Err::Error(nom::error::Error { + input: ir_text, + code: nom::error::ErrorKind::IsNot, + }))?; Ok(( ir_text, Node::IntrinsicCall { @@ -671,12 +682,20 @@ fn parse_index<'a>( Ok((ir_text, idx)) } -fn parse_projection<'a>(ir_text: &'a str, context: &RefCell<Context<'a>>, +fn parse_projection<'a>( + ir_text: &'a str, + context: &RefCell<Context<'a>>, ) -> nom::IResult<&'a str, Node> { let parse_usize = |x| parse_prim::<usize>(x, "1234567890"); let (ir_text, (control, index)) = parse_tuple2(parse_identifier, parse_usize)(ir_text)?; let control = context.borrow_mut().get_node_id(control); - Ok((ir_text, Node::Projection { control, selection: index })) + Ok(( + ir_text, + Node::Projection { + control, + selection: index, + }, + )) } fn parse_read<'a>(ir_text: &'a str, context: &RefCell<Context<'a>>) -> nom::IResult<&'a str, Node> { @@ -767,34 +786,7 @@ fn parse_match<'a>( fn parse_type<'a>(ir_text: &'a str, context: &RefCell<Context<'a>>) -> nom::IResult<&'a str, Type> { let ir_text = nom::character::complete::multispace0(ir_text)?.0; let (ir_text, ty) = nom::branch::alt(( - // Control tokens are parameterized by a list of dynamic constants - // representing their thread spawn factors. - nom::combinator::map( - nom::sequence::tuple(( - nom::bytes::complete::tag("ctrl"), - nom::character::complete::multispace0, - nom::character::complete::char('('), - nom::character::complete::multispace0, - nom::multi::separated_list1( - nom::sequence::tuple(( - nom::character::complete::multispace0, - nom::character::complete::char(','), - nom::character::complete::multispace0, - )), - |x| { - let (ir_text, node) = parse_identifier(x)?; - Ok((ir_text, context.borrow_mut().get_node_id(node))) - }, - ), - nom::character::complete::multispace0, - nom::character::complete::char(')'), - )), - |(_, _, _, _, id, _, _)| Type::Control(id.into_boxed_slice()), - ), - // If no arguments are provided, assumed that no forks have occurred. - nom::combinator::map(nom::bytes::complete::tag("ctrl"), |_| { - Type::Control(Box::new([])) - }), + nom::combinator::map(nom::bytes::complete::tag("ctrl"), |_| Type::Control), // Primitive types are written in Rust style. nom::combinator::map(nom::bytes::complete::tag("bool"), |_| Type::Boolean), nom::combinator::map(nom::bytes::complete::tag("i8"), |_| Type::Integer8), @@ -941,7 +933,7 @@ fn parse_constant<'a>( ) -> nom::IResult<&'a str, Constant> { let (ir_text, constant) = match ty { // There are not control constants. - Type::Control(_) => Err(nom::Err::Error(nom::error::Error { + Type::Control => Err(nom::Err::Error(nom::error::Error { input: ir_text, code: nom::error::ErrorKind::IsNot, }))?, diff --git a/hercules_ir/src/schedule.rs b/hercules_ir/src/schedule.rs index fb839b5c4a469e8f6a9b8a718f361f959e0c85e7..86f18424fb603e443650ba3b4d131723921c9737 100644 --- a/hercules_ir/src/schedule.rs +++ b/hercules_ir/src/schedule.rs @@ -285,7 +285,23 @@ impl Plan { } } - // Fourth, verify that every partition has at least one partition + // Fourth, verify that every projection node is in the same partition as + // its control use. + for id in (0..function.nodes.len()).map(NodeID::new) { + if let Node::Projection { + control, + selection: _, + } = function.nodes[id.idx()] + { + assert_eq!( + self.partitions[id.idx()], + self.partitions[control.idx()], + "PANIC: Found a projection node in a different partition than its control use." + ); + } + } + + // Fifth, verify that every partition has at least one partition // successor xor has at least one return node. for partition_idx in 0..self.num_partitions { let has_successor = partition_graph.succs(NodeID::new(partition_idx)).count() > 0; @@ -354,8 +370,11 @@ impl Plan { // For every non-phi node, check each of its data uses. If the // node and its use are in different partitions, then the use is // a data input for the partition of the node. Also, don't add - // the same node to the data inputs list twice. + // the same node to the data inputs list twice. Only consider + // non-constant uses data inputs - constant nodes are always + // rematerialized into the user partition. if !function.nodes[use_id.idx()].is_control() + && !function.nodes[use_id.idx()].is_constant() && self.partitions[id.idx()] != self.partitions[use_id.idx()] && !data_inputs.contains(use_id) { @@ -380,7 +399,7 @@ impl Plan { // be lowered to a single parameter to the corresponding simple // IR function. Note that for a phi node with some uses outside // and some uses inside the partition, the uses outside the - // partition become a single parameter to the simple IR + // partition become a single parameter to the schedule IR // function, and that parameter and all of the "inside" uses // become the inputs to a phi inside the simple IR function. if self.partitions[id.idx()] != self.partitions[use_id.idx()] @@ -410,7 +429,9 @@ impl Plan { let mut data_outputs = vec![vec![]; self.num_partitions]; for id in (0..function.nodes.len()).map(NodeID::new) { - if function.nodes[id.idx()].is_control() { + // Only consider non-constant data nodes as data outputs, since + // constant nodes are rematerialized into the user partition. + if function.nodes[id.idx()].is_control() || function.nodes[id.idx()].is_constant() { continue; } @@ -474,13 +495,13 @@ pub fn default_plan( /* * Infer parallel reductions consisting of a simple cycle between a Reduce node - * and a Write node, where an index of the Write is a position index using the - * ThreadID node attached to the corresponding Fork. This procedure also adds + * and a Write node, where indices of the Write are position indices using the + * ThreadID nodes attached to the corresponding Fork. This procedure also adds * the ParallelReduce schedule to Reduce nodes reducing over a parallelized - * Reduce, as long as the base Write node also has a position index that is the + * Reduce, as long as the base Write node also has position indices of the * ThreadID of the outer fork. In other words, the complete Reduce chain is - * annotated with ParallelReduce, as long as each ThreadID appears in the - * positional indexing of the Write. + * annotated with ParallelReduce, as long as each ThreadID dimension appears in + * the positional indexing of the original Write. */ pub fn infer_parallel_reduce( function: &Function, @@ -533,17 +554,27 @@ pub fn infer_parallel_reduce( .flat_map(|pos| pos.iter()); // Get the Forks corresponding to uses of bare ThreadIDs. - let mut forks = positions.filter_map(|id| { - if let Node::ThreadID { control } = function.nodes[id.idx()] { - Some(control) + let fork_thread_id_pairs = positions.filter_map(|id| { + if let Node::ThreadID { control, dimension } = function.nodes[id.idx()] { + Some((control, dimension)) } else { None } }); + let mut forks = HashMap::<NodeID, Vec<usize>>::new(); + for (fork, dim) in fork_thread_id_pairs { + forks.entry(fork).or_default().push(dim); + } - // Check if any of the Forks correspond to the Join associated with - // the Reduce being considered. - let is_parallel = forks.any(|id| fork_join_map[&id] == first_control.unwrap()); + // Check if one of the Forks correspond to the Join associated with + // the Reduce being considered, and has all of its dimensions + // represented in the indexing. + let is_parallel = forks.into_iter().any(|(id, mut rep_dims)| { + rep_dims.sort(); + rep_dims.dedup(); + fork_join_map[&id] == first_control.unwrap() + && function.nodes[id.idx()].try_fork().unwrap().1.len() == rep_dims.len() + }); if is_parallel { plan.schedules[id.idx()].push(Schedule::ParallelReduce); @@ -619,7 +650,10 @@ pub fn partition_out_forks( reverse_postorder, |inputs: &[&NodeID], node_id: NodeID| match function.nodes[node_id.idx()] { Node::Start => NodeID::new(0), - Node::Fork { control, factor: _ } => { + Node::Fork { + control, + factors: _, + } => { // Start a partition if the preceding partition isn't a fork // partition and the predecessor isn't the join for the // predecessor fork partition. Otherwise, be part of the parent diff --git a/hercules_ir/src/typecheck.rs b/hercules_ir/src/typecheck.rs index 4b22f153e381d72c3f07e7709d1ca3d17dbb3010..6137982bb6595697625da48e779104d525680e96 100644 --- a/hercules_ir/src/typecheck.rs +++ b/hercules_ir/src/typecheck.rs @@ -1,4 +1,4 @@ -use std::collections::HashMap; +use std::collections::{HashMap, HashSet}; use std::iter::zip; use crate::*; @@ -100,10 +100,6 @@ pub fn typecheck( .map(|(idx, ty)| (ty.clone(), TypeID::new(idx))) .collect(); - // Also create a join replication factor map. This is needed to typecheck - // collect node. - let mut join_factor_map: HashMap<NodeID, DynamicConstantID> = HashMap::new(); - // Step 2: run dataflow. This is an occurrence of dataflow where the flow // function performs a non-associative operation on the predecessor "out" // values. @@ -119,7 +115,6 @@ pub fn typecheck( constants, dynamic_constants, &mut reverse_type_map, - &mut join_factor_map, ) }) }) @@ -157,7 +152,6 @@ fn typeflow( constants: &Vec<Constant>, dynamic_constants: &Vec<DynamicConstant>, reverse_type_map: &mut HashMap<Type, TypeID>, - join_factor_map: &mut HashMap<NodeID, DynamicConstantID>, ) -> TypeSemilattice { // Whenever we want to reference a specific type (for example, for the // start node), we need to get its type ID. This helper function gets the @@ -200,11 +194,7 @@ fn typeflow( } // The start node is the producer of the control token. - Concrete(get_type_id( - Type::Control(Box::new([])), - types, - reverse_type_map, - )) + Concrete(get_type_id(Type::Control, types, reverse_type_map)) } Node::Region { preds: _ } => { if inputs.len() == 0 { @@ -262,30 +252,28 @@ fn typeflow( inputs[0].clone() } - Node::Fork { control: _, factor } => { + Node::Fork { + control: _, + factors, + } => { if inputs.len() != 1 { return Error(String::from("Fork node must have exactly one input.")); } - if !check_dynamic_constants(*factor, dynamic_constants, function.num_dynamic_constants) - { - return Error(String::from("Referenced parameter dynamic constant is not a valid dynamic constant parameter for the current function.")); + for factor in factors.iter() { + if !check_dynamic_constants( + *factor, + dynamic_constants, + function.num_dynamic_constants, + ) { + return Error(String::from("Referenced parameter dynamic constant is not a valid dynamic constant parameter for the current function.")); + } } if let Concrete(id) = inputs[0] { - if let Type::Control(factors) = &types[id.idx()] { - // Fork adds a new factor to the thread spawn factor list. - let mut new_factors = factors.clone().into_vec(); - new_factors.push(node_id); - - // Out type is control type, with the new thread spawn - // factor. - let control_out_id = get_type_id( - Type::Control(new_factors.into_boxed_slice()), - types, - reverse_type_map, - ); - return Concrete(control_out_id); + if let Type::Control = &types[id.idx()] { + // Out type is control type. + return Concrete(*id); } else { return Error(String::from( "Fork node's input cannot have non-control type.", @@ -304,28 +292,8 @@ fn typeflow( // concrete output type, so just return the control input non- // concrete type. if let Concrete(control_id) = inputs[0] { - if let Type::Control(factors) = &types[control_id.idx()] { - // Join removes a factor from the factor list. - if factors.len() == 0 { - return Error(String::from("Join node's first input must have a control type with at least one thread replication factor.")); - } - let mut new_factors = factors.clone().into_vec(); - let factor = if let Node::Fork { control: _, factor } = - function.nodes[new_factors.pop().unwrap().idx()] - { - factor - } else { - panic!("Node ID in factor list doesn't correspond with a fork node."); - }; - join_factor_map.insert(node_id, factor); - - // Out type is the new control type. - let control_out_id = get_type_id( - Type::Control(new_factors.into_boxed_slice()), - types, - reverse_type_map, - ); - return Concrete(control_out_id); + if let Type::Control = &types[control_id.idx()] { + return Concrete(*control_id); } else { return Error(String::from( "Join node's first input cannot have non-control type.", @@ -372,7 +340,10 @@ fn typeflow( meet } - Node::ThreadID { control: _ } => { + Node::ThreadID { + control: _, + dimension: _, + } => { if inputs.len() != 1 { return Error(String::from("ThreadID node must have exactly one input.")); } @@ -402,7 +373,7 @@ fn typeflow( (inputs[0], inputs[1], inputs[2]) { // Check control input is control. - if let Type::Control(_) = types[control_id.idx()] { + if let Type::Control = types[control_id.idx()] { } else { return Error(String::from( "Reduce node's control input must have control type.", @@ -410,14 +381,14 @@ fn typeflow( } // Check init input isn't control. - if let Type::Control(_) = types[init_id.idx()] { + if let Type::Control = types[init_id.idx()] { return Error(String::from( "Reduce node's initialization input must not have control type.", )); } // Check reduct input isn't control. - if let Type::Control(_) = types[reduct_id.idx()] { + if let Type::Control = types[reduct_id.idx()] { return Error(String::from( "Reduce node's reduction input must not have control type.", )); @@ -441,13 +412,7 @@ fn typeflow( // Check type of control input first, since this may produce an // error. if let Concrete(id) = inputs[0] { - if let Type::Control(factors) = &types[id.idx()] { - if factors.len() != 0 { - return Error(String::from( - "Return node's control input must have no thread replications.", - )); - } - } else { + if !types[id.idx()].is_control() { return Error(String::from( "Return node's control input cannot have non-control type.", )); @@ -779,26 +744,39 @@ fn typeflow( Concrete(callee.return_type) } - Node::IntrinsicCall { - intrinsic, - args: _, - } => { - let num_params = - match intrinsic { - Intrinsic::Abs | Intrinsic::ACos | Intrinsic::ACosh - | Intrinsic::ASin | Intrinsic::ASinh | Intrinsic::ATan - | Intrinsic::ATanh | Intrinsic::Cbrt | Intrinsic::Ceil - | Intrinsic::Cos | Intrinsic::Cosh | Intrinsic::Exp - | Intrinsic::Exp2 | Intrinsic::ExpM1 | Intrinsic::Floor - | Intrinsic::Ln | Intrinsic::Ln1P | Intrinsic::Log10 - | Intrinsic::Log2 | Intrinsic::Round | Intrinsic::Sin - | Intrinsic::Sinh | Intrinsic::Sqrt | Intrinsic::Tan - | Intrinsic::Tanh - => 1, - Intrinsic::ATan2 | Intrinsic::Log | Intrinsic::Pow - | Intrinsic::Powf | Intrinsic::Powi - => 2, - }; + Node::IntrinsicCall { intrinsic, args: _ } => { + let num_params = match intrinsic { + Intrinsic::Abs + | Intrinsic::ACos + | Intrinsic::ACosh + | Intrinsic::ASin + | Intrinsic::ASinh + | Intrinsic::ATan + | Intrinsic::ATanh + | Intrinsic::Cbrt + | Intrinsic::Ceil + | Intrinsic::Cos + | Intrinsic::Cosh + | Intrinsic::Exp + | Intrinsic::Exp2 + | Intrinsic::ExpM1 + | Intrinsic::Floor + | Intrinsic::Ln + | Intrinsic::Ln1P + | Intrinsic::Log10 + | Intrinsic::Log2 + | Intrinsic::Round + | Intrinsic::Sin + | Intrinsic::Sinh + | Intrinsic::Sqrt + | Intrinsic::Tan + | Intrinsic::Tanh => 1, + Intrinsic::ATan2 + | Intrinsic::Log + | Intrinsic::Pow + | Intrinsic::Powf + | Intrinsic::Powi => 2, + }; // Check number of run-time arguments if inputs.len() != num_params { @@ -820,94 +798,122 @@ fn typeflow( if types[id.idx()].is_arithmetic() { Concrete(*id) } else { - Error(format!("{} intrinsic cannot have non-numeric input type.", - intrinsic.lower_case_name())) + Error(format!( + "{} intrinsic cannot have non-numeric input type.", + intrinsic.lower_case_name() + )) } } else { // Otherwise, propogate errors and unconstrained types (*inputs[0]).clone() } - }, + } // Intrinsics that take any float type and return the same - Intrinsic::ACos | Intrinsic::ACosh | Intrinsic::ASin - | Intrinsic::ASinh | Intrinsic::ATan | Intrinsic::ATanh - | Intrinsic::Cbrt | Intrinsic::Ceil | Intrinsic::Cos - | Intrinsic::Cosh | Intrinsic::Exp | Intrinsic::Exp2 - | Intrinsic::ExpM1 | Intrinsic::Floor | Intrinsic::Ln - | Intrinsic::Ln1P | Intrinsic::Log10 | Intrinsic::Log2 - | Intrinsic::Round | Intrinsic::Sin | Intrinsic::Sinh - | Intrinsic::Sqrt | Intrinsic::Tan | Intrinsic::Tanh - => { + Intrinsic::ACos + | Intrinsic::ACosh + | Intrinsic::ASin + | Intrinsic::ASinh + | Intrinsic::ATan + | Intrinsic::ATanh + | Intrinsic::Cbrt + | Intrinsic::Ceil + | Intrinsic::Cos + | Intrinsic::Cosh + | Intrinsic::Exp + | Intrinsic::Exp2 + | Intrinsic::ExpM1 + | Intrinsic::Floor + | Intrinsic::Ln + | Intrinsic::Ln1P + | Intrinsic::Log10 + | Intrinsic::Log2 + | Intrinsic::Round + | Intrinsic::Sin + | Intrinsic::Sinh + | Intrinsic::Sqrt + | Intrinsic::Tan + | Intrinsic::Tanh => { if let Concrete(id) = inputs[0] { if types[id.idx()].is_float() { Concrete(*id) } else { - Error(format!("{} intrinsic cannot have non-float input type.", - intrinsic.lower_case_name())) + Error(format!( + "{} intrinsic cannot have non-float input type.", + intrinsic.lower_case_name() + )) } } else { // Otherwise, propogate errors and unconstrained types (*inputs[0]).clone() } - }, + } // Intrinsics that take any two values of the same float type // and return the same - Intrinsic::ATan2 | Intrinsic::Log | Intrinsic::Powf - => { + Intrinsic::ATan2 | Intrinsic::Log | Intrinsic::Powf => { let input_ty = TypeSemilattice::meet(inputs[0], inputs[1]); if let Concrete(id) = input_ty { if types[id.idx()].is_float() { Concrete(id) } else { - Error(format!("{} intrinsic cannot have non-float input types.", - intrinsic.lower_case_name())) + Error(format!( + "{} intrinsic cannot have non-float input types.", + intrinsic.lower_case_name() + )) } } else { // Otherwise, propogate errors and unconstrained types (*inputs[0]).clone() } - }, + } Intrinsic::Pow => { if let Concrete(id) = inputs[0] { if types[id.idx()].is_fixed() { if let Concrete(id) = inputs[1] { if types[id.idx()] != Type::UnsignedInteger32 { - return Error(format!("{} intrinsic expects u32 as second argument.", - intrinsic.lower_case_name())); + return Error(format!( + "{} intrinsic expects u32 as second argument.", + intrinsic.lower_case_name() + )); } } Concrete(*id) } else { - Error(format!("{} intrinsic cannot have non-integer first argument.", - intrinsic.lower_case_name())) + Error(format!( + "{} intrinsic cannot have non-integer first argument.", + intrinsic.lower_case_name() + )) } } else { // Otherwise, propagate errors and unconstrained types (*inputs[0]).clone() } - }, + } Intrinsic::Powi => { if let Concrete(id) = inputs[0] { if types[id.idx()].is_float() { if let Concrete(id) = inputs[1] { if types[id.idx()] != Type::Integer32 { - return Error(format!("{} intrinsic expects i32 as second argument.", - intrinsic.lower_case_name())); + return Error(format!( + "{} intrinsic expects i32 as second argument.", + intrinsic.lower_case_name() + )); } } Concrete(*id) } else { - Error(format!("{} intrinsic cannot have non-float first argument.", - intrinsic.lower_case_name())) + Error(format!( + "{} intrinsic cannot have non-float first argument.", + intrinsic.lower_case_name() + )) } } else { // Otherwise, propagate errors and unconstrained types (*inputs[0]).clone() } - }, + } } - }, + } Node::Read { collect: _, indices, @@ -1032,10 +1038,13 @@ fn typeflow( TypeSemilattice::Error(msg) => TypeSemilattice::Error(msg), } } - Node::Projection { control: _, selection: _ } => { + Node::Projection { + control: _, + selection: _, + } => { // Type is the type of the _if node inputs[0].clone() - }, + } } } @@ -1045,24 +1054,36 @@ fn typeflow( * (that's the mechanism used to implement this analysis). This analysis depends * on type information. */ -pub fn fork_join_map( - function: &Function, - typing: &Vec<TypeID>, - types: &Vec<Type>, -) -> HashMap<NodeID, NodeID> { +pub fn fork_join_map(function: &Function, control: &Subgraph) -> HashMap<NodeID, NodeID> { let mut fork_join_map = HashMap::new(); for idx in 0..function.nodes.len() { // We only care about join nodes. - if let Node::Join { control } = function.nodes[idx] { - // A join's input type must be control. Since we have types, if this - // isn't the case, the typing is incorrect and we should panic. - if let Type::Control(factors) = &types[typing[control.idx()].idx()] { - let join_id = NodeID::new(idx); - let fork_id = *factors.last().unwrap(); - fork_join_map.insert(fork_id, join_id); - } else { - panic!("Join node's control predecessor has a non-control type."); - } + if function.nodes[idx].is_join() { + // Iterate the control predecessors until finding a fork. Maintain a + // counter of unmatched fork-join pairs seen on the way, since fork- + // joins may be nested. Every join is dominated by their fork, so + // just iterate the first unseen predecessor of each control node. + let join_id = NodeID::new(idx); + let mut unpaired = 0; + let mut cursor = join_id; + let mut seen = HashSet::<NodeID>::new(); + let fork_id = loop { + cursor = control + .preds(cursor) + .filter(|pred| !seen.contains(pred)) + .next() + .unwrap(); + seen.insert(cursor); + + if function.nodes[cursor.idx()].is_join() { + unpaired += 1; + } else if function.nodes[cursor.idx()].is_fork() && unpaired > 0 { + unpaired -= 1; + } else if function.nodes[cursor.idx()].is_fork() { + break cursor; + } + }; + fork_join_map.insert(fork_id, join_id); } } fork_join_map diff --git a/hercules_ir/src/verify.rs b/hercules_ir/src/verify.rs index 815e735606da3074f832786487a3eb49c064cedc..c86920f9efe17333ab20e335dec4b8bd124974e7 100644 --- a/hercules_ir/src/verify.rs +++ b/hercules_ir/src/verify.rs @@ -34,8 +34,11 @@ pub fn verify( let typing = typecheck(module, &reverse_postorders)?; // Assemble fork join maps for module. - let fork_join_maps: Vec<_> = zip(module.functions.iter(), typing.iter()) - .map(|(function, typing)| fork_join_map(function, typing, &module.types)) + let subgraphs: Vec<_> = zip(module.functions.iter(), def_uses.iter()) + .map(|(function, def_use)| control_subgraph(function, def_use)) + .collect(); + let fork_join_maps: Vec<_> = zip(module.functions.iter(), subgraphs.iter()) + .map(|(function, subgraph)| fork_join_map(function, subgraph)) .collect(); // Check the structure of the functions in the module. @@ -46,9 +49,6 @@ pub fn verify( } // Calculate dominator and postdominator trees. - let subgraphs: Vec<_> = zip(module.functions.iter(), def_uses.iter()) - .map(|(function, def_use)| control_subgraph(function, def_use)) - .collect(); let doms: Vec<_> = subgraphs .iter() .map(|subgraph| dominator(subgraph, NodeID::new(0))) @@ -174,12 +174,15 @@ fn verify_structure( // it may have many thread ID users. Node::Fork { control: _, - factor: _, + factors: _, } => { let mut found_control = false; for user in users { match function.nodes[user.idx()] { - Node::ThreadID { control: _ } => {} + Node::ThreadID { + control: _, + dimension: _, + } => {} _ => { if function.nodes[user.idx()].is_control() { if found_control { @@ -188,7 +191,9 @@ fn verify_structure( found_control = true; } } else { - Err("All users of a fork node must be control or ThreadID nodes.")?; + Err( + "All users of a fork node must be control or thread ID nodes.", + )?; } } } @@ -235,22 +240,19 @@ fn verify_structure( Err(format!("If node must have 2 users, not {}.", users.len()))?; } if let ( - Node::Projection { - control: _, + Node::Projection { + control: _, selection: result1, }, - Node::Projection { - control: _, + Node::Projection { + control: _, selection: result2, }, ) = ( &function.nodes[users[0].idx()], &function.nodes[users[1].idx()], ) { - - if - !((*result1 == 0 && *result2 == 1) || (*result2 == 0 && *result1 == 1)) - { + if !((*result1 == 0 && *result2 == 1) || (*result2 == 0 && *result1 == 1)) { Err("If node's user Read nodes must reference different elements of output product.")?; } } else { @@ -265,12 +267,15 @@ fn verify_structure( } } // ThreadID nodes must depend on a fork node. - Node::ThreadID { control } => { + Node::ThreadID { control, dimension } => { if let Node::Fork { control: _, - factor: _, - } = function.nodes[control.idx()] + factors, + } = &function.nodes[control.idx()] { + if *dimension >= factors.len() { + Err("The dimension in a thread ID must be in bounds for its fork node.")?; + } } else { Err("ThreadID node's control input must be a fork node.")?; } @@ -404,6 +409,7 @@ fn verify_dominance_relationships( // here, since these nodes can explicitly have non-dominating inputs. let this_id = if let Node::ThreadID { control: dominated_control, + dimension: _, } = function.nodes[idx] { dominated_control @@ -457,7 +463,10 @@ fn verify_dominance_relationships( } // Verify that uses of thread ID nodes are dominated by the // corresponding fork nodes. - Node::ThreadID { control } => { + Node::ThreadID { + control, + dimension: _, + } => { if dom.contains(this_id) && !dom.does_dom(control, this_id) { Err(format!( "ThreadID node (ID {}) doesn't dominate its use (ID {}).", diff --git a/hercules_opt/Cargo.toml b/hercules_opt/Cargo.toml index 1a903d6d63431d73f75b6494d13d43bd32aa7414..6682ae8980d76bbdd4ac392add2de7b280442fb9 100644 --- a/hercules_opt/Cargo.toml +++ b/hercules_opt/Cargo.toml @@ -9,4 +9,5 @@ bitvec = "*" take_mut = "*" postcard = { version = "*", features = ["alloc"] } serde = { version = "*", features = ["derive"] } +hercules_cg = { path = "../hercules_cg" } hercules_ir = { path = "../hercules_ir" } diff --git a/hercules_opt/src/ccp.rs b/hercules_opt/src/ccp.rs index 6f644d46ceb9057a990e98efbde0212e18ed3c1d..5d13a6f4c6514134e8c4b17c75c84ccb37a0a2e4 100644 --- a/hercules_opt/src/ccp.rs +++ b/hercules_opt/src/ccp.rs @@ -162,14 +162,13 @@ macro_rules! binary_float_intrinsic { */ pub fn ccp( function: &mut Function, - types: &Vec<Type>, constants: &mut Vec<Constant>, def_use: &ImmutableDefUseMap, reverse_postorder: &Vec<NodeID>, ) { // Step 1: run ccp analysis to understand the function. let result = dataflow_global(&function, reverse_postorder, |inputs, node_id| { - ccp_flow_function(inputs, node_id, &function, &types, &constants) + ccp_flow_function(inputs, node_id, &function, &constants) }); // Step 2: update uses of constants. Any node that doesn't produce a @@ -401,7 +400,6 @@ fn ccp_flow_function( inputs: &[CCPLattice], node_id: NodeID, function: &Function, - types: &Vec<Type>, old_constants: &Vec<Constant>, ) -> CCPLattice { let node = &function.nodes[node_id.idx()]; @@ -414,7 +412,10 @@ fn ccp_flow_function( // reachability and constant propagation. Read handles that. Node::If { control, cond: _ } => inputs[control.idx()].clone(), Node::Match { control, sum: _ } => inputs[control.idx()].clone(), - Node::Fork { control, factor: _ } => inputs[control.idx()].clone(), + Node::Fork { + control, + factors: _, + } => inputs[control.idx()].clone(), Node::Join { control } => inputs[control.idx()].clone(), // Phi nodes must look at the reachability of the inputs to its // corresponding region node to determine the constant value being @@ -445,7 +446,10 @@ fn ccp_flow_function( } // TODO: This should produce a constant zero if the dynamic constant for // for the corresponding fork is one. - Node::ThreadID { control } => inputs[control.idx()].clone(), + Node::ThreadID { + control, + dimension: _, + } => inputs[control.idx()].clone(), // TODO: At least for now, reduce nodes always produce unknown values. Node::Reduce { control, @@ -479,6 +483,10 @@ fn ccp_flow_function( (UnaryOperator::Not, Constant::Integer16(val)) => ConstantLattice::Constant(Constant::Integer16(!val)), (UnaryOperator::Not, Constant::Integer32(val)) => ConstantLattice::Constant(Constant::Integer32(!val)), (UnaryOperator::Not, Constant::Integer64(val)) => ConstantLattice::Constant(Constant::Integer64(!val)), + (UnaryOperator::Not, Constant::UnsignedInteger8(val)) => ConstantLattice::Constant(Constant::UnsignedInteger8(!val)), + (UnaryOperator::Not, Constant::UnsignedInteger16(val)) => ConstantLattice::Constant(Constant::UnsignedInteger16(!val)), + (UnaryOperator::Not, Constant::UnsignedInteger32(val)) => ConstantLattice::Constant(Constant::UnsignedInteger32(!val)), + (UnaryOperator::Not, Constant::UnsignedInteger64(val)) => ConstantLattice::Constant(Constant::UnsignedInteger64(!val)), (UnaryOperator::Neg, Constant::Integer8(val)) => ConstantLattice::Constant(Constant::Integer8(-val)), (UnaryOperator::Neg, Constant::Integer16(val)) => ConstantLattice::Constant(Constant::Integer16(-val)), (UnaryOperator::Neg, Constant::Integer32(val)) => ConstantLattice::Constant(Constant::Integer32(-val)), @@ -734,23 +742,20 @@ fn ccp_flow_function( }), constant: ConstantLattice::bottom(), }, - Node::IntrinsicCall { - intrinsic, - args, - } => { + Node::IntrinsicCall { intrinsic, args } => { let mut new_reachability = ReachabilityLattice::top(); let mut new_constant = ConstantLattice::top(); let mut constants = vec![]; let mut all_constants = true; for arg in args.iter() { - let CCPLattice { ref reachability, ref constant } - = inputs[arg.idx()]; + let CCPLattice { + ref reachability, + ref constant, + } = inputs[arg.idx()]; - new_reachability = - ReachabilityLattice::meet(&new_reachability, reachability); - new_constant = - ConstantLattice::meet(&new_constant, constant); + new_reachability = ReachabilityLattice::meet(&new_reachability, reachability); + new_constant = ConstantLattice::meet(&new_constant, constant); if let ConstantLattice::Constant(constant) = constant { constants.push(constant); @@ -760,112 +765,120 @@ fn ccp_flow_function( } if all_constants { - new_constant = - match intrinsic { - Intrinsic::Abs => { + new_constant = match intrinsic { + Intrinsic::Abs => { + if let Constant::Integer8(i) = constants[0] { + ConstantLattice::Constant(Constant::Integer8(i.abs())) + } else if let Constant::Integer16(i) = constants[0] { + ConstantLattice::Constant(Constant::Integer16(i.abs())) + } else if let Constant::Integer32(i) = constants[0] { + ConstantLattice::Constant(Constant::Integer32(i.abs())) + } else if let Constant::Integer64(i) = constants[0] { + ConstantLattice::Constant(Constant::Integer64(i.abs())) + } else if let Constant::UnsignedInteger8(i) = constants[0] { + ConstantLattice::Constant(Constant::UnsignedInteger8(*i)) + } else if let Constant::UnsignedInteger16(i) = constants[0] { + ConstantLattice::Constant(Constant::UnsignedInteger16(*i)) + } else if let Constant::UnsignedInteger32(i) = constants[0] { + ConstantLattice::Constant(Constant::UnsignedInteger32(*i)) + } else if let Constant::UnsignedInteger64(i) = constants[0] { + ConstantLattice::Constant(Constant::UnsignedInteger64(*i)) + } else if let Constant::Float32(i) = constants[0] { + ConstantLattice::Constant(Constant::Float32( + ordered_float::OrderedFloat(i.abs()), + )) + } else if let Constant::Float64(i) = constants[0] { + ConstantLattice::Constant(Constant::Float64( + ordered_float::OrderedFloat(i.abs()), + )) + } else { + panic!("Unsupported combination of intrinsic abs and constant value. Did typechecking succeed?") + } + } + Intrinsic::ACos => unary_float_intrinsic!(intrinsic, constants, acos), + Intrinsic::ACosh => unary_float_intrinsic!(intrinsic, constants, acosh), + Intrinsic::ASin => unary_float_intrinsic!(intrinsic, constants, asin), + Intrinsic::ASinh => unary_float_intrinsic!(intrinsic, constants, asinh), + Intrinsic::ATan => unary_float_intrinsic!(intrinsic, constants, atan), + Intrinsic::ATan2 => binary_float_intrinsic!(intrinsic, constants, atan2), + Intrinsic::ATanh => unary_float_intrinsic!(intrinsic, constants, atanh), + Intrinsic::Cbrt => unary_float_intrinsic!(intrinsic, constants, cbrt), + Intrinsic::Ceil => unary_float_intrinsic!(intrinsic, constants, ceil), + Intrinsic::Cos => unary_float_intrinsic!(intrinsic, constants, cos), + Intrinsic::Cosh => unary_float_intrinsic!(intrinsic, constants, cosh), + Intrinsic::Exp => unary_float_intrinsic!(intrinsic, constants, exp), + Intrinsic::Exp2 => unary_float_intrinsic!(intrinsic, constants, exp2), + Intrinsic::ExpM1 => unary_float_intrinsic!(intrinsic, constants, exp_m1), + Intrinsic::Floor => unary_float_intrinsic!(intrinsic, constants, floor), + Intrinsic::Ln => unary_float_intrinsic!(intrinsic, constants, ln), + Intrinsic::Ln1P => unary_float_intrinsic!(intrinsic, constants, ln_1p), + Intrinsic::Log => binary_float_intrinsic!(intrinsic, constants, log), + Intrinsic::Log10 => unary_float_intrinsic!(intrinsic, constants, log10), + Intrinsic::Log2 => unary_float_intrinsic!(intrinsic, constants, log2), + Intrinsic::Pow => { + if let Constant::UnsignedInteger32(p) = constants[1] { if let Constant::Integer8(i) = constants[0] { - ConstantLattice::Constant(Constant::Integer8(i.abs())) + ConstantLattice::Constant(Constant::Integer8(i.pow(*p))) } else if let Constant::Integer16(i) = constants[0] { - ConstantLattice::Constant(Constant::Integer16(i.abs())) + ConstantLattice::Constant(Constant::Integer16(i.pow(*p))) } else if let Constant::Integer32(i) = constants[0] { - ConstantLattice::Constant(Constant::Integer32(i.abs())) + ConstantLattice::Constant(Constant::Integer32(i.pow(*p))) } else if let Constant::Integer64(i) = constants[0] { - ConstantLattice::Constant(Constant::Integer64(i.abs())) + ConstantLattice::Constant(Constant::Integer64(i.pow(*p))) } else if let Constant::UnsignedInteger8(i) = constants[0] { - ConstantLattice::Constant(Constant::UnsignedInteger8(*i)) + ConstantLattice::Constant(Constant::UnsignedInteger8(i.pow(*p))) } else if let Constant::UnsignedInteger16(i) = constants[0] { - ConstantLattice::Constant(Constant::UnsignedInteger16(*i)) + ConstantLattice::Constant(Constant::UnsignedInteger16(i.pow(*p))) } else if let Constant::UnsignedInteger32(i) = constants[0] { - ConstantLattice::Constant(Constant::UnsignedInteger32(*i)) + ConstantLattice::Constant(Constant::UnsignedInteger32(i.pow(*p))) } else if let Constant::UnsignedInteger64(i) = constants[0] { - ConstantLattice::Constant(Constant::UnsignedInteger64(*i)) - } else if let Constant::Float32(i) = constants[0] { - ConstantLattice::Constant(Constant::Float32(ordered_float::OrderedFloat(i.abs()))) - } else if let Constant::Float64(i) = constants[0] { - ConstantLattice::Constant(Constant::Float64(ordered_float::OrderedFloat(i.abs()))) - } else { - panic!("Unsupported combination of intrinsic abs and constant value. Did typechecking succeed?") - } - }, - Intrinsic::ACos => unary_float_intrinsic!(intrinsic, constants, acos), - Intrinsic::ACosh => unary_float_intrinsic!(intrinsic, constants, acosh), - Intrinsic::ASin => unary_float_intrinsic!(intrinsic, constants, asin), - Intrinsic::ASinh => unary_float_intrinsic!(intrinsic, constants, asinh), - Intrinsic::ATan => unary_float_intrinsic!(intrinsic, constants, atan), - Intrinsic::ATan2 => binary_float_intrinsic!(intrinsic, constants, atan2), - Intrinsic::ATanh => unary_float_intrinsic!(intrinsic, constants, atanh), - Intrinsic::Cbrt => unary_float_intrinsic!(intrinsic, constants, cbrt), - Intrinsic::Ceil => unary_float_intrinsic!(intrinsic, constants, ceil), - Intrinsic::Cos => unary_float_intrinsic!(intrinsic, constants, cos), - Intrinsic::Cosh => unary_float_intrinsic!(intrinsic, constants, cosh), - Intrinsic::Exp => unary_float_intrinsic!(intrinsic, constants, exp), - Intrinsic::Exp2 => unary_float_intrinsic!(intrinsic, constants, exp2), - Intrinsic::ExpM1 => unary_float_intrinsic!(intrinsic, constants, exp_m1), - Intrinsic::Floor => unary_float_intrinsic!(intrinsic, constants, floor), - Intrinsic::Ln => unary_float_intrinsic!(intrinsic, constants, ln), - Intrinsic::Ln1P => unary_float_intrinsic!(intrinsic, constants, ln_1p), - Intrinsic::Log => binary_float_intrinsic!(intrinsic, constants, log), - Intrinsic::Log10 => unary_float_intrinsic!(intrinsic, constants, log10), - Intrinsic::Log2 => unary_float_intrinsic!(intrinsic, constants, log2), - Intrinsic::Pow => { - if let Constant::UnsignedInteger32(p) = constants[1] { - if let Constant::Integer8(i) = constants[0] { - ConstantLattice::Constant(Constant::Integer8(i.pow(*p))) - } else if let Constant::Integer16(i) = constants[0] { - ConstantLattice::Constant(Constant::Integer16(i.pow(*p))) - } else if let Constant::Integer32(i) = constants[0] { - ConstantLattice::Constant(Constant::Integer32(i.pow(*p))) - } else if let Constant::Integer64(i) = constants[0] { - ConstantLattice::Constant(Constant::Integer64(i.pow(*p))) - } else if let Constant::UnsignedInteger8(i) = constants[0] { - ConstantLattice::Constant(Constant::UnsignedInteger8(i.pow(*p))) - } else if let Constant::UnsignedInteger16(i) = constants[0] { - ConstantLattice::Constant(Constant::UnsignedInteger16(i.pow(*p))) - } else if let Constant::UnsignedInteger32(i) = constants[0] { - ConstantLattice::Constant(Constant::UnsignedInteger32(i.pow(*p))) - } else if let Constant::UnsignedInteger64(i) = constants[0] { - ConstantLattice::Constant(Constant::UnsignedInteger64(i.pow(*p))) - } else { - panic!("Unsupported combination of intrinsic pow and constant values. Did typechecking succeed?") - } + ConstantLattice::Constant(Constant::UnsignedInteger64(i.pow(*p))) } else { panic!("Unsupported combination of intrinsic pow and constant values. Did typechecking succeed?") } - }, - Intrinsic::Powf => binary_float_intrinsic!(intrinsic, constants, powf), - Intrinsic::Powi => { - if let Constant::Integer32(p) = constants[1] { - if let Constant::Float32(v) = constants[0] { - ConstantLattice::Constant(Constant::Float32(ordered_float::OrderedFloat(v.powi(*p)))) - } else if let Constant::Float64(v) = constants[0] { - ConstantLattice::Constant(Constant::Float64(ordered_float::OrderedFloat(v.powi(*p)))) - } else { - panic!("Unsupported combination of intrinsic powi and constant value. Did typechecking succeed?") - } + } else { + panic!("Unsupported combination of intrinsic pow and constant values. Did typechecking succeed?") + } + } + Intrinsic::Powf => binary_float_intrinsic!(intrinsic, constants, powf), + Intrinsic::Powi => { + if let Constant::Integer32(p) = constants[1] { + if let Constant::Float32(v) = constants[0] { + ConstantLattice::Constant(Constant::Float32( + ordered_float::OrderedFloat(v.powi(*p)), + )) + } else if let Constant::Float64(v) = constants[0] { + ConstantLattice::Constant(Constant::Float64( + ordered_float::OrderedFloat(v.powi(*p)), + )) } else { - panic!("Unsupported combination of intrinsic powi and constant values. Did typechecking succeed?") + panic!("Unsupported combination of intrinsic powi and constant value. Did typechecking succeed?") } - }, - Intrinsic::Round => unary_float_intrinsic!(intrinsic, constants, round), - Intrinsic::Sin => unary_float_intrinsic!(intrinsic, constants, sin), - Intrinsic::Sinh => unary_float_intrinsic!(intrinsic, constants, sinh), - Intrinsic::Sqrt => unary_float_intrinsic!(intrinsic, constants, sqrt), - Intrinsic::Tan => unary_float_intrinsic!(intrinsic, constants, tan), - Intrinsic::Tanh => unary_float_intrinsic!(intrinsic, constants, tanh), - }; + } else { + panic!("Unsupported combination of intrinsic powi and constant values. Did typechecking succeed?") + } + } + Intrinsic::Round => unary_float_intrinsic!(intrinsic, constants, round), + Intrinsic::Sin => unary_float_intrinsic!(intrinsic, constants, sin), + Intrinsic::Sinh => unary_float_intrinsic!(intrinsic, constants, sinh), + Intrinsic::Sqrt => unary_float_intrinsic!(intrinsic, constants, sqrt), + Intrinsic::Tan => unary_float_intrinsic!(intrinsic, constants, tan), + Intrinsic::Tanh => unary_float_intrinsic!(intrinsic, constants, tanh), + }; } CCPLattice { reachability: new_reachability, constant: new_constant, } - }, - Node::Read { collect, indices: _ } => { - CCPLattice { - reachability: inputs[collect.idx()].reachability.clone(), - constant: ConstantLattice::bottom(), - } } + Node::Read { + collect, + indices: _, + } => CCPLattice { + reachability: inputs[collect.idx()].reachability.clone(), + constant: ConstantLattice::bottom(), + }, // Projection handles reachability when following an if or match. Node::Projection { control, selection } => match &function.nodes[control.idx()] { Node::If { control: _, cond } => { diff --git a/hercules_opt/src/fork_guard_elim.rs b/hercules_opt/src/fork_guard_elim.rs index a996e67c593b9c470b7f64d9f684046ed0900852..cfa2a6ff344f5641a28d01b7b1fa44feb480b03e 100644 --- a/hercules_opt/src/fork_guard_elim.rs +++ b/hercules_opt/src/fork_guard_elim.rs @@ -2,9 +2,9 @@ extern crate hercules_ir; use std::collections::{HashMap, HashSet}; +use self::hercules_ir::get_uses_mut; use self::hercules_ir::ir::*; use self::hercules_ir::ImmutableDefUseMap; -use self::hercules_ir::get_uses_mut; /* * This is a Hercules IR transformation that: @@ -31,42 +31,77 @@ use self::hercules_ir::get_uses_mut; * with, and also the region that joins the guard's branches mapping to the * fork's join NodeID */ -fn guarded_fork(function: &Function, - constants: &Vec<Constant>, - fork_join_map: &HashMap<NodeID, NodeID>, - def_use: &ImmutableDefUseMap, - index: usize, - node: &Node) - -> Option<(NodeID, DynamicConstantID, NodeID, NodeID, NodeID, - NodeID, HashMap<NodeID, NodeID>)> { +fn guarded_fork( + function: &Function, + constants: &Vec<Constant>, + fork_join_map: &HashMap<NodeID, NodeID>, + def_use: &ImmutableDefUseMap, + index: usize, + node: &Node, +) -> Option<( + NodeID, + Box<[DynamicConstantID]>, + NodeID, + NodeID, + NodeID, + NodeID, + HashMap<NodeID, NodeID>, +)> { // Identify fork nodes - let Node::Fork { control, factor } = node else { return None; }; + let Node::Fork { control, factors } = node else { + return None; + }; // Whose predecessor is a read from an if - let Node::Projection { control : if_node, ref selection } - = function.nodes[control.idx()] else { return None; }; - let Node::If { control : if_pred, cond } = function.nodes[if_node.idx()] - else { return None; }; + let Node::Projection { + control: if_node, + ref selection, + } = function.nodes[control.idx()] + else { + return None; + }; + let Node::If { + control: if_pred, + cond, + } = function.nodes[if_node.idx()] + else { + return None; + }; // Whose condition is appropriate - let Node::Binary { left, right, op } = function.nodes[cond.idx()] - else { return None; }; + let Node::Binary { left, right, op } = function.nodes[cond.idx()] else { + return None; + }; let branch_idx = *selection; // branchIdx == 1 means the true branch so we want the condition to be // 0 < n or n > 0 if branch_idx == 1 - && !((op == BinaryOperator::LT && function.nodes[left.idx()].is_zero_constant(constants) - && function.nodes[right.idx()].try_dynamic_constant() == Some(*factor)) - || (op == BinaryOperator::GT && function.nodes[right.idx()].is_zero_constant(constants) - && function.nodes[left.idx()].try_dynamic_constant() == Some(*factor))) { + && !((op == BinaryOperator::LT + && function.nodes[left.idx()].is_zero_constant(constants) + && factors + .iter() + .any(|factor| function.nodes[right.idx()].try_dynamic_constant() == Some(*factor))) + || (op == BinaryOperator::GT + && function.nodes[right.idx()].is_zero_constant(constants) + && factors.iter().any(|factor| { + function.nodes[left.idx()].try_dynamic_constant() == Some(*factor) + }))) + { return None; } // branchIdx == 0 means the false branch so we want the condition to be // n < 0 or 0 > n if branch_idx == 0 - && !((op == BinaryOperator::LT && function.nodes[left.idx()].try_dynamic_constant() == Some(*factor) - && function.nodes[right.idx()].is_zero_constant(constants)) - || (op == BinaryOperator::GT && function.nodes[right.idx()].try_dynamic_constant() == Some(*factor) - && function.nodes[left.idx()].is_zero_constant(constants))) { + && !((op == BinaryOperator::LT + && factors + .iter() + .any(|factor| function.nodes[left.idx()].try_dynamic_constant() == Some(*factor)) + && function.nodes[right.idx()].is_zero_constant(constants)) + || (op == BinaryOperator::GT + && factors.iter().any(|factor| { + function.nodes[right.idx()].try_dynamic_constant() == Some(*factor) + }) + && function.nodes[left.idx()].is_zero_constant(constants))) + { return None; } @@ -76,64 +111,98 @@ fn guarded_fork(function: &Function, // Find the unique control use of the join; if it's not a region we can't // eliminate this guard - let join_control = join_users.iter().filter(|n| function.nodes[n.idx()].is_region()) - .collect::<Vec<_>>(); - if join_control.len() != 1 { return None; } + let join_control = join_users + .iter() + .filter(|n| function.nodes[n.idx()].is_region()) + .collect::<Vec<_>>(); + if join_control.len() != 1 { + return None; + } let join_control = join_control[0]; - let Some(Node::Region { preds }) = function.nodes.get(join_control.idx()) - else { return None; }; + let Some(Node::Region { preds }) = function.nodes.get(join_control.idx()) else { + return None; + }; // The region after the join can only have two predecessors (for the guard // and the fork-join) - if preds.len() != 2 { return None; } - let other_pred = - if preds[1] == *join_id { - preds[0] - } else if preds[0] == *join_id { - preds[1] - } else { - return None; - }; + if preds.len() != 2 { + return None; + } + let other_pred = if preds[1] == *join_id { + preds[0] + } else if preds[0] == *join_id { + preds[1] + } else { + return None; + }; // Other predecessor needs to be the other read from the guard's if - let Node::Projection { control : if_node2, ref selection } - = function.nodes[other_pred.idx()] - else { return None; }; + let Node::Projection { + control: if_node2, + ref selection, + } = function.nodes[other_pred.idx()] + else { + return None; + }; let else_branch = *selection; - if else_branch == branch_idx { return None; } - if if_node2 != if_node { return None; } + if else_branch == branch_idx { + return None; + } + if if_node2 != if_node { + return None; + } // Finally, identify the phi nodes associated with the region and match // them with the reduce nodes of the fork-join - let reduce_nodes = join_users.iter().filter(|n| function.nodes[n.idx()].is_reduce()) - .collect::<HashSet<_>>(); + let reduce_nodes = join_users + .iter() + .filter(|n| function.nodes[n.idx()].is_reduce()) + .collect::<HashSet<_>>(); // Construct a map from phi nodes indices to the reduce node index - let phi_nodes = def_use.get_users(*join_control).iter() - .filter_map(|n| { - let Node::Phi { control : _, ref data } = function.nodes[n.idx()] - else { return None; }; - if data.len() != 2 { return Some((*n, None)); } - let (init_idx, reduce_node) - = if reduce_nodes.contains(&data[0]) { - (1, data[0]) - } else if reduce_nodes.contains(&data[1]) { - (0, data[1]) - } else { return Some((*n, None)); }; - let Node::Reduce { control : _, init, .. } - = function.nodes[reduce_node.idx()] - else { return Some((*n, None)); }; - if data[init_idx] != init { return Some((*n, None)); } - Some((*n, Some(reduce_node))) - }) - .collect::<HashMap<_, _>>(); + let phi_nodes = def_use + .get_users(*join_control) + .iter() + .filter_map(|n| { + let Node::Phi { + control: _, + ref data, + } = function.nodes[n.idx()] + else { + return None; + }; + if data.len() != 2 { + return Some((*n, None)); + } + let (init_idx, reduce_node) = if reduce_nodes.contains(&data[0]) { + (1, data[0]) + } else if reduce_nodes.contains(&data[1]) { + (0, data[1]) + } else { + return Some((*n, None)); + }; + let Node::Reduce { + control: _, init, .. + } = function.nodes[reduce_node.idx()] + else { + return Some((*n, None)); + }; + if data[init_idx] != init { + return Some((*n, None)); + } + Some((*n, Some(reduce_node))) + }) + .collect::<HashMap<_, _>>(); // If any of the phi nodes do not have an associated reduce node, we cannot // remove the loop guard - if phi_nodes.iter().any(|(_, red)| red.is_none()) { return None; } + if phi_nodes.iter().any(|(_, red)| red.is_none()) { + return None; + } - let mut phi_nodes = phi_nodes.into_iter() - .map(|(phi, red)| (phi, red.unwrap())) - .collect::<HashMap<_, _>>(); + let mut phi_nodes = phi_nodes + .into_iter() + .map(|(phi, red)| (phi, red.unwrap())) + .collect::<HashMap<_, _>>(); // We also add a map from the region to the join to this map so we only // need one map to handle all node replacements in the elimination process @@ -145,7 +214,15 @@ fn guarded_fork(function: &Function, // - The true and false reads of the if // - The guard's predecessor // - The map from phi nodes to reduce nodes and the region to the join - Some ((NodeID::new(index), *factor, if_node, *control, other_pred, if_pred, phi_nodes)) + Some(( + NodeID::new(index), + factors.clone(), + if_node, + *control, + other_pred, + if_pred, + phi_nodes, + )) } /* @@ -153,26 +230,33 @@ fn guarded_fork(function: &Function, * Deletes nodes by setting nodes to gravestones. Works with a function already * containing gravestones. */ -pub fn fork_guard_elim(function: &mut Function, - constants: &Vec<Constant>, - fork_join_map: &HashMap<NodeID, NodeID>, - def_use: &ImmutableDefUseMap) { - let guard_info = function.nodes.iter().enumerate() - .filter_map(|(i, n)| guarded_fork(function, constants, - fork_join_map, - def_use, i, n)) - .collect::<Vec<_>>(); - - for (fork_node, factor, guard_node, guard_proj1, guard_proj2, guard_pred, map) in guard_info { +pub fn fork_guard_elim( + function: &mut Function, + constants: &Vec<Constant>, + fork_join_map: &HashMap<NodeID, NodeID>, + def_use: &ImmutableDefUseMap, +) { + let guard_info = function + .nodes + .iter() + .enumerate() + .filter_map(|(i, n)| guarded_fork(function, constants, fork_join_map, def_use, i, n)) + .collect::<Vec<_>>(); + + for (fork_node, factors, guard_node, guard_proj1, guard_proj2, guard_pred, map) in guard_info { function.nodes[guard_node.idx()] = Node::Start; function.nodes[guard_proj1.idx()] = Node::Start; function.nodes[guard_proj2.idx()] = Node::Start; - function.nodes[fork_node.idx()] - = Node::Fork { control : guard_pred, factor : factor}; + function.nodes[fork_node.idx()] = Node::Fork { + control: guard_pred, + factors, + }; for (idx, node) in function.nodes.iter_mut().enumerate() { let node_idx = NodeID::new(idx); - if map.contains_key(&node_idx) { *node = Node::Start; } + if map.contains_key(&node_idx) { + *node = Node::Start; + } for u in get_uses_mut(node).as_mut() { if let Some(replacement) = map.get(u) { **u = *replacement; diff --git a/hercules_opt/src/forkify.rs b/hercules_opt/src/forkify.rs index 1b9e4d5b5b53aff5a0821a4a409c43b86b89de46..e32bef383a077d6444bf3f3817c4adbe96fee78b 100644 --- a/hercules_opt/src/forkify.rs +++ b/hercules_opt/src/forkify.rs @@ -141,7 +141,7 @@ pub fn forkify( // Create fork and join nodes. let fork = Node::Fork { control: *loop_pred, - factor: dc_id, + factors: Box::new([dc_id]), }; let fork_id = NodeID::new(function.nodes.len()); function.nodes.push(fork); @@ -215,7 +215,10 @@ pub fn forkify( } // Convert index phi node to thread ID node. - let thread_id = Node::ThreadID { control: fork_id }; + let thread_id = Node::ThreadID { + control: fork_id, + dimension: 0, + }; let thread_id_id = NodeID::new(function.nodes.len()); function.nodes.push(thread_id); diff --git a/hercules_opt/src/pass.rs b/hercules_opt/src/pass.rs index 641b46ab5dcd957f1f9d330e8044a120c5b72b83..395f24d64f58ec5aff40e9a8549e2e4a68234f23 100644 --- a/hercules_opt/src/pass.rs +++ b/hercules_opt/src/pass.rs @@ -1,13 +1,18 @@ +extern crate hercules_cg; extern crate hercules_ir; extern crate postcard; extern crate serde; extern crate take_mut; use std::collections::HashMap; +use std::fs::File; +use std::io::Write; use std::iter::zip; +use std::process::{Command, Stdio}; use self::serde::Deserialize; +use self::hercules_cg::*; use self::hercules_ir::*; use crate::*; @@ -56,7 +61,6 @@ pub struct PassManager { pub loops: Option<Vec<LoopTree>>, pub antideps: Option<Vec<Vec<(NodeID, NodeID)>>>, pub bbs: Option<Vec<Vec<NodeID>>>, - pub fork_join_placements: Option<Vec<Vec<ForkJoinPlacement>>>, // Current plan. Keep track of the last time the plan was updated. pub plans: Option<Vec<Plan>>, @@ -79,7 +83,6 @@ impl PassManager { antideps: None, bbs: None, plans: None, - fork_join_placements: None, } } @@ -157,13 +160,13 @@ impl PassManager { pub fn make_fork_join_maps(&mut self) { if self.fork_join_maps.is_none() { - self.make_typing(); + self.make_control_subgraphs(); self.fork_join_maps = Some( zip( self.module.functions.iter(), - self.typing.as_ref().unwrap().iter(), + self.control_subgraphs.as_ref().unwrap().iter(), ) - .map(|(function, typing)| fork_join_map(function, typing, &self.module.types)) + .map(|(function, subgraph)| fork_join_map(function, subgraph)) .collect(), ); } @@ -251,27 +254,6 @@ impl PassManager { } } - pub fn make_fork_join_placements(&mut self) { - if self.fork_join_placements.is_none() { - self.make_fork_join_maps(); - self.make_fork_join_nests(); - self.make_bbs(); - let fork_join_maps = self.fork_join_maps.as_ref().unwrap().iter(); - let fork_join_nests = self.fork_join_nests.as_ref().unwrap().iter(); - let bbs = self.bbs.as_ref().unwrap().iter(); - self.fork_join_placements = Some( - zip( - self.module.functions.iter(), - zip(fork_join_maps, zip(fork_join_nests, bbs)), - ) - .map(|(function, (fork_join_map, (fork_join_nest, bb)))| { - compute_fork_join_placement(function, fork_join_map, fork_join_nest, bb) - }) - .collect(), - ); - } - } - pub fn make_plans(&mut self) { if self.plans.is_none() { self.make_reverse_postorders(); @@ -309,7 +291,6 @@ impl PassManager { for idx in 0..self.module.functions.len() { ccp( &mut self.module.functions[idx], - &self.module.types, &mut self.module.constants, &def_uses[idx], &reverse_postorders[idx], @@ -451,7 +432,6 @@ impl PassManager { self.make_fork_join_maps(); self.make_bbs(); self.make_plans(); - self.make_fork_join_placements(); } xdot_module( &self.module, @@ -460,7 +440,6 @@ impl PassManager { self.fork_join_maps.as_ref(), self.bbs.as_ref(), self.plans.as_ref(), - self.fork_join_placements.as_ref(), ); // Xdot doesn't require clearing analysis results. @@ -476,48 +455,53 @@ impl PassManager { self.make_antideps(); self.make_bbs(); self.make_plans(); - self.make_fork_join_placements(); - - //let smodule = simple_compile( - // &self.module, - // self.def_uses.as_ref().unwrap(), - // self.reverse_postorders.as_ref().unwrap(), - // self.typing.as_ref().unwrap(), - // self.control_subgraphs.as_ref().unwrap(), - // self.fork_join_maps.as_ref().unwrap(), - // self.fork_join_nests.as_ref().unwrap(), - // self.antideps.as_ref().unwrap(), - // self.bbs.as_ref().unwrap(), - // self.plans.as_ref().unwrap(), - // self.fork_join_placements.as_ref().unwrap(), - //); - //println!("{:#?}", smodule); - - //// Compile LLVM IR into ELF object. - //let llc_process = Command::new("llc") - // .arg("-filetype=obj") - // .arg("-O3") - // .stdin(Stdio::piped()) - // .stdout(Stdio::piped()) - // .spawn() - // .unwrap(); - //llc_process - // .stdin - // .as_ref() - // .unwrap() - // .write(llvm_ir.as_bytes()) - // .unwrap(); - //let elf_object = llc_process.wait_with_output().unwrap().stdout; - // - //// Package manifest and ELF object into the same file. - //let hbin_module = (manifest, elf_object); - //let hbin_contents: Vec<u8> = postcard::to_allocvec(&hbin_module).unwrap(); - // - //let mut file = - // File::create(output_file_name).expect("PANIC: Unable to open output file."); - //file.write_all(&hbin_contents) - // .expect("PANIC: Unable to write output file contents."); - // + + let smodule = sched_compile( + &self.module, + self.def_uses.as_ref().unwrap(), + self.reverse_postorders.as_ref().unwrap(), + self.typing.as_ref().unwrap(), + self.control_subgraphs.as_ref().unwrap(), + self.fork_join_maps.as_ref().unwrap(), + self.fork_join_nests.as_ref().unwrap(), + self.antideps.as_ref().unwrap(), + self.bbs.as_ref().unwrap(), + self.plans.as_ref().unwrap(), + ); + + let mut llvm_ir = String::new(); + for manifest in smodule.manifests.values() { + for partition_manifest in manifest.partitions.iter() { + let function = &smodule.functions[&partition_manifest.name]; + cpu_compile(function, partition_manifest, &mut llvm_ir).unwrap(); + } + } + println!("{}", llvm_ir); + + // Compile LLVM IR into ELF object. + let llc_process = Command::new("llc") + .arg("-filetype=obj") + .arg("-O3") + .stdin(Stdio::piped()) + .stdout(Stdio::piped()) + .spawn() + .unwrap(); + llc_process + .stdin + .as_ref() + .unwrap() + .write(llvm_ir.as_bytes()) + .unwrap(); + let elf_object = llc_process.wait_with_output().unwrap().stdout; + + // Package manifest and ELF object into the same file. + let hbin_module = (smodule.manifests, elf_object); + let hbin_contents: Vec<u8> = postcard::to_allocvec(&hbin_module).unwrap(); + + let mut file = + File::create(output_file_name).expect("PANIC: Unable to open output file."); + file.write_all(&hbin_contents) + .expect("PANIC: Unable to write output file contents."); // Codegen doesn't require clearing analysis results. continue; diff --git a/hercules_rt/src/elf.rs b/hercules_rt/src/elf.rs index 4f2a196cd4e933a2a3e381fcfe430246600844b3..9a7aac0d87d1712b5b37dd7803d881329f2b472a 100644 --- a/hercules_rt/src/elf.rs +++ b/hercules_rt/src/elf.rs @@ -5,6 +5,7 @@ use std::mem::size_of; use std::ptr::copy_nonoverlapping; use std::ptr::null_mut; use std::ptr::read_unaligned; +use std::str::from_utf8; use self::libc::*; @@ -38,6 +39,13 @@ pub(crate) struct Elf { pub(crate) program_size: usize, } +/* + * Mmaps are visible to all threads and are thread safe, so we can share the ELF + * across threads. + */ +unsafe impl Send for Elf {} +unsafe impl Sync for Elf {} + impl Drop for Elf { fn drop(&mut self) { unsafe { munmap(self.program_section as *mut _, self.program_size) }; @@ -80,15 +88,20 @@ pub(crate) unsafe fn parse_elf(elf: &[u8]) -> Elf { let shstrtab = &elf[section_header_table[header.e_shstrndx as usize].sh_offset as usize..]; for i in 0..header.e_shnum as usize { let section_name = &shstrtab[section_header_table[i].sh_name as usize..]; - if section_name.starts_with(b".symtab") { + let null_position = section_name + .iter() + .position(|&c| c == b'\0') + .unwrap_or(section_name.len()); + let name_str = from_utf8(§ion_name[..null_position]).unwrap(); + if name_str == ".symtab" { symtab_ndx = i as i32; - } else if section_name.starts_with(b".strtab") { + } else if name_str == ".strtab" { strtab_ndx = i as i32; - } else if section_name.starts_with(b".text") { + } else if name_str == ".text" { text_ndx = i as i32; - } else if section_name.starts_with(b".bss") { + } else if name_str == ".bss" { bss_ndx = i as i32; - } else if section_name.starts_with(b".rela.text") { + } else if name_str == ".rela.text" { rela_text_ndx = i as i32; } } diff --git a/hercules_rt_proc/Cargo.toml b/hercules_rt_proc/Cargo.toml index eafa8a446884dc52f6f0afe208b03104396f28de..bbf52c6ba75b93db67bcd878f57cd5d475401591 100644 --- a/hercules_rt_proc/Cargo.toml +++ b/hercules_rt_proc/Cargo.toml @@ -10,5 +10,5 @@ proc-macro = true [dependencies] postcard = { version = "*", features = ["alloc"] } serde = { version = "*", features = ["derive"] } -hercules_ir = { path = "../hercules_ir" } +hercules_cg = { path = "../hercules_cg" } anyhow = "*" \ No newline at end of file diff --git a/hercules_rt_proc/src/lib.rs b/hercules_rt_proc/src/lib.rs index 4f4723db57b2907749c0aa53f0d5a43d45fad76b..0056ccd75a29ae0703dad666e3473d843dea911f 100644 --- a/hercules_rt_proc/src/lib.rs +++ b/hercules_rt_proc/src/lib.rs @@ -1,11 +1,11 @@ #![feature(iter_intersperse)] extern crate anyhow; -extern crate hercules_ir; +extern crate hercules_cg; extern crate postcard; extern crate proc_macro; -use std::collections::HashSet; +use std::collections::{HashMap, HashSet}; use std::ffi::OsStr; use std::fmt::Write; use std::fs::File; @@ -15,274 +15,185 @@ use std::path::Path; use proc_macro::*; -use self::hercules_ir::*; +use self::hercules_cg::*; /* * Parse manifest from header of .hbin file. */ -fn manifest_and_module_bytes(buffer: &[u8]) -> (ModuleManifest, Vec<u8>) { +fn manifests_and_module_bytes(buffer: &[u8]) -> (HashMap<String, Manifest>, Vec<u8>) { postcard::from_bytes(buffer).unwrap() } /* - * Convert Hercules types to the Rust types generated in the interface. + * Convert schedule IR types to the Rust types generated in the interface. */ -fn generate_type_string(ty: &Type, rust_types: &Vec<String>) -> String { +fn generate_type_string(ty: &SType) -> String { match ty { - Type::Control(_) => "NOT_A_REAL_TYPE".to_string(), - Type::Boolean => "bool".to_string(), - Type::Integer8 => "i8".to_string(), - Type::Integer16 => "i16".to_string(), - Type::Integer32 => "i32".to_string(), - Type::Integer64 => "i64".to_string(), - Type::UnsignedInteger8 => "u8".to_string(), - Type::UnsignedInteger16 => "u16".to_string(), - Type::UnsignedInteger32 => "u32".to_string(), - Type::UnsignedInteger64 => "u64".to_string(), - Type::Float32 => "f32".to_string(), - Type::Float64 => "f64".to_string(), - Type::Product(fields) => fields.iter().fold("Prod".to_string(), |acc, field| { - format!("{}_{}", acc, field.idx()) - }), - Type::Summation(_) => todo!(), - Type::Array(elem, _) => format!("*mut {}", &rust_types[elem.idx()]), - } -} - -/* - * Emit a dynamic constant, which isn't necessarily just a dynamic constant - * parameter. - */ -fn emit_dynamic_constant(dc: DynamicConstantID, manifest: &ModuleManifest) -> String { - match manifest.dynamic_constants[dc.idx()] { - DynamicConstant::Constant(val) => format!("{}", val), - DynamicConstant::Parameter(idx) => format!("dyn_cons_{}", idx), + SType::Boolean => "bool".to_string(), + SType::Integer8 => "i8".to_string(), + SType::Integer16 => "i16".to_string(), + SType::Integer32 => "i32".to_string(), + SType::Integer64 => "i64".to_string(), + SType::UnsignedInteger8 => "u8".to_string(), + SType::UnsignedInteger16 => "u16".to_string(), + SType::UnsignedInteger32 => "u32".to_string(), + SType::UnsignedInteger64 => "u64".to_string(), + SType::Float32 => "f32".to_string(), + SType::Float64 => "f64".to_string(), + SType::Product(fields) => { + fields.iter().fold("__Prod".to_string(), |acc, field| { + format!("{}_{}", acc, generate_type_name(field)) + }) + "_" + } + SType::ArrayRef(elem) => format!("*mut {}", generate_type_string(elem)), } } -/* - * Emit dynamic constant math to calculate the runtime size of a type, in number - * of bytes. - */ -fn emit_type_size_expression(ty: TypeID, manifest: &ModuleManifest) -> String { - match manifest.types[ty.idx()] { - Type::Array(elem, ref dcs) => { - dcs.iter().fold( - "(".to_string() + &emit_type_size_expression(elem, manifest), - |acc, dc| acc + " * " + &emit_dynamic_constant(*dc, manifest), - ) + ")" +fn generate_type_name(ty: &SType) -> String { + match ty { + SType::Boolean + | SType::Integer8 + | SType::Integer16 + | SType::Integer32 + | SType::Integer64 + | SType::UnsignedInteger8 + | SType::UnsignedInteger16 + | SType::UnsignedInteger32 + | SType::UnsignedInteger64 + | SType::Float32 + | SType::Float64 => generate_type_string(ty), + SType::Product(fields) => { + fields.iter().fold("__Prod".to_string(), |acc, field| { + format!("{}_{}", acc, generate_type_name(field)) + }) + "_" } - _ => format!( - "{}", - manifest.type_sizes_aligns[ty.idx()] - .0 - .expect("PANIC: Size of non-array type is unknown at compile-time.") - ), + SType::ArrayRef(elem) => format!("ArrayRef_{}", generate_type_name(elem)), } } /* * Generate async Rust code orchestrating partition execution. */ -fn codegen(manifest: &ModuleManifest, elf: &[u8]) -> Result<String, anyhow::Error> { - // Verify the manifest. - verify_manifest(manifest).map_err(anyhow::Error::msg)?; - - // Convert Hercules IR types to Rust types. - let mut rust_types = vec!["".to_string(); manifest.types.len()]; - for id in types_bottom_up(&manifest.types) { - rust_types[id.idx()] = generate_type_string(&manifest.types[id.idx()], &rust_types); - } - +fn codegen(manifests: &HashMap<String, Manifest>, elf: &[u8]) -> Result<String, anyhow::Error> { // Write to a String containing all of the Rust code. let mut rust_code = "".to_string(); - // Emit the ELF bytes as a static byte string constant. + // Emit the ELF bytes as a static byte string constant. Construct the module + // object on first access with LazyLock. write!( rust_code, - "const __HERCULES_ELF_OBJ: &[u8] = {};", + "const __HERCULES_ELF_OBJ: &[u8] = {};\n", Literal::byte_string(elf) )?; - - // Emit the async functions. - for function in &manifest.functions { - // Emit the function signature. - write!(rust_code, "async fn {}(", function.name)?; - let param_tys = function.param_types.iter().map(|id| &rust_types[id.idx()]); - for (idx, param) in param_tys.enumerate() { - write!(rust_code, "param_{}: {},", idx, param)?; - } - for dyn_cons_idx in 0..function.num_dynamic_constant_parameters { - write!(rust_code, "dyn_cons_{}: u64,", dyn_cons_idx)?; + write!( + rust_code, + "static __HERCULES_MODULE_OBJ: ::std::sync::LazyLock<::hercules_rt::Module> = ::std::sync::LazyLock::new(|| {{\n", + )?; + // Check that the ELF got embedded properly. + let hash = { + let mut s = DefaultHasher::new(); + elf.hash(&mut s); + s.finish() + }; + write!( + rust_code, + " use std::hash::{{DefaultHasher, Hash, Hasher}};\n debug_assert_eq!({{let mut s = DefaultHasher::new(); __HERCULES_ELF_OBJ.hash(&mut s); s.finish()}}, {});\n", + hash + )?; + write!( + rust_code, + " ::hercules_rt::Module::new(__HERCULES_ELF_OBJ)\n" + )?; + write!(rust_code, "}});\n")?; + + // Emit the product types used in this module. We can't just emit product + // types, since we need #[repr(C)] to interact with LLVM. + let all_stypes = manifests + .into_iter() + .map(|(_, manifest)| manifest.all_visible_types()) + .flatten() + .collect::<HashSet<SType>>(); + for stype in all_stypes.iter() { + if let Some(fields) = stype.try_product() { + write!( + rust_code, + "#[derive(Clone, Copy, Debug)]\n#[repr(C)]\nstruct {}({});\n", + generate_type_string(stype), + fields + .iter() + .map(|field| generate_type_string(field)) + .intersperse(", ".to_string()) + .fold("".to_string(), |acc, token| acc + &token) + )?; } - write!( - rust_code, - ") -> {} {{", - rust_types[function.return_type.idx()] - )?; - - // Check that the ELF got embedded properly. - write!( - rust_code, - "use std::hash::{{DefaultHasher, Hash, Hasher}};debug_assert_eq!({{let mut s = DefaultHasher::new(); __HERCULES_ELF_OBJ.hash(&mut s); s.finish()}}, {});", - { - let mut s = DefaultHasher::new(); - elf.hash(&mut s); - s.finish() - } - )?; + } - // In order to get repr(C), we need to define named structs to attach - // that attribute to. This unfortunately means we can't use anonymous - // products, which would be much nicer than this nonsense. - let mut already_seen = HashSet::new(); - let mut emit_product_type_def = - |ty: &Type, rust_code: &mut String| -> Result<(), anyhow::Error> { - match ty { - Type::Product(ref fields) => { - if !already_seen.contains(ty) { - write!( - rust_code, - "#[repr(C)] struct {}({});", - generate_type_string(&ty, &rust_types), - fields.iter().fold("".to_string(), |acc, field| { - acc + &rust_types[field.idx()] + "," - }) - )?; - already_seen.insert(ty.clone()); - } - } - _ => {} + // Emit the async Rust functions implementing each Hercules function. + for (function_name, manifest) in manifests.into_iter() { + // Emit the function signature. + write!(rust_code, "async unsafe fn {}(", function_name)?; + for (param_ty, param_kind) in manifest.param_types.iter() { + match param_kind { + ParameterKind::HerculesParameter(idx) => write!(rust_code, "param_{}", idx)?, + ParameterKind::DataInput(_) => panic!( + "PANIC: Parameter kind for Hercules function parameter cannot be DataInput." + ), + ParameterKind::DynamicConstant(idx) => write!(rust_code, "dc_{}", idx)?, + ParameterKind::ArrayConstant(array_id) => { + write!(rust_code, "array_{}", array_id.idx())? } - Ok(()) - }; - for id in types_bottom_up(&manifest.types) { - emit_product_type_def(&manifest.types[id.idx()], &mut rust_code)?; + } + write!(rust_code, ": {}, ", generate_type_string(param_ty))? } - - // Load the ELF object, and cast the appropriate pointers. write!( rust_code, - r#"let module = ::hercules_rt::Module::new(__HERCULES_ELF_OBJ);"# + ") -> {} {{\n", + generate_type_string(&manifest.return_type) )?; - for (idx, partition) in function.partitions.iter().enumerate() { - // Determine the signature for the partition function. - let u64_ty_id = TypeID::new( - manifest - .types - .iter() - .position(|ty| *ty == Type::UnsignedInteger64) - .unwrap(), - ); - let input_types = partition.inputs.iter().map(|input| match input { - PartitionInput::DataInput(node_id) => function.typing[node_id.idx()], - PartitionInput::FunctionArgument(param_idx) => { - function.param_types[*param_idx as usize] - } - PartitionInput::ArrayConstant(num) => { - // We store the types of constants in the manifest - use - // this info to get the type of this array constant. - let arr_cons_id = manifest.array_cons_ids[*num as usize]; - manifest - .constant_types - .iter() - .filter(|(cons_id, _)| arr_cons_id == *cons_id) - .next() - .expect("PANIC: Couldn't find type of array constant in manifest.") - .1 - } - PartitionInput::DynamicConstant(_) => u64_ty_id, - }); - let output_type = Type::Product( - partition - .outputs - .iter() - .map(|output| match output { - PartitionOutput::DataOutput(node_id) => function.typing[node_id.idx()], - PartitionOutput::ControlIndicator => u64_ty_id, - }) - .collect(), - ); - - // The product type outputs of partitions might not already exist in - // the list of types in the module. - emit_product_type_def(&output_type, &mut rust_code)?; - // Get the pointer for the partition function, and cast it to the - // correct function pointer type. + // Compute the signature for each partition function and emit the + // function pointers. + for (partition_idx, partition) in manifest.partitions.iter().enumerate() { write!( rust_code, - r#"let fn_ptr_part_{}: extern "C" fn({}) -> {} = unsafe {{ ::core::mem::transmute(module.get_function_ptr("{}_part_{}")) }};"#, - idx, - input_types - .map(|id| &rust_types[id.idx()] as &str) - .intersperse(", ") - .fold("".to_string(), |acc, field| acc + field), - // Generate the output product type on the fly. This type is not - // necessarily present in the typing of the original IR, before - // code generation. - generate_type_string(&output_type, &rust_types), - function.name, - idx + " let fn_ptr_part_{}: extern \"C\" fn(", + partition_idx )?; - } - - // Declare all of the input / output intermediate variables. We declare - // them as Options since we want to catch use before definition bugs. We - // get the list of variables from partition outputs specifically, since - // this also includes returned values. - let all_partition_outputs = function_partition_outputs(function); - for node in all_partition_outputs { + for (param_stype, _) in partition.parameters.iter() { + write!(rust_code, "{}, ", generate_type_string(param_stype))?; + } + let return_stype = if partition.returns.len() == 1 { + partition.returns[0].0.clone() + } else { + SType::Product( + partition + .returns + .iter() + .map(|(return_stype, _)| return_stype.clone()) + .collect(), + ) + }; write!( rust_code, - "let mut node_{}: Option<{}> = None;", - node.idx(), - &rust_types[function.typing[node.idx()].idx()] + ") -> {} = ::core::mem::transmute(__HERCULES_MODULE_OBJ.get_function_ptr(\"{}\"));\n", + generate_type_string(&return_stype), + partition.name, )?; } - // Declare all of the array constant memories. We declare them as Vecs - // to allocate the memories. We emit multiplications of the dynamic - // constant dimensions to allocate the whole memory as one contiguous - // range. - for (arr_cons_num, arr_cons_id) in manifest - .array_cons_ids - .iter() - .filter(|id| function.used_constants.contains(id)) - .enumerate() - { - let arr_ty_id = manifest - .constant_types - .iter() - .filter(|(cons_id, _)| arr_cons_id == cons_id) - .next() - .expect("PANIC: Couldn't find type of array constant in manifest.") - .1; - if let ConstantBytes::NonZero(ref bytes, _) = manifest.array_constants[arr_cons_num] { - // Initialize the vector from the non-zero constant bytes of the - // array. - write!( - rust_code, - "let mut arr_cons_{}_vec = Vec::from({});let mut arr_cons_{} = arr_cons_{}_vec.as_mut_ptr();arr_cons_{}_vec.leak();", - arr_cons_num, - Literal::byte_string(bytes), - arr_cons_num, - arr_cons_num, - arr_cons_num, - )?; - } else { - // The array is all zeros, so create a vector of zeros with a - // possibly runtime only known size. - write!( - rust_code, - "let mut arr_cons_{}_vec = vec![0u8; {} as usize];let mut arr_cons_{} = arr_cons_{}_vec.as_mut_ptr();arr_cons_{}_vec.leak();", - arr_cons_num, - emit_type_size_expression(arr_ty_id, manifest), - arr_cons_num, - arr_cons_num, - arr_cons_num, - )?; - } + // Declare all of the intermediary data input / output variables. They + // are declared as MaybeUninit, since they get assigned after running a + // partition. MaybeUninits should always be defined before assume_init() + // is called on them, assuming a valid partitioning. + let mut data_inputs = HashSet::new(); + let mut data_outputs = HashSet::new(); + for partition in manifest.partitions.iter() { + data_inputs.extend(partition.data_inputs()); + data_outputs.extend(partition.data_outputs()); + } + assert_eq!(data_inputs, data_outputs); + for (node, stype) in data_inputs { + write!(rust_code, " let mut node_{}: ::std::mem::MaybeUninit<{}> = ::std::mem::MaybeUninit::uninit();\n", node.idx(), generate_type_string(stype))?; } // The core executor is a Rust loop. We literally run a "control token" @@ -290,101 +201,88 @@ fn codegen(manifest: &ModuleManifest, elf: &[u8]) -> Result<String, anyhow::Erro // partitions to drive execution. write!( rust_code, - "let mut control_token: u32 = {};loop {{", - function.top_partition.idx() + " let mut control_token: i8 = 0;\n loop {{\n", )?; // Match on the control token position to determine which partition to // execute. - write!(rust_code, "match control_token {{")?; - for (idx, partition) in function.partitions.iter().enumerate() { - // The control token stores the current partition index. - write!(rust_code, "{} => {{", idx)?; + write!(rust_code, " match control_token {{\n")?; - // Calculate the inputs. - for (input_idx, input) in partition.inputs.iter().enumerate() { - write!( - rust_code, - "let input_{} = {};", - input_idx, - match input { - PartitionInput::DataInput(id) => format!( - "node_{}.expect(\"PANIC: Encountered use before def in runtime.\")", - id.idx() - ), - PartitionInput::FunctionArgument(idx) => format!("param_{}", idx), - PartitionInput::ArrayConstant(num) => format!("arr_cons_{} as _", num), - PartitionInput::DynamicConstant(idx) => format!("dyn_cons_{}", idx), - } - )?; - } + // Emit the match arm per partition. + for (idx, partition) in manifest.partitions.iter().enumerate() { + // Open the arm. + write!(rust_code, " {} => {{\n", idx)?; - // Make the call. - write!(rust_code, "let output = fn_ptr_part_{}(", idx)?; - for input_idx in 0..partition.inputs.len() { - write!(rust_code, "input_{},", input_idx)?; + // Call the partition function. + write!( + rust_code, + " let output = fn_ptr_part_{}(", + idx + )?; + for (_, kind) in partition.parameters.iter() { + match kind { + ParameterKind::HerculesParameter(idx) => write!(rust_code, "param_{}, ", idx)?, + ParameterKind::DataInput(id) => { + write!(rust_code, "node_{}.assume_init(), ", id.idx())? + } + ParameterKind::DynamicConstant(idx) => write!(rust_code, "dc_{}", idx)?, + ParameterKind::ArrayConstant(id) => write!(rust_code, "array_{}", id.idx())?, + } } - write!(rust_code, ");")?; + write!(rust_code, ");\n")?; // Assign the outputs. - for (output_idx, output) in partition.outputs.iter().enumerate() { - write!( - rust_code, - "{} = Some(output.{});", - match output { - PartitionOutput::DataOutput(id) => format!("node_{}", id.idx()), - // TODO: handle partitions with control indicator - // outputs. - PartitionOutput::ControlIndicator => "todo!()".to_string(), - }, - output_idx - )?; + for (output_idx, (_, kind)) in partition.returns.iter().enumerate() { + let output_ref = if partition.returns.len() == 1 { + "output".to_string() + } else { + format!("output.{}", output_idx) + }; + match kind { + ReturnKind::HerculesReturn => { + write!(rust_code, " return {};\n", output_ref)? + } + ReturnKind::DataOutput(id) => write!( + rust_code, + " node_{}.write({});\n", + id.idx(), + output_ref + )?, + ReturnKind::NextPartition => write!( + rust_code, + " control_token = {};l\n", + output_ref + )?, + } } - // If there are no control successors, return from the function. If - // there is one control successor, unconditionally set the control - // token to that successor's index. If there are multiple control - // successors, set the control token to the control indicator output - // of the partition function. - // TODO: handle partitions with multiple return nodes. - // TODO: handle partitions with a return node and a single successor - // partition. - if partition.successor_partitions.len() == 0 { - assert_eq!(function.returned_values.len(), 1); + // If there's only one partition successor, then an explicit + // NextPartition isn't returned - emit the new control token here. + if partition.successors.len() == 1 { write!( rust_code, - "return node_{}.expect(\"PANIC: Encountered use before def in runtime.\");", - function.returned_values[0].idx() + " control_token = {};\n", + partition.successors[0].idx() )?; - } else if partition.successor_partitions.len() == 1 { - write!( - rust_code, - "control_token = {};", - partition.successor_partitions[0].1.idx() - )?; - } else { - todo!(); } - // Close the branch. - write!(rust_code, "}}")?; + // Close the arm. + write!(rust_code, " }}\n")?; } // Close the match, and handle invalid control token values. write!( rust_code, - "_ => panic!(\"PANIC: Invalid control token value.\"),}}" + " _ => panic!(\"PANIC: Invalid control token value.\"),\n }}\n" )?; // Close the loop. - write!(rust_code, "}}")?; + write!(rust_code, " }}\n")?; // Close the function. - write!(rust_code, "}}")?; + write!(rust_code, "}}\n")?; } - eprintln!("{}", rust_code); - Ok(rust_code) } @@ -423,9 +321,10 @@ pub fn use_hbin(path: TokenStream) -> TokenStream { let mut f = File::open(path).unwrap(); let mut buffer = vec![]; f.read_to_end(&mut buffer).unwrap(); - let (manifest, elf) = manifest_and_module_bytes(&buffer); - eprintln!("{:?}", manifest); + let (manifests, elf) = manifests_and_module_bytes(&buffer); // Generate Rust code. - codegen(&manifest, &elf).unwrap().parse().unwrap() + let rust_code = codegen(&manifests, &elf).unwrap(); + eprintln!("{}", rust_code); + rust_code.parse().unwrap() } diff --git a/hercules_samples/dot/Cargo.toml b/hercules_samples/dot/Cargo.toml new file mode 100644 index 0000000000000000000000000000000000000000..fe5873abd363aea21b8c4e24bdc497c26ecd27f1 --- /dev/null +++ b/hercules_samples/dot/Cargo.toml @@ -0,0 +1,11 @@ +[package] +name = "hercules_dot" +version = "0.1.0" +authors = ["Russel Arbore <rarbore2@illinois.edu>"] +edition = "2021" + +[dependencies] +clap = { version = "*", features = ["derive"] } +hercules_rt = { path = "../../hercules_rt" } +rand = "*" +async-std = "*" diff --git a/hercules_samples/dot/dot.hir b/hercules_samples/dot/dot.hir new file mode 100644 index 0000000000000000000000000000000000000000..63790a343e9ec8dcbcac0829dea082a6e4b60e9c --- /dev/null +++ b/hercules_samples/dot/dot.hir @@ -0,0 +1,11 @@ +fn dot<1>(a: array(f32, #0), b: array(f32, #0)) -> f32 + zero = constant(f32, 0.0) + fork = fork(start, #0) + id = thread_id(fork, 0) + join = join(fork) + r = return(join, dot_red) + aval = read(a, position(id)) + bval = read(b, position(id)) + mul = mul(aval, bval) + dot = add(mul, dot_red) + dot_red = reduce(join, zero, dot) diff --git a/hercules_samples/dot/src/main.rs b/hercules_samples/dot/src/main.rs new file mode 100644 index 0000000000000000000000000000000000000000..3aea6409dd08eb6c8b8157877b55ad57cc811b80 --- /dev/null +++ b/hercules_samples/dot/src/main.rs @@ -0,0 +1,18 @@ +extern crate async_std; +extern crate clap; +extern crate hercules_rt; + +// To compile currently, run from the Hercules project root directory: +// cargo run --bin hercules_driver hercules_samples/dot/dot.hir "Codegen(\"dot.hbin\")" +// Then, you can execute this example with: +// cargo run --bin hercules_dot +hercules_rt::use_hbin!("dot.hbin"); + +fn main() { + async_std::task::block_on(async { + let mut a = vec![1.0, 2.0, 3.0, 4.0]; + let mut b = vec![5.0, 6.0, 7.0, 8.0]; + let c = unsafe { dot(a.as_mut_ptr(), b.as_mut_ptr(), 4).await }; + println!("{}", c,); + }); +} diff --git a/hercules_test/hercules_interpreter/src/value.rs b/hercules_test/hercules_interpreter/src/value.rs index 1c52fea8ce4e2bce729ec04189368e7a7e54ced0..d236145c00e1bfd63a1b45ff5bebad93398062ea 100644 --- a/hercules_test/hercules_interpreter/src/value.rs +++ b/hercules_test/hercules_interpreter/src/value.rs @@ -1,7 +1,7 @@ #![allow(unused)] extern crate derive_more; -use derive_more::{From}; +use derive_more::From; /* Defines semantic meaning of IR operations. */ extern crate itertools; @@ -62,12 +62,10 @@ pub enum InterpreterWrapper { Array(Box<[InterpreterWrapper]>), // TypeID of the array Type (not the element type) } - - -impl <T> From<&Vec<T>> for InterpreterWrapper - where T: Into<InterpreterWrapper> + Clone +impl<T> From<&Vec<T>> for InterpreterWrapper +where + T: Into<InterpreterWrapper> + Clone, { - fn from(value: &Vec<T>) -> Self { let mut values = vec![]; for i in 0..value.len() { @@ -77,8 +75,9 @@ impl <T> From<&Vec<T>> for InterpreterWrapper } } -impl <T> From<Vec<T>> for InterpreterWrapper - where T: Into<InterpreterWrapper> + Clone +impl<T> From<Vec<T>> for InterpreterWrapper +where + T: Into<InterpreterWrapper> + Clone, { fn from(value: Vec<T>) -> Self { let mut values = vec![]; @@ -89,11 +88,10 @@ impl <T> From<Vec<T>> for InterpreterWrapper } } - -impl <T> From<&[T]> for InterpreterWrapper - where T: Into<InterpreterWrapper> + Clone +impl<T> From<&[T]> for InterpreterWrapper +where + T: Into<InterpreterWrapper> + Clone, { - fn from(value: &[T]) -> Self { let mut values = vec![]; for i in 0..value.len() { @@ -103,7 +101,7 @@ impl <T> From<&[T]> for InterpreterWrapper } } -// Map rust types to interpreter values. +// Map rust types to interpreter values. macro_rules! from_impl { ($rust:ty, $variant:tt) => { impl From<$rust> for InterpreterWrapper { @@ -155,7 +153,9 @@ impl<'a> InterpreterVal { .try_extents() .expect("PANIC: wrong type for array") .into_iter() - .map(|extent| dyn_const_value(&dynamic_constants[extent.idx()], &dynamic_constant_params)) + .map(|extent| { + dyn_const_value(&dynamic_constants[extent.idx()], &dynamic_constant_params) + }) .collect(); let size = InterpreterVal::array_size(&extents); @@ -198,7 +198,6 @@ impl<'a> InterpreterVal { left: InterpreterVal, right: InterpreterVal, ) -> InterpreterVal { - // Do some type conversion first. let left = match left { InterpreterVal::DynamicConstant(v) => match right { @@ -785,6 +784,10 @@ impl<'a> InterpreterVal { (UnaryOperator::Not, Self::Integer16(val)) => Self::Integer16(!val), (UnaryOperator::Not, Self::Integer32(val)) => Self::Integer32(!val), (UnaryOperator::Not, Self::Integer64(val)) => Self::Integer64(!val), + (UnaryOperator::Not, Self::UnsignedInteger8(val)) => Self::UnsignedInteger8(!val), + (UnaryOperator::Not, Self::UnsignedInteger16(val)) => Self::UnsignedInteger16(!val), + (UnaryOperator::Not, Self::UnsignedInteger32(val)) => Self::UnsignedInteger32(!val), + (UnaryOperator::Not, Self::UnsignedInteger64(val)) => Self::UnsignedInteger64(!val), (UnaryOperator::Neg, Self::Integer8(val)) => Self::Integer8(-val), (UnaryOperator::Neg, Self::Integer16(val)) => Self::Integer16(-val), (UnaryOperator::Neg, Self::Integer32(val)) => Self::Integer32(-val),