Skip to content
Snippets Groups Projects
fork_transforms.rs 60.6 KiB
Newer Older
  • Learn to ignore specific revisions
  • use std::collections::{HashMap, HashSet};
    
    
    use bimap::BiMap;
    use itertools::Itertools;
    
    use hercules_ir::*;
    
    use crate::*;
    
    type ForkID = usize;
    
    /** Places each reduce node into its own fork */
    pub fn default_reduce_partition(
        editor: &FunctionEditor,
        _fork: NodeID,
        join: NodeID,
    ) -> SparseNodeMap<ForkID> {
        let mut map = SparseNodeMap::new();
    
        editor
            .get_users(join)
            .filter(|id| editor.func().nodes[id.idx()].is_reduce())
            .enumerate()
            .for_each(|(fork, reduce)| {
                map.insert(reduce, fork);
            });
    
        map
    }
    
    // TODO: Refine these conditions.
    /**  */
    pub fn find_reduce_dependencies<'a>(
        function: &'a Function,
        reduce: NodeID,
        fork: NodeID,
    ) -> impl IntoIterator<Item = NodeID> + 'a {
        let len = function.nodes.len();
    
        let mut visited: DenseNodeMap<bool> = vec![false; len];
        let mut depdendent: DenseNodeMap<bool> = vec![false; len];
    
        // Does `fork` need to be a parameter here? It never changes. If this was a closure could it just capture it?
        fn recurse(
            function: &Function,
            node: NodeID,
            fork: NodeID,
            dependent_map: &mut DenseNodeMap<bool>,
            visited: &mut DenseNodeMap<bool>,
        ) -> () {
            // return through dependent_map {
    
            if visited[node.idx()] {
                return;
            }
    
            visited[node.idx()] = true;
    
            if node == fork {
                dependent_map[node.idx()] = true;
                return;
            }
    
            let binding = get_uses(&function.nodes[node.idx()]);
            let uses = binding.as_ref();
    
            for used in uses {
                recurse(function, *used, fork, dependent_map, visited);
            }
    
            dependent_map[node.idx()] = uses.iter().map(|id| dependent_map[id.idx()]).any(|a| a);
            return;
        }
    
        // Note: HACKY, the condition wwe want is 'all nodes  on any path from the fork to the reduce (in the forward graph), or the reduce to the fork (in the directed graph)
        // cycles break this, but we assume for now that the only cycles are ones that involve the reduce node
        // NOTE: (control may break this (i.e loop inside fork) is a cycle that isn't the reduce)
        // the current solution is just to mark the reduce as dependent at the start of traversing the graph.
        depdendent[reduce.idx()] = true;
    
        recurse(function, reduce, fork, &mut depdendent, &mut visited);
    
        // Return node IDs that are dependent
        let ret_val: Vec<_> = depdendent
            .iter()
            .enumerate()
            .filter_map(|(idx, dependent)| {
                if *dependent {
                    Some(NodeID::new(idx))
                } else {
                    None
                }
            })
            .collect();
    
        ret_val
    }
    
    
    Xavier Routh's avatar
    Xavier Routh committed
    pub fn copy_subgraph_in_edit<'a, 'b>(
        mut edit: FunctionEdit<'a, 'b>,
        subgraph: HashSet<NodeID>,
    ) -> Result<(FunctionEdit<'a, 'b>, HashMap<NodeID, NodeID>), FunctionEdit<'a, 'b>> {
        let mut map: HashMap<NodeID, NodeID> = HashMap::new();
    
        // Copy nodes in subgraph
        for old_id in subgraph.iter().cloned() {
            let new_id = edit.copy_node(old_id);
            map.insert(old_id, new_id);
        }
    
        // Update edges to new nodes
        for old_id in subgraph.iter() {
            edit = edit.replace_all_uses_where(*old_id, map[old_id], |node_id| {
                map.values().contains(node_id)
            })?;
        }
    
        Ok((edit, map))
    }
    
    
    pub fn copy_subgraph(
        editor: &mut FunctionEditor,
        subgraph: HashSet<NodeID>,
    ) -> (
        HashSet<NodeID>,
        HashMap<NodeID, NodeID>,
        Vec<(NodeID, NodeID)>,
    ) // returns all new nodes, a map from old nodes to new nodes, and 
      // a vec of pairs of nodes (old node, outside node) s.t old node -> outside node,
      // outside means not part of the original subgraph.
    {
        let mut map: HashMap<NodeID, NodeID> = HashMap::new();
        let mut new_nodes: HashSet<NodeID> = HashSet::new();
    
        // Copy nodes
        for old_id in subgraph.iter() {
            editor.edit(|mut edit| {
                let new_id = edit.copy_node(*old_id);
                map.insert(*old_id, new_id);
                new_nodes.insert(new_id);
                Ok(edit)
            });
        }
    
        // Update edges to new nodes
        for old_id in subgraph.iter() {
            // Replace all uses of old_id w/ new_id, where the use is in new_node
            editor.edit(|edit| {
                edit.replace_all_uses_where(*old_id, map[old_id], |node_id| new_nodes.contains(node_id))
            });
        }
    
        // Get all users that aren't in new_nodes.
        let mut outside_users = Vec::new();
    
        for node in new_nodes.iter() {
            for user in editor.get_users(*node) {
                if !new_nodes.contains(&user) {
                    outside_users.push((*node, user));
                }
            }
        }
    
        (new_nodes, map, outside_users)
    }
    
    
    Xavier Routh's avatar
    Xavier Routh committed
    pub fn find_bufferize_edges(
        editor: &mut FunctionEditor,
        fork: NodeID,
        loop_tree: &LoopTree,
        fork_join_map: &HashMap<NodeID, NodeID>,
        nodes_in_fork_joins: &HashMap<NodeID, HashSet<NodeID>>,
        data_label: &LabelID,
    ) -> HashSet<(NodeID, NodeID)> {
        let mut edges: HashSet<_> = HashSet::new();
    
        for node in &nodes_in_fork_joins[&fork] {
            // Edge from *has data label** to doesn't have data label*
            let node_labels = &editor.func().labels[node.idx()];
    
            if !node_labels.contains(data_label) {
                continue;
            }
    
            // Don't draw bufferize edges from fork tids
            if editor.get_users(fork).contains(node) {
                continue;
            }
    
            for user in editor.get_users(*node) {
                let user_labels = &editor.func().labels[user.idx()];
                if user_labels.contains(data_label) {
                    continue;
                }
    
                if editor.node(user).is_control() || editor.node(node).is_control() {
                    continue;
                }
    
                edges.insert((*node, user));
            }
        }
        edges
    }
    
    
    rarbore2's avatar
    rarbore2 committed
    pub fn ff_bufferize_create_not_reduce_cycle_label_helper(
        editor: &mut FunctionEditor,
        fork: NodeID,
        fork_join_map: &HashMap<NodeID, NodeID>,
        reduce_cycles: &HashMap<NodeID, HashSet<NodeID>>,
        nodes_in_fork_joins: &HashMap<NodeID, HashSet<NodeID>>,
    ) -> LabelID {
        let join = fork_join_map[&fork];
        let mut nodes_not_in_a_reduce_cycle = nodes_in_fork_joins[&fork].clone();
        for (cycle, reduce) in editor
            .get_users(join)
            .filter_map(|id| reduce_cycles.get(&id).map(|cycle| (cycle, id)))
        {
            nodes_not_in_a_reduce_cycle.remove(&reduce);
            for id in cycle {
                nodes_not_in_a_reduce_cycle.remove(id);
            }
        }
        nodes_not_in_a_reduce_cycle.remove(&join);
    
        let mut label = LabelID::new(0);
        let success = editor.edit(|mut edit| {
            label = edit.fresh_label();
            for id in nodes_not_in_a_reduce_cycle {
                edit = edit.add_label(id, label)?;
            }
            Ok(edit)
        });
    
        assert!(success);
        label
    }
    
    
    Xavier Routh's avatar
    Xavier Routh committed
    pub fn ff_bufferize_any_fork<'a, 'b>(
        editor: &'b mut FunctionEditor<'a>,
        loop_tree: &'b LoopTree,
        fork_join_map: &'b HashMap<NodeID, NodeID>,
    
    rarbore2's avatar
    rarbore2 committed
        reduce_cycles: &HashMap<NodeID, HashSet<NodeID>>,
    
    Xavier Routh's avatar
    Xavier Routh committed
        nodes_in_fork_joins: &'b HashMap<NodeID, HashSet<NodeID>>,
        typing: &'b Vec<TypeID>,
    
    rarbore2's avatar
    rarbore2 committed
        fork_label: LabelID,
        data_label: Option<LabelID>,
    
    Xavier Routh's avatar
    Xavier Routh committed
    ) -> Option<(NodeID, NodeID)>
    where
        'a: 'b,
    {
    
    rarbore2's avatar
    rarbore2 committed
        let mut forks: Vec<_> = loop_tree
    
    Xavier Routh's avatar
    Xavier Routh committed
            .bottom_up_loops()
            .into_iter()
            .filter(|(k, _)| editor.func().nodes[k.idx()].is_fork())
            .collect();
    
    rarbore2's avatar
    rarbore2 committed
        forks.reverse();
    
    Xavier Routh's avatar
    Xavier Routh committed
    
        for l in forks {
            let fork_info = Loop {
                header: l.0,
                control: l.1.clone(),
            };
            let fork = fork_info.header;
            let join = fork_join_map[&fork];
    
    
    rarbore2's avatar
    rarbore2 committed
            if !editor.func().labels[fork.idx()].contains(&fork_label) {
    
    Xavier Routh's avatar
    Xavier Routh committed
                continue;
            }
    
    
    rarbore2's avatar
    rarbore2 committed
            let data_label = data_label.unwrap_or_else(|| {
                ff_bufferize_create_not_reduce_cycle_label_helper(
                    editor,
                    fork,
                    fork_join_map,
                    reduce_cycles,
                    nodes_in_fork_joins,
                )
            });
    
    Xavier Routh's avatar
    Xavier Routh committed
            let edges = find_bufferize_edges(
                editor,
                fork,
                &loop_tree,
                &fork_join_map,
                &nodes_in_fork_joins,
    
    rarbore2's avatar
    rarbore2 committed
                &data_label,
    
    Xavier Routh's avatar
    Xavier Routh committed
            );
            let result = fork_bufferize_fission_helper(
                editor,
                &fork_info,
                &edges,
                nodes_in_fork_joins,
                typing,
                fork,
                join,
            );
            if result.is_none() {
                continue;
            } else {
                return result;
            }
        }
        return None;
    }
    
    
    pub fn fork_fission<'a>(
        editor: &'a mut FunctionEditor,
    
    Xavier Routh's avatar
    Xavier Routh committed
        nodes_in_fork_joins: &HashMap<NodeID, HashSet<NodeID>>,
        reduce_cycles: &HashMap<NodeID, HashSet<NodeID>>,
        loop_tree: &LoopTree,
    
        fork_join_map: &HashMap<NodeID, NodeID>,
    
    Xavier Routh's avatar
    Xavier Routh committed
        fork_label: LabelID,
    ) -> Vec<NodeID> {
        let forks: Vec<_> = loop_tree
            .bottom_up_loops()
            .into_iter()
            .filter(|(k, _)| editor.func().nodes[k.idx()].is_fork())
    
    Xavier Routh's avatar
    Xavier Routh committed
        let mut created_forks = Vec::new();
    
    rarbore2's avatar
    rarbore2 committed
    
        // This does the reduction fission
    
    Xavier Routh's avatar
    Xavier Routh committed
        for fork in forks {
            let join = fork_join_map[&fork.0];
    
    rarbore2's avatar
    rarbore2 committed
            // FIXME: Don't make multiple forks for reduces that are in cycles with each other.
    
    Xavier Routh's avatar
    Xavier Routh committed
            let reduce_partition = default_reduce_partition(editor, fork.0, join);
    
            if !editor.func().labels[fork.0.idx()].contains(&fork_label) {
                continue;
    
    Xavier Routh's avatar
    Xavier Routh committed
    
            if editor.is_mutable(fork.0) {
    
    rarbore2's avatar
    rarbore2 committed
                created_forks = fork_reduce_fission_helper(
                    editor,
                    fork_join_map,
                    reduce_partition,
                    nodes_in_fork_joins,
                    fork.0,
                );
    
    Xavier Routh's avatar
    Xavier Routh committed
                if created_forks.is_empty() {
                    continue;
                } else {
                    return created_forks;
                }
            }
    
    Xavier Routh's avatar
    Xavier Routh committed
    
        created_forks
    
    }
    
    /** Split a 1D fork into two forks, placing select intermediate data into buffers. */
    
    Xavier Routh's avatar
    Xavier Routh committed
    pub fn fork_bufferize_fission_helper<'a, 'b>(
        editor: &'b mut FunctionEditor<'a>,
        l: &Loop,
        bufferized_edges: &HashSet<(NodeID, NodeID)>, // Describes what intermediate data should be bufferized.
        data_node_in_fork_joins: &'b HashMap<NodeID, HashSet<NodeID>>,
        types: &'b Vec<TypeID>,
    
    Xavier Routh's avatar
    Xavier Routh committed
        join: NodeID,
    ) -> Option<(NodeID, NodeID)>
    where
        'a: 'b,
    {
        if bufferized_edges.is_empty() {
            return None;
        }
    
    Xavier Routh's avatar
    Xavier Routh committed
        let all_loop_nodes = l.get_all_nodes();
    
    Xavier Routh's avatar
    Xavier Routh committed
        // FIXME: Cloning hell.
        let data_nodes = data_node_in_fork_joins[&fork].clone();
        let loop_nodes = editor
            .node_ids()
            .filter(|node_id| all_loop_nodes[node_id.idx()]);
        // Clone the subgraph that consists of this fork-join and all data and control nodes in it.
        let subgraph = HashSet::from_iter(data_nodes.into_iter().chain(loop_nodes));
    
    Xavier Routh's avatar
    Xavier Routh committed
        let mut outside_users = Vec::new(); // old_node, outside_user
    
        for node in subgraph.iter() {
            for user in editor.get_users(*node) {
                if !subgraph.iter().contains(&user) {
                    outside_users.push((*node, user));
                }
            }
        }
    
        let factors: Vec<_> = editor.func().nodes[fork.idx()]
            .try_fork()
            .unwrap()
            .1
            .iter()
            .cloned()
            .collect();
    
        let thread_stuff_it = factors.into_iter().enumerate();
    
        // Control succesors
        let fork_pred = editor
            .get_uses(fork)
            .filter(|a| editor.node(a).is_control())
            .next()
            .unwrap();
        let join_successor = editor
            .get_users(join)
            .filter(|a| editor.node(a).is_control())
            .next()
            .unwrap();
    
    
        let mut new_fork_id = NodeID::new(0);
    
    
    Xavier Routh's avatar
    Xavier Routh committed
        let edit_result = editor.edit(|edit| {
            let (mut edit, map) = copy_subgraph_in_edit(edit, subgraph)?;
    
            edit = edit.replace_all_uses_where(fork_pred, join, |a| *a == map[&fork])?;
            edit = edit.replace_all_uses_where(join, map[&join], |a| *a == join_successor)?;
    
            // Replace outside uses of reduces in old subgraph with reduces in new subgraph.
            for (old_node, outside_user) in outside_users {
                edit = edit
                    .replace_all_uses_where(old_node, map[&old_node], |node| *node == outside_user)?;
            }
    
    Xavier Routh's avatar
    Xavier Routh committed
            let new_fork = map[&fork];
    
    Xavier Routh's avatar
    Xavier Routh committed
            // FIXME: Do this as part of copy subgraph?
            // Add tids to original subgraph for indexing.
            let mut old_tids = Vec::new();
            let mut new_tids = Vec::new();
            for (dim, _) in thread_stuff_it.clone() {
                let old_id = edit.add_node(Node::ThreadID {
                    control: fork,
                    dimension: dim,
    
    Xavier Routh's avatar
    Xavier Routh committed
                let new_id = edit.add_node(Node::ThreadID {
                    control: new_fork,
                    dimension: dim,
                });
    
                old_tids.push(old_id);
                new_tids.push(new_id);
            }
    
            for (src, dst) in bufferized_edges {
    
                let array_dims = thread_stuff_it.clone().map(|(_, factor)| (factor));
    
    Xavier Routh's avatar
    Xavier Routh committed
                let position_idx = Index::Position(old_tids.clone().into_boxed_slice());
    
    
                let write = edit.add_node(Node::Write {
                    collect: NodeID::new(0),
    
    Xavier Routh's avatar
    Xavier Routh committed
                    data: *src,
                    indices: vec![position_idx.clone()].into(),
    
                });
                let ele_type = types[src.idx()];
                let empty_buffer = edit.add_type(hercules_ir::Type::Array(
                    ele_type,
                    array_dims.collect::<Vec<_>>().into_boxed_slice(),
                ));
                let empty_buffer = edit.add_zero_constant(empty_buffer);
                let empty_buffer = edit.add_node(Node::Constant { id: empty_buffer });
    
    Xavier Routh's avatar
    Xavier Routh committed
                edit = edit.add_schedule(empty_buffer, Schedule::NoResetConstant)?;
    
                let reduce = Node::Reduce {
    
    Xavier Routh's avatar
    Xavier Routh committed
                    control: join,
    
                    init: empty_buffer,
                    reduct: write,
                };
                let reduce = edit.add_node(reduce);
    
    Xavier Routh's avatar
    Xavier Routh committed
                edit = edit.add_schedule(reduce, Schedule::ParallelReduce)?;
    
    
                // Fix write node
                edit = edit.replace_all_uses_where(NodeID::new(0), reduce, |usee| *usee == write)?;
    
    
    Xavier Routh's avatar
    Xavier Routh committed
                // Create reads from buffer
                let position_idx = Index::Position(new_tids.clone().into_boxed_slice());
    
    
                let read = edit.add_node(Node::Read {
                    collect: reduce,
                    indices: vec![position_idx].into(),
                });
    
    
    Xavier Routh's avatar
    Xavier Routh committed
                // Replaces uses of bufferized edge src with corresponding reduce and read in old subraph
                edit = edit.replace_all_uses_where(map[src], read, |usee| *usee == map[dst])?;
            }
    
    Xavier Routh's avatar
    Xavier Routh committed
            new_fork_id = new_fork;
    
    Xavier Routh's avatar
    Xavier Routh committed
            Ok(edit)
        });
    
        if edit_result {
            Some((fork, new_fork_id))
        } else {
            None
        }
    
    Xavier Routh's avatar
    Xavier Routh committed
    /** Split a fork into a separate fork for each reduction. */
    
    pub fn fork_reduce_fission_helper<'a>(
        editor: &'a mut FunctionEditor,
        fork_join_map: &HashMap<NodeID, NodeID>,
        reduce_partition: SparseNodeMap<ForkID>, // Describes how the reduces of the fork should be split,
    
    Xavier Routh's avatar
    Xavier Routh committed
        nodes_in_fork_joins: &HashMap<NodeID, HashSet<NodeID>>,
    
    Xavier Routh's avatar
    Xavier Routh committed
    ) -> Vec<NodeID> {
    
        let join = fork_join_map[&fork];
    
    
    Xavier Routh's avatar
    Xavier Routh committed
        let mut new_forks = Vec::new();
    
    rarbore2's avatar
    rarbore2 committed
        let mut new_control_pred: NodeID = editor
            .get_uses(fork)
            .filter(|n| editor.node(n).is_control())
            .next()
            .unwrap();
    
    
        let mut new_fork = NodeID::new(0);
        let mut new_join = NodeID::new(0);
    
    
    rarbore2's avatar
    rarbore2 committed
        let subgraph = &nodes_in_fork_joins[&fork];
    
    
        // Gets everything between fork & join that this reduce needs. (ALL CONTROL)
    
    Xavier Routh's avatar
    Xavier Routh committed
        editor.edit(|mut edit| {
            for reduce in reduce_partition {
                let reduce = reduce.0;
    
    Xavier Routh's avatar
    Xavier Routh committed
                let a = copy_subgraph_in_edit(edit, subgraph.clone())?;
                edit = a.0;
                let mapping = a.1;
    
    Xavier Routh's avatar
    Xavier Routh committed
                new_fork = mapping[&fork];
                new_forks.push(new_fork);
                new_join = mapping[&join];
    
    rarbore2's avatar
    rarbore2 committed
    
    
                // Atttach new_fork after control_pred
                let (old_control_pred, _) = edit.get_node(new_fork).try_fork().unwrap().clone();
                edit = edit.replace_all_uses_where(old_control_pred, new_control_pred, |usee| {
                    *usee == new_fork
                })?;
    
                // Replace uses of reduce
                edit = edit.replace_all_uses(reduce, mapping[&reduce])?;
    
    Xavier Routh's avatar
    Xavier Routh committed
                new_control_pred = new_join;
    
    rarbore2's avatar
    rarbore2 committed
            }
    
    
            // Replace original join w/ new final join
            edit = edit.replace_all_uses_where(join, new_join, |_| true)?;
    
            // Delete original join (all reduce users have been moved)
            edit = edit.delete_node(join)?;
    
            // Replace all users of original fork, and then delete it, leftover users will be DCE'd.
            edit = edit.replace_all_uses(fork, new_fork)?;
    
    Xavier Routh's avatar
    Xavier Routh committed
            edit = edit.delete_node(fork)?;
    
            Ok(edit)
    
    Xavier Routh's avatar
    Xavier Routh committed
        new_forks
    
    }
    
    pub fn fork_coalesce(
        editor: &mut FunctionEditor,
        loops: &LoopTree,
        fork_join_map: &HashMap<NodeID, NodeID>,
    ) -> bool {
        let fork_joins = loops.bottom_up_loops().into_iter().filter_map(|(k, _)| {
            if editor.func().nodes[k.idx()].is_fork() {
                Some(k)
            } else {
                None
            }
        });
    
        let fork_joins: Vec<_> = fork_joins.collect();
        // FIXME: Add a postorder traversal to optimize this.
    
        // FIXME: This could give us two forks that aren't actually ancestors / related, but then the helper will just return false early.
        // something like: `fork_joins.postorder_iter().windows(2)` is ideal here.
        for (inner, outer) in fork_joins.iter().cartesian_product(fork_joins.iter()) {
    
    Aaron Councilman's avatar
    Aaron Councilman committed
            if fork_coalesce_helper(editor, *outer, *inner, fork_join_map).is_some() {
    
                return true;
            }
        }
        return false;
    }
    
    /** Opposite of fork split, takes two fork-joins
        with no control between them, and merges them into a single fork-join.
    
    Aaron Councilman's avatar
    Aaron Councilman committed
        Returns None if the forks could not be merged and the NodeIDs of the
        resulting fork and join if it succeeds in merging them.
    
    */
    pub fn fork_coalesce_helper(
        editor: &mut FunctionEditor,
        outer_fork: NodeID,
        inner_fork: NodeID,
        fork_join_map: &HashMap<NodeID, NodeID>,
    
    Aaron Councilman's avatar
    Aaron Councilman committed
    ) -> Option<(NodeID, NodeID)> {
    
        // Check that all reduces in the outer fork are in *simple* cycles with a unique reduce of the inner fork.
    
        let outer_join = fork_join_map[&outer_fork];
        let inner_join = fork_join_map[&inner_fork];
    
        let mut pairs: BiMap<NodeID, NodeID> = BiMap::new(); // Outer <-> Inner
    
        // FIXME: Iterate all control uses of joins to really collect all reduces
        // (reduces can be attached to inner control)
        for outer_reduce in editor
            .get_users(outer_join)
            .filter(|node| editor.func().nodes[node.idx()].is_reduce())
        {
            // check that inner reduce is of the inner join
            let (_, _, outer_reduct) = editor.func().nodes[outer_reduce.idx()]
                .try_reduce()
                .unwrap();
    
            let inner_reduce = outer_reduct;
            let inner_reduce_node = &editor.func().nodes[outer_reduct.idx()];
    
            let Node::Reduce {
                control: inner_control,
                init: inner_init,
                reduct: _,
            } = inner_reduce_node
            else {
    
    Aaron Councilman's avatar
    Aaron Councilman committed
                return None;
    
            };
    
            // FIXME: check this condition better (i.e reduce might not be attached to join)
            if *inner_control != inner_join {
    
    Aaron Councilman's avatar
    Aaron Councilman committed
                return None;
    
            };
            if *inner_init != outer_reduce {
    
    Aaron Councilman's avatar
    Aaron Councilman committed
                return None;
    
            };
    
            if pairs.contains_left(&outer_reduce) || pairs.contains_right(&inner_reduce) {
    
    Aaron Councilman's avatar
    Aaron Councilman committed
                return None;
    
            } else {
                pairs.insert(outer_reduce, inner_reduce);
            }
        }
    
        // Check for control between join-join and fork-fork
    
    Aaron Councilman's avatar
    Aaron Councilman committed
        let (control, _) = editor.node(inner_fork).try_fork().unwrap();
    
    Aaron Councilman's avatar
    Aaron Councilman committed
        if control != outer_fork {
            return None;
    
    Aaron Councilman's avatar
    Aaron Councilman committed
        let control = editor.node(outer_join).try_join().unwrap();
    
    Aaron Councilman's avatar
    Aaron Councilman committed
        if control != inner_join {
            return None;
    
        }
    
        // Checklist:
        // Increment inner TIDs
        // Add outer fork's dimension to front of inner fork.
        // Fuse reductions
        //  - Initializer becomes outer initializer
        // Replace uses of outer fork w/ inner fork.
        // Replace uses of outer join w/ inner join.
        // Delete outer fork-join
    
        let inner_tids: Vec<NodeID> = editor
            .get_users(inner_fork)
            .filter(|node| editor.func().nodes[node.idx()].is_thread_id())
            .collect();
    
        let (outer_pred, outer_dims) = editor.func().nodes[outer_fork.idx()].try_fork().unwrap();
        let (_, inner_dims) = editor.func().nodes[inner_fork.idx()].try_fork().unwrap();
        let num_outer_dims = outer_dims.len();
        let mut new_factors = outer_dims.to_vec();
    
        // CHECKME / FIXME: Might need to be added the other way.
        new_factors.append(&mut inner_dims.to_vec());
    
    
    Aaron Councilman's avatar
    Aaron Councilman committed
        let mut new_fork = NodeID::new(0);
        let new_join = inner_join; // We'll reuse the inner join as the join of the new fork
    
        let success = editor.edit(|mut edit| {
            for tid in inner_tids {
                let (fork, dim) = edit.get_node(tid).try_thread_id().unwrap();
                let new_tid = Node::ThreadID {
                    control: fork,
                    dimension: dim + num_outer_dims,
                };
    
    
                let new_tid = edit.add_node(new_tid);
    
    Aaron Councilman's avatar
    Aaron Councilman committed
                edit = edit.replace_all_uses(tid, new_tid)?;
    
    rarbore2's avatar
    rarbore2 committed
                edit.sub_edit(tid, new_tid);
    
    Aaron Councilman's avatar
    Aaron Councilman committed
            }
            // Fuse Reductions
            for (outer_reduce, inner_reduce) in pairs {
                let (_, outer_init, _) = edit.get_node(outer_reduce).try_reduce().unwrap();
                let (_, inner_init, _) = edit.get_node(inner_reduce).try_reduce().unwrap();
    
                // Set inner init to outer init.
                edit =
                    edit.replace_all_uses_where(inner_init, outer_init, |usee| *usee == inner_reduce)?;
                edit = edit.replace_all_uses(outer_reduce, inner_reduce)?;
                edit = edit.delete_node(outer_reduce)?;
    
    Aaron Councilman's avatar
    Aaron Councilman committed
            }
    
    Aaron Councilman's avatar
    Aaron Councilman committed
            let new_fork_node = Node::Fork {
    
                control: outer_pred,
                factors: new_factors.into(),
            };
    
    Aaron Councilman's avatar
    Aaron Councilman committed
            new_fork = edit.add_node(new_fork_node);
    
            if edit
                .get_schedule(outer_fork)
                .contains(&Schedule::ParallelFork)
                && edit
                    .get_schedule(inner_fork)
                    .contains(&Schedule::ParallelFork)
            {
                edit = edit.add_schedule(new_fork, Schedule::ParallelFork)?;
            }
    
    
            edit = edit.replace_all_uses(inner_fork, new_fork)?;
            edit = edit.replace_all_uses(outer_fork, new_fork)?;
            edit = edit.replace_all_uses(outer_join, inner_join)?;
            edit = edit.delete_node(outer_join)?;
            edit = edit.delete_node(inner_fork)?;
            edit = edit.delete_node(outer_fork)?;
    
            Ok(edit)
        });
    
    
    Aaron Councilman's avatar
    Aaron Councilman committed
        if success {
            Some((new_fork, new_join))
        } else {
            None
        }
    
    pub fn split_any_fork(
    
        editor: &mut FunctionEditor,
        fork_join_map: &HashMap<NodeID, NodeID>,
        reduce_cycles: &HashMap<NodeID, HashSet<NodeID>>,
    
    ) -> Option<(Vec<NodeID>, Vec<NodeID>)> {
    
        for (fork, join) in fork_join_map {
    
            if let Some((forks, joins)) = split_fork(editor, *fork, *join, reduce_cycles)
    
                return Some((forks, joins));
    
    }
    
    /*
     * Split multi-dimensional fork-joins into separate one-dimensional fork-joins.
     * Useful for code generation. A single iteration of `fork_split` only splits
     * at most one fork-join, it must be called repeatedly to split all fork-joins.
     */
    
    Aaron Councilman's avatar
    Aaron Councilman committed
    pub fn split_fork(
    
        editor: &mut FunctionEditor,
        fork: NodeID,
        join: NodeID,
        reduce_cycles: &HashMap<NodeID, HashSet<NodeID>>,
    ) -> Option<(Vec<NodeID>, Vec<NodeID>)> {
        // A single multi-dimensional fork becomes multiple forks, a join becomes
        // multiple joins, a thread ID becomes a thread ID on the correct
        // fork, and a reduce becomes multiple reduces to shuffle the reduction
        // value through the fork-join nest.
        let nodes = &editor.func().nodes;
        let (fork_control, factors) = nodes[fork.idx()].try_fork().unwrap();
        if factors.len() < 2 {
            return Some((vec![fork], vec![join]));
        }
        let factors: Box<[DynamicConstantID]> = factors.into();
        let join_control = nodes[join.idx()].try_join().unwrap();
        let tids: Vec<_> = editor
            .get_users(fork)
            .filter(|id| nodes[id.idx()].is_thread_id())
            .collect();
        let reduces: Vec<_> = editor
            .get_users(join)
            .filter(|id| nodes[id.idx()].is_reduce())
            .collect();
    
        let data_in_reduce_cycle: HashSet<(NodeID, NodeID)> = reduces
            .iter()
            .map(|reduce| editor.get_users(*reduce).map(move |user| (user, *reduce)))
            .flatten()
            .filter(|(user, reduce)| reduce_cycles[&reduce].contains(&user))
            .collect();
    
        let mut new_forks = vec![];
        let mut new_joins = vec![];
        let success = editor.edit(|mut edit| {
            // Create the forks and a thread ID per fork.
            let mut acc_fork = fork_control;
            let mut new_tids = vec![];
            for factor in factors {
                acc_fork = edit.add_node(Node::Fork {
                    control: acc_fork,
                    factors: Box::new([factor]),
                });
                new_forks.push(acc_fork);
                edit.sub_edit(fork, acc_fork);
                new_tids.push(edit.add_node(Node::ThreadID {
                    control: acc_fork,
                    dimension: 0,
                }));
            }
    
            // Create the joins.
            let mut acc_join = if join_control == fork {
                acc_fork
            } else {
                join_control
            };
            for _ in new_tids.iter() {
                acc_join = edit.add_node(Node::Join { control: acc_join });
                edit.sub_edit(join, acc_join);
                new_joins.push(acc_join);
            }
    
            // Create the reduces.
            let mut new_reduces = vec![];
            for reduce in reduces.iter() {
                let (_, init, reduct) = edit.get_node(*reduce).try_reduce().unwrap();
                let num_nodes = edit.num_node_ids();
                let mut inner_reduce = NodeID::new(0);
                let mut outer_reduce = NodeID::new(0);
                for (join_idx, join) in new_joins.iter().enumerate() {
                    let init = if join_idx == new_joins.len() - 1 {
                        init
                    } else {
                        NodeID::new(num_nodes + join_idx + 1)
                    };
                    let reduct = if join_idx == 0 {
                        reduct
                    } else {
                        NodeID::new(num_nodes + join_idx - 1)
                    };
                    let new_reduce = edit.add_node(Node::Reduce {
                        control: *join,
                        init,
                        reduct,
                    });
                    assert_eq!(new_reduce, NodeID::new(num_nodes + join_idx));
                    edit.sub_edit(*reduce, new_reduce);
                    if join_idx == 0 {
                        inner_reduce = new_reduce;
                    }
                    if join_idx == new_joins.len() - 1 {
                        outer_reduce = new_reduce;
                    }
                }
                new_reduces.push((inner_reduce, outer_reduce));
            }
    
            // Replace everything.
            edit = edit.replace_all_uses(fork, acc_fork)?;
            edit = edit.replace_all_uses(join, acc_join)?;
            for tid in tids.iter() {
                let dim = edit.get_node(*tid).try_thread_id().unwrap().1;
                edit.sub_edit(*tid, new_tids[dim]);
                edit = edit.replace_all_uses(*tid, new_tids[dim])?;
            }
            for (reduce, (inner_reduce, outer_reduce)) in zip(reduces.iter(), new_reduces) {
                edit = edit.replace_all_uses_where(*reduce, inner_reduce, |id| {
                    data_in_reduce_cycle.contains(&(*id, *reduce))
                })?;
                edit = edit.replace_all_uses_where(*reduce, outer_reduce, |id| {
                    !data_in_reduce_cycle.contains(&(*id, *reduce))
                })?;
            }
    
            // Delete all the old stuff.
            edit = edit.delete_node(fork)?;
            edit = edit.delete_node(join)?;
            for tid in tids {
                edit = edit.delete_node(tid)?;
            }
            for reduce in reduces {
                edit = edit.delete_node(reduce)?;
            }
    
            Ok(edit)
        });
        if success {
    
            new_joins.reverse();
    
            Some((new_forks, new_joins))
        } else {
            None
        }
    }
    
    Xavier Routh's avatar
    Xavier Routh committed
    
    pub fn chunk_all_forks_unguarded(
        editor: &mut FunctionEditor,
        fork_join_map: &HashMap<NodeID, NodeID>,
        dim_idx: usize,
        tile_size: usize,
    
        order: bool,
    
    Xavier Routh's avatar
    Xavier Routh committed
    ) -> () {
        // Add dc
        let mut dc_id = DynamicConstantID::new(0);
        editor.edit(|mut edit| {
            dc_id = edit.add_dynamic_constant(DynamicConstant::Constant(tile_size));
            Ok(edit)
        });
    
    
        let order = match order {
            true => &TileOrder::TileInner,
            false => &TileOrder::TileOuter,
        };
    
    
    Xavier Routh's avatar
    Xavier Routh committed
        for (fork, _) in fork_join_map {
    
    rarbore2's avatar
    rarbore2 committed
            if editor.is_mutable(*fork) {
                chunk_fork_unguarded(editor, *fork, dim_idx, dc_id, order);
            }
    
    Xavier Routh's avatar
    Xavier Routh committed
        }
    }
    // Splits a dimension of a single fork join into multiple.
    // Iterates an outer loop original_dim / tile_size times
    // adds a tile_size loop as the inner loop
    // Assumes that tile size divides original dim evenly.
    
    
    enum TileOrder {
        TileInner,
        TileOuter,
    }
    
    
    Xavier Routh's avatar
    Xavier Routh committed
    pub fn chunk_fork_unguarded(
        editor: &mut FunctionEditor,
        fork: NodeID,
        dim_idx: usize,
        tile_size: DynamicConstantID,
    
        order: &TileOrder,
    
    Xavier Routh's avatar
    Xavier Routh committed
    ) -> () {
        // tid_dim_idx = tid_dim_idx * tile_size + tid_(dim_idx + 1)
        let Node::Fork {
            control: old_control,
            factors: ref old_factors,
        } = *editor.node(fork)
        else {
            return;
        };
        assert!(dim_idx < old_factors.len());
        let mut new_factors: Vec<_> = old_factors.to_vec();
        let fork_users: Vec<_> = editor
            .get_users(fork)
            .map(|f| (f, editor.node(f).clone()))
            .collect();
    
    
        match order {
            TileOrder::TileInner => {
                editor.edit(|mut edit| {
                    let outer = DynamicConstant::div(new_factors[dim_idx], tile_size);
                    new_factors.insert(dim_idx + 1, tile_size);
                    new_factors[dim_idx] = edit.add_dynamic_constant(outer);
    
    rarbore2's avatar
    rarbore2 committed
    
    
                    let new_fork = Node::Fork {
                        control: old_control,
                        factors: new_factors.into(),
    
    Xavier Routh's avatar
    Xavier Routh committed
                    };
    
                    let new_fork = edit.add_node(new_fork);
    
    rarbore2's avatar
    rarbore2 committed
    
    
                    edit = edit.replace_all_uses(fork, new_fork)?;
                    edit.sub_edit(fork, new_fork);
    
    rarbore2's avatar
    rarbore2 committed
    
    
                    for (tid, node) in fork_users {
                        let Node::ThreadID {
                            control: _,
                            dimension: tid_dim,
                        } = node
                        else {
                            continue;
                        };
                        if tid_dim > dim_idx {
                            let new_tid = Node::ThreadID {
                                control: new_fork,
                                dimension: tid_dim + 1,
                            };
                            let new_tid = edit.add_node(new_tid);
                            edit = edit.replace_all_uses(tid, new_tid)?;
                            edit.sub_edit(tid, new_tid);
                            edit = edit.delete_node(tid)?;
                        } else if tid_dim == dim_idx {
                            let tile_tid = Node::ThreadID {
                                control: new_fork,
                                dimension: tid_dim + 1,
                            };
                            let tile_tid = edit.add_node(tile_tid);
    
    rarbore2's avatar
    rarbore2 committed
    
    
                            let tile_size = edit.add_node(Node::DynamicConstant { id: tile_size });
                            let mul = edit.add_node(Node::Binary {
                                left: tid,
                                right: tile_size,
                                op: BinaryOperator::Mul,
                            });
                            let add = edit.add_node(Node::Binary {
                                left: mul,
                                right: tile_tid,
                                op: BinaryOperator::Add,