Skip to content
Snippets Groups Projects
fork_transforms.rs 17.15 KiB
use std::collections::{HashMap, HashSet};

use bimap::BiMap;
use itertools::Itertools;

use hercules_ir::*;

use crate::*;

type ForkID = usize;

/** Places each reduce node into its own fork */
pub fn default_reduce_partition(
    editor: &FunctionEditor,
    _fork: NodeID,
    join: NodeID,
) -> SparseNodeMap<ForkID> {
    let mut map = SparseNodeMap::new();

    editor
        .get_users(join)
        .filter(|id| editor.func().nodes[id.idx()].is_reduce())
        .enumerate()
        .for_each(|(fork, reduce)| {
            map.insert(reduce, fork);
        });

    map
}

// TODO: Refine these conditions.
/**  */
pub fn find_reduce_dependencies<'a>(
    function: &'a Function,
    reduce: NodeID,
    fork: NodeID,
) -> impl IntoIterator<Item = NodeID> + 'a {
    let len = function.nodes.len();

    let mut visited: DenseNodeMap<bool> = vec![false; len];
    let mut depdendent: DenseNodeMap<bool> = vec![false; len];

    // Does `fork` need to be a parameter here? It never changes. If this was a closure could it just capture it?
    fn recurse(
        function: &Function,
        node: NodeID,
        fork: NodeID,
        dependent_map: &mut DenseNodeMap<bool>,
        visited: &mut DenseNodeMap<bool>,
    ) -> () {
        // return through dependent_map {

        if visited[node.idx()] {
            return;
        }

        visited[node.idx()] = true;

        if node == fork {
            dependent_map[node.idx()] = true;
            return;
        }

        let binding = get_uses(&function.nodes[node.idx()]);
        let uses = binding.as_ref();

        for used in uses {
            recurse(function, *used, fork, dependent_map, visited);
        }
        dependent_map[node.idx()] = uses.iter().map(|id| dependent_map[id.idx()]).any(|a| a);
        return;
    }

    // Note: HACKY, the condition wwe want is 'all nodes  on any path from the fork to the reduce (in the forward graph), or the reduce to the fork (in the directed graph)
    // cycles break this, but we assume for now that the only cycles are ones that involve the reduce node
    // NOTE: (control may break this (i.e loop inside fork) is a cycle that isn't the reduce)
    // the current solution is just to mark the reduce as dependent at the start of traversing the graph.
    depdendent[reduce.idx()] = true;

    recurse(function, reduce, fork, &mut depdendent, &mut visited);

    // Return node IDs that are dependent
    let ret_val: Vec<_> = depdendent
        .iter()
        .enumerate()
        .filter_map(|(idx, dependent)| {
            if *dependent {
                Some(NodeID::new(idx))
            } else {
                None
            }
        })
        .collect();

    ret_val
}

pub fn copy_subgraph(
    editor: &mut FunctionEditor,
    subgraph: HashSet<NodeID>,
) -> (
    HashSet<NodeID>,
    HashMap<NodeID, NodeID>,
    Vec<(NodeID, NodeID)>,
) // returns all new nodes, a map from old nodes to new nodes, and 
  // a vec of pairs of nodes (old node, outside node) s.t old node -> outside node,
  // outside means not part of the original subgraph.
{
    let mut map: HashMap<NodeID, NodeID> = HashMap::new();
    let mut new_nodes: HashSet<NodeID> = HashSet::new();

    // Copy nodes
    for old_id in subgraph.iter() {
        editor.edit(|mut edit| {
            let new_id = edit.copy_node(*old_id);
            map.insert(*old_id, new_id);
            new_nodes.insert(new_id);
            Ok(edit)
        });
    }

    // Update edges to new nodes
    for old_id in subgraph.iter() {
        // Replace all uses of old_id w/ new_id, where the use is in new_node
        editor.edit(|edit| {
            edit.replace_all_uses_where(*old_id, map[old_id], |node_id| new_nodes.contains(node_id))
        });
    }

    // Get all users that aren't in new_nodes.
    let mut outside_users = Vec::new();

    for node in new_nodes.iter() {
        for user in editor.get_users(*node) {
            if !new_nodes.contains(&user) {
                outside_users.push((*node, user));
            }
        }
    }

    (new_nodes, map, outside_users)
}

