diff --git a/juno_samples/matmul/build.rs b/juno_samples/matmul/build.rs index d2813388e0e7a1d7bd1696ffbb641e629096e2c2..7bc2083cf07c063eccda5855fa4fed3bfca91f87 100644 --- a/juno_samples/matmul/build.rs +++ b/juno_samples/matmul/build.rs @@ -1,24 +1,11 @@ use juno_build::JunoCompiler; fn main() { - #[cfg(not(feature = "cuda"))] - { - JunoCompiler::new() - .file_in_src("matmul.jn") - .unwrap() - .schedule_in_src("cpu.sch") - .unwrap() - .build() - .unwrap(); - } - #[cfg(feature = "cuda")] - { - JunoCompiler::new() - .file_in_src("matmul.jn") - .unwrap() - .schedule_in_src("gpu.sch") - .unwrap() - .build() - .unwrap(); - } + JunoCompiler::new() + .file_in_src("matmul.jn") + .unwrap() + .schedule_in_src("matmul.sch") + .unwrap() + .build() + .unwrap(); } diff --git a/juno_samples/matmul/src/matmul.sch b/juno_samples/matmul/src/matmul.sch new file mode 100644 index 0000000000000000000000000000000000000000..306997f58eb217f9ce301dc18e418c412e6df621 --- /dev/null +++ b/juno_samples/matmul/src/matmul.sch @@ -0,0 +1,81 @@ +macro optimize!(X) { + gvn(X); + phi-elim(X); + dce(X); + ip-sroa(X); + sroa(X); + dce(X); + gvn(X); + phi-elim(X); + dce(X); +} + +macro codegen-prep!(X) { + optimize!(X); + gcm(X); + float-collections(X); + dce(X); + gcm(X); +} + +macro forkify!(X) { + fixpoint { + forkify(X); + fork-guard-elim(X); + } +} + +macro fork-tile { + fork-tile[n, 0, false, true](X); +} + +macro parallelize!(X) { + parallel-fork(X); + parallel-reduce(X); +} + +macro unforkify!(X) { + fork-split(X); + unforkify(X); +} + +optimize!(*); +forkify!(*); + +if feature("cuda") { + fixpoint { + reduce-slf(*); + slf(*); + infer-schedules(*); + } + fork-coalesce(*); + infer-schedules(*); + dce(*); + rewrite(*); + fixpoint { + simplify-cfg(*); + dce(*); + } + + optimize!(*); + codegen-prep!(*); +} else { + associative(matmul@outer); + + // Parallelize by computing output array as 16 chunks + let par = matmul@outer \ matmul@inner; + fork-tile; + let (outer, inner, _) = fork-reshape[[1, 3], [0], [2]](par); + parallelize!(outer \ inner); + + let body = outline(inner); + cpu(body); + + // Tile for cache, assuming 64B cache lines + fork-tile; + let (outer, inner) = fork-reshape[[0, 2, 4, 1, 3], [5]](body); + + reduce-slf(inner); + unforkify!(body); + codegen-prep!(*); +}