diff --git a/hercules_cg/src/cpu.rs b/hercules_cg/src/cpu.rs index 4f8f91ef177ba947d35d3ef3885d8273d059aa35..20a3e6cb1a2b663e2bd3a7fc41df82fe0ef443db 100644 --- a/hercules_cg/src/cpu.rs +++ b/hercules_cg/src/cpu.rs @@ -694,7 +694,7 @@ impl<'a> CPUContext<'a> { ); write!( body, - " {} = call i64 @llvm.umin.i64(i64{},i64%dc{}))\n", + " {} = call i64 @llvm.umin.i64(i64{},i64%dc{})\n", new_val, cur_value, x.idx() diff --git a/hercules_opt/src/fork_transforms.rs b/hercules_opt/src/fork_transforms.rs index 456f670ea8f9a5998031597ffba830596d766dee..7c4238921a6d7c96dc4645f4fdb123aa1bf1a594 100644 --- a/hercules_opt/src/fork_transforms.rs +++ b/hercules_opt/src/fork_transforms.rs @@ -1,6 +1,6 @@ use std::collections::{HashMap, HashSet}; +use std::hash::Hash; use std::iter::zip; -use std::thread::ThreadId; use bimap::BiMap; use itertools::Itertools; @@ -98,6 +98,28 @@ pub fn find_reduce_dependencies<'a>( ret_val } +pub fn copy_subgraph_in_edit<'a, 'b>( + mut edit: FunctionEdit<'a, 'b>, + subgraph: HashSet<NodeID>, +) -> Result<(FunctionEdit<'a, 'b>, HashMap<NodeID, NodeID>), FunctionEdit<'a, 'b>> { + let mut map: HashMap<NodeID, NodeID> = HashMap::new(); + + // Copy nodes in subgraph + for old_id in subgraph.iter().cloned() { + let new_id = edit.copy_node(old_id); + map.insert(old_id, new_id); + } + + // Update edges to new nodes + for old_id in subgraph.iter() { + edit = edit.replace_all_uses_where(*old_id, map[old_id], |node_id| { + map.values().contains(node_id) + })?; + } + + Ok((edit, map)) +} + pub fn copy_subgraph( editor: &mut FunctionEditor, subgraph: HashSet<NodeID>, @@ -144,6 +166,101 @@ pub fn copy_subgraph( (new_nodes, map, outside_users) } +pub fn find_bufferize_edges( + editor: &mut FunctionEditor, + fork: NodeID, + loop_tree: &LoopTree, + fork_join_map: &HashMap<NodeID, NodeID>, + nodes_in_fork_joins: &HashMap<NodeID, HashSet<NodeID>>, + data_label: &LabelID, +) -> HashSet<(NodeID, NodeID)> { + let mut edges: HashSet<_> = HashSet::new(); + + for node in &nodes_in_fork_joins[&fork] { + // Edge from *has data label** to doesn't have data label* + let node_labels = &editor.func().labels[node.idx()]; + + if !node_labels.contains(data_label) { + continue; + } + + // Don't draw bufferize edges from fork tids + if editor.get_users(fork).contains(node) { + continue; + } + + for user in editor.get_users(*node) { + let user_labels = &editor.func().labels[user.idx()]; + if user_labels.contains(data_label) { + continue; + } + + if editor.node(user).is_control() || editor.node(node).is_control() { + continue; + } + + edges.insert((*node, user)); + } + } + edges +} + +pub fn ff_bufferize_any_fork<'a, 'b>( + editor: &'b mut FunctionEditor<'a>, + loop_tree: &'b LoopTree, + fork_join_map: &'b HashMap<NodeID, NodeID>, + nodes_in_fork_joins: &'b HashMap<NodeID, HashSet<NodeID>>, + typing: &'b Vec<TypeID>, + fork_label: &'b LabelID, + data_label: &'b LabelID, +) -> Option<(NodeID, NodeID)> +where + 'a: 'b, +{ + let forks: Vec<_> = loop_tree + .bottom_up_loops() + .into_iter() + .filter(|(k, _)| editor.func().nodes[k.idx()].is_fork()) + .collect(); + + for l in forks { + let fork_info = Loop { + header: l.0, + control: l.1.clone(), + }; + let fork = fork_info.header; + let join = fork_join_map[&fork]; + + if !editor.func().labels[fork.idx()].contains(fork_label) { + continue; + } + + let edges = find_bufferize_edges( + editor, + fork, + &loop_tree, + &fork_join_map, + &nodes_in_fork_joins, + data_label, + ); + let result = fork_bufferize_fission_helper( + editor, + &fork_info, + &edges, + nodes_in_fork_joins, + typing, + fork, + join, + ); + if result.is_none() { + continue; + } else { + return result; + } + } + return None; +} + pub fn fork_fission<'a>( editor: &'a mut FunctionEditor, _control_subgraph: &Subgraph, @@ -183,72 +300,107 @@ pub fn fork_fission<'a>( } /** Split a 1D fork into two forks, placing select intermediate data into buffers. */ -pub fn fork_bufferize_fission_helper<'a>( - editor: &'a mut FunctionEditor, - fork_join_map: &HashMap<NodeID, NodeID>, - bufferized_edges: HashSet<(NodeID, NodeID)>, // Describes what intermediate data should be bufferized. - _original_control_pred: NodeID, // What the new fork connects to. - types: &Vec<TypeID>, +pub fn fork_bufferize_fission_helper<'a, 'b>( + editor: &'b mut FunctionEditor<'a>, + l: &Loop, + bufferized_edges: &HashSet<(NodeID, NodeID)>, // Describes what intermediate data should be bufferized. + data_node_in_fork_joins: &'b HashMap<NodeID, HashSet<NodeID>>, + types: &'b Vec<TypeID>, fork: NodeID, -) -> (NodeID, NodeID) { - // Returns the two forks that it generates. + join: NodeID, +) -> Option<(NodeID, NodeID)> +where + 'a: 'b, +{ + if bufferized_edges.is_empty() { + return None; + } - // TODO: Check that bufferized edges src doesn't depend on anything that comes after the fork. + let all_loop_nodes = l.get_all_nodes(); - // Copy fork + control intermediates + join to new fork + join, - // How does control get partitioned? - // (depending on how it affects the data nodes on each side of the bufferized_edges) - // may end up in each loop, fix me later. - // place new fork + join after join of first. + // FIXME: Cloning hell. + let data_nodes = data_node_in_fork_joins[&fork].clone(); + let loop_nodes = editor + .node_ids() + .filter(|node_id| all_loop_nodes[node_id.idx()]); + // Clone the subgraph that consists of this fork-join and all data and control nodes in it. + let subgraph = HashSet::from_iter(data_nodes.into_iter().chain(loop_nodes)); - // Only handle fork+joins with no inner control for now. + let mut outside_users = Vec::new(); // old_node, outside_user + + for node in subgraph.iter() { + for user in editor.get_users(*node) { + if !subgraph.iter().contains(&user) { + outside_users.push((*node, user)); + } + } + } + + let factors: Vec<_> = editor.func().nodes[fork.idx()] + .try_fork() + .unwrap() + .1 + .iter() + .cloned() + .collect(); + + let thread_stuff_it = factors.into_iter().enumerate(); + + // Control succesors + let fork_pred = editor + .get_uses(fork) + .filter(|a| editor.node(a).is_control()) + .next() + .unwrap(); + let join_successor = editor + .get_users(join) + .filter(|a| editor.node(a).is_control()) + .next() + .unwrap(); - // Create fork + join + Thread control - let join = fork_join_map[&fork]; let mut new_fork_id = NodeID::new(0); - let mut new_join_id = NodeID::new(0); - editor.edit(|mut edit| { - new_join_id = edit.add_node(Node::Join { control: fork }); - let factors = edit.get_node(fork).try_fork().unwrap().1; - new_fork_id = edit.add_node(Node::Fork { - control: new_join_id, - factors: factors.into(), - }); - edit.replace_all_uses_where(fork, new_fork_id, |usee| *usee == join) - }); + let edit_result = editor.edit(|edit| { + let (mut edit, map) = copy_subgraph_in_edit(edit, subgraph)?; - for (src, dst) in bufferized_edges { - // FIXME: Disgusting cloning and allocationing and iterators. - let factors: Vec<_> = editor.func().nodes[fork.idx()] - .try_fork() - .unwrap() - .1 - .iter() - .cloned() - .collect(); - editor.edit(|mut edit| { - // Create write to buffer + edit = edit.replace_all_uses_where(fork_pred, join, |a| *a == map[&fork])?; + edit = edit.replace_all_uses_where(join, map[&join], |a| *a == join_successor)?; + + // Replace outside uses of reduces in old subgraph with reduces in new subgraph. + for (old_node, outside_user) in outside_users { + edit = edit + .replace_all_uses_where(old_node, map[&old_node], |node| *node == outside_user)?; + } + + let new_fork = map[&fork]; - let thread_stuff_it = factors.into_iter().enumerate(); + // FIXME: Do this as part of copy subgraph? + // Add tids to original subgraph for indexing. + let mut old_tids = Vec::new(); + let mut new_tids = Vec::new(); + for (dim, _) in thread_stuff_it.clone() { + let old_id = edit.add_node(Node::ThreadID { + control: fork, + dimension: dim, + }); - // FIxme: try to use unzip here? Idk why it wasn't working. - let tids = thread_stuff_it.clone().map(|(dim, _)| { - edit.add_node(Node::ThreadID { - control: fork, - dimension: dim, - }) + let new_id = edit.add_node(Node::ThreadID { + control: new_fork, + dimension: dim, }); + old_tids.push(old_id); + new_tids.push(new_id); + } + + for (src, dst) in bufferized_edges { let array_dims = thread_stuff_it.clone().map(|(_, factor)| (factor)); + let position_idx = Index::Position(old_tids.clone().into_boxed_slice()); - // Assume 1-d fork only for now. - // let tid = edit.add_node(Node::ThreadID { control: fork, dimension: 0 }); - let position_idx = Index::Position(tids.collect::<Vec<_>>().into_boxed_slice()); let write = edit.add_node(Node::Write { collect: NodeID::new(0), - data: src, - indices: vec![position_idx].into(), + data: *src, + indices: vec![position_idx.clone()].into(), }); let ele_type = types[src.idx()]; let empty_buffer = edit.add_type(hercules_ir::Type::Array( @@ -257,37 +409,40 @@ pub fn fork_bufferize_fission_helper<'a>( )); let empty_buffer = edit.add_zero_constant(empty_buffer); let empty_buffer = edit.add_node(Node::Constant { id: empty_buffer }); + edit = edit.add_schedule(empty_buffer, Schedule::NoResetConstant)?; let reduce = Node::Reduce { - control: new_join_id, + control: join, init: empty_buffer, reduct: write, }; let reduce = edit.add_node(reduce); + edit = edit.add_schedule(reduce, Schedule::ParallelReduce)?; + // Fix write node edit = edit.replace_all_uses_where(NodeID::new(0), reduce, |usee| *usee == write)?; - // Create read from buffer - let tids = thread_stuff_it.clone().map(|(dim, _)| { - edit.add_node(Node::ThreadID { - control: new_fork_id, - dimension: dim, - }) - }); - - let position_idx = Index::Position(tids.collect::<Vec<_>>().into_boxed_slice()); + // Create reads from buffer + let position_idx = Index::Position(new_tids.clone().into_boxed_slice()); let read = edit.add_node(Node::Read { collect: reduce, indices: vec![position_idx].into(), }); - edit = edit.replace_all_uses_where(src, read, |usee| *usee == dst)?; + // Replaces uses of bufferized edge src with corresponding reduce and read in old subraph + edit = edit.replace_all_uses_where(map[src], read, |usee| *usee == map[dst])?; + } - Ok(edit) - }); - } + new_fork_id = new_fork; - (fork, new_fork_id) + Ok(edit) + }); + + if edit_result { + Some((fork, new_fork_id)) + } else { + None + } } /** Split a 1D fork into a separate fork for each reduction. */ diff --git a/hercules_test/hercules_tests/tests/loop_tests.rs b/hercules_test/hercules_tests/tests/loop_tests.rs index 795642b25ba25bf6f5c845065a24d4c98119aa82..f42a652002a51cea4daedad16374d8451c124928 100644 --- a/hercules_test/hercules_tests/tests/loop_tests.rs +++ b/hercules_test/hercules_tests/tests/loop_tests.rs @@ -385,7 +385,7 @@ fn matmul_pipeline() { let dyn_consts = [I, J, K]; // FIXME: This path should not leave the crate - let mut module = parse_module_from_hbin("../../juno_samples/test/out.hbin"); + let mut module = parse_module_from_hbin("../../juno_samples/matmul/out.hbin"); // let mut correct_c: Box<[i32]> = (0..I * K).map(|_| 0).collect(); for i in 0..I { @@ -398,25 +398,20 @@ fn matmul_pipeline() { let result_1 = interp_module!(module, 1, dyn_consts, a.clone(), b.clone()); - println!("golden: {:?}", correct_c); - println!("result: {:?}", result_1); + // println!("golden: {:?}", correct_c); + // println!("result: {:?}", result_1); - let InterpreterVal::Array(_, d) = result_1.clone() else { - panic!() - }; - let InterpreterVal::Integer32(value) = d[0] else { - panic!() - }; - assert_eq!(correct_c[0], value); + // let InterpreterVal::Array(_, d) = result_1.clone() else { + // panic!() + // }; + // let InterpreterVal::Integer32(value) = d[0] else { + // panic!() + // }; + // assert_eq!(correct_c[0], value); let schedule = Some(default_schedule![ - AutoOutline, - InterproceduralSROA, - SROA, - InferSchedules, - DCE, Xdot, - GCM + Verify, ]); module = run_schedule_on_hercules(module, schedule).unwrap(); diff --git a/juno_samples/fork_join_tests/src/cpu.sch b/juno_samples/fork_join_tests/src/cpu.sch index 8543d23d7ff9da8f6f0eecc844dc6bb0a77bcc2d..38a38c2bacfe480710b128ff44a24abfd0059df6 100644 --- a/juno_samples/fork_join_tests/src/cpu.sch +++ b/juno_samples/fork_join_tests/src/cpu.sch @@ -5,13 +5,15 @@ gvn(*); phi-elim(*); dce(*); -let auto = auto-outline(test1, test2, test3, test4, test5, test7); +let auto = auto-outline(test1, test2, test3, test4, test5, test7, test8); cpu(auto.test1); cpu(auto.test2); cpu(auto.test3); cpu(auto.test4); cpu(auto.test5); cpu(auto.test7); +cpu(auto.test8); + ip-sroa(*); sroa(*); @@ -24,8 +26,12 @@ dce(*); fixpoint panic after 20 { forkify(*); fork-guard-elim(*); - fork-coalesce(*); + fork-coalesce(*); } + +dce(*); +gvn(*); + gvn(*); phi-elim(*); dce(*); @@ -51,6 +57,7 @@ cpu(out); ip-sroa(*); sroa(*); unforkify(out); + dce(*); ccp(*); gvn(*); @@ -64,4 +71,13 @@ dce(auto.test7); simplify-cfg(auto.test7); dce(auto.test7); +let fission = fork-fission-bufferize[test8@loop, test8@bufferize1](auto.test8); +dce(auto.test8); +unforkify(auto.test8); +dce(auto.test8); +ccp(auto.test8); +dce(auto.test8); +simplify-cfg(auto.test8); +dce(auto.test8); + gcm(*); diff --git a/juno_samples/fork_join_tests/src/fork_join_tests.jn b/juno_samples/fork_join_tests/src/fork_join_tests.jn index 3f63f8209742a3c5487366783170ba9f38c4a147..886ab13b238402e10ad1480b853bf1a237dd41b4 100644 --- a/juno_samples/fork_join_tests/src/fork_join_tests.jn +++ b/juno_samples/fork_join_tests/src/fork_join_tests.jn @@ -94,3 +94,26 @@ fn test7(input: i32) -> i32 { } return sum; } + +#[entry] +fn test8(input : i32) -> i32[8] { + @const1 let arr : i32[8]; + @const2 let out : i32[8]; + + for i = 0 to 8 { + arr[i] = i as i32; + } + + @loop for i = 0 to 8 { + let b: i32; + @bufferize1 { + let a = arr[i]; + let a2 = a + arr[7-i]; + b = a2 * i as i32; + } + let c = b; + let d = c + 10; + out[i] = d; + } + return out; +} diff --git a/juno_samples/fork_join_tests/src/gpu.sch b/juno_samples/fork_join_tests/src/gpu.sch index bd1fd1d8465ba27778556987429757b0c47705ec..159fac94ee83c1dc7aa886545959e6cec2ed5c7a 100644 --- a/juno_samples/fork_join_tests/src/gpu.sch +++ b/juno_samples/fork_join_tests/src/gpu.sch @@ -5,18 +5,21 @@ no-memset(test3@const1); no-memset(test3@const2); no-memset(test3@const3); no-memset(test6@const); +no-memset(test8@const1); +no-memset(test8@const2); gvn(*); phi-elim(*); dce(*); -let auto = auto-outline(test1, test2, test3, test4, test5, test7); +let auto = auto-outline(test1, test2, test3, test4, test5, test7, test8); gpu(auto.test1); gpu(auto.test2); gpu(auto.test3); gpu(auto.test4); gpu(auto.test5); gpu(auto.test7); +gpu(auto.test8); ip-sroa(*); sroa(*); @@ -50,6 +53,14 @@ dce(auto.test7); simplify-cfg(auto.test7); dce(auto.test7); +let fission = fork-fission-bufferize[test8@loop, test8@bufferize1](auto.test8); +xdot[true](*); +dce(auto.test8); +ccp(auto.test8); +dce(auto.test8); +simplify-cfg(auto.test8); +dce(auto.test8); + ip-sroa(*); sroa(*); dce(*); diff --git a/juno_samples/fork_join_tests/src/main.rs b/juno_samples/fork_join_tests/src/main.rs index e08a8b1b405ca367c7c40242679e7950430db532..1013a1f028b9d204a442be21bc0a1966d025d25e 100644 --- a/juno_samples/fork_join_tests/src/main.rs +++ b/juno_samples/fork_join_tests/src/main.rs @@ -52,6 +52,11 @@ fn main() { let output = r.run(42).await; let correct: i32 = (42i32..42i32+32i32).sum(); assert_eq!(correct, output); + + let mut r = runner!(test8); + let output = r.run(0).await; + let correct = vec![10, 17, 24, 31, 38, 45, 52, 59]; + assert(&correct, output); }); } diff --git a/juno_scheduler/src/compile.rs b/juno_scheduler/src/compile.rs index 1d5044d86479f68e86c74d864a6f976604c4c6fe..0ca3638010bc1a026b44d56d5d49696633f7d9fa 100644 --- a/juno_scheduler/src/compile.rs +++ b/juno_scheduler/src/compile.rs @@ -110,6 +110,7 @@ impl FromStr for Appliable { "ip-sroa" | "interprocedural-sroa" => { Ok(Appliable::Pass(ir::Pass::InterproceduralSROA)) } + "fork-fission-bufferize" => Ok(Appliable::Pass(ir::Pass::ForkFissionBufferize)), "fork-dim-merge" => Ok(Appliable::Pass(ir::Pass::ForkDimMerge)), "fork-chunk" | "fork-tile" => Ok(Appliable::Pass(ir::Pass::ForkChunk)), "lift-dc-math" => Ok(Appliable::Pass(ir::Pass::LiftDCMath)), diff --git a/juno_scheduler/src/ir.rs b/juno_scheduler/src/ir.rs index 88aed007d2106e7f77d2425b326278557175de01..0ec6f52cd678c6e42324f77ec2f0799deaca1535 100644 --- a/juno_scheduler/src/ir.rs +++ b/juno_scheduler/src/ir.rs @@ -13,6 +13,7 @@ pub enum Pass { ForkSplit, ForkCoalesce, Forkify, + ForkFissionBufferize, ForkDimMerge, ForkChunk, GCM, @@ -39,6 +40,7 @@ impl Pass { match self { Pass::Xdot => 1, Pass::ForkChunk => 3, + Pass::ForkFissionBufferize => 2, _ => 0, } } diff --git a/juno_scheduler/src/pm.rs b/juno_scheduler/src/pm.rs index 5343005e5cc8df5e4ce08100239467072bb872dd..20825c549825d7abac60abb82fa508dd9f3ae803 100644 --- a/juno_scheduler/src/pm.rs +++ b/juno_scheduler/src/pm.rs @@ -2058,6 +2058,140 @@ fn run_pass( pm.delete_gravestones(); pm.clear_analyses(); } + Pass::ForkFissionBufferize => { + pm.make_fork_join_maps(); + + assert_eq!(args.len(), 2); + let Some(Value::Label { + labels: fork_labels, + }) = args.get(0) + else { + return Err(SchedulerError::PassError { + pass: "forkFissionBufferize".to_string(), + error: "expected label argument".to_string(), + }); + }; + + let Some(Value::Label { + labels: fork_data_labels, + }) = args.get(1) + else { + return Err(SchedulerError::PassError { + pass: "forkFissionBufferize".to_string(), + error: "expected label argument".to_string(), + }); + }; + + let mut created_fork_joins = vec![vec![]; pm.functions.len()]; + + pm.make_fork_join_maps(); + pm.make_typing(); + pm.make_loops(); + pm.make_nodes_in_fork_joins(); + let fork_join_maps = pm.fork_join_maps.take().unwrap(); + let typing = pm.typing.take().unwrap(); + let loops = pm.loops.take().unwrap(); + let nodes_in_fork_joins = pm.nodes_in_fork_joins.take().unwrap(); + + // assert only one function is in the selection. + let num_functions = build_selection(pm, selection.clone(), false) + .iter() + .filter(|func| func.is_some()) + .count(); + + assert!(num_functions <= 1); + assert_eq!(fork_labels.len(), 1); + assert_eq!(fork_data_labels.len(), 1); + + let fork_label = fork_labels[0].label; + let data_label = fork_data_labels[0].label; + + for ((((func, fork_join_map), loop_tree), typing), nodes_in_fork_joins) in + build_selection(pm, selection, false) + .into_iter() + .zip(fork_join_maps.iter()) + .zip(loops.iter()) + .zip(typing.iter()) + .zip(nodes_in_fork_joins.iter()) + { + let Some(mut func) = func else { + continue; + }; + if let Some((fork1, fork2)) = ff_bufferize_any_fork( + &mut func, + loop_tree, + fork_join_map, + nodes_in_fork_joins, + typing, + &fork_label, + &data_label, + ) { + let created_fork_joins = &mut created_fork_joins[func.func_id().idx()]; + created_fork_joins.push(fork1); + created_fork_joins.push(fork2); + } + changed |= func.modified(); + } + + pm.clear_analyses(); + pm.make_nodes_in_fork_joins(); + let nodes_in_fork_joins = pm.nodes_in_fork_joins.take().unwrap(); + let mut new_fork_joins = HashMap::new(); + + let fork_label_name = &pm.labels.borrow()[fork_label.idx()].clone(); + + for (mut func, created_fork_joins) in + build_editors(pm).into_iter().zip(created_fork_joins) + { + // For every function, create a label for every level of fork- + // joins resulting from the split. + let name = func.func().name.clone(); + let func_id = func.func_id(); + let labels = create_labels_for_node_sets( + &mut func, + created_fork_joins.into_iter().map(|fork| { + nodes_in_fork_joins[func_id.idx()][&fork] + .iter() + .map(|id| *id) + }), + ); + + // Assemble those labels into a record for this function. The + // format of the records is <function>.<f>, where N is the + // level of the split fork-joins being referred to. + let mut func_record = HashMap::new(); + for (idx, label) in labels { + let fmt = if idx % 2 == 0 { + format!("fj_{}_top", fork_label_name) + } else { + format!("fj_{}_bottom", fork_label_name) + }; + func_record.insert( + fmt, + Value::Label { + labels: vec![LabelInfo { + func: func_id, + label: label, + }], + }, + ); + } + + // Try to avoid creating unnecessary record entries. + if !func_record.is_empty() { + new_fork_joins.entry(name).insert_entry(Value::Record { + fields: func_record, + }); + } + } + + pm.delete_gravestones(); + pm.clear_analyses(); + + result = Value::Record { + fields: new_fork_joins, + }; + } Pass::ForkDimMerge => { assert!(args.is_empty()); pm.make_fork_join_maps(); @@ -2094,8 +2228,8 @@ fn run_pass( let Some(mut func) = func else { continue; }; - changed |= fork_coalesce(&mut func, loop_nest, fork_join_map); - // func.modified(); + fork_coalesce(&mut func, loop_nest, fork_join_map); + changed |= func.modified(); } pm.delete_gravestones(); pm.clear_analyses();