Skip to content
Snippets Groups Projects
Commit 50b1c0e7 authored by rarbore2's avatar rarbore2
Browse files

More misc. rodinia opts

parent cc694f57
No related branches found
No related tags found
1 merge request!206More misc. rodinia opts
......@@ -938,7 +938,7 @@ impl<'a> RTContext<'a> {
let dst_device = self.node_colors.0[&collect];
write!(
block,
"::hercules_rt::__copy_{}_to_{}({}.byte_add({} as usize).0, {}.0, {});",
"::hercules_rt::__copy_{}_to_{}({}.byte_add({} as usize).0, {}.0, {} as usize);",
src_device.name(),
dst_device.name(),
self.get_value(collect, bb, false),
......
......@@ -879,8 +879,15 @@ fn spill_clones(
|| editor.func().nodes[a.idx()].is_reduce())
&& !editor.func().nodes[a.idx()]
.try_reduce()
.map(|(_, init, _)| {
init == *b
.map(|(_, init, reduct)| {
(init == *b || reduct == *b)
&& editor.func().schedules[a.idx()].contains(&Schedule::ParallelReduce)
})
.unwrap_or(false)
&& !editor.func().nodes[a.idx()]
.try_phi()
.map(|(_, data)| {
data.contains(b)
&& editor.func().schedules[a.idx()].contains(&Schedule::ParallelReduce)
})
.unwrap_or(false))
......@@ -1302,39 +1309,53 @@ enum UTerm {
Device(Device),
}
fn unify(
mut equations: VecDeque<(UTerm, UTerm)>,
) -> Result<BTreeMap<NodeID, Device>, BTreeMap<NodeID, Device>> {
fn unify(mut equations: VecDeque<(UTerm, UTerm)>) -> Result<BTreeMap<NodeID, Device>, NodeID> {
let mut theta = BTreeMap::new();
// First, assign devices to nodes when a rule directly says to.
for _ in 0..equations.len() {
let (l, r) = equations.pop_front().unwrap();
match (l, r) {
(UTerm::Node(n), UTerm::Device(d)) | (UTerm::Device(d), UTerm::Node(n)) => {
if let Some(old_d) = theta.insert(n, d)
&& old_d != d
{
return Err(n);
}
}
_ => equations.push_back((l, r)),
}
}
// Second, iterate the rest of the rules until...
// 1. The rules are exhausted. All the nodes have device assignments.
// 2. No progress is being made. Some nodes may not have device assignments.
// 3. An inconsistency has been found. The inconsistency is returned.
let mut no_progress_iters = 0;
while no_progress_iters <= equations.len()
&& let Some((l, r)) = equations.pop_front()
{
match (l, r) {
(UTerm::Node(_), UTerm::Node(_)) => {
if l != r {
equations.push_back((l, r));
}
no_progress_iters += 1;
}
(UTerm::Node(n), UTerm::Device(d)) | (UTerm::Device(d), UTerm::Node(n)) => {
theta.insert(n, d);
for (l, r) in equations.iter_mut() {
if *l == UTerm::Node(n) {
*l = UTerm::Device(d);
}
if *r == UTerm::Node(n) {
*r = UTerm::Device(d);
}
let (UTerm::Node(l), UTerm::Node(r)) = (l, r) else {
panic!();
};
match (theta.get(&l), theta.get(&r)) {
(Some(ld), Some(rd)) => {
if ld != rd {
return Err(l);
} else {
no_progress_iters = 0;
}
no_progress_iters = 0;
}
(UTerm::Device(d1), UTerm::Device(d2)) if d1 == d2 => {
(Some(d), None) | (None, Some(d)) => {
let d = *d;
theta.insert(l, d);
theta.insert(r, d);
no_progress_iters = 0;
}
_ => {
return Err(theta);
(None, None) => {
equations.push_back((UTerm::Node(l), UTerm::Node(r)));
no_progress_iters += 1;
}
}
}
......@@ -1377,8 +1398,8 @@ fn color_nodes(
} if !editor.get_type(typing[id.idx()]).is_primitive() => {
// Every input to a phi needs to be on the same device. The
// phi itself is also on this device.
for (l, r) in zip(data.into_iter(), data.into_iter().skip(1).chain(once(&id))) {
equations.push((UTerm::Node(*l), UTerm::Node(*r)));
for data in data {
equations.push((UTerm::Node(*data), UTerm::Node(id)));
}
}
Node::Reduce {
......@@ -1394,7 +1415,7 @@ fn color_nodes(
} if !editor.get_type(typing[id.idx()]).is_primitive() => {
// Every input to the reduce, and the reduce itself, are on
// the same device.
equations.push((UTerm::Node(first), UTerm::Node(second)));
equations.push((UTerm::Node(first), UTerm::Node(id)));
equations.push((UTerm::Node(second), UTerm::Node(id)));
}
Node::Constant { id: _ }
......@@ -1533,12 +1554,11 @@ fn color_nodes(
}
Some(func_colors)
}
Err(progress) => {
Err(id) => {
// If unification failed, then there's some node using a node in
// `progress` that's expecting a different type than what it got.
// Pick one and add potentially inter-device copies on each def-use
// edge. We'll clean these up later.
let (id, _) = progress.into_iter().next().unwrap();
// that's expecting a different type than what it got. Add
// potentially inter-device copies on each def-use edge. We'll clean
// these up later.
let users: Vec<_> = editor.get_users(id).collect();
let success = editor.edit(|mut edit| {
let cons = edit.add_zero_constant(typing[id.idx()]);
......
use std::collections::{BTreeSet, HashMap, HashSet};
use std::iter::once;
use hercules_ir::def_use::*;
use hercules_ir::ir::*;
use crate::*;
......@@ -42,6 +42,10 @@ pub fn infer_parallel_reduce(
fork_join_map: &HashMap<NodeID, NodeID>,
reduce_cycles: &HashMap<NodeID, HashSet<NodeID>>,
) {
let join_fork_map: HashMap<_, _> = fork_join_map
.into_iter()
.map(|(fork, join)| (*join, *fork))
.collect();
for id in editor.node_ids() {
let func = editor.func();
if !func.nodes[id.idx()].is_reduce() {
......@@ -98,40 +102,11 @@ pub fn infer_parallel_reduce(
&& *collect == last_reduce
&& !reduce_cycles[&last_reduce].contains(data)
{
// If there is a Write-Reduce tight cycle, get the position indices.
let positions = indices
.iter()
.filter_map(|index| {
if let Index::Position(indices) = index {
Some(indices)
} else {
None
}
})
.flat_map(|pos| pos.iter());
// Get the Forks corresponding to uses of bare ThreadIDs.
let fork_thread_id_pairs = positions.filter_map(|id| {
if let Node::ThreadID { control, dimension } = func.nodes[id.idx()] {
Some((control, dimension))
} else {
None
}
});
let mut forks = HashMap::<NodeID, Vec<usize>>::new();
for (fork, dim) in fork_thread_id_pairs {
forks.entry(fork).or_default().push(dim);
}
// Check if one of the Forks correspond to the Join associated with
// the Reduce being considered, and has all of its dimensions
// represented in the indexing.
let is_parallel = forks.into_iter().any(|(id, mut rep_dims)| {
rep_dims.sort();
rep_dims.dedup();
fork_join_map[&id] == first_control.unwrap()
&& func.nodes[id.idx()].try_fork().unwrap().1.len() == rep_dims.len()
});
let is_parallel = indices_parallel_over_forks(
editor,
indices,
once(join_fork_map[&first_control.unwrap()]),
);
if is_parallel {
editor.edit(|edit| edit.add_schedule(id, Schedule::ParallelReduce));
......@@ -145,6 +120,7 @@ pub fn infer_parallel_reduce(
* operands must be the Reduce node, and all other operands must not be in the
* Reduce node's cycle.
*/
#[rustfmt::skip]
pub fn infer_monoid_reduce(
editor: &mut FunctionEditor,
reduce_cycles: &HashMap<NodeID, HashSet<NodeID>>,
......
......@@ -532,6 +532,24 @@ where
let fork_thread_id_pairs = node_indices(indices).filter_map(|id| {
if let Node::ThreadID { control, dimension } = nodes[id.idx()] {
Some((control, dimension))
} else if let Node::Binary {
op: BinaryOperator::Add,
left: tid,
right: cons,
} = nodes[id.idx()]
&& let Node::ThreadID { control, dimension } = nodes[tid.idx()]
&& (nodes[cons.idx()].is_constant() || nodes[cons.idx()].is_dynamic_constant())
{
Some((control, dimension))
} else if let Node::Binary {
op: BinaryOperator::Add,
left: cons,
right: tid,
} = nodes[id.idx()]
&& let Node::ThreadID { control, dimension } = nodes[tid.idx()]
&& (nodes[cons.idx()].is_constant() || nodes[cons.idx()].is_dynamic_constant())
{
Some((control, dimension))
} else {
None
}
......
......@@ -120,6 +120,7 @@ simpl!(fuse4);
//fork-tile[2, 0, false, true](fuse4@channel_loop);
//fork-split(fuse4@channel_loop);
//clean-monoid-reduces(fuse4);
unforkify(fuse4@channel_loop);
no-memset(fuse5@res1);
no-memset(fuse5@res2);
......
......@@ -86,7 +86,7 @@ fixpoint {
simpl!(max_gradient);
fork-dim-merge(max_gradient);
simpl!(max_gradient);
fork-tile[8, 0, false, false](max_gradient);
fork-tile[16, 0, false, false](max_gradient);
let split = fork-split(max_gradient);
clean-monoid-reduces(max_gradient);
let out = outline(split._4_max_gradient.fj1);
......@@ -104,11 +104,18 @@ fixpoint {
}
predication(reject_zero_crossings);
simpl!(reject_zero_crossings);
fork-tile[4, 1, false, false](reject_zero_crossings);
fork-tile[4, 0, false, false](reject_zero_crossings);
fork-interchange[1, 2](reject_zero_crossings);
let split = fork-split(reject_zero_crossings);
let reject_zero_crossings_body = outline(split._5_reject_zero_crossings.fj2);
fork-coalesce(reject_zero_crossings, reject_zero_crossings_body);
simpl!(reject_zero_crossings, reject_zero_crossings_body);
async-call(edge_detection@le, edge_detection@zc);
fork-split(gaussian_smoothing_body, laplacian_estimate_body, zero_crossings_body, gradient, reject_zero_crossings);
unforkify(gaussian_smoothing_body, laplacian_estimate_body, zero_crossings_body, gradient, reject_zero_crossings);
fork-split(gaussian_smoothing_body, laplacian_estimate_body, zero_crossings_body, gradient, reject_zero_crossings_body);
unforkify(gaussian_smoothing_body, laplacian_estimate_body, zero_crossings_body, gradient, reject_zero_crossings_body);
simpl!(*);
......
......@@ -7,9 +7,9 @@ fn layer_forward<n, m: usize>(vals: f32[n + 1], weights: f32[n + 1, m + 1]) -> f
@res let result : f32[m + 1];
result[0] = 1.0;
for j in 1..=m {
@outer_loop for j in 1..=m {
let sum = 0.0;
for k in 0..=n {
@inner_loop for k in 0..=n {
sum += weights[k, j] * vals[k];
}
result[j] = squash(sum);
......
......@@ -15,20 +15,17 @@ delete-uncalled(*);
no-memset(layer_forward@res);
lift-dc-math(*);
loop-bound-canon(*);
dce(*);
simpl!(*);
lift-dc-math(*);
slf(*);
fixpoint {
forkify(*);
fork-guard-elim(*);
fork-coalesce(*);
}
reduce-slf(*);
simpl!(*);
fork-split(*);
gvn(*);
phi-elim(*);
dce(*);
unforkify(*);
gvn(*);
phi-elim(*);
dce(*);
gcm(*);
gvn(*);
phi-elim(*);
dce(*);
macro simpl!(X) {
ccp(X);
simplify-cfg(X);
lift-dc-math(X);
gvn(X);
phi-elim(X);
dce(X);
infer-schedules(X);
}
let outline = auto-outline(bfs);
gpu(outline.bfs);
phi-elim(bfs);
no-memset(bfs@cost);
let cost_init = outline(bfs@cost_init);
let loop1 = outline(bfs@loop1);
let loop2 = outline(bfs@loop2);
gpu(loop1, loop2);
ip-sroa(*);
sroa(*);
dce(*);
gvn(*);
phi-elim(*);
dce(*);
simpl!(*);
predication(*);
const-inline(*);
simpl!(*);
fixpoint {
forkify(*);
fork-guard-elim(*);
}
simpl!(*);
predication(*);
simpl!(*);
//forkify(*);
infer-schedules(*);
unforkify(cost_init);
parallel-reduce(loop1);
forkify(*);
fork-guard-elim(*);
simpl!(*);
predication(*);
reduce-slf(*);
simpl!(*);
gcm(*);
fixpoint {
float-collections(*);
dce(*);
gcm(*);
}
......@@ -29,6 +29,8 @@ fixpoint {
}
simpl!(*);
fork-interchange[0, 1](loop1);
reduce-slf(*);
simpl!(*);
fork-split(*);
unforkify(*);
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment