Skip to content
Snippets Groups Projects

More optimizations

Merged rarbore2 requested to merge more_opt3 into main
3 files
+ 66
32
Compare changes
  • Side-by-side
  • Inline
Files
3
@@ -6,10 +6,9 @@ fn squash(x: f32) -> f32 {
fn layer_forward<n, m: usize>(vals: f32[n + 1], weights: f32[n + 1, m + 1]) -> f32[m + 1] {
@res let result : f32[m + 1];
result[0] = 1.0;
@outer_loop for j in 1..=m {
let sum = 0.0;
@inner_loop for k in 0..=n {
let sum = weights[0, j] * vals[0];
@inner_loop for k in 1..=n {
sum += weights[k, j] * vals[k];
}
result[j] = squash(sum);
@@ -19,13 +18,16 @@ fn layer_forward<n, m: usize>(vals: f32[n + 1], weights: f32[n + 1, m + 1]) -> f
}
fn output_error<n: usize>(target: f32[n + 1], actual: f32[n + 1]) -> f32, f32[n + 1] {
let errsum = 0.0;
let delta : f32[n + 1];
for j in 1..=n {
@loop1 @res let delta : f32[n + 1];
@loop1 delta[0] = 0.0;
@loop1 for j in 1..=n {
let a = actual[j];
let t = target[j];
delta[j] = a * (1.0 - a) * (t - a);
}
let errsum = 0.0;
@loop2 for j in 1..=n {
errsum += abs!(delta[j]);
}
@@ -37,10 +39,9 @@ fn hidden_error<hidden_n, output_n: usize>(
hidden_weights: f32[hidden_n + 1, output_n + 1],
hidden_vals: f32[hidden_n + 1],
) -> f32, f32[hidden_n + 1] {
let errsum = 0.0;
let delta : f32[hidden_n + 1];
for j in 1..=hidden_n {
@loop1 @res let delta : f32[hidden_n + 1];
@loop1 delta[0] = 0.0;
@loop1 for j in 1..=hidden_n {
let h = hidden_vals[j];
let sum = 0.0;
@@ -49,6 +50,10 @@ fn hidden_error<hidden_n, output_n: usize>(
}
delta[j] = h * (1.0 - h) * sum;
}
let errsum = 0.0;
@loop2 for j in 1..=hidden_n {
errsum += abs!(delta[j]);
}
@@ -89,8 +94,8 @@ fn backprop<input_n, hidden_n, output_n: usize>(
let hidden_vals = layer_forward::<input_n, hidden_n>(input_vals, input_weights);
let output_vals = layer_forward::<hidden_n, output_n>(hidden_vals, hidden_weights);
let out_err, out_delta = output_error::<output_n>(target, output_vals);
let hid_err, hid_delta = hidden_error::<hidden_n, output_n>(out_delta, hidden_weights, hidden_vals);
@output_error let out_err, out_delta = output_error::<output_n>(target, output_vals);
@hidden_error let hid_err, hid_delta = hidden_error::<hidden_n, output_n>(out_delta, hidden_weights, hidden_vals);
let hidden_weights, hidden_prev_weights
= adjust_weights::<hidden_n, output_n>(out_delta, hidden_vals, hidden_weights, hidden_prev_weights);
Loading