Skip to content
Snippets Groups Projects
array_to_prod.rs 12.37 KiB
use hercules_ir::define_id_type;
use hercules_ir::ir::*;

use bitvec::prelude::*;

use crate::*;

use std::collections::{HashMap, HashSet};
use std::marker::PhantomData;

/*
 * Top level function for array to product which will convert constant
 * sized arrays into products if the array is only accessed at indices which
 * are constants.
 *
 * To identify the collections we can convert we look at each constant-sized
 * array constant and compute the set which includes the constant node and is
 * closed under the following properties:
 * - For each collection in the set, its uses are in the set
 * - For each node that uses a collection, all collections it uses are in the
 *   set
 * From this set, we then determine whether this whole set can be converted to
 * operating on products, rather than arrays, as follows
 * - Each read and write node must be to a constant index
 * - It may not contain any arguments (we could generate code to read a an
 *   array argument into a product, but do not do so for now)
 * - There are call or return nodes in the set (this would mean that the
 *   collections are consumed by a call or return, again we could reconstruct
 *   the array where needed but do not do so for now and so have this
 *   restriction)
 * - All nodes in the set are editable (if we cannot modify some node then the
 *   conversion will fail)
 *
 * The max_size argument allows the user to specify a limit on the size of arrays
 * that should be converted to products. If the number of elements in the array
 * is larger than the max size the array will not be converted.
 */
pub fn array_to_product(editor: &mut FunctionEditor, types: &[TypeID], max_size: Option<usize>) {
    let replace_nodes = array_usage_analysis(editor, types, max_size);
    let num_nodes = editor.func().nodes.len();

    // Replace nodes
    for node_idx in 0..num_nodes {
        if !replace_nodes[node_idx] {
            continue;
        }
        let node = NodeID::new(node_idx);

        // We can replace the array(s) this node uses with a product. What we have to do depends on
        // the type of the node
        match &editor.func().nodes[node_idx] {
            // Phi, Reduce, and Ternary just use the whole collection, they do not need to change,
            // except as they will be modified by replace_all_uses_of
            Node::Phi { .. }
            | Node::Reduce { .. }
            | Node::Ternary {
                op: TernaryOperator::Select,
                ..
            } => {}
            Node::Constant { id } => {
                assert!(editor.get_constant(*id).is_array());
                let element: TypeID = editor.get_type(types[node_idx]).try_element_type().unwrap();
                let dims: Vec<usize> = editor
                    .get_type(types[node_idx])
                    .try_extents()
                    .unwrap()
                    .iter()
                    .map(|dc| editor.get_dynamic_constant(*dc).try_constant().unwrap())
                    .collect();
                // Replace the constant by a product that is a product (for each dimension) and the
                // elements are zero'd
                editor.edit(|mut edit| {
                    let element_zero = edit.add_zero_constant(element);
                    let (constant, _) = dims.into_iter().rfold(
                        (element_zero, element),
                        |(cur_const, cur_type), dim| {
                            let new_type = edit.add_type(Type::Product(vec![cur_type; dim].into()));
                            let new_const = edit.add_constant(Constant::Product(
                                new_type,
                                vec![cur_const; dim].into(),
                            ));
                            (new_const, new_type)
                        },
                    );
                    let new_val = edit.add_node(Node::Constant { id: constant });
                    let edit = edit.replace_all_uses(node, new_val)?;
                    edit.delete_node(node)
                });
            }
            Node::Read { collect, indices } => {
                let collect = *collect;
                let new_indices = convert_indices_to_prod(editor, indices);
                editor.edit(|mut edit| {
                    let new_val = edit.add_node(Node::Read {
                        collect,
                        indices: new_indices,
                    });
                    let edit = edit.replace_all_uses(NodeID::new(node_idx), new_val)?;
                    edit.delete_node(node)
                });
            }
            Node::Write {
                collect,
                data,
                indices,
            } => {
                let collect = *collect;
                let data = *data;
                let new_indices = convert_indices_to_prod(editor, indices);
                editor.edit(|mut edit| {
                    let new_val = edit.add_node(Node::Write {
                        collect,
                        data,
                        indices: new_indices,
                    });
                    let edit = edit.replace_all_uses(NodeID::new(node_idx), new_val)?;
                    edit.delete_node(node)
                });
            }
            node => panic!("Node cannot be replaced: {:?}", node),
        }
    }
}

fn convert_indices_to_prod(editor: &FunctionEditor, indices: &[Index]) -> Box<[Index]> {
    let mut result = vec![];

    for index in indices {
        match index {
            Index::Position(positions) => {
                for pos in positions {
                    let const_id = editor.func().nodes[pos.idx()]
                        .try_constant()
                        .expect("Array position must be constant");
                    match *editor.get_constant(const_id) {
                        Constant::UnsignedInteger64(idx) => result.push(Index::Field(idx as usize)),
                        ref val => panic!("Position should be u64 constant: {:?}", val),
                    }
                }
            }
            index => panic!("Index cannot be replaced: {:?}", index),
        }
    }

    result.into()
}

// Given the editor, while compute a mask of which nodes are to be converted
// from using a constant sized array into using a product
fn array_usage_analysis(
    editor: &FunctionEditor,
    types: &[TypeID],
    max_size: Option<usize>,
) -> BitVec<u8, Lsb0> {
    let num_nodes = editor.func().nodes.len();

    // Step 1: identify the constant nodes that are constant sized arrays no larger than the
    // max_size, these are what we are interested in converting into products
    let sources = editor
        .func()
        .nodes
        .iter()
        .enumerate()
        .filter_map(|(idx, node)| {
            let Node::Constant { id } = node else {
                return None;
            };
            let Constant::Array(array_type) = *editor.get_constant(*id) else {
                return None;
            };
            let typ = editor.get_type(array_type);
            let Some(dims) = typ.try_extents() else {
                return None;
            };
            // Compute the total number of elements, the result is None if some dimension is not a
            // constant and otherwise is Some(num_elements) which we can then compare with max_size
            if let Some(elements) = dims.iter().fold(Some(1), |prod, dc| {
                prod.and_then(|prod| {
                    editor
                        .get_dynamic_constant(*dc)
                        .try_constant()
                        .map(|dim| prod * dim)
                })
            }) {
                if let Some(max_size) = max_size
                    && elements > max_size
                {
                    // Too many elements, don't convert
                    None
                } else {
                    Some(NodeID::new(idx))
                }
            } else {
                None
            }
        })
        .collect::<Vec<_>>();

    // Step 2: collect the collection information we need for the (whole) function. For each node
    // that returns a collection (that in reference semantics returns the same reference as some of
    // its inputs) union with all of its users. The nodes that matter in this are arguments,
    // constants, writes, phis, selects, and reduces with array types.
    let mut analysis = UnionFind::new();
    for node_idx in 0..num_nodes {
        let node_id = NodeID::new(node_idx);
        if editor.get_type(types[node_idx]).is_array() {
            match editor.func().nodes[node_idx] {
                Node::Phi { .. }
                | Node::Reduce { .. }
                | Node::Parameter { .. }
                | Node::Constant { .. }
                | Node::Ternary {
                    op: TernaryOperator::Select,
                    ..
                }
                | Node::Write { .. } => {
                    for user in editor.get_users(node_id) {
                        analysis.union(node_id, user);
                    }
                }
                _ => {}
            }
        }
    }

    let sets = analysis.sets(&sources);

    // Step 3: determine which sets can be converted and mark the nodes in those sets
    let mut result = bitvec![u8, Lsb0; 0; num_nodes];

    for nodes in sets {
        if nodes
            .iter()
            .all(|node_id| editor.is_mutable(*node_id) && can_replace(editor, *node_id))
        {
            for node_id in nodes {
                result.set(node_id.idx(), true);
            }
        }
    }

    result
}

fn can_replace(editor: &FunctionEditor, node: NodeID) -> bool {
    match &editor.func().nodes[node.idx()] {
        // Reads and writes must be at constant indices
        Node::Read { indices, .. } | Node::Write { indices, .. } => {
            indices.iter().all(|idx| match idx {
                Index::Position(pos) => pos
                    .iter()
                    .all(|node| editor.func().nodes[node.idx()].is_constant()),
                _ => false,
            })
        }
        // phi, reduce, constants, and select can always be replaced if their users and uses allow
        // it, which is handled by the construction of the set
        Node::Phi { .. }
        | Node::Reduce { .. }
        | Node::Constant { .. }
        | Node::Ternary {
            op: TernaryOperator::Select,
            ..
        } => true,
        // No other nodes allow replacement
        _ => false,
    }
}

define_id_type!(SetID);

#[derive(Clone, Debug)]
struct UnionFindNode {
    parent: SetID,
    rank: usize,
}

#[derive(Clone, Debug)]
struct UnionFind<T> {
    sets: Vec<UnionFindNode>,
    _phantom: PhantomData<T>,
}

impl<T: ID> UnionFind<T> {
    pub fn new() -> Self {
        UnionFind {
            sets: vec![],
            _phantom: PhantomData,
        }
    }

    fn extend_past(&mut self, size: usize) {
        for i in self.sets.len()..=size {
            // The new nodes we add are in their own sets and have rank 0
            self.sets.push(UnionFindNode {
                parent: SetID::new(i),
                rank: 0,
            });
        }
    }

    pub fn find(&mut self, x: T) -> SetID {
        self.extend_past(x.idx());
        self.find_set(x.idx())
    }

    fn find_set(&mut self, x: usize) -> SetID {
        let mut parent = self.sets[x].parent;
        if parent.idx() != x {
            parent = self.find_set(parent.idx());
            self.sets[x].parent = parent;
        }
        parent
    }

    pub fn union(&mut self, x: T, y: T) {
        let x = self.find(x);
        let y = self.find(y);
        self.link(x, y);
    }

    fn link(&mut self, x: SetID, y: SetID) {
        if self.sets[x.idx()].rank > self.sets[y.idx()].rank {
            self.sets[y.idx()].parent = x;
        } else {
            self.sets[x.idx()].parent = y;
            if self.sets[x.idx()].rank == self.sets[y.idx()].rank {
                self.sets[y.idx()].rank += 1;
            }
        }
    }

    pub fn sets(&mut self, keys: &[T]) -> Vec<Vec<T>> {
        let key_index = keys
            .iter()
            .enumerate()
            .map(|(i, k)| (self.find(*k), i))
            .collect::<HashMap<SetID, usize>>();
        let mut result = vec![vec![]; keys.len()];

        let num_elements = self.sets.len();
        for i in 0..num_elements {
            let set = self.find_set(i);
            let Some(idx) = key_index.get(&set) else {
                continue;
            };
            result[*idx].push(T::new(i));
        }

        result
    }
}