From 9de4d3bb84a0e9c34d6119b3d82f559b1ae89f95 Mon Sep 17 00:00:00 2001
From: Aaron Councilman <aaronjc4@illinois.edu>
Date: Sat, 22 Feb 2025 18:28:49 -0600
Subject: [PATCH 01/15] Make fork coalesce use one edit to handle partial
 selection properly

---
 hercules_opt/src/fork_transforms.rs | 47 ++++++++++++-----------------
 1 file changed, 19 insertions(+), 28 deletions(-)

diff --git a/hercules_opt/src/fork_transforms.rs b/hercules_opt/src/fork_transforms.rs
index e635b3c0..ec111e69 100644
--- a/hercules_opt/src/fork_transforms.rs
+++ b/hercules_opt/src/fork_transforms.rs
@@ -686,40 +686,33 @@ pub fn fork_coalesce_helper(
     // CHECKME / FIXME: Might need to be added the other way.
     new_factors.append(&mut inner_dims.to_vec());
 
-    for tid in inner_tids {
-        let (fork, dim) = editor.func().nodes[tid.idx()].try_thread_id().unwrap();
-        let new_tid = Node::ThreadID {
-            control: fork,
-            dimension: dim + num_outer_dims,
-        };
+    editor.edit(|mut edit| {
+        for tid in inner_tids {
+            let (fork, dim) = edit.get_node(tid).try_thread_id().unwrap();
+            let new_tid = Node::ThreadID {
+                control: fork,
+                dimension: dim + num_outer_dims,
+            };
 
-        editor.edit(|mut edit| {
             let new_tid = edit.add_node(new_tid);
-            let edit = edit.replace_all_uses(tid, new_tid)?;
-            Ok(edit)
-        });
-    }
+            edit = edit.replace_all_uses(tid, new_tid)?;
+        }
 
-    // Fuse Reductions
-    for (outer_reduce, inner_reduce) in pairs {
-        let (_, outer_init, _) = editor.func().nodes[outer_reduce.idx()]
-            .try_reduce()
-            .unwrap();
-        let (_, inner_init, _) = editor.func().nodes[inner_reduce.idx()]
-            .try_reduce()
-            .unwrap();
-        editor.edit(|mut edit| {
+        // Fuse Reductions
+        for (outer_reduce, inner_reduce) in pairs {
+            let (_, outer_init, _) = edit.get_node(outer_reduce)
+                .try_reduce()
+                .unwrap();
+            let (_, inner_init, _) = edit.get_node(inner_reduce)
+                .try_reduce()
+                .unwrap();
             // Set inner init to outer init.
             edit =
                 edit.replace_all_uses_where(inner_init, outer_init, |usee| *usee == inner_reduce)?;
             edit = edit.replace_all_uses(outer_reduce, inner_reduce)?;
             edit = edit.delete_node(outer_reduce)?;
+        }
 
-            Ok(edit)
-        });
-    }
-
-    editor.edit(|mut edit| {
         let new_fork = Node::Fork {
             control: outer_pred,
             factors: new_factors.into(),
@@ -734,9 +727,7 @@ pub fn fork_coalesce_helper(
         edit = edit.delete_node(outer_fork)?;
 
         Ok(edit)
-    });
-
-    true
+    })
 }
 
 pub fn split_any_fork(
-- 
GitLab


From 972c1e6ebe2f616181f4a8f7a5d90c500f9ac518 Mon Sep 17 00:00:00 2001
From: Aaron Councilman <aaronjc4@illinois.edu>
Date: Sat, 22 Feb 2025 18:42:41 -0600
Subject: [PATCH 02/15] Propagate ParallelFork in coalesce and interchange

---
 hercules_opt/src/fork_transforms.rs | 8 ++++++++
 1 file changed, 8 insertions(+)

diff --git a/hercules_opt/src/fork_transforms.rs b/hercules_opt/src/fork_transforms.rs
index ec111e69..9fceb892 100644
--- a/hercules_opt/src/fork_transforms.rs
+++ b/hercules_opt/src/fork_transforms.rs
@@ -719,6 +719,11 @@ pub fn fork_coalesce_helper(
         };
         let new_fork = edit.add_node(new_fork);
 
+        if edit.get_schedule(outer_fork).contains(&Schedule::ParallelFork)
+            && edit.get_schedule(inner_fork).contains(&Schedule::ParallelFork) {
+            edit = edit.add_schedule(new_fork, Schedule::ParallelFork)?;
+        }
+
         edit = edit.replace_all_uses(inner_fork, new_fork)?;
         edit = edit.replace_all_uses(outer_fork, new_fork)?;
         edit = edit.replace_all_uses(outer_join, inner_join)?;
@@ -1271,6 +1276,9 @@ fn fork_interchange(
             edit = edit.delete_node(old_id)?;
         }
         let new_fork = edit.add_node(new_fork);
+        if edit.get_schedule(fork).contains(&Schedule::ParallelFork) {
+            edit = edit.add_schedule(new_fork, Schedule::ParallelFork)?;
+        }
         edit = edit.replace_all_uses(fork, new_fork)?;
         edit.delete_node(fork)
     });
-- 
GitLab


From 199a8a80e72e8377faedea785a879847edf11ff7 Mon Sep 17 00:00:00 2001
From: Aaron Councilman <aaronjc4@illinois.edu>
Date: Sat, 22 Feb 2025 18:44:44 -0600
Subject: [PATCH 03/15] Manual gpu schedule

---
 juno_samples/matmul/src/gpu.sch | 63 +++++++++++++++++++++++----------
 1 file changed, 44 insertions(+), 19 deletions(-)

diff --git a/juno_samples/matmul/src/gpu.sch b/juno_samples/matmul/src/gpu.sch
index 76808149..effdc6b2 100644
--- a/juno_samples/matmul/src/gpu.sch
+++ b/juno_samples/matmul/src/gpu.sch
@@ -1,26 +1,51 @@
-phi-elim(*);
+macro optimize!(X) {
+  gvn(X);
+  phi-elim(X);
+  dce(X);
+  ip-sroa(X);
+  sroa(X);
+  dce(X);
+  gvn(X);
+  phi-elim(X);
+  dce(X);
+}
+
+macro codegen!(X) {
+  gcm(*);
+  float-collections(*);
+  dce(*);
+  gcm(*);
+}
 
-forkify(*);
-fork-guard-elim(*);
-dce(*);
+optimize!(*);
 
-fixpoint {
-  reduce-slf(*);
-  slf(*);
-  infer-schedules(*);
+fixpoint panic after 20 {
+  forkify(matmul);
+  fork-guard-elim(matmul);
 }
-fork-coalesce(*);
-infer-schedules(*);
-dce(*);
-rewrite(*);
-fixpoint {
-  simplify-cfg(*);
-  dce(*);
+
+optimize!(*);
+
+fixpoint panic after 20 {
+  reduce-slf(matmul);
+  slf(matmul);
+  infer-schedules(matmul);
 }
+dce(matmul);
 
-ip-sroa(*);
-sroa(*);
-dce(*);
+// Tile outer and middle loops into 32x32 sized blocks
+fork-tile[32, 0, false, true](matmul@outer \ matmul@inner);
+// Merge outer and middle loops and interchange so blocks are first
+fork-coalesce(matmul@outer \ matmul@inner);
+fork-interchange[1, 2](matmul@outer \ matmul@inner);
+// Split forks
+let split = fork-split(matmul);
+// Join the threads and then blocks into a single fork each
+fork-coalesce(split.matmul.fj2 \ matmul@inner);
+fork-coalesce(split.matmul.fj0 \ split.matmul.fj2);
 
+let auto = auto-outline(*);
 float-collections(*);
-gcm(*);
+gpu(auto.matmul);
+
+codegen!(*);
-- 
GitLab


From 4b1630b2537bd82e2d52c24509797e9c5bef77d5 Mon Sep 17 00:00:00 2001
From: Aaron Councilman <aaronjc4@illinois.edu>
Date: Mon, 24 Feb 2025 14:27:36 -0600
Subject: [PATCH 04/15] Restore gpu schedule

---
 juno_samples/matmul/src/gpu.sch | 63 ++++++++++-----------------------
 1 file changed, 19 insertions(+), 44 deletions(-)

diff --git a/juno_samples/matmul/src/gpu.sch b/juno_samples/matmul/src/gpu.sch
index effdc6b2..76808149 100644
--- a/juno_samples/matmul/src/gpu.sch
+++ b/juno_samples/matmul/src/gpu.sch
@@ -1,51 +1,26 @@
-macro optimize!(X) {
-  gvn(X);
-  phi-elim(X);
-  dce(X);
-  ip-sroa(X);
-  sroa(X);
-  dce(X);
-  gvn(X);
-  phi-elim(X);
-  dce(X);
-}
-
-macro codegen!(X) {
-  gcm(*);
-  float-collections(*);
-  dce(*);
-  gcm(*);
-}
+phi-elim(*);
 
-optimize!(*);
+forkify(*);
+fork-guard-elim(*);
+dce(*);
 
-fixpoint panic after 20 {
-  forkify(matmul);
-  fork-guard-elim(matmul);
+fixpoint {
+  reduce-slf(*);
+  slf(*);
+  infer-schedules(*);
 }
-
-optimize!(*);
-
-fixpoint panic after 20 {
-  reduce-slf(matmul);
-  slf(matmul);
-  infer-schedules(matmul);
+fork-coalesce(*);
+infer-schedules(*);
+dce(*);
+rewrite(*);
+fixpoint {
+  simplify-cfg(*);
+  dce(*);
 }
-dce(matmul);
 
-// Tile outer and middle loops into 32x32 sized blocks
-fork-tile[32, 0, false, true](matmul@outer \ matmul@inner);
-// Merge outer and middle loops and interchange so blocks are first
-fork-coalesce(matmul@outer \ matmul@inner);
-fork-interchange[1, 2](matmul@outer \ matmul@inner);
-// Split forks
-let split = fork-split(matmul);
-// Join the threads and then blocks into a single fork each
-fork-coalesce(split.matmul.fj2 \ matmul@inner);
-fork-coalesce(split.matmul.fj0 \ split.matmul.fj2);
+ip-sroa(*);
+sroa(*);
+dce(*);
 
-let auto = auto-outline(*);
 float-collections(*);
-gpu(auto.matmul);
-
-codegen!(*);
+gcm(*);
-- 
GitLab


From d179193de2081aec72dfe06baa96dd58612b98a4 Mon Sep 17 00:00:00 2001
From: Aaron Councilman <aaronjc4@illinois.edu>
Date: Mon, 24 Feb 2025 15:28:02 -0600
Subject: [PATCH 05/15] Parallel tiled cpu schedule

---
 juno_samples/matmul/src/cpu.sch | 63 +++++++++++++++++++++++++++++++++
 1 file changed, 63 insertions(+)
 create mode 100644 juno_samples/matmul/src/cpu.sch

diff --git a/juno_samples/matmul/src/cpu.sch b/juno_samples/matmul/src/cpu.sch
new file mode 100644
index 00000000..bef45ca2
--- /dev/null
+++ b/juno_samples/matmul/src/cpu.sch
@@ -0,0 +1,63 @@
+macro optimize!(X) {
+  gvn(X);
+  phi-elim(X);
+  dce(X);
+  ip-sroa(X);
+  sroa(X);
+  dce(X);
+  gvn(X);
+  phi-elim(X);
+  dce(X);
+}
+
+macro codegen!(X) {
+  gcm(*);
+  float-collections(*);
+  dce(*);
+  gcm(*);
+}
+
+optimize!(*);
+
+fixpoint panic after 20 {
+  forkify(matmul);
+  fork-guard-elim(matmul);
+}
+
+// Mark the whole loop nest as associative, any order of iterations is equivalent
+associative(matmul@outer);
+
+// Tile the outer 2 loops to create 16 parallel threads (each responsible for
+// computing one block of the output
+let par = matmul@outer \ matmul@inner;
+fork-tile[4, 0, false, true](par);
+fork-coalesce(par);
+fork-interchange[0, 1](par);
+fork-interchange[2, 3](par);
+fork-interchange[1, 2](par);
+
+let split = fork-split(*);
+fork-coalesce(split.matmul.fj0 \ split.matmul.fj2);
+parallel-fork(split.matmul.fj0 \ split.matmul.fj2);
+
+// Pull the body of the parallel loop out into its own device function
+let body = outline(split.matmul.fj2);
+cpu(body);
+
+// Tile the loop nest for cache performance; 16x16x16 tile
+fork-tile[16, 0, false, true](body);
+fixpoint { fork-coalesce(body); }
+
+fork-interchange[1, 2](body);
+fork-interchange[3, 4](body);
+fork-interchange[2, 3](body);
+
+optimize!(*);
+
+fork-split(body);
+reduce-slf(*);
+unforkify(body);
+
+optimize!(*);
+
+codegen!(*);
-- 
GitLab


From 568b399d103e9d91d29754d1f0510df3b949695b Mon Sep 17 00:00:00 2001
From: Aaron Councilman <aaronjc4@illinois.edu>
Date: Mon, 24 Feb 2025 15:29:07 -0600
Subject: [PATCH 06/15] Use cpu schedule

---
 juno_samples/matmul/build.rs | 2 ++
 1 file changed, 2 insertions(+)

diff --git a/juno_samples/matmul/build.rs b/juno_samples/matmul/build.rs
index 0be838c6..d2813388 100644
--- a/juno_samples/matmul/build.rs
+++ b/juno_samples/matmul/build.rs
@@ -6,6 +6,8 @@ fn main() {
         JunoCompiler::new()
             .file_in_src("matmul.jn")
             .unwrap()
+            .schedule_in_src("cpu.sch")
+            .unwrap()
             .build()
             .unwrap();
     }
-- 
GitLab


From aec9cf1fe4bf40f000d472d63e79cdf66b6b0cb8 Mon Sep 17 00:00:00 2001
From: Aaron Councilman <aaronjc4@illinois.edu>
Date: Wed, 26 Feb 2025 11:45:20 -0600
Subject: [PATCH 07/15] parallel cpu schedule

---
 juno_samples/matmul/src/cpu.sch | 3 +++
 1 file changed, 3 insertions(+)

diff --git a/juno_samples/matmul/src/cpu.sch b/juno_samples/matmul/src/cpu.sch
index bef45ca2..2fcf3108 100644
--- a/juno_samples/matmul/src/cpu.sch
+++ b/juno_samples/matmul/src/cpu.sch
@@ -60,4 +60,7 @@ unforkify(body);
 
 optimize!(*);
 
+parallel-reduce(split.matmul.fj0);
+xdot[true](*);
+
 codegen!(*);
-- 
GitLab


From 10f9373682232536ad0c400e4365d88e2f7050f0 Mon Sep 17 00:00:00 2001
From: Aaron Councilman <aaronjc4@illinois.edu>
Date: Wed, 26 Feb 2025 14:49:03 -0600
Subject: [PATCH 08/15] Clean-up cpu schedule

---
 juno_samples/matmul/src/cpu.sch | 54 +++++++++++++++++++++------------
 1 file changed, 34 insertions(+), 20 deletions(-)

diff --git a/juno_samples/matmul/src/cpu.sch b/juno_samples/matmul/src/cpu.sch
index 2fcf3108..fb08c254 100644
--- a/juno_samples/matmul/src/cpu.sch
+++ b/juno_samples/matmul/src/cpu.sch
@@ -10,27 +10,45 @@ macro optimize!(X) {
   dce(X);
 }
 
-macro codegen!(X) {
-  gcm(*);
-  float-collections(*);
-  dce(*);
-  gcm(*);
+macro codegen-prep!(X) {
+  optimize!(X);
+  gcm(X);
+  float-collections(X);
+  dce(X);
+  gcm(X);
 }
 
-optimize!(*);
+macro forkify!(X) {
+  fixpoint {
+    forkify(X);
+    fork-guard-elim(X);
+  }
+}
+
+macro fork-tile![n](X) {
+  fork-tile[n, 0, false, true](X);
+}
+
+macro parallelize!(X) {
+  parallel-fork(X);
+  parallel-reduce(X);
+}
 
-fixpoint panic after 20 {
-  forkify(matmul);
-  fork-guard-elim(matmul);
+macro unforkify!(X) {
+  fork-split(X);
+  unforkify(X);
 }
 
+optimize!(*);
+forkify!(*);
+
 // Mark the whole loop nest as associative, any order of iterations is equivalent
 associative(matmul@outer);
 
 // Tile the outer 2 loops to create 16 parallel threads (each responsible for
 // computing one block of the output
 let par = matmul@outer \ matmul@inner;
-fork-tile[4, 0, false, true](par);
+fork-tile![4](par);
 fork-coalesce(par);
 fork-interchange[0, 1](par);
 fork-interchange[2, 3](par);
@@ -38,29 +56,25 @@ fork-interchange[1, 2](par);
 
 let split = fork-split(*);
 fork-coalesce(split.matmul.fj0 \ split.matmul.fj2);
-parallel-fork(split.matmul.fj0 \ split.matmul.fj2);
+
+parallelize!(split.matmul.fj0 \ split.matmul.fj2);
 
 // Pull the body of the parallel loop out into its own device function
 let body = outline(split.matmul.fj2);
 cpu(body);
 
 // Tile the loop nest for cache performance; 16x16x16 tile
-fork-tile[16, 0, false, true](body);
+fork-tile![16](body);
 fixpoint { fork-coalesce(body); }
-
 fork-interchange[1, 2](body);
 fork-interchange[3, 4](body);
 fork-interchange[2, 3](body);
 
-optimize!(*);
-
 fork-split(body);
 reduce-slf(*);
-unforkify(body);
-
-optimize!(*);
 
-parallel-reduce(split.matmul.fj0);
 xdot[true](*);
 
-codegen!(*);
+unforkify!(body);
+
+codegen-prep!(*);
-- 
GitLab


From 205e68e2d6fbb1649641750bc80a301f0743deeb Mon Sep 17 00:00:00 2001
From: Aaron Councilman <aaronjc4@illinois.edu>
Date: Wed, 26 Feb 2025 18:38:52 -0600
Subject: [PATCH 09/15] Add tuples to scheduling language

---
 juno_scheduler/src/compile.rs | 23 ++++++++++++++++++++
 juno_scheduler/src/ir.rs      |  7 ++++++
 juno_scheduler/src/lang.y     | 32 +++++++++++++++++++++------
 juno_scheduler/src/pm.rs      | 41 ++++++++++++++++++++++++++++++++++-
 4 files changed, 95 insertions(+), 8 deletions(-)

diff --git a/juno_scheduler/src/compile.rs b/juno_scheduler/src/compile.rs
index 3c288ca7..7b8c5020 100644
--- a/juno_scheduler/src/compile.rs
+++ b/juno_scheduler/src/compile.rs
@@ -488,6 +488,29 @@ fn compile_expr(
                 rhs: Box::new(rhs),
             }))
         }
+        parser::Expr::Tuple {
+            span: _,
+            exps,
+        } => {
+            let exprs = exps.into_iter()
+                .map(|e| compile_exp_as_expr(e, lexer, macrostab, macros))
+                .fold(Ok(vec![]),
+                    |mut res, exp| {
+                        let mut res = res?;
+                        res.push(exp?);
+                        Ok(res)
+                })?;
+            Ok(ExprResult::Expr(ir::ScheduleExp::Tuple { exprs }))
+        }
+        parser::Expr::TupleField {
+            span: _,
+            lhs,
+            field,
+        } => {
+            let lhs = compile_exp_as_expr(*lhs, lexer, macrostab, macros)?;
+            let field = lexer.span_str(field).parse().expect("Parsing");
+            Ok(ExprResult::Expr(ir::ScheduleExp::TupleField { lhs: Box::new(lhs), field }))
+        }
     }
 }
 
diff --git a/juno_scheduler/src/ir.rs b/juno_scheduler/src/ir.rs
index 3a087c0d..71a185ba 100644
--- a/juno_scheduler/src/ir.rs
+++ b/juno_scheduler/src/ir.rs
@@ -127,6 +127,13 @@ pub enum ScheduleExp {
         lhs: Box<ScheduleExp>,
         rhs: Box<ScheduleExp>,
     },
+    Tuple {
+        exprs: Vec<ScheduleExp>,
+    },
+    TupleField {
+        lhs: Box<ScheduleExp>,
+        field: usize,
+    },
     // This is used to "box" a selection by evaluating it at one point and then
     // allowing it to be used as a selector later on
     Selection {
diff --git a/juno_scheduler/src/lang.y b/juno_scheduler/src/lang.y
index 3b030e1d..5903e429 100644
--- a/juno_scheduler/src/lang.y
+++ b/juno_scheduler/src/lang.y
@@ -56,10 +56,14 @@ Expr -> Expr
       { Expr::String { span: $span } }
   | Expr '.' 'ID'
       { Expr::Field { span: $span, lhs: Box::new($1), field: span_of_tok($3) } }
+  | Expr '.' 'INT'
+      { Expr::TupleField { span: $span, lhs: Box::new($1), field: span_of_tok($3) } }
   | Expr '@' 'ID'
       { Expr::Field { span: $span, lhs: Box::new($1), field: span_of_tok($3) } }
-  | '(' Expr ')'
-      { $2 }
+  | '(' Exprs ')'
+      { exprs_to_expr($span, $2) }
+  | '[' Exprs ']'
+      { exprs_to_expr($span, $2) }
   | '{' Schedule '}'
       { Expr::BlockExpr { span: $span, body: Box::new($2) } }
   | '<' Fields '>'
@@ -73,14 +77,18 @@ Expr -> Expr
   ;
 
 Args -> Vec<Expr>
-  :               { vec![] }
-  | '[' Exprs ']' { rev($2) }
+  :                { vec![] }
+  | '[' RExprs ']' { rev($2) }
   ;
 
 Exprs -> Vec<Expr>
-  :                 { vec![] }
-  | Expr            { vec![$1] }
-  | Expr ',' Exprs  { snoc($1, $3) }
+  : RExprs  { rev($1) }
+  ;
+
+RExprs -> Vec<Expr>
+  :                  { vec![] }
+  | Expr             { vec![$1] }
+  | Expr ',' RExprs  { snoc($1, $3) }
   ;
 
 Fields -> Vec<(Span, Expr)>
@@ -180,6 +188,8 @@ pub enum Expr {
   BlockExpr   { span: Span, body: Box<OperationList> },
   Record      { span: Span, fields: Vec<(Span, Expr)> },
   SetOp       { span: Span, op: SetOp, lhs: Box<Expr>, rhs: Box<Expr> },
+  Tuple       { span: Span, exps: Vec<Expr> },
+  TupleField  { span: Span, lhs: Box<Expr>, field: Span },
 }
 
 pub enum Selector {
@@ -193,3 +203,11 @@ pub struct MacroDecl {
   pub selection_name: Span,
   pub def: Box<OperationList>,
 }
+
+fn exprs_to_expr(span: Span, mut exps: Vec<Expr>) -> Expr {
+  if exps.len() == 1 {
+    exps.pop().unwrap()
+  } else {
+    Expr::Tuple { span, exps }
+  }
+}
diff --git a/juno_scheduler/src/pm.rs b/juno_scheduler/src/pm.rs
index 5f2fa4cc..ef9ec038 100644
--- a/juno_scheduler/src/pm.rs
+++ b/juno_scheduler/src/pm.rs
@@ -294,6 +294,9 @@ pub enum Value {
     Record {
         fields: HashMap<String, Value>,
     },
+    Tuple {
+        values: Vec<Value>,
+    },
     Everything {},
     Selection {
         selection: Vec<Value>,
@@ -371,6 +374,11 @@ impl Value {
                     "Expected code selection, found record".to_string(),
                 ));
             }
+            Value::Tuple { .. } => {
+                return Err(SchedulerError::SemanticError(
+                    "Expected code selection, found tuple".to_string(),
+                ));
+            }
             Value::Integer { .. } => {
                 return Err(SchedulerError::SemanticError(
                     "Expected code selection, found integer".to_string(),
@@ -1291,6 +1299,7 @@ fn interp_expr(
                 | Value::Integer { .. }
                 | Value::Boolean { .. }
                 | Value::String { .. }
+                | Value::Tuple { .. }
                 | Value::SetOp { .. } => Err(SchedulerError::UndefinedField(field.clone())),
                 Value::JunoFunction { func } => {
                     match pm.labels.borrow().iter().position(|s| s == field) {
@@ -1463,7 +1472,24 @@ fn interp_expr(
                 }
                 Ok((Value::Selection { selection: values }, changed))
             }
-        },
+        }
+        ScheduleExp::Tuple { exprs } => {
+            let mut vals = vec![];
+            let mut changed = false;
+            for exp in exprs {
+                let (val, change) = interp_expr(pm, exp, stringtab, env, functions)?;
+                vals.push(val);
+                changed = changed || change;
+            }
+            Ok((Value::Tuple { values: vals }, changed))
+        }
+        ScheduleExp::TupleField { lhs, field } => {
+            let (val, changed) = interp_expr(pm, lhs, stringtab, env, functions)?;
+            match val {
+                Value::Tuple { values } if *field < values.len() => Ok((vec_take(values, *field), changed)),
+                _ => Err(SchedulerError::SemanticError(format!("No field at index {}", field))),
+            }
+        }
     }
 }
 
@@ -1521,6 +1547,15 @@ fn update_value(
                 Some(Value::Record { fields: new_fields })
             }
         }
+        // For tuples, if we deleted values like we do for records this would mess up the indices
+        // which would behave very strangely. Instead if any field cannot be updated then we
+        // eliminate the entire value
+        Value::Tuple { values } => {
+            values.into_iter()
+                .map(|v| update_value(v, func_idx, juno_func_idx))
+                .collect::<Option<Vec<_>>>()
+                .map(|values| Value::Tuple { values })
+        }
         Value::JunoFunction { func } => {
             juno_func_idx[func.idx]
                 .clone()
@@ -3016,3 +3051,7 @@ where
     });
     labels
 }
+
+fn vec_take<T>(mut v: Vec<T>, index: usize) -> T {
+    v.swap_remove(index)
+}
-- 
GitLab


From 3c8eaae28662fc423b2514d9ede60ab9ef32227f Mon Sep 17 00:00:00 2001
From: Aaron Councilman <aaronjc4@illinois.edu>
Date: Fri, 28 Feb 2025 09:51:46 -0600
Subject: [PATCH 10/15] add_node only adds to def use map if not already
 present

---
 hercules_opt/src/editor.rs | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/hercules_opt/src/editor.rs b/hercules_opt/src/editor.rs
index 17cea325..ce344699 100644
--- a/hercules_opt/src/editor.rs
+++ b/hercules_opt/src/editor.rs
@@ -457,7 +457,7 @@ impl<'a, 'b> FunctionEdit<'a, 'b> {
     pub fn add_node(&mut self, node: Node) -> NodeID {
         let id = NodeID::new(self.num_node_ids());
         // Added nodes need to have an entry in the def-use map.
-        self.updated_def_use.insert(id, HashSet::new());
+        self.updated_def_use.entry(id).or_insert(HashSet::new());
         // Added nodes use other nodes, and we need to update their def-use
         // entries.
         for u in get_uses(&node).as_ref() {
-- 
GitLab


From 4754a909271bdc228294e501cbaa1aa067fbf599 Mon Sep 17 00:00:00 2001
From: Aaron Councilman <aaronjc4@illinois.edu>
Date: Fri, 28 Feb 2025 09:52:32 -0600
Subject: [PATCH 11/15] Return new fork ID from fork coalesce and interchange

---
 hercules_opt/src/fork_transforms.rs | 47 +++++++++++++++++------------
 1 file changed, 28 insertions(+), 19 deletions(-)

diff --git a/hercules_opt/src/fork_transforms.rs b/hercules_opt/src/fork_transforms.rs
index ff0f0283..cb0e7de4 100644
--- a/hercules_opt/src/fork_transforms.rs
+++ b/hercules_opt/src/fork_transforms.rs
@@ -578,7 +578,7 @@ pub fn fork_coalesce(
     // FIXME: This could give us two forks that aren't actually ancestors / related, but then the helper will just return false early.
     // something like: `fork_joins.postorder_iter().windows(2)` is ideal here.
     for (inner, outer) in fork_joins.iter().cartesian_product(fork_joins.iter()) {
-        if fork_coalesce_helper(editor, *outer, *inner, fork_join_map) {
+        if fork_coalesce_helper(editor, *outer, *inner, fork_join_map).is_some() {
             return true;
         }
     }
@@ -587,13 +587,15 @@ pub fn fork_coalesce(
 
 /** Opposite of fork split, takes two fork-joins
     with no control between them, and merges them into a single fork-join.
+    Returns None if the forks could not be merged and the NodeIDs of the
+    resulting fork and join if it succeeds in merging them.
 */
 pub fn fork_coalesce_helper(
     editor: &mut FunctionEditor,
     outer_fork: NodeID,
     inner_fork: NodeID,
     fork_join_map: &HashMap<NodeID, NodeID>,
-) -> bool {
+) -> Option<(NodeID, NodeID)> {
     // Check that all reduces in the outer fork are in *simple* cycles with a unique reduce of the inner fork.
 
     let outer_join = fork_join_map[&outer_fork];
@@ -621,19 +623,19 @@ pub fn fork_coalesce_helper(
             reduct: _,
         } = inner_reduce_node
         else {
-            return false;
+            return None;
         };
 
         // FIXME: check this condition better (i.e reduce might not be attached to join)
         if *inner_control != inner_join {
-            return false;
+            return None;
         };
         if *inner_init != outer_reduce {
-            return false;
+            return None;
         };
 
         if pairs.contains_left(&outer_reduce) || pairs.contains_right(&inner_reduce) {
-            return false;
+            return None;
         } else {
             pairs.insert(outer_reduce, inner_reduce);
         }
@@ -645,11 +647,11 @@ pub fn fork_coalesce_helper(
         .filter(|node| editor.func().nodes[node.idx()].is_control())
         .next()
     else {
-        return false;
+        return None;
     };
 
     if user != inner_fork {
-        return false;
+        return None;
     }
 
     let Some(user) = editor
@@ -657,11 +659,11 @@ pub fn fork_coalesce_helper(
         .filter(|node| editor.func().nodes[node.idx()].is_control())
         .next()
     else {
-        return false;
+        return None;
     };
 
     if user != outer_join {
-        return false;
+        return None;
     }
 
     // Checklist:
@@ -709,10 +711,10 @@ pub fn fork_coalesce_helper(
         let (_, inner_init, _) = editor.func().nodes[inner_reduce.idx()]
             .try_reduce()
             .unwrap();
-        editor.edit(|mut edit| {
+        let success = editor.edit(|mut edit| {
             // Set inner init to outer init.
             edit =
-                edit.replace_all_uses_where(inner_init, outer_init, |usee| *usee == inner_reduce)?;
+                edit.replace_all_uses_where(inner_init, outer_init, |usee| *usee == inner_reduce )?;
             edit = edit.replace_all_uses(outer_reduce, inner_reduce)?;
             edit = edit.delete_node(outer_reduce)?;
 
@@ -720,12 +722,15 @@ pub fn fork_coalesce_helper(
         });
     }
 
+    let mut new_fork = NodeID::new(0);
+    let new_join = inner_join; // We reuse the inner join as the join of the new fork
+
     editor.edit(|mut edit| {
-        let new_fork = Node::Fork {
+        let new_fork_node = Node::Fork {
             control: outer_pred,
             factors: new_factors.into(),
         };
-        let new_fork = edit.add_node(new_fork);
+        new_fork = edit.add_node(new_fork_node);
 
         edit = edit.replace_all_uses(inner_fork, new_fork)?;
         edit = edit.replace_all_uses(outer_fork, new_fork)?;
@@ -737,7 +742,7 @@ pub fn fork_coalesce_helper(
         Ok(edit)
     });
 
-    true
+    Some((new_fork, new_join))
 }
 
 pub fn split_any_fork(
@@ -760,7 +765,7 @@ pub fn split_any_fork(
  * Useful for code generation. A single iteration of `fork_split` only splits
  * at most one fork-join, it must be called repeatedly to split all fork-joins.
  */
-pub(crate) fn split_fork(
+pub fn split_fork(
     editor: &mut FunctionEditor,
     fork: NodeID,
     join: NodeID,
@@ -1215,13 +1220,13 @@ pub fn fork_interchange_all_forks(
     }
 }
 
-fn fork_interchange(
+pub fn fork_interchange(
     editor: &mut FunctionEditor,
     fork: NodeID,
     join: NodeID,
     first_dim: usize,
     second_dim: usize,
-) {
+) -> Option<NodeID> {
     // Check that every reduce on the join is parallel or associative.
     let nodes = &editor.func().nodes;
     let schedules = &editor.func().schedules;
@@ -1234,7 +1239,7 @@ fn fork_interchange(
         })
     {
         // If not, we can't necessarily do interchange.
-        return;
+        return None;
     }
 
     let Node::Fork {
@@ -1276,6 +1281,7 @@ fn fork_interchange(
     let mut factors = factors.clone();
     factors.swap(first_dim, second_dim);
     let new_fork = Node::Fork { control, factors };
+    let mut new_fork_id = None;
     editor.edit(|mut edit| {
         for (old_id, new_tid) in fix_tids {
             let new_id = edit.add_node(new_tid);
@@ -1283,9 +1289,12 @@ fn fork_interchange(
             edit = edit.delete_node(old_id)?;
         }
         let new_fork = edit.add_node(new_fork);
+        new_fork_id = Some(new_fork);
         edit = edit.replace_all_uses(fork, new_fork)?;
         edit.delete_node(fork)
     });
+
+    new_fork_id
 }
 
 /*
-- 
GitLab


From 751ded74756f4d650e23e8cac8ce55081106c8cb Mon Sep 17 00:00:00 2001
From: Aaron Councilman <aaronjc4@illinois.edu>
Date: Fri, 28 Feb 2025 10:59:01 -0600
Subject: [PATCH 12/15] Improve fork coalesce check for intermediate control

---
 hercules_opt/src/fork_transforms.rs | 20 ++++----------------
 1 file changed, 4 insertions(+), 16 deletions(-)

diff --git a/hercules_opt/src/fork_transforms.rs b/hercules_opt/src/fork_transforms.rs
index cb0e7de4..12b91194 100644
--- a/hercules_opt/src/fork_transforms.rs
+++ b/hercules_opt/src/fork_transforms.rs
@@ -642,27 +642,15 @@ pub fn fork_coalesce_helper(
     }
 
     // Check for control between join-join and fork-fork
-    let Some(user) = editor
-        .get_users(outer_fork)
-        .filter(|node| editor.func().nodes[node.idx()].is_control())
-        .next()
-    else {
-        return None;
-    };
+    let (control, _) = editor.node(inner_fork).try_fork().unwrap();
 
-    if user != inner_fork {
+    if control != outer_fork {
         return None;
     }
 
-    let Some(user) = editor
-        .get_users(inner_join)
-        .filter(|node| editor.func().nodes[node.idx()].is_control())
-        .next()
-    else {
-        return None;
-    };
+    let control = editor.node(outer_join).try_join().unwrap();
 
-    if user != outer_join {
+    if control != inner_join {
         return None;
     }
 
-- 
GitLab


From 2d577f79433c36f7c00e2368f9505dfee0efb644 Mon Sep 17 00:00:00 2001
From: Aaron Councilman <aaronjc4@illinois.edu>
Date: Fri, 28 Feb 2025 10:59:16 -0600
Subject: [PATCH 13/15] Fork-reshape

---
 juno_samples/matmul/build.rs    |   2 +
 juno_samples/matmul/src/cpu.sch |  62 +++++++++
 juno_scheduler/src/compile.rs   |   1 +
 juno_scheduler/src/ir.rs        |   3 +
 juno_scheduler/src/lang.y       |  12 +-
 juno_scheduler/src/pm.rs        | 214 ++++++++++++++++++++++++++++++++
 6 files changed, 284 insertions(+), 10 deletions(-)
 create mode 100644 juno_samples/matmul/src/cpu.sch

diff --git a/juno_samples/matmul/build.rs b/juno_samples/matmul/build.rs
index 0be838c6..d2813388 100644
--- a/juno_samples/matmul/build.rs
+++ b/juno_samples/matmul/build.rs
@@ -6,6 +6,8 @@ fn main() {
         JunoCompiler::new()
             .file_in_src("matmul.jn")
             .unwrap()
+            .schedule_in_src("cpu.sch")
+            .unwrap()
             .build()
             .unwrap();
     }
diff --git a/juno_samples/matmul/src/cpu.sch b/juno_samples/matmul/src/cpu.sch
new file mode 100644
index 00000000..7577aebb
--- /dev/null
+++ b/juno_samples/matmul/src/cpu.sch
@@ -0,0 +1,62 @@
+macro optimize!(X) {
+  gvn(X);
+  phi-elim(X);
+  dce(X);
+  ip-sroa(X);
+  sroa(X);
+  dce(X);
+  gvn(X);
+  phi-elim(X);
+  dce(X);
+}
+
+macro codegen-prep!(X) {
+  optimize!(X);
+  gcm(X);
+  float-collections(X);
+  dce(X);
+  gcm(X);
+}
+
+macro forkify!(X) {
+  fixpoint {
+    forkify(X);
+    fork-guard-elim(X);
+  }
+}
+
+macro fork-tile![n](X) {
+  fork-tile[n, 0, false, true](X);
+}
+
+macro parallelize!(X) {
+  parallel-fork(X);
+  parallel-reduce(X);
+}
+
+macro unforkify!(X) {
+  fork-split(X);
+  unforkify(X);
+}
+
+optimize!(*);
+forkify!(*);
+associative(matmul@outer);
+
+// Parallelize by computing output array as 16 chunks
+let par = matmul@outer \ matmul@inner;
+fork-tile![4](par);
+let res = fork-reshape[[1, 3], [0], [2]](par); let outer = res.0; let inner = res.1;
+parallelize!(outer \ inner);
+
+let body = outline(inner);
+cpu(body);
+
+// Tile for cache, assuming 64B cache lines
+fork-split(body);
+fork-tile![16](body);
+let res = fork-reshape[[0, 2, 4, 1, 3], [5]](body); let outer = res.0; let inner = res.1;
+
+reduce-slf(inner);
+unforkify!(body);
+codegen-prep!(*);
diff --git a/juno_scheduler/src/compile.rs b/juno_scheduler/src/compile.rs
index 7b8c5020..0f473f0d 100644
--- a/juno_scheduler/src/compile.rs
+++ b/juno_scheduler/src/compile.rs
@@ -134,6 +134,7 @@ impl FromStr for Appliable {
             "fork-extend" => Ok(Appliable::Pass(ir::Pass::ForkExtend)),
             "fork-unroll" | "unroll" => Ok(Appliable::Pass(ir::Pass::ForkUnroll)),
             "fork-fusion" | "fusion" => Ok(Appliable::Pass(ir::Pass::ForkFusion)),
+            "fork-reshape" => Ok(Appliable::Pass(ir::Pass::ForkReshape)),
             "lift-dc-math" => Ok(Appliable::Pass(ir::Pass::LiftDCMath)),
             "loop-bound-canon" => Ok(Appliable::Pass(ir::Pass::LoopBoundCanon)),
             "outline" => Ok(Appliable::Pass(ir::Pass::Outline)),
diff --git a/juno_scheduler/src/ir.rs b/juno_scheduler/src/ir.rs
index 71a185ba..cf4b6558 100644
--- a/juno_scheduler/src/ir.rs
+++ b/juno_scheduler/src/ir.rs
@@ -20,6 +20,7 @@ pub enum Pass {
     ForkFusion,
     ForkGuardElim,
     ForkInterchange,
+    ForkReshape,
     ForkSplit,
     ForkUnroll,
     Forkify,
@@ -57,6 +58,7 @@ impl Pass {
             Pass::ForkExtend => num == 1,
             Pass::ForkFissionBufferize => num == 2 || num == 1,
             Pass::ForkInterchange => num == 2,
+            Pass::ForkReshape => true,
             Pass::Print => num == 1,
             Pass::Rename => num == 1,
             Pass::SROA => num == 0 || num == 1,
@@ -73,6 +75,7 @@ impl Pass {
             Pass::ForkExtend => "1",
             Pass::ForkFissionBufferize => "1 or 2",
             Pass::ForkInterchange => "2",
+            Pass::ForkReshape => "any",
             Pass::Print => "1",
             Pass::Rename => "1",
             Pass::SROA => "0 or 1",
diff --git a/juno_scheduler/src/lang.y b/juno_scheduler/src/lang.y
index 5903e429..69d50014 100644
--- a/juno_scheduler/src/lang.y
+++ b/juno_scheduler/src/lang.y
@@ -61,9 +61,9 @@ Expr -> Expr
   | Expr '@' 'ID'
       { Expr::Field { span: $span, lhs: Box::new($1), field: span_of_tok($3) } }
   | '(' Exprs ')'
-      { exprs_to_expr($span, $2) }
+      { Expr::Tuple { span: $span, exps: $2 } }
   | '[' Exprs ']'
-      { exprs_to_expr($span, $2) }
+      { Expr::Tuple { span: $span, exps: $2 } }
   | '{' Schedule '}'
       { Expr::BlockExpr { span: $span, body: Box::new($2) } }
   | '<' Fields '>'
@@ -203,11 +203,3 @@ pub struct MacroDecl {
   pub selection_name: Span,
   pub def: Box<OperationList>,
 }
-
-fn exprs_to_expr(span: Span, mut exps: Vec<Expr>) -> Expr {
-  if exps.len() == 1 {
-    exps.pop().unwrap()
-  } else {
-    Expr::Tuple { span, exps }
-  }
-}
diff --git a/juno_scheduler/src/pm.rs b/juno_scheduler/src/pm.rs
index ef9ec038..36259718 100644
--- a/juno_scheduler/src/pm.rs
+++ b/juno_scheduler/src/pm.rs
@@ -2901,6 +2901,220 @@ fn run_pass(
             pm.delete_gravestones();
             pm.clear_analyses();
         }
+        Pass::ForkReshape => {
+            let mut shape = vec![];
+            let mut loops = BTreeSet::new();
+            let mut fork_count = 0;
+
+            for arg in args {
+                let Value::Tuple { values } = arg else {
+                    return Err(SchedulerError::PassError {
+                        pass: "fork-reshape".to_string(),
+                        error: "expected each argument to be a list of integers".to_string(),
+                    });
+                };
+
+                let mut indices = vec![];
+                for val in values {
+                    let Value::Integer { val: idx } = val else {
+                        return Err(SchedulerError::PassError {
+                            pass: "fork-reshape".to_string(),
+                            error: "expected each argument to be a list of integers".to_string(),
+                        });
+                    };
+                    indices.push(idx);
+                    loops.insert(idx);
+                    fork_count += 1;
+                }
+                shape.push(indices);
+            }
+
+            if loops != (0..fork_count).collect() {
+                return Err(SchedulerError::PassError {
+                    pass: "fork-reshape".to_string(),
+                    error: "expected forks to be numbered sequentially from 0 and used exactly once".to_string(),
+                });
+            }
+
+            let Some((nodes, func_id)) = selection_as_set(pm, selection) else {
+                return Err(SchedulerError::PassError {
+                    pass: "fork-reshape".to_string(),
+                    error: "must be applied to nodes in a single function".to_string(),
+                });
+            };
+            let func = func_id.idx();
+
+            pm.make_def_uses();
+            pm.make_fork_join_maps();
+            pm.make_loops();
+            pm.make_reduce_cycles();
+
+            let def_uses = pm.def_uses.take().unwrap();
+            let mut fork_join_maps = pm.fork_join_maps.take().unwrap();
+            let loops = pm.loops.take().unwrap();
+            let reduce_cycles = pm.reduce_cycles.take().unwrap();
+
+            let def_use = &def_uses[func];
+            let fork_join_map = &mut fork_join_maps[func];
+            let loops = &loops[func];
+            let reduce_cycles = &reduce_cycles[func];
+
+            let mut editor = FunctionEditor::new(
+                &mut pm.functions[func],
+                func_id,
+                &pm.constants,
+                &pm.dynamic_constants,
+                &pm.types,
+                &pm.labels,
+                def_use,
+            );
+
+            // There should be exactly one fork nest in the selection and it should contain
+            // exactly fork_count forks (counting each dimension of each fork)
+            // We determine the loops (ordered top-down) that are contained in the selection
+            // (in particular the header is in the selection) and its a fork-join (the header
+            // is a fork)
+            let mut loops = loops.bottom_up_loops().into_iter().rev()
+                .filter(|(header, _)| nodes.contains(header) && editor.node(header).is_fork());
+            let Some((top_fork_head, top_fork_body)) = loops.next() else {
+                return Err(SchedulerError::PassError {
+                    pass: "fork-reshape".to_string(),
+                    error: format!("expected {} forks found 0 in {}", fork_count, editor.func().name),
+                });
+            };
+            // All the remaining forks need to be contained in the top fork body
+            let mut forks = vec![top_fork_head];
+            let mut num_dims = editor.node(top_fork_head).try_fork().unwrap().1.len();
+            for (head, _) in loops {
+                if !top_fork_body[head.idx()] {
+                    return Err(SchedulerError::PassError {
+                        pass: "fork-reshape".to_string(),
+                        error: "selection includes multiple non-nested forks".to_string(),
+                    });
+                } else {
+                    forks.push(head);
+                    num_dims += editor.node(head).try_fork().unwrap().1.len();
+                }
+            }
+
+            if num_dims != fork_count {
+                return Err(SchedulerError::PassError {
+                    pass: "fork-reshape".to_string(),
+                    error: format!("expected {} forks, found {} in {}", fork_count, num_dims, pm.functions[func].name),
+                });
+            }
+
+            // Now, we coalesce all of these forks into one so that we can interchange them
+            let mut forks = forks.into_iter();
+            let top_fork = forks.next().unwrap();
+            let mut cur_fork = top_fork;
+            for next_fork in forks {
+                let Some((new_fork, new_join)) = fork_coalesce_helper(&mut editor, cur_fork, next_fork, &fork_join_map) else {
+                    return Err(SchedulerError::PassError {
+                        pass: "fork-reshape".to_string(),
+                        error: "failed to coalesce forks".to_string(),
+                    });
+                };
+                cur_fork = new_fork;
+                fork_join_map.insert(new_fork, new_join);
+            }
+            let join = *fork_join_map.get(&cur_fork).unwrap();
+
+            // Now we have just one fork and we can perform the interchanges we need
+            // To do this, we track two maps: from original index to current index and from
+            // current index to original index
+            let mut orig_to_cur = (0..fork_count).collect::<Vec<_>>();
+            let mut cur_to_orig = (0..fork_count).collect::<Vec<_>>();
+
+            // Now, starting from the first (outermost) index we move the desired fork bound
+            // into place
+            for (idx, original_idx) in shape.iter().flat_map(|idx| idx.iter()).enumerate() {
+                let cur_idx = orig_to_cur[*original_idx];
+                let swapping = cur_to_orig[idx];
+
+                // If the desired factor is already in the correct place, do nothing
+                if cur_idx == idx {
+                    continue;
+                }
+                assert!(idx < cur_idx);
+                let Some(fork_res) = fork_interchange(&mut editor, cur_fork, join, idx, cur_idx) else {
+                    return Err(SchedulerError::PassError {
+                        pass: "fork-reshape".to_string(),
+                        error: "failed to interchange forks".to_string(),
+                    });
+                };
+                cur_fork = fork_res;
+
+                // Update our maps
+                orig_to_cur[*original_idx] = idx;
+                orig_to_cur[swapping] = cur_idx;
+                cur_to_orig[idx] = *original_idx;
+                cur_to_orig[cur_idx] = swapping;
+            }
+
+            // Finally we split the fork into the desired pieces. We do this by first splitting
+            // the fork into individual forks and then coalesce the chunks together
+            // Not sure how split_fork could fail, so if it does panic is fine
+            let (forks, joins) = split_fork(&mut editor, cur_fork, join, &reduce_cycles).unwrap();
+
+            for (fork, join) in forks.iter().zip(joins.iter()) {
+                fork_join_map.insert(*fork, *join);
+            }
+
+            // Finally coalesce the chunks together
+            let mut fork_idx = 0;
+            let mut final_forks = vec![];
+            for chunk in shape.iter() {
+                let chunk_len = chunk.len();
+
+                let mut cur_fork = forks[fork_idx];
+                for i in 1..chunk_len {
+                    let next_fork = forks[fork_idx + i];
+                    // Again, not sure at this point how coalesce could fail, so panic if it
+                    // does
+                    let (new_fork, new_join) = fork_coalesce_helper(&mut editor, cur_fork, next_fork, &fork_join_map).unwrap();
+                    cur_fork = new_fork;
+                    fork_join_map.insert(new_fork, new_join);
+                }
+
+                fork_idx += chunk_len;
+                final_forks.push(cur_fork);
+            }
+
+            // Label each fork and return the labels
+            // We've trashed our analyses at this point, so rerun them so that we can determine the
+            // nodes in each of the result fork-joins
+            pm.clear_analyses();
+            pm.make_def_uses();
+            pm.make_nodes_in_fork_joins();
+            
+            let def_uses = pm.def_uses.take().unwrap();
+            let nodes_in_fork_joins = pm.nodes_in_fork_joins.take().unwrap();
+
+            let def_use = &def_uses[func];
+            let nodes_in_fork_joins = &nodes_in_fork_joins[func];
+            
+            let mut editor = FunctionEditor::new(
+                &mut pm.functions[func],
+                func_id,
+                &pm.constants,
+                &pm.dynamic_constants,
+                &pm.types,
+                &pm.labels,
+                def_use,
+            );
+
+            let labels = create_labels_for_node_sets(&mut editor, final_forks.into_iter().map(|fork| nodes_in_fork_joins[&fork].iter().copied()))
+                .into_iter()
+                .map(|(_, label)| Value::Label { labels: vec![LabelInfo { func: func_id, label }] })
+                .collect();
+
+            result = Value::Tuple { values: labels };
+            changed = true;
+
+            pm.delete_gravestones();
+            pm.clear_analyses();
+        }
         Pass::WritePredication => {
             assert!(args.is_empty());
             for func in build_selection(pm, selection, false) {
-- 
GitLab


From 39e2ea9372530618a5f6aab8cf6c8b371849b637 Mon Sep 17 00:00:00 2001
From: Aaron Councilman <aaronjc4@illinois.edu>
Date: Fri, 28 Feb 2025 11:16:08 -0600
Subject: [PATCH 14/15] Allow let-tuple destruction in scheduling language

---
 juno_samples/matmul/src/cpu.sch |  5 ++---
 juno_scheduler/src/compile.rs   | 18 ++++++++++++++++++
 juno_scheduler/src/lang.y       |  3 +++
 3 files changed, 23 insertions(+), 3 deletions(-)

diff --git a/juno_samples/matmul/src/cpu.sch b/juno_samples/matmul/src/cpu.sch
index 7577aebb..69f1811d 100644
--- a/juno_samples/matmul/src/cpu.sch
+++ b/juno_samples/matmul/src/cpu.sch
@@ -46,16 +46,15 @@ associative(matmul@outer);
 // Parallelize by computing output array as 16 chunks
 let par = matmul@outer \ matmul@inner;
 fork-tile![4](par);
-let res = fork-reshape[[1, 3], [0], [2]](par); let outer = res.0; let inner = res.1;
+let (outer, inner, _) = fork-reshape[[1, 3], [0], [2]](par);
 parallelize!(outer \ inner);
 
 let body = outline(inner);
 cpu(body);
 
 // Tile for cache, assuming 64B cache lines
-fork-split(body);
 fork-tile![16](body);
-let res = fork-reshape[[0, 2, 4, 1, 3], [5]](body); let outer = res.0; let inner = res.1;
+let (outer, inner) = fork-reshape[[0, 2, 4, 1, 3], [5]](body);
 
 reduce-slf(inner);
 unforkify!(body);
diff --git a/juno_scheduler/src/compile.rs b/juno_scheduler/src/compile.rs
index 0f473f0d..51ba3699 100644
--- a/juno_scheduler/src/compile.rs
+++ b/juno_scheduler/src/compile.rs
@@ -207,6 +207,24 @@ fn compile_stmt(
                 exp: compile_exp_as_expr(expr, lexer, macrostab, macros)?,
             }])
         }
+        parser::Stmt::LetsStmt { span: _, vars, expr } => {
+            let tmp = format!("{}_tmp", macros.uniq());
+            Ok(std::iter::once(ir::ScheduleStmt::Let {
+                var: tmp.clone(),
+                exp: compile_exp_as_expr(expr, lexer, macrostab, macros)?,
+            }).chain(vars.into_iter().enumerate()
+                .map(|(idx, v)| {
+                    let var = lexer.span_str(v).to_string();
+                    ir::ScheduleStmt::Let {
+                        var,
+                        exp: ir::ScheduleExp::TupleField {
+                            lhs: Box::new(ir::ScheduleExp::Variable { var: tmp.clone() }),
+                            field: idx,
+                        }
+                    }
+                })
+            ).collect())
+        }
         parser::Stmt::AssignStmt { span: _, var, rhs } => {
             let var = lexer.span_str(var).to_string();
             Ok(vec![ir::ScheduleStmt::Assign {
diff --git a/juno_scheduler/src/lang.y b/juno_scheduler/src/lang.y
index 69d50014..451f035b 100644
--- a/juno_scheduler/src/lang.y
+++ b/juno_scheduler/src/lang.y
@@ -19,6 +19,8 @@ Schedule -> OperationList
 Stmt -> Stmt
   : 'let' 'ID' '=' Expr ';'
       { Stmt::LetStmt { span: $span, var: span_of_tok($2), expr: $4 } }
+  | 'let' '(' Ids ')' '=' Expr ';'
+      { Stmt::LetsStmt { span: $span, vars: rev($3), expr: $6 } }
   | 'ID' '=' Expr ';'
       { Stmt::AssignStmt { span: $span, var: span_of_tok($1), rhs: $3 } }
   | Expr ';'
@@ -157,6 +159,7 @@ pub enum OperationList {
 
 pub enum Stmt {
   LetStmt    { span: Span, var: Span, expr: Expr },
+  LetsStmt   { span: Span, vars: Vec<Span>, expr: Expr },
   AssignStmt { span: Span, var: Span, rhs: Expr },
   ExprStmt   { span: Span, exp: Expr },
   Fixpoint   { span: Span, limit: FixpointLimit, body: Box<OperationList> },
-- 
GitLab


From 7370cb51e84d2992fcde703ed4bd8090250e9000 Mon Sep 17 00:00:00 2001
From: Aaron Councilman <aaronjc4@illinois.edu>
Date: Fri, 28 Feb 2025 11:47:28 -0600
Subject: [PATCH 15/15] Formatting

---
 hercules_opt/src/fork_transforms.rs | 19 ++++----
 hercules_samples/matmul/src/main.rs |  3 +-
 juno_scheduler/src/compile.rs       | 52 +++++++++++---------
 juno_scheduler/src/pm.rs            | 75 ++++++++++++++++++++---------
 4 files changed, 93 insertions(+), 56 deletions(-)

diff --git a/hercules_opt/src/fork_transforms.rs b/hercules_opt/src/fork_transforms.rs
index 9a16c99c..e6db0345 100644
--- a/hercules_opt/src/fork_transforms.rs
+++ b/hercules_opt/src/fork_transforms.rs
@@ -693,15 +693,11 @@ pub fn fork_coalesce_helper(
         }
         // Fuse Reductions
         for (outer_reduce, inner_reduce) in pairs {
-            let (_, outer_init, _) = edit.get_node(outer_reduce)
-                .try_reduce()
-                .unwrap();
-            let (_, inner_init, _) = edit.get_node(inner_reduce)
-                .try_reduce()
-                .unwrap();
+            let (_, outer_init, _) = edit.get_node(outer_reduce).try_reduce().unwrap();
+            let (_, inner_init, _) = edit.get_node(inner_reduce).try_reduce().unwrap();
             // Set inner init to outer init.
             edit =
-                edit.replace_all_uses_where(inner_init, outer_init, |usee| *usee == inner_reduce )?;
+                edit.replace_all_uses_where(inner_init, outer_init, |usee| *usee == inner_reduce)?;
             edit = edit.replace_all_uses(outer_reduce, inner_reduce)?;
             edit = edit.delete_node(outer_reduce)?;
         }
@@ -712,8 +708,13 @@ pub fn fork_coalesce_helper(
         };
         new_fork = edit.add_node(new_fork_node);
 
-        if edit.get_schedule(outer_fork).contains(&Schedule::ParallelFork)
-            && edit.get_schedule(inner_fork).contains(&Schedule::ParallelFork) {
+        if edit
+            .get_schedule(outer_fork)
+            .contains(&Schedule::ParallelFork)
+            && edit
+                .get_schedule(inner_fork)
+                .contains(&Schedule::ParallelFork)
+        {
             edit = edit.add_schedule(new_fork, Schedule::ParallelFork)?;
         }
 
diff --git a/hercules_samples/matmul/src/main.rs b/hercules_samples/matmul/src/main.rs
index 27727664..00e5b873 100644
--- a/hercules_samples/matmul/src/main.rs
+++ b/hercules_samples/matmul/src/main.rs
@@ -25,7 +25,8 @@ fn main() {
         let a = HerculesImmBox::from(a.as_ref());
         let b = HerculesImmBox::from(b.as_ref());
         let mut r = runner!(matmul);
-        let mut c: HerculesMutBox<i32> = HerculesMutBox::from(r.run(I as u64, J as u64, K as u64, a.to(), b.to()).await);
+        let mut c: HerculesMutBox<i32> =
+            HerculesMutBox::from(r.run(I as u64, J as u64, K as u64, a.to(), b.to()).await);
         assert_eq!(c.as_slice(), correct_c.as_ref());
     });
 }
diff --git a/juno_scheduler/src/compile.rs b/juno_scheduler/src/compile.rs
index 32d2c8d1..9d020c64 100644
--- a/juno_scheduler/src/compile.rs
+++ b/juno_scheduler/src/compile.rs
@@ -208,23 +208,27 @@ fn compile_stmt(
                 exp: compile_exp_as_expr(expr, lexer, macrostab, macros)?,
             }])
         }
-        parser::Stmt::LetsStmt { span: _, vars, expr } => {
+        parser::Stmt::LetsStmt {
+            span: _,
+            vars,
+            expr,
+        } => {
             let tmp = format!("{}_tmp", macros.uniq());
             Ok(std::iter::once(ir::ScheduleStmt::Let {
                 var: tmp.clone(),
                 exp: compile_exp_as_expr(expr, lexer, macrostab, macros)?,
-            }).chain(vars.into_iter().enumerate()
-                .map(|(idx, v)| {
-                    let var = lexer.span_str(v).to_string();
-                    ir::ScheduleStmt::Let {
-                        var,
-                        exp: ir::ScheduleExp::TupleField {
-                            lhs: Box::new(ir::ScheduleExp::Variable { var: tmp.clone() }),
-                            field: idx,
-                        }
-                    }
-                })
-            ).collect())
+            })
+            .chain(vars.into_iter().enumerate().map(|(idx, v)| {
+                let var = lexer.span_str(v).to_string();
+                ir::ScheduleStmt::Let {
+                    var,
+                    exp: ir::ScheduleExp::TupleField {
+                        lhs: Box::new(ir::ScheduleExp::Variable { var: tmp.clone() }),
+                        field: idx,
+                    },
+                }
+            }))
+            .collect())
         }
         parser::Stmt::AssignStmt { span: _, var, rhs } => {
             let var = lexer.span_str(var).to_string();
@@ -508,17 +512,14 @@ fn compile_expr(
                 rhs: Box::new(rhs),
             }))
         }
-        parser::Expr::Tuple {
-            span: _,
-            exps,
-        } => {
-            let exprs = exps.into_iter()
+        parser::Expr::Tuple { span: _, exps } => {
+            let exprs = exps
+                .into_iter()
                 .map(|e| compile_exp_as_expr(e, lexer, macrostab, macros))
-                .fold(Ok(vec![]),
-                    |mut res, exp| {
-                        let mut res = res?;
-                        res.push(exp?);
-                        Ok(res)
+                .fold(Ok(vec![]), |mut res, exp| {
+                    let mut res = res?;
+                    res.push(exp?);
+                    Ok(res)
                 })?;
             Ok(ExprResult::Expr(ir::ScheduleExp::Tuple { exprs }))
         }
@@ -529,7 +530,10 @@ fn compile_expr(
         } => {
             let lhs = compile_exp_as_expr(*lhs, lexer, macrostab, macros)?;
             let field = lexer.span_str(field).parse().expect("Parsing");
-            Ok(ExprResult::Expr(ir::ScheduleExp::TupleField { lhs: Box::new(lhs), field }))
+            Ok(ExprResult::Expr(ir::ScheduleExp::TupleField {
+                lhs: Box::new(lhs),
+                field,
+            }))
         }
     }
 }
diff --git a/juno_scheduler/src/pm.rs b/juno_scheduler/src/pm.rs
index f0b55eca..456df2ed 100644
--- a/juno_scheduler/src/pm.rs
+++ b/juno_scheduler/src/pm.rs
@@ -1475,7 +1475,7 @@ fn interp_expr(
                 }
                 Ok((Value::Selection { selection: values }, changed))
             }
-        }
+        },
         ScheduleExp::Tuple { exprs } => {
             let mut vals = vec![];
             let mut changed = false;
@@ -1489,8 +1489,13 @@ fn interp_expr(
         ScheduleExp::TupleField { lhs, field } => {
             let (val, changed) = interp_expr(pm, lhs, stringtab, env, functions)?;
             match val {
-                Value::Tuple { values } if *field < values.len() => Ok((vec_take(values, *field), changed)),
-                _ => Err(SchedulerError::SemanticError(format!("No field at index {}", field))),
+                Value::Tuple { values } if *field < values.len() => {
+                    Ok((vec_take(values, *field), changed))
+                }
+                _ => Err(SchedulerError::SemanticError(format!(
+                    "No field at index {}",
+                    field
+                ))),
             }
         }
     }
@@ -1553,12 +1558,11 @@ fn update_value(
         // For tuples, if we deleted values like we do for records this would mess up the indices
         // which would behave very strangely. Instead if any field cannot be updated then we
         // eliminate the entire value
-        Value::Tuple { values } => {
-            values.into_iter()
-                .map(|v| update_value(v, func_idx, juno_func_idx))
-                .collect::<Option<Vec<_>>>()
-                .map(|values| Value::Tuple { values })
-        }
+        Value::Tuple { values } => values
+            .into_iter()
+            .map(|v| update_value(v, func_idx, juno_func_idx))
+            .collect::<Option<Vec<_>>>()
+            .map(|values| Value::Tuple { values }),
         Value::JunoFunction { func } => {
             juno_func_idx[func.idx]
                 .clone()
@@ -2963,7 +2967,9 @@ fn run_pass(
             if loops != (0..fork_count).collect() {
                 return Err(SchedulerError::PassError {
                     pass: "fork-reshape".to_string(),
-                    error: "expected forks to be numbered sequentially from 0 and used exactly once".to_string(),
+                    error:
+                        "expected forks to be numbered sequentially from 0 and used exactly once"
+                            .to_string(),
                 });
             }
 
@@ -3005,12 +3011,19 @@ fn run_pass(
             // We determine the loops (ordered top-down) that are contained in the selection
             // (in particular the header is in the selection) and its a fork-join (the header
             // is a fork)
-            let mut loops = loops.bottom_up_loops().into_iter().rev()
+            let mut loops = loops
+                .bottom_up_loops()
+                .into_iter()
+                .rev()
                 .filter(|(header, _)| nodes.contains(header) && editor.node(header).is_fork());
             let Some((top_fork_head, top_fork_body)) = loops.next() else {
                 return Err(SchedulerError::PassError {
                     pass: "fork-reshape".to_string(),
-                    error: format!("expected {} forks found 0 in {}", fork_count, editor.func().name),
+                    error: format!(
+                        "expected {} forks found 0 in {}",
+                        fork_count,
+                        editor.func().name
+                    ),
                 });
             };
             // All the remaining forks need to be contained in the top fork body
@@ -3031,7 +3044,10 @@ fn run_pass(
             if num_dims != fork_count {
                 return Err(SchedulerError::PassError {
                     pass: "fork-reshape".to_string(),
-                    error: format!("expected {} forks, found {} in {}", fork_count, num_dims, pm.functions[func].name),
+                    error: format!(
+                        "expected {} forks, found {} in {}",
+                        fork_count, num_dims, pm.functions[func].name
+                    ),
                 });
             }
 
@@ -3040,7 +3056,9 @@ fn run_pass(
             let top_fork = forks.next().unwrap();
             let mut cur_fork = top_fork;
             for next_fork in forks {
-                let Some((new_fork, new_join)) = fork_coalesce_helper(&mut editor, cur_fork, next_fork, &fork_join_map) else {
+                let Some((new_fork, new_join)) =
+                    fork_coalesce_helper(&mut editor, cur_fork, next_fork, &fork_join_map)
+                else {
                     return Err(SchedulerError::PassError {
                         pass: "fork-reshape".to_string(),
                         error: "failed to coalesce forks".to_string(),
@@ -3068,7 +3086,8 @@ fn run_pass(
                     continue;
                 }
                 assert!(idx < cur_idx);
-                let Some(fork_res) = fork_interchange(&mut editor, cur_fork, join, idx, cur_idx) else {
+                let Some(fork_res) = fork_interchange(&mut editor, cur_fork, join, idx, cur_idx)
+                else {
                     return Err(SchedulerError::PassError {
                         pass: "fork-reshape".to_string(),
                         error: "failed to interchange forks".to_string(),
@@ -3103,7 +3122,9 @@ fn run_pass(
                     let next_fork = forks[fork_idx + i];
                     // Again, not sure at this point how coalesce could fail, so panic if it
                     // does
-                    let (new_fork, new_join) = fork_coalesce_helper(&mut editor, cur_fork, next_fork, &fork_join_map).unwrap();
+                    let (new_fork, new_join) =
+                        fork_coalesce_helper(&mut editor, cur_fork, next_fork, &fork_join_map)
+                            .unwrap();
                     cur_fork = new_fork;
                     fork_join_map.insert(new_fork, new_join);
                 }
@@ -3118,13 +3139,13 @@ fn run_pass(
             pm.clear_analyses();
             pm.make_def_uses();
             pm.make_nodes_in_fork_joins();
-            
+
             let def_uses = pm.def_uses.take().unwrap();
             let nodes_in_fork_joins = pm.nodes_in_fork_joins.take().unwrap();
 
             let def_use = &def_uses[func];
             let nodes_in_fork_joins = &nodes_in_fork_joins[func];
-            
+
             let mut editor = FunctionEditor::new(
                 &mut pm.functions[func],
                 func_id,
@@ -3135,10 +3156,20 @@ fn run_pass(
                 def_use,
             );
 
-            let labels = create_labels_for_node_sets(&mut editor, final_forks.into_iter().map(|fork| nodes_in_fork_joins[&fork].iter().copied()))
-                .into_iter()
-                .map(|(_, label)| Value::Label { labels: vec![LabelInfo { func: func_id, label }] })
-                .collect();
+            let labels = create_labels_for_node_sets(
+                &mut editor,
+                final_forks
+                    .into_iter()
+                    .map(|fork| nodes_in_fork_joins[&fork].iter().copied()),
+            )
+            .into_iter()
+            .map(|(_, label)| Value::Label {
+                labels: vec![LabelInfo {
+                    func: func_id,
+                    label,
+                }],
+            })
+            .collect();
 
             result = Value::Tuple { values: labels };
             changed = true;
-- 
GitLab