From 3adec0d65e6c765d6b9a61436c6df9597d602749 Mon Sep 17 00:00:00 2001 From: Russel Arbore <russel.jma@gmail.com> Date: Thu, 20 Feb 2025 14:55:30 -0600 Subject: [PATCH] Add min/max to clean monoid reduce --- hercules_ir/src/ir.rs | 52 ++++++++++++++++++++----- hercules_opt/src/editor.rs | 34 ++++++++++++++-- hercules_opt/src/fork_transforms.rs | 38 ++++++++++++++++++ hercules_opt/src/utils.rs | 26 +++++++++++++ juno_samples/edge_detection/src/cpu.sch | 2 + 5 files changed, 138 insertions(+), 14 deletions(-) diff --git a/hercules_ir/src/ir.rs b/hercules_ir/src/ir.rs index bf9698b3..f91efe58 100644 --- a/hercules_ir/src/ir.rs +++ b/hercules_ir/src/ir.rs @@ -1050,6 +1050,38 @@ impl Constant { _ => false, } } + + pub fn is_largest(&self) -> bool { + match self { + Constant::Integer8(i8::MAX) => true, + Constant::Integer16(i16::MAX) => true, + Constant::Integer32(i32::MAX) => true, + Constant::Integer64(i64::MAX) => true, + Constant::UnsignedInteger8(u8::MAX) => true, + Constant::UnsignedInteger16(u16::MAX) => true, + Constant::UnsignedInteger32(u32::MAX) => true, + Constant::UnsignedInteger64(u64::MAX) => true, + Constant::Float32(ord) => *ord == OrderedFloat::<f32>(f32::INFINITY), + Constant::Float64(ord) => *ord == OrderedFloat::<f64>(f64::INFINITY), + _ => false, + } + } + + pub fn is_smallest(&self) -> bool { + match self { + Constant::Integer8(i8::MIN) => true, + Constant::Integer16(i16::MIN) => true, + Constant::Integer32(i32::MIN) => true, + Constant::Integer64(i64::MIN) => true, + Constant::UnsignedInteger8(u8::MIN) => true, + Constant::UnsignedInteger16(u16::MIN) => true, + Constant::UnsignedInteger32(u32::MIN) => true, + Constant::UnsignedInteger64(u64::MIN) => true, + Constant::Float32(ord) => *ord == OrderedFloat::<f32>(f32::NEG_INFINITY), + Constant::Float64(ord) => *ord == OrderedFloat::<f64>(f64::NEG_INFINITY), + _ => false, + } + } } impl DynamicConstant { @@ -1098,19 +1130,19 @@ impl DynamicConstant { } pub fn is_zero(&self) -> bool { - if *self == DynamicConstant::Constant(0) { - true - } else { - false - } + *self == DynamicConstant::Constant(0) } pub fn is_one(&self) -> bool { - if *self == DynamicConstant::Constant(1) { - true - } else { - false - } + *self == DynamicConstant::Constant(1) + } + + pub fn is_largest(&self) -> bool { + *self == DynamicConstant::Constant(usize::MAX) + } + + pub fn is_smallest(&self) -> bool { + *self == DynamicConstant::Constant(usize::MIN) } pub fn try_parameter(&self) -> Option<usize> { diff --git a/hercules_opt/src/editor.rs b/hercules_opt/src/editor.rs index b33dc956..57fe2042 100644 --- a/hercules_opt/src/editor.rs +++ b/hercules_opt/src/editor.rs @@ -795,22 +795,48 @@ impl<'a, 'b> FunctionEdit<'a, 'b> { self.add_constant(constant_to_construct) } - pub fn add_pos_inf_constant(&mut self, id: TypeID) -> ConstantID { + pub fn add_largest_constant(&mut self, id: TypeID) -> ConstantID { let ty = self.get_type(id).clone(); let constant_to_construct = match ty { + Type::Boolean => Constant::Boolean(true), + Type::Integer8 => Constant::Integer8(i8::MAX), + Type::Integer16 => Constant::Integer16(i16::MAX), + Type::Integer32 => Constant::Integer32(i32::MAX), + Type::Integer64 => Constant::Integer64(i64::MAX), + Type::UnsignedInteger8 => Constant::UnsignedInteger8(u8::MAX), + Type::UnsignedInteger16 => Constant::UnsignedInteger16(u16::MAX), + Type::UnsignedInteger32 => Constant::UnsignedInteger32(u32::MAX), + Type::UnsignedInteger64 => Constant::UnsignedInteger64(u64::MAX), + Type::Float8 | Type::BFloat16 => panic!(), Type::Float32 => Constant::Float32(ordered_float::OrderedFloat(f32::INFINITY)), Type::Float64 => Constant::Float64(ordered_float::OrderedFloat(f64::INFINITY)), - _ => panic!(), + Type::Control => panic!("PANIC: Can't create largest constant for the control type."), + Type::Product(_) | Type::Summation(_) | Type::Array(_, _) => { + panic!("PANIC: Can't create largest constant of a collection type.") + } }; self.add_constant(constant_to_construct) } - pub fn add_neg_inf_constant(&mut self, id: TypeID) -> ConstantID { + pub fn add_smallest_constant(&mut self, id: TypeID) -> ConstantID { let ty = self.get_type(id).clone(); let constant_to_construct = match ty { + Type::Boolean => Constant::Boolean(true), + Type::Integer8 => Constant::Integer8(i8::MIN), + Type::Integer16 => Constant::Integer16(i16::MIN), + Type::Integer32 => Constant::Integer32(i32::MIN), + Type::Integer64 => Constant::Integer64(i64::MIN), + Type::UnsignedInteger8 => Constant::UnsignedInteger8(u8::MIN), + Type::UnsignedInteger16 => Constant::UnsignedInteger16(u16::MIN), + Type::UnsignedInteger32 => Constant::UnsignedInteger32(u32::MIN), + Type::UnsignedInteger64 => Constant::UnsignedInteger64(u64::MIN), + Type::Float8 | Type::BFloat16 => panic!(), Type::Float32 => Constant::Float32(ordered_float::OrderedFloat(f32::NEG_INFINITY)), Type::Float64 => Constant::Float64(ordered_float::OrderedFloat(f64::NEG_INFINITY)), - _ => panic!(), + Type::Control => panic!("PANIC: Can't create smallest constant for the control type."), + Type::Product(_) | Type::Summation(_) | Type::Array(_, _) => { + panic!("PANIC: Can't create smallest constant of a collection type.") + } }; self.add_constant(constant_to_construct) } diff --git a/hercules_opt/src/fork_transforms.rs b/hercules_opt/src/fork_transforms.rs index 283734a0..e635b3c0 100644 --- a/hercules_opt/src/fork_transforms.rs +++ b/hercules_opt/src/fork_transforms.rs @@ -1556,6 +1556,44 @@ pub fn clean_monoid_reduces(editor: &mut FunctionEditor, typing: &Vec<TypeID>) { edit.replace_all_uses_where(id, final_op, |u| *u != reduct && *u != final_op) }); } + Node::IntrinsicCall { + intrinsic: Intrinsic::Max, + args: _, + } if !is_smallest(editor, init) => { + editor.edit(|mut edit| { + let smallest = edit.add_smallest_constant(typing[init.idx()]); + let smallest = edit.add_node(Node::Constant { id: smallest }); + edit.sub_edit(id, smallest); + edit = edit.replace_all_uses_where(init, smallest, |u| *u == id)?; + let final_op = edit.add_node(Node::IntrinsicCall { + intrinsic: Intrinsic::Max, + args: Box::new([init, id]), + }); + for u in out_uses { + edit.sub_edit(u, final_op); + } + edit.replace_all_uses_where(id, final_op, |u| *u != reduct && *u != final_op) + }); + } + Node::IntrinsicCall { + intrinsic: Intrinsic::Min, + args: _, + } if !is_largest(editor, init) => { + editor.edit(|mut edit| { + let largest = edit.add_largest_constant(typing[init.idx()]); + let largest = edit.add_node(Node::Constant { id: largest }); + edit.sub_edit(id, largest); + edit = edit.replace_all_uses_where(init, largest, |u| *u == id)?; + let final_op = edit.add_node(Node::IntrinsicCall { + intrinsic: Intrinsic::Min, + args: Box::new([init, id]), + }); + for u in out_uses { + edit.sub_edit(u, final_op); + } + edit.replace_all_uses_where(id, final_op, |u| *u != reduct && *u != final_op) + }); + } _ => {} } } diff --git a/hercules_opt/src/utils.rs b/hercules_opt/src/utils.rs index 1806d5c7..793fe9fa 100644 --- a/hercules_opt/src/utils.rs +++ b/hercules_opt/src/utils.rs @@ -567,3 +567,29 @@ pub fn is_one(editor: &FunctionEditor, id: NodeID) -> bool { .unwrap_or(false) || nodes[id.idx()].is_undef() } + +pub fn is_largest(editor: &FunctionEditor, id: NodeID) -> bool { + let nodes = &editor.func().nodes; + nodes[id.idx()] + .try_constant() + .map(|id| editor.get_constant(id).is_largest()) + .unwrap_or(false) + || nodes[id.idx()] + .try_dynamic_constant() + .map(|id| editor.get_dynamic_constant(id).is_largest()) + .unwrap_or(false) + || nodes[id.idx()].is_undef() +} + +pub fn is_smallest(editor: &FunctionEditor, id: NodeID) -> bool { + let nodes = &editor.func().nodes; + nodes[id.idx()] + .try_constant() + .map(|id| editor.get_constant(id).is_smallest()) + .unwrap_or(false) + || nodes[id.idx()] + .try_dynamic_constant() + .map(|id| editor.get_dynamic_constant(id).is_smallest()) + .unwrap_or(false) + || nodes[id.idx()].is_undef() +} diff --git a/juno_samples/edge_detection/src/cpu.sch b/juno_samples/edge_detection/src/cpu.sch index 3c3d09b3..d08e86e6 100644 --- a/juno_samples/edge_detection/src/cpu.sch +++ b/juno_samples/edge_detection/src/cpu.sch @@ -58,6 +58,8 @@ fixpoint { fork-coalesce(max_gradient); } simpl!(max_gradient); +clean-monoid-reduces(max_gradient); +xdot[true](max_gradient); no-memset(reject_zero_crossings@res); fixpoint { -- GitLab