diff --git a/hercules_opt/src/fork_concat_split.rs b/hercules_opt/src/fork_concat_split.rs index 1339a38436bcf1db5a613d3cb121d1f67d612a2e..bb3a2cff556077d2bf3fe54a7fa21d0dd6d4e4b9 100644 --- a/hercules_opt/src/fork_concat_split.rs +++ b/hercules_opt/src/fork_concat_split.rs @@ -7,7 +7,8 @@ use crate::*; /* * Split multi-dimensional fork-joins into separate one-dimensional fork-joins. - * Useful for code generation. + * Useful for code generation. A single iteration of `fork_split` only splits + * at most one fork-join, it must be called repeatedly to split all fork-joins. */ pub fn fork_split( editor: &mut FunctionEditor, diff --git a/hercules_opt/src/fork_guard_elim.rs b/hercules_opt/src/fork_guard_elim.rs index 435e63b6eb0a2b8cc91adbb50451a0caddd2b16a..9384a8c18a03cdbafecdef9264c447d0f1a61592 100644 --- a/hercules_opt/src/fork_guard_elim.rs +++ b/hercules_opt/src/fork_guard_elim.rs @@ -1,11 +1,10 @@ use std::collections::{HashMap, HashSet}; use either::Either; -use hercules_ir::get_uses_mut; -use hercules_ir::ir::*; -use hercules_ir::ImmutableDefUseMap; -use crate::FunctionEditor; +use hercules_ir::*; + +use crate::*; /* * This is a Hercules IR transformation that: @@ -20,20 +19,6 @@ use crate::FunctionEditor; * guard remains and in these cases the guard is no longer needed. */ -/* Given a node index and the node itself, return None if the node is not - * a guarded fork where we can eliminate the guard. - * If the node is a fork with a guard we can eliminate returns a tuple of - * - This node's NodeID - * - The replication factor of the fork - * - The ID of the if of the guard - * - The ID of the projections of the if - * - The guard's predecessor - * - A map of NodeIDs for the phi nodes to the reduce they should be replaced - * with, and also the region that joins the guard's branches mapping to the - * fork's join NodeID - * - If the replication factor is a max that can be eliminated. - */ - // Simplify factors through max enum Factor { Max(usize, DynamicConstantID), @@ -61,6 +46,19 @@ struct GuardedFork { factor: Factor, // The factor that matches the guard } +/* Given a node index and the node itself, return None if the node is not + * a guarded fork where we can eliminate the guard. + * If the node is a fork with a guard we can eliminate returns a tuple of + * - This node's NodeID + * - The replication factor of the fork + * - The ID of the if of the guard + * - The ID of the projections of the if + * - The guard's predecessor + * - A map of NodeIDs for the phi nodes to the reduce they should be replaced + * with, and also the region that joins the guard's branches mapping to the + * fork's join NodeID + * - If the replication factor is a max that can be eliminated. + */ fn guarded_fork( editor: &mut FunctionEditor, fork_join_map: &HashMap<NodeID, NodeID>, @@ -73,8 +71,7 @@ fn guarded_fork( return None; }; - let factors = factors.iter().enumerate().map(|(idx, dc)| { - // FIXME: Can we hide .idx() in an impl Index or something so we don't index Vec<Nodes> iwht DynamicConstantId.idx() + let mut factors = factors.iter().enumerate().map(|(idx, dc)| { let DynamicConstant::Max(l, r) = *editor.get_dynamic_constant(*dc) else { return Factor::Normal(idx, *dc); }; @@ -140,24 +137,22 @@ fn guarded_fork( } // Match Factor - let factor = factors.clone().find(|factor| { - // This clone on the dc is painful. + let factor = factors.find(|factor| { match ( &function.nodes[pattern_factor.idx()], - editor.get_dynamic_constant(factor.get_id()).clone(), + &*editor.get_dynamic_constant(factor.get_id()), ) { (Node::Constant { id }, DynamicConstant::Constant(v)) => { let Constant::UnsignedInteger64(pattern_v) = *editor.get_constant(*id) else { return false; }; - pattern_v == (v as u64) + pattern_v == (*v as u64) } (Node::DynamicConstant { id }, _) => *id == factor.get_id(), _ => false, } }); - // return Factor factor }) } @@ -184,12 +179,10 @@ fn guarded_fork( } // Match Factor - // FIXME: Implement dc / constant matching as in case where branch_idx == 1 - let factor = factors.clone().find(|factor| { + let factor = factors.find(|factor| { function.nodes[pattern_factor.idx()].try_dynamic_constant() == Some(factor.get_id()) }); - // return Factor factor }) } else { @@ -229,7 +222,7 @@ fn guarded_fork( } else { return None; }; - // Other predecessor needs to be the other read from the guard's if + // Other predecessor needs to be the other projection from the guard's if let Node::Projection { control: if_node2, ref selection, @@ -317,8 +310,6 @@ fn guarded_fork( /* * Top level function to run fork guard elimination, as described above. - * Deletes nodes by setting nodes to gravestones. Works with a function already - * containing gravestones. */ pub fn fork_guard_elim(editor: &mut FunctionEditor, fork_join_map: &HashMap<NodeID, NodeID>) { let guard_info = editor