Skip to content
Snippets Groups Projects
Commit d00ffc3a authored by Aaron Councilman's avatar Aaron Councilman
Browse files

Progress on backprop

parent 47fc8f58
No related branches found
No related tags found
1 merge request!186Rodinia
......@@ -1128,6 +1128,7 @@ dependencies = [
"hercules_rt",
"juno_build",
"nom 6.2.2",
"rand 0.8.5",
"with_builtin_macros",
]
......
......@@ -649,7 +649,7 @@ impl<'a, T> From<HerculesCUDARefMut<'a>> for HerculesMutBox<'a, T> {
}
}
impl<'a, T> HerculesMutBox<'a, T>
impl<'a, 'b: 'a, T> HerculesMutBox<'b, T>
where
T: Default + Clone
{
......@@ -780,7 +780,7 @@ pub trait HerculesMutBoxTo<'a, T> {
fn to(&'a mut self) -> T;
}
impl<'a, T> HerculesMutBoxTo<'a, HerculesCPURefMut<'a>> for HerculesMutBox<'a, T>
impl<'a, 'b: 'a, T> HerculesMutBoxTo<'a, HerculesCPURefMut<'a>> for HerculesMutBox<'b, T>
where T: Default + Clone
{
fn to(&'a mut self) -> HerculesCPURefMut<'a> {
......@@ -789,7 +789,7 @@ where T: Default + Clone
}
#[cfg(feature = "cuda")]
impl<'a, T> HerculesMutBoxTo<'a, HerculesCUDARefMut<'a>> for HerculesMutBox<'a, T>
impl<'a, 'b: 'a, T> HerculesMutBoxTo<'a, HerculesCUDARefMut<'a>> for HerculesMutBox<'b, T>
where T: Default + Clone
{
fn to(&'a mut self) -> HerculesCUDARefMut<'a> {
......
......@@ -21,3 +21,4 @@ async-std = "*"
clap = { version = "*", features = ["derive"] }
with_builtin_macros = "0.1.0"
nom = "*"
rand = "*"
......@@ -78,8 +78,9 @@ fn backprop<input_n, hidden_n, output_n: usize>(
target: f32[output_n + 1],
input_prev_weights: f32[input_n + 1, hidden_n + 1],
hidden_prev_weights: f32[hidden_n + 1, output_n + 1],
) -> (f32[input_n + 1, hidden_n + 1], f32[input_n + 1, hidden_n + 1],
f32[hidden_n + 1, output_n + 1], f32[hidden_n + 1, output_n + 1]) {
//) -> (f32[input_n + 1, hidden_n + 1], f32[input_n + 1, hidden_n + 1],
// f32[hidden_n + 1, output_n + 1], f32[hidden_n + 1, output_n + 1]) {
) -> f32 {
let hidden_vals = layer_forward::<input_n, hidden_n>(input_vals, input_weights);
let output_vals = layer_forward::<hidden_n, output_n>(hidden_vals, hidden_weights);
......@@ -88,5 +89,6 @@ fn backprop<input_n, hidden_n, output_n: usize>(
let (hidden_weights, hidden_prev_weights) = adjust_weights::<hidden_n, output_n>(out_delta, hidden_vals, hidden_weights, hidden_prev_weights);
let (input_weights, input_prev_weights) = adjust_weights::<input_n, hidden_n>(hid_delta, input_vals, input_weights, input_prev_weights);
return (input_weights, input_prev_weights, hidden_weights, hidden_prev_weights);
return (input_weights[0, 0] + input_prev_weights[0, 0] + hidden_weights[0, 0] + hidden_prev_weights[0, 0]);
//return (input_weights, input_prev_weights, hidden_weights, hidden_prev_weights);
}
......@@ -24,7 +24,13 @@ inline(auto.backprop);
inline(auto.backprop);
delete-uncalled(*);
sroa(*);
dce(*);
float-collections(*);
gcm(*);
reuse-products(*);
dce(*);
xdot[true](*);
gcm(*);
......@@ -2,6 +2,65 @@
juno_build::juno!("backprop");
use hercules_rt::{runner, HerculesMutBox, HerculesImmBox, HerculesImmBoxTo, HerculesMutBoxTo};
use rand::Rng;
fn main() {
todo!()
let n_in = 2;
let n_hid = 4;
let n_out = 1;
let mut rng = rand::thread_rng();
let mut in_weights = (0..(n_in+1)*(n_hid+1)).map(|_| rng.gen::<f32>()).collect::<Vec<_>>();
let mut in_prev_weights = (0..(n_in+1)*(n_hid+1)).map(|_| rng.gen::<f32>()).collect::<Vec<_>>();
let mut hid_weights = (0..(n_hid+1)*(n_out+1)).map(|_| rng.gen::<f32>()).collect::<Vec<_>>();
let mut hid_prev_weights = (0..(n_hid+1)*(n_out+1)).map(|_| rng.gen::<f32>()).collect::<Vec<_>>();
println!("{:?}", in_weights);
println!("{:?}", in_prev_weights);
println!("{:?}", hid_weights);
println!("{:?}", hid_prev_weights);
println!("");
let input = vec![1.0, 0.0, 1.0];
let target = vec![1.0, 1.0];
{
let mut in_weights = HerculesMutBox::from(in_weights.as_mut_slice());
let mut in_prev_weights = HerculesMutBox::from(in_prev_weights.as_mut_slice());
let mut hid_weights = HerculesMutBox::from(hid_weights.as_mut_slice());
let mut hid_prev_weights = HerculesMutBox::from(hid_prev_weights.as_mut_slice());
let input = HerculesImmBox::from(input.as_slice());
let target = HerculesImmBox::from(target.as_slice());
let mut runner = runner!(backprop);
// Drop the result, we don't care about it
async_std::task::block_on(async {
let _ = runner.run(
n_in as u64,
n_hid as u64,
n_out as u64,
input.to(),
in_weights.to(),
hid_weights.to(),
target.to(),
in_prev_weights.to(),
hid_prev_weights.to(),
).await;
println!("{:?}", in_weights.as_slice());
println!("{:?}", in_prev_weights.as_slice());
println!("{:?}", hid_weights.as_slice());
println!("{:?}", hid_prev_weights.as_slice());
println!("");
});
}
println!("{:?}", in_weights);
println!("{:?}", in_prev_weights);
println!("{:?}", hid_weights);
println!("{:?}", hid_prev_weights);
}
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment