diff --git a/.gitlab-ci.yml b/.gitlab-ci.yml index af867eb8e4752b67cfb7cd503a343cd27ee13d91..7a14950ee0416e427b5aeceaacd4b5f68b014fba 100644 --- a/.gitlab-ci.yml +++ b/.gitlab-ci.yml @@ -1,9 +1,9 @@ test-cpu: stage: test script: - - cargo test --features=opencv -vv + - cargo test --features=opencv -vv -- --nocapture test-gpu: stage: test script: - - cargo test --features=cuda,opencv -vv + - cargo test --features=cuda,opencv -vv -- --nocapture diff --git a/Cargo.lock b/Cargo.lock index c438e846d87f12fd7eea2c27d73a8fc7e043836d..1973fbbeb88622cb01fe3ddfc191aae796a00f54 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -10,9 +10,9 @@ checksum = "512761e0bb2578dd7380c6baaa0f4ce03e84f95e960231d1dec8bf4d7d6e2627" [[package]] name = "aho-corasick" -version = "0.7.15" +version = "1.1.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7404febffaa47dac81aa44dba71523c9d069b1bdc50a77db41195149e17f68e5" +checksum = "8e60d3430d3a69478ad0993f19238d2df97c507009a52b3c10addcd7f6bcb916" dependencies = [ "memchr", ] @@ -29,6 +29,12 @@ version = "0.2.21" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "683d7910e743518b0e34f1186f92494becacb047c7b6bf616c96772180fef923" +[[package]] +name = "anes" +version = "0.1.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4b46cbb362ab8752921c97e041f5e366ee6297bd428a31275b9fcf1e380f7299" + [[package]] name = "anstream" version = "0.6.18" @@ -102,12 +108,6 @@ dependencies = [ "syn 2.0.98", ] -[[package]] -name = "arrayvec" -version = "0.5.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "23b62fc65de8e4e7f52534fb52b0f3ed04746ae267519eef2a83941e8085068b" - [[package]] name = "arrayvec" version = "0.7.6" @@ -255,7 +255,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "6678909d8c5d46a42abcf571271e15fdbc0a225e3646cf23762cd415046c78bf" dependencies = [ "anyhow", - "arrayvec 0.7.6", + "arrayvec", "log", "nom 7.1.3", "num-rational", @@ -268,7 +268,7 @@ version = "0.8.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e335041290c43101ca215eed6f43ec437eb5a42125573f600fc3fa42b9bddd62" dependencies = [ - "arrayvec 0.7.6", + "arrayvec", ] [[package]] @@ -310,28 +310,16 @@ version = "2.6.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "6099cdc01846bc367c4e7dd630dc5966dccf36b652fae7a74e17b640411a91b2" -[[package]] -name = "bitvec" -version = "0.19.6" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "55f93d0ef3363c364d5976646a38f04cf67cfe1d4c8d160cdea02cab2c116b33" -dependencies = [ - "funty 1.1.0", - "radium 0.5.3", - "tap", - "wyz 0.2.0", -] - [[package]] name = "bitvec" version = "1.0.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "1bc2832c24239b0141d5674bb9174f9d68a8b5b3f2753311927c172ca46f7e9c" dependencies = [ - "funty 2.0.0", - "radium 0.7.0", + "funty", + "radium", "tap", - "wyz 0.5.1", + "wyz", ] [[package]] @@ -349,9 +337,9 @@ dependencies = [ [[package]] name = "built" -version = "0.7.6" +version = "0.7.7" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "73848a43c5d63a1251d17adf6c2bf78aa94830e60a335a95eeea45d6ba9e1e4d" +checksum = "56ed6191a7e78c36abdb16ab65341eefd73d64d303fffccdbb00d51e4205967b" [[package]] name = "bumpalo" @@ -394,11 +382,17 @@ dependencies = [ "with_builtin_macros", ] +[[package]] +name = "cast" +version = "0.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "37b2a672a2cb129a2e41c10b1224bb368f9f37a2b16b612598138befd7b37eb5" + [[package]] name = "cc" -version = "1.2.13" +version = "1.2.14" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c7777341816418c02e033934a09f20dc0ccaf65a5201ef8a450ae0105a573fda" +checksum = "0c3d1b2e905a3a7b00a6141adb0e4c0bb941d11caf55349d863942a1cc44e3c9" dependencies = [ "jobserver", "libc", @@ -446,6 +440,33 @@ dependencies = [ "vob", ] +[[package]] +name = "ciborium" +version = "0.2.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "42e69ffd6f0917f5c029256a24d0161db17cea3997d185db0d35926308770f0e" +dependencies = [ + "ciborium-io", + "ciborium-ll", + "serde", +] + +[[package]] +name = "ciborium-io" +version = "0.2.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "05afea1e0a06c9be33d539b876f1ce3692f4afea2cb41f740e7743225ed1c757" + +[[package]] +name = "ciborium-ll" +version = "0.2.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "57663b653d948a338bfb3eeba9bb2fd5fcfaecb9e199e87e1eda4d9e8b240fd9" +dependencies = [ + "ciborium-io", + "half", +] + [[package]] name = "clang" version = "2.0.0" @@ -469,9 +490,9 @@ dependencies = [ [[package]] name = "clap" -version = "4.5.29" +version = "4.5.30" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8acebd8ad879283633b343856142139f2da2317c96b05b4dd6181c61e2480184" +checksum = "92b7b18d71fad5313a1e320fa9897994228ce274b60faa4d694fe0ea89cd9e6d" dependencies = [ "clap_builder", "clap_derive", @@ -479,9 +500,9 @@ dependencies = [ [[package]] name = "clap_builder" -version = "4.5.29" +version = "4.5.30" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f6ba32cbda51c7e1dfd49acc1457ba1a7dec5b64fe360e828acb13ca8dc9c2f9" +checksum = "a35db2071778a7344791a4fb4f95308b5673d219dee3ae348b86642574ecc90c" dependencies = [ "anstream", "anstyle", @@ -495,7 +516,7 @@ version = "4.5.28" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "bf4ced95c6f4a675af3da73304b9ac4ed991640c36374e4b46795c49e17cf1ed" dependencies = [ - "heck 0.5.0", + "heck", "proc-macro2", "quote", "syn 2.0.98", @@ -543,6 +564,42 @@ dependencies = [ "cfg-if", ] +[[package]] +name = "criterion" +version = "0.5.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f2b12d017a929603d80db1831cd3a24082f8137ce19c69e6447f54f5fc8d692f" +dependencies = [ + "anes", + "cast", + "ciborium", + "clap", + "criterion-plot", + "is-terminal", + "itertools 0.10.5", + "num-traits", + "once_cell", + "oorandom", + "plotters", + "rayon", + "regex", + "serde", + "serde_derive", + "serde_json", + "tinytemplate", + "walkdir", +] + +[[package]] +name = "criterion-plot" +version = "0.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6b50826342786a51a89e2da3a28f1c32b06e387201bc2d19791f622c673706b1" +dependencies = [ + "cast", + "itertools 0.10.5", +] + [[package]] name = "critical-section" version = "1.2.0" @@ -677,9 +734,9 @@ dependencies = [ [[package]] name = "equivalent" -version = "1.0.1" +version = "1.0.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5443807d6dff69373d433ab9ef5378ad8df50ca6298caf15de6e52e24aaf54d5" +checksum = "877a4ace8713b0bcf2a4e7eec82529c029f1d0619886d18145fea96c3ffe5c0f" [[package]] name = "errno" @@ -794,12 +851,6 @@ version = "0.1.4" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "a0d2fde1f7b3d48b8395d5f2de76c18a528bd6a9cdde438df747bfcba3e05d6f" -[[package]] -name = "funty" -version = "1.1.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "fed34cd105917e91daa4da6b3728c47b068749d6a62c59811f06ed2ac71d9da7" - [[package]] name = "funty" version = "2.0.0" @@ -944,12 +995,6 @@ dependencies = [ "stable_deref_trait", ] -[[package]] -name = "heck" -version = "0.4.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "95505c38b4572b2d910cecb0281560f54b440a19336cbbcb27bf6ce6adc6f5a8" - [[package]] name = "heck" version = "0.5.0" @@ -960,7 +1005,7 @@ checksum = "2304e00983f87ffb38b55b444b5e3b60a884b5d30c0fca7d82fe33449bbe55ea" name = "hercules_cg" version = "0.1.0" dependencies = [ - "bitvec 1.0.1", + "bitvec", "hercules_ir", "ordered-float", "rand 0.9.0", @@ -971,7 +1016,7 @@ dependencies = [ name = "hercules_interpreter" version = "0.1.0" dependencies = [ - "bitvec 1.0.1", + "bitvec", "clap", "derive_more", "hercules_ir", @@ -988,9 +1033,9 @@ dependencies = [ name = "hercules_ir" version = "0.1.0" dependencies = [ - "bitvec 1.0.1", + "bitvec", "either", - "nom 6.2.2", + "nom 8.0.0", "ordered-float", "rand 0.9.0", "serde", @@ -1001,7 +1046,7 @@ name = "hercules_opt" version = "0.1.0" dependencies = [ "bimap", - "bitvec 1.0.1", + "bitvec", "egg", "either", "hercules_cg", @@ -1025,7 +1070,7 @@ version = "0.1.0" name = "hercules_tests" version = "0.1.0" dependencies = [ - "bitvec 1.0.1", + "bitvec", "clap", "hercules_interpreter", "hercules_ir", @@ -1102,12 +1147,32 @@ dependencies = [ "syn 2.0.98", ] +[[package]] +name = "is-terminal" +version = "0.4.15" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e19b23d53f35ce9f56aebc7d1bb4e6ac1e9c0db7ac85c8d1760c04379edced37" +dependencies = [ + "hermit-abi", + "libc", + "windows-sys", +] + [[package]] name = "is_terminal_polyfill" version = "1.70.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "7943c866cc5cd64cbc25b2e01621d07fa8eb2a1a23160ee81ce38704e97b8ecf" +[[package]] +name = "itertools" +version = "0.10.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b0fd2260e829bddf4cb6ea802289de2f86d6a7a690192fbe91b3f46e0f2c8473" +dependencies = [ + "either", +] + [[package]] name = "itertools" version = "0.12.1" @@ -1175,7 +1240,7 @@ dependencies = [ "clap", "hercules_rt", "juno_build", - "nom 6.2.2", + "nom 8.0.0", "rand 0.9.0", "with_builtin_macros", ] @@ -1188,7 +1253,7 @@ dependencies = [ "clap", "hercules_rt", "juno_build", - "nom 6.2.2", + "nom 8.0.0", "with_builtin_macros", ] @@ -1217,6 +1282,7 @@ version = "0.1.0" dependencies = [ "async-std", "clap", + "criterion", "hercules_rt", "image", "juno_build", @@ -1231,7 +1297,7 @@ dependencies = [ "clap", "hercules_rt", "juno_build", - "nom 6.2.2", + "nom 8.0.0", "with_builtin_macros", ] @@ -1414,7 +1480,7 @@ dependencies = [ "clap", "hercules_rt", "juno_build", - "nom 6.2.2", + "nom 8.0.0", "with_builtin_macros", ] @@ -1446,19 +1512,6 @@ version = "0.5.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "03087c2bad5e1034e8cace5926dec053fb3790248370865f5117a7d0213354c8" -[[package]] -name = "lexical-core" -version = "0.7.6" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6607c62aa161d23d17a9072cc5da0be67cdfc89d3afb1e8d9c842bebc2525ffe" -dependencies = [ - "arrayvec 0.5.2", - "bitflags 1.3.2", - "cfg-if", - "ryu", - "static_assertions", -] - [[package]] name = "libc" version = "0.2.169" @@ -1543,7 +1596,7 @@ dependencies = [ "num-traits", "quote", "regex", - "regex-syntax 0.8.5", + "regex-syntax", "serde", "vergen", ] @@ -1608,9 +1661,9 @@ dependencies = [ [[package]] name = "memchr" -version = "2.3.4" +version = "2.7.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0ee1c47aaa256ecabcaea351eae4a9b01ef39ed810004e298d2511ed284b1525" +checksum = "78ca9ab1a0babb1e7d5695e3530886289c18cf2f87ec19a575a0abdce112e3a3" [[package]] name = "minimal-lexical" @@ -1648,25 +1701,21 @@ checksum = "650eef8c711430f1a879fdd01d4745a7deea475becfb90269c06775983bbf086" [[package]] name = "nom" -version = "6.2.2" +version = "7.1.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c6a7a9657c84d5814c6196b68bb4429df09c18b1573806259fba397ea4ad0d44" +checksum = "d273983c5a657a70a3e8f2a01329822f3b8c8172b73826411a55751e404a0a4a" dependencies = [ - "bitvec 0.19.6", - "funty 1.1.0", - "lexical-core", "memchr", - "version_check", + "minimal-lexical", ] [[package]] name = "nom" -version = "7.1.3" +version = "8.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d273983c5a657a70a3e8f2a01329822f3b8c8172b73826411a55751e404a0a4a" +checksum = "df9761775871bdef83bee530e60050f7e54b1105350d6884eb0fb4f46c2f9405" dependencies = [ "memchr", - "minimal-lexical", ] [[package]] @@ -1746,6 +1795,12 @@ version = "1.20.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "945462a4b81e43c4e3ba96bd7b49d834c6f61198356aa858733bc4acf3cbe62e" +[[package]] +name = "oorandom" +version = "11.1.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b410bbe7e14ab526a0e86877eb47c6996a2bd7746f027ba551028c925390e4e9" + [[package]] name = "opencv" version = "0.94.2" @@ -1783,9 +1838,9 @@ dependencies = [ [[package]] name = "ordered-float" -version = "4.6.0" +version = "5.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7bb71e1b3fa6ca1c61f383464aaf2bb0e2f8e772a1f01d486832464de363b951" +checksum = "e2c1f9f56e534ac6a9b8a4600bdf0f530fb393b5f393e7b4d03489c3cf0c3f01" dependencies = [ "num-traits", "rand 0.8.5", @@ -1891,6 +1946,34 @@ version = "0.3.31" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "953ec861398dccce10c670dfeaf3ec4911ca479e9c02154b3a215178c5f566f2" +[[package]] +name = "plotters" +version = "0.3.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5aeb6f403d7a4911efb1e33402027fc44f29b5bf6def3effcc22d7bb75f2b747" +dependencies = [ + "num-traits", + "plotters-backend", + "plotters-svg", + "wasm-bindgen", + "web-sys", +] + +[[package]] +name = "plotters-backend" +version = "0.3.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "df42e13c12958a16b3f7f4386b9ab1f3e7933914ecea48da7139435263a4172a" + +[[package]] +name = "plotters-svg" +version = "0.3.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "51bae2ac328883f7acdfea3d66a7c35751187f870bc81f94563733a154d7a670" +dependencies = [ + "plotters-backend", +] + [[package]] name = "png" version = "0.17.16" @@ -2048,12 +2131,6 @@ dependencies = [ "proc-macro2", ] -[[package]] -name = "radium" -version = "0.5.3" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "941ba9d78d8e2f7ce474c015eea4d9c6d25b6a3327f9832ee29a4de27f91bbb8" - [[package]] name = "radium" version = "0.7.0" @@ -2079,8 +2156,8 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "3779b94aeb87e8bd4e834cee3650289ee9e0d5677f976ecdb6d219e5f4f6cd94" dependencies = [ "rand_chacha 0.9.0", - "rand_core 0.9.0", - "zerocopy 0.8.17", + "rand_core 0.9.1", + "zerocopy 0.8.18", ] [[package]] @@ -2100,7 +2177,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d3022b5f1df60f26e1ffddd6c66e8aa15de382ae63b3a0c1bfc0e4d3e3f325cb" dependencies = [ "ppv-lite86", - "rand_core 0.9.0", + "rand_core 0.9.1", ] [[package]] @@ -2115,12 +2192,12 @@ dependencies = [ [[package]] name = "rand_core" -version = "0.9.0" +version = "0.9.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b08f3c9802962f7e1b25113931d94f43ed9725bebc59db9d0c3e9a23b67e15ff" +checksum = "a88e0da7a2c97baa202165137c158d0a2e824ac465d13d81046727b34cb247d3" dependencies = [ "getrandom 0.3.1", - "zerocopy 0.8.17", + "zerocopy 0.8.18", ] [[package]] @@ -2131,7 +2208,7 @@ checksum = "cd87ce80a7665b1cce111f8a16c1f3929f6547ce91ade6addf4ec86a8dda5ce9" dependencies = [ "arbitrary", "arg_enum_proc_macro", - "arrayvec 0.7.6", + "arrayvec", "av1-grain", "bitstream-io", "built", @@ -2213,20 +2290,26 @@ dependencies = [ [[package]] name = "regex" -version = "1.4.6" +version = "1.11.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2a26af418b574bd56588335b3a3659a65725d4e636eb1016c2f9e3b38c7cc759" +checksum = "b544ef1b4eac5dc2db33ea63606ae9ffcfac26c1416a2806ae0bf5f56b201191" dependencies = [ "aho-corasick", "memchr", - "regex-syntax 0.6.29", + "regex-automata", + "regex-syntax", ] [[package]] -name = "regex-syntax" -version = "0.6.29" +name = "regex-automata" +version = "0.4.9" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f162c6dd7b008981e4d40210aca20b4bd0f9b60ca9271061b07f78537722f2e1" +checksum = "809e8dc61f6de73b46c85f4c96486310fe304c434cfa43669d7b40f711150908" +dependencies = [ + "aho-corasick", + "memchr", + "regex-syntax", +] [[package]] name = "regex-syntax" @@ -2280,6 +2363,15 @@ version = "1.0.19" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "6ea1a2d0a644769cc99faa24c3ad26b379b786fe7c36fd3c546254801650e6dd" +[[package]] +name = "same-file" +version = "1.0.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "93fc1dc3aaa9bfed95e02e6eadabb4baf7e3078b0bd1b4d7b6b0b68378900502" +dependencies = [ + "winapi-util", +] + [[package]] name = "saturating" version = "0.1.0" @@ -2318,6 +2410,18 @@ dependencies = [ "syn 2.0.98", ] +[[package]] +name = "serde_json" +version = "1.0.138" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d434192e7da787e94a6ea7e9670b26a036d0ca41e0b7efb2676dd32bae872949" +dependencies = [ + "itoa", + "memchr", + "ryu", + "serde", +] + [[package]] name = "serde_spanned" version = "0.6.8" @@ -2374,9 +2478,9 @@ dependencies = [ [[package]] name = "smallvec" -version = "1.13.2" +version = "1.14.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3c5e1a9a646d36c3599cd173a41282daf47c44583ad367b8e6837255952e5c67" +checksum = "7fcf8323ef1faaee30a44a340193b1ac6814fd9b7b4e88e9d4519a3e4abe1cfd" [[package]] name = "sparsevec" @@ -2458,12 +2562,12 @@ dependencies = [ [[package]] name = "system-deps" -version = "6.1.1" +version = "6.2.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "30c2de8a4d8f4b823d634affc9cd2a74ec98c53a756f317e529a48046cbf71f3" +checksum = "a3e535eb8dded36d55ec13eddacd30dec501792ff23a0b1682c38601b8cf2349" dependencies = [ "cfg-expr", - "heck 0.4.1", + "heck", "pkg-config", "toml", "version-compare", @@ -2489,9 +2593,9 @@ checksum = "61c41af27dd6d1e27b1b16b489db798443478cef1f06a660c96db617ba5de3b1" [[package]] name = "tempfile" -version = "3.16.0" +version = "3.17.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "38c246215d7d24f48ae091a2902398798e05d978b24315d6efbc00ede9a8bb91" +checksum = "22e5a0acb1f3f55f65cc4a866c361b2fb2a0ff6366785ae6fbb5f85df07ba230" dependencies = [ "cfg-if", "fastrand", @@ -2565,11 +2669,21 @@ dependencies = [ "time-core", ] +[[package]] +name = "tinytemplate" +version = "1.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "be4d6b5f19ff7664e8c98d03e2139cb510db9b0a60b55f8e8709b689d939b6bc" +dependencies = [ + "serde", + "serde_json", +] + [[package]] name = "toml" -version = "0.7.6" +version = "0.8.20" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c17e963a819c331dcacd7ab957d80bc2b9a9c1e71c804826d2f283dd65306542" +checksum = "cd87a5cdd6ffab733b2f74bc4fd7ee5fff6634124999ac278c35fc78c6120148" dependencies = [ "serde", "serde_spanned", @@ -2588,9 +2702,9 @@ dependencies = [ [[package]] name = "toml_edit" -version = "0.19.12" +version = "0.22.24" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c500344a19072298cd05a7224b3c0c629348b78692bf48466c5238656e315a78" +checksum = "17b4795ff5edd201c7cd6dca065ae59972ce77d1b80fa0a84d94950ece7d1474" dependencies = [ "indexmap", "serde", @@ -2675,9 +2789,9 @@ dependencies = [ [[package]] name = "version-compare" -version = "0.1.1" +version = "0.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "579a42fc0b8e0c63b76519a339be31bed574929511fa53c1a3acae26eb258f29" +checksum = "852e951cb7832cb45cb1169900d19760cfa39b82bc0ea9c0e5a14ae88411c98b" [[package]] name = "version_check" @@ -2696,6 +2810,16 @@ dependencies = [ "serde", ] +[[package]] +name = "walkdir" +version = "2.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "29790946404f91d9c5d06f9874efddea1dc06c5efe94541a7d6863108e3a5e4b" +dependencies = [ + "same-file", + "winapi-util", +] + [[package]] name = "wasi" version = "0.11.0+wasi-snapshot-preview1" @@ -2814,6 +2938,15 @@ version = "0.4.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "ac3b87c63620426dd9b991e5ce0329eff545bccbbb34f3be09ff6fb6ab51b7b6" +[[package]] +name = "winapi-util" +version = "0.1.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "cf221c93e13a30d793f7645a0e7762c55d169dbb0a49671918a2319d289b10bb" +dependencies = [ + "windows-sys", +] + [[package]] name = "winapi-x86_64-pc-windows-gnu" version = "0.4.0" @@ -3022,9 +3155,9 @@ checksum = "271414315aff87387382ec3d271b52d7ae78726f5d44ac98b4f4030c91880486" [[package]] name = "winnow" -version = "0.4.7" +version = "0.7.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ca0ace3845f0d96209f0375e6d367e3eb87eb65d27d445bdc9f1843a26f39448" +checksum = "59690dea168f2198d1a3b0cac23b8063efcd11012f10ae4698f284808c8ef603" dependencies = [ "memchr", ] @@ -3058,12 +3191,6 @@ dependencies = [ "syn 1.0.109", ] -[[package]] -name = "wyz" -version = "0.2.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "85e60b0d1b5f99db2556934e21937020776a5d31520bf169e851ac44e6420214" - [[package]] name = "wyz" version = "0.5.1" @@ -3085,11 +3212,11 @@ dependencies = [ [[package]] name = "zerocopy" -version = "0.8.17" +version = "0.8.18" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "aa91407dacce3a68c56de03abe2760159582b846c6a4acd2f456618087f12713" +checksum = "79386d31a42a4996e3336b0919ddb90f81112af416270cff95b5f5af22b839c2" dependencies = [ - "zerocopy-derive 0.8.17", + "zerocopy-derive 0.8.18", ] [[package]] @@ -3105,9 +3232,9 @@ dependencies = [ [[package]] name = "zerocopy-derive" -version = "0.8.17" +version = "0.8.18" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "06718a168365cad3d5ff0bb133aad346959a2074bd4a85c121255a11304a8626" +checksum = "76331675d372f91bf8d17e13afbd5fe639200b73d01f0fc748bb059f9cca2db7" dependencies = [ "proc-macro2", "quote", diff --git a/hercules_ir/Cargo.toml b/hercules_ir/Cargo.toml index 0d6a25fb3884f0512b6defaba1c3255951e701a7..26950d4b7700d19326e6ea61aa2488b4c5d5df59 100644 --- a/hercules_ir/Cargo.toml +++ b/hercules_ir/Cargo.toml @@ -6,7 +6,7 @@ edition = "2021" [dependencies] rand = "*" -nom = "6.2.2" +nom = "*" ordered-float = { version = "*", features = ["serde"] } bitvec = "*" serde = { version = "*", features = ["derive"] } diff --git a/hercules_ir/src/parse.rs b/hercules_ir/src/parse.rs index f1f4153a1e5d202d6ff6d3e75e5e8353bab658e5..a38df8e98f9e4527ebe23ffbc39a4184ce714f90 100644 --- a/hercules_ir/src/parse.rs +++ b/hercules_ir/src/parse.rs @@ -3,6 +3,8 @@ use std::collections::{HashMap, HashSet}; use std::ops::Deref; use std::str::FromStr; +use nom::Parser; + use crate::*; /* @@ -128,9 +130,8 @@ fn parse_module<'a>(ir_text: &'a str, context: Context<'a>) -> nom::IResult<&'a // If there is any text left after successfully parsing some functions, // treat that as an error. let (rest, functions) = - nom::combinator::all_consuming(nom::multi::many0(|x| parse_function(x, &context)))( - ir_text, - )?; + nom::combinator::all_consuming(nom::multi::many0(|x| parse_function(x, &context))) + .parse(ir_text)?; let mut context = context.into_inner(); // Functions, as returned by parsing, is in parse order, which may differ @@ -215,7 +216,7 @@ fn parse_function<'a>( Ok((ir_text, num_dynamic_constants)) }; let (ir_text, num_dynamic_constants) = - nom::combinator::opt(parse_num_dynamic_constants)(ir_text)?; + nom::combinator::opt(parse_num_dynamic_constants).parse(ir_text)?; // If unspecified, assumed function has no dynamic constant arguments. let num_dynamic_constants = num_dynamic_constants.unwrap_or(0); @@ -223,7 +224,7 @@ fn parse_function<'a>( let ir_text = nom::character::complete::char('(')(ir_text)?.0; let (ir_text, params) = nom::multi::separated_list0( nom::character::complete::char(','), - nom::sequence::tuple(( + ( nom::character::complete::multispace0, parse_identifier, nom::character::complete::multispace0, @@ -231,8 +232,9 @@ fn parse_function<'a>( nom::character::complete::multispace0, |x| parse_type_id(x, context), nom::character::complete::multispace0, - )), - )(ir_text)?; + ), + ) + .parse(ir_text)?; // The start node is not explicitly specified in the textual IR, so create // it manually. @@ -246,7 +248,7 @@ fn parse_function<'a>( let ir_text = nom::character::complete::multispace0(ir_text)?.0; let ir_text = nom::bytes::complete::tag("->")(ir_text)?.0; let (ir_text, return_type) = parse_type_id(ir_text, context)?; - let (ir_text, nodes) = nom::multi::many1(|x| parse_node(x, context))(ir_text)?; + let (ir_text, nodes) = nom::multi::many1(|x| parse_node(x, context)).parse(ir_text)?; // `nodes`, as returned by parsing, is in parse order, which may differ from // the order dictated by NodeIDs in the node name intern map. @@ -383,11 +385,11 @@ fn parse_region<'a>( // explicitly using nom's separated list functionality. This example here // is a bit of an abuse of what parse_tupleN functions are meant for. let (ir_text, (preds,)) = parse_tuple1(nom::multi::separated_list1( - nom::sequence::tuple(( + ( nom::character::complete::multispace0, nom::character::complete::char(','), nom::character::complete::multispace0, - )), + ), parse_identifier, ))(ir_text)?; @@ -411,11 +413,11 @@ fn parse_fork<'a>(ir_text: &'a str, context: &RefCell<Context<'a>>) -> nom::IRes let (ir_text, (control, factors)) = parse_tuple2( parse_identifier, nom::multi::separated_list1( - nom::sequence::tuple(( + ( nom::character::complete::multispace0, nom::character::complete::char(','), nom::character::complete::multispace0, - )), + ), |x| parse_dynamic_constant_id(x, context), ), )(ir_text)?; @@ -440,11 +442,11 @@ fn parse_phi<'a>(ir_text: &'a str, context: &RefCell<Context<'a>>) -> nom::IResu let (ir_text, (control, data)) = parse_tuple2( parse_identifier, nom::multi::separated_list1( - nom::sequence::tuple(( + ( nom::character::complete::multispace0, nom::character::complete::char(','), nom::character::complete::multispace0, - )), + ), parse_identifier, ), )(ir_text)?; @@ -575,29 +577,32 @@ fn parse_call<'a>(ir_text: &'a str, context: &RefCell<Context<'a>>) -> nom::IRes let ir_text = nom::character::complete::char('<')(ir_text)?.0; let ir_text = nom::character::complete::multispace0(ir_text)?.0; let (ir_text, dynamic_constants) = nom::multi::separated_list1( - nom::sequence::tuple(( + ( nom::character::complete::multispace0, nom::character::complete::char(','), nom::character::complete::multispace0, - )), + ), |x| parse_dynamic_constant_id(x, context), - )(ir_text)?; + ) + .parse(ir_text)?; let ir_text = nom::character::complete::multispace0(ir_text)?.0; let ir_text = nom::character::complete::char('>')(ir_text)?.0; Ok((ir_text, dynamic_constants)) }; - let (ir_text, dynamic_constants) = nom::combinator::opt(parse_dynamic_constants)(ir_text)?; + let (ir_text, dynamic_constants) = + nom::combinator::opt(parse_dynamic_constants).parse(ir_text)?; let dynamic_constants = dynamic_constants.unwrap_or(vec![]); let ir_text = nom::character::complete::char('(')(ir_text)?.0; let ir_text = nom::character::complete::multispace0(ir_text)?.0; let (ir_text, mut function_and_args) = nom::multi::separated_list1( - nom::sequence::tuple(( + ( nom::character::complete::multispace0, nom::character::complete::char(','), nom::character::complete::multispace0, - )), + ), parse_identifier, - )(ir_text)?; + ) + .parse(ir_text)?; let function = function_and_args.remove(0); let mut args: Vec<NodeID> = function_and_args .into_iter() @@ -628,13 +633,14 @@ fn parse_intrinsic<'a>( let ir_text = nom::character::complete::char('(')(ir_text)?.0; let ir_text = nom::character::complete::multispace0(ir_text)?.0; let (ir_text, mut intrinsic_and_args) = nom::multi::separated_list1( - nom::sequence::tuple(( + ( nom::character::complete::multispace0, nom::character::complete::char(','), nom::character::complete::multispace0, - )), + ), parse_identifier, - )(ir_text)?; + ) + .parse(ir_text)?; let intrinsic = intrinsic_and_args.remove(0); let args: Vec<NodeID> = intrinsic_and_args .into_iter() @@ -661,7 +667,7 @@ fn parse_index<'a>( ) -> nom::IResult<&'a str, Index> { let (ir_text, idx) = nom::branch::alt(( nom::combinator::map( - nom::sequence::tuple(( + ( nom::character::complete::multispace0, nom::bytes::complete::tag("field"), nom::character::complete::multispace0, @@ -671,11 +677,11 @@ fn parse_index<'a>( nom::character::complete::multispace0, nom::character::complete::char(')'), nom::character::complete::multispace0, - )), + ), |(_, _, _, _, _, x, _, _, _)| Index::Field(x), ), nom::combinator::map( - nom::sequence::tuple(( + ( nom::character::complete::multispace0, nom::bytes::complete::tag("variant"), nom::character::complete::multispace0, @@ -685,28 +691,28 @@ fn parse_index<'a>( nom::character::complete::multispace0, nom::character::complete::char(')'), nom::character::complete::multispace0, - )), + ), |(_, _, _, _, _, x, _, _, _)| Index::Variant(x), ), nom::combinator::map( - nom::sequence::tuple(( + ( nom::character::complete::multispace0, nom::bytes::complete::tag("position"), nom::character::complete::multispace0, nom::character::complete::char('('), nom::character::complete::multispace0, nom::multi::separated_list1( - nom::sequence::tuple(( + ( nom::character::complete::multispace0, nom::character::complete::char(','), nom::character::complete::multispace0, - )), + ), parse_identifier, ), nom::character::complete::multispace0, nom::character::complete::char(')'), nom::character::complete::multispace0, - )), + ), |(_, _, _, _, _, x, _, _, _)| { Index::Position( x.into_iter() @@ -715,7 +721,8 @@ fn parse_index<'a>( ) }, ), - ))(ir_text)?; + )) + .parse(ir_text)?; Ok((ir_text, idx)) } @@ -744,13 +751,14 @@ fn parse_read<'a>(ir_text: &'a str, context: &RefCell<Context<'a>>) -> nom::IRes let ir_text = nom::character::complete::char(',')(ir_text)?.0; let ir_text = nom::character::complete::multispace0(ir_text)?.0; let (ir_text, indices) = nom::multi::separated_list1( - nom::sequence::tuple(( + ( nom::character::complete::multispace0, nom::character::complete::char(','), nom::character::complete::multispace0, - )), + ), |x| parse_index(x, context), - )(ir_text)?; + ) + .parse(ir_text)?; let ir_text = nom::character::complete::multispace0(ir_text)?.0; let ir_text = nom::character::complete::char(')')(ir_text)?.0; let collect = context.borrow_mut().get_node_id(collect); @@ -779,13 +787,14 @@ fn parse_write<'a>( let ir_text = nom::character::complete::char(',')(ir_text)?.0; let ir_text = nom::character::complete::multispace0(ir_text)?.0; let (ir_text, indices) = nom::multi::separated_list1( - nom::sequence::tuple(( + ( nom::character::complete::multispace0, nom::character::complete::char(','), nom::character::complete::multispace0, - )), + ), |x| parse_index(x, context), - )(ir_text)?; + ) + .parse(ir_text)?; let ir_text = nom::character::complete::multispace0(ir_text)?.0; let ir_text = nom::character::complete::char(')')(ir_text)?.0; let collect = context.borrow_mut().get_node_id(collect); @@ -844,48 +853,48 @@ fn parse_type<'a>(ir_text: &'a str, context: &RefCell<Context<'a>>) -> nom::IRes nom::combinator::map(nom::bytes::complete::tag("f64"), |_| Type::Float64), // Product types are parsed as a list of their element types. nom::combinator::map( - nom::sequence::tuple(( + ( nom::bytes::complete::tag("prod"), nom::character::complete::multispace0, nom::character::complete::char('('), nom::character::complete::multispace0, nom::multi::separated_list1( - nom::sequence::tuple(( + ( nom::character::complete::multispace0, nom::character::complete::char(','), nom::character::complete::multispace0, - )), + ), |x| parse_type_id(x, context), ), nom::character::complete::multispace0, nom::character::complete::char(')'), - )), + ), |(_, _, _, _, ids, _, _)| Type::Product(ids.into_boxed_slice()), ), // Sum types are parsed as a list of their variant types. nom::combinator::map( - nom::sequence::tuple(( + ( nom::bytes::complete::tag("sum"), nom::character::complete::multispace0, nom::character::complete::char('('), nom::character::complete::multispace0, nom::multi::separated_list1( - nom::sequence::tuple(( + ( nom::character::complete::multispace0, nom::character::complete::char(','), nom::character::complete::multispace0, - )), + ), |x| parse_type_id(x, context), ), nom::character::complete::multispace0, nom::character::complete::char(')'), - )), + ), |(_, _, _, _, ids, _, _)| Type::Summation(ids.into_boxed_slice()), ), // Array types are just a list of an element type and at least one // dynamic constant representing its extent. nom::combinator::map( - nom::sequence::tuple(( + ( nom::bytes::complete::tag("array"), nom::character::complete::multispace0, nom::character::complete::char('('), @@ -895,21 +904,22 @@ fn parse_type<'a>(ir_text: &'a str, context: &RefCell<Context<'a>>) -> nom::IRes nom::character::complete::char(','), nom::character::complete::multispace0, nom::multi::separated_list1( - nom::sequence::tuple(( + ( nom::character::complete::multispace0, nom::character::complete::char(','), nom::character::complete::multispace0, - )), + ), |x| parse_dynamic_constant_id(x, context), ), nom::character::complete::multispace0, nom::character::complete::char(')'), - )), + ), |(_, _, _, _, ty_id, _, _, _, dc_ids, _, _)| { Type::Array(ty_id, dc_ids.into_boxed_slice()) }, ), - ))(ir_text)?; + )) + .parse(ir_text)?; Ok((ir_text, ty)) } @@ -940,20 +950,20 @@ fn parse_dynamic_constant<'a>( // Parameter dynamic constants of a function are written by preprending // a '#' to the parameter's number. nom::combinator::map( - nom::sequence::tuple((nom::character::complete::char('#'), |x| { + (nom::character::complete::char('#'), |x| { parse_prim::<usize>(x, "1234567890") - })), + }), |(_, x)| DynamicConstant::Parameter(x), ), // Dynamic constant math is written using a prefix function nom::combinator::map( - nom::sequence::tuple(( + ( nom::character::complete::one_of("+-*/%"), parse_tuple2( |x| parse_dynamic_constant_id(x, context), |x| parse_dynamic_constant_id(x, context), ), - )), + ), |(op, (x, y))| match op { '+' => DynamicConstant::Add(vec![x, y]), '-' => DynamicConstant::Sub(x, y), @@ -963,7 +973,8 @@ fn parse_dynamic_constant<'a>( _ => panic!("Invalid parse"), }, ), - ))(ir_text)?; + )) + .parse(ir_text)?; Ok((ir_text, dc)) } @@ -1041,7 +1052,8 @@ fn parse_boolean<'a>(ir_text: &'a str) -> nom::IResult<&'a str, Constant> { let (ir_text, val) = nom::branch::alt(( nom::combinator::map(nom::bytes::complete::tag("false"), |_| false), nom::combinator::map(nom::bytes::complete::tag("true"), |_| true), - ))(ir_text)?; + )) + .parse(ir_text)?; Ok((ir_text, Constant::Boolean(val))) } @@ -1178,7 +1190,8 @@ fn parse_identifier<'a>(ir_text: &'a str) -> nom::IResult<&'a str, &'a str> { "1234567890_@ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz", ), |s: &str| s.len() > 0, - )(ir_text) + ) + .parse(ir_text) } /* @@ -1186,7 +1199,7 @@ fn parse_identifier<'a>(ir_text: &'a str) -> nom::IResult<&'a str, &'a str> { */ fn parse_tuple1<'a, A, AF>(mut parse_a: AF) -> impl FnMut(&'a str) -> nom::IResult<&'a str, (A,)> where - AF: nom::Parser<&'a str, A, nom::error::Error<&'a str>>, + AF: nom::Parser<&'a str, Output = A, Error = nom::error::Error<&'a str>>, { move |ir_text: &'a str| { let ir_text = nom::character::complete::multispace0(ir_text)?.0; @@ -1204,8 +1217,8 @@ fn parse_tuple2<'a, A, B, AF, BF>( mut parse_b: BF, ) -> impl FnMut(&'a str) -> nom::IResult<&'a str, (A, B)> where - AF: nom::Parser<&'a str, A, nom::error::Error<&'a str>>, - BF: nom::Parser<&'a str, B, nom::error::Error<&'a str>>, + AF: nom::Parser<&'a str, Output = A, Error = nom::error::Error<&'a str>>, + BF: nom::Parser<&'a str, Output = B, Error = nom::error::Error<&'a str>>, { move |ir_text: &'a str| { let ir_text = nom::character::complete::multispace0(ir_text)?.0; @@ -1228,9 +1241,9 @@ fn parse_tuple3<'a, A, B, C, AF, BF, CF>( mut parse_c: CF, ) -> impl FnMut(&'a str) -> nom::IResult<&'a str, (A, B, C)> where - AF: nom::Parser<&'a str, A, nom::error::Error<&'a str>>, - BF: nom::Parser<&'a str, B, nom::error::Error<&'a str>>, - CF: nom::Parser<&'a str, C, nom::error::Error<&'a str>>, + AF: nom::Parser<&'a str, Output = A, Error = nom::error::Error<&'a str>>, + BF: nom::Parser<&'a str, Output = B, Error = nom::error::Error<&'a str>>, + CF: nom::Parser<&'a str, Output = C, Error = nom::error::Error<&'a str>>, { move |ir_text: &'a str| { let ir_text = nom::character::complete::multispace0(ir_text)?.0; diff --git a/hercules_rt/src/lib.rs b/hercules_rt/src/lib.rs index 419a760fa49647c28295625bd7df9db0e87707c1..090a38a02cbbcd46253452f76a0b71681363c833 100644 --- a/hercules_rt/src/lib.rs +++ b/hercules_rt/src/lib.rs @@ -1,6 +1,6 @@ #![feature(once_cell_try)] -use std::alloc::{alloc, dealloc, Layout}; +use std::alloc::{alloc, dealloc, GlobalAlloc, Layout, System}; use std::marker::PhantomData; use std::ptr::{copy_nonoverlapping, write_bytes, NonNull}; use std::slice::{from_raw_parts, from_raw_parts_mut}; @@ -867,3 +867,24 @@ impl<'a, T> HerculesRefInto<'a> for Box<[T]> { HerculesCPURef::from_slice(self) } } + +/* + * We need all allocations to be aligned to LARGEST_ALIGNMENT bytes for + * vectorization. This is the easiest way to do that. + */ +pub struct AlignedAlloc; + +unsafe impl GlobalAlloc for AlignedAlloc { + unsafe fn alloc(&self, layout: Layout) -> *mut u8 { + let layout = layout.align_to(LARGEST_ALIGNMENT).unwrap(); + System.alloc(layout) + } + + unsafe fn dealloc(&self, ptr: *mut u8, layout: Layout) { + let layout = layout.align_to(LARGEST_ALIGNMENT).unwrap(); + System.dealloc(ptr, layout) + } +} + +#[global_allocator] +static A: AlignedAlloc = AlignedAlloc; diff --git a/hercules_samples/dot/src/main.rs b/hercules_samples/dot/src/main.rs index 7f5b453ab426f1ce0ab220682ce6be89bf851305..1f28cee28241827277f8836e963a0d80edeb5abc 100644 --- a/hercules_samples/dot/src/main.rs +++ b/hercules_samples/dot/src/main.rs @@ -2,36 +2,20 @@ #[cfg(feature = "cuda")] use hercules_rt::CUDABox; -use hercules_rt::{runner, HerculesCPURef}; +use hercules_rt::{runner, HerculesImmBox, HerculesImmBoxTo}; juno_build::juno!("dot"); fn main() { async_std::task::block_on(async { - #[cfg(not(feature = "cuda"))] - { - let a: [f32; 8] = [0.0, 1.0, 0.0, 2.0, 0.0, 3.0, 0.0, 4.0]; - let a = HerculesCPURef::from_slice(&a); - let b: [f32; 8] = [0.0, 5.0, 0.0, 6.0, 0.0, 7.0, 0.0, 8.0]; - let b = HerculesCPURef::from_slice(&b); - let mut r = runner!(dot); - let c = r.run(8, a, b).await; - println!("{}", c); - assert_eq!(c, 70.0); - } - #[cfg(feature = "cuda")] - { - let mut a: [f32; 8] = [0.0, 1.0, 0.0, 2.0, 0.0, 3.0, 0.0, 4.0]; - let a_box = CUDABox::from_cpu_ref(HerculesCPURef::from_slice(&mut a)); - let a = a_box.get_ref(); - let mut b: [f32; 8] = [0.0, 5.0, 0.0, 6.0, 0.0, 7.0, 0.0, 8.0]; - let b_box = CUDABox::from_cpu_ref(HerculesCPURef::from_slice(&mut b)); - let b = b_box.get_ref(); - let mut r = runner!(dot); - let c = r.run(8, a, b).await; - println!("{}", c); - assert_eq!(c, 70.0); - } + let a: Box<[f32; 8]> = Box::new([0.0, 1.0, 0.0, 2.0, 0.0, 3.0, 0.0, 4.0]); + let b: Box<[f32; 8]> = Box::new([0.0, 5.0, 0.0, 6.0, 0.0, 7.0, 0.0, 8.0]); + let a = HerculesImmBox::from(a.as_ref() as &[f32]); + let b = HerculesImmBox::from(b.as_ref() as &[f32]); + let mut r = runner!(dot); + let c = r.run(8, a.to(), b.to()).await; + println!("{}", c); + assert_eq!(c, 70.0); }); } diff --git a/hercules_samples/matmul/src/main.rs b/hercules_samples/matmul/src/main.rs index 5c87991569ab59e8d978f17d74f1d18423679669..277276648e905186bfeb54714fb00f7275f17b22 100644 --- a/hercules_samples/matmul/src/main.rs +++ b/hercules_samples/matmul/src/main.rs @@ -1,20 +1,19 @@ #![feature(concat_idents)] +use std::iter::zip; use rand::random; -#[cfg(feature = "cuda")] -use hercules_rt::CUDABox; -use hercules_rt::{runner, HerculesCPURef}; +use hercules_rt::{runner, HerculesImmBox, HerculesImmBoxTo, HerculesMutBox}; juno_build::juno!("matmul"); fn main() { async_std::task::block_on(async { const I: usize = 256; - const J: usize = 8; // hardcoded constant in matmul.hir + const J: usize = 64; const K: usize = 128; - let mut a: Box<[i32]> = (0..I * J).map(|_| random::<i32>() % 100).collect(); - let mut b: Box<[i32]> = (0..J * K).map(|_| random::<i32>() % 100).collect(); + let a: Box<[i32]> = (0..I * J).map(|_| random::<i32>() % 100).collect(); + let b: Box<[i32]> = (0..J * K).map(|_| random::<i32>() % 100).collect(); let mut correct_c: Box<[i32]> = (0..I * K).map(|_| 0).collect(); for i in 0..I { for k in 0..K { @@ -23,26 +22,11 @@ fn main() { } } } - #[cfg(not(feature = "cuda"))] - { - let a = HerculesCPURef::from_slice(&mut a); - let b = HerculesCPURef::from_slice(&mut b); - let mut r = runner!(matmul); - let c = r.run(I as u64, J as u64, K as u64, a, b).await; - assert_eq!(c.as_slice::<i32>(), &*correct_c); - } - #[cfg(feature = "cuda")] - { - let a = CUDABox::from_cpu_ref(HerculesCPURef::from_slice(&mut a)); - let b = CUDABox::from_cpu_ref(HerculesCPURef::from_slice(&mut b)); - let mut r = runner!(matmul); - let c = r - .run(I as u64, J as u64, K as u64, a.get_ref(), b.get_ref()) - .await; - let mut c_cpu: Box<[i32]> = vec![0; correct_c.len()].into_boxed_slice(); - c.to_cpu_ref(&mut c_cpu); - assert_eq!(&*c_cpu, &*correct_c); - } + let a = HerculesImmBox::from(a.as_ref()); + let b = HerculesImmBox::from(b.as_ref()); + let mut r = runner!(matmul); + let mut c: HerculesMutBox<i32> = HerculesMutBox::from(r.run(I as u64, J as u64, K as u64, a.to(), b.to()).await); + assert_eq!(c.as_slice(), correct_c.as_ref()); }); } diff --git a/hercules_samples/matmul/src/matmul.hir b/hercules_samples/matmul/src/matmul.hir index f9d37afcbb9e74038aafd35dfacf91f533bc4b99..b0c31da4b32207bf3308c4b03583bc74c61f3737 100644 --- a/hercules_samples/matmul/src/matmul.hir +++ b/hercules_samples/matmul/src/matmul.hir @@ -1,9 +1,9 @@ -fn matmul<3>(a: array(i32, #0, 8), b: array(i32, 8, #2)) -> array(i32, #0, #2) +fn matmul<3>(a: array(i32, #0, #1), b: array(i32, #1, #2)) -> array(i32, #0, #2) c = constant(array(i32, #0, #2), []) i_j_ctrl = fork(start, #0, #2) i_idx = thread_id(i_j_ctrl, 0) j_idx = thread_id(i_j_ctrl, 1) - k_ctrl = fork(i_j_ctrl, 8) + k_ctrl = fork(i_j_ctrl, #1) k_idx = thread_id(k_ctrl, 0) k_join_ctrl = join(k_ctrl) i_j_join_ctrl = join(k_join_ctrl) diff --git a/juno_samples/cava/Cargo.toml b/juno_samples/cava/Cargo.toml index 63b6b2ac98dcc022a45f3c4084930cbfb24956ff..17e9a1d330ae08c0ae36dd2474ac29aa9ce76d7a 100644 --- a/juno_samples/cava/Cargo.toml +++ b/juno_samples/cava/Cargo.toml @@ -8,6 +8,9 @@ edition = "2021" name = "juno_cava" path = "src/main.rs" +[lib] +path = "src/lib.rs" + [features] cuda = ["juno_build/cuda", "hercules_rt/cuda"] @@ -21,3 +24,10 @@ async-std = "*" clap = { version = "*", features = ["derive"] } image = "*" with_builtin_macros = "0.1.0" + +[dev-dependencies] +criterion = { version = "0.5", features = ["html_reports"] } + +[[bench]] +name = "cava_bench" +harness = false \ No newline at end of file diff --git a/juno_samples/cava/benches/cava_bench.rs b/juno_samples/cava/benches/cava_bench.rs new file mode 100644 index 0000000000000000000000000000000000000000..b8dd3ce26901e94872a5d1b28f1494c5bdfb8a77 --- /dev/null +++ b/juno_samples/cava/benches/cava_bench.rs @@ -0,0 +1,77 @@ +#![feature(concat_idents)] + +use criterion::{criterion_group, criterion_main, Criterion}; + +use hercules_rt::{runner, HerculesImmBoxTo}; + +use juno_cava::*; + +juno_build::juno!("cava"); + +fn cava_bench(c: &mut Criterion) { + let mut group = c.benchmark_group("cava bench"); + group.sample_size(10); + + let args = CavaInputs { + input: "examples/raw_tulip-small.bin".to_string(), + output: None, + verify: false, + output_verify: None, + cam_model: "cam_models/NikonD7000".to_string(), + crop_rows: Some(144), + crop_cols: Some(192), + }; + let (raw_image, cam_model) = make_raw_image_and_cam_model(&args); + let (rows, cols, num_ctrl_pts, image, tstw, ctrl_pts, weights, coefs, tonemap) = + prepare_hercules_inputs(&raw_image, &cam_model); + let mut r = runner!(cava); + + group.bench_function("cava bench small", |b| { + b.iter(|| { + async_std::task::block_on(r.run( + rows as u64, + cols as u64, + num_ctrl_pts as u64, + image.to(), + tstw.to(), + ctrl_pts.to(), + weights.to(), + coefs.to(), + tonemap.to(), + )); + }) + }); + + let args = CavaInputs { + input: "examples/raw_tulips.bin".to_string(), + output: None, + verify: true, + output_verify: None, + cam_model: "cam_models/NikonD7000".to_string(), + crop_rows: None, + crop_cols: None, + }; + let (raw_image, cam_model) = make_raw_image_and_cam_model(&args); + let (rows, cols, num_ctrl_pts, image, tstw, ctrl_pts, weights, coefs, tonemap) = + prepare_hercules_inputs(&raw_image, &cam_model); + let mut r = runner!(cava); + + group.bench_function("cava bench full", |b| { + b.iter(|| { + async_std::task::block_on(r.run( + rows as u64, + cols as u64, + num_ctrl_pts as u64, + image.to(), + tstw.to(), + ctrl_pts.to(), + weights.to(), + coefs.to(), + tonemap.to(), + )); + }) + }); +} + +criterion_group!(benches, cava_bench); +criterion_main!(benches); diff --git a/juno_samples/cava/src/cava.jn b/juno_samples/cava/src/cava.jn index 8158bf0a9bec25d57e1136bbf040ae0f561dcc3f..dbe799f9f23e63ed40157c9d57f5c4c8d9b4eb23 100644 --- a/juno_samples/cava/src/cava.jn +++ b/juno_samples/cava/src/cava.jn @@ -40,7 +40,7 @@ fn scale<row : usize, col : usize>(input : u8[CHAN, row, col]) -> f32[CHAN, row, fn demosaic<row : usize, col : usize>(input : f32[CHAN, row, col]) -> f32[CHAN, row, col] { @res2 let res : f32[CHAN, row, col]; - for r = 1 to row-1 { + @loop for r = 1 to row-1 { for c = 1 to col-1 { if r % 2 == 0 && c % 2 == 0 { let R1 = input[0, r, c-1]; @@ -152,7 +152,7 @@ fn gamut<row : usize, col : usize, num_ctrl_pts : usize>( let v = v1 * v1 + v2 * v2 + v3 * v3; l2_dist[cp] = sqrt!::<f32>(v); } - + @channel_loop for chan = 0 to CHAN { let chan_val : f32 = 0.0; for cp = 0 to num_ctrl_pts { diff --git a/juno_samples/cava/src/cpu.sch b/juno_samples/cava/src/cpu.sch index 6cd33a3b1054069264ba6406ba37f2b11031f095..3ae1c6bf8971fb91eae339c5ced475bd9ffd2265 100644 --- a/juno_samples/cava/src/cpu.sch +++ b/juno_samples/cava/src/cpu.sch @@ -46,6 +46,9 @@ fixpoint { } predication(fuse1); simpl!(fuse1); +write-predication(fuse1); +simpl!(fuse1); +parallel-reduce(fuse1@loop); inline(fuse2); no-memset(fuse2@res); diff --git a/juno_samples/cava/src/gpu.sch b/juno_samples/cava/src/gpu.sch index f440dacde5d4dc0f9d42599e19923ccba91d82ab..c8db124ede9b98220866a8c1cdc7b17cdfb8093f 100644 --- a/juno_samples/cava/src/gpu.sch +++ b/juno_samples/cava/src/gpu.sch @@ -1,23 +1,138 @@ -gvn(*); -phi-elim(*); -dce(*); +macro simpl!(X) { + ccp(X); + simplify-cfg(X); + lift-dc-math(X); + gvn(X); + phi-elim(X); + dce(X); + infer-schedules(X); +} + +simpl!(*); + +let fuse1 = outline(cava@fuse1); +inline(fuse1); +gpu(fuse1); + +let fuse2 = outline(cava@fuse2); +inline(fuse2); +gpu(fuse2); + +let fuse3 = outline(cava@fuse3); +inline(fuse3); +gpu(fuse3); -inline(denoise); -gpu(scale, demosaic, denoise, transform, gamut, tone_map, descale); +let fuse4 = outline(cava@fuse4); +inline(fuse4); +gpu(fuse4); + +let fuse5 = outline(cava@fuse5); +inline(fuse5); +gpu(fuse5); ip-sroa(*); sroa(*); -dce(*); -gvn(*); -phi-elim(*); -dce(*); +simpl!(*); -// forkify(*); -infer-schedules(*); +no-memset(fuse1@res1); +no-memset(fuse1@res2); +fixpoint { + forkify(fuse1); + fork-guard-elim(fuse1); + fork-coalesce(fuse1); +} +simpl!(fuse1); +array-slf(fuse1); +loop-bound-canon(fuse1); +fixpoint { + forkify(fuse1); + fork-guard-elim(fuse1); + fork-coalesce(fuse1); +} +predication(fuse1); +simpl!(fuse1); +write-predication(fuse1); +simpl!(fuse1); +parallel-reduce(fuse1@loop); -gcm(*); +inline(fuse2); +no-memset(fuse2@res); +no-memset(fuse2@filter); +no-memset(fuse2@tmp); +fixpoint { + forkify(fuse2); + fork-guard-elim(fuse2); + fork-coalesce(fuse2); +} +simpl!(fuse2); +predication(fuse2); +simpl!(fuse2); + +let median = outline(fuse2@median); +fork-unroll(median@medianOuter); +simpl!(median); +fixpoint { + forkify(median); + fork-guard-elim(median); +} +simpl!(median); fixpoint { - float-collections(*); - dce(*); - gcm(*); + fork-unroll(median); } +ccp(median); +array-to-product(median); +sroa(median); +phi-elim(median); +predication(median); +simpl!(median); + +inline(fuse2); +ip-sroa(*); +sroa(*); +array-slf(fuse2); +write-predication(fuse2); +simpl!(fuse2); + +no-memset(fuse3@res); +fixpoint { + forkify(fuse3); + fork-guard-elim(fuse3); + fork-coalesce(fuse3); +} +simpl!(fuse3); + +no-memset(fuse4@res); +no-memset(fuse4@l2); +fixpoint { + forkify(fuse4); + fork-guard-elim(fuse4); + fork-coalesce(fuse4); +} +simpl!(fuse4); +fork-unroll(fuse4@channel_loop); +simpl!(fuse4); +fixpoint { + fork-fusion(fuse4@channel_loop); +} +simpl!(fuse4); +array-slf(fuse4); +simpl!(fuse4); +//fork-tile[2, 0, false, true](fuse4@channel_loop); +//fork-split(fuse4@channel_loop); +//clean-monoid-reduces(fuse4); + +no-memset(fuse5@res1); +no-memset(fuse5@res2); +fixpoint { + forkify(fuse5); + fork-guard-elim(fuse5); + fork-coalesce(fuse5); +} +simpl!(fuse5); +array-slf(fuse5); +simpl!(fuse5); + +delete-uncalled(*); +simpl!(*); + +gcm(*); diff --git a/juno_samples/cava/src/lib.rs b/juno_samples/cava/src/lib.rs new file mode 100644 index 0000000000000000000000000000000000000000..1810a24670d8fa7325b2ff353ba70756824a0269 --- /dev/null +++ b/juno_samples/cava/src/lib.rs @@ -0,0 +1,198 @@ +#![feature(concat_idents)] + +mod camera_model; +mod cava_rust; +mod image_proc; + +pub use self::camera_model::*; +pub use self::cava_rust::CHAN; +pub use self::image_proc::*; + +use hercules_rt::{runner, HerculesImmBox, HerculesImmBoxTo, HerculesMutBox}; + +use image::ImageError; + +use clap::Parser; + +juno_build::juno!("cava"); + +pub fn make_raw_image_and_cam_model(args: &CavaInputs) -> (RawImage, CamModel) { + let raw_image = + read_raw(&args.input, args.crop_rows, args.crop_cols).expect("Error loading image"); + let cam_model = load_cam_model(&args.cam_model, CHAN).expect("Error loading camera model"); + println!( + "Running cava with {} rows, {} columns, and {} control points.", + raw_image.rows, raw_image.cols, cam_model.num_ctrl_pts + ); + (raw_image, cam_model) +} + +pub fn prepare_hercules_inputs<'a, 'b>( + raw_image: &'a RawImage, + cam_model: &'b CamModel, +) -> ( + usize, + usize, + usize, + HerculesImmBox<'a, u8>, + HerculesImmBox<'b, f32>, + HerculesImmBox<'b, f32>, + HerculesImmBox<'b, f32>, + HerculesImmBox<'b, f32>, + HerculesImmBox<'b, f32>, +) { + assert_eq!( + raw_image.pixels.len(), + CHAN * raw_image.rows * raw_image.cols + ); + assert_eq!(cam_model.tstw.len(), CHAN * CHAN); + assert_eq!(cam_model.ctrl_pts.len(), cam_model.num_ctrl_pts * CHAN); + assert_eq!(cam_model.weights.len(), cam_model.num_ctrl_pts * CHAN); + assert_eq!(cam_model.coefs.len(), 4 * CHAN); + assert_eq!(cam_model.tonemap.len(), 256 * CHAN); + + let image = HerculesImmBox::from(&raw_image.pixels as &[u8]); + let tstw = HerculesImmBox::from(&cam_model.tstw as &[f32]); + let ctrl_pts = HerculesImmBox::from(&cam_model.ctrl_pts as &[f32]); + let weights = HerculesImmBox::from(&cam_model.weights as &[f32]); + let coefs = HerculesImmBox::from(&cam_model.coefs as &[f32]); + let tonemap = HerculesImmBox::from(&cam_model.tonemap as &[f32]); + + ( + raw_image.rows, + raw_image.cols, + cam_model.num_ctrl_pts, + image, + tstw, + ctrl_pts, + weights, + coefs, + tonemap, + ) +} + +enum Error { + IOError(std::io::Error), + ImageError(image::ImageError), + ParseIntError(std::num::ParseIntError), + ParseFloatError(std::num::ParseFloatError), +} + +impl std::fmt::Debug for Error { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + Error::IOError(err) => write!(f, "IO Error: {:?}", err), + Error::ImageError(err) => write!(f, "Image Error: {:?}", err), + Error::ParseIntError(err) => write!(f, "Parse Error: {:?}", err), + Error::ParseFloatError(err) => write!(f, "Parse Error: {:?}", err), + } + } +} + +impl From<std::io::Error> for Error { + fn from(value: std::io::Error) -> Self { + Error::IOError(value) + } +} + +impl From<ImageError> for Error { + fn from(value: ImageError) -> Self { + Error::ImageError(value) + } +} + +impl From<std::num::ParseIntError> for Error { + fn from(value: std::num::ParseIntError) -> Self { + Error::ParseIntError(value) + } +} + +impl From<std::num::ParseFloatError> for Error { + fn from(value: std::num::ParseFloatError) -> Self { + Error::ParseFloatError(value) + } +} + +#[derive(Parser)] +#[clap(author, version, about, long_about = None)] +pub struct CavaInputs { + pub input: String, + #[clap(short, long, value_name = "PATH")] + pub output: Option<String>, + #[clap(short, long)] + pub verify: bool, + #[clap(long = "output-verify", value_name = "PATH")] + pub output_verify: Option<String>, + pub cam_model: String, + #[clap(short, long)] + pub crop_rows: Option<usize>, + #[clap(short, long)] + pub crop_cols: Option<usize>, +} + +pub fn cava_harness(args: CavaInputs) { + let (raw_image, cam_model) = make_raw_image_and_cam_model(&args); + let (rows, cols, num_ctrl_pts, image, tstw, ctrl_pts, weights, coefs, tonemap) = + prepare_hercules_inputs(&raw_image, &cam_model); + let mut r = runner!(cava); + + let result = async_std::task::block_on(async { + HerculesMutBox::from( + r.run( + rows as u64, + cols as u64, + num_ctrl_pts as u64, + image.to(), + tstw.to(), + ctrl_pts.to(), + weights.to(), + coefs.to(), + tonemap.to(), + ) + .await, + ) + }) + .as_slice() + .to_vec() + .into_boxed_slice(); + + if let Some(output) = args.output { + extern_image(rows, cols, &*result) + .save(output) + .expect("Error saving image"); + } + + if args.verify { + let cpu_result = cava_rust::cava( + rows, + cols, + num_ctrl_pts, + &raw_image.pixels, + &cam_model.tstw, + &cam_model.ctrl_pts, + &cam_model.weights, + &cam_model.coefs, + &cam_model.tonemap, + ); + + if let Some(output) = args.output_verify { + extern_image(rows, cols, &cpu_result) + .save(output) + .expect("Error saving verification image"); + } + + let max_diff = result + .iter() + .zip(cpu_result.iter()) + .map(|(a, b)| (*a as i16 - *b as i16).abs()) + .max() + .unwrap_or(0); + + assert!( + max_diff <= 3, + "Verification failed: maximum pixel difference of {} exceeds threshold of 3", + max_diff + ); + println!("Verified!"); + } +} diff --git a/juno_samples/cava/src/main.rs b/juno_samples/cava/src/main.rs index 142ed70394aeb38d09ce162a30af8664e604a559..6bd84639ddac38300d7b15b82a902508df9fcfb6 100644 --- a/juno_samples/cava/src/main.rs +++ b/juno_samples/cava/src/main.rs @@ -1,234 +1,6 @@ -#![feature(concat_idents)] - -mod camera_model; -mod cava_rust; -mod image_proc; - -use self::camera_model::*; -use self::cava_rust::CHAN; -use self::image_proc::*; - -use hercules_rt::{runner, HerculesImmBox, HerculesImmBoxTo, HerculesMutBox}; - -use image::ImageError; - use clap::Parser; -juno_build::juno!("cava"); - -// Individual lifetimes are not needed in this example but should probably be generated for -// flexibility -async fn safe_run<'a, 'b: 'a, 'c: 'a, 'd: 'a, 'e: 'a, 'f: 'a, 'g: 'a>( - runner: &'a mut HerculesRunner_cava, - r: u64, - c: u64, - num_ctrl_pts: u64, - input: &'b HerculesImmBox<'b, u8>, - tstw: &'c HerculesImmBox<'c, f32>, - ctrl_pts: &'d HerculesImmBox<'d, f32>, - weights: &'e HerculesImmBox<'e, f32>, - coefs: &'f HerculesImmBox<'f, f32>, - tonemap: &'g HerculesImmBox<'g, f32>, -) -> HerculesMutBox<'a, u8> { - HerculesMutBox::from( - runner - .run( - r, - c, - num_ctrl_pts, - input.to(), - tstw.to(), - ctrl_pts.to(), - weights.to(), - coefs.to(), - tonemap.to(), - ) - .await, - ) -} - -fn run_cava( - rows: usize, - cols: usize, - num_ctrl_pts: usize, - image: &[u8], - tstw: &[f32], - ctrl_pts: &[f32], - weights: &[f32], - coefs: &[f32], - tonemap: &[f32], -) -> Box<[u8]> { - assert_eq!(image.len(), CHAN * rows * cols); - assert_eq!(tstw.len(), CHAN * CHAN); - assert_eq!(ctrl_pts.len(), num_ctrl_pts * CHAN); - assert_eq!(weights.len(), num_ctrl_pts * CHAN); - assert_eq!(coefs.len(), 4 * CHAN); - assert_eq!(tonemap.len(), 256 * CHAN); - - let image = HerculesImmBox::from(image); - let tstw = HerculesImmBox::from(tstw); - let ctrl_pts = HerculesImmBox::from(ctrl_pts); - let weights = HerculesImmBox::from(weights); - let coefs = HerculesImmBox::from(coefs); - let tonemap = HerculesImmBox::from(tonemap); - - let mut r = runner!(cava); - - async_std::task::block_on(async { - safe_run( - &mut r, - rows as u64, - cols as u64, - num_ctrl_pts as u64, - &image, - &tstw, - &ctrl_pts, - &weights, - &coefs, - &tonemap, - ) - .await - }) - .as_slice() - .to_vec() - .into_boxed_slice() -} - -enum Error { - IOError(std::io::Error), - ImageError(image::ImageError), - ParseIntError(std::num::ParseIntError), - ParseFloatError(std::num::ParseFloatError), -} - -impl std::fmt::Debug for Error { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - match self { - Error::IOError(err) => write!(f, "IO Error: {:?}", err), - Error::ImageError(err) => write!(f, "Image Error: {:?}", err), - Error::ParseIntError(err) => write!(f, "Parse Error: {:?}", err), - Error::ParseFloatError(err) => write!(f, "Parse Error: {:?}", err), - } - } -} - -impl From<std::io::Error> for Error { - fn from(value: std::io::Error) -> Self { - Error::IOError(value) - } -} - -impl From<ImageError> for Error { - fn from(value: ImageError) -> Self { - Error::ImageError(value) - } -} - -impl From<std::num::ParseIntError> for Error { - fn from(value: std::num::ParseIntError) -> Self { - Error::ParseIntError(value) - } -} - -impl From<std::num::ParseFloatError> for Error { - fn from(value: std::num::ParseFloatError) -> Self { - Error::ParseFloatError(value) - } -} - -#[derive(Parser)] -#[clap(author, version, about, long_about = None)] -struct CavaInputs { - input: String, - #[clap(short, long, value_name = "PATH")] - output: Option<String>, - #[clap(short, long)] - verify: bool, - #[clap(long = "output-verify", value_name = "PATH")] - output_verify: Option<String>, - cam_model: String, - #[clap(short, long)] - crop_rows: Option<usize>, - #[clap(short, long)] - crop_cols: Option<usize>, -} - -fn cava_harness(args: CavaInputs) { - let CavaInputs { - input, - output, - verify, - output_verify, - cam_model, - crop_rows, - crop_cols, - } = args; - let RawImage { rows, cols, pixels } = - read_raw(input, crop_rows, crop_cols).expect("Error loading image"); - let CamModel { - tstw, - num_ctrl_pts, - ctrl_pts, - weights, - coefs, - tonemap, - } = load_cam_model(cam_model, CHAN).expect("Error loading camera model"); - - println!( - "Running cava with {} rows, {} columns, and {} control points.", - rows, cols, num_ctrl_pts - ); - let result = run_cava( - rows, - cols, - num_ctrl_pts, - &pixels, - &tstw, - &ctrl_pts, - &weights, - &coefs, - &tonemap, - ); - - if let Some(output) = output { - extern_image(rows, cols, &*result) - .save(output) - .expect("Error saving image"); - } - - if verify { - let cpu_result = cava_rust::cava( - rows, - cols, - num_ctrl_pts, - &pixels, - &tstw, - &ctrl_pts, - &weights, - &coefs, - &tonemap, - ); - - if let Some(output) = output_verify { - extern_image(rows, cols, &cpu_result) - .save(output) - .expect("Error saving verification image"); - } - - let max_diff = result - .iter() - .zip(cpu_result.iter()) - .map(|(a, b)| (*a as i16 - *b as i16).abs()) - .max() - .unwrap_or(0); - - assert!( - max_diff <= 3, - "Verification failed: maximum pixel difference of {} exceeds threshold of 3", - max_diff - ); - } -} +use juno_cava::{cava_harness, CavaInputs}; fn main() { let args = CavaInputs::parse(); diff --git a/juno_samples/concat/src/main.rs b/juno_samples/concat/src/main.rs index 547dee08b118c475e1905d0fe93e9aaebfdca535..2f704f168bf64b619dde32610dd1a12603fc9dd0 100644 --- a/juno_samples/concat/src/main.rs +++ b/juno_samples/concat/src/main.rs @@ -10,12 +10,12 @@ juno_build::juno!("concat"); fn main() { async_std::task::block_on(async { let mut r = runner!(concat_entry); - let mut a_data = [7, 7, 0]; - let mut b_data = [7, 7, 0, 0, 7, 7]; + let mut a_data = Box::new([7, 7, 0]); + let mut b_data = Box::new([7, 7, 0, 0, 7, 7]); #[cfg(not(feature = "cuda"))] { - let a = HerculesCPURef::from_slice(&mut a_data); - let b = HerculesCPURef::from_slice(&mut b_data); + let a = HerculesCPURef::from_slice(a_data.as_ref()); + let b = HerculesCPURef::from_slice(b_data.as_ref()); let output = r.run(3, 6, a, b).await; assert_eq!(output, 42); @@ -36,10 +36,8 @@ fn main() { } #[cfg(feature = "cuda")] { - let mut a_data = [7, 7, 0]; - let a = CUDABox::from_cpu_ref(HerculesCPURef::from_slice(&mut a_data)); - let mut b_data = [7, 7, 0, 0, 7, 7]; - let b = CUDABox::from_cpu_ref(HerculesCPURef::from_slice(&mut b_data)); + let a = CUDABox::from_cpu_ref(HerculesCPURef::from_slice(a_data.as_ref())); + let b = CUDABox::from_cpu_ref(HerculesCPURef::from_slice(b_data.as_ref())); let output = r.run(3, 6, a.get_ref(), b.get_ref()).await; assert_eq!(output, 42); } diff --git a/juno_samples/matmul/src/main.rs b/juno_samples/matmul/src/main.rs index c0e228daa04704b90156a592d27b761aeba6591c..29415b511992946b08a1496f3eb92d957615d8aa 100644 --- a/juno_samples/matmul/src/main.rs +++ b/juno_samples/matmul/src/main.rs @@ -3,9 +3,7 @@ use std::iter::zip; use rand::random; -use hercules_rt::{runner, HerculesRefInto}; -#[cfg(feature = "cuda")] -use hercules_rt::{CUDABox, HerculesCPURef}; +use hercules_rt::{runner, HerculesImmBox, HerculesImmBoxTo, HerculesMutBox}; juno_build::juno!("matmul"); @@ -24,24 +22,12 @@ fn main() { } } } - #[cfg(not(feature = "cuda"))] - { - let mut r = runner!(matmul); - let c = r.run(I as u64, J as u64, K as u64, a.to(), b.to()).await; - let c = c.as_slice::<f32>(); - assert_eq!(c, &*correct_c); - } - #[cfg(feature = "cuda")] - { - let a = CUDABox::from_cpu_ref(HerculesCPURef::from_slice(&a)); - let b = CUDABox::from_cpu_ref(HerculesCPURef::from_slice(&b)); - let mut r = runner!(matmul); - let c = r - .run(I as u64, J as u64, K as u64, a.get_ref(), b.get_ref()) - .await; - let mut c_cpu: Box<[f32]> = vec![0.0; correct_c.len()].into_boxed_slice(); - c.to_cpu_ref(&mut c_cpu); - assert!(zip(c_cpu, correct_c).all(|(calc, correct)| (calc - correct).abs() < 0.00001)); + let a = HerculesImmBox::from(a.as_ref()); + let b = HerculesImmBox::from(b.as_ref()); + let mut r = runner!(matmul); + let mut c = HerculesMutBox::from(r.run(I as u64, J as u64, K as u64, a.to(), b.to()).await); + for (calc, correct) in zip(c.as_slice().into_iter().map(|x: &mut f32| *x), correct_c) { + assert!((calc - correct).abs() < 0.0001, "{} != {}", calc, correct); } }); }