Skip to content
Snippets Groups Projects
Commit 7b8363fe authored by rarbore2's avatar rarbore2
Browse files

Merge branch 'clone_detection' into 'main'

Refactor analysis of collection objects

See merge request !89
parents 109d2f73 ee9a4536
No related branches found
No related tags found
1 merge request!89Refactor analysis of collection objects
Pipeline #200779 passed
...@@ -2,10 +2,31 @@ ...@@ -2,10 +2,31 @@
pub mod cpu; pub mod cpu;
pub mod device; pub mod device;
pub mod mem;
pub mod rt; pub mod rt;
pub use crate::cpu::*; pub use crate::cpu::*;
pub use crate::device::*; pub use crate::device::*;
pub use crate::mem::*;
pub use crate::rt::*; pub use crate::rt::*;
extern crate hercules_ir;
use self::hercules_ir::*;
/*
* The alignment of a type does not depend on dynamic constants.
*/
pub fn get_type_alignment(types: &Vec<Type>, ty: TypeID) -> usize {
match types[ty.idx()] {
Type::Control => panic!(),
Type::Boolean | Type::Integer8 | Type::UnsignedInteger8 => 1,
Type::Integer16 | Type::UnsignedInteger16 => 2,
Type::Integer32 | Type::UnsignedInteger32 | Type::Float32 => 4,
Type::Integer64 | Type::UnsignedInteger64 | Type::Float64 => 8,
Type::Product(ref members) | Type::Summation(ref members) => members
.into_iter()
.map(|id| get_type_alignment(types, *id))
.max()
.unwrap_or(1),
Type::Array(elem, _) => get_type_alignment(types, elem),
}
}
extern crate bitvec;
extern crate hercules_ir;
use std::collections::{BTreeMap, BTreeSet};
use self::bitvec::prelude::*;
use self::hercules_ir::*;
#[derive(Debug)]
pub struct MemoryObjects {
node_id_to_memory_objects: Vec<Vec<usize>>,
memory_object_to_origin: Vec<NodeID>,
parameter_index_to_memory_object: Vec<Option<usize>>,
possibly_returned_memory_objects: Vec<usize>,
}
impl MemoryObjects {
pub fn memory_objects(&self, id: NodeID) -> &Vec<usize> {
&self.node_id_to_memory_objects[id.idx()]
}
pub fn origin(&self, memory_object: usize) -> NodeID {
self.memory_object_to_origin[memory_object]
}
pub fn memory_object_of_parameter(&self, parameter: usize) -> Option<usize> {
self.parameter_index_to_memory_object[parameter]
}
pub fn returned_memory_objects(&self) -> &Vec<usize> {
&self.possibly_returned_memory_objects
}
pub fn num_memory_objects(&self) -> usize {
self.memory_object_to_origin.len()
}
}
#[derive(Debug)]
pub struct MemoryObjectsMutability {
func_to_memory_object_to_mutable: Vec<BitVec<u8, Lsb0>>,
}
impl MemoryObjectsMutability {
pub fn is_mutable(&self, id: FunctionID, memory_object: usize) -> bool {
self.func_to_memory_object_to_mutable[id.idx()][memory_object]
}
}
/*
* Each node is assigned a set of memory objects output-ed from the node. This
* is just a set of memory object IDs (usize).
*/
#[derive(PartialEq, Eq, Clone, Debug)]
struct MemoryObjectLattice {
objs: BTreeSet<usize>,
}
impl Semilattice for MemoryObjectLattice {
fn meet(a: &Self, b: &Self) -> Self {
MemoryObjectLattice {
objs: a.objs.union(&b.objs).map(|x| *x).collect(),
}
}
fn top() -> Self {
MemoryObjectLattice {
objs: BTreeSet::new(),
}
}
fn bottom() -> Self {
// Technically, this lattice is unbounded - technically technically, the
// lattice is bounded by the number of memory objects in a given
// instance, but incorporating this information is not possible in our
// Semilattice inferface. Luckily bottom() isn't necessary if we never
// call it, which we don't here.
panic!()
}
}
/*
* Top level function to analyze memory objects in a Hercules function. These
* are distinct collections (products, summations, arrays) that are used in a
* function where we try to disambiguate a string of values produced in the
* immutable value semantics of Hercules IR into a smaller amount of distinct
* memory object that can be modified in-place.
*/
pub fn memory_objects(
function: &Function,
types: &Vec<Type>,
reverse_postorder: &Vec<NodeID>,
typing: &Vec<TypeID>,
) -> MemoryObjects {
// Find memory objects originating at parameters, constants, calls, or
// undefs.
let memory_object_to_origin: Vec<_> = function
.nodes
.iter()
.enumerate()
.filter(|(idx, node)| {
(node.is_parameter() || node.is_constant() || node.is_call() || node.is_undef())
&& !types[typing[*idx].idx()].is_primitive()
})
.map(|(idx, _)| NodeID::new(idx))
.collect();
let node_id_to_originating_memory_obj: BTreeMap<_, _> = memory_object_to_origin
.iter()
.enumerate()
.map(|(idx, id)| (*id, idx))
.collect();
// Map parameter index to memory object, if applicable. Panic if two
// parameter nodes with the same index are found - those really should get
// removed by GVN!
let mut parameter_index_to_memory_object = vec![None; function.param_types.len()];
for (memory_object, origin) in memory_object_to_origin.iter().enumerate() {
if let Some(param) = function.nodes[origin.idx()].try_parameter() {
assert!(
parameter_index_to_memory_object[param].is_none(),
"PANIC: Found multiple parameter nodes with the same index."
);
parameter_index_to_memory_object[param] = Some(memory_object);
}
}
// Run dataflow analysis to figure out which memory objects each data node
// may be. Note that there's a strict subset of data nodes that can assigned
// memory objects:
//
// - Phi: selects between memory objects in SSA form, may be assigned
// multiple possible memory objects.
// - Reduce: reduces over a memory object, similar to phis.
// - Parameter: may originate a memory object.
// - Constant: may originate a memory object.
// - Call: may originate a memory object - if doesn't originate a memory
// object, doesn't become one based on arguments, as arguments are passed
// to callee.
// - Read: may extract a smaller memory object from input - this is
// considered to be the same memory object as the input, as no copy takes
// place.
// - Write: updates a memory object.
// - Undef: may originate a dummy memory object.
//
// Some notable omissions are:
//
// - Return: doesn't technically "output" a memory object, but may consume
// one. As in the logic with calls not returning a memory object, returns
// are not assigned memory objects.
// - Ternary (select): selecting over memory objects is a gray area
// currently. Bail if we see a select over memory objects.
assert!(!function.nodes.iter().enumerate().any(|(idx, node)| node
.try_ternary(TernaryOperator::Select)
.is_some()
&& !types[typing[idx].idx()].is_primitive()));
let lattice = forward_dataflow(function, reverse_postorder, |inputs, id| {
match function.nodes[id.idx()] {
Node::Phi {
control: _,
data: _,
}
| Node::Reduce {
control: _,
init: _,
reduct: _,
} => inputs
.into_iter()
.fold(MemoryObjectLattice::top(), |acc, input| {
MemoryObjectLattice::meet(&acc, input)
}),
Node::Parameter { index: _ }
| Node::Constant { id: _ }
| Node::Call {
control: _,
function: _,
dynamic_constants: _,
args: _,
}
| Node::Undef { ty: _ }
if let Some(obj) = node_id_to_originating_memory_obj.get(&id) =>
{
MemoryObjectLattice {
objs: [*obj].iter().map(|x| *x).collect(),
}
}
Node::Read {
collect: _,
indices: _,
}
| Node::Write {
collect: _,
data: _,
indices: _,
} => inputs[0].clone(),
_ => MemoryObjectLattice::top(),
}
});
// Look at the memory objects the data input to each return could be.
let mut possibly_returned_memory_objects = BTreeSet::new();
for node in function.nodes.iter() {
if let Node::Return { control: _, data } = node {
possibly_returned_memory_objects = possibly_returned_memory_objects
.union(&lattice[data.idx()].objs)
.map(|x| *x)
.collect();
}
}
let possibly_returned_memory_objects = possibly_returned_memory_objects.into_iter().collect();
let node_id_to_memory_objects = lattice
.into_iter()
.map(|lattice| lattice.objs.into_iter().collect())
.collect();
MemoryObjects {
node_id_to_memory_objects,
memory_object_to_origin,
parameter_index_to_memory_object,
possibly_returned_memory_objects,
}
}
/*
* Determine if each memory object in each function is mutated or not.
*/
pub fn memory_objects_mutability(
module: &Module,
callgraph: &CallGraph,
memory_objects: &Vec<MemoryObjects>,
) -> MemoryObjectsMutability {
let mut mutated: Vec<_> = memory_objects
.iter()
.map(|memory_objects| bitvec![u8, Lsb0; 0; memory_objects.num_memory_objects()])
.collect();
let topo = callgraph.topo();
for func_id in topo {
// A memory object is mutated when:
// 1. The object is the subject of a write node.
// 2. The object is passed as argument to a function that mutates it.
for (idx, node) in module.functions[func_id.idx()].nodes.iter().enumerate() {
if node.is_write() {
// Every memory object that the write itself corresponds to it
// mutable in this function.
for memory_object in memory_objects[func_id.idx()].memory_objects(NodeID::new(idx))
{
mutated[func_id.idx()].set(*memory_object, true);
}
} else if let Some((_, callee_id, _, args)) = node.try_call() {
for (param_idx, arg) in args.into_iter().enumerate() {
// If this parameter corresponds to a memory object and it's
// mutable in the callee...
if let Some(param_callee_memory_object) =
memory_objects[callee_id.idx()].memory_object_of_parameter(param_idx)
&& mutated[callee_id.idx()][param_callee_memory_object]
{
// Then every memory object corresponding to the
// argument node in this function is mutable.
for memory_object in memory_objects[func_id.idx()].memory_objects(*arg) {
mutated[func_id.idx()].set(*memory_object, true);
}
}
}
}
}
}
MemoryObjectsMutability {
func_to_memory_object_to_mutable: mutated,
}
}
/*
* The alignment of a type does not depend on dynamic constants.
*/
pub fn get_type_alignment(types: &Vec<Type>, ty: TypeID) -> usize {
match types[ty.idx()] {
Type::Control => panic!(),
Type::Boolean | Type::Integer8 | Type::UnsignedInteger8 => 1,
Type::Integer16 | Type::UnsignedInteger16 => 2,
Type::Integer32 | Type::UnsignedInteger32 | Type::Float32 => 4,
Type::Integer64 | Type::UnsignedInteger64 | Type::Float64 => 8,
Type::Product(ref members) | Type::Summation(ref members) => members
.into_iter()
.map(|id| get_type_alignment(types, *id))
.max()
.unwrap_or(1),
Type::Array(elem, _) => get_type_alignment(types, elem),
}
}
...@@ -23,10 +23,9 @@ pub fn rt_codegen<W: Write>( ...@@ -23,10 +23,9 @@ pub fn rt_codegen<W: Write>(
typing: &Vec<TypeID>, typing: &Vec<TypeID>,
control_subgraph: &Subgraph, control_subgraph: &Subgraph,
bbs: &Vec<NodeID>, bbs: &Vec<NodeID>,
collection_objects: &CollectionObjects,
callgraph: &CallGraph, callgraph: &CallGraph,
devices: &Vec<Device>, devices: &Vec<Device>,
memory_objects: &Vec<MemoryObjects>,
memory_objects_mutability: &MemoryObjectsMutability,
w: &mut W, w: &mut W,
) -> Result<(), Error> { ) -> Result<(), Error> {
let ctx = RTContext { let ctx = RTContext {
...@@ -36,10 +35,9 @@ pub fn rt_codegen<W: Write>( ...@@ -36,10 +35,9 @@ pub fn rt_codegen<W: Write>(
typing, typing,
control_subgraph, control_subgraph,
bbs, bbs,
collection_objects,
callgraph, callgraph,
devices, devices,
memory_objects,
_memory_objects_mutability: memory_objects_mutability,
}; };
ctx.codegen_function(w) ctx.codegen_function(w)
} }
...@@ -51,12 +49,9 @@ struct RTContext<'a> { ...@@ -51,12 +49,9 @@ struct RTContext<'a> {
typing: &'a Vec<TypeID>, typing: &'a Vec<TypeID>,
control_subgraph: &'a Subgraph, control_subgraph: &'a Subgraph,
bbs: &'a Vec<NodeID>, bbs: &'a Vec<NodeID>,
collection_objects: &'a CollectionObjects,
callgraph: &'a CallGraph, callgraph: &'a CallGraph,
devices: &'a Vec<Device>, devices: &'a Vec<Device>,
memory_objects: &'a Vec<MemoryObjects>,
// TODO: use once memory objects are passed in a custom type where this
// actually matters.
_memory_objects_mutability: &'a MemoryObjectsMutability,
} }
impl<'a> RTContext<'a> { impl<'a> RTContext<'a> {
...@@ -66,7 +61,7 @@ impl<'a> RTContext<'a> { ...@@ -66,7 +61,7 @@ impl<'a> RTContext<'a> {
// Dump the function signature. // Dump the function signature.
write!( write!(
w, w,
"#[allow(unused_variables,unused_mut)]async fn {}(", "#[allow(unused_variables,unused_mut)]\nasync fn {}(",
func.name func.name
)?; )?;
let mut first_param = true; let mut first_param = true;
...@@ -99,10 +94,10 @@ impl<'a> RTContext<'a> { ...@@ -99,10 +94,10 @@ impl<'a> RTContext<'a> {
write!(w, ") -> {} {{\n", self.get_type_interface(func.return_type))?; write!(w, ") -> {} {{\n", self.get_type_interface(func.return_type))?;
// Copy the "interface" parameters to "non-interface" parameters. // Copy the "interface" parameters to "non-interface" parameters.
// The purpose of this is to convert memory objects from a Box<[u8]> // The purpose of this is to convert collection objects from a Box<[u8]>
// type to a *mut u8 type. This name copying is done so that we can // type to a *mut u8 type. This name copying is done so that we can
// easily construct memory objects just after this by moving the // easily construct objects just after this by moving the "inferface"
// "interface" parameters. // parameters.
for (idx, ty) in func.param_types.iter().enumerate() { for (idx, ty) in func.param_types.iter().enumerate() {
if self.module.types[ty.idx()].is_primitive() { if self.module.types[ty.idx()].is_primitive() {
write!(w, " let p{} = p_i{};\n", idx, idx)?; write!(w, " let p{} = p_i{};\n", idx, idx)?;
...@@ -115,45 +110,42 @@ impl<'a> RTContext<'a> { ...@@ -115,45 +110,42 @@ impl<'a> RTContext<'a> {
} }
} }
// Collect the boxes representing ownership over memory objects for this // Collect the boxes representing ownership over collection objects for
// function. The actual emitted computation is done entirely using // this function. The actual emitted computation is done entirely using
// pointers, so these get emitted to hold onto ownership over the // pointers, so these get emitted to hold onto ownership over the
// underlying memory and to automatically clean them up when this // underlying memory and to automatically clean them up when this
// function returns. Memory objects are inside Options, since their // function returns. Collection objects are inside Options, since their
// ownership may get passed to other called RT functions. If this // ownership may get passed to other called RT functions. If this
// function returns a memory object, then at the very end, right before // function returns a collection object, then at the very end, right
// the return, the to-be-returned pointer is compared against the owned // before the return, the to-be-returned pointer is compared against the
// memory objects - it should match exactly one of those objects, and // owned collection objects - it should match exactly one of those
// that box is what's actually returned. // objects, and that box is what's actually returned.
let mem_obj_ty = "::core::option::Option<::std::boxed::Box<[u8]>>"; let mem_obj_ty = "::core::option::Option<::std::boxed::Box<[u8]>>";
for memory_object in 0..self.memory_objects[self.func_id.idx()].num_memory_objects() { for object in self.collection_objects[&self.func_id].iter_objects() {
let origin = self.memory_objects[self.func_id.idx()].origin(memory_object); match self.collection_objects[&self.func_id].origin(object) {
match func.nodes[origin.idx()] { CollectionObjectOrigin::Parameter(index) => write!(
Node::Parameter { index } => write!(
w, w,
" let mut mem_obj{}: {} = Some(p_i{});\n", " let mut obj{}: {} = Some(p_i{});\n",
memory_object, mem_obj_ty, index object.idx(),
mem_obj_ty,
index
)?, )?,
Node::Constant { id: _ } => { CollectionObjectOrigin::Constant(id) => {
let size = self.codegen_type_size(self.typing[origin.idx()]); let size = self.codegen_type_size(self.typing[id.idx()]);
write!( write!(
w, w,
" let mut mem_obj{}: {} = Some((0..{}).map(|_| 0u8).collect());\n", " let mut obj{}: {} = Some((0..{}).map(|_| 0u8).collect());\n",
memory_object, mem_obj_ty, size object.idx(),
mem_obj_ty,
size
)? )?
} }
Node::Call { CollectionObjectOrigin::Call(_) | CollectionObjectOrigin::Undef(_) => write!(
control: _,
function: _,
dynamic_constants: _,
args: _,
}
| Node::Undef { ty: _ } => write!(
w, w,
" let mut mem_obj{}: {} = None;\n", " let mut obj{}: {} = None;\n",
memory_object, mem_obj_ty, object.idx(),
mem_obj_ty,
)?, )?,
_ => panic!(),
} }
} }
...@@ -308,19 +300,19 @@ impl<'a> RTContext<'a> { ...@@ -308,19 +300,19 @@ impl<'a> RTContext<'a> {
} }
Node::Return { control: _, data } => { Node::Return { control: _, data } => {
let block = &mut blocks.get_mut(&id).unwrap(); let block = &mut blocks.get_mut(&id).unwrap();
let memory_objects = self.memory_objects[self.func_id.idx()].memory_objects(data); let objects = self.collection_objects[&self.func_id].objects(data);
if memory_objects.is_empty() { if objects.is_empty() {
write!(block, " return {};\n", self.get_value(data))? write!(block, " return {};\n", self.get_value(data))?
} else { } else {
// If the value to return is a memory object, figure out // If the value to return is a collection object, figure out
// which memory object it actually is at runtime and return // which object it actually is at runtime and return that
// that box. // box.
for memory_object in memory_objects { for object in objects {
write!(block, " if let Some(mut mem_obj) = mem_obj{} && ::std::boxed::Box::as_mut_ptr(&mut mem_obj) as *mut u8 == {} {{\n", memory_object, self.get_value(data))?; write!(block, " if let Some(mut obj) = obj{} && ::std::boxed::Box::as_mut_ptr(&mut obj) as *mut u8 == {} {{\n", object.idx(), self.get_value(data))?;
write!(block, " return mem_obj;\n")?; write!(block, " return obj;\n")?;
write!(block, " }}\n")?; write!(block, " }}\n")?;
} }
write!(block, " panic!(\"HERCULES PANIC: Pointer to be returned doesn't match any known memory objects.\");\n")? write!(block, " panic!(\"HERCULES PANIC: Pointer to be returned doesn't match any known collection objects.\");\n")?
} }
} }
_ => panic!("PANIC: Can't lower {:?}.", func.nodes[id.idx()]), _ => panic!("PANIC: Can't lower {:?}.", func.nodes[id.idx()]),
...@@ -363,14 +355,13 @@ impl<'a> RTContext<'a> { ...@@ -363,14 +355,13 @@ impl<'a> RTContext<'a> {
Constant::Float32(val) => write!(block, "{}f32", val)?, Constant::Float32(val) => write!(block, "{}f32", val)?,
Constant::Float64(val) => write!(block, "{}f64", val)?, Constant::Float64(val) => write!(block, "{}f64", val)?,
Constant::Product(_, _) | Constant::Summation(_, _, _) | Constant::Array(_) => { Constant::Product(_, _) | Constant::Summation(_, _, _) | Constant::Array(_) => {
let memory_objects = let objects = self.collection_objects[&self.func_id].objects(id);
self.memory_objects[self.func_id.idx()].memory_objects(id); assert_eq!(objects.len(), 1);
assert_eq!(memory_objects.len(), 1); let object = objects[0];
let memory_object = memory_objects[0];
write!( write!(
block, block,
"::std::boxed::Box::as_mut_ptr(mem_obj{}.as_mut().unwrap()) as *mut u8", "::std::boxed::Box::as_mut_ptr(obj{}.as_mut().unwrap()) as *mut u8",
memory_object object.idx()
)? )?
} }
} }
...@@ -400,43 +391,40 @@ impl<'a> RTContext<'a> { ...@@ -400,43 +391,40 @@ impl<'a> RTContext<'a> {
} }
write!(block, ") }};\n")?; write!(block, ") }};\n")?;
// When a CPU function is called that returns a memory // When a CPU function is called that returns a
// object, that memory object must have come from one of // collection object, that object must have come from
// its parameters. Dynamically figure out which one it // one of its parameters. Dynamically figure out which
// came from, so that we can move it to the slot of the // one it came from, so that we can move it to the slot
// output memory object. // of the output object.
let call_memory_objects = let call_objects = self.collection_objects[&self.func_id].objects(id);
self.memory_objects[self.func_id.idx()].memory_objects(id); if !call_objects.is_empty() {
if !call_memory_objects.is_empty() { assert_eq!(call_objects.len(), 1);
assert_eq!(call_memory_objects.len(), 1); let call_object = call_objects[0];
let call_memory_object = call_memory_objects[0];
let callee_returned_objects =
let callee_returned_memory_objects = self.collection_objects[&callee_id].returned_objects();
self.memory_objects[callee_id.idx()].returned_memory_objects();
let possible_params: Vec<_> = let possible_params: Vec<_> =
(0..self.module.functions[callee_id.idx()].param_types.len()) (0..self.module.functions[callee_id.idx()].param_types.len())
.filter(|idx| { .filter(|idx| {
let memory_object_of_param = self.memory_objects let object_of_param = self.collection_objects[&callee_id]
[callee_id.idx()] .param_to_object(*idx);
.memory_object_of_parameter(*idx);
// Look at parameters that could be the // Look at parameters that could be the
// source of the memory object returned // source of the memory object returned
// by the function. // by the function.
memory_object_of_param object_of_param
.map(|memory_object_of_param| { .map(|object_of_param| {
callee_returned_memory_objects callee_returned_objects.contains(&object_of_param)
.contains(&memory_object_of_param)
}) })
.unwrap_or(false) .unwrap_or(false)
}) })
.collect(); .collect();
let arg_memory_objects = args let arg_objects = args
.into_iter() .into_iter()
.enumerate() .enumerate()
.filter(|(idx, _)| possible_params.contains(idx)) .filter(|(idx, _)| possible_params.contains(idx))
.map(|(_, arg)| { .map(|(_, arg)| {
self.memory_objects[self.func_id.idx()] self.collection_objects[&self.func_id]
.memory_objects(*arg) .objects(*arg)
.into_iter() .into_iter()
}) })
.flatten(); .flatten();
...@@ -446,23 +434,24 @@ impl<'a> RTContext<'a> { ...@@ -446,23 +434,24 @@ impl<'a> RTContext<'a> {
// returned by the call. Move that memory object // returned by the call. Move that memory object
// into the memory object of the call. // into the memory object of the call.
let mut first_obj = true; let mut first_obj = true;
for arg_memory_object in arg_memory_objects { for arg_object in arg_objects {
write!(block, " ")?; write!(block, " ")?;
if first_obj { if first_obj {
first_obj = false; first_obj = false;
} else { } else {
write!(block, "else ")?; write!(block, "else ")?;
} }
write!(block, "if let Some(mem_obj) = mem_obj{}.as_mut() && ::std::boxed::Box::as_mut_ptr(mem_obj) as *mut u8 == {} {{\n", arg_memory_object, self.get_value(id))?; write!(block, "if let Some(obj) = obj{}.as_mut() && ::std::boxed::Box::as_mut_ptr(obj) as *mut u8 == {} {{\n", arg_object.idx(), self.get_value(id))?;
write!( write!(
block, block,
" mem_obj{} = mem_obj{}.take();\n", " obj{} = obj{}.take();\n",
call_memory_object, arg_memory_object call_object.idx(),
arg_object.idx()
)?; )?;
write!(block, " }}\n")?; write!(block, " }}\n")?;
} }
write!(block, " else {{\n")?; write!(block, " else {{\n")?;
write!(block, " panic!(\"HERCULES PANIC: Pointer returned from called function doesn't match any known memory objects.\");\n")?; write!(block, " panic!(\"HERCULES PANIC: Pointer returned from called function doesn't match any known collection objects.\");\n")?;
write!(block, " }}\n")?; write!(block, " }}\n")?;
} }
} }
......
extern crate bitvec;
use std::collections::{BTreeMap, BTreeSet};
use self::bitvec::prelude::*;
use crate::*;
/*
* Analysis result that finds "collection objects" in Hercules IR. This analysis
* is inter-procedural, since collection objects can be passed / returned to /
* from called functions. This analysis also tracks which collection objects are
* mutated "in" a function - a collection object is mutated in a function if
* that function contains a write node that may write to that collection object
* or if that function contains a call node that may take that collection object
* as an argument and that collection parameter of that function is mutated.
* Collection objects are numbered locally - the following nodes may originate a
* collection object:
*
* - Parameter: each parameter index gets assigned a single collection object,
* each parameter node gets assigned the object of its index.
* - Constant: each collection constant node gets assigned a single collection
* object.
* - Call: each function is analyzed to determine which collection objects (of
* its parameters or an object it originates) may be returned; a call node
* originates a new collection object if it may return an object originated
* inside the callee.
* - Undef: each undef node with a non-primitive type gets assigned a single
* collection object.
*
* The analysis contains the following information:
*
* - For each node in each function, which collection objects may be on the
* output of the node?
* - For each function, which collection objects may be mutated inside that
* function?
* - For each function, which collection objects may be returned?
* - For each collection object, how was it originated?
*/
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum CollectionObjectOrigin {
Parameter(usize),
Constant(NodeID),
Call(NodeID),
Undef(NodeID),
}
define_id_type!(CollectionObjectID);
#[derive(Debug, Clone)]
pub struct FunctionCollectionObjects {
objects_per_node: Vec<Vec<CollectionObjectID>>,
mutated: BitVec<u8, Lsb0>,
returned: Vec<CollectionObjectID>,
origins: Vec<CollectionObjectOrigin>,
}
pub type CollectionObjects = BTreeMap<FunctionID, FunctionCollectionObjects>;
impl CollectionObjectOrigin {
fn try_parameter(&self) -> Option<usize> {
match self {
CollectionObjectOrigin::Parameter(index) => Some(*index),
_ => None,
}
}
}
impl FunctionCollectionObjects {
pub fn objects(&self, id: NodeID) -> &Vec<CollectionObjectID> {
&self.objects_per_node[id.idx()]
}
pub fn origin(&self, object: CollectionObjectID) -> CollectionObjectOrigin {
self.origins[object.idx()]
}
pub fn param_to_object(&self, index: usize) -> Option<CollectionObjectID> {
self.origins
.iter()
.position(|origin| *origin == CollectionObjectOrigin::Parameter(index))
.map(CollectionObjectID::new)
}
pub fn returned_objects(&self) -> &Vec<CollectionObjectID> {
&self.returned
}
pub fn is_mutated(&self, object: CollectionObjectID) -> bool {
self.mutated[object.idx()]
}
pub fn num_objects(&self) -> usize {
self.origins.len()
}
pub fn iter_objects(&self) -> impl Iterator<Item = CollectionObjectID> {
(0..self.num_objects()).map(CollectionObjectID::new)
}
}
/*
* Each node is assigned a set of collection objects output-ed from the node.
* This is just a set of collection object IDs (usize).
*/
#[derive(PartialEq, Eq, Clone, Debug)]
struct CollectionObjectLattice {
objs: BTreeSet<CollectionObjectID>,
}
impl Semilattice for CollectionObjectLattice {
fn meet(a: &Self, b: &Self) -> Self {
CollectionObjectLattice {
objs: a.objs.union(&b.objs).map(|x| *x).collect(),
}
}
fn top() -> Self {
CollectionObjectLattice {
objs: BTreeSet::new(),
}
}
fn bottom() -> Self {
// Technically, this lattice is unbounded - technically technically, the
// lattice is bounded by the number of collection objects in a given
// function, but incorporating this information is not possible in our
// Semilattice inferface. Luckily bottom() isn't necessary if we never
// call it, which we don't for this analysis.
panic!()
}
}
/*
* Top level function to analyze collection objects in a Hercules module.
*/
pub fn collection_objects(
module: &Module,
reverse_postorders: &Vec<Vec<NodeID>>,
typing: &ModuleTyping,
callgraph: &CallGraph,
) -> CollectionObjects {
// Analyze functions in reverse topological order, since the analysis of a
// function depends on all functions it calls.
let mut collection_objects: CollectionObjects = BTreeMap::new();
let topo = callgraph.topo();
for func_id in topo {
let func = &module.functions[func_id.idx()];
let typing = &typing[func_id.idx()];
let reverse_postorder = &reverse_postorders[func_id.idx()];
// Find collection objects originating at parameters, constants, calls,
// or undefs. Each node may *originate* one collection object.
let param_origins = func
.param_types
.iter()
.enumerate()
.filter(|(_, ty_id)| !module.types[ty_id.idx()].is_primitive())
.map(|(idx, _)| CollectionObjectOrigin::Parameter(idx));
let other_origins = func
.nodes
.iter()
.enumerate()
.filter_map(|(idx, node)| match node {
Node::Constant { id: _ } if !module.types[typing[idx].idx()].is_primitive() => {
Some(CollectionObjectOrigin::Constant(NodeID::new(idx)))
}
Node::Call {
control: _,
function: callee,
dynamic_constants: _,
args: _,
} if {
let fco = &collection_objects[&callee];
fco.returned
.iter()
.any(|returned| fco.origins[returned.idx()].try_parameter().is_none())
} =>
{
// If the callee may return a new collection object, then
// this call node originates a single collection object. The
// node may output multiple collection objects, say if the
// callee may return an object passed in as a parameter -
// this is determined later.
Some(CollectionObjectOrigin::Call(NodeID::new(idx)))
}
Node::Undef { ty: _ } if !module.types[typing[idx].idx()].is_primitive() => {
Some(CollectionObjectOrigin::Undef(NodeID::new(idx)))
}
_ => None,
});
let origins: Vec<_> = param_origins.chain(other_origins).collect();
// Run dataflow analysis to figure out which collection objects each
// data node may output. Note that there's a strict subset of data nodes
// that can output collection objects:
//
// - Phi: selects between objects in SSA form, may be assigned multiple
// possible objects.
// - Reduce: reduces over an object, similar to phis.
// - Parameter: may originate an object.
// - Constant: may originate an object.
// - Call: may originate an object and may return an object passed in as
// a parameter.
// - Read: may extract a smaller object from the input - this is
// considered to be the same object as the input, as no copy takes
// place.
// - Write: updates an object - this is considered to be the same object
// as the input object, as the write gets lowered to an in-place
// mutation.
// - Undef: may originate a dummy object.
// - Ternary (select): selects between two objects, may output either.
let lattice = forward_dataflow(func, reverse_postorder, |inputs, id| {
match func.nodes[id.idx()] {
Node::Phi {
control: _,
data: _,
}
| Node::Reduce {
control: _,
init: _,
reduct: _,
}
| Node::Ternary {
op: TernaryOperator::Select,
first: _,
second: _,
third: _,
} => inputs
.into_iter()
.fold(CollectionObjectLattice::top(), |acc, input| {
CollectionObjectLattice::meet(&acc, input)
}),
Node::Parameter { index } => {
let obj = origins
.iter()
.position(|origin| *origin == CollectionObjectOrigin::Parameter(index))
.map(CollectionObjectID::new);
CollectionObjectLattice {
objs: obj.into_iter().collect(),
}
}
Node::Constant { id: _ } => {
let obj = origins
.iter()
.position(|origin| *origin == CollectionObjectOrigin::Constant(id))
.map(CollectionObjectID::new);
CollectionObjectLattice {
objs: obj.into_iter().collect(),
}
}
Node::Call {
control: _,
function: callee,
dynamic_constants: _,
args: _,
} if !module.types[typing[id.idx()].idx()].is_primitive() => {
let new_obj = origins
.iter()
.position(|origin| *origin == CollectionObjectOrigin::Call(id))
.map(CollectionObjectID::new);
let fco = &collection_objects[&callee];
let param_objs = fco
.returned
.iter()
.filter_map(|returned| fco.origins[returned.idx()].try_parameter())
.map(|param_index| inputs[param_index + 1]);
let mut objs: BTreeSet<_> = new_obj.into_iter().collect();
for param_objs in param_objs {
objs.extend(&param_objs.objs);
}
CollectionObjectLattice { objs }
}
Node::Undef { ty: _ } => {
let obj = origins
.iter()
.position(|origin| *origin == CollectionObjectOrigin::Undef(id))
.map(CollectionObjectID::new);
CollectionObjectLattice {
objs: obj.into_iter().collect(),
}
}
Node::Read {
collect: _,
indices: _,
}
| Node::Write {
collect: _,
data: _,
indices: _,
} => inputs[0].clone(),
_ => CollectionObjectLattice::top(),
}
});
let objects_per_node: Vec<Vec<_>> = lattice
.into_iter()
.map(|l| l.objs.into_iter().collect())
.collect();
// Look at the collection objects that each return may take as input.
let mut returned: BTreeSet<CollectionObjectID> = BTreeSet::new();
for node in func.nodes.iter() {
if let Node::Return { control: _, data } = node {
returned.extend(&objects_per_node[data.idx()]);
}
}
let returned = returned.into_iter().collect();
// Determine which objects are potentially mutated.
let mut mutated = bitvec![u8, Lsb0; 0; origins.len()];
for (idx, node) in func.nodes.iter().enumerate() {
if node.is_write() {
// Every object that the write itself corresponds to is mutable
// in this function.
for object in objects_per_node[idx].iter() {
mutated.set(object.idx(), true);
}
} else if let Some((_, callee, _, args)) = node.try_call() {
let fco = &collection_objects[&callee];
for (param_idx, arg) in args.into_iter().enumerate() {
// If this parameter corresponds to an object and it's
// mutable in the callee...
if let Some(param_callee_object) = fco.param_to_object(param_idx)
&& fco.is_mutated(param_callee_object)
{
// Then every object corresponding to the argument node
// in this function is mutable.
for object in objects_per_node[arg.idx()].iter() {
mutated.set(object.idx(), true);
}
}
}
}
}
let fco = FunctionCollectionObjects {
objects_per_node,
mutated,
returned,
origins,
};
collection_objects.insert(func_id, fco);
}
collection_objects
}
...@@ -9,6 +9,7 @@ ...@@ -9,6 +9,7 @@
pub mod antideps; pub mod antideps;
pub mod build; pub mod build;
pub mod callgraph; pub mod callgraph;
pub mod collections;
pub mod dataflow; pub mod dataflow;
pub mod def_use; pub mod def_use;
pub mod dom; pub mod dom;
...@@ -24,6 +25,7 @@ pub mod verify; ...@@ -24,6 +25,7 @@ pub mod verify;
pub use crate::antideps::*; pub use crate::antideps::*;
pub use crate::build::*; pub use crate::build::*;
pub use crate::callgraph::*; pub use crate::callgraph::*;
pub use crate::collections::*;
pub use crate::dataflow::*; pub use crate::dataflow::*;
pub use crate::def_use::*; pub use crate::def_use::*;
pub use crate::dom::*; pub use crate::dom::*;
......
...@@ -74,6 +74,7 @@ pub struct PassManager { ...@@ -74,6 +74,7 @@ pub struct PassManager {
pub antideps: Option<Vec<Vec<(NodeID, NodeID)>>>, pub antideps: Option<Vec<Vec<(NodeID, NodeID)>>>,
pub data_nodes_in_fork_joins: Option<Vec<HashMap<NodeID, HashSet<NodeID>>>>, pub data_nodes_in_fork_joins: Option<Vec<HashMap<NodeID, HashSet<NodeID>>>>,
pub bbs: Option<Vec<Vec<NodeID>>>, pub bbs: Option<Vec<Vec<NodeID>>>,
pub collection_objects: Option<CollectionObjects>,
pub callgraph: Option<CallGraph>, pub callgraph: Option<CallGraph>,
} }
...@@ -95,6 +96,7 @@ impl PassManager { ...@@ -95,6 +96,7 @@ impl PassManager {
antideps: None, antideps: None,
data_nodes_in_fork_joins: None, data_nodes_in_fork_joins: None,
bbs: None, bbs: None,
collection_objects: None,
callgraph: None, callgraph: None,
} }
} }
...@@ -315,6 +317,23 @@ impl PassManager { ...@@ -315,6 +317,23 @@ impl PassManager {
} }
} }
pub fn make_collection_objects(&mut self) {
if self.collection_objects.is_none() {
self.make_reverse_postorders();
self.make_typing();
self.make_callgraph();
let reverse_postorders = self.reverse_postorders.as_ref().unwrap();
let typing = self.typing.as_ref().unwrap();
let callgraph = self.callgraph.as_ref().unwrap();
self.collection_objects = Some(collection_objects(
&self.module,
reverse_postorders,
typing,
callgraph,
));
}
}
pub fn make_callgraph(&mut self) { pub fn make_callgraph(&mut self) {
if self.callgraph.is_none() { if self.callgraph.is_none() {
self.callgraph = Some(callgraph(&self.module)); self.callgraph = Some(callgraph(&self.module));
...@@ -844,26 +863,15 @@ impl PassManager { ...@@ -844,26 +863,15 @@ impl PassManager {
self.make_typing(); self.make_typing();
self.make_control_subgraphs(); self.make_control_subgraphs();
self.make_bbs(); self.make_bbs();
self.make_collection_objects();
self.make_callgraph(); self.make_callgraph();
let reverse_postorders = self.reverse_postorders.as_ref().unwrap(); let reverse_postorders = self.reverse_postorders.as_ref().unwrap();
let typing = self.typing.as_ref().unwrap(); let typing = self.typing.as_ref().unwrap();
let control_subgraphs = self.control_subgraphs.as_ref().unwrap(); let control_subgraphs = self.control_subgraphs.as_ref().unwrap();
let bbs = self.bbs.as_ref().unwrap(); let bbs = self.bbs.as_ref().unwrap();
let collection_objects = self.collection_objects.as_ref().unwrap();
let callgraph = self.callgraph.as_ref().unwrap(); let callgraph = self.callgraph.as_ref().unwrap();
let memory_objects: Vec<_> = (0..self.module.functions.len())
.map(|idx| {
memory_objects(
&self.module.functions[idx],
&self.module.types,
&reverse_postorders[idx],
&typing[idx],
)
})
.collect();
let memory_objects_mutable =
memory_objects_mutability(&self.module, &callgraph, &memory_objects);
let devices = device_placement(&self.module.functions, &callgraph); let devices = device_placement(&self.module.functions, &callgraph);
let mut rust_rt = String::new(); let mut rust_rt = String::new();
...@@ -889,10 +897,9 @@ impl PassManager { ...@@ -889,10 +897,9 @@ impl PassManager {
&typing[idx], &typing[idx],
&control_subgraphs[idx], &control_subgraphs[idx],
&bbs[idx], &bbs[idx],
&collection_objects,
&callgraph, &callgraph,
&devices, &devices,
&memory_objects,
&memory_objects_mutable,
&mut rust_rt, &mut rust_rt,
) )
.unwrap(), .unwrap(),
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment