diff --git a/hercules_opt/src/sroa.rs b/hercules_opt/src/sroa.rs index 7430dbca0008498d32cedc285ea7684830915772..59ee4a8ad54cea9627ae85185fe5fe04198e72cc 100644 --- a/hercules_opt/src/sroa.rs +++ b/hercules_opt/src/sroa.rs @@ -435,8 +435,44 @@ pub fn sroa(editor: &mut FunctionEditor, reverse_postorder: &Vec<NodeID>, types: }); // Replace uses of old reads + // Because a read that is being replaced could also be the node some other read is being + // replaced by (if the first read is then written into a product that is then read from again) + // we need to track what nodes have already been replaced (and by what) so we can properly + // replace uses without leaving users of nodes that should be deleted. + // replaced_by tracks what a node has been replaced by while replaced_of tracks everything that + // maps to a particular node (which is needed to maintain the data structure efficiently) + let mut replaced_by: HashMap<NodeID, NodeID> = HashMap::new(); + let mut replaced_of: HashMap<NodeID, Vec<NodeID>> = HashMap::new(); for (old, new) in to_replace { + let new = match replaced_by.get(&new) { + Some(res) => *res, + None => new, + }; + editor.edit(|edit| edit.replace_all_uses(old, new)); + replaced_by.insert(old, new); + + let mut replaced = vec![]; + match replaced_of.get_mut(&old) { + Some(res) => { + std::mem::swap(res, &mut replaced); + } + None => {} + } + + let new_of = match replaced_of.get_mut(&new) { + Some(res) => res, + None => { + replaced_of.insert(new, vec![]); + replaced_of.get_mut(&new).unwrap() + } + }; + new_of.push(old); + + for n in replaced { + replaced_by.insert(n, new); + new_of.push(n); + } } // Remove nodes diff --git a/juno_samples/products.jn b/juno_samples/products.jn index b97f1088e3ff502b203f3b866338be9aae481bc1..d39ca2464d35d63c40f688f7f7d1354d1ab0a56a 100644 --- a/juno_samples/products.jn +++ b/juno_samples/products.jn @@ -6,6 +6,7 @@ fn test_call(x : i32, y : f32) -> (i32, f32) { return res; } -fn test(x : i32, y : f32) -> (i32, f32) { - return test_call(x, y); +fn test(x : i32, y : f32) -> (f32, i32) { + let res = test_call(x, y); + return (res.1, res.0); }