Skip to content
Snippets Groups Projects
schedule.rs 6.05 KiB
use std::collections::{BTreeSet, HashMap, HashSet};

use hercules_ir::def_use::*;
use hercules_ir::ir::*;

use crate::*;

/*
 * Infer parallel fork-joins. These are fork-joins with only parallel reduction
 * variables.
 */
pub fn infer_parallel_fork(editor: &mut FunctionEditor, fork_join_map: &HashMap<NodeID, NodeID>) {
    for id in editor.node_ids() {
        let func = editor.func();
        if !func.nodes[id.idx()].is_fork() {
            continue;
        }
        let join_id = fork_join_map[&id];
        let all_parallel_reduce = editor.get_users(join_id).all(|user| {
            func.schedules[user.idx()].contains(&Schedule::ParallelReduce)
                || func.nodes[user.idx()].is_control()
        });
        if all_parallel_reduce {
            editor.edit(|edit| edit.add_schedule(id, Schedule::ParallelFork));
        }
    }
}

/*
 * Infer parallel reductions consisting of a simple cycle between a Reduce node
 * and a Write node, where indices of the Write are position indices using the
 * ThreadID nodes attached to the corresponding Fork, and data of the Write is
 * not in the Reduce node's cycle. This procedure also adds the ParallelReduce
 * schedule to Reduce nodes reducing over a parallelized Reduce, as long as the
 * base Write node also has position indices of the ThreadID of the outer fork.
 * In other words, the complete Reduce chain is annotated with ParallelReduce,
 * as long as each ThreadID dimension appears in the positional indexing of the
 * original Write.
 */
pub fn infer_parallel_reduce(
    editor: &mut FunctionEditor,
    fork_join_map: &HashMap<NodeID, NodeID>,
    reduce_cycles: &HashMap<NodeID, HashSet<NodeID>>,
) {
    for id in editor.node_ids() {
        let func = editor.func();
        if !func.nodes[id.idx()].is_reduce() {
            continue;
        }

        let mut first_control = None;
        let mut last_reduce = id;
        let mut chain_id = id;

        // Walk down Reduce chain until we reach the Reduce potentially looping
        // with the Write. Note the control node of the first Reduce, since this
        // will tell us which Thread ID to look for in the Write.
        while let Node::Reduce {
            control,
            init: _,
            reduct,
        } = func.nodes[chain_id.idx()]
        {
            if first_control.is_none() {
                first_control = Some(control);
            }

            last_reduce = chain_id;
            chain_id = reduct;
        }

        // Check for a Write-Reduce tight cycle.
        if let Node::Write {
            collect,
            data,
            indices,
        } = &func.nodes[chain_id.idx()]
            && *collect == last_reduce
            && !reduce_cycles[&last_reduce].contains(data)
        {
            // If there is a Write-Reduce tight cycle, get the position indices.
            let positions = indices
                .iter()
                .filter_map(|index| {
                    if let Index::Position(indices) = index {
                        Some(indices)
                    } else {
                        None
                    }
                })
                .flat_map(|pos| pos.iter());

            // Get the Forks corresponding to uses of bare ThreadIDs.
            let fork_thread_id_pairs = positions.filter_map(|id| {
                if let Node::ThreadID { control, dimension } = func.nodes[id.idx()] {
                    Some((control, dimension))
                } else {
                    None
                }
            });
            let mut forks = HashMap::<NodeID, Vec<usize>>::new();
            for (fork, dim) in fork_thread_id_pairs {
                forks.entry(fork).or_default().push(dim);
            }

            // Check if one of the Forks correspond to the Join associated with
            // the Reduce being considered, and has all of its dimensions
            // represented in the indexing.
            let is_parallel = forks.into_iter().any(|(id, mut rep_dims)| {
                rep_dims.sort();
                rep_dims.dedup();
                fork_join_map[&id] == first_control.unwrap()
                    && func.nodes[id.idx()].try_fork().unwrap().1.len() == rep_dims.len()
            });

            if is_parallel {
                editor.edit(|edit| edit.add_schedule(id, Schedule::ParallelReduce));
            }
        }
    }
}

/*
 * Infer monoid reduction loops. Exactly one of the associative operation's
 * operands must be the Reduce node, and all other operands must not be in the
 * Reduce node's cycle.
 */
pub fn infer_monoid_reduce(
    editor: &mut FunctionEditor,
    reduce_cycles: &HashMap<NodeID, HashSet<NodeID>>,
) {
    let is_binop_monoid = |op| {
        matches!(
            op,
            BinaryOperator::Add | BinaryOperator::Mul | BinaryOperator::Or | BinaryOperator::And
        )
    };
    let is_intrinsic_monoid = |intrinsic| matches!(intrinsic, Intrinsic::Max | Intrinsic::Min);

    for id in editor.node_ids() {
        let func = editor.func();
        if let Node::Reduce {
            control: _,
            init: _,
            reduct,
        } = func.nodes[id.idx()]
            && (matches!(func.nodes[reduct.idx()], Node::Binary { left, right, op }
                if ((left == id && !reduce_cycles[&id].contains(&right)) ||
                    (right == id && !reduce_cycles[&id].contains(&left))) &&
                    is_binop_monoid(op))
                || matches!(&func.nodes[reduct.idx()], Node::IntrinsicCall { intrinsic, args }
                if (args.contains(&id) && is_intrinsic_monoid(*intrinsic) && 
                    args.iter().filter(|arg| **arg != id).all(|arg| !reduce_cycles[&id].contains(arg)))))
        {
            editor.edit(|edit| edit.add_schedule(id, Schedule::MonoidReduce));
        }
    }
}

/*
 * From analysis result of which constants don't need to be reset, add schedules
 * to those constant nodes.
 */
pub fn infer_no_reset_constants(
    editor: &mut FunctionEditor,
    no_reset_constants: &BTreeSet<NodeID>,
) {
    for id in no_reset_constants {
        editor.edit(|edit| edit.add_schedule(*id, Schedule::NoResetConstant));
    }
}