diff --git a/LICENSE b/LICENSE new file mode 100644 index 0000000000000000000000000000000000000000..b1c318bf73fb349e4717387c75c82b5583730363 --- /dev/null +++ b/LICENSE @@ -0,0 +1,219 @@ +The Hercules Compiler is under the Apache License v2.0 with LLVM Exceptions: + + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + END OF TERMS AND CONDITIONS + + APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "[]" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + + Copyright [yyyy] [name of copyright owner] + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + +---- LLVM Exceptions to the Apache 2.0 License ---- + +As an exception, if, as a result of your compiling your source code, portions +of this Software are embedded into an Object form of such source code, you +may redistribute such embedded portions in such Object form without complying +with the conditions of Sections 4(a), 4(b) and 4(d) of the License. + +In addition, if you combine or link compiled forms of this Software with +software that is licensed under the GPLv2 ("Combined Software") and if a +court of competent jurisdiction determines that the patent provision (Section +3), the indemnity provision (Section 9) or other Section of the License +conflicts with the conditions of the GPLv2, you may retroactively and +prospectively choose to deem waived or otherwise exclude such Section(s) of +the License, but only in their entirety and only with respect to the Combined +Software. diff --git a/hercules_cg/src/rt.rs b/hercules_cg/src/rt.rs index 6981a3da7e59176f73d6fecdde07fe636cc6aecf..d94f0e19191a028dfaded785c460164513b712a4 100644 --- a/hercules_cg/src/rt.rs +++ b/hercules_cg/src/rt.rs @@ -938,7 +938,7 @@ impl<'a> RTContext<'a> { let dst_device = self.node_colors.0[&collect]; write!( block, - "::hercules_rt::__copy_{}_to_{}({}.byte_add({} as usize).0, {}.0, {});", + "::hercules_rt::__copy_{}_to_{}({}.byte_add({} as usize).0, {}.0, {} as usize);", src_device.name(), dst_device.name(), self.get_value(collect, bb, false), diff --git a/hercules_opt/src/editor.rs b/hercules_opt/src/editor.rs index 0e332a0033c50585160242f04bfdcceb37f87ad5..6f0fdf4dcb04e4e9d5adfec80053be3ecfc2b08d 100644 --- a/hercules_opt/src/editor.rs +++ b/hercules_opt/src/editor.rs @@ -457,7 +457,7 @@ impl<'a, 'b> FunctionEdit<'a, 'b> { pub fn add_node(&mut self, node: Node) -> NodeID { let id = NodeID::new(self.num_node_ids()); // Added nodes need to have an entry in the def-use map. - self.updated_def_use.insert(id, HashSet::new()); + self.updated_def_use.entry(id).or_insert(HashSet::new()); // Added nodes use other nodes, and we need to update their def-use // entries. for u in get_uses(&node).as_ref() { diff --git a/hercules_opt/src/fork_transforms.rs b/hercules_opt/src/fork_transforms.rs index ff0f0283767996914e8f9b2274ed9a6d538b1812..e6db0345def31324243cdee2bdcb6b5cca5d9a7b 100644 --- a/hercules_opt/src/fork_transforms.rs +++ b/hercules_opt/src/fork_transforms.rs @@ -578,7 +578,7 @@ pub fn fork_coalesce( // FIXME: This could give us two forks that aren't actually ancestors / related, but then the helper will just return false early. // something like: `fork_joins.postorder_iter().windows(2)` is ideal here. for (inner, outer) in fork_joins.iter().cartesian_product(fork_joins.iter()) { - if fork_coalesce_helper(editor, *outer, *inner, fork_join_map) { + if fork_coalesce_helper(editor, *outer, *inner, fork_join_map).is_some() { return true; } } @@ -587,13 +587,15 @@ pub fn fork_coalesce( /** Opposite of fork split, takes two fork-joins with no control between them, and merges them into a single fork-join. + Returns None if the forks could not be merged and the NodeIDs of the + resulting fork and join if it succeeds in merging them. */ pub fn fork_coalesce_helper( editor: &mut FunctionEditor, outer_fork: NodeID, inner_fork: NodeID, fork_join_map: &HashMap<NodeID, NodeID>, -) -> bool { +) -> Option<(NodeID, NodeID)> { // Check that all reduces in the outer fork are in *simple* cycles with a unique reduce of the inner fork. let outer_join = fork_join_map[&outer_fork]; @@ -621,47 +623,35 @@ pub fn fork_coalesce_helper( reduct: _, } = inner_reduce_node else { - return false; + return None; }; // FIXME: check this condition better (i.e reduce might not be attached to join) if *inner_control != inner_join { - return false; + return None; }; if *inner_init != outer_reduce { - return false; + return None; }; if pairs.contains_left(&outer_reduce) || pairs.contains_right(&inner_reduce) { - return false; + return None; } else { pairs.insert(outer_reduce, inner_reduce); } } // Check for control between join-join and fork-fork - let Some(user) = editor - .get_users(outer_fork) - .filter(|node| editor.func().nodes[node.idx()].is_control()) - .next() - else { - return false; - }; + let (control, _) = editor.node(inner_fork).try_fork().unwrap(); - if user != inner_fork { - return false; + if control != outer_fork { + return None; } - let Some(user) = editor - .get_users(inner_join) - .filter(|node| editor.func().nodes[node.idx()].is_control()) - .next() - else { - return false; - }; + let control = editor.node(outer_join).try_join().unwrap(); - if user != outer_join { - return false; + if control != inner_join { + return None; } // Checklist: @@ -686,46 +676,47 @@ pub fn fork_coalesce_helper( // CHECKME / FIXME: Might need to be added the other way. new_factors.append(&mut inner_dims.to_vec()); - for tid in inner_tids { - let (fork, dim) = editor.func().nodes[tid.idx()].try_thread_id().unwrap(); - let new_tid = Node::ThreadID { - control: fork, - dimension: dim + num_outer_dims, - }; + let mut new_fork = NodeID::new(0); + let new_join = inner_join; // We'll reuse the inner join as the join of the new fork + + let success = editor.edit(|mut edit| { + for tid in inner_tids { + let (fork, dim) = edit.get_node(tid).try_thread_id().unwrap(); + let new_tid = Node::ThreadID { + control: fork, + dimension: dim + num_outer_dims, + }; - editor.edit(|mut edit| { let new_tid = edit.add_node(new_tid); - let mut edit = edit.replace_all_uses(tid, new_tid)?; + edit = edit.replace_all_uses(tid, new_tid)?; edit.sub_edit(tid, new_tid); - Ok(edit) - }); - } - - // Fuse Reductions - for (outer_reduce, inner_reduce) in pairs { - let (_, outer_init, _) = editor.func().nodes[outer_reduce.idx()] - .try_reduce() - .unwrap(); - let (_, inner_init, _) = editor.func().nodes[inner_reduce.idx()] - .try_reduce() - .unwrap(); - editor.edit(|mut edit| { + } + // Fuse Reductions + for (outer_reduce, inner_reduce) in pairs { + let (_, outer_init, _) = edit.get_node(outer_reduce).try_reduce().unwrap(); + let (_, inner_init, _) = edit.get_node(inner_reduce).try_reduce().unwrap(); // Set inner init to outer init. edit = edit.replace_all_uses_where(inner_init, outer_init, |usee| *usee == inner_reduce)?; edit = edit.replace_all_uses(outer_reduce, inner_reduce)?; edit = edit.delete_node(outer_reduce)?; + } - Ok(edit) - }); - } - - editor.edit(|mut edit| { - let new_fork = Node::Fork { + let new_fork_node = Node::Fork { control: outer_pred, factors: new_factors.into(), }; - let new_fork = edit.add_node(new_fork); + new_fork = edit.add_node(new_fork_node); + + if edit + .get_schedule(outer_fork) + .contains(&Schedule::ParallelFork) + && edit + .get_schedule(inner_fork) + .contains(&Schedule::ParallelFork) + { + edit = edit.add_schedule(new_fork, Schedule::ParallelFork)?; + } edit = edit.replace_all_uses(inner_fork, new_fork)?; edit = edit.replace_all_uses(outer_fork, new_fork)?; @@ -737,7 +728,11 @@ pub fn fork_coalesce_helper( Ok(edit) }); - true + if success { + Some((new_fork, new_join)) + } else { + None + } } pub fn split_any_fork( @@ -760,7 +755,7 @@ pub fn split_any_fork( * Useful for code generation. A single iteration of `fork_split` only splits * at most one fork-join, it must be called repeatedly to split all fork-joins. */ -pub(crate) fn split_fork( +pub fn split_fork( editor: &mut FunctionEditor, fork: NodeID, join: NodeID, @@ -1215,13 +1210,13 @@ pub fn fork_interchange_all_forks( } } -fn fork_interchange( +pub fn fork_interchange( editor: &mut FunctionEditor, fork: NodeID, join: NodeID, first_dim: usize, second_dim: usize, -) { +) -> Option<NodeID> { // Check that every reduce on the join is parallel or associative. let nodes = &editor.func().nodes; let schedules = &editor.func().schedules; @@ -1234,7 +1229,7 @@ fn fork_interchange( }) { // If not, we can't necessarily do interchange. - return; + return None; } let Node::Fork { @@ -1276,6 +1271,7 @@ fn fork_interchange( let mut factors = factors.clone(); factors.swap(first_dim, second_dim); let new_fork = Node::Fork { control, factors }; + let mut new_fork_id = None; editor.edit(|mut edit| { for (old_id, new_tid) in fix_tids { let new_id = edit.add_node(new_tid); @@ -1283,9 +1279,17 @@ fn fork_interchange( edit = edit.delete_node(old_id)?; } let new_fork = edit.add_node(new_fork); + if edit.get_schedule(fork).contains(&Schedule::ParallelFork) { + edit = edit.add_schedule(new_fork, Schedule::ParallelFork)?; + } edit = edit.replace_all_uses(fork, new_fork)?; - edit.delete_node(fork) + edit = edit.delete_node(fork)?; + + new_fork_id = Some(new_fork); + Ok(edit) }); + + new_fork_id } /* diff --git a/hercules_opt/src/gcm.rs b/hercules_opt/src/gcm.rs index c612acac1e114fcb3a73edcf0468d2c6f7c84acd..d950941a4acba886f47cdd0e99cb3d9a48459636 100644 --- a/hercules_opt/src/gcm.rs +++ b/hercules_opt/src/gcm.rs @@ -879,8 +879,15 @@ fn spill_clones( || editor.func().nodes[a.idx()].is_reduce()) && !editor.func().nodes[a.idx()] .try_reduce() - .map(|(_, init, _)| { - init == *b + .map(|(_, init, reduct)| { + (init == *b || reduct == *b) + && editor.func().schedules[a.idx()].contains(&Schedule::ParallelReduce) + }) + .unwrap_or(false) + && !editor.func().nodes[a.idx()] + .try_phi() + .map(|(_, data)| { + data.contains(b) && editor.func().schedules[a.idx()].contains(&Schedule::ParallelReduce) }) .unwrap_or(false)) @@ -1302,39 +1309,53 @@ enum UTerm { Device(Device), } -fn unify( - mut equations: VecDeque<(UTerm, UTerm)>, -) -> Result<BTreeMap<NodeID, Device>, BTreeMap<NodeID, Device>> { +fn unify(mut equations: VecDeque<(UTerm, UTerm)>) -> Result<BTreeMap<NodeID, Device>, NodeID> { let mut theta = BTreeMap::new(); + // First, assign devices to nodes when a rule directly says to. + for _ in 0..equations.len() { + let (l, r) = equations.pop_front().unwrap(); + match (l, r) { + (UTerm::Node(n), UTerm::Device(d)) | (UTerm::Device(d), UTerm::Node(n)) => { + if let Some(old_d) = theta.insert(n, d) + && old_d != d + { + return Err(n); + } + } + _ => equations.push_back((l, r)), + } + } + + // Second, iterate the rest of the rules until... + // 1. The rules are exhausted. All the nodes have device assignments. + // 2. No progress is being made. Some nodes may not have device assignments. + // 3. An inconsistency has been found. The inconsistency is returned. let mut no_progress_iters = 0; while no_progress_iters <= equations.len() && let Some((l, r)) = equations.pop_front() { - match (l, r) { - (UTerm::Node(_), UTerm::Node(_)) => { - if l != r { - equations.push_back((l, r)); - } - no_progress_iters += 1; - } - (UTerm::Node(n), UTerm::Device(d)) | (UTerm::Device(d), UTerm::Node(n)) => { - theta.insert(n, d); - for (l, r) in equations.iter_mut() { - if *l == UTerm::Node(n) { - *l = UTerm::Device(d); - } - if *r == UTerm::Node(n) { - *r = UTerm::Device(d); - } + let (UTerm::Node(l), UTerm::Node(r)) = (l, r) else { + panic!(); + }; + + match (theta.get(&l), theta.get(&r)) { + (Some(ld), Some(rd)) => { + if ld != rd { + return Err(l); + } else { + no_progress_iters = 0; } - no_progress_iters = 0; } - (UTerm::Device(d1), UTerm::Device(d2)) if d1 == d2 => { + (Some(d), None) | (None, Some(d)) => { + let d = *d; + theta.insert(l, d); + theta.insert(r, d); no_progress_iters = 0; } - _ => { - return Err(theta); + (None, None) => { + equations.push_back((UTerm::Node(l), UTerm::Node(r))); + no_progress_iters += 1; } } } @@ -1377,8 +1398,8 @@ fn color_nodes( } if !editor.get_type(typing[id.idx()]).is_primitive() => { // Every input to a phi needs to be on the same device. The // phi itself is also on this device. - for (l, r) in zip(data.into_iter(), data.into_iter().skip(1).chain(once(&id))) { - equations.push((UTerm::Node(*l), UTerm::Node(*r))); + for data in data { + equations.push((UTerm::Node(*data), UTerm::Node(id))); } } Node::Reduce { @@ -1394,7 +1415,7 @@ fn color_nodes( } if !editor.get_type(typing[id.idx()]).is_primitive() => { // Every input to the reduce, and the reduce itself, are on // the same device. - equations.push((UTerm::Node(first), UTerm::Node(second))); + equations.push((UTerm::Node(first), UTerm::Node(id))); equations.push((UTerm::Node(second), UTerm::Node(id))); } Node::Constant { id: _ } @@ -1533,12 +1554,11 @@ fn color_nodes( } Some(func_colors) } - Err(progress) => { + Err(id) => { // If unification failed, then there's some node using a node in - // `progress` that's expecting a different type than what it got. - // Pick one and add potentially inter-device copies on each def-use - // edge. We'll clean these up later. - let (id, _) = progress.into_iter().next().unwrap(); + // that's expecting a different type than what it got. Add + // potentially inter-device copies on each def-use edge. We'll clean + // these up later. let users: Vec<_> = editor.get_users(id).collect(); let success = editor.edit(|mut edit| { let cons = edit.add_zero_constant(typing[id.idx()]); diff --git a/hercules_opt/src/loop_bound_canon.rs b/hercules_opt/src/loop_bound_canon.rs index 203e45f87df6f24d7b31438dd2dee8d7285ddf1a..3e194e2dda3319dd6c2c4ece4aea1af41e7e701b 100644 --- a/hercules_opt/src/loop_bound_canon.rs +++ b/hercules_opt/src/loop_bound_canon.rs @@ -73,8 +73,6 @@ pub fn canonicalize_single_loop_bounds( .into_iter() .partition(|f| loop_bound_iv_phis.contains(&f.phi())); - // println!("{:?}", loop_bound_ivs); - // Assume there is only one loop bound iv. if loop_bound_ivs.len() != 1 { @@ -248,9 +246,6 @@ pub fn canonicalize_single_loop_bounds( } else { None }; - // println!("condition node: {:?}", condition_node); - let users = editor.get_users(condition_node).collect_vec(); - // println!("{:?}", users); let mut condition_node = condition_node; @@ -294,7 +289,6 @@ pub fn canonicalize_single_loop_bounds( }; Ok(edit) }); - let update_expr_users: Vec<_> = editor .get_users(*update_expression) diff --git a/hercules_opt/src/schedule.rs b/hercules_opt/src/schedule.rs index 9bc7823ee7f5837cf49387170e548a9174340f42..10eca72e4b9c9aa6285f04cc8812d58f459d556f 100644 --- a/hercules_opt/src/schedule.rs +++ b/hercules_opt/src/schedule.rs @@ -1,6 +1,6 @@ use std::collections::{BTreeSet, HashMap, HashSet}; +use std::iter::once; -use hercules_ir::def_use::*; use hercules_ir::ir::*; use crate::*; @@ -42,6 +42,10 @@ pub fn infer_parallel_reduce( fork_join_map: &HashMap<NodeID, NodeID>, reduce_cycles: &HashMap<NodeID, HashSet<NodeID>>, ) { + let join_fork_map: HashMap<_, _> = fork_join_map + .into_iter() + .map(|(fork, join)| (*join, *fork)) + .collect(); for id in editor.node_ids() { let func = editor.func(); if !func.nodes[id.idx()].is_reduce() { @@ -98,40 +102,11 @@ pub fn infer_parallel_reduce( && *collect == last_reduce && !reduce_cycles[&last_reduce].contains(data) { - // If there is a Write-Reduce tight cycle, get the position indices. - let positions = indices - .iter() - .filter_map(|index| { - if let Index::Position(indices) = index { - Some(indices) - } else { - None - } - }) - .flat_map(|pos| pos.iter()); - - // Get the Forks corresponding to uses of bare ThreadIDs. - let fork_thread_id_pairs = positions.filter_map(|id| { - if let Node::ThreadID { control, dimension } = func.nodes[id.idx()] { - Some((control, dimension)) - } else { - None - } - }); - let mut forks = HashMap::<NodeID, Vec<usize>>::new(); - for (fork, dim) in fork_thread_id_pairs { - forks.entry(fork).or_default().push(dim); - } - - // Check if one of the Forks correspond to the Join associated with - // the Reduce being considered, and has all of its dimensions - // represented in the indexing. - let is_parallel = forks.into_iter().any(|(id, mut rep_dims)| { - rep_dims.sort(); - rep_dims.dedup(); - fork_join_map[&id] == first_control.unwrap() - && func.nodes[id.idx()].try_fork().unwrap().1.len() == rep_dims.len() - }); + let is_parallel = indices_parallel_over_forks( + editor, + indices, + once(join_fork_map[&first_control.unwrap()]), + ); if is_parallel { editor.edit(|edit| edit.add_schedule(id, Schedule::ParallelReduce)); @@ -145,6 +120,7 @@ pub fn infer_parallel_reduce( * operands must be the Reduce node, and all other operands must not be in the * Reduce node's cycle. */ +#[rustfmt::skip] pub fn infer_monoid_reduce( editor: &mut FunctionEditor, reduce_cycles: &HashMap<NodeID, HashSet<NodeID>>, diff --git a/hercules_opt/src/utils.rs b/hercules_opt/src/utils.rs index e962b81dfaf28ece80f49a150139f5774c186771..b910a128116fb8fb39de29475b93ffa70a12dfcd 100644 --- a/hercules_opt/src/utils.rs +++ b/hercules_opt/src/utils.rs @@ -532,6 +532,24 @@ where let fork_thread_id_pairs = node_indices(indices).filter_map(|id| { if let Node::ThreadID { control, dimension } = nodes[id.idx()] { Some((control, dimension)) + } else if let Node::Binary { + op: BinaryOperator::Add, + left: tid, + right: cons, + } = nodes[id.idx()] + && let Node::ThreadID { control, dimension } = nodes[tid.idx()] + && (nodes[cons.idx()].is_constant() || nodes[cons.idx()].is_dynamic_constant()) + { + Some((control, dimension)) + } else if let Node::Binary { + op: BinaryOperator::Add, + left: cons, + right: tid, + } = nodes[id.idx()] + && let Node::ThreadID { control, dimension } = nodes[tid.idx()] + && (nodes[cons.idx()].is_constant() || nodes[cons.idx()].is_dynamic_constant()) + { + Some((control, dimension)) } else { None } diff --git a/hercules_samples/matmul/src/main.rs b/hercules_samples/matmul/src/main.rs index 277276648e905186bfeb54714fb00f7275f17b22..00e5b873e2061f98b911876900890deb8b3abcef 100644 --- a/hercules_samples/matmul/src/main.rs +++ b/hercules_samples/matmul/src/main.rs @@ -25,7 +25,8 @@ fn main() { let a = HerculesImmBox::from(a.as_ref()); let b = HerculesImmBox::from(b.as_ref()); let mut r = runner!(matmul); - let mut c: HerculesMutBox<i32> = HerculesMutBox::from(r.run(I as u64, J as u64, K as u64, a.to(), b.to()).await); + let mut c: HerculesMutBox<i32> = + HerculesMutBox::from(r.run(I as u64, J as u64, K as u64, a.to(), b.to()).await); assert_eq!(c.as_slice(), correct_c.as_ref()); }); } diff --git a/juno_samples/cava/src/gpu.sch b/juno_samples/cava/src/gpu.sch index c8db124ede9b98220866a8c1cdc7b17cdfb8093f..bacfd3abca363a6dd93496c1adb23fa54c860b9f 100644 --- a/juno_samples/cava/src/gpu.sch +++ b/juno_samples/cava/src/gpu.sch @@ -120,6 +120,7 @@ simpl!(fuse4); //fork-tile[2, 0, false, true](fuse4@channel_loop); //fork-split(fuse4@channel_loop); //clean-monoid-reduces(fuse4); +unforkify(fuse4@channel_loop); no-memset(fuse5@res1); no-memset(fuse5@res2); diff --git a/juno_samples/edge_detection/src/cpu.sch b/juno_samples/edge_detection/src/cpu.sch index 4bd3254b1773a804e90154eec8385c8237b62fbd..ec9e423dc4c160d08b61eaf45d2b75329886f94e 100644 --- a/juno_samples/edge_detection/src/cpu.sch +++ b/juno_samples/edge_detection/src/cpu.sch @@ -86,7 +86,7 @@ fixpoint { simpl!(max_gradient); fork-dim-merge(max_gradient); simpl!(max_gradient); -fork-tile[8, 0, false, false](max_gradient); +fork-tile[16, 0, false, false](max_gradient); let split = fork-split(max_gradient); clean-monoid-reduces(max_gradient); let out = outline(split._4_max_gradient.fj1); @@ -104,11 +104,18 @@ fixpoint { } predication(reject_zero_crossings); simpl!(reject_zero_crossings); +fork-tile[4, 1, false, false](reject_zero_crossings); +fork-tile[4, 0, false, false](reject_zero_crossings); +fork-interchange[1, 2](reject_zero_crossings); +let split = fork-split(reject_zero_crossings); +let reject_zero_crossings_body = outline(split._5_reject_zero_crossings.fj2); +fork-coalesce(reject_zero_crossings, reject_zero_crossings_body); +simpl!(reject_zero_crossings, reject_zero_crossings_body); async-call(edge_detection@le, edge_detection@zc); -fork-split(gaussian_smoothing_body, laplacian_estimate_body, zero_crossings_body, gradient, reject_zero_crossings); -unforkify(gaussian_smoothing_body, laplacian_estimate_body, zero_crossings_body, gradient, reject_zero_crossings); +fork-split(gaussian_smoothing_body, laplacian_estimate_body, zero_crossings_body, gradient, reject_zero_crossings_body); +unforkify(gaussian_smoothing_body, laplacian_estimate_body, zero_crossings_body, gradient, reject_zero_crossings_body); simpl!(*); diff --git a/juno_samples/matmul/build.rs b/juno_samples/matmul/build.rs index 0be838c620761e8726590e2dbaf7bfdb7a82e3df..d2813388e0e7a1d7bd1696ffbb641e629096e2c2 100644 --- a/juno_samples/matmul/build.rs +++ b/juno_samples/matmul/build.rs @@ -6,6 +6,8 @@ fn main() { JunoCompiler::new() .file_in_src("matmul.jn") .unwrap() + .schedule_in_src("cpu.sch") + .unwrap() .build() .unwrap(); } diff --git a/juno_samples/matmul/src/cpu.sch b/juno_samples/matmul/src/cpu.sch new file mode 100644 index 0000000000000000000000000000000000000000..69f1811d08a5691de6dc10fa86f0720e416fb85a --- /dev/null +++ b/juno_samples/matmul/src/cpu.sch @@ -0,0 +1,61 @@ +macro optimize!(X) { + gvn(X); + phi-elim(X); + dce(X); + ip-sroa(X); + sroa(X); + dce(X); + gvn(X); + phi-elim(X); + dce(X); +} + +macro codegen-prep!(X) { + optimize!(X); + gcm(X); + float-collections(X); + dce(X); + gcm(X); +} + +macro forkify!(X) { + fixpoint { + forkify(X); + fork-guard-elim(X); + } +} + +macro fork-tile { + fork-tile[n, 0, false, true](X); +} + +macro parallelize!(X) { + parallel-fork(X); + parallel-reduce(X); +} + +macro unforkify!(X) { + fork-split(X); + unforkify(X); +} + +optimize!(*); +forkify!(*); +associative(matmul@outer); + +// Parallelize by computing output array as 16 chunks +let par = matmul@outer \ matmul@inner; +fork-tile; +let (outer, inner, _) = fork-reshape[[1, 3], [0], [2]](par); +parallelize!(outer \ inner); + +let body = outline(inner); +cpu(body); + +// Tile for cache, assuming 64B cache lines +fork-tile; +let (outer, inner) = fork-reshape[[0, 2, 4, 1, 3], [5]](body); + +reduce-slf(inner); +unforkify!(body); +codegen-prep!(*); diff --git a/juno_samples/rodinia/backprop/src/backprop.jn b/juno_samples/rodinia/backprop/src/backprop.jn index 356bb3d91836ba0994cad56315b9a5588b0df8b7..94c4334c1cae17a396384ad6135432e3e80f70e3 100644 --- a/juno_samples/rodinia/backprop/src/backprop.jn +++ b/juno_samples/rodinia/backprop/src/backprop.jn @@ -7,9 +7,9 @@ fn layer_forward<n, m: usize>(vals: f32[n + 1], weights: f32[n + 1, m + 1]) -> f @res let result : f32[m + 1]; result[0] = 1.0; - for j in 1..=m { + @outer_loop for j in 1..=m { let sum = 0.0; - for k in 0..=n { + @inner_loop for k in 0..=n { sum += weights[k, j] * vals[k]; } result[j] = squash(sum); diff --git a/juno_samples/rodinia/backprop/src/cpu.sch b/juno_samples/rodinia/backprop/src/cpu.sch index 432948076a85791c2a923d6ca286ad9e11206e0f..de34d660bcc5e3d95d58aa63524bffdbc0b8f67e 100644 --- a/juno_samples/rodinia/backprop/src/cpu.sch +++ b/juno_samples/rodinia/backprop/src/cpu.sch @@ -15,21 +15,17 @@ delete-uncalled(*); no-memset(layer_forward@res); lift-dc-math(*); loop-bound-canon(*); -dce(*); +simpl!(*); lift-dc-math(*); - +slf(*); fixpoint { forkify(*); fork-guard-elim(*); fork-coalesce(*); } +reduce-slf(*); +simpl!(*); fork-split(*); -gvn(*); -phi-elim(*); -dce(*); unforkify(*); -gvn(*); -phi-elim(*); -dce(*); gcm(*); diff --git a/juno_samples/rodinia/bfs/src/gpu.sch b/juno_samples/rodinia/bfs/src/gpu.sch index 0a3f4d7737ba22b77443e1f8e7b35ded35040e4c..6c4d027b77df936c5840237c211950d6c0430082 100644 --- a/juno_samples/rodinia/bfs/src/gpu.sch +++ b/juno_samples/rodinia/bfs/src/gpu.sch @@ -1,23 +1,39 @@ -gvn(*); -phi-elim(*); -dce(*); +macro simpl!(X) { + ccp(X); + simplify-cfg(X); + lift-dc-math(X); + gvn(X); + phi-elim(X); + dce(X); + infer-schedules(X); +} -let outline = auto-outline(bfs); -gpu(outline.bfs); +phi-elim(bfs); +no-memset(bfs@cost); +let cost_init = outline(bfs@cost_init); +let loop1 = outline(bfs@loop1); +let loop2 = outline(bfs@loop2); +gpu(loop1, loop2); -ip-sroa(*); -sroa(*); -dce(*); -gvn(*); -phi-elim(*); -dce(*); +simpl!(*); +predication(*); +const-inline(*); +simpl!(*); +fixpoint { + forkify(*); + fork-guard-elim(*); +} +simpl!(*); +predication(*); +simpl!(*); -//forkify(*); -infer-schedules(*); +unforkify(cost_init); +parallel-reduce(loop1); +forkify(*); +fork-guard-elim(*); +simpl!(*); +predication(*); +reduce-slf(*); +simpl!(*); gcm(*); -fixpoint { - float-collections(*); - dce(*); - gcm(*); -} diff --git a/juno_samples/rodinia/srad/src/cpu.sch b/juno_samples/rodinia/srad/src/cpu.sch index 2b45e8c956e10cb6af538282df98e32eb35b6b5e..a4cd49569aedadf42d3965bee040391f4d56cef9 100644 --- a/juno_samples/rodinia/srad/src/cpu.sch +++ b/juno_samples/rodinia/srad/src/cpu.sch @@ -29,6 +29,8 @@ fixpoint { } simpl!(*); fork-interchange[0, 1](loop1); +reduce-slf(*); +simpl!(*); fork-split(*); unforkify(*); diff --git a/juno_scheduler/src/compile.rs b/juno_scheduler/src/compile.rs index bd27350a26d58f4e729b24d0026f12cd13ca7195..9d020c64ccef3b9c0a79694876c5b0ace606f938 100644 --- a/juno_scheduler/src/compile.rs +++ b/juno_scheduler/src/compile.rs @@ -135,6 +135,7 @@ impl FromStr for Appliable { "fork-extend" => Ok(Appliable::Pass(ir::Pass::ForkExtend)), "fork-unroll" | "unroll" => Ok(Appliable::Pass(ir::Pass::ForkUnroll)), "fork-fusion" | "fusion" => Ok(Appliable::Pass(ir::Pass::ForkFusion)), + "fork-reshape" => Ok(Appliable::Pass(ir::Pass::ForkReshape)), "lift-dc-math" => Ok(Appliable::Pass(ir::Pass::LiftDCMath)), "loop-bound-canon" => Ok(Appliable::Pass(ir::Pass::LoopBoundCanon)), "outline" => Ok(Appliable::Pass(ir::Pass::Outline)), @@ -207,6 +208,28 @@ fn compile_stmt( exp: compile_exp_as_expr(expr, lexer, macrostab, macros)?, }]) } + parser::Stmt::LetsStmt { + span: _, + vars, + expr, + } => { + let tmp = format!("{}_tmp", macros.uniq()); + Ok(std::iter::once(ir::ScheduleStmt::Let { + var: tmp.clone(), + exp: compile_exp_as_expr(expr, lexer, macrostab, macros)?, + }) + .chain(vars.into_iter().enumerate().map(|(idx, v)| { + let var = lexer.span_str(v).to_string(); + ir::ScheduleStmt::Let { + var, + exp: ir::ScheduleExp::TupleField { + lhs: Box::new(ir::ScheduleExp::Variable { var: tmp.clone() }), + field: idx, + }, + } + })) + .collect()) + } parser::Stmt::AssignStmt { span: _, var, rhs } => { let var = lexer.span_str(var).to_string(); Ok(vec![ir::ScheduleStmt::Assign { @@ -489,6 +512,29 @@ fn compile_expr( rhs: Box::new(rhs), })) } + parser::Expr::Tuple { span: _, exps } => { + let exprs = exps + .into_iter() + .map(|e| compile_exp_as_expr(e, lexer, macrostab, macros)) + .fold(Ok(vec![]), |mut res, exp| { + let mut res = res?; + res.push(exp?); + Ok(res) + })?; + Ok(ExprResult::Expr(ir::ScheduleExp::Tuple { exprs })) + } + parser::Expr::TupleField { + span: _, + lhs, + field, + } => { + let lhs = compile_exp_as_expr(*lhs, lexer, macrostab, macros)?; + let field = lexer.span_str(field).parse().expect("Parsing"); + Ok(ExprResult::Expr(ir::ScheduleExp::TupleField { + lhs: Box::new(lhs), + field, + })) + } } } diff --git a/juno_scheduler/src/ir.rs b/juno_scheduler/src/ir.rs index 6aa85fe53689cf015497e56850ef0c197ccbdae0..ab1495b816c99452560d03c0addf77a5aec18974 100644 --- a/juno_scheduler/src/ir.rs +++ b/juno_scheduler/src/ir.rs @@ -21,6 +21,7 @@ pub enum Pass { ForkFusion, ForkGuardElim, ForkInterchange, + ForkReshape, ForkSplit, ForkUnroll, Forkify, @@ -59,6 +60,7 @@ impl Pass { Pass::ForkExtend => num == 1, Pass::ForkFissionBufferize => num == 2 || num == 1, Pass::ForkInterchange => num == 2, + Pass::ForkReshape => true, Pass::InterproceduralSROA => num == 0 || num == 1, Pass::Print => num == 1, Pass::Rename => num == 1, @@ -76,6 +78,7 @@ impl Pass { Pass::ForkExtend => "1", Pass::ForkFissionBufferize => "1 or 2", Pass::ForkInterchange => "2", + Pass::ForkReshape => "any", Pass::InterproceduralSROA => "0 or 1", Pass::Print => "1", Pass::Rename => "1", @@ -130,6 +133,13 @@ pub enum ScheduleExp { lhs: Box<ScheduleExp>, rhs: Box<ScheduleExp>, }, + Tuple { + exprs: Vec<ScheduleExp>, + }, + TupleField { + lhs: Box<ScheduleExp>, + field: usize, + }, // This is used to "box" a selection by evaluating it at one point and then // allowing it to be used as a selector later on Selection { diff --git a/juno_scheduler/src/lang.y b/juno_scheduler/src/lang.y index 3b030e1d42bdb970cdfa67d21c4198dc89edea9e..451f035b8122ac1a694567327a44776c9326fff2 100644 --- a/juno_scheduler/src/lang.y +++ b/juno_scheduler/src/lang.y @@ -19,6 +19,8 @@ Schedule -> OperationList Stmt -> Stmt : 'let' 'ID' '=' Expr ';' { Stmt::LetStmt { span: $span, var: span_of_tok($2), expr: $4 } } + | 'let' '(' Ids ')' '=' Expr ';' + { Stmt::LetsStmt { span: $span, vars: rev($3), expr: $6 } } | 'ID' '=' Expr ';' { Stmt::AssignStmt { span: $span, var: span_of_tok($1), rhs: $3 } } | Expr ';' @@ -56,10 +58,14 @@ Expr -> Expr { Expr::String { span: $span } } | Expr '.' 'ID' { Expr::Field { span: $span, lhs: Box::new($1), field: span_of_tok($3) } } + | Expr '.' 'INT' + { Expr::TupleField { span: $span, lhs: Box::new($1), field: span_of_tok($3) } } | Expr '@' 'ID' { Expr::Field { span: $span, lhs: Box::new($1), field: span_of_tok($3) } } - | '(' Expr ')' - { $2 } + | '(' Exprs ')' + { Expr::Tuple { span: $span, exps: $2 } } + | '[' Exprs ']' + { Expr::Tuple { span: $span, exps: $2 } } | '{' Schedule '}' { Expr::BlockExpr { span: $span, body: Box::new($2) } } | '<' Fields '>' @@ -73,14 +79,18 @@ Expr -> Expr ; Args -> Vec<Expr> - : { vec![] } - | '[' Exprs ']' { rev($2) } + : { vec![] } + | '[' RExprs ']' { rev($2) } ; Exprs -> Vec<Expr> - : { vec![] } - | Expr { vec![$1] } - | Expr ',' Exprs { snoc($1, $3) } + : RExprs { rev($1) } + ; + +RExprs -> Vec<Expr> + : { vec![] } + | Expr { vec![$1] } + | Expr ',' RExprs { snoc($1, $3) } ; Fields -> Vec<(Span, Expr)> @@ -149,6 +159,7 @@ pub enum OperationList { pub enum Stmt { LetStmt { span: Span, var: Span, expr: Expr }, + LetsStmt { span: Span, vars: Vec<Span>, expr: Expr }, AssignStmt { span: Span, var: Span, rhs: Expr }, ExprStmt { span: Span, exp: Expr }, Fixpoint { span: Span, limit: FixpointLimit, body: Box<OperationList> }, @@ -180,6 +191,8 @@ pub enum Expr { BlockExpr { span: Span, body: Box<OperationList> }, Record { span: Span, fields: Vec<(Span, Expr)> }, SetOp { span: Span, op: SetOp, lhs: Box<Expr>, rhs: Box<Expr> }, + Tuple { span: Span, exps: Vec<Expr> }, + TupleField { span: Span, lhs: Box<Expr>, field: Span }, } pub enum Selector { diff --git a/juno_scheduler/src/pm.rs b/juno_scheduler/src/pm.rs index 70d8e4278169ebdbe9985e00ede161acbe05c24d..456df2eda49b93a6c80327a090b6f6606ae711bb 100644 --- a/juno_scheduler/src/pm.rs +++ b/juno_scheduler/src/pm.rs @@ -294,6 +294,9 @@ pub enum Value { Record { fields: HashMap<String, Value>, }, + Tuple { + values: Vec<Value>, + }, Everything {}, Selection { selection: Vec<Value>, @@ -371,6 +374,11 @@ impl Value { "Expected code selection, found record".to_string(), )); } + Value::Tuple { .. } => { + return Err(SchedulerError::SemanticError( + "Expected code selection, found tuple".to_string(), + )); + } Value::Integer { .. } => { return Err(SchedulerError::SemanticError( "Expected code selection, found integer".to_string(), @@ -1294,6 +1302,7 @@ fn interp_expr( | Value::Integer { .. } | Value::Boolean { .. } | Value::String { .. } + | Value::Tuple { .. } | Value::SetOp { .. } => Err(SchedulerError::UndefinedField(field.clone())), Value::JunoFunction { func } => { match pm.labels.borrow().iter().position(|s| s == field) { @@ -1467,6 +1476,28 @@ fn interp_expr( Ok((Value::Selection { selection: values }, changed)) } }, + ScheduleExp::Tuple { exprs } => { + let mut vals = vec![]; + let mut changed = false; + for exp in exprs { + let (val, change) = interp_expr(pm, exp, stringtab, env, functions)?; + vals.push(val); + changed = changed || change; + } + Ok((Value::Tuple { values: vals }, changed)) + } + ScheduleExp::TupleField { lhs, field } => { + let (val, changed) = interp_expr(pm, lhs, stringtab, env, functions)?; + match val { + Value::Tuple { values } if *field < values.len() => { + Ok((vec_take(values, *field), changed)) + } + _ => Err(SchedulerError::SemanticError(format!( + "No field at index {}", + field + ))), + } + } } } @@ -1524,6 +1555,14 @@ fn update_value( Some(Value::Record { fields: new_fields }) } } + // For tuples, if we deleted values like we do for records this would mess up the indices + // which would behave very strangely. Instead if any field cannot be updated then we + // eliminate the entire value + Value::Tuple { values } => values + .into_iter() + .map(|v| update_value(v, func_idx, juno_func_idx)) + .collect::<Option<Vec<_>>>() + .map(|values| Value::Tuple { values }), Value::JunoFunction { func } => { juno_func_idx[func.idx] .clone() @@ -2897,6 +2936,247 @@ fn run_pass( pm.delete_gravestones(); pm.clear_analyses(); } + Pass::ForkReshape => { + let mut shape = vec![]; + let mut loops = BTreeSet::new(); + let mut fork_count = 0; + + for arg in args { + let Value::Tuple { values } = arg else { + return Err(SchedulerError::PassError { + pass: "fork-reshape".to_string(), + error: "expected each argument to be a list of integers".to_string(), + }); + }; + + let mut indices = vec![]; + for val in values { + let Value::Integer { val: idx } = val else { + return Err(SchedulerError::PassError { + pass: "fork-reshape".to_string(), + error: "expected each argument to be a list of integers".to_string(), + }); + }; + indices.push(idx); + loops.insert(idx); + fork_count += 1; + } + shape.push(indices); + } + + if loops != (0..fork_count).collect() { + return Err(SchedulerError::PassError { + pass: "fork-reshape".to_string(), + error: + "expected forks to be numbered sequentially from 0 and used exactly once" + .to_string(), + }); + } + + let Some((nodes, func_id)) = selection_as_set(pm, selection) else { + return Err(SchedulerError::PassError { + pass: "fork-reshape".to_string(), + error: "must be applied to nodes in a single function".to_string(), + }); + }; + let func = func_id.idx(); + + pm.make_def_uses(); + pm.make_fork_join_maps(); + pm.make_loops(); + pm.make_reduce_cycles(); + + let def_uses = pm.def_uses.take().unwrap(); + let mut fork_join_maps = pm.fork_join_maps.take().unwrap(); + let loops = pm.loops.take().unwrap(); + let reduce_cycles = pm.reduce_cycles.take().unwrap(); + + let def_use = &def_uses[func]; + let fork_join_map = &mut fork_join_maps[func]; + let loops = &loops[func]; + let reduce_cycles = &reduce_cycles[func]; + + let mut editor = FunctionEditor::new( + &mut pm.functions[func], + func_id, + &pm.constants, + &pm.dynamic_constants, + &pm.types, + &pm.labels, + def_use, + ); + + // There should be exactly one fork nest in the selection and it should contain + // exactly fork_count forks (counting each dimension of each fork) + // We determine the loops (ordered top-down) that are contained in the selection + // (in particular the header is in the selection) and its a fork-join (the header + // is a fork) + let mut loops = loops + .bottom_up_loops() + .into_iter() + .rev() + .filter(|(header, _)| nodes.contains(header) && editor.node(header).is_fork()); + let Some((top_fork_head, top_fork_body)) = loops.next() else { + return Err(SchedulerError::PassError { + pass: "fork-reshape".to_string(), + error: format!( + "expected {} forks found 0 in {}", + fork_count, + editor.func().name + ), + }); + }; + // All the remaining forks need to be contained in the top fork body + let mut forks = vec![top_fork_head]; + let mut num_dims = editor.node(top_fork_head).try_fork().unwrap().1.len(); + for (head, _) in loops { + if !top_fork_body[head.idx()] { + return Err(SchedulerError::PassError { + pass: "fork-reshape".to_string(), + error: "selection includes multiple non-nested forks".to_string(), + }); + } else { + forks.push(head); + num_dims += editor.node(head).try_fork().unwrap().1.len(); + } + } + + if num_dims != fork_count { + return Err(SchedulerError::PassError { + pass: "fork-reshape".to_string(), + error: format!( + "expected {} forks, found {} in {}", + fork_count, num_dims, pm.functions[func].name + ), + }); + } + + // Now, we coalesce all of these forks into one so that we can interchange them + let mut forks = forks.into_iter(); + let top_fork = forks.next().unwrap(); + let mut cur_fork = top_fork; + for next_fork in forks { + let Some((new_fork, new_join)) = + fork_coalesce_helper(&mut editor, cur_fork, next_fork, &fork_join_map) + else { + return Err(SchedulerError::PassError { + pass: "fork-reshape".to_string(), + error: "failed to coalesce forks".to_string(), + }); + }; + cur_fork = new_fork; + fork_join_map.insert(new_fork, new_join); + } + let join = *fork_join_map.get(&cur_fork).unwrap(); + + // Now we have just one fork and we can perform the interchanges we need + // To do this, we track two maps: from original index to current index and from + // current index to original index + let mut orig_to_cur = (0..fork_count).collect::<Vec<_>>(); + let mut cur_to_orig = (0..fork_count).collect::<Vec<_>>(); + + // Now, starting from the first (outermost) index we move the desired fork bound + // into place + for (idx, original_idx) in shape.iter().flat_map(|idx| idx.iter()).enumerate() { + let cur_idx = orig_to_cur[*original_idx]; + let swapping = cur_to_orig[idx]; + + // If the desired factor is already in the correct place, do nothing + if cur_idx == idx { + continue; + } + assert!(idx < cur_idx); + let Some(fork_res) = fork_interchange(&mut editor, cur_fork, join, idx, cur_idx) + else { + return Err(SchedulerError::PassError { + pass: "fork-reshape".to_string(), + error: "failed to interchange forks".to_string(), + }); + }; + cur_fork = fork_res; + + // Update our maps + orig_to_cur[*original_idx] = idx; + orig_to_cur[swapping] = cur_idx; + cur_to_orig[idx] = *original_idx; + cur_to_orig[cur_idx] = swapping; + } + + // Finally we split the fork into the desired pieces. We do this by first splitting + // the fork into individual forks and then coalesce the chunks together + // Not sure how split_fork could fail, so if it does panic is fine + let (forks, joins) = split_fork(&mut editor, cur_fork, join, &reduce_cycles).unwrap(); + + for (fork, join) in forks.iter().zip(joins.iter()) { + fork_join_map.insert(*fork, *join); + } + + // Finally coalesce the chunks together + let mut fork_idx = 0; + let mut final_forks = vec![]; + for chunk in shape.iter() { + let chunk_len = chunk.len(); + + let mut cur_fork = forks[fork_idx]; + for i in 1..chunk_len { + let next_fork = forks[fork_idx + i]; + // Again, not sure at this point how coalesce could fail, so panic if it + // does + let (new_fork, new_join) = + fork_coalesce_helper(&mut editor, cur_fork, next_fork, &fork_join_map) + .unwrap(); + cur_fork = new_fork; + fork_join_map.insert(new_fork, new_join); + } + + fork_idx += chunk_len; + final_forks.push(cur_fork); + } + + // Label each fork and return the labels + // We've trashed our analyses at this point, so rerun them so that we can determine the + // nodes in each of the result fork-joins + pm.clear_analyses(); + pm.make_def_uses(); + pm.make_nodes_in_fork_joins(); + + let def_uses = pm.def_uses.take().unwrap(); + let nodes_in_fork_joins = pm.nodes_in_fork_joins.take().unwrap(); + + let def_use = &def_uses[func]; + let nodes_in_fork_joins = &nodes_in_fork_joins[func]; + + let mut editor = FunctionEditor::new( + &mut pm.functions[func], + func_id, + &pm.constants, + &pm.dynamic_constants, + &pm.types, + &pm.labels, + def_use, + ); + + let labels = create_labels_for_node_sets( + &mut editor, + final_forks + .into_iter() + .map(|fork| nodes_in_fork_joins[&fork].iter().copied()), + ) + .into_iter() + .map(|(_, label)| Value::Label { + labels: vec![LabelInfo { + func: func_id, + label, + }], + }) + .collect(); + + result = Value::Tuple { values: labels }; + changed = true; + + pm.delete_gravestones(); + pm.clear_analyses(); + } Pass::WritePredication => { assert!(args.is_empty()); for func in build_selection(pm, selection, false) { @@ -3047,3 +3327,7 @@ where }); labels } + +fn vec_take<T>(mut v: Vec<T>, index: usize) -> T { + v.swap_remove(index) +}