From 375199f5369a1237f52edd82af8c0167219f2cce Mon Sep 17 00:00:00 2001
From: Russel Arbore <russel.jma@gmail.com>
Date: Mon, 3 Mar 2025 11:39:24 -0600
Subject: [PATCH] some tweaks

---
 hercules_opt/src/fork_transforms.rs |  6 ++++--
 hercules_opt/src/gcm.rs             | 14 +++++++++++---
 hercules_opt/src/utils.rs           | 18 ++++++++++++++++++
 3 files changed, 33 insertions(+), 5 deletions(-)

diff --git a/hercules_opt/src/fork_transforms.rs b/hercules_opt/src/fork_transforms.rs
index e1598463..6998f879 100644
--- a/hercules_opt/src/fork_transforms.rs
+++ b/hercules_opt/src/fork_transforms.rs
@@ -1533,7 +1533,8 @@ pub fn clean_monoid_reduces(editor: &mut FunctionEditor, typing: &Vec<TypeID>) {
                 left: _,
                 right: _,
             } if (op == BinaryOperator::Add || op == BinaryOperator::Or)
-                && !is_zero(editor, init) =>
+                && !is_zero(editor, init)
+                && !is_false(editor, init) =>
             {
                 editor.edit(|mut edit| {
                     let zero = edit.add_zero_constant(typing[init.idx()]);
@@ -1556,7 +1557,8 @@ pub fn clean_monoid_reduces(editor: &mut FunctionEditor, typing: &Vec<TypeID>) {
                 left: _,
                 right: _,
             } if (op == BinaryOperator::Mul || op == BinaryOperator::And)
-                && !is_one(editor, init) =>
+                && !is_one(editor, init)
+                && !is_true(editor, init) =>
             {
                 editor.edit(|mut edit| {
                     let one = edit.add_one_constant(typing[init.idx()]);
diff --git a/hercules_opt/src/gcm.rs b/hercules_opt/src/gcm.rs
index d950941a..4a6365c8 100644
--- a/hercules_opt/src/gcm.rs
+++ b/hercules_opt/src/gcm.rs
@@ -212,7 +212,8 @@ fn preliminary_fixups(
             let (_, init, _) = nodes[reduce.idx()].try_reduce().unwrap();
 
             // Replace uses of the reduce in its cycle with the init.
-            let success = editor.edit(|edit| {
+            let success = editor.edit(|mut edit| {
+                edit = edit.add_schedule(init, Schedule::ParallelReduce)?;
                 edit.replace_all_uses_where(reduce, init, |id| reduce_cycles[&reduce].contains(id))
             });
             assert!(success);
@@ -870,7 +871,7 @@ fn spill_clones(
     // Step 2: filter edges (A, B) to just see edges where A uses B and A
     // mutates B. These are the edges that may require a spill.
     let mut spill_edges = edges.into_iter().filter(|(a, b)| {
-        mutating_writes(editor.func(), *a, objects).any(|id| id == *b)
+        (mutating_writes(editor.func(), *a, objects).any(|id| id == *b)
             || (get_uses(&editor.func().nodes[a.idx()])
                 .as_ref()
                 .into_iter()
@@ -890,7 +891,14 @@ fn spill_clones(
                         data.contains(b)
                             && editor.func().schedules[a.idx()].contains(&Schedule::ParallelReduce)
                     })
-                    .unwrap_or(false))
+                    .unwrap_or(false)))
+            && !editor.func().nodes[a.idx()]
+                .try_write()
+                .map(|(collect, _, _)| {
+                    collect == *b
+                        && editor.func().schedules[b.idx()].contains(&Schedule::ParallelReduce)
+                })
+                .unwrap_or(false)
     });
 
     // Step 3: if there is a spill edge, spill it and return true. Otherwise,
diff --git a/hercules_opt/src/utils.rs b/hercules_opt/src/utils.rs
index b910a128..351abc2b 100644
--- a/hercules_opt/src/utils.rs
+++ b/hercules_opt/src/utils.rs
@@ -598,6 +598,24 @@ pub fn is_one(editor: &FunctionEditor, id: NodeID) -> bool {
         || nodes[id.idx()].is_undef()
 }
 
+pub fn is_false(editor: &FunctionEditor, id: NodeID) -> bool {
+    let nodes = &editor.func().nodes;
+    nodes[id.idx()]
+        .try_constant()
+        .map(|id| editor.get_constant(id).is_false())
+        .unwrap_or(false)
+        || nodes[id.idx()].is_undef()
+}
+
+pub fn is_true(editor: &FunctionEditor, id: NodeID) -> bool {
+    let nodes = &editor.func().nodes;
+    nodes[id.idx()]
+        .try_constant()
+        .map(|id| editor.get_constant(id).is_true())
+        .unwrap_or(false)
+        || nodes[id.idx()].is_undef()
+}
+
 pub fn is_largest(editor: &FunctionEditor, id: NodeID) -> bool {
     let nodes = &editor.func().nodes;
     nodes[id.idx()]
-- 
GitLab