Skip to content
Snippets Groups Projects
Commit 54438766 authored by Russel Arbore's avatar Russel Arbore
Browse files

some backprop opt

parent b1970233
No related branches found
No related tags found
2 merge requests!215Large benches,!214More optimizations
Pipeline #202009 failed
This commit is part of merge request !214. Comments created here will be created in the context of that merge request.
......@@ -6,10 +6,9 @@ fn squash(x: f32) -> f32 {
fn layer_forward<n, m: usize>(vals: f32[n + 1], weights: f32[n + 1, m + 1]) -> f32[m + 1] {
@res let result : f32[m + 1];
result[0] = 1.0;
@outer_loop for j in 1..=m {
let sum = 0.0;
@inner_loop for k in 0..=n {
let sum = weights[0, j] * vals[0];
@inner_loop for k in 1..=n {
sum += weights[k, j] * vals[k];
}
result[j] = squash(sum);
......@@ -19,13 +18,16 @@ fn layer_forward<n, m: usize>(vals: f32[n + 1], weights: f32[n + 1, m + 1]) -> f
}
fn output_error<n: usize>(target: f32[n + 1], actual: f32[n + 1]) -> f32, f32[n + 1] {
let errsum = 0.0;
let delta : f32[n + 1];
for j in 1..=n {
@loop1 @res let delta : f32[n + 1];
@loop1 delta[0] = 0.0;
@loop1 for j in 1..=n {
let a = actual[j];
let t = target[j];
delta[j] = a * (1.0 - a) * (t - a);
}
let errsum = 0.0;
@loop2 for j in 1..=n {
errsum += abs!(delta[j]);
}
......@@ -37,10 +39,9 @@ fn hidden_error<hidden_n, output_n: usize>(
hidden_weights: f32[hidden_n + 1, output_n + 1],
hidden_vals: f32[hidden_n + 1],
) -> f32, f32[hidden_n + 1] {
let errsum = 0.0;
let delta : f32[hidden_n + 1];
for j in 1..=hidden_n {
@loop1 @res let delta : f32[hidden_n + 1];
@loop1 delta[0] = 0.0;
@loop1 for j in 1..=hidden_n {
let h = hidden_vals[j];
let sum = 0.0;
......@@ -49,6 +50,10 @@ fn hidden_error<hidden_n, output_n: usize>(
}
delta[j] = h * (1.0 - h) * sum;
}
let errsum = 0.0;
@loop2 for j in 1..=hidden_n {
errsum += abs!(delta[j]);
}
......@@ -89,8 +94,8 @@ fn backprop<input_n, hidden_n, output_n: usize>(
let hidden_vals = layer_forward::<input_n, hidden_n>(input_vals, input_weights);
let output_vals = layer_forward::<hidden_n, output_n>(hidden_vals, hidden_weights);
let out_err, out_delta = output_error::<output_n>(target, output_vals);
let hid_err, hid_delta = hidden_error::<hidden_n, output_n>(out_delta, hidden_weights, hidden_vals);
@output_error let out_err, out_delta = output_error::<output_n>(target, output_vals);
@hidden_error let hid_err, hid_delta = hidden_error::<hidden_n, output_n>(out_delta, hidden_weights, hidden_vals);
let hidden_weights, hidden_prev_weights
= adjust_weights::<hidden_n, output_n>(out_delta, hidden_vals, hidden_weights, hidden_prev_weights);
......
......@@ -12,7 +12,7 @@ simpl!(*);
inline(layer_forward);
delete-uncalled(*);
no-memset(layer_forward@res);
no-memset(layer_forward@res, output_error@res, hidden_error@res);
lift-dc-math(*);
loop-bound-canon(*);
simpl!(*);
......@@ -25,6 +25,8 @@ fixpoint {
}
reduce-slf(*);
simpl!(*);
fork-interchange[0, 1](adjust_weights);
simpl!(*);
fork-split(*);
unforkify(*);
......
gvn(*);
dce(*);
macro simpl!(X) {
ccp(X);
simplify-cfg(X);
lift-dc-math(X);
gvn(X);
phi-elim(X);
dce(X);
infer-schedules(X);
}
no-memset(layer_forward@res, output_error@res, hidden_error@res);
phi-elim(*);
dce(*);
crc(*);
dce(*);
slf(*);
dce(*);
let output_loop1 = outline(output_error@loop1);
let output_loop2 = outline(output_error@loop2);
let hidden_loop1 = outline(hidden_error@loop1);
let hidden_loop2 = outline(hidden_error@loop2);
simpl!(*);
inline(layer_forward, backprop@output_error, backprop@hidden_error);
delete-uncalled(*);
gpu(layer_forward, output_loop1, output_loop2, hidden_loop1, hidden_loop2, adjust_weights);
const-inline(*);
let auto = auto-outline(backprop);
gpu(auto.backprop);
lift-dc-math(*);
loop-bound-canon(*);
simpl!(*);
lift-dc-math(*);
slf(*);
fixpoint {
forkify(*);
fork-guard-elim(*);
fork-coalesce(*);
}
reduce-slf(*);
simpl!(*);
inline(auto.backprop);
inline(auto.backprop);
delete-uncalled(*);
fork-tile[16, 0, false, true](layer_forward@inner_loop);
let out = fork-split(layer_forward@inner_loop);
clean-monoid-reduces(layer_forward);
simpl!(layer_forward);
let fission = fork-fission[out._1_layer_forward.fj0](layer_forward);
simpl!(layer_forward);
sroa[true](*);
dce(*);
float-collections(*);
reuse-products(*);
dce(*);
fork-dim-merge(adjust_weights);
simpl!(adjust_weights);
fork-extend[32](adjust_weights);
fork-tile[32, 0, false, true](adjust_weights);
fork-split(adjust_weights);
simpl!(adjust_weights);
xdot[true](*);
gcm(*);
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment