diff --git a/hercules_cg/src/gpu.rs b/hercules_cg/src/gpu.rs index afc016a48a39a091cf6638ade071d40cbb901afa..8f186aa7e4ff6f0a7436d94604b127ff2f1f3ea2 100644 --- a/hercules_cg/src/gpu.rs +++ b/hercules_cg/src/gpu.rs @@ -622,23 +622,23 @@ extern \"C\" {} {}(", write!(pass_args, "ret")?; write!(w, "\tcudaMalloc((void**)&ret, sizeof({}));\n", ret_type)?; } - write!(w, "\tcudaError_t err;\n"); + write!(w, "\tcudaError_t err;\n")?; write!( w, "\t{}_gpu<<<{}, {}, {}>>>({});\n", self.function.name, num_blocks, num_threads, dynamic_shared_offset, pass_args )?; - write!(w, "\terr = cudaGetLastError();\n"); + write!(w, "\terr = cudaGetLastError();\n")?; write!( w, "\tif (cudaSuccess != err) {{ printf(\"Error1: %s\\n\", cudaGetErrorString(err)); }}\n" - ); + )?; write!(w, "\tcudaDeviceSynchronize();\n")?; - write!(w, "\terr = cudaGetLastError();\n"); + write!(w, "\terr = cudaGetLastError();\n")?; write!( w, "\tif (cudaSuccess != err) {{ printf(\"Error2: %s\\n\", cudaGetErrorString(err)); }}\n" - ); + )?; if has_ret_var { // Copy return from device to host, whether it's primitive value or collection pointer write!(w, "\t{} host_ret;\n", ret_type)?; @@ -1150,7 +1150,8 @@ extern \"C\" {} {}(", // for all threads. Otherwise, it can be inside or outside block fork. // If inside, it's stored in shared memory so we "allocate" it once // and parallelize memset to 0. If outside, we initialize as offset - // to backing, but if multi-block grid, don't memset to avoid grid-level sync. + // to backing, but if multi-block grid, don't memset to avoid grid- + // level sync. Node::Constant { id: cons_id } => { let is_primitive = self.types[self.typing[id.idx()].idx()].is_primitive(); let cg_tile = match state { @@ -1192,9 +1193,7 @@ extern \"C\" {} {}(", )?; } if !is_primitive - && (state != KernelState::OutBlock - || is_block_parallel.is_none() - || !is_block_parallel.unwrap()) + && (state != KernelState::OutBlock || !is_block_parallel.unwrap_or(false)) { let data_size = self.get_size(self.typing[id.idx()], None, Some(extra_dim_collects)); diff --git a/hercules_ir/src/ir.rs b/hercules_ir/src/ir.rs index 846347b0cec605fc2485ae9b847213e2783d37b8..5c575ea1dded3ec314f6d5d9aa8deed025e6d532 100644 --- a/hercules_ir/src/ir.rs +++ b/hercules_ir/src/ir.rs @@ -1857,7 +1857,7 @@ pub fn indices_may_overlap(indices1: &[Index], indices2: &[Index]) -> bool { * list of indices B. */ pub fn indices_contain_other_indices(indices_a: &[Index], indices_b: &[Index]) -> bool { - if indices_a.len() < indices_b.len() { + if indices_a.len() > indices_b.len() { return false; } diff --git a/hercules_opt/src/gcm.rs b/hercules_opt/src/gcm.rs index 271bfaf1da55f6c8ab342d06853532eb5ce99fff..3ff6d2fe9e716f8121f84a595e0dcf0e5c576e26 100644 --- a/hercules_opt/src/gcm.rs +++ b/hercules_opt/src/gcm.rs @@ -1,5 +1,5 @@ use std::cell::Ref; -use std::collections::{BTreeMap, BTreeSet, HashMap, VecDeque}; +use std::collections::{BTreeMap, BTreeSet, HashMap, HashSet, VecDeque}; use std::iter::{empty, once, zip, FromIterator}; use bitvec::prelude::*; @@ -76,6 +76,7 @@ pub fn gcm( dom: &DomTree, fork_join_map: &HashMap<NodeID, NodeID>, loops: &LoopTree, + reduce_cycles: &HashMap<NodeID, HashSet<NodeID>>, objects: &CollectionObjects, devices: &Vec<Device>, object_device_demands: &FunctionObjectDeviceDemands, @@ -88,8 +89,10 @@ pub fn gcm( reverse_postorder, dom, loops, + reduce_cycles, fork_join_map, objects, + devices, ); let liveness = liveness_dataflow( @@ -172,8 +175,10 @@ fn basic_blocks( reverse_postorder: &Vec<NodeID>, dom: &DomTree, loops: &LoopTree, + reduce_cycles: &HashMap<NodeID, HashSet<NodeID>>, fork_join_map: &HashMap<NodeID, NodeID>, objects: &CollectionObjects, + devices: &Vec<Device>, ) -> BasicBlocks { let mut bbs: Vec<Option<NodeID>> = vec![None; function.nodes.len()]; @@ -244,6 +249,9 @@ fn basic_blocks( // but not forwarding read - forwarding reads are collapsed, and the // bottom read is treated as reading from the transitive parent of the // forwarding read(s). + // 3: If the node producing the collection is a reduce node, then any read + // users that aren't in the reduce's cycle shouldn't anti-depend user any + // mutators in the reduce cycle. let mut antideps = BTreeSet::new(); for id in reverse_postorder.iter() { // Find a terminating read node and the collections it reads. @@ -269,6 +277,10 @@ fn basic_blocks( // TODO: make this less outrageously inefficient. let func_objects = &objects[&func_id]; for root in roots.iter() { + let root_is_reduce_and_read_isnt_in_cycle = reduce_cycles + .get(root) + .map(|cycle| !cycle.contains(&id)) + .unwrap_or(false); let root_early = schedule_early[root.idx()].unwrap(); let mut root_block_iterated_users: BTreeSet<NodeID> = BTreeSet::new(); let mut workset = BTreeSet::new(); @@ -296,6 +308,11 @@ fn basic_blocks( && mutating_objects(function, func_id, *mutator, objects) .any(|mutated| read_objs.contains(&mutated)) && id != mutator + && (!root_is_reduce_and_read_isnt_in_cycle + || !reduce_cycles + .get(root) + .map(|cycle| cycle.contains(mutator)) + .unwrap_or(false)) { antideps.insert((*id, *mutator)); } @@ -421,9 +438,18 @@ fn basic_blocks( // If the next node further up the dominator tree is in a shallower // loop nest or if we can get out of a reduce loop when we don't // need to be in one, place this data node in a higher-up location. - // Only do this is the node isn't a constant or undef. + // Only do this is the node isn't a constant or undef - if a + // node is a constant or undef, we want its placement to be as + // control dependent as possible, even inside loops. In GPU + // functions specifically, lift constants that may be returned + // outside fork-joins. let is_constant_or_undef = function.nodes[id.idx()].is_constant() || function.nodes[id.idx()].is_undef(); + let is_gpu_returned = devices[func_id.idx()] == Device::CUDA + && objects[&func_id] + .objects(id) + .into_iter() + .any(|obj| objects[&func_id].returned_objects().contains(obj)); let old_nest = loops .header_of(location) .map(|header| loops.nesting(header).unwrap()); @@ -444,7 +470,10 @@ fn basic_blocks( // loop use the reduce node forming the loop, so the dominator chain // will consist of one block, and this loop won't ever iterate. let currently_at_join = function.nodes[location.idx()].is_join(); - if !is_constant_or_undef && (shallower_nest || currently_at_join) { + + if (!is_constant_or_undef || is_gpu_returned) + && (shallower_nest || currently_at_join) + { location = control_node; } } diff --git a/juno_samples/fork_join_tests/src/fork_join_tests.jn b/juno_samples/fork_join_tests/src/fork_join_tests.jn index 073cfd1ebd75da0f9dea8b931e30fa1dd790eb95..55e0a37e2194ff172add63c1b436dccd1cfd833d 100644 --- a/juno_samples/fork_join_tests/src/fork_join_tests.jn +++ b/juno_samples/fork_join_tests/src/fork_join_tests.jn @@ -33,7 +33,7 @@ fn test3(input : i32) -> i32[3, 3] { let arr2 : i32[3, 3]; for i = 0 to 3 { for j = 0 to 3 { - arr2[i, j] = arr1[3 - i, 3 - j]; + arr2[i, j] = arr1[2 - i, 2 - j]; } } let arr3 : i32[3, 3]; @@ -44,3 +44,18 @@ fn test3(input : i32) -> i32[3, 3] { } return arr3; } + +#[entry] +fn test4(input : i32) -> i32[4, 4] { + let arr : i32[4, 4]; + for i = 0 to 4 { + for j = 0 to 4 { + let acc = arr[i, j]; + for k = 0 to 7 { + acc += input; + } + arr[i, j] = acc; + } + } + return arr; +} diff --git a/juno_samples/fork_join_tests/src/gpu.sch b/juno_samples/fork_join_tests/src/gpu.sch index 80f1bbc9e1e6ec78a6f9c77909d084b068c37aa0..bf35caea74ac0bc273bb5cd43f5dffd66e13ee3b 100644 --- a/juno_samples/fork_join_tests/src/gpu.sch +++ b/juno_samples/fork_join_tests/src/gpu.sch @@ -6,6 +6,7 @@ let out = auto-outline(*); gpu(out.test1); gpu(out.test2); gpu(out.test3); +gpu(out.test4); ip-sroa(*); sroa(*); diff --git a/juno_samples/fork_join_tests/src/main.rs b/juno_samples/fork_join_tests/src/main.rs index 4384ecd5b186c7ffe7f3a48d22334989a72f9f0e..cbd42c50ac95d102f248477a3bee0cd517eb9d5b 100644 --- a/juno_samples/fork_join_tests/src/main.rs +++ b/juno_samples/fork_join_tests/src/main.rs @@ -32,6 +32,11 @@ fn main() { let output = r.run(0).await; let correct = vec![11, 10, 9, 10, 9, 8, 9, 8, 7]; assert(correct, output); + + let mut r = runner!(test4); + let output = r.run(9).await; + let correct = vec![63i32; 16]; + assert(correct, output); }); } diff --git a/juno_scheduler/src/pm.rs b/juno_scheduler/src/pm.rs index dd2ae73ac5342da0da5d78006e18e62993ed91ea..28e9de7631e2a4ea1eabb9caf5cb0ea5b00677de 100644 --- a/juno_scheduler/src/pm.rs +++ b/juno_scheduler/src/pm.rs @@ -1576,6 +1576,7 @@ fn run_pass( pm.make_doms(); pm.make_fork_join_maps(); pm.make_loops(); + pm.make_reduce_cycles(); pm.make_collection_objects(); pm.make_devices(); pm.make_object_device_demands(); @@ -1586,6 +1587,7 @@ fn run_pass( let doms = pm.doms.take().unwrap(); let fork_join_maps = pm.fork_join_maps.take().unwrap(); let loops = pm.loops.take().unwrap(); + let reduce_cycles = pm.reduce_cycles.take().unwrap(); let control_subgraphs = pm.control_subgraphs.take().unwrap(); let collection_objects = pm.collection_objects.take().unwrap(); let devices = pm.devices.take().unwrap(); @@ -1607,6 +1609,7 @@ fn run_pass( &doms[id.idx()], &fork_join_maps[id.idx()], &loops[id.idx()], + &reduce_cycles[id.idx()], &collection_objects, &devices, &object_device_demands[id.idx()],