Skip to content
Snippets Groups Projects
forkify.rs 18.72 KiB
use std::collections::HashMap;
use std::collections::HashSet;
use std::iter::zip;
use std::iter::FromIterator;

use itertools::Itertools;
use nestify::nest;

use hercules_ir::*;

use crate::*;

/*
 * TODO: Forkify currently makes a bunch of small edits - this needs to be
 * changed so that every loop that gets forkified corresponds to a single edit
 * + sub-edits. This would allow us to run forkify on a subset of a function.
 */
pub fn forkify(
    editor: &mut FunctionEditor,
    control_subgraph: &Subgraph,
    fork_join_map: &HashMap<NodeID, NodeID>,
    loops: &LoopTree,
) -> bool {
    let natural_loops = loops
        .bottom_up_loops()
        .into_iter()
        .filter(|(k, _)| editor.func().nodes[k.idx()].is_region());

    let natural_loops: Vec<_> = natural_loops.collect();

    for l in natural_loops {
        // FIXME: Run on all-bottom level loops, as they can be independently optimized without recomputing analyses.
        if editor.is_mutable(l.0)
            && forkify_loop(
                editor,
                control_subgraph,
                fork_join_map,
                &Loop {
                    header: l.0,
                    control: l.1.clone(),
                },
            )
        {
            return true;
        }
    }
    return false;
}

/** Given a node used as a loop bound, return a dynamic constant ID. */
pub fn get_node_as_dc(
    editor: &mut FunctionEditor,
    node: NodeID,
) -> Result<DynamicConstantID, String> {
    // Check for a constant used as loop bound.
    match editor.node(node) {
        Node::DynamicConstant {
            id: dynamic_constant_id,
        } => Ok(*dynamic_constant_id),
        Node::Constant { id: constant_id } => {
            let dc = match *editor.get_constant(*constant_id) {
                Constant::Integer8(x) => DynamicConstant::Constant(x as _),
                Constant::Integer16(x) => DynamicConstant::Constant(x as _),
                Constant::Integer32(x) => DynamicConstant::Constant(x as _),
                Constant::Integer64(x) => DynamicConstant::Constant(x as _),
                Constant::UnsignedInteger8(x) => DynamicConstant::Constant(x as _),
                Constant::UnsignedInteger16(x) => DynamicConstant::Constant(x as _),
                Constant::UnsignedInteger32(x) => DynamicConstant::Constant(x as _),
                Constant::UnsignedInteger64(x) => DynamicConstant::Constant(x as _),
                _ => return Err("Invalid constant as loop bound".to_string()),
            };

            let mut b = DynamicConstantID::new(0);
            editor.edit(|mut edit| {
                b = edit.add_dynamic_constant(dc);
                Ok(edit)
            });
            // Return the ID of the dynamic constant that is generated from the constant
            // or dynamic constant that is the existing loop bound
            Ok(b)
        }
        _ => Err("Blah".to_owned()),
    }
}

/**
 Top level function to convert natural loops with simple induction variables
 into fork-joins.
*/
pub fn forkify_loop(
    editor: &mut FunctionEditor,
    control_subgraph: &Subgraph,
    _fork_join_map: &HashMap<NodeID, NodeID>,
    l: &Loop,
) -> bool {
    let function = editor.func();

    let Some(loop_condition) = get_loop_exit_conditions(function, l, control_subgraph) else {
        return false;
    };

    let LoopExit::Conditional {
        if_node: loop_if,
        condition_node,
    } = loop_condition.clone()
    else {
        return false;
    };

    // Compute loop variance
    let loop_variance = compute_loop_variance(editor, l);
    let ivs = compute_induction_vars(editor.func(), l, &loop_variance);
    let ivs = compute_iv_ranges(editor, l, ivs, &loop_condition);
    let Some(canonical_iv) = has_canonical_iv(editor, l, &ivs) else {
        return false;
    };

    // Get bound
    let bound = match canonical_iv {
        InductionVariable::Basic {
            node: _,
            initializer: _,
            final_value,
            update_expression,
            update_value,
        } => final_value
            .map(|final_value| get_node_as_dc(editor, final_value))
            .and_then(|r| r.ok()),
        InductionVariable::SCEV(_) => return false,
    };

    let Some(bound_dc_id) = bound else {
        return false;
    };

    let function = editor.func();

    // Check if it is do-while loop.
    let loop_exit_projection = editor
        .get_users(loop_if)
        .filter(|id| !l.control[id.idx()])
        .next()
        .unwrap();

    let loop_continue_projection = editor
        .get_users(loop_if)
        .filter(|id| l.control[id.idx()])
        .next()
        .unwrap();

    let loop_preds: Vec<_> = editor
        .get_uses(l.header)
        .filter(|id| !l.control[id.idx()])
        .collect();

    // FIXME: @xrouth
    if loop_preds.len() != 1 {
        return false;
    }

    let loop_pred = loop_preds[0];

    if !editor
        .get_uses(l.header)
        .contains(&loop_continue_projection)
    {
        return false;
    }

    // Get all phis used outside of the loop, they need to be reductionable.
    // For now just assume all phis will be phis used outside of the loop, except for the canonical iv.
    // FIXME: We need a different definiton of `loop_nodes` to check for phis used outside hte loop than the one
    // we currently have.
    let loop_nodes = calculate_loop_nodes(editor, l);

    // Check phis to see if they are reductionable, only PHIs depending on the loop are considered,
    let candidate_phis: Vec<_> = editor
        .get_users(l.header)
        .filter(|id| function.nodes[id.idx()].is_phi())
        .filter(|id| *id != canonical_iv.phi())
        .collect();

    let reductionable_phis: Vec<_> = analyze_phis(&editor, &l, &candidate_phis, &loop_nodes)
        .into_iter()
        .collect();
    // TODO: Handle multiple loop body lasts.
    // If there are multiple candidates for loop body last, return false.
    if editor
        .get_uses(loop_if)
        .filter(|id| l.control[id.idx()])
        .count()
        > 1
    {
        return false;
    }

    let loop_body_last = editor.get_uses(loop_if).next().unwrap();

    if reductionable_phis
        .iter()
        .any(|phi| !matches!(phi, LoopPHI::Reductionable { .. }))
    {
        return false;
    }

    let phi_latches: Vec<_> = reductionable_phis
        .iter()
        .map(|phi| {
            let LoopPHI::Reductionable {
                phi: _,
                data_cycle: _,
                continue_latch,
                is_associative: _,
            } = phi
            else {
                unreachable!()
            };
            continue_latch
        })
        .collect();

    let stop_on: HashSet<_> = editor
        .node_ids()
        .filter(|node| {
            if editor.node(node).is_phi() {
                return true;
            }
            if editor.node(node).is_reduce() {
                return true;
            }
            if editor.node(node).is_control() {
                return true;
            }
            if phi_latches.contains(&node) {
                return true;
            }

            false
        })
        .collect();

    // Outside loop users of IV, then exit;
    // Unless the outside user is through the loop latch of a reducing phi,
    // then we know how to replace this edge, so its fine!
    let iv_users: Vec<_> =
        walk_all_users_stop_on(canonical_iv.phi(), editor, stop_on.clone()).collect();

    if iv_users
        .iter()
        .any(|node| !loop_nodes.contains(&node) && *node != loop_if)
    {
        return false;
    }

    // Start Transformation:

    // Graft everything between header and loop condition
    // Attach join to right before header (after loop_body_last, unless loop body last *is* the header).
    // Attach fork to right after loop_continue_projection.

    // // Create fork and join nodes:
    let mut join_id = NodeID::new(0);
    let mut fork_id = NodeID::new(0);

    // Turn dc bound into max (1, bound),
    let bound_dc_id = {
        let mut max_id = DynamicConstantID::new(0);
        editor.edit(|mut edit| {
            let one_id = edit.add_dynamic_constant(DynamicConstant::Constant(1));
            max_id = edit.add_dynamic_constant(DynamicConstant::max(one_id, bound_dc_id));
            Ok(edit)
        });
        max_id
    };

    // FIXME: (@xrouth) double check handling of control in loop body.
    editor.edit(|mut edit| {
        let fork = Node::Fork {
            control: loop_pred,
            factors: Box::new([bound_dc_id]),
        };
        fork_id = edit.add_node(fork);

        let join = Node::Join {
            control: if l.header == loop_body_last {
                fork_id
            } else {
                loop_body_last
            },
        };

        join_id = edit.add_node(join);

        Ok(edit)
    });

    let function = editor.func();
    let (_, factors) = function.nodes[fork_id.idx()].try_fork().unwrap();
    let dimension = factors.len() - 1;

    let redcutionable_phis_and_init: Vec<(_, NodeID)> = reductionable_phis
        .iter()
        .map(|reduction_phi| {
            let LoopPHI::Reductionable {
                phi,
                data_cycle: _,
                continue_latch: _,
                is_associative: _,
            } = reduction_phi
            else {
                panic!();
            };

            let function = editor.func();

            let init = *zip(
                editor.get_uses(l.header),
                function.nodes[phi.idx()].try_phi().unwrap().1.iter(),
            )
            .filter(|(c, _)| *c == loop_pred)
            .next()
            .unwrap()
            .1;

            (reduction_phi, init)
        })
        .collect();

    // Start failable edit:
    let result = editor.edit(|mut edit| {
        let thread_id = Node::ThreadID {
            control: fork_id,
            dimension: dimension,
        };
        let thread_id_id = edit.add_node(thread_id);

        // Replace uses that are inside with the thread id
        edit = edit.replace_all_uses_where(canonical_iv.phi(), thread_id_id, |node| {
            loop_nodes.contains(node)
        })?;
        edit.sub_edit(canonical_iv.phi(), thread_id_id);

        edit = edit.delete_node(canonical_iv.phi())?;

        for (reduction_phi, init) in redcutionable_phis_and_init {
            let LoopPHI::Reductionable {
                phi,
                data_cycle: _,
                continue_latch,
                is_associative: _,
            } = *reduction_phi
            else {
                panic!();
            };

            let reduce = Node::Reduce {
                control: join_id,
                init,
                reduct: continue_latch,
            };

            let reduce_id = edit.add_node(reduce);

            if (!edit.get_node(init).is_reduce()
                && edit.get_schedule(init).contains(&Schedule::ParallelReduce))
                || (!edit.get_node(continue_latch).is_reduce()
                    && edit
                        .get_schedule(continue_latch)
                        .contains(&Schedule::ParallelReduce))
            {
                edit = edit.add_schedule(reduce_id, Schedule::ParallelReduce)?;
            }
            if (!edit.get_node(init).is_reduce()
                && edit.get_schedule(init).contains(&Schedule::MonoidReduce))
                || (!edit.get_node(continue_latch).is_reduce()
                    && edit
                        .get_schedule(continue_latch)
                        .contains(&Schedule::MonoidReduce))
            {
                edit = edit.add_schedule(reduce_id, Schedule::MonoidReduce)?;
            }

            edit = edit.replace_all_uses_where(phi, reduce_id, |usee| *usee != reduce_id)?;
            edit = edit.replace_all_uses_where(continue_latch, reduce_id, |usee| {
                !loop_nodes.contains(usee) && *usee != reduce_id
            })?;
            edit.sub_edit(phi, reduce_id);
            edit = edit.delete_node(phi)?
        }

        edit = edit.replace_all_uses(l.header, fork_id)?;
        edit = edit.replace_all_uses(loop_continue_projection, fork_id)?;
        edit = edit.replace_all_uses(loop_exit_projection, join_id)?;
        edit.sub_edit(l.header, fork_id);
        edit.sub_edit(loop_continue_projection, fork_id);
        edit.sub_edit(loop_exit_projection, join_id);

        edit = edit.delete_node(loop_continue_projection)?;
        edit = edit.delete_node(condition_node)?; // Might have to get rid of other users of this.
        edit = edit.delete_node(loop_exit_projection)?;
        edit = edit.delete_node(loop_if)?;
        edit = edit.delete_node(l.header)?;
        Ok(edit)
    });

    return result;
}

