From 10574e4972a56f1707d3aa8272dfe62ebdbf0422 Mon Sep 17 00:00:00 2001 From: rarbore2 <rarbore2@illinois.edu> Date: Sun, 22 Dec 2024 00:06:23 -0600 Subject: [PATCH] Refactor schedules --- hercules_cg/src/cpu.rs | 2 +- hercules_ir/src/build.rs | 13 +- hercules_ir/src/dot.rs | 152 ++--- hercules_ir/src/ir.rs | 52 +- hercules_ir/src/lib.rs | 2 - hercules_ir/src/parse.rs | 11 +- hercules_ir/src/schedule.rs | 818 ----------------------- hercules_ir/src/subgraph.rs | 83 --- hercules_opt/src/ccp.rs | 10 +- hercules_opt/src/delete_uncalled.rs | 5 +- hercules_opt/src/editor.rs | 370 +++------- hercules_opt/src/fork_concat_split.rs | 12 +- hercules_opt/src/gvn.rs | 3 +- hercules_opt/src/inline.rs | 22 +- hercules_opt/src/interprocedural_sroa.rs | 4 +- hercules_opt/src/lib.rs | 2 + hercules_opt/src/outline.rs | 11 +- hercules_opt/src/pass.rs | 251 ++----- hercules_opt/src/pred.rs | 2 - hercules_opt/src/schedule.rs | 175 +++++ hercules_opt/src/sroa.rs | 15 +- hercules_opt/src/unforkify.rs | 4 + juno_frontend/src/lib.rs | 4 + juno_scheduler/src/lib.rs | 3 +- 24 files changed, 492 insertions(+), 1534 deletions(-) delete mode 100644 hercules_ir/src/schedule.rs create mode 100644 hercules_opt/src/schedule.rs diff --git a/hercules_cg/src/cpu.rs b/hercules_cg/src/cpu.rs index d9bf505c..a8f16790 100644 --- a/hercules_cg/src/cpu.rs +++ b/hercules_cg/src/cpu.rs @@ -1,7 +1,7 @@ extern crate bitvec; extern crate hercules_ir; -use std::collections::{BTreeMap, HashMap, HashSet, VecDeque}; +use std::collections::{BTreeMap, VecDeque}; use std::fmt::{Error, Write}; use std::iter::{zip, FromIterator}; use std::sync::atomic::{AtomicUsize, Ordering}; diff --git a/hercules_ir/src/build.rs b/hercules_ir/src/build.rs index 6dd5a3e6..78c4eca4 100644 --- a/hercules_ir/src/build.rs +++ b/hercules_ir/src/build.rs @@ -36,6 +36,7 @@ pub struct NodeBuilder { id: NodeID, function_id: FunctionID, node: Node, + schedules: Vec<Schedule>, } /* @@ -468,9 +469,11 @@ impl<'a> Builder<'a> { name: name.to_owned(), param_types, return_type, - nodes: vec![Node::Start], num_dynamic_constants, entry, + nodes: vec![Node::Start], + schedules: vec![vec![]], + device: None, }); Ok((id, NodeID::new(0))) } @@ -480,10 +483,12 @@ impl<'a> Builder<'a> { self.module.functions[function.idx()] .nodes .push(Node::Start); + self.module.functions[function.idx()].schedules.push(vec![]); NodeBuilder { id, function_id: function, node: Node::Start, + schedules: vec![], } } @@ -492,6 +497,8 @@ impl<'a> Builder<'a> { Err("Can't add node from a NodeBuilder before NodeBuilder has built a node.")? } self.module.functions[builder.function_id.idx()].nodes[builder.id.idx()] = builder.node; + self.module.functions[builder.function_id.idx()].schedules[builder.id.idx()] = + builder.schedules; Ok(()) } } @@ -606,4 +613,8 @@ impl NodeBuilder { indices, }; } + + pub fn add_schedule(&mut self, schedule: Schedule) { + self.schedules.push(schedule); + } } diff --git a/hercules_ir/src/dot.rs b/hercules_ir/src/dot.rs index d23a4972..a8754e78 100644 --- a/hercules_ir/src/dot.rs +++ b/hercules_ir/src/dot.rs @@ -21,7 +21,6 @@ pub fn xdot_module( doms: Option<&Vec<DomTree>>, fork_join_maps: Option<&Vec<HashMap<NodeID, NodeID>>>, bbs: Option<&Vec<Vec<NodeID>>>, - plans: Option<&Vec<Plan>>, ) { let mut tmp_path = temp_dir(); let mut rng = rand::thread_rng(); @@ -35,7 +34,6 @@ pub fn xdot_module( doms, fork_join_maps, bbs, - plans, &mut contents, ) .expect("PANIC: Unable to generate output file contents."); @@ -58,7 +56,6 @@ pub fn write_dot<W: Write>( doms: Option<&Vec<DomTree>>, fork_join_maps: Option<&Vec<HashMap<NodeID, NodeID>>>, bbs: Option<&Vec<Vec<NodeID>>>, - plans: Option<&Vec<Plan>>, w: &mut W, ) -> std::fmt::Result { write_digraph_header(w)?; @@ -66,101 +63,74 @@ pub fn write_dot<W: Write>( for function_id in (0..module.functions.len()).map(FunctionID::new) { let function = &module.functions[function_id.idx()]; let reverse_postorder = &reverse_postorders[function_id.idx()]; - let plan = plans.map(|plans| &plans[function_id.idx()]); let mut reverse_postorder_node_numbers = vec![0; function.nodes.len()]; for (idx, id) in reverse_postorder.iter().enumerate() { reverse_postorder_node_numbers[id.idx()] = idx; } write_subgraph_header(function_id, module, w)?; - let mut partition_to_node_map = plan.map(|plan| plan.invert_partition_map()); - // Step 1: draw IR graph itself. This includes all IR nodes and all edges // between IR nodes. - for partition_idx in 0..plan.map_or(1, |plan| plan.num_partitions) { - // First, write all the nodes in their subgraph. - let partition_color = plan.map(|plan| match plan.partition_devices[partition_idx] { - Device::CPU => "lightblue", - Device::GPU => "darkseagreen", - Device::AsyncRust => "plum2", - }); - if let Some(partition_color) = partition_color { - write_partition_header(function_id, partition_idx, module, partition_color, w)?; - } - let nodes_ids = if let Some(partition_to_node_map) = &mut partition_to_node_map { - let mut empty = vec![]; - std::mem::swap(&mut partition_to_node_map[partition_idx], &mut empty); - empty - } else { - (0..function.nodes.len()) - .map(NodeID::new) - .collect::<Vec<_>>() - }; - for node_id in nodes_ids.iter() { - let node = &function.nodes[node_id.idx()]; - let dst_control = node.is_control(); + for node_id in (0..function.nodes.len()).map(NodeID::new) { + let node = &function.nodes[node_id.idx()]; + let dst_control = node.is_control(); + + // Control nodes are dark red, data nodes are dark blue. + let color = if dst_control { "darkred" } else { "darkblue" }; + + write_node( + node_id, + function_id, + color, + module, + &function.schedules[node_id.idx()], + w, + )?; + } + + for node_id in (0..function.nodes.len()).map(NodeID::new) { + let node = &function.nodes[node_id.idx()]; + let dst_control = node.is_control(); + for u in def_use::get_uses(&node).as_ref() { + let src_control = function.nodes[u.idx()].is_control(); - // Control nodes are dark red, data nodes are dark blue. - let color = if dst_control { "darkred" } else { "darkblue" }; + // An edge between control nodes is dashed. An edge between data + // nodes is filled. An edge between a control node and a data + // node is dotted. + let style = if dst_control && src_control { + "dashed" + } else if !dst_control && !src_control { + "" + } else { + "dotted" + }; - write_node( - *node_id, + // To have a consistent layout, we will add "back edges" in the + // IR graph as backward facing edges in the graphviz output, so + // that they don't mess up the layout. There isn't necessarily a + // precise definition of a "back edge" in Hercules IR. I've + // found what makes for the most clear output graphs is treating + // edges to phi nodes as back edges when the phi node appears + // before the use in the reverse postorder, and treating a + // control edge a back edge when the destination appears before + // the source in the reverse postorder. + let is_back_edge = reverse_postorder_node_numbers[node_id.idx()] + < reverse_postorder_node_numbers[u.idx()] + && (node.is_phi() + || (function.nodes[node_id.idx()].is_control() + && function.nodes[u.idx()].is_control())); + write_edge( + node_id, function_id, - color, + *u, + function_id, + !is_back_edge, + "black", + style, module, - plan.map_or(&vec![], |plan| &plan.schedules[node_id.idx()]), w, )?; } - if plans.is_some() { - write_graph_footer(w)?; - } - - // Second, write all the edges coming out of a node. - for node_id in nodes_ids.iter() { - let node = &function.nodes[node_id.idx()]; - let dst_control = node.is_control(); - for u in def_use::get_uses(&node).as_ref() { - let src_control = function.nodes[u.idx()].is_control(); - - // An edge between control nodes is dashed. An edge between data - // nodes is filled. An edge between a control node and a data - // node is dotted. - let style = if dst_control && src_control { - "dashed" - } else if !dst_control && !src_control { - "" - } else { - "dotted" - }; - - // To have a consistent layout, we will add "back edges" in the - // IR graph as backward facing edges in the graphviz output, so - // that they don't mess up the layout. There isn't necessarily a - // precise definition of a "back edge" in Hercules IR. I've - // found what makes for the most clear output graphs is treating - // edges to phi nodes as back edges when the phi node appears - // before the use in the reverse postorder, and treating a - // control edge a back edge when the destination appears before - // the source in the reverse postorder. - let is_back_edge = reverse_postorder_node_numbers[node_id.idx()] - < reverse_postorder_node_numbers[u.idx()] - && (node.is_phi() - || (function.nodes[node_id.idx()].is_control() - && function.nodes[u.idx()].is_control())); - write_edge( - *node_id, - function_id, - *u, - function_id, - !is_back_edge, - "black", - style, - module, - w, - )?; - } - } } // Step 2: draw dominance edges in dark green. Don't draw post dominance @@ -258,22 +228,6 @@ fn write_subgraph_header<W: Write>( Ok(()) } -fn write_partition_header<W: Write>( - function_id: FunctionID, - partition_idx: usize, - module: &Module, - color: &str, - w: &mut W, -) -> std::fmt::Result { - let function = &module.functions[function_id.idx()]; - write!(w, "subgraph {}_{} {{\n", function.name, partition_idx)?; - write!(w, "label=\"\"\n")?; - write!(w, "style=rounded\n")?; - write!(w, "bgcolor={}\n", color)?; - write!(w, "cluster=true\n")?; - Ok(()) -} - fn write_graph_footer<W: Write>(w: &mut W) -> std::fmt::Result { write!(w, "}}\n")?; Ok(()) diff --git a/hercules_ir/src/ir.rs b/hercules_ir/src/ir.rs index cef8e43a..d4eed8e2 100644 --- a/hercules_ir/src/ir.rs +++ b/hercules_ir/src/ir.rs @@ -40,9 +40,13 @@ pub struct Function { pub name: String, pub param_types: Vec<TypeID>, pub return_type: TypeID, - pub nodes: Vec<Node>, pub num_dynamic_constants: u32, pub entry: bool, + + pub nodes: Vec<Node>, + + pub schedules: FunctionSchedule, + pub device: Option<Device>, } /* @@ -296,6 +300,48 @@ pub enum Intrinsic { Tanh, } +/* + * An individual schedule is a single "directive" for the compiler to take into + * consideration at some point during the compilation pipeline. Each schedule is + * associated with a single node. + */ +#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] +pub enum Schedule { + // This fork can be run in parallel. Any notion of tiling is handled by the + // fork tiling passes - the backend will lower a parallel fork with + // dimension N into a N-way thread launch. + ParallelFork, + // This reduce can be "run in parallel" - conceptually, the `reduct` + // backedge can be removed, and the reduce code can be merged into the + // parallel code. + ParallelReduce, + // This fork-join has no impeding control flow and the fork has a constant + // factor. + Vectorizable, + // This reduce can be re-associated. This may lower a sequential dependency + // chain into a reduction tree. + Associative, +} + +/* + * The authoritative enumeration of supported backends. Multiple backends may + * correspond to the same kind of hardware. + */ +#[derive(Debug, Clone, Copy, Serialize, Deserialize)] +pub enum Device { + LLVM, + NVVM, + // Internal nodes in the call graph are lowered to async Rust code that + // calls device functions (leaf nodes in the call graph), possibly + // concurrently. + AsyncRust, +} + +/* + * A single node may have multiple schedules. + */ +pub type FunctionSchedule = Vec<Vec<Schedule>>; + impl Module { /* * Printing out types, constants, and dynamic constants fully requires a @@ -671,9 +717,11 @@ impl Function { // Add to new_nodes. new_nodes.push(node); } - std::mem::swap(&mut new_nodes, &mut self.nodes); + // Step 4: update the schedules. + self.schedules.fix_gravestones(&node_mapping); + node_mapping } } diff --git a/hercules_ir/src/lib.rs b/hercules_ir/src/lib.rs index f7277cfa..abe7f46f 100644 --- a/hercules_ir/src/lib.rs +++ b/hercules_ir/src/lib.rs @@ -17,7 +17,6 @@ pub mod gcm; pub mod ir; pub mod loops; pub mod parse; -pub mod schedule; pub mod subgraph; pub mod typecheck; pub mod verify; @@ -33,7 +32,6 @@ pub use crate::gcm::*; pub use crate::ir::*; pub use crate::loops::*; pub use crate::parse::*; -pub use crate::schedule::*; pub use crate::subgraph::*; pub use crate::typecheck::*; pub use crate::verify::*; diff --git a/hercules_ir/src/parse.rs b/hercules_ir/src/parse.rs index 41d59441..3a47f8ea 100644 --- a/hercules_ir/src/parse.rs +++ b/hercules_ir/src/parse.rs @@ -122,9 +122,11 @@ fn parse_module<'a>(ir_text: &'a str, context: Context<'a>) -> nom::IResult<&'a name: String::from(""), param_types: vec![], return_type: TypeID::new(0), - nodes: vec![], num_dynamic_constants: 0, - entry: true + entry: true, + nodes: vec![], + schedules: vec![], + device: None, }; context.function_ids.len() ]; @@ -251,15 +253,18 @@ fn parse_function<'a>( // Intern function name. context.borrow_mut().get_function_id(function_name); + let num_nodes = fixed_nodes.len(); Ok(( ir_text, Function { name: String::from(function_name), param_types: params.into_iter().map(|x| x.5).collect(), return_type, - nodes: fixed_nodes, num_dynamic_constants, entry: true, + nodes: fixed_nodes, + schedules: vec![vec![]; num_nodes], + device: None, }, )) } diff --git a/hercules_ir/src/schedule.rs b/hercules_ir/src/schedule.rs deleted file mode 100644 index 2438a982..00000000 --- a/hercules_ir/src/schedule.rs +++ /dev/null @@ -1,818 +0,0 @@ -use std::collections::{HashMap, VecDeque}; -use std::iter::{repeat, zip}; - -use crate::*; - -/* - * An individual schedule is a single "directive" for the compiler to take into - * consideration at some point during the compilation pipeline. Each schedule is - * associated with a single node. - */ -#[derive(Debug, Clone, PartialEq, Eq)] -pub enum Schedule { - // This fork can be run in parallel and has a "natural" tiling, which may or - // may not be respected by certain backends. The field stores at least how - // many parallel tiles should be run concurrently, along each dimension. - // Some backends (such as GPU) may spawn more parallel tiles (each tile - // being a single thread in that case) along each axis. - ParallelFork(Box<[usize]>), - // This reduce can be "run in parallel" - conceptually, the `reduct` - // backedge can be removed, and the reduce code can be merged into the - // parallel code. - ParallelReduce, - // This fork-join has no impeding control flow. The field stores the vector - // width. - Vectorizable(usize), - // This reduce can be re-associated. This may lower a sequential dependency - // chain into a reduction tree. - Associative, -} - -/* - * The authoritative enumeration of supported devices. Technically, a device - * refers to a specific backend, so different "devices" may refer to the same - * "kind" of hardware. - */ -#[derive(Debug, Clone, Copy)] -pub enum Device { - CPU, - GPU, - // Hercules function calls are placed in solitary function calls that are - // directly represented in the generated async Rust runtime code. - AsyncRust, -} - -/* - * A plan is a set of schedules associated with each node, plus partitioning - * information. Partitioning information is a mapping from node ID to partition - * ID. A plan's scope is a single function. - */ -#[derive(Debug, Clone)] -pub struct Plan { - pub schedules: Vec<Vec<Schedule>>, - pub partitions: Vec<PartitionID>, - pub partition_devices: Vec<Device>, - pub num_partitions: usize, -} - -define_id_type!(PartitionID); - -impl Plan { - /* - * Invert stored map from node to partition to map from partition to nodes. - */ - pub fn invert_partition_map(&self) -> Vec<Vec<NodeID>> { - let mut map = vec![vec![]; self.num_partitions]; - - for idx in 0..self.partitions.len() { - map[self.partitions[idx].idx()].push(NodeID::new(idx)); - } - - map - } - - /* - * GCM takes in a partial partitioning, but most of the time we have a full - * partitioning. - */ - pub fn make_partial_partitioning(&self) -> Vec<Option<PartitionID>> { - self.partitions.iter().map(|id| Some(*id)).collect() - } - - /* - * LEGACY API: This is the legacy mechanism for repairing plans. Future code - * should use the `repair_plan` function in editor.rs. - * Plans must be "repairable", in the sense that the IR that's referred to - * may change after many passes. Since a plan is an explicit side data - * structure, it must be updated after every change in the IR. - */ - pub fn repair(self, function: &Function, grave_mapping: &Vec<NodeID>) -> Self { - // Unpack the plan. - let old_inverse_partition_map = self.invert_partition_map(); - let Plan { - mut schedules, - partitions: _, - partition_devices, - num_partitions: _, - } = self; - - // Schedules of old nodes just get dropped. Since schedules don't hold - // necessary semantic information, we are free to drop them arbitrarily. - schedules.fix_gravestones(grave_mapping); - schedules.resize(function.nodes.len(), vec![]); - - // Delete now empty partitions. First, filter out deleted nodes from the - // partitions and simultaneously map old node IDs to new node IDs. Then, - // filter out empty partitions. - let (new_inverse_partition_map, new_devices): (Vec<Vec<NodeID>>, Vec<Device>) = - zip(old_inverse_partition_map, partition_devices) - .into_iter() - .map(|(contents, device)| { - ( - contents - .into_iter() - .filter_map(|id| { - if id.idx() == 0 || grave_mapping[id.idx()].idx() != 0 { - Some(grave_mapping[id.idx()]) - } else { - None - } - }) - .collect::<Vec<NodeID>>(), - device, - ) - }) - .filter(|(contents, _)| !contents.is_empty()) - .unzip(); - - // Calculate the number of nodes after deletion but before addition. Use - // this is iterate new nodes later. - let num_nodes_before_addition = new_inverse_partition_map.iter().flatten().count(); - assert!(new_inverse_partition_map - .iter() - .flatten() - .all(|id| id.idx() < num_nodes_before_addition)); - - // Calculate the nodes that need to be assigned to a partition. This - // starts as just the nodes that have been added by passes. - let mut new_node_ids: VecDeque<NodeID> = (num_nodes_before_addition..function.nodes.len()) - .map(NodeID::new) - .collect(); - - // Any partition no longer containing at least one control node needs to - // be liquidated. - let (new_inverse_partition_map, new_devices): (Vec<Vec<NodeID>>, Vec<Device>) = - zip(new_inverse_partition_map, new_devices) - .into_iter() - .filter_map(|(part, device)| { - if part.iter().any(|id| function.nodes[id.idx()].is_control()) { - Some((part, device)) - } else { - // Nodes in removed partitions need to be re-partitioned. - new_node_ids.extend(part); - None - } - }) - .unzip(); - - // Assign the node IDs that need to be partitioned to partitions. In the - // process, construct a map from node ID to partition ID. - let mut node_id_to_partition_id: HashMap<NodeID, PartitionID> = new_inverse_partition_map - .into_iter() - .enumerate() - .map(|(partition_idx, node_ids)| { - node_ids - .into_iter() - .map(|node_id| (node_id, PartitionID::new(partition_idx))) - .collect::<Vec<(NodeID, PartitionID)>>() - }) - .flatten() - .collect(); - - // Make a best effort to assign nodes to the partition of one of their - // uses. Prioritize earlier uses. TODO: since not all partitions are - // legal, this is almost certainly not complete. Think more about that. - 'workloop: while let Some(id) = new_node_ids.pop_front() { - for u in get_uses(&function.nodes[id.idx()]).as_ref() { - if let Some(partition_id) = node_id_to_partition_id.get(u) { - node_id_to_partition_id.insert(id, *partition_id); - continue 'workloop; - } - } - new_node_ids.push_back(id); - } - - // Reconstruct the partitions vector. - let num_partitions = new_devices.len(); - let mut partitions = vec![PartitionID::new(0); function.nodes.len()]; - for (k, v) in node_id_to_partition_id { - partitions[k.idx()] = v; - } - - // Reconstruct the whole plan. - Plan { - schedules, - partitions, - partition_devices: new_devices, - num_partitions, - } - } - - /* - * Verify that a partitioning is valid. - */ - pub fn verify_partitioning( - &self, - function: &Function, - def_use: &ImmutableDefUseMap, - fork_join_map: &HashMap<NodeID, NodeID>, - ) { - let partition_to_node_ids = self.invert_partition_map(); - - // First, verify that there is at most one control node in the partition - // with a control use outside the partition. A partition may only have - // zero such control nodes if it contains the start node. This also - // checks that each partition has at least one control node. - for nodes_in_partition in partition_to_node_ids.iter() { - let contains_start = nodes_in_partition - .iter() - .any(|id| function.nodes[id.idx()] == Node::Start); - let num_inter_partition_control_uses = nodes_in_partition - .iter() - .filter(|id| { - // An inter-partition control use is a control node, - function.nodes[id.idx()].is_control() - // where one of its uses, - && get_uses(&function.nodes[id.idx()]) - .as_ref() - .into_iter() - .any(|use_id| { - // that is itself a control node as well, - function.nodes[use_id.idx()].is_control() - // is in a different partition. - && self.partitions[use_id.idx()] != self.partitions[id.idx()] - }) - }) - .count(); - - assert!(num_inter_partition_control_uses + contains_start as usize == 1, "PANIC: Found an invalid partition based on the inter-partition control use criteria."); - } - - // Second, verify that fork-joins are not split amongst partitions. - for id in (0..function.nodes.len()).map(NodeID::new) { - if function.nodes[id.idx()].is_fork() { - let fork_part = self.partitions[id.idx()]; - - // The join must be in the same partition. - let join = fork_join_map[&id]; - assert_eq!( - fork_part, - self.partitions[join.idx()], - "PANIC: Join is in a different partition than its corresponding fork." - ); - - // The thread IDs must be in the same partition. - def_use - .get_users(id) - .into_iter() - .filter(|user| function.nodes[user.idx()].is_thread_id()) - .for_each(|thread_id| { - assert_eq!( - fork_part, - self.partitions[thread_id.idx()], - "PANIC: Thread ID is in a different partition than its fork use." - ) - }); - - // The reduces must be in the same partition. - def_use - .get_users(join) - .into_iter() - .filter(|user| function.nodes[user.idx()].is_reduce()) - .for_each(|reduce| { - assert_eq!( - fork_part, - self.partitions[reduce.idx()], - "PANIC: Reduce is in a different partition than its join use." - ) - }); - } - } - - // Third, verify that every data node has proper dominance relations - // with respect to the partitioning. In particular: - // 1. Every non-phi data node should be in a partition that is dominated - // by the partitions of every one of its uses. - // 2. Every data node should be in a partition that dominates the - // partitions of every one of its non-phi users. - // Compute a dominance relation between the partitions by constructing a - // partition control graph. - let partition_graph = partition_graph(function, def_use, self); - let dom = dominator(&partition_graph, NodeID::new(self.partitions[0].idx())); - for id in (0..function.nodes.len()).map(NodeID::new) { - if function.nodes[id.idx()].is_control() { - continue; - } - let part = self.partitions[id.idx()]; - - // Check condition #1. - if !function.nodes[id.idx()].is_phi() { - let uses = get_uses(&function.nodes[id.idx()]); - for use_id in uses.as_ref() { - let use_part = self.partitions[use_id.idx()]; - assert!(dom.does_dom(NodeID::new(use_part.idx()), NodeID::new(part.idx())), "PANIC: A data node has a partition use that doesn't dominate its partition."); - } - } - - // Check condition #2. - let users = def_use.get_users(id); - for user_id in users.as_ref() { - if !function.nodes[user_id.idx()].is_phi() { - let user_part = self.partitions[user_id.idx()]; - assert!(dom.does_dom(NodeID::new(part.idx()), NodeID::new(user_part.idx())), "PANIC: A data node has a partition user that isn't dominated by its partition."); - } - } - } - - // Fourth, verify that every projection node is in the same partition as - // its control use. - for id in (0..function.nodes.len()).map(NodeID::new) { - if let Node::Projection { - control, - selection: _, - } = function.nodes[id.idx()] - { - assert_eq!( - self.partitions[id.idx()], - self.partitions[control.idx()], - "PANIC: Found a projection node in a different partition than its control use." - ); - } - } - - // Fifth, verify that every partition has at least one partition - // successor xor has at least one return node. - for partition_idx in 0..self.num_partitions { - let has_successor = partition_graph.succs(NodeID::new(partition_idx)).count() > 0; - let has_return = partition_to_node_ids[partition_idx] - .iter() - .any(|node_id| function.nodes[node_id.idx()].is_return()); - assert!(has_successor ^ has_return, "PANIC: Found an invalid partition based on the partition return / control criteria."); - } - } - - /* - * Compute the top node for each partition. - */ - pub fn compute_top_nodes( - &self, - function: &Function, - control_subgraph: &Subgraph, - inverted_partition_map: &Vec<Vec<NodeID>>, - ) -> Vec<NodeID> { - inverted_partition_map - .into_iter() - .enumerate() - .map(|(part_idx, part)| { - // For each partition, find the "top" node. - *part - .iter() - .filter(move |id| { - // The "top" node is a control node having at least one - // control predecessor in another partition, or is a - // start node. Every predecessor in the control subgraph - // is a control node. - function.nodes[id.idx()].is_start() - || (function.nodes[id.idx()].is_control() - && control_subgraph - .preds(**id) - .filter(|pred_id| { - self.partitions[pred_id.idx()].idx() != part_idx - }) - .count() - > 0) - }) - // We assume here there is exactly one such top node per - // partition. Verify a partitioning with - // `verify_partitioning` before calling this method. - .next() - .unwrap() - }) - .collect() - } - - /* - * Compute the data inputs of each partition. - */ - pub fn compute_data_inputs(&self, function: &Function) -> Vec<Vec<NodeID>> { - let mut data_inputs = vec![vec![]; self.num_partitions]; - - // First consider the non-phi nodes in each partition. - for id in (0..function.nodes.len()).map(NodeID::new) { - if function.nodes[id.idx()].is_phi() { - continue; - } - - let data_inputs = &mut data_inputs[self.partitions[id.idx()].idx()]; - let uses = get_uses(&function.nodes[id.idx()]); - for use_id in uses.as_ref() { - // For every non-phi node, check each of its data uses. If the - // node and its use are in different partitions, then the use is - // a data input for the partition of the node. Also, don't add - // the same node to the data inputs list twice. Only consider - // non-constant uses data inputs - constant nodes are always - // rematerialized into the user partition. - if !function.nodes[use_id.idx()].is_control() - && !function.nodes[use_id.idx()].is_constant() - && self.partitions[id.idx()] != self.partitions[use_id.idx()] - && !data_inputs.contains(use_id) - { - data_inputs.push(*use_id); - } - } - } - - // Second consider the phi nodes in each partition. - for id in (0..function.nodes.len()).map(NodeID::new) { - if !function.nodes[id.idx()].is_phi() { - continue; - } - - let data_inputs = &mut data_inputs[self.partitions[id.idx()].idx()]; - let uses = get_uses(&function.nodes[id.idx()]); - for use_id in uses.as_ref() { - // For every phi node, if any one of its non-constant uses is - // defined in a different partition, then the phi node itself, - // not its outside uses, is considered a data input. This is - // because a phi node whose uses are all in a different - // partition should be lowered to a single parameter to the - // corresponding simple IR function. Note that for a phi node - // with some uses outside and some uses inside the partition, - // the uses outside the partition become a single parameter to - // the schedule IR function, and that parameter and all of the - // "inside" uses become the inputs to a phi inside the simple IR - // function. - if self.partitions[id.idx()] != self.partitions[use_id.idx()] - && !function.nodes[use_id.idx()].is_constant() - && !data_inputs.contains(&id) - { - data_inputs.push(id); - break; - } - } - } - - // Sort the node IDs to keep a consistent interface between partitions. - for data_inputs in &mut data_inputs { - data_inputs.sort(); - } - data_inputs - } - - /* - * Compute the data outputs of each partition. - */ - pub fn compute_data_outputs( - &self, - function: &Function, - def_use: &ImmutableDefUseMap, - ) -> Vec<Vec<NodeID>> { - let mut data_outputs = vec![vec![]; self.num_partitions]; - - for id in (0..function.nodes.len()).map(NodeID::new) { - // Only consider non-constant data nodes as data outputs, since - // constant nodes are rematerialized into the user partition. - if function.nodes[id.idx()].is_control() || function.nodes[id.idx()].is_constant() { - continue; - } - - let data_outputs = &mut data_outputs[self.partitions[id.idx()].idx()]; - let users = def_use.get_users(id); - for user_id in users.as_ref() { - // For every data node, check each of its users. If the node and - // its user are in different partitions, then the node is a data - // output for the partition of the node. Also, don't add the - // same node to the data outputs list twice. It doesn't matter - // how this data node is being used - all that matters is that - // it itself is a data node, and that it has a user outside the - // partition. This makes the code simpler than the inputs case. - if self.partitions[id.idx()] != self.partitions[user_id.idx()] - && !data_outputs.contains(&id) - { - data_outputs.push(id); - break; - } - } - } - - // Sort the node IDs to keep a consistent interface between partitions. - for data_outputs in &mut data_outputs { - data_outputs.sort(); - } - data_outputs - } - - pub fn renumber_partitions(&mut self) { - let mut renumber_partitions = HashMap::new(); - for id in self.partitions.iter_mut() { - let next_id = PartitionID::new(renumber_partitions.len()); - let old_id = *id; - let new_id = *renumber_partitions.entry(old_id).or_insert(next_id); - *id = new_id; - } - let mut new_devices = vec![Device::CPU; renumber_partitions.len()]; - for (old_id, new_id) in renumber_partitions.iter() { - new_devices[new_id.idx()] = self.partition_devices[old_id.idx()]; - } - self.partition_devices = new_devices; - self.num_partitions = renumber_partitions.len(); - } -} - -impl GraveUpdatable for Plan { - fn fix_gravestones(&mut self, grave_mapping: &Vec<NodeID>) { - self.schedules.fix_gravestones(grave_mapping); - self.partitions.fix_gravestones(grave_mapping); - self.renumber_partitions(); - } -} - -/* - * Infer parallel fork-joins. These are fork-joins with only parallel reduction - * variables and no parent fork-joins. - */ -pub fn infer_parallel_fork( - function: &Function, - def_use: &ImmutableDefUseMap, - fork_join_map: &HashMap<NodeID, NodeID>, - fork_join_nest: &HashMap<NodeID, Vec<NodeID>>, - plan: &mut Plan, -) { - for id in (0..function.nodes.len()).map(NodeID::new) { - let Node::Fork { - control: _, - ref factors, - } = function.nodes[id.idx()] - else { - continue; - }; - let join_id = fork_join_map[&id]; - let all_parallel_reduce = def_use.get_users(join_id).as_ref().into_iter().all(|user| { - plan.schedules[user.idx()].contains(&Schedule::ParallelReduce) - || function.nodes[user.idx()].is_control() - }); - let top_level = fork_join_nest[&id].len() == 1 && fork_join_nest[&id][0] == id; - if all_parallel_reduce && top_level { - let tiling = repeat(4).take(factors.len()).collect(); - plan.schedules[id.idx()].push(Schedule::ParallelFork(tiling)); - } - } -} - -/* - * Infer parallel reductions consisting of a simple cycle between a Reduce node - * and a Write node, where indices of the Write are position indices using the - * ThreadID nodes attached to the corresponding Fork. This procedure also adds - * the ParallelReduce schedule to Reduce nodes reducing over a parallelized - * Reduce, as long as the base Write node also has position indices of the - * ThreadID of the outer fork. In other words, the complete Reduce chain is - * annotated with ParallelReduce, as long as each ThreadID dimension appears in - * the positional indexing of the original Write. - */ -pub fn infer_parallel_reduce( - function: &Function, - fork_join_map: &HashMap<NodeID, NodeID>, - plan: &mut Plan, -) { - for id in (0..function.nodes.len()) - .map(NodeID::new) - .filter(|id| function.nodes[id.idx()].is_reduce()) - { - let mut first_control = None; - let mut last_reduce = id; - let mut chain_id = id; - - // Walk down Reduce chain until we reach the Reduce potentially looping - // with the Write. Note the control node of the first Reduce, since this - // will tell us which Thread ID to look for in the Write. - while let Node::Reduce { - control, - init: _, - reduct, - } = function.nodes[chain_id.idx()] - { - if first_control.is_none() { - first_control = Some(control); - } - - last_reduce = chain_id; - chain_id = reduct; - } - - // Check for a Write-Reduce tight cycle. - if let Node::Write { - collect, - data: _, - indices, - } = &function.nodes[chain_id.idx()] - && *collect == last_reduce - { - // If there is a Write-Reduce tight cycle, get the position indices. - let positions = indices - .iter() - .filter_map(|index| { - if let Index::Position(indices) = index { - Some(indices) - } else { - None - } - }) - .flat_map(|pos| pos.iter()); - - // Get the Forks corresponding to uses of bare ThreadIDs. - let fork_thread_id_pairs = positions.filter_map(|id| { - if let Node::ThreadID { control, dimension } = function.nodes[id.idx()] { - Some((control, dimension)) - } else { - None - } - }); - let mut forks = HashMap::<NodeID, Vec<usize>>::new(); - for (fork, dim) in fork_thread_id_pairs { - forks.entry(fork).or_default().push(dim); - } - - // Check if one of the Forks correspond to the Join associated with - // the Reduce being considered, and has all of its dimensions - // represented in the indexing. - let is_parallel = forks.into_iter().any(|(id, mut rep_dims)| { - rep_dims.sort(); - rep_dims.dedup(); - fork_join_map[&id] == first_control.unwrap() - && function.nodes[id.idx()].try_fork().unwrap().1.len() == rep_dims.len() - }); - - if is_parallel { - plan.schedules[id.idx()].push(Schedule::ParallelReduce); - } - } - } -} - -/* - * Infer vectorizable fork-joins. Just check that there are no control nodes - * between a fork and its join and the factor is a constant. - */ -pub fn infer_vectorizable( - function: &Function, - dynamic_constants: &Vec<DynamicConstant>, - fork_join_map: &HashMap<NodeID, NodeID>, - plan: &mut Plan, -) { - for id in (0..function.nodes.len()) - .map(NodeID::new) - .filter(|id| function.nodes[id.idx()].is_join()) - { - let u = get_uses(&function.nodes[id.idx()]).as_ref()[0]; - if let Some(join) = fork_join_map.get(&u) - && *join == id - { - let factors = function.nodes[u.idx()].try_fork().unwrap().1; - if factors.len() == 1 - && let Some(width) = evaluate_dynamic_constant(factors[0], dynamic_constants) - { - plan.schedules[u.idx()].push(Schedule::Vectorizable(width)); - } - } - } -} - -/* - * Infer associative reduction loops. - */ -pub fn infer_associative(function: &Function, plan: &mut Plan) { - let is_associative = |op| match op { - BinaryOperator::Add - | BinaryOperator::Mul - | BinaryOperator::Or - | BinaryOperator::And - | BinaryOperator::Xor => true, - _ => false, - }; - - for (id, reduct) in (0..function.nodes.len()).map(NodeID::new).filter_map(|id| { - function.nodes[id.idx()] - .try_reduce() - .map(|(_, _, reduct)| (id, reduct)) - }) { - if let Node::Binary { left, right, op } = function.nodes[reduct.idx()] - && (left == id || right == id) - && is_associative(op) - { - plan.schedules[id.idx()].push(Schedule::Associative); - } - } -} - -/* - * Create partitions corresponding to fork-join nests. Also, split the "top- - * level" partition into sub-partitions that are connected graphs. Place data - * nodes using ThreadID or Reduce nodes in the corresponding fork-join nest's - * partition. - */ -pub fn partition_out_forks( - function: &Function, - reverse_postorder: &Vec<NodeID>, - fork_join_map: &HashMap<NodeID, NodeID>, - bbs: &Vec<NodeID>, - plan: &mut Plan, -) { - #[allow(non_local_definitions)] - impl Semilattice for NodeID { - fn meet(a: &Self, b: &Self) -> Self { - if a.idx() < b.idx() { - *a - } else { - *b - } - } - - fn bottom() -> Self { - NodeID::new(0) - } - - fn top() -> Self { - NodeID::new(!0) - } - } - - // Step 1: do dataflow analysis over control nodes to identify a - // representative node for each partition. Each fork not taking as input the - // ID of another fork node introduces its own ID as a partition - // representative. Each join node propagates the fork ID if it's not the - // fork pointing to the join in the fork join map - otherwise, it introduces - // its user as a representative node ID for a partition. A region node taking - // multiple node IDs as input belongs to the partition with the smaller - // representative node ID. - let mut representatives = forward_dataflow( - function, - reverse_postorder, - |inputs: &[&NodeID], node_id: NodeID| match function.nodes[node_id.idx()] { - Node::Start => NodeID::new(0), - Node::Fork { - control, - factors: _, - } => { - // Start a partition if the preceding partition isn't a fork - // partition and the predecessor isn't the join for the - // predecessor fork partition. Otherwise, be part of the parent - // fork partition. - if *inputs[0] != NodeID::top() - && function.nodes[inputs[0].idx()].is_fork() - && fork_join_map.get(&inputs[0]) != Some(&control) - { - inputs[0].clone() - } else { - node_id - } - } - Node::Join { control: _ } => inputs[0].clone(), - _ => { - // If the previous node is a join and terminates a fork's - // partition, then start a new partition here. Otherwise, just - // meet over the input lattice values. Set all data nodes to be - // in the !0 partition. - if !function.nodes[node_id.idx()].is_control() { - NodeID::top() - } else if zip(inputs, get_uses(&function.nodes[node_id.idx()]).as_ref()) - .any(|(part_id, pred_id)| fork_join_map.get(part_id) == Some(pred_id)) - { - node_id - } else { - inputs - .iter() - .filter(|id| { - ***id != NodeID::top() && function.nodes[id.idx()].is_control() - }) - .fold(NodeID::top(), |a, b| NodeID::meet(&a, b)) - } - } - }, - ); - - // Step 2: assign data nodes to the partitions of the control nodes they are - // assigned to by GCM. - for idx in 0..function.nodes.len() { - if !function.nodes[idx].is_control() { - representatives[idx] = representatives[bbs[idx].idx()]; - } - } - - // Step 3: deduplicate representative node IDs. - let mut representative_to_partition_ids = HashMap::new(); - for rep in &representatives { - if !representative_to_partition_ids.contains_key(rep) { - representative_to_partition_ids - .insert(rep, PartitionID::new(representative_to_partition_ids.len())); - } - } - - // Step 4: update plan. - plan.num_partitions = representative_to_partition_ids.len(); - for id in (0..function.nodes.len()).map(NodeID::new) { - plan.partitions[id.idx()] = representative_to_partition_ids[&representatives[id.idx()]]; - } - - plan.partition_devices = vec![Device::CPU; plan.num_partitions]; -} - -/* - * Set the device for all partitions containing a fork to the GPU. - */ -pub fn place_fork_partitions_on_gpu(function: &Function, plan: &mut Plan) { - for idx in 0..function.nodes.len() { - if function.nodes[idx].is_fork() { - plan.partition_devices[plan.partitions[idx].idx()] = Device::GPU; - } - } -} diff --git a/hercules_ir/src/subgraph.rs b/hercules_ir/src/subgraph.rs index 3549718a..de342579 100644 --- a/hercules_ir/src/subgraph.rs +++ b/hercules_ir/src/subgraph.rs @@ -246,86 +246,3 @@ pub fn control_subgraph(function: &Function, def_use: &ImmutableDefUseMap) -> Su && (!function.nodes[node.idx()].is_start() || node.idx() == 0) }) } - -/* - * Construct a subgraph representing the control relations between partitions. - * Technically, this isn't a "sub"graph of the function graph, since partition - * nodes don't correspond to nodes in the original function. - */ -pub fn partition_graph(function: &Function, def_use: &ImmutableDefUseMap, plan: &Plan) -> Subgraph { - let partition_to_node_ids = plan.invert_partition_map(); - - let mut subgraph = Subgraph { - nodes: (0..plan.num_partitions).map(NodeID::new).collect(), - node_numbers: (0..plan.num_partitions) - .map(|idx| (NodeID::new(idx), idx as u32)) - .collect(), - first_forward_edges: vec![], - forward_edges: vec![], - first_backward_edges: vec![], - backward_edges: vec![], - original_num_nodes: plan.num_partitions as u32, - }; - - // Step 1: collect backward edges from use info. - for partition in partition_to_node_ids.iter() { - // Record the source of the edges (the current partition). - let old_num_edges = subgraph.backward_edges.len(); - subgraph.first_backward_edges.push(old_num_edges as u32); - for node in partition - .iter() - .filter(|id| function.nodes[id.idx()].is_control()) - { - // Look at all the control uses of control nodes in that partition. - let uses = get_uses(&function.nodes[node.idx()]); - for use_id in uses - .as_ref() - .iter() - .filter(|id| function.nodes[id.idx()].is_control()) - { - // Add a backward edge to any different partition we are using - // and don't add duplicate backward edges. - if plan.partitions[use_id.idx()] != plan.partitions[node.idx()] - && !subgraph.backward_edges[old_num_edges..] - .contains(&(plan.partitions[use_id.idx()].idx() as u32)) - { - subgraph - .backward_edges - .push(plan.partitions[use_id.idx()].idx() as u32); - } - } - } - } - - // Step 2: collect forward edges from user (def_use) info. - for partition in partition_to_node_ids.iter() { - // Record the source of the edges (the current partition). - let old_num_edges = subgraph.forward_edges.len(); - subgraph.first_forward_edges.push(old_num_edges as u32); - for node in partition - .iter() - .filter(|id| function.nodes[id.idx()].is_control()) - { - // Look at all the control users of control nodes in that partition. - let users = def_use.get_users(*node); - for user_id in users - .as_ref() - .iter() - .filter(|id| function.nodes[id.idx()].is_control()) - { - // Add a forward edge to any different partition that we are a - // user of and don't add duplicate forward edges. - if plan.partitions[user_id.idx()] != plan.partitions[node.idx()] - && !subgraph.forward_edges[old_num_edges..] - .contains(&(plan.partitions[user_id.idx()].idx() as u32)) - { - subgraph - .forward_edges - .push(plan.partitions[user_id.idx()].idx() as u32); - } - } - } - } - - subgraph -} diff --git a/hercules_opt/src/ccp.rs b/hercules_opt/src/ccp.rs index aa3d0e68..3f53ea40 100644 --- a/hercules_opt/src/ccp.rs +++ b/hercules_opt/src/ccp.rs @@ -194,6 +194,7 @@ pub fn ccp(editor: &mut FunctionEditor, reverse_postorder: &Vec<NodeID>) { // Add a constant IR node for this constant let cons_node = edit.add_node(Node::Constant { id: cons_id }); // Replace the original node with the constant node + edit.sub_edit(old_id, cons_node); edit = edit.replace_all_uses(old_id, cons_node)?; edit.delete_node(old_id) }); @@ -227,6 +228,7 @@ pub fn ccp(editor: &mut FunctionEditor, reverse_postorder: &Vec<NodeID>) { control, data: new_data, }); + edit.sub_edit(phi_id, new_node); edit = edit.replace_all_uses(phi_id, new_node)?; edit.delete_node(phi_id) }); @@ -248,6 +250,7 @@ pub fn ccp(editor: &mut FunctionEditor, reverse_postorder: &Vec<NodeID>) { .collect(); editor.edit(|mut edit| { let new_node = edit.add_node(Node::Region { preds: new_preds }); + edit.sub_edit(region_id, new_node); edit = edit.replace_all_uses(region_id, new_node)?; edit.delete_node(region_id) }); @@ -276,12 +279,7 @@ pub fn ccp(editor: &mut FunctionEditor, reverse_postorder: &Vec<NodeID>) { // The reachable users iterator will contain one user if we need to // remove this branch node. if reachable_users.len() == 1 { - // The user is a Read node, which in turn has one user. - assert!( - editor.get_users(the_reachable_user).len() == 1, - "Control Read node doesn't have exactly one user." - ); - + // The user is a projection node, which in turn has one user. editor.edit(|mut edit| { // Replace all uses of the single reachable user with the node preceeding the // branch node diff --git a/hercules_opt/src/delete_uncalled.rs b/hercules_opt/src/delete_uncalled.rs index fe99b5fb..78ab4285 100644 --- a/hercules_opt/src/delete_uncalled.rs +++ b/hercules_opt/src/delete_uncalled.rs @@ -67,8 +67,9 @@ pub fn delete_uncalled( dynamic_constants: dynamic_constants.clone(), args: args.clone(), }); - let edit = edit.delete_node(callsite)?; - edit.replace_all_uses(callsite, new_node) + edit.sub_edit(callsite, new_node); + let edit = edit.replace_all_uses(callsite, new_node)?; + edit.delete_node(callsite) }); assert!( success, diff --git a/hercules_opt/src/editor.rs b/hercules_opt/src/editor.rs index 0ff58822..0c97abff 100644 --- a/hercules_opt/src/editor.rs +++ b/hercules_opt/src/editor.rs @@ -4,32 +4,21 @@ extern crate hercules_ir; extern crate itertools; use std::cell::{Ref, RefCell}; -use std::collections::{BTreeMap, HashMap, HashSet, VecDeque}; -use std::iter::FromIterator; +use std::collections::{BTreeMap, HashSet}; use std::mem::take; use std::ops::Deref; use self::bitvec::prelude::*; use self::either::Either; -use self::itertools::Itertools; -use self::hercules_ir::antideps::*; -use self::hercules_ir::dataflow::*; use self::hercules_ir::def_use::*; -use self::hercules_ir::dom::*; -use self::hercules_ir::gcm::*; use self::hercules_ir::ir::*; -use self::hercules_ir::loops::*; -use self::hercules_ir::schedule::*; -use self::hercules_ir::subgraph::*; - -pub type Edit = (HashSet<NodeID>, HashSet<NodeID>); /* - * Helper object for editing Hercules functions in a trackable manner. Edits are - * recorded in order to repair partitions and debug info. - * Edits must be made atomically, that is, only one `.edit` may be called at a time - * across all editors. + * Helper object for editing Hercules functions in a trackable manner. Edits + * must be made atomically, that is, only one `.edit` may be called at a time + * across all editors, and individual edits must leave the function in a valid + * state. */ #[derive(Debug)] pub struct FunctionEditor<'a> { @@ -44,19 +33,6 @@ pub struct FunctionEditor<'a> { // Most optimizations need def use info, so provide an iteratively updated // mutable version that's automatically updated based on recorded edits. mut_def_use: Vec<HashSet<NodeID>>, - // Record edits as a mapping from sets of node IDs to sets of node IDs. The - // sets on the "left" side of this map should be mutually disjoint, and the - // sets on the "right" side should also be mutually disjoint. All of the - // node IDs on the left side should be deleted node IDs or IDs of unmodified - // nodes, and all of the node IDs on the right side should be added node IDs - // or IDs of unmodified nodes. In other words, there should be no added node - // IDs on the left side, and no deleted node IDs on the right side. These - // mappings are stored sequentially in a list, rather than as a map. This is - // because a transformation may iteratively update a function - i.e., a node - // ID added in iteration N may be deleted in iteration N + M. To maintain a - // more precise history of edits, we store each edit individually, which - // allows us to make more precise repairs of partitions and debug info. - edits: Vec<Edit>, // The pass manager may indicate that only a certain subset of nodes should // be modified in a function - what this actually means is that some nodes // are off limits for deletion (equivalently modification) or being replaced @@ -77,6 +53,8 @@ pub struct FunctionEdit<'a: 'b, 'b> { added_nodeids: HashSet<NodeID>, // Keep track of added and use updated nodes. added_and_updated_nodes: BTreeMap<NodeID, Node>, + // Keep track of added and updated schedules. + added_and_updated_schedules: BTreeMap<NodeID, Vec<Schedule>>, // Keep track of added (dynamic) constants and types added_constants: Vec<Constant>, added_dynamic_constants: Vec<DynamicConstant>, @@ -84,6 +62,8 @@ pub struct FunctionEdit<'a: 'b, 'b> { // Compute a def-use map entries iteratively. updated_def_use: BTreeMap<NodeID, HashSet<NodeID>>, updated_return_type: Option<TypeID>, + // Keep track of which deleted and added node IDs directly correspond. + sub_edits: Vec<(NodeID, NodeID)>, } impl<'a: 'b, 'b> FunctionEditor<'a> { @@ -111,7 +91,6 @@ impl<'a: 'b, 'b> FunctionEditor<'a> { dynamic_constants, types, mut_def_use, - edits: vec![], mutable_nodes, } } @@ -125,12 +104,14 @@ impl<'a: 'b, 'b> FunctionEditor<'a> { editor: self, deleted_nodeids: HashSet::new(), added_nodeids: HashSet::new(), + added_and_updated_nodes: BTreeMap::new(), + added_and_updated_schedules: BTreeMap::new(), added_constants: Vec::new().into(), added_dynamic_constants: Vec::new().into(), added_types: Vec::new().into(), - added_and_updated_nodes: BTreeMap::new(), updated_def_use: BTreeMap::new(), updated_return_type: None, + sub_edits: vec![], }; if let Ok(populated_edit) = edit(edit_obj) { @@ -140,12 +121,14 @@ impl<'a: 'b, 'b> FunctionEditor<'a> { editor, deleted_nodeids, added_nodeids, + added_and_updated_nodes, + added_and_updated_schedules, added_constants, added_dynamic_constants, added_types, - added_and_updated_nodes: added_and_updated, updated_def_use, updated_return_type, + sub_edits, } = populated_edit; // Step 1: update the mutable def use map. for (u, new_users) in updated_def_use { @@ -165,7 +148,7 @@ impl<'a: 'b, 'b> FunctionEditor<'a> { } // Step 2: add and update nodes. - for (id, node) in added_and_updated { + for (id, node) in added_and_updated_nodes { if id.idx() < editor.function.nodes.len() { editor.function.nodes[id.idx()] = node; } else { @@ -176,7 +159,16 @@ impl<'a: 'b, 'b> FunctionEditor<'a> { } } - // Step 3: delete nodes. This is done using "gravestones", where a + // Step 3: add and update schedules. + editor + .function + .schedules + .resize(editor.function.nodes.len(), vec![]); + for (id, schedule) in added_and_updated_schedules { + editor.function.schedules[id.idx()] = schedule; + } + + // Step 4: delete nodes. This is done using "gravestones", where a // node other than node ID 0 being a start node is considered a // gravestone. for id in deleted_nodeids.iter() { @@ -185,16 +177,24 @@ impl<'a: 'b, 'b> FunctionEditor<'a> { editor.function.nodes[id.idx()] = Node::Start; } - // Step 4: add a single edit to the edit list. - editor.edits.push((deleted_nodeids, added_nodeids)); + // Step 5: propagate schedules along sub-edit edges. + for (src, dst) in sub_edits { + let mut dst_schedules = take(&mut editor.function.schedules[dst.idx()]); + for src_schedule in editor.function.schedules[src.idx()].iter() { + if !dst_schedules.contains(src_schedule) { + dst_schedules.push(src_schedule.clone()); + } + } + editor.function.schedules[dst.idx()] = dst_schedules; + } - // Step 5: update the length of mutable_nodes. All added nodes are + // Step 6: update the length of mutable_nodes. All added nodes are // mutable. editor .mutable_nodes .resize(editor.function.nodes.len(), true); - // Step 6: update types and constants + // Step 7: update types and constants. let mut editor_constants = editor.constants.borrow_mut(); let mut editor_dynamic_constants = editor.dynamic_constants.borrow_mut(); let mut editor_types = editor.types.borrow_mut(); @@ -203,7 +203,7 @@ impl<'a: 'b, 'b> FunctionEditor<'a> { editor_dynamic_constants.extend(added_dynamic_constants); editor_types.extend(added_types); - // Step 7: update return type if necessary + // Step 8: update return type if necessary. if let Some(return_type) = updated_return_type { editor.function.return_type = return_type; } @@ -244,10 +244,6 @@ impl<'a: 'b, 'b> FunctionEditor<'a> { self.mutable_nodes[id.idx()] } - pub fn edits(self) -> Vec<Edit> { - self.edits - } - pub fn node_ids(&self) -> impl ExactSizeIterator<Item = NodeID> { let num = self.function.nodes.len(); (0..num).map(NodeID::new) @@ -377,6 +373,12 @@ impl<'a, 'b> FunctionEdit<'a, 'b> { } } + pub fn sub_edit(&mut self, src: NodeID, dst: NodeID) { + assert!(!self.added_nodeids.contains(&src)); + assert!(!self.deleted_nodeids.contains(&dst)); + self.sub_edits.push((src, dst)); + } + pub fn get_name(&self) -> &str { &self.editor.function.name } @@ -397,6 +399,43 @@ impl<'a, 'b> FunctionEdit<'a, 'b> { } } + pub fn get_schedule(&self, id: NodeID) -> &Vec<Schedule> { + // The user may get the schedule of a to-be deleted node. + if let Some(schedule) = self.added_and_updated_schedules.get(&id) { + // Refer to added or updated schedule. + schedule + } else { + // Refer to the original schedule of this node. + &self.editor.function.schedules[id.idx()] + } + } + + pub fn add_schedule(mut self, id: NodeID, schedule: Schedule) -> Result<Self, Self> { + if self.is_mutable(id) { + if let Some(schedules) = self.added_and_updated_schedules.get_mut(&id) { + schedules.push(schedule); + } else { + let mut schedules = self.editor.function.schedules[id.idx()].clone(); + if !schedules.contains(&schedule) { + schedules.push(schedule); + } + self.added_and_updated_schedules.insert(id, schedules); + } + Ok(self) + } else { + Err(self) + } + } + + pub fn clear_schedule(mut self, id: NodeID) -> Result<Self, Self> { + if self.is_mutable(id) { + self.added_and_updated_schedules.insert(id, vec![]); + Ok(self) + } else { + Err(self) + } + } + pub fn get_users(&self, id: NodeID) -> impl Iterator<Item = NodeID> + '_ { assert!(!self.deleted_nodeids.contains(&id)); if let Some(users) = self.updated_def_use.get(&id) { @@ -539,250 +578,6 @@ impl<'a, 'b> FunctionEdit<'a, 'b> { } } -/* - * Simplify an edit sequence into a single, larger, edit. - */ -fn collapse_edits(edits: &[Edit]) -> Edit { - let mut total_edit = Edit::default(); - - let mut all_additions: HashSet<NodeID> = HashSet::new(); - - for edit in edits { - assert!(edit.0.is_disjoint(&edit.1), "PANIC: Edit sequence is malformed - can't add and delete the same node ID in a single edit."); - assert!( - total_edit.0.is_disjoint(&edit.0), - "PANIC: Edit sequence is malformed - can't delete the same node ID twice." - ); - assert!( - total_edit.1.is_disjoint(&edit.1), - "PANIC: Edit sequence is malformed - can't add the same node ID twice." - ); - - for delete in edit.0.iter() { - if !all_additions.contains(delete) { - total_edit.0.insert(*delete); - } - total_edit.1.remove(delete); - } - - for addition in edit.1.iter() { - total_edit.0.remove(addition); - total_edit.1.insert(*addition); - all_additions.insert(*addition); - } - } - - total_edit -} - -/* - * Plans can be repaired - this entails repairing schedules as well as - * partitions. `new_function` is the function after the edits have occurred, but - * before gravestones have been removed. - */ -pub fn repair_plan(plan: &mut Plan, new_function: &Function, edits: &[Edit]) { - // Step 1: collapse all of the edits into a single edit. For repairing - // partitions, we don't need to consider the intermediate edit states. - let total_edit = collapse_edits(edits); - - // Step 2: drop schedules for deleted nodes and create empty schedule lists - // for added nodes. - for deleted in total_edit.0.iter() { - // Nodes that were created and deleted using the same editor don't have - // an existing schedule, so ignore them - if deleted.idx() < plan.schedules.len() { - plan.schedules[deleted.idx()] = vec![]; - } - } - if !total_edit.1.is_empty() { - assert_eq!( - total_edit.1.iter().max().unwrap().idx() + 1, - new_function.nodes.len() - ); - plan.schedules.resize(new_function.nodes.len(), vec![]); - } - - // Step 3: figure out the order to add nodes to partitions. Roughly, we look - // at the added nodes in reverse postorder and partition by control/data. We - // first add control nodes to partitions using node-specific rules. We then - // add data nodes based on the partitions of their immediate control uses - // and users. - let def_use = def_use(new_function); - let rev_po = reverse_postorder(&def_use); - let added_control_nodes: Vec<NodeID> = rev_po - .iter() - .filter(|id| total_edit.1.contains(id) && new_function.nodes[id.idx()].is_control()) - .map(|id| *id) - .collect(); - let added_data_nodes: Vec<NodeID> = rev_po - .iter() - .filter(|id| total_edit.1.contains(id) && !new_function.nodes[id.idx()].is_control()) - .map(|id| *id) - .collect(); - - // Step 4: figure out the partitions for added control nodes. - // Do a bunch of analysis that basically boils down to finding what fork- - // joins are top-level. - let control_subgraph = control_subgraph(new_function, &def_use); - let dom = dominator(&control_subgraph, NodeID::new(0)); - let fork_join_map = fork_join_map(new_function, &control_subgraph); - let fork_join_nesting = compute_fork_join_nesting(new_function, &dom, &fork_join_map); - // While building, the new partitions map uses Option since we don't have - // partitions for new nodes yet, and we need to record that specifically for - // computing the partitions of region nodes. - let mut new_partitions: Vec<Option<PartitionID>> = - take(&mut plan.partitions).into_iter().map(Some).collect(); - new_partitions.resize(new_function.nodes.len(), None); - // Iterate the added control nodes using a worklist. - let mut worklist = VecDeque::from(added_control_nodes); - while let Some(control_id) = worklist.pop_front() { - let node = &new_function.nodes[control_id.idx()]; - // There are a few cases where this control node needs to start a new - // partition: - // 1. It's a non-gravestone start node. This is any start node visited - // by the reverse postorder. - // 2. It's a return node. - // 3. It's a top-level fork. - // 4. One of its control predecessors is a top-level join. - // 5. It's a region node where not every predecessor is in the same - // partition (equivalently, not every predecessor is in the same - // partition - only region nodes can have multiple predecessors). - // 6. It's a region node with a call user. - // 7. Its predecessor is a region node with a call user. - let top_level_fork = node.is_fork() && fork_join_nesting[&control_id].len() == 1; - let top_level_join = control_subgraph.preds(control_id).any(|pred| { - new_function.nodes[pred.idx()].is_join() && fork_join_nesting[&pred].len() == 1 - }); - // It's not possible for every predecessor to not have been assigned a - // partition yet because of reverse postorder traversal. - let multi_pred_region = !control_subgraph - .preds(control_id) - .map(|pred| new_partitions[pred.idx()]) - .all_equal(); - let region_with_call_user = |id: NodeID| { - new_function.nodes[id.idx()].is_region() - && def_use - .get_users(id) - .as_ref() - .into_iter() - .any(|id| new_function.nodes[id.idx()].is_call()) - }; - let call_region = region_with_call_user(control_id); - let pred_is_call_region = control_subgraph - .preds(control_id) - .any(|pred| region_with_call_user(pred)); - - if node.is_start() - || node.is_return() - || top_level_fork - || top_level_join - || multi_pred_region - || call_region - || pred_is_call_region - { - // This control node goes in a new partition. - let part_id = PartitionID::new(plan.num_partitions); - plan.num_partitions += 1; - new_partitions[control_id.idx()] = Some(part_id); - } else { - // This control node goes in the partition of any one of its - // predecessors. They're all the same by condition 3 above. - let any_pred = control_subgraph.preds(control_id).next().unwrap(); - if new_partitions[any_pred.idx()].is_some() { - new_partitions[control_id.idx()] = new_partitions[any_pred.idx()]; - } else { - worklist.push_back(control_id); - } - } - } - - // Step 5: figure out the partitions for added data nodes. - let antideps = antideps(&new_function, &def_use); - let loops = loops(&control_subgraph, NodeID::new(0), &dom, &fork_join_map); - let bbs = gcm( - new_function, - &def_use, - &rev_po, - &dom, - &antideps, - &loops, - &fork_join_map, - ); - for data_idx in 0..new_function.nodes.len() { - new_partitions[data_idx] = new_partitions[bbs[data_idx].idx()]; - } - - // Step 6: create a solitary gravestone partition. This will get removed - // when gravestone nodes are removed. - let gravestone_partition = PartitionID::new(plan.num_partitions); - plan.num_partitions += 1; - for (idx, node) in new_function.nodes.iter().enumerate() { - if idx > 0 && node.is_start() { - new_partitions[idx] = Some(gravestone_partition); - } - } - - // Step 7: wrap everything up. - plan.partitions = new_partitions.into_iter().map(|id| id.unwrap()).collect(); - plan.partition_devices - .resize(plan.num_partitions, Device::CPU); - // Place call partitions on the "AsyncRust" device. - for idx in 0..new_function.nodes.len() { - if new_function.nodes[idx].is_call() { - plan.partition_devices[plan.partitions[idx].idx()] = Device::AsyncRust; - } - } -} - -/* - * Default plans can be constructed by conservatively inferring schedules and - * creating partitions by "repairing" a partition where the edit is adding every - * node in the function. - */ -pub fn default_plan( - function: &Function, - dynamic_constants: &Vec<DynamicConstant>, - def_use: &ImmutableDefUseMap, - reverse_postorder: &Vec<NodeID>, - fork_join_map: &HashMap<NodeID, NodeID>, - fork_join_nesting: &HashMap<NodeID, Vec<NodeID>>, - bbs: &Vec<NodeID>, -) -> Plan { - // Start by creating a completely bare-bones plan doing nothing interesting. - let mut plan = Plan { - schedules: vec![vec![]; function.nodes.len()], - partitions: vec![], - partition_devices: vec![], - num_partitions: 1, - }; - - // Infer a partitioning by using `repair_plan`, where the "edit" is creating - // the entire function. - let edit = ( - HashSet::new(), - HashSet::from_iter((0..function.nodes.len()).map(NodeID::new)), - ); - repair_plan(&mut plan, function, &[edit]); - plan.renumber_partitions(); - - // Infer schedules. - infer_parallel_reduce(function, fork_join_map, &mut plan); - infer_parallel_fork( - function, - def_use, - fork_join_map, - fork_join_nesting, - &mut plan, - ); - infer_vectorizable(function, dynamic_constants, fork_join_map, &mut plan); - infer_associative(function, &mut plan); - - // TODO: uncomment once GPU backend is implemented. - // place_fork_partitions_on_gpu(function, &mut plan); - - plan -} - #[cfg(test)] mod editor_tests { #[allow(unused_imports)] @@ -790,6 +585,7 @@ mod editor_tests { use std::mem::replace; + use self::hercules_ir::dataflow::reverse_postorder; use self::hercules_ir::parse::parse; fn canonicalize(function: &mut Function) -> Vec<Option<NodeID>> { diff --git a/hercules_opt/src/fork_concat_split.rs b/hercules_opt/src/fork_concat_split.rs index df3652df..232b43f7 100644 --- a/hercules_opt/src/fork_concat_split.rs +++ b/hercules_opt/src/fork_concat_split.rs @@ -53,6 +53,7 @@ pub fn fork_split( control: acc_fork, factors: Box::new([factor]), }); + edit.sub_edit(*fork, acc_fork); new_tids.push(edit.add_node(Node::ThreadID { control: acc_fork, dimension: 0, @@ -68,6 +69,7 @@ pub fn fork_split( let mut joins = vec![]; for _ in new_tids.iter() { acc_join = edit.add_node(Node::Join { control: acc_join }); + edit.sub_edit(*join, acc_join); joins.push(acc_join); } @@ -89,17 +91,18 @@ pub fn fork_split( } else { NodeID::new(num_nodes + join_idx - 1) }; - let reduce = edit.add_node(Node::Reduce { + let new_reduce = edit.add_node(Node::Reduce { control: *join, init, reduct, }); - assert_eq!(reduce, NodeID::new(num_nodes + join_idx)); + assert_eq!(new_reduce, NodeID::new(num_nodes + join_idx)); + edit.sub_edit(*reduce, new_reduce); if join_idx == 0 { - inner_reduce = reduce; + inner_reduce = new_reduce; } if join_idx == joins.len() - 1 { - outer_reduce = reduce; + outer_reduce = new_reduce; } } new_reduces.push((inner_reduce, outer_reduce)); @@ -110,6 +113,7 @@ pub fn fork_split( edit = edit.replace_all_uses(*join, acc_join)?; for tid in tids.iter() { let dim = edit.get_node(*tid).try_thread_id().unwrap().1; + edit.sub_edit(*tid, new_tids[dim]); edit = edit.replace_all_uses(*tid, new_tids[dim])?; } for (reduce, (inner_reduce, outer_reduce)) in zip(reduces.iter(), new_reduces) { diff --git a/hercules_opt/src/gvn.rs b/hercules_opt/src/gvn.rs index b8360dd9..a9db0cd8 100644 --- a/hercules_opt/src/gvn.rs +++ b/hercules_opt/src/gvn.rs @@ -39,7 +39,8 @@ pub fn gvn(editor: &mut FunctionEditor) { // `number`. We want to replace `work` with `number`, which means // 1. replacing all uses of `work` with `number` // 2. deleting `work` - let success = editor.edit(|edit| { + let success = editor.edit(|mut edit| { + edit.sub_edit(work, *number); let edit = edit.replace_all_uses(work, *number)?; edit.delete_node(work) }); diff --git a/hercules_opt/src/inline.rs b/hercules_opt/src/inline.rs index 6b9e006d..1c485ecc 100644 --- a/hercules_opt/src/inline.rs +++ b/hercules_opt/src/inline.rs @@ -6,7 +6,6 @@ use std::iter::zip; use self::hercules_ir::callgraph::*; use self::hercules_ir::def_use::*; use self::hercules_ir::ir::*; -use self::hercules_ir::schedule::*; use crate::*; @@ -14,11 +13,7 @@ use crate::*; * Top level function to run inlining. Currently, inlines every function call, * since mutual recursion is not valid in Hercules IR. */ -pub fn inline( - editors: &mut [FunctionEditor], - callgraph: &CallGraph, - mut plans: Option<&mut Vec<Plan>>, -) { +pub fn inline(editors: &mut [FunctionEditor], callgraph: &CallGraph) { // Step 1: run topological sort on the call graph to inline the "deepest" // function first. let topo = callgraph.topo(); @@ -54,16 +49,11 @@ pub fn inline( // references, we need to do some weirdness here to simultaneously get: // 1. A mutable reference to the function we're modifying. // 2. Shared references to all of the functions called by that function. - // We need to get the same for plans, if we receive them. let callees = callgraph.get_callees(to_inline_id); let editor_refs = get_mut_and_immuts(editors, to_inline_id, callees); - let plan_refs = plans - .as_mut() - .map(|plans| get_mut_and_immuts(*plans, to_inline_id, callees)); inline_func( editor_refs.0, editor_refs.1, - plan_refs, &single_return_nodes, &dc_param_idx_to_dc_id, ); @@ -75,7 +65,7 @@ pub fn inline( * 1. A single mutable reference. * 2. Several shared references. * Where none of the references alias. We need to use this both for function - * editors and plans. + * editors and schedules. */ fn get_mut_and_immuts<'a, T, I: ID>( mut_refs: &'a mut [T], @@ -113,7 +103,6 @@ fn get_mut_and_immuts<'a, T, I: ID>( fn inline_func( editor: &mut FunctionEditor, called: HashMap<FunctionID, &FunctionEditor>, - plans: Option<(&mut Plan, HashMap<FunctionID, &Plan>)>, single_return_nodes: &Vec<Option<NodeID>>, dc_param_idx_to_dc_id: &Vec<DynamicConstantID>, ) { @@ -149,7 +138,7 @@ fn inline_func( let called_return_data = called_return_uses.as_ref()[1]; // Perform the actual edit. - let success = editor.edit(|mut edit| { + editor.edit(|mut edit| { // Insert the nodes from the called function. There are a few // special cases: // - Start: don't add start nodes - later, we'll replace_all_uses on @@ -217,6 +206,11 @@ fn inline_func( // Add the node and check that the IDs line up. let add_id = edit.add_node(node); assert_eq!(add_id, old_id_to_new_id(NodeID::new(idx))); + // Copy the schedule from the callee. + let callee_schedule = &called_func.schedules[idx]; + for schedule in callee_schedule { + edit = edit.add_schedule(add_id, schedule.clone())?; + } } // Stitch the control use of the inlined start node with the diff --git a/hercules_opt/src/interprocedural_sroa.rs b/hercules_opt/src/interprocedural_sroa.rs index 3ab35535..c6cf448b 100644 --- a/hercules_opt/src/interprocedural_sroa.rs +++ b/hercules_opt/src/interprocedural_sroa.rs @@ -276,6 +276,7 @@ fn compress_return_products(editors: &mut Vec<FunctionEditor>, all_callsites_edi control: return_control, data: compressed_data_id, }); + edit.sub_edit(old_return_id, new_return_id); let edit = edit.replace_all_uses(old_return_id, new_return_id)?; edit.delete_node(old_return_id) }); @@ -388,10 +389,11 @@ fn remove_return_singletons(editors: &mut Vec<FunctionEditor>, all_callsites_edi collect: old_data, indices: Box::new([Index::Field(0)]), }); - edit.add_node(Node::Return { + let new_return_id = edit.add_node(Node::Return { control: old_control, data: extracted_singleton_id, }); + edit.sub_edit(old_return_id, new_return_id); edit.set_return_type(tys[0]); edit.delete_node(old_return_id) diff --git a/hercules_opt/src/lib.rs b/hercules_opt/src/lib.rs index 5a429e14..c4cf21a9 100644 --- a/hercules_opt/src/lib.rs +++ b/hercules_opt/src/lib.rs @@ -14,6 +14,7 @@ pub mod outline; pub mod pass; pub mod phi_elim; pub mod pred; +pub mod schedule; pub mod sroa; pub mod unforkify; pub mod utils; @@ -32,6 +33,7 @@ pub use crate::outline::*; pub use crate::pass::*; pub use crate::phi_elim::*; pub use crate::pred::*; +pub use crate::schedule::*; pub use crate::sroa::*; pub use crate::unforkify::*; pub use crate::utils::*; diff --git a/hercules_opt/src/outline.rs b/hercules_opt/src/outline.rs index eb8d386c..70062bbc 100644 --- a/hercules_opt/src/outline.rs +++ b/hercules_opt/src/outline.rs @@ -201,9 +201,11 @@ pub fn outline( .chain(callee_succ_return_idx.map(|_| u32_ty)) .collect(), )), - nodes: vec![], num_dynamic_constants: edit.get_num_dynamic_constant_params(), entry: false, + nodes: vec![], + schedules: vec![], + device: None, }; // Re-number nodes in the partition. @@ -413,6 +415,13 @@ pub fn outline( }); } + // Copy the schedules into the new callee. + outlined.schedules.resize(outlined.nodes.len(), vec![]); + for id in partition.iter() { + let callee_id = convert_id(*id); + outlined.schedules[callee_id.idx()] = edit.get_schedule(*id).clone(); + } + // Step 3: edit the original function to call the outlined function. let dynamic_constants = (0..edit.get_num_dynamic_constant_params()) .map(|idx| edit.add_dynamic_constant(DynamicConstant::Parameter(idx as usize))) diff --git a/hercules_opt/src/pass.rs b/hercules_opt/src/pass.rs index 006bd371..90556612 100644 --- a/hercules_opt/src/pass.rs +++ b/hercules_opt/src/pass.rs @@ -38,6 +38,7 @@ pub enum Pass { DeleteUncalled, ForkSplit, Unforkify, + InferSchedules, Verify, // Parameterized over whether analyses that aid visualization are necessary. // Useful to set to false if displaying a potentially broken module. @@ -74,9 +75,6 @@ pub struct PassManager { pub data_nodes_in_fork_joins: Option<Vec<HashMap<NodeID, HashSet<NodeID>>>>, pub bbs: Option<Vec<Vec<NodeID>>>, pub callgraph: Option<CallGraph>, - - // Current plan. - pub plans: Option<Vec<Plan>>, } impl PassManager { @@ -98,7 +96,6 @@ impl PassManager { data_nodes_in_fork_joins: None, bbs: None, callgraph: None, - plans: None, } } @@ -324,54 +321,6 @@ impl PassManager { } } - pub fn set_plans(&mut self, plans: Vec<Plan>) { - self.plans = Some(plans); - } - - pub fn make_plans(&mut self) { - if self.plans.is_none() { - self.make_def_uses(); - self.make_reverse_postorders(); - self.make_fork_join_maps(); - self.make_fork_join_nests(); - self.make_bbs(); - let def_uses = self.def_uses.as_ref().unwrap().iter(); - let reverse_postorders = self.reverse_postorders.as_ref().unwrap().iter(); - let fork_join_maps = self.fork_join_maps.as_ref().unwrap().iter(); - let fork_join_nests = self.fork_join_nests.as_ref().unwrap().iter(); - let bbs = self.bbs.as_ref().unwrap().iter(); - self.plans = Some( - zip( - self.module.functions.iter(), - zip( - def_uses, - zip( - reverse_postorders, - zip(fork_join_maps, zip(fork_join_nests, bbs)), - ), - ), - ) - .map( - |( - function, - (def_use, (reverse_postorder, (fork_join_map, (fork_join_nest, bb)))), - )| { - default_plan( - function, - &self.module.dynamic_constants, - def_use, - reverse_postorder, - fork_join_map, - fork_join_nest, - bb, - ) - }, - ) - .collect(), - ); - } - } - pub fn run_passes(&mut self) { for pass in self.passes.clone().iter() { match pass { @@ -397,21 +346,13 @@ impl PassManager { self.module.dynamic_constants = dynamic_constants_ref.take(); self.module.types = types_ref.take(); - let edits = &editor.edits(); - if let Some(plans) = self.plans.as_mut() { - repair_plan(&mut plans[idx], &self.module.functions[idx], edits); - } - let grave_mapping = self.module.functions[idx].delete_gravestones(); - if let Some(plans) = self.plans.as_mut() { - plans[idx].fix_gravestones(&grave_mapping); - } + self.module.functions[idx].delete_gravestones(); } self.clear_analyses(); } Pass::InterproceduralSROA => { self.make_def_uses(); self.make_typing(); - let mut plans = self.plans.as_mut(); let constants_ref = RefCell::new(std::mem::take(&mut self.module.constants)); let dynamic_constants_ref = @@ -438,21 +379,12 @@ impl PassManager { interprocedural_sroa(&mut editors); - let function_edits: Vec<_> = - editors.into_iter().map(|e| e.edits()).enumerate().collect(); - self.module.constants = constants_ref.take(); self.module.dynamic_constants = dynamic_constants_ref.take(); self.module.types = types_ref.take(); - for (idx, edits) in function_edits { - if let Some(plans) = plans.as_mut() { - repair_plan(&mut plans[idx], &self.module.functions[idx], &edits); - } - let grave_mapping = &self.module.functions[idx].delete_gravestones(); - if let Some(plans) = plans.as_mut() { - plans[idx].fix_gravestones(&grave_mapping); - } + for func in self.module.functions.iter_mut() { + func.delete_gravestones(); } self.clear_analyses(); @@ -481,14 +413,7 @@ impl PassManager { self.module.dynamic_constants = dynamic_constants_ref.take(); self.module.types = types_ref.take(); - let edits = &editor.edits(); - if let Some(plans) = self.plans.as_mut() { - repair_plan(&mut plans[idx], &self.module.functions[idx], edits); - } - let grave_mapping = self.module.functions[idx].delete_gravestones(); - if let Some(plans) = self.plans.as_mut() { - plans[idx].fix_gravestones(&grave_mapping); - } + self.module.functions[idx].delete_gravestones(); } self.clear_analyses(); } @@ -514,14 +439,7 @@ impl PassManager { self.module.dynamic_constants = dynamic_constants_ref.take(); self.module.types = types_ref.take(); - let edits = &editor.edits(); - if let Some(plans) = self.plans.as_mut() { - repair_plan(&mut plans[idx], &self.module.functions[idx], edits); - } - let grave_mapping = self.module.functions[idx].delete_gravestones(); - if let Some(plans) = self.plans.as_mut() { - plans[idx].fix_gravestones(&grave_mapping); - } + self.module.functions[idx].delete_gravestones(); } self.clear_analyses(); } @@ -539,7 +457,6 @@ impl PassManager { &loops[idx], ) } - self.legacy_repair_plan(); self.clear_analyses(); } Pass::PhiElim => { @@ -564,14 +481,7 @@ impl PassManager { self.module.dynamic_constants = dynamic_constants_ref.take(); self.module.types = types_ref.take(); - let edits = &editor.edits(); - if let Some(plans) = self.plans.as_mut() { - repair_plan(&mut plans[idx], &self.module.functions[idx], edits); - } - let grave_mapping = self.module.functions[idx].delete_gravestones(); - if let Some(plans) = self.plans.as_mut() { - plans[idx].fix_gravestones(&grave_mapping); - } + self.module.functions[idx].delete_gravestones(); } self.clear_analyses(); } @@ -588,7 +498,6 @@ impl PassManager { &def_uses[idx], ) } - self.legacy_repair_plan(); self.clear_analyses(); } Pass::Predication => { @@ -596,12 +505,10 @@ impl PassManager { self.make_reverse_postorders(); self.make_doms(); self.make_fork_join_maps(); - self.make_plans(); let def_uses = self.def_uses.as_ref().unwrap(); let reverse_postorders = self.reverse_postorders.as_ref().unwrap(); let doms = self.doms.as_ref().unwrap(); let fork_join_maps = self.fork_join_maps.as_ref().unwrap(); - let plans = self.plans.as_ref().unwrap(); for idx in 0..self.module.functions.len() { predication( &mut self.module.functions[idx], @@ -609,10 +516,8 @@ impl PassManager { &reverse_postorders[idx], &doms[idx], &fork_join_maps[idx], - &plans[idx].schedules, ) } - self.legacy_repair_plan(); self.clear_analyses(); } Pass::SROA => { @@ -641,14 +546,7 @@ impl PassManager { self.module.dynamic_constants = dynamic_constants_ref.take(); self.module.types = types_ref.take(); - let edits = &editor.edits(); - if let Some(plans) = self.plans.as_mut() { - repair_plan(&mut plans[idx], &self.module.functions[idx], edits); - } - let grave_mapping = self.module.functions[idx].delete_gravestones(); - if let Some(plans) = self.plans.as_mut() { - plans[idx].fix_gravestones(&grave_mapping); - } + self.module.functions[idx].delete_gravestones(); } self.clear_analyses(); } @@ -673,21 +571,14 @@ impl PassManager { ) }) .collect(); - // Inlining is special in that it may modify partitions in a - // inter-procedural fashion. - inline(&mut editors, callgraph, self.plans.as_mut()); + inline(&mut editors, callgraph); + self.module.constants = constants_ref.take(); self.module.dynamic_constants = dynamic_constants_ref.take(); self.module.types = types_ref.take(); - let edits: Vec<_> = editors.into_iter().map(|editor| editor.edits()).collect(); - for idx in 0..edits.len() { - if let Some(plans) = self.plans.as_mut() { - repair_plan(&mut plans[idx], &self.module.functions[idx], &edits[idx]); - } - let grave_mapping = self.module.functions[idx].delete_gravestones(); - if let Some(plans) = self.plans.as_mut() { - plans[idx].fix_gravestones(&grave_mapping); - } + + for func in self.module.functions.iter_mut() { + func.delete_gravestones(); } self.clear_analyses(); } @@ -715,11 +606,6 @@ impl PassManager { collapse_returns(editor); ensure_between_control_flow(editor); } - let mut edits: Vec<_> = editors - .into_iter() - .enumerate() - .map(|(idx, editor)| (idx, editor.edits())) - .collect(); self.module.constants = constants_ref.take(); self.module.dynamic_constants = dynamic_constants_ref.take(); self.module.types = types_ref.take(); @@ -766,27 +652,9 @@ impl PassManager { self.module.constants = constants_ref.take(); self.module.dynamic_constants = dynamic_constants_ref.take(); self.module.types = types_ref.take(); - edits.extend( - editors - .into_iter() - .enumerate() - .map(|(idx, editor)| (idx, editor.edits())), - ); - for (func_idx, edit) in edits { - if let Some(plans) = self.plans.as_mut() { - repair_plan( - &mut plans[func_idx], - &self.module.functions[func_idx], - &edit, - ); - } - } - for idx in 0..self.module.functions.len() { - let grave_mapping = self.module.functions[idx].delete_gravestones(); - if let Some(plans) = self.plans.as_mut() { - plans[idx].fix_gravestones(&grave_mapping); - } + for func in self.module.functions.iter_mut() { + func.delete_gravestones(); } self.module.functions.extend(new_funcs); self.clear_analyses(); @@ -821,15 +689,8 @@ impl PassManager { self.module.dynamic_constants = dynamic_constants_ref.take(); self.module.types = types_ref.take(); - let edits: Vec<_> = editors.into_iter().map(|editor| editor.edits()).collect(); - for idx in 0..edits.len() { - if let Some(plans) = self.plans.as_mut() { - repair_plan(&mut plans[idx], &self.module.functions[idx], &edits[idx]); - } - let grave_mapping = self.module.functions[idx].delete_gravestones(); - if let Some(plans) = self.plans.as_mut() { - plans[idx].fix_gravestones(&grave_mapping); - } + for func in self.module.functions.iter_mut() { + func.delete_gravestones(); } self.fix_deleted_functions(&new_idx); @@ -863,14 +724,7 @@ impl PassManager { self.module.dynamic_constants = dynamic_constants_ref.take(); self.module.types = types_ref.take(); - let edits = &editor.edits(); - if let Some(plans) = self.plans.as_mut() { - repair_plan(&mut plans[idx], &self.module.functions[idx], edits); - } - let grave_mapping = self.module.functions[idx].delete_gravestones(); - if let Some(plans) = self.plans.as_mut() { - plans[idx].fix_gravestones(&grave_mapping); - } + self.module.functions[idx].delete_gravestones(); } self.clear_analyses(); } @@ -898,14 +752,38 @@ impl PassManager { self.module.dynamic_constants = dynamic_constants_ref.take(); self.module.types = types_ref.take(); - let edits = &editor.edits(); - if let Some(plans) = self.plans.as_mut() { - repair_plan(&mut plans[idx], &self.module.functions[idx], edits); - } - let grave_mapping = self.module.functions[idx].delete_gravestones(); - if let Some(plans) = self.plans.as_mut() { - plans[idx].fix_gravestones(&grave_mapping); - } + self.module.functions[idx].delete_gravestones(); + } + self.clear_analyses(); + } + Pass::InferSchedules => { + self.make_def_uses(); + self.make_fork_join_maps(); + let def_uses = self.def_uses.as_ref().unwrap(); + let fork_join_maps = self.fork_join_maps.as_ref().unwrap(); + for idx in 0..self.module.functions.len() { + let constants_ref = + RefCell::new(std::mem::take(&mut self.module.constants)); + let dynamic_constants_ref = + RefCell::new(std::mem::take(&mut self.module.dynamic_constants)); + let types_ref = RefCell::new(std::mem::take(&mut self.module.types)); + let mut editor = FunctionEditor::new( + &mut self.module.functions[idx], + &constants_ref, + &dynamic_constants_ref, + &types_ref, + &def_uses[idx], + ); + infer_parallel_reduce(&mut editor, &fork_join_maps[idx]); + infer_parallel_fork(&mut editor, &fork_join_maps[idx]); + infer_vectorizable(&mut editor, &fork_join_maps[idx]); + infer_associative(&mut editor); + + self.module.constants = constants_ref.take(); + self.module.dynamic_constants = dynamic_constants_ref.take(); + self.module.types = types_ref.take(); + + self.module.functions[idx].delete_gravestones(); } self.clear_analyses(); } @@ -930,17 +808,6 @@ impl PassManager { self.doms = Some(doms); self.postdoms = Some(postdoms); self.fork_join_maps = Some(fork_join_maps); - - // Verify the plan, if it exists. - if let Some(plans) = &self.plans { - for idx in 0..self.module.functions.len() { - plans[idx].verify_partitioning( - &self.module.functions[idx], - &self.def_uses.as_ref().unwrap()[idx], - &self.fork_join_maps.as_ref().unwrap()[idx], - ); - } - } } Pass::Xdot(force_analyses) => { self.make_reverse_postorders(); @@ -955,7 +822,6 @@ impl PassManager { self.doms.as_ref(), self.fork_join_maps.as_ref(), self.bbs.as_ref(), - self.plans.as_ref(), ); } Pass::Codegen(output_dir, module_name) => { @@ -1064,21 +930,6 @@ impl PassManager { } } - fn legacy_repair_plan(&mut self) { - // Cleanup the module after passes. Delete gravestone nodes. Repair the - // plans. - for idx in 0..self.module.functions.len() { - let grave_mapping = self.module.functions[idx].delete_gravestones(); - let plans = &mut self.plans; - let functions = &self.module.functions; - if let Some(plans) = plans.as_mut() { - take_mut::take(&mut plans[idx], |plan| { - plan.repair(&functions[idx], &grave_mapping) - }); - } - } - } - fn clear_analyses(&mut self) { self.def_uses = None; self.reverse_postorders = None; @@ -1092,8 +943,6 @@ impl PassManager { self.antideps = None; self.bbs = None; self.callgraph = None; - - // Don't clear the plan - this is repaired, not reconstructed. } pub fn get_module(self) -> Module { diff --git a/hercules_opt/src/pred.rs b/hercules_opt/src/pred.rs index f8478ded..09d9753d 100644 --- a/hercules_opt/src/pred.rs +++ b/hercules_opt/src/pred.rs @@ -10,7 +10,6 @@ use self::bitvec::prelude::*; use self::hercules_ir::def_use::*; use self::hercules_ir::dom::*; use self::hercules_ir::ir::*; -use self::hercules_ir::schedule::*; /* * Top level function to convert acyclic control flow in vectorized fork-joins @@ -22,7 +21,6 @@ pub fn predication( reverse_postorder: &Vec<NodeID>, dom: &DomTree, fork_join_map: &HashMap<NodeID, NodeID>, - schedules: &Vec<Vec<Schedule>>, ) { // Detect forks with vectorize schedules. let vector_forks: Vec<_> = function diff --git a/hercules_opt/src/schedule.rs b/hercules_opt/src/schedule.rs new file mode 100644 index 00000000..b65b4d04 --- /dev/null +++ b/hercules_opt/src/schedule.rs @@ -0,0 +1,175 @@ +extern crate hercules_ir; + +use std::collections::HashMap; + +use self::hercules_ir::def_use::*; +use self::hercules_ir::ir::*; + +use crate::*; + +/* + * Infer parallel fork-joins. These are fork-joins with only parallel reduction + * variables. + */ +pub fn infer_parallel_fork(editor: &mut FunctionEditor, fork_join_map: &HashMap<NodeID, NodeID>) { + for id in editor.node_ids() { + let func = editor.func(); + let Node::Fork { + control: _, + factors: _, + } = func.nodes[id.idx()] + else { + continue; + }; + let join_id = fork_join_map[&id]; + let all_parallel_reduce = editor.get_users(join_id).all(|user| { + func.schedules[user.idx()].contains(&Schedule::ParallelReduce) + || func.nodes[user.idx()].is_control() + }); + if all_parallel_reduce { + editor.edit(|edit| edit.add_schedule(id, Schedule::ParallelFork)); + } + } +} + +/* + * Infer parallel reductions consisting of a simple cycle between a Reduce node + * and a Write node, where indices of the Write are position indices using the + * ThreadID nodes attached to the corresponding Fork. This procedure also adds + * the ParallelReduce schedule to Reduce nodes reducing over a parallelized + * Reduce, as long as the base Write node also has position indices of the + * ThreadID of the outer fork. In other words, the complete Reduce chain is + * annotated with ParallelReduce, as long as each ThreadID dimension appears in + * the positional indexing of the original Write. + */ +pub fn infer_parallel_reduce(editor: &mut FunctionEditor, fork_join_map: &HashMap<NodeID, NodeID>) { + for id in editor.node_ids() { + let func = editor.func(); + if !func.nodes[id.idx()].is_reduce() { + continue; + } + + let mut first_control = None; + let mut last_reduce = id; + let mut chain_id = id; + + // Walk down Reduce chain until we reach the Reduce potentially looping + // with the Write. Note the control node of the first Reduce, since this + // will tell us which Thread ID to look for in the Write. + while let Node::Reduce { + control, + init: _, + reduct, + } = func.nodes[chain_id.idx()] + { + if first_control.is_none() { + first_control = Some(control); + } + + last_reduce = chain_id; + chain_id = reduct; + } + + // Check for a Write-Reduce tight cycle. + if let Node::Write { + collect, + data: _, + indices, + } = &func.nodes[chain_id.idx()] + && *collect == last_reduce + { + // If there is a Write-Reduce tight cycle, get the position indices. + let positions = indices + .iter() + .filter_map(|index| { + if let Index::Position(indices) = index { + Some(indices) + } else { + None + } + }) + .flat_map(|pos| pos.iter()); + + // Get the Forks corresponding to uses of bare ThreadIDs. + let fork_thread_id_pairs = positions.filter_map(|id| { + if let Node::ThreadID { control, dimension } = func.nodes[id.idx()] { + Some((control, dimension)) + } else { + None + } + }); + let mut forks = HashMap::<NodeID, Vec<usize>>::new(); + for (fork, dim) in fork_thread_id_pairs { + forks.entry(fork).or_default().push(dim); + } + + // Check if one of the Forks correspond to the Join associated with + // the Reduce being considered, and has all of its dimensions + // represented in the indexing. + let is_parallel = forks.into_iter().any(|(id, mut rep_dims)| { + rep_dims.sort(); + rep_dims.dedup(); + fork_join_map[&id] == first_control.unwrap() + && func.nodes[id.idx()].try_fork().unwrap().1.len() == rep_dims.len() + }); + + if is_parallel { + editor.edit(|edit| edit.add_schedule(id, Schedule::ParallelReduce)); + } + } + } +} + +/* + * Infer vectorizable fork-joins. Just check that there are no control nodes + * between a fork and its join and the factor is a constant. + */ +pub fn infer_vectorizable(editor: &mut FunctionEditor, fork_join_map: &HashMap<NodeID, NodeID>) { + for id in editor.node_ids() { + let func = editor.func(); + if !func.nodes[id.idx()].is_join() { + continue; + } + + let u = get_uses(&func.nodes[id.idx()]).as_ref()[0]; + if let Some(join) = fork_join_map.get(&u) + && *join == id + { + let factors = func.nodes[u.idx()].try_fork().unwrap().1; + if factors.len() == 1 + && evaluate_dynamic_constant(factors[0], &editor.get_dynamic_constants()).is_some() + { + editor.edit(|edit| edit.add_schedule(u, Schedule::Vectorizable)); + } + } + } +} + +/* + * Infer associative reduction loops. + */ +pub fn infer_associative(editor: &mut FunctionEditor) { + let is_associative = |op| match op { + BinaryOperator::Add + | BinaryOperator::Mul + | BinaryOperator::Or + | BinaryOperator::And + | BinaryOperator::Xor => true, + _ => false, + }; + + for id in editor.node_ids() { + let func = editor.func(); + if let Node::Reduce { + control: _, + init: _, + reduct, + } = func.nodes[id.idx()] + && let Node::Binary { left, right, op } = func.nodes[reduct.idx()] + && (left == id || right == id) + && is_associative(op) + { + editor.edit(|edit| edit.add_schedule(id, Schedule::Associative)); + } + } +} diff --git a/hercules_opt/src/sroa.rs b/hercules_opt/src/sroa.rs index 67c904ff..a73ecb2b 100644 --- a/hercules_opt/src/sroa.rs +++ b/hercules_opt/src/sroa.rs @@ -100,10 +100,11 @@ pub fn sroa(editor: &mut FunctionEditor, reverse_postorder: &Vec<NodeID>, types: let control = *control; let new_data = reconstruct_product(editor, types[&data], *data, &mut product_nodes); editor.edit(|mut edit| { - edit.add_node(Node::Return { + let new_return = edit.add_node(Node::Return { control, data: new_data, }); + edit.sub_edit(node, new_return); edit.delete_node(node) }); } @@ -145,6 +146,7 @@ pub fn sroa(editor: &mut FunctionEditor, reverse_postorder: &Vec<NodeID>, types: dynamic_constants, args: new_args.into(), }); + edit.sub_edit(node, new_call); let edit = edit.replace_all_uses(node, new_call)?; let edit = edit.delete_node(node)?; @@ -344,7 +346,7 @@ pub fn sroa(editor: &mut FunctionEditor, reverse_postorder: &Vec<NodeID>, types: } if ready { - fields.zip_list(data_fields).for_each(|idx, (res, data)| { + fields.zip_list(data_fields).for_each(|_, (res, data)| { to_insert.insert( res.idx(), Node::Phi { @@ -374,7 +376,7 @@ pub fn sroa(editor: &mut FunctionEditor, reverse_postorder: &Vec<NodeID>, types: (field_map.get(&init), field_map.get(&reduct)) { fields.zip(init_fields).zip(reduct_fields).for_each( - |idx, ((res, init), reduct)| { + |_, ((res, init), reduct)| { to_insert.insert( res.idx(), Node::Reduce { @@ -409,7 +411,7 @@ pub fn sroa(editor: &mut FunctionEditor, reverse_postorder: &Vec<NodeID>, types: fields .zip(thn_fields) .zip(els_fields) - .for_each(|idx, ((res, thn), els)| { + .for_each(|_, ((res, thn), els)| { to_insert.insert( res.idx(), Node::Ternary { @@ -458,7 +460,10 @@ pub fn sroa(editor: &mut FunctionEditor, reverse_postorder: &Vec<NodeID>, types: None => new, }; - editor.edit(|edit| edit.replace_all_uses(old, new)); + editor.edit(|mut edit| { + edit.sub_edit(old, new); + edit.replace_all_uses(old, new) + }); replaced_by.insert(old, new); let mut replaced = vec![]; diff --git a/hercules_opt/src/unforkify.rs b/hercules_opt/src/unforkify.rs index f31b7409..61c86a27 100644 --- a/hercules_opt/src/unforkify.rs +++ b/hercules_opt/src/unforkify.rs @@ -125,10 +125,14 @@ pub fn unforkify(editor: &mut FunctionEditor, fork_join_map: &HashMap<NodeID, No edit = edit.replace_all_uses(*fork, proj_back_id)?; edit = edit.replace_all_uses(*join, proj_exit_id)?; + edit.sub_edit(*fork, region_id); + edit.sub_edit(*join, if_id); for tid in tids.iter() { + edit.sub_edit(*tid, indvar_id); edit = edit.replace_all_uses(*tid, indvar_id)?; } for (reduce, phi_id) in zip(reduces.iter(), phi_ids) { + edit.sub_edit(*reduce, phi_id); edit = edit.replace_all_uses(*reduce, phi_id)?; } diff --git a/juno_frontend/src/lib.rs b/juno_frontend/src/lib.rs index c39faef9..4ff5d9fc 100644 --- a/juno_frontend/src/lib.rs +++ b/juno_frontend/src/lib.rs @@ -120,6 +120,8 @@ pub fn compile_ir( ) -> Result<(), ErrorMessage> { let mut pm = match schedule { JunoSchedule::None => hercules_opt::pass::PassManager::new(module), + _ => todo!(), + /* JunoSchedule::DefaultSchedule => { let mut pm = hercules_opt::pass::PassManager::new(module); pm.make_plans(); @@ -143,6 +145,7 @@ pub fn compile_ir( } } } + */ }; if verify.verify() || verify.verify_all() { pm.add_pass(hercules_opt::pass::Pass::Verify); @@ -187,6 +190,7 @@ pub fn compile_ir( add_pass!(pm, verify, Outline); add_pass!(pm, verify, InterproceduralSROA); add_pass!(pm, verify, SROA); + add_pass!(pm, verify, InferSchedules); add_pass!(pm, verify, ForkSplit); add_pass!(pm, verify, Unforkify); add_pass!(pm, verify, GVN); diff --git a/juno_scheduler/src/lib.rs b/juno_scheduler/src/lib.rs index 36ea79e9..7e558d6b 100644 --- a/juno_scheduler/src/lib.rs +++ b/juno_scheduler/src/lib.rs @@ -8,7 +8,6 @@ use lrlex::DefaultLexerTypes; use lrpar::NonStreamingLexer; use self::hercules_ir::ir::*; -use self::hercules_ir::schedule::*; mod parser; use crate::parser::lexer; @@ -47,6 +46,7 @@ pub enum LabeledStructure { Branch(NodeID), // If node } +/* pub fn schedule(module: &Module, info: FunctionMap, schedule: String) -> Result<Vec<Plan>, String> { if let Ok(mut file) = File::open(schedule) { let mut contents = String::new(); @@ -337,3 +337,4 @@ fn generate_schedule( .map(|(f, p)| (f, p.into())) .collect::<HashMap<_, _>>()) } +*/ -- GitLab