Skip to content
Snippets Groups Projects
Commit 012eec22 authored by rarbore2's avatar rarbore2
Browse files

Merge branch 'sync-opt' into 'main'

Sync optimizations

See merge request !227
parents 9ea73aab 3ece0fa1
No related branches found
No related tags found
1 merge request!227Sync optimizations
Pipeline #202113 passed
extern crate bitvec;
extern crate hercules_ir;
use std::cell::RefCell;
use std::collections::{BTreeMap, HashMap, HashSet};
use std::fmt::{Error, Write};
......@@ -170,6 +171,62 @@ pub fn gpu_codegen<W: Write>(
threads_per_warp: 32,
};
// Check for whether we need to emit synchronization for joins.
let avoid_join_sync = (|| {
// Check for simple block/thread fork structure.
let Some(root_forks) = fork_tree.get(&NodeID::new(0)) else {
return false;
};
if root_forks.len() != 1 {
return false;
}
let block = *root_forks.into_iter().next().unwrap();
let Some(block_forks) = fork_tree.get(&block) else {
return false;
};
if block_forks.len() != 1 {
return false;
}
let thread = *block_forks.into_iter().next().unwrap();
if let Some(thread_forks) = fork_tree.get(&thread) && !thread_forks.is_empty() {
return false;
}
// Check that the results from the thread fork aren't needed further
// inside this kernel.
let thread_join = fork_join_map[&thread];
let block_join = fork_join_map[&block];
let thread_reduces: Vec<_> = def_use_map.get_users(thread_join).as_ref().into_iter().filter(|id| function.nodes[id.idx()].is_reduce()).collect();
let block_reduces: Vec<_> = def_use_map.get_users(block_join).as_ref().into_iter().filter(|id| function.nodes[id.idx()].is_reduce()).collect();
for id in thread_reduces {
if !function.schedules[id.idx()].contains(&Schedule::ParallelReduce) {
return false;
}
let users = def_use_map.get_users(*id);
if users.len() > 1 {
return false;
}
let user = users.into_iter().next().unwrap();
if !block_reduces.contains(&user) {
return false;
}
}
for id in block_reduces {
if !function.schedules[id.idx()].contains(&Schedule::ParallelReduce) {
return false;
}
let users = def_use_map.get_users(*id);
if users.len() > 1 {
return false;
}
let user = users.into_iter().next().unwrap();
if !function.nodes[user.idx()].is_return() {
return false;
}
}
true
})();
let ctx = GPUContext {
module_name,
function,
......@@ -191,6 +248,8 @@ pub fn gpu_codegen<W: Write>(
control_data_phi_map,
return_parameters,
kernel_params,
avoid_join_sync,
generated_sync: RefCell::new(false),
};
ctx.codegen_function(w)
}
......@@ -221,6 +280,8 @@ struct GPUContext<'a> {
control_data_phi_map: HashMap<NodeID, Vec<(NodeID, NodeID)>>,
return_parameters: Vec<Option<usize>>,
kernel_params: &'a GPUKernelParams,
avoid_join_sync: bool,
generated_sync: RefCell<bool>,
}
/*
......@@ -277,8 +338,6 @@ impl GPUContext<'_> {
let mut dynamic_shared_offset = "0".to_string();
self.codegen_dynamic_constants(&mut top)?;
self.codegen_declare_data(&mut top)?;
self.codegen_helpers(&mut top)?;
write!(w, "{}", top)?;
// Setup for CUDA's "goto" for control flow between basic blocks.
let mut gotos: BTreeMap<_, _> = (0..self.function.nodes.len())
......@@ -327,6 +386,8 @@ impl GPUContext<'_> {
};
// Emit all GPU kernel code from previous steps
self.codegen_helpers(&mut top, *self.generated_sync.borrow())?;
write!(w, "{}", top)?;
self.codegen_goto_start(&mut thread_block_tiles)?;
write!(w, "{}", thread_block_tiles)?;
let mut kernel_body = String::new();
......@@ -575,16 +636,25 @@ namespace cg = cooperative_groups;
* are from CUDA's cooperative groups API and are used specifically for reads
* and writes.
*/
fn codegen_helpers(&self, w: &mut String) -> Result<(), Error> {
write!(
w,
"\t__shared__ cge::block_tile_memory<1024> block_sync_shared;\n"
)?;
fn codegen_helpers(&self, w: &mut String, need_sync_shared: bool) -> Result<(), Error> {
if need_sync_shared {
write!(
w,
"\t__shared__ cge::block_tile_memory<1024> block_sync_shared;\n"
)?;
}
write!(w, "\tcg::grid_group grid = cg::this_grid();\n")?;
write!(
w,
"\tcg::thread_block block = cge::this_thread_block(block_sync_shared);\n"
)?;
if need_sync_shared {
write!(
w,
"\tcg::thread_block block = cge::this_thread_block(block_sync_shared);\n"
)?;
} else {
write!(
w,
"\tcg::thread_block block = cg::this_thread_block();\n"
)?;
}
Ok(())
}
......@@ -1344,7 +1414,7 @@ namespace cg = cooperative_groups;
write!(w, "{}\t*({} + i) = 0;\n", tabs, define_variable)?;
write!(w, "{}}}\n", tabs)?;
write!(w, "{}{}.sync();\n", tabs, cg_tile)?;
//write!(w, "__syncthreads\n")?;
*self.generated_sync.borrow_mut() = true;
}
}
// Dynamic constants emitted at top
......@@ -1803,10 +1873,12 @@ namespace cg = cooperative_groups;
write!(w_term, "\t}}\n")?;
tabs += 1;
}
let fork = self.join_fork_map.get(&id).unwrap();
let cg_tile_available = self.get_cg_tile(*fork, CGType::Available);
write!(w_term, "\t{}.sync();\n", cg_tile_available)?;
//write!(w_term, "\t__syncthreads;\n")?;
if !self.avoid_join_sync {
let fork = self.join_fork_map.get(&id).unwrap();
let cg_tile_available = self.get_cg_tile(*fork, CGType::Available);
write!(w_term, "\t{}.sync();\n", cg_tile_available)?;
*self.generated_sync.borrow_mut() = true;
}
}
// If the Fork was parallelized, each thread or UsedPerId tile of
// threads only runs one ThreadID, so we can jump straight to the
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment