From cc90b62246ffed252546a40e26dba9d868bae6b3 Mon Sep 17 00:00:00 2001 From: Russel Arbore <russel.jma@gmail.com> Date: Fri, 31 Jan 2025 10:51:59 -0600 Subject: [PATCH 1/9] Start analysis for collections that don't need to be zeroed --- hercules_ir/src/collections.rs | 105 ++++++++++++++++++++++++++++++++- hercules_opt/src/slf.rs | 6 +- 2 files changed, 107 insertions(+), 4 deletions(-) diff --git a/hercules_ir/src/collections.rs b/hercules_ir/src/collections.rs index 9f421221..33b25731 100644 --- a/hercules_ir/src/collections.rs +++ b/hercules_ir/src/collections.rs @@ -1,4 +1,5 @@ -use std::collections::{BTreeMap, BTreeSet}; +use std::collections::{BTreeMap, BTreeSet, HashMap}; +use std::iter::once; use crate::*; @@ -347,3 +348,105 @@ pub fn collection_objects( collection_objects } + +/* + * The zero lattice determine what indices a collection has been written to so + * far. Reads from collection constants not at an index that's been written to + * return zero - if any such read exists for a collection constant, then that + * collection constant must be memset to zero explicitly. This is a similar + * analysis to store-to-load forwarding, but we are looking for cases where the + * store to forward might be a zero coming from a zero memset. + */ +#[derive(Debug, Clone, PartialEq, Eq, Hash)] +struct ZeroLattice { + written_to: BTreeSet<Box<[Index]>>, +} + +impl Semilattice for ZeroLattice { + fn meet(a: &Self, b: &Self) -> Self { + // Merge the two sets. Find equal indices sets between `a` and `b`. + let mut ret = BTreeSet::new(); + for indices in a.written_to.iter() { + if b.written_to.contains(indices) { + // If both sets have the same indices, add it to the meet value. + ret.insert(indices.clone()); + } + } + ZeroLattice { written_to: ret } + } + + fn top() -> Self { + ZeroLattice { + written_to: once(Box::new([]) as Box<[Index]>).collect(), + } + } + + fn bottom() -> Self { + ZeroLattice { + written_to: BTreeSet::new(), + } + } +} + +/* + * Analysis that determines what collection constants in a function don't need + * to be lowered to memsets. If a collection constant is only read at indices + * that it has first been written to, then the backing memory for the constant + * doesn't have to be set to zero, since those zeros are never read. + */ +pub fn no_reset_constant_collections( + function: &Function, + types: &Vec<Type>, + reverse_postorder: &Vec<NodeID>, + typing: &Vec<TypeID>, + objects: &CollectionObjects, + reduce_einsum: &(MathEnv, HashMap<NodeID, NodeID>), +) -> BTreeSet<NodeID> { + // First, run a dataflow analysis that determine at each collection node + // what indices have definitely been written to. + let lattice = forward_dataflow(function, reverse_postorder, |inputs, id| { + match function.nodes[id.idx()] { + Node::Phi { + control: _, + data: _, + } + | Node::Ternary { + op: TernaryOperator::Select, + first: _, + second: _, + third: _, + } => inputs.into_iter().fold(ZeroLattice::top(), |acc, input| { + ZeroLattice::meet(&acc, input) + }), + Node::Reduce { + control: _, + init: _, + reduct: _, + } => { + // If the einsum for this reduce node is a full array + // comprehension, then every array element is written to, and + // the empty indices set (the whole collection) is considered as + // written to. + let (env, exprs) = reduce_einsum; + if let MathExpr::Comprehension(_, _) = env[exprs[&id].idx()] { + ZeroLattice::top() + } + // Otherwise, meet the `init` and `reduct` inputs. + else { + ZeroLattice::meet(&inputs[0], &inputs[1]) + } + } + Node::Write { + collect: _, + data: _, + ref indices, + } => { + let mut value = inputs[0].clone(); + value.written_to.insert(indices.clone()); + value + } + _ => ZeroLattice::bottom(), + } + }); + todo!() +} diff --git a/hercules_opt/src/slf.rs b/hercules_opt/src/slf.rs index 981a0cce..92acb2a8 100644 --- a/hercules_opt/src/slf.rs +++ b/hercules_opt/src/slf.rs @@ -15,7 +15,7 @@ use crate::*; * insert an indices set with an empty positions list and a `None` value. */ #[derive(Debug, Clone, PartialEq, Eq, Hash)] -pub struct SLFLattice { +struct SLFLattice { subvalues: BTreeMap<Box<[Index]>, Option<NodeID>>, } @@ -25,7 +25,7 @@ impl Semilattice for SLFLattice { // keep their known sub-value if they're equivalent. All other indices // sets in `a` or `b` map to `None`. let mut ret = BTreeMap::new(); - for (indices, a_subvalue) in &a.subvalues { + for (indices, a_subvalue) in a.subvalues.iter() { if let Some(b_subvalue) = b.subvalues.get(indices) && a_subvalue == b_subvalue { @@ -39,7 +39,7 @@ impl Semilattice for SLFLattice { ret.insert(indices.clone(), None); } } - for (indices, _) in &b.subvalues { + for (indices, _) in b.subvalues.iter() { // Any indices sets in `b` that aren't in `ret` are indices sets // that aren't in `a`, so the sub-value isn't known. ret.entry(indices.clone()).or_insert(None); -- GitLab From 6c5517509d70820cab300df9d675387ec7882433 Mon Sep 17 00:00:00 2001 From: Russel Arbore <russel.jma@gmail.com> Date: Fri, 31 Jan 2025 11:13:42 -0600 Subject: [PATCH 2/9] Progress on no_reset_constant_collections --- hercules_ir/src/collections.rs | 57 ++++++++++++++++++++++++++++++++-- hercules_opt/src/utils.rs | 18 +++++++++++ 2 files changed, 73 insertions(+), 2 deletions(-) diff --git a/hercules_ir/src/collections.rs b/hercules_ir/src/collections.rs index 33b25731..5d50b14a 100644 --- a/hercules_ir/src/collections.rs +++ b/hercules_ir/src/collections.rs @@ -1,6 +1,8 @@ use std::collections::{BTreeMap, BTreeSet, HashMap}; use std::iter::once; +use either::Either; + use crate::*; /* @@ -61,6 +63,13 @@ impl CollectionObjectOrigin { _ => None, } } + + pub fn try_constant(&self) -> Option<NodeID> { + match self { + CollectionObjectOrigin::Constant(id) => Some(*id), + _ => None, + } + } } impl FunctionCollectionObjects { @@ -399,10 +408,10 @@ pub fn no_reset_constant_collections( types: &Vec<Type>, reverse_postorder: &Vec<NodeID>, typing: &Vec<TypeID>, - objects: &CollectionObjects, + objects: &FunctionCollectionObjects, reduce_einsum: &(MathEnv, HashMap<NodeID, NodeID>), ) -> BTreeSet<NodeID> { - // First, run a dataflow analysis that determine at each collection node + // First, run a dataflow analysis that determines at each collection node // what indices have definitely been written to. let lattice = forward_dataflow(function, reverse_postorder, |inputs, id| { match function.nodes[id.idx()] { @@ -448,5 +457,49 @@ pub fn no_reset_constant_collections( _ => ZeroLattice::bottom(), } }); + + // Second, collect triples from reads in the program containing: + // 1. What indices are read at that point. + // 2. What collection constants may be read. + // 3. What indices were written to by that point. + let full_indices: Box<[Index]> = Box::new([]); + let triples = (0..function.nodes.len()).filter_map(|idx| { + // Item #1. + let id = NodeID::new(idx); + let (collect, indices) = match function.nodes[id.idx()] { + Node::Read { + collect, + ref indices, + } => (Either::Left(once(collect)), indices), + Node::Write { + collect: _, + data, + indices: _, + } => (Either::Left(once(data)), &full_indices), + Node::Call { + control: _, + function: _, + dynamic_constants: _, + ref args, + } => (Either::Right(args.into_iter().map(|id| *id)), &full_indices), + _ => return None, + }; + + // Item #2. + let constants: Vec<_> = collect + .map(|collect| objects.objects(collect).into_iter()) + .flatten() + .filter_map(|obj| objects.origin(*obj).try_constant()) + .collect(); + if constants.is_empty() { + return None; + } + + // Item #3. + let written_to = &lattice[id.idx()].written_to; + + Some((indices, constants, written_to)) + }); + todo!() } diff --git a/hercules_opt/src/utils.rs b/hercules_opt/src/utils.rs index 2ab4e094..a2ca5028 100644 --- a/hercules_opt/src/utils.rs +++ b/hercules_opt/src/utils.rs @@ -370,6 +370,24 @@ pub(crate) fn indices_may_overlap(indices1: &[Index], indices2: &[Index]) -> boo true } +/* + * Helper function to determine if a list of indices A definitely contains a + * list of indices B. + */ +pub(crate) fn indices_contain_other_indices(indices_a: &[Index], indices_b: &[Index]) -> bool { + if indices_a.len() < indices_b.len() { + return false; + } + + for (idx1, idx2) in zip(indices_a, indices_b) { + if idx1 != idx2 { + return false; + } + } + + true +} + pub type DenseNodeMap<T> = Vec<T>; pub type SparseNodeMap<T> = HashMap<NodeID, T>; -- GitLab From 7b313282541b12bdd028101ea09b544d1dbec3d1 Mon Sep 17 00:00:00 2001 From: Russel Arbore <russel.jma@gmail.com> Date: Sat, 1 Feb 2025 10:29:46 -0600 Subject: [PATCH 3/9] Finix no reset analysis --- hercules_ir/src/collections.rs | 106 ++++++++++++++++++++------------- hercules_ir/src/ir.rs | 67 +++++++++++++++++++++ hercules_opt/src/utils.rs | 66 -------------------- 3 files changed, 132 insertions(+), 107 deletions(-) diff --git a/hercules_ir/src/collections.rs b/hercules_ir/src/collections.rs index 5d50b14a..e57a7756 100644 --- a/hercules_ir/src/collections.rs +++ b/hercules_ir/src/collections.rs @@ -1,5 +1,5 @@ use std::collections::{BTreeMap, BTreeSet, HashMap}; -use std::iter::once; +use std::iter::{once, repeat, zip}; use either::Either; @@ -459,47 +459,71 @@ pub fn no_reset_constant_collections( }); // Second, collect triples from reads in the program containing: - // 1. What indices are read at that point. - // 2. What collection constants may be read. - // 3. What indices were written to by that point. + // 1. What indices are read. + // 2. What node is read. + // 3. What indices were written into the node already. let full_indices: Box<[Index]> = Box::new([]); - let triples = (0..function.nodes.len()).filter_map(|idx| { - // Item #1. - let id = NodeID::new(idx); - let (collect, indices) = match function.nodes[id.idx()] { - Node::Read { - collect, - ref indices, - } => (Either::Left(once(collect)), indices), - Node::Write { - collect: _, - data, - indices: _, - } => (Either::Left(once(data)), &full_indices), - Node::Call { - control: _, - function: _, - dynamic_constants: _, - ref args, - } => (Either::Right(args.into_iter().map(|id| *id)), &full_indices), - _ => return None, - }; - - // Item #2. - let constants: Vec<_> = collect - .map(|collect| objects.objects(collect).into_iter()) - .flatten() - .filter_map(|obj| objects.origin(*obj).try_constant()) - .collect(); - if constants.is_empty() { - return None; + let triples = (0..function.nodes.len()) + .map(|idx| { + // Items #1 and #2. + let id = NodeID::new(idx); + let indices_and_nodes = match function.nodes[id.idx()] { + Node::Read { + collect, + ref indices, + } => Either::Left(zip(once(indices), once(collect))), + Node::Write { + collect: _, + data, + indices: _, + } + | Node::Return { control: _, data } => { + Either::Left(zip(once(&full_indices), once(data))) + } + Node::Call { + control: _, + function: _, + dynamic_constants: _, + ref args, + } => Either::Right(zip(repeat(&full_indices), args.into_iter().map(|id| *id))), + _ => return None, + }; + + // Item #3. + Some( + indices_and_nodes + .map(|(indices, read)| (indices, read, &lattice[read.idx()].written_to)), + ) + }) + .flatten() + .flatten(); + + // Third, look at each triple and check if there is a read that reads at + // indices not contained by any of the written to indices. If so, then any + // constant collection that may originate the object at that place has to be + // explicitly memset to zero. The result starts as every constant collection + // in the function, and is progressively shaved of constants that need to be + // reset explicitly. + let mut result: BTreeSet<_> = (0..function.nodes.len()) + .filter(|idx| { + function.nodes[*idx].is_constant() && !types[typing[*idx].idx()].is_primitive() + }) + .map(NodeID::new) + .collect(); + for (indices, node, written_to) in triples { + if !written_to + .into_iter() + .all(|written_to| indices_contain_other_indices(written_to, indices)) + { + for constant in objects + .objects(node) + .into_iter() + .filter_map(|obj| objects.origin(*obj).try_constant()) + { + result.remove(&constant); + } } + } - // Item #3. - let written_to = &lattice[id.idx()].written_to; - - Some((indices, constants, written_to)) - }); - - todo!() + result } diff --git a/hercules_ir/src/ir.rs b/hercules_ir/src/ir.rs index 187b3f98..30c2e4fa 100644 --- a/hercules_ir/src/ir.rs +++ b/hercules_ir/src/ir.rs @@ -1,5 +1,6 @@ use std::collections::HashSet; use std::fmt::Write; +use std::iter::zip; use std::ops::Coroutine; use std::ops::CoroutineState; use std::pin::Pin; @@ -1801,6 +1802,72 @@ impl Device { } } +/* + * Helper function to tell if two lists of indices have the same structure. + */ +pub fn indices_structurally_equivalent(indices1: &[Index], indices2: &[Index]) -> bool { + if indices1.len() == indices2.len() { + let mut equiv = true; + for pair in zip(indices1, indices2) { + equiv = equiv + && match pair { + (Index::Field(idx1), Index::Field(idx2)) => idx1 == idx2, + (Index::Variant(idx1), Index::Variant(idx2)) => idx1 == idx2, + (Index::Position(ref pos1), Index::Position(ref pos2)) => { + assert_eq!(pos1.len(), pos2.len()); + true + } + _ => false, + }; + } + equiv + } else { + false + } +} + +/* + * Helper function to determine if two lists of indices may overlap. + */ +pub fn indices_may_overlap(indices1: &[Index], indices2: &[Index]) -> bool { + for pair in zip(indices1, indices2) { + match pair { + // Check that the field numbers are the same. + (Index::Field(idx1), Index::Field(idx2)) => { + if idx1 != idx2 { + return false; + } + } + // Variant indices always may overlap, since it's the same + // underlying memory. Position indices always may overlap, since the + // indexing nodes may be the same at runtime. + (Index::Variant(_), Index::Variant(_)) | (Index::Position(_), Index::Position(_)) => {} + _ => panic!(), + } + } + // `zip` will exit as soon as either iterator is done - two sets of indices + // may overlap when one indexes a larger sub-value than the other. + true +} + +/* + * Helper function to determine if a list of indices A definitely contains a + * list of indices B. + */ +pub fn indices_contain_other_indices(indices_a: &[Index], indices_b: &[Index]) -> bool { + if indices_a.len() < indices_b.len() { + return false; + } + + for (idx1, idx2) in zip(indices_a, indices_b) { + if idx1 != idx2 { + return false; + } + } + + true +} + /* * Rust things to make newtyped IDs usable. */ diff --git a/hercules_opt/src/utils.rs b/hercules_opt/src/utils.rs index a2ca5028..0c6c2fac 100644 --- a/hercules_opt/src/utils.rs +++ b/hercules_opt/src/utils.rs @@ -322,72 +322,6 @@ pub fn ensure_between_control_flow(editor: &mut FunctionEditor) -> Option<NodeID } } -/* - * Helper function to tell if two lists of indices have the same structure. - */ -pub(crate) fn indices_structurally_equivalent(indices1: &[Index], indices2: &[Index]) -> bool { - if indices1.len() == indices2.len() { - let mut equiv = true; - for pair in zip(indices1, indices2) { - equiv = equiv - && match pair { - (Index::Field(idx1), Index::Field(idx2)) => idx1 == idx2, - (Index::Variant(idx1), Index::Variant(idx2)) => idx1 == idx2, - (Index::Position(ref pos1), Index::Position(ref pos2)) => { - assert_eq!(pos1.len(), pos2.len()); - true - } - _ => false, - }; - } - equiv - } else { - false - } -} - -/* - * Helper function to determine if two lists of indices may overlap. - */ -pub(crate) fn indices_may_overlap(indices1: &[Index], indices2: &[Index]) -> bool { - for pair in zip(indices1, indices2) { - match pair { - // Check that the field numbers are the same. - (Index::Field(idx1), Index::Field(idx2)) => { - if idx1 != idx2 { - return false; - } - } - // Variant indices always may overlap, since it's the same - // underlying memory. Position indices always may overlap, since the - // indexing nodes may be the same at runtime. - (Index::Variant(_), Index::Variant(_)) | (Index::Position(_), Index::Position(_)) => {} - _ => panic!(), - } - } - // `zip` will exit as soon as either iterator is done - two sets of indices - // may overlap when one indexes a larger sub-value than the other. - true -} - -/* - * Helper function to determine if a list of indices A definitely contains a - * list of indices B. - */ -pub(crate) fn indices_contain_other_indices(indices_a: &[Index], indices_b: &[Index]) -> bool { - if indices_a.len() < indices_b.len() { - return false; - } - - for (idx1, idx2) in zip(indices_a, indices_b) { - if idx1 != idx2 { - return false; - } - } - - true -} - pub type DenseNodeMap<T> = Vec<T>; pub type SparseNodeMap<T> = HashMap<NodeID, T>; -- GitLab From 6a760fa689ff829a0e9c48486eefe0691209542f Mon Sep 17 00:00:00 2001 From: Russel Arbore <russel.jma@gmail.com> Date: Sat, 1 Feb 2025 10:34:05 -0600 Subject: [PATCH 4/9] add link search path for juno_build --- juno_build/build.rs | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/juno_build/build.rs b/juno_build/build.rs index 7ba34c8c..b2464ecb 100644 --- a/juno_build/build.rs +++ b/juno_build/build.rs @@ -1,8 +1,9 @@ fn main() { #[cfg(feature = "cuda")] - println!("cargo::rustc-link-search=native=/usr/lib/x86_64-linux-gnu/"); - #[cfg(feature = "cuda")] - println!("cargo:rustc-link-search=native=/usr/local/cuda/lib64"); - #[cfg(feature = "cuda")] - println!("cargo:rustc-link-lib=cudart"); + { + println!("cargo::rustc-link-search=native=/usr/lib/x86_64-linux-gnu/"); + println!("cargo:rustc-link-search=native=/usr/local/cuda/lib64"); + println!("cargo::rustc-link-search=native=/opt/cuda/lib/"); + println!("cargo:rustc-link-lib=cudart"); + } } -- GitLab From 0268c3a4cf01d1be18238eec0ee13862c4ba9ffc Mon Sep 17 00:00:00 2001 From: Russel Arbore <russel.jma@gmail.com> Date: Sat, 1 Feb 2025 11:02:31 -0600 Subject: [PATCH 5/9] use analysis to schedule constant nodes --- hercules_ir/src/collections.rs | 8 +-- hercules_ir/src/ir.rs | 2 + hercules_opt/src/editor.rs | 13 +++-- hercules_opt/src/schedule.rs | 15 +++++- hercules_samples/matmul/src/cpu.sch | 7 ++- juno_scheduler/src/pm.rs | 78 ++++++++++++++++++++++++----- 6 files changed, 98 insertions(+), 25 deletions(-) diff --git a/hercules_ir/src/collections.rs b/hercules_ir/src/collections.rs index e57a7756..d9d6c8f7 100644 --- a/hercules_ir/src/collections.rs +++ b/hercules_ir/src/collections.rs @@ -409,7 +409,7 @@ pub fn no_reset_constant_collections( reverse_postorder: &Vec<NodeID>, typing: &Vec<TypeID>, objects: &FunctionCollectionObjects, - reduce_einsum: &(MathEnv, HashMap<NodeID, NodeID>), + reduce_einsum: &(MathEnv, HashMap<NodeID, MathID>), ) -> BTreeSet<NodeID> { // First, run a dataflow analysis that determines at each collection node // what indices have definitely been written to. @@ -437,7 +437,9 @@ pub fn no_reset_constant_collections( // the empty indices set (the whole collection) is considered as // written to. let (env, exprs) = reduce_einsum; - if let MathExpr::Comprehension(_, _) = env[exprs[&id].idx()] { + if let Some(expr) = exprs.get(&id) + && let MathExpr::Comprehension(_, _) = env[expr.idx()] + { ZeroLattice::top() } // Otherwise, meet the `init` and `reduct` inputs. @@ -513,7 +515,7 @@ pub fn no_reset_constant_collections( for (indices, node, written_to) in triples { if !written_to .into_iter() - .all(|written_to| indices_contain_other_indices(written_to, indices)) + .any(|written_to| indices_contain_other_indices(written_to, indices)) { for constant in objects .objects(node) diff --git a/hercules_ir/src/ir.rs b/hercules_ir/src/ir.rs index 30c2e4fa..846347b0 100644 --- a/hercules_ir/src/ir.rs +++ b/hercules_ir/src/ir.rs @@ -327,6 +327,8 @@ pub enum Schedule { // This reduce can be re-associated. This may lower a sequential dependency // chain into a reduction tree. TightAssociative, + // This constant node doesn't need to be memset to zero. + NoResetConstant, } /* diff --git a/hercules_opt/src/editor.rs b/hercules_opt/src/editor.rs index 1d586057..43a88b7c 100644 --- a/hercules_opt/src/editor.rs +++ b/hercules_opt/src/editor.rs @@ -557,19 +557,22 @@ impl<'a, 'b> FunctionEdit<'a, 'b> { pub fn add_schedule(mut self, id: NodeID, schedule: Schedule) -> Result<Self, Self> { if self.is_mutable(id) { if let Some(schedules) = self.added_and_updated_schedules.get_mut(&id) { - schedules.push(schedule); + if !schedules.contains(&schedule) { + schedules.push(schedule); + } } else { - let mut schedules = self + let empty = vec![]; + let schedules = self .editor .function .schedules .get(id.idx()) - .unwrap_or(&vec![]) - .clone(); + .unwrap_or(&empty); if !schedules.contains(&schedule) { + let mut schedules = schedules.clone(); schedules.push(schedule); + self.added_and_updated_schedules.insert(id, schedules); } - self.added_and_updated_schedules.insert(id, schedules); } Ok(self) } else { diff --git a/hercules_opt/src/schedule.rs b/hercules_opt/src/schedule.rs index f9f720be..fe894e47 100644 --- a/hercules_opt/src/schedule.rs +++ b/hercules_opt/src/schedule.rs @@ -1,4 +1,4 @@ -use std::collections::{HashMap, HashSet}; +use std::collections::{BTreeSet, HashMap, HashSet}; use hercules_ir::def_use::*; use hercules_ir::ir::*; @@ -181,3 +181,16 @@ pub fn infer_tight_associative( } } } + +/* + * From analysis result of which constants don't need to be reset, add schedules + * to those constant nodes. + */ +pub fn infer_no_reset_constants( + editor: &mut FunctionEditor, + no_reset_constants: &BTreeSet<NodeID>, +) { + for id in no_reset_constants { + editor.edit(|edit| edit.add_schedule(*id, Schedule::NoResetConstant)); + } +} diff --git a/hercules_samples/matmul/src/cpu.sch b/hercules_samples/matmul/src/cpu.sch index aeed7e10..0321e13d 100644 --- a/hercules_samples/matmul/src/cpu.sch +++ b/hercules_samples/matmul/src/cpu.sch @@ -7,15 +7,14 @@ auto-outline(*); ip-sroa(*); sroa(*); dce(*); -infer-schedules(*); +fixpoint { + infer-schedules(*); +} fork-split(*); unforkify(*); dce(*); -float-collections(*); gvn(*); phi-elim(*); dce(*); -infer-schedules(*); - gcm(*); diff --git a/juno_scheduler/src/pm.rs b/juno_scheduler/src/pm.rs index 2371e0f2..f6fe2fc1 100644 --- a/juno_scheduler/src/pm.rs +++ b/juno_scheduler/src/pm.rs @@ -184,6 +184,7 @@ pub struct PassManager { pub reduce_cycles: Option<Vec<HashMap<NodeID, HashSet<NodeID>>>>, pub data_nodes_in_fork_joins: Option<Vec<HashMap<NodeID, HashSet<NodeID>>>>, pub reduce_einsums: Option<Vec<(MathEnv, HashMap<NodeID, MathID>)>>, + pub no_reset_constants: Option<Vec<BTreeSet<NodeID>>>, pub collection_objects: Option<CollectionObjects>, pub callgraph: Option<CallGraph>, pub devices: Option<Vec<Device>>, @@ -222,6 +223,7 @@ impl PassManager { reduce_cycles: None, data_nodes_in_fork_joins: None, reduce_einsums: None, + no_reset_constants: None, collection_objects: None, callgraph: None, devices: None, @@ -345,7 +347,12 @@ impl PassManager { if self.fork_control_maps.is_none() { self.make_fork_join_nests(); self.fork_control_maps = Some( - self.fork_join_nests.as_ref().unwrap().iter().map(fork_control_map).collect(), + self.fork_join_nests + .as_ref() + .unwrap() + .iter() + .map(fork_control_map) + .collect(), ); } } @@ -463,6 +470,43 @@ impl PassManager { } } + pub fn make_no_reset_constants(&mut self) { + if self.no_reset_constants.is_none() { + self.make_reverse_postorders(); + self.make_typing(); + self.make_collection_objects(); + self.make_reduce_einsums(); + let reverse_postorders = self.reverse_postorders.as_ref().unwrap().iter(); + let typing = self.typing.as_ref().unwrap().iter(); + let collection_objects = self.collection_objects.as_ref().unwrap().iter(); + let reduce_einsums = self.reduce_einsums.as_ref().unwrap().iter(); + self.no_reset_constants = Some( + self.functions + .iter() + .zip(reverse_postorders) + .zip(typing) + .zip(collection_objects) + .zip(reduce_einsums) + .map( + |( + (((function, reverse_postorder), typing), collection_object), + reduce_einsum, + )| { + no_reset_constant_collections( + function, + &self.types.borrow(), + reverse_postorder, + typing, + collection_object.1, + reduce_einsum, + ) + }, + ) + .collect(), + ); + } + } + pub fn make_collection_objects(&mut self) { if self.collection_objects.is_none() { self.make_reverse_postorders(); @@ -535,6 +579,7 @@ impl PassManager { self.reduce_cycles = None; self.data_nodes_in_fork_joins = None; self.reduce_einsums = None; + self.no_reset_constants = None; self.collection_objects = None; self.callgraph = None; self.devices = None; @@ -716,8 +761,8 @@ impl PassManager { let mut llvm_path = tmp_dir.path().to_path_buf(); llvm_path.push(format!("{}.ll", module_name)); println!("{}", llvm_path.display()); - let mut file = File::create(&llvm_path) - .expect("PANIC: Unable to open output LLVM IR file."); + let mut file = + File::create(&llvm_path).expect("PANIC: Unable to open output LLVM IR file."); file.write_all(llvm_ir.as_bytes()) .expect("PANIC: Unable to write output LLVM IR file contents."); @@ -738,13 +783,17 @@ impl PassManager { let mut ar_args = vec!["crus", &output_archive, &llvm_object]; - let cuda_object = format!("{}/{}_cuda.o", tmp_dir.path().to_str().unwrap(), module_name); + let cuda_object = format!( + "{}/{}_cuda.o", + tmp_dir.path().to_str().unwrap(), + module_name + ); if cfg!(feature = "cuda") { // Write the CUDA IR into a temporary file. let mut cuda_path = tmp_dir.path().to_path_buf(); cuda_path.push(format!("{}.cu", module_name)); - let mut file = File::create(&cuda_path) - .expect("PANIC: Unable to open output CUDA IR file."); + let mut file = + File::create(&cuda_path).expect("PANIC: Unable to open output CUDA IR file."); file.write_all(cuda_ir.as_bytes()) .expect("PANIC: Unable to write output CUDA IR file contents."); @@ -770,8 +819,8 @@ impl PassManager { // Write the Rust runtime into a file. let output_rt = format!("{}/rt_{}.hrt", output_dir, module_name); println!("{}", output_rt); - let mut file = File::create(&output_rt) - .expect("PANIC: Unable to open output Rust runtime file."); + let mut file = + File::create(&output_rt).expect("PANIC: Unable to open output Rust runtime file."); file.write_all(rust_rt.as_bytes()) .expect("PANIC: Unable to write output Rust runtime file contents."); @@ -1592,12 +1641,16 @@ fn run_pass( assert!(args.is_empty()); pm.make_fork_join_maps(); pm.make_reduce_cycles(); + pm.make_no_reset_constants(); let fork_join_maps = pm.fork_join_maps.take().unwrap(); let reduce_cycles = pm.reduce_cycles.take().unwrap(); - for ((func, fork_join_map), reduce_cycles) in build_selection(pm, selection) - .into_iter() - .zip(fork_join_maps.iter()) - .zip(reduce_cycles.iter()) + let no_reset_constants = pm.no_reset_constants.take().unwrap(); + for (((func, fork_join_map), reduce_cycles), no_reset_constants) in + build_selection(pm, selection) + .into_iter() + .zip(fork_join_maps.iter()) + .zip(reduce_cycles.iter()) + .zip(no_reset_constants.iter()) { let Some(mut func) = func else { continue; @@ -1606,6 +1659,7 @@ fn run_pass( infer_parallel_fork(&mut func, fork_join_map); infer_vectorizable(&mut func, fork_join_map); infer_tight_associative(&mut func, reduce_cycles); + infer_no_reset_constants(&mut func, no_reset_constants); changed |= func.modified(); } pm.delete_gravestones(); -- GitLab From f4809b03f948e3f5f30581b2db5ac95ae9be4e27 Mon Sep 17 00:00:00 2001 From: Russel Arbore <russel.jma@gmail.com> Date: Sat, 1 Feb 2025 11:10:20 -0600 Subject: [PATCH 6/9] Use NoResetConstant in backends --- hercules_cg/src/cpu.rs | 16 +- hercules_cg/src/gpu.rs | 622 ++++++++++++++++++++++------ hercules_cg/src/rt.rs | 18 +- hercules_samples/matmul/src/gpu.sch | 7 +- 4 files changed, 519 insertions(+), 144 deletions(-) diff --git a/hercules_cg/src/cpu.rs b/hercules_cg/src/cpu.rs index 344554b6..f6a1f309 100644 --- a/hercules_cg/src/cpu.rs +++ b/hercules_cg/src/cpu.rs @@ -288,13 +288,15 @@ impl<'a> CPUContext<'a> { self.get_value(id, false), offset.idx() )?; - let data_size = self.codegen_type_size(self.typing[id.idx()], body)?; - write!( - body, - " call void @llvm.memset.p0.i64({}, i8 0, i64 {}, i1 false)\n", - self.get_value(id, true), - data_size, - )?; + if !self.function.schedules[id.idx()].contains(&Schedule::NoResetConstant) { + let data_size = self.codegen_type_size(self.typing[id.idx()], body)?; + write!( + body, + " call void @llvm.memset.p0.i64({}, i8 0, i64 {}, i1 false)\n", + self.get_value(id, true), + data_size, + )?; + } } } Node::DynamicConstant { id: dc_id } => { diff --git a/hercules_cg/src/gpu.rs b/hercules_cg/src/gpu.rs index d7a6d258..81e31396 100644 --- a/hercules_cg/src/gpu.rs +++ b/hercules_cg/src/gpu.rs @@ -3,7 +3,7 @@ extern crate hercules_ir; use std::collections::{BTreeMap, HashMap, HashSet}; use std::fmt::{Error, Write}; -use std::fs::{OpenOptions, File}; +use std::fs::{File, OpenOptions}; use std::io::Write as _; use self::hercules_ir::*; @@ -85,7 +85,6 @@ pub fn gpu_codegen<W: Write>( .map(NodeID::new) .collect(); - let join_fork_map: HashMap<NodeID, NodeID> = fork_join_map .iter() .map(|(fork, join)| (*join, *fork)) @@ -104,7 +103,7 @@ pub fn gpu_codegen<W: Write>( } = &function.nodes[reduce_node.idx()] { match function.nodes[control.idx()] { - Node::Join {..} => { + Node::Join { .. } => { let fork_node = join_fork_map[control]; fork_reduce_map .entry(fork_node) @@ -126,9 +125,10 @@ pub fn gpu_codegen<W: Write>( } for idx in 0..function.nodes.len() { if function.nodes[idx].is_fork() { - assert!(fork_reduce_map - .get(&NodeID::new(idx)) - .is_some_and(|reduces| !reduces.is_empty()), + assert!( + fork_reduce_map + .get(&NodeID::new(idx)) + .is_some_and(|reduces| !reduces.is_empty()), "Fork node {} has no reduce nodes", idx ); @@ -144,14 +144,18 @@ pub fn gpu_codegen<W: Write>( panic!("Phi's control must be a region node"); }; for (i, &pred) in preds.iter().enumerate() { - control_data_phi_map.entry(pred).or_default().push((data[i], NodeID::new(idx))); + control_data_phi_map + .entry(pred) + .or_default() + .push((data[i], NodeID::new(idx))); } } } let return_parameter = if collection_objects.returned_objects().len() == 1 { - collection_objects.origin(*collection_objects.returned_objects() - .first().unwrap()).try_parameter() + collection_objects + .origin(*collection_objects.returned_objects().first().unwrap()) + .try_parameter() } else { None }; @@ -280,14 +284,20 @@ impl GPUContext<'_> { // If there are no forks, fast forward to single-block, single-thread codegen let (num_blocks, num_threads) = if self.fork_join_map.is_empty() { - self.codegen_data_control_no_forks(&HashSet::new(), &mut dynamic_shared_offset, &mut gotos)?; + self.codegen_data_control_no_forks( + &HashSet::new(), + &mut dynamic_shared_offset, + &mut gotos, + )?; ("1".to_string(), "1".to_string()) } else { // Create structures and determine block and thread parallelization strategy let (root_forks, num_blocks, is_block_parallel) = self.get_root_forks_and_num_blocks(self.fork_tree); - let (thread_root_root_fork, thread_root_forks) = self.get_thread_root_forks(&root_forks, self.fork_tree, is_block_parallel); - let (fork_thread_quota_map, num_threads) = self.get_thread_quotas(self.fork_tree, thread_root_root_fork); + let (thread_root_root_fork, thread_root_forks) = + self.get_thread_root_forks(&root_forks, self.fork_tree, is_block_parallel); + let (fork_thread_quota_map, num_threads) = + self.get_thread_quotas(self.fork_tree, thread_root_root_fork); // TODO: Uncomment and adjust once we know logic of extra dim. This will affect constant // collections, reads, and writes. // let extra_dim_collects = self.get_extra_dim_collects(&fork_control_map, &fork_thread_quota_map); @@ -319,7 +329,12 @@ impl GPUContext<'_> { // Emit host launch code let mut host_launch = String::new(); - self.codegen_launch_code(num_blocks, num_threads, &dynamic_shared_offset, &mut host_launch)?; + self.codegen_launch_code( + num_blocks, + num_threads, + &dynamic_shared_offset, + &mut host_launch, + )?; write!(w, "{}", host_launch)?; Ok(()) @@ -329,7 +344,9 @@ impl GPUContext<'_> { * Emit kernel headers, signature, arguments, and dynamic shared memory declaration */ fn codegen_kernel_begin(&self, has_ret_var: bool, w: &mut String) -> Result<(), Error> { - write!(w, " + write!( + w, + " #include <assert.h> #include <stdio.h> #include <stddef.h> @@ -390,10 +407,7 @@ namespace cg = cooperative_groups; if !first_param { write!(w, ", ")?; } - write!( - w, - "void* __restrict__ ret", - )?; + write!(w, "void* __restrict__ ret",)?; } // Type is char since it's simplest to use single bytes for indexing @@ -425,24 +439,48 @@ namespace cg = cooperative_groups; } } DynamicConstant::Add(args) => { - let rhs = args.iter().map(|arg| format!("dc{}", arg.idx())).collect::<Vec<_>>().join(" + "); + let rhs = args + .iter() + .map(|arg| format!("dc{}", arg.idx())) + .collect::<Vec<_>>() + .join(" + "); write!(w, "\t{} = {};\n", dc_val, rhs)? } DynamicConstant::Mul(args) => { - let rhs = args.iter().map(|arg| format!("dc{}", arg.idx())).collect::<Vec<_>>().join(" * "); + let rhs = args + .iter() + .map(|arg| format!("dc{}", arg.idx())) + .collect::<Vec<_>>() + .join(" * "); write!(w, "\t{} = {};\n", dc_val, rhs)? } DynamicConstant::Min(args) => { - let rhs_but_last: String = args.iter().take(args.len() - 1).map(|arg| format!("min(dc{}, ", arg.idx())).collect(); + let rhs_but_last: String = args + .iter() + .take(args.len() - 1) + .map(|arg| format!("min(dc{}, ", arg.idx())) + .collect(); let rhs_last = format!("dc{}", args.last().unwrap().idx()); let rhs_end: String = std::iter::repeat(")").take(args.len() - 1).collect(); - write!(w, "\t{} = {}{}{};\n", dc_val, rhs_but_last, rhs_last, rhs_end)? + write!( + w, + "\t{} = {}{}{};\n", + dc_val, rhs_but_last, rhs_last, rhs_end + )? } DynamicConstant::Max(args) => { - let rhs_but_last: String = args.iter().take(args.len() - 1).map(|arg| format!("max(dc{}, ", arg.idx())).collect(); + let rhs_but_last: String = args + .iter() + .take(args.len() - 1) + .map(|arg| format!("max(dc{}, ", arg.idx())) + .collect(); let rhs_last = format!("dc{}", args.last().unwrap().idx()); let rhs_end: String = std::iter::repeat(")").take(args.len() - 1).collect(); - write!(w, "\t{} = {}{}{};\n", dc_val, rhs_but_last, rhs_last, rhs_end)? + write!( + w, + "\t{} = {}{}{};\n", + dc_val, rhs_but_last, rhs_last, rhs_end + )? } DynamicConstant::Sub(left, right) => { write!(w, "\t{} = dc{} - dc{};\n", dc_val, left.idx(), right.idx())? @@ -464,9 +502,10 @@ namespace cg = cooperative_groups; */ fn codegen_declare_data(&self, w: &mut String) -> Result<(), Error> { for id in (0..self.function.nodes.len()).map(NodeID::new) { - if !self.function.nodes[id.idx()].is_control() && - !self.function.nodes[id.idx()].is_dynamic_constant() && - !self.function.nodes[id.idx()].is_parameter() { + if !self.function.nodes[id.idx()].is_control() + && !self.function.nodes[id.idx()].is_dynamic_constant() + && !self.function.nodes[id.idx()].is_parameter() + { write!(w, "\t{};\n", self.get_value(id, true, false))?; } if self.function.nodes[id.idx()].is_phi() { @@ -493,7 +532,12 @@ namespace cg = cooperative_groups; Ok(()) } - fn codegen_gotos(&self, goto_debug: bool, gotos: &mut BTreeMap<NodeID, CudaGoto>, w: &mut String) -> Result<(), Error> { + fn codegen_gotos( + &self, + goto_debug: bool, + gotos: &mut BTreeMap<NodeID, CudaGoto>, + w: &mut String, + ) -> Result<(), Error> { write!(w, "\n")?; for (id, goto) in gotos.iter() { let goto_block = self.get_block_name(*id, false); @@ -513,13 +557,24 @@ namespace cg = cooperative_groups; Ok(()) } - fn codegen_launch_code(&self, num_blocks: String, num_threads: String, dynamic_shared_offset: &str, w: &mut String) -> Result<(), Error> { + fn codegen_launch_code( + &self, + num_blocks: String, + num_threads: String, + dynamic_shared_offset: &str, + w: &mut String, + ) -> Result<(), Error> { // The following steps are for host-side C function arguments, but we also // need to pass arguments to kernel, so we keep track of the arguments here. let ret_type = self.get_type(self.function.return_type, false); let mut pass_args = String::new(); - write!(w, " -extern \"C\" {} {}(", ret_type.clone(), self.function.name)?; + write!( + w, + " +extern \"C\" {} {}(", + ret_type.clone(), + self.function.name + )?; let mut first_param = true; // The first parameter is a pointer to GPU backing memory, if it's // needed. @@ -566,16 +621,30 @@ extern \"C\" {} {}(", ret_type.clone(), self.function.name)?; write!(w, "\tcudaMalloc((void**)&ret, sizeof({}));\n", ret_type)?; } write!(w, "\tcudaError_t err;\n"); - write!(w, "\t{}_gpu<<<{}, {}, {}>>>({});\n", self.function.name, num_blocks, num_threads, dynamic_shared_offset, pass_args)?; + write!( + w, + "\t{}_gpu<<<{}, {}, {}>>>({});\n", + self.function.name, num_blocks, num_threads, dynamic_shared_offset, pass_args + )?; write!(w, "\terr = cudaGetLastError();\n"); - write!(w, "\tif (cudaSuccess != err) {{ printf(\"Error1: %s\\n\", cudaGetErrorString(err)); }}\n"); + write!( + w, + "\tif (cudaSuccess != err) {{ printf(\"Error1: %s\\n\", cudaGetErrorString(err)); }}\n" + ); write!(w, "\tcudaDeviceSynchronize();\n")?; write!(w, "\terr = cudaGetLastError();\n"); - write!(w, "\tif (cudaSuccess != err) {{ printf(\"Error2: %s\\n\", cudaGetErrorString(err)); }}\n"); + write!( + w, + "\tif (cudaSuccess != err) {{ printf(\"Error2: %s\\n\", cudaGetErrorString(err)); }}\n" + ); if has_ret_var { // Copy return from device to host, whether it's primitive value or collection pointer write!(w, "\t{} host_ret;\n", ret_type)?; - write!(w, "\tcudaMemcpy(&host_ret, ret, sizeof({}), cudaMemcpyDeviceToHost);\n", ret_type)?; + write!( + w, + "\tcudaMemcpy(&host_ret, ret, sizeof({}), cudaMemcpyDeviceToHost);\n", + ret_type + )?; write!(w, "\treturn host_ret;\n")?; } else { write!(w, "\treturn p{};\n", self.return_parameter.unwrap())?; @@ -604,9 +673,12 @@ extern \"C\" {} {}(", ret_type.clone(), self.function.name)?; panic!("Expected fork node"); }; let reduces = &self.fork_reduce_map[root_fork]; - if self.function.schedules[root_fork.idx()].contains(&Schedule::ParallelFork) - { - let fork_size = factors.iter().map(|dc| format!("dc{}", dc.idx())).collect::<Vec<_>>().join(" * "); + if self.function.schedules[root_fork.idx()].contains(&Schedule::ParallelFork) { + let fork_size = factors + .iter() + .map(|dc| format!("dc{}", dc.idx())) + .collect::<Vec<_>>() + .join(" * "); (root_forks, fork_size, true) } else { (root_forks, "1".to_string(), false) @@ -626,7 +698,10 @@ extern \"C\" {} {}(", ret_type.clone(), self.function.name)?; ) -> (NodeID, HashSet<NodeID>) { if is_block_parallel { let root_fork = root_forks.iter().next().unwrap(); - (*root_fork, fork_tree.get(&root_fork).unwrap().iter().copied().collect()) + ( + *root_fork, + fork_tree.get(&root_fork).unwrap().iter().copied().collect(), + ) } else { (NodeID::new(0), root_forks.clone()) } @@ -676,12 +751,17 @@ extern \"C\" {} {}(", ret_type.clone(), self.function.name)?; .map(|child| (child, self.recurse_thread_quotas(*child, fork_tree, false))) .fold( (HashMap::new(), HashMap::new(), 1), - |(mut subsubtree_map, mut children_quota_map, subtree_quota), (child, (curr_map, curr_quota, use_curr))| { + |(mut subsubtree_map, mut children_quota_map, subtree_quota), + (child, (curr_map, curr_quota, use_curr))| { subsubtree_map.extend(curr_map); if use_curr { children_quota_map.insert(child, curr_quota); } - (subsubtree_map, children_quota_map, subtree_quota.max(curr_quota)) + ( + subsubtree_map, + children_quota_map, + subtree_quota.max(curr_quota), + ) }, ); // First update children_quota_map items with full information and add @@ -695,7 +775,7 @@ extern \"C\" {} {}(", ret_type.clone(), self.function.name)?; } let subtree_map = subsubtree_map; if is_root { - return (subtree_map, subtree_quota, true) + return (subtree_map, subtree_quota, true); } // A node can only be considered for parallelization if: // a) it has statically known size @@ -725,7 +805,8 @@ extern \"C\" {} {}(", ret_type.clone(), self.function.name)?; // is possible. if reduces.iter().any(|&reduce| { self.function.schedules[reduce.idx()].contains(&Schedule::TightAssociative) - }) || fork_size > self.kernel_params.max_num_threads / subtree_quota { + }) || fork_size > self.kernel_params.max_num_threads / subtree_quota + { if fork_size >= subtree_quota { (HashMap::new(), fork_size, true) } else { @@ -770,9 +851,22 @@ extern \"C\" {} {}(", ret_type.clone(), self.function.name)?; let post_init = &mut goto.post_init; let body = &mut goto.body; let term = &mut goto.term; - let mut tabs = self.codegen_control_node(control, None, None, None, init, post_init, term)?; + let mut tabs = + self.codegen_control_node(control, None, None, None, init, post_init, term)?; for data in self.bbs.1[control.idx()].iter() { - self.codegen_data_node(*data, KernelState::OutBlock, Some(false), None, None, None, false, extra_dim_collects, dynamic_shared_offset, body, &mut tabs)?; + self.codegen_data_node( + *data, + KernelState::OutBlock, + Some(false), + None, + None, + None, + false, + extra_dim_collects, + dynamic_shared_offset, + body, + &mut tabs, + )?; } Ok(()) }) @@ -801,9 +895,22 @@ extern \"C\" {} {}(", ret_type.clone(), self.function.name)?; let post_init = &mut goto.post_init; let body = &mut goto.body; let term = &mut goto.term; - let mut tabs = self.codegen_control_node(*control, None, None, None, init, post_init, term)?; + let mut tabs = + self.codegen_control_node(*control, None, None, None, init, post_init, term)?; for data in self.bbs.1[control.idx()].iter() { - self.codegen_data_node(*data, state, Some(is_block_parallel), None, None, None, false, extra_dim_collects, dynamic_shared_offset, body, &mut tabs)?; + self.codegen_data_node( + *data, + state, + Some(is_block_parallel), + None, + None, + None, + false, + extra_dim_collects, + dynamic_shared_offset, + body, + &mut tabs, + )?; } } // Then generate data and control for the single block fork if it exists @@ -815,9 +922,29 @@ extern \"C\" {} {}(", ret_type.clone(), self.function.name)?; let post_init = &mut goto.post_init; let body = &mut goto.body; let term = &mut goto.term; - let mut tabs = self.codegen_control_node(*control, Some(num_threads), Some(num_threads), Some(1), init, post_init, term)?; + let mut tabs = self.codegen_control_node( + *control, + Some(num_threads), + Some(num_threads), + Some(1), + init, + post_init, + term, + )?; for data in self.bbs.1[control.idx()].iter() { - self.codegen_data_node(*data, state, None, Some(num_threads), None, Some(block_fork.unwrap()), false, extra_dim_collects, dynamic_shared_offset, body, &mut tabs)?; + self.codegen_data_node( + *data, + state, + None, + Some(num_threads), + None, + Some(block_fork.unwrap()), + false, + extra_dim_collects, + dynamic_shared_offset, + body, + &mut tabs, + )?; } } } @@ -862,14 +989,19 @@ extern \"C\" {} {}(", ret_type.clone(), self.function.name)?; let reduces = &self.fork_reduce_map[&curr_fork]; let reducts = if parallel_factor.is_some() { reduces - .iter() - .map(|&reduce| { - let Node::Reduce { control: _, init: _, reduct} = &self.function.nodes[reduce.idx()] else { - panic!("Expected reduce node"); - }; - *reduct - }) - .collect() + .iter() + .map(|&reduce| { + let Node::Reduce { + control: _, + init: _, + reduct, + } = &self.function.nodes[reduce.idx()] + else { + panic!("Expected reduce node"); + }; + *reduct + }) + .collect() } else { HashSet::new() }; @@ -879,7 +1011,15 @@ extern \"C\" {} {}(", ret_type.clone(), self.function.name)?; let post_init = &mut goto.post_init; let body = &mut goto.body; let term = &mut goto.term; - let mut tabs = self.codegen_control_node(*control, Some(available_thread_quota), Some(use_thread_quota), parallel_factor, init, post_init, term)?; + let mut tabs = self.codegen_control_node( + *control, + Some(available_thread_quota), + Some(use_thread_quota), + parallel_factor, + init, + post_init, + term, + )?; for data in self.bbs.1[control.idx()].iter() { self.codegen_data_node( *data, @@ -928,8 +1068,15 @@ extern \"C\" {} {}(", ret_type.clone(), self.function.name)?; let define_variable = self.get_value(id, false, false).to_string(); let tabs = "\t".repeat(*num_tabs); match &self.function.nodes[id.idx()] { - Node::Phi { control: _, data: _ } => { - write!(w, "{}{} = {}_tmp;\n", tabs, define_variable, define_variable)?; + Node::Phi { + control: _, + data: _, + } => { + write!( + w, + "{}{} = {}_tmp;\n", + tabs, define_variable, define_variable + )?; } Node::ThreadID { control, dimension } => { let Node::Fork { factors, .. } = &self.function.nodes[control.idx()] else { @@ -950,11 +1097,22 @@ extern \"C\" {} {}(", ret_type.clone(), self.function.name)?; // No dependence on threadIdx.x because each used thread // will run this Fork serially let fork_iter = self.get_fork_iter(*control, false); - write!(w, "{}{} = ({} / {}) % {};\n", tabs, define_variable, fork_iter, divide, modulo)?; + write!( + w, + "{}{} = ({} / {}) % {};\n", + tabs, define_variable, fork_iter, divide, modulo + )?; } else { // We can directly use use_thread_quota and not worry about available // because Fork basic block's init section already does gating - write!(w, "{}{} = (threadIdx.x % {}) / {};\n", tabs, define_variable, use_thread_quota.unwrap(), use_thread_quota.unwrap() / parallel_factor.unwrap())?; + write!( + w, + "{}{} = (threadIdx.x % {}) / {};\n", + tabs, + define_variable, + use_thread_quota.unwrap(), + use_thread_quota.unwrap() / parallel_factor.unwrap() + )?; } } _ => { @@ -995,7 +1153,9 @@ extern \"C\" {} {}(", ret_type.clone(), self.function.name)?; let is_primitive = self.types[self.typing[id.idx()].idx()].is_primitive(); let cg_tile = match state { KernelState::OutBlock | KernelState::InBlock => "block".to_string(), - KernelState::InThread => self.get_cg_tile(nesting_fork.unwrap(), CGType::UsePerId), + KernelState::InThread => { + self.get_cg_tile(nesting_fork.unwrap(), CGType::UsePerId) + } }; if !is_primitive && state != KernelState::OutBlock { write!(w, "{}if ({}.thread_rank() == 0) {{\n", tabs, cg_tile)?; @@ -1018,13 +1178,29 @@ extern \"C\" {} {}(", ret_type.clone(), self.function.name)?; *num_tabs -= 1; } if !is_primitive && state == KernelState::OutBlock { + assert!(self.function.schedules[id.idx()].contains(&Schedule::NoResetConstant), "PANIC: The CUDA backend cannot lower a global memory constant that has to be reset to zero. This is because we cannot efficiently implement a memset to the underlying memory of the constant due to the need for a grid level sync. Consider floating this collection outside the CUDA function and into an AsyncRust function, or attaching the NoResetConstant schedule to indicate that no memset is semantically necessary."); let (_, offsets) = &self.backing_allocation[&Device::CUDA]; let offset = offsets[&id]; - write!(w, "{}{} = backing + dc{};\n", tabs, define_variable, offset.idx())?; + write!( + w, + "{}{} = backing + dc{};\n", + tabs, + define_variable, + offset.idx() + )?; } - if !is_primitive && (state != KernelState::OutBlock || is_block_parallel.is_none() || !is_block_parallel.unwrap()) { - let data_size = self.get_size(self.typing[id.idx()], None, Some(extra_dim_collects)); - write!(w, "{}for (int i = {}.thread_rank(); i < {}; i += {}.size()) {{\n", tabs, cg_tile, data_size, cg_tile)?; + if !is_primitive + && (state != KernelState::OutBlock + || is_block_parallel.is_none() + || !is_block_parallel.unwrap()) + { + let data_size = + self.get_size(self.typing[id.idx()], None, Some(extra_dim_collects)); + write!( + w, + "{}for (int i = {}.thread_rank(); i < {}; i += {}.size()) {{\n", + tabs, cg_tile, data_size, cg_tile + )?; write!(w, "{}\t*({} + i) = 0;\n", tabs, define_variable)?; write!(w, "{}}}\n", tabs)?; write!(w, "{}{}.sync();\n", tabs, cg_tile)?; @@ -1083,14 +1259,25 @@ extern \"C\" {} {}(", ret_type.clone(), self.function.name)?; let mut left_val = self.get_value(*left, false, false); let mut right_val = self.get_value(*right, false, false); let id_type = self.typing[id.idx()]; - if matches!(op, BinaryOperator::Add | BinaryOperator::Or | BinaryOperator::And - | BinaryOperator::Xor) && is_special_reduct { + if matches!( + op, + BinaryOperator::Add + | BinaryOperator::Or + | BinaryOperator::And + | BinaryOperator::Xor + ) && is_special_reduct + { // For parallelized associative Reduces, use the cooperative // groups reduce API. Associative multiplication is not // supported. We need to use CGType::Use not CGType::UsePerId // because for parallelized reduction we only have one thread // per ThreadID and the reduction is over Use, not UsePerId. - let (reduce_val, non_reduce_val) = if let Node::Reduce { control: _, init: _, reduct: _ } = &self.function.nodes[left.idx()] { + let (reduce_val, non_reduce_val) = if let Node::Reduce { + control: _, + init: _, + reduct: _, + } = &self.function.nodes[left.idx()] + { (left_val, right_val) } else { (right_val, left_val) @@ -1107,7 +1294,11 @@ extern \"C\" {} {}(", ret_type.clone(), self.function.name)?; _ => unreachable!(), }; let id_type_name = self.get_type(id_type, false); - write!(w, "{}{} = cg::reduce({}, {}, cg::{}<{}>());\n", tabs, define_variable, cg_tile, non_reduce_val, cg_op, id_type_name)?; + write!( + w, + "{}{} = cg::reduce({}, {}, cg::{}<{}>());\n", + tabs, define_variable, cg_tile, non_reduce_val, cg_op, id_type_name + )?; // Setup binop between reduce's init and reduced reduct. Since it's associative, // we can change binop ordering left_val = define_variable.clone(); @@ -1184,7 +1375,12 @@ extern \"C\" {} {}(", ret_type.clone(), self.function.name)?; let id_type = self.typing[id.idx()]; if matches!(intrinsic, Intrinsic::Max | Intrinsic::Min) && is_special_reduct { // Similar to associative Binops - let non_reduce_arg = if let Node::Reduce { control: _, init: _, reduct: _ } = &self.function.nodes[args[0].idx()] { + let non_reduce_arg = if let Node::Reduce { + control: _, + init: _, + reduct: _, + } = &self.function.nodes[args[0].idx()] + { self.get_value(args[1], false, false) } else { self.get_value(args[0], false, false) @@ -1197,33 +1393,44 @@ extern \"C\" {} {}(", ret_type.clone(), self.function.name)?; _ => unreachable!(), }; let id_type_name = self.get_type(id_type, false); - write!(w, "{}{} = cg::reduce({}, {}, cg::{}<{}>());\n", tabs, define_variable, non_reduce_arg, cg_tile, cg_op, id_type_name)?; + write!( + w, + "{}{} = cg::reduce({}, {}, cg::{}<{}>());\n", + tabs, define_variable, non_reduce_arg, cg_tile, cg_op, id_type_name + )?; } else { let ty = &self.types[id_type.idx()]; let intrinsic = self.codegen_intrinsic(intrinsic, ty); - let args = args.iter() + let args = args + .iter() .map(|arg| self.get_value(*arg, false, false)) .collect::<Vec<_>>() .join(", "); write!( w, "{}{} = {}({});\n", - tabs, - define_variable, - intrinsic, - args, + tabs, define_variable, intrinsic, args, )?; } } // Read of primitive requires load after pointer math. Node::Read { collect, indices } => { - let collect_with_indices = self.codegen_collect(*collect, indices, extra_dim_collects); + let collect_with_indices = + self.codegen_collect(*collect, indices, extra_dim_collects); let data_type_id = self.typing[id.idx()]; if self.types[data_type_id.idx()].is_primitive() { let type_name = self.get_type(data_type_id, true); - write!(w, "{}{} = *reinterpret_cast<{}>({});\n", tabs, define_variable, type_name, collect_with_indices)?; + write!( + w, + "{}{} = *reinterpret_cast<{}>({});\n", + tabs, define_variable, type_name, collect_with_indices + )?; } else { - write!(w, "{}{} = {};\n", tabs, define_variable, collect_with_indices)?; + write!( + w, + "{}{} = {};\n", + tabs, define_variable, collect_with_indices + )?; } } // Write of primitive needs a thread rank gate for safety. Write of @@ -1233,24 +1440,43 @@ extern \"C\" {} {}(", ret_type.clone(), self.function.name)?; data, indices, } => { - let collect_with_indices = self.codegen_collect(*collect, indices, extra_dim_collects); + let collect_with_indices = + self.codegen_collect(*collect, indices, extra_dim_collects); let data_variable = self.get_value(*data, false, false); let data_type_id = self.typing[data.idx()]; let cg_tile = match state { KernelState::OutBlock | KernelState::InBlock => "block".to_string(), - KernelState::InThread => self.get_cg_tile(nesting_fork.unwrap(), CGType::UsePerId), + KernelState::InThread => { + self.get_cg_tile(nesting_fork.unwrap(), CGType::UsePerId) + } }; if self.types[data_type_id.idx()].is_primitive() { write!(w, "{}if ({}.thread_rank() == 0) {{\n", tabs, cg_tile)?; let type_name = self.get_type(data_type_id, true); - write!(w, "{}\t*reinterpret_cast<{}>({}) = {};\n", tabs, type_name, collect_with_indices, data_variable)?; + write!( + w, + "{}\t*reinterpret_cast<{}>({}) = {};\n", + tabs, type_name, collect_with_indices, data_variable + )?; write!(w, "{}}}\n", tabs)?; } else { let data_size = self.get_size(data_type_id, None, Some(extra_dim_collects)); - write!(w, "{}for (int i = {}.thread_rank(); i < {}; i += {}.size()) {{\n", tabs, cg_tile, data_size, cg_tile)?; - write!(w, "{}\t*({} + i) = *({} + i);\n", tabs, collect_with_indices, data_variable)?; + write!( + w, + "{}for (int i = {}.thread_rank(); i < {}; i += {}.size()) {{\n", + tabs, cg_tile, data_size, cg_tile + )?; + write!( + w, + "{}\t*({} + i) = *({} + i);\n", + tabs, collect_with_indices, data_variable + )?; write!(w, "{}}}\n", tabs)?; - write!(w, "{}if ({}.thread_rank() < {} % {}.size()) {{\n", tabs, cg_tile, data_size, cg_tile)?; + write!( + w, + "{}if ({}.thread_rank() < {} % {}.size()) {{\n", + tabs, cg_tile, data_size, cg_tile + )?; write!(w, "{}\t*({} + {}.size() * ({} / {}.size()) + {}.thread_rank()) = *({} + {}.size() * ({} / {}.size()) + {}.thread_rank());\n", tabs, collect_with_indices, cg_tile, data_size, cg_tile, cg_tile, data_variable, cg_tile, data_size, cg_tile, cg_tile)?; write!(w, "{}}}\n", tabs)?; } @@ -1259,7 +1485,10 @@ extern \"C\" {} {}(", ret_type.clone(), self.function.name)?; write!(w, "{}{} = {};\n", tabs, define_variable, collect_variable)?; } _ => { - panic!("Unsupported data node type: {:?}", self.function.nodes[id.idx()]) + panic!( + "Unsupported data node type: {:?}", + self.function.nodes[id.idx()] + ) } } // Since reducts are responsible for updating Reduce nodes, we check and @@ -1292,7 +1521,10 @@ extern \"C\" {} {}(", ret_type.clone(), self.function.name)?; let tabs = match &self.function.nodes[id.idx()] { Node::Start | Node::Region { preds: _ } - | Node::Projection { control: _, selection: _ } => { + | Node::Projection { + control: _, + selection: _, + } => { let succ = self.control_subgraph.succs(id).next().unwrap(); write!(w_term, "\tgoto {};\n", self.get_block_name(succ, false))?; 1 @@ -1309,13 +1541,32 @@ extern \"C\" {} {}(", ret_type.clone(), self.function.name)?; "\tif ({}) {{\n", self.get_value(*cond, false, false) )?; - write!(w_term, "\t\tgoto {};\n", if succ1_is_true { succ1_block_name.clone() } else { succ2_block_name.clone() })?; + write!( + w_term, + "\t\tgoto {};\n", + if succ1_is_true { + succ1_block_name.clone() + } else { + succ2_block_name.clone() + } + )?; write!(w_term, "\t}} else {{\n")?; - write!(w_term, "\t\tgoto {};\n", if succ1_is_true { succ2_block_name } else { succ1_block_name })?; + write!( + w_term, + "\t\tgoto {};\n", + if succ1_is_true { + succ2_block_name + } else { + succ1_block_name + } + )?; write!(w_term, "\t}}\n")?; 1 } - Node::Fork { control: _, factors: _ } => { + Node::Fork { + control: _, + factors: _, + } => { // We create a cooperative group tile for each of: used threads per // thread ID- for reads and writes-, used threads across all thread // IDs- for parallelized reductions-, and available threads- to @@ -1332,12 +1583,24 @@ extern \"C\" {} {}(", ret_type.clone(), self.function.name)?; } else { use_thread_quota }; - write!(w_init, "\tcg::thread_block_tile<{}> {} = cg::tiled_partition<{}>(block);\n", use_thread_per_id, cg_tile, use_thread_per_id)?; + write!( + w_init, + "\tcg::thread_block_tile<{}> {} = cg::tiled_partition<{}>(block);\n", + use_thread_per_id, cg_tile, use_thread_per_id + )?; let cg_tile_use = self.get_cg_tile(id, CGType::Use); - write!(w_init, "\tcg::thread_block_tile<{}> {} = cg::tiled_partition<{}>(block);\n", use_thread_quota, cg_tile_use, use_thread_quota)?; + write!( + w_init, + "\tcg::thread_block_tile<{}> {} = cg::tiled_partition<{}>(block);\n", + use_thread_quota, cg_tile_use, use_thread_quota + )?; let available_thread_quota = available_thread_quota.unwrap(); let cg_tile_available = self.get_cg_tile(id, CGType::Available); - write!(w_init, "\tcg::thread_block_tile<{}> {} = cg::tiled_partition<{}>(block);\n", available_thread_quota, cg_tile_available, available_thread_quota)?; + write!( + w_init, + "\tcg::thread_block_tile<{}> {} = cg::tiled_partition<{}>(block);\n", + available_thread_quota, cg_tile_available, available_thread_quota + )?; if parallel_factor.is_none() { write!(w_init, "\t{} = 0;\n", self.get_fork_iter(id, true))?; write!(w_init, "\tgoto {};\n", self.get_block_name(id, true))?; @@ -1347,9 +1610,20 @@ extern \"C\" {} {}(", ret_type.clone(), self.function.name)?; // threads. If unused, we jump straight to the Join, and if used, // we jump to successor like normal. let succ = self.control_subgraph.succs(id).next().unwrap(); - if let Some(available_thread_quota) = available_thread_quota && let Some(use_thread_quota) = use_thread_quota && use_thread_quota < available_thread_quota { - let w_target = if parallel_factor.is_none() { w_post_init } else { w_init }; - write!(w_target, "\tif (threadIdx.x % {} < {}) {{\n", available_thread_quota, use_thread_quota)?; + if let Some(available_thread_quota) = available_thread_quota + && let Some(use_thread_quota) = use_thread_quota + && use_thread_quota < available_thread_quota + { + let w_target = if parallel_factor.is_none() { + w_post_init + } else { + w_init + }; + write!( + w_target, + "\tif (threadIdx.x % {} < {}) {{\n", + available_thread_quota, use_thread_quota + )?; write!(w_term, "\t\tgoto {};\n", self.get_block_name(succ, false))?; write!(w_term, "\t}}\n")?; write!(w_term, "\telse {{\n")?; @@ -1376,7 +1650,11 @@ extern \"C\" {} {}(", ret_type.clone(), self.function.name)?; let available_thread_quota = available_thread_quota.unwrap(); let use_thread_quota = use_thread_quota.unwrap(); if use_thread_quota < available_thread_quota { - write!(w_init, "\tif (threadIdx.x % {} < {}) {{\n", available_thread_quota, use_thread_quota)?; + write!( + w_init, + "\tif (threadIdx.x % {} < {}) {{\n", + available_thread_quota, use_thread_quota + )?; write!(w_term, "\t}}\n")?; tabs += 1; } @@ -1415,7 +1693,11 @@ extern \"C\" {} {}(", ret_type.clone(), self.function.name)?; let return_val = self.get_value(*data, false, false); let return_type_ptr = self.get_type(self.function.return_type, true); write!(w_term, "\tif (grid.thread_rank() == 0) {{\n")?; - write!(w_term, "\t\t*(reinterpret_cast<{}>(ret)) = {};\n", return_type_ptr, return_val)?; + write!( + w_term, + "\t\t*(reinterpret_cast<{}>(ret)) = {};\n", + return_type_ptr, return_val + )?; write!(w_term, "\t}}\n")?; } write!(w_term, "\treturn;\n")?; @@ -1432,7 +1714,12 @@ extern \"C\" {} {}(", ret_type.clone(), self.function.name)?; * This function emits collection name + pointer math for the provided indices. * All collection types use char pointers. */ - fn codegen_collect(&self, collect: NodeID, indices: &[Index], extra_dim_collects: &HashSet<TypeID>) -> String { + fn codegen_collect( + &self, + collect: NodeID, + indices: &[Index], + extra_dim_collects: &HashSet<TypeID>, + ) -> String { let mut index_ptr = "0".to_string(); let type_id = self.typing[collect.idx()]; for index in indices { @@ -1514,7 +1801,9 @@ extern \"C\" {} {}(", ret_type.clone(), self.function.name)?; Constant::UnsignedInteger32(val) => write!(w, "{}{} = {}ul;\n", tabs, name, val)?, Constant::Integer64(val) => write!(w, "{}{} = {}ll;\n", tabs, name, val)?, Constant::UnsignedInteger64(val) => write!(w, "{}{} = {}ull;\n", tabs, name, val)?, - Constant::Float32(val) => write!(w, "{}{} = {}f;\n", tabs, name, format_float(**val as f64))?, + Constant::Float32(val) => { + write!(w, "{}{} = {}f;\n", tabs, name, format_float(**val as f64))? + } Constant::Float64(val) => write!(w, "{}{} = {};\n", tabs, name, format_float(**val))?, // All three following collections involve align then allocate from the // single dynamic shared memory buffer by using and updating the offset. @@ -1522,9 +1811,20 @@ extern \"C\" {} {}(", ret_type.clone(), self.function.name)?; if allow_allocate { let alignment = self.get_alignment(*type_id); let size = self.get_size(*type_id, None, extra_dim_collects); - *dynamic_shared_offset = format!("(({} + {} - 1) / {}) * {}", dynamic_shared_offset, alignment, alignment, alignment); - write!(w, "{}dynamic_shared_offset = {};\n", tabs, dynamic_shared_offset)?; - write!(w, "{}{} = dynamic_shared + dynamic_shared_offset;\n", tabs, name)?; + *dynamic_shared_offset = format!( + "(({} + {} - 1) / {}) * {}", + dynamic_shared_offset, alignment, alignment, alignment + ); + write!( + w, + "{}dynamic_shared_offset = {};\n", + tabs, dynamic_shared_offset + )?; + write!( + w, + "{}{} = dynamic_shared + dynamic_shared_offset;\n", + tabs, name + )?; *dynamic_shared_offset = format!("{} + {}", dynamic_shared_offset, size); } let Type::Product(type_fields) = &self.types[type_id.idx()] else { @@ -1546,7 +1846,15 @@ extern \"C\" {} {}(", ret_type.clone(), self.function.name)?; num_tabs, )?; } else if !field_constant.is_array() { - self.codegen_constant(format!("{}+{}", name, offset), constant_fields[i], false, extra_dim_collects, dynamic_shared_offset, w, num_tabs)?; + self.codegen_constant( + format!("{}+{}", name, offset), + constant_fields[i], + false, + extra_dim_collects, + dynamic_shared_offset, + w, + num_tabs, + )?; } } } @@ -1554,9 +1862,20 @@ extern \"C\" {} {}(", ret_type.clone(), self.function.name)?; if allow_allocate { let alignment = self.get_alignment(*type_id); let size = self.get_size(*type_id, None, extra_dim_collects); - *dynamic_shared_offset = format!("(({} + {} - 1) / {}) * {}", dynamic_shared_offset, alignment, alignment, alignment); - write!(w, "{}dynamic_shared_offset = {};\n", tabs, dynamic_shared_offset)?; - write!(w, "{}{} = dynamic_shared + dynamic_shared_offset;\n", tabs, name)?; + *dynamic_shared_offset = format!( + "(({} + {} - 1) / {}) * {}", + dynamic_shared_offset, alignment, alignment, alignment + ); + write!( + w, + "{}dynamic_shared_offset = {};\n", + tabs, dynamic_shared_offset + )?; + write!( + w, + "{}{} = dynamic_shared + dynamic_shared_offset;\n", + tabs, name + )?; *dynamic_shared_offset = format!("{} + {}", dynamic_shared_offset, size); } // No offset updating needed since all variants start at 0 @@ -1565,7 +1884,8 @@ extern \"C\" {} {}(", ret_type.clone(), self.function.name)?; }; let variant_constant = &self.constants[field.idx()]; if variant_constant.is_scalar() { - let variant_type = self.get_type(self.typing[variants[*variant as usize].idx()], true); + let variant_type = + self.get_type(self.typing[variants[*variant as usize].idx()], true); self.codegen_constant( format!("*reinterpret_cast<{}>({})", variant_type, name), *field, @@ -1576,7 +1896,15 @@ extern \"C\" {} {}(", ret_type.clone(), self.function.name)?; num_tabs, )?; } else if !variant_constant.is_array() { - self.codegen_constant(name, *field, false, extra_dim_collects, dynamic_shared_offset, w, num_tabs)?; + self.codegen_constant( + name, + *field, + false, + extra_dim_collects, + dynamic_shared_offset, + w, + num_tabs, + )?; }; } Constant::Array(type_id) => { @@ -1585,9 +1913,20 @@ extern \"C\" {} {}(", ret_type.clone(), self.function.name)?; } let alignment = self.get_alignment(*type_id); let size = self.get_size(*type_id, None, extra_dim_collects); - *dynamic_shared_offset = format!("(({} + {} - 1) / {}) * {}", dynamic_shared_offset, alignment, alignment, alignment); - write!(w, "{}dynamic_shared_offset = {};\n", tabs, dynamic_shared_offset)?; - write!(w, "{}{} = dynamic_shared + dynamic_shared_offset;\n", tabs, name)?; + *dynamic_shared_offset = format!( + "(({} + {} - 1) / {}) * {}", + dynamic_shared_offset, alignment, alignment, alignment + ); + write!( + w, + "{}dynamic_shared_offset = {};\n", + tabs, dynamic_shared_offset + )?; + write!( + w, + "{}{} = dynamic_shared + dynamic_shared_offset;\n", + tabs, name + )?; *dynamic_shared_offset = format!("{} + {}", dynamic_shared_offset, size); } } @@ -1600,10 +1939,21 @@ extern \"C\" {} {}(", ret_type.clone(), self.function.name)?; * and offset to 2nd field. This is useful for constant initialization and read/write * index math. */ - fn get_size(&self, type_id: TypeID, num_fields: Option<usize>, extra_dim_collects: Option<&HashSet<TypeID>>) -> String { + fn get_size( + &self, + type_id: TypeID, + num_fields: Option<usize>, + extra_dim_collects: Option<&HashSet<TypeID>>, + ) -> String { match &self.types[type_id.idx()] { Type::Array(element_type, extents) => { - let array_size = if extra_dim_collects.is_some() && extra_dim_collects.unwrap().contains(&type_id) { "1".to_string() } else { multiply_dcs(extents) }; + let array_size = if extra_dim_collects.is_some() + && extra_dim_collects.unwrap().contains(&type_id) + { + "1".to_string() + } else { + multiply_dcs(extents) + }; format!("{} * {}", self.get_alignment(*element_type), array_size) } Type::Product(fields) => { @@ -1612,7 +1962,12 @@ extern \"C\" {} {}(", ret_type.clone(), self.function.name)?; .iter() .enumerate() .filter(|(i, _)| i < num_fields) - .map(|(_, id)| (self.get_size(*id, None, extra_dim_collects), self.get_alignment(*id))) + .map(|(_, id)| { + ( + self.get_size(*id, None, extra_dim_collects), + self.get_alignment(*id), + ) + }) .fold(String::from("0"), |acc, (size, align)| { if acc == "0" { size @@ -1637,16 +1992,16 @@ extern \"C\" {} {}(", ret_type.clone(), self.function.name)?; // The argmax variant by size is not guaranteed to be same as // argmax variant by alignment, eg product of 3 4-byte primitives // vs 1 8-byte primitive, so we need to calculate both. - let max_size = variants.iter().map(|id| self.get_size(*id, None, extra_dim_collects)).fold( - String::from("0"), - |acc, x| { + let max_size = variants + .iter() + .map(|id| self.get_size(*id, None, extra_dim_collects)) + .fold(String::from("0"), |acc, x| { if acc == "0" { x } else { format!("umax({}, {})", acc, x) } - }, - ); + }); let max_alignment = variants .iter() .map(|id| self.get_alignment(*id)) @@ -1793,7 +2148,17 @@ extern \"C\" {} {}(", ret_type.clone(), self.function.name)?; } fn get_cg_tile(&self, fork: NodeID, cg_type: CGType) -> String { - format!("cg_{}{}", self.get_value(fork, false, false), if cg_type == CGType::Use { "_use" } else if cg_type == CGType::Available { "_available" } else { "" }) + format!( + "cg_{}{}", + self.get_value(fork, false, false), + if cg_type == CGType::Use { + "_use" + } else if cg_type == CGType::Available { + "_available" + } else { + "" + } + ) } fn get_fork_iter(&self, fork: NodeID, ty: bool) -> String { @@ -1805,7 +2170,11 @@ extern \"C\" {} {}(", ret_type.clone(), self.function.name)?; } fn get_block_name(&self, id: NodeID, post: bool) -> String { - format!("bb_{}{}", self.get_value(id, false, false), if post { "_post" } else { "" }) + format!( + "bb_{}{}", + self.get_value(id, false, false), + if post { "_post" } else { "" } + ) } /* @@ -1860,7 +2229,10 @@ fn multiply_dcs(dcs: &[DynamicConstantID]) -> String { if dcs.is_empty() { "1".to_string() } else { - dcs.iter().map(|dc| format!("dc{}", dc.idx())).collect::<Vec<_>>().join(" * ") + dcs.iter() + .map(|dc| format!("dc{}", dc.idx())) + .collect::<Vec<_>>() + .join(" * ") } } diff --git a/hercules_cg/src/rt.rs b/hercules_cg/src/rt.rs index 916d6520..cbef5a00 100644 --- a/hercules_cg/src/rt.rs +++ b/hercules_cg/src/rt.rs @@ -358,14 +358,16 @@ impl<'a> RTContext<'a> { } } write!(block, ";\n")?; - if let Some((size, device)) = size_and_device { - write!( - block, - " ::hercules_rt::__{}_zero_mem({}, {} as usize);\n", - device.name(), - self.get_value(id), - size - )?; + if !func.schedules[id.idx()].contains(&Schedule::NoResetConstant) { + if let Some((size, device)) = size_and_device { + write!( + block, + " ::hercules_rt::__{}_zero_mem({}, {} as usize);\n", + device.name(), + self.get_value(id), + size + )?; + } } } Node::Call { diff --git a/hercules_samples/matmul/src/gpu.sch b/hercules_samples/matmul/src/gpu.sch index c0a1a5ce..4303d376 100644 --- a/hercules_samples/matmul/src/gpu.sch +++ b/hercules_samples/matmul/src/gpu.sch @@ -13,9 +13,8 @@ phi-elim(*); dce(*); forkify(*); -infer-schedules(*); +fixpoint { + infer-schedules(*); +} gcm(*); -float-collections(*); -dce(*); -gcm(*); -- GitLab From 4a70e1c75c78291c36ae4249de79b4425d77d8d6 Mon Sep 17 00:00:00 2001 From: Russel Arbore <russel.jma@gmail.com> Date: Sat, 1 Feb 2025 11:14:44 -0600 Subject: [PATCH 7/9] fix implicit_clone --- juno_samples/implicit_clone/src/gpu.sch | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/juno_samples/implicit_clone/src/gpu.sch b/juno_samples/implicit_clone/src/gpu.sch index 0f7c8021..50c4ae8e 100644 --- a/juno_samples/implicit_clone/src/gpu.sch +++ b/juno_samples/implicit_clone/src/gpu.sch @@ -15,6 +15,8 @@ dce(*); infer-schedules(*); gcm(*); -float-collections(*); -dce(*); -gcm(*); +fixpoint { + float-collections(*); + dce(*); + gcm(*); +} -- GitLab From 7114b2083ff64e8fdec1969038a11d8477862576 Mon Sep 17 00:00:00 2001 From: Russel Arbore <russel.jma@gmail.com> Date: Sat, 1 Feb 2025 11:18:17 -0600 Subject: [PATCH 8/9] fix cava gpu ci for now --- juno_samples/cava/src/gpu.sch | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/juno_samples/cava/src/gpu.sch b/juno_samples/cava/src/gpu.sch index a5570b8d..594cbfa3 100644 --- a/juno_samples/cava/src/gpu.sch +++ b/juno_samples/cava/src/gpu.sch @@ -17,3 +17,8 @@ dce(*); infer-schedules(*); gcm(*); +fixpoint { + float-collections(*); + dce(*); + gcm(*); +} -- GitLab From 6b39e69c801eb11c11e3b51a2506467b1c79598f Mon Sep 17 00:00:00 2001 From: Russel Arbore <russel.jma@gmail.com> Date: Sat, 1 Feb 2025 16:40:40 -0600 Subject: [PATCH 9/9] fix for edge_detection --- Cargo.toml | 8 ++++---- juno_samples/edge_detection/src/gpu.sch | 5 +++++ 2 files changed, 9 insertions(+), 4 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index f4976106..f7b9322a 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -21,14 +21,14 @@ members = [ "hercules_samples/ccp", "juno_samples/simple3", - "juno_samples/patterns", + "juno_samples/patterns", "juno_samples/matmul", "juno_samples/casts_and_intrinsics", "juno_samples/nested_ccp", "juno_samples/antideps", "juno_samples/implicit_clone", - "juno_samples/cava", + "juno_samples/cava", "juno_samples/concat", - "juno_samples/schedule_test", - "juno_samples/edge_detection", + "juno_samples/schedule_test", + "juno_samples/edge_detection", ] diff --git a/juno_samples/edge_detection/src/gpu.sch b/juno_samples/edge_detection/src/gpu.sch index d7330c67..a1bf06a4 100644 --- a/juno_samples/edge_detection/src/gpu.sch +++ b/juno_samples/edge_detection/src/gpu.sch @@ -17,3 +17,8 @@ dce(*); infer-schedules(*); gcm(*); +fixpoint { + float-collections(*); + dce(*); + gcm(*); +} -- GitLab