From 4606a2a7219ef33437a8995526c873ab21016980 Mon Sep 17 00:00:00 2001 From: Russel Arbore <russel.jma@gmail.com> Date: Sun, 29 Dec 2024 15:07:09 -0800 Subject: [PATCH] Refactor object analysis --- hercules_cg/src/lib.rs | 25 ++- hercules_cg/src/mem.rs | 291 --------------------------- hercules_cg/src/rt.rs | 153 +++++++-------- hercules_ir/src/collections.rs | 348 +++++++++++++++++++++++++++++++++ hercules_ir/src/lib.rs | 2 + hercules_opt/src/pass.rs | 37 ++-- 6 files changed, 466 insertions(+), 390 deletions(-) delete mode 100644 hercules_cg/src/mem.rs create mode 100644 hercules_ir/src/collections.rs diff --git a/hercules_cg/src/lib.rs b/hercules_cg/src/lib.rs index 952ce368..c579b7e9 100644 --- a/hercules_cg/src/lib.rs +++ b/hercules_cg/src/lib.rs @@ -2,10 +2,31 @@ pub mod cpu; pub mod device; -pub mod mem; pub mod rt; pub use crate::cpu::*; pub use crate::device::*; -pub use crate::mem::*; pub use crate::rt::*; + +extern crate hercules_ir; + +use self::hercules_ir::*; + +/* + * The alignment of a type does not depend on dynamic constants. + */ +pub fn get_type_alignment(types: &Vec<Type>, ty: TypeID) -> usize { + match types[ty.idx()] { + Type::Control => panic!(), + Type::Boolean | Type::Integer8 | Type::UnsignedInteger8 => 1, + Type::Integer16 | Type::UnsignedInteger16 => 2, + Type::Integer32 | Type::UnsignedInteger32 | Type::Float32 => 4, + Type::Integer64 | Type::UnsignedInteger64 | Type::Float64 => 8, + Type::Product(ref members) | Type::Summation(ref members) => members + .into_iter() + .map(|id| get_type_alignment(types, *id)) + .max() + .unwrap_or(1), + Type::Array(elem, _) => get_type_alignment(types, elem), + } +} diff --git a/hercules_cg/src/mem.rs b/hercules_cg/src/mem.rs deleted file mode 100644 index 0c053455..00000000 --- a/hercules_cg/src/mem.rs +++ /dev/null @@ -1,291 +0,0 @@ -extern crate bitvec; -extern crate hercules_ir; - -use std::collections::{BTreeMap, BTreeSet}; - -use self::bitvec::prelude::*; - -use self::hercules_ir::*; - -#[derive(Debug)] -pub struct MemoryObjects { - node_id_to_memory_objects: Vec<Vec<usize>>, - memory_object_to_origin: Vec<NodeID>, - parameter_index_to_memory_object: Vec<Option<usize>>, - possibly_returned_memory_objects: Vec<usize>, -} - -impl MemoryObjects { - pub fn memory_objects(&self, id: NodeID) -> &Vec<usize> { - &self.node_id_to_memory_objects[id.idx()] - } - - pub fn origin(&self, memory_object: usize) -> NodeID { - self.memory_object_to_origin[memory_object] - } - - pub fn memory_object_of_parameter(&self, parameter: usize) -> Option<usize> { - self.parameter_index_to_memory_object[parameter] - } - - pub fn returned_memory_objects(&self) -> &Vec<usize> { - &self.possibly_returned_memory_objects - } - - pub fn num_memory_objects(&self) -> usize { - self.memory_object_to_origin.len() - } -} - -#[derive(Debug)] -pub struct MemoryObjectsMutability { - func_to_memory_object_to_mutable: Vec<BitVec<u8, Lsb0>>, -} - -impl MemoryObjectsMutability { - pub fn is_mutable(&self, id: FunctionID, memory_object: usize) -> bool { - self.func_to_memory_object_to_mutable[id.idx()][memory_object] - } -} - -/* - * Each node is assigned a set of memory objects output-ed from the node. This - * is just a set of memory object IDs (usize). - */ -#[derive(PartialEq, Eq, Clone, Debug)] -struct MemoryObjectLattice { - objs: BTreeSet<usize>, -} - -impl Semilattice for MemoryObjectLattice { - fn meet(a: &Self, b: &Self) -> Self { - MemoryObjectLattice { - objs: a.objs.union(&b.objs).map(|x| *x).collect(), - } - } - - fn top() -> Self { - MemoryObjectLattice { - objs: BTreeSet::new(), - } - } - - fn bottom() -> Self { - // Technically, this lattice is unbounded - technically technically, the - // lattice is bounded by the number of memory objects in a given - // instance, but incorporating this information is not possible in our - // Semilattice inferface. Luckily bottom() isn't necessary if we never - // call it, which we don't here. - panic!() - } -} - -/* - * Top level function to analyze memory objects in a Hercules function. These - * are distinct collections (products, summations, arrays) that are used in a - * function where we try to disambiguate a string of values produced in the - * immutable value semantics of Hercules IR into a smaller amount of distinct - * memory object that can be modified in-place. - */ -pub fn memory_objects( - function: &Function, - types: &Vec<Type>, - reverse_postorder: &Vec<NodeID>, - typing: &Vec<TypeID>, -) -> MemoryObjects { - // Find memory objects originating at parameters, constants, calls, or - // undefs. - let memory_object_to_origin: Vec<_> = function - .nodes - .iter() - .enumerate() - .filter(|(idx, node)| { - (node.is_parameter() || node.is_constant() || node.is_call() || node.is_undef()) - && !types[typing[*idx].idx()].is_primitive() - }) - .map(|(idx, _)| NodeID::new(idx)) - .collect(); - let node_id_to_originating_memory_obj: BTreeMap<_, _> = memory_object_to_origin - .iter() - .enumerate() - .map(|(idx, id)| (*id, idx)) - .collect(); - - // Map parameter index to memory object, if applicable. Panic if two - // parameter nodes with the same index are found - those really should get - // removed by GVN! - let mut parameter_index_to_memory_object = vec![None; function.param_types.len()]; - for (memory_object, origin) in memory_object_to_origin.iter().enumerate() { - if let Some(param) = function.nodes[origin.idx()].try_parameter() { - assert!( - parameter_index_to_memory_object[param].is_none(), - "PANIC: Found multiple parameter nodes with the same index." - ); - parameter_index_to_memory_object[param] = Some(memory_object); - } - } - - // Run dataflow analysis to figure out which memory objects each data node - // may be. Note that there's a strict subset of data nodes that can assigned - // memory objects: - // - // - Phi: selects between memory objects in SSA form, may be assigned - // multiple possible memory objects. - // - Reduce: reduces over a memory object, similar to phis. - // - Parameter: may originate a memory object. - // - Constant: may originate a memory object. - // - Call: may originate a memory object - if doesn't originate a memory - // object, doesn't become one based on arguments, as arguments are passed - // to callee. - // - Read: may extract a smaller memory object from input - this is - // considered to be the same memory object as the input, as no copy takes - // place. - // - Write: updates a memory object. - // - Undef: may originate a dummy memory object. - // - // Some notable omissions are: - // - // - Return: doesn't technically "output" a memory object, but may consume - // one. As in the logic with calls not returning a memory object, returns - // are not assigned memory objects. - // - Ternary (select): selecting over memory objects is a gray area - // currently. Bail if we see a select over memory objects. - assert!(!function.nodes.iter().enumerate().any(|(idx, node)| node - .try_ternary(TernaryOperator::Select) - .is_some() - && !types[typing[idx].idx()].is_primitive())); - let lattice = forward_dataflow(function, reverse_postorder, |inputs, id| { - match function.nodes[id.idx()] { - Node::Phi { - control: _, - data: _, - } - | Node::Reduce { - control: _, - init: _, - reduct: _, - } => inputs - .into_iter() - .fold(MemoryObjectLattice::top(), |acc, input| { - MemoryObjectLattice::meet(&acc, input) - }), - Node::Parameter { index: _ } - | Node::Constant { id: _ } - | Node::Call { - control: _, - function: _, - dynamic_constants: _, - args: _, - } - | Node::Undef { ty: _ } - if let Some(obj) = node_id_to_originating_memory_obj.get(&id) => - { - MemoryObjectLattice { - objs: [*obj].iter().map(|x| *x).collect(), - } - } - Node::Read { - collect: _, - indices: _, - } - | Node::Write { - collect: _, - data: _, - indices: _, - } => inputs[0].clone(), - _ => MemoryObjectLattice::top(), - } - }); - - // Look at the memory objects the data input to each return could be. - let mut possibly_returned_memory_objects = BTreeSet::new(); - for node in function.nodes.iter() { - if let Node::Return { control: _, data } = node { - possibly_returned_memory_objects = possibly_returned_memory_objects - .union(&lattice[data.idx()].objs) - .map(|x| *x) - .collect(); - } - } - let possibly_returned_memory_objects = possibly_returned_memory_objects.into_iter().collect(); - - let node_id_to_memory_objects = lattice - .into_iter() - .map(|lattice| lattice.objs.into_iter().collect()) - .collect(); - MemoryObjects { - node_id_to_memory_objects, - memory_object_to_origin, - parameter_index_to_memory_object, - possibly_returned_memory_objects, - } -} - -/* - * Determine if each memory object in each function is mutated or not. - */ -pub fn memory_objects_mutability( - module: &Module, - callgraph: &CallGraph, - memory_objects: &Vec<MemoryObjects>, -) -> MemoryObjectsMutability { - let mut mutated: Vec<_> = memory_objects - .iter() - .map(|memory_objects| bitvec![u8, Lsb0; 0; memory_objects.num_memory_objects()]) - .collect(); - let topo = callgraph.topo(); - - for func_id in topo { - // A memory object is mutated when: - // 1. The object is the subject of a write node. - // 2. The object is passed as argument to a function that mutates it. - for (idx, node) in module.functions[func_id.idx()].nodes.iter().enumerate() { - if node.is_write() { - // Every memory object that the write itself corresponds to it - // mutable in this function. - for memory_object in memory_objects[func_id.idx()].memory_objects(NodeID::new(idx)) - { - mutated[func_id.idx()].set(*memory_object, true); - } - } else if let Some((_, callee_id, _, args)) = node.try_call() { - for (param_idx, arg) in args.into_iter().enumerate() { - // If this parameter corresponds to a memory object and it's - // mutable in the callee... - if let Some(param_callee_memory_object) = - memory_objects[callee_id.idx()].memory_object_of_parameter(param_idx) - && mutated[callee_id.idx()][param_callee_memory_object] - { - // Then every memory object corresponding to the - // argument node in this function is mutable. - for memory_object in memory_objects[func_id.idx()].memory_objects(*arg) { - mutated[func_id.idx()].set(*memory_object, true); - } - } - } - } - } - } - - MemoryObjectsMutability { - func_to_memory_object_to_mutable: mutated, - } -} - -/* - * The alignment of a type does not depend on dynamic constants. - */ -pub fn get_type_alignment(types: &Vec<Type>, ty: TypeID) -> usize { - match types[ty.idx()] { - Type::Control => panic!(), - Type::Boolean | Type::Integer8 | Type::UnsignedInteger8 => 1, - Type::Integer16 | Type::UnsignedInteger16 => 2, - Type::Integer32 | Type::UnsignedInteger32 | Type::Float32 => 4, - Type::Integer64 | Type::UnsignedInteger64 | Type::Float64 => 8, - Type::Product(ref members) | Type::Summation(ref members) => members - .into_iter() - .map(|id| get_type_alignment(types, *id)) - .max() - .unwrap_or(1), - Type::Array(elem, _) => get_type_alignment(types, elem), - } -} diff --git a/hercules_cg/src/rt.rs b/hercules_cg/src/rt.rs index 305ecf9b..b5623d1a 100644 --- a/hercules_cg/src/rt.rs +++ b/hercules_cg/src/rt.rs @@ -23,10 +23,9 @@ pub fn rt_codegen<W: Write>( typing: &Vec<TypeID>, control_subgraph: &Subgraph, bbs: &Vec<NodeID>, + collection_objects: &CollectionObjects, callgraph: &CallGraph, devices: &Vec<Device>, - memory_objects: &Vec<MemoryObjects>, - memory_objects_mutability: &MemoryObjectsMutability, w: &mut W, ) -> Result<(), Error> { let ctx = RTContext { @@ -36,10 +35,9 @@ pub fn rt_codegen<W: Write>( typing, control_subgraph, bbs, + collection_objects, callgraph, devices, - memory_objects, - _memory_objects_mutability: memory_objects_mutability, }; ctx.codegen_function(w) } @@ -51,12 +49,9 @@ struct RTContext<'a> { typing: &'a Vec<TypeID>, control_subgraph: &'a Subgraph, bbs: &'a Vec<NodeID>, + collection_objects: &'a CollectionObjects, callgraph: &'a CallGraph, devices: &'a Vec<Device>, - memory_objects: &'a Vec<MemoryObjects>, - // TODO: use once memory objects are passed in a custom type where this - // actually matters. - _memory_objects_mutability: &'a MemoryObjectsMutability, } impl<'a> RTContext<'a> { @@ -66,7 +61,7 @@ impl<'a> RTContext<'a> { // Dump the function signature. write!( w, - "#[allow(unused_variables,unused_mut)]async fn {}(", + "#[allow(unused_variables,unused_mut)]\nasync fn {}(", func.name )?; let mut first_param = true; @@ -99,10 +94,10 @@ impl<'a> RTContext<'a> { write!(w, ") -> {} {{\n", self.get_type_interface(func.return_type))?; // Copy the "interface" parameters to "non-interface" parameters. - // The purpose of this is to convert memory objects from a Box<[u8]> + // The purpose of this is to convert collection objects from a Box<[u8]> // type to a *mut u8 type. This name copying is done so that we can - // easily construct memory objects just after this by moving the - // "interface" parameters. + // easily construct objects just after this by moving the "inferface" + // parameters. for (idx, ty) in func.param_types.iter().enumerate() { if self.module.types[ty.idx()].is_primitive() { write!(w, " let p{} = p_i{};\n", idx, idx)?; @@ -115,45 +110,42 @@ impl<'a> RTContext<'a> { } } - // Collect the boxes representing ownership over memory objects for this - // function. The actual emitted computation is done entirely using + // Collect the boxes representing ownership over collection objects for + // this function. The actual emitted computation is done entirely using // pointers, so these get emitted to hold onto ownership over the // underlying memory and to automatically clean them up when this - // function returns. Memory objects are inside Options, since their + // function returns. Collection objects are inside Options, since their // ownership may get passed to other called RT functions. If this - // function returns a memory object, then at the very end, right before - // the return, the to-be-returned pointer is compared against the owned - // memory objects - it should match exactly one of those objects, and - // that box is what's actually returned. + // function returns a collection object, then at the very end, right + // before the return, the to-be-returned pointer is compared against the + // owned collection objects - it should match exactly one of those + // objects, and that box is what's actually returned. let mem_obj_ty = "::core::option::Option<::std::boxed::Box<[u8]>>"; - for memory_object in 0..self.memory_objects[self.func_id.idx()].num_memory_objects() { - let origin = self.memory_objects[self.func_id.idx()].origin(memory_object); - match func.nodes[origin.idx()] { - Node::Parameter { index } => write!( + for object in self.collection_objects[&self.func_id].iter_objects() { + match self.collection_objects[&self.func_id].origin(object) { + CollectionObjectOrigin::Parameter(index) => write!( w, - " let mut mem_obj{}: {} = Some(p_i{});\n", - memory_object, mem_obj_ty, index + " let mut obj{}: {} = Some(p_i{});\n", + object.idx(), + mem_obj_ty, + index )?, - Node::Constant { id: _ } => { - let size = self.codegen_type_size(self.typing[origin.idx()]); + CollectionObjectOrigin::Constant(id) => { + let size = self.codegen_type_size(self.typing[id.idx()]); write!( w, - " let mut mem_obj{}: {} = Some((0..{}).map(|_| 0u8).collect());\n", - memory_object, mem_obj_ty, size + " let mut obj{}: {} = Some((0..{}).map(|_| 0u8).collect());\n", + object.idx(), + mem_obj_ty, + size )? } - Node::Call { - control: _, - function: _, - dynamic_constants: _, - args: _, - } - | Node::Undef { ty: _ } => write!( + CollectionObjectOrigin::Call(_) | CollectionObjectOrigin::Undef(_) => write!( w, - " let mut mem_obj{}: {} = None;\n", - memory_object, mem_obj_ty, + " let mut obj{}: {} = None;\n", + object.idx(), + mem_obj_ty, )?, - _ => panic!(), } } @@ -308,19 +300,19 @@ impl<'a> RTContext<'a> { } Node::Return { control: _, data } => { let block = &mut blocks.get_mut(&id).unwrap(); - let memory_objects = self.memory_objects[self.func_id.idx()].memory_objects(data); - if memory_objects.is_empty() { + let objects = self.collection_objects[&self.func_id].objects(data); + if objects.is_empty() { write!(block, " return {};\n", self.get_value(data))? } else { - // If the value to return is a memory object, figure out - // which memory object it actually is at runtime and return - // that box. - for memory_object in memory_objects { - write!(block, " if let Some(mut mem_obj) = mem_obj{} && ::std::boxed::Box::as_mut_ptr(&mut mem_obj) as *mut u8 == {} {{\n", memory_object, self.get_value(data))?; - write!(block, " return mem_obj;\n")?; + // If the value to return is a collection object, figure out + // which object it actually is at runtime and return that + // box. + for object in objects { + write!(block, " if let Some(mut obj) = obj{} && ::std::boxed::Box::as_mut_ptr(&mut obj) as *mut u8 == {} {{\n", object.idx(), self.get_value(data))?; + write!(block, " return obj;\n")?; write!(block, " }}\n")?; } - write!(block, " panic!(\"HERCULES PANIC: Pointer to be returned doesn't match any known memory objects.\");\n")? + write!(block, " panic!(\"HERCULES PANIC: Pointer to be returned doesn't match any known collection objects.\");\n")? } } _ => panic!("PANIC: Can't lower {:?}.", func.nodes[id.idx()]), @@ -363,14 +355,13 @@ impl<'a> RTContext<'a> { Constant::Float32(val) => write!(block, "{}f32", val)?, Constant::Float64(val) => write!(block, "{}f64", val)?, Constant::Product(_, _) | Constant::Summation(_, _, _) | Constant::Array(_) => { - let memory_objects = - self.memory_objects[self.func_id.idx()].memory_objects(id); - assert_eq!(memory_objects.len(), 1); - let memory_object = memory_objects[0]; + let objects = self.collection_objects[&self.func_id].objects(id); + assert_eq!(objects.len(), 1); + let object = objects[0]; write!( block, - "::std::boxed::Box::as_mut_ptr(mem_obj{}.as_mut().unwrap()) as *mut u8", - memory_object + "::std::boxed::Box::as_mut_ptr(obj{}.as_mut().unwrap()) as *mut u8", + object.idx() )? } } @@ -400,43 +391,40 @@ impl<'a> RTContext<'a> { } write!(block, ") }};\n")?; - // When a CPU function is called that returns a memory - // object, that memory object must have come from one of - // its parameters. Dynamically figure out which one it - // came from, so that we can move it to the slot of the - // output memory object. - let call_memory_objects = - self.memory_objects[self.func_id.idx()].memory_objects(id); - if !call_memory_objects.is_empty() { - assert_eq!(call_memory_objects.len(), 1); - let call_memory_object = call_memory_objects[0]; - - let callee_returned_memory_objects = - self.memory_objects[callee_id.idx()].returned_memory_objects(); + // When a CPU function is called that returns a + // collection object, that object must have come from + // one of its parameters. Dynamically figure out which + // one it came from, so that we can move it to the slot + // of the output object. + let call_objects = self.collection_objects[&self.func_id].objects(id); + if !call_objects.is_empty() { + assert_eq!(call_objects.len(), 1); + let call_object = call_objects[0]; + + let callee_returned_objects = + self.collection_objects[&callee_id].returned_objects(); let possible_params: Vec<_> = (0..self.module.functions[callee_id.idx()].param_types.len()) .filter(|idx| { - let memory_object_of_param = self.memory_objects - [callee_id.idx()] - .memory_object_of_parameter(*idx); + let object_of_param = self.collection_objects[&callee_id] + .param_to_object(*idx); // Look at parameters that could be the // source of the memory object returned // by the function. - memory_object_of_param - .map(|memory_object_of_param| { - callee_returned_memory_objects - .contains(&memory_object_of_param) + object_of_param + .map(|object_of_param| { + callee_returned_objects.contains(&object_of_param) }) .unwrap_or(false) }) .collect(); - let arg_memory_objects = args + let arg_objects = args .into_iter() .enumerate() .filter(|(idx, _)| possible_params.contains(idx)) .map(|(_, arg)| { - self.memory_objects[self.func_id.idx()] - .memory_objects(*arg) + self.collection_objects[&self.func_id] + .objects(*arg) .into_iter() }) .flatten(); @@ -446,23 +434,24 @@ impl<'a> RTContext<'a> { // returned by the call. Move that memory object // into the memory object of the call. let mut first_obj = true; - for arg_memory_object in arg_memory_objects { + for arg_object in arg_objects { write!(block, " ")?; if first_obj { first_obj = false; } else { write!(block, "else ")?; } - write!(block, "if let Some(mem_obj) = mem_obj{}.as_mut() && ::std::boxed::Box::as_mut_ptr(mem_obj) as *mut u8 == {} {{\n", arg_memory_object, self.get_value(id))?; + write!(block, "if let Some(obj) = obj{}.as_mut() && ::std::boxed::Box::as_mut_ptr(obj) as *mut u8 == {} {{\n", arg_object.idx(), self.get_value(id))?; write!( block, - " mem_obj{} = mem_obj{}.take();\n", - call_memory_object, arg_memory_object + " obj{} = obj{}.take();\n", + call_object.idx(), + arg_object.idx() )?; write!(block, " }}\n")?; } write!(block, " else {{\n")?; - write!(block, " panic!(\"HERCULES PANIC: Pointer returned from called function doesn't match any known memory objects.\");\n")?; + write!(block, " panic!(\"HERCULES PANIC: Pointer returned from called function doesn't match any known collection objects.\");\n")?; write!(block, " }}\n")?; } } diff --git a/hercules_ir/src/collections.rs b/hercules_ir/src/collections.rs new file mode 100644 index 00000000..c3dafe7a --- /dev/null +++ b/hercules_ir/src/collections.rs @@ -0,0 +1,348 @@ +extern crate bitvec; + +use std::collections::{BTreeMap, BTreeSet}; + +use self::bitvec::prelude::*; + +use crate::*; + +/* + * Analysis result that finds "collection objects" in Hercules IR. This analysis + * is inter-procedural, since collection objects can be passed / returned to / + * from called functions. This analysis also tracks which collection objects are + * mutated "in" a function - a collection object is mutated in a function if + * that function contains a write node that may write to that collection object + * or if that function contains a call node that may take that collection object + * as an argument and that collection parameter of that function is mutated. + * Collection objects are numbered locally - the following nodes may originate a + * collection object: + * + * - Parameter: each parameter index gets assigned a single collection object, + * each parameter node gets assigned the object of its index. + * - Constant: each collection constant node gets assigned a single collection + * object. + * - Call: each function is analyzed to determine which collection objects (of + * its parameters or an object it originates) may be returned; a call node + * originates a new collection object if it may return an object originated + * inside the callee. + * - Undef: each undef node with a non-primitive type gets assigned a single + * collection object. + * + * The analysis contains the following information: + * + * - For each node in each function, which collection objects may be on the + * output of the node? + * - For each function, which collection objects may be mutated inside that + * function? + * - For each function, which collection objects may be returned? + * - For each collection object, how was it originated? + */ +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum CollectionObjectOrigin { + Parameter(usize), + Constant(NodeID), + Call(NodeID), + Undef(NodeID), +} + +define_id_type!(CollectionObjectID); + +#[derive(Debug, Clone)] +pub struct FunctionCollectionObjects { + objects_per_node: Vec<Vec<CollectionObjectID>>, + mutated: BitVec<u8, Lsb0>, + returned: Vec<CollectionObjectID>, + origins: Vec<CollectionObjectOrigin>, +} + +pub type CollectionObjects = BTreeMap<FunctionID, FunctionCollectionObjects>; + +impl CollectionObjectOrigin { + fn try_parameter(&self) -> Option<usize> { + match self { + CollectionObjectOrigin::Parameter(index) => Some(*index), + _ => None, + } + } +} + +impl FunctionCollectionObjects { + pub fn objects(&self, id: NodeID) -> &Vec<CollectionObjectID> { + &self.objects_per_node[id.idx()] + } + + pub fn origin(&self, object: CollectionObjectID) -> CollectionObjectOrigin { + self.origins[object.idx()] + } + + pub fn param_to_object(&self, index: usize) -> Option<CollectionObjectID> { + self.origins + .iter() + .position(|origin| *origin == CollectionObjectOrigin::Parameter(index)) + .map(CollectionObjectID::new) + } + + pub fn returned_objects(&self) -> &Vec<CollectionObjectID> { + &self.returned + } + + pub fn is_mutated(&self, object: CollectionObjectID) -> bool { + self.mutated[object.idx()] + } + + pub fn num_objects(&self) -> usize { + self.origins.len() + } + + pub fn iter_objects(&self) -> impl Iterator<Item = CollectionObjectID> { + (0..self.num_objects()).map(CollectionObjectID::new) + } +} + +/* + * Each node is assigned a set of collection objects output-ed from the node. + * This is just a set of collection object IDs (usize). + */ +#[derive(PartialEq, Eq, Clone, Debug)] +struct CollectionObjectLattice { + objs: BTreeSet<CollectionObjectID>, +} + +impl Semilattice for CollectionObjectLattice { + fn meet(a: &Self, b: &Self) -> Self { + CollectionObjectLattice { + objs: a.objs.union(&b.objs).map(|x| *x).collect(), + } + } + + fn top() -> Self { + CollectionObjectLattice { + objs: BTreeSet::new(), + } + } + + fn bottom() -> Self { + // Technically, this lattice is unbounded - technically technically, the + // lattice is bounded by the number of collection objects in a given + // function, but incorporating this information is not possible in our + // Semilattice inferface. Luckily bottom() isn't necessary if we never + // call it, which we don't for this analysis. + panic!() + } +} + +/* + * Top level function to analyze collection objects in a Hercules module. + */ +pub fn collection_objects( + module: &Module, + reverse_postorders: &Vec<Vec<NodeID>>, + typing: &ModuleTyping, + callgraph: &CallGraph, +) -> CollectionObjects { + // Analyze functions in reverse topological order, since the analysis of a + // function depends on all functions it calls. + let mut collection_objects: CollectionObjects = BTreeMap::new(); + let topo = callgraph.topo(); + + for func_id in topo { + let func = &module.functions[func_id.idx()]; + let typing = &typing[func_id.idx()]; + let reverse_postorder = &reverse_postorders[func_id.idx()]; + + // Find collection objects originating at parameters, constants, calls, + // or undefs. Each node may *originate* one collection object. + let param_origins = func + .param_types + .iter() + .enumerate() + .filter(|(_, ty_id)| !module.types[ty_id.idx()].is_primitive()) + .map(|(idx, _)| CollectionObjectOrigin::Parameter(idx)); + let other_origins = func + .nodes + .iter() + .enumerate() + .filter_map(|(idx, node)| match node { + Node::Constant { id: _ } if !module.types[typing[idx].idx()].is_primitive() => { + Some(CollectionObjectOrigin::Constant(NodeID::new(idx))) + } + Node::Call { + control: _, + function: callee, + dynamic_constants: _, + args: _, + } if { + let fco = &collection_objects[&callee]; + fco.returned + .iter() + .any(|returned| fco.origins[returned.idx()].try_parameter().is_none()) + } => + { + // If the callee may return a new collection object, then + // this call node originates a single collection object. The + // node may output multiple collection objects, say if the + // callee may return an object passed in as a parameter - + // this is determined later. + Some(CollectionObjectOrigin::Call(NodeID::new(idx))) + } + Node::Undef { ty: _ } if !module.types[typing[idx].idx()].is_primitive() => { + Some(CollectionObjectOrigin::Undef(NodeID::new(idx))) + } + _ => None, + }); + let origins: Vec<_> = param_origins.chain(other_origins).collect(); + + // Run dataflow analysis to figure out which collection objects each + // data node may output. Note that there's a strict subset of data nodes + // that can output collection objects: + // + // - Phi: selects between objects in SSA form, may be assigned multiple + // possible objects. + // - Reduce: reduces over an object, similar to phis. + // - Parameter: may originate an object. + // - Constant: may originate an object. + // - Call: may originate an object and may return an object passed in as + // a parameter. + // - Read: may extract a smaller object from the input - this is + // considered to be the same object as the input, as no copy takes + // place. + // - Write: updates an object - this is considered to be the same object + // as the input object, as the write gets lowered to an in-place + // mutation. + // - Undef: may originate a dummy object. + // - Ternary (select): selects between two objects, may output either. + let lattice = forward_dataflow(func, reverse_postorder, |inputs, id| { + match func.nodes[id.idx()] { + Node::Phi { + control: _, + data: _, + } + | Node::Reduce { + control: _, + init: _, + reduct: _, + } + | Node::Ternary { + op: TernaryOperator::Select, + first: _, + second: _, + third: _, + } => inputs + .into_iter() + .fold(CollectionObjectLattice::top(), |acc, input| { + CollectionObjectLattice::meet(&acc, input) + }), + Node::Parameter { index } => { + let obj = origins + .iter() + .position(|origin| *origin == CollectionObjectOrigin::Parameter(index)) + .map(CollectionObjectID::new); + CollectionObjectLattice { + objs: obj.into_iter().collect(), + } + } + Node::Constant { id: _ } => { + let obj = origins + .iter() + .position(|origin| *origin == CollectionObjectOrigin::Constant(id)) + .map(CollectionObjectID::new); + CollectionObjectLattice { + objs: obj.into_iter().collect(), + } + } + Node::Call { + control: _, + function: callee, + dynamic_constants: _, + args: _, + } if !module.types[typing[id.idx()].idx()].is_primitive() => { + let new_obj = origins + .iter() + .position(|origin| *origin == CollectionObjectOrigin::Call(id)) + .map(CollectionObjectID::new); + let fco = &collection_objects[&callee]; + let param_objs = fco + .returned + .iter() + .filter_map(|returned| fco.origins[returned.idx()].try_parameter()) + .map(|param_index| inputs[param_index + 1]); + + let mut objs: BTreeSet<_> = new_obj.into_iter().collect(); + for param_objs in param_objs { + objs.extend(¶m_objs.objs); + } + CollectionObjectLattice { objs } + } + Node::Undef { ty: _ } => { + let obj = origins + .iter() + .position(|origin| *origin == CollectionObjectOrigin::Undef(id)) + .map(CollectionObjectID::new); + CollectionObjectLattice { + objs: obj.into_iter().collect(), + } + } + Node::Read { + collect: _, + indices: _, + } + | Node::Write { + collect: _, + data: _, + indices: _, + } => inputs[0].clone(), + _ => CollectionObjectLattice::top(), + } + }); + let objects_per_node: Vec<Vec<_>> = lattice + .into_iter() + .map(|l| l.objs.into_iter().collect()) + .collect(); + + // Look at the collection objects that each return may take as input. + let mut returned: BTreeSet<CollectionObjectID> = BTreeSet::new(); + for node in func.nodes.iter() { + if let Node::Return { control: _, data } = node { + returned.extend(&objects_per_node[data.idx()]); + } + } + let returned = returned.into_iter().collect(); + + // Determine which objects are potentially mutated. + let mut mutated = bitvec![u8, Lsb0; 0; origins.len()]; + for (idx, node) in func.nodes.iter().enumerate() { + if node.is_write() { + // Every object that the write itself corresponds to is mutable + // in this function. + for object in objects_per_node[idx].iter() { + mutated.set(object.idx(), true); + } + } else if let Some((_, callee, _, args)) = node.try_call() { + let fco = &collection_objects[&callee]; + for (param_idx, arg) in args.into_iter().enumerate() { + // If this parameter corresponds to an object and it's + // mutable in the callee... + if let Some(param_callee_object) = fco.param_to_object(param_idx) + && fco.is_mutated(param_callee_object) + { + // Then every object corresponding to the argument node + // in this function is mutable. + for object in objects_per_node[arg.idx()].iter() { + mutated.set(object.idx(), true); + } + } + } + } + } + + let fco = FunctionCollectionObjects { + objects_per_node, + mutated, + returned, + origins, + }; + collection_objects.insert(func_id, fco); + } + + collection_objects +} diff --git a/hercules_ir/src/lib.rs b/hercules_ir/src/lib.rs index abe7f46f..05e5e2e8 100644 --- a/hercules_ir/src/lib.rs +++ b/hercules_ir/src/lib.rs @@ -9,6 +9,7 @@ pub mod antideps; pub mod build; pub mod callgraph; +pub mod collections; pub mod dataflow; pub mod def_use; pub mod dom; @@ -24,6 +25,7 @@ pub mod verify; pub use crate::antideps::*; pub use crate::build::*; pub use crate::callgraph::*; +pub use crate::collections::*; pub use crate::dataflow::*; pub use crate::def_use::*; pub use crate::dom::*; diff --git a/hercules_opt/src/pass.rs b/hercules_opt/src/pass.rs index 24ed0e4e..932ddee9 100644 --- a/hercules_opt/src/pass.rs +++ b/hercules_opt/src/pass.rs @@ -74,6 +74,7 @@ pub struct PassManager { pub antideps: Option<Vec<Vec<(NodeID, NodeID)>>>, pub data_nodes_in_fork_joins: Option<Vec<HashMap<NodeID, HashSet<NodeID>>>>, pub bbs: Option<Vec<Vec<NodeID>>>, + pub collection_objects: Option<CollectionObjects>, pub callgraph: Option<CallGraph>, } @@ -95,6 +96,7 @@ impl PassManager { antideps: None, data_nodes_in_fork_joins: None, bbs: None, + collection_objects: None, callgraph: None, } } @@ -315,6 +317,23 @@ impl PassManager { } } + pub fn make_collection_objects(&mut self) { + if self.collection_objects.is_none() { + self.make_reverse_postorders(); + self.make_typing(); + self.make_callgraph(); + let reverse_postorders = self.reverse_postorders.as_ref().unwrap(); + let typing = self.typing.as_ref().unwrap(); + let callgraph = self.callgraph.as_ref().unwrap(); + self.collection_objects = Some(collection_objects( + &self.module, + reverse_postorders, + typing, + callgraph, + )); + } + } + pub fn make_callgraph(&mut self) { if self.callgraph.is_none() { self.callgraph = Some(callgraph(&self.module)); @@ -844,26 +863,15 @@ impl PassManager { self.make_typing(); self.make_control_subgraphs(); self.make_bbs(); + self.make_collection_objects(); self.make_callgraph(); let reverse_postorders = self.reverse_postorders.as_ref().unwrap(); let typing = self.typing.as_ref().unwrap(); let control_subgraphs = self.control_subgraphs.as_ref().unwrap(); let bbs = self.bbs.as_ref().unwrap(); + let collection_objects = self.collection_objects.as_ref().unwrap(); let callgraph = self.callgraph.as_ref().unwrap(); - let memory_objects: Vec<_> = (0..self.module.functions.len()) - .map(|idx| { - memory_objects( - &self.module.functions[idx], - &self.module.types, - &reverse_postorders[idx], - &typing[idx], - ) - }) - .collect(); - let memory_objects_mutable = - memory_objects_mutability(&self.module, &callgraph, &memory_objects); - let devices = device_placement(&self.module.functions, &callgraph); let mut rust_rt = String::new(); @@ -889,10 +897,9 @@ impl PassManager { &typing[idx], &control_subgraphs[idx], &bbs[idx], + &collection_objects, &callgraph, &devices, - &memory_objects, - &memory_objects_mutable, &mut rust_rt, ) .unwrap(), -- GitLab