diff --git a/hercules_cg/src/cpu.rs b/hercules_cg/src/cpu.rs index f6a1f309745adf23b44d093ed8b27e48cd92737f..20a3e6cb1a2b663e2bd3a7fc41df82fe0ef443db 100644 --- a/hercules_cg/src/cpu.rs +++ b/hercules_cg/src/cpu.rs @@ -694,7 +694,7 @@ impl<'a> CPUContext<'a> { ); write!( body, - " {} = call i64 @llvm.umin.i64(i64{},i64%dc{}))\n", + " {} = call i64 @llvm.umin.i64(i64{},i64%dc{})\n", new_val, cur_value, x.idx() @@ -719,7 +719,7 @@ impl<'a> CPUContext<'a> { ); write!( body, - " {} = call i64 @llvm.umax.i64(i64{},i64%dc{}))\n", + " {} = call i64 @llvm.umax.i64(i64{},i64%dc{})\n", new_val, cur_value, x.idx() diff --git a/hercules_opt/src/fork_transforms.rs b/hercules_opt/src/fork_transforms.rs index 456f670ea8f9a5998031597ffba830596d766dee..f82bbdfb9fd7e6d4950fedaf78c882d539dce896 100644 --- a/hercules_opt/src/fork_transforms.rs +++ b/hercules_opt/src/fork_transforms.rs @@ -98,6 +98,31 @@ pub fn find_reduce_dependencies<'a>( ret_val } +pub fn copy_subgraph_in_edit<'a>( + mut edit: FunctionEdit<'a, 'a>, + subgraph: HashSet<NodeID>, +) -> ( + Result<(FunctionEdit<'a, 'a>, HashMap<NodeID, NodeID>), FunctionEdit<'a, 'a>>// a map from old nodes to new nodes + // A list of (inside *old* node, outside node) s.t insde old node -> outside node. + // The caller probably wants +) { + + let mut map: HashMap<NodeID, NodeID> = HashMap::new(); + + // Copy nodes in subgraph + for old_id in subgraph.iter().cloned() { + let new_id = edit.copy_node(old_id); + map.insert(old_id, new_id); + } + + // Update edges to new nodes + for old_id in subgraph.iter() { + edit = edit.replace_all_uses_where(*old_id, map[old_id], |node_id| map.values().contains(node_id))?; + } + + Ok((edit, map)) +} + pub fn copy_subgraph( editor: &mut FunctionEditor, subgraph: HashSet<NodeID>, @@ -144,6 +169,37 @@ pub fn copy_subgraph( (new_nodes, map, outside_users) } + +pub fn fork_fission_bufferize_toplevel<'a>( + editor: &'a mut FunctionEditor<'a>, + loop_tree: &'a LoopTree, + fork_join_map: &'a HashMap<NodeID, NodeID>, + data_node_in_fork_joins: &'a HashMap<NodeID, HashSet<NodeID>>, + typing: &'a Vec<TypeID> +) -> bool { + + let forks: Vec<_> = loop_tree + .bottom_up_loops() + .into_iter() + .filter(|(k, _)| editor.func().nodes[k.idx()].is_fork()) + .collect(); + + for l in forks { + let fork_info = &Loop { + header: l.0, + control: l.1.clone(), + }; + let fork = fork_info.header; + let join = fork_join_map[&fork]; + let mut edges = HashSet::new(); + edges.insert((NodeID::new(8), NodeID::new(3))); + let result = fork_bufferize_fission_helper(editor, fork_info, edges, data_node_in_fork_joins, typing, fork, join); + return result.is_some(); + } + return false; +} + + pub fn fork_fission<'a>( editor: &'a mut FunctionEditor, _control_subgraph: &Subgraph, @@ -184,71 +240,100 @@ pub fn fork_fission<'a>( /** Split a 1D fork into two forks, placing select intermediate data into buffers. */ pub fn fork_bufferize_fission_helper<'a>( - editor: &'a mut FunctionEditor, - fork_join_map: &HashMap<NodeID, NodeID>, + editor: &'a mut FunctionEditor<'a>, + l: &Loop, bufferized_edges: HashSet<(NodeID, NodeID)>, // Describes what intermediate data should be bufferized. - _original_control_pred: NodeID, // What the new fork connects to. + data_node_in_fork_joins: &'a HashMap<NodeID, HashSet<NodeID>>, types: &Vec<TypeID>, fork: NodeID, -) -> (NodeID, NodeID) { + join: NodeID +) -> Option<(NodeID, NodeID)> { // Returns the two forks that it generates. - // TODO: Check that bufferized edges src doesn't depend on anything that comes after the fork. + if bufferized_edges.is_empty() { + return None + } - // Copy fork + control intermediates + join to new fork + join, - // How does control get partitioned? - // (depending on how it affects the data nodes on each side of the bufferized_edges) - // may end up in each loop, fix me later. - // place new fork + join after join of first. + let all_loop_nodes = l.get_all_nodes(); - // Only handle fork+joins with no inner control for now. + // FIXME: Cloning hell. + let data_nodes = data_node_in_fork_joins[&fork].clone(); + let loop_nodes = editor.node_ids().filter(|node_id| all_loop_nodes[node_id.idx()]); + // Clone the subgraph that consists of this fork-join and all data and control nodes in it. + let subgraph = HashSet::from_iter(data_nodes.into_iter().chain(loop_nodes)); - // Create fork + join + Thread control - let join = fork_join_map[&fork]; - let mut new_fork_id = NodeID::new(0); - let mut new_join_id = NodeID::new(0); + let mut outside_users = Vec::new(); // old_node, outside_user - editor.edit(|mut edit| { - new_join_id = edit.add_node(Node::Join { control: fork }); - let factors = edit.get_node(fork).try_fork().unwrap().1; - new_fork_id = edit.add_node(Node::Fork { - control: new_join_id, - factors: factors.into(), - }); - edit.replace_all_uses_where(fork, new_fork_id, |usee| *usee == join) - }); + for node in subgraph.iter() { + for user in editor.get_users(*node) { + if !subgraph.iter().contains(&user) { + outside_users.push((*node, user)); + } + } + } - for (src, dst) in bufferized_edges { - // FIXME: Disgusting cloning and allocationing and iterators. - let factors: Vec<_> = editor.func().nodes[fork.idx()] - .try_fork() - .unwrap() - .1 - .iter() - .cloned() - .collect(); - editor.edit(|mut edit| { - // Create write to buffer + let factors: Vec<_> = editor.func().nodes[fork.idx()] + .try_fork() + .unwrap() + .1 + .iter() + .cloned() + .collect(); - let thread_stuff_it = factors.into_iter().enumerate(); + let thread_stuff_it = factors.into_iter().enumerate(); - // FIxme: try to use unzip here? Idk why it wasn't working. - let tids = thread_stuff_it.clone().map(|(dim, _)| { - edit.add_node(Node::ThreadID { - control: fork, - dimension: dim, - }) + // Control succesors + let fork_pred = editor.get_uses(fork).filter(|a| editor.node(a).is_control()).next().unwrap(); + let join_successor = editor.get_users(join).filter(|a| editor.node(a).is_control()).next().unwrap(); + + let mut new_fork_id = NodeID::new(0); + editor.edit(|edit| { + let (mut edit, map) = copy_subgraph_in_edit(edit, subgraph)?; + + // Put new subgraph after old subgraph + println!("map: {:?}", map); + println!("join: {:?}, fork: {:?}", join, fork); + println!("fork_sccueue: {:?}", join_successor); + edit = edit.replace_all_uses_where(fork_pred, join, |a| *a == map[&fork])?; + edit = edit.replace_all_uses_where(join, map[&join], |a| *a == join_successor)?; + + // Replace outside uses of reduces in old subgraph with reduces in new subgraph. + for (old_node, outside_user) in outside_users { + edit = edit.replace_all_uses_where(old_node, map[&old_node], |node| *node == outside_user)?; + } + + // Add buffers to old subgraph + + let new_join = map[&join]; + let new_fork = map[&fork]; + + // FIXME: Do this as part of copy subgraph? + // Add tids to original subgraph for indexing. + let mut old_tids = Vec::new(); + let mut new_tids = Vec::new(); + for (dim, _) in thread_stuff_it.clone() { + let old_id = edit.add_node(Node::ThreadID { + control: fork, + dimension: dim, }); - let array_dims = thread_stuff_it.clone().map(|(_, factor)| (factor)); + let new_id = edit.add_node(Node::ThreadID { + control: new_fork, + dimension: dim, + }); - // Assume 1-d fork only for now. - // let tid = edit.add_node(Node::ThreadID { control: fork, dimension: 0 }); - let position_idx = Index::Position(tids.collect::<Vec<_>>().into_boxed_slice()); + old_tids.push(old_id); + new_tids.push(new_id); + } + + for (src, dst) in &bufferized_edges { + let array_dims = thread_stuff_it.clone().map(|(_, factor)| (factor)); + let position_idx = Index::Position(old_tids.clone().into_boxed_slice()); + let write = edit.add_node(Node::Write { collect: NodeID::new(0), - data: src, - indices: vec![position_idx].into(), + data: *src, + indices: vec![position_idx.clone()].into(), }); let ele_type = types[src.idx()]; let empty_buffer = edit.add_type(hercules_ir::Type::Array( @@ -258,36 +343,125 @@ pub fn fork_bufferize_fission_helper<'a>( let empty_buffer = edit.add_zero_constant(empty_buffer); let empty_buffer = edit.add_node(Node::Constant { id: empty_buffer }); let reduce = Node::Reduce { - control: new_join_id, + control: join, init: empty_buffer, reduct: write, }; let reduce = edit.add_node(reduce); + // Fix write node edit = edit.replace_all_uses_where(NodeID::new(0), reduce, |usee| *usee == write)?; - // Create read from buffer - let tids = thread_stuff_it.clone().map(|(dim, _)| { - edit.add_node(Node::ThreadID { - control: new_fork_id, - dimension: dim, - }) - }); - let position_idx = Index::Position(tids.collect::<Vec<_>>().into_boxed_slice()); + // Create reads from buffer + let position_idx = Index::Position(new_tids.clone().into_boxed_slice()); let read = edit.add_node(Node::Read { collect: reduce, indices: vec![position_idx].into(), }); - edit = edit.replace_all_uses_where(src, read, |usee| *usee == dst)?; + // Replaces uses of bufferized edge src with corresponding reduce and read in old subraph + edit = edit.replace_all_uses_where(map[src], read, |usee| *usee == map[dst])?; + + } + + + - Ok(edit) - }); - } + new_fork_id = new_fork; + + Ok(edit) + }); - (fork, new_fork_id) + Some((fork, new_fork_id)) + + // let internal_control: Vec<NodeID> = Vec::new(); + + // // Create fork + join + Thread control + // let mut new_fork_id = NodeID::new(0); + // let mut new_join_id = NodeID::new(0); + + // editor.edit(|mut edit| { + // new_join_id = edit.add_node(Node::Join { control: fork }); + // let factors = edit.get_node(fork).try_fork().unwrap().1; + // new_fork_id = edit.add_node(Node::Fork { + // control: new_join_id, + // factors: factors.into(), + // }); + // edit.replace_all_uses_where(fork, new_fork_id, |usee| *usee == join) + // }); + + // for (src, dst) in bufferized_edges { + // // FIXME: Disgusting cloning and allocationing and iterators. + // let factors: Vec<_> = editor.func().nodes[fork.idx()] + // .try_fork() + // .unwrap() + // .1 + // .iter() + // .cloned() + // .collect(); + // editor.edit(|mut edit| { + // // Create write to buffer + + // let thread_stuff_it = factors.into_iter().enumerate(); + + // // FIxme: try to use unzip here? Idk why it wasn't working. + // let tids = thread_stuff_it.clone().map(|(dim, _)| { + // edit.add_node(Node::ThreadID { + // control: fork, + // dimension: dim, + // }) + // }); + + // let array_dims = thread_stuff_it.clone().map(|(_, factor)| (factor)); + + // // Assume 1-d fork only for now. + // // let tid = edit.add_node(Node::ThreadID { control: fork, dimension: 0 }); + // let position_idx = Index::Position(tids.collect::<Vec<_>>().into_boxed_slice()); + // let write = edit.add_node(Node::Write { + // collect: NodeID::new(0), + // data: src, + // indices: vec![position_idx].into(), + // }); + // let ele_type = types[src.idx()]; + // let empty_buffer = edit.add_type(hercules_ir::Type::Array( + // ele_type, + // array_dims.collect::<Vec<_>>().into_boxed_slice(), + // )); + // let empty_buffer = edit.add_zero_constant(empty_buffer); + // let empty_buffer = edit.add_node(Node::Constant { id: empty_buffer }); + // let reduce = Node::Reduce { + // control: new_join_id, + // init: empty_buffer, + // reduct: write, + // }; + // let reduce = edit.add_node(reduce); + // // Fix write node + // edit = edit.replace_all_uses_where(NodeID::new(0), reduce, |usee| *usee == write)?; + + // // Create read from buffer + // let tids = thread_stuff_it.clone().map(|(dim, _)| { + // edit.add_node(Node::ThreadID { + // control: new_fork_id, + // dimension: dim, + // }) + // }); + + // let position_idx = Index::Position(tids.collect::<Vec<_>>().into_boxed_slice()); + + // let read = edit.add_node(Node::Read { + // collect: reduce, + // indices: vec![position_idx].into(), + // }); + + // edit = edit.replace_all_uses_where(src, read, |usee| *usee == dst)?; + + // Ok(edit) + // }); + // } + + // (fork, new_fork_id) } /** Split a 1D fork into a separate fork for each reduction. */ diff --git a/hercules_test/hercules_tests/tests/loop_tests.rs b/hercules_test/hercules_tests/tests/loop_tests.rs index 795642b25ba25bf6f5c845065a24d4c98119aa82..f42a652002a51cea4daedad16374d8451c124928 100644 --- a/hercules_test/hercules_tests/tests/loop_tests.rs +++ b/hercules_test/hercules_tests/tests/loop_tests.rs @@ -385,7 +385,7 @@ fn matmul_pipeline() { let dyn_consts = [I, J, K]; // FIXME: This path should not leave the crate - let mut module = parse_module_from_hbin("../../juno_samples/test/out.hbin"); + let mut module = parse_module_from_hbin("../../juno_samples/matmul/out.hbin"); // let mut correct_c: Box<[i32]> = (0..I * K).map(|_| 0).collect(); for i in 0..I { @@ -398,25 +398,20 @@ fn matmul_pipeline() { let result_1 = interp_module!(module, 1, dyn_consts, a.clone(), b.clone()); - println!("golden: {:?}", correct_c); - println!("result: {:?}", result_1); + // println!("golden: {:?}", correct_c); + // println!("result: {:?}", result_1); - let InterpreterVal::Array(_, d) = result_1.clone() else { - panic!() - }; - let InterpreterVal::Integer32(value) = d[0] else { - panic!() - }; - assert_eq!(correct_c[0], value); + // let InterpreterVal::Array(_, d) = result_1.clone() else { + // panic!() + // }; + // let InterpreterVal::Integer32(value) = d[0] else { + // panic!() + // }; + // assert_eq!(correct_c[0], value); let schedule = Some(default_schedule![ - AutoOutline, - InterproceduralSROA, - SROA, - InferSchedules, - DCE, Xdot, - GCM + Verify, ]); module = run_schedule_on_hercules(module, schedule).unwrap(); diff --git a/juno_samples/fork_join_tests/src/cpu.sch b/juno_samples/fork_join_tests/src/cpu.sch index a557cd036c9aebcbc7a989bd584226f6647379c9..947d0dc86a72b55c1adf99d1d4ab4cc77f209c35 100644 --- a/juno_samples/fork_join_tests/src/cpu.sch +++ b/juno_samples/fork_join_tests/src/cpu.sch @@ -7,10 +7,13 @@ dce(*); let out = auto-outline(test1, test2, test3, test4, test5); cpu(out.test1); +<<<<<<< Updated upstream cpu(out.test2); cpu(out.test3); cpu(out.test4); cpu(out.test5); +======= +>>>>>>> Stashed changes ip-sroa(*); @@ -24,8 +27,15 @@ dce(*); fixpoint panic after 20 { forkify(*); fork-guard-elim(*); - fork-coalesce(*); + fork-coalesce(*); } + +dce(*); +gvn(*); + +xdot[true](*); +fork-fission-bufferize(*); +xdot[true](*); gvn(*); phi-elim(*); dce(*); diff --git a/juno_samples/fork_join_tests/src/fork_join_tests.jn b/juno_samples/fork_join_tests/src/fork_join_tests.jn index 128d3ce0198972b9767ad1879a594f83df3d4adc..90d06c2fe2c7139aaee4d3dd6288a82ffa87e055 100644 --- a/juno_samples/fork_join_tests/src/fork_join_tests.jn +++ b/juno_samples/fork_join_tests/src/fork_join_tests.jn @@ -9,6 +9,7 @@ fn test1(input : i32) -> i32[4, 4] { return arr; } +/** #[entry] fn test2(input : i32) -> i32[4, 4] { let arr : i32[4, 4]; @@ -72,6 +73,7 @@ fn test5(input : i32) -> i32[4] { } return arr1; } +<<<<<<< Updated upstream #[entry] fn test6(input: i32) -> i32[1024] { diff --git a/juno_samples/fork_join_tests/src/main.rs b/juno_samples/fork_join_tests/src/main.rs index 19838fd7741d7bf366c23c36038a66ab636d96c9..33c5602e4c3e4df4aab3cbbd9a88a666870023e0 100644 --- a/juno_samples/fork_join_tests/src/main.rs +++ b/juno_samples/fork_join_tests/src/main.rs @@ -23,6 +23,7 @@ fn main() { let correct = vec![5i32; 16]; assert(&correct, output); +<<<<<<< Updated upstream let mut r = runner!(test2); let output = r.run(3).await; let correct = vec![24i32; 16]; @@ -47,6 +48,27 @@ fn main() { let output = r.run(73).await; let correct = (73i32..73i32+1024i32).collect(); assert(&correct, output); +======= + // let mut r = runner!(test2); + // let output = r.run(3).await; + // let correct = vec![24i32; 16]; + // assert(correct, output); + + // let mut r = runner!(test3); + // let output = r.run(0).await; + // let correct = vec![11, 10, 9, 10, 9, 8, 9, 8, 7]; + // assert(correct, output); + + // let mut r = runner!(test4); + // let output = r.run(9).await; + // let correct = vec![63i32; 16]; + // assert(correct, output); + + // let mut r = runner!(test5); + // let output = r.run(4).await; + // let correct = vec![7i32; 4]; + // assert(correct, output); +>>>>>>> Stashed changes }); } diff --git a/juno_samples/matmul/build.rs b/juno_samples/matmul/build.rs index 0be838c620761e8726590e2dbaf7bfdb7a82e3df..d2813388e0e7a1d7bd1696ffbb641e629096e2c2 100644 --- a/juno_samples/matmul/build.rs +++ b/juno_samples/matmul/build.rs @@ -6,6 +6,8 @@ fn main() { JunoCompiler::new() .file_in_src("matmul.jn") .unwrap() + .schedule_in_src("cpu.sch") + .unwrap() .build() .unwrap(); } diff --git a/juno_scheduler/src/compile.rs b/juno_scheduler/src/compile.rs index 713c30d436fdb9316f92f7ecbeee3659f2924d83..08e952a4e535d36e270db4505210511ebacd91d3 100644 --- a/juno_scheduler/src/compile.rs +++ b/juno_scheduler/src/compile.rs @@ -109,6 +109,7 @@ impl FromStr for Appliable { "ip-sroa" | "interprocedural-sroa" => { Ok(Appliable::Pass(ir::Pass::InterproceduralSROA)) } + "fork-fission-bufferize" => Ok(Appliable::Pass(ir::Pass::ForkFissionBufferize)), "fork-dim-merge" => Ok(Appliable::Pass(ir::Pass::ForkDimMerge)), "fork-chunk" | "fork-tile" => Ok(Appliable::Pass(ir::Pass::ForkChunk)), "lift-dc-math" => Ok(Appliable::Pass(ir::Pass::LiftDCMath)), diff --git a/juno_scheduler/src/ir.rs b/juno_scheduler/src/ir.rs index 9e85509f83a322da222e56257ffffc697ab14e30..8a6e04edfacf0f09ab45d7144c1c7f90d0b244fb 100644 --- a/juno_scheduler/src/ir.rs +++ b/juno_scheduler/src/ir.rs @@ -12,6 +12,7 @@ pub enum Pass { ForkSplit, ForkCoalesce, Forkify, + ForkFissionBufferize, ForkDimMerge, ForkChunk, GCM, diff --git a/juno_scheduler/src/pm.rs b/juno_scheduler/src/pm.rs index d4acac190cb7bd96770fbaf1a6bd7e4604d714d6..28b0cbf524c4607317ce773d25e85bdf4da2ef89 100644 --- a/juno_scheduler/src/pm.rs +++ b/juno_scheduler/src/pm.rs @@ -1987,10 +1987,7 @@ fn run_pass( }); }; let Some(Value::Integer { val: dim_idx }) = args.get(1) else { - return Err(SchedulerError::PassError { - pass: "forkChunk".to_string(), - error: "expected integer argument".to_string(), - }); + panic!(); // How to error here? }; let Some(Value::Boolean { val: guarded_flag }) = args.get(2) else { return Err(SchedulerError::PassError { @@ -2002,9 +1999,10 @@ fn run_pass( assert_eq!(*guarded_flag, true); pm.make_fork_join_maps(); let fork_join_maps = pm.fork_join_maps.take().unwrap(); - for (func, fork_join_map) in build_selection(pm, selection) - .into_iter() - .zip(fork_join_maps.iter()) + for (func, fork_join_map) in + build_selection(pm, selection) + .into_iter() + .zip(fork_join_maps.iter()) { let Some(mut func) = func else { continue; @@ -2015,13 +2013,41 @@ fn run_pass( pm.delete_gravestones(); pm.clear_analyses(); } + Pass::ForkFissionBufferize => { + assert!(args.is_empty()); + pm.make_fork_join_maps(); + pm.make_typing(); + pm.make_loops(); + pm.make_nodes_in_fork_joins(); + let fork_join_maps = pm.fork_join_maps.take().unwrap(); + let typing = pm.typing.take().unwrap(); + let loops = pm.loops.take().unwrap(); + let nodes_in_fork_joins = pm.nodes_in_fork_joins.take().unwrap(); + for ((((func, fork_join_map), loop_tree), typing), nodes_in_fork_joins) in + build_selection(pm, selection) + .into_iter() + .zip(fork_join_maps.iter()) + .zip(loops.iter()) + .zip(typing.iter()) + .zip(nodes_in_fork_joins.iter()) + { + let Some(mut func) = func else { + continue; + }; + let result = fork_fission_bufferize_toplevel(&mut func, loop_tree, fork_join_map, nodes_in_fork_joins, typing); + changed |= result; + } + pm.delete_gravestones(); + pm.clear_analyses(); + } Pass::ForkDimMerge => { assert!(args.is_empty()); pm.make_fork_join_maps(); let fork_join_maps = pm.fork_join_maps.take().unwrap(); - for (func, fork_join_map) in build_selection(pm, selection) - .into_iter() - .zip(fork_join_maps.iter()) + for (func, fork_join_map) in + build_selection(pm, selection) + .into_iter() + .zip(fork_join_maps.iter()) { let Some(mut func) = func else { continue; @@ -2051,8 +2077,8 @@ fn run_pass( let Some(mut func) = func else { continue; }; - changed |= fork_coalesce(&mut func, loop_nest, fork_join_map); - // func.modified(); + fork_coalesce(&mut func, loop_nest, fork_join_map); + changed |= func.modified(); } pm.delete_gravestones(); pm.clear_analyses();