Skip to content
Snippets Groups Projects
slf.rs 15.95 KiB
use std::collections::{BTreeMap, HashMap, HashSet};
use std::iter::once;

use hercules_ir::*;

use crate::*;

/*
 * The SLF lattice tracks what sub-values of a collection are known. Each sub-
 * value is a node ID at a set of indices that were written at. A write to a set
 * of indices that structurally maps a previous sub-value removes the old sub-
 * value, since that write may overwrite the old known sub-value. The lattice
 * top corresponds to every value is 0. When the sub-values at a set of indices
 * are not known, the `subvalues` map stores `None` for the known value. When a
 * write involves array positions, remove sub-values that are clobbered and
 * insert an indices set with an empty positions list and a `None` value.
 */
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
struct SLFLattice {
    subvalues: BTreeMap<Box<[Index]>, Option<NodeID>>,
}

impl Semilattice for SLFLattice {
    fn meet(a: &Self, b: &Self) -> Self {
        // Merge the two maps. Find equal indices sets between `a` and `b` and
        // keep their known sub-value if they're equivalent. All other indices
        // sets in `a` or `b` map to `None`.
        let mut ret = BTreeMap::new();
        for (indices, a_subvalue) in a.subvalues.iter() {
            if let Some(b_subvalue) = b.subvalues.get(indices)
                && a_subvalue == b_subvalue
            {
                // If both maps have the same sub-value for this set of indices,
                // add it unmolested to the meet lattice value.
                ret.insert(indices.clone(), *a_subvalue);
            } else {
                // If not both maps have a write at the same set of indices or
                // if the writes don't match, then we don't know what's been
                // written there.
                ret.insert(indices.clone(), None);
            }
        }
        for (indices, _) in b.subvalues.iter() {
            // Any indices sets in `b` that aren't in `ret` are indices sets
            // that aren't in `a`, so the sub-value isn't known.
            ret.entry(indices.clone()).or_insert(None);
        }
        SLFLattice { subvalues: ret }
    }

    fn top() -> Self {
        SLFLattice {
            subvalues: BTreeMap::new(),
        }
    }

    fn bottom() -> Self {
        let mut subvalues = BTreeMap::new();
        // The empty indices set overlaps with all possible indices sets.
        subvalues.insert(Box::new([]) as Box<[Index]>, None);
        SLFLattice { subvalues }
    }
}

/*
 * Top level function to run store-to-load forwarding on a function. Looks for
 * known values inside collections and replaces reads of those values with the
 * values directly.
 */
pub fn slf(editor: &mut FunctionEditor, reverse_postorder: &Vec<NodeID>, typing: &Vec<TypeID>) {
    // First, run a dataflow analysis that looks at known values inside
    // collections. Thanks to the value semantics of Hercules IR, this analysis
    // is relatively simple and straightforward.
    let func = editor.func();
    let lattice = forward_dataflow(func, reverse_postorder, |inputs, id| {
        match func.nodes[id.idx()] {
            Node::Phi {
                control: _,
                data: _,
            }
            | Node::Reduce {
                control: _,
                init: _,
                reduct: _,
            }
            | Node::Ternary {
                op: TernaryOperator::Select,
                first: _,
                second: _,
                third: _,
            } => inputs.into_iter().fold(SLFLattice::top(), |acc, input| {
                SLFLattice::meet(&acc, input)
            }),
            Node::Write {
                collect: _,
                data,
                ref indices,
            } => {
                // Start with the indices of the `collect` input.
                let mut value = inputs[0].clone();

                // Any indices sets that overlap with `indices` become `None`,
                // since we no longer know what's stored there.
                for (other_indices, subvalue) in value.subvalues.iter_mut() {
                    if indices_may_overlap(other_indices, indices) {
                        *subvalue = None;
                    }
                }

                // Track `data` at `indices`.
                value.subvalues.insert(indices.clone(), Some(data));

                value
            }
            Node::Constant { id: _ } | Node::Undef { ty: _ } => SLFLattice::top(),
            _ => SLFLattice::bottom(),
        }
    });

    // Second, look for reads where the indices set either:
    // 1. Equal the indices of a known sub-value. Then, the read can be replaced
    //    by the known sub-value.
    // 2. Otherwise, if the indices set doesn't overlap with any known or
    //    unknown sub-value, then the read can be replaced by a zero constant.
    // 3. Otherwise, the read can't be replaced.
    // Keep track of which nodes we've already replaced, since a sub-value we
    // knew previously may be the ID of an old node replaced previously.
    let mut replacements = BTreeMap::new();
    for id in editor.node_ids() {
        let Node::Read {
            collect,
            ref indices,
        } = editor.func().nodes[id.idx()]
        else {
            continue;
        };
        let subvalues = &lattice[collect.idx()].subvalues;

        if let Some(sub_value) = subvalues.get(indices)
            && let Some(mut known) = *sub_value
        {
            while let Some(replacement) = replacements.get(&known) {
                known = *replacement;
            }
            editor.edit(|mut edit| {
                edit = edit.replace_all_uses(id, known)?;
                edit.delete_node(id)
            });
            replacements.insert(id, known);
        } else if !subvalues
            .keys()
            .any(|other_indices| indices_may_overlap(other_indices, indices))
        {
            editor.edit(|mut edit| {
                let zero = edit.add_zero_constant(typing[id.idx()]);
                let zero = edit.add_node(Node::Constant { id: zero });
                edit = edit.replace_all_uses(id, zero)?;
                edit.delete_node(id)
            });
        }
    }
}

