Skip to content
Snippets Groups Projects
loop_bound_canon.rs 11.5 KiB
Newer Older
  • Learn to ignore specific revisions
  • use std::collections::HashMap;
    use std::collections::HashSet;
    use std::iter::zip;
    use std::iter::FromIterator;
    
    use itertools::Itertools;
    use nestify::nest;
    
    use hercules_ir::*;
    
    use crate::*;
    
    pub fn loop_bound_canon_toplevel(
        editor: &mut FunctionEditor,
        fork_join_map: &HashMap<NodeID, NodeID>,
        control_subgraph: &Subgraph,
        loops: &LoopTree,
    ) -> bool {
        let natural_loops = loops
            .bottom_up_loops()
            .into_iter()
            .filter(|(k, _)| editor.func().nodes[k.idx()].is_region());
    
        let natural_loops: Vec<_> = natural_loops.collect();
    
        for l in natural_loops {
            if editor.is_mutable(l.0)
                && canonicalize_single_loop_bounds(
                    editor,
                    control_subgraph,
                    &Loop {
                        header: l.0,
                        control: l.1.clone(),
                    },
                )
            {
                return true;
            }
        }
        return false;
    }
    
    pub fn canonicalize_single_loop_bounds(
        editor: &mut FunctionEditor,
        control_subgraph: &Subgraph,
        l: &Loop,
    ) -> bool {
        let function = editor.func();
    
        let Some(loop_condition) = get_loop_exit_conditions(function, l, control_subgraph) else {
            return false;
        };
    
        let LoopExit::Conditional {
            if_node: loop_if,
            condition_node,
        } = loop_condition.clone()
        else {
            return false;
        };
    
        let loop_variance = compute_loop_variance(editor, l);
        let ivs = compute_induction_vars(editor.func(), l, &loop_variance);
        let ivs = compute_iv_ranges(editor, l, ivs, &loop_condition);
    
        if has_canonical_iv(editor, l, &ivs).is_some() {
    
    Xavier Routh's avatar
    Xavier Routh committed
            return false;
    
        }
    
        let loop_bound_iv_phis = get_loop_condition_ivs(editor, l, &ivs, &loop_condition);
    
        let (loop_bound_ivs, _): (Vec<InductionVariable>, Vec<InductionVariable>) = ivs
            .into_iter()
            .partition(|f| loop_bound_iv_phis.contains(&f.phi()));
    
    
        // Assume there is only one loop bound iv.
        if loop_bound_ivs.len() != 1 {
            return false;
        }
    
        let Some(iv) = loop_bound_ivs.first() else {
            return false;
        };
    
        let InductionVariable::Basic {
            node: iv_phi,
            initializer,
            final_value,
            update_expression,
            update_value,
        } = iv
        else {
            return false;
        };
    
    
        let Some(loop_pred) = editor
            .get_uses(l.header)
            .filter(|node| !l.control[node.idx()])
            .next()
        else {
            return false;
        };
    
        // If there is a guard, we need to edit it.
    
        // (init_id, bound_id, binop node, if node).
    
    
    Xavier Routh's avatar
    Xavier Routh committed
        // FIXME: This is not always correct, depends on lots of things about the loop IV. 
        let loop_bound_dc = match *editor.node(condition_node) {
            Node::Binary { left, right, op } => match op {
                BinaryOperator::LT => right,
                BinaryOperator::LTE => right,
                BinaryOperator::GT => {return false}
                BinaryOperator::GTE => {return false}
                BinaryOperator::EQ => {return false}
                BinaryOperator::NE => {return false}
                _ => {return false}
            },
            _ => {return false}
        };
    
        
    
        // FIXME: This is quite fragile.
    
    Xavier Routh's avatar
    Xavier Routh committed
        let mut guard_info: Option<(NodeID, NodeID, NodeID, NodeID)> = (|| {
    
    Aaron Councilman's avatar
    Aaron Councilman committed
            let Node::ControlProjection {
    
                control,
                selection: _,
            } = editor.node(loop_pred)
            else {
                return None;
            };
    
    
    Xavier Routh's avatar
    Xavier Routh committed
            let Node::If { cond, ..} = editor.node(control) else {
    
                return None;
            };
    
            let Node::Binary { left, right, op } = editor.node(cond) else {
                return None;
            };
    
            let Node::Binary {
                left: _,
    
    Xavier Routh's avatar
    Xavier Routh committed
                right: r,
    
                op: loop_op,
            } = editor.node(condition_node)
            else {
                return None;
            };
    
            if op != loop_op {
                return None;
            }
    
            if left != initializer {
                return None;
            }
    
    
    Xavier Routh's avatar
    Xavier Routh committed
            if right != r {
    
                return None;
            }
    
            return Some((*left, *right, *cond, *control));
        })();
    
        // // If guard is none, if some, make sure it is a good guard! move on
        // if let Some((init_id, bound_id, binop_node, if_node))= potential_guard_info {
    
        // };
    
        // let fork_guard_condition =
    
        // Lift dc math should make all constant into DCs, so these should all be DCs.
        let Node::DynamicConstant { id: init_dc_id } = *editor.node(initializer) else {
            return false;
        };
        let Node::DynamicConstant { id: update_dc_id } = *editor.node(update_value) else {
            return false;
        };
    
        // We are assuming this is a simple loop bound (i.e only one induction variable involved), so that .
        let Node::DynamicConstant {
            id: loop_bound_dc_id,
    
    Xavier Routh's avatar
    Xavier Routh committed
        } = *editor.node(loop_bound_dc)
    
        else {
            return false;
        };
    
        // We need to do 4 (5) things, which are mostly separate.
    
        // 0) Make the update into addition.
    
    Xavier Routh's avatar
    Xavier Routh committed
        // 1) Adjust update to be 1 (and bounds).
        // 2) Make the update a positive value. / Transform the condition into a `<`
        // - Are these separate?
    
        // 4) Change init to start from 0.
    
        // 5) Find some way to get fork-guard-elim to work with the new fork.
        // ideally, this goes in fork-guard-elim, but for now we hack it to change the guard condition bounds
        // here when we edit the loop bounds.
    
        // Right now we are just going to do (4), because I am lazy!
    
        // Collect info about the loop condition transformation.
        let mut dc_bound_node = match *editor.node(condition_node) {
            Node::Binary { left, right, op } => match op {
                BinaryOperator::LT => {
                    if left == *update_expression && editor.node(right).is_dynamic_constant() {
                        right
                    } else {
                        return false;
                    }
                }
    
    Xavier Routh's avatar
    Xavier Routh committed
                BinaryOperator::LTE => {
                    if left == *update_expression && editor.node(right).is_dynamic_constant() {
                        right
                    } else {
                        return false;
                    }
                }
    
                BinaryOperator::GT => todo!(),
                BinaryOperator::GTE => todo!(),
                BinaryOperator::EQ => todo!(),
                BinaryOperator::NE => todo!(),
                BinaryOperator::Or => todo!(),
                BinaryOperator::And => todo!(),
                BinaryOperator::Xor => todo!(),
                _ => panic!(),
            },
            _ => return false,
        };
    
    
    Xavier Routh's avatar
    Xavier Routh committed
        let condition_node_data = editor.node(condition_node).clone();
    
    
        let Node::DynamicConstant {
    
    Xavier Routh's avatar
    Xavier Routh committed
            id: mut bound_node_dc_id,
    
        } = *editor.node(dc_bound_node)
        else {
            return false;
        };
    
        // If increment is negative (how in the world do we know that...)
        // Increment can be DefinetlyPostiive, Unknown, DefinetlyNegative.
    
    Xavier Routh's avatar
    Xavier Routh committed
        let misc_guard_thing: Option<Node> =  if let Some((init_id, bound_id, binop_node, if_node)) = guard_info {
            Some(editor.node(binop_node).clone())
        } else {
            None
        };
    
        let mut condition_node = condition_node;
    
        let result = editor.edit(|mut edit| {
            // 2) Transform the condition into a < (from <=)
            if let Node::Binary { left, right, op } = condition_node_data {
                if BinaryOperator::LTE == op && left == *update_expression {
                    // Change the condition into <
                    let new_bop = edit.add_node(Node::Binary { left, right, op: BinaryOperator::LT });
                    
                    // Change the bound dc to be bound_dc + 1
                    let one = DynamicConstant::Constant(1);
                    let one = edit.add_dynamic_constant(one);
    
                    let tmp = DynamicConstant::add(bound_node_dc_id, one);
                    let new_condition_dc = edit.add_dynamic_constant(tmp);
    
                    let new_dc_bound_node = edit.add_node(Node::DynamicConstant { id: new_condition_dc });
    
                    // // 5) Change loop guard:
                    guard_info = if let Some((init_id, bound_id, binop_node, if_node)) = guard_info {
                        // Change binop node
                        let Some(Node::Binary { left, right, op }) =  misc_guard_thing else {unreachable!()};
                        let blah = edit.add_node(Node::DynamicConstant { id:  new_condition_dc});
    
                        // FIXME: Don't assume that right is the loop bound in the guard. 
                        let new_binop_node = edit.add_node(Node::Binary { left, right: blah, op: BinaryOperator::LT });
    
                        edit = edit.replace_all_uses_where(binop_node, new_binop_node, |usee| *usee == if_node)?;
                        Some((init_id, bound_id, new_binop_node, if_node))
                    } else {guard_info};
    
                    edit = edit.replace_all_uses_where(dc_bound_node, new_dc_bound_node, |usee| *usee == new_bop)?;
                    edit = edit.replace_all_uses(condition_node, new_bop)?;
    
                    // Change loop condition
                    dc_bound_node = new_dc_bound_node;
                    bound_node_dc_id = new_condition_dc;
                    condition_node = new_bop;
                }
            };
            Ok(edit)
        });
    
        let update_expr_users: Vec<_> = editor
            .get_users(*update_expression)
            .filter(|node| *node != iv.phi() && *node != condition_node)
            .collect();
        let iv_phi_users: Vec<_> = editor
            .get_users(iv.phi())
            .filter(|node| *node != iv.phi() && *node != *update_expression)
            .collect();
    
        let result = editor.edit(|mut edit| {
            // 4) Second, change loop IV to go from 0..N.
            // we subtract off init from init and dc_bound_node,
            // and then we add it back to uses of the IV.
            let new_init_dc = DynamicConstant::Constant(0);
            let new_init = Node::DynamicConstant {
                id: edit.add_dynamic_constant(new_init_dc),
            };
            let new_init = edit.add_node(new_init);
            edit = edit.replace_all_uses_where(*initializer, new_init, |usee| *usee == iv.phi())?;
    
    
    Xavier Routh's avatar
    Xavier Routh committed
            let new_condition_dc = DynamicConstant::sub(bound_node_dc_id, init_dc_id);
            let new_condition_dc_id = Node::DynamicConstant {
                id: edit.add_dynamic_constant(new_condition_dc),
    
    Xavier Routh's avatar
    Xavier Routh committed
            let new_condition_dc = edit.add_node(new_condition_dc_id);
    
            edit = edit
    
    Xavier Routh's avatar
    Xavier Routh committed
                .replace_all_uses_where(dc_bound_node, new_condition_dc, |usee| *usee == condition_node)?;
    
    Xavier Routh's avatar
    Xavier Routh committed
            // 5) Change loop guard:
    
            if let Some((init_id, bound_id, binop_node, if_node)) = guard_info {
                edit = edit.replace_all_uses_where(init_id, new_init, |usee| *usee == binop_node)?;
                edit =
    
    Xavier Routh's avatar
    Xavier Routh committed
                    edit.replace_all_uses_where(bound_id, new_condition_dc, |usee| *usee == binop_node)?;
    
    Xavier Routh's avatar
    Xavier Routh committed
            // 4) Add the offset back to users of the IV update expression
    
            let new_user = Node::Binary {
                left: *update_expression,
                right: *initializer,
                op: BinaryOperator::Add,
            };
            let new_user = edit.add_node(new_user);
            edit = edit.replace_all_uses_where(*update_expression, new_user, |usee| {
                *usee != iv.phi()
                    && *usee != *update_expression
                    && *usee != new_user
                    && *usee != condition_node
            })?;
    
    
    Xavier Routh's avatar
    Xavier Routh committed
            // Add the offset back to users of the IV directly
    
            let new_user = Node::Binary {
                left: *iv_phi,
                right: *initializer,
                op: BinaryOperator::Add,
            };
            let new_user = edit.add_node(new_user);
            edit = edit.replace_all_uses_where(*iv_phi, new_user, |usee| {
                *usee != iv.phi() && *usee != *update_expression && *usee != new_user
            })?;
    
            Ok(edit)
        });
    
        return result;
    }