Skip to content
Snippets Groups Projects
Commit 54438766 authored by Russel Arbore's avatar Russel Arbore
Browse files

some backprop opt

parent b1970233
No related branches found
No related tags found
2 merge requests!215Large benches,!214More optimizations
Pipeline #202009 failed
...@@ -6,10 +6,9 @@ fn squash(x: f32) -> f32 { ...@@ -6,10 +6,9 @@ fn squash(x: f32) -> f32 {
fn layer_forward<n, m: usize>(vals: f32[n + 1], weights: f32[n + 1, m + 1]) -> f32[m + 1] { fn layer_forward<n, m: usize>(vals: f32[n + 1], weights: f32[n + 1, m + 1]) -> f32[m + 1] {
@res let result : f32[m + 1]; @res let result : f32[m + 1];
result[0] = 1.0; result[0] = 1.0;
@outer_loop for j in 1..=m { @outer_loop for j in 1..=m {
let sum = 0.0; let sum = weights[0, j] * vals[0];
@inner_loop for k in 0..=n { @inner_loop for k in 1..=n {
sum += weights[k, j] * vals[k]; sum += weights[k, j] * vals[k];
} }
result[j] = squash(sum); result[j] = squash(sum);
...@@ -19,13 +18,16 @@ fn layer_forward<n, m: usize>(vals: f32[n + 1], weights: f32[n + 1, m + 1]) -> f ...@@ -19,13 +18,16 @@ fn layer_forward<n, m: usize>(vals: f32[n + 1], weights: f32[n + 1, m + 1]) -> f
} }
fn output_error<n: usize>(target: f32[n + 1], actual: f32[n + 1]) -> f32, f32[n + 1] { fn output_error<n: usize>(target: f32[n + 1], actual: f32[n + 1]) -> f32, f32[n + 1] {
let errsum = 0.0; @loop1 @res let delta : f32[n + 1];
let delta : f32[n + 1]; @loop1 delta[0] = 0.0;
@loop1 for j in 1..=n {
for j in 1..=n {
let a = actual[j]; let a = actual[j];
let t = target[j]; let t = target[j];
delta[j] = a * (1.0 - a) * (t - a); delta[j] = a * (1.0 - a) * (t - a);
}
let errsum = 0.0;
@loop2 for j in 1..=n {
errsum += abs!(delta[j]); errsum += abs!(delta[j]);
} }
...@@ -37,10 +39,9 @@ fn hidden_error<hidden_n, output_n: usize>( ...@@ -37,10 +39,9 @@ fn hidden_error<hidden_n, output_n: usize>(
hidden_weights: f32[hidden_n + 1, output_n + 1], hidden_weights: f32[hidden_n + 1, output_n + 1],
hidden_vals: f32[hidden_n + 1], hidden_vals: f32[hidden_n + 1],
) -> f32, f32[hidden_n + 1] { ) -> f32, f32[hidden_n + 1] {
let errsum = 0.0; @loop1 @res let delta : f32[hidden_n + 1];
let delta : f32[hidden_n + 1]; @loop1 delta[0] = 0.0;
@loop1 for j in 1..=hidden_n {
for j in 1..=hidden_n {
let h = hidden_vals[j]; let h = hidden_vals[j];
let sum = 0.0; let sum = 0.0;
...@@ -49,6 +50,10 @@ fn hidden_error<hidden_n, output_n: usize>( ...@@ -49,6 +50,10 @@ fn hidden_error<hidden_n, output_n: usize>(
} }
delta[j] = h * (1.0 - h) * sum; delta[j] = h * (1.0 - h) * sum;
}
let errsum = 0.0;
@loop2 for j in 1..=hidden_n {
errsum += abs!(delta[j]); errsum += abs!(delta[j]);
} }
...@@ -89,8 +94,8 @@ fn backprop<input_n, hidden_n, output_n: usize>( ...@@ -89,8 +94,8 @@ fn backprop<input_n, hidden_n, output_n: usize>(
let hidden_vals = layer_forward::<input_n, hidden_n>(input_vals, input_weights); let hidden_vals = layer_forward::<input_n, hidden_n>(input_vals, input_weights);
let output_vals = layer_forward::<hidden_n, output_n>(hidden_vals, hidden_weights); let output_vals = layer_forward::<hidden_n, output_n>(hidden_vals, hidden_weights);
let out_err, out_delta = output_error::<output_n>(target, output_vals); @output_error let out_err, out_delta = output_error::<output_n>(target, output_vals);
let hid_err, hid_delta = hidden_error::<hidden_n, output_n>(out_delta, hidden_weights, hidden_vals); @hidden_error let hid_err, hid_delta = hidden_error::<hidden_n, output_n>(out_delta, hidden_weights, hidden_vals);
let hidden_weights, hidden_prev_weights let hidden_weights, hidden_prev_weights
= adjust_weights::<hidden_n, output_n>(out_delta, hidden_vals, hidden_weights, hidden_prev_weights); = adjust_weights::<hidden_n, output_n>(out_delta, hidden_vals, hidden_weights, hidden_prev_weights);
......
...@@ -12,7 +12,7 @@ simpl!(*); ...@@ -12,7 +12,7 @@ simpl!(*);
inline(layer_forward); inline(layer_forward);
delete-uncalled(*); delete-uncalled(*);
no-memset(layer_forward@res); no-memset(layer_forward@res, output_error@res, hidden_error@res);
lift-dc-math(*); lift-dc-math(*);
loop-bound-canon(*); loop-bound-canon(*);
simpl!(*); simpl!(*);
...@@ -25,6 +25,8 @@ fixpoint { ...@@ -25,6 +25,8 @@ fixpoint {
} }
reduce-slf(*); reduce-slf(*);
simpl!(*); simpl!(*);
fork-interchange[0, 1](adjust_weights);
simpl!(*);
fork-split(*); fork-split(*);
unforkify(*); unforkify(*);
......
gvn(*); macro simpl!(X) {
dce(*); ccp(X);
simplify-cfg(X);
lift-dc-math(X);
gvn(X);
phi-elim(X);
dce(X);
infer-schedules(X);
}
no-memset(layer_forward@res, output_error@res, hidden_error@res);
phi-elim(*); phi-elim(*);
dce(*); let output_loop1 = outline(output_error@loop1);
crc(*); let output_loop2 = outline(output_error@loop2);
dce(*); let hidden_loop1 = outline(hidden_error@loop1);
slf(*); let hidden_loop2 = outline(hidden_error@loop2);
dce(*); simpl!(*);
inline(layer_forward, backprop@output_error, backprop@hidden_error);
delete-uncalled(*);
gpu(layer_forward, output_loop1, output_loop2, hidden_loop1, hidden_loop2, adjust_weights);
const-inline(*);
let auto = auto-outline(backprop); lift-dc-math(*);
gpu(auto.backprop); loop-bound-canon(*);
simpl!(*);
lift-dc-math(*);
slf(*);
fixpoint {
forkify(*);
fork-guard-elim(*);
fork-coalesce(*);
}
reduce-slf(*);
simpl!(*);
inline(auto.backprop); fork-tile[16, 0, false, true](layer_forward@inner_loop);
inline(auto.backprop); let out = fork-split(layer_forward@inner_loop);
delete-uncalled(*); clean-monoid-reduces(layer_forward);
simpl!(layer_forward);
let fission = fork-fission[out._1_layer_forward.fj0](layer_forward);
simpl!(layer_forward);
sroa[true](*); fork-dim-merge(adjust_weights);
dce(*); simpl!(adjust_weights);
float-collections(*); fork-extend[32](adjust_weights);
reuse-products(*); fork-tile[32, 0, false, true](adjust_weights);
dce(*); fork-split(adjust_weights);
simpl!(adjust_weights);
xdot[true](*);
gcm(*); gcm(*);
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment