From cb4634e4cbebb396795ae8d2ceb9f94b408ebe89 Mon Sep 17 00:00:00 2001
From: Aaron Councilman <aaronjc4@illinois.edu>
Date: Fri, 22 Nov 2024 09:47:17 -0600
Subject: [PATCH] Fix bug when read later used in write

- Need to track nodes that are replaced as we perform replacements
---
 hercules_opt/src/sroa.rs | 36 ++++++++++++++++++++++++++++++++++++
 juno_samples/products.jn |  5 +++--
 2 files changed, 39 insertions(+), 2 deletions(-)

diff --git a/hercules_opt/src/sroa.rs b/hercules_opt/src/sroa.rs
index 7430dbca..59ee4a8a 100644
--- a/hercules_opt/src/sroa.rs
+++ b/hercules_opt/src/sroa.rs
@@ -435,8 +435,44 @@ pub fn sroa(editor: &mut FunctionEditor, reverse_postorder: &Vec<NodeID>, types:
     });
 
     // Replace uses of old reads
+    // Because a read that is being replaced could also be the node some other read is being
+    // replaced by (if the first read is then written into a product that is then read from again)
+    // we need to track what nodes have already been replaced (and by what) so we can properly
+    // replace uses without leaving users of nodes that should be deleted.
+    // replaced_by tracks what a node has been replaced by while replaced_of tracks everything that
+    // maps to a particular node (which is needed to maintain the data structure efficiently)
+    let mut replaced_by: HashMap<NodeID, NodeID> = HashMap::new();
+    let mut replaced_of: HashMap<NodeID, Vec<NodeID>> = HashMap::new();
     for (old, new) in to_replace {
+        let new = match replaced_by.get(&new) {
+            Some(res) => *res,
+            None => new,
+        };
+
         editor.edit(|edit| edit.replace_all_uses(old, new));
+        replaced_by.insert(old, new);
+
+        let mut replaced = vec![];
+        match replaced_of.get_mut(&old) {
+            Some(res) => {
+                std::mem::swap(res, &mut replaced);
+            }
+            None => {}
+        }
+
+        let new_of = match replaced_of.get_mut(&new) {
+            Some(res) => res,
+            None => {
+                replaced_of.insert(new, vec![]);
+                replaced_of.get_mut(&new).unwrap()
+            }
+        };
+        new_of.push(old);
+
+        for n in replaced {
+            replaced_by.insert(n, new);
+            new_of.push(n);
+        }
     }
 
     // Remove nodes
diff --git a/juno_samples/products.jn b/juno_samples/products.jn
index b97f1088..d39ca246 100644
--- a/juno_samples/products.jn
+++ b/juno_samples/products.jn
@@ -6,6 +6,7 @@ fn test_call(x : i32, y : f32) -> (i32, f32) {
   return res;
 }
 
-fn test(x : i32, y : f32) -> (i32, f32) {
-  return test_call(x, y);
+fn test(x : i32, y : f32) -> (f32, i32) {
+  let res = test_call(x, y);
+  return (res.1, res.0);
 }
-- 
GitLab