From cb5a7b4eb050b78d0454d9f35e64b7a5aec553ad Mon Sep 17 00:00:00 2001 From: Xavier Routh <xrouth2@illinois.edu> Date: Thu, 13 Feb 2025 14:52:37 -0600 Subject: [PATCH] add alternate order to fork tilling --- hercules_opt/src/fork_transforms.rs | 188 ++++++++++++++++------- juno_samples/fork_join_tests/src/cpu.sch | 2 +- juno_samples/fork_join_tests/src/gpu.sch | 2 +- juno_scheduler/src/ir.rs | 2 +- juno_scheduler/src/pm.rs | 10 +- 5 files changed, 144 insertions(+), 60 deletions(-) diff --git a/hercules_opt/src/fork_transforms.rs b/hercules_opt/src/fork_transforms.rs index 2f7a91fa..20982163 100644 --- a/hercules_opt/src/fork_transforms.rs +++ b/hercules_opt/src/fork_transforms.rs @@ -856,6 +856,7 @@ pub fn chunk_all_forks_unguarded( fork_join_map: &HashMap<NodeID, NodeID>, dim_idx: usize, tile_size: usize, + order: bool, ) -> () { // Add dc let mut dc_id = DynamicConstantID::new(0); @@ -864,19 +865,31 @@ pub fn chunk_all_forks_unguarded( Ok(edit) }); + let order = match order { + true => &TileOrder::TileInner, + false => &TileOrder::TileOuter, + }; + for (fork, _) in fork_join_map { - chunk_fork_unguarded(editor, *fork, dim_idx, dc_id); + chunk_fork_unguarded(editor, *fork, dim_idx, dc_id, order); } } // Splits a dimension of a single fork join into multiple. // Iterates an outer loop original_dim / tile_size times // adds a tile_size loop as the inner loop // Assumes that tile size divides original dim evenly. + +enum TileOrder { + TileInner, + TileOuter, +} + pub fn chunk_fork_unguarded( editor: &mut FunctionEditor, fork: NodeID, dim_idx: usize, tile_size: DynamicConstantID, + order: &TileOrder, ) -> () { // tid_dim_idx = tid_dim_idx * tile_size + tid_(dim_idx + 1) let Node::Fork { @@ -893,63 +906,128 @@ pub fn chunk_fork_unguarded( .map(|f| (f, editor.node(f).clone())) .collect(); - editor.edit(|mut edit| { - let outer = DynamicConstant::div(new_factors[dim_idx], tile_size); - new_factors.insert(dim_idx + 1, tile_size); - new_factors[dim_idx] = edit.add_dynamic_constant(outer); - - let new_fork = Node::Fork { - control: old_control, - factors: new_factors.into(), - }; - let new_fork = edit.add_node(new_fork); - - edit = edit.replace_all_uses(fork, new_fork)?; - edit.sub_edit(fork, new_fork); - - for (tid, node) in fork_users { - let Node::ThreadID { - control: _, - dimension: tid_dim, - } = node - else { - continue; - }; - if tid_dim > dim_idx { - let new_tid = Node::ThreadID { - control: new_fork, - dimension: tid_dim + 1, + match order { + TileOrder::TileInner => { + editor.edit(|mut edit| { + let outer = DynamicConstant::div(new_factors[dim_idx], tile_size); + new_factors.insert(dim_idx + 1, tile_size); + new_factors[dim_idx] = edit.add_dynamic_constant(outer); + + let new_fork = Node::Fork { + control: old_control, + factors: new_factors.into(), }; - let new_tid = edit.add_node(new_tid); - edit = edit.replace_all_uses(tid, new_tid)?; - edit.sub_edit(tid, new_tid); - edit = edit.delete_node(tid)?; - } else if tid_dim == dim_idx { - let tile_tid = Node::ThreadID { - control: new_fork, - dimension: tid_dim + 1, + let new_fork = edit.add_node(new_fork); + + edit = edit.replace_all_uses(fork, new_fork)?; + edit.sub_edit(fork, new_fork); + + for (tid, node) in fork_users { + let Node::ThreadID { + control: _, + dimension: tid_dim, + } = node + else { + continue; + }; + if tid_dim > dim_idx { + let new_tid = Node::ThreadID { + control: new_fork, + dimension: tid_dim + 1, + }; + let new_tid = edit.add_node(new_tid); + edit = edit.replace_all_uses(tid, new_tid)?; + edit.sub_edit(tid, new_tid); + edit = edit.delete_node(tid)?; + } else if tid_dim == dim_idx { + let tile_tid = Node::ThreadID { + control: new_fork, + dimension: tid_dim + 1, + }; + let tile_tid = edit.add_node(tile_tid); + + let tile_size = edit.add_node(Node::DynamicConstant { id: tile_size }); + let mul = edit.add_node(Node::Binary { + left: tid, + right: tile_size, + op: BinaryOperator::Mul, + }); + let add = edit.add_node(Node::Binary { + left: mul, + right: tile_tid, + op: BinaryOperator::Add, + }); + edit.sub_edit(tid, add); + edit.sub_edit(tid, tile_tid); + edit = edit.replace_all_uses_where(tid, add, |usee| *usee != mul)?; + } + } + edit = edit.delete_node(fork)?; + Ok(edit) + }); + }, + TileOrder::TileOuter => { + editor.edit(|mut edit| { + let inner = DynamicConstant::div(new_factors[dim_idx], tile_size); + new_factors.insert(dim_idx, tile_size); + let inner_dc_id = edit.add_dynamic_constant(inner); + + + let new_fork = Node::Fork { + control: old_control, + factors: new_factors.into(), }; - let tile_tid = edit.add_node(tile_tid); - - let tile_size = edit.add_node(Node::DynamicConstant { id: tile_size }); - let mul = edit.add_node(Node::Binary { - left: tid, - right: tile_size, - op: BinaryOperator::Mul, - }); - let add = edit.add_node(Node::Binary { - left: mul, - right: tile_tid, - op: BinaryOperator::Add, - }); - edit.sub_edit(tid, add); - edit.sub_edit(tid, tile_tid); - edit = edit.replace_all_uses_where(tid, add, |usee| *usee != mul)?; - } + let new_fork = edit.add_node(new_fork); + + edit = edit.replace_all_uses(fork, new_fork)?; + edit.sub_edit(fork, new_fork); + + for (tid, node) in fork_users { + let Node::ThreadID { + control: _, + dimension: tid_dim, + } = node + else { + continue; + }; + if tid_dim > dim_idx { + let new_tid = Node::ThreadID { + control: new_fork, + dimension: tid_dim + 1, + }; + let new_tid = edit.add_node(new_tid); + edit = edit.replace_all_uses(tid, new_tid)?; + edit.sub_edit(tid, new_tid); + edit = edit.delete_node(tid)?; + } else if tid_dim == dim_idx { + + let tile_tid = Node::ThreadID { + control: new_fork, + dimension: tid_dim, + }; + let tile_tid = edit.add_node(tile_tid); + let inner_dc = edit.add_node(Node::DynamicConstant { id: inner_dc_id } ); + let mul = edit.add_node(Node::Binary { + left: tid, + right: inner_dc, + op: BinaryOperator::Mul, + }); + let add = edit.add_node(Node::Binary { + left: mul, + right: tile_tid, + op: BinaryOperator::Add, + }); + edit.sub_edit(tid, add); + edit.sub_edit(tid, tile_tid); + edit = edit.replace_all_uses_where(tid, add, |usee| *usee != mul)?; + } + } + edit = edit.delete_node(fork)?; + Ok(edit) + }); } - edit = edit.delete_node(fork)?; - Ok(edit) - }); + } + } pub fn merge_all_fork_dims(editor: &mut FunctionEditor, fork_join_map: &HashMap<NodeID, NodeID>) { diff --git a/juno_samples/fork_join_tests/src/cpu.sch b/juno_samples/fork_join_tests/src/cpu.sch index 9e87d26a..abc2fde0 100644 --- a/juno_samples/fork_join_tests/src/cpu.sch +++ b/juno_samples/fork_join_tests/src/cpu.sch @@ -54,7 +54,7 @@ gvn(*); phi-elim(*); dce(*); -fork-tile[32, 0, false](test6@loop); +fork-tile[32, 0, false, true](test6@loop); let out = fork-split(test6@loop); let out = outline(out.test6.fj1); cpu(out); diff --git a/juno_samples/fork_join_tests/src/gpu.sch b/juno_samples/fork_join_tests/src/gpu.sch index 8f4ec9ad..91bd6c79 100644 --- a/juno_samples/fork_join_tests/src/gpu.sch +++ b/juno_samples/fork_join_tests/src/gpu.sch @@ -52,7 +52,7 @@ slf(auto.test2); infer-schedules(auto.test2); fork-interchange[0, 1](auto.test2); -fork-tile[32, 0, false](test6@loop); +fork-tile[32, 0, false, true](test6@loop); let out = fork-split(test6@loop); let out = auto-outline(test6); gpu(out.test6); diff --git a/juno_scheduler/src/ir.rs b/juno_scheduler/src/ir.rs index 5bfb4e21..a9ee7956 100644 --- a/juno_scheduler/src/ir.rs +++ b/juno_scheduler/src/ir.rs @@ -43,7 +43,7 @@ impl Pass { pub fn num_args(&self) -> usize { match self { Pass::Xdot => 1, - Pass::ForkChunk => 3, + Pass::ForkChunk => 4, Pass::ForkFissionBufferize => 2, Pass::ForkInterchange => 2, _ => 0, diff --git a/juno_scheduler/src/pm.rs b/juno_scheduler/src/pm.rs index 342f875b..de725608 100644 --- a/juno_scheduler/src/pm.rs +++ b/juno_scheduler/src/pm.rs @@ -2118,7 +2118,7 @@ fn run_pass( pm.clear_analyses(); } Pass::ForkChunk => { - assert_eq!(args.len(), 3); + assert_eq!(args.len(), 4); let Some(Value::Integer { val: tile_size }) = args.get(0) else { return Err(SchedulerError::PassError { pass: "forkChunk".to_string(), @@ -2137,6 +2137,12 @@ fn run_pass( error: "expected boolean argument".to_string(), }); }; + let Some(Value::Boolean { val: tile_order }) = args.get(3) else { + return Err(SchedulerError::PassError { + pass: "forkChunk".to_string(), + error: "expected boolean argument".to_string(), + }); + }; assert!(!*guarded_flag); pm.make_fork_join_maps(); @@ -2148,7 +2154,7 @@ fn run_pass( let Some(mut func) = func else { continue; }; - chunk_all_forks_unguarded(&mut func, fork_join_map, *dim_idx, *tile_size); + chunk_all_forks_unguarded(&mut func, fork_join_map, *dim_idx, *tile_size, *tile_order); changed |= func.modified(); } pm.delete_gravestones(); -- GitLab