Skip to content
Snippets Groups Projects
backprop.jn 2.82 KiB
Newer Older
  • Learn to ignore specific revisions
  • Aaron Councilman's avatar
    Aaron Councilman committed
    fn squash(x: f32) -> f32 {
      // Sigmoid
      return 1.0 / (1.0 + exp!(-x));
    }
    
    fn layer_forward<n, m: usize>(vals: f32[n + 1], weights: f32[n + 1, m + 1]) -> f32[m + 1] {
    
      @res let result : f32[m + 1];
    
    Aaron Councilman's avatar
    Aaron Councilman committed
      result[0] = 1.0;
    
    
    rarbore2's avatar
    rarbore2 committed
      @outer_loop for j in 1..=m {
    
    Aaron Councilman's avatar
    Aaron Councilman committed
        let sum = 0.0;
    
    rarbore2's avatar
    rarbore2 committed
        @inner_loop for k in 0..=n {
    
    Aaron Councilman's avatar
    Aaron Councilman committed
          sum += weights[k, j] * vals[k];
        }
        result[j] = squash(sum);
      }
    
      return result;
    }
    
    
    Aaron Councilman's avatar
    Aaron Councilman committed
    fn output_error<n: usize>(target: f32[n + 1], actual: f32[n + 1]) -> f32, f32[n + 1] {
    
    Aaron Councilman's avatar
    Aaron Councilman committed
      let errsum = 0.0;
      let delta : f32[n + 1];
    
      for j in 1..=n {
        let a = actual[j];
        let t = target[j];
        delta[j] = a * (1.0 - a) * (t - a);
        errsum += abs!(delta[j]);
      }
    
    
    Aaron Councilman's avatar
    Aaron Councilman committed
      return errsum, delta;
    
    Aaron Councilman's avatar
    Aaron Councilman committed
    }
    
    fn hidden_error<hidden_n, output_n: usize>(
      out_delta: f32[output_n + 1],
      hidden_weights: f32[hidden_n + 1, output_n + 1],
      hidden_vals: f32[hidden_n + 1],
    
    Aaron Councilman's avatar
    Aaron Councilman committed
    ) -> f32, f32[hidden_n + 1] {
    
    Aaron Councilman's avatar
    Aaron Councilman committed
      let errsum = 0.0;
      let delta : f32[hidden_n + 1];
    
      for j in 1..=hidden_n {
        let h = hidden_vals[j];
    
        let sum = 0.0;
        for k in 1..=output_n {
          sum += out_delta[k] * hidden_weights[j, k];
        }
    
        delta[j] = h * (1.0 - h) * sum;
        errsum += abs!(delta[j]);
      }
    
    
    Aaron Councilman's avatar
    Aaron Councilman committed
      return errsum, delta;
    
    Aaron Councilman's avatar
    Aaron Councilman committed
    }
    
    const ETA : f32 = 0.3;
    const MOMENTUM : f32 = 0.3;
    
    fn adjust_weights<n, m: usize>(
      delta: f32[m + 1],
      vals: f32[n + 1],
      weights: f32[n + 1, m + 1],
      prev_weights: f32[n + 1, m + 1]
    
    Aaron Councilman's avatar
    Aaron Councilman committed
    ) -> f32[n + 1, m + 1], f32[n + 1, m + 1] {
    
    Aaron Councilman's avatar
    Aaron Councilman committed
      for j in 1..=m {
        for k in 0..=n {
          let new_dw = ETA * delta[j] * vals[k] + MOMENTUM * prev_weights[k, j];
          weights[k, j] += new_dw;
          prev_weights[k, j] = new_dw;
        }
      }
    
    
    Aaron Councilman's avatar
    Aaron Councilman committed
      return weights, prev_weights;
    
    Aaron Councilman's avatar
    Aaron Councilman committed
    }
    
    #[entry]
    fn backprop<input_n, hidden_n, output_n: usize>(
      input_vals: f32[input_n + 1],
      input_weights: f32[input_n + 1, hidden_n + 1],
      hidden_weights: f32[hidden_n + 1, output_n + 1],
      target: f32[output_n + 1],
      input_prev_weights: f32[input_n + 1, hidden_n + 1],
      hidden_prev_weights: f32[hidden_n + 1, output_n + 1],
    
    Aaron Councilman's avatar
    Aaron Councilman committed
    ) -> f32, f32,
         f32[input_n + 1, hidden_n + 1], f32[input_n + 1, hidden_n + 1],
         f32[hidden_n + 1, output_n + 1], f32[hidden_n + 1, output_n + 1] {
    
    Aaron Councilman's avatar
    Aaron Councilman committed
      let hidden_vals = layer_forward::<input_n, hidden_n>(input_vals, input_weights);
      let output_vals = layer_forward::<hidden_n, output_n>(hidden_vals, hidden_weights);
    
    
    Aaron Councilman's avatar
    Aaron Councilman committed
      let out_err, out_delta = output_error::<output_n>(target, output_vals);
      let hid_err, hid_delta = hidden_error::<hidden_n, output_n>(out_delta, hidden_weights, hidden_vals);
    
    Aaron Councilman's avatar
    Aaron Councilman committed
    
    
    Aaron Councilman's avatar
    Aaron Councilman committed
      let hidden_weights, hidden_prev_weights
    
    Aaron Councilman's avatar
    Aaron Councilman committed
        = adjust_weights::<hidden_n, output_n>(out_delta, hidden_vals, hidden_weights, hidden_prev_weights);
    
    Aaron Councilman's avatar
    Aaron Councilman committed
      let input_weights, input_prev_weights
    
    Aaron Councilman's avatar
    Aaron Councilman committed
        = adjust_weights::<input_n, hidden_n>(hid_delta, input_vals, input_weights, input_prev_weights);
    
    
    Aaron Councilman's avatar
    Aaron Councilman committed
      return out_err, hid_err, input_weights, input_prev_weights, hidden_weights, hidden_prev_weights;
    
    Aaron Councilman's avatar
    Aaron Councilman committed
    }