From 369b8beee80194974a63352e454b354d90c7e739 Mon Sep 17 00:00:00 2001
From: Aaron Councilman <aaronjc4@illinois.edu>
Date: Wed, 19 Feb 2025 15:14:15 -0600
Subject: [PATCH] Collections analysis

---
 hercules_ir/src/collections.rs | 99 +++++++++++++++++++---------------
 1 file changed, 57 insertions(+), 42 deletions(-)

diff --git a/hercules_ir/src/collections.rs b/hercules_ir/src/collections.rs
index 06a53fdb..fb3e6bbd 100644
--- a/hercules_ir/src/collections.rs
+++ b/hercules_ir/src/collections.rs
@@ -40,7 +40,7 @@ use crate::*;
 pub enum CollectionObjectOrigin {
     Parameter(usize),
     Constant(NodeID),
-    Call(NodeID),
+    DataProjection(NodeID),
     Undef(NodeID),
 }
 
@@ -50,7 +50,7 @@ define_id_type!(CollectionObjectID);
 pub struct FunctionCollectionObjects {
     objects_per_node: Vec<Vec<CollectionObjectID>>,
     mutated: Vec<Vec<NodeID>>,
-    returned: Vec<CollectionObjectID>,
+    returned: Vec<Vec<CollectionObjectID>>,
     origins: Vec<CollectionObjectOrigin>,
 }
 
@@ -92,8 +92,8 @@ impl FunctionCollectionObjects {
             .map(CollectionObjectID::new)
     }
 
