diff --git a/.gitignore b/.gitignore
index 749cea40376eed265a3734958bb1b298e2f083de..87af5349ee5aa22c4f3763c6f8905b69f2773bb2 100644
--- a/.gitignore
+++ b/.gitignore
@@ -12,3 +12,4 @@
 *.swp
 .vscode
 *_env
+*.txt
diff --git a/hercules_cg/src/fork_tree.rs b/hercules_cg/src/fork_tree.rs
index da7f640a6370c6040aa7e396b17f4080a02f618d..64a93160aabc5564c784bf0ff8b14e82045b81bb 100644
--- a/hercules_cg/src/fork_tree.rs
+++ b/hercules_cg/src/fork_tree.rs
@@ -2,6 +2,22 @@ use std::collections::{HashMap, HashSet};
 
 use crate::*;
 
+/*
+ * Construct a map from fork node to all control nodes (including itself) satisfying:
+ * a) domination by F
+ * b) no domination by F's join
+ * c) no domination by any other fork that's also dominated by F, where we do count self-domination
+ * Here too we include the non-fork start node, as key for all controls outside any fork.
+ */
+pub fn fork_control_map(fork_join_nesting: &HashMap<NodeID, Vec<NodeID>>) -> HashMap<NodeID, HashSet<NodeID>> {
+    let mut fork_control_map = HashMap::new();
+    for (control, forks) in fork_join_nesting {
+        let fork = forks.first().copied().unwrap_or(NodeID::new(0));
+        fork_control_map.entry(fork).or_insert_with(HashSet::new).insert(*control);
+    }
+    fork_control_map
+}
+
 /* Construct a map from each fork node F to all forks satisfying:
  * a) domination by F
  * b) no domination by F's join
@@ -19,19 +35,3 @@ pub fn fork_tree(function: &Function, fork_join_nesting: &HashMap<NodeID, Vec<No
     }
     fork_tree
 }
-
-/*
- * Construct a map from fork node to all control nodes (including itself) satisfying:
- * a) domination by F
- * b) no domination by F's join
- * c) no domination by any other fork that's also dominated by F, where we do count self-domination
- * Here too we include the non-fork start node, as key for all controls outside any fork.
- */
-pub fn fork_control_map(fork_join_nesting: &HashMap<NodeID, Vec<NodeID>>) -> HashMap<NodeID, HashSet<NodeID>> {
-    let mut fork_control_map = HashMap::new();
-    for (control, forks) in fork_join_nesting {
-        let fork = forks.first().copied().unwrap_or(NodeID::new(0));
-        fork_control_map.entry(fork).or_insert_with(HashSet::new).insert(*control);
-    }
-    fork_control_map
-}
diff --git a/hercules_cg/src/gpu.rs b/hercules_cg/src/gpu.rs
index 1a9b6869324f724c58dd4438a1d7f856d91fb6eb..a6711a33ab0325afb84362e8f3e7b3827efb0db7 100644
--- a/hercules_cg/src/gpu.rs
+++ b/hercules_cg/src/gpu.rs
@@ -3,6 +3,8 @@ extern crate hercules_ir;
 
 use std::collections::{BTreeMap, HashMap, HashSet};
 use std::fmt::{Error, Write};
+use std::fs::{OpenOptions, File};
+use std::io::Write as _;
 
 use self::hercules_ir::*;
 
@@ -107,6 +109,7 @@ pub fn gpu_codegen<W: Write>(
                         .entry(fork_node)
                         .or_default()
                         .push(*reduce_node);
+                    println!("reduce_node: {:?}, fork_node: {:?}, join: {:?}", reduce_node, fork_node, control);
                 }
                 Node::Region { preds: _ } => {
                     // TODO: map region node to fork node
@@ -129,7 +132,7 @@ pub fn gpu_codegen<W: Write>(
         if function.nodes[idx].is_fork() {
             assert!(fork_reduce_map
                 .get(&NodeID::new(idx))
-                .is_none_or(|reduces| reduces.is_empty()),
+                .is_some_and(|reduces| !reduces.is_empty()),
                 "Fork node {} has no reduce nodes",
                 idx
             );
@@ -158,11 +161,12 @@ pub fn gpu_codegen<W: Write>(
     };
 
     let kernel_params = &GPUKernelParams {
-        max_num_blocks: 1024,
         max_num_threads: 1024,
         threads_per_warp: 32,
     };
 
+    std::fs::write("out.txt", "debug\n\n").unwrap();
+
     let ctx = GPUContext {
         function,
         types,
@@ -187,7 +191,6 @@ pub fn gpu_codegen<W: Write>(
 }
 
 struct GPUKernelParams {
-    max_num_blocks: usize,
     max_num_threads: usize,
     threads_per_warp: usize,
 }
@@ -259,6 +262,8 @@ enum CGType {
 
 impl GPUContext<'_> {
     fn codegen_function<W: Write>(&self, w: &mut W) -> Result<(), Error> {
+        let mut file = OpenOptions::new().append(true).open("out.txt").unwrap();
+
         // Emit all code up to the "goto" to Start's block
         let mut top = String::new();
         self.codegen_kernel_begin(self.return_parameter.is_none(), &mut top)?;
@@ -281,13 +286,17 @@ impl GPUContext<'_> {
 
         // If there are no forks, fast forward to single-block, single-thread codegen
         let (num_blocks, num_threads) = if self.fork_join_map.is_empty() {
+            writeln!(file, "shortcut to 1b1t").unwrap();
             self.codegen_data_control_no_forks(&HashSet::new(), &mut dynamic_shared_offset, &mut gotos)?;
-            (1, 1)
+            ("1".to_string(), "1".to_string())
         } else {
+            writeln!(file, "no shortcut! fork tree: {:?}", self.fork_tree).unwrap();
             // Create structures and determine block and thread parallelization strategy
-            let (root_forks, num_blocks) =
-                self.get_root_forks_and_num_blocks(&self.fork_tree, self.kernel_params.max_num_blocks);
-            let (thread_root_root_fork, thread_root_forks) = self.get_thread_root_forks(&root_forks, &self.fork_tree, num_blocks);
+            let (root_forks, num_blocks, is_block_parallel) =
+                self.get_root_forks_and_num_blocks(&self.fork_tree);
+            writeln!(file, "is_block_parallel: {}", is_block_parallel).unwrap();
+            let (thread_root_root_fork, thread_root_forks) = self.get_thread_root_forks(&root_forks, &self.fork_tree, is_block_parallel);
+            writeln!(file, "thread_root_root_fork: {:?}", thread_root_root_fork).unwrap();
             let (fork_thread_quota_map, num_threads) = self.get_thread_quotas(&self.fork_tree, thread_root_root_fork);
             // TODO: Uncomment and adjust once we know logic of extra dim
             // let extra_dim_collects = self.get_extra_dim_collects(&fork_control_map, &fork_thread_quota_map);
@@ -295,7 +304,7 @@ impl GPUContext<'_> {
 
             // Core function for the CUDA code of all data and control nodes.
             self.codegen_data_control(
-                if num_blocks > 1 {
+                if is_block_parallel {
                     Some(thread_root_root_fork)
                 } else {
                     None
@@ -304,11 +313,11 @@ impl GPUContext<'_> {
                 &fork_thread_quota_map,
                 &extra_dim_collects,
                 &mut dynamic_shared_offset,
-                num_blocks,
+                is_block_parallel,
                 num_threads,
                 &mut gotos,
             )?;
-            (num_blocks, num_threads)
+            (num_blocks, num_threads.to_string())
         };
 
         // Emit all GPU kernel code from previous steps
@@ -493,7 +502,7 @@ namespace cg = cooperative_groups;
         Ok(())
     }
 
-    fn codegen_launch_code(&self, num_blocks: usize, num_threads: usize, dynamic_shared_offset: &str, w: &mut String) -> Result<(), Error> {
+    fn codegen_launch_code(&self, num_blocks: String, num_threads: String, dynamic_shared_offset: &str, w: &mut String) -> Result<(), Error> {
         // The following steps are for host-side C function arguments, but we also
         // need to pass arguments to kernel, so we keep track of the arguments here.
         let ret_type = self.get_type(self.function.return_type, false);
@@ -559,32 +568,28 @@ extern \"C\" {} {}(", ret_type.clone(), self.function.name)?;
     fn get_root_forks_and_num_blocks(
         &self,
         fork_tree: &HashMap<NodeID, HashSet<NodeID>>,
-        max_num_blocks: usize,
-    ) -> (HashSet<NodeID>, usize) {
+    ) -> (HashSet<NodeID>, String, bool) {
         let root_forks: HashSet<NodeID> = fork_tree.get(&NodeID::new(0)).unwrap().clone();
         if root_forks.len() != 1 {
-            return (root_forks, 1);
+            return (root_forks, "1".to_string(), false);
         }
 
         let root_fork = root_forks.iter().next().unwrap();
         let Node::Fork { factors, .. } = &self.function.nodes[root_fork.idx()] else {
             panic!("Expected fork node");
         };
-        let fork_size = self.multiply_fork_factors(factors);
         let reduces = &self.fork_reduce_map[root_fork];
         assert!(reduces.iter().all(|reduce| {
             self.collection_objects.objects(*reduce).iter().all(|object| {
                 self.collection_objects.origin(*object).try_parameter().is_some()
             })
         }), "All collection reduces in block fork must originate from parameters");
-        if let Some(fork_size) = fork_size
-            && fork_size <= max_num_blocks
-            && fork_size.is_power_of_two()
-            && self.function.schedules[root_fork.idx()].contains(&Schedule::ParallelFork)
+        if self.function.schedules[root_fork.idx()].contains(&Schedule::ParallelFork)
         {
-            (root_forks, fork_size)
+            let fork_size = factors.iter().map(|dc| format!("dc{}", dc.idx())).collect::<Vec<_>>().join(" * ");
+            (root_forks, fork_size, true)
         } else {
-            (root_forks, 1)
+            (root_forks, "1".to_string(), false)
         }
     }
 
@@ -597,9 +602,9 @@ extern \"C\" {} {}(", ret_type.clone(), self.function.name)?;
         &self,
         root_forks: &HashSet<NodeID>,
         fork_tree: &HashMap<NodeID, HashSet<NodeID>>,
-        num_blocks: usize,
+        is_block_parallel: bool,
     ) -> (NodeID, HashSet<NodeID>) {
-        if num_blocks > 1 {
+        if is_block_parallel {
             let root_fork = root_forks.iter().next().unwrap();
             (*root_fork, fork_tree.get(&root_fork).unwrap().iter().copied().collect())
         } else {
@@ -768,7 +773,7 @@ extern \"C\" {} {}(", ret_type.clone(), self.function.name)?;
                 let term = &mut goto.term;
                 let mut tabs = self.codegen_control_node(control, None, None, None, init, post_init, term)?;
                 for data in self.bbs.1[control.idx()].iter() {
-                    self.codegen_data_node(*data, KernelState::OutBlock, Some(1), None, None, None, false, extra_dim_collects, dynamic_shared_offset, body, &mut tabs)?;
+                    self.codegen_data_node(*data, KernelState::OutBlock, Some(false), None, None, None, false, extra_dim_collects, dynamic_shared_offset, body, &mut tabs)?;
                 }
                 Ok(())
             })
@@ -784,7 +789,7 @@ extern \"C\" {} {}(", ret_type.clone(), self.function.name)?;
         fork_thread_quota_map: &HashMap<NodeID, (usize, usize, usize)>,
         extra_dim_collects: &HashSet<TypeID>,
         dynamic_shared_offset: &mut String,
-        num_blocks: usize,
+        is_block_parallel: bool,
         num_threads: usize,
         gotos: &mut BTreeMap<NodeID, CudaGoto>,
     ) -> Result<(), Error> {
@@ -799,7 +804,7 @@ extern \"C\" {} {}(", ret_type.clone(), self.function.name)?;
             let term = &mut goto.term;
             let mut tabs = self.codegen_control_node(*control, None, None, None, init, post_init, term)?;
             for data in self.bbs.1[control.idx()].iter() {
-                self.codegen_data_node(*data, state, Some(num_blocks), None, None, None, false, extra_dim_collects, dynamic_shared_offset, body, &mut tabs)?;
+                self.codegen_data_node(*data, state, Some(is_block_parallel), None, None, None, false, extra_dim_collects, dynamic_shared_offset, body, &mut tabs)?;
             }
         }
         // Then generate data and control for the single block fork if it exists
@@ -911,7 +916,7 @@ extern \"C\" {} {}(", ret_type.clone(), self.function.name)?;
         &self,
         id: NodeID,
         state: KernelState,
-        num_blocks: Option<usize>,
+        is_block_parallel: Option<bool>,
         use_thread_quota: Option<usize>,
         parallel_factor: Option<usize>,
         nesting_fork: Option<NodeID>,
@@ -1215,7 +1220,7 @@ extern \"C\" {} {}(", ret_type.clone(), self.function.name)?;
                 let collect_with_indices = self.codegen_collect(*collect, indices, extra_dim_collects.contains(&self.typing[collect.idx()]));
                 let data_variable = self.get_value(*data, false, false);
                 let data_type_id = self.typing[data.idx()];
-                if KernelState::OutBlock == state && num_blocks.unwrap() > 1 {
+                if KernelState::OutBlock == state && is_block_parallel.unwrap() {
                     panic!("GPU can't guarantee correctness for multi-block collection writes");
                 }
                 let cg_tile = match state {
diff --git a/hercules_samples/dot/build.rs b/hercules_samples/dot/build.rs
index 4cfd2a87fba14d3c542bb54806a65da2d1a9b8f5..8657fdc166fe68ad2565a8a0736984c7991be0a7 100644
--- a/hercules_samples/dot/build.rs
+++ b/hercules_samples/dot/build.rs
@@ -4,8 +4,7 @@ fn main() {
     JunoCompiler::new()
         .ir_in_src("dot.hir")
         .unwrap()
-        //.schedule_in_src(if cfg!(feature = "cuda") { "gpu.sch" } else { "cpu.sch" })
-        .schedule_in_src("cpu.sch")
+        .schedule_in_src(if cfg!(feature = "cuda") { "gpu.sch" } else { "cpu.sch" })
         .unwrap()
         .build()
         .unwrap();
diff --git a/hercules_samples/matmul/build.rs b/hercules_samples/matmul/build.rs
index f895af867a019dfd23381a4df2d9a02f80a032f8..735458c0c8be76bdae6cd7b3b308e38ccae78edd 100644
--- a/hercules_samples/matmul/build.rs
+++ b/hercules_samples/matmul/build.rs
@@ -4,8 +4,7 @@ fn main() {
     JunoCompiler::new()
         .ir_in_src("matmul.hir")
         .unwrap()
-        //.schedule_in_src(if cfg!(feature = "cuda") { "gpu.sch" } else { "cpu.sch" })
-        .schedule_in_src("cpu.sch")
+        .schedule_in_src(if cfg!(feature = "cuda") { "gpu.sch" } else { "cpu.sch" })
         .unwrap()
         .build()
         .unwrap();
diff --git a/juno_frontend/Cargo.toml b/juno_frontend/Cargo.toml
index b6d9a71d70786d389649a80ea865c9b4153eef7f..648daf5f4757f055259c1053d10c5849259941f8 100644
--- a/juno_frontend/Cargo.toml
+++ b/juno_frontend/Cargo.toml
@@ -5,7 +5,7 @@ authors = ["Aaron Councilman <aaronjc4@illinois.edu>"]
 edition = "2021"
 
 [features]
-cuda = ["hercules_opt/cuda"]
+cuda = ["hercules_opt/cuda", "juno_scheduler/cuda"]
 default = []
 
 [[bin]]
diff --git a/juno_samples/schedule_test/Cargo.toml b/juno_samples/schedule_test/Cargo.toml
index be5d949bf1959d48c3463717f28b9fd186e05170..c783217a816960c764083c24aed1fdc5f0d1fb77 100644
--- a/juno_samples/schedule_test/Cargo.toml
+++ b/juno_samples/schedule_test/Cargo.toml
@@ -8,6 +8,9 @@ edition = "2021"
 name = "juno_schedule_test"
 path = "src/main.rs"
 
+[features]
+cuda = ["juno_build/cuda", "hercules_rt/cuda"]
+
 [build-dependencies]
 juno_build = { path = "../../juno_build" }
 
diff --git a/juno_samples/schedule_test/build.rs b/juno_samples/schedule_test/build.rs
index 4a4282473e87d6c24e12b5e3d59521ee8c99141e..749a660c551e8b231f63287898adb2863aef826e 100644
--- a/juno_samples/schedule_test/build.rs
+++ b/juno_samples/schedule_test/build.rs
@@ -4,7 +4,7 @@ fn main() {
     JunoCompiler::new()
         .file_in_src("code.jn")
         .unwrap()
-        .schedule_in_src("sched.sch")
+        .schedule_in_src(if cfg!(feature = "cuda") { "gpu.sch" } else { "cpu.sch" })
         .unwrap()
         .build()
         .unwrap();
diff --git a/juno_samples/schedule_test/src/sched.sch b/juno_samples/schedule_test/src/cpu.sch
similarity index 100%
rename from juno_samples/schedule_test/src/sched.sch
rename to juno_samples/schedule_test/src/cpu.sch
diff --git a/juno_samples/schedule_test/src/gpu.sch b/juno_samples/schedule_test/src/gpu.sch
new file mode 100644
index 0000000000000000000000000000000000000000..edca678ee103ec053f54a8e2caf0ee3e9e721e18
--- /dev/null
+++ b/juno_samples/schedule_test/src/gpu.sch
@@ -0,0 +1,47 @@
+macro juno-setup!(X) {
+  //gvn(X);
+  phi-elim(X);
+  dce(X);
+  lift-dc-math(X);
+}
+macro codegen-prep!(X) {
+  infer-schedules(X);
+  dce(X);
+  gcm(X);
+  dce(X);
+  phi-elim(X);
+  float-collections(X);
+  gcm(X);
+}
+
+
+juno-setup!(*);
+
+let first = outline(test@outer);
+let second = outline(test@row);
+
+// We can use the functions produced by outlining in our schedules
+gvn(first, second, test);
+
+ip-sroa(*);
+sroa(*);
+
+// We can evaluate expressions using labels and save them for later use
+let inner = first@inner;
+
+// A fixpoint can run a (series) of passes until no more changes are made
+// (though some passes seem to make edits even if there are no real changes,
+// so this is fragile).
+// We could just let it run until it converges but can also tell it to panic
+// if it hasn't converged after a number of iterations (like here) tell it to
+// just stop after a certain number of iterations (stop after #) or to print
+// the iteration number (print iter)
+fixpoint panic after 2 {
+  phi-elim(*);
+}
+
+host(*);
+gpu(first, second);
+
+codegen-prep!(*);
+//xdot[true](*);
diff --git a/juno_scheduler/Cargo.toml b/juno_scheduler/Cargo.toml
index 1c837d4a32764abb179b6a05fd5225b808ea764a..3d81ea967f5a6b862ac0fa6a03303b4eebcf1e01 100644
--- a/juno_scheduler/Cargo.toml
+++ b/juno_scheduler/Cargo.toml
@@ -4,6 +4,9 @@ version = "0.0.1"
 authors = ["Aaron Councilman <aaronjc4@illinois.edu>"]
 edition = "2021"
 
+[features]
+cuda = []
+
 [build-dependencies]
 cfgrammar = "0.13"
 lrlex = "0.13"