From 4606a2a7219ef33437a8995526c873ab21016980 Mon Sep 17 00:00:00 2001
From: Russel Arbore <russel.jma@gmail.com>
Date: Sun, 29 Dec 2024 15:07:09 -0800
Subject: [PATCH] Refactor object analysis

---
 hercules_cg/src/lib.rs         |  25 ++-
 hercules_cg/src/mem.rs         | 291 ---------------------------
 hercules_cg/src/rt.rs          | 153 +++++++--------
 hercules_ir/src/collections.rs | 348 +++++++++++++++++++++++++++++++++
 hercules_ir/src/lib.rs         |   2 +
 hercules_opt/src/pass.rs       |  37 ++--
 6 files changed, 466 insertions(+), 390 deletions(-)
 delete mode 100644 hercules_cg/src/mem.rs
 create mode 100644 hercules_ir/src/collections.rs

diff --git a/hercules_cg/src/lib.rs b/hercules_cg/src/lib.rs
index 952ce368..c579b7e9 100644
--- a/hercules_cg/src/lib.rs
+++ b/hercules_cg/src/lib.rs
@@ -2,10 +2,31 @@
 
 pub mod cpu;
 pub mod device;
-pub mod mem;
 pub mod rt;
 
 pub use crate::cpu::*;
 pub use crate::device::*;
-pub use crate::mem::*;
 pub use crate::rt::*;
