use std::cmp;

pub const CHAN: usize = 3;

fn f32_max(x: f32, y: f32) -> f32 {
    x.max(y)
}

fn scale(rows: usize, cols: usize, input: &[u8]) -> Vec<f32> {
    let mut output = vec![0.0; CHAN * rows * cols];

    for chan in 0..CHAN {
        for row in 0..rows {
            for col in 0..cols {
                let index = (chan * rows + row) * cols + col;
                output[index] = input[index] as f32 * 1.0 / 255.0;
            }
        }
    }

    output
}

fn descale(rows: usize, cols: usize, input: &[f32]) -> Vec<u8> {
    let mut output = vec![0; CHAN * rows * cols];

    for chan in 0..CHAN {
        for row in 0..rows {
            for col in 0..cols {
                let index = (chan * rows + row) * cols + col;
                output[index] = cmp::min(cmp::max((input[index] * 255.0) as u8, 0), 255);
            }
        }
    }

    output
}

fn demosaic(rows: usize, cols: usize, input: &[f32]) -> Vec<f32> {
    let mut result = vec![0.0; CHAN * rows * cols];

    for row in 1..rows - 1 {
        for col in 1..cols - 1 {
            let index_0 = (0 * rows + row) * cols + col;
            let index_1 = (1 * rows + row) * cols + col;
            let index_2 = (2 * rows + row) * cols + col;
            if row % 2 == 0 && col % 2 == 0 {
                let r1 = input[index_0 - 1];
                let r2 = input[index_0 + 1];
                let b1 = input[index_2 - cols];
                let b2 = input[index_2 + cols];
                result[index_0] = (r1 + r2) / 2.0;
                result[index_1] = input[index_1] * 2.0;
                result[index_2] = (b1 + b2) / 2.0;
            } else if row % 2 == 0 && col % 2 == 1 {
                let g1 = input[index_1 - cols];
                let g2 = input[index_1 + cols];
                let g3 = input[index_1 - 1];
                let g4 = input[index_1 + 1];
                let b1 = input[index_2 - cols - 1];
                let b2 = input[index_2 - cols + 1];
                let b3 = input[index_2 + cols - 1];
                let b4 = input[index_2 + cols + 1];
                result[index_0] = input[index_0];
                result[index_1] = (g1 + g2 + g3 + g4) / 2.0;
                result[index_2] = (b1 + b2 + b3 + b4) / 4.0;
            } else if row % 2 == 1 && col % 2 == 0 {
                let r1 = input[index_0 - cols - 1];
                let r2 = input[index_0 + cols - 1];
                let r3 = input[index_0 - cols + 1];
                let r4 = input[index_0 + cols + 1];
                let g1 = input[index_1 - cols];
                let g2 = input[index_1 + cols];
                let g3 = input[index_1 - 1];
                let g4 = input[index_1 + 1];
                result[index_0] = (r1 + r2 + r3 + r4) / 4.0;
                result[index_1] = (g1 + g2 + g3 + g4) / 2.0;
                result[index_2] = input[index_2];
            } else {
                let r1 = input[index_0 - cols];
                let r2 = input[index_0 + cols];
                let b1 = input[index_2 - 1];
                let b2 = input[index_2 + 1];
                result[index_0] = (r1 + r2) / 2.0;
                result[index_1] = input[index_1] * 2.0;
                result[index_2] = (b1 + b2) / 2.0;
            }
        }
    }

    result
}

fn denoise(rows: usize, cols: usize, input: &[f32]) -> Vec<f32> {
    let mut result = vec![0.0; CHAN * rows * cols];

    for chan in 0..CHAN {
        for row in 0..rows {
            for col in 0..cols {
                if row >= 1 && row < rows - 1 && col >= 1 && col < cols - 1 {
                    let mut filter = [0.0; 9];
                    for i in -1..=1 {
                        for j in -1..=1 {
                            let index = (i + 1) as usize * 3 + (j + 1) as usize;
                            filter[index] = input[(chan * rows + (i + row as i64) as usize) * cols
                                + (j + col as i64) as usize];
                        }
                    }
                    filter.sort_by(|a, b| a.total_cmp(b));
                    result[(chan * rows + row) * cols + col] = filter[4];
                } else {
                    result[(chan * rows + row) * cols + col] =
                        input[(chan * rows + row) * cols + col];
                }
            }
        }
    }

    result
}

