From 2b326bd3dc54df360ddd41fb4d6da183944eacf7 Mon Sep 17 00:00:00 2001
From: Aaron Councilman <aaronjc4@illinois.edu>
Date: Tue, 18 Feb 2025 16:48:04 -0600
Subject: [PATCH] IR changes for multi-return

---
 hercules_ir/src/build.rs       | 15 +++++---
 hercules_ir/src/collections.rs | 14 ++++----
 hercules_ir/src/def_use.rs     | 22 +++++++++---
 hercules_ir/src/ir.rs          | 63 +++++++++++++++++++++++++---------
 hercules_ir/src/parse.rs       | 55 ++++++++++++++++++++++++-----
 hercules_ir/src/typecheck.rs   | 44 ++++++++++++++++++++----
 hercules_ir/src/verify.rs      | 14 +++++---
 7 files changed, 176 insertions(+), 51 deletions(-)

diff --git a/hercules_ir/src/build.rs b/hercules_ir/src/build.rs
index 40538cef..3e966d53 100644
--- a/hercules_ir/src/build.rs
+++ b/hercules_ir/src/build.rs
@@ -392,6 +392,7 @@ impl<'a> Builder<'a> {
     pub fn create_constant_zero(&mut self, typ: TypeID) -> ConstantID {
         match &self.module.types[typ.idx()] {
             Type::Control => panic!("Cannot create constant for control types"),
+            Type::MultiReturn(..) => panic!("Cannot create constant for multi-return types"),
             Type::Boolean => self.create_constant_bool(false),
             Type::Integer8 => self.create_constant_i8(0),
             Type::Integer16 => self.create_constant_i16(0),
@@ -503,7 +504,7 @@ impl<'a> Builder<'a> {
         &mut self,
         name: &str,
         param_types: Vec<TypeID>,
-        return_type: TypeID,
+        return_types: Vec<TypeID>,
         num_dynamic_constants: u32,
         entry: bool,
     ) -> BuilderResult<(FunctionID, NodeID)> {
@@ -515,7 +516,7 @@ impl<'a> Builder<'a> {
         self.module.functions.push(Function {
             name: name.to_owned(),
             param_types,
-            return_type,
+            return_types,
             num_dynamic_constants,
             entry,
             nodes: vec![Node::Start],
@@ -594,11 +595,15 @@ impl NodeBuilder {
         };
     }
 
-    pub fn build_projection(&mut self, control: NodeID, selection: usize) {
-        self.node = Node::Projection { control, selection };
+    pub fn build_control_projection(&mut self, control: NodeID, selection: usize) {
+        self.node = Node::ControlProjection { control, selection };
     }
 
-    pub fn build_return(&mut self, control: NodeID, data: NodeID) {
+    pub fn build_data_projection(&mut self, data: NodeID, selection: usize) {
+        self.node = Node::DataProjection { data, selection };
+    }
+
+    pub fn build_return(&mut self, control: NodeID, data: Box<[NodeID]>) {
         self.node = Node::Return { control, data };
     }
 
diff --git a/hercules_ir/src/collections.rs b/hercules_ir/src/collections.rs
index 6b631519..c4e71f8b 100644
--- a/hercules_ir/src/collections.rs
+++ b/hercules_ir/src/collections.rs
@@ -328,7 +328,9 @@ pub fn collection_objects(
         let mut returned: BTreeSet<CollectionObjectID> = BTreeSet::new();
         for node in func.nodes.iter() {
             if let Node::Return { control: _, data } = node {
-                returned.extend(&objects_per_node[data.idx()]);
+                for node in data {
+                    returned.extend(&objects_per_node[node.idx()]);
+                }
             }
         }
         let returned = returned.into_iter().collect();
@@ -500,16 +502,16 @@ pub fn no_reset_constant_collections(
                     collect: _,
                     data,
                     indices: _,
-                }
-                | Node::Return { control: _, data } => {
+                } => {
                     Either::Left(zip(once(&full_indices), once(data)))
                 }
-                Node::Call {
+                Node::Return { control: _, ref data }
+                | Node::Call {
                     control: _,
                     function: _,
                     dynamic_constants: _,
-                    ref args,
-                } => Either::Right(zip(repeat(&full_indices), args.into_iter().map(|id| *id))),
+                    args: ref data,
+                } => Either::Right(zip(repeat(&full_indices), data.into_iter().map(|id| *id))),
                 _ => return None,
             };
 
diff --git a/hercules_ir/src/def_use.rs b/hercules_ir/src/def_use.rs
index ff0e08ed..a99c8a23 100644
--- a/hercules_ir/src/def_use.rs
+++ b/hercules_ir/src/def_use.rs
@@ -156,7 +156,11 @@ pub fn get_uses(node: &Node) -> NodeUses {
             init,
             reduct,
         } => NodeUses::Three([*control, *init, *reduct]),
-        Node::Return { control, data } => NodeUses::Two([*control, *data]),
+        Node::Return { control, data } => {
+            let mut uses: Vec<NodeID> = Vec::from(&data[..]);
+            uses.push(*control);
+            NodeUses::Variable(uses.into_boxed_slice())
+        }
         Node::Parameter { index: _ } => NodeUses::One([NodeID::new(0)]),
         Node::Constant { id: _ } => NodeUses::One([NodeID::new(0)]),
         Node::DynamicConstant { id: _ } => NodeUses::One([NodeID::new(0)]),
@@ -222,10 +226,14 @@ pub fn get_uses(node: &Node) -> NodeUses {
                 NodeUses::Two([*collect, *data])
             }
         }
-        Node::Projection {
+        Node::ControlProjection {
             control,
             selection: _,
         } => NodeUses::One([*control]),
+        Node::DataProjection {
+            data,
+            selection: _,
+        } => NodeUses::One([*data]),
         Node::Undef { ty: _ } => NodeUses::One([NodeID::new(0)]),
     }
 }
@@ -260,7 +268,9 @@ pub fn get_uses_mut<'a>(node: &'a mut Node) -> NodeUsesMut<'a> {
             init,
             reduct,
         } => NodeUsesMut::Three([control, init, reduct]),
-        Node::Return { control, data } => NodeUsesMut::Two([control, data]),
+        Node::Return { control, data } => {
+            NodeUsesMut::Variable(std::iter::once(control).chain(data.iter_mut()).collect())
+        }
         Node::Parameter { index: _ } => NodeUsesMut::Zero,
         Node::Constant { id: _ } => NodeUsesMut::Zero,
         Node::DynamicConstant { id: _ } => NodeUsesMut::Zero,
@@ -326,10 +336,14 @@ pub fn get_uses_mut<'a>(node: &'a mut Node) -> NodeUsesMut<'a> {
                 NodeUsesMut::Two([collect, data])
             }
         }
-        Node::Projection {
+        Node::ControlProjection {
             control,
             selection: _,
         } => NodeUsesMut::One([control]),
+        Node::DataProjection {
+            data,
+            selection: _,
+        } => NodeUsesMut::One([data]),
         Node::Undef { ty: _ } => NodeUsesMut::Zero,
     }
 }
