diff --git a/Cargo.lock b/Cargo.lock index a84f3d2c48283e1ef8dfdf0b441e18cc8e6b39ee..4431cf5d3eeac7bcb576e90de5f4dae2efabfd20 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1231,6 +1231,16 @@ dependencies = [ "with_builtin_macros", ] +[[package]] +name = "juno_median_window" +version = "0.1.0" +dependencies = [ + "async-std", + "hercules_rt", + "juno_build", + "with_builtin_macros", +] + [[package]] name = "juno_multi_device" version = "0.1.0" diff --git a/Cargo.toml b/Cargo.toml index 0ed8f64bbfe490464d7a201837f2f9ba133326c3..6514046bcada4372c94dfd73649d7d2329e28f36 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -34,4 +34,5 @@ members = [ "juno_samples/fork_join_tests", "juno_samples/multi_device", "juno_samples/products", + "juno_samples/median_window", ] diff --git a/hercules_ir/src/dot.rs b/hercules_ir/src/dot.rs index 9c6c5f174a4c3b1c4a19ce15597e5385ddae7aa4..0e0840853bdcb83caa954991f40397b33e1fd5eb 100644 --- a/hercules_ir/src/dot.rs +++ b/hercules_ir/src/dot.rs @@ -24,8 +24,8 @@ pub fn xdot_module( bbs: Option<&Vec<BasicBlocks>>, ) { let mut tmp_path = temp_dir(); - let mut rng = rand::thread_rng(); - let num: u64 = rng.gen(); + let mut rng = rand::rng(); + let num: u64 = rng.random(); tmp_path.push(format!("hercules_dot_{}.dot", num)); let mut file = File::create(&tmp_path).expect("PANIC: Unable to open output file."); let mut contents = String::new(); @@ -43,11 +43,11 @@ pub fn xdot_module( .expect("PANIC: Unable to generate output file contents."); file.write_all(contents.as_bytes()) .expect("PANIC: Unable to write output file contents."); + println!("Graphviz written to: {}", tmp_path.display()); Command::new("xdot") .args([&tmp_path]) .output() .expect("PANIC: Couldn't execute xdot. Is xdot installed?"); - println!("Graphviz written to: {}", tmp_path.display()); } /* @@ -108,6 +108,13 @@ pub fn write_dot<W: Write>( for node_id in (0..function.nodes.len()).map(NodeID::new) { let node = &function.nodes[node_id.idx()]; + let skip = node.is_constant() + || node.is_dynamic_constant() + || node.is_undef() + || node.is_parameter(); + if skip { + continue; + } let dst_control = node.is_control(); for u in def_use::get_uses(&node).as_ref() { let src_control = function.nodes[u.idx()].is_control(); diff --git a/hercules_ir/src/fork_join_analysis.rs b/hercules_ir/src/fork_join_analysis.rs index 3fcc6af029a50839c0b382be371db6fa593e1119..c584a3fd01da993f95ecc07f8e4a251834053faf 100644 --- a/hercules_ir/src/fork_join_analysis.rs +++ b/hercules_ir/src/fork_join_analysis.rs @@ -86,7 +86,7 @@ pub fn compute_fork_join_nesting( pub fn reduce_cycles( function: &Function, fork_join_map: &HashMap<NodeID, NodeID>, - fork_join_nest: &HashMap<NodeID, Vec<NodeID>>, + nodes_in_fork_joins: &HashMap<NodeID, HashSet<NodeID>>, ) -> HashMap<NodeID, HashSet<NodeID>> { let reduces = (0..function.nodes.len()) .filter(|idx| function.nodes[*idx].is_reduce()) @@ -111,7 +111,7 @@ pub fn reduce_cycles( reduce, &mut current_visited, &mut in_cycle, - fork_join_nest, + nodes_in_fork_joins, ); result.insert(reduce, in_cycle); } @@ -126,17 +126,8 @@ fn reduce_cycle_dfs_helper( reduce: NodeID, current_visited: &mut HashSet<NodeID>, in_cycle: &mut HashSet<NodeID>, - fork_join_nest: &HashMap<NodeID, Vec<NodeID>>, + nodes_in_fork_joins: &HashMap<NodeID, HashSet<NodeID>>, ) -> bool { - let isnt_outside_fork_join = |id: NodeID| { - let node = &function.nodes[id.idx()]; - node.try_phi() - .map(|(control, _)| control) - .or(node.try_reduce().map(|(control, _, _)| control)) - .map(|control| fork_join_nest[&control].contains(&fork)) - .unwrap_or(true) - }; - if iter == reduce || in_cycle.contains(&iter) { return true; } @@ -148,7 +139,7 @@ fn reduce_cycle_dfs_helper( for u in get_uses(&function.nodes[iter.idx()]).as_ref() { found_reduce |= !current_visited.contains(u) && !function.nodes[u.idx()].is_control() - && isnt_outside_fork_join(*u) + && nodes_in_fork_joins[&fork].contains(u) && reduce_cycle_dfs_helper( function, *u, @@ -156,7 +147,7 @@ fn reduce_cycle_dfs_helper( reduce, current_visited, in_cycle, - fork_join_nest, + nodes_in_fork_joins, ) } if found_reduce { diff --git a/hercules_opt/src/array_to_prod.rs b/hercules_opt/src/array_to_prod.rs new file mode 100644 index 0000000000000000000000000000000000000000..7b7745c0db3a017410551589993414bb2dab8031 --- /dev/null +++ b/hercules_opt/src/array_to_prod.rs @@ -0,0 +1,352 @@ +use hercules_ir::define_id_type; +use hercules_ir::ir::*; + +use bitvec::prelude::*; + +use crate::*; + +use std::collections::{HashMap, HashSet}; +use std::marker::PhantomData; + +/* + * Top level function for array to product which will convert constant + * sized arrays into products if the array is only accessed at indices which + * are constants. + * + * To identify the collections we can convert we look at each constant-sized + * array constant and compute the set which includes the constant node and is + * closed under the following properties: + * - For each collection in the set, its uses are in the set + * - For each node that uses a collection, all collections it uses are in the + * set + * From this set, we then determine whether this whole set can be converted to + * operating on products, rather than arrays, as follows + * - Each read and write node must be to a constant index + * - It may not contain any arguments (we could generate code to read a an + * array argument into a product, but do not do so for now) + * - There are call or return nodes in the set (this would mean that the + * collections are consumed by a call or return, again we could reconstruct + * the array where needed but do not do so for now and so have this + * restriction) + * - All nodes in the set are editable (if we cannot modify some node then the + * conversion will fail) + * + * The max_size argument allows the user to specify a limit on the size of arrays + * that should be converted to products. If the number of elements in the array + * is larger than the max size the array will not be converted. + */ +pub fn array_to_product(editor: &mut FunctionEditor, types: &[TypeID], max_size: Option<usize>) { + let replace_nodes = array_usage_analysis(editor, types, max_size); + let num_nodes = editor.func().nodes.len(); + + // Replace nodes + for node_idx in 0..num_nodes { + if !replace_nodes[node_idx] { + continue; + } + let node = NodeID::new(node_idx); + + // We can replace the array(s) this node uses with a product. What we have to do depends on + // the type of the node + match &editor.func().nodes[node_idx] { + // Phi, Reduce, and Ternary just use the whole collection, they do not need to change, + // except as they will be modified by replace_all_uses_of + Node::Phi { .. } + | Node::Reduce { .. } + | Node::Ternary { + op: TernaryOperator::Select, + .. + } => {} + Node::Constant { id } => { + assert!(editor.get_constant(*id).is_array()); + let element: TypeID = editor.get_type(types[node_idx]).try_element_type().unwrap(); + let dims: Vec<usize> = editor + .get_type(types[node_idx]) + .try_extents() + .unwrap() + .iter() + .map(|dc| editor.get_dynamic_constant(*dc).try_constant().unwrap()) + .collect(); + // Replace the constant by a product that is a product (for each dimension) and the + // elements are zero'd + editor.edit(|mut edit| { + let element_zero = edit.add_zero_constant(element); + let (constant, _) = dims.into_iter().rfold( + (element_zero, element), + |(cur_const, cur_type), dim| { + let new_type = edit.add_type(Type::Product(vec![cur_type; dim].into())); + let new_const = edit.add_constant(Constant::Product( + new_type, + vec![cur_const; dim].into(), + )); + (new_const, new_type) + }, + ); + let new_val = edit.add_node(Node::Constant { id: constant }); + let edit = edit.replace_all_uses(node, new_val)?; + edit.delete_node(node) + }); + } + Node::Read { collect, indices } => { + let collect = *collect; + let new_indices = convert_indices_to_prod(editor, indices); + editor.edit(|mut edit| { + let new_val = edit.add_node(Node::Read { + collect, + indices: new_indices, + }); + let edit = edit.replace_all_uses(NodeID::new(node_idx), new_val)?; + edit.delete_node(node) + }); + } + Node::Write { + collect, + data, + indices, + } => { + let collect = *collect; + let data = *data; + let new_indices = convert_indices_to_prod(editor, indices); + editor.edit(|mut edit| { + let new_val = edit.add_node(Node::Write { + collect, + data, + indices: new_indices, + }); + let edit = edit.replace_all_uses(NodeID::new(node_idx), new_val)?; + edit.delete_node(node) + }); + } + node => panic!("Node cannot be replaced: {:?}", node), + } + } +} + +fn convert_indices_to_prod(editor: &FunctionEditor, indices: &[Index]) -> Box<[Index]> { + let mut result = vec![]; + + for index in indices { + match index { + Index::Position(positions) => { + for pos in positions { + let const_id = editor.func().nodes[pos.idx()] + .try_constant() + .expect("Array position must be constant"); + match *editor.get_constant(const_id) { + Constant::UnsignedInteger64(idx) => result.push(Index::Field(idx as usize)), + ref val => panic!("Position should be u64 constant: {:?}", val), + } + } + } + index => panic!("Index cannot be replaced: {:?}", index), + } + } + + result.into() +} + +// Given the editor, while compute a mask of which nodes are to be converted +// from using a constant sized array into using a product +fn array_usage_analysis( + editor: &FunctionEditor, + types: &[TypeID], + max_size: Option<usize>, +) -> BitVec<u8, Lsb0> { + let num_nodes = editor.func().nodes.len(); + + // Step 1: identify the constant nodes that are constant sized arrays no larger than the + // max_size, these are what we are interested in converting into products + let sources = editor + .func() + .nodes + .iter() + .enumerate() + .filter_map(|(idx, node)| { + let Node::Constant { id } = node else { + return None; + }; + let Constant::Array(array_type) = *editor.get_constant(*id) else { + return None; + }; + let typ = editor.get_type(array_type); + let Some(dims) = typ.try_extents() else { + return None; + }; + // Compute the total number of elements, the result is None if some dimension is not a + // constant and otherwise is Some(num_elements) which we can then compare with max_size + if let Some(elements) = dims.iter().fold(Some(1), |prod, dc| { + prod.and_then(|prod| { + editor + .get_dynamic_constant(*dc) + .try_constant() + .map(|dim| prod * dim) + }) + }) { + if let Some(max_size) = max_size + && elements > max_size + { + // Too many elements, don't convert + None + } else { + Some(NodeID::new(idx)) + } + } else { + None + } + }) + .collect::<Vec<_>>(); + + // Step 2: collect the collection information we need for the (whole) function. For each node + // that returns a collection (that in reference semantics returns the same reference as some of + // its inputs) union with all of its users. The nodes that matter in this are arguments, + // constants, writes, phis, selects, and reduces with array types. + let mut analysis = UnionFind::new(); + for node_idx in 0..num_nodes { + let node_id = NodeID::new(node_idx); + if editor.get_type(types[node_idx]).is_array() { + match editor.func().nodes[node_idx] { + Node::Phi { .. } + | Node::Reduce { .. } + | Node::Parameter { .. } + | Node::Constant { .. } + | Node::Ternary { + op: TernaryOperator::Select, + .. + } + | Node::Write { .. } => { + for user in editor.get_users(node_id) { + analysis.union(node_id, user); + } + } + _ => {} + } + } + } + + let sets = analysis.sets(&sources); + + // Step 3: determine which sets can be converted and mark the nodes in those sets + let mut result = bitvec![u8, Lsb0; 0; num_nodes]; + + for nodes in sets { + if nodes + .iter() + .all(|node_id| editor.is_mutable(*node_id) && can_replace(editor, *node_id)) + { + for node_id in nodes { + result.set(node_id.idx(), true); + } + } + } + + result +} + +fn can_replace(editor: &FunctionEditor, node: NodeID) -> bool { + match &editor.func().nodes[node.idx()] { + // Reads and writes must be at constant indices + Node::Read { indices, .. } | Node::Write { indices, .. } => { + indices.iter().all(|idx| match idx { + Index::Position(pos) => pos + .iter() + .all(|node| editor.func().nodes[node.idx()].is_constant()), + _ => false, + }) + } + // phi, reduce, constants, and select can always be replaced if their users and uses allow + // it, which is handled by the construction of the set + Node::Phi { .. } + | Node::Reduce { .. } + | Node::Constant { .. } + | Node::Ternary { + op: TernaryOperator::Select, + .. + } => true, + // No other nodes allow replacement + _ => false, + } +} + +define_id_type!(SetID); + +#[derive(Clone, Debug)] +struct UnionFindNode { + parent: SetID, + rank: usize, +} + +#[derive(Clone, Debug)] +struct UnionFind<T> { + sets: Vec<UnionFindNode>, + _phantom: PhantomData<T>, +} + +impl<T: ID> UnionFind<T> { + pub fn new() -> Self { + UnionFind { + sets: vec![], + _phantom: PhantomData, + } + } + + fn extend_past(&mut self, size: usize) { + for i in self.sets.len()..=size { + // The new nodes we add are in their own sets and have rank 0 + self.sets.push(UnionFindNode { + parent: SetID::new(i), + rank: 0, + }); + } + } + + pub fn find(&mut self, x: T) -> SetID { + self.extend_past(x.idx()); + self.find_set(x.idx()) + } + + fn find_set(&mut self, x: usize) -> SetID { + let mut parent = self.sets[x].parent; + if parent.idx() != x { + parent = self.find_set(parent.idx()); + self.sets[x].parent = parent; + } + parent + } + + pub fn union(&mut self, x: T, y: T) { + let x = self.find(x); + let y = self.find(y); + self.link(x, y); + } + + fn link(&mut self, x: SetID, y: SetID) { + if self.sets[x.idx()].rank > self.sets[y.idx()].rank { + self.sets[y.idx()].parent = x; + } else { + self.sets[x.idx()].parent = y; + if self.sets[x.idx()].rank == self.sets[y.idx()].rank { + self.sets[y.idx()].rank += 1; + } + } + } + + pub fn sets(&mut self, keys: &[T]) -> Vec<Vec<T>> { + let key_index = keys + .iter() + .enumerate() + .map(|(i, k)| (self.find(*k), i)) + .collect::<HashMap<SetID, usize>>(); + let mut result = vec![vec![]; keys.len()]; + + let num_elements = self.sets.len(); + for i in 0..num_elements { + let set = self.find_set(i); + let Some(idx) = key_index.get(&set) else { + continue; + }; + result[*idx].push(T::new(i)); + } + + result + } +} diff --git a/hercules_opt/src/forkify.rs b/hercules_opt/src/forkify.rs index 2adfddd8d643386adea67a83957e87f7184b2c77..774220df1491dc2e9c81bb0d51f4fad5203a6e24 100644 --- a/hercules_opt/src/forkify.rs +++ b/hercules_opt/src/forkify.rs @@ -30,7 +30,7 @@ pub fn forkify( for l in natural_loops { // FIXME: Run on all-bottom level loops, as they can be independently optimized without recomputing analyses. - if forkify_loop( + if editor.is_mutable(l.0) && forkify_loop( editor, control_subgraph, fork_join_map, @@ -166,6 +166,7 @@ pub fn forkify_loop( return false; } + // Get all phis used outside of the loop, they need to be reductionable. // For now just assume all phis will be phis used outside of the loop, except for the canonical iv. // FIXME: We need a different definiton of `loop_nodes` to check for phis used outside hte loop than the one @@ -182,7 +183,6 @@ pub fn forkify_loop( let reductionable_phis: Vec<_> = analyze_phis(&editor, &l, &candidate_phis, &loop_nodes) .into_iter() .collect(); - // TODO: Handle multiple loop body lasts. // If there are multiple candidates for loop body last, return false. if editor @@ -327,7 +327,7 @@ pub fn forkify_loop( .collect(); // Start failable edit: - editor.edit(|mut edit| { + let result = editor.edit(|mut edit| { let thread_id = Node::ThreadID { control: fork_id, dimension: dimension, @@ -405,7 +405,7 @@ pub fn forkify_loop( Ok(edit) }); - return true; + return result; } nest! { @@ -457,7 +457,7 @@ pub fn analyze_phis<'a>( // External Phi if let Node::Phi { control, data: _ } = data { - if *control != natural_loop.header { + if !natural_loop.control[control.idx()] { return true; } } @@ -539,16 +539,7 @@ pub fn analyze_phis<'a>( // PHIs on the frontier of the uses by the candidate phi, i.e in uses_for_dependance need // to have headers that postdominate the loop continue latch. The value of the PHI used needs to be defined // by the time the reduce is triggered (at the end of the loop's internal control). - // If anything in the intersection is a phi (that isn't this own phi), then the reduction cycle depends on control. - // Which is not allowed. - if intersection - .iter() - .any(|cycle_node| editor.node(cycle_node).is_phi() && *cycle_node != *phi) - || editor.node(loop_continue_latch).is_phi() - { - return LoopPHI::ControlDependant(*phi); - } - + // No nodes in data cycles with this phi (in the loop) are used outside the loop, besides the loop_continue_latch. // If some other node in the cycle is used, there is not a valid node to assign it after making the cycle a reduce. if intersection diff --git a/hercules_opt/src/ivar.rs b/hercules_opt/src/ivar.rs index f7252d29b66f9fc1882849206bbbf5b327a0f307..edadd71722698158c77fa1efaf81e48f3c7afcc9 100644 --- a/hercules_opt/src/ivar.rs +++ b/hercules_opt/src/ivar.rs @@ -1,3 +1,4 @@ +use core::panic; use std::collections::HashSet; use bitvec::prelude::*; @@ -73,8 +74,13 @@ pub fn calculate_loop_nodes(editor: &FunctionEditor, natural_loop: &Loop) -> Has // External Phi if let Node::Phi { control, data: _ } = data { - if !natural_loop.control[control.idx()] { - return true; + match natural_loop.control.get(control.idx()) { + Some(v) => if !*v { + return true; + }, + None => { + panic!("unexpceted index: {:?} for loop {:?}", control, natural_loop.header); + }, } } // External Reduce @@ -84,14 +90,26 @@ pub fn calculate_loop_nodes(editor: &FunctionEditor, natural_loop: &Loop) -> Has reduct: _, } = data { - if !natural_loop.control[control.idx()] { - return true; + match natural_loop.control.get(control.idx()) { + Some(v) => if !*v { + return true; + }, + None => { + panic!("unexpceted index: {:?} for loop {:?}", control, natural_loop.header); + }, } } // External Control - if data.is_control() && !natural_loop.control[node.idx()] { - return true; + if data.is_control() { + match natural_loop.control.get(node.idx()) { + Some(v) => if !*v { + return true; + }, + None => { + panic!("unexpceted index: {:?} for loop {:?}", node, natural_loop.header); + }, + } } return false; diff --git a/hercules_opt/src/lib.rs b/hercules_opt/src/lib.rs index 17c55bbe71e74bf34fb1a3cf536b4dc2a8c36e9a..a349230e0add49e07b6a17dad05a358b17b3a8fa 100644 --- a/hercules_opt/src/lib.rs +++ b/hercules_opt/src/lib.rs @@ -1,5 +1,6 @@ #![feature(let_chains)] +pub mod array_to_prod; pub mod ccp; pub mod crc; pub mod dce; @@ -26,6 +27,7 @@ pub mod sroa; pub mod unforkify; pub mod utils; +pub use crate::array_to_prod::*; pub use crate::ccp::*; pub use crate::crc::*; pub use crate::dce::*; diff --git a/hercules_opt/src/lift_dc_math.rs b/hercules_opt/src/lift_dc_math.rs index 8256c889085a9b2902c6d4d5c8fd5a9fa2e77429..119fbd4d890a2af7a22b7526ae2a188b91d8fc3a 100644 --- a/hercules_opt/src/lift_dc_math.rs +++ b/hercules_opt/src/lift_dc_math.rs @@ -23,13 +23,6 @@ pub fn lift_dc_math(editor: &mut FunctionEditor) { }; DynamicConstant::Constant(cons as usize) } - Node::DynamicConstant { id } => { - let Some(cons) = evaluate_dynamic_constant(id, &*editor.get_dynamic_constants()) - else { - continue; - }; - DynamicConstant::Constant(cons) - } Node::Binary { op, left, right } => { let (left, right) = if let ( Node::DynamicConstant { id: left }, diff --git a/juno_samples/cava/src/cava.jn b/juno_samples/cava/src/cava.jn index 366792c3cfcb9d20b35da0760ea5e408aa26e0c9..f95f1ebc89dfaa4a798f561899963b734c766b90 100644 --- a/juno_samples/cava/src/cava.jn +++ b/juno_samples/cava/src/cava.jn @@ -1,22 +1,24 @@ fn medianMatrix<a : number, rows, cols : usize>(m : a[rows, cols]) -> a { - const n : usize = rows * cols; - - @tmp let tmp : a[rows * cols]; - for i = 0 to rows * cols { - tmp[i] = m[i / cols, i % cols]; - } - - for i = 0 to n - 1 { - for j = 0 to n - i - 1 { - if tmp[j] > tmp[j+1] { - let t : a = tmp[j]; - tmp[j] = tmp[j+1]; - tmp[j+1] = t; + @median { + const n : usize = rows * cols; + + @tmp let tmp : a[rows * cols]; + for i = 0 to rows * cols { + tmp[i] = m[i / cols, i % cols]; + } + + @medianOuter for i = 0 to n - 1 { + for j = 0 to n - i - 1 { + if tmp[j] > tmp[j+1] { + let t : a = tmp[j]; + tmp[j] = tmp[j+1]; + tmp[j+1] = t; + } } } + + return tmp[n / 2]; } - - return tmp[n / 2]; } const CHAN : u64 = 3; @@ -27,7 +29,7 @@ fn scale<row : usize, col : usize>(input : u8[CHAN, row, col]) -> f32[CHAN, row, for chan = 0 to CHAN { for r = 0 to row { for c = 0 to col { - res[chan, r, c] = input[chan, r, c] as f32 / 255; + res[chan, r, c] = input[chan, r, c] as f32 * (1.0 / 255.0); } } } diff --git a/juno_samples/cava/src/cpu.sch b/juno_samples/cava/src/cpu.sch index d51b8b943111950e73dc1f9dd5ab126b33ae6b14..2ef7152f6746ac9cac7a7e8033089049b748a4df 100644 --- a/juno_samples/cava/src/cpu.sch +++ b/juno_samples/cava/src/cpu.sch @@ -51,8 +51,32 @@ fixpoint { fork-coalesce(fuse2); } simpl!(fuse2); -array-slf(fuse2); +predication(fuse2); simpl!(fuse2); + +let median = outline(fuse2@median); +fork-unroll(median@medianOuter); +simpl!(median); +fixpoint { + forkify(median); + fork-guard-elim(median); +} +simpl!(median); +fixpoint { + fork-unroll(median); +} +ccp(median); +array-to-product(median); +sroa(median); +predication(median); +simpl!(median); + +inline(fuse2); +ip-sroa(*); +sroa(*); +array-slf(fuse2); +xdot[true](fuse2); +fork-split(fuse2); unforkify(fuse2); no-memset(fuse3@res); @@ -91,4 +115,4 @@ unforkify(fuse5); simpl!(*); delete-uncalled(*); -gcm(*); \ No newline at end of file +gcm(*); diff --git a/juno_samples/median_window/Cargo.toml b/juno_samples/median_window/Cargo.toml new file mode 100644 index 0000000000000000000000000000000000000000..1e376d5d23f6ba7c1fd41f169f3f955a36c0a7d4 --- /dev/null +++ b/juno_samples/median_window/Cargo.toml @@ -0,0 +1,21 @@ +[package] +name = "juno_median_window" +version = "0.1.0" +authors = ["Aaron Councilman <aaronjc4@illinois.edu>"] +edition = "2021" + +[[bin]] +name = "juno_median_window" +path = "src/main.rs" + +[features] +cuda = ["juno_build/cuda", "hercules_rt/cuda"] + +[build-dependencies] +juno_build = { path = "../../juno_build" } + +[dependencies] +juno_build = { path = "../../juno_build" } +hercules_rt = { path = "../../hercules_rt" } +with_builtin_macros = "0.1.0" +async-std = "*" diff --git a/juno_samples/median_window/build.rs b/juno_samples/median_window/build.rs new file mode 100644 index 0000000000000000000000000000000000000000..a6c29d5b2184490b673dc381532562ba50d169ef --- /dev/null +++ b/juno_samples/median_window/build.rs @@ -0,0 +1,11 @@ +use juno_build::JunoCompiler; + +fn main() { + JunoCompiler::new() + .file_in_src("median.jn") + .unwrap() + .schedule_in_src(if cfg!(feature = "cuda") { "gpu.sch" } else { "cpu.sch" }) + .unwrap() + .build() + .unwrap(); +} diff --git a/juno_samples/median_window/src/cpu.sch b/juno_samples/median_window/src/cpu.sch new file mode 100644 index 0000000000000000000000000000000000000000..fd44ba572e2d906209b85fe63ed9f819ed973757 --- /dev/null +++ b/juno_samples/median_window/src/cpu.sch @@ -0,0 +1,42 @@ +gvn(*); +phi-elim(*); +dce(*); + +inline(*); +delete-uncalled(*); + +let out = auto-outline(*); +cpu(out.median_window); + +ip-sroa(*); +sroa(*); +dce(*); +gvn(*); +phi-elim(*); +dce(*); + +forkify(*); +fork-guard-elim(*); +lift-dc-math(*); +forkify(*); +fork-guard-elim(*); +fork-unroll(out.median_window@outer); +lift-dc-math(*); +fixpoint { + forkify(*); + fork-guard-elim(*); + fork-unroll(*); +} +ccp(*); +gvn(*); +dce(*); + +array-to-product(*); +sroa(*); +phi-elim(*); +predication(*); +simplify-cfg(*); +dce(*); +gvn(*); + +gcm(*); diff --git a/juno_samples/median_window/src/gpu.sch b/juno_samples/median_window/src/gpu.sch new file mode 100644 index 0000000000000000000000000000000000000000..d9ca4d81a571236a828c904119f23273ef9e0b44 --- /dev/null +++ b/juno_samples/median_window/src/gpu.sch @@ -0,0 +1,43 @@ +gvn(*); +phi-elim(*); +dce(*); + +inline(*); +delete-uncalled(*); + +let out = auto-outline(*); +gpu(out.median_window); + +ip-sroa(*); +sroa(*); +dce(*); +gvn(*); +phi-elim(*); +dce(*); + +forkify(*); +fork-guard-elim(*); +lift-dc-math(*); +forkify(*); +fork-guard-elim(*); +fork-unroll(out.median_window@outer); +lift-dc-math(*); +fixpoint { + forkify(*); + fork-guard-elim(*); + fork-unroll(*); +} +ccp(*); +gvn(*); +dce(*); + +array-to-product(*); +sroa(*); +phi-elim(*); +predication(*); +simplify-cfg(*); +dce(*); +gvn(*); + +gcm(*); + diff --git a/juno_samples/median_window/src/main.rs b/juno_samples/median_window/src/main.rs new file mode 100644 index 0000000000000000000000000000000000000000..c515ac4b57f049d27bb800c9eb7dc1ac9463589c --- /dev/null +++ b/juno_samples/median_window/src/main.rs @@ -0,0 +1,26 @@ +#![feature(concat_idents)] + +juno_build::juno!("median"); + +use hercules_rt::{runner, HerculesImmBox, HerculesImmBoxTo}; + +fn main() { + let m = vec![86, 72, 14, 5, 55, + 25, 98, 89, 3, 66, + 44, 81, 27, 3, 40, + 18, 4, 57, 93, 34, + 70, 50, 50, 18, 34]; + let m = HerculesImmBox::from(m.as_slice()); + + let mut r = runner!(median_window); + let res = + async_std::task::block_on(async { + r.run(m.to()).await + }); + assert_eq!(res, 57); +} + +#[test] +fn test_median_window() { + main() +} diff --git a/juno_samples/median_window/src/median.jn b/juno_samples/median_window/src/median.jn new file mode 100644 index 0000000000000000000000000000000000000000..bbb02e4939fb737842e8d202c373e23b8e598097 --- /dev/null +++ b/juno_samples/median_window/src/median.jn @@ -0,0 +1,35 @@ +fn median_matrix<t: number, n, m: usize>(x: t[n, m]) -> t { + let tmp : t[n * m]; + for i = 0 to n { + for j = 0 to m { + tmp[i * m + j] = x[i, j]; + } + } + + const cnt = n * m; + + @outer for i = 0 to cnt - 1 { + for j = 0 to cnt - i - 1 { + if tmp[j] > tmp[j + 1] { + let t = tmp[j]; + tmp[j] = tmp[j + 1]; + tmp[j + 1] = t; + } + } + } + + return tmp[cnt / 2]; +} + +#[entry] +fn median_window(x: i32[5, 5]) -> i32 { + let window: i32[3, 3]; + + for i = 0 to 3 { + for j = 0 to 3 { + window[i, j] = x[i + 1, j + 1]; + } + } + + return median_matrix::<_, 3, 3>(window); +} diff --git a/juno_scheduler/src/compile.rs b/juno_scheduler/src/compile.rs index 4a8edc50a620b426d4b7eacf91653c476354e9c1..b0a85f593ca8721bd2fa541ec17818a27888d616 100644 --- a/juno_scheduler/src/compile.rs +++ b/juno_scheduler/src/compile.rs @@ -18,7 +18,7 @@ pub enum ScheduleCompilerError { UndefinedMacro(String, Location), NoSuchPass(String, Location), IncorrectArguments { - expected: usize, + expected: String, actual: usize, loc: Location, }, @@ -81,11 +81,20 @@ enum Appliable { } impl Appliable { - fn num_args(&self) -> usize { + // Tests whether a given number of arguments is a valid number of arguments for this + fn is_valid_num_args(&self, num: usize) -> bool { match self { - Appliable::Pass(pass) => pass.num_args(), + Appliable::Pass(pass) => pass.is_valid_num_args(num), // Delete uncalled, Schedules, and devices do not take arguments - _ => 0, + _ => num == 0, + } + } + + // Returns a description of the number of arguments this requires + fn valid_arg_nums(&self) -> &'static str { + match self { + Appliable::Pass(pass) => pass.valid_arg_nums(), + _ => "0", } } } @@ -96,6 +105,9 @@ impl FromStr for Appliable { fn from_str(s: &str) -> Result<Self, Self::Err> { match s { "array-slf" => Ok(Appliable::Pass(ir::Pass::ArraySLF)), + "array-to-product" | "array-to-prod" | "a2p" => { + Ok(Appliable::Pass(ir::Pass::ArrayToProduct)) + } "auto-outline" => Ok(Appliable::Pass(ir::Pass::AutoOutline)), "ccp" => Ok(Appliable::Pass(ir::Pass::CCP)), "crc" | "collapse-read-chains" => Ok(Appliable::Pass(ir::Pass::CRC)), @@ -308,9 +320,9 @@ fn compile_expr( .parse() .map_err(|s| ScheduleCompilerError::NoSuchPass(s, lexer.line_col(name)))?; - if args.len() != func.num_args() { + if !func.is_valid_num_args(args.len()) { return Err(ScheduleCompilerError::IncorrectArguments { - expected: func.num_args(), + expected: func.valid_arg_nums().to_string(), actual: args.len(), loc: lexer.line_col(span), }); @@ -367,7 +379,7 @@ fn compile_expr( if args.len() != params.len() { return Err(ScheduleCompilerError::IncorrectArguments { - expected: params.len(), + expected: params.len().to_string(), actual: args.len(), loc: lexer.line_col(span), }); diff --git a/juno_scheduler/src/ir.rs b/juno_scheduler/src/ir.rs index 27e63aa0d104a6a23db1d9e248f0fbef82749b14..b490ef71b595d6a612797dcda007361f3d5dc608 100644 --- a/juno_scheduler/src/ir.rs +++ b/juno_scheduler/src/ir.rs @@ -3,6 +3,7 @@ use hercules_ir::ir::{Device, Schedule}; #[derive(Debug, Copy, Clone)] pub enum Pass { ArraySLF, + ArrayToProduct, AutoOutline, CCP, CRC, @@ -40,12 +41,27 @@ pub enum Pass { } impl Pass { - pub fn num_args(&self) -> usize { + pub fn is_valid_num_args(&self, num: usize) -> bool { match self { - Pass::Xdot | Pass::Print => 1, - Pass::ForkFissionBufferize | Pass::ForkInterchange => 2, - Pass::ForkChunk => 4, - _ => 0, + Pass::ArrayToProduct => num == 0 || num == 1, + Pass::ForkChunk => num == 4, + Pass::ForkFissionBufferize => num == 2, + Pass::ForkInterchange => num == 2, + Pass::Print => num == 1, + Pass::Xdot => num == 0 || num == 1, + _ => num == 0, + } + } + + pub fn valid_arg_nums(&self) -> &'static str { + match self { + Pass::ArrayToProduct => "0 or 1", + Pass::ForkChunk => "4", + Pass::ForkFissionBufferize => "2", + Pass::ForkInterchange => "2", + Pass::Print => "1", + Pass::Xdot => "0 or 1", + _ => "0", } } } diff --git a/juno_scheduler/src/pm.rs b/juno_scheduler/src/pm.rs index 2434920751ef1b82f927929e6e648051e7e206d8..b1d4c93feb54facd63eba6a7825dbab60c405fb5 100644 --- a/juno_scheduler/src/pm.rs +++ b/juno_scheduler/src/pm.rs @@ -390,16 +390,16 @@ impl PassManager { pub fn make_reduce_cycles(&mut self) { if self.reduce_cycles.is_none() { self.make_fork_join_maps(); - self.make_fork_join_nests(); + self.make_nodes_in_fork_joins(); let fork_join_maps = self.fork_join_maps.as_ref().unwrap().iter(); - let fork_join_nests = self.fork_join_nests.as_ref().unwrap().iter(); + let nodes_in_fork_joins = self.nodes_in_fork_joins.as_ref().unwrap().iter(); self.reduce_cycles = Some( self.functions .iter() .zip(fork_join_maps) - .zip(fork_join_nests) - .map(|((function, fork_join_map), fork_join_nest)| { - reduce_cycles(function, fork_join_map, fork_join_nest) + .zip(nodes_in_fork_joins) + .map(|((function, fork_join_map), nodes_in_fork_joins)| { + reduce_cycles(function, fork_join_map, nodes_in_fork_joins) }) .collect(), ); @@ -1521,6 +1521,34 @@ fn run_pass( pm.delete_gravestones(); pm.clear_analyses(); } + Pass::ArrayToProduct => { + assert!(args.len() <= 1); + let max_size = match args.get(0) { + Some(Value::Integer { val }) => Some(*val), + Some(_) => { + return Err(SchedulerError::PassError { + pass: "array-to-product".to_string(), + error: "expected integer argument".to_string(), + }); + } + None => None, + }; + pm.make_typing(); + let typing = pm.typing.take().unwrap(); + + for (func, types) in build_selection(pm, selection, false) + .into_iter() + .zip(typing.iter()) + { + let Some(mut func) = func else { + continue; + }; + array_to_product(&mut func, types, max_size); + changed |= func.modified(); + } + pm.delete_gravestones(); + pm.clear_analyses(); + } Pass::AutoOutline => { let Some(funcs) = selection_of_functions(pm, selection) else { return Err(SchedulerError::PassError { @@ -1859,9 +1887,9 @@ fn run_pass( let Some(mut func) = func else { continue; }; - forkify(&mut func, control_subgraph, fork_join_map, loop_nest); - changed |= func.modified(); - inner_changed |= func.modified(); + let c = forkify(&mut func, control_subgraph, fork_join_map, loop_nest); + changed |= c; + inner_changed |= c; } pm.delete_gravestones(); pm.clear_analyses(); @@ -2306,7 +2334,13 @@ fn run_pass( let Some(mut func) = func else { continue; }; - chunk_all_forks_unguarded(&mut func, fork_join_map, *dim_idx, *tile_size, *tile_order); + chunk_all_forks_unguarded( + &mut func, + fork_join_map, + *dim_idx, + *tile_size, + *tile_order, + ); changed |= func.modified(); } pm.delete_gravestones();