Skip to content
Snippets Groups Projects
Commit 501ed914 authored by Aaron Councilman's avatar Aaron Councilman
Browse files

Fix backprop

parent aca44724
No related branches found
No related tags found
1 merge request!186Rodinia
Pipeline #201714 failed
......@@ -5,6 +5,7 @@ fn squash(x: f32) -> f32 {
fn layer_forward<n, m: usize>(vals: f32[n + 1], weights: f32[n + 1, m + 1]) -> f32[m + 1] {
let result : f32[m + 1];
result[0] = 1.0;
for j = 0 to m {
let sum = 0.0;
......@@ -57,13 +58,17 @@ fn hidden_error<hidden_n, output_n: usize>(
const ETA : f32 = 0.3;
const MOMENTUM : f32 = 0.3;
fn adjust_weights<n, m: usize>(delta: f32[m + 1], vals: f32[n + 1], weights: f32[n + 1, m + 1], prev_weights: f32[n + 1, m + 1])
-> (f32[n + 1, m + 1], f32[n + 1, m + 1]) {
fn adjust_weights<n, m: usize>(
delta: f32[m + 1],
vals: f32[n + 1],
weights: f32[n + 1, m + 1],
prev_weights: f32[n + 1, m + 1]
) -> (f32[n + 1, m + 1], f32[n + 1, m + 1]) {
for j = 0 to m {
for k = 0 to n {
let new_dw = ETA * delta[j+1] * vals[k+1] + MOMENTUM * prev_weights[k+1, j+1];
weights[k+1, j+1] += new_dw;
prev_weights[k+1, j+1] = new_dw;
for k = 0 to n+1 {
let new_dw = ETA * delta[j+1] * vals[k] + MOMENTUM * prev_weights[k, j+1];
weights[k, j+1] += new_dw;
prev_weights[k, j+1] = new_dw;
}
}
......@@ -78,17 +83,21 @@ fn backprop<input_n, hidden_n, output_n: usize>(
target: f32[output_n + 1],
input_prev_weights: f32[input_n + 1, hidden_n + 1],
hidden_prev_weights: f32[hidden_n + 1, output_n + 1],
//) -> (f32[input_n + 1, hidden_n + 1], f32[input_n + 1, hidden_n + 1],
//) -> (f32, f32,
// f32[input_n + 1, hidden_n + 1], f32[input_n + 1, hidden_n + 1],
// f32[hidden_n + 1, output_n + 1], f32[hidden_n + 1, output_n + 1]) {
) -> f32 {
// TODO: We may need to set some first elements to 1.0 in some places
) -> (f32, f32, f32) {
let hidden_vals = layer_forward::<input_n, hidden_n>(input_vals, input_weights);
let output_vals = layer_forward::<hidden_n, output_n>(hidden_vals, hidden_weights);
let (out_err, out_delta) = output_error::<output_n>(target, output_vals);
let (hid_err, hid_delta) = hidden_error::<hidden_n, output_n>(out_delta, hidden_weights, hidden_vals);
let (hidden_weights, hidden_prev_weights) = adjust_weights::<hidden_n, output_n>(out_delta, hidden_vals, hidden_weights, hidden_prev_weights);
let (input_weights, input_prev_weights) = adjust_weights::<input_n, hidden_n>(hid_delta, input_vals, input_weights, input_prev_weights);
return (input_weights[0, 0] + input_prev_weights[0, 0] + hidden_weights[0, 0] + hidden_prev_weights[0, 0]);
let (hidden_weights, hidden_prev_weights)
= adjust_weights::<hidden_n, output_n>(out_delta, hidden_vals, hidden_weights, hidden_prev_weights);
let (input_weights, input_prev_weights)
= adjust_weights::<input_n, hidden_n>(hid_delta, input_vals, input_weights, input_prev_weights);
return (out_err, hid_err, input_weights[0, 0] + input_prev_weights[0, 0] + hidden_weights[0, 0] + hidden_prev_weights[0, 0]);
//return (input_weights, input_prev_weights, hidden_weights, hidden_prev_weights);
}
......@@ -27,7 +27,7 @@ fn run_backprop(
target: &[f32],
input_prev_weights: &[f32],
hidden_prev_weights: &[f32],
) -> (Vec<f32>, Vec<f32>, Vec<f32>, Vec<f32>) {
) -> (f32, f32, Vec<f32>, Vec<f32>, Vec<f32>, Vec<f32>) {
let input_vals = HerculesImmBox::from(input_vals);
let target = HerculesImmBox::from(target);
......@@ -37,21 +37,26 @@ fn run_backprop(
let mut hidden_prev_weights = HerculesMutBox::from(hidden_prev_weights.to_vec());
let mut runner = runner!(backprop);
let _ = async_std::task::block_on(async {
runner.run(input_n,
hidden_n,
output_n,
input_vals.to(),
input_weights.to(),
hidden_weights.to(),
target.to(),
input_prev_weights.to(),
hidden_prev_weights.to(),
)
.await
});
(input_weights.as_slice().to_vec(), hidden_weights.as_slice().to_vec(),
let res = HerculesMutBox::from(
async_std::task::block_on(async {
runner.run(input_n,
hidden_n,
output_n,
input_vals.to(),
input_weights.to(),
hidden_weights.to(),
target.to(),
input_prev_weights.to(),
hidden_prev_weights.to(),
)
.await
})
).as_slice().to_vec();
let out_err = res[0];
let hid_err = res[1];
(out_err, hid_err,
input_weights.as_slice().to_vec(), hidden_weights.as_slice().to_vec(),
input_prev_weights.as_slice().to_vec(), hidden_prev_weights.as_slice().to_vec())
}
......@@ -78,7 +83,8 @@ fn backprop_harness(args: BackpropInputs) {
let input_prev_weights = vec![0.0; (input_n + 1) * (hidden_n + 1)];
let hidden_prev_weights = vec![0.0; (hidden_n + 1) * (output_n + 1)];
let (juno_input_weights, juno_hidden_weights,
let (juno_out_err, juno_hid_err,
juno_input_weights, juno_hidden_weights,
juno_input_prev_weights, juno_hidden_prev_weights)
= run_backprop(
input_n as u64,
......@@ -92,7 +98,8 @@ fn backprop_harness(args: BackpropInputs) {
&hidden_prev_weights,
);
let (rust_input_weights, rust_hidden_weights,
let (rust_out_err, rust_hid_err,
rust_input_weights, rust_hidden_weights,
rust_input_prev_weights, rust_hidden_prev_weights)
= rust_backprop::backprop(
input_n,
......@@ -106,10 +113,28 @@ fn backprop_harness(args: BackpropInputs) {
hidden_prev_weights,
);
assert_eq!(juno_input_weights, rust_input_weights);
assert_eq!(juno_hidden_weights, rust_hidden_weights);
assert_eq!(juno_input_prev_weights, rust_input_prev_weights);
assert_eq!(juno_hidden_prev_weights, rust_hidden_prev_weights);
assert_eq!(juno_out_err, rust_out_err);
assert_eq!(juno_hid_err, rust_hid_err);
if juno_input_weights != rust_input_weights {
let diff = juno_input_weights
.into_iter()
.zip(rust_input_weights.into_iter())
.enumerate()
.filter(|(i, (x, y))| x != y)
.collect::<Vec<_>>();
println!("Mismatch at {} locations", diff.len());
println!("{:?}", diff);
panic!("Input weights do not match after training");
}
if juno_hidden_weights != rust_hidden_weights {
panic!("Hidden weights do not match after training");
}
if juno_input_prev_weights != rust_input_prev_weights {
panic!("Input prev_weights do not match after training");
}
if juno_hidden_prev_weights != rust_hidden_prev_weights {
panic!("Hidden prev_weights do not match after training");
}
}
fn main() {
......
......@@ -13,18 +13,24 @@ fn layer_forward(n: usize, m: usize, vals: &[f32], weights: &[f32]) -> Vec<f32>
result
}
fn output_error(n: usize, target: &[f32], actual: &[f32]) -> Vec<f32> {
fn output_error(n: usize, target: &[f32], actual: &[f32]) -> (f32, Vec<f32>) {
let mut result = vec![0.0; n + 1];
let mut error = 0.0;
for j in 1..=n {
let o = actual[j];
let t = target[j];
result[j] = o * (1.0 - o) * (t - o);
error += result[j].abs();
}
result
(error, result)
}
fn hidden_error(n: usize, m: usize, delta: &[f32], weights: &[f32], actual: &[f32]) -> Vec<f32> {
fn hidden_error(n: usize, m: usize, delta: &[f32], weights: &[f32], actual: &[f32]) -> (f32, Vec<f32>) {
let mut result = vec![0.0; n + 1];
let mut error = 0.0;
for j in 1..=n {
let h = actual[j];
let mut sum = 0.0;
......@@ -32,8 +38,10 @@ fn hidden_error(n: usize, m: usize, delta: &[f32], weights: &[f32], actual: &[f3
sum += delta[k] * weights[j * (m + 1) + k];
}
result[j] = h * (1.0 - h) * sum;
error += result[j].abs();
}
result
(error, result)
}
fn adjust_weights(
......@@ -46,7 +54,7 @@ fn adjust_weights(
) -> (Vec<f32>, Vec<f32>) {
for j in 1..=m {
for k in 0..=n {
let new_dw = (0.3 * delta[j] * vals[k]) + (0.3 * weights[k * (m + 1) + j]);
let new_dw = (0.3 * delta[j] * vals[k]) + (0.3 * prev_weights[k * (m + 1) + j]);
weights[k * (m + 1) + j] += new_dw;
prev_weights[k * (m + 1) + j] = new_dw;
}
......@@ -65,17 +73,17 @@ pub fn backprop(
target: &[f32],
input_prev_weights: Vec<f32>,
hidden_prev_weights: Vec<f32>,
) -> (Vec<f32>, Vec<f32>, Vec<f32>, Vec<f32>) {
) -> (f32, f32, Vec<f32>, Vec<f32>, Vec<f32>, Vec<f32>) {
let hidden_vals = layer_forward(input_n, hidden_n, input_vals, &input_weights);
let output_vals = layer_forward(hidden_n, output_n, &hidden_vals, &hidden_weights);
let out_delta = output_error(output_n, target, &output_vals);
let hid_delta = hidden_error(hidden_n, output_n, &out_delta, &hidden_weights, &hidden_vals);
let (out_err, out_delta) = output_error(output_n, target, &output_vals);
let (hid_err, hid_delta) = hidden_error(hidden_n, output_n, &out_delta, &hidden_weights, &hidden_vals);
let (hidden_weights, hidden_prev_weights)
= adjust_weights(hidden_n, output_n, &out_delta, &hidden_vals, hidden_weights, hidden_prev_weights);
let (input_weights, input_prev_weights)
= adjust_weights(input_n, hidden_n, &hid_delta, &input_vals, input_weights, input_prev_weights);
(input_weights, hidden_weights, input_prev_weights, hidden_prev_weights)
(out_err, hid_err, input_weights, hidden_weights, input_prev_weights, hidden_prev_weights)
}
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment