Skip to content
Snippets Groups Projects

GPU backend

Merged prathi3 requested to merge gpu-cg into main
1 file
+ 78
72
Compare changes
  • Side-by-side
  • Inline
+ 78
72
@@ -79,6 +79,10 @@ pub fn gpu_codegen<W: Write>(
* - Add float8, float16, bfloat16 dtypes if they come
*/
// Temporary for matmul (both true) and dot (thread true) test while we don't have schedule annotations
let block_parallel_override = false;
let thread_parallel_override = false;
let reduce_nodes: Vec<NodeID> = (0..function.nodes.len())
.filter(|idx| function.nodes[*idx].is_reduce())
.map(NodeID::new)
@@ -109,7 +113,6 @@ pub fn gpu_codegen<W: Write>(
.entry(fork_node)
.or_default()
.push(*reduce_node);
println!("reduce_node: {:?}, fork_node: {:?}, join: {:?}", reduce_node, fork_node, control);
}
Node::Region { preds: _ } => {
// TODO: map region node to fork node
@@ -118,14 +121,10 @@ pub fn gpu_codegen<W: Write>(
panic!("Reduce's control must be a join or region node");
}
}
if !function.schedules[reduce_node.idx()].contains(&Schedule::ParallelReduce)
&& !function.schedules[reduce_node.idx()].contains(&Schedule::TightAssociative)
{
reduct_reduce_map
.entry(*reduct)
.or_default()
.push(*reduce_node);
}
reduct_reduce_map
.entry(*reduct)
.or_default()
.push(*reduce_node);
}
}
for idx in 0..function.nodes.len() {
@@ -186,6 +185,8 @@ pub fn gpu_codegen<W: Write>(
control_data_phi_map,
return_parameter,
kernel_params,
block_parallel_override,
thread_parallel_override,
};
ctx.codegen_function(w)
}
@@ -214,6 +215,8 @@ struct GPUContext<'a> {
control_data_phi_map: HashMap<NodeID, Vec<(NodeID, NodeID)>>,
return_parameter: Option<usize>,
kernel_params: &'a GPUKernelParams,
block_parallel_override: bool,
thread_parallel_override: bool,
}
/*
@@ -298,6 +301,7 @@ impl GPUContext<'_> {
let (thread_root_root_fork, thread_root_forks) = self.get_thread_root_forks(&root_forks, &self.fork_tree, is_block_parallel);
writeln!(file, "thread_root_root_fork: {:?}", thread_root_root_fork).unwrap();
let (fork_thread_quota_map, num_threads) = self.get_thread_quotas(&self.fork_tree, thread_root_root_fork);
writeln!(file, "fork_thread_quota_map: {:?}", fork_thread_quota_map).unwrap();
// TODO: Uncomment and adjust once we know logic of extra dim
// let extra_dim_collects = self.get_extra_dim_collects(&fork_control_map, &fork_thread_quota_map);
let extra_dim_collects = HashSet::new();
@@ -586,7 +590,7 @@ extern \"C\" {} {}(", ret_type.clone(), self.function.name)?;
self.collection_objects.origin(*object).try_parameter().is_some()
})
}), "All collection reduces in block fork must originate from parameters");
if true || self.function.schedules[root_fork.idx()].contains(&Schedule::ParallelFork)
if self.block_parallel_override || self.function.schedules[root_fork.idx()].contains(&Schedule::ParallelFork)
{
let fork_size = factors.iter().map(|dc| format!("dc{}", dc.idx())).collect::<Vec<_>>().join(" * ");
(root_forks, fork_size, true)
@@ -693,7 +697,7 @@ extern \"C\" {} {}(", ret_type.clone(), self.function.name)?;
&& fork_size.is_power_of_two()
&& reduces.iter().all(|&reduce| {
self.function.schedules[reduce.idx()].contains(&Schedule::ParallelReduce)
|| self.function.schedules[reduce.idx()].contains(&Schedule::TightAssociative)
|| self.thread_parallel_override || self.function.schedules[reduce.idx()].contains(&Schedule::TightAssociative)
})
{
// If there's an associative Reduce, parallelize the larger factor
@@ -706,7 +710,7 @@ extern \"C\" {} {}(", ret_type.clone(), self.function.name)?;
// restriction doesn't help for parallel Writes, so nested parallelization
// is possible.
if reduces.iter().any(|&reduce| {
self.function.schedules[reduce.idx()].contains(&Schedule::TightAssociative)
self.thread_parallel_override || self.function.schedules[reduce.idx()].contains(&Schedule::TightAssociative)
}) || fork_size > self.kernel_params.max_num_threads / subtree_quota {
if fork_size >= subtree_quota {
(HashMap::new(), fork_size, true)
@@ -1069,8 +1073,8 @@ extern \"C\" {} {}(", ret_type.clone(), self.function.name)?;
}
},
Node::Binary { op, left, right } => {
let left_val = self.get_value(*left, false, false);
let right_val = self.get_value(*right, false, false);
let mut left_val = self.get_value(*left, false, false);
let mut right_val = self.get_value(*right, false, false);
let id_type = self.typing[id.idx()];
if matches!(op, BinaryOperator::Add | BinaryOperator::Or | BinaryOperator::And
| BinaryOperator::Xor) && is_special_reduct {
@@ -1079,14 +1083,14 @@ extern \"C\" {} {}(", ret_type.clone(), self.function.name)?;
// supported. We need to use CGType::Use not CGType::UsePerId
// because for parallelized reduction we only have one thread
// per ThreadID and the reduction is over Use, not UsePerId.
let non_reduce_arg = if let Node::Reduce { control: _, init: _, reduct: _ } = &self.function.nodes[left.idx()] {
right_val
let (reduce_val, non_reduce_val) = if let Node::Reduce { control: _, init: _, reduct: _ } = &self.function.nodes[left.idx()] {
(left_val, right_val)
} else {
left_val
(right_val, left_val)
};
// Special reduct is only enabled for thread parallelization
// so don't need to worry about grid and block cases
let cg_tile = self.get_cg_tile(id, CGType::Use);
let cg_tile = self.get_cg_tile(nesting_fork.unwrap(), CGType::Use);
#[allow(unreachable_patterns)]
let cg_op = match op {
BinaryOperator::Add => "plus",
@@ -1096,57 +1100,60 @@ extern \"C\" {} {}(", ret_type.clone(), self.function.name)?;
_ => unreachable!(),
};
let id_type_name = self.get_type(id_type, false);
write!(w, "{}{} = cg::reduce({}, {}, cg::{}<{}>());\n", tabs, define_variable, cg_tile, non_reduce_arg, cg_op, id_type_name)?;
} else {
match (op, &self.types[id_type.idx()]) {
(BinaryOperator::Or, Type::Boolean) => write!(
w,
"{}{} = {} || {};\n",
tabs, define_variable, left_val, right_val,
)?,
(BinaryOperator::And, Type::Boolean) => write!(
w,
"{}{} = {} && {};\n",
tabs, define_variable, left_val, right_val,
)?,
(BinaryOperator::Rem, Type::Float32) => write!(
w,
"{}{} = fmodf({}, {});\n",
tabs, define_variable, left_val, right_val,
)?,
(BinaryOperator::Rem, Type::Float64) => write!(
w,
"{}{} = fmod({}, {});\n",
tabs, define_variable, left_val, right_val,
)?,
(op, _) => write!(
w,
"{}{} = {} {} {};\n",
tabs,
define_variable,
left_val,
match op {
BinaryOperator::Add => "+",
BinaryOperator::Sub => "-",
BinaryOperator::Mul => "*",
BinaryOperator::Div => "/",
BinaryOperator::Rem => "%",
BinaryOperator::LT => "<",
BinaryOperator::LTE => "<=",
BinaryOperator::GT => ">",
BinaryOperator::GTE => ">=",
BinaryOperator::EQ => "==",
BinaryOperator::NE => "!=",
BinaryOperator::Or => "|",
BinaryOperator::And => "&",
BinaryOperator::Xor => "^",
BinaryOperator::LSh => "<<",
BinaryOperator::RSh => ">>",
},
right_val,
)?,
};
write!(w, "{}{} = cg::reduce({}, {}, cg::{}<{}>());\n", tabs, define_variable, cg_tile, non_reduce_val, cg_op, id_type_name)?;
// Setup binop between reduce's init and reduced reduct. Since it's associative,
// we can change binop ordering
left_val = define_variable.clone();
right_val = reduce_val;
}
match (op, &self.types[id_type.idx()]) {
(BinaryOperator::Or, Type::Boolean) => write!(
w,
"{}{} = {} || {};\n",
tabs, define_variable, left_val, right_val,
)?,
(BinaryOperator::And, Type::Boolean) => write!(
w,
"{}{} = {} && {};\n",
tabs, define_variable, left_val, right_val,
)?,
(BinaryOperator::Rem, Type::Float32) => write!(
w,
"{}{} = fmodf({}, {});\n",
tabs, define_variable, left_val, right_val,
)?,
(BinaryOperator::Rem, Type::Float64) => write!(
w,
"{}{} = fmod({}, {});\n",
tabs, define_variable, left_val, right_val,
)?,
(op, _) => write!(
w,
"{}{} = {} {} {};\n",
tabs,
define_variable,
left_val,
match op {
BinaryOperator::Add => "+",
BinaryOperator::Sub => "-",
BinaryOperator::Mul => "*",
BinaryOperator::Div => "/",
BinaryOperator::Rem => "%",
BinaryOperator::LT => "<",
BinaryOperator::LTE => "<=",
BinaryOperator::GT => ">",
BinaryOperator::GTE => ">=",
BinaryOperator::EQ => "==",
BinaryOperator::NE => "!=",
BinaryOperator::Or => "|",
BinaryOperator::And => "&",
BinaryOperator::Xor => "^",
BinaryOperator::LSh => "<<",
BinaryOperator::RSh => ">>",
},
right_val,
)?,
};
}
Node::Ternary {
op,
@@ -1175,7 +1182,7 @@ extern \"C\" {} {}(", ret_type.clone(), self.function.name)?;
} else {
self.get_value(args[0], false, false)
};
let cg_tile = self.get_cg_tile(id, CGType::Use);
let cg_tile = self.get_cg_tile(nesting_fork.unwrap(), CGType::Use);
#[allow(unreachable_patterns)]
let cg_op = match intrinsic {
Intrinsic::Max => "max",
@@ -1248,8 +1255,7 @@ extern \"C\" {} {}(", ret_type.clone(), self.function.name)?;
write!(w, "{}{} = {};\n", tabs, define_variable, collect_variable)?;
}
_ => {
println!("Unsupported data node type: {:?}", self.function.nodes[id.idx()]);
panic!("Unsupported data node type")
panic!("Unsupported data node type: {:?}", self.function.nodes[id.idx()])
}
}
// Since reducts are responsible for updating Reduce nodes, we check and
@@ -1404,8 +1410,8 @@ extern \"C\" {} {}(", ret_type.clone(), self.function.name)?;
write!(w_term, "\tif (grid.thread_rank() == 0) {{\n")?;
write!(w_term, "\t\t*(reinterpret_cast<{}>(ret)) = {};\n", return_type_ptr, return_val)?;
write!(w_term, "\t}}\n")?;
write!(w_term, "\treturn;\n")?;
}
write!(w_term, "\treturn;\n")?;
1
}
_ => {
Loading