pub fn fork_fission<'a>(
    editor: &'a mut FunctionEditor,
    _control_subgraph: &Subgraph,
    _types: &Vec<TypeID>,
    _loop_tree: &LoopTree,
    fork_join_map: &HashMap<NodeID, NodeID>,
) -> () {
    let forks: Vec<_> = editor
        .func()
        .nodes
        .iter()
        .enumerate()
        .filter_map(|(idx, node)| {
            if node.is_fork() {
                Some(NodeID::new(idx))
            } else {
                None
            }
        })
        .collect();

    let control_pred = NodeID::new(0);

    // This does the reduction fission:
    for fork in forks.clone() {
        // FIXME: If there is control in between fork and join, don't just give up.
        let join = fork_join_map[&fork];
        let join_pred = editor.func().nodes[join.idx()].try_join().unwrap();
        if join_pred != fork {
            todo!("Can't do fork fission on nodes with internal control")
            // Inner control LOOPs are hard
            // inner control in general *should* work right now without modifications.
        }
        let reduce_partition = default_reduce_partition(editor, fork, join);
        fork_reduce_fission_helper(editor, fork_join_map, reduce_partition, control_pred, fork);
    }
}

/** Split a 1D fork into two forks, placing select intermediate data into buffers. */
pub fn fork_bufferize_fission_helper<'a>(
    editor: &'a mut FunctionEditor,
    fork_join_map: &HashMap<NodeID, NodeID>,
    bufferized_edges: HashSet<(NodeID, NodeID)>, // Describes what intermediate data should be bufferized.
    _original_control_pred: NodeID,              // What the new fork connects to.
    types: &Vec<TypeID>,
    fork: NodeID,
) -> (NodeID, NodeID) {
    // Returns the two forks that it generates.

    // TODO: Check that bufferized edges src doesn't depend on anything that comes after the fork.

    // Copy fork + control intermediates + join to new fork + join,
    // How does control get partitioned?
    //      (depending on how it affects the data nodes on each side of the bufferized_edges)
    //      may end up in each loop, fix me later.
    // place new fork + join after join of first.

    // Only handle fork+joins with no inner control for now.

    // Create fork + join + Thread control
    let join = fork_join_map[&fork];
    let mut new_fork_id = NodeID::new(0);
    let mut new_join_id = NodeID::new(0);

    editor.edit(|mut edit| {
        new_join_id = edit.add_node(Node::Join { control: fork });
        let factors = edit.get_node(fork).try_fork().unwrap().1;
        new_fork_id = edit.add_node(Node::Fork {
            control: new_join_id,
            factors: factors.into(),
        });
        edit.replace_all_uses_where(fork, new_fork_id, |usee| *usee == join)
    });

    for (src, dst) in bufferized_edges {
        // FIXME: Disgusting cloning and allocationing and iterators.
        let factors: Vec<_> = editor.func().nodes[fork.idx()]
            .try_fork()
            .unwrap()
            .1
            .iter()
            .cloned()
            .collect();
        editor.edit(|mut edit| {
            // Create write to buffer

            let thread_stuff_it = factors.into_iter().enumerate();

            // FIxme: try to use unzip here? Idk why it wasn't working.
            let tids = thread_stuff_it.clone().map(|(dim, _)| {
                edit.add_node(Node::ThreadID {
                    control: fork,
                    dimension: dim,
                })
            });

            let array_dims = thread_stuff_it.clone().map(|(_, factor)| (factor));

            // Assume 1-d fork only for now.
            // let tid = edit.add_node(Node::ThreadID { control: fork, dimension: 0 });
            let position_idx = Index::Position(tids.collect::<Vec<_>>().into_boxed_slice());
            let write = edit.add_node(Node::Write {
                collect: NodeID::new(0),
                data: src,
                indices: vec![position_idx].into(),
            });
            let ele_type = types[src.idx()];
            let empty_buffer = edit.add_type(hercules_ir::Type::Array(
                ele_type,
                array_dims.collect::<Vec<_>>().into_boxed_slice(),
            ));
            let empty_buffer = edit.add_zero_constant(empty_buffer);
            let empty_buffer = edit.add_node(Node::Constant { id: empty_buffer });
            let reduce = Node::Reduce {
                control: new_join_id,
                init: empty_buffer,
                reduct: write,
            };
            let reduce = edit.add_node(reduce);
            // Fix write node
            edit = edit.replace_all_uses_where(NodeID::new(0), reduce, |usee| *usee == write)?;

            // Create read from buffer
            let tids = thread_stuff_it.clone().map(|(dim, _)| {
                edit.add_node(Node::ThreadID {
                    control: new_fork_id,
                    dimension: dim,
                })
            });

            let position_idx = Index::Position(tids.collect::<Vec<_>>().into_boxed_slice());

            let read = edit.add_node(Node::Read {
                collect: reduce,
                indices: vec![position_idx].into(),
            });

            edit = edit.replace_all_uses_where(src, read, |usee| *usee == dst)?;

            Ok(edit)
        });
    }

    (fork, new_fork_id)
}

/** Split a 1D fork into a separate fork for each reduction. */
pub fn fork_reduce_fission_helper<'a>(
    editor: &'a mut FunctionEditor,
    fork_join_map: &HashMap<NodeID, NodeID>,
    reduce_partition: SparseNodeMap<ForkID>, // Describes how the reduces of the fork should be split,
    original_control_pred: NodeID,           // What the new fork connects to.

    fork: NodeID,
) -> (NodeID, NodeID) {
    let join = fork_join_map[&fork];

    let mut new_control_pred: NodeID = original_control_pred;
    // Important edges are: Reduces,

    // NOTE:
    // Say two reduce are in a fork, s.t  reduce A depends on reduce B
    // If user wants A and B in separate forks:
    // - we can simply refuse
    // - or we can duplicate B

    let mut new_fork = NodeID::new(0);
    let mut new_join = NodeID::new(0);

    // Gets everything between fork & join that this reduce needs. (ALL CONTROL)
    for reduce in reduce_partition {
        let reduce = reduce.0;

        let function = editor.func();
        let subgraph = find_reduce_dependencies(function, reduce, fork);

        let mut subgraph: HashSet<NodeID> = subgraph.into_iter().collect();

        subgraph.insert(join);
        subgraph.insert(fork);
        subgraph.insert(reduce);

        let (_, mapping, _) = copy_subgraph(editor, subgraph);

        new_fork = mapping[&fork];
        new_join = mapping[&join];

        editor.edit(|mut edit| {
            // Atttach new_fork after control_pred
            let (old_control_pred, _) = edit.get_node(new_fork).try_fork().unwrap().clone();
            edit = edit.replace_all_uses_where(old_control_pred, new_control_pred, |usee| {
                *usee == new_fork
            })?;

            // Replace uses of reduce
            edit = edit.replace_all_uses(reduce, mapping[&reduce])?;
            Ok(edit)
        });

        new_control_pred = new_join;
    }

    editor.edit(|mut edit| {
        // Replace original join w/ new final join
        edit = edit.replace_all_uses_where(join, new_join, |_| true)?;
        // Delete original join (all reduce users have been moved)
        edit = edit.delete_node(join)?;

        // Replace all users of original fork, and then delete it, leftover users will be DCE'd.
        edit = edit.replace_all_uses(fork, new_fork)?;
        edit.delete_node(fork)
    });

    (new_fork, new_join)
}

pub fn fork_coalesce(
    editor: &mut FunctionEditor,
    loops: &LoopTree,
    fork_join_map: &HashMap<NodeID, NodeID>,
) -> bool {
    let fork_joins = loops.bottom_up_loops().into_iter().filter_map(|(k, _)| {
        if editor.func().nodes[k.idx()].is_fork() {
            Some(k)
        } else {
            None
        }
    });

    let fork_joins: Vec<_> = fork_joins.collect();
    // FIXME: Add a postorder traversal to optimize this.

    // FIXME: This could give us two forks that aren't actually ancestors / related, but then the helper will just return false early.
    // something like: `fork_joins.postorder_iter().windows(2)` is ideal here.
    for (inner, outer) in fork_joins.iter().cartesian_product(fork_joins.iter()) {
        if fork_coalesce_helper(editor, *outer, *inner, fork_join_map) {
            return true;
        }
    }
    return false;
}

/** Opposite of fork split, takes two fork-joins
    with no control between them, and merges them into a single fork-join.
*/
pub fn fork_coalesce_helper(
    editor: &mut FunctionEditor,
    outer_fork: NodeID,
    inner_fork: NodeID,
    fork_join_map: &HashMap<NodeID, NodeID>,
) -> bool {
    // Check that all reduces in the outer fork are in *simple* cycles with a unique reduce of the inner fork.

    let outer_join = fork_join_map[&outer_fork];
    let inner_join = fork_join_map[&inner_fork];

    let mut pairs: BiMap<NodeID, NodeID> = BiMap::new(); // Outer <-> Inner

    // FIXME: Iterate all control uses of joins to really collect all reduces
    // (reduces can be attached to inner control)
    for outer_reduce in editor
        .get_users(outer_join)
        .filter(|node| editor.func().nodes[node.idx()].is_reduce())
    {
        // check that inner reduce is of the inner join
        let (_, _, outer_reduct) = editor.func().nodes[outer_reduce.idx()]
            .try_reduce()
            .unwrap();

        let inner_reduce = outer_reduct;
        let inner_reduce_node = &editor.func().nodes[outer_reduct.idx()];

        let Node::Reduce {
            control: inner_control,
            init: inner_init,
            reduct: _,
        } = inner_reduce_node
        else {
            return false;
        };

        // FIXME: check this condition better (i.e reduce might not be attached to join)
        if *inner_control != inner_join {
            return false;
        };
        if *inner_init != outer_reduce {
            return false;
        };

        if pairs.contains_left(&outer_reduce) || pairs.contains_right(&inner_reduce) {
            return false;
        } else {
            pairs.insert(outer_reduce, inner_reduce);
        }
    }

    // Check for control between join-join and fork-fork
    let Some(user) = editor
        .get_users(outer_fork)
        .filter(|node| editor.func().nodes[node.idx()].is_control())
        .next()
    else {
        return false;
    };

    if user != inner_fork {
        return false;
    }

    let Some(user) = editor
        .get_users(inner_join)
        .filter(|node| editor.func().nodes[node.idx()].is_control())
        .next()
    else {
        return false;
    };

    if user != outer_join {
        return false;
    }

    // Checklist:
    // Increment inner TIDs
    // Add outer fork's dimension to front of inner fork.
    // Fuse reductions
    //  - Initializer becomes outer initializer
    // Replace uses of outer fork w/ inner fork.
    // Replace uses of outer join w/ inner join.
    // Delete outer fork-join

    let inner_tids: Vec<NodeID> = editor
        .get_users(inner_fork)
        .filter(|node| editor.func().nodes[node.idx()].is_thread_id())
        .collect();

    let (outer_pred, outer_dims) = editor.func().nodes[outer_fork.idx()].try_fork().unwrap();
    let (_, inner_dims) = editor.func().nodes[inner_fork.idx()].try_fork().unwrap();
    let num_outer_dims = outer_dims.len();
    let mut new_factors = outer_dims.to_vec();

    // CHECKME / FIXME: Might need to be added the other way.
    new_factors.append(&mut inner_dims.to_vec());

    for tid in inner_tids {
        let (fork, dim) = editor.func().nodes[tid.idx()].try_thread_id().unwrap();
        let new_tid = Node::ThreadID {
            control: fork,
            dimension: dim + num_outer_dims,
        };

        editor.edit(|mut edit| {
            let new_tid = edit.add_node(new_tid);
            let edit = edit.replace_all_uses(tid, new_tid)?;
            Ok(edit)
        });
    }

    // Fuse Reductions
    for (outer_reduce, inner_reduce) in pairs {
        let (_, outer_init, _) = editor.func().nodes[outer_reduce.idx()]
            .try_reduce()
            .unwrap();
        let (_, inner_init, _) = editor.func().nodes[inner_reduce.idx()]
            .try_reduce()
            .unwrap();
        editor.edit(|mut edit| {
            // Set inner init to outer init.
            edit =
                edit.replace_all_uses_where(inner_init, outer_init, |usee| *usee == inner_reduce)?;
            edit = edit.replace_all_uses(outer_reduce, inner_reduce)?;
            edit = edit.delete_node(outer_reduce)?;

            Ok(edit)
        });
    }

    editor.edit(|mut edit| {
        let new_fork = Node::Fork {
            control: outer_pred,
            factors: new_factors.into(),
        };
        let new_fork = edit.add_node(new_fork);

        edit = edit.replace_all_uses(inner_fork, new_fork)?;
        edit = edit.replace_all_uses(outer_fork, new_fork)?;
        edit = edit.replace_all_uses(outer_join, inner_join)?;
        edit = edit.delete_node(outer_join)?;
        edit = edit.delete_node(inner_fork)?;
        edit = edit.delete_node(outer_fork)?;

        Ok(edit)
    });

    true
}