#![feature(concat_idents)]
mod rust_cfd;
mod setup;

use clap::Parser;

pub use crate::setup::*;

use hercules_rt::{runner, HerculesImmBox, HerculesImmBoxTo, HerculesMutBox, HerculesMutBoxTo};

juno_build::juno!("euler");
juno_build::juno!("pre_euler");

#[derive(Parser)]
#[clap(author, version, about, long_about = None)]
pub struct CFDInputs {
    pub data_file: String,
    pub iterations: usize,
    pub block_size: usize,
    #[clap(short = None, long = Some("pre-euler"))]
    pub pre_euler: bool,
}

fn run_euler(
    nelr: usize,
    iterations: usize,
    mut variables: AlignedSlice<f32>,
    areas: &[f32],
    elements_surrounding_elements: &[i32],
    normals: &[f32],
    ff_variable: &[f32],
    ff_fc_density_energy: &Float3,
    ff_fc_momentum_x: &Float3,
    ff_fc_momentum_y: &Float3,
    ff_fc_momentum_z: &Float3,
) -> Vec<f32> {
    let mut variables = HerculesMutBox::from(variables.as_mut_slice());
    let areas = HerculesImmBox::from(areas);
    let elements_surrounding_elements = HerculesImmBox::from(elements_surrounding_elements);
    let normals = HerculesImmBox::from(normals);
    let ff_variable = HerculesImmBox::from(ff_variable);

    // TODO: Make hercules box handle structs, for now we'll copy into a vec
    let ff_fc_density_energy = vec![
        ff_fc_density_energy.x,
        ff_fc_density_energy.y,
        ff_fc_density_energy.z,
    ];
    let ff_fc_density_energy = HerculesImmBox::from(ff_fc_density_energy.as_slice());
    let ff_fc_momentum_x = vec![ff_fc_momentum_x.x, ff_fc_momentum_x.y, ff_fc_momentum_x.z];
    let ff_fc_momentum_x = HerculesImmBox::from(ff_fc_momentum_x.as_slice());
    let ff_fc_momentum_y = vec![ff_fc_momentum_y.x, ff_fc_momentum_y.y, ff_fc_momentum_y.z];
    let ff_fc_momentum_y = HerculesImmBox::from(ff_fc_momentum_y.as_slice());
    let ff_fc_momentum_z = vec![ff_fc_momentum_z.x, ff_fc_momentum_z.y, ff_fc_momentum_z.z];
    let ff_fc_momentum_z = HerculesImmBox::from(ff_fc_momentum_z.as_slice());

    let mut runner = runner!(euler);
    let (density, x_m, y_m, z_m, energy) = async_std::task::block_on(async {
        runner
            .run(
                nelr as u64,
                iterations as u64,
                variables.to(),
                areas.to(),
                elements_surrounding_elements.to(),
                normals.to(),
                ff_variable.to(),
                ff_fc_density_energy.to(),
                ff_fc_momentum_x.to(),
                ff_fc_momentum_y.to(),
                ff_fc_momentum_z.to(),
            )
            .await
    });

    let total = vec![];
    let density = HerculesMutBox::from(density).as_slice();
    let x_m = HerculesMutBox::from(x_m).as_slice();
    let y_m = HerculesMutBox::from(y_m).as_slice();
    let z_m = HerculesMutBox::from(z_m).as_slice();
    let energy = HerculesMutBox::from(energy).as_slice();
    total.extend(density.into_iter().map(|id: &mut f32| *id));
    total.extend(x_m.into_iter().map(|id: &mut f32| *id));
    total.extend(y_m.into_iter().map(|id: &mut f32| *id));
    total.extend(z_m.into_iter().map(|id: &mut f32| *id));
    total.extend(energy.into_iter().map(|id: &mut f32| *id));
    total
}

fn run_pre_euler(
    nelr: usize,
    iterations: usize,
    mut variables: AlignedSlice<f32>,
    areas: &[f32],
    elements_surrounding_elements: &[i32],
    normals: &[f32],
    ff_variable: &[f32],
    ff_fc_density_energy: &Float3,
    ff_fc_momentum_x: &Float3,
    ff_fc_momentum_y: &Float3,
    ff_fc_momentum_z: &Float3,
) -> Vec<f32> {
    let mut variables = HerculesMutBox::from(variables.as_mut_slice());
    let areas = HerculesImmBox::from(areas);
    let elements_surrounding_elements = HerculesImmBox::from(elements_surrounding_elements);
    let normals = HerculesImmBox::from(normals);
    let ff_variable = HerculesImmBox::from(ff_variable);

    // TODO: Make hercules box handle structs, for now we'll copy into a vec
    let ff_fc_density_energy = vec![
        ff_fc_density_energy.x,
        ff_fc_density_energy.y,
        ff_fc_density_energy.z,
    ];
    let ff_fc_density_energy = HerculesImmBox::from(ff_fc_density_energy.as_slice());
    let ff_fc_momentum_x = vec![ff_fc_momentum_x.x, ff_fc_momentum_x.y, ff_fc_momentum_x.z];
    let ff_fc_momentum_x = HerculesImmBox::from(ff_fc_momentum_x.as_slice());
    let ff_fc_momentum_y = vec![ff_fc_momentum_y.x, ff_fc_momentum_y.y, ff_fc_momentum_y.z];
    let ff_fc_momentum_y = HerculesImmBox::from(ff_fc_momentum_y.as_slice());
    let ff_fc_momentum_z = vec![ff_fc_momentum_z.x, ff_fc_momentum_z.y, ff_fc_momentum_z.z];
    let ff_fc_momentum_z = HerculesImmBox::from(ff_fc_momentum_z.as_slice());

    let mut runner = runner!(pre_euler);

    let variables = variables.to();

    let (density, x_m, y_m, z_m, energy) = async_std::task::block_on(async {
        runner
            .run(
                nelr as u64,
                iterations as u64,
                variables,
                areas.to(),
                elements_surrounding_elements.to(),
                normals.to(),
                ff_variable.to(),
                ff_fc_density_energy.to(),
                ff_fc_momentum_x.to(),
                ff_fc_momentum_y.to(),
                ff_fc_momentum_z.to(),
            )
            .await
    });

    let total = vec![];
    let density = HerculesMutBox::from(density).as_slice();
    let x_m = HerculesMutBox::from(x_m).as_slice();
    let y_m = HerculesMutBox::from(y_m).as_slice();
    let z_m = HerculesMutBox::from(z_m).as_slice();
    let energy = HerculesMutBox::from(energy).as_slice();
    total.extend(density.into_iter().map(|id: &mut f32| *id));
    total.extend(x_m.into_iter().map(|id: &mut f32| *id));
    total.extend(y_m.into_iter().map(|id: &mut f32| *id));
    total.extend(z_m.into_iter().map(|id: &mut f32| *id));
    total.extend(energy.into_iter().map(|id: &mut f32| *id));
    total
}

fn compare_float(x: f32, y: f32) -> bool {
    (x - y).abs() < 1e-5
}

fn compare_floats(xs: &[f32], ys: &[f32]) -> bool {
    xs.len() == ys.len() && xs.iter().zip(ys.iter()).all(|(x, y)| compare_float(*x, *y))
}

pub fn cfd_harness(args: CFDInputs) {
    let CFDInputs {
        data_file,
        iterations,
        block_size,
        pre_euler,
    } = args;

    assert!(block_size % 16 == 0, "Hercules expects all arrays to be 64-byte aligned, cfd uses structs of arrays that are annoying to deal with if the block_size is not a multiple of 16");

    let FarFieldConditions {
        ff_variable,
        ff_fc_momentum_x,
        ff_fc_momentum_y,
        ff_fc_momentum_z,
        ff_fc_density_energy,
    } = set_far_field_conditions();

    let GeometryData {
        nelr,
        areas,
        elements_surrounding_elements,
        normals,
    } = read_domain_geometry(data_file, block_size);

    let variables = initialize_variables(nelr, ff_variable.as_slice());

    let res_juno = if pre_euler {
        run_pre_euler(
            nelr,
            iterations,
            variables.clone(),
            areas.as_slice(),
            elements_surrounding_elements.as_slice(),
            normals.as_slice(),
            ff_variable.as_slice(),
            &ff_fc_density_energy,
            &ff_fc_momentum_x,
            &ff_fc_momentum_y,
            &ff_fc_momentum_z,
        )
    } else {
        run_euler(
            nelr,
            iterations,
            variables.clone(),
            areas.as_slice(),
            elements_surrounding_elements.as_slice(),
            normals.as_slice(),
            ff_variable.as_slice(),
            &ff_fc_density_energy,
            &ff_fc_momentum_x,
            &ff_fc_momentum_y,
            &ff_fc_momentum_z,
        )
    };
    let res_rust = if pre_euler {
        rust_cfd::pre_euler(
            nelr,
            iterations,
            variables,
            areas.as_slice(),
            elements_surrounding_elements.as_slice(),
            normals.as_slice(),
            ff_variable.as_slice(),
            &ff_fc_density_energy,
            &ff_fc_momentum_x,
            &ff_fc_momentum_y,
            &ff_fc_momentum_z,
        )
    } else {
        rust_cfd::euler(
            nelr,
            iterations,
            variables,
            areas.as_slice(),
            elements_surrounding_elements.as_slice(),
            normals.as_slice(),
            ff_variable.as_slice(),
            &ff_fc_density_energy,
            &ff_fc_momentum_x,
            &ff_fc_momentum_y,
            &ff_fc_momentum_z,
        )
    };

    if !compare_floats(&res_juno, res_rust.as_slice()) {
        assert_eq!(res_juno.len(), res_rust.as_slice().len());
        panic!("Mismatch in results");
    }
}