Skip to content
Snippets Groups Projects

Fork fission bufferize

Merged Xavier Routh requested to merge fork-fission-bufferize into main
All threads resolved!
3 files
+ 12
108
Compare changes
  • Side-by-side
  • Inline
Files
3
use std::collections::{HashMap, HashSet};
use std::collections::{HashMap, HashSet};
use std::hash::Hash;
use std::hash::Hash;
use std::iter::zip;
use std::iter::zip;
use std::thread::ThreadId;
use bimap::BiMap;
use bimap::BiMap;
use itertools::Itertools;
use itertools::Itertools;
@@ -102,7 +101,7 @@ pub fn find_reduce_dependencies<'a>(
@@ -102,7 +101,7 @@ pub fn find_reduce_dependencies<'a>(
pub fn copy_subgraph_in_edit<'a, 'b>(
pub fn copy_subgraph_in_edit<'a, 'b>(
mut edit: FunctionEdit<'a, 'b>,
mut edit: FunctionEdit<'a, 'b>,
subgraph: HashSet<NodeID>,
subgraph: HashSet<NodeID>,
) -> (Result<(FunctionEdit<'a, 'b>, HashMap<NodeID, NodeID>), FunctionEdit<'a, 'b>>) {
) -> Result<(FunctionEdit<'a, 'b>, HashMap<NodeID, NodeID>), FunctionEdit<'a, 'b>> {
let mut map: HashMap<NodeID, NodeID> = HashMap::new();
let mut map: HashMap<NodeID, NodeID> = HashMap::new();
// Copy nodes in subgraph
// Copy nodes in subgraph
@@ -185,11 +184,8 @@ pub fn find_bufferize_edges(
@@ -185,11 +184,8 @@ pub fn find_bufferize_edges(
nodes_in_fork_joins: &HashMap<NodeID, HashSet<NodeID>>,
nodes_in_fork_joins: &HashMap<NodeID, HashSet<NodeID>>,
data_label: &LabelID,
data_label: &LabelID,
) -> HashSet<(NodeID, NodeID)> {
) -> HashSet<(NodeID, NodeID)> {
// println!("func: {:?}", editor.func_id());
let mut edges: HashSet<_> = HashSet::new();
let mut edges: HashSet<_> = HashSet::new();
println!("ndoes in fork joins: {:?}", &nodes_in_fork_joins[&fork]);
// print labels
for node in &nodes_in_fork_joins[&fork] {
for node in &nodes_in_fork_joins[&fork] {
// Edge from *has data label** to doesn't have data label*
// Edge from *has data label** to doesn't have data label*
let node_labels = &editor.func().labels[node.idx()];
let node_labels = &editor.func().labels[node.idx()];
@@ -216,7 +212,6 @@ pub fn find_bufferize_edges(
@@ -216,7 +212,6 @@ pub fn find_bufferize_edges(
edges.insert((*node, user));
edges.insert((*node, user));
}
}
}
}
println!("edges: {:?}", edges);
edges
edges
}
}
@@ -238,7 +233,6 @@ where
@@ -238,7 +233,6 @@ where
.filter(|(k, _)| editor.func().nodes[k.idx()].is_fork())
.filter(|(k, _)| editor.func().nodes[k.idx()].is_fork())
.collect();
.collect();
println!("fork_label {:?}", fork_label);
for l in forks {
for l in forks {
let fork_info = Loop {
let fork_info = Loop {
header: l.0,
header: l.0,
@@ -247,13 +241,10 @@ where
@@ -247,13 +241,10 @@ where
let fork = fork_info.header;
let fork = fork_info.header;
let join = fork_join_map[&fork];
let join = fork_join_map[&fork];
println!("fork labels: {:?}", editor.func().labels[fork.idx()]);
if !editor.func().labels[fork.idx()].contains(fork_label) {
if !editor.func().labels[fork.idx()].contains(fork_label) {
continue;
continue;
}
}
println!("fork: {:?}", fork);
let edges = find_bufferize_edges(
let edges = find_bufferize_edges(
editor,
editor,
fork,
fork,
@@ -384,11 +375,6 @@ where
@@ -384,11 +375,6 @@ where
let edit_result = editor.edit(|edit| {
let edit_result = editor.edit(|edit| {
let (mut edit, map) = copy_subgraph_in_edit(edit, subgraph)?;
let (mut edit, map) = copy_subgraph_in_edit(edit, subgraph)?;
// Put new subgraph after old subgraph
// println!("map: {:?}", map);
// println!("join: {:?}, fork: {:?}", join, fork);
// println!("fork_sccueue: {:?}", join_successor);
edit = edit.replace_all_uses_where(fork_pred, join, |a| *a == map[&fork])?;
edit = edit.replace_all_uses_where(fork_pred, join, |a| *a == map[&fork])?;
edit = edit.replace_all_uses_where(join, map[&join], |a| *a == join_successor)?;
edit = edit.replace_all_uses_where(join, map[&join], |a| *a == join_successor)?;
@@ -398,8 +384,6 @@ where
@@ -398,8 +384,6 @@ where
.replace_all_uses_where(old_node, map[&old_node], |node| *node == outside_user)?;
.replace_all_uses_where(old_node, map[&old_node], |node| *node == outside_user)?;
}
}
// Add buffers to old subgraph
let new_join = map[&join];
let new_join = map[&join];
let new_fork = map[&fork];
let new_fork = map[&fork];
@@ -465,100 +449,12 @@ where
@@ -465,100 +449,12 @@ where
Ok(edit)
Ok(edit)
});
});
println!("edit_result: {:?}", edit_result);
if edit_result == false {
if edit_result == false {
todo!();
todo!();
return None;
return None;
}
}
Some((fork, new_fork_id))
Some((fork, new_fork_id))
// let internal_control: Vec<NodeID> = Vec::new();
// // Create fork + join + Thread control
// let mut new_fork_id = NodeID::new(0);
// let mut new_join_id = NodeID::new(0);
// editor.edit(|mut edit| {
// new_join_id = edit.add_node(Node::Join { control: fork });
// let factors = edit.get_node(fork).try_fork().unwrap().1;
// new_fork_id = edit.add_node(Node::Fork {
// control: new_join_id,
// factors: factors.into(),
// });
// edit.replace_all_uses_where(fork, new_fork_id, |usee| *usee == join)
// });
// for (src, dst) in bufferized_edges {
// // FIXME: Disgusting cloning and allocationing and iterators.
// let factors: Vec<_> = editor.func().nodes[fork.idx()]
// .try_fork()
// .unwrap()
// .1
// .iter()
// .cloned()
// .collect();
// editor.edit(|mut edit| {
// // Create write to buffer
// let thread_stuff_it = factors.into_iter().enumerate();
// // FIxme: try to use unzip here? Idk why it wasn't working.
// let tids = thread_stuff_it.clone().map(|(dim, _)| {
// edit.add_node(Node::ThreadID {
// control: fork,
// dimension: dim,
// })
// });
// let array_dims = thread_stuff_it.clone().map(|(_, factor)| (factor));
// // Assume 1-d fork only for now.
// // let tid = edit.add_node(Node::ThreadID { control: fork, dimension: 0 });
// let position_idx = Index::Position(tids.collect::<Vec<_>>().into_boxed_slice());
// let write = edit.add_node(Node::Write {
// collect: NodeID::new(0),
// data: src,
// indices: vec![position_idx].into(),
// });
// let ele_type = types[src.idx()];
// let empty_buffer = edit.add_type(hercules_ir::Type::Array(
// ele_type,
// array_dims.collect::<Vec<_>>().into_boxed_slice(),
// ));
// let empty_buffer = edit.add_zero_constant(empty_buffer);
// let empty_buffer = edit.add_node(Node::Constant { id: empty_buffer });
// let reduce = Node::Reduce {
// control: new_join_id,
// init: empty_buffer,
// reduct: write,
// };
// let reduce = edit.add_node(reduce);
// // Fix write node
// edit = edit.replace_all_uses_where(NodeID::new(0), reduce, |usee| *usee == write)?;
// // Create read from buffer
// let tids = thread_stuff_it.clone().map(|(dim, _)| {
// edit.add_node(Node::ThreadID {
// control: new_fork_id,
// dimension: dim,
// })
// });
// let position_idx = Index::Position(tids.collect::<Vec<_>>().into_boxed_slice());
// let read = edit.add_node(Node::Read {
// collect: reduce,
// indices: vec![position_idx].into(),
// });
// edit = edit.replace_all_uses_where(src, read, |usee| *usee == dst)?;
// Ok(edit)
// });
// }
// (fork, new_fork_id)
}
}
/** Split a 1D fork into a separate fork for each reduction. */
/** Split a 1D fork into a separate fork for each reduction. */
Loading