From 0c59f141585a275d5ff826351c1f233c91af1798 Mon Sep 17 00:00:00 2001
From: Aaron Councilman <aaronjc4@illinois.edu>
Date: Thu, 20 Feb 2025 11:31:42 -0600
Subject: [PATCH] Multi return cpu functions

---
 hercules_cg/src/cpu.rs | 82 +++++++++++++++++++++++++++++++++---------
 1 file changed, 66 insertions(+), 16 deletions(-)

diff --git a/hercules_cg/src/cpu.rs b/hercules_cg/src/cpu.rs
index 27daf2a1..45c0f467 100644
--- a/hercules_cg/src/cpu.rs
+++ b/hercules_cg/src/cpu.rs
@@ -60,19 +60,33 @@ struct LLVMBlock {
 impl<'a> CPUContext<'a> {
     fn codegen_function<W: Write>(&self, w: &mut W) -> Result<(), Error> {
         // Dump the function signature.
-        if self.types[self.function.return_type.idx()].is_primitive() {
-            write!(
-                w,
-                "define dso_local {} @{}(",
-                self.get_type(self.function.return_type),
-                self.function.name
-            )?;
+        if self.function.return_types.len() == 1 {
+            let return_type = self.function.return_types[0];
+            if self.types[return_type.idx()].is_primitive() {
+                write!(
+                    w,
+                    "define dso_local {} @{}(",
+                    self.get_type(return_type),
+                    self.function.name
+                )?;
+            } else {
+                write!(
+                    w,
+                    "define dso_local nonnull noundef {} @{}(",
+                    self.get_type(return_type),
+                    self.function.name
+                )?;
+            }
         } else {
             write!(
                 w,
-                "define dso_local nonnull noundef {} @{}(",
-                self.get_type(self.function.return_type),
-                self.function.name
+                "%return.{} = type {{ {} }}\n",
+                self.function.name,
+                self.function.return_types
+                    .iter()
+                    .map(|t| self.get_type(*t))
+                    .collect::<Vec<_>>()
+                    .join(", "),
             )?;
         }
         let mut first_param = true;
@@ -110,6 +124,19 @@ impl<'a> CPUContext<'a> {
                 )?;
             }
         }
+        // Lastly, if the function has multiple returns, is a pointer to the return struct
+        if self.function.return_types.len() != 1 {
+            if first_param {
+                first_param = false;
+            } else {
+                write!(w, ", ")?;
+            }
+            write!(
+                w,
+                "ptr noalias nofree nonnull noundef sret(%return.{}) %ret.ptr",
+                self.function.name,
+            )?;
+        }
         write!(w, ") {{\n")?;
 
         let mut blocks: BTreeMap<_, _> = (0..self.function.nodes.len())
@@ -171,7 +198,7 @@ impl<'a> CPUContext<'a> {
             // successor and are otherwise simple.
             Node::Start
             | Node::Region { preds: _ }
-            | Node::Projection {
+            | Node::ControlProjection {
                 control: _,
                 selection: _,
             } => {
@@ -186,7 +213,7 @@ impl<'a> CPUContext<'a> {
                 let mut succs = self.control_subgraph.succs(id);
                 let succ1 = succs.next().unwrap();
                 let succ2 = succs.next().unwrap();
-                let succ1_is_true = self.function.nodes[succ1.idx()].try_projection(1).is_some();
+                let succ1_is_true = self.function.nodes[succ1.idx()].try_control_projection(1).is_some();
                 write!(
                     term,
                     "  br {}, label %{}, label %{}\n",
@@ -195,9 +222,32 @@ impl<'a> CPUContext<'a> {
                     self.get_block_name(if succ1_is_true { succ2 } else { succ1 }),
                 )?
             }
-            Node::Return { control: _, data } => {
-                let term = &mut blocks.get_mut(&id).unwrap().term;
-                write!(term, "  ret {}\n", self.get_value(data, true))?
+            Node::Return { control: _, ref data } => {
+                if data.len() == 1 {
+                    let ret_data = data[0];
+                    let term = &mut blocks.get_mut(&id).unwrap().term;
+                    write!(term, "  ret {}\n", self.get_value(ret_data, true))?
+                } else {
+                    let term = &mut blocks.get_mut(&id).unwrap().term;
+                    // Generate gep and stores into the output pointer
+                    for (idx, val) in data.iter().enumerate() {
+                        write!(
+                            term,
+                            "  %ret_ptr.{} = getelementptr inbounds %return.{}, ptr %ret.ptr, i32 0, i32 {}\n",
+                            idx,
+                            self.function.name,
+                            idx,
+                        )?;
+                        write!(
+                            term,
+                            "  store {}, ptr %ret_ptr.{}\n",
+                            self.get_value(*val, true),
+                            idx,
+                        )?;
+                    }
+                    // Finally return void
+                    write!(term, "  ret void\n")?
+                }
             }
             _ => panic!(
                 "PANIC: Can't lower {:?} in {}.",
@@ -808,7 +858,7 @@ impl<'a> CPUContext<'a> {
      */
     fn codegen_type_size(&self, ty: TypeID, body: &mut String) -> Result<String, Error> {
         match self.types[ty.idx()] {
-            Type::Control => panic!(),
+            Type::Control | Type::MultiReturn(_) => panic!(),
             Type::Boolean | Type::Integer8 | Type::UnsignedInteger8 | Type::Float8 => {
                 Ok("1".to_string())
             }
-- 
GitLab