From fac3f220759f59826bed3c7d5828c0f4fb468a10 Mon Sep 17 00:00:00 2001
From: Aaron Councilman <aaronjc4@illinois.edu>
Date: Wed, 19 Feb 2025 16:10:25 -0600
Subject: [PATCH] Fix collections and gcm

---
 hercules_cg/src/lib.rs                        |  4 +-
 hercules_ir/src/collections.rs                |  4 ++
 hercules_ir/src/ir.rs                         |  8 +++
 hercules_ir/src/parse.rs                      | 10 +--
 hercules_opt/src/gcm.rs                       | 64 +++++++++++++------
 juno_samples/multi_return/src/cpu.sch         |  5 ++
 juno_samples/multi_return/src/multi_return.jn |  2 +-
 7 files changed, 70 insertions(+), 27 deletions(-)

diff --git a/hercules_cg/src/lib.rs b/hercules_cg/src/lib.rs
index af2420d8..446231de 100644
--- a/hercules_cg/src/lib.rs
+++ b/hercules_cg/src/lib.rs
@@ -23,7 +23,7 @@ pub const LARGEST_ALIGNMENT: usize = 32;
  */
 pub fn get_type_alignment(types: &Vec<Type>, ty: TypeID) -> usize {
     match types[ty.idx()] {
-        Type::Control => panic!(),
+        Type::Control | Type::MultiReturn(_) => panic!(),
         Type::Boolean | Type::Integer8 | Type::UnsignedInteger8 | Type::Float8 => 1,
         Type::Integer16 | Type::UnsignedInteger16 | Type::BFloat16 => 2,
         Type::Integer32 | Type::UnsignedInteger32 | Type::Float32 => 4,
@@ -46,7 +46,7 @@ pub fn get_type_alignment(types: &Vec<Type>, ty: TypeID) -> usize {
 pub type FunctionNodeColors = (
     BTreeMap<NodeID, Device>,
     Vec<Option<Device>>,
-    Option<Device>,
+    Vec<Option<Device>>,
 );
 pub type NodeColors = BTreeMap<FunctionID, FunctionNodeColors>;
 
diff --git a/hercules_ir/src/collections.rs b/hercules_ir/src/collections.rs
index fb3e6bbd..cc0703ab 100644
--- a/hercules_ir/src/collections.rs
+++ b/hercules_ir/src/collections.rs
@@ -96,6 +96,10 @@ impl FunctionCollectionObjects {
         &self.returned[selection]
     }
 
+    pub fn all_returned_objects(&self) -> impl Iterator<Item = CollectionObjectID> + '_ {
+        self.returned.iter().flat_map(|colls| colls.iter().map(|c| *c))
+    }
+
     pub fn is_mutated(&self, object: CollectionObjectID) -> bool {
         !self.mutators(object).is_empty()
     }
diff --git a/hercules_ir/src/ir.rs b/hercules_ir/src/ir.rs
index 7a0158fb..3d625a39 100644
--- a/hercules_ir/src/ir.rs
+++ b/hercules_ir/src/ir.rs
@@ -863,6 +863,14 @@ impl Type {
         }
     }
 
+    pub fn is_multireturn(&self) -> bool {
+        if let Type::MultiReturn(_) = self {
+            true
+        } else {
+            false
+        }
+    }
+
     pub fn is_bool(&self) -> bool {
         self == &Type::Boolean
     }
diff --git a/hercules_ir/src/parse.rs b/hercules_ir/src/parse.rs
index a019f4d3..d61ff6e7 100644
--- a/hercules_ir/src/parse.rs
+++ b/hercules_ir/src/parse.rs
@@ -248,11 +248,11 @@ fn parse_function<'a>(
     let ir_text = nom::character::complete::multispace0(ir_text)?.0;
     let ir_text = nom::bytes::complete::tag("->")(ir_text)?.0;
     let (ir_text, return_types) = nom::multi::separated_list1(
-        nom::sequence::tuple((
+        (
             nom::character::complete::multispace0,
             nom::character::complete::char(','),
             nom::character::complete::multispace0,
-        )),
+        ),
         |text| parse_type_id(text, context),
     ).parse(ir_text)?;
     let (ir_text, nodes) = nom::multi::many1(|x| parse_node(x, context)).parse(ir_text)?;
@@ -506,13 +506,13 @@ fn parse_return<'a>(
     let ir_text = nom::character::complete::char(',')(ir_text)?.0;
     let ir_text = nom::character::complete::multispace0(ir_text)?.0;
     let (ir_text, data) = nom::multi::separated_list1(
-        nom::sequence::tuple((
+        (
             nom::character::complete::multispace0,
             nom::character::complete::char(','),
             nom::character::complete::multispace0,
-        )),
+        ),
         parse_identifier,
-    )(ir_text)?;
+    ).parse(ir_text)?;
     let control = context.borrow_mut().get_node_id(control);
     let data = data
         .into_iter()
