Skip to content
Snippets Groups Projects
fork_guard_elim.rs 11.76 KiB
use std::collections::{HashMap, HashSet};
use std::ops::Deref;

use hercules_ir::*;

use crate::*;

/*
 * This is a Hercules IR transformation that:
 * - Eliminates guards (directly) surrounding fork-joins when the guard's
 *   condition is of the form 0 < n for a fork with replication factor n,
 *   and when the initial inputs to the reduce nodes and phi nodes they feed
 *   into (for the region that joins control back between the guard's two
 *   branches) are the same
 *
 * This optimization is useful with code generated by the Juno frontend as it
 * generates guarded loops which are eventually converted into forks but the
 * guard remains and in these cases the guard is no longer needed.
 */

// Simplify factors through max
enum Factor {
    Max(usize, DynamicConstantID),
    Normal(DynamicConstantID),
}

impl Factor {
    fn get_id(&self) -> DynamicConstantID {
        match self {
            Factor::Max(_, dynamic_constant_id) => *dynamic_constant_id,
            Factor::Normal(dynamic_constant_id) => *dynamic_constant_id,
        }
    }
}

struct GuardedFork {
    fork: NodeID,
    join: NodeID,
    guard_if: NodeID,
    fork_taken_proj: NodeID,
    fork_skipped_proj: NodeID,
    guard_join_region: NodeID,
    phi_reduce_map: HashMap<NodeID, NodeID>,
    factor: Factor, // The factor that matches the guard
}

/* Given a node index and the node itself, return None if the node is not
 * a guarded fork where we can eliminate the guard.
 * If the node is a fork with a guard we can eliminate returns a tuple of
 * - This node's NodeID
 * - The replication factor of the fork
 * - The ID of the if of the guard
 * - The ID of the projections of the if
 * - The guard's predecessor
 * - A map of NodeIDs for the phi nodes to the reduce they should be replaced
 *   with, and also the region that joins the guard's branches mapping to the
 *   fork's join NodeID
 * - If the replication factor is a max that can be eliminated.
 */
