From d391d16354863a51c1d4ee8e3de4bce44b54402e Mon Sep 17 00:00:00 2001 From: prrathi <prrathi10@gmail.com> Date: Wed, 29 Jan 2025 16:25:22 +0000 Subject: [PATCH] mm dot works --- hercules_cg/src/gpu.rs | 29 ++++++----------------------- hercules_samples/matmul/src/main.rs | 12 +++++------- 2 files changed, 11 insertions(+), 30 deletions(-) diff --git a/hercules_cg/src/gpu.rs b/hercules_cg/src/gpu.rs index c3cb6634..ce52a20e 100644 --- a/hercules_cg/src/gpu.rs +++ b/hercules_cg/src/gpu.rs @@ -79,10 +79,6 @@ pub fn gpu_codegen<W: Write>( * - Add float8, float16, bfloat16 dtypes if they come */ - // Temporary for matmul (both true) and dot (thread true) test while we don't have schedule annotations - let block_parallel_override = false; - let thread_parallel_override = false; - let reduce_nodes: Vec<NodeID> = (0..function.nodes.len()) .filter(|idx| function.nodes[*idx].is_reduce()) .map(NodeID::new) @@ -164,8 +160,6 @@ pub fn gpu_codegen<W: Write>( threads_per_warp: 32, }; - std::fs::write("out.txt", "debug\n\n").unwrap(); - let ctx = GPUContext { function, types, @@ -185,8 +179,6 @@ pub fn gpu_codegen<W: Write>( control_data_phi_map, return_parameter, kernel_params, - block_parallel_override, - thread_parallel_override, }; ctx.codegen_function(w) } @@ -215,8 +207,6 @@ struct GPUContext<'a> { control_data_phi_map: HashMap<NodeID, Vec<(NodeID, NodeID)>>, return_parameter: Option<usize>, kernel_params: &'a GPUKernelParams, - block_parallel_override: bool, - thread_parallel_override: bool, } /* @@ -265,8 +255,6 @@ enum CGType { impl GPUContext<'_> { fn codegen_function<W: Write>(&self, w: &mut W) -> Result<(), Error> { - let mut file = OpenOptions::new().append(true).open("out.txt").unwrap(); - // Emit all code up to the "goto" to Start's block let mut top = String::new(); self.codegen_kernel_begin(self.return_parameter.is_none(), &mut top)?; @@ -289,19 +277,14 @@ impl GPUContext<'_> { // If there are no forks, fast forward to single-block, single-thread codegen let (num_blocks, num_threads) = if self.fork_join_map.is_empty() { - writeln!(file, "shortcut to 1b1t").unwrap(); self.codegen_data_control_no_forks(&HashSet::new(), &mut dynamic_shared_offset, &mut gotos)?; ("1".to_string(), "1".to_string()) } else { - writeln!(file, "no shortcut! fork tree: {:?}", self.fork_tree).unwrap(); // Create structures and determine block and thread parallelization strategy let (root_forks, num_blocks, is_block_parallel) = - self.get_root_forks_and_num_blocks(&self.fork_tree); - writeln!(file, "is_block_parallel: {}", is_block_parallel).unwrap(); - let (thread_root_root_fork, thread_root_forks) = self.get_thread_root_forks(&root_forks, &self.fork_tree, is_block_parallel); - writeln!(file, "thread_root_root_fork: {:?}", thread_root_root_fork).unwrap(); - let (fork_thread_quota_map, num_threads) = self.get_thread_quotas(&self.fork_tree, thread_root_root_fork); - writeln!(file, "fork_thread_quota_map: {:?}", fork_thread_quota_map).unwrap(); + self.get_root_forks_and_num_blocks(self.fork_tree); + let (thread_root_root_fork, thread_root_forks) = self.get_thread_root_forks(&root_forks, self.fork_tree, is_block_parallel); + let (fork_thread_quota_map, num_threads) = self.get_thread_quotas(self.fork_tree, thread_root_root_fork); // TODO: Uncomment and adjust once we know logic of extra dim // let extra_dim_collects = self.get_extra_dim_collects(&fork_control_map, &fork_thread_quota_map); let extra_dim_collects = HashSet::new(); @@ -590,7 +573,7 @@ extern \"C\" {} {}(", ret_type.clone(), self.function.name)?; self.collection_objects.origin(*object).try_parameter().is_some() }) }), "All collection reduces in block fork must originate from parameters"); - if self.block_parallel_override || self.function.schedules[root_fork.idx()].contains(&Schedule::ParallelFork) + if self.function.schedules[root_fork.idx()].contains(&Schedule::ParallelFork) { let fork_size = factors.iter().map(|dc| format!("dc{}", dc.idx())).collect::<Vec<_>>().join(" * "); (root_forks, fork_size, true) @@ -697,7 +680,7 @@ extern \"C\" {} {}(", ret_type.clone(), self.function.name)?; && fork_size.is_power_of_two() && reduces.iter().all(|&reduce| { self.function.schedules[reduce.idx()].contains(&Schedule::ParallelReduce) - || self.thread_parallel_override || self.function.schedules[reduce.idx()].contains(&Schedule::TightAssociative) + || self.function.schedules[reduce.idx()].contains(&Schedule::TightAssociative) }) { // If there's an associative Reduce, parallelize the larger factor @@ -710,7 +693,7 @@ extern \"C\" {} {}(", ret_type.clone(), self.function.name)?; // restriction doesn't help for parallel Writes, so nested parallelization // is possible. if reduces.iter().any(|&reduce| { - self.thread_parallel_override || self.function.schedules[reduce.idx()].contains(&Schedule::TightAssociative) + self.function.schedules[reduce.idx()].contains(&Schedule::TightAssociative) }) || fork_size > self.kernel_params.max_num_threads / subtree_quota { if fork_size >= subtree_quota { (HashMap::new(), fork_size, true) diff --git a/hercules_samples/matmul/src/main.rs b/hercules_samples/matmul/src/main.rs index 9421c773..7b6cfe79 100644 --- a/hercules_samples/matmul/src/main.rs +++ b/hercules_samples/matmul/src/main.rs @@ -10,13 +10,11 @@ juno_build::juno!("matmul"); fn main() { async_std::task::block_on(async { - const I: usize = 4; - const J: usize = 2; - const K: usize = 8; - // let mut a: Box<[i32]> = (0..I * J).map(|_| random::<i32>() % 100).collect(); - // let mut b: Box<[i32]> = (0..J * K).map(|_| random::<i32>() % 100).collect(); - let mut a: Box<[i32]> = (0..I * J).map(|i| (i as i32) % 100).collect(); - let mut b: Box<[i32]> = (0..J * K).map(|i| (i as i32) % 100).collect(); + const I: usize = 256; + const J: usize = 8; // hardcoded constant in matmul.hir + const K: usize = 128; + let mut a: Box<[i32]> = (0..I * J).map(|_| random::<i32>() % 100).collect(); + let mut b: Box<[i32]> = (0..J * K).map(|_| random::<i32>() % 100).collect(); let mut correct_c: Box<[i32]> = (0..I * K).map(|_| 0).collect(); for i in 0..I { for k in 0..K { -- GitLab