slf.rs 15.95 KiB
use std::collections::{BTreeMap, HashMap, HashSet};
use std::iter::once;
use hercules_ir::*;
use crate::*;
/*
* The SLF lattice tracks what sub-values of a collection are known. Each sub-
* value is a node ID at a set of indices that were written at. A write to a set
* of indices that structurally maps a previous sub-value removes the old sub-
* value, since that write may overwrite the old known sub-value. The lattice
* top corresponds to every value is 0. When the sub-values at a set of indices
* are not known, the `subvalues` map stores `None` for the known value. When a
* write involves array positions, remove sub-values that are clobbered and
* insert an indices set with an empty positions list and a `None` value.
*/
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
struct SLFLattice {
subvalues: BTreeMap<Box<[Index]>, Option<NodeID>>,
}
impl Semilattice for SLFLattice {
fn meet(a: &Self, b: &Self) -> Self {
// Merge the two maps. Find equal indices sets between `a` and `b` and
// keep their known sub-value if they're equivalent. All other indices
// sets in `a` or `b` map to `None`.
let mut ret = BTreeMap::new();
for (indices, a_subvalue) in a.subvalues.iter() {
if let Some(b_subvalue) = b.subvalues.get(indices)
&& a_subvalue == b_subvalue
{
// If both maps have the same sub-value for this set of indices,
// add it unmolested to the meet lattice value.
ret.insert(indices.clone(), *a_subvalue);
} else {
// If not both maps have a write at the same set of indices or
// if the writes don't match, then we don't know what's been
// written there.
ret.insert(indices.clone(), None);
}
}
for (indices, _) in b.subvalues.iter() {
// Any indices sets in `b` that aren't in `ret` are indices sets
// that aren't in `a`, so the sub-value isn't known.
ret.entry(indices.clone()).or_insert(None);
}
SLFLattice { subvalues: ret }
}
fn top() -> Self {
SLFLattice {
subvalues: BTreeMap::new(),
}
}
fn bottom() -> Self {
let mut subvalues = BTreeMap::new();
// The empty indices set overlaps with all possible indices sets.
subvalues.insert(Box::new([]) as Box<[Index]>, None);
SLFLattice { subvalues }
}
}
/*
* Top level function to run store-to-load forwarding on a function. Looks for
* known values inside collections and replaces reads of those values with the
* values directly.
*/
pub fn slf(editor: &mut FunctionEditor, reverse_postorder: &Vec<NodeID>, typing: &Vec<TypeID>) {
// First, run a dataflow analysis that looks at known values inside
// collections. Thanks to the value semantics of Hercules IR, this analysis
// is relatively simple and straightforward.
let func = editor.func();
let lattice = forward_dataflow(func, reverse_postorder, |inputs, id| {
match func.nodes[id.idx()] {
Node::Phi {
control: _,
data: _,
}
| Node::Reduce {
control: _,
init: _,
reduct: _,
}
| Node::Ternary {
op: TernaryOperator::Select,
first: _,
second: _,
third: _,
} => inputs.into_iter().fold(SLFLattice::top(), |acc, input| {
SLFLattice::meet(&acc, input)
}),
Node::Write {
collect: _,
data,
ref indices,
} => {
// Start with the indices of the `collect` input.
let mut value = inputs[0].clone();
// Any indices sets that overlap with `indices` become `None`,
// since we no longer know what's stored there.
for (other_indices, subvalue) in value.subvalues.iter_mut() {
if indices_may_overlap(other_indices, indices) {
*subvalue = None;
}
}
// Track `data` at `indices`.
value.subvalues.insert(indices.clone(), Some(data));
value
}
Node::Constant { id: _ } | Node::Undef { ty: _ } => SLFLattice::top(),
_ => SLFLattice::bottom(),
}
});
// Second, look for reads where the indices set either:
// 1. Equal the indices of a known sub-value. Then, the read can be replaced
// by the known sub-value.
// 2. Otherwise, if the indices set doesn't overlap with any known or
// unknown sub-value, then the read can be replaced by a zero constant.
// 3. Otherwise, the read can't be replaced.
// Keep track of which nodes we've already replaced, since a sub-value we
// knew previously may be the ID of an old node replaced previously.
let mut replacements = BTreeMap::new();
for id in editor.node_ids() {
let Node::Read {
collect,
ref indices,
} = editor.func().nodes[id.idx()]
else {
continue;
};
let subvalues = &lattice[collect.idx()].subvalues;
if let Some(sub_value) = subvalues.get(indices)
&& let Some(mut known) = *sub_value
{
while let Some(replacement) = replacements.get(&known) {
known = *replacement;
}
editor.edit(|mut edit| {
edit = edit.replace_all_uses(id, known)?;
edit.delete_node(id)
});
replacements.insert(id, known);
} else if !subvalues
.keys()
.any(|other_indices| indices_may_overlap(other_indices, indices))
{
editor.edit(|mut edit| {
let zero = edit.add_zero_constant(typing[id.idx()]);
let zero = edit.add_node(Node::Constant { id: zero });
edit = edit.replace_all_uses(id, zero)?;
edit.delete_node(id)
});
}
}
}
/*
* Top level functiion to run fork-join level store-to-load forwarding on a
* function. Looks for reduce nodes holding arrays that have nice einsum
* expressions, and replaces reads of that array with the sub-expression of the
* einsum array comprehension.
*/
pub fn array_slf(
editor: &mut FunctionEditor,
fork_join_map: &HashMap<NodeID, NodeID>,
reduce_einsum: &(MathEnv, HashMap<NodeID, MathID>),
nodes_in_fork_joins: &HashMap<NodeID, HashSet<NodeID>>,
) {
let join_fork_map: HashMap<NodeID, NodeID> = fork_join_map
.into_iter()
.map(|(fork, join)| (*join, *fork))
.collect();
let (env, einsums) = reduce_einsum;
for (reduce, einsum) in einsums {
// Check that the expression is an array comprehension.
let MathExpr::Comprehension(elem, _) = env[einsum.idx()] else {
continue;
};
// If any of the opaque nodes are "in" the fork-join of the reduce, then
// they depend on the thread IDs of the fork-join in a way that's not
// modeled in the einsum expression, and therefore those thread IDs
// can't be substituted with read position indices. We need to skip
// applying SLF to these arrays.
let nodes = &editor.func().nodes;
let join = nodes[reduce.idx()].try_reduce().unwrap().0;
let fork = join_fork_map[&join];
let nodes_in_this_fork_join = &nodes_in_fork_joins[&fork];
let opaque_nodes = opaque_nodes_in_expr(env, *einsum);
if opaque_nodes
.into_iter()
.any(|id| nodes_in_this_fork_join.contains(&id))
{
continue;
}
// Look for read users of the reduce. They can be replaced with
// substituting the read indices of the array into the einsum expression
// to compute, rather than read, the needed value.
let reads: Vec<(NodeID, Box<[NodeID]>)> = editor
.get_users(*reduce)
.filter_map(|id| {
nodes[id.idx()].try_read().map(|(_, indices)| {
// The indices list should just be a single position index, since
// einsum expressions are only derived for arrays of primitives.
assert_eq!(indices.len(), 1);
let Index::Position(indices) = &indices[0] else {
panic!()
};
(id, indices.clone())
})
})
.collect();
for (read, indices) in reads {
editor.edit(|mut edit| {
// Create the expression equivalent to the read.
let id = materialize_simple_einsum_expr(&mut edit, elem, env, &indices);
// Replace and delete the read.
edit = edit.replace_all_uses(read, id)?;
edit.delete_node(read)
});
}
}
}
/*
* Top level function to run reduce store-to-load forwarding on a function.
* There are two variants of reduce SLF. One optimizes parallel reductions and
* the other optimizes scalar reductions. This pass just runs one after the
* other - it's up to the user to potentially wrap this pass in a fixpoint.
*
* The parallel variant looks for reductions on collections with the following
* form:
*
* 1. A write `reduct` use.
* 2. A single read user in the reduction cycle.
* 3. The write use and read user have identical indices.
* 4. The indices set directly refers to at least every thread ID produced by
* the fork corresponding to the reduce.
*
* Condition #4 roughly corresponds to the same condition needed to infer the
* ParallelReduce schedule - however, in this scenario, it can't be inferred
* because the data being written is in the reduce cycle, since the data being
* written is derived from the read user of the reduce. However, since the
* indices changes each iteration, the read could equivalently read from the
* `init` input of the reduce, rather than the reduce itself. This optimization
* replaces the use of the reduce in the read with the `init` input of the
* reduce, nominally so that ParallelReduce can get inferred.
*
* If a reduction has the ParallelReduce schedule on it, then any read of it
* directly can be much more easily optimized to refer to the `init` input
* rather than the reduce itself.
*
* The scalar variant looks for reductions on collections with the following
* form:
*
* 1. A write `reduct` use.
* 2. A single read user in the reduction cycle.
* 3. The write use and read user have identical indices.
* 4. The indices set doesn't reference any nodes in the fork-join of the reduce
* (doesn't depend on the thread IDs of the fork-join).
*
* Instead of repeatedly reading / writing the same collection item each
* iteration, the reduction can reduce over the scalar value directly, and do a
* single write into the collection after the fork-join (note technically, the
* "scalar" may itself be a collection in the case of nested collections, but
* this optimization most often handles scalars).
*/
pub fn reduce_slf(
editor: &mut FunctionEditor,
fork_join_map: &HashMap<NodeID, NodeID>,
reduce_cycles: &HashMap<NodeID, HashSet<NodeID>>,
nodes_in_fork_joins: &HashMap<NodeID, HashSet<NodeID>>,
) {
// Helper to get write use and read user of reduce with identical indices.
// This checks conditions 1, 2, and 3 of both parallel and scalar reduction
// optimization.
let read_write_helper = |reduce: NodeID, editor: &FunctionEditor| -> Option<(NodeID, NodeID)> {
let nodes = &editor.func().nodes;
let reduct = nodes[reduce.idx()].try_reduce().unwrap().2;
if !nodes[reduct.idx()].is_write() || nodes[reduct.idx()].try_write().unwrap().0 != reduce {
return None;
}
let mut users = editor
.get_users(reduce)
.filter(|id| reduce_cycles[&reduce].contains(id) && *id != reduct);
let read = users.next()?;
if users.next().is_some()
|| !nodes[read.idx()].is_read()
|| nodes[read.idx()].try_read().unwrap().1 != nodes[reduct.idx()].try_write().unwrap().2
{
return None;
}
Some((read, reduct))
};
// First, optimize parallel reductions.
for (fork, join) in fork_join_map {
let reduces: Vec<_> = editor
.get_users(*join)
.filter(|id| editor.func().nodes[id.idx()].is_reduce())
.collect();
for reduce in reduces {
if let Some((read, _)) = read_write_helper(reduce, editor) {
// Check condition 4 of parallel reduction optimization.
let indices = editor.func().nodes[read.idx()].try_read().unwrap().1;
if indices_parallel_over_forks(editor, indices, once(*fork)) {
let init = editor.func().nodes[reduce.idx()].try_reduce().unwrap().1;
let new_read = Node::Read {
collect: init,
indices: indices.to_vec().into_boxed_slice(),
};
editor.edit(|mut edit| {
let new_read = edit.add_node(new_read);
edit = edit.replace_all_uses(read, new_read)?;
edit.delete_node(read)
});
}
} else if editor.func().schedules[reduce.idx()].contains(&Schedule::ParallelReduce) {
let init = editor.func().nodes[reduce.idx()].try_reduce().unwrap().1;
let read_users: Vec<_> = editor
.get_users(reduce)
.filter(|id| {
editor.func().nodes[id.idx()].is_read()
&& nodes_in_fork_joins[fork].contains(&id)
})
.collect();
editor.edit(|edit| {
edit.replace_all_uses_where(reduce, init, |id| read_users.contains(id))
});
}
}
}
// Second, optimize scalar reductions.
for (fork, join) in fork_join_map {
let reduces: Vec<_> = editor
.get_users(*join)
.filter(|id| editor.func().nodes[id.idx()].is_reduce())
.collect();
for reduce in reduces {
let Some((read, write)) = read_write_helper(reduce, editor) else {
continue;
};
// Check condition 4 of scalar reduction optimization.
let indices = editor.func().nodes[read.idx()].try_read().unwrap().1;
if node_indices(indices).all(|id| !nodes_in_fork_joins[fork].contains(&id)) {
let indices = indices.to_vec().into_boxed_slice();
let init = editor.func().nodes[reduce.idx()].try_reduce().unwrap().1;
let data = editor.func().nodes[write.idx()].try_write().unwrap().1;
let init_read = Node::Read {
collect: init,
indices: indices.clone(),
};
editor.edit(|mut edit| {
let init_read = edit.add_node(init_read);
let new_reduce = Node::Reduce {
control: *join,
init: init_read,
reduct: data,
};
let new_reduce = edit.add_node(new_reduce);
let post_write = Node::Write {
collect: init,
data: new_reduce,
indices,
};
let post_write = edit.add_node(post_write);
edit = edit.replace_all_uses(read, new_reduce)?;
edit = edit.replace_all_uses(reduce, post_write)?;
edit = edit.delete_node(read)?;
edit = edit.delete_node(reduce)?;
edit = edit.delete_node(write)?;
Ok(edit)
});
}
}
}
}