From 6a0c4a3410d40410dbdac7a92926583e9c6e640a Mon Sep 17 00:00:00 2001 From: Xavier Routh <xrouth2@illinois.edu> Date: Sat, 1 Feb 2025 15:04:24 -0600 Subject: [PATCH 01/33] control in reduce cycle fixes --- Cargo.lock | 11 ++++ Cargo.toml | 3 +- hercules_opt/src/forkify.rs | 12 +++++ hercules_opt/src/unforkify.rs | 2 +- .../hercules_tests/tests/loop_tests.rs | 4 +- juno_scheduler/src/pm.rs | 53 +++++++++++-------- 6 files changed, 59 insertions(+), 26 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 49630436..ad69bc72 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1181,6 +1181,17 @@ dependencies = [ "with_builtin_macros", ] +[[package]] +name = "juno_test" +version = "0.1.0" +dependencies = [ + "async-std", + "hercules_rt", + "juno_build", + "rand", + "with_builtin_macros", +] + [[package]] name = "juno_utils" version = "0.1.0" diff --git a/Cargo.toml b/Cargo.toml index ced011a9..46fc7eaa 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -21,7 +21,7 @@ members = [ "hercules_samples/ccp", "juno_samples/simple3", - "juno_samples/patterns", + "juno_samples/patterns", "juno_samples/matmul", "juno_samples/casts_and_intrinsics", "juno_samples/nested_ccp", @@ -30,4 +30,5 @@ members = [ "juno_samples/cava", "juno_samples/concat", "juno_samples/schedule_test", + "juno_samples/test", ] diff --git a/hercules_opt/src/forkify.rs b/hercules_opt/src/forkify.rs index ec4e9fbc..0f06627d 100644 --- a/hercules_opt/src/forkify.rs +++ b/hercules_opt/src/forkify.rs @@ -152,6 +152,7 @@ pub fn forkify_loop( .filter(|id| !l.control[id.idx()]) .collect(); + // FIXME: @xrouth if loop_preds.len() != 1 { return false; } @@ -388,6 +389,7 @@ nest! { is_associative: bool, }, LoopDependant(NodeID), + ControlDependant(NodeID), // This phi is redcutionable, but its cycle might depend on internal control within the loop. UsedByDependant(NodeID), } } @@ -398,6 +400,7 @@ impl LoopPHI { LoopPHI::Reductionable { phi, .. } => *phi, LoopPHI::LoopDependant(node_id) => *node_id, LoopPHI::UsedByDependant(node_id) => *node_id, + LoopPHI::ControlDependant(node_id) => *node_id, } } } @@ -415,6 +418,9 @@ pub fn analyze_phis<'a>( loop_nodes: &'a HashSet<NodeID>, ) -> impl Iterator<Item = LoopPHI> + 'a { + // We are also moving the phi from the top of the loop (the header), + // to the very end (the join). If there are uses of the phi somewhere in the loop, + // then they may try to use the phi (now a reduce) before it hits the join. // Find data cycles within the loop of this phi, // Start from the phis loop_continue_latch, and walk its uses until we find the original phi. @@ -509,6 +515,12 @@ pub fn analyze_phis<'a>( // to have headers that postdominate the loop continue latch. The value of the PHI used needs to be defined // by the time the reduce is triggered (at the end of the loop's internal control). + // If anything in the intersection is a phi (that isn't this own phi), then the reduction cycle depends on control. + // Which is not allowed. + if intersection.iter().any(|cycle_node| editor.node(cycle_node).is_phi() && *cycle_node != *phi) || editor.node(loop_continue_latch).is_phi() { + return LoopPHI::ControlDependant(*phi); + } + // No nodes in data cycles with this phi (in the loop) are used outside the loop, besides the loop_continue_latch. // If some other node in the cycle is used, there is not a valid node to assign it after making the cycle a reduce. if intersection diff --git a/hercules_opt/src/unforkify.rs b/hercules_opt/src/unforkify.rs index 85ffd233..7d158d1a 100644 --- a/hercules_opt/src/unforkify.rs +++ b/hercules_opt/src/unforkify.rs @@ -133,7 +133,7 @@ pub fn unforkify( if factors.len() > 1 { // For now, don't convert multi-dimensional fork-joins. Rely on pass // that splits fork-joins. - continue; + break; // Because we have to unforkify top down, we can't unforkify forks that are contained } let join_control = nodes[join.idx()].try_join().unwrap(); let tids: Vec<_> = editor diff --git a/hercules_test/hercules_tests/tests/loop_tests.rs b/hercules_test/hercules_tests/tests/loop_tests.rs index 5832a161..192c1366 100644 --- a/hercules_test/hercules_tests/tests/loop_tests.rs +++ b/hercules_test/hercules_tests/tests/loop_tests.rs @@ -401,7 +401,7 @@ fn matmul_pipeline() { let dyn_consts = [I, J, K]; // FIXME: This path should not leave the crate - let mut module = parse_module_from_hbin("../../juno_samples/matmul/out.hbin"); + let mut module = parse_module_from_hbin("../../juno_samples/test/out.hbin"); // let mut correct_c: Box<[i32]> = (0..I * K).map(|_| 0).collect(); for i in 0..I { @@ -425,7 +425,7 @@ fn matmul_pipeline() { }; assert_eq!(correct_c[0], value); - let schedule = Some(default_schedule![Xdot, ForkSplit, Unforkify, Xdot,]); + let schedule = Some(default_schedule![AutoOutline, InterproceduralSROA, SROA, InferSchedules, DCE, Xdot, GCM]); module = run_schedule_on_hercules(module, schedule).unwrap(); diff --git a/juno_scheduler/src/pm.rs b/juno_scheduler/src/pm.rs index 2371e0f2..d2772c71 100644 --- a/juno_scheduler/src/pm.rs +++ b/juno_scheduler/src/pm.rs @@ -1471,29 +1471,38 @@ fn run_pass( } Pass::Forkify => { assert!(args.is_empty()); - pm.make_fork_join_maps(); - pm.make_control_subgraphs(); - pm.make_loops(); - let fork_join_maps = pm.fork_join_maps.take().unwrap(); - let loops = pm.loops.take().unwrap(); - let control_subgraphs = pm.control_subgraphs.take().unwrap(); - for (((func, fork_join_map), loop_nest), control_subgraph) in - build_selection(pm, selection) - .into_iter() - .zip(fork_join_maps.iter()) - .zip(loops.iter()) - .zip(control_subgraphs.iter()) - { - let Some(mut func) = func else { - continue; - }; - // TODO: uses direct return from forkify for now instead of - // func.modified, see comment on top of `forkify` for why. Fix - // this eventually. - changed |= forkify(&mut func, control_subgraph, fork_join_map, loop_nest); + loop { + let mut inner_changed = false; + pm.make_fork_join_maps(); + pm.make_control_subgraphs(); + pm.make_loops(); + let fork_join_maps = pm.fork_join_maps.take().unwrap(); + let loops = pm.loops.take().unwrap(); + let control_subgraphs = pm.control_subgraphs.take().unwrap(); + for (((func, fork_join_map), loop_nest), control_subgraph) in + build_selection(pm, selection.clone()) + .into_iter() + .zip(fork_join_maps.iter()) + .zip(loops.iter()) + .zip(control_subgraphs.iter()) + { + let Some(mut func) = func else { + continue; + }; + // TODO: uses direct return from forkify for now instead of + // func.modified, see comment on top of `forkify` for why. Fix + // this eventually. + let c = forkify(&mut func, control_subgraph, fork_join_map, loop_nest); + changed |= c; + inner_changed |= c; + } + pm.delete_gravestones(); + pm.clear_analyses(); + + if !inner_changed { + break; + } } - pm.delete_gravestones(); - pm.clear_analyses(); } Pass::GCM => { assert!(args.is_empty()); -- GitLab From fab913636e8b63cef26af9db6df2eb699f415161 Mon Sep 17 00:00:00 2001 From: Xavier Routh <xrouth2@illinois.edu> Date: Sat, 1 Feb 2025 16:30:10 -0600 Subject: [PATCH 02/33] misc --- hercules_opt/src/forkify.rs | 1 - 1 file changed, 1 deletion(-) diff --git a/hercules_opt/src/forkify.rs b/hercules_opt/src/forkify.rs index 0f06627d..299422c1 100644 --- a/hercules_opt/src/forkify.rs +++ b/hercules_opt/src/forkify.rs @@ -514,7 +514,6 @@ pub fn analyze_phis<'a>( // PHIs on the frontier of the uses by the candidate phi, i.e in uses_for_dependance need // to have headers that postdominate the loop continue latch. The value of the PHI used needs to be defined // by the time the reduce is triggered (at the end of the loop's internal control). - // If anything in the intersection is a phi (that isn't this own phi), then the reduction cycle depends on control. // Which is not allowed. if intersection.iter().any(|cycle_node| editor.node(cycle_node).is_phi() && *cycle_node != *phi) || editor.node(loop_continue_latch).is_phi() { -- GitLab From 43b4022c38f65fcf5c2abad523c63f59888dabe9 Mon Sep 17 00:00:00 2001 From: Russel Arbore <russel.jma@gmail.com> Date: Sat, 1 Feb 2025 18:29:52 -0600 Subject: [PATCH 03/33] simple test --- Cargo.lock | 10 +++++++ Cargo.toml | 1 + juno_samples/fork_join_tests/Cargo.toml | 21 +++++++++++++ juno_samples/fork_join_tests/build.rs | 24 +++++++++++++++ .../fork_join_tests/src/fork_join_tests.jn | 10 +++++++ juno_samples/fork_join_tests/src/gpu.sch | 30 +++++++++++++++++++ juno_samples/fork_join_tests/src/main.rs | 17 +++++++++++ 7 files changed, 113 insertions(+) create mode 100644 juno_samples/fork_join_tests/Cargo.toml create mode 100644 juno_samples/fork_join_tests/build.rs create mode 100644 juno_samples/fork_join_tests/src/fork_join_tests.jn create mode 100644 juno_samples/fork_join_tests/src/gpu.sch create mode 100644 juno_samples/fork_join_tests/src/main.rs diff --git a/Cargo.lock b/Cargo.lock index b8bf2278..af7902c6 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1130,6 +1130,16 @@ dependencies = [ "with_builtin_macros", ] +[[package]] +name = "juno_fork_join_tests" +version = "0.1.0" +dependencies = [ + "async-std", + "hercules_rt", + "juno_build", + "with_builtin_macros", +] + [[package]] name = "juno_frontend" version = "0.1.0" diff --git a/Cargo.toml b/Cargo.toml index f7b9322a..890d7924 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -31,4 +31,5 @@ members = [ "juno_samples/concat", "juno_samples/schedule_test", "juno_samples/edge_detection", + "juno_samples/fork_join_tests", ] diff --git a/juno_samples/fork_join_tests/Cargo.toml b/juno_samples/fork_join_tests/Cargo.toml new file mode 100644 index 00000000..a109e782 --- /dev/null +++ b/juno_samples/fork_join_tests/Cargo.toml @@ -0,0 +1,21 @@ +[package] +name = "juno_fork_join_tests" +version = "0.1.0" +authors = ["Russel Arbore <rarbore2@illinois.edu>"] +edition = "2021" + +[[bin]] +name = "juno_fork_join_tests" +path = "src/main.rs" + +[features] +cuda = ["juno_build/cuda", "hercules_rt/cuda"] + +[build-dependencies] +juno_build = { path = "../../juno_build" } + +[dependencies] +juno_build = { path = "../../juno_build" } +hercules_rt = { path = "../../hercules_rt" } +with_builtin_macros = "0.1.0" +async-std = "*" diff --git a/juno_samples/fork_join_tests/build.rs b/juno_samples/fork_join_tests/build.rs new file mode 100644 index 00000000..796e9f32 --- /dev/null +++ b/juno_samples/fork_join_tests/build.rs @@ -0,0 +1,24 @@ +use juno_build::JunoCompiler; + +fn main() { + #[cfg(not(feature = "cuda"))] + { + JunoCompiler::new() + .file_in_src("fork_join_tests.jn") + .unwrap() + .schedule_in_src("cpu.sch") + .unwrap() + .build() + .unwrap(); + } + #[cfg(feature = "cuda")] + { + JunoCompiler::new() + .file_in_src("fork_join_tests.jn") + .unwrap() + .schedule_in_src("gpu.sch") + .unwrap() + .build() + .unwrap(); + } +} diff --git a/juno_samples/fork_join_tests/src/fork_join_tests.jn b/juno_samples/fork_join_tests/src/fork_join_tests.jn new file mode 100644 index 00000000..aa8eb4bb --- /dev/null +++ b/juno_samples/fork_join_tests/src/fork_join_tests.jn @@ -0,0 +1,10 @@ +#[entry] +fn test1(input : i32) -> i32[4, 4] { + let arr : i32[4, 4]; + for i = 0 to 4 { + for j = 0 to 4 { + arr[i, j] = input; + } + } + return arr; +} diff --git a/juno_samples/fork_join_tests/src/gpu.sch b/juno_samples/fork_join_tests/src/gpu.sch new file mode 100644 index 00000000..e2fe980e --- /dev/null +++ b/juno_samples/fork_join_tests/src/gpu.sch @@ -0,0 +1,30 @@ +gvn(*); +phi-elim(*); +dce(*); + +let out = auto-outline(*); +gpu(out.test1); + +ip-sroa(*); +sroa(*); +dce(*); +gvn(*); +phi-elim(*); +dce(*); + +fixpoint panic after 20 { + forkify(*); + fork-guard-elim(*); + fork-coalesce(*); +} + +gvn(*); +phi-elim(*); +dce(*); + +fixpoint panic after 20 { + infer-schedules(*); +} +xdot[true](*); + +gcm(*); diff --git a/juno_samples/fork_join_tests/src/main.rs b/juno_samples/fork_join_tests/src/main.rs new file mode 100644 index 00000000..6e5f2182 --- /dev/null +++ b/juno_samples/fork_join_tests/src/main.rs @@ -0,0 +1,17 @@ +#![feature(concat_idents)] + +use hercules_rt::runner; + +juno_build::juno!("fork_join_tests"); + +fn main() { + async_std::task::block_on(async { + let mut r = runner!(tests1); + let output = r.run(5).await; + }); +} + +#[test] +fn implicit_clone_test() { + main(); +} -- GitLab From 5eeb70cb675bdd7ac2cf4ac4887f803d24ebe165 Mon Sep 17 00:00:00 2001 From: Russel Arbore <russel.jma@gmail.com> Date: Sat, 1 Feb 2025 18:52:57 -0600 Subject: [PATCH 04/33] fixes for test --- hercules_cg/src/gpu.rs | 8 +++--- juno_samples/fork_join_tests/src/cpu.sch | 31 ++++++++++++++++++++++++ juno_samples/fork_join_tests/src/gpu.sch | 1 - juno_samples/fork_join_tests/src/main.rs | 13 +++++++++- juno_scheduler/src/pm.rs | 2 ++ 5 files changed, 50 insertions(+), 5 deletions(-) create mode 100644 juno_samples/fork_join_tests/src/cpu.sch diff --git a/hercules_cg/src/gpu.rs b/hercules_cg/src/gpu.rs index 81e31396..afc016a4 100644 --- a/hercules_cg/src/gpu.rs +++ b/hercules_cg/src/gpu.rs @@ -539,15 +539,17 @@ namespace cg = cooperative_groups; w: &mut String, ) -> Result<(), Error> { write!(w, "\n")?; - for (id, goto) in gotos.iter() { - let goto_block = self.get_block_name(*id, false); + let rev_po = self.control_subgraph.rev_po(NodeID::new(0)); + for id in rev_po { + let goto = &gotos[&id]; + let goto_block = self.get_block_name(id, false); write!(w, "{}:\n", goto_block)?; if goto_debug { write!(w, "\tprintf(\"goto {}\\n\");\n", goto_block)?; } write!(w, "{}", goto.init)?; if !goto.post_init.is_empty() { - let goto_block = self.get_block_name(*id, true); + let goto_block = self.get_block_name(id, true); write!(w, "{}:\n", goto_block)?; write!(w, "{}", goto.post_init)?; } diff --git a/juno_samples/fork_join_tests/src/cpu.sch b/juno_samples/fork_join_tests/src/cpu.sch new file mode 100644 index 00000000..81f5a12c --- /dev/null +++ b/juno_samples/fork_join_tests/src/cpu.sch @@ -0,0 +1,31 @@ +gvn(*); +phi-elim(*); +dce(*); + +let out = auto-outline(*); +cpu(out.test1); + +ip-sroa(*); +sroa(*); +dce(*); +gvn(*); +phi-elim(*); +dce(*); + +fixpoint panic after 20 { + forkify(*); + fork-guard-elim(*); + fork-coalesce(*); +} + +gvn(*); +phi-elim(*); +dce(*); + +fixpoint panic after 20 { + infer-schedules(*); +} +fork-split(*); +unforkify(*); + +gcm(*); diff --git a/juno_samples/fork_join_tests/src/gpu.sch b/juno_samples/fork_join_tests/src/gpu.sch index e2fe980e..e4e4e04f 100644 --- a/juno_samples/fork_join_tests/src/gpu.sch +++ b/juno_samples/fork_join_tests/src/gpu.sch @@ -25,6 +25,5 @@ dce(*); fixpoint panic after 20 { infer-schedules(*); } -xdot[true](*); gcm(*); diff --git a/juno_samples/fork_join_tests/src/main.rs b/juno_samples/fork_join_tests/src/main.rs index 6e5f2182..a63b3f78 100644 --- a/juno_samples/fork_join_tests/src/main.rs +++ b/juno_samples/fork_join_tests/src/main.rs @@ -6,8 +6,19 @@ juno_build::juno!("fork_join_tests"); fn main() { async_std::task::block_on(async { - let mut r = runner!(tests1); + let mut r = runner!(test1); let output = r.run(5).await; + let correct = vec![5i32; 16]; + #[cfg(not(feature = "cuda"))] + { + assert_eq!(output.as_slice::<i32>(), &correct); + } + #[cfg(feature = "cuda")] + { + let mut dst = vec![0i32; 16]; + let output = output.to_cpu_ref(&mut dst); + assert_eq!(output.as_slice::<i32>(), &correct); + } }); } diff --git a/juno_scheduler/src/pm.rs b/juno_scheduler/src/pm.rs index f6fe2fc1..5d398804 100644 --- a/juno_scheduler/src/pm.rs +++ b/juno_scheduler/src/pm.rs @@ -575,6 +575,8 @@ impl PassManager { self.postdoms = None; self.fork_join_maps = None; self.fork_join_nests = None; + self.fork_control_maps = None; + self.fork_trees = None; self.loops = None; self.reduce_cycles = None; self.data_nodes_in_fork_joins = None; -- GitLab From d67263e6a871fa539067ec84e87c58f242809ea4 Mon Sep 17 00:00:00 2001 From: Russel Arbore <russel.jma@gmail.com> Date: Sun, 2 Feb 2025 10:50:56 -0600 Subject: [PATCH 05/33] second fork test, fails on cpu w/ forkify/unforkify --- hercules_opt/src/float_collections.rs | 27 ++++++++++++------ juno_samples/fork_join_tests/src/cpu.sch | 1 + .../fork_join_tests/src/fork_join_tests.jn | 13 +++++++++ juno_samples/fork_join_tests/src/gpu.sch | 2 ++ juno_samples/fork_join_tests/src/main.rs | 28 ++++++++++++------- juno_scheduler/src/pm.rs | 17 +++++------ 6 files changed, 60 insertions(+), 28 deletions(-) diff --git a/hercules_opt/src/float_collections.rs b/hercules_opt/src/float_collections.rs index faa38375..6ef050c2 100644 --- a/hercules_opt/src/float_collections.rs +++ b/hercules_opt/src/float_collections.rs @@ -1,3 +1,5 @@ +use std::collections::BTreeMap; + use hercules_ir::*; use crate::*; @@ -7,27 +9,36 @@ use crate::*; * allowed. */ pub fn float_collections( - editors: &mut [FunctionEditor], + editors: &mut BTreeMap<FunctionID, FunctionEditor>, typing: &ModuleTyping, callgraph: &CallGraph, devices: &Vec<Device>, ) { - let topo = callgraph.topo(); + let topo: Vec<_> = callgraph + .topo() + .into_iter() + .filter(|id| editors.contains_key(&id)) + .collect(); for to_float_id in topo { // Collection constants float until reaching an AsyncRust function. if devices[to_float_id.idx()] == Device::AsyncRust { continue; } + // Check that all callers are in the selection as well. + for caller in callgraph.get_callers(to_float_id) { + assert!(editors.contains_key(&caller), "PANIC: FloatCollections called where a function ({:?}, {:?}) is in the selection but one of its callers ({:?}) is not. This means no collections will be floated from the callee, since the caller can't be modified to hold floated collections.", to_float_id, editors[&to_float_id].func().name, caller); + } + // Find the target constant nodes in the function. - let cons: Vec<(NodeID, Node)> = editors[to_float_id.idx()] + let cons: Vec<(NodeID, Node)> = editors[&to_float_id] .func() .nodes .iter() .enumerate() .filter(|(_, node)| { node.try_constant() - .map(|cons_id| !editors[to_float_id.idx()].get_constant(cons_id).is_scalar()) + .map(|cons_id| !editors[&to_float_id].get_constant(cons_id).is_scalar()) .unwrap_or(false) }) .map(|(idx, node)| (NodeID::new(idx), node.clone())) @@ -37,12 +48,12 @@ pub fn float_collections( } // Each constant node becomes a new parameter. - let mut new_param_types = editors[to_float_id.idx()].func().param_types.clone(); + let mut new_param_types = editors[&to_float_id].func().param_types.clone(); let old_num_params = new_param_types.len(); for (id, _) in cons.iter() { new_param_types.push(typing[to_float_id.idx()][id.idx()]); } - let success = editors[to_float_id.idx()].edit(|mut edit| { + let success = editors.get_mut(&to_float_id).unwrap().edit(|mut edit| { for (idx, (id, _)) in cons.iter().enumerate() { let param = edit.add_node(Node::Parameter { index: idx + old_num_params, @@ -59,7 +70,7 @@ pub fn float_collections( // Add constants in callers and pass them into calls. for caller in callgraph.get_callers(to_float_id) { - let calls: Vec<(NodeID, Node)> = editors[caller.idx()] + let calls: Vec<(NodeID, Node)> = editors[&caller] .func() .nodes .iter() @@ -71,7 +82,7 @@ pub fn float_collections( }) .map(|(idx, node)| (NodeID::new(idx), node.clone())) .collect(); - let success = editors[caller.idx()].edit(|mut edit| { + let success = editors.get_mut(&caller).unwrap().edit(|mut edit| { let cons_ids: Vec<_> = cons .iter() .map(|(_, node)| edit.add_node(node.clone())) diff --git a/juno_samples/fork_join_tests/src/cpu.sch b/juno_samples/fork_join_tests/src/cpu.sch index 81f5a12c..a6b1afe7 100644 --- a/juno_samples/fork_join_tests/src/cpu.sch +++ b/juno_samples/fork_join_tests/src/cpu.sch @@ -4,6 +4,7 @@ dce(*); let out = auto-outline(*); cpu(out.test1); +cpu(out.test2); ip-sroa(*); sroa(*); diff --git a/juno_samples/fork_join_tests/src/fork_join_tests.jn b/juno_samples/fork_join_tests/src/fork_join_tests.jn index aa8eb4bb..4a6a94c9 100644 --- a/juno_samples/fork_join_tests/src/fork_join_tests.jn +++ b/juno_samples/fork_join_tests/src/fork_join_tests.jn @@ -8,3 +8,16 @@ fn test1(input : i32) -> i32[4, 4] { } return arr; } + +#[entry] +fn test2(input : i32) -> i32[4, 4] { + let arr : i32[4, 4]; + for i = 0 to 8 { + for j = 0 to 4 { + for k = 0 to 4 { + arr[j, k] += input; + } + } + } + return arr; +} diff --git a/juno_samples/fork_join_tests/src/gpu.sch b/juno_samples/fork_join_tests/src/gpu.sch index e4e4e04f..b506c4a4 100644 --- a/juno_samples/fork_join_tests/src/gpu.sch +++ b/juno_samples/fork_join_tests/src/gpu.sch @@ -4,6 +4,7 @@ dce(*); let out = auto-outline(*); gpu(out.test1); +gpu(out.test2); ip-sroa(*); sroa(*); @@ -26,4 +27,5 @@ fixpoint panic after 20 { infer-schedules(*); } +float-collections(test2, out.test2); gcm(*); diff --git a/juno_samples/fork_join_tests/src/main.rs b/juno_samples/fork_join_tests/src/main.rs index a63b3f78..de1b0805 100644 --- a/juno_samples/fork_join_tests/src/main.rs +++ b/juno_samples/fork_join_tests/src/main.rs @@ -5,20 +5,28 @@ use hercules_rt::runner; juno_build::juno!("fork_join_tests"); fn main() { + #[cfg(not(feature = "cuda"))] + let assert = |correct, output: hercules_rt::HerculesCPURefMut<'_>| { + assert_eq!(output.as_slice::<i32>(), &correct); + }; + + #[cfg(feature = "cuda")] + let assert = |correct, output: hercules_rt::HerculesCUDARefMut<'_>| { + let mut dst = vec![0i32; 16]; + let output = output.to_cpu_ref(&mut dst); + assert_eq!(output.as_slice::<i32>(), &correct); + }; + async_std::task::block_on(async { let mut r = runner!(test1); let output = r.run(5).await; let correct = vec![5i32; 16]; - #[cfg(not(feature = "cuda"))] - { - assert_eq!(output.as_slice::<i32>(), &correct); - } - #[cfg(feature = "cuda")] - { - let mut dst = vec![0i32; 16]; - let output = output.to_cpu_ref(&mut dst); - assert_eq!(output.as_slice::<i32>(), &correct); - } + assert(correct, output); + + let mut r = runner!(test2); + let output = r.run(3).await; + let correct = vec![24i32; 16]; + assert(correct, output); }); } diff --git a/juno_scheduler/src/pm.rs b/juno_scheduler/src/pm.rs index 5d398804..1ebc885c 100644 --- a/juno_scheduler/src/pm.rs +++ b/juno_scheduler/src/pm.rs @@ -1305,7 +1305,7 @@ fn run_pass( pm: &mut PassManager, pass: Pass, args: Vec<Value>, - selection: Option<Vec<CodeLocation>>, + mut selection: Option<Vec<CodeLocation>>, ) -> Result<(Value, bool), SchedulerError> { let mut result = Value::Record { fields: HashMap::new(), @@ -1441,13 +1441,6 @@ fn run_pass( } Pass::FloatCollections => { assert!(args.is_empty()); - if let Some(_) = selection { - return Err(SchedulerError::PassError { - pass: "floatCollections".to_string(), - error: "must be applied to the entire module".to_string(), - }); - } - pm.make_typing(); pm.make_callgraph(); pm.make_devices(); @@ -1455,11 +1448,15 @@ fn run_pass( let callgraph = pm.callgraph.take().unwrap(); let devices = pm.devices.take().unwrap(); - let mut editors = build_editors(pm); + // Modify the selection to include callers of selected functions. + let mut editors = build_selection(pm, selection) + .into_iter() + .filter_map(|editor| editor.map(|editor| (editor.func_id(), editor))) + .collect(); float_collections(&mut editors, &typing, &callgraph, &devices); for func in editors { - changed |= func.modified(); + changed |= func.1.modified(); } pm.delete_gravestones(); -- GitLab From ba614011066dcde2305a649f7911368d99f95a84 Mon Sep 17 00:00:00 2001 From: Russel Arbore <russel.jma@gmail.com> Date: Sun, 2 Feb 2025 11:12:34 -0600 Subject: [PATCH 06/33] fix reduce cycles --- hercules_ir/src/fork_join_analysis.rs | 34 ++++++++++++------------ juno_samples/fork_join_tests/src/cpu.sch | 9 +++++++ 2 files changed, 26 insertions(+), 17 deletions(-) diff --git a/hercules_ir/src/fork_join_analysis.rs b/hercules_ir/src/fork_join_analysis.rs index 263fa952..7a098a35 100644 --- a/hercules_ir/src/fork_join_analysis.rs +++ b/hercules_ir/src/fork_join_analysis.rs @@ -140,23 +140,23 @@ fn reduce_cycle_dfs_helper( } current_visited.insert(iter); - let found_reduce = get_uses(&function.nodes[iter.idx()]) - .as_ref() - .into_iter() - .any(|u| { - !current_visited.contains(u) - && !function.nodes[u.idx()].is_control() - && isnt_outside_fork_join(*u) - && reduce_cycle_dfs_helper( - function, - *u, - fork, - reduce, - current_visited, - in_cycle, - fork_join_nest, - ) - }); + let mut found_reduce = false; + + // This doesn't short circuit on purpose. + for u in get_uses(&function.nodes[iter.idx()]).as_ref() { + found_reduce |= !current_visited.contains(u) + && !function.nodes[u.idx()].is_control() + && isnt_outside_fork_join(*u) + && reduce_cycle_dfs_helper( + function, + *u, + fork, + reduce, + current_visited, + in_cycle, + fork_join_nest, + ) + } if found_reduce { in_cycle.insert(iter); } diff --git a/juno_samples/fork_join_tests/src/cpu.sch b/juno_samples/fork_join_tests/src/cpu.sch index a6b1afe7..2889cec0 100644 --- a/juno_samples/fork_join_tests/src/cpu.sch +++ b/juno_samples/fork_join_tests/src/cpu.sch @@ -18,6 +18,9 @@ fixpoint panic after 20 { fork-guard-elim(*); fork-coalesce(*); } +gvn(*); +phi-elim(*); +dce(*); gvn(*); phi-elim(*); @@ -27,6 +30,12 @@ fixpoint panic after 20 { infer-schedules(*); } fork-split(*); +gvn(*); +phi-elim(*); +dce(*); unforkify(*); +gvn(*); +phi-elim(*); +dce(*); gcm(*); -- GitLab From f2865cbfd0fd7f3a7e84c08e70283c8c0eedefdf Mon Sep 17 00:00:00 2001 From: Russel Arbore <russel.jma@gmail.com> Date: Sun, 2 Feb 2025 11:21:02 -0600 Subject: [PATCH 07/33] interesting test --- juno_samples/fork_join_tests/src/cpu.sch | 1 + .../fork_join_tests/src/fork_join_tests.jn | 23 +++++++++++++++++++ juno_samples/fork_join_tests/src/gpu.sch | 1 + juno_samples/fork_join_tests/src/main.rs | 5 ++++ 4 files changed, 30 insertions(+) diff --git a/juno_samples/fork_join_tests/src/cpu.sch b/juno_samples/fork_join_tests/src/cpu.sch index 2889cec0..0263c275 100644 --- a/juno_samples/fork_join_tests/src/cpu.sch +++ b/juno_samples/fork_join_tests/src/cpu.sch @@ -5,6 +5,7 @@ dce(*); let out = auto-outline(*); cpu(out.test1); cpu(out.test2); +cpu(out.test3); ip-sroa(*); sroa(*); diff --git a/juno_samples/fork_join_tests/src/fork_join_tests.jn b/juno_samples/fork_join_tests/src/fork_join_tests.jn index 4a6a94c9..073cfd1e 100644 --- a/juno_samples/fork_join_tests/src/fork_join_tests.jn +++ b/juno_samples/fork_join_tests/src/fork_join_tests.jn @@ -21,3 +21,26 @@ fn test2(input : i32) -> i32[4, 4] { } return arr; } + +#[entry] +fn test3(input : i32) -> i32[3, 3] { + let arr1 : i32[3, 3]; + for i = 0 to 3 { + for j = 0 to 3 { + arr1[i, j] = (i + j) as i32 + input; + } + } + let arr2 : i32[3, 3]; + for i = 0 to 3 { + for j = 0 to 3 { + arr2[i, j] = arr1[3 - i, 3 - j]; + } + } + let arr3 : i32[3, 3]; + for i = 0 to 3 { + for j = 0 to 3 { + arr3[i, j] = arr2[i, j] + 7; + } + } + return arr3; +} diff --git a/juno_samples/fork_join_tests/src/gpu.sch b/juno_samples/fork_join_tests/src/gpu.sch index b506c4a4..80f1bbc9 100644 --- a/juno_samples/fork_join_tests/src/gpu.sch +++ b/juno_samples/fork_join_tests/src/gpu.sch @@ -5,6 +5,7 @@ dce(*); let out = auto-outline(*); gpu(out.test1); gpu(out.test2); +gpu(out.test3); ip-sroa(*); sroa(*); diff --git a/juno_samples/fork_join_tests/src/main.rs b/juno_samples/fork_join_tests/src/main.rs index de1b0805..4384ecd5 100644 --- a/juno_samples/fork_join_tests/src/main.rs +++ b/juno_samples/fork_join_tests/src/main.rs @@ -27,6 +27,11 @@ fn main() { let output = r.run(3).await; let correct = vec![24i32; 16]; assert(correct, output); + + let mut r = runner!(test3); + let output = r.run(0).await; + let correct = vec![11, 10, 9, 10, 9, 8, 9, 8, 7]; + assert(correct, output); }); } -- GitLab From 912a729ac05bd2077f2ab864cef04e1dddde7667 Mon Sep 17 00:00:00 2001 From: Xavier Routh <xrouth2@illinois.edu> Date: Sun, 2 Feb 2025 12:34:13 -0600 Subject: [PATCH 08/33] unforkify fixes --- hercules_opt/src/unforkify.rs | 6 +++-- juno_scheduler/src/pm.rs | 42 ++++++++++++++++++++++------------- 2 files changed, 30 insertions(+), 18 deletions(-) diff --git a/hercules_opt/src/unforkify.rs b/hercules_opt/src/unforkify.rs index 7d158d1a..a08d1667 100644 --- a/hercules_opt/src/unforkify.rs +++ b/hercules_opt/src/unforkify.rs @@ -118,7 +118,7 @@ pub fn unforkify( // control insides of the fork-join should become the successor of the true // projection node, and what was the use of the join should become a use of // the new region. - for l in loop_tree.bottom_up_loops().into_iter().rev() { + for l in loop_tree.bottom_up_loops().iter().rev() { if !editor.node(l.0).is_fork() { continue; } @@ -133,7 +133,8 @@ pub fn unforkify( if factors.len() > 1 { // For now, don't convert multi-dimensional fork-joins. Rely on pass // that splits fork-joins. - break; // Because we have to unforkify top down, we can't unforkify forks that are contained + // We can't unforkify, because then the outer forks reduce will depend on non-fork control. + break; } let join_control = nodes[join.idx()].try_join().unwrap(); let tids: Vec<_> = editor @@ -293,5 +294,6 @@ pub fn unforkify( Ok(edit) }); + break; } } diff --git a/juno_scheduler/src/pm.rs b/juno_scheduler/src/pm.rs index d2772c71..378d8730 100644 --- a/juno_scheduler/src/pm.rs +++ b/juno_scheduler/src/pm.rs @@ -1815,25 +1815,35 @@ fn run_pass( } Pass::Unforkify => { assert!(args.is_empty()); - pm.make_fork_join_maps(); - pm.make_loops(); + loop { + let mut inner_changed = false; - let fork_join_maps = pm.fork_join_maps.take().unwrap(); - let loops = pm.loops.take().unwrap(); + pm.make_fork_join_maps(); + pm.make_loops(); - for ((func, fork_join_map), loop_tree) in build_selection(pm, selection) - .into_iter() - .zip(fork_join_maps.iter()) - .zip(loops.iter()) - { - let Some(mut func) = func else { - continue; - }; - unforkify(&mut func, fork_join_map, loop_tree); - changed |= func.modified(); + let fork_join_maps = pm.fork_join_maps.take().unwrap(); + let loops = pm.loops.take().unwrap(); + + for ((func, fork_join_map), loop_tree) in build_selection(pm, selection.clone()) + .into_iter() + .zip(fork_join_maps.iter()) + .zip(loops.iter()) + { + let Some(mut func) = func else { + continue; + }; + unforkify(&mut func, fork_join_map, loop_tree); + changed |= func.modified(); + inner_changed |= func.modified(); + } + pm.delete_gravestones(); + pm.clear_analyses(); + + if !inner_changed { + break; + } + break; } - pm.delete_gravestones(); - pm.clear_analyses(); } Pass::ForkCoalesce => { assert!(args.is_empty()); -- GitLab From d390705275493b662dcfa92a8d7ad35551d119a4 Mon Sep 17 00:00:00 2001 From: Russel Arbore <russel.jma@gmail.com> Date: Sun, 2 Feb 2025 13:20:04 -0600 Subject: [PATCH 09/33] whoops --- juno_samples/fork_join_tests/src/fork_join_tests.jn | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/juno_samples/fork_join_tests/src/fork_join_tests.jn b/juno_samples/fork_join_tests/src/fork_join_tests.jn index 073cfd1e..3d003f3c 100644 --- a/juno_samples/fork_join_tests/src/fork_join_tests.jn +++ b/juno_samples/fork_join_tests/src/fork_join_tests.jn @@ -33,7 +33,7 @@ fn test3(input : i32) -> i32[3, 3] { let arr2 : i32[3, 3]; for i = 0 to 3 { for j = 0 to 3 { - arr2[i, j] = arr1[3 - i, 3 - j]; + arr2[i, j] = arr1[2 - i, 2 - j]; } } let arr3 : i32[3, 3]; -- GitLab From ed33189c2fcaf725a8ae34f09d4fb963826ac4b6 Mon Sep 17 00:00:00 2001 From: Russel Arbore <russel.jma@gmail.com> Date: Sun, 2 Feb 2025 13:38:57 -0600 Subject: [PATCH 10/33] whoops x2 --- hercules_ir/src/ir.rs | 2 +- juno_samples/fork_join_tests/src/gpu.sch | 1 + 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/hercules_ir/src/ir.rs b/hercules_ir/src/ir.rs index 846347b0..5c575ea1 100644 --- a/hercules_ir/src/ir.rs +++ b/hercules_ir/src/ir.rs @@ -1857,7 +1857,7 @@ pub fn indices_may_overlap(indices1: &[Index], indices2: &[Index]) -> bool { * list of indices B. */ pub fn indices_contain_other_indices(indices_a: &[Index], indices_b: &[Index]) -> bool { - if indices_a.len() < indices_b.len() { + if indices_a.len() > indices_b.len() { return false; } diff --git a/juno_samples/fork_join_tests/src/gpu.sch b/juno_samples/fork_join_tests/src/gpu.sch index 80f1bbc9..0647d781 100644 --- a/juno_samples/fork_join_tests/src/gpu.sch +++ b/juno_samples/fork_join_tests/src/gpu.sch @@ -28,5 +28,6 @@ fixpoint panic after 20 { infer-schedules(*); } +xdot[true](*); float-collections(test2, out.test2); gcm(*); -- GitLab From e1f5634f38cca264a25abe83174434ef333af9c3 Mon Sep 17 00:00:00 2001 From: Russel Arbore <russel.jma@gmail.com> Date: Sun, 2 Feb 2025 13:59:43 -0600 Subject: [PATCH 11/33] fix to antideps --- hercules_opt/src/gcm.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/hercules_opt/src/gcm.rs b/hercules_opt/src/gcm.rs index 271bfaf1..b13c919a 100644 --- a/hercules_opt/src/gcm.rs +++ b/hercules_opt/src/gcm.rs @@ -290,7 +290,7 @@ fn basic_blocks( .collect(); for mutator in reverse_postorder.iter() { let mutator_early = schedule_early[mutator.idx()].unwrap(); - if dom.does_dom(root_early, mutator_early) + if dom.does_prop_dom(root_early, mutator_early) && (root_early != mutator_early || root_block_iterated_users.contains(&mutator)) && mutating_objects(function, func_id, *mutator, objects) -- GitLab From 7f381ff54caddce24b0808725390a982e27bfdbc Mon Sep 17 00:00:00 2001 From: Russel Arbore <russel.jma@gmail.com> Date: Sun, 2 Feb 2025 14:38:13 -0600 Subject: [PATCH 12/33] hack for gpu --- hercules_cg/src/gpu.rs | 17 ++++++++--------- hercules_opt/src/gcm.rs | 18 ++++++++++++++++-- 2 files changed, 24 insertions(+), 11 deletions(-) diff --git a/hercules_cg/src/gpu.rs b/hercules_cg/src/gpu.rs index afc016a4..8f186aa7 100644 --- a/hercules_cg/src/gpu.rs +++ b/hercules_cg/src/gpu.rs @@ -622,23 +622,23 @@ extern \"C\" {} {}(", write!(pass_args, "ret")?; write!(w, "\tcudaMalloc((void**)&ret, sizeof({}));\n", ret_type)?; } - write!(w, "\tcudaError_t err;\n"); + write!(w, "\tcudaError_t err;\n")?; write!( w, "\t{}_gpu<<<{}, {}, {}>>>({});\n", self.function.name, num_blocks, num_threads, dynamic_shared_offset, pass_args )?; - write!(w, "\terr = cudaGetLastError();\n"); + write!(w, "\terr = cudaGetLastError();\n")?; write!( w, "\tif (cudaSuccess != err) {{ printf(\"Error1: %s\\n\", cudaGetErrorString(err)); }}\n" - ); + )?; write!(w, "\tcudaDeviceSynchronize();\n")?; - write!(w, "\terr = cudaGetLastError();\n"); + write!(w, "\terr = cudaGetLastError();\n")?; write!( w, "\tif (cudaSuccess != err) {{ printf(\"Error2: %s\\n\", cudaGetErrorString(err)); }}\n" - ); + )?; if has_ret_var { // Copy return from device to host, whether it's primitive value or collection pointer write!(w, "\t{} host_ret;\n", ret_type)?; @@ -1150,7 +1150,8 @@ extern \"C\" {} {}(", // for all threads. Otherwise, it can be inside or outside block fork. // If inside, it's stored in shared memory so we "allocate" it once // and parallelize memset to 0. If outside, we initialize as offset - // to backing, but if multi-block grid, don't memset to avoid grid-level sync. + // to backing, but if multi-block grid, don't memset to avoid grid- + // level sync. Node::Constant { id: cons_id } => { let is_primitive = self.types[self.typing[id.idx()].idx()].is_primitive(); let cg_tile = match state { @@ -1192,9 +1193,7 @@ extern \"C\" {} {}(", )?; } if !is_primitive - && (state != KernelState::OutBlock - || is_block_parallel.is_none() - || !is_block_parallel.unwrap()) + && (state != KernelState::OutBlock || !is_block_parallel.unwrap_or(false)) { let data_size = self.get_size(self.typing[id.idx()], None, Some(extra_dim_collects)); diff --git a/hercules_opt/src/gcm.rs b/hercules_opt/src/gcm.rs index b13c919a..65f7c2d0 100644 --- a/hercules_opt/src/gcm.rs +++ b/hercules_opt/src/gcm.rs @@ -90,6 +90,7 @@ pub fn gcm( loops, fork_join_map, objects, + devices, ); let liveness = liveness_dataflow( @@ -174,6 +175,7 @@ fn basic_blocks( loops: &LoopTree, fork_join_map: &HashMap<NodeID, NodeID>, objects: &CollectionObjects, + devices: &Vec<Device>, ) -> BasicBlocks { let mut bbs: Vec<Option<NodeID>> = vec![None; function.nodes.len()]; @@ -421,9 +423,18 @@ fn basic_blocks( // If the next node further up the dominator tree is in a shallower // loop nest or if we can get out of a reduce loop when we don't // need to be in one, place this data node in a higher-up location. - // Only do this is the node isn't a constant or undef. + // Only do this is the node isn't a constant or undef - if a + // node is a constant or undef, we want its placement to be as + // control dependent as possible, even inside loops. In GPU + // functions specifically, lift constants that may be returned + // outside fork-joins. let is_constant_or_undef = function.nodes[id.idx()].is_constant() || function.nodes[id.idx()].is_undef(); + let is_gpu_returned = devices[func_id.idx()] == Device::CUDA + && objects[&func_id] + .objects(id) + .into_iter() + .any(|obj| objects[&func_id].returned_objects().contains(obj)); let old_nest = loops .header_of(location) .map(|header| loops.nesting(header).unwrap()); @@ -444,7 +455,10 @@ fn basic_blocks( // loop use the reduce node forming the loop, so the dominator chain // will consist of one block, and this loop won't ever iterate. let currently_at_join = function.nodes[location.idx()].is_join(); - if !is_constant_or_undef && (shallower_nest || currently_at_join) { + + if (!is_constant_or_undef || is_gpu_returned) + && (shallower_nest || currently_at_join) + { location = control_node; } } -- GitLab From 25725cb1df4ec7b3c59f963c00aa5a2844a52916 Mon Sep 17 00:00:00 2001 From: Russel Arbore <russel.jma@gmail.com> Date: Sun, 2 Feb 2025 14:50:58 -0600 Subject: [PATCH 13/33] Ok fix antideps for real this time --- hercules_opt/src/gcm.rs | 19 +++++++++++++++++-- juno_scheduler/src/pm.rs | 3 +++ 2 files changed, 20 insertions(+), 2 deletions(-) diff --git a/hercules_opt/src/gcm.rs b/hercules_opt/src/gcm.rs index 65f7c2d0..3ff6d2fe 100644 --- a/hercules_opt/src/gcm.rs +++ b/hercules_opt/src/gcm.rs @@ -1,5 +1,5 @@ use std::cell::Ref; -use std::collections::{BTreeMap, BTreeSet, HashMap, VecDeque}; +use std::collections::{BTreeMap, BTreeSet, HashMap, HashSet, VecDeque}; use std::iter::{empty, once, zip, FromIterator}; use bitvec::prelude::*; @@ -76,6 +76,7 @@ pub fn gcm( dom: &DomTree, fork_join_map: &HashMap<NodeID, NodeID>, loops: &LoopTree, + reduce_cycles: &HashMap<NodeID, HashSet<NodeID>>, objects: &CollectionObjects, devices: &Vec<Device>, object_device_demands: &FunctionObjectDeviceDemands, @@ -88,6 +89,7 @@ pub fn gcm( reverse_postorder, dom, loops, + reduce_cycles, fork_join_map, objects, devices, @@ -173,6 +175,7 @@ fn basic_blocks( reverse_postorder: &Vec<NodeID>, dom: &DomTree, loops: &LoopTree, + reduce_cycles: &HashMap<NodeID, HashSet<NodeID>>, fork_join_map: &HashMap<NodeID, NodeID>, objects: &CollectionObjects, devices: &Vec<Device>, @@ -246,6 +249,9 @@ fn basic_blocks( // but not forwarding read - forwarding reads are collapsed, and the // bottom read is treated as reading from the transitive parent of the // forwarding read(s). + // 3: If the node producing the collection is a reduce node, then any read + // users that aren't in the reduce's cycle shouldn't anti-depend user any + // mutators in the reduce cycle. let mut antideps = BTreeSet::new(); for id in reverse_postorder.iter() { // Find a terminating read node and the collections it reads. @@ -271,6 +277,10 @@ fn basic_blocks( // TODO: make this less outrageously inefficient. let func_objects = &objects[&func_id]; for root in roots.iter() { + let root_is_reduce_and_read_isnt_in_cycle = reduce_cycles + .get(root) + .map(|cycle| !cycle.contains(&id)) + .unwrap_or(false); let root_early = schedule_early[root.idx()].unwrap(); let mut root_block_iterated_users: BTreeSet<NodeID> = BTreeSet::new(); let mut workset = BTreeSet::new(); @@ -292,12 +302,17 @@ fn basic_blocks( .collect(); for mutator in reverse_postorder.iter() { let mutator_early = schedule_early[mutator.idx()].unwrap(); - if dom.does_prop_dom(root_early, mutator_early) + if dom.does_dom(root_early, mutator_early) && (root_early != mutator_early || root_block_iterated_users.contains(&mutator)) && mutating_objects(function, func_id, *mutator, objects) .any(|mutated| read_objs.contains(&mutated)) && id != mutator + && (!root_is_reduce_and_read_isnt_in_cycle + || !reduce_cycles + .get(root) + .map(|cycle| cycle.contains(mutator)) + .unwrap_or(false)) { antideps.insert((*id, *mutator)); } diff --git a/juno_scheduler/src/pm.rs b/juno_scheduler/src/pm.rs index 1ebc885c..db4455e8 100644 --- a/juno_scheduler/src/pm.rs +++ b/juno_scheduler/src/pm.rs @@ -1567,6 +1567,7 @@ fn run_pass( pm.make_doms(); pm.make_fork_join_maps(); pm.make_loops(); + pm.make_reduce_cycles(); pm.make_collection_objects(); pm.make_devices(); pm.make_object_device_demands(); @@ -1577,6 +1578,7 @@ fn run_pass( let doms = pm.doms.take().unwrap(); let fork_join_maps = pm.fork_join_maps.take().unwrap(); let loops = pm.loops.take().unwrap(); + let reduce_cycles = pm.reduce_cycles.take().unwrap(); let control_subgraphs = pm.control_subgraphs.take().unwrap(); let collection_objects = pm.collection_objects.take().unwrap(); let devices = pm.devices.take().unwrap(); @@ -1598,6 +1600,7 @@ fn run_pass( &doms[id.idx()], &fork_join_maps[id.idx()], &loops[id.idx()], + &reduce_cycles[id.idx()], &collection_objects, &devices, &object_device_demands[id.idx()], -- GitLab From c7d47ec8de0737d73395c06cd3ea5e835b2324d8 Mon Sep 17 00:00:00 2001 From: Russel Arbore <russel.jma@gmail.com> Date: Sun, 2 Feb 2025 14:51:22 -0600 Subject: [PATCH 14/33] remove xdot from schedule --- juno_samples/fork_join_tests/src/gpu.sch | 1 - 1 file changed, 1 deletion(-) diff --git a/juno_samples/fork_join_tests/src/gpu.sch b/juno_samples/fork_join_tests/src/gpu.sch index 0647d781..80f1bbc9 100644 --- a/juno_samples/fork_join_tests/src/gpu.sch +++ b/juno_samples/fork_join_tests/src/gpu.sch @@ -28,6 +28,5 @@ fixpoint panic after 20 { infer-schedules(*); } -xdot[true](*); float-collections(test2, out.test2); gcm(*); -- GitLab From 0caa58c5be3deb70cd0ef5dfde77b02a09bfbf2a Mon Sep 17 00:00:00 2001 From: Russel Arbore <russel.jma@gmail.com> Date: Sun, 2 Feb 2025 15:28:02 -0600 Subject: [PATCH 15/33] Test requiring outer split + unforkify --- .../fork_join_tests/src/fork_join_tests.jn | 15 +++++++++++++++ juno_samples/fork_join_tests/src/gpu.sch | 1 + juno_samples/fork_join_tests/src/main.rs | 5 +++++ 3 files changed, 21 insertions(+) diff --git a/juno_samples/fork_join_tests/src/fork_join_tests.jn b/juno_samples/fork_join_tests/src/fork_join_tests.jn index 3d003f3c..55e0a37e 100644 --- a/juno_samples/fork_join_tests/src/fork_join_tests.jn +++ b/juno_samples/fork_join_tests/src/fork_join_tests.jn @@ -44,3 +44,18 @@ fn test3(input : i32) -> i32[3, 3] { } return arr3; } + +#[entry] +fn test4(input : i32) -> i32[4, 4] { + let arr : i32[4, 4]; + for i = 0 to 4 { + for j = 0 to 4 { + let acc = arr[i, j]; + for k = 0 to 7 { + acc += input; + } + arr[i, j] = acc; + } + } + return arr; +} diff --git a/juno_samples/fork_join_tests/src/gpu.sch b/juno_samples/fork_join_tests/src/gpu.sch index 80f1bbc9..bf35caea 100644 --- a/juno_samples/fork_join_tests/src/gpu.sch +++ b/juno_samples/fork_join_tests/src/gpu.sch @@ -6,6 +6,7 @@ let out = auto-outline(*); gpu(out.test1); gpu(out.test2); gpu(out.test3); +gpu(out.test4); ip-sroa(*); sroa(*); diff --git a/juno_samples/fork_join_tests/src/main.rs b/juno_samples/fork_join_tests/src/main.rs index 4384ecd5..cbd42c50 100644 --- a/juno_samples/fork_join_tests/src/main.rs +++ b/juno_samples/fork_join_tests/src/main.rs @@ -32,6 +32,11 @@ fn main() { let output = r.run(0).await; let correct = vec![11, 10, 9, 10, 9, 8, 9, 8, 7]; assert(correct, output); + + let mut r = runner!(test4); + let output = r.run(9).await; + let correct = vec![63i32; 16]; + assert(correct, output); }); } -- GitLab From 50555d655836a3090d605dfe622a9cc1127076a6 Mon Sep 17 00:00:00 2001 From: Xavier Routh <xrouth2@illinois.edu> Date: Sun, 2 Feb 2025 15:42:08 -0600 Subject: [PATCH 16/33] interpreter fixes + product consts --- hercules_test/hercules_interpreter/src/interpreter.rs | 3 +-- hercules_test/hercules_interpreter/src/value.rs | 4 ++-- 2 files changed, 3 insertions(+), 4 deletions(-) diff --git a/hercules_test/hercules_interpreter/src/interpreter.rs b/hercules_test/hercules_interpreter/src/interpreter.rs index 871e304a..22ef062a 100644 --- a/hercules_test/hercules_interpreter/src/interpreter.rs +++ b/hercules_test/hercules_interpreter/src/interpreter.rs @@ -783,8 +783,7 @@ impl<'a> FunctionExecutionState<'a> { &self.module.dynamic_constants, &self.dynamic_constant_params, ) - }) - .rev(); + }); let n_tokens: usize = factors.clone().product(); diff --git a/hercules_test/hercules_interpreter/src/value.rs b/hercules_test/hercules_interpreter/src/value.rs index 53911e05..adbed6e6 100644 --- a/hercules_test/hercules_interpreter/src/value.rs +++ b/hercules_test/hercules_interpreter/src/value.rs @@ -156,8 +156,8 @@ impl<'a> InterpreterVal { Constant::Float64(v) => Self::Float64(v), Constant::Product(ref type_id, ref constant_ids) => { - // Self::Product((), ()) - todo!() + let contents = constant_ids.iter().map(|const_id| InterpreterVal::from_constant(&constants[const_id.idx()], constants, types, dynamic_constants, dynamic_constant_params)); + InterpreterVal::Product(*type_id, contents.collect_vec().into_boxed_slice()) } Constant::Summation(_, _, _) => todo!(), Constant::Array(type_id) => { -- GitLab From acf060c7daf486ef28a68a292beda3132850891d Mon Sep 17 00:00:00 2001 From: Xavier Routh <xrouth2@illinois.edu> Date: Sun, 2 Feb 2025 15:42:26 -0600 Subject: [PATCH 17/33] read schedule from file --- juno_scheduler/src/lib.rs | 36 ++++++++++++++++++++++++++++++++++++ 1 file changed, 36 insertions(+) diff --git a/juno_scheduler/src/lib.rs b/juno_scheduler/src/lib.rs index 571d1fbf..2479af98 100644 --- a/juno_scheduler/src/lib.rs +++ b/juno_scheduler/src/lib.rs @@ -146,6 +146,42 @@ pub fn run_schedule_on_hercules( .map_err(|e| format!("Scheduling Error: {}", e)) } + +pub fn run_schedule_from_file_on_hercules( + module: Module, + sched_filename: Option<String>, +) -> Result<Module, String> { + let sched = process_schedule(sched_filename)?; + + // Prepare the scheduler's string table and environment + // For this, we put all of the Hercules function names into the environment + // and string table + let mut strings = StringTable::new(); + let mut env = Env::new(); + + env.open_scope(); + + for (idx, func) in module.functions.iter().enumerate() { + let func_name = strings.lookup_string(func.name.clone()); + env.insert( + func_name, + Value::HerculesFunction { + func: FunctionID::new(idx), + }, + ); + } + + env.open_scope(); + schedule_module( + module, + sched, + strings, + env, + JunoFunctions { func_ids: vec![] }, + ) + .map_err(|e| format!("Scheduling Error: {}", e)) +} + pub fn schedule_hercules( module: Module, sched_filename: Option<String>, -- GitLab From 61bc9ae455c3933f930d7a980a4cc4efc39afb6f Mon Sep 17 00:00:00 2001 From: Xavier Routh <xrouth2@illinois.edu> Date: Sun, 2 Feb 2025 15:43:08 -0600 Subject: [PATCH 18/33] unforkify fix --- hercules_opt/src/fork_guard_elim.rs | 8 ++++---- juno_scheduler/src/pm.rs | 1 - 2 files changed, 4 insertions(+), 5 deletions(-) diff --git a/hercules_opt/src/fork_guard_elim.rs b/hercules_opt/src/fork_guard_elim.rs index 052fd0e4..f6914b74 100644 --- a/hercules_opt/src/fork_guard_elim.rs +++ b/hercules_opt/src/fork_guard_elim.rs @@ -39,7 +39,6 @@ struct GuardedFork { guard_if: NodeID, fork_taken_proj: NodeID, fork_skipped_proj: NodeID, - guard_pred: NodeID, guard_join_region: NodeID, phi_reduce_map: HashMap<NodeID, NodeID>, factor: Factor, // The factor that matches the guard @@ -302,7 +301,6 @@ fn guarded_fork( guard_if: if_node, fork_taken_proj: *control, fork_skipped_proj: other_pred, - guard_pred: if_pred, guard_join_region: join_control, phi_reduce_map: phi_nodes, factor, @@ -323,13 +321,15 @@ pub fn fork_guard_elim(editor: &mut FunctionEditor, fork_join_map: &HashMap<Node join, fork_taken_proj, fork_skipped_proj, - guard_pred, phi_reduce_map, factor, guard_if, guard_join_region, } in guard_info - { + { + let Some(guard_pred) = editor.get_uses(guard_if).next() else { + unreachable!() + }; let new_fork_info = if let Factor::Max(idx, dc) = factor { let Node::Fork { control: _, diff --git a/juno_scheduler/src/pm.rs b/juno_scheduler/src/pm.rs index db3904f7..dd2ae73a 100644 --- a/juno_scheduler/src/pm.rs +++ b/juno_scheduler/src/pm.rs @@ -1895,7 +1895,6 @@ fn run_pass( if !inner_changed { break; } - break; } } Pass::ForkCoalesce => { -- GitLab From 82ab7c076640225cf0ff4de2b8443a46dcc9b95f Mon Sep 17 00:00:00 2001 From: Xavier Routh <xrouth2@illinois.edu> Date: Sun, 2 Feb 2025 22:39:36 -0600 Subject: [PATCH 19/33] fork dim merge --- hercules_opt/src/fork_transforms.rs | 161 ++++++++++++++++++++++++++++ juno_scheduler/src/compile.rs | 3 +- juno_scheduler/src/ir.rs | 2 + juno_scheduler/src/lib.rs | 3 +- juno_scheduler/src/pm.rs | 30 ++++-- 5 files changed, 189 insertions(+), 10 deletions(-) diff --git a/hercules_opt/src/fork_transforms.rs b/hercules_opt/src/fork_transforms.rs index e23f586f..58ace775 100644 --- a/hercules_opt/src/fork_transforms.rs +++ b/hercules_opt/src/fork_transforms.rs @@ -1,5 +1,6 @@ use std::collections::{HashMap, HashSet}; use std::iter::zip; +use std::thread::ThreadId; use bimap::BiMap; use itertools::Itertools; @@ -693,3 +694,163 @@ pub(crate) fn split_fork( None } } + +// Splits a dimension of a single fork join into multiple. +// Iterates an outer loop original_dim / tile_size times +// adds a tile_size loop as the inner loop +// Assumes that tile size divides original dim evenly. +pub fn chunk_fork_unguarded( + editor: &mut FunctionEditor, + fork: NodeID, + dim_idx: usize, + tile_size: DynamicConstantID, +) -> () { + // tid_dim_idx = tid_dim_idx * tile_size + tid_(dim_idx + 1) + + let Node::Fork { control: old_control, factors: ref old_factors} = *editor.node(fork) else {return}; + + let mut new_factors: Vec<_> = old_factors.to_vec(); + + let fork_users: Vec<_> = editor.get_users(fork).collect(); + + + for tid in fork_users { + let Node::ThreadID { control: _, dimension: tid_dim } = editor.node(tid).clone() else { continue }; + editor.edit(|mut edit| { + if tid_dim > dim_idx { + let new_tid = Node::ThreadID { control: fork, dimension: tid_dim + 1 }; + let new_tid = edit.add_node(new_tid); + edit.replace_all_uses(tid, new_tid) + } else if tid_dim == dim_idx { + let tile_tid = Node::ThreadID { control: fork, dimension: tid_dim + 1 }; + let tile_tid = edit.add_node(tile_tid); + + let tile_size = edit.add_node(Node::DynamicConstant { id: tile_size }); + let mul = edit.add_node(Node::Binary { left: tid, right: tile_size, op: BinaryOperator::Mul }); + let add = edit.add_node(Node::Binary { left: mul, right: tile_tid, op: BinaryOperator::Add }); + edit.replace_all_uses_where(tid, add, |usee| *usee != mul ) + } else { + Ok(edit) + } + }); + } + + editor.edit(|mut edit| { + let outer = DynamicConstant::div(new_factors[dim_idx], tile_size); + new_factors.insert(dim_idx + 1, tile_size); + new_factors[dim_idx] = edit.add_dynamic_constant(outer); + + let new_fork = Node::Fork { control: old_control, factors: new_factors.into() }; + let new_fork = edit.add_node(new_fork); + + edit.replace_all_uses(fork, new_fork) + }); +} + + +pub fn merge_all_fork_dims( + editor: &mut FunctionEditor, + fork_join_map: &HashMap<NodeID, NodeID>, +) { + for (fork, _) in fork_join_map { + let Node::Fork { control: _, factors: dims } = editor.node(fork) else { + unreachable!(); + }; + + let mut fork = *fork; + for _ in 0..dims.len() - 1 { + let outer = 0; + let inner = 1; + fork = fork_dim_merge(editor, fork, outer, inner); + } + } +} + +// Splits a dimension of a single fork join into multiple. +// Iterates an outer loop original_dim / tile_size times +// adds a tile_size loop as the inner loop +// Assumes that tile size divides original dim evenly. +pub fn fork_dim_merge( + editor: &mut FunctionEditor, + fork: NodeID, + dim_idx1: usize, + dim_idx2: usize, +) -> NodeID { + // tid_dim_idx1 (replaced w/) <- dim_idx1 / dim(dim_idx2) + // tid_dim_idx2 (replaced w/) <- dim_idx1 % dim(dim_idx2) + assert_ne!(dim_idx1, dim_idx2); + + // Outer is smaller, and also closer to the left of the factors array. + let (outer_idx, inner_idx) = if dim_idx2 < dim_idx1 { + (dim_idx2, dim_idx1) + } else { + (dim_idx1, dim_idx2) + }; + + let Node::Fork { control: old_control, factors: ref old_factors} = *editor.node(fork) else {return fork}; + + let mut new_factors: Vec<_> = old_factors.to_vec(); + + + + let fork_users: Vec<_> = editor.get_users(fork).collect(); + + let mut new_nodes = vec![]; + + let outer_dc_id = new_factors[outer_idx]; + let inner_dc_id = new_factors[inner_idx]; + + let mut new_fork_id = NodeID::new(0); + editor.edit(|mut edit| { + new_factors[outer_idx] = edit.add_dynamic_constant(DynamicConstant::mul(new_factors[outer_idx], new_factors[inner_idx])); + new_factors.remove(inner_idx); + + let new_fork = Node::Fork { control: old_control, factors: new_factors.into() }; + let new_fork = edit.add_node(new_fork); + new_fork_id = new_fork; + + + edit = edit.replace_all_uses(fork, new_fork)?; + edit.delete_node(fork) + }); + + + + for tid in fork_users { + let Node::ThreadID { control: _, dimension: tid_dim } = editor.node(tid).clone() else { continue }; + + println!("tid: {:?}", tid); + editor.edit(|mut edit| { + if tid_dim > inner_idx { + let new_tid = Node::ThreadID { control: new_fork_id, dimension: tid_dim - 1 }; + let new_tid = edit.add_node(new_tid); + edit.replace_all_uses(tid, new_tid) + } else if tid_dim == outer_idx { + let outer_tid = Node::ThreadID { control: new_fork_id, dimension: outer_idx }; + let outer_tid = edit.add_node(outer_tid); + + let outer_dc = edit.add_node(Node::DynamicConstant { id: outer_dc_id }); + new_nodes.push(outer_tid); + + // inner_idx % dim(outer_idx) + let rem = edit.add_node(Node::Binary { left: outer_tid, right: outer_dc, op: BinaryOperator::Rem}); + + edit.replace_all_uses(tid, rem) + } else if tid_dim == inner_idx { + let outer_tid = Node::ThreadID { control: new_fork_id, dimension: outer_idx }; + let outer_tid = edit.add_node(outer_tid); + + let outer_dc = edit.add_node(Node::DynamicConstant { id: outer_dc_id }); + // inner_idx / dim(outer_idx) + let div = edit.add_node(Node::Binary { left: outer_tid, right: outer_dc, op: BinaryOperator::Div}); + + edit.replace_all_uses(tid, div) + } else { + Ok(edit) + } + }); + }; + + return new_fork_id; + +} \ No newline at end of file diff --git a/juno_scheduler/src/compile.rs b/juno_scheduler/src/compile.rs index 11a8ec53..07ad5e7a 100644 --- a/juno_scheduler/src/compile.rs +++ b/juno_scheduler/src/compile.rs @@ -108,7 +108,8 @@ impl FromStr for Appliable { "inline" => Ok(Appliable::Pass(ir::Pass::Inline)), "ip-sroa" | "interprocedural-sroa" => { Ok(Appliable::Pass(ir::Pass::InterproceduralSROA)) - } + }, + "fork-dim-merge" => Ok(Appliable::Pass(ir::Pass::ForkDimMerge)), "lift-dc-math" => Ok(Appliable::Pass(ir::Pass::LiftDCMath)), "outline" => Ok(Appliable::Pass(ir::Pass::Outline)), "phi-elim" => Ok(Appliable::Pass(ir::Pass::PhiElim)), diff --git a/juno_scheduler/src/ir.rs b/juno_scheduler/src/ir.rs index d6a41baf..939ef925 100644 --- a/juno_scheduler/src/ir.rs +++ b/juno_scheduler/src/ir.rs @@ -12,6 +12,8 @@ pub enum Pass { ForkSplit, ForkCoalesce, Forkify, + ForkDimMerge, + ForkChunk, GCM, GVN, InferSchedules, diff --git a/juno_scheduler/src/lib.rs b/juno_scheduler/src/lib.rs index 2479af98..ad9195fb 100644 --- a/juno_scheduler/src/lib.rs +++ b/juno_scheduler/src/lib.rs @@ -60,7 +60,7 @@ fn build_schedule(sched_filename: String) -> Result<ScheduleStmt, String> { } } -fn process_schedule(sched_filename: Option<String>) -> Result<ScheduleStmt, String> { +pub fn process_schedule(sched_filename: Option<String>) -> Result<ScheduleStmt, String> { if let Some(name) = sched_filename { build_schedule(name) } else { @@ -146,7 +146,6 @@ pub fn run_schedule_on_hercules( .map_err(|e| format!("Scheduling Error: {}", e)) } - pub fn run_schedule_from_file_on_hercules( module: Module, sched_filename: Option<String>, diff --git a/juno_scheduler/src/pm.rs b/juno_scheduler/src/pm.rs index 43f355c3..8b71d24e 100644 --- a/juno_scheduler/src/pm.rs +++ b/juno_scheduler/src/pm.rs @@ -1871,14 +1871,11 @@ fn run_pass( } Pass::Unforkify => { assert!(args.is_empty()); - loop { - let mut inner_changed = false; - - pm.make_fork_join_maps(); - pm.make_loops(); + pm.make_fork_join_maps(); + pm.make_loops(); - let fork_join_maps = pm.fork_join_maps.take().unwrap(); - let loops = pm.loops.take().unwrap(); + let fork_join_maps = pm.fork_join_maps.take().unwrap(); + let loops = pm.loops.take().unwrap(); for ((func, fork_join_map), loop_tree) in build_selection(pm, selection) .into_iter() @@ -1894,6 +1891,24 @@ fn run_pass( pm.delete_gravestones(); pm.clear_analyses(); } + Pass::ForkDimMerge => { + assert!(args.is_empty()); + pm.make_fork_join_maps(); + let fork_join_maps = pm.fork_join_maps.take().unwrap(); + for (func, fork_join_map) in + build_selection(pm, selection) + .into_iter() + .zip(fork_join_maps.iter()) + { + let Some(mut func) = func else { + continue; + }; + merge_all_fork_dims(&mut func, fork_join_map); + changed |= func.modified(); + } + pm.delete_gravestones(); + pm.clear_analyses(); + } Pass::ForkCoalesce => { assert!(args.is_empty()); pm.make_fork_join_maps(); @@ -1991,6 +2006,7 @@ fn run_pass( // Put BasicBlocks back, since it's needed for Codegen. pm.bbs = bbs; } + Pass::ForkChunk => todo!(), } println!("Ran Pass: {:?}", pass); -- GitLab From 1788bd1f2ba4ec496cb5afa62da884f92c1761a4 Mon Sep 17 00:00:00 2001 From: Xavier Routh <xrouth2@illinois.edu> Date: Mon, 3 Feb 2025 11:10:10 -0600 Subject: [PATCH 20/33] tiling + dim merge with one edit per loop dim --- hercules_opt/src/fork_transforms.rs | 96 ++++++++++++++++------------- juno_scheduler/src/compile.rs | 1 + juno_scheduler/src/ir.rs | 1 + juno_scheduler/src/pm.rs | 34 ++++++++++ 4 files changed, 90 insertions(+), 42 deletions(-) diff --git a/hercules_opt/src/fork_transforms.rs b/hercules_opt/src/fork_transforms.rs index 58ace775..cbb09bbf 100644 --- a/hercules_opt/src/fork_transforms.rs +++ b/hercules_opt/src/fork_transforms.rs @@ -695,6 +695,24 @@ pub(crate) fn split_fork( } } +pub fn chunk_all_forks_unguarded( + editor: &mut FunctionEditor, + fork_join_map: &HashMap<NodeID, NodeID>, + dim_idx: usize, + tile_size: usize, +) -> () { + // Add dc + let mut dc_id = DynamicConstantID::new(0); + editor.edit(|mut edit| { + dc_id = edit.add_dynamic_constant(DynamicConstant::Constant(tile_size)); + Ok(edit) + }); + + for (fork, _ ) in fork_join_map { + chunk_fork_unguarded(editor, *fork, dim_idx, dc_id); + } + +} // Splits a dimension of a single fork join into multiple. // Iterates an outer loop original_dim / tile_size times // adds a tile_size loop as the inner loop @@ -711,39 +729,36 @@ pub fn chunk_fork_unguarded( let mut new_factors: Vec<_> = old_factors.to_vec(); - let fork_users: Vec<_> = editor.get_users(fork).collect(); + let fork_users: Vec<_> = editor.get_users(fork).map(|f| (f, editor.node(f).clone())).collect(); + + editor.edit(|mut edit| { + let outer = DynamicConstant::div(new_factors[dim_idx], tile_size); + new_factors.insert(dim_idx + 1, tile_size); + new_factors[dim_idx] = edit.add_dynamic_constant(outer); + + let new_fork = Node::Fork { control: old_control, factors: new_factors.into() }; + let new_fork = edit.add_node(new_fork); + edit = edit.replace_all_uses(fork, new_fork)?; - for tid in fork_users { - let Node::ThreadID { control: _, dimension: tid_dim } = editor.node(tid).clone() else { continue }; - editor.edit(|mut edit| { + for (tid, node) in fork_users { + let Node::ThreadID { control: _, dimension: tid_dim } = node else {continue}; if tid_dim > dim_idx { - let new_tid = Node::ThreadID { control: fork, dimension: tid_dim + 1 }; + let new_tid = Node::ThreadID { control: new_fork, dimension: tid_dim + 1 }; let new_tid = edit.add_node(new_tid); - edit.replace_all_uses(tid, new_tid) + edit = edit.replace_all_uses(tid, new_tid)?; } else if tid_dim == dim_idx { - let tile_tid = Node::ThreadID { control: fork, dimension: tid_dim + 1 }; + let tile_tid = Node::ThreadID { control: new_fork, dimension: tid_dim + 1 }; let tile_tid = edit.add_node(tile_tid); let tile_size = edit.add_node(Node::DynamicConstant { id: tile_size }); let mul = edit.add_node(Node::Binary { left: tid, right: tile_size, op: BinaryOperator::Mul }); let add = edit.add_node(Node::Binary { left: mul, right: tile_tid, op: BinaryOperator::Add }); - edit.replace_all_uses_where(tid, add, |usee| *usee != mul ) - } else { - Ok(edit) + edit = edit.replace_all_uses_where(tid, add, |usee| *usee != mul )?; } - }); - } - - editor.edit(|mut edit| { - let outer = DynamicConstant::div(new_factors[dim_idx], tile_size); - new_factors.insert(dim_idx + 1, tile_size); - new_factors[dim_idx] = edit.add_dynamic_constant(outer); - - let new_fork = Node::Fork { control: old_control, factors: new_factors.into() }; - let new_fork = edit.add_node(new_fork); - - edit.replace_all_uses(fork, new_fork) + } + edit = edit.delete_node(fork)?; + Ok(edit) }); } @@ -791,9 +806,8 @@ pub fn fork_dim_merge( let mut new_factors: Vec<_> = old_factors.to_vec(); - - let fork_users: Vec<_> = editor.get_users(fork).collect(); + let fork_users: Vec<_> = editor.get_users(fork).map(|f| (f, editor.node(f).clone())).collect(); let mut new_nodes = vec![]; @@ -801,6 +815,7 @@ pub fn fork_dim_merge( let inner_dc_id = new_factors[inner_idx]; let mut new_fork_id = NodeID::new(0); + editor.edit(|mut edit| { new_factors[outer_idx] = edit.add_dynamic_constant(DynamicConstant::mul(new_factors[outer_idx], new_factors[inner_idx])); new_factors.remove(inner_idx); @@ -809,22 +824,20 @@ pub fn fork_dim_merge( let new_fork = edit.add_node(new_fork); new_fork_id = new_fork; + edit.sub_edit(fork, new_fork); edit = edit.replace_all_uses(fork, new_fork)?; - edit.delete_node(fork) - }); - - - - for tid in fork_users { - let Node::ThreadID { control: _, dimension: tid_dim } = editor.node(tid).clone() else { continue }; + edit = edit.delete_node(fork)?; - println!("tid: {:?}", tid); - editor.edit(|mut edit| { + for (tid, node) in fork_users { + // FIXME: DO we want sub edits in this? + + let Node::ThreadID { control: _, dimension: tid_dim } = node else { continue }; if tid_dim > inner_idx { let new_tid = Node::ThreadID { control: new_fork_id, dimension: tid_dim - 1 }; let new_tid = edit.add_node(new_tid); - edit.replace_all_uses(tid, new_tid) + edit = edit.replace_all_uses(tid, new_tid)?; + edit.sub_edit(tid, new_tid); } else if tid_dim == outer_idx { let outer_tid = Node::ThreadID { control: new_fork_id, dimension: outer_idx }; let outer_tid = edit.add_node(outer_tid); @@ -834,8 +847,8 @@ pub fn fork_dim_merge( // inner_idx % dim(outer_idx) let rem = edit.add_node(Node::Binary { left: outer_tid, right: outer_dc, op: BinaryOperator::Rem}); - - edit.replace_all_uses(tid, rem) + edit.sub_edit(tid, rem); + edit = edit.replace_all_uses(tid, rem)?; } else if tid_dim == inner_idx { let outer_tid = Node::ThreadID { control: new_fork_id, dimension: outer_idx }; let outer_tid = edit.add_node(outer_tid); @@ -843,13 +856,12 @@ pub fn fork_dim_merge( let outer_dc = edit.add_node(Node::DynamicConstant { id: outer_dc_id }); // inner_idx / dim(outer_idx) let div = edit.add_node(Node::Binary { left: outer_tid, right: outer_dc, op: BinaryOperator::Div}); - - edit.replace_all_uses(tid, div) - } else { - Ok(edit) + edit.sub_edit(tid, div); + edit = edit.replace_all_uses(tid, div)?; } - }); - }; + } + Ok(edit) + }); return new_fork_id; diff --git a/juno_scheduler/src/compile.rs b/juno_scheduler/src/compile.rs index 07ad5e7a..49dedd2b 100644 --- a/juno_scheduler/src/compile.rs +++ b/juno_scheduler/src/compile.rs @@ -110,6 +110,7 @@ impl FromStr for Appliable { Ok(Appliable::Pass(ir::Pass::InterproceduralSROA)) }, "fork-dim-merge" => Ok(Appliable::Pass(ir::Pass::ForkDimMerge)), + "fork-chunk" | "fork-tile" => Ok(Appliable::Pass(ir::Pass::ForkChunk)), "lift-dc-math" => Ok(Appliable::Pass(ir::Pass::LiftDCMath)), "outline" => Ok(Appliable::Pass(ir::Pass::Outline)), "phi-elim" => Ok(Appliable::Pass(ir::Pass::PhiElim)), diff --git a/juno_scheduler/src/ir.rs b/juno_scheduler/src/ir.rs index 939ef925..796437a7 100644 --- a/juno_scheduler/src/ir.rs +++ b/juno_scheduler/src/ir.rs @@ -36,6 +36,7 @@ impl Pass { pub fn num_args(&self) -> usize { match self { Pass::Xdot => 1, + Pass::ForkChunk => 3, _ => 0, } } diff --git a/juno_scheduler/src/pm.rs b/juno_scheduler/src/pm.rs index 8b71d24e..5740d2a6 100644 --- a/juno_scheduler/src/pm.rs +++ b/juno_scheduler/src/pm.rs @@ -1891,6 +1891,40 @@ fn run_pass( pm.delete_gravestones(); pm.clear_analyses(); } + Pass::ForkChunk => { + assert_eq!(args.len(), 3); + let tile_size = args.get(0); + let dim_idx = args.get(1); + + let Some(Value::Boolean { val: guarded_flag }) = args.get(2) else { + panic!(); // How to error here? + }; + + let Some(Value::Integer { val: dim_idx }) = args.get(1) else { + panic!(); // How to error here? + }; + + let Some(Value::Integer { val: tile_size }) = args.get(0) else { + panic!(); // How to error here? + }; + + assert_eq!(*guarded_flag, true); + pm.make_fork_join_maps(); + let fork_join_maps = pm.fork_join_maps.take().unwrap(); + for (func, fork_join_map) in + build_selection(pm, selection) + .into_iter() + .zip(fork_join_maps.iter()) + { + let Some(mut func) = func else { + continue; + }; + chunk_all_forks_unguarded(&mut func, fork_join_map, *dim_idx, *tile_size); + changed |= func.modified(); + } + pm.delete_gravestones(); + pm.clear_analyses(); + } Pass::ForkDimMerge => { assert!(args.is_empty()); pm.make_fork_join_maps(); -- GitLab From 27731fa4fe3553ba1a93a98115636b60a4dd00ce Mon Sep 17 00:00:00 2001 From: Xavier Routh <xrouth2@illinois.edu> Date: Mon, 3 Feb 2025 11:15:56 -0600 Subject: [PATCH 21/33] check for out of bounds dim on chunking --- hercules_opt/src/fork_transforms.rs | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/hercules_opt/src/fork_transforms.rs b/hercules_opt/src/fork_transforms.rs index cbb09bbf..190dbd25 100644 --- a/hercules_opt/src/fork_transforms.rs +++ b/hercules_opt/src/fork_transforms.rs @@ -725,8 +725,13 @@ pub fn chunk_fork_unguarded( ) -> () { // tid_dim_idx = tid_dim_idx * tile_size + tid_(dim_idx + 1) + let Node::Fork { control: old_control, factors: ref old_factors} = *editor.node(fork) else {return}; + if dim_idx >= old_factors.len() { + return; // FIXME Error here? + } + let mut new_factors: Vec<_> = old_factors.to_vec(); let fork_users: Vec<_> = editor.get_users(fork).map(|f| (f, editor.node(f).clone())).collect(); -- GitLab From 1f1c6cb94c80ac867958c41d2a84d506bf06ef92 Mon Sep 17 00:00:00 2001 From: Xavier Routh <xrouth2@illinois.edu> Date: Mon, 3 Feb 2025 11:46:58 -0600 Subject: [PATCH 22/33] rewrite forkify as single edit per loop --- hercules_opt/src/forkify.rs | 75 ++++++++++++++++++++----------------- 1 file changed, 40 insertions(+), 35 deletions(-) diff --git a/hercules_opt/src/forkify.rs b/hercules_opt/src/forkify.rs index 38b9aaaa..0a2d5601 100644 --- a/hercules_opt/src/forkify.rs +++ b/hercules_opt/src/forkify.rs @@ -298,29 +298,11 @@ pub fn forkify_loop( let (_, factors) = function.nodes[fork_id.idx()].try_fork().unwrap(); let dimension = factors.len() - 1; - // Create ThreadID - editor.edit(|mut edit| { - let thread_id = Node::ThreadID { - control: fork_id, - dimension: dimension, - }; - let thread_id_id = edit.add_node(thread_id); - - // Replace uses that are inside with the thread id - edit = edit.replace_all_uses_where(canonical_iv.phi(), thread_id_id, |node| { - loop_nodes.contains(node) - })?; + // Start failable edit: - // Replace uses that are outside with DC - 1. Or just give up. - let bound_dc_node = edit.add_node(Node::DynamicConstant { id: bound_dc_id }); - edit = edit.replace_all_uses_where(canonical_iv.phi(), bound_dc_node, |node| { - !loop_nodes.contains(node) - })?; - - edit.delete_node(canonical_iv.phi()) - }); - - for reduction_phi in reductionable_phis { + let redcutionable_phis_and_init: Vec<(_, NodeID)> = + reductionable_phis.iter().map(|reduction_phi| { + let LoopPHI::Reductionable { phi, data_cycle: _, @@ -342,12 +324,41 @@ pub fn forkify_loop( .unwrap() .1; - editor.edit(|mut edit| { + (reduction_phi, init) + }).collect(); + + editor.edit(|mut edit| { + let thread_id = Node::ThreadID { + control: fork_id, + dimension: dimension, + }; + let thread_id_id = edit.add_node(thread_id); + + // Replace uses that are inside with the thread id + edit = edit.replace_all_uses_where(canonical_iv.phi(), thread_id_id, |node| { + loop_nodes.contains(node) + })?; + + edit = edit.delete_node(canonical_iv.phi())?; + + for (reduction_phi, init) in redcutionable_phis_and_init { + let LoopPHI::Reductionable { + phi, + data_cycle: _, + continue_latch, + is_associative: _, + } = *reduction_phi + else { + panic!(); + }; + let reduce = Node::Reduce { control: join_id, init, reduct: continue_latch, }; + + let reduce_id = edit.add_node(reduce); if (!edit.get_node(init).is_reduce() @@ -375,20 +386,14 @@ pub fn forkify_loop( edit = edit.replace_all_uses_where(continue_latch, reduce_id, |usee| { !loop_nodes.contains(usee) && *usee != reduce_id })?; - edit.delete_node(phi) - }); - } - - // Replace all uses of the loop header with the fork - editor.edit(|edit| edit.replace_all_uses(l.header, fork_id)); + edit = edit.delete_node(phi)? - editor.edit(|edit| edit.replace_all_uses(loop_continue_projection, fork_id)); + } - editor.edit(|edit| edit.replace_all_uses(loop_exit_projection, join_id)); + edit = edit.replace_all_uses(l.header, fork_id)?; + edit = edit.replace_all_uses(loop_continue_projection, fork_id)?; + edit = edit.replace_all_uses(loop_exit_projection, join_id)?; - // Get rid of loop condition - // DCE should get these, but delete them ourselves because we are nice :) - editor.edit(|mut edit| { edit = edit.delete_node(loop_continue_projection)?; edit = edit.delete_node(condition_node)?; // Might have to get rid of other users of this. edit = edit.delete_node(loop_exit_projection)?; @@ -396,7 +401,7 @@ pub fn forkify_loop( edit = edit.delete_node(l.header)?; Ok(edit) }); - + return true; } -- GitLab From 6382ef4263b16f54a8d3b4d5e3a795c9c9e11013 Mon Sep 17 00:00:00 2001 From: Russel Arbore <russel.jma@gmail.com> Date: Mon, 3 Feb 2025 15:44:25 -0600 Subject: [PATCH 23/33] fix ups --- Cargo.lock | 11 -- hercules_cg/src/fork_tree.rs | 19 +- hercules_opt/src/fork_guard_elim.rs | 19 +- hercules_opt/src/fork_transforms.rs | 166 +++++++++++------- hercules_opt/src/forkify.rs | 61 ++++--- hercules_samples/dot/build.rs | 6 +- hercules_samples/dot/src/main.rs | 2 +- hercules_samples/matmul/build.rs | 6 +- hercules_samples/matmul/src/main.rs | 6 +- .../hercules_interpreter/src/interpreter.rs | 76 ++++---- .../hercules_interpreter/src/value.rs | 10 +- .../hercules_tests/tests/loop_tests.rs | 40 ++--- juno_frontend/src/semant.rs | 11 +- juno_samples/cava/src/main.rs | 45 ++--- juno_samples/concat/src/main.rs | 4 +- juno_samples/edge_detection/src/main.rs | 11 +- juno_samples/matmul/src/main.rs | 18 +- juno_samples/nested_ccp/src/main.rs | 2 +- juno_samples/patterns/src/main.rs | 2 +- juno_samples/schedule_test/build.rs | 6 +- juno_samples/schedule_test/src/main.rs | 13 +- juno_samples/simple3/src/main.rs | 2 +- juno_scheduler/src/compile.rs | 2 +- juno_scheduler/src/ir.rs | 2 +- juno_scheduler/src/pm.rs | 31 ++-- 25 files changed, 336 insertions(+), 235 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 7a70825a..af7902c6 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1236,17 +1236,6 @@ dependencies = [ "with_builtin_macros", ] -[[package]] -name = "juno_test" -version = "0.1.0" -dependencies = [ - "async-std", - "hercules_rt", - "juno_build", - "rand", - "with_builtin_macros", -] - [[package]] name = "juno_utils" version = "0.1.0" diff --git a/hercules_cg/src/fork_tree.rs b/hercules_cg/src/fork_tree.rs index 64a93160..c048f7e3 100644 --- a/hercules_cg/src/fork_tree.rs +++ b/hercules_cg/src/fork_tree.rs @@ -9,11 +9,16 @@ use crate::*; * c) no domination by any other fork that's also dominated by F, where we do count self-domination * Here too we include the non-fork start node, as key for all controls outside any fork. */ -pub fn fork_control_map(fork_join_nesting: &HashMap<NodeID, Vec<NodeID>>) -> HashMap<NodeID, HashSet<NodeID>> { +pub fn fork_control_map( + fork_join_nesting: &HashMap<NodeID, Vec<NodeID>>, +) -> HashMap<NodeID, HashSet<NodeID>> { let mut fork_control_map = HashMap::new(); for (control, forks) in fork_join_nesting { let fork = forks.first().copied().unwrap_or(NodeID::new(0)); - fork_control_map.entry(fork).or_insert_with(HashSet::new).insert(*control); + fork_control_map + .entry(fork) + .or_insert_with(HashSet::new) + .insert(*control); } fork_control_map } @@ -24,13 +29,19 @@ pub fn fork_control_map(fork_join_nesting: &HashMap<NodeID, Vec<NodeID>>) -> Has * c) no domination by any other fork that's also dominated by F, where we don't count self-domination * Note that the fork_tree also includes the non-fork start node, as unique root node. */ -pub fn fork_tree(function: &Function, fork_join_nesting: &HashMap<NodeID, Vec<NodeID>>) -> HashMap<NodeID, HashSet<NodeID>> { +pub fn fork_tree( + function: &Function, + fork_join_nesting: &HashMap<NodeID, Vec<NodeID>>, +) -> HashMap<NodeID, HashSet<NodeID>> { let mut fork_tree = HashMap::new(); for (control, forks) in fork_join_nesting { if function.nodes[control.idx()].is_fork() { fork_tree.entry(*control).or_insert_with(HashSet::new); let nesting_fork = forks.get(1).copied().unwrap_or(NodeID::new(0)); - fork_tree.entry(nesting_fork).or_insert_with(HashSet::new).insert(*control); + fork_tree + .entry(nesting_fork) + .or_insert_with(HashSet::new) + .insert(*control); } } fork_tree diff --git a/hercules_opt/src/fork_guard_elim.rs b/hercules_opt/src/fork_guard_elim.rs index f6914b74..df40e60f 100644 --- a/hercules_opt/src/fork_guard_elim.rs +++ b/hercules_opt/src/fork_guard_elim.rs @@ -76,13 +76,16 @@ fn guarded_fork( }; // Filter out any terms which are just 1s - let non_ones = xs.iter().filter(|i| { - if let DynamicConstant::Constant(1) = editor.get_dynamic_constant(**i).deref() { - false - } else { - true - } - }).collect::<Vec<_>>(); + let non_ones = xs + .iter() + .filter(|i| { + if let DynamicConstant::Constant(1) = editor.get_dynamic_constant(**i).deref() { + false + } else { + true + } + }) + .collect::<Vec<_>>(); // If we're left with just one term x, we had max { 1, x } if non_ones.len() == 1 { Factor::Max(idx, *non_ones[0]) @@ -326,7 +329,7 @@ pub fn fork_guard_elim(editor: &mut FunctionEditor, fork_join_map: &HashMap<Node guard_if, guard_join_region, } in guard_info - { + { let Some(guard_pred) = editor.get_uses(guard_if).next() else { unreachable!() }; diff --git a/hercules_opt/src/fork_transforms.rs b/hercules_opt/src/fork_transforms.rs index 190dbd25..ed6283fd 100644 --- a/hercules_opt/src/fork_transforms.rs +++ b/hercules_opt/src/fork_transforms.rs @@ -708,14 +708,13 @@ pub fn chunk_all_forks_unguarded( Ok(edit) }); - for (fork, _ ) in fork_join_map { + for (fork, _) in fork_join_map { chunk_fork_unguarded(editor, *fork, dim_idx, dc_id); } - } -// Splits a dimension of a single fork join into multiple. -// Iterates an outer loop original_dim / tile_size times -// adds a tile_size loop as the inner loop +// Splits a dimension of a single fork join into multiple. +// Iterates an outer loop original_dim / tile_size times +// adds a tile_size loop as the inner loop // Assumes that tile size divides original dim evenly. pub fn chunk_fork_unguarded( editor: &mut FunctionEditor, @@ -724,42 +723,68 @@ pub fn chunk_fork_unguarded( tile_size: DynamicConstantID, ) -> () { // tid_dim_idx = tid_dim_idx * tile_size + tid_(dim_idx + 1) - - - let Node::Fork { control: old_control, factors: ref old_factors} = *editor.node(fork) else {return}; - - if dim_idx >= old_factors.len() { - return; // FIXME Error here? - } - + let Node::Fork { + control: old_control, + factors: ref old_factors, + } = *editor.node(fork) + else { + return; + }; + assert!(dim_idx < old_factors.len()); let mut new_factors: Vec<_> = old_factors.to_vec(); - - let fork_users: Vec<_> = editor.get_users(fork).map(|f| (f, editor.node(f).clone())).collect(); + let fork_users: Vec<_> = editor + .get_users(fork) + .map(|f| (f, editor.node(f).clone())) + .collect(); editor.edit(|mut edit| { let outer = DynamicConstant::div(new_factors[dim_idx], tile_size); new_factors.insert(dim_idx + 1, tile_size); new_factors[dim_idx] = edit.add_dynamic_constant(outer); - let new_fork = Node::Fork { control: old_control, factors: new_factors.into() }; + let new_fork = Node::Fork { + control: old_control, + factors: new_factors.into(), + }; let new_fork = edit.add_node(new_fork); edit = edit.replace_all_uses(fork, new_fork)?; for (tid, node) in fork_users { - let Node::ThreadID { control: _, dimension: tid_dim } = node else {continue}; + let Node::ThreadID { + control: _, + dimension: tid_dim, + } = node + else { + continue; + }; if tid_dim > dim_idx { - let new_tid = Node::ThreadID { control: new_fork, dimension: tid_dim + 1 }; + let new_tid = Node::ThreadID { + control: new_fork, + dimension: tid_dim + 1, + }; let new_tid = edit.add_node(new_tid); edit = edit.replace_all_uses(tid, new_tid)?; + edit = edit.delete_node(tid)?; } else if tid_dim == dim_idx { - let tile_tid = Node::ThreadID { control: new_fork, dimension: tid_dim + 1 }; + let tile_tid = Node::ThreadID { + control: new_fork, + dimension: tid_dim + 1, + }; let tile_tid = edit.add_node(tile_tid); - + let tile_size = edit.add_node(Node::DynamicConstant { id: tile_size }); - let mul = edit.add_node(Node::Binary { left: tid, right: tile_size, op: BinaryOperator::Mul }); - let add = edit.add_node(Node::Binary { left: mul, right: tile_tid, op: BinaryOperator::Add }); - edit = edit.replace_all_uses_where(tid, add, |usee| *usee != mul )?; + let mul = edit.add_node(Node::Binary { + left: tid, + right: tile_size, + op: BinaryOperator::Mul, + }); + let add = edit.add_node(Node::Binary { + left: mul, + right: tile_tid, + op: BinaryOperator::Add, + }); + edit = edit.replace_all_uses_where(tid, add, |usee| *usee != mul)?; } } edit = edit.delete_node(fork)?; @@ -767,13 +792,13 @@ pub fn chunk_fork_unguarded( }); } - -pub fn merge_all_fork_dims( - editor: &mut FunctionEditor, - fork_join_map: &HashMap<NodeID, NodeID>, -) { +pub fn merge_all_fork_dims(editor: &mut FunctionEditor, fork_join_map: &HashMap<NodeID, NodeID>) { for (fork, _) in fork_join_map { - let Node::Fork { control: _, factors: dims } = editor.node(fork) else { + let Node::Fork { + control: _, + factors: dims, + } = editor.node(fork) + else { unreachable!(); }; @@ -786,10 +811,6 @@ pub fn merge_all_fork_dims( } } -// Splits a dimension of a single fork join into multiple. -// Iterates an outer loop original_dim / tile_size times -// adds a tile_size loop as the inner loop -// Assumes that tile size divides original dim evenly. pub fn fork_dim_merge( editor: &mut FunctionEditor, fork: NodeID, @@ -806,61 +827,85 @@ pub fn fork_dim_merge( } else { (dim_idx1, dim_idx2) }; - - let Node::Fork { control: old_control, factors: ref old_factors} = *editor.node(fork) else {return fork}; - + let Node::Fork { + control: old_control, + factors: ref old_factors, + } = *editor.node(fork) + else { + return fork; + }; let mut new_factors: Vec<_> = old_factors.to_vec(); - - - let fork_users: Vec<_> = editor.get_users(fork).map(|f| (f, editor.node(f).clone())).collect(); - + let fork_users: Vec<_> = editor + .get_users(fork) + .map(|f| (f, editor.node(f).clone())) + .collect(); let mut new_nodes = vec![]; - let outer_dc_id = new_factors[outer_idx]; let inner_dc_id = new_factors[inner_idx]; - - let mut new_fork_id = NodeID::new(0); + let mut new_fork = NodeID::new(0); editor.edit(|mut edit| { - new_factors[outer_idx] = edit.add_dynamic_constant(DynamicConstant::mul(new_factors[outer_idx], new_factors[inner_idx])); + new_factors[outer_idx] = edit.add_dynamic_constant(DynamicConstant::mul( + new_factors[outer_idx], + new_factors[inner_idx], + )); new_factors.remove(inner_idx); - - let new_fork = Node::Fork { control: old_control, factors: new_factors.into() }; - let new_fork = edit.add_node(new_fork); - new_fork_id = new_fork; - + new_fork = edit.add_node(Node::Fork { + control: old_control, + factors: new_factors.into(), + }); edit.sub_edit(fork, new_fork); - edit = edit.replace_all_uses(fork, new_fork)?; edit = edit.delete_node(fork)?; for (tid, node) in fork_users { - // FIXME: DO we want sub edits in this? - - let Node::ThreadID { control: _, dimension: tid_dim } = node else { continue }; + let Node::ThreadID { + control: _, + dimension: tid_dim, + } = node + else { + continue; + }; if tid_dim > inner_idx { - let new_tid = Node::ThreadID { control: new_fork_id, dimension: tid_dim - 1 }; + let new_tid = Node::ThreadID { + control: new_fork_id, + dimension: tid_dim - 1, + }; let new_tid = edit.add_node(new_tid); edit = edit.replace_all_uses(tid, new_tid)?; edit.sub_edit(tid, new_tid); } else if tid_dim == outer_idx { - let outer_tid = Node::ThreadID { control: new_fork_id, dimension: outer_idx }; + let outer_tid = Node::ThreadID { + control: new_fork_id, + dimension: outer_idx, + }; let outer_tid = edit.add_node(outer_tid); let outer_dc = edit.add_node(Node::DynamicConstant { id: outer_dc_id }); new_nodes.push(outer_tid); // inner_idx % dim(outer_idx) - let rem = edit.add_node(Node::Binary { left: outer_tid, right: outer_dc, op: BinaryOperator::Rem}); + let rem = edit.add_node(Node::Binary { + left: outer_tid, + right: outer_dc, + op: BinaryOperator::Rem, + }); edit.sub_edit(tid, rem); edit = edit.replace_all_uses(tid, rem)?; } else if tid_dim == inner_idx { - let outer_tid = Node::ThreadID { control: new_fork_id, dimension: outer_idx }; + let outer_tid = Node::ThreadID { + control: new_fork_id, + dimension: outer_idx, + }; let outer_tid = edit.add_node(outer_tid); let outer_dc = edit.add_node(Node::DynamicConstant { id: outer_dc_id }); // inner_idx / dim(outer_idx) - let div = edit.add_node(Node::Binary { left: outer_tid, right: outer_dc, op: BinaryOperator::Div}); + let div = edit.add_node(Node::Binary { + left: outer_tid, + right: outer_dc, + op: BinaryOperator::Div, + }); edit.sub_edit(tid, div); edit = edit.replace_all_uses(tid, div)?; } @@ -868,6 +913,5 @@ pub fn fork_dim_merge( Ok(edit) }); - return new_fork_id; - -} \ No newline at end of file + new_fork +} diff --git a/hercules_opt/src/forkify.rs b/hercules_opt/src/forkify.rs index 0a2d5601..f6db06ca 100644 --- a/hercules_opt/src/forkify.rs +++ b/hercules_opt/src/forkify.rs @@ -300,32 +300,33 @@ pub fn forkify_loop( // Start failable edit: - let redcutionable_phis_and_init: Vec<(_, NodeID)> = - reductionable_phis.iter().map(|reduction_phi| { - - let LoopPHI::Reductionable { - phi, - data_cycle: _, - continue_latch, - is_associative: _, - } = reduction_phi - else { - panic!(); - }; + let redcutionable_phis_and_init: Vec<(_, NodeID)> = reductionable_phis + .iter() + .map(|reduction_phi| { + let LoopPHI::Reductionable { + phi, + data_cycle: _, + continue_latch, + is_associative: _, + } = reduction_phi + else { + panic!(); + }; - let function = editor.func(); + let function = editor.func(); - let init = *zip( - editor.get_uses(l.header), - function.nodes[phi.idx()].try_phi().unwrap().1.iter(), - ) - .filter(|(c, _)| *c == loop_pred) - .next() - .unwrap() - .1; + let init = *zip( + editor.get_uses(l.header), + function.nodes[phi.idx()].try_phi().unwrap().1.iter(), + ) + .filter(|(c, _)| *c == loop_pred) + .next() + .unwrap() + .1; - (reduction_phi, init) - }).collect(); + (reduction_phi, init) + }) + .collect(); editor.edit(|mut edit| { let thread_id = Node::ThreadID { @@ -351,14 +352,13 @@ pub fn forkify_loop( else { panic!(); }; - + let reduce = Node::Reduce { control: join_id, init, reduct: continue_latch, }; - - + let reduce_id = edit.add_node(reduce); if (!edit.get_node(init).is_reduce() @@ -387,7 +387,6 @@ pub fn forkify_loop( !loop_nodes.contains(usee) && *usee != reduce_id })?; edit = edit.delete_node(phi)? - } edit = edit.replace_all_uses(l.header, fork_id)?; @@ -401,7 +400,7 @@ pub fn forkify_loop( edit = edit.delete_node(l.header)?; Ok(edit) }); - + return true; } @@ -538,7 +537,11 @@ pub fn analyze_phis<'a>( // by the time the reduce is triggered (at the end of the loop's internal control). // If anything in the intersection is a phi (that isn't this own phi), then the reduction cycle depends on control. // Which is not allowed. - if intersection.iter().any(|cycle_node| editor.node(cycle_node).is_phi() && *cycle_node != *phi) || editor.node(loop_continue_latch).is_phi() { + if intersection + .iter() + .any(|cycle_node| editor.node(cycle_node).is_phi() && *cycle_node != *phi) + || editor.node(loop_continue_latch).is_phi() + { return LoopPHI::ControlDependant(*phi); } diff --git a/hercules_samples/dot/build.rs b/hercules_samples/dot/build.rs index 8657fdc1..c8de7e90 100644 --- a/hercules_samples/dot/build.rs +++ b/hercules_samples/dot/build.rs @@ -4,7 +4,11 @@ fn main() { JunoCompiler::new() .ir_in_src("dot.hir") .unwrap() - .schedule_in_src(if cfg!(feature = "cuda") { "gpu.sch" } else { "cpu.sch" }) + .schedule_in_src(if cfg!(feature = "cuda") { + "gpu.sch" + } else { + "cpu.sch" + }) .unwrap() .build() .unwrap(); diff --git a/hercules_samples/dot/src/main.rs b/hercules_samples/dot/src/main.rs index 8862c11a..7f5b453a 100644 --- a/hercules_samples/dot/src/main.rs +++ b/hercules_samples/dot/src/main.rs @@ -1,8 +1,8 @@ #![feature(concat_idents)] -use hercules_rt::{runner, HerculesCPURef}; #[cfg(feature = "cuda")] use hercules_rt::CUDABox; +use hercules_rt::{runner, HerculesCPURef}; juno_build::juno!("dot"); diff --git a/hercules_samples/matmul/build.rs b/hercules_samples/matmul/build.rs index 735458c0..ed92e022 100644 --- a/hercules_samples/matmul/build.rs +++ b/hercules_samples/matmul/build.rs @@ -4,7 +4,11 @@ fn main() { JunoCompiler::new() .ir_in_src("matmul.hir") .unwrap() - .schedule_in_src(if cfg!(feature = "cuda") { "gpu.sch" } else { "cpu.sch" }) + .schedule_in_src(if cfg!(feature = "cuda") { + "gpu.sch" + } else { + "cpu.sch" + }) .unwrap() .build() .unwrap(); diff --git a/hercules_samples/matmul/src/main.rs b/hercules_samples/matmul/src/main.rs index abd25ec9..5c879915 100644 --- a/hercules_samples/matmul/src/main.rs +++ b/hercules_samples/matmul/src/main.rs @@ -2,9 +2,9 @@ use rand::random; -use hercules_rt::{runner, HerculesCPURef}; #[cfg(feature = "cuda")] use hercules_rt::CUDABox; +use hercules_rt::{runner, HerculesCPURef}; juno_build::juno!("matmul"); @@ -36,7 +36,9 @@ fn main() { let a = CUDABox::from_cpu_ref(HerculesCPURef::from_slice(&mut a)); let b = CUDABox::from_cpu_ref(HerculesCPURef::from_slice(&mut b)); let mut r = runner!(matmul); - let c = r.run(I as u64, J as u64, K as u64, a.get_ref(), b.get_ref()).await; + let c = r + .run(I as u64, J as u64, K as u64, a.get_ref(), b.get_ref()) + .await; let mut c_cpu: Box<[i32]> = vec![0; correct_c.len()].into_boxed_slice(); c.to_cpu_ref(&mut c_cpu); assert_eq!(&*c_cpu, &*correct_c); diff --git a/hercules_test/hercules_interpreter/src/interpreter.rs b/hercules_test/hercules_interpreter/src/interpreter.rs index 22ef062a..2e352644 100644 --- a/hercules_test/hercules_interpreter/src/interpreter.rs +++ b/hercules_test/hercules_interpreter/src/interpreter.rs @@ -69,18 +69,18 @@ pub fn dyn_const_value( match dc { DynamicConstant::Constant(v) => *v, DynamicConstant::Parameter(v) => dyn_const_params[*v], - DynamicConstant::Add(xs) => { - xs.iter().map(|x| dyn_const_value(x, dyn_const_values, dyn_const_params)) - .fold(0, |s, v| s + v) - } + DynamicConstant::Add(xs) => xs + .iter() + .map(|x| dyn_const_value(x, dyn_const_values, dyn_const_params)) + .fold(0, |s, v| s + v), DynamicConstant::Sub(a, b) => { dyn_const_value(a, dyn_const_values, dyn_const_params) - dyn_const_value(b, dyn_const_values, dyn_const_params) } - DynamicConstant::Mul(xs) => { - xs.iter().map(|x| dyn_const_value(x, dyn_const_values, dyn_const_params)) - .fold(1, |p, v| p * v) - } + DynamicConstant::Mul(xs) => xs + .iter() + .map(|x| dyn_const_value(x, dyn_const_values, dyn_const_params)) + .fold(1, |p, v| p * v), DynamicConstant::Div(a, b) => { dyn_const_value(a, dyn_const_values, dyn_const_params) / dyn_const_value(b, dyn_const_values, dyn_const_params) @@ -89,28 +89,28 @@ pub fn dyn_const_value( dyn_const_value(a, dyn_const_values, dyn_const_params) % dyn_const_value(b, dyn_const_values, dyn_const_params) } - DynamicConstant::Max(xs) => { - xs.iter().map(|x| dyn_const_value(x, dyn_const_values, dyn_const_params)) - .fold(None, |m, v| { - if let Some(m) = m { - Some(max(m, v)) - } else { - Some(v) - } - }) - .unwrap() - } - DynamicConstant::Min(xs) => { - xs.iter().map(|x| dyn_const_value(x, dyn_const_values, dyn_const_params)) - .fold(None, |m, v| { - if let Some(m) = m { - Some(min(m, v)) - } else { - Some(v) - } - }) - .unwrap() - } + DynamicConstant::Max(xs) => xs + .iter() + .map(|x| dyn_const_value(x, dyn_const_values, dyn_const_params)) + .fold(None, |m, v| { + if let Some(m) = m { + Some(max(m, v)) + } else { + Some(v) + } + }) + .unwrap(), + DynamicConstant::Min(xs) => xs + .iter() + .map(|x| dyn_const_value(x, dyn_const_values, dyn_const_params)) + .fold(None, |m, v| { + if let Some(m) = m { + Some(min(m, v)) + } else { + Some(v) + } + }) + .unwrap(), } } @@ -775,15 +775,13 @@ impl<'a> FunctionExecutionState<'a> { // panic!("multi-dimensional forks unimplemented") // } - let factors = factors - .iter() - .map(|f| { - dyn_const_value( - &f, - &self.module.dynamic_constants, - &self.dynamic_constant_params, - ) - }); + let factors = factors.iter().map(|f| { + dyn_const_value( + &f, + &self.module.dynamic_constants, + &self.dynamic_constant_params, + ) + }); let n_tokens: usize = factors.clone().product(); diff --git a/hercules_test/hercules_interpreter/src/value.rs b/hercules_test/hercules_interpreter/src/value.rs index adbed6e6..4a802f7a 100644 --- a/hercules_test/hercules_interpreter/src/value.rs +++ b/hercules_test/hercules_interpreter/src/value.rs @@ -156,7 +156,15 @@ impl<'a> InterpreterVal { Constant::Float64(v) => Self::Float64(v), Constant::Product(ref type_id, ref constant_ids) => { - let contents = constant_ids.iter().map(|const_id| InterpreterVal::from_constant(&constants[const_id.idx()], constants, types, dynamic_constants, dynamic_constant_params)); + let contents = constant_ids.iter().map(|const_id| { + InterpreterVal::from_constant( + &constants[const_id.idx()], + constants, + types, + dynamic_constants, + dynamic_constant_params, + ) + }); InterpreterVal::Product(*type_id, contents.collect_vec().into_boxed_slice()) } Constant::Summation(_, _, _) => todo!(), diff --git a/hercules_test/hercules_tests/tests/loop_tests.rs b/hercules_test/hercules_tests/tests/loop_tests.rs index 192c1366..795642b2 100644 --- a/hercules_test/hercules_tests/tests/loop_tests.rs +++ b/hercules_test/hercules_tests/tests/loop_tests.rs @@ -35,9 +35,7 @@ fn alternate_bounds_use_after_loop_no_tid() { println!("result: {:?}", result_1); - let schedule = default_schedule![ - Forkify, - ]; + let schedule = default_schedule![Forkify,]; let module = run_schedule_on_hercules(module, Some(schedule)).unwrap(); @@ -61,9 +59,7 @@ fn alternate_bounds_use_after_loop() { println!("result: {:?}", result_1); - let schedule = Some(default_schedule![ - Forkify, - ]); + let schedule = Some(default_schedule![Forkify,]); let module = run_schedule_on_hercules(module, schedule).unwrap(); @@ -108,10 +104,7 @@ fn do_while_separate_body() { println!("result: {:?}", result_1); - let schedule = Some(default_schedule![ - PhiElim, - Forkify, - ]); + let schedule = Some(default_schedule![PhiElim, Forkify,]); let module = run_schedule_on_hercules(module, schedule).unwrap(); @@ -131,10 +124,7 @@ fn alternate_bounds_internal_control() { println!("result: {:?}", result_1); - let schedule = Some(default_schedule![ - PhiElim, - Forkify, - ]); + let schedule = Some(default_schedule![PhiElim, Forkify,]); let module = run_schedule_on_hercules(module, schedule).unwrap(); @@ -155,10 +145,7 @@ fn alternate_bounds_internal_control2() { println!("result: {:?}", result_1); - let schedule = Some(default_schedule![ - PhiElim, - Forkify, - ]); + let schedule = Some(default_schedule![PhiElim, Forkify,]); let module = run_schedule_on_hercules(module, schedule).unwrap(); @@ -366,16 +353,13 @@ fn look_at_local() { "/home/xavierrouth/dev/hercules/hercules_test/hercules_tests/save_me.hbin", ); - let schedule = Some(default_schedule![ - ]); + let schedule = Some(default_schedule![]); let result_1 = interp_module!(module, 0, dyn_consts, a.clone(), b.clone()); let module = run_schedule_on_hercules(module.clone(), schedule).unwrap(); - let schedule = Some(default_schedule![ - Unforkify, Verify, - ]); + let schedule = Some(default_schedule![Unforkify, Verify,]); let module = run_schedule_on_hercules(module.clone(), schedule).unwrap(); @@ -425,7 +409,15 @@ fn matmul_pipeline() { }; assert_eq!(correct_c[0], value); - let schedule = Some(default_schedule![AutoOutline, InterproceduralSROA, SROA, InferSchedules, DCE, Xdot, GCM]); + let schedule = Some(default_schedule![ + AutoOutline, + InterproceduralSROA, + SROA, + InferSchedules, + DCE, + Xdot, + GCM + ]); module = run_schedule_on_hercules(module, schedule).unwrap(); diff --git a/juno_frontend/src/semant.rs b/juno_frontend/src/semant.rs index e133e3c2..8668d1b4 100644 --- a/juno_frontend/src/semant.rs +++ b/juno_frontend/src/semant.rs @@ -752,7 +752,16 @@ fn analyze_program( } arg_info.push((ty, inout.is_some(), var)); - match process_irrefutable_pattern(pattern, false, var, ty, lexer, &mut stringtab, &mut env, &mut types) { + match process_irrefutable_pattern( + pattern, + false, + var, + ty, + lexer, + &mut stringtab, + &mut env, + &mut types, + ) { Ok(prep) => { stmts.extend(prep); } diff --git a/juno_samples/cava/src/main.rs b/juno_samples/cava/src/main.rs index 482bbf8d..e8a7e4e9 100644 --- a/juno_samples/cava/src/main.rs +++ b/juno_samples/cava/src/main.rs @@ -8,9 +8,9 @@ use self::camera_model::*; use self::cava_rust::CHAN; use self::image_proc::*; -use hercules_rt::{runner, HerculesCPURef}; #[cfg(feature = "cuda")] use hercules_rt::CUDABox; +use hercules_rt::{runner, HerculesCPURef}; use image::ImageError; @@ -31,7 +31,6 @@ fn run_cava( coefs: &[f32], tonemap: &[f32], ) -> Box<[u8]> { - assert_eq!(image.len(), CHAN * rows * cols); assert_eq!(tstw.len(), CHAN * CHAN); assert_eq!(ctrl_pts.len(), num_ctrl_pts * CHAN); @@ -47,21 +46,24 @@ fn run_cava( let weights = HerculesCPURef::from_slice(weights); let coefs = HerculesCPURef::from_slice(coefs); let tonemap = HerculesCPURef::from_slice(tonemap); - let mut r = runner!(cava); - async_std::task::block_on(async { - r.run( - rows as u64, - cols as u64, - num_ctrl_pts as u64, - image, - tstw, - ctrl_pts, - weights, - coefs, - tonemap, - ) - .await - }).as_slice::<u8>().to_vec().into_boxed_slice() + let mut r = runner!(cava); + async_std::task::block_on(async { + r.run( + rows as u64, + cols as u64, + num_ctrl_pts as u64, + image, + tstw, + ctrl_pts, + weights, + coefs, + tonemap, + ) + .await + }) + .as_slice::<u8>() + .to_vec() + .into_boxed_slice() } #[cfg(feature = "cuda")] @@ -72,8 +74,8 @@ fn run_cava( let weights = CUDABox::from_cpu_ref(HerculesCPURef::from_slice(weights)); let coefs = CUDABox::from_cpu_ref(HerculesCPURef::from_slice(coefs)); let tonemap = CUDABox::from_cpu_ref(HerculesCPURef::from_slice(tonemap)); - let mut r = runner!(cava); - let res = async_std::task::block_on(async { + let mut r = runner!(cava); + let res = async_std::task::block_on(async { r.run( rows as u64, cols as u64, @@ -86,7 +88,7 @@ fn run_cava( tonemap.get_ref(), ) .await - }); + }); let num_out = unsafe { res.__size() / std::mem::size_of::<u8>() }; let mut res_cpu: Box<[u8]> = vec![0; num_out].into_boxed_slice(); res.to_cpu_ref(&mut res_cpu); @@ -204,7 +206,8 @@ fn cava_harness(args: CavaInputs) { .expect("Error saving verification image"); } - let max_diff = result.iter() + let max_diff = result + .iter() .zip(cpu_result.iter()) .map(|(a, b)| (*a as i16 - *b as i16).abs()) .max() diff --git a/juno_samples/concat/src/main.rs b/juno_samples/concat/src/main.rs index 9674c2c5..547dee08 100644 --- a/juno_samples/concat/src/main.rs +++ b/juno_samples/concat/src/main.rs @@ -1,9 +1,9 @@ #![feature(concat_idents)] use hercules_rt::runner; -use hercules_rt::HerculesCPURef; #[cfg(feature = "cuda")] use hercules_rt::CUDABox; +use hercules_rt::HerculesCPURef; juno_build::juno!("concat"); @@ -20,7 +20,7 @@ fn main() { assert_eq!(output, 42); const N: usize = 3; - let arr : Box<[i32]> = (2..=4).collect(); + let arr: Box<[i32]> = (2..=4).collect(); let arr = HerculesCPURef::from_slice(&arr); let mut r = runner!(concat_switch); diff --git a/juno_samples/edge_detection/src/main.rs b/juno_samples/edge_detection/src/main.rs index eda65016..3b067ebd 100644 --- a/juno_samples/edge_detection/src/main.rs +++ b/juno_samples/edge_detection/src/main.rs @@ -2,9 +2,9 @@ mod edge_detection_rust; -use hercules_rt::{runner, HerculesCPURef}; #[cfg(feature = "cuda")] use hercules_rt::CUDABox; +use hercules_rt::{runner, HerculesCPURef}; use std::slice::from_raw_parts; @@ -228,9 +228,9 @@ fn edge_detection_harness(args: EdgeDetectionInputs) { }); #[cfg(not(feature = "cuda"))] - let result : Box<[f32]> = result.as_slice::<f32>().to_vec().into_boxed_slice(); + let result: Box<[f32]> = result.as_slice::<f32>().to_vec().into_boxed_slice(); #[cfg(feature = "cuda")] - let result : Box<[f32]> = { + let result: Box<[f32]> = { let num_out = unsafe { result.__size() / std::mem::size_of::<f32>() }; let mut res_cpu: Box<[f32]> = vec![0.0; num_out].into_boxed_slice(); result.to_cpu_ref(&mut res_cpu); @@ -261,7 +261,10 @@ fn edge_detection_harness(args: EdgeDetectionInputs) { theta, ); - assert_eq!(result.as_ref(), <Vec<f32> as AsRef<[f32]>>::as_ref(&rust_result)); + assert_eq!( + result.as_ref(), + <Vec<f32> as AsRef<[f32]>>::as_ref(&rust_result) + ); println!("Frames {} match", i); if display_verify { diff --git a/juno_samples/matmul/src/main.rs b/juno_samples/matmul/src/main.rs index 50fe1760..2892cd34 100644 --- a/juno_samples/matmul/src/main.rs +++ b/juno_samples/matmul/src/main.rs @@ -2,9 +2,9 @@ use rand::random; -use hercules_rt::{runner, HerculesCPURef}; #[cfg(feature = "cuda")] use hercules_rt::CUDABox; +use hercules_rt::{runner, HerculesCPURef}; juno_build::juno!("matmul"); @@ -28,10 +28,14 @@ fn main() { let a = HerculesCPURef::from_slice(&a); let b = HerculesCPURef::from_slice(&b); let mut r = runner!(matmul); - let c = r.run(I as u64, J as u64, K as u64, a.clone(), b.clone()).await; + let c = r + .run(I as u64, J as u64, K as u64, a.clone(), b.clone()) + .await; assert_eq!(c.as_slice::<i32>(), &*correct_c); let mut r = runner!(tiled_64_matmul); - let tiled_c = r.run(I as u64, J as u64, K as u64, a.clone(), b.clone()).await; + let tiled_c = r + .run(I as u64, J as u64, K as u64, a.clone(), b.clone()) + .await; assert_eq!(tiled_c.as_slice::<i32>(), &*correct_c); } #[cfg(feature = "cuda")] @@ -39,12 +43,16 @@ fn main() { let a = CUDABox::from_cpu_ref(HerculesCPURef::from_slice(&mut a)); let b = CUDABox::from_cpu_ref(HerculesCPURef::from_slice(&mut b)); let mut r = runner!(matmul); - let c = r.run(I as u64, J as u64, K as u64, a.get_ref(), b.get_ref()).await; + let c = r + .run(I as u64, J as u64, K as u64, a.get_ref(), b.get_ref()) + .await; let mut c_cpu: Box<[i32]> = vec![0; correct_c.len()].into_boxed_slice(); c.to_cpu_ref(&mut c_cpu); assert_eq!(&*c_cpu, &*correct_c); let mut r = runner!(tiled_64_matmul); - let tiled_c = r.run(I as u64, J as u64, K as u64, a.get_ref(), b.get_ref()).await; + let tiled_c = r + .run(I as u64, J as u64, K as u64, a.get_ref(), b.get_ref()) + .await; let mut tiled_c_cpu: Box<[i32]> = vec![0; correct_c.len()].into_boxed_slice(); tiled_c.to_cpu_ref(&mut tiled_c_cpu); assert_eq!(&*tiled_c_cpu, &*correct_c); diff --git a/juno_samples/nested_ccp/src/main.rs b/juno_samples/nested_ccp/src/main.rs index bc99a4bd..b364c03c 100644 --- a/juno_samples/nested_ccp/src/main.rs +++ b/juno_samples/nested_ccp/src/main.rs @@ -1,8 +1,8 @@ #![feature(concat_idents)] -use hercules_rt::{runner, HerculesCPURef, HerculesCPURefMut}; #[cfg(feature = "cuda")] use hercules_rt::CUDABox; +use hercules_rt::{runner, HerculesCPURef, HerculesCPURefMut}; juno_build::juno!("nested_ccp"); diff --git a/juno_samples/patterns/src/main.rs b/juno_samples/patterns/src/main.rs index 5cc2e7c8..a5586c8b 100644 --- a/juno_samples/patterns/src/main.rs +++ b/juno_samples/patterns/src/main.rs @@ -1,6 +1,6 @@ #![feature(concat_idents)] -use hercules_rt::{runner}; +use hercules_rt::runner; juno_build::juno!("patterns"); diff --git a/juno_samples/schedule_test/build.rs b/juno_samples/schedule_test/build.rs index 749a660c..0129c4de 100644 --- a/juno_samples/schedule_test/build.rs +++ b/juno_samples/schedule_test/build.rs @@ -4,7 +4,11 @@ fn main() { JunoCompiler::new() .file_in_src("code.jn") .unwrap() - .schedule_in_src(if cfg!(feature = "cuda") { "gpu.sch" } else { "cpu.sch" }) + .schedule_in_src(if cfg!(feature = "cuda") { + "gpu.sch" + } else { + "cpu.sch" + }) .unwrap() .build() .unwrap(); diff --git a/juno_samples/schedule_test/src/main.rs b/juno_samples/schedule_test/src/main.rs index 1505d4e5..f769e750 100644 --- a/juno_samples/schedule_test/src/main.rs +++ b/juno_samples/schedule_test/src/main.rs @@ -2,9 +2,9 @@ use rand::random; -use hercules_rt::{runner, HerculesCPURef}; #[cfg(feature = "cuda")] use hercules_rt::CUDABox; +use hercules_rt::{runner, HerculesCPURef}; juno_build::juno!("code"); @@ -43,7 +43,16 @@ fn main() { let b = CUDABox::from_cpu_ref(HerculesCPURef::from_slice(&b)); let c = CUDABox::from_cpu_ref(HerculesCPURef::from_slice(&c)); let mut r = runner!(test); - let res = r.run(N as u64, M as u64, K as u64, a.get_ref(), b.get_ref(), c.get_ref()).await; + let res = r + .run( + N as u64, + M as u64, + K as u64, + a.get_ref(), + b.get_ref(), + c.get_ref(), + ) + .await; let mut res_cpu: Box<[i32]> = vec![0; correct_res.len()].into_boxed_slice(); res.to_cpu_ref(&mut res_cpu); assert_eq!(&*res_cpu, &*correct_res); diff --git a/juno_samples/simple3/src/main.rs b/juno_samples/simple3/src/main.rs index 8eb78f7c..687ff414 100644 --- a/juno_samples/simple3/src/main.rs +++ b/juno_samples/simple3/src/main.rs @@ -1,8 +1,8 @@ #![feature(concat_idents)] -use hercules_rt::{runner, HerculesCPURef}; #[cfg(feature = "cuda")] use hercules_rt::CUDABox; +use hercules_rt::{runner, HerculesCPURef}; juno_build::juno!("simple3"); diff --git a/juno_scheduler/src/compile.rs b/juno_scheduler/src/compile.rs index ea06a0f2..713c30d4 100644 --- a/juno_scheduler/src/compile.rs +++ b/juno_scheduler/src/compile.rs @@ -108,7 +108,7 @@ impl FromStr for Appliable { "inline" => Ok(Appliable::Pass(ir::Pass::Inline)), "ip-sroa" | "interprocedural-sroa" => { Ok(Appliable::Pass(ir::Pass::InterproceduralSROA)) - }, + } "fork-dim-merge" => Ok(Appliable::Pass(ir::Pass::ForkDimMerge)), "fork-chunk" | "fork-tile" => Ok(Appliable::Pass(ir::Pass::ForkChunk)), "lift-dc-math" => Ok(Appliable::Pass(ir::Pass::LiftDCMath)), diff --git a/juno_scheduler/src/ir.rs b/juno_scheduler/src/ir.rs index 796437a7..9e85509f 100644 --- a/juno_scheduler/src/ir.rs +++ b/juno_scheduler/src/ir.rs @@ -36,7 +36,7 @@ impl Pass { pub fn num_args(&self) -> usize { match self { Pass::Xdot => 1, - Pass::ForkChunk => 3, + Pass::ForkChunk => 3, _ => 0, } } diff --git a/juno_scheduler/src/pm.rs b/juno_scheduler/src/pm.rs index d176b636..2142d5c5 100644 --- a/juno_scheduler/src/pm.rs +++ b/juno_scheduler/src/pm.rs @@ -1566,7 +1566,7 @@ fn run_pass( // this eventually. let c = forkify(&mut func, control_subgraph, fork_join_map, loop_nest); changed |= c; - inner_changed |= c; + inner_changed |= c; } pm.delete_gravestones(); pm.clear_analyses(); @@ -1921,24 +1921,32 @@ fn run_pass( let dim_idx = args.get(1); let Some(Value::Boolean { val: guarded_flag }) = args.get(2) else { - panic!(); // How to error here? + return Err(SchedulerError::PassError { + pass: "forkChunk".to_string(), + error: "expected boolean argument".to_string(), + }); }; let Some(Value::Integer { val: dim_idx }) = args.get(1) else { - panic!(); // How to error here? + return Err(SchedulerError::PassError { + pass: "forkChunk".to_string(), + error: "expected integer argument".to_string(), + }); }; let Some(Value::Integer { val: tile_size }) = args.get(0) else { - panic!(); // How to error here? + return Err(SchedulerError::PassError { + pass: "forkChunk".to_string(), + error: "expected integer argument".to_string(), + }); }; assert_eq!(*guarded_flag, true); pm.make_fork_join_maps(); let fork_join_maps = pm.fork_join_maps.take().unwrap(); - for (func, fork_join_map) in - build_selection(pm, selection) - .into_iter() - .zip(fork_join_maps.iter()) + for (func, fork_join_map) in build_selection(pm, selection) + .into_iter() + .zip(fork_join_maps.iter()) { let Some(mut func) = func else { continue; @@ -1953,10 +1961,9 @@ fn run_pass( assert!(args.is_empty()); pm.make_fork_join_maps(); let fork_join_maps = pm.fork_join_maps.take().unwrap(); - for (func, fork_join_map) in - build_selection(pm, selection) - .into_iter() - .zip(fork_join_maps.iter()) + for (func, fork_join_map) in build_selection(pm, selection) + .into_iter() + .zip(fork_join_maps.iter()) { let Some(mut func) = func else { continue; -- GitLab From 43404780896ebb6bd287c8a56783a08cded6852e Mon Sep 17 00:00:00 2001 From: Russel Arbore <russel.jma@gmail.com> Date: Mon, 3 Feb 2025 15:53:31 -0600 Subject: [PATCH 24/33] whoops --- hercules_opt/src/fork_transforms.rs | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/hercules_opt/src/fork_transforms.rs b/hercules_opt/src/fork_transforms.rs index ed6283fd..c4a6ba7f 100644 --- a/hercules_opt/src/fork_transforms.rs +++ b/hercules_opt/src/fork_transforms.rs @@ -868,7 +868,7 @@ pub fn fork_dim_merge( }; if tid_dim > inner_idx { let new_tid = Node::ThreadID { - control: new_fork_id, + control: new_fork, dimension: tid_dim - 1, }; let new_tid = edit.add_node(new_tid); @@ -876,7 +876,7 @@ pub fn fork_dim_merge( edit.sub_edit(tid, new_tid); } else if tid_dim == outer_idx { let outer_tid = Node::ThreadID { - control: new_fork_id, + control: new_fork, dimension: outer_idx, }; let outer_tid = edit.add_node(outer_tid); @@ -894,7 +894,7 @@ pub fn fork_dim_merge( edit = edit.replace_all_uses(tid, rem)?; } else if tid_dim == inner_idx { let outer_tid = Node::ThreadID { - control: new_fork_id, + control: new_fork, dimension: outer_idx, }; let outer_tid = edit.add_node(outer_tid); -- GitLab From e1124c929bd07adbe14d73b28af3fea777e5bbbe Mon Sep 17 00:00:00 2001 From: Xavier Routh <xrouth2@illinois.edu> Date: Wed, 5 Feb 2025 12:06:36 -0600 Subject: [PATCH 25/33] add bufferize test --- juno_samples/fork_join_tests/src/cpu.sch | 8 ++++++++ juno_samples/fork_join_tests/src/fork_join_tests.jn | 13 +++++++++++++ 2 files changed, 21 insertions(+) diff --git a/juno_samples/fork_join_tests/src/cpu.sch b/juno_samples/fork_join_tests/src/cpu.sch index 38010004..a557cd03 100644 --- a/juno_samples/fork_join_tests/src/cpu.sch +++ b/juno_samples/fork_join_tests/src/cpu.sch @@ -12,6 +12,7 @@ cpu(out.test3); cpu(out.test4); cpu(out.test5); + ip-sroa(*); sroa(*); dce(*); @@ -42,14 +43,21 @@ gvn(*); phi-elim(*); dce(*); +xdot[true](*); +fork-fission-bufferize(test7@loop, test7@bufferize1, test7@bufferize2, test7@bufferize3, test7@bufferize4); fork-tile[32, 0, true](test6@loop); let out = fork-split(test6@loop); //let out = outline(out.test6.fj1); + +let out7 = auto-outline(test7); +cpu(out7.test7); + let out = auto-outline(test6); cpu(out.test6); ip-sroa(*); sroa(*); unforkify(out.test6); +unforkify(out7.test7); dce(*); ccp(*); gvn(*); diff --git a/juno_samples/fork_join_tests/src/fork_join_tests.jn b/juno_samples/fork_join_tests/src/fork_join_tests.jn index 806cb0f1..128d3ce0 100644 --- a/juno_samples/fork_join_tests/src/fork_join_tests.jn +++ b/juno_samples/fork_join_tests/src/fork_join_tests.jn @@ -81,3 +81,16 @@ fn test6(input: i32) -> i32[1024] { } return arr; } + +#[entry] +fn test7(input : i32) -> i32[8] { + let arr : i32[8]; + @loop for i = 0 to 8 { + @bufferize1 let a = arr[i]; + @bufferize2 let b = a + arr[7-i]; + @bufferize3 let c = b * i as i32; + @bufferize4 let d = c; + arr[i] = d; + } + return arr; +} \ No newline at end of file -- GitLab From 3d9e31eedef992af21699fd94e5ec5ff9fdbb644 Mon Sep 17 00:00:00 2001 From: Xavier Routh <xrouth2@illinois.edu> Date: Wed, 5 Feb 2025 12:07:19 -0600 Subject: [PATCH 26/33] stash apply --- hercules_cg/src/cpu.rs | 4 +- hercules_opt/src/fork_transforms.rs | 298 ++++++++++++++---- .../hercules_tests/tests/loop_tests.rs | 27 +- juno_samples/fork_join_tests/src/cpu.sch | 12 +- .../fork_join_tests/src/fork_join_tests.jn | 2 + juno_samples/fork_join_tests/src/main.rs | 22 ++ juno_samples/matmul/build.rs | 2 + juno_scheduler/src/compile.rs | 1 + juno_scheduler/src/ir.rs | 1 + juno_scheduler/src/pm.rs | 50 ++- 10 files changed, 326 insertions(+), 93 deletions(-) diff --git a/hercules_cg/src/cpu.rs b/hercules_cg/src/cpu.rs index f6a1f309..20a3e6cb 100644 --- a/hercules_cg/src/cpu.rs +++ b/hercules_cg/src/cpu.rs @@ -694,7 +694,7 @@ impl<'a> CPUContext<'a> { ); write!( body, - " {} = call i64 @llvm.umin.i64(i64{},i64%dc{}))\n", + " {} = call i64 @llvm.umin.i64(i64{},i64%dc{})\n", new_val, cur_value, x.idx() @@ -719,7 +719,7 @@ impl<'a> CPUContext<'a> { ); write!( body, - " {} = call i64 @llvm.umax.i64(i64{},i64%dc{}))\n", + " {} = call i64 @llvm.umax.i64(i64{},i64%dc{})\n", new_val, cur_value, x.idx() diff --git a/hercules_opt/src/fork_transforms.rs b/hercules_opt/src/fork_transforms.rs index 456f670e..f82bbdfb 100644 --- a/hercules_opt/src/fork_transforms.rs +++ b/hercules_opt/src/fork_transforms.rs @@ -98,6 +98,31 @@ pub fn find_reduce_dependencies<'a>( ret_val } +pub fn copy_subgraph_in_edit<'a>( + mut edit: FunctionEdit<'a, 'a>, + subgraph: HashSet<NodeID>, +) -> ( + Result<(FunctionEdit<'a, 'a>, HashMap<NodeID, NodeID>), FunctionEdit<'a, 'a>>// a map from old nodes to new nodes + // A list of (inside *old* node, outside node) s.t insde old node -> outside node. + // The caller probably wants +) { + + let mut map: HashMap<NodeID, NodeID> = HashMap::new(); + + // Copy nodes in subgraph + for old_id in subgraph.iter().cloned() { + let new_id = edit.copy_node(old_id); + map.insert(old_id, new_id); + } + + // Update edges to new nodes + for old_id in subgraph.iter() { + edit = edit.replace_all_uses_where(*old_id, map[old_id], |node_id| map.values().contains(node_id))?; + } + + Ok((edit, map)) +} + pub fn copy_subgraph( editor: &mut FunctionEditor, subgraph: HashSet<NodeID>, @@ -144,6 +169,37 @@ pub fn copy_subgraph( (new_nodes, map, outside_users) } + +pub fn fork_fission_bufferize_toplevel<'a>( + editor: &'a mut FunctionEditor<'a>, + loop_tree: &'a LoopTree, + fork_join_map: &'a HashMap<NodeID, NodeID>, + data_node_in_fork_joins: &'a HashMap<NodeID, HashSet<NodeID>>, + typing: &'a Vec<TypeID> +) -> bool { + + let forks: Vec<_> = loop_tree + .bottom_up_loops() + .into_iter() + .filter(|(k, _)| editor.func().nodes[k.idx()].is_fork()) + .collect(); + + for l in forks { + let fork_info = &Loop { + header: l.0, + control: l.1.clone(), + }; + let fork = fork_info.header; + let join = fork_join_map[&fork]; + let mut edges = HashSet::new(); + edges.insert((NodeID::new(8), NodeID::new(3))); + let result = fork_bufferize_fission_helper(editor, fork_info, edges, data_node_in_fork_joins, typing, fork, join); + return result.is_some(); + } + return false; +} + + pub fn fork_fission<'a>( editor: &'a mut FunctionEditor, _control_subgraph: &Subgraph, @@ -184,71 +240,100 @@ pub fn fork_fission<'a>( /** Split a 1D fork into two forks, placing select intermediate data into buffers. */ pub fn fork_bufferize_fission_helper<'a>( - editor: &'a mut FunctionEditor, - fork_join_map: &HashMap<NodeID, NodeID>, + editor: &'a mut FunctionEditor<'a>, + l: &Loop, bufferized_edges: HashSet<(NodeID, NodeID)>, // Describes what intermediate data should be bufferized. - _original_control_pred: NodeID, // What the new fork connects to. + data_node_in_fork_joins: &'a HashMap<NodeID, HashSet<NodeID>>, types: &Vec<TypeID>, fork: NodeID, -) -> (NodeID, NodeID) { + join: NodeID +) -> Option<(NodeID, NodeID)> { // Returns the two forks that it generates. - // TODO: Check that bufferized edges src doesn't depend on anything that comes after the fork. + if bufferized_edges.is_empty() { + return None + } - // Copy fork + control intermediates + join to new fork + join, - // How does control get partitioned? - // (depending on how it affects the data nodes on each side of the bufferized_edges) - // may end up in each loop, fix me later. - // place new fork + join after join of first. + let all_loop_nodes = l.get_all_nodes(); - // Only handle fork+joins with no inner control for now. + // FIXME: Cloning hell. + let data_nodes = data_node_in_fork_joins[&fork].clone(); + let loop_nodes = editor.node_ids().filter(|node_id| all_loop_nodes[node_id.idx()]); + // Clone the subgraph that consists of this fork-join and all data and control nodes in it. + let subgraph = HashSet::from_iter(data_nodes.into_iter().chain(loop_nodes)); - // Create fork + join + Thread control - let join = fork_join_map[&fork]; - let mut new_fork_id = NodeID::new(0); - let mut new_join_id = NodeID::new(0); + let mut outside_users = Vec::new(); // old_node, outside_user - editor.edit(|mut edit| { - new_join_id = edit.add_node(Node::Join { control: fork }); - let factors = edit.get_node(fork).try_fork().unwrap().1; - new_fork_id = edit.add_node(Node::Fork { - control: new_join_id, - factors: factors.into(), - }); - edit.replace_all_uses_where(fork, new_fork_id, |usee| *usee == join) - }); + for node in subgraph.iter() { + for user in editor.get_users(*node) { + if !subgraph.iter().contains(&user) { + outside_users.push((*node, user)); + } + } + } - for (src, dst) in bufferized_edges { - // FIXME: Disgusting cloning and allocationing and iterators. - let factors: Vec<_> = editor.func().nodes[fork.idx()] - .try_fork() - .unwrap() - .1 - .iter() - .cloned() - .collect(); - editor.edit(|mut edit| { - // Create write to buffer + let factors: Vec<_> = editor.func().nodes[fork.idx()] + .try_fork() + .unwrap() + .1 + .iter() + .cloned() + .collect(); - let thread_stuff_it = factors.into_iter().enumerate(); + let thread_stuff_it = factors.into_iter().enumerate(); - // FIxme: try to use unzip here? Idk why it wasn't working. - let tids = thread_stuff_it.clone().map(|(dim, _)| { - edit.add_node(Node::ThreadID { - control: fork, - dimension: dim, - }) + // Control succesors + let fork_pred = editor.get_uses(fork).filter(|a| editor.node(a).is_control()).next().unwrap(); + let join_successor = editor.get_users(join).filter(|a| editor.node(a).is_control()).next().unwrap(); + + let mut new_fork_id = NodeID::new(0); + editor.edit(|edit| { + let (mut edit, map) = copy_subgraph_in_edit(edit, subgraph)?; + + // Put new subgraph after old subgraph + println!("map: {:?}", map); + println!("join: {:?}, fork: {:?}", join, fork); + println!("fork_sccueue: {:?}", join_successor); + edit = edit.replace_all_uses_where(fork_pred, join, |a| *a == map[&fork])?; + edit = edit.replace_all_uses_where(join, map[&join], |a| *a == join_successor)?; + + // Replace outside uses of reduces in old subgraph with reduces in new subgraph. + for (old_node, outside_user) in outside_users { + edit = edit.replace_all_uses_where(old_node, map[&old_node], |node| *node == outside_user)?; + } + + // Add buffers to old subgraph + + let new_join = map[&join]; + let new_fork = map[&fork]; + + // FIXME: Do this as part of copy subgraph? + // Add tids to original subgraph for indexing. + let mut old_tids = Vec::new(); + let mut new_tids = Vec::new(); + for (dim, _) in thread_stuff_it.clone() { + let old_id = edit.add_node(Node::ThreadID { + control: fork, + dimension: dim, }); - let array_dims = thread_stuff_it.clone().map(|(_, factor)| (factor)); + let new_id = edit.add_node(Node::ThreadID { + control: new_fork, + dimension: dim, + }); - // Assume 1-d fork only for now. - // let tid = edit.add_node(Node::ThreadID { control: fork, dimension: 0 }); - let position_idx = Index::Position(tids.collect::<Vec<_>>().into_boxed_slice()); + old_tids.push(old_id); + new_tids.push(new_id); + } + + for (src, dst) in &bufferized_edges { + let array_dims = thread_stuff_it.clone().map(|(_, factor)| (factor)); + let position_idx = Index::Position(old_tids.clone().into_boxed_slice()); + let write = edit.add_node(Node::Write { collect: NodeID::new(0), - data: src, - indices: vec![position_idx].into(), + data: *src, + indices: vec![position_idx.clone()].into(), }); let ele_type = types[src.idx()]; let empty_buffer = edit.add_type(hercules_ir::Type::Array( @@ -258,36 +343,125 @@ pub fn fork_bufferize_fission_helper<'a>( let empty_buffer = edit.add_zero_constant(empty_buffer); let empty_buffer = edit.add_node(Node::Constant { id: empty_buffer }); let reduce = Node::Reduce { - control: new_join_id, + control: join, init: empty_buffer, reduct: write, }; let reduce = edit.add_node(reduce); + // Fix write node edit = edit.replace_all_uses_where(NodeID::new(0), reduce, |usee| *usee == write)?; - // Create read from buffer - let tids = thread_stuff_it.clone().map(|(dim, _)| { - edit.add_node(Node::ThreadID { - control: new_fork_id, - dimension: dim, - }) - }); - let position_idx = Index::Position(tids.collect::<Vec<_>>().into_boxed_slice()); + // Create reads from buffer + let position_idx = Index::Position(new_tids.clone().into_boxed_slice()); let read = edit.add_node(Node::Read { collect: reduce, indices: vec![position_idx].into(), }); - edit = edit.replace_all_uses_where(src, read, |usee| *usee == dst)?; + // Replaces uses of bufferized edge src with corresponding reduce and read in old subraph + edit = edit.replace_all_uses_where(map[src], read, |usee| *usee == map[dst])?; + + } + + + - Ok(edit) - }); - } + new_fork_id = new_fork; + + Ok(edit) + }); - (fork, new_fork_id) + Some((fork, new_fork_id)) + + // let internal_control: Vec<NodeID> = Vec::new(); + + // // Create fork + join + Thread control + // let mut new_fork_id = NodeID::new(0); + // let mut new_join_id = NodeID::new(0); + + // editor.edit(|mut edit| { + // new_join_id = edit.add_node(Node::Join { control: fork }); + // let factors = edit.get_node(fork).try_fork().unwrap().1; + // new_fork_id = edit.add_node(Node::Fork { + // control: new_join_id, + // factors: factors.into(), + // }); + // edit.replace_all_uses_where(fork, new_fork_id, |usee| *usee == join) + // }); + + // for (src, dst) in bufferized_edges { + // // FIXME: Disgusting cloning and allocationing and iterators. + // let factors: Vec<_> = editor.func().nodes[fork.idx()] + // .try_fork() + // .unwrap() + // .1 + // .iter() + // .cloned() + // .collect(); + // editor.edit(|mut edit| { + // // Create write to buffer + + // let thread_stuff_it = factors.into_iter().enumerate(); + + // // FIxme: try to use unzip here? Idk why it wasn't working. + // let tids = thread_stuff_it.clone().map(|(dim, _)| { + // edit.add_node(Node::ThreadID { + // control: fork, + // dimension: dim, + // }) + // }); + + // let array_dims = thread_stuff_it.clone().map(|(_, factor)| (factor)); + + // // Assume 1-d fork only for now. + // // let tid = edit.add_node(Node::ThreadID { control: fork, dimension: 0 }); + // let position_idx = Index::Position(tids.collect::<Vec<_>>().into_boxed_slice()); + // let write = edit.add_node(Node::Write { + // collect: NodeID::new(0), + // data: src, + // indices: vec![position_idx].into(), + // }); + // let ele_type = types[src.idx()]; + // let empty_buffer = edit.add_type(hercules_ir::Type::Array( + // ele_type, + // array_dims.collect::<Vec<_>>().into_boxed_slice(), + // )); + // let empty_buffer = edit.add_zero_constant(empty_buffer); + // let empty_buffer = edit.add_node(Node::Constant { id: empty_buffer }); + // let reduce = Node::Reduce { + // control: new_join_id, + // init: empty_buffer, + // reduct: write, + // }; + // let reduce = edit.add_node(reduce); + // // Fix write node + // edit = edit.replace_all_uses_where(NodeID::new(0), reduce, |usee| *usee == write)?; + + // // Create read from buffer + // let tids = thread_stuff_it.clone().map(|(dim, _)| { + // edit.add_node(Node::ThreadID { + // control: new_fork_id, + // dimension: dim, + // }) + // }); + + // let position_idx = Index::Position(tids.collect::<Vec<_>>().into_boxed_slice()); + + // let read = edit.add_node(Node::Read { + // collect: reduce, + // indices: vec![position_idx].into(), + // }); + + // edit = edit.replace_all_uses_where(src, read, |usee| *usee == dst)?; + + // Ok(edit) + // }); + // } + + // (fork, new_fork_id) } /** Split a 1D fork into a separate fork for each reduction. */ diff --git a/hercules_test/hercules_tests/tests/loop_tests.rs b/hercules_test/hercules_tests/tests/loop_tests.rs index 795642b2..f42a6520 100644 --- a/hercules_test/hercules_tests/tests/loop_tests.rs +++ b/hercules_test/hercules_tests/tests/loop_tests.rs @@ -385,7 +385,7 @@ fn matmul_pipeline() { let dyn_consts = [I, J, K]; // FIXME: This path should not leave the crate - let mut module = parse_module_from_hbin("../../juno_samples/test/out.hbin"); + let mut module = parse_module_from_hbin("../../juno_samples/matmul/out.hbin"); // let mut correct_c: Box<[i32]> = (0..I * K).map(|_| 0).collect(); for i in 0..I { @@ -398,25 +398,20 @@ fn matmul_pipeline() { let result_1 = interp_module!(module, 1, dyn_consts, a.clone(), b.clone()); - println!("golden: {:?}", correct_c); - println!("result: {:?}", result_1); + // println!("golden: {:?}", correct_c); + // println!("result: {:?}", result_1); - let InterpreterVal::Array(_, d) = result_1.clone() else { - panic!() - }; - let InterpreterVal::Integer32(value) = d[0] else { - panic!() - }; - assert_eq!(correct_c[0], value); + // let InterpreterVal::Array(_, d) = result_1.clone() else { + // panic!() + // }; + // let InterpreterVal::Integer32(value) = d[0] else { + // panic!() + // }; + // assert_eq!(correct_c[0], value); let schedule = Some(default_schedule![ - AutoOutline, - InterproceduralSROA, - SROA, - InferSchedules, - DCE, Xdot, - GCM + Verify, ]); module = run_schedule_on_hercules(module, schedule).unwrap(); diff --git a/juno_samples/fork_join_tests/src/cpu.sch b/juno_samples/fork_join_tests/src/cpu.sch index a557cd03..947d0dc8 100644 --- a/juno_samples/fork_join_tests/src/cpu.sch +++ b/juno_samples/fork_join_tests/src/cpu.sch @@ -7,10 +7,13 @@ dce(*); let out = auto-outline(test1, test2, test3, test4, test5); cpu(out.test1); +<<<<<<< Updated upstream cpu(out.test2); cpu(out.test3); cpu(out.test4); cpu(out.test5); +======= +>>>>>>> Stashed changes ip-sroa(*); @@ -24,8 +27,15 @@ dce(*); fixpoint panic after 20 { forkify(*); fork-guard-elim(*); - fork-coalesce(*); + fork-coalesce(*); } + +dce(*); +gvn(*); + +xdot[true](*); +fork-fission-bufferize(*); +xdot[true](*); gvn(*); phi-elim(*); dce(*); diff --git a/juno_samples/fork_join_tests/src/fork_join_tests.jn b/juno_samples/fork_join_tests/src/fork_join_tests.jn index 128d3ce0..90d06c2f 100644 --- a/juno_samples/fork_join_tests/src/fork_join_tests.jn +++ b/juno_samples/fork_join_tests/src/fork_join_tests.jn @@ -9,6 +9,7 @@ fn test1(input : i32) -> i32[4, 4] { return arr; } +/** #[entry] fn test2(input : i32) -> i32[4, 4] { let arr : i32[4, 4]; @@ -72,6 +73,7 @@ fn test5(input : i32) -> i32[4] { } return arr1; } +<<<<<<< Updated upstream #[entry] fn test6(input: i32) -> i32[1024] { diff --git a/juno_samples/fork_join_tests/src/main.rs b/juno_samples/fork_join_tests/src/main.rs index 19838fd7..33c5602e 100644 --- a/juno_samples/fork_join_tests/src/main.rs +++ b/juno_samples/fork_join_tests/src/main.rs @@ -23,6 +23,7 @@ fn main() { let correct = vec![5i32; 16]; assert(&correct, output); +<<<<<<< Updated upstream let mut r = runner!(test2); let output = r.run(3).await; let correct = vec![24i32; 16]; @@ -47,6 +48,27 @@ fn main() { let output = r.run(73).await; let correct = (73i32..73i32+1024i32).collect(); assert(&correct, output); +======= + // let mut r = runner!(test2); + // let output = r.run(3).await; + // let correct = vec![24i32; 16]; + // assert(correct, output); + + // let mut r = runner!(test3); + // let output = r.run(0).await; + // let correct = vec![11, 10, 9, 10, 9, 8, 9, 8, 7]; + // assert(correct, output); + + // let mut r = runner!(test4); + // let output = r.run(9).await; + // let correct = vec![63i32; 16]; + // assert(correct, output); + + // let mut r = runner!(test5); + // let output = r.run(4).await; + // let correct = vec![7i32; 4]; + // assert(correct, output); +>>>>>>> Stashed changes }); } diff --git a/juno_samples/matmul/build.rs b/juno_samples/matmul/build.rs index 0be838c6..d2813388 100644 --- a/juno_samples/matmul/build.rs +++ b/juno_samples/matmul/build.rs @@ -6,6 +6,8 @@ fn main() { JunoCompiler::new() .file_in_src("matmul.jn") .unwrap() + .schedule_in_src("cpu.sch") + .unwrap() .build() .unwrap(); } diff --git a/juno_scheduler/src/compile.rs b/juno_scheduler/src/compile.rs index 713c30d4..08e952a4 100644 --- a/juno_scheduler/src/compile.rs +++ b/juno_scheduler/src/compile.rs @@ -109,6 +109,7 @@ impl FromStr for Appliable { "ip-sroa" | "interprocedural-sroa" => { Ok(Appliable::Pass(ir::Pass::InterproceduralSROA)) } + "fork-fission-bufferize" => Ok(Appliable::Pass(ir::Pass::ForkFissionBufferize)), "fork-dim-merge" => Ok(Appliable::Pass(ir::Pass::ForkDimMerge)), "fork-chunk" | "fork-tile" => Ok(Appliable::Pass(ir::Pass::ForkChunk)), "lift-dc-math" => Ok(Appliable::Pass(ir::Pass::LiftDCMath)), diff --git a/juno_scheduler/src/ir.rs b/juno_scheduler/src/ir.rs index 9e85509f..8a6e04ed 100644 --- a/juno_scheduler/src/ir.rs +++ b/juno_scheduler/src/ir.rs @@ -12,6 +12,7 @@ pub enum Pass { ForkSplit, ForkCoalesce, Forkify, + ForkFissionBufferize, ForkDimMerge, ForkChunk, GCM, diff --git a/juno_scheduler/src/pm.rs b/juno_scheduler/src/pm.rs index d4acac19..28b0cbf5 100644 --- a/juno_scheduler/src/pm.rs +++ b/juno_scheduler/src/pm.rs @@ -1987,10 +1987,7 @@ fn run_pass( }); }; let Some(Value::Integer { val: dim_idx }) = args.get(1) else { - return Err(SchedulerError::PassError { - pass: "forkChunk".to_string(), - error: "expected integer argument".to_string(), - }); + panic!(); // How to error here? }; let Some(Value::Boolean { val: guarded_flag }) = args.get(2) else { return Err(SchedulerError::PassError { @@ -2002,9 +1999,10 @@ fn run_pass( assert_eq!(*guarded_flag, true); pm.make_fork_join_maps(); let fork_join_maps = pm.fork_join_maps.take().unwrap(); - for (func, fork_join_map) in build_selection(pm, selection) - .into_iter() - .zip(fork_join_maps.iter()) + for (func, fork_join_map) in + build_selection(pm, selection) + .into_iter() + .zip(fork_join_maps.iter()) { let Some(mut func) = func else { continue; @@ -2015,13 +2013,41 @@ fn run_pass( pm.delete_gravestones(); pm.clear_analyses(); } + Pass::ForkFissionBufferize => { + assert!(args.is_empty()); + pm.make_fork_join_maps(); + pm.make_typing(); + pm.make_loops(); + pm.make_nodes_in_fork_joins(); + let fork_join_maps = pm.fork_join_maps.take().unwrap(); + let typing = pm.typing.take().unwrap(); + let loops = pm.loops.take().unwrap(); + let nodes_in_fork_joins = pm.nodes_in_fork_joins.take().unwrap(); + for ((((func, fork_join_map), loop_tree), typing), nodes_in_fork_joins) in + build_selection(pm, selection) + .into_iter() + .zip(fork_join_maps.iter()) + .zip(loops.iter()) + .zip(typing.iter()) + .zip(nodes_in_fork_joins.iter()) + { + let Some(mut func) = func else { + continue; + }; + let result = fork_fission_bufferize_toplevel(&mut func, loop_tree, fork_join_map, nodes_in_fork_joins, typing); + changed |= result; + } + pm.delete_gravestones(); + pm.clear_analyses(); + } Pass::ForkDimMerge => { assert!(args.is_empty()); pm.make_fork_join_maps(); let fork_join_maps = pm.fork_join_maps.take().unwrap(); - for (func, fork_join_map) in build_selection(pm, selection) - .into_iter() - .zip(fork_join_maps.iter()) + for (func, fork_join_map) in + build_selection(pm, selection) + .into_iter() + .zip(fork_join_maps.iter()) { let Some(mut func) = func else { continue; @@ -2051,8 +2077,8 @@ fn run_pass( let Some(mut func) = func else { continue; }; - changed |= fork_coalesce(&mut func, loop_nest, fork_join_map); - // func.modified(); + fork_coalesce(&mut func, loop_nest, fork_join_map); + changed |= func.modified(); } pm.delete_gravestones(); pm.clear_analyses(); -- GitLab From d2cec2531bce086be7132bae76b16383d7ddccaa Mon Sep 17 00:00:00 2001 From: Xavier Routh <xrouth2@illinois.edu> Date: Sat, 8 Feb 2025 13:23:19 -0600 Subject: [PATCH 27/33] fork fission bufferize --- hercules_opt/src/fork_transforms.rs | 107 +++++++++++++----- juno_samples/fork_join_tests/src/blah.sch | 34 ++++++ juno_samples/fork_join_tests/src/cpu.sch | 14 +-- .../fork_join_tests/src/fork_join_tests.jn | 24 ++-- juno_samples/fork_join_tests/src/main.rs | 28 +---- juno_scheduler/src/pm.rs | 60 +++++++++- 6 files changed, 198 insertions(+), 69 deletions(-) create mode 100644 juno_samples/fork_join_tests/src/blah.sch diff --git a/hercules_opt/src/fork_transforms.rs b/hercules_opt/src/fork_transforms.rs index f82bbdfb..ea486c94 100644 --- a/hercules_opt/src/fork_transforms.rs +++ b/hercules_opt/src/fork_transforms.rs @@ -1,6 +1,7 @@ use std::collections::{HashMap, HashSet}; use std::iter::zip; use std::thread::ThreadId; +use std::hash::Hash; use bimap::BiMap; use itertools::Itertools; @@ -98,11 +99,11 @@ pub fn find_reduce_dependencies<'a>( ret_val } -pub fn copy_subgraph_in_edit<'a>( - mut edit: FunctionEdit<'a, 'a>, +pub fn copy_subgraph_in_edit<'a, 'b>( + mut edit: FunctionEdit<'a, 'b>, subgraph: HashSet<NodeID>, ) -> ( - Result<(FunctionEdit<'a, 'a>, HashMap<NodeID, NodeID>), FunctionEdit<'a, 'a>>// a map from old nodes to new nodes + Result<(FunctionEdit<'a, 'b>, HashMap<NodeID, NodeID>), FunctionEdit<'a, 'b>>// a map from old nodes to new nodes // A list of (inside *old* node, outside node) s.t insde old node -> outside node. // The caller probably wants ) { @@ -169,14 +170,56 @@ pub fn copy_subgraph( (new_nodes, map, outside_users) } +fn is_strict_superset<T: Eq + Hash>(set1: &HashSet<T>, set2: &HashSet<T>) -> bool { + // A strict superset must be larger than its subset + if set1.len() <= set2.len() { + return false; + } + + // Every element in set2 must be in set1 + set2.iter().all(|item| set1.contains(item)) +} -pub fn fork_fission_bufferize_toplevel<'a>( - editor: &'a mut FunctionEditor<'a>, - loop_tree: &'a LoopTree, - fork_join_map: &'a HashMap<NodeID, NodeID>, - data_node_in_fork_joins: &'a HashMap<NodeID, HashSet<NodeID>>, - typing: &'a Vec<TypeID> -) -> bool { +pub fn find_bufferize_edges( + editor: & mut FunctionEditor, + fork: NodeID, + loop_tree: & LoopTree, + fork_join_map: &HashMap<NodeID, NodeID>, + nodes_in_fork_joins: & HashMap<NodeID, HashSet<NodeID>>, +) -> HashSet<(NodeID, NodeID)> { + + // println!("func: {:?}", editor.func_id()); + let mut edges: HashSet<_> = HashSet::new(); + // print labels + for node in &nodes_in_fork_joins[&fork] { + println!("node: {:?}, label {:?}, ", node, editor.func().labels[node.idx()]); + let node_labels = &editor.func().labels[node.idx()]; + for usee in editor.get_uses(*node) { + // If usee labels is a superset of this node labels, then make an edge. + let usee_labels = &editor.func().labels[usee.idx()]; + // strict superset + if !(usee_labels.is_superset(node_labels) && usee_labels.len() > node_labels.len()) { + continue; + } + + if editor.node(usee).is_control() || editor.node(node).is_control() { + continue; + } + + edges.insert((usee, *node)); + } + } + println!("edges: {:?}", edges); + edges +} + +pub fn ff_bufferize_any_fork<'a, 'b>( + editor: &'b mut FunctionEditor<'a>, + loop_tree: &'b LoopTree, + fork_join_map: &'b HashMap<NodeID, NodeID>, + nodes_in_fork_joins: &'b HashMap<NodeID, HashSet<NodeID>>, + typing: &'b Vec<TypeID> +) -> Option<(NodeID, NodeID)> where 'a: 'b { let forks: Vec<_> = loop_tree .bottom_up_loops() @@ -185,18 +228,22 @@ pub fn fork_fission_bufferize_toplevel<'a>( .collect(); for l in forks { - let fork_info = &Loop { + let fork_info = Loop { header: l.0, control: l.1.clone(), }; let fork = fork_info.header; let join = fork_join_map[&fork]; - let mut edges = HashSet::new(); - edges.insert((NodeID::new(8), NodeID::new(3))); - let result = fork_bufferize_fission_helper(editor, fork_info, edges, data_node_in_fork_joins, typing, fork, join); - return result.is_some(); + + let edges = find_bufferize_edges(editor, fork, &loop_tree, &fork_join_map, &nodes_in_fork_joins); + let result = fork_bufferize_fission_helper(editor, &fork_info, &edges, nodes_in_fork_joins, typing, fork, join); + if result.is_none() { + continue + } else { + return result; + } } - return false; + return None; } @@ -239,15 +286,15 @@ pub fn fork_fission<'a>( } /** Split a 1D fork into two forks, placing select intermediate data into buffers. */ -pub fn fork_bufferize_fission_helper<'a>( - editor: &'a mut FunctionEditor<'a>, +pub fn fork_bufferize_fission_helper<'a, 'b>( + editor: &'b mut FunctionEditor<'a>, l: &Loop, - bufferized_edges: HashSet<(NodeID, NodeID)>, // Describes what intermediate data should be bufferized. - data_node_in_fork_joins: &'a HashMap<NodeID, HashSet<NodeID>>, - types: &Vec<TypeID>, + bufferized_edges: &HashSet<(NodeID, NodeID)>, // Describes what intermediate data should be bufferized. + data_node_in_fork_joins: &'b HashMap<NodeID, HashSet<NodeID>>, + types: &'b Vec<TypeID>, fork: NodeID, join: NodeID -) -> Option<(NodeID, NodeID)> { +) -> Option<(NodeID, NodeID)> where 'a: 'b { // Returns the two forks that it generates. if bufferized_edges.is_empty() { @@ -287,13 +334,15 @@ pub fn fork_bufferize_fission_helper<'a>( let join_successor = editor.get_users(join).filter(|a| editor.node(a).is_control()).next().unwrap(); let mut new_fork_id = NodeID::new(0); - editor.edit(|edit| { + + let edit_result = editor.edit(|edit| { let (mut edit, map) = copy_subgraph_in_edit(edit, subgraph)?; // Put new subgraph after old subgraph println!("map: {:?}", map); println!("join: {:?}, fork: {:?}", join, fork); println!("fork_sccueue: {:?}", join_successor); + edit = edit.replace_all_uses_where(fork_pred, join, |a| *a == map[&fork])?; edit = edit.replace_all_uses_where(join, map[&join], |a| *a == join_successor)?; @@ -302,6 +351,8 @@ pub fn fork_bufferize_fission_helper<'a>( edit = edit.replace_all_uses_where(old_node, map[&old_node], |node| *node == outside_user)?; } + + // Add buffers to old subgraph let new_join = map[&join]; @@ -326,7 +377,7 @@ pub fn fork_bufferize_fission_helper<'a>( new_tids.push(new_id); } - for (src, dst) in &bufferized_edges { + for (src, dst) in bufferized_edges { let array_dims = thread_stuff_it.clone().map(|(_, factor)| (factor)); let position_idx = Index::Position(old_tids.clone().into_boxed_slice()); @@ -366,14 +417,18 @@ pub fn fork_bufferize_fission_helper<'a>( } - - new_fork_id = new_fork; Ok(edit) }); + println!("edit_result: {:?}", edit_result); + if edit_result == false { + todo!(); + return None + } + Some((fork, new_fork_id)) // let internal_control: Vec<NodeID> = Vec::new(); diff --git a/juno_samples/fork_join_tests/src/blah.sch b/juno_samples/fork_join_tests/src/blah.sch new file mode 100644 index 00000000..52dea702 --- /dev/null +++ b/juno_samples/fork_join_tests/src/blah.sch @@ -0,0 +1,34 @@ + +xdot[true](*); + +fixpoint panic after 20 { + forkify(*); + fork-guard-elim(*); + fork-coalesce(*); + dce(*); +} + +xdot[true](*); + +//gvn(*); +//phi-elim(*); +//dce(*); + +//gvn(*); +//phi-elim(*); +//dce(*); + +//fixpoint panic after 20 { +// infer-schedules(*); +//} + +//fork-split(*); +//gvn(*); +//phi-elim(*); +//dce(*); +//unforkify(*); +//gvn(*); +//phi-elim(*); +//dce(*); + +//gcm(*); diff --git a/juno_samples/fork_join_tests/src/cpu.sch b/juno_samples/fork_join_tests/src/cpu.sch index 947d0dc8..e04d8dfe 100644 --- a/juno_samples/fork_join_tests/src/cpu.sch +++ b/juno_samples/fork_join_tests/src/cpu.sch @@ -7,13 +7,10 @@ dce(*); let out = auto-outline(test1, test2, test3, test4, test5); cpu(out.test1); -<<<<<<< Updated upstream cpu(out.test2); cpu(out.test3); cpu(out.test4); cpu(out.test5); -======= ->>>>>>> Stashed changes ip-sroa(*); @@ -33,9 +30,6 @@ fixpoint panic after 20 { dce(*); gvn(*); -xdot[true](*); -fork-fission-bufferize(*); -xdot[true](*); gvn(*); phi-elim(*); dce(*); @@ -53,10 +47,12 @@ gvn(*); phi-elim(*); dce(*); -xdot[true](*); -fork-fission-bufferize(test7@loop, test7@bufferize1, test7@bufferize2, test7@bufferize3, test7@bufferize4); +fork-fission-bufferize(test7); +dce(*); + fork-tile[32, 0, true](test6@loop); let out = fork-split(test6@loop); +fork-split(*); //let out = outline(out.test6.fj1); let out7 = auto-outline(test7); @@ -69,6 +65,8 @@ sroa(*); unforkify(out.test6); unforkify(out7.test7); dce(*); +unforkify(*); +dce(*); ccp(*); gvn(*); phi-elim(*); diff --git a/juno_samples/fork_join_tests/src/fork_join_tests.jn b/juno_samples/fork_join_tests/src/fork_join_tests.jn index 90d06c2f..ae3be778 100644 --- a/juno_samples/fork_join_tests/src/fork_join_tests.jn +++ b/juno_samples/fork_join_tests/src/fork_join_tests.jn @@ -9,7 +9,6 @@ fn test1(input : i32) -> i32[4, 4] { return arr; } -/** #[entry] fn test2(input : i32) -> i32[4, 4] { let arr : i32[4, 4]; @@ -73,7 +72,6 @@ fn test5(input : i32) -> i32[4] { } return arr1; } -<<<<<<< Updated upstream #[entry] fn test6(input: i32) -> i32[1024] { @@ -87,12 +85,22 @@ fn test6(input: i32) -> i32[1024] { #[entry] fn test7(input : i32) -> i32[8] { let arr : i32[8]; + let out : i32[8]; + + for i = 0 to 8 { + arr[i] = i as i32; + } + @loop for i = 0 to 8 { - @bufferize1 let a = arr[i]; - @bufferize2 let b = a + arr[7-i]; - @bufferize3 let c = b * i as i32; - @bufferize4 let d = c; - arr[i] = d; + let b: i32; + @bufferize1 { + let a = arr[i]; + let a2 = a + arr[7-i]; + b = a2 * i as i32; + } + let c = b; + let d = c + 10; + out[i] = d; } - return arr; + return out; } \ No newline at end of file diff --git a/juno_samples/fork_join_tests/src/main.rs b/juno_samples/fork_join_tests/src/main.rs index 33c5602e..caf956a1 100644 --- a/juno_samples/fork_join_tests/src/main.rs +++ b/juno_samples/fork_join_tests/src/main.rs @@ -23,7 +23,6 @@ fn main() { let correct = vec![5i32; 16]; assert(&correct, output); -<<<<<<< Updated upstream let mut r = runner!(test2); let output = r.run(3).await; let correct = vec![24i32; 16]; @@ -44,31 +43,10 @@ fn main() { let correct = vec![7i32; 4]; assert(&correct, output); - let mut r = runner!(test6); - let output = r.run(73).await; - let correct = (73i32..73i32+1024i32).collect(); + let mut r = runner!(test7); + let output = r.run(0).await; + let correct = vec![10, 17, 24, 31, 38, 45, 52, 59]; assert(&correct, output); -======= - // let mut r = runner!(test2); - // let output = r.run(3).await; - // let correct = vec![24i32; 16]; - // assert(correct, output); - - // let mut r = runner!(test3); - // let output = r.run(0).await; - // let correct = vec![11, 10, 9, 10, 9, 8, 9, 8, 7]; - // assert(correct, output); - - // let mut r = runner!(test4); - // let output = r.run(9).await; - // let correct = vec![63i32; 16]; - // assert(correct, output); - - // let mut r = runner!(test5); - // let output = r.run(4).await; - // let correct = vec![7i32; 4]; - // assert(correct, output); ->>>>>>> Stashed changes }); } diff --git a/juno_scheduler/src/pm.rs b/juno_scheduler/src/pm.rs index 28b0cbf5..2c5a3687 100644 --- a/juno_scheduler/src/pm.rs +++ b/juno_scheduler/src/pm.rs @@ -2015,6 +2015,8 @@ fn run_pass( } Pass::ForkFissionBufferize => { assert!(args.is_empty()); + let mut created_fork_joins = vec![vec![]; pm.functions.len()]; + pm.make_fork_join_maps(); pm.make_typing(); pm.make_loops(); @@ -2034,9 +2036,63 @@ fn run_pass( let Some(mut func) = func else { continue; }; - let result = fork_fission_bufferize_toplevel(&mut func, loop_tree, fork_join_map, nodes_in_fork_joins, typing); - changed |= result; + if let Some((fork1, fork2)) = + ff_bufferize_any_fork(&mut func, loop_tree, fork_join_map, nodes_in_fork_joins, typing) + { + let created_fork_joins = &mut created_fork_joins[func.func_id().idx()]; + created_fork_joins.push(fork1); + created_fork_joins.push(fork2); + } + changed |= func.modified(); + } + + pm.make_nodes_in_fork_joins(); + let nodes_in_fork_joins = pm.nodes_in_fork_joins.take().unwrap(); + let mut new_fork_joins = HashMap::new(); + + for (mut func, created_fork_joins) in + build_editors(pm).into_iter().zip(created_fork_joins) + { + // For every function, create a label for every level of fork- + // joins resulting from the split. + let name = func.func().name.clone(); + let func_id = func.func_id(); + let labels = create_labels_for_node_sets( + &mut func, + created_fork_joins.into_iter().map(|fork| { + nodes_in_fork_joins[func_id.idx()][&fork] + .iter() + .map(|id| *id) + }) + , + ); + + // Assemble those labels into a record for this function. The + // format of the records is <function>.<f>, where N is the + // level of the split fork-joins being referred to. + todo!(); + // FIXME: What if there are multiple bufferized forks per function? + let mut func_record = HashMap::new(); + for (idx, label) in labels { + func_record.insert( + format!("fj{}", idx), + Value::Label { + labels: vec![LabelInfo { + func: func_id, + label: label, + }], + }, + ); + } + + // Try to avoid creating unnecessary record entries. + if !func_record.is_empty() { + new_fork_joins.entry(name).insert_entry(Value::Record { + fields: func_record, + }); + } } + pm.delete_gravestones(); pm.clear_analyses(); } -- GitLab From c0d161709074c5c79705de7052953b0fb27489a4 Mon Sep 17 00:00:00 2001 From: Xavier Routh <xrouth2@illinois.edu> Date: Sun, 9 Feb 2025 13:09:53 -0600 Subject: [PATCH 28/33] fork fission bufferize pm --- hercules_opt/src/fork_transforms.rs | 144 +++++++++++++++-------- juno_samples/fork_join_tests/src/cpu.sch | 11 +- juno_samples/fork_join_tests/src/main.rs | 2 +- juno_scheduler/src/ir.rs | 1 + juno_scheduler/src/pm.rs | 101 ++++++++++++---- 5 files changed, 179 insertions(+), 80 deletions(-) diff --git a/hercules_opt/src/fork_transforms.rs b/hercules_opt/src/fork_transforms.rs index ea486c94..ad81d90b 100644 --- a/hercules_opt/src/fork_transforms.rs +++ b/hercules_opt/src/fork_transforms.rs @@ -1,7 +1,7 @@ use std::collections::{HashMap, HashSet}; +use std::hash::Hash; use std::iter::zip; use std::thread::ThreadId; -use std::hash::Hash; use bimap::BiMap; use itertools::Itertools; @@ -102,14 +102,9 @@ pub fn find_reduce_dependencies<'a>( pub fn copy_subgraph_in_edit<'a, 'b>( mut edit: FunctionEdit<'a, 'b>, subgraph: HashSet<NodeID>, -) -> ( - Result<(FunctionEdit<'a, 'b>, HashMap<NodeID, NodeID>), FunctionEdit<'a, 'b>>// a map from old nodes to new nodes - // A list of (inside *old* node, outside node) s.t insde old node -> outside node. - // The caller probably wants -) { - +) -> (Result<(FunctionEdit<'a, 'b>, HashMap<NodeID, NodeID>), FunctionEdit<'a, 'b>>) { let mut map: HashMap<NodeID, NodeID> = HashMap::new(); - + // Copy nodes in subgraph for old_id in subgraph.iter().cloned() { let new_id = edit.copy_node(old_id); @@ -118,7 +113,9 @@ pub fn copy_subgraph_in_edit<'a, 'b>( // Update edges to new nodes for old_id in subgraph.iter() { - edit = edit.replace_all_uses_where(*old_id, map[old_id], |node_id| map.values().contains(node_id))?; + edit = edit.replace_all_uses_where(*old_id, map[old_id], |node_id| { + map.values().contains(node_id) + })?; } Ok((edit, map)) @@ -175,38 +172,48 @@ fn is_strict_superset<T: Eq + Hash>(set1: &HashSet<T>, set2: &HashSet<T>) -> boo if set1.len() <= set2.len() { return false; } - + // Every element in set2 must be in set1 set2.iter().all(|item| set1.contains(item)) } pub fn find_bufferize_edges( - editor: & mut FunctionEditor, + editor: &mut FunctionEditor, fork: NodeID, - loop_tree: & LoopTree, + loop_tree: &LoopTree, fork_join_map: &HashMap<NodeID, NodeID>, - nodes_in_fork_joins: & HashMap<NodeID, HashSet<NodeID>>, + nodes_in_fork_joins: &HashMap<NodeID, HashSet<NodeID>>, + data_label: &LabelID, ) -> HashSet<(NodeID, NodeID)> { - // println!("func: {:?}", editor.func_id()); let mut edges: HashSet<_> = HashSet::new(); - // print labels + + println!("ndoes in fork joins: {:?}", &nodes_in_fork_joins[&fork]); + // print labels for node in &nodes_in_fork_joins[&fork] { - println!("node: {:?}, label {:?}, ", node, editor.func().labels[node.idx()]); + // Edge from *has data label** to doesn't have data label* let node_labels = &editor.func().labels[node.idx()]; - for usee in editor.get_uses(*node) { - // If usee labels is a superset of this node labels, then make an edge. - let usee_labels = &editor.func().labels[usee.idx()]; - // strict superset - if !(usee_labels.is_superset(node_labels) && usee_labels.len() > node_labels.len()) { + + if !node_labels.contains(data_label) { + continue; + } + + // Don't draw bufferize edges from fork tids + if editor.get_users(fork).contains(node) { + continue; + } + + for user in editor.get_users(*node) { + let user_labels = &editor.func().labels[user.idx()]; + if user_labels.contains(data_label) { continue; } - if editor.node(usee).is_control() || editor.node(node).is_control() { + if editor.node(user).is_control() || editor.node(node).is_control() { continue; } - edges.insert((usee, *node)); + edges.insert((*node, user)); } } println!("edges: {:?}", edges); @@ -218,15 +225,20 @@ pub fn ff_bufferize_any_fork<'a, 'b>( loop_tree: &'b LoopTree, fork_join_map: &'b HashMap<NodeID, NodeID>, nodes_in_fork_joins: &'b HashMap<NodeID, HashSet<NodeID>>, - typing: &'b Vec<TypeID> -) -> Option<(NodeID, NodeID)> where 'a: 'b { - + typing: &'b Vec<TypeID>, + fork_label: &'b LabelID, + data_label: &'b LabelID, +) -> Option<(NodeID, NodeID)> +where + 'a: 'b, +{ let forks: Vec<_> = loop_tree .bottom_up_loops() .into_iter() .filter(|(k, _)| editor.func().nodes[k.idx()].is_fork()) .collect(); + println!("fork_label {:?}", fork_label); for l in forks { let fork_info = Loop { header: l.0, @@ -235,18 +247,39 @@ pub fn ff_bufferize_any_fork<'a, 'b>( let fork = fork_info.header; let join = fork_join_map[&fork]; - let edges = find_bufferize_edges(editor, fork, &loop_tree, &fork_join_map, &nodes_in_fork_joins); - let result = fork_bufferize_fission_helper(editor, &fork_info, &edges, nodes_in_fork_joins, typing, fork, join); + println!("fork labels: {:?}", editor.func().labels[fork.idx()]); + if !editor.func().labels[fork.idx()].contains(fork_label) { + continue; + } + + println!("fork: {:?}", fork); + + let edges = find_bufferize_edges( + editor, + fork, + &loop_tree, + &fork_join_map, + &nodes_in_fork_joins, + data_label, + ); + let result = fork_bufferize_fission_helper( + editor, + &fork_info, + &edges, + nodes_in_fork_joins, + typing, + fork, + join, + ); if result.is_none() { - continue + continue; } else { - return result; + return result; } } return None; } - pub fn fork_fission<'a>( editor: &'a mut FunctionEditor, _control_subgraph: &Subgraph, @@ -293,20 +326,25 @@ pub fn fork_bufferize_fission_helper<'a, 'b>( data_node_in_fork_joins: &'b HashMap<NodeID, HashSet<NodeID>>, types: &'b Vec<TypeID>, fork: NodeID, - join: NodeID -) -> Option<(NodeID, NodeID)> where 'a: 'b { + join: NodeID, +) -> Option<(NodeID, NodeID)> +where + 'a: 'b, +{ // Returns the two forks that it generates. if bufferized_edges.is_empty() { - return None + return None; } let all_loop_nodes = l.get_all_nodes(); // FIXME: Cloning hell. let data_nodes = data_node_in_fork_joins[&fork].clone(); - let loop_nodes = editor.node_ids().filter(|node_id| all_loop_nodes[node_id.idx()]); - // Clone the subgraph that consists of this fork-join and all data and control nodes in it. + let loop_nodes = editor + .node_ids() + .filter(|node_id| all_loop_nodes[node_id.idx()]); + // Clone the subgraph that consists of this fork-join and all data and control nodes in it. let subgraph = HashSet::from_iter(data_nodes.into_iter().chain(loop_nodes)); let mut outside_users = Vec::new(); // old_node, outside_user @@ -330,8 +368,16 @@ pub fn fork_bufferize_fission_helper<'a, 'b>( let thread_stuff_it = factors.into_iter().enumerate(); // Control succesors - let fork_pred = editor.get_uses(fork).filter(|a| editor.node(a).is_control()).next().unwrap(); - let join_successor = editor.get_users(join).filter(|a| editor.node(a).is_control()).next().unwrap(); + let fork_pred = editor + .get_uses(fork) + .filter(|a| editor.node(a).is_control()) + .next() + .unwrap(); + let join_successor = editor + .get_users(join) + .filter(|a| editor.node(a).is_control()) + .next() + .unwrap(); let mut new_fork_id = NodeID::new(0); @@ -339,25 +385,24 @@ pub fn fork_bufferize_fission_helper<'a, 'b>( let (mut edit, map) = copy_subgraph_in_edit(edit, subgraph)?; // Put new subgraph after old subgraph - println!("map: {:?}", map); - println!("join: {:?}, fork: {:?}", join, fork); - println!("fork_sccueue: {:?}", join_successor); + // println!("map: {:?}", map); + // println!("join: {:?}, fork: {:?}", join, fork); + // println!("fork_sccueue: {:?}", join_successor); edit = edit.replace_all_uses_where(fork_pred, join, |a| *a == map[&fork])?; edit = edit.replace_all_uses_where(join, map[&join], |a| *a == join_successor)?; // Replace outside uses of reduces in old subgraph with reduces in new subgraph. for (old_node, outside_user) in outside_users { - edit = edit.replace_all_uses_where(old_node, map[&old_node], |node| *node == outside_user)?; + edit = edit + .replace_all_uses_where(old_node, map[&old_node], |node| *node == outside_user)?; } - - // Add buffers to old subgraph let new_join = map[&join]; let new_fork = map[&fork]; - + // FIXME: Do this as part of copy subgraph? // Add tids to original subgraph for indexing. let mut old_tids = Vec::new(); @@ -376,11 +421,11 @@ pub fn fork_bufferize_fission_helper<'a, 'b>( old_tids.push(old_id); new_tids.push(new_id); } - + for (src, dst) in bufferized_edges { let array_dims = thread_stuff_it.clone().map(|(_, factor)| (factor)); let position_idx = Index::Position(old_tids.clone().into_boxed_slice()); - + let write = edit.add_node(Node::Write { collect: NodeID::new(0), data: *src, @@ -403,7 +448,6 @@ pub fn fork_bufferize_fission_helper<'a, 'b>( // Fix write node edit = edit.replace_all_uses_where(NodeID::new(0), reduce, |usee| *usee == write)?; - // Create reads from buffer let position_idx = Index::Position(new_tids.clone().into_boxed_slice()); @@ -414,9 +458,7 @@ pub fn fork_bufferize_fission_helper<'a, 'b>( // Replaces uses of bufferized edge src with corresponding reduce and read in old subraph edit = edit.replace_all_uses_where(map[src], read, |usee| *usee == map[dst])?; - } - new_fork_id = new_fork; @@ -426,7 +468,7 @@ pub fn fork_bufferize_fission_helper<'a, 'b>( println!("edit_result: {:?}", edit_result); if edit_result == false { todo!(); - return None + return None; } Some((fork, new_fork_id)) diff --git a/juno_samples/fork_join_tests/src/cpu.sch b/juno_samples/fork_join_tests/src/cpu.sch index e04d8dfe..57290a62 100644 --- a/juno_samples/fork_join_tests/src/cpu.sch +++ b/juno_samples/fork_join_tests/src/cpu.sch @@ -47,23 +47,24 @@ gvn(*); phi-elim(*); dce(*); -fork-fission-bufferize(test7); -dce(*); +let blah = fork-fission-bufferize[test8@loop, test8@bufferize1](test8); +dce(blah.test8.fj_loop_top); + fork-tile[32, 0, true](test6@loop); let out = fork-split(test6@loop); fork-split(*); //let out = outline(out.test6.fj1); -let out7 = auto-outline(test7); -cpu(out7.test7); +let out8 = auto-outline(test8); +cpu(out8.test8); let out = auto-outline(test6); cpu(out.test6); ip-sroa(*); sroa(*); unforkify(out.test6); -unforkify(out7.test7); +unforkify(out8.test8); dce(*); unforkify(*); dce(*); diff --git a/juno_samples/fork_join_tests/src/main.rs b/juno_samples/fork_join_tests/src/main.rs index caf956a1..fa0d80b3 100644 --- a/juno_samples/fork_join_tests/src/main.rs +++ b/juno_samples/fork_join_tests/src/main.rs @@ -43,7 +43,7 @@ fn main() { let correct = vec![7i32; 4]; assert(&correct, output); - let mut r = runner!(test7); + let mut r = runner!(test8); let output = r.run(0).await; let correct = vec![10, 17, 24, 31, 38, 45, 52, 59]; assert(&correct, output); diff --git a/juno_scheduler/src/ir.rs b/juno_scheduler/src/ir.rs index 8a6e04ed..4e5dc4c5 100644 --- a/juno_scheduler/src/ir.rs +++ b/juno_scheduler/src/ir.rs @@ -38,6 +38,7 @@ impl Pass { match self { Pass::Xdot => 1, Pass::ForkChunk => 3, + Pass::ForkFissionBufferize => 2, _ => 0, } } diff --git a/juno_scheduler/src/pm.rs b/juno_scheduler/src/pm.rs index 2c5a3687..75cd377a 100644 --- a/juno_scheduler/src/pm.rs +++ b/juno_scheduler/src/pm.rs @@ -1987,7 +1987,10 @@ fn run_pass( }); }; let Some(Value::Integer { val: dim_idx }) = args.get(1) else { - panic!(); // How to error here? + return Err(SchedulerError::PassError { + pass: "forkChunk".to_string(), + error: "expected integer argument".to_string(), + }); }; let Some(Value::Boolean { val: guarded_flag }) = args.get(2) else { return Err(SchedulerError::PassError { @@ -1999,10 +2002,9 @@ fn run_pass( assert_eq!(*guarded_flag, true); pm.make_fork_join_maps(); let fork_join_maps = pm.fork_join_maps.take().unwrap(); - for (func, fork_join_map) in - build_selection(pm, selection) - .into_iter() - .zip(fork_join_maps.iter()) + for (func, fork_join_map) in build_selection(pm, selection) + .into_iter() + .zip(fork_join_maps.iter()) { let Some(mut func) = func else { continue; @@ -2014,7 +2016,29 @@ fn run_pass( pm.clear_analyses(); } Pass::ForkFissionBufferize => { - assert!(args.is_empty()); + pm.make_fork_join_maps(); + + assert_eq!(args.len(), 2); + let Some(Value::Label { + labels: fork_labels, + }) = args.get(0) + else { + return Err(SchedulerError::PassError { + pass: "forkFissionBufferize".to_string(), + error: "expected label argument".to_string(), + }); + }; + + let Some(Value::Label { + labels: fork_data_labels, + }) = args.get(1) + else { + return Err(SchedulerError::PassError { + pass: "forkFissionBufferize".to_string(), + error: "expected label argument".to_string(), + }); + }; + let mut created_fork_joins = vec![vec![]; pm.functions.len()]; pm.make_fork_join_maps(); @@ -2025,6 +2049,21 @@ fn run_pass( let typing = pm.typing.take().unwrap(); let loops = pm.loops.take().unwrap(); let nodes_in_fork_joins = pm.nodes_in_fork_joins.take().unwrap(); + + // assert only one function is in the selection. + let num_functions = build_selection(pm, selection.clone()) + .iter() + .filter(|func| func.is_some()) + .count(); + + assert!(num_functions <= 1); + assert_eq!(fork_labels.len(), 1); + assert_eq!(fork_data_labels.len(), 1); + + let fork_label = fork_labels[0].label; + let data_label = fork_data_labels[0].label; + + // Only one func is_some(). for ((((func, fork_join_map), loop_tree), typing), nodes_in_fork_joins) in build_selection(pm, selection) .into_iter() @@ -2036,9 +2075,15 @@ fn run_pass( let Some(mut func) = func else { continue; }; - if let Some((fork1, fork2)) = - ff_bufferize_any_fork(&mut func, loop_tree, fork_join_map, nodes_in_fork_joins, typing) - { + if let Some((fork1, fork2)) = ff_bufferize_any_fork( + &mut func, + loop_tree, + fork_join_map, + nodes_in_fork_joins, + typing, + &fork_label, + &data_label, + ) { let created_fork_joins = &mut created_fork_joins[func.func_id().idx()]; created_fork_joins.push(fork1); created_fork_joins.push(fork2); @@ -2046,36 +2091,41 @@ fn run_pass( changed |= func.modified(); } + pm.clear_analyses(); pm.make_nodes_in_fork_joins(); let nodes_in_fork_joins = pm.nodes_in_fork_joins.take().unwrap(); let mut new_fork_joins = HashMap::new(); + let fork_label_name = &pm.labels.borrow()[fork_label.idx()].clone(); + for (mut func, created_fork_joins) in build_editors(pm).into_iter().zip(created_fork_joins) { - // For every function, create a label for every level of fork- + // For every function, create a label for every level of fork- // joins resulting from the split. let name = func.func().name.clone(); let func_id = func.func_id(); let labels = create_labels_for_node_sets( &mut func, created_fork_joins.into_iter().map(|fork| { - nodes_in_fork_joins[func_id.idx()][&fork] - .iter() - .map(|id| *id) - }) - , + nodes_in_fork_joins[func_id.idx()][&fork] + .iter() + .map(|id| *id) + }), ); // Assemble those labels into a record for this function. The // format of the records is <function>.<f>, where N is the // level of the split fork-joins being referred to. - todo!(); - // FIXME: What if there are multiple bufferized forks per function? let mut func_record = HashMap::new(); for (idx, label) in labels { + let fmt = if idx % 2 == 0 { + format!("fj_{}_top", fork_label_name) + } else { + format!("fj_{}_bottom", fork_label_name) + }; func_record.insert( - format!("fj{}", idx), + fmt, Value::Label { labels: vec![LabelInfo { func: func_id, @@ -2095,15 +2145,21 @@ fn run_pass( pm.delete_gravestones(); pm.clear_analyses(); + + result = Value::Record { + fields: new_fork_joins, + }; + + println!("result: {:?}", result); + } Pass::ForkDimMerge => { assert!(args.is_empty()); pm.make_fork_join_maps(); let fork_join_maps = pm.fork_join_maps.take().unwrap(); - for (func, fork_join_map) in - build_selection(pm, selection) - .into_iter() - .zip(fork_join_maps.iter()) + for (func, fork_join_map) in build_selection(pm, selection) + .into_iter() + .zip(fork_join_maps.iter()) { let Some(mut func) = func else { continue; @@ -2211,7 +2267,6 @@ fn run_pass( // Put BasicBlocks back, since it's needed for Codegen. pm.bbs = bbs; } - Pass::ForkChunk => todo!(), } println!("Ran Pass: {:?}", pass); -- GitLab From ab3035178e8ad0433b8d7ce31ace008b648fe25b Mon Sep 17 00:00:00 2001 From: Xavier Routh <xrouth2@illinois.edu> Date: Sun, 9 Feb 2025 13:31:05 -0600 Subject: [PATCH 29/33] pm fixes --- juno_samples/fork_join_tests/src/blah.sch | 34 ------------------- juno_samples/fork_join_tests/src/cpu.sch | 12 ++++--- .../fork_join_tests/src/fork_join_tests.jn | 2 +- juno_scheduler/src/pm.rs | 5 ++- 4 files changed, 10 insertions(+), 43 deletions(-) delete mode 100644 juno_samples/fork_join_tests/src/blah.sch diff --git a/juno_samples/fork_join_tests/src/blah.sch b/juno_samples/fork_join_tests/src/blah.sch deleted file mode 100644 index 52dea702..00000000 --- a/juno_samples/fork_join_tests/src/blah.sch +++ /dev/null @@ -1,34 +0,0 @@ - -xdot[true](*); - -fixpoint panic after 20 { - forkify(*); - fork-guard-elim(*); - fork-coalesce(*); - dce(*); -} - -xdot[true](*); - -//gvn(*); -//phi-elim(*); -//dce(*); - -//gvn(*); -//phi-elim(*); -//dce(*); - -//fixpoint panic after 20 { -// infer-schedules(*); -//} - -//fork-split(*); -//gvn(*); -//phi-elim(*); -//dce(*); -//unforkify(*); -//gvn(*); -//phi-elim(*); -//dce(*); - -//gcm(*); diff --git a/juno_samples/fork_join_tests/src/cpu.sch b/juno_samples/fork_join_tests/src/cpu.sch index 3b672846..fd3fc0ff 100644 --- a/juno_samples/fork_join_tests/src/cpu.sch +++ b/juno_samples/fork_join_tests/src/cpu.sch @@ -5,7 +5,7 @@ gvn(*); phi-elim(*); dce(*); -let auto = auto-outline(test1, test2, test3, test4, test5, test7); +let auto = auto-outline(test1, test2, test3, test4, test5, test7, test8); cpu(auto.test1); cpu(auto.test2); cpu(auto.test3); @@ -50,12 +50,8 @@ gvn(*); phi-elim(*); dce(*); - fork-tile[32, 0, true](test6@loop); let out = fork-split(test6@loop); - -let fission = fork-fission-bufferize[test8@loop, test8@bufferize1](auto.test8); - let out = outline(out.test6.fj1); cpu(out); ip-sroa(*); @@ -75,4 +71,10 @@ dce(auto.test7); simplify-cfg(auto.test7); dce(auto.test7); +let fission = fork-fission-bufferize[test8@loop, test8@bufferize1](auto.test8); +dce(auto.test8); +unforkify(auto.test8); +ccp(auto.test8); +dce(auto.test8); + gcm(*); diff --git a/juno_samples/fork_join_tests/src/fork_join_tests.jn b/juno_samples/fork_join_tests/src/fork_join_tests.jn index bc89b2e2..a765726f 100644 --- a/juno_samples/fork_join_tests/src/fork_join_tests.jn +++ b/juno_samples/fork_join_tests/src/fork_join_tests.jn @@ -82,7 +82,7 @@ fn test6(input: i32) -> i32[1024] { return arr; } - +#[entry] fn test7(input: i32) -> i32 { let arr : i32[32]; for i = 0 to 32 { diff --git a/juno_scheduler/src/pm.rs b/juno_scheduler/src/pm.rs index f762e5ff..58ff399f 100644 --- a/juno_scheduler/src/pm.rs +++ b/juno_scheduler/src/pm.rs @@ -2094,7 +2094,7 @@ fn run_pass( let nodes_in_fork_joins = pm.nodes_in_fork_joins.take().unwrap(); // assert only one function is in the selection. - let num_functions = build_selection(pm, selection.clone()) + let num_functions = build_selection(pm, selection.clone(), false) .iter() .filter(|func| func.is_some()) .count(); @@ -2106,9 +2106,8 @@ fn run_pass( let fork_label = fork_labels[0].label; let data_label = fork_data_labels[0].label; - // Only one func is_some(). for ((((func, fork_join_map), loop_tree), typing), nodes_in_fork_joins) in - build_selection(pm, selection) + build_selection(pm, selection, false) .into_iter() .zip(fork_join_maps.iter()) .zip(loops.iter()) -- GitLab From b4978ef526f248e205f2df22c4493e649b080162 Mon Sep 17 00:00:00 2001 From: Russel Arbore <russel.jma@gmail.com> Date: Mon, 10 Feb 2025 09:28:34 -0600 Subject: [PATCH 30/33] cleanup --- hercules_opt/src/fork_transforms.rs | 106 +---------------------- juno_samples/fork_join_tests/src/cpu.sch | 3 + juno_samples/fork_join_tests/src/main.rs | 11 ++- 3 files changed, 12 insertions(+), 108 deletions(-) diff --git a/hercules_opt/src/fork_transforms.rs b/hercules_opt/src/fork_transforms.rs index ad81d90b..80ff7b8f 100644 --- a/hercules_opt/src/fork_transforms.rs +++ b/hercules_opt/src/fork_transforms.rs @@ -1,7 +1,6 @@ use std::collections::{HashMap, HashSet}; use std::hash::Hash; use std::iter::zip; -use std::thread::ThreadId; use bimap::BiMap; use itertools::Itertools; @@ -102,7 +101,7 @@ pub fn find_reduce_dependencies<'a>( pub fn copy_subgraph_in_edit<'a, 'b>( mut edit: FunctionEdit<'a, 'b>, subgraph: HashSet<NodeID>, -) -> (Result<(FunctionEdit<'a, 'b>, HashMap<NodeID, NodeID>), FunctionEdit<'a, 'b>>) { +) -> Result<(FunctionEdit<'a, 'b>, HashMap<NodeID, NodeID>), FunctionEdit<'a, 'b>> { let mut map: HashMap<NodeID, NodeID> = HashMap::new(); // Copy nodes in subgraph @@ -185,11 +184,8 @@ pub fn find_bufferize_edges( nodes_in_fork_joins: &HashMap<NodeID, HashSet<NodeID>>, data_label: &LabelID, ) -> HashSet<(NodeID, NodeID)> { - // println!("func: {:?}", editor.func_id()); let mut edges: HashSet<_> = HashSet::new(); - println!("ndoes in fork joins: {:?}", &nodes_in_fork_joins[&fork]); - // print labels for node in &nodes_in_fork_joins[&fork] { // Edge from *has data label** to doesn't have data label* let node_labels = &editor.func().labels[node.idx()]; @@ -216,7 +212,6 @@ pub fn find_bufferize_edges( edges.insert((*node, user)); } } - println!("edges: {:?}", edges); edges } @@ -238,7 +233,6 @@ where .filter(|(k, _)| editor.func().nodes[k.idx()].is_fork()) .collect(); - println!("fork_label {:?}", fork_label); for l in forks { let fork_info = Loop { header: l.0, @@ -247,13 +241,10 @@ where let fork = fork_info.header; let join = fork_join_map[&fork]; - println!("fork labels: {:?}", editor.func().labels[fork.idx()]); if !editor.func().labels[fork.idx()].contains(fork_label) { continue; } - println!("fork: {:?}", fork); - let edges = find_bufferize_edges( editor, fork, @@ -384,11 +375,6 @@ where let edit_result = editor.edit(|edit| { let (mut edit, map) = copy_subgraph_in_edit(edit, subgraph)?; - // Put new subgraph after old subgraph - // println!("map: {:?}", map); - // println!("join: {:?}, fork: {:?}", join, fork); - // println!("fork_sccueue: {:?}", join_successor); - edit = edit.replace_all_uses_where(fork_pred, join, |a| *a == map[&fork])?; edit = edit.replace_all_uses_where(join, map[&join], |a| *a == join_successor)?; @@ -398,8 +384,6 @@ where .replace_all_uses_where(old_node, map[&old_node], |node| *node == outside_user)?; } - // Add buffers to old subgraph - let new_join = map[&join]; let new_fork = map[&fork]; @@ -465,100 +449,12 @@ where Ok(edit) }); - println!("edit_result: {:?}", edit_result); if edit_result == false { todo!(); return None; } Some((fork, new_fork_id)) - - // let internal_control: Vec<NodeID> = Vec::new(); - - // // Create fork + join + Thread control - // let mut new_fork_id = NodeID::new(0); - // let mut new_join_id = NodeID::new(0); - - // editor.edit(|mut edit| { - // new_join_id = edit.add_node(Node::Join { control: fork }); - // let factors = edit.get_node(fork).try_fork().unwrap().1; - // new_fork_id = edit.add_node(Node::Fork { - // control: new_join_id, - // factors: factors.into(), - // }); - // edit.replace_all_uses_where(fork, new_fork_id, |usee| *usee == join) - // }); - - // for (src, dst) in bufferized_edges { - // // FIXME: Disgusting cloning and allocationing and iterators. - // let factors: Vec<_> = editor.func().nodes[fork.idx()] - // .try_fork() - // .unwrap() - // .1 - // .iter() - // .cloned() - // .collect(); - // editor.edit(|mut edit| { - // // Create write to buffer - - // let thread_stuff_it = factors.into_iter().enumerate(); - - // // FIxme: try to use unzip here? Idk why it wasn't working. - // let tids = thread_stuff_it.clone().map(|(dim, _)| { - // edit.add_node(Node::ThreadID { - // control: fork, - // dimension: dim, - // }) - // }); - - // let array_dims = thread_stuff_it.clone().map(|(_, factor)| (factor)); - - // // Assume 1-d fork only for now. - // // let tid = edit.add_node(Node::ThreadID { control: fork, dimension: 0 }); - // let position_idx = Index::Position(tids.collect::<Vec<_>>().into_boxed_slice()); - // let write = edit.add_node(Node::Write { - // collect: NodeID::new(0), - // data: src, - // indices: vec![position_idx].into(), - // }); - // let ele_type = types[src.idx()]; - // let empty_buffer = edit.add_type(hercules_ir::Type::Array( - // ele_type, - // array_dims.collect::<Vec<_>>().into_boxed_slice(), - // )); - // let empty_buffer = edit.add_zero_constant(empty_buffer); - // let empty_buffer = edit.add_node(Node::Constant { id: empty_buffer }); - // let reduce = Node::Reduce { - // control: new_join_id, - // init: empty_buffer, - // reduct: write, - // }; - // let reduce = edit.add_node(reduce); - // // Fix write node - // edit = edit.replace_all_uses_where(NodeID::new(0), reduce, |usee| *usee == write)?; - - // // Create read from buffer - // let tids = thread_stuff_it.clone().map(|(dim, _)| { - // edit.add_node(Node::ThreadID { - // control: new_fork_id, - // dimension: dim, - // }) - // }); - - // let position_idx = Index::Position(tids.collect::<Vec<_>>().into_boxed_slice()); - - // let read = edit.add_node(Node::Read { - // collect: reduce, - // indices: vec![position_idx].into(), - // }); - - // edit = edit.replace_all_uses_where(src, read, |usee| *usee == dst)?; - - // Ok(edit) - // }); - // } - - // (fork, new_fork_id) } /** Split a 1D fork into a separate fork for each reduction. */ diff --git a/juno_samples/fork_join_tests/src/cpu.sch b/juno_samples/fork_join_tests/src/cpu.sch index fd3fc0ff..38a38c2b 100644 --- a/juno_samples/fork_join_tests/src/cpu.sch +++ b/juno_samples/fork_join_tests/src/cpu.sch @@ -74,7 +74,10 @@ dce(auto.test7); let fission = fork-fission-bufferize[test8@loop, test8@bufferize1](auto.test8); dce(auto.test8); unforkify(auto.test8); +dce(auto.test8); ccp(auto.test8); dce(auto.test8); +simplify-cfg(auto.test8); +dce(auto.test8); gcm(*); diff --git a/juno_samples/fork_join_tests/src/main.rs b/juno_samples/fork_join_tests/src/main.rs index 4006afa8..1013a1f0 100644 --- a/juno_samples/fork_join_tests/src/main.rs +++ b/juno_samples/fork_join_tests/src/main.rs @@ -43,15 +43,20 @@ fn main() { let correct = vec![7i32; 4]; assert(&correct, output); - let mut r = runner!(test8); - let output = r.run(0).await; - let correct = vec![10, 17, 24, 31, 38, 45, 52, 59]; + let mut r = runner!(test6); + let output = r.run(73).await; + let correct = (73i32..73i32+1024i32).collect(); assert(&correct, output); let mut r = runner!(test7); let output = r.run(42).await; let correct: i32 = (42i32..42i32+32i32).sum(); assert_eq!(correct, output); + + let mut r = runner!(test8); + let output = r.run(0).await; + let correct = vec![10, 17, 24, 31, 38, 45, 52, 59]; + assert(&correct, output); }); } -- GitLab From 7a9bdc262e76dbedb1849b3e3d318829a8c1fe3a Mon Sep 17 00:00:00 2001 From: Russel Arbore <russel.jma@gmail.com> Date: Mon, 10 Feb 2025 09:39:16 -0600 Subject: [PATCH 31/33] gpu --- hercules_opt/src/fork_transforms.rs | 14 ++++++-------- .../fork_join_tests/src/fork_join_tests.jn | 6 +++--- juno_samples/fork_join_tests/src/gpu.sch | 13 ++++++++++++- juno_scheduler/src/pm.rs | 3 --- 4 files changed, 21 insertions(+), 15 deletions(-) diff --git a/hercules_opt/src/fork_transforms.rs b/hercules_opt/src/fork_transforms.rs index 80ff7b8f..b5d3bb28 100644 --- a/hercules_opt/src/fork_transforms.rs +++ b/hercules_opt/src/fork_transforms.rs @@ -322,8 +322,6 @@ pub fn fork_bufferize_fission_helper<'a, 'b>( where 'a: 'b, { - // Returns the two forks that it generates. - if bufferized_edges.is_empty() { return None; } @@ -384,7 +382,6 @@ where .replace_all_uses_where(old_node, map[&old_node], |node| *node == outside_user)?; } - let new_join = map[&join]; let new_fork = map[&fork]; // FIXME: Do this as part of copy subgraph? @@ -422,12 +419,14 @@ where )); let empty_buffer = edit.add_zero_constant(empty_buffer); let empty_buffer = edit.add_node(Node::Constant { id: empty_buffer }); + edit = edit.add_schedule(empty_buffer, Schedule::NoResetConstant)?; let reduce = Node::Reduce { control: join, init: empty_buffer, reduct: write, }; let reduce = edit.add_node(reduce); + edit = edit.add_schedule(reduce, Schedule::ParallelReduce)?; // Fix write node edit = edit.replace_all_uses_where(NodeID::new(0), reduce, |usee| *usee == write)?; @@ -449,12 +448,11 @@ where Ok(edit) }); - if edit_result == false { - todo!(); - return None; + if edit_result { + Some((fork, new_fork_id)) + } else { + None } - - Some((fork, new_fork_id)) } /** Split a 1D fork into a separate fork for each reduction. */ diff --git a/juno_samples/fork_join_tests/src/fork_join_tests.jn b/juno_samples/fork_join_tests/src/fork_join_tests.jn index a765726f..886ab13b 100644 --- a/juno_samples/fork_join_tests/src/fork_join_tests.jn +++ b/juno_samples/fork_join_tests/src/fork_join_tests.jn @@ -97,8 +97,8 @@ fn test7(input: i32) -> i32 { #[entry] fn test8(input : i32) -> i32[8] { - let arr : i32[8]; - let out : i32[8]; + @const1 let arr : i32[8]; + @const2 let out : i32[8]; for i = 0 to 8 { arr[i] = i as i32; @@ -116,4 +116,4 @@ fn test8(input : i32) -> i32[8] { out[i] = d; } return out; -} \ No newline at end of file +} diff --git a/juno_samples/fork_join_tests/src/gpu.sch b/juno_samples/fork_join_tests/src/gpu.sch index bd1fd1d8..159fac94 100644 --- a/juno_samples/fork_join_tests/src/gpu.sch +++ b/juno_samples/fork_join_tests/src/gpu.sch @@ -5,18 +5,21 @@ no-memset(test3@const1); no-memset(test3@const2); no-memset(test3@const3); no-memset(test6@const); +no-memset(test8@const1); +no-memset(test8@const2); gvn(*); phi-elim(*); dce(*); -let auto = auto-outline(test1, test2, test3, test4, test5, test7); +let auto = auto-outline(test1, test2, test3, test4, test5, test7, test8); gpu(auto.test1); gpu(auto.test2); gpu(auto.test3); gpu(auto.test4); gpu(auto.test5); gpu(auto.test7); +gpu(auto.test8); ip-sroa(*); sroa(*); @@ -50,6 +53,14 @@ dce(auto.test7); simplify-cfg(auto.test7); dce(auto.test7); +let fission = fork-fission-bufferize[test8@loop, test8@bufferize1](auto.test8); +xdot[true](*); +dce(auto.test8); +ccp(auto.test8); +dce(auto.test8); +simplify-cfg(auto.test8); +dce(auto.test8); + ip-sroa(*); sroa(*); dce(*); diff --git a/juno_scheduler/src/pm.rs b/juno_scheduler/src/pm.rs index 58ff399f..20825c54 100644 --- a/juno_scheduler/src/pm.rs +++ b/juno_scheduler/src/pm.rs @@ -2191,9 +2191,6 @@ fn run_pass( result = Value::Record { fields: new_fork_joins, }; - - println!("result: {:?}", result); - } Pass::ForkDimMerge => { assert!(args.is_empty()); -- GitLab From 2b48a39eeea07a6654dafff1372e45ccd75a7321 Mon Sep 17 00:00:00 2001 From: Russel Arbore <russel.jma@gmail.com> Date: Mon, 10 Feb 2025 09:41:45 -0600 Subject: [PATCH 32/33] wut --- juno_samples/matmul/build.rs | 2 -- 1 file changed, 2 deletions(-) diff --git a/juno_samples/matmul/build.rs b/juno_samples/matmul/build.rs index d2813388..0be838c6 100644 --- a/juno_samples/matmul/build.rs +++ b/juno_samples/matmul/build.rs @@ -6,8 +6,6 @@ fn main() { JunoCompiler::new() .file_in_src("matmul.jn") .unwrap() - .schedule_in_src("cpu.sch") - .unwrap() .build() .unwrap(); } -- GitLab From 17d1c03be21ec990964917f073e0e25b8bae5f66 Mon Sep 17 00:00:00 2001 From: Russel Arbore <russel.jma@gmail.com> Date: Mon, 10 Feb 2025 09:42:18 -0600 Subject: [PATCH 33/33] cleanup --- hercules_opt/src/fork_transforms.rs | 10 ---------- 1 file changed, 10 deletions(-) diff --git a/hercules_opt/src/fork_transforms.rs b/hercules_opt/src/fork_transforms.rs index b5d3bb28..7c423892 100644 --- a/hercules_opt/src/fork_transforms.rs +++ b/hercules_opt/src/fork_transforms.rs @@ -166,16 +166,6 @@ pub fn copy_subgraph( (new_nodes, map, outside_users) } -fn is_strict_superset<T: Eq + Hash>(set1: &HashSet<T>, set2: &HashSet<T>) -> bool { - // A strict superset must be larger than its subset - if set1.len() <= set2.len() { - return false; - } - - // Every element in set2 must be in set1 - set2.iter().all(|item| set1.contains(item)) -} - pub fn find_bufferize_edges( editor: &mut FunctionEditor, fork: NodeID, -- GitLab