From 58b7928b8e7cd5d197b06cddc81674dce3bf8468 Mon Sep 17 00:00:00 2001 From: prrathi <prrathi10@gmail.com> Date: Sat, 25 Jan 2025 18:22:09 +0000 Subject: [PATCH] things work --- hercules_cg/src/gpu.rs | 150 +++++++++++++++++++++-------------------- 1 file changed, 78 insertions(+), 72 deletions(-) diff --git a/hercules_cg/src/gpu.rs b/hercules_cg/src/gpu.rs index 074281f8..c3cb6634 100644 --- a/hercules_cg/src/gpu.rs +++ b/hercules_cg/src/gpu.rs @@ -79,6 +79,10 @@ pub fn gpu_codegen<W: Write>( * - Add float8, float16, bfloat16 dtypes if they come */ + // Temporary for matmul (both true) and dot (thread true) test while we don't have schedule annotations + let block_parallel_override = false; + let thread_parallel_override = false; + let reduce_nodes: Vec<NodeID> = (0..function.nodes.len()) .filter(|idx| function.nodes[*idx].is_reduce()) .map(NodeID::new) @@ -109,7 +113,6 @@ pub fn gpu_codegen<W: Write>( .entry(fork_node) .or_default() .push(*reduce_node); - println!("reduce_node: {:?}, fork_node: {:?}, join: {:?}", reduce_node, fork_node, control); } Node::Region { preds: _ } => { // TODO: map region node to fork node @@ -118,14 +121,10 @@ pub fn gpu_codegen<W: Write>( panic!("Reduce's control must be a join or region node"); } } - if !function.schedules[reduce_node.idx()].contains(&Schedule::ParallelReduce) - && !function.schedules[reduce_node.idx()].contains(&Schedule::TightAssociative) - { - reduct_reduce_map - .entry(*reduct) - .or_default() - .push(*reduce_node); - } + reduct_reduce_map + .entry(*reduct) + .or_default() + .push(*reduce_node); } } for idx in 0..function.nodes.len() { @@ -186,6 +185,8 @@ pub fn gpu_codegen<W: Write>( control_data_phi_map, return_parameter, kernel_params, + block_parallel_override, + thread_parallel_override, }; ctx.codegen_function(w) } @@ -214,6 +215,8 @@ struct GPUContext<'a> { control_data_phi_map: HashMap<NodeID, Vec<(NodeID, NodeID)>>, return_parameter: Option<usize>, kernel_params: &'a GPUKernelParams, + block_parallel_override: bool, + thread_parallel_override: bool, } /* @@ -298,6 +301,7 @@ impl GPUContext<'_> { let (thread_root_root_fork, thread_root_forks) = self.get_thread_root_forks(&root_forks, &self.fork_tree, is_block_parallel); writeln!(file, "thread_root_root_fork: {:?}", thread_root_root_fork).unwrap(); let (fork_thread_quota_map, num_threads) = self.get_thread_quotas(&self.fork_tree, thread_root_root_fork); + writeln!(file, "fork_thread_quota_map: {:?}", fork_thread_quota_map).unwrap(); // TODO: Uncomment and adjust once we know logic of extra dim // let extra_dim_collects = self.get_extra_dim_collects(&fork_control_map, &fork_thread_quota_map); let extra_dim_collects = HashSet::new(); @@ -586,7 +590,7 @@ extern \"C\" {} {}(", ret_type.clone(), self.function.name)?; self.collection_objects.origin(*object).try_parameter().is_some() }) }), "All collection reduces in block fork must originate from parameters"); - if true || self.function.schedules[root_fork.idx()].contains(&Schedule::ParallelFork) + if self.block_parallel_override || self.function.schedules[root_fork.idx()].contains(&Schedule::ParallelFork) { let fork_size = factors.iter().map(|dc| format!("dc{}", dc.idx())).collect::<Vec<_>>().join(" * "); (root_forks, fork_size, true) @@ -693,7 +697,7 @@ extern \"C\" {} {}(", ret_type.clone(), self.function.name)?; && fork_size.is_power_of_two() && reduces.iter().all(|&reduce| { self.function.schedules[reduce.idx()].contains(&Schedule::ParallelReduce) - || self.function.schedules[reduce.idx()].contains(&Schedule::TightAssociative) + || self.thread_parallel_override || self.function.schedules[reduce.idx()].contains(&Schedule::TightAssociative) }) { // If there's an associative Reduce, parallelize the larger factor @@ -706,7 +710,7 @@ extern \"C\" {} {}(", ret_type.clone(), self.function.name)?; // restriction doesn't help for parallel Writes, so nested parallelization // is possible. if reduces.iter().any(|&reduce| { - self.function.schedules[reduce.idx()].contains(&Schedule::TightAssociative) + self.thread_parallel_override || self.function.schedules[reduce.idx()].contains(&Schedule::TightAssociative) }) || fork_size > self.kernel_params.max_num_threads / subtree_quota { if fork_size >= subtree_quota { (HashMap::new(), fork_size, true) @@ -1069,8 +1073,8 @@ extern \"C\" {} {}(", ret_type.clone(), self.function.name)?; } }, Node::Binary { op, left, right } => { - let left_val = self.get_value(*left, false, false); - let right_val = self.get_value(*right, false, false); + let mut left_val = self.get_value(*left, false, false); + let mut right_val = self.get_value(*right, false, false); let id_type = self.typing[id.idx()]; if matches!(op, BinaryOperator::Add | BinaryOperator::Or | BinaryOperator::And | BinaryOperator::Xor) && is_special_reduct { @@ -1079,14 +1083,14 @@ extern \"C\" {} {}(", ret_type.clone(), self.function.name)?; // supported. We need to use CGType::Use not CGType::UsePerId // because for parallelized reduction we only have one thread // per ThreadID and the reduction is over Use, not UsePerId. - let non_reduce_arg = if let Node::Reduce { control: _, init: _, reduct: _ } = &self.function.nodes[left.idx()] { - right_val + let (reduce_val, non_reduce_val) = if let Node::Reduce { control: _, init: _, reduct: _ } = &self.function.nodes[left.idx()] { + (left_val, right_val) } else { - left_val + (right_val, left_val) }; // Special reduct is only enabled for thread parallelization // so don't need to worry about grid and block cases - let cg_tile = self.get_cg_tile(id, CGType::Use); + let cg_tile = self.get_cg_tile(nesting_fork.unwrap(), CGType::Use); #[allow(unreachable_patterns)] let cg_op = match op { BinaryOperator::Add => "plus", @@ -1096,57 +1100,60 @@ extern \"C\" {} {}(", ret_type.clone(), self.function.name)?; _ => unreachable!(), }; let id_type_name = self.get_type(id_type, false); - write!(w, "{}{} = cg::reduce({}, {}, cg::{}<{}>());\n", tabs, define_variable, cg_tile, non_reduce_arg, cg_op, id_type_name)?; - } else { - match (op, &self.types[id_type.idx()]) { - (BinaryOperator::Or, Type::Boolean) => write!( - w, - "{}{} = {} || {};\n", - tabs, define_variable, left_val, right_val, - )?, - (BinaryOperator::And, Type::Boolean) => write!( - w, - "{}{} = {} && {};\n", - tabs, define_variable, left_val, right_val, - )?, - (BinaryOperator::Rem, Type::Float32) => write!( - w, - "{}{} = fmodf({}, {});\n", - tabs, define_variable, left_val, right_val, - )?, - (BinaryOperator::Rem, Type::Float64) => write!( - w, - "{}{} = fmod({}, {});\n", - tabs, define_variable, left_val, right_val, - )?, - (op, _) => write!( - w, - "{}{} = {} {} {};\n", - tabs, - define_variable, - left_val, - match op { - BinaryOperator::Add => "+", - BinaryOperator::Sub => "-", - BinaryOperator::Mul => "*", - BinaryOperator::Div => "/", - BinaryOperator::Rem => "%", - BinaryOperator::LT => "<", - BinaryOperator::LTE => "<=", - BinaryOperator::GT => ">", - BinaryOperator::GTE => ">=", - BinaryOperator::EQ => "==", - BinaryOperator::NE => "!=", - BinaryOperator::Or => "|", - BinaryOperator::And => "&", - BinaryOperator::Xor => "^", - BinaryOperator::LSh => "<<", - BinaryOperator::RSh => ">>", - }, - right_val, - )?, - }; + write!(w, "{}{} = cg::reduce({}, {}, cg::{}<{}>());\n", tabs, define_variable, cg_tile, non_reduce_val, cg_op, id_type_name)?; + // Setup binop between reduce's init and reduced reduct. Since it's associative, + // we can change binop ordering + left_val = define_variable.clone(); + right_val = reduce_val; } + match (op, &self.types[id_type.idx()]) { + (BinaryOperator::Or, Type::Boolean) => write!( + w, + "{}{} = {} || {};\n", + tabs, define_variable, left_val, right_val, + )?, + (BinaryOperator::And, Type::Boolean) => write!( + w, + "{}{} = {} && {};\n", + tabs, define_variable, left_val, right_val, + )?, + (BinaryOperator::Rem, Type::Float32) => write!( + w, + "{}{} = fmodf({}, {});\n", + tabs, define_variable, left_val, right_val, + )?, + (BinaryOperator::Rem, Type::Float64) => write!( + w, + "{}{} = fmod({}, {});\n", + tabs, define_variable, left_val, right_val, + )?, + (op, _) => write!( + w, + "{}{} = {} {} {};\n", + tabs, + define_variable, + left_val, + match op { + BinaryOperator::Add => "+", + BinaryOperator::Sub => "-", + BinaryOperator::Mul => "*", + BinaryOperator::Div => "/", + BinaryOperator::Rem => "%", + BinaryOperator::LT => "<", + BinaryOperator::LTE => "<=", + BinaryOperator::GT => ">", + BinaryOperator::GTE => ">=", + BinaryOperator::EQ => "==", + BinaryOperator::NE => "!=", + BinaryOperator::Or => "|", + BinaryOperator::And => "&", + BinaryOperator::Xor => "^", + BinaryOperator::LSh => "<<", + BinaryOperator::RSh => ">>", + }, + right_val, + )?, + }; } Node::Ternary { op, @@ -1175,7 +1182,7 @@ extern \"C\" {} {}(", ret_type.clone(), self.function.name)?; } else { self.get_value(args[0], false, false) }; - let cg_tile = self.get_cg_tile(id, CGType::Use); + let cg_tile = self.get_cg_tile(nesting_fork.unwrap(), CGType::Use); #[allow(unreachable_patterns)] let cg_op = match intrinsic { Intrinsic::Max => "max", @@ -1248,8 +1255,7 @@ extern \"C\" {} {}(", ret_type.clone(), self.function.name)?; write!(w, "{}{} = {};\n", tabs, define_variable, collect_variable)?; } _ => { - println!("Unsupported data node type: {:?}", self.function.nodes[id.idx()]); - panic!("Unsupported data node type") + panic!("Unsupported data node type: {:?}", self.function.nodes[id.idx()]) } } // Since reducts are responsible for updating Reduce nodes, we check and @@ -1404,8 +1410,8 @@ extern \"C\" {} {}(", ret_type.clone(), self.function.name)?; write!(w_term, "\tif (grid.thread_rank() == 0) {{\n")?; write!(w_term, "\t\t*(reinterpret_cast<{}>(ret)) = {};\n", return_type_ptr, return_val)?; write!(w_term, "\t}}\n")?; - write!(w_term, "\treturn;\n")?; } + write!(w_term, "\treturn;\n")?; 1 } _ => { -- GitLab