diff --git a/juno_samples/matmul/build.rs b/juno_samples/matmul/build.rs
index 0be838c620761e8726590e2dbaf7bfdb7a82e3df..d2813388e0e7a1d7bd1696ffbb641e629096e2c2 100644
--- a/juno_samples/matmul/build.rs
+++ b/juno_samples/matmul/build.rs
@@ -6,6 +6,8 @@ fn main() {
         JunoCompiler::new()
             .file_in_src("matmul.jn")
             .unwrap()
+            .schedule_in_src("cpu.sch")
+            .unwrap()
             .build()
             .unwrap();
     }
diff --git a/juno_samples/matmul/src/cpu.sch b/juno_samples/matmul/src/cpu.sch
new file mode 100644
index 0000000000000000000000000000000000000000..7577aebba60a03f3b21de8c54b83be083c557e7c
--- /dev/null
+++ b/juno_samples/matmul/src/cpu.sch
@@ -0,0 +1,62 @@
+macro optimize!(X) {
+  gvn(X);
+  phi-elim(X);
+  dce(X);
+  ip-sroa(X);
+  sroa(X);
+  dce(X);
+  gvn(X);
+  phi-elim(X);
+  dce(X);
+}
+
+macro codegen-prep!(X) {
+  optimize!(X);
+  gcm(X);
+  float-collections(X);
+  dce(X);
+  gcm(X);
+}
+
+macro forkify!(X) {
+  fixpoint {
+    forkify(X);
+    fork-guard-elim(X);
+  }
+}
+
+macro fork-tile![n](X) {
+  fork-tile[n, 0, false, true](X);
+}
+
+macro parallelize!(X) {
+  parallel-fork(X);
+  parallel-reduce(X);
+}
+
+macro unforkify!(X) {
+  fork-split(X);
+  unforkify(X);
+}
+
+optimize!(*);
+forkify!(*);
+associative(matmul@outer);
+
+// Parallelize by computing output array as 16 chunks
+let par = matmul@outer \ matmul@inner;
+fork-tile![4](par);
+let res = fork-reshape[[1, 3], [0], [2]](par); let outer = res.0; let inner = res.1;
+parallelize!(outer \ inner);
+
+let body = outline(inner);
+cpu(body);
+
+// Tile for cache, assuming 64B cache lines
+fork-split(body);
+fork-tile![16](body);
+let res = fork-reshape[[0, 2, 4, 1, 3], [5]](body); let outer = res.0; let inner = res.1;
+
+reduce-slf(inner);
+unforkify!(body);
+codegen-prep!(*);
diff --git a/juno_scheduler/src/compile.rs b/juno_scheduler/src/compile.rs
index 7b8c5020de1688bc509a07ed16226c1d3714ecd4..0f473f0d3b96e56b796e6436109335adf5d5f038 100644
--- a/juno_scheduler/src/compile.rs
+++ b/juno_scheduler/src/compile.rs
@@ -134,6 +134,7 @@ impl FromStr for Appliable {
             "fork-extend" => Ok(Appliable::Pass(ir::Pass::ForkExtend)),
             "fork-unroll" | "unroll" => Ok(Appliable::Pass(ir::Pass::ForkUnroll)),
             "fork-fusion" | "fusion" => Ok(Appliable::Pass(ir::Pass::ForkFusion)),
+            "fork-reshape" => Ok(Appliable::Pass(ir::Pass::ForkReshape)),
             "lift-dc-math" => Ok(Appliable::Pass(ir::Pass::LiftDCMath)),
             "loop-bound-canon" => Ok(Appliable::Pass(ir::Pass::LoopBoundCanon)),
             "outline" => Ok(Appliable::Pass(ir::Pass::Outline)),
diff --git a/juno_scheduler/src/ir.rs b/juno_scheduler/src/ir.rs
index 71a185ba65b4cd5ade237cbfe164cce8067a475c..cf4b655859a08b177a1f2308df686025876b35f1 100644
--- a/juno_scheduler/src/ir.rs
+++ b/juno_scheduler/src/ir.rs
@@ -20,6 +20,7 @@ pub enum Pass {
     ForkFusion,
     ForkGuardElim,
     ForkInterchange,
+    ForkReshape,
     ForkSplit,
     ForkUnroll,
     Forkify,
@@ -57,6 +58,7 @@ impl Pass {
             Pass::ForkExtend => num == 1,
             Pass::ForkFissionBufferize => num == 2 || num == 1,
             Pass::ForkInterchange => num == 2,
+            Pass::ForkReshape => true,
             Pass::Print => num == 1,
             Pass::Rename => num == 1,
             Pass::SROA => num == 0 || num == 1,
@@ -73,6 +75,7 @@ impl Pass {
             Pass::ForkExtend => "1",
             Pass::ForkFissionBufferize => "1 or 2",
             Pass::ForkInterchange => "2",
+            Pass::ForkReshape => "any",
             Pass::Print => "1",
             Pass::Rename => "1",
             Pass::SROA => "0 or 1",
diff --git a/juno_scheduler/src/lang.y b/juno_scheduler/src/lang.y
index 5903e42941cef535128765cb6ab98601987d6d84..69d50014a57865807bd9f82311a8852c21e648da 100644
--- a/juno_scheduler/src/lang.y
+++ b/juno_scheduler/src/lang.y
@@ -61,9 +61,9 @@ Expr -> Expr
   | Expr '@' 'ID'
       { Expr::Field { span: $span, lhs: Box::new($1), field: span_of_tok($3) } }
   | '(' Exprs ')'
-      { exprs_to_expr($span, $2) }
+      { Expr::Tuple { span: $span, exps: $2 } }
   | '[' Exprs ']'
-      { exprs_to_expr($span, $2) }
+      { Expr::Tuple { span: $span, exps: $2 } }
   | '{' Schedule '}'
       { Expr::BlockExpr { span: $span, body: Box::new($2) } }
   | '<' Fields '>'
@@ -203,11 +203,3 @@ pub struct MacroDecl {
   pub selection_name: Span,
   pub def: Box<OperationList>,
 }
-
-fn exprs_to_expr(span: Span, mut exps: Vec<Expr>) -> Expr {
-  if exps.len() == 1 {
-    exps.pop().unwrap()
-  } else {
-    Expr::Tuple { span, exps }
-  }
-}
diff --git a/juno_scheduler/src/pm.rs b/juno_scheduler/src/pm.rs
index ef9ec038c12f10ffb2a025a7aa347a5ee5d02d9d..362597186adcaac2c8c0106c7550d49d52fa6ea6 100644
--- a/juno_scheduler/src/pm.rs
+++ b/juno_scheduler/src/pm.rs
@@ -2901,6 +2901,220 @@ fn run_pass(
             pm.delete_gravestones();
             pm.clear_analyses();
         }
+        Pass::ForkReshape => {
+            let mut shape = vec![];
+            let mut loops = BTreeSet::new();
+            let mut fork_count = 0;
+
+            for arg in args {
+                let Value::Tuple { values } = arg else {
+                    return Err(SchedulerError::PassError {
+                        pass: "fork-reshape".to_string(),
+                        error: "expected each argument to be a list of integers".to_string(),
+                    });
+                };
+
+                let mut indices = vec![];
+                for val in values {
+                    let Value::Integer { val: idx } = val else {
+                        return Err(SchedulerError::PassError {
+                            pass: "fork-reshape".to_string(),
+                            error: "expected each argument to be a list of integers".to_string(),
+                        });
+                    };
+                    indices.push(idx);
+                    loops.insert(idx);
+                    fork_count += 1;
+                }
+                shape.push(indices);
+            }
+
+            if loops != (0..fork_count).collect() {
+                return Err(SchedulerError::PassError {
+                    pass: "fork-reshape".to_string(),
+                    error: "expected forks to be numbered sequentially from 0 and used exactly once".to_string(),
+                });
+            }
+
+            let Some((nodes, func_id)) = selection_as_set(pm, selection) else {
+                return Err(SchedulerError::PassError {
+                    pass: "fork-reshape".to_string(),
+                    error: "must be applied to nodes in a single function".to_string(),
+                });
+            };
+            let func = func_id.idx();
+
+            pm.make_def_uses();
+            pm.make_fork_join_maps();
+            pm.make_loops();
+            pm.make_reduce_cycles();
+
+            let def_uses = pm.def_uses.take().unwrap();
+            let mut fork_join_maps = pm.fork_join_maps.take().unwrap();
+            let loops = pm.loops.take().unwrap();
+            let reduce_cycles = pm.reduce_cycles.take().unwrap();
+
+            let def_use = &def_uses[func];
+            let fork_join_map = &mut fork_join_maps[func];
+            let loops = &loops[func];
+            let reduce_cycles = &reduce_cycles[func];
+
+            let mut editor = FunctionEditor::new(
+                &mut pm.functions[func],
+                func_id,
+                &pm.constants,
+                &pm.dynamic_constants,
+                &pm.types,
+                &pm.labels,
+                def_use,
+            );
+
+            // There should be exactly one fork nest in the selection and it should contain
+            // exactly fork_count forks (counting each dimension of each fork)
+            // We determine the loops (ordered top-down) that are contained in the selection
+            // (in particular the header is in the selection) and its a fork-join (the header
+            // is a fork)
+            let mut loops = loops.bottom_up_loops().into_iter().rev()
+                .filter(|(header, _)| nodes.contains(header) && editor.node(header).is_fork());
+            let Some((top_fork_head, top_fork_body)) = loops.next() else {
+                return Err(SchedulerError::PassError {
+                    pass: "fork-reshape".to_string(),
+                    error: format!("expected {} forks found 0 in {}", fork_count, editor.func().name),
+                });
+            };
+            // All the remaining forks need to be contained in the top fork body
+            let mut forks = vec![top_fork_head];
+            let mut num_dims = editor.node(top_fork_head).try_fork().unwrap().1.len();
+            for (head, _) in loops {
+                if !top_fork_body[head.idx()] {
+                    return Err(SchedulerError::PassError {
+                        pass: "fork-reshape".to_string(),
+                        error: "selection includes multiple non-nested forks".to_string(),
+                    });
+                } else {
+                    forks.push(head);
+                    num_dims += editor.node(head).try_fork().unwrap().1.len();
+                }
+            }
+
+            if num_dims != fork_count {
+                return Err(SchedulerError::PassError {
+                    pass: "fork-reshape".to_string(),
+                    error: format!("expected {} forks, found {} in {}", fork_count, num_dims, pm.functions[func].name),
+                });
+            }
+
+            // Now, we coalesce all of these forks into one so that we can interchange them
+            let mut forks = forks.into_iter();
+            let top_fork = forks.next().unwrap();
+            let mut cur_fork = top_fork;
+            for next_fork in forks {
+                let Some((new_fork, new_join)) = fork_coalesce_helper(&mut editor, cur_fork, next_fork, &fork_join_map) else {
+                    return Err(SchedulerError::PassError {
+                        pass: "fork-reshape".to_string(),
+                        error: "failed to coalesce forks".to_string(),
+                    });
+                };
+                cur_fork = new_fork;
+                fork_join_map.insert(new_fork, new_join);
+            }
+            let join = *fork_join_map.get(&cur_fork).unwrap();
+
+            // Now we have just one fork and we can perform the interchanges we need
+            // To do this, we track two maps: from original index to current index and from
+            // current index to original index
+            let mut orig_to_cur = (0..fork_count).collect::<Vec<_>>();
+            let mut cur_to_orig = (0..fork_count).collect::<Vec<_>>();
+
+            // Now, starting from the first (outermost) index we move the desired fork bound
+            // into place
+            for (idx, original_idx) in shape.iter().flat_map(|idx| idx.iter()).enumerate() {
+                let cur_idx = orig_to_cur[*original_idx];
+                let swapping = cur_to_orig[idx];
+
+                // If the desired factor is already in the correct place, do nothing
+                if cur_idx == idx {
+                    continue;
+                }
+                assert!(idx < cur_idx);
+                let Some(fork_res) = fork_interchange(&mut editor, cur_fork, join, idx, cur_idx) else {
+                    return Err(SchedulerError::PassError {
+                        pass: "fork-reshape".to_string(),
+                        error: "failed to interchange forks".to_string(),
+                    });
+                };
+                cur_fork = fork_res;
+
+                // Update our maps
+                orig_to_cur[*original_idx] = idx;
+                orig_to_cur[swapping] = cur_idx;
+                cur_to_orig[idx] = *original_idx;
+                cur_to_orig[cur_idx] = swapping;
+            }
+
+            // Finally we split the fork into the desired pieces. We do this by first splitting
+            // the fork into individual forks and then coalesce the chunks together
+            // Not sure how split_fork could fail, so if it does panic is fine
+            let (forks, joins) = split_fork(&mut editor, cur_fork, join, &reduce_cycles).unwrap();
+
+            for (fork, join) in forks.iter().zip(joins.iter()) {
+                fork_join_map.insert(*fork, *join);
+            }
+
+            // Finally coalesce the chunks together
+            let mut fork_idx = 0;
+            let mut final_forks = vec![];
+            for chunk in shape.iter() {
+                let chunk_len = chunk.len();
+
+                let mut cur_fork = forks[fork_idx];
+                for i in 1..chunk_len {
+                    let next_fork = forks[fork_idx + i];
+                    // Again, not sure at this point how coalesce could fail, so panic if it
+                    // does
+                    let (new_fork, new_join) = fork_coalesce_helper(&mut editor, cur_fork, next_fork, &fork_join_map).unwrap();
+                    cur_fork = new_fork;
+                    fork_join_map.insert(new_fork, new_join);
+                }
+
+                fork_idx += chunk_len;
+                final_forks.push(cur_fork);
+            }
+
+            // Label each fork and return the labels
+            // We've trashed our analyses at this point, so rerun them so that we can determine the
+            // nodes in each of the result fork-joins
+            pm.clear_analyses();
+            pm.make_def_uses();
+            pm.make_nodes_in_fork_joins();
+            
+            let def_uses = pm.def_uses.take().unwrap();
+            let nodes_in_fork_joins = pm.nodes_in_fork_joins.take().unwrap();
+
+            let def_use = &def_uses[func];
+            let nodes_in_fork_joins = &nodes_in_fork_joins[func];
+            
+            let mut editor = FunctionEditor::new(
+                &mut pm.functions[func],
+                func_id,
+                &pm.constants,
+                &pm.dynamic_constants,
+                &pm.types,
+                &pm.labels,
+                def_use,
+            );
+
+            let labels = create_labels_for_node_sets(&mut editor, final_forks.into_iter().map(|fork| nodes_in_fork_joins[&fork].iter().copied()))
+                .into_iter()
+                .map(|(_, label)| Value::Label { labels: vec![LabelInfo { func: func_id, label }] })
+                .collect();
+
+            result = Value::Tuple { values: labels };
+            changed = true;
+
+            pm.delete_gravestones();
+            pm.clear_analyses();
+        }
         Pass::WritePredication => {
             assert!(args.is_empty());
             for func in build_selection(pm, selection, false) {