Skip to content
Snippets Groups Projects

Compare revisions

Changes are shown as if the source revision was being merged into the target revision. Learn more about comparing revisions.

Source

Select target project
No results found

Target

Select target project
  • llvm/hercules
1 result
Show changes
Commits on Source (6)
Showing
with 274 additions and 151 deletions
...@@ -435,10 +435,9 @@ pub fn debug_print_math_expr(id: MathID, env: &MathEnv) { ...@@ -435,10 +435,9 @@ pub fn debug_print_math_expr(id: MathID, env: &MathEnv) {
} }
MathExpr::IntrinsicFunc(intrinsic, ref args) => { MathExpr::IntrinsicFunc(intrinsic, ref args) => {
print!("{}(", intrinsic.lower_case_name()); print!("{}(", intrinsic.lower_case_name());
debug_print_math_expr(id, env);
for arg in args { for arg in args {
print!(", ");
debug_print_math_expr(*arg, env); debug_print_math_expr(*arg, env);
print!(", ");
} }
print!(")"); print!(")");
} }
......
...@@ -822,7 +822,7 @@ fn typeflow( ...@@ -822,7 +822,7 @@ fn typeflow(
// We also return the return type from here // We also return the return type from here
match intrinsic { match intrinsic {
// Intrinsics that take any numeric type and return the same // Intrinsics that take any numeric type and return the same
Intrinsic::Abs => { Intrinsic::Abs | Intrinsic::Max | Intrinsic::Min => {
if let Concrete(id) = inputs[0] { if let Concrete(id) = inputs[0] {
if types[id.idx()].is_arithmetic() { if types[id.idx()].is_arithmetic() {
Concrete(*id) Concrete(*id)
...@@ -856,8 +856,6 @@ fn typeflow( ...@@ -856,8 +856,6 @@ fn typeflow(
| Intrinsic::Ln1P | Intrinsic::Ln1P
| Intrinsic::Log10 | Intrinsic::Log10
| Intrinsic::Log2 | Intrinsic::Log2
| Intrinsic::Max
| Intrinsic::Min
| Intrinsic::Round | Intrinsic::Round
| Intrinsic::Sin | Intrinsic::Sin
| Intrinsic::Sinh | Intrinsic::Sinh
......
...@@ -106,7 +106,7 @@ pub fn forkify_loop( ...@@ -106,7 +106,7 @@ pub fn forkify_loop(
else { else {
return false; return false;
}; };
// Compute loop variance // Compute loop variance
let loop_variance = compute_loop_variance(editor, l); let loop_variance = compute_loop_variance(editor, l);
let ivs = compute_induction_vars(editor.func(), l, &loop_variance); let ivs = compute_induction_vars(editor.func(), l, &loop_variance);
...@@ -530,10 +530,11 @@ pub fn analyze_phis<'a>( ...@@ -530,10 +530,11 @@ pub fn analyze_phis<'a>(
let intersection: HashSet<_> = set1.intersection(&set2).cloned().collect(); let intersection: HashSet<_> = set1.intersection(&set2).cloned().collect();
// If this phi uses any other phis the node is loop dependant, // If this phi uses any other phis the node is loop dependant,
// we use `phis` because this phi can actually contain the loop iv and its fine. // // we use `phis` because this phi can actually contain the loop iv and its fine.
if uses_for_dependance.any(|node| phis.contains(&node) && node != *phi) { // if uses_for_dependance.any(|node| phis.contains(&node) && node != *phi) {
LoopPHI::LoopDependant(*phi) // LoopPHI::LoopDependant(*phi)
} else if intersection.clone().iter().next().is_some() { // } else
if intersection.clone().iter().next().is_some() {
// PHIs on the frontier of the uses by the candidate phi, i.e in uses_for_dependance need // PHIs on the frontier of the uses by the candidate phi, i.e in uses_for_dependance need
// to have headers that postdominate the loop continue latch. The value of the PHI used needs to be defined // to have headers that postdominate the loop continue latch. The value of the PHI used needs to be defined
// by the time the reduce is triggered (at the end of the loop's internal control). // by the time the reduce is triggered (at the end of the loop's internal control).
......
...@@ -275,7 +275,7 @@ pub fn canonicalize_single_loop_bounds( ...@@ -275,7 +275,7 @@ pub fn canonicalize_single_loop_bounds(
let new_binop_node = edit.add_node(Node::Binary { left, right: blah, op: BinaryOperator::LT }); let new_binop_node = edit.add_node(Node::Binary { left, right: blah, op: BinaryOperator::LT });
edit = edit.replace_all_uses_where(binop_node, new_binop_node, |usee| *usee == if_node)?; edit = edit.replace_all_uses_where(binop_node, new_binop_node, |usee| *usee == if_node)?;
Some((init_id, bound_id, new_binop_node, if_node)) Some((init_id, blah, new_binop_node, if_node))
} else {guard_info}; } else {guard_info};
edit = edit.replace_all_uses_where(dc_bound_node, new_dc_bound_node, |usee| *usee == new_bop)?; edit = edit.replace_all_uses_where(dc_bound_node, new_dc_bound_node, |usee| *usee == new_bop)?;
...@@ -289,7 +289,6 @@ pub fn canonicalize_single_loop_bounds( ...@@ -289,7 +289,6 @@ pub fn canonicalize_single_loop_bounds(
}; };
Ok(edit) Ok(edit)
}); });
let update_expr_users: Vec<_> = editor let update_expr_users: Vec<_> = editor
.get_users(*update_expression) .get_users(*update_expression)
......
...@@ -3,7 +3,7 @@ gvn(*); ...@@ -3,7 +3,7 @@ gvn(*);
phi-elim(*); phi-elim(*);
dce(*); dce(*);
let auto = auto-outline(test1, test2, test3, test4, test5, test7, test8, test9); let auto = auto-outline(test1, test2, test3, test4, test5, test7, test8, test9, test10);
cpu(auto.test1); cpu(auto.test1);
cpu(auto.test2); cpu(auto.test2);
cpu(auto.test3); cpu(auto.test3);
...@@ -12,6 +12,7 @@ cpu(auto.test5); ...@@ -12,6 +12,7 @@ cpu(auto.test5);
cpu(auto.test7); cpu(auto.test7);
cpu(auto.test8); cpu(auto.test8);
cpu(auto.test9); cpu(auto.test9);
cpu(auto.test10);
let test1_cpu = auto.test1; let test1_cpu = auto.test1;
rename["test1_cpu"](test1_cpu); rename["test1_cpu"](test1_cpu);
...@@ -94,6 +95,11 @@ dce(auto.test8); ...@@ -94,6 +95,11 @@ dce(auto.test8);
simplify-cfg(auto.test8); simplify-cfg(auto.test8);
dce(auto.test8); dce(auto.test8);
no-memset(test9@const); array-slf(auto.test10);
ccp(auto.test10);
dce(auto.test10);
simplify-cfg(auto.test10);
dce(auto.test10);
unforkify(auto.test10);
gcm(*); gcm(*);
...@@ -147,3 +147,16 @@ fn test9<r, c : usize>(input : i32[r, c]) -> i32[r, c] { ...@@ -147,3 +147,16 @@ fn test9<r, c : usize>(input : i32[r, c]) -> i32[r, c] {
return out; return out;
} }
#[entry]
fn test10(k1 : i32[8], k2 : i32[8], v : i32[8]) -> i32 {
@const let s : i32[8];
for i = 0 to 8 {
s[i] = v[k1[i] as u64];
}
let sum = 0;
for i = 0 to 8 {
sum += s[k2[i] as u64];
}
return sum;
}
\ No newline at end of file
...@@ -8,12 +8,13 @@ no-memset(test6@const); ...@@ -8,12 +8,13 @@ no-memset(test6@const);
no-memset(test8@const1); no-memset(test8@const1);
no-memset(test8@const2); no-memset(test8@const2);
no-memset(test9@const); no-memset(test9@const);
no-memset(test10@const);
gvn(*); gvn(*);
phi-elim(*); phi-elim(*);
dce(*); dce(*);
let auto = auto-outline(test1, test2, test3, test4, test5, test7, test8, test9); let auto = auto-outline(test1, test2, test3, test4, test5, test7, test8, test9, test10);
gpu(auto.test1); gpu(auto.test1);
gpu(auto.test2); gpu(auto.test2);
gpu(auto.test3); gpu(auto.test3);
...@@ -22,6 +23,7 @@ gpu(auto.test5); ...@@ -22,6 +23,7 @@ gpu(auto.test5);
gpu(auto.test7); gpu(auto.test7);
gpu(auto.test8); gpu(auto.test8);
gpu(auto.test9); gpu(auto.test9);
gpu(auto.test10);
ip-sroa(*); ip-sroa(*);
sroa(*); sroa(*);
......
...@@ -74,6 +74,20 @@ fn main() { ...@@ -74,6 +74,20 @@ fn main() {
5 + 6 + 8 + 9, 5 + 6 + 8 + 9,
]; ];
assert(&correct, output); assert(&correct, output);
let mut r = runner!(test10);
let k1 = vec![0, 4, 3, 7, 3, 4, 2, 1];
let k2 = vec![6, 4, 3, 2, 4, 1, 0, 5];
let v = vec![3, -499, 4, 32, -2, 55, -74, 10];
let mut correct = 0;
for i in 0..8 {
correct += v[k1[k2[i] as usize] as usize];
}
let k1 = HerculesImmBox::from(&k1 as &[i32]);
let k2 = HerculesImmBox::from(&k2 as &[i32]);
let v = HerculesImmBox::from(&v as &[i32]);
let output = r.run(k1.to(), k2.to(), v.to()).await;
assert_eq!(output, correct);
}); });
} }
......
use juno_build::JunoCompiler; use juno_build::JunoCompiler;
fn main() { fn main() {
#[cfg(not(feature = "cuda"))] JunoCompiler::new()
{ .file_in_src("matmul.jn")
JunoCompiler::new() .unwrap()
.file_in_src("matmul.jn") .schedule_in_src("matmul.sch")
.unwrap() .unwrap()
.schedule_in_src("cpu.sch") .build()
.unwrap() .unwrap();
.build()
.unwrap();
}
#[cfg(feature = "cuda")]
{
JunoCompiler::new()
.file_in_src("matmul.jn")
.unwrap()
.schedule_in_src("gpu.sch")
.unwrap()
.build()
.unwrap();
}
} }
phi-elim(*);
forkify(*);
fork-guard-elim(*);
dce(*);
fixpoint {
reduce-slf(*);
slf(*);
infer-schedules(*);
}
fork-coalesce(*);
infer-schedules(*);
dce(*);
rewrite(*);
fixpoint {
simplify-cfg(*);
dce(*);
}
ip-sroa(*);
sroa(*);
dce(*);
float-collections(*);
gcm(*);
...@@ -41,21 +41,41 @@ macro unforkify!(X) { ...@@ -41,21 +41,41 @@ macro unforkify!(X) {
optimize!(*); optimize!(*);
forkify!(*); forkify!(*);
associative(matmul@outer);
// Parallelize by computing output array as 16 chunks if feature("cuda") {
let par = matmul@outer \ matmul@inner; fixpoint {
fork-tile![4](par); reduce-slf(*);
let (outer, inner, _) = fork-reshape[[1, 3], [0], [2]](par); slf(*);
parallelize!(outer \ inner); infer-schedules(*);
}
fork-coalesce(*);
infer-schedules(*);
dce(*);
rewrite(*);
fixpoint {
simplify-cfg(*);
dce(*);
}
let body = outline(inner); optimize!(*);
cpu(body); codegen-prep!(*);
} else {
associative(matmul@outer);
// Tile for cache, assuming 64B cache lines // Parallelize by computing output array as 16 chunks
fork-tile![16](body); let par = matmul@outer \ matmul@inner;
let (outer, inner) = fork-reshape[[0, 2, 4, 1, 3], [5]](body); fork-tile![4](par);
let (outer, inner, _) = fork-reshape[[1, 3], [0], [2]](par);
parallelize!(outer \ inner);
reduce-slf(inner); let body = outline(inner);
unforkify!(body); cpu(body);
codegen-prep!(*);
// Tile for cache, assuming 64B cache lines
fork-tile![16](body);
let (outer, inner) = fork-reshape[[0, 2, 4, 1, 3], [5]](body);
reduce-slf(inner);
unforkify!(body);
codegen-prep!(*);
}
...@@ -24,19 +24,6 @@ fn srad_bench(c: &mut Criterion) { ...@@ -24,19 +24,6 @@ fn srad_bench(c: &mut Criterion) {
} = read_graphics(image); } = read_graphics(image);
let image = resize(&image_ori, image_ori_rows, image_ori_cols, nrows, ncols); let image = resize(&image_ori, image_ori_rows, image_ori_cols, nrows, ncols);
let mut image_h = HerculesMutBox::from(image.clone()); let mut image_h = HerculesMutBox::from(image.clone());
let mut iN = (0..nrows).map(|i| i as i32 - 1).collect::<Vec<_>>();
let mut iS = (0..nrows).map(|i| i as i32 + 1).collect::<Vec<_>>();
let mut jW = (0..ncols).map(|j| j as i32 - 1).collect::<Vec<_>>();
let mut jE = (0..ncols).map(|j| j as i32 + 1).collect::<Vec<_>>();
// Fix boundary conditions
iN[0] = 0;
iS[nrows - 1] = (nrows - 1) as i32;
jW[0] = 0;
jE[ncols - 1] = (ncols - 1) as i32;
let iN_h = HerculesImmBox::from(iN.as_slice());
let iS_h = HerculesImmBox::from(iS.as_slice());
let jW_h = HerculesImmBox::from(jW.as_slice());
let jE_h = HerculesImmBox::from(jE.as_slice());
group.bench_function("srad bench", |b| { group.bench_function("srad bench", |b| {
b.iter(|| { b.iter(|| {
async_std::task::block_on(async { async_std::task::block_on(async {
...@@ -45,10 +32,6 @@ fn srad_bench(c: &mut Criterion) { ...@@ -45,10 +32,6 @@ fn srad_bench(c: &mut Criterion) {
ncols as u64, ncols as u64,
niter as u64, niter as u64,
image_h.to(), image_h.to(),
iN_h.to(),
iS_h.to(),
jW_h.to(),
jE_h.to(),
max, max,
lambda, lambda,
) )
......
...@@ -8,6 +8,7 @@ macro simpl!(X) { ...@@ -8,6 +8,7 @@ macro simpl!(X) {
infer-schedules(X); infer-schedules(X);
} }
no-memset(srad@scratch);
phi-elim(*); phi-elim(*);
let loop1 = outline(srad@loop1); let loop1 = outline(srad@loop1);
let loop2 = outline(srad@loop2); let loop2 = outline(srad@loop2);
...@@ -31,8 +32,18 @@ simpl!(*); ...@@ -31,8 +32,18 @@ simpl!(*);
fork-interchange[0, 1](loop1); fork-interchange[0, 1](loop1);
reduce-slf(*); reduce-slf(*);
simpl!(*); simpl!(*);
slf(*);
simpl!(*);
fork-tile[32, 0, false, false](loop2);
let split = fork-split(loop2);
let loop2_body = outline(split.srad_1.fj1);
simpl!(loop2, loop2_body);
inline(srad@loop2);
delete-uncalled(*);
fork-split(*); fork-split(extract, compress, loop1, loop2_body, loop3);
unforkify(*); unforkify(extract, compress, loop1, loop2_body, loop3);
gcm(*); gcm(*);
...@@ -8,6 +8,7 @@ macro simpl!(X) { ...@@ -8,6 +8,7 @@ macro simpl!(X) {
infer-schedules(X); infer-schedules(X);
} }
no-memset(srad@scratch);
phi-elim(*); phi-elim(*);
let sum_loop = outline(srad@loop1); let sum_loop = outline(srad@loop1);
let main_loops = outline(srad@loop2 | srad@loop3); let main_loops = outline(srad@loop2 | srad@loop3);
...@@ -41,15 +42,26 @@ fork-tile[32, 0, false, true](sum_loop); ...@@ -41,15 +42,26 @@ fork-tile[32, 0, false, true](sum_loop);
let out = fork-split(sum_loop); let out = fork-split(sum_loop);
clean-monoid-reduces(sum_loop); clean-monoid-reduces(sum_loop);
simpl!(sum_loop); simpl!(sum_loop);
let fission = fork-fission[out.srad_0.fj0](sum_loop);
let fission1 = fork-fission[out.srad_0.fj0](sum_loop);
simpl!(sum_loop);
fork-tile[32, 0, false, true](fission1.srad_0.fj_bottom);
let out = fork-split(fission1.srad_0.fj_bottom);
clean-monoid-reduces(sum_loop);
simpl!(sum_loop);
let fission2 = fork-fission[out.srad_0.fj0](sum_loop);
simpl!(sum_loop); simpl!(sum_loop);
fork-tile[32, 0, false, true](fission.srad_0.fj_bottom); fork-tile[32, 0, false, true](fission2.srad_0.fj_bottom);
let out = fork-split(fission.srad_0.fj_bottom); let out = fork-split(fission2.srad_0.fj_bottom);
clean-monoid-reduces(sum_loop); clean-monoid-reduces(sum_loop);
simpl!(sum_loop); simpl!(sum_loop);
let top = outline(fission.srad_0.fj_top);
let bottom = outline(out.srad_0.fj0); let first = outline(fission1.srad_0.fj_top);
gpu(top, bottom); let second = outline(fission2.srad_0.fj_top);
let third = outline(out.srad_0.fj0);
gpu(first, second, third);
const-inline[false](*);
ip-sroa(*); ip-sroa(*);
sroa(*); sroa(*);
simpl!(*); simpl!(*);
...@@ -60,4 +72,16 @@ dce(main_loops); ...@@ -60,4 +72,16 @@ dce(main_loops);
fork-split(main_loops); fork-split(main_loops);
simpl!(main_loops); simpl!(main_loops);
fork-dim-merge(extract);
fork-tile[32, 0, false, true](extract);
dce(extract);
fork-split(extract);
simpl!(extract);
fork-dim-merge(compress);
fork-tile[32, 0, false, true](compress);
dce(compress);
fork-split(compress);
simpl!(compress);
gcm(*); gcm(*);
...@@ -48,22 +48,6 @@ pub fn srad_harness(args: SRADInputs) { ...@@ -48,22 +48,6 @@ pub fn srad_harness(args: SRADInputs) {
let image = resize(&image_ori, image_ori_rows, image_ori_cols, nrows, ncols); let image = resize(&image_ori, image_ori_rows, image_ori_cols, nrows, ncols);
let mut image_h = HerculesMutBox::from(image.clone()); let mut image_h = HerculesMutBox::from(image.clone());
let mut iN = (0..nrows).map(|i| i as i32 - 1).collect::<Vec<_>>();
let mut iS = (0..nrows).map(|i| i as i32 + 1).collect::<Vec<_>>();
let mut jW = (0..ncols).map(|j| j as i32 - 1).collect::<Vec<_>>();
let mut jE = (0..ncols).map(|j| j as i32 + 1).collect::<Vec<_>>();
// Fix boundary conditions
iN[0] = 0;
iS[nrows - 1] = (nrows - 1) as i32;
jW[0] = 0;
jE[ncols - 1] = (ncols - 1) as i32;
let iN_h = HerculesImmBox::from(iN.as_slice());
let iS_h = HerculesImmBox::from(iS.as_slice());
let jW_h = HerculesImmBox::from(jW.as_slice());
let jE_h = HerculesImmBox::from(jE.as_slice());
let mut runner = runner!(srad); let mut runner = runner!(srad);
let result: Vec<f32> = HerculesMutBox::from( let result: Vec<f32> = HerculesMutBox::from(
runner runner
...@@ -72,10 +56,6 @@ pub fn srad_harness(args: SRADInputs) { ...@@ -72,10 +56,6 @@ pub fn srad_harness(args: SRADInputs) {
ncols as u64, ncols as u64,
niter as u64, niter as u64,
image_h.to(), image_h.to(),
iN_h.to(),
iS_h.to(),
jW_h.to(),
jE_h.to(),
max, max,
lambda, lambda,
) )
...@@ -90,18 +70,7 @@ pub fn srad_harness(args: SRADInputs) { ...@@ -90,18 +70,7 @@ pub fn srad_harness(args: SRADInputs) {
if verify { if verify {
let mut rust_result = image; let mut rust_result = image;
rust_srad::srad( rust_srad::srad(nrows, ncols, niter, &mut rust_result, max, lambda);
nrows,
ncols,
niter,
&mut rust_result,
&iN,
&iS,
&jW,
&jE,
max,
lambda,
);
if let Some(output) = output_verify { if let Some(output) = output_verify {
write_graphics(output, &rust_result, nrows, ncols, max); write_graphics(output, &rust_result, nrows, ncols, max);
......
pub fn srad( pub fn srad(nrows: usize, ncols: usize, niter: usize, image: &mut Vec<f32>, max: f32, lambda: f32) {
nrows: usize,
ncols: usize,
niter: usize,
image: &mut Vec<f32>,
iN: &[i32],
iS: &[i32],
jW: &[i32],
jE: &[i32],
max: f32,
lambda: f32,
) {
let nelems = nrows * ncols; let nelems = nrows * ncols;
// EXTRACT // EXTRACT
...@@ -44,11 +33,15 @@ pub fn srad( ...@@ -44,11 +33,15 @@ pub fn srad(
for i in 0..nrows { for i in 0..nrows {
let k = i + nrows * j; let k = i + nrows * j;
let Jc = image[k]; let Jc = image[k];
let iN = std::cmp::max(i, 1) - 1;
let iS = std::cmp::min(i, nrows - 2) + 1;
let jW = std::cmp::max(j, 1) - 1;
let jE = std::cmp::min(j, ncols - 2) + 1;
dN[k] = image[iN[i] as usize + nrows * j] - Jc; dN[k] = image[iN as usize + nrows * j] - Jc;
dS[k] = image[iS[i] as usize + nrows * j] - Jc; dS[k] = image[iS as usize + nrows * j] - Jc;
dW[k] = image[i + nrows * jW[j] as usize] - Jc; dW[k] = image[i + nrows * jW as usize] - Jc;
dE[k] = image[i + nrows * jE[j] as usize] - Jc; dE[k] = image[i + nrows * jE as usize] - Jc;
let G2 = let G2 =
(dN[k] * dN[k] + dS[k] * dS[k] + dW[k] * dW[k] + dE[k] * dE[k]) / (Jc * Jc); (dN[k] * dN[k] + dS[k] * dS[k] + dW[k] * dW[k] + dE[k] * dE[k]) / (Jc * Jc);
...@@ -72,11 +65,13 @@ pub fn srad( ...@@ -72,11 +65,13 @@ pub fn srad(
for j in 0..ncols { for j in 0..ncols {
for i in 0..nrows { for i in 0..nrows {
let k = i + nrows * j; let k = i + nrows * j;
let iS = std::cmp::min(i, nrows - 2) + 1;
let jE = std::cmp::min(j, ncols - 2) + 1;
let cN = c[k]; let cN = c[k];
let cS = c[iS[i] as usize + nrows * j]; let cS = c[iS as usize + nrows * j];
let cW = c[k]; let cW = c[k];
let cE = c[i + nrows * jE[j] as usize]; let cE = c[i + nrows * jE as usize];
let D = cN * dN[k] + cS * dS[k] + cW * dW[k] + cE * dE[k]; let D = cN * dN[k] + cS * dS[k] + cW * dW[k] + cE * dE[k];
......
...@@ -21,10 +21,6 @@ fn compress<nrows, ncols: usize>(inout image: f32[ncols, nrows], max: f32) { ...@@ -21,10 +21,6 @@ fn compress<nrows, ncols: usize>(inout image: f32[ncols, nrows], max: f32) {
fn srad<nrows, ncols: usize>( fn srad<nrows, ncols: usize>(
niter: usize, niter: usize,
inout image: f32[ncols, nrows], inout image: f32[ncols, nrows],
iN: i32[nrows],
iS: i32[nrows],
jW: i32[ncols],
jE: i32[ncols],
max: f32, max: f32,
lambda: f32, lambda: f32,
) { ) {
...@@ -50,20 +46,25 @@ fn srad<nrows, ncols: usize>( ...@@ -50,20 +46,25 @@ fn srad<nrows, ncols: usize>(
let varROI = (sum2 / nelems as f32) - meanROI * meanROI; let varROI = (sum2 / nelems as f32) - meanROI * meanROI;
let q0sqr = varROI / (meanROI * meanROI); let q0sqr = varROI / (meanROI * meanROI);
@dirs let dN : f32[ncols, nrows]; @scratch let dN : f32[ncols, nrows];
@dirs let dS : f32[ncols, nrows]; @scratch let dS : f32[ncols, nrows];
@dirs let dE : f32[ncols, nrows]; @scratch let dE : f32[ncols, nrows];
@dirs let dW : f32[ncols, nrows]; @scratch let dW : f32[ncols, nrows];
let c : f32[ncols, nrows]; @scratch let c : f32[ncols, nrows];
@loop2 for j in 0..ncols { @loop2 for j in 0..ncols {
for i in 0..nrows { for i in 0..nrows {
let Jc = image[j, i]; let Jc = image[j, i];
dN[j, i] = image[j, iN[i] as u64] - Jc; let iN = max!(i, 1) - 1;
dS[j, i] = image[j, iS[i] as u64] - Jc; let iS = min!(i, nrows - 2) + 1;
dW[j, i] = image[jW[j] as u64, i] - Jc; let jW = max!(j, 1) - 1;
dE[j, i] = image[jE[j] as u64, i] - Jc; let jE = min!(j, ncols - 2) + 1;
dN[j, i] = image[j, iN as u64] - Jc;
dS[j, i] = image[j, iS as u64] - Jc;
dW[j, i] = image[jW as u64, i] - Jc;
dE[j, i] = image[jE as u64, i] - Jc;
let G2 = (dN[j, i] * dN[j, i] + dS[j, i] * dS[j, i] let G2 = (dN[j, i] * dN[j, i] + dS[j, i] * dS[j, i]
+ dW[j, i] * dW[j, i] + dE[j, i] * dE[j, i]) / (Jc * Jc); + dW[j, i] * dW[j, i] + dE[j, i] * dE[j, i]) / (Jc * Jc);
...@@ -85,10 +86,13 @@ fn srad<nrows, ncols: usize>( ...@@ -85,10 +86,13 @@ fn srad<nrows, ncols: usize>(
@loop3 for j in 0..ncols { @loop3 for j in 0..ncols {
for i in 0..nrows { for i in 0..nrows {
let iS = min!(i, nrows - 2) + 1;
let jE = min!(j, ncols - 2) + 1;
let cN = c[j, i]; let cN = c[j, i];
let cS = c[j, iS[i] as u64]; let cS = c[j, iS as u64];
let cW = c[j, i]; let cW = c[j, i];
let cE = c[jE[j] as u64, i]; let cE = c[jE as u64, i];
let D = cN * dN[j, i] + cS * dS[j, i] + cW * dW[j, i] + cE * dE[j, i]; let D = cN * dN[j, i] + cS * dS[j, i] + cW * dW[j, i] + cE * dE[j, i];
image[j, i] = image[j, i] + 0.25 * lambda * D; image[j, i] = image[j, i] + 0.25 * lambda * D;
......
...@@ -22,6 +22,7 @@ pub enum ScheduleCompilerError { ...@@ -22,6 +22,7 @@ pub enum ScheduleCompilerError {
actual: usize, actual: usize,
loc: Location, loc: Location,
}, },
SemanticError(String, Location),
} }
impl fmt::Display for ScheduleCompilerError { impl fmt::Display for ScheduleCompilerError {
...@@ -46,6 +47,11 @@ impl fmt::Display for ScheduleCompilerError { ...@@ -46,6 +47,11 @@ impl fmt::Display for ScheduleCompilerError {
"({}, {}) -- ({}, {}): Expected {} arguments, found {}", "({}, {}) -- ({}, {}): Expected {} arguments, found {}",
loc.0 .0, loc.0 .1, loc.1 .0, loc.1 .1, expected, actual loc.0 .0, loc.0 .1, loc.1 .0, loc.1 .1, expected, actual
), ),
ScheduleCompilerError::SemanticError(msg, loc) => write!(
f,
"({}, {}) -- ({}, {}): {}",
loc.0 .0, loc.0 .1, loc.1 .0, loc.1 .1, msg,
),
} }
} }
} }
...@@ -76,6 +82,8 @@ enum Appliable { ...@@ -76,6 +82,8 @@ enum Appliable {
// DeleteUncalled requires special handling because it changes FunctionIDs, so it is not // DeleteUncalled requires special handling because it changes FunctionIDs, so it is not
// treated like a pass // treated like a pass
DeleteUncalled, DeleteUncalled,
// Test whether a feature is enabled
Feature,
Schedule(Schedule), Schedule(Schedule),
Device(Device), Device(Device),
} }
...@@ -85,6 +93,8 @@ impl Appliable { ...@@ -85,6 +93,8 @@ impl Appliable {
fn is_valid_num_args(&self, num: usize) -> bool { fn is_valid_num_args(&self, num: usize) -> bool {
match self { match self {
Appliable::Pass(pass) => pass.is_valid_num_args(num), Appliable::Pass(pass) => pass.is_valid_num_args(num),
// Testing whether a feature is enabled takes the feature instead of a selection, so it
// has 0 arguments
// Delete uncalled, Schedules, and devices do not take arguments // Delete uncalled, Schedules, and devices do not take arguments
_ => num == 0, _ => num == 0,
} }
...@@ -158,6 +168,8 @@ impl FromStr for Appliable { ...@@ -158,6 +168,8 @@ impl FromStr for Appliable {
"serialize" => Ok(Appliable::Pass(ir::Pass::Serialize)), "serialize" => Ok(Appliable::Pass(ir::Pass::Serialize)),
"write-predication" => Ok(Appliable::Pass(ir::Pass::WritePredication)), "write-predication" => Ok(Appliable::Pass(ir::Pass::WritePredication)),
"feature" => Ok(Appliable::Feature),
"print" => Ok(Appliable::Pass(ir::Pass::Print)), "print" => Ok(Appliable::Pass(ir::Pass::Print)),
"cpu" | "llvm" => Ok(Appliable::Device(Device::LLVM)), "cpu" | "llvm" => Ok(Appliable::Device(Device::LLVM)),
...@@ -275,6 +287,35 @@ fn compile_stmt( ...@@ -275,6 +287,35 @@ fn compile_stmt(
limit, limit,
}]) }])
} }
parser::Stmt::IfThenElse {
span: _,
cond,
thn,
els,
} => {
let cond = compile_exp_as_expr(cond, lexer, macrostab, macros)?;
macros.open_scope();
let thn = ir::ScheduleStmt::Block {
body: compile_ops_as_block(*thn, lexer, macrostab, macros)?,
};
macros.close_scope();
macros.open_scope();
let els = match els {
Some(els) => ir::ScheduleStmt::Block {
body: compile_ops_as_block(*els, lexer, macrostab, macros)?,
},
None => ir::ScheduleStmt::Block { body: vec![] },
};
macros.close_scope();
Ok(vec![ir::ScheduleStmt::IfThenElse {
cond,
thn: Box::new(thn),
els: Box::new(els),
}])
}
parser::Stmt::MacroDecl { span: _, def } => { parser::Stmt::MacroDecl { span: _, def } => {
let parser::MacroDecl { let parser::MacroDecl {
name, name,
...@@ -380,6 +421,17 @@ fn compile_expr( ...@@ -380,6 +421,17 @@ fn compile_expr(
on: selection, on: selection,
})) }))
} }
Appliable::Feature => match selection {
ir::Selector::Selection(mut args) if args.len() == 1 => {
Ok(ExprResult::Expr(ir::ScheduleExp::Feature {
feature: Box::new(args.pop().unwrap()),
}))
}
_ => Err(ScheduleCompilerError::SemanticError(
"feature requires exactly one argument as its selection".to_string(),
lexer.line_col(span),
)),
},
Appliable::Schedule(sched) => Ok(ExprResult::Stmt(ir::ScheduleStmt::AddSchedule { Appliable::Schedule(sched) => Ok(ExprResult::Stmt(ir::ScheduleStmt::AddSchedule {
sched, sched,
on: selection, on: selection,
......
...@@ -121,6 +121,9 @@ pub enum ScheduleExp { ...@@ -121,6 +121,9 @@ pub enum ScheduleExp {
DeleteUncalled { DeleteUncalled {
on: Selector, on: Selector,
}, },
Feature {
feature: Box<ScheduleExp>,
},
Record { Record {
fields: Vec<(String, ScheduleExp)>, fields: Vec<(String, ScheduleExp)>,
}, },
...@@ -180,4 +183,9 @@ pub enum ScheduleStmt { ...@@ -180,4 +183,9 @@ pub enum ScheduleStmt {
device: Device, device: Device,
on: Selector, on: Selector,
}, },
IfThenElse {
cond: ScheduleExp,
thn: Box<ScheduleStmt>,
els: Box<ScheduleStmt>,
},
} }
...@@ -20,12 +20,15 @@ ...@@ -20,12 +20,15 @@
\. "." \. "."
apply "apply" apply "apply"
else "else"
fixpoint "fixpoint" fixpoint "fixpoint"
if "if"
let "let" let "let"
macro "macro_keyword" macro "macro_keyword"
on "on" on "on"
set "set" set "set"
target "target" target "target"
then "then"
true "true" true "true"
false "false" false "false"
......