diff --git a/hercules_cg/src/gpu.rs b/hercules_cg/src/gpu.rs index 4069cb02bcf1ed0a93f5c40df432253e842ead9a..d802a274e95b151db0efefe1ec53ba1fdb43a185 100644 --- a/hercules_cg/src/gpu.rs +++ b/hercules_cg/src/gpu.rs @@ -1,6 +1,7 @@ extern crate bitvec; extern crate hercules_ir; +use std::cell::RefCell; use std::collections::{BTreeMap, HashMap, HashSet}; use std::fmt::{Error, Write}; @@ -170,6 +171,62 @@ pub fn gpu_codegen<W: Write>( threads_per_warp: 32, }; + // Check for whether we need to emit synchronization for joins. + let avoid_join_sync = (|| { + // Check for simple block/thread fork structure. + let Some(root_forks) = fork_tree.get(&NodeID::new(0)) else { + return false; + }; + if root_forks.len() != 1 { + return false; + } + let block = *root_forks.into_iter().next().unwrap(); + let Some(block_forks) = fork_tree.get(&block) else { + return false; + }; + if block_forks.len() != 1 { + return false; + } + let thread = *block_forks.into_iter().next().unwrap(); + if let Some(thread_forks) = fork_tree.get(&thread) && !thread_forks.is_empty() { + return false; + } + + // Check that the results from the thread fork aren't needed further + // inside this kernel. + let thread_join = fork_join_map[&thread]; + let block_join = fork_join_map[&block]; + let thread_reduces: Vec<_> = def_use_map.get_users(thread_join).as_ref().into_iter().filter(|id| function.nodes[id.idx()].is_reduce()).collect(); + let block_reduces: Vec<_> = def_use_map.get_users(block_join).as_ref().into_iter().filter(|id| function.nodes[id.idx()].is_reduce()).collect(); + for id in thread_reduces { + if !function.schedules[id.idx()].contains(&Schedule::ParallelReduce) { + return false; + } + let users = def_use_map.get_users(*id); + if users.len() > 1 { + return false; + } + let user = users.into_iter().next().unwrap(); + if !block_reduces.contains(&user) { + return false; + } + } + for id in block_reduces { + if !function.schedules[id.idx()].contains(&Schedule::ParallelReduce) { + return false; + } + let users = def_use_map.get_users(*id); + if users.len() > 1 { + return false; + } + let user = users.into_iter().next().unwrap(); + if !function.nodes[user.idx()].is_return() { + return false; + } + } + true + })(); + let ctx = GPUContext { module_name, function, @@ -191,6 +248,8 @@ pub fn gpu_codegen<W: Write>( control_data_phi_map, return_parameters, kernel_params, + avoid_join_sync, + generated_sync: RefCell::new(false), }; ctx.codegen_function(w) } @@ -221,6 +280,8 @@ struct GPUContext<'a> { control_data_phi_map: HashMap<NodeID, Vec<(NodeID, NodeID)>>, return_parameters: Vec<Option<usize>>, kernel_params: &'a GPUKernelParams, + avoid_join_sync: bool, + generated_sync: RefCell<bool>, } /* @@ -277,8 +338,6 @@ impl GPUContext<'_> { let mut dynamic_shared_offset = "0".to_string(); self.codegen_dynamic_constants(&mut top)?; self.codegen_declare_data(&mut top)?; - self.codegen_helpers(&mut top)?; - write!(w, "{}", top)?; // Setup for CUDA's "goto" for control flow between basic blocks. let mut gotos: BTreeMap<_, _> = (0..self.function.nodes.len()) @@ -327,6 +386,8 @@ impl GPUContext<'_> { }; // Emit all GPU kernel code from previous steps + self.codegen_helpers(&mut top, *self.generated_sync.borrow())?; + write!(w, "{}", top)?; self.codegen_goto_start(&mut thread_block_tiles)?; write!(w, "{}", thread_block_tiles)?; let mut kernel_body = String::new(); @@ -575,16 +636,25 @@ namespace cg = cooperative_groups; * are from CUDA's cooperative groups API and are used specifically for reads * and writes. */ - fn codegen_helpers(&self, w: &mut String) -> Result<(), Error> { - write!( - w, - "\t__shared__ cge::block_tile_memory<1024> block_sync_shared;\n" - )?; + fn codegen_helpers(&self, w: &mut String, need_sync_shared: bool) -> Result<(), Error> { + if need_sync_shared { + write!( + w, + "\t__shared__ cge::block_tile_memory<1024> block_sync_shared;\n" + )?; + } write!(w, "\tcg::grid_group grid = cg::this_grid();\n")?; - write!( - w, - "\tcg::thread_block block = cge::this_thread_block(block_sync_shared);\n" - )?; + if need_sync_shared { + write!( + w, + "\tcg::thread_block block = cge::this_thread_block(block_sync_shared);\n" + )?; + } else { + write!( + w, + "\tcg::thread_block block = cg::this_thread_block();\n" + )?; + } Ok(()) } @@ -1344,7 +1414,7 @@ namespace cg = cooperative_groups; write!(w, "{}\t*({} + i) = 0;\n", tabs, define_variable)?; write!(w, "{}}}\n", tabs)?; write!(w, "{}{}.sync();\n", tabs, cg_tile)?; - //write!(w, "__syncthreads\n")?; + *self.generated_sync.borrow_mut() = true; } } // Dynamic constants emitted at top @@ -1803,10 +1873,12 @@ namespace cg = cooperative_groups; write!(w_term, "\t}}\n")?; tabs += 1; } - let fork = self.join_fork_map.get(&id).unwrap(); - let cg_tile_available = self.get_cg_tile(*fork, CGType::Available); - write!(w_term, "\t{}.sync();\n", cg_tile_available)?; - //write!(w_term, "\t__syncthreads;\n")?; + if !self.avoid_join_sync { + let fork = self.join_fork_map.get(&id).unwrap(); + let cg_tile_available = self.get_cg_tile(*fork, CGType::Available); + write!(w_term, "\t{}.sync();\n", cg_tile_available)?; + *self.generated_sync.borrow_mut() = true; + } } // If the Fork was parallelized, each thread or UsedPerId tile of // threads only runs one ThreadID, so we can jump straight to the