Skip to content
Snippets Groups Projects
sroa.rs 13.63 KiB
extern crate bitvec;
extern crate hercules_ir;

use std::collections::HashMap;
use std::iter::zip;

use self::bitvec::prelude::*;

use self::hercules_ir::dataflow::*;
use self::hercules_ir::def_use::*;
use self::hercules_ir::ir::*;

/*
 * Top level function to run SROA, intraprocedurally. Product values can be used
 * and created by a relatively small number of nodes. Here are *all* of them:
 *
 * - Phi: can merge SSA values of products - these get broken up into phis on
 *   the individual fields
 *
 * - Reduce: similarly to phis, reduce nodes can cycle product values through
 *   reduction loops - these get broken up into reduces on the fields
 *
 * + Return: can return a product - these are untouched, and are the sinks for
 *   unbroken product values
 *
 * + Parameter: can introduce a product - these are untouched, and are the
 *   sources for unbroken product values
 *
 * - Constant: can introduce a product - these are broken up into constants for
 *   the individual fields
 *
 * - Ternary: the select ternary operator can select between products - these
 *   are broken up into ternary nodes for the individual fields
 *
 * + Call: the call node can use a product value as an argument to another
 *   function, and can produce a product value as a result - these are
 *   untouched, and are the sink and source for unbroken product values
 *
 * - Read: the read node reads primitive fields from product values - these get
 *   replaced by a direct use of the field value from the broken product value,
 *   but are retained when the product value is unbroken
 *
 * - Write: the write node writes primitive fields in product values - these get
 *   replaced by a direct def of the field value from the broken product value,
 *   but are retained when the product value is unbroken
 *
 * The nodes above with the list marker "+" are retained for maintaining API/ABI
 * compatability with other Hercules functions and the host code. These are
 * called "sink" or "source" nodes in comments below.
 */