nest! {
    #[derive(Debug)]
    pub enum LoopPHI {
        Reductionable {
            phi: NodeID,
            data_cycle: HashSet<NodeID>, // All nodes in a data cycle with this phi
            continue_latch: NodeID,
            is_associative: bool,
        },
        LoopDependant(NodeID),
        ControlDependant(NodeID), // This phi is redcutionable, but its cycle might depend on internal control within the loop.
        UsedByDependant(NodeID),
    }
}

impl LoopPHI {
    pub fn get_phi(&self) -> NodeID {
        match self {
            LoopPHI::Reductionable { phi, .. } => *phi,
            LoopPHI::LoopDependant(node_id) => *node_id,
            LoopPHI::UsedByDependant(node_id) => *node_id,
            LoopPHI::ControlDependant(node_id) => *node_id,
        }
    }
}

/**
Checks some conditions on loop variables that will need to be converted into reductions to be forkified.
 - The phi is in a cycle *in the loop* with itself.
 - Every cycle *in the loop* containing the phi does not contain any other phi of the loop header.
 - The phi does not immediatley (not blocked by another phi or another reduce) use any other phis of the loop header.
 */
pub fn analyze_phis<'a>(
    editor: &'a FunctionEditor,
    natural_loop: &'a Loop,
    phis: &'a [NodeID],
    loop_nodes: &'a HashSet<NodeID>,
) -> impl Iterator<Item = LoopPHI> + 'a {
    // Find data cycles within the loop of this phi,
    // Start from the phis loop_continue_latch, and walk its uses until we find the original phi.

    phis.into_iter().map(move |phi| {
        let stop_on: HashSet<NodeID> = editor
            .node_ids()
            .filter(|node| {
                let data = &editor.func().nodes[node.idx()];

                // External Phi
                if let Node::Phi { control, data: _ } = data {
                    if !natural_loop.control[control.idx()] {
                        return true;
                    }
                }

                // This phi
                if node == phi {
                    return true;
                }

                // External Reduce
                if let Node::Reduce {
                    control,
                    init: _,
                    reduct: _,
                } = data
                {
                    if !natural_loop.control[control.idx()] {
                        return true;
                    } else {
                        return false;
                    }
                }

                // Data Cycles Only
                if data.is_control() {
                    return true;
                }

                return false;
            })
            .collect();
        let continue_idx = editor
            .get_uses(natural_loop.header)
            .position(|node| natural_loop.control[node.idx()])
            .unwrap();

        let loop_continue_latch = editor.node(phi).try_phi().unwrap().1[continue_idx];

        let uses = walk_all_uses_stop_on(loop_continue_latch, editor, stop_on.clone());
        let users = walk_all_users_stop_on(*phi, editor, stop_on.clone());

        let other_stop_on: HashSet<NodeID> = editor
            .node_ids()
            .filter(|node| {
                let data = &editor.func().nodes[node.idx()];

                // Phi, Reduce
                if data.is_phi() {
                    return true;
                }

                if data.is_reduce() {
                    return true;
                }

                // External Control
                if data.is_control() {
                    return true;
                }

                return false;
            })
            .collect();

        let mut uses_for_dependance =
            walk_all_users_stop_on(loop_continue_latch, editor, other_stop_on);

        let set1: HashSet<_> = HashSet::from_iter(uses);
        let set2: HashSet<_> = HashSet::from_iter(users);

        let intersection: HashSet<_> = set1.intersection(&set2).cloned().collect();

        // If this phi uses any other phis the node is loop dependant,
        // we use `phis` because this phi can actually contain the loop iv and its fine.
        if uses_for_dependance.any(|node| phis.contains(&node) && node != *phi) {
            LoopPHI::LoopDependant(*phi)
        } else if intersection.clone().iter().next().is_some() {
            // PHIs on the frontier of the uses by the candidate phi, i.e in uses_for_dependance need
            // to have headers that postdominate the loop continue latch. The value of the PHI used needs to be defined
            // by the time the reduce is triggered (at the end of the loop's internal control).

            // No nodes in data cycles with this phi (in the loop) are used outside the loop, besides the loop_continue_latch.
            // If some other node in the cycle is used, there is not a valid node to assign it after making the cycle a reduce.
            if intersection
                .iter()
                .filter(|node| **node != loop_continue_latch)
                .any(|data_node| {
                    editor
                        .get_users(*data_node)
                        .any(|user| !loop_nodes.contains(&user))
                })
            {
                // This phi can be made into a reduce in different ways, if the cycle is associative (contains all the same kind of associative op)
                // 3) Split the cycle into two phis, add them or multiply them together at the end.
                // 4) Split the cycle into two reduces, add them or multiply them together at the end.
                // Somewhere else should handle this.
                return LoopPHI::LoopDependant(*phi);
            }

            // FIXME: Do we want to calculate associativity here, there might be a case where this information is used in forkify
            // i.e as described above.
            let is_associative = false;

            // No nodes in the data cycle are used outside of the loop, besides the latched value of the phi
            LoopPHI::Reductionable {
                phi: *phi,
                data_cycle: intersection,
                continue_latch: loop_continue_latch,
                is_associative,
            }
        } else {
            // No cycles exist, this isn't a reduction.
            LoopPHI::LoopDependant(*phi)
        }
    })
}