From 23103bc2d1e31fae8880cd9063d4948a1be89a92 Mon Sep 17 00:00:00 2001 From: Russel Arbore <russel.jma@gmail.com> Date: Wed, 12 Feb 2025 22:02:51 -0600 Subject: [PATCH] holy shit that just worked --- hercules_opt/src/fork_transforms.rs | 81 +++++++++++++++++++++--- juno_samples/fork_join_tests/src/cpu.sch | 6 +- 2 files changed, 76 insertions(+), 11 deletions(-) diff --git a/hercules_opt/src/fork_transforms.rs b/hercules_opt/src/fork_transforms.rs index 94898b0d..2f7a91fa 100644 --- a/hercules_opt/src/fork_transforms.rs +++ b/hercules_opt/src/fork_transforms.rs @@ -1186,7 +1186,8 @@ pub fn fork_unroll( join: NodeID, nodes_in_fork_joins: &HashMap<NodeID, HashSet<NodeID>>, ) -> bool { - // We can only unroll forks with a compile time known factor list. + // We can only unroll fork-joins with a compile time known factor list. For + // simplicity, just unroll fork-joins that have a single dimension. let nodes = &editor.func().nodes; let Node::Fork { control, @@ -1195,17 +1196,79 @@ pub fn fork_unroll( else { panic!() }; - let mut cons_factors = vec![]; - for factor in factors { - let DynamicConstant::Constant(cons) = *editor.get_dynamic_constant(*factor) else { - return false; - }; - cons_factors.push(cons); + if factors.len() != 1 || editor.get_users(fork).count() != 2 { + return false; } - println!("{}: {:?}", editor.func().name, cons_factors); + let DynamicConstant::Constant(cons) = *editor.get_dynamic_constant(factors[0]) else { + return false; + }; + let tid = editor + .get_users(fork) + .filter(|id| nodes[id.idx()].is_thread_id()) + .next() + .unwrap(); + let inits: HashMap<NodeID, NodeID> = editor + .get_users(join) + .filter_map(|id| nodes[id.idx()].try_reduce().map(|(_, init, _)| (id, init))) + .collect(); editor.edit(|mut edit| { - (); + // Create a copy of the nodes in the fork join per unrolled iteration, + // excluding the fork itself, the join itself, the thread IDs of the fork, + // and the reduces on the join. Keep a running tally of the top control + // node and the current reduction value. + let mut top_control = control; + let mut current_reduces = inits; + for iter in 0..cons { + let iter_cons = edit.add_constant(Constant::UnsignedInteger64(iter as u64)); + let iter_tid = edit.add_node(Node::Constant { id: iter_cons }); + + // First, add a copy of each node in the fork join unmodified. + // Record the mapping from old ID to new ID. + let mut added_ids = HashSet::new(); + let mut old_to_new_ids = HashMap::new(); + let mut new_control = None; + let mut new_reduces = HashMap::new(); + for node in nodes_in_fork_joins[&fork].iter() { + if *node == fork { + old_to_new_ids.insert(*node, top_control); + } else if *node == join { + new_control = Some(edit.get_node(*node).try_join().unwrap()); + } else if *node == tid { + old_to_new_ids.insert(*node, iter_tid); + } else if let Some(current) = current_reduces.get(node) { + old_to_new_ids.insert(*node, *current); + new_reduces.insert(*node, edit.get_node(*node).try_reduce().unwrap().2); + } else { + let new_node = edit.add_node(edit.get_node(*node).clone()); + old_to_new_ids.insert(*node, new_node); + added_ids.insert(new_node); + } + } + + // Second, replace all the uses in the just added nodes. + if let Some(new_control) = new_control { + top_control = old_to_new_ids[&new_control]; + } + for (reduce, reduct) in new_reduces { + current_reduces.insert(reduce, old_to_new_ids[&reduct]); + } + for (old, new) in old_to_new_ids { + edit = edit.replace_all_uses_where(old, new, |id| added_ids.contains(id))?; + } + } + + // Hook up the control and reduce outputs to the rest of the function. + edit = edit.replace_all_uses(join, top_control)?; + for (reduce, reduct) in current_reduces { + edit = edit.replace_all_uses(reduce, reduct)?; + } + + // Delete the old fork-join. + for node in nodes_in_fork_joins[&fork].iter() { + edit = edit.delete_node(*node)?; + } + Ok(edit) }) } diff --git a/juno_samples/fork_join_tests/src/cpu.sch b/juno_samples/fork_join_tests/src/cpu.sch index 2c832d66..9e87d26a 100644 --- a/juno_samples/fork_join_tests/src/cpu.sch +++ b/juno_samples/fork_join_tests/src/cpu.sch @@ -39,8 +39,10 @@ dce(*); fixpoint panic after 20 { infer-schedules(*); } -unroll(auto.test1); -xdot[true](*); +fork-split(auto.test1); +fixpoint panic after 20 { + unroll(auto.test1); +} fork-split(auto.test2, auto.test3, auto.test4, auto.test5); gvn(*); -- GitLab