use std::collections::HashMap; use std::collections::HashSet; use std::iter::zip; use std::iter::FromIterator; use itertools::Itertools; use nestify::nest; use hercules_ir::*; use crate::*; /* * TODO: Forkify currently makes a bunch of small edits - this needs to be * changed so that every loop that gets forkified corresponds to a single edit * + sub-edits. This would allow us to run forkify on a subset of a function. */ pub fn forkify( editor: &mut FunctionEditor, control_subgraph: &Subgraph, fork_join_map: &HashMap<NodeID, NodeID>, loops: &LoopTree, ) -> bool { let natural_loops = loops .bottom_up_loops() .into_iter() .filter(|(k, _)| editor.func().nodes[k.idx()].is_region()); let natural_loops: Vec<_> = natural_loops.collect(); for l in natural_loops { // FIXME: Run on all-bottom level loops, as they can be independently optimized without recomputing analyses. if editor.is_mutable(l.0) && forkify_loop( editor, control_subgraph, fork_join_map, &Loop { header: l.0, control: l.1.clone(), }, ) { return true; } } return false; } /** Given a node used as a loop bound, return a dynamic constant ID. */ pub fn get_node_as_dc( editor: &mut FunctionEditor, node: NodeID, ) -> Result<DynamicConstantID, String> { // Check for a constant used as loop bound. match editor.node(node) { Node::DynamicConstant { id: dynamic_constant_id, } => Ok(*dynamic_constant_id), Node::Constant { id: constant_id } => { let dc = match *editor.get_constant(*constant_id) { Constant::Integer8(x) => DynamicConstant::Constant(x as _), Constant::Integer16(x) => DynamicConstant::Constant(x as _), Constant::Integer32(x) => DynamicConstant::Constant(x as _), Constant::Integer64(x) => DynamicConstant::Constant(x as _), Constant::UnsignedInteger8(x) => DynamicConstant::Constant(x as _), Constant::UnsignedInteger16(x) => DynamicConstant::Constant(x as _), Constant::UnsignedInteger32(x) => DynamicConstant::Constant(x as _), Constant::UnsignedInteger64(x) => DynamicConstant::Constant(x as _), _ => return Err("Invalid constant as loop bound".to_string()), }; let mut b = DynamicConstantID::new(0); editor.edit(|mut edit| { b = edit.add_dynamic_constant(dc); Ok(edit) }); // Return the ID of the dynamic constant that is generated from the constant // or dynamic constant that is the existing loop bound Ok(b) } _ => Err("Blah".to_owned()), } } /** Top level function to convert natural loops with simple induction variables into fork-joins. */ pub fn forkify_loop( editor: &mut FunctionEditor, control_subgraph: &Subgraph, _fork_join_map: &HashMap<NodeID, NodeID>, l: &Loop, ) -> bool { let function = editor.func(); println!("forkifying {:?}", l.header); let Some(loop_condition) = get_loop_exit_conditions(function, l, control_subgraph) else { return false; }; let LoopExit::Conditional { if_node: loop_if, condition_node, } = loop_condition.clone() else { return false; }; // Compute loop variance let loop_variance = compute_loop_variance(editor, l); let ivs = compute_induction_vars(editor.func(), l, &loop_variance); let ivs = compute_iv_ranges(editor, l, ivs, &loop_condition); let Some(canonical_iv) = has_canonical_iv(editor, l, &ivs) else { println!("no canonical iv"); return false; }; // Get bound let bound = match canonical_iv { InductionVariable::Basic { node: _, initializer: _, final_value, update_expression, update_value, } => final_value .map(|final_value| get_node_as_dc(editor, final_value)) .and_then(|r| r.ok()), InductionVariable::SCEV(_) => return false, }; let Some(bound_dc_id) = bound else { println!("no bound iv"); return false; }; let function = editor.func(); // Check if it is do-while loop. let loop_exit_projection = editor .get_users(loop_if) .filter(|id| !l.control[id.idx()]) .next() .unwrap(); let loop_continue_projection = editor .get_users(loop_if) .filter(|id| l.control[id.idx()]) .next() .unwrap(); let loop_preds: Vec<_> = editor .get_uses(l.header) .filter(|id| !l.control[id.idx()]) .collect(); // FIXME: @xrouth if loop_preds.len() != 1 { return false; } let loop_pred = loop_preds[0]; if !editor .get_uses(l.header) .contains(&loop_continue_projection) { return false; } println!("this one"); // Get all phis used outside of the loop, they need to be reductionable. // For now just assume all phis will be phis used outside of the loop, except for the canonical iv. // FIXME: We need a different definiton of `loop_nodes` to check for phis used outside hte loop than the one // we currently have. let loop_nodes = calculate_loop_nodes(editor, l); // Check phis to see if they are reductionable, only PHIs depending on the loop are considered, let candidate_phis: Vec<_> = editor .get_users(l.header) .filter(|id| function.nodes[id.idx()].is_phi()) .filter(|id| *id != canonical_iv.phi()) .collect(); let reductionable_phis: Vec<_> = analyze_phis(&editor, &l, &candidate_phis, &loop_nodes) .into_iter() .collect(); // TODO: Handle multiple loop body lasts. // If there are multiple candidates for loop body last, return false. if editor .get_uses(loop_if) .filter(|id| l.control[id.idx()]) .count() > 1 { return false; } let loop_body_last = editor.get_uses(loop_if).next().unwrap(); println!("phis {:?}", reductionable_phis); if reductionable_phis .iter() .any(|phi| !matches!(phi, LoopPHI::Reductionable { .. })) { return false; } let phi_latches: Vec<_> = reductionable_phis .iter() .map(|phi| { let LoopPHI::Reductionable { phi: _, data_cycle: _, continue_latch, is_associative: _, } = phi else { unreachable!() }; continue_latch }) .collect(); let stop_on: HashSet<_> = editor .node_ids() .filter(|node| { if editor.node(node).is_phi() { return true; } if editor.node(node).is_reduce() { return true; } if editor.node(node).is_control() { return true; } if phi_latches.contains(&node) { return true; } false }) .collect(); // Outside loop users of IV, then exit; // Unless the outside user is through the loop latch of a reducing phi, // then we know how to replace this edge, so its fine! let iv_users: Vec<_> = walk_all_users_stop_on(canonical_iv.phi(), editor, stop_on.clone()).collect(); if iv_users .iter() .any(|node| !loop_nodes.contains(&node) && *node != loop_if) { return false; } // Start Transformation: // Graft everything between header and loop condition // Attach join to right before header (after loop_body_last, unless loop body last *is* the header). // Attach fork to right after loop_continue_projection. // // Create fork and join nodes: let mut join_id = NodeID::new(0); let mut fork_id = NodeID::new(0); // Turn dc bound into max (1, bound), let bound_dc_id = { let mut max_id = DynamicConstantID::new(0); editor.edit(|mut edit| { let one_id = edit.add_dynamic_constant(DynamicConstant::Constant(1)); max_id = edit.add_dynamic_constant(DynamicConstant::max(one_id, bound_dc_id)); Ok(edit) }); max_id }; // FIXME: (@xrouth) double check handling of control in loop body. editor.edit(|mut edit| { let fork = Node::Fork { control: loop_pred, factors: Box::new([bound_dc_id]), }; fork_id = edit.add_node(fork); let join = Node::Join { control: if l.header == loop_body_last { fork_id } else { loop_body_last }, }; join_id = edit.add_node(join); Ok(edit) }); let function = editor.func(); let (_, factors) = function.nodes[fork_id.idx()].try_fork().unwrap(); let dimension = factors.len() - 1; let redcutionable_phis_and_init: Vec<(_, NodeID)> = reductionable_phis .iter() .map(|reduction_phi| { let LoopPHI::Reductionable { phi, data_cycle: _, continue_latch: _, is_associative: _, } = reduction_phi else { panic!(); }; let function = editor.func(); let init = *zip( editor.get_uses(l.header), function.nodes[phi.idx()].try_phi().unwrap().1.iter(), ) .filter(|(c, _)| *c == loop_pred) .next() .unwrap() .1; (reduction_phi, init) }) .collect(); // Start failable edit: let result = editor.edit(|mut edit| { let thread_id = Node::ThreadID { control: fork_id, dimension: dimension, }; let thread_id_id = edit.add_node(thread_id); // Replace uses that are inside with the thread id edit = edit.replace_all_uses_where(canonical_iv.phi(), thread_id_id, |node| { loop_nodes.contains(node) })?; edit.sub_edit(canonical_iv.phi(), thread_id_id); edit = edit.delete_node(canonical_iv.phi())?; for (reduction_phi, init) in redcutionable_phis_and_init { let LoopPHI::Reductionable { phi, data_cycle: _, continue_latch, is_associative: _, } = *reduction_phi else { panic!(); }; let reduce = Node::Reduce { control: join_id, init, reduct: continue_latch, }; let reduce_id = edit.add_node(reduce); if (!edit.get_node(init).is_reduce() && edit.get_schedule(init).contains(&Schedule::ParallelReduce)) || (!edit.get_node(continue_latch).is_reduce() && edit .get_schedule(continue_latch) .contains(&Schedule::ParallelReduce)) { edit = edit.add_schedule(reduce_id, Schedule::ParallelReduce)?; } if (!edit.get_node(init).is_reduce() && edit.get_schedule(init).contains(&Schedule::MonoidReduce)) || (!edit.get_node(continue_latch).is_reduce() && edit .get_schedule(continue_latch) .contains(&Schedule::MonoidReduce)) { edit = edit.add_schedule(reduce_id, Schedule::MonoidReduce)?; } edit = edit.replace_all_uses_where(phi, reduce_id, |usee| *usee != reduce_id)?; edit = edit.replace_all_uses_where(continue_latch, reduce_id, |usee| { !loop_nodes.contains(usee) && *usee != reduce_id })?; edit.sub_edit(phi, reduce_id); edit = edit.delete_node(phi)? } edit = edit.replace_all_uses(l.header, fork_id)?; edit = edit.replace_all_uses(loop_continue_projection, fork_id)?; edit = edit.replace_all_uses(loop_exit_projection, join_id)?; edit.sub_edit(l.header, fork_id); edit.sub_edit(loop_continue_projection, fork_id); edit.sub_edit(loop_exit_projection, join_id); edit = edit.delete_node(loop_continue_projection)?; edit = edit.delete_node(condition_node)?; // Might have to get rid of other users of this. edit = edit.delete_node(loop_exit_projection)?; edit = edit.delete_node(loop_if)?; edit = edit.delete_node(l.header)?; Ok(edit) }); println!("result: {:?}", result); return result; } nest! { #[derive(Debug)] pub enum LoopPHI { Reductionable { phi: NodeID, data_cycle: HashSet<NodeID>, // All nodes in a data cycle with this phi continue_latch: NodeID, is_associative: bool, }, LoopDependant(NodeID), ControlDependant(NodeID), // This phi is redcutionable, but its cycle might depend on internal control within the loop. UsedByDependant(NodeID), } } impl LoopPHI { pub fn get_phi(&self) -> NodeID { match self { LoopPHI::Reductionable { phi, .. } => *phi, LoopPHI::LoopDependant(node_id) => *node_id, LoopPHI::UsedByDependant(node_id) => *node_id, LoopPHI::ControlDependant(node_id) => *node_id, } } } /** Checks some conditions on loop variables that will need to be converted into reductions to be forkified. - The phi is in a cycle *in the loop* with itself. - Every cycle *in the loop* containing the phi does not contain any other phi of the loop header. - The phi does not immediatley (not blocked by another phi or another reduce) use any other phis of the loop header. */ pub fn analyze_phis<'a>( editor: &'a FunctionEditor, natural_loop: &'a Loop, phis: &'a [NodeID], loop_nodes: &'a HashSet<NodeID>, ) -> impl Iterator<Item = LoopPHI> + 'a { // Find data cycles within the loop of this phi, // Start from the phis loop_continue_latch, and walk its uses until we find the original phi. phis.into_iter().map(move |phi| { let stop_on: HashSet<NodeID> = editor .node_ids() .filter(|node| { let data = &editor.func().nodes[node.idx()]; // External Phi if let Node::Phi { control, data: _ } = data { if !natural_loop.control[control.idx()] { return true; } } // This phi if node == phi { return true; } // External Reduce if let Node::Reduce { control, init: _, reduct: _, } = data { if !natural_loop.control[control.idx()] { return true; } else { return false; } } // Data Cycles Only if data.is_control() { return true; } return false; }) .collect(); let continue_idx = editor .get_uses(natural_loop.header) .position(|node| natural_loop.control[node.idx()]) .unwrap(); let loop_continue_latch = editor.node(phi).try_phi().unwrap().1[continue_idx]; let uses = walk_all_uses_stop_on(loop_continue_latch, editor, stop_on.clone()); let users = walk_all_users_stop_on(*phi, editor, stop_on.clone()); let other_stop_on: HashSet<NodeID> = editor .node_ids() .filter(|node| { let data = &editor.func().nodes[node.idx()]; // Phi, Reduce if data.is_phi() { return true; } if data.is_reduce() { return true; } // External Control if data.is_control() { return true; } return false; }) .collect(); let mut uses_for_dependance = walk_all_users_stop_on(loop_continue_latch, editor, other_stop_on); let set1: HashSet<_> = HashSet::from_iter(uses); let set2: HashSet<_> = HashSet::from_iter(users); let intersection: HashSet<_> = set1.intersection(&set2).cloned().collect(); // If this phi uses any other phis the node is loop dependant, // // we use `phis` because this phi can actually contain the loop iv and its fine. // if uses_for_dependance.any(|node| phis.contains(&node) && node != *phi) { // LoopPHI::LoopDependant(*phi) // } else if intersection.clone().iter().next().is_some() { // PHIs on the frontier of the uses by the candidate phi, i.e in uses_for_dependance need // to have headers that postdominate the loop continue latch. The value of the PHI used needs to be defined // by the time the reduce is triggered (at the end of the loop's internal control). // No nodes in data cycles with this phi (in the loop) are used outside the loop, besides the loop_continue_latch. // If some other node in the cycle is used, there is not a valid node to assign it after making the cycle a reduce. if intersection .iter() .filter(|node| **node != loop_continue_latch) .any(|data_node| { editor .get_users(*data_node) .any(|user| !loop_nodes.contains(&user)) }) { // This phi can be made into a reduce in different ways, if the cycle is associative (contains all the same kind of associative op) // 3) Split the cycle into two phis, add them or multiply them together at the end. // 4) Split the cycle into two reduces, add them or multiply them together at the end. // Somewhere else should handle this. return LoopPHI::LoopDependant(*phi); } // FIXME: Do we want to calculate associativity here, there might be a case where this information is used in forkify // i.e as described above. let is_associative = false; // No nodes in the data cycle are used outside of the loop, besides the latched value of the phi LoopPHI::Reductionable { phi: *phi, data_cycle: intersection, continue_latch: loop_continue_latch, is_associative, } } else { // No cycles exist, this isn't a reduction. LoopPHI::LoopDependant(*phi) } }) }