fn guarded_fork(
    editor: &mut FunctionEditor,
    fork_join_map: &HashMap<NodeID, NodeID>,
    node: NodeID,
) -> Option<GuardedFork> {
    let function = editor.func();

    // Identify fork nodes
    let Node::Fork { control, factors } = &function.nodes[node.idx()] else {
        return None;
    };

    let mut factors = factors.iter().enumerate().map(|(idx, dc)| {
        let factor = editor.get_dynamic_constant(*dc);
        let DynamicConstant::Max(xs) = factor.deref() else {
            return Factor::Normal(*dc);
        };

        // Filter out any terms which are just 1s
        let non_ones = xs
            .iter()
            .filter(|i| {
                if let DynamicConstant::Constant(1) = editor.get_dynamic_constant(**i).deref() {
                    false
                } else {
                    true
                }
            })
            .collect::<Vec<_>>();
        // If we're left with just one term x, we had max { 1, x }
        if non_ones.len() == 1 {
            Factor::Max(idx, *non_ones[0])
        } else {
            Factor::Normal(*dc)
        }
    });

    // Whose predecessor is a read from an if
    let Node::ControlProjection {
        control: if_node,
        ref selection,
    } = function.nodes[control.idx()]
    else {
        return None;
    };
    let Node::If {
        control: if_pred,
        cond,
    } = function.nodes[if_node.idx()]
    else {
        return None;
    };

    // Whose condition is appropriate
    let Node::Binary { left, right, op } = function.nodes[cond.idx()] else {
        return None;
    };
    let branch_idx = *selection;

    let factor = {
        // branchIdx == 1 means the true branch so we want the condition to be
        // 0 < n or n > 0
        if branch_idx == 1 {
            [
                (left, BinaryOperator::LT, right),
                (right, BinaryOperator::GT, left),
            ]
            .iter()
            .find_map(|(pattern_zero, pattern_op, pattern_factor)| {
                // Match Op
                if op != *pattern_op {
                    return None;
                }
                // Match Zero
                if !(function.nodes[pattern_zero.idx()].is_zero_constant(&editor.get_constants())
                    || editor
                        .node(pattern_zero)
                        .is_zero_dc(&editor.get_dynamic_constants()))
                {
                    return None;
                }

                // Match Factor
                let factor = factors.find(|factor| {
                    match (
                        &function.nodes[pattern_factor.idx()],
                        &*editor.get_dynamic_constant(factor.get_id()),
                    ) {
                        (Node::Constant { id }, DynamicConstant::Constant(v)) => {
                            let Constant::UnsignedInteger64(pattern_v) = *editor.get_constant(*id)
                            else {
                                return false;
                            };
                            pattern_v == (*v as u64)
                        }
                        (Node::DynamicConstant { id }, _) => *id == factor.get_id(),
                        _ => false,
                    }
                });
                factor
            })
        }
        // branchIdx == 0 means the false branch so we want the condition to be
        // n < 0 or 0 > n
        else if branch_idx == 0 {
            [
                (right, BinaryOperator::LT, left),
                (left, BinaryOperator::GT, right),
            ]
            .iter()
            .find_map(|(pattern_zero, pattern_op, pattern_factor)| {
                // Match Op
                if op != *pattern_op {
                    return None;
                }
                // Match Zero
                if !(function.nodes[pattern_zero.idx()].is_zero_constant(&editor.get_constants())
                    || editor
                        .node(pattern_zero)
                        .is_zero_dc(&editor.get_dynamic_constants()))
                {
                    return None;
                }

                // Match Factor
                let factor = factors.find(|factor| {
                    function.nodes[pattern_factor.idx()].try_dynamic_constant()
                        == Some(factor.get_id())
                });
                factor
            })
        } else {
            None
        }
    };

    let Some(factor) = factor else { return None };

    // Identify the join node and its users
    let join_id = fork_join_map.get(&node)?;

    // Find the unique control use of the join; if it's not a region we can't
    // eliminate this guard
    let join_control = editor
        .get_users(*join_id)
        .filter(|n| function.nodes[n.idx()].is_region())
        .collect::<Vec<_>>();
    if join_control.len() != 1 {
        return None;
    }
    let join_control = join_control[0];
    let Some(Node::Region { preds }) = function.nodes.get(join_control.idx()) else {
        return None;
    };

    // The region after the join can only have two predecessors (for the guard
    // and the fork-join)
    if preds.len() != 2 {
        return None;
    }
    let other_pred = if preds[1] == *join_id {
        preds[0]
    } else if preds[0] == *join_id {
        preds[1]
    } else {
        return None;
    };
    // Other predecessor needs to be the other projection from the guard's if
    let Node::ControlProjection {
        control: if_node2,
        ref selection,
    } = function.nodes[other_pred.idx()]
    else {
        return None;
    };
    let else_branch = *selection;
    if else_branch == branch_idx {
        return None;
    }
    if if_node2 != if_node {
        return None;
    }

    // Finally, identify the phi nodes associated with the region and match
    // them with the reduce nodes of the fork-join
    let reduce_nodes = editor
        .get_users(*join_id)
        .filter(|n| function.nodes[n.idx()].is_reduce())
        .collect::<HashSet<_>>();
    // Construct a map from phi nodes indices to the reduce node index
    let phi_nodes = editor
        .get_users(join_control)
        .filter_map(|n| {
            let Node::Phi {
                control: _,
                ref data,
            } = function.nodes[n.idx()]
            else {
                return None;
            };
            if data.len() != 2 {
                return Some((n, None));
            }
            let (init_idx, reduce_node) = if reduce_nodes.contains(&data[0]) {
                (1, data[0])
            } else if reduce_nodes.contains(&data[1]) {
                (0, data[1])
            } else {
                return Some((n, None));
            };
            let Node::Reduce {
                control: _, init, ..
            } = function.nodes[reduce_node.idx()]
            else {
                return Some((n, None));
            };
            if data[init_idx] != init {
                return Some((n, None));
            }
            Some((n, Some(reduce_node)))
        })
        .collect::<HashMap<_, _>>();

    // If any of the phi nodes do not have an associated reduce node, we cannot
    // remove the loop guard
    if phi_nodes.iter().any(|(_, red)| red.is_none()) {
        return None;
    }

    let phi_nodes = phi_nodes
        .into_iter()
        .map(|(phi, red)| (phi, red.unwrap()))
        .collect::<HashMap<_, _>>();

    // Finally, we return this node's index along with
    // - The replication factor of the fork
    // - The if node
    // - The true and false reads of the if
    // - The guard's predecessor
    // - The map from phi nodes to reduce nodes and the region to the join
    Some(GuardedFork {
        fork: node,
        join: *join_id,
        guard_if: if_node,
        fork_taken_proj: *control,
        fork_skipped_proj: other_pred,
        guard_join_region: join_control,
        phi_reduce_map: phi_nodes,
        factor,
    })
}

/*
 * Top level function to run fork guard elimination, as described above.
 */
pub fn fork_guard_elim(editor: &mut FunctionEditor, fork_join_map: &HashMap<NodeID, NodeID>) {
    let guard_info = editor
        .node_ids()
        .filter_map(|node| guarded_fork(editor, fork_join_map, node))
        .collect::<Vec<_>>();

    for GuardedFork {
        fork,
        join,
        fork_taken_proj,
        fork_skipped_proj,
        phi_reduce_map,
        factor,
        guard_if,
        guard_join_region,
    } in guard_info
    {
        let Some(guard_pred) = editor.get_uses(guard_if).next() else {
            unreachable!()
        };
        let new_fork_info = if let Factor::Max(idx, dc) = factor {
            let Node::Fork {
                control: _,
                mut factors,
            } = editor.func().nodes[fork.idx()].clone()
            else {
                unreachable!()
            };
            factors[idx] = dc;
            let new_fork = Node::Fork {
                control: guard_pred,
                factors,
            };
            Some(new_fork)
        } else {
            None
        };

        editor.edit(|mut edit| {
            edit =
                edit.replace_all_uses_where(fork_taken_proj, guard_pred, |usee| *usee == fork)?;
            edit = edit.delete_node(guard_if)?;
            edit = edit.delete_node(fork_taken_proj)?;
            edit = edit.delete_node(fork_skipped_proj)?;
            edit = edit.replace_all_uses(guard_join_region, join)?;
            edit = edit.delete_node(guard_join_region)?;
            // Delete region node

            for (phi, reduce) in phi_reduce_map.iter() {
                edit = edit.replace_all_uses(*phi, *reduce)?;
                edit = edit.delete_node(*phi)?;
            }

            if let Some(new_fork_info) = new_fork_info {
                let new_fork = edit.add_node(new_fork_info);
                edit = edit.replace_all_uses(fork, new_fork)?;
                edit = edit.delete_node(fork)?;
            }

            Ok(edit)
        });
    }
}