Skip to content
Snippets Groups Projects
Commit 453cbdc1 authored by Aaron Councilman's avatar Aaron Councilman
Browse files

Interprocedural sroa

parent 82edc984
No related branches found
No related tags found
1 merge request!59Interprocedural sroa
......@@ -84,6 +84,10 @@ pub enum DeviceManifest {
parallel_launch: Box<[(usize, DynamicConstantID)]>,
},
GPU,
Call {
// This is a Hercules function name, not a schedule IR function name.
callee: String,
},
}
impl Manifest {
......
......@@ -183,6 +183,8 @@ pub fn gcm(
continue;
};
// #[feature(iter_collect_into)]
// Look between the LCA and the schedule early location to place the
// node. If a data node can't be scheduled to any control nodes in its
// partition (this may happen if all of the control nodes in a partition
......
......@@ -1086,6 +1086,27 @@ impl Node {
}
}
pub fn try_call(
&self,
) -> Option<(
NodeID,
FunctionID,
&Box<[DynamicConstantID]>,
&Box<[NodeID]>,
)> {
if let Node::Call {
control,
function,
dynamic_constants,
args,
} = self
{
Some((*control, *function, dynamic_constants, args))
} else {
None
}
}
pub fn try_dynamic_constant(&self) -> Option<DynamicConstantID> {
if let Node::DynamicConstant { id } = self {
Some(*id)
......
......@@ -303,6 +303,18 @@ impl<'a, 'b> FunctionEdit<'a, 'b> {
}
pub fn replace_all_uses(mut self, old: NodeID, new: NodeID) -> Result<Self, Self> {
self.replace_all_uses_where(old, new, |_| true)
}
pub fn replace_all_uses_where<P>(
mut self,
old: NodeID,
new: NodeID,
pred: P,
) -> Result<Self, Self>
where
P: Fn(&NodeID) -> bool,
{
// We can only replace uses of mutable nodes. Return None if we try to
// replace uses of an immutable node, as it means the whole edit should
// be aborted.
......@@ -310,25 +322,32 @@ impl<'a, 'b> FunctionEdit<'a, 'b> {
// Update all of the users of the old node.
self.ensure_updated_def_use_entry(old);
for user_id in self.updated_def_use[&old].iter() {
// Replace uses of old with new.
let mut updated_user = self.get_node(*user_id).clone();
for u in get_uses_mut(&mut updated_user).as_mut() {
if **u == old {
**u = new;
if pred(user_id) {
// Replace uses of old with new.
let mut updated_user = self.get_node(*user_id).clone();
for u in get_uses_mut(&mut updated_user).as_mut() {
if **u == old {
**u = new;
}
}
// Add the updated user to added_and_updated.
self.added_and_updated_nodes.insert(*user_id, updated_user);
}
// Add the updated user to added_and_updated.
self.added_and_updated_nodes.insert(*user_id, updated_user);
}
// All of the users of the old node become users of the new node, so
// move all of the entries in the def-use from the old to the new.
let old_entries = take(self.updated_def_use.get_mut(&old).unwrap());
let (new_users, old_users): (Vec<_>, Vec<_>) = old_entries.into_iter().partition(pred);
self.ensure_updated_def_use_entry(new);
self.updated_def_use
.get_mut(&new)
.unwrap()
.extend(old_entries);
.extend(new_users);
self.updated_def_use
.get_mut(&old)
.unwrap()
.extend(old_users);
Ok(self)
} else {
......@@ -376,7 +395,7 @@ impl<'a, 'b> FunctionEdit<'a, 'b> {
}
}
pub fn get_type(&self, id: TypeID) -> impl Deref + '_ {
pub fn get_type(&self, id: TypeID) -> impl Deref<Target = Type> + '_ {
if id.idx() < self.editor.types.borrow().len() {
Either::Left(Ref::map(self.editor.types.borrow(), |types| {
&types[id.idx()]
......@@ -408,6 +427,32 @@ impl<'a, 'b> FunctionEdit<'a, 'b> {
}
}
pub fn add_zero_constant(&mut self, id: TypeID) -> ConstantID {
let ty = self.get_type(id).clone();
let constant_to_construct = match ty {
Type::Boolean => Constant::Boolean(false),
Type::Integer8 => Constant::Integer8(0),
Type::Integer16 => Constant::Integer16(0),
Type::Integer32 => Constant::Integer32(0),
Type::Integer64 => Constant::Integer64(0),
Type::UnsignedInteger8 => Constant::UnsignedInteger8(0),
Type::UnsignedInteger16 => Constant::UnsignedInteger16(0),
Type::UnsignedInteger32 => Constant::UnsignedInteger32(0),
Type::UnsignedInteger64 => Constant::UnsignedInteger64(0),
Type::Float32 => Constant::Float32(ordered_float::OrderedFloat(0.0)),
Type::Float64 => Constant::Float64(ordered_float::OrderedFloat(0.0)),
Type::Control => panic!("Tried to get zero control element"),
Type::Product(tys) => {
let dummy_elems: Vec<_> =
tys.iter().map(|ty| self.add_zero_constant(*ty)).collect();
Constant::Product(id, dummy_elems.into_boxed_slice())
}
Type::Summation(tys) => Constant::Summation(id, 0, self.add_zero_constant(tys[0])),
Type::Array(tid, _) => Constant::Array(tid),
};
self.add_constant(constant_to_construct)
}
pub fn get_constant(&self, id: ConstantID) -> impl Deref + '_ {
if id.idx() < self.editor.constants.borrow().len() {
Either::Left(Ref::map(self.editor.constants.borrow(), |constants| {
......@@ -467,6 +512,8 @@ impl<'a, 'b> FunctionEdit<'a, 'b> {
fn collapse_edits(edits: &[Edit]) -> Edit {
let mut total_edit = Edit::default();
let mut all_additions: HashSet<NodeID> = HashSet::new();
for edit in edits {
assert!(edit.0.is_disjoint(&edit.1), "PANIC: Edit sequence is malformed - can't add and delete the same node ID in a single edit.");
assert!(
......@@ -479,13 +526,16 @@ fn collapse_edits(edits: &[Edit]) -> Edit {
);
for delete in edit.0.iter() {
total_edit.0.insert(*delete);
if !all_additions.contains(delete) {
total_edit.0.insert(*delete);
}
total_edit.1.remove(delete);
}
for addition in edit.1.iter() {
total_edit.0.remove(addition);
total_edit.1.insert(*addition);
all_additions.insert(*addition);
}
}
......@@ -543,10 +593,8 @@ pub fn repair_plan(plan: &mut Plan, new_function: &Function, edits: &[Edit]) {
// While building, the new partitions map uses Option since we don't have
// partitions for new nodes yet, and we need to record that specifically for
// computing the partitions of region nodes.
let mut new_partitions: Vec<Option<PartitionID>> = take(&mut plan.partitions)
.into_iter()
.map(|part| Some(part))
.collect();
let mut new_partitions: Vec<Option<PartitionID>> =
take(&mut plan.partitions).into_iter().map(Some).collect();
new_partitions.resize(new_function.nodes.len(), None);
// Iterate the added control nodes using a worklist.
let mut worklist = VecDeque::from(added_control_nodes);
......
extern crate hercules_ir;
use self::hercules_ir::ir::*;
use crate::*;
/**
* Given an editor for each function in a module, return V s.t.
* V[i] = true iff every call node to the function with index i
* is editable. If there are no calls to this function, V[i] = true.
*/
fn get_editable_callsites(editors: &mut Vec<FunctionEditor>) -> Vec<bool> {
let mut callsites_editable = vec![true; editors.len()];
for editor in editors {
for (idx, (_, function, _, _)) in editor
.func()
.nodes
.iter()
.enumerate()
.filter_map(|(idx, node)| node.try_call().map(|c| (idx, c)))
{
if !editor.is_mutable(NodeID::new(idx)) {
callsites_editable[function.idx()] = false;
}
}
}
callsites_editable
}
/**
* Given a type tree, return a Vec containing all leaves which are not units.
*/
fn get_nonempty_leaves(edit: &FunctionEdit, type_id: &TypeID) -> Vec<TypeID> {
let ty = edit.get_type(*type_id).clone();
match ty {
Type::Product(type_ids) => {
let mut leaves = vec![];
for type_id in type_ids {
leaves.extend(get_nonempty_leaves(&edit, &type_id))
}
leaves
}
_ => vec![*type_id],
}
}
/**
* Given a `source` NodeID which produces a product containing
* all nonempty leaves of the type tree for `type_id` in order, build
* a node producing the `type_id`.
*
* `offset` represents the index at which to begin reading
* elements of the `source` product.
*
* Returns a 3-tuple of
* 1. Node producing the `type`
* 2. "Next" offset, i.e. `offset` + number of reads performed to build (1)
* 3. List of node IDs which read `source` (tracked so that these will not
* be replaced by replace_all_uses_where)
*/
fn build_uncompressed_product(
edit: &mut FunctionEdit,
source: &NodeID,
type_id: &TypeID,
offset: usize,
) -> (NodeID, usize, Vec<NodeID>) {
let ty = edit.get_type(*type_id).clone();
match ty {
Type::Product(child_type_ids) => {
// Step 1. Create an empty constant for the type. We'll write
// child values into this constant.
let empty_constant_id = edit.add_zero_constant(*type_id);
let empty_constant_node = edit.add_node(Node::Constant {
id: empty_constant_id,
});
// Step 2. Build a node that generates each inner type.
// Since `source` contains nonempty leaves *in order*,
// we must process inner types in order; as part of this,
// inner type i+1 must read from where inner type i left off,
// hence we track the `current_offset` at which we are reading.
// Similarly, to combine results of all recursive calls,
// we keep the invariant that, at iteration i+1, currently_writing_to
// is an instance of `type_id` for which the first i elements
// have been populated based on inorder nonempty leaves
// (and, at iteration 0, it is empty).
let mut current_offset = offset;
let mut currently_writing_to = empty_constant_node;
let mut readers = vec![];
for (idx, child_type_id) in child_type_ids.iter().enumerate() {
let (child_data, next_offset, child_readers) =
build_uncompressed_product(edit, source, child_type_id, current_offset);
current_offset = next_offset;
currently_writing_to = edit.add_node(Node::Write {
collect: currently_writing_to,
data: child_data,
indices: Box::new([Index::Field(idx)]),
});
readers.extend(child_readers)
}
(currently_writing_to, current_offset, readers)
}
_ => {
// If the type is not a product, then we've reached a nonempty
// leaf, which we must read from source. Since this is a single
// read, the new offset increases by only 1.
let reader = edit.add_node(Node::Read {
collect: *source,
indices: Box::new([Index::Field(offset)]),
});
(reader, offset + 1, vec![reader])
}
}
}
/**
* Given a node with a product value, read the product's values
* *in order* into the nonempty leaves of a product type represented
* by type_id. Returns the ID of the resulting node, as well as the IDs
* of all nodes which read from `node_id`.
*/
fn uncompress_product(
edit: &mut FunctionEdit,
node_id: &NodeID,
type_id: &TypeID,
) -> (NodeID, Vec<NodeID>) {
let (uncompressed_value, _, readers) = build_uncompressed_product(edit, node_id, type_id, 0);
(uncompressed_value, readers)
}
/**
* Let `read_from` be a node with a value of type `type_id`.
* Let `source` be a product value.
* Returns a node representing the value obtained by writing
* nonempty leaves of `read_from` *in order* into `source`,
* starting at `offset`.
*
* `source` should be a product type with at least enough indices
* to support this operation. Typically, `build_compressed_product`
* should be called initially with a `source` created by adding a
* zero constant for the flattened `type_id`.
*
* Returns:
* 1. The ID of the node to which all nonempty leaves have been written
* 2. The first offset after `offset` which was not written to.
*/
fn build_compressed_product(
mut edit: &mut FunctionEdit,
source: &NodeID,
type_id: &TypeID,
offset: usize,
read_from: &NodeID,
) -> (NodeID, usize) {
let ty = edit.get_type(*type_id).clone();
match ty {
Type::Product(child_type_ids) => {
// Iterate through child types in order. For each type, construct
// a node that reads the corresponding value from `read_from`,
// and pass it as the node to read from in the recursive call.
let mut next_offset = offset;
let mut next_destination = *source;
for (idx, child_type_id) in child_type_ids.iter().enumerate() {
let child_value = edit.add_node(Node::Read {
collect: *read_from,
indices: Box::new([Index::Field(idx)]),
});
(next_destination, next_offset) = build_compressed_product(
&mut edit,
&next_destination,
&child_type_id,
next_offset,
&child_value,
);
}
(next_destination, next_offset)
}
_ => {
let writer = edit.add_node(Node::Write {
collect: *source,
data: *read_from,
indices: Box::new([Index::Field(offset)]),
});
(writer, offset + 1)
}
}
}
/**
* Given a node which has a value of the given type (which must be a product)
* generate a new product node created by inserting nonempty leaves of the
* source node *in order*. Returns the ID of this node, as well as the ID of
* its type.
*/
fn compress_product(
edit: &mut FunctionEdit,
node_id: &NodeID,
type_id: &TypeID,
) -> (NodeID, TypeID) {
let nonempty_leaves = get_nonempty_leaves(&edit, &type_id);
let compressed_type = Type::Product(nonempty_leaves.into_boxed_slice());
let compressed_type_id = edit.add_type(compressed_type);
let empty_compressed_constant_id = edit.add_zero_constant(compressed_type_id);
let empty_compressed_node_id = edit.add_node(Node::Constant {
id: empty_compressed_constant_id,
});
let (compressed_value, _) =
build_compressed_product(edit, &empty_compressed_node_id, type_id, 0, node_id);
(compressed_value, compressed_type_id)
}
fn compress_return_products(editors: &mut Vec<FunctionEditor>, all_callsites_editable: &Vec<bool>) {
// Track whether we successfully applied edits to return statements,
// so that callsites are only modified when returns were. This is
// initialized to false, so that `is_compressed` is false when
// the corresponding entry in `callsites_editable` is false.
let mut is_compressed = vec![false; editors.len()];
let old_return_type_ids: Vec<_> = editors
.iter()
.map(|editor| editor.func().return_type)
.collect();
// Step 1. Modify the return type of all editors corresponding to a function
// for which we can edit every callsite, and the return type is a product.
for (idx, editor) in editors.iter_mut().enumerate() {
if !all_callsites_editable[idx] {
continue;
}
let old_return_id = NodeID::new(
(0..editor.func().nodes.len())
.filter(|idx| editor.func().nodes[*idx].is_return())
.next()
.unwrap(),
);
let old_return_type_id = old_return_type_ids[idx];
is_compressed[idx] = editor.get_type(editor.func().return_type).is_product()
&& editor.edit(|mut edit| {
let return_node = edit.get_node(old_return_id);
let (return_control, return_data) = return_node.try_return().unwrap();
let (compressed_data_id, compressed_type_id) =
compress_product(&mut edit, &return_data, &old_return_type_id);
edit.set_return_type(compressed_type_id);
let new_return_id = edit.add_node(Node::Return {
control: return_control,
data: compressed_data_id,
});
let edit = edit.replace_all_uses(old_return_id, new_return_id)?;
edit.delete_node(old_return_id)
});
}
// Step 2: For every editor, update all mutable callsites corresponding to
// calls to functions which have been compressed. Since we only compress returns
// for functions for which every callsite is mutable, this should never fail,
// so we panic if it does.
for (_, editor) in editors.iter_mut().enumerate() {
let call_node_ids: Vec<_> = (0..editor.func().nodes.len())
.map(NodeID::new)
.filter(|id| editor.func().nodes[id.idx()].is_call())
.filter(|id| editor.is_mutable(*id))
.collect();
for call_node_id in call_node_ids {
let (_, function_id, _, _) =
editor.func().nodes[call_node_id.idx()].try_call().unwrap();
if !is_compressed[function_id.idx()] {
continue;
}
let edit_successful = editor.edit(|mut edit| {
let (expanded_product, readers) = uncompress_product(
&mut edit,
&call_node_id,
&old_return_type_ids[function_id.idx()],
);
edit.replace_all_uses_where(call_node_id, expanded_product, |id| {
!readers.contains(id)
})
});
if !edit_successful {
panic!("Tried and failed to edit mutable callsite!");
}
}
}
}
fn remove_return_singletons(editors: &mut Vec<FunctionEditor>, all_callsites_editable: &Vec<bool>) {
// Track whether we removed a singleton product from the return of each
// editor's function. Defaults to false so that if the function was not
// edited (i.e. because not all callsites are editable), then no callsites
// will be edited.
let mut singleton_removed = vec![false; editors.len()];
let old_return_type_ids: Vec<_> = editors
.iter()
.map(|editor| editor.func().return_type)
.collect();
// Step 1. For all editors which correspond to a function for whic hall
// callsites are editable, modify their return type by extracting the
// value from the singleton and returning it directly.
for (idx, editor) in editors.iter_mut().enumerate() {
if !all_callsites_editable[idx] {
continue;
}
let return_type = editor.get_type(old_return_type_ids[idx]).clone();
singleton_removed[idx] = match return_type {
Type::Product(tys) if tys.len() == 1 && all_callsites_editable[idx] => {
let old_return_id = NodeID::new(
(0..editor.func().nodes.len())
.filter(|idx| editor.func().nodes[*idx].is_return())
.next()
.unwrap(),
);
editor.edit(|mut edit| {
let (old_control, old_data) =
edit.get_node(old_return_id).try_return().unwrap();
let extracted_singleton_id = edit.add_node(Node::Read {
collect: old_data,
indices: Box::new([Index::Field(0)]),
});
edit.add_node(Node::Return {
control: old_control,
data: extracted_singleton_id,
});
edit.set_return_type(tys[0]);
edit.delete_node(old_return_id)
})
}
_ => false,
}
}
// Step 2. For each editor, find all callsites and reconstruct
// the singleton product at each if the return of the corresponding
// function was modified. This should always succeed since we only
// edited functions for which all callsites were mutable, so panic
// if an edit does not succeed.
for editor in editors.iter_mut() {
let call_node_ids: Vec<_> = (0..editor.func().nodes.len())
.map(NodeID::new)
.filter(|id| editor.func().nodes[id.idx()].is_call())
.filter(|id| editor.is_mutable(*id))
.collect();
for call_node_id in call_node_ids {
let (_, function, _, _) = editor.func().nodes[call_node_id.idx()].try_call().unwrap();
if singleton_removed[function.idx()] {
let edit_successful = editor.edit(|mut edit| {
let empty_constant_id =
edit.add_zero_constant(old_return_type_ids[function.idx()]);
let empty_node_id = edit.add_node(Node::Constant {
id: empty_constant_id,
});
let restored_singleton_id = edit.add_node(Node::Write {
collect: empty_node_id,
data: call_node_id,
indices: Box::new([Index::Field(0)]),
});
edit.replace_all_uses_where(call_node_id, restored_singleton_id, |id| {
*id != restored_singleton_id
})
});
if !edit_successful {
panic!("Tried and failed to edit mutable callsite!");
}
}
}
}
}
pub fn interprocedural_sroa(editors: &mut Vec<FunctionEditor>) {
// SROA is implemented in two phases. First, we flatten (or "compress")
// all product return types, so that they are only depth 1 products,
// and do not contain any empty products.
// Next, if any return type is now a singleton product, we
// remove the singleton and just retun the type directly.
// We only apply these changes to functions for which
// all their callsites are editable.
let all_callsites_editable = get_editable_callsites(editors);
compress_return_products(editors, &all_callsites_editable);
remove_return_singletons(editors, &all_callsites_editable);
// Run DCE to prevent issues with schedule repair.
for editor in editors.iter_mut() {
dce(editor);
}
}
......@@ -6,6 +6,7 @@ pub mod editor;
pub mod fork_guard_elim;
pub mod forkify;
pub mod gvn;
pub mod interprocedural_sroa;
pub mod inline;
pub mod pass;
pub mod phi_elim;
......@@ -18,6 +19,7 @@ pub use crate::editor::*;
pub use crate::fork_guard_elim::*;
pub use crate::forkify::*;
pub use crate::gvn::*;
pub use crate::interprocedural_sroa::*;
pub use crate::inline::*;
pub use crate::pass::*;
pub use crate::phi_elim::*;
......
......@@ -42,6 +42,7 @@ pub enum Pass {
Codegen(String, String),
// Parameterized over where to serialize module to.
Serialize(String),
InterproceduralSROA,
}
/*
......@@ -373,6 +374,54 @@ impl PassManager {
}
self.clear_analyses();
}
Pass::InterproceduralSROA => {
self.make_def_uses();
let mut plans = self.plans.as_mut();
let constants_ref = RefCell::new(std::mem::take(&mut self.module.constants));
let dynamic_constants_ref =
RefCell::new(std::mem::take(&mut self.module.dynamic_constants));
let types_ref = RefCell::new(std::mem::take(&mut self.module.types));
let def_uses = self.def_uses.as_ref().unwrap();
let mut editors: Vec<_> = self
.module
.functions
.iter_mut()
.enumerate()
.map(|(i, f)| {
FunctionEditor::new(
f,
&constants_ref,
&dynamic_constants_ref,
&types_ref,
&def_uses[i],
)
})
.collect();
interprocedural_sroa(&mut editors);
let function_edits: Vec<_> =
editors.into_iter().map(|e| e.edits()).enumerate().collect();
self.module.constants = constants_ref.take();
self.module.dynamic_constants = dynamic_constants_ref.take();
self.module.types = types_ref.take();
for (idx, edits) in function_edits {
if let Some(plans) = plans.as_mut() {
repair_plan(&mut plans[idx], &self.module.functions[idx], &edits);
}
let grave_mapping = &self.module.functions[idx].delete_gravestones();
if let Some(plans) = plans.as_mut() {
plans[idx].fix_gravestones(&grave_mapping);
}
}
self.clear_analyses();
}
Pass::CCP => {
self.make_def_uses();
self.make_reverse_postorders();
......
......@@ -156,7 +156,7 @@ pub fn compile_ir(
add_pass!(pm, verify, DCE);
add_pass!(pm, verify, GVN);
add_pass!(pm, verify, DCE);
//add_pass!(pm, verify, SROA);
add_pass!(pm, verify, InterproceduralSROA);
if x_dot {
pm.add_pass(hercules_opt::pass::Pass::Xdot(true));
}
......
fn simple1_inner(x : i32, y : i32) -> i32 {
fn simple1(x : i32, y : i32) -> i32 {
return x + y;
}
fn simple(x : i32, y : i32) -> i32 {
return simple1_inner(x, y);
}
fn simple1(x : i32, y : i32) -> i32 {
return x + y;
}
fn caller(x: i32) -> i32 {
return simple1(x, 3);
}
......@@ -11,8 +11,8 @@ fn main() {
let mut a = vec![1, 2, 3, 4, 5, 6, 7, 8];
let mut b = vec![8, 7, 6, 5, 4, 3, 2, 1];
let c = unsafe { simple3(a.as_mut_ptr(), b.as_mut_ptr(), 8).await };
println!("{:?}", c);
assert_eq!(c.0, 120);
println!("{}", c);
assert_eq!(c, 120);
});
}
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment