Skip to content
Snippets Groups Projects
fork_transforms.rs 60.6 KiB
Newer Older
  • Learn to ignore specific revisions
  •                         });
                            edit.sub_edit(tid, add);
                            edit.sub_edit(tid, tile_tid);
                            edit = edit.replace_all_uses_where(tid, add, |usee| *usee != mul)?;
                        }
                    }
                    edit = edit.delete_node(fork)?;
                    Ok(edit)
                });
    
    rarbore2's avatar
    rarbore2 committed
            }
    
            TileOrder::TileOuter => {
                editor.edit(|mut edit| {
                    let inner = DynamicConstant::div(new_factors[dim_idx], tile_size);
                    new_factors.insert(dim_idx, tile_size);
    
    rarbore2's avatar
    rarbore2 committed
                    let inner_dc_id = edit.add_dynamic_constant(inner);
    
    Xavier Routh's avatar
    Xavier Routh committed
                    new_factors[dim_idx + 1] = inner_dc_id;
    
                    let new_fork = Node::Fork {
                        control: old_control,
                        factors: new_factors.into(),
    
    Xavier Routh's avatar
    Xavier Routh committed
                    };
    
                    let new_fork = edit.add_node(new_fork);
    
                    edit = edit.replace_all_uses(fork, new_fork)?;
                    edit.sub_edit(fork, new_fork);
    
                    for (tid, node) in fork_users {
                        let Node::ThreadID {
                            control: _,
                            dimension: tid_dim,
                        } = node
                        else {
                            continue;
                        };
                        if tid_dim > dim_idx {
                            let new_tid = Node::ThreadID {
                                control: new_fork,
                                dimension: tid_dim + 1,
                            };
                            let new_tid = edit.add_node(new_tid);
                            edit = edit.replace_all_uses(tid, new_tid)?;
                            edit.sub_edit(tid, new_tid);
                            edit = edit.delete_node(tid)?;
                        } else if tid_dim == dim_idx {
                            let tile_tid = Node::ThreadID {
                                control: new_fork,
    
    rarbore2's avatar
    rarbore2 committed
                                dimension: tid_dim + 1,
    
                            };
                            let tile_tid = edit.add_node(tile_tid);
    
    rarbore2's avatar
    rarbore2 committed
                            let inner_dc = edit.add_node(Node::DynamicConstant { id: inner_dc_id });
    
                            let mul = edit.add_node(Node::Binary {
                                left: tid,
                                right: inner_dc,
                                op: BinaryOperator::Mul,
                            });
                            let add = edit.add_node(Node::Binary {
                                left: mul,
                                right: tile_tid,
                                op: BinaryOperator::Add,
                            });
                            edit.sub_edit(tid, add);
                            edit.sub_edit(tid, tile_tid);
                            edit = edit.replace_all_uses_where(tid, add, |usee| *usee != mul)?;
                        }
                    }
                    edit = edit.delete_node(fork)?;
                    Ok(edit)
                });
    
    Xavier Routh's avatar
    Xavier Routh committed
            }
    
    Xavier Routh's avatar
    Xavier Routh committed
    }
    
    pub fn merge_all_fork_dims(editor: &mut FunctionEditor, fork_join_map: &HashMap<NodeID, NodeID>) {
        for (fork, _) in fork_join_map {
            let Node::Fork {
                control: _,
                factors: dims,
            } = editor.node(fork)
            else {
                unreachable!();
            };
    
            let mut fork = *fork;
            for _ in 0..dims.len() - 1 {
                let outer = 0;
                let inner = 1;
                fork = fork_dim_merge(editor, fork, outer, inner);
            }
        }
    }
    
    pub fn fork_dim_merge(
        editor: &mut FunctionEditor,
        fork: NodeID,
        dim_idx1: usize,
        dim_idx2: usize,
    ) -> NodeID {
        // tid_dim_idx1 (replaced w/) <- dim_idx1 / dim(dim_idx2)
        // tid_dim_idx2 (replaced w/) <- dim_idx1 % dim(dim_idx2)
        assert_ne!(dim_idx1, dim_idx2);
    
        // Outer is smaller, and also closer to the left of the factors array.
        let (outer_idx, inner_idx) = if dim_idx2 < dim_idx1 {
            (dim_idx2, dim_idx1)
        } else {
            (dim_idx1, dim_idx2)
        };
        let Node::Fork {
            control: old_control,
            factors: ref old_factors,
        } = *editor.node(fork)
        else {
            return fork;
        };
        let mut new_factors: Vec<_> = old_factors.to_vec();
        let fork_users: Vec<_> = editor
            .get_users(fork)
            .map(|f| (f, editor.node(f).clone()))
            .collect();
        let mut new_nodes = vec![];
        let outer_dc_id = new_factors[outer_idx];
        let inner_dc_id = new_factors[inner_idx];
        let mut new_fork = NodeID::new(0);
    
        editor.edit(|mut edit| {
            new_factors[outer_idx] = edit.add_dynamic_constant(DynamicConstant::mul(
                new_factors[outer_idx],
                new_factors[inner_idx],
            ));
            new_factors.remove(inner_idx);
            new_fork = edit.add_node(Node::Fork {
                control: old_control,
                factors: new_factors.into(),
            });
            edit.sub_edit(fork, new_fork);
            edit = edit.replace_all_uses(fork, new_fork)?;
            edit = edit.delete_node(fork)?;
    
            for (tid, node) in fork_users {
                let Node::ThreadID {
                    control: _,
                    dimension: tid_dim,
                } = node
                else {
                    continue;
                };
                if tid_dim > inner_idx {
                    let new_tid = Node::ThreadID {
                        control: new_fork,
                        dimension: tid_dim - 1,
                    };
                    let new_tid = edit.add_node(new_tid);
                    edit = edit.replace_all_uses(tid, new_tid)?;
                    edit.sub_edit(tid, new_tid);
                } else if tid_dim == outer_idx {
                    let outer_tid = Node::ThreadID {
                        control: new_fork,
                        dimension: outer_idx,
                    };
                    let outer_tid = edit.add_node(outer_tid);
    
                    let outer_dc = edit.add_node(Node::DynamicConstant { id: outer_dc_id });
                    new_nodes.push(outer_tid);
    
                    // inner_idx % dim(outer_idx)
                    let rem = edit.add_node(Node::Binary {
                        left: outer_tid,
                        right: outer_dc,
                        op: BinaryOperator::Rem,
                    });
                    edit.sub_edit(tid, rem);
    
    rarbore2's avatar
    rarbore2 committed
                    edit.sub_edit(tid, outer_tid);
    
    Xavier Routh's avatar
    Xavier Routh committed
                    edit = edit.replace_all_uses(tid, rem)?;
                } else if tid_dim == inner_idx {
                    let outer_tid = Node::ThreadID {
                        control: new_fork,
                        dimension: outer_idx,
                    };
                    let outer_tid = edit.add_node(outer_tid);
    
                    let outer_dc = edit.add_node(Node::DynamicConstant { id: outer_dc_id });
                    // inner_idx / dim(outer_idx)
                    let div = edit.add_node(Node::Binary {
                        left: outer_tid,
                        right: outer_dc,
                        op: BinaryOperator::Div,
                    });
                    edit.sub_edit(tid, div);
    
    rarbore2's avatar
    rarbore2 committed
                    edit.sub_edit(tid, outer_tid);
    
    Xavier Routh's avatar
    Xavier Routh committed
                    edit = edit.replace_all_uses(tid, div)?;
                }
            }
            Ok(edit)
        });
    
        new_fork
    }
    
    rarbore2's avatar
    rarbore2 committed
    
    /*
     * Run fork interchange on all fork-joins that are mutable in an editor.
     */
    pub fn fork_interchange_all_forks(
        editor: &mut FunctionEditor,
        fork_join_map: &HashMap<NodeID, NodeID>,
        first_dim: usize,
        second_dim: usize,
    ) {
        for (fork, join) in fork_join_map {
            if editor.is_mutable(*fork) {
                fork_interchange(editor, *fork, *join, first_dim, second_dim);
            }
        }
    }
    
    
    Aaron Councilman's avatar
    Aaron Councilman committed
    pub fn fork_interchange(
    
    rarbore2's avatar
    rarbore2 committed
        editor: &mut FunctionEditor,
        fork: NodeID,
        join: NodeID,
        first_dim: usize,
        second_dim: usize,
    
    Aaron Councilman's avatar
    Aaron Councilman committed
    ) -> Option<NodeID> {
    
    rarbore2's avatar
    rarbore2 committed
        // Check that every reduce on the join is parallel or associative.
    
    rarbore2's avatar
    rarbore2 committed
        let nodes = &editor.func().nodes;
        let schedules = &editor.func().schedules;
        if !editor
            .get_users(join)
            .filter(|id| nodes[id.idx()].is_reduce())
            .all(|id| {
                schedules[id.idx()].contains(&Schedule::ParallelReduce)
    
    rarbore2's avatar
    rarbore2 committed
                    || schedules[id.idx()].contains(&Schedule::MonoidReduce)
    
    rarbore2's avatar
    rarbore2 committed
            })
        {
            // If not, we can't necessarily do interchange.
    
    Aaron Councilman's avatar
    Aaron Councilman committed
            return None;
    
    rarbore2's avatar
    rarbore2 committed
        }
    
        let Node::Fork {
            control,
            ref factors,
        } = nodes[fork.idx()]
        else {
            panic!()
        };
        let fix_tids: Vec<(NodeID, Node)> = editor
            .get_users(fork)
            .filter_map(|id| {
                nodes[id.idx()]
                    .try_thread_id()
                    .map(|(_, dim)| {
                        if dim == first_dim {
                            Some((
                                id,
                                Node::ThreadID {
                                    control: fork,
                                    dimension: second_dim,
                                },
                            ))
                        } else if dim == second_dim {
                            Some((
                                id,
                                Node::ThreadID {
                                    control: fork,
                                    dimension: first_dim,
                                },
                            ))
                        } else {
                            None
                        }
                    })
                    .flatten()
            })
            .collect();
        let mut factors = factors.clone();
        factors.swap(first_dim, second_dim);
        let new_fork = Node::Fork { control, factors };
    
    Aaron Councilman's avatar
    Aaron Councilman committed
        let mut new_fork_id = None;
    
    rarbore2's avatar
    rarbore2 committed
        editor.edit(|mut edit| {
            for (old_id, new_tid) in fix_tids {
                let new_id = edit.add_node(new_tid);
                edit = edit.replace_all_uses(old_id, new_id)?;
                edit = edit.delete_node(old_id)?;
            }
            let new_fork = edit.add_node(new_fork);
    
    Aaron Councilman's avatar
    Aaron Councilman committed
            if edit.get_schedule(fork).contains(&Schedule::ParallelFork) {
                edit = edit.add_schedule(new_fork, Schedule::ParallelFork)?;
            }
    
    rarbore2's avatar
    rarbore2 committed
            edit = edit.replace_all_uses(fork, new_fork)?;
    
    Aaron Councilman's avatar
    Aaron Councilman committed
            edit = edit.delete_node(fork)?;
    
            new_fork_id = Some(new_fork);
            Ok(edit)
    
    rarbore2's avatar
    rarbore2 committed
        });
    
    Aaron Councilman's avatar
    Aaron Councilman committed
    
        new_fork_id
    
    rarbore2's avatar
    rarbore2 committed
    }
    
    rarbore2's avatar
    rarbore2 committed
    
    /*
     * Run fork unrolling on all fork-joins that are mutable in an editor.
     */
    pub fn fork_unroll_all_forks(
        editor: &mut FunctionEditor,
        fork_joins: &HashMap<NodeID, NodeID>,
        nodes_in_fork_joins: &HashMap<NodeID, HashSet<NodeID>>,
    ) {
        for (fork, join) in fork_joins {
            if editor.is_mutable(*fork) && fork_unroll(editor, *fork, *join, nodes_in_fork_joins) {
                break;
            }
        }
    }
    
    pub fn fork_unroll(
        editor: &mut FunctionEditor,
        fork: NodeID,
        join: NodeID,
        nodes_in_fork_joins: &HashMap<NodeID, HashSet<NodeID>>,
    ) -> bool {
        // We can only unroll fork-joins with a compile time known factor list. For
        // simplicity, just unroll fork-joins that have a single dimension.
        let nodes = &editor.func().nodes;
        let Node::Fork {
            control,
            ref factors,
        } = nodes[fork.idx()]
        else {
            panic!()
        };
        if factors.len() != 1 || editor.get_users(fork).count() != 2 {
            return false;
        }
        let DynamicConstant::Constant(cons) = *editor.get_dynamic_constant(factors[0]) else {
            return false;
        };
        let tid = editor
            .get_users(fork)
            .filter(|id| nodes[id.idx()].is_thread_id())
            .next()
            .unwrap();
        let inits: HashMap<NodeID, NodeID> = editor
            .get_users(join)
            .filter_map(|id| nodes[id.idx()].try_reduce().map(|(_, init, _)| (id, init)))
            .collect();
    
        editor.edit(|mut edit| {
            // Create a copy of the nodes in the fork join per unrolled iteration,
            // excluding the fork itself, the join itself, the thread IDs of the fork,
            // and the reduces on the join. Keep a running tally of the top control
            // node and the current reduction value.
            let mut top_control = control;
            let mut current_reduces = inits;
            for iter in 0..cons {
                let iter_cons = edit.add_constant(Constant::UnsignedInteger64(iter as u64));
                let iter_tid = edit.add_node(Node::Constant { id: iter_cons });
    
                // First, add a copy of each node in the fork join unmodified.
                // Record the mapping from old ID to new ID.
                let mut added_ids = HashSet::new();
                let mut old_to_new_ids = HashMap::new();
                let mut new_control = None;
                let mut new_reduces = HashMap::new();
                for node in nodes_in_fork_joins[&fork].iter() {
                    if *node == fork {
                        old_to_new_ids.insert(*node, top_control);
                    } else if *node == join {
                        new_control = Some(edit.get_node(*node).try_join().unwrap());
                    } else if *node == tid {
                        old_to_new_ids.insert(*node, iter_tid);
                    } else if let Some(current) = current_reduces.get(node) {
                        old_to_new_ids.insert(*node, *current);
                        new_reduces.insert(*node, edit.get_node(*node).try_reduce().unwrap().2);
                    } else {
                        let new_node = edit.add_node(edit.get_node(*node).clone());
                        old_to_new_ids.insert(*node, new_node);
                        added_ids.insert(new_node);
                    }
                }
    
                // Second, replace all the uses in the just added nodes.
                if let Some(new_control) = new_control {
                    top_control = old_to_new_ids[&new_control];
                }
                for (reduce, reduct) in new_reduces {
                    current_reduces.insert(reduce, old_to_new_ids[&reduct]);
                }
                for (old, new) in old_to_new_ids {
                    edit = edit.replace_all_uses_where(old, new, |id| added_ids.contains(id))?;
                }
            }
    
            // Hook up the control and reduce outputs to the rest of the function.
            edit = edit.replace_all_uses(join, top_control)?;
            for (reduce, reduct) in current_reduces {
                edit = edit.replace_all_uses(reduce, reduct)?;
            }
    
            // Delete the old fork-join.
            for node in nodes_in_fork_joins[&fork].iter() {
                edit = edit.delete_node(*node)?;
            }
    
            Ok(edit)
        })
    }
    
    rarbore2's avatar
    rarbore2 committed
    
    /*
     * Looks for fork-joins that are next to each other, not inter-dependent, and
     * have the same bounds. These fork-joins can be fused, pooling together all
     * their reductions.
     */
    pub fn fork_fusion_all_forks(
        editor: &mut FunctionEditor,
        fork_join_map: &HashMap<NodeID, NodeID>,
        nodes_in_fork_joins: &HashMap<NodeID, HashSet<NodeID>>,
    ) {
        for (fork, join) in fork_join_map {
            if editor.is_mutable(*fork)
                && fork_fusion(editor, *fork, *join, fork_join_map, nodes_in_fork_joins)
            {
                break;
            }
        }
    }
    
    /*
     * Tries to fuse a given fork join with the immediately following fork-join, if
     * it exists.
     */
    fn fork_fusion(
        editor: &mut FunctionEditor,
        top_fork: NodeID,
        top_join: NodeID,
        fork_join_map: &HashMap<NodeID, NodeID>,
        nodes_in_fork_joins: &HashMap<NodeID, HashSet<NodeID>>,
    ) -> bool {
        let nodes = &editor.func().nodes;
        // Rust operator precedence is not such that these can be put in one big
        // let-else statement. Sad!
        let Some(bottom_fork) = editor
            .get_users(top_join)
            .filter(|id| nodes[id.idx()].is_control())
            .next()
        else {
            return false;
        };
        let Some(bottom_join) = fork_join_map.get(&bottom_fork) else {
            return false;
        };
        let (_, top_factors) = nodes[top_fork.idx()].try_fork().unwrap();
        let (bottom_fork_pred, bottom_factors) = nodes[bottom_fork.idx()].try_fork().unwrap();
        assert_eq!(bottom_fork_pred, top_join);
        let top_join_pred = nodes[top_join.idx()].try_join().unwrap();
        let bottom_join_pred = nodes[bottom_join.idx()].try_join().unwrap();
    
        // The fork factors must be identical.
        if top_factors != bottom_factors {
            return false;
        }
    
        // Check that no iterated users of the top's reduces are in the bottom fork-
        // join (iteration stops at a phi or reduce outside the bottom fork-join).
        for reduce in editor
            .get_users(top_join)
            .filter(|id| nodes[id.idx()].is_reduce())
        {
            let mut visited = HashSet::new();
            visited.insert(reduce);
            let mut workset = vec![reduce];
            while let Some(pop) = workset.pop() {
                for u in editor.get_users(pop) {
                    if nodes_in_fork_joins[&bottom_fork].contains(&u) {
                        return false;
                    } else if (nodes[u.idx()].is_phi() || nodes[u.idx()].is_reduce())
                        && !nodes_in_fork_joins[&top_fork].contains(&u)
                    {
                    } else if !visited.contains(&u) && !nodes_in_fork_joins[&top_fork].contains(&u) {
                        visited.insert(u);
                        workset.push(u);
                    }
                }
            }
        }
    
        // Perform the fusion.
    
    rarbore2's avatar
    rarbore2 committed
        let bottom_tids: Vec<_> = editor
            .get_users(bottom_fork)
            .filter(|id| nodes[id.idx()].is_thread_id())
            .collect();
    
    rarbore2's avatar
    rarbore2 committed
        editor.edit(|mut edit| {
    
    rarbore2's avatar
    rarbore2 committed
            edit = edit.replace_all_uses_where(bottom_fork, top_fork, |id| bottom_tids.contains(id))?;
    
    rarbore2's avatar
    rarbore2 committed
            if bottom_join_pred != bottom_fork {
                // If there is control flow in the bottom fork-join, stitch it into
                // the top fork-join.
                edit = edit.replace_all_uses_where(bottom_fork, top_join_pred, |id| {
                    nodes_in_fork_joins[&bottom_fork].contains(id)
                })?;
                edit =
                    edit.replace_all_uses_where(top_join_pred, bottom_join_pred, |id| *id == top_join)?;
            }
            // Replace the bottom fork and join with the top fork and join.
            edit = edit.replace_all_uses(bottom_fork, top_fork)?;
            edit = edit.replace_all_uses(*bottom_join, top_join)?;
            edit = edit.delete_node(bottom_fork)?;
            edit = edit.delete_node(*bottom_join)?;
            Ok(edit)
        })
    }
    
    rarbore2's avatar
    rarbore2 committed
    
    /*
     * Looks for monoid reductions where the initial input is not the identity
     * element, and converts them into a form whose initial input is an identity
     * element. This aides in parallelizing outer loops. Looks only at reduces with
     * the monoid reduce schedule, since that indicates a particular structure which
     * is annoying to check for again.
    
    rarbore2's avatar
    rarbore2 committed
     *
     * Looks for would-be monoid reduces, if not for a gate on the reduction.
     * Partially predicate the gated reduction to allow for a proper monoid
     * reduction.
    
    rarbore2's avatar
    rarbore2 committed
     */
    pub fn clean_monoid_reduces(editor: &mut FunctionEditor, typing: &Vec<TypeID>) {
        for id in editor.node_ids() {
            if !editor.func().schedules[id.idx()].contains(&Schedule::MonoidReduce) {
                continue;
            }
            let nodes = &editor.func().nodes;
            let Some((_, init, reduct)) = nodes[id.idx()].try_reduce() else {
                continue;
            };
    
    rarbore2's avatar
    rarbore2 committed
            let out_users: Vec<_> = editor.get_users(id).filter(|id| *id != reduct).collect();
    
    rarbore2's avatar
    rarbore2 committed
    
            match nodes[reduct.idx()] {
                Node::Binary {
                    op,
                    left: _,
                    right: _,
                } if (op == BinaryOperator::Add || op == BinaryOperator::Or)
    
    rarbore2's avatar
    rarbore2 committed
                    && !is_zero(editor, init)
                    && !is_false(editor, init) =>
    
    rarbore2's avatar
    rarbore2 committed
                {
                    editor.edit(|mut edit| {
                        let zero = edit.add_zero_constant(typing[init.idx()]);
                        let zero = edit.add_node(Node::Constant { id: zero });
                        edit.sub_edit(id, zero);
                        edit = edit.replace_all_uses_where(init, zero, |u| *u == id)?;
                        let final_op = edit.add_node(Node::Binary {
                            op,
                            left: init,
                            right: id,
                        });
    
    rarbore2's avatar
    rarbore2 committed
                        for u in out_users {
    
    rarbore2's avatar
    rarbore2 committed
                            edit.sub_edit(u, final_op);
                        }
                        edit.replace_all_uses_where(id, final_op, |u| *u != reduct && *u != final_op)
                    });
                }
                Node::Binary {
                    op,
                    left: _,
                    right: _,
                } if (op == BinaryOperator::Mul || op == BinaryOperator::And)
    
    rarbore2's avatar
    rarbore2 committed
                    && !is_one(editor, init)
                    && !is_true(editor, init) =>
    
    rarbore2's avatar
    rarbore2 committed
                {
                    editor.edit(|mut edit| {
                        let one = edit.add_one_constant(typing[init.idx()]);
                        let one = edit.add_node(Node::Constant { id: one });
                        edit.sub_edit(id, one);
                        edit = edit.replace_all_uses_where(init, one, |u| *u == id)?;
                        let final_op = edit.add_node(Node::Binary {
                            op,
                            left: init,
                            right: id,
                        });
    
    rarbore2's avatar
    rarbore2 committed
                        for u in out_users {
    
    rarbore2's avatar
    rarbore2 committed
                            edit.sub_edit(u, final_op);
                        }
                        edit.replace_all_uses_where(id, final_op, |u| *u != reduct && *u != final_op)
                    });
                }
    
    rarbore2's avatar
    rarbore2 committed
                Node::IntrinsicCall {
                    intrinsic: Intrinsic::Max,
                    args: _,
                } if !is_smallest(editor, init) => {
                    editor.edit(|mut edit| {
                        let smallest = edit.add_smallest_constant(typing[init.idx()]);
                        let smallest = edit.add_node(Node::Constant { id: smallest });
                        edit.sub_edit(id, smallest);
                        edit = edit.replace_all_uses_where(init, smallest, |u| *u == id)?;
                        let final_op = edit.add_node(Node::IntrinsicCall {
                            intrinsic: Intrinsic::Max,
                            args: Box::new([init, id]),
                        });
    
    rarbore2's avatar
    rarbore2 committed
                        for u in out_users {
    
    rarbore2's avatar
    rarbore2 committed
                            edit.sub_edit(u, final_op);
                        }
                        edit.replace_all_uses_where(id, final_op, |u| *u != reduct && *u != final_op)
                    });
                }
                Node::IntrinsicCall {
                    intrinsic: Intrinsic::Min,
                    args: _,
                } if !is_largest(editor, init) => {
                    editor.edit(|mut edit| {
                        let largest = edit.add_largest_constant(typing[init.idx()]);
                        let largest = edit.add_node(Node::Constant { id: largest });
                        edit.sub_edit(id, largest);
                        edit = edit.replace_all_uses_where(init, largest, |u| *u == id)?;
                        let final_op = edit.add_node(Node::IntrinsicCall {
                            intrinsic: Intrinsic::Min,
                            args: Box::new([init, id]),
                        });
    
    rarbore2's avatar
    rarbore2 committed
                        for u in out_users {
    
    rarbore2's avatar
    rarbore2 committed
                            edit.sub_edit(u, final_op);
                        }
                        edit.replace_all_uses_where(id, final_op, |u| *u != reduct && *u != final_op)
                    });
                }
    
    rarbore2's avatar
    rarbore2 committed
                _ => {}
            }
        }
    
    rarbore2's avatar
    rarbore2 committed
    
        for id in editor.node_ids() {
            if !editor.func().schedules[id.idx()].contains(&Schedule::MonoidReduce) {
                continue;
            }
            let nodes = &editor.func().nodes;
            let Some((control, init, reduct)) = nodes[id.idx()].try_reduce() else {
                continue;
            };
            if let Node::Phi {
                control: phi_control,
                ref data,
            } = nodes[reduct.idx()]
                && data.len() == 2
                && data.contains(&id)
                && let other = *data
                    .into_iter()
                    .filter(|other| **other != id)
                    .next()
                    .unwrap()
                && let Node::Binary {
                    op: BinaryOperator::Add,
                    left,
                    right,
                } = nodes[other.idx()]
                && ((left == id) ^ (right == id))
            {
                let gated_input = if left == id { right } else { left };
                let data = data.clone();
                editor.edit(|mut edit| {
                    let zero = edit.add_zero_constant(typing[id.idx()]);
                    let zero = edit.add_node(Node::Constant { id: zero });
                    let phi = edit.add_node(Node::Phi {
                        control: phi_control,
                        data: data
                            .iter()
                            .map(|phi_use| if *phi_use == id { zero } else { gated_input })
                            .collect(),
                    });
                    let new_reduce_id = NodeID::new(edit.num_node_ids());
                    let new_reduct_id = NodeID::new(edit.num_node_ids() + 1);
                    let new_reduce = Node::Reduce {
                        control,
                        init,
                        reduct: new_reduct_id,
                    };
                    let new_add = Node::Binary {
                        op: BinaryOperator::Add,
                        left: new_reduce_id,
                        right: phi,
                    };
                    let new_reduce = edit.add_node(new_reduce);
                    edit.add_node(new_add);
                    edit = edit.replace_all_uses(id, new_reduce)?;
                    edit = edit.delete_node(id)?;
                    Ok(edit)
                });
            }
        }
    
    rarbore2's avatar
    rarbore2 committed
    }
    
    rarbore2's avatar
    rarbore2 committed
    
    /*
     * Extends the dimensions of a fork-join to be a multiple of a number and gates
     * the execution of the body.
     */
    pub fn extend_all_forks(
        editor: &mut FunctionEditor,
        fork_join_map: &HashMap<NodeID, NodeID>,
        multiple: usize,
    ) {
        for (fork, join) in fork_join_map {
            if editor.is_mutable(*fork) {
                extend_fork(editor, *fork, *join, multiple);
            }
        }
    }
    
    fn extend_fork(editor: &mut FunctionEditor, fork: NodeID, join: NodeID, multiple: usize) {
        let nodes = &editor.func().nodes;
        let (fork_pred, factors) = nodes[fork.idx()].try_fork().unwrap();
        let factors = factors.to_vec();
        let fork_succ = editor
            .get_users(fork)
            .filter(|id| nodes[id.idx()].is_control())
            .next()
            .unwrap();
        let join_pred = nodes[join.idx()].try_join().unwrap();
        let ctrl_between = fork != join_pred;
        let reduces: Vec<_> = editor
            .get_users(join)
            .filter_map(|id| nodes[id.idx()].try_reduce().map(|x| (id, x)))
            .collect();
    
        editor.edit(|mut edit| {
            // We can round up a dynamic constant A to a multiple of another dynamic
            // constant B via the following math:
            // ((A + B - 1) / B) * B
            let new_factors: Vec<_> = factors
                .iter()
                .map(|factor| {
                    let b = edit.add_dynamic_constant(DynamicConstant::Constant(multiple));
                    let apb = edit.add_dynamic_constant(DynamicConstant::add(*factor, b));
                    let o = edit.add_dynamic_constant(DynamicConstant::Constant(1));
                    let apbmo = edit.add_dynamic_constant(DynamicConstant::sub(apb, o));
                    let apbmodb = edit.add_dynamic_constant(DynamicConstant::div(apbmo, b));
                    edit.add_dynamic_constant(DynamicConstant::mul(apbmodb, b))
                })
                .collect();
    
            // Create the new control structure.
            let new_fork = edit.add_node(Node::Fork {
                control: fork_pred,
                factors: new_factors.into_boxed_slice(),
            });
            edit = edit.replace_all_uses_where(fork, new_fork, |id| *id != fork_succ)?;
            edit.sub_edit(fork, new_fork);
            let conds: Vec<_> = factors
                .iter()
                .enumerate()
                .map(|(idx, old_factor)| {
                    let tid = edit.add_node(Node::ThreadID {
                        control: new_fork,
                        dimension: idx,
                    });
    
    rarbore2's avatar
    rarbore2 committed
                    edit.sub_edit(fork, tid);
    
    rarbore2's avatar
    rarbore2 committed
                    let old_bound = edit.add_node(Node::DynamicConstant { id: *old_factor });
                    edit.add_node(Node::Binary {
                        op: BinaryOperator::LT,
                        left: tid,
                        right: old_bound,
                    })
                })
                .collect();
            let cond = conds
                .into_iter()
                .reduce(|left, right| {
                    edit.add_node(Node::Binary {
                        op: BinaryOperator::And,
                        left,
                        right,
                    })
                })
                .unwrap();
            let branch = edit.add_node(Node::If {
                control: new_fork,
                cond,
            });
            let false_proj = edit.add_node(Node::ControlProjection {
                control: branch,
                selection: 0,
            });
            let true_proj = edit.add_node(Node::ControlProjection {
                control: branch,
                selection: 1,
            });
            if ctrl_between {
                edit = edit.replace_all_uses_where(fork, true_proj, |id| *id == fork_succ)?;
            }
            let bottom_region = edit.add_node(Node::Region {
                preds: Box::new([false_proj, if ctrl_between { join_pred } else { true_proj }]),
            });
            let new_join = edit.add_node(Node::Join {
                control: bottom_region,
            });
            edit = edit.replace_all_uses(join, new_join)?;
            edit.sub_edit(join, new_join);
            edit = edit.delete_node(fork)?;
            edit = edit.delete_node(join)?;
    
            // Update the reduces to use phis on the region node to gate their execution.
            for (reduce, (_, init, reduct)) in reduces {
                let phi = edit.add_node(Node::Phi {
                    control: bottom_region,
                    data: Box::new([reduce, reduct]),
                });
                let new_reduce = edit.add_node(Node::Reduce {
                    control: new_join,
                    init,
                    reduct: phi,
                });
                edit = edit.replace_all_uses(reduce, new_reduce)?;
                edit.sub_edit(reduce, new_reduce);
                edit = edit.delete_node(reduce)?;
            }
    
            Ok(edit)
        });
    }