Skip to content
Snippets Groups Projects
unforkify.rs 5.45 KiB
use std::collections::HashMap;
use std::iter::zip;

use hercules_ir::ir::*;

use crate::*;

/*
 * Convert forks back into loops right before codegen when a backend is not
 * lowering a fork-join to vector / parallel code. Lowering fork-joins into
 * sequential loops in LLVM is actually not entirely trivial, so it's easier to
 * just do this transformation within Hercules IR.
 */
pub fn unforkify(editor: &mut FunctionEditor, fork_join_map: &HashMap<NodeID, NodeID>) {
    let mut zero_cons_id = ConstantID::new(0);
    let mut one_cons_id = ConstantID::new(0);
    assert!(editor.edit(|mut edit| {
        zero_cons_id = edit.add_constant(Constant::UnsignedInteger64(0));
        one_cons_id = edit.add_constant(Constant::UnsignedInteger64(1));
        Ok(edit)
    }));

    // Convert the fork to a region, thread IDs to a single phi, reduces to
    // phis, and the join to a branch at the top of the loop. The previous
    // control insides of the fork-join should become the successor of the true
    // projection node, and what was the use of the join should become a use of
    // the new region.
    for (fork, join) in fork_join_map {
        let nodes = &editor.func().nodes;
        let (fork_control, factors) = nodes[fork.idx()].try_fork().unwrap();
        if factors.len() > 1 {
            // For now, don't convert multi-dimensional fork-joins. Rely on pass
            // that splits fork-joins.
            continue;
        }
        let join_control = nodes[join.idx()].try_join().unwrap();
        let tids: Vec<_> = editor
            .get_users(*fork)
            .filter(|id| nodes[id.idx()].is_thread_id())
            .collect();
        let reduces: Vec<_> = editor
            .get_users(*join)
            .filter(|id| nodes[id.idx()].is_reduce())
            .collect();

        let num_nodes = editor.node_ids().len();
        let region_id = NodeID::new(num_nodes);
        let if_id = NodeID::new(num_nodes + 1);
        let proj_back_id = NodeID::new(num_nodes + 2);
        let proj_exit_id = NodeID::new(num_nodes + 3);
        let zero_id = NodeID::new(num_nodes + 4);
        let one_id = NodeID::new(num_nodes + 5);
        let indvar_id = NodeID::new(num_nodes + 6);
        let add_id = NodeID::new(num_nodes + 7);
        let dc_id = NodeID::new(num_nodes + 8);
        let neq_id = NodeID::new(num_nodes + 9);
        let phi_ids = (num_nodes + 10..num_nodes + 10 + reduces.len()).map(NodeID::new);

        let region = Node::Region {
            preds: Box::new([
                fork_control,
                if join_control == *fork {
                    proj_back_id
                } else {
                    join_control
                },
            ]),
        };
        let if_node = Node::If {
            control: region_id,
            cond: neq_id,
        };
        let proj_back = Node::Projection {
            control: if_id,
            selection: 1,
        };
        let proj_exit = Node::Projection {
            control: if_id,
            selection: 0,
        };
        let zero = Node::Constant { id: zero_cons_id };
        let one = Node::Constant { id: one_cons_id };
        let indvar = Node::Phi {
            control: region_id,
            data: Box::new([zero_id, add_id]),
        };
        let add = Node::Binary {
            op: BinaryOperator::Add,
            left: indvar_id,
            right: one_id,
        };
        let dc = Node::DynamicConstant { id: factors[0] };
        let neq = Node::Binary {
            op: BinaryOperator::NE,
            left: indvar_id,
            right: dc_id,
        };
        let phis: Vec<_> = reduces
            .iter()
            .map(|reduce_id| {
                let (_, init, reduct) = nodes[reduce_id.idx()].try_reduce().unwrap();
                Node::Phi {
                    control: region_id,
                    data: Box::new([init, reduct]),
                }
            })
            .collect();

        editor.edit(|mut edit| {
            assert_eq!(edit.add_node(region), region_id);
            assert_eq!(edit.add_node(if_node), if_id);
            assert_eq!(edit.add_node(proj_back), proj_back_id);
            assert_eq!(edit.add_node(proj_exit), proj_exit_id);
            assert_eq!(edit.add_node(zero), zero_id);
            assert_eq!(edit.add_node(one), one_id);
            assert_eq!(edit.add_node(indvar), indvar_id);
            assert_eq!(edit.add_node(add), add_id);
            assert_eq!(edit.add_node(dc), dc_id);
            assert_eq!(edit.add_node(neq), neq_id);
            for (phi_id, phi) in zip(phi_ids.clone(), phis) {
                assert_eq!(edit.add_node(phi), phi_id);
            }

            edit = edit.replace_all_uses(*fork, proj_back_id)?;
            edit = edit.replace_all_uses(*join, proj_exit_id)?;
            edit.sub_edit(*fork, region_id);
            edit.sub_edit(*join, if_id);
            for tid in tids.iter() {
                edit.sub_edit(*tid, indvar_id);
                edit = edit.replace_all_uses(*tid, indvar_id)?;
            }
            for (reduce, phi_id) in zip(reduces.iter(), phi_ids) {
                edit.sub_edit(*reduce, phi_id);
                edit = edit.replace_all_uses(*reduce, phi_id)?;
            }

            edit = edit.delete_node(*fork)?;
            edit = edit.delete_node(*join)?;
            for tid in tids {
                edit = edit.delete_node(tid)?;
            }
            for reduce in reduces {
                edit = edit.delete_node(reduce)?;
            }

            Ok(edit)
        });
    }
}