diff --git a/Cargo.lock b/Cargo.lock
index ebe621e801437e1d3512943a39df7382fa578f72..fdcbaf8426dd64fca782b34240256cf149657303 100644
--- a/Cargo.lock
+++ b/Cargo.lock
@@ -1414,6 +1414,16 @@ dependencies = [
  "with_builtin_macros",
 ]
 
+[[package]]
+name = "juno_multi_return"
+version = "0.1.0"
+dependencies = [
+ "async-std",
+ "hercules_rt",
+ "juno_build",
+ "with_builtin_macros",
+]
+
 [[package]]
 name = "juno_patterns"
 version = "0.1.0"
diff --git a/Cargo.toml b/Cargo.toml
index 42d2813535ca9c00facef65212269555d0173f27..01f8cc138b7a2dfc2f3a6cc2882558961fd6b63f 100644
--- a/Cargo.toml
+++ b/Cargo.toml
@@ -26,6 +26,7 @@ members = [
 	"juno_samples/matmul",
 	"juno_samples/median_window",
 	"juno_samples/multi_device",
+	"juno_samples/multi_return",
 	"juno_samples/patterns",
 	"juno_samples/products",
 	"juno_samples/rodinia/backprop",
diff --git a/hercules_cg/src/cpu.rs b/hercules_cg/src/cpu.rs
index 27daf2a10c7f3cacb915127389c625f4b6c671f3..6ad38fc0e5e59f18771a9ee56819bd6269cd868a 100644
--- a/hercules_cg/src/cpu.rs
+++ b/hercules_cg/src/cpu.rs
@@ -60,20 +60,36 @@ struct LLVMBlock {
 impl<'a> CPUContext<'a> {
     fn codegen_function<W: Write>(&self, w: &mut W) -> Result<(), Error> {
         // Dump the function signature.
-        if self.types[self.function.return_type.idx()].is_primitive() {
-            write!(
-                w,
-                "define dso_local {} @{}(",
-                self.get_type(self.function.return_type),
-                self.function.name
-            )?;
+        if self.function.return_types.len() == 1 {
+            let return_type = self.function.return_types[0];
+            if self.types[return_type.idx()].is_primitive() {
+                write!(
+                    w,
+                    "define dso_local {} @{}(",
+                    self.get_type(return_type),
+                    self.function.name
+                )?;
+            } else {
+                write!(
+                    w,
+                    "define dso_local nonnull noundef {} @{}(",
+                    self.get_type(return_type),
+                    self.function.name
+                )?;
+            }
         } else {
             write!(
                 w,
-                "define dso_local nonnull noundef {} @{}(",
-                self.get_type(self.function.return_type),
-                self.function.name
+                "%return.{} = type {{ {} }}\n",
+                self.function.name,
+                self.function
+                    .return_types
+                    .iter()
+                    .map(|t| self.get_type(*t))
+                    .collect::<Vec<_>>()
+                    .join(", "),
             )?;
+            write!(w, "define dso_local void @{}(", self.function.name,)?;
         }
         let mut first_param = true;
         // The first parameter is a pointer to CPU backing memory, if it's
@@ -110,6 +126,17 @@ impl<'a> CPUContext<'a> {
                 )?;
             }
         }
+        // Lastly, if the function has multiple returns, is a pointer to the return struct
+        if self.function.return_types.len() != 1 {
+            if !first_param {
+                write!(w, ", ")?;
+            }
+            write!(
+                w,
+                "ptr noalias nofree nonnull noundef sret(%return.{}) %ret.ptr",
+                self.function.name,
+            )?;
+        }
         write!(w, ") {{\n")?;
 
         let mut blocks: BTreeMap<_, _> = (0..self.function.nodes.len())
@@ -171,7 +198,7 @@ impl<'a> CPUContext<'a> {
             // successor and are otherwise simple.
             Node::Start
             | Node::Region { preds: _ }
-            | Node::Projection {
+            | Node::ControlProjection {
                 control: _,
                 selection: _,
             } => {
@@ -186,7 +213,9 @@ impl<'a> CPUContext<'a> {
                 let mut succs = self.control_subgraph.succs(id);
                 let succ1 = succs.next().unwrap();
                 let succ2 = succs.next().unwrap();
-                let succ1_is_true = self.function.nodes[succ1.idx()].try_projection(1).is_some();
+                let succ1_is_true = self.function.nodes[succ1.idx()]
+                    .try_control_projection(1)
+                    .is_some();
                 write!(
                     term,
                     "  br {}, label %{}, label %{}\n",
@@ -195,9 +224,35 @@ impl<'a> CPUContext<'a> {
                     self.get_block_name(if succ1_is_true { succ2 } else { succ1 }),
                 )?
             }
-            Node::Return { control: _, data } => {
-                let term = &mut blocks.get_mut(&id).unwrap().term;
-                write!(term, "  ret {}\n", self.get_value(data, true))?
+            Node::Return {
+                control: _,
+                ref data,
+            } => {
+                if data.len() == 1 {
+                    let ret_data = data[0];
+                    let term = &mut blocks.get_mut(&id).unwrap().term;
+                    write!(term, "  ret {}\n", self.get_value(ret_data, true))?
+                } else {
+                    let term = &mut blocks.get_mut(&id).unwrap().term;
+                    // Generate gep and stores into the output pointer
+                    for (idx, val) in data.iter().enumerate() {
+                        write!(
+                            term,
+                            "  %ret_ptr.{} = getelementptr inbounds %return.{}, ptr %ret.ptr, i32 0, i32 {}\n",
+                            idx,
+                            self.function.name,
+                            idx,
+                        )?;
+                        write!(
+                            term,
+                            "  store {}, ptr %ret_ptr.{}\n",
+                            self.get_value(*val, true),
+                            idx,
+                        )?;
+                    }
+                    // Finally return void
+                    write!(term, "  ret void\n")?
+                }
             }
             _ => panic!(
                 "PANIC: Can't lower {:?} in {}.",
@@ -808,7 +863,7 @@ impl<'a> CPUContext<'a> {
      */
     fn codegen_type_size(&self, ty: TypeID, body: &mut String) -> Result<String, Error> {
         match self.types[ty.idx()] {
-            Type::Control => panic!(),
+            Type::Control | Type::MultiReturn(_) => panic!(),
             Type::Boolean | Type::Integer8 | Type::UnsignedInteger8 | Type::Float8 => {
                 Ok("1".to_string())
             }
@@ -974,7 +1029,7 @@ fn convert_intrinsic(intrinsic: &Intrinsic, ty: &Type) -> String {
             } else {
                 panic!()
             }
-        },
+        }
         Intrinsic::ACos => "acos",
         Intrinsic::ASin => "asin",
         Intrinsic::ATan => "atan",
diff --git a/hercules_cg/src/gpu.rs b/hercules_cg/src/gpu.rs
index 931071cb2747ee03cb3e87f342f3737e6eb82404..76aba7e030492706a367bbb5f31107014010a129 100644
--- a/hercules_cg/src/gpu.rs
+++ b/hercules_cg/src/gpu.rs
@@ -150,13 +150,19 @@ pub fn gpu_codegen<W: Write>(
         }
     }
 
-    let return_parameter = if collection_objects.returned_objects().len() == 1 {
-        collection_objects
-            .origin(*collection_objects.returned_objects().first().unwrap())
-            .try_parameter()
-    } else {
-        None
-    };
+    // Tracks for each return value whether it is always the same parameter
+    // collection
+    let return_parameters = (0..function.return_types.len())
+        .map(|idx| {
+            if collection_objects.returned_objects(idx).len() == 1 {
+                collection_objects
+                    .origin(*collection_objects.returned_objects(idx).first().unwrap())
+                    .try_parameter()
+            } else {
+                None
+            }
+        })
+        .collect::<Vec<_>>();
 
     let kernel_params = &GPUKernelParams {
         max_num_threads: 1024,
@@ -181,7 +187,7 @@ pub fn gpu_codegen<W: Write>(
         fork_reduce_map,
         reduct_reduce_map,
         control_data_phi_map,
-        return_parameter,
+        return_parameters,
         kernel_params,
     };
     ctx.codegen_function(w)
@@ -210,7 +216,7 @@ struct GPUContext<'a> {
     fork_reduce_map: HashMap<NodeID, Vec<NodeID>>,
     reduct_reduce_map: HashMap<NodeID, Vec<NodeID>>,
     control_data_phi_map: HashMap<NodeID, Vec<(NodeID, NodeID)>>,
-    return_parameter: Option<usize>,
+    return_parameters: Vec<Option<usize>>,
     kernel_params: &'a GPUKernelParams,
 }
 
@@ -262,7 +268,9 @@ impl GPUContext<'_> {
     fn codegen_function<W: Write>(&self, w: &mut W) -> Result<(), Error> {
         // Emit all code up to the "goto" to Start's block
         let mut top = String::new();
-        self.codegen_kernel_begin(self.return_parameter.is_none(), &mut top)?;
+        self.codegen_kernel_preamble(&mut top)?;
+        self.codegen_return_struct(&mut top)?;
+        self.codegen_kernel_begin(&mut top)?;
         let mut dynamic_shared_offset = "0".to_string();
         self.codegen_dynamic_constants(&mut top)?;
         self.codegen_declare_data(&mut top)?;
@@ -339,10 +347,7 @@ impl GPUContext<'_> {
         Ok(())
     }
 
-    /*
-     * Emit kernel headers, signature, arguments, and dynamic shared memory declaration
-     */
-    fn codegen_kernel_begin(&self, has_ret_var: bool, w: &mut String) -> Result<(), Error> {
+    fn codegen_kernel_preamble<W: Write>(&self, w: &mut W) -> Result<(), Error> {
         write!(
             w,
             "
@@ -366,8 +371,28 @@ namespace cg = cooperative_groups;
 #define isqrt(a) ((int)sqrtf((float)(a)))
 
 ",
-        )?;
+        )
+    }
+
+    fn codegen_return_struct<W: Write>(&self, w: &mut W) -> Result<(), Error> {
+        write!(
+            w,
+            "struct return_{} {{ {} }};\n",
+            self.function.name,
+            self.function
+                .return_types
+                .iter()
+                .enumerate()
+                .map(|(idx, typ)| format!("{} f{};", self.get_type(*typ, false), idx))
+                .collect::<Vec<_>>()
+                .join(" "),
+        )
+    }
 
+    /*
+     * Emit kernel signature, arguments, and dynamic shared memory declaration
+     */
+    fn codegen_kernel_begin<W: Write>(&self, w: &mut W) -> Result<(), Error> {
         write!(
             w,
             "__global__ void __launch_bounds__({}) {}_gpu(",
@@ -403,11 +428,23 @@ namespace cg = cooperative_groups;
             };
             write!(w, "{} p{}", param_type, idx)?;
         }
-        if has_ret_var {
+        let ret_fields = self
+            .return_parameters
+            .iter()
+            .enumerate()
+            .filter_map(|(idx, param)| {
+                if param.is_some() {
+                    None
+                } else {
+                    Some((idx, self.function.return_types[idx]))
+                }
+            })
+            .collect::<Vec<(usize, TypeID)>>();
+        if !ret_fields.is_empty() {
             if !first_param {
                 write!(w, ", ")?;
             }
-            write!(w, "void* __restrict__ ret",)?;
+            write!(w, "return_{}* __restrict__ ret", self.function.name)?;
         }
 
         // Type is char since it's simplest to use single bytes for indexing
@@ -599,17 +636,17 @@ namespace cg = cooperative_groups;
         dynamic_shared_offset: &str,
         w: &mut String,
     ) -> Result<(), Error> {
-        // The following steps are for host-side C function arguments, but we also
-        // need to pass arguments to kernel, so we keep track of the arguments here.
-        let ret_type = self.get_type(self.function.return_type, false);
         let mut pass_args = String::new();
-        write!(
-            w,
-            "
-extern \"C\" {} {}(",
-            ret_type.clone(),
-            self.function.name
-        )?;
+
+        let is_multi_return = self.function.return_types.len() != 1;
+        write!(w, "extern \"C\" ")?;
+        if is_multi_return {
+            write!(w, "void")?;
+        } else {
+            write!(w, "{}", self.get_type(self.function.return_types[0], false))?;
+        }
+        write!(w, " {}(", self.function.name)?;
+
         let mut first_param = true;
         // The first parameter is a pointer to GPU backing memory, if it's
         // needed.
@@ -641,20 +678,46 @@ extern \"C\" {} {}(",
             write!(w, "{} p{}", param_type, idx)?;
             write!(pass_args, "p{}", idx)?;
         }
+        // If the function is multi-return, the last argument is the return pointer
+        // This is a CPU pointer, we will allocate a separate pointer used for the kernel's return
+        // arguments (if any)
+        if is_multi_return {
+            if !first_param {
+                write!(w, ", ")?;
+            }
+            write!(w, "return_{}* ret_ptr", self.function.name)?;
+        }
         write!(w, ") {{\n")?;
         // For case of dynamic block count
         self.codegen_dynamic_constants(w)?;
-        let has_ret_var = self.return_parameter.is_none();
-        if has_ret_var {
-            // Allocate return parameter and lift to kernel argument
-            let ret_type_pnt = self.get_type(self.function.return_type, true);
-            write!(w, "\t{} ret;\n", ret_type_pnt)?;
+
+        let (kernel_returns, param_returns) = self.return_parameters.iter().enumerate().fold(
+            (vec![], vec![]),
+            |(mut kernel_returns, mut param_returns), (idx, param)| {
+                if let Some(param_idx) = param {
+                    param_returns.push((idx, param_idx));
+                } else {
+                    kernel_returns.push((idx, self.function.return_types[idx]));
+                }
+                (kernel_returns, param_returns)
+            },
+        );
+
+        if !kernel_returns.is_empty() {
+            // Allocate kernel return struct
+            write!(w, "\treturn_{}* ret_cuda;\n", self.function.name)?;
+            write!(
+                w,
+                "\tcudaMalloc((void**)&ret_cuda, sizeof(return_{}));\n",
+                self.function.name
+            )?;
+            // Add the return pointer to the kernel arguments
             if !first_param {
                 write!(pass_args, ", ")?;
             }
-            write!(pass_args, "ret")?;
-            write!(w, "\tcudaMalloc((void**)&ret, sizeof({}));\n", ret_type)?;
+            write!(pass_args, "ret_cuda")?;
         }
+
         write!(w, "\tcudaError_t err;\n")?;
         write!(
             w,
@@ -666,18 +729,39 @@ extern \"C\" {} {}(",
             w,
             "\tif (cudaSuccess != err) {{ printf(\"Error1: %s\\n\", cudaGetErrorString(err)); }}\n"
         )?;
-        if has_ret_var {
-            // Copy return from device to host, whether it's primitive value or collection pointer
-            write!(w, "\t{} host_ret;\n", ret_type)?;
-            write!(
-                w,
-                "\tcudaMemcpy(&host_ret, ret, sizeof({}), cudaMemcpyDeviceToHost);\n",
-                ret_type
-            )?;
-            write!(w, "\treturn host_ret;\n")?;
+
+        if !is_multi_return {
+            if kernel_returns.is_empty() {
+                // A single return of a parameter, we can just return it directly
+                write!(w, "\treturn p{};\n", param_returns[0].1)?;
+            } else {
+                // A single return of a value computed on the device, we create a stack allocation
+                // and retrieve the value from the device and then return it
+                write!(w, "\t return_{} ret_host;\n", self.function.name)?;
+                write!(w,
+                    "\tcudaMemcpy(&ret_host, ret_cuda, sizeof(return_{}), cudaMemcpyDeviceToHost);\n",
+                    self.function.name,
+                )?;
+                write!(w, "\treturn ret_host.f0;\n")?;
+            }
         } else {
-            write!(w, "\treturn p{};\n", self.return_parameter.unwrap())?;
+            // Multi return is handle via an output pointer provided to this function
+            // If there are kernel returns then we copy those back from the device and then fill in
+            // the parameter returns
+            if !kernel_returns.is_empty() {
+                // Copy from the device directly into the output struct
+                write!(
+                    w,
+                    "\tcudaMemcpy(ret_ptr, ret_cuda, sizeof(return_{}), cudaMemcpyDeviceToHost);\n",
+                    self.function.name,
+                )?;
+            }
+            for (field_idx, param_idx) in param_returns {
+                write!(w, "\tret_ptr->f{} = p{};\n", field_idx, param_idx)?;
+            }
+            write!(w, "\treturn;\n")?;
         }
+
         write!(w, "}}\n")?;
         Ok(())
     }
@@ -1545,7 +1629,7 @@ extern \"C\" {} {}(",
         let tabs = match &self.function.nodes[id.idx()] {
             Node::Start
             | Node::Region { preds: _ }
-            | Node::Projection {
+            | Node::ControlProjection {
                 control: _,
                 selection: _,
             } => {
@@ -1557,7 +1641,9 @@ extern \"C\" {} {}(",
                 let mut succs = self.control_subgraph.succs(id);
                 let succ1 = succs.next().unwrap();
                 let succ2 = succs.next().unwrap();
-                let succ1_is_true = self.function.nodes[succ1.idx()].try_projection(1).is_some();
+                let succ1_is_true = self.function.nodes[succ1.idx()]
+                    .try_control_projection(1)
+                    .is_some();
                 let succ1_block_name = self.get_block_name(succ1, false);
                 let succ2_block_name = self.get_block_name(succ2, false);
                 write!(
@@ -1710,20 +1796,26 @@ extern \"C\" {} {}(",
                 }
                 tabs
             }
-            Node::Return { control: _, data } => {
-                if self.return_parameter.is_none() {
-                    // Since we lift return into a kernel argument, we write to that
-                    // argument upon return.
-                    let return_val = self.get_value(*data, false, false);
-                    let return_type_ptr = self.get_type(self.function.return_type, true);
-                    write!(w_term, "\tif (grid.thread_rank() == 0) {{\n")?;
-                    write!(
-                        w_term,
-                        "\t\t*(reinterpret_cast<{}>(ret)) = {};\n",
-                        return_type_ptr, return_val
-                    )?;
-                    write!(w_term, "\t}}\n")?;
+            Node::Return {
+                control: _,
+                ref data,
+            } => {
+                write!(w_term, "\tif (grid.thread_rank() == 0) {{\n")?;
+                for (idx, (data, param)) in
+                    data.iter().zip(self.return_parameters.iter()).enumerate()
+                {
+                    // For return values that are not identical to some parameter, we write it into
+                    // the output struct
+                    if !param.is_some() {
+                        write!(
+                            w_term,
+                            "\t\tret->f{} = {};\n",
+                            idx,
+                            self.get_value(*data, false, false)
+                        )?;
+                    }
                 }
+                write!(w_term, "\t}}\n")?;
                 write!(w_term, "\treturn;\n")?;
                 1
             }
diff --git a/hercules_cg/src/lib.rs b/hercules_cg/src/lib.rs
index af2420d83a550e7a50ff22962302612f21627995..446231de1677a34e6e6e0be43b380b6737905f5e 100644
--- a/hercules_cg/src/lib.rs
+++ b/hercules_cg/src/lib.rs
@@ -23,7 +23,7 @@ pub const LARGEST_ALIGNMENT: usize = 32;
  */
 pub fn get_type_alignment(types: &Vec<Type>, ty: TypeID) -> usize {
     match types[ty.idx()] {
-        Type::Control => panic!(),
+        Type::Control | Type::MultiReturn(_) => panic!(),
         Type::Boolean | Type::Integer8 | Type::UnsignedInteger8 | Type::Float8 => 1,
         Type::Integer16 | Type::UnsignedInteger16 | Type::BFloat16 => 2,
         Type::Integer32 | Type::UnsignedInteger32 | Type::Float32 => 4,
@@ -46,7 +46,7 @@ pub fn get_type_alignment(types: &Vec<Type>, ty: TypeID) -> usize {
 pub type FunctionNodeColors = (
     BTreeMap<NodeID, Device>,
     Vec<Option<Device>>,
-    Option<Device>,
+    Vec<Option<Device>>,
 );
 pub type NodeColors = BTreeMap<FunctionID, FunctionNodeColors>;
 
diff --git a/hercules_cg/src/rt.rs b/hercules_cg/src/rt.rs
index 7cbb43ad54b439c1a61eb4a5f0bf927fc2f37eae..884129c712863a49d535d23f46893d8ed246c56b 100644
--- a/hercules_cg/src/rt.rs
+++ b/hercules_cg/src/rt.rs
@@ -192,40 +192,75 @@ impl<'a> RTContext<'a> {
             }
             write!(w, "p{}: {}", idx, self.get_type(func.param_types[idx]))?;
         }
-        write!(w, ") -> {} {{", self.get_type(func.return_type))?;
+        write!(w, ") -> ")?;
+        self.write_rust_return_type(w, &func.return_types)?;
+        write!(w, " {{")?;
 
         // Dump signatures for called device functions.
-        write!(w, "extern \"C\" {{")?;
+        // For single-return functions we directly expose the device function
+        // while for multi-return functions we generate a wrapper which handles
+        // allocation of the return struct and extracting values from it. This
+        // ensures that device function signatures match what they would be in
+        // AsyncRust
         for callee_id in self.callgraph.get_callees(self.func_id) {
             if self.devices[callee_id.idx()] == Device::AsyncRust {
                 continue;
             }
             let callee = &self.module.functions[callee_id.idx()];
-            write!(w, "fn {}(", callee.name)?;
-            let mut first_param = true;
-            if self.backing_allocations[&callee_id].contains_key(&self.devices[callee_id.idx()]) {
-                first_param = false;
-                write!(w, "backing: ::hercules_rt::__RawPtrSendSync")?;
+            let is_single_return = callee.return_types.len() == 1;
+            if is_single_return {
+                write!(w, "extern \"C\" {{")?;
             }
-            for idx in 0..callee.num_dynamic_constants {
-                if first_param {
-                    first_param = false;
-                } else {
-                    write!(w, ", ")?;
+            self.write_device_signature_async(w, *callee_id, !is_single_return)?;
+            if is_single_return {
+                write!(w, ";}}")?;
+            } else {
+                // Generate the wrapper function for multi-return device functions
+                write!(w, " {{ ")?;
+                // Define the return struct
+                write!(
+                    w,
+                    "#[repr(C)] struct ReturnStruct {{ {} }} ",
+                    callee
+                        .return_types
+                        .iter()
+                        .enumerate()
+                        .map(|(idx, t)| format!("f{}: {}", idx, self.get_type(*t)))
+                        .collect::<Vec<_>>()
+                        .join(", "),
+                )?;
+                // Declare the extern function's signature
+                write!(w, "extern \"C\" {{ ")?;
+                self.write_device_signature(w, *callee_id)?;
+                write!(w, "; }}")?;
+                // Create the return struct
+                write!(w, "let mut ret_struct: ::std::mem::MaybeUninit<ReturnStruct> = ::std::mem::MaybeUninit::uninit();")?;
+                // Call the device function
+                write!(w, "{}(", callee.name)?;
+                if self.backing_allocations[&callee_id].contains_key(&self.devices[callee_id.idx()])
+                {
+                    write!(w, "backing, ")?;
                 }
-                write!(w, "dc{}: u64", idx)?;
-            }
-            for (idx, ty) in callee.param_types.iter().enumerate() {
-                if first_param {
-                    first_param = false;
-                } else {
-                    write!(w, ", ")?;
+                for idx in 0..callee.num_dynamic_constants {
+                    write!(w, "dc{}, ", idx)?;
+                }
+                for idx in 0..callee.param_types.len() {
+                    write!(w, "p{}, ", idx)?;
                 }
-                write!(w, "p{}: {}", idx, self.get_type(*ty))?;
+                write!(w, "ret_struct.as_mut_ptr());")?;
+                // Extract the result into a Rust product
+                write!(w, "let ret_struct = ret_struct.assume_init();")?;
+                write!(
+                    w,
+                    "({})",
+                    (0..callee.return_types.len())
+                        .map(|idx| format!("ret_struct.f{}", idx))
+                        .collect::<Vec<_>>()
+                        .join(", "),
+                )?;
+                write!(w, "}}")?;
             }
-            write!(w, ") -> {};", self.get_type(callee.return_type))?;
         }
-        write!(w, "}}")?;
 
         // Set up the root environment for the function. An environment is set
         // up for every created task in async closures, and there needs to be a
@@ -301,7 +336,7 @@ impl<'a> RTContext<'a> {
             // successor and are otherwise simple.
             Node::Start
             | Node::Region { preds: _ }
-            | Node::Projection {
+            | Node::ControlProjection {
                 control: _,
                 selection: _,
             } => {
@@ -320,7 +355,7 @@ impl<'a> RTContext<'a> {
                 let mut succs = self.control_subgraph.succs(id);
                 let succ1 = succs.next().unwrap();
                 let succ2 = succs.next().unwrap();
-                let succ1_is_true = func.nodes[succ1.idx()].try_projection(1).is_some();
+                let succ1_is_true = func.nodes[succ1.idx()].try_control_projection(1).is_some();
                 write!(
                     epilogue,
                     "control_token = if {} {{{}}} else {{{}}};}}",
@@ -329,11 +364,25 @@ impl<'a> RTContext<'a> {
                     if succ1_is_true { succ2 } else { succ1 }.idx(),
                 )?;
             }
-            Node::Return { control: _, data } => {
+            Node::Return {
+                control: _,
+                ref data,
+            } => {
                 let prologue = &mut blocks.get_mut(&id).unwrap().prologue;
                 write!(prologue, "{} => {{", id.idx())?;
                 let epilogue = &mut blocks.get_mut(&id).unwrap().epilogue;
-                write!(epilogue, "return {};}}", self.get_value(data, id, false))?;
+                if data.len() == 1 {
+                    write!(epilogue, "return {};}}", self.get_value(data[0], id, false))?;
+                } else {
+                    write!(
+                        epilogue,
+                        "return ({});}}",
+                        data.iter()
+                            .map(|v| self.get_value(*v, id, false))
+                            .collect::<Vec<_>>()
+                            .join(", "),
+                    )?;
+                }
             }
             // Fork nodes open a new environment for defining an async closure.
             Node::Fork {
@@ -590,8 +639,8 @@ impl<'a> RTContext<'a> {
                 ref args,
             } => {
                 assert_eq!(control, bb);
-                // The device backends ensure that device functions have the
-                // same interface as AsyncRust functions.
+                // The device backends and the wrappers we generated earlier ensure that device
+                // functions have the same interface as AsyncRust functions.
                 let block = &mut blocks.get_mut(&bb).unwrap();
                 let block = &mut block.data;
                 let is_async = func.schedules[id.idx()].contains(&Schedule::AsyncCall);
@@ -644,6 +693,33 @@ impl<'a> RTContext<'a> {
                 }
                 write!(block, "){};", postfix)?;
             }
+            Node::DataProjection { data, selection } => {
+                let block = &mut blocks.get_mut(&bb).unwrap().data;
+                let Node::Call {
+                    function: callee_id,
+                    ..
+                } = func.nodes[data.idx()]
+                else {
+                    panic!()
+                };
+                if self.module.functions[callee_id.idx()].return_types.len() == 1 {
+                    assert!(selection == 0);
+                    write!(
+                        block,
+                        "{} = {};",
+                        self.get_value(id, bb, true),
+                        self.get_value(data, bb, false),
+                    )?;
+                } else {
+                    write!(
+                        block,
+                        "{} = {}.{};",
+                        self.get_value(id, bb, true),
+                        self.get_value(data, bb, false),
+                        selection,
+                    )?;
+                }
+            }
             Node::IntrinsicCall {
                 intrinsic,
                 ref args,
@@ -1041,7 +1117,7 @@ impl<'a> RTContext<'a> {
      */
     fn codegen_type_size(&self, ty: TypeID) -> String {
         match self.module.types[ty.idx()] {
-            Type::Control => panic!(),
+            Type::Control | Type::MultiReturn(_) => panic!(),
             Type::Boolean | Type::Integer8 | Type::UnsignedInteger8 | Type::Float8 => {
                 "1".to_string()
             }
@@ -1149,15 +1225,7 @@ impl<'a> RTContext<'a> {
                     if is_reduce_on_child { "reduce" } else { "node" },
                     idx,
                     self.get_type(self.typing[idx]),
-                    if self.module.types[self.typing[idx].idx()].is_bool() {
-                        "false"
-                    } else if self.module.types[self.typing[idx].idx()].is_integer() {
-                        "0"
-                    } else if self.module.types[self.typing[idx].idx()].is_float() {
-                        "0.0"
-                    } else {
-                        "::hercules_rt::__RawPtrSendSync(::core::ptr::null_mut())"
-                    }
+                    self.get_default_value(self.typing[idx]),
                 )?;
             }
         }
@@ -1195,9 +1263,9 @@ impl<'a> RTContext<'a> {
         // references.
         let func = self.get_func();
         let param_devices = &self.node_colors.1;
-        let return_device = self.node_colors.2;
+        let return_devices = &self.node_colors.2;
         let mut param_muts = vec![false; func.param_types.len()];
-        let mut return_mut = true;
+        let mut return_muts = vec![true; func.return_types.len()];
         let objects = &self.collection_objects[&self.func_id];
         for idx in 0..func.param_types.len() {
             if let Some(object) = objects.param_to_object(idx)
@@ -1206,11 +1274,14 @@ impl<'a> RTContext<'a> {
                 param_muts[idx] = true;
             }
         }
-        for object in objects.returned_objects() {
-            if let Some(idx) = objects.origin(*object).try_parameter()
-                && !param_muts[idx]
-            {
-                return_mut = false;
+        let num_returns = func.return_types.len();
+        for idx in 0..num_returns {
+            for object in objects.returned_objects(idx) {
+                if let Some(param_idx) = objects.origin(*object).try_parameter()
+                    && !param_muts[param_idx]
+                {
+                    return_muts[idx] = false;
+                }
             }
         }
 
@@ -1240,27 +1311,42 @@ impl<'a> RTContext<'a> {
         }
         write!(w, "}}}}")?;
 
-        // Every reference that may be returned has the same lifetime. Every
-        // other reference gets its own unique lifetime.
-        let returned_origins: HashSet<_> = self.collection_objects[&self.func_id]
-            .returned_objects()
-            .into_iter()
-            .map(|obj| self.collection_objects[&self.func_id].origin(*obj))
+        // Each returned reference, input reference, and the runner will have
+        // its own lifetime. We use lifetime bounds to ensure that the runner
+        // and parameters are borrowed for the lifetimes needed by the outputs
+        let returned_origins: Vec<HashSet<_>> = (0..num_returns)
+            .map(|idx| {
+                objects
+                    .returned_objects(idx)
+                    .iter()
+                    .map(|obj| objects.origin(*obj))
+                    .collect()
+            })
             .collect();
 
-        write!(w, "async fn run<'runner, 'returned")?;
-        for idx in 0..func.param_types.len() {
-            write!(w, ", 'p{}", idx)?;
+        write!(w, "async fn run<'runner:")?;
+        for (ret_idx, origins) in returned_origins.iter().enumerate() {
+            if origins.iter().any(|origin| !origin.is_parameter()) {
+                write!(w, " 'r{} +", ret_idx)?;
+            }
         }
-        write!(
-            w,
-            ">(&'{} mut self",
-            if returned_origins.iter().any(|origin| !origin.is_parameter()) {
-                "returned"
-            } else {
-                "runner"
+        for idx in 0..num_returns {
+            write!(w, ", 'r{}", idx)?;
+        }
+        for idx in 0..func.param_types.len() {
+            write!(w, ", 'p{}:", idx)?;
+            for (ret_idx, origins) in returned_origins.iter().enumerate() {
+                if origins.iter().any(|origin| {
+                    origin
+                        .try_parameter()
+                        .map(|oidx| idx == oidx)
+                        .unwrap_or(false)
+                }) {
+                    write!(w, " 'r{} +", ret_idx)?;
+                }
             }
-        )?;
+        }
+        write!(w, ">(&'runner mut self")?;
         for idx in 0..func.num_dynamic_constants {
             write!(w, ", dc_p{}: u64", idx)?;
         }
@@ -1276,37 +1362,38 @@ impl<'a> RTContext<'a> {
                 let mutability = if param_muts[idx] { "Mut" } else { "" };
                 write!(
                     w,
-                    ", p{}: ::hercules_rt::Hercules{}Ref{}<'{}>",
-                    idx,
-                    device,
-                    mutability,
-                    if returned_origins.iter().any(|origin| origin
-                        .try_parameter()
-                        .map(|oidx| idx == oidx)
-                        .unwrap_or(false))
-                    {
-                        "returned".to_string()
-                    } else {
-                        format!("p{}", idx)
-                    }
+                    ", p{}: ::hercules_rt::Hercules{}Ref{}<'p{}>",
+                    idx, device, mutability, idx,
                 )?;
             }
         }
-        if self.module.types[func.return_type.idx()].is_primitive() {
-            write!(w, ") -> {} {{", self.get_type(func.return_type))?;
-        } else {
-            let device = match return_device {
-                Some(Device::LLVM) | None => "CPU",
-                Some(Device::CUDA) => "CUDA",
-                _ => panic!(),
-            };
-            let mutability = if return_mut { "Mut" } else { "" };
-            write!(
-                w,
-                ") -> ::hercules_rt::Hercules{}Ref{}<'returned> {{",
-                device, mutability
-            )?;
-        }
+        write!(
+            w,
+            ") -> {}{}{} {{",
+            if num_returns != 1 { "(" } else { "" },
+            func.return_types
+                .iter()
+                .enumerate()
+                .map(
+                    |(ret_idx, typ)| if self.module.types[typ.idx()].is_primitive() {
+                        self.get_type(*typ)
+                    } else {
+                        let device = match return_devices[ret_idx] {
+                            Some(Device::LLVM) | None => "CPU",
+                            Some(Device::CUDA) => "CUDA",
+                            _ => panic!(),
+                        };
+                        let mutability = if return_muts[ret_idx] { "Mut" } else { "" };
+                        format!(
+                            "::hercules_rt::Hercules{}Ref{}<'r{}>",
+                            device, mutability, ret_idx
+                        )
+                    }
+                )
+                .collect::<Vec<_>>()
+                .join(", "),
+            if num_returns != 1 { ")" } else { "" },
+        )?;
 
         // Start with possibly re-allocating the backing memory if it's not
         // large enough.
@@ -1362,22 +1449,48 @@ impl<'a> RTContext<'a> {
             write!(w, "p{}, ", idx)?;
         }
         write!(w, ").await;")?;
-        if self.module.types[func.return_type.idx()].is_primitive() {
-            write!(w, "        ret")?;
+        // Return the result, appropriately wrapping pointers
+        if num_returns == 1 {
+            if self.module.types[func.return_types[0].idx()].is_primitive() {
+                write!(w, "ret")?;
+            } else {
+                let device = match return_devices[0] {
+                    Some(Device::LLVM) | None => "CPU",
+                    Some(Device::CUDA) => "CUDA",
+                    _ => panic!(),
+                };
+                let mutability = if return_muts[0] { "Mut" } else { "" };
+                write!(
+                    w,
+                    "::hercules_rt::Hercules{}Ref{}::__from_parts(ret.0, {} as usize)",
+                    device,
+                    mutability,
+                    self.codegen_type_size(func.return_types[0])
+                )?;
+            }
         } else {
-            let device = match return_device {
-                Some(Device::LLVM) | None => "CPU",
-                Some(Device::CUDA) => "CUDA",
-                _ => panic!(),
-            };
-            let mutability = if return_mut { "Mut" } else { "" };
-            write!(
-                w,
-                "::hercules_rt::Hercules{}Ref{}::__from_parts(ret.0, {} as usize)",
-                device,
-                mutability,
-                self.codegen_type_size(func.return_type)
-            )?;
+            write!(w, "(")?;
+            for (idx, typ) in func.return_types.iter().enumerate() {
+                if self.module.types[typ.idx()].is_primitive() {
+                    write!(w, "ret.{},", idx)?;
+                } else {
+                    let device = match return_devices[idx] {
+                        Some(Device::LLVM) | None => "CPU",
+                        Some(Device::CUDA) => "CUDA",
+                        _ => panic!(),
+                    };
+                    let mutability = if return_muts[idx] { "Mut" } else { "" };
+                    write!(
+                        w,
+                        "::hercules_rt::Hercules{}Ref{}::__from_parts(ret.{}.0, {} as usize),",
+                        device,
+                        mutability,
+                        idx,
+                        self.codegen_type_size(func.return_types[idx]),
+                    )?;
+                }
+            }
+            write!(w, ")")?;
         }
         write!(w, "}}}}")?;
 
@@ -1435,8 +1548,123 @@ impl<'a> RTContext<'a> {
         }
     }
 
-    fn get_type(&self, id: TypeID) -> &'static str {
-        convert_type(&self.module.types[id.idx()])
+    fn get_type(&self, id: TypeID) -> String {
+        convert_type(&self.module.types[id.idx()], &self.module.types)
+    }
+
+    fn get_default_value(&self, idx: TypeID) -> String {
+        let typ = &self.module.types[idx.idx()];
+        if typ.is_bool() {
+            "false".to_string()
+        } else if typ.is_integer() {
+            "0".to_string()
+        } else if typ.is_float() {
+            "0.0".to_string()
+        } else if let Some(ts) = typ.try_multi_return() {
+            format!(
+                "({})",
+                ts.iter()
+                    .map(|t| self.get_default_value(*t))
+                    .collect::<Vec<_>>()
+                    .join(", ")
+            )
+        } else {
+            "::hercules_rt::__RawPtrSendSync(::core::ptr::null_mut())".to_string()
+        }
+    }
+
+    fn write_rust_return_type<W: Write>(&self, w: &mut W, tys: &[TypeID]) -> Result<(), Error> {
+        if tys.len() == 1 {
+            write!(w, "{}", self.get_type(tys[0]))
+        } else {
+            write!(
+                w,
+                "({})",
+                tys.iter()
+                    .map(|t| self.get_type(*t))
+                    .collect::<Vec<_>>()
+                    .join(", "),
+            )
+        }
+    }
+
+    // Writes the signature of a device function as if it were an async function, in particular
+    // this means that if the function is multi-return it will return a product in the produced
+    // Rust code
+    // Writes from the "fn" keyword up to the end of the return type
+    fn write_device_signature_async<W: Write>(
+        &self,
+        w: &mut W,
+        func_id: FunctionID,
+        is_unsafe: bool,
+    ) -> Result<(), Error> {
+        let func = &self.module.functions[func_id.idx()];
+        write!(
+            w,
+            "{}fn {}(",
+            if is_unsafe { "unsafe " } else { "" },
+            func.name
+        )?;
+        let mut first_param = true;
+        if self.backing_allocations[&func_id].contains_key(&self.devices[func_id.idx()]) {
+            first_param = false;
+            write!(w, "backing: ::hercules_rt::__RawPtrSendSync")?;
+        }
+        for idx in 0..func.num_dynamic_constants {
+            if first_param {
+                first_param = false;
+            } else {
+                write!(w, ", ")?;
+            }
+            write!(w, "dc{}: u64", idx)?;
+        }
+        for (idx, ty) in func.param_types.iter().enumerate() {
+            if first_param {
+                first_param = false;
+            } else {
+                write!(w, ", ")?;
+            }
+            write!(w, "p{}: {}", idx, self.get_type(*ty))?;
+        }
+        write!(w, ") -> ")?;
+        self.write_rust_return_type(w, &func.return_types)
+    }
+
+    // Writes the true signature of a device function
+    // Compared to the _async version this converts multi-return into a return struct
+    fn write_device_signature<W: Write>(
+        &self,
+        w: &mut W,
+        func_id: FunctionID,
+    ) -> Result<(), Error> {
+        let func = &self.module.functions[func_id.idx()];
+        write!(w, "fn {}(", func.name)?;
+        let mut first_param = true;
+        if self.backing_allocations[&func_id].contains_key(&self.devices[func_id.idx()]) {
+            first_param = false;
+            write!(w, "backing: ::hercules_rt::__RawPtrSendSync")?;
+        }
+        for idx in 0..func.num_dynamic_constants {
+            if first_param {
+                first_param = false;
+            } else {
+                write!(w, ", ")?;
+            }
+            write!(w, "dc{}: u64", idx)?;
+        }
+        for (idx, ty) in func.param_types.iter().enumerate() {
+            if first_param {
+                first_param = false;
+            } else {
+                write!(w, ", ")?;
+            }
+            write!(w, "p{}: {}", idx, self.get_type(*ty))?;
+        }
+        if func.return_types.len() == 1 {
+            write!(w, ") -> {}", self.get_type(func.return_types[0]))
+        } else {
+            write!(w, ", ret_ptr: *mut ReturnStruct)")
+        }
     }
 
     fn library_prim_ty(&self, id: TypeID) -> &'static str {
@@ -1459,21 +1687,30 @@ impl<'a> RTContext<'a> {
     }
 }
 
-fn convert_type(ty: &Type) -> &'static str {
+fn convert_type(ty: &Type, types: &[Type]) -> String {
     match ty {
-        Type::Boolean => "bool",
-        Type::Integer8 => "i8",
-        Type::Integer16 => "i16",
-        Type::Integer32 => "i32",
-        Type::Integer64 => "i64",
-        Type::UnsignedInteger8 => "u8",
-        Type::UnsignedInteger16 => "u16",
-        Type::UnsignedInteger32 => "u32",
-        Type::UnsignedInteger64 => "u64",
-        Type::Float32 => "f32",
-        Type::Float64 => "f64",
+        Type::Boolean => "bool".to_string(),
+        Type::Integer8 => "i8".to_string(),
+        Type::Integer16 => "i16".to_string(),
+        Type::Integer32 => "i32".to_string(),
+        Type::Integer64 => "i64".to_string(),
+        Type::UnsignedInteger8 => "u8".to_string(),
+        Type::UnsignedInteger16 => "u16".to_string(),
+        Type::UnsignedInteger32 => "u32".to_string(),
+        Type::UnsignedInteger64 => "u64".to_string(),
+        Type::Float32 => "f32".to_string(),
+        Type::Float64 => "f64".to_string(),
         Type::Product(_) | Type::Summation(_) | Type::Array(_, _) => {
-            "::hercules_rt::__RawPtrSendSync"
+            "::hercules_rt::__RawPtrSendSync".to_string()
+        }
+        Type::MultiReturn(ts) => {
+            format!(
+                "({})",
+                ts.iter()
+                    .map(|t| convert_type(&types[t.idx()], types))
+                    .collect::<Vec<_>>()
+                    .join(", ")
+            )
         }
         _ => panic!(),
     }
diff --git a/hercules_ir/src/build.rs b/hercules_ir/src/build.rs
index 40538cef34e34b886348ce38bbacbb4e596d00a7..3e966d534e5514c08110f65f996e8e3f2aed29ba 100644
--- a/hercules_ir/src/build.rs
+++ b/hercules_ir/src/build.rs
@@ -392,6 +392,7 @@ impl<'a> Builder<'a> {
     pub fn create_constant_zero(&mut self, typ: TypeID) -> ConstantID {
         match &self.module.types[typ.idx()] {
             Type::Control => panic!("Cannot create constant for control types"),
+            Type::MultiReturn(..) => panic!("Cannot create constant for multi-return types"),
             Type::Boolean => self.create_constant_bool(false),
             Type::Integer8 => self.create_constant_i8(0),
             Type::Integer16 => self.create_constant_i16(0),
@@ -503,7 +504,7 @@ impl<'a> Builder<'a> {
         &mut self,
         name: &str,
         param_types: Vec<TypeID>,
-        return_type: TypeID,
+        return_types: Vec<TypeID>,
         num_dynamic_constants: u32,
         entry: bool,
     ) -> BuilderResult<(FunctionID, NodeID)> {
@@ -515,7 +516,7 @@ impl<'a> Builder<'a> {
         self.module.functions.push(Function {
             name: name.to_owned(),
             param_types,
-            return_type,
+            return_types,
             num_dynamic_constants,
             entry,
             nodes: vec![Node::Start],
@@ -594,11 +595,15 @@ impl NodeBuilder {
         };
     }
 
-    pub fn build_projection(&mut self, control: NodeID, selection: usize) {
-        self.node = Node::Projection { control, selection };
+    pub fn build_control_projection(&mut self, control: NodeID, selection: usize) {
+        self.node = Node::ControlProjection { control, selection };
     }
 
-    pub fn build_return(&mut self, control: NodeID, data: NodeID) {
+    pub fn build_data_projection(&mut self, data: NodeID, selection: usize) {
+        self.node = Node::DataProjection { data, selection };
+    }
+
+    pub fn build_return(&mut self, control: NodeID, data: Box<[NodeID]>) {
         self.node = Node::Return { control, data };
     }
 
diff --git a/hercules_ir/src/collections.rs b/hercules_ir/src/collections.rs
index 6b631519d69cf2a548164eaad58cb2574c6b70c3..60f4fb1cadcdea835975b292c25b7fac4bf3e35d 100644
--- a/hercules_ir/src/collections.rs
+++ b/hercules_ir/src/collections.rs
@@ -40,7 +40,7 @@ use crate::*;
 pub enum CollectionObjectOrigin {
     Parameter(usize),
     Constant(NodeID),
-    Call(NodeID),
+    DataProjection(NodeID),
     Undef(NodeID),
 }
 
@@ -50,7 +50,7 @@ define_id_type!(CollectionObjectID);
 pub struct FunctionCollectionObjects {
     objects_per_node: Vec<Vec<CollectionObjectID>>,
     mutated: Vec<Vec<NodeID>>,
-    returned: Vec<CollectionObjectID>,
+    returned: Vec<Vec<CollectionObjectID>>,
     origins: Vec<CollectionObjectOrigin>,
 }
 
@@ -92,8 +92,14 @@ impl FunctionCollectionObjects {
             .map(CollectionObjectID::new)
     }
 
-    pub fn returned_objects(&self) -> &Vec<CollectionObjectID> {
-        &self.returned
+    pub fn returned_objects(&self, selection: usize) -> &Vec<CollectionObjectID> {
+        &self.returned[selection]
+    }
+
+    pub fn all_returned_objects(&self) -> impl Iterator<Item = CollectionObjectID> + '_ {
+        self.returned
+            .iter()
+            .flat_map(|colls| colls.iter().map(|c| *c))
     }
 
     pub fn is_mutated(&self, object: CollectionObjectID) -> bool {
@@ -165,8 +171,9 @@ pub fn collection_objects(
         let typing = &typing[func_id.idx()];
         let reverse_postorder = &reverse_postorders[func_id.idx()];
 
-        // Find collection objects originating at parameters, constants, calls,
-        // or undefs. Each node may *originate* one collection object.
+        // Find collection objects originating at parameters, constants,
+        // data projections (of calls), or undefs.
+        // Each of these nodes may *originate* one collection object.
         let param_origins = func
             .param_types
             .iter()
@@ -181,24 +188,31 @@ pub fn collection_objects(
                 Node::Constant { id: _ } if !types[typing[idx].idx()].is_primitive() => {
                     Some(CollectionObjectOrigin::Constant(NodeID::new(idx)))
                 }
-                Node::Call {
-                    control: _,
-                    function: callee,
-                    dynamic_constants: _,
-                    args: _,
-                } if {
+                Node::DataProjection { data, selection } => {
+                    let Node::Call {
+                        control: _,
+                        function: callee,
+                        dynamic_constants: _,
+                        args: _,
+                    } = func.nodes[data.idx()]
+                    else {
+                        panic!("Data-projection's data is not a call node");
+                    };
+
                     let fco = &collection_objects[&callee];
-                    fco.returned
+                    if fco.returned[*selection]
                         .iter()
-                        .any(|returned| fco.origins[returned.idx()].try_parameter().is_none())
-                } =>
-                {
-                    // If the callee may return a new collection object, then
-                    // this call node originates a single collection object. The
-                    // node may output multiple collection objects, say if the
-                    // callee may return an object passed in as a parameter -
-                    // this is determined later.
-                    Some(CollectionObjectOrigin::Call(NodeID::new(idx)))
+                        .any(|returned| fco.origins[returned.idx()].try_parameter().is_some())
+                    {
+                        // If the callee may return a new collection object, then
+                        // this data projection node originates a single collection object. The
+                        // node may output multiple collection objects, say if the
+                        // callee may return an object passed in as a parameter -
+                        // this is determined later.
+                        Some(CollectionObjectOrigin::DataProjection(NodeID::new(idx)))
+                    } else {
+                        None
+                    }
                 }
                 Node::Undef { ty: _ } if !types[typing[idx].idx()].is_primitive() => {
                     Some(CollectionObjectOrigin::Undef(NodeID::new(idx)))
@@ -216,8 +230,8 @@ pub fn collection_objects(
         // - Reduce: reduces over an object, similar to phis.
         // - Parameter: may originate an object.
         // - Constant: may originate an object.
-        // - Call: may originate an object and may return an object passed in as
-        //   a parameter.
+        // - DataProjection: may originate an object and may return an object
+        //   passed in to its associated call as a parameter.
         // - LibraryCall: may return an object passed in as a parameter, but may
         //   not originate an object.
         // - Read: may extract a smaller object from the input - this is
@@ -228,7 +242,13 @@ pub fn collection_objects(
         //   mutation.
         // - Undef: may originate a dummy object.
         // - Ternary (select): selects between two objects, may output either.
-        let lattice = forward_dataflow(func, reverse_postorder, |inputs, id| {
+        let lattice = dataflow_global(func, reverse_postorder, |global_input, id| {
+            let inputs = get_uses(&func.nodes[id.idx()])
+                .as_ref()
+                .iter()
+                .map(|id| &global_input[id.idx()])
+                .collect::<Vec<_>>();
+
             match func.nodes[id.idx()] {
                 Node::Phi {
                     control: _,
@@ -267,22 +287,28 @@ pub fn collection_objects(
                         objs: obj.into_iter().collect(),
                     }
                 }
-                Node::Call {
-                    control: _,
-                    function: callee,
-                    dynamic_constants: _,
-                    args: _,
-                } if !types[typing[id.idx()].idx()].is_primitive() => {
+                Node::DataProjection { data, selection }
+                    if !types[typing[id.idx()].idx()].is_primitive() =>
+                {
+                    let Node::Call {
+                        control: _,
+                        function: callee,
+                        dynamic_constants: _,
+                        ref args,
+                    } = func.nodes[data.idx()]
+                    else {
+                        panic!();
+                    };
+
                     let new_obj = origins
                         .iter()
-                        .position(|origin| *origin == CollectionObjectOrigin::Call(id))
+                        .position(|origin| *origin == CollectionObjectOrigin::DataProjection(id))
                         .map(CollectionObjectID::new);
                     let fco = &collection_objects[&callee];
-                    let param_objs = fco
-                        .returned
+                    let param_objs = fco.returned[selection]
                         .iter()
                         .filter_map(|returned| fco.origins[returned.idx()].try_parameter())
-                        .map(|param_index| inputs[param_index + 1]);
+                        .map(|param_index| &global_input[args[param_index].idx()]);
 
                     let mut objs: BTreeSet<_> = new_obj.into_iter().collect();
                     for param_objs in param_objs {
@@ -324,14 +350,20 @@ pub fn collection_objects(
             .map(|l| l.objs.into_iter().collect())
             .collect();
 
-        // Look at the collection objects that each return may take as input.
-        let mut returned: BTreeSet<CollectionObjectID> = BTreeSet::new();
+        // Look at the collection objects that each return value may take as input.
+        let mut returned: Vec<BTreeSet<CollectionObjectID>> =
+            vec![BTreeSet::new(); func.return_types.len()];
         for node in func.nodes.iter() {
             if let Node::Return { control: _, data } = node {
-                returned.extend(&objects_per_node[data.idx()]);
+                for (idx, node) in data.iter().enumerate() {
+                    returned[idx].extend(&objects_per_node[node.idx()]);
+                }
             }
         }
-        let returned = returned.into_iter().collect();
+        let returned = returned
+            .into_iter()
+            .map(|set| set.into_iter().collect())
+            .collect();
 
         // Determine which objects are potentially mutated.
         let mut mutated = vec![vec![]; origins.len()];
@@ -500,16 +532,17 @@ pub fn no_reset_constant_collections(
                     collect: _,
                     data,
                     indices: _,
+                } => Either::Left(zip(once(&full_indices), once(data))),
+                Node::Return {
+                    control: _,
+                    ref data,
                 }
-                | Node::Return { control: _, data } => {
-                    Either::Left(zip(once(&full_indices), once(data)))
-                }
-                Node::Call {
+                | Node::Call {
                     control: _,
                     function: _,
                     dynamic_constants: _,
-                    ref args,
-                } => Either::Right(zip(repeat(&full_indices), args.into_iter().map(|id| *id))),
+                    args: ref data,
+                } => Either::Right(zip(repeat(&full_indices), data.into_iter().map(|id| *id))),
                 _ => return None,
             };
 
diff --git a/hercules_ir/src/def_use.rs b/hercules_ir/src/def_use.rs
index ff0e08edc8c15f76ea38243eebea7cba1940cd19..e9ba4576d06d8ba4c140f0b27328b57c9b933896 100644
--- a/hercules_ir/src/def_use.rs
+++ b/hercules_ir/src/def_use.rs
@@ -156,7 +156,11 @@ pub fn get_uses(node: &Node) -> NodeUses {
             init,
             reduct,
         } => NodeUses::Three([*control, *init, *reduct]),
-        Node::Return { control, data } => NodeUses::Two([*control, *data]),
+        Node::Return { control, data } => {
+            let mut uses: Vec<NodeID> = vec![*control];
+            uses.extend(data);
+            NodeUses::Variable(uses.into_boxed_slice())
+        }
         Node::Parameter { index: _ } => NodeUses::One([NodeID::new(0)]),
         Node::Constant { id: _ } => NodeUses::One([NodeID::new(0)]),
         Node::DynamicConstant { id: _ } => NodeUses::One([NodeID::new(0)]),
@@ -222,10 +226,11 @@ pub fn get_uses(node: &Node) -> NodeUses {
                 NodeUses::Two([*collect, *data])
             }
         }
-        Node::Projection {
+        Node::ControlProjection {
             control,
             selection: _,
         } => NodeUses::One([*control]),
+        Node::DataProjection { data, selection: _ } => NodeUses::One([*data]),
         Node::Undef { ty: _ } => NodeUses::One([NodeID::new(0)]),
     }
 }
@@ -260,7 +265,9 @@ pub fn get_uses_mut<'a>(node: &'a mut Node) -> NodeUsesMut<'a> {
             init,
             reduct,
         } => NodeUsesMut::Three([control, init, reduct]),
-        Node::Return { control, data } => NodeUsesMut::Two([control, data]),
+        Node::Return { control, data } => {
+            NodeUsesMut::Variable(std::iter::once(control).chain(data.iter_mut()).collect())
+        }
         Node::Parameter { index: _ } => NodeUsesMut::Zero,
         Node::Constant { id: _ } => NodeUsesMut::Zero,
         Node::DynamicConstant { id: _ } => NodeUsesMut::Zero,
@@ -326,10 +333,11 @@ pub fn get_uses_mut<'a>(node: &'a mut Node) -> NodeUsesMut<'a> {
                 NodeUsesMut::Two([collect, data])
             }
         }
-        Node::Projection {
+        Node::ControlProjection {
             control,
             selection: _,
         } => NodeUsesMut::One([control]),
+        Node::DataProjection { data, selection: _ } => NodeUsesMut::One([data]),
         Node::Undef { ty: _ } => NodeUsesMut::Zero,
     }
 }
diff --git a/hercules_ir/src/dot.rs b/hercules_ir/src/dot.rs
index 921a813d7c9d680defe2e9dd8fa750ddf960d0cb..aff1f9c52eb8735db828620a15e27a6395b20f9f 100644
--- a/hercules_ir/src/dot.rs
+++ b/hercules_ir/src/dot.rs
@@ -349,6 +349,11 @@ fn write_node<W: Write>(
                 }
             }
         }
+        Node::ControlProjection {
+            control: _,
+            selection,
+        }
+        | Node::DataProjection { data: _, selection } => write!(&mut suffix, "{}", selection)?,
         _ => {}
     };
 
diff --git a/hercules_ir/src/ir.rs b/hercules_ir/src/ir.rs
index f91efe584c7422e2d7e1e542fed7141fcf684f53..5dfe2915f5f3e30f56b6665dc27d23cd40cca3d4 100644
--- a/hercules_ir/src/ir.rs
+++ b/hercules_ir/src/ir.rs
@@ -39,7 +39,7 @@ pub struct Module {
 pub struct Function {
     pub name: String,
     pub param_types: Vec<TypeID>,
-    pub return_type: TypeID,
+    pub return_types: Vec<TypeID>,
     pub num_dynamic_constants: u32,
     pub entry: bool,
 
@@ -77,6 +77,7 @@ pub enum Type {
     Product(Box<[TypeID]>),
     Summation(Box<[TypeID]>),
     Array(TypeID, Box<[DynamicConstantID]>),
+    MultiReturn(Box<[TypeID]>),
 }
 
 /*
@@ -186,7 +187,7 @@ pub enum Node {
     },
     Return {
         control: NodeID,
-        data: NodeID,
+        data: Box<[NodeID]>,
     },
     Parameter {
         index: usize,
@@ -237,10 +238,14 @@ pub enum Node {
         data: NodeID,
         indices: Box<[Index]>,
     },
-    Projection {
+    ControlProjection {
         control: NodeID,
         selection: usize,
     },
+    DataProjection {
+        data: NodeID,
+        selection: usize,
+    },
     Undef {
         ty: TypeID,
     },
@@ -434,6 +439,17 @@ impl Module {
                 }
                 write!(w, ")")
             }
+            Type::MultiReturn(fields) => {
+                write!(w, "MultiReturn(")?;
+                for idx in 0..fields.len() {
+                    let field_ty_id = fields[idx];
+                    self.write_type(field_ty_id, w)?;
+                    if idx + 1 < fields.len() {
+                        write!(w, ", ")?;
+                    }
+                }
+                write!(w, ")")
+            }
         }?;
 
         Ok(())
@@ -847,6 +863,14 @@ impl Type {
         }
     }
 
+    pub fn is_multireturn(&self) -> bool {
+        if let Type::MultiReturn(_) = self {
+            true
+        } else {
+            false
+        }
+    }
+
     pub fn is_bool(&self) -> bool {
         self == &Type::Boolean
     }
@@ -955,6 +979,14 @@ impl Type {
         }
     }
 
+    pub fn try_multi_return(&self) -> Option<&[TypeID]> {
+        if let Type::MultiReturn(ts) = self {
+            Some(ts)
+        } else {
+            None
+        }
+    }
+
     pub fn num_bits(&self) -> u8 {
         match self {
             Type::Boolean => 1,
@@ -1294,12 +1326,19 @@ impl Node {
     );
     define_pattern_predicate!(is_match, Node::Match { control: _, sum: _ });
     define_pattern_predicate!(
-        is_projection,
-        Node::Projection {
+        is_control_projection,
+        Node::ControlProjection {
             control: _,
             selection: _
         }
     );
+    define_pattern_predicate!(
+        is_data_projection,
+        Node::DataProjection {
+            data: _,
+            selection: _
+        }
+    );
 
     define_pattern_predicate!(is_undef, Node::Undef { ty: _ });
 
@@ -1319,14 +1358,22 @@ impl Node {
         }
     }
 
-    pub fn try_proj(&self) -> Option<(NodeID, usize)> {
-        if let Node::Projection { control, selection } = self {
+    pub fn try_control_proj(&self) -> Option<(NodeID, usize)> {
+        if let Node::ControlProjection { control, selection } = self {
             Some((*control, *selection))
         } else {
             None
         }
     }
 
+    pub fn try_data_proj(&self) -> Option<(NodeID, usize)> {
+        if let Node::DataProjection { data, selection } = self {
+            Some((*data, *selection))
+        } else {
+            None
+        }
+    }
+
     pub fn try_phi(&self) -> Option<(NodeID, &[NodeID])> {
         if let Node::Phi { control, data } = self {
             Some((*control, data))
@@ -1335,9 +1382,9 @@ impl Node {
         }
     }
 
-    pub fn try_return(&self) -> Option<(NodeID, NodeID)> {
+    pub fn try_return(&self) -> Option<(NodeID, &[NodeID])> {
         if let Node::Return { control, data } = self {
-            Some((*control, *data))
+            Some((*control, data))
         } else {
             None
         }
@@ -1511,8 +1558,8 @@ impl Node {
         }
     }
 
-    pub fn try_projection(&self, branch: usize) -> Option<NodeID> {
-        if let Node::Projection { control, selection } = self
+    pub fn try_control_projection(&self, branch: usize) -> Option<NodeID> {
+        if let Node::ControlProjection { control, selection } = self
             && branch == *selection
         {
             Some(*control)
@@ -1592,10 +1639,14 @@ impl Node {
                 data: _,
                 indices: _,
             } => "Write",
-            Node::Projection {
+            Node::ControlProjection {
                 control: _,
                 selection: _,
-            } => "Projection",
+            } => "ControlProjection",
+            Node::DataProjection {
+                data: _,
+                selection: _,
+            } => "DataProjection",
             Node::Undef { ty: _ } => "Undef",
         }
     }
@@ -1671,10 +1722,14 @@ impl Node {
                 data: _,
                 indices: _,
             } => "write",
-            Node::Projection {
+            Node::ControlProjection {
                 control: _,
                 selection: _,
-            } => "projection",
+            } => "control_projection",
+            Node::DataProjection {
+                data: _,
+                selection: _,
+            } => "data_projection",
             Node::Undef { ty: _ } => "undef",
         }
     }
@@ -1687,7 +1742,7 @@ impl Node {
             || self.is_fork()
             || self.is_join()
             || self.is_return()
-            || self.is_projection()
+            || self.is_control_projection()
     }
 }
 
diff --git a/hercules_ir/src/parse.rs b/hercules_ir/src/parse.rs
index a38df8e98f9e4527ebe23ffbc39a4184ce714f90..9462df4db528c410834ab4a5052ef823aa6a92e0 100644
--- a/hercules_ir/src/parse.rs
+++ b/hercules_ir/src/parse.rs
@@ -140,7 +140,7 @@ fn parse_module<'a>(ir_text: &'a str, context: Context<'a>) -> nom::IResult<&'a
         Function {
             name: String::from(""),
             param_types: vec![],
-            return_type: TypeID::new(0),
+            return_types: vec![],
             num_dynamic_constants: 0,
             entry: true,
             nodes: vec![],
@@ -247,7 +247,15 @@ fn parse_function<'a>(
     let ir_text = nom::character::complete::char(')')(ir_text)?.0;
     let ir_text = nom::character::complete::multispace0(ir_text)?.0;
     let ir_text = nom::bytes::complete::tag("->")(ir_text)?.0;
-    let (ir_text, return_type) = parse_type_id(ir_text, context)?;
+    let (ir_text, return_types) = nom::multi::separated_list1(
+        (
+            nom::character::complete::multispace0,
+            nom::character::complete::char(','),
+            nom::character::complete::multispace0,
+        ),
+        |text| parse_type_id(text, context),
+    )
+    .parse(ir_text)?;
     let (ir_text, nodes) = nom::multi::many1(|x| parse_node(x, context)).parse(ir_text)?;
 
     // `nodes`, as returned by parsing, is in parse order, which may differ from
@@ -288,7 +296,7 @@ fn parse_function<'a>(
         Function {
             name: String::from(function_name),
             param_types: params.into_iter().map(|x| x.5).collect(),
-            return_type,
+            return_types,
             num_dynamic_constants,
             entry: true,
             nodes: fixed_nodes,
@@ -336,7 +344,8 @@ fn parse_node<'a>(
         "return" => parse_return(ir_text, context)?,
         "constant" => parse_constant_node(ir_text, context)?,
         "dynamic_constant" => parse_dynamic_constant_node(ir_text, context)?,
-        "projection" => parse_projection(ir_text, context)?,
+        "control_projection" => parse_control_projection(ir_text, context)?,
+        "data_projection" => parse_data_projection(ir_text, context)?,
         // Unary and binary ops are spelled out in the textual format, but we
         // parse them into Unary or Binary node kinds.
         "not" => parse_unary(ir_text, context, UnaryOperator::Not)?,
@@ -491,9 +500,28 @@ fn parse_return<'a>(
     ir_text: &'a str,
     context: &RefCell<Context<'a>>,
 ) -> nom::IResult<&'a str, Node> {
-    let (ir_text, (control, data)) = parse_tuple2(parse_identifier, parse_identifier)(ir_text)?;
+    let ir_text = nom::character::complete::multispace0(ir_text)?.0;
+    let ir_text = nom::character::complete::char('(')(ir_text)?.0;
+    let (ir_text, control) = parse_identifier(ir_text)?;
+    let ir_text = nom::character::complete::multispace0(ir_text)?.0;
+    let ir_text = nom::character::complete::char(',')(ir_text)?.0;
+    let ir_text = nom::character::complete::multispace0(ir_text)?.0;
+    let (ir_text, data) = nom::multi::separated_list1(
+        (
+            nom::character::complete::multispace0,
+            nom::character::complete::char(','),
+            nom::character::complete::multispace0,
+        ),
+        parse_identifier,
+    )
+    .parse(ir_text)?;
+    let ir_text = nom::character::complete::multispace0(ir_text)?.0;
+    let ir_text = nom::character::complete::char(')')(ir_text)?.0;
     let control = context.borrow_mut().get_node_id(control);
-    let data = context.borrow_mut().get_node_id(data);
+    let data = data
+        .into_iter()
+        .map(|d| context.borrow_mut().get_node_id(d))
+        .collect();
     Ok((ir_text, Node::Return { control, data }))
 }
 
@@ -726,7 +754,7 @@ fn parse_index<'a>(
     Ok((ir_text, idx))
 }
 
-fn parse_projection<'a>(
+fn parse_control_projection<'a>(
     ir_text: &'a str,
     context: &RefCell<Context<'a>>,
 ) -> nom::IResult<&'a str, Node> {
@@ -735,13 +763,29 @@ fn parse_projection<'a>(
     let control = context.borrow_mut().get_node_id(control);
     Ok((
         ir_text,
-        Node::Projection {
+        Node::ControlProjection {
             control,
             selection: index,
         },
     ))
 }
 
+fn parse_data_projection<'a>(
+    ir_text: &'a str,
+    context: &RefCell<Context<'a>>,
+) -> nom::IResult<&'a str, Node> {
+    let parse_usize = |x| parse_prim::<usize>(x, "1234567890");
+    let (ir_text, (data, index)) = parse_tuple2(parse_identifier, parse_usize)(ir_text)?;
+    let data = context.borrow_mut().get_node_id(data);
+    Ok((
+        ir_text,
+        Node::DataProjection {
+            data,
+            selection: index,
+        },
+    ))
+}
+
 fn parse_read<'a>(ir_text: &'a str, context: &RefCell<Context<'a>>) -> nom::IResult<&'a str, Node> {
     let ir_text = nom::character::complete::multispace0(ir_text)?.0;
     let ir_text = nom::character::complete::char('(')(ir_text)?.0;
@@ -1002,7 +1046,7 @@ fn parse_constant<'a>(
 ) -> nom::IResult<&'a str, Constant> {
     let (ir_text, constant) = match ty {
         // There are not control constants.
-        Type::Control => Err(nom::Err::Error(nom::error::Error {
+        Type::Control | Type::MultiReturn(_) => Err(nom::Err::Error(nom::error::Error {
             input: ir_text,
             code: nom::error::ErrorKind::IsNot,
         }))?,
diff --git a/hercules_ir/src/typecheck.rs b/hercules_ir/src/typecheck.rs
index 1ff890db4ff4ed3678be0338106f4add2b829aba..2a3f9fb1aa86dd092d048bf79b51aa118d2489b8 100644
--- a/hercules_ir/src/typecheck.rs
+++ b/hercules_ir/src/typecheck.rs
@@ -423,8 +423,8 @@ fn typeflow(
             control: _,
             data: _,
         } => {
-            if inputs.len() != 2 {
-                return Error(String::from("Return node must have exactly two inputs."));
+            if inputs.len() < 1 {
+                return Error(String::from("Return node must have at least one input."));
             }
 
             // Check type of control input first, since this may produce an
@@ -441,12 +441,18 @@ fn typeflow(
                 return inputs[0].clone();
             }
 
-            if let Concrete(id) = inputs[1] {
-                if *id != function.return_type {
-                    return Error(String::from("Return node's data input type must be the same as the function's return type."));
+            for (idx, (input, return_type)) in inputs[1..]
+                .iter()
+                .zip(function.return_types.iter())
+                .enumerate()
+            {
+                if let Concrete(id) = input {
+                    if *id != *return_type {
+                        return Error(format!("Return node's data input at index {} does not match function's return type.", idx));
+                    }
+                } else if input.is_error() {
+                    return (*input).clone();
                 }
-            } else if inputs[1].is_error() {
-                return inputs[1].clone();
             }
 
             Concrete(get_type_id(
@@ -759,7 +765,7 @@ fn typeflow(
                 }
             }
 
-            Concrete(subst.type_subst(callee.return_type))
+            Concrete(subst.build_return_type(&callee.return_types))
         }
         Node::IntrinsicCall { intrinsic, args: _ } => {
             let num_params = match intrinsic {
@@ -1061,13 +1067,32 @@ fn typeflow(
                 TypeSemilattice::Error(msg) => TypeSemilattice::Error(msg),
             }
         }
-        Node::Projection {
+        Node::ControlProjection {
             control: _,
             selection: _,
         } => {
             // Type is the type of the _if node
             inputs[0].clone()
         }
+        Node::DataProjection { data: _, selection } => {
+            if let Concrete(type_id) = inputs[0] {
+                match &types[type_id.idx()] {
+                    Type::MultiReturn(types) => {
+                        if *selection >= types.len() {
+                            return Error(String::from("Data projection's selection must be in range of the multi-return being indexed"));
+                        }
+                        return Concrete(types[*selection]);
+                    }
+                    _ => {
+                        return Error(String::from(
+                            "Data projection node must read from multi-return value.",
+                        ));
+                    }
+                }
+            }
+
+            inputs[0].clone()
+        }
         Node::Undef { ty } => TypeSemilattice::Concrete(*ty),
     }
 }
@@ -1138,6 +1163,11 @@ impl<'a> DCSubst<'a> {
         }
     }
 
+    fn build_return_type(&mut self, tys: &[TypeID]) -> TypeID {
+        let tys = tys.iter().map(|t| self.type_subst(*t)).collect();
+        self.intern_type(Type::MultiReturn(tys))
+    }
+
     fn type_subst(&mut self, typ: TypeID) -> TypeID {
         match &self.types[typ.idx()] {
             Type::Control
@@ -1172,6 +1202,7 @@ impl<'a> DCSubst<'a> {
                 let new_elem = self.type_subst(elem);
                 self.intern_type(Type::Array(new_elem, new_dims))
             }
+            Type::MultiReturn(..) => panic!("A multi-return type should never be substituted"),
         }
     }
 
diff --git a/hercules_ir/src/verify.rs b/hercules_ir/src/verify.rs
index f188932e3a362760cc8855b43c9fa9fea21cbe42..b50ab0d211f9563fe422add6511d0ff03c1b53a3 100644
--- a/hercules_ir/src/verify.rs
+++ b/hercules_ir/src/verify.rs
@@ -251,11 +251,11 @@ fn verify_structure(
                     Err(format!("If node must have 2 users, not {}.", users.len()))?;
                 }
                 if let (
-                    Node::Projection {
+                    Node::ControlProjection {
                         control: _,
                         selection: result1,
                     },
-                    Node::Projection {
+                    Node::ControlProjection {
                         control: _,
                         selection: result2,
                     },
@@ -290,7 +290,8 @@ fn verify_structure(
                     Err("ThreadID node's control input must be a fork node.")?;
                 }
             }
-            // Call nodes must depend on a region node.
+            // Call nodes must depend on a region node and its only users must
+            // be DataProjections.
             Node::Call {
                 control,
                 function: _,
@@ -300,6 +301,11 @@ fn verify_structure(
                 if !function.nodes[control.idx()].is_region() {
                     Err("Call node's control input must be a region node.")?;
                 }
+                for user in users {
+                    if !function.nodes[user.idx()].is_data_projection() {
+                        Err("Call node users must be DataProjection nodes.")?;
+                    }
+                }
             }
             // Reduce nodes must depend on a join node.
             Node::Reduce {
@@ -339,7 +345,7 @@ fn verify_structure(
                     }
                     let mut users_covered = bitvec![u8, Lsb0; 0; users.len()];
                     for user in users {
-                        if let Node::Projection {
+                        if let Node::ControlProjection {
                             control: _,
                             ref selection,
                         } = function.nodes[user.idx()]
diff --git a/hercules_opt/src/ccp.rs b/hercules_opt/src/ccp.rs
index b626148c936e384b6bc0a9aaf951c35a9c4b4736..87d23c11aeb4759a8b1c288fb27ba33845f1dc18 100644
--- a/hercules_opt/src/ccp.rs
+++ b/hercules_opt/src/ccp.rs
@@ -697,6 +697,11 @@ fn ccp_flow_function(
             }),
             constant: ConstantLattice::bottom(),
         },
+        // Data projections are uninterpretable.
+        Node::DataProjection { data, selection: _ } => CCPLattice {
+            reachability: inputs[data.idx()].reachability.clone(),
+            constant: ConstantLattice::bottom(),
+        },
         Node::IntrinsicCall { intrinsic, args } => {
             let mut new_reachability = ReachabilityLattice::bottom();
             let mut new_constant = ConstantLattice::top();
@@ -961,8 +966,9 @@ fn ccp_flow_function(
                 constant: ConstantLattice::bottom(),
             }
         }
-        // Projection handles reachability when following an if or match.
-        Node::Projection { control, selection } => match &editor.func().nodes[control.idx()] {
+        // Control projection handles reachability when following an if or match.
+        Node::ControlProjection { control, selection } => match &editor.func().nodes[control.idx()]
+        {
             Node::If { control: _, cond } => {
                 let cond_constant = &inputs[cond.idx()].constant;
                 let if_reachability = &inputs[control.idx()].reachability;
diff --git a/hercules_opt/src/editor.rs b/hercules_opt/src/editor.rs
index 57fe204245fc0adf75c8a046cfd9db25ebed119e..9cf5af72e2945b35337a7e1da871d9a4c5e06fb5 100644
--- a/hercules_opt/src/editor.rs
+++ b/hercules_opt/src/editor.rs
@@ -69,7 +69,7 @@ pub struct FunctionEdit<'a: 'b, 'b> {
     // Compute a def-use map entries iteratively.
     updated_def_use: BTreeMap<NodeID, HashSet<NodeID>>,
     updated_param_types: Option<Vec<TypeID>>,
-    updated_return_type: Option<TypeID>,
+    updated_return_types: Option<Vec<TypeID>>,
     // Keep track of which deleted and added node IDs directly correspond.
     sub_edits: Vec<(NodeID, NodeID)>,
 }
@@ -208,7 +208,7 @@ impl<'a: 'b, 'b> FunctionEditor<'a> {
             added_labels: Vec::new().into(),
             updated_def_use: BTreeMap::new(),
             updated_param_types: None,
-            updated_return_type: None,
+            updated_return_types: None,
             sub_edits: vec![],
         };
 
@@ -228,7 +228,7 @@ impl<'a: 'b, 'b> FunctionEditor<'a> {
                 added_labels,
                 updated_def_use,
                 updated_param_types,
-                updated_return_type,
+                updated_return_types,
                 sub_edits,
             } = populated_edit;
 
@@ -358,8 +358,8 @@ impl<'a: 'b, 'b> FunctionEditor<'a> {
             }
 
             // Step 9: update return type if necessary.
-            if let Some(return_type) = updated_return_type {
-                editor.function.return_type = return_type;
+            if let Some(return_types) = updated_return_types {
+                editor.function.return_types = return_types;
             }
 
             true
@@ -768,6 +768,9 @@ impl<'a, 'b> FunctionEdit<'a, 'b> {
             }
             Type::Summation(tys) => Constant::Summation(id, 0, self.add_zero_constant(tys[0])),
             Type::Array(_, _) => Constant::Array(id),
+            Type::MultiReturn(_) => {
+                panic!("PANIC: Can't create zero constant for multi-return types.")
+            }
         };
         self.add_constant(constant_to_construct)
     }
@@ -791,6 +794,9 @@ impl<'a, 'b> FunctionEdit<'a, 'b> {
             Type::Product(_) | Type::Summation(_) | Type::Array(_, _) => {
                 panic!("PANIC: Can't create one constant of a collection type.")
             }
+            Type::MultiReturn(_) => {
+                panic!("PANIC: Can't create one constant for multi-return types.")
+            }
         };
         self.add_constant(constant_to_construct)
     }
@@ -814,6 +820,9 @@ impl<'a, 'b> FunctionEdit<'a, 'b> {
             Type::Product(_) | Type::Summation(_) | Type::Array(_, _) => {
                 panic!("PANIC: Can't create largest constant of a collection type.")
             }
+            Type::MultiReturn(_) => {
+                panic!("PANIC: Can't create largest constant for multi-return types.")
+            }
         };
         self.add_constant(constant_to_construct)
     }
@@ -837,6 +846,9 @@ impl<'a, 'b> FunctionEdit<'a, 'b> {
             Type::Product(_) | Type::Summation(_) | Type::Array(_, _) => {
                 panic!("PANIC: Can't create smallest constant of a collection type.")
             }
+            Type::MultiReturn(_) => {
+                panic!("PANIC: Can't create smallest constant for multi-return types.")
+            }
         };
         self.add_constant(constant_to_construct)
     }
@@ -881,8 +893,8 @@ impl<'a, 'b> FunctionEdit<'a, 'b> {
         self.updated_param_types = Some(tys);
     }
 
-    pub fn set_return_type(&mut self, ty: TypeID) {
-        self.updated_return_type = Some(ty);
+    pub fn set_return_types(&mut self, tys: Vec<TypeID>) {
+        self.updated_return_types = Some(tys);
     }
 }
 
diff --git a/hercules_opt/src/fork_guard_elim.rs b/hercules_opt/src/fork_guard_elim.rs
index df40e60f89f0490cacb35d6e9754f3b134ed1483..c480f266f683cec84f0517db9842dd98566f22b9 100644
--- a/hercules_opt/src/fork_guard_elim.rs
+++ b/hercules_opt/src/fork_guard_elim.rs
@@ -95,7 +95,7 @@ fn guarded_fork(
     });
 
     // Whose predecessor is a read from an if
-    let Node::Projection {
+    let Node::ControlProjection {
         control: if_node,
         ref selection,
     } = function.nodes[control.idx()]
@@ -226,7 +226,7 @@ fn guarded_fork(
         return None;
     };
     // Other predecessor needs to be the other projection from the guard's if
-    let Node::Projection {
+    let Node::ControlProjection {
         control: if_node2,
         ref selection,
     } = function.nodes[other_pred.idx()]
diff --git a/hercules_opt/src/gcm.rs b/hercules_opt/src/gcm.rs
index f240589300520d680c1e126ba93728977f1d17e8..d311970517260466d8712c430c20ab38fb4aec2b 100644
--- a/hercules_opt/src/gcm.rs
+++ b/hercules_opt/src/gcm.rs
@@ -127,7 +127,7 @@ pub fn gcm(
     let mut alignments = vec![];
     Ref::map(editor.get_types(), |types| {
         for idx in 0..types.len() {
-            if types[idx].is_control() {
+            if types[idx].is_control() || types[idx].is_multireturn() {
                 alignments.push(0);
             } else {
                 alignments.push(get_type_alignment(types, TypeID::new(idx)));
@@ -255,6 +255,12 @@ fn basic_blocks(
                 dynamic_constants: _,
                 args: _,
             } => bbs[idx] = Some(control),
+            Node::DataProjection { data, selection: _ } => {
+                let Node::Call { control, .. } = function.nodes[data.idx()] else {
+                    panic!();
+                };
+                bbs[idx] = Some(control);
+            }
             Node::Parameter { index: _ } => bbs[idx] = Some(NodeID::new(0)),
             _ if function.nodes[idx].is_control() => bbs[idx] = Some(NodeID::new(idx)),
             _ => {}
@@ -505,10 +511,11 @@ fn basic_blocks(
                     || function.nodes[id.idx()].is_undef())
                     && !types[typing[id.idx()].idx()].is_primitive();
                 let is_gpu_returned = devices[func_id.idx()] == Device::CUDA
-                    && objects[&func_id]
-                        .objects(id)
-                        .into_iter()
-                        .any(|obj| objects[&func_id].returned_objects().contains(obj));
+                    && objects[&func_id].objects(id).into_iter().any(|obj| {
+                        objects[&func_id]
+                            .all_returned_objects()
+                            .any(|ret| ret == *obj)
+                    });
                 let old_nest = loops
                     .header_of(location)
                     .map(|header| loops.nesting(header).unwrap());
@@ -646,9 +653,9 @@ fn terminating_reads<'a>(
             ref args,
         } => Box::new(args.into_iter().enumerate().filter_map(move |(idx, arg)| {
             let objects = &objects[&callee];
-            let returns = objects.returned_objects();
+            let mut returns = objects.all_returned_objects();
             let param_obj = objects.param_to_object(idx)?;
-            if !objects.is_mutated(param_obj) && !returns.contains(&param_obj) {
+            if !objects.is_mutated(param_obj) && !returns.any(|ret| ret == param_obj) {
                 Some(*arg)
             } else {
                 None
@@ -692,9 +699,9 @@ fn forwarding_reads<'a>(
             ref args,
         } => Box::new(args.into_iter().enumerate().filter_map(move |(idx, arg)| {
             let objects = &objects[&callee];
-            let returns = objects.returned_objects();
+            let mut returns = objects.all_returned_objects();
             let param_obj = objects.param_to_object(idx)?;
-            if !objects.is_mutated(param_obj) && returns.contains(&param_obj) {
+            if !objects.is_mutated(param_obj) && returns.any(|ret| ret == param_obj) {
                 Some(*arg)
             } else {
                 None
@@ -1218,7 +1225,7 @@ fn color_nodes(
     let mut func_colors = (
         BTreeMap::new(),
         vec![None; editor.func().param_types.len()],
-        None,
+        vec![None; editor.func().return_types.len()],
     );
 
     // Assigning nodes to devices is tricky due to function calls. Technically,
@@ -1320,17 +1327,29 @@ fn color_nodes(
                         equations.push((UTerm::Node(*arg), UTerm::Device(device)));
                     }
                 }
+            }
+            Node::DataProjection { data, selection } => {
+                let Node::Call {
+                    control: _,
+                    function: callee,
+                    dynamic_constants: _,
+                    ref args,
+                } = &nodes[data.idx()]
+                else {
+                    panic!()
+                };
 
-                // If the callee has a definite device for the returned value,
-                // add an equation for the call node itself.
-                if let Some(device) = node_colors[&callee].2 {
+                // If the callee has a definite device for this returned value,
+                // add an equation for the data projection node itself.
+                if let Some(device) = node_colors[&callee].2[selection] {
                     equations.push((UTerm::Node(id), UTerm::Device(device)));
                 }
 
-                // For any object that may be returned by the callee that
-                // originates as a parameter in the callee, the device of the
-                // corresponding argument and call node itself must be equal.
-                for ret in objects[&callee].returned_objects() {
+                // For any object that may be returned in this position by the
+                // callee that originates as a parameter in the callee, the
+                // device of the corresponding argument and the data projection
+                // must be equal.
+                for ret in objects[&callee].returned_objects(selection) {
                     if let Some(idx) = objects[&callee].origin(*ret).try_parameter() {
                         equations.push((UTerm::Node(args[idx]), UTerm::Node(id)));
                     }
@@ -1365,11 +1384,17 @@ fn color_nodes(
                 {
                     assert!(func_colors.1[index].is_none(), "PANIC: Found multiple parameter nodes for the same index in GCM. Please just run GVN first.");
                     func_colors.1[index] = Some(*device);
-                } else if let Node::Return { control: _, data } = nodes[id.idx()]
-                    && let Some(device) = func_colors.0.get(&data)
+                } else if let Node::Return {
+                    control: _,
+                    ref data,
+                } = nodes[id.idx()]
                 {
-                    assert!(func_colors.2.is_none(), "PANIC: Found multiple return nodes in GCM. Contact Russel if you see this, it's an easy fix.");
-                    func_colors.2 = Some(*device);
+                    for (idx, val) in data.iter().enumerate() {
+                        if let Some(device) = func_colors.0.get(val) {
+                            assert!(func_colors.2[idx].is_none(), "PANIC: Found multiple return nodes in GCM. Contact Russel if you see this, it's an easy fix.");
+                            func_colors.2[idx] = Some(*device);
+                        }
+                    }
                 }
             }
             Some(func_colors)
@@ -1420,6 +1445,7 @@ fn type_size(edit: &mut FunctionEdit, ty_id: TypeID, alignments: &Vec<usize>) ->
     let ty = edit.get_type(ty_id).clone();
     let size = match ty {
         Type::Control => panic!(),
+        Type::MultiReturn(_) => panic!(),
         Type::Boolean | Type::Integer8 | Type::UnsignedInteger8 | Type::Float8 => {
             edit.add_dynamic_constant(DynamicConstant::Constant(1))
         }
diff --git a/hercules_opt/src/inline.rs b/hercules_opt/src/inline.rs
index 6e308274c65e0fbf14bb938e297a188f8ebaf0f3..99187dd2bbfb2a9bbc2bf5937cb40f5246ba5b29 100644
--- a/hercules_opt/src/inline.rs
+++ b/hercules_opt/src/inline.rs
@@ -1,5 +1,4 @@
 use std::collections::HashMap;
-use std::iter::zip;
 
 use hercules_ir::callgraph::*;
 use hercules_ir::def_use::*;
@@ -125,13 +124,25 @@ fn inline_func(
         assert_eq!(call_pred.as_ref().len(), 1);
         let call_pred = call_pred.as_ref()[0];
         let called_func = called[&function].func();
+        let call_users = editor.get_users(id);
+        let call_projs = call_users
+            .map(|node_id| {
+                (
+                    node_id,
+                    editor.func().nodes[node_id.idx()]
+                        .try_data_proj()
+                        .expect("PANIC: Call user is not a data projection")
+                        .1,
+                )
+            })
+            .collect::<Vec<_>>();
         // We can't inline calls to functions with multiple returns.
         let Some(called_return) = single_return_nodes[function.idx()] else {
             continue;
         };
         let called_return_uses = get_uses(&called_func.nodes[called_return.idx()]);
         let called_return_pred = called_return_uses.as_ref()[0];
-        let called_return_data = called_return_uses.as_ref()[1];
+        let called_return_data = &called_return_uses.as_ref()[1..];
 
         // Perform the actual edit.
         editor.edit(|mut edit| {
@@ -200,6 +211,13 @@ fn inline_func(
                 },
             )?;
 
+            // Replace and delete the call's (data projection) users
+            for (proj_id, proj_idx) in call_projs {
+                let proj_val = called_return_data[proj_idx];
+                edit = edit.replace_all_uses(proj_id, old_id_to_new_id(proj_val))?;
+                edit = edit.delete_node(proj_id)?;
+            }
+
             // Stitch uses of parameter nodes in the inlined function to the IDs
             // of arguments provided to the call node.
             for (node_idx, node) in called_func.nodes.iter().enumerate() {
@@ -209,12 +227,7 @@ fn inline_func(
                 }
             }
 
-            // Finally, delete the call node.
-            if let Node::Parameter { index } = called_func.nodes[called_return_data.idx()] {
-                edit = edit.replace_all_uses(id, args[index])?;
-            } else {
-                edit = edit.replace_all_uses(id, old_id_to_new_id(called_return_data))?;
-            }
+            // Finally delete the call node
             edit = edit.delete_node(control)?;
             edit = edit.delete_node(id)?;
 
diff --git a/hercules_opt/src/interprocedural_sroa.rs b/hercules_opt/src/interprocedural_sroa.rs
index 944ef8fd02e54b6164c7eb185c5e5e9aa27b5a28..ad4ce19ede72c3fd373e67f1a51f46de00c47241 100644
--- a/hercules_opt/src/interprocedural_sroa.rs
+++ b/hercules_opt/src/interprocedural_sroa.rs
@@ -5,466 +5,209 @@ use hercules_ir::ir::*;
 
 use crate::*;
 
-/**
- * Given an editor for each function in a module, return V s.t.
- * V[i] = true iff every call node to the function with index i
- * is editable. If there are no calls to this function, V[i] = true.
- */
-fn get_editable_callsites(editors: &mut Vec<FunctionEditor>) -> Vec<bool> {
-    let mut callsites_editable = vec![true; editors.len()];
-    for editor in editors {
-        for (idx, (_, function, _, _)) in editor
-            .func()
-            .nodes
-            .iter()
-            .enumerate()
-            .filter_map(|(idx, node)| node.try_call().map(|c| (idx, c)))
-        {
-            if !editor.is_mutable(NodeID::new(idx)) {
-                callsites_editable[function.idx()] = false;
-            }
-        }
-    }
-    callsites_editable
-}
-
-/**
- * Given a type tree, return a Vec containing all leaves which are not units.
- */
-fn get_nonempty_leaves(edit: &FunctionEdit, type_id: &TypeID) -> Vec<TypeID> {
-    let ty = edit.get_type(*type_id).clone();
-    match ty {
-        Type::Product(type_ids) => {
-            let mut leaves = vec![];
-            for type_id in type_ids {
-                leaves.extend(get_nonempty_leaves(&edit, &type_id))
-            }
-            leaves
-        }
-        _ => vec![*type_id],
-    }
-}
-
-/**
- * Given a `source` NodeID which produces a product containing
- * all nonempty leaves of the type tree for `type_id` in order, build
- * a node producing the `type_id`.
+/*
+ * Top-level function for running interprocedural analysis.
  *
- * `offset` represents the index at which to begin reading
- * elements of the `source` product.
+ * IP SROA expects that all nodes in all functions provided to it can be edited,
+ * since it needs to be able to modify both the functions whose types are being
+ * changed and call sites of those functions. What functions to run IP SROA on
+ * is therefore specified by a separate argument.
  *
- * Returns a 3-tuple of
- * 1. Node producing the `type`
- * 2. "Next" offset, i.e. `offset` + number of reads performed to build (1)
- * 3. List of node IDs which read `source` (tracked so that these will not
- *    be replaced by replace_all_uses_where)
+ * This optimization also takes an allow_sroa_arrays arguments (like non-IP
+ * SROA) which controls whether it will break up products of arrays.
  */
-fn build_uncompressed_product(
-    edit: &mut FunctionEdit,
-    source: &NodeID,
-    type_id: &TypeID,
-    offset: usize,
-) -> (NodeID, usize, Vec<NodeID>) {
-    let ty = edit.get_type(*type_id).clone();
-    match ty {
-        Type::Product(child_type_ids) => {
-            // Step 1. Create an empty constant for the type. We'll write
-            // child values into this constant.
-            let empty_constant_id = edit.add_zero_constant(*type_id);
-            let empty_constant_node = edit.add_node(Node::Constant {
-                id: empty_constant_id,
-            });
-            // Step 2. Build a node that generates each inner type.
-            // Since `source` contains nonempty leaves *in order*,
-            // we must process inner types in order; as part of this,
-            // inner type i+1 must read from where inner type i left off,
-            // hence we track the `current_offset` at which we are reading.
-            // Similarly, to combine results of all recursive calls,
-            // we keep the invariant that, at iteration i+1, currently_writing_to
-            // is an instance of `type_id` for which the first i elements
-            // have been populated based on inorder nonempty leaves
-            // (and, at iteration 0, it is empty).
-            let mut current_offset = offset;
-            let mut currently_writing_to = empty_constant_node;
-            let mut readers = vec![];
-            for (idx, child_type_id) in child_type_ids.iter().enumerate() {
-                let (child_data, next_offset, child_readers) =
-                    build_uncompressed_product(edit, source, child_type_id, current_offset);
-                current_offset = next_offset;
-                currently_writing_to = edit.add_node(Node::Write {
-                    collect: currently_writing_to,
-                    data: child_data,
-                    indices: Box::new([Index::Field(idx)]),
-                });
-                readers.extend(child_readers)
-            }
-            (currently_writing_to, current_offset, readers)
-        }
-        _ => {
-            // If the type is not a product, then we've reached a nonempty
-            // leaf, which we must read from source. Since this is a single
-            // read, the new offset increases by only 1.
-            let reader = edit.add_node(Node::Read {
-                collect: *source,
-                indices: Box::new([Index::Field(offset)]),
-            });
-            (reader, offset + 1, vec![reader])
+pub fn interprocedural_sroa(
+    editors: &mut Vec<FunctionEditor>,
+    types: &Vec<Vec<TypeID>>,
+    func_selection: &Vec<bool>,
+    allow_sroa_arrays: bool,
+) {
+    let can_sroa_type = |editor: &FunctionEditor, typ: TypeID| {
+        editor.get_type(typ).is_product()
+            && (allow_sroa_arrays || !type_contains_array(editor, typ))
+    };
+
+    let callsites = get_callsites(editors);
+
+    for ((func_id, apply), callsites) in (0..func_selection.len())
+        .map(FunctionID::new)
+        .zip(func_selection.iter())
+        .zip(callsites.into_iter())
+    {
+        if !apply {
+            continue;
         }
-    }
-}
-
-/**
- * Given a node with a product value, read the product's values
- * *in order* into the nonempty leaves of a product type represented
- * by type_id. Returns the ID of the resulting node, as well as the IDs
- * of all nodes which read from `node_id`.
- */
-fn uncompress_product(
-    edit: &mut FunctionEdit,
-    node_id: &NodeID,
-    type_id: &TypeID,
-) -> (NodeID, Vec<NodeID>) {
-    let (uncompressed_value, _, readers) = build_uncompressed_product(edit, node_id, type_id, 0);
-    (uncompressed_value, readers)
-}
 
-/**
-* Let `read_from` be a node with a value of type `type_id`.
-* Let `source` be a product value.
-* Returns a node representing the value obtained by writing
-* nonempty leaves of `read_from` *in order* into `source`,
-* starting at `offset`.
-*
-* `source` should be a product type with at least enough indices
-* to support this operation. Typically, `build_compressed_product`
-* should be called initially with a `source` created by adding a
-* zero constant for the flattened `type_id`.
-*
-* Returns:
-* 1. The ID of the node to which all nonempty leaves have been written
-* 2. The first offset after `offset` which was not written to.
-*/
-fn build_compressed_product(
-    mut edit: &mut FunctionEdit,
-    source: &NodeID,
-    type_id: &TypeID,
-    offset: usize,
-    read_from: &NodeID,
-) -> (NodeID, usize) {
-    let ty = edit.get_type(*type_id).clone();
-    match ty {
-        Type::Product(child_type_ids) => {
-            // Iterate through child types in order. For each type, construct
-            // a node that reads the corresponding value from `read_from`,
-            // and pass it as the node to read from in the recursive call.
-            let mut next_offset = offset;
-            let mut next_destination = *source;
-            for (idx, child_type_id) in child_type_ids.iter().enumerate() {
-                let child_value = edit.add_node(Node::Read {
-                    collect: *read_from,
-                    indices: Box::new([Index::Field(idx)]),
-                });
-                (next_destination, next_offset) = build_compressed_product(
-                    &mut edit,
-                    &next_destination,
-                    &child_type_id,
-                    next_offset,
-                    &child_value,
-                );
+        let editor: &mut FunctionEditor = &mut editors[func_id.idx()];
+        let return_types = &editor.func().return_types.to_vec();
+
+        // We determine the new return types of the function and track a map
+        // that tells us how the old return values are constructed from the
+        // new ones
+        let mut new_return_types = vec![];
+        let mut old_return_type_map = vec![];
+        let mut changed = false;
+
+        for ret_typ in return_types.iter() {
+            if !can_sroa_type(editor, *ret_typ) {
+                old_return_type_map.push(IndexTree::Leaf(new_return_types.len()));
+                new_return_types.push(*ret_typ);
+            } else {
+                let (types, index) = sroa_type(editor, *ret_typ, new_return_types.len());
+                old_return_type_map.push(index);
+                new_return_types.extend(types);
+                changed = true;
             }
-            (next_destination, next_offset)
-        }
-        _ => {
-            let writer = edit.add_node(Node::Write {
-                collect: *source,
-                data: *read_from,
-                indices: Box::new([Index::Field(offset)]),
-            });
-            (writer, offset + 1)
         }
-    }
-}
 
-/**
- * Given a node which has a value of the given type (which must be a product)
- * generate a new product node created by inserting nonempty leaves of the
- * source node *in order*. Returns the ID of this node, as well as the ID of
- * its type.
- */
-fn compress_product(
-    edit: &mut FunctionEdit,
-    node_id: &NodeID,
-    type_id: &TypeID,
-) -> (NodeID, TypeID) {
-    let nonempty_leaves = get_nonempty_leaves(&edit, &type_id);
-    let compressed_type = Type::Product(nonempty_leaves.into_boxed_slice());
-    let compressed_type_id = edit.add_type(compressed_type);
-
-    let empty_compressed_constant_id = edit.add_zero_constant(compressed_type_id);
-    let empty_compressed_node_id = edit.add_node(Node::Constant {
-        id: empty_compressed_constant_id,
-    });
-
-    let (compressed_value, _) =
-        build_compressed_product(edit, &empty_compressed_node_id, type_id, 0, node_id);
-
-    (compressed_value, compressed_type_id)
-}
-
-fn compress_return_products(editors: &mut Vec<FunctionEditor>, all_callsites_editable: &Vec<bool>) {
-    // Track whether we successfully applied edits to return statements,
-    // so that callsites are only modified when returns were. This is
-    // initialized to false, so that `is_compressed` is false when
-    // the corresponding entry in `callsites_editable` is false.
-    let mut is_compressed = vec![false; editors.len()];
-    let old_return_type_ids: Vec<_> = editors
-        .iter()
-        .map(|editor| editor.func().return_type)
-        .collect();
-
-    // Step 1. Track mapping of dynamic constant indexes to ids, so that
-    // we can substitute when generating empty constants later. The reason
-    // this works is that the following property is satisfied:
-    //   Let f and g be two functions such that f has d_f dynamic constants
-    //   and g has d_g dynamic constants. Wlog assume d_f < d_g. Then, the
-    //   first d_f dynamic constants of g are the dynamic constants of f.
-    // For any call node, the ith dynamic constant in the node is provided
-    // for the ith dynamic constant of the function called. So, when we need
-    // to take a type and replace d function dynamic constants with their
-    // values from a call, it suffices to look at the first d entries of
-    // dc_param_idx_to_dc_id to get the id of the dynamic constants in the function,
-    // and then replace dc_param_idx_to_dc_id[i] with call.dynamic_constants[i],
-    // for all i.
-    let max_num_dc_params = editors
-        .iter()
-        .map(|editor| editor.func().num_dynamic_constants)
-        .max()
-        .unwrap();
-    let mut dc_args = vec![];
-    editors[0].edit(|mut edit| {
-        dc_args = (0..max_num_dc_params as usize)
-            .map(|i| edit.add_dynamic_constant(DynamicConstant::Parameter(i)))
-            .collect();
-        Ok(edit)
-    });
-
-    // Step 2. Modify the return type of all editors corresponding to a function
-    // for which we can edit every callsite, and the return type is a product.
-    for (idx, editor) in editors.iter_mut().enumerate() {
-        if !all_callsites_editable[idx] {
+        // If the return type is not changed by IP SROA, skip to the next function
+        if !changed {
             continue;
         }
 
-        let old_return_id = NodeID::new(
-            (0..editor.func().nodes.len())
-                .filter(|idx| editor.func().nodes[*idx].is_return())
-                .next()
-                .unwrap(),
-        );
-        let old_return_type_id = old_return_type_ids[idx];
-
-        is_compressed[idx] = editor.get_type(editor.func().return_type).is_product()
-            && editor.edit(|mut edit| {
-                let return_node = edit.get_node(old_return_id);
-                let (return_control, return_data) = return_node.try_return().unwrap();
-
-                let (compressed_data_id, compressed_type_id) =
-                    compress_product(&mut edit, &return_data, &old_return_type_id);
+        // Now, modify each return in the current function and the return type
+        let return_nodes = editor
+            .func()
+            .nodes
+            .iter()
+            .enumerate()
+            .filter_map(|(idx, node)| {
+                if node.try_return().is_some() {
+                    Some(NodeID::new(idx))
+                } else {
+                    None
+                }
+            })
+            .collect::<Vec<_>>();
+        let success = editor.edit(|mut edit| {
+            for node in return_nodes {
+                let Node::Return { control, data } = edit.get_node(node) else {
+                    panic!()
+                };
+                let control = *control;
+                let data = data.to_vec();
+
+                let mut new_data = vec![];
+                for (idx, (data_id, update_info)) in
+                    data.into_iter().zip(old_return_type_map.iter()).enumerate()
+                {
+                    if let IndexTree::Leaf(new_idx) = update_info {
+                        // Unchanged return value
+                        assert!(new_data.len() == *new_idx);
+                        new_data.push(data_id);
+                    } else {
+                        // SROA'd return value
+                        let reads = generate_reads_edit(&mut edit, return_types[idx], data_id);
+                        reads.zip(update_info).for_each(|_, (read_id, ret_idx)| {
+                            assert!(new_data.len() == **ret_idx);
+                            new_data.push(*read_id);
+                        });
+                    }
+                }
 
-                edit.set_return_type(compressed_type_id);
-                let new_return_id = edit.add_node(Node::Return {
-                    control: return_control,
-                    data: compressed_data_id,
+                let new_ret = edit.add_node(Node::Return {
+                    control,
+                    data: new_data.into(),
                 });
-                edit.sub_edit(old_return_id, new_return_id);
-                let edit = edit.replace_all_uses(old_return_id, new_return_id)?;
-                edit.delete_node(old_return_id)
-            });
-    }
-
-    // Step 3: For every editor, update all mutable callsites corresponding to
-    // calls to functions which have been compressed. Since we only compress returns
-    // for functions for which every callsite is mutable, this should never fail,
-    // so we panic if it does.
-    for (_, editor) in editors.iter_mut().enumerate() {
-        let call_node_ids: Vec<_> = (0..editor.func().nodes.len())
-            .map(NodeID::new)
-            .filter(|id| editor.func().nodes[id.idx()].is_call())
-            .filter(|id| editor.is_mutable(*id))
-            .collect();
-
-        for call_node_id in call_node_ids {
-            let (_, function_id, ref dynamic_constants, _) =
-                editor.func().nodes[call_node_id.idx()].try_call().unwrap();
-            if !is_compressed[function_id.idx()] {
-                continue;
+                edit.sub_edit(node, new_ret);
+                edit = edit.delete_node(node)?;
             }
 
-            // Before creating the uncompressed product, we must update
-            // the type of the uncompressed product to reflect the dynamic
-            // constants provided when calling the function. Since we can
-            // only replace one constant at a time, we need to map
-            // constants to dummy values, and then map these to the
-            // replacement values (this prevents the case of replacements
-            // (0->1), (1->2) causing conflicts when we have [0, 1], we should
-            // get [1, 2], not [2, 2], which a naive loop would generate).
-
-            // A similar loop exists in the inline pass but at the node level.
-            // If this becomes a common pattern, it would be worth creating
-            // a better abstraction around bulk replacement.
-
-            let new_dcs = (*dynamic_constants).to_vec();
-            let old_dcs = dc_args[..new_dcs.len()].to_vec();
-            assert_eq!(old_dcs.len(), new_dcs.len());
-            let substs = old_dcs
-                .into_iter()
-                .zip(new_dcs.into_iter())
-                .collect::<HashMap<_, _>>();
-
-            let edit_successful = editor.edit(|mut edit| {
-                let substituted = substitute_dynamic_constants_in_type(
-                    &substs,
-                    old_return_type_ids[function_id.idx()],
-                    &mut edit,
-                );
-
-                let (expanded_product, readers) =
-                    uncompress_product(&mut edit, &call_node_id, &substituted);
-                edit.replace_all_uses_where(call_node_id, expanded_product, |id| {
-                    !readers.contains(id)
-                })
-            });
-
-            if !edit_successful {
-                panic!("Tried and failed to edit mutable callsite!");
+            edit.set_return_types(new_return_types);
+
+            Ok(edit)
+        });
+        assert!(success, "IP SROA expects to be able to edit everything, specify what functions to IP SROA via the func_selection argument");
+
+        // Finally, update calls of this function
+        // In particular, we actually don't have to update the call node at all but have to update
+        // its DataProjection users
+        for (caller, callsite) in callsites {
+            let editor = &mut editors[caller.idx()];
+            assert!(editor.func_id() == caller);
+            let projs = editor.get_users(callsite).collect::<Vec<_>>();
+            for proj_id in projs {
+                let Node::DataProjection { data: _, selection } = editor.node(proj_id) else {
+                    panic!("Call has a non data-projection user");
+                };
+                let new_return_info = &old_return_type_map[*selection];
+                let typ = types[caller.idx()][proj_id.idx()];
+                replace_returned_value(editor, proj_id, typ, new_return_info, callsite);
             }
         }
     }
 }
 
-fn remove_return_singletons(editors: &mut Vec<FunctionEditor>, all_callsites_editable: &Vec<bool>) {
-    // Track whether we removed a singleton product from the return of each
-    // editor's function. Defaults to false so that if the function was not
-    // edited (i.e. because not all callsites are editable), then no callsites
-    // will be edited.
-    let mut singleton_removed = vec![false; editors.len()];
-    let old_return_type_ids: Vec<_> = editors
-        .iter()
-        .map(|editor| editor.func().return_type)
-        .collect();
-
-    // Step 1. For all editors which correspond to a function for whic hall
-    // callsites are editable, modify their return type by extracting the
-    // value from the singleton and returning it directly.
-    for (idx, editor) in editors.iter_mut().enumerate() {
-        if !all_callsites_editable[idx] {
-            continue;
-        }
-
-        let return_type = editor.get_type(old_return_type_ids[idx]).clone();
-        singleton_removed[idx] = match return_type {
-            Type::Product(tys) if tys.len() == 1 && all_callsites_editable[idx] => {
-                let old_return_id = NodeID::new(
-                    (0..editor.func().nodes.len())
-                        .filter(|idx| editor.func().nodes[*idx].is_return())
-                        .next()
-                        .unwrap(),
-                );
-
-                editor.edit(|mut edit| {
-                    let (old_control, old_data) =
-                        edit.get_node(old_return_id).try_return().unwrap();
-
-                    let extracted_singleton_id = edit.add_node(Node::Read {
-                        collect: old_data,
-                        indices: Box::new([Index::Field(0)]),
-                    });
-                    let new_return_id = edit.add_node(Node::Return {
-                        control: old_control,
-                        data: extracted_singleton_id,
-                    });
-                    edit.sub_edit(old_return_id, new_return_id);
-                    edit.set_return_type(tys[0]);
-
-                    edit.delete_node(old_return_id)
-                })
+fn sroa_type(
+    editor: &FunctionEditor,
+    typ: TypeID,
+    type_index: usize,
+) -> (Vec<TypeID>, IndexTree<usize>) {
+    match &*editor.get_type(typ) {
+        Type::Product(ts) => {
+            let mut res_types = vec![];
+            let mut index = type_index;
+            let mut children = vec![];
+            for t in ts {
+                let (types, child) = sroa_type(editor, *t, index);
+                index += types.len();
+                res_types.extend(types);
+                children.push(child);
             }
-            _ => false,
+            (res_types, IndexTree::Node(children))
         }
+        _ => (vec![typ], IndexTree::Leaf(type_index)),
     }
+}
 
-    // Step 2. For each editor, find all callsites and reconstruct
-    // the singleton product at each if the return of the corresponding
-    // function was modified. This should always succeed since we only
-    // edited functions for which all callsites were mutable, so panic
-    // if an edit does not succeed.
-    for editor in editors.iter_mut() {
-        let call_node_ids: Vec<_> = (0..editor.func().nodes.len())
-            .map(NodeID::new)
-            .filter(|id| editor.func().nodes[id.idx()].is_call())
-            .filter(|id| editor.is_mutable(*id))
-            .collect();
-
-        for call_node_id in call_node_ids {
-            let (_, function, dc_args, _) =
-                editor.func().nodes[call_node_id.idx()].try_call().unwrap();
-
-            let dc_args = dc_args.to_vec();
-
-            if singleton_removed[function.idx()] {
-                let edit_successful = editor.edit(|mut edit| {
-                    let dc_params = (0..dc_args.len())
-                        .map(|param_idx| {
-                            edit.add_dynamic_constant(DynamicConstant::Parameter(param_idx))
-                        })
-                        .collect::<Vec<_>>();
-                    let substs = dc_params
-                        .into_iter()
-                        .zip(dc_args.into_iter())
-                        .collect::<HashMap<_, _>>();
-
-                    let substituted = substitute_dynamic_constants_in_type(
-                        &substs,
-                        old_return_type_ids[function.idx()],
-                        &mut edit,
-                    );
-                    let empty_constant_id = edit.add_zero_constant(substituted);
-                    let empty_node_id = edit.add_node(Node::Constant {
-                        id: empty_constant_id,
-                    });
-
-                    let restored_singleton_id = edit.add_node(Node::Write {
-                        collect: empty_node_id,
-                        data: call_node_id,
-                        indices: Box::new([Index::Field(0)]),
-                    });
-                    edit.replace_all_uses_where(call_node_id, restored_singleton_id, |id| {
-                        *id != restored_singleton_id
-                    })
-                });
+// Returns a list for each function of the call sites of that function
+fn get_callsites(editors: &Vec<FunctionEditor>) -> Vec<Vec<(FunctionID, NodeID)>> {
+    let mut callsites = vec![vec![]; editors.len()];
 
-                if !edit_successful {
-                    panic!("Tried and failed to edit mutable callsite!");
-                }
-            }
+    for editor in editors {
+        let caller = editor.func_id();
+        for (callsite, (_, callee, _, _)) in editor
+            .func()
+            .nodes
+            .iter()
+            .enumerate()
+            .filter_map(|(idx, node)| node.try_call().map(|c| (idx, c)))
+        {
+            assert!(editor.is_mutable(NodeID::new(callsite)), "IP SROA expects to be able to edit everything, specify what functions to IP SROA via the func_selection argument");
+            callsites[callee.idx()].push((caller, NodeID::new(callsite)));
         }
     }
+
+    callsites
 }
 
-pub fn interprocedural_sroa(editors: &mut Vec<FunctionEditor>) {
-    // SROA is implemented in two phases. First, we flatten (or "compress")
-    // all product return types, so that they are only depth 1 products,
-    // and do not contain any empty products.
-    // Next, if any return type is now a singleton product, we
-    // remove the singleton and just retun the type directly.
-    // We only apply these changes to functions for which
-    // all their callsites are editable.
-    let all_callsites_editable = get_editable_callsites(editors);
-    compress_return_products(editors, &all_callsites_editable);
-    remove_return_singletons(editors, &all_callsites_editable);
+// Replaces a projection node (from before the function signature change) based on the of_new_call
+// description (which tells how to construct the value from the new returned values).
+fn replace_returned_value(
+    editor: &mut FunctionEditor,
+    proj_id: NodeID,
+    proj_typ: TypeID,
+    of_new_call: &IndexTree<usize>,
+    call_node: NodeID,
+) {
+    let constant = generate_constant(editor, proj_typ);
+
+    let success = editor.edit(|mut edit| {
+        let mut new_val = edit.add_node(Node::Constant { id: constant });
+        of_new_call.for_each(|idx, selection| {
+            let new_proj = edit.add_node(Node::DataProjection {
+                data: call_node,
+                selection: *selection,
+            });
+            new_val = edit.add_node(Node::Write {
+                collect: new_val,
+                data: new_proj,
+                indices: idx.clone().into(),
+            });
+        });
 
-    // Run DCE to prevent issues with schedule repair.
-    for editor in editors.iter_mut() {
-        dce(editor);
-    }
+        edit = edit.replace_all_uses(proj_id, new_val)?;
+        edit.delete_node(proj_id)
+    });
+    assert!(success);
 }
diff --git a/hercules_opt/src/loop_bound_canon.rs b/hercules_opt/src/loop_bound_canon.rs
index c127c617b1b8e3d1b5cc58af87bb1db26e1b931a..edda6b63cb033cae86b722109c8f6b57f530639b 100644
--- a/hercules_opt/src/loop_bound_canon.rs
+++ b/hercules_opt/src/loop_bound_canon.rs
@@ -111,7 +111,7 @@ pub fn canonicalize_single_loop_bounds(
 
     // FIXME: This is quite fragile.
     let guard_info: Option<(NodeID, NodeID, NodeID, NodeID)> = (|| {
-        let Node::Projection {
+        let Node::ControlProjection {
             control,
             selection: _,
         } = editor.node(loop_pred)
diff --git a/hercules_opt/src/outline.rs b/hercules_opt/src/outline.rs
index 088e57750df0265883472ab111d5d0a6e5f21c6e..27c268626416d4bb916c5e6823f350505cf36571 100644
--- a/hercules_opt/src/outline.rs
+++ b/hercules_opt/src/outline.rs
@@ -191,12 +191,11 @@ pub fn outline(
     editor.edit(|mut edit| {
         // Step 2: assemble the outlined function.
         let u32_ty = edit.add_type(Type::UnsignedInteger32);
-        let return_types: Box<[_]> = return_idx_to_inside_id
+        let return_types: Vec<_> = return_idx_to_inside_id
             .iter()
             .map(|id| typing[id.idx()])
             .chain(callee_succ_return_idx.map(|_| u32_ty))
             .collect();
-        let single_return = return_types.len() == 1;
 
         let mut outlined = Function {
             name: format!(
@@ -209,11 +208,7 @@ pub fn outline(
                 .map(|id| typing[id.idx()])
                 .chain(callee_pred_param_idx.map(|_| u32_ty))
                 .collect(),
-            return_type: if single_return {
-                return_types[0]
-            } else {
-                edit.add_type(Type::Product(return_types))
-            },
+            return_types,
             num_dynamic_constants: edit.get_num_dynamic_constant_params(),
             entry: false,
             nodes: vec![],
@@ -374,7 +369,6 @@ pub fn outline(
         outlined.nodes.extend(select_top_phi_inputs);
 
         // Add the return nodes.
-        let cons_id = edit.add_zero_constant(outlined.return_type);
         for ((exit, _), dom_return_values) in
             zip(exit_points.iter(), exit_point_dom_return_values.iter())
         {
@@ -409,29 +403,10 @@ pub fn outline(
                 data_ids.push(cons_node_id);
             }
 
-            // Build the return value
-            let construct_id = if single_return {
-                assert!(data_ids.len() == 1);
-                data_ids.pop().unwrap()
-            } else {
-                let mut construct_id = NodeID::new(outlined.nodes.len());
-                outlined.nodes.push(Node::Constant { id: cons_id });
-                for (idx, data) in data_ids.into_iter().enumerate() {
-                    let write = Node::Write {
-                        collect: construct_id,
-                        data: data,
-                        indices: Box::new([Index::Field(idx)]),
-                    };
-                    construct_id = NodeID::new(outlined.nodes.len());
-                    outlined.nodes.push(write);
-                }
-                construct_id
-            };
-
             // Return the return product.
             outlined.nodes.push(Node::Return {
                 control: convert_id(*exit),
-                data: construct_id,
+                data: data_ids.into(),
             });
         }
 
@@ -526,29 +501,25 @@ pub fn outline(
             (new_region_id, call_id)
         };
 
-        // Create the read nodes from the call node to get the outputs of the
-        // outlined function (if there are multiple returned values)
-        let output_reads: Vec<_> = if single_return {
-            vec![call_id]
-        } else {
-            (0..return_idx_to_inside_id.len())
-                .map(|idx| {
-                    let read = Node::Read {
-                        collect: call_id,
-                        indices: Box::new([Index::Field(idx)]),
-                    };
-                    edit.add_node(read)
-                })
-                .collect()
-        };
-        let indicator_read = callee_succ_return_idx.map(|idx| {
-            let read = Node::Read {
-                collect: call_id,
-                indices: Box::new([Index::Field(idx)]),
+        // Create the data projection nodes from the call node to get the outputs of the outlined
+        // function
+        let output_projs: Vec<_> = (0..return_idx_to_inside_id.len())
+            .map(|idx| {
+                let proj = Node::DataProjection {
+                    data: call_id,
+                    selection: idx,
+                };
+                edit.add_node(proj)
+            })
+            .collect();
+        let indicator_proj = callee_succ_return_idx.map(|idx| {
+            let proj = Node::DataProjection {
+                data: call_id,
+                selection: idx,
             };
-            edit.add_node(read)
+            edit.add_node(proj)
         });
-        for (old_id, new_id) in zip(return_idx_to_inside_id.iter(), output_reads.iter()) {
+        for (old_id, new_id) in zip(return_idx_to_inside_id.iter(), output_projs.iter()) {
             edit = edit.replace_all_uses(*old_id, *new_id)?;
         }
 
@@ -565,18 +536,18 @@ pub fn outline(
             });
             let cmp_id = edit.add_node(Node::Binary {
                 op: BinaryOperator::EQ,
-                left: indicator_read.unwrap(),
+                left: indicator_proj.unwrap(),
                 right: indicator_cons_node_id,
             });
             let if_id = edit.add_node(Node::If {
                 control: if_tree_acc,
                 cond: cmp_id,
             });
-            let false_id = edit.add_node(Node::Projection {
+            let false_id = edit.add_node(Node::ControlProjection {
                 control: if_id,
                 selection: 0,
             });
-            let true_id = edit.add_node(Node::Projection {
+            let true_id = edit.add_node(Node::ControlProjection {
                 control: if_id,
                 selection: 1,
             });
diff --git a/hercules_opt/src/pred.rs b/hercules_opt/src/pred.rs
index 644c69d0df34d327c2c2e34bf8e0a915ddd68233..ed7c3a855b016608aa194cc9f2cd89f05d836bde 100644
--- a/hercules_opt/src/pred.rs
+++ b/hercules_opt/src/pred.rs
@@ -26,7 +26,7 @@ pub fn predication(editor: &mut FunctionEditor, typing: &Vec<TypeID>) {
                     // Look for two projections with the same branch.
                     let preds = preds.into_iter().filter_map(|id| {
                         nodes[id.idx()]
-                            .try_proj()
+                            .try_control_proj()
                             .map(|(branch, selection)| (*id, branch, selection))
                     });
                     // Index projections by if branch.
diff --git a/hercules_opt/src/sroa.rs b/hercules_opt/src/sroa.rs
index 8865f863934233cb3675c57319c5e2746e339aaf..e658ff8818246b31b4dbf3c9dd7ad0372c59f3c6 100644
--- a/hercules_opt/src/sroa.rs
+++ b/hercules_opt/src/sroa.rs
@@ -27,9 +27,10 @@ use crate::*;
  *   are broken up into ternary nodes for the individual fields
  *
  * - Call: the call node can use a product value as an argument to another
- *   function, and can produce a product value as a result. Argument values
- *   will be constructed at the call site and the return value will be broken
- *   into individual fields
+ *   function, argument values will be constructed at the call site
+ *
+ * - DataProjection: data projection nodes can produce a product value that was
+ *   returned by a function, we will break the value into individual fields
  *
  * - Read: the read node reads primitive fields from product values - these get
  *   replaced by a direct use of the field value
@@ -71,8 +72,9 @@ pub fn sroa(
 
     // First: determine all nodes which interact with products (as described above)
     let mut product_nodes: Vec<NodeID> = vec![];
-    // We track call and return nodes separately since they (may) require constructing new products
-    // for the call's arguments or the return's value
+    // We track call, data projection, and return nodes separately since they (may) require
+    // constructing new products for the call's arguments, data projection's value, or a
+    // returned value
     let mut call_return_nodes: Vec<NodeID> = vec![];
 
     for node in reverse_postorder {
@@ -303,37 +305,57 @@ pub fn sroa(
                 }
             }
 
-            // We add all calls to the call/return list and check their arguments later
-            Node::Call { .. } => call_return_nodes.push(*node),
-            Node::Return { control: _, data } if can_sroa_type(editor, types[&data]) => {
-                call_return_nodes.push(*node)
+            // We add all calls and returns to the call/return list and check their
+            // arguments/return values later
+            Node::Call { .. } | Node::Return { .. } => call_return_nodes.push(*node),
+            // We add DataProjetion nodes that produce SROAable values
+            Node::DataProjection { .. } if can_sroa_type(editor, types[&node]) => {
+                call_return_nodes.push(*node);
             }
 
             _ => (),
         }
     }
 
-    // Next, we handle calls and returns. For returns, we will insert nodes that read each field of
-    // the returned product and then write them into a new product. These writes are not put into
-    // the list of product nodes since they must remain but the reads are so that they will be
-    // replaced later on.
-    // For calls, we do a similar process for each (product) argument. Additionally, if the call
-    // returns a product, we create reads for each field in that product and store it into our
-    // field map
+    // Next, we handle calls and returns. For returns, for each returned value that is a product,
+    // we will insert nodes that read each field of it and then write them into a new product.
+    // The writes we create are not put into the list of product nodes since they must remain but
+    // the reads are put in the list so that they will be replaced later on.
+    // For calls, we do a similar process for each (product) argument.
+    // For data projection that produce product values, we create reads for each field of that
+    // product and store it into our field map
     for node in call_return_nodes {
         match &editor.func().nodes[node.idx()] {
             Node::Return { control, data } => {
-                assert!(can_sroa_type(editor, types[&data]));
                 let control = *control;
-                let new_data = reconstruct_product(editor, types[&data], *data, &mut product_nodes);
-                editor.edit(|mut edit| {
-                    let new_return = edit.add_node(Node::Return {
-                        control,
-                        data: new_data,
+                let data = data.to_vec();
+
+                let (new_data, changed) =
+                    data.into_iter()
+                        .fold((vec![], false), |(mut vals, changed), val_id| {
+                            if !can_sroa_type(editor, types[&val_id]) {
+                                vals.push(val_id);
+                                (vals, changed)
+                            } else {
+                                vals.push(reconstruct_product(
+                                    editor,
+                                    types[&val_id],
+                                    val_id,
+                                    &mut product_nodes,
+                                ));
+                                (vals, true)
+                            }
+                        });
+                if changed {
+                    editor.edit(|mut edit| {
+                        let new_return = edit.add_node(Node::Return {
+                            control,
+                            data: new_data.into(),
+                        });
+                        edit.sub_edit(node, new_return);
+                        edit.delete_node(node)
                     });
-                    edit.sub_edit(node, new_return);
-                    edit.delete_node(node)
-                });
+                }
             }
             Node::Call {
                 control,
@@ -344,55 +366,44 @@ pub fn sroa(
                 let control = *control;
                 let function = *function;
                 let dynamic_constants = dynamic_constants.clone();
-                let args = args.clone();
+                let args = args.to_vec();
+
+                let (new_args, changed) =
+                    args.into_iter()
+                        .fold((vec![], false), |(mut vals, changed), arg| {
+                            if !can_sroa_type(editor, types[&arg]) {
+                                vals.push(arg);
+                                (vals, changed)
+                            } else {
+                                vals.push(reconstruct_product(
+                                    editor,
+                                    types[&arg],
+                                    arg,
+                                    &mut product_nodes,
+                                ));
+                                (vals, true)
+                            }
+                        });
 
-                // If the call returns a product that we can sroa, we generate reads for each field
-                let fields = if can_sroa_type(editor, types[&node]) {
-                    Some(generate_reads(editor, types[&node], node))
-                } else {
-                    None
-                };
+                if changed {
+                    editor.edit(|mut edit| {
+                        let new_call = edit.add_node(Node::Call {
+                            control,
+                            function,
+                            dynamic_constants,
+                            args: new_args.into(),
+                        });
+                        edit.sub_edit(node, new_call);
+                        let edit = edit.replace_all_uses(node, new_call)?;
+                        let edit = edit.delete_node(node)?;
 
-                let mut new_args = vec![];
-                for arg in args {
-                    if can_sroa_type(editor, types[&arg]) {
-                        new_args.push(reconstruct_product(
-                            editor,
-                            types[&arg],
-                            arg,
-                            &mut product_nodes,
-                        ));
-                    } else {
-                        new_args.push(arg);
-                    }
-                }
-                editor.edit(|mut edit| {
-                    let new_call = edit.add_node(Node::Call {
-                        control,
-                        function,
-                        dynamic_constants,
-                        args: new_args.into(),
+                        Ok(edit)
                     });
-                    edit.sub_edit(node, new_call);
-                    let edit = edit.replace_all_uses(node, new_call)?;
-                    let edit = edit.delete_node(node)?;
-
-                    // Since we've replaced uses of calls with the new node, we update the type
-                    // information so that we can retrieve the type of the new call if needed
-                    // Because the other nodes we've created so far are only used in very
-                    // particular ways (i.e. are not used by arbitrary nodes) we don't need their
-                    // type information but do for the new calls
-                    types.insert(new_call, types[&node]);
-
-                    match fields {
-                        None => {}
-                        Some(fields) => {
-                            field_map.insert(new_call, fields);
-                        }
-                    }
-
-                    Ok(edit)
-                });
+                }
+            }
+            Node::DataProjection { .. } => {
+                assert!(can_sroa_type(editor, types[&node]));
+                field_map.insert(node, generate_reads(editor, types[&node], node));
             }
             _ => panic!("Processing non-call or return node"),
         }
@@ -725,7 +736,7 @@ pub fn sroa(
     });
 }
 
-fn type_contains_array(editor: &FunctionEditor, typ: TypeID) -> bool {
+pub fn type_contains_array(editor: &FunctionEditor, typ: TypeID) -> bool {
     match &*editor.get_type(typ) {
         Type::Array(_, _) => true,
         Type::Product(ts) | Type::Summation(ts) => {
@@ -967,20 +978,31 @@ fn reconstruct_product(
 
 // Given a node val of type typ, adds nodes to the function which read all (leaf) fields of val and
 // returns an IndexTree that tracks the nodes reading each leaf field
-fn generate_reads(editor: &mut FunctionEditor, typ: TypeID, val: NodeID) -> IndexTree<NodeID> {
-    let res = generate_reads_at_index(editor, typ, val, vec![]);
-    res
+pub fn generate_reads(editor: &mut FunctionEditor, typ: TypeID, val: NodeID) -> IndexTree<NodeID> {
+    let mut result = None;
+
+    editor.edit(|mut edit| {
+        result = Some(generate_reads_edit(&mut edit, typ, val));
+        Ok(edit)
+    });
+
+    result.unwrap()
+}
+
+// The same as generate_reads but for if we have a FunctionEdit rather than a FunctionEditor
+pub fn generate_reads_edit(edit: &mut FunctionEdit, typ: TypeID, val: NodeID) -> IndexTree<NodeID> {
+    generate_reads_at_index_edit(edit, typ, val, vec![])
 }
 
 // Given a node val of type which at the indices idx has type typ, construct reads of all (leaf)
 // fields within this sub-value of val and return the correspondence list
-fn generate_reads_at_index(
-    editor: &mut FunctionEditor,
+fn generate_reads_at_index_edit(
+    edit: &mut FunctionEdit,
     typ: TypeID,
     val: NodeID,
     idx: Vec<Index>,
 ) -> IndexTree<NodeID> {
-    let ts: Option<Vec<TypeID>> = if let Some(ts) = editor.get_type(typ).try_product() {
+    let ts: Option<Vec<TypeID>> = if let Some(ts) = edit.get_type(typ).try_product() {
         Some(ts.into())
     } else {
         None
@@ -993,22 +1015,18 @@ fn generate_reads_at_index(
         for (i, t) in ts.into_iter().enumerate() {
             let mut new_idx = idx.clone();
             new_idx.push(Index::Field(i));
-            fields.push(generate_reads_at_index(editor, t, val, new_idx));
+            fields.push(generate_reads_at_index_edit(edit, t, val, new_idx));
         }
         IndexTree::Node(fields)
     } else {
         // For non-product types, we've reached a leaf so we generate the read and return it's
         // information
-        let mut read_id = None;
-        editor.edit(|mut edit| {
-            read_id = Some(edit.add_node(Node::Read {
-                collect: val,
-                indices: idx.clone().into(),
-            }));
-            Ok(edit)
+        let read_id = edit.add_node(Node::Read {
+            collect: val,
+            indices: idx.into(),
         });
 
-        IndexTree::Leaf(read_id.expect("Add node canont fail"))
+        IndexTree::Leaf(read_id)
     }
 }
 
@@ -1024,7 +1042,7 @@ macro_rules! add_const {
 }
 
 // Given a type, builds a default constant of that type
-fn generate_constant(editor: &mut FunctionEditor, typ: TypeID) -> ConstantID {
+pub fn generate_constant(editor: &mut FunctionEditor, typ: TypeID) -> ConstantID {
     let t = editor.get_type(typ).clone();
 
     match t {
@@ -1055,6 +1073,7 @@ fn generate_constant(editor: &mut FunctionEditor, typ: TypeID) -> ConstantID {
             add_const!(editor, Constant::Array(typ))
         }
         Type::Control => panic!("Cannot create constant of control type"),
+        Type::MultiReturn(_) => panic!("Cannot create constant of multi-return type"),
     }
 }
 
diff --git a/hercules_opt/src/unforkify.rs b/hercules_opt/src/unforkify.rs
index b44ed8df82b494b7da0aff006587246c501f8e5d..2d6cf7b36f44864d67ef4ee505203de12ee59bff 100644
--- a/hercules_opt/src/unforkify.rs
+++ b/hercules_opt/src/unforkify.rs
@@ -205,11 +205,11 @@ pub fn unforkify(
         control: fork_control,
         cond: guard_cond_id,
     };
-    let guard_taken_proj = Node::Projection {
+    let guard_taken_proj = Node::ControlProjection {
         control: guard_if_id,
         selection: 1,
     };
-    let guard_skipped_proj = Node::Projection {
+    let guard_skipped_proj = Node::ControlProjection {
         control: guard_if_id,
         selection: 0,
     };
@@ -224,11 +224,11 @@ pub fn unforkify(
         control: join_control,
         cond: neq_id,
     };
-    let proj_back = Node::Projection {
+    let proj_back = Node::ControlProjection {
         control: if_id,
         selection: 1,
     };
-    let proj_exit = Node::Projection {
+    let proj_exit = Node::ControlProjection {
         control: if_id,
         selection: 0,
     };
diff --git a/hercules_opt/src/utils.rs b/hercules_opt/src/utils.rs
index 793fe9fabfee830eef97687097aa5e40177a369d..e962b81dfaf28ece80f49a150139f5774c186771 100644
--- a/hercules_opt/src/utils.rs
+++ b/hercules_opt/src/utils.rs
@@ -244,31 +244,42 @@ pub fn collapse_returns(editor: &mut FunctionEditor) -> Option<NodeID> {
     if returns.len() == 1 {
         return Some(returns[0]);
     }
-    let preds_before_returns: Vec<NodeID> = returns
+    let preds_before_returns: Box<[NodeID]> = returns
         .iter()
         .map(|ret_id| get_uses(&editor.func().nodes[ret_id.idx()]).as_ref()[0])
         .collect();
-    let data_to_return: Vec<NodeID> = returns
-        .iter()
-        .map(|ret_id| get_uses(&editor.func().nodes[ret_id.idx()]).as_ref()[1])
+
+    let num_return_data = editor.func().return_types.len();
+    let data_to_return: Vec<Box<[NodeID]>> = (0..num_return_data)
+        .map(|idx| {
+            returns
+                .iter()
+                .map(|ret_id| get_uses(&editor.func().nodes[ret_id.idx()]).as_ref()[idx + 1])
+                .collect()
+        })
         .collect();
 
     // All of the old returns get replaced in a single edit.
     let mut new_return = None;
     editor.edit(|mut edit| {
         let region = edit.add_node(Node::Region {
-            preds: preds_before_returns.into_boxed_slice(),
-        });
-        let phi = edit.add_node(Node::Phi {
-            control: region,
-            data: data_to_return.into_boxed_slice(),
+            preds: preds_before_returns,
         });
+        let return_vals = data_to_return
+            .into_iter()
+            .map(|data| {
+                edit.add_node(Node::Phi {
+                    control: region,
+                    data,
+                })
+            })
+            .collect();
         for ret in returns {
             edit = edit.delete_node(ret)?;
         }
         new_return = Some(edit.add_node(Node::Return {
             control: region,
-            data: phi,
+            data: return_vals,
         }));
         Ok(edit)
     });
@@ -293,10 +304,11 @@ pub fn ensure_between_control_flow(editor: &mut FunctionEditor) -> Option<NodeID
             .filter(|id| editor.func().nodes[id.idx()].is_control())
             .next()
             .unwrap();
-        let Node::Return { control, data } = editor.func().nodes[ret.idx()] else {
+        let Node::Return { control, ref data } = editor.func().nodes[ret.idx()] else {
             panic!("PANIC: A Hercules function with only two control nodes must have a return node be the other control node, other than the start node.")
         };
         assert_eq!(control, NodeID::new(0), "PANIC: The only other control node in a Hercules function, the return node, is not using the start node.");
+        let data = data.clone();
         let mut region_id = None;
         editor.edit(|mut edit| {
             edit = edit.delete_node(ret)?;
diff --git a/hercules_samples/call/src/call.hir b/hercules_samples/call/src/call.hir
index cecee343288370e9ce51564e8cc1fc40149bd94c..77f5db2de2b4fa0508048d8155d8745e34ca3168 100644
--- a/hercules_samples/call/src/call.hir
+++ b/hercules_samples/call/src/call.hir
@@ -2,8 +2,10 @@ fn myfunc(x: u64) -> u64
   cr1 = region(start)
   cr2 = region(cr1)
   c = constant(u64, 24)
-  y = call<16>(add, cr1, x, x)
-  z = call<10>(add, cr2, x, c)
+  cy = call<16>(add, cr1, x, x)
+  y = data_projection(cy, 0)
+  cz = call<10>(add, cr2, x, c)
+  z = data_projection(cz, 0)
   w = add(y, z)
   r = return(cr2, w)
 
diff --git a/hercules_samples/ccp/src/ccp.hir b/hercules_samples/ccp/src/ccp.hir
index b8e939942be05f75d85411f05016ed3484ed1bf0..e07df1d37a97054a19605578fe2511525a0c229b 100644
--- a/hercules_samples/ccp/src/ccp.hir
+++ b/hercules_samples/ccp/src/ccp.hir
@@ -7,14 +7,14 @@ fn tricky(x: i32) -> i32
   val = phi(loop, one, later_val)
   b = ne(one, val)
   if1 = if(loop, b)
-  if1_false = projection(if1, 0)
-  if1_true = projection(if1, 1)
+  if1_false = control_projection(if1, 0)
+  if1_true = control_projection(if1, 1)
   middle = region(if1_false, if1_true)
   inter_val = sub(two, val)
   later_val = phi(middle, inter_val, two)
   idx_dec = sub(idx, one)
   cond = gte(idx_dec, one)
   if2 = if(middle, cond)
-  if2_false = projection(if2, 0)
-  if2_true = projection(if2, 1)
+  if2_false = control_projection(if2, 0)
+  if2_true = control_projection(if2, 1)
   r = return(if2_false, later_val)
diff --git a/hercules_samples/fac/src/fac.hir b/hercules_samples/fac/src/fac.hir
index e43dd8cae1a605bca7c3ceac4eb7c029665e86e6..aaf55c1de38cca2c6b024be061eea22c50ad5e6d 100644
--- a/hercules_samples/fac/src/fac.hir
+++ b/hercules_samples/fac/src/fac.hir
@@ -8,6 +8,6 @@ fn fac(x: i32) -> i32
   fac_acc = mul(fac, idx_inc)
   in_bounds = lt(idx_inc, x)
   if = if(loop, in_bounds)
-  if_false = projection(if, 0)
-  if_true = projection(if, 1)
+  if_false = control_projection(if, 0)
+  if_true = control_projection(if, 1)
   r = return(if_false, fac_acc)
diff --git a/hercules_test/hercules_interpreter/src/interpreter.rs b/hercules_test/hercules_interpreter/src/interpreter.rs
index 2e352644cc816a3fe5c43427fa7190708508bb82..8a57783911924803186f829b7efef3b8833d6abd 100644
--- a/hercules_test/hercules_interpreter/src/interpreter.rs
+++ b/hercules_test/hercules_interpreter/src/interpreter.rs
@@ -529,6 +529,13 @@ impl<'a> FunctionExecutionState<'a> {
 
                 state.run()
             }
+            Node::DataProjection { data, selection } => {
+                let data = self.handle_data(token, *data);
+                let InterpreterVal::MultiReturn(vs) = data else {
+                    panic!();
+                };
+                vs[*selection].clone()
+            }
             Node::Read { collect, indices } => {
                 let collection = self.handle_data(token, *collect);
                 if let InterpreterVal::Undef(_) = collection {
@@ -745,7 +752,7 @@ impl<'a> FunctionExecutionState<'a> {
                         .succs(ctrl_token.curr)
                         .find(|n| {
                             self.get_function().nodes[n.idx()]
-                                .try_projection(cond)
+                                .try_control_projection(cond)
                                 .is_some()
                         })
                         .expect("PANIC: No outgoing valid outgoing edge.");
@@ -753,7 +760,7 @@ impl<'a> FunctionExecutionState<'a> {
                     let ctrl_token = ctrl_token.moved_to(next);
                     vec![ctrl_token]
                 }
-                Node::Projection { .. } => {
+                Node::ControlProjection { .. } => {
                     let next: NodeID = self
                         .get_control_subgraph()
                         .succs(ctrl_token.curr)
@@ -861,8 +868,11 @@ impl<'a> FunctionExecutionState<'a> {
                     }
                 }
                 Node::Return { control: _, data } => {
-                    let result = self.handle_data(&ctrl_token, *data);
-                    break 'outer result;
+                    let results = data
+                        .iter()
+                        .map(|data| self.handle_data(&ctrl_token, *data))
+                        .collect();
+                    break 'outer InterpreterVal::MultiReturn(results);
                 }
                 _ => {
                     panic!("PANIC: Unexpected node in control subgraph {:?}", node);
diff --git a/hercules_test/hercules_interpreter/src/value.rs b/hercules_test/hercules_interpreter/src/value.rs
index dfc290b253666c8251370450f9f2893fe78d8830..0f5716e75b64f3753e22d074bd93210fee7ddfd1 100644
--- a/hercules_test/hercules_interpreter/src/value.rs
+++ b/hercules_test/hercules_interpreter/src/value.rs
@@ -36,6 +36,8 @@ pub enum InterpreterVal {
     // These can be freely? casted
     DynamicConstant(usize),
     ThreadID(usize),
+
+    MultiReturn(Box<[InterpreterVal]>),
 }
 
 #[derive(Clone)]
@@ -848,6 +850,7 @@ impl<'a> InterpreterVal {
                     Type::Product(_) => todo!(),
                     Type::Summation(_) => todo!(),
                     Type::Array(type_id, _) => todo!(),
+                    Type::MultiReturn(_) => todo!(),
                 }
             }
             (_, Self::Undef(v)) => InterpreterVal::Undef(v),
diff --git a/hercules_test/test_inputs/call.hir b/hercules_test/test_inputs/call.hir
index 447489343cce6e010c342add119aa0c4fa31058b..0ebf7c7dc79510ee2ac5535fc6c6f510c5187e6a 100644
--- a/hercules_test/test_inputs/call.hir
+++ b/hercules_test/test_inputs/call.hir
@@ -1,7 +1,8 @@
 fn myfunc(x: i32) -> i32
-  y = call(add, x, x)
+  cy = call(add, x, x)
+  y = data_projection(cy, 0)
   r = return(start, y)
 
 fn add(x: i32, y: i32) -> i32
   w = add(x, y)
-  r = return(start, w)
\ No newline at end of file
+  r = return(start, w)
diff --git a/hercules_test/test_inputs/call_dc_params.hir b/hercules_test/test_inputs/call_dc_params.hir
index 5ccf2686848379b4b31eec3c2201d924984ad2e0..b8da97918fd7758ddb80e96533fe2d81a4139e77 100644
--- a/hercules_test/test_inputs/call_dc_params.hir
+++ b/hercules_test/test_inputs/call_dc_params.hir
@@ -1,9 +1,10 @@
 fn myfunc(x: u64) -> u64
-  y = call<10, 4>(add, x, x)
+  cy = call<10, 4>(add, x, x)
+  y = data_projection(cy, 0)
   r = return(start, y)
 
 fn add<2>(x: u64, y: u64) -> u64
   b = dynamic_constant(#1)
   r = return(start, z)
   w = add(x, y)
-  z = add(b, w)
\ No newline at end of file
+  z = add(b, w)
diff --git a/hercules_test/test_inputs/ccp_example.hir b/hercules_test/test_inputs/ccp_example.hir
index 25b7379e19f488c2d5ecf8cb377161758c62bb98..f8004b636e5df90299b27886778e5bd56799dfd0 100644
--- a/hercules_test/test_inputs/ccp_example.hir
+++ b/hercules_test/test_inputs/ccp_example.hir
@@ -6,14 +6,14 @@ fn tricky(x: i32) -> i32
   val = phi(loop, one, later_val)
   b = ne(one, val)
   if1 = if(loop, b)
-  if1_false = projection(if1, 0)
-  if1_true = projection(if1, 1)
+  if1_false = control_projection(if1, 0)
+  if1_true = control_projection(if1, 1)
   middle = region(if1_false, if1_true)
   inter_val = sub(two, val)
   later_val = phi(middle, inter_val, two)
   idx_dec = sub(idx, one)
   cond = gte(idx_dec, one)
   if2 = if(middle, cond)
-  if2_false = projection(if2, 0)
-  if2_true = projection(if2, 1)
+  if2_false = control_projection(if2, 0)
+  if2_true = control_projection(if2, 1)
   r = return(if2_false, later_val)
diff --git a/hercules_test/test_inputs/fork_transforms/fork_fission/inner_loop.hir b/hercules_test/test_inputs/fork_transforms/fork_fission/inner_loop.hir
index 0cc13b2fe21ab6cb434b9a13b3dc212c125394aa..b7458a438b4af6dbd3809897a5fd0469244c206c 100644
--- a/hercules_test/test_inputs/fork_transforms/fork_fission/inner_loop.hir
+++ b/hercules_test/test_inputs/fork_transforms/fork_fission/inner_loop.hir
@@ -11,8 +11,8 @@ fn fun<2>(x: u64) -> u64
   idx_inc = add(idx, one_idx)
   in_bounds = lt(idx, bound)
   if = if(loop, in_bounds)
-  if_false = projection(if, 0)
-  if_true = projection(if, 1)
+  if_false = control_projection(if, 0)
+  if_true = control_projection(if, 1)
   j = join(if_false)
   tid = thread_id(f, 0)
   add1 = add(reduce1, idx)
@@ -20,4 +20,4 @@ fn fun<2>(x: u64) -> u64
   add2 = add(reduce2, idx_inc)
   reduce2 =  reduce(j, zero, add2)
   out1 = add(reduce1, reduce2)
-  z = return(j, out1)
\ No newline at end of file
+  z = return(j, out1)
diff --git a/hercules_test/test_inputs/forkify/alternate_bounds.hir b/hercules_test/test_inputs/forkify/alternate_bounds.hir
index 4a9ba0153448e4c9339e901d3ba3a10e027bad56..7de8cf1e606bf2f231ad3391fad50e7b04d32d93 100644
--- a/hercules_test/test_inputs/forkify/alternate_bounds.hir
+++ b/hercules_test/test_inputs/forkify/alternate_bounds.hir
@@ -11,6 +11,6 @@ fn sum<1>(a: array(i32, #0)) -> i32
   red_add = add(red, read)
   in_bounds = lt(idx_inc, bound)
   if = if(loop, in_bounds)
-  if_false = projection(if, 0)
-  if_true = projection(if, 1)
-  r = return(if_false, red_add)
\ No newline at end of file
+  if_false = control_projection(if, 0)
+  if_true = control_projection(if, 1)
+  r = return(if_false, red_add)
diff --git a/hercules_test/test_inputs/forkify/broken_sum.hir b/hercules_test/test_inputs/forkify/broken_sum.hir
index d15ef5613e271cd8685660785f5505dedbf40ec9..75b12350da88214c761654a47cffb0943015afa9 100644
--- a/hercules_test/test_inputs/forkify/broken_sum.hir
+++ b/hercules_test/test_inputs/forkify/broken_sum.hir
@@ -11,6 +11,6 @@ fn sum<1>(a: array(i32, #0)) -> i32
   red_add = add(red, read)
   in_bounds = lt(idx, bound)
   if = if(loop, in_bounds)
-  if_false = projection(if, 0)
-  if_true = projection(if, 1)
-  r = return(if_false, red_add)
\ No newline at end of file
+  if_false = control_projection(if, 0)
+  if_true = control_projection(if, 1)
+  r = return(if_false, red_add)
diff --git a/hercules_test/test_inputs/forkify/control_after_condition.hir b/hercules_test/test_inputs/forkify/control_after_condition.hir
index db40225bac26f5d6687b4b36bd46e06c963c4815..a1a97fba3c5cd3655ccf47f6b97b1a63dbb6bf9d 100644
--- a/hercules_test/test_inputs/forkify/control_after_condition.hir
+++ b/hercules_test/test_inputs/forkify/control_after_condition.hir
@@ -11,8 +11,8 @@ fn alt_sum<1>(a: array(i32, #0)) -> i32
   rem = rem(idx, two_idx)
   odd = eq(rem, one_idx)
   negate_if = if(loop_continue, odd)
-  negate_if_false = projection(negate_if, 0)
-  negate_if_true = projection(negate_if, 1)
+  negate_if_false = control_projection(negate_if, 0)
+  negate_if_true = control_projection(negate_if, 1)
   negate_bottom = region(negate_if_false, negate_if_true)
   read = read(a, position(idx))
   read_neg = neg(read)
@@ -20,6 +20,6 @@ fn alt_sum<1>(a: array(i32, #0)) -> i32
   red_add = add(red, read_phi)
   in_bounds = lt(idx, bound)
   if = if(loop, in_bounds)
-  loop_exit = projection(if, 0)
-  loop_continue = projection(if, 1)
-  r = return(loop_exit, red)
\ No newline at end of file
+  loop_exit = control_projection(if, 0)
+  loop_continue = control_projection(if, 1)
+  r = return(loop_exit, red)
diff --git a/hercules_test/test_inputs/forkify/control_before_condition.hir b/hercules_test/test_inputs/forkify/control_before_condition.hir
index f24b565a56df1e9ed98e2637a8b66ef27ff6e7de..e351d7142a44df6c33e9598b56511fd84227fee7 100644
--- a/hercules_test/test_inputs/forkify/control_before_condition.hir
+++ b/hercules_test/test_inputs/forkify/control_before_condition.hir
@@ -11,8 +11,8 @@ fn alt_sum<1>(a: array(i32, #0)) -> i32
   rem = rem(idx, two_idx)
   odd = eq(rem, one_idx)
   negate_if = if(loop, odd)
-  negate_if_false = projection(negate_if, 0)
-  negate_if_true = projection(negate_if, 1)
+  negate_if_false = control_projection(negate_if, 0)
+  negate_if_true = control_projection(negate_if, 1)
   negate_bottom = region(negate_if_false, negate_if_true)
   read = read(a, position(idx))
   read_neg = neg(read)
@@ -20,6 +20,6 @@ fn alt_sum<1>(a: array(i32, #0)) -> i32
   red_add = add(red, read_phi)
   in_bounds = lt(idx, bound)
   if = if(negate_bottom, in_bounds)
-  if_false = projection(if, 0)
-  if_true = projection(if, 1)
-  r = return(if_false, red)
\ No newline at end of file
+  if_false = control_projection(if, 0)
+  if_true = control_projection(if, 1)
+  r = return(if_false, red)
diff --git a/hercules_test/test_inputs/forkify/expected_fails.hir/bad_3nest_return.hir b/hercules_test/test_inputs/forkify/expected_fails.hir/bad_3nest_return.hir
index f5ec4370b6befc4d45535fd25f024db60b55a0d0..7599e6ecbfda68266d40b0908d99ed89c5b4321c 100644
--- a/hercules_test/test_inputs/forkify/expected_fails.hir/bad_3nest_return.hir
+++ b/hercules_test/test_inputs/forkify/expected_fails.hir/bad_3nest_return.hir
@@ -16,18 +16,18 @@ fn loop<3>(a: u32) -> i32
   outer_idx_inc = add(outer_idx, one_idx)
   outer_in_bounds = lt(outer_idx, outer_bound)
   inner_if = if(inner_loop, inner_in_bounds)
-  inner_if_false = projection(inner_if, 0)
-  inner_if_true = projection(inner_if, 1)
+  inner_if_false = control_projection(inner_if, 0)
+  inner_if_true = control_projection(inner_if, 1)
   outer_if = if(outer_loop, outer_in_bounds)
-  outer_if_false = projection(outer_if, 0)
-  outer_if_true = projection(outer_if, 1)
+  outer_if_false = control_projection(outer_if, 0)
+  outer_if_true = control_projection(outer_if, 1)
   outer_bound = dynamic_constant(#1)
   outer_outer_bound = dynamic_constant(#2)
   outer_outer_loop = region(start, outer_if_false)
   outer_outer_var = phi(outer_outer_loop, zero_var, outer_var)
   outer_outer_if  = if(outer_outer_loop, outer_outer_in_bounds)
-  outer_outer_if_false = projection(outer_outer_if, 0)
-  outer_outer_if_true = projection(outer_outer_if, 1)
+  outer_outer_if_false = control_projection(outer_outer_if, 0)
+  outer_outer_if_true = control_projection(outer_outer_if, 1)
   outer_outer_idx = phi(outer_outer_loop, zero_idx, outer_outer_idx_inc, outer_outer_idx)
   outer_outer_idx_inc = add(outer_outer_idx, one_idx)
   outer_outer_in_bounds = lt(outer_outer_idx, outer_outer_bound)
diff --git a/hercules_test/test_inputs/forkify/expected_fails.hir/bad_loop_tid_sum.hir b/hercules_test/test_inputs/forkify/expected_fails.hir/bad_loop_tid_sum.hir
index 8dda179bf36d82d29bbbda48686191fa46a02fb6..8f7d5e48032571d15e1632dcd58067555b91279b 100644
--- a/hercules_test/test_inputs/forkify/expected_fails.hir/bad_loop_tid_sum.hir
+++ b/hercules_test/test_inputs/forkify/expected_fails.hir/bad_loop_tid_sum.hir
@@ -11,6 +11,6 @@ fn loop<1>(a: u64) -> u64
   idx_inc = add(idx, one_idx)
   in_bounds = lt(idx, bound)
   if = if(loop, in_bounds)
-  if_false = projection(if, 0)
-  if_true = projection(if, 1)
-  r = return(if_false, var_inc)
\ No newline at end of file
+  if_false = control_projection(if, 0)
+  if_true = control_projection(if, 1)
+  r = return(if_false, var_inc)
diff --git a/hercules_test/test_inputs/forkify/inner_fork.hir b/hercules_test/test_inputs/forkify/inner_fork.hir
index e2c96a68b85ee605d5804baf338ada9b493e155f..c603dc42d350e4a9e516de9f9442ac3ae353e617 100644
--- a/hercules_test/test_inputs/forkify/inner_fork.hir
+++ b/hercules_test/test_inputs/forkify/inner_fork.hir
@@ -6,7 +6,7 @@ fn loop<2>(a: u32) -> i32
   inner_bound = dynamic_constant(#0)
   outer_bound = dynamic_constant(#1)
   outer_loop = region(start, inner_join)
-  outer_if_true = projection(outer_if, 1)
+  outer_if_true = control_projection(outer_if, 1)
   inner_fork = fork(outer_if_true, #0)
   inner_join = join(inner_fork)
   outer_var = phi(outer_loop, zero_var, inner_var)
@@ -17,6 +17,6 @@ fn loop<2>(a: u32) -> i32
   outer_idx_inc = add(outer_idx, one_idx)
   outer_in_bounds = lt(outer_idx, outer_bound)
   outer_if = if(outer_loop, outer_in_bounds)
-  outer_if_false = projection(outer_if, 0)
+  outer_if_false = control_projection(outer_if, 0)
   r = return(outer_if_false, outer_var)
- 
\ No newline at end of file
+ 
diff --git a/hercules_test/test_inputs/forkify/inner_fork_complex.hir b/hercules_test/test_inputs/forkify/inner_fork_complex.hir
index 91eb00fae9dd7e043155c1f6316e50b38337bad2..c9488f7f6d3a3ebe69ffb5cbaa4d75623542df55 100644
--- a/hercules_test/test_inputs/forkify/inner_fork_complex.hir
+++ b/hercules_test/test_inputs/forkify/inner_fork_complex.hir
@@ -8,14 +8,14 @@ fn loop<2>(a: u32) -> u64
   inner_bound = dynamic_constant(#0)
   outer_bound = dynamic_constant(#1)
   outer_loop = region(start, inner_condition_true_projection, inner_condition_false_projection )
-  outer_if_true = projection(outer_if, 1)
+  outer_if_true = control_projection(outer_if, 1)
   other_phi_weird = phi(outer_loop, zero_var, inner_var, other_phi_weird)
   inner_fork = fork(outer_if_true, #0)
   inner_join = join(inner_fork)
   inner_condition_eq = eq(outer_idx, two)
   inner_condition_if = if(inner_join, inner_condition_eq)
-  inner_condition_true_projection = projection(inner_condition_if, 1)
-  inner_condition_false_projection = projection(inner_condition_if, 0)
+  inner_condition_true_projection = control_projection(inner_condition_if, 1)
+  inner_condition_false_projection = control_projection(inner_condition_if, 0)
   outer_var = phi(outer_loop, zero_var, inner_var, inner_var)
   inner_var = reduce(inner_join, outer_var, inner_var_inc)
   inner_var_inc = add(inner_var, inner_var_inc_3)
@@ -26,7 +26,7 @@ fn loop<2>(a: u32) -> u64
   outer_idx_inc = add(outer_idx, one_idx)
   outer_in_bounds = lt(outer_idx, outer_bound)
   outer_if = if(outer_loop, outer_in_bounds)
-  outer_if_false = projection(outer_if, 0)
+  outer_if_false = control_projection(outer_if, 0)
   ret_val = add(outer_var, other_phi_weird)
   r = return(outer_if_false, ret_val)
- 
\ No newline at end of file
+ 
diff --git a/hercules_test/test_inputs/forkify/loop_array_sum.hir b/hercules_test/test_inputs/forkify/loop_array_sum.hir
index f9972b5917c200b93b5775fd4a6e501318e8c548..884d22d469c068f180b8a0724325541ee44f7da4 100644
--- a/hercules_test/test_inputs/forkify/loop_array_sum.hir
+++ b/hercules_test/test_inputs/forkify/loop_array_sum.hir
@@ -11,6 +11,6 @@ fn sum<1>(a: array(i32, #0)) -> i32
   red_add = add(red, read)
   in_bounds = lt(idx, bound)
   if = if(loop, in_bounds)
-  if_false = projection(if, 0)
-  if_true = projection(if, 1)
-  r = return(if_false, red)
\ No newline at end of file
+  if_false = control_projection(if, 0)
+  if_true = control_projection(if, 1)
+  r = return(if_false, red)
diff --git a/hercules_test/test_inputs/forkify/loop_simple_iv.hir b/hercules_test/test_inputs/forkify/loop_simple_iv.hir
index c25b9a2cf006b67087df0b8dc652f72400557090..c671b94c27f05d5de124c0a4e705f259668400da 100644
--- a/hercules_test/test_inputs/forkify/loop_simple_iv.hir
+++ b/hercules_test/test_inputs/forkify/loop_simple_iv.hir
@@ -7,6 +7,6 @@ fn loop<1>(a: u32) -> u64
   idx_inc = add(idx, one_idx)
   in_bounds = lt(idx, bound)
   if = if(loop, in_bounds)
-  if_false = projection(if, 0)
-  if_true = projection(if, 1)
-  r = return(if_false, idx)
\ No newline at end of file
+  if_false = control_projection(if, 0)
+  if_true = control_projection(if, 1)
+  r = return(if_false, idx)
diff --git a/hercules_test/test_inputs/forkify/loop_sum.hir b/hercules_test/test_inputs/forkify/loop_sum.hir
index fd9c4debc163600c01e661b127b166358ac9c6db..a236ddf7323cb673608f5341f74dcea9699a3f8d 100644
--- a/hercules_test/test_inputs/forkify/loop_sum.hir
+++ b/hercules_test/test_inputs/forkify/loop_sum.hir
@@ -11,6 +11,6 @@ fn loop<1>(a: u32) -> i32
   idx_inc = add(idx, one_idx)
   in_bounds = lt(idx, bound)
   if = if(loop, in_bounds)
-  if_false = projection(if, 0)
-  if_true = projection(if, 1)
-  r = return(if_false, var)
\ No newline at end of file
+  if_false = control_projection(if, 0)
+  if_true = control_projection(if, 1)
+  r = return(if_false, var)
diff --git a/hercules_test/test_inputs/forkify/loop_tid_sum.hir b/hercules_test/test_inputs/forkify/loop_tid_sum.hir
index 2d3ca34db88efa5007b5fb933f0bb0e4b55e63e4..6a1e2c56f5c821be9c04c192e1d33fee083669af 100644
--- a/hercules_test/test_inputs/forkify/loop_tid_sum.hir
+++ b/hercules_test/test_inputs/forkify/loop_tid_sum.hir
@@ -11,6 +11,6 @@ fn loop<1>(a: u64) -> u64
   idx_inc = add(idx, one_idx)
   in_bounds = lt(idx, bound)
   if = if(loop, in_bounds)
-  if_false = projection(if, 0)
-  if_true = projection(if, 1)
-  r = return(if_false, var)
\ No newline at end of file
+  if_false = control_projection(if, 0)
+  if_true = control_projection(if, 1)
+  r = return(if_false, var)
diff --git a/hercules_test/test_inputs/forkify/merged_phi_cycle.hir b/hercules_test/test_inputs/forkify/merged_phi_cycle.hir
index cee473a08740b48935d4ff571bc003bbd9908729..2b276d3ec1c560d6e014dfdd3c1ddc93374f8c2f 100644
--- a/hercules_test/test_inputs/forkify/merged_phi_cycle.hir
+++ b/hercules_test/test_inputs/forkify/merged_phi_cycle.hir
@@ -13,6 +13,6 @@ fn sum<1>(a: i32) -> u64
   second_red_add_2 = add(first_red_add, two)
   in_bounds = lt(idx_inc, bound)
   if = if(loop, in_bounds)
-  if_false = projection(if, 0)
-  if_true = projection(if, 1)
-  r = return(if_false, first_red_add_2)
\ No newline at end of file
+  if_false = control_projection(if, 0)
+  if_true = control_projection(if, 1)
+  r = return(if_false, first_red_add_2)
diff --git a/hercules_test/test_inputs/forkify/nested_loop2.hir b/hercules_test/test_inputs/forkify/nested_loop2.hir
index 0f29ec747dd67bae4c5b7d8b1250ead925257e9d..c3c7d8e5b7f0a3f23f28216ae4190f40edb0d7c5 100644
--- a/hercules_test/test_inputs/forkify/nested_loop2.hir
+++ b/hercules_test/test_inputs/forkify/nested_loop2.hir
@@ -17,9 +17,9 @@ fn loop<2>(a: u32) -> i32
   outer_idx_inc = add(outer_idx, one_idx)
   outer_in_bounds = lt(outer_idx, outer_bound)
   inner_if = if(inner_loop, inner_in_bounds)
-  inner_if_false = projection(inner_if, 0)
-  inner_if_true = projection(inner_if, 1)
+  inner_if_false = control_projection(inner_if, 0)
+  inner_if_true = control_projection(inner_if, 1)
   outer_if = if(outer_loop, outer_in_bounds)
-  outer_if_false = projection(outer_if, 0)
-  outer_if_true = projection(outer_if, 1)
-  r = return(outer_if_false, outer_var)
\ No newline at end of file
+  outer_if_false = control_projection(outer_if, 0)
+  outer_if_true = control_projection(outer_if, 1)
+  r = return(outer_if_false, outer_var)
diff --git a/hercules_test/test_inputs/forkify/nested_tid_sum.hir b/hercules_test/test_inputs/forkify/nested_tid_sum.hir
index 5539202d297d5124bd3258e56568c1614e4e942e..f7e4bda4112d62dd4598b0d73ccba583b12cdd4b 100644
--- a/hercules_test/test_inputs/forkify/nested_tid_sum.hir
+++ b/hercules_test/test_inputs/forkify/nested_tid_sum.hir
@@ -17,9 +17,9 @@ fn loop<2>(a: u32) -> u64
   outer_idx_inc = add(outer_idx, one_idx)
   outer_in_bounds = lt(outer_idx, outer_bound)
   inner_if = if(inner_loop, inner_in_bounds)
-  inner_if_false = projection(inner_if, 0)
-  inner_if_true = projection(inner_if, 1)
+  inner_if_false = control_projection(inner_if, 0)
+  inner_if_true = control_projection(inner_if, 1)
   outer_if = if(outer_loop, outer_in_bounds)
-  outer_if_false = projection(outer_if, 0)
-  outer_if_true = projection(outer_if, 1)
-  r = return(outer_if_false, outer_var)
\ No newline at end of file
+  outer_if_false = control_projection(outer_if, 0)
+  outer_if_true = control_projection(outer_if, 1)
+  r = return(outer_if_false, outer_var)
diff --git a/hercules_test/test_inputs/forkify/nested_tid_sum_2.hir b/hercules_test/test_inputs/forkify/nested_tid_sum_2.hir
index 9221fd47d857c9215a7b8912e68088b5809238b5..50634a2c93095fa0d891e8b54ad18fb0f0fbab88 100644
--- a/hercules_test/test_inputs/forkify/nested_tid_sum_2.hir
+++ b/hercules_test/test_inputs/forkify/nested_tid_sum_2.hir
@@ -18,9 +18,9 @@ fn loop<2>(a: u32) -> u64
   outer_idx_inc = add(outer_idx, one_idx)
   outer_in_bounds = lt(outer_idx, outer_bound)
   inner_if = if(inner_loop, inner_in_bounds)
-  inner_if_false = projection(inner_if, 0)
-  inner_if_true = projection(inner_if, 1)
+  inner_if_false = control_projection(inner_if, 0)
+  inner_if_true = control_projection(inner_if, 1)
   outer_if = if(outer_loop, outer_in_bounds)
-  outer_if_false = projection(outer_if, 0)
-  outer_if_true = projection(outer_if, 1)
-  r = return(outer_if_false, outer_var)
\ No newline at end of file
+  outer_if_false = control_projection(outer_if, 0)
+  outer_if_true = control_projection(outer_if, 1)
+  r = return(outer_if_false, outer_var)
diff --git a/hercules_test/test_inputs/forkify/phi_loop4.hir b/hercules_test/test_inputs/forkify/phi_loop4.hir
index e69ecc3daf264359426acd7b5dbf9ff84fd96c4c..9ce594da94a34065023ab1397f031cab9b7710ff 100644
--- a/hercules_test/test_inputs/forkify/phi_loop4.hir
+++ b/hercules_test/test_inputs/forkify/phi_loop4.hir
@@ -11,6 +11,6 @@ fn loop<1>(a: u32) -> i32
   idx_inc = add(idx, one_idx)
   in_bounds = lt(idx, bound)
   if = if(loop, in_bounds)
-  if_false = projection(if, 0)
-  if_true = projection(if, 1)
-  r = return(if_false, var_inc)
\ No newline at end of file
+  if_false = control_projection(if, 0)
+  if_true = control_projection(if, 1)
+  r = return(if_false, var_inc)
diff --git a/hercules_test/test_inputs/forkify/split_phi_cycle.hir b/hercules_test/test_inputs/forkify/split_phi_cycle.hir
index 96de73c8e054fb9cb45ade6dfe9150fcfc79334f..a233230bcb29152fa2d10382ae66884c4aba3f37 100644
--- a/hercules_test/test_inputs/forkify/split_phi_cycle.hir
+++ b/hercules_test/test_inputs/forkify/split_phi_cycle.hir
@@ -11,6 +11,6 @@ fn sum<1>(a: i32) -> u64
   first_red_add_2 = add(first_red_add, two)
   in_bounds = lt(idx_inc, bound)
   if = if(loop, in_bounds)
-  if_false = projection(if, 0)
-  if_true = projection(if, 1)
-  r = return(if_false, first_red_add_2)
\ No newline at end of file
+  if_false = control_projection(if, 0)
+  if_true = control_projection(if, 1)
+  r = return(if_false, first_red_add_2)
diff --git a/hercules_test/test_inputs/forkify/super_nested_loop.hir b/hercules_test/test_inputs/forkify/super_nested_loop.hir
index 6853efbfc2b620860644bc94486fb09a09e131f0..b568b85aa8a3a2912af3b1fa8c5b92d10a2a7ece 100644
--- a/hercules_test/test_inputs/forkify/super_nested_loop.hir
+++ b/hercules_test/test_inputs/forkify/super_nested_loop.hir
@@ -16,18 +16,18 @@ fn loop<3>(a: u32) -> i32
   outer_idx_inc = add(outer_idx, one_idx)
   outer_in_bounds = lt(outer_idx, outer_bound)
   inner_if = if(inner_loop, inner_in_bounds)
-  inner_if_false = projection(inner_if, 0)
-  inner_if_true = projection(inner_if, 1)
+  inner_if_false = control_projection(inner_if, 0)
+  inner_if_true = control_projection(inner_if, 1)
   outer_if = if(outer_loop, outer_in_bounds)
-  outer_if_false = projection(outer_if, 0)
-  outer_if_true = projection(outer_if, 1)
+  outer_if_false = control_projection(outer_if, 0)
+  outer_if_true = control_projection(outer_if, 1)
   outer_bound = dynamic_constant(#1)
   outer_outer_bound = dynamic_constant(#2)
   outer_outer_loop = region(start, outer_if_false)
   outer_outer_var = phi(outer_outer_loop, zero_var, outer_var)
   outer_outer_if  = if(outer_outer_loop, outer_outer_in_bounds)
-  outer_outer_if_false = projection(outer_outer_if, 0)
-  outer_outer_if_true = projection(outer_outer_if, 1)
+  outer_outer_if_false = control_projection(outer_outer_if, 0)
+  outer_outer_if_true = control_projection(outer_outer_if, 1)
   outer_outer_idx = phi(outer_outer_loop, zero_idx, outer_outer_idx_inc, outer_outer_idx)
   outer_outer_idx_inc = add(outer_outer_idx, one_idx)
   outer_outer_in_bounds = lt(outer_outer_idx, outer_outer_bound)
diff --git a/hercules_test/test_inputs/loop_analysis/alternate_bounds.hir b/hercules_test/test_inputs/loop_analysis/alternate_bounds.hir
index 4df92a18a9895c88cf27f143285999ef2218bfcf..a6ae209b9bb8661164022c12eb4d7425f69fa444 100644
--- a/hercules_test/test_inputs/loop_analysis/alternate_bounds.hir
+++ b/hercules_test/test_inputs/loop_analysis/alternate_bounds.hir
@@ -9,6 +9,6 @@ fn sum<1>(a: u32) -> u64
   red_add = add(red, one_idx)
   in_bounds = lt(idx_inc, bound)
   if = if(loop, in_bounds)
-  if_false = projection(if, 0)
-  if_true = projection(if, 1)
-  r = return(if_false, red)
\ No newline at end of file
+  if_false = control_projection(if, 0)
+  if_true = control_projection(if, 1)
+  r = return(if_false, red)
diff --git a/hercules_test/test_inputs/loop_analysis/alternate_bounds_internal_control.hir b/hercules_test/test_inputs/loop_analysis/alternate_bounds_internal_control.hir
index 8b4431bfb446237b6a66088d7a7d2339d9c889d1..9bd6b626d4ae38a1611ffa131be6933c141a9e0d 100644
--- a/hercules_test/test_inputs/loop_analysis/alternate_bounds_internal_control.hir
+++ b/hercules_test/test_inputs/loop_analysis/alternate_bounds_internal_control.hir
@@ -15,9 +15,9 @@ fn sum<1>(a: u64) -> u64
   red_add2 = add(red, inner_phi)
   in_bounds = lt(idx_inc, bound)
   if = if(inner_ctrl, in_bounds)
-  if_false = projection(if, 0)
-  if_true = projection(if, 1)
+  if_false = control_projection(if, 0)
+  if_true = control_projection(if, 1)
   plus_ten = add(red_add, ten)
   red_add_2_plus_blah = add(red2, plus_ten)
   final_add = add(inner_phi, red_add_2_plus_blah)
-  r = return(if_false, final_add)
\ No newline at end of file
+  r = return(if_false, final_add)
diff --git a/hercules_test/test_inputs/loop_analysis/alternate_bounds_internal_control2.hir b/hercules_test/test_inputs/loop_analysis/alternate_bounds_internal_control2.hir
index f4adf6435968ddaccc285c8d6132e7c9dd91c973..2801a1656db8db1397d1b5d32dcbbd40acd9f76e 100644
--- a/hercules_test/test_inputs/loop_analysis/alternate_bounds_internal_control2.hir
+++ b/hercules_test/test_inputs/loop_analysis/alternate_bounds_internal_control2.hir
@@ -13,9 +13,9 @@ fn sum<1>(a: u64) -> u64
   red_add = add(red, two)
   in_bounds = lt(idx_inc, bound)
   if = if(inner_ctrl, in_bounds)
-  if_false = projection(if, 0)
-  if_true = projection(if, 1)
+  if_false = control_projection(if, 0)
+  if_true = control_projection(if, 1)
   plus_ten = add(red_add, ten)
   red_add_2_plus_blah = add(inner_phi, plus_ten)
   final_add = add(inner_phi, red_add_2_plus_blah)
-  r = return(if_false, final_add)
\ No newline at end of file
+  r = return(if_false, final_add)
diff --git a/hercules_test/test_inputs/loop_analysis/alternate_bounds_nested_do_loop.hir b/hercules_test/test_inputs/loop_analysis/alternate_bounds_nested_do_loop.hir
index 52f701727c33305c0719e46aacf717bf8b220fcb..edbec0c5929cfb056d4015d81b775e2a5ba4bbbb 100644
--- a/hercules_test/test_inputs/loop_analysis/alternate_bounds_nested_do_loop.hir
+++ b/hercules_test/test_inputs/loop_analysis/alternate_bounds_nested_do_loop.hir
@@ -20,9 +20,9 @@ fn loop<2>(a: u64) -> u64
   outer_idx_inc = add(outer_idx, one_idx)
   outer_in_bounds = lt(outer_idx_inc, outer_bound)
   inner_if = if(inner_loop, inner_in_bounds)
-  inner_if_false = projection(inner_if, 0)
-  inner_if_true = projection(inner_if, 1)
+  inner_if_false = control_projection(inner_if, 0)
+  inner_if_true = control_projection(inner_if, 1)
   outer_if = if(inner_if_false, outer_in_bounds)
-  outer_if_false = projection(outer_if, 0)
-  outer_if_true = projection(outer_if, 1)
-  r = return(outer_if_false, inner_var_inc)
\ No newline at end of file
+  outer_if_false = control_projection(outer_if, 0)
+  outer_if_true = control_projection(outer_if, 1)
+  r = return(outer_if_false, inner_var_inc)
diff --git a/hercules_test/test_inputs/loop_analysis/alternate_bounds_nested_do_loop2.hir b/hercules_test/test_inputs/loop_analysis/alternate_bounds_nested_do_loop2.hir
index f295b39166fc16bb4560d9100bb84526281dda84..4e81871c1e6a7978d9fbafb912e27c3e0e62ff1c 100644
--- a/hercules_test/test_inputs/loop_analysis/alternate_bounds_nested_do_loop2.hir
+++ b/hercules_test/test_inputs/loop_analysis/alternate_bounds_nested_do_loop2.hir
@@ -17,9 +17,9 @@ fn loop<2>(a: u32) -> i32
   outer_idx_inc = add(outer_idx, one_idx)
   outer_in_bounds = lt(outer_idx_inc, outer_bound)
   inner_if = if(inner_loop, inner_in_bounds)
-  inner_if_false = projection(inner_if, 0)
-  inner_if_true = projection(inner_if, 1)
+  inner_if_false = control_projection(inner_if, 0)
+  inner_if_true = control_projection(inner_if, 1)
   outer_if = if(inner_if_false, outer_in_bounds)
-  outer_if_false = projection(outer_if, 0)
-  outer_if_true = projection(outer_if, 1)
-  r = return(outer_if_false, inner_var_inc)
\ No newline at end of file
+  outer_if_false = control_projection(outer_if, 0)
+  outer_if_true = control_projection(outer_if, 1)
+  r = return(outer_if_false, inner_var_inc)
diff --git a/hercules_test/test_inputs/loop_analysis/alternate_bounds_nested_do_loop_array.hir b/hercules_test/test_inputs/loop_analysis/alternate_bounds_nested_do_loop_array.hir
index e5401779a28503677f4a5e51c703c4197b433d4e..98477c91e40d6f0099aedb1ccc9b5e227aa8c7f4 100644
--- a/hercules_test/test_inputs/loop_analysis/alternate_bounds_nested_do_loop_array.hir
+++ b/hercules_test/test_inputs/loop_analysis/alternate_bounds_nested_do_loop_array.hir
@@ -20,9 +20,9 @@ fn loop<2>(a: array(u64, #1)) -> u64
   outer_idx_inc = add(outer_idx, one_idx)
   outer_in_bounds = lt(outer_idx_inc, outer_bound)
   inner_if = if(inner_loop, inner_in_bounds)
-  inner_if_false = projection(inner_if, 0)
-  inner_if_true = projection(inner_if, 1)
+  inner_if_false = control_projection(inner_if, 0)
+  inner_if_true = control_projection(inner_if, 1)
   outer_if = if(inner_if_false, outer_in_bounds)
-  outer_if_false = projection(outer_if, 0)
-  outer_if_true = projection(outer_if, 1)
-  r = return(outer_if_false, inner_var_inc)
\ No newline at end of file
+  outer_if_false = control_projection(outer_if, 0)
+  outer_if_true = control_projection(outer_if, 1)
+  r = return(outer_if_false, inner_var_inc)
diff --git a/hercules_test/test_inputs/loop_analysis/alternate_bounds_nested_do_loop_guarded.hir b/hercules_test/test_inputs/loop_analysis/alternate_bounds_nested_do_loop_guarded.hir
index b979ad42ce0522cc9b27d1fab07736cc73831590..eee77b6c22d13f11bd859e70d9a2deba26604fc3 100644
--- a/hercules_test/test_inputs/loop_analysis/alternate_bounds_nested_do_loop_guarded.hir
+++ b/hercules_test/test_inputs/loop_analysis/alternate_bounds_nested_do_loop_guarded.hir
@@ -5,8 +5,8 @@ fn loop<2>(a: u64) -> u64
   one_var = constant(u64, 1)
   ten = constant(u64, 10)
   outer_guard_if = if(start, outer_guard_lt)
-  outer_guard_if_false = projection(outer_guard_if, 0)
-  outer_guard_if_true = projection(outer_guard_if, 1)
+  outer_guard_if_false = control_projection(outer_guard_if, 0)
+  outer_guard_if_true = control_projection(outer_guard_if, 1)
   outer_guard_lt = lt(zero_idx, outer_bound)
   outer_join = region(outer_guard_if_false, outer_if_false)
   outer_join_var = phi(outer_join, zero_idx, join_var)
@@ -16,8 +16,8 @@ fn loop<2>(a: u64) -> u64
   inner_loop = region(guard_if_true, inner_if_true)
   guard_lt = lt(zero_idx, inner_bound)
   guard_if = if(outer_loop, guard_lt)
-  guard_if_true = projection(guard_if, 1)
-  guard_if_false = projection(guard_if, 0)
+  guard_if_true = control_projection(guard_if, 1)
+  guard_if_false = control_projection(guard_if, 0)
   guard_join = region(guard_if_false, inner_if_false)
   inner_idx = phi(inner_loop, zero_idx, inner_idx_inc)
   inner_idx_inc = add(inner_idx, one_idx)
@@ -26,15 +26,15 @@ fn loop<2>(a: u64) -> u64
   outer_idx_inc = add(outer_idx, one_idx)
   outer_in_bounds = lt(outer_idx_inc, outer_bound)
   inner_if = if(inner_loop, inner_in_bounds)
-  inner_if_false = projection(inner_if, 0)
-  inner_if_true = projection(inner_if, 1)
+  inner_if_false = control_projection(inner_if, 0)
+  inner_if_true = control_projection(inner_if, 1)
   outer_if = if(guard_join, outer_in_bounds)
-  outer_if_false = projection(outer_if, 0)
-  outer_if_true = projection(outer_if, 1)
+  outer_if_false = control_projection(outer_if, 0)
+  outer_if_true = control_projection(outer_if, 1)
   outer_var = phi(outer_loop, zero_var, join_var)
   inner_var = phi(inner_loop, outer_var, inner_var_inc)
   blah = mul(outer_idx, ten)
   blah2 = add(blah, inner_idx)
   inner_var_inc = add(inner_var, blah2)
   join_var = phi(guard_join, outer_var, inner_var_inc)
-  r = return(outer_join, outer_join_var)
\ No newline at end of file
+  r = return(outer_join, outer_join_var)
diff --git a/hercules_test/test_inputs/loop_analysis/alternate_bounds_use_after_loop.hir b/hercules_test/test_inputs/loop_analysis/alternate_bounds_use_after_loop.hir
index 2fe4ca57345dd0e6c4dd94e399dba58e3cab81a4..4a6e8cd64ac6d016200979db1b4c09731ac125a2 100644
--- a/hercules_test/test_inputs/loop_analysis/alternate_bounds_use_after_loop.hir
+++ b/hercules_test/test_inputs/loop_analysis/alternate_bounds_use_after_loop.hir
@@ -13,9 +13,9 @@ fn sum<1>(a: array(i32, #0)) -> i32
   red_add = add(red, read)
   in_bounds = lt(idx_inc, bound)
   if = if(loop, in_bounds)
-  if_false = projection(if, 0)
-  if_true = projection(if, 1)
+  if_false = control_projection(if, 0)
+  if_true = control_projection(if, 1)
   plus_ten = add(red_add, ten)
   mult = mul(read, three)
   final = add(plus_ten, mult)
-  r = return(if_false, final)
\ No newline at end of file
+  r = return(if_false, final)
diff --git a/hercules_test/test_inputs/loop_analysis/alternate_bounds_use_after_loop2.hir b/hercules_test/test_inputs/loop_analysis/alternate_bounds_use_after_loop2.hir
index 760ae5ad690382c42e6760af690b18e5ea36a6b2..f735c8c6c1c41c00ea885d6a847e1a43ad7a34b3 100644
--- a/hercules_test/test_inputs/loop_analysis/alternate_bounds_use_after_loop2.hir
+++ b/hercules_test/test_inputs/loop_analysis/alternate_bounds_use_after_loop2.hir
@@ -13,9 +13,9 @@ fn sum<1>(a: array(i32, #0)) -> i32
   red_add = add(red, read)
   in_bounds = lt(idx_inc, bound)
   if = if(loop, in_bounds)
-  if_false = projection(if, 0)
-  if_true = projection(if, 1)
+  if_false = control_projection(if, 0)
+  if_true = control_projection(if, 1)
   plus_ten = add(red, ten)
   mult = mul(read, three)
   final = add(plus_ten, mult)
-  r = return(if_false, final)
\ No newline at end of file
+  r = return(if_false, final)
diff --git a/hercules_test/test_inputs/loop_analysis/alternate_bounds_use_after_loop_no_tid.hir b/hercules_test/test_inputs/loop_analysis/alternate_bounds_use_after_loop_no_tid.hir
index 4b9375090dd5c13a127e9db85d7ce6dc8f2f7d75..c2f5e30a5cc826b10f1f5978b2d3fdaef5582351 100644
--- a/hercules_test/test_inputs/loop_analysis/alternate_bounds_use_after_loop_no_tid.hir
+++ b/hercules_test/test_inputs/loop_analysis/alternate_bounds_use_after_loop_no_tid.hir
@@ -11,7 +11,7 @@ fn sum<1>(a: u64) -> u64
   red_add = add(red, two)
   in_bounds = lt(idx_inc, bound)
   if = if(loop, in_bounds)
-  if_false = projection(if, 0)
-  if_true = projection(if, 1)
+  if_false = control_projection(if, 0)
+  if_true = control_projection(if, 1)
   plus_ten = add(red_add, ten)
-  r = return(if_false, plus_ten)
\ No newline at end of file
+  r = return(if_false, plus_ten)
diff --git a/hercules_test/test_inputs/loop_analysis/alternate_bounds_use_after_loop_no_tid2.hir b/hercules_test/test_inputs/loop_analysis/alternate_bounds_use_after_loop_no_tid2.hir
index fd06eb7dd22f64022cc797a549946bfda5e8b7cf..f7a4af06b7b3502cc73d783bbb2f367b660594d8 100644
--- a/hercules_test/test_inputs/loop_analysis/alternate_bounds_use_after_loop_no_tid2.hir
+++ b/hercules_test/test_inputs/loop_analysis/alternate_bounds_use_after_loop_no_tid2.hir
@@ -12,8 +12,8 @@ fn sum<1>(a: u64) -> u64
   blah = phi(loop, zero_idx, red_add)
   in_bounds = lt(idx_inc, bound)
   if = if(loop, in_bounds)
-  if_false = projection(if, 0)
-  if_true = projection(if, 1)
+  if_false = control_projection(if, 0)
+  if_true = control_projection(if, 1)
   plus_ten = add(red_add, ten)
   plus_blah = add(blah, red_add)
-  r = return(if_false, plus_blah)
\ No newline at end of file
+  r = return(if_false, plus_blah)
diff --git a/hercules_test/test_inputs/loop_analysis/broken_sum.hir b/hercules_test/test_inputs/loop_analysis/broken_sum.hir
index d15ef5613e271cd8685660785f5505dedbf40ec9..75b12350da88214c761654a47cffb0943015afa9 100644
--- a/hercules_test/test_inputs/loop_analysis/broken_sum.hir
+++ b/hercules_test/test_inputs/loop_analysis/broken_sum.hir
@@ -11,6 +11,6 @@ fn sum<1>(a: array(i32, #0)) -> i32
   red_add = add(red, read)
   in_bounds = lt(idx, bound)
   if = if(loop, in_bounds)
-  if_false = projection(if, 0)
-  if_true = projection(if, 1)
-  r = return(if_false, red_add)
\ No newline at end of file
+  if_false = control_projection(if, 0)
+  if_true = control_projection(if, 1)
+  r = return(if_false, red_add)
diff --git a/hercules_test/test_inputs/loop_analysis/do_loop_far_guard.hir b/hercules_test/test_inputs/loop_analysis/do_loop_far_guard.hir
index 4df92a18a9895c88cf27f143285999ef2218bfcf..a6ae209b9bb8661164022c12eb4d7425f69fa444 100644
--- a/hercules_test/test_inputs/loop_analysis/do_loop_far_guard.hir
+++ b/hercules_test/test_inputs/loop_analysis/do_loop_far_guard.hir
@@ -9,6 +9,6 @@ fn sum<1>(a: u32) -> u64
   red_add = add(red, one_idx)
   in_bounds = lt(idx_inc, bound)
   if = if(loop, in_bounds)
-  if_false = projection(if, 0)
-  if_true = projection(if, 1)
-  r = return(if_false, red)
\ No newline at end of file
+  if_false = control_projection(if, 0)
+  if_true = control_projection(if, 1)
+  r = return(if_false, red)
diff --git a/hercules_test/test_inputs/loop_analysis/do_loop_immediate_guard.hir b/hercules_test/test_inputs/loop_analysis/do_loop_immediate_guard.hir
index a4732cdeaa1f474548fba0293e21a900448d1791..bfdb673faffdc1835d5b3f96ea51da6331279c0a 100644
--- a/hercules_test/test_inputs/loop_analysis/do_loop_immediate_guard.hir
+++ b/hercules_test/test_inputs/loop_analysis/do_loop_immediate_guard.hir
@@ -4,8 +4,8 @@ fn sum<1>(a: u64) -> u64
   bound = dynamic_constant(#0)
   guard_lt = lt(zero_idx, bound)
   guard = if(start, guard_lt)
-  guard_true = projection(guard, 1)
-  guard_false = projection(guard, 0)
+  guard_true = control_projection(guard, 1)
+  guard_false = control_projection(guard, 0)
   loop = region(guard_true, if_true)
   inner_side_effect = region(loop)
   idx = phi(loop, zero_idx, idx_inc)
@@ -15,7 +15,7 @@ fn sum<1>(a: u64) -> u64
   join_phi = phi(final, zero_idx, red_add)
   in_bounds = lt(idx_inc, bound)
   if = if(inner_side_effect, in_bounds)
-  if_false = projection(if, 0)
-  if_true = projection(if, 1)
+  if_false = control_projection(if, 0)
+  if_true = control_projection(if, 1)
   final = region(guard_false, if_false)
-  r = return(final, join_phi)
\ No newline at end of file
+  r = return(final, join_phi)
diff --git a/hercules_test/test_inputs/loop_analysis/do_loop_no_guard.hir b/hercules_test/test_inputs/loop_analysis/do_loop_no_guard.hir
index 9e22e14baba40dcda34de4a9d36c05dfe73f11eb..d48fe0622ad8b6f53024f929b90ac1962eb07154 100644
--- a/hercules_test/test_inputs/loop_analysis/do_loop_no_guard.hir
+++ b/hercules_test/test_inputs/loop_analysis/do_loop_no_guard.hir
@@ -10,6 +10,6 @@ fn sum<1>(a: u64) -> u64
   red_add = add(red, one_idx)
   in_bounds = lt(idx_inc, bound)
   if = if(inner_side_effect, in_bounds)
-  if_false = projection(if, 0)
-  if_true = projection(if, 1)
-  r = return(if_false, red_add)
\ No newline at end of file
+  if_false = control_projection(if, 0)
+  if_true = control_projection(if, 1)
+  r = return(if_false, red_add)
diff --git a/hercules_test/test_inputs/loop_analysis/do_while_separate_body.hir b/hercules_test/test_inputs/loop_analysis/do_while_separate_body.hir
index 42269040615520dd5ff2c151bd545a94445f520f..435b62686bd219b94824a71438e780538bf69738 100644
--- a/hercules_test/test_inputs/loop_analysis/do_while_separate_body.hir
+++ b/hercules_test/test_inputs/loop_analysis/do_while_separate_body.hir
@@ -11,6 +11,6 @@ fn sum<1>(a: i32) -> u64
   red_add = add(outer_red, idx)
   in_bounds = lt(idx_inc, bound)
   if = if(inner_region, in_bounds)
-  if_false = projection(if, 0)
-  if_true = projection(if, 1)
-  r = return(if_false, inner_red)
\ No newline at end of file
+  if_false = control_projection(if, 0)
+  if_true = control_projection(if, 1)
+  r = return(if_false, inner_red)
diff --git a/hercules_test/test_inputs/loop_analysis/do_while_separate_body2.hir b/hercules_test/test_inputs/loop_analysis/do_while_separate_body2.hir
index a751952dcde83e7ffc0ee64f506314724f7bd745..d1e2d4d64a46d40cb2ad0feb3c6bcbb9eeb0a8c1 100644
--- a/hercules_test/test_inputs/loop_analysis/do_while_separate_body2.hir
+++ b/hercules_test/test_inputs/loop_analysis/do_while_separate_body2.hir
@@ -13,6 +13,6 @@ fn sum<1>(a: i32) -> u64
   red_mul = mul(red_add, idx)
   in_bounds = lt(idx_inc, bound)
   if = if(inner_region, in_bounds)
-  if_false = projection(if, 0)
-  if_true = projection(if, 1)
-  r = return(if_false, inner_red)
\ No newline at end of file
+  if_false = control_projection(if, 0)
+  if_true = control_projection(if, 1)
+  r = return(if_false, inner_red)
diff --git a/hercules_test/test_inputs/loop_analysis/loop_array_sum.hir b/hercules_test/test_inputs/loop_analysis/loop_array_sum.hir
index f9972b5917c200b93b5775fd4a6e501318e8c548..884d22d469c068f180b8a0724325541ee44f7da4 100644
--- a/hercules_test/test_inputs/loop_analysis/loop_array_sum.hir
+++ b/hercules_test/test_inputs/loop_analysis/loop_array_sum.hir
@@ -11,6 +11,6 @@ fn sum<1>(a: array(i32, #0)) -> i32
   red_add = add(red, read)
   in_bounds = lt(idx, bound)
   if = if(loop, in_bounds)
-  if_false = projection(if, 0)
-  if_true = projection(if, 1)
-  r = return(if_false, red)
\ No newline at end of file
+  if_false = control_projection(if, 0)
+  if_true = control_projection(if, 1)
+  r = return(if_false, red)
diff --git a/hercules_test/test_inputs/loop_analysis/loop_body_count.hir b/hercules_test/test_inputs/loop_analysis/loop_body_count.hir
index c6f3cbf649484f9b00f0fcc9cd208a8b4811284f..5ec745ba4e2f65780c2d43b7c5769114b7adf431 100644
--- a/hercules_test/test_inputs/loop_analysis/loop_body_count.hir
+++ b/hercules_test/test_inputs/loop_analysis/loop_body_count.hir
@@ -11,6 +11,6 @@ fn loop<1>(a: u64) -> u64
   idx_inc = add(idx, one_idx)
   in_bounds = lt(idx, bound)
   if = if(loop, in_bounds)
-  if_false = projection(if, 0)
-  if_true = projection(if, 1)
-  r = return(if_false, var)
\ No newline at end of file
+  if_false = control_projection(if, 0)
+  if_true = control_projection(if, 1)
+  r = return(if_false, var)
diff --git a/hercules_test/test_inputs/loop_analysis/loop_sum.hir b/hercules_test/test_inputs/loop_analysis/loop_sum.hir
index fd9c4debc163600c01e661b127b166358ac9c6db..a236ddf7323cb673608f5341f74dcea9699a3f8d 100644
--- a/hercules_test/test_inputs/loop_analysis/loop_sum.hir
+++ b/hercules_test/test_inputs/loop_analysis/loop_sum.hir
@@ -11,6 +11,6 @@ fn loop<1>(a: u32) -> i32
   idx_inc = add(idx, one_idx)
   in_bounds = lt(idx, bound)
   if = if(loop, in_bounds)
-  if_false = projection(if, 0)
-  if_true = projection(if, 1)
-  r = return(if_false, var)
\ No newline at end of file
+  if_false = control_projection(if, 0)
+  if_true = control_projection(if, 1)
+  r = return(if_false, var)
diff --git a/hercules_test/test_inputs/loop_analysis/loop_trip_count.hir b/hercules_test/test_inputs/loop_analysis/loop_trip_count.hir
index b756f0901fb7a66a3feb83d1611aa1711bcb5601..799cc6d9d4db4b5527566cbd9f1f222979624bcb 100644
--- a/hercules_test/test_inputs/loop_analysis/loop_trip_count.hir
+++ b/hercules_test/test_inputs/loop_analysis/loop_trip_count.hir
@@ -12,8 +12,8 @@ fn loop<1>(b: prod(u64, u64)) -> prod(u64, u64)
   idx_inc = add(idx, one_idx)
   in_bounds = lt(idx, bound)
   if = if(loop, in_bounds)
-  if_false = projection(if, 0)
-  if_true = projection(if, 1)
+  if_false = control_projection(if, 0)
+  if_true = control_projection(if, 1)
   tuple1 = write(c, var, field(0))
   tuple2 = write(tuple1, idx, field(1))
-  r = return(if_false, tuple2)
\ No newline at end of file
+  r = return(if_false, tuple2)
diff --git a/hercules_test/test_inputs/loop_analysis/loop_trip_count_tuple.hir b/hercules_test/test_inputs/loop_analysis/loop_trip_count_tuple.hir
index b756f0901fb7a66a3feb83d1611aa1711bcb5601..799cc6d9d4db4b5527566cbd9f1f222979624bcb 100644
--- a/hercules_test/test_inputs/loop_analysis/loop_trip_count_tuple.hir
+++ b/hercules_test/test_inputs/loop_analysis/loop_trip_count_tuple.hir
@@ -12,8 +12,8 @@ fn loop<1>(b: prod(u64, u64)) -> prod(u64, u64)
   idx_inc = add(idx, one_idx)
   in_bounds = lt(idx, bound)
   if = if(loop, in_bounds)
-  if_false = projection(if, 0)
-  if_true = projection(if, 1)
+  if_false = control_projection(if, 0)
+  if_true = control_projection(if, 1)
   tuple1 = write(c, var, field(0))
   tuple2 = write(tuple1, idx, field(1))
-  r = return(if_false, tuple2)
\ No newline at end of file
+  r = return(if_false, tuple2)
diff --git a/hercules_test/test_inputs/simple2.hir b/hercules_test/test_inputs/simple2.hir
index af5ac284b0715a0d10de73571575fd8e2b8a1ac5..d4f1bebe754d08ecb48739f4499551452ffa78bc 100644
--- a/hercules_test/test_inputs/simple2.hir
+++ b/hercules_test/test_inputs/simple2.hir
@@ -8,6 +8,6 @@ fn simple2(x: i32) -> i32
   fac_acc = mul(fac, idx_inc)
   in_bounds = lt(idx_inc, x)
   if = if(loop, in_bounds)
-  if_false = projection(if, 0)
-  if_true = projection(if, 1)
-  r = return(if_false, fac_acc)
\ No newline at end of file
+  if_false = control_projection(if, 0)
+  if_true = control_projection(if, 1)
+  r = return(if_false, fac_acc)
diff --git a/hercules_test/test_inputs/strset.hir b/hercules_test/test_inputs/strset.hir
index 4c8b32eeb1ab3875e053818ec7745e74f7b502f1..e8615f21e5a7c8d3148b1673f6e12b0fe6deaee8 100644
--- a/hercules_test/test_inputs/strset.hir
+++ b/hercules_test/test_inputs/strset.hir
@@ -12,6 +12,6 @@ fn strset<1>(str: array(u8, #0), byte: u8) -> array(u8, #0)
   continue = ne(read, byte)
   if_cond = and(continue, in_bounds)
   if = if(loop, if_cond)
-  if_false = projection(if, 0)
-  if_true = projection(if, 1)
+  if_false = control_projection(if, 0)
+  if_true = control_projection(if, 1)
   r = return(if_false, str_inc)
diff --git a/hercules_test/test_inputs/sum_int1.hir b/hercules_test/test_inputs/sum_int1.hir
index 4a9ba0153448e4c9339e901d3ba3a10e027bad56..7de8cf1e606bf2f231ad3391fad50e7b04d32d93 100644
--- a/hercules_test/test_inputs/sum_int1.hir
+++ b/hercules_test/test_inputs/sum_int1.hir
@@ -11,6 +11,6 @@ fn sum<1>(a: array(i32, #0)) -> i32
   red_add = add(red, read)
   in_bounds = lt(idx_inc, bound)
   if = if(loop, in_bounds)
-  if_false = projection(if, 0)
-  if_true = projection(if, 1)
-  r = return(if_false, red_add)
\ No newline at end of file
+  if_false = control_projection(if, 0)
+  if_true = control_projection(if, 1)
+  r = return(if_false, red_add)
diff --git a/hercules_test/test_inputs/sum_int2.hir b/hercules_test/test_inputs/sum_int2.hir
index b5e9a5c010275206c964708074219cf58568fe08..bc614d4e789364ad74fb81b9129f7c758c5cc685 100644
--- a/hercules_test/test_inputs/sum_int2.hir
+++ b/hercules_test/test_inputs/sum_int2.hir
@@ -11,8 +11,8 @@ fn alt_sum<1>(a: array(i32, #0)) -> i32
   rem = rem(idx, two_idx)
   odd = eq(rem, one_idx)
   negate_if = if(loop, odd)
-  negate_if_false = projection(negate_if, 0)
-  negate_if_true = projection(negate_if, 1)
+  negate_if_false = control_projection(negate_if, 0)
+  negate_if_true = control_projection(negate_if, 1)
   negate_bottom = region(negate_if_false, negate_if_true)
   read = read(a, position(idx))
   read_neg = neg(read)
@@ -20,6 +20,6 @@ fn alt_sum<1>(a: array(i32, #0)) -> i32
   red_add = add(red, read_phi)
   in_bounds = lt(idx_inc, bound)
   if = if(negate_bottom, in_bounds)
-  if_false = projection(if, 0)
-  if_true = projection(if, 1)
+  if_false = control_projection(if, 0)
+  if_true = control_projection(if, 1)
   r = return(if_false, red_add)
diff --git a/juno_frontend/src/codegen.rs b/juno_frontend/src/codegen.rs
index a3197041da5a47b50f36ce81f0e36647083e3c17..0902bc61d429794dd67ea737f20553ae3185ba21 100644
--- a/juno_frontend/src/codegen.rs
+++ b/juno_frontend/src/codegen.rs
@@ -118,15 +118,18 @@ impl CodeGenerator<'_> {
                     param_types.push(solver_inst.lower_type(&mut self.builder.builder, *ty));
                 }
 
-                let return_type =
-                    solver_inst.lower_type(&mut self.builder.builder, func.return_type);
+                let return_types = func
+                    .return_types
+                    .iter()
+                    .map(|t| solver_inst.lower_type(&mut self.builder.builder, *t))
+                    .collect::<Vec<_>>();
 
                 let (func_id, entry) = self
                     .builder
                     .create_function(
                         &name,
                         param_types,
-                        return_type,
+                        return_types,
                         func.num_dyn_consts as u32,
                         func.entry,
                     )
@@ -264,10 +267,16 @@ impl CodeGenerator<'_> {
                 // loop
                 Some(block_exit)
             }
-            Stmt::ReturnStmt { expr } => {
-                let (val_ret, block_ret) = self.codegen_expr(expr, types, ssa, cur_block);
+            Stmt::ReturnStmt { exprs } => {
+                let mut vals = vec![];
+                let mut block = cur_block;
+                for expr in exprs {
+                    let (val_ret, block_ret) = self.codegen_expr(expr, types, ssa, block);
+                    vals.push(val_ret);
+                    block = block_ret;
+                }
                 let mut return_node = self.builder.allocate_node();
-                return_node.build_return(block_ret, val_ret);
+                return_node.build_return(block, vals.into());
                 self.builder.add_node(return_node);
                 None
             }
@@ -482,6 +491,7 @@ impl CodeGenerator<'_> {
                 ty_args,
                 dyn_consts,
                 args,
+                num_returns, // number of non-inout returns (which are first)
                 ..
             } => {
                 // We start by lowering the type arguments to TypeIDs
@@ -541,30 +551,26 @@ impl CodeGenerator<'_> {
 
                 // Read each of the "inout values" and perform the SSA update
                 let has_inouts = !inouts.is_empty();
-                // TODO: We should omit unit returns, if we do so the + 1 below is not needed
                 for (idx, var) in inouts.into_iter().enumerate() {
-                    let index = self.builder.builder.create_field_index(idx + 1);
-                    let mut read = self.builder.allocate_node();
-                    let read_id = read.id();
-                    read.build_read(call_id, vec![index].into());
-                    self.builder.add_node(read);
+                    let mut proj = self.builder.allocate_node();
+                    let proj_id = proj.id();
+                    proj.build_data_projection(call_id, num_returns + idx);
+                    self.builder.add_node(proj);
 
-                    ssa.write_variable(var, block, read_id);
+                    ssa.write_variable(var, block, proj_id);
                 }
 
-                // Read the "actual return" value and return it
-                let result = if !has_inouts {
-                    call_id
-                } else {
-                    let value_index = self.builder.builder.create_field_index(0);
-                    let mut read = self.builder.allocate_node();
-                    let read_id = read.id();
-                    read.build_read(call_id, vec![value_index].into());
-                    self.builder.add_node(read);
-                    read_id
-                };
+                (call_id, block)
+            }
+            Expr::CallExtract { call, index, .. } => {
+                let (call, block) = self.codegen_expr(call, types, ssa, cur_block);
+
+                let mut proj = self.builder.allocate_node();
+                let proj_id = proj.id();
+                proj.build_data_projection(call, *index);
+                self.builder.add_node(proj);
 
-                (result, block)
+                (proj_id, block)
             }
             Expr::Intrinsic {
                 id,
diff --git a/juno_frontend/src/labeled_builder.rs b/juno_frontend/src/labeled_builder.rs
index 15bed6c26d29485b3133c1a89bea1540994693d2..869485e91a938320a0d2bbfb4fa116768947d236 100644
--- a/juno_frontend/src/labeled_builder.rs
+++ b/juno_frontend/src/labeled_builder.rs
@@ -33,14 +33,14 @@ impl<'a> LabeledBuilder<'a> {
         &mut self,
         name: &str,
         param_types: Vec<TypeID>,
-        return_type: TypeID,
+        return_types: Vec<TypeID>,
         num_dynamic_constants: u32,
         entry: bool,
     ) -> Result<(FunctionID, NodeID), String> {
         let (func, entry) = self.builder.create_function(
             name,
             param_types,
-            return_type,
+            return_types,
             num_dynamic_constants,
             entry,
         )?;
diff --git a/juno_frontend/src/lang.y b/juno_frontend/src/lang.y
index be9161aaf4515123995602b8d6077fc8157a1c96..b9efe1faa565a58aca57125db6e390c6c75e4cdd 100644
--- a/juno_frontend/src/lang.y
+++ b/juno_frontend/src/lang.y
@@ -167,17 +167,17 @@ ConstDecl -> Result<Top, ()>
 FuncDecl -> Result<Top, ()>
   : PubOption 'fn' 'ID' TypeVars '(' Arguments ')' Stmts
       { Ok(Top::FuncDecl{ span : $span, public : $1?, attr : None, name : span_of_tok($3)?,
-                          ty_vars : $4?, args : $6?, ty : None, body : $8? }) }
+                          ty_vars : $4?, args : $6?, rets: vec![], body : $8? }) }
   | 'FUNC_ATTR' PubOption 'fn' 'ID' TypeVars '(' Arguments ')' Stmts
       { Ok(Top::FuncDecl{ span : $span, public : $2?, attr : Some(span_of_tok($1)?),
-                          name : span_of_tok($4)?, ty_vars : $5?, args : $7?, ty : None,
+                          name : span_of_tok($4)?, ty_vars : $5?, args : $7?, rets: vec![],
                           body : $9? }) }
-  | PubOption 'fn' 'ID' TypeVars '(' Arguments ')' '->' Type Stmts
+  | PubOption 'fn' 'ID' TypeVars '(' Arguments ')' '->' Types Stmts
       { Ok(Top::FuncDecl{ span : $span, public : $1?, attr : None, name : span_of_tok($3)?,
-                          ty_vars : $4?, args : $6?, ty : Some($9?), body : $10? }) }
-  | 'FUNC_ATTR' PubOption 'fn' 'ID' TypeVars '(' Arguments ')' '->' Type Stmts
+                          ty_vars : $4?, args : $6?, rets: $9?, body : $10? }) }
+  | 'FUNC_ATTR' PubOption 'fn' 'ID' TypeVars '(' Arguments ')' '->' Types Stmts
       { Ok(Top::FuncDecl{ span : $span, public : $2?, attr : Some(span_of_tok($1)?),
-                          name : span_of_tok($4)?, ty_vars : $5?, args : $7?, ty : Some($10?),
+                          name : span_of_tok($4)?, ty_vars : $5?, args : $7?, rets: $10?,
                           body : $11? }) }
   ;
 Arguments -> Result<Vec<(Option<Span>, VarBind)>, ()>
@@ -198,6 +198,18 @@ VarBind -> Result<VarBind, ()>
   | Pattern ':' Type { Ok(VarBind{ span : $span, pattern : $1?, typ : Some($3?) }) }
   ;
 
+LetBind -> Result<LetBind, ()>
+  : VarBind { 
+      let VarBind { span, pattern, typ } = $1?;
+      Ok(LetBind::Single { span, pattern, typ }) 
+  }
+  | PatternsCommaS ',' Pattern {
+      let mut pats = $1?;
+      pats.push($3?);
+      Ok(LetBind::Multi { span: $span, patterns: pats })
+  }
+  ;
+
 Pattern -> Result<Pattern, ()>
   : '_'                   { Ok(Pattern::Wildcard { span : $span }) }
   | IntLit                { let (span, base) = $1?;
@@ -240,9 +252,9 @@ StructPatterns -> Result<(VecDeque<(Id, Pattern)>, bool), ()>
   ;
 
 Stmt -> Result<Stmt, ()>
-  : 'let' VarBind ';'
+  : 'let' LetBind ';'
       { Ok(Stmt::LetStmt{ span : $span, var : $2?, init : None }) }
-  | 'let' VarBind '=' Expr ';'
+  | 'let' LetBind '=' Expr ';'
       { Ok(Stmt::LetStmt{ span : $span, var : $2?, init : Some($4?) }) }
   | 'const' VarBind ';'
       { Ok(Stmt::ConstStmt{ span : $span, var : $2?, init : None }) }
@@ -305,10 +317,8 @@ Stmt -> Result<Stmt, ()>
                           inclusive: true, step: None, body: Box::new($8?) }) }
   | 'while' NonStructExpr Stmts
       { Ok(Stmt::WhileStmt{ span : $span, cond : $2?, body : Box::new($3?) }) }
-  | 'return' ';'
-      { Ok(Stmt::ReturnStmt{ span : $span, expr : None }) }
-  | 'return' Expr ';'
-      { Ok(Stmt::ReturnStmt{ span : $span, expr : Some($2?)}) }
+  | 'return' Exprs ';'
+      { Ok(Stmt::ReturnStmt{ span : $span, vals: $2?}) }
   | 'break' ';'
       { Ok(Stmt::BreakStmt{ span : $span }) }
   | 'continue' ';'
@@ -659,6 +669,15 @@ pub struct VarBind { pub span : Span, pub pattern : Pattern, pub typ : Option<Ty
 #[derive(Debug)]
 pub struct Case { pub span : Span, pub pat : Vec<Pattern>, pub body : Stmt }
 
+// Let bindings are different from other bindings because they can be used to
+// destruct multi-return function values, and so can actually contain multiple
+// patterns
+#[derive(Debug)]
+pub enum LetBind {
+  Single { span: Span, pattern: Pattern, typ: Option<Type> },
+  Multi  { span: Span, patterns: Vec<Pattern> },
+}
+
 #[derive(Debug)]
 pub enum Top {
   Import    { span : Span, name : ImportName },
@@ -666,7 +685,7 @@ pub enum Top {
   ConstDecl { span : Span, public : bool, name : Id, ty : Option<Type>, body : Expr },
   FuncDecl  { span : Span, public : bool, attr : Option<Span>, name : Id, ty_vars : Vec<TypeVar>,
               args : Vec<(Option<Span>, VarBind)>, // option is for inout
-              ty : Option<Type>, body : Stmt },
+              rets : Vec<Type>, body : Stmt },
   ModDecl   { span : Span, public : bool, name : Id, body : Vec<Top> },
 }
 
@@ -688,7 +707,7 @@ pub enum Type {
 
 #[derive(Debug)]
 pub enum Stmt {
-  LetStmt    { span : Span, var : VarBind, init : Option<Expr> },
+  LetStmt    { span : Span, var : LetBind, init : Option<Expr> },
   ConstStmt  { span : Span, var : VarBind, init : Option<Expr> },
   AssignStmt { span : Span, lhs : LExpr, assign : AssignOp, assign_span : Span, rhs : Expr },
   IfStmt     { span : Span, cond : Expr, thn : Box<Stmt>, els : Option<Box<Stmt>> },
@@ -697,7 +716,7 @@ pub enum Stmt {
   ForStmt    { span : Span, var : VarBind, init : Expr, bound : Expr,
                inclusive: bool, step : Option<(bool, Span, IntBase)>, body : Box<Stmt> },
   WhileStmt  { span : Span, cond : Expr, body : Box<Stmt> },
-  ReturnStmt { span : Span, expr : Option<Expr> },
+  ReturnStmt { span : Span, vals : Vec<Expr> },
   BreakStmt  { span : Span },
   ContinueStmt { span : Span },
   BlockStmt    { span : Span, body : Vec<Stmt> },
diff --git a/juno_frontend/src/semant.rs b/juno_frontend/src/semant.rs
index bfd4cf7f52f46274db5ad42fd495884fd8195fbd..f0736d2b63296b8006e874ce629aa08ac6ec3de6 100644
--- a/juno_frontend/src/semant.rs
+++ b/juno_frontend/src/semant.rs
@@ -46,7 +46,7 @@ enum Entity {
         index: usize,
         type_args: Vec<parser::Kind>,
         args: Vec<(types::Type, bool)>,
-        return_type: types::Type,
+        return_types: Vec<types::Type>,
     },
 }
 
@@ -202,7 +202,7 @@ pub struct Function {
     pub num_dyn_consts: usize,
     pub num_type_args: usize,
     pub arguments: Vec<(usize, Type)>,
-    pub return_type: Type,
+    pub return_types: Vec<Type>,
     pub body: Stmt,
     pub entry: bool,
 }
@@ -236,7 +236,7 @@ pub enum Stmt {
         body: Box<Stmt>,
     },
     ReturnStmt {
-        expr: Expr,
+        exprs: Vec<Expr>,
     },
     BreakStmt {},
     ContinueStmt {},
@@ -328,6 +328,14 @@ pub enum Expr {
         ty_args: Vec<Type>,
         dyn_consts: Vec<DynConst>,
         args: Vec<Either<Expr, usize>>,
+        // Include the number of Juno returns (i.e. non-inouts) for codegen
+        num_returns: usize,
+        typ: Type,
+    },
+    // A projection from a call
+    CallExtract {
+        call: Box<Expr>,
+        index: usize,
         typ: Type,
     },
     Intrinsic {
@@ -415,57 +423,20 @@ fn convert_binary_op(op: parser::BinaryOp) -> BinaryOp {
 impl Expr {
     pub fn get_type(&self) -> Type {
         match self {
-            Expr::Variable { var: _, typ }
-            | Expr::DynConst { val: _, typ }
-            | Expr::Read {
-                index: _,
-                val: _,
-                typ,
-            }
-            | Expr::Write {
-                index: _,
-                val: _,
-                rep: _,
-                typ,
-            }
-            | Expr::Tuple { vals: _, typ }
-            | Expr::Union {
-                tag: _,
-                val: _,
-                typ,
-            }
-            | Expr::Constant { val: _, typ }
-            | Expr::UnaryExp {
-                op: _,
-                expr: _,
-                typ,
-            }
-            | Expr::BinaryExp {
-                op: _,
-                lhs: _,
-                rhs: _,
-                typ,
-            }
-            | Expr::CastExpr { expr: _, typ }
-            | Expr::CondExpr {
-                cond: _,
-                thn: _,
-                els: _,
-                typ,
-            }
-            | Expr::CallExpr {
-                func: _,
-                ty_args: _,
-                dyn_consts: _,
-                args: _,
-                typ,
-            }
-            | Expr::Intrinsic {
-                id: _,
-                ty_args: _,
-                args: _,
-                typ,
-            }
+            Expr::Variable { typ, .. }
+            | Expr::DynConst { typ, .. }
+            | Expr::Read { typ, .. }
+            | Expr::Write { typ, .. }
+            | Expr::Tuple { typ, .. }
+            | Expr::Union { typ, .. }
+            | Expr::Constant { typ, .. }
+            | Expr::UnaryExp { typ, .. }
+            | Expr::BinaryExp { typ, .. }
+            | Expr::CastExpr { typ, .. }
+            | Expr::CondExpr { typ, .. }
+            | Expr::CallExpr { typ, .. }
+            | Expr::CallExtract { typ, .. }
+            | Expr::Intrinsic { typ, .. }
             | Expr::Zero { typ } => *typ,
         }
     }
@@ -650,7 +621,7 @@ fn analyze_program(
                 name,
                 ty_vars,
                 args,
-                ty,
+                rets,
                 body,
             } => {
                 // TODO: Handle public
@@ -778,44 +749,40 @@ fn analyze_program(
                     }
                 }
 
-                let return_type = {
-                    // A missing return type is implicitly void
-                    let ty = ty.unwrap_or(parser::Type::PrimType {
-                        span: span,
-                        typ: parser::Primitive::Void,
-                    });
-                    match process_type(
-                        ty,
-                        num_dyn_const,
-                        lexer,
-                        &mut stringtab,
-                        &env,
-                        &mut types,
-                        true,
-                    ) {
-                        Ok(ty) => ty,
-                        Err(mut errs) => {
-                            errors.append(&mut errs);
-                            types.new_primitive(types::Primitive::Unit)
+                let return_types = rets
+                    .into_iter()
+                    .map(|ty| {
+                        match process_type(
+                            ty,
+                            num_dyn_const,
+                            lexer,
+                            &mut stringtab,
+                            &env,
+                            &mut types,
+                            true,
+                        ) {
+                            Ok(ty) => ty,
+                            Err(mut errs) => {
+                                errors.append(&mut errs);
+                                // Type we return doesn't matter, error will be propagated upwards
+                                // next, but need to return something
+                                types.new_primitive(types::Primitive::Unit)
+                            }
                         }
-                    }
-                };
+                    })
+                    .collect::<Vec<_>>();
 
                 if !errors.is_empty() {
                     return Err(errors);
                 }
 
                 // Compute the proper type accounting for the inouts (which become returns)
-                let mut inout_types = inouts.iter().map(|e| e.get_type()).collect::<Vec<_>>();
-
-                let mut return_types = vec![return_type];
-                return_types.extend(inout_types);
-                // TODO: Ideally we would omit unit returns
-                let pure_return_type = if return_types.len() == 1 {
-                    return_types.pop().unwrap()
-                } else {
-                    types.new_tuple(return_types)
-                };
+                let inout_types = inouts.iter().map(|e| e.get_type()).collect::<Vec<_>>();
+                let pure_return_types = return_types
+                    .clone()
+                    .into_iter()
+                    .chain(inout_types.into_iter())
+                    .collect::<Vec<_>>();
 
                 // Finally, we have a properly built environment and we can
                 // start processing the body
@@ -827,7 +794,7 @@ fn analyze_program(
                     &mut env,
                     &mut types,
                     false,
-                    return_type,
+                    &return_types,
                     &inouts,
                     &mut labels,
                 )?;
@@ -835,21 +802,11 @@ fn analyze_program(
                 if end_reachable {
                     // The end of a function being reachable (i.e. there is some possible path
                     // where there is no return statement) is an error unless the return type is
-                    // void
-                    if types.unify_void(return_type) {
+                    // empty
+                    if return_types.is_empty() {
                         // Insert return at the end
                         body = Stmt::BlockStmt {
-                            body: vec![
-                                body,
-                                generate_return(
-                                    Expr::Tuple {
-                                        vals: vec![],
-                                        typ: types.new_primitive(types::Primitive::Unit),
-                                    },
-                                    &inouts,
-                                    &mut types,
-                                ),
-                            ],
+                            body: vec![body, generate_return(vec![], &inouts)],
                         };
                     } else {
                         Err(singleton_error(ErrorMessage::SemanticError(
@@ -876,7 +833,7 @@ fn analyze_program(
                             .iter()
                             .map(|(ty, is, _)| (*ty, *is))
                             .collect::<Vec<_>>(),
-                        return_type: return_type,
+                        return_types,
                     },
                 );
 
@@ -889,7 +846,7 @@ fn analyze_program(
                         .iter()
                         .map(|(t, _, v)| (*v, *t))
                         .collect::<Vec<_>>(),
-                    return_type: pure_return_type,
+                    return_types: pure_return_types,
                     body: body,
                     entry: entry,
                 });
@@ -1610,21 +1567,17 @@ fn process_stmt(
     env: &mut Env<usize, Entity>,
     types: &mut TypeSolver,
     in_loop: bool,
-    return_type: Type,
+    return_types: &[Type],
     inouts: &Vec<Expr>,
     labels: &mut StringTable,
 ) -> Result<(Stmt, bool), ErrorMessages> {
     match stmt {
-        parser::Stmt::LetStmt {
-            span,
-            var:
-                VarBind {
-                    span: v_span,
-                    pattern,
-                    typ,
-                },
-            init,
-        } => {
+        parser::Stmt::LetStmt { span, var, init } => {
+            let (_, pattern, typ) = match var {
+                LetBind::Single { span, pattern, typ } => (span, Either::Left(pattern), typ),
+                LetBind::Multi { span, patterns } => (span, Either::Right(patterns), None),
+            };
+
             if typ.is_none() && init.is_none() {
                 return Err(singleton_error(ErrorMessage::SemanticError(
                     span_to_loc(span, lexer),
@@ -1676,12 +1629,72 @@ fn process_stmt(
 
             let mut res = vec![];
             res.push(Stmt::AssignStmt { var, val });
-            res.extend(
-                process_irrefutable_pattern(
-                    pattern, false, var, typ, lexer, stringtab, env, types, false,
-                )?
-                .0,
-            );
+
+            match pattern {
+                Either::Left(pattern) => {
+                    if let Some(return_types) = types.get_return_types(typ) {
+                        return Err(singleton_error(ErrorMessage::SemanticError(
+                            span_to_loc(span, lexer),
+                            format!("Expected {} patterns, found 1 pattern", return_types.len()),
+                        )));
+                    }
+                    res.extend(
+                        process_irrefutable_pattern(
+                            pattern, false, var, typ, lexer, stringtab, env, types, false,
+                        )?
+                        .0,
+                    );
+                }
+                Either::Right(patterns) => {
+                    let Some(return_types) = types.get_return_types(typ) else {
+                        return Err(singleton_error(ErrorMessage::SemanticError(
+                            span_to_loc(span, lexer),
+                            format!("Expected 1 pattern, found {} patterns", patterns.len()),
+                        )));
+                    };
+                    if return_types.len() != patterns.len() {
+                        return Err(singleton_error(ErrorMessage::SemanticError(
+                            span_to_loc(span, lexer),
+                            format!(
+                                "Expected {} pattern, found {} patterns",
+                                return_types.len(),
+                                patterns.len()
+                            ),
+                        )));
+                    }
+
+                    // Process each pattern after extracting the appropriate value from the call
+                    for (index, (pat, ret_typ)) in patterns
+                        .into_iter()
+                        .zip(return_types.clone().into_iter())
+                        .enumerate()
+                    {
+                        let extract_var = env.uniq();
+                        res.push(Stmt::AssignStmt {
+                            var: extract_var,
+                            val: Expr::CallExtract {
+                                call: Box::new(Expr::Variable { var, typ }),
+                                index,
+                                typ: ret_typ,
+                            },
+                        });
+                        res.extend(
+                            process_irrefutable_pattern(
+                                pat,
+                                false,
+                                extract_var,
+                                ret_typ,
+                                lexer,
+                                stringtab,
+                                env,
+                                types,
+                                false,
+                            )?
+                            .0,
+                        );
+                    }
+                }
+            }
 
             Ok((Stmt::BlockStmt { body: res }, true))
         }
@@ -1689,7 +1702,7 @@ fn process_stmt(
             span,
             var:
                 VarBind {
-                    span: v_span,
+                    span: _v_span,
                     pattern,
                     typ,
                 },
@@ -1935,7 +1948,7 @@ fn process_stmt(
                 env,
                 types,
                 in_loop,
-                return_type,
+                return_types,
                 inouts,
                 labels,
             );
@@ -1952,7 +1965,7 @@ fn process_stmt(
                     env,
                     types,
                     in_loop,
-                    return_type,
+                    return_types,
                     inouts,
                     labels,
                 )
@@ -2110,7 +2123,7 @@ fn process_stmt(
                 env,
                 types,
                 true,
-                return_type,
+                return_types,
                 inouts,
                 labels,
             )?;
@@ -2214,7 +2227,7 @@ fn process_stmt(
                 env,
                 types,
                 true,
-                return_type,
+                return_types,
                 inouts,
                 labels,
             );
@@ -2241,36 +2254,50 @@ fn process_stmt(
                 true,
             ))
         }
-        parser::Stmt::ReturnStmt { span, expr } => {
-            let return_val = if expr.is_none() && types.unify_void(return_type) {
-                Expr::Constant {
-                    val: (Literal::Unit, return_type),
-                    typ: return_type,
-                }
-            } else if expr.is_none() {
-                Err(singleton_error(ErrorMessage::SemanticError(
+        parser::Stmt::ReturnStmt { span, vals } => {
+            if return_types.len() != vals.len() {
+                return Err(singleton_error(ErrorMessage::SemanticError(
                     span_to_loc(span, lexer),
                     format!(
-                        "Expected return of type {} found no return value",
-                        unparse_type(types, return_type, stringtab)
+                        "Expected {} return values found {}",
+                        return_types.len(),
+                        vals.len(),
                     ),
-                )))?
-            } else {
-                let val = process_expr(expr.unwrap(), num_dyn_const, lexer, stringtab, env, types)?;
-                let typ = val.get_type();
-                if !types.unify(return_type, typ) {
-                    Err(singleton_error(ErrorMessage::TypeError(
-                        span_to_loc(span, lexer),
-                        unparse_type(types, return_type, stringtab),
-                        unparse_type(types, typ, stringtab),
-                    )))?
-                }
-                val
-            };
+                )));
+            }
+
+            let return_vals = vals
+                .into_iter()
+                .zip(return_types.iter())
+                .map(|(expr, typ)| {
+                    let expr_span = expr.span();
+                    let val = process_expr(expr, num_dyn_const, lexer, stringtab, env, types)?;
+                    if types.unify(*typ, val.get_type()) {
+                        Ok(val)
+                    } else {
+                        Err(singleton_error(ErrorMessage::TypeError(
+                            span_to_loc(expr_span, lexer),
+                            unparse_type(types, *typ, stringtab),
+                            unparse_type(types, val.get_type(), stringtab),
+                        )))
+                    }
+                })
+                .fold(Ok(vec![]), |res, val| match (res, val) {
+                    (Ok(mut res), Ok(val)) => {
+                        res.push(val);
+                        Ok(res)
+                    }
+                    (Ok(_), Err(msg)) => Err(msg),
+                    (Err(msg), Ok(_)) => Err(msg),
+                    (Err(mut msgs), Err(msg)) => {
+                        msgs.extend(msg);
+                        Err(msgs)
+                    }
+                })?;
 
-            // We return a tuple of the return value and of the inout variables
+            // We return both the actual return values and the inout arguments
             // Statements after a return are never reachable
-            Ok((generate_return(return_val, inouts, types), false))
+            Ok((generate_return(return_vals, inouts), false))
         }
         parser::Stmt::BreakStmt { span } => {
             if !in_loop {
@@ -2318,7 +2345,7 @@ fn process_stmt(
                     env,
                     types,
                     in_loop,
-                    return_type,
+                    return_types,
                     inouts,
                     labels,
                 ) {
@@ -2384,7 +2411,7 @@ fn process_stmt(
                 env,
                 types,
                 in_loop,
-                return_type,
+                return_types,
                 inouts,
                 labels,
             )?;
@@ -4723,7 +4750,7 @@ fn process_expr(
                     index: function,
                     type_args: kinds,
                     args: func_args,
-                    return_type,
+                    return_types,
                 }) => {
                     let func = *function;
 
@@ -4813,11 +4840,11 @@ fn process_expr(
                         }
                         tys
                     };
-                    let return_typ = if let Some(res) =
-                        types.instantiate(*return_type, &type_vars, &dyn_consts)
-                    {
-                        res
-                    } else {
+                    let return_types = return_types
+                        .iter()
+                        .map(|t| types.instantiate(*t, &type_vars, &dyn_consts))
+                        .collect::<Option<Vec<_>>>();
+                    let Some(return_types) = return_types else {
                         return Err(singleton_error(ErrorMessage::SemanticError(
                             span_to_loc(span, lexer),
                             "Failure in variable substitution".to_string(),
@@ -4887,13 +4914,29 @@ fn process_expr(
                     if !errors.is_empty() {
                         Err(errors)
                     } else {
-                        Ok(Expr::CallExpr {
-                            func: func,
+                        let single_type = if return_types.len() == 1 {
+                            Some(return_types[0])
+                        } else {
+                            None
+                        };
+                        let num_returns = return_types.len();
+                        let call = Expr::CallExpr {
+                            func,
                             ty_args: type_vars,
-                            dyn_consts: dyn_consts,
+                            dyn_consts,
                             args: arg_vals,
-                            typ: return_typ,
-                        })
+                            num_returns,
+                            typ: types.new_multi_return(return_types),
+                        };
+                        if let Some(return_type) = single_type {
+                            Ok(Expr::CallExtract {
+                                call: Box::new(call),
+                                index: 0,
+                                typ: return_type,
+                            })
+                        } else {
+                            Ok(call)
+                        }
                     }
                 }
             }
@@ -5024,25 +5067,9 @@ fn process_expr(
     }
 }
 
-fn generate_return(expr: Expr, inouts: &Vec<Expr>, types: &mut TypeSolver) -> Stmt {
-    let inout_types = inouts.iter().map(|e| e.get_type()).collect::<Vec<_>>();
-
-    let mut return_types = vec![expr.get_type()];
-    return_types.extend(inout_types);
-
-    let mut return_vals = vec![expr];
-    return_vals.extend_from_slice(inouts);
-
-    let val = if return_vals.len() == 1 {
-        return_vals.pop().unwrap()
-    } else {
-        Expr::Tuple {
-            vals: return_vals,
-            typ: types.new_tuple(return_types),
-        }
-    };
-
-    Stmt::ReturnStmt { expr: val }
+fn generate_return(mut exprs: Vec<Expr>, inouts: &[Expr]) -> Stmt {
+    exprs.extend_from_slice(inouts);
+    Stmt::ReturnStmt { exprs: exprs }
 }
 
 fn convert_primitive(prim: parser::Primitive) -> types::Primitive {
@@ -5098,6 +5125,7 @@ fn process_irrefutable_pattern(
                     "Bound variables must be local names, without a package separator".to_string(),
                 )));
             }
+            assert!(types.get_return_types(typ).is_none());
 
             let nm = intern_package_name(&name, lexer, stringtab)[0];
             let variable = env.uniq();
diff --git a/juno_frontend/src/ssa.rs b/juno_frontend/src/ssa.rs
index 7076d62259e04cf17f2e629825bd9fd7b7f05747..9dbc0bfd23594ff566c5a1cb82f4d26ba8ac1381 100644
--- a/juno_frontend/src/ssa.rs
+++ b/juno_frontend/src/ssa.rs
@@ -45,10 +45,10 @@ impl SSA {
         let right_proj = right_builder.id();
 
         // True branch
-        left_builder.build_projection(if_builder.id(), 1);
+        left_builder.build_control_projection(if_builder.id(), 1);
 
         // False branch
-        right_builder.build_projection(if_builder.id(), 0);
+        right_builder.build_control_projection(if_builder.id(), 0);
 
         builder.add_node(left_builder);
         builder.add_node(right_builder);
diff --git a/juno_frontend/src/types.rs b/juno_frontend/src/types.rs
index edb51db5c03eaef1ff1b6aa162a7767d38adc161..4099c56704846454fa0fcea2a9cdc6dbc9ac1531 100644
--- a/juno_frontend/src/types.rs
+++ b/juno_frontend/src/types.rs
@@ -204,6 +204,11 @@ enum TypeForm {
         kind: parser::Kind,
         loc: Location,
     },
+
+    // The types of call nodes are MultiReturns
+    MultiReturn {
+        types: Vec<Type>,
+    },
 }
 
 #[derive(Debug)]
@@ -279,6 +284,10 @@ impl TypeSolver {
         })
     }
 
+    pub fn new_multi_return(&mut self, types: Vec<Type>) -> Type {
+        self.create_type(TypeForm::MultiReturn { types })
+    }
+
     fn create_type(&mut self, typ: TypeForm) -> Type {
         let idx = self.types.len();
         self.types.push(typ);
@@ -543,26 +552,12 @@ impl TypeSolver {
                 }
             }
 
+            // Note that MultReturn types never unify with anything (even itself), this is
+            // intentional and makes it so that the only way MultiReturns can be used is to
+            // destruct them
             _ => false,
         }
     }
-    /*
-        pub fn is_tuple(&self, Type { val } : Type) -> bool {
-            match &self.types[val] {
-                TypeForm::Tuple(_) => true,
-                TypeForm::OtherType(t) => self.is_tuple(*t),
-                _ => false,
-            }
-        }
-
-        pub fn get_num_fields(&self, Type { val } : Type) -> Option<usize> {
-            match &self.types[val] {
-                TypeForm::Tuple(fields) => { Some(fields.len()) },
-                TypeForm::OtherType(t) => self.get_num_fields(*t),
-                _ => None,
-            }
-        }
-    */
 
     // Returns the types of the fields of a tuple
     pub fn get_fields(&self, Type { val }: Type) -> Option<&Vec<Type>> {
@@ -676,26 +671,13 @@ impl TypeSolver {
         }
     }
 
-    /*
-        pub fn get_constructor_list(&self, Type { val } : Type) -> Option<Vec<usize>> {
-            match &self.types[val] {
-                TypeForm::Union { name : _, id : _, constr : _, names } => {
-                    Some(names.keys().map(|i| *i).collect::<Vec<_>>())
-                },
-                TypeForm::OtherType(t) => self.get_constructor_list(*t),
-                _ => None,
-            }
-        }
-
-
-        fn is_type_var_num(&self, num : usize, Type { val } : Type) -> bool {
-            match &self.types[val] {
-                TypeForm::TypeVar { name : _, index, .. } => *index == num,
-                TypeForm::OtherType(t) => self.is_type_var_num(num, *t),
-                _ => false,
-            }
+    pub fn get_return_types(&self, Type { val }: Type) -> Option<&Vec<Type>> {
+        match &self.types[val] {
+            TypeForm::MultiReturn { types } => Some(types),
+            TypeForm::OtherType { other, .. } => self.get_return_types(*other),
+            _ => None,
         }
-    */
+    }
 
     pub fn to_string(&self, Type { val }: Type, stringtab: &dyn Fn(usize) -> String) -> String {
         match &self.types[val] {
@@ -724,6 +706,11 @@ impl TypeSolver {
             | TypeForm::Struct { name, .. }
             | TypeForm::Union { name, .. } => stringtab(*name),
             TypeForm::AnyOfKind { kind, .. } => kind.to_string(),
+            TypeForm::MultiReturn { types } => types
+                .iter()
+                .map(|t| self.to_string(*t, stringtab))
+                .collect::<Vec<_>>()
+                .join(", "),
         }
     }
 
@@ -825,6 +812,9 @@ impl TypeSolver {
                     Some(Type { val })
                 }
             }
+            TypeForm::MultiReturn { .. } => {
+                panic!("Multi-Return types should never be instantiated")
+            }
         }
     }
 
@@ -969,6 +959,9 @@ impl TypeSolverInst<'_> {
                 TypeForm::AnyOfKind { .. } => {
                     panic!("TypeSolverInst only works on solved types which do not have AnyOfKinds")
                 }
+                TypeForm::MultiReturn { .. } => {
+                    panic!("MultiReturn types should never be lowered")
+                }
             };
 
             match solution {
diff --git a/juno_samples/multi_return/Cargo.toml b/juno_samples/multi_return/Cargo.toml
new file mode 100644
index 0000000000000000000000000000000000000000..0fb3de94cfff22a8720d084ad2bd7e43c87a4cd4
--- /dev/null
+++ b/juno_samples/multi_return/Cargo.toml
@@ -0,0 +1,21 @@
+[package]
+name = "juno_multi_return"
+version = "0.1.0"
+authors = ["Aaron Councilman <aaronjc4@illinois.edu>"]
+edition = "2021"
+
+[[bin]]
+name = "juno_multi_return"
+path = "src/main.rs"
+
+[features]
+cuda = ["juno_build/cuda", "hercules_rt/cuda"]
+
+[build-dependencies]
+juno_build = { path = "../../juno_build" }
+
+[dependencies]
+juno_build = { path = "../../juno_build" }
+hercules_rt = { path = "../../hercules_rt" }
+with_builtin_macros = "0.1.0"
+async-std = "*"
diff --git a/juno_samples/multi_return/build.rs b/juno_samples/multi_return/build.rs
new file mode 100644
index 0000000000000000000000000000000000000000..3a8f9b1c503f8cc5fbca0baf06ff8edbd300bf9e
--- /dev/null
+++ b/juno_samples/multi_return/build.rs
@@ -0,0 +1,15 @@
+use juno_build::JunoCompiler;
+
+fn main() {
+    JunoCompiler::new()
+        .file_in_src("multi_return.jn")
+        .unwrap()
+        .schedule_in_src(if cfg!(feature = "cuda") {
+            "gpu.sch"
+        } else {
+            "cpu.sch"
+        })
+        .unwrap()
+        .build()
+        .unwrap();
+}
diff --git a/juno_samples/multi_return/src/cpu.sch b/juno_samples/multi_return/src/cpu.sch
new file mode 100644
index 0000000000000000000000000000000000000000..d89b3b8bb819f3254491b94a4b26d205bfbc9afa
--- /dev/null
+++ b/juno_samples/multi_return/src/cpu.sch
@@ -0,0 +1,35 @@
+gvn(*);
+phi-elim(*);
+dce(*);
+
+ip-sroa(*);
+sroa(*);
+
+ip-sroa[true](rolling_sum);
+sroa[true](rolling_sum, rolling_sum_prod);
+
+dce(*);
+
+forkify(*);
+fork-guard-elim(*);
+gvn(*);
+dce(*);
+
+inline(*);
+delete-uncalled(*);
+
+let out = auto-outline(*);
+cpu(out.rolling_sum_prod);
+
+fork-fusion(out.rolling_sum_prod);
+gvn(*);
+dce(*);
+
+float-collections(*);
+
+unforkify(*);
+gvn(*);
+ccp(*);
+dce(*);
+
+gcm(*);
diff --git a/juno_samples/multi_return/src/gpu.sch b/juno_samples/multi_return/src/gpu.sch
new file mode 100644
index 0000000000000000000000000000000000000000..f690086b5c213eca585f9640d0f7419364800e11
--- /dev/null
+++ b/juno_samples/multi_return/src/gpu.sch
@@ -0,0 +1,31 @@
+gvn(*);
+phi-elim(*);
+dce(*);
+
+ip-sroa(*);
+sroa(*);
+
+ip-sroa[true](rolling_sum);
+sroa[true](rolling_sum, rolling_sum_prod);
+
+dce(*);
+
+forkify(*);
+fork-guard-elim(*);
+gvn(*);
+dce(*);
+
+inline(*);
+delete-uncalled(*);
+
+let out = auto-outline(*);
+gpu(out.rolling_sum_prod);
+
+fork-fusion(out.rolling_sum_prod);
+gvn(*);
+dce(*);
+
+float-collections(*);
+unforkify(*);
+
+gcm(*);
diff --git a/juno_samples/multi_return/src/main.rs b/juno_samples/multi_return/src/main.rs
new file mode 100644
index 0000000000000000000000000000000000000000..b0fd169f71fdd585e22cce89cb12088b7388154e
--- /dev/null
+++ b/juno_samples/multi_return/src/main.rs
@@ -0,0 +1,39 @@
+#![feature(concat_idents)]
+
+juno_build::juno!("multi_return");
+
+use hercules_rt::{runner, HerculesImmBox, HerculesImmBoxTo, HerculesMutBox};
+
+fn main() {
+    const N: usize = 32;
+    let a: Box<[f32]> = (1..=N).map(|i| i as f32).collect();
+    let arg = HerculesImmBox::from(a.as_ref());
+    let mut r = runner!(rolling_sum_prod);
+    let (sums, sum, prods, prod) =
+        async_std::task::block_on(async { r.run(N as u64, arg.to()).await });
+
+    let mut sums = HerculesMutBox::<f32>::from(sums);
+    let mut prods = HerculesMutBox::<f32>::from(prods);
+
+    let (expected_sums, expected_sum) = a.iter().fold((vec![0.0], 0.0), |(mut sums, sum), v| {
+        let new_sum = sum + v;
+        sums.push(new_sum);
+        (sums, new_sum)
+    });
+    let (expected_prods, expected_prod) =
+        a.iter().fold((vec![1.0], 1.0), |(mut prods, prod), v| {
+            let new_prod = prod * v;
+            prods.push(new_prod);
+            (prods, new_prod)
+        });
+
+    assert_eq!(sum, expected_sum);
+    assert_eq!(sums.as_slice(), expected_sums.as_slice());
+    assert_eq!(prod, expected_prod);
+    assert_eq!(prods.as_slice(), expected_prods.as_slice());
+}
+
+#[test]
+fn test_multi_return() {
+    main()
+}
diff --git a/juno_samples/multi_return/src/multi_return.jn b/juno_samples/multi_return/src/multi_return.jn
new file mode 100644
index 0000000000000000000000000000000000000000..30b5576e89270834e7563e8f8b6eba9b00c23673
--- /dev/null
+++ b/juno_samples/multi_return/src/multi_return.jn
@@ -0,0 +1,32 @@
+fn rolling_sum<t: number, n: usize>(x: t[n]) -> (t, t[n + 1]) {
+  let rolling_sum: t[n + 1];
+  let sum = 0;
+
+  for i in 0..n {
+    rolling_sum[i] = sum;
+    sum += x[i];
+  }
+  rolling_sum[n] = sum;
+
+  return (sum, rolling_sum);
+}
+
+fn rolling_prod<t: number, n: usize>(x: t[n]) -> t, t[n + 1] {
+  let rolling_prod: t[n + 1];
+  let prod = 1;
+
+  for i in 0..n {
+    rolling_prod[i] = prod;
+    prod *= x[i];
+  }
+  rolling_prod[n] = prod;
+
+  return prod, rolling_prod;
+}
+
+#[entry]
+fn rolling_sum_prod<n: usize>(x: f32[n]) -> f32[n + 1], f32, f32[n + 1], f32 {
+  let (sum, rsum) = rolling_sum::<_, n>(x);
+  let prod, rprod = rolling_prod::<_, n>(x);
+  return rsum, sum, rprod, prod;
+}
diff --git a/juno_samples/products/src/gpu.sch b/juno_samples/products/src/gpu.sch
index 5ef4c479550bcbe4f5d9ddeffb03efde4162cf7b..0a734bb21dfbe51ad7aabf37085ff72b40b7e8f7 100644
--- a/juno_samples/products/src/gpu.sch
+++ b/juno_samples/products/src/gpu.sch
@@ -5,7 +5,6 @@ dce(*);
 let out = auto-outline(*);
 gpu(out.product_read);
 
-ip-sroa(*);
 sroa(*);
 reuse-products(*);
 crc(*);
diff --git a/juno_samples/rodinia/backprop/src/backprop.jn b/juno_samples/rodinia/backprop/src/backprop.jn
index 2927dbb59ef55681160fae56dcea64a3b8329a27..c7f4345bc5dc89d2fb55ee96231e6b5f6604ef4f 100644
--- a/juno_samples/rodinia/backprop/src/backprop.jn
+++ b/juno_samples/rodinia/backprop/src/backprop.jn
@@ -18,7 +18,7 @@ fn layer_forward<n, m: usize>(vals: f32[n + 1], weights: f32[n + 1, m + 1]) -> f
   return result;
 }
 
-fn output_error<n: usize>(target: f32[n + 1], actual: f32[n + 1]) -> (f32, f32[n + 1]) {
+fn output_error<n: usize>(target: f32[n + 1], actual: f32[n + 1]) -> f32, f32[n + 1] {
   let errsum = 0.0;
   let delta : f32[n + 1];
 
@@ -29,14 +29,14 @@ fn output_error<n: usize>(target: f32[n + 1], actual: f32[n + 1]) -> (f32, f32[n
     errsum += abs!(delta[j]);
   }
 
-  return (errsum, delta);
+  return errsum, delta;
 }
 
 fn hidden_error<hidden_n, output_n: usize>(
   out_delta: f32[output_n + 1],
   hidden_weights: f32[hidden_n + 1, output_n + 1],
   hidden_vals: f32[hidden_n + 1],
-) -> (f32, f32[hidden_n + 1]) {
+) -> f32, f32[hidden_n + 1] {
   let errsum = 0.0;
   let delta : f32[hidden_n + 1];
 
@@ -52,7 +52,7 @@ fn hidden_error<hidden_n, output_n: usize>(
     errsum += abs!(delta[j]);
   }
 
-  return (errsum, delta);
+  return errsum, delta;
 }
 
 const ETA : f32 = 0.3;
@@ -63,7 +63,7 @@ fn adjust_weights<n, m: usize>(
   vals: f32[n + 1],
   weights: f32[n + 1, m + 1],
   prev_weights: f32[n + 1, m + 1]
-) -> (f32[n + 1, m + 1], f32[n + 1, m + 1]) {
+) -> f32[n + 1, m + 1], f32[n + 1, m + 1] {
   for j in 1..=m {
     for k in 0..=n {
       let new_dw = ETA * delta[j] * vals[k] + MOMENTUM * prev_weights[k, j];
@@ -72,7 +72,7 @@ fn adjust_weights<n, m: usize>(
     }
   }
 
-  return (weights, prev_weights);
+  return weights, prev_weights;
 }
 
 #[entry]
@@ -83,21 +83,19 @@ fn backprop<input_n, hidden_n, output_n: usize>(
   target: f32[output_n + 1],
   input_prev_weights: f32[input_n + 1, hidden_n + 1],
   hidden_prev_weights: f32[hidden_n + 1, output_n + 1],
-//) -> (f32, f32,
-//      f32[input_n + 1, hidden_n + 1], f32[input_n + 1, hidden_n + 1],
-//      f32[hidden_n + 1, output_n + 1], f32[hidden_n + 1, output_n + 1]) {
-) -> (f32, f32, f32) {
+) -> f32, f32,
+     f32[input_n + 1, hidden_n + 1], f32[input_n + 1, hidden_n + 1],
+     f32[hidden_n + 1, output_n + 1], f32[hidden_n + 1, output_n + 1] {
   let hidden_vals = layer_forward::<input_n, hidden_n>(input_vals, input_weights);
   let output_vals = layer_forward::<hidden_n, output_n>(hidden_vals, hidden_weights);
 
-  let (out_err, out_delta) = output_error::<output_n>(target, output_vals);
-  let (hid_err, hid_delta) = hidden_error::<hidden_n, output_n>(out_delta, hidden_weights, hidden_vals);
+  let out_err, out_delta = output_error::<output_n>(target, output_vals);
+  let hid_err, hid_delta = hidden_error::<hidden_n, output_n>(out_delta, hidden_weights, hidden_vals);
 
-  let (hidden_weights, hidden_prev_weights)
+  let hidden_weights, hidden_prev_weights
     = adjust_weights::<hidden_n, output_n>(out_delta, hidden_vals, hidden_weights, hidden_prev_weights);
-  let (input_weights, input_prev_weights)
+  let input_weights, input_prev_weights
     = adjust_weights::<input_n, hidden_n>(hid_delta, input_vals, input_weights, input_prev_weights);
 
-  return (out_err, hid_err, input_weights[0, 0] + input_prev_weights[0, 0] + hidden_weights[0, 0] + hidden_prev_weights[0, 0]);
-  //return (input_weights, input_prev_weights, hidden_weights, hidden_prev_weights);
+  return out_err, hid_err, input_weights, input_prev_weights, hidden_weights, hidden_prev_weights;
 }
diff --git a/juno_samples/rodinia/backprop/src/main.rs b/juno_samples/rodinia/backprop/src/main.rs
index 848b0abb02cdca6d26d9aa2961e3e0206844fa19..fa80a7a51cba6581f3305398f5e3f91da05ad877 100644
--- a/juno_samples/rodinia/backprop/src/main.rs
+++ b/juno_samples/rodinia/backprop/src/main.rs
@@ -37,25 +37,26 @@ fn run_backprop(
     let mut hidden_prev_weights = HerculesMutBox::from(hidden_prev_weights.to_vec());
 
     let mut runner = runner!(backprop);
-    let res = HerculesMutBox::from(async_std::task::block_on(async {
-        runner
-            .run(
-                input_n,
-                hidden_n,
-                output_n,
-                input_vals.to(),
-                input_weights.to(),
-                hidden_weights.to(),
-                target.to(),
-                input_prev_weights.to(),
-                hidden_prev_weights.to(),
-            )
-            .await
-    }))
-    .as_slice()
-    .to_vec();
-    let out_err = res[0];
-    let hid_err = res[1];
+    let (out_err, hid_err, input_weights, input_prev_weights, hidden_weights, hidden_prev_weights) =
+        async_std::task::block_on(async {
+            runner
+                .run(
+                    input_n,
+                    hidden_n,
+                    output_n,
+                    input_vals.to(),
+                    input_weights.to(),
+                    hidden_weights.to(),
+                    target.to(),
+                    input_prev_weights.to(),
+                    hidden_prev_weights.to(),
+                )
+                .await
+        });
+    let mut input_weights = HerculesMutBox::from(input_weights);
+    let mut hidden_weights = HerculesMutBox::from(hidden_weights);
+    let mut input_prev_weights = HerculesMutBox::from(input_prev_weights);
+    let mut hidden_prev_weights = HerculesMutBox::from(hidden_prev_weights);
 
     (
         out_err,
diff --git a/juno_scheduler/src/ir.rs b/juno_scheduler/src/ir.rs
index 4ac5a732d12f87b6eb76c2771c2c5addfa42cb50..a888cf74dc223a8e52466daa1bacdec533809de4 100644
--- a/juno_scheduler/src/ir.rs
+++ b/juno_scheduler/src/ir.rs
@@ -56,6 +56,7 @@ impl Pass {
             Pass::Print => num == 1,
             Pass::Rename => num == 1,
             Pass::SROA => num == 0 || num == 1,
+            Pass::InterproceduralSROA => num == 0 || num == 1,
             Pass::Xdot => num == 0 || num == 1,
             _ => num == 0,
         }
@@ -70,6 +71,7 @@ impl Pass {
             Pass::Print => "1",
             Pass::Rename => "1",
             Pass::SROA => "0 or 1",
+            Pass::InterproceduralSROA => "0 or 1",
             Pass::Xdot => "0 or 1",
             _ => "0",
         }
diff --git a/juno_scheduler/src/pm.rs b/juno_scheduler/src/pm.rs
index 675cfe1ca9f25f74f7c6ad3c969f3d2b9b9524fd..84b25811f7f387fe5822ea514a732842a0e77b02 100644
--- a/juno_scheduler/src/pm.rs
+++ b/juno_scheduler/src/pm.rs
@@ -2077,16 +2077,33 @@ fn run_pass(
             pm.clear_analyses();
         }
         Pass::InterproceduralSROA => {
-            assert!(args.is_empty());
-            if let Some(_) = selection {
-                return Err(SchedulerError::PassError {
-                    pass: "interproceduralSROA".to_string(),
-                    error: "must be applied to the entire module".to_string(),
-                });
-            }
+            let sroa_with_arrays = match args.get(0) {
+                Some(Value::Boolean { val }) => *val,
+                Some(_) => {
+                    return Err(SchedulerError::PassError {
+                        pass: "sroa".to_string(),
+                        error: "expected boolean argument".to_string(),
+                    });
+                }
+                None => false,
+            };
+
+            let selection =
+                selection_of_functions(pm, selection).ok_or_else(|| SchedulerError::PassError {
+                    pass: "xdot".to_string(),
+                    error: "expected coarse-grained selection (can't partially xdot a function)"
+                        .to_string(),
+                })?;
+            let mut bool_selection = vec![false; pm.functions.len()];
+            selection
+                .into_iter()
+                .for_each(|func| bool_selection[func.idx()] = true);
+
+            pm.make_typing();
+            let typing = pm.typing.take().unwrap();
 
             let mut editors = build_editors(pm);
-            interprocedural_sroa(&mut editors);
+            interprocedural_sroa(&mut editors, &typing, &bool_selection, sroa_with_arrays);
 
             for func in editors {
                 changed |= func.modified();
@@ -2717,21 +2734,16 @@ fn run_pass(
                 None => true,
             };
 
-            let mut bool_selection = vec![];
-            if let Some(selection) = selection {
-                bool_selection = vec![false; pm.functions.len()];
-                for loc in selection {
-                    let CodeLocation::Function(id) = loc else {
-                        return Err(SchedulerError::PassError {
-                        pass: "xdot".to_string(),
-                        error: "expected coarse-grained selection (can't partially xdot a function)".to_string(),
-                    });
-                    };
-                    bool_selection[id.idx()] = true;
-                }
-            } else {
-                bool_selection = vec![true; pm.functions.len()];
-            }
+            let selection =
+                selection_of_functions(pm, selection).ok_or_else(|| SchedulerError::PassError {
+                    pass: "xdot".to_string(),
+                    error: "expected coarse-grained selection (can't partially xdot a function)"
+                        .to_string(),
+                })?;
+            let mut bool_selection = vec![false; pm.functions.len()];
+            selection
+                .into_iter()
+                .for_each(|func| bool_selection[func.idx()] = true);
 
             pm.make_reverse_postorders();
             if force_analyses {