Skip to content
Snippets Groups Projects

Compare revisions

Changes are shown as if the source revision was being merged into the target revision. Learn more about comparing revisions.

Source

Select target project
No results found

Target

Select target project
  • llvm/hercules
1 result
Show changes
Commits on Source (4)
Showing
with 420 additions and 132 deletions
...@@ -354,6 +354,7 @@ impl GPUContext<'_> { ...@@ -354,6 +354,7 @@ impl GPUContext<'_> {
write!( write!(
w, w,
" "
#define _CG_ABI_EXPERIMENTAL
#include <assert.h> #include <assert.h>
#include <stdio.h> #include <stdio.h>
#include <stddef.h> #include <stddef.h>
...@@ -561,8 +562,9 @@ namespace cg = cooperative_groups; ...@@ -561,8 +562,9 @@ namespace cg = cooperative_groups;
* and writes. * and writes.
*/ */
fn codegen_helpers(&self, w: &mut String) -> Result<(), Error> { fn codegen_helpers(&self, w: &mut String) -> Result<(), Error> {
write!(w, "\t__shared__ cg::experimental::block_tile_memory<1024> block_sync_shared;\n")?;
write!(w, "\tcg::grid_group grid = cg::this_grid();\n")?; write!(w, "\tcg::grid_group grid = cg::this_grid();\n")?;
write!(w, "\tcg::thread_block block = cg::this_thread_block();\n")?; write!(w, "\tcg::thread_block block = cg::experimental::this_thread_block(block_sync_shared);\n")?;
Ok(()) Ok(())
} }
...@@ -1294,7 +1296,7 @@ namespace cg = cooperative_groups; ...@@ -1294,7 +1296,7 @@ namespace cg = cooperative_groups;
} }
if !is_primitive && state != KernelState::OutBlock { if !is_primitive && state != KernelState::OutBlock {
write!(w, "{}}}\n", tabs)?; write!(w, "{}}}\n", tabs)?;
write!(w, "{}{}.sync();\n", tabs, cg_tile)?; //write!(w, "{}{}.sync();\n", tabs, cg_tile)?;
*num_tabs -= 1; *num_tabs -= 1;
} }
if !is_primitive && state == KernelState::OutBlock { if !is_primitive && state == KernelState::OutBlock {
...@@ -1311,6 +1313,7 @@ namespace cg = cooperative_groups; ...@@ -1311,6 +1313,7 @@ namespace cg = cooperative_groups;
} }
if !is_primitive if !is_primitive
&& (state != KernelState::OutBlock || !is_block_parallel.unwrap_or(false)) && (state != KernelState::OutBlock || !is_block_parallel.unwrap_or(false))
&& !self.function.schedules[id.idx()].contains(&Schedule::NoResetConstant)
{ {
let data_size = self.get_size(self.typing[id.idx()], None); let data_size = self.get_size(self.typing[id.idx()], None);
write!( write!(
...@@ -1321,6 +1324,7 @@ namespace cg = cooperative_groups; ...@@ -1321,6 +1324,7 @@ namespace cg = cooperative_groups;
write!(w, "{}\t*({} + i) = 0;\n", tabs, define_variable)?; write!(w, "{}\t*({} + i) = 0;\n", tabs, define_variable)?;
write!(w, "{}}}\n", tabs)?; write!(w, "{}}}\n", tabs)?;
write!(w, "{}{}.sync();\n", tabs, cg_tile)?; write!(w, "{}{}.sync();\n", tabs, cg_tile)?;
//write!(w, "__syncthreads\n")?;
} }
} }
// Dynamic constants emitted at top // Dynamic constants emitted at top
...@@ -1595,7 +1599,7 @@ namespace cg = cooperative_groups; ...@@ -1595,7 +1599,7 @@ namespace cg = cooperative_groups;
write!(w, "{}\t*({} + {}.size() * ({} / {}.size()) + {}.thread_rank()) = *({} + {}.size() * ({} / {}.size()) + {}.thread_rank());\n", tabs, collect_with_indices, cg_tile, data_size, cg_tile, cg_tile, data_variable, cg_tile, data_size, cg_tile, cg_tile)?; write!(w, "{}\t*({} + {}.size() * ({} / {}.size()) + {}.thread_rank()) = *({} + {}.size() * ({} / {}.size()) + {}.thread_rank());\n", tabs, collect_with_indices, cg_tile, data_size, cg_tile, cg_tile, data_variable, cg_tile, data_size, cg_tile, cg_tile)?;
write!(w, "{}}}\n", tabs)?; write!(w, "{}}}\n", tabs)?;
} }
write!(w, "{}{}.sync();\n", tabs, cg_tile)?; //write!(w, "{}{}.sync();\n", tabs, cg_tile)?;
let collect_variable = self.get_value(*collect, false, false); let collect_variable = self.get_value(*collect, false, false);
write!(w, "{}{} = {};\n", tabs, define_variable, collect_variable)?; write!(w, "{}{} = {};\n", tabs, define_variable, collect_variable)?;
} }
...@@ -1705,20 +1709,20 @@ namespace cg = cooperative_groups; ...@@ -1705,20 +1709,20 @@ namespace cg = cooperative_groups;
}; };
write!( write!(
thread_block_tiles, thread_block_tiles,
"\tcg::thread_block_tile<{}> {} = cg::tiled_partition<{}>(block);\n", "\tcg::thread_block_tile<{}> {} = cg::experimental::tiled_partition<{}>(block);\n",
use_thread_per_id, cg_tile, use_thread_per_id use_thread_per_id, cg_tile, use_thread_per_id
)?; )?;
let cg_tile_use = self.get_cg_tile(id, CGType::Use); let cg_tile_use = self.get_cg_tile(id, CGType::Use);
write!( write!(
thread_block_tiles, thread_block_tiles,
"\tcg::thread_block_tile<{}> {} = cg::tiled_partition<{}>(block);\n", "\tcg::thread_block_tile<{}> {} = cg::experimental::tiled_partition<{}>(block);\n",
use_thread_quota, cg_tile_use, use_thread_quota use_thread_quota, cg_tile_use, use_thread_quota
)?; )?;
let available_thread_quota = available_thread_quota.unwrap(); let available_thread_quota = available_thread_quota.unwrap();
let cg_tile_available = self.get_cg_tile(id, CGType::Available); let cg_tile_available = self.get_cg_tile(id, CGType::Available);
write!( write!(
thread_block_tiles, thread_block_tiles,
"\tcg::thread_block_tile<{}> {} = cg::tiled_partition<{}>(block);\n", "\tcg::thread_block_tile<{}> {} = cg::experimental::tiled_partition<{}>(block);\n",
available_thread_quota, cg_tile_available, available_thread_quota available_thread_quota, cg_tile_available, available_thread_quota
)?; )?;
if parallel_factor.is_none() { if parallel_factor.is_none() {
...@@ -1781,6 +1785,7 @@ namespace cg = cooperative_groups; ...@@ -1781,6 +1785,7 @@ namespace cg = cooperative_groups;
let fork = self.join_fork_map.get(&id).unwrap(); let fork = self.join_fork_map.get(&id).unwrap();
let cg_tile_available = self.get_cg_tile(*fork, CGType::Available); let cg_tile_available = self.get_cg_tile(*fork, CGType::Available);
write!(w_term, "\t{}.sync();\n", cg_tile_available)?; write!(w_term, "\t{}.sync();\n", cg_tile_available)?;
//write!(w_term, "\t__syncthreads;\n")?;
} }
// If the Fork was parallelized, each thread or UsedPerId tile of // If the Fork was parallelized, each thread or UsedPerId tile of
// threads only runs one ThreadID, so we can jump straight to the // threads only runs one ThreadID, so we can jump straight to the
......
...@@ -319,12 +319,12 @@ pub fn fork_fission<'a>( ...@@ -319,12 +319,12 @@ pub fn fork_fission<'a>(
.collect(); .collect();
let mut created_forks = Vec::new(); let mut created_forks = Vec::new();
// This does the reduction fission // This does the reduction fission
for fork in forks { for fork in forks {
let join = fork_join_map[&fork.0]; let join = fork_join_map[&fork.0];
// FIXME: Don't make multiple forks for reduces that are in cycles with each other. // FIXME: Don't make multiple forks for reduces that are in cycles with each other.
let reduce_partition = default_reduce_partition(editor, fork.0, join); let reduce_partition = default_reduce_partition(editor, fork.0, join);
if !editor.func().labels[fork.0.idx()].contains(&fork_label) { if !editor.func().labels[fork.0.idx()].contains(&fork_label) {
...@@ -332,14 +332,19 @@ pub fn fork_fission<'a>( ...@@ -332,14 +332,19 @@ pub fn fork_fission<'a>(
} }
if editor.is_mutable(fork.0) { if editor.is_mutable(fork.0) {
created_forks = fork_reduce_fission_helper(editor, fork_join_map, reduce_partition, nodes_in_fork_joins, fork.0); created_forks = fork_reduce_fission_helper(
editor,
fork_join_map,
reduce_partition,
nodes_in_fork_joins,
fork.0,
);
if created_forks.is_empty() { if created_forks.is_empty() {
continue; continue;
} else { } else {
return created_forks; return created_forks;
} }
} }
} }
created_forks created_forks
...@@ -503,13 +508,17 @@ pub fn fork_reduce_fission_helper<'a>( ...@@ -503,13 +508,17 @@ pub fn fork_reduce_fission_helper<'a>(
let mut new_forks = Vec::new(); let mut new_forks = Vec::new();
let mut new_control_pred: NodeID = editor.get_uses(fork).filter(|n| editor.node(n).is_control()).next().unwrap(); let mut new_control_pred: NodeID = editor
.get_uses(fork)
.filter(|n| editor.node(n).is_control())
.next()
.unwrap();
let mut new_fork = NodeID::new(0); let mut new_fork = NodeID::new(0);
let mut new_join = NodeID::new(0); let mut new_join = NodeID::new(0);
let subgraph = &nodes_in_fork_joins[&fork]; let subgraph = &nodes_in_fork_joins[&fork];
// Gets everything between fork & join that this reduce needs. (ALL CONTROL) // Gets everything between fork & join that this reduce needs. (ALL CONTROL)
editor.edit(|mut edit| { editor.edit(|mut edit| {
for reduce in reduce_partition { for reduce in reduce_partition {
...@@ -522,7 +531,7 @@ pub fn fork_reduce_fission_helper<'a>( ...@@ -522,7 +531,7 @@ pub fn fork_reduce_fission_helper<'a>(
new_fork = mapping[&fork]; new_fork = mapping[&fork];
new_forks.push(new_fork); new_forks.push(new_fork);
new_join = mapping[&join]; new_join = mapping[&join];
// Atttach new_fork after control_pred // Atttach new_fork after control_pred
let (old_control_pred, _) = edit.get_node(new_fork).try_fork().unwrap().clone(); let (old_control_pred, _) = edit.get_node(new_fork).try_fork().unwrap().clone();
edit = edit.replace_all_uses_where(old_control_pred, new_control_pred, |usee| { edit = edit.replace_all_uses_where(old_control_pred, new_control_pred, |usee| {
...@@ -532,7 +541,7 @@ pub fn fork_reduce_fission_helper<'a>( ...@@ -532,7 +541,7 @@ pub fn fork_reduce_fission_helper<'a>(
// Replace uses of reduce // Replace uses of reduce
edit = edit.replace_all_uses(reduce, mapping[&reduce])?; edit = edit.replace_all_uses(reduce, mapping[&reduce])?;
new_control_pred = new_join; new_control_pred = new_join;
}; }
// Replace original join w/ new final join // Replace original join w/ new final join
edit = edit.replace_all_uses_where(join, new_join, |_| true)?; edit = edit.replace_all_uses_where(join, new_join, |_| true)?;
...@@ -1502,6 +1511,10 @@ fn fork_fusion( ...@@ -1502,6 +1511,10 @@ fn fork_fusion(
* element. This aides in parallelizing outer loops. Looks only at reduces with * element. This aides in parallelizing outer loops. Looks only at reduces with
* the monoid reduce schedule, since that indicates a particular structure which * the monoid reduce schedule, since that indicates a particular structure which
* is annoying to check for again. * is annoying to check for again.
*
* Looks for would-be monoid reduces, if not for a gate on the reduction.
* Partially predicate the gated reduction to allow for a proper monoid
* reduction.
*/ */
pub fn clean_monoid_reduces(editor: &mut FunctionEditor, typing: &Vec<TypeID>) { pub fn clean_monoid_reduces(editor: &mut FunctionEditor, typing: &Vec<TypeID>) {
for id in editor.node_ids() { for id in editor.node_ids() {
...@@ -1512,7 +1525,7 @@ pub fn clean_monoid_reduces(editor: &mut FunctionEditor, typing: &Vec<TypeID>) { ...@@ -1512,7 +1525,7 @@ pub fn clean_monoid_reduces(editor: &mut FunctionEditor, typing: &Vec<TypeID>) {
let Some((_, init, reduct)) = nodes[id.idx()].try_reduce() else { let Some((_, init, reduct)) = nodes[id.idx()].try_reduce() else {
continue; continue;
}; };
let out_uses: Vec<_> = editor.get_users(id).filter(|id| *id != reduct).collect(); let out_users: Vec<_> = editor.get_users(id).filter(|id| *id != reduct).collect();
match nodes[reduct.idx()] { match nodes[reduct.idx()] {
Node::Binary { Node::Binary {
...@@ -1520,7 +1533,8 @@ pub fn clean_monoid_reduces(editor: &mut FunctionEditor, typing: &Vec<TypeID>) { ...@@ -1520,7 +1533,8 @@ pub fn clean_monoid_reduces(editor: &mut FunctionEditor, typing: &Vec<TypeID>) {
left: _, left: _,
right: _, right: _,
} if (op == BinaryOperator::Add || op == BinaryOperator::Or) } if (op == BinaryOperator::Add || op == BinaryOperator::Or)
&& !is_zero(editor, init) => && !is_zero(editor, init)
&& !is_false(editor, init) =>
{ {
editor.edit(|mut edit| { editor.edit(|mut edit| {
let zero = edit.add_zero_constant(typing[init.idx()]); let zero = edit.add_zero_constant(typing[init.idx()]);
...@@ -1532,7 +1546,7 @@ pub fn clean_monoid_reduces(editor: &mut FunctionEditor, typing: &Vec<TypeID>) { ...@@ -1532,7 +1546,7 @@ pub fn clean_monoid_reduces(editor: &mut FunctionEditor, typing: &Vec<TypeID>) {
left: init, left: init,
right: id, right: id,
}); });
for u in out_uses { for u in out_users {
edit.sub_edit(u, final_op); edit.sub_edit(u, final_op);
} }
edit.replace_all_uses_where(id, final_op, |u| *u != reduct && *u != final_op) edit.replace_all_uses_where(id, final_op, |u| *u != reduct && *u != final_op)
...@@ -1543,7 +1557,8 @@ pub fn clean_monoid_reduces(editor: &mut FunctionEditor, typing: &Vec<TypeID>) { ...@@ -1543,7 +1557,8 @@ pub fn clean_monoid_reduces(editor: &mut FunctionEditor, typing: &Vec<TypeID>) {
left: _, left: _,
right: _, right: _,
} if (op == BinaryOperator::Mul || op == BinaryOperator::And) } if (op == BinaryOperator::Mul || op == BinaryOperator::And)
&& !is_one(editor, init) => && !is_one(editor, init)
&& !is_true(editor, init) =>
{ {
editor.edit(|mut edit| { editor.edit(|mut edit| {
let one = edit.add_one_constant(typing[init.idx()]); let one = edit.add_one_constant(typing[init.idx()]);
...@@ -1555,7 +1570,7 @@ pub fn clean_monoid_reduces(editor: &mut FunctionEditor, typing: &Vec<TypeID>) { ...@@ -1555,7 +1570,7 @@ pub fn clean_monoid_reduces(editor: &mut FunctionEditor, typing: &Vec<TypeID>) {
left: init, left: init,
right: id, right: id,
}); });
for u in out_uses { for u in out_users {
edit.sub_edit(u, final_op); edit.sub_edit(u, final_op);
} }
edit.replace_all_uses_where(id, final_op, |u| *u != reduct && *u != final_op) edit.replace_all_uses_where(id, final_op, |u| *u != reduct && *u != final_op)
...@@ -1574,7 +1589,7 @@ pub fn clean_monoid_reduces(editor: &mut FunctionEditor, typing: &Vec<TypeID>) { ...@@ -1574,7 +1589,7 @@ pub fn clean_monoid_reduces(editor: &mut FunctionEditor, typing: &Vec<TypeID>) {
intrinsic: Intrinsic::Max, intrinsic: Intrinsic::Max,
args: Box::new([init, id]), args: Box::new([init, id]),
}); });
for u in out_uses { for u in out_users {
edit.sub_edit(u, final_op); edit.sub_edit(u, final_op);
} }
edit.replace_all_uses_where(id, final_op, |u| *u != reduct && *u != final_op) edit.replace_all_uses_where(id, final_op, |u| *u != reduct && *u != final_op)
...@@ -1593,7 +1608,7 @@ pub fn clean_monoid_reduces(editor: &mut FunctionEditor, typing: &Vec<TypeID>) { ...@@ -1593,7 +1608,7 @@ pub fn clean_monoid_reduces(editor: &mut FunctionEditor, typing: &Vec<TypeID>) {
intrinsic: Intrinsic::Min, intrinsic: Intrinsic::Min,
args: Box::new([init, id]), args: Box::new([init, id]),
}); });
for u in out_uses { for u in out_users {
edit.sub_edit(u, final_op); edit.sub_edit(u, final_op);
} }
edit.replace_all_uses_where(id, final_op, |u| *u != reduct && *u != final_op) edit.replace_all_uses_where(id, final_op, |u| *u != reduct && *u != final_op)
...@@ -1602,6 +1617,65 @@ pub fn clean_monoid_reduces(editor: &mut FunctionEditor, typing: &Vec<TypeID>) { ...@@ -1602,6 +1617,65 @@ pub fn clean_monoid_reduces(editor: &mut FunctionEditor, typing: &Vec<TypeID>) {
_ => {} _ => {}
} }
} }
for id in editor.node_ids() {
if !editor.func().schedules[id.idx()].contains(&Schedule::MonoidReduce) {
continue;
}
let nodes = &editor.func().nodes;
let Some((control, init, reduct)) = nodes[id.idx()].try_reduce() else {
continue;
};
if let Node::Phi {
control: phi_control,
ref data,
} = nodes[reduct.idx()]
&& data.len() == 2
&& data.contains(&id)
&& let other = *data
.into_iter()
.filter(|other| **other != id)
.next()
.unwrap()
&& let Node::Binary {
op: BinaryOperator::Add,
left,
right,
} = nodes[other.idx()]
&& ((left == id) ^ (right == id))
{
let gated_input = if left == id { right } else { left };
let data = data.clone();
editor.edit(|mut edit| {
let zero = edit.add_zero_constant(typing[id.idx()]);
let zero = edit.add_node(Node::Constant { id: zero });
let phi = edit.add_node(Node::Phi {
control: phi_control,
data: data
.iter()
.map(|phi_use| if *phi_use == id { zero } else { gated_input })
.collect(),
});
let new_reduce_id = NodeID::new(edit.num_node_ids());
let new_reduct_id = NodeID::new(edit.num_node_ids() + 1);
let new_reduce = Node::Reduce {
control,
init,
reduct: new_reduct_id,
};
let new_add = Node::Binary {
op: BinaryOperator::Add,
left: new_reduce_id,
right: phi,
};
let new_reduce = edit.add_node(new_reduce);
edit.add_node(new_add);
edit = edit.replace_all_uses(id, new_reduce)?;
edit = edit.delete_node(id)?;
Ok(edit)
});
}
}
} }
/* /*
......
...@@ -212,7 +212,8 @@ fn preliminary_fixups( ...@@ -212,7 +212,8 @@ fn preliminary_fixups(
let (_, init, _) = nodes[reduce.idx()].try_reduce().unwrap(); let (_, init, _) = nodes[reduce.idx()].try_reduce().unwrap();
// Replace uses of the reduce in its cycle with the init. // Replace uses of the reduce in its cycle with the init.
let success = editor.edit(|edit| { let success = editor.edit(|mut edit| {
edit = edit.add_schedule(init, Schedule::ParallelReduce)?;
edit.replace_all_uses_where(reduce, init, |id| reduce_cycles[&reduce].contains(id)) edit.replace_all_uses_where(reduce, init, |id| reduce_cycles[&reduce].contains(id))
}); });
assert!(success); assert!(success);
...@@ -870,7 +871,7 @@ fn spill_clones( ...@@ -870,7 +871,7 @@ fn spill_clones(
// Step 2: filter edges (A, B) to just see edges where A uses B and A // Step 2: filter edges (A, B) to just see edges where A uses B and A
// mutates B. These are the edges that may require a spill. // mutates B. These are the edges that may require a spill.
let mut spill_edges = edges.into_iter().filter(|(a, b)| { let mut spill_edges = edges.into_iter().filter(|(a, b)| {
mutating_writes(editor.func(), *a, objects).any(|id| id == *b) (mutating_writes(editor.func(), *a, objects).any(|id| id == *b)
|| (get_uses(&editor.func().nodes[a.idx()]) || (get_uses(&editor.func().nodes[a.idx()])
.as_ref() .as_ref()
.into_iter() .into_iter()
...@@ -890,7 +891,14 @@ fn spill_clones( ...@@ -890,7 +891,14 @@ fn spill_clones(
data.contains(b) data.contains(b)
&& editor.func().schedules[a.idx()].contains(&Schedule::ParallelReduce) && editor.func().schedules[a.idx()].contains(&Schedule::ParallelReduce)
}) })
.unwrap_or(false)) .unwrap_or(false)))
&& !editor.func().nodes[a.idx()]
.try_write()
.map(|(collect, _, _)| {
collect == *b
&& editor.func().schedules[b.idx()].contains(&Schedule::ParallelReduce)
})
.unwrap_or(false)
}); });
// Step 3: if there is a spill edge, spill it and return true. Otherwise, // Step 3: if there is a spill edge, spill it and return true. Otherwise,
......
...@@ -126,11 +126,24 @@ fn remove_useless_fork_joins( ...@@ -126,11 +126,24 @@ fn remove_useless_fork_joins(
// Third, get rid of fork-joins. // Third, get rid of fork-joins.
for (fork, join) in fork_join_map { for (fork, join) in fork_join_map {
if editor.get_users(*fork).len() == 1 && editor.get_users(*join).len() == 1 { if editor.get_users(*join).len() == 1 {
let fork_use = get_uses(&editor.func().nodes[fork.idx()]).as_ref()[0]; let fork_use = get_uses(&editor.func().nodes[fork.idx()]).as_ref()[0];
let join_use = get_uses(&editor.func().nodes[join.idx()]).as_ref()[0]; let join_use = get_uses(&editor.func().nodes[join.idx()]).as_ref()[0];
let tids: Vec<_> = editor
.get_users(*fork)
.filter(|id| editor.func().nodes[id.idx()].is_thread_id())
.collect();
editor.edit(|mut edit| { editor.edit(|mut edit| {
if !tids.is_empty() {
let u64_ty = edit.add_type(Type::UnsignedInteger64);
let zero = edit.add_zero_constant(u64_ty);
let zero = edit.add_node(Node::Constant { id: zero });
for tid in tids {
edit = edit.replace_all_uses(tid, zero)?;
edit = edit.delete_node(tid)?;
}
}
edit = edit.replace_all_uses(*join, join_use)?; edit = edit.replace_all_uses(*join, join_use)?;
edit = edit.replace_all_uses(*fork, fork_use)?; edit = edit.replace_all_uses(*fork, fork_use)?;
edit = edit.delete_node(*fork)?; edit = edit.delete_node(*fork)?;
......
...@@ -598,6 +598,24 @@ pub fn is_one(editor: &FunctionEditor, id: NodeID) -> bool { ...@@ -598,6 +598,24 @@ pub fn is_one(editor: &FunctionEditor, id: NodeID) -> bool {
|| nodes[id.idx()].is_undef() || nodes[id.idx()].is_undef()
} }
pub fn is_false(editor: &FunctionEditor, id: NodeID) -> bool {
let nodes = &editor.func().nodes;
nodes[id.idx()]
.try_constant()
.map(|id| editor.get_constant(id).is_false())
.unwrap_or(false)
|| nodes[id.idx()].is_undef()
}
pub fn is_true(editor: &FunctionEditor, id: NodeID) -> bool {
let nodes = &editor.func().nodes;
nodes[id.idx()]
.try_constant()
.map(|id| editor.get_constant(id).is_true())
.unwrap_or(false)
|| nodes[id.idx()].is_undef()
}
pub fn is_largest(editor: &FunctionEditor, id: NodeID) -> bool { pub fn is_largest(editor: &FunctionEditor, id: NodeID) -> bool {
let nodes = &editor.func().nodes; let nodes = &editor.func().nodes;
nodes[id.idx()] nodes[id.idx()]
......
...@@ -6,10 +6,9 @@ fn squash(x: f32) -> f32 { ...@@ -6,10 +6,9 @@ fn squash(x: f32) -> f32 {
fn layer_forward<n, m: usize>(vals: f32[n + 1], weights: f32[n + 1, m + 1]) -> f32[m + 1] { fn layer_forward<n, m: usize>(vals: f32[n + 1], weights: f32[n + 1, m + 1]) -> f32[m + 1] {
@res let result : f32[m + 1]; @res let result : f32[m + 1];
result[0] = 1.0; result[0] = 1.0;
@outer_loop for j in 1..=m { @outer_loop for j in 1..=m {
let sum = 0.0; let sum = weights[0, j] * vals[0];
@inner_loop for k in 0..=n { @inner_loop for k in 1..=n {
sum += weights[k, j] * vals[k]; sum += weights[k, j] * vals[k];
} }
result[j] = squash(sum); result[j] = squash(sum);
...@@ -19,13 +18,16 @@ fn layer_forward<n, m: usize>(vals: f32[n + 1], weights: f32[n + 1, m + 1]) -> f ...@@ -19,13 +18,16 @@ fn layer_forward<n, m: usize>(vals: f32[n + 1], weights: f32[n + 1, m + 1]) -> f
} }
fn output_error<n: usize>(target: f32[n + 1], actual: f32[n + 1]) -> f32, f32[n + 1] { fn output_error<n: usize>(target: f32[n + 1], actual: f32[n + 1]) -> f32, f32[n + 1] {
let errsum = 0.0; @loop1 @res let delta : f32[n + 1];
let delta : f32[n + 1]; @loop1 delta[0] = 0.0;
@loop1 for j in 1..=n {
for j in 1..=n {
let a = actual[j]; let a = actual[j];
let t = target[j]; let t = target[j];
delta[j] = a * (1.0 - a) * (t - a); delta[j] = a * (1.0 - a) * (t - a);
}
let errsum = 0.0;
@loop2 for j in 1..=n {
errsum += abs!(delta[j]); errsum += abs!(delta[j]);
} }
...@@ -37,10 +39,9 @@ fn hidden_error<hidden_n, output_n: usize>( ...@@ -37,10 +39,9 @@ fn hidden_error<hidden_n, output_n: usize>(
hidden_weights: f32[hidden_n + 1, output_n + 1], hidden_weights: f32[hidden_n + 1, output_n + 1],
hidden_vals: f32[hidden_n + 1], hidden_vals: f32[hidden_n + 1],
) -> f32, f32[hidden_n + 1] { ) -> f32, f32[hidden_n + 1] {
let errsum = 0.0; @loop1 @res let delta : f32[hidden_n + 1];
let delta : f32[hidden_n + 1]; @loop1 delta[0] = 0.0;
@loop1 for j in 1..=hidden_n {
for j in 1..=hidden_n {
let h = hidden_vals[j]; let h = hidden_vals[j];
let sum = 0.0; let sum = 0.0;
...@@ -49,6 +50,10 @@ fn hidden_error<hidden_n, output_n: usize>( ...@@ -49,6 +50,10 @@ fn hidden_error<hidden_n, output_n: usize>(
} }
delta[j] = h * (1.0 - h) * sum; delta[j] = h * (1.0 - h) * sum;
}
let errsum = 0.0;
@loop2 for j in 1..=hidden_n {
errsum += abs!(delta[j]); errsum += abs!(delta[j]);
} }
...@@ -64,8 +69,8 @@ fn adjust_weights<n, m: usize>( ...@@ -64,8 +69,8 @@ fn adjust_weights<n, m: usize>(
weights: f32[n + 1, m + 1], weights: f32[n + 1, m + 1],
prev_weights: f32[n + 1, m + 1] prev_weights: f32[n + 1, m + 1]
) -> f32[n + 1, m + 1], f32[n + 1, m + 1] { ) -> f32[n + 1, m + 1], f32[n + 1, m + 1] {
for j in 1..=m { @outer_loop for j in 1..=m {
for k in 0..=n { @inner_loop for k in 0..=n {
let new_dw = ETA * delta[j] * vals[k] + MOMENTUM * prev_weights[k, j]; let new_dw = ETA * delta[j] * vals[k] + MOMENTUM * prev_weights[k, j];
weights[k, j] += new_dw; weights[k, j] += new_dw;
prev_weights[k, j] = new_dw; prev_weights[k, j] = new_dw;
...@@ -86,15 +91,15 @@ fn backprop<input_n, hidden_n, output_n: usize>( ...@@ -86,15 +91,15 @@ fn backprop<input_n, hidden_n, output_n: usize>(
) -> f32, f32, ) -> f32, f32,
f32[input_n + 1, hidden_n + 1], f32[input_n + 1, hidden_n + 1], f32[input_n + 1, hidden_n + 1], f32[input_n + 1, hidden_n + 1],
f32[hidden_n + 1, output_n + 1], f32[hidden_n + 1, output_n + 1] { f32[hidden_n + 1, output_n + 1], f32[hidden_n + 1, output_n + 1] {
let hidden_vals = layer_forward::<input_n, hidden_n>(input_vals, input_weights); @forward_input let hidden_vals = layer_forward::<input_n, hidden_n>(input_vals, input_weights);
let output_vals = layer_forward::<hidden_n, output_n>(hidden_vals, hidden_weights); @forward_hidden let output_vals = layer_forward::<hidden_n, output_n>(hidden_vals, hidden_weights);
let out_err, out_delta = output_error::<output_n>(target, output_vals); @output_error let out_err, out_delta = output_error::<output_n>(target, output_vals);
let hid_err, hid_delta = hidden_error::<hidden_n, output_n>(out_delta, hidden_weights, hidden_vals); @hidden_error let hid_err, hid_delta = hidden_error::<hidden_n, output_n>(out_delta, hidden_weights, hidden_vals);
let hidden_weights, hidden_prev_weights @adjust_hidden let hidden_weights, hidden_prev_weights
= adjust_weights::<hidden_n, output_n>(out_delta, hidden_vals, hidden_weights, hidden_prev_weights); = adjust_weights::<hidden_n, output_n>(out_delta, hidden_vals, hidden_weights, hidden_prev_weights);
let input_weights, input_prev_weights @adjust_input let input_weights, input_prev_weights
= adjust_weights::<input_n, hidden_n>(hid_delta, input_vals, input_weights, input_prev_weights); = adjust_weights::<input_n, hidden_n>(hid_delta, input_vals, input_weights, input_prev_weights);
return out_err, hid_err, input_weights, input_prev_weights, hidden_weights, hidden_prev_weights; return out_err, hid_err, input_weights, input_prev_weights, hidden_weights, hidden_prev_weights;
......
...@@ -12,7 +12,7 @@ simpl!(*); ...@@ -12,7 +12,7 @@ simpl!(*);
inline(layer_forward); inline(layer_forward);
delete-uncalled(*); delete-uncalled(*);
no-memset(layer_forward@res); no-memset(layer_forward@res, output_error@res, hidden_error@res);
lift-dc-math(*); lift-dc-math(*);
loop-bound-canon(*); loop-bound-canon(*);
simpl!(*); simpl!(*);
...@@ -25,7 +25,42 @@ fixpoint { ...@@ -25,7 +25,42 @@ fixpoint {
} }
reduce-slf(*); reduce-slf(*);
simpl!(*); simpl!(*);
fork-interchange[0, 1](adjust_weights);
simpl!(*);
infer-schedules(*);
// The first call to layer_forward can be parallelized by 16 (the size of the
// hidden layer) and the second can't be parallelized at all (the size of the
// output layer is 1)
inline(backprop@forward_input, backprop@forward_hidden);
let forward_input = outline(backprop@forward_input);
let forward_hidden = outline(backprop@forward_hidden);
fork-tile[16, 0, false, true](forward_input@outer_loop \ forward_input@inner_loop);
let (outer, inner) = fork-reshape[[1], [0]](forward_input@outer_loop \ forward_input@inner_loop);
let forward_input = outline(inner);
inline(backprop@forward_input);
// The first call to adjust_weights has total loop dimensions of 1 * 17, so not
// worth parallelizing (given that the body is trivial)
// The second call to adjust_weights has a total dimension of 16 * (input + 1)
// which is worth parallelizing, we'll do it by 16
inline(backprop@adjust_hidden, backprop@adjust_input);
let adjust_hidden = outline(backprop@adjust_hidden);
let adjust_input = outline(backprop@adjust_input);
fork-tile[16, 0, false, true](adjust_input);
let (outer, inner) = fork-reshape[[1], [0, 2]](adjust_input);
let adjust_input = outline(inner);
inline(backprop@adjust_input);
delete-uncalled(*);
const-inline(*);
simpl!(*);
fork-split(*); fork-split(*);
unforkify(*); unforkify(output_error, hidden_error, adjust_hidden, adjust_input, forward_hidden, forward_input);
simpl!(*);
gcm(*); gcm(*);
gvn(*); macro simpl!(X) {
dce(*); ccp(X);
simplify-cfg(X);
lift-dc-math(X);
gvn(X);
phi-elim(X);
dce(X);
infer-schedules(X);
}
no-memset(layer_forward@res, output_error@res, hidden_error@res);
phi-elim(*); phi-elim(*);
dce(*); let output_loop1 = outline(output_error@loop1);
crc(*); let output_loop2 = outline(output_error@loop2);
dce(*); let hidden_loop1 = outline(hidden_error@loop1);
slf(*); let hidden_loop2 = outline(hidden_error@loop2);
dce(*); simpl!(*);
inline(layer_forward, backprop@output_error, backprop@hidden_error);
delete-uncalled(*);
gpu(layer_forward, output_loop1, output_loop2, hidden_loop1, hidden_loop2, adjust_weights);
const-inline(*);
let auto = auto-outline(backprop); lift-dc-math(*);
gpu(auto.backprop); loop-bound-canon(*);
simpl!(*);
lift-dc-math(*);
slf(*);
fixpoint {
forkify(*);
fork-guard-elim(*);
fork-coalesce(*);
}
reduce-slf(*);
simpl!(*);
inline(auto.backprop); fork-extend[32](layer_forward@inner_loop);
inline(auto.backprop); clean-monoid-reduces(layer_forward);
delete-uncalled(*); simpl!(layer_forward);
fork-tile[32, 0, false, true](layer_forward@inner_loop);
clean-monoid-reduces(layer_forward);
let out = fork-split(layer_forward@inner_loop);
clean-monoid-reduces(layer_forward);
simpl!(layer_forward);
let fission = fork-fission[out._1_layer_forward.fj0](layer_forward);
simpl!(layer_forward);
sroa[true](*); fork-dim-merge(adjust_weights);
dce(*); simpl!(adjust_weights);
float-collections(*); fork-extend[32](adjust_weights);
reuse-products(*); fork-tile[32, 0, false, true](adjust_weights);
dce(*); fork-split(adjust_weights);
simpl!(adjust_weights);
gcm(*); gcm(*);
type Node = struct { edge_start: u32; num_edges: u32; }; type Node = struct { edge_start: u32; num_edges: u32; };
type StopProd = struct { stop: bool; };
fn make_stop_prod() -> StopProd {
let ret : StopProd;
ret.stop = true;
return ret;
}
#[entry] #[entry]
fn bfs<n, m: usize>(graph_nodes: Node[n], source: u32, edges: u32[m]) -> i32[n] { fn bfs<n, m: usize>(graph_nodes: Node[n], source: u32, edges: u32[m]) -> i32[n] {
...@@ -23,8 +30,6 @@ fn bfs<n, m: usize>(graph_nodes: Node[n], source: u32, edges: u32[m]) -> i32[n] ...@@ -23,8 +30,6 @@ fn bfs<n, m: usize>(graph_nodes: Node[n], source: u32, edges: u32[m]) -> i32[n]
let updated: bool[n]; let updated: bool[n];
while !stop { while !stop {
stop = true;
@loop1 for i in 0..n { @loop1 for i in 0..n {
if mask[i] { if mask[i] {
mask[i] = false; mask[i] = false;
...@@ -42,14 +47,16 @@ fn bfs<n, m: usize>(graph_nodes: Node[n], source: u32, edges: u32[m]) -> i32[n] ...@@ -42,14 +47,16 @@ fn bfs<n, m: usize>(graph_nodes: Node[n], source: u32, edges: u32[m]) -> i32[n]
} }
} }
@make let stop_prod = make_stop_prod();
@loop2 for i in 0..n { @loop2 for i in 0..n {
stop = stop && !updated[i];
if updated[i] { if updated[i] {
mask[i] = true; mask[i] = true;
visited[i] = true; visited[i] = true;
updated[i] = false; updated[i] = false;
stop_prod.stop = updated[i];
} }
} }
stop = stop_prod.stop;
} }
return cost; return cost;
......
...@@ -10,13 +10,14 @@ macro simpl!(X) { ...@@ -10,13 +10,14 @@ macro simpl!(X) {
phi-elim(bfs); phi-elim(bfs);
no-memset(bfs@cost); no-memset(bfs@cost);
outline(bfs@cost_init); let init = outline(bfs@cost_init);
let loop1 = outline(bfs@loop1); let traverse = outline(bfs@loop1);
let loop2 = outline(bfs@loop2); let collect = outline(bfs@loop2);
simpl!(*); simpl!(*);
predication(*); predication(*);
const-inline(*); const-inline(*);
loop-bound-canon(*);
simpl!(*); simpl!(*);
fixpoint { fixpoint {
forkify(*); forkify(*);
...@@ -25,6 +26,37 @@ fixpoint { ...@@ -25,6 +26,37 @@ fixpoint {
simpl!(*); simpl!(*);
predication(*); predication(*);
simpl!(*); simpl!(*);
reduce-slf(*);
simpl!(*);
slf(*);
simpl!(*);
fixpoint {
forkify(collect);
fork-guard-elim(collect);
}
simpl!(collect);
parallel-fork(traverse, collect);
parallel-reduce(traverse, collect);
unforkify(*); fork-tile[32, 0, false, true](traverse, collect);
let (outer, inner) = fork-reshape[[1], [0]](traverse);
let traverse_body = outline(inner);
let (outer, inner) = fork-reshape[[1], [0]](collect);
let collect_body = outline(inner);
let init_body = init;
// Following code seems to generate breaking RT code
//fork-tile[32, 0, false, true](init);
//let (outer, inner) = fork-reshape[[1], [0]](init);
//let init_body = outline(inner);
//inline(bfs@cost_init);
inline(bfs@loop1, bfs@loop2);
delete-uncalled(*);
const-inline(*);
unforkify(init_body, traverse_body, collect_body);
simpl!(*);
gcm(*); gcm(*);
...@@ -10,14 +10,17 @@ macro simpl!(X) { ...@@ -10,14 +10,17 @@ macro simpl!(X) {
phi-elim(bfs); phi-elim(bfs);
no-memset(bfs@cost); no-memset(bfs@cost);
let cost_init = outline(bfs@cost_init); let init = outline(bfs@cost_init);
let loop1 = outline(bfs@loop1); let traverse = outline(bfs@loop1);
let loop2 = outline(bfs@loop2); let collect = outline(bfs@loop2);
gpu(loop1, loop2); parallel-reduce(traverse, collect);
no-memset(make_stop_prod);
gpu(traverse, make_stop_prod, collect);
simpl!(*); simpl!(*);
predication(*); predication(*);
const-inline(*); const-inline(*);
loop-bound-canon(*);
simpl!(*); simpl!(*);
fixpoint { fixpoint {
forkify(*); forkify(*);
...@@ -26,14 +29,17 @@ fixpoint { ...@@ -26,14 +29,17 @@ fixpoint {
simpl!(*); simpl!(*);
predication(*); predication(*);
simpl!(*); simpl!(*);
unforkify(cost_init);
parallel-reduce(loop1);
forkify(*);
fork-guard-elim(*);
simpl!(*);
predication(*);
reduce-slf(*); reduce-slf(*);
simpl!(*); simpl!(*);
fixpoint {
forkify(collect);
fork-guard-elim(collect);
}
simpl!(collect);
fork-tile[1024, 0, false, true](traverse, collect);
fork-split(traverse, collect);
unforkify(init);
gcm(*); gcm(*);
...@@ -19,6 +19,7 @@ pub struct BFSInputs { ...@@ -19,6 +19,7 @@ pub struct BFSInputs {
fn run_bfs(nodes: &[Node], source: u32, edges: &[u32]) -> Vec<i32> { fn run_bfs(nodes: &[Node], source: u32, edges: &[u32]) -> Vec<i32> {
let n = nodes.len() as u64; let n = nodes.len() as u64;
let m = edges.len() as u64; let m = edges.len() as u64;
println!("Running with {} nodes and {} edges.", n, m);
let nodes = HerculesImmBox::from(nodes); let nodes = HerculesImmBox::from(nodes);
let edges = HerculesImmBox::from(edges); let edges = HerculesImmBox::from(edges);
......
...@@ -8,6 +8,7 @@ fn main() { ...@@ -8,6 +8,7 @@ fn main() {
} }
#[test] #[test]
#[ignore]
fn bfs_test_4096() { fn bfs_test_4096() {
bfs_harness(BFSInputs { bfs_harness(BFSInputs {
input: "data/graph4096.txt".to_string(), input: "data/graph4096.txt".to_string(),
......
...@@ -19,6 +19,8 @@ pub struct CFDInputs { ...@@ -19,6 +19,8 @@ pub struct CFDInputs {
pub block_size: usize, pub block_size: usize,
#[clap(short = None, long = Some("pre-euler"))] #[clap(short = None, long = Some("pre-euler"))]
pub pre_euler: bool, pub pre_euler: bool,
#[clap(short, long)]
pub verify: bool,
} }
fn run_euler( fn run_euler(
...@@ -219,6 +221,7 @@ pub fn cfd_harness(args: CFDInputs) { ...@@ -219,6 +221,7 @@ pub fn cfd_harness(args: CFDInputs) {
iterations, iterations,
block_size, block_size,
pre_euler, pre_euler,
verify,
} = args; } = args;
let FarFieldConditions { let FarFieldConditions {
...@@ -268,37 +271,39 @@ pub fn cfd_harness(args: CFDInputs) { ...@@ -268,37 +271,39 @@ pub fn cfd_harness(args: CFDInputs) {
&ff_fc_momentum_z, &ff_fc_momentum_z,
) )
}; };
let res_rust = if pre_euler { if verify {
rust_cfd::pre_euler( let res_rust = if pre_euler {
nelr, rust_cfd::pre_euler(
iterations, nelr,
variables, iterations,
areas.as_slice(), variables,
elements_surrounding_elements.as_slice(), areas.as_slice(),
&normals, elements_surrounding_elements.as_slice(),
&ff_variable, &normals,
&ff_fc_density_energy, &ff_variable,
&ff_fc_momentum_x, &ff_fc_density_energy,
&ff_fc_momentum_y, &ff_fc_momentum_x,
&ff_fc_momentum_z, &ff_fc_momentum_y,
) &ff_fc_momentum_z,
} else { )
rust_cfd::euler( } else {
nelr, rust_cfd::euler(
iterations, nelr,
variables, iterations,
areas.as_slice(), variables,
elements_surrounding_elements.as_slice(), areas.as_slice(),
&normals, elements_surrounding_elements.as_slice(),
&ff_variable, &normals,
&ff_fc_density_energy, &ff_variable,
&ff_fc_momentum_x, &ff_fc_density_energy,
&ff_fc_momentum_y, &ff_fc_momentum_x,
&ff_fc_momentum_z, &ff_fc_momentum_y,
) &ff_fc_momentum_z,
}; )
};
if !compare_floats(&res_juno, &res_rust) { if !compare_floats(&res_juno, &res_rust) {
panic!("Mismatch in results"); panic!("Mismatch in results");
}
} }
} }
...@@ -14,6 +14,7 @@ fn test_euler() { ...@@ -14,6 +14,7 @@ fn test_euler() {
iterations: 1, iterations: 1,
block_size: 16, block_size: 16,
pre_euler: false, pre_euler: false,
verify: true,
}); });
} }
...@@ -24,5 +25,6 @@ fn test_pre_euler() { ...@@ -24,5 +25,6 @@ fn test_pre_euler() {
iterations: 1, iterations: 1,
block_size: 16, block_size: 16,
pre_euler: true, pre_euler: true,
verify: true,
}); });
} }
...@@ -40,10 +40,15 @@ let split = fork-split(loop2); ...@@ -40,10 +40,15 @@ let split = fork-split(loop2);
let loop2_body = outline(split.srad_1.fj1); let loop2_body = outline(split.srad_1.fj1);
simpl!(loop2, loop2_body); simpl!(loop2, loop2_body);
inline(srad@loop2); fork-tile[32, 0, false, false](loop3);
let split = fork-split(loop3);
let loop3_body = outline(split.srad_2.fj1);
simpl!(loop3, loop3_body);
inline(srad@loop2, srad@loop3);
delete-uncalled(*); delete-uncalled(*);
fork-split(extract, compress, loop1, loop2_body, loop3); fork-split(extract, compress, loop1, loop2_body, loop3_body);
unforkify(extract, compress, loop1, loop2_body, loop3); unforkify(extract, compress, loop1, loop2_body, loop3_body);
gcm(*); gcm(*);
...@@ -551,7 +551,14 @@ fn compile_expr( ...@@ -551,7 +551,14 @@ fn compile_expr(
} }
Ok(ExprResult::Expr(ir::ScheduleExp::Record { fields: result })) Ok(ExprResult::Expr(ir::ScheduleExp::Record { fields: result }))
} }
parser::Expr::SetOp { parser::Expr::UnaryOp { span: _, op, exp } => {
let exp = compile_exp_as_expr(*exp, lexer, macrostab, macros)?;
Ok(ExprResult::Expr(ir::ScheduleExp::UnaryOp {
op,
exp: Box::new(exp),
}))
}
parser::Expr::BinaryOp {
span: _, span: _,
op, op,
lhs, lhs,
...@@ -559,7 +566,7 @@ fn compile_expr( ...@@ -559,7 +566,7 @@ fn compile_expr(
} => { } => {
let lhs = compile_exp_as_expr(*lhs, lexer, macrostab, macros)?; let lhs = compile_exp_as_expr(*lhs, lexer, macrostab, macros)?;
let rhs = compile_exp_as_expr(*rhs, lexer, macrostab, macros)?; let rhs = compile_exp_as_expr(*rhs, lexer, macrostab, macros)?;
Ok(ExprResult::Expr(ir::ScheduleExp::SetOp { Ok(ExprResult::Expr(ir::ScheduleExp::BinaryOp {
op, op,
lhs: Box::new(lhs), lhs: Box::new(lhs),
rhs: Box::new(rhs), rhs: Box::new(rhs),
......
...@@ -18,7 +18,7 @@ pub enum Pass { ...@@ -18,7 +18,7 @@ pub enum Pass {
ForkDimMerge, ForkDimMerge,
ForkExtend, ForkExtend,
ForkFissionBufferize, ForkFissionBufferize,
ForkFission, ForkFission,
ForkFusion, ForkFusion,
ForkGuardElim, ForkGuardElim,
ForkInterchange, ForkInterchange,
...@@ -134,8 +134,12 @@ pub enum ScheduleExp { ...@@ -134,8 +134,12 @@ pub enum ScheduleExp {
body: Vec<ScheduleStmt>, body: Vec<ScheduleStmt>,
res: Box<ScheduleExp>, res: Box<ScheduleExp>,
}, },
SetOp { UnaryOp {
op: parser::SetOp, op: parser::UnaryOp,
exp: Box<ScheduleExp>,
},
BinaryOp {
op: parser::BinaryOp,
lhs: Box<ScheduleExp>, lhs: Box<ScheduleExp>,
rhs: Box<ScheduleExp>, rhs: Box<ScheduleExp>,
}, },
......
...@@ -46,6 +46,10 @@ false "false" ...@@ -46,6 +46,10 @@ false "false"
\| "|" \| "|"
& "&" & "&"
! "!"
\|\| "||"
&& "&&"
panic[\t \n\r]+after "panic_after" panic[\t \n\r]+after "panic_after"
print[\t \n\r]+iter "print_iter" print[\t \n\r]+iter "print_iter"
stop[\t \n\r]+after "stop_after" stop[\t \n\r]+after "stop_after"
......
...@@ -6,6 +6,9 @@ ...@@ -6,6 +6,9 @@
%left '\\' %left '\\'
%left '|' %left '|'
%left '&' %left '&'
%left '||'
%left '&&'
%right '!'
%left '.' '@' %left '.' '@'
%% %%
...@@ -27,14 +30,23 @@ Stmt -> Stmt ...@@ -27,14 +30,23 @@ Stmt -> Stmt
{ Stmt::ExprStmt { span: $span, exp: $1 } } { Stmt::ExprStmt { span: $span, exp: $1 } }
| 'fixpoint' FixpointLimit '{' Schedule '}' | 'fixpoint' FixpointLimit '{' Schedule '}'
{ Stmt::Fixpoint { span: $span, limit: $2, body: Box::new($4) } } { Stmt::Fixpoint { span: $span, limit: $2, body: Box::new($4) } }
| 'if' Expr '{' Schedule '}' | 'if' Expr '{' Schedule '}' ElseStmt
{ Stmt::IfThenElse { span: $span, cond: $2, thn: Box::new($4), els: None } } { Stmt::IfThenElse { span: $span, cond: $2, thn: Box::new($4), els: $6 } }
| 'if' Expr '{' Schedule '}' 'else' '{' Schedule '}'
{ Stmt::IfThenElse { span: $span, cond: $2, thn: Box::new($4), els: Some(Box::new($8)) } }
| MacroDecl | MacroDecl
{ Stmt::MacroDecl { span: $span, def: $1 } } { Stmt::MacroDecl { span: $span, def: $1 } }
; ;
ElseStmt -> Option<Box<OperationList>>
: { None }
| 'else' '{' Schedule '}'
{ Some(Box::new($3)) }
| 'else' 'if' Expr '{' Schedule '}' ElseStmt
{ Some(Box::new(OperationList::ConsStmt(
Stmt::IfThenElse { span: $span, cond: $3, thn: Box::new($5), els: $7 },
Box::new(OperationList::NilStmt()),
))) }
;
FixpointLimit -> FixpointLimit FixpointLimit -> FixpointLimit
: { FixpointLimit::NoLimit { span: $span } } : { FixpointLimit::NoLimit { span: $span } }
| 'stop_after' 'INT' | 'stop_after' 'INT'
...@@ -75,11 +87,17 @@ Expr -> Expr ...@@ -75,11 +87,17 @@ Expr -> Expr
| '<' Fields '>' | '<' Fields '>'
{ Expr::Record { span: $span, fields: rev($2) } } { Expr::Record { span: $span, fields: rev($2) } }
| Expr '\\' Expr | Expr '\\' Expr
{ Expr::SetOp { span: $span, op: SetOp::Difference, lhs: Box::new($1), rhs: Box::new($3) } } { Expr::BinaryOp { span: $span, op: BinaryOp::Difference, lhs: Box::new($1), rhs: Box::new($3) } }
| Expr '|' Expr | Expr '|' Expr
{ Expr::SetOp { span: $span, op: SetOp::Union, lhs: Box::new($1), rhs: Box::new($3) } } { Expr::BinaryOp { span: $span, op: BinaryOp::Union, lhs: Box::new($1), rhs: Box::new($3) } }
| Expr '&' Expr | Expr '&' Expr
{ Expr::SetOp { span: $span, op: SetOp::Intersection, lhs: Box::new($1), rhs: Box::new($3) } } { Expr::BinaryOp { span: $span, op: BinaryOp::Intersection, lhs: Box::new($1), rhs: Box::new($3) } }
| '!' Expr
{ Expr::UnaryOp { span: $span, op: UnaryOp::Not, exp: Box::new($2) } }
| Expr '||' Expr
{ Expr::BinaryOp { span: $span, op: BinaryOp::Or, lhs: Box::new($1), rhs: Box::new($3) } }
| Expr '&&' Expr
{ Expr::BinaryOp { span: $span, op: BinaryOp::And, lhs: Box::new($1), rhs: Box::new($3) } }
; ;
Args -> Vec<Expr> Args -> Vec<Expr>
...@@ -179,10 +197,17 @@ pub enum FixpointLimit { ...@@ -179,10 +197,17 @@ pub enum FixpointLimit {
} }
#[derive(Copy, Clone, Debug)] #[derive(Copy, Clone, Debug)]
pub enum SetOp { pub enum UnaryOp {
Not,
}
#[derive(Copy, Clone, Debug)]
pub enum BinaryOp {
Difference, Difference,
Union, Union,
Intersection, Intersection,
Or,
And,
} }
pub enum Expr { pub enum Expr {
...@@ -195,7 +220,8 @@ pub enum Expr { ...@@ -195,7 +220,8 @@ pub enum Expr {
Field { span: Span, lhs: Box<Expr>, field: Span }, Field { span: Span, lhs: Box<Expr>, field: Span },
BlockExpr { span: Span, body: Box<OperationList> }, BlockExpr { span: Span, body: Box<OperationList> },
Record { span: Span, fields: Vec<(Span, Expr)> }, Record { span: Span, fields: Vec<(Span, Expr)> },
SetOp { span: Span, op: SetOp, lhs: Box<Expr>, rhs: Box<Expr> }, UnaryOp { span: Span, op: UnaryOp, exp: Box<Expr> },
BinaryOp { span: Span, op: BinaryOp, lhs: Box<Expr>, rhs: Box<Expr> },
Tuple { span: Span, exps: Vec<Expr> }, Tuple { span: Span, exps: Vec<Expr> },
TupleField { span: Span, lhs: Box<Expr>, field: Span }, TupleField { span: Span, lhs: Box<Expr>, field: Span },
} }
......