-    pub fn returned_objects(&self) -> &Vec<CollectionObjectID> {
-        &self.returned
+    pub fn returned_objects(&self, selection: usize) -> &Vec<CollectionObjectID> {
+        &self.returned[selection]
     }
 
     pub fn is_mutated(&self, object: CollectionObjectID) -> bool {
@@ -155,8 +155,6 @@ pub fn collection_objects(
     typing: &ModuleTyping,
     callgraph: &CallGraph,
 ) -> CollectionObjects {
-    panic!("Collections analysis needs to be updated to handle multi-return");
-
     // Analyze functions in reverse topological order, since the analysis of a
     // function depends on all functions it calls.
     let mut collection_objects: CollectionObjects = BTreeMap::new();
@@ -167,8 +165,9 @@ pub fn collection_objects(
         let typing = &typing[func_id.idx()];
         let reverse_postorder = &reverse_postorders[func_id.idx()];
 
-        // Find collection objects originating at parameters, constants, calls,
-        // or undefs. Each node may *originate* one collection object.
+        // Find collection objects originating at parameters, constants,
+        // data projections (of calls), or undefs.
+        // Each of these nodes may *originate* one collection object.
         let param_origins = func
             .param_types
             .iter()
@@ -183,24 +182,29 @@ pub fn collection_objects(
                 Node::Constant { id: _ } if !types[typing[idx].idx()].is_primitive() => {
                     Some(CollectionObjectOrigin::Constant(NodeID::new(idx)))
                 }
-                Node::Call {
-                    control: _,
-                    function: callee,
-                    dynamic_constants: _,
-                    args: _,
-                } if {
+                Node::DataProjection { data, selection } => {
+                    let Node::Call { 
+                        control: _,
+                        function: callee,
+                        dynamic_constants: _,
+                        args: _,
+                    } = func.nodes[data.idx()] else {
+                        panic!("Data-projection's data is not a call node");
+                    };
+
                     let fco = &collection_objects[&callee];
-                    fco.returned
-                        .iter()
-                        .any(|returned| fco.origins[returned.idx()].try_parameter().is_none())
-                } =>
-                {
-                    // If the callee may return a new collection object, then
-                    // this call node originates a single collection object. The
-                    // node may output multiple collection objects, say if the
-                    // callee may return an object passed in as a parameter -
-                    // this is determined later.
-                    Some(CollectionObjectOrigin::Call(NodeID::new(idx)))
+                    if fco.returned[*selection]
+                          .iter()
+                          .any(|returned| fco.origins[returned.idx()].try_parameter().is_some()) {
+                        // If the callee may return a new collection object, then
+                        // this data projection node originates a single collection object. The
+                        // node may output multiple collection objects, say if the
+                        // callee may return an object passed in as a parameter -
+                        // this is determined later.
+                        Some(CollectionObjectOrigin::DataProjection(NodeID::new(idx)))
+                    } else {
+                        None
+                    }
                 }
                 Node::Undef { ty: _ } if !types[typing[idx].idx()].is_primitive() => {
                     Some(CollectionObjectOrigin::Undef(NodeID::new(idx)))
@@ -218,8 +222,8 @@ pub fn collection_objects(
         // - Reduce: reduces over an object, similar to phis.
         // - Parameter: may originate an object.
         // - Constant: may originate an object.
-        // - Call: may originate an object and may return an object passed in as
-        //   a parameter.
+        // - DataProjection: may originate an object and may return an object
+        //   passed in to its associated call as a parameter.
         // - LibraryCall: may return an object passed in as a parameter, but may
         //   not originate an object.
         // - Read: may extract a smaller object from the input - this is
@@ -230,7 +234,13 @@ pub fn collection_objects(
         //   mutation.
         // - Undef: may originate a dummy object.
         // - Ternary (select): selects between two objects, may output either.
-        let lattice = forward_dataflow(func, reverse_postorder, |inputs, id| {
+        let lattice = dataflow_global(func, reverse_postorder, |global_input, id| {
+            let inputs = get_uses(&func.nodes[id.idx()])
+                .as_ref()
+                .iter()
+                .map(|id| &global_input[id.idx()])
+                .collect::<Vec<_>>();
+
             match func.nodes[id.idx()] {
                 Node::Phi {
                     control: _,
@@ -269,22 +279,27 @@ pub fn collection_objects(
                         objs: obj.into_iter().collect(),
                     }
                 }
-                Node::Call {
-                    control: _,
-                    function: callee,
-                    dynamic_constants: _,
-                    args: _,
-                } if !types[typing[id.idx()].idx()].is_primitive() => {
+                Node::DataProjection { data, selection } 
+                if !types[typing[id.idx()].idx()].is_primitive() => {
+                    let Node::Call {
+                        control: _,
+                        function: callee,
+                        dynamic_constants: _,
+                        ref args,
+                    } = func.nodes[data.idx()] else {
+                        panic!();
+                    };
+
                     let new_obj = origins
                         .iter()
-                        .position(|origin| *origin == CollectionObjectOrigin::Call(id))
+                        .position(|origin| *origin == CollectionObjectOrigin::DataProjection(id))
                         .map(CollectionObjectID::new);
                     let fco = &collection_objects[&callee];
                     let param_objs = fco
-                        .returned
+                        .returned[selection]
                         .iter()
                         .filter_map(|returned| fco.origins[returned.idx()].try_parameter())
-                        .map(|param_index| inputs[param_index + 1]);
+                        .map(|param_index| &global_input[args[param_index].idx()]);
 
                     let mut objs: BTreeSet<_> = new_obj.into_iter().collect();
                     for param_objs in param_objs {
@@ -326,16 +341,16 @@ pub fn collection_objects(
             .map(|l| l.objs.into_iter().collect())
             .collect();
 
-        // Look at the collection objects that each return may take as input.
-        let mut returned: BTreeSet<CollectionObjectID> = BTreeSet::new();
+        // Look at the collection objects that each return value may take as input.
+        let mut returned: Vec<BTreeSet<CollectionObjectID>> = vec![BTreeSet::new(); func.return_types.len()];
         for node in func.nodes.iter() {
             if let Node::Return { control: _, data } = node {
-                for node in data {
-                    returned.extend(&objects_per_node[node.idx()]);
+                for (idx, node) in data.iter().enumerate() {
+                    returned[idx].extend(&objects_per_node[node.idx()]);
                 }
             }
         }
-        let returned = returned.into_iter().collect();
+        let returned = returned.into_iter().map(|set| set.into_iter().collect()).collect();
 
         // Determine which objects are potentially mutated.
         let mut mutated = vec![vec![]; origins.len()];
-- 
GitLab