#![feature(concat_idents)]
mod graphics;
mod rust_srad;

pub use graphics::*;

use clap::Parser;

use hercules_rt::{runner, HerculesImmBox, HerculesImmBoxTo, HerculesMutBox, HerculesMutBoxTo};

juno_build::juno!("srad");

#[derive(Parser)]
#[clap(author, version, about, long_about = None)]
pub struct SRADInputs {
    pub niter: usize,
    pub lambda: f32,
    pub nrows: usize,
    pub ncols: usize,
    pub image: String,
    #[clap(short, long)]
    pub output: Option<String>,
    #[clap(short, long)]
    pub verify: bool,
    #[clap(long = "output-verify", value_name = "PATH")]
    pub output_verify: Option<String>,
}

pub fn srad_harness(args: SRADInputs) {
    async_std::task::block_on(async {
        let SRADInputs {
            niter,
            lambda,
            nrows,
            ncols,
            image,
            output,
            verify,
            output_verify,
        } = args;

        let Image {
            image: image_ori,
            max,
            rows: image_ori_rows,
            cols: image_ori_cols,
        } = read_graphics(image);
        let image = resize(&image_ori, image_ori_rows, image_ori_cols, nrows, ncols);
        let mut image_h = HerculesMutBox::from(image.clone());

        let mut iN = (0..nrows).map(|i| i as i32 - 1).collect::<Vec<_>>();
        let mut iS = (0..nrows).map(|i| i as i32 + 1).collect::<Vec<_>>();
        let mut jW = (0..ncols).map(|j| j as i32 - 1).collect::<Vec<_>>();
        let mut jE = (0..ncols).map(|j| j as i32 + 1).collect::<Vec<_>>();

        // Fix boundary conditions
        iN[0] = 0;
        iS[nrows - 1] = (nrows - 1) as i32;
        jW[0] = 0;
        jE[ncols - 1] = (ncols - 1) as i32;

        let iN_h = HerculesImmBox::from(iN.as_slice());
        let iS_h = HerculesImmBox::from(iS.as_slice());
        let jW_h = HerculesImmBox::from(jW.as_slice());
        let jE_h = HerculesImmBox::from(jE.as_slice());

        let mut runner = runner!(srad);
        let result: Vec<f32> = HerculesMutBox::from(
            runner
                .run(
                    nrows as u64,
                    ncols as u64,
                    niter as u64,
                    image_h.to(),
                    iN_h.to(),
                    iS_h.to(),
                    jW_h.to(),
                    jE_h.to(),
                    max,
                    lambda,
                )
                .await,
        )
        .as_slice()
        .to_vec();

        if let Some(output) = output {
            write_graphics(output, &result, nrows, ncols, max);
        }

        if verify {
            let mut rust_result = image;
            rust_srad::srad(
                nrows,
                ncols,
                niter,
                &mut rust_result,
                &iN,
                &iS,
                &jW,
                &jE,
                max,
                lambda,
            );

            if let Some(output) = output_verify {
                write_graphics(output, &rust_result, nrows, ncols, max);
            }

            let max_diff = result
                .iter()
                .zip(rust_result.iter())
                .map(|(a, b)| (*a as i32 - *b as i32).abs())
                .max()
                .unwrap_or(0);
            assert!(
                max_diff <= 2,
                "Verification failed: maximum pixel difference of {} exceeds threshold of 1",
                max_diff
            );
        }
    })
}