From 6a0c4a3410d40410dbdac7a92926583e9c6e640a Mon Sep 17 00:00:00 2001 From: Xavier Routh <xrouth2@illinois.edu> Date: Sat, 1 Feb 2025 15:04:24 -0600 Subject: [PATCH 01/25] control in reduce cycle fixes --- Cargo.lock | 11 ++++ Cargo.toml | 3 +- hercules_opt/src/forkify.rs | 12 +++++ hercules_opt/src/unforkify.rs | 2 +- .../hercules_tests/tests/loop_tests.rs | 4 +- juno_scheduler/src/pm.rs | 53 +++++++++++-------- 6 files changed, 59 insertions(+), 26 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 49630436..ad69bc72 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1181,6 +1181,17 @@ dependencies = [ "with_builtin_macros", ] +[[package]] +name = "juno_test" +version = "0.1.0" +dependencies = [ + "async-std", + "hercules_rt", + "juno_build", + "rand", + "with_builtin_macros", +] + [[package]] name = "juno_utils" version = "0.1.0" diff --git a/Cargo.toml b/Cargo.toml index ced011a9..46fc7eaa 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -21,7 +21,7 @@ members = [ "hercules_samples/ccp", "juno_samples/simple3", - "juno_samples/patterns", + "juno_samples/patterns", "juno_samples/matmul", "juno_samples/casts_and_intrinsics", "juno_samples/nested_ccp", @@ -30,4 +30,5 @@ members = [ "juno_samples/cava", "juno_samples/concat", "juno_samples/schedule_test", + "juno_samples/test", ] diff --git a/hercules_opt/src/forkify.rs b/hercules_opt/src/forkify.rs index ec4e9fbc..0f06627d 100644 --- a/hercules_opt/src/forkify.rs +++ b/hercules_opt/src/forkify.rs @@ -152,6 +152,7 @@ pub fn forkify_loop( .filter(|id| !l.control[id.idx()]) .collect(); + // FIXME: @xrouth if loop_preds.len() != 1 { return false; } @@ -388,6 +389,7 @@ nest! { is_associative: bool, }, LoopDependant(NodeID), + ControlDependant(NodeID), // This phi is redcutionable, but its cycle might depend on internal control within the loop. UsedByDependant(NodeID), } } @@ -398,6 +400,7 @@ impl LoopPHI { LoopPHI::Reductionable { phi, .. } => *phi, LoopPHI::LoopDependant(node_id) => *node_id, LoopPHI::UsedByDependant(node_id) => *node_id, + LoopPHI::ControlDependant(node_id) => *node_id, } } } @@ -415,6 +418,9 @@ pub fn analyze_phis<'a>( loop_nodes: &'a HashSet<NodeID>, ) -> impl Iterator<Item = LoopPHI> + 'a { + // We are also moving the phi from the top of the loop (the header), + // to the very end (the join). If there are uses of the phi somewhere in the loop, + // then they may try to use the phi (now a reduce) before it hits the join. // Find data cycles within the loop of this phi, // Start from the phis loop_continue_latch, and walk its uses until we find the original phi. @@ -509,6 +515,12 @@ pub fn analyze_phis<'a>( // to have headers that postdominate the loop continue latch. The value of the PHI used needs to be defined // by the time the reduce is triggered (at the end of the loop's internal control). + // If anything in the intersection is a phi (that isn't this own phi), then the reduction cycle depends on control. + // Which is not allowed. + if intersection.iter().any(|cycle_node| editor.node(cycle_node).is_phi() && *cycle_node != *phi) || editor.node(loop_continue_latch).is_phi() { + return LoopPHI::ControlDependant(*phi); + } + // No nodes in data cycles with this phi (in the loop) are used outside the loop, besides the loop_continue_latch. // If some other node in the cycle is used, there is not a valid node to assign it after making the cycle a reduce. if intersection diff --git a/hercules_opt/src/unforkify.rs b/hercules_opt/src/unforkify.rs index 85ffd233..7d158d1a 100644 --- a/hercules_opt/src/unforkify.rs +++ b/hercules_opt/src/unforkify.rs @@ -133,7 +133,7 @@ pub fn unforkify( if factors.len() > 1 { // For now, don't convert multi-dimensional fork-joins. Rely on pass // that splits fork-joins. - continue; + break; // Because we have to unforkify top down, we can't unforkify forks that are contained } let join_control = nodes[join.idx()].try_join().unwrap(); let tids: Vec<_> = editor diff --git a/hercules_test/hercules_tests/tests/loop_tests.rs b/hercules_test/hercules_tests/tests/loop_tests.rs index 5832a161..192c1366 100644 --- a/hercules_test/hercules_tests/tests/loop_tests.rs +++ b/hercules_test/hercules_tests/tests/loop_tests.rs @@ -401,7 +401,7 @@ fn matmul_pipeline() { let dyn_consts = [I, J, K]; // FIXME: This path should not leave the crate - let mut module = parse_module_from_hbin("../../juno_samples/matmul/out.hbin"); + let mut module = parse_module_from_hbin("../../juno_samples/test/out.hbin"); // let mut correct_c: Box<[i32]> = (0..I * K).map(|_| 0).collect(); for i in 0..I { @@ -425,7 +425,7 @@ fn matmul_pipeline() { }; assert_eq!(correct_c[0], value); - let schedule = Some(default_schedule![Xdot, ForkSplit, Unforkify, Xdot,]); + let schedule = Some(default_schedule![AutoOutline, InterproceduralSROA, SROA, InferSchedules, DCE, Xdot, GCM]); module = run_schedule_on_hercules(module, schedule).unwrap(); diff --git a/juno_scheduler/src/pm.rs b/juno_scheduler/src/pm.rs index 2371e0f2..d2772c71 100644 --- a/juno_scheduler/src/pm.rs +++ b/juno_scheduler/src/pm.rs @@ -1471,29 +1471,38 @@ fn run_pass( } Pass::Forkify => { assert!(args.is_empty()); - pm.make_fork_join_maps(); - pm.make_control_subgraphs(); - pm.make_loops(); - let fork_join_maps = pm.fork_join_maps.take().unwrap(); - let loops = pm.loops.take().unwrap(); - let control_subgraphs = pm.control_subgraphs.take().unwrap(); - for (((func, fork_join_map), loop_nest), control_subgraph) in - build_selection(pm, selection) - .into_iter() - .zip(fork_join_maps.iter()) - .zip(loops.iter()) - .zip(control_subgraphs.iter()) - { - let Some(mut func) = func else { - continue; - }; - // TODO: uses direct return from forkify for now instead of - // func.modified, see comment on top of `forkify` for why. Fix - // this eventually. - changed |= forkify(&mut func, control_subgraph, fork_join_map, loop_nest); + loop { + let mut inner_changed = false; + pm.make_fork_join_maps(); + pm.make_control_subgraphs(); + pm.make_loops(); + let fork_join_maps = pm.fork_join_maps.take().unwrap(); + let loops = pm.loops.take().unwrap(); + let control_subgraphs = pm.control_subgraphs.take().unwrap(); + for (((func, fork_join_map), loop_nest), control_subgraph) in + build_selection(pm, selection.clone()) + .into_iter() + .zip(fork_join_maps.iter()) + .zip(loops.iter()) + .zip(control_subgraphs.iter()) + { + let Some(mut func) = func else { + continue; + }; + // TODO: uses direct return from forkify for now instead of + // func.modified, see comment on top of `forkify` for why. Fix + // this eventually. + let c = forkify(&mut func, control_subgraph, fork_join_map, loop_nest); + changed |= c; + inner_changed |= c; + } + pm.delete_gravestones(); + pm.clear_analyses(); + + if !inner_changed { + break; + } } - pm.delete_gravestones(); - pm.clear_analyses(); } Pass::GCM => { assert!(args.is_empty()); -- GitLab From fab913636e8b63cef26af9db6df2eb699f415161 Mon Sep 17 00:00:00 2001 From: Xavier Routh <xrouth2@illinois.edu> Date: Sat, 1 Feb 2025 16:30:10 -0600 Subject: [PATCH 02/25] misc --- hercules_opt/src/forkify.rs | 1 - 1 file changed, 1 deletion(-) diff --git a/hercules_opt/src/forkify.rs b/hercules_opt/src/forkify.rs index 0f06627d..299422c1 100644 --- a/hercules_opt/src/forkify.rs +++ b/hercules_opt/src/forkify.rs @@ -514,7 +514,6 @@ pub fn analyze_phis<'a>( // PHIs on the frontier of the uses by the candidate phi, i.e in uses_for_dependance need // to have headers that postdominate the loop continue latch. The value of the PHI used needs to be defined // by the time the reduce is triggered (at the end of the loop's internal control). - // If anything in the intersection is a phi (that isn't this own phi), then the reduction cycle depends on control. // Which is not allowed. if intersection.iter().any(|cycle_node| editor.node(cycle_node).is_phi() && *cycle_node != *phi) || editor.node(loop_continue_latch).is_phi() { -- GitLab From 43b4022c38f65fcf5c2abad523c63f59888dabe9 Mon Sep 17 00:00:00 2001 From: Russel Arbore <russel.jma@gmail.com> Date: Sat, 1 Feb 2025 18:29:52 -0600 Subject: [PATCH 03/25] simple test --- Cargo.lock | 10 +++++++ Cargo.toml | 1 + juno_samples/fork_join_tests/Cargo.toml | 21 +++++++++++++ juno_samples/fork_join_tests/build.rs | 24 +++++++++++++++ .../fork_join_tests/src/fork_join_tests.jn | 10 +++++++ juno_samples/fork_join_tests/src/gpu.sch | 30 +++++++++++++++++++ juno_samples/fork_join_tests/src/main.rs | 17 +++++++++++ 7 files changed, 113 insertions(+) create mode 100644 juno_samples/fork_join_tests/Cargo.toml create mode 100644 juno_samples/fork_join_tests/build.rs create mode 100644 juno_samples/fork_join_tests/src/fork_join_tests.jn create mode 100644 juno_samples/fork_join_tests/src/gpu.sch create mode 100644 juno_samples/fork_join_tests/src/main.rs diff --git a/Cargo.lock b/Cargo.lock index b8bf2278..af7902c6 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1130,6 +1130,16 @@ dependencies = [ "with_builtin_macros", ] +[[package]] +name = "juno_fork_join_tests" +version = "0.1.0" +dependencies = [ + "async-std", + "hercules_rt", + "juno_build", + "with_builtin_macros", +] + [[package]] name = "juno_frontend" version = "0.1.0" diff --git a/Cargo.toml b/Cargo.toml index f7b9322a..890d7924 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -31,4 +31,5 @@ members = [ "juno_samples/concat", "juno_samples/schedule_test", "juno_samples/edge_detection", + "juno_samples/fork_join_tests", ] diff --git a/juno_samples/fork_join_tests/Cargo.toml b/juno_samples/fork_join_tests/Cargo.toml new file mode 100644 index 00000000..a109e782 --- /dev/null +++ b/juno_samples/fork_join_tests/Cargo.toml @@ -0,0 +1,21 @@ +[package] +name = "juno_fork_join_tests" +version = "0.1.0" +authors = ["Russel Arbore <rarbore2@illinois.edu>"] +edition = "2021" + +[[bin]] +name = "juno_fork_join_tests" +path = "src/main.rs" + +[features] +cuda = ["juno_build/cuda", "hercules_rt/cuda"] + +[build-dependencies] +juno_build = { path = "../../juno_build" } + +[dependencies] +juno_build = { path = "../../juno_build" } +hercules_rt = { path = "../../hercules_rt" } +with_builtin_macros = "0.1.0" +async-std = "*" diff --git a/juno_samples/fork_join_tests/build.rs b/juno_samples/fork_join_tests/build.rs new file mode 100644 index 00000000..796e9f32 --- /dev/null +++ b/juno_samples/fork_join_tests/build.rs @@ -0,0 +1,24 @@ +use juno_build::JunoCompiler; + +fn main() { + #[cfg(not(feature = "cuda"))] + { + JunoCompiler::new() + .file_in_src("fork_join_tests.jn") + .unwrap() + .schedule_in_src("cpu.sch") + .unwrap() + .build() + .unwrap(); + } + #[cfg(feature = "cuda")] + { + JunoCompiler::new() + .file_in_src("fork_join_tests.jn") + .unwrap() + .schedule_in_src("gpu.sch") + .unwrap() + .build() + .unwrap(); + } +} diff --git a/juno_samples/fork_join_tests/src/fork_join_tests.jn b/juno_samples/fork_join_tests/src/fork_join_tests.jn new file mode 100644 index 00000000..aa8eb4bb --- /dev/null +++ b/juno_samples/fork_join_tests/src/fork_join_tests.jn @@ -0,0 +1,10 @@ +#[entry] +fn test1(input : i32) -> i32[4, 4] { + let arr : i32[4, 4]; + for i = 0 to 4 { + for j = 0 to 4 { + arr[i, j] = input; + } + } + return arr; +} diff --git a/juno_samples/fork_join_tests/src/gpu.sch b/juno_samples/fork_join_tests/src/gpu.sch new file mode 100644 index 00000000..e2fe980e --- /dev/null +++ b/juno_samples/fork_join_tests/src/gpu.sch @@ -0,0 +1,30 @@ +gvn(*); +phi-elim(*); +dce(*); + +let out = auto-outline(*); +gpu(out.test1); + +ip-sroa(*); +sroa(*); +dce(*); +gvn(*); +phi-elim(*); +dce(*); + +fixpoint panic after 20 { + forkify(*); + fork-guard-elim(*); + fork-coalesce(*); +} + +gvn(*); +phi-elim(*); +dce(*); + +fixpoint panic after 20 { + infer-schedules(*); +} +xdot[true](*); + +gcm(*); diff --git a/juno_samples/fork_join_tests/src/main.rs b/juno_samples/fork_join_tests/src/main.rs new file mode 100644 index 00000000..6e5f2182 --- /dev/null +++ b/juno_samples/fork_join_tests/src/main.rs @@ -0,0 +1,17 @@ +#![feature(concat_idents)] + +use hercules_rt::runner; + +juno_build::juno!("fork_join_tests"); + +fn main() { + async_std::task::block_on(async { + let mut r = runner!(tests1); + let output = r.run(5).await; + }); +} + +#[test] +fn implicit_clone_test() { + main(); +} -- GitLab From 5eeb70cb675bdd7ac2cf4ac4887f803d24ebe165 Mon Sep 17 00:00:00 2001 From: Russel Arbore <russel.jma@gmail.com> Date: Sat, 1 Feb 2025 18:52:57 -0600 Subject: [PATCH 04/25] fixes for test --- hercules_cg/src/gpu.rs | 8 +++--- juno_samples/fork_join_tests/src/cpu.sch | 31 ++++++++++++++++++++++++ juno_samples/fork_join_tests/src/gpu.sch | 1 - juno_samples/fork_join_tests/src/main.rs | 13 +++++++++- juno_scheduler/src/pm.rs | 2 ++ 5 files changed, 50 insertions(+), 5 deletions(-) create mode 100644 juno_samples/fork_join_tests/src/cpu.sch diff --git a/hercules_cg/src/gpu.rs b/hercules_cg/src/gpu.rs index 81e31396..afc016a4 100644 --- a/hercules_cg/src/gpu.rs +++ b/hercules_cg/src/gpu.rs @@ -539,15 +539,17 @@ namespace cg = cooperative_groups; w: &mut String, ) -> Result<(), Error> { write!(w, "\n")?; - for (id, goto) in gotos.iter() { - let goto_block = self.get_block_name(*id, false); + let rev_po = self.control_subgraph.rev_po(NodeID::new(0)); + for id in rev_po { + let goto = &gotos[&id]; + let goto_block = self.get_block_name(id, false); write!(w, "{}:\n", goto_block)?; if goto_debug { write!(w, "\tprintf(\"goto {}\\n\");\n", goto_block)?; } write!(w, "{}", goto.init)?; if !goto.post_init.is_empty() { - let goto_block = self.get_block_name(*id, true); + let goto_block = self.get_block_name(id, true); write!(w, "{}:\n", goto_block)?; write!(w, "{}", goto.post_init)?; } diff --git a/juno_samples/fork_join_tests/src/cpu.sch b/juno_samples/fork_join_tests/src/cpu.sch new file mode 100644 index 00000000..81f5a12c --- /dev/null +++ b/juno_samples/fork_join_tests/src/cpu.sch @@ -0,0 +1,31 @@ +gvn(*); +phi-elim(*); +dce(*); + +let out = auto-outline(*); +cpu(out.test1); + +ip-sroa(*); +sroa(*); +dce(*); +gvn(*); +phi-elim(*); +dce(*); + +fixpoint panic after 20 { + forkify(*); + fork-guard-elim(*); + fork-coalesce(*); +} + +gvn(*); +phi-elim(*); +dce(*); + +fixpoint panic after 20 { + infer-schedules(*); +} +fork-split(*); +unforkify(*); + +gcm(*); diff --git a/juno_samples/fork_join_tests/src/gpu.sch b/juno_samples/fork_join_tests/src/gpu.sch index e2fe980e..e4e4e04f 100644 --- a/juno_samples/fork_join_tests/src/gpu.sch +++ b/juno_samples/fork_join_tests/src/gpu.sch @@ -25,6 +25,5 @@ dce(*); fixpoint panic after 20 { infer-schedules(*); } -xdot[true](*); gcm(*); diff --git a/juno_samples/fork_join_tests/src/main.rs b/juno_samples/fork_join_tests/src/main.rs index 6e5f2182..a63b3f78 100644 --- a/juno_samples/fork_join_tests/src/main.rs +++ b/juno_samples/fork_join_tests/src/main.rs @@ -6,8 +6,19 @@ juno_build::juno!("fork_join_tests"); fn main() { async_std::task::block_on(async { - let mut r = runner!(tests1); + let mut r = runner!(test1); let output = r.run(5).await; + let correct = vec![5i32; 16]; + #[cfg(not(feature = "cuda"))] + { + assert_eq!(output.as_slice::<i32>(), &correct); + } + #[cfg(feature = "cuda")] + { + let mut dst = vec![0i32; 16]; + let output = output.to_cpu_ref(&mut dst); + assert_eq!(output.as_slice::<i32>(), &correct); + } }); } diff --git a/juno_scheduler/src/pm.rs b/juno_scheduler/src/pm.rs index f6fe2fc1..5d398804 100644 --- a/juno_scheduler/src/pm.rs +++ b/juno_scheduler/src/pm.rs @@ -575,6 +575,8 @@ impl PassManager { self.postdoms = None; self.fork_join_maps = None; self.fork_join_nests = None; + self.fork_control_maps = None; + self.fork_trees = None; self.loops = None; self.reduce_cycles = None; self.data_nodes_in_fork_joins = None; -- GitLab From d67263e6a871fa539067ec84e87c58f242809ea4 Mon Sep 17 00:00:00 2001 From: Russel Arbore <russel.jma@gmail.com> Date: Sun, 2 Feb 2025 10:50:56 -0600 Subject: [PATCH 05/25] second fork test, fails on cpu w/ forkify/unforkify --- hercules_opt/src/float_collections.rs | 27 ++++++++++++------ juno_samples/fork_join_tests/src/cpu.sch | 1 + .../fork_join_tests/src/fork_join_tests.jn | 13 +++++++++ juno_samples/fork_join_tests/src/gpu.sch | 2 ++ juno_samples/fork_join_tests/src/main.rs | 28 ++++++++++++------- juno_scheduler/src/pm.rs | 17 +++++------ 6 files changed, 60 insertions(+), 28 deletions(-) diff --git a/hercules_opt/src/float_collections.rs b/hercules_opt/src/float_collections.rs index faa38375..6ef050c2 100644 --- a/hercules_opt/src/float_collections.rs +++ b/hercules_opt/src/float_collections.rs @@ -1,3 +1,5 @@ +use std::collections::BTreeMap; + use hercules_ir::*; use crate::*; @@ -7,27 +9,36 @@ use crate::*; * allowed. */ pub fn float_collections( - editors: &mut [FunctionEditor], + editors: &mut BTreeMap<FunctionID, FunctionEditor>, typing: &ModuleTyping, callgraph: &CallGraph, devices: &Vec<Device>, ) { - let topo = callgraph.topo(); + let topo: Vec<_> = callgraph + .topo() + .into_iter() + .filter(|id| editors.contains_key(&id)) + .collect(); for to_float_id in topo { // Collection constants float until reaching an AsyncRust function. if devices[to_float_id.idx()] == Device::AsyncRust { continue; } + // Check that all callers are in the selection as well. + for caller in callgraph.get_callers(to_float_id) { + assert!(editors.contains_key(&caller), "PANIC: FloatCollections called where a function ({:?}, {:?}) is in the selection but one of its callers ({:?}) is not. This means no collections will be floated from the callee, since the caller can't be modified to hold floated collections.", to_float_id, editors[&to_float_id].func().name, caller); + } + // Find the target constant nodes in the function. - let cons: Vec<(NodeID, Node)> = editors[to_float_id.idx()] + let cons: Vec<(NodeID, Node)> = editors[&to_float_id] .func() .nodes .iter() .enumerate() .filter(|(_, node)| { node.try_constant() - .map(|cons_id| !editors[to_float_id.idx()].get_constant(cons_id).is_scalar()) + .map(|cons_id| !editors[&to_float_id].get_constant(cons_id).is_scalar()) .unwrap_or(false) }) .map(|(idx, node)| (NodeID::new(idx), node.clone())) @@ -37,12 +48,12 @@ pub fn float_collections( } // Each constant node becomes a new parameter. - let mut new_param_types = editors[to_float_id.idx()].func().param_types.clone(); + let mut new_param_types = editors[&to_float_id].func().param_types.clone(); let old_num_params = new_param_types.len(); for (id, _) in cons.iter() { new_param_types.push(typing[to_float_id.idx()][id.idx()]); } - let success = editors[to_float_id.idx()].edit(|mut edit| { + let success = editors.get_mut(&to_float_id).unwrap().edit(|mut edit| { for (idx, (id, _)) in cons.iter().enumerate() { let param = edit.add_node(Node::Parameter { index: idx + old_num_params, @@ -59,7 +70,7 @@ pub fn float_collections( // Add constants in callers and pass them into calls. for caller in callgraph.get_callers(to_float_id) { - let calls: Vec<(NodeID, Node)> = editors[caller.idx()] + let calls: Vec<(NodeID, Node)> = editors[&caller] .func() .nodes .iter() @@ -71,7 +82,7 @@ pub fn float_collections( }) .map(|(idx, node)| (NodeID::new(idx), node.clone())) .collect(); - let success = editors[caller.idx()].edit(|mut edit| { + let success = editors.get_mut(&caller).unwrap().edit(|mut edit| { let cons_ids: Vec<_> = cons .iter() .map(|(_, node)| edit.add_node(node.clone())) diff --git a/juno_samples/fork_join_tests/src/cpu.sch b/juno_samples/fork_join_tests/src/cpu.sch index 81f5a12c..a6b1afe7 100644 --- a/juno_samples/fork_join_tests/src/cpu.sch +++ b/juno_samples/fork_join_tests/src/cpu.sch @@ -4,6 +4,7 @@ dce(*); let out = auto-outline(*); cpu(out.test1); +cpu(out.test2); ip-sroa(*); sroa(*); diff --git a/juno_samples/fork_join_tests/src/fork_join_tests.jn b/juno_samples/fork_join_tests/src/fork_join_tests.jn index aa8eb4bb..4a6a94c9 100644 --- a/juno_samples/fork_join_tests/src/fork_join_tests.jn +++ b/juno_samples/fork_join_tests/src/fork_join_tests.jn @@ -8,3 +8,16 @@ fn test1(input : i32) -> i32[4, 4] { } return arr; } + +#[entry] +fn test2(input : i32) -> i32[4, 4] { + let arr : i32[4, 4]; + for i = 0 to 8 { + for j = 0 to 4 { + for k = 0 to 4 { + arr[j, k] += input; + } + } + } + return arr; +} diff --git a/juno_samples/fork_join_tests/src/gpu.sch b/juno_samples/fork_join_tests/src/gpu.sch index e4e4e04f..b506c4a4 100644 --- a/juno_samples/fork_join_tests/src/gpu.sch +++ b/juno_samples/fork_join_tests/src/gpu.sch @@ -4,6 +4,7 @@ dce(*); let out = auto-outline(*); gpu(out.test1); +gpu(out.test2); ip-sroa(*); sroa(*); @@ -26,4 +27,5 @@ fixpoint panic after 20 { infer-schedules(*); } +float-collections(test2, out.test2); gcm(*); diff --git a/juno_samples/fork_join_tests/src/main.rs b/juno_samples/fork_join_tests/src/main.rs index a63b3f78..de1b0805 100644 --- a/juno_samples/fork_join_tests/src/main.rs +++ b/juno_samples/fork_join_tests/src/main.rs @@ -5,20 +5,28 @@ use hercules_rt::runner; juno_build::juno!("fork_join_tests"); fn main() { + #[cfg(not(feature = "cuda"))] + let assert = |correct, output: hercules_rt::HerculesCPURefMut<'_>| { + assert_eq!(output.as_slice::<i32>(), &correct); + }; + + #[cfg(feature = "cuda")] + let assert = |correct, output: hercules_rt::HerculesCUDARefMut<'_>| { + let mut dst = vec![0i32; 16]; + let output = output.to_cpu_ref(&mut dst); + assert_eq!(output.as_slice::<i32>(), &correct); + }; + async_std::task::block_on(async { let mut r = runner!(test1); let output = r.run(5).await; let correct = vec![5i32; 16]; - #[cfg(not(feature = "cuda"))] - { - assert_eq!(output.as_slice::<i32>(), &correct); - } - #[cfg(feature = "cuda")] - { - let mut dst = vec![0i32; 16]; - let output = output.to_cpu_ref(&mut dst); - assert_eq!(output.as_slice::<i32>(), &correct); - } + assert(correct, output); + + let mut r = runner!(test2); + let output = r.run(3).await; + let correct = vec![24i32; 16]; + assert(correct, output); }); } diff --git a/juno_scheduler/src/pm.rs b/juno_scheduler/src/pm.rs index 5d398804..1ebc885c 100644 --- a/juno_scheduler/src/pm.rs +++ b/juno_scheduler/src/pm.rs @@ -1305,7 +1305,7 @@ fn run_pass( pm: &mut PassManager, pass: Pass, args: Vec<Value>, - selection: Option<Vec<CodeLocation>>, + mut selection: Option<Vec<CodeLocation>>, ) -> Result<(Value, bool), SchedulerError> { let mut result = Value::Record { fields: HashMap::new(), @@ -1441,13 +1441,6 @@ fn run_pass( } Pass::FloatCollections => { assert!(args.is_empty()); - if let Some(_) = selection { - return Err(SchedulerError::PassError { - pass: "floatCollections".to_string(), - error: "must be applied to the entire module".to_string(), - }); - } - pm.make_typing(); pm.make_callgraph(); pm.make_devices(); @@ -1455,11 +1448,15 @@ fn run_pass( let callgraph = pm.callgraph.take().unwrap(); let devices = pm.devices.take().unwrap(); - let mut editors = build_editors(pm); + // Modify the selection to include callers of selected functions. + let mut editors = build_selection(pm, selection) + .into_iter() + .filter_map(|editor| editor.map(|editor| (editor.func_id(), editor))) + .collect(); float_collections(&mut editors, &typing, &callgraph, &devices); for func in editors { - changed |= func.modified(); + changed |= func.1.modified(); } pm.delete_gravestones(); -- GitLab From ba614011066dcde2305a649f7911368d99f95a84 Mon Sep 17 00:00:00 2001 From: Russel Arbore <russel.jma@gmail.com> Date: Sun, 2 Feb 2025 11:12:34 -0600 Subject: [PATCH 06/25] fix reduce cycles --- hercules_ir/src/fork_join_analysis.rs | 34 ++++++++++++------------ juno_samples/fork_join_tests/src/cpu.sch | 9 +++++++ 2 files changed, 26 insertions(+), 17 deletions(-) diff --git a/hercules_ir/src/fork_join_analysis.rs b/hercules_ir/src/fork_join_analysis.rs index 263fa952..7a098a35 100644 --- a/hercules_ir/src/fork_join_analysis.rs +++ b/hercules_ir/src/fork_join_analysis.rs @@ -140,23 +140,23 @@ fn reduce_cycle_dfs_helper( } current_visited.insert(iter); - let found_reduce = get_uses(&function.nodes[iter.idx()]) - .as_ref() - .into_iter() - .any(|u| { - !current_visited.contains(u) - && !function.nodes[u.idx()].is_control() - && isnt_outside_fork_join(*u) - && reduce_cycle_dfs_helper( - function, - *u, - fork, - reduce, - current_visited, - in_cycle, - fork_join_nest, - ) - }); + let mut found_reduce = false; + + // This doesn't short circuit on purpose. + for u in get_uses(&function.nodes[iter.idx()]).as_ref() { + found_reduce |= !current_visited.contains(u) + && !function.nodes[u.idx()].is_control() + && isnt_outside_fork_join(*u) + && reduce_cycle_dfs_helper( + function, + *u, + fork, + reduce, + current_visited, + in_cycle, + fork_join_nest, + ) + } if found_reduce { in_cycle.insert(iter); } diff --git a/juno_samples/fork_join_tests/src/cpu.sch b/juno_samples/fork_join_tests/src/cpu.sch index a6b1afe7..2889cec0 100644 --- a/juno_samples/fork_join_tests/src/cpu.sch +++ b/juno_samples/fork_join_tests/src/cpu.sch @@ -18,6 +18,9 @@ fixpoint panic after 20 { fork-guard-elim(*); fork-coalesce(*); } +gvn(*); +phi-elim(*); +dce(*); gvn(*); phi-elim(*); @@ -27,6 +30,12 @@ fixpoint panic after 20 { infer-schedules(*); } fork-split(*); +gvn(*); +phi-elim(*); +dce(*); unforkify(*); +gvn(*); +phi-elim(*); +dce(*); gcm(*); -- GitLab From f2865cbfd0fd7f3a7e84c08e70283c8c0eedefdf Mon Sep 17 00:00:00 2001 From: Russel Arbore <russel.jma@gmail.com> Date: Sun, 2 Feb 2025 11:21:02 -0600 Subject: [PATCH 07/25] interesting test --- juno_samples/fork_join_tests/src/cpu.sch | 1 + .../fork_join_tests/src/fork_join_tests.jn | 23 +++++++++++++++++++ juno_samples/fork_join_tests/src/gpu.sch | 1 + juno_samples/fork_join_tests/src/main.rs | 5 ++++ 4 files changed, 30 insertions(+) diff --git a/juno_samples/fork_join_tests/src/cpu.sch b/juno_samples/fork_join_tests/src/cpu.sch index 2889cec0..0263c275 100644 --- a/juno_samples/fork_join_tests/src/cpu.sch +++ b/juno_samples/fork_join_tests/src/cpu.sch @@ -5,6 +5,7 @@ dce(*); let out = auto-outline(*); cpu(out.test1); cpu(out.test2); +cpu(out.test3); ip-sroa(*); sroa(*); diff --git a/juno_samples/fork_join_tests/src/fork_join_tests.jn b/juno_samples/fork_join_tests/src/fork_join_tests.jn index 4a6a94c9..073cfd1e 100644 --- a/juno_samples/fork_join_tests/src/fork_join_tests.jn +++ b/juno_samples/fork_join_tests/src/fork_join_tests.jn @@ -21,3 +21,26 @@ fn test2(input : i32) -> i32[4, 4] { } return arr; } + +#[entry] +fn test3(input : i32) -> i32[3, 3] { + let arr1 : i32[3, 3]; + for i = 0 to 3 { + for j = 0 to 3 { + arr1[i, j] = (i + j) as i32 + input; + } + } + let arr2 : i32[3, 3]; + for i = 0 to 3 { + for j = 0 to 3 { + arr2[i, j] = arr1[3 - i, 3 - j]; + } + } + let arr3 : i32[3, 3]; + for i = 0 to 3 { + for j = 0 to 3 { + arr3[i, j] = arr2[i, j] + 7; + } + } + return arr3; +} diff --git a/juno_samples/fork_join_tests/src/gpu.sch b/juno_samples/fork_join_tests/src/gpu.sch index b506c4a4..80f1bbc9 100644 --- a/juno_samples/fork_join_tests/src/gpu.sch +++ b/juno_samples/fork_join_tests/src/gpu.sch @@ -5,6 +5,7 @@ dce(*); let out = auto-outline(*); gpu(out.test1); gpu(out.test2); +gpu(out.test3); ip-sroa(*); sroa(*); diff --git a/juno_samples/fork_join_tests/src/main.rs b/juno_samples/fork_join_tests/src/main.rs index de1b0805..4384ecd5 100644 --- a/juno_samples/fork_join_tests/src/main.rs +++ b/juno_samples/fork_join_tests/src/main.rs @@ -27,6 +27,11 @@ fn main() { let output = r.run(3).await; let correct = vec![24i32; 16]; assert(correct, output); + + let mut r = runner!(test3); + let output = r.run(0).await; + let correct = vec![11, 10, 9, 10, 9, 8, 9, 8, 7]; + assert(correct, output); }); } -- GitLab From 912a729ac05bd2077f2ab864cef04e1dddde7667 Mon Sep 17 00:00:00 2001 From: Xavier Routh <xrouth2@illinois.edu> Date: Sun, 2 Feb 2025 12:34:13 -0600 Subject: [PATCH 08/25] unforkify fixes --- hercules_opt/src/unforkify.rs | 6 +++-- juno_scheduler/src/pm.rs | 42 ++++++++++++++++++++++------------- 2 files changed, 30 insertions(+), 18 deletions(-) diff --git a/hercules_opt/src/unforkify.rs b/hercules_opt/src/unforkify.rs index 7d158d1a..a08d1667 100644 --- a/hercules_opt/src/unforkify.rs +++ b/hercules_opt/src/unforkify.rs @@ -118,7 +118,7 @@ pub fn unforkify( // control insides of the fork-join should become the successor of the true // projection node, and what was the use of the join should become a use of // the new region. - for l in loop_tree.bottom_up_loops().into_iter().rev() { + for l in loop_tree.bottom_up_loops().iter().rev() { if !editor.node(l.0).is_fork() { continue; } @@ -133,7 +133,8 @@ pub fn unforkify( if factors.len() > 1 { // For now, don't convert multi-dimensional fork-joins. Rely on pass // that splits fork-joins. - break; // Because we have to unforkify top down, we can't unforkify forks that are contained + // We can't unforkify, because then the outer forks reduce will depend on non-fork control. + break; } let join_control = nodes[join.idx()].try_join().unwrap(); let tids: Vec<_> = editor @@ -293,5 +294,6 @@ pub fn unforkify( Ok(edit) }); + break; } } diff --git a/juno_scheduler/src/pm.rs b/juno_scheduler/src/pm.rs index d2772c71..378d8730 100644 --- a/juno_scheduler/src/pm.rs +++ b/juno_scheduler/src/pm.rs @@ -1815,25 +1815,35 @@ fn run_pass( } Pass::Unforkify => { assert!(args.is_empty()); - pm.make_fork_join_maps(); - pm.make_loops(); + loop { + let mut inner_changed = false; - let fork_join_maps = pm.fork_join_maps.take().unwrap(); - let loops = pm.loops.take().unwrap(); + pm.make_fork_join_maps(); + pm.make_loops(); - for ((func, fork_join_map), loop_tree) in build_selection(pm, selection) - .into_iter() - .zip(fork_join_maps.iter()) - .zip(loops.iter()) - { - let Some(mut func) = func else { - continue; - }; - unforkify(&mut func, fork_join_map, loop_tree); - changed |= func.modified(); + let fork_join_maps = pm.fork_join_maps.take().unwrap(); + let loops = pm.loops.take().unwrap(); + + for ((func, fork_join_map), loop_tree) in build_selection(pm, selection.clone()) + .into_iter() + .zip(fork_join_maps.iter()) + .zip(loops.iter()) + { + let Some(mut func) = func else { + continue; + }; + unforkify(&mut func, fork_join_map, loop_tree); + changed |= func.modified(); + inner_changed |= func.modified(); + } + pm.delete_gravestones(); + pm.clear_analyses(); + + if !inner_changed { + break; + } + break; } - pm.delete_gravestones(); - pm.clear_analyses(); } Pass::ForkCoalesce => { assert!(args.is_empty()); -- GitLab From d390705275493b662dcfa92a8d7ad35551d119a4 Mon Sep 17 00:00:00 2001 From: Russel Arbore <russel.jma@gmail.com> Date: Sun, 2 Feb 2025 13:20:04 -0600 Subject: [PATCH 09/25] whoops --- juno_samples/fork_join_tests/src/fork_join_tests.jn | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/juno_samples/fork_join_tests/src/fork_join_tests.jn b/juno_samples/fork_join_tests/src/fork_join_tests.jn index 073cfd1e..3d003f3c 100644 --- a/juno_samples/fork_join_tests/src/fork_join_tests.jn +++ b/juno_samples/fork_join_tests/src/fork_join_tests.jn @@ -33,7 +33,7 @@ fn test3(input : i32) -> i32[3, 3] { let arr2 : i32[3, 3]; for i = 0 to 3 { for j = 0 to 3 { - arr2[i, j] = arr1[3 - i, 3 - j]; + arr2[i, j] = arr1[2 - i, 2 - j]; } } let arr3 : i32[3, 3]; -- GitLab From ed33189c2fcaf725a8ae34f09d4fb963826ac4b6 Mon Sep 17 00:00:00 2001 From: Russel Arbore <russel.jma@gmail.com> Date: Sun, 2 Feb 2025 13:38:57 -0600 Subject: [PATCH 10/25] whoops x2 --- hercules_ir/src/ir.rs | 2 +- juno_samples/fork_join_tests/src/gpu.sch | 1 + 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/hercules_ir/src/ir.rs b/hercules_ir/src/ir.rs index 846347b0..5c575ea1 100644 --- a/hercules_ir/src/ir.rs +++ b/hercules_ir/src/ir.rs @@ -1857,7 +1857,7 @@ pub fn indices_may_overlap(indices1: &[Index], indices2: &[Index]) -> bool { * list of indices B. */ pub fn indices_contain_other_indices(indices_a: &[Index], indices_b: &[Index]) -> bool { - if indices_a.len() < indices_b.len() { + if indices_a.len() > indices_b.len() { return false; } diff --git a/juno_samples/fork_join_tests/src/gpu.sch b/juno_samples/fork_join_tests/src/gpu.sch index 80f1bbc9..0647d781 100644 --- a/juno_samples/fork_join_tests/src/gpu.sch +++ b/juno_samples/fork_join_tests/src/gpu.sch @@ -28,5 +28,6 @@ fixpoint panic after 20 { infer-schedules(*); } +xdot[true](*); float-collections(test2, out.test2); gcm(*); -- GitLab From e1f5634f38cca264a25abe83174434ef333af9c3 Mon Sep 17 00:00:00 2001 From: Russel Arbore <russel.jma@gmail.com> Date: Sun, 2 Feb 2025 13:59:43 -0600 Subject: [PATCH 11/25] fix to antideps --- hercules_opt/src/gcm.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/hercules_opt/src/gcm.rs b/hercules_opt/src/gcm.rs index 271bfaf1..b13c919a 100644 --- a/hercules_opt/src/gcm.rs +++ b/hercules_opt/src/gcm.rs @@ -290,7 +290,7 @@ fn basic_blocks( .collect(); for mutator in reverse_postorder.iter() { let mutator_early = schedule_early[mutator.idx()].unwrap(); - if dom.does_dom(root_early, mutator_early) + if dom.does_prop_dom(root_early, mutator_early) && (root_early != mutator_early || root_block_iterated_users.contains(&mutator)) && mutating_objects(function, func_id, *mutator, objects) -- GitLab From 7f381ff54caddce24b0808725390a982e27bfdbc Mon Sep 17 00:00:00 2001 From: Russel Arbore <russel.jma@gmail.com> Date: Sun, 2 Feb 2025 14:38:13 -0600 Subject: [PATCH 12/25] hack for gpu --- hercules_cg/src/gpu.rs | 17 ++++++++--------- hercules_opt/src/gcm.rs | 18 ++++++++++++++++-- 2 files changed, 24 insertions(+), 11 deletions(-) diff --git a/hercules_cg/src/gpu.rs b/hercules_cg/src/gpu.rs index afc016a4..8f186aa7 100644 --- a/hercules_cg/src/gpu.rs +++ b/hercules_cg/src/gpu.rs @@ -622,23 +622,23 @@ extern \"C\" {} {}(", write!(pass_args, "ret")?; write!(w, "\tcudaMalloc((void**)&ret, sizeof({}));\n", ret_type)?; } - write!(w, "\tcudaError_t err;\n"); + write!(w, "\tcudaError_t err;\n")?; write!( w, "\t{}_gpu<<<{}, {}, {}>>>({});\n", self.function.name, num_blocks, num_threads, dynamic_shared_offset, pass_args )?; - write!(w, "\terr = cudaGetLastError();\n"); + write!(w, "\terr = cudaGetLastError();\n")?; write!( w, "\tif (cudaSuccess != err) {{ printf(\"Error1: %s\\n\", cudaGetErrorString(err)); }}\n" - ); + )?; write!(w, "\tcudaDeviceSynchronize();\n")?; - write!(w, "\terr = cudaGetLastError();\n"); + write!(w, "\terr = cudaGetLastError();\n")?; write!( w, "\tif (cudaSuccess != err) {{ printf(\"Error2: %s\\n\", cudaGetErrorString(err)); }}\n" - ); + )?; if has_ret_var { // Copy return from device to host, whether it's primitive value or collection pointer write!(w, "\t{} host_ret;\n", ret_type)?; @@ -1150,7 +1150,8 @@ extern \"C\" {} {}(", // for all threads. Otherwise, it can be inside or outside block fork. // If inside, it's stored in shared memory so we "allocate" it once // and parallelize memset to 0. If outside, we initialize as offset - // to backing, but if multi-block grid, don't memset to avoid grid-level sync. + // to backing, but if multi-block grid, don't memset to avoid grid- + // level sync. Node::Constant { id: cons_id } => { let is_primitive = self.types[self.typing[id.idx()].idx()].is_primitive(); let cg_tile = match state { @@ -1192,9 +1193,7 @@ extern \"C\" {} {}(", )?; } if !is_primitive - && (state != KernelState::OutBlock - || is_block_parallel.is_none() - || !is_block_parallel.unwrap()) + && (state != KernelState::OutBlock || !is_block_parallel.unwrap_or(false)) { let data_size = self.get_size(self.typing[id.idx()], None, Some(extra_dim_collects)); diff --git a/hercules_opt/src/gcm.rs b/hercules_opt/src/gcm.rs index b13c919a..65f7c2d0 100644 --- a/hercules_opt/src/gcm.rs +++ b/hercules_opt/src/gcm.rs @@ -90,6 +90,7 @@ pub fn gcm( loops, fork_join_map, objects, + devices, ); let liveness = liveness_dataflow( @@ -174,6 +175,7 @@ fn basic_blocks( loops: &LoopTree, fork_join_map: &HashMap<NodeID, NodeID>, objects: &CollectionObjects, + devices: &Vec<Device>, ) -> BasicBlocks { let mut bbs: Vec<Option<NodeID>> = vec![None; function.nodes.len()]; @@ -421,9 +423,18 @@ fn basic_blocks( // If the next node further up the dominator tree is in a shallower // loop nest or if we can get out of a reduce loop when we don't // need to be in one, place this data node in a higher-up location. - // Only do this is the node isn't a constant or undef. + // Only do this is the node isn't a constant or undef - if a + // node is a constant or undef, we want its placement to be as + // control dependent as possible, even inside loops. In GPU + // functions specifically, lift constants that may be returned + // outside fork-joins. let is_constant_or_undef = function.nodes[id.idx()].is_constant() || function.nodes[id.idx()].is_undef(); + let is_gpu_returned = devices[func_id.idx()] == Device::CUDA + && objects[&func_id] + .objects(id) + .into_iter() + .any(|obj| objects[&func_id].returned_objects().contains(obj)); let old_nest = loops .header_of(location) .map(|header| loops.nesting(header).unwrap()); @@ -444,7 +455,10 @@ fn basic_blocks( // loop use the reduce node forming the loop, so the dominator chain // will consist of one block, and this loop won't ever iterate. let currently_at_join = function.nodes[location.idx()].is_join(); - if !is_constant_or_undef && (shallower_nest || currently_at_join) { + + if (!is_constant_or_undef || is_gpu_returned) + && (shallower_nest || currently_at_join) + { location = control_node; } } -- GitLab From 25725cb1df4ec7b3c59f963c00aa5a2844a52916 Mon Sep 17 00:00:00 2001 From: Russel Arbore <russel.jma@gmail.com> Date: Sun, 2 Feb 2025 14:50:58 -0600 Subject: [PATCH 13/25] Ok fix antideps for real this time --- hercules_opt/src/gcm.rs | 19 +++++++++++++++++-- juno_scheduler/src/pm.rs | 3 +++ 2 files changed, 20 insertions(+), 2 deletions(-) diff --git a/hercules_opt/src/gcm.rs b/hercules_opt/src/gcm.rs index 65f7c2d0..3ff6d2fe 100644 --- a/hercules_opt/src/gcm.rs +++ b/hercules_opt/src/gcm.rs @@ -1,5 +1,5 @@ use std::cell::Ref; -use std::collections::{BTreeMap, BTreeSet, HashMap, VecDeque}; +use std::collections::{BTreeMap, BTreeSet, HashMap, HashSet, VecDeque}; use std::iter::{empty, once, zip, FromIterator}; use bitvec::prelude::*; @@ -76,6 +76,7 @@ pub fn gcm( dom: &DomTree, fork_join_map: &HashMap<NodeID, NodeID>, loops: &LoopTree, + reduce_cycles: &HashMap<NodeID, HashSet<NodeID>>, objects: &CollectionObjects, devices: &Vec<Device>, object_device_demands: &FunctionObjectDeviceDemands, @@ -88,6 +89,7 @@ pub fn gcm( reverse_postorder, dom, loops, + reduce_cycles, fork_join_map, objects, devices, @@ -173,6 +175,7 @@ fn basic_blocks( reverse_postorder: &Vec<NodeID>, dom: &DomTree, loops: &LoopTree, + reduce_cycles: &HashMap<NodeID, HashSet<NodeID>>, fork_join_map: &HashMap<NodeID, NodeID>, objects: &CollectionObjects, devices: &Vec<Device>, @@ -246,6 +249,9 @@ fn basic_blocks( // but not forwarding read - forwarding reads are collapsed, and the // bottom read is treated as reading from the transitive parent of the // forwarding read(s). + // 3: If the node producing the collection is a reduce node, then any read + // users that aren't in the reduce's cycle shouldn't anti-depend user any + // mutators in the reduce cycle. let mut antideps = BTreeSet::new(); for id in reverse_postorder.iter() { // Find a terminating read node and the collections it reads. @@ -271,6 +277,10 @@ fn basic_blocks( // TODO: make this less outrageously inefficient. let func_objects = &objects[&func_id]; for root in roots.iter() { + let root_is_reduce_and_read_isnt_in_cycle = reduce_cycles + .get(root) + .map(|cycle| !cycle.contains(&id)) + .unwrap_or(false); let root_early = schedule_early[root.idx()].unwrap(); let mut root_block_iterated_users: BTreeSet<NodeID> = BTreeSet::new(); let mut workset = BTreeSet::new(); @@ -292,12 +302,17 @@ fn basic_blocks( .collect(); for mutator in reverse_postorder.iter() { let mutator_early = schedule_early[mutator.idx()].unwrap(); - if dom.does_prop_dom(root_early, mutator_early) + if dom.does_dom(root_early, mutator_early) && (root_early != mutator_early || root_block_iterated_users.contains(&mutator)) && mutating_objects(function, func_id, *mutator, objects) .any(|mutated| read_objs.contains(&mutated)) && id != mutator + && (!root_is_reduce_and_read_isnt_in_cycle + || !reduce_cycles + .get(root) + .map(|cycle| cycle.contains(mutator)) + .unwrap_or(false)) { antideps.insert((*id, *mutator)); } diff --git a/juno_scheduler/src/pm.rs b/juno_scheduler/src/pm.rs index 1ebc885c..db4455e8 100644 --- a/juno_scheduler/src/pm.rs +++ b/juno_scheduler/src/pm.rs @@ -1567,6 +1567,7 @@ fn run_pass( pm.make_doms(); pm.make_fork_join_maps(); pm.make_loops(); + pm.make_reduce_cycles(); pm.make_collection_objects(); pm.make_devices(); pm.make_object_device_demands(); @@ -1577,6 +1578,7 @@ fn run_pass( let doms = pm.doms.take().unwrap(); let fork_join_maps = pm.fork_join_maps.take().unwrap(); let loops = pm.loops.take().unwrap(); + let reduce_cycles = pm.reduce_cycles.take().unwrap(); let control_subgraphs = pm.control_subgraphs.take().unwrap(); let collection_objects = pm.collection_objects.take().unwrap(); let devices = pm.devices.take().unwrap(); @@ -1598,6 +1600,7 @@ fn run_pass( &doms[id.idx()], &fork_join_maps[id.idx()], &loops[id.idx()], + &reduce_cycles[id.idx()], &collection_objects, &devices, &object_device_demands[id.idx()], -- GitLab From c7d47ec8de0737d73395c06cd3ea5e835b2324d8 Mon Sep 17 00:00:00 2001 From: Russel Arbore <russel.jma@gmail.com> Date: Sun, 2 Feb 2025 14:51:22 -0600 Subject: [PATCH 14/25] remove xdot from schedule --- juno_samples/fork_join_tests/src/gpu.sch | 1 - 1 file changed, 1 deletion(-) diff --git a/juno_samples/fork_join_tests/src/gpu.sch b/juno_samples/fork_join_tests/src/gpu.sch index 0647d781..80f1bbc9 100644 --- a/juno_samples/fork_join_tests/src/gpu.sch +++ b/juno_samples/fork_join_tests/src/gpu.sch @@ -28,6 +28,5 @@ fixpoint panic after 20 { infer-schedules(*); } -xdot[true](*); float-collections(test2, out.test2); gcm(*); -- GitLab From 0caa58c5be3deb70cd0ef5dfde77b02a09bfbf2a Mon Sep 17 00:00:00 2001 From: Russel Arbore <russel.jma@gmail.com> Date: Sun, 2 Feb 2025 15:28:02 -0600 Subject: [PATCH 15/25] Test requiring outer split + unforkify --- .../fork_join_tests/src/fork_join_tests.jn | 15 +++++++++++++++ juno_samples/fork_join_tests/src/gpu.sch | 1 + juno_samples/fork_join_tests/src/main.rs | 5 +++++ 3 files changed, 21 insertions(+) diff --git a/juno_samples/fork_join_tests/src/fork_join_tests.jn b/juno_samples/fork_join_tests/src/fork_join_tests.jn index 3d003f3c..55e0a37e 100644 --- a/juno_samples/fork_join_tests/src/fork_join_tests.jn +++ b/juno_samples/fork_join_tests/src/fork_join_tests.jn @@ -44,3 +44,18 @@ fn test3(input : i32) -> i32[3, 3] { } return arr3; } + +#[entry] +fn test4(input : i32) -> i32[4, 4] { + let arr : i32[4, 4]; + for i = 0 to 4 { + for j = 0 to 4 { + let acc = arr[i, j]; + for k = 0 to 7 { + acc += input; + } + arr[i, j] = acc; + } + } + return arr; +} diff --git a/juno_samples/fork_join_tests/src/gpu.sch b/juno_samples/fork_join_tests/src/gpu.sch index 80f1bbc9..bf35caea 100644 --- a/juno_samples/fork_join_tests/src/gpu.sch +++ b/juno_samples/fork_join_tests/src/gpu.sch @@ -6,6 +6,7 @@ let out = auto-outline(*); gpu(out.test1); gpu(out.test2); gpu(out.test3); +gpu(out.test4); ip-sroa(*); sroa(*); diff --git a/juno_samples/fork_join_tests/src/main.rs b/juno_samples/fork_join_tests/src/main.rs index 4384ecd5..cbd42c50 100644 --- a/juno_samples/fork_join_tests/src/main.rs +++ b/juno_samples/fork_join_tests/src/main.rs @@ -32,6 +32,11 @@ fn main() { let output = r.run(0).await; let correct = vec![11, 10, 9, 10, 9, 8, 9, 8, 7]; assert(correct, output); + + let mut r = runner!(test4); + let output = r.run(9).await; + let correct = vec![63i32; 16]; + assert(correct, output); }); } -- GitLab From 50555d655836a3090d605dfe622a9cc1127076a6 Mon Sep 17 00:00:00 2001 From: Xavier Routh <xrouth2@illinois.edu> Date: Sun, 2 Feb 2025 15:42:08 -0600 Subject: [PATCH 16/25] interpreter fixes + product consts --- hercules_test/hercules_interpreter/src/interpreter.rs | 3 +-- hercules_test/hercules_interpreter/src/value.rs | 4 ++-- 2 files changed, 3 insertions(+), 4 deletions(-) diff --git a/hercules_test/hercules_interpreter/src/interpreter.rs b/hercules_test/hercules_interpreter/src/interpreter.rs index 871e304a..22ef062a 100644 --- a/hercules_test/hercules_interpreter/src/interpreter.rs +++ b/hercules_test/hercules_interpreter/src/interpreter.rs @@ -783,8 +783,7 @@ impl<'a> FunctionExecutionState<'a> { &self.module.dynamic_constants, &self.dynamic_constant_params, ) - }) - .rev(); + }); let n_tokens: usize = factors.clone().product(); diff --git a/hercules_test/hercules_interpreter/src/value.rs b/hercules_test/hercules_interpreter/src/value.rs index 53911e05..adbed6e6 100644 --- a/hercules_test/hercules_interpreter/src/value.rs +++ b/hercules_test/hercules_interpreter/src/value.rs @@ -156,8 +156,8 @@ impl<'a> InterpreterVal { Constant::Float64(v) => Self::Float64(v), Constant::Product(ref type_id, ref constant_ids) => { - // Self::Product((), ()) - todo!() + let contents = constant_ids.iter().map(|const_id| InterpreterVal::from_constant(&constants[const_id.idx()], constants, types, dynamic_constants, dynamic_constant_params)); + InterpreterVal::Product(*type_id, contents.collect_vec().into_boxed_slice()) } Constant::Summation(_, _, _) => todo!(), Constant::Array(type_id) => { -- GitLab From acf060c7daf486ef28a68a292beda3132850891d Mon Sep 17 00:00:00 2001 From: Xavier Routh <xrouth2@illinois.edu> Date: Sun, 2 Feb 2025 15:42:26 -0600 Subject: [PATCH 17/25] read schedule from file --- juno_scheduler/src/lib.rs | 36 ++++++++++++++++++++++++++++++++++++ 1 file changed, 36 insertions(+) diff --git a/juno_scheduler/src/lib.rs b/juno_scheduler/src/lib.rs index 571d1fbf..2479af98 100644 --- a/juno_scheduler/src/lib.rs +++ b/juno_scheduler/src/lib.rs @@ -146,6 +146,42 @@ pub fn run_schedule_on_hercules( .map_err(|e| format!("Scheduling Error: {}", e)) } + +pub fn run_schedule_from_file_on_hercules( + module: Module, + sched_filename: Option<String>, +) -> Result<Module, String> { + let sched = process_schedule(sched_filename)?; + + // Prepare the scheduler's string table and environment + // For this, we put all of the Hercules function names into the environment + // and string table + let mut strings = StringTable::new(); + let mut env = Env::new(); + + env.open_scope(); + + for (idx, func) in module.functions.iter().enumerate() { + let func_name = strings.lookup_string(func.name.clone()); + env.insert( + func_name, + Value::HerculesFunction { + func: FunctionID::new(idx), + }, + ); + } + + env.open_scope(); + schedule_module( + module, + sched, + strings, + env, + JunoFunctions { func_ids: vec![] }, + ) + .map_err(|e| format!("Scheduling Error: {}", e)) +} + pub fn schedule_hercules( module: Module, sched_filename: Option<String>, -- GitLab From 61bc9ae455c3933f930d7a980a4cc4efc39afb6f Mon Sep 17 00:00:00 2001 From: Xavier Routh <xrouth2@illinois.edu> Date: Sun, 2 Feb 2025 15:43:08 -0600 Subject: [PATCH 18/25] unforkify fix --- hercules_opt/src/fork_guard_elim.rs | 8 ++++---- juno_scheduler/src/pm.rs | 1 - 2 files changed, 4 insertions(+), 5 deletions(-) diff --git a/hercules_opt/src/fork_guard_elim.rs b/hercules_opt/src/fork_guard_elim.rs index 052fd0e4..f6914b74 100644 --- a/hercules_opt/src/fork_guard_elim.rs +++ b/hercules_opt/src/fork_guard_elim.rs @@ -39,7 +39,6 @@ struct GuardedFork { guard_if: NodeID, fork_taken_proj: NodeID, fork_skipped_proj: NodeID, - guard_pred: NodeID, guard_join_region: NodeID, phi_reduce_map: HashMap<NodeID, NodeID>, factor: Factor, // The factor that matches the guard @@ -302,7 +301,6 @@ fn guarded_fork( guard_if: if_node, fork_taken_proj: *control, fork_skipped_proj: other_pred, - guard_pred: if_pred, guard_join_region: join_control, phi_reduce_map: phi_nodes, factor, @@ -323,13 +321,15 @@ pub fn fork_guard_elim(editor: &mut FunctionEditor, fork_join_map: &HashMap<Node join, fork_taken_proj, fork_skipped_proj, - guard_pred, phi_reduce_map, factor, guard_if, guard_join_region, } in guard_info - { + { + let Some(guard_pred) = editor.get_uses(guard_if).next() else { + unreachable!() + }; let new_fork_info = if let Factor::Max(idx, dc) = factor { let Node::Fork { control: _, diff --git a/juno_scheduler/src/pm.rs b/juno_scheduler/src/pm.rs index db3904f7..dd2ae73a 100644 --- a/juno_scheduler/src/pm.rs +++ b/juno_scheduler/src/pm.rs @@ -1895,7 +1895,6 @@ fn run_pass( if !inner_changed { break; } - break; } } Pass::ForkCoalesce => { -- GitLab From 82ab7c076640225cf0ff4de2b8443a46dcc9b95f Mon Sep 17 00:00:00 2001 From: Xavier Routh <xrouth2@illinois.edu> Date: Sun, 2 Feb 2025 22:39:36 -0600 Subject: [PATCH 19/25] fork dim merge --- hercules_opt/src/fork_transforms.rs | 161 ++++++++++++++++++++++++++++ juno_scheduler/src/compile.rs | 3 +- juno_scheduler/src/ir.rs | 2 + juno_scheduler/src/lib.rs | 3 +- juno_scheduler/src/pm.rs | 30 ++++-- 5 files changed, 189 insertions(+), 10 deletions(-) diff --git a/hercules_opt/src/fork_transforms.rs b/hercules_opt/src/fork_transforms.rs index e23f586f..58ace775 100644 --- a/hercules_opt/src/fork_transforms.rs +++ b/hercules_opt/src/fork_transforms.rs @@ -1,5 +1,6 @@ use std::collections::{HashMap, HashSet}; use std::iter::zip; +use std::thread::ThreadId; use bimap::BiMap; use itertools::Itertools; @@ -693,3 +694,163 @@ pub(crate) fn split_fork( None } } + +// Splits a dimension of a single fork join into multiple. +// Iterates an outer loop original_dim / tile_size times +// adds a tile_size loop as the inner loop +// Assumes that tile size divides original dim evenly. +pub fn chunk_fork_unguarded( + editor: &mut FunctionEditor, + fork: NodeID, + dim_idx: usize, + tile_size: DynamicConstantID, +) -> () { + // tid_dim_idx = tid_dim_idx * tile_size + tid_(dim_idx + 1) + + let Node::Fork { control: old_control, factors: ref old_factors} = *editor.node(fork) else {return}; + + let mut new_factors: Vec<_> = old_factors.to_vec(); + + let fork_users: Vec<_> = editor.get_users(fork).collect(); + + + for tid in fork_users { + let Node::ThreadID { control: _, dimension: tid_dim } = editor.node(tid).clone() else { continue }; + editor.edit(|mut edit| { + if tid_dim > dim_idx { + let new_tid = Node::ThreadID { control: fork, dimension: tid_dim + 1 }; + let new_tid = edit.add_node(new_tid); + edit.replace_all_uses(tid, new_tid) + } else if tid_dim == dim_idx { + let tile_tid = Node::ThreadID { control: fork, dimension: tid_dim + 1 }; + let tile_tid = edit.add_node(tile_tid); + + let tile_size = edit.add_node(Node::DynamicConstant { id: tile_size }); + let mul = edit.add_node(Node::Binary { left: tid, right: tile_size, op: BinaryOperator::Mul }); + let add = edit.add_node(Node::Binary { left: mul, right: tile_tid, op: BinaryOperator::Add }); + edit.replace_all_uses_where(tid, add, |usee| *usee != mul ) + } else { + Ok(edit) + } + }); + } + + editor.edit(|mut edit| { + let outer = DynamicConstant::div(new_factors[dim_idx], tile_size); + new_factors.insert(dim_idx + 1, tile_size); + new_factors[dim_idx] = edit.add_dynamic_constant(outer); + + let new_fork = Node::Fork { control: old_control, factors: new_factors.into() }; + let new_fork = edit.add_node(new_fork); + + edit.replace_all_uses(fork, new_fork) + }); +} + + +pub fn merge_all_fork_dims( + editor: &mut FunctionEditor, + fork_join_map: &HashMap<NodeID, NodeID>, +) { + for (fork, _) in fork_join_map { + let Node::Fork { control: _, factors: dims } = editor.node(fork) else { + unreachable!(); + }; + + let mut fork = *fork; + for _ in 0..dims.len() - 1 { + let outer = 0; + let inner = 1; + fork = fork_dim_merge(editor, fork, outer, inner); + } + } +} + +// Splits a dimension of a single fork join into multiple. +// Iterates an outer loop original_dim / tile_size times +// adds a tile_size loop as the inner loop +// Assumes that tile size divides original dim evenly. +pub fn fork_dim_merge( + editor: &mut FunctionEditor, + fork: NodeID, + dim_idx1: usize, + dim_idx2: usize, +) -> NodeID { + // tid_dim_idx1 (replaced w/) <- dim_idx1 / dim(dim_idx2) + // tid_dim_idx2 (replaced w/) <- dim_idx1 % dim(dim_idx2) + assert_ne!(dim_idx1, dim_idx2); + + // Outer is smaller, and also closer to the left of the factors array. + let (outer_idx, inner_idx) = if dim_idx2 < dim_idx1 { + (dim_idx2, dim_idx1) + } else { + (dim_idx1, dim_idx2) + }; + + let Node::Fork { control: old_control, factors: ref old_factors} = *editor.node(fork) else {return fork}; + + let mut new_factors: Vec<_> = old_factors.to_vec(); + + + + let fork_users: Vec<_> = editor.get_users(fork).collect(); + + let mut new_nodes = vec![]; + + let outer_dc_id = new_factors[outer_idx]; + let inner_dc_id = new_factors[inner_idx]; + + let mut new_fork_id = NodeID::new(0); + editor.edit(|mut edit| { + new_factors[outer_idx] = edit.add_dynamic_constant(DynamicConstant::mul(new_factors[outer_idx], new_factors[inner_idx])); + new_factors.remove(inner_idx); + + let new_fork = Node::Fork { control: old_control, factors: new_factors.into() }; + let new_fork = edit.add_node(new_fork); + new_fork_id = new_fork; + + + edit = edit.replace_all_uses(fork, new_fork)?; + edit.delete_node(fork) + }); + + + + for tid in fork_users { + let Node::ThreadID { control: _, dimension: tid_dim } = editor.node(tid).clone() else { continue }; + + println!("tid: {:?}", tid); + editor.edit(|mut edit| { + if tid_dim > inner_idx { + let new_tid = Node::ThreadID { control: new_fork_id, dimension: tid_dim - 1 }; + let new_tid = edit.add_node(new_tid); + edit.replace_all_uses(tid, new_tid) + } else if tid_dim == outer_idx { + let outer_tid = Node::ThreadID { control: new_fork_id, dimension: outer_idx }; + let outer_tid = edit.add_node(outer_tid); + + let outer_dc = edit.add_node(Node::DynamicConstant { id: outer_dc_id }); + new_nodes.push(outer_tid); + + // inner_idx % dim(outer_idx) + let rem = edit.add_node(Node::Binary { left: outer_tid, right: outer_dc, op: BinaryOperator::Rem}); + + edit.replace_all_uses(tid, rem) + } else if tid_dim == inner_idx { + let outer_tid = Node::ThreadID { control: new_fork_id, dimension: outer_idx }; + let outer_tid = edit.add_node(outer_tid); + + let outer_dc = edit.add_node(Node::DynamicConstant { id: outer_dc_id }); + // inner_idx / dim(outer_idx) + let div = edit.add_node(Node::Binary { left: outer_tid, right: outer_dc, op: BinaryOperator::Div}); + + edit.replace_all_uses(tid, div) + } else { + Ok(edit) + } + }); + }; + + return new_fork_id; + +} \ No newline at end of file diff --git a/juno_scheduler/src/compile.rs b/juno_scheduler/src/compile.rs index 11a8ec53..07ad5e7a 100644 --- a/juno_scheduler/src/compile.rs +++ b/juno_scheduler/src/compile.rs @@ -108,7 +108,8 @@ impl FromStr for Appliable { "inline" => Ok(Appliable::Pass(ir::Pass::Inline)), "ip-sroa" | "interprocedural-sroa" => { Ok(Appliable::Pass(ir::Pass::InterproceduralSROA)) - } + }, + "fork-dim-merge" => Ok(Appliable::Pass(ir::Pass::ForkDimMerge)), "lift-dc-math" => Ok(Appliable::Pass(ir::Pass::LiftDCMath)), "outline" => Ok(Appliable::Pass(ir::Pass::Outline)), "phi-elim" => Ok(Appliable::Pass(ir::Pass::PhiElim)), diff --git a/juno_scheduler/src/ir.rs b/juno_scheduler/src/ir.rs index d6a41baf..939ef925 100644 --- a/juno_scheduler/src/ir.rs +++ b/juno_scheduler/src/ir.rs @@ -12,6 +12,8 @@ pub enum Pass { ForkSplit, ForkCoalesce, Forkify, + ForkDimMerge, + ForkChunk, GCM, GVN, InferSchedules, diff --git a/juno_scheduler/src/lib.rs b/juno_scheduler/src/lib.rs index 2479af98..ad9195fb 100644 --- a/juno_scheduler/src/lib.rs +++ b/juno_scheduler/src/lib.rs @@ -60,7 +60,7 @@ fn build_schedule(sched_filename: String) -> Result<ScheduleStmt, String> { } } -fn process_schedule(sched_filename: Option<String>) -> Result<ScheduleStmt, String> { +pub fn process_schedule(sched_filename: Option<String>) -> Result<ScheduleStmt, String> { if let Some(name) = sched_filename { build_schedule(name) } else { @@ -146,7 +146,6 @@ pub fn run_schedule_on_hercules( .map_err(|e| format!("Scheduling Error: {}", e)) } - pub fn run_schedule_from_file_on_hercules( module: Module, sched_filename: Option<String>, diff --git a/juno_scheduler/src/pm.rs b/juno_scheduler/src/pm.rs index 43f355c3..8b71d24e 100644 --- a/juno_scheduler/src/pm.rs +++ b/juno_scheduler/src/pm.rs @@ -1871,14 +1871,11 @@ fn run_pass( } Pass::Unforkify => { assert!(args.is_empty()); - loop { - let mut inner_changed = false; - - pm.make_fork_join_maps(); - pm.make_loops(); + pm.make_fork_join_maps(); + pm.make_loops(); - let fork_join_maps = pm.fork_join_maps.take().unwrap(); - let loops = pm.loops.take().unwrap(); + let fork_join_maps = pm.fork_join_maps.take().unwrap(); + let loops = pm.loops.take().unwrap(); for ((func, fork_join_map), loop_tree) in build_selection(pm, selection) .into_iter() @@ -1894,6 +1891,24 @@ fn run_pass( pm.delete_gravestones(); pm.clear_analyses(); } + Pass::ForkDimMerge => { + assert!(args.is_empty()); + pm.make_fork_join_maps(); + let fork_join_maps = pm.fork_join_maps.take().unwrap(); + for (func, fork_join_map) in + build_selection(pm, selection) + .into_iter() + .zip(fork_join_maps.iter()) + { + let Some(mut func) = func else { + continue; + }; + merge_all_fork_dims(&mut func, fork_join_map); + changed |= func.modified(); + } + pm.delete_gravestones(); + pm.clear_analyses(); + } Pass::ForkCoalesce => { assert!(args.is_empty()); pm.make_fork_join_maps(); @@ -1991,6 +2006,7 @@ fn run_pass( // Put BasicBlocks back, since it's needed for Codegen. pm.bbs = bbs; } + Pass::ForkChunk => todo!(), } println!("Ran Pass: {:?}", pass); -- GitLab From 1788bd1f2ba4ec496cb5afa62da884f92c1761a4 Mon Sep 17 00:00:00 2001 From: Xavier Routh <xrouth2@illinois.edu> Date: Mon, 3 Feb 2025 11:10:10 -0600 Subject: [PATCH 20/25] tiling + dim merge with one edit per loop dim --- hercules_opt/src/fork_transforms.rs | 96 ++++++++++++++++------------- juno_scheduler/src/compile.rs | 1 + juno_scheduler/src/ir.rs | 1 + juno_scheduler/src/pm.rs | 34 ++++++++++ 4 files changed, 90 insertions(+), 42 deletions(-) diff --git a/hercules_opt/src/fork_transforms.rs b/hercules_opt/src/fork_transforms.rs index 58ace775..cbb09bbf 100644 --- a/hercules_opt/src/fork_transforms.rs +++ b/hercules_opt/src/fork_transforms.rs @@ -695,6 +695,24 @@ pub(crate) fn split_fork( } } +pub fn chunk_all_forks_unguarded( + editor: &mut FunctionEditor, + fork_join_map: &HashMap<NodeID, NodeID>, + dim_idx: usize, + tile_size: usize, +) -> () { + // Add dc + let mut dc_id = DynamicConstantID::new(0); + editor.edit(|mut edit| { + dc_id = edit.add_dynamic_constant(DynamicConstant::Constant(tile_size)); + Ok(edit) + }); + + for (fork, _ ) in fork_join_map { + chunk_fork_unguarded(editor, *fork, dim_idx, dc_id); + } + +} // Splits a dimension of a single fork join into multiple. // Iterates an outer loop original_dim / tile_size times // adds a tile_size loop as the inner loop @@ -711,39 +729,36 @@ pub fn chunk_fork_unguarded( let mut new_factors: Vec<_> = old_factors.to_vec(); - let fork_users: Vec<_> = editor.get_users(fork).collect(); + let fork_users: Vec<_> = editor.get_users(fork).map(|f| (f, editor.node(f).clone())).collect(); + + editor.edit(|mut edit| { + let outer = DynamicConstant::div(new_factors[dim_idx], tile_size); + new_factors.insert(dim_idx + 1, tile_size); + new_factors[dim_idx] = edit.add_dynamic_constant(outer); + + let new_fork = Node::Fork { control: old_control, factors: new_factors.into() }; + let new_fork = edit.add_node(new_fork); + edit = edit.replace_all_uses(fork, new_fork)?; - for tid in fork_users { - let Node::ThreadID { control: _, dimension: tid_dim } = editor.node(tid).clone() else { continue }; - editor.edit(|mut edit| { + for (tid, node) in fork_users { + let Node::ThreadID { control: _, dimension: tid_dim } = node else {continue}; if tid_dim > dim_idx { - let new_tid = Node::ThreadID { control: fork, dimension: tid_dim + 1 }; + let new_tid = Node::ThreadID { control: new_fork, dimension: tid_dim + 1 }; let new_tid = edit.add_node(new_tid); - edit.replace_all_uses(tid, new_tid) + edit = edit.replace_all_uses(tid, new_tid)?; } else if tid_dim == dim_idx { - let tile_tid = Node::ThreadID { control: fork, dimension: tid_dim + 1 }; + let tile_tid = Node::ThreadID { control: new_fork, dimension: tid_dim + 1 }; let tile_tid = edit.add_node(tile_tid); let tile_size = edit.add_node(Node::DynamicConstant { id: tile_size }); let mul = edit.add_node(Node::Binary { left: tid, right: tile_size, op: BinaryOperator::Mul }); let add = edit.add_node(Node::Binary { left: mul, right: tile_tid, op: BinaryOperator::Add }); - edit.replace_all_uses_where(tid, add, |usee| *usee != mul ) - } else { - Ok(edit) + edit = edit.replace_all_uses_where(tid, add, |usee| *usee != mul )?; } - }); - } - - editor.edit(|mut edit| { - let outer = DynamicConstant::div(new_factors[dim_idx], tile_size); - new_factors.insert(dim_idx + 1, tile_size); - new_factors[dim_idx] = edit.add_dynamic_constant(outer); - - let new_fork = Node::Fork { control: old_control, factors: new_factors.into() }; - let new_fork = edit.add_node(new_fork); - - edit.replace_all_uses(fork, new_fork) + } + edit = edit.delete_node(fork)?; + Ok(edit) }); } @@ -791,9 +806,8 @@ pub fn fork_dim_merge( let mut new_factors: Vec<_> = old_factors.to_vec(); - - let fork_users: Vec<_> = editor.get_users(fork).collect(); + let fork_users: Vec<_> = editor.get_users(fork).map(|f| (f, editor.node(f).clone())).collect(); let mut new_nodes = vec![]; @@ -801,6 +815,7 @@ pub fn fork_dim_merge( let inner_dc_id = new_factors[inner_idx]; let mut new_fork_id = NodeID::new(0); + editor.edit(|mut edit| { new_factors[outer_idx] = edit.add_dynamic_constant(DynamicConstant::mul(new_factors[outer_idx], new_factors[inner_idx])); new_factors.remove(inner_idx); @@ -809,22 +824,20 @@ pub fn fork_dim_merge( let new_fork = edit.add_node(new_fork); new_fork_id = new_fork; + edit.sub_edit(fork, new_fork); edit = edit.replace_all_uses(fork, new_fork)?; - edit.delete_node(fork) - }); - - - - for tid in fork_users { - let Node::ThreadID { control: _, dimension: tid_dim } = editor.node(tid).clone() else { continue }; + edit = edit.delete_node(fork)?; - println!("tid: {:?}", tid); - editor.edit(|mut edit| { + for (tid, node) in fork_users { + // FIXME: DO we want sub edits in this? + + let Node::ThreadID { control: _, dimension: tid_dim } = node else { continue }; if tid_dim > inner_idx { let new_tid = Node::ThreadID { control: new_fork_id, dimension: tid_dim - 1 }; let new_tid = edit.add_node(new_tid); - edit.replace_all_uses(tid, new_tid) + edit = edit.replace_all_uses(tid, new_tid)?; + edit.sub_edit(tid, new_tid); } else if tid_dim == outer_idx { let outer_tid = Node::ThreadID { control: new_fork_id, dimension: outer_idx }; let outer_tid = edit.add_node(outer_tid); @@ -834,8 +847,8 @@ pub fn fork_dim_merge( // inner_idx % dim(outer_idx) let rem = edit.add_node(Node::Binary { left: outer_tid, right: outer_dc, op: BinaryOperator::Rem}); - - edit.replace_all_uses(tid, rem) + edit.sub_edit(tid, rem); + edit = edit.replace_all_uses(tid, rem)?; } else if tid_dim == inner_idx { let outer_tid = Node::ThreadID { control: new_fork_id, dimension: outer_idx }; let outer_tid = edit.add_node(outer_tid); @@ -843,13 +856,12 @@ pub fn fork_dim_merge( let outer_dc = edit.add_node(Node::DynamicConstant { id: outer_dc_id }); // inner_idx / dim(outer_idx) let div = edit.add_node(Node::Binary { left: outer_tid, right: outer_dc, op: BinaryOperator::Div}); - - edit.replace_all_uses(tid, div) - } else { - Ok(edit) + edit.sub_edit(tid, div); + edit = edit.replace_all_uses(tid, div)?; } - }); - }; + } + Ok(edit) + }); return new_fork_id; diff --git a/juno_scheduler/src/compile.rs b/juno_scheduler/src/compile.rs index 07ad5e7a..49dedd2b 100644 --- a/juno_scheduler/src/compile.rs +++ b/juno_scheduler/src/compile.rs @@ -110,6 +110,7 @@ impl FromStr for Appliable { Ok(Appliable::Pass(ir::Pass::InterproceduralSROA)) }, "fork-dim-merge" => Ok(Appliable::Pass(ir::Pass::ForkDimMerge)), + "fork-chunk" | "fork-tile" => Ok(Appliable::Pass(ir::Pass::ForkChunk)), "lift-dc-math" => Ok(Appliable::Pass(ir::Pass::LiftDCMath)), "outline" => Ok(Appliable::Pass(ir::Pass::Outline)), "phi-elim" => Ok(Appliable::Pass(ir::Pass::PhiElim)), diff --git a/juno_scheduler/src/ir.rs b/juno_scheduler/src/ir.rs index 939ef925..796437a7 100644 --- a/juno_scheduler/src/ir.rs +++ b/juno_scheduler/src/ir.rs @@ -36,6 +36,7 @@ impl Pass { pub fn num_args(&self) -> usize { match self { Pass::Xdot => 1, + Pass::ForkChunk => 3, _ => 0, } } diff --git a/juno_scheduler/src/pm.rs b/juno_scheduler/src/pm.rs index 8b71d24e..5740d2a6 100644 --- a/juno_scheduler/src/pm.rs +++ b/juno_scheduler/src/pm.rs @@ -1891,6 +1891,40 @@ fn run_pass( pm.delete_gravestones(); pm.clear_analyses(); } + Pass::ForkChunk => { + assert_eq!(args.len(), 3); + let tile_size = args.get(0); + let dim_idx = args.get(1); + + let Some(Value::Boolean { val: guarded_flag }) = args.get(2) else { + panic!(); // How to error here? + }; + + let Some(Value::Integer { val: dim_idx }) = args.get(1) else { + panic!(); // How to error here? + }; + + let Some(Value::Integer { val: tile_size }) = args.get(0) else { + panic!(); // How to error here? + }; + + assert_eq!(*guarded_flag, true); + pm.make_fork_join_maps(); + let fork_join_maps = pm.fork_join_maps.take().unwrap(); + for (func, fork_join_map) in + build_selection(pm, selection) + .into_iter() + .zip(fork_join_maps.iter()) + { + let Some(mut func) = func else { + continue; + }; + chunk_all_forks_unguarded(&mut func, fork_join_map, *dim_idx, *tile_size); + changed |= func.modified(); + } + pm.delete_gravestones(); + pm.clear_analyses(); + } Pass::ForkDimMerge => { assert!(args.is_empty()); pm.make_fork_join_maps(); -- GitLab From 27731fa4fe3553ba1a93a98115636b60a4dd00ce Mon Sep 17 00:00:00 2001 From: Xavier Routh <xrouth2@illinois.edu> Date: Mon, 3 Feb 2025 11:15:56 -0600 Subject: [PATCH 21/25] check for out of bounds dim on chunking --- hercules_opt/src/fork_transforms.rs | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/hercules_opt/src/fork_transforms.rs b/hercules_opt/src/fork_transforms.rs index cbb09bbf..190dbd25 100644 --- a/hercules_opt/src/fork_transforms.rs +++ b/hercules_opt/src/fork_transforms.rs @@ -725,8 +725,13 @@ pub fn chunk_fork_unguarded( ) -> () { // tid_dim_idx = tid_dim_idx * tile_size + tid_(dim_idx + 1) + let Node::Fork { control: old_control, factors: ref old_factors} = *editor.node(fork) else {return}; + if dim_idx >= old_factors.len() { + return; // FIXME Error here? + } + let mut new_factors: Vec<_> = old_factors.to_vec(); let fork_users: Vec<_> = editor.get_users(fork).map(|f| (f, editor.node(f).clone())).collect(); -- GitLab From 1f1c6cb94c80ac867958c41d2a84d506bf06ef92 Mon Sep 17 00:00:00 2001 From: Xavier Routh <xrouth2@illinois.edu> Date: Mon, 3 Feb 2025 11:46:58 -0600 Subject: [PATCH 22/25] rewrite forkify as single edit per loop --- hercules_opt/src/forkify.rs | 75 ++++++++++++++++++++----------------- 1 file changed, 40 insertions(+), 35 deletions(-) diff --git a/hercules_opt/src/forkify.rs b/hercules_opt/src/forkify.rs index 38b9aaaa..0a2d5601 100644 --- a/hercules_opt/src/forkify.rs +++ b/hercules_opt/src/forkify.rs @@ -298,29 +298,11 @@ pub fn forkify_loop( let (_, factors) = function.nodes[fork_id.idx()].try_fork().unwrap(); let dimension = factors.len() - 1; - // Create ThreadID - editor.edit(|mut edit| { - let thread_id = Node::ThreadID { - control: fork_id, - dimension: dimension, - }; - let thread_id_id = edit.add_node(thread_id); - - // Replace uses that are inside with the thread id - edit = edit.replace_all_uses_where(canonical_iv.phi(), thread_id_id, |node| { - loop_nodes.contains(node) - })?; + // Start failable edit: - // Replace uses that are outside with DC - 1. Or just give up. - let bound_dc_node = edit.add_node(Node::DynamicConstant { id: bound_dc_id }); - edit = edit.replace_all_uses_where(canonical_iv.phi(), bound_dc_node, |node| { - !loop_nodes.contains(node) - })?; - - edit.delete_node(canonical_iv.phi()) - }); - - for reduction_phi in reductionable_phis { + let redcutionable_phis_and_init: Vec<(_, NodeID)> = + reductionable_phis.iter().map(|reduction_phi| { + let LoopPHI::Reductionable { phi, data_cycle: _, @@ -342,12 +324,41 @@ pub fn forkify_loop( .unwrap() .1; - editor.edit(|mut edit| { + (reduction_phi, init) + }).collect(); + + editor.edit(|mut edit| { + let thread_id = Node::ThreadID { + control: fork_id, + dimension: dimension, + }; + let thread_id_id = edit.add_node(thread_id); + + // Replace uses that are inside with the thread id + edit = edit.replace_all_uses_where(canonical_iv.phi(), thread_id_id, |node| { + loop_nodes.contains(node) + })?; + + edit = edit.delete_node(canonical_iv.phi())?; + + for (reduction_phi, init) in redcutionable_phis_and_init { + let LoopPHI::Reductionable { + phi, + data_cycle: _, + continue_latch, + is_associative: _, + } = *reduction_phi + else { + panic!(); + }; + let reduce = Node::Reduce { control: join_id, init, reduct: continue_latch, }; + + let reduce_id = edit.add_node(reduce); if (!edit.get_node(init).is_reduce() @@ -375,20 +386,14 @@ pub fn forkify_loop( edit = edit.replace_all_uses_where(continue_latch, reduce_id, |usee| { !loop_nodes.contains(usee) && *usee != reduce_id })?; - edit.delete_node(phi) - }); - } - - // Replace all uses of the loop header with the fork - editor.edit(|edit| edit.replace_all_uses(l.header, fork_id)); + edit = edit.delete_node(phi)? - editor.edit(|edit| edit.replace_all_uses(loop_continue_projection, fork_id)); + } - editor.edit(|edit| edit.replace_all_uses(loop_exit_projection, join_id)); + edit = edit.replace_all_uses(l.header, fork_id)?; + edit = edit.replace_all_uses(loop_continue_projection, fork_id)?; + edit = edit.replace_all_uses(loop_exit_projection, join_id)?; - // Get rid of loop condition - // DCE should get these, but delete them ourselves because we are nice :) - editor.edit(|mut edit| { edit = edit.delete_node(loop_continue_projection)?; edit = edit.delete_node(condition_node)?; // Might have to get rid of other users of this. edit = edit.delete_node(loop_exit_projection)?; @@ -396,7 +401,7 @@ pub fn forkify_loop( edit = edit.delete_node(l.header)?; Ok(edit) }); - + return true; } -- GitLab From 6382ef4263b16f54a8d3b4d5e3a795c9c9e11013 Mon Sep 17 00:00:00 2001 From: Russel Arbore <russel.jma@gmail.com> Date: Mon, 3 Feb 2025 15:44:25 -0600 Subject: [PATCH 23/25] fix ups --- Cargo.lock | 11 -- hercules_cg/src/fork_tree.rs | 19 +- hercules_opt/src/fork_guard_elim.rs | 19 +- hercules_opt/src/fork_transforms.rs | 166 +++++++++++------- hercules_opt/src/forkify.rs | 61 ++++--- hercules_samples/dot/build.rs | 6 +- hercules_samples/dot/src/main.rs | 2 +- hercules_samples/matmul/build.rs | 6 +- hercules_samples/matmul/src/main.rs | 6 +- .../hercules_interpreter/src/interpreter.rs | 76 ++++---- .../hercules_interpreter/src/value.rs | 10 +- .../hercules_tests/tests/loop_tests.rs | 40 ++--- juno_frontend/src/semant.rs | 11 +- juno_samples/cava/src/main.rs | 45 ++--- juno_samples/concat/src/main.rs | 4 +- juno_samples/edge_detection/src/main.rs | 11 +- juno_samples/matmul/src/main.rs | 18 +- juno_samples/nested_ccp/src/main.rs | 2 +- juno_samples/patterns/src/main.rs | 2 +- juno_samples/schedule_test/build.rs | 6 +- juno_samples/schedule_test/src/main.rs | 13 +- juno_samples/simple3/src/main.rs | 2 +- juno_scheduler/src/compile.rs | 2 +- juno_scheduler/src/ir.rs | 2 +- juno_scheduler/src/pm.rs | 31 ++-- 25 files changed, 336 insertions(+), 235 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 7a70825a..af7902c6 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1236,17 +1236,6 @@ dependencies = [ "with_builtin_macros", ] -[[package]] -name = "juno_test" -version = "0.1.0" -dependencies = [ - "async-std", - "hercules_rt", - "juno_build", - "rand", - "with_builtin_macros", -] - [[package]] name = "juno_utils" version = "0.1.0" diff --git a/hercules_cg/src/fork_tree.rs b/hercules_cg/src/fork_tree.rs index 64a93160..c048f7e3 100644 --- a/hercules_cg/src/fork_tree.rs +++ b/hercules_cg/src/fork_tree.rs @@ -9,11 +9,16 @@ use crate::*; * c) no domination by any other fork that's also dominated by F, where we do count self-domination * Here too we include the non-fork start node, as key for all controls outside any fork. */ -pub fn fork_control_map(fork_join_nesting: &HashMap<NodeID, Vec<NodeID>>) -> HashMap<NodeID, HashSet<NodeID>> { +pub fn fork_control_map( + fork_join_nesting: &HashMap<NodeID, Vec<NodeID>>, +) -> HashMap<NodeID, HashSet<NodeID>> { let mut fork_control_map = HashMap::new(); for (control, forks) in fork_join_nesting { let fork = forks.first().copied().unwrap_or(NodeID::new(0)); - fork_control_map.entry(fork).or_insert_with(HashSet::new).insert(*control); + fork_control_map + .entry(fork) + .or_insert_with(HashSet::new) + .insert(*control); } fork_control_map } @@ -24,13 +29,19 @@ pub fn fork_control_map(fork_join_nesting: &HashMap<NodeID, Vec<NodeID>>) -> Has * c) no domination by any other fork that's also dominated by F, where we don't count self-domination * Note that the fork_tree also includes the non-fork start node, as unique root node. */ -pub fn fork_tree(function: &Function, fork_join_nesting: &HashMap<NodeID, Vec<NodeID>>) -> HashMap<NodeID, HashSet<NodeID>> { +pub fn fork_tree( + function: &Function, + fork_join_nesting: &HashMap<NodeID, Vec<NodeID>>, +) -> HashMap<NodeID, HashSet<NodeID>> { let mut fork_tree = HashMap::new(); for (control, forks) in fork_join_nesting { if function.nodes[control.idx()].is_fork() { fork_tree.entry(*control).or_insert_with(HashSet::new); let nesting_fork = forks.get(1).copied().unwrap_or(NodeID::new(0)); - fork_tree.entry(nesting_fork).or_insert_with(HashSet::new).insert(*control); + fork_tree + .entry(nesting_fork) + .or_insert_with(HashSet::new) + .insert(*control); } } fork_tree diff --git a/hercules_opt/src/fork_guard_elim.rs b/hercules_opt/src/fork_guard_elim.rs index f6914b74..df40e60f 100644 --- a/hercules_opt/src/fork_guard_elim.rs +++ b/hercules_opt/src/fork_guard_elim.rs @@ -76,13 +76,16 @@ fn guarded_fork( }; // Filter out any terms which are just 1s - let non_ones = xs.iter().filter(|i| { - if let DynamicConstant::Constant(1) = editor.get_dynamic_constant(**i).deref() { - false - } else { - true - } - }).collect::<Vec<_>>(); + let non_ones = xs + .iter() + .filter(|i| { + if let DynamicConstant::Constant(1) = editor.get_dynamic_constant(**i).deref() { + false + } else { + true + } + }) + .collect::<Vec<_>>(); // If we're left with just one term x, we had max { 1, x } if non_ones.len() == 1 { Factor::Max(idx, *non_ones[0]) @@ -326,7 +329,7 @@ pub fn fork_guard_elim(editor: &mut FunctionEditor, fork_join_map: &HashMap<Node guard_if, guard_join_region, } in guard_info - { + { let Some(guard_pred) = editor.get_uses(guard_if).next() else { unreachable!() }; diff --git a/hercules_opt/src/fork_transforms.rs b/hercules_opt/src/fork_transforms.rs index 190dbd25..ed6283fd 100644 --- a/hercules_opt/src/fork_transforms.rs +++ b/hercules_opt/src/fork_transforms.rs @@ -708,14 +708,13 @@ pub fn chunk_all_forks_unguarded( Ok(edit) }); - for (fork, _ ) in fork_join_map { + for (fork, _) in fork_join_map { chunk_fork_unguarded(editor, *fork, dim_idx, dc_id); } - } -// Splits a dimension of a single fork join into multiple. -// Iterates an outer loop original_dim / tile_size times -// adds a tile_size loop as the inner loop +// Splits a dimension of a single fork join into multiple. +// Iterates an outer loop original_dim / tile_size times +// adds a tile_size loop as the inner loop // Assumes that tile size divides original dim evenly. pub fn chunk_fork_unguarded( editor: &mut FunctionEditor, @@ -724,42 +723,68 @@ pub fn chunk_fork_unguarded( tile_size: DynamicConstantID, ) -> () { // tid_dim_idx = tid_dim_idx * tile_size + tid_(dim_idx + 1) - - - let Node::Fork { control: old_control, factors: ref old_factors} = *editor.node(fork) else {return}; - - if dim_idx >= old_factors.len() { - return; // FIXME Error here? - } - + let Node::Fork { + control: old_control, + factors: ref old_factors, + } = *editor.node(fork) + else { + return; + }; + assert!(dim_idx < old_factors.len()); let mut new_factors: Vec<_> = old_factors.to_vec(); - - let fork_users: Vec<_> = editor.get_users(fork).map(|f| (f, editor.node(f).clone())).collect(); + let fork_users: Vec<_> = editor + .get_users(fork) + .map(|f| (f, editor.node(f).clone())) + .collect(); editor.edit(|mut edit| { let outer = DynamicConstant::div(new_factors[dim_idx], tile_size); new_factors.insert(dim_idx + 1, tile_size); new_factors[dim_idx] = edit.add_dynamic_constant(outer); - let new_fork = Node::Fork { control: old_control, factors: new_factors.into() }; + let new_fork = Node::Fork { + control: old_control, + factors: new_factors.into(), + }; let new_fork = edit.add_node(new_fork); edit = edit.replace_all_uses(fork, new_fork)?; for (tid, node) in fork_users { - let Node::ThreadID { control: _, dimension: tid_dim } = node else {continue}; + let Node::ThreadID { + control: _, + dimension: tid_dim, + } = node + else { + continue; + }; if tid_dim > dim_idx { - let new_tid = Node::ThreadID { control: new_fork, dimension: tid_dim + 1 }; + let new_tid = Node::ThreadID { + control: new_fork, + dimension: tid_dim + 1, + }; let new_tid = edit.add_node(new_tid); edit = edit.replace_all_uses(tid, new_tid)?; + edit = edit.delete_node(tid)?; } else if tid_dim == dim_idx { - let tile_tid = Node::ThreadID { control: new_fork, dimension: tid_dim + 1 }; + let tile_tid = Node::ThreadID { + control: new_fork, + dimension: tid_dim + 1, + }; let tile_tid = edit.add_node(tile_tid); - + let tile_size = edit.add_node(Node::DynamicConstant { id: tile_size }); - let mul = edit.add_node(Node::Binary { left: tid, right: tile_size, op: BinaryOperator::Mul }); - let add = edit.add_node(Node::Binary { left: mul, right: tile_tid, op: BinaryOperator::Add }); - edit = edit.replace_all_uses_where(tid, add, |usee| *usee != mul )?; + let mul = edit.add_node(Node::Binary { + left: tid, + right: tile_size, + op: BinaryOperator::Mul, + }); + let add = edit.add_node(Node::Binary { + left: mul, + right: tile_tid, + op: BinaryOperator::Add, + }); + edit = edit.replace_all_uses_where(tid, add, |usee| *usee != mul)?; } } edit = edit.delete_node(fork)?; @@ -767,13 +792,13 @@ pub fn chunk_fork_unguarded( }); } - -pub fn merge_all_fork_dims( - editor: &mut FunctionEditor, - fork_join_map: &HashMap<NodeID, NodeID>, -) { +pub fn merge_all_fork_dims(editor: &mut FunctionEditor, fork_join_map: &HashMap<NodeID, NodeID>) { for (fork, _) in fork_join_map { - let Node::Fork { control: _, factors: dims } = editor.node(fork) else { + let Node::Fork { + control: _, + factors: dims, + } = editor.node(fork) + else { unreachable!(); }; @@ -786,10 +811,6 @@ pub fn merge_all_fork_dims( } } -// Splits a dimension of a single fork join into multiple. -// Iterates an outer loop original_dim / tile_size times -// adds a tile_size loop as the inner loop -// Assumes that tile size divides original dim evenly. pub fn fork_dim_merge( editor: &mut FunctionEditor, fork: NodeID, @@ -806,61 +827,85 @@ pub fn fork_dim_merge( } else { (dim_idx1, dim_idx2) }; - - let Node::Fork { control: old_control, factors: ref old_factors} = *editor.node(fork) else {return fork}; - + let Node::Fork { + control: old_control, + factors: ref old_factors, + } = *editor.node(fork) + else { + return fork; + }; let mut new_factors: Vec<_> = old_factors.to_vec(); - - - let fork_users: Vec<_> = editor.get_users(fork).map(|f| (f, editor.node(f).clone())).collect(); - + let fork_users: Vec<_> = editor + .get_users(fork) + .map(|f| (f, editor.node(f).clone())) + .collect(); let mut new_nodes = vec![]; - let outer_dc_id = new_factors[outer_idx]; let inner_dc_id = new_factors[inner_idx]; - - let mut new_fork_id = NodeID::new(0); + let mut new_fork = NodeID::new(0); editor.edit(|mut edit| { - new_factors[outer_idx] = edit.add_dynamic_constant(DynamicConstant::mul(new_factors[outer_idx], new_factors[inner_idx])); + new_factors[outer_idx] = edit.add_dynamic_constant(DynamicConstant::mul( + new_factors[outer_idx], + new_factors[inner_idx], + )); new_factors.remove(inner_idx); - - let new_fork = Node::Fork { control: old_control, factors: new_factors.into() }; - let new_fork = edit.add_node(new_fork); - new_fork_id = new_fork; - + new_fork = edit.add_node(Node::Fork { + control: old_control, + factors: new_factors.into(), + }); edit.sub_edit(fork, new_fork); - edit = edit.replace_all_uses(fork, new_fork)?; edit = edit.delete_node(fork)?; for (tid, node) in fork_users { - // FIXME: DO we want sub edits in this? - - let Node::ThreadID { control: _, dimension: tid_dim } = node else { continue }; + let Node::ThreadID { + control: _, + dimension: tid_dim, + } = node + else { + continue; + }; if tid_dim > inner_idx { - let new_tid = Node::ThreadID { control: new_fork_id, dimension: tid_dim - 1 }; + let new_tid = Node::ThreadID { + control: new_fork_id, + dimension: tid_dim - 1, + }; let new_tid = edit.add_node(new_tid); edit = edit.replace_all_uses(tid, new_tid)?; edit.sub_edit(tid, new_tid); } else if tid_dim == outer_idx { - let outer_tid = Node::ThreadID { control: new_fork_id, dimension: outer_idx }; + let outer_tid = Node::ThreadID { + control: new_fork_id, + dimension: outer_idx, + }; let outer_tid = edit.add_node(outer_tid); let outer_dc = edit.add_node(Node::DynamicConstant { id: outer_dc_id }); new_nodes.push(outer_tid); // inner_idx % dim(outer_idx) - let rem = edit.add_node(Node::Binary { left: outer_tid, right: outer_dc, op: BinaryOperator::Rem}); + let rem = edit.add_node(Node::Binary { + left: outer_tid, + right: outer_dc, + op: BinaryOperator::Rem, + }); edit.sub_edit(tid, rem); edit = edit.replace_all_uses(tid, rem)?; } else if tid_dim == inner_idx { - let outer_tid = Node::ThreadID { control: new_fork_id, dimension: outer_idx }; + let outer_tid = Node::ThreadID { + control: new_fork_id, + dimension: outer_idx, + }; let outer_tid = edit.add_node(outer_tid); let outer_dc = edit.add_node(Node::DynamicConstant { id: outer_dc_id }); // inner_idx / dim(outer_idx) - let div = edit.add_node(Node::Binary { left: outer_tid, right: outer_dc, op: BinaryOperator::Div}); + let div = edit.add_node(Node::Binary { + left: outer_tid, + right: outer_dc, + op: BinaryOperator::Div, + }); edit.sub_edit(tid, div); edit = edit.replace_all_uses(tid, div)?; } @@ -868,6 +913,5 @@ pub fn fork_dim_merge( Ok(edit) }); - return new_fork_id; - -} \ No newline at end of file + new_fork +} diff --git a/hercules_opt/src/forkify.rs b/hercules_opt/src/forkify.rs index 0a2d5601..f6db06ca 100644 --- a/hercules_opt/src/forkify.rs +++ b/hercules_opt/src/forkify.rs @@ -300,32 +300,33 @@ pub fn forkify_loop( // Start failable edit: - let redcutionable_phis_and_init: Vec<(_, NodeID)> = - reductionable_phis.iter().map(|reduction_phi| { - - let LoopPHI::Reductionable { - phi, - data_cycle: _, - continue_latch, - is_associative: _, - } = reduction_phi - else { - panic!(); - }; + let redcutionable_phis_and_init: Vec<(_, NodeID)> = reductionable_phis + .iter() + .map(|reduction_phi| { + let LoopPHI::Reductionable { + phi, + data_cycle: _, + continue_latch, + is_associative: _, + } = reduction_phi + else { + panic!(); + }; - let function = editor.func(); + let function = editor.func(); - let init = *zip( - editor.get_uses(l.header), - function.nodes[phi.idx()].try_phi().unwrap().1.iter(), - ) - .filter(|(c, _)| *c == loop_pred) - .next() - .unwrap() - .1; + let init = *zip( + editor.get_uses(l.header), + function.nodes[phi.idx()].try_phi().unwrap().1.iter(), + ) + .filter(|(c, _)| *c == loop_pred) + .next() + .unwrap() + .1; - (reduction_phi, init) - }).collect(); + (reduction_phi, init) + }) + .collect(); editor.edit(|mut edit| { let thread_id = Node::ThreadID { @@ -351,14 +352,13 @@ pub fn forkify_loop( else { panic!(); }; - + let reduce = Node::Reduce { control: join_id, init, reduct: continue_latch, }; - - + let reduce_id = edit.add_node(reduce); if (!edit.get_node(init).is_reduce() @@ -387,7 +387,6 @@ pub fn forkify_loop( !loop_nodes.contains(usee) && *usee != reduce_id })?; edit = edit.delete_node(phi)? - } edit = edit.replace_all_uses(l.header, fork_id)?; @@ -401,7 +400,7 @@ pub fn forkify_loop( edit = edit.delete_node(l.header)?; Ok(edit) }); - + return true; } @@ -538,7 +537,11 @@ pub fn analyze_phis<'a>( // by the time the reduce is triggered (at the end of the loop's internal control). // If anything in the intersection is a phi (that isn't this own phi), then the reduction cycle depends on control. // Which is not allowed. - if intersection.iter().any(|cycle_node| editor.node(cycle_node).is_phi() && *cycle_node != *phi) || editor.node(loop_continue_latch).is_phi() { + if intersection + .iter() + .any(|cycle_node| editor.node(cycle_node).is_phi() && *cycle_node != *phi) + || editor.node(loop_continue_latch).is_phi() + { return LoopPHI::ControlDependant(*phi); } diff --git a/hercules_samples/dot/build.rs b/hercules_samples/dot/build.rs index 8657fdc1..c8de7e90 100644 --- a/hercules_samples/dot/build.rs +++ b/hercules_samples/dot/build.rs @@ -4,7 +4,11 @@ fn main() { JunoCompiler::new() .ir_in_src("dot.hir") .unwrap() - .schedule_in_src(if cfg!(feature = "cuda") { "gpu.sch" } else { "cpu.sch" }) + .schedule_in_src(if cfg!(feature = "cuda") { + "gpu.sch" + } else { + "cpu.sch" + }) .unwrap() .build() .unwrap(); diff --git a/hercules_samples/dot/src/main.rs b/hercules_samples/dot/src/main.rs index 8862c11a..7f5b453a 100644 --- a/hercules_samples/dot/src/main.rs +++ b/hercules_samples/dot/src/main.rs @@ -1,8 +1,8 @@ #![feature(concat_idents)] -use hercules_rt::{runner, HerculesCPURef}; #[cfg(feature = "cuda")] use hercules_rt::CUDABox; +use hercules_rt::{runner, HerculesCPURef}; juno_build::juno!("dot"); diff --git a/hercules_samples/matmul/build.rs b/hercules_samples/matmul/build.rs index 735458c0..ed92e022 100644 --- a/hercules_samples/matmul/build.rs +++ b/hercules_samples/matmul/build.rs @@ -4,7 +4,11 @@ fn main() { JunoCompiler::new() .ir_in_src("matmul.hir") .unwrap() - .schedule_in_src(if cfg!(feature = "cuda") { "gpu.sch" } else { "cpu.sch" }) + .schedule_in_src(if cfg!(feature = "cuda") { + "gpu.sch" + } else { + "cpu.sch" + }) .unwrap() .build() .unwrap(); diff --git a/hercules_samples/matmul/src/main.rs b/hercules_samples/matmul/src/main.rs index abd25ec9..5c879915 100644 --- a/hercules_samples/matmul/src/main.rs +++ b/hercules_samples/matmul/src/main.rs @@ -2,9 +2,9 @@ use rand::random; -use hercules_rt::{runner, HerculesCPURef}; #[cfg(feature = "cuda")] use hercules_rt::CUDABox; +use hercules_rt::{runner, HerculesCPURef}; juno_build::juno!("matmul"); @@ -36,7 +36,9 @@ fn main() { let a = CUDABox::from_cpu_ref(HerculesCPURef::from_slice(&mut a)); let b = CUDABox::from_cpu_ref(HerculesCPURef::from_slice(&mut b)); let mut r = runner!(matmul); - let c = r.run(I as u64, J as u64, K as u64, a.get_ref(), b.get_ref()).await; + let c = r + .run(I as u64, J as u64, K as u64, a.get_ref(), b.get_ref()) + .await; let mut c_cpu: Box<[i32]> = vec![0; correct_c.len()].into_boxed_slice(); c.to_cpu_ref(&mut c_cpu); assert_eq!(&*c_cpu, &*correct_c); diff --git a/hercules_test/hercules_interpreter/src/interpreter.rs b/hercules_test/hercules_interpreter/src/interpreter.rs index 22ef062a..2e352644 100644 --- a/hercules_test/hercules_interpreter/src/interpreter.rs +++ b/hercules_test/hercules_interpreter/src/interpreter.rs @@ -69,18 +69,18 @@ pub fn dyn_const_value( match dc { DynamicConstant::Constant(v) => *v, DynamicConstant::Parameter(v) => dyn_const_params[*v], - DynamicConstant::Add(xs) => { - xs.iter().map(|x| dyn_const_value(x, dyn_const_values, dyn_const_params)) - .fold(0, |s, v| s + v) - } + DynamicConstant::Add(xs) => xs + .iter() + .map(|x| dyn_const_value(x, dyn_const_values, dyn_const_params)) + .fold(0, |s, v| s + v), DynamicConstant::Sub(a, b) => { dyn_const_value(a, dyn_const_values, dyn_const_params) - dyn_const_value(b, dyn_const_values, dyn_const_params) } - DynamicConstant::Mul(xs) => { - xs.iter().map(|x| dyn_const_value(x, dyn_const_values, dyn_const_params)) - .fold(1, |p, v| p * v) - } + DynamicConstant::Mul(xs) => xs + .iter() + .map(|x| dyn_const_value(x, dyn_const_values, dyn_const_params)) + .fold(1, |p, v| p * v), DynamicConstant::Div(a, b) => { dyn_const_value(a, dyn_const_values, dyn_const_params) / dyn_const_value(b, dyn_const_values, dyn_const_params) @@ -89,28 +89,28 @@ pub fn dyn_const_value( dyn_const_value(a, dyn_const_values, dyn_const_params) % dyn_const_value(b, dyn_const_values, dyn_const_params) } - DynamicConstant::Max(xs) => { - xs.iter().map(|x| dyn_const_value(x, dyn_const_values, dyn_const_params)) - .fold(None, |m, v| { - if let Some(m) = m { - Some(max(m, v)) - } else { - Some(v) - } - }) - .unwrap() - } - DynamicConstant::Min(xs) => { - xs.iter().map(|x| dyn_const_value(x, dyn_const_values, dyn_const_params)) - .fold(None, |m, v| { - if let Some(m) = m { - Some(min(m, v)) - } else { - Some(v) - } - }) - .unwrap() - } + DynamicConstant::Max(xs) => xs + .iter() + .map(|x| dyn_const_value(x, dyn_const_values, dyn_const_params)) + .fold(None, |m, v| { + if let Some(m) = m { + Some(max(m, v)) + } else { + Some(v) + } + }) + .unwrap(), + DynamicConstant::Min(xs) => xs + .iter() + .map(|x| dyn_const_value(x, dyn_const_values, dyn_const_params)) + .fold(None, |m, v| { + if let Some(m) = m { + Some(min(m, v)) + } else { + Some(v) + } + }) + .unwrap(), } } @@ -775,15 +775,13 @@ impl<'a> FunctionExecutionState<'a> { // panic!("multi-dimensional forks unimplemented") // } - let factors = factors - .iter() - .map(|f| { - dyn_const_value( - &f, - &self.module.dynamic_constants, - &self.dynamic_constant_params, - ) - }); + let factors = factors.iter().map(|f| { + dyn_const_value( + &f, + &self.module.dynamic_constants, + &self.dynamic_constant_params, + ) + }); let n_tokens: usize = factors.clone().product(); diff --git a/hercules_test/hercules_interpreter/src/value.rs b/hercules_test/hercules_interpreter/src/value.rs index adbed6e6..4a802f7a 100644 --- a/hercules_test/hercules_interpreter/src/value.rs +++ b/hercules_test/hercules_interpreter/src/value.rs @@ -156,7 +156,15 @@ impl<'a> InterpreterVal { Constant::Float64(v) => Self::Float64(v), Constant::Product(ref type_id, ref constant_ids) => { - let contents = constant_ids.iter().map(|const_id| InterpreterVal::from_constant(&constants[const_id.idx()], constants, types, dynamic_constants, dynamic_constant_params)); + let contents = constant_ids.iter().map(|const_id| { + InterpreterVal::from_constant( + &constants[const_id.idx()], + constants, + types, + dynamic_constants, + dynamic_constant_params, + ) + }); InterpreterVal::Product(*type_id, contents.collect_vec().into_boxed_slice()) } Constant::Summation(_, _, _) => todo!(), diff --git a/hercules_test/hercules_tests/tests/loop_tests.rs b/hercules_test/hercules_tests/tests/loop_tests.rs index 192c1366..795642b2 100644 --- a/hercules_test/hercules_tests/tests/loop_tests.rs +++ b/hercules_test/hercules_tests/tests/loop_tests.rs @@ -35,9 +35,7 @@ fn alternate_bounds_use_after_loop_no_tid() { println!("result: {:?}", result_1); - let schedule = default_schedule![ - Forkify, - ]; + let schedule = default_schedule![Forkify,]; let module = run_schedule_on_hercules(module, Some(schedule)).unwrap(); @@ -61,9 +59,7 @@ fn alternate_bounds_use_after_loop() { println!("result: {:?}", result_1); - let schedule = Some(default_schedule![ - Forkify, - ]); + let schedule = Some(default_schedule![Forkify,]); let module = run_schedule_on_hercules(module, schedule).unwrap(); @@ -108,10 +104,7 @@ fn do_while_separate_body() { println!("result: {:?}", result_1); - let schedule = Some(default_schedule![ - PhiElim, - Forkify, - ]); + let schedule = Some(default_schedule![PhiElim, Forkify,]); let module = run_schedule_on_hercules(module, schedule).unwrap(); @@ -131,10 +124,7 @@ fn alternate_bounds_internal_control() { println!("result: {:?}", result_1); - let schedule = Some(default_schedule![ - PhiElim, - Forkify, - ]); + let schedule = Some(default_schedule![PhiElim, Forkify,]); let module = run_schedule_on_hercules(module, schedule).unwrap(); @@ -155,10 +145,7 @@ fn alternate_bounds_internal_control2() { println!("result: {:?}", result_1); - let schedule = Some(default_schedule![ - PhiElim, - Forkify, - ]); + let schedule = Some(default_schedule![PhiElim, Forkify,]); let module = run_schedule_on_hercules(module, schedule).unwrap(); @@ -366,16 +353,13 @@ fn look_at_local() { "/home/xavierrouth/dev/hercules/hercules_test/hercules_tests/save_me.hbin", ); - let schedule = Some(default_schedule![ - ]); + let schedule = Some(default_schedule![]); let result_1 = interp_module!(module, 0, dyn_consts, a.clone(), b.clone()); let module = run_schedule_on_hercules(module.clone(), schedule).unwrap(); - let schedule = Some(default_schedule![ - Unforkify, Verify, - ]); + let schedule = Some(default_schedule![Unforkify, Verify,]); let module = run_schedule_on_hercules(module.clone(), schedule).unwrap(); @@ -425,7 +409,15 @@ fn matmul_pipeline() { }; assert_eq!(correct_c[0], value); - let schedule = Some(default_schedule![AutoOutline, InterproceduralSROA, SROA, InferSchedules, DCE, Xdot, GCM]); + let schedule = Some(default_schedule![ + AutoOutline, + InterproceduralSROA, + SROA, + InferSchedules, + DCE, + Xdot, + GCM + ]); module = run_schedule_on_hercules(module, schedule).unwrap(); diff --git a/juno_frontend/src/semant.rs b/juno_frontend/src/semant.rs index e133e3c2..8668d1b4 100644 --- a/juno_frontend/src/semant.rs +++ b/juno_frontend/src/semant.rs @@ -752,7 +752,16 @@ fn analyze_program( } arg_info.push((ty, inout.is_some(), var)); - match process_irrefutable_pattern(pattern, false, var, ty, lexer, &mut stringtab, &mut env, &mut types) { + match process_irrefutable_pattern( + pattern, + false, + var, + ty, + lexer, + &mut stringtab, + &mut env, + &mut types, + ) { Ok(prep) => { stmts.extend(prep); } diff --git a/juno_samples/cava/src/main.rs b/juno_samples/cava/src/main.rs index 482bbf8d..e8a7e4e9 100644 --- a/juno_samples/cava/src/main.rs +++ b/juno_samples/cava/src/main.rs @@ -8,9 +8,9 @@ use self::camera_model::*; use self::cava_rust::CHAN; use self::image_proc::*; -use hercules_rt::{runner, HerculesCPURef}; #[cfg(feature = "cuda")] use hercules_rt::CUDABox; +use hercules_rt::{runner, HerculesCPURef}; use image::ImageError; @@ -31,7 +31,6 @@ fn run_cava( coefs: &[f32], tonemap: &[f32], ) -> Box<[u8]> { - assert_eq!(image.len(), CHAN * rows * cols); assert_eq!(tstw.len(), CHAN * CHAN); assert_eq!(ctrl_pts.len(), num_ctrl_pts * CHAN); @@ -47,21 +46,24 @@ fn run_cava( let weights = HerculesCPURef::from_slice(weights); let coefs = HerculesCPURef::from_slice(coefs); let tonemap = HerculesCPURef::from_slice(tonemap); - let mut r = runner!(cava); - async_std::task::block_on(async { - r.run( - rows as u64, - cols as u64, - num_ctrl_pts as u64, - image, - tstw, - ctrl_pts, - weights, - coefs, - tonemap, - ) - .await - }).as_slice::<u8>().to_vec().into_boxed_slice() + let mut r = runner!(cava); + async_std::task::block_on(async { + r.run( + rows as u64, + cols as u64, + num_ctrl_pts as u64, + image, + tstw, + ctrl_pts, + weights, + coefs, + tonemap, + ) + .await + }) + .as_slice::<u8>() + .to_vec() + .into_boxed_slice() } #[cfg(feature = "cuda")] @@ -72,8 +74,8 @@ fn run_cava( let weights = CUDABox::from_cpu_ref(HerculesCPURef::from_slice(weights)); let coefs = CUDABox::from_cpu_ref(HerculesCPURef::from_slice(coefs)); let tonemap = CUDABox::from_cpu_ref(HerculesCPURef::from_slice(tonemap)); - let mut r = runner!(cava); - let res = async_std::task::block_on(async { + let mut r = runner!(cava); + let res = async_std::task::block_on(async { r.run( rows as u64, cols as u64, @@ -86,7 +88,7 @@ fn run_cava( tonemap.get_ref(), ) .await - }); + }); let num_out = unsafe { res.__size() / std::mem::size_of::<u8>() }; let mut res_cpu: Box<[u8]> = vec![0; num_out].into_boxed_slice(); res.to_cpu_ref(&mut res_cpu); @@ -204,7 +206,8 @@ fn cava_harness(args: CavaInputs) { .expect("Error saving verification image"); } - let max_diff = result.iter() + let max_diff = result + .iter() .zip(cpu_result.iter()) .map(|(a, b)| (*a as i16 - *b as i16).abs()) .max() diff --git a/juno_samples/concat/src/main.rs b/juno_samples/concat/src/main.rs index 9674c2c5..547dee08 100644 --- a/juno_samples/concat/src/main.rs +++ b/juno_samples/concat/src/main.rs @@ -1,9 +1,9 @@ #![feature(concat_idents)] use hercules_rt::runner; -use hercules_rt::HerculesCPURef; #[cfg(feature = "cuda")] use hercules_rt::CUDABox; +use hercules_rt::HerculesCPURef; juno_build::juno!("concat"); @@ -20,7 +20,7 @@ fn main() { assert_eq!(output, 42); const N: usize = 3; - let arr : Box<[i32]> = (2..=4).collect(); + let arr: Box<[i32]> = (2..=4).collect(); let arr = HerculesCPURef::from_slice(&arr); let mut r = runner!(concat_switch); diff --git a/juno_samples/edge_detection/src/main.rs b/juno_samples/edge_detection/src/main.rs index eda65016..3b067ebd 100644 --- a/juno_samples/edge_detection/src/main.rs +++ b/juno_samples/edge_detection/src/main.rs @@ -2,9 +2,9 @@ mod edge_detection_rust; -use hercules_rt::{runner, HerculesCPURef}; #[cfg(feature = "cuda")] use hercules_rt::CUDABox; +use hercules_rt::{runner, HerculesCPURef}; use std::slice::from_raw_parts; @@ -228,9 +228,9 @@ fn edge_detection_harness(args: EdgeDetectionInputs) { }); #[cfg(not(feature = "cuda"))] - let result : Box<[f32]> = result.as_slice::<f32>().to_vec().into_boxed_slice(); + let result: Box<[f32]> = result.as_slice::<f32>().to_vec().into_boxed_slice(); #[cfg(feature = "cuda")] - let result : Box<[f32]> = { + let result: Box<[f32]> = { let num_out = unsafe { result.__size() / std::mem::size_of::<f32>() }; let mut res_cpu: Box<[f32]> = vec![0.0; num_out].into_boxed_slice(); result.to_cpu_ref(&mut res_cpu); @@ -261,7 +261,10 @@ fn edge_detection_harness(args: EdgeDetectionInputs) { theta, ); - assert_eq!(result.as_ref(), <Vec<f32> as AsRef<[f32]>>::as_ref(&rust_result)); + assert_eq!( + result.as_ref(), + <Vec<f32> as AsRef<[f32]>>::as_ref(&rust_result) + ); println!("Frames {} match", i); if display_verify { diff --git a/juno_samples/matmul/src/main.rs b/juno_samples/matmul/src/main.rs index 50fe1760..2892cd34 100644 --- a/juno_samples/matmul/src/main.rs +++ b/juno_samples/matmul/src/main.rs @@ -2,9 +2,9 @@ use rand::random; -use hercules_rt::{runner, HerculesCPURef}; #[cfg(feature = "cuda")] use hercules_rt::CUDABox; +use hercules_rt::{runner, HerculesCPURef}; juno_build::juno!("matmul"); @@ -28,10 +28,14 @@ fn main() { let a = HerculesCPURef::from_slice(&a); let b = HerculesCPURef::from_slice(&b); let mut r = runner!(matmul); - let c = r.run(I as u64, J as u64, K as u64, a.clone(), b.clone()).await; + let c = r + .run(I as u64, J as u64, K as u64, a.clone(), b.clone()) + .await; assert_eq!(c.as_slice::<i32>(), &*correct_c); let mut r = runner!(tiled_64_matmul); - let tiled_c = r.run(I as u64, J as u64, K as u64, a.clone(), b.clone()).await; + let tiled_c = r + .run(I as u64, J as u64, K as u64, a.clone(), b.clone()) + .await; assert_eq!(tiled_c.as_slice::<i32>(), &*correct_c); } #[cfg(feature = "cuda")] @@ -39,12 +43,16 @@ fn main() { let a = CUDABox::from_cpu_ref(HerculesCPURef::from_slice(&mut a)); let b = CUDABox::from_cpu_ref(HerculesCPURef::from_slice(&mut b)); let mut r = runner!(matmul); - let c = r.run(I as u64, J as u64, K as u64, a.get_ref(), b.get_ref()).await; + let c = r + .run(I as u64, J as u64, K as u64, a.get_ref(), b.get_ref()) + .await; let mut c_cpu: Box<[i32]> = vec![0; correct_c.len()].into_boxed_slice(); c.to_cpu_ref(&mut c_cpu); assert_eq!(&*c_cpu, &*correct_c); let mut r = runner!(tiled_64_matmul); - let tiled_c = r.run(I as u64, J as u64, K as u64, a.get_ref(), b.get_ref()).await; + let tiled_c = r + .run(I as u64, J as u64, K as u64, a.get_ref(), b.get_ref()) + .await; let mut tiled_c_cpu: Box<[i32]> = vec![0; correct_c.len()].into_boxed_slice(); tiled_c.to_cpu_ref(&mut tiled_c_cpu); assert_eq!(&*tiled_c_cpu, &*correct_c); diff --git a/juno_samples/nested_ccp/src/main.rs b/juno_samples/nested_ccp/src/main.rs index bc99a4bd..b364c03c 100644 --- a/juno_samples/nested_ccp/src/main.rs +++ b/juno_samples/nested_ccp/src/main.rs @@ -1,8 +1,8 @@ #![feature(concat_idents)] -use hercules_rt::{runner, HerculesCPURef, HerculesCPURefMut}; #[cfg(feature = "cuda")] use hercules_rt::CUDABox; +use hercules_rt::{runner, HerculesCPURef, HerculesCPURefMut}; juno_build::juno!("nested_ccp"); diff --git a/juno_samples/patterns/src/main.rs b/juno_samples/patterns/src/main.rs index 5cc2e7c8..a5586c8b 100644 --- a/juno_samples/patterns/src/main.rs +++ b/juno_samples/patterns/src/main.rs @@ -1,6 +1,6 @@ #![feature(concat_idents)] -use hercules_rt::{runner}; +use hercules_rt::runner; juno_build::juno!("patterns"); diff --git a/juno_samples/schedule_test/build.rs b/juno_samples/schedule_test/build.rs index 749a660c..0129c4de 100644 --- a/juno_samples/schedule_test/build.rs +++ b/juno_samples/schedule_test/build.rs @@ -4,7 +4,11 @@ fn main() { JunoCompiler::new() .file_in_src("code.jn") .unwrap() - .schedule_in_src(if cfg!(feature = "cuda") { "gpu.sch" } else { "cpu.sch" }) + .schedule_in_src(if cfg!(feature = "cuda") { + "gpu.sch" + } else { + "cpu.sch" + }) .unwrap() .build() .unwrap(); diff --git a/juno_samples/schedule_test/src/main.rs b/juno_samples/schedule_test/src/main.rs index 1505d4e5..f769e750 100644 --- a/juno_samples/schedule_test/src/main.rs +++ b/juno_samples/schedule_test/src/main.rs @@ -2,9 +2,9 @@ use rand::random; -use hercules_rt::{runner, HerculesCPURef}; #[cfg(feature = "cuda")] use hercules_rt::CUDABox; +use hercules_rt::{runner, HerculesCPURef}; juno_build::juno!("code"); @@ -43,7 +43,16 @@ fn main() { let b = CUDABox::from_cpu_ref(HerculesCPURef::from_slice(&b)); let c = CUDABox::from_cpu_ref(HerculesCPURef::from_slice(&c)); let mut r = runner!(test); - let res = r.run(N as u64, M as u64, K as u64, a.get_ref(), b.get_ref(), c.get_ref()).await; + let res = r + .run( + N as u64, + M as u64, + K as u64, + a.get_ref(), + b.get_ref(), + c.get_ref(), + ) + .await; let mut res_cpu: Box<[i32]> = vec![0; correct_res.len()].into_boxed_slice(); res.to_cpu_ref(&mut res_cpu); assert_eq!(&*res_cpu, &*correct_res); diff --git a/juno_samples/simple3/src/main.rs b/juno_samples/simple3/src/main.rs index 8eb78f7c..687ff414 100644 --- a/juno_samples/simple3/src/main.rs +++ b/juno_samples/simple3/src/main.rs @@ -1,8 +1,8 @@ #![feature(concat_idents)] -use hercules_rt::{runner, HerculesCPURef}; #[cfg(feature = "cuda")] use hercules_rt::CUDABox; +use hercules_rt::{runner, HerculesCPURef}; juno_build::juno!("simple3"); diff --git a/juno_scheduler/src/compile.rs b/juno_scheduler/src/compile.rs index ea06a0f2..713c30d4 100644 --- a/juno_scheduler/src/compile.rs +++ b/juno_scheduler/src/compile.rs @@ -108,7 +108,7 @@ impl FromStr for Appliable { "inline" => Ok(Appliable::Pass(ir::Pass::Inline)), "ip-sroa" | "interprocedural-sroa" => { Ok(Appliable::Pass(ir::Pass::InterproceduralSROA)) - }, + } "fork-dim-merge" => Ok(Appliable::Pass(ir::Pass::ForkDimMerge)), "fork-chunk" | "fork-tile" => Ok(Appliable::Pass(ir::Pass::ForkChunk)), "lift-dc-math" => Ok(Appliable::Pass(ir::Pass::LiftDCMath)), diff --git a/juno_scheduler/src/ir.rs b/juno_scheduler/src/ir.rs index 796437a7..9e85509f 100644 --- a/juno_scheduler/src/ir.rs +++ b/juno_scheduler/src/ir.rs @@ -36,7 +36,7 @@ impl Pass { pub fn num_args(&self) -> usize { match self { Pass::Xdot => 1, - Pass::ForkChunk => 3, + Pass::ForkChunk => 3, _ => 0, } } diff --git a/juno_scheduler/src/pm.rs b/juno_scheduler/src/pm.rs index d176b636..2142d5c5 100644 --- a/juno_scheduler/src/pm.rs +++ b/juno_scheduler/src/pm.rs @@ -1566,7 +1566,7 @@ fn run_pass( // this eventually. let c = forkify(&mut func, control_subgraph, fork_join_map, loop_nest); changed |= c; - inner_changed |= c; + inner_changed |= c; } pm.delete_gravestones(); pm.clear_analyses(); @@ -1921,24 +1921,32 @@ fn run_pass( let dim_idx = args.get(1); let Some(Value::Boolean { val: guarded_flag }) = args.get(2) else { - panic!(); // How to error here? + return Err(SchedulerError::PassError { + pass: "forkChunk".to_string(), + error: "expected boolean argument".to_string(), + }); }; let Some(Value::Integer { val: dim_idx }) = args.get(1) else { - panic!(); // How to error here? + return Err(SchedulerError::PassError { + pass: "forkChunk".to_string(), + error: "expected integer argument".to_string(), + }); }; let Some(Value::Integer { val: tile_size }) = args.get(0) else { - panic!(); // How to error here? + return Err(SchedulerError::PassError { + pass: "forkChunk".to_string(), + error: "expected integer argument".to_string(), + }); }; assert_eq!(*guarded_flag, true); pm.make_fork_join_maps(); let fork_join_maps = pm.fork_join_maps.take().unwrap(); - for (func, fork_join_map) in - build_selection(pm, selection) - .into_iter() - .zip(fork_join_maps.iter()) + for (func, fork_join_map) in build_selection(pm, selection) + .into_iter() + .zip(fork_join_maps.iter()) { let Some(mut func) = func else { continue; @@ -1953,10 +1961,9 @@ fn run_pass( assert!(args.is_empty()); pm.make_fork_join_maps(); let fork_join_maps = pm.fork_join_maps.take().unwrap(); - for (func, fork_join_map) in - build_selection(pm, selection) - .into_iter() - .zip(fork_join_maps.iter()) + for (func, fork_join_map) in build_selection(pm, selection) + .into_iter() + .zip(fork_join_maps.iter()) { let Some(mut func) = func else { continue; -- GitLab From 43404780896ebb6bd287c8a56783a08cded6852e Mon Sep 17 00:00:00 2001 From: Russel Arbore <russel.jma@gmail.com> Date: Mon, 3 Feb 2025 15:53:31 -0600 Subject: [PATCH 24/25] whoops --- hercules_opt/src/fork_transforms.rs | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/hercules_opt/src/fork_transforms.rs b/hercules_opt/src/fork_transforms.rs index ed6283fd..c4a6ba7f 100644 --- a/hercules_opt/src/fork_transforms.rs +++ b/hercules_opt/src/fork_transforms.rs @@ -868,7 +868,7 @@ pub fn fork_dim_merge( }; if tid_dim > inner_idx { let new_tid = Node::ThreadID { - control: new_fork_id, + control: new_fork, dimension: tid_dim - 1, }; let new_tid = edit.add_node(new_tid); @@ -876,7 +876,7 @@ pub fn fork_dim_merge( edit.sub_edit(tid, new_tid); } else if tid_dim == outer_idx { let outer_tid = Node::ThreadID { - control: new_fork_id, + control: new_fork, dimension: outer_idx, }; let outer_tid = edit.add_node(outer_tid); @@ -894,7 +894,7 @@ pub fn fork_dim_merge( edit = edit.replace_all_uses(tid, rem)?; } else if tid_dim == inner_idx { let outer_tid = Node::ThreadID { - control: new_fork_id, + control: new_fork, dimension: outer_idx, }; let outer_tid = edit.add_node(outer_tid); -- GitLab From 13bc8938b80f96711eb26738fec53044a7800136 Mon Sep 17 00:00:00 2001 From: Russel Arbore <russel.jma@gmail.com> Date: Tue, 4 Feb 2025 09:52:26 -0600 Subject: [PATCH 25/25] use editor modified() fork forkify --- juno_scheduler/src/pm.rs | 9 +++------ 1 file changed, 3 insertions(+), 6 deletions(-) diff --git a/juno_scheduler/src/pm.rs b/juno_scheduler/src/pm.rs index 80699fee..b2845913 100644 --- a/juno_scheduler/src/pm.rs +++ b/juno_scheduler/src/pm.rs @@ -1561,12 +1561,9 @@ fn run_pass( let Some(mut func) = func else { continue; }; - // TODO: uses direct return from forkify for now instead of - // func.modified, see comment on top of `forkify` for why. Fix - // this eventually. - let c = forkify(&mut func, control_subgraph, fork_join_map, loop_nest); - changed |= c; - inner_changed |= c; + forkify(&mut func, control_subgraph, fork_join_map, loop_nest); + changed |= func.modified(); + inner_changed |= func.modified(); } pm.delete_gravestones(); pm.clear_analyses(); -- GitLab