Skip to content
Snippets Groups Projects
outline.rs 23.27 KiB
extern crate hercules_ir;

use std::collections::{BTreeMap, BTreeSet};
use std::iter::zip;
use std::sync::atomic::{AtomicUsize, Ordering};

use self::hercules_ir::def_use::*;
use self::hercules_ir::dom::*;
use self::hercules_ir::gcm::*;
use self::hercules_ir::ir::*;
use self::hercules_ir::subgraph::*;

use crate::*;

static NUM_OUTLINES: AtomicUsize = AtomicUsize::new(0);

/*
 * Top level function to outline a subset of nodes from a function.
 * TODO: Partitions with multiple entry or exit points should really accept /
 * return summations over the needed items given the incoming / outgoing control
 * paths, rather than accepting / returning every possibly needed value and
 * using undefs. Re-visit when summations are implemented in codegen.
 */
pub fn outline(
    editor: &mut FunctionEditor,
    typing: &Vec<TypeID>,
    control_subgraph: &Subgraph,
    dom: &DomTree,
    partition: &BTreeSet<NodeID>,
    to_be_function_id: FunctionID,
) -> Option<Function> {
    // Step 1: do a whole bunch of analysis on the partition.
    let nodes = &editor.func().nodes;
    assert!(
        !partition
            .iter()
            .any(|id| nodes[id.idx()].is_start() || nodes[id.idx()].is_parameter() || nodes[id.idx()].is_return()),
        "PANIC: Can't outline a partition containing the start node, parameter nodes, or return nodes."
    );
    let mut top_nodes = partition.iter().filter(|id| {
        nodes[id.idx()].is_control()
            && control_subgraph
                .preds(**id)
                .filter(|pred_id| !partition.contains(pred_id))
                .count()
                > 0
    });

    // There should be exactly one top node. Note that a top node has a
    // predecessor in a different partition, but may also have predecessors in
    // from its own partition.
    let top_node = *top_nodes.next().expect(
        "PANIC: Can't outline a partition with no entry points (has to contain at least one control node).",
    );
    assert!(
        top_nodes.next().is_none(),
        "PANIC: Can't outline a partition with multiple entry points."
    );

    // Figure out the nodes that will need to be passed in as parameters.
    let mut param_idx_to_outside_id = vec![];
    let mut found_outside_ids: BTreeSet<NodeID> = BTreeSet::new();
    for id in partition.iter() {
        // Data uses of other nodes get assigned a parameter index in the
        // outlined function's signature.
        for use_id in get_uses(&nodes[id.idx()]).as_ref() {
            if !nodes[use_id.idx()].is_control()
                && !partition.contains(use_id)
                && !found_outside_ids.contains(use_id)
            {
                param_idx_to_outside_id.push(*use_id);
                found_outside_ids.insert(*use_id);
            }
        }
    }
    let outside_id_to_param_idx: BTreeMap<_, _> = param_idx_to_outside_id
        .iter()
        .enumerate()
        .map(|(idx, id)| (*id, idx))
        .collect();

    // If there are multiple predecessors, then any phis depending on the top
    // node need to know which of their inputs to select. Explicitly take a
    // parameter indicating which control predecessor jumped to this partition.
    // Note, this is the control ID of the exit point of the previous partition,
    // not just the previous partition - partition A may have multiple exit
    // points that jump to partition B.
    let top_node_outside_preds: Box<[_]> = control_subgraph
        .preds(top_node)
        .filter(|pred_id| !partition.contains(pred_id))
        .collect();
    let callee_pred_param_idx = if top_node_outside_preds.len() > 1 {
        Some(param_idx_to_outside_id.len())
    } else {
        None
    };
    let outside_id_and_top_pred_to_does_dom: BTreeMap<_, _> = param_idx_to_outside_id
        .iter()
        .map(|outside_id| {
            let func = editor.func();
            top_node_outside_preds.iter().map(move |pred_id| {
                (
                    (*outside_id, *pred_id),
                    does_data_dom_control(func, *outside_id, *pred_id, dom),
                )
            })
        })
        .flatten()
        .collect();

    // Figure out the nodes that will need to be returned.
    let mut return_idx_to_inside_id = vec![];
    for id in partition.iter() {
        if nodes[id.idx()].is_control() {
            continue;
        }

        let users = editor.get_users(*id);
        for user_id in users {
            if !partition.contains(&user_id) {
                return_idx_to_inside_id.push(*id);
                break;
            }
        }
    }

    // Figure out the partition successors, and the exit points.
    assert!(
        !partition.iter().any(|id| nodes[id.idx()].is_control()
            && control_subgraph
                .succs(*id)
                .filter(|succ_id| !partition.contains(succ_id))
                .count()
                > 0 && control_subgraph
                .succs(*id)
                .count()
                > 1),
        "PANIC: Can't outline a partition where a single node has multiple successors and one is in another partition. Please include projection nodes in the same partition as its accompanying node."
    );
    let exit_points: Vec<_> = partition
        .iter()
        .filter(|id| nodes[id.idx()].is_control())
        .filter_map(|id| {
            Some((
                *id,
                control_subgraph
                    .succs(*id)
                    .filter(|succ_id| !partition.contains(succ_id))
                    .next()?,
            ))
        })
        .collect();
    assert!(
        exit_points.len() > 0,
        "PANIC: Can't outline partition with no exit points (how did you even get here?)."
    );

    // For each exit point, compute whether each data output dominates that
    // exit point. If not, then that return point should return an Undef
    // instead of the value.
    let exit_point_dom_return_values: Vec<BTreeSet<_>> = exit_points
        .iter()
        .map(|(control, _)| {
            return_idx_to_inside_id
                .iter()
                .map(|id| *id)
                .filter(|data| does_data_dom_control(editor.func(), *data, *control, dom))
                .collect()
        })
        .collect();

    // If there are multiple successors, then the caller needs to know which
    // successor to jump to after calling the outlined callee. Explicitly return
    // the exit point control ID reached for this purpose.
    let callee_succ_return_idx = if exit_points.len() > 1 {
        Some(return_idx_to_inside_id.len())
    } else {
        None
    };

    // Do the remaining work inside an edit.
    let mut outlined_done = None;
    editor.edit(|mut edit| {
        // Step 2: assemble the outlined function.
        let u32_ty = edit.add_type(Type::UnsignedInteger32);
        let mut outlined = Function {
            name: format!(
                "{}_{}",
                edit.get_name(),
                NUM_OUTLINES.fetch_add(1, Ordering::Relaxed)
            ),
            param_types: param_idx_to_outside_id
                .iter()
                .map(|id| typing[id.idx()])
                .chain(callee_pred_param_idx.map(|_| u32_ty))
                .collect(),
            return_type: edit.add_type(Type::Product(
                return_idx_to_inside_id
                    .iter()
                    .map(|id| typing[id.idx()])
                    .chain(callee_succ_return_idx.map(|_| u32_ty))
                    .collect(),
            )),
            nodes: vec![],
            num_dynamic_constants: edit.get_num_dynamic_constant_params(),
            entry: false,
        };

        // Re-number nodes in the partition.
        let mut old_to_new_id = BTreeMap::new();
        // Reserve ID 0 for the start node of the outlined function, and N IDs
        // for the parameters of the outlined function.
        let mut new_id_counter = NodeID::new(outlined.param_types.len() + 1);
        let mut new_id = || {
            let ret = new_id_counter;
            new_id_counter = NodeID::new(new_id_counter.idx() + 1);
            ret
        };
        for id in partition.iter() {
            old_to_new_id.entry(*id).or_insert(new_id());
        }

        // Create the partition nodes in the outlined function.

        // Add the start and parameter nodes.
        outlined.nodes.push(Node::Start);
        for index in 0..outlined.param_types.len() {
            outlined.nodes.push(Node::Parameter { index });
        }

        // Add the nodes from the partition.
        let mut select_top_phi_inputs = vec![];
        for id in partition.iter() {
            let convert_id = |old| {
                if let Some(new) = old_to_new_id.get(&old) {
                    // Map a use inside the partition to its new ID in the
                    // outlined function.
                    *new
                } else if let Some(idx) = outside_id_to_param_idx.get(&old) {
                    // Map a data use outside the partition to the ID of the
                    // corresponding parameter node.
                    NodeID::new(idx + 1)
                } else {
                    // Map a control use outside the partition to the start
                    // node. This corresponds to the outside predecessors of the
                    // top node and the use of the old start by constant,
                    // dynamic constant, parameter, and undef nodes.
                    NodeID::new(0)
                }
            };

            let mut node = edit.get_node(*id).clone();
            if let Node::Phi { control, data } = &mut node
                && *control == top_node
            {
                for datum in data.iter_mut() {
                    *datum = convert_id(*datum);
                }

                // If this node is a phi on the top node, we need to replace the
                // inputs corresponding to outside partition predecessors with a
                // single data input (with the corresponding control input of
                // the top node being the new start node) that explicitly
                // selects over the possible outside data inputs.
                let mut outside_pred_per_data = vec![];
                for (pred, data) in zip(edit.get_node(*control).try_region().unwrap(), data.iter())
                {
                    if !partition.contains(pred) {
                        outside_pred_per_data.push(Some((*pred, *data)));
                    } else {
                        outside_pred_per_data.push(None);
                    }
                }

                *control = convert_id(*control);
                let total_select = outside_pred_per_data
                    .iter()
                    .filter_map(|x| *x)
                    .reduce(|(_, acc), (pred, data)| {
                        // Select between the accumulated select or this
                        // particular data input if the predecessor parameter is
                        // equal to the predecessor corresponding to this data
                        // input in the original phi.
                        let constant_id = new_id();
                        let cmp_id = new_id();
                        let select_id = new_id();
                        let param_id = NodeID::new(callee_pred_param_idx.unwrap() + 1);
                        let cons_id =
                            edit.add_constant(Constant::UnsignedInteger32(pred.idx() as u32));
                        let constant = Node::Constant { id: cons_id };
                        let cmp = Node::Binary {
                            op: BinaryOperator::EQ,
                            left: param_id,
                            right: constant_id,
                        };
                        let select = Node::Ternary {
                            op: TernaryOperator::Select,
                            first: cmp_id,
                            second: data,
                            third: acc,
                        };
                        select_top_phi_inputs.push(constant);
                        select_top_phi_inputs.push(cmp);
                        select_top_phi_inputs.push(select);
                        (NodeID::new(0), select_id)
                    })
                    .map(|(_, acc)| acc);

                // Actually edit the data inputs of the phi node.
                let mut found_outside = false;
                *data = zip(
                    data.into_iter(),
                    outside_pred_per_data
                        .into_iter()
                        .map(|outside| outside.is_some()),
                )
                .filter_map(|(data, outside)| {
                    if outside {
                        // Only the first outside data input changes to the
                        // select node - the rest get dropped.
                        if found_outside {
                            None
                        } else {
                            found_outside = true;
                            Some(total_select.unwrap())
                        }
                    } else {
                        Some(*data)
                    }
                })
                .collect();
            } else if *id == top_node
                && let Node::Region { preds } = &mut node
            {
                // If this node is the top node and is a region, then only the
                // first predecessor that is outside the partition gets mapped
                // to the new predecessor (the new start node). All of the other
                // predecessors get removed from the predecessor list.
                let mut found_outside = false;
                *preds = preds
                    .into_iter()
                    .map(|u| convert_id(*u))
                    .filter(|pred_id| {
                        let include = !found_outside || *pred_id != NodeID::new(0);
                        found_outside = found_outside || *pred_id == NodeID::new(0);
                        include
                    })
                    .collect();
            } else {
                // Otherwise, just convert all the used IDs.
                for u in get_uses_mut(&mut node).as_mut() {
                    **u = convert_id(**u);
                }
            }
            assert_eq!(outlined.nodes.len(), old_to_new_id[id].idx());
            outlined.nodes.push(node);
        }
        outlined.nodes.extend(select_top_phi_inputs);

        // Add the return nodes.
        let cons_id = edit.add_zero_constant(outlined.return_type);
        for ((exit, _), dom_return_values) in
            zip(exit_points.iter(), exit_point_dom_return_values.iter())
        {
            // Get the IDs of the nodes we're going to return. This isn't just
            // the set of returned IDs, since some IDs may not dominate some
            // exit points - these get mapped to undefs. Additionally, when
            // there are multiple exit points, return which exit was taken.
            let mut data_ids: Vec<_> = return_idx_to_inside_id
                .iter()
                .map(|inside_id| {
                    if dom_return_values.contains(inside_id) {
                        // If this outside used value dominates this return,
                        // then return the value itself.
                        *inside_id
                    } else {
                        // If not, then return an Undef. Since this value
                        // doesn't dominate this return, it can't be used on
                        // this path, so returning Undef is fine. All the
                        // Undefs will get GVNed together, so make a bunch.
                        let undef_id = NodeID::new(outlined.nodes.len());
                        outlined.nodes.push(Node::Undef {
                            ty: typing[inside_id.idx()],
                        });
                        undef_id
                    }
                })
                .collect();
            if callee_succ_return_idx.is_some() {
                let cons_id = edit.add_constant(Constant::UnsignedInteger32(exit.idx() as u32));
                let cons_node_id = NodeID::new(outlined.nodes.len());
                outlined.nodes.push(Node::Constant { id: cons_id });
                data_ids.push(cons_node_id);
            }

            // Build the return product.
            let mut construct_id = NodeID::new(outlined.nodes.len());
            outlined.nodes.push(Node::Constant { id: cons_id });
            for (idx, data) in data_ids.into_iter().enumerate() {
                let write = Node::Write {
                    collect: construct_id,
                    data: data,
                    indices: Box::new([Index::Field(idx)]),
                };
                construct_id = NodeID::new(outlined.nodes.len());
                outlined.nodes.push(write);
            }

            // Return the return product.
            outlined.nodes.push(Node::Return {
                control: *exit,
                data: construct_id,
            });
        }

        // Step 3: edit the original function to call the outlined function.
        let dynamic_constants = (0..edit.get_num_dynamic_constant_params())
            .map(|idx| edit.add_dynamic_constant(DynamicConstant::Parameter(idx as usize)))
            .collect();
        // The top node can be a...
        // - Region (multiple predecessors): the original region gets replaced
        //   with a new region with the same predecessors. The inputs to the
        //   call, if not dominating the region, need to be passed to the call
        //   via phis.
        // - Region (single predecessor), If, Match, or Fork: the original top
        //   node gets replaced with a new region with a single predecessor. The
        //   call is generated in a straightforward manner.
        let (new_region_id, call_id) = if top_node_outside_preds.len() > 1 {
            let new_region = Node::Region {
                preds: top_node_outside_preds.clone(),
            };
            let new_region_id = edit.add_node(new_region);
            let mut args = vec![];

            // Add a phi per argument - selects between the input or undef when
            // defined or not on a specific control path.
            for outside_id in param_idx_to_outside_id.iter() {
                let undef = Node::Undef {
                    ty: typing[outside_id.idx()],
                };
                let undef_id = edit.add_node(undef);
                let phi = Node::Phi {
                    control: new_region_id,
                    data: top_node_outside_preds
                        .iter()
                        .map(|pred_id| {
                            if outside_id_and_top_pred_to_does_dom[&(*outside_id, *pred_id)] {
                                *outside_id
                            } else {
                                undef_id
                            }
                        })
                        .collect(),
                };
                args.push(edit.add_node(phi));
            }

            // Add a phi for the control predecessor indicator.
            let control_id_constants = top_node_outside_preds.iter().map(|pred_id| {
                let id = edit.add_constant(Constant::UnsignedInteger32(pred_id.idx() as u32));
                edit.add_node(Node::Constant { id })
            });
            let indicator_phi = Node::Phi {
                control: new_region_id,
                data: control_id_constants.collect(),
            };
            args.push(edit.add_node(indicator_phi));

            let call = Node::Call {
                control: new_region_id,
                function: to_be_function_id,
                dynamic_constants,
                args: args.into_boxed_slice(),
            };
            let call_id = edit.add_node(call);
            (new_region_id, call_id)
        } else {
            let new_region = Node::Region {
                preds: top_node_outside_preds,
            };
            let new_region_id = edit.add_node(new_region);
            let call = Node::Call {
                control: new_region_id,
                function: to_be_function_id,
                dynamic_constants,
                args: param_idx_to_outside_id.clone().into_boxed_slice(),
            };
            let call_id = edit.add_node(call);
            (new_region_id, call_id)
        };

        // Create the read nodes from the call node to get the outputs of the
        // outlined function.
        let output_reads: Vec<_> = (0..return_idx_to_inside_id.len())
            .map(|idx| {
                let read = Node::Read {
                    collect: call_id,
                    indices: Box::new([Index::Field(idx)]),
                };
                edit.add_node(read)
            })
            .collect();
        let indicator_read = callee_succ_return_idx.map(|idx| {
            let read = Node::Read {
                collect: call_id,
                indices: Box::new([Index::Field(idx)]),
            };
            edit.add_node(read)
        });
        for (old_id, new_id) in zip(return_idx_to_inside_id.iter(), output_reads.iter()) {
            edit = edit.replace_all_uses(*old_id, *new_id)?;
        }

        // The partition may have multiple exit points - in this case, build an
        // (imbalanced) if tree to branch to the correct successor control node.
        // The control successor to branch to is returned by the callee.
        let mut if_tree_acc = new_region_id;
        for idx in 0..exit_points.len() - 1 {
            let (indicator_id, _) = exit_points[idx];
            let indicator_cons_id =
                edit.add_constant(Constant::UnsignedInteger32(indicator_id.idx() as u32));
            let indicator_cons_node_id = edit.add_node(Node::Constant {
                id: indicator_cons_id,
            });
            let cmp_id = edit.add_node(Node::Binary {
                op: BinaryOperator::EQ,
                left: indicator_read.unwrap(),
                right: indicator_cons_node_id,
            });
            let if_id = edit.add_node(Node::If {
                control: if_tree_acc,
                cond: cmp_id,
            });
            let false_id = edit.add_node(Node::Projection {
                control: if_id,
                selection: 0,
            });
            let true_id = edit.add_node(Node::Projection {
                control: if_id,
                selection: 1,
            });
            edit = edit.replace_all_uses(indicator_id, true_id)?;
            if_tree_acc = false_id;
        }
        edit = edit.replace_all_uses(exit_points.last().unwrap().0, if_tree_acc)?;

        // Delete all the nodes that were outlined.
        for id in partition.iter() {
            edit = edit.delete_node(*id)?;
        }

        outlined_done = Some(outlined);
        Ok(edit)
    });

    outlined_done
}

/*
 * Just outlines all of a function accept the entry and return. Minimum work
 * needed to cause runtime Rust code to be generated as necessary.
 */
pub fn dumb_outline(
    editor: &mut FunctionEditor,
    typing: &Vec<TypeID>,
    control_subgraph: &Subgraph,
    dom: &DomTree,
    to_be_function_id: FunctionID,
) -> Option<Function> {
    collapse_returns(editor);
    let partition = editor
        .node_ids()
        .filter(|id| {
            let node = &editor.func().nodes[id.idx()];
            !(node.is_start() || node.is_parameter() || node.is_return())
        })
        .collect();
    outline(
        editor,
        typing,
        control_subgraph,
        dom,
        &partition,
        to_be_function_id,
    )
}