From b4978ef526f248e205f2df22c4493e649b080162 Mon Sep 17 00:00:00 2001
From: Russel Arbore <russel.jma@gmail.com>
Date: Mon, 10 Feb 2025 09:28:34 -0600
Subject: [PATCH] cleanup

---
 hercules_opt/src/fork_transforms.rs      | 106 +----------------------
 juno_samples/fork_join_tests/src/cpu.sch |   3 +
 juno_samples/fork_join_tests/src/main.rs |  11 ++-
 3 files changed, 12 insertions(+), 108 deletions(-)

diff --git a/hercules_opt/src/fork_transforms.rs b/hercules_opt/src/fork_transforms.rs
index ad81d90b..80ff7b8f 100644
--- a/hercules_opt/src/fork_transforms.rs
+++ b/hercules_opt/src/fork_transforms.rs
@@ -1,7 +1,6 @@
 use std::collections::{HashMap, HashSet};
 use std::hash::Hash;
 use std::iter::zip;
-use std::thread::ThreadId;
 
 use bimap::BiMap;
 use itertools::Itertools;
@@ -102,7 +101,7 @@ pub fn find_reduce_dependencies<'a>(
 pub fn copy_subgraph_in_edit<'a, 'b>(
     mut edit: FunctionEdit<'a, 'b>,
     subgraph: HashSet<NodeID>,
-) -> (Result<(FunctionEdit<'a, 'b>, HashMap<NodeID, NodeID>), FunctionEdit<'a, 'b>>) {
+) -> Result<(FunctionEdit<'a, 'b>, HashMap<NodeID, NodeID>), FunctionEdit<'a, 'b>> {
     let mut map: HashMap<NodeID, NodeID> = HashMap::new();
 
     // Copy nodes in subgraph
@@ -185,11 +184,8 @@ pub fn find_bufferize_edges(
     nodes_in_fork_joins: &HashMap<NodeID, HashSet<NodeID>>,
     data_label: &LabelID,
 ) -> HashSet<(NodeID, NodeID)> {
-    // println!("func: {:?}", editor.func_id());
     let mut edges: HashSet<_> = HashSet::new();
 
-    println!("ndoes in fork joins: {:?}", &nodes_in_fork_joins[&fork]);
-    // print labels
     for node in &nodes_in_fork_joins[&fork] {
         // Edge from *has data label** to doesn't have data label*
         let node_labels = &editor.func().labels[node.idx()];
@@ -216,7 +212,6 @@ pub fn find_bufferize_edges(
             edges.insert((*node, user));
         }
     }
-    println!("edges: {:?}", edges);
     edges
 }
 
@@ -238,7 +233,6 @@ where
         .filter(|(k, _)| editor.func().nodes[k.idx()].is_fork())
         .collect();
 
-    println!("fork_label {:?}", fork_label);
     for l in forks {
         let fork_info = Loop {
             header: l.0,
@@ -247,13 +241,10 @@ where
         let fork = fork_info.header;
         let join = fork_join_map[&fork];
 
-        println!("fork labels: {:?}", editor.func().labels[fork.idx()]);
         if !editor.func().labels[fork.idx()].contains(fork_label) {
             continue;
         }
 
-        println!("fork: {:?}", fork);
-
         let edges = find_bufferize_edges(
             editor,
             fork,
@@ -384,11 +375,6 @@ where
     let edit_result = editor.edit(|edit| {
         let (mut edit, map) = copy_subgraph_in_edit(edit, subgraph)?;
 
-        // Put new subgraph after old subgraph
-        // println!("map: {:?}", map);
-        // println!("join: {:?}, fork: {:?}", join, fork);
-        // println!("fork_sccueue: {:?}", join_successor);
-
         edit = edit.replace_all_uses_where(fork_pred, join, |a| *a == map[&fork])?;
         edit = edit.replace_all_uses_where(join, map[&join], |a| *a == join_successor)?;
 
@@ -398,8 +384,6 @@ where
                 .replace_all_uses_where(old_node, map[&old_node], |node| *node == outside_user)?;
         }
 
-        // Add buffers to old subgraph
-
         let new_join = map[&join];
         let new_fork = map[&fork];
 
@@ -465,100 +449,12 @@ where
         Ok(edit)
     });
 
-    println!("edit_result: {:?}", edit_result);
     if edit_result == false {
         todo!();
         return None;
     }
 
     Some((fork, new_fork_id))
-
-    // let internal_control: Vec<NodeID> = Vec::new();
-
-    // // Create fork + join + Thread control
-    // let mut new_fork_id = NodeID::new(0);
-    // let mut new_join_id = NodeID::new(0);
-
-    // editor.edit(|mut edit| {
-    //     new_join_id = edit.add_node(Node::Join { control: fork });
-    //     let factors = edit.get_node(fork).try_fork().unwrap().1;
-    //     new_fork_id = edit.add_node(Node::Fork {
-    //         control: new_join_id,
-    //         factors: factors.into(),
-    //     });
-    //     edit.replace_all_uses_where(fork, new_fork_id, |usee| *usee == join)
-    // });
-
-    // for (src, dst) in bufferized_edges {
-    //     // FIXME: Disgusting cloning and allocationing and iterators.
-    //     let factors: Vec<_> = editor.func().nodes[fork.idx()]
-    //         .try_fork()
-    //         .unwrap()
-    //         .1
-    //         .iter()
-    //         .cloned()
-    //         .collect();
-    //     editor.edit(|mut edit| {
-    //         // Create write to buffer
-
-    //         let thread_stuff_it = factors.into_iter().enumerate();
-
-    //         // FIxme: try to use unzip here? Idk why it wasn't working.
-    //         let tids = thread_stuff_it.clone().map(|(dim, _)| {
-    //             edit.add_node(Node::ThreadID {
-    //                 control: fork,
-    //                 dimension: dim,
-    //             })
-    //         });
-
-    //         let array_dims = thread_stuff_it.clone().map(|(_, factor)| (factor));
-
-    //         // Assume 1-d fork only for now.
-    //         // let tid = edit.add_node(Node::ThreadID { control: fork, dimension: 0 });
-    //         let position_idx = Index::Position(tids.collect::<Vec<_>>().into_boxed_slice());
-    //         let write = edit.add_node(Node::Write {
-    //             collect: NodeID::new(0),
-    //             data: src,
-    //             indices: vec![position_idx].into(),
-    //         });
-    //         let ele_type = types[src.idx()];
-    //         let empty_buffer = edit.add_type(hercules_ir::Type::Array(
-    //             ele_type,
-    //             array_dims.collect::<Vec<_>>().into_boxed_slice(),
-    //         ));
-    //         let empty_buffer = edit.add_zero_constant(empty_buffer);
-    //         let empty_buffer = edit.add_node(Node::Constant { id: empty_buffer });
-    //         let reduce = Node::Reduce {
-    //             control: new_join_id,
-    //             init: empty_buffer,
-    //             reduct: write,
-    //         };
-    //         let reduce = edit.add_node(reduce);
-    //         // Fix write node
-    //         edit = edit.replace_all_uses_where(NodeID::new(0), reduce, |usee| *usee == write)?;
-
-    //         // Create read from buffer
-    //         let tids = thread_stuff_it.clone().map(|(dim, _)| {
-    //             edit.add_node(Node::ThreadID {
-    //                 control: new_fork_id,
-    //                 dimension: dim,
-    //             })
-    //         });
-
-    //         let position_idx = Index::Position(tids.collect::<Vec<_>>().into_boxed_slice());
-
-    //         let read = edit.add_node(Node::Read {
-    //             collect: reduce,
-    //             indices: vec![position_idx].into(),
-    //         });
-
-    //         edit = edit.replace_all_uses_where(src, read, |usee| *usee == dst)?;
-
-    //         Ok(edit)
-    //     });
-    // }
-
-    // (fork, new_fork_id)
 }
 
 /** Split a 1D fork into a separate fork for each reduction. */
diff --git a/juno_samples/fork_join_tests/src/cpu.sch b/juno_samples/fork_join_tests/src/cpu.sch
index fd3fc0ff..38a38c2b 100644
--- a/juno_samples/fork_join_tests/src/cpu.sch
+++ b/juno_samples/fork_join_tests/src/cpu.sch
@@ -74,7 +74,10 @@ dce(auto.test7);
 let fission = fork-fission-bufferize[test8@loop, test8@bufferize1](auto.test8);
 dce(auto.test8);
 unforkify(auto.test8);
+dce(auto.test8);
 ccp(auto.test8);
 dce(auto.test8);
+simplify-cfg(auto.test8);
+dce(auto.test8);
 
 gcm(*);
diff --git a/juno_samples/fork_join_tests/src/main.rs b/juno_samples/fork_join_tests/src/main.rs
index 4006afa8..1013a1f0 100644
--- a/juno_samples/fork_join_tests/src/main.rs
+++ b/juno_samples/fork_join_tests/src/main.rs
@@ -43,15 +43,20 @@ fn main() {
         let correct = vec![7i32; 4];
         assert(&correct, output);
 
-        let mut r = runner!(test8);
-        let output = r.run(0).await;
-        let correct = vec![10, 17, 24, 31, 38, 45, 52, 59];
+        let mut r = runner!(test6);
+        let output = r.run(73).await;
+        let correct = (73i32..73i32+1024i32).collect();
         assert(&correct, output);
 
         let mut r = runner!(test7);
         let output = r.run(42).await;
         let correct: i32 = (42i32..42i32+32i32).sum();
         assert_eq!(correct, output);
+
+        let mut r = runner!(test8);
+        let output = r.run(0).await;
+        let correct = vec![10, 17, 24, 31, 38, 45, 52, 59];
+        assert(&correct, output);
     });
 }
 
-- 
GitLab