From 702bf42ad8a1ebc155725ced49d3d3a8586a644d Mon Sep 17 00:00:00 2001
From: rarbore2 <rarbore2@illinois.edu>
Date: Thu, 26 Sep 2024 09:29:07 -0500
Subject: [PATCH] Fixes to backend to support matmul

---
 Cargo.lock                          |  11 ++
 Cargo.toml                          |   2 +-
 hercules_cg/Cargo.toml              |   1 +
 hercules_cg/src/cpu.rs              | 191 +++++++++++++++++-----------
 hercules_cg/src/lib.rs              |   2 +
 hercules_cg/src/sched_dot.rs        | 178 ++++++++++++++++++++++++++
 hercules_cg/src/sched_gen.rs        |  15 ++-
 hercules_cg/src/sched_ir.rs         | 107 +++++++++++++++-
 hercules_cg/src/sched_schedule.rs   |  57 ++++++++-
 hercules_ir/src/dot.rs              |   4 +-
 hercules_opt/src/pass.rs            |  30 +++++
 hercules_rt_proc/src/lib.rs         |   4 +-
 hercules_samples/matmul/matmul.hir  |   6 +-
 hercules_samples/matmul/src/main.rs |  14 +-
 14 files changed, 521 insertions(+), 101 deletions(-)
 create mode 100644 hercules_cg/src/sched_dot.rs

diff --git a/Cargo.lock b/Cargo.lock
index cc0a005b..6dfa8e9c 100644
--- a/Cargo.lock
+++ b/Cargo.lock
@@ -588,6 +588,7 @@ dependencies = [
  "bitvec",
  "hercules_ir",
  "ordered-float",
+ "rand",
  "serde",
 ]
 
@@ -636,6 +637,16 @@ dependencies = [
  "serde",
 ]
 
+[[package]]
+name = "hercules_matmul"
+version = "0.1.0"
+dependencies = [
+ "async-std",
+ "clap",
+ "hercules_rt",
+ "rand",
+]
+
 [[package]]
 name = "hercules_opt"
 version = "0.1.0"
diff --git a/Cargo.toml b/Cargo.toml
index 67d0cea4..31639dce 100644
--- a/Cargo.toml
+++ b/Cargo.toml
@@ -16,6 +16,6 @@ members = [
 	"juno_frontend",
 
 	"hercules_samples/dot",
-	#"hercules_samples/matmul",
+	"hercules_samples/matmul",
 	#"hercules_samples/task_parallel"
 ]
diff --git a/hercules_cg/Cargo.toml b/hercules_cg/Cargo.toml
index 8658a637..8a60956b 100644
--- a/hercules_cg/Cargo.toml
+++ b/hercules_cg/Cargo.toml
@@ -4,6 +4,7 @@ version = "0.1.0"
 authors = ["Russel Arbore <rarbore2@illinois.edu>"]
 
 [dependencies]
+rand = "*"
 ordered-float = "*"
 bitvec = "*"
 serde = { version = "*", features = ["derive"] }
diff --git a/hercules_cg/src/cpu.rs b/hercules_cg/src/cpu.rs
index 4f93707f..898f83d3 100644
--- a/hercules_cg/src/cpu.rs
+++ b/hercules_cg/src/cpu.rs
@@ -2,6 +2,7 @@ extern crate bitvec;
 
 use std::collections::{HashMap, VecDeque};
 use std::fmt::{Error, Write};
+use std::iter::once;
 
 use self::bitvec::prelude::*;
 
@@ -30,37 +31,22 @@ use crate::*;
  *    be identical, even though for both, one of v1 or v2 doesn't dominate the
  *    PartitionExit. What *should* happen here is that each PartitionExit gets
  *    lowered to an LLVM `ret`, where the non-dominating output is set to
- *    `undef` This works since in the original, un-partitioned, Hercules IR,
+ *    `undef`. This works since in the original, un-partitioned, Hercules IR,
  *    defs must dominate uses, so we won't run into a situation where a returned
  *    `undef` value is actually read. What happens currently is that the
  *    generated LLVM will `ret` `%v1` and `%v2`, which LLVM won't compile (since
  *    the code wouldn't be in SSA form). This should get fixed when we start
  *    compiling more complicated codes.
  *
- * 2. We're a bit, "loose", with how certain basic blocks involving fork-joins
- *    are handled. In particular, the following things need to be handled
- *    properly, but aren't yet:
- *    - Phis in a block with a reduce block predecessor need to generate LLVM
- *      phis that actually depend on the top parallel block of that fork-join,
- *      not the reduce block. This is because we hijack the top block to be the
- *      loop header, rather than use the reduce block to be the loop header.
- *    - The above also applies in the case of phis generated from thread IDs and
- *      reduction variables inside the top parallel block of a fork-join. This
- *      case occurs when there is a parallel-reduce section inside another
- *      parallel-reduce section.
- *    - The above also applies in the case of a parallel-reduce section
- *      immediately following another parallel-reduce section (a reduce block
- *      jumps to the top parallel block of another parallel-reduce section).
- *
- * 3. Handle >= 3D fork-joins and array accesses. This isn't conceptually
+ * 2. Handle >= 3D fork-joins and array accesses. This isn't conceptually
  *    difficult, but generating the LLVM code to implement these is annoying.
  *
- * 4. Handle ABI properly when taking in / returning structs taking moret han 16
+ * 3. Handle ABI properly when taking in / returning structs taking more than 16
  *    bytes. When a passed / returned struct takes more than 16 bytes, it needs
  *    to be passed around via pointers. This is one of many platform specific C
  *    ABI rules we need to handle to be properly called from Rust (that 16 byte
  *    rule is actually x86-64 specific). I'm honestly not sure how to handle
- *    this well. We avoid running into the manifestation of this problem by for
+ *    this well. We avoid running into the manifestation of this problem for
  *    some samples by removing unneeded parameters / return values from
  *    partitions at the schedule IR level, which we should do anyway, but this
  *    isn't a complete solution.
@@ -81,6 +67,61 @@ pub fn cpu_compile<W: Write>(
     let svalue_types = sched_svalue_types(function);
     let parallel_reduce_infos = sched_parallel_reduce_sections(function);
 
+    // Calculate the names of each block. For blocks that are the top or bottom
+    // blocks of sequential fork-joins, references outside the fork-join
+    // actually need to refer to the header block. This is a bit complicated to
+    // handle, and we use these names in several places, so pre-calculate the
+    // block names. Intuitively, if we are "inside" a sequential fork-join,
+    // references to the top or bottom blocks actually refer to those blocks,
+    // while if we are "outside" the sequential fork-join, references to both
+    // the top or bottom blocks actually refer to the loop header block.
+    let mut block_names = HashMap::new();
+    for (block_idx, block) in function.blocks.iter().enumerate() {
+        for fork_join_id in parallel_reduce_infos
+            .keys()
+            .map(|id| Some(*id))
+            .chain(once(None))
+        {
+            let block_id = BlockID::new(block_idx);
+            let possible_parent = block.kind.try_fork_join_id();
+            let mut walk = fork_join_id;
+
+            // Check if the location in the key is "inside" the location of the
+            // block.
+            let is_inside = if let Some(parent) = possible_parent {
+                loop {
+                    if let Some(step) = walk {
+                        if step == parent {
+                            // If we see the block's location, then the key
+                            // location is "inside".
+                            break true;
+                        } else {
+                            // Walk the parent until we find something
+                            // interesting.
+                            walk = parallel_reduce_infos[&step].parent_fork_join_id;
+                        }
+                    } else {
+                        // If we don't find the block, then the key location is
+                        // "outside" the block's parallel-reduce.
+                        break false;
+                    }
+                }
+            } else {
+                // Every location is "inside" the top level sequential section.
+                true
+            };
+
+            if is_inside {
+                block_names.insert((block_id, fork_join_id), format!("bb_{}", block_idx));
+            } else {
+                block_names.insert(
+                    (block_id, fork_join_id),
+                    format!("fork_join_seq_header_{}", possible_parent.unwrap().idx()),
+                );
+            }
+        }
+    }
+
     // Generate a dummy uninitialized global - this is needed so that there'll
     // be a non-empty .bss section in the ELF object file.
     write!(
@@ -115,11 +156,8 @@ pub fn cpu_compile<W: Write>(
 
     // Emit the function body. Emit each block, one at a time.
     for (block_idx, block) in function.blocks.iter().enumerate() {
-        // Emit the header for the block.
-        write!(w, "bb_{}:\n", block_idx)?;
-
-        // For "tops" of sequential fork-joins, we hijack to the top block to be
-        // the loop header for the fork-join loop.
+        // For "tops" of sequential fork-joins, we emit a special top block to
+        // be the loop header for the fork-join loop.
         if let Some(fork_join_id) = block.kind.try_parallel()
             && parallel_reduce_infos[&fork_join_id]
                 .top_parallel_block
@@ -127,13 +165,22 @@ pub fn cpu_compile<W: Write>(
                 == block_idx
         {
             emit_fork_join_seq_header(
+                fork_join_id,
                 &parallel_reduce_infos[&fork_join_id],
                 &svalue_types,
+                &block_names,
                 block_idx,
                 w,
             )?;
         }
 
+        // Emit the header for the block.
+        write!(
+            w,
+            "{}:\n",
+            &block_names[&(BlockID::new(block_idx), block.kind.try_fork_join_id())]
+        )?;
+
         // For each basic block, emit instructions in that block. Emit using a
         // worklist over the dependency graph.
         let mut emitted = bitvec![u8, Lsb0; 0; block.insts.len()];
@@ -152,10 +199,12 @@ pub fn cpu_compile<W: Write>(
                 emit_inst(
                     block.virt_regs[inst_id.idx_1()].0,
                     &block.insts[inst_idx],
+                    block.kind.try_fork_join_id(),
                     block
                         .kind
                         .try_fork_join_id()
                         .map(|fork_join_id| &parallel_reduce_infos[&fork_join_id]),
+                    &block_names,
                     &svalue_types,
                     w,
                 )?;
@@ -271,7 +320,9 @@ fn emit_svalue<W: Write>(
 fn emit_inst<W: Write>(
     virt_reg: usize,
     inst: &SInst,
+    location: Option<ForkJoinID>,
     parallel_reduce_info: Option<&ParallelReduceInfo>,
+    block_names: &HashMap<(BlockID, Option<ForkJoinID>), String>,
     types: &HashMap<SValue, SType>,
     w: &mut W,
 ) -> Result<(), Error> {
@@ -295,17 +346,19 @@ fn emit_inst<W: Write>(
                     Some((pred_block_id, svalue)) => {
                         write!(w, "[ ")?;
                         emit_svalue(svalue, false, types, w)?;
-                        write!(w, ", %bb_{} ]", pred_block_id.idx())?;
+                        write!(w, ", %{} ]", &block_names[&(*pred_block_id, location)])?;
                         Ok(())
                     }
                     None => write!(w, ", "),
                 })
                 .collect::<Result<(), Error>>()?;
         }
-        SInst::ThreadID { dimension } => {
-            let block = parallel_reduce_info.unwrap().top_parallel_block;
+        SInst::ThreadID {
+            dimension,
+            fork_join,
+        } => {
             emit_assign(w)?;
-            write!(w, "add i64 0, %thread_id_{}_{}", block.idx(), dimension)?;
+            write!(w, "add i64 0, %thread_id_{}_{}", fork_join.idx(), dimension)?;
         }
         SInst::ReductionVariable { number } => {
             write!(w, "; Already emitted reduction variable #{number}.")?;
@@ -318,11 +371,11 @@ fn emit_inst<W: Write>(
             if reduce_exit.is_some() {
                 write!(
                     w,
-                    "br label %bb_{}",
-                    parallel_reduce_info.unwrap().top_parallel_block.idx()
+                    "br label %fork_join_seq_header_{}",
+                    location.unwrap().idx(),
                 )?;
             } else {
-                write!(w, "br label %bb_{}", target.idx())?;
+                write!(w, "br label %{}", &block_names[&(*target, location)])?;
             }
         }
         SInst::Branch {
@@ -334,9 +387,9 @@ fn emit_inst<W: Write>(
             emit_svalue(cond, true, types, w)?;
             write!(
                 w,
-                ", label %bb_{}, label %bb_{}",
-                true_target.idx(),
-                false_target.idx()
+                ", label %{}, label %{}",
+                &block_names[&(*true_target, location)],
+                &block_names[&(*false_target, location)],
             )?;
         }
         SInst::PartitionExit { data_outputs } => {
@@ -444,7 +497,7 @@ fn emit_inst<W: Write>(
         } => {
             emit_linear_index_calc(virt_reg, position, bounds, types, w)?;
             write!(w, "%store_ptr_{} = getelementptr ", virt_reg)?;
-            emit_type(&types[&self_svalue], w)?;
+            emit_type(&types[value], w)?;
             write!(w, ", ")?;
             emit_svalue(array, true, types, w)?;
             write!(w, ", i64 %calc_linear_idx_{}\n  ", virt_reg)?;
@@ -463,39 +516,33 @@ fn emit_inst<W: Write>(
  * Emit the loop header implementing a sequential fork-join.
  */
 fn emit_fork_join_seq_header<W: Write>(
+    fork_join_id: ForkJoinID,
     info: &ParallelReduceInfo,
     types: &HashMap<SValue, SType>,
+    block_names: &HashMap<(BlockID, Option<ForkJoinID>), String>,
     block_idx: usize,
     w: &mut W,
 ) -> Result<(), Error> {
+    // Start the header of the loop.
+    write!(w, "fork_join_seq_header_{}:\n", fork_join_id.idx())?;
+
     // Emit the phis for the linear loop index variable and the reduction
     // variables.
-    // TODO: handle these cases:
-    // 1. A parallel-reduce section is nested inside another parallel-reduce
-    //    section. If the predecessor block is itself a top parallel block, the
-    //    predecessor needs to be the _body block, not the original block.
-    // 2. A parallel-reduce section is immediately followed by another parallel-
-    //    reduce section. If the predecessor block is itself a reduce block, the
-    //    predecessor needs to be the top parallel block of the previous
-    //    parallel-reduce section, not its reduce block.
-    let entry_pred = info.predecessor;
-    let loop_pred = info.reduce_block;
+    let entry_name = &block_names[&(info.predecessor, Some(fork_join_id))];
+    let loop_name = &block_names[&(info.reduce_block, Some(fork_join_id))];
     write!(
         w,
-        "  %linear_{} = phi i64 [ 0, %bb_{} ], [ %linear_{}_inc, %bb_{} ]\n",
-        block_idx,
-        entry_pred.idx(),
-        block_idx,
-        loop_pred.idx(),
+        "  %linear_{} = phi i64 [ 0, %{} ], [ %linear_{}_inc, %{} ]\n",
+        block_idx, entry_name, block_idx, loop_name,
     )?;
     for (var_num, virt_reg) in info.reduction_variables.iter() {
         write!(w, "  %v{} = phi ", virt_reg)?;
         emit_type(&types[&SValue::VirtualRegister(*virt_reg)], w)?;
         write!(w, " [ ")?;
         emit_svalue(&info.reduce_inits[*var_num], false, types, w)?;
-        write!(w, ", %bb_{} ], [ ", entry_pred.idx())?;
+        write!(w, ", %{} ], [ ", entry_name)?;
         emit_svalue(&info.reduce_reducts[*var_num], false, types, w)?;
-        write!(w, ", %bb_{} ]\n", loop_pred.idx())?;
+        write!(w, ", %{} ]\n", loop_name)?;
     }
 
     // Calculate the loop bounds.
@@ -513,42 +560,28 @@ fn emit_fork_join_seq_header<W: Write>(
         todo!("TODO: Handle the 3 or more dimensional fork-join case.")
     }
 
-    // Emit the branch.
-    write!(
-        w,
-        "  %cond_{} = icmp ult i64 %linear_{}, %bound_{}\n",
-        block_idx, block_idx, block_idx
-    )?;
-    write!(
-        w,
-        "  br i1 %cond_{}, label %bb_{}_body, label %bb_{}\n",
-        block_idx,
-        block_idx,
-        info.successor.idx()
-    )?;
-
-    // Start the body of the loop.
-    write!(w, "bb_{}_body:\n", block_idx)?;
-
     // Calculate the multi-dimensional thread indices.
     if info.thread_counts.len() == 1 {
         write!(
             w,
             "  %thread_id_{}_0 = add i64 0, %linear_{}\n",
-            block_idx, block_idx
+            fork_join_id.idx(),
+            block_idx
         )?;
     } else if info.thread_counts.len() == 2 {
         write!(
             w,
             "  %thread_id_{}_0 = udiv i64 %linear_{}, ",
-            block_idx, block_idx
+            fork_join_id.idx(),
+            block_idx
         )?;
         emit_svalue(&info.thread_counts[1], false, types, w)?;
         write!(w, "\n")?;
         write!(
             w,
             "  %thread_id_{}_1 = urem i64 %linear_{}, ",
-            block_idx, block_idx
+            fork_join_id.idx(),
+            block_idx
         )?;
         emit_svalue(&info.thread_counts[1], false, types, w)?;
         write!(w, "\n")?;
@@ -563,6 +596,20 @@ fn emit_fork_join_seq_header<W: Write>(
         block_idx, block_idx
     )?;
 
+    // Emit the branch.
+    write!(
+        w,
+        "  %cond_{} = icmp ult i64 %linear_{}, %bound_{}\n",
+        block_idx, block_idx, block_idx
+    )?;
+    let top_name = &block_names[&(BlockID::new(block_idx), Some(fork_join_id))];
+    let succ_name = &block_names[&(info.successor, Some(fork_join_id))];
+    write!(
+        w,
+        "  br i1 %cond_{}, label %{}, label %{}\n",
+        block_idx, top_name, succ_name
+    )?;
+
     Ok(())
 }
 
diff --git a/hercules_cg/src/lib.rs b/hercules_cg/src/lib.rs
index b9a67eba..3ecb506a 100644
--- a/hercules_cg/src/lib.rs
+++ b/hercules_cg/src/lib.rs
@@ -2,12 +2,14 @@
 
 pub mod cpu;
 pub mod manifest;
+pub mod sched_dot;
 pub mod sched_gen;
 pub mod sched_ir;
 pub mod sched_schedule;
 
 pub use crate::cpu::*;
 pub use crate::manifest::*;
+pub use crate::sched_dot::*;
 pub use crate::sched_gen::*;
 pub use crate::sched_ir::*;
 pub use crate::sched_schedule::*;
diff --git a/hercules_cg/src/sched_dot.rs b/hercules_cg/src/sched_dot.rs
new file mode 100644
index 00000000..e21ca9c5
--- /dev/null
+++ b/hercules_cg/src/sched_dot.rs
@@ -0,0 +1,178 @@
+extern crate bitvec;
+extern crate rand;
+
+use std::collections::{HashMap, VecDeque};
+use std::env::temp_dir;
+use std::fmt::Write;
+use std::fs::File;
+use std::io::Write as _;
+use std::iter::zip;
+use std::process::Command;
+
+use self::bitvec::prelude::*;
+
+use self::rand::Rng;
+
+use crate::*;
+
+/*
+ * Top level function to compute a dot graph for a schedule IR module, and
+ * immediately render it using xdot.
+ */
+pub fn xdot_sched_module(module: &SModule) {
+    let mut tmp_path = temp_dir();
+    let mut rng = rand::thread_rng();
+    let num: u64 = rng.gen();
+    tmp_path.push(format!("sched_dot_{}.dot", num));
+    let mut file = File::create(tmp_path.clone()).expect("PANIC: Unable to open output file.");
+    let mut contents = String::new();
+    write_dot(&module, &mut contents).expect("PANIC: Unable to generate output file contents.");
+    file.write_all(contents.as_bytes())
+        .expect("PANIC: Unable to write output file contents.");
+    Command::new("xdot")
+        .args([tmp_path])
+        .output()
+        .expect("PANIC: Couldn't execute xdot. Is xdot installed?");
+}
+
+/*
+ * Top level function to write a schedule IR module out as a dot graph.
+ */
+pub fn write_dot<W: Write>(module: &SModule, w: &mut W) -> std::fmt::Result {
+    write_digraph_header(w)?;
+
+    for (function_name, function) in module.functions.iter() {
+        // Schedule the SFunction to form a linear ordering of instructions.
+        let dep_graph = sched_dependence_graph(function);
+        let mut block_to_inst_list = (0..function.blocks.len())
+            .map(|block_idx| (block_idx, vec![]))
+            .collect::<HashMap<usize, Vec<(&SInst, usize)>>>();
+        for (block_idx, block) in function.blocks.iter().enumerate() {
+            let mut emitted = bitvec![u8, Lsb0; 0; block.insts.len()];
+            let mut worklist = VecDeque::from((0..block.insts.len()).collect::<Vec<_>>());
+            while let Some(inst_idx) = worklist.pop_front() {
+                let inst_id = InstID::new(block_idx, inst_idx);
+                let dependencies = &dep_graph[&inst_id];
+                let all_uses_emitted = dependencies
+                    .into_iter()
+                    // Check that all used instructions in this block...
+                    .filter(|inst_id| inst_id.idx_0() == block_idx)
+                    // were already emitted.
+                    .all(|inst_id| emitted[inst_id.idx_1()]);
+                // Phis don't need to wait for all of their uses to be added.
+                if block.insts[inst_idx].is_phi() || all_uses_emitted {
+                    block_to_inst_list
+                        .get_mut(&block_idx)
+                        .unwrap()
+                        .push((&block.insts[inst_idx], block.virt_regs[inst_idx].0));
+                    emitted.set(inst_id.idx_1(), true);
+                } else {
+                    worklist.push_back(inst_idx);
+                }
+            }
+        }
+
+        // A SFunction is a subgraph.
+        write_subgraph_header(function_name, w)?;
+
+        // Each SBlock is a nested subgraph.
+        for (block_idx, block) in function.blocks.iter().enumerate() {
+            write_block_header(function_name, block_idx, "lightblue", w)?;
+
+            // Emit the instructions in scheduled order.
+            write_block(function_name, block_idx, &block_to_inst_list[&block_idx], w)?;
+
+            write_graph_footer(w)?;
+
+            // Add control edges.
+            for succ in block.successors().as_ref() {
+                write_control_edge(function_name, block_idx, succ.idx(), w)?;
+            }
+        }
+
+        write_graph_footer(w)?;
+    }
+
+    write_graph_footer(w)?;
+    Ok(())
+}
+
+fn write_digraph_header<W: Write>(w: &mut W) -> std::fmt::Result {
+    write!(w, "digraph \"Module\" {{\n")?;
+    write!(w, "compound=true\n")?;
+    Ok(())
+}
+
+fn write_subgraph_header<W: Write>(function_name: &SFunctionName, w: &mut W) -> std::fmt::Result {
+    write!(w, "subgraph {} {{\n", function_name)?;
+    write!(w, "label=\"{}\"\n", function_name)?;
+    write!(w, "bgcolor=ivory4\n")?;
+    write!(w, "cluster=true\n")?;
+    Ok(())
+}
+
+fn write_block_header<W: Write>(
+    function_name: &SFunctionName,
+    block_idx: usize,
+    color: &str,
+    w: &mut W,
+) -> std::fmt::Result {
+    write!(w, "subgraph {}_block_{} {{\n", function_name, block_idx)?;
+    write!(w, "label=\"\"\n")?;
+    write!(w, "style=rounded\n")?;
+    write!(w, "bgcolor={}\n", color)?;
+    write!(w, "cluster=true\n")?;
+    Ok(())
+}
+
+fn write_graph_footer<W: Write>(w: &mut W) -> std::fmt::Result {
+    write!(w, "}}\n")?;
+    Ok(())
+}
+
+fn write_block<W: Write>(
+    function_name: &SFunctionName,
+    block_idx: usize,
+    insts: &[(&SInst, usize)],
+    w: &mut W,
+) -> std::fmt::Result {
+    write!(
+        w,
+        "{}_{} [xlabel={}, label=\"{{",
+        function_name, block_idx, block_idx
+    )?;
+    for token in insts.into_iter().map(|token| Some(token)).intersperse(None) {
+        match token {
+            Some((inst, virt_reg)) => {
+                write!(w, "%{} = {}(", virt_reg, inst.upper_case_name())?;
+                for token in sched_get_uses(inst).map(|u| Some(u)).intersperse(None) {
+                    match token {
+                        Some(SValue::VirtualRegister(use_virt_reg)) => {
+                            write!(w, "%{}", use_virt_reg)?
+                        }
+                        Some(SValue::Constant(scons)) => write!(w, "{:?}", scons)?,
+                        None => write!(w, ", ")?,
+                    }
+                }
+                write!(w, ")")?;
+            }
+            None => write!(w, " | ")?,
+        }
+    }
+    write!(w, "}}\", shape = \"record\"];\n")?;
+    Ok(())
+}
+
+fn write_control_edge<W: Write>(
+    function_name: &SFunctionName,
+    src: usize,
+    dst: usize,
+    w: &mut W,
+) -> std::fmt::Result {
+    write!(
+        w,
+        "{}_{} -> {}_{} [color=\"black\"];\n",
+        function_name, src, function_name, dst
+    )?;
+    Ok(())
+}
diff --git a/hercules_cg/src/sched_gen.rs b/hercules_cg/src/sched_gen.rs
index c9696801..98d2a202 100644
--- a/hercules_cg/src/sched_gen.rs
+++ b/hercules_cg/src/sched_gen.rs
@@ -711,6 +711,7 @@ impl<'a> FunctionContext<'a> {
                     id,
                     &control_id_to_block_id,
                     &data_id_to_svalue,
+                    &fork_node_id_to_fork_join_id,
                     &mut blocks,
                     partition_idx,
                     manifest,
@@ -772,6 +773,7 @@ impl<'a> FunctionContext<'a> {
         id: NodeID,
         control_id_to_block_id: &HashMap<NodeID, BlockID>,
         data_id_to_svalue: &HashMap<NodeID, SValue>,
+        fork_node_id_to_fork_join_id: &HashMap<NodeID, ForkJoinID>,
         blocks: &mut Vec<SBlock>,
         partition_idx: usize,
         manifest: &Manifest,
@@ -804,7 +806,7 @@ impl<'a> FunctionContext<'a> {
                     None
                 };
                 let reduce_exit = if self.function.nodes[id.idx()].is_join() {
-                    Some(self.compile_reduce_exit(dst, data_id_to_svalue))
+                    Some(self.compile_reduce_exit(id, data_id_to_svalue))
                 } else {
                     None
                 };
@@ -969,11 +971,12 @@ impl<'a> FunctionContext<'a> {
                 }
             }
 
-            Node::ThreadID {
-                control: _,
-                dimension,
-            } => {
-                block.insts.push(SInst::ThreadID { dimension });
+            Node::ThreadID { control, dimension } => {
+                let fork_join = fork_node_id_to_fork_join_id[&control];
+                block.insts.push(SInst::ThreadID {
+                    dimension,
+                    fork_join,
+                });
                 block
                     .virt_regs
                     .push((self_virt_reg(), SType::UnsignedInteger64));
diff --git a/hercules_cg/src/sched_ir.rs b/hercules_cg/src/sched_ir.rs
index 56a2bb51..3db4a9e1 100644
--- a/hercules_cg/src/sched_ir.rs
+++ b/hercules_cg/src/sched_ir.rs
@@ -260,6 +260,7 @@ pub enum SInst {
     },
     ThreadID {
         dimension: usize,
+        fork_join: ForkJoinID,
     },
     ReductionVariable {
         number: usize,
@@ -359,9 +360,13 @@ impl SInst {
         self.is_jump() || self.is_partition_exit() || self.is_return()
     }
 
-    pub fn try_thread_id(&self) -> Option<usize> {
-        if let SInst::ThreadID { dimension } = self {
-            Some(*dimension)
+    pub fn try_thread_id(&self) -> Option<(usize, ForkJoinID)> {
+        if let SInst::ThreadID {
+            dimension,
+            fork_join,
+        } = self
+        {
+            Some((*dimension, *fork_join))
         } else {
             None
         }
@@ -403,6 +408,61 @@ impl SInst {
             _ => BlockSuccessors::Zero,
         }
     }
+
+    pub fn upper_case_name(&self) -> &'static str {
+        match self {
+            SInst::Phi { inputs: _ } => "Phi",
+            SInst::ThreadID {
+                dimension: _,
+                fork_join: _,
+            } => "ThreadID",
+            SInst::ReductionVariable { number: _ } => "ReductionVariable",
+            SInst::Jump {
+                target: _,
+                parallel_entry: _,
+                reduce_exit: _,
+            } => "Jump",
+            SInst::Branch {
+                cond: _,
+                false_target: _,
+                true_target: _,
+            } => "Branch",
+            SInst::PartitionExit { data_outputs: _ } => "PartitionExit",
+            SInst::Return { value: _ } => "Return",
+            SInst::Unary { input: _, op } => op.upper_case_name(),
+            SInst::Binary {
+                left: _,
+                right: _,
+                op,
+            } => op.upper_case_name(),
+            SInst::Ternary {
+                first: _,
+                second: _,
+                third: _,
+                op,
+            } => op.upper_case_name(),
+            SInst::ProductExtract {
+                product: _,
+                indices: _,
+            } => "ProductExtract",
+            SInst::ProductInsert {
+                product: _,
+                data: _,
+                indices: _,
+            } => "ProductInsert",
+            SInst::ArrayLoad {
+                array: _,
+                position: _,
+                bounds: _,
+            } => "ArrayLoad",
+            SInst::ArrayStore {
+                array: _,
+                value: _,
+                position: _,
+                bounds: _,
+            } => "ArrayStore",
+        }
+    }
 }
 
 #[derive(Debug, PartialEq, Eq)]
@@ -452,6 +512,16 @@ pub enum SUnaryOperator {
     Cast(SType),
 }
 
+impl SUnaryOperator {
+    pub fn upper_case_name(&self) -> &'static str {
+        match self {
+            SUnaryOperator::Not => "Not",
+            SUnaryOperator::Neg => "Neg",
+            SUnaryOperator::Cast(_) => "Cast",
+        }
+    }
+}
+
 #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
 pub enum SBinaryOperator {
     Add,
@@ -472,11 +542,42 @@ pub enum SBinaryOperator {
     RSh,
 }
 
+impl SBinaryOperator {
+    pub fn upper_case_name(&self) -> &'static str {
+        match self {
+            SBinaryOperator::Add => "Add",
+            SBinaryOperator::Sub => "Sub",
+            SBinaryOperator::Mul => "Mul",
+            SBinaryOperator::Div => "Div",
+            SBinaryOperator::Rem => "Rem",
+            SBinaryOperator::LT => "LT",
+            SBinaryOperator::LTE => "LTE",
+            SBinaryOperator::GT => "GT",
+            SBinaryOperator::GTE => "GTE",
+            SBinaryOperator::EQ => "EQ",
+            SBinaryOperator::NE => "NE",
+            SBinaryOperator::Or => "Or",
+            SBinaryOperator::And => "And",
+            SBinaryOperator::Xor => "Xor",
+            SBinaryOperator::LSh => "LSh",
+            SBinaryOperator::RSh => "RSh",
+        }
+    }
+}
+
 #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
 pub enum STernaryOperator {
     Select,
 }
 
+impl STernaryOperator {
+    pub fn upper_case_name(&self) -> &'static str {
+        match self {
+            STernaryOperator::Select => "Select",
+        }
+    }
+}
+
 pub type SFunctionName = String;
 
 define_id_type!(ArrayID);
diff --git a/hercules_cg/src/sched_schedule.rs b/hercules_cg/src/sched_schedule.rs
index ac4583df..ee16f5bd 100644
--- a/hercules_cg/src/sched_schedule.rs
+++ b/hercules_cg/src/sched_schedule.rs
@@ -9,7 +9,10 @@ use crate::*;
 pub fn sched_get_uses(inst: &SInst) -> Box<dyn Iterator<Item = &SValue> + '_> {
     match inst {
         SInst::Phi { inputs } => Box::new(inputs.iter().map(|(_, svalue)| svalue)),
-        SInst::ThreadID { dimension: _ } => Box::new(empty()),
+        SInst::ThreadID {
+            dimension: _,
+            fork_join: _,
+        } => Box::new(empty()),
         SInst::ReductionVariable { number: _ } => Box::new(empty()),
         SInst::Jump {
             target: _,
@@ -214,6 +217,7 @@ pub fn sched_svalue_types(function: &SFunction) -> HashMap<SValue, SType> {
 /*
  * Analysis information for one fork-join.
  */
+#[derive(Debug)]
 pub struct ParallelReduceInfo {
     // The block that jumps into the parallel section.
     pub predecessor: BlockID,
@@ -240,6 +244,11 @@ pub struct ParallelReduceInfo {
     // Map from reduction variable number to virtual register of the
     // corresponding reduction variable instruction.
     pub reduction_variables: HashMap<usize, usize>,
+
+    // If this parallel-reduce section is inside another parallel-reduce, store
+    // the parent's ForkJoinID. Parallel-reduce sections in an SFunction form a
+    // forest.
+    pub parent_fork_join_id: Option<ForkJoinID>,
 }
 
 /*
@@ -275,14 +284,18 @@ pub fn sched_parallel_reduce_sections(
                     .try_parallel()
                     .unwrap();
 
-                // Traverse the blocks until finding a jump to a reduce block.
+                // Traverse the blocks until finding a jump to the corresponding
+                // reduce block.
                 let mut queue = VecDeque::from(vec![top_parallel_block]);
                 let mut visited = HashSet::new();
                 visited.insert(top_parallel_block);
                 let mut bfs_dest = None;
                 while let Some(bfs) = queue.pop_front() {
                     for succ in function.blocks[bfs.idx()].successors().as_ref() {
-                        if function.blocks[succ.idx()].kind.try_reduce().is_some() {
+                        if let Some(reduce_fork_join_id) =
+                            function.blocks[succ.idx()].kind.try_reduce()
+                            && reduce_fork_join_id == fork_join_id
+                        {
                             bfs_dest = Some((bfs, *succ));
                         } else if !visited.contains(succ) {
                             queue.push_back(*succ);
@@ -310,7 +323,9 @@ pub fn sched_parallel_reduce_sections(
                         function.blocks[parallel_block.idx()].insts.iter(),
                         function.blocks[parallel_block.idx()].virt_regs.iter(),
                     ) {
-                        if let Some(dim) = inst.try_thread_id() {
+                        if let Some((dim, tid_fork_join)) = inst.try_thread_id()
+                            && tid_fork_join == fork_join_id
+                        {
                             thread_ids.get_mut(&dim).unwrap().push(*virt_reg);
                         }
                     }
@@ -342,11 +357,45 @@ pub fn sched_parallel_reduce_sections(
 
                     thread_ids,
                     reduction_variables,
+
+                    parent_fork_join_id: None,
                 };
                 result.insert(fork_join_id, info);
             }
         }
     }
 
+    // Compute the parallel-reduce forest last, since this requires some info we
+    // just computed above.
+    let mut parents = HashMap::new();
+    for (fork_join_id, parallel_reduce_info) in result.iter() {
+        let mut pred_block = parallel_reduce_info.predecessor;
+
+        // Keep looking at predecessors of adjacent parallel-reduce sections
+        // until one belongs to a parent parallel-reduce or is sequential, so
+        // this parallel-reduce is a root.
+        let parent = loop {
+            match function.blocks[pred_block.idx()].kind {
+                // If the predecessor is sequential, then this parallel-reduce
+                // is a root.
+                SBlockKind::Sequential => break None,
+                // If the predecessor is parallel, then this parallel-reduce is
+                // inside that parallel-reduce.
+                SBlockKind::Parallel(parent) => break Some(parent),
+                // If the predecessor is reduce, then that parallel-reduce is a
+                // child of the same parent. Iterate on its predecessor.
+                SBlockKind::Reduce(adjacent) => {
+                    pred_block = result[&adjacent].predecessor;
+                }
+            }
+        };
+        parents.insert(*fork_join_id, parent);
+    }
+
+    // Insert the information into the parallel reduce info map.
+    for (fork_join_id, parallel_reduce_info) in result.iter_mut() {
+        parallel_reduce_info.parent_fork_join_id = parents[fork_join_id];
+    }
+
     result
 }
diff --git a/hercules_ir/src/dot.rs b/hercules_ir/src/dot.rs
index 92bfdaaf..328ecb77 100644
--- a/hercules_ir/src/dot.rs
+++ b/hercules_ir/src/dot.rs
@@ -354,9 +354,9 @@ fn write_node<W: Write>(
     // If this is a node with additional information, add that to the node
     // label.
     let label = if suffix.is_empty() {
-        node.lower_case_name().to_owned()
+        node.upper_case_name().to_owned()
     } else {
-        format!("{} ({})", node.lower_case_name(), suffix)
+        format!("{} ({})", node.upper_case_name(), suffix)
     };
 
     let mut iter = schedules.into_iter();
diff --git a/hercules_opt/src/pass.rs b/hercules_opt/src/pass.rs
index 395f24d6..188d428d 100644
--- a/hercules_opt/src/pass.rs
+++ b/hercules_opt/src/pass.rs
@@ -34,6 +34,7 @@ pub enum Pass {
     // Parameterized over whether analyses that aid visualization are necessary.
     // Useful to set to false if displaying a potentially broken module.
     Xdot(bool),
+    SchedXdot,
     // Parameterized by output file name.
     Codegen(String),
 }
@@ -445,6 +446,35 @@ impl PassManager {
                     // Xdot doesn't require clearing analysis results.
                     continue;
                 }
+                Pass::SchedXdot => {
+                    self.make_def_uses();
+                    self.make_reverse_postorders();
+                    self.make_typing();
+                    self.make_control_subgraphs();
+                    self.make_fork_join_maps();
+                    self.make_fork_join_nests();
+                    self.make_antideps();
+                    self.make_bbs();
+                    self.make_plans();
+
+                    let smodule = sched_compile(
+                        &self.module,
+                        self.def_uses.as_ref().unwrap(),
+                        self.reverse_postorders.as_ref().unwrap(),
+                        self.typing.as_ref().unwrap(),
+                        self.control_subgraphs.as_ref().unwrap(),
+                        self.fork_join_maps.as_ref().unwrap(),
+                        self.fork_join_nests.as_ref().unwrap(),
+                        self.antideps.as_ref().unwrap(),
+                        self.bbs.as_ref().unwrap(),
+                        self.plans.as_ref().unwrap(),
+                    );
+
+                    xdot_sched_module(&smodule);
+
+                    // Xdot doesn't require clearing analysis results.
+                    continue;
+                }
                 Pass::Codegen(output_file_name) => {
                     self.make_def_uses();
                     self.make_reverse_postorders();
diff --git a/hercules_rt_proc/src/lib.rs b/hercules_rt_proc/src/lib.rs
index 0056ccd7..69625c19 100644
--- a/hercules_rt_proc/src/lib.rs
+++ b/hercules_rt_proc/src/lib.rs
@@ -225,8 +225,8 @@ fn codegen(manifests: &HashMap<String, Manifest>, elf: &[u8]) -> Result<String,
                     ParameterKind::DataInput(id) => {
                         write!(rust_code, "node_{}.assume_init(), ", id.idx())?
                     }
-                    ParameterKind::DynamicConstant(idx) => write!(rust_code, "dc_{}", idx)?,
-                    ParameterKind::ArrayConstant(id) => write!(rust_code, "array_{}", id.idx())?,
+                    ParameterKind::DynamicConstant(idx) => write!(rust_code, "dc_{}, ", idx)?,
+                    ParameterKind::ArrayConstant(id) => write!(rust_code, "array_{}, ", id.idx())?,
                 }
             }
             write!(rust_code, ");\n")?;
diff --git a/hercules_samples/matmul/matmul.hir b/hercules_samples/matmul/matmul.hir
index 8c34a316..6207eb88 100644
--- a/hercules_samples/matmul/matmul.hir
+++ b/hercules_samples/matmul/matmul.hir
@@ -1,11 +1,11 @@
 fn matmul<3>(a: array(f32, #0, #1), b: array(f32, #1, #2)) -> array(f32, #0, #2)
   c = constant(array(f32, #0, #2), [])
   i_ctrl = fork(start, #0)
-  i_idx = thread_id(i_ctrl)
+  i_idx = thread_id(i_ctrl, 0)
   j_ctrl = fork(i_ctrl, #2)
-  j_idx = thread_id(j_ctrl)
+  j_idx = thread_id(j_ctrl, 0)
   k_ctrl = fork(j_ctrl, #1)
-  k_idx = thread_id(k_ctrl)
+  k_idx = thread_id(k_ctrl, 0)
   k_join_ctrl = join(k_ctrl)
   j_join_ctrl = join(k_join_ctrl)
   i_join_ctrl = join(j_join_ctrl)
diff --git a/hercules_samples/matmul/src/main.rs b/hercules_samples/matmul/src/main.rs
index 732aece4..baeab268 100644
--- a/hercules_samples/matmul/src/main.rs
+++ b/hercules_samples/matmul/src/main.rs
@@ -12,15 +12,13 @@ fn main() {
     async_std::task::block_on(async {
         let mut a = vec![1.0, 2.0, 3.0, 4.0];
         let mut b = vec![5.0, 6.0, 7.0, 8.0];
-        let c = matmul(a.as_mut_ptr(), b.as_mut_ptr(), 2, 2, 2).await;
+        let mut c = vec![0.0, 0.0, 0.0, 0.0];
         unsafe {
-            println!(
-                "[[{}, {}], [{}, {}]]",
-                *c,
-                *c.offset(1),
-                *c.offset(2),
-                *c.offset(3)
-            );
+            matmul(a.as_mut_ptr(), b.as_mut_ptr(), c.as_mut_ptr(), 2, 2, 2).await;
         }
+        println!(
+            "[[{}, {}], [{}, {}]]",
+            c[0], c[1], c[2], c[3]
+        );
     });
 }
-- 
GitLab