Skip to content
Snippets Groups Projects
pred.rs 13.29 KiB
use std::cmp::{max, min};
use std::collections::{BTreeMap, BTreeSet};

use itertools::Itertools;

use hercules_ir::*;

use crate::*;

/*
 * Top level function to run predication on a function. Repeatedly looks for
 * acyclic control flow that can be converted into dataflow.
 */
pub fn predication(editor: &mut FunctionEditor, typing: &Vec<TypeID>) {
    // Remove branches iteratively, since predicating an inside branch may cause
    // an outside branch to be available for predication.
    let mut bad_branches = BTreeSet::new();
    loop {
        // First, look for a branch whose projections all point to the same
        // region. These are branches with no internal control flow.
        let nodes = &editor.func().nodes;
        let Some((region, branch, false_proj, true_proj)) = editor
            .node_ids()
            .filter_map(|id| {
                if let Node::Region { ref preds } = nodes[id.idx()] {
                    // Look for two projections with the same branch.
                    let preds = preds.into_iter().filter_map(|id| {
                        nodes[id.idx()]
                            .try_control_proj()
                            .map(|(branch, selection)| (*id, branch, selection))
                    });
                    // Index projections by if branch.
                    let mut pred_map: BTreeMap<NodeID, Vec<(NodeID, usize)>> = BTreeMap::new();
                    for (proj, branch, selection) in preds {
                        if nodes[branch.idx()].is_if() && !bad_branches.contains(&branch) {
                            pred_map.entry(branch).or_default().push((proj, selection));
                        }
                    }
                    // Look for an if branch with two projections going into the
                    // same region.
                    for (branch, projs) in pred_map {
                        if projs.len() == 2 {
                            let way = projs[0].1;
                            assert_ne!(way, projs[1].1);
                            return Some((id, branch, projs[way].0, projs[1 - way].0));
                        }
                    }
                }
                None
            })
            .next()
        else {
            break;
        };
        let Node::Region { preds } = nodes[region.idx()].clone() else {
            panic!()
        };
        let Node::If {
            control: if_pred,
            cond,
        } = nodes[branch.idx()]
        else {
            panic!()
        };
        let phis: Vec<_> = editor
            .get_users(region)
            .filter(|id| nodes[id.idx()].is_phi())
            .collect();
        // Don't predicate branches where one of the phis is a collection.
        // Predicating this branch would result in a clone and probably woudln't
        // result in good vector code.
        if phis
            .iter()
            .any(|id| !editor.get_type(typing[id.idx()]).is_primitive())
        {
            bad_branches.insert(branch);
            continue;
        }
        let false_pos = preds.iter().position(|id| *id == false_proj).unwrap();
        let true_pos = preds.iter().position(|id| *id == true_proj).unwrap();

        // Second, make all the modifications:
        // - Add the select nodes.
        // - Replace uses in phis with the select node.
        // - Remove the branch from the control flow.
        // This leaves the old region in place - if the region has one
        // predecessor, it may be removed by CCP.
        let success = editor.edit(|mut edit| {
            // Replace the branch projection predecessors of the region with the
            // predecessor of the branch.
            let Node::Region { preds } = edit.get_node(region).clone() else {
                panic!()
            };
            let mut preds = Vec::from(preds);
            preds.remove(max(true_pos, false_pos));
            preds.remove(min(true_pos, false_pos));
            preds.push(if_pred);
            let new_region = edit.add_node(Node::Region {
                preds: preds.into_boxed_slice(),
            });
            edit = edit.replace_all_uses(region, new_region)?;

            // Replace the corresponding inputs in the phi nodes with select
            // nodes selecting over the old inputs to the phis.
            for phi in phis {
                let Node::Phi { control: _, data } = edit.get_node(phi).clone() else {
                    panic!()
                };
                let mut data = Vec::from(data);
                let select = edit.add_node(Node::Ternary {
                    op: TernaryOperator::Select,
                    first: cond,
                    second: data[true_pos],
                    third: data[false_pos],
                });
                data.remove(max(true_pos, false_pos));
                data.remove(min(true_pos, false_pos));
                data.push(select);
                let new_phi = edit.add_node(Node::Phi {
                    control: new_region,
                    data: data.into_boxed_slice(),
                });
                edit = edit.replace_all_uses(phi, new_phi)?;
                edit = edit.delete_node(phi)?;
            }

            // Delete the old control nodes.
            edit = edit.delete_node(region)?;
            edit = edit.delete_node(branch)?;
            edit = edit.delete_node(false_proj)?;
            edit = edit.delete_node(true_proj)?;

            Ok(edit)
        });
        if !success {
            bad_branches.insert(branch);
        }
    }

    // Do a quick and dirty rewrite to convert select(a, b, false) to a && b and
    // select(a, b, true) to a || b.
    for id in editor.node_ids() {
        let nodes = &editor.func().nodes;
        if let Node::Ternary {
            op: TernaryOperator::Select,
            first,
            second,
            third,
        } = nodes[id.idx()]
        {
            if let Some(cons) = nodes[second.idx()].try_constant()
                && editor.get_constant(cons).is_false()
            {
                editor.edit(|mut edit| {
                    let inv = edit.add_node(Node::Unary {
                        op: UnaryOperator::Not,
                        input: first,
                    });
                    let node = edit.add_node(Node::Binary {
                        op: BinaryOperator::And,
                        left: inv,
                        right: third,
                    });
                    edit = edit.replace_all_uses(id, node)?;
                    edit.delete_node(id)
                });
            } else if let Some(cons) = nodes[third.idx()].try_constant()
                && editor.get_constant(cons).is_false()
            {
                editor.edit(|mut edit| {
                    let node = edit.add_node(Node::Binary {
                        op: BinaryOperator::And,
                        left: first,
                        right: second,
                    });
                    edit = edit.replace_all_uses(id, node)?;
                    edit.delete_node(id)
                });
            } else if let Some(cons) = nodes[second.idx()].try_constant()
                && editor.get_constant(cons).is_true()
            {
                editor.edit(|mut edit| {
                    let node = edit.add_node(Node::Binary {
                        op: BinaryOperator::Or,
                        left: first,
                        right: third,
                    });
                    edit = edit.replace_all_uses(id, node)?;
                    edit.delete_node(id)
                });
            } else if let Some(cons) = nodes[third.idx()].try_constant()
                && editor.get_constant(cons).is_true()
            {
                editor.edit(|mut edit| {
                    let inv = edit.add_node(Node::Unary {
                        op: UnaryOperator::Not,
                        input: first,
                    });
                    let node = edit.add_node(Node::Binary {
                        op: BinaryOperator::Or,
                        left: inv,
                        right: second,
                    });
                    edit = edit.replace_all_uses(id, node)?;
                    edit.delete_node(id)
                });
            }
        }
    }
}