+
+extern crate hercules_ir;
+
+use self::hercules_ir::*;
+
+/*
+ * The alignment of a type does not depend on dynamic constants.
+ */
+pub fn get_type_alignment(types: &Vec<Type>, ty: TypeID) -> usize {
+    match types[ty.idx()] {
+        Type::Control => panic!(),
+        Type::Boolean | Type::Integer8 | Type::UnsignedInteger8 => 1,
+        Type::Integer16 | Type::UnsignedInteger16 => 2,
+        Type::Integer32 | Type::UnsignedInteger32 | Type::Float32 => 4,
+        Type::Integer64 | Type::UnsignedInteger64 | Type::Float64 => 8,
+        Type::Product(ref members) | Type::Summation(ref members) => members
+            .into_iter()
+            .map(|id| get_type_alignment(types, *id))
+            .max()
+            .unwrap_or(1),
+        Type::Array(elem, _) => get_type_alignment(types, elem),
+    }
+}
diff --git a/hercules_cg/src/mem.rs b/hercules_cg/src/mem.rs
deleted file mode 100644
index 0c053455..00000000
--- a/hercules_cg/src/mem.rs
+++ /dev/null
@@ -1,291 +0,0 @@
-extern crate bitvec;
-extern crate hercules_ir;
-
-use std::collections::{BTreeMap, BTreeSet};
-
-use self::bitvec::prelude::*;
-
-use self::hercules_ir::*;
-
-#[derive(Debug)]
-pub struct MemoryObjects {
-    node_id_to_memory_objects: Vec<Vec<usize>>,
-    memory_object_to_origin: Vec<NodeID>,
-    parameter_index_to_memory_object: Vec<Option<usize>>,
-    possibly_returned_memory_objects: Vec<usize>,
-}
-
-impl MemoryObjects {
-    pub fn memory_objects(&self, id: NodeID) -> &Vec<usize> {
-        &self.node_id_to_memory_objects[id.idx()]
-    }
-
-    pub fn origin(&self, memory_object: usize) -> NodeID {
-        self.memory_object_to_origin[memory_object]
-    }
-
-    pub fn memory_object_of_parameter(&self, parameter: usize) -> Option<usize> {
-        self.parameter_index_to_memory_object[parameter]
-    }
-
-    pub fn returned_memory_objects(&self) -> &Vec<usize> {
-        &self.possibly_returned_memory_objects
-    }
-
-    pub fn num_memory_objects(&self) -> usize {
-        self.memory_object_to_origin.len()
-    }
-}
-
-#[derive(Debug)]
-pub struct MemoryObjectsMutability {
-    func_to_memory_object_to_mutable: Vec<BitVec<u8, Lsb0>>,
-}
-
-impl MemoryObjectsMutability {
-    pub fn is_mutable(&self, id: FunctionID, memory_object: usize) -> bool {
-        self.func_to_memory_object_to_mutable[id.idx()][memory_object]
-    }
-}
-
-/*
- * Each node is assigned a set of memory objects output-ed from the node. This
- * is just a set of memory object IDs (usize).
- */
-#[derive(PartialEq, Eq, Clone, Debug)]
-struct MemoryObjectLattice {
-    objs: BTreeSet<usize>,
-}
-
-impl Semilattice for MemoryObjectLattice {
-    fn meet(a: &Self, b: &Self) -> Self {
-        MemoryObjectLattice {
-            objs: a.objs.union(&b.objs).map(|x| *x).collect(),
-        }
-    }
-
-    fn top() -> Self {
-        MemoryObjectLattice {
-            objs: BTreeSet::new(),
-        }
-    }
-
-    fn bottom() -> Self {
-        // Technically, this lattice is unbounded - technically technically, the
-        // lattice is bounded by the number of memory objects in a given
-        // instance, but incorporating this information is not possible in our
-        // Semilattice inferface. Luckily bottom() isn't necessary if we never
-        // call it, which we don't here.
-        panic!()
-    }
-}
-
-/*
- * Top level function to analyze memory objects in a Hercules function. These
- * are distinct collections (products, summations, arrays) that are used in a
- * function where we try to disambiguate a string of values produced in the
- * immutable value semantics of Hercules IR into a smaller amount of distinct
- * memory object that can be modified in-place.
- */
-pub fn memory_objects(
-    function: &Function,
-    types: &Vec<Type>,
-    reverse_postorder: &Vec<NodeID>,
-    typing: &Vec<TypeID>,
-) -> MemoryObjects {
-    // Find memory objects originating at parameters, constants, calls, or
-    // undefs.
-    let memory_object_to_origin: Vec<_> = function
-        .nodes
-        .iter()
-        .enumerate()
-        .filter(|(idx, node)| {
-            (node.is_parameter() || node.is_constant() || node.is_call() || node.is_undef())
-                && !types[typing[*idx].idx()].is_primitive()
-        })
-        .map(|(idx, _)| NodeID::new(idx))
-        .collect();
-    let node_id_to_originating_memory_obj: BTreeMap<_, _> = memory_object_to_origin
-        .iter()
-        .enumerate()
-        .map(|(idx, id)| (*id, idx))
-        .collect();
-
-    // Map parameter index to memory object, if applicable. Panic if two
-    // parameter nodes with the same index are found - those really should get
-    // removed by GVN!
-    let mut parameter_index_to_memory_object = vec![None; function.param_types.len()];
-    for (memory_object, origin) in memory_object_to_origin.iter().enumerate() {
-        if let Some(param) = function.nodes[origin.idx()].try_parameter() {
-            assert!(
-                parameter_index_to_memory_object[param].is_none(),
-                "PANIC: Found multiple parameter nodes with the same index."
-            );
-            parameter_index_to_memory_object[param] = Some(memory_object);
-        }
-    }
-
-    // Run dataflow analysis to figure out which memory objects each data node
-    // may be. Note that there's a strict subset of data nodes that can assigned
-    // memory objects:
-    //
-    // - Phi: selects between memory objects in SSA form, may be assigned
-    //   multiple possible memory objects.
-    // - Reduce: reduces over a memory object, similar to phis.
-    // - Parameter: may originate a memory object.
-    // - Constant: may originate a memory object.
-    // - Call: may originate a memory object - if doesn't originate a memory
-    //   object, doesn't become one based on arguments, as arguments are passed
-    //   to callee.
-    // - Read: may extract a smaller memory object from input - this is
-    //   considered to be the same memory object as the input, as no copy takes
-    //   place.
-    // - Write: updates a memory object.
-    // - Undef: may originate a dummy memory object.
-    //
-    // Some notable omissions are:
-    //
-    // - Return: doesn't technically "output" a memory object, but may consume
-    //   one. As in the logic with calls not returning a memory object, returns
-    //   are not assigned memory objects.
-    // - Ternary (select): selecting over memory objects is a gray area
-    //   currently. Bail if we see a select over memory objects.
-    assert!(!function.nodes.iter().enumerate().any(|(idx, node)| node
-        .try_ternary(TernaryOperator::Select)
-        .is_some()
-        && !types[typing[idx].idx()].is_primitive()));
-    let lattice = forward_dataflow(function, reverse_postorder, |inputs, id| {
-        match function.nodes[id.idx()] {
-            Node::Phi {
-                control: _,
-                data: _,
-            }
-            | Node::Reduce {
-                control: _,
-                init: _,
-                reduct: _,
-            } => inputs
-                .into_iter()
-                .fold(MemoryObjectLattice::top(), |acc, input| {
-                    MemoryObjectLattice::meet(&acc, input)
-                }),
-            Node::Parameter { index: _ }
-            | Node::Constant { id: _ }
-            | Node::Call {
-                control: _,
-                function: _,
-                dynamic_constants: _,
-                args: _,
-            }
-            | Node::Undef { ty: _ }
-                if let Some(obj) = node_id_to_originating_memory_obj.get(&id) =>
-            {
-                MemoryObjectLattice {
-                    objs: [*obj].iter().map(|x| *x).collect(),
-                }
-            }
-            Node::Read {
-                collect: _,
-                indices: _,
-            }
-            | Node::Write {
-                collect: _,
-                data: _,
-                indices: _,
-            } => inputs[0].clone(),
-            _ => MemoryObjectLattice::top(),
-        }
-    });
-
-    // Look at the memory objects the data input to each return could be.
-    let mut possibly_returned_memory_objects = BTreeSet::new();
-    for node in function.nodes.iter() {
-        if let Node::Return { control: _, data } = node {
-            possibly_returned_memory_objects = possibly_returned_memory_objects
-                .union(&lattice[data.idx()].objs)
-                .map(|x| *x)
-                .collect();
-        }
-    }
-    let possibly_returned_memory_objects = possibly_returned_memory_objects.into_iter().collect();
-
-    let node_id_to_memory_objects = lattice
-        .into_iter()
-        .map(|lattice| lattice.objs.into_iter().collect())
-        .collect();
-    MemoryObjects {
-        node_id_to_memory_objects,
-        memory_object_to_origin,
-        parameter_index_to_memory_object,
-        possibly_returned_memory_objects,
-    }
-}
-
-/*
- * Determine if each memory object in each function is mutated or not.
- */
-pub fn memory_objects_mutability(
-    module: &Module,
-    callgraph: &CallGraph,
-    memory_objects: &Vec<MemoryObjects>,
-) -> MemoryObjectsMutability {
-    let mut mutated: Vec<_> = memory_objects
-        .iter()
-        .map(|memory_objects| bitvec![u8, Lsb0; 0; memory_objects.num_memory_objects()])
-        .collect();
-    let topo = callgraph.topo();
-
-    for func_id in topo {
-        // A memory object is mutated when:
-        // 1. The object is the subject of a write node.
-        // 2. The object is passed as argument to a function that mutates it.
-        for (idx, node) in module.functions[func_id.idx()].nodes.iter().enumerate() {
-            if node.is_write() {
-                // Every memory object that the write itself corresponds to it
-                // mutable in this function.
-                for memory_object in memory_objects[func_id.idx()].memory_objects(NodeID::new(idx))
-                {
-                    mutated[func_id.idx()].set(*memory_object, true);
-                }
-            } else if let Some((_, callee_id, _, args)) = node.try_call() {
-                for (param_idx, arg) in args.into_iter().enumerate() {
-                    // If this parameter corresponds to a memory object and it's
-                    // mutable in the callee...
-                    if let Some(param_callee_memory_object) =
-                        memory_objects[callee_id.idx()].memory_object_of_parameter(param_idx)
-                        && mutated[callee_id.idx()][param_callee_memory_object]
-                    {
-                        // Then every memory object corresponding to the
-                        // argument node in this function is mutable.
-                        for memory_object in memory_objects[func_id.idx()].memory_objects(*arg) {
-                            mutated[func_id.idx()].set(*memory_object, true);
-                        }
-                    }
-                }
-            }
-        }
-    }
-
-    MemoryObjectsMutability {
-        func_to_memory_object_to_mutable: mutated,
-    }
-}
-
-/*
- * The alignment of a type does not depend on dynamic constants.
- */
-pub fn get_type_alignment(types: &Vec<Type>, ty: TypeID) -> usize {
-    match types[ty.idx()] {
-        Type::Control => panic!(),
-        Type::Boolean | Type::Integer8 | Type::UnsignedInteger8 => 1,
-        Type::Integer16 | Type::UnsignedInteger16 => 2,
-        Type::Integer32 | Type::UnsignedInteger32 | Type::Float32 => 4,
-        Type::Integer64 | Type::UnsignedInteger64 | Type::Float64 => 8,
-        Type::Product(ref members) | Type::Summation(ref members) => members
-            .into_iter()
-            .map(|id| get_type_alignment(types, *id))
-            .max()
-            .unwrap_or(1),
-        Type::Array(elem, _) => get_type_alignment(types, elem),
-    }
-}
diff --git a/hercules_cg/src/rt.rs b/hercules_cg/src/rt.rs
index 305ecf9b..b5623d1a 100644
--- a/hercules_cg/src/rt.rs
+++ b/hercules_cg/src/rt.rs
@@ -23,10 +23,9 @@ pub fn rt_codegen<W: Write>(
     typing: &Vec<TypeID>,
     control_subgraph: &Subgraph,
     bbs: &Vec<NodeID>,
+    collection_objects: &CollectionObjects,
     callgraph: &CallGraph,
     devices: &Vec<Device>,
-    memory_objects: &Vec<MemoryObjects>,
-    memory_objects_mutability: &MemoryObjectsMutability,
     w: &mut W,
 ) -> Result<(), Error> {
     let ctx = RTContext {
@@ -36,10 +35,9 @@ pub fn rt_codegen<W: Write>(
         typing,
         control_subgraph,
         bbs,
+        collection_objects,
         callgraph,
         devices,
-        memory_objects,
-        _memory_objects_mutability: memory_objects_mutability,
     };
     ctx.codegen_function(w)
 }
@@ -51,12 +49,9 @@ struct RTContext<'a> {
     typing: &'a Vec<TypeID>,
     control_subgraph: &'a Subgraph,
     bbs: &'a Vec<NodeID>,
+    collection_objects: &'a CollectionObjects,
     callgraph: &'a CallGraph,
     devices: &'a Vec<Device>,
-    memory_objects: &'a Vec<MemoryObjects>,
-    // TODO: use once memory objects are passed in a custom type where this
-    // actually matters.
-    _memory_objects_mutability: &'a MemoryObjectsMutability,
 }
 
 impl<'a> RTContext<'a> {
@@ -66,7 +61,7 @@ impl<'a> RTContext<'a> {
         // Dump the function signature.
         write!(
             w,
-            "#[allow(unused_variables,unused_mut)]async fn {}(",
+            "#[allow(unused_variables,unused_mut)]\nasync fn {}(",
             func.name
         )?;
         let mut first_param = true;
@@ -99,10 +94,10 @@ impl<'a> RTContext<'a> {
         write!(w, ") -> {} {{\n", self.get_type_interface(func.return_type))?;
 
         // Copy the "interface" parameters to "non-interface" parameters.
-        // The purpose of this is to convert memory objects from a Box<[u8]>
+        // The purpose of this is to convert collection objects from a Box<[u8]>
         // type to a *mut u8 type. This name copying is done so that we can
-        // easily construct memory objects just after this by moving the
-        // "interface" parameters.
+        // easily construct objects just after this by moving the "inferface"
+        // parameters.
         for (idx, ty) in func.param_types.iter().enumerate() {
             if self.module.types[ty.idx()].is_primitive() {
                 write!(w, "    let p{} = p_i{};\n", idx, idx)?;
@@ -115,45 +110,42 @@ impl<'a> RTContext<'a> {
             }
         }
 
-        // Collect the boxes representing ownership over memory objects for this
-        // function. The actual emitted computation is done entirely using
+        // Collect the boxes representing ownership over collection objects for
+        // this function. The actual emitted computation is done entirely using
         // pointers, so these get emitted to hold onto ownership over the
         // underlying memory and to automatically clean them up when this
-        // function returns. Memory objects are inside Options, since their
+        // function returns. Collection objects are inside Options, since their
         // ownership may get passed to other called RT functions. If this
-        // function returns a memory object, then at the very end, right before
-        // the return, the to-be-returned pointer is compared against the owned
-        // memory objects - it should match exactly one of those objects, and
-        // that box is what's actually returned.
+        // function returns a collection object, then at the very end, right
+        // before the return, the to-be-returned pointer is compared against the
+        // owned collection objects - it should match exactly one of those
+        // objects, and that box is what's actually returned.
         let mem_obj_ty = "::core::option::Option<::std::boxed::Box<[u8]>>";
-        for memory_object in 0..self.memory_objects[self.func_id.idx()].num_memory_objects() {
-            let origin = self.memory_objects[self.func_id.idx()].origin(memory_object);
-            match func.nodes[origin.idx()] {
-                Node::Parameter { index } => write!(
+        for object in self.collection_objects[&self.func_id].iter_objects() {
+            match self.collection_objects[&self.func_id].origin(object) {
+                CollectionObjectOrigin::Parameter(index) => write!(
                     w,
-                    "    let mut mem_obj{}: {} = Some(p_i{});\n",
-                    memory_object, mem_obj_ty, index
+                    "    let mut obj{}: {} = Some(p_i{});\n",
+                    object.idx(),
+                    mem_obj_ty,
+                    index
                 )?,
-                Node::Constant { id: _ } => {
-                    let size = self.codegen_type_size(self.typing[origin.idx()]);
+                CollectionObjectOrigin::Constant(id) => {
+                    let size = self.codegen_type_size(self.typing[id.idx()]);
                     write!(
                         w,
-                        "    let mut mem_obj{}: {} = Some((0..{}).map(|_| 0u8).collect());\n",
-                        memory_object, mem_obj_ty, size
+                        "    let mut obj{}: {} = Some((0..{}).map(|_| 0u8).collect());\n",
+                        object.idx(),
+                        mem_obj_ty,
+                        size
                     )?
                 }
-                Node::Call {
-                    control: _,
-                    function: _,
-                    dynamic_constants: _,
-                    args: _,
-                }
-                | Node::Undef { ty: _ } => write!(
+                CollectionObjectOrigin::Call(_) | CollectionObjectOrigin::Undef(_) => write!(
                     w,
-                    "    let mut mem_obj{}: {} = None;\n",
-                    memory_object, mem_obj_ty,
+                    "    let mut obj{}: {} = None;\n",
+                    object.idx(),
+                    mem_obj_ty,
                 )?,
-                _ => panic!(),
             }
         }
 
@@ -308,19 +300,19 @@ impl<'a> RTContext<'a> {
             }
             Node::Return { control: _, data } => {
                 let block = &mut blocks.get_mut(&id).unwrap();
-                let memory_objects = self.memory_objects[self.func_id.idx()].memory_objects(data);
-                if memory_objects.is_empty() {
+                let objects = self.collection_objects[&self.func_id].objects(data);
+                if objects.is_empty() {
                     write!(block, "                return {};\n", self.get_value(data))?
                 } else {
-                    // If the value to return is a memory object, figure out
-                    // which memory object it actually is at runtime and return
-                    // that box.
-                    for memory_object in memory_objects {
-                        write!(block, "                if let Some(mut mem_obj) = mem_obj{} && ::std::boxed::Box::as_mut_ptr(&mut mem_obj) as *mut u8 == {} {{\n", memory_object, self.get_value(data))?;
-                        write!(block, "                    return mem_obj;\n")?;
+                    // If the value to return is a collection object, figure out
+                    // which object it actually is at runtime and return that
+                    // box.
+                    for object in objects {
+                        write!(block, "                if let Some(mut obj) = obj{} && ::std::boxed::Box::as_mut_ptr(&mut obj) as *mut u8 == {} {{\n", object.idx(), self.get_value(data))?;
+                        write!(block, "                    return obj;\n")?;
                         write!(block, "                }}\n")?;
                     }
-                    write!(block, "                panic!(\"HERCULES PANIC: Pointer to be returned doesn't match any known memory objects.\");\n")?
+                    write!(block, "                panic!(\"HERCULES PANIC: Pointer to be returned doesn't match any known collection objects.\");\n")?
                 }
             }
             _ => panic!("PANIC: Can't lower {:?}.", func.nodes[id.idx()]),
@@ -363,14 +355,13 @@ impl<'a> RTContext<'a> {
                     Constant::Float32(val) => write!(block, "{}f32", val)?,
                     Constant::Float64(val) => write!(block, "{}f64", val)?,
                     Constant::Product(_, _) | Constant::Summation(_, _, _) | Constant::Array(_) => {
-                        let memory_objects =
-                            self.memory_objects[self.func_id.idx()].memory_objects(id);
-                        assert_eq!(memory_objects.len(), 1);
-                        let memory_object = memory_objects[0];
+                        let objects = self.collection_objects[&self.func_id].objects(id);
+                        assert_eq!(objects.len(), 1);
+                        let object = objects[0];
                         write!(
                             block,
-                            "::std::boxed::Box::as_mut_ptr(mem_obj{}.as_mut().unwrap()) as *mut u8",
-                            memory_object
+                            "::std::boxed::Box::as_mut_ptr(obj{}.as_mut().unwrap()) as *mut u8",
+                            object.idx()
                         )?
                     }
                 }
@@ -400,43 +391,40 @@ impl<'a> RTContext<'a> {
                         }
                         write!(block, ") }};\n")?;
 
-                        // When a CPU function is called that returns a memory
-                        // object, that memory object must have come from one of
-                        // its parameters. Dynamically figure out which one it
-                        // came from, so that we can move it to the slot of the
-                        // output memory object.
-                        let call_memory_objects =
-                            self.memory_objects[self.func_id.idx()].memory_objects(id);
-                        if !call_memory_objects.is_empty() {
-                            assert_eq!(call_memory_objects.len(), 1);
-                            let call_memory_object = call_memory_objects[0];
-
-                            let callee_returned_memory_objects =
-                                self.memory_objects[callee_id.idx()].returned_memory_objects();
+                        // When a CPU function is called that returns a
+                        // collection object, that object must have come from
+                        // one of its parameters. Dynamically figure out which
+                        // one it came from, so that we can move it to the slot
+                        // of the output object.
+                        let call_objects = self.collection_objects[&self.func_id].objects(id);
+                        if !call_objects.is_empty() {
+                            assert_eq!(call_objects.len(), 1);
+                            let call_object = call_objects[0];
+
+                            let callee_returned_objects =
+                                self.collection_objects[&callee_id].returned_objects();
                             let possible_params: Vec<_> =
                                 (0..self.module.functions[callee_id.idx()].param_types.len())
                                     .filter(|idx| {
-                                        let memory_object_of_param = self.memory_objects
-                                            [callee_id.idx()]
-                                        .memory_object_of_parameter(*idx);
+                                        let object_of_param = self.collection_objects[&callee_id]
+                                            .param_to_object(*idx);
                                         // Look at parameters that could be the
                                         // source of the memory object returned
                                         // by the function.
-                                        memory_object_of_param
-                                            .map(|memory_object_of_param| {
-                                                callee_returned_memory_objects
-                                                    .contains(&memory_object_of_param)
+                                        object_of_param
+                                            .map(|object_of_param| {
+                                                callee_returned_objects.contains(&object_of_param)
                                             })
                                             .unwrap_or(false)
                                     })
                                     .collect();
-                            let arg_memory_objects = args
+                            let arg_objects = args
                                 .into_iter()
                                 .enumerate()
                                 .filter(|(idx, _)| possible_params.contains(idx))
                                 .map(|(_, arg)| {
-                                    self.memory_objects[self.func_id.idx()]
-                                        .memory_objects(*arg)
+                                    self.collection_objects[&self.func_id]
+                                        .objects(*arg)
                                         .into_iter()
                                 })
                                 .flatten();
@@ -446,23 +434,24 @@ impl<'a> RTContext<'a> {
                             // returned by the call. Move that memory object
                             // into the memory object of the call.
                             let mut first_obj = true;
-                            for arg_memory_object in arg_memory_objects {
+                            for arg_object in arg_objects {
                                 write!(block, "                ")?;
                                 if first_obj {
                                     first_obj = false;
                                 } else {
                                     write!(block, "else ")?;
                                 }
-                                write!(block, "if let Some(mem_obj) = mem_obj{}.as_mut() && ::std::boxed::Box::as_mut_ptr(mem_obj) as *mut u8 == {} {{\n", arg_memory_object, self.get_value(id))?;
+                                write!(block, "if let Some(obj) = obj{}.as_mut() && ::std::boxed::Box::as_mut_ptr(obj) as *mut u8 == {} {{\n", arg_object.idx(), self.get_value(id))?;
                                 write!(
                                     block,
-                                    "                    mem_obj{} = mem_obj{}.take();\n",
-                                    call_memory_object, arg_memory_object
+                                    "                    obj{} = obj{}.take();\n",
+                                    call_object.idx(),
+                                    arg_object.idx()
                                 )?;
                                 write!(block, "                }}\n")?;
                             }
                             write!(block, "                else {{\n")?;
-                            write!(block, "                    panic!(\"HERCULES PANIC: Pointer returned from called function doesn't match any known memory objects.\");\n")?;
+                            write!(block, "                    panic!(\"HERCULES PANIC: Pointer returned from called function doesn't match any known collection objects.\");\n")?;
                             write!(block, "                }}\n")?;
                         }
                     }
diff --git a/hercules_ir/src/collections.rs b/hercules_ir/src/collections.rs
new file mode 100644
index 00000000..c3dafe7a
--- /dev/null
+++ b/hercules_ir/src/collections.rs
@@ -0,0 +1,348 @@
+extern crate bitvec;
+
+use std::collections::{BTreeMap, BTreeSet};
+
+use self::bitvec::prelude::*;
+
+use crate::*;
+
+/*
+ * Analysis result that finds "collection objects" in Hercules IR. This analysis
+ * is inter-procedural, since collection objects can be passed / returned to /
+ * from called functions. This analysis also tracks which collection objects are
+ * mutated "in" a function - a collection object is mutated in a function if
+ * that function contains a write node that may write to that collection object
+ * or if that function contains a call node that may take that collection object
+ * as an argument and that collection parameter of that function is mutated.
+ * Collection objects are numbered locally - the following nodes may originate a
+ * collection object:
+ *
+ * - Parameter: each parameter index gets assigned a single collection object,
+ *   each parameter node gets assigned the object of its index.
+ * - Constant: each collection constant node gets assigned a single collection
+ *   object.
+ * - Call: each function is analyzed to determine which collection objects (of
+ *   its parameters or an object it originates) may be returned; a call node
+ *   originates a new collection object if it may return an object originated
+ *   inside the callee.
+ * - Undef: each undef node with a non-primitive type gets assigned a single
+ *   collection object.
+ *
+ * The analysis contains the following information:
+ *
+ * - For each node in each function, which collection objects may be on the
+ *   output of the node?
+ * - For each function, which collection objects may be mutated inside that
+ *   function?
+ * - For each function, which collection objects may be returned?
+ * - For each collection object, how was it originated?
+ */
+#[derive(Debug, Clone, Copy, PartialEq, Eq)]
+pub enum CollectionObjectOrigin {
+    Parameter(usize),
+    Constant(NodeID),
+    Call(NodeID),
+    Undef(NodeID),
+}
+
+define_id_type!(CollectionObjectID);
+
+#[derive(Debug, Clone)]
+pub struct FunctionCollectionObjects {
+    objects_per_node: Vec<Vec<CollectionObjectID>>,
+    mutated: BitVec<u8, Lsb0>,
+    returned: Vec<CollectionObjectID>,
+    origins: Vec<CollectionObjectOrigin>,
+}
+
+pub type CollectionObjects = BTreeMap<FunctionID, FunctionCollectionObjects>;
+
+impl CollectionObjectOrigin {
+    fn try_parameter(&self) -> Option<usize> {
+        match self {
+            CollectionObjectOrigin::Parameter(index) => Some(*index),
+            _ => None,
+        }
+    }
+}
+
+impl FunctionCollectionObjects {
+    pub fn objects(&self, id: NodeID) -> &Vec<CollectionObjectID> {
+        &self.objects_per_node[id.idx()]
+    }
+
+    pub fn origin(&self, object: CollectionObjectID) -> CollectionObjectOrigin {
+        self.origins[object.idx()]
+    }
+
+    pub fn param_to_object(&self, index: usize) -> Option<CollectionObjectID> {
+        self.origins
+            .iter()
+            .position(|origin| *origin == CollectionObjectOrigin::Parameter(index))
+            .map(CollectionObjectID::new)
+    }
+
+    pub fn returned_objects(&self) -> &Vec<CollectionObjectID> {
+        &self.returned
+    }
+
+    pub fn is_mutated(&self, object: CollectionObjectID) -> bool {
+        self.mutated[object.idx()]
+    }
+
+    pub fn num_objects(&self) -> usize {
+        self.origins.len()
+    }
+
+    pub fn iter_objects(&self) -> impl Iterator<Item = CollectionObjectID> {
+        (0..self.num_objects()).map(CollectionObjectID::new)
+    }
+}
+
+/*
+ * Each node is assigned a set of collection objects output-ed from the node.
+ * This is just a set of collection object IDs (usize).
+ */
+#[derive(PartialEq, Eq, Clone, Debug)]
+struct CollectionObjectLattice {
+    objs: BTreeSet<CollectionObjectID>,
+}
+
+impl Semilattice for CollectionObjectLattice {
+    fn meet(a: &Self, b: &Self) -> Self {
+        CollectionObjectLattice {
+            objs: a.objs.union(&b.objs).map(|x| *x).collect(),
+        }
+    }
+
+    fn top() -> Self {
+        CollectionObjectLattice {
+            objs: BTreeSet::new(),
+        }
+    }
+
+    fn bottom() -> Self {
+        // Technically, this lattice is unbounded - technically technically, the
+        // lattice is bounded by the number of collection objects in a given
+        // function, but incorporating this information is not possible in our
+        // Semilattice inferface. Luckily bottom() isn't necessary if we never
+        // call it, which we don't for this analysis.
+        panic!()
+    }
+}
+
+/*
+ * Top level function to analyze collection objects in a Hercules module.
+ */
+pub fn collection_objects(
+    module: &Module,
+    reverse_postorders: &Vec<Vec<NodeID>>,
+    typing: &ModuleTyping,
+    callgraph: &CallGraph,
+) -> CollectionObjects {
+    // Analyze functions in reverse topological order, since the analysis of a
+    // function depends on all functions it calls.
+    let mut collection_objects: CollectionObjects = BTreeMap::new();
+    let topo = callgraph.topo();
+
+    for func_id in topo {
+        let func = &module.functions[func_id.idx()];
+        let typing = &typing[func_id.idx()];
+        let reverse_postorder = &reverse_postorders[func_id.idx()];
+
+        // Find collection objects originating at parameters, constants, calls,
+        // or undefs. Each node may *originate* one collection object.
+        let param_origins = func
+            .param_types
+            .iter()
+            .enumerate()
+            .filter(|(_, ty_id)| !module.types[ty_id.idx()].is_primitive())
+            .map(|(idx, _)| CollectionObjectOrigin::Parameter(idx));
+        let other_origins = func
+            .nodes
+            .iter()
+            .enumerate()
+            .filter_map(|(idx, node)| match node {
+                Node::Constant { id: _ } if !module.types[typing[idx].idx()].is_primitive() => {
+                    Some(CollectionObjectOrigin::Constant(NodeID::new(idx)))
+                }
+                Node::Call {
+                    control: _,
+                    function: callee,
+                    dynamic_constants: _,
+                    args: _,
+                } if {
+                    let fco = &collection_objects[&callee];
+                    fco.returned
+                        .iter()
+                        .any(|returned| fco.origins[returned.idx()].try_parameter().is_none())
+                } =>
+                {
+                    // If the callee may return a new collection object, then
+                    // this call node originates a single collection object. The
+                    // node may output multiple collection objects, say if the
+                    // callee may return an object passed in as a parameter -
+                    // this is determined later.
+                    Some(CollectionObjectOrigin::Call(NodeID::new(idx)))
+                }
+                Node::Undef { ty: _ } if !module.types[typing[idx].idx()].is_primitive() => {
+                    Some(CollectionObjectOrigin::Undef(NodeID::new(idx)))
+                }
+                _ => None,
+            });
+        let origins: Vec<_> = param_origins.chain(other_origins).collect();
+
+        // Run dataflow analysis to figure out which collection objects each
+        // data node may output. Note that there's a strict subset of data nodes
+        // that can output collection objects:
+        //
+        // - Phi: selects between objects in SSA form, may be assigned multiple
+        //   possible objects.
+        // - Reduce: reduces over an object, similar to phis.
+        // - Parameter: may originate an object.
+        // - Constant: may originate an object.
+        // - Call: may originate an object and may return an object passed in as
+        //   a parameter.
+        // - Read: may extract a smaller object from the input - this is
+        //   considered to be the same object as the input, as no copy takes
+        //   place.
+        // - Write: updates an object - this is considered to be the same object
+        //   as the input object, as the write gets lowered to an in-place
+        //   mutation.
+        // - Undef: may originate a dummy object.
+        // - Ternary (select): selects between two objects, may output either.
+        let lattice = forward_dataflow(func, reverse_postorder, |inputs, id| {
+            match func.nodes[id.idx()] {
+                Node::Phi {
+                    control: _,
+                    data: _,
+                }
+                | Node::Reduce {
+                    control: _,
+                    init: _,
+                    reduct: _,
+                }
+                | Node::Ternary {
+                    op: TernaryOperator::Select,
+                    first: _,
+                    second: _,
+                    third: _,
+                } => inputs
+                    .into_iter()
+                    .fold(CollectionObjectLattice::top(), |acc, input| {
+                        CollectionObjectLattice::meet(&acc, input)
+                    }),
+                Node::Parameter { index } => {
+                    let obj = origins
+                        .iter()
+                        .position(|origin| *origin == CollectionObjectOrigin::Parameter(index))
+                        .map(CollectionObjectID::new);
+                    CollectionObjectLattice {
+                        objs: obj.into_iter().collect(),
+                    }
+                }
+                Node::Constant { id: _ } => {
+                    let obj = origins
+                        .iter()
+                        .position(|origin| *origin == CollectionObjectOrigin::Constant(id))
+                        .map(CollectionObjectID::new);
+                    CollectionObjectLattice {
+                        objs: obj.into_iter().collect(),
+                    }
+                }
+                Node::Call {
+                    control: _,
+                    function: callee,
+                    dynamic_constants: _,
+                    args: _,
+                } if !module.types[typing[id.idx()].idx()].is_primitive() => {
+                    let new_obj = origins
+                        .iter()
+                        .position(|origin| *origin == CollectionObjectOrigin::Call(id))
+                        .map(CollectionObjectID::new);
+                    let fco = &collection_objects[&callee];
+                    let param_objs = fco
+                        .returned
+                        .iter()
+                        .filter_map(|returned| fco.origins[returned.idx()].try_parameter())
+                        .map(|param_index| inputs[param_index + 1]);
+
+                    let mut objs: BTreeSet<_> = new_obj.into_iter().collect();
+                    for param_objs in param_objs {
+                        objs.extend(&param_objs.objs);
+                    }
+                    CollectionObjectLattice { objs }
+                }
+                Node::Undef { ty: _ } => {
+                    let obj = origins
+                        .iter()
+                        .position(|origin| *origin == CollectionObjectOrigin::Undef(id))
+                        .map(CollectionObjectID::new);
+                    CollectionObjectLattice {
+                        objs: obj.into_iter().collect(),
+                    }
+                }
+                Node::Read {
+                    collect: _,
+                    indices: _,
+                }
+                | Node::Write {
+                    collect: _,
+                    data: _,
+                    indices: _,
+                } => inputs[0].clone(),
+                _ => CollectionObjectLattice::top(),
+            }
+        });
+        let objects_per_node: Vec<Vec<_>> = lattice
+            .into_iter()
+            .map(|l| l.objs.into_iter().collect())
+            .collect();
+
+        // Look at the collection objects that each return may take as input.
+        let mut returned: BTreeSet<CollectionObjectID> = BTreeSet::new();
+        for node in func.nodes.iter() {
+            if let Node::Return { control: _, data } = node {
+                returned.extend(&objects_per_node[data.idx()]);
+            }
+        }
+        let returned = returned.into_iter().collect();
+
+        // Determine which objects are potentially mutated.
+        let mut mutated = bitvec![u8, Lsb0; 0; origins.len()];
+        for (idx, node) in func.nodes.iter().enumerate() {
+            if node.is_write() {
+                // Every object that the write itself corresponds to is mutable
+                // in this function.
+                for object in objects_per_node[idx].iter() {
+                    mutated.set(object.idx(), true);
+                }
+            } else if let Some((_, callee, _, args)) = node.try_call() {
+                let fco = &collection_objects[&callee];
+                for (param_idx, arg) in args.into_iter().enumerate() {
+                    // If this parameter corresponds to an object and it's
+                    // mutable in the callee...
+                    if let Some(param_callee_object) = fco.param_to_object(param_idx)
+                        && fco.is_mutated(param_callee_object)
+                    {
+                        // Then every object corresponding to the argument node
+                        // in this function is mutable.
+                        for object in objects_per_node[arg.idx()].iter() {
+                            mutated.set(object.idx(), true);
+                        }
+                    }
+                }
+            }
+        }
+
+        let fco = FunctionCollectionObjects {
+            objects_per_node,
+            mutated,
+            returned,
+            origins,
+        };
+        collection_objects.insert(func_id, fco);
+    }
+
+    collection_objects
+}
diff --git a/hercules_ir/src/lib.rs b/hercules_ir/src/lib.rs
index abe7f46f..05e5e2e8 100644
--- a/hercules_ir/src/lib.rs
+++ b/hercules_ir/src/lib.rs
@@ -9,6 +9,7 @@
 pub mod antideps;
 pub mod build;
 pub mod callgraph;
+pub mod collections;
 pub mod dataflow;
 pub mod def_use;
 pub mod dom;
@@ -24,6 +25,7 @@ pub mod verify;
 pub use crate::antideps::*;
 pub use crate::build::*;
 pub use crate::callgraph::*;
+pub use crate::collections::*;
 pub use crate::dataflow::*;
 pub use crate::def_use::*;
 pub use crate::dom::*;
diff --git a/hercules_opt/src/pass.rs b/hercules_opt/src/pass.rs
index 24ed0e4e..932ddee9 100644
--- a/hercules_opt/src/pass.rs
+++ b/hercules_opt/src/pass.rs
@@ -74,6 +74,7 @@ pub struct PassManager {
     pub antideps: Option<Vec<Vec<(NodeID, NodeID)>>>,
     pub data_nodes_in_fork_joins: Option<Vec<HashMap<NodeID, HashSet<NodeID>>>>,
     pub bbs: Option<Vec<Vec<NodeID>>>,
+    pub collection_objects: Option<CollectionObjects>,
     pub callgraph: Option<CallGraph>,
 }
 
@@ -95,6 +96,7 @@ impl PassManager {
             antideps: None,
             data_nodes_in_fork_joins: None,
             bbs: None,
+            collection_objects: None,
             callgraph: None,
         }
     }
@@ -315,6 +317,23 @@ impl PassManager {
         }
     }
 
+    pub fn make_collection_objects(&mut self) {
+        if self.collection_objects.is_none() {
+            self.make_reverse_postorders();
+            self.make_typing();
+            self.make_callgraph();
+            let reverse_postorders = self.reverse_postorders.as_ref().unwrap();
+            let typing = self.typing.as_ref().unwrap();
+            let callgraph = self.callgraph.as_ref().unwrap();
+            self.collection_objects = Some(collection_objects(
+                &self.module,
+                reverse_postorders,
+                typing,
+                callgraph,
+            ));
+        }
+    }
+
     pub fn make_callgraph(&mut self) {
         if self.callgraph.is_none() {
             self.callgraph = Some(callgraph(&self.module));
@@ -844,26 +863,15 @@ impl PassManager {
                     self.make_typing();
                     self.make_control_subgraphs();
                     self.make_bbs();
+                    self.make_collection_objects();
                     self.make_callgraph();
                     let reverse_postorders = self.reverse_postorders.as_ref().unwrap();
                     let typing = self.typing.as_ref().unwrap();
                     let control_subgraphs = self.control_subgraphs.as_ref().unwrap();
                     let bbs = self.bbs.as_ref().unwrap();
+                    let collection_objects = self.collection_objects.as_ref().unwrap();
                     let callgraph = self.callgraph.as_ref().unwrap();
 
-                    let memory_objects: Vec<_> = (0..self.module.functions.len())
-                        .map(|idx| {
-                            memory_objects(
-                                &self.module.functions[idx],
-                                &self.module.types,
-                                &reverse_postorders[idx],
-                                &typing[idx],
-                            )
-                        })
-                        .collect();
-                    let memory_objects_mutable =
-                        memory_objects_mutability(&self.module, &callgraph, &memory_objects);
-
                     let devices = device_placement(&self.module.functions, &callgraph);
 
                     let mut rust_rt = String::new();
@@ -889,10 +897,9 @@ impl PassManager {
                                 &typing[idx],
                                 &control_subgraphs[idx],
                                 &bbs[idx],
+                                &collection_objects,
                                 &callgraph,
                                 &devices,
-                                &memory_objects,
-                                &memory_objects_mutable,
                                 &mut rust_rt,
                             )
                             .unwrap(),
-- 
GitLab