diff --git a/hercules_opt/src/fork_transforms.rs b/hercules_opt/src/fork_transforms.rs index 58ace775feba391d284ec9e39d9dd054940a7616..cbb09bbfe3a1dd9e0359ee05376a40481fa4b5f1 100644 --- a/hercules_opt/src/fork_transforms.rs +++ b/hercules_opt/src/fork_transforms.rs @@ -695,6 +695,24 @@ pub(crate) fn split_fork( } } +pub fn chunk_all_forks_unguarded( + editor: &mut FunctionEditor, + fork_join_map: &HashMap<NodeID, NodeID>, + dim_idx: usize, + tile_size: usize, +) -> () { + // Add dc + let mut dc_id = DynamicConstantID::new(0); + editor.edit(|mut edit| { + dc_id = edit.add_dynamic_constant(DynamicConstant::Constant(tile_size)); + Ok(edit) + }); + + for (fork, _ ) in fork_join_map { + chunk_fork_unguarded(editor, *fork, dim_idx, dc_id); + } + +} // Splits a dimension of a single fork join into multiple. // Iterates an outer loop original_dim / tile_size times // adds a tile_size loop as the inner loop @@ -711,39 +729,36 @@ pub fn chunk_fork_unguarded( let mut new_factors: Vec<_> = old_factors.to_vec(); - let fork_users: Vec<_> = editor.get_users(fork).collect(); + let fork_users: Vec<_> = editor.get_users(fork).map(|f| (f, editor.node(f).clone())).collect(); + + editor.edit(|mut edit| { + let outer = DynamicConstant::div(new_factors[dim_idx], tile_size); + new_factors.insert(dim_idx + 1, tile_size); + new_factors[dim_idx] = edit.add_dynamic_constant(outer); + + let new_fork = Node::Fork { control: old_control, factors: new_factors.into() }; + let new_fork = edit.add_node(new_fork); + edit = edit.replace_all_uses(fork, new_fork)?; - for tid in fork_users { - let Node::ThreadID { control: _, dimension: tid_dim } = editor.node(tid).clone() else { continue }; - editor.edit(|mut edit| { + for (tid, node) in fork_users { + let Node::ThreadID { control: _, dimension: tid_dim } = node else {continue}; if tid_dim > dim_idx { - let new_tid = Node::ThreadID { control: fork, dimension: tid_dim + 1 }; + let new_tid = Node::ThreadID { control: new_fork, dimension: tid_dim + 1 }; let new_tid = edit.add_node(new_tid); - edit.replace_all_uses(tid, new_tid) + edit = edit.replace_all_uses(tid, new_tid)?; } else if tid_dim == dim_idx { - let tile_tid = Node::ThreadID { control: fork, dimension: tid_dim + 1 }; + let tile_tid = Node::ThreadID { control: new_fork, dimension: tid_dim + 1 }; let tile_tid = edit.add_node(tile_tid); let tile_size = edit.add_node(Node::DynamicConstant { id: tile_size }); let mul = edit.add_node(Node::Binary { left: tid, right: tile_size, op: BinaryOperator::Mul }); let add = edit.add_node(Node::Binary { left: mul, right: tile_tid, op: BinaryOperator::Add }); - edit.replace_all_uses_where(tid, add, |usee| *usee != mul ) - } else { - Ok(edit) + edit = edit.replace_all_uses_where(tid, add, |usee| *usee != mul )?; } - }); - } - - editor.edit(|mut edit| { - let outer = DynamicConstant::div(new_factors[dim_idx], tile_size); - new_factors.insert(dim_idx + 1, tile_size); - new_factors[dim_idx] = edit.add_dynamic_constant(outer); - - let new_fork = Node::Fork { control: old_control, factors: new_factors.into() }; - let new_fork = edit.add_node(new_fork); - - edit.replace_all_uses(fork, new_fork) + } + edit = edit.delete_node(fork)?; + Ok(edit) }); } @@ -791,9 +806,8 @@ pub fn fork_dim_merge( let mut new_factors: Vec<_> = old_factors.to_vec(); - - let fork_users: Vec<_> = editor.get_users(fork).collect(); + let fork_users: Vec<_> = editor.get_users(fork).map(|f| (f, editor.node(f).clone())).collect(); let mut new_nodes = vec![]; @@ -801,6 +815,7 @@ pub fn fork_dim_merge( let inner_dc_id = new_factors[inner_idx]; let mut new_fork_id = NodeID::new(0); + editor.edit(|mut edit| { new_factors[outer_idx] = edit.add_dynamic_constant(DynamicConstant::mul(new_factors[outer_idx], new_factors[inner_idx])); new_factors.remove(inner_idx); @@ -809,22 +824,20 @@ pub fn fork_dim_merge( let new_fork = edit.add_node(new_fork); new_fork_id = new_fork; + edit.sub_edit(fork, new_fork); edit = edit.replace_all_uses(fork, new_fork)?; - edit.delete_node(fork) - }); - - - - for tid in fork_users { - let Node::ThreadID { control: _, dimension: tid_dim } = editor.node(tid).clone() else { continue }; + edit = edit.delete_node(fork)?; - println!("tid: {:?}", tid); - editor.edit(|mut edit| { + for (tid, node) in fork_users { + // FIXME: DO we want sub edits in this? + + let Node::ThreadID { control: _, dimension: tid_dim } = node else { continue }; if tid_dim > inner_idx { let new_tid = Node::ThreadID { control: new_fork_id, dimension: tid_dim - 1 }; let new_tid = edit.add_node(new_tid); - edit.replace_all_uses(tid, new_tid) + edit = edit.replace_all_uses(tid, new_tid)?; + edit.sub_edit(tid, new_tid); } else if tid_dim == outer_idx { let outer_tid = Node::ThreadID { control: new_fork_id, dimension: outer_idx }; let outer_tid = edit.add_node(outer_tid); @@ -834,8 +847,8 @@ pub fn fork_dim_merge( // inner_idx % dim(outer_idx) let rem = edit.add_node(Node::Binary { left: outer_tid, right: outer_dc, op: BinaryOperator::Rem}); - - edit.replace_all_uses(tid, rem) + edit.sub_edit(tid, rem); + edit = edit.replace_all_uses(tid, rem)?; } else if tid_dim == inner_idx { let outer_tid = Node::ThreadID { control: new_fork_id, dimension: outer_idx }; let outer_tid = edit.add_node(outer_tid); @@ -843,13 +856,12 @@ pub fn fork_dim_merge( let outer_dc = edit.add_node(Node::DynamicConstant { id: outer_dc_id }); // inner_idx / dim(outer_idx) let div = edit.add_node(Node::Binary { left: outer_tid, right: outer_dc, op: BinaryOperator::Div}); - - edit.replace_all_uses(tid, div) - } else { - Ok(edit) + edit.sub_edit(tid, div); + edit = edit.replace_all_uses(tid, div)?; } - }); - }; + } + Ok(edit) + }); return new_fork_id; diff --git a/juno_scheduler/src/compile.rs b/juno_scheduler/src/compile.rs index 07ad5e7ab75820d3654cc52ecf52ae13bd31aca1..49dedd2b0ebce4712150b5b20f5bd0166531982b 100644 --- a/juno_scheduler/src/compile.rs +++ b/juno_scheduler/src/compile.rs @@ -110,6 +110,7 @@ impl FromStr for Appliable { Ok(Appliable::Pass(ir::Pass::InterproceduralSROA)) }, "fork-dim-merge" => Ok(Appliable::Pass(ir::Pass::ForkDimMerge)), + "fork-chunk" | "fork-tile" => Ok(Appliable::Pass(ir::Pass::ForkChunk)), "lift-dc-math" => Ok(Appliable::Pass(ir::Pass::LiftDCMath)), "outline" => Ok(Appliable::Pass(ir::Pass::Outline)), "phi-elim" => Ok(Appliable::Pass(ir::Pass::PhiElim)), diff --git a/juno_scheduler/src/ir.rs b/juno_scheduler/src/ir.rs index 939ef925b5f24f17650c6c5119b6d10665677580..796437a71412a8411faee201c648c8ad46b577f5 100644 --- a/juno_scheduler/src/ir.rs +++ b/juno_scheduler/src/ir.rs @@ -36,6 +36,7 @@ impl Pass { pub fn num_args(&self) -> usize { match self { Pass::Xdot => 1, + Pass::ForkChunk => 3, _ => 0, } } diff --git a/juno_scheduler/src/pm.rs b/juno_scheduler/src/pm.rs index 8b71d24e34ceed7664588bbe29b3000db939714d..5740d2a66b2e7b74595497526a860fcb1d01d459 100644 --- a/juno_scheduler/src/pm.rs +++ b/juno_scheduler/src/pm.rs @@ -1891,6 +1891,40 @@ fn run_pass( pm.delete_gravestones(); pm.clear_analyses(); } + Pass::ForkChunk => { + assert_eq!(args.len(), 3); + let tile_size = args.get(0); + let dim_idx = args.get(1); + + let Some(Value::Boolean { val: guarded_flag }) = args.get(2) else { + panic!(); // How to error here? + }; + + let Some(Value::Integer { val: dim_idx }) = args.get(1) else { + panic!(); // How to error here? + }; + + let Some(Value::Integer { val: tile_size }) = args.get(0) else { + panic!(); // How to error here? + }; + + assert_eq!(*guarded_flag, true); + pm.make_fork_join_maps(); + let fork_join_maps = pm.fork_join_maps.take().unwrap(); + for (func, fork_join_map) in + build_selection(pm, selection) + .into_iter() + .zip(fork_join_maps.iter()) + { + let Some(mut func) = func else { + continue; + }; + chunk_all_forks_unguarded(&mut func, fork_join_map, *dim_idx, *tile_size); + changed |= func.modified(); + } + pm.delete_gravestones(); + pm.clear_analyses(); + } Pass::ForkDimMerge => { assert!(args.is_empty()); pm.make_fork_join_maps();