use std::collections::{HashMap, HashSet}; use std::iter::zip; use bitvec::{order::Lsb0, vec::BitVec}; use hercules_ir::{ir::*, LoopTree}; use crate::*; type NodeVec = BitVec<u8, Lsb0>; pub fn calculate_fork_nodes( editor: &FunctionEditor, inner_control: &NodeVec, fork: NodeID, ) -> HashSet<NodeID> { // Stop on PHIs / reduces outside of loop. let stop_on: HashSet<NodeID> = editor .node_ids() .filter(|node| { let data = &editor.func().nodes[node.idx()]; // External Phi if let Node::Phi { control, data: _ } = data { if match inner_control.get(control.idx()) { Some(v) => !*v, // None => true, // Doesn't exist, must be external } { return true; } } // External Reduce if let Node::Reduce { control, init: _, reduct: _, } = data { if match inner_control.get(control.idx()) { Some(v) => !*v, // None => true, // Doesn't exist, must be external } { return true; } } // External Control if data.is_control() { return match inner_control.get(node.idx()) { Some(v) => !*v, // None => true, // Doesn't exist, must be external }; } // else return false; }) .collect(); let reduces: Vec<_> = editor .node_ids() .filter(|node| { let Node::Reduce { control, .. } = editor.func().nodes[node.idx()] else { return false; }; match inner_control.get(control.idx()) { Some(v) => *v, None => false, } }) .chain( editor .get_users(fork) .filter(|node| editor.node(node).is_thread_id()), ) .collect(); let all_users: HashSet<NodeID> = reduces .clone() .iter() .flat_map(|phi| walk_all_users_stop_on(*phi, editor, stop_on.clone())) .chain(reduces.clone()) .collect(); let all_uses: HashSet<_> = reduces .clone() .iter() .flat_map(|phi| walk_all_uses_stop_on(*phi, editor, stop_on.clone())) .chain(reduces) .filter(|node| { // Get rid of nodes in stop_on !stop_on.contains(node) }) .collect(); all_users.intersection(&all_uses).cloned().collect() } /* * Convert forks back into loops right before codegen when a backend is not * lowering a fork-join to vector / parallel code. Lowering fork-joins into * sequential loops in LLVM is actually not entirely trivial, so it's easier to * just do this transformation within Hercules IR. */ // FIXME: Only works on fully split fork nests. pub fn unforkify_all( editor: &mut FunctionEditor, fork_join_map: &HashMap<NodeID, NodeID>, loop_tree: &LoopTree, ) { for l in loop_tree.bottom_up_loops().into_iter().rev() { if !editor.node(l.0).is_fork() { continue; } let fork = l.0; let join = fork_join_map[&fork]; unforkify(editor, fork, join, loop_tree); } } pub fn unforkify_one( editor: &mut FunctionEditor, fork_join_map: &HashMap<NodeID, NodeID>, loop_tree: &LoopTree, ) { for l in loop_tree.bottom_up_loops().into_iter().rev() { if !editor.node(l.0).is_fork() { continue; } let fork = l.0; let join = fork_join_map[&fork]; if unforkify(editor, fork, join, loop_tree) { break; } } } pub fn unforkify( editor: &mut FunctionEditor, fork: NodeID, join: NodeID, loop_tree: &LoopTree, ) -> bool { let mut zero_cons_id = ConstantID::new(0); let mut one_cons_id = ConstantID::new(0); assert!(editor.edit(|mut edit| { zero_cons_id = edit.add_constant(Constant::UnsignedInteger64(0)); one_cons_id = edit.add_constant(Constant::UnsignedInteger64(1)); Ok(edit) })); // Convert the fork to a region, thread IDs to a single phi, reduces to // phis, and the join to a branch at the top of the loop. The previous // control insides of the fork-join should become the successor of the true // projection node, and what was the use of the join should become a use of // the new region. let fork_nodes = calculate_fork_nodes(editor, loop_tree.nodes_in_loop_bitvec(fork), fork); let nodes = &editor.func().nodes; let (fork_control, factors) = nodes[fork.idx()].try_fork().unwrap(); if factors.len() > 1 { // For now, don't convert multi-dimensional fork-joins. Rely on pass // that splits fork-joins. return false; } let join_control = nodes[join.idx()].try_join().unwrap(); let tids: Vec<_> = editor .get_users(fork) .filter(|id| nodes[id.idx()].is_thread_id()) .collect(); let reduces: Vec<_> = editor .get_users(join) .filter(|id| nodes[id.idx()].is_reduce()) .collect(); let num_nodes = editor.node_ids().len(); let region_id = NodeID::new(num_nodes); let if_id = NodeID::new(num_nodes + 1); let proj_back_id = NodeID::new(num_nodes + 2); let proj_exit_id = NodeID::new(num_nodes + 3); let zero_id = NodeID::new(num_nodes + 4); let one_id = NodeID::new(num_nodes + 5); let indvar_id = NodeID::new(num_nodes + 6); let add_id = NodeID::new(num_nodes + 7); let dc_id = NodeID::new(num_nodes + 8); let neq_id = NodeID::new(num_nodes + 9); let guard_if_id = NodeID::new(num_nodes + 10); let guard_join_id = NodeID::new(num_nodes + 11); let guard_taken_proj_id = NodeID::new(num_nodes + 12); let guard_skipped_proj_id = NodeID::new(num_nodes + 13); let guard_cond_id = NodeID::new(num_nodes + 14); let phi_ids = (num_nodes + 15..num_nodes + 15 + reduces.len()).map(NodeID::new); let s = num_nodes + 15 + reduces.len(); let join_phi_ids = (s..s + reduces.len()).map(NodeID::new); let guard_cond = Node::Binary { left: zero_id, right: dc_id, op: BinaryOperator::LT, }; let guard_if = Node::If { control: fork_control, cond: guard_cond_id, }; let guard_taken_proj = Node::Projection { control: guard_if_id, selection: 1, }; let guard_skipped_proj = Node::Projection { control: guard_if_id, selection: 0, }; let guard_join = Node::Region { preds: Box::new([guard_skipped_proj_id, proj_exit_id]), }; let region = Node::Region { preds: Box::new([guard_taken_proj_id, proj_back_id]), }; let if_node = Node::If { control: join_control, cond: neq_id, }; let proj_back = Node::Projection { control: if_id, selection: 1, }; let proj_exit = Node::Projection { control: if_id, selection: 0, }; let zero = Node::Constant { id: zero_cons_id }; let one = Node::Constant { id: one_cons_id }; let indvar = Node::Phi { control: region_id, data: Box::new([zero_id, add_id]), }; let add = Node::Binary { op: BinaryOperator::Add, left: indvar_id, right: one_id, }; let dc = Node::DynamicConstant { id: factors[0] }; let neq = Node::Binary { op: BinaryOperator::NE, left: add_id, right: dc_id, }; let (phis, join_phis): (Vec<_>, Vec<_>) = reduces .iter() .map(|reduce_id| { let (_, init, reduct) = nodes[reduce_id.idx()].try_reduce().unwrap(); ( Node::Phi { control: region_id, data: Box::new([init, reduct]), }, Node::Phi { control: guard_join_id, data: Box::new([init, reduct]), }, ) }) .unzip(); editor.edit(|mut edit| { assert_eq!(edit.add_node(region), region_id); assert_eq!(edit.add_node(if_node), if_id); assert_eq!(edit.add_node(proj_back), proj_back_id); assert_eq!(edit.add_node(proj_exit), proj_exit_id); assert_eq!(edit.add_node(zero), zero_id); assert_eq!(edit.add_node(one), one_id); assert_eq!(edit.add_node(indvar), indvar_id); assert_eq!(edit.add_node(add), add_id); assert_eq!(edit.add_node(dc), dc_id); assert_eq!(edit.add_node(neq), neq_id); assert_eq!(edit.add_node(guard_if), guard_if_id); assert_eq!(edit.add_node(guard_join), guard_join_id); assert_eq!(edit.add_node(guard_taken_proj), guard_taken_proj_id); assert_eq!(edit.add_node(guard_skipped_proj), guard_skipped_proj_id); assert_eq!(edit.add_node(guard_cond), guard_cond_id); for (phi_id, phi) in zip(phi_ids.clone(), &phis) { assert_eq!(edit.add_node(phi.clone()), phi_id); } for (phi_id, phi) in zip(join_phi_ids.clone(), &join_phis) { assert_eq!(edit.add_node(phi.clone()), phi_id); } edit = edit.replace_all_uses(fork, region_id)?; edit = edit.replace_all_uses_where(join, guard_join_id, |usee| *usee != if_id)?; edit.sub_edit(fork, region_id); edit.sub_edit(join, if_id); for tid in tids.iter() { edit.sub_edit(*tid, indvar_id); edit = edit.replace_all_uses(*tid, indvar_id)?; } for (((reduce, phi_id), phi), join_phi_id) in zip(reduces.iter(), phi_ids).zip(phis).zip(join_phi_ids) { edit.sub_edit(*reduce, phi_id); let Node::Phi { control: _, data } = phi else { panic!() }; edit = edit .replace_all_uses_where(*reduce, join_phi_id, |usee| !fork_nodes.contains(usee))?; //, |usee| *usee != *reduct)?; edit = edit.replace_all_uses_where(*reduce, phi_id, |usee| { fork_nodes.contains(usee) || *usee == data[1] })?; edit = edit.delete_node(*reduce)?; } edit = edit.delete_node(fork)?; edit = edit.delete_node(join)?; for tid in tids { edit = edit.delete_node(tid)?; } Ok(edit) }) }