Skip to content
Snippets Groups Projects

Array to prod

Merged Aaron Councilman requested to merge array-to-prod into main
All threads resolved!
14 files
+ 609
20
Compare changes
  • Side-by-side
  • Inline
Files
14
+ 352
0
 
use hercules_ir::define_id_type;
 
use hercules_ir::ir::*;
 
 
use bitvec::prelude::*;
 
 
use crate::*;
 
 
use std::collections::{HashMap, HashSet};
 
use std::marker::PhantomData;
 
 
/*
 
* Top level function for array to product which will convert constant
 
* sized arrays into products if the array is only accessed at indices which
 
* are constants.
 
*
 
* To identify the collections we can convert we look at each constant-sized
 
* array constant and compute the set which includes the constant node and is
 
* closed under the following properties:
 
* - For each collection in the set, its uses are in the set
 
* - For each node that uses a collection, all collections it uses are in the
 
* set
 
* From this set, we then determine whether this whole set can be converted to
 
* operating on products, rather than arrays, as follows
 
* - Each read and write node must be to a constant index
 
* - It may not contain any arguments (we could generate code to read a an
 
* array argument into a product, but do not do so for now)
 
* - There are call or return nodes in the set (this would mean that the
 
* collections are consumed by a call or return, again we could reconstruct
 
* the array where needed but do not do so for now and so have this
 
* restriction)
 
* - All nodes in the set are editable (if we cannot modify some node then the
 
* conversion will fail)
 
*
 
* The max_size argument allows the user to specify a limit on the size of arrays
 
* that should be converted to products. If the number of elements in the array
 
* is larger than the max size the array will not be converted.
 
*/
 
pub fn array_to_product(editor: &mut FunctionEditor, types: &[TypeID], max_size: Option<usize>) {
 
let replace_nodes = array_usage_analysis(editor, types, max_size);
 
let num_nodes = editor.func().nodes.len();
 
 
// Replace nodes
 
for node_idx in 0..num_nodes {
 
if !replace_nodes[node_idx] {
 
continue;
 
}
 
let node = NodeID::new(node_idx);
 
 
// We can replace the array(s) this node uses with a product. What we have to do depends on
 
// the type of the node
 
match &editor.func().nodes[node_idx] {
 
// Phi, Reduce, and Ternary just use the whole collection, they do not need to change,
 
// except as they will be modified by replace_all_uses_of
 
Node::Phi { .. }
 
| Node::Reduce { .. }
 
| Node::Ternary {
 
op: TernaryOperator::Select,
 
..
 
} => {}
 
Node::Constant { id } => {
 
assert!(editor.get_constant(*id).is_array());
 
let element: TypeID = editor.get_type(types[node_idx]).try_element_type().unwrap();
 
let dims: Vec<usize> = editor
 
.get_type(types[node_idx])
 
.try_extents()
 
.unwrap()
 
.iter()
 
.map(|dc| editor.get_dynamic_constant(*dc).try_constant().unwrap())
 
.collect();
 
// Replace the constant by a product that is a product (for each dimension) and the
 
// elements are zero'd
 
editor.edit(|mut edit| {
 
let element_zero = edit.add_zero_constant(element);
 
let (constant, _) = dims.into_iter().rfold(
 
(element_zero, element),
 
|(cur_const, cur_type), dim| {
 
let new_type = edit.add_type(Type::Product(vec![cur_type; dim].into()));
 
let new_const = edit.add_constant(Constant::Product(
 
new_type,
 
vec![cur_const; dim].into(),
 
));
 
(new_const, new_type)
 
},
 
);
 
let new_val = edit.add_node(Node::Constant { id: constant });
 
let edit = edit.replace_all_uses(node, new_val)?;
 
edit.delete_node(node)
 
});
 
}
 
Node::Read { collect, indices } => {
 
let collect = *collect;
 
let new_indices = convert_indices_to_prod(editor, indices);
 
editor.edit(|mut edit| {
 
let new_val = edit.add_node(Node::Read {
 
collect,
 
indices: new_indices,
 
});
 
let edit = edit.replace_all_uses(NodeID::new(node_idx), new_val)?;
 
edit.delete_node(node)
 
});
 
}
 
Node::Write {
 
collect,
 
data,
 
indices,
 
} => {
 
let collect = *collect;
 
let data = *data;
 
let new_indices = convert_indices_to_prod(editor, indices);
 
editor.edit(|mut edit| {
 
let new_val = edit.add_node(Node::Write {
 
collect,
 
data,
 
indices: new_indices,
 
});
 
let edit = edit.replace_all_uses(NodeID::new(node_idx), new_val)?;
 
edit.delete_node(node)
 
});
 
}
 
node => panic!("Node cannot be replaced: {:?}", node),
 
}
 
}
 
}
 
 
fn convert_indices_to_prod(editor: &FunctionEditor, indices: &[Index]) -> Box<[Index]> {
 
let mut result = vec![];
 
 
for index in indices {
 
match index {
 
Index::Position(positions) => {
 
for pos in positions {
 
let const_id = editor.func().nodes[pos.idx()]
 
.try_constant()
 
.expect("Array position must be constant");
 
match *editor.get_constant(const_id) {
 
Constant::UnsignedInteger64(idx) => result.push(Index::Field(idx as usize)),
 
ref val => panic!("Position should be u64 constant: {:?}", val),
 
}
 
}
 
}
 
index => panic!("Index cannot be replaced: {:?}", index),
 
}
 
}
 
 
result.into()
 
}
 
 
// Given the editor, while compute a mask of which nodes are to be converted
 
