// HPVM put the extract and compress kernels in the host code, but the CUDA and
// OpenCL versions of Rodinia put them on the device
fn extract<nrows, ncols: usize>(inout image: f32[ncols, nrows], max: f32) {
  for j in 0..ncols {
    for i in 0..nrows {
      image[j, i] = exp!(image[j, i] / max);
    }
  }
}

fn compress<nrows, ncols: usize>(inout image: f32[ncols, nrows], max: f32) {
  for j in 0..ncols {
    for i in 0..nrows {
      image[j, i] = ln!(image[j, i]) * max;
    }
  }
}

// For unclear reasons the original source stores the image in column major
#[entry]
fn srad<nrows, ncols: usize>(
  niter: usize,
  inout image: f32[ncols, nrows],
  max: f32,
  lambda: f32,
) {
  const nelems = nrows * ncols;

  extract::<nrows, ncols>(&image, max);

  for iter in 0..niter {
    let sum = 0;
    let sum2 = 0;

    // These loops should really be interchanged, but they aren't in the
    // Rodinia source (though they are in the HPVM source)
    @loop1 for i in 0..nrows {
      for j in 0..ncols {
        let tmp = image[j, i];
        sum += tmp;
        sum2 += tmp * tmp;
      }
    }

    let meanROI = sum / nelems as f32;
    let varROI  = (sum2 / nelems as f32) - meanROI * meanROI;
    let q0sqr   = varROI / (meanROI * meanROI);

    @scratch let dN : f32[ncols, nrows];
    @scratch let dS : f32[ncols, nrows];
    @scratch let dE : f32[ncols, nrows];
    @scratch let dW : f32[ncols, nrows];

    @scratch let c : f32[ncols, nrows];

    @loop2 for j in 0..ncols {
      for i in 0..nrows {
        let Jc = image[j, i];
	let iN = max!(i, 1) - 1;
	let iS = min!(i, nrows - 2) + 1;
	let jW = max!(j, 1) - 1;
	let jE = min!(j, ncols - 2) + 1;

        dN[j, i] = image[j, iN as u64] - Jc;
        dS[j, i] = image[j, iS as u64] - Jc;
        dW[j, i] = image[jW as u64, i] - Jc;
        dE[j, i] = image[jE as u64, i] - Jc;

        let G2 = (dN[j, i] * dN[j, i] + dS[j, i] * dS[j, i]
                + dW[j, i] * dW[j, i] + dE[j, i] * dE[j, i]) / (Jc * Jc);

        let L = (dN[j, i] + dS[j, i] + dW[j, i] + dE[j, i]) / Jc;

        let num  = (0.5 * G2) - ((1.0 / 16.0) * (L * L));
        let den  = 1 + (0.25 * L);
        let qsqr = num / (den * den);

        let den = (qsqr - q0sqr) / (q0sqr * (1 + q0sqr));
	let val = 1.0 / (1.0 + den);

        if val < 0      { c[j, i] = 0; }
        else if val > 1 { c[j, i] = 1; }
	else            { c[j, i] = val; }
      }
    }

    @loop3 for j in 0..ncols {
      for i in 0..nrows {
	let iS = min!(i, nrows - 2) + 1;
	let jE = min!(j, ncols - 2) + 1;

        let cN = c[j, i];
        let cS = c[j, iS as u64];
        let cW = c[j, i];
        let cE = c[jE as u64, i];

        let D = cN * dN[j, i] + cS * dS[j, i] + cW * dW[j, i] + cE * dE[j, i];
        image[j, i] = image[j, i] + 0.25 * lambda * D;
      }
    }
  }

  compress::<nrows, ncols>(&image, max);
}