diff --git a/Cargo.lock b/Cargo.lock index 8bb64bd279179ee50d9e10e8285ea815bd8a99a7..623fc35c9260676fc9b683bd63e96ac7cbc31a2c 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -259,6 +259,12 @@ dependencies = [ "arrayvec", ] +[[package]] +name = "bimap" +version = "0.6.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "230c5f1ca6a325a32553f8640d31ac9b49f2411e901e427570154868b46da4f7" + [[package]] name = "bincode" version = "1.3.3" @@ -282,9 +288,9 @@ checksum = "bef38d45163c2f1dde094a7dfd33ccf595c92905c8f8f4fdc18d06fb1037718a" [[package]] name = "bitflags" -version = "2.7.0" +version = "2.8.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1be3f42a67d6d345ecd59f675f3f012d6974981560836e938c22b424b85ce1be" +checksum = "8f68f53c83ab957f72c32642f3868eec03eb974d1fb82e453128456482613d36" [[package]] name = "bitstream-io" @@ -366,9 +372,9 @@ dependencies = [ [[package]] name = "cc" -version = "1.2.9" +version = "1.2.10" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c8293772165d9345bdaaa39b45b2109591e63fe5e6fbc23c6ff930a048aa310b" +checksum = "13208fcbb66eaeffe09b99fffbe1af420f00a7b35aa99ad683dfc1aa76145229" dependencies = [ "jobserver", "libc", @@ -418,9 +424,9 @@ dependencies = [ [[package]] name = "clap" -version = "4.5.26" +version = "4.5.27" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a8eb5e908ef3a6efbe1ed62520fb7287959888c88485abe072543190ecc66783" +checksum = "769b0145982b4b48713e01ec42d61614425f27b7058bda7180a3a41f30104796" dependencies = [ "clap_builder", "clap_derive", @@ -428,9 +434,9 @@ dependencies = [ [[package]] name = "clap_builder" -version = "4.5.26" +version = "4.5.27" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "96b01801b5fc6a0a232407abc821660c9c6d25a1cafc0d4f85f29fb8d9afc121" +checksum = "1b26884eb4b57140e4d2d93652abfa49498b938b3c9179f9fc487b0acc3edad7" dependencies = [ "anstream", "anstyle", @@ -525,9 +531,9 @@ checksum = "d0a5c400df2834b80a4c3327b3aad3a4c4cd4de0629063962b03235697506a28" [[package]] name = "crunchy" -version = "0.2.2" +version = "0.2.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7a81dae078cea95a014a339291cec439d2f232ebe854a9d672b796c6afafa9b7" +checksum = "43da5946c66ffcc7745f48db692ffbb10a83bfe0afd96235c5c2a4fb23994929" [[package]] name = "deranged" @@ -538,6 +544,26 @@ dependencies = [ "powerfmt", ] +[[package]] +name = "derive_more" +version = "1.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4a9b99b9cbbe49445b21764dc0625032a89b145a2642e67603e1c936f5458d05" +dependencies = [ + "derive_more-impl", +] + +[[package]] +name = "derive_more-impl" +version = "1.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "cb7330aeadfbe296029522e6c40f315320aba36fc43a5b3632f3795348f3bd22" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.96", +] + [[package]] name = "dot" version = "0.1.0" @@ -819,6 +845,23 @@ dependencies = [ "serde", ] +[[package]] +name = "hercules_interpreter" +version = "0.1.0" +dependencies = [ + "bitvec", + "clap", + "derive_more", + "hercules_ir", + "hercules_opt", + "itertools 0.14.0", + "juno_scheduler", + "ordered-float", + "postcard", + "rand", + "serde", +] + [[package]] name = "hercules_ir" version = "0.1.0" @@ -834,14 +877,17 @@ dependencies = [ name = "hercules_opt" version = "0.1.0" dependencies = [ + "bimap", "bitvec", "either", "hercules_cg", "hercules_ir", "itertools 0.14.0", + "nestify", "ordered-float", "postcard", "serde", + "slotmap", "take_mut", "tempfile", "union-find", @@ -851,6 +897,21 @@ dependencies = [ name = "hercules_rt" version = "0.1.0" +[[package]] +name = "hercules_tests" +version = "0.1.0" +dependencies = [ + "bitvec", + "clap", + "hercules_interpreter", + "hercules_ir", + "hercules_opt", + "itertools 0.14.0", + "juno_scheduler", + "ordered-float", + "rand", +] + [[package]] name = "hermit-abi" version = "0.4.0" @@ -898,9 +959,9 @@ checksum = "d0263a3d970d5c054ed9312c0057b4f3bde9c0b33836d3637361d4a9e6e7a408" [[package]] name = "indexmap" -version = "2.7.0" +version = "2.7.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "62f822373a4fe84d4bb149bf54e584a7f4abec90e072ed49cda0edea5b95471f" +checksum = "8c9c992b02b5b4c94ea26e32fe5bccb7aa7d9f390ab5c1221ff895bc7ea8b652" dependencies = [ "equivalent", "hashbrown", @@ -1094,6 +1155,8 @@ dependencies = [ "juno_utils", "lrlex", "lrpar", + "postcard", + "serde", "tempfile", ] @@ -1157,7 +1220,7 @@ version = "0.1.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "c0ff37bd590ca25063e35af745c343cb7a0271906fb7b37e4813e8f79f00268d" dependencies = [ - "bitflags 2.7.0", + "bitflags 2.8.0", "libc", "redox_syscall", ] @@ -1180,9 +1243,9 @@ dependencies = [ [[package]] name = "log" -version = "0.4.22" +version = "0.4.25" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a7a70ba024b9dc04c27ea2f0c0548feb474ec5c54bba33a7f72f873a39d07b24" +checksum = "04cbf5b083de1c7e0222a7a51dbfdba1cbe1c6ab0b15e29fff3f6c077fd9cd9f" dependencies = [ "value-bag", ] @@ -1286,14 +1349,26 @@ checksum = "68354c5c6bd36d73ff3feceb05efa59b6acb7626617f4962be322a825e61f79a" [[package]] name = "miniz_oxide" -version = "0.8.2" +version = "0.8.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4ffbe83022cedc1d264172192511ae958937694cd57ce297164951b8b3568394" +checksum = "b8402cab7aefae129c6977bb0ff1b8fd9a04eb5b51efc50a70bea51cda0c7924" dependencies = [ "adler2", "simd-adler32", ] +[[package]] +name = "nestify" +version = "0.3.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "41d7249f7122d4e8a40f3b1b1850b763d2f864bf8e4b712427f024f8a167ea17" +dependencies = [ + "proc-macro-error", + "proc-macro2", + "quote", + "syn 2.0.96", +] + [[package]] name = "new_debug_unreachable" version = "1.0.6" @@ -1547,6 +1622,30 @@ dependencies = [ "zerocopy", ] +[[package]] +name = "proc-macro-error" +version = "1.0.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "da25490ff9892aab3fcf7c36f08cfb902dd3e71ca0f9f9517bea02a73a5ce38c" +dependencies = [ + "proc-macro-error-attr", + "proc-macro2", + "quote", + "syn 1.0.109", + "version_check", +] + +[[package]] +name = "proc-macro-error-attr" +version = "1.0.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a1be40180e52ecc98ad80b184934baf3d0d29f979574e439af5a55274b35f869" +dependencies = [ + "proc-macro2", + "quote", + "version_check", +] + [[package]] name = "proc-macro2" version = "1.0.93" @@ -1713,7 +1812,7 @@ version = "0.5.8" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "03a862b389f93e68874fbf580b9de08dd02facb9a788ebadaf4a3fd33cf58834" dependencies = [ - "bitflags 2.7.0", + "bitflags 2.8.0", ] [[package]] @@ -1762,11 +1861,11 @@ dependencies = [ [[package]] name = "rustix" -version = "0.38.43" +version = "0.38.44" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a78891ee6bf2340288408954ac787aa063d8e8817e9f53abb37c695c6d834ef6" +checksum = "fdb5bc1ae2baa591800df16c9ca78619bf65c0488b41b96ccec5d11220d8c154" dependencies = [ - "bitflags 2.7.0", + "bitflags 2.8.0", "errno", "libc", "linux-raw-sys", @@ -1787,9 +1886,9 @@ checksum = "94143f37725109f92c262ed2cf5e59bce7498c01bcc1502d7b9afe439a4e9f49" [[package]] name = "semver" -version = "1.0.24" +version = "1.0.25" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3cb6eb87a131f756572d7fb904f6e7b68633f09cca868c5df1c4b8d1a694bbba" +checksum = "f79dfe2d285b0488816f30e700a7438c5a73d816b5b7d3ac72fbc48b0d185e03" [[package]] name = "serde" @@ -1856,6 +1955,15 @@ dependencies = [ "autocfg", ] +[[package]] +name = "slotmap" +version = "1.0.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "dbff4acf519f630b3a3ddcfaea6c06b42174d9a44bc70c620e9ed1649d58b82a" +dependencies = [ + "version_check", +] + [[package]] name = "smallvec" version = "1.13.2" @@ -2084,9 +2192,9 @@ checksum = "e672c95779cf947c5311f83787af4fa8fffd12fb27e4993211a84bdfd9610f9c" [[package]] name = "unicode-ident" -version = "1.0.14" +version = "1.0.15" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "adb9e6ca4f869e1180728b7950e35922a7fc6397f7b641499e8f3ef06e50dc83" +checksum = "11cd88e12b17c6494200a9c1b683a04fcac9573ed74cd1b62aeb2727c5592243" [[package]] name = "unicode-width" @@ -2140,6 +2248,12 @@ version = "0.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "852e951cb7832cb45cb1169900d19760cfa39b82bc0ea9c0e5a14ae88411c98b" +[[package]] +name = "version_check" +version = "0.9.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0b928f33d975fc6ad9f86c8f283853ad26bdd5b10b7f1542aa2fa15e2289105a" + [[package]] name = "vob" version = "3.0.3" diff --git a/Cargo.toml b/Cargo.toml index 6adefdba2d9d684698abfc6a58c98d186b81422d..d31c59f7914683e43cceb4455928f8e52b23621d 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -11,8 +11,8 @@ members = [ "juno_scheduler", "juno_build", - #"hercules_test/hercules_interpreter", - #"hercules_test/hercules_tests", + "hercules_test/hercules_interpreter", + "hercules_test/hercules_tests", "hercules_samples/dot", "hercules_samples/matmul", @@ -26,7 +26,7 @@ members = [ "juno_samples/nested_ccp", "juno_samples/antideps", "juno_samples/implicit_clone", - "juno_samples/cava", + "juno_samples/cava", "juno_samples/concat", - "juno_samples/schedule_test", + "juno_samples/schedule_test", ] diff --git a/hercules_cg/src/cpu.rs b/hercules_cg/src/cpu.rs index 3750c4f6abbac3a774269c729eaded8afcc204c3..5374385288c4aa4940eafd293c58b0beeabbc5e3 100644 --- a/hercules_cg/src/cpu.rs +++ b/hercules_cg/src/cpu.rs @@ -615,14 +615,14 @@ impl<'a> CPUContext<'a> { )?, DynamicConstant::Min(left, right) => write!( body, - " %dc{} = call @llvm.umin.i64(i64%dc{},i64%dc{})\n", + " %dc{} = call i64 @llvm.umin.i64(i64%dc{},i64%dc{})\n", dc.idx(), left.idx(), right.idx() )?, DynamicConstant::Max(left, right) => write!( body, - " %dc{} = call @llvm.umax.i64(i64%dc{},i64%dc{})\n", + " %dc{} = call i64 @llvm.umax.i64(i64%dc{},i64%dc{})\n", dc.idx(), left.idx(), right.idx() diff --git a/hercules_ir/src/ir.rs b/hercules_ir/src/ir.rs index 0b271e857df2269cbdfc51e223424141435558c6..bf7806dcc371112419bbf7e21ef3b206df3a2e32 100644 --- a/hercules_ir/src/ir.rs +++ b/hercules_ir/src/ir.rs @@ -1356,6 +1356,36 @@ impl Node { } } + pub fn is_zero_dc(&self, dynamic_constants: &Vec<DynamicConstant>) -> bool { + if let Node::DynamicConstant { id } = self + && dynamic_constants[id.idx()].try_constant() == Some(0) + { + true + } else { + false + } + } + + pub fn is_one_dc(&self, dynamic_constants: &Vec<DynamicConstant>) -> bool { + if let Node::DynamicConstant { id } = self + && dynamic_constants[id.idx()].try_constant() == Some(1) + { + true + } else { + false + } + } + + pub fn is_one_constant(&self, constants: &Vec<Constant>) -> bool { + if let Node::Constant { id } = self + && constants[id.idx()].is_one() + { + true + } else { + false + } + } + pub fn try_projection(&self, branch: usize) -> Option<NodeID> { if let Node::Projection { control, selection } = self && branch == *selection diff --git a/hercules_ir/src/verify.rs b/hercules_ir/src/verify.rs index 5ee5f1d26850c32e7bd292602185dbb8327167ce..f188932e3a362760cc8855b43c9fa9fea21cbe42 100644 --- a/hercules_ir/src/verify.rs +++ b/hercules_ir/src/verify.rs @@ -129,6 +129,7 @@ fn verify_structure( match function.nodes[user.idx()] { Node::Parameter { index: _ } | Node::Constant { id: _ } + | Node::Undef { ty: _ } | Node::DynamicConstant { id: _ } => {} _ => { if function.nodes[user.idx()].is_control() { @@ -300,7 +301,7 @@ fn verify_structure( Err("Call node's control input must be a region node.")?; } } - // Collect nodes must depend on a join node. + // Reduce nodes must depend on a join node. Node::Reduce { control, init: _, @@ -308,7 +309,7 @@ fn verify_structure( } => { if let Node::Join { control: _ } = function.nodes[control.idx()] { } else { - Err("Collect node's control input must be a join node.")?; + Err("Reduce node's control input must be a join node.")?; } } // Return nodes must have no users. @@ -498,8 +499,8 @@ fn verify_dominance_relationships( // Every use of a thread ID must be postdominated by // the thread ID's fork's corresponding join node. We // don't need to check for the case where the thread ID - // flows through the collect node out of the fork-join, - // because after the collect, the thread ID is no longer + // flows through the reduce node out of the fork-join, + // because after the reduce, the thread ID is no longer // considered an immediate control output use. if postdom.contains(this_id) && !postdom.does_dom(*fork_join_map.get(&control).unwrap(), this_id) diff --git a/hercules_opt/Cargo.toml b/hercules_opt/Cargo.toml index 9f22884dc9811a1a79235c0ea5ddf1111f7d22a0..1db5dc2f555709341ef331149602c29a6bf917c3 100644 --- a/hercules_opt/Cargo.toml +++ b/hercules_opt/Cargo.toml @@ -11,8 +11,11 @@ tempfile = "*" either = "*" itertools = "*" take_mut = "*" +slotmap = "*" union-find = "*" postcard = { version = "*", features = ["alloc"] } serde = { version = "*", features = ["derive"] } hercules_cg = { path = "../hercules_cg" } hercules_ir = { path = "../hercules_ir" } +nestify = "*" +bimap = "*" \ No newline at end of file diff --git a/hercules_opt/src/ccp.rs b/hercules_opt/src/ccp.rs index 68693e8b78a85d71928cfb4dc295cacd864a8189..92d52a716710d255ff9d1a0ef8a9de20a7c88fa9 100644 --- a/hercules_opt/src/ccp.rs +++ b/hercules_opt/src/ccp.rs @@ -677,7 +677,9 @@ fn ccp_flow_function( (BinaryOperator::RSh, Constant::UnsignedInteger64(left_val), Constant::UnsignedInteger64(right_val)) => Some(Constant::UnsignedInteger64(left_val >> right_val)), _ => panic!("Unsupported combination of binary operation and constant values. Did typechecking succeed?") }; - new_cons.map(|c| ConstantLattice::Constant(c)).unwrap_or(ConstantLattice::bottom()) + new_cons + .map(|c| ConstantLattice::Constant(c)) + .unwrap_or(ConstantLattice::bottom()) } else if (left_constant.is_top() && !right_constant.is_bottom()) || (!left_constant.is_bottom() && right_constant.is_top()) { diff --git a/hercules_opt/src/device_placement.rs b/hercules_opt/src/device_placement.rs deleted file mode 100644 index 2badd69df428db12a0f3a46ac4aee05f8154d171..0000000000000000000000000000000000000000 --- a/hercules_opt/src/device_placement.rs +++ /dev/null @@ -1,3 +0,0 @@ -use hercules_ir::ir::*; - -use crate::*; diff --git a/hercules_opt/src/editor.rs b/hercules_opt/src/editor.rs index 8b90710e737ccd1e96293cc8a3f05bccb3ea7153..39f1184cc947a35418641a817a86321343f101fc 100644 --- a/hercules_opt/src/editor.rs +++ b/hercules_opt/src/editor.rs @@ -1,3 +1,4 @@ +use std::borrow::Borrow; use std::cell::{Ref, RefCell}; use std::collections::{BTreeMap, HashSet}; use std::mem::take; @@ -335,6 +336,10 @@ impl<'a: 'b, 'b> FunctionEditor<'a> { self.function_id } + pub fn node(&self, node: impl Borrow<NodeID>) -> &Node { + &self.function.nodes[node.borrow().idx()] + } + pub fn get_types(&self) -> Ref<'_, Vec<Type>> { self.types.borrow() } @@ -351,6 +356,15 @@ impl<'a: 'b, 'b> FunctionEditor<'a> { self.mut_def_use[id.idx()].iter().map(|x| *x) } + pub fn get_uses(&self, id: NodeID) -> impl ExactSizeIterator<Item = NodeID> + '_ { + get_uses(&self.function.nodes[id.idx()]) + .as_ref() + .into_iter() + .map(|x| *x) + .collect::<Vec<_>>() + .into_iter() + } + pub fn get_type(&self, id: TypeID) -> Ref<'_, Type> { Ref::map(self.types.borrow(), |types| &types[id.idx()]) } @@ -405,6 +419,10 @@ impl<'a, 'b> FunctionEdit<'a, 'b> { self.editor.function.nodes.len() + self.added_nodeids.len() } + pub fn copy_node(&mut self, node: NodeID) -> NodeID { + self.add_node(self.editor.func().nodes[node.idx()].clone()) + } + pub fn add_node(&mut self, node: Node) -> NodeID { let id = NodeID::new(self.num_node_ids()); // Added nodes need to have an entry in the def-use map. diff --git a/hercules_opt/src/fork_concat_split.rs b/hercules_opt/src/fork_concat_split.rs index 186cd6a6eaad70ba4b26e25dee8714c7988ee611..bb3a2cff556077d2bf3fe54a7fa21d0dd6d4e4b9 100644 --- a/hercules_opt/src/fork_concat_split.rs +++ b/hercules_opt/src/fork_concat_split.rs @@ -7,7 +7,8 @@ use crate::*; /* * Split multi-dimensional fork-joins into separate one-dimensional fork-joins. - * Useful for code generation. + * Useful for code generation. A single iteration of `fork_split` only splits + * at most one fork-join, it must be called repeatedly to split all fork-joins. */ pub fn fork_split( editor: &mut FunctionEditor, @@ -135,5 +136,6 @@ pub fn fork_split( Ok(edit) }); + break; } } diff --git a/hercules_opt/src/fork_guard_elim.rs b/hercules_opt/src/fork_guard_elim.rs index 842c83086f9ecedbbe5c8c96bf160de8968a953a..1abb89672ae1d5c4f0f34578ca9d8eb2d69a2bc0 100644 --- a/hercules_opt/src/fork_guard_elim.rs +++ b/hercules_opt/src/fork_guard_elim.rs @@ -1,8 +1,8 @@ use std::collections::{HashMap, HashSet}; -use hercules_ir::get_uses_mut; -use hercules_ir::ir::*; -use hercules_ir::ImmutableDefUseMap; +use hercules_ir::*; + +use crate::*; /* * This is a Hercules IR transformation that: @@ -17,6 +17,33 @@ use hercules_ir::ImmutableDefUseMap; * guard remains and in these cases the guard is no longer needed. */ +// Simplify factors through max +enum Factor { + Max(usize, DynamicConstantID), + Normal(DynamicConstantID), +} + +impl Factor { + fn get_id(&self) -> DynamicConstantID { + match self { + Factor::Max(_, dynamic_constant_id) => *dynamic_constant_id, + Factor::Normal(dynamic_constant_id) => *dynamic_constant_id, + } + } +} + +struct GuardedFork { + fork: NodeID, + join: NodeID, + guard_if: NodeID, + fork_taken_proj: NodeID, + fork_skipped_proj: NodeID, + guard_pred: NodeID, + guard_join_region: NodeID, + phi_reduce_map: HashMap<NodeID, NodeID>, + factor: Factor, // The factor that matches the guard +} + /* Given a node index and the node itself, return None if the node is not * a guarded fork where we can eliminate the guard. * If the node is a fork with a guard we can eliminate returns a tuple of @@ -28,27 +55,40 @@ use hercules_ir::ImmutableDefUseMap; * - A map of NodeIDs for the phi nodes to the reduce they should be replaced * with, and also the region that joins the guard's branches mapping to the * fork's join NodeID + * - If the replication factor is a max that can be eliminated. */ fn guarded_fork( - function: &Function, - constants: &Vec<Constant>, + editor: &mut FunctionEditor, fork_join_map: &HashMap<NodeID, NodeID>, - def_use: &ImmutableDefUseMap, - index: usize, - node: &Node, -) -> Option<( - NodeID, - Box<[DynamicConstantID]>, - NodeID, - NodeID, - NodeID, - NodeID, - HashMap<NodeID, NodeID>, -)> { + node: NodeID, +) -> Option<GuardedFork> { + let function = editor.func(); + // Identify fork nodes - let Node::Fork { control, factors } = node else { + let Node::Fork { control, factors } = &function.nodes[node.idx()] else { return None; }; + + let mut factors = factors.iter().enumerate().map(|(idx, dc)| { + let DynamicConstant::Max(l, r) = *editor.get_dynamic_constant(*dc) else { + return Factor::Normal(*dc); + }; + + // There really needs to be a better way to work w/ associativity. + let binding = [(l, r), (r, l)]; + let id = binding.iter().find_map(|(a, b)| { + let DynamicConstant::Constant(1) = *editor.get_dynamic_constant(*a) else { + return None; + }; + Some(b) + }); + + match id { + Some(v) => Factor::Max(idx, *v), + None => Factor::Normal(*dc), + } + }); + // Whose predecessor is a read from an if let Node::Projection { control: if_node, @@ -70,47 +110,93 @@ fn guarded_fork( return None; }; let branch_idx = *selection; - // branchIdx == 1 means the true branch so we want the condition to be - // 0 < n or n > 0 - if branch_idx == 1 - && !((op == BinaryOperator::LT - && function.nodes[left.idx()].is_zero_constant(constants) - && factors - .iter() - .any(|factor| function.nodes[right.idx()].try_dynamic_constant() == Some(*factor))) - || (op == BinaryOperator::GT - && function.nodes[right.idx()].is_zero_constant(constants) - && factors.iter().any(|factor| { - function.nodes[left.idx()].try_dynamic_constant() == Some(*factor) - }))) - { - return None; - } - // branchIdx == 0 means the false branch so we want the condition to be - // n < 0 or 0 > n - if branch_idx == 0 - && !((op == BinaryOperator::LT - && factors - .iter() - .any(|factor| function.nodes[left.idx()].try_dynamic_constant() == Some(*factor)) - && function.nodes[right.idx()].is_zero_constant(constants)) - || (op == BinaryOperator::GT - && factors.iter().any(|factor| { - function.nodes[right.idx()].try_dynamic_constant() == Some(*factor) - }) - && function.nodes[left.idx()].is_zero_constant(constants))) - { - return None; - } + + let factor = { + // branchIdx == 1 means the true branch so we want the condition to be + // 0 < n or n > 0 + if branch_idx == 1 { + [ + (left, BinaryOperator::LT, right), + (right, BinaryOperator::GT, left), + ] + .iter() + .find_map(|(pattern_zero, pattern_op, pattern_factor)| { + // Match Op + if op != *pattern_op { + return None; + } + // Match Zero + if !(function.nodes[pattern_zero.idx()].is_zero_constant(&editor.get_constants()) + || editor + .node(pattern_zero) + .is_zero_dc(&editor.get_dynamic_constants())) + { + return None; + } + + // Match Factor + let factor = factors.find(|factor| { + match ( + &function.nodes[pattern_factor.idx()], + &*editor.get_dynamic_constant(factor.get_id()), + ) { + (Node::Constant { id }, DynamicConstant::Constant(v)) => { + let Constant::UnsignedInteger64(pattern_v) = *editor.get_constant(*id) + else { + return false; + }; + pattern_v == (*v as u64) + } + (Node::DynamicConstant { id }, _) => *id == factor.get_id(), + _ => false, + } + }); + factor + }) + } + // branchIdx == 0 means the false branch so we want the condition to be + // n < 0 or 0 > n + else if branch_idx == 0 { + [ + (right, BinaryOperator::LT, left), + (left, BinaryOperator::GT, right), + ] + .iter() + .find_map(|(pattern_zero, pattern_op, pattern_factor)| { + // Match Op + if op != *pattern_op { + return None; + } + // Match Zero + if !(function.nodes[pattern_zero.idx()].is_zero_constant(&editor.get_constants()) + || editor + .node(pattern_zero) + .is_zero_dc(&editor.get_dynamic_constants())) + { + return None; + } + + // Match Factor + let factor = factors.find(|factor| { + function.nodes[pattern_factor.idx()].try_dynamic_constant() + == Some(factor.get_id()) + }); + factor + }) + } else { + None + } + }; + + let Some(factor) = factor else { return None }; // Identify the join node and its users - let join_id = fork_join_map.get(&NodeID::new(index))?; - let join_users = def_use.get_users(*join_id); + let join_id = fork_join_map.get(&node)?; // Find the unique control use of the join; if it's not a region we can't // eliminate this guard - let join_control = join_users - .iter() + let join_control = editor + .get_users(*join_id) .filter(|n| function.nodes[n.idx()].is_region()) .collect::<Vec<_>>(); if join_control.len() != 1 { @@ -134,7 +220,7 @@ fn guarded_fork( } else { return None; }; - // Other predecessor needs to be the other read from the guard's if + // Other predecessor needs to be the other projection from the guard's if let Node::Projection { control: if_node2, ref selection, @@ -152,14 +238,13 @@ fn guarded_fork( // Finally, identify the phi nodes associated with the region and match // them with the reduce nodes of the fork-join - let reduce_nodes = join_users - .iter() + let reduce_nodes = editor + .get_users(*join_id) .filter(|n| function.nodes[n.idx()].is_reduce()) .collect::<HashSet<_>>(); // Construct a map from phi nodes indices to the reduce node index - let phi_nodes = def_use - .get_users(*join_control) - .iter() + let phi_nodes = editor + .get_users(join_control) .filter_map(|n| { let Node::Phi { control: _, @@ -169,25 +254,25 @@ fn guarded_fork( return None; }; if data.len() != 2 { - return Some((*n, None)); + return Some((n, None)); } let (init_idx, reduce_node) = if reduce_nodes.contains(&data[0]) { (1, data[0]) } else if reduce_nodes.contains(&data[1]) { (0, data[1]) } else { - return Some((*n, None)); + return Some((n, None)); }; let Node::Reduce { control: _, init, .. } = function.nodes[reduce_node.idx()] else { - return Some((*n, None)); + return Some((n, None)); }; if data[init_idx] != init { - return Some((*n, None)); + return Some((n, None)); } - Some((*n, Some(reduce_node))) + Some((n, Some(reduce_node))) }) .collect::<HashMap<_, _>>(); @@ -197,69 +282,91 @@ fn guarded_fork( return None; } - let mut phi_nodes = phi_nodes + let phi_nodes = phi_nodes .into_iter() .map(|(phi, red)| (phi, red.unwrap())) .collect::<HashMap<_, _>>(); - // We also add a map from the region to the join to this map so we only - // need one map to handle all node replacements in the elimination process - phi_nodes.insert(*join_control, *join_id); - // Finally, we return this node's index along with // - The replication factor of the fork // - The if node // - The true and false reads of the if // - The guard's predecessor // - The map from phi nodes to reduce nodes and the region to the join - Some(( - NodeID::new(index), - factors.clone(), - if_node, - *control, - other_pred, - if_pred, - phi_nodes, - )) + Some(GuardedFork { + fork: node, + join: *join_id, + guard_if: if_node, + fork_taken_proj: *control, + fork_skipped_proj: other_pred, + guard_pred: if_pred, + guard_join_region: join_control, + phi_reduce_map: phi_nodes, + factor, + }) } /* * Top level function to run fork guard elimination, as described above. - * Deletes nodes by setting nodes to gravestones. Works with a function already - * containing gravestones. */ -pub fn fork_guard_elim( - function: &mut Function, - constants: &Vec<Constant>, - fork_join_map: &HashMap<NodeID, NodeID>, - def_use: &ImmutableDefUseMap, -) { - let guard_info = function - .nodes - .iter() - .enumerate() - .filter_map(|(i, n)| guarded_fork(function, constants, fork_join_map, def_use, i, n)) +pub fn fork_guard_elim(editor: &mut FunctionEditor, fork_join_map: &HashMap<NodeID, NodeID>) { + let guard_info = editor + .node_ids() + .filter_map(|node| guarded_fork(editor, fork_join_map, node)) .collect::<Vec<_>>(); - for (fork_node, factors, guard_node, guard_proj1, guard_proj2, guard_pred, map) in guard_info { - function.nodes[guard_node.idx()] = Node::Start; - function.nodes[guard_proj1.idx()] = Node::Start; - function.nodes[guard_proj2.idx()] = Node::Start; - function.nodes[fork_node.idx()] = Node::Fork { - control: guard_pred, - factors, + for GuardedFork { + fork, + join, + fork_taken_proj, + fork_skipped_proj, + guard_pred, + phi_reduce_map, + factor, + guard_if, + guard_join_region, + } in guard_info + { + let new_fork_info = if let Factor::Max(idx, dc) = factor { + let Node::Fork { + control: _, + mut factors, + } = editor.func().nodes[fork.idx()].clone() + else { + unreachable!() + }; + factors[idx] = dc; + let new_fork = Node::Fork { + control: guard_pred, + factors, + }; + Some(new_fork) + } else { + None }; - for (idx, node) in function.nodes.iter_mut().enumerate() { - let node_idx = NodeID::new(idx); - if map.contains_key(&node_idx) { - *node = Node::Start; + editor.edit(|mut edit| { + edit = + edit.replace_all_uses_where(fork_taken_proj, guard_pred, |usee| *usee == fork)?; + edit = edit.delete_node(guard_if)?; + edit = edit.delete_node(fork_taken_proj)?; + edit = edit.delete_node(fork_skipped_proj)?; + edit = edit.replace_all_uses(guard_join_region, join)?; + edit = edit.delete_node(guard_join_region)?; + // Delete region node + + for (phi, reduce) in phi_reduce_map.iter() { + edit = edit.replace_all_uses(*phi, *reduce)?; + edit = edit.delete_node(*phi)?; } - for u in get_uses_mut(node).as_mut() { - if let Some(replacement) = map.get(u) { - **u = *replacement; - } + + if let Some(new_fork_info) = new_fork_info { + let new_fork = edit.add_node(new_fork_info); + edit = edit.replace_all_uses(fork, new_fork)?; + edit = edit.delete_node(fork)?; } - } + + Ok(edit) + }); } } diff --git a/hercules_opt/src/fork_transforms.rs b/hercules_opt/src/fork_transforms.rs new file mode 100644 index 0000000000000000000000000000000000000000..a4605bec7824255b0098cafabb50ed5446773947 --- /dev/null +++ b/hercules_opt/src/fork_transforms.rs @@ -0,0 +1,540 @@ +use std::collections::{HashMap, HashSet}; + +use bimap::BiMap; +use itertools::Itertools; + +use hercules_ir::*; + +use crate::*; + +type ForkID = usize; + +/** Places each reduce node into its own fork */ +pub fn default_reduce_partition( + editor: &FunctionEditor, + _fork: NodeID, + join: NodeID, +) -> SparseNodeMap<ForkID> { + let mut map = SparseNodeMap::new(); + + editor + .get_users(join) + .filter(|id| editor.func().nodes[id.idx()].is_reduce()) + .enumerate() + .for_each(|(fork, reduce)| { + map.insert(reduce, fork); + }); + + map +} + +// TODO: Refine these conditions. +/** */ +pub fn find_reduce_dependencies<'a>( + function: &'a Function, + reduce: NodeID, + fork: NodeID, +) -> impl IntoIterator<Item = NodeID> + 'a { + let len = function.nodes.len(); + + let mut visited: DenseNodeMap<bool> = vec![false; len]; + let mut depdendent: DenseNodeMap<bool> = vec![false; len]; + + // Does `fork` need to be a parameter here? It never changes. If this was a closure could it just capture it? + fn recurse( + function: &Function, + node: NodeID, + fork: NodeID, + dependent_map: &mut DenseNodeMap<bool>, + visited: &mut DenseNodeMap<bool>, + ) -> () { + // return through dependent_map { + + if visited[node.idx()] { + return; + } + + visited[node.idx()] = true; + + if node == fork { + dependent_map[node.idx()] = true; + return; + } + + let binding = get_uses(&function.nodes[node.idx()]); + let uses = binding.as_ref(); + + for used in uses { + recurse(function, *used, fork, dependent_map, visited); + } + + dependent_map[node.idx()] = uses.iter().map(|id| dependent_map[id.idx()]).any(|a| a); + return; + } + + // Note: HACKY, the condition wwe want is 'all nodes on any path from the fork to the reduce (in the forward graph), or the reduce to the fork (in the directed graph) + // cycles break this, but we assume for now that the only cycles are ones that involve the reduce node + // NOTE: (control may break this (i.e loop inside fork) is a cycle that isn't the reduce) + // the current solution is just to mark the reduce as dependent at the start of traversing the graph. + depdendent[reduce.idx()] = true; + + recurse(function, reduce, fork, &mut depdendent, &mut visited); + + // Return node IDs that are dependent + let ret_val: Vec<_> = depdendent + .iter() + .enumerate() + .filter_map(|(idx, dependent)| { + if *dependent { + Some(NodeID::new(idx)) + } else { + None + } + }) + .collect(); + + ret_val +} + +pub fn copy_subgraph( + editor: &mut FunctionEditor, + subgraph: HashSet<NodeID>, +) -> ( + HashSet<NodeID>, + HashMap<NodeID, NodeID>, + Vec<(NodeID, NodeID)>, +) // returns all new nodes, a map from old nodes to new nodes, and + // a vec of pairs of nodes (old node, outside node) s.t old node -> outside node, + // outside means not part of the original subgraph. +{ + let mut map: HashMap<NodeID, NodeID> = HashMap::new(); + let mut new_nodes: HashSet<NodeID> = HashSet::new(); + + // Copy nodes + for old_id in subgraph.iter() { + editor.edit(|mut edit| { + let new_id = edit.copy_node(*old_id); + map.insert(*old_id, new_id); + new_nodes.insert(new_id); + Ok(edit) + }); + } + + // Update edges to new nodes + for old_id in subgraph.iter() { + // Replace all uses of old_id w/ new_id, where the use is in new_node + editor.edit(|edit| { + edit.replace_all_uses_where(*old_id, map[old_id], |node_id| new_nodes.contains(node_id)) + }); + } + + // Get all users that aren't in new_nodes. + let mut outside_users = Vec::new(); + + for node in new_nodes.iter() { + for user in editor.get_users(*node) { + if !new_nodes.contains(&user) { + outside_users.push((*node, user)); + } + } + } + + (new_nodes, map, outside_users) +} + +pub fn fork_fission<'a>( + editor: &'a mut FunctionEditor, + _control_subgraph: &Subgraph, + _types: &Vec<TypeID>, + _loop_tree: &LoopTree, + fork_join_map: &HashMap<NodeID, NodeID>, +) -> () { + let forks: Vec<_> = editor + .func() + .nodes + .iter() + .enumerate() + .filter_map(|(idx, node)| { + if node.is_fork() { + Some(NodeID::new(idx)) + } else { + None + } + }) + .collect(); + + let control_pred = NodeID::new(0); + + // This does the reduction fission: + for fork in forks.clone() { + // FIXME: If there is control in between fork and join, don't just give up. + let join = fork_join_map[&fork]; + let join_pred = editor.func().nodes[join.idx()].try_join().unwrap(); + if join_pred != fork { + todo!("Can't do fork fission on nodes with internal control") + // Inner control LOOPs are hard + // inner control in general *should* work right now without modifications. + } + let reduce_partition = default_reduce_partition(editor, fork, join); + fork_reduce_fission_helper(editor, fork_join_map, reduce_partition, control_pred, fork); + } +} + +/** Split a 1D fork into two forks, placing select intermediate data into buffers. */ +pub fn fork_bufferize_fission_helper<'a>( + editor: &'a mut FunctionEditor, + fork_join_map: &HashMap<NodeID, NodeID>, + bufferized_edges: HashSet<(NodeID, NodeID)>, // Describes what intermediate data should be bufferized. + _original_control_pred: NodeID, // What the new fork connects to. + types: &Vec<TypeID>, + fork: NodeID, +) -> (NodeID, NodeID) { + // Returns the two forks that it generates. + + // TODO: Check that bufferized edges src doesn't depend on anything that comes after the fork. + + // Copy fork + control intermediates + join to new fork + join, + // How does control get partitioned? + // (depending on how it affects the data nodes on each side of the bufferized_edges) + // may end up in each loop, fix me later. + // place new fork + join after join of first. + + // Only handle fork+joins with no inner control for now. + + // Create fork + join + Thread control + let join = fork_join_map[&fork]; + let mut new_fork_id = NodeID::new(0); + let mut new_join_id = NodeID::new(0); + + editor.edit(|mut edit| { + new_join_id = edit.add_node(Node::Join { control: fork }); + let factors = edit.get_node(fork).try_fork().unwrap().1; + new_fork_id = edit.add_node(Node::Fork { + control: new_join_id, + factors: factors.into(), + }); + edit.replace_all_uses_where(fork, new_fork_id, |usee| *usee == join) + }); + + for (src, dst) in bufferized_edges { + // FIXME: Disgusting cloning and allocationing and iterators. + let factors: Vec<_> = editor.func().nodes[fork.idx()] + .try_fork() + .unwrap() + .1 + .iter() + .cloned() + .collect(); + editor.edit(|mut edit| { + // Create write to buffer + + let thread_stuff_it = factors.into_iter().enumerate(); + + // FIxme: try to use unzip here? Idk why it wasn't working. + let tids = thread_stuff_it.clone().map(|(dim, _)| { + edit.add_node(Node::ThreadID { + control: fork, + dimension: dim, + }) + }); + + let array_dims = thread_stuff_it.clone().map(|(_, factor)| (factor)); + + // Assume 1-d fork only for now. + // let tid = edit.add_node(Node::ThreadID { control: fork, dimension: 0 }); + let position_idx = Index::Position(tids.collect::<Vec<_>>().into_boxed_slice()); + let write = edit.add_node(Node::Write { + collect: NodeID::new(0), + data: src, + indices: vec![position_idx].into(), + }); + let ele_type = types[src.idx()]; + let empty_buffer = edit.add_type(hercules_ir::Type::Array( + ele_type, + array_dims.collect::<Vec<_>>().into_boxed_slice(), + )); + let empty_buffer = edit.add_zero_constant(empty_buffer); + let empty_buffer = edit.add_node(Node::Constant { id: empty_buffer }); + let reduce = Node::Reduce { + control: new_join_id, + init: empty_buffer, + reduct: write, + }; + let reduce = edit.add_node(reduce); + // Fix write node + edit = edit.replace_all_uses_where(NodeID::new(0), reduce, |usee| *usee == write)?; + + // Create read from buffer + let tids = thread_stuff_it.clone().map(|(dim, _)| { + edit.add_node(Node::ThreadID { + control: new_fork_id, + dimension: dim, + }) + }); + + let position_idx = Index::Position(tids.collect::<Vec<_>>().into_boxed_slice()); + + let read = edit.add_node(Node::Read { + collect: reduce, + indices: vec![position_idx].into(), + }); + + edit = edit.replace_all_uses_where(src, read, |usee| *usee == dst)?; + + Ok(edit) + }); + } + + (fork, new_fork_id) +} + +/** Split a 1D fork into a separate fork for each reduction. */ +pub fn fork_reduce_fission_helper<'a>( + editor: &'a mut FunctionEditor, + fork_join_map: &HashMap<NodeID, NodeID>, + reduce_partition: SparseNodeMap<ForkID>, // Describes how the reduces of the fork should be split, + original_control_pred: NodeID, // What the new fork connects to. + + fork: NodeID, +) -> (NodeID, NodeID) { + let join = fork_join_map[&fork]; + + let mut new_control_pred: NodeID = original_control_pred; + // Important edges are: Reduces, + + // NOTE: + // Say two reduce are in a fork, s.t reduce A depends on reduce B + // If user wants A and B in separate forks: + // - we can simply refuse + // - or we can duplicate B + + let mut new_fork = NodeID::new(0); + let mut new_join = NodeID::new(0); + + // Gets everything between fork & join that this reduce needs. (ALL CONTROL) + for reduce in reduce_partition { + let reduce = reduce.0; + + let function = editor.func(); + let subgraph = find_reduce_dependencies(function, reduce, fork); + + let mut subgraph: HashSet<NodeID> = subgraph.into_iter().collect(); + + subgraph.insert(join); + subgraph.insert(fork); + subgraph.insert(reduce); + + let (_, mapping, _) = copy_subgraph(editor, subgraph); + + new_fork = mapping[&fork]; + new_join = mapping[&join]; + + editor.edit(|mut edit| { + // Atttach new_fork after control_pred + let (old_control_pred, _) = edit.get_node(new_fork).try_fork().unwrap().clone(); + edit = edit.replace_all_uses_where(old_control_pred, new_control_pred, |usee| { + *usee == new_fork + })?; + + // Replace uses of reduce + edit = edit.replace_all_uses(reduce, mapping[&reduce])?; + Ok(edit) + }); + + new_control_pred = new_join; + } + + editor.edit(|mut edit| { + // Replace original join w/ new final join + edit = edit.replace_all_uses_where(join, new_join, |_| true)?; + + // Delete original join (all reduce users have been moved) + edit = edit.delete_node(join)?; + + // Replace all users of original fork, and then delete it, leftover users will be DCE'd. + edit = edit.replace_all_uses(fork, new_fork)?; + edit.delete_node(fork) + }); + + (new_fork, new_join) +} + +pub fn fork_coalesce( + editor: &mut FunctionEditor, + loops: &LoopTree, + fork_join_map: &HashMap<NodeID, NodeID>, +) -> bool { + let fork_joins = loops.bottom_up_loops().into_iter().filter_map(|(k, _)| { + if editor.func().nodes[k.idx()].is_fork() { + Some(k) + } else { + None + } + }); + + let fork_joins: Vec<_> = fork_joins.collect(); + // FIXME: Add a postorder traversal to optimize this. + + // FIXME: This could give us two forks that aren't actually ancestors / related, but then the helper will just return false early. + // something like: `fork_joins.postorder_iter().windows(2)` is ideal here. + for (inner, outer) in fork_joins.iter().cartesian_product(fork_joins.iter()) { + if fork_coalesce_helper(editor, *outer, *inner, fork_join_map) { + return true; + } + } + return false; +} + +/** Opposite of fork split, takes two fork-joins + with no control between them, and merges them into a single fork-join. +*/ +pub fn fork_coalesce_helper( + editor: &mut FunctionEditor, + outer_fork: NodeID, + inner_fork: NodeID, + fork_join_map: &HashMap<NodeID, NodeID>, +) -> bool { + // Check that all reduces in the outer fork are in *simple* cycles with a unique reduce of the inner fork. + + let outer_join = fork_join_map[&outer_fork]; + let inner_join = fork_join_map[&inner_fork]; + + let mut pairs: BiMap<NodeID, NodeID> = BiMap::new(); // Outer <-> Inner + + // FIXME: Iterate all control uses of joins to really collect all reduces + // (reduces can be attached to inner control) + for outer_reduce in editor + .get_users(outer_join) + .filter(|node| editor.func().nodes[node.idx()].is_reduce()) + { + // check that inner reduce is of the inner join + let (_, _, outer_reduct) = editor.func().nodes[outer_reduce.idx()] + .try_reduce() + .unwrap(); + + let inner_reduce = outer_reduct; + let inner_reduce_node = &editor.func().nodes[outer_reduct.idx()]; + + let Node::Reduce { + control: inner_control, + init: inner_init, + reduct: _, + } = inner_reduce_node + else { + return false; + }; + + // FIXME: check this condition better (i.e reduce might not be attached to join) + if *inner_control != inner_join { + return false; + }; + if *inner_init != outer_reduce { + return false; + }; + + if pairs.contains_left(&outer_reduce) || pairs.contains_right(&inner_reduce) { + return false; + } else { + pairs.insert(outer_reduce, inner_reduce); + } + } + + // Check for control between join-join and fork-fork + let Some(user) = editor + .get_users(outer_fork) + .filter(|node| editor.func().nodes[node.idx()].is_control()) + .next() + else { + return false; + }; + + if user != inner_fork { + return false; + } + + let Some(user) = editor + .get_users(inner_join) + .filter(|node| editor.func().nodes[node.idx()].is_control()) + .next() + else { + return false; + }; + + if user != outer_join { + return false; + } + + // Checklist: + // Increment inner TIDs + // Add outer fork's dimension to front of inner fork. + // Fuse reductions + // - Initializer becomes outer initializer + // Replace uses of outer fork w/ inner fork. + // Replace uses of outer join w/ inner join. + // Delete outer fork-join + + let inner_tids: Vec<NodeID> = editor + .get_users(inner_fork) + .filter(|node| editor.func().nodes[node.idx()].is_thread_id()) + .collect(); + + let (outer_pred, outer_dims) = editor.func().nodes[outer_fork.idx()].try_fork().unwrap(); + let (_, inner_dims) = editor.func().nodes[inner_fork.idx()].try_fork().unwrap(); + let num_outer_dims = outer_dims.len(); + let mut new_factors = outer_dims.to_vec(); + + // CHECKME / FIXME: Might need to be added the other way. + new_factors.append(&mut inner_dims.to_vec()); + + for tid in inner_tids { + let (fork, dim) = editor.func().nodes[tid.idx()].try_thread_id().unwrap(); + let new_tid = Node::ThreadID { + control: fork, + dimension: dim + num_outer_dims, + }; + + editor.edit(|mut edit| { + let new_tid = edit.add_node(new_tid); + let edit = edit.replace_all_uses(tid, new_tid)?; + Ok(edit) + }); + } + + // Fuse Reductions + for (outer_reduce, inner_reduce) in pairs { + let (_, outer_init, _) = editor.func().nodes[outer_reduce.idx()] + .try_reduce() + .unwrap(); + let (_, inner_init, _) = editor.func().nodes[inner_reduce.idx()] + .try_reduce() + .unwrap(); + editor.edit(|mut edit| { + // Set inner init to outer init. + edit = + edit.replace_all_uses_where(inner_init, outer_init, |usee| *usee == inner_reduce)?; + edit = edit.replace_all_uses(outer_reduce, inner_reduce)?; + edit = edit.delete_node(outer_reduce)?; + + Ok(edit) + }); + } + + editor.edit(|mut edit| { + let new_fork = Node::Fork { + control: outer_pred, + factors: new_factors.into(), + }; + let new_fork = edit.add_node(new_fork); + + edit = edit.replace_all_uses(inner_fork, new_fork)?; + edit = edit.replace_all_uses(outer_fork, new_fork)?; + edit = edit.replace_all_uses(outer_join, inner_join)?; + edit = edit.delete_node(outer_join)?; + edit = edit.delete_node(inner_fork)?; + edit = edit.delete_node(outer_fork)?; + + Ok(edit) + }); + + true +} diff --git a/hercules_opt/src/forkify.rs b/hercules_opt/src/forkify.rs index fb53a5e4a84bb08da7606aea914d368878c94b41..ce9ac1412f1253bff6589ec668db63725183ca6c 100644 --- a/hercules_opt/src/forkify.rs +++ b/hercules_opt/src/forkify.rs @@ -1,239 +1,547 @@ +use std::collections::HashMap; +use std::collections::HashSet; use std::iter::zip; +use std::iter::FromIterator; -use hercules_ir::def_use::*; -use hercules_ir::ir::*; -use hercules_ir::loops::*; +use itertools::Itertools; +use nestify::nest; -/* - * Top level function to convert natural loops with simple induction variables - * into fork-joins. +use hercules_ir::*; + +use crate::*; + +/* + * TODO: Forkify currently makes a bunch of small edits - this needs to be + * changed so that every loop that gets forkified corresponds to a single edit + * + sub-edits. This would allow us to run forkify on a subset of a function. */ pub fn forkify( - function: &mut Function, - constants: &Vec<Constant>, - dynamic_constants: &mut Vec<DynamicConstant>, - def_use: &ImmutableDefUseMap, + editor: &mut FunctionEditor, + control_subgraph: &Subgraph, + fork_join_map: &HashMap<NodeID, NodeID>, loops: &LoopTree, -) { - // Ignore loops that are already fork-joins. TODO: re-calculate def_use per - // loop, since it's technically invalidated after each individual loop - // modification. +) -> bool { let natural_loops = loops .bottom_up_loops() .into_iter() - .rev() - .filter(|(k, _)| function.nodes[k.idx()].is_region()); + .filter(|(k, _)| editor.func().nodes[k.idx()].is_region()); + + let natural_loops: Vec<_> = natural_loops.collect(); + + for l in natural_loops { + // FIXME: Run on all-bottom level loops, as they can be independently optimized without recomputing analyses. + if forkify_loop( + editor, + control_subgraph, + fork_join_map, + &Loop { + header: l.0, + control: l.1.clone(), + }, + ) { + return true; + } + } + return false; +} + +/** Given a node used as a loop bound, return a dynamic constant ID. */ +pub fn get_node_as_dc( + editor: &mut FunctionEditor, + node: NodeID, +) -> Result<DynamicConstantID, String> { + // Check for a constant used as loop bound. + match editor.node(node) { + Node::DynamicConstant { + id: dynamic_constant_id, + } => Ok(*dynamic_constant_id), + Node::Constant { id: constant_id } => { + let dc = match *editor.get_constant(*constant_id) { + Constant::Integer8(x) => DynamicConstant::Constant(x as _), + Constant::Integer16(x) => DynamicConstant::Constant(x as _), + Constant::Integer32(x) => DynamicConstant::Constant(x as _), + Constant::Integer64(x) => DynamicConstant::Constant(x as _), + Constant::UnsignedInteger8(x) => DynamicConstant::Constant(x as _), + Constant::UnsignedInteger16(x) => DynamicConstant::Constant(x as _), + Constant::UnsignedInteger32(x) => DynamicConstant::Constant(x as _), + Constant::UnsignedInteger64(x) => DynamicConstant::Constant(x as _), + _ => return Err("Invalid constant as loop bound".to_string()), + }; + + let mut b = DynamicConstantID::new(0); + editor.edit(|mut edit| { + b = edit.add_dynamic_constant(dc); + Ok(edit) + }); + // Return the ID of the dynamic constant that is generated from the constant + // or dynamic constant that is the existing loop bound + Ok(b) + } + _ => Err("Blah".to_owned()), + } +} + +/** + Top level function to convert natural loops with simple induction variables + into fork-joins. +*/ +pub fn forkify_loop( + editor: &mut FunctionEditor, + control_subgraph: &Subgraph, + _fork_join_map: &HashMap<NodeID, NodeID>, + l: &Loop, +) -> bool { + let function = editor.func(); + + let Some(loop_condition) = get_loop_exit_conditions(function, l, control_subgraph) else { + return false; + }; + + let LoopExit::Conditional { + if_node: loop_if, + condition_node, + } = loop_condition.clone() + else { + return false; + }; + + // Compute loop variance + let loop_variance = compute_loop_variance(editor, l); + let ivs = compute_induction_vars(editor.func(), l, &loop_variance); + let ivs = compute_iv_ranges(editor, l, ivs, &loop_condition); + let Some(canonical_iv) = has_canonical_iv(editor, l, &ivs) else { + return false; + }; - // Detect loops that have a simple loop induction variable. TODO: proper - // affine analysis to recognize other cases of linear induction variables. - let affine_loops: Vec<_> = natural_loops + // FIXME: Make sure IV is not used outside the loop. + + // Get bound + let bound = match canonical_iv { + InductionVariable::Basic { + node: _, + initializer: _, + update: _, + final_value, + } => final_value + .map(|final_value| get_node_as_dc(editor, final_value)) + .and_then(|r| r.ok()), + InductionVariable::SCEV(_) => return false, + }; + + let Some(bound_dc_id) = bound else { + return false; + }; + + let function = editor.func(); + + // Check if it is do-while loop. + let loop_exit_projection = editor + .get_users(loop_if) + .filter(|id| !l.control[id.idx()]) + .next() + .unwrap(); + + let loop_continue_projection = editor + .get_users(loop_if) + .filter(|id| l.control[id.idx()]) + .next() + .unwrap(); + + let loop_preds: Vec<_> = editor + .get_uses(l.header) + .filter(|id| !l.control[id.idx()]) + .collect(); + + if loop_preds.len() != 1 { + return false; + } + + let loop_pred = loop_preds[0]; + + if !editor + .get_uses(l.header) + .contains(&loop_continue_projection) + { + return false; + } + + // Get all phis used outside of the loop, they need to be reductionable. + // For now just assume all phis will be phis used outside of the loop, except for the canonical iv. + // FIXME: We need a different definiton of `loop_nodes` to check for phis used outside hte loop than the one + // we currently have. + let loop_nodes = calculate_loop_nodes(editor, l); + + // Check phis to see if they are reductionable, only PHIs depending on the loop are considered, + let candidate_phis: Vec<_> = editor + .get_users(l.header) + .filter(|id| function.nodes[id.idx()].is_phi()) + .filter(|id| *id != canonical_iv.phi()) + .collect(); + + let reductionable_phis: Vec<_> = analyze_phis(&editor, &l, &candidate_phis, &loop_nodes) .into_iter() - .filter_map(|(header, contents)| { - // Get the single loop contained predecessor of the loop header. - let header_uses = get_uses(&function.nodes[header.idx()]); - let mut pred_loop = header_uses.as_ref().iter().filter(|id| contents[id.idx()]); - let single_pred_loop = pred_loop.next()?; - if pred_loop.next().is_some() || header_uses.as_ref().len() != 2 { - return None; - } + .collect(); - // Check for a very particular loop indexing structure. - let if_ctrl = function.nodes[single_pred_loop.idx()].try_projection(1)?; - let (_, if_cond) = function.nodes[if_ctrl.idx()].try_if()?; - let (idx, bound) = function.nodes[if_cond.idx()].try_binary(BinaryOperator::LT)?; - let (phi, one) = function.nodes[idx.idx()].try_binary(BinaryOperator::Add)?; - let (should_be_header, pred_datas) = function.nodes[phi.idx()].try_phi()?; - let one_c_id = function.nodes[one.idx()].try_constant()?; + // TODO: Handle multiple loop body lasts. + // If there are multiple candidates for loop body last, return false. + if editor + .get_uses(loop_if) + .filter(|id| l.control[id.idx()]) + .count() + > 1 + { + return false; + } - if should_be_header != header || !constants[one_c_id.idx()].is_one() { - return None; - } + let loop_body_last = editor.get_uses(loop_if).next().unwrap(); - // Check that phi's if predecessor is the add node, and check that the - // phi's other predecessors are zeros. - zip(header_uses.as_ref().iter(), pred_datas.iter()) - .position(|(c, d)| *c == *single_pred_loop && *d == idx)?; - if zip(header_uses.as_ref().iter(), pred_datas.iter()) - .filter(|(c, d)| { - (**c != *single_pred_loop) - && !function.nodes[d.idx()].is_zero_constant(constants) - }) - .count() - != 0 - { - return None; - } + if reductionable_phis + .iter() + .any(|phi| !matches!(phi, LoopPHI::Reductionable { .. })) + { + return false; + } - // Check for constant used as loop bound. Do this last, since we may - // create a new dynamic constant here. - let bound_dc_id = - if let Some(bound_dc_id) = function.nodes[bound.idx()].try_dynamic_constant() { - bound_dc_id - } else if let Some(bound_c_id) = function.nodes[bound.idx()].try_constant() { - // Create new dynamic constant that reflects this constant. - let dc = match constants[bound_c_id.idx()] { - Constant::Integer8(x) => DynamicConstant::Constant(x as _), - Constant::Integer16(x) => DynamicConstant::Constant(x as _), - Constant::Integer32(x) => DynamicConstant::Constant(x as _), - Constant::Integer64(x) => DynamicConstant::Constant(x as _), - Constant::UnsignedInteger8(x) => DynamicConstant::Constant(x as _), - Constant::UnsignedInteger16(x) => DynamicConstant::Constant(x as _), - Constant::UnsignedInteger32(x) => DynamicConstant::Constant(x as _), - Constant::UnsignedInteger64(x) => DynamicConstant::Constant(x as _), - _ => return None, - }; - - // The new dynamic constant may already be interned. - let maybe_already_in = dynamic_constants - .iter() - .enumerate() - .find(|(_, x)| **x == dc) - .map(|(idx, _)| idx); - if let Some(bound_dc_idx) = maybe_already_in { - DynamicConstantID::new(bound_dc_idx) - } else { - let id = DynamicConstantID::new(dynamic_constants.len()); - dynamic_constants.push(dc); - id - } - } else { - return None; - }; + let phi_latches: Vec<_> = reductionable_phis + .iter() + .map(|phi| { + let LoopPHI::Reductionable { + phi: _, + data_cycle: _, + continue_latch, + is_associative: _, + } = phi + else { + unreachable!() + }; + continue_latch + }) + .collect(); - Some((header, phi, contents, bound_dc_id)) + let stop_on: HashSet<_> = editor + .node_ids() + .filter(|node| { + if editor.node(node).is_phi() { + return true; + } + if editor.node(node).is_reduce() { + return true; + } + if editor.node(node).is_control() { + return true; + } + if phi_latches.contains(&node) { + return true; + } + + false }) .collect(); - // Convert affine loops into fork-joins. - for (header, idx_phi, contents, dc_id) in affine_loops { - let header_uses = get_uses(&function.nodes[header.idx()]); - let header_uses: Vec<_> = header_uses.as_ref().into_iter().map(|x| *x).collect(); + // Outside loop users of IV, then exit; + // Unless the outside user is through the loop latch of a reducing phi, + // then we know how to replace this edge, so its fine! + let iv_users: Vec<_> = + walk_all_users_stop_on(canonical_iv.phi(), editor, stop_on.clone()).collect(); - // Get the control portions of the loop that need to be grafted. - let loop_pred = header_uses - .iter() - .filter(|id| !contents[id.idx()]) - .next() - .unwrap(); - let loop_true_read = header_uses - .iter() - .filter(|id| contents[id.idx()]) - .next() - .unwrap(); - let loop_end = function.nodes[loop_true_read.idx()] - .try_projection(1) - .unwrap(); - let loop_false_read = *def_use - .get_users(loop_end) - .iter() - .filter_map(|id| { - if function.nodes[id.idx()].try_projection(0).is_some() { - Some(id) - } else { - None - } - }) - .next() - .unwrap(); + if iv_users + .iter() + .any(|node| !loop_nodes.contains(&node) && *node != loop_if) + { + return false; + } + + // Start Transformation: + + // Graft everything between header and loop condition + // Attach join to right before header (after loop_body_last, unless loop body last *is* the header). + // Attach fork to right after loop_continue_projection. + + // // Create fork and join nodes: + let mut join_id = NodeID::new(0); + let mut fork_id = NodeID::new(0); - // Create fork and join nodes. + // Turn dc bound into max (1, bound), + let bound_dc_id = { + let mut max_id = DynamicConstantID::new(0); + editor.edit(|mut edit| { + // FIXME: Maybe add_dynamic_constant should intern? + let one_id = edit.add_dynamic_constant(DynamicConstant::Constant(1)); + max_id = edit.add_dynamic_constant(DynamicConstant::Max(one_id, bound_dc_id)); + Ok(edit) + }); + max_id + }; + + // FIXME: (@xrouth) double check handling of control in loop body. + editor.edit(|mut edit| { let fork = Node::Fork { - control: *loop_pred, - factors: Box::new([dc_id]), + control: loop_pred, + factors: Box::new([bound_dc_id]), }; - let fork_id = NodeID::new(function.nodes.len()); - function.nodes.push(fork); + fork_id = edit.add_node(fork); let join = Node::Join { - control: if header == get_uses(&function.nodes[loop_end.idx()]).as_ref()[0] { + control: if l.header == loop_body_last { fork_id } else { - function.nodes[loop_end.idx()].try_if().unwrap().0 + loop_body_last }, }; - let join_id = NodeID::new(function.nodes.len()); - function.nodes.push(join); - - // Convert reducing phi nodes to reduce nodes. - let reduction_phis: Vec<_> = def_use - .get_users(header) - .iter() - .filter(|id| **id != idx_phi && function.nodes[id.idx()].is_phi()) - .collect(); - for reduction_phi in reduction_phis { - // Loop predecessor input to phi is the reduction initializer. - let init = *zip( - header_uses.iter(), - function.nodes[reduction_phi.idx()] - .try_phi() - .unwrap() - .1 - .iter(), - ) - .filter(|(c, _)| **c == *loop_pred) - .next() - .unwrap() - .1; - - // Loop back edge input to phi is the reduction induction variable. - let reduct = *zip( - header_uses.iter(), - function.nodes[reduction_phi.idx()] - .try_phi() - .unwrap() - .1 - .iter(), - ) - .filter(|(c, _)| **c == *loop_true_read) - .next() - .unwrap() - .1; - - // Create reduction node. + + join_id = edit.add_node(join); + + Ok(edit) + }); + + let function = editor.func(); + let (_, factors) = function.nodes[fork_id.idx()].try_fork().unwrap(); + let dimension = factors.len() - 1; + + // Create ThreadID + editor.edit(|mut edit| { + let thread_id = Node::ThreadID { + control: fork_id, + dimension: dimension, + }; + let thread_id_id = edit.add_node(thread_id); + + // Replace uses that are inside with the thread id + edit = edit.replace_all_uses_where(canonical_iv.phi(), thread_id_id, |node| { + loop_nodes.contains(node) + })?; + + // Replace uses that are outside with DC - 1. Or just give up. + let bound_dc_node = edit.add_node(Node::DynamicConstant { id: bound_dc_id }); + edit = edit.replace_all_uses_where(canonical_iv.phi(), bound_dc_node, |node| { + !loop_nodes.contains(node) + })?; + + edit.delete_node(canonical_iv.phi()) + }); + + for reduction_phi in reductionable_phis { + let LoopPHI::Reductionable { + phi, + data_cycle: _, + continue_latch, + is_associative: _, + } = reduction_phi + else { + panic!(); + }; + + let function = editor.func(); + + let init = *zip( + editor.get_uses(l.header), + function.nodes[phi.idx()].try_phi().unwrap().1.iter(), + ) + .filter(|(c, _)| *c == loop_pred) + .next() + .unwrap() + .1; + + editor.edit(|mut edit| { let reduce = Node::Reduce { control: join_id, init, - reduct, + reduct: continue_latch, }; - let reduce_id = NodeID::new(function.nodes.len()); - function.nodes.push(reduce); + let reduce_id = edit.add_node(reduce); - // Edit users of phis. - for user in def_use.get_users(*reduction_phi) { - get_uses_mut(&mut function.nodes[user.idx()]).map(*reduction_phi, reduce_id); - } + edit = edit.replace_all_uses_where(phi, reduce_id, |usee| *usee != reduce_id)?; + edit = edit.replace_all_uses_where(continue_latch, reduce_id, |usee| { + !loop_nodes.contains(usee) && *usee != reduce_id + })?; + edit.delete_node(phi) + }); + } - // Edit users of uses of phis. - for user in def_use.get_users(reduct) { - get_uses_mut(&mut function.nodes[user.idx()]).map(reduct, reduce_id); - } + // Replace all uses of the loop header with the fork + editor.edit(|edit| edit.replace_all_uses(l.header, fork_id)); - // Delete reducing phi. - function.nodes[reduction_phi.idx()] = Node::Start; - } + editor.edit(|edit| edit.replace_all_uses(loop_continue_projection, fork_id)); - // Convert index phi node to thread ID node. - let thread_id = Node::ThreadID { - control: fork_id, - dimension: 0, - }; - let thread_id_id = NodeID::new(function.nodes.len()); - function.nodes.push(thread_id); + editor.edit(|edit| edit.replace_all_uses(loop_exit_projection, join_id)); - for user in def_use.get_users(idx_phi) { - get_uses_mut(&mut function.nodes[user.idx()]).map(idx_phi, thread_id_id); - } - for user in def_use.get_users(header) { - get_uses_mut(&mut function.nodes[user.idx()]).map(header, fork_id); - } - for user in def_use.get_users(loop_false_read) { - get_uses_mut(&mut function.nodes[user.idx()]).map(loop_false_read, join_id); - } + // Get rid of loop condition + // DCE should get these, but delete them ourselves because we are nice :) + editor.edit(|mut edit| { + edit = edit.delete_node(loop_continue_projection)?; + edit = edit.delete_node(condition_node)?; // Might have to get rid of other users of this. + edit = edit.delete_node(loop_exit_projection)?; + edit = edit.delete_node(loop_if)?; + edit = edit.delete_node(l.header)?; + Ok(edit) + }); + + return true; +} - function.nodes[idx_phi.idx()] = Node::Start; - function.nodes[header.idx()] = Node::Start; - function.nodes[loop_end.idx()] = Node::Start; - function.nodes[loop_true_read.idx()] = Node::Start; - function.nodes[loop_false_read.idx()] = Node::Start; +nest! { + #[derive(Debug)] + pub enum LoopPHI { + Reductionable { + phi: NodeID, + data_cycle: HashSet<NodeID>, // All nodes in a data cycle with this phi + continue_latch: NodeID, + is_associative: bool, + }, + LoopDependant(NodeID), + UsedByDependant(NodeID), } } + +impl LoopPHI { + pub fn get_phi(&self) -> NodeID { + match self { + LoopPHI::Reductionable { phi, .. } => *phi, + LoopPHI::LoopDependant(node_id) => *node_id, + LoopPHI::UsedByDependant(node_id) => *node_id, + } + } +} + +/** +Checks some conditions on loop variables that will need to be converted into reductions to be forkified. + - The phi is in a cycle *in the loop* with itself. + - Every cycle *in the loop* containing the phi does not contain any other phi of the loop header. + - The phi does not immediatley (not blocked by another phi or another reduce) use any other phis of the loop header. + */ +pub fn analyze_phis<'a>( + editor: &'a FunctionEditor, + natural_loop: &'a Loop, + phis: &'a [NodeID], + loop_nodes: &'a HashSet<NodeID>, +) -> impl Iterator<Item = LoopPHI> + 'a { + + // Find data cycles within the loop of this phi, + // Start from the phis loop_continue_latch, and walk its uses until we find the original phi. + + phis.into_iter().map(move |phi| { + let stop_on: HashSet<NodeID> = editor + .node_ids() + .filter(|node| { + let data = &editor.func().nodes[node.idx()]; + + // External Phi + if let Node::Phi { control, data: _ } = data { + if *control != natural_loop.header { + return true; + } + } + + // This phi + if node == phi { + return true; + } + + // External Reduce + if let Node::Reduce { + control, + init: _, + reduct: _, + } = data + { + if !natural_loop.control[control.idx()] { + return true; + } else { + return false; + } + } + + // Data Cycles Only + if data.is_control() { + return true; + } + + return false; + }) + .collect(); + + let continue_idx = editor + .get_uses(natural_loop.header) + .position(|node| natural_loop.control[node.idx()]) + .unwrap(); + + let loop_continue_latch = editor.node(phi).try_phi().unwrap().1[continue_idx]; + + let uses = walk_all_uses_stop_on(loop_continue_latch, editor, stop_on.clone()); + let users = walk_all_users_stop_on(*phi, editor, stop_on.clone()); + + let other_stop_on: HashSet<NodeID> = editor + .node_ids() + .filter(|node| { + let data = &editor.func().nodes[node.idx()]; + + // Phi, Reduce + if data.is_phi() { + return true; + } + + if data.is_reduce() { + return true; + } + + // External Control + if data.is_control() { + return true; + } + + return false; + }) + .collect(); + + let mut uses_for_dependance = + walk_all_users_stop_on(loop_continue_latch, editor, other_stop_on); + + let set1: HashSet<_> = HashSet::from_iter(uses); + let set2: HashSet<_> = HashSet::from_iter(users); + + let intersection: HashSet<_> = set1.intersection(&set2).cloned().collect(); + + // If this phi uses any other phis the node is loop dependant, + // we use `phis` because this phi can actually contain the loop iv and its fine. + if uses_for_dependance.any(|node| phis.contains(&node) && node != *phi) { + LoopPHI::LoopDependant(*phi) + } else if intersection.clone().iter().next().is_some() { + // PHIs on the frontier of the uses by the candidate phi, i.e in uses_for_dependance need + // to have headers that postdominate the loop continue latch. The value of the PHI used needs to be defined + // by the time the reduce is triggered (at the end of the loop's internal control). + + // No nodes in data cycles with this phi (in the loop) are used outside the loop, besides the loop_continue_latch. + // If some other node in the cycle is used, there is not a valid node to assign it after making the cycle a reduce. + if intersection + .iter() + .filter(|node| **node != loop_continue_latch ) + .any(|data_node| { + editor + .get_users(*data_node) + .any(|user| !loop_nodes.contains(&user)) + }) + { + // This phi can be made into a reduce in different ways, if the cycle is associative (contains all the same kind of associative op) + // 3) Split the cycle into two phis, add them or multiply them together at the end. + // 4) Split the cycle into two reduces, add them or multiply them together at the end. + // Somewhere else should handle this. + return LoopPHI::LoopDependant(*phi); + } + + // FIXME: Do we want to calculate associativity here, there might be a case where this information is used in forkify + // i.e as described above. + let is_associative = false; + + // No nodes in the data cycle are used outside of the loop, besides the latched value of the phi + LoopPHI::Reductionable { + phi: *phi, + data_cycle: intersection, + continue_latch: loop_continue_latch, + is_associative, + } + } else { + // No cycles exist, this isn't a reduction. + LoopPHI::LoopDependant(*phi) + } + }) +} diff --git a/hercules_opt/src/gcm.rs b/hercules_opt/src/gcm.rs index b5822a519a5ca204db3a80d581493ea37ded511b..462d10871565bfed9b351d2627c44af8ca778ffc 100644 --- a/hercules_opt/src/gcm.rs +++ b/hercules_opt/src/gcm.rs @@ -1022,7 +1022,7 @@ fn liveness_dataflow( * device clones when a single node may potentially be on different devices. */ fn color_nodes( - editor: &mut FunctionEditor, + _editor: &mut FunctionEditor, reverse_postorder: &Vec<NodeID>, objects: &FunctionCollectionObjects, object_device_demands: &FunctionObjectDeviceDemands, @@ -1138,7 +1138,7 @@ fn object_allocation( typing: &Vec<TypeID>, node_colors: &FunctionNodeColors, alignments: &Vec<usize>, - liveness: &Liveness, + _liveness: &Liveness, backing_allocations: &BackingAllocations, ) -> FunctionBackingAllocation { let mut fba = BTreeMap::new(); diff --git a/hercules_opt/src/ivar.rs b/hercules_opt/src/ivar.rs new file mode 100644 index 0000000000000000000000000000000000000000..f7252d29b66f9fc1882849206bbbf5b327a0f307 --- /dev/null +++ b/hercules_opt/src/ivar.rs @@ -0,0 +1,602 @@ +use std::collections::HashSet; + +use bitvec::prelude::*; +use nestify::nest; + +use hercules_ir::*; + +use crate::*; + +#[derive(Debug)] +pub struct LoopVarianceInfo { + pub loop_header: NodeID, + pub map: DenseNodeMap<LoopVariance>, +} + +#[derive(Clone, Copy, Debug, PartialEq)] +pub enum LoopVariance { + Unknown, + Invariant, + Variant, +} + +type NodeVec = BitVec<u8, Lsb0>; + +#[derive(Clone, Debug)] +pub struct Loop { + pub header: NodeID, + pub control: NodeVec, // +} + +impl Loop { + pub fn get_all_nodes(&self) -> NodeVec { + let mut all_loop_nodes = self.control.clone(); + all_loop_nodes.set(self.header.idx(), true); + all_loop_nodes + } +} + +nest! { + #[derive(Clone, Copy, Debug, PartialEq)]* + pub enum InductionVariable { + pub Basic { + node: NodeID, + initializer: NodeID, + update: NodeID, + final_value: Option<NodeID>, + }, + SCEV(NodeID), // TODO @(xrouth) + } +} + +impl InductionVariable { + pub fn phi(&self) -> NodeID { + match self { + InductionVariable::Basic { + node, + initializer: _, + update: _, + final_value: _, + } => *node, + InductionVariable::SCEV(_) => todo!(), + } + } +} + +// TODO: Optimize. +pub fn calculate_loop_nodes(editor: &FunctionEditor, natural_loop: &Loop) -> HashSet<NodeID> { + // Stop on PHIs / reduces outside of loop. + let stop_on: HashSet<NodeID> = editor + .node_ids() + .filter(|node| { + let data = &editor.func().nodes[node.idx()]; + + // External Phi + if let Node::Phi { control, data: _ } = data { + if !natural_loop.control[control.idx()] { + return true; + } + } + // External Reduce + if let Node::Reduce { + control, + init: _, + reduct: _, + } = data + { + if !natural_loop.control[control.idx()] { + return true; + } + } + + // External Control + if data.is_control() && !natural_loop.control[node.idx()] { + return true; + } + + return false; + }) + .collect(); + + let phis: Vec<_> = editor + .node_ids() + .filter(|node| { + let Node::Phi { control, data: _ } = editor.func().nodes[node.idx()] else { + return false; + }; + natural_loop.control[control.idx()] + }) + .collect(); + + let all_users: HashSet<NodeID> = phis + .clone() + .iter() + .flat_map(|phi| walk_all_users_stop_on(*phi, editor, stop_on.clone())) + .chain(phis.clone()) + .collect(); + + let all_uses: HashSet<_> = phis + .clone() + .iter() + .flat_map(|phi| walk_all_uses_stop_on(*phi, editor, stop_on.clone())) + .chain(phis.clone()) + .filter(|node| { + // Get rid of nodes in stop_on + !stop_on.contains(node) + }) + .collect(); + + all_users + .intersection(&all_uses) + .chain(phis.iter()) + .cloned() + .collect() +} + +/** returns PHIs that are on any regions inside the loop. */ +pub fn get_all_loop_phis<'a>( + function: &'a Function, + l: &'a Loop, +) -> impl Iterator<Item = NodeID> + 'a { + function + .nodes + .iter() + .enumerate() + .filter_map(move |(node_id, node)| { + if let Some((control, _)) = node.try_phi() { + if l.control[control.idx()] { + Some(NodeID::new(node_id)) + } else { + None + } + } else { + None + } + }) +} + +// FIXME: Need a trait that Editor and Function both implement, that gives us UseDefInfo + +/** Given a loop determine for each data node if the value might change upon each iteration of the loop */ +pub fn compute_loop_variance(editor: &FunctionEditor, l: &Loop) -> LoopVarianceInfo { + // Gather all Phi nodes that are controlled by this loop. + let mut loop_vars: Vec<NodeID> = vec![]; + + for node_id in editor.get_users(l.header) { + let node = &editor.func().nodes[node_id.idx()]; + if let Some((control, _)) = node.try_phi() { + if l.control[control.idx()] { + loop_vars.push(node_id); + } + } + } + + let len = editor.func().nodes.len(); + + let mut all_loop_nodes = l.control.clone(); + + all_loop_nodes.set(l.header.idx(), true); + + let mut variance_map: DenseNodeMap<LoopVariance> = vec![LoopVariance::Unknown; len]; + + fn recurse( + function: &Function, + node: NodeID, + all_loop_nodes: &BitVec<u8, Lsb0>, + variance_map: &mut DenseNodeMap<LoopVariance>, + visited: &mut DenseNodeMap<bool>, + ) -> LoopVariance { + if visited[node.idx()] { + return variance_map[node.idx()]; + } + + visited[node.idx()] = true; + + let node_variance = match variance_map[node.idx()] { + LoopVariance::Invariant => LoopVariance::Invariant, + LoopVariance::Variant => LoopVariance::Variant, + LoopVariance::Unknown => { + let mut node_variance = LoopVariance::Invariant; + + // Two conditions cause something to be loop variant: + for node_use in get_uses(&function.nodes[node.idx()]).as_ref() { + // 1) The use is a PHI *controlled* by the loop + if let Some((control, _)) = function.nodes[node_use.idx()].try_phi() { + if *all_loop_nodes.get(control.idx()).unwrap() { + node_variance = LoopVariance::Variant; + break; + } + } + + // 2) Any of the nodes uses are loop variant + if recurse(function, *node_use, all_loop_nodes, variance_map, visited) + == LoopVariance::Variant + { + node_variance = LoopVariance::Variant; + break; + } + } + + variance_map[node.idx()] = node_variance; + + node_variance + } + }; + + return node_variance; + } + + let mut visited: DenseNodeMap<bool> = vec![false; len]; + + for node in (0..len).map(NodeID::new) { + recurse( + editor.func(), + node, + &all_loop_nodes, + &mut variance_map, + &mut visited, + ); + } + + return LoopVarianceInfo { + loop_header: l.header, + map: variance_map, + }; +} + +nest! { +#[derive(Clone, Copy, Debug, PartialEq)] +pub enum LoopExit { + Conditional { + if_node: NodeID, + condition_node: NodeID, + }, + Unconditional(NodeID) +} +} + +pub fn get_loop_exit_conditions( + function: &Function, + l: &Loop, + control_subgraph: &Subgraph, +) -> Option<LoopExit> { + // impl IntoIterator<Item = LoopExit> + // DFS Traversal on loop control subgraph until we find a node that is outside the loop, find the last IF on this path. + let mut last_if_on_path: DenseNodeMap<Option<NodeID>> = vec![None; function.nodes.len()]; + + // FIXME: (@xrouth) THIS IS MOST CERTAINLY BUGGED + // this might be bugged... i.e might need to udpate `last if` even if already defined. + // needs to be `saturating` kinda, more iterative. May need to visit nodes more than once? + + // FIXME: (@xrouth) Right now we assume only one exit from the loop, later: check for multiple exits on the loop, + // either as an assertion here or some other part of forkify or analysis. + let mut bag_of_control_nodes = vec![l.header]; + let mut visited: DenseNodeMap<bool> = vec![false; function.nodes.len()]; + + let mut final_if: Option<NodeID> = None; + + // do WFS + while !bag_of_control_nodes.is_empty() { + let node = bag_of_control_nodes.pop().unwrap(); + if visited[node.idx()] { + continue; + } + visited[node.idx()] = true; + + final_if = if function.nodes[node.idx()].is_if() { + Some(node) + } else { + last_if_on_path[node.idx()] + }; + + if !l.control[node.idx()] { + break; + } + + for succ in control_subgraph.succs(node) { + last_if_on_path[succ.idx()] = final_if; + bag_of_control_nodes.push(succ.clone()); + } + } + + final_if.map(|v| LoopExit::Conditional { + if_node: v, + condition_node: if let Node::If { control: _, cond } = function.nodes[v.idx()] { + cond + } else { + unreachable!() + }, + }) +} + +pub fn has_const_fields(editor: &FunctionEditor, ivar: InductionVariable) -> bool { + match ivar { + InductionVariable::Basic { + node: _, + initializer, + update, + final_value, + } => { + if final_value.is_none() { + return false; + } + [initializer, update] + .iter() + .any(|node| !editor.node(node).is_constant()) + } + InductionVariable::SCEV(_) => false, + } +} + +/* Loop has any IV from range 0....N, N can be dynconst iterates +1 per iteration */ +// IVs need to be bounded... +pub fn has_canonical_iv<'a>( + editor: &FunctionEditor, + _l: &Loop, + ivs: &'a [InductionVariable], +) -> Option<&'a InductionVariable> { + ivs.iter().find(|iv| match iv { + InductionVariable::Basic { + node: _, + initializer, + update, + final_value, + } => { + (editor + .node(initializer) + .is_zero_constant(&editor.get_constants()) + || editor + .node(initializer) + .is_zero_dc(&editor.get_dynamic_constants())) + && (editor.node(update).is_one_constant(&editor.get_constants()) + || editor + .node(update) + .is_one_dc(&editor.get_dynamic_constants())) + && (final_value + .map(|val| { + editor.node(val).is_constant() || editor.node(val).is_dynamic_constant() + }) + .is_some()) + } + InductionVariable::SCEV(_) => false, + }) +} + +// Need a transformation that forces all IVs to be SCEVs of an IV from range 0...N, +1, else places them in a separate loop? +pub fn compute_induction_vars( + function: &Function, + l: &Loop, + _loop_variance: &LoopVarianceInfo, +) -> Vec<InductionVariable> { + // 1) Gather PHIs contained in the loop. + // FIXME: (@xrouth) Should this just be PHIs controlled by the header? + let mut loop_vars: Vec<NodeID> = vec![]; + + for (node_id, node) in function.nodes.iter().enumerate() { + if let Some((control, _)) = node.try_phi() { + if l.control[control.idx()] { + loop_vars.push(NodeID::new(node_id)); + } + } + } + + // FIXME: (@xrouth) For now, only compute variables that have one assignment, + // (look into this:) possibly treat multiple assignment as separate induction variables. + let mut induction_variables: Vec<InductionVariable> = vec![]; + + /* For each PHI controlled by the loop, check how it is modified */ + + // It's initializer needs to be loop invariant, it's update needs to be loop variant. + for phi_id in loop_vars { + let phi_node = &function.nodes[phi_id.idx()]; + let (region, data) = phi_node.try_phi().unwrap(); + let region_node = &function.nodes[region.idx()]; + let Node::Region { + preds: region_inputs, + } = region_node + else { + continue; + }; + + // The initializer index is the first index of the inputs to the region node of that isn't in the loop. (what is loop_header, wtf...) + // FIXME (@xrouth): If there is control flow in the loop, we won't find ... WHAT + let Some(initializer_idx) = region_inputs + .iter() + .position(|&node_id| !l.control[node_id.idx()]) + else { + continue; + }; + + let initializer_id = data[initializer_idx]; + + // Check dynamic constancy: + let initializer = &function.nodes[initializer_id.idx()]; + + // In the case of a non 0 starting value: + // - a new dynamic constant or constant may need to be created that is the difference between the initiailizer and the loop bounds. + // Initializer does not necessarily have to be constant, but this is fine for now. + if !(initializer.is_dynamic_constant() || initializer.is_constant()) { + continue; + } + + // Check all data inputs to this phi, that aren't the initializer (i.e the value the comes from control outside of the loop) + // For now we expect only one initializer. + let data_inputs = data + .iter() + .filter(|data_id| NodeID::new(initializer_idx) != **data_id); + + for data_id in data_inputs { + let node = &function.nodes[data_id.idx()]; + for bop in [BinaryOperator::Add] { + //, BinaryOperator::Mul, BinaryOperator::Sub] { + if let Some((a, b)) = node.try_binary(bop) { + let iv = [(a, b), (b, a)] + .iter() + .find_map(|(pattern_phi, pattern_const)| { + if *pattern_phi == phi_id + && function.nodes[pattern_const.idx()].is_constant() + || function.nodes[pattern_const.idx()].is_dynamic_constant() + { + return Some(InductionVariable::Basic { + node: phi_id, + initializer: initializer_id, + update: b, + final_value: None, + }); + } else { + None + } + }); + if let Some(iv) = iv { + induction_variables.push(iv); + } + } + } + } + } + + induction_variables +} + +// Find loop iterations +pub fn compute_iv_ranges( + editor: &FunctionEditor, + l: &Loop, + induction_vars: Vec<InductionVariable>, + loop_condition: &LoopExit, +) -> Vec<InductionVariable> { + let condition_node = match loop_condition { + LoopExit::Conditional { + if_node: _, + condition_node, + } => condition_node, + LoopExit::Unconditional(_) => todo!(), + }; + + // Find IVs used by the loop condition, not across loop iterations. + // without leaving the loop. + let stop_on: HashSet<_> = editor + .node_ids() + .filter(|node_id| { + if let Node::Phi { control, data: _ } = editor.node(node_id) { + *control == l.header + } else { + false + } + }) + .collect(); + + // Bound IVs used in loop bound. + let loop_bound_uses: HashSet<_> = + walk_all_uses_stop_on(*condition_node, editor, stop_on).collect(); + let (loop_bound_ivs, other_ivs): (Vec<InductionVariable>, Vec<InductionVariable>) = + induction_vars + .into_iter() + .partition(|f| loop_bound_uses.contains(&f.phi())); + + let Some(iv) = loop_bound_ivs.first() else { + return other_ivs; + }; + + if loop_bound_ivs.len() > 1 { + return loop_bound_ivs.into_iter().chain(other_ivs).collect(); + } + + // FIXME: DO linear algerbra to solve for loop bounds with multiple variables involved. + let final_value = match &editor.func().nodes[condition_node.idx()] { + Node::Phi { + control: _, + data: _, + } => None, + Node::Reduce { + control: _, + init: _, + reduct: _, + } => None, + Node::Parameter { index: _ } => None, + Node::Constant { id: _ } => None, + Node::Unary { input: _, op: _ } => None, + Node::Ternary { + first: _, + second: _, + third: _, + op: _, + } => None, + Node::Binary { left, right, op } => { + match op { + BinaryOperator::LT => { + // Check for a loop guard condition. + // left < right + if *left == iv.phi() + && (editor.func().nodes[right.idx()].is_constant() + || editor.func().nodes[right.idx()].is_dynamic_constant()) + { + Some(*right) + } + // left + const < right, + else if let Node::Binary { + left: inner_left, + right: inner_right, + op: _, + } = editor.node(left) + { + let pattern = [(inner_left, inner_right), (inner_right, inner_left)] + .iter() + .find_map(|(pattern_iv, pattern_constant)| { + if iv.phi() == **pattern_iv + && (editor.node(*pattern_constant).is_constant() + || editor.node(*pattern_constant).is_dynamic_constant()) + { + // FIXME: pattern_constant can be anything >= loop_update expression, + let update = match iv { + InductionVariable::Basic { + node: _, + initializer: _, + update, + final_value: _, + } => update, + InductionVariable::SCEV(_) => todo!(), + }; + if *pattern_constant == update { + Some(*right) + } else { + None + } + } else { + None + } + }); + pattern.iter().cloned().next() + } else { + None + } + } + BinaryOperator::LTE => None, + BinaryOperator::GT => None, + BinaryOperator::GTE => None, + BinaryOperator::EQ => None, + BinaryOperator::NE => None, + _ => None, + } + } + _ => None, + }; + + let basic = match iv { + InductionVariable::Basic { + node, + initializer, + update, + final_value: _, + } => InductionVariable::Basic { + node: *node, + initializer: *initializer, + update: *update, + final_value, + }, + InductionVariable::SCEV(_) => todo!(), + }; + + // Propagate bounds to other IVs. + vec![basic].into_iter().chain(other_ivs).collect() +} diff --git a/hercules_opt/src/lib.rs b/hercules_opt/src/lib.rs index 4a90f698efda8ff5f0c6b1dd4817dfbd615ed370..e3cca1612354fee6b5544264b97a37f1352bbafd 100644 --- a/hercules_opt/src/lib.rs +++ b/hercules_opt/src/lib.rs @@ -4,16 +4,17 @@ pub mod ccp; pub mod crc; pub mod dce; pub mod delete_uncalled; -pub mod device_placement; pub mod editor; pub mod float_collections; pub mod fork_concat_split; pub mod fork_guard_elim; +pub mod fork_transforms; pub mod forkify; pub mod gcm; pub mod gvn; pub mod inline; pub mod interprocedural_sroa; +pub mod ivar; pub mod lift_dc_math; pub mod outline; pub mod phi_elim; @@ -28,16 +29,17 @@ pub use crate::ccp::*; pub use crate::crc::*; pub use crate::dce::*; pub use crate::delete_uncalled::*; -pub use crate::device_placement::*; pub use crate::editor::*; pub use crate::float_collections::*; pub use crate::fork_concat_split::*; pub use crate::fork_guard_elim::*; +pub use crate::fork_transforms::*; pub use crate::forkify::*; pub use crate::gcm::*; pub use crate::gvn::*; pub use crate::inline::*; pub use crate::interprocedural_sroa::*; +pub use crate::ivar::*; pub use crate::lift_dc_math::*; pub use crate::outline::*; pub use crate::phi_elim::*; diff --git a/hercules_opt/src/outline.rs b/hercules_opt/src/outline.rs index e59c815da12b505cadc807c4d87e6a2ef913d3fa..8fe978c5c9554fa7d0fd42f480ff724dcdc9cb36 100644 --- a/hercules_opt/src/outline.rs +++ b/hercules_opt/src/outline.rs @@ -4,7 +4,6 @@ use std::sync::atomic::{AtomicUsize, Ordering}; use hercules_ir::def_use::*; use hercules_ir::dom::*; -use hercules_ir::fork_join_analysis::*; use hercules_ir::ir::*; use hercules_ir::subgraph::*; diff --git a/hercules_opt/src/schedule.rs b/hercules_opt/src/schedule.rs index 2c8209aae002beffca3c8f7ee40d6986fed49fb0..f9f720bef2c00f97facfaf90bfd880bc06beaf5a 100644 --- a/hercules_opt/src/schedule.rs +++ b/hercules_opt/src/schedule.rs @@ -29,7 +29,7 @@ pub fn infer_parallel_fork(editor: &mut FunctionEditor, fork_join_map: &HashMap< /* * Infer parallel reductions consisting of a simple cycle between a Reduce node * and a Write node, where indices of the Write are position indices using the - * ThreadID nodes attached to the corresponding Fork, and data of the Write is + * ThreadID nodes attached to the corresponding Fork, and data of the Write is * not in the Reduce node's cycle. This procedure also adds the ParallelReduce * schedule to Reduce nodes reducing over a parallelized Reduce, as long as the * base Write node also has position indices of the ThreadID of the outer fork. @@ -37,7 +37,11 @@ pub fn infer_parallel_fork(editor: &mut FunctionEditor, fork_join_map: &HashMap< * as long as each ThreadID dimension appears in the positional indexing of the * original Write. */ -pub fn infer_parallel_reduce(editor: &mut FunctionEditor, fork_join_map: &HashMap<NodeID, NodeID>, reduce_cycles: &HashMap<NodeID, HashSet<NodeID>>) { +pub fn infer_parallel_reduce( + editor: &mut FunctionEditor, + fork_join_map: &HashMap<NodeID, NodeID>, + reduce_cycles: &HashMap<NodeID, HashSet<NodeID>>, +) { for id in editor.node_ids() { let func = editor.func(); if !func.nodes[id.idx()].is_reduce() { @@ -146,11 +150,17 @@ pub fn infer_vectorizable(editor: &mut FunctionEditor, fork_join_map: &HashMap<N * operation's operands must be the Reduce node, and all other operands must * not be in the Reduce node's cycle. */ -pub fn infer_tight_associative(editor: &mut FunctionEditor, reduce_cycles: &HashMap<NodeID, HashSet<NodeID>>) { - let is_binop_associative = |op| matches!(op, - BinaryOperator::Add | BinaryOperator::Or | BinaryOperator::And | BinaryOperator::Xor); - let is_intrinsic_associative = |intrinsic| matches!(intrinsic, - Intrinsic::Max | Intrinsic::Min); +pub fn infer_tight_associative( + editor: &mut FunctionEditor, + reduce_cycles: &HashMap<NodeID, HashSet<NodeID>>, +) { + let is_binop_associative = |op| { + matches!( + op, + BinaryOperator::Add | BinaryOperator::Or | BinaryOperator::And | BinaryOperator::Xor + ) + }; + let is_intrinsic_associative = |intrinsic| matches!(intrinsic, Intrinsic::Max | Intrinsic::Min); for id in editor.node_ids() { let func = editor.func(); @@ -162,8 +172,8 @@ pub fn infer_tight_associative(editor: &mut FunctionEditor, reduce_cycles: &Hash && (matches!(func.nodes[reduct.idx()], Node::Binary { left, right, op } if ((left == id && !reduce_cycles[&id].contains(&right)) || (right == id && !reduce_cycles[&id].contains(&left))) && - is_binop_associative(op)) || - matches!(&func.nodes[reduct.idx()], Node::IntrinsicCall { intrinsic, args } + is_binop_associative(op)) + || matches!(&func.nodes[reduct.idx()], Node::IntrinsicCall { intrinsic, args } if (args.contains(&id) && is_intrinsic_associative(*intrinsic) && args.iter().filter(|arg| **arg != id).all(|arg| !reduce_cycles[&id].contains(arg))))) { diff --git a/hercules_opt/src/unforkify.rs b/hercules_opt/src/unforkify.rs index a5df7a7c404e820a92d27842211dd0e3396dae41..85ffd233dad79ca3339525cdf4542493d8c20124 100644 --- a/hercules_opt/src/unforkify.rs +++ b/hercules_opt/src/unforkify.rs @@ -1,17 +1,110 @@ -use std::collections::HashMap; +use std::collections::{HashMap, HashSet}; use std::iter::zip; -use hercules_ir::ir::*; +use bitvec::{order::Lsb0, vec::BitVec}; +use hercules_ir::{ir::*, LoopTree}; use crate::*; +type NodeVec = BitVec<u8, Lsb0>; +pub fn calculate_fork_nodes( + editor: &FunctionEditor, + inner_control: &NodeVec, + fork: NodeID, +) -> HashSet<NodeID> { + // Stop on PHIs / reduces outside of loop. + let stop_on: HashSet<NodeID> = editor + .node_ids() + .filter(|node| { + let data = &editor.func().nodes[node.idx()]; + + // External Phi + if let Node::Phi { control, data: _ } = data { + if match inner_control.get(control.idx()) { + Some(v) => !*v, // + None => true, // Doesn't exist, must be external + } { + return true; + } + } + // External Reduce + if let Node::Reduce { + control, + init: _, + reduct: _, + } = data + { + if match inner_control.get(control.idx()) { + Some(v) => !*v, // + None => true, // Doesn't exist, must be external + } { + return true; + } + } + + // External Control + if data.is_control() { + return match inner_control.get(node.idx()) { + Some(v) => !*v, // + None => true, // Doesn't exist, must be external + }; + } + // else + return false; + }) + .collect(); + + let reduces: Vec<_> = editor + .node_ids() + .filter(|node| { + let Node::Reduce { control, .. } = editor.func().nodes[node.idx()] else { + return false; + }; + match inner_control.get(control.idx()) { + Some(v) => *v, + None => false, + } + }) + .chain( + editor + .get_users(fork) + .filter(|node| editor.node(node).is_thread_id()), + ) + .collect(); + + let all_users: HashSet<NodeID> = reduces + .clone() + .iter() + .flat_map(|phi| walk_all_users_stop_on(*phi, editor, stop_on.clone())) + .chain(reduces.clone()) + .collect(); + + let all_uses: HashSet<_> = reduces + .clone() + .iter() + .flat_map(|phi| walk_all_uses_stop_on(*phi, editor, stop_on.clone())) + .chain(reduces) + .filter(|node| { + // Get rid of nodes in stop_on + !stop_on.contains(node) + }) + .collect(); + + all_users.intersection(&all_uses).cloned().collect() +} /* * Convert forks back into loops right before codegen when a backend is not * lowering a fork-join to vector / parallel code. Lowering fork-joins into * sequential loops in LLVM is actually not entirely trivial, so it's easier to * just do this transformation within Hercules IR. */ -pub fn unforkify(editor: &mut FunctionEditor, fork_join_map: &HashMap<NodeID, NodeID>) { + +// FIXME: Only works on fully split fork nests. +pub fn unforkify( + editor: &mut FunctionEditor, + fork_join_map: &HashMap<NodeID, NodeID>, + loop_tree: &LoopTree, +) { let mut zero_cons_id = ConstantID::new(0); let mut one_cons_id = ConstantID::new(0); assert!(editor.edit(|mut edit| { @@ -25,7 +118,16 @@ pub fn unforkify(editor: &mut FunctionEditor, fork_join_map: &HashMap<NodeID, No // control insides of the fork-join should become the successor of the true // projection node, and what was the use of the join should become a use of // the new region. - for (fork, join) in fork_join_map { + for l in loop_tree.bottom_up_loops().into_iter().rev() { + if !editor.node(l.0).is_fork() { + continue; + } + + let fork = &l.0; + let join = &fork_join_map[&fork]; + + let fork_nodes = calculate_fork_nodes(editor, l.1, *fork); + let nodes = &editor.func().nodes; let (fork_control, factors) = nodes[fork.idx()].try_fork().unwrap(); if factors.len() > 1 { @@ -54,20 +156,43 @@ pub fn unforkify(editor: &mut FunctionEditor, fork_join_map: &HashMap<NodeID, No let add_id = NodeID::new(num_nodes + 7); let dc_id = NodeID::new(num_nodes + 8); let neq_id = NodeID::new(num_nodes + 9); - let phi_ids = (num_nodes + 10..num_nodes + 10 + reduces.len()).map(NodeID::new); + + let guard_if_id = NodeID::new(num_nodes + 10); + let guard_join_id = NodeID::new(num_nodes + 11); + let guard_taken_proj_id = NodeID::new(num_nodes + 12); + let guard_skipped_proj_id = NodeID::new(num_nodes + 13); + let guard_cond_id = NodeID::new(num_nodes + 14); + + let phi_ids = (num_nodes + 15..num_nodes + 15 + reduces.len()).map(NodeID::new); + let s = num_nodes + 15 + reduces.len(); + let join_phi_ids = (s..s + reduces.len()).map(NodeID::new); + + let guard_cond = Node::Binary { + left: zero_id, + right: dc_id, + op: BinaryOperator::LT, + }; + let guard_if = Node::If { + control: fork_control, + cond: guard_cond_id, + }; + let guard_taken_proj = Node::Projection { + control: guard_if_id, + selection: 1, + }; + let guard_skipped_proj = Node::Projection { + control: guard_if_id, + selection: 0, + }; + let guard_join = Node::Region { + preds: Box::new([guard_skipped_proj_id, proj_exit_id]), + }; let region = Node::Region { - preds: Box::new([ - fork_control, - if join_control == *fork { - proj_back_id - } else { - join_control - }, - ]), + preds: Box::new([guard_taken_proj_id, proj_back_id]), }; let if_node = Node::If { - control: region_id, + control: join_control, cond: neq_id, }; let proj_back = Node::Projection { @@ -92,19 +217,25 @@ pub fn unforkify(editor: &mut FunctionEditor, fork_join_map: &HashMap<NodeID, No let dc = Node::DynamicConstant { id: factors[0] }; let neq = Node::Binary { op: BinaryOperator::NE, - left: indvar_id, + left: add_id, right: dc_id, }; - let phis: Vec<_> = reduces + let (phis, join_phis): (Vec<_>, Vec<_>) = reduces .iter() .map(|reduce_id| { let (_, init, reduct) = nodes[reduce_id.idx()].try_reduce().unwrap(); - Node::Phi { - control: region_id, - data: Box::new([init, reduct]), - } + ( + Node::Phi { + control: region_id, + data: Box::new([init, reduct]), + }, + Node::Phi { + control: guard_join_id, + data: Box::new([init, reduct]), + }, + ) }) - .collect(); + .unzip(); editor.edit(|mut edit| { assert_eq!(edit.add_node(region), region_id); @@ -117,21 +248,41 @@ pub fn unforkify(editor: &mut FunctionEditor, fork_join_map: &HashMap<NodeID, No assert_eq!(edit.add_node(add), add_id); assert_eq!(edit.add_node(dc), dc_id); assert_eq!(edit.add_node(neq), neq_id); - for (phi_id, phi) in zip(phi_ids.clone(), phis) { - assert_eq!(edit.add_node(phi), phi_id); + assert_eq!(edit.add_node(guard_if), guard_if_id); + assert_eq!(edit.add_node(guard_join), guard_join_id); + assert_eq!(edit.add_node(guard_taken_proj), guard_taken_proj_id); + assert_eq!(edit.add_node(guard_skipped_proj), guard_skipped_proj_id); + assert_eq!(edit.add_node(guard_cond), guard_cond_id); + + for (phi_id, phi) in zip(phi_ids.clone(), &phis) { + assert_eq!(edit.add_node(phi.clone()), phi_id); + } + for (phi_id, phi) in zip(join_phi_ids.clone(), &join_phis) { + assert_eq!(edit.add_node(phi.clone()), phi_id); } - edit = edit.replace_all_uses(*fork, proj_back_id)?; - edit = edit.replace_all_uses(*join, proj_exit_id)?; + edit = edit.replace_all_uses(*fork, region_id)?; + edit = edit.replace_all_uses_where(*join, guard_join_id, |usee| *usee != if_id)?; edit.sub_edit(*fork, region_id); edit.sub_edit(*join, if_id); for tid in tids.iter() { edit.sub_edit(*tid, indvar_id); edit = edit.replace_all_uses(*tid, indvar_id)?; } - for (reduce, phi_id) in zip(reduces.iter(), phi_ids) { + for (((reduce, phi_id), phi), join_phi_id) in + zip(reduces.iter(), phi_ids).zip(phis).zip(join_phi_ids) + { edit.sub_edit(*reduce, phi_id); - edit = edit.replace_all_uses(*reduce, phi_id)?; + let Node::Phi { control: _, data } = phi else { + panic!() + }; + edit = edit.replace_all_uses_where(*reduce, join_phi_id, |usee| { + !fork_nodes.contains(usee) + })?; //, |usee| *usee != *reduct)?; + edit = edit.replace_all_uses_where(*reduce, phi_id, |usee| { + fork_nodes.contains(usee) || *usee == data[1] + })?; + edit = edit.delete_node(*reduce)?; } edit = edit.delete_node(*fork)?; @@ -139,9 +290,6 @@ pub fn unforkify(editor: &mut FunctionEditor, fork_join_map: &HashMap<NodeID, No for tid in tids { edit = edit.delete_node(tid)?; } - for reduce in reduces { - edit = edit.delete_node(reduce)?; - } Ok(edit) }); diff --git a/hercules_opt/src/utils.rs b/hercules_opt/src/utils.rs index aa0d53fe32f855d5a1f9ad689fab44ef330e7a76..7ad48c1c09cc8542d1b521e3d8e12fe271ef1d39 100644 --- a/hercules_opt/src/utils.rs +++ b/hercules_opt/src/utils.rs @@ -1,7 +1,10 @@ +use std::collections::HashMap; +use std::collections::HashSet; use std::iter::zip; use hercules_ir::def_use::*; use hercules_ir::ir::*; +use nestify::nest; use crate::*; @@ -376,3 +379,105 @@ pub(crate) fn indices_may_overlap(indices1: &[Index], indices2: &[Index]) -> boo // may overlap when one indexes a larger sub-value than the other. true } + +pub type DenseNodeMap<T> = Vec<T>; +pub type SparseNodeMap<T> = HashMap<NodeID, T>; + +nest! { +// +#[derive(Clone, Debug)] +pub struct NodeIterator<'a> { + pub direction: + #[derive(Clone, Debug, PartialEq)] + pub enum Direction { + Uses, + Users, + }, + visited: DenseNodeMap<bool>, + stack: Vec<NodeID>, + func: &'a FunctionEditor<'a>, // Maybe this is an enum, def use can be gotten from the function or from the editor. + // `stop condition`, then return all nodes that caused stoppage i.e the frontier of the search. + stop_on: HashSet<NodeID>, // Don't add neighbors of these. +} +} + +pub fn walk_all_uses<'a>(node: NodeID, editor: &'a FunctionEditor<'a>) -> NodeIterator<'a> { + let len = editor.func().nodes.len(); + NodeIterator { + direction: Direction::Uses, + visited: vec![false; len], + stack: vec![node], + func: editor, + stop_on: HashSet::new(), + } +} + +pub fn walk_all_users<'a>(node: NodeID, editor: &'a FunctionEditor<'a>) -> NodeIterator<'a> { + let len = editor.func().nodes.len(); + NodeIterator { + direction: Direction::Users, + visited: vec![false; len], + stack: vec![node], + func: editor, + stop_on: HashSet::new(), + } +} + +pub fn walk_all_uses_stop_on<'a>( + node: NodeID, + editor: &'a FunctionEditor<'a>, + stop_on: HashSet<NodeID>, +) -> NodeIterator<'a> { + let len = editor.func().nodes.len(); + let uses = editor.get_uses(node).collect(); + NodeIterator { + direction: Direction::Uses, + visited: vec![false; len], + stack: uses, + func: editor, + stop_on, + } +} + +pub fn walk_all_users_stop_on<'a>( + node: NodeID, + editor: &'a FunctionEditor<'a>, + stop_on: HashSet<NodeID>, +) -> NodeIterator<'a> { + let len = editor.func().nodes.len(); + let users = editor.get_users(node).collect(); + NodeIterator { + direction: Direction::Users, + visited: vec![false; len], + stack: users, + func: editor, + stop_on, + } +} + +impl<'a> Iterator for NodeIterator<'a> { + type Item = NodeID; + + fn next(&mut self) -> Option<Self::Item> { + while let Some(current) = self.stack.pop() { + if !self.visited[current.idx()] { + self.visited[current.idx()] = true; + + if !self.stop_on.contains(¤t) { + if self.direction == Direction::Uses { + for neighbor in self.func.get_uses(current) { + self.stack.push(neighbor) + } + } else { + for neighbor in self.func.get_users(current) { + self.stack.push(neighbor) + } + } + } + + return Some(current); + } + } + None + } +} diff --git a/hercules_samples/matmul/build.rs b/hercules_samples/matmul/build.rs index f895af867a019dfd23381a4df2d9a02f80a032f8..c15ca97fa4b0730622f28e6cf16f7ab24de7310a 100644 --- a/hercules_samples/matmul/build.rs +++ b/hercules_samples/matmul/build.rs @@ -4,7 +4,7 @@ fn main() { JunoCompiler::new() .ir_in_src("matmul.hir") .unwrap() - //.schedule_in_src(if cfg!(feature = "cuda") { "gpu.sch" } else { "cpu.sch" }) + // .schedule_in_src(if cfg!(feature = "cuda") { "gpu.sch" } else { "cpu.sch" }) .schedule_in_src("cpu.sch") .unwrap() .build() diff --git a/hercules_test/hercules_interpreter/Cargo.toml b/hercules_test/hercules_interpreter/Cargo.toml index d41caff805f96f17d173450db40e6ce396faeead..6e02b9b83b8b37de7e7102b2d580fa7bc54aa6da 100644 --- a/hercules_test/hercules_interpreter/Cargo.toml +++ b/hercules_test/hercules_interpreter/Cargo.toml @@ -9,6 +9,9 @@ clap = { version = "*", features = ["derive"] } rand = "*" hercules_ir = { path = "../../hercules_ir" } hercules_opt = { path = "../../hercules_opt" } +juno_scheduler = { path = "../../juno_scheduler" } itertools = "*" ordered-float = "*" -derive_more = {version = "*", features = ["from"]} \ No newline at end of file +derive_more = {version = "*", features = ["from"]} +postcard = { version = "*", features = ["alloc"] } +serde = { version = "*", features = ["derive"] } \ No newline at end of file diff --git a/hercules_test/hercules_interpreter/src/interpreter.rs b/hercules_test/hercules_interpreter/src/interpreter.rs index 621260e5a25597e3dedcc279348e0be7c1fb9472..a78330e4f08075be053593b41dba0f412687f5f1 100644 --- a/hercules_test/hercules_interpreter/src/interpreter.rs +++ b/hercules_test/hercules_interpreter/src/interpreter.rs @@ -1,15 +1,17 @@ - +use std::collections::hash_map::Entry::Occupied; use std::collections::HashMap; use std::panic; -use std::collections::hash_map::Entry::Occupied; use itertools::Itertools; +use std::cmp::{max, min}; use hercules_ir::*; use crate::value; use value::*; +const VERBOSE: bool = false; + /* High level design details / discussion for this: * * This crate includes tools for interpreting a hercules IR module. Execution model / flow is based on @@ -41,8 +43,8 @@ pub struct FunctionContext<'a> { fork_join_nest: &'a HashMap<NodeID, Vec<NodeID>>, } -impl <'a> FunctionContext<'a> { - pub fn new ( +impl<'a> FunctionContext<'a> { + pub fn new( control_subgraph: &'a Subgraph, def_use: &'a ImmutableDefUseMap, fork_join_map: &'a HashMap<NodeID, NodeID>, // Map forks -> joins @@ -57,12 +59,47 @@ impl <'a> FunctionContext<'a> { } } -pub fn dyn_const_value(dc: &DynamicConstant, dyn_const_params: &[usize]) -> usize { +// TODO: (@xrouth) I feel like this funcitonality should be provided by the manager that holds and allocates dynamic constants & IDs. +pub fn dyn_const_value( + dc: &DynamicConstantID, + dyn_const_values: &[DynamicConstant], + dyn_const_params: &[usize], +) -> usize { + let dc = &dyn_const_values[dc.idx()]; match dc { DynamicConstant::Constant(v) => *v, DynamicConstant::Parameter(v) => dyn_const_params[*v], + DynamicConstant::Add(a, b) => { + dyn_const_value(a, dyn_const_values, dyn_const_params) + + dyn_const_value(b, dyn_const_values, dyn_const_params) + } + DynamicConstant::Sub(a, b) => { + dyn_const_value(a, dyn_const_values, dyn_const_params) + - dyn_const_value(b, dyn_const_values, dyn_const_params) + } + DynamicConstant::Mul(a, b) => { + dyn_const_value(a, dyn_const_values, dyn_const_params) + * dyn_const_value(b, dyn_const_values, dyn_const_params) + } + DynamicConstant::Div(a, b) => { + dyn_const_value(a, dyn_const_values, dyn_const_params) + / dyn_const_value(b, dyn_const_values, dyn_const_params) + } + DynamicConstant::Rem(a, b) => { + dyn_const_value(a, dyn_const_values, dyn_const_params) + % dyn_const_value(b, dyn_const_values, dyn_const_params) + } + DynamicConstant::Max(a, b) => max( + dyn_const_value(a, dyn_const_values, dyn_const_params), + dyn_const_value(b, dyn_const_values, dyn_const_params), + ), + DynamicConstant::Min(a, b) => min( + dyn_const_value(a, dyn_const_values, dyn_const_params), + dyn_const_value(b, dyn_const_values, dyn_const_params), + ), } } + // Each control token stores a current position, and also a mapping of fork nodes -> thread idx. #[derive(Debug, Clone, Eq, PartialEq)] pub struct ControlToken { @@ -78,7 +115,12 @@ pub struct ControlToken { impl ControlToken { pub fn moved_to(&self, next: NodeID) -> ControlToken { - ControlToken { curr: next, prev: self.curr, thread_indicies: self.thread_indicies.clone(), phi_values: self.phi_values.clone() } + ControlToken { + curr: next, + prev: self.curr, + thread_indicies: self.thread_indicies.clone(), + phi_values: self.phi_values.clone(), + } } } impl<'a> FunctionExecutionState<'a> { @@ -89,7 +131,15 @@ impl<'a> FunctionExecutionState<'a> { function_contexts: &'a Vec<FunctionContext>, dynamic_constant_params: Vec<usize>, ) -> Self { - assert_eq!(args.len(), module.functions[function_id.idx()].param_types.len()); + println!( + "param types: {:?}", + module.functions[function_id.idx()].param_types + ); + + assert_eq!( + args.len(), + module.functions[function_id.idx()].param_types.len() + ); FunctionExecutionState { args, @@ -123,15 +173,10 @@ impl<'a> FunctionExecutionState<'a> { } /* Drives PHI values of this region for a control token, returns the next control node. */ - pub fn handle_region( - &mut self, - token: &mut ControlToken, - preds: &Box<[NodeID]>, - ) -> NodeID { - + pub fn handle_region(&mut self, token: &mut ControlToken, preds: &Box<[NodeID]>) -> NodeID { let prev = token.prev; let node = token.curr; - + // Gather PHI nodes for this region node. let phis: Vec<NodeID> = self .get_def_use() @@ -178,8 +223,12 @@ impl<'a> FunctionExecutionState<'a> { .try_phi() .expect("PANIC: handle_phi on non-phi node."); let value_node = data[edge]; - // println!("Latching PHI value of node {:?}", value_node.idx()); + let value = self.handle_data(token, value_node); + if VERBOSE { + println!("Latching PHI {:?} to {:?}", phi.idx(), value); + } + (phi, value) } @@ -199,26 +248,32 @@ impl<'a> FunctionExecutionState<'a> { }) .collect(); - for reduction in reduces { - self.handle_reduction(&token, reduction); + for reduction in &reduces { + self.handle_reduction(&token, *reduction); } let thread_values = self.get_thread_factors(&token, join); - // dbg!(thread_values.clone()); // This and_modify doesn't do aynthing?? self.join_counters .entry((thread_values.clone(), join)) .and_modify(|v| *v -= 1); + if VERBOSE { + println!( + "join, thread_values : {:?}, {:?}", + join, + thread_values.clone() + ); + } if *self .join_counters - .get(&(thread_values, join)) + .get(&(thread_values.clone(), join)) .expect("PANIC: join counter not initialized") == 0 { let curr = token.curr; token.prev = curr; - token.thread_indicies.pop(); // Get rid of this thread index. + token.thread_indicies.truncate(thread_values.len()); // Get rid of this thread index. token.curr = self.get_control_subgraph().succs(join).next().unwrap(); Some(token) } else { @@ -236,16 +291,36 @@ impl<'a> FunctionExecutionState<'a> { // Take the top N entries such that it matches the length of the TRF in the control token. // Get the depth of the control token that is requesting this reduction node. - let fork_levels = nested_forks.len(); - let len = fork_levels - 1; + // Sum over all thread dimensions in nested forks + let fork_levels: usize = nested_forks + .iter() + .map(|ele| { + self.get_function().nodes[ele.idx()] + .try_fork() + .unwrap() + .1 + .len() + }) + .sum(); + + let len = if nested_forks.is_empty() { + fork_levels - 1 + } else { + fork_levels + - (self.get_function().nodes[nested_forks.first().unwrap().idx()] + .try_fork() + .unwrap() + .1 + .len()) + }; + let mut thread_values = token.thread_indicies.clone(); thread_values.truncate(len); thread_values } - pub fn intialize_reduction(&mut self, token_at_fork: &ControlToken, reduce: NodeID) { - + pub fn initialize_reduction(&mut self, token_at_fork: &ControlToken, reduce: NodeID) { let token = token_at_fork; let (control, init, _) = &self.get_function().nodes[reduce.idx()] @@ -255,12 +330,16 @@ impl<'a> FunctionExecutionState<'a> { let thread_values = self.get_thread_factors(token, *control); let init = self.handle_data(&token, *init); - // Q (@xrouth): It is UB to have the initializer depend on things within the fork-join section? do we check for that? - // A: Should be done in verify (TODO). + + if VERBOSE { + println!( + "reduction {:?} initialized to: {:?} on thread {:?}", + reduce, init, thread_values + ); + } self.reduce_values - .entry((thread_values.clone(), reduce)) - .or_insert(init); + .insert((thread_values.clone(), reduce), init); } // Drive the reduction, this will be invoked for each control token. @@ -271,58 +350,86 @@ impl<'a> FunctionExecutionState<'a> { let thread_values = self.get_thread_factors(token, *control); - // If empty set to default (figure out how to not repeat this check) - // TODO: (Can we do it upon entry to the fork node?) (YES!) - let data = self.handle_data(&token, *reduct); - /* - println!( - "reduction write: {:?}, {:?}, {:?}", - thread_values, reduce, data - ); */ + + if VERBOSE { + println!( + "reduction {:?} write of {:?} on thread {:?}", + reduce, data, thread_values + ); + } self.reduce_values.insert((thread_values, reduce), data); } pub fn handle_data(&mut self, token: &ControlToken, node: NodeID) -> InterpreterVal { - // println!("Data Node: {} {:?}", node.idx(), &self.get_function().nodes[node.idx()]); - // Partial borrow complaint. :/ - match &self.module.functions[self.function_id.idx()].nodes[node.idx()]{ - Node::Phi { control: _, data: _ } => (*token + match &self.module.functions[self.function_id.idx()].nodes[node.idx()] { + Node::Phi { + control: _, + data: _, + } => (*token .phi_values .get(&node) - .expect("PANIC: Phi value not latched.")) + .expect(&format!("PANIC: Phi {:?} value not latched.", node))) .clone(), - Node::ThreadID { control } => { + Node::ThreadID { control, dimension } => { // `control` is the fork that drives this node. - let nesting_level = self + + let nested_forks = self .get_fork_join_nest() .get(control) .expect("PANIC: No nesting information for thread index!") - .len(); - let v = token.thread_indicies[nesting_level - 1]; // Might have to -1? + .clone(); + + // Skip forks until we get to this level. + // How many forks are outer? idfk. + let outer_forks: Vec<NodeID> = nested_forks + .iter() + .cloned() + .take_while(|fork| *fork != node) + .collect(); + + let fork_levels: usize = outer_forks + .iter() + .skip(1) + .map(|ele| { + self.get_function().nodes[ele.idx()] + .try_fork() + .unwrap() + .1 + .len() + }) + .sum(); + + // Dimension might need to instead be dimensions - dimension + let v = token.thread_indicies[fork_levels + dimension]; // Might have to -1? + if VERBOSE { + println!( + "node: {:?} gives tid: {:?} for thread: {:?}, dim: {:?}", + node, v, token.thread_indicies, dimension + ); + } InterpreterVal::DynamicConstant((v).into()) } // If we read from a reduction that is the same depth as this thread, we need to write back to it before anyone else reads from it. // This probably isn't the exact condition, but somethign similar. Anyways, we achieve correctness by iterating control nodes recursively. Node::Reduce { control, - init, + init: _, reduct: _, } => { let thread_values = self.get_thread_factors(token, *control); - // println!("reduction read: {:?}, {:?}", thread_values, node); - let entry = self - .reduce_values - .entry((thread_values, node)); - + let entry = self.reduce_values.entry((thread_values.clone(), node)); + let val = match entry { Occupied(v) => v.get().clone(), - std::collections::hash_map::Entry::Vacant(_) => panic!("Reduce has not been initialized!"), + std::collections::hash_map::Entry::Vacant(_) => panic!( + "Ctrl token: {:?}, Reduce {:?} has not been initialized!, TV: {:?}", + token, node, thread_values + ), }; - // println!("value: {:?}", val.clone()); val } Node::Parameter { index } => self.args[*index].clone(), @@ -333,21 +440,22 @@ impl<'a> FunctionExecutionState<'a> { &self.module.constants, &self.module.types, &self.module.dynamic_constants, - &self.dynamic_constant_params + &self.dynamic_constant_params, ) } Node::DynamicConstant { id } => { - let dyn_con = &self.module.dynamic_constants[id.idx()]; - let v = match dyn_con { - DynamicConstant::Constant(v) => v, - DynamicConstant::Parameter(v) => &self.dynamic_constant_params[*v], - }; + let v = dyn_const_value( + id, + &self.module.dynamic_constants, + &self.dynamic_constant_params, + ); + // TODO: Figure out what type / semantics are of thread ID and dynamic const. - InterpreterVal::DynamicConstant((*v).into()) + InterpreterVal::UnsignedInteger64(v.try_into().expect("too big dyn const!")) } Node::Unary { input, op } => { let val = self.handle_data(token, *input); - InterpreterVal::unary_op(*op, val) + InterpreterVal::unary_op(&self.module.types, *op, val) } Node::Binary { left, right, op } => { let left = self.handle_data(token, *left); @@ -376,25 +484,26 @@ impl<'a> FunctionExecutionState<'a> { } } Node::Call { + control: _, function, dynamic_constants, args, } => { + let args = args + .into_iter() + .map(|arg_node| self.handle_data(token, *arg_node)) + .collect(); - let args = args.into_iter() - .map(|arg_node| self.handle_data(token, *arg_node)) - .collect(); - - - let dynamic_constant_params = dynamic_constants.into_iter() - .map(|id| { - let dyn_con = &self.module.dynamic_constants[id.idx()]; - let v = match dyn_con { - DynamicConstant::Constant(v) => *v, - DynamicConstant::Parameter(v) => self.dynamic_constant_params[*v], - }; - v - }).collect_vec(); + let dynamic_constant_params = dynamic_constants + .into_iter() + .map(|id| { + dyn_const_value( + id, + &self.module.dynamic_constants, + &self.dynamic_constant_params, + ) + }) + .collect_vec(); let mut state = FunctionExecutionState::new( args, @@ -408,8 +517,19 @@ impl<'a> FunctionExecutionState<'a> { } Node::Read { collect, indices } => { let collection = self.handle_data(token, *collect); + if let InterpreterVal::Undef(_) = collection { + collection + } else { + let result = self.handle_read(token, collection.clone(), indices); - self.handle_read(token, collection, indices) + if VERBOSE { + println!( + "{:?} read value : {:?} from {:?}, {:?} at index {:?}", + node, result, collect, collection, indices + ); + } + result + } } Node::Write { collect, @@ -417,9 +537,14 @@ impl<'a> FunctionExecutionState<'a> { indices, } => { let collection = self.handle_data(token, *collect); - let data = self.handle_data(token, *data); - self.handle_write(token, collection, data, indices) + if let InterpreterVal::Undef(_) = collection { + collection + } else { + let data = self.handle_data(token, *data); + self.handle_write(token, collection, data, indices) + } } + Node::Undef { ty } => InterpreterVal::Undef(*ty), _ => todo!(), } } @@ -431,13 +556,19 @@ impl<'a> FunctionExecutionState<'a> { data: InterpreterVal, indices: &[Index], ) -> InterpreterVal { - let index = &indices[0]; - // TODO (@xrouth): Recurse on writes correctly - let val = match index { - Index::Field(_) => todo!(), - Index::Variant(_) => todo!(), - Index::Position(array_indices) => { + let val = match indices.first() { + Some(Index::Field(idx)) => { + if let InterpreterVal::Product(type_id, mut vals) = collection { + vals[*idx] = data; + InterpreterVal::Product(type_id, vals) + } else { + panic!("PANIC: Field index on not a product type") + } + } + None => collection, + Some(Index::Variant(_)) => todo!(), + Some(Index::Position(array_indices)) => { // Arrays also have inner indices... // Recover dimensional data from types. let array_indices: Vec<_> = array_indices @@ -451,12 +582,21 @@ impl<'a> FunctionExecutionState<'a> { .try_extents() .expect("PANIC: wrong type for array") .into_iter() - .map(|extent| dyn_const_value(&self.module.dynamic_constants[extent.idx()], &self.dynamic_constant_params)) + .map(|extent| { + dyn_const_value( + extent, + &self.module.dynamic_constants, + &self.dynamic_constant_params, + ) + }) .collect(); let idx = InterpreterVal::array_idx(&extents, &array_indices); - //println!("idx: {:?}", idx); - vals[idx] = data; - InterpreterVal::Array(type_id, vals) + if idx >= vals.len() { + InterpreterVal::Undef(type_id) + } else { + vals[idx] = data; + InterpreterVal::Array(type_id, vals) + } } else { panic!("PANIC: Position index on not an array") } @@ -484,6 +624,10 @@ impl<'a> FunctionExecutionState<'a> { .map(|idx| self.handle_data(token, *idx).as_usize()) .collect(); + if VERBOSE { + println!("read at rt indicies: {:?}", array_indices); + } + // TODO: Implemenet . try_array() and other try_conversions on the InterpreterVal type if let InterpreterVal::Array(type_id, vals) = collection { // TODO: Make this its own funciton to reuse w/ array_size @@ -491,9 +635,20 @@ impl<'a> FunctionExecutionState<'a> { .try_extents() .expect("PANIC: wrong type for array") .into_iter() - .map(|extent| dyn_const_value(&self.module.dynamic_constants[extent.idx()], &self.dynamic_constant_params)) + .map(|extent| { + dyn_const_value( + extent, + &self.module.dynamic_constants, + &self.dynamic_constant_params, + ) + }) .collect(); - vals[InterpreterVal::array_idx(&extents, &array_indices)].clone() + // FIXME: This type may be wrong. + let ret = vals + .get(InterpreterVal::array_idx(&extents, &array_indices)) + .unwrap_or(&InterpreterVal::Undef(type_id)) + .clone(); + ret } else { panic!("PANIC: Position index on not an array") } @@ -521,32 +676,40 @@ impl<'a> FunctionExecutionState<'a> { let mut live_tokens: Vec<ControlToken> = Vec::new(); live_tokens.push(start_token); - // To do reduction nodes correctly we have to traverse control tokens in a depth-first fashion (i.e immediately handle spawned threads). 'outer: loop { - let mut ctrl_token = live_tokens.pop().expect("PANIC: Interpreter ran out of control tokens without returning."); - - /* println!( - "\n\nNew Token at: Control State: {} threads: {:?}, {:?}", - ctrl_token.curr.idx(), - ctrl_token.thread_indicies.clone(), - &self.get_function().nodes[ctrl_token.curr.idx()] - ); */ + let mut ctrl_token = live_tokens + .pop() + .expect("PANIC: Interpreter ran out of control tokens without returning."); + + // TODO: (@xrouth): Enable this + PHI latch logging wi/ a simple debug flag. + // Tracking PHI vals and control state is very useful for debugging. + + if VERBOSE { + println!( + "control token {} {}", + ctrl_token.curr.idx(), + &self.get_function().nodes[ctrl_token.curr.idx()].lower_case_name() + ); + } // TODO: Rust is annoying and can't recognize that this is a partial borrow. - // Can't partial borrow, so need a clone. + // Can't partial borrow, so need a clone. let node = &self.get_function().nodes[ctrl_token.curr.idx()].clone(); let new_tokens = match node { Node::Start => { - let next: NodeID = self.get_control_subgraph().succs(ctrl_token.curr).next().unwrap(); + let next: NodeID = self + .get_control_subgraph() + .succs(ctrl_token.curr) + .next() + .unwrap(); let ctrl_token = ctrl_token.moved_to(next); - + vec![ctrl_token] } Node::Region { preds } => { - - // Updates + // Updates let next = self.handle_region(&mut ctrl_token, &preds); let ctrl_token = ctrl_token.moved_to(next); @@ -558,6 +721,7 @@ impl<'a> FunctionExecutionState<'a> { // Convert condition to usize let cond: usize = match cond { InterpreterVal::Boolean(v) => v.into(), + InterpreterVal::Undef(_) => panic!("PANIC: Undef reached IF"), _ => panic!("PANIC: Invalid condition for IF, please typecheck."), }; @@ -576,7 +740,11 @@ impl<'a> FunctionExecutionState<'a> { vec![ctrl_token] } Node::Projection { .. } => { - let next: NodeID = self.get_control_subgraph().succs(ctrl_token.curr).next().unwrap(); + let next: NodeID = self + .get_control_subgraph() + .succs(ctrl_token.curr) + .next() + .unwrap(); let ctrl_token = ctrl_token.moved_to(next); @@ -584,28 +752,52 @@ impl<'a> FunctionExecutionState<'a> { } Node::Match { control: _, sum: _ } => todo!(), - Node::Fork { control: _, factor } => { + Node::Fork { + control: _, + factors, + } => { let fork = ctrl_token.curr; - let dyn_con = &self.module.dynamic_constants[factor.idx()]; + // if factors.len() > 1 { + // panic!("multi-dimensional forks unimplemented") + // } + + let factors = factors + .iter() + .map(|f| { + dyn_const_value( + &f, + &self.module.dynamic_constants, + &self.dynamic_constant_params, + ) + }) + .rev(); - let thread_factor = match dyn_con { - DynamicConstant::Constant(v) => v, - DynamicConstant::Parameter(v) => &self.dynamic_constant_params[*v], - }.clone(); + let n_tokens: usize = factors.clone().product(); - // Update control token - let next = self.get_control_subgraph().succs(ctrl_token.curr).nth(0).unwrap(); + // Update control token + let next = self + .get_control_subgraph() + .succs(ctrl_token.curr) + .nth(0) + .unwrap(); let ctrl_token = ctrl_token.moved_to(next); - let mut tokens_to_add = Vec::with_capacity(thread_factor); + let mut tokens_to_add = Vec::with_capacity(n_tokens); - assert_ne!(thread_factor, 0); + assert_ne!(n_tokens, 0); // Token is at its correct sontrol succesor already. + // Add the new thread index. - for i in 0..(thread_factor) { + let num_outer_dims = ctrl_token.thread_indicies.len(); + for i in 0..n_tokens { + let mut temp = i; let mut new_token = ctrl_token.clone(); // Copy map, curr, prev, etc. - new_token.thread_indicies.push(i); // Stack of thread indicies + + for (_, dim) in factors.clone().enumerate().rev() { + new_token.thread_indicies.insert(num_outer_dims, temp % dim); // Stack of thread indicies + temp /= dim; + } tokens_to_add.push(new_token); } @@ -630,15 +822,21 @@ impl<'a> FunctionExecutionState<'a> { } }) .collect(); - + for reduction in reduces { // TODO: Is this the correct reduction? - self.intialize_reduction(&ctrl_token, reduction); + self.initialize_reduction(&ctrl_token, reduction); } + if VERBOSE { + println!( + "tf, fork, join, n_tokens: {:?}, {:?}, {:?}, {:?}", + thread_factors, fork, join, n_tokens + ); + } + self.join_counters.insert((thread_factors, join), n_tokens); - self.join_counters.insert((thread_factors, join), thread_factor); - + tokens_to_add.reverse(); tokens_to_add } Node::Join { control: _ } => { @@ -653,7 +851,6 @@ impl<'a> FunctionExecutionState<'a> { } Node::Return { control: _, data } => { let result = self.handle_data(&ctrl_token, *data); - // println!("result = {:?}", result); break 'outer result; } _ => { @@ -664,9 +861,6 @@ impl<'a> FunctionExecutionState<'a> { for i in new_tokens { live_tokens.push(i); } - } } } - - diff --git a/hercules_test/hercules_interpreter/src/lib.rs b/hercules_test/hercules_interpreter/src/lib.rs index 89fae51aded9271c84219f2394290ebbcec4123b..66f8c4eac35baa0845464541c1eece8335c53430 100644 --- a/hercules_test/hercules_interpreter/src/lib.rs +++ b/hercules_test/hercules_interpreter/src/lib.rs @@ -2,16 +2,23 @@ pub mod interpreter; pub mod value; use std::fs::File; +use std::io::Read; use hercules_ir::Module; use hercules_ir::TypeID; +use hercules_ir::ID; + +pub use juno_scheduler::PassManager; pub use crate::interpreter::*; pub use crate::value::*; -// Get a vec of -pub fn into_interp_val(module: &Module, wrapper: InterpreterWrapper, target_ty_id: TypeID) -> InterpreterVal -{ +// Get a vec of +pub fn into_interp_val( + module: &Module, + wrapper: InterpreterWrapper, + target_ty_id: TypeID, +) -> InterpreterVal { match wrapper { InterpreterWrapper::Boolean(v) => InterpreterVal::Boolean(v), InterpreterWrapper::Integer8(v) => InterpreterVal::Integer8(v), @@ -29,31 +36,25 @@ pub fn into_interp_val(module: &Module, wrapper: InterpreterWrapper, target_ty_i InterpreterWrapper::Array(array) => { let ty = &module.types[target_ty_id.idx()]; - let ele_type = ty.try_element_type().expect("PANIC: Type ID"); - // unwrap -> map to rust type, check - + ty.try_element_type() + .expect("PANIC: Invalid parameter type"); + let mut values = vec![]; - + for i in 0..array.len() { values.push(into_interp_val(module, array[i].clone(), TypeID::new(0))); } - + InterpreterVal::Array(target_ty_id, values.into_boxed_slice()) } } -} - -pub fn array_from_interp_val<T: Clone>(module: &Module, interp_val: InterpreterVal) -> Vec<T> - where value::InterpreterVal: Into<T> -{ - vec![] } // Recursively turns rt args into interpreter wrappers. #[macro_export] macro_rules! parse_rt_args { ($arg:expr) => { - { + { let mut values: Vec<InterpreterWrapper> = vec![]; @@ -63,7 +64,7 @@ macro_rules! parse_rt_args { } }; ( $arg:expr, $($tail_args:expr), +) => { - { + { let mut values: Vec<InterpreterWrapper> = vec![]; values.push($arg.into()); @@ -85,9 +86,16 @@ pub fn parse_file(path: &str) -> Module { module } +pub fn parse_module_from_hbin(path: &str) -> hercules_ir::ir::Module { + let mut file = File::open(path).expect("PANIC: Unable to open input file."); + let mut buffer = vec![]; + file.read_to_end(&mut buffer).unwrap(); + postcard::from_bytes(&buffer).unwrap() +} + #[macro_export] macro_rules! interp_module { - ($module:ident, $dynamic_constants:expr, $($args:expr), *) => { + ($module:ident, $entry_func:expr, $dynamic_constants:expr, $($args:expr), *) => { { //let hir_file = String::from($path); @@ -96,27 +104,22 @@ macro_rules! interp_module { let dynamic_constants: Vec<usize> = $dynamic_constants.into(); let module = $module.clone(); //parse_file(hir_file); - let mut pm = hercules_opt::pass::PassManager::new(module); - pm.add_pass(hercules_opt::pass::Pass::Verify); - - pm.run_passes(); + let mut pm = PassManager::new(module); + pm.make_typing(); pm.make_reverse_postorders(); pm.make_doms(); pm.make_fork_join_maps(); pm.make_fork_join_nests(); pm.make_control_subgraphs(); - pm.make_plans(); let reverse_postorders = pm.reverse_postorders.as_ref().unwrap().clone(); let doms = pm.doms.as_ref().unwrap().clone(); let fork_join_maps = pm.fork_join_maps.as_ref().unwrap().clone(); let fork_join_nests = pm.fork_join_nests.as_ref().unwrap().clone(); - let plans = pm.plans.as_ref().unwrap().clone(); let control_subgraphs = pm.control_subgraphs.as_ref().unwrap().clone(); let def_uses = pm.def_uses.as_ref().unwrap().clone(); let module = pm.get_module(); - let mut function_contexts = vec![]; for idx in 0..module.functions.len() { @@ -129,7 +132,7 @@ macro_rules! interp_module { function_contexts.push(context); } - let function_number = 0; + let function_number = $entry_func; let parameter_types = &module.functions[function_number].param_types; @@ -148,27 +151,19 @@ macro_rules! interp_module { }; } - #[macro_export] macro_rules! interp_file_with_passes { ($path:literal, $dynamic_constants:expr, $passes:expr, $($args:expr), *) => { { let module = parse_file($path); - - let result_before = interp_module!(module, $dynamic_constants, $($args), *); - - let mut pm = hercules_opt::pass::PassManager::new(module.clone()); - for pass in $passes { - pm.add_pass(pass); - } + let result_before = interp_module!(module, $dynamic_constants, $($args), *); - pm.run_passes(); + let module = run_schedule_on_hercules(module, None).unwrap(); - let module = pm.get_module(); - let result_after = interp_module!(module, $dynamic_constants, $($args), *); + let result_after = interp_module!(module, $dynamic_constants, $($args), *); assert_eq!(result_after, result_before); } }; -} \ No newline at end of file +} diff --git a/hercules_test/hercules_interpreter/src/main.rs b/hercules_test/hercules_interpreter/src/main.rs deleted file mode 100644 index 5db31cd730fe802dd9ccbf1b8e0d603c736fb196..0000000000000000000000000000000000000000 --- a/hercules_test/hercules_interpreter/src/main.rs +++ /dev/null @@ -1,28 +0,0 @@ -use std::fs::File; -use std::io::prelude::*; - -use clap::Parser; - -use hercules_ir::*; - -use hercules_interpreter::interpreter::*; -use hercules_interpreter::*; -use hercules_interpreter::value; - -#[derive(Parser, Debug)] -#[command(author, version, about, long_about = None)] -struct Args { - hir_file: String, - - #[arg(short, long, default_value_t = String::new())] - output: String, -} - -fn main() { - let args = Args::parse(); - let module = parse_file(&args.hir_file); - let ret_val = interp_module!(module, [2, 3, 4], 1, 3); - - println!("ret val: {:?}", ret_val); -} - diff --git a/hercules_test/hercules_interpreter/src/value.rs b/hercules_test/hercules_interpreter/src/value.rs index e032bd5b711a713c55c6cd911c99e327ff1d2e0b..53911e05c2333a0e9b30c5bfdacb854f8409f692 100644 --- a/hercules_test/hercules_interpreter/src/value.rs +++ b/hercules_test/hercules_interpreter/src/value.rs @@ -29,6 +29,7 @@ pub enum InterpreterVal { Float32(ordered_float::OrderedFloat<f32>), Float64(ordered_float::OrderedFloat<f64>), + Undef(TypeID), Product(TypeID, Box<[InterpreterVal]>), Summation(TypeID, u32, Box<[InterpreterVal]>), Array(TypeID, Box<[InterpreterVal]>), // TypeID of the array Type (not the element type) @@ -76,6 +77,7 @@ where { fn from(value: Vec<T>) -> Self { let mut values = vec![]; + values.reserve(value.len()); for i in 0..value.len() { values.push(value[i].clone().into()); } @@ -89,8 +91,23 @@ where { fn from(value: &[T]) -> Self { let mut values = vec![]; + values.reserve(value.len()); for i in 0..value.len() { - values[i] = value[i].clone().into() + values.push(value[i].clone().into()); + } + InterpreterWrapper::Array(values.into_boxed_slice()) + } +} + +impl<T> From<Box<[T]>> for InterpreterWrapper +where + T: Into<InterpreterWrapper> + Clone, +{ + fn from(value: Box<[T]>) -> Self { + let mut values = vec![]; + values.reserve(value.len()); + for i in 0..value.len() { + values.push(value[i].clone().into()); } InterpreterWrapper::Array(values.into_boxed_slice()) } @@ -138,7 +155,10 @@ impl<'a> InterpreterVal { Constant::Float32(v) => Self::Float32(v), Constant::Float64(v) => Self::Float64(v), - Constant::Product(_, _) => todo!(), + Constant::Product(ref type_id, ref constant_ids) => { + // Self::Product((), ()) + todo!() + } Constant::Summation(_, _, _) => todo!(), Constant::Array(type_id) => { // TODO: This is currently only implemented for arrays of primitive types, implement zero initializers for other types. @@ -149,7 +169,7 @@ impl<'a> InterpreterVal { .expect("PANIC: wrong type for array") .into_iter() .map(|extent| { - dyn_const_value(&dynamic_constants[extent.idx()], &dynamic_constant_params) + dyn_const_value(extent, &dynamic_constants, &dynamic_constant_params) }) .collect(); @@ -193,6 +213,14 @@ impl<'a> InterpreterVal { left: InterpreterVal, right: InterpreterVal, ) -> InterpreterVal { + // If either are undef, propogate undef + if let InterpreterVal::Undef(v) = left { + return InterpreterVal::Undef(v); + } + if let InterpreterVal::Undef(v) = right { + return InterpreterVal::Undef(v); + } + // Do some type conversion first. let left = match left { InterpreterVal::DynamicConstant(v) => match right { @@ -213,7 +241,8 @@ impl<'a> InterpreterVal { InterpreterVal::UnsignedInteger64(v.try_into().unwrap()) } InterpreterVal::DynamicConstant(_) => { - panic!("PANIC: Some math on dynamic constants is unimplemented") + InterpreterVal::UnsignedInteger64(v.try_into().unwrap()) + //panic!("PANIC: Some math on dynamic constants is unimplemented") } // InterpreterVal::ThreadID(_) => InterpreterVal::Boolean(v), _ => panic!("PANIC: Some math on dynamic constants is unimplemented"), @@ -241,7 +270,8 @@ impl<'a> InterpreterVal { InterpreterVal::UnsignedInteger64(v.try_into().unwrap()) } InterpreterVal::DynamicConstant(_) => { - panic!("PANIC: Some math on dynamic constants is unimplemented") + InterpreterVal::UnsignedInteger64(v.try_into().unwrap()) + //panic!("PANIC: Some math on dynamic constants is unimplemented") } _ => panic!("PANIC: Some math on dynamic constants is unimplemented"), }, @@ -772,7 +802,7 @@ impl<'a> InterpreterVal { } } - pub fn unary_op(op: UnaryOperator, val: InterpreterVal) -> Self { + pub fn unary_op(types: &Vec<Type>, op: UnaryOperator, val: InterpreterVal) -> Self { match (op, val) { (UnaryOperator::Not, Self::Boolean(val)) => Self::Boolean(!val), (UnaryOperator::Not, Self::Integer8(val)) => Self::Integer8(!val), @@ -789,7 +819,28 @@ impl<'a> InterpreterVal { (UnaryOperator::Neg, Self::Integer64(val)) => Self::Integer64(-val), (UnaryOperator::Neg, Self::Float32(val)) => Self::Float32(-val), (UnaryOperator::Neg, Self::Float64(val)) => Self::Float64(-val), - (UnaryOperator::Cast(_), _) => todo!("Write cast impl"), + (UnaryOperator::Cast(type_id), val) => { + // FIXME: This probably doesn't work. + let val = val.as_i128(); + match types[type_id.idx()] { + Type::Control => todo!(), + Type::Boolean => todo!(), + Type::Integer8 => todo!(), + Type::Integer16 => todo!(), + Type::Integer32 => Self::Integer32(val.try_into().unwrap()), + Type::Integer64 => todo!(), + Type::UnsignedInteger8 => todo!(), + Type::UnsignedInteger16 => todo!(), + Type::UnsignedInteger32 => todo!(), + Type::UnsignedInteger64 => Self::UnsignedInteger64(val.try_into().unwrap()), + Type::Float32 => todo!(), + Type::Float64 => todo!(), + Type::Product(_) => todo!(), + Type::Summation(_) => todo!(), + Type::Array(type_id, _) => todo!(), + } + } + (_, Self::Undef(v)) => InterpreterVal::Undef(v), _ => panic!("Unsupported combination of unary operation and constant value. Did typechecking succeed?") } } @@ -811,6 +862,23 @@ impl<'a> InterpreterVal { } } + pub fn as_i128(&self) -> i128 { + match *self { + InterpreterVal::Boolean(v) => v.try_into().unwrap(), + InterpreterVal::Integer8(v) => v.try_into().unwrap(), + InterpreterVal::Integer16(v) => v.try_into().unwrap(), + InterpreterVal::Integer32(v) => v.try_into().unwrap(), + InterpreterVal::Integer64(v) => v.try_into().unwrap(), + InterpreterVal::UnsignedInteger8(v) => v.try_into().unwrap(), + InterpreterVal::UnsignedInteger16(v) => v.try_into().unwrap(), + InterpreterVal::UnsignedInteger32(v) => v.try_into().unwrap(), + InterpreterVal::UnsignedInteger64(v) => v.try_into().unwrap(), + InterpreterVal::DynamicConstant(v) => v.try_into().unwrap(), + InterpreterVal::ThreadID(v) => v.try_into().unwrap(), + _ => panic!("PANIC: Value not castable to usize"), + } + } + // Defines row major / how we layout our arrays pub fn array_idx(extents: &[usize], indices: &[usize]) -> usize { let a = extents diff --git a/hercules_test/hercules_tests/Cargo.toml b/hercules_test/hercules_tests/Cargo.toml index 9bd6fe7b7ecdf1537af7e1c3dab430b84a3b7601..8c140e75145be2ca12626b4e4d3c8123a578e329 100644 --- a/hercules_test/hercules_tests/Cargo.toml +++ b/hercules_test/hercules_tests/Cargo.toml @@ -9,6 +9,7 @@ clap = { version = "*", features = ["derive"] } rand = "*" hercules_ir = { path = "../../hercules_ir" } hercules_opt = { path = "../../hercules_opt" } +juno_scheduler = { path = "../../juno_scheduler" } hercules_interpreter = { path = "../hercules_interpreter" } itertools = "*" ordered-float = "*" diff --git a/hercules_test/hercules_tests/tests/fork_transform_tests.rs b/hercules_test/hercules_tests/tests/fork_transform_tests.rs new file mode 100644 index 0000000000000000000000000000000000000000..3799ca0ac7e8abe9907603269692fbd438c4e33d --- /dev/null +++ b/hercules_test/hercules_tests/tests/fork_transform_tests.rs @@ -0,0 +1,99 @@ +use std::{env, fs::File, io::Read, path::Path}; + +use hercules_interpreter::*; +use hercules_ir::ID; +use juno_scheduler::ir::*; + +use juno_scheduler::pass; +use juno_scheduler::{default_schedule, run_schedule_on_hercules}; +use rand::Rng; + +#[test] +fn fission_simple1() { + let module = parse_file("../test_inputs/fork_transforms/fork_fission/simple1.hir"); + let dyn_consts = [10]; + let params = 2; // TODO: (@xrouth) fix macro to take no params as an option. + let result_1 = interp_module!(module, 0, dyn_consts, 2); + + println!("result: {:?}", result_1); + + let sched = Some(default_schedule![ + Verify, //Xdot, + Unforkify, //Xdot, + DCE, Verify, + ]); + + let module = run_schedule_on_hercules(module, sched).unwrap(); + let result_2 = interp_module!(module, 0, dyn_consts, 2); + println!("result: {:?}", result_2); + assert_eq!(result_1, result_2) +} + +// #[test] +// fn fission_simple2() { +// let module = parse_file("../test_inputs/fork_transforms/fork_fission/simple2.hir"); +// let dyn_consts = [10]; +// let params = 2; // TODO: (@xrouth) fix macro to take no params as an option. +// let result_1 = interp_module!(module, 0, dyn_consts, 2); + +// println!("result: {:?}", result_1); + +// let sched: Option<ScheduleStmt> = Some(default_schedule![ +// Verify, +// ForkFission, +// DCE, +// Verify, +// ]); + +// let module = run_schedule_on_hercules(module, sched).unwrap(); +// let result_2 = interp_module!(module, 0, dyn_consts, 2); +// println!("result: {:?}", result_2); +// assert_eq!(result_1, result_2) +// } + +// #[ignore] // Wait +// #[test] +// fn fission_tricky() { +// // This either crashes or gives wrong result depending on the order which reduces are observed in. +// let module = parse_file("../test_inputs/fork_transforms/fork_fission/tricky.hir"); +// let dyn_consts = [10]; +// let params = 2; // TODO: (@xrouth) fix macro to take no params as an option. +// let result_1 = interp_module!(module, 0, dyn_consts, 2); + +// println!("result: {:?}", result_1); + +// let sched: Option<ScheduleStmt> = Some(default_schedule![ +// Verify, +// ForkFission, +// DCE, +// Verify, +// ]); + +// let module = run_schedule_on_hercules(module, sched).unwrap(); +// let result_2 = interp_module!(module, 0, dyn_consts, 2); +// println!("result: {:?}", result_2); +// assert_eq!(result_1, result_2) +// } + +// #[ignore] +// #[test] +// fn inner_loop() { +// let module = parse_file("../test_inputs/fork_transforms/fork_fission/inner_loop.hir"); +// let dyn_consts = [10, 20]; +// let params = 2; // TODO: (@xrouth) fix macro to take no params as an option. +// let result_1 = interp_module!(module, 0, dyn_consts, 2); + +// println!("result: {:?}", result_1); + +// let sched: Option<ScheduleStmt> = Some(default_schedule![ +// Verify, +// ForkFission, +// DCE, +// Verify, +// ]); + +// let module = run_schedule_on_hercules(module, sched).unwrap(); +// let result_2 = interp_module!(module, 0, dyn_consts, 2); +// println!("result: {:?}", result_2); +// assert_eq!(result_1, result_2) +// } diff --git a/hercules_test/hercules_tests/tests/forkify_tests.rs b/hercules_test/hercules_tests/tests/forkify_tests.rs new file mode 100644 index 0000000000000000000000000000000000000000..8ba8e1354ec8f1c080e2d8b556fe427493af75ae --- /dev/null +++ b/hercules_test/hercules_tests/tests/forkify_tests.rs @@ -0,0 +1,294 @@ +use std::{env, fs::File, io::Read, path::Path}; + +use hercules_interpreter::*; +use hercules_ir::ID; + +use hercules_interpreter::*; +use juno_scheduler::ir::*; +use juno_scheduler::pass; + +use juno_scheduler::{default_schedule, run_schedule_on_hercules}; +use rand::Rng; + +#[test] +#[ignore] +fn inner_fork_chain() { + let module = parse_file("../test_inputs/forkify/inner_fork_chain.hir"); + let dyn_consts = [10]; + let params = 2; // TODO: (@xrouth) fix macro to take no params as an option. + // let result_1 = interp_module!(module, 0, dyn_consts, 2); + + let sched: Option<ScheduleStmt> = Some(default_schedule![Verify, Forkify, PhiElim, Verify,]); + + let module = run_schedule_on_hercules(module, sched).unwrap(); + + let result_2 = interp_module!(module, 0, dyn_consts, 2); + println!("result: {:?}", result_2); + //assert_eq!(result_1, result_2) +} + +#[test] +fn loop_simple_iv() { + let module = parse_file("../test_inputs/forkify/loop_simple_iv.hir"); + let dyn_consts = [10]; + let params = 2; // TODO: (@xrouth) fix macro to take no params as an option. + let result_1 = interp_module!(module, 0, dyn_consts, 2); + + println!("result: {:?}", result_1); + + let sched: Option<ScheduleStmt> = Some(default_schedule![Verify, Forkify, Verify,]); + + let module = run_schedule_on_hercules(module, sched).unwrap(); + + let result_2 = interp_module!(module, 0, dyn_consts, 2); + println!("result: {:?}", result_2); + assert_eq!(result_1, result_2) +} + +#[test] +#[ignore] +fn merged_phi_cycle() { + let module = parse_file("../test_inputs/forkify/merged_phi_cycle.hir"); + let dyn_consts = [10]; + let params = 2; // TODO: (@xrouth) fix macro to take no params as an option. + let result_1 = interp_module!(module, 0, dyn_consts, 2); + + println!("result: {:?}", result_1); + + let sched: Option<ScheduleStmt> = Some(default_schedule![Verify, Verify,]); + + let module = run_schedule_on_hercules(module, sched).unwrap(); + + let result_2 = interp_module!(module, 0, dyn_consts, 2); + println!("result: {:?}", result_2); + assert_eq!(result_1, result_2) +} + +#[test] +fn split_phi_cycle() { + let module = parse_file("../test_inputs/forkify/split_phi_cycle.hir"); + let dyn_consts = [10]; + let params = 2; // TODO: (@xrouth) fix macro to take no params as an option. + let result_1 = interp_module!(module, 0, dyn_consts, 2); + + println!("result: {:?}", result_1); + + let sched: Option<ScheduleStmt> = Some(default_schedule![Verify, Verify,]); + + let module = run_schedule_on_hercules(module, sched).unwrap(); + + let result_2 = interp_module!(module, 0, dyn_consts, 2); + println!("result: {:?}", result_2); + assert_eq!(result_1, result_2) +} + +#[test] +fn loop_sum() { + let module = parse_file("../test_inputs/forkify/loop_sum.hir"); + let dyn_consts = [20]; + let params = 2; // TODO: (@xrouth) fix macro to take no params as an option. + let result_1 = interp_module!(module, 0, dyn_consts, 2); + + println!("result: {:?}", result_1); + + let module = run_schedule_on_hercules(module, None).unwrap(); + let result_2 = interp_module!(module, 0, dyn_consts, 2); + assert_eq!(result_1, result_2); + println!("{:?}, {:?}", result_1, result_2); +} + +#[test] +fn loop_tid_sum() { + let module = parse_file("../test_inputs/forkify/loop_tid_sum.hir"); + let dyn_consts = [20]; + let params = 2; // TODO: (@xrouth) fix macro to take no params as an option. + let result_1 = interp_module!(module, 0, dyn_consts, 2); + + println!("result: {:?}", result_1); + + let module = run_schedule_on_hercules(module, None).unwrap(); + let result_2 = interp_module!(module, 0, dyn_consts, 2); + assert_eq!(result_1, result_2); + println!("{:?}, {:?}", result_1, result_2); +} + +#[test] +fn loop_array_sum() { + let module = parse_file("../test_inputs/forkify/loop_array_sum.hir"); + let len = 5; + let dyn_consts = [len]; + let params = vec![1, 2, 3, 4, 5]; // TODO: (@xrouth) fix macro to take no params as an option. + let result_1 = interp_module!(module, 0, dyn_consts, params.clone()); + + println!("result: {:?}", result_1); + + let module = run_schedule_on_hercules(module, None).unwrap(); + let result_2 = interp_module!(module, 0, dyn_consts, params); + assert_eq!(result_1, result_2); + println!("{:?}, {:?}", result_1, result_2); +} + +/** Nested loop 2 is 2 nested loops with different dyn var parameter dimensions. + * It is a add of 1 for each iteration, so the result should be dim1 x dim2 + * The loop PHIs are structured such that on every outer iteration, inner loop increment is set to the running sum, + * Notice how there is no outer_var_inc. + * + * The alternative, seen in nested_loop1, is to intiailize the inner loop to 0 every time, and track + * the outer sum more separaetly. + * + * Idk what im yapping about. +*/ +#[test] +fn nested_loop2() { + let module = parse_file("../test_inputs/forkify/nested_loop2.hir"); + let len = 5; + let dyn_consts = [5, 6]; + let params = vec![1, 2, 3, 4, 5]; // TODO: (@xrouth) fix macro to take no params as an option. + let result_1 = interp_module!(module, 0, dyn_consts, 2); + + println!("result: {:?}", result_1); + + let module = run_schedule_on_hercules(module, None).unwrap(); + let result_2 = interp_module!(module, 0, dyn_consts, 2); + assert_eq!(result_1, result_2); +} + +#[test] +fn super_nested_loop() { + let module = parse_file("../test_inputs/forkify/super_nested_loop.hir"); + let len = 5; + let dyn_consts = [5, 10, 15]; + let params = vec![1, 2, 3, 4, 5]; // TODO: (@xrouth) fix macro to take no params as an option. + let result_1 = interp_module!(module, 0, dyn_consts, 2); + + println!("result: {:?}", result_1); + + let module = run_schedule_on_hercules(module, None).unwrap(); + + let result_2 = interp_module!(module, 0, dyn_consts, 2); + assert_eq!(result_1, result_2); +} + +/** + * Tests forkify on a loop where there is control in between the continue projection + * and the header. aka control *after* the `loop condition / guard`. This should forkify. + */ +#[test] +fn control_after_condition() { + let module = parse_file("../test_inputs/forkify/control_after_condition.hir"); + + let size = 10; + let dyn_consts = [size]; + let mut vec = vec![0; size]; + let mut rng = rand::thread_rng(); + + for x in vec.iter_mut() { + *x = rng.gen::<i32>() / 100; + } + + let result_1 = interp_module!(module, 0, dyn_consts, vec.clone()); + + println!("result: {:?}", result_1); + + let module = run_schedule_on_hercules(module, None).unwrap(); + + let result_2 = interp_module!(module, 0, dyn_consts, vec); + assert_eq!(result_1, result_2); +} + +/** + * Tests forkify on a loop where there is control before the loop condition, so in between the header + * and the loop condition. This should not forkify. + * + * This example is bugged, it reads out of bounds even before forkify. + */ +#[ignore] +#[test] +fn control_before_condition() { + let module = parse_file("../test_inputs/forkify/control_before_condition.hir"); + + let size = 11; + let dyn_consts = [size - 1]; + let mut vec = vec![0; size]; + let mut rng = rand::thread_rng(); + + for x in vec.iter_mut() { + *x = rng.gen::<i32>() / 100; + } + + let result_1 = interp_module!(module, 0, dyn_consts, vec.clone()); + + println!("result: {:?}", result_1); + + let sched: Option<ScheduleStmt> = Some(default_schedule![Verify, Forkify, DCE, Verify,]); + + let module = run_schedule_on_hercules(module, sched).unwrap(); + let result_2 = interp_module!(module, 0, dyn_consts, vec); + assert_eq!(result_1, result_2); +} + +#[test] +fn nested_tid_sum() { + let module = parse_file("../test_inputs/forkify/nested_tid_sum.hir"); + let len = 5; + let dyn_consts = [5, 6]; + let params = vec![1, 2, 3, 4, 5]; // TODO: (@xrouth) fix macro to take no params as an option. + let result_1 = interp_module!(module, 0, dyn_consts, 2); + + println!("result: {:?}", result_1); + + let sched: Option<ScheduleStmt> = Some(default_schedule![Verify, Forkify, DCE, Verify,]); + + let module = run_schedule_on_hercules(module, sched).unwrap(); + let result_2 = interp_module!(module, 0, dyn_consts, 2); + assert_eq!(result_1, result_2); + + let sched: Option<ScheduleStmt> = Some(default_schedule![Verify, Forkify, DCE, Verify,]); + + let module = run_schedule_on_hercules(module, sched).unwrap(); + let result_3 = interp_module!(module, 0, dyn_consts, 2); + + println!("{:?}, {:?}, {:?}", result_1, result_2, result_3); +} + +#[test] +fn nested_tid_sum_2() { + let module = parse_file("../test_inputs/forkify/nested_tid_sum_2.hir"); + let len = 5; + let dyn_consts = [5, 6]; + let params = vec![1, 2, 3, 4, 5]; // TODO: (@xrouth) fix macro to take no params as an option. + let result_1 = interp_module!(module, 0, dyn_consts, 2); + + println!("result: {:?}", result_1); + + let sched: Option<ScheduleStmt> = Some(default_schedule![Verify, Forkify, DCE, Verify,]); + + let module = run_schedule_on_hercules(module, sched).unwrap(); + let result_2 = interp_module!(module, 0, dyn_consts, 2); + assert_eq!(result_1, result_2); + + let sched: Option<ScheduleStmt> = Some(default_schedule![Verify, Forkify, DCE, Verify,]); + + let module = run_schedule_on_hercules(module, sched).unwrap(); + let result_3 = interp_module!(module, 0, dyn_consts, 2); + + println!("{:?}, {:?}, {:?}", result_1, result_2, result_3); +} + +/** Tests weird control in outer loop for possible 2d fork-join pair. */ +#[test] +fn inner_fork_complex() { + let module = parse_file("../test_inputs/forkify/inner_fork_complex.hir"); + let dyn_consts = [5, 6]; + let params = vec![1, 2, 3, 4, 5]; // TODO: (@xrouth) fix macro to take no params as an option. + let result_1 = interp_module!(module, 0, dyn_consts, 10); + + println!("result: {:?}", result_1); + + let sched: Option<ScheduleStmt> = Some(default_schedule![Verify, Forkify, DCE, Verify,]); + + let module = run_schedule_on_hercules(module, sched).unwrap(); + let result_2 = interp_module!(module, 0, dyn_consts, 10); + assert_eq!(result_1, result_2); + println!("{:?}, {:?}", result_1, result_2); +} diff --git a/hercules_test/hercules_tests/tests/interpreter_tests.rs b/hercules_test/hercules_tests/tests/interpreter_tests.rs new file mode 100644 index 0000000000000000000000000000000000000000..a779c70b263b64d5088a4c2dfc23fcd402951b0b --- /dev/null +++ b/hercules_test/hercules_tests/tests/interpreter_tests.rs @@ -0,0 +1,63 @@ +use std::env; + +use hercules_interpreter::*; +use hercules_interpreter::*; +use hercules_ir::ID; +use juno_scheduler::ir::*; +use juno_scheduler::pass; + +use juno_scheduler::{default_schedule, run_schedule_on_hercules}; +use rand::Rng; + +#[test] +fn twodeefork() { + let module = parse_file("../test_inputs/2d_fork.hir"); + let d1 = 2; + let d2 = 3; + let dyn_consts = [d1, d2]; + let result_1 = interp_module!(module, 0, dyn_consts, 2); + + let sched = Some(default_schedule![ + Verify, ForkSplit, //Xdot, + Unforkify, //Xdot, + DCE, Verify, + ]); + + let module = run_schedule_on_hercules(module, sched).unwrap(); + let result_2 = interp_module!(module, 0, dyn_consts, 2); + + let res = (d1 as i32 * d2 as i32); + let result_2: InterpreterWrapper = res.into(); + println!("result: {:?}", result_1); // Should be d1 * d2. +} + +#[test] +fn threedee() { + let module = parse_file("../test_inputs/3d_fork.hir"); + let d1 = 2; + let d2 = 3; + let d3 = 5; + let dyn_consts = [d1, d2, 5]; + let result_1 = interp_module!(module, 0, dyn_consts, 2); + + let sched = Some(default_schedule![ + Verify, ForkSplit, //Xdot, + Unforkify, //Xdot, + DCE, Verify, + ]); + + let module = run_schedule_on_hercules(module, sched).unwrap(); + let result_2 = interp_module!(module, 0, dyn_consts, 2); + + let res = (d1 as i32 * d2 as i32 * d3 as i32); + let result_2: InterpreterWrapper = res.into(); + println!("result: {:?}", result_1); // Should be d1 * d2. +} + +#[test] +fn fivedeefork() { + let module = parse_file("../test_inputs/5d_fork.hir"); + let dyn_consts = [1, 2, 3, 4, 5]; + let result_1 = interp_module!(module, 0, dyn_consts, 2); + println!("result: {:?}", result_1); // Should be 1 * 2 * 3 * 4 * 5; +} diff --git a/hercules_test/hercules_tests/tests/loop_tests.rs b/hercules_test/hercules_tests/tests/loop_tests.rs new file mode 100644 index 0000000000000000000000000000000000000000..5832a161a18f18ea43860c1c5d6364385d0f187f --- /dev/null +++ b/hercules_test/hercules_tests/tests/loop_tests.rs @@ -0,0 +1,436 @@ +use std::{env, fs::File, io::Read, path::Path}; + +use hercules_interpreter::*; +use hercules_ir::ID; +use juno_scheduler::ir::*; +use juno_scheduler::pass; + +use juno_scheduler::{default_schedule, run_schedule_on_hercules}; +use rand::random; +use rand::Rng; + +// Tests canonicalization + +#[ignore] +#[test] +fn loop_trip_count() { + let module = parse_file("../test_inputs/loop_analysis/loop_trip_count.hir"); + let dyn_consts = [10]; + let params = 2; // TODO: (@xrouth) fix macro to take no params as an option. + let result_1 = interp_module!(module, 0, dyn_consts, 2); + + println!("result: {:?}", result_1); +} + +// Test canonicalization +#[test] +#[ignore] +fn alternate_bounds_use_after_loop_no_tid() { + let len = 1; + let dyn_consts = [len]; + + let module = + parse_file("../test_inputs/loop_analysis/alternate_bounds_use_after_loop_no_tid.hir"); + let result_1 = interp_module!(module, 0, dyn_consts, 3); + + println!("result: {:?}", result_1); + + let schedule = default_schedule![ + Forkify, + ]; + + let module = run_schedule_on_hercules(module, Some(schedule)).unwrap(); + + let result_2 = interp_module!(module, 0, dyn_consts, 3); + println!("{:?}", result_1); + println!("{:?}", result_2); + + assert_eq!(result_1, result_2); +} + +// Test canonicalization +#[test] +#[ignore] +fn alternate_bounds_use_after_loop() { + let len = 4; + let dyn_consts = [len]; + + let a = vec![3, 4, 5, 6]; + let module = parse_file("../test_inputs/loop_analysis/alternate_bounds_use_after_loop.hir"); + let result_1 = interp_module!(module, 0, dyn_consts, a.clone()); + + println!("result: {:?}", result_1); + + let schedule = Some(default_schedule![ + Forkify, + ]); + + let module = run_schedule_on_hercules(module, schedule).unwrap(); + + let result_2 = interp_module!(module, 0, dyn_consts, a.clone()); + println!("{:?}", result_2); + + assert_eq!(result_1, result_2); +} + +// Test canonicalization +#[test] +#[ignore] +fn alternate_bounds_use_after_loop2() { + let len = 4; + let dyn_consts = [len]; + + let a = vec![3, 4, 5, 6]; + let module = parse_file("../test_inputs/loop_analysis/alternate_bounds_use_after_loop2.hir"); + let result_1 = interp_module!(module, 0, dyn_consts, a.clone()); + + println!("result: {:?}", result_1); + + let schedule = Some(default_schedule![]); + + let module = run_schedule_on_hercules(module, schedule).unwrap(); + + let result_2 = interp_module!(module, 0, dyn_consts, a.clone()); + println!("{:?}", result_2); + + assert_eq!(result_1, result_2); +} + +// Test canonicalization +#[test] +fn do_while_separate_body() { + let len = 2; + let dyn_consts = [len]; + + let a = vec![3, 4, 5, 6]; + let module = parse_file("../test_inputs/loop_analysis/do_while_separate_body2.hir"); + let result_1 = interp_module!(module, 0, dyn_consts, 2i32); + + println!("result: {:?}", result_1); + + let schedule = Some(default_schedule![ + PhiElim, + Forkify, + ]); + + let module = run_schedule_on_hercules(module, schedule).unwrap(); + + let result_2 = interp_module!(module, 0, dyn_consts, 2i32); + println!("{:?}", result_2); + + assert_eq!(result_1, result_2); +} + +#[test] +fn alternate_bounds_internal_control() { + let len = 4; + let dyn_consts = [len]; + + let module = parse_file("../test_inputs/loop_analysis/alternate_bounds_internal_control.hir"); + let result_1 = interp_module!(module, 0, dyn_consts, 3); + + println!("result: {:?}", result_1); + + let schedule = Some(default_schedule![ + PhiElim, + Forkify, + ]); + + let module = run_schedule_on_hercules(module, schedule).unwrap(); + + let result_2 = interp_module!(module, 0, dyn_consts, 3); + println!("{:?}", result_1); + println!("{:?}", result_2); + + assert_eq!(result_1, result_2); +} + +#[test] +fn alternate_bounds_internal_control2() { + let len = 2; + let dyn_consts = [len]; + + let module = parse_file("../test_inputs/loop_analysis/alternate_bounds_internal_control2.hir"); + let result_1 = interp_module!(module, 0, dyn_consts, 3); + + println!("result: {:?}", result_1); + + let schedule = Some(default_schedule![ + PhiElim, + Forkify, + ]); + + let module = run_schedule_on_hercules(module, schedule).unwrap(); + + let result_2 = interp_module!(module, 0, dyn_consts, 3); + println!("{:?}", result_1); + println!("{:?}", result_2); + + assert_eq!(result_1, result_2); +} + +#[test] +fn alternate_bounds_nested_do_loop() { + let len = 1; + let dyn_consts = [10, 5]; + + let module = parse_file("../test_inputs/loop_analysis/alternate_bounds_nested_do_loop.hir"); + let result_1 = interp_module!(module, 0, dyn_consts, 3); + + println!("result: {:?}", result_1); + + let module = run_schedule_on_hercules(module, None).unwrap(); + + let result_2 = interp_module!(module, 0, dyn_consts, 3); + println!("{:?}", result_1); + println!("{:?}", result_2); + + assert_eq!(result_1, result_2); +} + +#[test] +fn alternate_bounds_nested_do_loop_array() { + let len = 1; + let dyn_consts = [10, 5]; + + let a = vec![4u64, 4, 4, 4, 4]; + let module = + parse_file("../test_inputs/loop_analysis/alternate_bounds_nested_do_loop_array.hir"); + let result_1 = interp_module!(module, 0, dyn_consts, a.clone()); + + println!("result: {:?}", result_1); + + let module = run_schedule_on_hercules(module, None).unwrap(); + + let result_2 = interp_module!(module, 0, dyn_consts, a); + println!("{:?}", result_1); + println!("{:?}", result_2); + + assert_eq!(result_1, result_2); +} + +#[test] +fn alternate_bounds_nested_do_loop_guarded() { + let len = 1; + let dyn_consts = [3, 2]; + + let module = + parse_file("../test_inputs/loop_analysis/alternate_bounds_nested_do_loop_guarded.hir"); + let result_1 = interp_module!(module, 0, dyn_consts, 3); + + println!("result: {:?}", result_1); + + let module = run_schedule_on_hercules(module, None).unwrap(); + + let result_2 = interp_module!(module, 0, dyn_consts, 3); + println!("{:?}", result_1); + println!("{:?}", result_2); + + assert_eq!(result_1, result_2); + + let mut pm = PassManager::new(module.clone()); + + let module = run_schedule_on_hercules(module, None).unwrap(); + + let result_2 = interp_module!(module, 0, dyn_consts, 3); + println!("{:?}", result_1); + println!("{:?}", result_2); + + assert_eq!(result_1, result_2); +} + +// Tests a do while loop that only iterates once, +// canonicalization *should not* transform this to a while loop, as there is no +// guard that replicates the loop condition. +#[ignore] +#[test] +fn do_loop_not_continued() { + // let len = 1; + // let dyn_consts = [len]; + // let params = vec![1, 2, 3, 4, 5]; + + // let module = parse_file("../test_inputs/loop_analysis/alternate_bounds_use_after_loop.hir"); + // let result_1 = interp_module!(module, 0,dyn_consts, params); + + // println!("result: {:?}", result_1); +} + +// Tests a do while loop that is guarded, so should be canonicalized +// It also has +#[test] +fn do_loop_complex_immediate_guarded() { + let len = 1; + let dyn_consts = [len]; + + let module = parse_file("../test_inputs/loop_analysis/do_loop_immediate_guard.hir"); + let result_1 = interp_module!(module, 0, dyn_consts, 3); + + println!("result: {:?}", result_1); + + let module = run_schedule_on_hercules(module, None).unwrap(); + + let result_2 = interp_module!(module, 0, dyn_consts, 3); + assert_eq!(result_1, result_2); +} + +#[ignore] +#[test] +fn loop_canonical_sum() { + let len = 1; + let dyn_consts = [len]; + let params = vec![1, 2, 3, 4, 5]; + + let module = parse_file("../test_inputs/loop_analysis/loop_array_sum.hir"); + let result_1 = interp_module!(module, 0, dyn_consts, params); + + println!("result: {:?}", result_1); +} + +#[test] +#[ignore] +fn antideps_pipeline() { + let len = 1; + let dyn_consts = [2, 2, 2]; + + // FIXME: This path should not leave the crate + let module = parse_module_from_hbin("../../juno_samples/antideps/antideps.hbin"); + let result_1 = interp_module!(module, 0, dyn_consts, 9i32); + + println!("result: {:?}", result_1); + + let module = run_schedule_on_hercules(module, None).unwrap(); + + let result_2 = interp_module!(module, 0, dyn_consts, 9i32); + assert_eq!(result_1, result_2); +} + +#[test] +#[ignore] +fn implicit_clone_pipeline() { + let len = 1; + let dyn_consts = [2, 2, 2]; + + // FIXME: This path should not leave the crate + let module = parse_module_from_hbin("../../juno_samples/implicit_clone/out.hbin"); + let result_1 = interp_module!(module, 0, dyn_consts, 2u64, 2u64); + + println!("result: {:?}", result_1); + let schedule = default_schedule![ + Forkify, + ForkGuardElim, + Forkify, + ForkGuardElim, + Forkify, + ForkGuardElim, + DCE, + ForkSplit, + Unforkify, + GVN, + DCE, + DCE, + AutoOutline, + InterproceduralSROA, + SROA, + InferSchedules, + DCE, + GCM, + DCE, + FloatCollections, + GCM, + ]; + let module = run_schedule_on_hercules(module, Some(schedule)).unwrap(); + + let result_2 = interp_module!(module, 0, dyn_consts, 2u64, 2u64); + assert_eq!(result_1, result_2); +} + +#[test] +#[ignore] +fn look_at_local() { + const I: usize = 4; + const J: usize = 4; + const K: usize = 4; + let a: Vec<i32> = (0..I * J).map(|_| random::<i32>() % 100).collect(); + let b: Vec<i32> = (0..J * K).map(|_| random::<i32>() % 100).collect(); + let dyn_consts = [I, J, K]; + let mut correct_c: Box<[i32]> = (0..I * K).map(|_| 0).collect(); + for i in 0..I { + for k in 0..K { + for j in 0..J { + correct_c[i * K + k] += a[i * J + j] * b[j * K + k]; + } + } + } + + let module = parse_module_from_hbin( + "/home/xavierrouth/dev/hercules/hercules_test/hercules_tests/save_me.hbin", + ); + + let schedule = Some(default_schedule![ + ]); + + let result_1 = interp_module!(module, 0, dyn_consts, a.clone(), b.clone()); + + let module = run_schedule_on_hercules(module.clone(), schedule).unwrap(); + + let schedule = Some(default_schedule![ + Unforkify, Verify, + ]); + + let module = run_schedule_on_hercules(module.clone(), schedule).unwrap(); + + let result_2 = interp_module!(module, 0, dyn_consts, a.clone(), b.clone()); + + println!("golden: {:?}", correct_c); + println!("result: {:?}", result_2); +} +#[test] +#[ignore] +fn matmul_pipeline() { + let len = 1; + + const I: usize = 4; + const J: usize = 4; + const K: usize = 4; + let a: Vec<i32> = (0i32..(I * J) as i32).map(|v| v + 1).collect(); + let b: Vec<i32> = ((I * J) as i32..(J * K) as i32 + (I * J) as i32) + .map(|v| v + 1) + .collect(); + let a: Vec<i32> = (0..I * J).map(|_| random::<i32>() % 100).collect(); + let b: Vec<i32> = (0..J * K).map(|_| random::<i32>() % 100).collect(); + let dyn_consts = [I, J, K]; + + // FIXME: This path should not leave the crate + let mut module = parse_module_from_hbin("../../juno_samples/matmul/out.hbin"); + // + let mut correct_c: Box<[i32]> = (0..I * K).map(|_| 0).collect(); + for i in 0..I { + for k in 0..K { + for j in 0..J { + correct_c[i * K + k] += a[i * J + j] * b[j * K + k]; + } + } + } + + let result_1 = interp_module!(module, 1, dyn_consts, a.clone(), b.clone()); + + println!("golden: {:?}", correct_c); + println!("result: {:?}", result_1); + + let InterpreterVal::Array(_, d) = result_1.clone() else { + panic!() + }; + let InterpreterVal::Integer32(value) = d[0] else { + panic!() + }; + assert_eq!(correct_c[0], value); + + let schedule = Some(default_schedule![Xdot, ForkSplit, Unforkify, Xdot,]); + + module = run_schedule_on_hercules(module, schedule).unwrap(); + + let result_2 = interp_module!(module, 1, dyn_consts, a.clone(), b.clone()); + + println!("result: {:?}", result_2); + assert_eq!(result_1, result_2); +} diff --git a/hercules_test/hercules_tests/tests/opt_tests.rs b/hercules_test/hercules_tests/tests/opt_tests.rs index 388dfeddf38eecab3ffb191a344682395e71d5cd..2f85b78b6103d84d7aacb9f35e34b834aea56005 100644 --- a/hercules_test/hercules_tests/tests/opt_tests.rs +++ b/hercules_test/hercules_tests/tests/opt_tests.rs @@ -3,198 +3,205 @@ use std::env; use rand::Rng; use hercules_interpreter::*; -use hercules_opt::pass::Pass; - -#[test] -fn matmul_int() { - let module = parse_file("../test_inputs/matmul_int.hir"); - let dyn_consts = [2, 2, 2]; - let m1 = vec![3, 4, 5, 6]; - let m2 = vec![7, 8, 9, 10]; - let result_1 = interp_module!(module, dyn_consts, m1.clone(), m2.clone()); - - let mut pm = hercules_opt::pass::PassManager::new(module.clone()); - - let passes = vec![ - Pass::Verify, - Pass::CCP, - Pass::DCE, - Pass::GVN, - Pass::DCE, - Pass::Forkify, - Pass::DCE, - Pass::Predication, - Pass::DCE, - ]; - - for pass in passes { - pm.add_pass(pass); - } - pm.run_passes(); - - let module = pm.get_module(); - let result_2 = interp_module!(module, dyn_consts, m1, m2); - assert_eq!(result_1, result_2) -} - -#[test] -fn ccp_example() { - let module = parse_file("../test_inputs/ccp_example.hir"); - let dyn_consts = []; - let x = 34; - let result_1 = interp_module!(module, dyn_consts, x); - - let mut pm = hercules_opt::pass::PassManager::new(module.clone()); - - let passes = vec![ - Pass::Verify, - Pass::CCP, - Pass::DCE, - Pass::GVN, - Pass::DCE, - Pass::Forkify, - Pass::DCE, - Pass::Predication, - Pass::DCE, - ]; - - for pass in passes { - pm.add_pass(pass); - } - pm.run_passes(); - - let module = pm.get_module(); - let result_2 = interp_module!(module, dyn_consts, x); - assert_eq!(result_1, result_2) -} - -#[test] -fn gvn_example() { - let module = parse_file("../test_inputs/gvn_example.hir"); - - let dyn_consts = []; - let x: i32 = rand::random(); - let x = x / 32; - let y: i32 = rand::random(); - let y = y / 32; // prevent overflow, - let result_1 = interp_module!(module, dyn_consts, x, y); - - let mut pm = hercules_opt::pass::PassManager::new(module.clone()); - - let passes = vec![ - Pass::Verify, - Pass::CCP, - Pass::DCE, - Pass::GVN, - Pass::DCE, - Pass::Forkify, - Pass::DCE, - Pass::Predication, - Pass::DCE, - ]; - - for pass in passes { - pm.add_pass(pass); - } - pm.run_passes(); - - let module = pm.get_module(); - let result_2 = interp_module!(module, dyn_consts, x, y); - assert_eq!(result_1, result_2) -} - -#[test] -fn sum_int() { - let module = parse_file("../test_inputs/sum_int1.hir"); - - let size = 100; - let dyn_consts = [size]; - let mut vec = vec![0; size]; - let mut rng = rand::thread_rng(); - - for x in vec.iter_mut() { - *x = rng.gen::<i32>() / 100; - } - - let result_1 = interp_module!(module, dyn_consts, vec.clone()); - - let mut pm = hercules_opt::pass::PassManager::new(module.clone()); - - let passes = vec![ - Pass::Verify, - Pass::CCP, - Pass::DCE, - Pass::GVN, - Pass::DCE, - Pass::Forkify, - Pass::DCE, - Pass::Predication, - Pass::DCE, - ]; - - for pass in passes { - pm.add_pass(pass); - } - pm.run_passes(); - - let module = pm.get_module(); - let result_2 = interp_module!(module, dyn_consts, vec); - assert_eq!(result_1, result_2) -} - -#[test] -fn sum_int2() { - let module = parse_file("../test_inputs/sum_int2.hir"); - - let size = 100; - let dyn_consts = [size]; - let mut vec = vec![0; size]; - let mut rng = rand::thread_rng(); - - for x in vec.iter_mut() { - *x = rng.gen::<i32>() / 100; - } - - let result_1 = interp_module!(module, dyn_consts, vec.clone()); - - let mut pm = hercules_opt::pass::PassManager::new(module.clone()); - - let passes = vec![ - Pass::Verify, - Pass::CCP, - Pass::DCE, - Pass::GVN, - Pass::DCE, - Pass::Forkify, - Pass::DCE, - Pass::Predication, - Pass::DCE, - ]; - - for pass in passes { - pm.add_pass(pass); - } - pm.run_passes(); - - let module = pm.get_module(); - let result_2 = interp_module!(module, dyn_consts, vec); - assert_eq!(result_1, result_2) -} - -#[test] -fn sum_int2_smaller() { - interp_file_with_passes!("../test_inputs/sum_int2.hir", - [100], - vec![ - Pass::Verify, - Pass::CCP, - Pass::DCE, - Pass::GVN, - Pass::DCE, - Pass::Forkify, - Pass::DCE, - Pass::Predication, - Pass::DCE, - ], - vec![1; 100]); -} +use hercules_ir::ID; +use juno_scheduler::*; + +// #[test] +// fn matmul_int() { +// let module = parse_file("../test_inputs/matmul_int.hir"); +// let dyn_consts = [2, 2, 2]; +// let m1 = vec![3, 4, 5, 6]; +// let m2 = vec![7, 8, 9, 10]; +// let result_1 = interp_module!(module, 0, dyn_consts, m1.clone(), m2.clone()); + +// let mut pm = hercules_opt::pass::PassManager::new(module.clone()); + +// let passes = vec![ +// // Pass::Verify, +// // Pass::CCP, +// // Pass::DCE, +// // Pass::GVN, +// // Pass::DCE, +// // Pass::Forkify, +// // Pass::DCE, +// // Pass::Predication, +// // Pass::DCE, +// ]; + +// for pass in passes { +// pm.add_pass(pass); +// } +// pm.run_passes(); + +// let module = pm.get_module(); +// let result_2 = interp_module!(module, 0, dyn_consts, m1, m2); +// // println!("result: {:?}", result_1); +// assert_eq!(result_1, result_2) +// } + +// #[test] +// fn ccp_example() { +// let module = parse_file("../test_inputs/ccp_example.hir"); +// let dyn_consts = []; +// let x = 34; +// let result_1 = interp_module!(module, 0, dyn_consts, x); + +// let mut pm = hercules_opt::pass::PassManager::new(module.clone()); + +// let passes = vec![ +// Pass::Verify, +// Pass::CCP, +// Pass::DCE, +// Pass::GVN, +// Pass::DCE, +// Pass::Forkify, +// Pass::DCE, +// Pass::Predication, +// Pass::DCE, +// ]; + +// for pass in passes { +// pm.add_pass(pass); +// } +// pm.run_passes(); + +// let module = pm.get_module(); +// let result_2 = interp_module!(module, 0, dyn_consts, x); +// assert_eq!(result_1, result_2) +// } + +// #[test] +// fn gvn_example() { +// let module = parse_file("../test_inputs/gvn_example.hir"); + +// let dyn_consts = []; +// let x: i32 = rand::random(); +// let x = x / 32; +// let y: i32 = rand::random(); +// let y = y / 32; // prevent overflow, +// let result_1 = interp_module!(module, 0, dyn_consts, x, y); + +// let mut pm = hercules_opt::pass::PassManager::new(module.clone()); + +// let passes = vec![ +// Pass::Verify, +// Pass::CCP, +// Pass::DCE, +// Pass::GVN, +// Pass::DCE, +// Pass::Forkify, +// Pass::DCE, +// Pass::Predication, +// Pass::DCE, +// ]; + +// for pass in passes { +// pm.add_pass(pass); +// } +// pm.run_passes(); + +// let module = pm.get_module(); +// let result_2 = interp_module!(module, 0, dyn_consts, x, y); +// assert_eq!(result_1, result_2) +// } + +// #[test] +// fn sum_int() { +// let module = parse_file("../test_inputs/sum_int1.hir"); + +// let size = 2; +// let dyn_consts = [size]; +// let mut vec = vec![0; size]; +// let mut rng = rand::thread_rng(); + +// for x in vec.iter_mut() { +// *x = rng.gen::<i32>() / 100; +// } + +// println!("{:?}", vec); + +// let result_1 = interp_module!(module, 0, dyn_consts, vec.clone()); + +// println!("{:?}", result_1); + +// let mut pm = hercules_opt::pass::PassManager::new(module.clone()); + +// let passes = vec![ +// Pass::Verify, +// Pass::CCP, +// Pass::DCE, +// Pass::GVN, +// Pass::DCE, +// Pass::Forkify, +// Pass::DCE, +// Pass::Predication, +// Pass::DCE, +// ]; + +// for pass in passes { +// pm.add_pass(pass); +// } +// pm.run_passes(); + +// let module = pm.get_module(); +// let result_2 = interp_module!(module, 0, dyn_consts, vec); + +// assert_eq!(result_1, result_2) +// } + +// #[test] +// fn sum_int2() { +// let module = parse_file("../test_inputs/sum_int2.hir"); + +// let size = 10; +// let dyn_consts = [size]; +// let mut vec = vec![0; size]; +// let mut rng = rand::thread_rng(); + +// for x in vec.iter_mut() { +// *x = rng.gen::<i32>() / 100; +// } + +// let result_1 = interp_module!(module, 0, dyn_consts, vec.clone()); + +// let mut pm = hercules_opt::pass::PassManager::new(module.clone()); + +// let passes = vec![ +// Pass::Verify, +// Pass::CCP, +// Pass::DCE, +// Pass::GVN, +// Pass::DCE, +// Pass::Forkify, +// Pass::DCE, +// Pass::Predication, +// Pass::DCE, +// ]; + +// for pass in passes { +// pm.add_pass(pass); +// } +// pm.run_passes(); + +// let module = pm.get_module(); +// let result_2 = interp_module!(module, 0, dyn_consts, vec); +// assert_eq!(result_1, result_2) +// } + +// #[test] +// fn sum_int2_smaller() { +// interp_file_with_passes!("../test_inputs/sum_int2.hir", +// [100], +// vec![ +// Pass::Verify, +// Pass::CCP, +// Pass::DCE, +// Pass::GVN, +// Pass::DCE, +// Pass::Forkify, +// Pass::DCE, +// Pass::Predication, +// Pass::DCE, +// ], +// vec![1; 100]); +// } diff --git a/hercules_test/test_inputs/2d_fork.hir b/hercules_test/test_inputs/2d_fork.hir new file mode 100644 index 0000000000000000000000000000000000000000..e784c1db2d11640d6987ce3b21f34038e534b4e3 --- /dev/null +++ b/hercules_test/test_inputs/2d_fork.hir @@ -0,0 +1,8 @@ +fn twodeefork<2>(x: i32) -> i32 + zero = constant(i32, 0) + one = constant(i32, 1) + f = fork(start, #1, #0) + j = join(f) + add = add(r, one) + r = reduce(j, zero, add) + z = return(j, r) \ No newline at end of file diff --git a/hercules_test/test_inputs/3d_fork.hir b/hercules_test/test_inputs/3d_fork.hir new file mode 100644 index 0000000000000000000000000000000000000000..746fd902ce8fb8349852d3accaff451003058f44 --- /dev/null +++ b/hercules_test/test_inputs/3d_fork.hir @@ -0,0 +1,8 @@ +fn twodeefork<3>(x: i32) -> i32 + zero = constant(i32, 0) + one = constant(i32, 1) + f = fork(start, #2, #1, #0) + j = join(f) + add = add(r, one) + r = reduce(j, zero, add) + z = return(j, r) \ No newline at end of file diff --git a/hercules_test/test_inputs/5d_fork.hir b/hercules_test/test_inputs/5d_fork.hir new file mode 100644 index 0000000000000000000000000000000000000000..942996013f6aada238ea461f37ad365e7f3d0d0e --- /dev/null +++ b/hercules_test/test_inputs/5d_fork.hir @@ -0,0 +1,8 @@ +fn fivedeefork<5>(x: i32) -> i32 + zero = constant(i32, 0) + one = constant(i32, 1) + f = fork(start, #4, #3, #2, #1, #0) + j = join(f) + add = add(r, one) + r = reduce(j, zero, add) + z = return(j, r) \ No newline at end of file diff --git a/hercules_test/test_inputs/fork_transforms/fork_fission/inner_control.hir b/hercules_test/test_inputs/fork_transforms/fork_fission/inner_control.hir new file mode 100644 index 0000000000000000000000000000000000000000..052bbdb829ad1c6e33f304f500f3f7e1d2838324 --- /dev/null +++ b/hercules_test/test_inputs/fork_transforms/fork_fission/inner_control.hir @@ -0,0 +1,15 @@ +fn fun<2>(x: u64) -> u64 + zero = constant(u64, 0) + one = constant(u64, 1) + two = constant(u64, 2) + f = fork(start, #0) + f2 = fork(f, #1) + j2 = join(f2) + j = join(j2) + tid = thread_id(f, 0) + add1 = add(reduce1, one) + reduce1 = reduce(j, zero, add1) + add2 = add(reduce2, two) + reduce2 = reduce(j, zero, add2) + out1 = add(reduce1, reduce2) + z = return(j, out1) \ No newline at end of file diff --git a/hercules_test/test_inputs/fork_transforms/fork_fission/inner_loop.hir b/hercules_test/test_inputs/fork_transforms/fork_fission/inner_loop.hir new file mode 100644 index 0000000000000000000000000000000000000000..0cc13b2fe21ab6cb434b9a13b3dc212c125394aa --- /dev/null +++ b/hercules_test/test_inputs/fork_transforms/fork_fission/inner_loop.hir @@ -0,0 +1,23 @@ +fn fun<2>(x: u64) -> u64 + zero = constant(u64, 0) + one = constant(u64, 1) + two = constant(u64, 2) + f = fork(start, #0) + zero_idx = constant(u64, 0) + one_idx = constant(u64, 1) + bound = dynamic_constant(#1) + loop = region(f, if_true) + idx = phi(loop, zero_idx, idx_inc) + idx_inc = add(idx, one_idx) + in_bounds = lt(idx, bound) + if = if(loop, in_bounds) + if_false = projection(if, 0) + if_true = projection(if, 1) + j = join(if_false) + tid = thread_id(f, 0) + add1 = add(reduce1, idx) + reduce1 = reduce(j, zero, add1) + add2 = add(reduce2, idx_inc) + reduce2 = reduce(j, zero, add2) + out1 = add(reduce1, reduce2) + z = return(j, out1) \ No newline at end of file diff --git a/hercules_test/test_inputs/fork_transforms/fork_fission/intermediate_buffer_simple.hir b/hercules_test/test_inputs/fork_transforms/fork_fission/intermediate_buffer_simple.hir new file mode 100644 index 0000000000000000000000000000000000000000..75e0f157fcf9e0ecf1cddcd8821669646473f10a --- /dev/null +++ b/hercules_test/test_inputs/fork_transforms/fork_fission/intermediate_buffer_simple.hir @@ -0,0 +1,10 @@ +fn fun<1>(x: u64) -> u64 + zero = constant(u64, 0) + one = constant(u64, 1) + two = constant(u64, 2) + f = fork(start, #0) + j = join(f) + tid = thread_id(f, 0) + add1 = add(reduce1, two) + reduce1 = reduce(j, zero, add1) + z = return(j, reduce1) diff --git a/hercules_test/test_inputs/fork_transforms/fork_fission/simple1.hir b/hercules_test/test_inputs/fork_transforms/fork_fission/simple1.hir new file mode 100644 index 0000000000000000000000000000000000000000..aaed60d9039a092b08a2c96223ac639872304bf4 --- /dev/null +++ b/hercules_test/test_inputs/fork_transforms/fork_fission/simple1.hir @@ -0,0 +1,13 @@ +fn fun<1>(x: u64) -> u64 + zero = constant(u64, 0) + one = constant(u64, 1) + two = constant(u64, 2) + f = fork(start, #0) + j = join(f) + tid = thread_id(f, 0) + add1 = add(reduce1, one) + reduce1 = reduce(j, zero, add1) + add2 = add(reduce2, two) + reduce2 = reduce(j, zero, add2) + out1 = add(reduce1, reduce2) + z = return(j, out1) diff --git a/hercules_test/test_inputs/fork_transforms/fork_fission/simple2.hir b/hercules_test/test_inputs/fork_transforms/fork_fission/simple2.hir new file mode 100644 index 0000000000000000000000000000000000000000..6be6d2c7c782f5432a22a33694e7aa04c607b421 --- /dev/null +++ b/hercules_test/test_inputs/fork_transforms/fork_fission/simple2.hir @@ -0,0 +1,19 @@ +fn fun<1>(x: u64) -> u64 + zero = constant(u64, 0) + one = constant(u64, 1) + two = constant(u64, 2) + f = fork(start, #0) + j = join(f) + tid = thread_id(f, 0) + add1 = add(reduce1, one) + reduce1 = reduce(j, zero, add1) + add2 = add(reduce2, two) + reduce2 = reduce(j, zero, add2) + add3 = add(reduce3, tid) + reduce3 = reduce(j, zero, add3) + add4 = mul(reduce4, tid) + reduce4 = reduce(j, zero, add4) + out1 = add(reduce1, reduce2) + out2 = add(reduce3, reduce4) + out3 = add(out1, out2) + z = return(j, out3) diff --git a/hercules_test/test_inputs/fork_transforms/fork_fission/tricky.hir b/hercules_test/test_inputs/fork_transforms/fork_fission/tricky.hir new file mode 100644 index 0000000000000000000000000000000000000000..6fb895c4147e14cb798c804c05db41695e64266e --- /dev/null +++ b/hercules_test/test_inputs/fork_transforms/fork_fission/tricky.hir @@ -0,0 +1,13 @@ +fn fun<1>(x: u64) -> u64 + zero = constant(u64, 0) + one = constant(u64, 1) + two = constant(u64, 2) + f = fork(start, #0) + j = join(f) + tid = thread_id(f, 0) + add1 = add(reduce1, one) + reduce1 = reduce(j, zero, add1) + add2 = add(reduce2, reduce1) + reduce2 = reduce(j, zero, add2) + out1 = add(reduce1, reduce2) + z = return(j, out1) diff --git a/hercules_test/test_inputs/fork_transforms/fork_fusion.hir b/hercules_test/test_inputs/fork_transforms/fork_fusion.hir new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/hercules_test/test_inputs/fork_transforms/fork_interchange.hir b/hercules_test/test_inputs/fork_transforms/fork_interchange.hir new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/hercules_test/test_inputs/fork_transforms/matmul_int.hir b/hercules_test/test_inputs/fork_transforms/matmul_int.hir new file mode 100644 index 0000000000000000000000000000000000000000..ab0f384a563ccb6144e59b811745fe5aa76f08dd --- /dev/null +++ b/hercules_test/test_inputs/fork_transforms/matmul_int.hir @@ -0,0 +1,18 @@ +fn matmul<3>(a: array(i32, #0, #1), b: array(i32, #1, #2)) -> array(i32, #0, #2) + c = constant(array(i32, #0, #2), []) + i_j_ctrl = fork(start, #0, #2) + i_idx = thread_id(i_j_ctrl, 0) + j_idx = thread_id(i_j_ctrl, 1) + k_ctrl = fork(i_j_ctrl, #1) + k_idx = thread_id(k_ctrl, 0) + k_join_ctrl = join(k_ctrl) + i_j_join_ctrl = join(k_join_ctrl) + r = return(i_j_join_ctrl, update_i_j_c) + zero = constant(i32, 0) + a_val = read(a, position(i_idx, k_idx)) + b_val = read(b, position(k_idx, j_idx)) + mul = mul(a_val, b_val) + add = add(mul, dot) + dot = reduce(k_join_ctrl, zero, add) + update_c = write(update_i_j_c, dot, position(i_idx, j_idx)) + update_i_j_c = reduce(i_j_join_ctrl, c, update_c) \ No newline at end of file diff --git a/hercules_test/test_inputs/fork_transforms/tiled_matmul_int.hir b/hercules_test/test_inputs/fork_transforms/tiled_matmul_int.hir new file mode 100644 index 0000000000000000000000000000000000000000..ab0f384a563ccb6144e59b811745fe5aa76f08dd --- /dev/null +++ b/hercules_test/test_inputs/fork_transforms/tiled_matmul_int.hir @@ -0,0 +1,18 @@ +fn matmul<3>(a: array(i32, #0, #1), b: array(i32, #1, #2)) -> array(i32, #0, #2) + c = constant(array(i32, #0, #2), []) + i_j_ctrl = fork(start, #0, #2) + i_idx = thread_id(i_j_ctrl, 0) + j_idx = thread_id(i_j_ctrl, 1) + k_ctrl = fork(i_j_ctrl, #1) + k_idx = thread_id(k_ctrl, 0) + k_join_ctrl = join(k_ctrl) + i_j_join_ctrl = join(k_join_ctrl) + r = return(i_j_join_ctrl, update_i_j_c) + zero = constant(i32, 0) + a_val = read(a, position(i_idx, k_idx)) + b_val = read(b, position(k_idx, j_idx)) + mul = mul(a_val, b_val) + add = add(mul, dot) + dot = reduce(k_join_ctrl, zero, add) + update_c = write(update_i_j_c, dot, position(i_idx, j_idx)) + update_i_j_c = reduce(i_j_join_ctrl, c, update_c) \ No newline at end of file diff --git a/hercules_test/test_inputs/forkify/alternate_bounds.hir b/hercules_test/test_inputs/forkify/alternate_bounds.hir new file mode 100644 index 0000000000000000000000000000000000000000..4a9ba0153448e4c9339e901d3ba3a10e027bad56 --- /dev/null +++ b/hercules_test/test_inputs/forkify/alternate_bounds.hir @@ -0,0 +1,16 @@ +fn sum<1>(a: array(i32, #0)) -> i32 + zero_idx = constant(u64, 0) + one_idx = constant(u64, 1) + zero_inc = constant(i32, 0) + bound = dynamic_constant(#0) + loop = region(start, if_true) + idx = phi(loop, zero_idx, idx_inc) + idx_inc = add(idx, one_idx) + red = phi(loop, zero_inc, red_add) + read = read(a, position(idx)) + red_add = add(red, read) + in_bounds = lt(idx_inc, bound) + if = if(loop, in_bounds) + if_false = projection(if, 0) + if_true = projection(if, 1) + r = return(if_false, red_add) \ No newline at end of file diff --git a/hercules_test/test_inputs/forkify/broken_sum.hir b/hercules_test/test_inputs/forkify/broken_sum.hir new file mode 100644 index 0000000000000000000000000000000000000000..d15ef5613e271cd8685660785f5505dedbf40ec9 --- /dev/null +++ b/hercules_test/test_inputs/forkify/broken_sum.hir @@ -0,0 +1,16 @@ +fn sum<1>(a: array(i32, #0)) -> i32 + zero_idx = constant(u64, 0) + one_idx = constant(u64, 1) + zero_inc = constant(i32, 0) + bound = dynamic_constant(#0) + loop = region(start, if_true) + idx = phi(loop, zero_idx, idx_inc) + idx_inc = add(idx, one_idx) + red = phi(loop, zero_inc, red_add) + read = read(a, position(idx)) + red_add = add(red, read) + in_bounds = lt(idx, bound) + if = if(loop, in_bounds) + if_false = projection(if, 0) + if_true = projection(if, 1) + r = return(if_false, red_add) \ No newline at end of file diff --git a/hercules_test/test_inputs/forkify/control_after_condition.hir b/hercules_test/test_inputs/forkify/control_after_condition.hir new file mode 100644 index 0000000000000000000000000000000000000000..db40225bac26f5d6687b4b36bd46e06c963c4815 --- /dev/null +++ b/hercules_test/test_inputs/forkify/control_after_condition.hir @@ -0,0 +1,25 @@ +fn alt_sum<1>(a: array(i32, #0)) -> i32 + zero_idx = constant(u64, 0) + one_idx = constant(u64, 1) + two_idx = constant(u64, 2) + zero_inc = constant(i32, 0) + bound = dynamic_constant(#0) + loop = region(start, negate_bottom) + idx = phi(loop, zero_idx, idx_inc) + idx_inc = add(idx, one_idx) + red = phi(loop, zero_inc, red_add) + rem = rem(idx, two_idx) + odd = eq(rem, one_idx) + negate_if = if(loop_continue, odd) + negate_if_false = projection(negate_if, 0) + negate_if_true = projection(negate_if, 1) + negate_bottom = region(negate_if_false, negate_if_true) + read = read(a, position(idx)) + read_neg = neg(read) + read_phi = phi(negate_bottom, read, read_neg) + red_add = add(red, read_phi) + in_bounds = lt(idx, bound) + if = if(loop, in_bounds) + loop_exit = projection(if, 0) + loop_continue = projection(if, 1) + r = return(loop_exit, red) \ No newline at end of file diff --git a/hercules_test/test_inputs/forkify/control_before_condition.hir b/hercules_test/test_inputs/forkify/control_before_condition.hir new file mode 100644 index 0000000000000000000000000000000000000000..f24b565a56df1e9ed98e2637a8b66ef27ff6e7de --- /dev/null +++ b/hercules_test/test_inputs/forkify/control_before_condition.hir @@ -0,0 +1,25 @@ +fn alt_sum<1>(a: array(i32, #0)) -> i32 + zero_idx = constant(u64, 0) + one_idx = constant(u64, 1) + two_idx = constant(u64, 2) + zero_inc = constant(i32, 0) + bound = dynamic_constant(#0) + loop = region(start, if_true) + idx = phi(loop, zero_idx, idx_inc) + idx_inc = add(idx, one_idx) + red = phi(loop, zero_inc, red_add) + rem = rem(idx, two_idx) + odd = eq(rem, one_idx) + negate_if = if(loop, odd) + negate_if_false = projection(negate_if, 0) + negate_if_true = projection(negate_if, 1) + negate_bottom = region(negate_if_false, negate_if_true) + read = read(a, position(idx)) + read_neg = neg(read) + read_phi = phi(negate_bottom, read, read_neg) + red_add = add(red, read_phi) + in_bounds = lt(idx, bound) + if = if(negate_bottom, in_bounds) + if_false = projection(if, 0) + if_true = projection(if, 1) + r = return(if_false, red) \ No newline at end of file diff --git a/hercules_test/test_inputs/forkify/expected_fails.hir/bad_3nest_return.hir b/hercules_test/test_inputs/forkify/expected_fails.hir/bad_3nest_return.hir new file mode 100644 index 0000000000000000000000000000000000000000..f5ec4370b6befc4d45535fd25f024db60b55a0d0 --- /dev/null +++ b/hercules_test/test_inputs/forkify/expected_fails.hir/bad_3nest_return.hir @@ -0,0 +1,35 @@ +fn loop<3>(a: u32) -> i32 + zero_idx = constant(u64, 0) + one_idx = constant(u64, 1) + zero_var = constant(i32, 0) + one_var = constant(i32, 1) + inner_bound = dynamic_constant(#0) + outer_loop = region(outer_outer_if_true, inner_if_false) + inner_loop = region(outer_if_true, inner_if_true) + outer_var = phi(outer_loop, outer_outer_var, inner_var) + inner_var = phi(inner_loop, outer_var, inner_var_inc) + inner_var_inc = add(inner_var, one_var) + inner_idx = phi(inner_loop, zero_idx, inner_idx_inc) + inner_idx_inc = add(inner_idx, one_idx) + inner_in_bounds = lt(inner_idx, inner_bound) + outer_idx = phi(outer_loop, zero_idx, outer_idx_inc, outer_idx) + outer_idx_inc = add(outer_idx, one_idx) + outer_in_bounds = lt(outer_idx, outer_bound) + inner_if = if(inner_loop, inner_in_bounds) + inner_if_false = projection(inner_if, 0) + inner_if_true = projection(inner_if, 1) + outer_if = if(outer_loop, outer_in_bounds) + outer_if_false = projection(outer_if, 0) + outer_if_true = projection(outer_if, 1) + outer_bound = dynamic_constant(#1) + outer_outer_bound = dynamic_constant(#2) + outer_outer_loop = region(start, outer_if_false) + outer_outer_var = phi(outer_outer_loop, zero_var, outer_var) + outer_outer_if = if(outer_outer_loop, outer_outer_in_bounds) + outer_outer_if_false = projection(outer_outer_if, 0) + outer_outer_if_true = projection(outer_outer_if, 1) + outer_outer_idx = phi(outer_outer_loop, zero_idx, outer_outer_idx_inc, outer_outer_idx) + outer_outer_idx_inc = add(outer_outer_idx, one_idx) + outer_outer_in_bounds = lt(outer_outer_idx, outer_outer_bound) + r = return(outer_outer_if_false, inner_var) + diff --git a/hercules_test/test_inputs/forkify/expected_fails.hir/bad_loop_tid_sum.hir b/hercules_test/test_inputs/forkify/expected_fails.hir/bad_loop_tid_sum.hir new file mode 100644 index 0000000000000000000000000000000000000000..8dda179bf36d82d29bbbda48686191fa46a02fb6 --- /dev/null +++ b/hercules_test/test_inputs/forkify/expected_fails.hir/bad_loop_tid_sum.hir @@ -0,0 +1,16 @@ +fn loop<1>(a: u64) -> u64 + zero_idx = constant(u64, 0) + one_idx = constant(u64, 1) + zero_var = constant(u64, 0) + one_var = constant(u64, 1) + bound = dynamic_constant(#0) + loop = region(start, if_true) + var = phi(loop, zero_var, var_inc) + var_inc = add(var, idx) + idx = phi(loop, zero_idx, idx_inc) + idx_inc = add(idx, one_idx) + in_bounds = lt(idx, bound) + if = if(loop, in_bounds) + if_false = projection(if, 0) + if_true = projection(if, 1) + r = return(if_false, var_inc) \ No newline at end of file diff --git a/hercules_test/test_inputs/forkify/inner_fork.hir b/hercules_test/test_inputs/forkify/inner_fork.hir new file mode 100644 index 0000000000000000000000000000000000000000..e2c96a68b85ee605d5804baf338ada9b493e155f --- /dev/null +++ b/hercules_test/test_inputs/forkify/inner_fork.hir @@ -0,0 +1,22 @@ +fn loop<2>(a: u32) -> i32 + zero_idx = constant(u64, 0) + one_idx = constant(u64, 1) + zero_var = constant(i32, 0) + one_var = constant(i32, 1) + inner_bound = dynamic_constant(#0) + outer_bound = dynamic_constant(#1) + outer_loop = region(start, inner_join) + outer_if_true = projection(outer_if, 1) + inner_fork = fork(outer_if_true, #0) + inner_join = join(inner_fork) + outer_var = phi(outer_loop, zero_var, inner_var) + inner_var = reduce(inner_fork, outer_var, inner_var_inc) + inner_var_inc = add(inner_var, inner_idx) + inner_idx = thread_id(inner_fork, 0) + outer_idx = phi(outer_loop, zero_idx, outer_idx_inc) + outer_idx_inc = add(outer_idx, one_idx) + outer_in_bounds = lt(outer_idx, outer_bound) + outer_if = if(outer_loop, outer_in_bounds) + outer_if_false = projection(outer_if, 0) + r = return(outer_if_false, outer_var) + \ No newline at end of file diff --git a/hercules_test/test_inputs/forkify/inner_fork_complex.hir b/hercules_test/test_inputs/forkify/inner_fork_complex.hir new file mode 100644 index 0000000000000000000000000000000000000000..91eb00fae9dd7e043155c1f6316e50b38337bad2 --- /dev/null +++ b/hercules_test/test_inputs/forkify/inner_fork_complex.hir @@ -0,0 +1,32 @@ +fn loop<2>(a: u32) -> u64 + zero_idx = constant(u64, 0) + one_idx = constant(u64, 1) + zero_var = constant(u64, 0) + one_var = constant(u64, 1) + ten = constant(u64, 10) + two = constant(u64, 2) + inner_bound = dynamic_constant(#0) + outer_bound = dynamic_constant(#1) + outer_loop = region(start, inner_condition_true_projection, inner_condition_false_projection ) + outer_if_true = projection(outer_if, 1) + other_phi_weird = phi(outer_loop, zero_var, inner_var, other_phi_weird) + inner_fork = fork(outer_if_true, #0) + inner_join = join(inner_fork) + inner_condition_eq = eq(outer_idx, two) + inner_condition_if = if(inner_join, inner_condition_eq) + inner_condition_true_projection = projection(inner_condition_if, 1) + inner_condition_false_projection = projection(inner_condition_if, 0) + outer_var = phi(outer_loop, zero_var, inner_var, inner_var) + inner_var = reduce(inner_join, outer_var, inner_var_inc) + inner_var_inc = add(inner_var, inner_var_inc_3) + inner_var_inc_2 = mul(ten, outer_idx) + inner_var_inc_3 = add(inner_var_inc_2, inner_idx) + inner_idx = thread_id(inner_fork, 0) + outer_idx = phi(outer_loop, zero_idx, outer_idx_inc, outer_idx_inc) + outer_idx_inc = add(outer_idx, one_idx) + outer_in_bounds = lt(outer_idx, outer_bound) + outer_if = if(outer_loop, outer_in_bounds) + outer_if_false = projection(outer_if, 0) + ret_val = add(outer_var, other_phi_weird) + r = return(outer_if_false, ret_val) + \ No newline at end of file diff --git a/hercules_test/test_inputs/forkify/loop_array_sum.hir b/hercules_test/test_inputs/forkify/loop_array_sum.hir new file mode 100644 index 0000000000000000000000000000000000000000..f9972b5917c200b93b5775fd4a6e501318e8c548 --- /dev/null +++ b/hercules_test/test_inputs/forkify/loop_array_sum.hir @@ -0,0 +1,16 @@ +fn sum<1>(a: array(i32, #0)) -> i32 + zero_idx = constant(u64, 0) + one_idx = constant(u64, 1) + zero_inc = constant(i32, 0) + bound = dynamic_constant(#0) + loop = region(start, if_true) + idx = phi(loop, zero_idx, idx_inc) + idx_inc = add(idx, one_idx) + red = phi(loop, zero_inc, red_add) + read = read(a, position(idx)) + red_add = add(red, read) + in_bounds = lt(idx, bound) + if = if(loop, in_bounds) + if_false = projection(if, 0) + if_true = projection(if, 1) + r = return(if_false, red) \ No newline at end of file diff --git a/hercules_test/test_inputs/forkify/loop_simple_iv.hir b/hercules_test/test_inputs/forkify/loop_simple_iv.hir new file mode 100644 index 0000000000000000000000000000000000000000..c25b9a2cf006b67087df0b8dc652f72400557090 --- /dev/null +++ b/hercules_test/test_inputs/forkify/loop_simple_iv.hir @@ -0,0 +1,12 @@ +fn loop<1>(a: u32) -> u64 + zero_idx = constant(u64, 0) + one_idx = constant(u64, 1) + bound = dynamic_constant(#0) + loop = region(start, if_true) + idx = phi(loop, zero_idx, idx_inc) + idx_inc = add(idx, one_idx) + in_bounds = lt(idx, bound) + if = if(loop, in_bounds) + if_false = projection(if, 0) + if_true = projection(if, 1) + r = return(if_false, idx) \ No newline at end of file diff --git a/hercules_test/test_inputs/forkify/loop_sum.hir b/hercules_test/test_inputs/forkify/loop_sum.hir new file mode 100644 index 0000000000000000000000000000000000000000..fd9c4debc163600c01e661b127b166358ac9c6db --- /dev/null +++ b/hercules_test/test_inputs/forkify/loop_sum.hir @@ -0,0 +1,16 @@ +fn loop<1>(a: u32) -> i32 + zero_idx = constant(u64, 0) + one_idx = constant(u64, 1) + zero_var = constant(i32, 0) + one_var = constant(i32, 1) + bound = dynamic_constant(#0) + loop = region(start, if_true) + var = phi(loop, zero_var, var_inc) + var_inc = add(var, one_var) + idx = phi(loop, zero_idx, idx_inc) + idx_inc = add(idx, one_idx) + in_bounds = lt(idx, bound) + if = if(loop, in_bounds) + if_false = projection(if, 0) + if_true = projection(if, 1) + r = return(if_false, var) \ No newline at end of file diff --git a/hercules_test/test_inputs/forkify/loop_tid_sum.hir b/hercules_test/test_inputs/forkify/loop_tid_sum.hir new file mode 100644 index 0000000000000000000000000000000000000000..2d3ca34db88efa5007b5fb933f0bb0e4b55e63e4 --- /dev/null +++ b/hercules_test/test_inputs/forkify/loop_tid_sum.hir @@ -0,0 +1,16 @@ +fn loop<1>(a: u64) -> u64 + zero_idx = constant(u64, 0) + one_idx = constant(u64, 1) + zero_var = constant(u64, 0) + one_var = constant(u64, 1) + bound = dynamic_constant(#0) + loop = region(start, if_true) + var = phi(loop, zero_var, var_inc) + var_inc = add(var, idx) + idx = phi(loop, zero_idx, idx_inc) + idx_inc = add(idx, one_idx) + in_bounds = lt(idx, bound) + if = if(loop, in_bounds) + if_false = projection(if, 0) + if_true = projection(if, 1) + r = return(if_false, var) \ No newline at end of file diff --git a/hercules_test/test_inputs/forkify/merged_phi_cycle.hir b/hercules_test/test_inputs/forkify/merged_phi_cycle.hir new file mode 100644 index 0000000000000000000000000000000000000000..cee473a08740b48935d4ff571bc003bbd9908729 --- /dev/null +++ b/hercules_test/test_inputs/forkify/merged_phi_cycle.hir @@ -0,0 +1,18 @@ +fn sum<1>(a: i32) -> u64 + zero_idx = constant(u64, 0) + one_idx = constant(u64, 1) + two = constant(u64, 2) + bound = dynamic_constant(#0) + loop = region(start, if_true) + idx = phi(loop, zero_idx, idx_inc) + idx_inc = add(idx, one_idx) + first_red = phi(loop, zero_idx, first_red_add) + second_red = phi(loop, zero_idx, first_red_add_2) + first_red_add = add(first_red, idx) + second_red_add_1 = add(first_red, idx) + second_red_add_2 = add(first_red_add, two) + in_bounds = lt(idx_inc, bound) + if = if(loop, in_bounds) + if_false = projection(if, 0) + if_true = projection(if, 1) + r = return(if_false, first_red_add_2) \ No newline at end of file diff --git a/hercules_test/test_inputs/forkify/nested_loop2.hir b/hercules_test/test_inputs/forkify/nested_loop2.hir new file mode 100644 index 0000000000000000000000000000000000000000..0f29ec747dd67bae4c5b7d8b1250ead925257e9d --- /dev/null +++ b/hercules_test/test_inputs/forkify/nested_loop2.hir @@ -0,0 +1,25 @@ +fn loop<2>(a: u32) -> i32 + zero_idx = constant(u64, 0) + one_idx = constant(u64, 1) + zero_var = constant(i32, 0) + one_var = constant(i32, 1) + inner_bound = dynamic_constant(#0) + outer_bound = dynamic_constant(#1) + outer_loop = region(start, inner_if_false) + inner_loop = region(outer_if_true, inner_if_true) + outer_var = phi(outer_loop, zero_var, inner_var) + inner_var = phi(inner_loop, outer_var, inner_var_inc) + inner_var_inc = add(inner_var, one_var) + inner_idx = phi(inner_loop, zero_idx, inner_idx_inc) + inner_idx_inc = add(inner_idx, one_idx) + inner_in_bounds = lt(inner_idx, inner_bound) + outer_idx = phi(outer_loop, zero_idx, outer_idx_inc, outer_idx) + outer_idx_inc = add(outer_idx, one_idx) + outer_in_bounds = lt(outer_idx, outer_bound) + inner_if = if(inner_loop, inner_in_bounds) + inner_if_false = projection(inner_if, 0) + inner_if_true = projection(inner_if, 1) + outer_if = if(outer_loop, outer_in_bounds) + outer_if_false = projection(outer_if, 0) + outer_if_true = projection(outer_if, 1) + r = return(outer_if_false, outer_var) \ No newline at end of file diff --git a/hercules_test/test_inputs/forkify/nested_tid_sum.hir b/hercules_test/test_inputs/forkify/nested_tid_sum.hir new file mode 100644 index 0000000000000000000000000000000000000000..5539202d297d5124bd3258e56568c1614e4e942e --- /dev/null +++ b/hercules_test/test_inputs/forkify/nested_tid_sum.hir @@ -0,0 +1,25 @@ +fn loop<2>(a: u32) -> u64 + zero_idx = constant(u64, 0) + one_idx = constant(u64, 1) + zero_var = constant(u64, 0) + one_var = constant(u64, 1) + inner_bound = dynamic_constant(#0) + outer_bound = dynamic_constant(#1) + outer_loop = region(start, inner_if_false) + inner_loop = region(outer_if_true, inner_if_true) + outer_var = phi(outer_loop, zero_var, inner_var) + inner_var = phi(inner_loop, outer_var, inner_var_inc) + inner_var_inc = add(inner_var, inner_idx) + inner_idx = phi(inner_loop, zero_idx, inner_idx_inc) + inner_idx_inc = add(inner_idx, one_idx) + inner_in_bounds = lt(inner_idx, inner_bound) + outer_idx = phi(outer_loop, zero_idx, outer_idx_inc, outer_idx) + outer_idx_inc = add(outer_idx, one_idx) + outer_in_bounds = lt(outer_idx, outer_bound) + inner_if = if(inner_loop, inner_in_bounds) + inner_if_false = projection(inner_if, 0) + inner_if_true = projection(inner_if, 1) + outer_if = if(outer_loop, outer_in_bounds) + outer_if_false = projection(outer_if, 0) + outer_if_true = projection(outer_if, 1) + r = return(outer_if_false, outer_var) \ No newline at end of file diff --git a/hercules_test/test_inputs/forkify/nested_tid_sum_2.hir b/hercules_test/test_inputs/forkify/nested_tid_sum_2.hir new file mode 100644 index 0000000000000000000000000000000000000000..9221fd47d857c9215a7b8912e68088b5809238b5 --- /dev/null +++ b/hercules_test/test_inputs/forkify/nested_tid_sum_2.hir @@ -0,0 +1,26 @@ +fn loop<2>(a: u32) -> u64 + zero_idx = constant(u64, 0) + one_idx = constant(u64, 1) + zero_var = constant(u64, 0) + one_var = constant(u64, 1) + inner_bound = dynamic_constant(#0) + outer_bound = dynamic_constant(#1) + outer_loop = region(start, inner_if_false) + inner_loop = region(outer_if_true, inner_if_true) + outer_var = phi(outer_loop, zero_var, inner_var) + inner_var = phi(inner_loop, outer_var, inner_var_inc) + iv_mul = mul(inner_idx, outer_idx) + inner_var_inc = add(inner_var, iv_mul) + inner_idx = phi(inner_loop, zero_idx, inner_idx_inc) + inner_idx_inc = add(inner_idx, one_idx) + inner_in_bounds = lt(inner_idx, inner_bound) + outer_idx = phi(outer_loop, zero_idx, outer_idx_inc, outer_idx) + outer_idx_inc = add(outer_idx, one_idx) + outer_in_bounds = lt(outer_idx, outer_bound) + inner_if = if(inner_loop, inner_in_bounds) + inner_if_false = projection(inner_if, 0) + inner_if_true = projection(inner_if, 1) + outer_if = if(outer_loop, outer_in_bounds) + outer_if_false = projection(outer_if, 0) + outer_if_true = projection(outer_if, 1) + r = return(outer_if_false, outer_var) \ No newline at end of file diff --git a/hercules_test/test_inputs/forkify/phi_loop4.hir b/hercules_test/test_inputs/forkify/phi_loop4.hir new file mode 100644 index 0000000000000000000000000000000000000000..e69ecc3daf264359426acd7b5dbf9ff84fd96c4c --- /dev/null +++ b/hercules_test/test_inputs/forkify/phi_loop4.hir @@ -0,0 +1,16 @@ +fn loop<1>(a: u32) -> i32 + zero_idx = constant(u64, 0) + one_idx = constant(u64, 1) + zero_var = constant(i32, 0) + one_var = constant(i32, 1) + bound = dynamic_constant(#0) + loop = region(start, if_true) + var = phi(loop, zero_var, var_inc) + var_inc = add(var, one_var) + idx = phi(loop, zero_idx, idx_inc) + idx_inc = add(idx, one_idx) + in_bounds = lt(idx, bound) + if = if(loop, in_bounds) + if_false = projection(if, 0) + if_true = projection(if, 1) + r = return(if_false, var_inc) \ No newline at end of file diff --git a/hercules_test/test_inputs/forkify/split_phi_cycle.hir b/hercules_test/test_inputs/forkify/split_phi_cycle.hir new file mode 100644 index 0000000000000000000000000000000000000000..96de73c8e054fb9cb45ade6dfe9150fcfc79334f --- /dev/null +++ b/hercules_test/test_inputs/forkify/split_phi_cycle.hir @@ -0,0 +1,16 @@ +fn sum<1>(a: i32) -> u64 + zero_idx = constant(u64, 0) + one_idx = constant(u64, 1) + two = constant(u64, 2) + bound = dynamic_constant(#0) + loop = region(start, if_true) + idx = phi(loop, zero_idx, idx_inc) + idx_inc = add(idx, one_idx) + first_red = phi(loop, zero_idx, first_red_add_2) + first_red_add = add(first_red, idx) + first_red_add_2 = add(first_red_add, two) + in_bounds = lt(idx_inc, bound) + if = if(loop, in_bounds) + if_false = projection(if, 0) + if_true = projection(if, 1) + r = return(if_false, first_red_add_2) \ No newline at end of file diff --git a/hercules_test/test_inputs/forkify/super_nested_loop.hir b/hercules_test/test_inputs/forkify/super_nested_loop.hir new file mode 100644 index 0000000000000000000000000000000000000000..6853efbfc2b620860644bc94486fb09a09e131f0 --- /dev/null +++ b/hercules_test/test_inputs/forkify/super_nested_loop.hir @@ -0,0 +1,35 @@ +fn loop<3>(a: u32) -> i32 + zero_idx = constant(u64, 0) + one_idx = constant(u64, 1) + zero_var = constant(i32, 0) + one_var = constant(i32, 1) + inner_bound = dynamic_constant(#0) + outer_loop = region(outer_outer_if_true, inner_if_false) + inner_loop = region(outer_if_true, inner_if_true) + outer_var = phi(outer_loop, outer_outer_var, inner_var) + inner_var = phi(inner_loop, outer_var, inner_var_inc) + inner_var_inc = add(inner_var, one_var) + inner_idx = phi(inner_loop, zero_idx, inner_idx_inc) + inner_idx_inc = add(inner_idx, one_idx) + inner_in_bounds = lt(inner_idx, inner_bound) + outer_idx = phi(outer_loop, zero_idx, outer_idx_inc, outer_idx) + outer_idx_inc = add(outer_idx, one_idx) + outer_in_bounds = lt(outer_idx, outer_bound) + inner_if = if(inner_loop, inner_in_bounds) + inner_if_false = projection(inner_if, 0) + inner_if_true = projection(inner_if, 1) + outer_if = if(outer_loop, outer_in_bounds) + outer_if_false = projection(outer_if, 0) + outer_if_true = projection(outer_if, 1) + outer_bound = dynamic_constant(#1) + outer_outer_bound = dynamic_constant(#2) + outer_outer_loop = region(start, outer_if_false) + outer_outer_var = phi(outer_outer_loop, zero_var, outer_var) + outer_outer_if = if(outer_outer_loop, outer_outer_in_bounds) + outer_outer_if_false = projection(outer_outer_if, 0) + outer_outer_if_true = projection(outer_outer_if, 1) + outer_outer_idx = phi(outer_outer_loop, zero_idx, outer_outer_idx_inc, outer_outer_idx) + outer_outer_idx_inc = add(outer_outer_idx, one_idx) + outer_outer_in_bounds = lt(outer_outer_idx, outer_outer_bound) + r = return(outer_outer_if_false, outer_outer_var) + diff --git a/hercules_test/test_inputs/loop_analysis/alternate_bounds.hir b/hercules_test/test_inputs/loop_analysis/alternate_bounds.hir new file mode 100644 index 0000000000000000000000000000000000000000..4df92a18a9895c88cf27f143285999ef2218bfcf --- /dev/null +++ b/hercules_test/test_inputs/loop_analysis/alternate_bounds.hir @@ -0,0 +1,14 @@ +fn sum<1>(a: u32) -> u64 + zero_idx = constant(u64, 0) + one_idx = constant(u64, 1) + bound = dynamic_constant(#0) + loop = region(start, if_true) + idx = phi(loop, zero_idx, idx_inc) + idx_inc = add(idx, one_idx) + red = phi(loop, zero_idx, red_add) + red_add = add(red, one_idx) + in_bounds = lt(idx_inc, bound) + if = if(loop, in_bounds) + if_false = projection(if, 0) + if_true = projection(if, 1) + r = return(if_false, red) \ No newline at end of file diff --git a/hercules_test/test_inputs/loop_analysis/alternate_bounds_internal_control.hir b/hercules_test/test_inputs/loop_analysis/alternate_bounds_internal_control.hir new file mode 100644 index 0000000000000000000000000000000000000000..8b4431bfb446237b6a66088d7a7d2339d9c889d1 --- /dev/null +++ b/hercules_test/test_inputs/loop_analysis/alternate_bounds_internal_control.hir @@ -0,0 +1,23 @@ +fn sum<1>(a: u64) -> u64 + zero_idx = constant(u64, 0) + one_idx = constant(u64, 1) + two = constant(u64, 2) + ten = constant(u64, 10) + bound = dynamic_constant(#0) + loop = region(start, if_true) + inner_ctrl = region(loop) + inner_phi = phi(inner_ctrl, idx) + idx = phi(loop, zero_idx, idx_inc) + idx_inc = add(idx, one_idx) + red = phi(loop, zero_idx, red_add) + red_add = add(red, two) + red2 = phi(loop, zero_idx, red_add2) + red_add2 = add(red, inner_phi) + in_bounds = lt(idx_inc, bound) + if = if(inner_ctrl, in_bounds) + if_false = projection(if, 0) + if_true = projection(if, 1) + plus_ten = add(red_add, ten) + red_add_2_plus_blah = add(red2, plus_ten) + final_add = add(inner_phi, red_add_2_plus_blah) + r = return(if_false, final_add) \ No newline at end of file diff --git a/hercules_test/test_inputs/loop_analysis/alternate_bounds_internal_control2.hir b/hercules_test/test_inputs/loop_analysis/alternate_bounds_internal_control2.hir new file mode 100644 index 0000000000000000000000000000000000000000..f4adf6435968ddaccc285c8d6132e7c9dd91c973 --- /dev/null +++ b/hercules_test/test_inputs/loop_analysis/alternate_bounds_internal_control2.hir @@ -0,0 +1,21 @@ +fn sum<1>(a: u64) -> u64 + zero_idx = constant(u64, 0) + one_idx = constant(u64, 1) + two = constant(u64, 2) + ten = constant(u64, 10) + bound = dynamic_constant(#0) + loop = region(start, if_true) + inner_ctrl = region(loop) + inner_phi = phi(inner_ctrl, idx) + idx = phi(loop, zero_idx, idx_inc) + idx_inc = add(idx, one_idx) + red = phi(loop, zero_idx, red_add) + red_add = add(red, two) + in_bounds = lt(idx_inc, bound) + if = if(inner_ctrl, in_bounds) + if_false = projection(if, 0) + if_true = projection(if, 1) + plus_ten = add(red_add, ten) + red_add_2_plus_blah = add(inner_phi, plus_ten) + final_add = add(inner_phi, red_add_2_plus_blah) + r = return(if_false, final_add) \ No newline at end of file diff --git a/hercules_test/test_inputs/loop_analysis/alternate_bounds_nested_do_loop.hir b/hercules_test/test_inputs/loop_analysis/alternate_bounds_nested_do_loop.hir new file mode 100644 index 0000000000000000000000000000000000000000..52f701727c33305c0719e46aacf717bf8b220fcb --- /dev/null +++ b/hercules_test/test_inputs/loop_analysis/alternate_bounds_nested_do_loop.hir @@ -0,0 +1,28 @@ +fn loop<2>(a: u64) -> u64 + zero_idx = constant(u64, 0) + one_idx = constant(u64, 1) + zero_var = constant(u64, 0) + one_var = constant(u64, 1) + ten = constant(u64, 10) + inner_bound = dynamic_constant(#0) + outer_bound = dynamic_constant(#1) + outer_loop = region(start, outer_if_true) + inner_loop = region(outer_loop, inner_if_true) + outer_var = phi(outer_loop, zero_var, inner_var_inc) + inner_var = phi(inner_loop, outer_var, inner_var_inc) + inner_var_inc = add(inner_var, blah2) + inner_idx = phi(inner_loop, zero_idx, inner_idx_inc) + blah = mul(outer_idx, ten) + blah2 = add(blah, inner_idx) + inner_idx_inc = add(inner_idx, one_idx) + inner_in_bounds = lt(inner_idx_inc, inner_bound) + outer_idx = phi(outer_loop, zero_idx, outer_idx_inc, outer_idx) + outer_idx_inc = add(outer_idx, one_idx) + outer_in_bounds = lt(outer_idx_inc, outer_bound) + inner_if = if(inner_loop, inner_in_bounds) + inner_if_false = projection(inner_if, 0) + inner_if_true = projection(inner_if, 1) + outer_if = if(inner_if_false, outer_in_bounds) + outer_if_false = projection(outer_if, 0) + outer_if_true = projection(outer_if, 1) + r = return(outer_if_false, inner_var_inc) \ No newline at end of file diff --git a/hercules_test/test_inputs/loop_analysis/alternate_bounds_nested_do_loop2.hir b/hercules_test/test_inputs/loop_analysis/alternate_bounds_nested_do_loop2.hir new file mode 100644 index 0000000000000000000000000000000000000000..f295b39166fc16bb4560d9100bb84526281dda84 --- /dev/null +++ b/hercules_test/test_inputs/loop_analysis/alternate_bounds_nested_do_loop2.hir @@ -0,0 +1,25 @@ +fn loop<2>(a: u32) -> i32 + zero_idx = constant(u64, 0) + one_idx = constant(u64, 1) + zero_var = constant(i32, 0) + one_var = constant(i32, 1) + inner_bound = dynamic_constant(#0) + outer_bound = dynamic_constant(#1) + outer_loop = region(start, outer_if_true) + inner_loop = region(outer_loop, inner_if_true) + outer_var = phi(outer_loop, zero_var, inner_var_inc) + inner_var = phi(inner_loop, outer_var, inner_var_inc) + inner_var_inc = add(inner_var, one_var) + inner_idx = phi(inner_loop, zero_idx, inner_idx_inc) + inner_idx_inc = add(inner_idx, one_idx) + inner_in_bounds = lt(inner_idx_inc, inner_bound) + outer_idx = phi(outer_loop, zero_idx, outer_idx_inc, outer_idx) + outer_idx_inc = add(outer_idx, one_idx) + outer_in_bounds = lt(outer_idx_inc, outer_bound) + inner_if = if(inner_loop, inner_in_bounds) + inner_if_false = projection(inner_if, 0) + inner_if_true = projection(inner_if, 1) + outer_if = if(inner_if_false, outer_in_bounds) + outer_if_false = projection(outer_if, 0) + outer_if_true = projection(outer_if, 1) + r = return(outer_if_false, inner_var_inc) \ No newline at end of file diff --git a/hercules_test/test_inputs/loop_analysis/alternate_bounds_nested_do_loop_array.hir b/hercules_test/test_inputs/loop_analysis/alternate_bounds_nested_do_loop_array.hir new file mode 100644 index 0000000000000000000000000000000000000000..e5401779a28503677f4a5e51c703c4197b433d4e --- /dev/null +++ b/hercules_test/test_inputs/loop_analysis/alternate_bounds_nested_do_loop_array.hir @@ -0,0 +1,28 @@ +fn loop<2>(a: array(u64, #1)) -> u64 + zero_idx = constant(u64, 0) + one_idx = constant(u64, 1) + zero_var = constant(u64, 0) + one_var = constant(u64, 1) + ten = constant(u64, 10) + inner_bound = dynamic_constant(#0) + outer_bound = dynamic_constant(#1) + outer_loop = region(start, outer_if_true) + inner_loop = region(outer_loop, inner_if_true) + outer_var = phi(outer_loop, zero_var, inner_var_inc) + inner_var = phi(inner_loop, outer_var, inner_var_inc) + inner_var_inc = add(inner_var, blah2) + inner_idx = phi(inner_loop, zero_idx, inner_idx_inc) + blah = read(a, position(outer_idx)) + blah2 = add(blah, inner_idx) + inner_idx_inc = add(inner_idx, one_idx) + inner_in_bounds = lt(inner_idx_inc, inner_bound) + outer_idx = phi(outer_loop, zero_idx, outer_idx_inc, outer_idx) + outer_idx_inc = add(outer_idx, one_idx) + outer_in_bounds = lt(outer_idx_inc, outer_bound) + inner_if = if(inner_loop, inner_in_bounds) + inner_if_false = projection(inner_if, 0) + inner_if_true = projection(inner_if, 1) + outer_if = if(inner_if_false, outer_in_bounds) + outer_if_false = projection(outer_if, 0) + outer_if_true = projection(outer_if, 1) + r = return(outer_if_false, inner_var_inc) \ No newline at end of file diff --git a/hercules_test/test_inputs/loop_analysis/alternate_bounds_nested_do_loop_guarded.hir b/hercules_test/test_inputs/loop_analysis/alternate_bounds_nested_do_loop_guarded.hir new file mode 100644 index 0000000000000000000000000000000000000000..b979ad42ce0522cc9b27d1fab07736cc73831590 --- /dev/null +++ b/hercules_test/test_inputs/loop_analysis/alternate_bounds_nested_do_loop_guarded.hir @@ -0,0 +1,40 @@ +fn loop<2>(a: u64) -> u64 + zero_idx = constant(u64, 0) + one_idx = constant(u64, 1) + zero_var = constant(u64, 0) + one_var = constant(u64, 1) + ten = constant(u64, 10) + outer_guard_if = if(start, outer_guard_lt) + outer_guard_if_false = projection(outer_guard_if, 0) + outer_guard_if_true = projection(outer_guard_if, 1) + outer_guard_lt = lt(zero_idx, outer_bound) + outer_join = region(outer_guard_if_false, outer_if_false) + outer_join_var = phi(outer_join, zero_idx, join_var) + inner_bound = dynamic_constant(#0) + outer_bound = dynamic_constant(#1) + outer_loop = region(outer_guard_if_true, outer_if_true) + inner_loop = region(guard_if_true, inner_if_true) + guard_lt = lt(zero_idx, inner_bound) + guard_if = if(outer_loop, guard_lt) + guard_if_true = projection(guard_if, 1) + guard_if_false = projection(guard_if, 0) + guard_join = region(guard_if_false, inner_if_false) + inner_idx = phi(inner_loop, zero_idx, inner_idx_inc) + inner_idx_inc = add(inner_idx, one_idx) + inner_in_bounds = lt(inner_idx_inc, inner_bound) + outer_idx = phi(outer_loop, zero_idx, outer_idx_inc, outer_idx) + outer_idx_inc = add(outer_idx, one_idx) + outer_in_bounds = lt(outer_idx_inc, outer_bound) + inner_if = if(inner_loop, inner_in_bounds) + inner_if_false = projection(inner_if, 0) + inner_if_true = projection(inner_if, 1) + outer_if = if(guard_join, outer_in_bounds) + outer_if_false = projection(outer_if, 0) + outer_if_true = projection(outer_if, 1) + outer_var = phi(outer_loop, zero_var, join_var) + inner_var = phi(inner_loop, outer_var, inner_var_inc) + blah = mul(outer_idx, ten) + blah2 = add(blah, inner_idx) + inner_var_inc = add(inner_var, blah2) + join_var = phi(guard_join, outer_var, inner_var_inc) + r = return(outer_join, outer_join_var) \ No newline at end of file diff --git a/hercules_test/test_inputs/loop_analysis/alternate_bounds_use_after_loop.hir b/hercules_test/test_inputs/loop_analysis/alternate_bounds_use_after_loop.hir new file mode 100644 index 0000000000000000000000000000000000000000..2fe4ca57345dd0e6c4dd94e399dba58e3cab81a4 --- /dev/null +++ b/hercules_test/test_inputs/loop_analysis/alternate_bounds_use_after_loop.hir @@ -0,0 +1,21 @@ +fn sum<1>(a: array(i32, #0)) -> i32 + zero_idx = constant(u64, 0) + one_idx = constant(u64, 1) + zero_inc = constant(i32, 0) + ten = constant(i32, 10) + three = constant(i32, 3) + bound = dynamic_constant(#0) + loop = region(start, if_true) + idx = phi(loop, zero_idx, idx_inc) + idx_inc = add(idx, one_idx) + red = phi(loop, zero_inc, red_add) + read = read(a, position(idx)) + red_add = add(red, read) + in_bounds = lt(idx_inc, bound) + if = if(loop, in_bounds) + if_false = projection(if, 0) + if_true = projection(if, 1) + plus_ten = add(red_add, ten) + mult = mul(read, three) + final = add(plus_ten, mult) + r = return(if_false, final) \ No newline at end of file diff --git a/hercules_test/test_inputs/loop_analysis/alternate_bounds_use_after_loop2.hir b/hercules_test/test_inputs/loop_analysis/alternate_bounds_use_after_loop2.hir new file mode 100644 index 0000000000000000000000000000000000000000..760ae5ad690382c42e6760af690b18e5ea36a6b2 --- /dev/null +++ b/hercules_test/test_inputs/loop_analysis/alternate_bounds_use_after_loop2.hir @@ -0,0 +1,21 @@ +fn sum<1>(a: array(i32, #0)) -> i32 + zero_idx = constant(u64, 0) + one_idx = constant(u64, 1) + zero_inc = constant(i32, 0) + ten = constant(i32, 10) + three = constant(i32, 3) + bound = dynamic_constant(#0) + loop = region(start, if_true) + idx = phi(loop, zero_idx, idx_inc) + idx_inc = add(idx, one_idx) + red = phi(loop, zero_inc, red_add) + read = read(a, position(idx_inc)) + red_add = add(red, read) + in_bounds = lt(idx_inc, bound) + if = if(loop, in_bounds) + if_false = projection(if, 0) + if_true = projection(if, 1) + plus_ten = add(red, ten) + mult = mul(read, three) + final = add(plus_ten, mult) + r = return(if_false, final) \ No newline at end of file diff --git a/hercules_test/test_inputs/loop_analysis/alternate_bounds_use_after_loop_no_tid.hir b/hercules_test/test_inputs/loop_analysis/alternate_bounds_use_after_loop_no_tid.hir new file mode 100644 index 0000000000000000000000000000000000000000..4b9375090dd5c13a127e9db85d7ce6dc8f2f7d75 --- /dev/null +++ b/hercules_test/test_inputs/loop_analysis/alternate_bounds_use_after_loop_no_tid.hir @@ -0,0 +1,17 @@ +fn sum<1>(a: u64) -> u64 + zero_idx = constant(u64, 0) + one_idx = constant(u64, 1) + two = constant(u64, 2) + ten = constant(u64, 10) + bound = dynamic_constant(#0) + loop = region(start, if_true) + idx = phi(loop, zero_idx, idx_inc) + idx_inc = add(idx, one_idx) + red = phi(loop, zero_idx, red_add) + red_add = add(red, two) + in_bounds = lt(idx_inc, bound) + if = if(loop, in_bounds) + if_false = projection(if, 0) + if_true = projection(if, 1) + plus_ten = add(red_add, ten) + r = return(if_false, plus_ten) \ No newline at end of file diff --git a/hercules_test/test_inputs/loop_analysis/alternate_bounds_use_after_loop_no_tid2.hir b/hercules_test/test_inputs/loop_analysis/alternate_bounds_use_after_loop_no_tid2.hir new file mode 100644 index 0000000000000000000000000000000000000000..fd06eb7dd22f64022cc797a549946bfda5e8b7cf --- /dev/null +++ b/hercules_test/test_inputs/loop_analysis/alternate_bounds_use_after_loop_no_tid2.hir @@ -0,0 +1,19 @@ +fn sum<1>(a: u64) -> u64 + zero_idx = constant(u64, 0) + one_idx = constant(u64, 1) + two = constant(u64, 2) + ten = constant(u64, 10) + bound = dynamic_constant(#0) + loop = region(start, if_true) + idx = phi(loop, zero_idx, idx_inc) + idx_inc = add(idx, one_idx) + red = phi(loop, zero_idx, red_add) + red_add = add(red, two) + blah = phi(loop, zero_idx, red_add) + in_bounds = lt(idx_inc, bound) + if = if(loop, in_bounds) + if_false = projection(if, 0) + if_true = projection(if, 1) + plus_ten = add(red_add, ten) + plus_blah = add(blah, red_add) + r = return(if_false, plus_blah) \ No newline at end of file diff --git a/hercules_test/test_inputs/loop_analysis/broken_sum.hir b/hercules_test/test_inputs/loop_analysis/broken_sum.hir new file mode 100644 index 0000000000000000000000000000000000000000..d15ef5613e271cd8685660785f5505dedbf40ec9 --- /dev/null +++ b/hercules_test/test_inputs/loop_analysis/broken_sum.hir @@ -0,0 +1,16 @@ +fn sum<1>(a: array(i32, #0)) -> i32 + zero_idx = constant(u64, 0) + one_idx = constant(u64, 1) + zero_inc = constant(i32, 0) + bound = dynamic_constant(#0) + loop = region(start, if_true) + idx = phi(loop, zero_idx, idx_inc) + idx_inc = add(idx, one_idx) + red = phi(loop, zero_inc, red_add) + read = read(a, position(idx)) + red_add = add(red, read) + in_bounds = lt(idx, bound) + if = if(loop, in_bounds) + if_false = projection(if, 0) + if_true = projection(if, 1) + r = return(if_false, red_add) \ No newline at end of file diff --git a/hercules_test/test_inputs/loop_analysis/do_loop_far_guard.hir b/hercules_test/test_inputs/loop_analysis/do_loop_far_guard.hir new file mode 100644 index 0000000000000000000000000000000000000000..4df92a18a9895c88cf27f143285999ef2218bfcf --- /dev/null +++ b/hercules_test/test_inputs/loop_analysis/do_loop_far_guard.hir @@ -0,0 +1,14 @@ +fn sum<1>(a: u32) -> u64 + zero_idx = constant(u64, 0) + one_idx = constant(u64, 1) + bound = dynamic_constant(#0) + loop = region(start, if_true) + idx = phi(loop, zero_idx, idx_inc) + idx_inc = add(idx, one_idx) + red = phi(loop, zero_idx, red_add) + red_add = add(red, one_idx) + in_bounds = lt(idx_inc, bound) + if = if(loop, in_bounds) + if_false = projection(if, 0) + if_true = projection(if, 1) + r = return(if_false, red) \ No newline at end of file diff --git a/hercules_test/test_inputs/loop_analysis/do_loop_immediate_guard.hir b/hercules_test/test_inputs/loop_analysis/do_loop_immediate_guard.hir new file mode 100644 index 0000000000000000000000000000000000000000..a4732cdeaa1f474548fba0293e21a900448d1791 --- /dev/null +++ b/hercules_test/test_inputs/loop_analysis/do_loop_immediate_guard.hir @@ -0,0 +1,21 @@ +fn sum<1>(a: u64) -> u64 + zero_idx = constant(u64, 0) + one_idx = constant(u64, 1) + bound = dynamic_constant(#0) + guard_lt = lt(zero_idx, bound) + guard = if(start, guard_lt) + guard_true = projection(guard, 1) + guard_false = projection(guard, 0) + loop = region(guard_true, if_true) + inner_side_effect = region(loop) + idx = phi(loop, zero_idx, idx_inc) + idx_inc = add(idx, one_idx) + red = phi(loop, zero_idx, red_add) + red_add = add(red, one_idx) + join_phi = phi(final, zero_idx, red_add) + in_bounds = lt(idx_inc, bound) + if = if(inner_side_effect, in_bounds) + if_false = projection(if, 0) + if_true = projection(if, 1) + final = region(guard_false, if_false) + r = return(final, join_phi) \ No newline at end of file diff --git a/hercules_test/test_inputs/loop_analysis/do_loop_no_guard.hir b/hercules_test/test_inputs/loop_analysis/do_loop_no_guard.hir new file mode 100644 index 0000000000000000000000000000000000000000..9e22e14baba40dcda34de4a9d36c05dfe73f11eb --- /dev/null +++ b/hercules_test/test_inputs/loop_analysis/do_loop_no_guard.hir @@ -0,0 +1,15 @@ +fn sum<1>(a: u64) -> u64 + zero_idx = constant(u64, 0) + one_idx = constant(u64, 1) + bound = dynamic_constant(#0) + loop = region(start, if_true) + inner_side_effect = region(loop) + idx = phi(loop, zero_idx, idx_inc) + idx_inc = add(idx, one_idx) + red = phi(loop, zero_idx, red_add) + red_add = add(red, one_idx) + in_bounds = lt(idx_inc, bound) + if = if(inner_side_effect, in_bounds) + if_false = projection(if, 0) + if_true = projection(if, 1) + r = return(if_false, red_add) \ No newline at end of file diff --git a/hercules_test/test_inputs/loop_analysis/do_while_separate_body.hir b/hercules_test/test_inputs/loop_analysis/do_while_separate_body.hir new file mode 100644 index 0000000000000000000000000000000000000000..42269040615520dd5ff2c151bd545a94445f520f --- /dev/null +++ b/hercules_test/test_inputs/loop_analysis/do_while_separate_body.hir @@ -0,0 +1,16 @@ +fn sum<1>(a: i32) -> u64 + zero_idx = constant(u64, 0) + one_idx = constant(u64, 1) + bound = dynamic_constant(#0) + loop = region(start, if_true) + inner_region = region(loop) + inner_red = phi(inner_region, red_add) + idx = phi(loop, zero_idx, idx_inc) + idx_inc = add(idx, one_idx) + outer_red = phi(loop, zero_idx, inner_red) + red_add = add(outer_red, idx) + in_bounds = lt(idx_inc, bound) + if = if(inner_region, in_bounds) + if_false = projection(if, 0) + if_true = projection(if, 1) + r = return(if_false, inner_red) \ No newline at end of file diff --git a/hercules_test/test_inputs/loop_analysis/do_while_separate_body2.hir b/hercules_test/test_inputs/loop_analysis/do_while_separate_body2.hir new file mode 100644 index 0000000000000000000000000000000000000000..a751952dcde83e7ffc0ee64f506314724f7bd745 --- /dev/null +++ b/hercules_test/test_inputs/loop_analysis/do_while_separate_body2.hir @@ -0,0 +1,18 @@ +fn sum<1>(a: i32) -> u64 + zero_idx = constant(u64, 0) + one_idx = constant(u64, 1) + two = constant(u64, 1) + bound = dynamic_constant(#0) + loop = region(start, if_true) + inner_region = region(loop) + inner_red = phi(inner_region, red_mul) + idx = phi(loop, zero_idx, idx_inc) + idx_inc = add(idx, one_idx) + outer_red = phi(loop, zero_idx, inner_red) + red_add = add(outer_red, idx) + red_mul = mul(red_add, idx) + in_bounds = lt(idx_inc, bound) + if = if(inner_region, in_bounds) + if_false = projection(if, 0) + if_true = projection(if, 1) + r = return(if_false, inner_red) \ No newline at end of file diff --git a/hercules_test/test_inputs/loop_analysis/loop_array_sum.hir b/hercules_test/test_inputs/loop_analysis/loop_array_sum.hir new file mode 100644 index 0000000000000000000000000000000000000000..f9972b5917c200b93b5775fd4a6e501318e8c548 --- /dev/null +++ b/hercules_test/test_inputs/loop_analysis/loop_array_sum.hir @@ -0,0 +1,16 @@ +fn sum<1>(a: array(i32, #0)) -> i32 + zero_idx = constant(u64, 0) + one_idx = constant(u64, 1) + zero_inc = constant(i32, 0) + bound = dynamic_constant(#0) + loop = region(start, if_true) + idx = phi(loop, zero_idx, idx_inc) + idx_inc = add(idx, one_idx) + red = phi(loop, zero_inc, red_add) + read = read(a, position(idx)) + red_add = add(red, read) + in_bounds = lt(idx, bound) + if = if(loop, in_bounds) + if_false = projection(if, 0) + if_true = projection(if, 1) + r = return(if_false, red) \ No newline at end of file diff --git a/hercules_test/test_inputs/loop_analysis/loop_body_count.hir b/hercules_test/test_inputs/loop_analysis/loop_body_count.hir new file mode 100644 index 0000000000000000000000000000000000000000..c6f3cbf649484f9b00f0fcc9cd208a8b4811284f --- /dev/null +++ b/hercules_test/test_inputs/loop_analysis/loop_body_count.hir @@ -0,0 +1,16 @@ +fn loop<1>(a: u64) -> u64 + zero_idx = constant(u64, 0) + one_idx = constant(u64, 1) + zero_var = constant(u64, 0) + one_var = constant(u64, 1) + bound = dynamic_constant(#0) + loop = region(start, if_true) + var = phi(loop, zero_var, var_inc) + var_inc = add(var, one_var) + idx = phi(loop, zero_idx, idx_inc) + idx_inc = add(idx, one_idx) + in_bounds = lt(idx, bound) + if = if(loop, in_bounds) + if_false = projection(if, 0) + if_true = projection(if, 1) + r = return(if_false, var) \ No newline at end of file diff --git a/hercules_test/test_inputs/loop_analysis/loop_sum.hir b/hercules_test/test_inputs/loop_analysis/loop_sum.hir new file mode 100644 index 0000000000000000000000000000000000000000..fd9c4debc163600c01e661b127b166358ac9c6db --- /dev/null +++ b/hercules_test/test_inputs/loop_analysis/loop_sum.hir @@ -0,0 +1,16 @@ +fn loop<1>(a: u32) -> i32 + zero_idx = constant(u64, 0) + one_idx = constant(u64, 1) + zero_var = constant(i32, 0) + one_var = constant(i32, 1) + bound = dynamic_constant(#0) + loop = region(start, if_true) + var = phi(loop, zero_var, var_inc) + var_inc = add(var, one_var) + idx = phi(loop, zero_idx, idx_inc) + idx_inc = add(idx, one_idx) + in_bounds = lt(idx, bound) + if = if(loop, in_bounds) + if_false = projection(if, 0) + if_true = projection(if, 1) + r = return(if_false, var) \ No newline at end of file diff --git a/hercules_test/test_inputs/loop_analysis/loop_trip_count.hir b/hercules_test/test_inputs/loop_analysis/loop_trip_count.hir new file mode 100644 index 0000000000000000000000000000000000000000..b756f0901fb7a66a3feb83d1611aa1711bcb5601 --- /dev/null +++ b/hercules_test/test_inputs/loop_analysis/loop_trip_count.hir @@ -0,0 +1,19 @@ +fn loop<1>(b: prod(u64, u64)) -> prod(u64, u64) + zero_idx = constant(u64, 0) + one_idx = constant(u64, 1) + zero_var = constant(u64, 0) + one_var = constant(u64, 1) + c = constant(prod(u64, u64), (0, 0)) + bound = dynamic_constant(#0) + loop = region(start, if_true) + var = phi(loop, zero_var, var_inc) + var_inc = add(var, one_var) + idx = phi(loop, zero_idx, idx_inc) + idx_inc = add(idx, one_idx) + in_bounds = lt(idx, bound) + if = if(loop, in_bounds) + if_false = projection(if, 0) + if_true = projection(if, 1) + tuple1 = write(c, var, field(0)) + tuple2 = write(tuple1, idx, field(1)) + r = return(if_false, tuple2) \ No newline at end of file diff --git a/hercules_test/test_inputs/loop_analysis/loop_trip_count_tuple.hir b/hercules_test/test_inputs/loop_analysis/loop_trip_count_tuple.hir new file mode 100644 index 0000000000000000000000000000000000000000..b756f0901fb7a66a3feb83d1611aa1711bcb5601 --- /dev/null +++ b/hercules_test/test_inputs/loop_analysis/loop_trip_count_tuple.hir @@ -0,0 +1,19 @@ +fn loop<1>(b: prod(u64, u64)) -> prod(u64, u64) + zero_idx = constant(u64, 0) + one_idx = constant(u64, 1) + zero_var = constant(u64, 0) + one_var = constant(u64, 1) + c = constant(prod(u64, u64), (0, 0)) + bound = dynamic_constant(#0) + loop = region(start, if_true) + var = phi(loop, zero_var, var_inc) + var_inc = add(var, one_var) + idx = phi(loop, zero_idx, idx_inc) + idx_inc = add(idx, one_idx) + in_bounds = lt(idx, bound) + if = if(loop, in_bounds) + if_false = projection(if, 0) + if_true = projection(if, 1) + tuple1 = write(c, var, field(0)) + tuple2 = write(tuple1, idx, field(1)) + r = return(if_false, tuple2) \ No newline at end of file diff --git a/hercules_test/test_inputs/matmul_int.hir b/hercules_test/test_inputs/matmul_int.hir index 1e496babc9aaa8877f4607e2f8d572a0e5ce9e25..ab0f384a563ccb6144e59b811745fe5aa76f08dd 100644 --- a/hercules_test/test_inputs/matmul_int.hir +++ b/hercules_test/test_inputs/matmul_int.hir @@ -1,21 +1,18 @@ fn matmul<3>(a: array(i32, #0, #1), b: array(i32, #1, #2)) -> array(i32, #0, #2) c = constant(array(i32, #0, #2), []) - i_ctrl = fork(start, #0) - i_idx = thread_id(i_ctrl) - j_ctrl = fork(i_ctrl, #2) - j_idx = thread_id(j_ctrl) - k_ctrl = fork(j_ctrl, #1) - k_idx = thread_id(k_ctrl) + i_j_ctrl = fork(start, #0, #2) + i_idx = thread_id(i_j_ctrl, 0) + j_idx = thread_id(i_j_ctrl, 1) + k_ctrl = fork(i_j_ctrl, #1) + k_idx = thread_id(k_ctrl, 0) k_join_ctrl = join(k_ctrl) - j_join_ctrl = join(k_join_ctrl) - i_join_ctrl = join(j_join_ctrl) - r = return(i_join_ctrl, update_i_c) + i_j_join_ctrl = join(k_join_ctrl) + r = return(i_j_join_ctrl, update_i_j_c) zero = constant(i32, 0) a_val = read(a, position(i_idx, k_idx)) b_val = read(b, position(k_idx, j_idx)) mul = mul(a_val, b_val) add = add(mul, dot) dot = reduce(k_join_ctrl, zero, add) - updated_c = write(update_j_c, dot, position(i_idx, j_idx)) - update_j_c = reduce(j_join_ctrl, update_i_c, updated_c) - update_i_c = reduce(i_join_ctrl, c, update_j_c) + update_c = write(update_i_j_c, dot, position(i_idx, j_idx)) + update_i_j_c = reduce(i_j_join_ctrl, c, update_c) \ No newline at end of file diff --git a/juno_samples/cava/build.rs b/juno_samples/cava/build.rs index 929d3eba3e1c83f185c2c0ff256450b05247c80d..7f60f8019f105dda64cf12a2664668efbd637662 100644 --- a/juno_samples/cava/build.rs +++ b/juno_samples/cava/build.rs @@ -1,4 +1,3 @@ -extern crate juno_build; use juno_build::JunoCompiler; fn main() { diff --git a/juno_samples/cava/src/main.rs b/juno_samples/cava/src/main.rs index 73a75a94f67edfab905c8c3830191dc342da337f..8ad6824f01d4d94a76f088ba1540ba44d2ce7b71 100644 --- a/juno_samples/cava/src/main.rs +++ b/juno_samples/cava/src/main.rs @@ -59,7 +59,10 @@ fn run_cava( tonemap, ) .await - }).as_slice::<u8>().to_vec().into_boxed_slice() + }) + .as_slice::<u8>() + .to_vec() + .into_boxed_slice() } enum Error { diff --git a/juno_samples/matmul/build.rs b/juno_samples/matmul/build.rs index 926fbc33ecfa5ab31b40a92f778bb4d3b7f6a77e..511bf483099ba78cc62754b39146517aa3623103 100644 --- a/juno_samples/matmul/build.rs +++ b/juno_samples/matmul/build.rs @@ -4,6 +4,8 @@ fn main() { JunoCompiler::new() .file_in_src("matmul.jn") .unwrap() + //.schedule_in_src("sched.sch") + //.unwrap() .build() .unwrap(); } diff --git a/juno_samples/matmul/src/main.rs b/juno_samples/matmul/src/main.rs index fa5d1f04d48cdf48cf377e8f3d08de80d30e688e..624ee5652a78d9c2ab7bc84d3974bf2df5b02838 100644 --- a/juno_samples/matmul/src/main.rs +++ b/juno_samples/matmul/src/main.rs @@ -24,10 +24,14 @@ fn main() { let a = HerculesCPURef::from_slice(&a); let b = HerculesCPURef::from_slice(&b); let mut r = runner!(matmul); - let c = r.run(I as u64, J as u64, K as u64, a.clone(), b.clone()).await; + let c = r + .run(I as u64, J as u64, K as u64, a.clone(), b.clone()) + .await; assert_eq!(c.as_slice::<i32>(), &*correct_c); let mut r = runner!(tiled_64_matmul); - let tiled_c = r.run(I as u64, J as u64, K as u64, a.clone(), b.clone()).await; + let tiled_c = r + .run(I as u64, J as u64, K as u64, a.clone(), b.clone()) + .await; assert_eq!(tiled_c.as_slice::<i32>(), &*correct_c); }); } @@ -36,4 +40,3 @@ fn main() { fn matmul_test() { main(); } - diff --git a/juno_samples/matmul/src/sched.sch b/juno_samples/matmul/src/sched.sch new file mode 100644 index 0000000000000000000000000000000000000000..3999f92389317ffdabec519d26af5a603ba17177 --- /dev/null +++ b/juno_samples/matmul/src/sched.sch @@ -0,0 +1,76 @@ +macro juno-setup!(X) { + gvn(X); + dce(X); + phi-elim(X); +} + +macro default!(X) { + dce(X); + crc(X); + dce(X); + slf(X); + dce(X); + inline(X); + ip-sroa(X); + sroa(X); + phi-elim(X); + dce(X); + ccp(X); + dce(X); + gvn(X); + dce(X); + write-predication(X); + phi-elim(X); + dce(X); + crc(X); + dce(X); + slf(X); + dce(X); + predication(X); + dce(X); + ccp(X); + dce(X); + gvn(X); + dce(X); + lift-dc-math(X); + dce(X); + gvn(X); + dce(X); +} + +macro codegen-prep!(X) { + verify(*); + ip-sroa(*); + sroa(*); + infer-schedules(X); + dce(X); + gcm(X); + dce(X); + phi-elim(X); + float-collections(X); + gcm(X); +} + +juno-setup!(*); +default!(*); +// your stuff here. + +fixpoint stop after 13 { + forkify(*); + fork-guard-elim(*); + fork-coalesce(*); + phi-elim(*); + dce(*); +} + +xdot[true](*); +// serialize(*); + +fork-split(*); +unforkify(*); + +gvn(*); +dce(*); + +auto-outline(*); +codegen-prep!(*); diff --git a/juno_scheduler/Cargo.toml b/juno_scheduler/Cargo.toml index 1c837d4a32764abb179b6a05fd5225b808ea764a..04ab156c543fab4112eb0179eb1821a67c817d7a 100644 --- a/juno_scheduler/Cargo.toml +++ b/juno_scheduler/Cargo.toml @@ -18,3 +18,5 @@ hercules_cg = { path = "../hercules_cg" } hercules_ir = { path = "../hercules_ir" } hercules_opt = { path = "../hercules_opt" } juno_utils = { path = "../juno_utils" } +postcard = { version = "*", features = ["alloc"] } +serde = { version = "*", features = ["derive"] } \ No newline at end of file diff --git a/juno_scheduler/src/compile.rs b/juno_scheduler/src/compile.rs index fc3c1279606dd33c4fe79e1b03c9ec345cbfc61b..11a8ec53b1fd52ec37fb260d3219c849760aee49 100644 --- a/juno_scheduler/src/compile.rs +++ b/juno_scheduler/src/compile.rs @@ -4,8 +4,7 @@ use crate::parser; use juno_utils::env::Env; use juno_utils::stringtab::StringTable; -extern crate hercules_ir; -use self::hercules_ir::ir::{Device, Schedule}; +use hercules_ir::ir::{Device, Schedule}; use lrlex::DefaultLexerTypes; use lrpar::NonStreamingLexer; @@ -117,8 +116,11 @@ impl FromStr for Appliable { "slf" | "store-load-forward" => Ok(Appliable::Pass(ir::Pass::SLF)), "sroa" => Ok(Appliable::Pass(ir::Pass::SROA)), "unforkify" => Ok(Appliable::Pass(ir::Pass::Unforkify)), + "fork-coalesce" => Ok(Appliable::Pass(ir::Pass::ForkCoalesce)), "verify" => Ok(Appliable::Pass(ir::Pass::Verify)), "xdot" => Ok(Appliable::Pass(ir::Pass::Xdot)), + "serialize" => Ok(Appliable::Pass(ir::Pass::Serialize)), + "write-predication" => Ok(Appliable::Pass(ir::Pass::WritePredication)), "cpu" | "llvm" => Ok(Appliable::Device(Device::LLVM)), "gpu" | "cuda" | "nvidia" => Ok(Appliable::Device(Device::CUDA)), diff --git a/juno_scheduler/src/default.rs b/juno_scheduler/src/default.rs index 46b51b43b8040105f92d452df8f939522df40d2c..fd45a3713dfa48ac425a0c453707e133198e1d79 100644 --- a/juno_scheduler/src/default.rs +++ b/juno_scheduler/src/default.rs @@ -1,5 +1,6 @@ use crate::ir::*; +#[macro_export] macro_rules! pass { ($p:ident) => { ScheduleStmt::Let { @@ -13,6 +14,7 @@ macro_rules! pass { }; } +#[macro_export] macro_rules! default_schedule { () => { ScheduleStmt::Block { @@ -64,8 +66,9 @@ pub fn default_schedule() -> ScheduleStmt { DCE, GVN, DCE, - /*Forkify,*/ - /*ForkGuardElim,*/ + // Forkify, + // ForkGuardElim, + // ForkCoalesce, DCE, ForkSplit, Unforkify, diff --git a/juno_scheduler/src/ir.rs b/juno_scheduler/src/ir.rs index 381c3475e0b4f52285bf333f59cbe175d60fce60..d6a41baf99d8ed3cec6ab8183ae52e9956e6c5b0 100644 --- a/juno_scheduler/src/ir.rs +++ b/juno_scheduler/src/ir.rs @@ -1,6 +1,4 @@ -extern crate hercules_ir; - -use self::hercules_ir::ir::{Device, Schedule}; +use hercules_ir::ir::{Device, Schedule}; #[derive(Debug, Copy, Clone)] pub enum Pass { @@ -12,6 +10,7 @@ pub enum Pass { FloatCollections, ForkGuardElim, ForkSplit, + ForkCoalesce, Forkify, GCM, GVN, @@ -28,6 +27,7 @@ pub enum Pass { WritePredication, Verify, Xdot, + Serialize, } impl Pass { diff --git a/juno_scheduler/src/lib.rs b/juno_scheduler/src/lib.rs index 1caafe4f0f3a0275130e47bb230f8a2047865cd6..571d1fbf6da74e9ee454871cbe2fc59f43b5597e 100644 --- a/juno_scheduler/src/lib.rs +++ b/juno_scheduler/src/lib.rs @@ -14,7 +14,7 @@ use crate::parser::lexer; mod compile; mod default; -mod ir; +pub mod ir; pub mod labels; mod pm; @@ -22,7 +22,7 @@ use crate::compile::*; use crate::default::*; use crate::ir::*; use crate::labels::*; -use crate::pm::*; +pub use crate::pm::*; // Given a schedule's filename parse and process the schedule fn build_schedule(sched_filename: String) -> Result<ScheduleStmt, String> { @@ -107,6 +107,45 @@ pub fn schedule_juno( .map_err(|e| format!("Scheduling Error: {}", e)) } +pub fn run_schedule_on_hercules( + module: Module, + sched: Option<ScheduleStmt>, +) -> Result<Module, String> { + let sched = if let Some(sched) = sched { + sched + } else { + default_schedule() + }; + + // Prepare the scheduler's string table and environment + // For this, we put all of the Hercules function names into the environment + // and string table + let mut strings = StringTable::new(); + let mut env = Env::new(); + + env.open_scope(); + + for (idx, func) in module.functions.iter().enumerate() { + let func_name = strings.lookup_string(func.name.clone()); + env.insert( + func_name, + Value::HerculesFunction { + func: FunctionID::new(idx), + }, + ); + } + + env.open_scope(); + schedule_module( + module, + sched, + strings, + env, + JunoFunctions { func_ids: vec![] }, + ) + .map_err(|e| format!("Scheduling Error: {}", e)) +} + pub fn schedule_hercules( module: Module, sched_filename: Option<String>, diff --git a/juno_scheduler/src/pm.rs b/juno_scheduler/src/pm.rs index c388833ca6efbc5d1c4f9f7ec346980cbfc1b918..9888f3d2f2bd052b818049f3c225c614db54dbe7 100644 --- a/juno_scheduler/src/pm.rs +++ b/juno_scheduler/src/pm.rs @@ -4,6 +4,8 @@ use hercules_cg::*; use hercules_ir::*; use hercules_opt::*; +use serde::Deserialize; +use serde::Serialize; use tempfile::TempDir; use juno_utils::env::Env; @@ -132,7 +134,7 @@ impl Value { } } -#[derive(Debug, Clone)] +#[derive(Debug, Clone, Serialize, Deserialize)] pub enum SchedulerError { UndefinedVariable(String), UndefinedField(String), @@ -159,8 +161,8 @@ impl fmt::Display for SchedulerError { } } -#[derive(Debug)] -struct PassManager { +#[derive(Debug, Clone)] +pub struct PassManager { functions: Vec<Function>, types: RefCell<Vec<Type>>, constants: RefCell<Vec<Constant>>, @@ -189,7 +191,7 @@ struct PassManager { } impl PassManager { - fn new(module: Module) -> Self { + pub fn new(module: Module) -> Self { let Module { functions, types, @@ -500,6 +502,31 @@ impl PassManager { res } + pub fn get_module(&self) -> Module { + let PassManager { + functions, + types, + constants, + dynamic_constants, + labels, + typing: _, + control_subgraphs: _, + bbs: _, + collection_objects: _, + callgraph: _, + .. + } = self; + + let module = Module { + functions: functions.to_vec(), + types: types.clone().into_inner(), + constants: constants.clone().into_inner(), + dynamic_constants: dynamic_constants.clone().into_inner(), + labels: labels.clone().into_inner(), + }; + module + } + fn codegen(mut self, output_dir: String, module_name: String) -> Result<(), SchedulerError> { self.make_typing(); self.make_control_subgraphs(); @@ -629,6 +656,18 @@ pub fn schedule_codegen( pm.codegen(output_dir, module_name) } +pub fn schedule_module( + module: Module, + schedule: ScheduleStmt, + mut stringtab: StringTable, + mut env: Env<usize, Value>, + functions: JunoFunctions, +) -> Result<Module, SchedulerError> { + let mut pm = PassManager::new(module); + let _ = schedule_interpret(&mut pm, &schedule, &mut stringtab, &mut env, &functions)?; + Ok(pm.get_module()) +} + // Interpreter for statements and expressions returns a bool indicating whether // any optimization ran and changed the IR. This is used for implementing // the fixpoint @@ -1239,30 +1278,85 @@ fn run_pass( pm.clear_analyses(); } Pass::ForkGuardElim => { - todo!("Fork Guard Elim doesn't use editor") - } - Pass::ForkSplit => { assert!(args.is_empty()); pm.make_fork_join_maps(); - pm.make_reduce_cycles(); let fork_join_maps = pm.fork_join_maps.take().unwrap(); - let reduce_cycles = pm.reduce_cycles.take().unwrap(); - for ((func, fork_join_map), reduce_cycles) in build_selection(pm, selection) + for (func, fork_join_map) in build_selection(pm, selection) .into_iter() .zip(fork_join_maps.iter()) - .zip(reduce_cycles.iter()) { let Some(mut func) = func else { continue; }; - fork_split(&mut func, fork_join_map, reduce_cycles); + fork_guard_elim(&mut func, fork_join_map); changed |= func.modified(); } pm.delete_gravestones(); pm.clear_analyses(); } + Pass::Serialize => { + // FIXME: How to get module name here? + let output_file = "out.hbin"; + let module = pm.clone().get_module().clone(); + let module_contents: Vec<u8> = postcard::to_allocvec(&module).unwrap(); + let mut file = + File::create(&output_file).expect("PANIC: Unable to open output module file."); + file.write_all(&module_contents) + .expect("PANIC: Unable to write output module file contents."); + } + Pass::ForkSplit => { + assert!(args.is_empty()); + loop { + let mut inner_changed = false; + pm.make_fork_join_maps(); + pm.make_reduce_cycles(); + let fork_join_maps = pm.fork_join_maps.take().unwrap(); + let reduce_cycles = pm.reduce_cycles.take().unwrap(); + for ((func, fork_join_map), reduce_cycles) in build_selection(pm, selection.clone()) + .into_iter() + .zip(fork_join_maps.iter()) + .zip(reduce_cycles.iter()) + { + let Some(mut func) = func else { + continue; + }; + fork_split(&mut func, fork_join_map, reduce_cycles); + changed |= func.modified(); + inner_changed |= func.modified(); + } + pm.delete_gravestones(); + pm.clear_analyses(); + + if !inner_changed { + break; + } + } + } Pass::Forkify => { - todo!("Forkify doesn't use editor") + assert!(args.is_empty()); + pm.make_fork_join_maps(); + pm.make_control_subgraphs(); + pm.make_loops(); + let fork_join_maps = pm.fork_join_maps.take().unwrap(); + let loops = pm.loops.take().unwrap(); + let control_subgraphs = pm.control_subgraphs.take().unwrap(); + for (((func, fork_join_map), loop_nest), control_subgraph) in + build_selection(pm, selection) + .into_iter() + .zip(fork_join_maps.iter()) + .zip(loops.iter()) + .zip(control_subgraphs.iter()) + { + let Some(mut func) = func else { + continue; + }; + // TODO: uses direct return from forkify for now instead of + // func.modified, see comment on top of `forkify` for why. Fix + // this eventually. + changed |= forkify(&mut func, control_subgraph, fork_join_map, loop_nest); + } + pm.delete_gravestones(); + pm.clear_analyses(); } Pass::GCM => { assert!(args.is_empty()); @@ -1576,21 +1670,50 @@ fn run_pass( Pass::Unforkify => { assert!(args.is_empty()); pm.make_fork_join_maps(); + pm.make_loops(); + let fork_join_maps = pm.fork_join_maps.take().unwrap(); + let loops = pm.loops.take().unwrap(); - for (func, fork_join_map) in build_selection(pm, selection) + for ((func, fork_join_map), loop_tree) in build_selection(pm, selection) .into_iter() .zip(fork_join_maps.iter()) + .zip(loops.iter()) { let Some(mut func) = func else { continue; }; - unforkify(&mut func, fork_join_map); + unforkify(&mut func, fork_join_map, loop_tree); changed |= func.modified(); } pm.delete_gravestones(); pm.clear_analyses(); } + Pass::ForkCoalesce => { + assert!(args.is_empty()); + pm.make_fork_join_maps(); + pm.make_control_subgraphs(); + pm.make_loops(); + pm.make_reduce_cycles(); + let fork_join_maps = pm.fork_join_maps.take().unwrap(); + let loops = pm.loops.take().unwrap(); + let control_subgraphs = pm.control_subgraphs.take().unwrap(); + for (((func, fork_join_map), loop_nest), control_subgraph) in + build_selection(pm, selection) + .into_iter() + .zip(fork_join_maps.iter()) + .zip(loops.iter()) + .zip(control_subgraphs.iter()) + { + let Some(mut func) = func else { + continue; + }; + changed |= fork_coalesce(&mut func, loop_nest, fork_join_map); + // func.modified(); + } + pm.delete_gravestones(); + pm.clear_analyses(); + } Pass::WritePredication => { assert!(args.is_empty()); for func in build_selection(pm, selection) { diff --git a/juno_utils/src/stringtab.rs b/juno_utils/src/stringtab.rs index e151b830d7b51d3baa4f64ae32214afca81a9eab..45ee08644642692acb36bff277081745f492c3d1 100644 --- a/juno_utils/src/stringtab.rs +++ b/juno_utils/src/stringtab.rs @@ -1,6 +1,4 @@ -extern crate serde; - -use self::serde::{Deserialize, Serialize}; +use serde::{Deserialize, Serialize}; use std::collections::HashMap;