// from using a constant sized array into using a product
 
fn array_usage_analysis(
 
editor: &FunctionEditor,
 
types: &[TypeID],
 
max_size: Option<usize>,
 
) -> BitVec<u8, Lsb0> {
 
let num_nodes = editor.func().nodes.len();
 
 
// Step 1: identify the constant nodes that are constant sized arrays no larger than the
 
// max_size, these are what we are interested in converting into products
 
let sources = editor
 
.func()
 
.nodes
 
.iter()
 
.enumerate()
 
.filter_map(|(idx, node)| {
 
let Node::Constant { id } = node else {
 
return None;
 
};
 
let Constant::Array(array_type) = *editor.get_constant(*id) else {
 
return None;
 
};
 
let typ = editor.get_type(array_type);
 
let Some(dims) = typ.try_extents() else {
 
return None;
 
};
 
// Compute the total number of elements, the result is None if some dimension is not a
 
// constant and otherwise is Some(num_elements) which we can then compare with max_size
 
if let Some(elements) = dims.iter().fold(Some(1), |prod, dc| {
 
prod.and_then(|prod| {
 
editor
 
.get_dynamic_constant(*dc)
 
.try_constant()
 
.map(|dim| prod * dim)
 
})
 
}) {
 
if let Some(max_size) = max_size
 
&& elements > max_size
 
{
 
// Too many elements, don't convert
 
None
 
} else {
 
Some(NodeID::new(idx))
 
}
 
} else {
 
None
 
}
 
})
 
.collect::<Vec<_>>();
 
 
// Step 2: collect the collection information we need for the (whole) function. For each node
 
// that returns a collection (that in reference semantics returns the same reference as some of
 
// its inputs) union with all of its users. The nodes that matter in this are arguments,
 
// constants, writes, phis, selects, and reduces with array types.
 
let mut analysis = UnionFind::new();
 
for node_idx in 0..num_nodes {
 
let node_id = NodeID::new(node_idx);
 
if editor.get_type(types[node_idx]).is_array() {
 
match editor.func().nodes[node_idx] {
 
Node::Phi { .. }
 
| Node::Reduce { .. }
 
| Node::Parameter { .. }
 
| Node::Constant { .. }
 
| Node::Ternary {
 
op: TernaryOperator::Select,
 
..
 
}
 
| Node::Write { .. } => {
 
for user in editor.get_users(node_id) {
 
analysis.union(node_id, user);
 
}
 
}
 
_ => {}
 
}
 
}
 
}
 
 
let sets = analysis.sets(&sources);
 
 
// Step 3: determine which sets can be converted and mark the nodes in those sets
 
let mut result = bitvec![u8, Lsb0; 0; num_nodes];
 
 
for nodes in sets {
 
if nodes
 
.iter()
 
.all(|node_id| editor.is_mutable(*node_id) && can_replace(editor, *node_id))
 
{
 
for node_id in nodes {
 
result.set(node_id.idx(), true);
 
}
 
}
 
}
 
 
result
 
}
 
 
fn can_replace(editor: &FunctionEditor, node: NodeID) -> bool {
 
match &editor.func().nodes[node.idx()] {
 
// Reads and writes must be at constant indices
 
Node::Read { indices, .. } | Node::Write { indices, .. } => {
 
indices.iter().all(|idx| match idx {
 
Index::Position(pos) => pos
 
.iter()
 
.all(|node| editor.func().nodes[node.idx()].is_constant()),
 
_ => false,
 
})
 
}
 
// phi, reduce, constants, and select can always be replaced if their users and uses allow
 
// it, which is handled by the construction of the set
 
Node::Phi { .. }
 
| Node::Reduce { .. }
 
| Node::Constant { .. }
 
| Node::Ternary {
 
op: TernaryOperator::Select,
 
..
 
} => true,
 
// No other nodes allow replacement
 
_ => false,
 
}
 
}
 
 
define_id_type!(SetID);
 
 
#[derive(Clone, Debug)]
 
