diff --git a/Cargo.lock b/Cargo.lock index a1eb77de19e2b05c2d96c44fe6579b9b71fdd62f..80ef4516f66677992a07755b3bdb788245e63c3a 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -861,6 +861,7 @@ dependencies = [ "serde", "take_mut", "tempfile", + "union-find", ] [[package]] @@ -2098,6 +2099,12 @@ version = "0.1.14" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "7dd6e30e90baa6f72411720665d41d89b9a3d039dc45b8faea1ddd07f617f6af" +[[package]] +name = "union-find" +version = "0.4.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "039142448432983c34b64739f8526f8f233a1eec7a66e61b6ab29acfa781194e" + [[package]] name = "utf8parse" version = "0.2.2" diff --git a/hercules_opt/Cargo.toml b/hercules_opt/Cargo.toml index 84f6aca83e508d905ad0e13f0670e7d45c18d22b..9f22884dc9811a1a79235c0ea5ddf1111f7d22a0 100644 --- a/hercules_opt/Cargo.toml +++ b/hercules_opt/Cargo.toml @@ -11,6 +11,7 @@ tempfile = "*" either = "*" itertools = "*" take_mut = "*" +union-find = "*" postcard = { version = "*", features = ["alloc"] } serde = { version = "*", features = ["derive"] } hercules_cg = { path = "../hercules_cg" } diff --git a/hercules_opt/src/gcm.rs b/hercules_opt/src/gcm.rs index a7df9bd9c8409daf4daa004dcd2b29a3ec2660c8..5ea9485d108ea6454d856bf164d990ea5d7895f8 100644 --- a/hercules_opt/src/gcm.rs +++ b/hercules_opt/src/gcm.rs @@ -3,6 +3,7 @@ use std::iter::{empty, once, zip, FromIterator}; use bitvec::prelude::*; use either::Either; +use union_find::{QuickFindUf, UnionBySize, UnionFind}; use hercules_cg::*; use hercules_ir::*; @@ -551,6 +552,35 @@ fn mutating_objects<'a>( } } +fn mutating_writes<'a>( + function: &'a Function, + mutator: NodeID, + objects: &'a CollectionObjects, +) -> Box<dyn Iterator<Item = NodeID> + 'a> { + match function.nodes[mutator.idx()] { + Node::Write { + collect, + data: _, + indices: _, + } => Box::new(once(collect)), + Node::Call { + control: _, + function: callee, + dynamic_constants: _, + ref args, + } => Box::new(args.into_iter().enumerate().filter_map(move |(idx, arg)| { + let callee_objects = &objects[&callee]; + let param_obj = callee_objects.param_to_object(idx)?; + if callee_objects.is_mutated(param_obj) { + Some(*arg) + } else { + None + } + })), + _ => Box::new(empty()), + } +} + type Liveness = BTreeMap<NodeID, Vec<BTreeSet<NodeID>>>; /* @@ -579,27 +609,60 @@ fn spill_clones( // Step 2: compute an interference graph from the liveness result. This // graph contains a vertex per node ID producing a collection value and an // edge per pair of node IDs that interfere. Nodes A and B interfere if node - // A is defined right above a point where node B is live. + // A is defined right above a point where node B is live and A != B. Extra + // edges are drawn for forwarding reads - when there is a node A that is a + // forwarding read of a node B, A and B really have the same live range for + // the purpose of determining when spills are necessary, since forwarding + // reads can be thought of as nothing but pointer math. For this purpose, we + // maintain a union-find of nodes that form a forwarding read DAG (notably, + // phis and reduces are not considered forwarding reads). The more precise + // version of the interference condition is nodes A and B interfere is node + // A is defined right above a point where a node C is live where C is in the + // same union-find class as B. + + // Assemble the union-find to group forwarding read DAGs. + let mut union_find = QuickFindUf::<UnionBySize>::new(editor.func().nodes.len()); + for id in editor.node_ids() { + for forwarding_read in forwarding_reads(editor.func(), editor.func_id(), id, objects) { + union_find.union(id.idx(), forwarding_read.idx()); + } + } + + // Figure out which classes contain which node IDs, since we need to iterate + // the disjoint sets. + let mut disjoint_sets: BTreeMap<usize, Vec<NodeID>> = BTreeMap::new(); + for id in editor.node_ids() { + disjoint_sets + .entry(union_find.find(id.idx())) + .or_default() + .push(id); + } + + // Create the graph. let mut edges = vec![]; for (bb, liveness) in liveness { let insts = &bbs.1[bb.idx()]; for (node, live) in zip(insts, liveness.into_iter().skip(1)) { for live_node in live { - if *node != live_node { - edges.push((*node, live_node)); + for live_node in disjoint_sets[&union_find.find(live_node.idx())].iter() { + if *node != *live_node { + edges.push((*node, *live_node)); + } } } } } - // Step 3: filter edges (A, B) to just see edges where A uses B and A isn't - // a terminating read. These are the edges that may require a spill. + // Step 3: filter edges (A, B) to just see edges where A uses B and A + // mutates B. These are the edges that may require a spill. let mut spill_edges = edges.into_iter().filter(|(a, b)| { - get_uses(&editor.func().nodes[a.idx()]) - .as_ref() - .into_iter() - .any(|u| *u == *b) - && !terminating_reads(editor.func(), editor.func_id(), *a, objects).any(|id| id == *b) + mutating_writes(editor.func(), *a, objects).any(|id| id == *b) + || (get_uses(&editor.func().nodes[a.idx()]) + .as_ref() + .into_iter() + .any(|u| *u == *b) + && (editor.func().nodes[a.idx()].is_phi() + || editor.func().nodes[a.idx()].is_reduce())) }); // Step 4: if there is a spill edge, spill it and return true. Otherwise,