From 53108467bf037b9aa1bcf036ee05d5521753d8fe Mon Sep 17 00:00:00 2001
From: Aaron Councilman <aaronjc4@illinois.edu>
Date: Wed, 19 Feb 2025 08:16:53 -0600
Subject: [PATCH] Fixes to front-end and ir

---
 hercules_ir/src/def_use.rs   | 4 ++--
 hercules_ir/src/dot.rs       | 8 ++++++++
 hercules_ir/src/typecheck.rs | 6 +++---
 juno_frontend/src/codegen.rs | 7 +++----
 4 files changed, 16 insertions(+), 9 deletions(-)

diff --git a/hercules_ir/src/def_use.rs b/hercules_ir/src/def_use.rs
index a99c8a23..99531345 100644
--- a/hercules_ir/src/def_use.rs
+++ b/hercules_ir/src/def_use.rs
@@ -157,8 +157,8 @@ pub fn get_uses(node: &Node) -> NodeUses {
             reduct,
         } => NodeUses::Three([*control, *init, *reduct]),
         Node::Return { control, data } => {
-            let mut uses: Vec<NodeID> = Vec::from(&data[..]);
-            uses.push(*control);
+            let mut uses: Vec<NodeID> = vec![*control];
+            uses.extend(data);
             NodeUses::Variable(uses.into_boxed_slice())
         }
         Node::Parameter { index: _ } => NodeUses::One([NodeID::new(0)]),
diff --git a/hercules_ir/src/dot.rs b/hercules_ir/src/dot.rs
index 921a813d..a7f890f8 100644
--- a/hercules_ir/src/dot.rs
+++ b/hercules_ir/src/dot.rs
@@ -349,6 +349,14 @@ fn write_node<W: Write>(
                 }
             }
         }
+        Node::ControlProjection {
+            control: _,
+            selection,
+        }
+        | Node::DataProjection {
+            data: _,
+            selection,
+        } => write!(&mut suffix, "{}", selection)?,
         _ => {}
     };
 
diff --git a/hercules_ir/src/typecheck.rs b/hercules_ir/src/typecheck.rs
index dca11fe7..919da640 100644
--- a/hercules_ir/src/typecheck.rs
+++ b/hercules_ir/src/typecheck.rs
@@ -423,8 +423,8 @@ fn typeflow(
             control: _,
             data: _,
         } => {
-            if inputs.len() != 2 {
-                return Error(String::from("Return node must have exactly two inputs."));
+            if inputs.len() < 1 {
+                return Error(String::from("Return node must have at least one input."));
             }
 
             // Check type of control input first, since this may produce an
@@ -1080,7 +1080,7 @@ fn typeflow(
                         if *selection >= types.len() {
                             return Error(String::from("Data projection's selection must be in range of the multi-return being indexed"));
                         }
-                        return Concrete(*type_id);
+                        return Concrete(types[*selection]);
                     }
                     _ => {
                         return Error(String::from(
diff --git a/juno_frontend/src/codegen.rs b/juno_frontend/src/codegen.rs
index 4e6f89e8..533fd268 100644
--- a/juno_frontend/src/codegen.rs
+++ b/juno_frontend/src/codegen.rs
@@ -276,7 +276,7 @@ impl CodeGenerator<'_> {
                     block = block_ret;
                 }
                 let mut return_node = self.builder.allocate_node();
-                return_node.build_return(block, vals);
+                return_node.build_return(block, vals.into());
                 self.builder.add_node(return_node);
                 None
             }
@@ -552,10 +552,9 @@ impl CodeGenerator<'_> {
                 // Read each of the "inout values" and perform the SSA update
                 let has_inouts = !inouts.is_empty();
                 for (idx, var) in inouts.into_iter().enumerate() {
-                    let index = self.builder.builder.create_field_index(num_returns + idx);
                     let mut proj = self.builder.allocate_node();
                     let proj_id = proj.id();
-                    proj.build_data_projection(call_id, index);
+                    proj.build_data_projection(call_id, num_returns + idx);
                     self.builder.add_node(proj);
 
                     ssa.write_variable(var, block, proj_id);
@@ -568,7 +567,7 @@ impl CodeGenerator<'_> {
 
                 let mut proj = self.builder.allocate_node();
                 let proj_id = proj.id();
-                proj.build_data_projection(call, index);
+                proj.build_data_projection(call, *index);
                 self.builder.add_node(proj);
 
                 (proj_id, block)
-- 
GitLab