diff --git a/hercules_ir/src/ir.rs b/hercules_ir/src/ir.rs
index bf9698b3..68fdc26c 100644
--- a/hercules_ir/src/ir.rs
+++ b/hercules_ir/src/ir.rs
@@ -39,7 +39,7 @@ pub struct Module {
 pub struct Function {
     pub name: String,
     pub param_types: Vec<TypeID>,
-    pub return_type: TypeID,
+    pub return_types: Vec<TypeID>,
     pub num_dynamic_constants: u32,
     pub entry: bool,
 
@@ -77,6 +77,7 @@ pub enum Type {
     Product(Box<[TypeID]>),
     Summation(Box<[TypeID]>),
     Array(TypeID, Box<[DynamicConstantID]>),
+    MultiReturn(Box<[TypeID]>),
 }
 
 /*
@@ -186,7 +187,7 @@ pub enum Node {
     },
     Return {
         control: NodeID,
-        data: NodeID,
+        data: Box<[NodeID]>,
     },
     Parameter {
         index: usize,
@@ -237,10 +238,14 @@ pub enum Node {
         data: NodeID,
         indices: Box<[Index]>,
     },
-    Projection {
+    ControlProjection {
         control: NodeID,
         selection: usize,
     },
+    DataProjection {
+        data: NodeID,
+        selection: usize,
+    },
     Undef {
         ty: TypeID,
     },
@@ -434,6 +439,17 @@ impl Module {
                 }
                 write!(w, ")")
             }
+            Type::MultiReturn(fields) => {
+                write!(w, "MultiReturn(")?;
+                for idx in 0..fields.len() {
+                    let field_ty_id = fields[idx];
+                    self.write_type(field_ty_id, w)?;
+                    if idx + 1 < fields.len() {
+                        write!(w, ", ")?;
+                    }
+                }
+                write!(w, ")")
+            }
         }?;
 
         Ok(())
@@ -1262,12 +1278,19 @@ impl Node {
     );
     define_pattern_predicate!(is_match, Node::Match { control: _, sum: _ });
     define_pattern_predicate!(
-        is_projection,
-        Node::Projection {
+        is_control_projection,
+        Node::ControlProjection {
             control: _,
             selection: _
         }
     );
+    define_pattern_predicate!(
+        is_data_projection,
+        Node::DataProjection {
+            data: _,
+            selection: _
+        }
+    );
 
     define_pattern_predicate!(is_undef, Node::Undef { ty: _ });
 
@@ -1287,8 +1310,8 @@ impl Node {
         }
     }
 
-    pub fn try_proj(&self) -> Option<(NodeID, usize)> {
-        if let Node::Projection { control, selection } = self {
+    pub fn try_control_proj(&self) -> Option<(NodeID, usize)> {
+        if let Node::ControlProjection { control, selection } = self {
             Some((*control, *selection))
         } else {
             None
@@ -1303,9 +1326,9 @@ impl Node {
         }
     }
 
-    pub fn try_return(&self) -> Option<(NodeID, NodeID)> {
+    pub fn try_return(&self) -> Option<(NodeID, &[NodeID])> {
         if let Node::Return { control, data } = self {
-            Some((*control, *data))
+            Some((*control, data))
         } else {
             None
         }
@@ -1479,8 +1502,8 @@ impl Node {
         }
     }
 
-    pub fn try_projection(&self, branch: usize) -> Option<NodeID> {
-        if let Node::Projection { control, selection } = self
+    pub fn try_control_projection(&self, branch: usize) -> Option<NodeID> {
+        if let Node::ControlProjection { control, selection } = self
             && branch == *selection
         {
             Some(*control)
@@ -1560,10 +1583,14 @@ impl Node {
                 data: _,
                 indices: _,
             } => "Write",
-            Node::Projection {
+            Node::ControlProjection {
                 control: _,
                 selection: _,
-            } => "Projection",
+            } => "ControlProjection",
+            Node::DataProjection {
+                data: _,
+                selection: _,
+            } => "DataProjection",
             Node::Undef { ty: _ } => "Undef",
         }
     }
@@ -1639,10 +1666,14 @@ impl Node {
                 data: _,
                 indices: _,
             } => "write",
-            Node::Projection {
+            Node::ControlProjection {
                 control: _,
                 selection: _,
-            } => "projection",
+            } => "control_projection",
+            Node::DataProjection {
+                data: _,
+                selection: _,
+            } => "data_projection",
             Node::Undef { ty: _ } => "undef",
         }
     }
@@ -1655,7 +1686,7 @@ impl Node {
             || self.is_fork()
             || self.is_join()
             || self.is_return()
-            || self.is_projection()
+            || self.is_control_projection()
     }
 }
 
diff --git a/hercules_ir/src/parse.rs b/hercules_ir/src/parse.rs
index f1f4153a..42730f77 100644
--- a/hercules_ir/src/parse.rs
+++ b/hercules_ir/src/parse.rs
@@ -139,7 +139,7 @@ fn parse_module<'a>(ir_text: &'a str, context: Context<'a>) -> nom::IResult<&'a
         Function {
             name: String::from(""),
             param_types: vec![],
-            return_type: TypeID::new(0),
+            return_types: vec![],
             num_dynamic_constants: 0,
             entry: true,
             nodes: vec![],
@@ -245,7 +245,14 @@ fn parse_function<'a>(
     let ir_text = nom::character::complete::char(')')(ir_text)?.0;
     let ir_text = nom::character::complete::multispace0(ir_text)?.0;
     let ir_text = nom::bytes::complete::tag("->")(ir_text)?.0;
-    let (ir_text, return_type) = parse_type_id(ir_text, context)?;
+    let (ir_text, return_types) = nom::multi::separated_list1(
+        nom::sequence::tuple((
+            nom::character::complete::multispace0,
+            nom::character::complete::char(','),
+            nom::character::complete::multispace0,
+        )),
+        |text| parse_type_id(text, context),
+    )(ir_text)?;
     let (ir_text, nodes) = nom::multi::many1(|x| parse_node(x, context))(ir_text)?;
 
     // `nodes`, as returned by parsing, is in parse order, which may differ from
@@ -286,7 +293,7 @@ fn parse_function<'a>(
         Function {
             name: String::from(function_name),
             param_types: params.into_iter().map(|x| x.5).collect(),
-            return_type,
+            return_types,
             num_dynamic_constants,
             entry: true,
             nodes: fixed_nodes,
@@ -334,7 +341,8 @@ fn parse_node<'a>(
         "return" => parse_return(ir_text, context)?,
         "constant" => parse_constant_node(ir_text, context)?,
         "dynamic_constant" => parse_dynamic_constant_node(ir_text, context)?,
-        "projection" => parse_projection(ir_text, context)?,
+        "control_projection" => parse_control_projection(ir_text, context)?,
+        "data_projection" => parse_data_projection(ir_text, context)?,
         // Unary and binary ops are spelled out in the textual format, but we
         // parse them into Unary or Binary node kinds.
         "not" => parse_unary(ir_text, context, UnaryOperator::Not)?,
@@ -489,9 +497,21 @@ fn parse_return<'a>(
     ir_text: &'a str,
     context: &RefCell<Context<'a>>,
 ) -> nom::IResult<&'a str, Node> {
-    let (ir_text, (control, data)) = parse_tuple2(parse_identifier, parse_identifier)(ir_text)?;
+    let ir_text = nom::character::complete::multispace0(ir_text)?.0;
+    let ir_text = nom::character::complete::char('(')(ir_text)?.0;
+    let (ir_text, control) = parse_identifier(ir_text)?;
+    let ir_text = nom::character::complete::multispace0(ir_text)?.0;
+    let ir_text = nom::character::complete::char(',')(ir_text)?.0;
+    let ir_text = nom::character::complete::multispace0(ir_text)?.0;
+    let (ir_text, data) = nom::multi::separated_list1(
+        nom::sequence::tuple((
+            nom::character::complete::multispace0,
+            nom::character::complete::char(','),
+            nom::character::complete::multispace0,
+        )),
+        parse_identifier)(ir_text)?;
     let control = context.borrow_mut().get_node_id(control);
-    let data = context.borrow_mut().get_node_id(data);
+    let data = data.into_iter().map(|d| context.borrow_mut().get_node_id(d)).collect();
     Ok((ir_text, Node::Return { control, data }))
 }
 
@@ -719,7 +739,7 @@ fn parse_index<'a>(
     Ok((ir_text, idx))
 }
 
-fn parse_projection<'a>(
+fn parse_control_projection<'a>(
     ir_text: &'a str,
     context: &RefCell<Context<'a>>,
 ) -> nom::IResult<&'a str, Node> {
@@ -728,13 +748,29 @@ fn parse_projection<'a>(
     let control = context.borrow_mut().get_node_id(control);
     Ok((
         ir_text,
-        Node::Projection {
+        Node::ControlProjection {
             control,
             selection: index,
         },
     ))
 }
 
+fn parse_data_projection<'a>(
+    ir_text: &'a str,
+    context: &RefCell<Context<'a>>,
+) -> nom::IResult<&'a str, Node> {
+    let parse_usize = |x| parse_prim::<usize>(x, "1234567890");
+    let (ir_text, (data, index)) = parse_tuple2(parse_identifier, parse_usize)(ir_text)?;
+    let data = context.borrow_mut().get_node_id(data);
+    Ok((
+        ir_text,
+        Node::DataProjection {
+            data,
+            selection: index,
+        },
+    ))
+}
+
 fn parse_read<'a>(ir_text: &'a str, context: &RefCell<Context<'a>>) -> nom::IResult<&'a str, Node> {
     let ir_text = nom::character::complete::multispace0(ir_text)?.0;
     let ir_text = nom::character::complete::char('(')(ir_text)?.0;
@@ -991,7 +1027,8 @@ fn parse_constant<'a>(
 ) -> nom::IResult<&'a str, Constant> {
     let (ir_text, constant) = match ty {
         // There are not control constants.
-        Type::Control => Err(nom::Err::Error(nom::error::Error {
+        Type::Control
+        | Type::MultiReturn(_) => Err(nom::Err::Error(nom::error::Error {
             input: ir_text,
             code: nom::error::ErrorKind::IsNot,
         }))?,
diff --git a/hercules_ir/src/typecheck.rs b/hercules_ir/src/typecheck.rs
index 1ff890db..dca11fe7 100644
--- a/hercules_ir/src/typecheck.rs
+++ b/hercules_ir/src/typecheck.rs
@@ -441,12 +441,14 @@ fn typeflow(
                 return inputs[0].clone();
             }
 
-            if let Concrete(id) = inputs[1] {
-                if *id != function.return_type {
-                    return Error(String::from("Return node's data input type must be the same as the function's return type."));
+            for (idx, (input, return_type)) in inputs[1..].iter().zip(function.return_types.iter()).enumerate() {
+                if let Concrete(id) = input {
+                    if *id != *return_type {
+                        return Error(format!("Return node's data input at index {} does not match function's return type.", idx));
+                    }
+                } else if input.is_error() {
+                    return (*input).clone();
                 }
-            } else if inputs[1].is_error() {
-                return inputs[1].clone();
             }
 
             Concrete(get_type_id(
@@ -759,7 +761,7 @@ fn typeflow(
                 }
             }
 
-            Concrete(subst.type_subst(callee.return_type))
+            Concrete(subst.build_return_type(&callee.return_types))
         }
         Node::IntrinsicCall { intrinsic, args: _ } => {
             let num_params = match intrinsic {
@@ -1061,13 +1063,35 @@ fn typeflow(
                 TypeSemilattice::Error(msg) => TypeSemilattice::Error(msg),
             }
         }
-        Node::Projection {
+        Node::ControlProjection {
             control: _,
             selection: _,
         } => {
             // Type is the type of the _if node
             inputs[0].clone()
         }
+        Node::DataProjection {
+            data: _,
+            selection,
+        } => {
+            if let Concrete(type_id) = inputs[0] {
+                match &types[type_id.idx()] {
+                    Type::MultiReturn(types) => {
+                        if *selection >= types.len() {
+                            return Error(String::from("Data projection's selection must be in range of the multi-return being indexed"));
+                        }
+                        return Concrete(*type_id);
+                    }
+                    _ => {
+                        return Error(String::from(
+                            "Data projection node must read from multi-return value.",
+                        ));
+                    }
+                }
+            }
+
+            inputs[0].clone()
+        }
         Node::Undef { ty } => TypeSemilattice::Concrete(*ty),
     }
 }
@@ -1138,6 +1162,11 @@ impl<'a> DCSubst<'a> {
         }
     }
 
+    fn build_return_type(&mut self, tys: &[TypeID]) -> TypeID {
+        let tys = tys.iter().map(|t| self.type_subst(*t)).collect();
+        self.intern_type(Type::MultiReturn(tys))
+    }
+
     fn type_subst(&mut self, typ: TypeID) -> TypeID {
         match &self.types[typ.idx()] {
             Type::Control
@@ -1172,6 +1201,7 @@ impl<'a> DCSubst<'a> {
                 let new_elem = self.type_subst(elem);
                 self.intern_type(Type::Array(new_elem, new_dims))
             }
+            Type::MultiReturn(..) => panic!("A multi-return type should never be substituted"),
         }
     }
 
diff --git a/hercules_ir/src/verify.rs b/hercules_ir/src/verify.rs
index f188932e..b50ab0d2 100644
--- a/hercules_ir/src/verify.rs
+++ b/hercules_ir/src/verify.rs
@@ -251,11 +251,11 @@ fn verify_structure(
                     Err(format!("If node must have 2 users, not {}.", users.len()))?;
                 }
                 if let (
-                    Node::Projection {
+                    Node::ControlProjection {
                         control: _,
                         selection: result1,
                     },
-                    Node::Projection {
+                    Node::ControlProjection {
                         control: _,
                         selection: result2,
                     },
@@ -290,7 +290,8 @@ fn verify_structure(
                     Err("ThreadID node's control input must be a fork node.")?;
                 }
             }
-            // Call nodes must depend on a region node.
+            // Call nodes must depend on a region node and its only users must
+            // be DataProjections.
             Node::Call {
                 control,
                 function: _,
@@ -300,6 +301,11 @@ fn verify_structure(
                 if !function.nodes[control.idx()].is_region() {
                     Err("Call node's control input must be a region node.")?;
                 }
+                for user in users {
+                    if !function.nodes[user.idx()].is_data_projection() {
+                        Err("Call node users must be DataProjection nodes.")?;
+                    }
+                }
             }
             // Reduce nodes must depend on a join node.
             Node::Reduce {
@@ -339,7 +345,7 @@ fn verify_structure(
                     }
                     let mut users_covered = bitvec![u8, Lsb0; 0; users.len()];
                     for user in users {
-                        if let Node::Projection {
+                        if let Node::ControlProjection {
                             control: _,
                             ref selection,
                         } = function.nodes[user.idx()]
-- 
GitLab