struct UnionFindNode {
 
parent: SetID,
 
rank: usize,
 
}
 
 
#[derive(Clone, Debug)]
 
struct UnionFind<T> {
 
sets: Vec<UnionFindNode>,
 
_phantom: PhantomData<T>,
 
}
 
 
impl<T: ID> UnionFind<T> {
 
pub fn new() -> Self {
 
UnionFind {
 
sets: vec![],
 
_phantom: PhantomData,
 
}
 
}
 
 
fn extend_past(&mut self, size: usize) {
 
for i in self.sets.len()..=size {
 
// The new nodes we add are in their own sets and have rank 0
 
self.sets.push(UnionFindNode {
 
parent: SetID::new(i),
 
rank: 0,
 
});
 
}
 
}
 
 
pub fn find(&mut self, x: T) -> SetID {
 
self.extend_past(x.idx());
 
self.find_set(x.idx())
 
}
 
 
fn find_set(&mut self, x: usize) -> SetID {
 
let mut parent = self.sets[x].parent;
 
if parent.idx() != x {
 
parent = self.find_set(parent.idx());
 
self.sets[x].parent = parent;
 
}
 
parent
 
}
 
 
pub fn union(&mut self, x: T, y: T) {
 
let x = self.find(x);
 
let y = self.find(y);
 
self.link(x, y);
 
}
 
 
fn link(&mut self, x: SetID, y: SetID) {
 
if self.sets[x.idx()].rank > self.sets[y.idx()].rank {
 
self.sets[y.idx()].parent = x;
 
} else {
 
self.sets[x.idx()].parent = y;
 
if self.sets[x.idx()].rank == self.sets[y.idx()].rank {
 
self.sets[y.idx()].rank += 1;
 
}
 
}
 
}
 
 
pub fn sets(&mut self, keys: &[T]) -> Vec<Vec<T>> {
 
let key_index = keys
 
.iter()
 
.enumerate()
 
.map(|(i, k)| (self.find(*k), i))
 
.collect::<HashMap<SetID, usize>>();
 
let mut result = vec![vec![]; keys.len()];
 
 
let num_elements = self.sets.len();
 
for i in 0..num_elements {
 
let set = self.find_set(i);
 
let Some(idx) = key_index.get(&set) else {
 
continue;
 
};
 
result[*idx].push(T::new(i));
 
}
 
 
result
 
}
 
}
Loading