From 52c0eb8a37c28f91e9ad6c431bbfdf9242631256 Mon Sep 17 00:00:00 2001 From: rarbore2 <rarbore2@illinois.edu> Date: Sun, 2 Feb 2025 16:54:06 -0600 Subject: [PATCH] More fork tests, bug fixes in fork passes and GPU backend --- Cargo.lock | 10 + Cargo.toml | 1 + hercules_cg/src/gpu.rs | 85 +++-- hercules_ir/src/fork_join_analysis.rs | 34 +- hercules_ir/src/ir.rs | 2 +- hercules_ir/src/loops.rs | 4 + hercules_opt/src/float_collections.rs | 27 +- hercules_opt/src/fork_concat_split.rs | 141 -------- hercules_opt/src/fork_transforms.rs | 155 ++++++++ hercules_opt/src/gcm.rs | 77 +++- hercules_opt/src/lib.rs | 2 - hercules_opt/src/unforkify.rs | 335 +++++++++--------- juno_samples/fork_join_tests/Cargo.toml | 21 ++ juno_samples/fork_join_tests/build.rs | 24 ++ juno_samples/fork_join_tests/src/cpu.sch | 42 +++ .../fork_join_tests/src/fork_join_tests.jn | 61 ++++ juno_samples/fork_join_tests/src/gpu.sch | 33 ++ juno_samples/fork_join_tests/src/main.rs | 46 +++ juno_scheduler/src/pm.rs | 26 +- 19 files changed, 752 insertions(+), 374 deletions(-) delete mode 100644 hercules_opt/src/fork_concat_split.rs create mode 100644 juno_samples/fork_join_tests/Cargo.toml create mode 100644 juno_samples/fork_join_tests/build.rs create mode 100644 juno_samples/fork_join_tests/src/cpu.sch create mode 100644 juno_samples/fork_join_tests/src/fork_join_tests.jn create mode 100644 juno_samples/fork_join_tests/src/gpu.sch create mode 100644 juno_samples/fork_join_tests/src/main.rs diff --git a/Cargo.lock b/Cargo.lock index b8bf2278..af7902c6 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1130,6 +1130,16 @@ dependencies = [ "with_builtin_macros", ] +[[package]] +name = "juno_fork_join_tests" +version = "0.1.0" +dependencies = [ + "async-std", + "hercules_rt", + "juno_build", + "with_builtin_macros", +] + [[package]] name = "juno_frontend" version = "0.1.0" diff --git a/Cargo.toml b/Cargo.toml index f7b9322a..890d7924 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -31,4 +31,5 @@ members = [ "juno_samples/concat", "juno_samples/schedule_test", "juno_samples/edge_detection", + "juno_samples/fork_join_tests", ] diff --git a/hercules_cg/src/gpu.rs b/hercules_cg/src/gpu.rs index 81e31396..1086d1aa 100644 --- a/hercules_cg/src/gpu.rs +++ b/hercules_cg/src/gpu.rs @@ -323,7 +323,10 @@ impl GPUContext<'_> { // Emit all GPU kernel code from previous steps let mut kernel_body = String::new(); - self.codegen_gotos(false, &mut gotos, &mut kernel_body)?; + let rev_po = self.control_subgraph.rev_po(NodeID::new(0)); + write!(w, "\n")?; + self.codegen_goto(false, &mut gotos, NodeID::new(0), &mut kernel_body)?; + self.codegen_gotos(false, &mut gotos, &rev_po, NodeID::new(0), &mut kernel_body)?; write!(w, "{}", kernel_body)?; write!(w, "}}\n")?; @@ -536,27 +539,62 @@ namespace cg = cooperative_groups; &self, goto_debug: bool, gotos: &mut BTreeMap<NodeID, CudaGoto>, + rev_po: &Vec<NodeID>, + root: NodeID, w: &mut String, ) -> Result<(), Error> { - write!(w, "\n")?; - for (id, goto) in gotos.iter() { - let goto_block = self.get_block_name(*id, false); - write!(w, "{}:\n", goto_block)?; - if goto_debug { - write!(w, "\tprintf(\"goto {}\\n\");\n", goto_block)?; - } - write!(w, "{}", goto.init)?; - if !goto.post_init.is_empty() { - let goto_block = self.get_block_name(*id, true); - write!(w, "{}:\n", goto_block)?; - write!(w, "{}", goto.post_init)?; + // Print the blocks in a kind of silly way to avoid errors aroun + // initialization of fork variables and gotos. + let mut not_forks = vec![]; + let mut forks = vec![]; + let not_fork_controls = &self.fork_control_map[&root]; + for bb in rev_po + .into_iter() + .filter(|id| not_fork_controls.contains(id) && **id != root) + { + not_forks.push(*bb); + } + if let Some(fork_controls) = &self.fork_tree.get(&root) { + for bb in rev_po + .into_iter() + .filter(|id| fork_controls.contains(id) && **id != root) + { + forks.push(*bb); } - write!(w, "{}", goto.body)?; - write!(w, "{}\n", goto.term)?; + } + for id in not_forks { + self.codegen_goto(goto_debug, gotos, id, w)?; + } + for root in forks { + self.codegen_goto(goto_debug, gotos, root, w)?; + self.codegen_gotos(goto_debug, gotos, rev_po, root, w)?; } Ok(()) } + fn codegen_goto( + &self, + goto_debug: bool, + gotos: &mut BTreeMap<NodeID, CudaGoto>, + bb: NodeID, + w: &mut String, + ) -> Result<(), Error> { + let goto = &gotos[&bb]; + let goto_block = self.get_block_name(bb, false); + write!(w, "{}:\n", goto_block)?; + if goto_debug { + write!(w, "\tprintf(\"goto {}\\n\");\n", goto_block)?; + } + write!(w, "{}", goto.init)?; + if !goto.post_init.is_empty() { + let goto_block = self.get_block_name(bb, true); + write!(w, "{}:\n", goto_block)?; + write!(w, "{}", goto.post_init)?; + } + write!(w, "{}", goto.body)?; + write!(w, "{}\n", goto.term) + } + fn codegen_launch_code( &self, num_blocks: String, @@ -620,23 +658,23 @@ extern \"C\" {} {}(", write!(pass_args, "ret")?; write!(w, "\tcudaMalloc((void**)&ret, sizeof({}));\n", ret_type)?; } - write!(w, "\tcudaError_t err;\n"); + write!(w, "\tcudaError_t err;\n")?; write!( w, "\t{}_gpu<<<{}, {}, {}>>>({});\n", self.function.name, num_blocks, num_threads, dynamic_shared_offset, pass_args )?; - write!(w, "\terr = cudaGetLastError();\n"); + write!(w, "\terr = cudaGetLastError();\n")?; write!( w, "\tif (cudaSuccess != err) {{ printf(\"Error1: %s\\n\", cudaGetErrorString(err)); }}\n" - ); + )?; write!(w, "\tcudaDeviceSynchronize();\n")?; - write!(w, "\terr = cudaGetLastError();\n"); + write!(w, "\terr = cudaGetLastError();\n")?; write!( w, "\tif (cudaSuccess != err) {{ printf(\"Error2: %s\\n\", cudaGetErrorString(err)); }}\n" - ); + )?; if has_ret_var { // Copy return from device to host, whether it's primitive value or collection pointer write!(w, "\t{} host_ret;\n", ret_type)?; @@ -1148,7 +1186,8 @@ extern \"C\" {} {}(", // for all threads. Otherwise, it can be inside or outside block fork. // If inside, it's stored in shared memory so we "allocate" it once // and parallelize memset to 0. If outside, we initialize as offset - // to backing, but if multi-block grid, don't memset to avoid grid-level sync. + // to backing, but if multi-block grid, don't memset to avoid grid- + // level sync. Node::Constant { id: cons_id } => { let is_primitive = self.types[self.typing[id.idx()].idx()].is_primitive(); let cg_tile = match state { @@ -1190,9 +1229,7 @@ extern \"C\" {} {}(", )?; } if !is_primitive - && (state != KernelState::OutBlock - || is_block_parallel.is_none() - || !is_block_parallel.unwrap()) + && (state != KernelState::OutBlock || !is_block_parallel.unwrap_or(false)) { let data_size = self.get_size(self.typing[id.idx()], None, Some(extra_dim_collects)); diff --git a/hercules_ir/src/fork_join_analysis.rs b/hercules_ir/src/fork_join_analysis.rs index 263fa952..7a098a35 100644 --- a/hercules_ir/src/fork_join_analysis.rs +++ b/hercules_ir/src/fork_join_analysis.rs @@ -140,23 +140,23 @@ fn reduce_cycle_dfs_helper( } current_visited.insert(iter); - let found_reduce = get_uses(&function.nodes[iter.idx()]) - .as_ref() - .into_iter() - .any(|u| { - !current_visited.contains(u) - && !function.nodes[u.idx()].is_control() - && isnt_outside_fork_join(*u) - && reduce_cycle_dfs_helper( - function, - *u, - fork, - reduce, - current_visited, - in_cycle, - fork_join_nest, - ) - }); + let mut found_reduce = false; + + // This doesn't short circuit on purpose. + for u in get_uses(&function.nodes[iter.idx()]).as_ref() { + found_reduce |= !current_visited.contains(u) + && !function.nodes[u.idx()].is_control() + && isnt_outside_fork_join(*u) + && reduce_cycle_dfs_helper( + function, + *u, + fork, + reduce, + current_visited, + in_cycle, + fork_join_nest, + ) + } if found_reduce { in_cycle.insert(iter); } diff --git a/hercules_ir/src/ir.rs b/hercules_ir/src/ir.rs index 846347b0..5c575ea1 100644 --- a/hercules_ir/src/ir.rs +++ b/hercules_ir/src/ir.rs @@ -1857,7 +1857,7 @@ pub fn indices_may_overlap(indices1: &[Index], indices2: &[Index]) -> bool { * list of indices B. */ pub fn indices_contain_other_indices(indices_a: &[Index], indices_b: &[Index]) -> bool { - if indices_a.len() < indices_b.len() { + if indices_a.len() > indices_b.len() { return false; } diff --git a/hercules_ir/src/loops.rs b/hercules_ir/src/loops.rs index a425c442..daf85968 100644 --- a/hercules_ir/src/loops.rs +++ b/hercules_ir/src/loops.rs @@ -40,6 +40,10 @@ impl LoopTree { self.loops[&header].0.iter_ones().map(NodeID::new) } + pub fn nodes_in_loop_bitvec(&self, header: NodeID) -> &BitVec<u8, Lsb0> { + &self.loops[&header].0 + } + pub fn is_in_loop(&self, header: NodeID, is_in: NodeID) -> bool { header == self.root || self.loops[&header].0[is_in.idx()] } diff --git a/hercules_opt/src/float_collections.rs b/hercules_opt/src/float_collections.rs index faa38375..6ef050c2 100644 --- a/hercules_opt/src/float_collections.rs +++ b/hercules_opt/src/float_collections.rs @@ -1,3 +1,5 @@ +use std::collections::BTreeMap; + use hercules_ir::*; use crate::*; @@ -7,27 +9,36 @@ use crate::*; * allowed. */ pub fn float_collections( - editors: &mut [FunctionEditor], + editors: &mut BTreeMap<FunctionID, FunctionEditor>, typing: &ModuleTyping, callgraph: &CallGraph, devices: &Vec<Device>, ) { - let topo = callgraph.topo(); + let topo: Vec<_> = callgraph + .topo() + .into_iter() + .filter(|id| editors.contains_key(&id)) + .collect(); for to_float_id in topo { // Collection constants float until reaching an AsyncRust function. if devices[to_float_id.idx()] == Device::AsyncRust { continue; } + // Check that all callers are in the selection as well. + for caller in callgraph.get_callers(to_float_id) { + assert!(editors.contains_key(&caller), "PANIC: FloatCollections called where a function ({:?}, {:?}) is in the selection but one of its callers ({:?}) is not. This means no collections will be floated from the callee, since the caller can't be modified to hold floated collections.", to_float_id, editors[&to_float_id].func().name, caller); + } + // Find the target constant nodes in the function. - let cons: Vec<(NodeID, Node)> = editors[to_float_id.idx()] + let cons: Vec<(NodeID, Node)> = editors[&to_float_id] .func() .nodes .iter() .enumerate() .filter(|(_, node)| { node.try_constant() - .map(|cons_id| !editors[to_float_id.idx()].get_constant(cons_id).is_scalar()) + .map(|cons_id| !editors[&to_float_id].get_constant(cons_id).is_scalar()) .unwrap_or(false) }) .map(|(idx, node)| (NodeID::new(idx), node.clone())) @@ -37,12 +48,12 @@ pub fn float_collections( } // Each constant node becomes a new parameter. - let mut new_param_types = editors[to_float_id.idx()].func().param_types.clone(); + let mut new_param_types = editors[&to_float_id].func().param_types.clone(); let old_num_params = new_param_types.len(); for (id, _) in cons.iter() { new_param_types.push(typing[to_float_id.idx()][id.idx()]); } - let success = editors[to_float_id.idx()].edit(|mut edit| { + let success = editors.get_mut(&to_float_id).unwrap().edit(|mut edit| { for (idx, (id, _)) in cons.iter().enumerate() { let param = edit.add_node(Node::Parameter { index: idx + old_num_params, @@ -59,7 +70,7 @@ pub fn float_collections( // Add constants in callers and pass them into calls. for caller in callgraph.get_callers(to_float_id) { - let calls: Vec<(NodeID, Node)> = editors[caller.idx()] + let calls: Vec<(NodeID, Node)> = editors[&caller] .func() .nodes .iter() @@ -71,7 +82,7 @@ pub fn float_collections( }) .map(|(idx, node)| (NodeID::new(idx), node.clone())) .collect(); - let success = editors[caller.idx()].edit(|mut edit| { + let success = editors.get_mut(&caller).unwrap().edit(|mut edit| { let cons_ids: Vec<_> = cons .iter() .map(|(_, node)| edit.add_node(node.clone())) diff --git a/hercules_opt/src/fork_concat_split.rs b/hercules_opt/src/fork_concat_split.rs deleted file mode 100644 index bb3a2cff..00000000 --- a/hercules_opt/src/fork_concat_split.rs +++ /dev/null @@ -1,141 +0,0 @@ -use std::collections::{HashMap, HashSet}; -use std::iter::zip; - -use hercules_ir::ir::*; - -use crate::*; - -/* - * Split multi-dimensional fork-joins into separate one-dimensional fork-joins. - * Useful for code generation. A single iteration of `fork_split` only splits - * at most one fork-join, it must be called repeatedly to split all fork-joins. - */ -pub fn fork_split( - editor: &mut FunctionEditor, - fork_join_map: &HashMap<NodeID, NodeID>, - reduce_cycles: &HashMap<NodeID, HashSet<NodeID>>, -) { - // A single multi-dimensional fork becomes multiple forks, a join becomes - // multiple joins, a thread ID becomes a thread ID on the correct - // fork, and a reduce becomes multiple reduces to shuffle the reduction - // value through the fork-join nest. - for (fork, join) in fork_join_map { - let nodes = &editor.func().nodes; - let (fork_control, factors) = nodes[fork.idx()].try_fork().unwrap(); - if factors.len() < 2 { - continue; - } - let factors: Box<[DynamicConstantID]> = factors.into(); - let join_control = nodes[join.idx()].try_join().unwrap(); - let tids: Vec<_> = editor - .get_users(*fork) - .filter(|id| nodes[id.idx()].is_thread_id()) - .collect(); - let reduces: Vec<_> = editor - .get_users(*join) - .filter(|id| nodes[id.idx()].is_reduce()) - .collect(); - - let data_in_reduce_cycle: HashSet<(NodeID, NodeID)> = reduces - .iter() - .map(|reduce| editor.get_users(*reduce).map(move |user| (user, *reduce))) - .flatten() - .filter(|(user, reduce)| reduce_cycles[&reduce].contains(&user)) - .collect(); - - editor.edit(|mut edit| { - // Create the forks and a thread ID per fork. - let mut acc_fork = fork_control; - let mut new_tids = vec![]; - for factor in factors { - acc_fork = edit.add_node(Node::Fork { - control: acc_fork, - factors: Box::new([factor]), - }); - edit.sub_edit(*fork, acc_fork); - new_tids.push(edit.add_node(Node::ThreadID { - control: acc_fork, - dimension: 0, - })); - } - - // Create the joins. - let mut acc_join = if join_control == *fork { - acc_fork - } else { - join_control - }; - let mut joins = vec![]; - for _ in new_tids.iter() { - acc_join = edit.add_node(Node::Join { control: acc_join }); - edit.sub_edit(*join, acc_join); - joins.push(acc_join); - } - - // Create the reduces. - let mut new_reduces = vec![]; - for reduce in reduces.iter() { - let (_, init, reduct) = edit.get_node(*reduce).try_reduce().unwrap(); - let num_nodes = edit.num_node_ids(); - let mut inner_reduce = NodeID::new(0); - let mut outer_reduce = NodeID::new(0); - for (join_idx, join) in joins.iter().enumerate() { - let init = if join_idx == joins.len() - 1 { - init - } else { - NodeID::new(num_nodes + join_idx + 1) - }; - let reduct = if join_idx == 0 { - reduct - } else { - NodeID::new(num_nodes + join_idx - 1) - }; - let new_reduce = edit.add_node(Node::Reduce { - control: *join, - init, - reduct, - }); - assert_eq!(new_reduce, NodeID::new(num_nodes + join_idx)); - edit.sub_edit(*reduce, new_reduce); - if join_idx == 0 { - inner_reduce = new_reduce; - } - if join_idx == joins.len() - 1 { - outer_reduce = new_reduce; - } - } - new_reduces.push((inner_reduce, outer_reduce)); - } - - // Replace everything. - edit = edit.replace_all_uses(*fork, acc_fork)?; - edit = edit.replace_all_uses(*join, acc_join)?; - for tid in tids.iter() { - let dim = edit.get_node(*tid).try_thread_id().unwrap().1; - edit.sub_edit(*tid, new_tids[dim]); - edit = edit.replace_all_uses(*tid, new_tids[dim])?; - } - for (reduce, (inner_reduce, outer_reduce)) in zip(reduces.iter(), new_reduces) { - edit = edit.replace_all_uses_where(*reduce, inner_reduce, |id| { - data_in_reduce_cycle.contains(&(*id, *reduce)) - })?; - edit = edit.replace_all_uses_where(*reduce, outer_reduce, |id| { - !data_in_reduce_cycle.contains(&(*id, *reduce)) - })?; - } - - // Delete all the old stuff. - edit = edit.delete_node(*fork)?; - edit = edit.delete_node(*join)?; - for tid in tids { - edit = edit.delete_node(tid)?; - } - for reduce in reduces { - edit = edit.delete_node(reduce)?; - } - - Ok(edit) - }); - break; - } -} diff --git a/hercules_opt/src/fork_transforms.rs b/hercules_opt/src/fork_transforms.rs index a4605bec..e23f586f 100644 --- a/hercules_opt/src/fork_transforms.rs +++ b/hercules_opt/src/fork_transforms.rs @@ -1,4 +1,5 @@ use std::collections::{HashMap, HashSet}; +use std::iter::zip; use bimap::BiMap; use itertools::Itertools; @@ -538,3 +539,157 @@ pub fn fork_coalesce_helper( true } + +pub fn split_all_forks( + editor: &mut FunctionEditor, + fork_join_map: &HashMap<NodeID, NodeID>, + reduce_cycles: &HashMap<NodeID, HashSet<NodeID>>, +) { + for (fork, join) in fork_join_map { + if let Some((forks, _)) = split_fork(editor, *fork, *join, reduce_cycles) + && forks.len() > 1 + { + break; + } + } +} + +/* + * Split multi-dimensional fork-joins into separate one-dimensional fork-joins. + * Useful for code generation. A single iteration of `fork_split` only splits + * at most one fork-join, it must be called repeatedly to split all fork-joins. + */ +pub(crate) fn split_fork( + editor: &mut FunctionEditor, + fork: NodeID, + join: NodeID, + reduce_cycles: &HashMap<NodeID, HashSet<NodeID>>, +) -> Option<(Vec<NodeID>, Vec<NodeID>)> { + // A single multi-dimensional fork becomes multiple forks, a join becomes + // multiple joins, a thread ID becomes a thread ID on the correct + // fork, and a reduce becomes multiple reduces to shuffle the reduction + // value through the fork-join nest. + let nodes = &editor.func().nodes; + let (fork_control, factors) = nodes[fork.idx()].try_fork().unwrap(); + if factors.len() < 2 { + return Some((vec![fork], vec![join])); + } + let factors: Box<[DynamicConstantID]> = factors.into(); + let join_control = nodes[join.idx()].try_join().unwrap(); + let tids: Vec<_> = editor + .get_users(fork) + .filter(|id| nodes[id.idx()].is_thread_id()) + .collect(); + let reduces: Vec<_> = editor + .get_users(join) + .filter(|id| nodes[id.idx()].is_reduce()) + .collect(); + + let data_in_reduce_cycle: HashSet<(NodeID, NodeID)> = reduces + .iter() + .map(|reduce| editor.get_users(*reduce).map(move |user| (user, *reduce))) + .flatten() + .filter(|(user, reduce)| reduce_cycles[&reduce].contains(&user)) + .collect(); + + let mut new_forks = vec![]; + let mut new_joins = vec![]; + let success = editor.edit(|mut edit| { + // Create the forks and a thread ID per fork. + let mut acc_fork = fork_control; + let mut new_tids = vec![]; + for factor in factors { + acc_fork = edit.add_node(Node::Fork { + control: acc_fork, + factors: Box::new([factor]), + }); + new_forks.push(acc_fork); + edit.sub_edit(fork, acc_fork); + new_tids.push(edit.add_node(Node::ThreadID { + control: acc_fork, + dimension: 0, + })); + } + + // Create the joins. + let mut acc_join = if join_control == fork { + acc_fork + } else { + join_control + }; + for _ in new_tids.iter() { + acc_join = edit.add_node(Node::Join { control: acc_join }); + edit.sub_edit(join, acc_join); + new_joins.push(acc_join); + } + + // Create the reduces. + let mut new_reduces = vec![]; + for reduce in reduces.iter() { + let (_, init, reduct) = edit.get_node(*reduce).try_reduce().unwrap(); + let num_nodes = edit.num_node_ids(); + let mut inner_reduce = NodeID::new(0); + let mut outer_reduce = NodeID::new(0); + for (join_idx, join) in new_joins.iter().enumerate() { + let init = if join_idx == new_joins.len() - 1 { + init + } else { + NodeID::new(num_nodes + join_idx + 1) + }; + let reduct = if join_idx == 0 { + reduct + } else { + NodeID::new(num_nodes + join_idx - 1) + }; + let new_reduce = edit.add_node(Node::Reduce { + control: *join, + init, + reduct, + }); + assert_eq!(new_reduce, NodeID::new(num_nodes + join_idx)); + edit.sub_edit(*reduce, new_reduce); + if join_idx == 0 { + inner_reduce = new_reduce; + } + if join_idx == new_joins.len() - 1 { + outer_reduce = new_reduce; + } + } + new_reduces.push((inner_reduce, outer_reduce)); + } + + // Replace everything. + edit = edit.replace_all_uses(fork, acc_fork)?; + edit = edit.replace_all_uses(join, acc_join)?; + for tid in tids.iter() { + let dim = edit.get_node(*tid).try_thread_id().unwrap().1; + edit.sub_edit(*tid, new_tids[dim]); + edit = edit.replace_all_uses(*tid, new_tids[dim])?; + } + for (reduce, (inner_reduce, outer_reduce)) in zip(reduces.iter(), new_reduces) { + edit = edit.replace_all_uses_where(*reduce, inner_reduce, |id| { + data_in_reduce_cycle.contains(&(*id, *reduce)) + })?; + edit = edit.replace_all_uses_where(*reduce, outer_reduce, |id| { + !data_in_reduce_cycle.contains(&(*id, *reduce)) + })?; + } + + // Delete all the old stuff. + edit = edit.delete_node(fork)?; + edit = edit.delete_node(join)?; + for tid in tids { + edit = edit.delete_node(tid)?; + } + for reduce in reduces { + edit = edit.delete_node(reduce)?; + } + + Ok(edit) + }); + if success { + Some((new_forks, new_joins)) + } else { + None + } +} diff --git a/hercules_opt/src/gcm.rs b/hercules_opt/src/gcm.rs index 271bfaf1..d9505fde 100644 --- a/hercules_opt/src/gcm.rs +++ b/hercules_opt/src/gcm.rs @@ -1,5 +1,5 @@ use std::cell::Ref; -use std::collections::{BTreeMap, BTreeSet, HashMap, VecDeque}; +use std::collections::{BTreeMap, BTreeSet, HashMap, HashSet, VecDeque}; use std::iter::{empty, once, zip, FromIterator}; use bitvec::prelude::*; @@ -40,6 +40,13 @@ use crate::*; * necessarily valid anymore, so this function is called in a loop in the pass * manager until no more spills are found. * + * GCM also generally tries to massage the code to be properly formed for the + * device backends in other ways. For example, reduction cycles through the + * `init` inputs of an inner reduce and a use of a non-parallel outer reduce in + * nested fork-joins is not schedulable, since the outer reduce doesn't dominate + * the fork corresponding to the inner reduce. In such cases, the outer fork- + * join must be split and unforkified. + * * GCM is additionally complicated by the need to generate code that references * objects across multiple devices. In particular, GCM makes sure that every * object lives on exactly one device, so that references to that object always @@ -76,11 +83,16 @@ pub fn gcm( dom: &DomTree, fork_join_map: &HashMap<NodeID, NodeID>, loops: &LoopTree, + reduce_cycles: &HashMap<NodeID, HashSet<NodeID>>, objects: &CollectionObjects, devices: &Vec<Device>, object_device_demands: &FunctionObjectDeviceDemands, backing_allocations: &BackingAllocations, ) -> Option<(BasicBlocks, FunctionNodeColors, FunctionBackingAllocation)> { + if preliminary_fixups(editor, fork_join_map, loops, reduce_cycles) { + return None; + } + let bbs = basic_blocks( editor.func(), editor.func_id(), @@ -88,8 +100,10 @@ pub fn gcm( reverse_postorder, dom, loops, + reduce_cycles, fork_join_map, objects, + devices, ); let liveness = liveness_dataflow( @@ -156,6 +170,37 @@ pub fn gcm( Some((bbs, node_colors, backing_allocation)) } +/* + * Do misc. fixups on the IR, such as unforkifying sequential outer forks with + * problematic reduces. + */ +fn preliminary_fixups( + editor: &mut FunctionEditor, + fork_join_map: &HashMap<NodeID, NodeID>, + loops: &LoopTree, + reduce_cycles: &HashMap<NodeID, HashSet<NodeID>>, +) -> bool { + let nodes = &editor.func().nodes; + for (reduce, cycle) in reduce_cycles { + if cycle.into_iter().any(|id| nodes[id.idx()].is_reduce()) { + let join = nodes[reduce.idx()].try_reduce().unwrap().0; + let fork = fork_join_map + .into_iter() + .filter(|(_, j)| join == **j) + .map(|(f, _)| *f) + .next() + .unwrap(); + let (forks, _) = split_fork(editor, fork, join, reduce_cycles).unwrap(); + if forks.len() > 1 { + return true; + } + unforkify(editor, fork, join, loops); + return true; + } + } + false +} + /* * Top level global code motion function. Assigns each data node to one of its * immediate control use / user nodes, forming (unordered) basic blocks. Returns @@ -172,8 +217,10 @@ fn basic_blocks( reverse_postorder: &Vec<NodeID>, dom: &DomTree, loops: &LoopTree, + reduce_cycles: &HashMap<NodeID, HashSet<NodeID>>, fork_join_map: &HashMap<NodeID, NodeID>, objects: &CollectionObjects, + devices: &Vec<Device>, ) -> BasicBlocks { let mut bbs: Vec<Option<NodeID>> = vec![None; function.nodes.len()]; @@ -244,6 +291,9 @@ fn basic_blocks( // but not forwarding read - forwarding reads are collapsed, and the // bottom read is treated as reading from the transitive parent of the // forwarding read(s). + // 3: If the node producing the collection is a reduce node, then any read + // users that aren't in the reduce's cycle shouldn't anti-depend user any + // mutators in the reduce cycle. let mut antideps = BTreeSet::new(); for id in reverse_postorder.iter() { // Find a terminating read node and the collections it reads. @@ -269,6 +319,10 @@ fn basic_blocks( // TODO: make this less outrageously inefficient. let func_objects = &objects[&func_id]; for root in roots.iter() { + let root_is_reduce_and_read_isnt_in_cycle = reduce_cycles + .get(root) + .map(|cycle| !cycle.contains(&id)) + .unwrap_or(false); let root_early = schedule_early[root.idx()].unwrap(); let mut root_block_iterated_users: BTreeSet<NodeID> = BTreeSet::new(); let mut workset = BTreeSet::new(); @@ -296,6 +350,11 @@ fn basic_blocks( && mutating_objects(function, func_id, *mutator, objects) .any(|mutated| read_objs.contains(&mutated)) && id != mutator + && (!root_is_reduce_and_read_isnt_in_cycle + || !reduce_cycles + .get(root) + .map(|cycle| cycle.contains(mutator)) + .unwrap_or(false)) { antideps.insert((*id, *mutator)); } @@ -421,9 +480,18 @@ fn basic_blocks( // If the next node further up the dominator tree is in a shallower // loop nest or if we can get out of a reduce loop when we don't // need to be in one, place this data node in a higher-up location. - // Only do this is the node isn't a constant or undef. + // Only do this is the node isn't a constant or undef - if a + // node is a constant or undef, we want its placement to be as + // control dependent as possible, even inside loops. In GPU + // functions specifically, lift constants that may be returned + // outside fork-joins. let is_constant_or_undef = function.nodes[id.idx()].is_constant() || function.nodes[id.idx()].is_undef(); + let is_gpu_returned = devices[func_id.idx()] == Device::CUDA + && objects[&func_id] + .objects(id) + .into_iter() + .any(|obj| objects[&func_id].returned_objects().contains(obj)); let old_nest = loops .header_of(location) .map(|header| loops.nesting(header).unwrap()); @@ -444,7 +512,10 @@ fn basic_blocks( // loop use the reduce node forming the loop, so the dominator chain // will consist of one block, and this loop won't ever iterate. let currently_at_join = function.nodes[location.idx()].is_join(); - if !is_constant_or_undef && (shallower_nest || currently_at_join) { + + if (!is_constant_or_undef || is_gpu_returned) + && (shallower_nest || currently_at_join) + { location = control_node; } } diff --git a/hercules_opt/src/lib.rs b/hercules_opt/src/lib.rs index e3cca161..48475f2f 100644 --- a/hercules_opt/src/lib.rs +++ b/hercules_opt/src/lib.rs @@ -6,7 +6,6 @@ pub mod dce; pub mod delete_uncalled; pub mod editor; pub mod float_collections; -pub mod fork_concat_split; pub mod fork_guard_elim; pub mod fork_transforms; pub mod forkify; @@ -31,7 +30,6 @@ pub use crate::dce::*; pub use crate::delete_uncalled::*; pub use crate::editor::*; pub use crate::float_collections::*; -pub use crate::fork_concat_split::*; pub use crate::fork_guard_elim::*; pub use crate::fork_transforms::*; pub use crate::forkify::*; diff --git a/hercules_opt/src/unforkify.rs b/hercules_opt/src/unforkify.rs index 85ffd233..7451b0ad 100644 --- a/hercules_opt/src/unforkify.rs +++ b/hercules_opt/src/unforkify.rs @@ -100,11 +100,24 @@ pub fn calculate_fork_nodes( */ // FIXME: Only works on fully split fork nests. -pub fn unforkify( +pub fn unforkify_all( editor: &mut FunctionEditor, fork_join_map: &HashMap<NodeID, NodeID>, loop_tree: &LoopTree, ) { + for l in loop_tree.bottom_up_loops().into_iter().rev() { + if !editor.node(l.0).is_fork() { + continue; + } + + let fork = l.0; + let join = fork_join_map[&fork]; + + unforkify(editor, fork, join, loop_tree); + } +} + +pub fn unforkify(editor: &mut FunctionEditor, fork: NodeID, join: NodeID, loop_tree: &LoopTree) { let mut zero_cons_id = ConstantID::new(0); let mut one_cons_id = ConstantID::new(0); assert!(editor.edit(|mut edit| { @@ -118,180 +131,170 @@ pub fn unforkify( // control insides of the fork-join should become the successor of the true // projection node, and what was the use of the join should become a use of // the new region. - for l in loop_tree.bottom_up_loops().into_iter().rev() { - if !editor.node(l.0).is_fork() { - continue; - } - - let fork = &l.0; - let join = &fork_join_map[&fork]; - - let fork_nodes = calculate_fork_nodes(editor, l.1, *fork); + let fork_nodes = calculate_fork_nodes(editor, loop_tree.nodes_in_loop_bitvec(fork), fork); - let nodes = &editor.func().nodes; - let (fork_control, factors) = nodes[fork.idx()].try_fork().unwrap(); - if factors.len() > 1 { - // For now, don't convert multi-dimensional fork-joins. Rely on pass - // that splits fork-joins. - continue; - } - let join_control = nodes[join.idx()].try_join().unwrap(); - let tids: Vec<_> = editor - .get_users(*fork) - .filter(|id| nodes[id.idx()].is_thread_id()) - .collect(); - let reduces: Vec<_> = editor - .get_users(*join) - .filter(|id| nodes[id.idx()].is_reduce()) - .collect(); + let nodes = &editor.func().nodes; + let (fork_control, factors) = nodes[fork.idx()].try_fork().unwrap(); + if factors.len() > 1 { + // For now, don't convert multi-dimensional fork-joins. Rely on pass + // that splits fork-joins. + return; + } + let join_control = nodes[join.idx()].try_join().unwrap(); + let tids: Vec<_> = editor + .get_users(fork) + .filter(|id| nodes[id.idx()].is_thread_id()) + .collect(); + let reduces: Vec<_> = editor + .get_users(join) + .filter(|id| nodes[id.idx()].is_reduce()) + .collect(); - let num_nodes = editor.node_ids().len(); - let region_id = NodeID::new(num_nodes); - let if_id = NodeID::new(num_nodes + 1); - let proj_back_id = NodeID::new(num_nodes + 2); - let proj_exit_id = NodeID::new(num_nodes + 3); - let zero_id = NodeID::new(num_nodes + 4); - let one_id = NodeID::new(num_nodes + 5); - let indvar_id = NodeID::new(num_nodes + 6); - let add_id = NodeID::new(num_nodes + 7); - let dc_id = NodeID::new(num_nodes + 8); - let neq_id = NodeID::new(num_nodes + 9); + let num_nodes = editor.node_ids().len(); + let region_id = NodeID::new(num_nodes); + let if_id = NodeID::new(num_nodes + 1); + let proj_back_id = NodeID::new(num_nodes + 2); + let proj_exit_id = NodeID::new(num_nodes + 3); + let zero_id = NodeID::new(num_nodes + 4); + let one_id = NodeID::new(num_nodes + 5); + let indvar_id = NodeID::new(num_nodes + 6); + let add_id = NodeID::new(num_nodes + 7); + let dc_id = NodeID::new(num_nodes + 8); + let neq_id = NodeID::new(num_nodes + 9); - let guard_if_id = NodeID::new(num_nodes + 10); - let guard_join_id = NodeID::new(num_nodes + 11); - let guard_taken_proj_id = NodeID::new(num_nodes + 12); - let guard_skipped_proj_id = NodeID::new(num_nodes + 13); - let guard_cond_id = NodeID::new(num_nodes + 14); + let guard_if_id = NodeID::new(num_nodes + 10); + let guard_join_id = NodeID::new(num_nodes + 11); + let guard_taken_proj_id = NodeID::new(num_nodes + 12); + let guard_skipped_proj_id = NodeID::new(num_nodes + 13); + let guard_cond_id = NodeID::new(num_nodes + 14); - let phi_ids = (num_nodes + 15..num_nodes + 15 + reduces.len()).map(NodeID::new); - let s = num_nodes + 15 + reduces.len(); - let join_phi_ids = (s..s + reduces.len()).map(NodeID::new); + let phi_ids = (num_nodes + 15..num_nodes + 15 + reduces.len()).map(NodeID::new); + let s = num_nodes + 15 + reduces.len(); + let join_phi_ids = (s..s + reduces.len()).map(NodeID::new); - let guard_cond = Node::Binary { - left: zero_id, - right: dc_id, - op: BinaryOperator::LT, - }; - let guard_if = Node::If { - control: fork_control, - cond: guard_cond_id, - }; - let guard_taken_proj = Node::Projection { - control: guard_if_id, - selection: 1, - }; - let guard_skipped_proj = Node::Projection { - control: guard_if_id, - selection: 0, - }; - let guard_join = Node::Region { - preds: Box::new([guard_skipped_proj_id, proj_exit_id]), - }; + let guard_cond = Node::Binary { + left: zero_id, + right: dc_id, + op: BinaryOperator::LT, + }; + let guard_if = Node::If { + control: fork_control, + cond: guard_cond_id, + }; + let guard_taken_proj = Node::Projection { + control: guard_if_id, + selection: 1, + }; + let guard_skipped_proj = Node::Projection { + control: guard_if_id, + selection: 0, + }; + let guard_join = Node::Region { + preds: Box::new([guard_skipped_proj_id, proj_exit_id]), + }; - let region = Node::Region { - preds: Box::new([guard_taken_proj_id, proj_back_id]), - }; - let if_node = Node::If { - control: join_control, - cond: neq_id, - }; - let proj_back = Node::Projection { - control: if_id, - selection: 1, - }; - let proj_exit = Node::Projection { - control: if_id, - selection: 0, - }; - let zero = Node::Constant { id: zero_cons_id }; - let one = Node::Constant { id: one_cons_id }; - let indvar = Node::Phi { - control: region_id, - data: Box::new([zero_id, add_id]), - }; - let add = Node::Binary { - op: BinaryOperator::Add, - left: indvar_id, - right: one_id, - }; - let dc = Node::DynamicConstant { id: factors[0] }; - let neq = Node::Binary { - op: BinaryOperator::NE, - left: add_id, - right: dc_id, - }; - let (phis, join_phis): (Vec<_>, Vec<_>) = reduces - .iter() - .map(|reduce_id| { - let (_, init, reduct) = nodes[reduce_id.idx()].try_reduce().unwrap(); - ( - Node::Phi { - control: region_id, - data: Box::new([init, reduct]), - }, - Node::Phi { - control: guard_join_id, - data: Box::new([init, reduct]), - }, - ) - }) - .unzip(); + let region = Node::Region { + preds: Box::new([guard_taken_proj_id, proj_back_id]), + }; + let if_node = Node::If { + control: join_control, + cond: neq_id, + }; + let proj_back = Node::Projection { + control: if_id, + selection: 1, + }; + let proj_exit = Node::Projection { + control: if_id, + selection: 0, + }; + let zero = Node::Constant { id: zero_cons_id }; + let one = Node::Constant { id: one_cons_id }; + let indvar = Node::Phi { + control: region_id, + data: Box::new([zero_id, add_id]), + }; + let add = Node::Binary { + op: BinaryOperator::Add, + left: indvar_id, + right: one_id, + }; + let dc = Node::DynamicConstant { id: factors[0] }; + let neq = Node::Binary { + op: BinaryOperator::NE, + left: add_id, + right: dc_id, + }; + let (phis, join_phis): (Vec<_>, Vec<_>) = reduces + .iter() + .map(|reduce_id| { + let (_, init, reduct) = nodes[reduce_id.idx()].try_reduce().unwrap(); + ( + Node::Phi { + control: region_id, + data: Box::new([init, reduct]), + }, + Node::Phi { + control: guard_join_id, + data: Box::new([init, reduct]), + }, + ) + }) + .unzip(); - editor.edit(|mut edit| { - assert_eq!(edit.add_node(region), region_id); - assert_eq!(edit.add_node(if_node), if_id); - assert_eq!(edit.add_node(proj_back), proj_back_id); - assert_eq!(edit.add_node(proj_exit), proj_exit_id); - assert_eq!(edit.add_node(zero), zero_id); - assert_eq!(edit.add_node(one), one_id); - assert_eq!(edit.add_node(indvar), indvar_id); - assert_eq!(edit.add_node(add), add_id); - assert_eq!(edit.add_node(dc), dc_id); - assert_eq!(edit.add_node(neq), neq_id); - assert_eq!(edit.add_node(guard_if), guard_if_id); - assert_eq!(edit.add_node(guard_join), guard_join_id); - assert_eq!(edit.add_node(guard_taken_proj), guard_taken_proj_id); - assert_eq!(edit.add_node(guard_skipped_proj), guard_skipped_proj_id); - assert_eq!(edit.add_node(guard_cond), guard_cond_id); + editor.edit(|mut edit| { + assert_eq!(edit.add_node(region), region_id); + assert_eq!(edit.add_node(if_node), if_id); + assert_eq!(edit.add_node(proj_back), proj_back_id); + assert_eq!(edit.add_node(proj_exit), proj_exit_id); + assert_eq!(edit.add_node(zero), zero_id); + assert_eq!(edit.add_node(one), one_id); + assert_eq!(edit.add_node(indvar), indvar_id); + assert_eq!(edit.add_node(add), add_id); + assert_eq!(edit.add_node(dc), dc_id); + assert_eq!(edit.add_node(neq), neq_id); + assert_eq!(edit.add_node(guard_if), guard_if_id); + assert_eq!(edit.add_node(guard_join), guard_join_id); + assert_eq!(edit.add_node(guard_taken_proj), guard_taken_proj_id); + assert_eq!(edit.add_node(guard_skipped_proj), guard_skipped_proj_id); + assert_eq!(edit.add_node(guard_cond), guard_cond_id); - for (phi_id, phi) in zip(phi_ids.clone(), &phis) { - assert_eq!(edit.add_node(phi.clone()), phi_id); - } - for (phi_id, phi) in zip(join_phi_ids.clone(), &join_phis) { - assert_eq!(edit.add_node(phi.clone()), phi_id); - } + for (phi_id, phi) in zip(phi_ids.clone(), &phis) { + assert_eq!(edit.add_node(phi.clone()), phi_id); + } + for (phi_id, phi) in zip(join_phi_ids.clone(), &join_phis) { + assert_eq!(edit.add_node(phi.clone()), phi_id); + } - edit = edit.replace_all_uses(*fork, region_id)?; - edit = edit.replace_all_uses_where(*join, guard_join_id, |usee| *usee != if_id)?; - edit.sub_edit(*fork, region_id); - edit.sub_edit(*join, if_id); - for tid in tids.iter() { - edit.sub_edit(*tid, indvar_id); - edit = edit.replace_all_uses(*tid, indvar_id)?; - } - for (((reduce, phi_id), phi), join_phi_id) in - zip(reduces.iter(), phi_ids).zip(phis).zip(join_phi_ids) - { - edit.sub_edit(*reduce, phi_id); - let Node::Phi { control: _, data } = phi else { - panic!() - }; - edit = edit.replace_all_uses_where(*reduce, join_phi_id, |usee| { - !fork_nodes.contains(usee) - })?; //, |usee| *usee != *reduct)?; - edit = edit.replace_all_uses_where(*reduce, phi_id, |usee| { - fork_nodes.contains(usee) || *usee == data[1] - })?; - edit = edit.delete_node(*reduce)?; - } + edit = edit.replace_all_uses(fork, region_id)?; + edit = edit.replace_all_uses_where(join, guard_join_id, |usee| *usee != if_id)?; + edit.sub_edit(fork, region_id); + edit.sub_edit(join, if_id); + for tid in tids.iter() { + edit.sub_edit(*tid, indvar_id); + edit = edit.replace_all_uses(*tid, indvar_id)?; + } + for (((reduce, phi_id), phi), join_phi_id) in + zip(reduces.iter(), phi_ids).zip(phis).zip(join_phi_ids) + { + edit.sub_edit(*reduce, phi_id); + let Node::Phi { control: _, data } = phi else { + panic!() + }; + edit = edit + .replace_all_uses_where(*reduce, join_phi_id, |usee| !fork_nodes.contains(usee))?; //, |usee| *usee != *reduct)?; + edit = edit.replace_all_uses_where(*reduce, phi_id, |usee| { + fork_nodes.contains(usee) || *usee == data[1] + })?; + edit = edit.delete_node(*reduce)?; + } - edit = edit.delete_node(*fork)?; - edit = edit.delete_node(*join)?; - for tid in tids { - edit = edit.delete_node(tid)?; - } + edit = edit.delete_node(fork)?; + edit = edit.delete_node(join)?; + for tid in tids { + edit = edit.delete_node(tid)?; + } - Ok(edit) - }); - } + Ok(edit) + }); } diff --git a/juno_samples/fork_join_tests/Cargo.toml b/juno_samples/fork_join_tests/Cargo.toml new file mode 100644 index 00000000..a109e782 --- /dev/null +++ b/juno_samples/fork_join_tests/Cargo.toml @@ -0,0 +1,21 @@ +[package] +name = "juno_fork_join_tests" +version = "0.1.0" +authors = ["Russel Arbore <rarbore2@illinois.edu>"] +edition = "2021" + +[[bin]] +name = "juno_fork_join_tests" +path = "src/main.rs" + +[features] +cuda = ["juno_build/cuda", "hercules_rt/cuda"] + +[build-dependencies] +juno_build = { path = "../../juno_build" } + +[dependencies] +juno_build = { path = "../../juno_build" } +hercules_rt = { path = "../../hercules_rt" } +with_builtin_macros = "0.1.0" +async-std = "*" diff --git a/juno_samples/fork_join_tests/build.rs b/juno_samples/fork_join_tests/build.rs new file mode 100644 index 00000000..796e9f32 --- /dev/null +++ b/juno_samples/fork_join_tests/build.rs @@ -0,0 +1,24 @@ +use juno_build::JunoCompiler; + +fn main() { + #[cfg(not(feature = "cuda"))] + { + JunoCompiler::new() + .file_in_src("fork_join_tests.jn") + .unwrap() + .schedule_in_src("cpu.sch") + .unwrap() + .build() + .unwrap(); + } + #[cfg(feature = "cuda")] + { + JunoCompiler::new() + .file_in_src("fork_join_tests.jn") + .unwrap() + .schedule_in_src("gpu.sch") + .unwrap() + .build() + .unwrap(); + } +} diff --git a/juno_samples/fork_join_tests/src/cpu.sch b/juno_samples/fork_join_tests/src/cpu.sch new file mode 100644 index 00000000..0263c275 --- /dev/null +++ b/juno_samples/fork_join_tests/src/cpu.sch @@ -0,0 +1,42 @@ +gvn(*); +phi-elim(*); +dce(*); + +let out = auto-outline(*); +cpu(out.test1); +cpu(out.test2); +cpu(out.test3); + +ip-sroa(*); +sroa(*); +dce(*); +gvn(*); +phi-elim(*); +dce(*); + +fixpoint panic after 20 { + forkify(*); + fork-guard-elim(*); + fork-coalesce(*); +} +gvn(*); +phi-elim(*); +dce(*); + +gvn(*); +phi-elim(*); +dce(*); + +fixpoint panic after 20 { + infer-schedules(*); +} +fork-split(*); +gvn(*); +phi-elim(*); +dce(*); +unforkify(*); +gvn(*); +phi-elim(*); +dce(*); + +gcm(*); diff --git a/juno_samples/fork_join_tests/src/fork_join_tests.jn b/juno_samples/fork_join_tests/src/fork_join_tests.jn new file mode 100644 index 00000000..55e0a37e --- /dev/null +++ b/juno_samples/fork_join_tests/src/fork_join_tests.jn @@ -0,0 +1,61 @@ +#[entry] +fn test1(input : i32) -> i32[4, 4] { + let arr : i32[4, 4]; + for i = 0 to 4 { + for j = 0 to 4 { + arr[i, j] = input; + } + } + return arr; +} + +#[entry] +fn test2(input : i32) -> i32[4, 4] { + let arr : i32[4, 4]; + for i = 0 to 8 { + for j = 0 to 4 { + for k = 0 to 4 { + arr[j, k] += input; + } + } + } + return arr; +} + +#[entry] +fn test3(input : i32) -> i32[3, 3] { + let arr1 : i32[3, 3]; + for i = 0 to 3 { + for j = 0 to 3 { + arr1[i, j] = (i + j) as i32 + input; + } + } + let arr2 : i32[3, 3]; + for i = 0 to 3 { + for j = 0 to 3 { + arr2[i, j] = arr1[2 - i, 2 - j]; + } + } + let arr3 : i32[3, 3]; + for i = 0 to 3 { + for j = 0 to 3 { + arr3[i, j] = arr2[i, j] + 7; + } + } + return arr3; +} + +#[entry] +fn test4(input : i32) -> i32[4, 4] { + let arr : i32[4, 4]; + for i = 0 to 4 { + for j = 0 to 4 { + let acc = arr[i, j]; + for k = 0 to 7 { + acc += input; + } + arr[i, j] = acc; + } + } + return arr; +} diff --git a/juno_samples/fork_join_tests/src/gpu.sch b/juno_samples/fork_join_tests/src/gpu.sch new file mode 100644 index 00000000..701c347c --- /dev/null +++ b/juno_samples/fork_join_tests/src/gpu.sch @@ -0,0 +1,33 @@ +gvn(*); +phi-elim(*); +dce(*); + +let out = auto-outline(*); +gpu(out.test1); +gpu(out.test2); +gpu(out.test3); +gpu(out.test4); + +ip-sroa(*); +sroa(*); +dce(*); +gvn(*); +phi-elim(*); +dce(*); + +fixpoint panic after 20 { + forkify(*); + fork-guard-elim(*); + fork-coalesce(*); +} + +gvn(*); +phi-elim(*); +dce(*); + +fixpoint panic after 20 { + infer-schedules(*); +} + +float-collections(test2, out.test2, test4, out.test4); +gcm(*); diff --git a/juno_samples/fork_join_tests/src/main.rs b/juno_samples/fork_join_tests/src/main.rs new file mode 100644 index 00000000..cbd42c50 --- /dev/null +++ b/juno_samples/fork_join_tests/src/main.rs @@ -0,0 +1,46 @@ +#![feature(concat_idents)] + +use hercules_rt::runner; + +juno_build::juno!("fork_join_tests"); + +fn main() { + #[cfg(not(feature = "cuda"))] + let assert = |correct, output: hercules_rt::HerculesCPURefMut<'_>| { + assert_eq!(output.as_slice::<i32>(), &correct); + }; + + #[cfg(feature = "cuda")] + let assert = |correct, output: hercules_rt::HerculesCUDARefMut<'_>| { + let mut dst = vec![0i32; 16]; + let output = output.to_cpu_ref(&mut dst); + assert_eq!(output.as_slice::<i32>(), &correct); + }; + + async_std::task::block_on(async { + let mut r = runner!(test1); + let output = r.run(5).await; + let correct = vec![5i32; 16]; + assert(correct, output); + + let mut r = runner!(test2); + let output = r.run(3).await; + let correct = vec![24i32; 16]; + assert(correct, output); + + let mut r = runner!(test3); + let output = r.run(0).await; + let correct = vec![11, 10, 9, 10, 9, 8, 9, 8, 7]; + assert(correct, output); + + let mut r = runner!(test4); + let output = r.run(9).await; + let correct = vec![63i32; 16]; + assert(correct, output); + }); +} + +#[test] +fn implicit_clone_test() { + main(); +} diff --git a/juno_scheduler/src/pm.rs b/juno_scheduler/src/pm.rs index f6fe2fc1..9c818707 100644 --- a/juno_scheduler/src/pm.rs +++ b/juno_scheduler/src/pm.rs @@ -575,6 +575,8 @@ impl PassManager { self.postdoms = None; self.fork_join_maps = None; self.fork_join_nests = None; + self.fork_control_maps = None; + self.fork_trees = None; self.loops = None; self.reduce_cycles = None; self.data_nodes_in_fork_joins = None; @@ -1303,7 +1305,7 @@ fn run_pass( pm: &mut PassManager, pass: Pass, args: Vec<Value>, - selection: Option<Vec<CodeLocation>>, + mut selection: Option<Vec<CodeLocation>>, ) -> Result<(Value, bool), SchedulerError> { let mut result = Value::Record { fields: HashMap::new(), @@ -1439,13 +1441,6 @@ fn run_pass( } Pass::FloatCollections => { assert!(args.is_empty()); - if let Some(_) = selection { - return Err(SchedulerError::PassError { - pass: "floatCollections".to_string(), - error: "must be applied to the entire module".to_string(), - }); - } - pm.make_typing(); pm.make_callgraph(); pm.make_devices(); @@ -1453,11 +1448,15 @@ fn run_pass( let callgraph = pm.callgraph.take().unwrap(); let devices = pm.devices.take().unwrap(); - let mut editors = build_editors(pm); + // Modify the selection to include callers of selected functions. + let mut editors = build_selection(pm, selection) + .into_iter() + .filter_map(|editor| editor.map(|editor| (editor.func_id(), editor))) + .collect(); float_collections(&mut editors, &typing, &callgraph, &devices); for func in editors { - changed |= func.modified(); + changed |= func.1.modified(); } pm.delete_gravestones(); @@ -1506,7 +1505,7 @@ fn run_pass( let Some(mut func) = func else { continue; }; - fork_split(&mut func, fork_join_map, reduce_cycles); + split_all_forks(&mut func, fork_join_map, reduce_cycles); changed |= func.modified(); inner_changed |= func.modified(); } @@ -1568,6 +1567,7 @@ fn run_pass( pm.make_doms(); pm.make_fork_join_maps(); pm.make_loops(); + pm.make_reduce_cycles(); pm.make_collection_objects(); pm.make_devices(); pm.make_object_device_demands(); @@ -1578,6 +1578,7 @@ fn run_pass( let doms = pm.doms.take().unwrap(); let fork_join_maps = pm.fork_join_maps.take().unwrap(); let loops = pm.loops.take().unwrap(); + let reduce_cycles = pm.reduce_cycles.take().unwrap(); let control_subgraphs = pm.control_subgraphs.take().unwrap(); let collection_objects = pm.collection_objects.take().unwrap(); let devices = pm.devices.take().unwrap(); @@ -1599,6 +1600,7 @@ fn run_pass( &doms[id.idx()], &fork_join_maps[id.idx()], &loops[id.idx()], + &reduce_cycles[id.idx()], &collection_objects, &devices, &object_device_demands[id.idx()], @@ -1874,7 +1876,7 @@ fn run_pass( let Some(mut func) = func else { continue; }; - unforkify(&mut func, fork_join_map, loop_tree); + unforkify_all(&mut func, fork_join_map, loop_tree); changed |= func.modified(); } pm.delete_gravestones(); -- GitLab