Newer
Older
fn squash(x: f32) -> f32 {
// Sigmoid
return 1.0 / (1.0 + exp!(-x));
}
fn layer_forward<n, m: usize>(vals: f32[n + 1], weights: f32[n + 1, m + 1]) -> f32[m + 1] {
sum += weights[k, j] * vals[k];
}
result[j] = squash(sum);
}
return result;
}
fn output_error<n: usize>(target: f32[n + 1], actual: f32[n + 1]) -> f32, f32[n + 1] {
let errsum = 0.0;
let delta : f32[n + 1];
for j in 1..=n {
let a = actual[j];
let t = target[j];
delta[j] = a * (1.0 - a) * (t - a);
errsum += abs!(delta[j]);
}
}
fn hidden_error<hidden_n, output_n: usize>(
out_delta: f32[output_n + 1],
hidden_weights: f32[hidden_n + 1, output_n + 1],
hidden_vals: f32[hidden_n + 1],
let errsum = 0.0;
let delta : f32[hidden_n + 1];
for j in 1..=hidden_n {
let h = hidden_vals[j];
let sum = 0.0;
for k in 1..=output_n {
sum += out_delta[k] * hidden_weights[j, k];
}
delta[j] = h * (1.0 - h) * sum;
errsum += abs!(delta[j]);
}
}
const ETA : f32 = 0.3;
const MOMENTUM : f32 = 0.3;
fn adjust_weights<n, m: usize>(
delta: f32[m + 1],
vals: f32[n + 1],
weights: f32[n + 1, m + 1],
prev_weights: f32[n + 1, m + 1]
for j in 1..=m {
for k in 0..=n {
let new_dw = ETA * delta[j] * vals[k] + MOMENTUM * prev_weights[k, j];
weights[k, j] += new_dw;
prev_weights[k, j] = new_dw;
}
}
}
#[entry]
fn backprop<input_n, hidden_n, output_n: usize>(
input_vals: f32[input_n + 1],
input_weights: f32[input_n + 1, hidden_n + 1],
hidden_weights: f32[hidden_n + 1, output_n + 1],
target: f32[output_n + 1],
input_prev_weights: f32[input_n + 1, hidden_n + 1],
hidden_prev_weights: f32[hidden_n + 1, output_n + 1],
) -> f32, f32,
f32[input_n + 1, hidden_n + 1], f32[input_n + 1, hidden_n + 1],
f32[hidden_n + 1, output_n + 1], f32[hidden_n + 1, output_n + 1] {
let hidden_vals = layer_forward::<input_n, hidden_n>(input_vals, input_weights);
let output_vals = layer_forward::<hidden_n, output_n>(hidden_vals, hidden_weights);
let out_err, out_delta = output_error::<output_n>(target, output_vals);
let hid_err, hid_delta = hidden_error::<hidden_n, output_n>(out_delta, hidden_weights, hidden_vals);
= adjust_weights::<hidden_n, output_n>(out_delta, hidden_vals, hidden_weights, hidden_prev_weights);
= adjust_weights::<input_n, hidden_n>(hid_delta, input_vals, input_weights, input_prev_weights);
return out_err, hid_err, input_weights, input_prev_weights, hidden_weights, hidden_prev_weights;