From b1f0d045b696ad8a6f64cd4ddf26b409629ebb3c Mon Sep 17 00:00:00 2001 From: Aaron Councilman <aaronjc4@illinois.edu> Date: Wed, 12 Feb 2025 15:29:47 -0600 Subject: [PATCH 1/7] Describe array to product analysis --- hercules_opt/src/array_to_prod.rs | 30 ++++++++++++++++++++++++++++++ 1 file changed, 30 insertions(+) create mode 100644 hercules_opt/src/array_to_prod.rs diff --git a/hercules_opt/src/array_to_prod.rs b/hercules_opt/src/array_to_prod.rs new file mode 100644 index 00000000..ad1a2471 --- /dev/null +++ b/hercules_opt/src/array_to_prod.rs @@ -0,0 +1,30 @@ +use hercules_ir::ir::*; + +use crate::*; + +/* + * Top level function for array to product which will convert constant + * sized arrays into products if the array is only accessed at indices which + * are constants. + * + * To identify the collections we can convert we look at each constant-sized + * array constant and compute the set which includes the constant node and is + * closed under the following properties: + * - For each collection in the set, its uses are in the set + * - For each node that uses a collection, all collections it uses are in the + * set + * From this set, we then determine whether this whole set can be converted to + * operating on products, rather than arrays, as follows + * - Each read and write node must be to a constant index + * - It may not contain any arguments (we could generate code to read a an + * array argument into a product, but do not do so for now) + * - There are call or return nodes in the set (this would mean that the + * collections are consumed by a call or return, again we could reconstruct + * the array where needed but do not do so for now and so have this + * restriction) + * - All nodes in the set are editable (if we cannot modify some node then the + * conversion will fail) + */ +pub fn array_to_product(editor: &mut FunctionEditor) { + todo!() +} -- GitLab From 810e262845123241ddaee8e2ebd1ea972d6a825f Mon Sep 17 00:00:00 2001 From: Aaron Councilman <aaronjc4@illinois.edu> Date: Wed, 12 Feb 2025 20:03:41 -0600 Subject: [PATCH 2/7] Array to product pass --- hercules_opt/src/array_to_prod.rs | 304 +++++++++++++++++++++++++++++- hercules_opt/src/lib.rs | 2 + juno_scheduler/src/compile.rs | 3 + juno_scheduler/src/ir.rs | 1 + juno_scheduler/src/pm.rs | 20 ++ 5 files changed, 328 insertions(+), 2 deletions(-) diff --git a/hercules_opt/src/array_to_prod.rs b/hercules_opt/src/array_to_prod.rs index ad1a2471..5908a4e5 100644 --- a/hercules_opt/src/array_to_prod.rs +++ b/hercules_opt/src/array_to_prod.rs @@ -1,7 +1,13 @@ +use hercules_ir::define_id_type; use hercules_ir::ir::*; +use bitvec::prelude::*; + use crate::*; +use std::collections::{HashMap, HashSet}; +use std::marker::PhantomData; + /* * Top level function for array to product which will convert constant * sized arrays into products if the array is only accessed at indices which @@ -25,6 +31,300 @@ use crate::*; * - All nodes in the set are editable (if we cannot modify some node then the * conversion will fail) */ -pub fn array_to_product(editor: &mut FunctionEditor) { - todo!() +pub fn array_to_product(editor: &mut FunctionEditor, types: &[TypeID]) { + let replace_nodes = array_usage_analysis(editor, types); + let num_nodes = editor.func().nodes.len(); + + for node_idx in 0..num_nodes { + if !replace_nodes[node_idx] { + continue; + } + let node = NodeID::new(node_idx); + + // We can replace the array(s) this node uses with a product. What we have to do depends on + // the type of the node + match &editor.func().nodes[node_idx] { + // Phi, Reduce, and Ternary just use the whole collection, they do not need to change, + // except as they will be modified by replace_all_uses_of + Node::Phi { .. } + | Node::Reduce { .. } + | Node::Ternary { + op: TernaryOperator::Select, + .. + } => {} + Node::Constant { id } => { + assert!(editor.get_constant(*id).is_array()); + let element: TypeID = editor.get_type(types[node_idx]).try_element_type().unwrap(); + let dims: Vec<usize> = editor + .get_type(types[node_idx]) + .try_extents() + .unwrap() + .iter() + .map(|dc| editor.get_dynamic_constant(*dc).try_constant().unwrap()) + .collect(); + // Replace the constant by a product that is a product (for each dimension) and the + // elements are zero'd + editor.edit(|mut edit| { + let element_zero = edit.add_zero_constant(element); + let (constant, _) = dims.into_iter().rfold( + (element_zero, element), + |(cur_const, cur_type), dim| { + let new_type = edit.add_type(Type::Product(vec![cur_type; dim].into())); + let new_const = edit.add_constant(Constant::Product( + new_type, + vec![cur_const; dim].into(), + )); + (new_const, new_type) + }, + ); + let new_val = edit.add_node(Node::Constant { id: constant }); + let edit = edit.replace_all_uses(node, new_val)?; + edit.delete_node(node) + }); + } + Node::Read { collect, indices } => { + let collect = *collect; + let new_indices = convert_indices_to_prod(editor, indices); + editor.edit(|mut edit| { + let new_val = edit.add_node(Node::Read { + collect, + indices: new_indices, + }); + let edit = edit.replace_all_uses(NodeID::new(node_idx), new_val)?; + edit.delete_node(node) + }); + } + Node::Write { + collect, + data, + indices, + } => { + let collect = *collect; + let data = *data; + let new_indices = convert_indices_to_prod(editor, indices); + editor.edit(|mut edit| { + let new_val = edit.add_node(Node::Write { + collect, + data, + indices: new_indices, + }); + let edit = edit.replace_all_uses(NodeID::new(node_idx), new_val)?; + edit.delete_node(node) + }); + } + node => panic!("Node cannot be replaced: {:?}", node), + } + } +} + +fn convert_indices_to_prod(editor: &FunctionEditor, indices: &[Index]) -> Box<[Index]> { + let mut result = vec![]; + + for index in indices { + match index { + Index::Position(positions) => { + for pos in positions { + let const_id = editor.func().nodes[pos.idx()] + .try_constant() + .expect("Array position must be constant"); + match *editor.get_constant(const_id) { + Constant::UnsignedInteger64(idx) => result.push(Index::Field(idx as usize)), + ref val => panic!("Position should be u64 constant: {:?}", val), + } + } + } + index => panic!("Index cannot be replaced: {:?}", index), + } + } + + result.into() +} + +// Given the editor, while compute a mask of which nodes are to be converted +// from using a constant sized array into using a product +fn array_usage_analysis(editor: &FunctionEditor, types: &[TypeID]) -> BitVec<u8, Lsb0> { + let num_nodes = editor.func().nodes.len(); + + // Step 1: identify the constant nodes that are constant sized arrays, these are what we are + // interested in converting into products + let sources = editor + .func() + .nodes + .iter() + .enumerate() + .filter_map(|(idx, node)| { + let Node::Constant { id } = node else { + return None; + }; + let Constant::Array(array_type) = *editor.get_constant(*id) else { + return None; + }; + let typ = editor.get_type(array_type); + let Some(dims) = typ.try_extents() else { + return None; + }; + if dims + .iter() + .all(|dc| editor.get_dynamic_constant(*dc).is_constant()) + { + Some(NodeID::new(idx)) + } else { + None + } + }) + .collect::<Vec<_>>(); + + // Step 2: collect the collection information we need for the (whole) function. For each node + // that returns a collection (that in reference semantics returns the same reference as some of + // its inputs) union with all of its users. The nodes that matter in this are arguments, + // constants, writes, phis, selects, and reduces with array types. + let mut analysis = UnionFind::new(); + for node_idx in 0..num_nodes { + let node_id = NodeID::new(node_idx); + if editor.get_type(types[node_idx]).is_array() { + match editor.func().nodes[node_idx] { + Node::Phi { .. } + | Node::Reduce { .. } + | Node::Parameter { .. } + | Node::Constant { .. } + | Node::Ternary { + op: TernaryOperator::Select, + .. + } + | Node::Write { .. } => { + for user in editor.get_users(node_id) { + analysis.union(node_id, user); + } + } + _ => {} + } + } + } + + let sets = analysis.sets(&sources); + + // Step 3: determine which sets can be converted and mark the nodes in those sets + let mut result = bitvec![u8, Lsb0; 0; num_nodes]; + + for nodes in sets { + if nodes + .iter() + .all(|node_id| editor.is_mutable(*node_id) && can_replace(editor, *node_id)) + { + for node_id in nodes { + result.set(node_id.idx(), true); + } + } + } + + result +} + +fn can_replace(editor: &FunctionEditor, node: NodeID) -> bool { + match &editor.func().nodes[node.idx()] { + // Reads and writes must be at constant indices + Node::Read { indices, .. } | Node::Write { indices, .. } => { + indices.iter().all(|idx| match idx { + Index::Position(pos) => pos + .iter() + .all(|node| editor.func().nodes[node.idx()].is_constant()), + _ => false, + }) + } + // phi, reduce, constants, and select can always be replaced if their users and uses allow + // it, which is handled by the construction of the set + Node::Phi { .. } + | Node::Reduce { .. } + | Node::Constant { .. } + | Node::Ternary { + op: TernaryOperator::Select, + .. + } => true, + // No other nodes allow replacement + _ => false, + } +} + +define_id_type!(SetID); + +#[derive(Clone, Debug)] +struct UnionFindNode { + parent: SetID, + rank: usize, +} + +#[derive(Clone, Debug)] +struct UnionFind<T> { + sets: Vec<UnionFindNode>, + _phantom: PhantomData<T>, +} + +impl<T: ID> UnionFind<T> { + pub fn new() -> Self { + UnionFind { + sets: vec![], + _phantom: PhantomData, + } + } + + fn extend_past(&mut self, size: usize) { + for i in self.sets.len()..=size { + // The new nodes we add are in their own sets and have rank 0 + self.sets.push(UnionFindNode { + parent: SetID::new(i), + rank: 0, + }); + } + } + + pub fn find(&mut self, x: T) -> SetID { + self.extend_past(x.idx()); + self.find_set(x.idx()) + } + + fn find_set(&mut self, x: usize) -> SetID { + let mut parent = self.sets[x].parent; + if parent.idx() != x { + parent = self.find_set(parent.idx()); + self.sets[x].parent = parent; + } + parent + } + + pub fn union(&mut self, x: T, y: T) { + let x = self.find(x); + let y = self.find(y); + self.link(x, y); + } + + fn link(&mut self, x: SetID, y: SetID) { + if self.sets[x.idx()].rank > self.sets[y.idx()].rank { + self.sets[y.idx()].parent = x; + } else { + self.sets[x.idx()].parent = y; + if self.sets[x.idx()].rank == self.sets[y.idx()].rank { + self.sets[y.idx()].rank += 1; + } + } + } + + pub fn sets(&mut self, keys: &[T]) -> Vec<Vec<T>> { + let key_index = keys + .iter() + .enumerate() + .map(|(i, k)| (self.find(*k), i)) + .collect::<HashMap<SetID, usize>>(); + let mut result = vec![vec![]; keys.len()]; + + let num_elements = self.sets.len(); + for i in 0..num_elements { + let set = self.find_set(i); + let Some(idx) = key_index.get(&set) else { + continue; + }; + result[*idx].push(T::new(i)); + } + + result + } } diff --git a/hercules_opt/src/lib.rs b/hercules_opt/src/lib.rs index 7187508a..ed16ca35 100644 --- a/hercules_opt/src/lib.rs +++ b/hercules_opt/src/lib.rs @@ -1,5 +1,6 @@ #![feature(let_chains)] +pub mod array_to_prod; pub mod ccp; pub mod crc; pub mod dce; @@ -25,6 +26,7 @@ pub mod sroa; pub mod unforkify; pub mod utils; +pub use crate::array_to_prod::*; pub use crate::ccp::*; pub use crate::crc::*; pub use crate::dce::*; diff --git a/juno_scheduler/src/compile.rs b/juno_scheduler/src/compile.rs index 4ea8dfb5..34c13a71 100644 --- a/juno_scheduler/src/compile.rs +++ b/juno_scheduler/src/compile.rs @@ -93,6 +93,9 @@ impl FromStr for Appliable { fn from_str(s: &str) -> Result<Self, Self::Err> { match s { "array-slf" => Ok(Appliable::Pass(ir::Pass::ArraySLF)), + "array-to-product" | "array-to-prod" | "a2p" => { + Ok(Appliable::Pass(ir::Pass::ArrayToProduct)) + } "auto-outline" => Ok(Appliable::Pass(ir::Pass::AutoOutline)), "ccp" => Ok(Appliable::Pass(ir::Pass::CCP)), "crc" | "collapse-read-chains" => Ok(Appliable::Pass(ir::Pass::CRC)), diff --git a/juno_scheduler/src/ir.rs b/juno_scheduler/src/ir.rs index 0ecac39a..e07eae43 100644 --- a/juno_scheduler/src/ir.rs +++ b/juno_scheduler/src/ir.rs @@ -3,6 +3,7 @@ use hercules_ir::ir::{Device, Schedule}; #[derive(Debug, Copy, Clone)] pub enum Pass { ArraySLF, + ArrayToProduct, AutoOutline, CCP, CRC, diff --git a/juno_scheduler/src/pm.rs b/juno_scheduler/src/pm.rs index e9c681cd..e8327c39 100644 --- a/juno_scheduler/src/pm.rs +++ b/juno_scheduler/src/pm.rs @@ -1366,6 +1366,26 @@ fn run_pass( pm.delete_gravestones(); pm.clear_analyses(); } + Pass::ArrayToProduct => { + // TODO: We might allow an (optional) maximum size as an argument to this pass for + // nicer control for the user + assert!(args.is_empty()); + pm.make_typing(); + let typing = pm.typing.take().unwrap(); + + for (func, types) in build_selection(pm, selection, false) + .into_iter() + .zip(typing.iter()) + { + let Some(mut func) = func else { + continue; + }; + array_to_product(&mut func, types); + changed |= func.modified(); + } + pm.delete_gravestones(); + pm.clear_analyses(); + } Pass::AutoOutline => { let Some(funcs) = selection_of_functions(pm, selection) else { return Err(SchedulerError::PassError { -- GitLab From 6bbf60c3ca2a4c610f84118b01746e777e7b5ce6 Mon Sep 17 00:00:00 2001 From: Aaron Councilman <aaronjc4@illinois.edu> Date: Wed, 12 Feb 2025 20:29:27 -0600 Subject: [PATCH 3/7] Add test case for array to product --- Cargo.lock | 10 ++++ Cargo.toml | 1 + juno_samples/median_window/Cargo.toml | 21 +++++++ juno_samples/median_window/build.rs | 11 ++++ juno_samples/median_window/src/cpu.sch | 29 +++++++++ juno_samples/median_window/src/gpu.sch | 29 +++++++++ juno_samples/median_window/src/main.rs | 26 ++++++++ juno_samples/median_window/src/median.jn | 76 ++++++++++++++++++++++++ 8 files changed, 203 insertions(+) create mode 100644 juno_samples/median_window/Cargo.toml create mode 100644 juno_samples/median_window/build.rs create mode 100644 juno_samples/median_window/src/cpu.sch create mode 100644 juno_samples/median_window/src/gpu.sch create mode 100644 juno_samples/median_window/src/main.rs create mode 100644 juno_samples/median_window/src/median.jn diff --git a/Cargo.lock b/Cargo.lock index ffb61f4d..0d85fdf5 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1189,6 +1189,16 @@ dependencies = [ "with_builtin_macros", ] +[[package]] +name = "juno_median_window" +version = "0.1.0" +dependencies = [ + "async-std", + "hercules_rt", + "juno_build", + "with_builtin_macros", +] + [[package]] name = "juno_multi_device" version = "0.1.0" diff --git a/Cargo.toml b/Cargo.toml index 3e86bad0..2c8903b1 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -34,4 +34,5 @@ members = [ "juno_samples/fork_join_tests", "juno_samples/multi_device", "juno_samples/product_read", + "juno_samples/median_window", ] diff --git a/juno_samples/median_window/Cargo.toml b/juno_samples/median_window/Cargo.toml new file mode 100644 index 00000000..1e376d5d --- /dev/null +++ b/juno_samples/median_window/Cargo.toml @@ -0,0 +1,21 @@ +[package] +name = "juno_median_window" +version = "0.1.0" +authors = ["Aaron Councilman <aaronjc4@illinois.edu>"] +edition = "2021" + +[[bin]] +name = "juno_median_window" +path = "src/main.rs" + +[features] +cuda = ["juno_build/cuda", "hercules_rt/cuda"] + +[build-dependencies] +juno_build = { path = "../../juno_build" } + +[dependencies] +juno_build = { path = "../../juno_build" } +hercules_rt = { path = "../../hercules_rt" } +with_builtin_macros = "0.1.0" +async-std = "*" diff --git a/juno_samples/median_window/build.rs b/juno_samples/median_window/build.rs new file mode 100644 index 00000000..a6c29d5b --- /dev/null +++ b/juno_samples/median_window/build.rs @@ -0,0 +1,11 @@ +use juno_build::JunoCompiler; + +fn main() { + JunoCompiler::new() + .file_in_src("median.jn") + .unwrap() + .schedule_in_src(if cfg!(feature = "cuda") { "gpu.sch" } else { "cpu.sch" }) + .unwrap() + .build() + .unwrap(); +} diff --git a/juno_samples/median_window/src/cpu.sch b/juno_samples/median_window/src/cpu.sch new file mode 100644 index 00000000..bab309a6 --- /dev/null +++ b/juno_samples/median_window/src/cpu.sch @@ -0,0 +1,29 @@ +gvn(*); +phi-elim(*); +dce(*); + +inline(*); + +let out = auto-outline(*); +cpu(out.median_window); + +ip-sroa(*); +sroa(*); +dce(*); +gvn(*); +phi-elim(*); +dce(*); + +array-to-product(*); +sroa(*); +phi-elim(*); +predication(*); +simplify-cfg(*); +dce(*); +gvn(*); + +infer-schedules(*); +gcm(*); +float-collections(*); +dce(*); +gcm(*); diff --git a/juno_samples/median_window/src/gpu.sch b/juno_samples/median_window/src/gpu.sch new file mode 100644 index 00000000..df4e04df --- /dev/null +++ b/juno_samples/median_window/src/gpu.sch @@ -0,0 +1,29 @@ +gvn(*); +phi-elim(*); +dce(*); + +inline(*); + +let out = auto-outline(*); +gpu(out.median_window); + +ip-sroa(*); +sroa(*); +dce(*); +gvn(*); +phi-elim(*); +dce(*); + +array-to-product(*); +sroa(*); +phi-elim(*); +predication(*); +simplify-cfg(*); +dce(*); +gvn(*); + +infer-schedules(*); +gcm(*); +float-collections(*); +dce(*); +gcm(*); diff --git a/juno_samples/median_window/src/main.rs b/juno_samples/median_window/src/main.rs new file mode 100644 index 00000000..c515ac4b --- /dev/null +++ b/juno_samples/median_window/src/main.rs @@ -0,0 +1,26 @@ +#![feature(concat_idents)] + +juno_build::juno!("median"); + +use hercules_rt::{runner, HerculesImmBox, HerculesImmBoxTo}; + +fn main() { + let m = vec![86, 72, 14, 5, 55, + 25, 98, 89, 3, 66, + 44, 81, 27, 3, 40, + 18, 4, 57, 93, 34, + 70, 50, 50, 18, 34]; + let m = HerculesImmBox::from(m.as_slice()); + + let mut r = runner!(median_window); + let res = + async_std::task::block_on(async { + r.run(m.to()).await + }); + assert_eq!(res, 57); +} + +#[test] +fn test_median_window() { + main() +} diff --git a/juno_samples/median_window/src/median.jn b/juno_samples/median_window/src/median.jn new file mode 100644 index 00000000..2c80882e --- /dev/null +++ b/juno_samples/median_window/src/median.jn @@ -0,0 +1,76 @@ +// TODO: Once unrolling is merged, use loops instead of this nightmare +fn median_matrix<t: number>(m: t[3, 3]) -> t { + let tmp : t[3 * 3]; + tmp[0] = m[0, 0]; + tmp[1] = m[0, 1]; + tmp[2] = m[0, 2]; + tmp[3] = m[1, 0]; + tmp[4] = m[1, 1]; + tmp[5] = m[1, 2]; + tmp[6] = m[2, 0]; + tmp[7] = m[2, 1]; + tmp[8] = m[2, 2]; + + if tmp[0] > tmp[1] { let t = tmp[0]; tmp[0] = tmp[1]; tmp[1] = t; } + if tmp[1] > tmp[2] { let t = tmp[1]; tmp[1] = tmp[2]; tmp[2] = t; } + if tmp[2] > tmp[3] { let t = tmp[2]; tmp[2] = tmp[3]; tmp[3] = t; } + if tmp[3] > tmp[4] { let t = tmp[3]; tmp[3] = tmp[4]; tmp[4] = t; } + if tmp[4] > tmp[5] { let t = tmp[4]; tmp[4] = tmp[5]; tmp[5] = t; } + if tmp[5] > tmp[6] { let t = tmp[5]; tmp[5] = tmp[6]; tmp[6] = t; } + if tmp[6] > tmp[7] { let t = tmp[6]; tmp[6] = tmp[7]; tmp[7] = t; } + if tmp[7] > tmp[8] { let t = tmp[7]; tmp[7] = tmp[8]; tmp[8] = t; } + + if tmp[0] > tmp[1] { let t = tmp[0]; tmp[0] = tmp[1]; tmp[1] = t; } + if tmp[1] > tmp[2] { let t = tmp[1]; tmp[1] = tmp[2]; tmp[2] = t; } + if tmp[2] > tmp[3] { let t = tmp[2]; tmp[2] = tmp[3]; tmp[3] = t; } + if tmp[3] > tmp[4] { let t = tmp[3]; tmp[3] = tmp[4]; tmp[4] = t; } + if tmp[4] > tmp[5] { let t = tmp[4]; tmp[4] = tmp[5]; tmp[5] = t; } + if tmp[5] > tmp[6] { let t = tmp[5]; tmp[5] = tmp[6]; tmp[6] = t; } + if tmp[6] > tmp[7] { let t = tmp[6]; tmp[6] = tmp[7]; tmp[7] = t; } + + if tmp[0] > tmp[1] { let t = tmp[0]; tmp[0] = tmp[1]; tmp[1] = t; } + if tmp[1] > tmp[2] { let t = tmp[1]; tmp[1] = tmp[2]; tmp[2] = t; } + if tmp[2] > tmp[3] { let t = tmp[2]; tmp[2] = tmp[3]; tmp[3] = t; } + if tmp[3] > tmp[4] { let t = tmp[3]; tmp[3] = tmp[4]; tmp[4] = t; } + if tmp[4] > tmp[5] { let t = tmp[4]; tmp[4] = tmp[5]; tmp[5] = t; } + if tmp[5] > tmp[6] { let t = tmp[5]; tmp[5] = tmp[6]; tmp[6] = t; } + + if tmp[0] > tmp[1] { let t = tmp[0]; tmp[0] = tmp[1]; tmp[1] = t; } + if tmp[1] > tmp[2] { let t = tmp[1]; tmp[1] = tmp[2]; tmp[2] = t; } + if tmp[2] > tmp[3] { let t = tmp[2]; tmp[2] = tmp[3]; tmp[3] = t; } + if tmp[3] > tmp[4] { let t = tmp[3]; tmp[3] = tmp[4]; tmp[4] = t; } + if tmp[4] > tmp[5] { let t = tmp[4]; tmp[4] = tmp[5]; tmp[5] = t; } + + if tmp[0] > tmp[1] { let t = tmp[0]; tmp[0] = tmp[1]; tmp[1] = t; } + if tmp[1] > tmp[2] { let t = tmp[1]; tmp[1] = tmp[2]; tmp[2] = t; } + if tmp[2] > tmp[3] { let t = tmp[2]; tmp[2] = tmp[3]; tmp[3] = t; } + if tmp[3] > tmp[4] { let t = tmp[3]; tmp[3] = tmp[4]; tmp[4] = t; } + + if tmp[0] > tmp[1] { let t = tmp[0]; tmp[0] = tmp[1]; tmp[1] = t; } + if tmp[1] > tmp[2] { let t = tmp[1]; tmp[1] = tmp[2]; tmp[2] = t; } + if tmp[2] > tmp[3] { let t = tmp[2]; tmp[2] = tmp[3]; tmp[3] = t; } + + if tmp[0] > tmp[1] { let t = tmp[0]; tmp[0] = tmp[1]; tmp[1] = t; } + if tmp[1] > tmp[2] { let t = tmp[1]; tmp[1] = tmp[2]; tmp[2] = t; } + + if tmp[0] > tmp[1] { let t = tmp[0]; tmp[0] = tmp[1]; tmp[1] = t; } + + return tmp[4]; +} + +#[entry] +fn median_window(x: i32[5, 5]) -> i32 { + let window: i32[3, 3]; + + window[0, 0] = x[1, 1]; + window[0, 1] = x[1, 2]; + window[0, 2] = x[1, 3]; + window[1, 0] = x[2, 1]; + window[1, 1] = x[2, 2]; + window[1, 2] = x[2, 3]; + window[2, 0] = x[3, 1]; + window[2, 1] = x[3, 2]; + window[2, 2] = x[3, 3]; + + return median_matrix::<_>(window); +} -- GitLab From 1b97d09f8f7243d18c3328aebe7006ae1fce2316 Mon Sep 17 00:00:00 2001 From: Aaron Councilman <aaronjc4@illinois.edu> Date: Thu, 13 Feb 2025 09:56:20 -0600 Subject: [PATCH 4/7] Clean-up example --- juno_samples/median_window/src/median.jn | 89 +++++++----------------- 1 file changed, 24 insertions(+), 65 deletions(-) diff --git a/juno_samples/median_window/src/median.jn b/juno_samples/median_window/src/median.jn index 2c80882e..38ae407b 100644 --- a/juno_samples/median_window/src/median.jn +++ b/juno_samples/median_window/src/median.jn @@ -1,76 +1,35 @@ -// TODO: Once unrolling is merged, use loops instead of this nightmare -fn median_matrix<t: number>(m: t[3, 3]) -> t { - let tmp : t[3 * 3]; - tmp[0] = m[0, 0]; - tmp[1] = m[0, 1]; - tmp[2] = m[0, 2]; - tmp[3] = m[1, 0]; - tmp[4] = m[1, 1]; - tmp[5] = m[1, 2]; - tmp[6] = m[2, 0]; - tmp[7] = m[2, 1]; - tmp[8] = m[2, 2]; +fn median_matrix<t: number, n, m: usize>(x: t[n, m]) -> t { + let tmp : t[n * m]; + for i = 0 to n { + for j = 0 to m { + tmp[i * m + j] = x[i, j]; + } + } - if tmp[0] > tmp[1] { let t = tmp[0]; tmp[0] = tmp[1]; tmp[1] = t; } - if tmp[1] > tmp[2] { let t = tmp[1]; tmp[1] = tmp[2]; tmp[2] = t; } - if tmp[2] > tmp[3] { let t = tmp[2]; tmp[2] = tmp[3]; tmp[3] = t; } - if tmp[3] > tmp[4] { let t = tmp[3]; tmp[3] = tmp[4]; tmp[4] = t; } - if tmp[4] > tmp[5] { let t = tmp[4]; tmp[4] = tmp[5]; tmp[5] = t; } - if tmp[5] > tmp[6] { let t = tmp[5]; tmp[5] = tmp[6]; tmp[6] = t; } - if tmp[6] > tmp[7] { let t = tmp[6]; tmp[6] = tmp[7]; tmp[7] = t; } - if tmp[7] > tmp[8] { let t = tmp[7]; tmp[7] = tmp[8]; tmp[8] = t; } + const cnt = n * m; - if tmp[0] > tmp[1] { let t = tmp[0]; tmp[0] = tmp[1]; tmp[1] = t; } - if tmp[1] > tmp[2] { let t = tmp[1]; tmp[1] = tmp[2]; tmp[2] = t; } - if tmp[2] > tmp[3] { let t = tmp[2]; tmp[2] = tmp[3]; tmp[3] = t; } - if tmp[3] > tmp[4] { let t = tmp[3]; tmp[3] = tmp[4]; tmp[4] = t; } - if tmp[4] > tmp[5] { let t = tmp[4]; tmp[4] = tmp[5]; tmp[5] = t; } - if tmp[5] > tmp[6] { let t = tmp[5]; tmp[5] = tmp[6]; tmp[6] = t; } - if tmp[6] > tmp[7] { let t = tmp[6]; tmp[6] = tmp[7]; tmp[7] = t; } - - if tmp[0] > tmp[1] { let t = tmp[0]; tmp[0] = tmp[1]; tmp[1] = t; } - if tmp[1] > tmp[2] { let t = tmp[1]; tmp[1] = tmp[2]; tmp[2] = t; } - if tmp[2] > tmp[3] { let t = tmp[2]; tmp[2] = tmp[3]; tmp[3] = t; } - if tmp[3] > tmp[4] { let t = tmp[3]; tmp[3] = tmp[4]; tmp[4] = t; } - if tmp[4] > tmp[5] { let t = tmp[4]; tmp[4] = tmp[5]; tmp[5] = t; } - if tmp[5] > tmp[6] { let t = tmp[5]; tmp[5] = tmp[6]; tmp[6] = t; } - - if tmp[0] > tmp[1] { let t = tmp[0]; tmp[0] = tmp[1]; tmp[1] = t; } - if tmp[1] > tmp[2] { let t = tmp[1]; tmp[1] = tmp[2]; tmp[2] = t; } - if tmp[2] > tmp[3] { let t = tmp[2]; tmp[2] = tmp[3]; tmp[3] = t; } - if tmp[3] > tmp[4] { let t = tmp[3]; tmp[3] = tmp[4]; tmp[4] = t; } - if tmp[4] > tmp[5] { let t = tmp[4]; tmp[4] = tmp[5]; tmp[5] = t; } - - if tmp[0] > tmp[1] { let t = tmp[0]; tmp[0] = tmp[1]; tmp[1] = t; } - if tmp[1] > tmp[2] { let t = tmp[1]; tmp[1] = tmp[2]; tmp[2] = t; } - if tmp[2] > tmp[3] { let t = tmp[2]; tmp[2] = tmp[3]; tmp[3] = t; } - if tmp[3] > tmp[4] { let t = tmp[3]; tmp[3] = tmp[4]; tmp[4] = t; } - - if tmp[0] > tmp[1] { let t = tmp[0]; tmp[0] = tmp[1]; tmp[1] = t; } - if tmp[1] > tmp[2] { let t = tmp[1]; tmp[1] = tmp[2]; tmp[2] = t; } - if tmp[2] > tmp[3] { let t = tmp[2]; tmp[2] = tmp[3]; tmp[3] = t; } - - if tmp[0] > tmp[1] { let t = tmp[0]; tmp[0] = tmp[1]; tmp[1] = t; } - if tmp[1] > tmp[2] { let t = tmp[1]; tmp[1] = tmp[2]; tmp[2] = t; } - - if tmp[0] > tmp[1] { let t = tmp[0]; tmp[0] = tmp[1]; tmp[1] = t; } + for i = 0 to cnt - 1 { + for j = 0 to cnt - i - 1 { + if tmp[j] > tmp[j + 1] { + let t = tmp[j]; + tmp[j] = tmp[j + 1]; + tmp[j + 1] = t; + } + } + } - return tmp[4]; + return tmp[cnt / 2]; } #[entry] fn median_window(x: i32[5, 5]) -> i32 { let window: i32[3, 3]; - window[0, 0] = x[1, 1]; - window[0, 1] = x[1, 2]; - window[0, 2] = x[1, 3]; - window[1, 0] = x[2, 1]; - window[1, 1] = x[2, 2]; - window[1, 2] = x[2, 3]; - window[2, 0] = x[3, 1]; - window[2, 1] = x[3, 2]; - window[2, 2] = x[3, 3]; + for i = 0 to 3 { + for j = 0 to 3 { + window[i, j] = x[i + 1, j + 1]; + } + } - return median_matrix::<_>(window); + return median_matrix::<_, 3, 3>(window); } -- GitLab From 25688a4e369ae78237fd17df6a85b5f33e09fc8b Mon Sep 17 00:00:00 2001 From: Aaron Councilman <aaronjc4@illinois.edu> Date: Thu, 13 Feb 2025 20:53:24 -0600 Subject: [PATCH 5/7] Update schedules for median to forkify and unroll some loops --- juno_samples/median_window/src/cpu.sch | 11 +++++++++++ juno_samples/median_window/src/gpu.sch | 11 +++++++++++ 2 files changed, 22 insertions(+) diff --git a/juno_samples/median_window/src/cpu.sch b/juno_samples/median_window/src/cpu.sch index bab309a6..a3eded93 100644 --- a/juno_samples/median_window/src/cpu.sch +++ b/juno_samples/median_window/src/cpu.sch @@ -3,6 +3,7 @@ phi-elim(*); dce(*); inline(*); +delete-uncalled(*); let out = auto-outline(*); cpu(out.median_window); @@ -14,6 +15,16 @@ gvn(*); phi-elim(*); dce(*); +fixpoint { + forkify(*); + fork-guard-elim(*); + unroll(*); +} + +ccp(*); +gvn(*); +dce(*); + array-to-product(*); sroa(*); phi-elim(*); diff --git a/juno_samples/median_window/src/gpu.sch b/juno_samples/median_window/src/gpu.sch index df4e04df..86e17466 100644 --- a/juno_samples/median_window/src/gpu.sch +++ b/juno_samples/median_window/src/gpu.sch @@ -3,6 +3,7 @@ phi-elim(*); dce(*); inline(*); +delete-uncalled(*); let out = auto-outline(*); gpu(out.median_window); @@ -14,6 +15,16 @@ gvn(*); phi-elim(*); dce(*); +fixpoint { + forkify(*); + fork-guard-elim(*); + unroll(*); +} + +ccp(*); +gvn(*); +dce(*); + array-to-product(*); sroa(*); phi-elim(*); -- GitLab From 772dc81f184e9b5771a5e19121b7f4158cf11d0c Mon Sep 17 00:00:00 2001 From: Aaron Councilman <aaronjc4@illinois.edu> Date: Fri, 14 Feb 2025 09:12:06 -0600 Subject: [PATCH 6/7] Add optional max-size argument to array to prod --- hercules_opt/src/array_to_prod.rs | 42 +++++++++++++++++++++++-------- juno_scheduler/src/compile.rs | 23 +++++++++++------ juno_scheduler/src/ir.rs | 24 +++++++++++++----- juno_scheduler/src/pm.rs | 16 +++++++++--- 4 files changed, 78 insertions(+), 27 deletions(-) diff --git a/hercules_opt/src/array_to_prod.rs b/hercules_opt/src/array_to_prod.rs index 5908a4e5..7b7745c0 100644 --- a/hercules_opt/src/array_to_prod.rs +++ b/hercules_opt/src/array_to_prod.rs @@ -30,11 +30,16 @@ use std::marker::PhantomData; * restriction) * - All nodes in the set are editable (if we cannot modify some node then the * conversion will fail) + * + * The max_size argument allows the user to specify a limit on the size of arrays + * that should be converted to products. If the number of elements in the array + * is larger than the max size the array will not be converted. */ -pub fn array_to_product(editor: &mut FunctionEditor, types: &[TypeID]) { - let replace_nodes = array_usage_analysis(editor, types); +pub fn array_to_product(editor: &mut FunctionEditor, types: &[TypeID], max_size: Option<usize>) { + let replace_nodes = array_usage_analysis(editor, types, max_size); let num_nodes = editor.func().nodes.len(); + // Replace nodes for node_idx in 0..num_nodes { if !replace_nodes[node_idx] { continue; @@ -142,11 +147,15 @@ fn convert_indices_to_prod(editor: &FunctionEditor, indices: &[Index]) -> Box<[I // Given the editor, while compute a mask of which nodes are to be converted // from using a constant sized array into using a product -fn array_usage_analysis(editor: &FunctionEditor, types: &[TypeID]) -> BitVec<u8, Lsb0> { +fn array_usage_analysis( + editor: &FunctionEditor, + types: &[TypeID], + max_size: Option<usize>, +) -> BitVec<u8, Lsb0> { let num_nodes = editor.func().nodes.len(); - // Step 1: identify the constant nodes that are constant sized arrays, these are what we are - // interested in converting into products + // Step 1: identify the constant nodes that are constant sized arrays no larger than the + // max_size, these are what we are interested in converting into products let sources = editor .func() .nodes @@ -163,11 +172,24 @@ fn array_usage_analysis(editor: &FunctionEditor, types: &[TypeID]) -> BitVec<u8, let Some(dims) = typ.try_extents() else { return None; }; - if dims - .iter() - .all(|dc| editor.get_dynamic_constant(*dc).is_constant()) - { - Some(NodeID::new(idx)) + // Compute the total number of elements, the result is None if some dimension is not a + // constant and otherwise is Some(num_elements) which we can then compare with max_size + if let Some(elements) = dims.iter().fold(Some(1), |prod, dc| { + prod.and_then(|prod| { + editor + .get_dynamic_constant(*dc) + .try_constant() + .map(|dim| prod * dim) + }) + }) { + if let Some(max_size) = max_size + && elements > max_size + { + // Too many elements, don't convert + None + } else { + Some(NodeID::new(idx)) + } } else { None } diff --git a/juno_scheduler/src/compile.rs b/juno_scheduler/src/compile.rs index 0cf9d208..e0e08e95 100644 --- a/juno_scheduler/src/compile.rs +++ b/juno_scheduler/src/compile.rs @@ -18,7 +18,7 @@ pub enum ScheduleCompilerError { UndefinedMacro(String, Location), NoSuchPass(String, Location), IncorrectArguments { - expected: usize, + expected: String, actual: usize, loc: Location, }, @@ -81,11 +81,20 @@ enum Appliable { } impl Appliable { - fn num_args(&self) -> usize { + // Tests whether a given number of arguments is a valid number of arguments for this + fn is_valid_num_args(&self, num: usize) -> bool { match self { - Appliable::Pass(pass) => pass.num_args(), + Appliable::Pass(pass) => pass.is_valid_num_args(num), // Delete uncalled, Schedules, and devices do not take arguments - _ => 0, + _ => num == 0, + } + } + + // Returns a description of the number of arguments this requires + fn valid_arg_nums(&self) -> &'static str { + match self { + Appliable::Pass(pass) => pass.valid_arg_nums(), + _ => "0", } } } @@ -309,9 +318,9 @@ fn compile_expr( .parse() .map_err(|s| ScheduleCompilerError::NoSuchPass(s, lexer.line_col(name)))?; - if args.len() != func.num_args() { + if !func.is_valid_num_args(args.len()) { return Err(ScheduleCompilerError::IncorrectArguments { - expected: func.num_args(), + expected: func.valid_arg_nums().to_string(), actual: args.len(), loc: lexer.line_col(span), }); @@ -368,7 +377,7 @@ fn compile_expr( if args.len() != params.len() { return Err(ScheduleCompilerError::IncorrectArguments { - expected: params.len(), + expected: params.len().to_string(), actual: args.len(), loc: lexer.line_col(span), }); diff --git a/juno_scheduler/src/ir.rs b/juno_scheduler/src/ir.rs index 26e6ddc9..480fee64 100644 --- a/juno_scheduler/src/ir.rs +++ b/juno_scheduler/src/ir.rs @@ -40,13 +40,25 @@ pub enum Pass { } impl Pass { - pub fn num_args(&self) -> usize { + pub fn is_valid_num_args(&self, num: usize) -> bool { match self { - Pass::Xdot => 1, - Pass::ForkChunk => 4, - Pass::ForkFissionBufferize => 2, - Pass::ForkInterchange => 2, - _ => 0, + Pass::ArrayToProduct => num == 0 || num == 1, + Pass::Xdot => num == 0 || num == 1, + Pass::ForkChunk => num == 4, + Pass::ForkFissionBufferize => num == 2, + Pass::ForkInterchange => num == 2, + _ => num == 0, + } + } + + pub fn valid_arg_nums(&self) -> &'static str { + match self { + Pass::ArrayToProduct => "0 or 1", + Pass::Xdot => "0 or 1", + Pass::ForkChunk => "4", + Pass::ForkFissionBufferize => "2", + Pass::ForkInterchange => "2", + _ => "0", } } } diff --git a/juno_scheduler/src/pm.rs b/juno_scheduler/src/pm.rs index 5464125e..9aa098a0 100644 --- a/juno_scheduler/src/pm.rs +++ b/juno_scheduler/src/pm.rs @@ -1522,9 +1522,17 @@ fn run_pass( pm.clear_analyses(); } Pass::ArrayToProduct => { - // TODO: We might allow an (optional) maximum size as an argument to this pass for - // nicer control for the user - assert!(args.is_empty()); + assert!(args.len() <= 1); + let max_size = match args.get(0) { + Some(Value::Integer { val }) => Some(*val), + Some(_) => { + return Err(SchedulerError::PassError { + pass: "array-to-product".to_string(), + error: "expected integer argument".to_string(), + }); + } + None => None, + }; pm.make_typing(); let typing = pm.typing.take().unwrap(); @@ -1535,7 +1543,7 @@ fn run_pass( let Some(mut func) = func else { continue; }; - array_to_product(&mut func, types); + array_to_product(&mut func, types, max_size); changed |= func.modified(); } pm.delete_gravestones(); -- GitLab From 664daa176615ba093cd81de03113a70ab2b37fb9 Mon Sep 17 00:00:00 2001 From: Russel Arbore <russel.jma@gmail.com> Date: Fri, 14 Feb 2025 10:36:46 -0600 Subject: [PATCH 7/7] fully unroll / predicate median window --- hercules_opt/src/lift_dc_math.rs | 7 ------- juno_samples/median_window/src/cpu.sch | 14 ++++++++------ juno_samples/median_window/src/gpu.sch | 15 +++++++++------ juno_samples/median_window/src/median.jn | 2 +- 4 files changed, 18 insertions(+), 20 deletions(-) diff --git a/hercules_opt/src/lift_dc_math.rs b/hercules_opt/src/lift_dc_math.rs index 8256c889..119fbd4d 100644 --- a/hercules_opt/src/lift_dc_math.rs +++ b/hercules_opt/src/lift_dc_math.rs @@ -23,13 +23,6 @@ pub fn lift_dc_math(editor: &mut FunctionEditor) { }; DynamicConstant::Constant(cons as usize) } - Node::DynamicConstant { id } => { - let Some(cons) = evaluate_dynamic_constant(id, &*editor.get_dynamic_constants()) - else { - continue; - }; - DynamicConstant::Constant(cons) - } Node::Binary { op, left, right } => { let (left, right) = if let ( Node::DynamicConstant { id: left }, diff --git a/juno_samples/median_window/src/cpu.sch b/juno_samples/median_window/src/cpu.sch index a3eded93..fd44ba57 100644 --- a/juno_samples/median_window/src/cpu.sch +++ b/juno_samples/median_window/src/cpu.sch @@ -15,12 +15,18 @@ gvn(*); phi-elim(*); dce(*); +forkify(*); +fork-guard-elim(*); +lift-dc-math(*); +forkify(*); +fork-guard-elim(*); +fork-unroll(out.median_window@outer); +lift-dc-math(*); fixpoint { forkify(*); fork-guard-elim(*); - unroll(*); + fork-unroll(*); } - ccp(*); gvn(*); dce(*); @@ -33,8 +39,4 @@ simplify-cfg(*); dce(*); gvn(*); -infer-schedules(*); -gcm(*); -float-collections(*); -dce(*); gcm(*); diff --git a/juno_samples/median_window/src/gpu.sch b/juno_samples/median_window/src/gpu.sch index 86e17466..d9ca4d81 100644 --- a/juno_samples/median_window/src/gpu.sch +++ b/juno_samples/median_window/src/gpu.sch @@ -15,12 +15,18 @@ gvn(*); phi-elim(*); dce(*); +forkify(*); +fork-guard-elim(*); +lift-dc-math(*); +forkify(*); +fork-guard-elim(*); +fork-unroll(out.median_window@outer); +lift-dc-math(*); fixpoint { forkify(*); fork-guard-elim(*); - unroll(*); + fork-unroll(*); } - ccp(*); gvn(*); dce(*); @@ -33,8 +39,5 @@ simplify-cfg(*); dce(*); gvn(*); -infer-schedules(*); -gcm(*); -float-collections(*); -dce(*); gcm(*); + diff --git a/juno_samples/median_window/src/median.jn b/juno_samples/median_window/src/median.jn index 38ae407b..bbb02e49 100644 --- a/juno_samples/median_window/src/median.jn +++ b/juno_samples/median_window/src/median.jn @@ -8,7 +8,7 @@ fn median_matrix<t: number, n, m: usize>(x: t[n, m]) -> t { const cnt = n * m; - for i = 0 to cnt - 1 { + @outer for i = 0 to cnt - 1 { for j = 0 to cnt - i - 1 { if tmp[j] > tmp[j + 1] { let t = tmp[j]; -- GitLab