diff --git a/hercules_opt/src/loop_bound_canon.rs b/hercules_opt/src/loop_bound_canon.rs new file mode 100644 index 0000000000000000000000000000000000000000..0dce9c28dda6b6027830e192ecfa542769921e6f --- /dev/null +++ b/hercules_opt/src/loop_bound_canon.rs @@ -0,0 +1,314 @@ +use std::collections::HashMap; +use std::collections::HashSet; +use std::iter::zip; +use std::iter::FromIterator; + +use itertools::Itertools; +use nestify::nest; + +use hercules_ir::*; + +use crate::*; + +/* + * TODO: Forkify currently makes a bunch of small edits - this needs to be + * changed so that every loop that gets forkified corresponds to a single edit + * + sub-edits. This would allow us to run forkify on a subset of a function. + */ +pub fn loop_bound_canon_toplevel( + editor: &mut FunctionEditor, + fork_join_map: &HashMap<NodeID, NodeID>, + control_subgraph: &Subgraph, + loops: &LoopTree, +) -> bool { + let natural_loops = loops + .bottom_up_loops() + .into_iter() + .filter(|(k, _)| editor.func().nodes[k.idx()].is_region()); + + let natural_loops: Vec<_> = natural_loops.collect(); + + for l in natural_loops { + if editor.is_mutable(l.0) + && canonicalize_single_loop_bounds( + editor, + control_subgraph, + &Loop { + header: l.0, + control: l.1.clone(), + }, + ) + { + return true; + } + } + return false; +} + +pub fn canonicalize_single_loop_bounds( + editor: &mut FunctionEditor, + control_subgraph: &Subgraph, + l: &Loop, +) -> bool { + let function = editor.func(); + + let Some(loop_condition) = get_loop_exit_conditions(function, l, control_subgraph) else { + return false; + }; + + let LoopExit::Conditional { + if_node: loop_if, + condition_node, + } = loop_condition.clone() + else { + return false; + }; + + let loop_variance = compute_loop_variance(editor, l); + let ivs = compute_induction_vars(editor.func(), l, &loop_variance); + let ivs = compute_iv_ranges(editor, l, ivs, &loop_condition); + + if has_canonical_iv(editor, l, &ivs).is_some() { + // println!("has canon iv!"); + return true; + } + + let loop_bound_iv_phis = get_loop_condition_ivs(editor, l, &ivs, &loop_condition); + + let (loop_bound_ivs, _): (Vec<InductionVariable>, Vec<InductionVariable>) = ivs + .into_iter() + .partition(|f| loop_bound_iv_phis.contains(&f.phi())); + + // Assume there is only one loop bound iv. + if loop_bound_ivs.len() != 1 { + // println!("has multiple iv!"); + return false; + } + + let Some(iv) = loop_bound_ivs.first() else { + return false; + }; + + let InductionVariable::Basic { + node: iv_phi, + initializer, + final_value, + update_expression, + update_value, + } = iv + else { + return false; + }; + + let Some(final_value) = final_value else { + return false; + }; + + let Some(loop_pred) = editor + .get_uses(l.header) + .filter(|node| !l.control[node.idx()]) + .next() + else { + return false; + }; + + // If there is a guard, we need to edit it. + + // (init_id, bound_id, binop node, if node). + + // FIXME: This is quite fragile. + let guard_info: Option<(NodeID, NodeID, NodeID, NodeID)> = (|| { + let Node::Projection { + control, + selection: _, + } = editor.node(loop_pred) + else { + return None; + }; + + let Node::If { control, cond } = editor.node(control) else { + return None; + }; + + let Node::Binary { left, right, op } = editor.node(cond) else { + return None; + }; + + let Node::Binary { + left: _, + right: _, + op: loop_op, + } = editor.node(condition_node) + else { + return None; + }; + + if op != loop_op { + return None; + } + + if left != initializer { + return None; + } + + if right != final_value { + return None; + } + + return Some((*left, *right, *cond, *control)); + })(); + + // // If guard is none, if some, make sure it is a good guard! move on + // if let Some((init_id, bound_id, binop_node, if_node))= potential_guard_info { + + // }; + + // let fork_guard_condition = + + // Lift dc math should make all constant into DCs, so these should all be DCs. + let Node::DynamicConstant { id: init_dc_id } = *editor.node(initializer) else { + return false; + }; + let Node::DynamicConstant { id: update_dc_id } = *editor.node(update_value) else { + return false; + }; + + // We are assuming this is a simple loop bound (i.e only one induction variable involved), so that . + let Node::DynamicConstant { + id: loop_bound_dc_id, + } = *editor.node(final_value) + else { + return false; + }; + + // We need to do 4 (5) things, which are mostly separate. + + // 0) Make the update into addition. + // 1) Make the update a positive value. + // 2) Transform the condition into a `<` + // 3) Adjust update to be 1 (and bounds). + // 4) Change init to start from 0. + + // 5) Find some way to get fork-guard-elim to work with the new fork. + // ideally, this goes in fork-guard-elim, but for now we hack it to change the guard condition bounds + // here when we edit the loop bounds. + + // Right now we are just going to do (4), because I am lazy! + + // Collect info about the loop condition transformation. + let mut dc_bound_node = match *editor.node(condition_node) { + Node::Binary { left, right, op } => match op { + BinaryOperator::LT => { + if left == *update_expression && editor.node(right).is_dynamic_constant() { + right + } else { + return false; + } + } + BinaryOperator::LTE => todo!(), + BinaryOperator::GT => todo!(), + BinaryOperator::GTE => todo!(), + BinaryOperator::EQ => todo!(), + BinaryOperator::NE => todo!(), + BinaryOperator::Or => todo!(), + BinaryOperator::And => todo!(), + BinaryOperator::Xor => todo!(), + _ => panic!(), + }, + _ => return false, + }; + + let Node::DynamicConstant { + id: bound_node_dc_id, + } = *editor.node(dc_bound_node) + else { + return false; + }; + + // If increment is negative (how in the world do we know that...) + // Increment can be DefinetlyPostiive, Unknown, DefinetlyNegative. + + // // First, massage loop condition to be <, because that is normal! + // Also includes + // editor.edit(|mut edit| { + + // } + // Collect immediate IV users + + let update_expr_users: Vec<_> = editor + .get_users(*update_expression) + .filter(|node| *node != iv.phi() && *node != condition_node) + .collect(); + // println!("update_expr_users: {:?}", update_expr_users); + let iv_phi_users: Vec<_> = editor + .get_users(iv.phi()) + .filter(|node| *node != iv.phi() && *node != *update_expression) + .collect(); + + // println!(" iv_phi_users: {:?}", iv_phi_users); + + let result = editor.edit(|mut edit| { + // 4) Second, change loop IV to go from 0..N. + // we subtract off init from init and dc_bound_node, + // and then we add it back to uses of the IV. + let new_init_dc = DynamicConstant::Constant(0); + let new_init = Node::DynamicConstant { + id: edit.add_dynamic_constant(new_init_dc), + }; + let new_init = edit.add_node(new_init); + edit = edit.replace_all_uses_where(*initializer, new_init, |usee| *usee == iv.phi())?; + + let new_condition_id = DynamicConstant::sub(bound_node_dc_id, init_dc_id); + let new_condition = Node::DynamicConstant { + id: edit.add_dynamic_constant(new_condition_id), + }; + let new_condition = edit.add_node(new_condition); + edit = edit + .replace_all_uses_where(dc_bound_node, new_condition, |usee| *usee == condition_node)?; + + // Change loop guard: + if let Some((init_id, bound_id, binop_node, if_node)) = guard_info { + edit = edit.replace_all_uses_where(init_id, new_init, |usee| *usee == binop_node)?; + edit = + edit.replace_all_uses_where(bound_id, new_condition, |usee| *usee == binop_node)?; + } + + // Add back to uses of the IV + for user in update_expr_users { + let new_user = Node::Binary { + left: user, + right: *initializer, + op: BinaryOperator::Add, + }; + let new_user = edit.add_node(new_user); + edit = edit.replace_all_uses(user, new_user)?; + } + + let new_user = Node::Binary { + left: *update_expression, + right: *initializer, + op: BinaryOperator::Add, + }; + let new_user = edit.add_node(new_user); + edit = edit.replace_all_uses_where(*update_expression, new_user, |usee| { + *usee != iv.phi() + && *usee != *update_expression + && *usee != new_user + && *usee != condition_node + })?; + + let new_user = Node::Binary { + left: *iv_phi, + right: *initializer, + op: BinaryOperator::Add, + }; + let new_user = edit.add_node(new_user); + edit = edit.replace_all_uses_where(*iv_phi, new_user, |usee| { + *usee != iv.phi() && *usee != *update_expression && *usee != new_user + })?; + + Ok(edit) + }); + + return result; +}