From cc90b62246ffed252546a40e26dba9d868bae6b3 Mon Sep 17 00:00:00 2001
From: Russel Arbore <russel.jma@gmail.com>
Date: Fri, 31 Jan 2025 10:51:59 -0600
Subject: [PATCH 1/9] Start analysis for collections that don't need to be
 zeroed

---
 hercules_ir/src/collections.rs | 105 ++++++++++++++++++++++++++++++++-
 hercules_opt/src/slf.rs        |   6 +-
 2 files changed, 107 insertions(+), 4 deletions(-)

diff --git a/hercules_ir/src/collections.rs b/hercules_ir/src/collections.rs
index 9f421221..33b25731 100644
--- a/hercules_ir/src/collections.rs
+++ b/hercules_ir/src/collections.rs
@@ -1,4 +1,5 @@
-use std::collections::{BTreeMap, BTreeSet};
+use std::collections::{BTreeMap, BTreeSet, HashMap};
+use std::iter::once;
 
 use crate::*;
 
@@ -347,3 +348,105 @@ pub fn collection_objects(
 
     collection_objects
 }
+
+/*
+ * The zero lattice determine what indices a collection has been written to so
+ * far. Reads from collection constants not at an index that's been written to
+ * return zero - if any such read exists for a collection constant, then that
+ * collection constant must be memset to zero explicitly. This is a similar
+ * analysis to store-to-load forwarding, but we are looking for cases where the
+ * store to forward might be a zero coming from a zero memset.
+ */
+#[derive(Debug, Clone, PartialEq, Eq, Hash)]
+struct ZeroLattice {
+    written_to: BTreeSet<Box<[Index]>>,
+}
+
+impl Semilattice for ZeroLattice {
+    fn meet(a: &Self, b: &Self) -> Self {
+        // Merge the two sets. Find equal indices sets between `a` and `b`.
+        let mut ret = BTreeSet::new();
+        for indices in a.written_to.iter() {
+            if b.written_to.contains(indices) {
+                // If both sets have the same indices, add it to the meet value.
+                ret.insert(indices.clone());
+            }
+        }
+        ZeroLattice { written_to: ret }
+    }
+
+    fn top() -> Self {
+        ZeroLattice {
+            written_to: once(Box::new([]) as Box<[Index]>).collect(),
+        }
+    }
+
+    fn bottom() -> Self {
+        ZeroLattice {
+            written_to: BTreeSet::new(),
+        }
+    }
+}
+
+/*
+ * Analysis that determines what collection constants in a function don't need
+ * to be lowered to memsets. If a collection constant is only read at indices
+ * that it has first been written to, then the backing memory for the constant
+ * doesn't have to be set to zero, since those zeros are never read.
+ */
+pub fn no_reset_constant_collections(
+    function: &Function,
+    types: &Vec<Type>,
+    reverse_postorder: &Vec<NodeID>,
+    typing: &Vec<TypeID>,
+    objects: &CollectionObjects,
+    reduce_einsum: &(MathEnv, HashMap<NodeID, NodeID>),
+) -> BTreeSet<NodeID> {
+    // First, run a dataflow analysis that determine at each collection node
+    // what indices have definitely been written to.
+    let lattice = forward_dataflow(function, reverse_postorder, |inputs, id| {
+        match function.nodes[id.idx()] {
+            Node::Phi {
+                control: _,
+                data: _,
+            }
+            | Node::Ternary {
+                op: TernaryOperator::Select,
+                first: _,
+                second: _,
+                third: _,
+            } => inputs.into_iter().fold(ZeroLattice::top(), |acc, input| {
+                ZeroLattice::meet(&acc, input)
+            }),
+            Node::Reduce {
+                control: _,
+                init: _,
+                reduct: _,
+            } => {
+                // If the einsum for this reduce node is a full array
+                // comprehension, then every array element is written to, and
+                // the empty indices set (the whole collection) is considered as
+                // written to.
+                let (env, exprs) = reduce_einsum;
+                if let MathExpr::Comprehension(_, _) = env[exprs[&id].idx()] {
+                    ZeroLattice::top()
+                }
+                // Otherwise, meet the `init` and `reduct` inputs.
+                else {
+                    ZeroLattice::meet(&inputs[0], &inputs[1])
+                }
+            }
+            Node::Write {
+                collect: _,
+                data: _,
+                ref indices,
+            } => {
+                let mut value = inputs[0].clone();
+                value.written_to.insert(indices.clone());
+                value
+            }
+            _ => ZeroLattice::bottom(),
+        }
+    });
+    todo!()
+}
diff --git a/hercules_opt/src/slf.rs b/hercules_opt/src/slf.rs
index 981a0cce..92acb2a8 100644
--- a/hercules_opt/src/slf.rs
+++ b/hercules_opt/src/slf.rs
@@ -15,7 +15,7 @@ use crate::*;
  * insert an indices set with an empty positions list and a `None` value.
  */
 #[derive(Debug, Clone, PartialEq, Eq, Hash)]
-pub struct SLFLattice {
+struct SLFLattice {
     subvalues: BTreeMap<Box<[Index]>, Option<NodeID>>,
 }
 
@@ -25,7 +25,7 @@ impl Semilattice for SLFLattice {
         // keep their known sub-value if they're equivalent. All other indices
         // sets in `a` or `b` map to `None`.
         let mut ret = BTreeMap::new();
-        for (indices, a_subvalue) in &a.subvalues {
+        for (indices, a_subvalue) in a.subvalues.iter() {
             if let Some(b_subvalue) = b.subvalues.get(indices)
                 && a_subvalue == b_subvalue
             {
@@ -39,7 +39,7 @@ impl Semilattice for SLFLattice {
                 ret.insert(indices.clone(), None);
             }
         }
-        for (indices, _) in &b.subvalues {
+        for (indices, _) in b.subvalues.iter() {
             // Any indices sets in `b` that aren't in `ret` are indices sets
             // that aren't in `a`, so the sub-value isn't known.
             ret.entry(indices.clone()).or_insert(None);
-- 
GitLab


From 6c5517509d70820cab300df9d675387ec7882433 Mon Sep 17 00:00:00 2001
From: Russel Arbore <russel.jma@gmail.com>
Date: Fri, 31 Jan 2025 11:13:42 -0600
Subject: [PATCH 2/9] Progress on no_reset_constant_collections

---
 hercules_ir/src/collections.rs | 57 ++++++++++++++++++++++++++++++++--
 hercules_opt/src/utils.rs      | 18 +++++++++++
 2 files changed, 73 insertions(+), 2 deletions(-)

diff --git a/hercules_ir/src/collections.rs b/hercules_ir/src/collections.rs
index 33b25731..5d50b14a 100644
--- a/hercules_ir/src/collections.rs
+++ b/hercules_ir/src/collections.rs
@@ -1,6 +1,8 @@
 use std::collections::{BTreeMap, BTreeSet, HashMap};
 use std::iter::once;
 
+use either::Either;
+
 use crate::*;
 
 /*
@@ -61,6 +63,13 @@ impl CollectionObjectOrigin {
             _ => None,
         }
     }
+
+    pub fn try_constant(&self) -> Option<NodeID> {
+        match self {
+            CollectionObjectOrigin::Constant(id) => Some(*id),
+            _ => None,
+        }
+    }
 }
 
 impl FunctionCollectionObjects {
@@ -399,10 +408,10 @@ pub fn no_reset_constant_collections(
     types: &Vec<Type>,
     reverse_postorder: &Vec<NodeID>,
     typing: &Vec<TypeID>,
-    objects: &CollectionObjects,
+    objects: &FunctionCollectionObjects,
     reduce_einsum: &(MathEnv, HashMap<NodeID, NodeID>),
 ) -> BTreeSet<NodeID> {
-    // First, run a dataflow analysis that determine at each collection node
+    // First, run a dataflow analysis that determines at each collection node
     // what indices have definitely been written to.
     let lattice = forward_dataflow(function, reverse_postorder, |inputs, id| {
         match function.nodes[id.idx()] {
@@ -448,5 +457,49 @@ pub fn no_reset_constant_collections(
             _ => ZeroLattice::bottom(),
         }
     });
+
+    // Second, collect triples from reads in the program containing:
+    // 1. What indices are read at that point.
+    // 2. What collection constants may be read.
+    // 3. What indices were written to by that point.
+    let full_indices: Box<[Index]> = Box::new([]);
+    let triples = (0..function.nodes.len()).filter_map(|idx| {
+        // Item #1.
+        let id = NodeID::new(idx);
+        let (collect, indices) = match function.nodes[id.idx()] {
+            Node::Read {
+                collect,
+                ref indices,
+            } => (Either::Left(once(collect)), indices),
+            Node::Write {
+                collect: _,
+                data,
+                indices: _,
+            } => (Either::Left(once(data)), &full_indices),
+            Node::Call {
+                control: _,
+                function: _,
+                dynamic_constants: _,
+                ref args,
+            } => (Either::Right(args.into_iter().map(|id| *id)), &full_indices),
+            _ => return None,
+        };
+
+        // Item #2.
+        let constants: Vec<_> = collect
+            .map(|collect| objects.objects(collect).into_iter())
+            .flatten()
+            .filter_map(|obj| objects.origin(*obj).try_constant())
+            .collect();
+        if constants.is_empty() {
+            return None;
+        }
+
+        // Item #3.
+        let written_to = &lattice[id.idx()].written_to;
+
+        Some((indices, constants, written_to))
+    });
+
     todo!()
 }
diff --git a/hercules_opt/src/utils.rs b/hercules_opt/src/utils.rs
index 2ab4e094..a2ca5028 100644
--- a/hercules_opt/src/utils.rs
+++ b/hercules_opt/src/utils.rs
@@ -370,6 +370,24 @@ pub(crate) fn indices_may_overlap(indices1: &[Index], indices2: &[Index]) -> boo
     true
 }
 
+/*
+ * Helper function to determine if a list of indices A definitely contains a
+ * list of indices B.
+ */
+pub(crate) fn indices_contain_other_indices(indices_a: &[Index], indices_b: &[Index]) -> bool {
+    if indices_a.len() < indices_b.len() {
+        return false;
+    }
+
+    for (idx1, idx2) in zip(indices_a, indices_b) {
+        if idx1 != idx2 {
+            return false;
+        }
+    }
+
+    true
+}
+
 pub type DenseNodeMap<T> = Vec<T>;
 pub type SparseNodeMap<T> = HashMap<NodeID, T>;
 
-- 
GitLab


From 7b313282541b12bdd028101ea09b544d1dbec3d1 Mon Sep 17 00:00:00 2001
From: Russel Arbore <russel.jma@gmail.com>
Date: Sat, 1 Feb 2025 10:29:46 -0600
Subject: [PATCH 3/9] Finix no reset analysis

---
 hercules_ir/src/collections.rs | 106 ++++++++++++++++++++-------------
 hercules_ir/src/ir.rs          |  67 +++++++++++++++++++++
 hercules_opt/src/utils.rs      |  66 --------------------
 3 files changed, 132 insertions(+), 107 deletions(-)

diff --git a/hercules_ir/src/collections.rs b/hercules_ir/src/collections.rs
index 5d50b14a..e57a7756 100644
--- a/hercules_ir/src/collections.rs
+++ b/hercules_ir/src/collections.rs
@@ -1,5 +1,5 @@
 use std::collections::{BTreeMap, BTreeSet, HashMap};
-use std::iter::once;
+use std::iter::{once, repeat, zip};
 
 use either::Either;
 
@@ -459,47 +459,71 @@ pub fn no_reset_constant_collections(
     });
 
     // Second, collect triples from reads in the program containing:
-    // 1. What indices are read at that point.
-    // 2. What collection constants may be read.
-    // 3. What indices were written to by that point.
+    // 1. What indices are read.
+    // 2. What node is read.
+    // 3. What indices were written into the node already.
     let full_indices: Box<[Index]> = Box::new([]);
-    let triples = (0..function.nodes.len()).filter_map(|idx| {
-        // Item #1.
-        let id = NodeID::new(idx);
-        let (collect, indices) = match function.nodes[id.idx()] {
-            Node::Read {
-                collect,
-                ref indices,
-            } => (Either::Left(once(collect)), indices),
-            Node::Write {
-                collect: _,
-                data,
-                indices: _,
-            } => (Either::Left(once(data)), &full_indices),
-            Node::Call {
-                control: _,
-                function: _,
-                dynamic_constants: _,
-                ref args,
-            } => (Either::Right(args.into_iter().map(|id| *id)), &full_indices),
-            _ => return None,
-        };
-
-        // Item #2.
-        let constants: Vec<_> = collect
-            .map(|collect| objects.objects(collect).into_iter())
-            .flatten()
-            .filter_map(|obj| objects.origin(*obj).try_constant())
-            .collect();
-        if constants.is_empty() {
-            return None;
+    let triples = (0..function.nodes.len())
+        .map(|idx| {
+            // Items #1 and #2.
+            let id = NodeID::new(idx);
+            let indices_and_nodes = match function.nodes[id.idx()] {
+                Node::Read {
+                    collect,
+                    ref indices,
+                } => Either::Left(zip(once(indices), once(collect))),
+                Node::Write {
+                    collect: _,
+                    data,
+                    indices: _,
+                }
+                | Node::Return { control: _, data } => {
+                    Either::Left(zip(once(&full_indices), once(data)))
+                }
+                Node::Call {
+                    control: _,
+                    function: _,
+                    dynamic_constants: _,
+                    ref args,
+                } => Either::Right(zip(repeat(&full_indices), args.into_iter().map(|id| *id))),
+                _ => return None,
+            };
+
+            // Item #3.
+            Some(
+                indices_and_nodes
+                    .map(|(indices, read)| (indices, read, &lattice[read.idx()].written_to)),
+            )
+        })
+        .flatten()
+        .flatten();
+
+    // Third, look at each triple and check if there is a read that reads at
+    // indices not contained by any of the written to indices. If so, then any
+    // constant collection that may originate the object at that place has to be
+    // explicitly memset to zero. The result starts as every constant collection
+    // in the function, and is progressively shaved of constants that need to be
+    // reset explicitly.
+    let mut result: BTreeSet<_> = (0..function.nodes.len())
+        .filter(|idx| {
+            function.nodes[*idx].is_constant() && !types[typing[*idx].idx()].is_primitive()
+        })
+        .map(NodeID::new)
+        .collect();
+    for (indices, node, written_to) in triples {
+        if !written_to
+            .into_iter()
+            .all(|written_to| indices_contain_other_indices(written_to, indices))
+        {
+            for constant in objects
+                .objects(node)
+                .into_iter()
+                .filter_map(|obj| objects.origin(*obj).try_constant())
+            {
+                result.remove(&constant);
+            }
         }
+    }
 
-        // Item #3.
-        let written_to = &lattice[id.idx()].written_to;
-
-        Some((indices, constants, written_to))
-    });
-
-    todo!()
+    result
 }
diff --git a/hercules_ir/src/ir.rs b/hercules_ir/src/ir.rs
index 187b3f98..30c2e4fa 100644
--- a/hercules_ir/src/ir.rs
+++ b/hercules_ir/src/ir.rs
@@ -1,5 +1,6 @@
 use std::collections::HashSet;
 use std::fmt::Write;
+use std::iter::zip;
 use std::ops::Coroutine;
 use std::ops::CoroutineState;
 use std::pin::Pin;
@@ -1801,6 +1802,72 @@ impl Device {
     }
 }
 
+/*
+ * Helper function to tell if two lists of indices have the same structure.
+ */
+pub fn indices_structurally_equivalent(indices1: &[Index], indices2: &[Index]) -> bool {
+    if indices1.len() == indices2.len() {
+        let mut equiv = true;
+        for pair in zip(indices1, indices2) {
+            equiv = equiv
+                && match pair {
+                    (Index::Field(idx1), Index::Field(idx2)) => idx1 == idx2,
+                    (Index::Variant(idx1), Index::Variant(idx2)) => idx1 == idx2,
+                    (Index::Position(ref pos1), Index::Position(ref pos2)) => {
+                        assert_eq!(pos1.len(), pos2.len());
+                        true
+                    }
+                    _ => false,
+                };
+        }
+        equiv
+    } else {
+        false
+    }
+}
+
+/*
+ * Helper function to determine if two lists of indices may overlap.
+ */
+pub fn indices_may_overlap(indices1: &[Index], indices2: &[Index]) -> bool {
+    for pair in zip(indices1, indices2) {
+        match pair {
+            // Check that the field numbers are the same.
+            (Index::Field(idx1), Index::Field(idx2)) => {
+                if idx1 != idx2 {
+                    return false;
+                }
+            }
+            // Variant indices always may overlap, since it's the same
+            // underlying memory. Position indices always may overlap, since the
+            // indexing nodes may be the same at runtime.
+            (Index::Variant(_), Index::Variant(_)) | (Index::Position(_), Index::Position(_)) => {}
+            _ => panic!(),
+        }
+    }
+    // `zip` will exit as soon as either iterator is done - two sets of indices
+    // may overlap when one indexes a larger sub-value than the other.
+    true
+}
+
+/*
+ * Helper function to determine if a list of indices A definitely contains a
+ * list of indices B.
+ */
+pub fn indices_contain_other_indices(indices_a: &[Index], indices_b: &[Index]) -> bool {
+    if indices_a.len() < indices_b.len() {
+        return false;
+    }
+
+    for (idx1, idx2) in zip(indices_a, indices_b) {
+        if idx1 != idx2 {
+            return false;
+        }
+    }
+
+    true
+}
+
 /*
  * Rust things to make newtyped IDs usable.
  */
diff --git a/hercules_opt/src/utils.rs b/hercules_opt/src/utils.rs
index a2ca5028..0c6c2fac 100644
--- a/hercules_opt/src/utils.rs
+++ b/hercules_opt/src/utils.rs
@@ -322,72 +322,6 @@ pub fn ensure_between_control_flow(editor: &mut FunctionEditor) -> Option<NodeID
     }
 }
 
-/*
- * Helper function to tell if two lists of indices have the same structure.
- */
-pub(crate) fn indices_structurally_equivalent(indices1: &[Index], indices2: &[Index]) -> bool {
-    if indices1.len() == indices2.len() {
-        let mut equiv = true;
-        for pair in zip(indices1, indices2) {
-            equiv = equiv
-                && match pair {
-                    (Index::Field(idx1), Index::Field(idx2)) => idx1 == idx2,
-                    (Index::Variant(idx1), Index::Variant(idx2)) => idx1 == idx2,
-                    (Index::Position(ref pos1), Index::Position(ref pos2)) => {
-                        assert_eq!(pos1.len(), pos2.len());
-                        true
-                    }
-                    _ => false,
-                };
-        }
-        equiv
-    } else {
-        false
-    }
-}
-
-/*
- * Helper function to determine if two lists of indices may overlap.
- */
-pub(crate) fn indices_may_overlap(indices1: &[Index], indices2: &[Index]) -> bool {
-    for pair in zip(indices1, indices2) {
-        match pair {
-            // Check that the field numbers are the same.
-            (Index::Field(idx1), Index::Field(idx2)) => {
-                if idx1 != idx2 {
-                    return false;
-                }
-            }
-            // Variant indices always may overlap, since it's the same
-            // underlying memory. Position indices always may overlap, since the
-            // indexing nodes may be the same at runtime.
-            (Index::Variant(_), Index::Variant(_)) | (Index::Position(_), Index::Position(_)) => {}
-            _ => panic!(),
-        }
-    }
-    // `zip` will exit as soon as either iterator is done - two sets of indices
-    // may overlap when one indexes a larger sub-value than the other.
-    true
-}
-
-/*
- * Helper function to determine if a list of indices A definitely contains a
- * list of indices B.
- */
-pub(crate) fn indices_contain_other_indices(indices_a: &[Index], indices_b: &[Index]) -> bool {
-    if indices_a.len() < indices_b.len() {
-        return false;
-    }
-
-    for (idx1, idx2) in zip(indices_a, indices_b) {
-        if idx1 != idx2 {
-            return false;
-        }
-    }
-
-    true
-}
-
 pub type DenseNodeMap<T> = Vec<T>;
 pub type SparseNodeMap<T> = HashMap<NodeID, T>;
 
-- 
GitLab


From 6a760fa689ff829a0e9c48486eefe0691209542f Mon Sep 17 00:00:00 2001
From: Russel Arbore <russel.jma@gmail.com>
Date: Sat, 1 Feb 2025 10:34:05 -0600
Subject: [PATCH 4/9] add link search path for juno_build

---
 juno_build/build.rs | 11 ++++++-----
 1 file changed, 6 insertions(+), 5 deletions(-)

diff --git a/juno_build/build.rs b/juno_build/build.rs
index 7ba34c8c..b2464ecb 100644
--- a/juno_build/build.rs
+++ b/juno_build/build.rs
@@ -1,8 +1,9 @@
 fn main() {
     #[cfg(feature = "cuda")]
-    println!("cargo::rustc-link-search=native=/usr/lib/x86_64-linux-gnu/");
-    #[cfg(feature = "cuda")]
-    println!("cargo:rustc-link-search=native=/usr/local/cuda/lib64");
-    #[cfg(feature = "cuda")]
-    println!("cargo:rustc-link-lib=cudart");
+    {
+        println!("cargo::rustc-link-search=native=/usr/lib/x86_64-linux-gnu/");
+        println!("cargo:rustc-link-search=native=/usr/local/cuda/lib64");
+        println!("cargo::rustc-link-search=native=/opt/cuda/lib/");
+        println!("cargo:rustc-link-lib=cudart");
+    }
 }
-- 
GitLab


From 0268c3a4cf01d1be18238eec0ee13862c4ba9ffc Mon Sep 17 00:00:00 2001
From: Russel Arbore <russel.jma@gmail.com>
Date: Sat, 1 Feb 2025 11:02:31 -0600
Subject: [PATCH 5/9] use analysis to schedule constant nodes

---
 hercules_ir/src/collections.rs      |  8 +--
 hercules_ir/src/ir.rs               |  2 +
 hercules_opt/src/editor.rs          | 13 +++--
 hercules_opt/src/schedule.rs        | 15 +++++-
 hercules_samples/matmul/src/cpu.sch |  7 ++-
 juno_scheduler/src/pm.rs            | 78 ++++++++++++++++++++++++-----
 6 files changed, 98 insertions(+), 25 deletions(-)

diff --git a/hercules_ir/src/collections.rs b/hercules_ir/src/collections.rs
index e57a7756..d9d6c8f7 100644
--- a/hercules_ir/src/collections.rs
+++ b/hercules_ir/src/collections.rs
@@ -409,7 +409,7 @@ pub fn no_reset_constant_collections(
     reverse_postorder: &Vec<NodeID>,
     typing: &Vec<TypeID>,
     objects: &FunctionCollectionObjects,
-    reduce_einsum: &(MathEnv, HashMap<NodeID, NodeID>),
+    reduce_einsum: &(MathEnv, HashMap<NodeID, MathID>),
 ) -> BTreeSet<NodeID> {
     // First, run a dataflow analysis that determines at each collection node
     // what indices have definitely been written to.
@@ -437,7 +437,9 @@ pub fn no_reset_constant_collections(
                 // the empty indices set (the whole collection) is considered as
                 // written to.
                 let (env, exprs) = reduce_einsum;
-                if let MathExpr::Comprehension(_, _) = env[exprs[&id].idx()] {
+                if let Some(expr) = exprs.get(&id)
+                    && let MathExpr::Comprehension(_, _) = env[expr.idx()]
+                {
                     ZeroLattice::top()
                 }
                 // Otherwise, meet the `init` and `reduct` inputs.
@@ -513,7 +515,7 @@ pub fn no_reset_constant_collections(
     for (indices, node, written_to) in triples {
         if !written_to
             .into_iter()
-            .all(|written_to| indices_contain_other_indices(written_to, indices))
+            .any(|written_to| indices_contain_other_indices(written_to, indices))
         {
             for constant in objects
                 .objects(node)
diff --git a/hercules_ir/src/ir.rs b/hercules_ir/src/ir.rs
index 30c2e4fa..846347b0 100644
--- a/hercules_ir/src/ir.rs
+++ b/hercules_ir/src/ir.rs
@@ -327,6 +327,8 @@ pub enum Schedule {
     // This reduce can be re-associated. This may lower a sequential dependency
     // chain into a reduction tree.
     TightAssociative,
+    // This constant node doesn't need to be memset to zero.
+    NoResetConstant,
 }
 
 /*
diff --git a/hercules_opt/src/editor.rs b/hercules_opt/src/editor.rs
index 1d586057..43a88b7c 100644
--- a/hercules_opt/src/editor.rs
+++ b/hercules_opt/src/editor.rs
@@ -557,19 +557,22 @@ impl<'a, 'b> FunctionEdit<'a, 'b> {
     pub fn add_schedule(mut self, id: NodeID, schedule: Schedule) -> Result<Self, Self> {
         if self.is_mutable(id) {
             if let Some(schedules) = self.added_and_updated_schedules.get_mut(&id) {
-                schedules.push(schedule);
+                if !schedules.contains(&schedule) {
+                    schedules.push(schedule);
+                }
             } else {
-                let mut schedules = self
+                let empty = vec![];
+                let schedules = self
                     .editor
                     .function
                     .schedules
                     .get(id.idx())
-                    .unwrap_or(&vec![])
-                    .clone();
+                    .unwrap_or(&empty);
                 if !schedules.contains(&schedule) {
+                    let mut schedules = schedules.clone();
                     schedules.push(schedule);
+                    self.added_and_updated_schedules.insert(id, schedules);
                 }
-                self.added_and_updated_schedules.insert(id, schedules);
             }
             Ok(self)
         } else {
diff --git a/hercules_opt/src/schedule.rs b/hercules_opt/src/schedule.rs
index f9f720be..fe894e47 100644
--- a/hercules_opt/src/schedule.rs
+++ b/hercules_opt/src/schedule.rs
@@ -1,4 +1,4 @@
-use std::collections::{HashMap, HashSet};
+use std::collections::{BTreeSet, HashMap, HashSet};
 
 use hercules_ir::def_use::*;
 use hercules_ir::ir::*;
@@ -181,3 +181,16 @@ pub fn infer_tight_associative(
         }
     }
 }
+
+/*
+ * From analysis result of which constants don't need to be reset, add schedules
+ * to those constant nodes.
+ */
+pub fn infer_no_reset_constants(
+    editor: &mut FunctionEditor,
+    no_reset_constants: &BTreeSet<NodeID>,
+) {
+    for id in no_reset_constants {
+        editor.edit(|edit| edit.add_schedule(*id, Schedule::NoResetConstant));
+    }
+}
diff --git a/hercules_samples/matmul/src/cpu.sch b/hercules_samples/matmul/src/cpu.sch
index aeed7e10..0321e13d 100644
--- a/hercules_samples/matmul/src/cpu.sch
+++ b/hercules_samples/matmul/src/cpu.sch
@@ -7,15 +7,14 @@ auto-outline(*);
 ip-sroa(*);
 sroa(*);
 dce(*);
-infer-schedules(*);
+fixpoint {
+  infer-schedules(*);
+}
 fork-split(*);
 unforkify(*);
 dce(*);
-float-collections(*);
 gvn(*);
 phi-elim(*);
 dce(*);
 
-infer-schedules(*);
-
 gcm(*);
diff --git a/juno_scheduler/src/pm.rs b/juno_scheduler/src/pm.rs
index 2371e0f2..f6fe2fc1 100644
--- a/juno_scheduler/src/pm.rs
+++ b/juno_scheduler/src/pm.rs
@@ -184,6 +184,7 @@ pub struct PassManager {
     pub reduce_cycles: Option<Vec<HashMap<NodeID, HashSet<NodeID>>>>,
     pub data_nodes_in_fork_joins: Option<Vec<HashMap<NodeID, HashSet<NodeID>>>>,
     pub reduce_einsums: Option<Vec<(MathEnv, HashMap<NodeID, MathID>)>>,
+    pub no_reset_constants: Option<Vec<BTreeSet<NodeID>>>,
     pub collection_objects: Option<CollectionObjects>,
     pub callgraph: Option<CallGraph>,
     pub devices: Option<Vec<Device>>,
@@ -222,6 +223,7 @@ impl PassManager {
             reduce_cycles: None,
             data_nodes_in_fork_joins: None,
             reduce_einsums: None,
+            no_reset_constants: None,
             collection_objects: None,
             callgraph: None,
             devices: None,
@@ -345,7 +347,12 @@ impl PassManager {
         if self.fork_control_maps.is_none() {
             self.make_fork_join_nests();
             self.fork_control_maps = Some(
-                self.fork_join_nests.as_ref().unwrap().iter().map(fork_control_map).collect(),
+                self.fork_join_nests
+                    .as_ref()
+                    .unwrap()
+                    .iter()
+                    .map(fork_control_map)
+                    .collect(),
             );
         }
     }
@@ -463,6 +470,43 @@ impl PassManager {
         }
     }
 
+    pub fn make_no_reset_constants(&mut self) {
+        if self.no_reset_constants.is_none() {
+            self.make_reverse_postorders();
+            self.make_typing();
+            self.make_collection_objects();
+            self.make_reduce_einsums();
+            let reverse_postorders = self.reverse_postorders.as_ref().unwrap().iter();
+            let typing = self.typing.as_ref().unwrap().iter();
+            let collection_objects = self.collection_objects.as_ref().unwrap().iter();
+            let reduce_einsums = self.reduce_einsums.as_ref().unwrap().iter();
+            self.no_reset_constants = Some(
+                self.functions
+                    .iter()
+                    .zip(reverse_postorders)
+                    .zip(typing)
+                    .zip(collection_objects)
+                    .zip(reduce_einsums)
+                    .map(
+                        |(
+                            (((function, reverse_postorder), typing), collection_object),
+                            reduce_einsum,
+                        )| {
+                            no_reset_constant_collections(
+                                function,
+                                &self.types.borrow(),
+                                reverse_postorder,
+                                typing,
+                                collection_object.1,
+                                reduce_einsum,
+                            )
+                        },
+                    )
+                    .collect(),
+            );
+        }
+    }
+
     pub fn make_collection_objects(&mut self) {
         if self.collection_objects.is_none() {
             self.make_reverse_postorders();
@@ -535,6 +579,7 @@ impl PassManager {
         self.reduce_cycles = None;
         self.data_nodes_in_fork_joins = None;
         self.reduce_einsums = None;
+        self.no_reset_constants = None;
         self.collection_objects = None;
         self.callgraph = None;
         self.devices = None;
@@ -716,8 +761,8 @@ impl PassManager {
         let mut llvm_path = tmp_dir.path().to_path_buf();
         llvm_path.push(format!("{}.ll", module_name));
         println!("{}", llvm_path.display());
-        let mut file = File::create(&llvm_path)
-            .expect("PANIC: Unable to open output LLVM IR file.");
+        let mut file =
+            File::create(&llvm_path).expect("PANIC: Unable to open output LLVM IR file.");
         file.write_all(llvm_ir.as_bytes())
             .expect("PANIC: Unable to write output LLVM IR file contents.");
 
@@ -738,13 +783,17 @@ impl PassManager {
 
         let mut ar_args = vec!["crus", &output_archive, &llvm_object];
 
-        let cuda_object = format!("{}/{}_cuda.o", tmp_dir.path().to_str().unwrap(), module_name);
+        let cuda_object = format!(
+            "{}/{}_cuda.o",
+            tmp_dir.path().to_str().unwrap(),
+            module_name
+        );
         if cfg!(feature = "cuda") {
             // Write the CUDA IR into a temporary file.
             let mut cuda_path = tmp_dir.path().to_path_buf();
             cuda_path.push(format!("{}.cu", module_name));
-            let mut file = File::create(&cuda_path)
-                .expect("PANIC: Unable to open output CUDA IR file.");
+            let mut file =
+                File::create(&cuda_path).expect("PANIC: Unable to open output CUDA IR file.");
             file.write_all(cuda_ir.as_bytes())
                 .expect("PANIC: Unable to write output CUDA IR file contents.");
 
@@ -770,8 +819,8 @@ impl PassManager {
         // Write the Rust runtime into a file.
         let output_rt = format!("{}/rt_{}.hrt", output_dir, module_name);
         println!("{}", output_rt);
-        let mut file = File::create(&output_rt)
-            .expect("PANIC: Unable to open output Rust runtime file.");
+        let mut file =
+            File::create(&output_rt).expect("PANIC: Unable to open output Rust runtime file.");
         file.write_all(rust_rt.as_bytes())
             .expect("PANIC: Unable to write output Rust runtime file contents.");
 
@@ -1592,12 +1641,16 @@ fn run_pass(
             assert!(args.is_empty());
             pm.make_fork_join_maps();
             pm.make_reduce_cycles();
+            pm.make_no_reset_constants();
             let fork_join_maps = pm.fork_join_maps.take().unwrap();
             let reduce_cycles = pm.reduce_cycles.take().unwrap();
-            for ((func, fork_join_map), reduce_cycles) in build_selection(pm, selection)
-                .into_iter()
-                .zip(fork_join_maps.iter())
-                .zip(reduce_cycles.iter())
+            let no_reset_constants = pm.no_reset_constants.take().unwrap();
+            for (((func, fork_join_map), reduce_cycles), no_reset_constants) in
+                build_selection(pm, selection)
+                    .into_iter()
+                    .zip(fork_join_maps.iter())
+                    .zip(reduce_cycles.iter())
+                    .zip(no_reset_constants.iter())
             {
                 let Some(mut func) = func else {
                     continue;
@@ -1606,6 +1659,7 @@ fn run_pass(
                 infer_parallel_fork(&mut func, fork_join_map);
                 infer_vectorizable(&mut func, fork_join_map);
                 infer_tight_associative(&mut func, reduce_cycles);
+                infer_no_reset_constants(&mut func, no_reset_constants);
                 changed |= func.modified();
             }
             pm.delete_gravestones();
-- 
GitLab


From f4809b03f948e3f5f30581b2db5ac95ae9be4e27 Mon Sep 17 00:00:00 2001
From: Russel Arbore <russel.jma@gmail.com>
Date: Sat, 1 Feb 2025 11:10:20 -0600
Subject: [PATCH 6/9] Use NoResetConstant in backends

---
 hercules_cg/src/cpu.rs              |  16 +-
 hercules_cg/src/gpu.rs              | 622 ++++++++++++++++++++++------
 hercules_cg/src/rt.rs               |  18 +-
 hercules_samples/matmul/src/gpu.sch |   7 +-
 4 files changed, 519 insertions(+), 144 deletions(-)

diff --git a/hercules_cg/src/cpu.rs b/hercules_cg/src/cpu.rs
index 344554b6..f6a1f309 100644
--- a/hercules_cg/src/cpu.rs
+++ b/hercules_cg/src/cpu.rs
@@ -288,13 +288,15 @@ impl<'a> CPUContext<'a> {
                         self.get_value(id, false),
                         offset.idx()
                     )?;
-                    let data_size = self.codegen_type_size(self.typing[id.idx()], body)?;
-                    write!(
-                        body,
-                        "  call void @llvm.memset.p0.i64({}, i8 0, i64 {}, i1 false)\n",
-                        self.get_value(id, true),
-                        data_size,
-                    )?;
+                    if !self.function.schedules[id.idx()].contains(&Schedule::NoResetConstant) {
+                        let data_size = self.codegen_type_size(self.typing[id.idx()], body)?;
+                        write!(
+                            body,
+                            "  call void @llvm.memset.p0.i64({}, i8 0, i64 {}, i1 false)\n",
+                            self.get_value(id, true),
+                            data_size,
+                        )?;
+                    }
                 }
             }
             Node::DynamicConstant { id: dc_id } => {
diff --git a/hercules_cg/src/gpu.rs b/hercules_cg/src/gpu.rs
index d7a6d258..81e31396 100644
--- a/hercules_cg/src/gpu.rs
+++ b/hercules_cg/src/gpu.rs
@@ -3,7 +3,7 @@ extern crate hercules_ir;
 
 use std::collections::{BTreeMap, HashMap, HashSet};
 use std::fmt::{Error, Write};
-use std::fs::{OpenOptions, File};
+use std::fs::{File, OpenOptions};
 use std::io::Write as _;
 
 use self::hercules_ir::*;
@@ -85,7 +85,6 @@ pub fn gpu_codegen<W: Write>(
         .map(NodeID::new)
         .collect();
 
-
     let join_fork_map: HashMap<NodeID, NodeID> = fork_join_map
         .iter()
         .map(|(fork, join)| (*join, *fork))
@@ -104,7 +103,7 @@ pub fn gpu_codegen<W: Write>(
         } = &function.nodes[reduce_node.idx()]
         {
             match function.nodes[control.idx()] {
-                Node::Join {..} => {
+                Node::Join { .. } => {
                     let fork_node = join_fork_map[control];
                     fork_reduce_map
                         .entry(fork_node)
@@ -126,9 +125,10 @@ pub fn gpu_codegen<W: Write>(
     }
     for idx in 0..function.nodes.len() {
         if function.nodes[idx].is_fork() {
-            assert!(fork_reduce_map
-                .get(&NodeID::new(idx))
-                .is_some_and(|reduces| !reduces.is_empty()),
+            assert!(
+                fork_reduce_map
+                    .get(&NodeID::new(idx))
+                    .is_some_and(|reduces| !reduces.is_empty()),
                 "Fork node {} has no reduce nodes",
                 idx
             );
@@ -144,14 +144,18 @@ pub fn gpu_codegen<W: Write>(
                 panic!("Phi's control must be a region node");
             };
             for (i, &pred) in preds.iter().enumerate() {
-                control_data_phi_map.entry(pred).or_default().push((data[i], NodeID::new(idx)));
+                control_data_phi_map
+                    .entry(pred)
+                    .or_default()
+                    .push((data[i], NodeID::new(idx)));
             }
         }
     }
 
     let return_parameter = if collection_objects.returned_objects().len() == 1 {
-        collection_objects.origin(*collection_objects.returned_objects()
-            .first().unwrap()).try_parameter()
+        collection_objects
+            .origin(*collection_objects.returned_objects().first().unwrap())
+            .try_parameter()
     } else {
         None
     };
@@ -280,14 +284,20 @@ impl GPUContext<'_> {
 
         // If there are no forks, fast forward to single-block, single-thread codegen
         let (num_blocks, num_threads) = if self.fork_join_map.is_empty() {
-            self.codegen_data_control_no_forks(&HashSet::new(), &mut dynamic_shared_offset, &mut gotos)?;
+            self.codegen_data_control_no_forks(
+                &HashSet::new(),
+                &mut dynamic_shared_offset,
+                &mut gotos,
+            )?;
             ("1".to_string(), "1".to_string())
         } else {
             // Create structures and determine block and thread parallelization strategy
             let (root_forks, num_blocks, is_block_parallel) =
                 self.get_root_forks_and_num_blocks(self.fork_tree);
-            let (thread_root_root_fork, thread_root_forks) = self.get_thread_root_forks(&root_forks, self.fork_tree, is_block_parallel);
-            let (fork_thread_quota_map, num_threads) = self.get_thread_quotas(self.fork_tree, thread_root_root_fork);
+            let (thread_root_root_fork, thread_root_forks) =
+                self.get_thread_root_forks(&root_forks, self.fork_tree, is_block_parallel);
+            let (fork_thread_quota_map, num_threads) =
+                self.get_thread_quotas(self.fork_tree, thread_root_root_fork);
             // TODO: Uncomment and adjust once we know logic of extra dim. This will affect constant
             // collections, reads, and writes.
             // let extra_dim_collects = self.get_extra_dim_collects(&fork_control_map, &fork_thread_quota_map);
@@ -319,7 +329,12 @@ impl GPUContext<'_> {
 
         // Emit host launch code
         let mut host_launch = String::new();
-        self.codegen_launch_code(num_blocks, num_threads, &dynamic_shared_offset, &mut host_launch)?;
+        self.codegen_launch_code(
+            num_blocks,
+            num_threads,
+            &dynamic_shared_offset,
+            &mut host_launch,
+        )?;
         write!(w, "{}", host_launch)?;
 
         Ok(())
@@ -329,7 +344,9 @@ impl GPUContext<'_> {
      * Emit kernel headers, signature, arguments, and dynamic shared memory declaration
      */
     fn codegen_kernel_begin(&self, has_ret_var: bool, w: &mut String) -> Result<(), Error> {
-        write!(w, "
+        write!(
+            w,
+            "
 #include <assert.h>
 #include <stdio.h>
 #include <stddef.h>
@@ -390,10 +407,7 @@ namespace cg = cooperative_groups;
             if !first_param {
                 write!(w, ", ")?;
             }
-            write!(
-                w,
-                "void* __restrict__ ret",
-            )?;
+            write!(w, "void* __restrict__ ret",)?;
         }
 
         // Type is char since it's simplest to use single bytes for indexing
@@ -425,24 +439,48 @@ namespace cg = cooperative_groups;
                     }
                 }
                 DynamicConstant::Add(args) => {
-                    let rhs = args.iter().map(|arg| format!("dc{}", arg.idx())).collect::<Vec<_>>().join(" + ");
+                    let rhs = args
+                        .iter()
+                        .map(|arg| format!("dc{}", arg.idx()))
+                        .collect::<Vec<_>>()
+                        .join(" + ");
                     write!(w, "\t{} = {};\n", dc_val, rhs)?
                 }
                 DynamicConstant::Mul(args) => {
-                    let rhs = args.iter().map(|arg| format!("dc{}", arg.idx())).collect::<Vec<_>>().join(" * ");
+                    let rhs = args
+                        .iter()
+                        .map(|arg| format!("dc{}", arg.idx()))
+                        .collect::<Vec<_>>()
+                        .join(" * ");
                     write!(w, "\t{} = {};\n", dc_val, rhs)?
                 }
                 DynamicConstant::Min(args) => {
-                    let rhs_but_last: String = args.iter().take(args.len() - 1).map(|arg| format!("min(dc{}, ", arg.idx())).collect();
+                    let rhs_but_last: String = args
+                        .iter()
+                        .take(args.len() - 1)
+                        .map(|arg| format!("min(dc{}, ", arg.idx()))
+                        .collect();
                     let rhs_last = format!("dc{}", args.last().unwrap().idx());
                     let rhs_end: String = std::iter::repeat(")").take(args.len() - 1).collect();
-                    write!(w, "\t{} = {}{}{};\n", dc_val, rhs_but_last, rhs_last, rhs_end)?
+                    write!(
+                        w,
+                        "\t{} = {}{}{};\n",
+                        dc_val, rhs_but_last, rhs_last, rhs_end
+                    )?
                 }
                 DynamicConstant::Max(args) => {
-                    let rhs_but_last: String = args.iter().take(args.len() - 1).map(|arg| format!("max(dc{}, ", arg.idx())).collect();
+                    let rhs_but_last: String = args
+                        .iter()
+                        .take(args.len() - 1)
+                        .map(|arg| format!("max(dc{}, ", arg.idx()))
+                        .collect();
                     let rhs_last = format!("dc{}", args.last().unwrap().idx());
                     let rhs_end: String = std::iter::repeat(")").take(args.len() - 1).collect();
-                    write!(w, "\t{} = {}{}{};\n", dc_val, rhs_but_last, rhs_last, rhs_end)?
+                    write!(
+                        w,
+                        "\t{} = {}{}{};\n",
+                        dc_val, rhs_but_last, rhs_last, rhs_end
+                    )?
                 }
                 DynamicConstant::Sub(left, right) => {
                     write!(w, "\t{} = dc{} - dc{};\n", dc_val, left.idx(), right.idx())?
@@ -464,9 +502,10 @@ namespace cg = cooperative_groups;
      */
     fn codegen_declare_data(&self, w: &mut String) -> Result<(), Error> {
         for id in (0..self.function.nodes.len()).map(NodeID::new) {
-            if !self.function.nodes[id.idx()].is_control() &&
-                !self.function.nodes[id.idx()].is_dynamic_constant() &&
-                !self.function.nodes[id.idx()].is_parameter() {
+            if !self.function.nodes[id.idx()].is_control()
+                && !self.function.nodes[id.idx()].is_dynamic_constant()
+                && !self.function.nodes[id.idx()].is_parameter()
+            {
                 write!(w, "\t{};\n", self.get_value(id, true, false))?;
             }
             if self.function.nodes[id.idx()].is_phi() {
@@ -493,7 +532,12 @@ namespace cg = cooperative_groups;
         Ok(())
     }
 
-    fn codegen_gotos(&self, goto_debug: bool, gotos: &mut BTreeMap<NodeID, CudaGoto>, w: &mut String) -> Result<(), Error> {
+    fn codegen_gotos(
+        &self,
+        goto_debug: bool,
+        gotos: &mut BTreeMap<NodeID, CudaGoto>,
+        w: &mut String,
+    ) -> Result<(), Error> {
         write!(w, "\n")?;
         for (id, goto) in gotos.iter() {
             let goto_block = self.get_block_name(*id, false);
@@ -513,13 +557,24 @@ namespace cg = cooperative_groups;
         Ok(())
     }
 
-    fn codegen_launch_code(&self, num_blocks: String, num_threads: String, dynamic_shared_offset: &str, w: &mut String) -> Result<(), Error> {
+    fn codegen_launch_code(
+        &self,
+        num_blocks: String,
+        num_threads: String,
+        dynamic_shared_offset: &str,
+        w: &mut String,
+    ) -> Result<(), Error> {
         // The following steps are for host-side C function arguments, but we also
         // need to pass arguments to kernel, so we keep track of the arguments here.
         let ret_type = self.get_type(self.function.return_type, false);
         let mut pass_args = String::new();
-        write!(w, "
-extern \"C\" {} {}(", ret_type.clone(), self.function.name)?;
+        write!(
+            w,
+            "
+extern \"C\" {} {}(",
+            ret_type.clone(),
+            self.function.name
+        )?;
         let mut first_param = true;
         // The first parameter is a pointer to GPU backing memory, if it's
         // needed.
@@ -566,16 +621,30 @@ extern \"C\" {} {}(", ret_type.clone(), self.function.name)?;
             write!(w, "\tcudaMalloc((void**)&ret, sizeof({}));\n", ret_type)?;
         }
         write!(w, "\tcudaError_t err;\n");
-        write!(w, "\t{}_gpu<<<{}, {}, {}>>>({});\n", self.function.name, num_blocks, num_threads, dynamic_shared_offset, pass_args)?;
+        write!(
+            w,
+            "\t{}_gpu<<<{}, {}, {}>>>({});\n",
+            self.function.name, num_blocks, num_threads, dynamic_shared_offset, pass_args
+        )?;
         write!(w, "\terr = cudaGetLastError();\n");
-        write!(w, "\tif (cudaSuccess != err) {{ printf(\"Error1: %s\\n\", cudaGetErrorString(err)); }}\n");
+        write!(
+            w,
+            "\tif (cudaSuccess != err) {{ printf(\"Error1: %s\\n\", cudaGetErrorString(err)); }}\n"
+        );
         write!(w, "\tcudaDeviceSynchronize();\n")?;
         write!(w, "\terr = cudaGetLastError();\n");
-        write!(w, "\tif (cudaSuccess != err) {{ printf(\"Error2: %s\\n\", cudaGetErrorString(err)); }}\n");
+        write!(
+            w,
+            "\tif (cudaSuccess != err) {{ printf(\"Error2: %s\\n\", cudaGetErrorString(err)); }}\n"
+        );
         if has_ret_var {
             // Copy return from device to host, whether it's primitive value or collection pointer
             write!(w, "\t{} host_ret;\n", ret_type)?;
-            write!(w, "\tcudaMemcpy(&host_ret, ret, sizeof({}), cudaMemcpyDeviceToHost);\n", ret_type)?;
+            write!(
+                w,
+                "\tcudaMemcpy(&host_ret, ret, sizeof({}), cudaMemcpyDeviceToHost);\n",
+                ret_type
+            )?;
             write!(w, "\treturn host_ret;\n")?;
         } else {
             write!(w, "\treturn p{};\n", self.return_parameter.unwrap())?;
@@ -604,9 +673,12 @@ extern \"C\" {} {}(", ret_type.clone(), self.function.name)?;
             panic!("Expected fork node");
         };
         let reduces = &self.fork_reduce_map[root_fork];
-        if self.function.schedules[root_fork.idx()].contains(&Schedule::ParallelFork)
-        {
-            let fork_size = factors.iter().map(|dc| format!("dc{}", dc.idx())).collect::<Vec<_>>().join(" * ");
+        if self.function.schedules[root_fork.idx()].contains(&Schedule::ParallelFork) {
+            let fork_size = factors
+                .iter()
+                .map(|dc| format!("dc{}", dc.idx()))
+                .collect::<Vec<_>>()
+                .join(" * ");
             (root_forks, fork_size, true)
         } else {
             (root_forks, "1".to_string(), false)
@@ -626,7 +698,10 @@ extern \"C\" {} {}(", ret_type.clone(), self.function.name)?;
     ) -> (NodeID, HashSet<NodeID>) {
         if is_block_parallel {
             let root_fork = root_forks.iter().next().unwrap();
-            (*root_fork, fork_tree.get(&root_fork).unwrap().iter().copied().collect())
+            (
+                *root_fork,
+                fork_tree.get(&root_fork).unwrap().iter().copied().collect(),
+            )
         } else {
             (NodeID::new(0), root_forks.clone())
         }
@@ -676,12 +751,17 @@ extern \"C\" {} {}(", ret_type.clone(), self.function.name)?;
             .map(|child| (child, self.recurse_thread_quotas(*child, fork_tree, false)))
             .fold(
                 (HashMap::new(), HashMap::new(), 1),
-                |(mut subsubtree_map, mut children_quota_map, subtree_quota), (child, (curr_map, curr_quota, use_curr))| {
+                |(mut subsubtree_map, mut children_quota_map, subtree_quota),
+                 (child, (curr_map, curr_quota, use_curr))| {
                     subsubtree_map.extend(curr_map);
                     if use_curr {
                         children_quota_map.insert(child, curr_quota);
                     }
-                    (subsubtree_map, children_quota_map, subtree_quota.max(curr_quota))
+                    (
+                        subsubtree_map,
+                        children_quota_map,
+                        subtree_quota.max(curr_quota),
+                    )
                 },
             );
         // First update children_quota_map items with full information and add
@@ -695,7 +775,7 @@ extern \"C\" {} {}(", ret_type.clone(), self.function.name)?;
         }
         let subtree_map = subsubtree_map;
         if is_root {
-            return (subtree_map, subtree_quota, true)
+            return (subtree_map, subtree_quota, true);
         }
         // A node can only be considered for parallelization if:
         // a) it has statically known size
@@ -725,7 +805,8 @@ extern \"C\" {} {}(", ret_type.clone(), self.function.name)?;
             // is possible.
             if reduces.iter().any(|&reduce| {
                 self.function.schedules[reduce.idx()].contains(&Schedule::TightAssociative)
-            }) || fork_size > self.kernel_params.max_num_threads / subtree_quota {
+            }) || fork_size > self.kernel_params.max_num_threads / subtree_quota
+            {
                 if fork_size >= subtree_quota {
                     (HashMap::new(), fork_size, true)
                 } else {
@@ -770,9 +851,22 @@ extern \"C\" {} {}(", ret_type.clone(), self.function.name)?;
                 let post_init = &mut goto.post_init;
                 let body = &mut goto.body;
                 let term = &mut goto.term;
-                let mut tabs = self.codegen_control_node(control, None, None, None, init, post_init, term)?;
+                let mut tabs =
+                    self.codegen_control_node(control, None, None, None, init, post_init, term)?;
                 for data in self.bbs.1[control.idx()].iter() {
-                    self.codegen_data_node(*data, KernelState::OutBlock, Some(false), None, None, None, false, extra_dim_collects, dynamic_shared_offset, body, &mut tabs)?;
+                    self.codegen_data_node(
+                        *data,
+                        KernelState::OutBlock,
+                        Some(false),
+                        None,
+                        None,
+                        None,
+                        false,
+                        extra_dim_collects,
+                        dynamic_shared_offset,
+                        body,
+                        &mut tabs,
+                    )?;
                 }
                 Ok(())
             })
@@ -801,9 +895,22 @@ extern \"C\" {} {}(", ret_type.clone(), self.function.name)?;
             let post_init = &mut goto.post_init;
             let body = &mut goto.body;
             let term = &mut goto.term;
-            let mut tabs = self.codegen_control_node(*control, None, None, None, init, post_init, term)?;
+            let mut tabs =
+                self.codegen_control_node(*control, None, None, None, init, post_init, term)?;
             for data in self.bbs.1[control.idx()].iter() {
-                self.codegen_data_node(*data, state, Some(is_block_parallel), None, None, None, false, extra_dim_collects, dynamic_shared_offset, body, &mut tabs)?;
+                self.codegen_data_node(
+                    *data,
+                    state,
+                    Some(is_block_parallel),
+                    None,
+                    None,
+                    None,
+                    false,
+                    extra_dim_collects,
+                    dynamic_shared_offset,
+                    body,
+                    &mut tabs,
+                )?;
             }
         }
         // Then generate data and control for the single block fork if it exists
@@ -815,9 +922,29 @@ extern \"C\" {} {}(", ret_type.clone(), self.function.name)?;
                 let post_init = &mut goto.post_init;
                 let body = &mut goto.body;
                 let term = &mut goto.term;
-                let mut tabs = self.codegen_control_node(*control, Some(num_threads), Some(num_threads), Some(1), init, post_init, term)?;
+                let mut tabs = self.codegen_control_node(
+                    *control,
+                    Some(num_threads),
+                    Some(num_threads),
+                    Some(1),
+                    init,
+                    post_init,
+                    term,
+                )?;
                 for data in self.bbs.1[control.idx()].iter() {
-                    self.codegen_data_node(*data, state, None, Some(num_threads), None, Some(block_fork.unwrap()), false, extra_dim_collects, dynamic_shared_offset, body, &mut tabs)?;
+                    self.codegen_data_node(
+                        *data,
+                        state,
+                        None,
+                        Some(num_threads),
+                        None,
+                        Some(block_fork.unwrap()),
+                        false,
+                        extra_dim_collects,
+                        dynamic_shared_offset,
+                        body,
+                        &mut tabs,
+                    )?;
                 }
             }
         }
@@ -862,14 +989,19 @@ extern \"C\" {} {}(", ret_type.clone(), self.function.name)?;
         let reduces = &self.fork_reduce_map[&curr_fork];
         let reducts = if parallel_factor.is_some() {
             reduces
-            .iter()
-            .map(|&reduce| {
-                let Node::Reduce { control: _, init: _, reduct} = &self.function.nodes[reduce.idx()] else {
-                    panic!("Expected reduce node");
-                };
-                *reduct
-            })
-            .collect()
+                .iter()
+                .map(|&reduce| {
+                    let Node::Reduce {
+                        control: _,
+                        init: _,
+                        reduct,
+                    } = &self.function.nodes[reduce.idx()]
+                    else {
+                        panic!("Expected reduce node");
+                    };
+                    *reduct
+                })
+                .collect()
         } else {
             HashSet::new()
         };
@@ -879,7 +1011,15 @@ extern \"C\" {} {}(", ret_type.clone(), self.function.name)?;
             let post_init = &mut goto.post_init;
             let body = &mut goto.body;
             let term = &mut goto.term;
-            let mut tabs = self.codegen_control_node(*control, Some(available_thread_quota), Some(use_thread_quota), parallel_factor, init, post_init, term)?;
+            let mut tabs = self.codegen_control_node(
+                *control,
+                Some(available_thread_quota),
+                Some(use_thread_quota),
+                parallel_factor,
+                init,
+                post_init,
+                term,
+            )?;
             for data in self.bbs.1[control.idx()].iter() {
                 self.codegen_data_node(
                     *data,
@@ -928,8 +1068,15 @@ extern \"C\" {} {}(", ret_type.clone(), self.function.name)?;
         let define_variable = self.get_value(id, false, false).to_string();
         let tabs = "\t".repeat(*num_tabs);
         match &self.function.nodes[id.idx()] {
-            Node::Phi { control: _, data: _ } => {
-                write!(w, "{}{} = {}_tmp;\n", tabs, define_variable, define_variable)?;
+            Node::Phi {
+                control: _,
+                data: _,
+            } => {
+                write!(
+                    w,
+                    "{}{} = {}_tmp;\n",
+                    tabs, define_variable, define_variable
+                )?;
             }
             Node::ThreadID { control, dimension } => {
                 let Node::Fork { factors, .. } = &self.function.nodes[control.idx()] else {
@@ -950,11 +1097,22 @@ extern \"C\" {} {}(", ret_type.clone(), self.function.name)?;
                             // No dependence on threadIdx.x because each used thread
                             // will run this Fork serially
                             let fork_iter = self.get_fork_iter(*control, false);
-                            write!(w, "{}{} = ({} / {}) % {};\n", tabs, define_variable, fork_iter, divide, modulo)?;
+                            write!(
+                                w,
+                                "{}{} = ({} / {}) % {};\n",
+                                tabs, define_variable, fork_iter, divide, modulo
+                            )?;
                         } else {
                             // We can directly use use_thread_quota and not worry about available
                             // because Fork basic block's init section already does gating
-                            write!(w, "{}{} = (threadIdx.x % {}) / {};\n", tabs, define_variable, use_thread_quota.unwrap(), use_thread_quota.unwrap() / parallel_factor.unwrap())?;
+                            write!(
+                                w,
+                                "{}{} = (threadIdx.x % {}) / {};\n",
+                                tabs,
+                                define_variable,
+                                use_thread_quota.unwrap(),
+                                use_thread_quota.unwrap() / parallel_factor.unwrap()
+                            )?;
                         }
                     }
                     _ => {
@@ -995,7 +1153,9 @@ extern \"C\" {} {}(", ret_type.clone(), self.function.name)?;
                 let is_primitive = self.types[self.typing[id.idx()].idx()].is_primitive();
                 let cg_tile = match state {
                     KernelState::OutBlock | KernelState::InBlock => "block".to_string(),
-                    KernelState::InThread => self.get_cg_tile(nesting_fork.unwrap(), CGType::UsePerId),
+                    KernelState::InThread => {
+                        self.get_cg_tile(nesting_fork.unwrap(), CGType::UsePerId)
+                    }
                 };
                 if !is_primitive && state != KernelState::OutBlock {
                     write!(w, "{}if ({}.thread_rank() == 0) {{\n", tabs, cg_tile)?;
@@ -1018,13 +1178,29 @@ extern \"C\" {} {}(", ret_type.clone(), self.function.name)?;
                     *num_tabs -= 1;
                 }
                 if !is_primitive && state == KernelState::OutBlock {
+                    assert!(self.function.schedules[id.idx()].contains(&Schedule::NoResetConstant), "PANIC: The CUDA backend cannot lower a global memory constant that has to be reset to zero. This is because we cannot efficiently implement a memset to the underlying memory of the constant due to the need for a grid level sync. Consider floating this collection outside the CUDA function and into an AsyncRust function, or attaching the NoResetConstant schedule to indicate that no memset is semantically necessary.");
                     let (_, offsets) = &self.backing_allocation[&Device::CUDA];
                     let offset = offsets[&id];
-                    write!(w, "{}{} = backing + dc{};\n", tabs, define_variable, offset.idx())?;
+                    write!(
+                        w,
+                        "{}{} = backing + dc{};\n",
+                        tabs,
+                        define_variable,
+                        offset.idx()
+                    )?;
                 }
-                if !is_primitive && (state != KernelState::OutBlock || is_block_parallel.is_none() || !is_block_parallel.unwrap()) {
-                    let data_size = self.get_size(self.typing[id.idx()], None, Some(extra_dim_collects));
-                    write!(w, "{}for (int i = {}.thread_rank(); i < {}; i += {}.size()) {{\n", tabs, cg_tile, data_size, cg_tile)?;
+                if !is_primitive
+                    && (state != KernelState::OutBlock
+                        || is_block_parallel.is_none()
+                        || !is_block_parallel.unwrap())
+                {
+                    let data_size =
+                        self.get_size(self.typing[id.idx()], None, Some(extra_dim_collects));
+                    write!(
+                        w,
+                        "{}for (int i = {}.thread_rank(); i < {}; i += {}.size()) {{\n",
+                        tabs, cg_tile, data_size, cg_tile
+                    )?;
                     write!(w, "{}\t*({} + i) = 0;\n", tabs, define_variable)?;
                     write!(w, "{}}}\n", tabs)?;
                     write!(w, "{}{}.sync();\n", tabs, cg_tile)?;
@@ -1083,14 +1259,25 @@ extern \"C\" {} {}(", ret_type.clone(), self.function.name)?;
                 let mut left_val = self.get_value(*left, false, false);
                 let mut right_val = self.get_value(*right, false, false);
                 let id_type = self.typing[id.idx()];
-                if matches!(op, BinaryOperator::Add | BinaryOperator::Or | BinaryOperator::And
-                    | BinaryOperator::Xor) && is_special_reduct {
+                if matches!(
+                    op,
+                    BinaryOperator::Add
+                        | BinaryOperator::Or
+                        | BinaryOperator::And
+                        | BinaryOperator::Xor
+                ) && is_special_reduct
+                {
                     // For parallelized associative Reduces, use the cooperative
                     // groups reduce API. Associative multiplication is not
                     // supported. We need to use CGType::Use not CGType::UsePerId
                     // because for parallelized reduction we only have one thread
                     // per ThreadID and the reduction is over Use, not UsePerId.
-                    let (reduce_val, non_reduce_val) = if let Node::Reduce { control: _, init: _, reduct: _ } = &self.function.nodes[left.idx()] {
+                    let (reduce_val, non_reduce_val) = if let Node::Reduce {
+                        control: _,
+                        init: _,
+                        reduct: _,
+                    } = &self.function.nodes[left.idx()]
+                    {
                         (left_val, right_val)
                     } else {
                         (right_val, left_val)
@@ -1107,7 +1294,11 @@ extern \"C\" {} {}(", ret_type.clone(), self.function.name)?;
                         _ => unreachable!(),
                     };
                     let id_type_name = self.get_type(id_type, false);
-                    write!(w, "{}{} = cg::reduce({}, {}, cg::{}<{}>());\n", tabs, define_variable, cg_tile, non_reduce_val, cg_op, id_type_name)?;
+                    write!(
+                        w,
+                        "{}{} = cg::reduce({}, {}, cg::{}<{}>());\n",
+                        tabs, define_variable, cg_tile, non_reduce_val, cg_op, id_type_name
+                    )?;
                     // Setup binop between reduce's init and reduced reduct. Since it's associative,
                     // we can change binop ordering
                     left_val = define_variable.clone();
@@ -1184,7 +1375,12 @@ extern \"C\" {} {}(", ret_type.clone(), self.function.name)?;
                 let id_type = self.typing[id.idx()];
                 if matches!(intrinsic, Intrinsic::Max | Intrinsic::Min) && is_special_reduct {
                     // Similar to associative Binops
-                    let non_reduce_arg = if let Node::Reduce { control: _, init: _, reduct: _ } = &self.function.nodes[args[0].idx()] {
+                    let non_reduce_arg = if let Node::Reduce {
+                        control: _,
+                        init: _,
+                        reduct: _,
+                    } = &self.function.nodes[args[0].idx()]
+                    {
                         self.get_value(args[1], false, false)
                     } else {
                         self.get_value(args[0], false, false)
@@ -1197,33 +1393,44 @@ extern \"C\" {} {}(", ret_type.clone(), self.function.name)?;
                         _ => unreachable!(),
                     };
                     let id_type_name = self.get_type(id_type, false);
-                    write!(w, "{}{} = cg::reduce({}, {}, cg::{}<{}>());\n", tabs, define_variable, non_reduce_arg, cg_tile, cg_op, id_type_name)?;
+                    write!(
+                        w,
+                        "{}{} = cg::reduce({}, {}, cg::{}<{}>());\n",
+                        tabs, define_variable, non_reduce_arg, cg_tile, cg_op, id_type_name
+                    )?;
                 } else {
                     let ty = &self.types[id_type.idx()];
                     let intrinsic = self.codegen_intrinsic(intrinsic, ty);
-                    let args = args.iter()
+                    let args = args
+                        .iter()
                         .map(|arg| self.get_value(*arg, false, false))
                         .collect::<Vec<_>>()
                         .join(", ");
                     write!(
                         w,
                         "{}{} = {}({});\n",
-                        tabs,
-                        define_variable,
-                        intrinsic,
-                        args,
+                        tabs, define_variable, intrinsic, args,
                     )?;
                 }
             }
             // Read of primitive requires load after pointer math.
             Node::Read { collect, indices } => {
-                let collect_with_indices = self.codegen_collect(*collect, indices, extra_dim_collects);
+                let collect_with_indices =
+                    self.codegen_collect(*collect, indices, extra_dim_collects);
                 let data_type_id = self.typing[id.idx()];
                 if self.types[data_type_id.idx()].is_primitive() {
                     let type_name = self.get_type(data_type_id, true);
-                    write!(w, "{}{} = *reinterpret_cast<{}>({});\n", tabs, define_variable, type_name, collect_with_indices)?;
+                    write!(
+                        w,
+                        "{}{} = *reinterpret_cast<{}>({});\n",
+                        tabs, define_variable, type_name, collect_with_indices
+                    )?;
                 } else {
-                    write!(w, "{}{} = {};\n", tabs, define_variable, collect_with_indices)?;
+                    write!(
+                        w,
+                        "{}{} = {};\n",
+                        tabs, define_variable, collect_with_indices
+                    )?;
                 }
             }
             // Write of primitive needs a thread rank gate for safety. Write of
@@ -1233,24 +1440,43 @@ extern \"C\" {} {}(", ret_type.clone(), self.function.name)?;
                 data,
                 indices,
             } => {
-                let collect_with_indices = self.codegen_collect(*collect, indices, extra_dim_collects);
+                let collect_with_indices =
+                    self.codegen_collect(*collect, indices, extra_dim_collects);
                 let data_variable = self.get_value(*data, false, false);
                 let data_type_id = self.typing[data.idx()];
                 let cg_tile = match state {
                     KernelState::OutBlock | KernelState::InBlock => "block".to_string(),
-                    KernelState::InThread => self.get_cg_tile(nesting_fork.unwrap(), CGType::UsePerId),
+                    KernelState::InThread => {
+                        self.get_cg_tile(nesting_fork.unwrap(), CGType::UsePerId)
+                    }
                 };
                 if self.types[data_type_id.idx()].is_primitive() {
                     write!(w, "{}if ({}.thread_rank() == 0) {{\n", tabs, cg_tile)?;
                     let type_name = self.get_type(data_type_id, true);
-                    write!(w, "{}\t*reinterpret_cast<{}>({}) = {};\n", tabs, type_name, collect_with_indices, data_variable)?;
+                    write!(
+                        w,
+                        "{}\t*reinterpret_cast<{}>({}) = {};\n",
+                        tabs, type_name, collect_with_indices, data_variable
+                    )?;
                     write!(w, "{}}}\n", tabs)?;
                 } else {
                     let data_size = self.get_size(data_type_id, None, Some(extra_dim_collects));
-                    write!(w, "{}for (int i = {}.thread_rank(); i < {}; i += {}.size()) {{\n", tabs, cg_tile, data_size, cg_tile)?;
-                    write!(w, "{}\t*({} + i) = *({} + i);\n", tabs, collect_with_indices, data_variable)?;
+                    write!(
+                        w,
+                        "{}for (int i = {}.thread_rank(); i < {}; i += {}.size()) {{\n",
+                        tabs, cg_tile, data_size, cg_tile
+                    )?;
+                    write!(
+                        w,
+                        "{}\t*({} + i) = *({} + i);\n",
+                        tabs, collect_with_indices, data_variable
+                    )?;
                     write!(w, "{}}}\n", tabs)?;
-                    write!(w, "{}if ({}.thread_rank() < {} % {}.size()) {{\n", tabs, cg_tile, data_size, cg_tile)?;
+                    write!(
+                        w,
+                        "{}if ({}.thread_rank() < {} % {}.size()) {{\n",
+                        tabs, cg_tile, data_size, cg_tile
+                    )?;
                     write!(w, "{}\t*({} + {}.size() * ({} / {}.size()) + {}.thread_rank()) = *({} + {}.size() * ({} / {}.size()) + {}.thread_rank());\n", tabs, collect_with_indices, cg_tile, data_size, cg_tile, cg_tile, data_variable, cg_tile, data_size, cg_tile, cg_tile)?;
                     write!(w, "{}}}\n", tabs)?;
                 }
@@ -1259,7 +1485,10 @@ extern \"C\" {} {}(", ret_type.clone(), self.function.name)?;
                 write!(w, "{}{} = {};\n", tabs, define_variable, collect_variable)?;
             }
             _ => {
-                panic!("Unsupported data node type: {:?}", self.function.nodes[id.idx()])
+                panic!(
+                    "Unsupported data node type: {:?}",
+                    self.function.nodes[id.idx()]
+                )
             }
         }
         // Since reducts are responsible for updating Reduce nodes, we check and
@@ -1292,7 +1521,10 @@ extern \"C\" {} {}(", ret_type.clone(), self.function.name)?;
         let tabs = match &self.function.nodes[id.idx()] {
             Node::Start
             | Node::Region { preds: _ }
-            | Node::Projection { control: _, selection: _ } => {
+            | Node::Projection {
+                control: _,
+                selection: _,
+            } => {
                 let succ = self.control_subgraph.succs(id).next().unwrap();
                 write!(w_term, "\tgoto {};\n", self.get_block_name(succ, false))?;
                 1
@@ -1309,13 +1541,32 @@ extern \"C\" {} {}(", ret_type.clone(), self.function.name)?;
                     "\tif ({}) {{\n",
                     self.get_value(*cond, false, false)
                 )?;
-                write!(w_term, "\t\tgoto {};\n", if succ1_is_true { succ1_block_name.clone() } else { succ2_block_name.clone() })?;
+                write!(
+                    w_term,
+                    "\t\tgoto {};\n",
+                    if succ1_is_true {
+                        succ1_block_name.clone()
+                    } else {
+                        succ2_block_name.clone()
+                    }
+                )?;
                 write!(w_term, "\t}} else {{\n")?;
-                write!(w_term, "\t\tgoto {};\n", if succ1_is_true { succ2_block_name } else { succ1_block_name })?;
+                write!(
+                    w_term,
+                    "\t\tgoto {};\n",
+                    if succ1_is_true {
+                        succ2_block_name
+                    } else {
+                        succ1_block_name
+                    }
+                )?;
                 write!(w_term, "\t}}\n")?;
                 1
             }
-            Node::Fork { control: _, factors: _ } => {
+            Node::Fork {
+                control: _,
+                factors: _,
+            } => {
                 // We create a cooperative group tile for each of: used threads per
                 // thread ID- for reads and writes-, used threads across all thread
                 // IDs- for parallelized reductions-, and available threads- to
@@ -1332,12 +1583,24 @@ extern \"C\" {} {}(", ret_type.clone(), self.function.name)?;
                     } else {
                         use_thread_quota
                     };
-                    write!(w_init, "\tcg::thread_block_tile<{}> {} = cg::tiled_partition<{}>(block);\n", use_thread_per_id, cg_tile, use_thread_per_id)?;
+                    write!(
+                        w_init,
+                        "\tcg::thread_block_tile<{}> {} = cg::tiled_partition<{}>(block);\n",
+                        use_thread_per_id, cg_tile, use_thread_per_id
+                    )?;
                     let cg_tile_use = self.get_cg_tile(id, CGType::Use);
-                    write!(w_init, "\tcg::thread_block_tile<{}> {} = cg::tiled_partition<{}>(block);\n", use_thread_quota, cg_tile_use, use_thread_quota)?;
+                    write!(
+                        w_init,
+                        "\tcg::thread_block_tile<{}> {} = cg::tiled_partition<{}>(block);\n",
+                        use_thread_quota, cg_tile_use, use_thread_quota
+                    )?;
                     let available_thread_quota = available_thread_quota.unwrap();
                     let cg_tile_available = self.get_cg_tile(id, CGType::Available);
-                    write!(w_init, "\tcg::thread_block_tile<{}> {} = cg::tiled_partition<{}>(block);\n", available_thread_quota, cg_tile_available, available_thread_quota)?;
+                    write!(
+                        w_init,
+                        "\tcg::thread_block_tile<{}> {} = cg::tiled_partition<{}>(block);\n",
+                        available_thread_quota, cg_tile_available, available_thread_quota
+                    )?;
                     if parallel_factor.is_none() {
                         write!(w_init, "\t{} = 0;\n", self.get_fork_iter(id, true))?;
                         write!(w_init, "\tgoto {};\n", self.get_block_name(id, true))?;
@@ -1347,9 +1610,20 @@ extern \"C\" {} {}(", ret_type.clone(), self.function.name)?;
                 // threads. If unused, we jump straight to the Join, and if used,
                 // we jump to successor like normal.
                 let succ = self.control_subgraph.succs(id).next().unwrap();
-                if let Some(available_thread_quota) = available_thread_quota && let Some(use_thread_quota) = use_thread_quota && use_thread_quota < available_thread_quota {
-                    let w_target = if parallel_factor.is_none() { w_post_init } else { w_init };
-                    write!(w_target, "\tif (threadIdx.x % {} < {}) {{\n", available_thread_quota, use_thread_quota)?;
+                if let Some(available_thread_quota) = available_thread_quota
+                    && let Some(use_thread_quota) = use_thread_quota
+                    && use_thread_quota < available_thread_quota
+                {
+                    let w_target = if parallel_factor.is_none() {
+                        w_post_init
+                    } else {
+                        w_init
+                    };
+                    write!(
+                        w_target,
+                        "\tif (threadIdx.x % {} < {}) {{\n",
+                        available_thread_quota, use_thread_quota
+                    )?;
                     write!(w_term, "\t\tgoto {};\n", self.get_block_name(succ, false))?;
                     write!(w_term, "\t}}\n")?;
                     write!(w_term, "\telse {{\n")?;
@@ -1376,7 +1650,11 @@ extern \"C\" {} {}(", ret_type.clone(), self.function.name)?;
                     let available_thread_quota = available_thread_quota.unwrap();
                     let use_thread_quota = use_thread_quota.unwrap();
                     if use_thread_quota < available_thread_quota {
-                        write!(w_init, "\tif (threadIdx.x % {} < {}) {{\n", available_thread_quota, use_thread_quota)?;
+                        write!(
+                            w_init,
+                            "\tif (threadIdx.x % {} < {}) {{\n",
+                            available_thread_quota, use_thread_quota
+                        )?;
                         write!(w_term, "\t}}\n")?;
                         tabs += 1;
                     }
@@ -1415,7 +1693,11 @@ extern \"C\" {} {}(", ret_type.clone(), self.function.name)?;
                     let return_val = self.get_value(*data, false, false);
                     let return_type_ptr = self.get_type(self.function.return_type, true);
                     write!(w_term, "\tif (grid.thread_rank() == 0) {{\n")?;
-                    write!(w_term, "\t\t*(reinterpret_cast<{}>(ret)) = {};\n", return_type_ptr, return_val)?;
+                    write!(
+                        w_term,
+                        "\t\t*(reinterpret_cast<{}>(ret)) = {};\n",
+                        return_type_ptr, return_val
+                    )?;
                     write!(w_term, "\t}}\n")?;
                 }
                 write!(w_term, "\treturn;\n")?;
@@ -1432,7 +1714,12 @@ extern \"C\" {} {}(", ret_type.clone(), self.function.name)?;
      * This function emits collection name + pointer math for the provided indices.
      * All collection types use char pointers.
      */
-    fn codegen_collect(&self, collect: NodeID, indices: &[Index], extra_dim_collects: &HashSet<TypeID>) -> String {
+    fn codegen_collect(
+        &self,
+        collect: NodeID,
+        indices: &[Index],
+        extra_dim_collects: &HashSet<TypeID>,
+    ) -> String {
         let mut index_ptr = "0".to_string();
         let type_id = self.typing[collect.idx()];
         for index in indices {
@@ -1514,7 +1801,9 @@ extern \"C\" {} {}(", ret_type.clone(), self.function.name)?;
             Constant::UnsignedInteger32(val) => write!(w, "{}{} = {}ul;\n", tabs, name, val)?,
             Constant::Integer64(val) => write!(w, "{}{} = {}ll;\n", tabs, name, val)?,
             Constant::UnsignedInteger64(val) => write!(w, "{}{} = {}ull;\n", tabs, name, val)?,
-            Constant::Float32(val) => write!(w, "{}{} = {}f;\n", tabs, name, format_float(**val as f64))?,
+            Constant::Float32(val) => {
+                write!(w, "{}{} = {}f;\n", tabs, name, format_float(**val as f64))?
+            }
             Constant::Float64(val) => write!(w, "{}{} = {};\n", tabs, name, format_float(**val))?,
             // All three following collections involve align then allocate from the
             // single dynamic shared memory buffer by using and updating the offset.
@@ -1522,9 +1811,20 @@ extern \"C\" {} {}(", ret_type.clone(), self.function.name)?;
                 if allow_allocate {
                     let alignment = self.get_alignment(*type_id);
                     let size = self.get_size(*type_id, None, extra_dim_collects);
-                    *dynamic_shared_offset = format!("(({} + {} - 1) / {}) * {}", dynamic_shared_offset, alignment, alignment, alignment);
-                    write!(w, "{}dynamic_shared_offset = {};\n", tabs, dynamic_shared_offset)?;
-                    write!(w, "{}{} = dynamic_shared + dynamic_shared_offset;\n", tabs, name)?;
+                    *dynamic_shared_offset = format!(
+                        "(({} + {} - 1) / {}) * {}",
+                        dynamic_shared_offset, alignment, alignment, alignment
+                    );
+                    write!(
+                        w,
+                        "{}dynamic_shared_offset = {};\n",
+                        tabs, dynamic_shared_offset
+                    )?;
+                    write!(
+                        w,
+                        "{}{} = dynamic_shared + dynamic_shared_offset;\n",
+                        tabs, name
+                    )?;
                     *dynamic_shared_offset = format!("{} + {}", dynamic_shared_offset, size);
                 }
                 let Type::Product(type_fields) = &self.types[type_id.idx()] else {
@@ -1546,7 +1846,15 @@ extern \"C\" {} {}(", ret_type.clone(), self.function.name)?;
                             num_tabs,
                         )?;
                     } else if !field_constant.is_array() {
-                        self.codegen_constant(format!("{}+{}", name, offset), constant_fields[i], false, extra_dim_collects, dynamic_shared_offset, w, num_tabs)?;
+                        self.codegen_constant(
+                            format!("{}+{}", name, offset),
+                            constant_fields[i],
+                            false,
+                            extra_dim_collects,
+                            dynamic_shared_offset,
+                            w,
+                            num_tabs,
+                        )?;
                     }
                 }
             }
@@ -1554,9 +1862,20 @@ extern \"C\" {} {}(", ret_type.clone(), self.function.name)?;
                 if allow_allocate {
                     let alignment = self.get_alignment(*type_id);
                     let size = self.get_size(*type_id, None, extra_dim_collects);
-                    *dynamic_shared_offset = format!("(({} + {} - 1) / {}) * {}", dynamic_shared_offset, alignment, alignment, alignment);
-                    write!(w, "{}dynamic_shared_offset = {};\n", tabs, dynamic_shared_offset)?;
-                    write!(w, "{}{} = dynamic_shared + dynamic_shared_offset;\n", tabs, name)?;
+                    *dynamic_shared_offset = format!(
+                        "(({} + {} - 1) / {}) * {}",
+                        dynamic_shared_offset, alignment, alignment, alignment
+                    );
+                    write!(
+                        w,
+                        "{}dynamic_shared_offset = {};\n",
+                        tabs, dynamic_shared_offset
+                    )?;
+                    write!(
+                        w,
+                        "{}{} = dynamic_shared + dynamic_shared_offset;\n",
+                        tabs, name
+                    )?;
                     *dynamic_shared_offset = format!("{} + {}", dynamic_shared_offset, size);
                 }
                 // No offset updating needed since all variants start at 0
@@ -1565,7 +1884,8 @@ extern \"C\" {} {}(", ret_type.clone(), self.function.name)?;
                 };
                 let variant_constant = &self.constants[field.idx()];
                 if variant_constant.is_scalar() {
-                    let variant_type = self.get_type(self.typing[variants[*variant as usize].idx()], true);
+                    let variant_type =
+                        self.get_type(self.typing[variants[*variant as usize].idx()], true);
                     self.codegen_constant(
                         format!("*reinterpret_cast<{}>({})", variant_type, name),
                         *field,
@@ -1576,7 +1896,15 @@ extern \"C\" {} {}(", ret_type.clone(), self.function.name)?;
                         num_tabs,
                     )?;
                 } else if !variant_constant.is_array() {
-                    self.codegen_constant(name, *field, false, extra_dim_collects, dynamic_shared_offset, w, num_tabs)?;
+                    self.codegen_constant(
+                        name,
+                        *field,
+                        false,
+                        extra_dim_collects,
+                        dynamic_shared_offset,
+                        w,
+                        num_tabs,
+                    )?;
                 };
             }
             Constant::Array(type_id) => {
@@ -1585,9 +1913,20 @@ extern \"C\" {} {}(", ret_type.clone(), self.function.name)?;
                 }
                 let alignment = self.get_alignment(*type_id);
                 let size = self.get_size(*type_id, None, extra_dim_collects);
-                *dynamic_shared_offset = format!("(({} + {} - 1) / {}) * {}", dynamic_shared_offset, alignment, alignment, alignment);
-                write!(w, "{}dynamic_shared_offset = {};\n", tabs, dynamic_shared_offset)?;
-                write!(w, "{}{} = dynamic_shared + dynamic_shared_offset;\n", tabs, name)?;
+                *dynamic_shared_offset = format!(
+                    "(({} + {} - 1) / {}) * {}",
+                    dynamic_shared_offset, alignment, alignment, alignment
+                );
+                write!(
+                    w,
+                    "{}dynamic_shared_offset = {};\n",
+                    tabs, dynamic_shared_offset
+                )?;
+                write!(
+                    w,
+                    "{}{} = dynamic_shared + dynamic_shared_offset;\n",
+                    tabs, name
+                )?;
                 *dynamic_shared_offset = format!("{} + {}", dynamic_shared_offset, size);
             }
         }
@@ -1600,10 +1939,21 @@ extern \"C\" {} {}(", ret_type.clone(), self.function.name)?;
      * and offset to 2nd field. This is useful for constant initialization and read/write
      * index math.
      */
-    fn get_size(&self, type_id: TypeID, num_fields: Option<usize>, extra_dim_collects: Option<&HashSet<TypeID>>) -> String {
+    fn get_size(
+        &self,
+        type_id: TypeID,
+        num_fields: Option<usize>,
+        extra_dim_collects: Option<&HashSet<TypeID>>,
+    ) -> String {
         match &self.types[type_id.idx()] {
             Type::Array(element_type, extents) => {
-                let array_size = if extra_dim_collects.is_some() && extra_dim_collects.unwrap().contains(&type_id) { "1".to_string() } else { multiply_dcs(extents) };
+                let array_size = if extra_dim_collects.is_some()
+                    && extra_dim_collects.unwrap().contains(&type_id)
+                {
+                    "1".to_string()
+                } else {
+                    multiply_dcs(extents)
+                };
                 format!("{} * {}", self.get_alignment(*element_type), array_size)
             }
             Type::Product(fields) => {
@@ -1612,7 +1962,12 @@ extern \"C\" {} {}(", ret_type.clone(), self.function.name)?;
                     .iter()
                     .enumerate()
                     .filter(|(i, _)| i < num_fields)
-                    .map(|(_, id)| (self.get_size(*id, None, extra_dim_collects), self.get_alignment(*id)))
+                    .map(|(_, id)| {
+                        (
+                            self.get_size(*id, None, extra_dim_collects),
+                            self.get_alignment(*id),
+                        )
+                    })
                     .fold(String::from("0"), |acc, (size, align)| {
                         if acc == "0" {
                             size
@@ -1637,16 +1992,16 @@ extern \"C\" {} {}(", ret_type.clone(), self.function.name)?;
                 // The argmax variant by size is not guaranteed to be same as
                 // argmax variant by alignment, eg product of 3 4-byte primitives
                 // vs 1 8-byte primitive, so we need to calculate both.
-                let max_size = variants.iter().map(|id| self.get_size(*id, None, extra_dim_collects)).fold(
-                    String::from("0"),
-                    |acc, x| {
+                let max_size = variants
+                    .iter()
+                    .map(|id| self.get_size(*id, None, extra_dim_collects))
+                    .fold(String::from("0"), |acc, x| {
                         if acc == "0" {
                             x
                         } else {
                             format!("umax({}, {})", acc, x)
                         }
-                    },
-                );
+                    });
                 let max_alignment = variants
                     .iter()
                     .map(|id| self.get_alignment(*id))
@@ -1793,7 +2148,17 @@ extern \"C\" {} {}(", ret_type.clone(), self.function.name)?;
     }
 
     fn get_cg_tile(&self, fork: NodeID, cg_type: CGType) -> String {
-        format!("cg_{}{}", self.get_value(fork, false, false), if cg_type == CGType::Use { "_use" } else if cg_type == CGType::Available { "_available" } else { "" })
+        format!(
+            "cg_{}{}",
+            self.get_value(fork, false, false),
+            if cg_type == CGType::Use {
+                "_use"
+            } else if cg_type == CGType::Available {
+                "_available"
+            } else {
+                ""
+            }
+        )
     }
 
     fn get_fork_iter(&self, fork: NodeID, ty: bool) -> String {
@@ -1805,7 +2170,11 @@ extern \"C\" {} {}(", ret_type.clone(), self.function.name)?;
     }
 
     fn get_block_name(&self, id: NodeID, post: bool) -> String {
-        format!("bb_{}{}", self.get_value(id, false, false), if post { "_post" } else { "" })
+        format!(
+            "bb_{}{}",
+            self.get_value(id, false, false),
+            if post { "_post" } else { "" }
+        )
     }
 
     /*
@@ -1860,7 +2229,10 @@ fn multiply_dcs(dcs: &[DynamicConstantID]) -> String {
     if dcs.is_empty() {
         "1".to_string()
     } else {
-        dcs.iter().map(|dc| format!("dc{}", dc.idx())).collect::<Vec<_>>().join(" * ")
+        dcs.iter()
+            .map(|dc| format!("dc{}", dc.idx()))
+            .collect::<Vec<_>>()
+            .join(" * ")
     }
 }
 
diff --git a/hercules_cg/src/rt.rs b/hercules_cg/src/rt.rs
index 916d6520..cbef5a00 100644
--- a/hercules_cg/src/rt.rs
+++ b/hercules_cg/src/rt.rs
@@ -358,14 +358,16 @@ impl<'a> RTContext<'a> {
                     }
                 }
                 write!(block, ";\n")?;
-                if let Some((size, device)) = size_and_device {
-                    write!(
-                        block,
-                        "                ::hercules_rt::__{}_zero_mem({}, {} as usize);\n",
-                        device.name(),
-                        self.get_value(id),
-                        size
-                    )?;
+                if !func.schedules[id.idx()].contains(&Schedule::NoResetConstant) {
+                    if let Some((size, device)) = size_and_device {
+                        write!(
+                            block,
+                            "                ::hercules_rt::__{}_zero_mem({}, {} as usize);\n",
+                            device.name(),
+                            self.get_value(id),
+                            size
+                        )?;
+                    }
                 }
             }
             Node::Call {
diff --git a/hercules_samples/matmul/src/gpu.sch b/hercules_samples/matmul/src/gpu.sch
index c0a1a5ce..4303d376 100644
--- a/hercules_samples/matmul/src/gpu.sch
+++ b/hercules_samples/matmul/src/gpu.sch
@@ -13,9 +13,8 @@ phi-elim(*);
 dce(*);
 
 forkify(*);
-infer-schedules(*);
+fixpoint {
+  infer-schedules(*);
+}
 
 gcm(*);
-float-collections(*);
-dce(*);
-gcm(*);
-- 
GitLab


From 4a70e1c75c78291c36ae4249de79b4425d77d8d6 Mon Sep 17 00:00:00 2001
From: Russel Arbore <russel.jma@gmail.com>
Date: Sat, 1 Feb 2025 11:14:44 -0600
Subject: [PATCH 7/9] fix implicit_clone

---
 juno_samples/implicit_clone/src/gpu.sch | 8 +++++---
 1 file changed, 5 insertions(+), 3 deletions(-)

diff --git a/juno_samples/implicit_clone/src/gpu.sch b/juno_samples/implicit_clone/src/gpu.sch
index 0f7c8021..50c4ae8e 100644
--- a/juno_samples/implicit_clone/src/gpu.sch
+++ b/juno_samples/implicit_clone/src/gpu.sch
@@ -15,6 +15,8 @@ dce(*);
 infer-schedules(*);
 
 gcm(*);
-float-collections(*);
-dce(*);
-gcm(*);
+fixpoint {
+  float-collections(*);
+  dce(*);
+  gcm(*);
+}
-- 
GitLab


From 7114b2083ff64e8fdec1969038a11d8477862576 Mon Sep 17 00:00:00 2001
From: Russel Arbore <russel.jma@gmail.com>
Date: Sat, 1 Feb 2025 11:18:17 -0600
Subject: [PATCH 8/9] fix cava gpu ci for now

---
 juno_samples/cava/src/gpu.sch | 5 +++++
 1 file changed, 5 insertions(+)

diff --git a/juno_samples/cava/src/gpu.sch b/juno_samples/cava/src/gpu.sch
index a5570b8d..594cbfa3 100644
--- a/juno_samples/cava/src/gpu.sch
+++ b/juno_samples/cava/src/gpu.sch
@@ -17,3 +17,8 @@ dce(*);
 infer-schedules(*);
 
 gcm(*);
+fixpoint {
+  float-collections(*);
+  dce(*);
+  gcm(*);
+}
-- 
GitLab


From 6b39e69c801eb11c11e3b51a2506467b1c79598f Mon Sep 17 00:00:00 2001
From: Russel Arbore <russel.jma@gmail.com>
Date: Sat, 1 Feb 2025 16:40:40 -0600
Subject: [PATCH 9/9] fix for edge_detection

---
 Cargo.toml                              | 8 ++++----
 juno_samples/edge_detection/src/gpu.sch | 5 +++++
 2 files changed, 9 insertions(+), 4 deletions(-)

diff --git a/Cargo.toml b/Cargo.toml
index f4976106..f7b9322a 100644
--- a/Cargo.toml
+++ b/Cargo.toml
@@ -21,14 +21,14 @@ members = [
 	"hercules_samples/ccp",
 
 	"juno_samples/simple3",
-  "juno_samples/patterns",
+	"juno_samples/patterns",
 	"juno_samples/matmul",
 	"juno_samples/casts_and_intrinsics",
 	"juno_samples/nested_ccp",
 	"juno_samples/antideps",
 	"juno_samples/implicit_clone",
-  "juno_samples/cava",
+	"juno_samples/cava",
 	"juno_samples/concat",
-  "juno_samples/schedule_test",
-  "juno_samples/edge_detection",
+	"juno_samples/schedule_test",
+	"juno_samples/edge_detection",
 ]
diff --git a/juno_samples/edge_detection/src/gpu.sch b/juno_samples/edge_detection/src/gpu.sch
index d7330c67..a1bf06a4 100644
--- a/juno_samples/edge_detection/src/gpu.sch
+++ b/juno_samples/edge_detection/src/gpu.sch
@@ -17,3 +17,8 @@ dce(*);
 infer-schedules(*);
 
 gcm(*);
+fixpoint {
+  float-collections(*);
+  dce(*);
+  gcm(*);
+}
-- 
GitLab