From c63b72a64691c606d9a09503b9b39e2cc3e6fded Mon Sep 17 00:00:00 2001
From: Russel Arbore <russel.jma@gmail.com>
Date: Thu, 30 Jan 2025 10:20:05 -0600
Subject: [PATCH] cleanup fork_guard_elim

---
 hercules_opt/src/fork_concat_split.rs |  3 +-
 hercules_opt/src/fork_guard_elim.rs   | 53 +++++++++++----------------
 2 files changed, 24 insertions(+), 32 deletions(-)

diff --git a/hercules_opt/src/fork_concat_split.rs b/hercules_opt/src/fork_concat_split.rs
index 1339a384..bb3a2cff 100644
--- a/hercules_opt/src/fork_concat_split.rs
+++ b/hercules_opt/src/fork_concat_split.rs
@@ -7,7 +7,8 @@ use crate::*;
 
 /*
  * Split multi-dimensional fork-joins into separate one-dimensional fork-joins.
- * Useful for code generation.
+ * Useful for code generation. A single iteration of `fork_split` only splits
+ * at most one fork-join, it must be called repeatedly to split all fork-joins.
  */
 pub fn fork_split(
     editor: &mut FunctionEditor,
diff --git a/hercules_opt/src/fork_guard_elim.rs b/hercules_opt/src/fork_guard_elim.rs
index 435e63b6..9384a8c1 100644
--- a/hercules_opt/src/fork_guard_elim.rs
+++ b/hercules_opt/src/fork_guard_elim.rs
@@ -1,11 +1,10 @@
 use std::collections::{HashMap, HashSet};
 
 use either::Either;
-use hercules_ir::get_uses_mut;
-use hercules_ir::ir::*;
-use hercules_ir::ImmutableDefUseMap;
 
-use crate::FunctionEditor;
+use hercules_ir::*;
+
+use crate::*;
 
 /*
  * This is a Hercules IR transformation that:
@@ -20,20 +19,6 @@ use crate::FunctionEditor;
  * guard remains and in these cases the guard is no longer needed.
  */
 
-/* Given a node index and the node itself, return None if the node is not
- * a guarded fork where we can eliminate the guard.
- * If the node is a fork with a guard we can eliminate returns a tuple of
- * - This node's NodeID
- * - The replication factor of the fork
- * - The ID of the if of the guard
- * - The ID of the projections of the if
- * - The guard's predecessor
- * - A map of NodeIDs for the phi nodes to the reduce they should be replaced
- *   with, and also the region that joins the guard's branches mapping to the
- *   fork's join NodeID
- * - If the replication factor is a max that can be eliminated.
- */
-
 // Simplify factors through max
 enum Factor {
     Max(usize, DynamicConstantID),
@@ -61,6 +46,19 @@ struct GuardedFork {
     factor: Factor, // The factor that matches the guard
 }
 
+/* Given a node index and the node itself, return None if the node is not
+ * a guarded fork where we can eliminate the guard.
+ * If the node is a fork with a guard we can eliminate returns a tuple of
+ * - This node's NodeID
+ * - The replication factor of the fork
+ * - The ID of the if of the guard
+ * - The ID of the projections of the if
+ * - The guard's predecessor
+ * - A map of NodeIDs for the phi nodes to the reduce they should be replaced
+ *   with, and also the region that joins the guard's branches mapping to the
+ *   fork's join NodeID
+ * - If the replication factor is a max that can be eliminated.
+ */
 fn guarded_fork(
     editor: &mut FunctionEditor,
     fork_join_map: &HashMap<NodeID, NodeID>,
@@ -73,8 +71,7 @@ fn guarded_fork(
         return None;
     };
 
-    let factors = factors.iter().enumerate().map(|(idx, dc)| {
-        // FIXME: Can we hide .idx() in an impl Index or something so we don't index Vec<Nodes> iwht DynamicConstantId.idx()
+    let mut factors = factors.iter().enumerate().map(|(idx, dc)| {
         let DynamicConstant::Max(l, r) = *editor.get_dynamic_constant(*dc) else {
             return Factor::Normal(idx, *dc);
         };
@@ -140,24 +137,22 @@ fn guarded_fork(
                 }
 
                 // Match Factor
-                let factor = factors.clone().find(|factor| {
-                    // This clone on the dc is painful.
+                let factor = factors.find(|factor| {
                     match (
                         &function.nodes[pattern_factor.idx()],
-                        editor.get_dynamic_constant(factor.get_id()).clone(),
+                        &*editor.get_dynamic_constant(factor.get_id()),
                     ) {
                         (Node::Constant { id }, DynamicConstant::Constant(v)) => {
                             let Constant::UnsignedInteger64(pattern_v) = *editor.get_constant(*id)
                             else {
                                 return false;
                             };
-                            pattern_v == (v as u64)
+                            pattern_v == (*v as u64)
                         }
                         (Node::DynamicConstant { id }, _) => *id == factor.get_id(),
                         _ => false,
                     }
                 });
-                // return Factor
                 factor
             })
         }
@@ -184,12 +179,10 @@ fn guarded_fork(
                 }
 
                 // Match Factor
-                // FIXME: Implement dc / constant matching as in case where branch_idx == 1
-                let factor = factors.clone().find(|factor| {
+                let factor = factors.find(|factor| {
                     function.nodes[pattern_factor.idx()].try_dynamic_constant()
                         == Some(factor.get_id())
                 });
-                // return Factor
                 factor
             })
         } else {
@@ -229,7 +222,7 @@ fn guarded_fork(
     } else {
         return None;
     };
-    // Other predecessor needs to be the other read from the guard's if
+    // Other predecessor needs to be the other projection from the guard's if
     let Node::Projection {
         control: if_node2,
         ref selection,
@@ -317,8 +310,6 @@ fn guarded_fork(
 
 /*
  * Top level function to run fork guard elimination, as described above.
- * Deletes nodes by setting nodes to gravestones. Works with a function already
- * containing gravestones.
  */
 pub fn fork_guard_elim(editor: &mut FunctionEditor, fork_join_map: &HashMap<NodeID, NodeID>) {
     let guard_info = editor
-- 
GitLab