diff --git a/hercules_opt/src/gcm.rs b/hercules_opt/src/gcm.rs
index f2405893..c2ec4e94 100644
--- a/hercules_opt/src/gcm.rs
+++ b/hercules_opt/src/gcm.rs
@@ -127,7 +127,7 @@ pub fn gcm(
     let mut alignments = vec![];
     Ref::map(editor.get_types(), |types| {
         for idx in 0..types.len() {
-            if types[idx].is_control() {
+            if types[idx].is_control() || types[idx].is_multireturn() {
                 alignments.push(0);
             } else {
                 alignments.push(get_type_alignment(types, TypeID::new(idx)));
@@ -255,6 +255,15 @@ fn basic_blocks(
                 dynamic_constants: _,
                 args: _,
             } => bbs[idx] = Some(control),
+            Node::DataProjection {
+                data,
+                selection: _,
+            } => {
+                let Node::Call { control, .. } = function.nodes[data.idx()] else {
+                    panic!();
+                };
+                bbs[idx] = Some(control);
+            }
             Node::Parameter { index: _ } => bbs[idx] = Some(NodeID::new(0)),
             _ if function.nodes[idx].is_control() => bbs[idx] = Some(NodeID::new(idx)),
             _ => {}
@@ -508,7 +517,7 @@ fn basic_blocks(
                     && objects[&func_id]
                         .objects(id)
                         .into_iter()
-                        .any(|obj| objects[&func_id].returned_objects().contains(obj));
+                        .any(|obj| objects[&func_id].all_returned_objects().any(|ret| ret == *obj));
                 let old_nest = loops
                     .header_of(location)
                     .map(|header| loops.nesting(header).unwrap());
@@ -646,9 +655,9 @@ fn terminating_reads<'a>(
             ref args,
         } => Box::new(args.into_iter().enumerate().filter_map(move |(idx, arg)| {
             let objects = &objects[&callee];
-            let returns = objects.returned_objects();
+            let mut returns = objects.all_returned_objects();
             let param_obj = objects.param_to_object(idx)?;
-            if !objects.is_mutated(param_obj) && !returns.contains(&param_obj) {
+            if !objects.is_mutated(param_obj) && !returns.any(|ret| ret == param_obj) {
                 Some(*arg)
             } else {
                 None
@@ -692,9 +701,9 @@ fn forwarding_reads<'a>(
             ref args,
         } => Box::new(args.into_iter().enumerate().filter_map(move |(idx, arg)| {
             let objects = &objects[&callee];
-            let returns = objects.returned_objects();
+            let mut returns = objects.all_returned_objects();
             let param_obj = objects.param_to_object(idx)?;
-            if !objects.is_mutated(param_obj) && returns.contains(&param_obj) {
+            if !objects.is_mutated(param_obj) && returns.any(|ret| ret == param_obj) {
                 Some(*arg)
             } else {
                 None
@@ -1218,7 +1227,7 @@ fn color_nodes(
     let mut func_colors = (
         BTreeMap::new(),
         vec![None; editor.func().param_types.len()],
-        None,
+        vec![None; editor.func().return_types.len()],
     );
 
     // Assigning nodes to devices is tricky due to function calls. Technically,
@@ -1320,17 +1329,31 @@ fn color_nodes(
                         equations.push((UTerm::Node(*arg), UTerm::Device(device)));
                     }
                 }
+            }
+            Node::DataProjection {
+                data,
+                selection,
+            } => {
+                let Node::Call {
+                    control: _,
+                    function: callee,
+                    dynamic_constants: _,
+                    ref args,
+                } = &nodes[data.idx()] else {
+                    panic!()
+                };
 
-                // If the callee has a definite device for the returned value,
-                // add an equation for the call node itself.
-                if let Some(device) = node_colors[&callee].2 {
+                // If the callee has a definite device for this returned value,
+                // add an equation for the data projection node itself.
+                if let Some(device) = node_colors[&callee].2[selection] {
                     equations.push((UTerm::Node(id), UTerm::Device(device)));
                 }
 
-                // For any object that may be returned by the callee that
-                // originates as a parameter in the callee, the device of the
-                // corresponding argument and call node itself must be equal.
-                for ret in objects[&callee].returned_objects() {
+                // For any object that may be returned in this position by the
+                // callee that originates as a parameter in the callee, the
+                // device of the corresponding argument and the data projection
+                // must be equal.
+                for ret in objects[&callee].returned_objects(selection) {
                     if let Some(idx) = objects[&callee].origin(*ret).try_parameter() {
                         equations.push((UTerm::Node(args[idx]), UTerm::Node(id)));
                     }
@@ -1365,11 +1388,13 @@ fn color_nodes(
                 {
                     assert!(func_colors.1[index].is_none(), "PANIC: Found multiple parameter nodes for the same index in GCM. Please just run GVN first.");
                     func_colors.1[index] = Some(*device);
-                } else if let Node::Return { control: _, data } = nodes[id.idx()]
-                    && let Some(device) = func_colors.0.get(&data)
-                {
-                    assert!(func_colors.2.is_none(), "PANIC: Found multiple return nodes in GCM. Contact Russel if you see this, it's an easy fix.");
-                    func_colors.2 = Some(*device);
+                } else if let Node::Return { control: _, ref data } = nodes[id.idx()] {
+                    for (idx, val) in data.iter().enumerate() {
+                        if let Some(device) = func_colors.0.get(val) {
+                            assert!(func_colors.2[idx].is_none(), "PANIC: Found multiple return nodes in GCM. Contact Russel if you see this, it's an easy fix.");
+                            func_colors.2[idx] = Some(*device);
+                        }
+                    }
                 }
             }
             Some(func_colors)
@@ -1420,6 +1445,7 @@ fn type_size(edit: &mut FunctionEdit, ty_id: TypeID, alignments: &Vec<usize>) ->
     let ty = edit.get_type(ty_id).clone();
     let size = match ty {
         Type::Control => panic!(),
+        Type::MultiReturn(_) => panic!(),
         Type::Boolean | Type::Integer8 | Type::UnsignedInteger8 | Type::Float8 => {
             edit.add_dynamic_constant(DynamicConstant::Constant(1))
         }
diff --git a/juno_samples/multi_return/src/cpu.sch b/juno_samples/multi_return/src/cpu.sch
index 03fb2585..972405f5 100644
--- a/juno_samples/multi_return/src/cpu.sch
+++ b/juno_samples/multi_return/src/cpu.sch
@@ -4,6 +4,10 @@ dce(*);
 
 ip-sroa(*);
 sroa(*);
+
+ip-sroa[true](rolling_sum);
+sroa[true](rolling_sum, rolling_sum_prod);
+
 dce(*);
 
 forkify(*);
@@ -29,3 +33,4 @@ ccp(*);
 dce(*);
 
 gcm(*);
+xdot[true](*);
diff --git a/juno_samples/multi_return/src/multi_return.jn b/juno_samples/multi_return/src/multi_return.jn
index a49df91c..84bab015 100644
--- a/juno_samples/multi_return/src/multi_return.jn
+++ b/juno_samples/multi_return/src/multi_return.jn
@@ -1,4 +1,4 @@
-fn rolling_sum<t: number, n: usize>(x: t[n]) -> t, t[n + 1] {
+fn rolling_sum<t: number, n: usize>(x: t[n]) -> (t, t[n + 1]) {
   let rolling_sum: t[n + 1];
   let sum = 0;
 
-- 
GitLab