Skip to content
Snippets Groups Projects

Compare revisions

Changes are shown as if the source revision was being merged into the target revision. Learn more about comparing revisions.

Source

Select target project
No results found

Target

Select target project
  • llvm/hercules
1 result
Show changes
Commits on Source (6)
Showing
with 632 additions and 128 deletions
......@@ -16,6 +16,7 @@ static NUM_FILLER_REGS: AtomicUsize = AtomicUsize::new(0);
* LLVM bindings for Rust, and we are *not* writing any C++.
*/
pub fn cpu_codegen<W: Write>(
module_name: &str,
function: &Function,
types: &Vec<Type>,
constants: &Vec<Constant>,
......@@ -27,6 +28,7 @@ pub fn cpu_codegen<W: Write>(
w: &mut W,
) -> Result<(), Error> {
let ctx = CPUContext {
module_name,
function,
types,
constants,
......@@ -40,6 +42,7 @@ pub fn cpu_codegen<W: Write>(
}
struct CPUContext<'a> {
module_name: &'a str,
function: &'a Function,
types: &'a Vec<Type>,
constants: &'a Vec<Constant>,
......@@ -65,16 +68,18 @@ impl<'a> CPUContext<'a> {
if self.types[return_type.idx()].is_primitive() {
write!(
w,
"define dso_local {} @{}(",
"define dso_local {} @{}_{}(",
self.get_type(return_type),
self.function.name
self.module_name,
self.function.name,
)?;
} else {
write!(
w,
"define dso_local nonnull noundef {} @{}(",
"define dso_local nonnull noundef {} @{}_{}(",
self.get_type(return_type),
self.function.name
self.module_name,
self.function.name,
)?;
}
} else {
......@@ -89,7 +94,11 @@ impl<'a> CPUContext<'a> {
.collect::<Vec<_>>()
.join(", "),
)?;
write!(w, "define dso_local void @{}(", self.function.name,)?;
write!(
w,
"define dso_local void @{}_{}(",
self.module_name, self.function.name,
)?;
}
let mut first_param = true;
// The first parameter is a pointer to CPU backing memory, if it's
......
......@@ -14,6 +14,7 @@ use crate::*;
* of similarities with the CPU LLVM generation plus custom GPU parallelization.
*/
pub fn gpu_codegen<W: Write>(
module_name: &str,
function: &Function,
types: &Vec<Type>,
constants: &Vec<Constant>,
......@@ -170,6 +171,7 @@ pub fn gpu_codegen<W: Write>(
};
let ctx = GPUContext {
module_name,
function,
types,
constants,
......@@ -199,6 +201,7 @@ struct GPUKernelParams {
}
struct GPUContext<'a> {
module_name: &'a str,
function: &'a Function,
types: &'a Vec<Type>,
constants: &'a Vec<Constant>,
......@@ -395,8 +398,8 @@ namespace cg = cooperative_groups;
fn codegen_kernel_begin<W: Write>(&self, w: &mut W) -> Result<(), Error> {
write!(
w,
"__global__ void __launch_bounds__({}) {}_gpu(",
self.kernel_params.max_num_threads, self.function.name
"__global__ void __launch_bounds__({}) {}_{}_gpu(",
self.kernel_params.max_num_threads, self.module_name, self.function.name
)?;
let mut first_param = true;
// The first parameter is a pointer to GPU backing memory, if it's
......@@ -645,7 +648,7 @@ namespace cg = cooperative_groups;
} else {
write!(w, "{}", self.get_type(self.function.return_types[0], false))?;
}
write!(w, " {}(", self.function.name)?;
write!(w, " {}_{}(", self.module_name, self.function.name)?;
let mut first_param = true;
// The first parameter is a pointer to GPU backing memory, if it's
......@@ -721,8 +724,13 @@ namespace cg = cooperative_groups;
write!(w, "\tcudaError_t err;\n")?;
write!(
w,
"\t{}_gpu<<<{}, {}, {}>>>({});\n",
self.function.name, num_blocks, num_threads, dynamic_shared_offset, pass_args
"\t{}_{}_gpu<<<{}, {}, {}>>>({});\n",
self.module_name,
self.function.name,
num_blocks,
num_threads,
dynamic_shared_offset,
pass_args
)?;
write!(w, "\terr = cudaGetLastError();\n")?;
write!(
......
......@@ -74,6 +74,7 @@ use crate::*;
* set some CUDA memory - the user can then take a CUDA reference to that box.
*/
pub fn rt_codegen<W: Write>(
module_name: &str,
func_id: FunctionID,
module: &Module,
def_use: &ImmutableDefUseMap,
......@@ -96,6 +97,7 @@ pub fn rt_codegen<W: Write>(
.map(|(fork, join)| (*join, *fork))
.collect();
let ctx = RTContext {
module_name,
func_id,
module,
def_use,
......@@ -117,6 +119,7 @@ pub fn rt_codegen<W: Write>(
}
struct RTContext<'a> {
module_name: &'a str,
func_id: FunctionID,
module: &'a Module,
def_use: &'a ImmutableDefUseMap,
......@@ -157,7 +160,8 @@ impl<'a> RTContext<'a> {
// Dump the function signature.
write!(
w,
"#[allow(unused_assignments,unused_variables,unused_mut,unused_parens,unused_unsafe,non_snake_case)]async unsafe fn {}(",
"#[allow(unused_assignments,unused_variables,unused_mut,unused_parens,unused_unsafe,non_snake_case)]async unsafe fn {}_{}(",
self.module_name,
func.name
)?;
let mut first_param = true;
......@@ -236,7 +240,7 @@ impl<'a> RTContext<'a> {
// Create the return struct
write!(w, "let mut ret_struct: ::std::mem::MaybeUninit<ReturnStruct> = ::std::mem::MaybeUninit::uninit();")?;
// Call the device function
write!(w, "{}(", callee.name)?;
write!(w, "{}_{}(", self.module_name, callee.name)?;
if self.backing_allocations[&callee_id].contains_key(&self.devices[callee_id.idx()])
{
write!(w, "backing, ")?;
......@@ -672,8 +676,9 @@ impl<'a> RTContext<'a> {
};
write!(
block,
"{}{}(",
"{}{}_{}(",
prefix,
self.module_name,
self.module.functions[callee_id.idx()].name
)?;
for (device, (offset, size)) in self.backing_allocations[&self.func_id]
......@@ -1463,7 +1468,7 @@ impl<'a> RTContext<'a> {
}
// Call the wrapped function.
write!(w, "let ret = {}(", func.name)?;
write!(w, "let ret = {}_{}(", self.module_name, func.name)?;
for (device, _) in self.backing_allocations[&self.func_id].iter() {
write!(
w,
......@@ -1630,8 +1635,9 @@ impl<'a> RTContext<'a> {
let func = &self.module.functions[func_id.idx()];
write!(
w,
"{}fn {}(",
"{}fn {}_{}(",
if is_unsafe { "unsafe " } else { "" },
self.module_name,
func.name
)?;
let mut first_param = true;
......@@ -1667,7 +1673,7 @@ impl<'a> RTContext<'a> {
func_id: FunctionID,
) -> Result<(), Error> {
let func = &self.module.functions[func_id.idx()];
write!(w, "fn {}(", func.name)?;
write!(w, "fn {}_{}(", self.module_name, func.name)?;
let mut first_param = true;
if self.backing_allocations[&func_id].contains_key(&self.devices[func_id.idx()]) {
first_param = false;
......
......@@ -1048,9 +1048,20 @@ impl Constant {
}
}
/*
* Useful for GVN.
*/
pub fn is_false(&self) -> bool {
match self {
Constant::Boolean(false) => true,
_ => false,
}
}
pub fn is_true(&self) -> bool {
match self {
Constant::Boolean(true) => true,
_ => false,
}
}
pub fn is_zero(&self) -> bool {
match self {
Constant::Integer8(0) => true,
......
......@@ -880,6 +880,18 @@ impl<'a, 'b> FunctionEdit<'a, 'b> {
}
}
pub fn get_param_types(&self) -> &Vec<TypeID> {
self.updated_param_types
.as_ref()
.unwrap_or(&self.editor.function.param_types)
}
pub fn get_return_types(&self) -> &Vec<TypeID> {
self.updated_return_types
.as_ref()
.unwrap_or(&self.editor.function.return_types)
}
pub fn set_param_types(&mut self, tys: Vec<TypeID>) {
self.updated_param_types = Some(tys);
}
......
use std::cell::Ref;
use std::collections::HashMap;
use hercules_ir::callgraph::*;
use hercules_ir::def_use::*;
use hercules_ir::ir::*;
use hercules_ir::*;
use crate::*;
......@@ -235,3 +234,216 @@ fn inline_func(
});
}
}
#[derive(Clone, Debug, Copy, PartialEq, Eq)]
enum ParameterLattice {
Top,
Constant(ConstantID),
// Dynamic constant
DynamicConstant(DynamicConstantID, FunctionID),
Bottom,
}
impl ParameterLattice {
fn from_node(node: &Node, func_id: FunctionID) -> Self {
use ParameterLattice::*;
match node {
Node::Undef { ty: _ } => Top,
Node::Constant { id } => Constant(*id),
Node::DynamicConstant { id } => DynamicConstant(*id, func_id),
_ => Bottom,
}
}
fn meet(&mut self, b: Self, cons: Ref<'_, Vec<Constant>>, dcs: Ref<'_, Vec<DynamicConstant>>) {
use ParameterLattice::*;
*self = match (*self, b) {
(Top, b) => b,
(a, Top) => a,
(Bottom, _) | (_, Bottom) => Bottom,
(Constant(id_a), Constant(id_b)) => {
if id_a == id_b {
Constant(id_a)
} else {
Bottom
}
}
(DynamicConstant(dc_a, f_a), DynamicConstant(dc_b, f_b)) => {
if dc_a == dc_b && f_a == f_b {
DynamicConstant(dc_a, f_a)
} else if let (
ir::DynamicConstant::Constant(dcv_a),
ir::DynamicConstant::Constant(dcv_b),
) = (&dcs[dc_a.idx()], &dcs[dc_b.idx()])
&& *dcv_a == *dcv_b
{
DynamicConstant(dc_a, f_a)
} else {
Bottom
}
}
(DynamicConstant(dc, _), Constant(con)) | (Constant(con), DynamicConstant(dc, _)) => {
match (&cons[con.idx()], &dcs[dc.idx()]) {
(ir::Constant::UnsignedInteger64(conv), ir::DynamicConstant::Constant(dcv))
if *conv as usize == *dcv =>
{
Constant(con)
}
_ => Bottom,
}
}
}
}
}
/*
* Top level function to inline constant parameters and constant dynamic
* constant parameters. Identifies functions that are:
*
* 1. Not marked as entry.
* 2. At every call site, a particular parameter is always a specific constant
* or dynamic constant.
*
* These functions can have that constant "inlined" - the parameter is removed
* and all uses of the parameter becomes uses of the constant directly.
*/
pub fn const_inline(
editors: &mut [FunctionEditor],
callgraph: &CallGraph,
inline_collections: bool,
) {
// Run const inlining on each function, starting at the most shallow
// function first, since we want to propagate constants down the call graph.
for func_id in callgraph.topo().into_iter().rev() {
let func = editors[func_id.idx()].func();
if func.entry || callgraph.num_callers(func_id) == 0 {
continue;
}
// Figure out what we know about the parameters to this function.
let mut param_lattice = vec![ParameterLattice::Top; func.param_types.len()];
let mut callers = vec![];
for caller in callgraph.get_callers(func_id) {
let editor = &editors[caller.idx()];
let nodes = &editor.func().nodes;
for id in editor.node_ids() {
if let Some((_, callee, _, args)) = nodes[id.idx()].try_call()
&& callee == func_id
{
if editor.is_mutable(id) {
for (idx, id) in args.into_iter().enumerate() {
let lattice = ParameterLattice::from_node(&nodes[id.idx()], callee);
param_lattice[idx].meet(
lattice,
editor.get_constants(),
editor.get_dynamic_constants(),
);
}
} else {
// If we can't modify the call node in the caller, then
// we can't perform the inlining.
param_lattice = vec![ParameterLattice::Bottom; func.param_types.len()];
}
callers.push((caller, id));
}
}
}
if param_lattice.iter().all(|v| *v == ParameterLattice::Bottom) {
continue;
}
// Replace the arguments.
let editor = &mut editors[func_id.idx()];
let mut param_idx_to_ids: HashMap<usize, Vec<NodeID>> = HashMap::new();
for id in editor.node_ids() {
if let Some(idx) = editor.func().nodes[id.idx()].try_parameter() {
param_idx_to_ids.entry(idx).or_default().push(id);
}
}
let mut params_to_remove = vec![];
let success = editor.edit(|mut edit| {
let mut param_tys = edit.get_param_types().clone();
let mut decrement_index_by = 0;
for idx in 0..param_tys.len() {
if (inline_collections
|| edit
.get_type(param_tys[idx - decrement_index_by])
.is_primitive())
&& let Some(node) = match param_lattice[idx] {
ParameterLattice::Top => Some(Node::Undef {
ty: param_tys[idx - decrement_index_by],
}),
ParameterLattice::Constant(id) => Some(Node::Constant { id }),
ParameterLattice::DynamicConstant(id, _) => {
// Rust moment.
let maybe_cons = edit.get_dynamic_constant(id).try_constant();
if let Some(val) = maybe_cons {
Some(Node::DynamicConstant {
id: edit.add_dynamic_constant(DynamicConstant::Constant(val)),
})
} else {
None
}
}
_ => None,
}
&& let Some(ids) = param_idx_to_ids.get(&idx)
{
let node = edit.add_node(node);
for id in ids {
edit = edit.replace_all_uses(*id, node)?;
edit = edit.delete_node(*id)?;
}
param_tys.remove(idx - decrement_index_by);
params_to_remove.push(idx);
decrement_index_by += 1;
} else if decrement_index_by != 0
&& let Some(ids) = param_idx_to_ids.get(&idx)
{
let node = edit.add_node(Node::Parameter {
index: idx - decrement_index_by,
});
for id in ids {
edit = edit.replace_all_uses(*id, node)?;
edit = edit.delete_node(*id)?;
}
}
}
edit.set_param_types(param_tys);
Ok(edit)
});
params_to_remove.reverse();
// Update callers.
if success {
for (caller, call) in callers {
let editor = &mut editors[caller.idx()];
let success = editor.edit(|mut edit| {
let Node::Call {
control,
function,
dynamic_constants,
args,
} = edit.get_node(call).clone()
else {
panic!();
};
let mut args = args.into_vec();
for idx in params_to_remove.iter() {
args.remove(*idx);
}
let node = edit.add_node(Node::Call {
control,
function,
dynamic_constants,
args: args.into_boxed_slice(),
});
edit = edit.replace_all_uses(call, node)?;
edit = edit.delete_node(call)?;
Ok(edit)
});
assert!(success);
}
}
}
}
......@@ -73,6 +73,7 @@ pub fn canonicalize_single_loop_bounds(
.into_iter()
.partition(|f| loop_bound_iv_phis.contains(&f.phi()));
// Assume there is only one loop bound iv.
if loop_bound_ivs.len() != 1 {
return false;
......@@ -93,9 +94,6 @@ pub fn canonicalize_single_loop_bounds(
return false;
};
let Some(final_value) = final_value else {
return false;
};
let Some(loop_pred) = editor
.get_uses(l.header)
......@@ -109,8 +107,23 @@ pub fn canonicalize_single_loop_bounds(
// (init_id, bound_id, binop node, if node).
// FIXME: This is not always correct, depends on lots of things about the loop IV.
let loop_bound_dc = match *editor.node(condition_node) {
Node::Binary { left, right, op } => match op {
BinaryOperator::LT => right,
BinaryOperator::LTE => right,
BinaryOperator::GT => {return false}
BinaryOperator::GTE => {return false}
BinaryOperator::EQ => {return false}
BinaryOperator::NE => {return false}
_ => {return false}
},
_ => {return false}
};
// FIXME: This is quite fragile.
let guard_info: Option<(NodeID, NodeID, NodeID, NodeID)> = (|| {
let mut guard_info: Option<(NodeID, NodeID, NodeID, NodeID)> = (|| {
let Node::ControlProjection {
control,
selection: _,
......@@ -119,7 +132,7 @@ pub fn canonicalize_single_loop_bounds(
return None;
};
let Node::If { control, cond } = editor.node(control) else {
let Node::If { cond, ..} = editor.node(control) else {
return None;
};
......@@ -129,7 +142,7 @@ pub fn canonicalize_single_loop_bounds(
let Node::Binary {
left: _,
right: _,
right: r,
op: loop_op,
} = editor.node(condition_node)
else {
......@@ -144,7 +157,7 @@ pub fn canonicalize_single_loop_bounds(
return None;
}
if right != final_value {
if right != r {
return None;
}
......@@ -169,7 +182,7 @@ pub fn canonicalize_single_loop_bounds(
// We are assuming this is a simple loop bound (i.e only one induction variable involved), so that .
let Node::DynamicConstant {
id: loop_bound_dc_id,
} = *editor.node(final_value)
} = *editor.node(loop_bound_dc)
else {
return false;
};
......@@ -177,9 +190,9 @@ pub fn canonicalize_single_loop_bounds(
// We need to do 4 (5) things, which are mostly separate.
// 0) Make the update into addition.
// 1) Make the update a positive value.
// 2) Transform the condition into a `<`
// 3) Adjust update to be 1 (and bounds).
// 1) Adjust update to be 1 (and bounds).
// 2) Make the update a positive value. / Transform the condition into a `<`
// - Are these separate?
// 4) Change init to start from 0.
// 5) Find some way to get fork-guard-elim to work with the new fork.
......@@ -198,7 +211,13 @@ pub fn canonicalize_single_loop_bounds(
return false;
}
}
BinaryOperator::LTE => todo!(),
BinaryOperator::LTE => {
if left == *update_expression && editor.node(right).is_dynamic_constant() {
right
} else {
return false;
}
}
BinaryOperator::GT => todo!(),
BinaryOperator::GTE => todo!(),
BinaryOperator::EQ => todo!(),
......@@ -211,8 +230,10 @@ pub fn canonicalize_single_loop_bounds(
_ => return false,
};
let condition_node_data = editor.node(condition_node).clone();
let Node::DynamicConstant {
id: bound_node_dc_id,
id: mut bound_node_dc_id,
} = *editor.node(dc_bound_node)
else {
return false;
......@@ -220,7 +241,56 @@ pub fn canonicalize_single_loop_bounds(
// If increment is negative (how in the world do we know that...)
// Increment can be DefinetlyPostiive, Unknown, DefinetlyNegative.
let misc_guard_thing: Option<Node> = if let Some((init_id, bound_id, binop_node, if_node)) = guard_info {
Some(editor.node(binop_node).clone())
} else {
None
};
let mut condition_node = condition_node;
let result = editor.edit(|mut edit| {
// 2) Transform the condition into a < (from <=)
if let Node::Binary { left, right, op } = condition_node_data {
if BinaryOperator::LTE == op && left == *update_expression {
// Change the condition into <
let new_bop = edit.add_node(Node::Binary { left, right, op: BinaryOperator::LT });
// Change the bound dc to be bound_dc + 1
let one = DynamicConstant::Constant(1);
let one = edit.add_dynamic_constant(one);
let tmp = DynamicConstant::add(bound_node_dc_id, one);
let new_condition_dc = edit.add_dynamic_constant(tmp);
let new_dc_bound_node = edit.add_node(Node::DynamicConstant { id: new_condition_dc });
// // 5) Change loop guard:
guard_info = if let Some((init_id, bound_id, binop_node, if_node)) = guard_info {
// Change binop node
let Some(Node::Binary { left, right, op }) = misc_guard_thing else {unreachable!()};
let blah = edit.add_node(Node::DynamicConstant { id: new_condition_dc});
// FIXME: Don't assume that right is the loop bound in the guard.
let new_binop_node = edit.add_node(Node::Binary { left, right: blah, op: BinaryOperator::LT });
edit = edit.replace_all_uses_where(binop_node, new_binop_node, |usee| *usee == if_node)?;
Some((init_id, bound_id, new_binop_node, if_node))
} else {guard_info};
edit = edit.replace_all_uses_where(dc_bound_node, new_dc_bound_node, |usee| *usee == new_bop)?;
edit = edit.replace_all_uses(condition_node, new_bop)?;
// Change loop condition
dc_bound_node = new_dc_bound_node;
bound_node_dc_id = new_condition_dc;
condition_node = new_bop;
}
};
Ok(edit)
});
let update_expr_users: Vec<_> = editor
.get_users(*update_expression)
.filter(|node| *node != iv.phi() && *node != condition_node)
......@@ -241,34 +311,23 @@ pub fn canonicalize_single_loop_bounds(
let new_init = edit.add_node(new_init);
edit = edit.replace_all_uses_where(*initializer, new_init, |usee| *usee == iv.phi())?;
let new_condition_id = DynamicConstant::sub(bound_node_dc_id, init_dc_id);
let new_condition = Node::DynamicConstant {
id: edit.add_dynamic_constant(new_condition_id),
let new_condition_dc = DynamicConstant::sub(bound_node_dc_id, init_dc_id);
let new_condition_dc_id = Node::DynamicConstant {
id: edit.add_dynamic_constant(new_condition_dc),
};
let new_condition = edit.add_node(new_condition);
let new_condition_dc = edit.add_node(new_condition_dc_id);
edit = edit
.replace_all_uses_where(dc_bound_node, new_condition, |usee| *usee == condition_node)?;
.replace_all_uses_where(dc_bound_node, new_condition_dc, |usee| *usee == condition_node)?;
// Change loop guard:
// 5) Change loop guard:
if let Some((init_id, bound_id, binop_node, if_node)) = guard_info {
edit = edit.replace_all_uses_where(init_id, new_init, |usee| *usee == binop_node)?;
edit =
edit.replace_all_uses_where(bound_id, new_condition, |usee| *usee == binop_node)?;
edit.replace_all_uses_where(bound_id, new_condition_dc, |usee| *usee == binop_node)?;
}
// for user in update_expr_users {
// let new_user = Node::Binary {
// left: user,
// right: *initializer,
// op: BinaryOperator::Add,
// };
// let new_user = edit.add_node(new_user);
// edit = edit.replace_all_uses(user, new_user)?;
// }
// for
// Add the offset back to users of the IV update expression
// 4) Add the offset back to users of the IV update expression
let new_user = Node::Binary {
left: *update_expression,
right: *initializer,
......
......@@ -136,6 +136,77 @@ pub fn predication(editor: &mut FunctionEditor, typing: &Vec<TypeID>) {
bad_branches.insert(branch);
}
}
// Do a quick and dirty rewrite to convert select(a, b, false) to a && b and
// select(a, b, true) to a || b.
for id in editor.node_ids() {
let nodes = &editor.func().nodes;
if let Node::Ternary {
op: TernaryOperator::Select,
first,
second,
third,
} = nodes[id.idx()]
{
if let Some(cons) = nodes[second.idx()].try_constant()
&& editor.get_constant(cons).is_false()
{
editor.edit(|mut edit| {
let inv = edit.add_node(Node::Unary {
op: UnaryOperator::Not,
input: first,
});
let node = edit.add_node(Node::Binary {
op: BinaryOperator::And,
left: inv,
right: third,
});
edit = edit.replace_all_uses(id, node)?;
edit.delete_node(id)
});
} else if let Some(cons) = nodes[third.idx()].try_constant()
&& editor.get_constant(cons).is_false()
{
editor.edit(|mut edit| {
let node = edit.add_node(Node::Binary {
op: BinaryOperator::And,
left: first,
right: second,
});
edit = edit.replace_all_uses(id, node)?;
edit.delete_node(id)
});
} else if let Some(cons) = nodes[second.idx()].try_constant()
&& editor.get_constant(cons).is_true()
{
editor.edit(|mut edit| {
let node = edit.add_node(Node::Binary {
op: BinaryOperator::Or,
left: first,
right: third,
});
edit = edit.replace_all_uses(id, node)?;
edit.delete_node(id)
});
} else if let Some(cons) = nodes[third.idx()].try_constant()
&& editor.get_constant(cons).is_true()
{
editor.edit(|mut edit| {
let inv = edit.add_node(Node::Unary {
op: UnaryOperator::Not,
input: first,
});
let node = edit.add_node(Node::Binary {
op: BinaryOperator::Or,
left: inv,
right: second,
});
edit = edit.replace_all_uses(id, node)?;
edit.delete_node(id)
});
}
}
}
}
/*
......
......@@ -69,6 +69,26 @@ pub fn infer_parallel_reduce(
chain_id = reduct;
}
// If the use is a phi that uses the reduce and a write, then we might
// want to parallelize this still. Set the chain ID to the write.
if let Node::Phi {
control: _,
ref data,
} = func.nodes[chain_id.idx()]
&& data.len()
== data
.into_iter()
.filter(|phi_use| **phi_use == last_reduce)
.count()
+ 1
{
chain_id = *data
.into_iter()
.filter(|phi_use| **phi_use != last_reduce)
.next()
.unwrap();
}
// Check for a Write-Reduce tight cycle.
if let Node::Write {
collect,
......@@ -130,12 +150,13 @@ pub fn infer_monoid_reduce(
reduce_cycles: &HashMap<NodeID, HashSet<NodeID>>,
) {
let is_binop_monoid = |op| {
matches!(
op,
BinaryOperator::Add | BinaryOperator::Mul | BinaryOperator::Or | BinaryOperator::And
)
op == BinaryOperator::Add
|| op == BinaryOperator::Mul
|| op == BinaryOperator::Or
|| op == BinaryOperator::And
};
let is_intrinsic_monoid = |intrinsic| matches!(intrinsic, Intrinsic::Max | Intrinsic::Min);
let is_intrinsic_monoid =
|intrinsic| intrinsic == Intrinsic::Max || intrinsic == Intrinsic::Min;
for id in editor.node_ids() {
let func = editor.func();
......
......@@ -4,7 +4,7 @@ fn squash(x: f32) -> f32 {
}
fn layer_forward<n, m: usize>(vals: f32[n + 1], weights: f32[n + 1, m + 1]) -> f32[m + 1] {
let result : f32[m + 1];
@res let result : f32[m + 1];
result[0] = 1.0;
for j in 1..=m {
......
gvn(*);
dce(*);
phi-elim(*);
dce(*);
crc(*);
dce(*);
slf(*);
dce(*);
macro simpl!(X) {
ccp(X);
simplify-cfg(X);
lift-dc-math(X);
gvn(X);
phi-elim(X);
dce(X);
infer-schedules(X);
}
let auto = auto-outline(backprop);
cpu(auto.backprop);
inline(auto.backprop);
inline(auto.backprop);
simpl!(*);
inline(layer_forward);
delete-uncalled(*);
sroa[true](*);
dce(*);
float-collections(*);
reuse-products(*);
no-memset(layer_forward@res);
lift-dc-math(*);
loop-bound-canon(*);
dce(*);
lift-dc-math(*);
fixpoint {
forkify(*);
fork-guard-elim(*);
fork-coalesce(*);
}
fork-split(*);
gvn(*);
phi-elim(*);
dce(*);
unforkify(*);
gvn(*);
phi-elim(*);
dce(*);
gcm(*);
......@@ -13,6 +13,8 @@ fn main() {
JunoCompiler::new()
.file_in_src("bfs.jn")
.unwrap()
.schedule_in_src("cpu.sch")
.unwrap()
.build()
.unwrap();
}
......@@ -13,8 +13,8 @@ fn bfs<n, m: usize>(graph_nodes: Node[n], source: u32, edges: u32[m]) -> i32[n]
let visited: bool[n];
visited[source as u64] = true;
let cost: i32[n];
for i in 0..n {
@cost @cost_init let cost: i32[n];
@cost_init for i in 0..n {
cost[i] = -1;
}
cost[source as u64] = 0;
......@@ -25,7 +25,7 @@ fn bfs<n, m: usize>(graph_nodes: Node[n], source: u32, edges: u32[m]) -> i32[n]
while !stop {
stop = true;
for i in 0..n {
@loop1 for i in 0..n {
if mask[i] {
mask[i] = false;
......@@ -42,11 +42,11 @@ fn bfs<n, m: usize>(graph_nodes: Node[n], source: u32, edges: u32[m]) -> i32[n]
}
}
for i in 0..n {
@loop2 for i in 0..n {
stop = stop && !updated[i];
if updated[i] {
mask[i] = true;
visited[i] = true;
stop = false;
updated[i] = false;
}
}
......
macro simpl!(X) {
ccp(X);
simplify-cfg(X);
lift-dc-math(X);
gvn(X);
phi-elim(X);
dce(X);
infer-schedules(X);
}
phi-elim(bfs);
no-memset(bfs@cost);
outline(bfs@cost_init);
let loop1 = outline(bfs@loop1);
let loop2 = outline(bfs@loop2);
simpl!(*);
predication(*);
const-inline(*);
simpl!(*);
fixpoint {
forkify(*);
fork-guard-elim(*);
}
simpl!(*);
predication(*);
simpl!(*);
unforkify(*);
gcm(*);
gvn(*);
dce(*);
phi-elim(*);
dce(*);
crc(*);
dce(*);
slf(*);
dce(*);
macro simpl!(X) {
ccp(X);
simplify-cfg(X);
lift-dc-math(X);
gvn(X);
phi-elim(X);
crc(X);
slf(X);
dce(X);
infer-schedules(X);
}
let auto = auto-outline(euler);
cpu(auto.euler);
inline(auto.euler);
inline(auto.euler);
simpl!(*);
inline(compute_step_factor, compute_flux, compute_flux_contribution, time_step);
delete-uncalled(*);
simpl!(*);
ip-sroa[false](*);
sroa[false](*);
predication(*);
const-inline(*);
simpl!(*);
fixpoint {
forkify(*);
fork-guard-elim(*);
}
simpl!(*);
no-memset(compute_step_factor@res, compute_flux@res, copy_vars@res);
parallel-reduce(time_step, copy_vars, compute_flux@outer_loop \ compute_flux@inner_loop);
sroa[false](auto.euler);
dce(*);
float-collections(*);
dce(*);
unforkify(*);
gcm(*);
gvn(*);
dce(*);
phi-elim(*);
dce(*);
crc(*);
dce(*);
slf(*);
dce(*);
macro simpl!(X) {
ccp(X);
simplify-cfg(X);
lift-dc-math(X);
gvn(X);
phi-elim(X);
crc(X);
slf(X);
dce(X);
infer-schedules(X);
}
let auto = auto-outline(pre_euler);
cpu(auto.pre_euler);
inline(auto.pre_euler);
inline(auto.pre_euler);
simpl!(*);
inline(compute_step_factor, compute_flux, compute_flux_contributions, compute_flux_contribution, time_step);
delete-uncalled(*);
simpl!(*);
ip-sroa[false](*);
sroa[false](*);
predication(*);
const-inline(*);
simpl!(*);
fixpoint {
forkify(*);
fork-guard-elim(*);
}
simpl!(*);
sroa[false](auto.pre_euler);
dce(*);
float-collections(*);
dce(*);
unforkify(*);
gcm(*);
......@@ -47,7 +47,7 @@ fn compute_speed_of_sound(density: f32, pressure: f32) -> f32 {
}
fn compute_step_factor<nelr: usize>(variables: Variables::<nelr>, areas: f32[nelr]) -> f32[nelr] {
let step_factors : f32[nelr];
@res let step_factors : f32[nelr];
for i in 0..nelr {
let density = variables.density[i];
......@@ -106,9 +106,9 @@ fn compute_flux<nelr: usize>(
ff_flux_contribution_momentum_z: float3,
) -> Variables::<nelr> {
const smoothing_coefficient : f32 = 0.2;
let fluxes: Variables::<nelr>;
@res let fluxes: Variables::<nelr>;
for i in 0..nelr {
@outer_loop for i in 0..nelr {
let density_i = variables.density[i];
let momentum_i = float3 { x: variables.momentum.x[i],
......@@ -131,7 +131,7 @@ fn compute_flux<nelr: usize>(
let flux_i_momentum = float3 { x: 0.0, y: 0.0, z: 0.0 };
let flux_i_density_energy : f32 = 0.0;
for j in 0..NNB {
@inner_loop for j in 0..NNB {
let nb = elements_surrounding_elements[j, i];
let normal = float3 {
x: normals.x[j, i],
......@@ -249,7 +249,7 @@ fn time_step<nelr: usize>(
}
fn copy_vars<nelr: usize>(variables: Variables::<nelr>) -> Variables::<nelr> {
let result : Variables::<nelr>;
@res let result : Variables::<nelr>;
for i in 0..nelr {
result.density[i] = variables.density[i];
......
......@@ -13,8 +13,8 @@ fn srad_bench(c: &mut Criterion) {
let mut r = runner!(srad);
let niter = 100;
let lambda = 0.5;
let nrows = 502;
let ncols = 458;
let nrows = 512;
let ncols = 512;
let image = "data/image.pgm".to_string();
let Image {
image: image_ori,
......
......@@ -13,6 +13,8 @@ fn main() {
JunoCompiler::new()
.file_in_src("srad.jn")
.unwrap()
.schedule_in_src("cpu.sch")
.unwrap()
.build()
.unwrap();
}
macro simpl!(X) {
ccp(X);
simplify-cfg(X);
lift-dc-math(X);
gvn(X);
phi-elim(X);
dce(X);
infer-schedules(X);
}
phi-elim(*);
let loop1 = outline(srad@loop1);
let loop2 = outline(srad@loop2);
let loop3 = outline(srad@loop3);
simpl!(*);
const-inline(*);
crc(*);
slf(*);
write-predication(*);
simpl!(*);
predication(*);
simpl!(*);
predication(*);
simpl!(*);
fixpoint {
forkify(*);
fork-guard-elim(*);
fork-coalesce(*);
}
simpl!(*);
fork-interchange[0, 1](loop1);
fork-split(*);
unforkify(*);
gcm(*);