Skip to content
Snippets Groups Projects

GPU backend

Merged prathi3 requested to merge gpu-cg into main
2 files
+ 11
30
Compare changes
  • Side-by-side
  • Inline
Files
2
+ 6
23
@@ -79,10 +79,6 @@ pub fn gpu_codegen<W: Write>(
* - Add float8, float16, bfloat16 dtypes if they come
*/
// Temporary for matmul (both true) and dot (thread true) test while we don't have schedule annotations
let block_parallel_override = false;
let thread_parallel_override = false;
let reduce_nodes: Vec<NodeID> = (0..function.nodes.len())
.filter(|idx| function.nodes[*idx].is_reduce())
.map(NodeID::new)
@@ -164,8 +160,6 @@ pub fn gpu_codegen<W: Write>(
threads_per_warp: 32,
};
std::fs::write("out.txt", "debug\n\n").unwrap();
let ctx = GPUContext {
function,
types,
@@ -185,8 +179,6 @@ pub fn gpu_codegen<W: Write>(
control_data_phi_map,
return_parameter,
kernel_params,
block_parallel_override,
thread_parallel_override,
};
ctx.codegen_function(w)
}
@@ -215,8 +207,6 @@ struct GPUContext<'a> {
control_data_phi_map: HashMap<NodeID, Vec<(NodeID, NodeID)>>,
return_parameter: Option<usize>,
kernel_params: &'a GPUKernelParams,
block_parallel_override: bool,
thread_parallel_override: bool,
}
/*
@@ -265,8 +255,6 @@ enum CGType {
impl GPUContext<'_> {
fn codegen_function<W: Write>(&self, w: &mut W) -> Result<(), Error> {
let mut file = OpenOptions::new().append(true).open("out.txt").unwrap();
// Emit all code up to the "goto" to Start's block
let mut top = String::new();
self.codegen_kernel_begin(self.return_parameter.is_none(), &mut top)?;
@@ -289,19 +277,14 @@ impl GPUContext<'_> {
// If there are no forks, fast forward to single-block, single-thread codegen
let (num_blocks, num_threads) = if self.fork_join_map.is_empty() {
writeln!(file, "shortcut to 1b1t").unwrap();
self.codegen_data_control_no_forks(&HashSet::new(), &mut dynamic_shared_offset, &mut gotos)?;
("1".to_string(), "1".to_string())
} else {
writeln!(file, "no shortcut! fork tree: {:?}", self.fork_tree).unwrap();
// Create structures and determine block and thread parallelization strategy
let (root_forks, num_blocks, is_block_parallel) =
self.get_root_forks_and_num_blocks(&self.fork_tree);
writeln!(file, "is_block_parallel: {}", is_block_parallel).unwrap();
let (thread_root_root_fork, thread_root_forks) = self.get_thread_root_forks(&root_forks, &self.fork_tree, is_block_parallel);
writeln!(file, "thread_root_root_fork: {:?}", thread_root_root_fork).unwrap();
let (fork_thread_quota_map, num_threads) = self.get_thread_quotas(&self.fork_tree, thread_root_root_fork);
writeln!(file, "fork_thread_quota_map: {:?}", fork_thread_quota_map).unwrap();
self.get_root_forks_and_num_blocks(self.fork_tree);
let (thread_root_root_fork, thread_root_forks) = self.get_thread_root_forks(&root_forks, self.fork_tree, is_block_parallel);
let (fork_thread_quota_map, num_threads) = self.get_thread_quotas(self.fork_tree, thread_root_root_fork);
// TODO: Uncomment and adjust once we know logic of extra dim
// let extra_dim_collects = self.get_extra_dim_collects(&fork_control_map, &fork_thread_quota_map);
let extra_dim_collects = HashSet::new();
@@ -590,7 +573,7 @@ extern \"C\" {} {}(", ret_type.clone(), self.function.name)?;
self.collection_objects.origin(*object).try_parameter().is_some()
})
}), "All collection reduces in block fork must originate from parameters");
if self.block_parallel_override || self.function.schedules[root_fork.idx()].contains(&Schedule::ParallelFork)
if self.function.schedules[root_fork.idx()].contains(&Schedule::ParallelFork)
{
let fork_size = factors.iter().map(|dc| format!("dc{}", dc.idx())).collect::<Vec<_>>().join(" * ");
(root_forks, fork_size, true)
@@ -697,7 +680,7 @@ extern \"C\" {} {}(", ret_type.clone(), self.function.name)?;
&& fork_size.is_power_of_two()
&& reduces.iter().all(|&reduce| {
self.function.schedules[reduce.idx()].contains(&Schedule::ParallelReduce)
|| self.thread_parallel_override || self.function.schedules[reduce.idx()].contains(&Schedule::TightAssociative)
|| self.function.schedules[reduce.idx()].contains(&Schedule::TightAssociative)
})
{
// If there's an associative Reduce, parallelize the larger factor
@@ -710,7 +693,7 @@ extern \"C\" {} {}(", ret_type.clone(), self.function.name)?;
// restriction doesn't help for parallel Writes, so nested parallelization
// is possible.
if reduces.iter().any(|&reduce| {
self.thread_parallel_override || self.function.schedules[reduce.idx()].contains(&Schedule::TightAssociative)
self.function.schedules[reduce.idx()].contains(&Schedule::TightAssociative)
}) || fork_size > self.kernel_params.max_num_threads / subtree_quota {
if fork_size >= subtree_quota {
(HashMap::new(), fork_size, true)
Loading