diff --git a/hercules_opt/src/fork_transforms.rs b/hercules_opt/src/fork_transforms.rs index 72b5716f6ab56b260cc1d0d0cb8c78fa2da60d42..c32a517e46ab52a69f13feacb94382a8edaaf731 100644 --- a/hercules_opt/src/fork_transforms.rs +++ b/hercules_opt/src/fork_transforms.rs @@ -912,16 +912,16 @@ pub fn chunk_fork_unguarded( let outer = DynamicConstant::div(new_factors[dim_idx], tile_size); new_factors.insert(dim_idx + 1, tile_size); new_factors[dim_idx] = edit.add_dynamic_constant(outer); - + let new_fork = Node::Fork { control: old_control, factors: new_factors.into(), }; let new_fork = edit.add_node(new_fork); - + edit = edit.replace_all_uses(fork, new_fork)?; edit.sub_edit(fork, new_fork); - + for (tid, node) in fork_users { let Node::ThreadID { control: _, @@ -945,7 +945,7 @@ pub fn chunk_fork_unguarded( dimension: tid_dim + 1, }; let tile_tid = edit.add_node(tile_tid); - + let tile_size = edit.add_node(Node::DynamicConstant { id: tile_size }); let mul = edit.add_node(Node::Binary { left: tid, @@ -965,23 +965,23 @@ pub fn chunk_fork_unguarded( edit = edit.delete_node(fork)?; Ok(edit) }); - }, + } TileOrder::TileOuter => { editor.edit(|mut edit| { let inner = DynamicConstant::div(new_factors[dim_idx], tile_size); new_factors.insert(dim_idx, tile_size); - let inner_dc_id = edit.add_dynamic_constant(inner); + let inner_dc_id = edit.add_dynamic_constant(inner); new_factors[dim_idx + 1] = inner_dc_id; - + let new_fork = Node::Fork { control: old_control, factors: new_factors.into(), }; let new_fork = edit.add_node(new_fork); - + edit = edit.replace_all_uses(fork, new_fork)?; edit.sub_edit(fork, new_fork); - + for (tid, node) in fork_users { let Node::ThreadID { control: _, @@ -1000,13 +1000,12 @@ pub fn chunk_fork_unguarded( edit.sub_edit(tid, new_tid); edit = edit.delete_node(tid)?; } else if tid_dim == dim_idx { - let tile_tid = Node::ThreadID { control: new_fork, dimension: tid_dim, }; let tile_tid = edit.add_node(tile_tid); - let inner_dc = edit.add_node(Node::DynamicConstant { id: inner_dc_id } ); + let inner_dc = edit.add_node(Node::DynamicConstant { id: inner_dc_id }); let mul = edit.add_node(Node::Binary { left: tid, right: inner_dc, @@ -1027,7 +1026,6 @@ pub fn chunk_fork_unguarded( }); } } - } pub fn merge_all_fork_dims(editor: &mut FunctionEditor, fork_join_map: &HashMap<NodeID, NodeID>) { @@ -1350,3 +1348,101 @@ pub fn fork_unroll( Ok(edit) }) } + +/* + * Looks for fork-joins that are next to each other, not inter-dependent, and + * have the same bounds. These fork-joins can be fused, pooling together all + * their reductions. + */ +pub fn fork_fusion_all_forks( + editor: &mut FunctionEditor, + fork_join_map: &HashMap<NodeID, NodeID>, + nodes_in_fork_joins: &HashMap<NodeID, HashSet<NodeID>>, +) { + for (fork, join) in fork_join_map { + if editor.is_mutable(*fork) + && fork_fusion(editor, *fork, *join, fork_join_map, nodes_in_fork_joins) + { + break; + } + } +} + +/* + * Tries to fuse a given fork join with the immediately following fork-join, if + * it exists. + */ +fn fork_fusion( + editor: &mut FunctionEditor, + top_fork: NodeID, + top_join: NodeID, + fork_join_map: &HashMap<NodeID, NodeID>, + nodes_in_fork_joins: &HashMap<NodeID, HashSet<NodeID>>, +) -> bool { + let nodes = &editor.func().nodes; + // Rust operator precedence is not such that these can be put in one big + // let-else statement. Sad! + let Some(bottom_fork) = editor + .get_users(top_join) + .filter(|id| nodes[id.idx()].is_control()) + .next() + else { + return false; + }; + let Some(bottom_join) = fork_join_map.get(&bottom_fork) else { + return false; + }; + let (_, top_factors) = nodes[top_fork.idx()].try_fork().unwrap(); + let (bottom_fork_pred, bottom_factors) = nodes[bottom_fork.idx()].try_fork().unwrap(); + assert_eq!(bottom_fork_pred, top_join); + let top_join_pred = nodes[top_join.idx()].try_join().unwrap(); + let bottom_join_pred = nodes[bottom_join.idx()].try_join().unwrap(); + + // The fork factors must be identical. + if top_factors != bottom_factors { + return false; + } + + // Check that no iterated users of the top's reduces are in the bottom fork- + // join (iteration stops at a phi or reduce outside the bottom fork-join). + for reduce in editor + .get_users(top_join) + .filter(|id| nodes[id.idx()].is_reduce()) + { + let mut visited = HashSet::new(); + visited.insert(reduce); + let mut workset = vec![reduce]; + while let Some(pop) = workset.pop() { + for u in editor.get_users(pop) { + if nodes_in_fork_joins[&bottom_fork].contains(&u) { + return false; + } else if (nodes[u.idx()].is_phi() || nodes[u.idx()].is_reduce()) + && !nodes_in_fork_joins[&top_fork].contains(&u) + { + } else if !visited.contains(&u) && !nodes_in_fork_joins[&top_fork].contains(&u) { + visited.insert(u); + workset.push(u); + } + } + } + } + + // Perform the fusion. + editor.edit(|mut edit| { + if bottom_join_pred != bottom_fork { + // If there is control flow in the bottom fork-join, stitch it into + // the top fork-join. + edit = edit.replace_all_uses_where(bottom_fork, top_join_pred, |id| { + nodes_in_fork_joins[&bottom_fork].contains(id) + })?; + edit = + edit.replace_all_uses_where(top_join_pred, bottom_join_pred, |id| *id == top_join)?; + } + // Replace the bottom fork and join with the top fork and join. + edit = edit.replace_all_uses(bottom_fork, top_fork)?; + edit = edit.replace_all_uses(*bottom_join, top_join)?; + edit = edit.delete_node(bottom_fork)?; + edit = edit.delete_node(*bottom_join)?; + Ok(edit) + }) +} diff --git a/juno_samples/cava/src/cava.jn b/juno_samples/cava/src/cava.jn index e6961faf01bc90114186e3ea72b6b9671e2758d4..8158bf0a9bec25d57e1136bbf040ae0f561dcc3f 100644 --- a/juno_samples/cava/src/cava.jn +++ b/juno_samples/cava/src/cava.jn @@ -152,7 +152,7 @@ fn gamut<row : usize, col : usize, num_ctrl_pts : usize>( let v = v1 * v1 + v2 * v2 + v3 * v3; l2_dist[cp] = sqrt!::<f32>(v); } - + @channel_loop for chan = 0 to CHAN { let chan_val : f32 = 0.0; for cp = 0 to num_ctrl_pts { diff --git a/juno_samples/cava/src/cpu.sch b/juno_samples/cava/src/cpu.sch index 1ae1dc132942ff868a25080d297a25f06c8395a2..0b8869d55de67ab8e796397bc25d5c15cc926d59 100644 --- a/juno_samples/cava/src/cpu.sch +++ b/juno_samples/cava/src/cpu.sch @@ -101,7 +101,9 @@ fixpoint { simpl!(fuse4); fork-unroll(fuse4@channel_loop); simpl!(fuse4); -//fork-fusion(fuse4@channel_loop); +fixpoint { + fork-fusion(fuse4@channel_loop); +} simpl!(fuse4); array-slf(fuse4); simpl!(fuse4); diff --git a/juno_scheduler/src/compile.rs b/juno_scheduler/src/compile.rs index 912cc91f7f5fd968374a50d304dfb4e5cce1654f..7c92e00d93dc100c6aec313c68ec17e42f7c44e9 100644 --- a/juno_scheduler/src/compile.rs +++ b/juno_scheduler/src/compile.rs @@ -129,6 +129,7 @@ impl FromStr for Appliable { "fork-interchange" => Ok(Appliable::Pass(ir::Pass::ForkInterchange)), "fork-chunk" | "fork-tile" => Ok(Appliable::Pass(ir::Pass::ForkChunk)), "fork-unroll" | "unroll" => Ok(Appliable::Pass(ir::Pass::ForkUnroll)), + "fork-fusion" | "fusion" => Ok(Appliable::Pass(ir::Pass::ForkFusion)), "lift-dc-math" => Ok(Appliable::Pass(ir::Pass::LiftDCMath)), "outline" => Ok(Appliable::Pass(ir::Pass::Outline)), "phi-elim" => Ok(Appliable::Pass(ir::Pass::PhiElim)), diff --git a/juno_scheduler/src/ir.rs b/juno_scheduler/src/ir.rs index 8ad923242a3b73f8b9379f1f9b51070245d2030c..205cd70b710dc810a89e88a3683fde4a9f695e5b 100644 --- a/juno_scheduler/src/ir.rs +++ b/juno_scheduler/src/ir.rs @@ -13,6 +13,7 @@ pub enum Pass { ForkCoalesce, ForkDimMerge, ForkFissionBufferize, + ForkFusion, ForkGuardElim, ForkInterchange, ForkSplit, diff --git a/juno_scheduler/src/pm.rs b/juno_scheduler/src/pm.rs index b26d1720fc192378a214e3951d15e957c317b77f..a4783a93701564f041237f2fecef8a22c62abe5c 100644 --- a/juno_scheduler/src/pm.rs +++ b/juno_scheduler/src/pm.rs @@ -2526,6 +2526,27 @@ fn run_pass( fields: new_fork_joins, }; } + Pass::ForkFusion => { + assert!(args.is_empty()); + pm.make_fork_join_maps(); + pm.make_nodes_in_fork_joins(); + let fork_join_maps = pm.fork_join_maps.take().unwrap(); + let nodes_in_fork_joins = pm.nodes_in_fork_joins.take().unwrap(); + for ((func, fork_join_map), nodes_in_fork_joins) in + build_selection(pm, selection, false) + .into_iter() + .zip(fork_join_maps.iter()) + .zip(nodes_in_fork_joins.iter()) + { + let Some(mut func) = func else { + continue; + }; + fork_fusion_all_forks(&mut func, fork_join_map, nodes_in_fork_joins); + changed |= func.modified(); + } + pm.delete_gravestones(); + pm.clear_analyses(); + } Pass::ForkDimMerge => { assert!(args.is_empty()); pm.make_fork_join_maps();