From 28ea01a1afcbb630b570e360693defb3e76c70d9 Mon Sep 17 00:00:00 2001 From: Russel Arbore <russel.jma@gmail.com> Date: Fri, 28 Feb 2025 13:24:48 -0600 Subject: [PATCH] Infer more indices as parallel --- hercules_opt/src/utils.rs | 18 ++++++++++++++++++ juno_samples/rodinia/backprop/src/backprop.jn | 4 ++-- juno_samples/rodinia/backprop/src/cpu.sch | 10 +++------- 3 files changed, 23 insertions(+), 9 deletions(-) diff --git a/hercules_opt/src/utils.rs b/hercules_opt/src/utils.rs index e962b81d..b910a128 100644 --- a/hercules_opt/src/utils.rs +++ b/hercules_opt/src/utils.rs @@ -532,6 +532,24 @@ where let fork_thread_id_pairs = node_indices(indices).filter_map(|id| { if let Node::ThreadID { control, dimension } = nodes[id.idx()] { Some((control, dimension)) + } else if let Node::Binary { + op: BinaryOperator::Add, + left: tid, + right: cons, + } = nodes[id.idx()] + && let Node::ThreadID { control, dimension } = nodes[tid.idx()] + && (nodes[cons.idx()].is_constant() || nodes[cons.idx()].is_dynamic_constant()) + { + Some((control, dimension)) + } else if let Node::Binary { + op: BinaryOperator::Add, + left: cons, + right: tid, + } = nodes[id.idx()] + && let Node::ThreadID { control, dimension } = nodes[tid.idx()] + && (nodes[cons.idx()].is_constant() || nodes[cons.idx()].is_dynamic_constant()) + { + Some((control, dimension)) } else { None } diff --git a/juno_samples/rodinia/backprop/src/backprop.jn b/juno_samples/rodinia/backprop/src/backprop.jn index 356bb3d9..94c4334c 100644 --- a/juno_samples/rodinia/backprop/src/backprop.jn +++ b/juno_samples/rodinia/backprop/src/backprop.jn @@ -7,9 +7,9 @@ fn layer_forward<n, m: usize>(vals: f32[n + 1], weights: f32[n + 1, m + 1]) -> f @res let result : f32[m + 1]; result[0] = 1.0; - for j in 1..=m { + @outer_loop for j in 1..=m { let sum = 0.0; - for k in 0..=n { + @inner_loop for k in 0..=n { sum += weights[k, j] * vals[k]; } result[j] = squash(sum); diff --git a/juno_samples/rodinia/backprop/src/cpu.sch b/juno_samples/rodinia/backprop/src/cpu.sch index d1fe8953..d59fd5f5 100644 --- a/juno_samples/rodinia/backprop/src/cpu.sch +++ b/juno_samples/rodinia/backprop/src/cpu.sch @@ -15,20 +15,16 @@ delete-uncalled(*); no-memset(layer_forward@res); lift-dc-math(*); loop-bound-canon(*); -dce(*); +simpl!(*); lift-dc-math(*); +slf(*); fixpoint { forkify(*); fork-guard-elim(*); fork-coalesce(*); } +simpl!(*); fork-split(*); -gvn(*); -phi-elim(*); -dce(*); unforkify(*); -gvn(*); -phi-elim(*); -dce(*); gcm(*); -- GitLab