From 0a6326e6aaaffec59d8490f89b9bc187e79a686d Mon Sep 17 00:00:00 2001
From: Russel Arbore <russel.jma@gmail.com>
Date: Thu, 27 Feb 2025 19:06:34 -0600
Subject: [PATCH] Infer parallelreduce in bfs

---
 hercules_ir/src/ir.rs                | 17 ++++++--
 hercules_opt/src/pred.rs             | 63 ++++++++++++++++++++++++++++
 hercules_opt/src/schedule.rs         | 31 +++++++++++---
 juno_samples/rodinia/bfs/src/bfs.jn  |  2 +-
 juno_samples/rodinia/bfs/src/cpu.sch |  3 +-
 5 files changed, 106 insertions(+), 10 deletions(-)

diff --git a/hercules_ir/src/ir.rs b/hercules_ir/src/ir.rs
index 5dfe2915..f6aafa35 100644
--- a/hercules_ir/src/ir.rs
+++ b/hercules_ir/src/ir.rs
@@ -1048,9 +1048,20 @@ impl Constant {
         }
     }
 
-    /*
-     * Useful for GVN.
-     */
+    pub fn is_false(&self) -> bool {
+        match self {
+            Constant::Boolean(false) => true,
+            _ => false,
+        }
+    }
+
+    pub fn is_true(&self) -> bool {
+        match self {
+            Constant::Boolean(true) => true,
+            _ => false,
+        }
+    }
+
     pub fn is_zero(&self) -> bool {
         match self {
             Constant::Integer8(0) => true,
diff --git a/hercules_opt/src/pred.rs b/hercules_opt/src/pred.rs
index ed7c3a85..587c4507 100644
--- a/hercules_opt/src/pred.rs
+++ b/hercules_opt/src/pred.rs
@@ -136,6 +136,69 @@ pub fn predication(editor: &mut FunctionEditor, typing: &Vec<TypeID>) {
             bad_branches.insert(branch);
         }
     }
+
+    // Do a quick and dirty rewrite to convert select(a, b, false) to a && b and
+    // select(a, b, true) to a || b.
+    for id in editor.node_ids() {
+        let nodes = &editor.func().nodes;
+        if let Node::Ternary {
+            op: TernaryOperator::Select,
+            first,
+            second,
+            third,
+        } = nodes[id.idx()]
+        {
+            if let Some(cons) = nodes[second.idx()].try_constant()
+                && editor.get_constant(cons).is_false()
+            {
+                editor.edit(|mut edit| {
+                    let node = edit.add_node(Node::Binary {
+                        op: BinaryOperator::And,
+                        left: first,
+                        right: third,
+                    });
+                    edit = edit.replace_all_uses(id, node)?;
+                    edit.delete_node(id)
+                });
+            } else if let Some(cons) = nodes[third.idx()].try_constant()
+                && editor.get_constant(cons).is_false()
+            {
+                editor.edit(|mut edit| {
+                    let node = edit.add_node(Node::Binary {
+                        op: BinaryOperator::And,
+                        left: first,
+                        right: second,
+                    });
+                    edit = edit.replace_all_uses(id, node)?;
+                    edit.delete_node(id)
+                });
+            } else if let Some(cons) = nodes[second.idx()].try_constant()
+                && editor.get_constant(cons).is_true()
+            {
+                editor.edit(|mut edit| {
+                    let node = edit.add_node(Node::Binary {
+                        op: BinaryOperator::Or,
+                        left: first,
+                        right: third,
+                    });
+                    edit = edit.replace_all_uses(id, node)?;
+                    edit.delete_node(id)
+                });
+            } else if let Some(cons) = nodes[third.idx()].try_constant()
+                && editor.get_constant(cons).is_true()
+            {
+                editor.edit(|mut edit| {
+                    let node = edit.add_node(Node::Binary {
+                        op: BinaryOperator::Or,
+                        left: first,
+                        right: second,
+                    });
+                    edit = edit.replace_all_uses(id, node)?;
+                    edit.delete_node(id)
+                });
+            }
+        }
+    }
 }
 
 /*
diff --git a/hercules_opt/src/schedule.rs b/hercules_opt/src/schedule.rs
index d7ae4048..9bc7823e 100644
--- a/hercules_opt/src/schedule.rs
+++ b/hercules_opt/src/schedule.rs
@@ -69,6 +69,26 @@ pub fn infer_parallel_reduce(
             chain_id = reduct;
         }
 
+        // If the use is a phi that uses the reduce and a write, then we might
+        // want to parallelize this still. Set the chain ID to the write.
+        if let Node::Phi {
+            control: _,
+            ref data,
+        } = func.nodes[chain_id.idx()]
+            && data.len()
+                == data
+                    .into_iter()
+                    .filter(|phi_use| **phi_use == last_reduce)
+                    .count()
+                    + 1
+        {
+            chain_id = *data
+                .into_iter()
+                .filter(|phi_use| **phi_use != last_reduce)
+                .next()
+                .unwrap();
+        }
+
         // Check for a Write-Reduce tight cycle.
         if let Node::Write {
             collect,
@@ -130,12 +150,13 @@ pub fn infer_monoid_reduce(
     reduce_cycles: &HashMap<NodeID, HashSet<NodeID>>,
 ) {
     let is_binop_monoid = |op| {
-        matches!(
-            op,
-            BinaryOperator::Add | BinaryOperator::Mul | BinaryOperator::Or | BinaryOperator::And
-        )
+        op == BinaryOperator::Add
+            || op == BinaryOperator::Mul
+            || op == BinaryOperator::Or
+            || op == BinaryOperator::And
     };
-    let is_intrinsic_monoid = |intrinsic| matches!(intrinsic, Intrinsic::Max | Intrinsic::Min);
+    let is_intrinsic_monoid =
+        |intrinsic| intrinsic == Intrinsic::Max || intrinsic == Intrinsic::Min;
 
     for id in editor.node_ids() {
         let func = editor.func();
diff --git a/juno_samples/rodinia/bfs/src/bfs.jn b/juno_samples/rodinia/bfs/src/bfs.jn
index 51dcd945..ca0f7774 100644
--- a/juno_samples/rodinia/bfs/src/bfs.jn
+++ b/juno_samples/rodinia/bfs/src/bfs.jn
@@ -43,10 +43,10 @@ fn bfs<n, m: usize>(graph_nodes: Node[n], source: u32, edges: u32[m]) -> i32[n]
     }
 
     @loop2 for i in 0..n {
+      stop = stop && updated[i];
       if updated[i] {
         mask[i] = true;
         visited[i] = true;
-        stop = false;
         updated[i] = false;
       }
     }
diff --git a/juno_samples/rodinia/bfs/src/cpu.sch b/juno_samples/rodinia/bfs/src/cpu.sch
index 44cfa8ad..ae67fdd9 100644
--- a/juno_samples/rodinia/bfs/src/cpu.sch
+++ b/juno_samples/rodinia/bfs/src/cpu.sch
@@ -23,7 +23,8 @@ fixpoint {
   fork-guard-elim(*);
 }
 simpl!(*);
+predication(*);
+simpl!(*);
 
 unforkify(*);
-
 gcm(*);
-- 
GitLab