Skip to content
Snippets Groups Projects

GPU backend

Merged prathi3 requested to merge gpu-cg into main
4 files
+ 31
56
Compare changes
  • Side-by-side
  • Inline
Files
4
+ 27
52
@@ -183,21 +183,19 @@ pub fn gpu_codegen<W: Write>(
threads_per_warp: 32,
};
let label_data_for_phi = || -> HashMap<NodeID, Vec<NodeID>> {
let mut label_data_for_phi = HashMap::new();
for (idx, node) in function.nodes.iter().enumerate() {
if let Node::Phi { control: _, data } = node {
for &data_id in data.iter() {
label_data_for_phi
.entry(data_id)
.or_insert(vec![])
.push(NodeID::new(idx));
}
// Map from control to pairs of data to update phi
// For each phi, we go to its region and get region's controls
let control_data_phi_map: &mut HashMap<NodeID, Vec<(NodeID, NodeID)>> = &mut HashMap::new();
for (idx, node) in function.nodes.iter().enumerate() {
if let Node::Phi { control, data } = node {
let Node::Region { preds } = &function.nodes[control.idx()] else {
panic!("Phi's control must be a region node");
};
for (i, &pred) in preds.iter().enumerate() {
control_data_phi_map.entry(pred).or_insert(vec![]).push((data[i], NodeID::new(idx)));
}
}
label_data_for_phi
};
let label_data_for_phi = &label_data_for_phi();
}
let def_use_map = &def_use(function);
@@ -215,7 +213,7 @@ pub fn gpu_codegen<W: Write>(
join_fork_map,
fork_reduce_map,
reduct_reduce_map,
label_data_for_phi,
control_data_phi_map,
return_type_id,
};
ctx.codegen_function(w)
@@ -241,7 +239,7 @@ struct GPUContext<'a> {
join_fork_map: &'a HashMap<NodeID, NodeID>,
fork_reduce_map: &'a HashMap<NodeID, Vec<NodeID>>,
reduct_reduce_map: &'a HashMap<NodeID, Vec<NodeID>>,
label_data_for_phi: &'a HashMap<NodeID, Vec<NodeID>>,
control_data_phi_map: &'a HashMap<NodeID, Vec<(NodeID, NodeID)>>,
return_type_id: &'a TypeID,
}
@@ -490,6 +488,9 @@ namespace cg = cooperative_groups;
!self.function.nodes[id.idx()].is_parameter() {
write!(w, "\t{};\n", self.get_value(id, true, false))?;
}
if self.function.nodes[id.idx()].is_phi() {
write!(w, "\t{}_tmp;\n", self.get_value(id, true, false))?;
}
}
Ok(())
}
@@ -900,9 +901,6 @@ extern \"C\" {} {}(", if ret_primitive { ret_type.clone() } else { "void".to_str
for data in self.bbs.1[control.idx()].iter() {
self.codegen_data_node(*data, KernelState::OutBlock, None, None, None, false, extra_dim_collects, dynamic_shared_offset, body, &mut tabs)?;
}
for data in self.bbs.1[control.idx()].iter() {
self.codegen_data_phi(*data, tabs, body)?;
}
Ok(())
})
}
@@ -935,9 +933,6 @@ extern \"C\" {} {}(", if ret_primitive { ret_type.clone() } else { "void".to_str
for data in self.bbs.1[control.idx()].iter() {
self.codegen_data_node(*data, state, None, None, None, false, extra_dim_collects, dynamic_shared_offset, body, &mut tabs)?;
}
for data in self.bbs.1[control.idx()].iter() {
self.codegen_data_phi(*data, tabs, body)?;
}
}
// Then generate data and control for the single block fork if it exists
if block_fork.is_some() {
@@ -952,9 +947,6 @@ extern \"C\" {} {}(", if ret_primitive { ret_type.clone() } else { "void".to_str
for data in self.bbs.1[control.idx()].iter() {
self.codegen_data_node(*data, state, Some(num_threads), None, Some(block_fork.unwrap()), false, extra_dim_collects, dynamic_shared_offset, body, &mut tabs)?;
}
for data in self.bbs.1[control.idx()].iter() {
self.codegen_data_phi(*data, tabs, body)?;
}
}
}
// Then generate for the thread fork tree through Fork node traversal.
@@ -1034,9 +1026,6 @@ extern \"C\" {} {}(", if ret_primitive { ret_type.clone() } else { "void".to_str
&mut tabs,
)?;
}
for data in self.bbs.1[control.idx()].iter() {
self.codegen_data_phi(*data, tabs, body)?;
}
}
for child in fork_tree.get(&curr_fork).unwrap() {
self.codegen_data_control_traverse(
@@ -1071,12 +1060,9 @@ extern \"C\" {} {}(", if ret_primitive { ret_type.clone() } else { "void".to_str
let define_variable = self.get_value(id, false, false).to_string();
let tabs = "\t".repeat(*num_tabs);
match &self.function.nodes[id.idx()] {
// Phi registers emitted at top and the data nodes it uses will
// update the phi
Node::Phi {
control: _,
data: _,
} => {}
Node::Phi { control: _, data: _ } => {
write!(w, "{}{} = {}_tmp;\n", tabs, define_variable, define_variable)?;
}
Node::ThreadID { control, dimension } => {
let Node::Fork { factors, .. } = &self.function.nodes[control.idx()] else {
panic!("Expected ThreadID's control to be a fork node");
@@ -1413,8 +1399,8 @@ extern \"C\" {} {}(", if ret_primitive { ret_type.clone() } else { "void".to_str
panic!("Unsupported data node type")
}
}
// Since reducts are responsible for updating Reduce nodes,
// we check and emit those for each data node.
// Since reducts are responsible for updating Reduce nodes, we check and
// emit those for each data node.
if let Some(reduces) = self.reduct_reduce_map.get(&id) {
let val = self.get_value(id, false, false);
for reduce in reduces {
@@ -1425,22 +1411,6 @@ extern \"C\" {} {}(", if ret_primitive { ret_type.clone() } else { "void".to_str
Ok(())
}
/*
* Update Phi assignments for each data node. This is run after all data nodes
* for given control block have been emitted.
*/
fn codegen_data_phi(&self, id: NodeID, num_tabs: usize, w: &mut String) -> Result<(), Error> {
let tabs = "\t".repeat(num_tabs);
if let Some(phis) = self.label_data_for_phi.get(&id) {
let val = self.get_value(id, false, false);
for phi in phis {
let phi_val = self.get_value(*phi, false, false);
write!(w, "{}{} = {};\n", tabs, phi_val, val)?;
}
}
Ok(())
}
fn codegen_control_node(
&self,
id: NodeID,
@@ -1451,6 +1421,11 @@ extern \"C\" {} {}(", if ret_primitive { ret_type.clone() } else { "void".to_str
w_post_init: &mut String,
w_term: &mut String,
) -> Result<usize, Error> {
for (data, phi) in self.control_data_phi_map.get(&id).unwrap_or(&vec![]).iter() {
let data = self.get_value(*data, false, false);
let phi = self.get_value(*phi, false, false);
write!(w_term, "\t{}_tmp = {};\n", phi, data)?;
}
let tabs = match &self.function.nodes[id.idx()] {
Node::Start
| Node::Region { preds: _ }
@@ -1572,7 +1547,7 @@ extern \"C\" {} {}(", if ret_primitive { ret_type.clone() } else { "void".to_str
// we write to that parameter upon return.
if self.types[self.typing[data.idx()].idx()].is_primitive() {
let return_val = self.get_value(*data, false, false);
write!(w_term, "\tif (threadIdx.x == 0) {{\n")?;
write!(w_term, "\tif (grid.thread_rank() == 0) {{\n")?;
write!(w_term, "\t\t*ret = {};\n", return_val)?;
write!(w_term, "\t}}\n")?;
}
Loading