From 10574e4972a56f1707d3aa8272dfe62ebdbf0422 Mon Sep 17 00:00:00 2001
From: rarbore2 <rarbore2@illinois.edu>
Date: Sun, 22 Dec 2024 00:06:23 -0600
Subject: [PATCH] Refactor schedules

---
 hercules_cg/src/cpu.rs                   |   2 +-
 hercules_ir/src/build.rs                 |  13 +-
 hercules_ir/src/dot.rs                   | 152 ++---
 hercules_ir/src/ir.rs                    |  52 +-
 hercules_ir/src/lib.rs                   |   2 -
 hercules_ir/src/parse.rs                 |  11 +-
 hercules_ir/src/schedule.rs              | 818 -----------------------
 hercules_ir/src/subgraph.rs              |  83 ---
 hercules_opt/src/ccp.rs                  |  10 +-
 hercules_opt/src/delete_uncalled.rs      |   5 +-
 hercules_opt/src/editor.rs               | 370 +++-------
 hercules_opt/src/fork_concat_split.rs    |  12 +-
 hercules_opt/src/gvn.rs                  |   3 +-
 hercules_opt/src/inline.rs               |  22 +-
 hercules_opt/src/interprocedural_sroa.rs |   4 +-
 hercules_opt/src/lib.rs                  |   2 +
 hercules_opt/src/outline.rs              |  11 +-
 hercules_opt/src/pass.rs                 | 251 ++-----
 hercules_opt/src/pred.rs                 |   2 -
 hercules_opt/src/schedule.rs             | 175 +++++
 hercules_opt/src/sroa.rs                 |  15 +-
 hercules_opt/src/unforkify.rs            |   4 +
 juno_frontend/src/lib.rs                 |   4 +
 juno_scheduler/src/lib.rs                |   3 +-
 24 files changed, 492 insertions(+), 1534 deletions(-)
 delete mode 100644 hercules_ir/src/schedule.rs
 create mode 100644 hercules_opt/src/schedule.rs

diff --git a/hercules_cg/src/cpu.rs b/hercules_cg/src/cpu.rs
index d9bf505c..a8f16790 100644
--- a/hercules_cg/src/cpu.rs
+++ b/hercules_cg/src/cpu.rs
@@ -1,7 +1,7 @@
 extern crate bitvec;
 extern crate hercules_ir;
 
-use std::collections::{BTreeMap, HashMap, HashSet, VecDeque};
+use std::collections::{BTreeMap, VecDeque};
 use std::fmt::{Error, Write};
 use std::iter::{zip, FromIterator};
 use std::sync::atomic::{AtomicUsize, Ordering};
diff --git a/hercules_ir/src/build.rs b/hercules_ir/src/build.rs
index 6dd5a3e6..78c4eca4 100644
--- a/hercules_ir/src/build.rs
+++ b/hercules_ir/src/build.rs
@@ -36,6 +36,7 @@ pub struct NodeBuilder {
     id: NodeID,
     function_id: FunctionID,
     node: Node,
+    schedules: Vec<Schedule>,
 }
 
 /*
@@ -468,9 +469,11 @@ impl<'a> Builder<'a> {
             name: name.to_owned(),
             param_types,
             return_type,
-            nodes: vec![Node::Start],
             num_dynamic_constants,
             entry,
+            nodes: vec![Node::Start],
+            schedules: vec![vec![]],
+            device: None,
         });
         Ok((id, NodeID::new(0)))
     }
@@ -480,10 +483,12 @@ impl<'a> Builder<'a> {
         self.module.functions[function.idx()]
             .nodes
             .push(Node::Start);
+        self.module.functions[function.idx()].schedules.push(vec![]);
         NodeBuilder {
             id,
             function_id: function,
             node: Node::Start,
+            schedules: vec![],
         }
     }
 
@@ -492,6 +497,8 @@ impl<'a> Builder<'a> {
             Err("Can't add node from a NodeBuilder before NodeBuilder has built a node.")?
         }
         self.module.functions[builder.function_id.idx()].nodes[builder.id.idx()] = builder.node;
+        self.module.functions[builder.function_id.idx()].schedules[builder.id.idx()] =
+            builder.schedules;
         Ok(())
     }
 }
@@ -606,4 +613,8 @@ impl NodeBuilder {
             indices,
         };
     }
+
+    pub fn add_schedule(&mut self, schedule: Schedule) {
+        self.schedules.push(schedule);
+    }
 }
diff --git a/hercules_ir/src/dot.rs b/hercules_ir/src/dot.rs
index d23a4972..a8754e78 100644
--- a/hercules_ir/src/dot.rs
+++ b/hercules_ir/src/dot.rs
@@ -21,7 +21,6 @@ pub fn xdot_module(
     doms: Option<&Vec<DomTree>>,
     fork_join_maps: Option<&Vec<HashMap<NodeID, NodeID>>>,
     bbs: Option<&Vec<Vec<NodeID>>>,
-    plans: Option<&Vec<Plan>>,
 ) {
     let mut tmp_path = temp_dir();
     let mut rng = rand::thread_rng();
@@ -35,7 +34,6 @@ pub fn xdot_module(
         doms,
         fork_join_maps,
         bbs,
-        plans,
         &mut contents,
     )
     .expect("PANIC: Unable to generate output file contents.");
@@ -58,7 +56,6 @@ pub fn write_dot<W: Write>(
     doms: Option<&Vec<DomTree>>,
     fork_join_maps: Option<&Vec<HashMap<NodeID, NodeID>>>,
     bbs: Option<&Vec<Vec<NodeID>>>,
-    plans: Option<&Vec<Plan>>,
     w: &mut W,
 ) -> std::fmt::Result {
     write_digraph_header(w)?;
@@ -66,101 +63,74 @@ pub fn write_dot<W: Write>(
     for function_id in (0..module.functions.len()).map(FunctionID::new) {
         let function = &module.functions[function_id.idx()];
         let reverse_postorder = &reverse_postorders[function_id.idx()];
-        let plan = plans.map(|plans| &plans[function_id.idx()]);
         let mut reverse_postorder_node_numbers = vec![0; function.nodes.len()];
         for (idx, id) in reverse_postorder.iter().enumerate() {
             reverse_postorder_node_numbers[id.idx()] = idx;
         }
         write_subgraph_header(function_id, module, w)?;
 
-        let mut partition_to_node_map = plan.map(|plan| plan.invert_partition_map());
-
         // Step 1: draw IR graph itself. This includes all IR nodes and all edges
         // between IR nodes.
-        for partition_idx in 0..plan.map_or(1, |plan| plan.num_partitions) {
-            // First, write all the nodes in their subgraph.
-            let partition_color = plan.map(|plan| match plan.partition_devices[partition_idx] {
-                Device::CPU => "lightblue",
-                Device::GPU => "darkseagreen",
-                Device::AsyncRust => "plum2",
-            });
-            if let Some(partition_color) = partition_color {
-                write_partition_header(function_id, partition_idx, module, partition_color, w)?;
-            }
-            let nodes_ids = if let Some(partition_to_node_map) = &mut partition_to_node_map {
-                let mut empty = vec![];
-                std::mem::swap(&mut partition_to_node_map[partition_idx], &mut empty);
-                empty
-            } else {
-                (0..function.nodes.len())
-                    .map(NodeID::new)
-                    .collect::<Vec<_>>()
-            };
-            for node_id in nodes_ids.iter() {
-                let node = &function.nodes[node_id.idx()];
-                let dst_control = node.is_control();
+        for node_id in (0..function.nodes.len()).map(NodeID::new) {
+            let node = &function.nodes[node_id.idx()];
+            let dst_control = node.is_control();
+
+            // Control nodes are dark red, data nodes are dark blue.
+            let color = if dst_control { "darkred" } else { "darkblue" };
+
+            write_node(
+                node_id,
+                function_id,
+                color,
+                module,
+                &function.schedules[node_id.idx()],
+                w,
+            )?;
+        }
+
+        for node_id in (0..function.nodes.len()).map(NodeID::new) {
+            let node = &function.nodes[node_id.idx()];
+            let dst_control = node.is_control();
+            for u in def_use::get_uses(&node).as_ref() {
+                let src_control = function.nodes[u.idx()].is_control();
 
-                // Control nodes are dark red, data nodes are dark blue.
-                let color = if dst_control { "darkred" } else { "darkblue" };
+                // An edge between control nodes is dashed. An edge between data
+                // nodes is filled. An edge between a control node and a data
+                // node is dotted.
+                let style = if dst_control && src_control {
+                    "dashed"
+                } else if !dst_control && !src_control {
+                    ""
+                } else {
+                    "dotted"
+                };
 
-                write_node(
-                    *node_id,
+                // To have a consistent layout, we will add "back edges" in the
+                // IR graph as backward facing edges in the graphviz output, so
+                // that they don't mess up the layout. There isn't necessarily a
+                // precise definition of a "back edge" in Hercules IR. I've
+                // found what makes for the most clear output graphs is treating
+                // edges to phi nodes as back edges when the phi node appears
+                // before the use in the reverse postorder, and treating a
+                // control edge a back edge when the destination appears before
+                // the source in the reverse postorder.
+                let is_back_edge = reverse_postorder_node_numbers[node_id.idx()]
+                    < reverse_postorder_node_numbers[u.idx()]
+                    && (node.is_phi()
+                        || (function.nodes[node_id.idx()].is_control()
+                            && function.nodes[u.idx()].is_control()));
+                write_edge(
+                    node_id,
                     function_id,
-                    color,
+                    *u,
+                    function_id,
+                    !is_back_edge,
+                    "black",
+                    style,
                     module,
-                    plan.map_or(&vec![], |plan| &plan.schedules[node_id.idx()]),
                     w,
                 )?;
             }
-            if plans.is_some() {
-                write_graph_footer(w)?;
-            }
-
-            // Second, write all the edges coming out of a node.
-            for node_id in nodes_ids.iter() {
-                let node = &function.nodes[node_id.idx()];
-                let dst_control = node.is_control();
-                for u in def_use::get_uses(&node).as_ref() {
-                    let src_control = function.nodes[u.idx()].is_control();
-
-                    // An edge between control nodes is dashed. An edge between data
-                    // nodes is filled. An edge between a control node and a data
-                    // node is dotted.
-                    let style = if dst_control && src_control {
-                        "dashed"
-                    } else if !dst_control && !src_control {
-                        ""
-                    } else {
-                        "dotted"
-                    };
-
-                    // To have a consistent layout, we will add "back edges" in the
-                    // IR graph as backward facing edges in the graphviz output, so
-                    // that they don't mess up the layout. There isn't necessarily a
-                    // precise definition of a "back edge" in Hercules IR. I've
-                    // found what makes for the most clear output graphs is treating
-                    // edges to phi nodes as back edges when the phi node appears
-                    // before the use in the reverse postorder, and treating a
-                    // control edge a back edge when the destination appears before
-                    // the source in the reverse postorder.
-                    let is_back_edge = reverse_postorder_node_numbers[node_id.idx()]
-                        < reverse_postorder_node_numbers[u.idx()]
-                        && (node.is_phi()
-                            || (function.nodes[node_id.idx()].is_control()
-                                && function.nodes[u.idx()].is_control()));
-                    write_edge(
-                        *node_id,
-                        function_id,
-                        *u,
-                        function_id,
-                        !is_back_edge,
-                        "black",
-                        style,
-                        module,
-                        w,
-                    )?;
-                }
-            }
         }
 
         // Step 2: draw dominance edges in dark green. Don't draw post dominance
@@ -258,22 +228,6 @@ fn write_subgraph_header<W: Write>(
     Ok(())
 }
 
-fn write_partition_header<W: Write>(
-    function_id: FunctionID,
-    partition_idx: usize,
-    module: &Module,
-    color: &str,
-    w: &mut W,
-) -> std::fmt::Result {
-    let function = &module.functions[function_id.idx()];
-    write!(w, "subgraph {}_{} {{\n", function.name, partition_idx)?;
-    write!(w, "label=\"\"\n")?;
-    write!(w, "style=rounded\n")?;
-    write!(w, "bgcolor={}\n", color)?;
-    write!(w, "cluster=true\n")?;
-    Ok(())
-}
-
 fn write_graph_footer<W: Write>(w: &mut W) -> std::fmt::Result {
     write!(w, "}}\n")?;
     Ok(())
diff --git a/hercules_ir/src/ir.rs b/hercules_ir/src/ir.rs
index cef8e43a..d4eed8e2 100644
--- a/hercules_ir/src/ir.rs
+++ b/hercules_ir/src/ir.rs
@@ -40,9 +40,13 @@ pub struct Function {
     pub name: String,
     pub param_types: Vec<TypeID>,
     pub return_type: TypeID,
-    pub nodes: Vec<Node>,
     pub num_dynamic_constants: u32,
     pub entry: bool,
+
+    pub nodes: Vec<Node>,
+
+    pub schedules: FunctionSchedule,
+    pub device: Option<Device>,
 }
 
 /*
@@ -296,6 +300,48 @@ pub enum Intrinsic {
     Tanh,
 }
 
+/*
+ * An individual schedule is a single "directive" for the compiler to take into
+ * consideration at some point during the compilation pipeline. Each schedule is
+ * associated with a single node.
+ */
+#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
+pub enum Schedule {
+    // This fork can be run in parallel. Any notion of tiling is handled by the
+    // fork tiling passes - the backend will lower a parallel fork with
+    // dimension N into a N-way thread launch.
+    ParallelFork,
+    // This reduce can be "run in parallel" - conceptually, the `reduct`
+    // backedge can be removed, and the reduce code can be merged into the
+    // parallel code.
+    ParallelReduce,
+    // This fork-join has no impeding control flow and the fork has a constant
+    // factor.
+    Vectorizable,
+    // This reduce can be re-associated. This may lower a sequential dependency
+    // chain into a reduction tree.
+    Associative,
+}
+
+/*
+ * The authoritative enumeration of supported backends. Multiple backends may
+ * correspond to the same kind of hardware.
+ */
+#[derive(Debug, Clone, Copy, Serialize, Deserialize)]
+pub enum Device {
+    LLVM,
+    NVVM,
+    // Internal nodes in the call graph are lowered to async Rust code that
+    // calls device functions (leaf nodes in the call graph), possibly
+    // concurrently.
+    AsyncRust,
+}
+
+/*
+ * A single node may have multiple schedules.
+ */
+pub type FunctionSchedule = Vec<Vec<Schedule>>;
+
 impl Module {
     /*
      * Printing out types, constants, and dynamic constants fully requires a
@@ -671,9 +717,11 @@ impl Function {
             // Add to new_nodes.
             new_nodes.push(node);
         }
-
         std::mem::swap(&mut new_nodes, &mut self.nodes);
 
+        // Step 4: update the schedules.
+        self.schedules.fix_gravestones(&node_mapping);
+
         node_mapping
     }
 }
diff --git a/hercules_ir/src/lib.rs b/hercules_ir/src/lib.rs
index f7277cfa..abe7f46f 100644
--- a/hercules_ir/src/lib.rs
+++ b/hercules_ir/src/lib.rs
@@ -17,7 +17,6 @@ pub mod gcm;
 pub mod ir;
 pub mod loops;
 pub mod parse;
-pub mod schedule;
 pub mod subgraph;
 pub mod typecheck;
 pub mod verify;
@@ -33,7 +32,6 @@ pub use crate::gcm::*;
 pub use crate::ir::*;
 pub use crate::loops::*;
 pub use crate::parse::*;
-pub use crate::schedule::*;
 pub use crate::subgraph::*;
 pub use crate::typecheck::*;
 pub use crate::verify::*;
diff --git a/hercules_ir/src/parse.rs b/hercules_ir/src/parse.rs
index 41d59441..3a47f8ea 100644
--- a/hercules_ir/src/parse.rs
+++ b/hercules_ir/src/parse.rs
@@ -122,9 +122,11 @@ fn parse_module<'a>(ir_text: &'a str, context: Context<'a>) -> nom::IResult<&'a
             name: String::from(""),
             param_types: vec![],
             return_type: TypeID::new(0),
-            nodes: vec![],
             num_dynamic_constants: 0,
-            entry: true
+            entry: true,
+            nodes: vec![],
+            schedules: vec![],
+            device: None,
         };
         context.function_ids.len()
     ];
@@ -251,15 +253,18 @@ fn parse_function<'a>(
 
     // Intern function name.
     context.borrow_mut().get_function_id(function_name);
+    let num_nodes = fixed_nodes.len();
     Ok((
         ir_text,
         Function {
             name: String::from(function_name),
             param_types: params.into_iter().map(|x| x.5).collect(),
             return_type,
-            nodes: fixed_nodes,
             num_dynamic_constants,
             entry: true,
+            nodes: fixed_nodes,
+            schedules: vec![vec![]; num_nodes],
+            device: None,
         },
     ))
 }
diff --git a/hercules_ir/src/schedule.rs b/hercules_ir/src/schedule.rs
deleted file mode 100644
index 2438a982..00000000
--- a/hercules_ir/src/schedule.rs
+++ /dev/null
@@ -1,818 +0,0 @@
-use std::collections::{HashMap, VecDeque};
-use std::iter::{repeat, zip};
-
-use crate::*;
-
-/*
- * An individual schedule is a single "directive" for the compiler to take into
- * consideration at some point during the compilation pipeline. Each schedule is
- * associated with a single node.
- */
-#[derive(Debug, Clone, PartialEq, Eq)]
-pub enum Schedule {
-    // This fork can be run in parallel and has a "natural" tiling, which may or
-    // may not be respected by certain backends. The field stores at least how
-    // many parallel tiles should be run concurrently, along each dimension.
-    // Some backends (such as GPU) may spawn more parallel tiles (each tile
-    // being a single thread in that case) along each axis.
-    ParallelFork(Box<[usize]>),
-    // This reduce can be "run in parallel" - conceptually, the `reduct`
-    // backedge can be removed, and the reduce code can be merged into the
-    // parallel code.
-    ParallelReduce,
-    // This fork-join has no impeding control flow. The field stores the vector
-    // width.
-    Vectorizable(usize),
-    // This reduce can be re-associated. This may lower a sequential dependency
-    // chain into a reduction tree.
-    Associative,
-}
-
-/*
- * The authoritative enumeration of supported devices. Technically, a device
- * refers to a specific backend, so different "devices" may refer to the same
- * "kind" of hardware.
- */
-#[derive(Debug, Clone, Copy)]
-pub enum Device {
-    CPU,
-    GPU,
-    // Hercules function calls are placed in solitary function calls that are
-    // directly represented in the generated async Rust runtime code.
-    AsyncRust,
-}
-
-/*
- * A plan is a set of schedules associated with each node, plus partitioning
- * information. Partitioning information is a mapping from node ID to partition
- * ID. A plan's scope is a single function.
- */
-#[derive(Debug, Clone)]
-pub struct Plan {
-    pub schedules: Vec<Vec<Schedule>>,
-    pub partitions: Vec<PartitionID>,
-    pub partition_devices: Vec<Device>,
-    pub num_partitions: usize,
-}
-
-define_id_type!(PartitionID);
-
-impl Plan {
-    /*
-     * Invert stored map from node to partition to map from partition to nodes.
-     */
-    pub fn invert_partition_map(&self) -> Vec<Vec<NodeID>> {
-        let mut map = vec![vec![]; self.num_partitions];
-
-        for idx in 0..self.partitions.len() {
-            map[self.partitions[idx].idx()].push(NodeID::new(idx));
-        }
-
-        map
-    }
-
-    /*
-     * GCM takes in a partial partitioning, but most of the time we have a full
-     * partitioning.
-     */
-    pub fn make_partial_partitioning(&self) -> Vec<Option<PartitionID>> {
-        self.partitions.iter().map(|id| Some(*id)).collect()
-    }
-
-    /*
-     * LEGACY API: This is the legacy mechanism for repairing plans. Future code
-     * should use the `repair_plan` function in editor.rs.
-     * Plans must be "repairable", in the sense that the IR that's referred to
-     * may change after many passes. Since a plan is an explicit side data
-     * structure, it must be updated after every change in the IR.
-     */
-    pub fn repair(self, function: &Function, grave_mapping: &Vec<NodeID>) -> Self {
-        // Unpack the plan.
-        let old_inverse_partition_map = self.invert_partition_map();
-        let Plan {
-            mut schedules,
-            partitions: _,
-            partition_devices,
-            num_partitions: _,
-        } = self;
-
-        // Schedules of old nodes just get dropped. Since schedules don't hold
-        // necessary semantic information, we are free to drop them arbitrarily.
-        schedules.fix_gravestones(grave_mapping);
-        schedules.resize(function.nodes.len(), vec![]);
-
-        // Delete now empty partitions. First, filter out deleted nodes from the
-        // partitions and simultaneously map old node IDs to new node IDs. Then,
-        // filter out empty partitions.
-        let (new_inverse_partition_map, new_devices): (Vec<Vec<NodeID>>, Vec<Device>) =
-            zip(old_inverse_partition_map, partition_devices)
-                .into_iter()
-                .map(|(contents, device)| {
-                    (
-                        contents
-                            .into_iter()
-                            .filter_map(|id| {
-                                if id.idx() == 0 || grave_mapping[id.idx()].idx() != 0 {
-                                    Some(grave_mapping[id.idx()])
-                                } else {
-                                    None
-                                }
-                            })
-                            .collect::<Vec<NodeID>>(),
-                        device,
-                    )
-                })
-                .filter(|(contents, _)| !contents.is_empty())
-                .unzip();
-
-        // Calculate the number of nodes after deletion but before addition. Use
-        // this is iterate new nodes later.
-        let num_nodes_before_addition = new_inverse_partition_map.iter().flatten().count();
-        assert!(new_inverse_partition_map
-            .iter()
-            .flatten()
-            .all(|id| id.idx() < num_nodes_before_addition));
-
-        // Calculate the nodes that need to be assigned to a partition. This
-        // starts as just the nodes that have been added by passes.
-        let mut new_node_ids: VecDeque<NodeID> = (num_nodes_before_addition..function.nodes.len())
-            .map(NodeID::new)
-            .collect();
-
-        // Any partition no longer containing at least one control node needs to
-        // be liquidated.
-        let (new_inverse_partition_map, new_devices): (Vec<Vec<NodeID>>, Vec<Device>) =
-            zip(new_inverse_partition_map, new_devices)
-                .into_iter()
-                .filter_map(|(part, device)| {
-                    if part.iter().any(|id| function.nodes[id.idx()].is_control()) {
-                        Some((part, device))
-                    } else {
-                        // Nodes in removed partitions need to be re-partitioned.
-                        new_node_ids.extend(part);
-                        None
-                    }
-                })
-                .unzip();
-
-        // Assign the node IDs that need to be partitioned to partitions. In the
-        // process, construct a map from node ID to partition ID.
-        let mut node_id_to_partition_id: HashMap<NodeID, PartitionID> = new_inverse_partition_map
-            .into_iter()
-            .enumerate()
-            .map(|(partition_idx, node_ids)| {
-                node_ids
-                    .into_iter()
-                    .map(|node_id| (node_id, PartitionID::new(partition_idx)))
-                    .collect::<Vec<(NodeID, PartitionID)>>()
-            })
-            .flatten()
-            .collect();
-
-        // Make a best effort to assign nodes to the partition of one of their
-        // uses. Prioritize earlier uses. TODO: since not all partitions are
-        // legal, this is almost certainly not complete. Think more about that.
-        'workloop: while let Some(id) = new_node_ids.pop_front() {
-            for u in get_uses(&function.nodes[id.idx()]).as_ref() {
-                if let Some(partition_id) = node_id_to_partition_id.get(u) {
-                    node_id_to_partition_id.insert(id, *partition_id);
-                    continue 'workloop;
-                }
-            }
-            new_node_ids.push_back(id);
-        }
-
-        // Reconstruct the partitions vector.
-        let num_partitions = new_devices.len();
-        let mut partitions = vec![PartitionID::new(0); function.nodes.len()];
-        for (k, v) in node_id_to_partition_id {
-            partitions[k.idx()] = v;
-        }
-
-        // Reconstruct the whole plan.
-        Plan {
-            schedules,
-            partitions,
-            partition_devices: new_devices,
-            num_partitions,
-        }
-    }
-
-    /*
-     * Verify that a partitioning is valid.
-     */
-    pub fn verify_partitioning(
-        &self,
-        function: &Function,
-        def_use: &ImmutableDefUseMap,
-        fork_join_map: &HashMap<NodeID, NodeID>,
-    ) {
-        let partition_to_node_ids = self.invert_partition_map();
-
-        // First, verify that there is at most one control node in the partition
-        // with a control use outside the partition. A partition may only have
-        // zero such control nodes if it contains the start node. This also
-        // checks that each partition has at least one control node.
-        for nodes_in_partition in partition_to_node_ids.iter() {
-            let contains_start = nodes_in_partition
-                .iter()
-                .any(|id| function.nodes[id.idx()] == Node::Start);
-            let num_inter_partition_control_uses = nodes_in_partition
-                .iter()
-                .filter(|id| {
-                    // An inter-partition control use is a control node,
-                    function.nodes[id.idx()].is_control()
-                        // where one of its uses,
-                        && get_uses(&function.nodes[id.idx()])
-                            .as_ref()
-                            .into_iter()
-                            .any(|use_id| {
-                                // that is itself a control node as well,
-                                function.nodes[use_id.idx()].is_control()
-                                    // is in a different partition.
-                                    && self.partitions[use_id.idx()] != self.partitions[id.idx()]
-                            })
-                })
-                .count();
-
-            assert!(num_inter_partition_control_uses + contains_start as usize == 1, "PANIC: Found an invalid partition based on the inter-partition control use criteria.");
-        }
-
-        // Second, verify that fork-joins are not split amongst partitions.
-        for id in (0..function.nodes.len()).map(NodeID::new) {
-            if function.nodes[id.idx()].is_fork() {
-                let fork_part = self.partitions[id.idx()];
-
-                // The join must be in the same partition.
-                let join = fork_join_map[&id];
-                assert_eq!(
-                    fork_part,
-                    self.partitions[join.idx()],
-                    "PANIC: Join is in a different partition than its corresponding fork."
-                );
-
-                // The thread IDs must be in the same partition.
-                def_use
-                    .get_users(id)
-                    .into_iter()
-                    .filter(|user| function.nodes[user.idx()].is_thread_id())
-                    .for_each(|thread_id| {
-                        assert_eq!(
-                            fork_part,
-                            self.partitions[thread_id.idx()],
-                            "PANIC: Thread ID is in a different partition than its fork use."
-                        )
-                    });
-
-                // The reduces must be in the same partition.
-                def_use
-                    .get_users(join)
-                    .into_iter()
-                    .filter(|user| function.nodes[user.idx()].is_reduce())
-                    .for_each(|reduce| {
-                        assert_eq!(
-                            fork_part,
-                            self.partitions[reduce.idx()],
-                            "PANIC: Reduce is in a different partition than its join use."
-                        )
-                    });
-            }
-        }
-
-        // Third, verify that every data node has proper dominance relations
-        // with respect to the partitioning. In particular:
-        // 1. Every non-phi data node should be in a partition that is dominated
-        //    by the partitions of every one of its uses.
-        // 2. Every data node should be in a partition that dominates the
-        //    partitions of every one of its non-phi users.
-        // Compute a dominance relation between the partitions by constructing a
-        // partition control graph.
-        let partition_graph = partition_graph(function, def_use, self);
-        let dom = dominator(&partition_graph, NodeID::new(self.partitions[0].idx()));
-        for id in (0..function.nodes.len()).map(NodeID::new) {
-            if function.nodes[id.idx()].is_control() {
-                continue;
-            }
-            let part = self.partitions[id.idx()];
-
-            // Check condition #1.
-            if !function.nodes[id.idx()].is_phi() {
-                let uses = get_uses(&function.nodes[id.idx()]);
-                for use_id in uses.as_ref() {
-                    let use_part = self.partitions[use_id.idx()];
-                    assert!(dom.does_dom(NodeID::new(use_part.idx()), NodeID::new(part.idx())), "PANIC: A data node has a partition use that doesn't dominate its partition.");
-                }
-            }
-
-            // Check condition #2.
-            let users = def_use.get_users(id);
-            for user_id in users.as_ref() {
-                if !function.nodes[user_id.idx()].is_phi() {
-                    let user_part = self.partitions[user_id.idx()];
-                    assert!(dom.does_dom(NodeID::new(part.idx()), NodeID::new(user_part.idx())), "PANIC: A data node has a partition user that isn't dominated by its partition.");
-                }
-            }
-        }
-
-        // Fourth, verify that every projection node is in the same partition as
-        // its control use.
-        for id in (0..function.nodes.len()).map(NodeID::new) {
-            if let Node::Projection {
-                control,
-                selection: _,
-            } = function.nodes[id.idx()]
-            {
-                assert_eq!(
-                    self.partitions[id.idx()],
-                    self.partitions[control.idx()],
-                    "PANIC: Found a projection node in a different partition than its control use."
-                );
-            }
-        }
-
-        // Fifth, verify that every partition has at least one partition
-        // successor xor has at least one return node.
-        for partition_idx in 0..self.num_partitions {
-            let has_successor = partition_graph.succs(NodeID::new(partition_idx)).count() > 0;
-            let has_return = partition_to_node_ids[partition_idx]
-                .iter()
-                .any(|node_id| function.nodes[node_id.idx()].is_return());
-            assert!(has_successor ^ has_return, "PANIC: Found an invalid partition based on the partition return / control criteria.");
-        }
-    }
-
-    /*
-     * Compute the top node for each partition.
-     */
-    pub fn compute_top_nodes(
-        &self,
-        function: &Function,
-        control_subgraph: &Subgraph,
-        inverted_partition_map: &Vec<Vec<NodeID>>,
-    ) -> Vec<NodeID> {
-        inverted_partition_map
-            .into_iter()
-            .enumerate()
-            .map(|(part_idx, part)| {
-                // For each partition, find the "top" node.
-                *part
-                    .iter()
-                    .filter(move |id| {
-                        // The "top" node is a control node having at least one
-                        // control predecessor in another partition, or is a
-                        // start node. Every predecessor in the control subgraph
-                        // is a control node.
-                        function.nodes[id.idx()].is_start()
-                            || (function.nodes[id.idx()].is_control()
-                                && control_subgraph
-                                    .preds(**id)
-                                    .filter(|pred_id| {
-                                        self.partitions[pred_id.idx()].idx() != part_idx
-                                    })
-                                    .count()
-                                    > 0)
-                    })
-                    // We assume here there is exactly one such top node per
-                    // partition. Verify a partitioning with
-                    // `verify_partitioning` before calling this method.
-                    .next()
-                    .unwrap()
-            })
-            .collect()
-    }
-
-    /*
-     * Compute the data inputs of each partition.
-     */
-    pub fn compute_data_inputs(&self, function: &Function) -> Vec<Vec<NodeID>> {
-        let mut data_inputs = vec![vec![]; self.num_partitions];
-
-        // First consider the non-phi nodes in each partition.
-        for id in (0..function.nodes.len()).map(NodeID::new) {
-            if function.nodes[id.idx()].is_phi() {
-                continue;
-            }
-
-            let data_inputs = &mut data_inputs[self.partitions[id.idx()].idx()];
-            let uses = get_uses(&function.nodes[id.idx()]);
-            for use_id in uses.as_ref() {
-                // For every non-phi node, check each of its data uses. If the
-                // node and its use are in different partitions, then the use is
-                // a data input for the partition of the node. Also, don't add
-                // the same node to the data inputs list twice. Only consider
-                // non-constant uses data inputs - constant nodes are always
-                // rematerialized into the user partition.
-                if !function.nodes[use_id.idx()].is_control()
-                    && !function.nodes[use_id.idx()].is_constant()
-                    && self.partitions[id.idx()] != self.partitions[use_id.idx()]
-                    && !data_inputs.contains(use_id)
-                {
-                    data_inputs.push(*use_id);
-                }
-            }
-        }
-
-        // Second consider the phi nodes in each partition.
-        for id in (0..function.nodes.len()).map(NodeID::new) {
-            if !function.nodes[id.idx()].is_phi() {
-                continue;
-            }
-
-            let data_inputs = &mut data_inputs[self.partitions[id.idx()].idx()];
-            let uses = get_uses(&function.nodes[id.idx()]);
-            for use_id in uses.as_ref() {
-                // For every phi node, if any one of its non-constant uses is
-                // defined in a different partition, then the phi node itself,
-                // not its outside uses, is considered a data input. This is
-                // because a phi node whose uses are all in a different
-                // partition should be lowered to a single parameter to the
-                // corresponding simple IR function. Note that for a phi node
-                // with some uses outside and some uses inside the partition,
-                // the uses outside the partition become a single parameter to
-                // the schedule IR function, and that parameter and all of the
-                // "inside" uses become the inputs to a phi inside the simple IR
-                // function.
-                if self.partitions[id.idx()] != self.partitions[use_id.idx()]
-                    && !function.nodes[use_id.idx()].is_constant()
-                    && !data_inputs.contains(&id)
-                {
-                    data_inputs.push(id);
-                    break;
-                }
-            }
-        }
-
-        // Sort the node IDs to keep a consistent interface between partitions.
-        for data_inputs in &mut data_inputs {
-            data_inputs.sort();
-        }
-        data_inputs
-    }
-
-    /*
-     * Compute the data outputs of each partition.
-     */
-    pub fn compute_data_outputs(
-        &self,
-        function: &Function,
-        def_use: &ImmutableDefUseMap,
-    ) -> Vec<Vec<NodeID>> {
-        let mut data_outputs = vec![vec![]; self.num_partitions];
-
-        for id in (0..function.nodes.len()).map(NodeID::new) {
-            // Only consider non-constant data nodes as data outputs, since
-            // constant nodes are rematerialized into the user partition.
-            if function.nodes[id.idx()].is_control() || function.nodes[id.idx()].is_constant() {
-                continue;
-            }
-
-            let data_outputs = &mut data_outputs[self.partitions[id.idx()].idx()];
-            let users = def_use.get_users(id);
-            for user_id in users.as_ref() {
-                // For every data node, check each of its users. If the node and
-                // its user are in different partitions, then the node is a data
-                // output for the partition of the node. Also, don't add the
-                // same node to the data outputs list twice. It doesn't matter
-                // how this data node is being used - all that matters is that
-                // it itself is a data node, and that it has a user outside the
-                // partition. This makes the code simpler than the inputs case.
-                if self.partitions[id.idx()] != self.partitions[user_id.idx()]
-                    && !data_outputs.contains(&id)
-                {
-                    data_outputs.push(id);
-                    break;
-                }
-            }
-        }
-
-        // Sort the node IDs to keep a consistent interface between partitions.
-        for data_outputs in &mut data_outputs {
-            data_outputs.sort();
-        }
-        data_outputs
-    }
-
-    pub fn renumber_partitions(&mut self) {
-        let mut renumber_partitions = HashMap::new();
-        for id in self.partitions.iter_mut() {
-            let next_id = PartitionID::new(renumber_partitions.len());
-            let old_id = *id;
-            let new_id = *renumber_partitions.entry(old_id).or_insert(next_id);
-            *id = new_id;
-        }
-        let mut new_devices = vec![Device::CPU; renumber_partitions.len()];
-        for (old_id, new_id) in renumber_partitions.iter() {
-            new_devices[new_id.idx()] = self.partition_devices[old_id.idx()];
-        }
-        self.partition_devices = new_devices;
-        self.num_partitions = renumber_partitions.len();
-    }
-}
-
-impl GraveUpdatable for Plan {
-    fn fix_gravestones(&mut self, grave_mapping: &Vec<NodeID>) {
-        self.schedules.fix_gravestones(grave_mapping);
-        self.partitions.fix_gravestones(grave_mapping);
-        self.renumber_partitions();
-    }
-}
-
-/*
- * Infer parallel fork-joins. These are fork-joins with only parallel reduction
- * variables and no parent fork-joins.
- */
-pub fn infer_parallel_fork(
-    function: &Function,
-    def_use: &ImmutableDefUseMap,
-    fork_join_map: &HashMap<NodeID, NodeID>,
-    fork_join_nest: &HashMap<NodeID, Vec<NodeID>>,
-    plan: &mut Plan,
-) {
-    for id in (0..function.nodes.len()).map(NodeID::new) {
-        let Node::Fork {
-            control: _,
-            ref factors,
-        } = function.nodes[id.idx()]
-        else {
-            continue;
-        };
-        let join_id = fork_join_map[&id];
-        let all_parallel_reduce = def_use.get_users(join_id).as_ref().into_iter().all(|user| {
-            plan.schedules[user.idx()].contains(&Schedule::ParallelReduce)
-                || function.nodes[user.idx()].is_control()
-        });
-        let top_level = fork_join_nest[&id].len() == 1 && fork_join_nest[&id][0] == id;
-        if all_parallel_reduce && top_level {
-            let tiling = repeat(4).take(factors.len()).collect();
-            plan.schedules[id.idx()].push(Schedule::ParallelFork(tiling));
-        }
-    }
-}
-
-/*
- * Infer parallel reductions consisting of a simple cycle between a Reduce node
- * and a Write node, where indices of the Write are position indices using the
- * ThreadID nodes attached to the corresponding Fork. This procedure also adds
- * the ParallelReduce schedule to Reduce nodes reducing over a parallelized
- * Reduce, as long as the base Write node also has position indices of the
- * ThreadID of the outer fork. In other words, the complete Reduce chain is
- * annotated with ParallelReduce, as long as each ThreadID dimension appears in
- * the positional indexing of the original Write.
- */
-pub fn infer_parallel_reduce(
-    function: &Function,
-    fork_join_map: &HashMap<NodeID, NodeID>,
-    plan: &mut Plan,
-) {
-    for id in (0..function.nodes.len())
-        .map(NodeID::new)
-        .filter(|id| function.nodes[id.idx()].is_reduce())
-    {
-        let mut first_control = None;
-        let mut last_reduce = id;
-        let mut chain_id = id;
-
-        // Walk down Reduce chain until we reach the Reduce potentially looping
-        // with the Write. Note the control node of the first Reduce, since this
-        // will tell us which Thread ID to look for in the Write.
-        while let Node::Reduce {
-            control,
-            init: _,
-            reduct,
-        } = function.nodes[chain_id.idx()]
-        {
-            if first_control.is_none() {
-                first_control = Some(control);
-            }
-
-            last_reduce = chain_id;
-            chain_id = reduct;
-        }
-
-        // Check for a Write-Reduce tight cycle.
-        if let Node::Write {
-            collect,
-            data: _,
-            indices,
-        } = &function.nodes[chain_id.idx()]
-            && *collect == last_reduce
-        {
-            // If there is a Write-Reduce tight cycle, get the position indices.
-            let positions = indices
-                .iter()
-                .filter_map(|index| {
-                    if let Index::Position(indices) = index {
-                        Some(indices)
-                    } else {
-                        None
-                    }
-                })
-                .flat_map(|pos| pos.iter());
-
-            // Get the Forks corresponding to uses of bare ThreadIDs.
-            let fork_thread_id_pairs = positions.filter_map(|id| {
-                if let Node::ThreadID { control, dimension } = function.nodes[id.idx()] {
-                    Some((control, dimension))
-                } else {
-                    None
-                }
-            });
-            let mut forks = HashMap::<NodeID, Vec<usize>>::new();
-            for (fork, dim) in fork_thread_id_pairs {
-                forks.entry(fork).or_default().push(dim);
-            }
-
-            // Check if one of the Forks correspond to the Join associated with
-            // the Reduce being considered, and has all of its dimensions
-            // represented in the indexing.
-            let is_parallel = forks.into_iter().any(|(id, mut rep_dims)| {
-                rep_dims.sort();
-                rep_dims.dedup();
-                fork_join_map[&id] == first_control.unwrap()
-                    && function.nodes[id.idx()].try_fork().unwrap().1.len() == rep_dims.len()
-            });
-
-            if is_parallel {
-                plan.schedules[id.idx()].push(Schedule::ParallelReduce);
-            }
-        }
-    }
-}
-
-/*
- * Infer vectorizable fork-joins. Just check that there are no control nodes
- * between a fork and its join and the factor is a constant.
- */
-pub fn infer_vectorizable(
-    function: &Function,
-    dynamic_constants: &Vec<DynamicConstant>,
-    fork_join_map: &HashMap<NodeID, NodeID>,
-    plan: &mut Plan,
-) {
-    for id in (0..function.nodes.len())
-        .map(NodeID::new)
-        .filter(|id| function.nodes[id.idx()].is_join())
-    {
-        let u = get_uses(&function.nodes[id.idx()]).as_ref()[0];
-        if let Some(join) = fork_join_map.get(&u)
-            && *join == id
-        {
-            let factors = function.nodes[u.idx()].try_fork().unwrap().1;
-            if factors.len() == 1
-                && let Some(width) = evaluate_dynamic_constant(factors[0], dynamic_constants)
-            {
-                plan.schedules[u.idx()].push(Schedule::Vectorizable(width));
-            }
-        }
-    }
-}
-
-/*
- * Infer associative reduction loops.
- */
-pub fn infer_associative(function: &Function, plan: &mut Plan) {
-    let is_associative = |op| match op {
-        BinaryOperator::Add
-        | BinaryOperator::Mul
-        | BinaryOperator::Or
-        | BinaryOperator::And
-        | BinaryOperator::Xor => true,
-        _ => false,
-    };
-
-    for (id, reduct) in (0..function.nodes.len()).map(NodeID::new).filter_map(|id| {
-        function.nodes[id.idx()]
-            .try_reduce()
-            .map(|(_, _, reduct)| (id, reduct))
-    }) {
-        if let Node::Binary { left, right, op } = function.nodes[reduct.idx()]
-            && (left == id || right == id)
-            && is_associative(op)
-        {
-            plan.schedules[id.idx()].push(Schedule::Associative);
-        }
-    }
-}
-
-/*
- * Create partitions corresponding to fork-join nests. Also, split the "top-
- * level" partition into sub-partitions that are connected graphs. Place data
- * nodes using ThreadID or Reduce nodes in the corresponding fork-join nest's
- * partition.
- */
-pub fn partition_out_forks(
-    function: &Function,
-    reverse_postorder: &Vec<NodeID>,
-    fork_join_map: &HashMap<NodeID, NodeID>,
-    bbs: &Vec<NodeID>,
-    plan: &mut Plan,
-) {
-    #[allow(non_local_definitions)]
-    impl Semilattice for NodeID {
-        fn meet(a: &Self, b: &Self) -> Self {
-            if a.idx() < b.idx() {
-                *a
-            } else {
-                *b
-            }
-        }
-
-        fn bottom() -> Self {
-            NodeID::new(0)
-        }
-
-        fn top() -> Self {
-            NodeID::new(!0)
-        }
-    }
-
-    // Step 1: do dataflow analysis over control nodes to identify a
-    // representative node for each partition. Each fork not taking as input the
-    // ID of another fork node introduces its own ID as a partition
-    // representative. Each join node propagates the fork ID if it's not the
-    // fork pointing to the join in the fork join map - otherwise, it introduces
-    // its user as a representative node ID for a partition. A region node taking
-    // multiple node IDs as input belongs to the partition with the smaller
-    // representative node ID.
-    let mut representatives = forward_dataflow(
-        function,
-        reverse_postorder,
-        |inputs: &[&NodeID], node_id: NodeID| match function.nodes[node_id.idx()] {
-            Node::Start => NodeID::new(0),
-            Node::Fork {
-                control,
-                factors: _,
-            } => {
-                // Start a partition if the preceding partition isn't a fork
-                // partition and the predecessor isn't the join for the
-                // predecessor fork partition. Otherwise, be part of the parent
-                // fork partition.
-                if *inputs[0] != NodeID::top()
-                    && function.nodes[inputs[0].idx()].is_fork()
-                    && fork_join_map.get(&inputs[0]) != Some(&control)
-                {
-                    inputs[0].clone()
-                } else {
-                    node_id
-                }
-            }
-            Node::Join { control: _ } => inputs[0].clone(),
-            _ => {
-                // If the previous node is a join and terminates a fork's
-                // partition, then start a new partition here. Otherwise, just
-                // meet over the input lattice values. Set all data nodes to be
-                // in the !0 partition.
-                if !function.nodes[node_id.idx()].is_control() {
-                    NodeID::top()
-                } else if zip(inputs, get_uses(&function.nodes[node_id.idx()]).as_ref())
-                    .any(|(part_id, pred_id)| fork_join_map.get(part_id) == Some(pred_id))
-                {
-                    node_id
-                } else {
-                    inputs
-                        .iter()
-                        .filter(|id| {
-                            ***id != NodeID::top() && function.nodes[id.idx()].is_control()
-                        })
-                        .fold(NodeID::top(), |a, b| NodeID::meet(&a, b))
-                }
-            }
-        },
-    );
-
-    // Step 2: assign data nodes to the partitions of the control nodes they are
-    // assigned to by GCM.
-    for idx in 0..function.nodes.len() {
-        if !function.nodes[idx].is_control() {
-            representatives[idx] = representatives[bbs[idx].idx()];
-        }
-    }
-
-    // Step 3: deduplicate representative node IDs.
-    let mut representative_to_partition_ids = HashMap::new();
-    for rep in &representatives {
-        if !representative_to_partition_ids.contains_key(rep) {
-            representative_to_partition_ids
-                .insert(rep, PartitionID::new(representative_to_partition_ids.len()));
-        }
-    }
-
-    // Step 4: update plan.
-    plan.num_partitions = representative_to_partition_ids.len();
-    for id in (0..function.nodes.len()).map(NodeID::new) {
-        plan.partitions[id.idx()] = representative_to_partition_ids[&representatives[id.idx()]];
-    }
-
-    plan.partition_devices = vec![Device::CPU; plan.num_partitions];
-}
-
-/*
- * Set the device for all partitions containing a fork to the GPU.
- */
-pub fn place_fork_partitions_on_gpu(function: &Function, plan: &mut Plan) {
-    for idx in 0..function.nodes.len() {
-        if function.nodes[idx].is_fork() {
-            plan.partition_devices[plan.partitions[idx].idx()] = Device::GPU;
-        }
-    }
-}
diff --git a/hercules_ir/src/subgraph.rs b/hercules_ir/src/subgraph.rs
index 3549718a..de342579 100644
--- a/hercules_ir/src/subgraph.rs
+++ b/hercules_ir/src/subgraph.rs
@@ -246,86 +246,3 @@ pub fn control_subgraph(function: &Function, def_use: &ImmutableDefUseMap) -> Su
             && (!function.nodes[node.idx()].is_start() || node.idx() == 0)
     })
 }
-
-/*
- * Construct a subgraph representing the control relations between partitions.
- * Technically, this isn't a "sub"graph of the function graph, since partition
- * nodes don't correspond to nodes in the original function.
- */
-pub fn partition_graph(function: &Function, def_use: &ImmutableDefUseMap, plan: &Plan) -> Subgraph {
-    let partition_to_node_ids = plan.invert_partition_map();
-
-    let mut subgraph = Subgraph {
-        nodes: (0..plan.num_partitions).map(NodeID::new).collect(),
-        node_numbers: (0..plan.num_partitions)
-            .map(|idx| (NodeID::new(idx), idx as u32))
-            .collect(),
-        first_forward_edges: vec![],
-        forward_edges: vec![],
-        first_backward_edges: vec![],
-        backward_edges: vec![],
-        original_num_nodes: plan.num_partitions as u32,
-    };
-
-    // Step 1: collect backward edges from use info.
-    for partition in partition_to_node_ids.iter() {
-        // Record the source of the edges (the current partition).
-        let old_num_edges = subgraph.backward_edges.len();
-        subgraph.first_backward_edges.push(old_num_edges as u32);
-        for node in partition
-            .iter()
-            .filter(|id| function.nodes[id.idx()].is_control())
-        {
-            // Look at all the control uses of control nodes in that partition.
-            let uses = get_uses(&function.nodes[node.idx()]);
-            for use_id in uses
-                .as_ref()
-                .iter()
-                .filter(|id| function.nodes[id.idx()].is_control())
-            {
-                // Add a backward edge to any different partition we are using
-                // and don't add duplicate backward edges.
-                if plan.partitions[use_id.idx()] != plan.partitions[node.idx()]
-                    && !subgraph.backward_edges[old_num_edges..]
-                        .contains(&(plan.partitions[use_id.idx()].idx() as u32))
-                {
-                    subgraph
-                        .backward_edges
-                        .push(plan.partitions[use_id.idx()].idx() as u32);
-                }
-            }
-        }
-    }
-
-    // Step 2: collect forward edges from user (def_use) info.
-    for partition in partition_to_node_ids.iter() {
-        // Record the source of the edges (the current partition).
-        let old_num_edges = subgraph.forward_edges.len();
-        subgraph.first_forward_edges.push(old_num_edges as u32);
-        for node in partition
-            .iter()
-            .filter(|id| function.nodes[id.idx()].is_control())
-        {
-            // Look at all the control users of control nodes in that partition.
-            let users = def_use.get_users(*node);
-            for user_id in users
-                .as_ref()
-                .iter()
-                .filter(|id| function.nodes[id.idx()].is_control())
-            {
-                // Add a forward edge to any different partition that we are a
-                // user of and don't add duplicate forward edges.
-                if plan.partitions[user_id.idx()] != plan.partitions[node.idx()]
-                    && !subgraph.forward_edges[old_num_edges..]
-                        .contains(&(plan.partitions[user_id.idx()].idx() as u32))
-                {
-                    subgraph
-                        .forward_edges
-                        .push(plan.partitions[user_id.idx()].idx() as u32);
-                }
-            }
-        }
-    }
-
-    subgraph
-}
diff --git a/hercules_opt/src/ccp.rs b/hercules_opt/src/ccp.rs
index aa3d0e68..3f53ea40 100644
--- a/hercules_opt/src/ccp.rs
+++ b/hercules_opt/src/ccp.rs
@@ -194,6 +194,7 @@ pub fn ccp(editor: &mut FunctionEditor, reverse_postorder: &Vec<NodeID>) {
                 // Add a constant IR node for this constant
                 let cons_node = edit.add_node(Node::Constant { id: cons_id });
                 // Replace the original node with the constant node
+                edit.sub_edit(old_id, cons_node);
                 edit = edit.replace_all_uses(old_id, cons_node)?;
                 edit.delete_node(old_id)
             });
@@ -227,6 +228,7 @@ pub fn ccp(editor: &mut FunctionEditor, reverse_postorder: &Vec<NodeID>) {
                         control,
                         data: new_data,
                     });
+                    edit.sub_edit(phi_id, new_node);
                     edit = edit.replace_all_uses(phi_id, new_node)?;
                     edit.delete_node(phi_id)
                 });
@@ -248,6 +250,7 @@ pub fn ccp(editor: &mut FunctionEditor, reverse_postorder: &Vec<NodeID>) {
                 .collect();
             editor.edit(|mut edit| {
                 let new_node = edit.add_node(Node::Region { preds: new_preds });
+                edit.sub_edit(region_id, new_node);
                 edit = edit.replace_all_uses(region_id, new_node)?;
                 edit.delete_node(region_id)
             });
@@ -276,12 +279,7 @@ pub fn ccp(editor: &mut FunctionEditor, reverse_postorder: &Vec<NodeID>) {
             // The reachable users iterator will contain one user if we need to
             // remove this branch node.
             if reachable_users.len() == 1 {
-                // The user is a Read node, which in turn has one user.
-                assert!(
-                    editor.get_users(the_reachable_user).len() == 1,
-                    "Control Read node doesn't have exactly one user."
-                );
-
+                // The user is a projection node, which in turn has one user.
                 editor.edit(|mut edit| {
                     // Replace all uses of the single reachable user with the node preceeding the
                     // branch node
diff --git a/hercules_opt/src/delete_uncalled.rs b/hercules_opt/src/delete_uncalled.rs
index fe99b5fb..78ab4285 100644
--- a/hercules_opt/src/delete_uncalled.rs
+++ b/hercules_opt/src/delete_uncalled.rs
@@ -67,8 +67,9 @@ pub fn delete_uncalled(
                     dynamic_constants: dynamic_constants.clone(),
                     args: args.clone(),
                 });
-                let edit = edit.delete_node(callsite)?;
-                edit.replace_all_uses(callsite, new_node)
+                edit.sub_edit(callsite, new_node);
+                let edit = edit.replace_all_uses(callsite, new_node)?;
+                edit.delete_node(callsite)
             });
             assert!(
                 success,
diff --git a/hercules_opt/src/editor.rs b/hercules_opt/src/editor.rs
index 0ff58822..0c97abff 100644
--- a/hercules_opt/src/editor.rs
+++ b/hercules_opt/src/editor.rs
@@ -4,32 +4,21 @@ extern crate hercules_ir;
 extern crate itertools;
 
 use std::cell::{Ref, RefCell};
-use std::collections::{BTreeMap, HashMap, HashSet, VecDeque};
-use std::iter::FromIterator;
+use std::collections::{BTreeMap, HashSet};
 use std::mem::take;
 use std::ops::Deref;
 
 use self::bitvec::prelude::*;
 use self::either::Either;
-use self::itertools::Itertools;
 
-use self::hercules_ir::antideps::*;
-use self::hercules_ir::dataflow::*;
 use self::hercules_ir::def_use::*;
-use self::hercules_ir::dom::*;
-use self::hercules_ir::gcm::*;
 use self::hercules_ir::ir::*;
-use self::hercules_ir::loops::*;
-use self::hercules_ir::schedule::*;
-use self::hercules_ir::subgraph::*;
-
-pub type Edit = (HashSet<NodeID>, HashSet<NodeID>);
 
 /*
- * Helper object for editing Hercules functions in a trackable manner. Edits are
- * recorded in order to repair partitions and debug info.
- * Edits must be made atomically, that is, only one `.edit` may be called at a time
- * across all editors.
+ * Helper object for editing Hercules functions in a trackable manner. Edits
+ * must be made atomically, that is, only one `.edit` may be called at a time
+ * across all editors, and individual edits must leave the function in a valid
+ * state.
  */
 #[derive(Debug)]
 pub struct FunctionEditor<'a> {
@@ -44,19 +33,6 @@ pub struct FunctionEditor<'a> {
     // Most optimizations need def use info, so provide an iteratively updated
     // mutable version that's automatically updated based on recorded edits.
     mut_def_use: Vec<HashSet<NodeID>>,
-    // Record edits as a mapping from sets of node IDs to sets of node IDs. The
-    // sets on the "left" side of this map should be mutually disjoint, and the
-    // sets on the "right" side should also be mutually disjoint. All of the
-    // node IDs on the left side should be deleted node IDs or IDs of unmodified
-    // nodes, and all of the node IDs on the right side should be added node IDs
-    // or IDs of unmodified nodes. In other words, there should be no added node
-    // IDs on the left side, and no deleted node IDs on the right side. These
-    // mappings are stored sequentially in a list, rather than as a map. This is
-    // because a transformation may iteratively update a function - i.e., a node
-    // ID added in iteration N may be deleted in iteration N + M. To maintain a
-    // more precise history of edits, we store each edit individually, which
-    // allows us to make more precise repairs of partitions and debug info.
-    edits: Vec<Edit>,
     // The pass manager may indicate that only a certain subset of nodes should
     // be modified in a function - what this actually means is that some nodes
     // are off limits for deletion (equivalently modification) or being replaced
@@ -77,6 +53,8 @@ pub struct FunctionEdit<'a: 'b, 'b> {
     added_nodeids: HashSet<NodeID>,
     // Keep track of added and use updated nodes.
     added_and_updated_nodes: BTreeMap<NodeID, Node>,
+    // Keep track of added and updated schedules.
+    added_and_updated_schedules: BTreeMap<NodeID, Vec<Schedule>>,
     // Keep track of added (dynamic) constants and types
     added_constants: Vec<Constant>,
     added_dynamic_constants: Vec<DynamicConstant>,
@@ -84,6 +62,8 @@ pub struct FunctionEdit<'a: 'b, 'b> {
     // Compute a def-use map entries iteratively.
     updated_def_use: BTreeMap<NodeID, HashSet<NodeID>>,
     updated_return_type: Option<TypeID>,
+    // Keep track of which deleted and added node IDs directly correspond.
+    sub_edits: Vec<(NodeID, NodeID)>,
 }
 
 impl<'a: 'b, 'b> FunctionEditor<'a> {
@@ -111,7 +91,6 @@ impl<'a: 'b, 'b> FunctionEditor<'a> {
             dynamic_constants,
             types,
             mut_def_use,
-            edits: vec![],
             mutable_nodes,
         }
     }
@@ -125,12 +104,14 @@ impl<'a: 'b, 'b> FunctionEditor<'a> {
             editor: self,
             deleted_nodeids: HashSet::new(),
             added_nodeids: HashSet::new(),
+            added_and_updated_nodes: BTreeMap::new(),
+            added_and_updated_schedules: BTreeMap::new(),
             added_constants: Vec::new().into(),
             added_dynamic_constants: Vec::new().into(),
             added_types: Vec::new().into(),
-            added_and_updated_nodes: BTreeMap::new(),
             updated_def_use: BTreeMap::new(),
             updated_return_type: None,
+            sub_edits: vec![],
         };
 
         if let Ok(populated_edit) = edit(edit_obj) {
@@ -140,12 +121,14 @@ impl<'a: 'b, 'b> FunctionEditor<'a> {
                 editor,
                 deleted_nodeids,
                 added_nodeids,
+                added_and_updated_nodes,
+                added_and_updated_schedules,
                 added_constants,
                 added_dynamic_constants,
                 added_types,
-                added_and_updated_nodes: added_and_updated,
                 updated_def_use,
                 updated_return_type,
+                sub_edits,
             } = populated_edit;
             // Step 1: update the mutable def use map.
             for (u, new_users) in updated_def_use {
@@ -165,7 +148,7 @@ impl<'a: 'b, 'b> FunctionEditor<'a> {
             }
 
             // Step 2: add and update nodes.
-            for (id, node) in added_and_updated {
+            for (id, node) in added_and_updated_nodes {
                 if id.idx() < editor.function.nodes.len() {
                     editor.function.nodes[id.idx()] = node;
                 } else {
@@ -176,7 +159,16 @@ impl<'a: 'b, 'b> FunctionEditor<'a> {
                 }
             }
 
-            // Step 3: delete nodes. This is done using "gravestones", where a
+            // Step 3: add and update schedules.
+            editor
+                .function
+                .schedules
+                .resize(editor.function.nodes.len(), vec![]);
+            for (id, schedule) in added_and_updated_schedules {
+                editor.function.schedules[id.idx()] = schedule;
+            }
+
+            // Step 4: delete nodes. This is done using "gravestones", where a
             // node other than node ID 0 being a start node is considered a
             // gravestone.
             for id in deleted_nodeids.iter() {
@@ -185,16 +177,24 @@ impl<'a: 'b, 'b> FunctionEditor<'a> {
                 editor.function.nodes[id.idx()] = Node::Start;
             }
 
-            // Step 4: add a single edit to the edit list.
-            editor.edits.push((deleted_nodeids, added_nodeids));
+            // Step 5: propagate schedules along sub-edit edges.
+            for (src, dst) in sub_edits {
+                let mut dst_schedules = take(&mut editor.function.schedules[dst.idx()]);
+                for src_schedule in editor.function.schedules[src.idx()].iter() {
+                    if !dst_schedules.contains(src_schedule) {
+                        dst_schedules.push(src_schedule.clone());
+                    }
+                }
+                editor.function.schedules[dst.idx()] = dst_schedules;
+            }
 
-            // Step 5: update the length of mutable_nodes. All added nodes are
+            // Step 6: update the length of mutable_nodes. All added nodes are
             // mutable.
             editor
                 .mutable_nodes
                 .resize(editor.function.nodes.len(), true);
 
-            // Step 6: update types and constants
+            // Step 7: update types and constants.
             let mut editor_constants = editor.constants.borrow_mut();
             let mut editor_dynamic_constants = editor.dynamic_constants.borrow_mut();
             let mut editor_types = editor.types.borrow_mut();
@@ -203,7 +203,7 @@ impl<'a: 'b, 'b> FunctionEditor<'a> {
             editor_dynamic_constants.extend(added_dynamic_constants);
             editor_types.extend(added_types);
 
-            // Step 7: update return type if necessary
+            // Step 8: update return type if necessary.
             if let Some(return_type) = updated_return_type {
                 editor.function.return_type = return_type;
             }
@@ -244,10 +244,6 @@ impl<'a: 'b, 'b> FunctionEditor<'a> {
         self.mutable_nodes[id.idx()]
     }
 
-    pub fn edits(self) -> Vec<Edit> {
-        self.edits
-    }
-
     pub fn node_ids(&self) -> impl ExactSizeIterator<Item = NodeID> {
         let num = self.function.nodes.len();
         (0..num).map(NodeID::new)
@@ -377,6 +373,12 @@ impl<'a, 'b> FunctionEdit<'a, 'b> {
         }
     }
 
+    pub fn sub_edit(&mut self, src: NodeID, dst: NodeID) {
+        assert!(!self.added_nodeids.contains(&src));
+        assert!(!self.deleted_nodeids.contains(&dst));
+        self.sub_edits.push((src, dst));
+    }
+
     pub fn get_name(&self) -> &str {
         &self.editor.function.name
     }
@@ -397,6 +399,43 @@ impl<'a, 'b> FunctionEdit<'a, 'b> {
         }
     }
 
+    pub fn get_schedule(&self, id: NodeID) -> &Vec<Schedule> {
+        // The user may get the schedule of a to-be deleted node.
+        if let Some(schedule) = self.added_and_updated_schedules.get(&id) {
+            // Refer to added or updated schedule.
+            schedule
+        } else {
+            // Refer to the original schedule of this node.
+            &self.editor.function.schedules[id.idx()]
+        }
+    }
+
+    pub fn add_schedule(mut self, id: NodeID, schedule: Schedule) -> Result<Self, Self> {
+        if self.is_mutable(id) {
+            if let Some(schedules) = self.added_and_updated_schedules.get_mut(&id) {
+                schedules.push(schedule);
+            } else {
+                let mut schedules = self.editor.function.schedules[id.idx()].clone();
+                if !schedules.contains(&schedule) {
+                    schedules.push(schedule);
+                }
+                self.added_and_updated_schedules.insert(id, schedules);
+            }
+            Ok(self)
+        } else {
+            Err(self)
+        }
+    }
+
+    pub fn clear_schedule(mut self, id: NodeID) -> Result<Self, Self> {
+        if self.is_mutable(id) {
+            self.added_and_updated_schedules.insert(id, vec![]);
+            Ok(self)
+        } else {
+            Err(self)
+        }
+    }
+
     pub fn get_users(&self, id: NodeID) -> impl Iterator<Item = NodeID> + '_ {
         assert!(!self.deleted_nodeids.contains(&id));
         if let Some(users) = self.updated_def_use.get(&id) {
@@ -539,250 +578,6 @@ impl<'a, 'b> FunctionEdit<'a, 'b> {
     }
 }
 
-/*
- * Simplify an edit sequence into a single, larger, edit.
- */
-fn collapse_edits(edits: &[Edit]) -> Edit {
-    let mut total_edit = Edit::default();
-
-    let mut all_additions: HashSet<NodeID> = HashSet::new();
-
-    for edit in edits {
-        assert!(edit.0.is_disjoint(&edit.1), "PANIC: Edit sequence is malformed - can't add and delete the same node ID in a single edit.");
-        assert!(
-            total_edit.0.is_disjoint(&edit.0),
-            "PANIC: Edit sequence is malformed - can't delete the same node ID twice."
-        );
-        assert!(
-            total_edit.1.is_disjoint(&edit.1),
-            "PANIC: Edit sequence is malformed - can't add the same node ID twice."
-        );
-
-        for delete in edit.0.iter() {
-            if !all_additions.contains(delete) {
-                total_edit.0.insert(*delete);
-            }
-            total_edit.1.remove(delete);
-        }
-
-        for addition in edit.1.iter() {
-            total_edit.0.remove(addition);
-            total_edit.1.insert(*addition);
-            all_additions.insert(*addition);
-        }
-    }
-
-    total_edit
-}
-
-/*
- * Plans can be repaired - this entails repairing schedules as well as
- * partitions. `new_function` is the function after the edits have occurred, but
- * before gravestones have been removed.
- */
-pub fn repair_plan(plan: &mut Plan, new_function: &Function, edits: &[Edit]) {
-    // Step 1: collapse all of the edits into a single edit. For repairing
-    // partitions, we don't need to consider the intermediate edit states.
-    let total_edit = collapse_edits(edits);
-
-    // Step 2: drop schedules for deleted nodes and create empty schedule lists
-    // for added nodes.
-    for deleted in total_edit.0.iter() {
-        // Nodes that were created and deleted using the same editor don't have
-        // an existing schedule, so ignore them
-        if deleted.idx() < plan.schedules.len() {
-            plan.schedules[deleted.idx()] = vec![];
-        }
-    }
-    if !total_edit.1.is_empty() {
-        assert_eq!(
-            total_edit.1.iter().max().unwrap().idx() + 1,
-            new_function.nodes.len()
-        );
-        plan.schedules.resize(new_function.nodes.len(), vec![]);
-    }
-
-    // Step 3: figure out the order to add nodes to partitions. Roughly, we look
-    // at the added nodes in reverse postorder and partition by control/data. We
-    // first add control nodes to partitions using node-specific rules. We then
-    // add data nodes based on the partitions of their immediate control uses
-    // and users.
-    let def_use = def_use(new_function);
-    let rev_po = reverse_postorder(&def_use);
-    let added_control_nodes: Vec<NodeID> = rev_po
-        .iter()
-        .filter(|id| total_edit.1.contains(id) && new_function.nodes[id.idx()].is_control())
-        .map(|id| *id)
-        .collect();
-    let added_data_nodes: Vec<NodeID> = rev_po
-        .iter()
-        .filter(|id| total_edit.1.contains(id) && !new_function.nodes[id.idx()].is_control())
-        .map(|id| *id)
-        .collect();
-
-    // Step 4: figure out the partitions for added control nodes.
-    // Do a bunch of analysis that basically boils down to finding what fork-
-    // joins are top-level.
-    let control_subgraph = control_subgraph(new_function, &def_use);
-    let dom = dominator(&control_subgraph, NodeID::new(0));
-    let fork_join_map = fork_join_map(new_function, &control_subgraph);
-    let fork_join_nesting = compute_fork_join_nesting(new_function, &dom, &fork_join_map);
-    // While building, the new partitions map uses Option since we don't have
-    // partitions for new nodes yet, and we need to record that specifically for
-    // computing the partitions of region nodes.
-    let mut new_partitions: Vec<Option<PartitionID>> =
-        take(&mut plan.partitions).into_iter().map(Some).collect();
-    new_partitions.resize(new_function.nodes.len(), None);
-    // Iterate the added control nodes using a worklist.
-    let mut worklist = VecDeque::from(added_control_nodes);
-    while let Some(control_id) = worklist.pop_front() {
-        let node = &new_function.nodes[control_id.idx()];
-        // There are a few cases where this control node needs to start a new
-        // partition:
-        // 1. It's a non-gravestone start node. This is any start node visited
-        //    by the reverse postorder.
-        // 2. It's a return node.
-        // 3. It's a top-level fork.
-        // 4. One of its control predecessors is a top-level join.
-        // 5. It's a region node where not every predecessor is in the same
-        //    partition (equivalently, not every predecessor is in the same
-        //    partition - only region nodes can have multiple predecessors).
-        // 6. It's a region node with a call user.
-        // 7. Its predecessor is a region node with a call user.
-        let top_level_fork = node.is_fork() && fork_join_nesting[&control_id].len() == 1;
-        let top_level_join = control_subgraph.preds(control_id).any(|pred| {
-            new_function.nodes[pred.idx()].is_join() && fork_join_nesting[&pred].len() == 1
-        });
-        // It's not possible for every predecessor to not have been assigned a
-        // partition yet because of reverse postorder traversal.
-        let multi_pred_region = !control_subgraph
-            .preds(control_id)
-            .map(|pred| new_partitions[pred.idx()])
-            .all_equal();
-        let region_with_call_user = |id: NodeID| {
-            new_function.nodes[id.idx()].is_region()
-                && def_use
-                    .get_users(id)
-                    .as_ref()
-                    .into_iter()
-                    .any(|id| new_function.nodes[id.idx()].is_call())
-        };
-        let call_region = region_with_call_user(control_id);
-        let pred_is_call_region = control_subgraph
-            .preds(control_id)
-            .any(|pred| region_with_call_user(pred));
-
-        if node.is_start()
-            || node.is_return()
-            || top_level_fork
-            || top_level_join
-            || multi_pred_region
-            || call_region
-            || pred_is_call_region
-        {
-            // This control node goes in a new partition.
-            let part_id = PartitionID::new(plan.num_partitions);
-            plan.num_partitions += 1;
-            new_partitions[control_id.idx()] = Some(part_id);
-        } else {
-            // This control node goes in the partition of any one of its
-            // predecessors. They're all the same by condition 3 above.
-            let any_pred = control_subgraph.preds(control_id).next().unwrap();
-            if new_partitions[any_pred.idx()].is_some() {
-                new_partitions[control_id.idx()] = new_partitions[any_pred.idx()];
-            } else {
-                worklist.push_back(control_id);
-            }
-        }
-    }
-
-    // Step 5: figure out the partitions for added data nodes.
-    let antideps = antideps(&new_function, &def_use);
-    let loops = loops(&control_subgraph, NodeID::new(0), &dom, &fork_join_map);
-    let bbs = gcm(
-        new_function,
-        &def_use,
-        &rev_po,
-        &dom,
-        &antideps,
-        &loops,
-        &fork_join_map,
-    );
-    for data_idx in 0..new_function.nodes.len() {
-        new_partitions[data_idx] = new_partitions[bbs[data_idx].idx()];
-    }
-
-    // Step 6: create a solitary gravestone partition. This will get removed
-    // when gravestone nodes are removed.
-    let gravestone_partition = PartitionID::new(plan.num_partitions);
-    plan.num_partitions += 1;
-    for (idx, node) in new_function.nodes.iter().enumerate() {
-        if idx > 0 && node.is_start() {
-            new_partitions[idx] = Some(gravestone_partition);
-        }
-    }
-
-    // Step 7: wrap everything up.
-    plan.partitions = new_partitions.into_iter().map(|id| id.unwrap()).collect();
-    plan.partition_devices
-        .resize(plan.num_partitions, Device::CPU);
-    // Place call partitions on the "AsyncRust" device.
-    for idx in 0..new_function.nodes.len() {
-        if new_function.nodes[idx].is_call() {
-            plan.partition_devices[plan.partitions[idx].idx()] = Device::AsyncRust;
-        }
-    }
-}
-
-/*
- * Default plans can be constructed by conservatively inferring schedules and
- * creating partitions by "repairing" a partition where the edit is adding every
- * node in the function.
- */
-pub fn default_plan(
-    function: &Function,
-    dynamic_constants: &Vec<DynamicConstant>,
-    def_use: &ImmutableDefUseMap,
-    reverse_postorder: &Vec<NodeID>,
-    fork_join_map: &HashMap<NodeID, NodeID>,
-    fork_join_nesting: &HashMap<NodeID, Vec<NodeID>>,
-    bbs: &Vec<NodeID>,
-) -> Plan {
-    // Start by creating a completely bare-bones plan doing nothing interesting.
-    let mut plan = Plan {
-        schedules: vec![vec![]; function.nodes.len()],
-        partitions: vec![],
-        partition_devices: vec![],
-        num_partitions: 1,
-    };
-
-    // Infer a partitioning by using `repair_plan`, where the "edit" is creating
-    // the entire function.
-    let edit = (
-        HashSet::new(),
-        HashSet::from_iter((0..function.nodes.len()).map(NodeID::new)),
-    );
-    repair_plan(&mut plan, function, &[edit]);
-    plan.renumber_partitions();
-
-    // Infer schedules.
-    infer_parallel_reduce(function, fork_join_map, &mut plan);
-    infer_parallel_fork(
-        function,
-        def_use,
-        fork_join_map,
-        fork_join_nesting,
-        &mut plan,
-    );
-    infer_vectorizable(function, dynamic_constants, fork_join_map, &mut plan);
-    infer_associative(function, &mut plan);
-
-    // TODO: uncomment once GPU backend is implemented.
-    // place_fork_partitions_on_gpu(function, &mut plan);
-
-    plan
-}
-
 #[cfg(test)]
 mod editor_tests {
     #[allow(unused_imports)]
@@ -790,6 +585,7 @@ mod editor_tests {
 
     use std::mem::replace;
 
+    use self::hercules_ir::dataflow::reverse_postorder;
     use self::hercules_ir::parse::parse;
 
     fn canonicalize(function: &mut Function) -> Vec<Option<NodeID>> {
diff --git a/hercules_opt/src/fork_concat_split.rs b/hercules_opt/src/fork_concat_split.rs
index df3652df..232b43f7 100644
--- a/hercules_opt/src/fork_concat_split.rs
+++ b/hercules_opt/src/fork_concat_split.rs
@@ -53,6 +53,7 @@ pub fn fork_split(
                     control: acc_fork,
                     factors: Box::new([factor]),
                 });
+                edit.sub_edit(*fork, acc_fork);
                 new_tids.push(edit.add_node(Node::ThreadID {
                     control: acc_fork,
                     dimension: 0,
@@ -68,6 +69,7 @@ pub fn fork_split(
             let mut joins = vec![];
             for _ in new_tids.iter() {
                 acc_join = edit.add_node(Node::Join { control: acc_join });
+                edit.sub_edit(*join, acc_join);
                 joins.push(acc_join);
             }
 
@@ -89,17 +91,18 @@ pub fn fork_split(
                     } else {
                         NodeID::new(num_nodes + join_idx - 1)
                     };
-                    let reduce = edit.add_node(Node::Reduce {
+                    let new_reduce = edit.add_node(Node::Reduce {
                         control: *join,
                         init,
                         reduct,
                     });
-                    assert_eq!(reduce, NodeID::new(num_nodes + join_idx));
+                    assert_eq!(new_reduce, NodeID::new(num_nodes + join_idx));
+                    edit.sub_edit(*reduce, new_reduce);
                     if join_idx == 0 {
-                        inner_reduce = reduce;
+                        inner_reduce = new_reduce;
                     }
                     if join_idx == joins.len() - 1 {
-                        outer_reduce = reduce;
+                        outer_reduce = new_reduce;
                     }
                 }
                 new_reduces.push((inner_reduce, outer_reduce));
@@ -110,6 +113,7 @@ pub fn fork_split(
             edit = edit.replace_all_uses(*join, acc_join)?;
             for tid in tids.iter() {
                 let dim = edit.get_node(*tid).try_thread_id().unwrap().1;
+                edit.sub_edit(*tid, new_tids[dim]);
                 edit = edit.replace_all_uses(*tid, new_tids[dim])?;
             }
             for (reduce, (inner_reduce, outer_reduce)) in zip(reduces.iter(), new_reduces) {
diff --git a/hercules_opt/src/gvn.rs b/hercules_opt/src/gvn.rs
index b8360dd9..a9db0cd8 100644
--- a/hercules_opt/src/gvn.rs
+++ b/hercules_opt/src/gvn.rs
@@ -39,7 +39,8 @@ pub fn gvn(editor: &mut FunctionEditor) {
             // `number`. We want to replace `work` with `number`, which means
             // 1. replacing all uses of `work` with `number`
             // 2. deleting `work`
-            let success = editor.edit(|edit| {
+            let success = editor.edit(|mut edit| {
+                edit.sub_edit(work, *number);
                 let edit = edit.replace_all_uses(work, *number)?;
                 edit.delete_node(work)
             });
diff --git a/hercules_opt/src/inline.rs b/hercules_opt/src/inline.rs
index 6b9e006d..1c485ecc 100644
--- a/hercules_opt/src/inline.rs
+++ b/hercules_opt/src/inline.rs
@@ -6,7 +6,6 @@ use std::iter::zip;
 use self::hercules_ir::callgraph::*;
 use self::hercules_ir::def_use::*;
 use self::hercules_ir::ir::*;
-use self::hercules_ir::schedule::*;
 
 use crate::*;
 
@@ -14,11 +13,7 @@ use crate::*;
  * Top level function to run inlining. Currently, inlines every function call,
  * since mutual recursion is not valid in Hercules IR.
  */
-pub fn inline(
-    editors: &mut [FunctionEditor],
-    callgraph: &CallGraph,
-    mut plans: Option<&mut Vec<Plan>>,
-) {
+pub fn inline(editors: &mut [FunctionEditor], callgraph: &CallGraph) {
     // Step 1: run topological sort on the call graph to inline the "deepest"
     // function first.
     let topo = callgraph.topo();
@@ -54,16 +49,11 @@ pub fn inline(
         // references, we need to do some weirdness here to simultaneously get:
         // 1. A mutable reference to the function we're modifying.
         // 2. Shared references to all of the functions called by that function.
-        // We need to get the same for plans, if we receive them.
         let callees = callgraph.get_callees(to_inline_id);
         let editor_refs = get_mut_and_immuts(editors, to_inline_id, callees);
-        let plan_refs = plans
-            .as_mut()
-            .map(|plans| get_mut_and_immuts(*plans, to_inline_id, callees));
         inline_func(
             editor_refs.0,
             editor_refs.1,
-            plan_refs,
             &single_return_nodes,
             &dc_param_idx_to_dc_id,
         );
@@ -75,7 +65,7 @@ pub fn inline(
  * 1. A single mutable reference.
  * 2. Several shared references.
  * Where none of the references alias. We need to use this both for function
- * editors and plans.
+ * editors and schedules.
  */
 fn get_mut_and_immuts<'a, T, I: ID>(
     mut_refs: &'a mut [T],
@@ -113,7 +103,6 @@ fn get_mut_and_immuts<'a, T, I: ID>(
 fn inline_func(
     editor: &mut FunctionEditor,
     called: HashMap<FunctionID, &FunctionEditor>,
-    plans: Option<(&mut Plan, HashMap<FunctionID, &Plan>)>,
     single_return_nodes: &Vec<Option<NodeID>>,
     dc_param_idx_to_dc_id: &Vec<DynamicConstantID>,
 ) {
@@ -149,7 +138,7 @@ fn inline_func(
         let called_return_data = called_return_uses.as_ref()[1];
 
         // Perform the actual edit.
-        let success = editor.edit(|mut edit| {
+        editor.edit(|mut edit| {
             // Insert the nodes from the called function. There are a few
             // special cases:
             // - Start: don't add start nodes - later, we'll replace_all_uses on
@@ -217,6 +206,11 @@ fn inline_func(
                 // Add the node and check that the IDs line up.
                 let add_id = edit.add_node(node);
                 assert_eq!(add_id, old_id_to_new_id(NodeID::new(idx)));
+                // Copy the schedule from the callee.
+                let callee_schedule = &called_func.schedules[idx];
+                for schedule in callee_schedule {
+                    edit = edit.add_schedule(add_id, schedule.clone())?;
+                }
             }
 
             // Stitch the control use of the inlined start node with the
diff --git a/hercules_opt/src/interprocedural_sroa.rs b/hercules_opt/src/interprocedural_sroa.rs
index 3ab35535..c6cf448b 100644
--- a/hercules_opt/src/interprocedural_sroa.rs
+++ b/hercules_opt/src/interprocedural_sroa.rs
@@ -276,6 +276,7 @@ fn compress_return_products(editors: &mut Vec<FunctionEditor>, all_callsites_edi
                     control: return_control,
                     data: compressed_data_id,
                 });
+                edit.sub_edit(old_return_id, new_return_id);
                 let edit = edit.replace_all_uses(old_return_id, new_return_id)?;
                 edit.delete_node(old_return_id)
             });
@@ -388,10 +389,11 @@ fn remove_return_singletons(editors: &mut Vec<FunctionEditor>, all_callsites_edi
                         collect: old_data,
                         indices: Box::new([Index::Field(0)]),
                     });
-                    edit.add_node(Node::Return {
+                    let new_return_id = edit.add_node(Node::Return {
                         control: old_control,
                         data: extracted_singleton_id,
                     });
+                    edit.sub_edit(old_return_id, new_return_id);
                     edit.set_return_type(tys[0]);
 
                     edit.delete_node(old_return_id)
diff --git a/hercules_opt/src/lib.rs b/hercules_opt/src/lib.rs
index 5a429e14..c4cf21a9 100644
--- a/hercules_opt/src/lib.rs
+++ b/hercules_opt/src/lib.rs
@@ -14,6 +14,7 @@ pub mod outline;
 pub mod pass;
 pub mod phi_elim;
 pub mod pred;
+pub mod schedule;
 pub mod sroa;
 pub mod unforkify;
 pub mod utils;
@@ -32,6 +33,7 @@ pub use crate::outline::*;
 pub use crate::pass::*;
 pub use crate::phi_elim::*;
 pub use crate::pred::*;
+pub use crate::schedule::*;
 pub use crate::sroa::*;
 pub use crate::unforkify::*;
 pub use crate::utils::*;
diff --git a/hercules_opt/src/outline.rs b/hercules_opt/src/outline.rs
index eb8d386c..70062bbc 100644
--- a/hercules_opt/src/outline.rs
+++ b/hercules_opt/src/outline.rs
@@ -201,9 +201,11 @@ pub fn outline(
                     .chain(callee_succ_return_idx.map(|_| u32_ty))
                     .collect(),
             )),
-            nodes: vec![],
             num_dynamic_constants: edit.get_num_dynamic_constant_params(),
             entry: false,
+            nodes: vec![],
+            schedules: vec![],
+            device: None,
         };
 
         // Re-number nodes in the partition.
@@ -413,6 +415,13 @@ pub fn outline(
             });
         }
 
+        // Copy the schedules into the new callee.
+        outlined.schedules.resize(outlined.nodes.len(), vec![]);
+        for id in partition.iter() {
+            let callee_id = convert_id(*id);
+            outlined.schedules[callee_id.idx()] = edit.get_schedule(*id).clone();
+        }
+
         // Step 3: edit the original function to call the outlined function.
         let dynamic_constants = (0..edit.get_num_dynamic_constant_params())
             .map(|idx| edit.add_dynamic_constant(DynamicConstant::Parameter(idx as usize)))
diff --git a/hercules_opt/src/pass.rs b/hercules_opt/src/pass.rs
index 006bd371..90556612 100644
--- a/hercules_opt/src/pass.rs
+++ b/hercules_opt/src/pass.rs
@@ -38,6 +38,7 @@ pub enum Pass {
     DeleteUncalled,
     ForkSplit,
     Unforkify,
+    InferSchedules,
     Verify,
     // Parameterized over whether analyses that aid visualization are necessary.
     // Useful to set to false if displaying a potentially broken module.
@@ -74,9 +75,6 @@ pub struct PassManager {
     pub data_nodes_in_fork_joins: Option<Vec<HashMap<NodeID, HashSet<NodeID>>>>,
     pub bbs: Option<Vec<Vec<NodeID>>>,
     pub callgraph: Option<CallGraph>,
-
-    // Current plan.
-    pub plans: Option<Vec<Plan>>,
 }
 
 impl PassManager {
@@ -98,7 +96,6 @@ impl PassManager {
             data_nodes_in_fork_joins: None,
             bbs: None,
             callgraph: None,
-            plans: None,
         }
     }
 
@@ -324,54 +321,6 @@ impl PassManager {
         }
     }
 
-    pub fn set_plans(&mut self, plans: Vec<Plan>) {
-        self.plans = Some(plans);
-    }
-
-    pub fn make_plans(&mut self) {
-        if self.plans.is_none() {
-            self.make_def_uses();
-            self.make_reverse_postorders();
-            self.make_fork_join_maps();
-            self.make_fork_join_nests();
-            self.make_bbs();
-            let def_uses = self.def_uses.as_ref().unwrap().iter();
-            let reverse_postorders = self.reverse_postorders.as_ref().unwrap().iter();
-            let fork_join_maps = self.fork_join_maps.as_ref().unwrap().iter();
-            let fork_join_nests = self.fork_join_nests.as_ref().unwrap().iter();
-            let bbs = self.bbs.as_ref().unwrap().iter();
-            self.plans = Some(
-                zip(
-                    self.module.functions.iter(),
-                    zip(
-                        def_uses,
-                        zip(
-                            reverse_postorders,
-                            zip(fork_join_maps, zip(fork_join_nests, bbs)),
-                        ),
-                    ),
-                )
-                .map(
-                    |(
-                        function,
-                        (def_use, (reverse_postorder, (fork_join_map, (fork_join_nest, bb)))),
-                    )| {
-                        default_plan(
-                            function,
-                            &self.module.dynamic_constants,
-                            def_use,
-                            reverse_postorder,
-                            fork_join_map,
-                            fork_join_nest,
-                            bb,
-                        )
-                    },
-                )
-                .collect(),
-            );
-        }
-    }
-
     pub fn run_passes(&mut self) {
         for pass in self.passes.clone().iter() {
             match pass {
@@ -397,21 +346,13 @@ impl PassManager {
                         self.module.dynamic_constants = dynamic_constants_ref.take();
                         self.module.types = types_ref.take();
 
-                        let edits = &editor.edits();
-                        if let Some(plans) = self.plans.as_mut() {
-                            repair_plan(&mut plans[idx], &self.module.functions[idx], edits);
-                        }
-                        let grave_mapping = self.module.functions[idx].delete_gravestones();
-                        if let Some(plans) = self.plans.as_mut() {
-                            plans[idx].fix_gravestones(&grave_mapping);
-                        }
+                        self.module.functions[idx].delete_gravestones();
                     }
                     self.clear_analyses();
                 }
                 Pass::InterproceduralSROA => {
                     self.make_def_uses();
                     self.make_typing();
-                    let mut plans = self.plans.as_mut();
 
                     let constants_ref = RefCell::new(std::mem::take(&mut self.module.constants));
                     let dynamic_constants_ref =
@@ -438,21 +379,12 @@ impl PassManager {
 
                     interprocedural_sroa(&mut editors);
 
-                    let function_edits: Vec<_> =
-                        editors.into_iter().map(|e| e.edits()).enumerate().collect();
-
                     self.module.constants = constants_ref.take();
                     self.module.dynamic_constants = dynamic_constants_ref.take();
                     self.module.types = types_ref.take();
 
-                    for (idx, edits) in function_edits {
-                        if let Some(plans) = plans.as_mut() {
-                            repair_plan(&mut plans[idx], &self.module.functions[idx], &edits);
-                        }
-                        let grave_mapping = &self.module.functions[idx].delete_gravestones();
-                        if let Some(plans) = plans.as_mut() {
-                            plans[idx].fix_gravestones(&grave_mapping);
-                        }
+                    for func in self.module.functions.iter_mut() {
+                        func.delete_gravestones();
                     }
 
                     self.clear_analyses();
@@ -481,14 +413,7 @@ impl PassManager {
                         self.module.dynamic_constants = dynamic_constants_ref.take();
                         self.module.types = types_ref.take();
 
-                        let edits = &editor.edits();
-                        if let Some(plans) = self.plans.as_mut() {
-                            repair_plan(&mut plans[idx], &self.module.functions[idx], edits);
-                        }
-                        let grave_mapping = self.module.functions[idx].delete_gravestones();
-                        if let Some(plans) = self.plans.as_mut() {
-                            plans[idx].fix_gravestones(&grave_mapping);
-                        }
+                        self.module.functions[idx].delete_gravestones();
                     }
                     self.clear_analyses();
                 }
@@ -514,14 +439,7 @@ impl PassManager {
                         self.module.dynamic_constants = dynamic_constants_ref.take();
                         self.module.types = types_ref.take();
 
-                        let edits = &editor.edits();
-                        if let Some(plans) = self.plans.as_mut() {
-                            repair_plan(&mut plans[idx], &self.module.functions[idx], edits);
-                        }
-                        let grave_mapping = self.module.functions[idx].delete_gravestones();
-                        if let Some(plans) = self.plans.as_mut() {
-                            plans[idx].fix_gravestones(&grave_mapping);
-                        }
+                        self.module.functions[idx].delete_gravestones();
                     }
                     self.clear_analyses();
                 }
@@ -539,7 +457,6 @@ impl PassManager {
                             &loops[idx],
                         )
                     }
-                    self.legacy_repair_plan();
                     self.clear_analyses();
                 }
                 Pass::PhiElim => {
@@ -564,14 +481,7 @@ impl PassManager {
                         self.module.dynamic_constants = dynamic_constants_ref.take();
                         self.module.types = types_ref.take();
 
-                        let edits = &editor.edits();
-                        if let Some(plans) = self.plans.as_mut() {
-                            repair_plan(&mut plans[idx], &self.module.functions[idx], edits);
-                        }
-                        let grave_mapping = self.module.functions[idx].delete_gravestones();
-                        if let Some(plans) = self.plans.as_mut() {
-                            plans[idx].fix_gravestones(&grave_mapping);
-                        }
+                        self.module.functions[idx].delete_gravestones();
                     }
                     self.clear_analyses();
                 }
@@ -588,7 +498,6 @@ impl PassManager {
                             &def_uses[idx],
                         )
                     }
-                    self.legacy_repair_plan();
                     self.clear_analyses();
                 }
                 Pass::Predication => {
@@ -596,12 +505,10 @@ impl PassManager {
                     self.make_reverse_postorders();
                     self.make_doms();
                     self.make_fork_join_maps();
-                    self.make_plans();
                     let def_uses = self.def_uses.as_ref().unwrap();
                     let reverse_postorders = self.reverse_postorders.as_ref().unwrap();
                     let doms = self.doms.as_ref().unwrap();
                     let fork_join_maps = self.fork_join_maps.as_ref().unwrap();
-                    let plans = self.plans.as_ref().unwrap();
                     for idx in 0..self.module.functions.len() {
                         predication(
                             &mut self.module.functions[idx],
@@ -609,10 +516,8 @@ impl PassManager {
                             &reverse_postorders[idx],
                             &doms[idx],
                             &fork_join_maps[idx],
-                            &plans[idx].schedules,
                         )
                     }
-                    self.legacy_repair_plan();
                     self.clear_analyses();
                 }
                 Pass::SROA => {
@@ -641,14 +546,7 @@ impl PassManager {
                         self.module.dynamic_constants = dynamic_constants_ref.take();
                         self.module.types = types_ref.take();
 
-                        let edits = &editor.edits();
-                        if let Some(plans) = self.plans.as_mut() {
-                            repair_plan(&mut plans[idx], &self.module.functions[idx], edits);
-                        }
-                        let grave_mapping = self.module.functions[idx].delete_gravestones();
-                        if let Some(plans) = self.plans.as_mut() {
-                            plans[idx].fix_gravestones(&grave_mapping);
-                        }
+                        self.module.functions[idx].delete_gravestones();
                     }
                     self.clear_analyses();
                 }
@@ -673,21 +571,14 @@ impl PassManager {
                                 )
                             })
                             .collect();
-                    // Inlining is special in that it may modify partitions in a
-                    // inter-procedural fashion.
-                    inline(&mut editors, callgraph, self.plans.as_mut());
+                    inline(&mut editors, callgraph);
+
                     self.module.constants = constants_ref.take();
                     self.module.dynamic_constants = dynamic_constants_ref.take();
                     self.module.types = types_ref.take();
-                    let edits: Vec<_> = editors.into_iter().map(|editor| editor.edits()).collect();
-                    for idx in 0..edits.len() {
-                        if let Some(plans) = self.plans.as_mut() {
-                            repair_plan(&mut plans[idx], &self.module.functions[idx], &edits[idx]);
-                        }
-                        let grave_mapping = self.module.functions[idx].delete_gravestones();
-                        if let Some(plans) = self.plans.as_mut() {
-                            plans[idx].fix_gravestones(&grave_mapping);
-                        }
+
+                    for func in self.module.functions.iter_mut() {
+                        func.delete_gravestones();
                     }
                     self.clear_analyses();
                 }
@@ -715,11 +606,6 @@ impl PassManager {
                         collapse_returns(editor);
                         ensure_between_control_flow(editor);
                     }
-                    let mut edits: Vec<_> = editors
-                        .into_iter()
-                        .enumerate()
-                        .map(|(idx, editor)| (idx, editor.edits()))
-                        .collect();
                     self.module.constants = constants_ref.take();
                     self.module.dynamic_constants = dynamic_constants_ref.take();
                     self.module.types = types_ref.take();
@@ -766,27 +652,9 @@ impl PassManager {
                     self.module.constants = constants_ref.take();
                     self.module.dynamic_constants = dynamic_constants_ref.take();
                     self.module.types = types_ref.take();
-                    edits.extend(
-                        editors
-                            .into_iter()
-                            .enumerate()
-                            .map(|(idx, editor)| (idx, editor.edits())),
-                    );
 
-                    for (func_idx, edit) in edits {
-                        if let Some(plans) = self.plans.as_mut() {
-                            repair_plan(
-                                &mut plans[func_idx],
-                                &self.module.functions[func_idx],
-                                &edit,
-                            );
-                        }
-                    }
-                    for idx in 0..self.module.functions.len() {
-                        let grave_mapping = self.module.functions[idx].delete_gravestones();
-                        if let Some(plans) = self.plans.as_mut() {
-                            plans[idx].fix_gravestones(&grave_mapping);
-                        }
+                    for func in self.module.functions.iter_mut() {
+                        func.delete_gravestones();
                     }
                     self.module.functions.extend(new_funcs);
                     self.clear_analyses();
@@ -821,15 +689,8 @@ impl PassManager {
                     self.module.dynamic_constants = dynamic_constants_ref.take();
                     self.module.types = types_ref.take();
 
-                    let edits: Vec<_> = editors.into_iter().map(|editor| editor.edits()).collect();
-                    for idx in 0..edits.len() {
-                        if let Some(plans) = self.plans.as_mut() {
-                            repair_plan(&mut plans[idx], &self.module.functions[idx], &edits[idx]);
-                        }
-                        let grave_mapping = self.module.functions[idx].delete_gravestones();
-                        if let Some(plans) = self.plans.as_mut() {
-                            plans[idx].fix_gravestones(&grave_mapping);
-                        }
+                    for func in self.module.functions.iter_mut() {
+                        func.delete_gravestones();
                     }
 
                     self.fix_deleted_functions(&new_idx);
@@ -863,14 +724,7 @@ impl PassManager {
                         self.module.dynamic_constants = dynamic_constants_ref.take();
                         self.module.types = types_ref.take();
 
-                        let edits = &editor.edits();
-                        if let Some(plans) = self.plans.as_mut() {
-                            repair_plan(&mut plans[idx], &self.module.functions[idx], edits);
-                        }
-                        let grave_mapping = self.module.functions[idx].delete_gravestones();
-                        if let Some(plans) = self.plans.as_mut() {
-                            plans[idx].fix_gravestones(&grave_mapping);
-                        }
+                        self.module.functions[idx].delete_gravestones();
                     }
                     self.clear_analyses();
                 }
@@ -898,14 +752,38 @@ impl PassManager {
                         self.module.dynamic_constants = dynamic_constants_ref.take();
                         self.module.types = types_ref.take();
 
-                        let edits = &editor.edits();
-                        if let Some(plans) = self.plans.as_mut() {
-                            repair_plan(&mut plans[idx], &self.module.functions[idx], edits);
-                        }
-                        let grave_mapping = self.module.functions[idx].delete_gravestones();
-                        if let Some(plans) = self.plans.as_mut() {
-                            plans[idx].fix_gravestones(&grave_mapping);
-                        }
+                        self.module.functions[idx].delete_gravestones();
+                    }
+                    self.clear_analyses();
+                }
+                Pass::InferSchedules => {
+                    self.make_def_uses();
+                    self.make_fork_join_maps();
+                    let def_uses = self.def_uses.as_ref().unwrap();
+                    let fork_join_maps = self.fork_join_maps.as_ref().unwrap();
+                    for idx in 0..self.module.functions.len() {
+                        let constants_ref =
+                            RefCell::new(std::mem::take(&mut self.module.constants));
+                        let dynamic_constants_ref =
+                            RefCell::new(std::mem::take(&mut self.module.dynamic_constants));
+                        let types_ref = RefCell::new(std::mem::take(&mut self.module.types));
+                        let mut editor = FunctionEditor::new(
+                            &mut self.module.functions[idx],
+                            &constants_ref,
+                            &dynamic_constants_ref,
+                            &types_ref,
+                            &def_uses[idx],
+                        );
+                        infer_parallel_reduce(&mut editor, &fork_join_maps[idx]);
+                        infer_parallel_fork(&mut editor, &fork_join_maps[idx]);
+                        infer_vectorizable(&mut editor, &fork_join_maps[idx]);
+                        infer_associative(&mut editor);
+
+                        self.module.constants = constants_ref.take();
+                        self.module.dynamic_constants = dynamic_constants_ref.take();
+                        self.module.types = types_ref.take();
+
+                        self.module.functions[idx].delete_gravestones();
                     }
                     self.clear_analyses();
                 }
@@ -930,17 +808,6 @@ impl PassManager {
                     self.doms = Some(doms);
                     self.postdoms = Some(postdoms);
                     self.fork_join_maps = Some(fork_join_maps);
-
-                    // Verify the plan, if it exists.
-                    if let Some(plans) = &self.plans {
-                        for idx in 0..self.module.functions.len() {
-                            plans[idx].verify_partitioning(
-                                &self.module.functions[idx],
-                                &self.def_uses.as_ref().unwrap()[idx],
-                                &self.fork_join_maps.as_ref().unwrap()[idx],
-                            );
-                        }
-                    }
                 }
                 Pass::Xdot(force_analyses) => {
                     self.make_reverse_postorders();
@@ -955,7 +822,6 @@ impl PassManager {
                         self.doms.as_ref(),
                         self.fork_join_maps.as_ref(),
                         self.bbs.as_ref(),
-                        self.plans.as_ref(),
                     );
                 }
                 Pass::Codegen(output_dir, module_name) => {
@@ -1064,21 +930,6 @@ impl PassManager {
         }
     }
 
-    fn legacy_repair_plan(&mut self) {
-        // Cleanup the module after passes. Delete gravestone nodes. Repair the
-        // plans.
-        for idx in 0..self.module.functions.len() {
-            let grave_mapping = self.module.functions[idx].delete_gravestones();
-            let plans = &mut self.plans;
-            let functions = &self.module.functions;
-            if let Some(plans) = plans.as_mut() {
-                take_mut::take(&mut plans[idx], |plan| {
-                    plan.repair(&functions[idx], &grave_mapping)
-                });
-            }
-        }
-    }
-
     fn clear_analyses(&mut self) {
         self.def_uses = None;
         self.reverse_postorders = None;
@@ -1092,8 +943,6 @@ impl PassManager {
         self.antideps = None;
         self.bbs = None;
         self.callgraph = None;
-
-        // Don't clear the plan - this is repaired, not reconstructed.
     }
 
     pub fn get_module(self) -> Module {
diff --git a/hercules_opt/src/pred.rs b/hercules_opt/src/pred.rs
index f8478ded..09d9753d 100644
--- a/hercules_opt/src/pred.rs
+++ b/hercules_opt/src/pred.rs
@@ -10,7 +10,6 @@ use self::bitvec::prelude::*;
 use self::hercules_ir::def_use::*;
 use self::hercules_ir::dom::*;
 use self::hercules_ir::ir::*;
-use self::hercules_ir::schedule::*;
 
 /*
  * Top level function to convert acyclic control flow in vectorized fork-joins
@@ -22,7 +21,6 @@ pub fn predication(
     reverse_postorder: &Vec<NodeID>,
     dom: &DomTree,
     fork_join_map: &HashMap<NodeID, NodeID>,
-    schedules: &Vec<Vec<Schedule>>,
 ) {
     // Detect forks with vectorize schedules.
     let vector_forks: Vec<_> = function
diff --git a/hercules_opt/src/schedule.rs b/hercules_opt/src/schedule.rs
new file mode 100644
index 00000000..b65b4d04
--- /dev/null
+++ b/hercules_opt/src/schedule.rs
@@ -0,0 +1,175 @@
+extern crate hercules_ir;
+
+use std::collections::HashMap;
+
+use self::hercules_ir::def_use::*;
+use self::hercules_ir::ir::*;
+
+use crate::*;
+
+/*
+ * Infer parallel fork-joins. These are fork-joins with only parallel reduction
+ * variables.
+ */
+pub fn infer_parallel_fork(editor: &mut FunctionEditor, fork_join_map: &HashMap<NodeID, NodeID>) {
+    for id in editor.node_ids() {
+        let func = editor.func();
+        let Node::Fork {
+            control: _,
+            factors: _,
+        } = func.nodes[id.idx()]
+        else {
+            continue;
+        };
+        let join_id = fork_join_map[&id];
+        let all_parallel_reduce = editor.get_users(join_id).all(|user| {
+            func.schedules[user.idx()].contains(&Schedule::ParallelReduce)
+                || func.nodes[user.idx()].is_control()
+        });
+        if all_parallel_reduce {
+            editor.edit(|edit| edit.add_schedule(id, Schedule::ParallelFork));
+        }
+    }
+}
+
+/*
+ * Infer parallel reductions consisting of a simple cycle between a Reduce node
+ * and a Write node, where indices of the Write are position indices using the
+ * ThreadID nodes attached to the corresponding Fork. This procedure also adds
+ * the ParallelReduce schedule to Reduce nodes reducing over a parallelized
+ * Reduce, as long as the base Write node also has position indices of the
+ * ThreadID of the outer fork. In other words, the complete Reduce chain is
+ * annotated with ParallelReduce, as long as each ThreadID dimension appears in
+ * the positional indexing of the original Write.
+ */
+pub fn infer_parallel_reduce(editor: &mut FunctionEditor, fork_join_map: &HashMap<NodeID, NodeID>) {
+    for id in editor.node_ids() {
+        let func = editor.func();
+        if !func.nodes[id.idx()].is_reduce() {
+            continue;
+        }
+
+        let mut first_control = None;
+        let mut last_reduce = id;
+        let mut chain_id = id;
+
+        // Walk down Reduce chain until we reach the Reduce potentially looping
+        // with the Write. Note the control node of the first Reduce, since this
+        // will tell us which Thread ID to look for in the Write.
+        while let Node::Reduce {
+            control,
+            init: _,
+            reduct,
+        } = func.nodes[chain_id.idx()]
+        {
+            if first_control.is_none() {
+                first_control = Some(control);
+            }
+
+            last_reduce = chain_id;
+            chain_id = reduct;
+        }
+
+        // Check for a Write-Reduce tight cycle.
+        if let Node::Write {
+            collect,
+            data: _,
+            indices,
+        } = &func.nodes[chain_id.idx()]
+            && *collect == last_reduce
+        {
+            // If there is a Write-Reduce tight cycle, get the position indices.
+            let positions = indices
+                .iter()
+                .filter_map(|index| {
+                    if let Index::Position(indices) = index {
+                        Some(indices)
+                    } else {
+                        None
+                    }
+                })
+                .flat_map(|pos| pos.iter());
+
+            // Get the Forks corresponding to uses of bare ThreadIDs.
+            let fork_thread_id_pairs = positions.filter_map(|id| {
+                if let Node::ThreadID { control, dimension } = func.nodes[id.idx()] {
+                    Some((control, dimension))
+                } else {
+                    None
+                }
+            });
+            let mut forks = HashMap::<NodeID, Vec<usize>>::new();
+            for (fork, dim) in fork_thread_id_pairs {
+                forks.entry(fork).or_default().push(dim);
+            }
+
+            // Check if one of the Forks correspond to the Join associated with
+            // the Reduce being considered, and has all of its dimensions
+            // represented in the indexing.
+            let is_parallel = forks.into_iter().any(|(id, mut rep_dims)| {
+                rep_dims.sort();
+                rep_dims.dedup();
+                fork_join_map[&id] == first_control.unwrap()
+                    && func.nodes[id.idx()].try_fork().unwrap().1.len() == rep_dims.len()
+            });
+
+            if is_parallel {
+                editor.edit(|edit| edit.add_schedule(id, Schedule::ParallelReduce));
+            }
+        }
+    }
+}
+
+/*
+ * Infer vectorizable fork-joins. Just check that there are no control nodes
+ * between a fork and its join and the factor is a constant.
+ */
+pub fn infer_vectorizable(editor: &mut FunctionEditor, fork_join_map: &HashMap<NodeID, NodeID>) {
+    for id in editor.node_ids() {
+        let func = editor.func();
+        if !func.nodes[id.idx()].is_join() {
+            continue;
+        }
+
+        let u = get_uses(&func.nodes[id.idx()]).as_ref()[0];
+        if let Some(join) = fork_join_map.get(&u)
+            && *join == id
+        {
+            let factors = func.nodes[u.idx()].try_fork().unwrap().1;
+            if factors.len() == 1
+                && evaluate_dynamic_constant(factors[0], &editor.get_dynamic_constants()).is_some()
+            {
+                editor.edit(|edit| edit.add_schedule(u, Schedule::Vectorizable));
+            }
+        }
+    }
+}
+
+/*
+ * Infer associative reduction loops.
+ */
+pub fn infer_associative(editor: &mut FunctionEditor) {
+    let is_associative = |op| match op {
+        BinaryOperator::Add
+        | BinaryOperator::Mul
+        | BinaryOperator::Or
+        | BinaryOperator::And
+        | BinaryOperator::Xor => true,
+        _ => false,
+    };
+
+    for id in editor.node_ids() {
+        let func = editor.func();
+        if let Node::Reduce {
+            control: _,
+            init: _,
+            reduct,
+        } = func.nodes[id.idx()]
+            && let Node::Binary { left, right, op } = func.nodes[reduct.idx()]
+            && (left == id || right == id)
+            && is_associative(op)
+        {
+            editor.edit(|edit| edit.add_schedule(id, Schedule::Associative));
+        }
+    }
+}
diff --git a/hercules_opt/src/sroa.rs b/hercules_opt/src/sroa.rs
index 67c904ff..a73ecb2b 100644
--- a/hercules_opt/src/sroa.rs
+++ b/hercules_opt/src/sroa.rs
@@ -100,10 +100,11 @@ pub fn sroa(editor: &mut FunctionEditor, reverse_postorder: &Vec<NodeID>, types:
                 let control = *control;
                 let new_data = reconstruct_product(editor, types[&data], *data, &mut product_nodes);
                 editor.edit(|mut edit| {
-                    edit.add_node(Node::Return {
+                    let new_return = edit.add_node(Node::Return {
                         control,
                         data: new_data,
                     });
+                    edit.sub_edit(node, new_return);
                     edit.delete_node(node)
                 });
             }
@@ -145,6 +146,7 @@ pub fn sroa(editor: &mut FunctionEditor, reverse_postorder: &Vec<NodeID>, types:
                         dynamic_constants,
                         args: new_args.into(),
                     });
+                    edit.sub_edit(node, new_call);
                     let edit = edit.replace_all_uses(node, new_call)?;
                     let edit = edit.delete_node(node)?;
 
@@ -344,7 +346,7 @@ pub fn sroa(editor: &mut FunctionEditor, reverse_postorder: &Vec<NodeID>, types:
                 }
 
                 if ready {
-                    fields.zip_list(data_fields).for_each(|idx, (res, data)| {
+                    fields.zip_list(data_fields).for_each(|_, (res, data)| {
                         to_insert.insert(
                             res.idx(),
                             Node::Phi {
@@ -374,7 +376,7 @@ pub fn sroa(editor: &mut FunctionEditor, reverse_postorder: &Vec<NodeID>, types:
                     (field_map.get(&init), field_map.get(&reduct))
                 {
                     fields.zip(init_fields).zip(reduct_fields).for_each(
-                        |idx, ((res, init), reduct)| {
+                        |_, ((res, init), reduct)| {
                             to_insert.insert(
                                 res.idx(),
                                 Node::Reduce {
@@ -409,7 +411,7 @@ pub fn sroa(editor: &mut FunctionEditor, reverse_postorder: &Vec<NodeID>, types:
                     fields
                         .zip(thn_fields)
                         .zip(els_fields)
-                        .for_each(|idx, ((res, thn), els)| {
+                        .for_each(|_, ((res, thn), els)| {
                             to_insert.insert(
                                 res.idx(),
                                 Node::Ternary {
@@ -458,7 +460,10 @@ pub fn sroa(editor: &mut FunctionEditor, reverse_postorder: &Vec<NodeID>, types:
             None => new,
         };
 
-        editor.edit(|edit| edit.replace_all_uses(old, new));
+        editor.edit(|mut edit| {
+            edit.sub_edit(old, new);
+            edit.replace_all_uses(old, new)
+        });
         replaced_by.insert(old, new);
 
         let mut replaced = vec![];
diff --git a/hercules_opt/src/unforkify.rs b/hercules_opt/src/unforkify.rs
index f31b7409..61c86a27 100644
--- a/hercules_opt/src/unforkify.rs
+++ b/hercules_opt/src/unforkify.rs
@@ -125,10 +125,14 @@ pub fn unforkify(editor: &mut FunctionEditor, fork_join_map: &HashMap<NodeID, No
 
             edit = edit.replace_all_uses(*fork, proj_back_id)?;
             edit = edit.replace_all_uses(*join, proj_exit_id)?;
+            edit.sub_edit(*fork, region_id);
+            edit.sub_edit(*join, if_id);
             for tid in tids.iter() {
+                edit.sub_edit(*tid, indvar_id);
                 edit = edit.replace_all_uses(*tid, indvar_id)?;
             }
             for (reduce, phi_id) in zip(reduces.iter(), phi_ids) {
+                edit.sub_edit(*reduce, phi_id);
                 edit = edit.replace_all_uses(*reduce, phi_id)?;
             }
 
diff --git a/juno_frontend/src/lib.rs b/juno_frontend/src/lib.rs
index c39faef9..4ff5d9fc 100644
--- a/juno_frontend/src/lib.rs
+++ b/juno_frontend/src/lib.rs
@@ -120,6 +120,8 @@ pub fn compile_ir(
 ) -> Result<(), ErrorMessage> {
     let mut pm = match schedule {
         JunoSchedule::None => hercules_opt::pass::PassManager::new(module),
+        _ => todo!(),
+        /*
         JunoSchedule::DefaultSchedule => {
             let mut pm = hercules_opt::pass::PassManager::new(module);
             pm.make_plans();
@@ -143,6 +145,7 @@ pub fn compile_ir(
                 }
             }
         }
+        */
     };
     if verify.verify() || verify.verify_all() {
         pm.add_pass(hercules_opt::pass::Pass::Verify);
@@ -187,6 +190,7 @@ pub fn compile_ir(
     add_pass!(pm, verify, Outline);
     add_pass!(pm, verify, InterproceduralSROA);
     add_pass!(pm, verify, SROA);
+    add_pass!(pm, verify, InferSchedules);
     add_pass!(pm, verify, ForkSplit);
     add_pass!(pm, verify, Unforkify);
     add_pass!(pm, verify, GVN);
diff --git a/juno_scheduler/src/lib.rs b/juno_scheduler/src/lib.rs
index 36ea79e9..7e558d6b 100644
--- a/juno_scheduler/src/lib.rs
+++ b/juno_scheduler/src/lib.rs
@@ -8,7 +8,6 @@ use lrlex::DefaultLexerTypes;
 use lrpar::NonStreamingLexer;
 
 use self::hercules_ir::ir::*;
-use self::hercules_ir::schedule::*;
 
 mod parser;
 use crate::parser::lexer;
@@ -47,6 +46,7 @@ pub enum LabeledStructure {
     Branch(NodeID), // If node
 }
 
+/*
 pub fn schedule(module: &Module, info: FunctionMap, schedule: String) -> Result<Vec<Plan>, String> {
     if let Ok(mut file) = File::open(schedule) {
         let mut contents = String::new();
@@ -337,3 +337,4 @@ fn generate_schedule(
         .map(|(f, p)| (f, p.into()))
         .collect::<HashMap<_, _>>())
 }
+*/
-- 
GitLab