Skip to content
Snippets Groups Projects
Commit b156fad6 authored by Aaron Councilman's avatar Aaron Councilman
Browse files

Parallelize backprop

parent 416ed7b1
No related branches found
No related tags found
2 merge requests!215Large benches,!214More optimizations
......@@ -69,8 +69,8 @@ fn adjust_weights<n, m: usize>(
weights: f32[n + 1, m + 1],
prev_weights: f32[n + 1, m + 1]
) -> f32[n + 1, m + 1], f32[n + 1, m + 1] {
for j in 1..=m {
for k in 0..=n {
@outer_loop for j in 1..=m {
@inner_loop for k in 0..=n {
let new_dw = ETA * delta[j] * vals[k] + MOMENTUM * prev_weights[k, j];
weights[k, j] += new_dw;
prev_weights[k, j] = new_dw;
......
......@@ -28,6 +28,29 @@ simpl!(*);
fork-interchange[0, 1](adjust_weights);
simpl!(*);
fork-split(*);
unforkify(*);
infer-schedules(*);
fork-tile[32, 0, false, true](layer_forward@outer_loop \ layer_forward@inner_loop);
let (forward_outer, forward_inner) = fork-reshape[[1], [0]](layer_forward@outer_loop \ layer_forward@inner_loop);
fork-tile[32, 0, false, true](adjust_weights);
let (adjust_outer, adjust_inner) = fork-reshape[[1], [0, 2]](adjust_weights);
let forward_body = outline(forward_inner);
let adjust_body = outline(adjust_inner);
rename["output_error"](output_error);
rename["hidden_error"](hidden_error);
let output_error_body = auto-outline(output_error).output_error;
let hidden_error_body = auto-outline(hidden_error).hidden_error;
inline(backprop);
delete-uncalled(*);
const-inline(*);
simpl!(*);
fork-split(forward_body, adjust_body, output_error_body, hidden_error_body);
unforkify(forward_body, adjust_body, output_error_body, hidden_error_body);
gcm(*);
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment