Skip to content
Snippets Groups Projects

Compare revisions

Changes are shown as if the source revision was being merged into the target revision. Learn more about comparing revisions.

Source

Select target project
No results found

Target

Select target project
  • llvm/hercules
1 result
Show changes
Commits on Source (4)
Showing
with 420 additions and 132 deletions
......@@ -354,6 +354,7 @@ impl GPUContext<'_> {
write!(
w,
"
#define _CG_ABI_EXPERIMENTAL
#include <assert.h>
#include <stdio.h>
#include <stddef.h>
......@@ -561,8 +562,9 @@ namespace cg = cooperative_groups;
* and writes.
*/
fn codegen_helpers(&self, w: &mut String) -> Result<(), Error> {
write!(w, "\t__shared__ cg::experimental::block_tile_memory<1024> block_sync_shared;\n")?;
write!(w, "\tcg::grid_group grid = cg::this_grid();\n")?;
write!(w, "\tcg::thread_block block = cg::this_thread_block();\n")?;
write!(w, "\tcg::thread_block block = cg::experimental::this_thread_block(block_sync_shared);\n")?;
Ok(())
}
......@@ -1294,7 +1296,7 @@ namespace cg = cooperative_groups;
}
if !is_primitive && state != KernelState::OutBlock {
write!(w, "{}}}\n", tabs)?;
write!(w, "{}{}.sync();\n", tabs, cg_tile)?;
//write!(w, "{}{}.sync();\n", tabs, cg_tile)?;
*num_tabs -= 1;
}
if !is_primitive && state == KernelState::OutBlock {
......@@ -1311,6 +1313,7 @@ namespace cg = cooperative_groups;
}
if !is_primitive
&& (state != KernelState::OutBlock || !is_block_parallel.unwrap_or(false))
&& !self.function.schedules[id.idx()].contains(&Schedule::NoResetConstant)
{
let data_size = self.get_size(self.typing[id.idx()], None);
write!(
......@@ -1321,6 +1324,7 @@ namespace cg = cooperative_groups;
write!(w, "{}\t*({} + i) = 0;\n", tabs, define_variable)?;
write!(w, "{}}}\n", tabs)?;
write!(w, "{}{}.sync();\n", tabs, cg_tile)?;
//write!(w, "__syncthreads\n")?;
}
}
// Dynamic constants emitted at top
......@@ -1595,7 +1599,7 @@ namespace cg = cooperative_groups;
write!(w, "{}\t*({} + {}.size() * ({} / {}.size()) + {}.thread_rank()) = *({} + {}.size() * ({} / {}.size()) + {}.thread_rank());\n", tabs, collect_with_indices, cg_tile, data_size, cg_tile, cg_tile, data_variable, cg_tile, data_size, cg_tile, cg_tile)?;
write!(w, "{}}}\n", tabs)?;
}
write!(w, "{}{}.sync();\n", tabs, cg_tile)?;
//write!(w, "{}{}.sync();\n", tabs, cg_tile)?;
let collect_variable = self.get_value(*collect, false, false);
write!(w, "{}{} = {};\n", tabs, define_variable, collect_variable)?;
}
......@@ -1705,20 +1709,20 @@ namespace cg = cooperative_groups;
};
write!(
thread_block_tiles,
"\tcg::thread_block_tile<{}> {} = cg::tiled_partition<{}>(block);\n",
"\tcg::thread_block_tile<{}> {} = cg::experimental::tiled_partition<{}>(block);\n",
use_thread_per_id, cg_tile, use_thread_per_id
)?;
let cg_tile_use = self.get_cg_tile(id, CGType::Use);
write!(
thread_block_tiles,
"\tcg::thread_block_tile<{}> {} = cg::tiled_partition<{}>(block);\n",
"\tcg::thread_block_tile<{}> {} = cg::experimental::tiled_partition<{}>(block);\n",
use_thread_quota, cg_tile_use, use_thread_quota
)?;
let available_thread_quota = available_thread_quota.unwrap();
let cg_tile_available = self.get_cg_tile(id, CGType::Available);
write!(
thread_block_tiles,
"\tcg::thread_block_tile<{}> {} = cg::tiled_partition<{}>(block);\n",
"\tcg::thread_block_tile<{}> {} = cg::experimental::tiled_partition<{}>(block);\n",
available_thread_quota, cg_tile_available, available_thread_quota
)?;
if parallel_factor.is_none() {
......@@ -1781,6 +1785,7 @@ namespace cg = cooperative_groups;
let fork = self.join_fork_map.get(&id).unwrap();
let cg_tile_available = self.get_cg_tile(*fork, CGType::Available);
write!(w_term, "\t{}.sync();\n", cg_tile_available)?;
//write!(w_term, "\t__syncthreads;\n")?;
}
// If the Fork was parallelized, each thread or UsedPerId tile of
// threads only runs one ThreadID, so we can jump straight to the
......
......@@ -319,12 +319,12 @@ pub fn fork_fission<'a>(
.collect();
let mut created_forks = Vec::new();
// This does the reduction fission
// This does the reduction fission
for fork in forks {
let join = fork_join_map[&fork.0];
// FIXME: Don't make multiple forks for reduces that are in cycles with each other.
// FIXME: Don't make multiple forks for reduces that are in cycles with each other.
let reduce_partition = default_reduce_partition(editor, fork.0, join);
if !editor.func().labels[fork.0.idx()].contains(&fork_label) {
......@@ -332,14 +332,19 @@ pub fn fork_fission<'a>(
}
if editor.is_mutable(fork.0) {
created_forks = fork_reduce_fission_helper(editor, fork_join_map, reduce_partition, nodes_in_fork_joins, fork.0);
created_forks = fork_reduce_fission_helper(
editor,
fork_join_map,
reduce_partition,
nodes_in_fork_joins,
fork.0,
);
if created_forks.is_empty() {
continue;
} else {
return created_forks;
}
}
}
created_forks
......@@ -503,13 +508,17 @@ pub fn fork_reduce_fission_helper<'a>(
let mut new_forks = Vec::new();
let mut new_control_pred: NodeID = editor.get_uses(fork).filter(|n| editor.node(n).is_control()).next().unwrap();
let mut new_control_pred: NodeID = editor
.get_uses(fork)
.filter(|n| editor.node(n).is_control())
.next()
.unwrap();
let mut new_fork = NodeID::new(0);
let mut new_join = NodeID::new(0);
let subgraph = &nodes_in_fork_joins[&fork];
let subgraph = &nodes_in_fork_joins[&fork];
// Gets everything between fork & join that this reduce needs. (ALL CONTROL)
editor.edit(|mut edit| {
for reduce in reduce_partition {
......@@ -522,7 +531,7 @@ pub fn fork_reduce_fission_helper<'a>(
new_fork = mapping[&fork];
new_forks.push(new_fork);
new_join = mapping[&join];
// Atttach new_fork after control_pred
let (old_control_pred, _) = edit.get_node(new_fork).try_fork().unwrap().clone();
edit = edit.replace_all_uses_where(old_control_pred, new_control_pred, |usee| {
......@@ -532,7 +541,7 @@ pub fn fork_reduce_fission_helper<'a>(
// Replace uses of reduce
edit = edit.replace_all_uses(reduce, mapping[&reduce])?;
new_control_pred = new_join;
};
}
// Replace original join w/ new final join
edit = edit.replace_all_uses_where(join, new_join, |_| true)?;
......@@ -1502,6 +1511,10 @@ fn fork_fusion(
* element. This aides in parallelizing outer loops. Looks only at reduces with
* the monoid reduce schedule, since that indicates a particular structure which
* is annoying to check for again.
*
* Looks for would-be monoid reduces, if not for a gate on the reduction.
* Partially predicate the gated reduction to allow for a proper monoid
* reduction.
*/
pub fn clean_monoid_reduces(editor: &mut FunctionEditor, typing: &Vec<TypeID>) {
for id in editor.node_ids() {
......@@ -1512,7 +1525,7 @@ pub fn clean_monoid_reduces(editor: &mut FunctionEditor, typing: &Vec<TypeID>) {
let Some((_, init, reduct)) = nodes[id.idx()].try_reduce() else {
continue;
};
let out_uses: Vec<_> = editor.get_users(id).filter(|id| *id != reduct).collect();
let out_users: Vec<_> = editor.get_users(id).filter(|id| *id != reduct).collect();
match nodes[reduct.idx()] {
Node::Binary {
......@@ -1520,7 +1533,8 @@ pub fn clean_monoid_reduces(editor: &mut FunctionEditor, typing: &Vec<TypeID>) {
left: _,
right: _,
} if (op == BinaryOperator::Add || op == BinaryOperator::Or)
&& !is_zero(editor, init) =>
&& !is_zero(editor, init)
&& !is_false(editor, init) =>
{
editor.edit(|mut edit| {
let zero = edit.add_zero_constant(typing[init.idx()]);
......@@ -1532,7 +1546,7 @@ pub fn clean_monoid_reduces(editor: &mut FunctionEditor, typing: &Vec<TypeID>) {
left: init,
right: id,
});
for u in out_uses {
for u in out_users {
edit.sub_edit(u, final_op);
}
edit.replace_all_uses_where(id, final_op, |u| *u != reduct && *u != final_op)
......@@ -1543,7 +1557,8 @@ pub fn clean_monoid_reduces(editor: &mut FunctionEditor, typing: &Vec<TypeID>) {
left: _,
right: _,
} if (op == BinaryOperator::Mul || op == BinaryOperator::And)
&& !is_one(editor, init) =>
&& !is_one(editor, init)
&& !is_true(editor, init) =>
{
editor.edit(|mut edit| {
let one = edit.add_one_constant(typing[init.idx()]);
......@@ -1555,7 +1570,7 @@ pub fn clean_monoid_reduces(editor: &mut FunctionEditor, typing: &Vec<TypeID>) {
left: init,
right: id,
});
for u in out_uses {
for u in out_users {
edit.sub_edit(u, final_op);
}
edit.replace_all_uses_where(id, final_op, |u| *u != reduct && *u != final_op)
......@@ -1574,7 +1589,7 @@ pub fn clean_monoid_reduces(editor: &mut FunctionEditor, typing: &Vec<TypeID>) {
intrinsic: Intrinsic::Max,
args: Box::new([init, id]),
});
for u in out_uses {
for u in out_users {
edit.sub_edit(u, final_op);
}
edit.replace_all_uses_where(id, final_op, |u| *u != reduct && *u != final_op)
......@@ -1593,7 +1608,7 @@ pub fn clean_monoid_reduces(editor: &mut FunctionEditor, typing: &Vec<TypeID>) {
intrinsic: Intrinsic::Min,
args: Box::new([init, id]),
});
for u in out_uses {
for u in out_users {
edit.sub_edit(u, final_op);
}
edit.replace_all_uses_where(id, final_op, |u| *u != reduct && *u != final_op)
......@@ -1602,6 +1617,65 @@ pub fn clean_monoid_reduces(editor: &mut FunctionEditor, typing: &Vec<TypeID>) {
_ => {}
}
}
for id in editor.node_ids() {
if !editor.func().schedules[id.idx()].contains(&Schedule::MonoidReduce) {
continue;
}
let nodes = &editor.func().nodes;
let Some((control, init, reduct)) = nodes[id.idx()].try_reduce() else {
continue;
};
if let Node::Phi {
control: phi_control,
ref data,
} = nodes[reduct.idx()]
&& data.len() == 2
&& data.contains(&id)
&& let other = *data
.into_iter()
.filter(|other| **other != id)
.next()
.unwrap()
&& let Node::Binary {
op: BinaryOperator::Add,
left,
right,
} = nodes[other.idx()]
&& ((left == id) ^ (right == id))
{
let gated_input = if left == id { right } else { left };
let data = data.clone();
editor.edit(|mut edit| {
let zero = edit.add_zero_constant(typing[id.idx()]);
let zero = edit.add_node(Node::Constant { id: zero });
let phi = edit.add_node(Node::Phi {
control: phi_control,
data: data
.iter()
.map(|phi_use| if *phi_use == id { zero } else { gated_input })
.collect(),
});
let new_reduce_id = NodeID::new(edit.num_node_ids());
let new_reduct_id = NodeID::new(edit.num_node_ids() + 1);
let new_reduce = Node::Reduce {
control,
init,
reduct: new_reduct_id,
};
let new_add = Node::Binary {
op: BinaryOperator::Add,
left: new_reduce_id,
right: phi,
};
let new_reduce = edit.add_node(new_reduce);
edit.add_node(new_add);
edit = edit.replace_all_uses(id, new_reduce)?;
edit = edit.delete_node(id)?;
Ok(edit)
});
}
}
}
/*
......
......@@ -212,7 +212,8 @@ fn preliminary_fixups(
let (_, init, _) = nodes[reduce.idx()].try_reduce().unwrap();
// Replace uses of the reduce in its cycle with the init.
let success = editor.edit(|edit| {
let success = editor.edit(|mut edit| {
edit = edit.add_schedule(init, Schedule::ParallelReduce)?;
edit.replace_all_uses_where(reduce, init, |id| reduce_cycles[&reduce].contains(id))
});
assert!(success);
......@@ -870,7 +871,7 @@ fn spill_clones(
// Step 2: filter edges (A, B) to just see edges where A uses B and A
// mutates B. These are the edges that may require a spill.
let mut spill_edges = edges.into_iter().filter(|(a, b)| {
mutating_writes(editor.func(), *a, objects).any(|id| id == *b)
(mutating_writes(editor.func(), *a, objects).any(|id| id == *b)
|| (get_uses(&editor.func().nodes[a.idx()])
.as_ref()
.into_iter()
......@@ -890,7 +891,14 @@ fn spill_clones(
data.contains(b)
&& editor.func().schedules[a.idx()].contains(&Schedule::ParallelReduce)
})
.unwrap_or(false))
.unwrap_or(false)))
&& !editor.func().nodes[a.idx()]
.try_write()
.map(|(collect, _, _)| {
collect == *b
&& editor.func().schedules[b.idx()].contains(&Schedule::ParallelReduce)
})
.unwrap_or(false)
});
// Step 3: if there is a spill edge, spill it and return true. Otherwise,
......
......@@ -126,11 +126,24 @@ fn remove_useless_fork_joins(
// Third, get rid of fork-joins.
for (fork, join) in fork_join_map {
if editor.get_users(*fork).len() == 1 && editor.get_users(*join).len() == 1 {
if editor.get_users(*join).len() == 1 {
let fork_use = get_uses(&editor.func().nodes[fork.idx()]).as_ref()[0];
let join_use = get_uses(&editor.func().nodes[join.idx()]).as_ref()[0];
let tids: Vec<_> = editor
.get_users(*fork)
.filter(|id| editor.func().nodes[id.idx()].is_thread_id())
.collect();
editor.edit(|mut edit| {
if !tids.is_empty() {
let u64_ty = edit.add_type(Type::UnsignedInteger64);
let zero = edit.add_zero_constant(u64_ty);
let zero = edit.add_node(Node::Constant { id: zero });
for tid in tids {
edit = edit.replace_all_uses(tid, zero)?;
edit = edit.delete_node(tid)?;
}
}
edit = edit.replace_all_uses(*join, join_use)?;
edit = edit.replace_all_uses(*fork, fork_use)?;
edit = edit.delete_node(*fork)?;
......
......@@ -598,6 +598,24 @@ pub fn is_one(editor: &FunctionEditor, id: NodeID) -> bool {
|| nodes[id.idx()].is_undef()
}
pub fn is_false(editor: &FunctionEditor, id: NodeID) -> bool {
let nodes = &editor.func().nodes;
nodes[id.idx()]
.try_constant()
.map(|id| editor.get_constant(id).is_false())
.unwrap_or(false)
|| nodes[id.idx()].is_undef()
}
pub fn is_true(editor: &FunctionEditor, id: NodeID) -> bool {
let nodes = &editor.func().nodes;
nodes[id.idx()]
.try_constant()
.map(|id| editor.get_constant(id).is_true())
.unwrap_or(false)
|| nodes[id.idx()].is_undef()
}
pub fn is_largest(editor: &FunctionEditor, id: NodeID) -> bool {
let nodes = &editor.func().nodes;
nodes[id.idx()]
......
......@@ -6,10 +6,9 @@ fn squash(x: f32) -> f32 {
fn layer_forward<n, m: usize>(vals: f32[n + 1], weights: f32[n + 1, m + 1]) -> f32[m + 1] {
@res let result : f32[m + 1];
result[0] = 1.0;
@outer_loop for j in 1..=m {
let sum = 0.0;
@inner_loop for k in 0..=n {
let sum = weights[0, j] * vals[0];
@inner_loop for k in 1..=n {
sum += weights[k, j] * vals[k];
}
result[j] = squash(sum);
......@@ -19,13 +18,16 @@ fn layer_forward<n, m: usize>(vals: f32[n + 1], weights: f32[n + 1, m + 1]) -> f
}
fn output_error<n: usize>(target: f32[n + 1], actual: f32[n + 1]) -> f32, f32[n + 1] {
let errsum = 0.0;
let delta : f32[n + 1];
for j in 1..=n {
@loop1 @res let delta : f32[n + 1];
@loop1 delta[0] = 0.0;
@loop1 for j in 1..=n {
let a = actual[j];
let t = target[j];
delta[j] = a * (1.0 - a) * (t - a);
}
let errsum = 0.0;
@loop2 for j in 1..=n {
errsum += abs!(delta[j]);
}
......@@ -37,10 +39,9 @@ fn hidden_error<hidden_n, output_n: usize>(
hidden_weights: f32[hidden_n + 1, output_n + 1],
hidden_vals: f32[hidden_n + 1],
) -> f32, f32[hidden_n + 1] {
let errsum = 0.0;
let delta : f32[hidden_n + 1];
for j in 1..=hidden_n {
@loop1 @res let delta : f32[hidden_n + 1];
@loop1 delta[0] = 0.0;
@loop1 for j in 1..=hidden_n {
let h = hidden_vals[j];
let sum = 0.0;
......@@ -49,6 +50,10 @@ fn hidden_error<hidden_n, output_n: usize>(
}
delta[j] = h * (1.0 - h) * sum;
}
let errsum = 0.0;
@loop2 for j in 1..=hidden_n {
errsum += abs!(delta[j]);
}
......@@ -64,8 +69,8 @@ fn adjust_weights<n, m: usize>(
weights: f32[n + 1, m + 1],
prev_weights: f32[n + 1, m + 1]
) -> f32[n + 1, m + 1], f32[n + 1, m + 1] {
for j in 1..=m {
for k in 0..=n {
@outer_loop for j in 1..=m {
@inner_loop for k in 0..=n {
let new_dw = ETA * delta[j] * vals[k] + MOMENTUM * prev_weights[k, j];
weights[k, j] += new_dw;
prev_weights[k, j] = new_dw;
......@@ -86,15 +91,15 @@ fn backprop<input_n, hidden_n, output_n: usize>(
) -> f32, f32,
f32[input_n + 1, hidden_n + 1], f32[input_n + 1, hidden_n + 1],
f32[hidden_n + 1, output_n + 1], f32[hidden_n + 1, output_n + 1] {
let hidden_vals = layer_forward::<input_n, hidden_n>(input_vals, input_weights);
let output_vals = layer_forward::<hidden_n, output_n>(hidden_vals, hidden_weights);
@forward_input let hidden_vals = layer_forward::<input_n, hidden_n>(input_vals, input_weights);
@forward_hidden let output_vals = layer_forward::<hidden_n, output_n>(hidden_vals, hidden_weights);
let out_err, out_delta = output_error::<output_n>(target, output_vals);
let hid_err, hid_delta = hidden_error::<hidden_n, output_n>(out_delta, hidden_weights, hidden_vals);
@output_error let out_err, out_delta = output_error::<output_n>(target, output_vals);
@hidden_error let hid_err, hid_delta = hidden_error::<hidden_n, output_n>(out_delta, hidden_weights, hidden_vals);
let hidden_weights, hidden_prev_weights
@adjust_hidden let hidden_weights, hidden_prev_weights
= adjust_weights::<hidden_n, output_n>(out_delta, hidden_vals, hidden_weights, hidden_prev_weights);
let input_weights, input_prev_weights
@adjust_input let input_weights, input_prev_weights
= adjust_weights::<input_n, hidden_n>(hid_delta, input_vals, input_weights, input_prev_weights);
return out_err, hid_err, input_weights, input_prev_weights, hidden_weights, hidden_prev_weights;
......
......@@ -12,7 +12,7 @@ simpl!(*);
inline(layer_forward);
delete-uncalled(*);
no-memset(layer_forward@res);
no-memset(layer_forward@res, output_error@res, hidden_error@res);
lift-dc-math(*);
loop-bound-canon(*);
simpl!(*);
......@@ -25,7 +25,42 @@ fixpoint {
}
reduce-slf(*);
simpl!(*);
fork-interchange[0, 1](adjust_weights);
simpl!(*);
infer-schedules(*);
// The first call to layer_forward can be parallelized by 16 (the size of the
// hidden layer) and the second can't be parallelized at all (the size of the
// output layer is 1)
inline(backprop@forward_input, backprop@forward_hidden);
let forward_input = outline(backprop@forward_input);
let forward_hidden = outline(backprop@forward_hidden);
fork-tile[16, 0, false, true](forward_input@outer_loop \ forward_input@inner_loop);
let (outer, inner) = fork-reshape[[1], [0]](forward_input@outer_loop \ forward_input@inner_loop);
let forward_input = outline(inner);
inline(backprop@forward_input);
// The first call to adjust_weights has total loop dimensions of 1 * 17, so not
// worth parallelizing (given that the body is trivial)
// The second call to adjust_weights has a total dimension of 16 * (input + 1)
// which is worth parallelizing, we'll do it by 16
inline(backprop@adjust_hidden, backprop@adjust_input);
let adjust_hidden = outline(backprop@adjust_hidden);
let adjust_input = outline(backprop@adjust_input);
fork-tile[16, 0, false, true](adjust_input);
let (outer, inner) = fork-reshape[[1], [0, 2]](adjust_input);
let adjust_input = outline(inner);
inline(backprop@adjust_input);
delete-uncalled(*);
const-inline(*);
simpl!(*);
fork-split(*);
unforkify(*);
unforkify(output_error, hidden_error, adjust_hidden, adjust_input, forward_hidden, forward_input);
simpl!(*);
gcm(*);
gvn(*);
dce(*);
macro simpl!(X) {
ccp(X);
simplify-cfg(X);
lift-dc-math(X);
gvn(X);
phi-elim(X);
dce(X);
infer-schedules(X);
}
no-memset(layer_forward@res, output_error@res, hidden_error@res);
phi-elim(*);
dce(*);
crc(*);
dce(*);
slf(*);
dce(*);
let output_loop1 = outline(output_error@loop1);
let output_loop2 = outline(output_error@loop2);
let hidden_loop1 = outline(hidden_error@loop1);
let hidden_loop2 = outline(hidden_error@loop2);
simpl!(*);
inline(layer_forward, backprop@output_error, backprop@hidden_error);
delete-uncalled(*);
gpu(layer_forward, output_loop1, output_loop2, hidden_loop1, hidden_loop2, adjust_weights);
const-inline(*);
let auto = auto-outline(backprop);
gpu(auto.backprop);
lift-dc-math(*);
loop-bound-canon(*);
simpl!(*);
lift-dc-math(*);
slf(*);
fixpoint {
forkify(*);
fork-guard-elim(*);
fork-coalesce(*);
}
reduce-slf(*);
simpl!(*);
inline(auto.backprop);
inline(auto.backprop);
delete-uncalled(*);
fork-extend[32](layer_forward@inner_loop);
clean-monoid-reduces(layer_forward);
simpl!(layer_forward);
fork-tile[32, 0, false, true](layer_forward@inner_loop);
clean-monoid-reduces(layer_forward);
let out = fork-split(layer_forward@inner_loop);
clean-monoid-reduces(layer_forward);
simpl!(layer_forward);
let fission = fork-fission[out._1_layer_forward.fj0](layer_forward);
simpl!(layer_forward);
sroa[true](*);
dce(*);
float-collections(*);
reuse-products(*);
dce(*);
fork-dim-merge(adjust_weights);
simpl!(adjust_weights);
fork-extend[32](adjust_weights);
fork-tile[32, 0, false, true](adjust_weights);
fork-split(adjust_weights);
simpl!(adjust_weights);
gcm(*);
type Node = struct { edge_start: u32; num_edges: u32; };
type StopProd = struct { stop: bool; };
fn make_stop_prod() -> StopProd {
let ret : StopProd;
ret.stop = true;
return ret;
}
#[entry]
fn bfs<n, m: usize>(graph_nodes: Node[n], source: u32, edges: u32[m]) -> i32[n] {
......@@ -23,8 +30,6 @@ fn bfs<n, m: usize>(graph_nodes: Node[n], source: u32, edges: u32[m]) -> i32[n]
let updated: bool[n];
while !stop {
stop = true;
@loop1 for i in 0..n {
if mask[i] {
mask[i] = false;
......@@ -42,14 +47,16 @@ fn bfs<n, m: usize>(graph_nodes: Node[n], source: u32, edges: u32[m]) -> i32[n]
}
}
@make let stop_prod = make_stop_prod();
@loop2 for i in 0..n {
stop = stop && !updated[i];
if updated[i] {
mask[i] = true;
visited[i] = true;
updated[i] = false;
stop_prod.stop = updated[i];
}
}
stop = stop_prod.stop;
}
return cost;
......
......@@ -10,13 +10,14 @@ macro simpl!(X) {
phi-elim(bfs);
no-memset(bfs@cost);
outline(bfs@cost_init);
let loop1 = outline(bfs@loop1);
let loop2 = outline(bfs@loop2);
let init = outline(bfs@cost_init);
let traverse = outline(bfs@loop1);
let collect = outline(bfs@loop2);
simpl!(*);
predication(*);
const-inline(*);
loop-bound-canon(*);
simpl!(*);
fixpoint {
forkify(*);
......@@ -25,6 +26,37 @@ fixpoint {
simpl!(*);
predication(*);
simpl!(*);
reduce-slf(*);
simpl!(*);
slf(*);
simpl!(*);
fixpoint {
forkify(collect);
fork-guard-elim(collect);
}
simpl!(collect);
parallel-fork(traverse, collect);
parallel-reduce(traverse, collect);
unforkify(*);
fork-tile[32, 0, false, true](traverse, collect);
let (outer, inner) = fork-reshape[[1], [0]](traverse);
let traverse_body = outline(inner);
let (outer, inner) = fork-reshape[[1], [0]](collect);
let collect_body = outline(inner);
let init_body = init;
// Following code seems to generate breaking RT code
//fork-tile[32, 0, false, true](init);
//let (outer, inner) = fork-reshape[[1], [0]](init);
//let init_body = outline(inner);
//inline(bfs@cost_init);
inline(bfs@loop1, bfs@loop2);
delete-uncalled(*);
const-inline(*);
unforkify(init_body, traverse_body, collect_body);
simpl!(*);
gcm(*);
......@@ -10,14 +10,17 @@ macro simpl!(X) {
phi-elim(bfs);
no-memset(bfs@cost);
let cost_init = outline(bfs@cost_init);
let loop1 = outline(bfs@loop1);
let loop2 = outline(bfs@loop2);
gpu(loop1, loop2);
let init = outline(bfs@cost_init);
let traverse = outline(bfs@loop1);
let collect = outline(bfs@loop2);
parallel-reduce(traverse, collect);
no-memset(make_stop_prod);
gpu(traverse, make_stop_prod, collect);
simpl!(*);
predication(*);
const-inline(*);
loop-bound-canon(*);
simpl!(*);
fixpoint {
forkify(*);
......@@ -26,14 +29,17 @@ fixpoint {
simpl!(*);
predication(*);
simpl!(*);
unforkify(cost_init);
parallel-reduce(loop1);
forkify(*);
fork-guard-elim(*);
simpl!(*);
predication(*);
reduce-slf(*);
simpl!(*);
fixpoint {
forkify(collect);
fork-guard-elim(collect);
}
simpl!(collect);
fork-tile[1024, 0, false, true](traverse, collect);
fork-split(traverse, collect);
unforkify(init);
gcm(*);
......@@ -19,6 +19,7 @@ pub struct BFSInputs {
fn run_bfs(nodes: &[Node], source: u32, edges: &[u32]) -> Vec<i32> {
let n = nodes.len() as u64;
let m = edges.len() as u64;
println!("Running with {} nodes and {} edges.", n, m);
let nodes = HerculesImmBox::from(nodes);
let edges = HerculesImmBox::from(edges);
......
......@@ -8,6 +8,7 @@ fn main() {
}
#[test]
#[ignore]
fn bfs_test_4096() {
bfs_harness(BFSInputs {
input: "data/graph4096.txt".to_string(),
......
......@@ -19,6 +19,8 @@ pub struct CFDInputs {
pub block_size: usize,
#[clap(short = None, long = Some("pre-euler"))]
pub pre_euler: bool,
#[clap(short, long)]
pub verify: bool,
}
fn run_euler(
......@@ -219,6 +221,7 @@ pub fn cfd_harness(args: CFDInputs) {
iterations,
block_size,
pre_euler,
verify,
} = args;
let FarFieldConditions {
......@@ -268,37 +271,39 @@ pub fn cfd_harness(args: CFDInputs) {
&ff_fc_momentum_z,
)
};
let res_rust = if pre_euler {
rust_cfd::pre_euler(
nelr,
iterations,
variables,
areas.as_slice(),
elements_surrounding_elements.as_slice(),
&normals,
&ff_variable,
&ff_fc_density_energy,
&ff_fc_momentum_x,
&ff_fc_momentum_y,
&ff_fc_momentum_z,
)
} else {
rust_cfd::euler(
nelr,
iterations,
variables,
areas.as_slice(),
elements_surrounding_elements.as_slice(),
&normals,
&ff_variable,
&ff_fc_density_energy,
&ff_fc_momentum_x,
&ff_fc_momentum_y,
&ff_fc_momentum_z,
)
};
if verify {
let res_rust = if pre_euler {
rust_cfd::pre_euler(
nelr,
iterations,
variables,
areas.as_slice(),
elements_surrounding_elements.as_slice(),
&normals,
&ff_variable,
&ff_fc_density_energy,
&ff_fc_momentum_x,
&ff_fc_momentum_y,
&ff_fc_momentum_z,
)
} else {
rust_cfd::euler(
nelr,
iterations,
variables,
areas.as_slice(),
elements_surrounding_elements.as_slice(),
&normals,
&ff_variable,
&ff_fc_density_energy,
&ff_fc_momentum_x,
&ff_fc_momentum_y,
&ff_fc_momentum_z,
)
};
if !compare_floats(&res_juno, &res_rust) {
panic!("Mismatch in results");
if !compare_floats(&res_juno, &res_rust) {
panic!("Mismatch in results");
}
}
}
......@@ -14,6 +14,7 @@ fn test_euler() {
iterations: 1,
block_size: 16,
pre_euler: false,
verify: true,
});
}
......@@ -24,5 +25,6 @@ fn test_pre_euler() {
iterations: 1,
block_size: 16,
pre_euler: true,
verify: true,
});
}
......@@ -40,10 +40,15 @@ let split = fork-split(loop2);
let loop2_body = outline(split.srad_1.fj1);
simpl!(loop2, loop2_body);
inline(srad@loop2);
fork-tile[32, 0, false, false](loop3);
let split = fork-split(loop3);
let loop3_body = outline(split.srad_2.fj1);
simpl!(loop3, loop3_body);
inline(srad@loop2, srad@loop3);
delete-uncalled(*);
fork-split(extract, compress, loop1, loop2_body, loop3);
unforkify(extract, compress, loop1, loop2_body, loop3);
fork-split(extract, compress, loop1, loop2_body, loop3_body);
unforkify(extract, compress, loop1, loop2_body, loop3_body);
gcm(*);
......@@ -551,7 +551,14 @@ fn compile_expr(
}
Ok(ExprResult::Expr(ir::ScheduleExp::Record { fields: result }))
}
parser::Expr::SetOp {
parser::Expr::UnaryOp { span: _, op, exp } => {
let exp = compile_exp_as_expr(*exp, lexer, macrostab, macros)?;
Ok(ExprResult::Expr(ir::ScheduleExp::UnaryOp {
op,
exp: Box::new(exp),
}))
}
parser::Expr::BinaryOp {
span: _,
op,
lhs,
......@@ -559,7 +566,7 @@ fn compile_expr(
} => {
let lhs = compile_exp_as_expr(*lhs, lexer, macrostab, macros)?;
let rhs = compile_exp_as_expr(*rhs, lexer, macrostab, macros)?;
Ok(ExprResult::Expr(ir::ScheduleExp::SetOp {
Ok(ExprResult::Expr(ir::ScheduleExp::BinaryOp {
op,
lhs: Box::new(lhs),
rhs: Box::new(rhs),
......
......@@ -18,7 +18,7 @@ pub enum Pass {
ForkDimMerge,
ForkExtend,
ForkFissionBufferize,
ForkFission,
ForkFission,
ForkFusion,
ForkGuardElim,
ForkInterchange,
......@@ -134,8 +134,12 @@ pub enum ScheduleExp {
body: Vec<ScheduleStmt>,
res: Box<ScheduleExp>,
},
SetOp {
op: parser::SetOp,
UnaryOp {
op: parser::UnaryOp,
exp: Box<ScheduleExp>,
},
BinaryOp {
op: parser::BinaryOp,
lhs: Box<ScheduleExp>,
rhs: Box<ScheduleExp>,
},
......
......@@ -46,6 +46,10 @@ false "false"
\| "|"
& "&"
! "!"
\|\| "||"
&& "&&"
panic[\t \n\r]+after "panic_after"
print[\t \n\r]+iter "print_iter"
stop[\t \n\r]+after "stop_after"
......
......@@ -6,6 +6,9 @@
%left '\\'
%left '|'
%left '&'
%left '||'
%left '&&'
%right '!'
%left '.' '@'
%%
......@@ -27,14 +30,23 @@ Stmt -> Stmt
{ Stmt::ExprStmt { span: $span, exp: $1 } }
| 'fixpoint' FixpointLimit '{' Schedule '}'
{ Stmt::Fixpoint { span: $span, limit: $2, body: Box::new($4) } }
| 'if' Expr '{' Schedule '}'
{ Stmt::IfThenElse { span: $span, cond: $2, thn: Box::new($4), els: None } }
| 'if' Expr '{' Schedule '}' 'else' '{' Schedule '}'
{ Stmt::IfThenElse { span: $span, cond: $2, thn: Box::new($4), els: Some(Box::new($8)) } }
| 'if' Expr '{' Schedule '}' ElseStmt
{ Stmt::IfThenElse { span: $span, cond: $2, thn: Box::new($4), els: $6 } }
| MacroDecl
{ Stmt::MacroDecl { span: $span, def: $1 } }
;
ElseStmt -> Option<Box<OperationList>>
: { None }
| 'else' '{' Schedule '}'
{ Some(Box::new($3)) }
| 'else' 'if' Expr '{' Schedule '}' ElseStmt
{ Some(Box::new(OperationList::ConsStmt(
Stmt::IfThenElse { span: $span, cond: $3, thn: Box::new($5), els: $7 },
Box::new(OperationList::NilStmt()),
))) }
;
FixpointLimit -> FixpointLimit
: { FixpointLimit::NoLimit { span: $span } }
| 'stop_after' 'INT'
......@@ -75,11 +87,17 @@ Expr -> Expr
| '<' Fields '>'
{ Expr::Record { span: $span, fields: rev($2) } }
| Expr '\\' Expr
{ Expr::SetOp { span: $span, op: SetOp::Difference, lhs: Box::new($1), rhs: Box::new($3) } }
{ Expr::BinaryOp { span: $span, op: BinaryOp::Difference, lhs: Box::new($1), rhs: Box::new($3) } }
| Expr '|' Expr
{ Expr::SetOp { span: $span, op: SetOp::Union, lhs: Box::new($1), rhs: Box::new($3) } }
{ Expr::BinaryOp { span: $span, op: BinaryOp::Union, lhs: Box::new($1), rhs: Box::new($3) } }
| Expr '&' Expr
{ Expr::SetOp { span: $span, op: SetOp::Intersection, lhs: Box::new($1), rhs: Box::new($3) } }
{ Expr::BinaryOp { span: $span, op: BinaryOp::Intersection, lhs: Box::new($1), rhs: Box::new($3) } }
| '!' Expr
{ Expr::UnaryOp { span: $span, op: UnaryOp::Not, exp: Box::new($2) } }
| Expr '||' Expr
{ Expr::BinaryOp { span: $span, op: BinaryOp::Or, lhs: Box::new($1), rhs: Box::new($3) } }
| Expr '&&' Expr
{ Expr::BinaryOp { span: $span, op: BinaryOp::And, lhs: Box::new($1), rhs: Box::new($3) } }
;
Args -> Vec<Expr>
......@@ -179,10 +197,17 @@ pub enum FixpointLimit {
}
#[derive(Copy, Clone, Debug)]
pub enum SetOp {
pub enum UnaryOp {
Not,
}
#[derive(Copy, Clone, Debug)]
pub enum BinaryOp {
Difference,
Union,
Intersection,
Or,
And,
}
pub enum Expr {
......@@ -195,7 +220,8 @@ pub enum Expr {
Field { span: Span, lhs: Box<Expr>, field: Span },
BlockExpr { span: Span, body: Box<OperationList> },
Record { span: Span, fields: Vec<(Span, Expr)> },
SetOp { span: Span, op: SetOp, lhs: Box<Expr>, rhs: Box<Expr> },
UnaryOp { span: Span, op: UnaryOp, exp: Box<Expr> },
BinaryOp { span: Span, op: BinaryOp, lhs: Box<Expr>, rhs: Box<Expr> },
Tuple { span: Span, exps: Vec<Expr> },
TupleField { span: Span, lhs: Box<Expr>, field: Span },
}
......