Skip to content
Snippets Groups Projects
Commit be6c3a90 authored by Aaron Councilman's avatar Aaron Councilman
Browse files

New backprop schedule

parent 17d33842
No related branches found
No related tags found
2 merge requests!215Large benches,!214More optimizations
Pipeline #202028 passed
...@@ -91,15 +91,15 @@ fn backprop<input_n, hidden_n, output_n: usize>( ...@@ -91,15 +91,15 @@ fn backprop<input_n, hidden_n, output_n: usize>(
) -> f32, f32, ) -> f32, f32,
f32[input_n + 1, hidden_n + 1], f32[input_n + 1, hidden_n + 1], f32[input_n + 1, hidden_n + 1], f32[input_n + 1, hidden_n + 1],
f32[hidden_n + 1, output_n + 1], f32[hidden_n + 1, output_n + 1] { f32[hidden_n + 1, output_n + 1], f32[hidden_n + 1, output_n + 1] {
let hidden_vals = layer_forward::<input_n, hidden_n>(input_vals, input_weights); @forward_input let hidden_vals = layer_forward::<input_n, hidden_n>(input_vals, input_weights);
let output_vals = layer_forward::<hidden_n, output_n>(hidden_vals, hidden_weights); @forward_hidden let output_vals = layer_forward::<hidden_n, output_n>(hidden_vals, hidden_weights);
@output_error let out_err, out_delta = output_error::<output_n>(target, output_vals); @output_error let out_err, out_delta = output_error::<output_n>(target, output_vals);
@hidden_error let hid_err, hid_delta = hidden_error::<hidden_n, output_n>(out_delta, hidden_weights, hidden_vals); @hidden_error let hid_err, hid_delta = hidden_error::<hidden_n, output_n>(out_delta, hidden_weights, hidden_vals);
let hidden_weights, hidden_prev_weights @adjust_hidden let hidden_weights, hidden_prev_weights
= adjust_weights::<hidden_n, output_n>(out_delta, hidden_vals, hidden_weights, hidden_prev_weights); = adjust_weights::<hidden_n, output_n>(out_delta, hidden_vals, hidden_weights, hidden_prev_weights);
let input_weights, input_prev_weights @adjust_input let input_weights, input_prev_weights
= adjust_weights::<input_n, hidden_n>(hid_delta, input_vals, input_weights, input_prev_weights); = adjust_weights::<input_n, hidden_n>(hid_delta, input_vals, input_weights, input_prev_weights);
return out_err, hid_err, input_weights, input_prev_weights, hidden_weights, hidden_prev_weights; return out_err, hid_err, input_weights, input_prev_weights, hidden_weights, hidden_prev_weights;
......
...@@ -30,11 +30,24 @@ simpl!(*); ...@@ -30,11 +30,24 @@ simpl!(*);
infer-schedules(*); infer-schedules(*);
// The first call to layer_forward can be parallelized by 16 (the size of the
// hidden layer) and the second can't be parallelized at all (the size of the
// output layer is 1)
inline(backprop@forward_input, backprop@forward_hidden);
let forward_input = outline(backprop@forward_input);
let forward_hidden = outline(backprop@forward_hidden);
fork-tile[16, 0, false, true](forward_input@outer_loop \ forward_input@inner_loop);
let (outer, inner) = fork-reshape[[1], [0]](forward_input@outer_loop \ forward_input@inner_loop);
let forward_input = outline(inner);
inline(backprop@forward_input);
delete-uncalled(*); delete-uncalled(*);
const-inline(*); const-inline(*);
simpl!(*); simpl!(*);
fork-split(*); fork-split(*);
unforkify(*); unforkify(output_error, hidden_error, adjust_weights, forward_hidden, forward_input);
simpl!(*);
gcm(*); gcm(*);
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment