fn medianMatrix<a : number, rows, cols : usize>(m : a[rows, cols]) -> a {
  @median {
    const n : usize = rows * cols;
    
    @tmp let tmp : a[rows * cols];
    for i = 0 to rows * cols {
      tmp[i] = m[i / cols, i % cols];
    }
    
    @medianOuter for i = 0 to n - 1 {
      for j = 0 to n - i - 1 {
        if tmp[j] > tmp[j+1] {
          let t : a = tmp[j];
          tmp[j] = tmp[j+1];
          tmp[j+1] = t;
        }
      }
    }
    
    return tmp[n / 2];
  }
}

const CHAN : u64 = 3;

fn scale<row : usize, col : usize>(input : u8[CHAN, row, col]) -> f32[CHAN, row, col] {
  @res1 let res : f32[CHAN, row, col];

  for chan = 0 to CHAN {
    for r = 0 to row {
      for c = 0 to col {
        res[chan, r, c] = input[chan, r, c] as f32 / 255.0;
      }
    }
  }

  return res;
}

fn demosaic<row : usize, col : usize>(input : f32[CHAN, row, col]) -> f32[CHAN, row, col] {
  @res2 let res : f32[CHAN, row, col];

  @loop for r = 1 to row-1 {
    for c = 1 to col-1 {
      if r % 2 == 0 && c % 2 == 0 {
        let R1 = input[0, r, c-1];
        let R2 = input[0, r, c+1];
        let B1 = input[2, r-1, c];
        let B2 = input[2, r+1, c];
        res[0, r, c] = (R1 + R2) / 2;
        res[1, r, c] = input[1, r, c] * 2;
        res[2, r, c] = (B1 + B2) / 2;
      } else if r % 2 == 0 && c % 2 == 1 {
        let G1 = input[1, r-1, c];
        let G2 = input[1, r+1, c];
        let G3 = input[1, r, c-1];
        let G4 = input[1, r, c+1];
        let B1 = input[2, r-1, c-1];
        let B2 = input[2, r-1, c+1];
        let B3 = input[2, r+1, c-1];
        let B4 = input[2, r+1, c+1];
        res[0, r, c] = input[0, r, c];
        res[1, r, c] = (G1 + G2 + G3 + G4) / 2;
        res[2, r, c] = (B1 + B2 + B3 + B4) / 4;
      } else if r % 2 == 1 && c % 2 == 0 {
        let R1 = input[0, r-1, c-1];
        let R2 = input[0, r+1, c-1];
        let R3 = input[0, r-1, c+1];
        let R4 = input[0, r+1, c+1];
        let G1 = input[1, r-1, c];
        let G2 = input[1, r+1, c];
        let G3 = input[1, r, c-1];
        let G4 = input[1, r, c+1];
        res[0, r, c] = (R1 + R2 + R3 + R4) / 4;
        res[1, r, c] = (G1 + G2 + G3 + G4) / 2;
        res[2, r, c] = input[2, r, c];
      } else {
        let R1 = input[0, r-1, c];
        let R2 = input[0, r+1, c];
        let B1 = input[2, r, c-1];
        let B2 = input[2, r, c+1];
        res[0, r, c] = (R1 + R2) / 2;
        res[1, r, c] = input[1, r, c] * 2;
        res[2, r, c] = (B1 + B2) / 2;
      }
    }
  }

  return res;
}

fn denoise<row : usize, col : usize>(input : f32[CHAN, row, col]) -> f32[CHAN, row, col] {
  @res let res : f32[CHAN, row, col];

  for chan = 0 to CHAN {
    for r = 0 to row {
      for c = 0 to col {
        if r >= 1 && r < row - 1 && c >= 1 && c < col - 1 {
          @filter let filter : f32[3][3]; // same as [3, 3]
          for i = 0 to 3 by 1 {
            for j = 0 to 3 by 1 {
              filter[i, j] = input[chan, r + i - 1, c + j - 1];
            }
          }

	  res[chan, r, c] = medianMatrix::<f32, 3, 3>(filter);
        } else {
          res[chan, r, c] = input[chan, r, c];
        }
      }
    }
  }

  return res;
}

fn transform<row : usize, col : usize>
  (input : f32[CHAN, row, col], tstw_trans : f32[CHAN, CHAN])
    -> f32[CHAN, row, col] {
  @res let result : f32[CHAN, row, col];

  for chan = 0 to CHAN {
    for r = 0 to row {
      for c = 0 to col {
        result[chan, r, c] = max!::<f32>(
                                input[0, r, c] * tstw_trans[0, chan]
                                + input[1, r, c] * tstw_trans[1, chan]
                                + input[2, r, c] * tstw_trans[2, chan]
                                , 0);
      }
    }
  }

  return result;
}

fn gamut<row : usize, col : usize, num_ctrl_pts : usize>(
  input : f32[CHAN, row, col],
  ctrl_pts : f32[num_ctrl_pts, CHAN],
  weights  : f32[num_ctrl_pts, CHAN],
  coefs : f32[4, CHAN]
) -> f32[CHAN, row, col] {
  @res let result : f32[CHAN, row, col];

  @image_loop for r = 0 to row {
    for c = 0 to col {
      @l2 let l2_dist : f32[num_ctrl_pts];
      @cp_loop for cp = 0 to num_ctrl_pts {
        let v1 = input[0, r, c] - ctrl_pts[cp, 0];
        let v2 = input[1, r, c] - ctrl_pts[cp, 1];
        let v3 = input[2, r, c] - ctrl_pts[cp, 2];
        let v  = v1 * v1 + v2 * v2 + v3 * v3;
        l2_dist[cp] = sqrt!::<f32>(v);
      }
     
      @channel_loop for chan = 0 to CHAN {
        let chan_val : f32 = 0.0;
        @cp_loop for cp = 0 to num_ctrl_pts {
          chan_val += l2_dist[cp] * weights[cp, chan];
        }

        chan_val += coefs[0, chan] + coefs[1, chan] * input[0, r, c]
                                   + coefs[2, chan] * input[1, r, c]
                                   + coefs[3, chan] * input[2, r, c];
        result[chan, r, c] = max!::<f32>(chan_val, 0);
      }
    }
  }

  return result;
}

fn tone_map<row : usize, col:usize>
  (input : f32[CHAN, row, col], tone_map : f32[256, CHAN]) -> f32[CHAN, row, col] {
  @res1 let result : f32[CHAN, row, col];

  for chan = 0 to CHAN {
    for r = 0 to row {
      for c = 0 to col {
        let x = (input[chan, r, c] * 255) as u8;
        result[chan, r, c] = tone_map[x as usize, chan];
      }
    }
  }

  return result;
}

fn descale<row : usize, col : usize>(input : f32[CHAN, row, col]) -> u8[CHAN, row, col] {
  @res2 let res : u8[CHAN, row, col];

  for chan = 0 to CHAN {
    for r = 0 to row {
      for c = 0 to col {
        res[chan, r, c] = min!::<f32>(max!::<f32>(input[chan, r, c] * 255, 0), 255) as u8;
      }
    }
  }

  return res;
}

#[entry]
fn cava<r, c, num_ctrl_pts : usize>(
  input : u8[CHAN, r, c],
  TsTw : f32[CHAN, CHAN],
  ctrl_pts : f32[num_ctrl_pts, CHAN],
  weights : f32[num_ctrl_pts, CHAN],
  coefs : f32[4, CHAN],
  tonemap : f32[256, CHAN],
) -> u8[CHAN, r, c] {
  @fuse1 let scaled = scale::<r, c>(input);
  @fuse1 let demosc = demosaic::<r, c>(scaled);
  @fuse2 let denosd = denoise::<r, c>(demosc);
  @fuse3 let transf = transform::<r, c>(denosd, TsTw);
  @fuse4 let gamutd = gamut::<r, c, num_ctrl_pts>(transf, ctrl_pts, weights, coefs);
  @fuse5 let tonemd = tone_map::<r, c>(gamutd, tonemap);
  @fuse5 let dscald = descale::<r, c>(tonemd);
  return dscald;
}