From 2021ea3b0b5f45add9409b4c649470ac7a642ac6 Mon Sep 17 00:00:00 2001
From: Aaron Councilman <aaronjc4@illinois.edu>
Date: Tue, 19 Nov 2024 08:14:04 -0600
Subject: [PATCH] Comment change in editor and format

---
 hercules_opt/src/editor.rs |   2 +
 hercules_opt/src/pass.rs   |   6 +-
 hercules_opt/src/sroa.rs   | 224 ++++++++++++++++++++++++++-----------
 3 files changed, 159 insertions(+), 73 deletions(-)

diff --git a/hercules_opt/src/editor.rs b/hercules_opt/src/editor.rs
index bf7cede2..2410a8ef 100644
--- a/hercules_opt/src/editor.rs
+++ b/hercules_opt/src/editor.rs
@@ -505,6 +505,8 @@ pub fn repair_plan(plan: &mut Plan, new_function: &Function, edits: &[Edit]) {
     // Step 2: drop schedules for deleted nodes and create empty schedule lists
     // for added nodes.
     for deleted in total_edit.0.iter() {
+        // Nodes that were created and deleted using the same editor don't have
+        // an existing schedule, so ignore them
         if deleted.idx() < plan.schedules.len() {
             plan.schedules[deleted.idx()] = vec![];
         }
diff --git a/hercules_opt/src/pass.rs b/hercules_opt/src/pass.rs
index a7d48efb..f8310ab7 100644
--- a/hercules_opt/src/pass.rs
+++ b/hercules_opt/src/pass.rs
@@ -506,11 +506,7 @@ impl PassManager {
                             &types_ref,
                             &def_uses[idx],
                         );
-                        sroa(
-                            &mut editor,
-                            &reverse_postorders[idx],
-                            &typing[idx]
-                        );
+                        sroa(&mut editor, &reverse_postorders[idx], &typing[idx]);
 
                         self.module.constants = constants_ref.take();
                         self.module.dynamic_constants = dynamic_constants_ref.take();
diff --git a/hercules_opt/src/sroa.rs b/hercules_opt/src/sroa.rs
index 2dd0049b..d8e2021a 100644
--- a/hercules_opt/src/sroa.rs
+++ b/hercules_opt/src/sroa.rs
@@ -43,33 +43,43 @@ use crate::*;
 pub fn sroa(editor: &mut FunctionEditor, reverse_postorder: &Vec<NodeID>, types: &Vec<TypeID>) {
     // This map stores a map from NodeID to an index tree which can be used to lookup the NodeID
     // that contains the corresponding fields of the original value
-    let mut field_map : HashMap<NodeID, IndexTree<NodeID>> = HashMap::new();
+    let mut field_map: HashMap<NodeID, IndexTree<NodeID>> = HashMap::new();
 
     // First: determine all nodes which interact with products (as described above)
-    let mut product_nodes : Vec<NodeID> = vec![];
+    let mut product_nodes: Vec<NodeID> = vec![];
     // We track call and return nodes separately since they (may) require constructing new products
     // for the call's arguments or the return's value
-    let mut call_return_nodes : Vec<NodeID> = vec![];
+    let mut call_return_nodes: Vec<NodeID> = vec![];
 
     let func = editor.func();
 
     for node in reverse_postorder {
         match func.nodes[node.idx()] {
-            Node::Phi { .. } | Node::Reduce { .. } | Node::Parameter { .. }
-            | Node::Constant { .. } | Node::Write { .. }
-            | Node::Ternary { first: _, second: _, third: _, op: TernaryOperator::Select }
-                if editor.get_type(types[node.idx()]).is_product() => product_nodes.push(*node),
-            
-            Node::Read { collect, .. }
-                if editor.get_type(types[collect.idx()]).is_product() => product_nodes.push(*node),
+            Node::Phi { .. }
+            | Node::Reduce { .. }
+            | Node::Parameter { .. }
+            | Node::Constant { .. }
+            | Node::Write { .. }
+            | Node::Ternary {
+                first: _,
+                second: _,
+                third: _,
+                op: TernaryOperator::Select,
+            } if editor.get_type(types[node.idx()]).is_product() => product_nodes.push(*node),
+
+            Node::Read { collect, .. } if editor.get_type(types[collect.idx()]).is_product() => {
+                product_nodes.push(*node)
+            }
 
             // We add all calls to the call/return list and check their arguments later
             Node::Call { .. } => call_return_nodes.push(*node),
             Node::Return { control: _, data }
                 if editor.get_type(types[data.idx()]).is_product() =>
-                call_return_nodes.push(*node),
+            {
+                call_return_nodes.push(*node)
+            }
 
-            _ => ()
+            _ => (),
         }
     }
 
@@ -85,36 +95,54 @@ pub fn sroa(editor: &mut FunctionEditor, reverse_postorder: &Vec<NodeID>, types:
             Node::Return { control, data } => {
                 assert!(editor.get_type(types[data.idx()]).is_product());
                 let control = *control;
-                let new_data = reconstruct_product(editor, types[data.idx()], *data, &mut product_nodes);
+                let new_data =
+                    reconstruct_product(editor, types[data.idx()], *data, &mut product_nodes);
                 editor.edit(|mut edit| {
-                    edit.add_node(Node::Return { control, data: new_data });
+                    edit.add_node(Node::Return {
+                        control,
+                        data: new_data,
+                    });
                     edit.delete_node(node)
                 });
             }
-            Node::Call { control, function, dynamic_constants, args } => {
+            Node::Call {
+                control,
+                function,
+                dynamic_constants,
+                args,
+            } => {
                 let control = *control;
                 let function = *function;
                 let dynamic_constants = dynamic_constants.clone();
                 let args = args.clone();
 
                 // If the call returns a product, we generate reads for each field
-                let fields =
-                    if editor.get_type(types[node.idx()]).is_product() {
-                        Some(generate_reads(editor, types[node.idx()], node))
-                    } else {
-                        None
-                    };
+                let fields = if editor.get_type(types[node.idx()]).is_product() {
+                    Some(generate_reads(editor, types[node.idx()], node))
+                } else {
+                    None
+                };
 
                 let mut new_args = vec![];
                 for arg in args {
                     if editor.get_type(types[arg.idx()]).is_product() {
-                        new_args.push(reconstruct_product(editor, types[arg.idx()], arg, &mut product_nodes));
+                        new_args.push(reconstruct_product(
+                            editor,
+                            types[arg.idx()],
+                            arg,
+                            &mut product_nodes,
+                        ));
                     } else {
                         new_args.push(arg);
                     }
                 }
                 editor.edit(|mut edit| {
-                    let new_call = edit.add_node(Node::Call { control, function, dynamic_constants, args: new_args.into() });
+                    let new_call = edit.add_node(Node::Call {
+                        control,
+                        function,
+                        dynamic_constants,
+                        args: new_args.into(),
+                    });
                     let edit = edit.replace_all_uses(node, new_call)?;
                     let edit = edit.delete_node(node)?;
 
@@ -128,15 +156,29 @@ pub fn sroa(editor: &mut FunctionEditor, reverse_postorder: &Vec<NodeID>, types:
                     Ok(edit)
                 });
             }
-            _ => panic!("Processing non-call or return node")
+            _ => panic!("Processing non-call or return node"),
         }
     }
 
     enum WorkItem {
         Unhandled(NodeID),
-        AllocatedPhi { control: NodeID, data: Vec<NodeID>, fields: IndexTree<NodeID> },
-        AllocatedReduce { control: NodeID, init: NodeID, reduct: NodeID, fields: IndexTree<NodeID> },
-        AllocatedTernary { first: NodeID, second: NodeID, third: NodeID, fields: IndexTree<NodeID> },
+        AllocatedPhi {
+            control: NodeID,
+            data: Vec<NodeID>,
+            fields: IndexTree<NodeID>,
+        },
+        AllocatedReduce {
+            control: NodeID,
+            init: NodeID,
+            reduct: NodeID,
+            fields: IndexTree<NodeID>,
+        },
+        AllocatedTernary {
+            first: NodeID,
+            second: NodeID,
+            third: NodeID,
+            fields: IndexTree<NodeID>,
+        },
     }
 
     // Now, we process the other nodes that deal with products.
@@ -155,15 +197,17 @@ pub fn sroa(editor: &mut FunctionEditor, reverse_postorder: &Vec<NodeID>, types:
                 field_map.insert(node, generate_constant_fields(editor, id));
                 to_delete.push(node);
             }
-            _ => { worklist.push_back(WorkItem::Unhandled(node)); }
+            _ => {
+                worklist.push_back(WorkItem::Unhandled(node));
+            }
         }
     }
 
     // Now, we process the remaining nodes, allocating NodeIDs for them and updating the field_map.
     // We track the current NodeID and add nodes whenever possible, otherwise we return values to
     // the worklist
-    let mut next_id : usize = editor.func().nodes.len();
-    let mut cur_id : usize = editor.func().nodes.len();
+    let mut next_id: usize = editor.func().nodes.len();
+    let mut cur_id: usize = editor.func().nodes.len();
     while let Some(mut item) = worklist.pop_front() {
         if let WorkItem::Unhandled(node) = item {
             match &editor.func().nodes[node.idx()] {
@@ -177,32 +221,59 @@ pub fn sroa(editor: &mut FunctionEditor, reverse_postorder: &Vec<NodeID>, types:
                     item = WorkItem::AllocatedPhi {
                         control,
                         data: data.into(),
-                        fields };
+                        fields,
+                    };
                 }
-                Node::Reduce { control, init, reduct } => {
+                Node::Reduce {
+                    control,
+                    init,
+                    reduct,
+                } => {
                     let control = *control;
                     let init = *init;
                     let reduct = *reduct;
                     let fields = allocate_fields(editor, types[node.idx()], &mut next_id);
                     field_map.insert(node, fields.clone());
 
-                    item = WorkItem::AllocatedReduce { control, init, reduct, fields };
+                    item = WorkItem::AllocatedReduce {
+                        control,
+                        init,
+                        reduct,
+                        fields,
+                    };
                 }
-                Node::Ternary { first, second, third, .. } => {
+                Node::Ternary {
+                    first,
+                    second,
+                    third,
+                    ..
+                } => {
                     let first = *first;
                     let second = *second;
                     let third = *third;
                     let fields = allocate_fields(editor, types[node.idx()], &mut next_id);
                     field_map.insert(node, fields.clone());
 
-                    item = WorkItem::AllocatedTernary { first, second, third, fields };
+                    item = WorkItem::AllocatedTernary {
+                        first,
+                        second,
+                        third,
+                        fields,
+                    };
                 }
 
-                Node::Write { collect, data, indices } => {
+                Node::Write {
+                    collect,
+                    data,
+                    indices,
+                } => {
                     if let Some(index_map) = field_map.get(collect) {
                         if editor.get_type(types[data.idx()]).is_product() {
                             if let Some(data_idx) = field_map.get(data) {
-                                field_map.insert(node, index_map.clone().replace(indices, data_idx.clone()));
+                                field_map.insert(
+                                    node,
+                                    index_map.clone().replace(indices, data_idx.clone()),
+                                );
                                 to_delete.push(node);
                             } else {
                                 worklist.push_back(WorkItem::Unhandled(node));
@@ -220,9 +291,7 @@ pub fn sroa(editor: &mut FunctionEditor, reverse_postorder: &Vec<NodeID>, types:
                         let read_info = index_map.lookup(indices);
                         match read_info {
                             IndexTree::Leaf(field) => {
-                                editor.edit(|edit| {
-                                    edit.replace_all_uses(node, *field)
-                                });
+                                editor.edit(|edit| edit.replace_all_uses(node, *field));
                             }
                             _ => {}
                         }
@@ -233,12 +302,12 @@ pub fn sroa(editor: &mut FunctionEditor, reverse_postorder: &Vec<NodeID>, types:
                     }
                 }
 
-                _ => panic!("Unexpected node type")
+                _ => panic!("Unexpected node type"),
             }
         }
         match item {
             WorkItem::Unhandled(_) => {}
-            _ => todo!()
+            _ => todo!(),
         }
     }
 
@@ -268,7 +337,7 @@ impl<T: std::fmt::Debug> IndexTree<T> {
             if let Index::Field(i) = idx[n] {
                 match self {
                     IndexTree::Leaf(_) => panic!("Invalid field"),
-                    IndexTree::Node(ts) => ts[i].lookup_idx(idx, n+1),
+                    IndexTree::Node(ts) => ts[i].lookup_idx(idx, n + 1),
                 }
             } else {
                 // TODO: This could be hit because of an array inside of a product
@@ -291,11 +360,11 @@ impl<T: std::fmt::Debug> IndexTree<T> {
                     IndexTree::Node(mut ts) => {
                         if i + 1 == ts.len() {
                             let t = ts.pop().unwrap();
-                            ts.push(t.set_idx(idx, val, n+1));
+                            ts.push(t.set_idx(idx, val, n + 1));
                         } else {
                             let mut t = ts.pop().unwrap();
                             std::mem::swap(&mut ts[i], &mut t);
-                            t = t.set_idx(idx, val, n+1);
+                            t = t.set_idx(idx, val, n + 1);
                             std::mem::swap(&mut ts[i], &mut t);
                             ts.push(t);
                         }
@@ -322,11 +391,11 @@ impl<T: std::fmt::Debug> IndexTree<T> {
                     IndexTree::Node(mut ts) => {
                         if i + 1 == ts.len() {
                             let t = ts.pop().unwrap();
-                            ts.push(t.replace_idx(idx, val, n+1));
+                            ts.push(t.replace_idx(idx, val, n + 1));
                         } else {
                             let mut t = ts.pop().unwrap();
                             std::mem::swap(&mut ts[i], &mut t);
-                            t = t.replace_idx(idx, val, n+1);
+                            t = t.replace_idx(idx, val, n + 1);
                             std::mem::swap(&mut ts[i], &mut t);
                             ts.push(t);
                         }
@@ -342,12 +411,16 @@ impl<T: std::fmt::Debug> IndexTree<T> {
     }
 
     fn for_each<F>(&self, mut f: F)
-        where F: FnMut(&Vec<Index>, &T) {
+    where
+        F: FnMut(&Vec<Index>, &T),
+    {
         self.for_each_idx(&mut vec![], &mut f);
     }
 
     fn for_each_idx<F>(&self, idx: &mut Vec<Index>, f: &mut F)
-        where F: FnMut(&Vec<Index>, &T) {
+    where
+        F: FnMut(&Vec<Index>, &T),
+    {
         match self {
             IndexTree::Leaf(t) => f(idx, t),
             IndexTree::Node(ts) => {
@@ -365,7 +438,12 @@ impl<T: std::fmt::Debug> IndexTree<T> {
 // from that value and then writing them into a new constant
 // This process also adds all the read nodes that are generated into the read_list so that the
 // reads can be eliminated by later parts of SROA
-fn reconstruct_product(editor: &mut FunctionEditor, typ: TypeID, val: NodeID, read_list: &mut Vec<NodeID>) -> NodeID {
+fn reconstruct_product(
+    editor: &mut FunctionEditor,
+    typ: TypeID,
+    val: NodeID,
+    read_list: &mut Vec<NodeID>,
+) -> NodeID {
     let fields = generate_reads(editor, typ, val);
     let new_const = generate_constant(editor, typ);
 
@@ -381,7 +459,11 @@ fn reconstruct_product(editor: &mut FunctionEditor, typ: TypeID, val: NodeID, re
     fields.for_each(|idx: &Vec<Index>, val: &NodeID| {
         read_list.push(*val);
         editor.edit(|mut edit| {
-            value = edit.add_node(Node::Write { collect: value, data: *val, indices: idx.clone().into() });
+            value = edit.add_node(Node::Write {
+                collect: value,
+                data: *val,
+                indices: idx.clone().into(),
+            });
             Ok(edit)
         });
     });
@@ -398,13 +480,17 @@ fn generate_reads(editor: &mut FunctionEditor, typ: TypeID, val: NodeID) -> Inde
 
 // Given a node val of type which at the indices idx has type typ, construct reads of all (leaf)
 // fields within this sub-value of val and return the correspondence list
-fn generate_reads_at_index(editor: &mut FunctionEditor, typ: TypeID, val: NodeID, idx: Vec<Index>) -> IndexTree<NodeID> {
-    let ts : Option<Vec<TypeID>> =
-        if let Some(ts) = editor.get_type(typ).try_product() {
-            Some(ts.into())
-        } else {
-            None
-        };
+fn generate_reads_at_index(
+    editor: &mut FunctionEditor,
+    typ: TypeID,
+    val: NodeID,
+    idx: Vec<Index>,
+) -> IndexTree<NodeID> {
+    let ts: Option<Vec<TypeID>> = if let Some(ts) = editor.get_type(typ).try_product() {
+        Some(ts.into())
+    } else {
+        None
+    };
 
     if let Some(ts) = ts {
         // For product values, we will recurse down each of its fields with an extended index
@@ -421,7 +507,10 @@ fn generate_reads_at_index(editor: &mut FunctionEditor, typ: TypeID, val: NodeID
         // information
         let mut read_id = None;
         editor.edit(|mut edit| {
-            read_id = Some(edit.add_node(Node::Read { collect: val, indices: idx.clone().into() }));
+            read_id = Some(edit.add_node(Node::Read {
+                collect: val,
+                indices: idx.clone().into(),
+            }));
             Ok(edit)
         });
 
@@ -437,7 +526,7 @@ macro_rules! add_const {
             Ok(edit)
         });
         res.expect("Add constant cannot fail")
-    }}
+    }};
 }
 
 // Given a type, builds a default constant of that type
@@ -470,14 +559,14 @@ fn generate_constant(editor: &mut FunctionEditor, typ: TypeID) -> ConstantID {
         Type::Array(elem, _) => {
             add_const!(editor, Constant::Array(typ))
         }
-        Type::Control => panic!("Cannot create constant of control type")
+        Type::Control => panic!("Cannot create constant of control type"),
     }
 }
 
 // Given a constant cnst adds node to the function which are the constant values of each field and
 // returns a list of pairs of indices and the node that holds that index
 fn generate_constant_fields(editor: &mut FunctionEditor, cnst: ConstantID) -> IndexTree<NodeID> {
-    let cs : Option<Vec<ConstantID>> =
+    let cs: Option<Vec<ConstantID>> =
         if let Some(cs) = editor.get_constant(cnst).try_product_fields() {
             Some(cs.into())
         } else {
@@ -503,12 +592,11 @@ fn generate_constant_fields(editor: &mut FunctionEditor, cnst: ConstantID) -> In
 // Given a type, return a list of the fields and new NodeIDs for them, with NodeIDs starting at the
 // id provided
 fn allocate_fields(editor: &FunctionEditor, typ: TypeID, id: &mut usize) -> IndexTree<NodeID> {
-    let ts : Option<Vec<TypeID>> =
-        if let Some(ts) = editor.get_type(typ).try_product() {
-            Some(ts.into())
-        } else {
-            None
-        };
+    let ts: Option<Vec<TypeID>> = if let Some(ts) = editor.get_type(typ).try_product() {
+        Some(ts.into())
+    } else {
+        None
+    };
 
     if let Some(ts) = ts {
         let mut fields = vec![];
-- 
GitLab