diff --git a/hercules_opt/src/fork_transforms.rs b/hercules_opt/src/fork_transforms.rs index 12b91194749b6125418b3639fa1d29d6b0397fc1..9a16c99c7d3decceff04cc7a4ce8b13a149f96ad 100644 --- a/hercules_opt/src/fork_transforms.rs +++ b/hercules_opt/src/fork_transforms.rs @@ -676,50 +676,47 @@ pub fn fork_coalesce_helper( // CHECKME / FIXME: Might need to be added the other way. new_factors.append(&mut inner_dims.to_vec()); - for tid in inner_tids { - let (fork, dim) = editor.func().nodes[tid.idx()].try_thread_id().unwrap(); - let new_tid = Node::ThreadID { - control: fork, - dimension: dim + num_outer_dims, - }; + let mut new_fork = NodeID::new(0); + let new_join = inner_join; // We'll reuse the inner join as the join of the new fork + + let success = editor.edit(|mut edit| { + for tid in inner_tids { + let (fork, dim) = edit.get_node(tid).try_thread_id().unwrap(); + let new_tid = Node::ThreadID { + control: fork, + dimension: dim + num_outer_dims, + }; - editor.edit(|mut edit| { let new_tid = edit.add_node(new_tid); - let mut edit = edit.replace_all_uses(tid, new_tid)?; + edit = edit.replace_all_uses(tid, new_tid)?; edit.sub_edit(tid, new_tid); - Ok(edit) - }); - } - - // Fuse Reductions - for (outer_reduce, inner_reduce) in pairs { - let (_, outer_init, _) = editor.func().nodes[outer_reduce.idx()] - .try_reduce() - .unwrap(); - let (_, inner_init, _) = editor.func().nodes[inner_reduce.idx()] - .try_reduce() - .unwrap(); - let success = editor.edit(|mut edit| { + } + // Fuse Reductions + for (outer_reduce, inner_reduce) in pairs { + let (_, outer_init, _) = edit.get_node(outer_reduce) + .try_reduce() + .unwrap(); + let (_, inner_init, _) = edit.get_node(inner_reduce) + .try_reduce() + .unwrap(); // Set inner init to outer init. edit = edit.replace_all_uses_where(inner_init, outer_init, |usee| *usee == inner_reduce )?; edit = edit.replace_all_uses(outer_reduce, inner_reduce)?; edit = edit.delete_node(outer_reduce)?; + } - Ok(edit) - }); - } - - let mut new_fork = NodeID::new(0); - let new_join = inner_join; // We reuse the inner join as the join of the new fork - - editor.edit(|mut edit| { let new_fork_node = Node::Fork { control: outer_pred, factors: new_factors.into(), }; new_fork = edit.add_node(new_fork_node); + if edit.get_schedule(outer_fork).contains(&Schedule::ParallelFork) + && edit.get_schedule(inner_fork).contains(&Schedule::ParallelFork) { + edit = edit.add_schedule(new_fork, Schedule::ParallelFork)?; + } + edit = edit.replace_all_uses(inner_fork, new_fork)?; edit = edit.replace_all_uses(outer_fork, new_fork)?; edit = edit.replace_all_uses(outer_join, inner_join)?; @@ -730,7 +727,11 @@ pub fn fork_coalesce_helper( Ok(edit) }); - Some((new_fork, new_join)) + if success { + Some((new_fork, new_join)) + } else { + None + } } pub fn split_any_fork( @@ -1277,9 +1278,14 @@ pub fn fork_interchange( edit = edit.delete_node(old_id)?; } let new_fork = edit.add_node(new_fork); - new_fork_id = Some(new_fork); + if edit.get_schedule(fork).contains(&Schedule::ParallelFork) { + edit = edit.add_schedule(new_fork, Schedule::ParallelFork)?; + } edit = edit.replace_all_uses(fork, new_fork)?; - edit.delete_node(fork) + edit = edit.delete_node(fork)?; + + new_fork_id = Some(new_fork); + Ok(edit) }); new_fork_id