diff --git a/hercules_opt/src/fork_transforms.rs b/hercules_opt/src/fork_transforms.rs index 1df5338ddb3a4fa1a3ff65db27a6c9571dd21bb0..7367a63e51fad5dc3fb5dfb5de0e2426f46b41e3 100644 --- a/hercules_opt/src/fork_transforms.rs +++ b/hercules_opt/src/fork_transforms.rs @@ -310,23 +310,33 @@ pub fn fork_fission<'a>( reduce_cycles: &HashMap<NodeID, HashSet<NodeID>>, loop_tree: &LoopTree, fork_join_map: &HashMap<NodeID, NodeID>, + fork_label: LabelID, ) -> Vec<NodeID> { - let mut forks: Vec<_> = loop_tree + let forks: Vec<_> = loop_tree .bottom_up_loops() .into_iter() .filter(|(k, _)| editor.func().nodes[k.idx()].is_fork()) .collect(); let mut created_forks = Vec::new(); + // This does the reduction fission - for fork in forks.clone() { + for fork in forks { let join = fork_join_map[&fork.0]; + + // FIXME: Don't make multiple forks for reduces that are in cycles with each other. let reduce_partition = default_reduce_partition(editor, fork.0, join); + if !editor.func().labels[fork.0.idx()].contains(&fork_label) { + continue; + } + if editor.is_mutable(fork.0) { created_forks = fork_reduce_fission_helper(editor, fork_join_map, reduce_partition, nodes_in_fork_joins, fork.0); - if !created_forks.is_empty() { - break; + if created_forks.is_empty() { + continue; + } else { + return created_forks; } } diff --git a/juno_samples/fork_join_tests/src/cpu.sch b/juno_samples/fork_join_tests/src/cpu.sch index 71a6eb928fd54afd7f6e79e41f1f611bda8ed5ca..a3bb146b3f8d456083daeb4ba62b8bb5369b7a71 100644 --- a/juno_samples/fork_join_tests/src/cpu.sch +++ b/juno_samples/fork_join_tests/src/cpu.sch @@ -98,15 +98,11 @@ dce(auto.test8); no-memset(test9@const); fork-split(auto.test10); -xdot[true](auto.test10); -fork-fission-reduces(auto.test10); -dce(auto.test10); -xdot[false](auto.test10); +let fission_out = fork-fission-reduces[test10@loop3](auto.test10); +fork-fusion(fission_out.test10_8.fj0, fission_out.test10_8.fj1); dce(auto.test10); simplify-cfg(auto.test10); dce(auto.test10); -xdot[true](auto.test10); - fork-split(auto.test10); unforkify(auto.test10); diff --git a/juno_samples/fork_join_tests/src/gpu.sch b/juno_samples/fork_join_tests/src/gpu.sch index 81dc8d9854776931f4598a9010837008796baaf8..0d828f489cea8251b0b008efb1a3ddfae857547f 100644 --- a/juno_samples/fork_join_tests/src/gpu.sch +++ b/juno_samples/fork_join_tests/src/gpu.sch @@ -13,7 +13,7 @@ gvn(*); phi-elim(*); dce(*); -let auto = auto-outline(test1, test2, test3, test4, test5, test7, test8, test9); +let auto = auto-outline(test1, test2, test3, test4, test5, test7, test8, test9, test10); gpu(auto.test1); gpu(auto.test2); gpu(auto.test3); @@ -22,6 +22,7 @@ gpu(auto.test5); gpu(auto.test7); gpu(auto.test8); gpu(auto.test9); +gpu(auto.test10); ip-sroa(*); sroa(*); @@ -75,6 +76,12 @@ dce(auto.test8); no-memset(test9@const); +fork-split(auto.test10); +fork-fission-reduces(auto.test10); +dce(auto.test10); +simplify-cfg(auto.test10); +dce(auto.test10); + ip-sroa(*); sroa(*); dce(*); diff --git a/juno_scheduler/src/ir.rs b/juno_scheduler/src/ir.rs index d381e861c323f5d1326e26cccb2cf916a0bf735f..dd7a76a7da6b58bb5bc723116e84318b4b6346dc 100644 --- a/juno_scheduler/src/ir.rs +++ b/juno_scheduler/src/ir.rs @@ -60,6 +60,7 @@ impl Pass { Pass::ForkChunk => num == 4, Pass::ForkExtend => num == 1, Pass::ForkFissionBufferize => num == 2 || num == 1, + Pass::ForkFission => num == 1, Pass::ForkInterchange => num == 2, Pass::ForkReshape => true, Pass::InterproceduralSROA => num == 0 || num == 1, @@ -78,6 +79,7 @@ impl Pass { Pass::ForkChunk => "4", Pass::ForkExtend => "1", Pass::ForkFissionBufferize => "1 or 2", + Pass::ForkFission => "1", Pass::ForkInterchange => "2", Pass::ForkReshape => "any", Pass::InterproceduralSROA => "0 or 1", diff --git a/juno_scheduler/src/pm.rs b/juno_scheduler/src/pm.rs index 9281925a96dda6fbad44c746bf824206f13c34d7..666ebdcd726ff3347fedac45471d20e33ed9d995 100644 --- a/juno_scheduler/src/pm.rs +++ b/juno_scheduler/src/pm.rs @@ -2742,7 +2742,18 @@ fn run_pass( pm.clear_analyses(); }, Pass::ForkFission => { - assert!(args.len() == 0); + assert!(args.len() == 1); + + let Some(Value::Label { + labels: fork_labels, + }) = args.get(0) + else { + return Err(SchedulerError::PassError { + pass: "forkFission".to_string(), + error: "expected label argument".to_string(), + }); + }; + pm.make_fork_join_maps(); pm.make_loops(); @@ -2755,6 +2766,17 @@ fn run_pass( let mut created_fork_joins = vec![vec![]; pm.functions.len()]; + // assert only one function is in the selection. + let num_functions = build_selection(pm, selection.clone(), false) + .iter() + .filter(|func| func.is_some()) + .count(); + + assert!(num_functions <= 1); + assert_eq!(fork_labels.len(), 1); + + let fork_label = fork_labels[0].label; + for ((((func, fork_join_map), loop_tree), reduce_cycles), nodes_in_fork_joins, ) in build_selection(pm, selection, false) .into_iter() @@ -2773,6 +2795,7 @@ fn run_pass( reduce_cycles, loop_tree, fork_join_map, + fork_label, ); let created_fork_joins = &mut created_fork_joins[func.func_id().idx()]; @@ -2810,7 +2833,7 @@ fn run_pass( // level of the split fork-joins being referred to. let mut func_record = HashMap::new(); for (idx, label) in labels { - let fmt = if idx % 2 == 0 { "fj_top" } else { "fj_bottom" }; + let fmt = format!("fj{idx}"); func_record.insert( fmt.to_string(), Value::Label { @@ -2874,14 +2897,6 @@ fn run_pass( let fork_label = fork_labels[0].label; - // assert only one function is in the selection. - let num_functions = build_selection(pm, selection.clone(), false) - .iter() - .filter(|func| func.is_some()) - .count(); - - assert!(num_functions <= 1); - for ( ((((func, fork_join_map), loop_tree), typing), reduce_cycles), nodes_in_fork_joins,