Skip to content
Snippets Groups Projects
matmul.sch 1.35 KiB
macro optimize!(X) {
  gvn(X);
  phi-elim(X);
  dce(X);
  ip-sroa(X);
  sroa(X);
  dce(X);
  gvn(X);
  phi-elim(X);
  dce(X);
}

macro codegen-prep!(X) {
  optimize!(X);
  gcm(X);
  float-collections(X);
  dce(X);
  gcm(X);
}

macro forkify!(X) {
  fixpoint {
    forkify(X);
    fork-guard-elim(X);
  }
}

macro fork-chunk![n](X) {
  fork-tile[n, 0, false, false](X);
}

macro fork-tile![n](X) {
  fork-tile[n, 0, false, true](X);
}

macro parallelize!(X) {
  parallel-fork(X);
  parallel-reduce(X);
}

macro unforkify!(X) {
  fork-split(X);
  unforkify(X);
}

optimize!(*);
forkify!(*);

if feature("cuda") {
  fixpoint {
    reduce-slf(*);
    slf(*);
    infer-schedules(*);
  }
  fork-coalesce(*);
  infer-schedules(*);
  dce(*);
  //rewrite(*);
  let out = outline(matmul@outer);
  gpu(out);
  fixpoint {
    simplify-cfg(*);
    dce(*);
  }

  optimize!(*);
  codegen-prep!(*);
} else {
  associative(matmul@outer);
  // Parallelize by computing output array as 16 chunks
  let par = matmul@outer \ matmul@inner;
  fork-chunk![4](par);
  let (outer, inner, _) = fork-reshape[[0, 2], [1], [3]](par);
  parallelize!(outer \ inner);

  let body = outline(inner);
  cpu(body);

  // Tile for cache, assuming 64B cache lines
  fork-tile![16](body);
  let (outer, inner) = fork-reshape[[0, 2, 4, 1, 3], [5]](body);

  reduce-slf(inner);
  unforkify!(body);
  codegen-prep!(*);
}