/*
 * Top level functiion to run fork-join level store-to-load forwarding on a
 * function. Looks for reduce nodes holding arrays that have nice einsum
 * expressions, and replaces reads of that array with the sub-expression of the
 * einsum array comprehension.
 */
pub fn array_slf(
    editor: &mut FunctionEditor,
    fork_join_map: &HashMap<NodeID, NodeID>,
    reduce_einsum: &(MathEnv, HashMap<NodeID, MathID>),
    nodes_in_fork_joins: &HashMap<NodeID, HashSet<NodeID>>,
) {
    let join_fork_map: HashMap<NodeID, NodeID> = fork_join_map
        .into_iter()
        .map(|(fork, join)| (*join, *fork))
        .collect();

    let (env, einsums) = reduce_einsum;
    for (reduce, einsum) in einsums {
        // Check that the expression is an array comprehension.
        let MathExpr::Comprehension(elem, _) = env[einsum.idx()] else {
            continue;
        };

        // If any of the opaque nodes are "in" the fork-join of the reduce, then
        // they depend on the thread IDs of the fork-join in a way that's not
        // modeled in the einsum expression, and therefore those thread IDs
        // can't be substituted with read position indices. We need to skip
        // applying SLF to these arrays.
        let nodes = &editor.func().nodes;
        let join = nodes[reduce.idx()].try_reduce().unwrap().0;
        let fork = join_fork_map[&join];
        let nodes_in_this_fork_join = &nodes_in_fork_joins[&fork];
        let opaque_nodes = opaque_nodes_in_expr(env, *einsum);
        if opaque_nodes
            .into_iter()
            .any(|id| nodes_in_this_fork_join.contains(&id))
        {
            continue;
        }

        // Look for read users of the reduce. They can be replaced with
        // substituting the read indices of the array into the einsum expression
        // to compute, rather than read, the needed value.
        let reads: Vec<(NodeID, Box<[NodeID]>)> = editor
            .get_users(*reduce)
            .filter_map(|id| {
                nodes[id.idx()].try_read().map(|(_, indices)| {
                    // The indices list should just be a single position index, since
                    // einsum expressions are only derived for arrays of primitives.
                    assert_eq!(indices.len(), 1);
                    let Index::Position(indices) = &indices[0] else {
                        panic!()
                    };
                    (id, indices.clone())
                })
            })
            .collect();
        for (read, indices) in reads {
            editor.edit(|mut edit| {
                // Create the expression equivalent to the read.
                let id = materialize_simple_einsum_expr(&mut edit, elem, env, &indices);

                // Replace and delete the read.
                edit = edit.replace_all_uses(read, id)?;
                edit.delete_node(read)
            });
        }
    }
}

/*
 * Top level function to run reduce store-to-load forwarding on a function.
 * There are two variants of reduce SLF. One optimizes parallel reductions and
 * the other optimizes scalar reductions. This pass just runs one after the
 * other - it's up to the user to potentially wrap this pass in a fixpoint.
 *
 * The parallel variant looks for reductions on collections with the following
 * form:
 *
 * 1. A write `reduct` use.
 * 2. A single read user in the reduction cycle.
 * 3. The write use and read user have identical indices.
 * 4. The indices set directly refers to at least every thread ID produced by
 *    the fork corresponding to the reduce.
 *
 * Condition #4 roughly corresponds to the same condition needed to infer the
 * ParallelReduce schedule - however, in this scenario, it can't be inferred
 * because the data being written is in the reduce cycle, since the data being
 * written is derived from the read user of the reduce. However, since the
 * indices changes each iteration, the read could equivalently read from the
 * `init` input of the reduce, rather than the reduce itself. This optimization
 * replaces the use of the reduce in the read with the `init` input of the
 * reduce, nominally so that ParallelReduce can get inferred.
 *
 * If a reduction has the ParallelReduce schedule on it, then any read of it
 * directly can be much more easily optimized to refer to the `init` input
 * rather than the reduce itself.
 *
 * The scalar variant looks for reductions on collections with the following
 * form:
 *
 * 1. A write `reduct` use.
 * 2. A single read user in the reduction cycle.
 * 3. The write use and read user have identical indices.
 * 4. The indices set doesn't reference any nodes in the fork-join of the reduce
 *    (doesn't depend on the thread IDs of the fork-join).
 *
 * Instead of repeatedly reading / writing the same collection item each
 * iteration, the reduction can reduce over the scalar value directly, and do a
 * single write into the collection after the fork-join (note technically, the
 * "scalar" may itself be a collection in the case of nested collections, but
 * this optimization most often handles scalars).
 */
