Skip to content
Snippets Groups Projects
backprop.jn 3.10 KiB
fn squash(x: f32) -> f32 {
  // Sigmoid
  return 1.0 / (1.0 + exp!(-x));
}

fn layer_forward<n, m: usize>(vals: f32[n + 1], weights: f32[n + 1, m + 1]) -> f32[m + 1] {
  @res let result : f32[m + 1];
  result[0] = 1.0;
  @outer_loop for j in 1..=m {
    let sum = weights[0, j] * vals[0];
    @inner_loop for k in 1..=n {
      sum += weights[k, j] * vals[k];
    }
    result[j] = squash(sum);
  }

  return result;
}

fn output_error<n: usize>(target: f32[n + 1], actual: f32[n + 1]) -> f32, f32[n + 1] {
  @loop1 @res let delta : f32[n + 1];
  @loop1 delta[0] = 0.0;
  @loop1 for j in 1..=n {
    let a = actual[j];
    let t = target[j];
    delta[j] = a * (1.0 - a) * (t - a);
  }

  let errsum = 0.0;
  @loop2 for j in 1..=n {
    errsum += abs!(delta[j]);
  }

  return errsum, delta;
}

fn hidden_error<hidden_n, output_n: usize>(
  out_delta: f32[output_n + 1],
  hidden_weights: f32[hidden_n + 1, output_n + 1],
  hidden_vals: f32[hidden_n + 1],
) -> f32, f32[hidden_n + 1] {
  @loop1 @res let delta : f32[hidden_n + 1];
  @loop1 delta[0] = 0.0;
  @loop1 for j in 1..=hidden_n {
    let h = hidden_vals[j];

    let sum = 0.0;
    for k in 1..=output_n {
      sum += out_delta[k] * hidden_weights[j, k];
    }

    delta[j] = h * (1.0 - h) * sum;
  }

  let errsum = 0.0;
  @loop2 for j in 1..=hidden_n {
    errsum += abs!(delta[j]);
  }

  return errsum, delta;
}

const ETA : f32 = 0.3;
const MOMENTUM : f32 = 0.3;

fn adjust_weights<n, m: usize>(
  delta: f32[m + 1],
  vals: f32[n + 1],
  weights: f32[n + 1, m + 1],
  prev_weights: f32[n + 1, m + 1]
) -> f32[n + 1, m + 1], f32[n + 1, m + 1] {
  @outer_loop for j in 1..=m {
    @inner_loop for k in 0..=n {
      let new_dw = ETA * delta[j] * vals[k] + MOMENTUM * prev_weights[k, j];
      weights[k, j] += new_dw;
      prev_weights[k, j] = new_dw;
    }
  }

  return weights, prev_weights;
}

#[entry]
fn backprop<input_n, hidden_n, output_n: usize>(
  input_vals: f32[input_n + 1],
  input_weights: f32[input_n + 1, hidden_n + 1],
  hidden_weights: f32[hidden_n + 1, output_n + 1],
  target: f32[output_n + 1],
  input_prev_weights: f32[input_n + 1, hidden_n + 1],
  hidden_prev_weights: f32[hidden_n + 1, output_n + 1],
) -> f32, f32,
     f32[input_n + 1, hidden_n + 1], f32[input_n + 1, hidden_n + 1],
     f32[hidden_n + 1, output_n + 1], f32[hidden_n + 1, output_n + 1] {
  @forward_input let hidden_vals = layer_forward::<input_n, hidden_n>(input_vals, input_weights);
  @forward_hidden let output_vals = layer_forward::<hidden_n, output_n>(hidden_vals, hidden_weights);

  @output_error let out_err, out_delta = output_error::<output_n>(target, output_vals);
  @hidden_error let hid_err, hid_delta = hidden_error::<hidden_n, output_n>(out_delta, hidden_weights, hidden_vals);

  @adjust_hidden let hidden_weights, hidden_prev_weights
    = adjust_weights::<hidden_n, output_n>(out_delta, hidden_vals, hidden_weights, hidden_prev_weights);
  @adjust_input let input_weights, input_prev_weights
    = adjust_weights::<input_n, hidden_n>(hid_delta, input_vals, input_weights, input_prev_weights);

  return out_err, hid_err, input_weights, input_prev_weights, hidden_weights, hidden_prev_weights;
}