pub fn sroa(
    function: &mut Function,
    def_use: &ImmutableDefUseMap,
    reverse_postorder: &Vec<NodeID>,
    typing: &Vec<TypeID>,
    types: &Vec<Type>,
    constants: &mut Vec<Constant>,
) {
    // Determine which sources of product values we want to try breaking up. We
    // can determine easily on the soure side if a node produces a product that
    // shouldn't be broken up by just examining the node type. However, the way
    // that products are used is also important for determining if the product
    // can be broken up. We backward dataflow this info to the sources of
    // product values.
    #[derive(PartialEq, Eq, Clone, Debug)]
    enum ProductUseLattice {
        // The product value used by this node is eventually used by a sink.
        UsedBySink,
        // This node uses multiple product values - the stored node ID indicates
        // which is eventually used by a sink. This lattice value is produced by
        // read and write nodes implementing partial indexing.
        SpecificUsedBySink(NodeID),
        // This node doesn't use a product node, or the product node it does use
        // is not in turn used by a sink.
        UnusedBySink,
    }

    impl Semilattice for ProductUseLattice {
        fn meet(a: &Self, b: &Self) -> Self {
            match (a, b) {
                (Self::UsedBySink, _) | (_, Self::UsedBySink) => Self::UsedBySink,
                (Self::SpecificUsedBySink(id1), Self::SpecificUsedBySink(id2)) => {
                    if id1 == id2 {
                        Self::SpecificUsedBySink(*id1)
                    } else {
                        Self::UsedBySink
                    }
                }
                (Self::SpecificUsedBySink(id), _) | (_, Self::SpecificUsedBySink(id)) => {
                    Self::SpecificUsedBySink(*id)
                }
                _ => Self::UnusedBySink,
            }
        }

        fn bottom() -> Self {
            Self::UsedBySink
        }

        fn top() -> Self {
            Self::UnusedBySink
        }
    }

    // Run dataflow analysis to find which product values are used by a sink.
    let product_uses = backward_dataflow(function, def_use, reverse_postorder, |succ_outs, id| {
        match function.nodes[id.idx()] {
            Node::Return {
                control: _,
                data: _,
            } => {
                if types[typing[id.idx()].idx()].is_product() {
                    ProductUseLattice::UsedBySink
                } else {
                    ProductUseLattice::UnusedBySink
                }
            }
            Node::Call {
                function: _,
                dynamic_constants: _,
                args: _,
            } => todo!(),
            // For reads and writes, we only want to propagate the use of the
            // product to the collect input of the node.
            Node::Read {
                collect,
                indices: _,
            }
            | Node::Write {
                collect,
                data: _,
                indices: _,
            } => {
                let meet = succ_outs
                    .iter()
                    .fold(ProductUseLattice::top(), |acc, latt| {
                        ProductUseLattice::meet(&acc, latt)
                    });
                if meet == ProductUseLattice::UnusedBySink {
                    ProductUseLattice::UnusedBySink
                } else {
                    ProductUseLattice::SpecificUsedBySink(collect)
                }
            }
            // For non-sink nodes.
            _ => {
                if function.nodes[id.idx()].is_control() {
                    return ProductUseLattice::UnusedBySink;
                }
                let meet = succ_outs
                    .iter()
                    .fold(ProductUseLattice::top(), |acc, latt| {
                        ProductUseLattice::meet(&acc, latt)
                    });
                if let ProductUseLattice::SpecificUsedBySink(meet_id) = meet {
                    if meet_id == id {
                        ProductUseLattice::UsedBySink
                    } else {
                        ProductUseLattice::UnusedBySink
                    }
                } else {
                    meet
                }
            }
        }
    });

    // Only product values introduced as constants can be replaced by scalars.
    let to_sroa: Vec<(NodeID, ConstantID)> = product_uses
        .into_iter()
        .enumerate()
        .filter_map(|(node_idx, product_use)| {
            if ProductUseLattice::UnusedBySink == product_use
                && types[typing[node_idx].idx()].is_product()
            {
                function.nodes[node_idx]
                    .try_constant()
                    .map(|cons_id| (NodeID::new(node_idx), cons_id))
            } else {
                None
            }
        })
        .collect();
    println!("{:?}", to_sroa);

    // Perform SROA. TODO: repair def-use when there are multiple product
    // constants to SROA away.
    assert!(to_sroa.len() < 2);
    for (constant_node_id, constant_id) in to_sroa {
        // Get the field constants to replace the product constant with.
        let product_constant = constants[constant_id.idx()].clone();
        let constant_fields = product_constant
            .try_product_fields(types, constants)
            .unwrap();
        println!("{:?}", constant_fields);

        // DFS to find all data nodes that use the product constant.
        let to_replace = sroa_dfs(constant_node_id, function, def_use);
        println!("{:?}", to_replace);

        // Assemble a mapping from old nodes IDs acting on the product constant
        // to new nodes IDs operating on the field constants.
        let old_to_new_id_map: HashMap<NodeID, Vec<NodeID>> = to_replace
            .iter()
            .map(|old_id| match function.nodes[old_id.idx()] {
                Node::Phi {
                    control: _,
                    data: _,
                }
                | Node::Reduce {
                    control: _,
                    init: _,
                    reduct: _,
                }
                | Node::Constant { id: _ }
                | Node::Ternary {
                    op: _,
                    first: _,
                    second: _,
                    third: _,
                }
                | Node::Write {
                    collect: _,
                    data: _,
                    indices: _,
                } => {
                    let new_ids = (0..constant_fields.len())
                        .map(|_| {
                            let id = NodeID::new(function.nodes.len());
                            function.nodes.push(Node::Start);
                            id
                        })
                        .collect();
                    (*old_id, new_ids)
                }
                Node::Read {
                    collect: _,
                    indices: _,
                } => (*old_id, vec![]),
                _ => panic!("PANIC: Invalid node using a constant product found during SROA."),
            })
            .collect();

        // Replace the old nodes with the new nodes. Since we've already
        // allocated the node IDs, at this point we can iterate through the to-
        // replace nodes in an arbitrary order.
        for (old_id, new_ids) in &old_to_new_id_map {
            // First, add the new nodes to the node list.
            let node = function.nodes[old_id.idx()].clone();
            match node {
                // Replace the original constant with constants for each field.
                Node::Constant { id: _ } => {
                    for (new_id, field_id) in zip(new_ids.iter(), constant_fields.iter()) {
                        function.nodes[new_id.idx()] = Node::Constant { id: *field_id };
                    }
                }
                // Replace writes using the constant as the data use with a
                // series of writes writing the invidiual constant fields. TODO:
                // handle the case where the constant is the collect use of the
                // write node.
                Node::Write {
                    collect,
                    data,
                    ref indices,
                } => {
                    // Create the write chain.
                    assert!(old_to_new_id_map.contains_key(&data), "PANIC: Can't handle case where write node depends on constant to SROA in the collect use yet.");
                    let mut collect_def = collect;
                    for (idx, (new_id, new_data_def)) in
                        zip(new_ids.iter(), old_to_new_id_map[&data].iter()).enumerate()
                    {
                        let mut new_indices = indices.clone().into_vec();
                        new_indices.push(Index::Field(idx));
                        function.nodes[new_id.idx()] = Node::Write {
                            collect: collect_def,
                            data: *new_data_def,
                            indices: new_indices.into_boxed_slice(),
                        };
                        collect_def = *new_id;
                    }

                    // Replace uses of the old write with the new write.
                    for user in def_use.get_users(*old_id) {
                        get_uses_mut(&mut function.nodes[user.idx()]).map(*old_id, collect_def);
                    }
                }
                _ => todo!(),
            }

            // Delete the old node.
            function.nodes[old_id.idx()] = Node::Start;
        }
    }
}

fn sroa_dfs(src: NodeID, function: &Function, def_uses: &ImmutableDefUseMap) -> Vec<NodeID> {
    // Initialize order vector and bitset for tracking which nodes have been
    // visited.
    let order = Vec::with_capacity(def_uses.num_nodes());
    let visited = bitvec![u8, Lsb0; 0; def_uses.num_nodes()];

    // Order and visited are threaded through arguments / return pair of
    // sroa_dfs_helper for ownership reasons.
    let (order, _) = sroa_dfs_helper(src, src, function, def_uses, order, visited);
    order
}

fn sroa_dfs_helper(
    node: NodeID,
    def: NodeID,
    function: &Function,
    def_uses: &ImmutableDefUseMap,
    mut order: Vec<NodeID>,
    mut visited: BitVec<u8, Lsb0>,
) -> (Vec<NodeID>, BitVec<u8, Lsb0>) {
    if visited[node.idx()] {
        // If already visited, return early.
        (order, visited)
    } else {
        // Set visited to true.
        visited.set(node.idx(), true);

        // Before iterating users, push this node.
        order.push(node);
        match function.nodes[node.idx()] {
            Node::Phi {
                control: _,
                data: _,
            }
            | Node::Reduce {
                control: _,
                init: _,
                reduct: _,
            }
            | Node::Constant { id: _ }
            | Node::Ternary {
                op: _,
                first: _,
                second: _,
                third: _,
            } => {}
            Node::Read {
                collect,
                indices: _,
            } => {
                assert_eq!(def, collect);
                return (order, visited);
            }
            Node::Write {
                collect,
                data,
                indices: _,
            } => {
                if def == data {
                    return (order, visited);
                }
                assert_eq!(def, collect);
            }
            _ => panic!("PANIC: Invalid node using a constant product found during SROA."),
        }

        // Iterate over users, if we shouldn't stop here.
        for user in def_uses.get_users(node) {
            (order, visited) = sroa_dfs_helper(*user, node, function, def_uses, order, visited);
        }

        (order, visited)
    }
}