Skip to content
Snippets Groups Projects
Commit 4897c13f authored by Praneet Rathi's avatar Praneet Rathi
Browse files

minor

parent 249294c5
No related branches found
No related tags found
1 merge request!115GPU backend
...@@ -459,13 +459,11 @@ namespace cg = cooperative_groups; ...@@ -459,13 +459,11 @@ namespace cg = cooperative_groups;
} }
/* /*
* Emit helper registers that are used throughout the kernel. alignment * Emit helper registers that are used throughout the kernel. grid and block
* is for proper dynamic shared memory allocation. grid and block are * are from CUDA's cooperative groups API and are used specifically for reads
* from CUDA's cooperative groups API and are used specifically for reads and * and writes.
* writes.
*/ */
fn codegen_helpers(&self, w: &mut String) -> Result<(), Error> { fn codegen_helpers(&self, w: &mut String) -> Result<(), Error> {
write!(w, "\tsize_t alignment;\n")?;
write!(w, "\tcg::grid_group grid = cg::this_grid();\n")?; write!(w, "\tcg::grid_group grid = cg::this_grid();\n")?;
write!(w, "\tcg::thread_block block = cg::this_thread_block();\n")?; write!(w, "\tcg::thread_block block = cg::this_thread_block();\n")?;
Ok(()) Ok(())
...@@ -479,9 +477,15 @@ namespace cg = cooperative_groups; ...@@ -479,9 +477,15 @@ namespace cg = cooperative_groups;
fn codegen_gotos(&self, gotos: &mut BTreeMap<NodeID, CudaGoto>, w: &mut String) -> Result<(), Error> { fn codegen_gotos(&self, gotos: &mut BTreeMap<NodeID, CudaGoto>, w: &mut String) -> Result<(), Error> {
write!(w, "\n")?; write!(w, "\n")?;
for (_, goto) in gotos.iter() { for (id, goto) in gotos.iter() {
let goto_block = self.get_block_name(*id, false);
write!(w, "{}:\n", goto_block)?;
write!(w, "{}\n", goto.init)?; write!(w, "{}\n", goto.init)?;
write!(w, "{}\n", goto.post_init)?; if !goto.post_init.is_empty() {
let goto_block = self.get_block_name(*id, true);
write!(w, "{}:\n", goto_block)?;
write!(w, "{}\n", goto.post_init)?;
}
write!(w, "{}\n", goto.body)?; write!(w, "{}\n", goto.body)?;
write!(w, "{}\n\n", goto.term)?; write!(w, "{}\n\n", goto.term)?;
} }
...@@ -886,7 +890,11 @@ namespace cg = cooperative_groups; ...@@ -886,7 +890,11 @@ namespace cg = cooperative_groups;
Node::Constant { id: cons_id } => { Node::Constant { id: cons_id } => {
let is_primitive = self.types[self.typing[id.idx()].idx()].is_primitive(); let is_primitive = self.types[self.typing[id.idx()].idx()].is_primitive();
if (!is_primitive) { if (!is_primitive) {
let cg_tile = self.get_cg_tile(id, CGType::UsePerId); let cg_tile = match state {
KernelState::OutBlock
| KernelState::InBlock => "block".to_string(),
KernelState::InThread => self.get_cg_tile(id, CGType::UsePerId),
};
write!(w, "{}if ({}.thread_rank() == 0) {{\n", tabs, cg_tile)?; write!(w, "{}if ({}.thread_rank() == 0) {{\n", tabs, cg_tile)?;
*num_tabs += 1; *num_tabs += 1;
} }
...@@ -967,6 +975,8 @@ namespace cg = cooperative_groups; ...@@ -967,6 +975,8 @@ namespace cg = cooperative_groups;
} else { } else {
left_val left_val
}; };
// Special reduct is only enabled for thread parallelization
// so don't need to worry about grid and block cases
let cg_tile = self.get_cg_tile(id, CGType::Use); let cg_tile = self.get_cg_tile(id, CGType::Use);
#[allow(unreachable_patterns)] #[allow(unreachable_patterns)]
let cg_op = match op { let cg_op = match op {
...@@ -1086,7 +1096,7 @@ namespace cg = cooperative_groups; ...@@ -1086,7 +1096,7 @@ namespace cg = cooperative_groups;
// The first three can all use cooperative groups memcpy and the last // The first three can all use cooperative groups memcpy and the last
// one can't. However, the C++/CUDA semantics for the last three are // one can't. However, the C++/CUDA semantics for the last three are
// identical, so we differentiate the cases by data type instead of // identical, so we differentiate the cases by data type instead of
// data source and destination. // data src/dest, with only collection type using collective group.
Node::Read { collect, indices } => { Node::Read { collect, indices } => {
let is_char = self.is_char(self.typing[collect.idx()]); let is_char = self.is_char(self.typing[collect.idx()]);
let collect_with_indices = self.codegen_collect(*collect, indices, is_char); let collect_with_indices = self.codegen_collect(*collect, indices, is_char);
...@@ -1100,15 +1110,19 @@ namespace cg = cooperative_groups; ...@@ -1100,15 +1110,19 @@ namespace cg = cooperative_groups;
} }
} else { } else {
let nested_fork = nesting_fork.unwrap(); let nested_fork = nesting_fork.unwrap();
let cg_tile = self.get_cg_tile(nested_fork, CGType::UsePerId); let cg_tile = match state {
KernelState::OutBlock => "grid".to_string(),
KernelState::InBlock => "block".to_string(),
KernelState::InThread => self.get_cg_tile(nested_fork, CGType::UsePerId),
};
let data_size = self.get_size(data_type_id, None); let data_size = self.get_size(data_type_id, None);
write!(w, "{}cg::memcpy_async({}, {}, {}, {});\n", tabs, cg_tile, define_variable, collect_with_indices, data_size)?; write!(w, "{}cg::memcpy_async({}, {}, {}, {});\n", tabs, cg_tile, define_variable, collect_with_indices, data_size)?;
write!(w, "{}cg::wait({});\n", tabs, cg_tile)?; write!(w, "{}cg::wait({});\n", tabs, cg_tile)?;
} }
} }
// For write, the cases are the same, but since we're using C++/CUDA // For write, the cases are the same, but when using C++ dereference
// not-thread-safe write semantics, we need to gate the write with // semantics, we need to gate the write with a thread rank check for
// a thread rank check. // thread safety.
Node::Write { Node::Write {
collect, collect,
data, data,
...@@ -1119,7 +1133,11 @@ namespace cg = cooperative_groups; ...@@ -1119,7 +1133,11 @@ namespace cg = cooperative_groups;
let data_variable = self.get_value(*data, false, false); let data_variable = self.get_value(*data, false, false);
let data_type_id = self.typing[data.idx()]; let data_type_id = self.typing[data.idx()];
let nested_fork = nesting_fork.unwrap(); let nested_fork = nesting_fork.unwrap();
let cg_tile = self.get_cg_tile(nested_fork, CGType::UsePerId); let cg_tile = match state {
KernelState::OutBlock => "grid".to_string(),
KernelState::InBlock => "block".to_string(),
KernelState::InThread => self.get_cg_tile(nested_fork, CGType::UsePerId),
};
if self.types[data_type_id.idx()].is_primitive() { if self.types[data_type_id.idx()].is_primitive() {
write!(w, "{}if ({}.thread_rank() == 0) {{\n", tabs, cg_tile)?; write!(w, "{}if ({}.thread_rank() == 0) {{\n", tabs, cg_tile)?;
if is_char { if is_char {
...@@ -1244,9 +1262,9 @@ namespace cg = cooperative_groups; ...@@ -1244,9 +1262,9 @@ namespace cg = cooperative_groups;
write!(w_init, "\tif (threadIdx.x % {} < {}) {{\n", available_thread_quota, use_thread_quota)?; write!(w_init, "\tif (threadIdx.x % {} < {}) {{\n", available_thread_quota, use_thread_quota)?;
write!(w_term, "\t}}\n")?; write!(w_term, "\t}}\n")?;
} }
let cg_tile_available = self.get_cg_tile(id, CGType::Available);
write!(w_term, "\t{}.sync();\n", cg_tile_available)?;
} }
let cg_tile_available = self.get_cg_tile(id, CGType::Available);
write!(w_term, "\t{}.sync();\n", cg_tile_available)?;
// If the Fork was parallelized, each thread or UsedPerId tile of // If the Fork was parallelized, each thread or UsedPerId tile of
// threads only runs one ThreadID, so we can jump straight to the // threads only runs one ThreadID, so we can jump straight to the
// successor. Else, we jump back to the Fork until all ThreadIDs // successor. Else, we jump back to the Fork until all ThreadIDs
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment