diff --git a/hercules_opt/src/fork_transforms.rs b/hercules_opt/src/fork_transforms.rs index e635b3c00d7bfa0090376d8056e65d8d01e60ce2..ec111e69d3f69270beffc0ca92ad3be67e6ff765 100644 --- a/hercules_opt/src/fork_transforms.rs +++ b/hercules_opt/src/fork_transforms.rs @@ -686,40 +686,33 @@ pub fn fork_coalesce_helper( // CHECKME / FIXME: Might need to be added the other way. new_factors.append(&mut inner_dims.to_vec()); - for tid in inner_tids { - let (fork, dim) = editor.func().nodes[tid.idx()].try_thread_id().unwrap(); - let new_tid = Node::ThreadID { - control: fork, - dimension: dim + num_outer_dims, - }; + editor.edit(|mut edit| { + for tid in inner_tids { + let (fork, dim) = edit.get_node(tid).try_thread_id().unwrap(); + let new_tid = Node::ThreadID { + control: fork, + dimension: dim + num_outer_dims, + }; - editor.edit(|mut edit| { let new_tid = edit.add_node(new_tid); - let edit = edit.replace_all_uses(tid, new_tid)?; - Ok(edit) - }); - } + edit = edit.replace_all_uses(tid, new_tid)?; + } - // Fuse Reductions - for (outer_reduce, inner_reduce) in pairs { - let (_, outer_init, _) = editor.func().nodes[outer_reduce.idx()] - .try_reduce() - .unwrap(); - let (_, inner_init, _) = editor.func().nodes[inner_reduce.idx()] - .try_reduce() - .unwrap(); - editor.edit(|mut edit| { + // Fuse Reductions + for (outer_reduce, inner_reduce) in pairs { + let (_, outer_init, _) = edit.get_node(outer_reduce) + .try_reduce() + .unwrap(); + let (_, inner_init, _) = edit.get_node(inner_reduce) + .try_reduce() + .unwrap(); // Set inner init to outer init. edit = edit.replace_all_uses_where(inner_init, outer_init, |usee| *usee == inner_reduce)?; edit = edit.replace_all_uses(outer_reduce, inner_reduce)?; edit = edit.delete_node(outer_reduce)?; + } - Ok(edit) - }); - } - - editor.edit(|mut edit| { let new_fork = Node::Fork { control: outer_pred, factors: new_factors.into(), @@ -734,9 +727,7 @@ pub fn fork_coalesce_helper( edit = edit.delete_node(outer_fork)?; Ok(edit) - }); - - true + }) } pub fn split_any_fork(