diff --git a/hercules_cg/src/cpu.rs b/hercules_cg/src/cpu.rs index 6ad38fc0e5e59f18771a9ee56819bd6269cd868a..552dc3a30ea9f03813838c6c6d279f8f9e941de2 100644 --- a/hercules_cg/src/cpu.rs +++ b/hercules_cg/src/cpu.rs @@ -334,7 +334,7 @@ impl<'a> CPUContext<'a> { } } else { let (_, offsets) = &self.backing_allocation[&Device::LLVM]; - let offset = offsets[&id]; + let offset = offsets[&id].0; write!( body, " {} = getelementptr i8, ptr %backing, i64 %dc{}\n", diff --git a/hercules_cg/src/gpu.rs b/hercules_cg/src/gpu.rs index 76aba7e030492706a367bbb5f31107014010a129..5f2feedd28b09633e456efa4cec670d68f694fb1 100644 --- a/hercules_cg/src/gpu.rs +++ b/hercules_cg/src/gpu.rs @@ -1290,7 +1290,7 @@ namespace cg = cooperative_groups; if !is_primitive && state == KernelState::OutBlock { assert!(self.function.schedules[id.idx()].contains(&Schedule::NoResetConstant), "PANIC: The CUDA backend cannot lower a global memory constant that has to be reset to zero. This is because we cannot efficiently implement a memset to the underlying memory of the constant due to the need for a grid level sync. Consider floating this collection outside the CUDA function and into an AsyncRust function, or attaching the NoResetConstant schedule to indicate that no memset is semantically necessary."); let (_, offsets) = &self.backing_allocation[&Device::CUDA]; - let offset = offsets[&id]; + let offset = offsets[&id].0; write!( w, "{}{} = backing + dc{};\n", diff --git a/hercules_cg/src/lib.rs b/hercules_cg/src/lib.rs index 446231de1677a34e6e6e0be43b380b6737905f5e..9866400cc8adf9413c16ab5010b8ad3f515ec543 100644 --- a/hercules_cg/src/lib.rs +++ b/hercules_cg/src/lib.rs @@ -16,7 +16,7 @@ use std::collections::BTreeMap; use hercules_ir::*; -pub const LARGEST_ALIGNMENT: usize = 32; +pub const LARGEST_ALIGNMENT: usize = 64; /* * The alignment of a type does not depend on dynamic constants. @@ -53,10 +53,15 @@ pub type NodeColors = BTreeMap<FunctionID, FunctionNodeColors>; /* * The allocation information of each function is a size of the backing memory * needed and offsets into that backing memory per constant object and call node - * in the function. + * in the function (as well as their individual sizes). */ -pub type FunctionBackingAllocation = - BTreeMap<Device, (DynamicConstantID, BTreeMap<NodeID, DynamicConstantID>)>; +pub type FunctionBackingAllocation = BTreeMap< + Device, + ( + DynamicConstantID, + BTreeMap<NodeID, (DynamicConstantID, DynamicConstantID)>, + ), +>; pub type BackingAllocations = BTreeMap<FunctionID, FunctionBackingAllocation>; pub const BACKED_DEVICES: [Device; 2] = [Device::LLVM, Device::CUDA]; diff --git a/hercules_cg/src/rt.rs b/hercules_cg/src/rt.rs index 884129c712863a49d535d23f46893d8ed246c56b..3db0f16f35c711372841ced574c72cc50ee007ec 100644 --- a/hercules_cg/src/rt.rs +++ b/hercules_cg/src/rt.rs @@ -80,7 +80,7 @@ pub fn rt_codegen<W: Write>( typing: &Vec<TypeID>, control_subgraph: &Subgraph, fork_join_map: &HashMap<NodeID, NodeID>, - fork_control_map: &HashMap<NodeID, HashSet<NodeID>>, + fork_join_nest: &HashMap<NodeID, Vec<NodeID>>, fork_tree: &HashMap<NodeID, HashSet<NodeID>>, nodes_in_fork_joins: &HashMap<NodeID, HashSet<NodeID>>, collection_objects: &CollectionObjects, @@ -103,7 +103,7 @@ pub fn rt_codegen<W: Write>( control_subgraph, fork_join_map, join_fork_map: &join_fork_map, - fork_control_map, + fork_join_nest, fork_tree, nodes_in_fork_joins, collection_objects, @@ -124,7 +124,7 @@ struct RTContext<'a> { control_subgraph: &'a Subgraph, fork_join_map: &'a HashMap<NodeID, NodeID>, join_fork_map: &'a HashMap<NodeID, NodeID>, - fork_control_map: &'a HashMap<NodeID, HashSet<NodeID>>, + fork_join_nest: &'a HashMap<NodeID, Vec<NodeID>>, fork_tree: &'a HashMap<NodeID, HashSet<NodeID>>, nodes_in_fork_joins: &'a HashMap<NodeID, HashSet<NodeID>>, collection_objects: &'a CollectionObjects, @@ -529,7 +529,7 @@ impl<'a> RTContext<'a> { write!(block, "{} = ", self.get_value(id, bb, true))?; let mut size_and_device = None; match self.module.constants[cons_id.idx()] { - Constant::Boolean(val) => write!(block, "{}bool", val)?, + Constant::Boolean(val) => write!(block, "{}", val)?, Constant::Integer8(val) => write!(block, "{}i8", val)?, Constant::Integer16(val) => write!(block, "{}i16", val)?, Constant::Integer32(val) => write!(block, "{}i32", val)?, @@ -559,7 +559,7 @@ impl<'a> RTContext<'a> { Constant::Product(ty, _) | Constant::Summation(ty, _, _) | Constant::Array(ty) => { - let (device, offset) = self.backing_allocations[&self.func_id] + let (device, (offset, _)) = self.backing_allocations[&self.func_id] .iter() .filter_map(|(device, (_, offsets))| { offsets.get(&id).map(|id| (*device, *id)) @@ -676,13 +676,28 @@ impl<'a> RTContext<'a> { prefix, self.module.functions[callee_id.idx()].name )?; - for (device, offset) in self.backing_allocations[&self.func_id] + for (device, (offset, size)) in self.backing_allocations[&self.func_id] .iter() .filter_map(|(device, (_, offsets))| offsets.get(&id).map(|id| (*device, *id))) { - write!(block, "backing_{}.byte_add(", device.name())?; + write!(block, "backing_{}.byte_add(((", device.name())?; self.codegen_dynamic_constant(offset, block)?; - write!(block, " as usize), ")? + let forks = &self.fork_join_nest[&bb]; + if !forks.is_empty() { + write!(block, ") + ")?; + let mut linear_thread = "0".to_string(); + for fork in forks { + let factors = func.nodes[fork.idx()].try_fork().unwrap().1; + for (factor_idx, factor) in factors.into_iter().enumerate() { + linear_thread = format!("({} *", linear_thread); + self.codegen_dynamic_constant(*factor, &mut linear_thread)?; + write!(linear_thread, " + tid_{}_{})", fork.idx(), factor_idx)?; + } + } + write!(block, "{} * (", linear_thread)?; + self.codegen_dynamic_constant(size, block)?; + } + write!(block, ")) as usize), ")? } for dc in dynamic_constants { self.codegen_dynamic_constant(*dc, block)?; diff --git a/hercules_ir/src/collections.rs b/hercules_ir/src/collections.rs index 60f4fb1cadcdea835975b292c25b7fac4bf3e35d..a8fc3c67e7476a01c6d5cdc512c45c6c8eda5db1 100644 --- a/hercules_ir/src/collections.rs +++ b/hercules_ir/src/collections.rs @@ -202,7 +202,7 @@ pub fn collection_objects( let fco = &collection_objects[&callee]; if fco.returned[*selection] .iter() - .any(|returned| fco.origins[returned.idx()].try_parameter().is_some()) + .any(|returned| fco.origins[returned.idx()].try_parameter().is_none()) { // If the callee may return a new collection object, then // this data projection node originates a single collection object. The @@ -412,6 +412,7 @@ pub fn collection_objects( } } + assert_eq!(objects_per_node.len(), func.nodes.len()); let fco = FunctionCollectionObjects { objects_per_node, mutated, diff --git a/hercules_ir/src/dot.rs b/hercules_ir/src/dot.rs index aff1f9c52eb8735db828620a15e27a6395b20f9f..f21137f8d5ea02347765b114d8c487ef5f3ceaaf 100644 --- a/hercules_ir/src/dot.rs +++ b/hercules_ir/src/dot.rs @@ -44,10 +44,17 @@ pub fn xdot_module( file.write_all(contents.as_bytes()) .expect("PANIC: Unable to write output file contents."); println!("Graphviz written to: {}", tmp_path.display()); - Command::new("xdot") + let mut xdot_process = Command::new("xdot") .args([&tmp_path]) - .output() + .spawn() .expect("PANIC: Couldn't execute xdot. Is xdot installed?"); + assert!( + xdot_process + .wait() + .map(|status| status.success()) + .unwrap_or(false), + "PANIC: Xdot failed to execute." + ) } /* diff --git a/hercules_opt/src/fork_transforms.rs b/hercules_opt/src/fork_transforms.rs index e635b3c00d7bfa0090376d8056e65d8d01e60ce2..ff0f0283767996914e8f9b2274ed9a6d538b1812 100644 --- a/hercules_opt/src/fork_transforms.rs +++ b/hercules_opt/src/fork_transforms.rs @@ -695,7 +695,8 @@ pub fn fork_coalesce_helper( editor.edit(|mut edit| { let new_tid = edit.add_node(new_tid); - let edit = edit.replace_all_uses(tid, new_tid)?; + let mut edit = edit.replace_all_uses(tid, new_tid)?; + edit.sub_edit(tid, new_tid); Ok(edit) }); } @@ -915,7 +916,9 @@ pub fn chunk_all_forks_unguarded( }; for (fork, _) in fork_join_map { - chunk_fork_unguarded(editor, *fork, dim_idx, dc_id, order); + if editor.is_mutable(*fork) { + chunk_fork_unguarded(editor, *fork, dim_idx, dc_id, order); + } } } // Splits a dimension of a single fork join into multiple. @@ -1598,3 +1601,130 @@ pub fn clean_monoid_reduces(editor: &mut FunctionEditor, typing: &Vec<TypeID>) { } } } + +/* + * Extends the dimensions of a fork-join to be a multiple of a number and gates + * the execution of the body. + */ +pub fn extend_all_forks( + editor: &mut FunctionEditor, + fork_join_map: &HashMap<NodeID, NodeID>, + multiple: usize, +) { + for (fork, join) in fork_join_map { + if editor.is_mutable(*fork) { + extend_fork(editor, *fork, *join, multiple); + } + } +} + +fn extend_fork(editor: &mut FunctionEditor, fork: NodeID, join: NodeID, multiple: usize) { + let nodes = &editor.func().nodes; + let (fork_pred, factors) = nodes[fork.idx()].try_fork().unwrap(); + let factors = factors.to_vec(); + let fork_succ = editor + .get_users(fork) + .filter(|id| nodes[id.idx()].is_control()) + .next() + .unwrap(); + let join_pred = nodes[join.idx()].try_join().unwrap(); + let ctrl_between = fork != join_pred; + let reduces: Vec<_> = editor + .get_users(join) + .filter_map(|id| nodes[id.idx()].try_reduce().map(|x| (id, x))) + .collect(); + + editor.edit(|mut edit| { + // We can round up a dynamic constant A to a multiple of another dynamic + // constant B via the following math: + // ((A + B - 1) / B) * B + let new_factors: Vec<_> = factors + .iter() + .map(|factor| { + let b = edit.add_dynamic_constant(DynamicConstant::Constant(multiple)); + let apb = edit.add_dynamic_constant(DynamicConstant::add(*factor, b)); + let o = edit.add_dynamic_constant(DynamicConstant::Constant(1)); + let apbmo = edit.add_dynamic_constant(DynamicConstant::sub(apb, o)); + let apbmodb = edit.add_dynamic_constant(DynamicConstant::div(apbmo, b)); + edit.add_dynamic_constant(DynamicConstant::mul(apbmodb, b)) + }) + .collect(); + + // Create the new control structure. + let new_fork = edit.add_node(Node::Fork { + control: fork_pred, + factors: new_factors.into_boxed_slice(), + }); + edit = edit.replace_all_uses_where(fork, new_fork, |id| *id != fork_succ)?; + edit.sub_edit(fork, new_fork); + let conds: Vec<_> = factors + .iter() + .enumerate() + .map(|(idx, old_factor)| { + let tid = edit.add_node(Node::ThreadID { + control: new_fork, + dimension: idx, + }); + let old_bound = edit.add_node(Node::DynamicConstant { id: *old_factor }); + edit.add_node(Node::Binary { + op: BinaryOperator::LT, + left: tid, + right: old_bound, + }) + }) + .collect(); + let cond = conds + .into_iter() + .reduce(|left, right| { + edit.add_node(Node::Binary { + op: BinaryOperator::And, + left, + right, + }) + }) + .unwrap(); + let branch = edit.add_node(Node::If { + control: new_fork, + cond, + }); + let false_proj = edit.add_node(Node::ControlProjection { + control: branch, + selection: 0, + }); + let true_proj = edit.add_node(Node::ControlProjection { + control: branch, + selection: 1, + }); + if ctrl_between { + edit = edit.replace_all_uses_where(fork, true_proj, |id| *id == fork_succ)?; + } + let bottom_region = edit.add_node(Node::Region { + preds: Box::new([false_proj, if ctrl_between { join_pred } else { true_proj }]), + }); + let new_join = edit.add_node(Node::Join { + control: bottom_region, + }); + edit = edit.replace_all_uses(join, new_join)?; + edit.sub_edit(join, new_join); + edit = edit.delete_node(fork)?; + edit = edit.delete_node(join)?; + + // Update the reduces to use phis on the region node to gate their execution. + for (reduce, (_, init, reduct)) in reduces { + let phi = edit.add_node(Node::Phi { + control: bottom_region, + data: Box::new([reduce, reduct]), + }); + let new_reduce = edit.add_node(Node::Reduce { + control: new_join, + init, + reduct: phi, + }); + edit = edit.replace_all_uses(reduce, new_reduce)?; + edit.sub_edit(reduce, new_reduce); + edit = edit.delete_node(reduce)?; + } + + Ok(edit) + }); +} diff --git a/hercules_opt/src/gcm.rs b/hercules_opt/src/gcm.rs index d311970517260466d8712c430c20ab38fb4aec2b..c612acac1e114fcb3a73edcf0468d2c6f7c84acd 100644 --- a/hercules_opt/src/gcm.rs +++ b/hercules_opt/src/gcm.rs @@ -82,6 +82,7 @@ pub fn gcm( control_subgraph: &Subgraph, dom: &DomTree, fork_join_map: &HashMap<NodeID, NodeID>, + fork_join_nest: &HashMap<NodeID, Vec<NodeID>>, loops: &LoopTree, reduce_cycles: &HashMap<NodeID, HashSet<NodeID>>, objects: &CollectionObjects, @@ -120,6 +121,18 @@ pub fn gcm( return None; } + if add_extra_collection_dims( + editor, + typing, + fork_join_map, + fork_join_nest, + objects, + devices, + &bbs, + ) { + return None; + } + let Some(node_colors) = color_nodes(editor, typing, &objects, &devices, node_colors) else { return None; }; @@ -139,6 +152,7 @@ pub fn gcm( let backing_allocation = object_allocation( editor, typing, + fork_join_nest, &node_colors, &alignments, &liveness, @@ -1027,6 +1041,126 @@ fn spill_clones( } } +/* + * Look for mutated collections placed inside fork-joins in AsyncRust functions. + * These collections should be duplicated across the size of the fork-join. + */ +fn add_extra_collection_dims( + editor: &mut FunctionEditor, + typing: &Vec<TypeID>, + fork_join_map: &HashMap<NodeID, NodeID>, + fork_join_nest: &HashMap<NodeID, Vec<NodeID>>, + objects: &CollectionObjects, + devices: &Vec<Device>, + bbs: &BasicBlocks, +) -> bool { + if devices[editor.func_id().idx()] == Device::AsyncRust { + // Look for collection constant nodes inside fork-joins that are mutated + // inside the fork-join, aren't involved in any of the reduces of the + // fork-join, and have a user that isn't a direct read based on all of + // the thread IDs. + let fco = &objects[&editor.func_id()]; + let candidates: Vec<_> = editor + .node_ids() + .filter(|id| { + editor.func().nodes[id.idx()].is_constant() + && !editor.get_type(typing[id.idx()]).is_primitive() + }) + .collect(); + for id in candidates { + // Check all of the above conditions. + let nodes = &editor.func().nodes; + if editor.get_users(id).len() != 1 { + continue; + } + let forks = &fork_join_nest[&bbs.0[id.idx()]]; + if forks.is_empty() { + continue; + } + let object = fco.objects(id)[0]; + let mutated_inside = fco + .mutators(object) + .into_iter() + .any(|id| &fork_join_nest[&bbs.0[id.idx()]] == forks); + if !mutated_inside { + continue; + } + let in_reduce = forks.into_iter().any(|id| { + let join = fork_join_map[id]; + let mut reduces = editor + .get_users(join) + .filter(|id| nodes[id.idx()].is_reduce()); + reduces.any(|id| fco.objects(id).contains(&object)) + }); + if in_reduce { + continue; + } + if let Node::Read { + collect: _, + ref indices, + } = nodes[editor.get_users(id).next().unwrap().idx()] + && let Index::Position(ref pos) = indices[0] + && { + let tid_pos: BTreeSet<(NodeID, usize)> = pos + .into_iter() + .filter_map(|id| nodes[id.idx()].try_thread_id()) + .collect(); + let reference: BTreeSet<(NodeID, usize)> = forks + .into_iter() + .flat_map(|id| { + (0..nodes[id.idx()].try_fork().unwrap().1.len()).map(|dim| (*id, dim)) + }) + .collect(); + tid_pos == reference + } + { + continue; + } + + // We know that this collection needs to be replicated across the + // fork-join dimensions, so do that. + let ty = typing[id.idx()]; + let num_dims: Vec<_> = forks + .into_iter() + .rev() + .map(|id| nodes[id.idx()].try_fork().unwrap().1.len()) + .collect(); + let factors = forks + .into_iter() + .rev() + .flat_map(|id| nodes[id.idx()].try_fork().unwrap().1.into_iter()) + .map(|dc| *dc) + .collect(); + let array_ty = Type::Array(ty, factors); + let success = editor.edit(|mut edit| { + let new_ty = edit.add_type(array_ty); + let new_cons = edit.add_zero_constant(new_ty); + let new_cons = edit.add_node(Node::Constant { id: new_cons }); + let mut tids = vec![]; + for (fork, num_dims) in forks.into_iter().rev().zip(num_dims) { + for dim in 0..num_dims { + tids.push(edit.add_node(Node::ThreadID { + control: *fork, + dimension: dim, + })); + } + } + let read = edit.add_node(Node::Read { + collect: new_cons, + indices: Box::new([Index::Position(tids.into_boxed_slice())]), + }); + edit.sub_edit(id, new_cons); + edit = edit.replace_all_uses(id, read)?; + edit = edit.delete_node(id)?; + Ok(edit) + }); + assert!(success); + return true; + } + } + false +} + type Liveness = BTreeMap<NodeID, Vec<BTreeSet<NodeID>>>; /* @@ -1507,12 +1641,13 @@ fn type_size(edit: &mut FunctionEdit, ty_id: TypeID, alignments: &Vec<usize>) -> fn object_allocation( editor: &mut FunctionEditor, typing: &Vec<TypeID>, + fork_join_nest: &HashMap<NodeID, Vec<NodeID>>, node_colors: &FunctionNodeColors, alignments: &Vec<usize>, _liveness: &Liveness, backing_allocations: &BackingAllocations, ) -> FunctionBackingAllocation { - let mut fba = BTreeMap::new(); + let mut fba = FunctionBackingAllocation::new(); let node_ids = editor.node_ids(); editor.edit(|mut edit| { @@ -1526,13 +1661,13 @@ fn object_allocation( let (total, offsets) = fba.entry(device).or_insert_with(|| (zero, BTreeMap::new())); *total = align(&mut edit, *total, alignments[typing[id.idx()].idx()]); - offsets.insert(id, *total); let type_size = type_size(&mut edit, typing[id.idx()], alignments); + offsets.insert(id, (*total, type_size)); *total = edit.add_dynamic_constant(DynamicConstant::add(*total, type_size)); } } Node::Call { - control: _, + control, function: callee, ref dynamic_constants, args: _, @@ -1554,7 +1689,6 @@ fn object_allocation( // We don't know the alignment requirement of the memory // in the callee, so just assume the largest alignment. *total = align(&mut edit, *total, LARGEST_ALIGNMENT); - offsets.insert(id, *total); // Substitute the dynamic constant parameters in the // callee's backing size. callee_backing_size = substitute_dynamic_constants( @@ -1562,9 +1696,26 @@ fn object_allocation( callee_backing_size, &mut edit, ); + offsets.insert(id, (*total, callee_backing_size)); + // Multiply the backing allocation size of the + // callee by the number of parallel threads that + // will call the function. + let forks = &fork_join_nest[&control]; + let factors: Vec<_> = forks + .into_iter() + .rev() + .flat_map(|id| edit.get_node(*id).try_fork().unwrap().1.into_iter()) + .map(|dc| *dc) + .collect(); + let mut multiplied_callee_backing_size = callee_backing_size; + for factor in factors { + multiplied_callee_backing_size = edit.add_dynamic_constant( + DynamicConstant::mul(multiplied_callee_backing_size, factor), + ); + } *total = edit.add_dynamic_constant(DynamicConstant::add( *total, - callee_backing_size, + multiplied_callee_backing_size, )); } } diff --git a/hercules_rt/src/lib.rs b/hercules_rt/src/lib.rs index d19a0a5a16e4438e8746ca312428759ea28ed556..a5954ca03c32441709a2ce5452e9b2584f9c6155 100644 --- a/hercules_rt/src/lib.rs +++ b/hercules_rt/src/lib.rs @@ -13,7 +13,7 @@ use std::sync::OnceLock; * src/rt.rs (the RT backend). */ -pub const LARGEST_ALIGNMENT: usize = 32; +pub const LARGEST_ALIGNMENT: usize = 64; pub unsafe fn __cpu_alloc(size: usize) -> *mut u8 { let ptr = alloc(Layout::from_size_align(size, LARGEST_ALIGNMENT).unwrap()); @@ -150,7 +150,7 @@ extern "C" { fn ___cublas_sgemm(i: u64, j: u64, k: u64, c: *mut u8, a: *const u8, b: *const u8); } -#[derive(Clone, Debug)] +#[derive(Clone, Debug, Copy)] pub struct HerculesCPURef<'a> { ptr: NonNull<u8>, size: usize, @@ -165,7 +165,7 @@ pub struct HerculesCPURefMut<'a> { } #[cfg(feature = "cuda")] -#[derive(Clone, Debug)] +#[derive(Clone, Debug, Copy)] pub struct HerculesCUDARef<'a> { ptr: NonNull<u8>, size: usize, diff --git a/juno_samples/cava/benches/cava_bench.rs b/juno_samples/cava/benches/cava_bench.rs index b8dd3ce26901e94872a5d1b28f1494c5bdfb8a77..41b34230f30af5739f0487eb3d38c9d51f384458 100644 --- a/juno_samples/cava/benches/cava_bench.rs +++ b/juno_samples/cava/benches/cava_bench.rs @@ -25,6 +25,12 @@ fn cava_bench(c: &mut Criterion) { let (rows, cols, num_ctrl_pts, image, tstw, ctrl_pts, weights, coefs, tonemap) = prepare_hercules_inputs(&raw_image, &cam_model); let mut r = runner!(cava); + let image = image.to(); + let tstw = tstw.to(); + let ctrl_pts = ctrl_pts.to(); + let weights = weights.to(); + let coefs = coefs.to(); + let tonemap = tonemap.to(); group.bench_function("cava bench small", |b| { b.iter(|| { @@ -32,12 +38,12 @@ fn cava_bench(c: &mut Criterion) { rows as u64, cols as u64, num_ctrl_pts as u64, - image.to(), - tstw.to(), - ctrl_pts.to(), - weights.to(), - coefs.to(), - tonemap.to(), + image, + tstw, + ctrl_pts, + weights, + coefs, + tonemap, )); }) }); @@ -55,6 +61,12 @@ fn cava_bench(c: &mut Criterion) { let (rows, cols, num_ctrl_pts, image, tstw, ctrl_pts, weights, coefs, tonemap) = prepare_hercules_inputs(&raw_image, &cam_model); let mut r = runner!(cava); + let image = image.to(); + let tstw = tstw.to(); + let ctrl_pts = ctrl_pts.to(); + let weights = weights.to(); + let coefs = coefs.to(); + let tonemap = tonemap.to(); group.bench_function("cava bench full", |b| { b.iter(|| { @@ -62,12 +74,12 @@ fn cava_bench(c: &mut Criterion) { rows as u64, cols as u64, num_ctrl_pts as u64, - image.to(), - tstw.to(), - ctrl_pts.to(), - weights.to(), - coefs.to(), - tonemap.to(), + image, + tstw, + ctrl_pts, + weights, + coefs, + tonemap, )); }) }); diff --git a/juno_samples/cava/src/cava.jn b/juno_samples/cava/src/cava.jn index dbe799f9f23e63ed40157c9d57f5c4c8d9b4eb23..4d02b2cdb5f843774ebae9b873388e03c82071b0 100644 --- a/juno_samples/cava/src/cava.jn +++ b/juno_samples/cava/src/cava.jn @@ -142,7 +142,7 @@ fn gamut<row : usize, col : usize, num_ctrl_pts : usize>( ) -> f32[CHAN, row, col] { @res let result : f32[CHAN, row, col]; - for r = 0 to row { + @image_loop for r = 0 to row { for c = 0 to col { @l2 let l2_dist : f32[num_ctrl_pts]; for cp = 0 to num_ctrl_pts { diff --git a/juno_samples/cava/src/cpu.sch b/juno_samples/cava/src/cpu.sch index 3ac2f326115bb4aafb80c6a2d4b3cd024096db8c..8f22b37d4c3a77a31e8d2467c8e6e130c281513f 100644 --- a/juno_samples/cava/src/cpu.sch +++ b/juno_samples/cava/src/cpu.sch @@ -49,6 +49,7 @@ simpl!(fuse1); write-predication(fuse1); simpl!(fuse1); parallel-reduce(fuse1@loop); +fork-extend[8](fuse1); inline(fuse2); no-memset(fuse2@res); @@ -112,6 +113,14 @@ fixpoint { simpl!(fuse4); array-slf(fuse4); simpl!(fuse4); +let par = fuse4@image_loop \ fuse4@channel_loop; +fork-tile[4, 1, false, false](par); +fork-tile[4, 0, false, false](par); +fork-interchange[1, 2](par); +let split = fork-split(par); +let fuse4_body = outline(split.cava_3.fj2); +fork-coalesce(fuse4, fuse4_body); +simpl!(fuse4, fuse4_body); no-memset(fuse5@res1); no-memset(fuse5@res2); @@ -127,8 +136,8 @@ simpl!(fuse5); delete-uncalled(*); simpl!(*); -fork-split(fuse1, fuse2, fuse3, fuse4, fuse5); -unforkify(fuse1, fuse2, fuse3, fuse4, fuse5); +fork-split(fuse1, fuse2, fuse3, fuse4_body, fuse5); +unforkify(fuse1, fuse2, fuse3, fuse4_body, fuse5); simpl!(*); diff --git a/juno_samples/edge_detection/Cargo.toml b/juno_samples/edge_detection/Cargo.toml index 483724d8e4c2c7bcd057990ce5e149923e90cc3b..fa4ca1ff09ea0ebc404998ad6119e886d1d95b4c 100644 --- a/juno_samples/edge_detection/Cargo.toml +++ b/juno_samples/edge_detection/Cargo.toml @@ -33,3 +33,4 @@ criterion = { version = "0.5", features = ["html_reports"] } [[bench]] name = "edge_detection_bench" harness = false +required-features = ["opencv"] diff --git a/juno_samples/edge_detection/benches/edge_detection_bench.rs b/juno_samples/edge_detection/benches/edge_detection_bench.rs index 806a886510666f471335227b201daa74c000f4e1..760352754ccd39df878fa45523c5b7e76dd84f98 100644 --- a/juno_samples/edge_detection/benches/edge_detection_bench.rs +++ b/juno_samples/edge_detection/benches/edge_detection_bench.rs @@ -63,6 +63,10 @@ fn edge_detection_bench(c: &mut Criterion) { let num_frames = 5; let mut r = runner!(edge_detection); + let gaussian_filter_h = gaussian_filter_h.to(); + let structure_h = structure_h.to(); + let sx_h = sx_h.to(); + let sy_h = sy_h.to(); let frames: Vec<_> = (0..num_frames).map(|_| load_frame(&mut video)).collect(); @@ -87,10 +91,10 @@ fn edge_detection_bench(c: &mut Criterion) { sz as u64, sb as u64, input_h.to(), - gaussian_filter_h.to(), - structure_h.to(), - sx_h.to(), - sy_h.to(), + gaussian_filter_h, + structure_h, + sx_h, + sy_h, theta, ) .await diff --git a/juno_samples/edge_detection/src/cpu.sch b/juno_samples/edge_detection/src/cpu.sch index ead722ce0adfcf61ccc6dc79c70e1ba76d0e8eeb..4bd3254b1773a804e90154eec8385c8237b62fbd 100644 --- a/juno_samples/edge_detection/src/cpu.sch +++ b/juno_samples/edge_detection/src/cpu.sch @@ -24,6 +24,14 @@ predication(gaussian_smoothing); simpl!(gaussian_smoothing); predication(gaussian_smoothing); simpl!(gaussian_smoothing); +let par = gaussian_smoothing@image_loop \ gaussian_smoothing@filter_loop; +fork-tile[4, 1, false, false](par); +fork-tile[4, 0, false, false](par); +fork-interchange[1, 2](par); +let split = fork-split(par); +let gaussian_smoothing_body = outline(split._0_gaussian_smoothing.fj2); +fork-coalesce(gaussian_smoothing, gaussian_smoothing_body); +simpl!(gaussian_smoothing, gaussian_smoothing_body); no-memset(laplacian_estimate@res, laplacian_estimate@shr1, laplacian_estimate@shr2); fixpoint { @@ -32,6 +40,15 @@ fixpoint { fork-coalesce(laplacian_estimate); } simpl!(laplacian_estimate); +let par = laplacian_estimate@image_loop \ laplacian_estimate@filter_loop; +fork-tile[4, 1, false, false](par); +fork-tile[4, 0, false, false](par); +fork-interchange[1, 2](par); +let split = fork-split(par); +let body = split._1_laplacian_estimate.fj2 | laplacian_estimate.shr1 | laplacian_estimate.shr2; +let laplacian_estimate_body = outline(body); +fork-coalesce(laplacian_estimate, laplacian_estimate_body); +simpl!(laplacian_estimate, laplacian_estimate_body); no-memset(zero_crossings@res, zero_crossings@shr1, zero_crossings@shr2); fixpoint { @@ -40,6 +57,15 @@ fixpoint { fork-coalesce(zero_crossings); } simpl!(zero_crossings); +let par = zero_crossings@image_loop \ zero_crossings@filter_loop; +fork-tile[4, 1, false, false](par); +fork-tile[4, 0, false, false](par); +fork-interchange[1, 2](par); +let split = fork-split(par); +let body = split._2_zero_crossings.fj2 | zero_crossings.shr1 | zero_crossings.shr2; +let zero_crossings_body = outline(body); +fork-coalesce(zero_crossings, zero_crossings_body); +simpl!(zero_crossings, zero_crossings_body); no-memset(gradient@res); fixpoint { @@ -81,8 +107,8 @@ simpl!(reject_zero_crossings); async-call(edge_detection@le, edge_detection@zc); -fork-split(gaussian_smoothing, laplacian_estimate, zero_crossings, gradient, reject_zero_crossings); -unforkify(gaussian_smoothing, laplacian_estimate, zero_crossings, gradient, reject_zero_crossings); +fork-split(gaussian_smoothing_body, laplacian_estimate_body, zero_crossings_body, gradient, reject_zero_crossings); +unforkify(gaussian_smoothing_body, laplacian_estimate_body, zero_crossings_body, gradient, reject_zero_crossings); simpl!(*); diff --git a/juno_samples/edge_detection/src/edge_detection.jn b/juno_samples/edge_detection/src/edge_detection.jn index e1413488e95d324e154ac478c2131db666fcbbf8..58f364dc1cea77b27c83ba2340bfbd11fde07f31 100644 --- a/juno_samples/edge_detection/src/edge_detection.jn +++ b/juno_samples/edge_detection/src/edge_detection.jn @@ -7,11 +7,11 @@ fn gaussian_smoothing<n, m, gs : usize>( // Define the gaussian radius as half the gaussian size const gr = gs / 2; - for row = 0 to n { + @image_loop for row = 0 to n { for col = 0 to m { let smoothed = 0.0; - for i = 0 to gs { + @filter_loop for i = 0 to gs { for j = 0 to gs { let val = input[if row + i < gr then 0 else if row + i - gr > n - 1 then n - 1 @@ -41,23 +41,23 @@ fn laplacian_estimate<n, m, sz: usize>( @res let result : f32[n, m]; - for row = 0 to n { + @image_loop for row = 0 to n { for col = 0 to m { // Copy data for dilation filter @shr1 let imageArea : f32[sz, sz]; - for i = 0 to sz { + @filter_loop for i = 0 to sz { for j = 0 to sz { - imageArea[i, j] = if row + i < r then MIN_BR - else if row + i - r > n - 1 then MIN_BR - else if col + j < r then MIN_BR - else if col + j - r > m - 1 then MIN_BR - else input[row + i - r, col + j - r]; + imageArea[i, j] = if row + i < r + || row + i - r > n - 1 + || col + j < r + || col + j - r > m - 1 then MIN_BR + else input[row + i - r, col + j - r]; } } // Compute pixel of dilated image let dilated = MIN_BR; - for i = 0 to sz { + @filter_loop for i = 0 to sz { for j = 0 to sz { dilated = max!(dilated, imageArea[i, j] * structure[i, j]); } @@ -65,19 +65,19 @@ fn laplacian_estimate<n, m, sz: usize>( // Data copy for erotion filter @shr2 let imageArea : f32[sz, sz]; - for i = 0 to sz { + @filter_loop for i = 0 to sz { for j = 0 to sz { - imageArea[i, j] = if row + i < r then MAX_BR - else if row + i - r > n - 1 then MAX_BR - else if col + j < r then MAX_BR - else if col + j - r > m - 1 then MAX_BR - else input[row + i - r, col + j - r]; + imageArea[i, j] = if row + i < r + || row + i - r > n - 1 + || col + j < r + || col + j - r > m - 1 then MAX_BR + else input[row + i - r, col + j - r]; } } // Compute pixel of eroded image let eroded = MAX_BR; - for i = 0 to sz { + @filter_loop for i = 0 to sz { for j = 0 to sz { eroded = min!(eroded, imageArea[i, j] * structure[i, j]); } @@ -99,25 +99,24 @@ fn zero_crossings<n, m, sz: usize>( @res let result : f32[n, m]; - for row = 0 to n { + @image_loop for row = 0 to n { for col = 0 to m { // Data copy for dilation filter @shr1 let imageArea : f32[sz, sz]; - for i = 0 to sz { + @filter_loop for i = 0 to sz { for j = 0 to sz { - imageArea[i, j] = if row + i < r then MIN_BR - else if row + i - r > n - 1 then MIN_BR - else if col + j < r then MIN_BR - else if col + j - r > m - 1 then MIN_BR - else if input[row + i - r, col + j - r] > MIN_BR - then MAX_BR + imageArea[i, j] = if row + i < r + || row + i - r > n - 1 + || col + j < r + || col + j - r > m - 1 then MIN_BR + else if input[row + i - r, col + j - r] > MIN_BR then MAX_BR else MIN_BR; } } // Compute the pixel of dilated image let dilated = MIN_BR; - for i = 0 to sz { + @filter_loop for i = 0 to sz { for j = 0 to sz { dilated = max!(dilated, imageArea[i, j] * structure[i, j]); } @@ -125,21 +124,20 @@ fn zero_crossings<n, m, sz: usize>( // Data copy for erotion filter @shr2 let imageArea : f32[sz, sz]; - for i = 0 to sz { + @filter_loop for i = 0 to sz { for j = 0 to sz { - imageArea[i, j] = if row + i < r then MAX_BR - else if row + i - r > n - 1 then MAX_BR - else if col + j < r then MAX_BR - else if col + j - r > m - 1 then MAX_BR - else if input[row + i - r, col + j - r] > MIN_BR - then MAX_BR + imageArea[i, j] = if row + i < r + || row + i - r > n - 1 + || col + j < r + || col + j - r > m - 1 then MAX_BR + else if input[row + i - r, col + j - r] > MIN_BR then MAX_BR else MIN_BR; } } // Compute the pixel of eroded image let eroded = MAX_BR; - for i = 0 to sz { + @filter_loop for i = 0 to sz { for j = 0 to sz { eroded = min!(eroded, imageArea[i, j] * structure[i, j]); } diff --git a/juno_samples/edge_detection/src/lib.rs b/juno_samples/edge_detection/src/lib.rs index dab84cf6206c3cb9b816c88c53c8ddfbec102994..aa44e2e7e7215ee21560fefefc8c748610e183ce 100644 --- a/juno_samples/edge_detection/src/lib.rs +++ b/juno_samples/edge_detection/src/lib.rs @@ -234,7 +234,14 @@ pub fn edge_detection_harness(args: EdgeDetectionInputs) { theta, ); - assert_eq!(result, rust_result); + let mut all = true; + for idx in 0..rust_result.len() { + if result[idx] != rust_result[idx] { + all = false; + println!("Found mismatch in images at {}.", idx); + } + } + assert!(all); println!("Frames {} match", i); if display_verify { diff --git a/juno_samples/fork_join_tests/src/cpu.sch b/juno_samples/fork_join_tests/src/cpu.sch index 76dcbdf68ce0ac88f8a2a22bda364f60a88e78bb..f46c91d6a84a08b2258332af1dc5d6a662d86639 100644 --- a/juno_samples/fork_join_tests/src/cpu.sch +++ b/juno_samples/fork_join_tests/src/cpu.sch @@ -3,7 +3,7 @@ gvn(*); phi-elim(*); dce(*); -let auto = auto-outline(test1, test2, test3, test4, test5, test7, test8); +let auto = auto-outline(test1, test2, test3, test4, test5, test7, test8, test9); cpu(auto.test1); cpu(auto.test2); cpu(auto.test3); @@ -11,6 +11,7 @@ cpu(auto.test4); cpu(auto.test5); cpu(auto.test7); cpu(auto.test8); +cpu(auto.test9); let test1_cpu = auto.test1; rename["test1_cpu"](test1_cpu); @@ -51,11 +52,11 @@ fixpoint panic after 20 { unroll(auto.test1); } -fork-split(auto.test2, auto.test3, auto.test4, auto.test5); +fork-split(auto.test2, auto.test3, auto.test4, auto.test5, auto.test9); gvn(*); phi-elim(*); dce(*); -unforkify(auto.test2, auto.test3, auto.test4, auto.test5); +unforkify(auto.test2, auto.test3, auto.test4, auto.test5, auto.test9); ccp(*); gvn(*); phi-elim(*); @@ -93,4 +94,6 @@ dce(auto.test8); simplify-cfg(auto.test8); dce(auto.test8); +no-memset(test9@const); + gcm(*); diff --git a/juno_samples/fork_join_tests/src/fork_join_tests.jn b/juno_samples/fork_join_tests/src/fork_join_tests.jn index bfb5564be29e444d065c3caaa40afb760d678aa7..334fc2bfe4f745cec9004fdf9ebdf80d11818c0f 100644 --- a/juno_samples/fork_join_tests/src/fork_join_tests.jn +++ b/juno_samples/fork_join_tests/src/fork_join_tests.jn @@ -122,3 +122,28 @@ fn test8(input : i32) -> i32[8] { } return out; } + +#[entry] +fn test9<r, c : usize>(input : i32[r, c]) -> i32[r, c] { + const z = 3; + const rad = z / 2; + @const let out : i32[r, c]; + + for ir = 0 to r { + for ic = 0 to c { + let acc = 0; + @filter_loop for sr = 0 to z { + for sc = 0 to z { + acc += if ir + sr < rad then 0 + else if ir + sr - rad > r - 1 then 0 + else if ic + sc < rad then 0 + else if ic + sc - rad > c - 1 then 0 + else input[ir + sr - rad, ic + sc - rad]; + } + } + out[ir, ic] = acc; + } + } + + return out; +} diff --git a/juno_samples/fork_join_tests/src/gpu.sch b/juno_samples/fork_join_tests/src/gpu.sch index 364673cd2aedf323ca7727dc1ea8515188373a8d..81dc8d9854776931f4598a9010837008796baaf8 100644 --- a/juno_samples/fork_join_tests/src/gpu.sch +++ b/juno_samples/fork_join_tests/src/gpu.sch @@ -7,12 +7,13 @@ no-memset(test3@const3); no-memset(test6@const); no-memset(test8@const1); no-memset(test8@const2); +no-memset(test9@const); gvn(*); phi-elim(*); dce(*); -let auto = auto-outline(test1, test2, test3, test4, test5, test7, test8); +let auto = auto-outline(test1, test2, test3, test4, test5, test7, test8, test9); gpu(auto.test1); gpu(auto.test2); gpu(auto.test3); @@ -20,6 +21,7 @@ gpu(auto.test4); gpu(auto.test5); gpu(auto.test7); gpu(auto.test8); +gpu(auto.test9); ip-sroa(*); sroa(*); @@ -34,7 +36,7 @@ fixpoint panic after 20 { } fixpoint panic after 20 { - fork-coalesce(auto.test1, auto.test3, auto.test4, auto.test5, auto.test7, auto.test8); + fork-coalesce(auto.test1, auto.test3, auto.test4, auto.test5, auto.test7, auto.test8, auto.test9); } gvn(*); @@ -71,6 +73,8 @@ dce(auto.test8); simplify-cfg(auto.test8); dce(auto.test8); +no-memset(test9@const); + ip-sroa(*); sroa(*); dce(*); @@ -78,7 +82,7 @@ ccp(*); gvn(*); phi-elim(*); dce(*); -gcm(*); +gcm(*); float-collections(test2, auto.test2, test4, auto.test4, test5, auto.test5); gcm(*); diff --git a/juno_samples/fork_join_tests/src/main.rs b/juno_samples/fork_join_tests/src/main.rs index cd715cace80ba3e58d0971b7956ec1e957b839e2..e66309b22b0650feaab315829dd412f6275e9a99 100644 --- a/juno_samples/fork_join_tests/src/main.rs +++ b/juno_samples/fork_join_tests/src/main.rs @@ -1,6 +1,6 @@ #![feature(concat_idents)] -use hercules_rt::runner; +use hercules_rt::{runner, HerculesImmBox, HerculesImmBoxTo}; juno_build::juno!("fork_join_tests"); @@ -57,6 +57,23 @@ fn main() { let output = r.run(0).await; let correct = vec![10, 17, 24, 31, 38, 45, 52, 59]; assert(&correct, output); + + let mut r = runner!(test9); + let input = vec![1, 2, 3, 4, 5, 6, 7, 8, 9]; + let input = HerculesImmBox::from(&input as &[i32]); + let output = r.run(3, 3, input.to()).await; + let correct = vec![ + 1 + 2 + 4 + 5, + 1 + 2 + 3 + 4 + 5 + 6, + 2 + 3 + 5 + 6, + 1 + 2 + 4 + 5 + 7 + 8, + 1 + 2 + 3 + 4 + 5 + 6 + 7 + 8 + 9, + 2 + 3 + 5 + 6 + 8 + 9, + 4 + 5 + 7 + 8, + 4 + 5 + 6 + 7 + 8 + 9, + 5 + 6 + 8 + 9, + ]; + assert(&correct, output); }); } diff --git a/juno_samples/rodinia/cfd/src/main.rs b/juno_samples/rodinia/cfd/src/main.rs index 1ce6b89ae370eede1e0ca88c20aeff0f1261b66f..fab241fa6ead74d7ab52c774a6a6b0c6984426cc 100644 --- a/juno_samples/rodinia/cfd/src/main.rs +++ b/juno_samples/rodinia/cfd/src/main.rs @@ -151,7 +151,7 @@ fn cfd_harness(args: CFDInputs) { pre_euler, } = args; - assert!(block_size % 8 == 0, "Hercules expects all arrays to be 32-byte aligned, cfd uses structs of arrays that are annoying to deal with if the block_size is not a multiple of 8"); + assert!(block_size % 16 == 0, "Hercules expects all arrays to be 64-byte aligned, cfd uses structs of arrays that are annoying to deal with if the block_size is not a multiple of 16"); let FarFieldConditions { ff_variable, @@ -245,7 +245,7 @@ fn test_euler() { cfd_harness(CFDInputs { data_file: "data/fvcorr.domn.097K".to_string(), iterations: 1, - block_size: 8, + block_size: 16, pre_euler: false, }); } @@ -255,7 +255,7 @@ fn test_pre_euler() { cfd_harness(CFDInputs { data_file: "data/fvcorr.domn.097K".to_string(), iterations: 1, - block_size: 8, + block_size: 16, pre_euler: true, }); } diff --git a/juno_scheduler/src/compile.rs b/juno_scheduler/src/compile.rs index 13990ef9df632e2a01914d2690152a3d1462e739..3c288ca7a61c855b3982ba6ed66c215c9e2942fe 100644 --- a/juno_scheduler/src/compile.rs +++ b/juno_scheduler/src/compile.rs @@ -131,6 +131,7 @@ impl FromStr for Appliable { "fork-dim-merge" => Ok(Appliable::Pass(ir::Pass::ForkDimMerge)), "fork-interchange" => Ok(Appliable::Pass(ir::Pass::ForkInterchange)), "fork-chunk" | "fork-tile" => Ok(Appliable::Pass(ir::Pass::ForkChunk)), + "fork-extend" => Ok(Appliable::Pass(ir::Pass::ForkExtend)), "fork-unroll" | "unroll" => Ok(Appliable::Pass(ir::Pass::ForkUnroll)), "fork-fusion" | "fusion" => Ok(Appliable::Pass(ir::Pass::ForkFusion)), "lift-dc-math" => Ok(Appliable::Pass(ir::Pass::LiftDCMath)), diff --git a/juno_scheduler/src/ir.rs b/juno_scheduler/src/ir.rs index bbecc6ff190af7077f646c525291f5da7655f0fa..3a087c0d40093c6363e67332c1bd489f22727a42 100644 --- a/juno_scheduler/src/ir.rs +++ b/juno_scheduler/src/ir.rs @@ -15,6 +15,7 @@ pub enum Pass { ForkChunk, ForkCoalesce, ForkDimMerge, + ForkExtend, ForkFissionBufferize, ForkFusion, ForkGuardElim, @@ -53,6 +54,7 @@ impl Pass { match self { Pass::ArrayToProduct => num == 0 || num == 1, Pass::ForkChunk => num == 4, + Pass::ForkExtend => num == 1, Pass::ForkFissionBufferize => num == 2 || num == 1, Pass::ForkInterchange => num == 2, Pass::Print => num == 1, @@ -68,6 +70,7 @@ impl Pass { match self { Pass::ArrayToProduct => "0 or 1", Pass::ForkChunk => "4", + Pass::ForkExtend => "1", Pass::ForkFissionBufferize => "1 or 2", Pass::ForkInterchange => "2", Pass::Print => "1", diff --git a/juno_scheduler/src/pm.rs b/juno_scheduler/src/pm.rs index d5e280b442de3a6571bdfa960cef652a4dac117d..5f2fa4cce02e04ca4499cbff8806356f81cc86bc 100644 --- a/juno_scheduler/src/pm.rs +++ b/juno_scheduler/src/pm.rs @@ -900,6 +900,7 @@ impl PassManager { self.make_typing(); self.make_control_subgraphs(); self.make_fork_join_maps(); + self.make_fork_join_nests(); self.make_fork_control_maps(); self.make_fork_trees(); self.make_nodes_in_fork_joins(); @@ -917,6 +918,7 @@ impl PassManager { typing: Some(typing), control_subgraphs: Some(control_subgraphs), fork_join_maps: Some(fork_join_maps), + fork_join_nests: Some(fork_join_nests), fork_control_maps: Some(fork_control_maps), fork_trees: Some(fork_trees), nodes_in_fork_joins: Some(nodes_in_fork_joins), @@ -990,7 +992,7 @@ impl PassManager { &typing[idx], &control_subgraphs[idx], &fork_join_maps[idx], - &fork_control_maps[idx], + &fork_join_nests[idx], &fork_trees[idx], &nodes_in_fork_joins[idx], &collection_objects, @@ -2105,6 +2107,7 @@ fn run_pass( pm.make_control_subgraphs(); pm.make_doms(); pm.make_fork_join_maps(); + pm.make_fork_join_nests(); pm.make_loops(); pm.make_reduce_cycles(); pm.make_collection_objects(); @@ -2115,6 +2118,7 @@ fn run_pass( let typing = pm.typing.take().unwrap(); let doms = pm.doms.take().unwrap(); let fork_join_maps = pm.fork_join_maps.take().unwrap(); + let fork_join_nests = pm.fork_join_nests.take().unwrap(); let loops = pm.loops.take().unwrap(); let reduce_cycles = pm.reduce_cycles.take().unwrap(); let control_subgraphs = pm.control_subgraphs.take().unwrap(); @@ -2136,6 +2140,7 @@ fn run_pass( &control_subgraphs[id.idx()], &doms[id.idx()], &fork_join_maps[id.idx()], + &fork_join_nests[id.idx()], &loops[id.idx()], &reduce_cycles[id.idx()], &collection_objects, @@ -2642,6 +2647,30 @@ fn run_pass( pm.delete_gravestones(); pm.clear_analyses(); } + Pass::ForkExtend => { + assert_eq!(args.len(), 1); + let Some(Value::Integer { val: multiple }) = args.get(0) else { + return Err(SchedulerError::PassError { + pass: "forkExtend".to_string(), + error: "expected integer argument".to_string(), + }); + }; + + pm.make_fork_join_maps(); + let fork_join_maps = pm.fork_join_maps.take().unwrap(); + for (func, fork_join_map) in build_selection(pm, selection, false) + .into_iter() + .zip(fork_join_maps.iter()) + { + let Some(mut func) = func else { + continue; + }; + extend_all_forks(&mut func, fork_join_map, *multiple); + changed |= func.modified(); + } + pm.delete_gravestones(); + pm.clear_analyses(); + } Pass::ForkFissionBufferize => { assert!(args.len() == 1 || args.len() == 2); let Some(Value::Label {