Skip to content
Snippets Groups Projects
Commit b935a755 authored by rarbore2's avatar rarbore2
Browse files

Merge branch 'cava_opt_3' into 'main'

Set up cava benchmark

See merge request !190
parents 49a0d4c7 df22cb0e
No related branches found
No related tags found
1 merge request!190Set up cava benchmark
Pipeline #201771 passed
Showing
with 815 additions and 527 deletions
test-cpu: test-cpu:
stage: test stage: test
script: script:
- cargo test --features=opencv -vv - cargo test --features=opencv -vv -- --nocapture
test-gpu: test-gpu:
stage: test stage: test
script: script:
- cargo test --features=cuda,opencv -vv - cargo test --features=cuda,opencv -vv -- --nocapture
...@@ -10,9 +10,9 @@ checksum = "512761e0bb2578dd7380c6baaa0f4ce03e84f95e960231d1dec8bf4d7d6e2627" ...@@ -10,9 +10,9 @@ checksum = "512761e0bb2578dd7380c6baaa0f4ce03e84f95e960231d1dec8bf4d7d6e2627"
[[package]] [[package]]
name = "aho-corasick" name = "aho-corasick"
version = "0.7.15" version = "1.1.3"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7404febffaa47dac81aa44dba71523c9d069b1bdc50a77db41195149e17f68e5" checksum = "8e60d3430d3a69478ad0993f19238d2df97c507009a52b3c10addcd7f6bcb916"
dependencies = [ dependencies = [
"memchr", "memchr",
] ]
...@@ -29,6 +29,12 @@ version = "0.2.21" ...@@ -29,6 +29,12 @@ version = "0.2.21"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "683d7910e743518b0e34f1186f92494becacb047c7b6bf616c96772180fef923" checksum = "683d7910e743518b0e34f1186f92494becacb047c7b6bf616c96772180fef923"
[[package]]
name = "anes"
version = "0.1.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "4b46cbb362ab8752921c97e041f5e366ee6297bd428a31275b9fcf1e380f7299"
[[package]] [[package]]
name = "anstream" name = "anstream"
version = "0.6.18" version = "0.6.18"
...@@ -102,12 +108,6 @@ dependencies = [ ...@@ -102,12 +108,6 @@ dependencies = [
"syn 2.0.98", "syn 2.0.98",
] ]
[[package]]
name = "arrayvec"
version = "0.5.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "23b62fc65de8e4e7f52534fb52b0f3ed04746ae267519eef2a83941e8085068b"
[[package]] [[package]]
name = "arrayvec" name = "arrayvec"
version = "0.7.6" version = "0.7.6"
...@@ -255,7 +255,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" ...@@ -255,7 +255,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "6678909d8c5d46a42abcf571271e15fdbc0a225e3646cf23762cd415046c78bf" checksum = "6678909d8c5d46a42abcf571271e15fdbc0a225e3646cf23762cd415046c78bf"
dependencies = [ dependencies = [
"anyhow", "anyhow",
"arrayvec 0.7.6", "arrayvec",
"log", "log",
"nom 7.1.3", "nom 7.1.3",
"num-rational", "num-rational",
...@@ -268,7 +268,7 @@ version = "0.8.2" ...@@ -268,7 +268,7 @@ version = "0.8.2"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e335041290c43101ca215eed6f43ec437eb5a42125573f600fc3fa42b9bddd62" checksum = "e335041290c43101ca215eed6f43ec437eb5a42125573f600fc3fa42b9bddd62"
dependencies = [ dependencies = [
"arrayvec 0.7.6", "arrayvec",
] ]
[[package]] [[package]]
...@@ -310,28 +310,16 @@ version = "2.6.0" ...@@ -310,28 +310,16 @@ version = "2.6.0"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "6099cdc01846bc367c4e7dd630dc5966dccf36b652fae7a74e17b640411a91b2" checksum = "6099cdc01846bc367c4e7dd630dc5966dccf36b652fae7a74e17b640411a91b2"
[[package]]
name = "bitvec"
version = "0.19.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "55f93d0ef3363c364d5976646a38f04cf67cfe1d4c8d160cdea02cab2c116b33"
dependencies = [
"funty 1.1.0",
"radium 0.5.3",
"tap",
"wyz 0.2.0",
]
[[package]] [[package]]
name = "bitvec" name = "bitvec"
version = "1.0.1" version = "1.0.1"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "1bc2832c24239b0141d5674bb9174f9d68a8b5b3f2753311927c172ca46f7e9c" checksum = "1bc2832c24239b0141d5674bb9174f9d68a8b5b3f2753311927c172ca46f7e9c"
dependencies = [ dependencies = [
"funty 2.0.0", "funty",
"radium 0.7.0", "radium",
"tap", "tap",
"wyz 0.5.1", "wyz",
] ]
[[package]] [[package]]
...@@ -349,9 +337,9 @@ dependencies = [ ...@@ -349,9 +337,9 @@ dependencies = [
[[package]] [[package]]
name = "built" name = "built"
version = "0.7.6" version = "0.7.7"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "73848a43c5d63a1251d17adf6c2bf78aa94830e60a335a95eeea45d6ba9e1e4d" checksum = "56ed6191a7e78c36abdb16ab65341eefd73d64d303fffccdbb00d51e4205967b"
[[package]] [[package]]
name = "bumpalo" name = "bumpalo"
...@@ -394,11 +382,17 @@ dependencies = [ ...@@ -394,11 +382,17 @@ dependencies = [
"with_builtin_macros", "with_builtin_macros",
] ]
[[package]]
name = "cast"
version = "0.3.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "37b2a672a2cb129a2e41c10b1224bb368f9f37a2b16b612598138befd7b37eb5"
[[package]] [[package]]
name = "cc" name = "cc"
version = "1.2.13" version = "1.2.14"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "c7777341816418c02e033934a09f20dc0ccaf65a5201ef8a450ae0105a573fda" checksum = "0c3d1b2e905a3a7b00a6141adb0e4c0bb941d11caf55349d863942a1cc44e3c9"
dependencies = [ dependencies = [
"jobserver", "jobserver",
"libc", "libc",
...@@ -446,6 +440,33 @@ dependencies = [ ...@@ -446,6 +440,33 @@ dependencies = [
"vob", "vob",
] ]
[[package]]
name = "ciborium"
version = "0.2.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "42e69ffd6f0917f5c029256a24d0161db17cea3997d185db0d35926308770f0e"
dependencies = [
"ciborium-io",
"ciborium-ll",
"serde",
]
[[package]]
name = "ciborium-io"
version = "0.2.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "05afea1e0a06c9be33d539b876f1ce3692f4afea2cb41f740e7743225ed1c757"
[[package]]
name = "ciborium-ll"
version = "0.2.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "57663b653d948a338bfb3eeba9bb2fd5fcfaecb9e199e87e1eda4d9e8b240fd9"
dependencies = [
"ciborium-io",
"half",
]
[[package]] [[package]]
name = "clang" name = "clang"
version = "2.0.0" version = "2.0.0"
...@@ -469,9 +490,9 @@ dependencies = [ ...@@ -469,9 +490,9 @@ dependencies = [
[[package]] [[package]]
name = "clap" name = "clap"
version = "4.5.29" version = "4.5.30"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "8acebd8ad879283633b343856142139f2da2317c96b05b4dd6181c61e2480184" checksum = "92b7b18d71fad5313a1e320fa9897994228ce274b60faa4d694fe0ea89cd9e6d"
dependencies = [ dependencies = [
"clap_builder", "clap_builder",
"clap_derive", "clap_derive",
...@@ -479,9 +500,9 @@ dependencies = [ ...@@ -479,9 +500,9 @@ dependencies = [
[[package]] [[package]]
name = "clap_builder" name = "clap_builder"
version = "4.5.29" version = "4.5.30"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f6ba32cbda51c7e1dfd49acc1457ba1a7dec5b64fe360e828acb13ca8dc9c2f9" checksum = "a35db2071778a7344791a4fb4f95308b5673d219dee3ae348b86642574ecc90c"
dependencies = [ dependencies = [
"anstream", "anstream",
"anstyle", "anstyle",
...@@ -495,7 +516,7 @@ version = "4.5.28" ...@@ -495,7 +516,7 @@ version = "4.5.28"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "bf4ced95c6f4a675af3da73304b9ac4ed991640c36374e4b46795c49e17cf1ed" checksum = "bf4ced95c6f4a675af3da73304b9ac4ed991640c36374e4b46795c49e17cf1ed"
dependencies = [ dependencies = [
"heck 0.5.0", "heck",
"proc-macro2", "proc-macro2",
"quote", "quote",
"syn 2.0.98", "syn 2.0.98",
...@@ -543,6 +564,42 @@ dependencies = [ ...@@ -543,6 +564,42 @@ dependencies = [
"cfg-if", "cfg-if",
] ]
[[package]]
name = "criterion"
version = "0.5.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f2b12d017a929603d80db1831cd3a24082f8137ce19c69e6447f54f5fc8d692f"
dependencies = [
"anes",
"cast",
"ciborium",
"clap",
"criterion-plot",
"is-terminal",
"itertools 0.10.5",
"num-traits",
"once_cell",
"oorandom",
"plotters",
"rayon",
"regex",
"serde",
"serde_derive",
"serde_json",
"tinytemplate",
"walkdir",
]
[[package]]
name = "criterion-plot"
version = "0.5.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "6b50826342786a51a89e2da3a28f1c32b06e387201bc2d19791f622c673706b1"
dependencies = [
"cast",
"itertools 0.10.5",
]
[[package]] [[package]]
name = "critical-section" name = "critical-section"
version = "1.2.0" version = "1.2.0"
...@@ -677,9 +734,9 @@ dependencies = [ ...@@ -677,9 +734,9 @@ dependencies = [
[[package]] [[package]]
name = "equivalent" name = "equivalent"
version = "1.0.1" version = "1.0.2"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "5443807d6dff69373d433ab9ef5378ad8df50ca6298caf15de6e52e24aaf54d5" checksum = "877a4ace8713b0bcf2a4e7eec82529c029f1d0619886d18145fea96c3ffe5c0f"
[[package]] [[package]]
name = "errno" name = "errno"
...@@ -794,12 +851,6 @@ version = "0.1.4" ...@@ -794,12 +851,6 @@ version = "0.1.4"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a0d2fde1f7b3d48b8395d5f2de76c18a528bd6a9cdde438df747bfcba3e05d6f" checksum = "a0d2fde1f7b3d48b8395d5f2de76c18a528bd6a9cdde438df747bfcba3e05d6f"
[[package]]
name = "funty"
version = "1.1.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "fed34cd105917e91daa4da6b3728c47b068749d6a62c59811f06ed2ac71d9da7"
[[package]] [[package]]
name = "funty" name = "funty"
version = "2.0.0" version = "2.0.0"
...@@ -944,12 +995,6 @@ dependencies = [ ...@@ -944,12 +995,6 @@ dependencies = [
"stable_deref_trait", "stable_deref_trait",
] ]
[[package]]
name = "heck"
version = "0.4.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "95505c38b4572b2d910cecb0281560f54b440a19336cbbcb27bf6ce6adc6f5a8"
[[package]] [[package]]
name = "heck" name = "heck"
version = "0.5.0" version = "0.5.0"
...@@ -960,7 +1005,7 @@ checksum = "2304e00983f87ffb38b55b444b5e3b60a884b5d30c0fca7d82fe33449bbe55ea" ...@@ -960,7 +1005,7 @@ checksum = "2304e00983f87ffb38b55b444b5e3b60a884b5d30c0fca7d82fe33449bbe55ea"
name = "hercules_cg" name = "hercules_cg"
version = "0.1.0" version = "0.1.0"
dependencies = [ dependencies = [
"bitvec 1.0.1", "bitvec",
"hercules_ir", "hercules_ir",
"ordered-float", "ordered-float",
"rand 0.9.0", "rand 0.9.0",
...@@ -971,7 +1016,7 @@ dependencies = [ ...@@ -971,7 +1016,7 @@ dependencies = [
name = "hercules_interpreter" name = "hercules_interpreter"
version = "0.1.0" version = "0.1.0"
dependencies = [ dependencies = [
"bitvec 1.0.1", "bitvec",
"clap", "clap",
"derive_more", "derive_more",
"hercules_ir", "hercules_ir",
...@@ -988,9 +1033,9 @@ dependencies = [ ...@@ -988,9 +1033,9 @@ dependencies = [
name = "hercules_ir" name = "hercules_ir"
version = "0.1.0" version = "0.1.0"
dependencies = [ dependencies = [
"bitvec 1.0.1", "bitvec",
"either", "either",
"nom 6.2.2", "nom 8.0.0",
"ordered-float", "ordered-float",
"rand 0.9.0", "rand 0.9.0",
"serde", "serde",
...@@ -1001,7 +1046,7 @@ name = "hercules_opt" ...@@ -1001,7 +1046,7 @@ name = "hercules_opt"
version = "0.1.0" version = "0.1.0"
dependencies = [ dependencies = [
"bimap", "bimap",
"bitvec 1.0.1", "bitvec",
"egg", "egg",
"either", "either",
"hercules_cg", "hercules_cg",
...@@ -1025,7 +1070,7 @@ version = "0.1.0" ...@@ -1025,7 +1070,7 @@ version = "0.1.0"
name = "hercules_tests" name = "hercules_tests"
version = "0.1.0" version = "0.1.0"
dependencies = [ dependencies = [
"bitvec 1.0.1", "bitvec",
"clap", "clap",
"hercules_interpreter", "hercules_interpreter",
"hercules_ir", "hercules_ir",
...@@ -1102,12 +1147,32 @@ dependencies = [ ...@@ -1102,12 +1147,32 @@ dependencies = [
"syn 2.0.98", "syn 2.0.98",
] ]
[[package]]
name = "is-terminal"
version = "0.4.15"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e19b23d53f35ce9f56aebc7d1bb4e6ac1e9c0db7ac85c8d1760c04379edced37"
dependencies = [
"hermit-abi",
"libc",
"windows-sys",
]
[[package]] [[package]]
name = "is_terminal_polyfill" name = "is_terminal_polyfill"
version = "1.70.1" version = "1.70.1"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7943c866cc5cd64cbc25b2e01621d07fa8eb2a1a23160ee81ce38704e97b8ecf" checksum = "7943c866cc5cd64cbc25b2e01621d07fa8eb2a1a23160ee81ce38704e97b8ecf"
[[package]]
name = "itertools"
version = "0.10.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b0fd2260e829bddf4cb6ea802289de2f86d6a7a690192fbe91b3f46e0f2c8473"
dependencies = [
"either",
]
[[package]] [[package]]
name = "itertools" name = "itertools"
version = "0.12.1" version = "0.12.1"
...@@ -1175,7 +1240,7 @@ dependencies = [ ...@@ -1175,7 +1240,7 @@ dependencies = [
"clap", "clap",
"hercules_rt", "hercules_rt",
"juno_build", "juno_build",
"nom 6.2.2", "nom 8.0.0",
"rand 0.9.0", "rand 0.9.0",
"with_builtin_macros", "with_builtin_macros",
] ]
...@@ -1188,7 +1253,7 @@ dependencies = [ ...@@ -1188,7 +1253,7 @@ dependencies = [
"clap", "clap",
"hercules_rt", "hercules_rt",
"juno_build", "juno_build",
"nom 6.2.2", "nom 8.0.0",
"with_builtin_macros", "with_builtin_macros",
] ]
...@@ -1217,6 +1282,7 @@ version = "0.1.0" ...@@ -1217,6 +1282,7 @@ version = "0.1.0"
dependencies = [ dependencies = [
"async-std", "async-std",
"clap", "clap",
"criterion",
"hercules_rt", "hercules_rt",
"image", "image",
"juno_build", "juno_build",
...@@ -1231,7 +1297,7 @@ dependencies = [ ...@@ -1231,7 +1297,7 @@ dependencies = [
"clap", "clap",
"hercules_rt", "hercules_rt",
"juno_build", "juno_build",
"nom 6.2.2", "nom 8.0.0",
"with_builtin_macros", "with_builtin_macros",
] ]
...@@ -1414,7 +1480,7 @@ dependencies = [ ...@@ -1414,7 +1480,7 @@ dependencies = [
"clap", "clap",
"hercules_rt", "hercules_rt",
"juno_build", "juno_build",
"nom 6.2.2", "nom 8.0.0",
"with_builtin_macros", "with_builtin_macros",
] ]
...@@ -1446,19 +1512,6 @@ version = "0.5.2" ...@@ -1446,19 +1512,6 @@ version = "0.5.2"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "03087c2bad5e1034e8cace5926dec053fb3790248370865f5117a7d0213354c8" checksum = "03087c2bad5e1034e8cace5926dec053fb3790248370865f5117a7d0213354c8"
[[package]]
name = "lexical-core"
version = "0.7.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "6607c62aa161d23d17a9072cc5da0be67cdfc89d3afb1e8d9c842bebc2525ffe"
dependencies = [
"arrayvec 0.5.2",
"bitflags 1.3.2",
"cfg-if",
"ryu",
"static_assertions",
]
[[package]] [[package]]
name = "libc" name = "libc"
version = "0.2.169" version = "0.2.169"
...@@ -1543,7 +1596,7 @@ dependencies = [ ...@@ -1543,7 +1596,7 @@ dependencies = [
"num-traits", "num-traits",
"quote", "quote",
"regex", "regex",
"regex-syntax 0.8.5", "regex-syntax",
"serde", "serde",
"vergen", "vergen",
] ]
...@@ -1608,9 +1661,9 @@ dependencies = [ ...@@ -1608,9 +1661,9 @@ dependencies = [
[[package]] [[package]]
name = "memchr" name = "memchr"
version = "2.3.4" version = "2.7.4"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "0ee1c47aaa256ecabcaea351eae4a9b01ef39ed810004e298d2511ed284b1525" checksum = "78ca9ab1a0babb1e7d5695e3530886289c18cf2f87ec19a575a0abdce112e3a3"
[[package]] [[package]]
name = "minimal-lexical" name = "minimal-lexical"
...@@ -1648,25 +1701,21 @@ checksum = "650eef8c711430f1a879fdd01d4745a7deea475becfb90269c06775983bbf086" ...@@ -1648,25 +1701,21 @@ checksum = "650eef8c711430f1a879fdd01d4745a7deea475becfb90269c06775983bbf086"
[[package]] [[package]]
name = "nom" name = "nom"
version = "6.2.2" version = "7.1.3"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "c6a7a9657c84d5814c6196b68bb4429df09c18b1573806259fba397ea4ad0d44" checksum = "d273983c5a657a70a3e8f2a01329822f3b8c8172b73826411a55751e404a0a4a"
dependencies = [ dependencies = [
"bitvec 0.19.6",
"funty 1.1.0",
"lexical-core",
"memchr", "memchr",
"version_check", "minimal-lexical",
] ]
[[package]] [[package]]
name = "nom" name = "nom"
version = "7.1.3" version = "8.0.0"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d273983c5a657a70a3e8f2a01329822f3b8c8172b73826411a55751e404a0a4a" checksum = "df9761775871bdef83bee530e60050f7e54b1105350d6884eb0fb4f46c2f9405"
dependencies = [ dependencies = [
"memchr", "memchr",
"minimal-lexical",
] ]
[[package]] [[package]]
...@@ -1746,6 +1795,12 @@ version = "1.20.3" ...@@ -1746,6 +1795,12 @@ version = "1.20.3"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "945462a4b81e43c4e3ba96bd7b49d834c6f61198356aa858733bc4acf3cbe62e" checksum = "945462a4b81e43c4e3ba96bd7b49d834c6f61198356aa858733bc4acf3cbe62e"
[[package]]
name = "oorandom"
version = "11.1.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b410bbe7e14ab526a0e86877eb47c6996a2bd7746f027ba551028c925390e4e9"
[[package]] [[package]]
name = "opencv" name = "opencv"
version = "0.94.2" version = "0.94.2"
...@@ -1783,9 +1838,9 @@ dependencies = [ ...@@ -1783,9 +1838,9 @@ dependencies = [
[[package]] [[package]]
name = "ordered-float" name = "ordered-float"
version = "4.6.0" version = "5.0.0"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7bb71e1b3fa6ca1c61f383464aaf2bb0e2f8e772a1f01d486832464de363b951" checksum = "e2c1f9f56e534ac6a9b8a4600bdf0f530fb393b5f393e7b4d03489c3cf0c3f01"
dependencies = [ dependencies = [
"num-traits", "num-traits",
"rand 0.8.5", "rand 0.8.5",
...@@ -1891,6 +1946,34 @@ version = "0.3.31" ...@@ -1891,6 +1946,34 @@ version = "0.3.31"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "953ec861398dccce10c670dfeaf3ec4911ca479e9c02154b3a215178c5f566f2" checksum = "953ec861398dccce10c670dfeaf3ec4911ca479e9c02154b3a215178c5f566f2"
[[package]]
name = "plotters"
version = "0.3.7"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "5aeb6f403d7a4911efb1e33402027fc44f29b5bf6def3effcc22d7bb75f2b747"
dependencies = [
"num-traits",
"plotters-backend",
"plotters-svg",
"wasm-bindgen",
"web-sys",
]
[[package]]
name = "plotters-backend"
version = "0.3.7"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "df42e13c12958a16b3f7f4386b9ab1f3e7933914ecea48da7139435263a4172a"
[[package]]
name = "plotters-svg"
version = "0.3.7"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "51bae2ac328883f7acdfea3d66a7c35751187f870bc81f94563733a154d7a670"
dependencies = [
"plotters-backend",
]
[[package]] [[package]]
name = "png" name = "png"
version = "0.17.16" version = "0.17.16"
...@@ -2048,12 +2131,6 @@ dependencies = [ ...@@ -2048,12 +2131,6 @@ dependencies = [
"proc-macro2", "proc-macro2",
] ]
[[package]]
name = "radium"
version = "0.5.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "941ba9d78d8e2f7ce474c015eea4d9c6d25b6a3327f9832ee29a4de27f91bbb8"
[[package]] [[package]]
name = "radium" name = "radium"
version = "0.7.0" version = "0.7.0"
...@@ -2079,8 +2156,8 @@ source = "registry+https://github.com/rust-lang/crates.io-index" ...@@ -2079,8 +2156,8 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "3779b94aeb87e8bd4e834cee3650289ee9e0d5677f976ecdb6d219e5f4f6cd94" checksum = "3779b94aeb87e8bd4e834cee3650289ee9e0d5677f976ecdb6d219e5f4f6cd94"
dependencies = [ dependencies = [
"rand_chacha 0.9.0", "rand_chacha 0.9.0",
"rand_core 0.9.0", "rand_core 0.9.1",
"zerocopy 0.8.17", "zerocopy 0.8.18",
] ]
[[package]] [[package]]
...@@ -2100,7 +2177,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" ...@@ -2100,7 +2177,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d3022b5f1df60f26e1ffddd6c66e8aa15de382ae63b3a0c1bfc0e4d3e3f325cb" checksum = "d3022b5f1df60f26e1ffddd6c66e8aa15de382ae63b3a0c1bfc0e4d3e3f325cb"
dependencies = [ dependencies = [
"ppv-lite86", "ppv-lite86",
"rand_core 0.9.0", "rand_core 0.9.1",
] ]
[[package]] [[package]]
...@@ -2115,12 +2192,12 @@ dependencies = [ ...@@ -2115,12 +2192,12 @@ dependencies = [
[[package]] [[package]]
name = "rand_core" name = "rand_core"
version = "0.9.0" version = "0.9.1"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b08f3c9802962f7e1b25113931d94f43ed9725bebc59db9d0c3e9a23b67e15ff" checksum = "a88e0da7a2c97baa202165137c158d0a2e824ac465d13d81046727b34cb247d3"
dependencies = [ dependencies = [
"getrandom 0.3.1", "getrandom 0.3.1",
"zerocopy 0.8.17", "zerocopy 0.8.18",
] ]
[[package]] [[package]]
...@@ -2131,7 +2208,7 @@ checksum = "cd87ce80a7665b1cce111f8a16c1f3929f6547ce91ade6addf4ec86a8dda5ce9" ...@@ -2131,7 +2208,7 @@ checksum = "cd87ce80a7665b1cce111f8a16c1f3929f6547ce91ade6addf4ec86a8dda5ce9"
dependencies = [ dependencies = [
"arbitrary", "arbitrary",
"arg_enum_proc_macro", "arg_enum_proc_macro",
"arrayvec 0.7.6", "arrayvec",
"av1-grain", "av1-grain",
"bitstream-io", "bitstream-io",
"built", "built",
...@@ -2213,20 +2290,26 @@ dependencies = [ ...@@ -2213,20 +2290,26 @@ dependencies = [
[[package]] [[package]]
name = "regex" name = "regex"
version = "1.4.6" version = "1.11.1"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "2a26af418b574bd56588335b3a3659a65725d4e636eb1016c2f9e3b38c7cc759" checksum = "b544ef1b4eac5dc2db33ea63606ae9ffcfac26c1416a2806ae0bf5f56b201191"
dependencies = [ dependencies = [
"aho-corasick", "aho-corasick",
"memchr", "memchr",
"regex-syntax 0.6.29", "regex-automata",
"regex-syntax",
] ]
[[package]] [[package]]
name = "regex-syntax" name = "regex-automata"
version = "0.6.29" version = "0.4.9"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f162c6dd7b008981e4d40210aca20b4bd0f9b60ca9271061b07f78537722f2e1" checksum = "809e8dc61f6de73b46c85f4c96486310fe304c434cfa43669d7b40f711150908"
dependencies = [
"aho-corasick",
"memchr",
"regex-syntax",
]
[[package]] [[package]]
name = "regex-syntax" name = "regex-syntax"
...@@ -2280,6 +2363,15 @@ version = "1.0.19" ...@@ -2280,6 +2363,15 @@ version = "1.0.19"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "6ea1a2d0a644769cc99faa24c3ad26b379b786fe7c36fd3c546254801650e6dd" checksum = "6ea1a2d0a644769cc99faa24c3ad26b379b786fe7c36fd3c546254801650e6dd"
[[package]]
name = "same-file"
version = "1.0.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "93fc1dc3aaa9bfed95e02e6eadabb4baf7e3078b0bd1b4d7b6b0b68378900502"
dependencies = [
"winapi-util",
]
[[package]] [[package]]
name = "saturating" name = "saturating"
version = "0.1.0" version = "0.1.0"
...@@ -2318,6 +2410,18 @@ dependencies = [ ...@@ -2318,6 +2410,18 @@ dependencies = [
"syn 2.0.98", "syn 2.0.98",
] ]
[[package]]
name = "serde_json"
version = "1.0.138"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d434192e7da787e94a6ea7e9670b26a036d0ca41e0b7efb2676dd32bae872949"
dependencies = [
"itoa",
"memchr",
"ryu",
"serde",
]
[[package]] [[package]]
name = "serde_spanned" name = "serde_spanned"
version = "0.6.8" version = "0.6.8"
...@@ -2374,9 +2478,9 @@ dependencies = [ ...@@ -2374,9 +2478,9 @@ dependencies = [
[[package]] [[package]]
name = "smallvec" name = "smallvec"
version = "1.13.2" version = "1.14.0"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "3c5e1a9a646d36c3599cd173a41282daf47c44583ad367b8e6837255952e5c67" checksum = "7fcf8323ef1faaee30a44a340193b1ac6814fd9b7b4e88e9d4519a3e4abe1cfd"
[[package]] [[package]]
name = "sparsevec" name = "sparsevec"
...@@ -2458,12 +2562,12 @@ dependencies = [ ...@@ -2458,12 +2562,12 @@ dependencies = [
[[package]] [[package]]
name = "system-deps" name = "system-deps"
version = "6.1.1" version = "6.2.2"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "30c2de8a4d8f4b823d634affc9cd2a74ec98c53a756f317e529a48046cbf71f3" checksum = "a3e535eb8dded36d55ec13eddacd30dec501792ff23a0b1682c38601b8cf2349"
dependencies = [ dependencies = [
"cfg-expr", "cfg-expr",
"heck 0.4.1", "heck",
"pkg-config", "pkg-config",
"toml", "toml",
"version-compare", "version-compare",
...@@ -2489,9 +2593,9 @@ checksum = "61c41af27dd6d1e27b1b16b489db798443478cef1f06a660c96db617ba5de3b1" ...@@ -2489,9 +2593,9 @@ checksum = "61c41af27dd6d1e27b1b16b489db798443478cef1f06a660c96db617ba5de3b1"
[[package]] [[package]]
name = "tempfile" name = "tempfile"
version = "3.16.0" version = "3.17.1"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "38c246215d7d24f48ae091a2902398798e05d978b24315d6efbc00ede9a8bb91" checksum = "22e5a0acb1f3f55f65cc4a866c361b2fb2a0ff6366785ae6fbb5f85df07ba230"
dependencies = [ dependencies = [
"cfg-if", "cfg-if",
"fastrand", "fastrand",
...@@ -2565,11 +2669,21 @@ dependencies = [ ...@@ -2565,11 +2669,21 @@ dependencies = [
"time-core", "time-core",
] ]
[[package]]
name = "tinytemplate"
version = "1.2.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "be4d6b5f19ff7664e8c98d03e2139cb510db9b0a60b55f8e8709b689d939b6bc"
dependencies = [
"serde",
"serde_json",
]
[[package]] [[package]]
name = "toml" name = "toml"
version = "0.7.6" version = "0.8.20"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "c17e963a819c331dcacd7ab957d80bc2b9a9c1e71c804826d2f283dd65306542" checksum = "cd87a5cdd6ffab733b2f74bc4fd7ee5fff6634124999ac278c35fc78c6120148"
dependencies = [ dependencies = [
"serde", "serde",
"serde_spanned", "serde_spanned",
...@@ -2588,9 +2702,9 @@ dependencies = [ ...@@ -2588,9 +2702,9 @@ dependencies = [
[[package]] [[package]]
name = "toml_edit" name = "toml_edit"
version = "0.19.12" version = "0.22.24"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "c500344a19072298cd05a7224b3c0c629348b78692bf48466c5238656e315a78" checksum = "17b4795ff5edd201c7cd6dca065ae59972ce77d1b80fa0a84d94950ece7d1474"
dependencies = [ dependencies = [
"indexmap", "indexmap",
"serde", "serde",
...@@ -2675,9 +2789,9 @@ dependencies = [ ...@@ -2675,9 +2789,9 @@ dependencies = [
[[package]] [[package]]
name = "version-compare" name = "version-compare"
version = "0.1.1" version = "0.2.0"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "579a42fc0b8e0c63b76519a339be31bed574929511fa53c1a3acae26eb258f29" checksum = "852e951cb7832cb45cb1169900d19760cfa39b82bc0ea9c0e5a14ae88411c98b"
[[package]] [[package]]
name = "version_check" name = "version_check"
...@@ -2696,6 +2810,16 @@ dependencies = [ ...@@ -2696,6 +2810,16 @@ dependencies = [
"serde", "serde",
] ]
[[package]]
name = "walkdir"
version = "2.5.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "29790946404f91d9c5d06f9874efddea1dc06c5efe94541a7d6863108e3a5e4b"
dependencies = [
"same-file",
"winapi-util",
]
[[package]] [[package]]
name = "wasi" name = "wasi"
version = "0.11.0+wasi-snapshot-preview1" version = "0.11.0+wasi-snapshot-preview1"
...@@ -2814,6 +2938,15 @@ version = "0.4.0" ...@@ -2814,6 +2938,15 @@ version = "0.4.0"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ac3b87c63620426dd9b991e5ce0329eff545bccbbb34f3be09ff6fb6ab51b7b6" checksum = "ac3b87c63620426dd9b991e5ce0329eff545bccbbb34f3be09ff6fb6ab51b7b6"
[[package]]
name = "winapi-util"
version = "0.1.9"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "cf221c93e13a30d793f7645a0e7762c55d169dbb0a49671918a2319d289b10bb"
dependencies = [
"windows-sys",
]
[[package]] [[package]]
name = "winapi-x86_64-pc-windows-gnu" name = "winapi-x86_64-pc-windows-gnu"
version = "0.4.0" version = "0.4.0"
...@@ -3022,9 +3155,9 @@ checksum = "271414315aff87387382ec3d271b52d7ae78726f5d44ac98b4f4030c91880486" ...@@ -3022,9 +3155,9 @@ checksum = "271414315aff87387382ec3d271b52d7ae78726f5d44ac98b4f4030c91880486"
[[package]] [[package]]
name = "winnow" name = "winnow"
version = "0.4.7" version = "0.7.2"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ca0ace3845f0d96209f0375e6d367e3eb87eb65d27d445bdc9f1843a26f39448" checksum = "59690dea168f2198d1a3b0cac23b8063efcd11012f10ae4698f284808c8ef603"
dependencies = [ dependencies = [
"memchr", "memchr",
] ]
...@@ -3058,12 +3191,6 @@ dependencies = [ ...@@ -3058,12 +3191,6 @@ dependencies = [
"syn 1.0.109", "syn 1.0.109",
] ]
[[package]]
name = "wyz"
version = "0.2.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "85e60b0d1b5f99db2556934e21937020776a5d31520bf169e851ac44e6420214"
[[package]] [[package]]
name = "wyz" name = "wyz"
version = "0.5.1" version = "0.5.1"
...@@ -3085,11 +3212,11 @@ dependencies = [ ...@@ -3085,11 +3212,11 @@ dependencies = [
[[package]] [[package]]
name = "zerocopy" name = "zerocopy"
version = "0.8.17" version = "0.8.18"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "aa91407dacce3a68c56de03abe2760159582b846c6a4acd2f456618087f12713" checksum = "79386d31a42a4996e3336b0919ddb90f81112af416270cff95b5f5af22b839c2"
dependencies = [ dependencies = [
"zerocopy-derive 0.8.17", "zerocopy-derive 0.8.18",
] ]
[[package]] [[package]]
...@@ -3105,9 +3232,9 @@ dependencies = [ ...@@ -3105,9 +3232,9 @@ dependencies = [
[[package]] [[package]]
name = "zerocopy-derive" name = "zerocopy-derive"
version = "0.8.17" version = "0.8.18"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "06718a168365cad3d5ff0bb133aad346959a2074bd4a85c121255a11304a8626" checksum = "76331675d372f91bf8d17e13afbd5fe639200b73d01f0fc748bb059f9cca2db7"
dependencies = [ dependencies = [
"proc-macro2", "proc-macro2",
"quote", "quote",
......
...@@ -6,7 +6,7 @@ edition = "2021" ...@@ -6,7 +6,7 @@ edition = "2021"
[dependencies] [dependencies]
rand = "*" rand = "*"
nom = "6.2.2" nom = "*"
ordered-float = { version = "*", features = ["serde"] } ordered-float = { version = "*", features = ["serde"] }
bitvec = "*" bitvec = "*"
serde = { version = "*", features = ["derive"] } serde = { version = "*", features = ["derive"] }
......
...@@ -3,6 +3,8 @@ use std::collections::{HashMap, HashSet}; ...@@ -3,6 +3,8 @@ use std::collections::{HashMap, HashSet};
use std::ops::Deref; use std::ops::Deref;
use std::str::FromStr; use std::str::FromStr;
use nom::Parser;
use crate::*; use crate::*;
/* /*
...@@ -128,9 +130,8 @@ fn parse_module<'a>(ir_text: &'a str, context: Context<'a>) -> nom::IResult<&'a ...@@ -128,9 +130,8 @@ fn parse_module<'a>(ir_text: &'a str, context: Context<'a>) -> nom::IResult<&'a
// If there is any text left after successfully parsing some functions, // If there is any text left after successfully parsing some functions,
// treat that as an error. // treat that as an error.
let (rest, functions) = let (rest, functions) =
nom::combinator::all_consuming(nom::multi::many0(|x| parse_function(x, &context)))( nom::combinator::all_consuming(nom::multi::many0(|x| parse_function(x, &context)))
ir_text, .parse(ir_text)?;
)?;
let mut context = context.into_inner(); let mut context = context.into_inner();
// Functions, as returned by parsing, is in parse order, which may differ // Functions, as returned by parsing, is in parse order, which may differ
...@@ -215,7 +216,7 @@ fn parse_function<'a>( ...@@ -215,7 +216,7 @@ fn parse_function<'a>(
Ok((ir_text, num_dynamic_constants)) Ok((ir_text, num_dynamic_constants))
}; };
let (ir_text, num_dynamic_constants) = let (ir_text, num_dynamic_constants) =
nom::combinator::opt(parse_num_dynamic_constants)(ir_text)?; nom::combinator::opt(parse_num_dynamic_constants).parse(ir_text)?;
// If unspecified, assumed function has no dynamic constant arguments. // If unspecified, assumed function has no dynamic constant arguments.
let num_dynamic_constants = num_dynamic_constants.unwrap_or(0); let num_dynamic_constants = num_dynamic_constants.unwrap_or(0);
...@@ -223,7 +224,7 @@ fn parse_function<'a>( ...@@ -223,7 +224,7 @@ fn parse_function<'a>(
let ir_text = nom::character::complete::char('(')(ir_text)?.0; let ir_text = nom::character::complete::char('(')(ir_text)?.0;
let (ir_text, params) = nom::multi::separated_list0( let (ir_text, params) = nom::multi::separated_list0(
nom::character::complete::char(','), nom::character::complete::char(','),
nom::sequence::tuple(( (
nom::character::complete::multispace0, nom::character::complete::multispace0,
parse_identifier, parse_identifier,
nom::character::complete::multispace0, nom::character::complete::multispace0,
...@@ -231,8 +232,9 @@ fn parse_function<'a>( ...@@ -231,8 +232,9 @@ fn parse_function<'a>(
nom::character::complete::multispace0, nom::character::complete::multispace0,
|x| parse_type_id(x, context), |x| parse_type_id(x, context),
nom::character::complete::multispace0, nom::character::complete::multispace0,
)), ),
)(ir_text)?; )
.parse(ir_text)?;
// The start node is not explicitly specified in the textual IR, so create // The start node is not explicitly specified in the textual IR, so create
// it manually. // it manually.
...@@ -246,7 +248,7 @@ fn parse_function<'a>( ...@@ -246,7 +248,7 @@ fn parse_function<'a>(
let ir_text = nom::character::complete::multispace0(ir_text)?.0; let ir_text = nom::character::complete::multispace0(ir_text)?.0;
let ir_text = nom::bytes::complete::tag("->")(ir_text)?.0; let ir_text = nom::bytes::complete::tag("->")(ir_text)?.0;
let (ir_text, return_type) = parse_type_id(ir_text, context)?; let (ir_text, return_type) = parse_type_id(ir_text, context)?;
let (ir_text, nodes) = nom::multi::many1(|x| parse_node(x, context))(ir_text)?; let (ir_text, nodes) = nom::multi::many1(|x| parse_node(x, context)).parse(ir_text)?;
// `nodes`, as returned by parsing, is in parse order, which may differ from // `nodes`, as returned by parsing, is in parse order, which may differ from
// the order dictated by NodeIDs in the node name intern map. // the order dictated by NodeIDs in the node name intern map.
...@@ -383,11 +385,11 @@ fn parse_region<'a>( ...@@ -383,11 +385,11 @@ fn parse_region<'a>(
// explicitly using nom's separated list functionality. This example here // explicitly using nom's separated list functionality. This example here
// is a bit of an abuse of what parse_tupleN functions are meant for. // is a bit of an abuse of what parse_tupleN functions are meant for.
let (ir_text, (preds,)) = parse_tuple1(nom::multi::separated_list1( let (ir_text, (preds,)) = parse_tuple1(nom::multi::separated_list1(
nom::sequence::tuple(( (
nom::character::complete::multispace0, nom::character::complete::multispace0,
nom::character::complete::char(','), nom::character::complete::char(','),
nom::character::complete::multispace0, nom::character::complete::multispace0,
)), ),
parse_identifier, parse_identifier,
))(ir_text)?; ))(ir_text)?;
...@@ -411,11 +413,11 @@ fn parse_fork<'a>(ir_text: &'a str, context: &RefCell<Context<'a>>) -> nom::IRes ...@@ -411,11 +413,11 @@ fn parse_fork<'a>(ir_text: &'a str, context: &RefCell<Context<'a>>) -> nom::IRes
let (ir_text, (control, factors)) = parse_tuple2( let (ir_text, (control, factors)) = parse_tuple2(
parse_identifier, parse_identifier,
nom::multi::separated_list1( nom::multi::separated_list1(
nom::sequence::tuple(( (
nom::character::complete::multispace0, nom::character::complete::multispace0,
nom::character::complete::char(','), nom::character::complete::char(','),
nom::character::complete::multispace0, nom::character::complete::multispace0,
)), ),
|x| parse_dynamic_constant_id(x, context), |x| parse_dynamic_constant_id(x, context),
), ),
)(ir_text)?; )(ir_text)?;
...@@ -440,11 +442,11 @@ fn parse_phi<'a>(ir_text: &'a str, context: &RefCell<Context<'a>>) -> nom::IResu ...@@ -440,11 +442,11 @@ fn parse_phi<'a>(ir_text: &'a str, context: &RefCell<Context<'a>>) -> nom::IResu
let (ir_text, (control, data)) = parse_tuple2( let (ir_text, (control, data)) = parse_tuple2(
parse_identifier, parse_identifier,
nom::multi::separated_list1( nom::multi::separated_list1(
nom::sequence::tuple(( (
nom::character::complete::multispace0, nom::character::complete::multispace0,
nom::character::complete::char(','), nom::character::complete::char(','),
nom::character::complete::multispace0, nom::character::complete::multispace0,
)), ),
parse_identifier, parse_identifier,
), ),
)(ir_text)?; )(ir_text)?;
...@@ -575,29 +577,32 @@ fn parse_call<'a>(ir_text: &'a str, context: &RefCell<Context<'a>>) -> nom::IRes ...@@ -575,29 +577,32 @@ fn parse_call<'a>(ir_text: &'a str, context: &RefCell<Context<'a>>) -> nom::IRes
let ir_text = nom::character::complete::char('<')(ir_text)?.0; let ir_text = nom::character::complete::char('<')(ir_text)?.0;
let ir_text = nom::character::complete::multispace0(ir_text)?.0; let ir_text = nom::character::complete::multispace0(ir_text)?.0;
let (ir_text, dynamic_constants) = nom::multi::separated_list1( let (ir_text, dynamic_constants) = nom::multi::separated_list1(
nom::sequence::tuple(( (
nom::character::complete::multispace0, nom::character::complete::multispace0,
nom::character::complete::char(','), nom::character::complete::char(','),
nom::character::complete::multispace0, nom::character::complete::multispace0,
)), ),
|x| parse_dynamic_constant_id(x, context), |x| parse_dynamic_constant_id(x, context),
)(ir_text)?; )
.parse(ir_text)?;
let ir_text = nom::character::complete::multispace0(ir_text)?.0; let ir_text = nom::character::complete::multispace0(ir_text)?.0;
let ir_text = nom::character::complete::char('>')(ir_text)?.0; let ir_text = nom::character::complete::char('>')(ir_text)?.0;
Ok((ir_text, dynamic_constants)) Ok((ir_text, dynamic_constants))
}; };
let (ir_text, dynamic_constants) = nom::combinator::opt(parse_dynamic_constants)(ir_text)?; let (ir_text, dynamic_constants) =
nom::combinator::opt(parse_dynamic_constants).parse(ir_text)?;
let dynamic_constants = dynamic_constants.unwrap_or(vec![]); let dynamic_constants = dynamic_constants.unwrap_or(vec![]);
let ir_text = nom::character::complete::char('(')(ir_text)?.0; let ir_text = nom::character::complete::char('(')(ir_text)?.0;
let ir_text = nom::character::complete::multispace0(ir_text)?.0; let ir_text = nom::character::complete::multispace0(ir_text)?.0;
let (ir_text, mut function_and_args) = nom::multi::separated_list1( let (ir_text, mut function_and_args) = nom::multi::separated_list1(
nom::sequence::tuple(( (
nom::character::complete::multispace0, nom::character::complete::multispace0,
nom::character::complete::char(','), nom::character::complete::char(','),
nom::character::complete::multispace0, nom::character::complete::multispace0,
)), ),
parse_identifier, parse_identifier,
)(ir_text)?; )
.parse(ir_text)?;
let function = function_and_args.remove(0); let function = function_and_args.remove(0);
let mut args: Vec<NodeID> = function_and_args let mut args: Vec<NodeID> = function_and_args
.into_iter() .into_iter()
...@@ -628,13 +633,14 @@ fn parse_intrinsic<'a>( ...@@ -628,13 +633,14 @@ fn parse_intrinsic<'a>(
let ir_text = nom::character::complete::char('(')(ir_text)?.0; let ir_text = nom::character::complete::char('(')(ir_text)?.0;
let ir_text = nom::character::complete::multispace0(ir_text)?.0; let ir_text = nom::character::complete::multispace0(ir_text)?.0;
let (ir_text, mut intrinsic_and_args) = nom::multi::separated_list1( let (ir_text, mut intrinsic_and_args) = nom::multi::separated_list1(
nom::sequence::tuple(( (
nom::character::complete::multispace0, nom::character::complete::multispace0,
nom::character::complete::char(','), nom::character::complete::char(','),
nom::character::complete::multispace0, nom::character::complete::multispace0,
)), ),
parse_identifier, parse_identifier,
)(ir_text)?; )
.parse(ir_text)?;
let intrinsic = intrinsic_and_args.remove(0); let intrinsic = intrinsic_and_args.remove(0);
let args: Vec<NodeID> = intrinsic_and_args let args: Vec<NodeID> = intrinsic_and_args
.into_iter() .into_iter()
...@@ -661,7 +667,7 @@ fn parse_index<'a>( ...@@ -661,7 +667,7 @@ fn parse_index<'a>(
) -> nom::IResult<&'a str, Index> { ) -> nom::IResult<&'a str, Index> {
let (ir_text, idx) = nom::branch::alt(( let (ir_text, idx) = nom::branch::alt((
nom::combinator::map( nom::combinator::map(
nom::sequence::tuple(( (
nom::character::complete::multispace0, nom::character::complete::multispace0,
nom::bytes::complete::tag("field"), nom::bytes::complete::tag("field"),
nom::character::complete::multispace0, nom::character::complete::multispace0,
...@@ -671,11 +677,11 @@ fn parse_index<'a>( ...@@ -671,11 +677,11 @@ fn parse_index<'a>(
nom::character::complete::multispace0, nom::character::complete::multispace0,
nom::character::complete::char(')'), nom::character::complete::char(')'),
nom::character::complete::multispace0, nom::character::complete::multispace0,
)), ),
|(_, _, _, _, _, x, _, _, _)| Index::Field(x), |(_, _, _, _, _, x, _, _, _)| Index::Field(x),
), ),
nom::combinator::map( nom::combinator::map(
nom::sequence::tuple(( (
nom::character::complete::multispace0, nom::character::complete::multispace0,
nom::bytes::complete::tag("variant"), nom::bytes::complete::tag("variant"),
nom::character::complete::multispace0, nom::character::complete::multispace0,
...@@ -685,28 +691,28 @@ fn parse_index<'a>( ...@@ -685,28 +691,28 @@ fn parse_index<'a>(
nom::character::complete::multispace0, nom::character::complete::multispace0,
nom::character::complete::char(')'), nom::character::complete::char(')'),
nom::character::complete::multispace0, nom::character::complete::multispace0,
)), ),
|(_, _, _, _, _, x, _, _, _)| Index::Variant(x), |(_, _, _, _, _, x, _, _, _)| Index::Variant(x),
), ),
nom::combinator::map( nom::combinator::map(
nom::sequence::tuple(( (
nom::character::complete::multispace0, nom::character::complete::multispace0,
nom::bytes::complete::tag("position"), nom::bytes::complete::tag("position"),
nom::character::complete::multispace0, nom::character::complete::multispace0,
nom::character::complete::char('('), nom::character::complete::char('('),
nom::character::complete::multispace0, nom::character::complete::multispace0,
nom::multi::separated_list1( nom::multi::separated_list1(
nom::sequence::tuple(( (
nom::character::complete::multispace0, nom::character::complete::multispace0,
nom::character::complete::char(','), nom::character::complete::char(','),
nom::character::complete::multispace0, nom::character::complete::multispace0,
)), ),
parse_identifier, parse_identifier,
), ),
nom::character::complete::multispace0, nom::character::complete::multispace0,
nom::character::complete::char(')'), nom::character::complete::char(')'),
nom::character::complete::multispace0, nom::character::complete::multispace0,
)), ),
|(_, _, _, _, _, x, _, _, _)| { |(_, _, _, _, _, x, _, _, _)| {
Index::Position( Index::Position(
x.into_iter() x.into_iter()
...@@ -715,7 +721,8 @@ fn parse_index<'a>( ...@@ -715,7 +721,8 @@ fn parse_index<'a>(
) )
}, },
), ),
))(ir_text)?; ))
.parse(ir_text)?;
Ok((ir_text, idx)) Ok((ir_text, idx))
} }
...@@ -744,13 +751,14 @@ fn parse_read<'a>(ir_text: &'a str, context: &RefCell<Context<'a>>) -> nom::IRes ...@@ -744,13 +751,14 @@ fn parse_read<'a>(ir_text: &'a str, context: &RefCell<Context<'a>>) -> nom::IRes
let ir_text = nom::character::complete::char(',')(ir_text)?.0; let ir_text = nom::character::complete::char(',')(ir_text)?.0;
let ir_text = nom::character::complete::multispace0(ir_text)?.0; let ir_text = nom::character::complete::multispace0(ir_text)?.0;
let (ir_text, indices) = nom::multi::separated_list1( let (ir_text, indices) = nom::multi::separated_list1(
nom::sequence::tuple(( (
nom::character::complete::multispace0, nom::character::complete::multispace0,
nom::character::complete::char(','), nom::character::complete::char(','),
nom::character::complete::multispace0, nom::character::complete::multispace0,
)), ),
|x| parse_index(x, context), |x| parse_index(x, context),
)(ir_text)?; )
.parse(ir_text)?;
let ir_text = nom::character::complete::multispace0(ir_text)?.0; let ir_text = nom::character::complete::multispace0(ir_text)?.0;
let ir_text = nom::character::complete::char(')')(ir_text)?.0; let ir_text = nom::character::complete::char(')')(ir_text)?.0;
let collect = context.borrow_mut().get_node_id(collect); let collect = context.borrow_mut().get_node_id(collect);
...@@ -779,13 +787,14 @@ fn parse_write<'a>( ...@@ -779,13 +787,14 @@ fn parse_write<'a>(
let ir_text = nom::character::complete::char(',')(ir_text)?.0; let ir_text = nom::character::complete::char(',')(ir_text)?.0;
let ir_text = nom::character::complete::multispace0(ir_text)?.0; let ir_text = nom::character::complete::multispace0(ir_text)?.0;
let (ir_text, indices) = nom::multi::separated_list1( let (ir_text, indices) = nom::multi::separated_list1(
nom::sequence::tuple(( (
nom::character::complete::multispace0, nom::character::complete::multispace0,
nom::character::complete::char(','), nom::character::complete::char(','),
nom::character::complete::multispace0, nom::character::complete::multispace0,
)), ),
|x| parse_index(x, context), |x| parse_index(x, context),
)(ir_text)?; )
.parse(ir_text)?;
let ir_text = nom::character::complete::multispace0(ir_text)?.0; let ir_text = nom::character::complete::multispace0(ir_text)?.0;
let ir_text = nom::character::complete::char(')')(ir_text)?.0; let ir_text = nom::character::complete::char(')')(ir_text)?.0;
let collect = context.borrow_mut().get_node_id(collect); let collect = context.borrow_mut().get_node_id(collect);
...@@ -844,48 +853,48 @@ fn parse_type<'a>(ir_text: &'a str, context: &RefCell<Context<'a>>) -> nom::IRes ...@@ -844,48 +853,48 @@ fn parse_type<'a>(ir_text: &'a str, context: &RefCell<Context<'a>>) -> nom::IRes
nom::combinator::map(nom::bytes::complete::tag("f64"), |_| Type::Float64), nom::combinator::map(nom::bytes::complete::tag("f64"), |_| Type::Float64),
// Product types are parsed as a list of their element types. // Product types are parsed as a list of their element types.
nom::combinator::map( nom::combinator::map(
nom::sequence::tuple(( (
nom::bytes::complete::tag("prod"), nom::bytes::complete::tag("prod"),
nom::character::complete::multispace0, nom::character::complete::multispace0,
nom::character::complete::char('('), nom::character::complete::char('('),
nom::character::complete::multispace0, nom::character::complete::multispace0,
nom::multi::separated_list1( nom::multi::separated_list1(
nom::sequence::tuple(( (
nom::character::complete::multispace0, nom::character::complete::multispace0,
nom::character::complete::char(','), nom::character::complete::char(','),
nom::character::complete::multispace0, nom::character::complete::multispace0,
)), ),
|x| parse_type_id(x, context), |x| parse_type_id(x, context),
), ),
nom::character::complete::multispace0, nom::character::complete::multispace0,
nom::character::complete::char(')'), nom::character::complete::char(')'),
)), ),
|(_, _, _, _, ids, _, _)| Type::Product(ids.into_boxed_slice()), |(_, _, _, _, ids, _, _)| Type::Product(ids.into_boxed_slice()),
), ),
// Sum types are parsed as a list of their variant types. // Sum types are parsed as a list of their variant types.
nom::combinator::map( nom::combinator::map(
nom::sequence::tuple(( (
nom::bytes::complete::tag("sum"), nom::bytes::complete::tag("sum"),
nom::character::complete::multispace0, nom::character::complete::multispace0,
nom::character::complete::char('('), nom::character::complete::char('('),
nom::character::complete::multispace0, nom::character::complete::multispace0,
nom::multi::separated_list1( nom::multi::separated_list1(
nom::sequence::tuple(( (
nom::character::complete::multispace0, nom::character::complete::multispace0,
nom::character::complete::char(','), nom::character::complete::char(','),
nom::character::complete::multispace0, nom::character::complete::multispace0,
)), ),
|x| parse_type_id(x, context), |x| parse_type_id(x, context),
), ),
nom::character::complete::multispace0, nom::character::complete::multispace0,
nom::character::complete::char(')'), nom::character::complete::char(')'),
)), ),
|(_, _, _, _, ids, _, _)| Type::Summation(ids.into_boxed_slice()), |(_, _, _, _, ids, _, _)| Type::Summation(ids.into_boxed_slice()),
), ),
// Array types are just a list of an element type and at least one // Array types are just a list of an element type and at least one
// dynamic constant representing its extent. // dynamic constant representing its extent.
nom::combinator::map( nom::combinator::map(
nom::sequence::tuple(( (
nom::bytes::complete::tag("array"), nom::bytes::complete::tag("array"),
nom::character::complete::multispace0, nom::character::complete::multispace0,
nom::character::complete::char('('), nom::character::complete::char('('),
...@@ -895,21 +904,22 @@ fn parse_type<'a>(ir_text: &'a str, context: &RefCell<Context<'a>>) -> nom::IRes ...@@ -895,21 +904,22 @@ fn parse_type<'a>(ir_text: &'a str, context: &RefCell<Context<'a>>) -> nom::IRes
nom::character::complete::char(','), nom::character::complete::char(','),
nom::character::complete::multispace0, nom::character::complete::multispace0,
nom::multi::separated_list1( nom::multi::separated_list1(
nom::sequence::tuple(( (
nom::character::complete::multispace0, nom::character::complete::multispace0,
nom::character::complete::char(','), nom::character::complete::char(','),
nom::character::complete::multispace0, nom::character::complete::multispace0,
)), ),
|x| parse_dynamic_constant_id(x, context), |x| parse_dynamic_constant_id(x, context),
), ),
nom::character::complete::multispace0, nom::character::complete::multispace0,
nom::character::complete::char(')'), nom::character::complete::char(')'),
)), ),
|(_, _, _, _, ty_id, _, _, _, dc_ids, _, _)| { |(_, _, _, _, ty_id, _, _, _, dc_ids, _, _)| {
Type::Array(ty_id, dc_ids.into_boxed_slice()) Type::Array(ty_id, dc_ids.into_boxed_slice())
}, },
), ),
))(ir_text)?; ))
.parse(ir_text)?;
Ok((ir_text, ty)) Ok((ir_text, ty))
} }
...@@ -940,20 +950,20 @@ fn parse_dynamic_constant<'a>( ...@@ -940,20 +950,20 @@ fn parse_dynamic_constant<'a>(
// Parameter dynamic constants of a function are written by preprending // Parameter dynamic constants of a function are written by preprending
// a '#' to the parameter's number. // a '#' to the parameter's number.
nom::combinator::map( nom::combinator::map(
nom::sequence::tuple((nom::character::complete::char('#'), |x| { (nom::character::complete::char('#'), |x| {
parse_prim::<usize>(x, "1234567890") parse_prim::<usize>(x, "1234567890")
})), }),
|(_, x)| DynamicConstant::Parameter(x), |(_, x)| DynamicConstant::Parameter(x),
), ),
// Dynamic constant math is written using a prefix function // Dynamic constant math is written using a prefix function
nom::combinator::map( nom::combinator::map(
nom::sequence::tuple(( (
nom::character::complete::one_of("+-*/%"), nom::character::complete::one_of("+-*/%"),
parse_tuple2( parse_tuple2(
|x| parse_dynamic_constant_id(x, context), |x| parse_dynamic_constant_id(x, context),
|x| parse_dynamic_constant_id(x, context), |x| parse_dynamic_constant_id(x, context),
), ),
)), ),
|(op, (x, y))| match op { |(op, (x, y))| match op {
'+' => DynamicConstant::Add(vec![x, y]), '+' => DynamicConstant::Add(vec![x, y]),
'-' => DynamicConstant::Sub(x, y), '-' => DynamicConstant::Sub(x, y),
...@@ -963,7 +973,8 @@ fn parse_dynamic_constant<'a>( ...@@ -963,7 +973,8 @@ fn parse_dynamic_constant<'a>(
_ => panic!("Invalid parse"), _ => panic!("Invalid parse"),
}, },
), ),
))(ir_text)?; ))
.parse(ir_text)?;
Ok((ir_text, dc)) Ok((ir_text, dc))
} }
...@@ -1041,7 +1052,8 @@ fn parse_boolean<'a>(ir_text: &'a str) -> nom::IResult<&'a str, Constant> { ...@@ -1041,7 +1052,8 @@ fn parse_boolean<'a>(ir_text: &'a str) -> nom::IResult<&'a str, Constant> {
let (ir_text, val) = nom::branch::alt(( let (ir_text, val) = nom::branch::alt((
nom::combinator::map(nom::bytes::complete::tag("false"), |_| false), nom::combinator::map(nom::bytes::complete::tag("false"), |_| false),
nom::combinator::map(nom::bytes::complete::tag("true"), |_| true), nom::combinator::map(nom::bytes::complete::tag("true"), |_| true),
))(ir_text)?; ))
.parse(ir_text)?;
Ok((ir_text, Constant::Boolean(val))) Ok((ir_text, Constant::Boolean(val)))
} }
...@@ -1178,7 +1190,8 @@ fn parse_identifier<'a>(ir_text: &'a str) -> nom::IResult<&'a str, &'a str> { ...@@ -1178,7 +1190,8 @@ fn parse_identifier<'a>(ir_text: &'a str) -> nom::IResult<&'a str, &'a str> {
"1234567890_@ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz", "1234567890_@ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz",
), ),
|s: &str| s.len() > 0, |s: &str| s.len() > 0,
)(ir_text) )
.parse(ir_text)
} }
/* /*
...@@ -1186,7 +1199,7 @@ fn parse_identifier<'a>(ir_text: &'a str) -> nom::IResult<&'a str, &'a str> { ...@@ -1186,7 +1199,7 @@ fn parse_identifier<'a>(ir_text: &'a str) -> nom::IResult<&'a str, &'a str> {
*/ */
fn parse_tuple1<'a, A, AF>(mut parse_a: AF) -> impl FnMut(&'a str) -> nom::IResult<&'a str, (A,)> fn parse_tuple1<'a, A, AF>(mut parse_a: AF) -> impl FnMut(&'a str) -> nom::IResult<&'a str, (A,)>
where where
AF: nom::Parser<&'a str, A, nom::error::Error<&'a str>>, AF: nom::Parser<&'a str, Output = A, Error = nom::error::Error<&'a str>>,
{ {
move |ir_text: &'a str| { move |ir_text: &'a str| {
let ir_text = nom::character::complete::multispace0(ir_text)?.0; let ir_text = nom::character::complete::multispace0(ir_text)?.0;
...@@ -1204,8 +1217,8 @@ fn parse_tuple2<'a, A, B, AF, BF>( ...@@ -1204,8 +1217,8 @@ fn parse_tuple2<'a, A, B, AF, BF>(
mut parse_b: BF, mut parse_b: BF,
) -> impl FnMut(&'a str) -> nom::IResult<&'a str, (A, B)> ) -> impl FnMut(&'a str) -> nom::IResult<&'a str, (A, B)>
where where
AF: nom::Parser<&'a str, A, nom::error::Error<&'a str>>, AF: nom::Parser<&'a str, Output = A, Error = nom::error::Error<&'a str>>,
BF: nom::Parser<&'a str, B, nom::error::Error<&'a str>>, BF: nom::Parser<&'a str, Output = B, Error = nom::error::Error<&'a str>>,
{ {
move |ir_text: &'a str| { move |ir_text: &'a str| {
let ir_text = nom::character::complete::multispace0(ir_text)?.0; let ir_text = nom::character::complete::multispace0(ir_text)?.0;
...@@ -1228,9 +1241,9 @@ fn parse_tuple3<'a, A, B, C, AF, BF, CF>( ...@@ -1228,9 +1241,9 @@ fn parse_tuple3<'a, A, B, C, AF, BF, CF>(
mut parse_c: CF, mut parse_c: CF,
) -> impl FnMut(&'a str) -> nom::IResult<&'a str, (A, B, C)> ) -> impl FnMut(&'a str) -> nom::IResult<&'a str, (A, B, C)>
where where
AF: nom::Parser<&'a str, A, nom::error::Error<&'a str>>, AF: nom::Parser<&'a str, Output = A, Error = nom::error::Error<&'a str>>,
BF: nom::Parser<&'a str, B, nom::error::Error<&'a str>>, BF: nom::Parser<&'a str, Output = B, Error = nom::error::Error<&'a str>>,
CF: nom::Parser<&'a str, C, nom::error::Error<&'a str>>, CF: nom::Parser<&'a str, Output = C, Error = nom::error::Error<&'a str>>,
{ {
move |ir_text: &'a str| { move |ir_text: &'a str| {
let ir_text = nom::character::complete::multispace0(ir_text)?.0; let ir_text = nom::character::complete::multispace0(ir_text)?.0;
......
#![feature(once_cell_try)] #![feature(once_cell_try)]
use std::alloc::{alloc, dealloc, Layout}; use std::alloc::{alloc, dealloc, GlobalAlloc, Layout, System};
use std::marker::PhantomData; use std::marker::PhantomData;
use std::ptr::{copy_nonoverlapping, write_bytes, NonNull}; use std::ptr::{copy_nonoverlapping, write_bytes, NonNull};
use std::slice::{from_raw_parts, from_raw_parts_mut}; use std::slice::{from_raw_parts, from_raw_parts_mut};
...@@ -867,3 +867,24 @@ impl<'a, T> HerculesRefInto<'a> for Box<[T]> { ...@@ -867,3 +867,24 @@ impl<'a, T> HerculesRefInto<'a> for Box<[T]> {
HerculesCPURef::from_slice(self) HerculesCPURef::from_slice(self)
} }
} }
/*
* We need all allocations to be aligned to LARGEST_ALIGNMENT bytes for
* vectorization. This is the easiest way to do that.
*/
pub struct AlignedAlloc;
unsafe impl GlobalAlloc for AlignedAlloc {
unsafe fn alloc(&self, layout: Layout) -> *mut u8 {
let layout = layout.align_to(LARGEST_ALIGNMENT).unwrap();
System.alloc(layout)
}
unsafe fn dealloc(&self, ptr: *mut u8, layout: Layout) {
let layout = layout.align_to(LARGEST_ALIGNMENT).unwrap();
System.dealloc(ptr, layout)
}
}
#[global_allocator]
static A: AlignedAlloc = AlignedAlloc;
...@@ -2,36 +2,20 @@ ...@@ -2,36 +2,20 @@
#[cfg(feature = "cuda")] #[cfg(feature = "cuda")]
use hercules_rt::CUDABox; use hercules_rt::CUDABox;
use hercules_rt::{runner, HerculesCPURef}; use hercules_rt::{runner, HerculesImmBox, HerculesImmBoxTo};
juno_build::juno!("dot"); juno_build::juno!("dot");
fn main() { fn main() {
async_std::task::block_on(async { async_std::task::block_on(async {
#[cfg(not(feature = "cuda"))] let a: Box<[f32; 8]> = Box::new([0.0, 1.0, 0.0, 2.0, 0.0, 3.0, 0.0, 4.0]);
{ let b: Box<[f32; 8]> = Box::new([0.0, 5.0, 0.0, 6.0, 0.0, 7.0, 0.0, 8.0]);
let a: [f32; 8] = [0.0, 1.0, 0.0, 2.0, 0.0, 3.0, 0.0, 4.0]; let a = HerculesImmBox::from(a.as_ref() as &[f32]);
let a = HerculesCPURef::from_slice(&a); let b = HerculesImmBox::from(b.as_ref() as &[f32]);
let b: [f32; 8] = [0.0, 5.0, 0.0, 6.0, 0.0, 7.0, 0.0, 8.0]; let mut r = runner!(dot);
let b = HerculesCPURef::from_slice(&b); let c = r.run(8, a.to(), b.to()).await;
let mut r = runner!(dot); println!("{}", c);
let c = r.run(8, a, b).await; assert_eq!(c, 70.0);
println!("{}", c);
assert_eq!(c, 70.0);
}
#[cfg(feature = "cuda")]
{
let mut a: [f32; 8] = [0.0, 1.0, 0.0, 2.0, 0.0, 3.0, 0.0, 4.0];
let a_box = CUDABox::from_cpu_ref(HerculesCPURef::from_slice(&mut a));
let a = a_box.get_ref();
let mut b: [f32; 8] = [0.0, 5.0, 0.0, 6.0, 0.0, 7.0, 0.0, 8.0];
let b_box = CUDABox::from_cpu_ref(HerculesCPURef::from_slice(&mut b));
let b = b_box.get_ref();
let mut r = runner!(dot);
let c = r.run(8, a, b).await;
println!("{}", c);
assert_eq!(c, 70.0);
}
}); });
} }
......
#![feature(concat_idents)] #![feature(concat_idents)]
use std::iter::zip;
use rand::random; use rand::random;
#[cfg(feature = "cuda")] use hercules_rt::{runner, HerculesImmBox, HerculesImmBoxTo, HerculesMutBox};
use hercules_rt::CUDABox;
use hercules_rt::{runner, HerculesCPURef};
juno_build::juno!("matmul"); juno_build::juno!("matmul");
fn main() { fn main() {
async_std::task::block_on(async { async_std::task::block_on(async {
const I: usize = 256; const I: usize = 256;
const J: usize = 8; // hardcoded constant in matmul.hir const J: usize = 64;
const K: usize = 128; const K: usize = 128;
let mut a: Box<[i32]> = (0..I * J).map(|_| random::<i32>() % 100).collect(); let a: Box<[i32]> = (0..I * J).map(|_| random::<i32>() % 100).collect();
let mut b: Box<[i32]> = (0..J * K).map(|_| random::<i32>() % 100).collect(); let b: Box<[i32]> = (0..J * K).map(|_| random::<i32>() % 100).collect();
let mut correct_c: Box<[i32]> = (0..I * K).map(|_| 0).collect(); let mut correct_c: Box<[i32]> = (0..I * K).map(|_| 0).collect();
for i in 0..I { for i in 0..I {
for k in 0..K { for k in 0..K {
...@@ -23,26 +22,11 @@ fn main() { ...@@ -23,26 +22,11 @@ fn main() {
} }
} }
} }
#[cfg(not(feature = "cuda"))] let a = HerculesImmBox::from(a.as_ref());
{ let b = HerculesImmBox::from(b.as_ref());
let a = HerculesCPURef::from_slice(&mut a); let mut r = runner!(matmul);
let b = HerculesCPURef::from_slice(&mut b); let mut c: HerculesMutBox<i32> = HerculesMutBox::from(r.run(I as u64, J as u64, K as u64, a.to(), b.to()).await);
let mut r = runner!(matmul); assert_eq!(c.as_slice(), correct_c.as_ref());
let c = r.run(I as u64, J as u64, K as u64, a, b).await;
assert_eq!(c.as_slice::<i32>(), &*correct_c);
}
#[cfg(feature = "cuda")]
{
let a = CUDABox::from_cpu_ref(HerculesCPURef::from_slice(&mut a));
let b = CUDABox::from_cpu_ref(HerculesCPURef::from_slice(&mut b));
let mut r = runner!(matmul);
let c = r
.run(I as u64, J as u64, K as u64, a.get_ref(), b.get_ref())
.await;
let mut c_cpu: Box<[i32]> = vec![0; correct_c.len()].into_boxed_slice();
c.to_cpu_ref(&mut c_cpu);
assert_eq!(&*c_cpu, &*correct_c);
}
}); });
} }
......
fn matmul<3>(a: array(i32, #0, 8), b: array(i32, 8, #2)) -> array(i32, #0, #2) fn matmul<3>(a: array(i32, #0, #1), b: array(i32, #1, #2)) -> array(i32, #0, #2)
c = constant(array(i32, #0, #2), []) c = constant(array(i32, #0, #2), [])
i_j_ctrl = fork(start, #0, #2) i_j_ctrl = fork(start, #0, #2)
i_idx = thread_id(i_j_ctrl, 0) i_idx = thread_id(i_j_ctrl, 0)
j_idx = thread_id(i_j_ctrl, 1) j_idx = thread_id(i_j_ctrl, 1)
k_ctrl = fork(i_j_ctrl, 8) k_ctrl = fork(i_j_ctrl, #1)
k_idx = thread_id(k_ctrl, 0) k_idx = thread_id(k_ctrl, 0)
k_join_ctrl = join(k_ctrl) k_join_ctrl = join(k_ctrl)
i_j_join_ctrl = join(k_join_ctrl) i_j_join_ctrl = join(k_join_ctrl)
......
...@@ -8,6 +8,9 @@ edition = "2021" ...@@ -8,6 +8,9 @@ edition = "2021"
name = "juno_cava" name = "juno_cava"
path = "src/main.rs" path = "src/main.rs"
[lib]
path = "src/lib.rs"
[features] [features]
cuda = ["juno_build/cuda", "hercules_rt/cuda"] cuda = ["juno_build/cuda", "hercules_rt/cuda"]
...@@ -21,3 +24,10 @@ async-std = "*" ...@@ -21,3 +24,10 @@ async-std = "*"
clap = { version = "*", features = ["derive"] } clap = { version = "*", features = ["derive"] }
image = "*" image = "*"
with_builtin_macros = "0.1.0" with_builtin_macros = "0.1.0"
[dev-dependencies]
criterion = { version = "0.5", features = ["html_reports"] }
[[bench]]
name = "cava_bench"
harness = false
\ No newline at end of file
#![feature(concat_idents)]
use criterion::{criterion_group, criterion_main, Criterion};
use hercules_rt::{runner, HerculesImmBoxTo};
use juno_cava::*;
juno_build::juno!("cava");
fn cava_bench(c: &mut Criterion) {
let mut group = c.benchmark_group("cava bench");
group.sample_size(10);
let args = CavaInputs {
input: "examples/raw_tulip-small.bin".to_string(),
output: None,
verify: false,
output_verify: None,
cam_model: "cam_models/NikonD7000".to_string(),
crop_rows: Some(144),
crop_cols: Some(192),
};
let (raw_image, cam_model) = make_raw_image_and_cam_model(&args);
let (rows, cols, num_ctrl_pts, image, tstw, ctrl_pts, weights, coefs, tonemap) =
prepare_hercules_inputs(&raw_image, &cam_model);
let mut r = runner!(cava);
group.bench_function("cava bench small", |b| {
b.iter(|| {
async_std::task::block_on(r.run(
rows as u64,
cols as u64,
num_ctrl_pts as u64,
image.to(),
tstw.to(),
ctrl_pts.to(),
weights.to(),
coefs.to(),
tonemap.to(),
));
})
});
let args = CavaInputs {
input: "examples/raw_tulips.bin".to_string(),
output: None,
verify: true,
output_verify: None,
cam_model: "cam_models/NikonD7000".to_string(),
crop_rows: None,
crop_cols: None,
};
let (raw_image, cam_model) = make_raw_image_and_cam_model(&args);
let (rows, cols, num_ctrl_pts, image, tstw, ctrl_pts, weights, coefs, tonemap) =
prepare_hercules_inputs(&raw_image, &cam_model);
let mut r = runner!(cava);
group.bench_function("cava bench full", |b| {
b.iter(|| {
async_std::task::block_on(r.run(
rows as u64,
cols as u64,
num_ctrl_pts as u64,
image.to(),
tstw.to(),
ctrl_pts.to(),
weights.to(),
coefs.to(),
tonemap.to(),
));
})
});
}
criterion_group!(benches, cava_bench);
criterion_main!(benches);
...@@ -40,7 +40,7 @@ fn scale<row : usize, col : usize>(input : u8[CHAN, row, col]) -> f32[CHAN, row, ...@@ -40,7 +40,7 @@ fn scale<row : usize, col : usize>(input : u8[CHAN, row, col]) -> f32[CHAN, row,
fn demosaic<row : usize, col : usize>(input : f32[CHAN, row, col]) -> f32[CHAN, row, col] { fn demosaic<row : usize, col : usize>(input : f32[CHAN, row, col]) -> f32[CHAN, row, col] {
@res2 let res : f32[CHAN, row, col]; @res2 let res : f32[CHAN, row, col];
for r = 1 to row-1 { @loop for r = 1 to row-1 {
for c = 1 to col-1 { for c = 1 to col-1 {
if r % 2 == 0 && c % 2 == 0 { if r % 2 == 0 && c % 2 == 0 {
let R1 = input[0, r, c-1]; let R1 = input[0, r, c-1];
...@@ -152,7 +152,7 @@ fn gamut<row : usize, col : usize, num_ctrl_pts : usize>( ...@@ -152,7 +152,7 @@ fn gamut<row : usize, col : usize, num_ctrl_pts : usize>(
let v = v1 * v1 + v2 * v2 + v3 * v3; let v = v1 * v1 + v2 * v2 + v3 * v3;
l2_dist[cp] = sqrt!::<f32>(v); l2_dist[cp] = sqrt!::<f32>(v);
} }
@channel_loop for chan = 0 to CHAN { @channel_loop for chan = 0 to CHAN {
let chan_val : f32 = 0.0; let chan_val : f32 = 0.0;
for cp = 0 to num_ctrl_pts { for cp = 0 to num_ctrl_pts {
......
...@@ -46,6 +46,9 @@ fixpoint { ...@@ -46,6 +46,9 @@ fixpoint {
} }
predication(fuse1); predication(fuse1);
simpl!(fuse1); simpl!(fuse1);
write-predication(fuse1);
simpl!(fuse1);
parallel-reduce(fuse1@loop);
inline(fuse2); inline(fuse2);
no-memset(fuse2@res); no-memset(fuse2@res);
......
gvn(*); macro simpl!(X) {
phi-elim(*); ccp(X);
dce(*); simplify-cfg(X);
lift-dc-math(X);
gvn(X);
phi-elim(X);
dce(X);
infer-schedules(X);
}
simpl!(*);
let fuse1 = outline(cava@fuse1);
inline(fuse1);
gpu(fuse1);
let fuse2 = outline(cava@fuse2);
inline(fuse2);
gpu(fuse2);
let fuse3 = outline(cava@fuse3);
inline(fuse3);
gpu(fuse3);
inline(denoise); let fuse4 = outline(cava@fuse4);
gpu(scale, demosaic, denoise, transform, gamut, tone_map, descale); inline(fuse4);
gpu(fuse4);
let fuse5 = outline(cava@fuse5);
inline(fuse5);
gpu(fuse5);
ip-sroa(*); ip-sroa(*);
sroa(*); sroa(*);
dce(*); simpl!(*);
gvn(*);
phi-elim(*);
dce(*);
// forkify(*); no-memset(fuse1@res1);
infer-schedules(*); no-memset(fuse1@res2);
fixpoint {
forkify(fuse1);
fork-guard-elim(fuse1);
fork-coalesce(fuse1);
}
simpl!(fuse1);
array-slf(fuse1);
loop-bound-canon(fuse1);
fixpoint {
forkify(fuse1);
fork-guard-elim(fuse1);
fork-coalesce(fuse1);
}
predication(fuse1);
simpl!(fuse1);
write-predication(fuse1);
simpl!(fuse1);
parallel-reduce(fuse1@loop);
gcm(*); inline(fuse2);
no-memset(fuse2@res);
no-memset(fuse2@filter);
no-memset(fuse2@tmp);
fixpoint {
forkify(fuse2);
fork-guard-elim(fuse2);
fork-coalesce(fuse2);
}
simpl!(fuse2);
predication(fuse2);
simpl!(fuse2);
let median = outline(fuse2@median);
fork-unroll(median@medianOuter);
simpl!(median);
fixpoint {
forkify(median);
fork-guard-elim(median);
}
simpl!(median);
fixpoint { fixpoint {
float-collections(*); fork-unroll(median);
dce(*);
gcm(*);
} }
ccp(median);
array-to-product(median);
sroa(median);
phi-elim(median);
predication(median);
simpl!(median);
inline(fuse2);
ip-sroa(*);
sroa(*);
array-slf(fuse2);
write-predication(fuse2);
simpl!(fuse2);
no-memset(fuse3@res);
fixpoint {
forkify(fuse3);
fork-guard-elim(fuse3);
fork-coalesce(fuse3);
}
simpl!(fuse3);
no-memset(fuse4@res);
no-memset(fuse4@l2);
fixpoint {
forkify(fuse4);
fork-guard-elim(fuse4);
fork-coalesce(fuse4);
}
simpl!(fuse4);
fork-unroll(fuse4@channel_loop);
simpl!(fuse4);
fixpoint {
fork-fusion(fuse4@channel_loop);
}
simpl!(fuse4);
array-slf(fuse4);
simpl!(fuse4);
//fork-tile[2, 0, false, true](fuse4@channel_loop);
//fork-split(fuse4@channel_loop);
//clean-monoid-reduces(fuse4);
no-memset(fuse5@res1);
no-memset(fuse5@res2);
fixpoint {
forkify(fuse5);
fork-guard-elim(fuse5);
fork-coalesce(fuse5);
}
simpl!(fuse5);
array-slf(fuse5);
simpl!(fuse5);
delete-uncalled(*);
simpl!(*);
gcm(*);
#![feature(concat_idents)]
mod camera_model;
mod cava_rust;
mod image_proc;
pub use self::camera_model::*;
pub use self::cava_rust::CHAN;
pub use self::image_proc::*;
use hercules_rt::{runner, HerculesImmBox, HerculesImmBoxTo, HerculesMutBox};
use image::ImageError;
use clap::Parser;
juno_build::juno!("cava");
pub fn make_raw_image_and_cam_model(args: &CavaInputs) -> (RawImage, CamModel) {
let raw_image =
read_raw(&args.input, args.crop_rows, args.crop_cols).expect("Error loading image");
let cam_model = load_cam_model(&args.cam_model, CHAN).expect("Error loading camera model");
println!(
"Running cava with {} rows, {} columns, and {} control points.",
raw_image.rows, raw_image.cols, cam_model.num_ctrl_pts
);
(raw_image, cam_model)
}
pub fn prepare_hercules_inputs<'a, 'b>(
raw_image: &'a RawImage,
cam_model: &'b CamModel,
) -> (
usize,
usize,
usize,
HerculesImmBox<'a, u8>,
HerculesImmBox<'b, f32>,
HerculesImmBox<'b, f32>,
HerculesImmBox<'b, f32>,
HerculesImmBox<'b, f32>,
HerculesImmBox<'b, f32>,
) {
assert_eq!(
raw_image.pixels.len(),
CHAN * raw_image.rows * raw_image.cols
);
assert_eq!(cam_model.tstw.len(), CHAN * CHAN);
assert_eq!(cam_model.ctrl_pts.len(), cam_model.num_ctrl_pts * CHAN);
assert_eq!(cam_model.weights.len(), cam_model.num_ctrl_pts * CHAN);
assert_eq!(cam_model.coefs.len(), 4 * CHAN);
assert_eq!(cam_model.tonemap.len(), 256 * CHAN);
let image = HerculesImmBox::from(&raw_image.pixels as &[u8]);
let tstw = HerculesImmBox::from(&cam_model.tstw as &[f32]);
let ctrl_pts = HerculesImmBox::from(&cam_model.ctrl_pts as &[f32]);
let weights = HerculesImmBox::from(&cam_model.weights as &[f32]);
let coefs = HerculesImmBox::from(&cam_model.coefs as &[f32]);
let tonemap = HerculesImmBox::from(&cam_model.tonemap as &[f32]);
(
raw_image.rows,
raw_image.cols,
cam_model.num_ctrl_pts,
image,
tstw,
ctrl_pts,
weights,
coefs,
tonemap,
)
}
enum Error {
IOError(std::io::Error),
ImageError(image::ImageError),
ParseIntError(std::num::ParseIntError),
ParseFloatError(std::num::ParseFloatError),
}
impl std::fmt::Debug for Error {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Error::IOError(err) => write!(f, "IO Error: {:?}", err),
Error::ImageError(err) => write!(f, "Image Error: {:?}", err),
Error::ParseIntError(err) => write!(f, "Parse Error: {:?}", err),
Error::ParseFloatError(err) => write!(f, "Parse Error: {:?}", err),
}
}
}
impl From<std::io::Error> for Error {
fn from(value: std::io::Error) -> Self {
Error::IOError(value)
}
}
impl From<ImageError> for Error {
fn from(value: ImageError) -> Self {
Error::ImageError(value)
}
}
impl From<std::num::ParseIntError> for Error {
fn from(value: std::num::ParseIntError) -> Self {
Error::ParseIntError(value)
}
}
impl From<std::num::ParseFloatError> for Error {
fn from(value: std::num::ParseFloatError) -> Self {
Error::ParseFloatError(value)
}
}
#[derive(Parser)]
#[clap(author, version, about, long_about = None)]
pub struct CavaInputs {
pub input: String,
#[clap(short, long, value_name = "PATH")]
pub output: Option<String>,
#[clap(short, long)]
pub verify: bool,
#[clap(long = "output-verify", value_name = "PATH")]
pub output_verify: Option<String>,
pub cam_model: String,
#[clap(short, long)]
pub crop_rows: Option<usize>,
#[clap(short, long)]
pub crop_cols: Option<usize>,
}
pub fn cava_harness(args: CavaInputs) {
let (raw_image, cam_model) = make_raw_image_and_cam_model(&args);
let (rows, cols, num_ctrl_pts, image, tstw, ctrl_pts, weights, coefs, tonemap) =
prepare_hercules_inputs(&raw_image, &cam_model);
let mut r = runner!(cava);
let result = async_std::task::block_on(async {
HerculesMutBox::from(
r.run(
rows as u64,
cols as u64,
num_ctrl_pts as u64,
image.to(),
tstw.to(),
ctrl_pts.to(),
weights.to(),
coefs.to(),
tonemap.to(),
)
.await,
)
})
.as_slice()
.to_vec()
.into_boxed_slice();
if let Some(output) = args.output {
extern_image(rows, cols, &*result)
.save(output)
.expect("Error saving image");
}
if args.verify {
let cpu_result = cava_rust::cava(
rows,
cols,
num_ctrl_pts,
&raw_image.pixels,
&cam_model.tstw,
&cam_model.ctrl_pts,
&cam_model.weights,
&cam_model.coefs,
&cam_model.tonemap,
);
if let Some(output) = args.output_verify {
extern_image(rows, cols, &cpu_result)
.save(output)
.expect("Error saving verification image");
}
let max_diff = result
.iter()
.zip(cpu_result.iter())
.map(|(a, b)| (*a as i16 - *b as i16).abs())
.max()
.unwrap_or(0);
assert!(
max_diff <= 3,
"Verification failed: maximum pixel difference of {} exceeds threshold of 3",
max_diff
);
println!("Verified!");
}
}
#![feature(concat_idents)]
mod camera_model;
mod cava_rust;
mod image_proc;
use self::camera_model::*;
use self::cava_rust::CHAN;
use self::image_proc::*;
use hercules_rt::{runner, HerculesImmBox, HerculesImmBoxTo, HerculesMutBox};
use image::ImageError;
use clap::Parser; use clap::Parser;
juno_build::juno!("cava"); use juno_cava::{cava_harness, CavaInputs};
// Individual lifetimes are not needed in this example but should probably be generated for
// flexibility
async fn safe_run<'a, 'b: 'a, 'c: 'a, 'd: 'a, 'e: 'a, 'f: 'a, 'g: 'a>(
runner: &'a mut HerculesRunner_cava,
r: u64,
c: u64,
num_ctrl_pts: u64,
input: &'b HerculesImmBox<'b, u8>,
tstw: &'c HerculesImmBox<'c, f32>,
ctrl_pts: &'d HerculesImmBox<'d, f32>,
weights: &'e HerculesImmBox<'e, f32>,
coefs: &'f HerculesImmBox<'f, f32>,
tonemap: &'g HerculesImmBox<'g, f32>,
) -> HerculesMutBox<'a, u8> {
HerculesMutBox::from(
runner
.run(
r,
c,
num_ctrl_pts,
input.to(),
tstw.to(),
ctrl_pts.to(),
weights.to(),
coefs.to(),
tonemap.to(),
)
.await,
)
}
fn run_cava(
rows: usize,
cols: usize,
num_ctrl_pts: usize,
image: &[u8],
tstw: &[f32],
ctrl_pts: &[f32],
weights: &[f32],
coefs: &[f32],
tonemap: &[f32],
) -> Box<[u8]> {
assert_eq!(image.len(), CHAN * rows * cols);
assert_eq!(tstw.len(), CHAN * CHAN);
assert_eq!(ctrl_pts.len(), num_ctrl_pts * CHAN);
assert_eq!(weights.len(), num_ctrl_pts * CHAN);
assert_eq!(coefs.len(), 4 * CHAN);
assert_eq!(tonemap.len(), 256 * CHAN);
let image = HerculesImmBox::from(image);
let tstw = HerculesImmBox::from(tstw);
let ctrl_pts = HerculesImmBox::from(ctrl_pts);
let weights = HerculesImmBox::from(weights);
let coefs = HerculesImmBox::from(coefs);
let tonemap = HerculesImmBox::from(tonemap);
let mut r = runner!(cava);
async_std::task::block_on(async {
safe_run(
&mut r,
rows as u64,
cols as u64,
num_ctrl_pts as u64,
&image,
&tstw,
&ctrl_pts,
&weights,
&coefs,
&tonemap,
)
.await
})
.as_slice()
.to_vec()
.into_boxed_slice()
}
enum Error {
IOError(std::io::Error),
ImageError(image::ImageError),
ParseIntError(std::num::ParseIntError),
ParseFloatError(std::num::ParseFloatError),
}
impl std::fmt::Debug for Error {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Error::IOError(err) => write!(f, "IO Error: {:?}", err),
Error::ImageError(err) => write!(f, "Image Error: {:?}", err),
Error::ParseIntError(err) => write!(f, "Parse Error: {:?}", err),
Error::ParseFloatError(err) => write!(f, "Parse Error: {:?}", err),
}
}
}
impl From<std::io::Error> for Error {
fn from(value: std::io::Error) -> Self {
Error::IOError(value)
}
}
impl From<ImageError> for Error {
fn from(value: ImageError) -> Self {
Error::ImageError(value)
}
}
impl From<std::num::ParseIntError> for Error {
fn from(value: std::num::ParseIntError) -> Self {
Error::ParseIntError(value)
}
}
impl From<std::num::ParseFloatError> for Error {
fn from(value: std::num::ParseFloatError) -> Self {
Error::ParseFloatError(value)
}
}
#[derive(Parser)]
#[clap(author, version, about, long_about = None)]
struct CavaInputs {
input: String,
#[clap(short, long, value_name = "PATH")]
output: Option<String>,
#[clap(short, long)]
verify: bool,
#[clap(long = "output-verify", value_name = "PATH")]
output_verify: Option<String>,
cam_model: String,
#[clap(short, long)]
crop_rows: Option<usize>,
#[clap(short, long)]
crop_cols: Option<usize>,
}
fn cava_harness(args: CavaInputs) {
let CavaInputs {
input,
output,
verify,
output_verify,
cam_model,
crop_rows,
crop_cols,
} = args;
let RawImage { rows, cols, pixels } =
read_raw(input, crop_rows, crop_cols).expect("Error loading image");
let CamModel {
tstw,
num_ctrl_pts,
ctrl_pts,
weights,
coefs,
tonemap,
} = load_cam_model(cam_model, CHAN).expect("Error loading camera model");
println!(
"Running cava with {} rows, {} columns, and {} control points.",
rows, cols, num_ctrl_pts
);
let result = run_cava(
rows,
cols,
num_ctrl_pts,
&pixels,
&tstw,
&ctrl_pts,
&weights,
&coefs,
&tonemap,
);
if let Some(output) = output {
extern_image(rows, cols, &*result)
.save(output)
.expect("Error saving image");
}
if verify {
let cpu_result = cava_rust::cava(
rows,
cols,
num_ctrl_pts,
&pixels,
&tstw,
&ctrl_pts,
&weights,
&coefs,
&tonemap,
);
if let Some(output) = output_verify {
extern_image(rows, cols, &cpu_result)
.save(output)
.expect("Error saving verification image");
}
let max_diff = result
.iter()
.zip(cpu_result.iter())
.map(|(a, b)| (*a as i16 - *b as i16).abs())
.max()
.unwrap_or(0);
assert!(
max_diff <= 3,
"Verification failed: maximum pixel difference of {} exceeds threshold of 3",
max_diff
);
}
}
fn main() { fn main() {
let args = CavaInputs::parse(); let args = CavaInputs::parse();
......
...@@ -10,12 +10,12 @@ juno_build::juno!("concat"); ...@@ -10,12 +10,12 @@ juno_build::juno!("concat");
fn main() { fn main() {
async_std::task::block_on(async { async_std::task::block_on(async {
let mut r = runner!(concat_entry); let mut r = runner!(concat_entry);
let mut a_data = [7, 7, 0]; let mut a_data = Box::new([7, 7, 0]);
let mut b_data = [7, 7, 0, 0, 7, 7]; let mut b_data = Box::new([7, 7, 0, 0, 7, 7]);
#[cfg(not(feature = "cuda"))] #[cfg(not(feature = "cuda"))]
{ {
let a = HerculesCPURef::from_slice(&mut a_data); let a = HerculesCPURef::from_slice(a_data.as_ref());
let b = HerculesCPURef::from_slice(&mut b_data); let b = HerculesCPURef::from_slice(b_data.as_ref());
let output = r.run(3, 6, a, b).await; let output = r.run(3, 6, a, b).await;
assert_eq!(output, 42); assert_eq!(output, 42);
...@@ -36,10 +36,8 @@ fn main() { ...@@ -36,10 +36,8 @@ fn main() {
} }
#[cfg(feature = "cuda")] #[cfg(feature = "cuda")]
{ {
let mut a_data = [7, 7, 0]; let a = CUDABox::from_cpu_ref(HerculesCPURef::from_slice(a_data.as_ref()));
let a = CUDABox::from_cpu_ref(HerculesCPURef::from_slice(&mut a_data)); let b = CUDABox::from_cpu_ref(HerculesCPURef::from_slice(b_data.as_ref()));
let mut b_data = [7, 7, 0, 0, 7, 7];
let b = CUDABox::from_cpu_ref(HerculesCPURef::from_slice(&mut b_data));
let output = r.run(3, 6, a.get_ref(), b.get_ref()).await; let output = r.run(3, 6, a.get_ref(), b.get_ref()).await;
assert_eq!(output, 42); assert_eq!(output, 42);
} }
......
...@@ -3,9 +3,7 @@ use std::iter::zip; ...@@ -3,9 +3,7 @@ use std::iter::zip;
use rand::random; use rand::random;
use hercules_rt::{runner, HerculesRefInto}; use hercules_rt::{runner, HerculesImmBox, HerculesImmBoxTo, HerculesMutBox};
#[cfg(feature = "cuda")]
use hercules_rt::{CUDABox, HerculesCPURef};
juno_build::juno!("matmul"); juno_build::juno!("matmul");
...@@ -24,24 +22,12 @@ fn main() { ...@@ -24,24 +22,12 @@ fn main() {
} }
} }
} }
#[cfg(not(feature = "cuda"))] let a = HerculesImmBox::from(a.as_ref());
{ let b = HerculesImmBox::from(b.as_ref());
let mut r = runner!(matmul); let mut r = runner!(matmul);
let c = r.run(I as u64, J as u64, K as u64, a.to(), b.to()).await; let mut c = HerculesMutBox::from(r.run(I as u64, J as u64, K as u64, a.to(), b.to()).await);
let c = c.as_slice::<f32>(); for (calc, correct) in zip(c.as_slice().into_iter().map(|x: &mut f32| *x), correct_c) {
assert_eq!(c, &*correct_c); assert!((calc - correct).abs() < 0.0001, "{} != {}", calc, correct);
}
#[cfg(feature = "cuda")]
{
let a = CUDABox::from_cpu_ref(HerculesCPURef::from_slice(&a));
let b = CUDABox::from_cpu_ref(HerculesCPURef::from_slice(&b));
let mut r = runner!(matmul);
let c = r
.run(I as u64, J as u64, K as u64, a.get_ref(), b.get_ref())
.await;
let mut c_cpu: Box<[f32]> = vec![0.0; correct_c.len()].into_boxed_slice();
c.to_cpu_ref(&mut c_cpu);
assert!(zip(c_cpu, correct_c).all(|(calc, correct)| (calc - correct).abs() < 0.00001));
} }
}); });
} }
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment