Skip to content
Snippets Groups Projects
Commit 23103bc2 authored by Russel Arbore's avatar Russel Arbore
Browse files

holy shit that just worked

parent 9bc5101e
No related branches found
No related tags found
1 merge request!165Fork unrolling
Pipeline #201574 passed
...@@ -1186,7 +1186,8 @@ pub fn fork_unroll( ...@@ -1186,7 +1186,8 @@ pub fn fork_unroll(
join: NodeID, join: NodeID,
nodes_in_fork_joins: &HashMap<NodeID, HashSet<NodeID>>, nodes_in_fork_joins: &HashMap<NodeID, HashSet<NodeID>>,
) -> bool { ) -> bool {
// We can only unroll forks with a compile time known factor list. // We can only unroll fork-joins with a compile time known factor list. For
// simplicity, just unroll fork-joins that have a single dimension.
let nodes = &editor.func().nodes; let nodes = &editor.func().nodes;
let Node::Fork { let Node::Fork {
control, control,
...@@ -1195,17 +1196,79 @@ pub fn fork_unroll( ...@@ -1195,17 +1196,79 @@ pub fn fork_unroll(
else { else {
panic!() panic!()
}; };
let mut cons_factors = vec![]; if factors.len() != 1 || editor.get_users(fork).count() != 2 {
for factor in factors { return false;
let DynamicConstant::Constant(cons) = *editor.get_dynamic_constant(*factor) else {
return false;
};
cons_factors.push(cons);
} }
println!("{}: {:?}", editor.func().name, cons_factors); let DynamicConstant::Constant(cons) = *editor.get_dynamic_constant(factors[0]) else {
return false;
};
let tid = editor
.get_users(fork)
.filter(|id| nodes[id.idx()].is_thread_id())
.next()
.unwrap();
let inits: HashMap<NodeID, NodeID> = editor
.get_users(join)
.filter_map(|id| nodes[id.idx()].try_reduce().map(|(_, init, _)| (id, init)))
.collect();
editor.edit(|mut edit| { editor.edit(|mut edit| {
(); // Create a copy of the nodes in the fork join per unrolled iteration,
// excluding the fork itself, the join itself, the thread IDs of the fork,
// and the reduces on the join. Keep a running tally of the top control
// node and the current reduction value.
let mut top_control = control;
let mut current_reduces = inits;
for iter in 0..cons {
let iter_cons = edit.add_constant(Constant::UnsignedInteger64(iter as u64));
let iter_tid = edit.add_node(Node::Constant { id: iter_cons });
// First, add a copy of each node in the fork join unmodified.
// Record the mapping from old ID to new ID.
let mut added_ids = HashSet::new();
let mut old_to_new_ids = HashMap::new();
let mut new_control = None;
let mut new_reduces = HashMap::new();
for node in nodes_in_fork_joins[&fork].iter() {
if *node == fork {
old_to_new_ids.insert(*node, top_control);
} else if *node == join {
new_control = Some(edit.get_node(*node).try_join().unwrap());
} else if *node == tid {
old_to_new_ids.insert(*node, iter_tid);
} else if let Some(current) = current_reduces.get(node) {
old_to_new_ids.insert(*node, *current);
new_reduces.insert(*node, edit.get_node(*node).try_reduce().unwrap().2);
} else {
let new_node = edit.add_node(edit.get_node(*node).clone());
old_to_new_ids.insert(*node, new_node);
added_ids.insert(new_node);
}
}
// Second, replace all the uses in the just added nodes.
if let Some(new_control) = new_control {
top_control = old_to_new_ids[&new_control];
}
for (reduce, reduct) in new_reduces {
current_reduces.insert(reduce, old_to_new_ids[&reduct]);
}
for (old, new) in old_to_new_ids {
edit = edit.replace_all_uses_where(old, new, |id| added_ids.contains(id))?;
}
}
// Hook up the control and reduce outputs to the rest of the function.
edit = edit.replace_all_uses(join, top_control)?;
for (reduce, reduct) in current_reduces {
edit = edit.replace_all_uses(reduce, reduct)?;
}
// Delete the old fork-join.
for node in nodes_in_fork_joins[&fork].iter() {
edit = edit.delete_node(*node)?;
}
Ok(edit) Ok(edit)
}) })
} }
...@@ -39,8 +39,10 @@ dce(*); ...@@ -39,8 +39,10 @@ dce(*);
fixpoint panic after 20 { fixpoint panic after 20 {
infer-schedules(*); infer-schedules(*);
} }
unroll(auto.test1); fork-split(auto.test1);
xdot[true](*); fixpoint panic after 20 {
unroll(auto.test1);
}
fork-split(auto.test2, auto.test3, auto.test4, auto.test5); fork-split(auto.test2, auto.test3, auto.test4, auto.test5);
gvn(*); gvn(*);
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment