Skip to content
Snippets Groups Projects

Compare revisions

Changes are shown as if the source revision was being merged into the target revision. Learn more about comparing revisions.

Source

Select target project
No results found

Target

Select target project
  • llvm/hercules
1 result
Show changes
Commits on Source (6)
Showing
with 632 additions and 128 deletions
...@@ -16,6 +16,7 @@ static NUM_FILLER_REGS: AtomicUsize = AtomicUsize::new(0); ...@@ -16,6 +16,7 @@ static NUM_FILLER_REGS: AtomicUsize = AtomicUsize::new(0);
* LLVM bindings for Rust, and we are *not* writing any C++. * LLVM bindings for Rust, and we are *not* writing any C++.
*/ */
pub fn cpu_codegen<W: Write>( pub fn cpu_codegen<W: Write>(
module_name: &str,
function: &Function, function: &Function,
types: &Vec<Type>, types: &Vec<Type>,
constants: &Vec<Constant>, constants: &Vec<Constant>,
...@@ -27,6 +28,7 @@ pub fn cpu_codegen<W: Write>( ...@@ -27,6 +28,7 @@ pub fn cpu_codegen<W: Write>(
w: &mut W, w: &mut W,
) -> Result<(), Error> { ) -> Result<(), Error> {
let ctx = CPUContext { let ctx = CPUContext {
module_name,
function, function,
types, types,
constants, constants,
...@@ -40,6 +42,7 @@ pub fn cpu_codegen<W: Write>( ...@@ -40,6 +42,7 @@ pub fn cpu_codegen<W: Write>(
} }
struct CPUContext<'a> { struct CPUContext<'a> {
module_name: &'a str,
function: &'a Function, function: &'a Function,
types: &'a Vec<Type>, types: &'a Vec<Type>,
constants: &'a Vec<Constant>, constants: &'a Vec<Constant>,
...@@ -65,16 +68,18 @@ impl<'a> CPUContext<'a> { ...@@ -65,16 +68,18 @@ impl<'a> CPUContext<'a> {
if self.types[return_type.idx()].is_primitive() { if self.types[return_type.idx()].is_primitive() {
write!( write!(
w, w,
"define dso_local {} @{}(", "define dso_local {} @{}_{}(",
self.get_type(return_type), self.get_type(return_type),
self.function.name self.module_name,
self.function.name,
)?; )?;
} else { } else {
write!( write!(
w, w,
"define dso_local nonnull noundef {} @{}(", "define dso_local nonnull noundef {} @{}_{}(",
self.get_type(return_type), self.get_type(return_type),
self.function.name self.module_name,
self.function.name,
)?; )?;
} }
} else { } else {
...@@ -89,7 +94,11 @@ impl<'a> CPUContext<'a> { ...@@ -89,7 +94,11 @@ impl<'a> CPUContext<'a> {
.collect::<Vec<_>>() .collect::<Vec<_>>()
.join(", "), .join(", "),
)?; )?;
write!(w, "define dso_local void @{}(", self.function.name,)?; write!(
w,
"define dso_local void @{}_{}(",
self.module_name, self.function.name,
)?;
} }
let mut first_param = true; let mut first_param = true;
// The first parameter is a pointer to CPU backing memory, if it's // The first parameter is a pointer to CPU backing memory, if it's
......
...@@ -14,6 +14,7 @@ use crate::*; ...@@ -14,6 +14,7 @@ use crate::*;
* of similarities with the CPU LLVM generation plus custom GPU parallelization. * of similarities with the CPU LLVM generation plus custom GPU parallelization.
*/ */
pub fn gpu_codegen<W: Write>( pub fn gpu_codegen<W: Write>(
module_name: &str,
function: &Function, function: &Function,
types: &Vec<Type>, types: &Vec<Type>,
constants: &Vec<Constant>, constants: &Vec<Constant>,
...@@ -170,6 +171,7 @@ pub fn gpu_codegen<W: Write>( ...@@ -170,6 +171,7 @@ pub fn gpu_codegen<W: Write>(
}; };
let ctx = GPUContext { let ctx = GPUContext {
module_name,
function, function,
types, types,
constants, constants,
...@@ -199,6 +201,7 @@ struct GPUKernelParams { ...@@ -199,6 +201,7 @@ struct GPUKernelParams {
} }
struct GPUContext<'a> { struct GPUContext<'a> {
module_name: &'a str,
function: &'a Function, function: &'a Function,
types: &'a Vec<Type>, types: &'a Vec<Type>,
constants: &'a Vec<Constant>, constants: &'a Vec<Constant>,
...@@ -395,8 +398,8 @@ namespace cg = cooperative_groups; ...@@ -395,8 +398,8 @@ namespace cg = cooperative_groups;
fn codegen_kernel_begin<W: Write>(&self, w: &mut W) -> Result<(), Error> { fn codegen_kernel_begin<W: Write>(&self, w: &mut W) -> Result<(), Error> {
write!( write!(
w, w,
"__global__ void __launch_bounds__({}) {}_gpu(", "__global__ void __launch_bounds__({}) {}_{}_gpu(",
self.kernel_params.max_num_threads, self.function.name self.kernel_params.max_num_threads, self.module_name, self.function.name
)?; )?;
let mut first_param = true; let mut first_param = true;
// The first parameter is a pointer to GPU backing memory, if it's // The first parameter is a pointer to GPU backing memory, if it's
...@@ -645,7 +648,7 @@ namespace cg = cooperative_groups; ...@@ -645,7 +648,7 @@ namespace cg = cooperative_groups;
} else { } else {
write!(w, "{}", self.get_type(self.function.return_types[0], false))?; write!(w, "{}", self.get_type(self.function.return_types[0], false))?;
} }
write!(w, " {}(", self.function.name)?; write!(w, " {}_{}(", self.module_name, self.function.name)?;
let mut first_param = true; let mut first_param = true;
// The first parameter is a pointer to GPU backing memory, if it's // The first parameter is a pointer to GPU backing memory, if it's
...@@ -721,8 +724,13 @@ namespace cg = cooperative_groups; ...@@ -721,8 +724,13 @@ namespace cg = cooperative_groups;
write!(w, "\tcudaError_t err;\n")?; write!(w, "\tcudaError_t err;\n")?;
write!( write!(
w, w,
"\t{}_gpu<<<{}, {}, {}>>>({});\n", "\t{}_{}_gpu<<<{}, {}, {}>>>({});\n",
self.function.name, num_blocks, num_threads, dynamic_shared_offset, pass_args self.module_name,
self.function.name,
num_blocks,
num_threads,
dynamic_shared_offset,
pass_args
)?; )?;
write!(w, "\terr = cudaGetLastError();\n")?; write!(w, "\terr = cudaGetLastError();\n")?;
write!( write!(
......
...@@ -74,6 +74,7 @@ use crate::*; ...@@ -74,6 +74,7 @@ use crate::*;
* set some CUDA memory - the user can then take a CUDA reference to that box. * set some CUDA memory - the user can then take a CUDA reference to that box.
*/ */
pub fn rt_codegen<W: Write>( pub fn rt_codegen<W: Write>(
module_name: &str,
func_id: FunctionID, func_id: FunctionID,
module: &Module, module: &Module,
def_use: &ImmutableDefUseMap, def_use: &ImmutableDefUseMap,
...@@ -96,6 +97,7 @@ pub fn rt_codegen<W: Write>( ...@@ -96,6 +97,7 @@ pub fn rt_codegen<W: Write>(
.map(|(fork, join)| (*join, *fork)) .map(|(fork, join)| (*join, *fork))
.collect(); .collect();
let ctx = RTContext { let ctx = RTContext {
module_name,
func_id, func_id,
module, module,
def_use, def_use,
...@@ -117,6 +119,7 @@ pub fn rt_codegen<W: Write>( ...@@ -117,6 +119,7 @@ pub fn rt_codegen<W: Write>(
} }
struct RTContext<'a> { struct RTContext<'a> {
module_name: &'a str,
func_id: FunctionID, func_id: FunctionID,
module: &'a Module, module: &'a Module,
def_use: &'a ImmutableDefUseMap, def_use: &'a ImmutableDefUseMap,
...@@ -157,7 +160,8 @@ impl<'a> RTContext<'a> { ...@@ -157,7 +160,8 @@ impl<'a> RTContext<'a> {
// Dump the function signature. // Dump the function signature.
write!( write!(
w, w,
"#[allow(unused_assignments,unused_variables,unused_mut,unused_parens,unused_unsafe,non_snake_case)]async unsafe fn {}(", "#[allow(unused_assignments,unused_variables,unused_mut,unused_parens,unused_unsafe,non_snake_case)]async unsafe fn {}_{}(",
self.module_name,
func.name func.name
)?; )?;
let mut first_param = true; let mut first_param = true;
...@@ -236,7 +240,7 @@ impl<'a> RTContext<'a> { ...@@ -236,7 +240,7 @@ impl<'a> RTContext<'a> {
// Create the return struct // Create the return struct
write!(w, "let mut ret_struct: ::std::mem::MaybeUninit<ReturnStruct> = ::std::mem::MaybeUninit::uninit();")?; write!(w, "let mut ret_struct: ::std::mem::MaybeUninit<ReturnStruct> = ::std::mem::MaybeUninit::uninit();")?;
// Call the device function // Call the device function
write!(w, "{}(", callee.name)?; write!(w, "{}_{}(", self.module_name, callee.name)?;
if self.backing_allocations[&callee_id].contains_key(&self.devices[callee_id.idx()]) if self.backing_allocations[&callee_id].contains_key(&self.devices[callee_id.idx()])
{ {
write!(w, "backing, ")?; write!(w, "backing, ")?;
...@@ -672,8 +676,9 @@ impl<'a> RTContext<'a> { ...@@ -672,8 +676,9 @@ impl<'a> RTContext<'a> {
}; };
write!( write!(
block, block,
"{}{}(", "{}{}_{}(",
prefix, prefix,
self.module_name,
self.module.functions[callee_id.idx()].name self.module.functions[callee_id.idx()].name
)?; )?;
for (device, (offset, size)) in self.backing_allocations[&self.func_id] for (device, (offset, size)) in self.backing_allocations[&self.func_id]
...@@ -1463,7 +1468,7 @@ impl<'a> RTContext<'a> { ...@@ -1463,7 +1468,7 @@ impl<'a> RTContext<'a> {
} }
// Call the wrapped function. // Call the wrapped function.
write!(w, "let ret = {}(", func.name)?; write!(w, "let ret = {}_{}(", self.module_name, func.name)?;
for (device, _) in self.backing_allocations[&self.func_id].iter() { for (device, _) in self.backing_allocations[&self.func_id].iter() {
write!( write!(
w, w,
...@@ -1630,8 +1635,9 @@ impl<'a> RTContext<'a> { ...@@ -1630,8 +1635,9 @@ impl<'a> RTContext<'a> {
let func = &self.module.functions[func_id.idx()]; let func = &self.module.functions[func_id.idx()];
write!( write!(
w, w,
"{}fn {}(", "{}fn {}_{}(",
if is_unsafe { "unsafe " } else { "" }, if is_unsafe { "unsafe " } else { "" },
self.module_name,
func.name func.name
)?; )?;
let mut first_param = true; let mut first_param = true;
...@@ -1667,7 +1673,7 @@ impl<'a> RTContext<'a> { ...@@ -1667,7 +1673,7 @@ impl<'a> RTContext<'a> {
func_id: FunctionID, func_id: FunctionID,
) -> Result<(), Error> { ) -> Result<(), Error> {
let func = &self.module.functions[func_id.idx()]; let func = &self.module.functions[func_id.idx()];
write!(w, "fn {}(", func.name)?; write!(w, "fn {}_{}(", self.module_name, func.name)?;
let mut first_param = true; let mut first_param = true;
if self.backing_allocations[&func_id].contains_key(&self.devices[func_id.idx()]) { if self.backing_allocations[&func_id].contains_key(&self.devices[func_id.idx()]) {
first_param = false; first_param = false;
......
...@@ -1048,9 +1048,20 @@ impl Constant { ...@@ -1048,9 +1048,20 @@ impl Constant {
} }
} }
/* pub fn is_false(&self) -> bool {
* Useful for GVN. match self {
*/ Constant::Boolean(false) => true,
_ => false,
}
}
pub fn is_true(&self) -> bool {
match self {
Constant::Boolean(true) => true,
_ => false,
}
}
pub fn is_zero(&self) -> bool { pub fn is_zero(&self) -> bool {
match self { match self {
Constant::Integer8(0) => true, Constant::Integer8(0) => true,
......
...@@ -880,6 +880,18 @@ impl<'a, 'b> FunctionEdit<'a, 'b> { ...@@ -880,6 +880,18 @@ impl<'a, 'b> FunctionEdit<'a, 'b> {
} }
} }
pub fn get_param_types(&self) -> &Vec<TypeID> {
self.updated_param_types
.as_ref()
.unwrap_or(&self.editor.function.param_types)
}
pub fn get_return_types(&self) -> &Vec<TypeID> {
self.updated_return_types
.as_ref()
.unwrap_or(&self.editor.function.return_types)
}
pub fn set_param_types(&mut self, tys: Vec<TypeID>) { pub fn set_param_types(&mut self, tys: Vec<TypeID>) {
self.updated_param_types = Some(tys); self.updated_param_types = Some(tys);
} }
......
use std::cell::Ref;
use std::collections::HashMap; use std::collections::HashMap;
use hercules_ir::callgraph::*; use hercules_ir::*;
use hercules_ir::def_use::*;
use hercules_ir::ir::*;
use crate::*; use crate::*;
...@@ -235,3 +234,216 @@ fn inline_func( ...@@ -235,3 +234,216 @@ fn inline_func(
}); });
} }
} }
#[derive(Clone, Debug, Copy, PartialEq, Eq)]
enum ParameterLattice {
Top,
Constant(ConstantID),
// Dynamic constant
DynamicConstant(DynamicConstantID, FunctionID),
Bottom,
}
impl ParameterLattice {
fn from_node(node: &Node, func_id: FunctionID) -> Self {
use ParameterLattice::*;
match node {
Node::Undef { ty: _ } => Top,
Node::Constant { id } => Constant(*id),
Node::DynamicConstant { id } => DynamicConstant(*id, func_id),
_ => Bottom,
}
}
fn meet(&mut self, b: Self, cons: Ref<'_, Vec<Constant>>, dcs: Ref<'_, Vec<DynamicConstant>>) {
use ParameterLattice::*;
*self = match (*self, b) {
(Top, b) => b,
(a, Top) => a,
(Bottom, _) | (_, Bottom) => Bottom,
(Constant(id_a), Constant(id_b)) => {
if id_a == id_b {
Constant(id_a)
} else {
Bottom
}
}
(DynamicConstant(dc_a, f_a), DynamicConstant(dc_b, f_b)) => {
if dc_a == dc_b && f_a == f_b {
DynamicConstant(dc_a, f_a)
} else if let (
ir::DynamicConstant::Constant(dcv_a),
ir::DynamicConstant::Constant(dcv_b),
) = (&dcs[dc_a.idx()], &dcs[dc_b.idx()])
&& *dcv_a == *dcv_b
{
DynamicConstant(dc_a, f_a)
} else {
Bottom
}
}
(DynamicConstant(dc, _), Constant(con)) | (Constant(con), DynamicConstant(dc, _)) => {
match (&cons[con.idx()], &dcs[dc.idx()]) {
(ir::Constant::UnsignedInteger64(conv), ir::DynamicConstant::Constant(dcv))
if *conv as usize == *dcv =>
{
Constant(con)
}
_ => Bottom,
}
}
}
}
}
/*
* Top level function to inline constant parameters and constant dynamic
* constant parameters. Identifies functions that are:
*
* 1. Not marked as entry.
* 2. At every call site, a particular parameter is always a specific constant
* or dynamic constant.
*
* These functions can have that constant "inlined" - the parameter is removed
* and all uses of the parameter becomes uses of the constant directly.
*/
pub fn const_inline(
editors: &mut [FunctionEditor],
callgraph: &CallGraph,
inline_collections: bool,
) {
// Run const inlining on each function, starting at the most shallow
// function first, since we want to propagate constants down the call graph.
for func_id in callgraph.topo().into_iter().rev() {
let func = editors[func_id.idx()].func();
if func.entry || callgraph.num_callers(func_id) == 0 {
continue;
}
// Figure out what we know about the parameters to this function.
let mut param_lattice = vec![ParameterLattice::Top; func.param_types.len()];
let mut callers = vec![];
for caller in callgraph.get_callers(func_id) {
let editor = &editors[caller.idx()];
let nodes = &editor.func().nodes;
for id in editor.node_ids() {
if let Some((_, callee, _, args)) = nodes[id.idx()].try_call()
&& callee == func_id
{
if editor.is_mutable(id) {
for (idx, id) in args.into_iter().enumerate() {
let lattice = ParameterLattice::from_node(&nodes[id.idx()], callee);
param_lattice[idx].meet(
lattice,
editor.get_constants(),
editor.get_dynamic_constants(),
);
}
} else {
// If we can't modify the call node in the caller, then
// we can't perform the inlining.
param_lattice = vec![ParameterLattice::Bottom; func.param_types.len()];
}
callers.push((caller, id));
}
}
}
if param_lattice.iter().all(|v| *v == ParameterLattice::Bottom) {
continue;
}
// Replace the arguments.
let editor = &mut editors[func_id.idx()];
let mut param_idx_to_ids: HashMap<usize, Vec<NodeID>> = HashMap::new();
for id in editor.node_ids() {
if let Some(idx) = editor.func().nodes[id.idx()].try_parameter() {
param_idx_to_ids.entry(idx).or_default().push(id);
}
}
let mut params_to_remove = vec![];
let success = editor.edit(|mut edit| {
let mut param_tys = edit.get_param_types().clone();
let mut decrement_index_by = 0;
for idx in 0..param_tys.len() {
if (inline_collections
|| edit
.get_type(param_tys[idx - decrement_index_by])
.is_primitive())
&& let Some(node) = match param_lattice[idx] {
ParameterLattice::Top => Some(Node::Undef {
ty: param_tys[idx - decrement_index_by],
}),
ParameterLattice::Constant(id) => Some(Node::Constant { id }),
ParameterLattice::DynamicConstant(id, _) => {
// Rust moment.
let maybe_cons = edit.get_dynamic_constant(id).try_constant();
if let Some(val) = maybe_cons {
Some(Node::DynamicConstant {
id: edit.add_dynamic_constant(DynamicConstant::Constant(val)),
})
} else {
None
}
}
_ => None,
}
&& let Some(ids) = param_idx_to_ids.get(&idx)
{
let node = edit.add_node(node);
for id in ids {
edit = edit.replace_all_uses(*id, node)?;
edit = edit.delete_node(*id)?;
}
param_tys.remove(idx - decrement_index_by);
params_to_remove.push(idx);
decrement_index_by += 1;
} else if decrement_index_by != 0
&& let Some(ids) = param_idx_to_ids.get(&idx)
{
let node = edit.add_node(Node::Parameter {
index: idx - decrement_index_by,
});
for id in ids {
edit = edit.replace_all_uses(*id, node)?;
edit = edit.delete_node(*id)?;
}
}
}
edit.set_param_types(param_tys);
Ok(edit)
});
params_to_remove.reverse();
// Update callers.
if success {
for (caller, call) in callers {
let editor = &mut editors[caller.idx()];
let success = editor.edit(|mut edit| {
let Node::Call {
control,
function,
dynamic_constants,
args,
} = edit.get_node(call).clone()
else {
panic!();
};
let mut args = args.into_vec();
for idx in params_to_remove.iter() {
args.remove(*idx);
}
let node = edit.add_node(Node::Call {
control,
function,
dynamic_constants,
args: args.into_boxed_slice(),
});
edit = edit.replace_all_uses(call, node)?;
edit = edit.delete_node(call)?;
Ok(edit)
});
assert!(success);
}
}
}
}
...@@ -73,6 +73,7 @@ pub fn canonicalize_single_loop_bounds( ...@@ -73,6 +73,7 @@ pub fn canonicalize_single_loop_bounds(
.into_iter() .into_iter()
.partition(|f| loop_bound_iv_phis.contains(&f.phi())); .partition(|f| loop_bound_iv_phis.contains(&f.phi()));
// Assume there is only one loop bound iv. // Assume there is only one loop bound iv.
if loop_bound_ivs.len() != 1 { if loop_bound_ivs.len() != 1 {
return false; return false;
...@@ -93,9 +94,6 @@ pub fn canonicalize_single_loop_bounds( ...@@ -93,9 +94,6 @@ pub fn canonicalize_single_loop_bounds(
return false; return false;
}; };
let Some(final_value) = final_value else {
return false;
};
let Some(loop_pred) = editor let Some(loop_pred) = editor
.get_uses(l.header) .get_uses(l.header)
...@@ -109,8 +107,23 @@ pub fn canonicalize_single_loop_bounds( ...@@ -109,8 +107,23 @@ pub fn canonicalize_single_loop_bounds(
// (init_id, bound_id, binop node, if node). // (init_id, bound_id, binop node, if node).
// FIXME: This is not always correct, depends on lots of things about the loop IV.
let loop_bound_dc = match *editor.node(condition_node) {
Node::Binary { left, right, op } => match op {
BinaryOperator::LT => right,
BinaryOperator::LTE => right,
BinaryOperator::GT => {return false}
BinaryOperator::GTE => {return false}
BinaryOperator::EQ => {return false}
BinaryOperator::NE => {return false}
_ => {return false}
},
_ => {return false}
};
// FIXME: This is quite fragile. // FIXME: This is quite fragile.
let guard_info: Option<(NodeID, NodeID, NodeID, NodeID)> = (|| { let mut guard_info: Option<(NodeID, NodeID, NodeID, NodeID)> = (|| {
let Node::ControlProjection { let Node::ControlProjection {
control, control,
selection: _, selection: _,
...@@ -119,7 +132,7 @@ pub fn canonicalize_single_loop_bounds( ...@@ -119,7 +132,7 @@ pub fn canonicalize_single_loop_bounds(
return None; return None;
}; };
let Node::If { control, cond } = editor.node(control) else { let Node::If { cond, ..} = editor.node(control) else {
return None; return None;
}; };
...@@ -129,7 +142,7 @@ pub fn canonicalize_single_loop_bounds( ...@@ -129,7 +142,7 @@ pub fn canonicalize_single_loop_bounds(
let Node::Binary { let Node::Binary {
left: _, left: _,
right: _, right: r,
op: loop_op, op: loop_op,
} = editor.node(condition_node) } = editor.node(condition_node)
else { else {
...@@ -144,7 +157,7 @@ pub fn canonicalize_single_loop_bounds( ...@@ -144,7 +157,7 @@ pub fn canonicalize_single_loop_bounds(
return None; return None;
} }
if right != final_value { if right != r {
return None; return None;
} }
...@@ -169,7 +182,7 @@ pub fn canonicalize_single_loop_bounds( ...@@ -169,7 +182,7 @@ pub fn canonicalize_single_loop_bounds(
// We are assuming this is a simple loop bound (i.e only one induction variable involved), so that . // We are assuming this is a simple loop bound (i.e only one induction variable involved), so that .
let Node::DynamicConstant { let Node::DynamicConstant {
id: loop_bound_dc_id, id: loop_bound_dc_id,
} = *editor.node(final_value) } = *editor.node(loop_bound_dc)
else { else {
return false; return false;
}; };
...@@ -177,9 +190,9 @@ pub fn canonicalize_single_loop_bounds( ...@@ -177,9 +190,9 @@ pub fn canonicalize_single_loop_bounds(
// We need to do 4 (5) things, which are mostly separate. // We need to do 4 (5) things, which are mostly separate.
// 0) Make the update into addition. // 0) Make the update into addition.
// 1) Make the update a positive value. // 1) Adjust update to be 1 (and bounds).
// 2) Transform the condition into a `<` // 2) Make the update a positive value. / Transform the condition into a `<`
// 3) Adjust update to be 1 (and bounds). // - Are these separate?
// 4) Change init to start from 0. // 4) Change init to start from 0.
// 5) Find some way to get fork-guard-elim to work with the new fork. // 5) Find some way to get fork-guard-elim to work with the new fork.
...@@ -198,7 +211,13 @@ pub fn canonicalize_single_loop_bounds( ...@@ -198,7 +211,13 @@ pub fn canonicalize_single_loop_bounds(
return false; return false;
} }
} }
BinaryOperator::LTE => todo!(), BinaryOperator::LTE => {
if left == *update_expression && editor.node(right).is_dynamic_constant() {
right
} else {
return false;
}
}
BinaryOperator::GT => todo!(), BinaryOperator::GT => todo!(),
BinaryOperator::GTE => todo!(), BinaryOperator::GTE => todo!(),
BinaryOperator::EQ => todo!(), BinaryOperator::EQ => todo!(),
...@@ -211,8 +230,10 @@ pub fn canonicalize_single_loop_bounds( ...@@ -211,8 +230,10 @@ pub fn canonicalize_single_loop_bounds(
_ => return false, _ => return false,
}; };
let condition_node_data = editor.node(condition_node).clone();
let Node::DynamicConstant { let Node::DynamicConstant {
id: bound_node_dc_id, id: mut bound_node_dc_id,
} = *editor.node(dc_bound_node) } = *editor.node(dc_bound_node)
else { else {
return false; return false;
...@@ -220,7 +241,56 @@ pub fn canonicalize_single_loop_bounds( ...@@ -220,7 +241,56 @@ pub fn canonicalize_single_loop_bounds(
// If increment is negative (how in the world do we know that...) // If increment is negative (how in the world do we know that...)
// Increment can be DefinetlyPostiive, Unknown, DefinetlyNegative. // Increment can be DefinetlyPostiive, Unknown, DefinetlyNegative.
let misc_guard_thing: Option<Node> = if let Some((init_id, bound_id, binop_node, if_node)) = guard_info {
Some(editor.node(binop_node).clone())
} else {
None
};
let mut condition_node = condition_node;
let result = editor.edit(|mut edit| {
// 2) Transform the condition into a < (from <=)
if let Node::Binary { left, right, op } = condition_node_data {
if BinaryOperator::LTE == op && left == *update_expression {
// Change the condition into <
let new_bop = edit.add_node(Node::Binary { left, right, op: BinaryOperator::LT });
// Change the bound dc to be bound_dc + 1
let one = DynamicConstant::Constant(1);
let one = edit.add_dynamic_constant(one);
let tmp = DynamicConstant::add(bound_node_dc_id, one);
let new_condition_dc = edit.add_dynamic_constant(tmp);
let new_dc_bound_node = edit.add_node(Node::DynamicConstant { id: new_condition_dc });
// // 5) Change loop guard:
guard_info = if let Some((init_id, bound_id, binop_node, if_node)) = guard_info {
// Change binop node
let Some(Node::Binary { left, right, op }) = misc_guard_thing else {unreachable!()};
let blah = edit.add_node(Node::DynamicConstant { id: new_condition_dc});
// FIXME: Don't assume that right is the loop bound in the guard.
let new_binop_node = edit.add_node(Node::Binary { left, right: blah, op: BinaryOperator::LT });
edit = edit.replace_all_uses_where(binop_node, new_binop_node, |usee| *usee == if_node)?;
Some((init_id, bound_id, new_binop_node, if_node))
} else {guard_info};
edit = edit.replace_all_uses_where(dc_bound_node, new_dc_bound_node, |usee| *usee == new_bop)?;
edit = edit.replace_all_uses(condition_node, new_bop)?;
// Change loop condition
dc_bound_node = new_dc_bound_node;
bound_node_dc_id = new_condition_dc;
condition_node = new_bop;
}
};
Ok(edit)
});
let update_expr_users: Vec<_> = editor let update_expr_users: Vec<_> = editor
.get_users(*update_expression) .get_users(*update_expression)
.filter(|node| *node != iv.phi() && *node != condition_node) .filter(|node| *node != iv.phi() && *node != condition_node)
...@@ -241,34 +311,23 @@ pub fn canonicalize_single_loop_bounds( ...@@ -241,34 +311,23 @@ pub fn canonicalize_single_loop_bounds(
let new_init = edit.add_node(new_init); let new_init = edit.add_node(new_init);
edit = edit.replace_all_uses_where(*initializer, new_init, |usee| *usee == iv.phi())?; edit = edit.replace_all_uses_where(*initializer, new_init, |usee| *usee == iv.phi())?;
let new_condition_id = DynamicConstant::sub(bound_node_dc_id, init_dc_id); let new_condition_dc = DynamicConstant::sub(bound_node_dc_id, init_dc_id);
let new_condition = Node::DynamicConstant { let new_condition_dc_id = Node::DynamicConstant {
id: edit.add_dynamic_constant(new_condition_id), id: edit.add_dynamic_constant(new_condition_dc),
}; };
let new_condition = edit.add_node(new_condition); let new_condition_dc = edit.add_node(new_condition_dc_id);
edit = edit edit = edit
.replace_all_uses_where(dc_bound_node, new_condition, |usee| *usee == condition_node)?; .replace_all_uses_where(dc_bound_node, new_condition_dc, |usee| *usee == condition_node)?;
// Change loop guard: // 5) Change loop guard:
if let Some((init_id, bound_id, binop_node, if_node)) = guard_info { if let Some((init_id, bound_id, binop_node, if_node)) = guard_info {
edit = edit.replace_all_uses_where(init_id, new_init, |usee| *usee == binop_node)?; edit = edit.replace_all_uses_where(init_id, new_init, |usee| *usee == binop_node)?;
edit = edit =
edit.replace_all_uses_where(bound_id, new_condition, |usee| *usee == binop_node)?; edit.replace_all_uses_where(bound_id, new_condition_dc, |usee| *usee == binop_node)?;
} }
// for user in update_expr_users { // 4) Add the offset back to users of the IV update expression
// let new_user = Node::Binary {
// left: user,
// right: *initializer,
// op: BinaryOperator::Add,
// };
// let new_user = edit.add_node(new_user);
// edit = edit.replace_all_uses(user, new_user)?;
// }
// for
// Add the offset back to users of the IV update expression
let new_user = Node::Binary { let new_user = Node::Binary {
left: *update_expression, left: *update_expression,
right: *initializer, right: *initializer,
......
...@@ -136,6 +136,77 @@ pub fn predication(editor: &mut FunctionEditor, typing: &Vec<TypeID>) { ...@@ -136,6 +136,77 @@ pub fn predication(editor: &mut FunctionEditor, typing: &Vec<TypeID>) {
bad_branches.insert(branch); bad_branches.insert(branch);
} }
} }
// Do a quick and dirty rewrite to convert select(a, b, false) to a && b and
// select(a, b, true) to a || b.
for id in editor.node_ids() {
let nodes = &editor.func().nodes;
if let Node::Ternary {
op: TernaryOperator::Select,
first,
second,
third,
} = nodes[id.idx()]
{
if let Some(cons) = nodes[second.idx()].try_constant()
&& editor.get_constant(cons).is_false()
{
editor.edit(|mut edit| {
let inv = edit.add_node(Node::Unary {
op: UnaryOperator::Not,
input: first,
});
let node = edit.add_node(Node::Binary {
op: BinaryOperator::And,
left: inv,
right: third,
});
edit = edit.replace_all_uses(id, node)?;
edit.delete_node(id)
});
} else if let Some(cons) = nodes[third.idx()].try_constant()
&& editor.get_constant(cons).is_false()
{
editor.edit(|mut edit| {
let node = edit.add_node(Node::Binary {
op: BinaryOperator::And,
left: first,
right: second,
});
edit = edit.replace_all_uses(id, node)?;
edit.delete_node(id)
});
} else if let Some(cons) = nodes[second.idx()].try_constant()
&& editor.get_constant(cons).is_true()
{
editor.edit(|mut edit| {
let node = edit.add_node(Node::Binary {
op: BinaryOperator::Or,
left: first,
right: third,
});
edit = edit.replace_all_uses(id, node)?;
edit.delete_node(id)
});
} else if let Some(cons) = nodes[third.idx()].try_constant()
&& editor.get_constant(cons).is_true()
{
editor.edit(|mut edit| {
let inv = edit.add_node(Node::Unary {
op: UnaryOperator::Not,
input: first,
});
let node = edit.add_node(Node::Binary {
op: BinaryOperator::Or,
left: inv,
right: second,
});
edit = edit.replace_all_uses(id, node)?;
edit.delete_node(id)
});
}
}
}
} }
/* /*
......
...@@ -69,6 +69,26 @@ pub fn infer_parallel_reduce( ...@@ -69,6 +69,26 @@ pub fn infer_parallel_reduce(
chain_id = reduct; chain_id = reduct;
} }
// If the use is a phi that uses the reduce and a write, then we might
// want to parallelize this still. Set the chain ID to the write.
if let Node::Phi {
control: _,
ref data,
} = func.nodes[chain_id.idx()]
&& data.len()
== data
.into_iter()
.filter(|phi_use| **phi_use == last_reduce)
.count()
+ 1
{
chain_id = *data
.into_iter()
.filter(|phi_use| **phi_use != last_reduce)
.next()
.unwrap();
}
// Check for a Write-Reduce tight cycle. // Check for a Write-Reduce tight cycle.
if let Node::Write { if let Node::Write {
collect, collect,
...@@ -130,12 +150,13 @@ pub fn infer_monoid_reduce( ...@@ -130,12 +150,13 @@ pub fn infer_monoid_reduce(
reduce_cycles: &HashMap<NodeID, HashSet<NodeID>>, reduce_cycles: &HashMap<NodeID, HashSet<NodeID>>,
) { ) {
let is_binop_monoid = |op| { let is_binop_monoid = |op| {
matches!( op == BinaryOperator::Add
op, || op == BinaryOperator::Mul
BinaryOperator::Add | BinaryOperator::Mul | BinaryOperator::Or | BinaryOperator::And || op == BinaryOperator::Or
) || op == BinaryOperator::And
}; };
let is_intrinsic_monoid = |intrinsic| matches!(intrinsic, Intrinsic::Max | Intrinsic::Min); let is_intrinsic_monoid =
|intrinsic| intrinsic == Intrinsic::Max || intrinsic == Intrinsic::Min;
for id in editor.node_ids() { for id in editor.node_ids() {
let func = editor.func(); let func = editor.func();
......
...@@ -4,7 +4,7 @@ fn squash(x: f32) -> f32 { ...@@ -4,7 +4,7 @@ fn squash(x: f32) -> f32 {
} }
fn layer_forward<n, m: usize>(vals: f32[n + 1], weights: f32[n + 1, m + 1]) -> f32[m + 1] { fn layer_forward<n, m: usize>(vals: f32[n + 1], weights: f32[n + 1, m + 1]) -> f32[m + 1] {
let result : f32[m + 1]; @res let result : f32[m + 1];
result[0] = 1.0; result[0] = 1.0;
for j in 1..=m { for j in 1..=m {
......
gvn(*); macro simpl!(X) {
dce(*); ccp(X);
phi-elim(*); simplify-cfg(X);
dce(*); lift-dc-math(X);
crc(*); gvn(X);
dce(*); phi-elim(X);
slf(*); dce(X);
dce(*); infer-schedules(X);
}
let auto = auto-outline(backprop); simpl!(*);
cpu(auto.backprop); inline(layer_forward);
inline(auto.backprop);
inline(auto.backprop);
delete-uncalled(*); delete-uncalled(*);
sroa[true](*); no-memset(layer_forward@res);
dce(*); lift-dc-math(*);
float-collections(*); loop-bound-canon(*);
reuse-products(*);
dce(*); dce(*);
lift-dc-math(*);
fixpoint {
forkify(*);
fork-guard-elim(*);
fork-coalesce(*);
}
fork-split(*);
gvn(*);
phi-elim(*);
dce(*);
unforkify(*);
gvn(*);
phi-elim(*);
dce(*);
gcm(*); gcm(*);
...@@ -13,6 +13,8 @@ fn main() { ...@@ -13,6 +13,8 @@ fn main() {
JunoCompiler::new() JunoCompiler::new()
.file_in_src("bfs.jn") .file_in_src("bfs.jn")
.unwrap() .unwrap()
.schedule_in_src("cpu.sch")
.unwrap()
.build() .build()
.unwrap(); .unwrap();
} }
...@@ -13,8 +13,8 @@ fn bfs<n, m: usize>(graph_nodes: Node[n], source: u32, edges: u32[m]) -> i32[n] ...@@ -13,8 +13,8 @@ fn bfs<n, m: usize>(graph_nodes: Node[n], source: u32, edges: u32[m]) -> i32[n]
let visited: bool[n]; let visited: bool[n];
visited[source as u64] = true; visited[source as u64] = true;
let cost: i32[n]; @cost @cost_init let cost: i32[n];
for i in 0..n { @cost_init for i in 0..n {
cost[i] = -1; cost[i] = -1;
} }
cost[source as u64] = 0; cost[source as u64] = 0;
...@@ -25,7 +25,7 @@ fn bfs<n, m: usize>(graph_nodes: Node[n], source: u32, edges: u32[m]) -> i32[n] ...@@ -25,7 +25,7 @@ fn bfs<n, m: usize>(graph_nodes: Node[n], source: u32, edges: u32[m]) -> i32[n]
while !stop { while !stop {
stop = true; stop = true;
for i in 0..n { @loop1 for i in 0..n {
if mask[i] { if mask[i] {
mask[i] = false; mask[i] = false;
...@@ -42,11 +42,11 @@ fn bfs<n, m: usize>(graph_nodes: Node[n], source: u32, edges: u32[m]) -> i32[n] ...@@ -42,11 +42,11 @@ fn bfs<n, m: usize>(graph_nodes: Node[n], source: u32, edges: u32[m]) -> i32[n]
} }
} }
for i in 0..n { @loop2 for i in 0..n {
stop = stop && !updated[i];
if updated[i] { if updated[i] {
mask[i] = true; mask[i] = true;
visited[i] = true; visited[i] = true;
stop = false;
updated[i] = false; updated[i] = false;
} }
} }
......
macro simpl!(X) {
ccp(X);
simplify-cfg(X);
lift-dc-math(X);
gvn(X);
phi-elim(X);
dce(X);
infer-schedules(X);
}
phi-elim(bfs);
no-memset(bfs@cost);
outline(bfs@cost_init);
let loop1 = outline(bfs@loop1);
let loop2 = outline(bfs@loop2);
simpl!(*);
predication(*);
const-inline(*);
simpl!(*);
fixpoint {
forkify(*);
fork-guard-elim(*);
}
simpl!(*);
predication(*);
simpl!(*);
unforkify(*);
gcm(*);
gvn(*); macro simpl!(X) {
dce(*); ccp(X);
phi-elim(*); simplify-cfg(X);
dce(*); lift-dc-math(X);
crc(*); gvn(X);
dce(*); phi-elim(X);
slf(*); crc(X);
dce(*); slf(X);
dce(X);
infer-schedules(X);
}
let auto = auto-outline(euler); simpl!(*);
cpu(auto.euler); inline(compute_step_factor, compute_flux, compute_flux_contribution, time_step);
inline(auto.euler);
inline(auto.euler);
delete-uncalled(*); delete-uncalled(*);
simpl!(*);
ip-sroa[false](*);
sroa[false](*);
predication(*);
const-inline(*);
simpl!(*);
fixpoint {
forkify(*);
fork-guard-elim(*);
}
simpl!(*);
no-memset(compute_step_factor@res, compute_flux@res, copy_vars@res);
parallel-reduce(time_step, copy_vars, compute_flux@outer_loop \ compute_flux@inner_loop);
sroa[false](auto.euler); unforkify(*);
dce(*);
float-collections(*);
dce(*);
gcm(*); gcm(*);
gvn(*); macro simpl!(X) {
dce(*); ccp(X);
phi-elim(*); simplify-cfg(X);
dce(*); lift-dc-math(X);
crc(*); gvn(X);
dce(*); phi-elim(X);
slf(*); crc(X);
dce(*); slf(X);
dce(X);
infer-schedules(X);
}
let auto = auto-outline(pre_euler); simpl!(*);
cpu(auto.pre_euler); inline(compute_step_factor, compute_flux, compute_flux_contributions, compute_flux_contribution, time_step);
inline(auto.pre_euler);
inline(auto.pre_euler);
delete-uncalled(*); delete-uncalled(*);
simpl!(*);
ip-sroa[false](*);
sroa[false](*);
predication(*);
const-inline(*);
simpl!(*);
fixpoint {
forkify(*);
fork-guard-elim(*);
}
simpl!(*);
sroa[false](auto.pre_euler); unforkify(*);
dce(*);
float-collections(*);
dce(*);
gcm(*); gcm(*);
...@@ -47,7 +47,7 @@ fn compute_speed_of_sound(density: f32, pressure: f32) -> f32 { ...@@ -47,7 +47,7 @@ fn compute_speed_of_sound(density: f32, pressure: f32) -> f32 {
} }
fn compute_step_factor<nelr: usize>(variables: Variables::<nelr>, areas: f32[nelr]) -> f32[nelr] { fn compute_step_factor<nelr: usize>(variables: Variables::<nelr>, areas: f32[nelr]) -> f32[nelr] {
let step_factors : f32[nelr]; @res let step_factors : f32[nelr];
for i in 0..nelr { for i in 0..nelr {
let density = variables.density[i]; let density = variables.density[i];
...@@ -106,9 +106,9 @@ fn compute_flux<nelr: usize>( ...@@ -106,9 +106,9 @@ fn compute_flux<nelr: usize>(
ff_flux_contribution_momentum_z: float3, ff_flux_contribution_momentum_z: float3,
) -> Variables::<nelr> { ) -> Variables::<nelr> {
const smoothing_coefficient : f32 = 0.2; const smoothing_coefficient : f32 = 0.2;
let fluxes: Variables::<nelr>; @res let fluxes: Variables::<nelr>;
for i in 0..nelr { @outer_loop for i in 0..nelr {
let density_i = variables.density[i]; let density_i = variables.density[i];
let momentum_i = float3 { x: variables.momentum.x[i], let momentum_i = float3 { x: variables.momentum.x[i],
...@@ -131,7 +131,7 @@ fn compute_flux<nelr: usize>( ...@@ -131,7 +131,7 @@ fn compute_flux<nelr: usize>(
let flux_i_momentum = float3 { x: 0.0, y: 0.0, z: 0.0 }; let flux_i_momentum = float3 { x: 0.0, y: 0.0, z: 0.0 };
let flux_i_density_energy : f32 = 0.0; let flux_i_density_energy : f32 = 0.0;
for j in 0..NNB { @inner_loop for j in 0..NNB {
let nb = elements_surrounding_elements[j, i]; let nb = elements_surrounding_elements[j, i];
let normal = float3 { let normal = float3 {
x: normals.x[j, i], x: normals.x[j, i],
...@@ -249,7 +249,7 @@ fn time_step<nelr: usize>( ...@@ -249,7 +249,7 @@ fn time_step<nelr: usize>(
} }
fn copy_vars<nelr: usize>(variables: Variables::<nelr>) -> Variables::<nelr> { fn copy_vars<nelr: usize>(variables: Variables::<nelr>) -> Variables::<nelr> {
let result : Variables::<nelr>; @res let result : Variables::<nelr>;
for i in 0..nelr { for i in 0..nelr {
result.density[i] = variables.density[i]; result.density[i] = variables.density[i];
......
...@@ -13,8 +13,8 @@ fn srad_bench(c: &mut Criterion) { ...@@ -13,8 +13,8 @@ fn srad_bench(c: &mut Criterion) {
let mut r = runner!(srad); let mut r = runner!(srad);
let niter = 100; let niter = 100;
let lambda = 0.5; let lambda = 0.5;
let nrows = 502; let nrows = 512;
let ncols = 458; let ncols = 512;
let image = "data/image.pgm".to_string(); let image = "data/image.pgm".to_string();
let Image { let Image {
image: image_ori, image: image_ori,
......
...@@ -13,6 +13,8 @@ fn main() { ...@@ -13,6 +13,8 @@ fn main() {
JunoCompiler::new() JunoCompiler::new()
.file_in_src("srad.jn") .file_in_src("srad.jn")
.unwrap() .unwrap()
.schedule_in_src("cpu.sch")
.unwrap()
.build() .build()
.unwrap(); .unwrap();
} }
macro simpl!(X) {
ccp(X);
simplify-cfg(X);
lift-dc-math(X);
gvn(X);
phi-elim(X);
dce(X);
infer-schedules(X);
}
phi-elim(*);
let loop1 = outline(srad@loop1);
let loop2 = outline(srad@loop2);
let loop3 = outline(srad@loop3);
simpl!(*);
const-inline(*);
crc(*);
slf(*);
write-predication(*);
simpl!(*);
predication(*);
simpl!(*);
predication(*);
simpl!(*);
fixpoint {
forkify(*);
fork-guard-elim(*);
fork-coalesce(*);
}
simpl!(*);
fork-interchange[0, 1](loop1);
fork-split(*);
unforkify(*);
gcm(*);