/*
 * Top level function to run write predication on a function. Repeatedly looks
 * for phi nodes where every data input is a write node and each write node has
 * matching types of indices. These writes are coalesced into a single write
 * using the phi as the `collect` input, and the phi selects over the old
 * `collect` inputs to the writes. New phis are added to select over the `data`
 * inputs and indices.
 */
pub fn write_predication(editor: &mut FunctionEditor) {
    let mut bad_phis = BTreeSet::new();
    loop {
        // First, look for phis where every input is a write with the same
        // indexing structure.
        let nodes = &editor.func().nodes;
        let Some((phi, control, writes)) = editor
            .node_ids()
            .filter_map(|id| {
                if let Node::Phi {
                control,
                ref data,
            } = nodes[id.idx()]
                && !bad_phis.contains(&id)
                // Check that every input is a write - if this weren't true,
                // we'd have to insert dummy writes that write something that
                // was just read from the array. We could handle this case, but
                // it's probably not needed for now. Also check that the phi is
                // the only user of the write.
                && data.into_iter().all(|id| nodes[id.idx()].is_write() && editor.get_users(*id).count() == 1)
                // Check that every write input has equivalent indexing
                // structure.
                && data
                    .into_iter()
                    .filter_map(|id| nodes[id.idx()].try_write())
                    .tuple_windows()
                    .all(|(w1, w2)| indices_structurally_equivalent(w1.2, w2.2))
                {
                    Some((id, control, data.clone()))
                } else {
                    None
                }
            })
            .next()
        else {
            break;
        };
        let (collects, datas, indices): (Vec<_>, Vec<_>, Vec<_>) = writes
            .iter()
            .filter_map(|id| nodes[id.idx()].try_write())
            .map(|(collect, data, indices)| (collect, data, indices.to_owned()))
            .multiunzip();

        // Second, make all the modifications:
        // - Replace the old phi with a phi selecting over the `collect` inputs
        //   of the old write inputs.
        // - Add phis for the `data` and `indices` inputs to the old writes.
        // - Add a write that uses the phi-ed data and indices.
        // Be a little careful over how the old phi and writes get replaced,
        // since the old phi itself may be used by the old writes.
        let success = editor.edit(|mut edit| {
            // Create all the phis.
            let collect_phi = edit.add_node(Node::Phi {
                control,
                data: collects.into_boxed_slice(),
            });
            let data_phi = edit.add_node(Node::Phi {
                control,
                data: datas.into_boxed_slice(),
            });
            let mut phied_indices = vec![];
            for index in 0..indices[0].len() {
                match indices[0][index] {
                    // For field and variant indices, the index is the same
                    // across all writes, so just take the one from the first
                    // set of indices.
                    Index::Position(ref old_pos) => {
                        let mut pos = vec![];
                        for pos_idx in 0..old_pos.len() {
                            // This code is kind of messy due to three layers of
                            // arrays. Basically, we are collecting every
                            // indexing node on indices across each write.
                            pos.push(
                                edit.add_node(Node::Phi {
                                    control,
                                    data: indices
                                        .iter()
                                        .map(|indices| {
                                            indices[index].try_position().unwrap()[pos_idx]
                                        })
                                        .collect(),
                                }),
                            );
                        }
                        phied_indices.push(Index::Position(pos.into_boxed_slice()));
                    }
                    _ => {
                        phied_indices.push(indices[0][index].clone());
                    }
                }
            }

            // Create the write.
            let new_write = edit.add_node(Node::Write {
                collect: collect_phi,
                data: data_phi,
                indices: phied_indices.into_boxed_slice(),
            });

            // Replace the old phi with the new write.
            edit = edit.replace_all_uses(phi, new_write)?;

            // Delete the old phi and writes.
            edit = edit.delete_node(phi)?;
            for write in writes {
                edit = edit.delete_node(write)?;
            }

            Ok(edit)
        });
        if !success {
            bad_phis.insert(phi);
        }
    }
}