Skip to content
Snippets Groups Projects 18.72 KiB
use std::collections::HashMap;
use std::collections::HashSet;
use std::iter::zip;
use std::iter::FromIterator;

use itertools::Itertools;
use nestify::nest;

use hercules_ir::*;

use crate::*;

 * TODO: Forkify currently makes a bunch of small edits - this needs to be
 * changed so that every loop that gets forkified corresponds to a single edit
 * + sub-edits. This would allow us to run forkify on a subset of a function.
pub fn forkify(
    editor: &mut FunctionEditor,
    control_subgraph: &Subgraph,
    fork_join_map: &HashMap<NodeID, NodeID>,
    loops: &LoopTree,
) -> bool {
    let natural_loops = loops
        .filter(|(k, _)| editor.func().nodes[k.idx()].is_region());

    let natural_loops: Vec<_> = natural_loops.collect();

    for l in natural_loops {
        // FIXME: Run on all-bottom level loops, as they can be independently optimized without recomputing analyses.
        if editor.is_mutable(l.0)
            && forkify_loop(
                &Loop {
                    header: l.0,
                    control: l.1.clone(),
            return true;
    return false;

/** Given a node used as a loop bound, return a dynamic constant ID. */
pub fn get_node_as_dc(
    editor: &mut FunctionEditor,
    node: NodeID,
) -> Result<DynamicConstantID, String> {
    // Check for a constant used as loop bound.
    match editor.node(node) {
        Node::DynamicConstant {
            id: dynamic_constant_id,
        } => Ok(*dynamic_constant_id),
        Node::Constant { id: constant_id } => {
            let dc = match *editor.get_constant(*constant_id) {
                Constant::Integer8(x) => DynamicConstant::Constant(x as _),
                Constant::Integer16(x) => DynamicConstant::Constant(x as _),
                Constant::Integer32(x) => DynamicConstant::Constant(x as _),
                Constant::Integer64(x) => DynamicConstant::Constant(x as _),
                Constant::UnsignedInteger8(x) => DynamicConstant::Constant(x as _),
                Constant::UnsignedInteger16(x) => DynamicConstant::Constant(x as _),
                Constant::UnsignedInteger32(x) => DynamicConstant::Constant(x as _),
                Constant::UnsignedInteger64(x) => DynamicConstant::Constant(x as _),
                _ => return Err("Invalid constant as loop bound".to_string()),

            let mut b = DynamicConstantID::new(0);
            editor.edit(|mut edit| {
                b = edit.add_dynamic_constant(dc);
            // Return the ID of the dynamic constant that is generated from the constant
            // or dynamic constant that is the existing loop bound
        _ => Err("Blah".to_owned()),

 Top level function to convert natural loops with simple induction variables
 into fork-joins.
pub fn forkify_loop(
    editor: &mut FunctionEditor,
    control_subgraph: &Subgraph,
    _fork_join_map: &HashMap<NodeID, NodeID>,
    l: &Loop,
) -> bool {
    let function = editor.func();

    let Some(loop_condition) = get_loop_exit_conditions(function, l, control_subgraph) else {
        return false;

    let LoopExit::Conditional {
        if_node: loop_if,
    } = loop_condition.clone()
    else {
        return false;

    // Compute loop variance
    let loop_variance = compute_loop_variance(editor, l);
    let ivs = compute_induction_vars(editor.func(), l, &loop_variance);
    let ivs = compute_iv_ranges(editor, l, ivs, &loop_condition);
    let Some(canonical_iv) = has_canonical_iv(editor, l, &ivs) else {
        return false;

    // Get bound
    let bound = match canonical_iv {
        InductionVariable::Basic {
            node: _,
            initializer: _,
        } => final_value
            .map(|final_value| get_node_as_dc(editor, final_value))
            .and_then(|r| r.ok()),
        InductionVariable::SCEV(_) => return false,

    let Some(bound_dc_id) = bound else {
        return false;

    let function = editor.func();

    // Check if it is do-while loop.
    let loop_exit_projection = editor
        .filter(|id| !l.control[id.idx()])

    let loop_continue_projection = editor
        .filter(|id| l.control[id.idx()])

    let loop_preds: Vec<_> = editor
        .filter(|id| !l.control[id.idx()])

    // FIXME: @xrouth
    if loop_preds.len() != 1 {
        return false;

    let loop_pred = loop_preds[0];

    if !editor
        return false;

    // Get all phis used outside of the loop, they need to be reductionable.
    // For now just assume all phis will be phis used outside of the loop, except for the canonical iv.
    // FIXME: We need a different definiton of `loop_nodes` to check for phis used outside hte loop than the one
    // we currently have.
    let loop_nodes = calculate_loop_nodes(editor, l);

    // Check phis to see if they are reductionable, only PHIs depending on the loop are considered,
    let candidate_phis: Vec<_> = editor
        .filter(|id| function.nodes[id.idx()].is_phi())
        .filter(|id| *id != canonical_iv.phi())

    let reductionable_phis: Vec<_> = analyze_phis(&editor, &l, &candidate_phis, &loop_nodes)
    // TODO: Handle multiple loop body lasts.
    // If there are multiple candidates for loop body last, return false.
    if editor
        .filter(|id| l.control[id.idx()])
        > 1
        return false;

    let loop_body_last = editor.get_uses(loop_if).next().unwrap();

    if reductionable_phis
        .any(|phi| !matches!(phi, LoopPHI::Reductionable { .. }))
        return false;

    let phi_latches: Vec<_> = reductionable_phis
        .map(|phi| {
            let LoopPHI::Reductionable {
                phi: _,
                data_cycle: _,
                is_associative: _,
            } = phi
            else {

    let stop_on: HashSet<_> = editor
        .filter(|node| {
            if editor.node(node).is_phi() {
                return true;
            if editor.node(node).is_reduce() {
                return true;
            if editor.node(node).is_control() {
                return true;
            if phi_latches.contains(&node) {
                return true;


    // Outside loop users of IV, then exit;
    // Unless the outside user is through the loop latch of a reducing phi,
    // then we know how to replace this edge, so its fine!
    let iv_users: Vec<_> =
        walk_all_users_stop_on(canonical_iv.phi(), editor, stop_on.clone()).collect();

    if iv_users
        .any(|node| !loop_nodes.contains(&node) && *node != loop_if)
        return false;

    // Start Transformation:

    // Graft everything between header and loop condition
    // Attach join to right before header (after loop_body_last, unless loop body last *is* the header).
    // Attach fork to right after loop_continue_projection.

    // // Create fork and join nodes:
    let mut join_id = NodeID::new(0);
    let mut fork_id = NodeID::new(0);

    // Turn dc bound into max (1, bound),
    let bound_dc_id = {
        let mut max_id = DynamicConstantID::new(0);
        editor.edit(|mut edit| {
            let one_id = edit.add_dynamic_constant(DynamicConstant::Constant(1));
            max_id = edit.add_dynamic_constant(DynamicConstant::max(one_id, bound_dc_id));

    // FIXME: (@xrouth) double check handling of control in loop body.
    editor.edit(|mut edit| {
        let fork = Node::Fork {
            control: loop_pred,
            factors: Box::new([bound_dc_id]),
        fork_id = edit.add_node(fork);

        let join = Node::Join {
            control: if l.header == loop_body_last {
            } else {

        join_id = edit.add_node(join);


    let function = editor.func();
    let (_, factors) = function.nodes[fork_id.idx()].try_fork().unwrap();
    let dimension = factors.len() - 1;

    let redcutionable_phis_and_init: Vec<(_, NodeID)> = reductionable_phis
        .map(|reduction_phi| {
            let LoopPHI::Reductionable {
                data_cycle: _,
                continue_latch: _,
                is_associative: _,
            } = reduction_phi
            else {

            let function = editor.func();

            let init = *zip(
            .filter(|(c, _)| *c == loop_pred)

            (reduction_phi, init)

    // Start failable edit:
    let result = editor.edit(|mut edit| {
        let thread_id = Node::ThreadID {
            control: fork_id,
            dimension: dimension,
        let thread_id_id = edit.add_node(thread_id);

        // Replace uses that are inside with the thread id
        edit = edit.replace_all_uses_where(canonical_iv.phi(), thread_id_id, |node| {
        edit.sub_edit(canonical_iv.phi(), thread_id_id);

        edit = edit.delete_node(canonical_iv.phi())?;

        for (reduction_phi, init) in redcutionable_phis_and_init {
            let LoopPHI::Reductionable {
                data_cycle: _,
                is_associative: _,
            } = *reduction_phi
            else {

            let reduce = Node::Reduce {
                control: join_id,
                reduct: continue_latch,

            let reduce_id = edit.add_node(reduce);

            if (!edit.get_node(init).is_reduce()
                && edit.get_schedule(init).contains(&Schedule::ParallelReduce))
                || (!edit.get_node(continue_latch).is_reduce()
                    && edit
                edit = edit.add_schedule(reduce_id, Schedule::ParallelReduce)?;
            if (!edit.get_node(init).is_reduce()
                && edit.get_schedule(init).contains(&Schedule::MonoidReduce))
                || (!edit.get_node(continue_latch).is_reduce()
                    && edit
                edit = edit.add_schedule(reduce_id, Schedule::MonoidReduce)?;

            edit = edit.replace_all_uses_where(phi, reduce_id, |usee| *usee != reduce_id)?;
            edit = edit.replace_all_uses_where(continue_latch, reduce_id, |usee| {
                !loop_nodes.contains(usee) && *usee != reduce_id
            edit.sub_edit(phi, reduce_id);
            edit = edit.delete_node(phi)?

        edit = edit.replace_all_uses(l.header, fork_id)?;
        edit = edit.replace_all_uses(loop_continue_projection, fork_id)?;
        edit = edit.replace_all_uses(loop_exit_projection, join_id)?;
        edit.sub_edit(l.header, fork_id);
        edit.sub_edit(loop_continue_projection, fork_id);
        edit.sub_edit(loop_exit_projection, join_id);

        edit = edit.delete_node(loop_continue_projection)?;
        edit = edit.delete_node(condition_node)?; // Might have to get rid of other users of this.
        edit = edit.delete_node(loop_exit_projection)?;
        edit = edit.delete_node(loop_if)?;
        edit = edit.delete_node(l.header)?;

    return result;

nest! {
    pub enum LoopPHI {
        Reductionable {
            phi: NodeID,
            data_cycle: HashSet<NodeID>, // All nodes in a data cycle with this phi
            continue_latch: NodeID,
            is_associative: bool,
        ControlDependant(NodeID), // This phi is redcutionable, but its cycle might depend on internal control within the loop.

impl LoopPHI {
    pub fn get_phi(&self) -> NodeID {
        match self {
            LoopPHI::Reductionable { phi, .. } => *phi,
            LoopPHI::LoopDependant(node_id) => *node_id,
            LoopPHI::UsedByDependant(node_id) => *node_id,
            LoopPHI::ControlDependant(node_id) => *node_id,

Checks some conditions on loop variables that will need to be converted into reductions to be forkified.
 - The phi is in a cycle *in the loop* with itself.
 - Every cycle *in the loop* containing the phi does not contain any other phi of the loop header.
 - The phi does not immediatley (not blocked by another phi or another reduce) use any other phis of the loop header.
pub fn analyze_phis<'a>(
    editor: &'a FunctionEditor,
    natural_loop: &'a Loop,
    phis: &'a [NodeID],
    loop_nodes: &'a HashSet<NodeID>,
) -> impl Iterator<Item = LoopPHI> + 'a {
    // Find data cycles within the loop of this phi,
    // Start from the phis loop_continue_latch, and walk its uses until we find the original phi.

    phis.into_iter().map(move |phi| {
        let stop_on: HashSet<NodeID> = editor
            .filter(|node| {
                let data = &editor.func().nodes[node.idx()];

                // External Phi
                if let Node::Phi { control, data: _ } = data {
                    if !natural_loop.control[control.idx()] {
                        return true;

                // This phi
                if node == phi {
                    return true;

                // External Reduce
                if let Node::Reduce {
                    init: _,
                    reduct: _,
                } = data
                    if !natural_loop.control[control.idx()] {
                        return true;
                    } else {
                        return false;

                // Data Cycles Only
                if data.is_control() {
                    return true;

                return false;
        let continue_idx = editor
            .position(|node| natural_loop.control[node.idx()])

        let loop_continue_latch = editor.node(phi).try_phi().unwrap().1[continue_idx];

        let uses = walk_all_uses_stop_on(loop_continue_latch, editor, stop_on.clone());
        let users = walk_all_users_stop_on(*phi, editor, stop_on.clone());

        let other_stop_on: HashSet<NodeID> = editor
            .filter(|node| {
                let data = &editor.func().nodes[node.idx()];

                // Phi, Reduce
                if data.is_phi() {
                    return true;

                if data.is_reduce() {
                    return true;

                // External Control
                if data.is_control() {
                    return true;

                return false;

        let mut uses_for_dependance =
            walk_all_users_stop_on(loop_continue_latch, editor, other_stop_on);

        let set1: HashSet<_> = HashSet::from_iter(uses);
        let set2: HashSet<_> = HashSet::from_iter(users);

        let intersection: HashSet<_> = set1.intersection(&set2).cloned().collect();

        // If this phi uses any other phis the node is loop dependant,
        // we use `phis` because this phi can actually contain the loop iv and its fine.
        if uses_for_dependance.any(|node| phis.contains(&node) && node != *phi) {
        } else if intersection.clone().iter().next().is_some() {
            // PHIs on the frontier of the uses by the candidate phi, i.e in uses_for_dependance need
            // to have headers that postdominate the loop continue latch. The value of the PHI used needs to be defined
            // by the time the reduce is triggered (at the end of the loop's internal control).

            // No nodes in data cycles with this phi (in the loop) are used outside the loop, besides the loop_continue_latch.
            // If some other node in the cycle is used, there is not a valid node to assign it after making the cycle a reduce.
            if intersection
                .filter(|node| **node != loop_continue_latch)
                .any(|data_node| {
                        .any(|user| !loop_nodes.contains(&user))
                // This phi can be made into a reduce in different ways, if the cycle is associative (contains all the same kind of associative op)
                // 3) Split the cycle into two phis, add them or multiply them together at the end.
                // 4) Split the cycle into two reduces, add them or multiply them together at the end.
                // Somewhere else should handle this.
                return LoopPHI::LoopDependant(*phi);

            // FIXME: Do we want to calculate associativity here, there might be a case where this information is used in forkify
            // i.e as described above.
            let is_associative = false;

            // No nodes in the data cycle are used outside of the loop, besides the latched value of the phi
            LoopPHI::Reductionable {
                phi: *phi,
                data_cycle: intersection,
                continue_latch: loop_continue_latch,
        } else {
            // No cycles exist, this isn't a reduction.