From e0f26259f8895724ffc4d7767c7b66675ab2e871 Mon Sep 17 00:00:00 2001 From: Russel Arbore <russel.jma@gmail.com> Date: Sat, 15 Feb 2025 17:45:14 -0600 Subject: [PATCH 01/10] new dot test --- Cargo.lock | 11 ++++++++++ Cargo.toml | 39 +++++++++++++++++------------------- juno_samples/dot/Cargo.toml | 22 ++++++++++++++++++++ juno_samples/dot/build.rs | 24 ++++++++++++++++++++++ juno_samples/dot/src/cpu.sch | 17 ++++++++++++++++ juno_samples/dot/src/dot.jn | 10 +++++++++ juno_samples/dot/src/gpu.sch | 18 +++++++++++++++++ juno_samples/dot/src/main.rs | 27 +++++++++++++++++++++++++ 8 files changed, 147 insertions(+), 21 deletions(-) create mode 100644 juno_samples/dot/Cargo.toml create mode 100644 juno_samples/dot/build.rs create mode 100644 juno_samples/dot/src/cpu.sch create mode 100644 juno_samples/dot/src/dot.jn create mode 100644 juno_samples/dot/src/gpu.sch create mode 100644 juno_samples/dot/src/main.rs diff --git a/Cargo.lock b/Cargo.lock index 4431cf5d..81c37d79 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1170,6 +1170,17 @@ dependencies = [ "with_builtin_macros", ] +[[package]] +name = "juno_dot" +version = "0.1.0" +dependencies = [ + "async-std", + "hercules_rt", + "juno_build", + "rand 0.8.5", + "with_builtin_macros", +] + [[package]] name = "juno_edge_detection" version = "0.1.0" diff --git a/Cargo.toml b/Cargo.toml index 6514046b..7a3906fb 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -5,34 +5,31 @@ members = [ "hercules_ir", "hercules_opt", "hercules_rt", - - "juno_utils", - "juno_frontend", - "juno_scheduler", - "juno_build", - - "hercules_test/hercules_interpreter", - "hercules_test/hercules_tests", - - "hercules_samples/dot", - "hercules_samples/matmul", - "hercules_samples/fac", "hercules_samples/call", "hercules_samples/ccp", - - "juno_samples/simple3", - "juno_samples/patterns", - "juno_samples/matmul", - "juno_samples/casts_and_intrinsics", - "juno_samples/control", + "hercules_samples/dot", + "hercules_samples/fac", + "hercules_samples/matmul", + "hercules_test/hercules_interpreter", + "hercules_test/hercules_tests", + "juno_build", + "juno_frontend", "juno_samples/antideps", - "juno_samples/implicit_clone", + "juno_samples/casts_and_intrinsics", "juno_samples/cava", "juno_samples/concat", - "juno_samples/schedule_test", + "juno_samples/control", + "juno_samples/dot", "juno_samples/edge_detection", "juno_samples/fork_join_tests", + "juno_samples/implicit_clone", + "juno_samples/matmul", + "juno_samples/median_window", "juno_samples/multi_device", + "juno_samples/patterns", "juno_samples/products", - "juno_samples/median_window", + "juno_samples/schedule_test", + "juno_samples/simple3", + "juno_scheduler", + "juno_utils", ] diff --git a/juno_samples/dot/Cargo.toml b/juno_samples/dot/Cargo.toml new file mode 100644 index 00000000..155a0b13 --- /dev/null +++ b/juno_samples/dot/Cargo.toml @@ -0,0 +1,22 @@ +[package] +name = "juno_dot" +version = "0.1.0" +authors = ["Aaron Councilman <aaronjc4@illinois.edu>"] +edition = "2021" + +[[bin]] +name = "juno_dot" +path = "src/main.rs" + +[features] +cuda = ["juno_build/cuda", "hercules_rt/cuda"] + +[build-dependencies] +juno_build = { path = "../../juno_build" } + +[dependencies] +juno_build = { path = "../../juno_build" } +hercules_rt = { path = "../../hercules_rt" } +with_builtin_macros = "0.1.0" +async-std = "*" +rand = "*" diff --git a/juno_samples/dot/build.rs b/juno_samples/dot/build.rs new file mode 100644 index 00000000..b23b9860 --- /dev/null +++ b/juno_samples/dot/build.rs @@ -0,0 +1,24 @@ +use juno_build::JunoCompiler; + +fn main() { + #[cfg(not(feature = "cuda"))] + { + JunoCompiler::new() + .file_in_src("dot.jn") + .unwrap() + .schedule_in_src("cpu.sch") + .unwrap() + .build() + .unwrap(); + } + #[cfg(feature = "cuda")] + { + JunoCompiler::new() + .file_in_src("dot.jn") + .unwrap() + .schedule_in_src("gpu.sch") + .unwrap() + .build() + .unwrap(); + } +} diff --git a/juno_samples/dot/src/cpu.sch b/juno_samples/dot/src/cpu.sch new file mode 100644 index 00000000..be110bde --- /dev/null +++ b/juno_samples/dot/src/cpu.sch @@ -0,0 +1,17 @@ +phi-elim(*); + +forkify(*); +fork-guard-elim(*); +dce(*); + +fork-tile[8, 0, false, true](*); +fork-split(*); + +let out = auto-outline(*); +cpu(out.dot); +ip-sroa(*); +sroa(*); +dce(*); + +unforkify(*); +gcm(*); diff --git a/juno_samples/dot/src/dot.jn b/juno_samples/dot/src/dot.jn new file mode 100644 index 00000000..0421dc4c --- /dev/null +++ b/juno_samples/dot/src/dot.jn @@ -0,0 +1,10 @@ +#[entry] +fn dot<n : usize>(a : f32[n], b : f32[n]) -> f32 { + let res : f32; + + for i = 0 to n { + res += a[i] * b[i]; + } + + return res; +} diff --git a/juno_samples/dot/src/gpu.sch b/juno_samples/dot/src/gpu.sch new file mode 100644 index 00000000..b7ece681 --- /dev/null +++ b/juno_samples/dot/src/gpu.sch @@ -0,0 +1,18 @@ +phi-elim(*); + +forkify(*); +fork-guard-elim(*); +dce(*); + +fork-tile[8, 0, false, true](*); +fork-split(*); + +let out = auto-outline(*); +gpu(out.dot); +ip-sroa(*); +sroa(*); +dce(*); + +unforkify(*); +gcm(*); + diff --git a/juno_samples/dot/src/main.rs b/juno_samples/dot/src/main.rs new file mode 100644 index 00000000..5d0aaf7b --- /dev/null +++ b/juno_samples/dot/src/main.rs @@ -0,0 +1,27 @@ +#![feature(concat_idents)] +use std::iter::zip; + +use rand::random; + +use hercules_rt::{runner, HerculesImmBox, HerculesImmBoxTo}; + +juno_build::juno!("dot"); + +fn main() { + async_std::task::block_on(async { + const N: u64 = 4096; + let a: Box<[f32]> = (0..N).map(|_| random::<f32>()).collect(); + let b: Box<[f32]> = (0..N).map(|_| random::<f32>()).collect(); + let a_herc = HerculesImmBox::from(&a as &[f32]); + let b_herc = HerculesImmBox::from(&b as &[f32]); + let mut r = runner!(dot); + let output = r.run(N, a_herc.to(), b_herc.to()).await; + let correct = zip(a, b).map(|(a, b)| a * b).sum(); + assert_eq!(output, correct); + }); +} + +#[test] +fn dot_test() { + main(); +} -- GitLab From 3329cb5f72e361010d2596f5724211798467a3dd Mon Sep 17 00:00:00 2001 From: Russel Arbore <russel.jma@gmail.com> Date: Sat, 15 Feb 2025 18:01:01 -0600 Subject: [PATCH 02/10] Rename tightassociative to monoidreduce --- hercules_cg/src/gpu.rs | 4 ++-- hercules_ir/src/einsum.rs | 13 ++++++------ hercules_ir/src/ir.rs | 2 +- hercules_opt/src/fork_transforms.rs | 4 ++-- hercules_opt/src/forkify.rs | 31 ++++++++++++++--------------- hercules_opt/src/schedule.rs | 24 ++++++++++++---------- juno_scheduler/src/compile.rs | 2 +- juno_scheduler/src/pm.rs | 2 +- 8 files changed, 42 insertions(+), 40 deletions(-) diff --git a/hercules_cg/src/gpu.rs b/hercules_cg/src/gpu.rs index d6461a1e..33b239f7 100644 --- a/hercules_cg/src/gpu.rs +++ b/hercules_cg/src/gpu.rs @@ -821,7 +821,7 @@ extern \"C\" {} {}(", && fork_size.is_power_of_two() && reduces.iter().all(|&reduce| { self.function.schedules[reduce.idx()].contains(&Schedule::ParallelReduce) - || self.function.schedules[reduce.idx()].contains(&Schedule::TightAssociative) + || self.function.schedules[reduce.idx()].contains(&Schedule::MonoidReduce) }) { // If there's an associative Reduce, parallelize the larger factor @@ -834,7 +834,7 @@ extern \"C\" {} {}(", // restriction doesn't help for parallel Writes, so nested parallelization // is possible. if reduces.iter().any(|&reduce| { - self.function.schedules[reduce.idx()].contains(&Schedule::TightAssociative) + self.function.schedules[reduce.idx()].contains(&Schedule::MonoidReduce) }) || fork_size > self.kernel_params.max_num_threads / subtree_quota { if fork_size >= subtree_quota { diff --git a/hercules_ir/src/einsum.rs b/hercules_ir/src/einsum.rs index b222e1bc..6c2ca31b 100644 --- a/hercules_ir/src/einsum.rs +++ b/hercules_ir/src/einsum.rs @@ -150,13 +150,12 @@ pub fn einsum( ctx.result_insert(reduce, total_id); } // The reduce defines a sum reduction over a set of fork dimensions. - else if function.schedules[reduce.idx()].contains(&Schedule::TightAssociative) - && let Node::Binary { - op: BinaryOperator::Add, - left, - right, - } = function.nodes[reduct.idx()] - && (left == reduce || right == reduce) + else if let Node::Binary { + op: BinaryOperator::Add, + left, + right, + } = function.nodes[reduct.idx()] + && ((left == reduce) ^ (right == reduce)) { let data_expr = ctx.compute_math_expr(if left == reduce { right } else { left }); let reduce_expr = MathExpr::SumReduction( diff --git a/hercules_ir/src/ir.rs b/hercules_ir/src/ir.rs index eb008904..972fd7f9 100644 --- a/hercules_ir/src/ir.rs +++ b/hercules_ir/src/ir.rs @@ -328,7 +328,7 @@ pub enum Schedule { Vectorizable, // This reduce can be re-associated. This may lower a sequential dependency // chain into a reduction tree. - TightAssociative, + MonoidReduce, // This constant node doesn't need to be memset to zero. NoResetConstant, // This call should be called in a spawned future. diff --git a/hercules_opt/src/fork_transforms.rs b/hercules_opt/src/fork_transforms.rs index c32a517e..342728fd 100644 --- a/hercules_opt/src/fork_transforms.rs +++ b/hercules_opt/src/fork_transforms.rs @@ -1175,7 +1175,7 @@ fn fork_interchange( first_dim: usize, second_dim: usize, ) { - // Check that every reduce on the join is parallel or tight associative. + // Check that every reduce on the join is parallel or associative. let nodes = &editor.func().nodes; let schedules = &editor.func().schedules; if !editor @@ -1183,7 +1183,7 @@ fn fork_interchange( .filter(|id| nodes[id.idx()].is_reduce()) .all(|id| { schedules[id.idx()].contains(&Schedule::ParallelReduce) - || schedules[id.idx()].contains(&Schedule::TightAssociative) + || schedules[id.idx()].contains(&Schedule::MonoidReduce) }) { // If not, we can't necessarily do interchange. diff --git a/hercules_opt/src/forkify.rs b/hercules_opt/src/forkify.rs index 774220df..2f6466c0 100644 --- a/hercules_opt/src/forkify.rs +++ b/hercules_opt/src/forkify.rs @@ -30,15 +30,17 @@ pub fn forkify( for l in natural_loops { // FIXME: Run on all-bottom level loops, as they can be independently optimized without recomputing analyses. - if editor.is_mutable(l.0) && forkify_loop( - editor, - control_subgraph, - fork_join_map, - &Loop { - header: l.0, - control: l.1.clone(), - }, - ) { + if editor.is_mutable(l.0) + && forkify_loop( + editor, + control_subgraph, + fork_join_map, + &Loop { + header: l.0, + control: l.1.clone(), + }, + ) + { return true; } } @@ -166,7 +168,6 @@ pub fn forkify_loop( return false; } - // Get all phis used outside of the loop, they need to be reductionable. // For now just assume all phis will be phis used outside of the loop, except for the canonical iv. // FIXME: We need a different definiton of `loop_nodes` to check for phis used outside hte loop than the one @@ -371,15 +372,13 @@ pub fn forkify_loop( edit = edit.add_schedule(reduce_id, Schedule::ParallelReduce)?; } if (!edit.get_node(init).is_reduce() - && edit - .get_schedule(init) - .contains(&Schedule::TightAssociative)) + && edit.get_schedule(init).contains(&Schedule::MonoidReduce)) || (!edit.get_node(continue_latch).is_reduce() && edit .get_schedule(continue_latch) - .contains(&Schedule::TightAssociative)) + .contains(&Schedule::MonoidReduce)) { - edit = edit.add_schedule(reduce_id, Schedule::TightAssociative)?; + edit = edit.add_schedule(reduce_id, Schedule::MonoidReduce)?; } edit = edit.replace_all_uses_where(phi, reduce_id, |usee| *usee != reduce_id)?; @@ -539,7 +538,7 @@ pub fn analyze_phis<'a>( // PHIs on the frontier of the uses by the candidate phi, i.e in uses_for_dependance need // to have headers that postdominate the loop continue latch. The value of the PHI used needs to be defined // by the time the reduce is triggered (at the end of the loop's internal control). - + // No nodes in data cycles with this phi (in the loop) are used outside the loop, besides the loop_continue_latch. // If some other node in the cycle is used, there is not a valid node to assign it after making the cycle a reduce. if intersection diff --git a/hercules_opt/src/schedule.rs b/hercules_opt/src/schedule.rs index fe894e47..4cb912fd 100644 --- a/hercules_opt/src/schedule.rs +++ b/hercules_opt/src/schedule.rs @@ -146,21 +146,25 @@ pub fn infer_vectorizable(editor: &mut FunctionEditor, fork_join_map: &HashMap<N } /* - * Infer tight associative reduction loops. Exactly one of the associative - * operation's operands must be the Reduce node, and all other operands must - * not be in the Reduce node's cycle. + * Infer monoid reduction loops. Exactly one of the associative operation's + * operands must be the Reduce node, and all other operands must not be in the + * Reduce node's cycle. */ -pub fn infer_tight_associative( +pub fn infer_monoid_reduce( editor: &mut FunctionEditor, reduce_cycles: &HashMap<NodeID, HashSet<NodeID>>, ) { - let is_binop_associative = |op| { + let is_binop_monoid = |op| { matches!( op, - BinaryOperator::Add | BinaryOperator::Or | BinaryOperator::And | BinaryOperator::Xor + BinaryOperator::Add + | BinaryOperator::Mul + | BinaryOperator::Or + | BinaryOperator::And + | BinaryOperator::Xor ) }; - let is_intrinsic_associative = |intrinsic| matches!(intrinsic, Intrinsic::Max | Intrinsic::Min); + let is_intrinsic_monoid = |intrinsic| matches!(intrinsic, Intrinsic::Max | Intrinsic::Min); for id in editor.node_ids() { let func = editor.func(); @@ -172,12 +176,12 @@ pub fn infer_tight_associative( && (matches!(func.nodes[reduct.idx()], Node::Binary { left, right, op } if ((left == id && !reduce_cycles[&id].contains(&right)) || (right == id && !reduce_cycles[&id].contains(&left))) && - is_binop_associative(op)) + is_binop_monoid(op)) || matches!(&func.nodes[reduct.idx()], Node::IntrinsicCall { intrinsic, args } - if (args.contains(&id) && is_intrinsic_associative(*intrinsic) && + if (args.contains(&id) && is_intrinsic_monoid(*intrinsic) && args.iter().filter(|arg| **arg != id).all(|arg| !reduce_cycles[&id].contains(arg))))) { - editor.edit(|edit| edit.add_schedule(id, Schedule::TightAssociative)); + editor.edit(|edit| edit.add_schedule(id, Schedule::MonoidReduce)); } } } diff --git a/juno_scheduler/src/compile.rs b/juno_scheduler/src/compile.rs index 7c92e00d..188cb1c6 100644 --- a/juno_scheduler/src/compile.rs +++ b/juno_scheduler/src/compile.rs @@ -153,7 +153,7 @@ impl FromStr for Appliable { "gpu" | "cuda" | "nvidia" => Ok(Appliable::Device(Device::CUDA)), "host" | "rust" | "rust-async" => Ok(Appliable::Device(Device::AsyncRust)), - "associative" => Ok(Appliable::Schedule(Schedule::TightAssociative)), + "monoid" | "associative" => Ok(Appliable::Schedule(Schedule::MonoidReduce)), "parallel-fork" => Ok(Appliable::Schedule(Schedule::ParallelFork)), "parallel-reduce" => Ok(Appliable::Schedule(Schedule::ParallelReduce)), "vectorize" => Ok(Appliable::Schedule(Schedule::Vectorizable)), diff --git a/juno_scheduler/src/pm.rs b/juno_scheduler/src/pm.rs index 8e152cfe..19bd78e2 100644 --- a/juno_scheduler/src/pm.rs +++ b/juno_scheduler/src/pm.rs @@ -2034,7 +2034,7 @@ fn run_pass( infer_parallel_reduce(&mut func, fork_join_map, reduce_cycles); infer_parallel_fork(&mut func, fork_join_map); infer_vectorizable(&mut func, fork_join_map); - infer_tight_associative(&mut func, reduce_cycles); + infer_monoid_reduce(&mut func, reduce_cycles); infer_no_reset_constants(&mut func, no_reset_constants); changed |= func.modified(); } -- GitLab From 822eb61ac348e7b80a36690234c305d427e2b03b Mon Sep 17 00:00:00 2001 From: Russel Arbore <russel.jma@gmail.com> Date: Sat, 15 Feb 2025 18:26:29 -0600 Subject: [PATCH 03/10] Clean monoid reduces pass --- hercules_ir/src/ir.rs | 16 ++++++++ hercules_opt/src/fork_transforms.rs | 61 +++++++++++++++++++++++++++++ hercules_opt/src/schedule.rs | 1 - hercules_opt/src/utils.rs | 26 ++++++++++++ juno_samples/dot/src/cpu.sch | 2 + juno_scheduler/src/compile.rs | 1 + juno_scheduler/src/ir.rs | 1 + juno_scheduler/src/pm.rs | 17 ++++++++ 8 files changed, 124 insertions(+), 1 deletion(-) diff --git a/hercules_ir/src/ir.rs b/hercules_ir/src/ir.rs index 972fd7f9..e8dfc280 100644 --- a/hercules_ir/src/ir.rs +++ b/hercules_ir/src/ir.rs @@ -1086,6 +1086,22 @@ impl DynamicConstant { } } + pub fn is_zero(&self) -> bool { + if *self == DynamicConstant::Constant(0) { + true + } else { + false + } + } + + pub fn is_one(&self) -> bool { + if *self == DynamicConstant::Constant(1) { + true + } else { + false + } + } + pub fn try_parameter(&self) -> Option<usize> { if let DynamicConstant::Parameter(v) = self { Some(*v) diff --git a/hercules_opt/src/fork_transforms.rs b/hercules_opt/src/fork_transforms.rs index 342728fd..0b5de1e5 100644 --- a/hercules_opt/src/fork_transforms.rs +++ b/hercules_opt/src/fork_transforms.rs @@ -1446,3 +1446,64 @@ fn fork_fusion( Ok(edit) }) } + +/* + * Looks for monoid reductions where the initial input is not the identity + * element, and converts them into a form whose initial input is an identity + * element. This aides in parallelizing outer loops. Looks only at reduces with + * the monoid reduce schedule, since that indicates a particular structure which + * is annoying to check for again. + */ +pub fn clean_monoid_reduces(editor: &mut FunctionEditor, typing: &Vec<TypeID>) { + for id in editor.node_ids() { + if !editor.func().schedules[id.idx()].contains(&Schedule::MonoidReduce) { + continue; + } + let nodes = &editor.func().nodes; + let Some((_, init, reduct)) = nodes[id.idx()].try_reduce() else { + continue; + }; + + match nodes[reduct.idx()] { + Node::Binary { + op, + left: _, + right: _, + } if (op == BinaryOperator::Add || op == BinaryOperator::Or) + && !is_zero(editor, init) => + { + editor.edit(|mut edit| { + let zero = edit.add_zero_constant(typing[init.idx()]); + let zero = edit.add_node(Node::Constant { id: zero }); + edit = edit.replace_all_uses_where(init, zero, |u| *u == id)?; + let final_add = edit.add_node(Node::Binary { + op, + left: init, + right: id, + }); + edit.replace_all_uses_where(id, final_add, |u| *u != reduct && *u != final_add) + }); + } + Node::Binary { + op, + left: _, + right: _, + } if (op == BinaryOperator::Mul || op == BinaryOperator::And) + && !is_one(editor, init) => + { + editor.edit(|mut edit| { + let one = edit.add_one_constant(typing[init.idx()]); + let one = edit.add_node(Node::Constant { id: one }); + edit = edit.replace_all_uses_where(init, one, |u| *u == id)?; + let final_add = edit.add_node(Node::Binary { + op, + left: init, + right: id, + }); + edit.replace_all_uses_where(id, final_add, |u| *u != reduct && *u != final_add) + }); + } + _ => panic!(), + } + } +} diff --git a/hercules_opt/src/schedule.rs b/hercules_opt/src/schedule.rs index 4cb912fd..7ecf07a4 100644 --- a/hercules_opt/src/schedule.rs +++ b/hercules_opt/src/schedule.rs @@ -161,7 +161,6 @@ pub fn infer_monoid_reduce( | BinaryOperator::Mul | BinaryOperator::Or | BinaryOperator::And - | BinaryOperator::Xor ) }; let is_intrinsic_monoid = |intrinsic| matches!(intrinsic, Intrinsic::Max | Intrinsic::Min); diff --git a/hercules_opt/src/utils.rs b/hercules_opt/src/utils.rs index 3f12ad7c..1806d5c7 100644 --- a/hercules_opt/src/utils.rs +++ b/hercules_opt/src/utils.rs @@ -541,3 +541,29 @@ where nodes[fork.idx()].try_fork().unwrap().1.len() == rep_dims.len() }) } + +pub fn is_zero(editor: &FunctionEditor, id: NodeID) -> bool { + let nodes = &editor.func().nodes; + nodes[id.idx()] + .try_constant() + .map(|id| editor.get_constant(id).is_zero()) + .unwrap_or(false) + || nodes[id.idx()] + .try_dynamic_constant() + .map(|id| editor.get_dynamic_constant(id).is_zero()) + .unwrap_or(false) + || nodes[id.idx()].is_undef() +} + +pub fn is_one(editor: &FunctionEditor, id: NodeID) -> bool { + let nodes = &editor.func().nodes; + nodes[id.idx()] + .try_constant() + .map(|id| editor.get_constant(id).is_one()) + .unwrap_or(false) + || nodes[id.idx()] + .try_dynamic_constant() + .map(|id| editor.get_dynamic_constant(id).is_one()) + .unwrap_or(false) + || nodes[id.idx()].is_undef() +} diff --git a/juno_samples/dot/src/cpu.sch b/juno_samples/dot/src/cpu.sch index be110bde..6ee00c8b 100644 --- a/juno_samples/dot/src/cpu.sch +++ b/juno_samples/dot/src/cpu.sch @@ -6,6 +6,8 @@ dce(*); fork-tile[8, 0, false, true](*); fork-split(*); +infer-schedules(*); +clean-monoid-reduces(*); let out = auto-outline(*); cpu(out.dot); diff --git a/juno_scheduler/src/compile.rs b/juno_scheduler/src/compile.rs index 188cb1c6..2e930639 100644 --- a/juno_scheduler/src/compile.rs +++ b/juno_scheduler/src/compile.rs @@ -111,6 +111,7 @@ impl FromStr for Appliable { "auto-outline" => Ok(Appliable::Pass(ir::Pass::AutoOutline)), "ccp" => Ok(Appliable::Pass(ir::Pass::CCP)), "crc" | "collapse-read-chains" => Ok(Appliable::Pass(ir::Pass::CRC)), + "clean-monoid-reduces" => Ok(Appliable::Pass(ir::Pass::CleanMonoidReduces)), "dce" => Ok(Appliable::Pass(ir::Pass::DCE)), "delete-uncalled" => Ok(Appliable::DeleteUncalled), "float-collections" | "collections" => Ok(Appliable::Pass(ir::Pass::FloatCollections)), diff --git a/juno_scheduler/src/ir.rs b/juno_scheduler/src/ir.rs index 205cd70b..d6f59bef 100644 --- a/juno_scheduler/src/ir.rs +++ b/juno_scheduler/src/ir.rs @@ -6,6 +6,7 @@ pub enum Pass { ArrayToProduct, AutoOutline, CCP, + CleanMonoidReduces, CRC, DCE, FloatCollections, diff --git a/juno_scheduler/src/pm.rs b/juno_scheduler/src/pm.rs index 19bd78e2..45cebe80 100644 --- a/juno_scheduler/src/pm.rs +++ b/juno_scheduler/src/pm.rs @@ -1668,6 +1668,23 @@ fn run_pass( pm.delete_gravestones(); pm.clear_analyses(); } + Pass::CleanMonoidReduces => { + assert!(args.is_empty()); + pm.make_typing(); + let typing = pm.typing.take().unwrap(); + for (func, typing) in build_selection(pm, selection, false) + .into_iter() + .zip(typing.iter()) + { + let Some(mut func) = func else { + continue; + }; + clean_monoid_reduces(&mut func, typing); + changed |= func.modified(); + } + pm.delete_gravestones(); + pm.clear_analyses(); + } Pass::CRC => { assert!(args.is_empty()); for func in build_selection(pm, selection, false) { -- GitLab From 544c6659b5cfeaef555990e3d29808587cfa7b95 Mon Sep 17 00:00:00 2001 From: Russel Arbore <russel.jma@gmail.com> Date: Sat, 15 Feb 2025 18:39:56 -0600 Subject: [PATCH 04/10] Found bug in fork tiling in the other direction --- hercules_opt/src/fork_transforms.rs | 2 +- juno_samples/dot/src/cpu.sch | 5 +++++ juno_samples/dot/src/dot.jn | 4 ++-- juno_samples/dot/src/main.rs | 10 +++++----- 4 files changed, 13 insertions(+), 8 deletions(-) diff --git a/hercules_opt/src/fork_transforms.rs b/hercules_opt/src/fork_transforms.rs index 0b5de1e5..05606f5a 100644 --- a/hercules_opt/src/fork_transforms.rs +++ b/hercules_opt/src/fork_transforms.rs @@ -1503,7 +1503,7 @@ pub fn clean_monoid_reduces(editor: &mut FunctionEditor, typing: &Vec<TypeID>) { edit.replace_all_uses_where(id, final_add, |u| *u != reduct && *u != final_add) }); } - _ => panic!(), + _ => {} } } } diff --git a/juno_samples/dot/src/cpu.sch b/juno_samples/dot/src/cpu.sch index 6ee00c8b..4e40e351 100644 --- a/juno_samples/dot/src/cpu.sch +++ b/juno_samples/dot/src/cpu.sch @@ -5,9 +5,12 @@ fork-guard-elim(*); dce(*); fork-tile[8, 0, false, true](*); +fork-tile[32, 0, false, false](*); fork-split(*); infer-schedules(*); clean-monoid-reduces(*); +infer-schedules(*); +clean-monoid-reduces(*); let out = auto-outline(*); cpu(out.dot); @@ -15,5 +18,7 @@ ip-sroa(*); sroa(*); dce(*); +xdot[true](*); + unforkify(*); gcm(*); diff --git a/juno_samples/dot/src/dot.jn b/juno_samples/dot/src/dot.jn index 0421dc4c..8c0e029c 100644 --- a/juno_samples/dot/src/dot.jn +++ b/juno_samples/dot/src/dot.jn @@ -1,6 +1,6 @@ #[entry] -fn dot<n : usize>(a : f32[n], b : f32[n]) -> f32 { - let res : f32; +fn dot<n : usize>(a : i64[n], b : i64[n]) -> i64 { + let res : i64; for i = 0 to n { res += a[i] * b[i]; diff --git a/juno_samples/dot/src/main.rs b/juno_samples/dot/src/main.rs index 5d0aaf7b..b73f8710 100644 --- a/juno_samples/dot/src/main.rs +++ b/juno_samples/dot/src/main.rs @@ -9,11 +9,11 @@ juno_build::juno!("dot"); fn main() { async_std::task::block_on(async { - const N: u64 = 4096; - let a: Box<[f32]> = (0..N).map(|_| random::<f32>()).collect(); - let b: Box<[f32]> = (0..N).map(|_| random::<f32>()).collect(); - let a_herc = HerculesImmBox::from(&a as &[f32]); - let b_herc = HerculesImmBox::from(&b as &[f32]); + const N: u64 = 1024 * 1024; + let a: Box<[i64]> = (0..N).map(|_| random::<i64>() % 100).collect(); + let b: Box<[i64]> = (0..N).map(|_| random::<i64>() % 100).collect(); + let a_herc = HerculesImmBox::from(&a as &[i64]); + let b_herc = HerculesImmBox::from(&b as &[i64]); let mut r = runner!(dot); let output = r.run(N, a_herc.to(), b_herc.to()).await; let correct = zip(a, b).map(|(a, b)| a * b).sum(); -- GitLab From d9a01c8b3536d28bba4726e9bc9e331a56c52d4f Mon Sep 17 00:00:00 2001 From: Russel Arbore <russel.jma@gmail.com> Date: Sat, 15 Feb 2025 18:44:41 -0600 Subject: [PATCH 05/10] Fix --- hercules_opt/src/fork_transforms.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/hercules_opt/src/fork_transforms.rs b/hercules_opt/src/fork_transforms.rs index 05606f5a..7d4fa9a2 100644 --- a/hercules_opt/src/fork_transforms.rs +++ b/hercules_opt/src/fork_transforms.rs @@ -1002,7 +1002,7 @@ pub fn chunk_fork_unguarded( } else if tid_dim == dim_idx { let tile_tid = Node::ThreadID { control: new_fork, - dimension: tid_dim, + dimension: tid_dim + 1, }; let tile_tid = edit.add_node(tile_tid); let inner_dc = edit.add_node(Node::DynamicConstant { id: inner_dc_id }); -- GitLab From 900024ee60c8d8b8c5fb8f1979b2a19105ac7bed Mon Sep 17 00:00:00 2001 From: Russel Arbore <russel.jma@gmail.com> Date: Sat, 15 Feb 2025 21:53:56 -0600 Subject: [PATCH 06/10] Fixes for bufferize fission --- hercules_opt/src/fork_transforms.rs | 51 ++++++++++++++++++++++++++--- hercules_opt/src/simplify_cfg.rs | 45 ++++++++++++------------- juno_samples/dot/src/dot.jn | 2 +- juno_scheduler/src/ir.rs | 4 +-- juno_scheduler/src/pm.rs | 48 ++++++++++++++------------- 5 files changed, 97 insertions(+), 53 deletions(-) diff --git a/hercules_opt/src/fork_transforms.rs b/hercules_opt/src/fork_transforms.rs index 7d4fa9a2..e832e559 100644 --- a/hercules_opt/src/fork_transforms.rs +++ b/hercules_opt/src/fork_transforms.rs @@ -204,14 +204,48 @@ pub fn find_bufferize_edges( edges } +pub fn ff_bufferize_create_not_reduce_cycle_label_helper( + editor: &mut FunctionEditor, + fork: NodeID, + fork_join_map: &HashMap<NodeID, NodeID>, + reduce_cycles: &HashMap<NodeID, HashSet<NodeID>>, + nodes_in_fork_joins: &HashMap<NodeID, HashSet<NodeID>>, +) -> LabelID { + let join = fork_join_map[&fork]; + let mut nodes_not_in_a_reduce_cycle = nodes_in_fork_joins[&fork].clone(); + for (cycle, reduce) in editor + .get_users(join) + .filter_map(|id| reduce_cycles.get(&id).map(|cycle| (cycle, id))) + { + nodes_not_in_a_reduce_cycle.remove(&reduce); + for id in cycle { + nodes_not_in_a_reduce_cycle.remove(id); + } + } + nodes_not_in_a_reduce_cycle.remove(&join); + + let mut label = LabelID::new(0); + let success = editor.edit(|mut edit| { + label = edit.fresh_label(); + for id in nodes_not_in_a_reduce_cycle { + edit = edit.add_label(id, label)?; + } + Ok(edit) + }); + + assert!(success); + label +} + pub fn ff_bufferize_any_fork<'a, 'b>( editor: &'b mut FunctionEditor<'a>, loop_tree: &'b LoopTree, fork_join_map: &'b HashMap<NodeID, NodeID>, + reduce_cycles: &HashMap<NodeID, HashSet<NodeID>>, nodes_in_fork_joins: &'b HashMap<NodeID, HashSet<NodeID>>, typing: &'b Vec<TypeID>, - fork_label: &'b LabelID, - data_label: &'b LabelID, + fork_label: LabelID, + data_label: Option<LabelID>, ) -> Option<(NodeID, NodeID)> where 'a: 'b, @@ -230,17 +264,26 @@ where let fork = fork_info.header; let join = fork_join_map[&fork]; - if !editor.func().labels[fork.idx()].contains(fork_label) { + if !editor.func().labels[fork.idx()].contains(&fork_label) { continue; } + let data_label = data_label.unwrap_or_else(|| { + ff_bufferize_create_not_reduce_cycle_label_helper( + editor, + fork, + fork_join_map, + reduce_cycles, + nodes_in_fork_joins, + ) + }); let edges = find_bufferize_edges( editor, fork, &loop_tree, &fork_join_map, &nodes_in_fork_joins, - data_label, + &data_label, ); let result = fork_bufferize_fission_helper( editor, diff --git a/hercules_opt/src/simplify_cfg.rs b/hercules_opt/src/simplify_cfg.rs index d579012e..14a152dc 100644 --- a/hercules_opt/src/simplify_cfg.rs +++ b/hercules_opt/src/simplify_cfg.rs @@ -91,36 +91,33 @@ fn remove_useless_fork_joins( fork_join_map: &HashMap<NodeID, NodeID>, reduce_cycles: &HashMap<NodeID, HashSet<NodeID>>, ) { - // First, try to get rid of reduces where possible. We can only delete all - // the reduces or none of the reduces in a particular fork-join, since even - // if one reduce may have no users outside the reduction cycle, it may be - // used by a reduce that is used outside the cycle, so it shouldn't be - // deleted. The reduction cycle may contain every reduce in a fork-join. + // First, try to get rid of reduces where possible. Look for reduces with no + // users outside its reduce cycle, and its reduce cycle contains no other + // reduce nodes. for (_, join) in fork_join_map { - let nodes = &editor.func().nodes; let reduces: Vec<_> = editor .get_users(*join) - .filter(|id| nodes[id.idx()].is_reduce()) + .filter(|id| editor.func().nodes[id.idx()].is_reduce()) .collect(); - // If every reduce has users only in the reduce cycle, then all the - // reduces can be deleted, along with every node in the reduce cycles. - if reduces.iter().all(|reduce| { - editor - .get_users(*reduce) - .all(|user| reduce_cycles[reduce].contains(&user)) - }) { - let mut all_the_nodes = HashSet::new(); - for reduce in reduces { - all_the_nodes.insert(reduce); - all_the_nodes.extend(&reduce_cycles[&reduce]); + for reduce in reduces { + // If the reduce has users only in the reduce cycle, and none of + // the nodes in the cycle are reduce nodes, then the reduce and its + // whole cycle can be deleted. + if editor + .get_users(reduce) + .all(|user| reduce_cycles[&reduce].contains(&user)) + && reduce_cycles[&reduce] + .iter() + .all(|id| !editor.func().nodes[id.idx()].is_reduce()) + { + editor.edit(|mut edit| { + for id in reduce_cycles[&reduce].iter() { + edit = edit.delete_node(*id)?; + } + edit.delete_node(reduce) + }); } - editor.edit(|mut edit| { - for id in all_the_nodes { - edit = edit.delete_node(id)?; - } - Ok(edit) - }); } } diff --git a/juno_samples/dot/src/dot.jn b/juno_samples/dot/src/dot.jn index 8c0e029c..cf097178 100644 --- a/juno_samples/dot/src/dot.jn +++ b/juno_samples/dot/src/dot.jn @@ -2,7 +2,7 @@ fn dot<n : usize>(a : i64[n], b : i64[n]) -> i64 { let res : i64; - for i = 0 to n { + @loop for i = 0 to n { res += a[i] * b[i]; } diff --git a/juno_scheduler/src/ir.rs b/juno_scheduler/src/ir.rs index d6f59bef..11cf6b13 100644 --- a/juno_scheduler/src/ir.rs +++ b/juno_scheduler/src/ir.rs @@ -48,7 +48,7 @@ impl Pass { match self { Pass::ArrayToProduct => num == 0 || num == 1, Pass::ForkChunk => num == 4, - Pass::ForkFissionBufferize => num == 2, + Pass::ForkFissionBufferize => num == 2 || num == 1, Pass::ForkInterchange => num == 2, Pass::Print => num == 1, Pass::Rename => num == 1, @@ -61,7 +61,7 @@ impl Pass { match self { Pass::ArrayToProduct => "0 or 1", Pass::ForkChunk => "4", - Pass::ForkFissionBufferize => "2", + Pass::ForkFissionBufferize => "1 or 2", Pass::ForkInterchange => "2", Pass::Print => "1", Pass::Rename => "1", diff --git a/juno_scheduler/src/pm.rs b/juno_scheduler/src/pm.rs index 45cebe80..34c2474b 100644 --- a/juno_scheduler/src/pm.rs +++ b/juno_scheduler/src/pm.rs @@ -2414,7 +2414,7 @@ fn run_pass( pm.clear_analyses(); } Pass::ForkFissionBufferize => { - assert_eq!(args.len(), 2); + assert!(args.len() == 1 || args.len() == 2); let Some(Value::Label { labels: fork_labels, }) = args.get(0) @@ -2425,25 +2425,17 @@ fn run_pass( }); }; - let Some(Value::Label { - labels: fork_data_labels, - }) = args.get(1) - else { - return Err(SchedulerError::PassError { - pass: "forkFissionBufferize".to_string(), - error: "expected label argument".to_string(), - }); - }; - let mut created_fork_joins = vec![vec![]; pm.functions.len()]; pm.make_fork_join_maps(); pm.make_typing(); pm.make_loops(); + pm.make_reduce_cycles(); pm.make_nodes_in_fork_joins(); let fork_join_maps = pm.fork_join_maps.take().unwrap(); let typing = pm.typing.take().unwrap(); let loops = pm.loops.take().unwrap(); + let reduce_cycles = pm.reduce_cycles.take().unwrap(); let nodes_in_fork_joins = pm.nodes_in_fork_joins.take().unwrap(); // assert only one function is in the selection. @@ -2454,30 +2446,42 @@ fn run_pass( assert!(num_functions <= 1); assert_eq!(fork_labels.len(), 1); - assert_eq!(fork_data_labels.len(), 1); let fork_label = fork_labels[0].label; - let data_label = fork_data_labels[0].label; - for ((((func, fork_join_map), loop_tree), typing), nodes_in_fork_joins) in - build_selection(pm, selection, false) - .into_iter() - .zip(fork_join_maps.iter()) - .zip(loops.iter()) - .zip(typing.iter()) - .zip(nodes_in_fork_joins.iter()) + for ( + ((((func, fork_join_map), loop_tree), typing), reduce_cycles), + nodes_in_fork_joins, + ) in build_selection(pm, selection, false) + .into_iter() + .zip(fork_join_maps.iter()) + .zip(loops.iter()) + .zip(typing.iter()) + .zip(reduce_cycles.iter()) + .zip(nodes_in_fork_joins.iter()) { let Some(mut func) = func else { continue; }; + + let data_label = if let Some(Value::Label { + labels: fork_data_labels, + }) = args.get(1) + { + assert_eq!(fork_data_labels.len(), 1); + Some(fork_data_labels[0].label) + } else { + None + }; if let Some((fork1, fork2)) = ff_bufferize_any_fork( &mut func, loop_tree, fork_join_map, + reduce_cycles, nodes_in_fork_joins, typing, - &fork_label, - &data_label, + fork_label, + data_label, ) { let created_fork_joins = &mut created_fork_joins[func.func_id().idx()]; created_fork_joins.push(fork1); -- GitLab From c8ee3bdcd8756a4d4bbb7f5229688388fb2d9e9e Mon Sep 17 00:00:00 2001 From: Russel Arbore <russel.jma@gmail.com> Date: Sat, 15 Feb 2025 22:23:22 -0600 Subject: [PATCH 07/10] fission works for dot example --- hercules_opt/src/fork_transforms.rs | 18 ++++++++++----- juno_samples/dot/src/cpu.sch | 35 ++++++++++++++++++----------- juno_scheduler/src/compile.rs | 4 +++- 3 files changed, 38 insertions(+), 19 deletions(-) diff --git a/hercules_opt/src/fork_transforms.rs b/hercules_opt/src/fork_transforms.rs index e832e559..7f6dd1bc 100644 --- a/hercules_opt/src/fork_transforms.rs +++ b/hercules_opt/src/fork_transforms.rs @@ -250,11 +250,12 @@ pub fn ff_bufferize_any_fork<'a, 'b>( where 'a: 'b, { - let forks: Vec<_> = loop_tree + let mut forks: Vec<_> = loop_tree .bottom_up_loops() .into_iter() .filter(|(k, _)| editor.func().nodes[k.idx()].is_fork()) .collect(); + forks.reverse(); for l in forks { let fork_info = Loop { @@ -1506,6 +1507,7 @@ pub fn clean_monoid_reduces(editor: &mut FunctionEditor, typing: &Vec<TypeID>) { let Some((_, init, reduct)) = nodes[id.idx()].try_reduce() else { continue; }; + let out_uses: Vec<_> = editor.get_users(id).filter(|id| *id != reduct).collect(); match nodes[reduct.idx()] { Node::Binary { @@ -1519,12 +1521,15 @@ pub fn clean_monoid_reduces(editor: &mut FunctionEditor, typing: &Vec<TypeID>) { let zero = edit.add_zero_constant(typing[init.idx()]); let zero = edit.add_node(Node::Constant { id: zero }); edit = edit.replace_all_uses_where(init, zero, |u| *u == id)?; - let final_add = edit.add_node(Node::Binary { + let final_op = edit.add_node(Node::Binary { op, left: init, right: id, }); - edit.replace_all_uses_where(id, final_add, |u| *u != reduct && *u != final_add) + for u in out_uses { + edit.sub_edit(u, final_op); + } + edit.replace_all_uses_where(id, final_op, |u| *u != reduct && *u != final_op) }); } Node::Binary { @@ -1538,12 +1543,15 @@ pub fn clean_monoid_reduces(editor: &mut FunctionEditor, typing: &Vec<TypeID>) { let one = edit.add_one_constant(typing[init.idx()]); let one = edit.add_node(Node::Constant { id: one }); edit = edit.replace_all_uses_where(init, one, |u| *u == id)?; - let final_add = edit.add_node(Node::Binary { + let final_op = edit.add_node(Node::Binary { op, left: init, right: id, }); - edit.replace_all_uses_where(id, final_add, |u| *u != reduct && *u != final_add) + for u in out_uses { + edit.sub_edit(u, final_op); + } + edit.replace_all_uses_where(id, final_op, |u| *u != reduct && *u != final_op) }); } _ => {} diff --git a/juno_samples/dot/src/cpu.sch b/juno_samples/dot/src/cpu.sch index 4e40e351..734054ab 100644 --- a/juno_samples/dot/src/cpu.sch +++ b/juno_samples/dot/src/cpu.sch @@ -1,24 +1,33 @@ -phi-elim(*); +phi-elim(dot); +ip-sroa(*); +sroa(dot); +dce(dot); -forkify(*); -fork-guard-elim(*); -dce(*); +forkify(dot); +fork-guard-elim(dot); +dce(dot); -fork-tile[8, 0, false, true](*); -fork-tile[32, 0, false, false](*); -fork-split(*); +fork-tile[8, 0, false, true](dot); +fork-tile[32, 0, false, false](dot); +let split_out = fork-split(dot); infer-schedules(*); clean-monoid-reduces(*); infer-schedules(*); clean-monoid-reduces(*); -let out = auto-outline(*); -cpu(out.dot); +let out = outline(split_out.dot.fj1); ip-sroa(*); -sroa(*); -dce(*); +sroa(dot); +gvn(dot); +dce(dot); -xdot[true](*); +let fission_out = fork-fission[out@loop](dot); +simplify-cfg(dot); +dce(dot); +unforkify(fission_out.dot.fj_loop_bottom); +ccp(dot); +gvn(dot); +dce(dot); -unforkify(*); +unforkify(out); gcm(*); diff --git a/juno_scheduler/src/compile.rs b/juno_scheduler/src/compile.rs index 2e930639..88816562 100644 --- a/juno_scheduler/src/compile.rs +++ b/juno_scheduler/src/compile.rs @@ -125,7 +125,9 @@ impl FromStr for Appliable { "ip-sroa" | "interprocedural-sroa" => { Ok(Appliable::Pass(ir::Pass::InterproceduralSROA)) } - "fork-fission-bufferize" => Ok(Appliable::Pass(ir::Pass::ForkFissionBufferize)), + "fork-fission-bufferize" | "fork-fission" => { + Ok(Appliable::Pass(ir::Pass::ForkFissionBufferize)) + } "fork-dim-merge" => Ok(Appliable::Pass(ir::Pass::ForkDimMerge)), "fork-interchange" => Ok(Appliable::Pass(ir::Pass::ForkInterchange)), "fork-chunk" | "fork-tile" => Ok(Appliable::Pass(ir::Pass::ForkChunk)), -- GitLab From e9e5aa319d3ad26b543f48519ecdc06c2f97c486 Mon Sep 17 00:00:00 2001 From: Russel Arbore <russel.jma@gmail.com> Date: Sat, 15 Feb 2025 22:48:21 -0600 Subject: [PATCH 08/10] a bunch of stuff for dot --- hercules_opt/src/fork_transforms.rs | 2 ++ hercules_opt/src/unforkify.rs | 30 ++++++++++++++++++++++++++--- juno_samples/dot/src/cpu.sch | 13 +++++++++---- juno_scheduler/src/compile.rs | 1 + juno_scheduler/src/ir.rs | 1 + juno_scheduler/src/pm.rs | 22 +++++++++++++++++++++ 6 files changed, 62 insertions(+), 7 deletions(-) diff --git a/hercules_opt/src/fork_transforms.rs b/hercules_opt/src/fork_transforms.rs index 7f6dd1bc..283734a0 100644 --- a/hercules_opt/src/fork_transforms.rs +++ b/hercules_opt/src/fork_transforms.rs @@ -1520,6 +1520,7 @@ pub fn clean_monoid_reduces(editor: &mut FunctionEditor, typing: &Vec<TypeID>) { editor.edit(|mut edit| { let zero = edit.add_zero_constant(typing[init.idx()]); let zero = edit.add_node(Node::Constant { id: zero }); + edit.sub_edit(id, zero); edit = edit.replace_all_uses_where(init, zero, |u| *u == id)?; let final_op = edit.add_node(Node::Binary { op, @@ -1542,6 +1543,7 @@ pub fn clean_monoid_reduces(editor: &mut FunctionEditor, typing: &Vec<TypeID>) { editor.edit(|mut edit| { let one = edit.add_one_constant(typing[init.idx()]); let one = edit.add_node(Node::Constant { id: one }); + edit.sub_edit(id, one); edit = edit.replace_all_uses_where(init, one, |u| *u == id)?; let final_op = edit.add_node(Node::Binary { op, diff --git a/hercules_opt/src/unforkify.rs b/hercules_opt/src/unforkify.rs index 7451b0ad..b44ed8df 100644 --- a/hercules_opt/src/unforkify.rs +++ b/hercules_opt/src/unforkify.rs @@ -117,7 +117,31 @@ pub fn unforkify_all( } } -pub fn unforkify(editor: &mut FunctionEditor, fork: NodeID, join: NodeID, loop_tree: &LoopTree) { +pub fn unforkify_one( + editor: &mut FunctionEditor, + fork_join_map: &HashMap<NodeID, NodeID>, + loop_tree: &LoopTree, +) { + for l in loop_tree.bottom_up_loops().into_iter().rev() { + if !editor.node(l.0).is_fork() { + continue; + } + + let fork = l.0; + let join = fork_join_map[&fork]; + + if unforkify(editor, fork, join, loop_tree) { + break; + } + } +} + +pub fn unforkify( + editor: &mut FunctionEditor, + fork: NodeID, + join: NodeID, + loop_tree: &LoopTree, +) -> bool { let mut zero_cons_id = ConstantID::new(0); let mut one_cons_id = ConstantID::new(0); assert!(editor.edit(|mut edit| { @@ -138,7 +162,7 @@ pub fn unforkify(editor: &mut FunctionEditor, fork: NodeID, join: NodeID, loop_t if factors.len() > 1 { // For now, don't convert multi-dimensional fork-joins. Rely on pass // that splits fork-joins. - return; + return false; } let join_control = nodes[join.idx()].try_join().unwrap(); let tids: Vec<_> = editor @@ -296,5 +320,5 @@ pub fn unforkify(editor: &mut FunctionEditor, fork: NodeID, join: NodeID, loop_t } Ok(edit) - }); + }) } diff --git a/juno_samples/dot/src/cpu.sch b/juno_samples/dot/src/cpu.sch index 734054ab..aa87972e 100644 --- a/juno_samples/dot/src/cpu.sch +++ b/juno_samples/dot/src/cpu.sch @@ -17,17 +17,22 @@ clean-monoid-reduces(*); let out = outline(split_out.dot.fj1); ip-sroa(*); -sroa(dot); -gvn(dot); -dce(dot); +sroa(*); +gvn(*); +dce(*); let fission_out = fork-fission[out@loop](dot); simplify-cfg(dot); dce(dot); unforkify(fission_out.dot.fj_loop_bottom); ccp(dot); +simplify-cfg(dot); gvn(dot); dce(dot); -unforkify(out); +unforkify-one(out); +ccp(out); +simplify-cfg(out); +gvn(out); +dce(out); gcm(*); diff --git a/juno_scheduler/src/compile.rs b/juno_scheduler/src/compile.rs index 88816562..fc2a729e 100644 --- a/juno_scheduler/src/compile.rs +++ b/juno_scheduler/src/compile.rs @@ -144,6 +144,7 @@ impl FromStr for Appliable { "slf" | "store-load-forward" => Ok(Appliable::Pass(ir::Pass::SLF)), "sroa" => Ok(Appliable::Pass(ir::Pass::SROA)), "unforkify" => Ok(Appliable::Pass(ir::Pass::Unforkify)), + "unforkify-one" => Ok(Appliable::Pass(ir::Pass::UnforkifyOne)), "fork-coalesce" => Ok(Appliable::Pass(ir::Pass::ForkCoalesce)), "verify" => Ok(Appliable::Pass(ir::Pass::Verify)), "xdot" => Ok(Appliable::Pass(ir::Pass::Xdot)), diff --git a/juno_scheduler/src/ir.rs b/juno_scheduler/src/ir.rs index 11cf6b13..bf3fe037 100644 --- a/juno_scheduler/src/ir.rs +++ b/juno_scheduler/src/ir.rs @@ -38,6 +38,7 @@ pub enum Pass { Serialize, SimplifyCFG, Unforkify, + UnforkifyOne, Verify, WritePredication, Xdot, diff --git a/juno_scheduler/src/pm.rs b/juno_scheduler/src/pm.rs index 34c2474b..8db79b46 100644 --- a/juno_scheduler/src/pm.rs +++ b/juno_scheduler/src/pm.rs @@ -2364,6 +2364,28 @@ fn run_pass( pm.delete_gravestones(); pm.clear_analyses(); } + Pass::UnforkifyOne => { + assert!(args.is_empty()); + pm.make_fork_join_maps(); + pm.make_loops(); + + let fork_join_maps = pm.fork_join_maps.take().unwrap(); + let loops = pm.loops.take().unwrap(); + + for ((func, fork_join_map), loop_tree) in build_selection(pm, selection, false) + .into_iter() + .zip(fork_join_maps.iter()) + .zip(loops.iter()) + { + let Some(mut func) = func else { + continue; + }; + unforkify_one(&mut func, fork_join_map, loop_tree); + changed |= func.modified(); + } + pm.delete_gravestones(); + pm.clear_analyses(); + } Pass::ForkChunk => { assert_eq!(args.len(), 4); let Some(Value::Integer { val: tile_size }) = args.get(0) else { -- GitLab From 5a1fa18bb347ff034c2adea0e924ecf2ee73425f Mon Sep 17 00:00:00 2001 From: Russel Arbore <russel.jma@gmail.com> Date: Sun, 16 Feb 2025 09:33:18 -0600 Subject: [PATCH 09/10] Lower read and write in rt backend --- hercules_cg/src/rt.rs | 34 +++++++++++++++++++++++++++++----- 1 file changed, 29 insertions(+), 5 deletions(-) diff --git a/hercules_cg/src/rt.rs b/hercules_cg/src/rt.rs index f758fed3..8c5775d8 100644 --- a/hercules_cg/src/rt.rs +++ b/hercules_cg/src/rt.rs @@ -622,9 +622,26 @@ impl<'a> RTContext<'a> { } => { let block = &mut blocks.get_mut(&bb).unwrap().data; let collect_ty = self.typing[collect.idx()]; - let out_size = self.codegen_type_size(self.typing[id.idx()]); + let self_ty = self.typing[id.idx()]; let offset = self.codegen_index_math(collect_ty, indices, bb)?; - todo!(); + if self.module.types[self_ty.idx()].is_primitive() { + write!( + block, + "{} = ({}.byte_add({} as usize).0 as *mut {}).read();", + self.get_value(id, bb, true), + self.get_value(collect, bb, false), + offset, + self.get_type(self_ty) + )?; + } else { + write!( + block, + "{} = {}.byte_add({} as usize);", + self.get_value(id, bb, true), + self.get_value(collect, bb, false), + offset, + )?; + } } Node::Write { collect, @@ -633,11 +650,18 @@ impl<'a> RTContext<'a> { } => { let block = &mut blocks.get_mut(&bb).unwrap().data; let collect_ty = self.typing[collect.idx()]; - let data_size = self.codegen_type_size(self.typing[data.idx()]); - let offset = self.codegen_index_math(collect_ty, indices, bb)?; let data_ty = self.typing[data.idx()]; + let data_size = self.codegen_type_size(data_ty); + let offset = self.codegen_index_math(collect_ty, indices, bb)?; if self.module.types[data_ty.idx()].is_primitive() { - todo!(); + write!( + block, + "({}.byte_add({} as usize).0 as *mut {}).write({});", + self.get_value(collect, bb, false), + offset, + self.get_type(data_ty), + self.get_value(data, bb, false), + )?; } else { // If the data item being written is not a primitive type, // then perform a memcpy from the data collection to the -- GitLab From 85ce14bde0ca675bfa967fb35dc746eea343ca86 Mon Sep 17 00:00:00 2001 From: Russel Arbore <russel.jma@gmail.com> Date: Sun, 16 Feb 2025 10:15:36 -0600 Subject: [PATCH 10/10] enough rt backend stuff for multi-threaded dot --- hercules_cg/src/rt.rs | 69 ++++++++++++++++++++++++++++++++---- hercules_opt/src/gcm.rs | 9 +++-- juno_samples/dot/src/cpu.sch | 5 +-- juno_samples/dot/src/main.rs | 2 +- 4 files changed, 73 insertions(+), 12 deletions(-) diff --git a/hercules_cg/src/rt.rs b/hercules_cg/src/rt.rs index 8c5775d8..4d9a6cf6 100644 --- a/hercules_cg/src/rt.rs +++ b/hercules_cg/src/rt.rs @@ -139,7 +139,10 @@ struct RTContext<'a> { struct RustBlock { prologue: String, data: String, + phi_tmp_assignments: String, + phi_assignments: String, epilogue: String, + join_epilogue: String, } impl<'a> RTContext<'a> { @@ -251,7 +254,28 @@ impl<'a> RTContext<'a> { // fork and join nodes open and close environments, respectively. for id in rev_po.iter() { let block = &blocks[id]; - write!(w, "{}{}{}", block.prologue, block.data, block.epilogue)?; + if func.nodes[id.idx()].is_join() { + write!( + w, + "{}{}{}{}{}{}", + block.prologue, + block.data, + block.epilogue, + block.phi_tmp_assignments, + block.phi_assignments, + block.join_epilogue + )?; + } else { + write!( + w, + "{}{}{}{}{}", + block.prologue, + block.data, + block.phi_tmp_assignments, + block.phi_assignments, + block.epilogue + )?; + } } // Close the root environment. @@ -367,7 +391,10 @@ impl<'a> RTContext<'a> { // Close the branch inside the async closure. let epilogue = &mut blocks.get_mut(&id).unwrap().epilogue; - write!(epilogue, "return;}}")?; + write!( + epilogue, + "::std::sync::atomic::fence(::std::sync::atomic::Ordering::Release);return;}}" + )?; // Close the fork's environment. self.codegen_close_environment(epilogue)?; @@ -405,9 +432,10 @@ impl<'a> RTContext<'a> { } } + let join_epilogue = &mut blocks.get_mut(&id).unwrap().join_epilogue; // Branch to the successor control node in the surrounding // context, and close the branch for the join. - write!(epilogue, "control_token = {};}}", succ.idx())?; + write!(join_epilogue, "control_token = {};}}", succ.idx())?; } _ => panic!("PANIC: Can't lower {:?}.", func.nodes[id.idx()]), } @@ -481,15 +509,39 @@ impl<'a> RTContext<'a> { write!(block, ";")?; } Node::ThreadID { control, dimension } => { + assert_eq!(control, bb); let block = &mut blocks.get_mut(&bb).unwrap().data; write!( block, "{} = tid_{}_{};", self.get_value(id, bb, true), - control.idx(), + bb.idx(), dimension )?; } + Node::Phi { control, ref data } => { + assert_eq!(control, bb); + // Phis aren't executable in their own basic block - predecessor + // blocks assign the to-be phi values themselves. Assign + // temporary values first before assigning the phi itself, since + // there may be simultaneous inter-dependent phis. + for (data, pred) in zip(data.into_iter(), self.control_subgraph.preds(bb)) { + let block = &mut blocks.get_mut(&pred).unwrap().phi_tmp_assignments; + write!( + block, + "let {}_tmp = {};", + self.get_value(id, pred, true), + self.get_value(*data, pred, false), + )?; + let block = &mut blocks.get_mut(&pred).unwrap().phi_assignments; + write!( + block, + "{} = {}_tmp;", + self.get_value(id, pred, true), + self.get_value(id, pred, false), + )?; + } + } Node::Reduce { control: _, init: _, @@ -498,11 +550,12 @@ impl<'a> RTContext<'a> { assert!(func.schedules[id.idx()].contains(&Schedule::ParallelReduce)); } Node::Call { - control: _, + control, function: callee_id, ref dynamic_constants, ref args, } => { + assert_eq!(control, bb); // The device backends ensure that device functions have the // same interface as AsyncRust functions. let block = &mut blocks.get_mut(&bb).unwrap().data; @@ -975,7 +1028,9 @@ impl<'a> RTContext<'a> { if is_reduce_on_child { "reduce" } else { "node" }, idx, self.get_type(self.typing[idx]), - if self.module.types[self.typing[idx].idx()].is_integer() { + if self.module.types[self.typing[idx].idx()].is_bool() { + "false" + } else if self.module.types[self.typing[idx].idx()].is_integer() { "0" } else if self.module.types[self.typing[idx].idx()].is_float() { "0.0" @@ -1241,7 +1296,7 @@ impl<'a> RTContext<'a> { // Before using the value of a reduction outside the fork-join, // await the futures. format!( - "{{for fut in fork_{}.drain(..) {{ fut.await; }}; reduce_{}}}", + "{{for fut in fork_{}.drain(..) {{ fut.await; }}; ::std::sync::atomic::fence(::std::sync::atomic::Ordering::Acquire); reduce_{}}}", fork.idx(), id.idx() ) diff --git a/hercules_opt/src/gcm.rs b/hercules_opt/src/gcm.rs index 99c44d52..821d02ea 100644 --- a/hercules_opt/src/gcm.rs +++ b/hercules_opt/src/gcm.rs @@ -95,9 +95,11 @@ pub fn gcm( let bbs = basic_blocks( editor.func(), + editor.get_types(), editor.func_id(), def_use, reverse_postorder, + typing, dom, loops, reduce_cycles, @@ -218,9 +220,11 @@ fn preliminary_fixups( */ fn basic_blocks( function: &Function, + types: Ref<Vec<Type>>, func_id: FunctionID, def_use: &ImmutableDefUseMap, reverse_postorder: &Vec<NodeID>, + typing: &Vec<TypeID>, dom: &DomTree, loops: &LoopTree, reduce_cycles: &HashMap<NodeID, HashSet<NodeID>>, @@ -498,8 +502,9 @@ fn basic_blocks( // control dependent as possible, even inside loops. In GPU // functions specifically, lift constants that may be returned // outside fork-joins. - let is_constant_or_undef = - function.nodes[id.idx()].is_constant() || function.nodes[id.idx()].is_undef(); + let is_constant_or_undef = (function.nodes[id.idx()].is_constant() + || function.nodes[id.idx()].is_undef()) + && !types[typing[id.idx()].idx()].is_primitive(); let is_gpu_returned = devices[func_id.idx()] == Device::CUDA && objects[&func_id] .objects(id) diff --git a/juno_samples/dot/src/cpu.sch b/juno_samples/dot/src/cpu.sch index aa87972e..1f8953d9 100644 --- a/juno_samples/dot/src/cpu.sch +++ b/juno_samples/dot/src/cpu.sch @@ -8,7 +8,7 @@ fork-guard-elim(dot); dce(dot); fork-tile[8, 0, false, true](dot); -fork-tile[32, 0, false, false](dot); +fork-tile[8, 0, false, false](dot); let split_out = fork-split(dot); infer-schedules(*); clean-monoid-reduces(*); @@ -29,8 +29,9 @@ ccp(dot); simplify-cfg(dot); gvn(dot); dce(dot); +infer-schedules(dot); -unforkify-one(out); +unforkify(out); ccp(out); simplify-cfg(out); gvn(out); diff --git a/juno_samples/dot/src/main.rs b/juno_samples/dot/src/main.rs index b73f8710..bd887194 100644 --- a/juno_samples/dot/src/main.rs +++ b/juno_samples/dot/src/main.rs @@ -9,7 +9,7 @@ juno_build::juno!("dot"); fn main() { async_std::task::block_on(async { - const N: u64 = 1024 * 1024; + const N: u64 = 1024 * 8; let a: Box<[i64]> = (0..N).map(|_| random::<i64>() % 100).collect(); let b: Box<[i64]> = (0..N).map(|_| random::<i64>() % 100).collect(); let a_herc = HerculesImmBox::from(&a as &[i64]); -- GitLab