pub fn reduce_slf(
    editor: &mut FunctionEditor,
    fork_join_map: &HashMap<NodeID, NodeID>,
    reduce_cycles: &HashMap<NodeID, HashSet<NodeID>>,
    nodes_in_fork_joins: &HashMap<NodeID, HashSet<NodeID>>,
) {
    // Helper to get write use and read user of reduce with identical indices.
    // This checks conditions 1, 2, and 3 of both parallel and scalar reduction
    // optimization.
    let read_write_helper = |reduce: NodeID, editor: &FunctionEditor| -> Option<(NodeID, NodeID)> {
        let nodes = &editor.func().nodes;
        let reduct = nodes[reduce.idx()].try_reduce().unwrap().2;
        if !nodes[reduct.idx()].is_write() || nodes[reduct.idx()].try_write().unwrap().0 != reduce {
            return None;
        }
        let mut users = editor
            .get_users(reduce)
            .filter(|id| reduce_cycles[&reduce].contains(id) && *id != reduct);
        let read = users.next()?;
        if users.next().is_some()
            || !nodes[read.idx()].is_read()
            || nodes[read.idx()].try_read().unwrap().1 != nodes[reduct.idx()].try_write().unwrap().2
        {
            return None;
        }
        Some((read, reduct))
    };

    // First, optimize parallel reductions.
    for (fork, join) in fork_join_map {
        let reduces: Vec<_> = editor
            .get_users(*join)
            .filter(|id| editor.func().nodes[id.idx()].is_reduce())
            .collect();
        for reduce in reduces {
            if let Some((read, _)) = read_write_helper(reduce, editor) {
                // Check condition 4 of parallel reduction optimization.
                let indices = editor.func().nodes[read.idx()].try_read().unwrap().1;
                if indices_parallel_over_forks(editor, indices, once(*fork)) {
                    let init = editor.func().nodes[reduce.idx()].try_reduce().unwrap().1;
                    let new_read = Node::Read {
                        collect: init,
                        indices: indices.to_vec().into_boxed_slice(),
                    };
                    editor.edit(|mut edit| {
                        let new_read = edit.add_node(new_read);
                        edit = edit.replace_all_uses(read, new_read)?;
                        edit.delete_node(read)
                    });
                }
            } else if editor.func().schedules[reduce.idx()].contains(&Schedule::ParallelReduce) {
                let init = editor.func().nodes[reduce.idx()].try_reduce().unwrap().1;
                let read_users: Vec<_> = editor
                    .get_users(reduce)
                    .filter(|id| {
                        editor.func().nodes[id.idx()].is_read()
                            && nodes_in_fork_joins[fork].contains(&id)
                    })
                    .collect();
                editor.edit(|edit| {
                    edit.replace_all_uses_where(reduce, init, |id| read_users.contains(id))
                });
            }
        }
    }

    // Second, optimize scalar reductions.
    for (fork, join) in fork_join_map {
        let reduces: Vec<_> = editor
            .get_users(*join)
            .filter(|id| editor.func().nodes[id.idx()].is_reduce())
            .collect();
        for reduce in reduces {
            let Some((read, write)) = read_write_helper(reduce, editor) else {
                continue;
            };

            // Check condition 4 of scalar reduction optimization.
            let indices = editor.func().nodes[read.idx()].try_read().unwrap().1;
            if node_indices(indices).all(|id| !nodes_in_fork_joins[fork].contains(&id)) {
                let indices = indices.to_vec().into_boxed_slice();
                let init = editor.func().nodes[reduce.idx()].try_reduce().unwrap().1;
                let data = editor.func().nodes[write.idx()].try_write().unwrap().1;
                let init_read = Node::Read {
                    collect: init,
                    indices: indices.clone(),
                };
                editor.edit(|mut edit| {
                    let init_read = edit.add_node(init_read);
                    let new_reduce = Node::Reduce {
                        control: *join,
                        init: init_read,
                        reduct: data,
                    };
                    let new_reduce = edit.add_node(new_reduce);
                    let post_write = Node::Write {
                        collect: init,
                        data: new_reduce,
                        indices,
                    };
                    let post_write = edit.add_node(post_write);

                    edit = edit.replace_all_uses(read, new_reduce)?;
                    edit = edit.replace_all_uses(reduce, post_write)?;

                    edit = edit.delete_node(read)?;
                    edit = edit.delete_node(reduce)?;
                    edit = edit.delete_node(write)?;

                    Ok(edit)
                });
            }
        }
    }
}