diff --git a/Cargo.lock b/Cargo.lock index 9fb973e6f63cf70eb1ca685235a08ef25c7923f1..a84f3d2c48283e1ef8dfdf0b441e18cc8e6b39ee 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1252,7 +1252,7 @@ dependencies = [ ] [[package]] -name = "juno_product_read" +name = "juno_products" version = "0.1.0" dependencies = [ "async-std", diff --git a/Cargo.toml b/Cargo.toml index 3e86bad053f8847c9f10d18117635de1cfa76250..0ed8f64bbfe490464d7a201837f2f9ba133326c3 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -33,5 +33,5 @@ members = [ "juno_samples/edge_detection", "juno_samples/fork_join_tests", "juno_samples/multi_device", - "juno_samples/product_read", + "juno_samples/products", ] diff --git a/hercules_opt/src/lib.rs b/hercules_opt/src/lib.rs index 7187508a31240071849a26b85100e89607786b2f..17c55bbe71e74bf34fb1a3cf536b4dc2a8c36e9a 100644 --- a/hercules_opt/src/lib.rs +++ b/hercules_opt/src/lib.rs @@ -18,6 +18,7 @@ pub mod lift_dc_math; pub mod outline; pub mod phi_elim; pub mod pred; +pub mod reuse_products; pub mod schedule; pub mod simplify_cfg; pub mod slf; @@ -43,6 +44,7 @@ pub use crate::lift_dc_math::*; pub use crate::outline::*; pub use crate::phi_elim::*; pub use crate::pred::*; +pub use crate::reuse_products::*; pub use crate::schedule::*; pub use crate::simplify_cfg::*; pub use crate::slf::*; diff --git a/hercules_opt/src/reuse_products.rs b/hercules_opt/src/reuse_products.rs new file mode 100644 index 0000000000000000000000000000000000000000..eb0b4a657ccca8a18af141f5c416416a0f1b6af3 --- /dev/null +++ b/hercules_opt/src/reuse_products.rs @@ -0,0 +1,212 @@ +use std::collections::HashMap; + +use hercules_ir::ir::*; + +use crate::*; + +/* + * Reuse Products is an optimization pass which identifies when two product + * values are identical because each field of the "source" product is read and + * then written into the "destination" product and then replaces the destination + * product by the source product. + * + * This pattern can occur in our code because SROA and IP SROA are both + * aggressive about breaking products into their fields and reconstructing + * products right where needed, so if a function returns a product that is + * produced by a call node, these optimizations will produce code that reads the + * fields out of the call node and then writes them into the product that is + * returned. + * + * This optimization does not delete any nodes other than the destination nodes, + * if other nodes become dead as a result the clean up is left to DCE. + * + * The analysis for this starts by labeling each product source node (arguments, + * constants, and call nodes) with themselves as the source of all of their + * fields. Then, these field sources are propagated along read and write nodes. + * At the end all nodes with product values are labeled by the source (node and + * index) of each of its fields. We then check if any node's fields are exactly + * the fields of some other node (i.e. is exactly the same value as some other + * node) we replace it with that other node. + */ +pub fn reuse_products( + editor: &mut FunctionEditor, + reverse_postorder: &Vec<NodeID>, + types: &Vec<TypeID>, +) { + let mut source_nodes = vec![]; + let mut read_write_nodes = vec![]; + + for node in reverse_postorder { + match &editor.node(node) { + Node::Parameter { .. } | Node::Constant { .. } | Node::Call { .. } + if editor.get_type(types[node.idx()]).is_product() => + { + source_nodes.push(*node) + } + Node::Write { .. } if editor.get_type(types[node.idx()]).is_product() => { + read_write_nodes.push(*node) + } + Node::Read { collect, .. } if editor.get_type(types[collect.idx()]).is_product() => { + read_write_nodes.push(*node) + } + _ => (), + } + } + + let mut product_nodes: HashMap<NodeID, IndexTree<(NodeID, Vec<Index>)>> = HashMap::new(); + + for source in source_nodes { + product_nodes.insert( + source, + generate_source_info(editor, source, types[source.idx()]), + ); + } + + for node in read_write_nodes { + match editor.node(node) { + Node::Read { collect, indices } => { + let Some(collect) = product_nodes.get(collect) else { + continue; + }; + let result = collect.lookup(indices); + product_nodes.insert(node, result.clone()); + } + Node::Write { + collect, + data, + indices, + } => { + let Some(collect) = product_nodes.get(collect) else { + continue; + }; + let Some(data) = product_nodes.get(data) else { + continue; + }; + let result = collect.clone().replace(indices, data.clone()); + product_nodes.insert(node, result); + } + _ => panic!("Non read/write node"), + } + } + + // Note that we don't have to worry about some node A being equivalent to node B but node B + // being equivalent to node C and being replaced first causing an issue when we try to replace + // node A with B. + // This cannot occur since the only nodes something can be equivalent with are the source nodes + // and they are all equivalent to precisely themselves which we ignore. + for (node, data) in product_nodes { + let Some(replace_with) = is_other_product(editor, types, data) else { + continue; + }; + + if replace_with != node { + editor.edit(|edit| { + let edit = edit.replace_all_uses(node, replace_with)?; + edit.delete_node(node) + }); + } + } +} + +fn generate_source_info( + editor: &FunctionEditor, + source: NodeID, + typ: TypeID, +) -> IndexTree<(NodeID, Vec<Index>)> { + generate_source_info_at_index(editor, source, typ, vec![]) +} + +fn generate_source_info_at_index( + editor: &FunctionEditor, + source: NodeID, + typ: TypeID, + idx: Vec<Index>, +) -> IndexTree<(NodeID, Vec<Index>)> { + let ts: Option<Vec<TypeID>> = if let Some(ts) = editor.get_type(typ).try_product() { + Some(ts.into()) + } else { + None + }; + + if let Some(ts) = ts { + // Recurse on each field with an extended index and appropriate type + let mut fields = vec![]; + for (i, t) in ts.into_iter().enumerate() { + let mut new_idx = idx.clone(); + new_idx.push(Index::Field(i)); + fields.push(generate_source_info_at_index(editor, source, t, new_idx)); + } + IndexTree::Node(fields) + } else { + // We've reached the leaf + IndexTree::Leaf((source, idx)) + } +} + +fn is_other_product( + editor: &FunctionEditor, + types: &Vec<TypeID>, + node: IndexTree<(NodeID, Vec<Index>)>, +) -> Option<NodeID> { + let Some(other_node) = find_only_node(&node) else { + return None; + }; + + if matches_fields_index(editor, types[other_node.idx()], &node, vec![]) { + Some(other_node) + } else { + None + } +} + +fn find_only_node(tree: &IndexTree<(NodeID, Vec<Index>)>) -> Option<NodeID> { + match tree { + IndexTree::Leaf((node, _)) => Some(*node), + IndexTree::Node(fields) => fields + .iter() + .map(|t| find_only_node(t)) + .reduce(|n, m| match (n, m) { + (Some(n), Some(m)) if n == m => Some(n), + (_, _) => None, + }) + .flatten(), + } +} + +fn matches_fields_index( + editor: &FunctionEditor, + typ: TypeID, + tree: &IndexTree<(NodeID, Vec<Index>)>, + index: Vec<Index>, +) -> bool { + match tree { + IndexTree::Leaf((_, idx)) => { + // If in the original value we still have a product, these can't match + if editor.get_type(typ).is_product() { + false + } else { + *idx == index + } + } + IndexTree::Node(fields) => { + let ts: Vec<TypeID> = if let Some(ts) = editor.get_type(typ).try_product() { + ts.into() + } else { + return false; + }; + + if fields.len() != ts.len() { + return false; + } + + ts.into_iter() + .zip(fields.iter()) + .enumerate() + .all(|(i, (ty, field))| { + let mut new_index = index.clone(); + new_index.push(Index::Field(i)); + matches_fields_index(editor, ty, field, new_index) + }) + } + } +} diff --git a/hercules_opt/src/sroa.rs b/hercules_opt/src/sroa.rs index 3210094ded7006001e6a82ee356161221cc0b468..dbb2f8cef91ee4d67f65818128113082ed063fc1 100644 --- a/hercules_opt/src/sroa.rs +++ b/hercules_opt/src/sroa.rs @@ -700,13 +700,13 @@ pub fn sroa(editor: &mut FunctionEditor, reverse_postorder: &Vec<NodeID>, types: // An index tree is used to store results at many index lists #[derive(Clone, Debug)] -enum IndexTree<T> { +pub enum IndexTree<T> { Leaf(T), Node(Vec<IndexTree<T>>), } impl<T: std::fmt::Debug> IndexTree<T> { - fn lookup(&self, idx: &[Index]) -> &IndexTree<T> { + pub fn lookup(&self, idx: &[Index]) -> &IndexTree<T> { self.lookup_idx(idx, 0) } @@ -725,7 +725,7 @@ impl<T: std::fmt::Debug> IndexTree<T> { } } - fn set(self, idx: &[Index], val: T) -> IndexTree<T> { + pub fn set(self, idx: &[Index], val: T) -> IndexTree<T> { self.set_idx(idx, val, 0) } @@ -756,7 +756,7 @@ impl<T: std::fmt::Debug> IndexTree<T> { } } - fn replace(self, idx: &[Index], val: IndexTree<T>) -> IndexTree<T> { + pub fn replace(self, idx: &[Index], val: IndexTree<T>) -> IndexTree<T> { self.replace_idx(idx, val, 0) } @@ -787,7 +787,7 @@ impl<T: std::fmt::Debug> IndexTree<T> { } } - fn zip<'a, A>(self, other: &'a IndexTree<A>) -> IndexTree<(T, &'a A)> { + pub fn zip<'a, A>(self, other: &'a IndexTree<A>) -> IndexTree<(T, &'a A)> { match (self, other) { (IndexTree::Leaf(t), IndexTree::Leaf(a)) => IndexTree::Leaf((t, a)), (IndexTree::Node(t), IndexTree::Node(a)) => { @@ -801,7 +801,7 @@ impl<T: std::fmt::Debug> IndexTree<T> { } } - fn zip_list<'a, A>(self, others: Vec<&'a IndexTree<A>>) -> IndexTree<(T, Vec<&'a A>)> { + pub fn zip_list<'a, A>(self, others: Vec<&'a IndexTree<A>>) -> IndexTree<(T, Vec<&'a A>)> { match self { IndexTree::Leaf(t) => { let mut res = vec![]; @@ -835,7 +835,7 @@ impl<T: std::fmt::Debug> IndexTree<T> { } } - fn for_each<F>(&self, mut f: F) + pub fn for_each<F>(&self, mut f: F) where F: FnMut(&Vec<Index>, &T), { diff --git a/juno_samples/product_read/Cargo.toml b/juno_samples/products/Cargo.toml similarity index 88% rename from juno_samples/product_read/Cargo.toml rename to juno_samples/products/Cargo.toml index d466f5550b77e426040842339684a1e8906b22fa..34878a07f2e4fac223075536bf5559b8e1f4b132 100644 --- a/juno_samples/product_read/Cargo.toml +++ b/juno_samples/products/Cargo.toml @@ -1,11 +1,11 @@ [package] -name = "juno_product_read" +name = "juno_products" version = "0.1.0" authors = ["Aaron Councilman <aaronjc4@illinois.edu>"] edition = "2021" [[bin]] -name = "juno_product_read" +name = "juno_products" path = "src/main.rs" [features] diff --git a/juno_samples/product_read/build.rs b/juno_samples/products/build.rs similarity index 81% rename from juno_samples/product_read/build.rs rename to juno_samples/products/build.rs index 2bd5172e661e65e2284a986e2d710cd890d71b90..6d621961581406b08cb5e85f6ff03e63df29b84a 100644 --- a/juno_samples/product_read/build.rs +++ b/juno_samples/products/build.rs @@ -4,7 +4,7 @@ fn main() { #[cfg(not(feature = "cuda"))] { JunoCompiler::new() - .file_in_src("product_read.jn") + .file_in_src("products.jn") .unwrap() .build() .unwrap(); @@ -12,7 +12,7 @@ fn main() { #[cfg(feature = "cuda")] { JunoCompiler::new() - .file_in_src("product_read.jn") + .file_in_src("products.jn") .unwrap() .schedule_in_src("gpu.sch") .unwrap() diff --git a/juno_samples/product_read/src/gpu.sch b/juno_samples/products/src/gpu.sch similarity index 91% rename from juno_samples/product_read/src/gpu.sch rename to juno_samples/products/src/gpu.sch index 549b421561da50023719fd432fe8d943a4ab37f6..5ef4c479550bcbe4f5d9ddeffb03efde4162cf7b 100644 --- a/juno_samples/product_read/src/gpu.sch +++ b/juno_samples/products/src/gpu.sch @@ -7,6 +7,7 @@ gpu(out.product_read); ip-sroa(*); sroa(*); +reuse-products(*); crc(*); dce(*); gvn(*); diff --git a/juno_samples/product_read/src/main.rs b/juno_samples/products/src/main.rs similarity index 63% rename from juno_samples/product_read/src/main.rs rename to juno_samples/products/src/main.rs index 5211098ceebd6d7b15871ba0dd73cdbefb993313..b8abb59d5bc3b2cbbb0ef13ec1f44c8d0734e1bb 100644 --- a/juno_samples/product_read/src/main.rs +++ b/juno_samples/products/src/main.rs @@ -2,7 +2,7 @@ use hercules_rt::{runner, HerculesImmBox, HerculesImmBoxTo, HerculesMutBox}; -juno_build::juno!("product_read"); +juno_build::juno!("products"); fn main() { async_std::task::block_on(async { @@ -11,6 +11,11 @@ fn main() { let mut r = runner!(product_read); let res : Vec<i32> = HerculesMutBox::from(r.run(input.to()).await).as_slice().to_vec(); assert_eq!(res, vec![0, 1, 2, 3]); + + // Technically this returns a product of two i32s, but we can interpret that as an array + let mut r = runner!(product_return); + let res : Vec<i32> = HerculesMutBox::from(r.run(42, 17).await).as_slice().to_vec(); + assert_eq!(res, vec![42, 17]); }); } diff --git a/juno_samples/product_read/src/product_read.jn b/juno_samples/products/src/products.jn similarity index 72% rename from juno_samples/product_read/src/product_read.jn rename to juno_samples/products/src/products.jn index 7bf74a105b32099341f299c38898f6f6c08eb467..4f56368ec35463e9d7ac545b7ed08ef288b0178b 100644 --- a/juno_samples/product_read/src/product_read.jn +++ b/juno_samples/products/src/products.jn @@ -7,3 +7,8 @@ fn product_read(input: (i32, i32)[2]) -> i32[4] { result[3] = input[1].1; return result; } + +#[entry] +fn product_return(x: i32, y: i32) -> (i32, i32) { + return (x, y); +} diff --git a/juno_scheduler/src/compile.rs b/juno_scheduler/src/compile.rs index 6b40001c2b3324176913b1e843934d422ed2e711..7887b9b392eaacb8b11dd93bf508db600ba3d121 100644 --- a/juno_scheduler/src/compile.rs +++ b/juno_scheduler/src/compile.rs @@ -119,6 +119,7 @@ impl FromStr for Appliable { "phi-elim" => Ok(Appliable::Pass(ir::Pass::PhiElim)), "predication" => Ok(Appliable::Pass(ir::Pass::Predication)), "reduce-slf" => Ok(Appliable::Pass(ir::Pass::ReduceSLF)), + "reuse-products" => Ok(Appliable::Pass(ir::Pass::ReuseProducts)), "simplify-cfg" => Ok(Appliable::Pass(ir::Pass::SimplifyCFG)), "slf" | "store-load-forward" => Ok(Appliable::Pass(ir::Pass::SLF)), "sroa" => Ok(Appliable::Pass(ir::Pass::SROA)), diff --git a/juno_scheduler/src/default.rs b/juno_scheduler/src/default.rs index 3f4af107c45d87e6a89b198483b25abae6156a78..0621f8deb9145614a3919003a32b023d8c70b2b7 100644 --- a/juno_scheduler/src/default.rs +++ b/juno_scheduler/src/default.rs @@ -45,6 +45,8 @@ pub fn default_schedule() -> ScheduleStmt { SROA, PhiElim, DCE, + ReuseProducts, + DCE, CCP, SimplifyCFG, DCE, @@ -88,6 +90,7 @@ pub fn default_schedule() -> ScheduleStmt { AutoOutline, InterproceduralSROA, SROA, + ReuseProducts, SimplifyCFG, InferSchedules, DCE, diff --git a/juno_scheduler/src/ir.rs b/juno_scheduler/src/ir.rs index 840f25a6e9dc986ab064adecbeba822ca47016d8..5bfb4e216e3ca79403c07958dd0ae481f38e5247 100644 --- a/juno_scheduler/src/ir.rs +++ b/juno_scheduler/src/ir.rs @@ -28,6 +28,7 @@ pub enum Pass { PhiElim, Predication, ReduceSLF, + ReuseProducts, SLF, SROA, Serialize, diff --git a/juno_scheduler/src/pm.rs b/juno_scheduler/src/pm.rs index f59834eddad670cceeb4954812fff5926f39f862..342f875b5aa60151ce60cb6ba952a2f54696ef0d 100644 --- a/juno_scheduler/src/pm.rs +++ b/juno_scheduler/src/pm.rs @@ -2011,6 +2011,27 @@ fn run_pass( changed |= func.modified(); } } + Pass::ReuseProducts => { + assert!(args.is_empty()); + pm.make_reverse_postorders(); + pm.make_typing(); + let reverse_postorders = pm.reverse_postorders.take().unwrap(); + let typing = pm.typing.take().unwrap(); + + for ((func, reverse_postorder), types) in build_selection(pm, selection, false) + .into_iter() + .zip(reverse_postorders.iter()) + .zip(typing.iter()) + { + let Some(mut func) = func else { + continue; + }; + reuse_products(&mut func, reverse_postorder, types); + changed |= func.modified(); + } + pm.delete_gravestones(); + pm.clear_analyses(); + } Pass::SLF => { assert!(args.is_empty()); pm.make_reverse_postorders();