From 1b78bc2e9d2acce3e9b1eec8aed25489eb9aced5 Mon Sep 17 00:00:00 2001
From: Russel Arbore <russel.jma@gmail.com>
Date: Fri, 28 Feb 2025 15:59:10 -0600
Subject: [PATCH] Refactor infer parallelreduce

---
 hercules_opt/src/schedule.rs              | 46 ++++++-----------------
 juno_samples/rodinia/backprop/src/cpu.sch |  1 +
 2 files changed, 12 insertions(+), 35 deletions(-)

diff --git a/hercules_opt/src/schedule.rs b/hercules_opt/src/schedule.rs
index 9bc7823e..10eca72e 100644
--- a/hercules_opt/src/schedule.rs
+++ b/hercules_opt/src/schedule.rs
@@ -1,6 +1,6 @@
 use std::collections::{BTreeSet, HashMap, HashSet};
+use std::iter::once;
 
-use hercules_ir::def_use::*;
 use hercules_ir::ir::*;
 
 use crate::*;
@@ -42,6 +42,10 @@ pub fn infer_parallel_reduce(
     fork_join_map: &HashMap<NodeID, NodeID>,
     reduce_cycles: &HashMap<NodeID, HashSet<NodeID>>,
 ) {
+    let join_fork_map: HashMap<_, _> = fork_join_map
+        .into_iter()
+        .map(|(fork, join)| (*join, *fork))
+        .collect();
     for id in editor.node_ids() {
         let func = editor.func();
         if !func.nodes[id.idx()].is_reduce() {
@@ -98,40 +102,11 @@ pub fn infer_parallel_reduce(
             && *collect == last_reduce
             && !reduce_cycles[&last_reduce].contains(data)
         {
-            // If there is a Write-Reduce tight cycle, get the position indices.
-            let positions = indices
-                .iter()
-                .filter_map(|index| {
-                    if let Index::Position(indices) = index {
-                        Some(indices)
-                    } else {
-                        None
-                    }
-                })
-                .flat_map(|pos| pos.iter());
-
-            // Get the Forks corresponding to uses of bare ThreadIDs.
-            let fork_thread_id_pairs = positions.filter_map(|id| {
-                if let Node::ThreadID { control, dimension } = func.nodes[id.idx()] {
-                    Some((control, dimension))
-                } else {
-                    None
-                }
-            });
-            let mut forks = HashMap::<NodeID, Vec<usize>>::new();
-            for (fork, dim) in fork_thread_id_pairs {
-                forks.entry(fork).or_default().push(dim);
-            }
-
-            // Check if one of the Forks correspond to the Join associated with
-            // the Reduce being considered, and has all of its dimensions
-            // represented in the indexing.
-            let is_parallel = forks.into_iter().any(|(id, mut rep_dims)| {
-                rep_dims.sort();
-                rep_dims.dedup();
-                fork_join_map[&id] == first_control.unwrap()
-                    && func.nodes[id.idx()].try_fork().unwrap().1.len() == rep_dims.len()
-            });
+            let is_parallel = indices_parallel_over_forks(
+                editor,
+                indices,
+                once(join_fork_map[&first_control.unwrap()]),
+            );
 
             if is_parallel {
                 editor.edit(|edit| edit.add_schedule(id, Schedule::ParallelReduce));
@@ -145,6 +120,7 @@ pub fn infer_parallel_reduce(
  * operands must be the Reduce node, and all other operands must not be in the
  * Reduce node's cycle.
  */
+#[rustfmt::skip]
 pub fn infer_monoid_reduce(
     editor: &mut FunctionEditor,
     reduce_cycles: &HashMap<NodeID, HashSet<NodeID>>,
diff --git a/juno_samples/rodinia/backprop/src/cpu.sch b/juno_samples/rodinia/backprop/src/cpu.sch
index d59fd5f5..de34d660 100644
--- a/juno_samples/rodinia/backprop/src/cpu.sch
+++ b/juno_samples/rodinia/backprop/src/cpu.sch
@@ -23,6 +23,7 @@ fixpoint {
   fork-guard-elim(*);
   fork-coalesce(*);
 }
+reduce-slf(*);
 simpl!(*);
 
 fork-split(*);
-- 
GitLab