diff --git a/hercules_opt/src/utils.rs b/hercules_opt/src/utils.rs index e962b81dfaf28ece80f49a150139f5774c186771..b910a128116fb8fb39de29475b93ffa70a12dfcd 100644 --- a/hercules_opt/src/utils.rs +++ b/hercules_opt/src/utils.rs @@ -532,6 +532,24 @@ where let fork_thread_id_pairs = node_indices(indices).filter_map(|id| { if let Node::ThreadID { control, dimension } = nodes[id.idx()] { Some((control, dimension)) + } else if let Node::Binary { + op: BinaryOperator::Add, + left: tid, + right: cons, + } = nodes[id.idx()] + && let Node::ThreadID { control, dimension } = nodes[tid.idx()] + && (nodes[cons.idx()].is_constant() || nodes[cons.idx()].is_dynamic_constant()) + { + Some((control, dimension)) + } else if let Node::Binary { + op: BinaryOperator::Add, + left: cons, + right: tid, + } = nodes[id.idx()] + && let Node::ThreadID { control, dimension } = nodes[tid.idx()] + && (nodes[cons.idx()].is_constant() || nodes[cons.idx()].is_dynamic_constant()) + { + Some((control, dimension)) } else { None } diff --git a/juno_samples/rodinia/backprop/src/backprop.jn b/juno_samples/rodinia/backprop/src/backprop.jn index 356bb3d91836ba0994cad56315b9a5588b0df8b7..94c4334c1cae17a396384ad6135432e3e80f70e3 100644 --- a/juno_samples/rodinia/backprop/src/backprop.jn +++ b/juno_samples/rodinia/backprop/src/backprop.jn @@ -7,9 +7,9 @@ fn layer_forward<n, m: usize>(vals: f32[n + 1], weights: f32[n + 1, m + 1]) -> f @res let result : f32[m + 1]; result[0] = 1.0; - for j in 1..=m { + @outer_loop for j in 1..=m { let sum = 0.0; - for k in 0..=n { + @inner_loop for k in 0..=n { sum += weights[k, j] * vals[k]; } result[j] = squash(sum); diff --git a/juno_samples/rodinia/backprop/src/cpu.sch b/juno_samples/rodinia/backprop/src/cpu.sch index d1fe89536f5d0c73551a57162e7176be16629bb5..d59fd5f582dbb33ce5d786896b17673e0986776f 100644 --- a/juno_samples/rodinia/backprop/src/cpu.sch +++ b/juno_samples/rodinia/backprop/src/cpu.sch @@ -15,20 +15,16 @@ delete-uncalled(*); no-memset(layer_forward@res); lift-dc-math(*); loop-bound-canon(*); -dce(*); +simpl!(*); lift-dc-math(*); +slf(*); fixpoint { forkify(*); fork-guard-elim(*); fork-coalesce(*); } +simpl!(*); fork-split(*); -gvn(*); -phi-elim(*); -dce(*); unforkify(*); -gvn(*); -phi-elim(*); -dce(*); gcm(*);