diff --git a/hercules_opt/src/fork_transforms.rs b/hercules_opt/src/fork_transforms.rs index f82bbdfb9fd7e6d4950fedaf78c882d539dce896..ea486c949d110dc4472226100d2a7f33d2a46a67 100644 --- a/hercules_opt/src/fork_transforms.rs +++ b/hercules_opt/src/fork_transforms.rs @@ -1,6 +1,7 @@ use std::collections::{HashMap, HashSet}; use std::iter::zip; use std::thread::ThreadId; +use std::hash::Hash; use bimap::BiMap; use itertools::Itertools; @@ -98,11 +99,11 @@ pub fn find_reduce_dependencies<'a>( ret_val } -pub fn copy_subgraph_in_edit<'a>( - mut edit: FunctionEdit<'a, 'a>, +pub fn copy_subgraph_in_edit<'a, 'b>( + mut edit: FunctionEdit<'a, 'b>, subgraph: HashSet<NodeID>, ) -> ( - Result<(FunctionEdit<'a, 'a>, HashMap<NodeID, NodeID>), FunctionEdit<'a, 'a>>// a map from old nodes to new nodes + Result<(FunctionEdit<'a, 'b>, HashMap<NodeID, NodeID>), FunctionEdit<'a, 'b>>// a map from old nodes to new nodes // A list of (inside *old* node, outside node) s.t insde old node -> outside node. // The caller probably wants ) { @@ -169,14 +170,56 @@ pub fn copy_subgraph( (new_nodes, map, outside_users) } +fn is_strict_superset<T: Eq + Hash>(set1: &HashSet<T>, set2: &HashSet<T>) -> bool { + // A strict superset must be larger than its subset + if set1.len() <= set2.len() { + return false; + } + + // Every element in set2 must be in set1 + set2.iter().all(|item| set1.contains(item)) +} -pub fn fork_fission_bufferize_toplevel<'a>( - editor: &'a mut FunctionEditor<'a>, - loop_tree: &'a LoopTree, - fork_join_map: &'a HashMap<NodeID, NodeID>, - data_node_in_fork_joins: &'a HashMap<NodeID, HashSet<NodeID>>, - typing: &'a Vec<TypeID> -) -> bool { +pub fn find_bufferize_edges( + editor: & mut FunctionEditor, + fork: NodeID, + loop_tree: & LoopTree, + fork_join_map: &HashMap<NodeID, NodeID>, + nodes_in_fork_joins: & HashMap<NodeID, HashSet<NodeID>>, +) -> HashSet<(NodeID, NodeID)> { + + // println!("func: {:?}", editor.func_id()); + let mut edges: HashSet<_> = HashSet::new(); + // print labels + for node in &nodes_in_fork_joins[&fork] { + println!("node: {:?}, label {:?}, ", node, editor.func().labels[node.idx()]); + let node_labels = &editor.func().labels[node.idx()]; + for usee in editor.get_uses(*node) { + // If usee labels is a superset of this node labels, then make an edge. + let usee_labels = &editor.func().labels[usee.idx()]; + // strict superset + if !(usee_labels.is_superset(node_labels) && usee_labels.len() > node_labels.len()) { + continue; + } + + if editor.node(usee).is_control() || editor.node(node).is_control() { + continue; + } + + edges.insert((usee, *node)); + } + } + println!("edges: {:?}", edges); + edges +} + +pub fn ff_bufferize_any_fork<'a, 'b>( + editor: &'b mut FunctionEditor<'a>, + loop_tree: &'b LoopTree, + fork_join_map: &'b HashMap<NodeID, NodeID>, + nodes_in_fork_joins: &'b HashMap<NodeID, HashSet<NodeID>>, + typing: &'b Vec<TypeID> +) -> Option<(NodeID, NodeID)> where 'a: 'b { let forks: Vec<_> = loop_tree .bottom_up_loops() @@ -185,18 +228,22 @@ pub fn fork_fission_bufferize_toplevel<'a>( .collect(); for l in forks { - let fork_info = &Loop { + let fork_info = Loop { header: l.0, control: l.1.clone(), }; let fork = fork_info.header; let join = fork_join_map[&fork]; - let mut edges = HashSet::new(); - edges.insert((NodeID::new(8), NodeID::new(3))); - let result = fork_bufferize_fission_helper(editor, fork_info, edges, data_node_in_fork_joins, typing, fork, join); - return result.is_some(); + + let edges = find_bufferize_edges(editor, fork, &loop_tree, &fork_join_map, &nodes_in_fork_joins); + let result = fork_bufferize_fission_helper(editor, &fork_info, &edges, nodes_in_fork_joins, typing, fork, join); + if result.is_none() { + continue + } else { + return result; + } } - return false; + return None; } @@ -239,15 +286,15 @@ pub fn fork_fission<'a>( } /** Split a 1D fork into two forks, placing select intermediate data into buffers. */ -pub fn fork_bufferize_fission_helper<'a>( - editor: &'a mut FunctionEditor<'a>, +pub fn fork_bufferize_fission_helper<'a, 'b>( + editor: &'b mut FunctionEditor<'a>, l: &Loop, - bufferized_edges: HashSet<(NodeID, NodeID)>, // Describes what intermediate data should be bufferized. - data_node_in_fork_joins: &'a HashMap<NodeID, HashSet<NodeID>>, - types: &Vec<TypeID>, + bufferized_edges: &HashSet<(NodeID, NodeID)>, // Describes what intermediate data should be bufferized. + data_node_in_fork_joins: &'b HashMap<NodeID, HashSet<NodeID>>, + types: &'b Vec<TypeID>, fork: NodeID, join: NodeID -) -> Option<(NodeID, NodeID)> { +) -> Option<(NodeID, NodeID)> where 'a: 'b { // Returns the two forks that it generates. if bufferized_edges.is_empty() { @@ -287,13 +334,15 @@ pub fn fork_bufferize_fission_helper<'a>( let join_successor = editor.get_users(join).filter(|a| editor.node(a).is_control()).next().unwrap(); let mut new_fork_id = NodeID::new(0); - editor.edit(|edit| { + + let edit_result = editor.edit(|edit| { let (mut edit, map) = copy_subgraph_in_edit(edit, subgraph)?; // Put new subgraph after old subgraph println!("map: {:?}", map); println!("join: {:?}, fork: {:?}", join, fork); println!("fork_sccueue: {:?}", join_successor); + edit = edit.replace_all_uses_where(fork_pred, join, |a| *a == map[&fork])?; edit = edit.replace_all_uses_where(join, map[&join], |a| *a == join_successor)?; @@ -302,6 +351,8 @@ pub fn fork_bufferize_fission_helper<'a>( edit = edit.replace_all_uses_where(old_node, map[&old_node], |node| *node == outside_user)?; } + + // Add buffers to old subgraph let new_join = map[&join]; @@ -326,7 +377,7 @@ pub fn fork_bufferize_fission_helper<'a>( new_tids.push(new_id); } - for (src, dst) in &bufferized_edges { + for (src, dst) in bufferized_edges { let array_dims = thread_stuff_it.clone().map(|(_, factor)| (factor)); let position_idx = Index::Position(old_tids.clone().into_boxed_slice()); @@ -366,14 +417,18 @@ pub fn fork_bufferize_fission_helper<'a>( } - - new_fork_id = new_fork; Ok(edit) }); + println!("edit_result: {:?}", edit_result); + if edit_result == false { + todo!(); + return None + } + Some((fork, new_fork_id)) // let internal_control: Vec<NodeID> = Vec::new(); diff --git a/juno_samples/fork_join_tests/src/blah.sch b/juno_samples/fork_join_tests/src/blah.sch new file mode 100644 index 0000000000000000000000000000000000000000..52dea702fc0c9057d764eb25f60a27aa52c7a530 --- /dev/null +++ b/juno_samples/fork_join_tests/src/blah.sch @@ -0,0 +1,34 @@ + +xdot[true](*); + +fixpoint panic after 20 { + forkify(*); + fork-guard-elim(*); + fork-coalesce(*); + dce(*); +} + +xdot[true](*); + +//gvn(*); +//phi-elim(*); +//dce(*); + +//gvn(*); +//phi-elim(*); +//dce(*); + +//fixpoint panic after 20 { +// infer-schedules(*); +//} + +//fork-split(*); +//gvn(*); +//phi-elim(*); +//dce(*); +//unforkify(*); +//gvn(*); +//phi-elim(*); +//dce(*); + +//gcm(*); diff --git a/juno_samples/fork_join_tests/src/cpu.sch b/juno_samples/fork_join_tests/src/cpu.sch index 947d0dc86a72b55c1adf99d1d4ab4cc77f209c35..e04d8dfec9e7318bec95a62671da78e281701ed1 100644 --- a/juno_samples/fork_join_tests/src/cpu.sch +++ b/juno_samples/fork_join_tests/src/cpu.sch @@ -7,13 +7,10 @@ dce(*); let out = auto-outline(test1, test2, test3, test4, test5); cpu(out.test1); -<<<<<<< Updated upstream cpu(out.test2); cpu(out.test3); cpu(out.test4); cpu(out.test5); -======= ->>>>>>> Stashed changes ip-sroa(*); @@ -33,9 +30,6 @@ fixpoint panic after 20 { dce(*); gvn(*); -xdot[true](*); -fork-fission-bufferize(*); -xdot[true](*); gvn(*); phi-elim(*); dce(*); @@ -53,10 +47,12 @@ gvn(*); phi-elim(*); dce(*); -xdot[true](*); -fork-fission-bufferize(test7@loop, test7@bufferize1, test7@bufferize2, test7@bufferize3, test7@bufferize4); +fork-fission-bufferize(test7); +dce(*); + fork-tile[32, 0, true](test6@loop); let out = fork-split(test6@loop); +fork-split(*); //let out = outline(out.test6.fj1); let out7 = auto-outline(test7); @@ -69,6 +65,8 @@ sroa(*); unforkify(out.test6); unforkify(out7.test7); dce(*); +unforkify(*); +dce(*); ccp(*); gvn(*); phi-elim(*); diff --git a/juno_samples/fork_join_tests/src/fork_join_tests.jn b/juno_samples/fork_join_tests/src/fork_join_tests.jn index 90d06c2fe2c7139aaee4d3dd6288a82ffa87e055..ae3be778b42343d2a23404e98ddcb3fb9218ed3a 100644 --- a/juno_samples/fork_join_tests/src/fork_join_tests.jn +++ b/juno_samples/fork_join_tests/src/fork_join_tests.jn @@ -9,7 +9,6 @@ fn test1(input : i32) -> i32[4, 4] { return arr; } -/** #[entry] fn test2(input : i32) -> i32[4, 4] { let arr : i32[4, 4]; @@ -73,7 +72,6 @@ fn test5(input : i32) -> i32[4] { } return arr1; } -<<<<<<< Updated upstream #[entry] fn test6(input: i32) -> i32[1024] { @@ -87,12 +85,22 @@ fn test6(input: i32) -> i32[1024] { #[entry] fn test7(input : i32) -> i32[8] { let arr : i32[8]; + let out : i32[8]; + + for i = 0 to 8 { + arr[i] = i as i32; + } + @loop for i = 0 to 8 { - @bufferize1 let a = arr[i]; - @bufferize2 let b = a + arr[7-i]; - @bufferize3 let c = b * i as i32; - @bufferize4 let d = c; - arr[i] = d; + let b: i32; + @bufferize1 { + let a = arr[i]; + let a2 = a + arr[7-i]; + b = a2 * i as i32; + } + let c = b; + let d = c + 10; + out[i] = d; } - return arr; + return out; } \ No newline at end of file diff --git a/juno_samples/fork_join_tests/src/main.rs b/juno_samples/fork_join_tests/src/main.rs index 33c5602e4c3e4df4aab3cbbd9a88a666870023e0..caf956a1652d7e638480b3176128469b6a574b22 100644 --- a/juno_samples/fork_join_tests/src/main.rs +++ b/juno_samples/fork_join_tests/src/main.rs @@ -23,7 +23,6 @@ fn main() { let correct = vec![5i32; 16]; assert(&correct, output); -<<<<<<< Updated upstream let mut r = runner!(test2); let output = r.run(3).await; let correct = vec![24i32; 16]; @@ -44,31 +43,10 @@ fn main() { let correct = vec![7i32; 4]; assert(&correct, output); - let mut r = runner!(test6); - let output = r.run(73).await; - let correct = (73i32..73i32+1024i32).collect(); + let mut r = runner!(test7); + let output = r.run(0).await; + let correct = vec![10, 17, 24, 31, 38, 45, 52, 59]; assert(&correct, output); -======= - // let mut r = runner!(test2); - // let output = r.run(3).await; - // let correct = vec![24i32; 16]; - // assert(correct, output); - - // let mut r = runner!(test3); - // let output = r.run(0).await; - // let correct = vec![11, 10, 9, 10, 9, 8, 9, 8, 7]; - // assert(correct, output); - - // let mut r = runner!(test4); - // let output = r.run(9).await; - // let correct = vec![63i32; 16]; - // assert(correct, output); - - // let mut r = runner!(test5); - // let output = r.run(4).await; - // let correct = vec![7i32; 4]; - // assert(correct, output); ->>>>>>> Stashed changes }); } diff --git a/juno_scheduler/src/pm.rs b/juno_scheduler/src/pm.rs index 28b0cbf524c4607317ce773d25e85bdf4da2ef89..2c5a3687131881c6b025af60c6425dd8d42886d7 100644 --- a/juno_scheduler/src/pm.rs +++ b/juno_scheduler/src/pm.rs @@ -2015,6 +2015,8 @@ fn run_pass( } Pass::ForkFissionBufferize => { assert!(args.is_empty()); + let mut created_fork_joins = vec![vec![]; pm.functions.len()]; + pm.make_fork_join_maps(); pm.make_typing(); pm.make_loops(); @@ -2034,9 +2036,63 @@ fn run_pass( let Some(mut func) = func else { continue; }; - let result = fork_fission_bufferize_toplevel(&mut func, loop_tree, fork_join_map, nodes_in_fork_joins, typing); - changed |= result; + if let Some((fork1, fork2)) = + ff_bufferize_any_fork(&mut func, loop_tree, fork_join_map, nodes_in_fork_joins, typing) + { + let created_fork_joins = &mut created_fork_joins[func.func_id().idx()]; + created_fork_joins.push(fork1); + created_fork_joins.push(fork2); + } + changed |= func.modified(); + } + + pm.make_nodes_in_fork_joins(); + let nodes_in_fork_joins = pm.nodes_in_fork_joins.take().unwrap(); + let mut new_fork_joins = HashMap::new(); + + for (mut func, created_fork_joins) in + build_editors(pm).into_iter().zip(created_fork_joins) + { + // For every function, create a label for every level of fork- + // joins resulting from the split. + let name = func.func().name.clone(); + let func_id = func.func_id(); + let labels = create_labels_for_node_sets( + &mut func, + created_fork_joins.into_iter().map(|fork| { + nodes_in_fork_joins[func_id.idx()][&fork] + .iter() + .map(|id| *id) + }) + , + ); + + // Assemble those labels into a record for this function. The + // format of the records is <function>.<f>, where N is the + // level of the split fork-joins being referred to. + todo!(); + // FIXME: What if there are multiple bufferized forks per function? + let mut func_record = HashMap::new(); + for (idx, label) in labels { + func_record.insert( + format!("fj{}", idx), + Value::Label { + labels: vec![LabelInfo { + func: func_id, + label: label, + }], + }, + ); + } + + // Try to avoid creating unnecessary record entries. + if !func_record.is_empty() { + new_fork_joins.entry(name).insert_entry(Value::Record { + fields: func_record, + }); + } } + pm.delete_gravestones(); pm.clear_analyses(); }