fn transform(rows: usize, cols: usize, input: &[f32], tstw_trans: &[f32]) -> Vec<f32> {
    let mut result = vec![0.0; CHAN * rows * cols];

    for chan in 0..CHAN {
        for row in 0..rows {
            for col in 0..cols {
                let index = (chan * rows + row) * cols + col;
                let index_0 = (0 * rows + row) * cols + col;
                let index_1 = (1 * rows + row) * cols + col;
                let index_2 = (2 * rows + row) * cols + col;
                let index_2d_0 = 0 * CHAN + chan;
                let index_2d_1 = 1 * CHAN + chan;
                let index_2d_2 = 2 * CHAN + chan;
                result[index] = f32_max(
                    input[index_0] * tstw_trans[index_2d_0]
                        + input[index_1] * tstw_trans[index_2d_1]
                        + input[index_2] * tstw_trans[index_2d_2],
                    0.0,
                );
            }
        }
    }

    result
}

fn gamut_map(
    rows: usize,
    cols: usize,
    num_ctrl_pts: usize,
    input: &[f32],
    ctrl_pts: &[f32],
    weights: &[f32],
    coefs: &[f32],
) -> Vec<f32> {
    let mut result = vec![0.0; CHAN * rows * cols];
    let mut l2_dist = vec![0.0; num_ctrl_pts];

    for row in 0..rows {
        for col in 0..cols {
            for cp in 0..num_ctrl_pts {
                let index_0 = (0 * rows + row) * cols + col;
                let index_1 = (1 * rows + row) * cols + col;
                let index_2 = (2 * rows + row) * cols + col;
                let val1 = input[index_0] - ctrl_pts[cp * 3 + 0];
                let val2 = input[index_1] - ctrl_pts[cp * 3 + 1];
                let val3 = input[index_2] - ctrl_pts[cp * 3 + 2];
                let val = val1 * val1 + val2 * val2 + val3 * val3;
                let sqrt_val = val.sqrt();
                l2_dist[cp] = sqrt_val;
            }
            for chan in 0..CHAN {
                let mut chan_val = 0.0;
                for cp in 0..num_ctrl_pts {
                    chan_val += l2_dist[cp] * weights[cp * CHAN + chan];
                }
                chan_val += coefs[0 * CHAN + chan]
                    + coefs[1 * CHAN + chan] * input[(0 * rows + row) * cols + col]
                    + coefs[2 * CHAN + chan] * input[(1 * rows + row) * cols + col]
                    + coefs[3 * CHAN + chan] * input[(2 * rows + row) * cols + col];
                result[(chan * rows + row) * cols + col] = f32_max(chan_val, 0.0);
            }
        }
    }

    result
}

fn tone_map(rows: usize, cols: usize, input: &[f32], tone_map: &[f32]) -> Vec<f32> {
    let mut result = vec![0.0; CHAN * rows * cols];

    for chan in 0..CHAN {
        for row in 0..rows {
            for col in 0..cols {
                let index = (chan * rows + row) * cols + col;
                let x = (input[index] * 255.0) as u8 as usize;
                result[index] = tone_map[x * CHAN + chan];
            }
        }
    }

    result
}

pub fn cava(
    rows: usize,
    cols: usize,
    num_ctrl_pts: usize,
    input: &[u8],
    tstw: &[f32],
    ctrl_pts: &[f32],
    weights: &[f32],
    coefs: &[f32],
    tonemap: &[f32],
) -> Vec<u8> {
    let scaled = scale(rows, cols, input);
    let demosc = demosaic(rows, cols, &scaled);
    let denosd = denoise(rows, cols, &demosc);
    let transf = transform(rows, cols, &denosd, tstw);
    let gamutd = gamut_map(rows, cols, num_ctrl_pts, &transf, ctrl_pts, weights, coefs);
    let tonemd = tone_map(rows, cols, &gamutd, tonemap);
    return descale(rows, cols, &tonemd);
}