Skip to content
Snippets Groups Projects

Compare revisions

Changes are shown as if the source revision was being merged into the target revision. Learn more about comparing revisions.

Source

Select target project
No results found

Target

Select target project
  • llvm/hercules
1 result
Show changes
......@@ -294,6 +294,9 @@ pub enum Value {
Record {
fields: HashMap<String, Value>,
},
Tuple {
values: Vec<Value>,
},
Everything {},
Selection {
selection: Vec<Value>,
......@@ -371,6 +374,11 @@ impl Value {
"Expected code selection, found record".to_string(),
));
}
Value::Tuple { .. } => {
return Err(SchedulerError::SemanticError(
"Expected code selection, found tuple".to_string(),
));
}
Value::Integer { .. } => {
return Err(SchedulerError::SemanticError(
"Expected code selection, found integer".to_string(),
......@@ -1294,6 +1302,7 @@ fn interp_expr(
| Value::Integer { .. }
| Value::Boolean { .. }
| Value::String { .. }
| Value::Tuple { .. }
| Value::SetOp { .. } => Err(SchedulerError::UndefinedField(field.clone())),
Value::JunoFunction { func } => {
match pm.labels.borrow().iter().position(|s| s == field) {
......@@ -1467,6 +1476,28 @@ fn interp_expr(
Ok((Value::Selection { selection: values }, changed))
}
},
ScheduleExp::Tuple { exprs } => {
let mut vals = vec![];
let mut changed = false;
for exp in exprs {
let (val, change) = interp_expr(pm, exp, stringtab, env, functions)?;
vals.push(val);
changed = changed || change;
}
Ok((Value::Tuple { values: vals }, changed))
}
ScheduleExp::TupleField { lhs, field } => {
let (val, changed) = interp_expr(pm, lhs, stringtab, env, functions)?;
match val {
Value::Tuple { values } if *field < values.len() => {
Ok((vec_take(values, *field), changed))
}
_ => Err(SchedulerError::SemanticError(format!(
"No field at index {}",
field
))),
}
}
}
}
......@@ -1524,6 +1555,14 @@ fn update_value(
Some(Value::Record { fields: new_fields })
}
}
// For tuples, if we deleted values like we do for records this would mess up the indices
// which would behave very strangely. Instead if any field cannot be updated then we
// eliminate the entire value
Value::Tuple { values } => values
.into_iter()
.map(|v| update_value(v, func_idx, juno_func_idx))
.collect::<Option<Vec<_>>>()
.map(|values| Value::Tuple { values }),
Value::JunoFunction { func } => {
juno_func_idx[func.idx]
.clone()
......@@ -2897,6 +2936,247 @@ fn run_pass(
pm.delete_gravestones();
pm.clear_analyses();
}
Pass::ForkReshape => {
let mut shape = vec![];
let mut loops = BTreeSet::new();
let mut fork_count = 0;
for arg in args {
let Value::Tuple { values } = arg else {
return Err(SchedulerError::PassError {
pass: "fork-reshape".to_string(),
error: "expected each argument to be a list of integers".to_string(),
});
};
let mut indices = vec![];
for val in values {
let Value::Integer { val: idx } = val else {
return Err(SchedulerError::PassError {
pass: "fork-reshape".to_string(),
error: "expected each argument to be a list of integers".to_string(),
});
};
indices.push(idx);
loops.insert(idx);
fork_count += 1;
}
shape.push(indices);
}
if loops != (0..fork_count).collect() {
return Err(SchedulerError::PassError {
pass: "fork-reshape".to_string(),
error:
"expected forks to be numbered sequentially from 0 and used exactly once"
.to_string(),
});
}
let Some((nodes, func_id)) = selection_as_set(pm, selection) else {
return Err(SchedulerError::PassError {
pass: "fork-reshape".to_string(),
error: "must be applied to nodes in a single function".to_string(),
});
};
let func = func_id.idx();
pm.make_def_uses();
pm.make_fork_join_maps();
pm.make_loops();
pm.make_reduce_cycles();
let def_uses = pm.def_uses.take().unwrap();
let mut fork_join_maps = pm.fork_join_maps.take().unwrap();
let loops = pm.loops.take().unwrap();
let reduce_cycles = pm.reduce_cycles.take().unwrap();
let def_use = &def_uses[func];
let fork_join_map = &mut fork_join_maps[func];
let loops = &loops[func];
let reduce_cycles = &reduce_cycles[func];
let mut editor = FunctionEditor::new(
&mut pm.functions[func],
func_id,
&pm.constants,
&pm.dynamic_constants,
&pm.types,
&pm.labels,
def_use,
);
// There should be exactly one fork nest in the selection and it should contain
// exactly fork_count forks (counting each dimension of each fork)
// We determine the loops (ordered top-down) that are contained in the selection
// (in particular the header is in the selection) and its a fork-join (the header
// is a fork)
let mut loops = loops
.bottom_up_loops()
.into_iter()
.rev()
.filter(|(header, _)| nodes.contains(header) && editor.node(header).is_fork());
let Some((top_fork_head, top_fork_body)) = loops.next() else {
return Err(SchedulerError::PassError {
pass: "fork-reshape".to_string(),
error: format!(
"expected {} forks found 0 in {}",
fork_count,
editor.func().name
),
});
};
// All the remaining forks need to be contained in the top fork body
let mut forks = vec![top_fork_head];
let mut num_dims = editor.node(top_fork_head).try_fork().unwrap().1.len();
for (head, _) in loops {
if !top_fork_body[head.idx()] {
return Err(SchedulerError::PassError {
pass: "fork-reshape".to_string(),
error: "selection includes multiple non-nested forks".to_string(),
});
} else {
forks.push(head);
num_dims += editor.node(head).try_fork().unwrap().1.len();
}
}
if num_dims != fork_count {
return Err(SchedulerError::PassError {
pass: "fork-reshape".to_string(),
error: format!(
"expected {} forks, found {} in {}",
fork_count, num_dims, pm.functions[func].name
),
});
}
// Now, we coalesce all of these forks into one so that we can interchange them
let mut forks = forks.into_iter();
let top_fork = forks.next().unwrap();
let mut cur_fork = top_fork;
for next_fork in forks {
let Some((new_fork, new_join)) =
fork_coalesce_helper(&mut editor, cur_fork, next_fork, &fork_join_map)
else {
return Err(SchedulerError::PassError {
pass: "fork-reshape".to_string(),
error: "failed to coalesce forks".to_string(),
});
};
cur_fork = new_fork;
fork_join_map.insert(new_fork, new_join);
}
let join = *fork_join_map.get(&cur_fork).unwrap();
// Now we have just one fork and we can perform the interchanges we need
// To do this, we track two maps: from original index to current index and from
// current index to original index
let mut orig_to_cur = (0..fork_count).collect::<Vec<_>>();
let mut cur_to_orig = (0..fork_count).collect::<Vec<_>>();
// Now, starting from the first (outermost) index we move the desired fork bound
// into place
for (idx, original_idx) in shape.iter().flat_map(|idx| idx.iter()).enumerate() {
let cur_idx = orig_to_cur[*original_idx];
let swapping = cur_to_orig[idx];
// If the desired factor is already in the correct place, do nothing
if cur_idx == idx {
continue;
}
assert!(idx < cur_idx);
let Some(fork_res) = fork_interchange(&mut editor, cur_fork, join, idx, cur_idx)
else {
return Err(SchedulerError::PassError {
pass: "fork-reshape".to_string(),
error: "failed to interchange forks".to_string(),
});
};
cur_fork = fork_res;
// Update our maps
orig_to_cur[*original_idx] = idx;
orig_to_cur[swapping] = cur_idx;
cur_to_orig[idx] = *original_idx;
cur_to_orig[cur_idx] = swapping;
}
// Finally we split the fork into the desired pieces. We do this by first splitting
// the fork into individual forks and then coalesce the chunks together
// Not sure how split_fork could fail, so if it does panic is fine
let (forks, joins) = split_fork(&mut editor, cur_fork, join, &reduce_cycles).unwrap();
for (fork, join) in forks.iter().zip(joins.iter()) {
fork_join_map.insert(*fork, *join);
}
// Finally coalesce the chunks together
let mut fork_idx = 0;
let mut final_forks = vec![];
for chunk in shape.iter() {
let chunk_len = chunk.len();
let mut cur_fork = forks[fork_idx];
for i in 1..chunk_len {
let next_fork = forks[fork_idx + i];
// Again, not sure at this point how coalesce could fail, so panic if it
// does
let (new_fork, new_join) =
fork_coalesce_helper(&mut editor, cur_fork, next_fork, &fork_join_map)
.unwrap();
cur_fork = new_fork;
fork_join_map.insert(new_fork, new_join);
}
fork_idx += chunk_len;
final_forks.push(cur_fork);
}
// Label each fork and return the labels
// We've trashed our analyses at this point, so rerun them so that we can determine the
// nodes in each of the result fork-joins
pm.clear_analyses();
pm.make_def_uses();
pm.make_nodes_in_fork_joins();
let def_uses = pm.def_uses.take().unwrap();
let nodes_in_fork_joins = pm.nodes_in_fork_joins.take().unwrap();
let def_use = &def_uses[func];
let nodes_in_fork_joins = &nodes_in_fork_joins[func];
let mut editor = FunctionEditor::new(
&mut pm.functions[func],
func_id,
&pm.constants,
&pm.dynamic_constants,
&pm.types,
&pm.labels,
def_use,
);
let labels = create_labels_for_node_sets(
&mut editor,
final_forks
.into_iter()
.map(|fork| nodes_in_fork_joins[&fork].iter().copied()),
)
.into_iter()
.map(|(_, label)| Value::Label {
labels: vec![LabelInfo {
func: func_id,
label,
}],
})
.collect();
result = Value::Tuple { values: labels };
changed = true;
pm.delete_gravestones();
pm.clear_analyses();
}
Pass::WritePredication => {
assert!(args.is_empty());
for func in build_selection(pm, selection, false) {
......@@ -3047,3 +3327,7 @@ where
});
labels
}
fn vec_take<T>(mut v: Vec<T>, index: usize) -> T {
v.swap_remove(index)
}