From e76923405e879bddd94888edd89a8088359622cd Mon Sep 17 00:00:00 2001 From: Russel Arbore <russel.jma@gmail.com> Date: Thu, 20 Feb 2025 15:20:59 -0600 Subject: [PATCH] Lower intrinsics in RT backend --- hercules_cg/src/rt.rs | 37 +++++++++++++++++++++++-- juno_samples/dot/src/cpu.sch | 2 +- juno_samples/edge_detection/src/cpu.sch | 15 ++++++++-- juno_scheduler/src/pm.rs | 10 ++----- 4 files changed, 51 insertions(+), 13 deletions(-) diff --git a/hercules_cg/src/rt.rs b/hercules_cg/src/rt.rs index d3013239..7cbb43ad 100644 --- a/hercules_cg/src/rt.rs +++ b/hercules_cg/src/rt.rs @@ -489,8 +489,24 @@ impl<'a> RTContext<'a> { Constant::UnsignedInteger16(val) => write!(block, "{}u16", val)?, Constant::UnsignedInteger32(val) => write!(block, "{}u32", val)?, Constant::UnsignedInteger64(val) => write!(block, "{}u64", val)?, - Constant::Float32(val) => write!(block, "{}f32", val)?, - Constant::Float64(val) => write!(block, "{}f64", val)?, + Constant::Float32(val) => { + if val == f32::INFINITY { + write!(block, "f32::INFINITY")? + } else if val == f32::NEG_INFINITY { + write!(block, "f32::NEG_INFINITY")? + } else { + write!(block, "{}f32", val)? + } + } + Constant::Float64(val) => { + if val == f64::INFINITY { + write!(block, "f64::INFINITY")? + } else if val == f64::NEG_INFINITY { + write!(block, "f64::NEG_INFINITY")? + } else { + write!(block, "{}f64", val)? + } + } Constant::Product(ty, _) | Constant::Summation(ty, _, _) | Constant::Array(ty) => { @@ -628,6 +644,23 @@ impl<'a> RTContext<'a> { } write!(block, "){};", postfix)?; } + Node::IntrinsicCall { + intrinsic, + ref args, + } => { + let block = &mut blocks.get_mut(&bb).unwrap().data; + write!( + block, + "{} = {}::{}(", + self.get_value(id, bb, true), + self.get_type(self.typing[id.idx()]), + intrinsic.lower_case_name(), + )?; + for arg in args { + write!(block, "{}, ", self.get_value(*arg, bb, false))?; + } + write!(block, ");")?; + } Node::LibraryCall { library_function, ref args, diff --git a/juno_samples/dot/src/cpu.sch b/juno_samples/dot/src/cpu.sch index 1f8953d9..5c763772 100644 --- a/juno_samples/dot/src/cpu.sch +++ b/juno_samples/dot/src/cpu.sch @@ -24,7 +24,7 @@ dce(*); let fission_out = fork-fission[out@loop](dot); simplify-cfg(dot); dce(dot); -unforkify(fission_out.dot.fj_loop_bottom); +unforkify(fission_out.dot.fj_bottom); ccp(dot); simplify-cfg(dot); gvn(dot); diff --git a/juno_samples/edge_detection/src/cpu.sch b/juno_samples/edge_detection/src/cpu.sch index d08e86e6..ead722ce 100644 --- a/juno_samples/edge_detection/src/cpu.sch +++ b/juno_samples/edge_detection/src/cpu.sch @@ -58,8 +58,17 @@ fixpoint { fork-coalesce(max_gradient); } simpl!(max_gradient); +fork-dim-merge(max_gradient); +simpl!(max_gradient); +fork-tile[8, 0, false, false](max_gradient); +let split = fork-split(max_gradient); clean-monoid-reduces(max_gradient); -xdot[true](max_gradient); +let out = outline(split._4_max_gradient.fj1); +simpl!(max_gradient, out); +unforkify(out); +let out = fork-fission[split._4_max_gradient.fj0](max_gradient); +simpl!(max_gradient); +unforkify(out._4_max_gradient.fj_bottom); no-memset(reject_zero_crossings@res); fixpoint { @@ -72,8 +81,8 @@ simpl!(reject_zero_crossings); async-call(edge_detection@le, edge_detection@zc); -fork-split(gaussian_smoothing, laplacian_estimate, zero_crossings, gradient, max_gradient, reject_zero_crossings); -unforkify(gaussian_smoothing, laplacian_estimate, zero_crossings, gradient, max_gradient, reject_zero_crossings); +fork-split(gaussian_smoothing, laplacian_estimate, zero_crossings, gradient, reject_zero_crossings); +unforkify(gaussian_smoothing, laplacian_estimate, zero_crossings, gradient, reject_zero_crossings); simpl!(*); diff --git a/juno_scheduler/src/pm.rs b/juno_scheduler/src/pm.rs index 6931ce2e..392273d3 100644 --- a/juno_scheduler/src/pm.rs +++ b/juno_scheduler/src/pm.rs @@ -2560,7 +2560,7 @@ fn run_pass( let nodes_in_fork_joins = pm.nodes_in_fork_joins.take().unwrap(); let mut new_fork_joins = HashMap::new(); - let fork_label_name = &pm.labels.borrow()[fork_label.idx()].clone(); + let _fork_label_name = &pm.labels.borrow()[fork_label.idx()].clone(); for (mut func, created_fork_joins) in build_editors(pm).into_iter().zip(created_fork_joins) @@ -2583,13 +2583,9 @@ fn run_pass( // level of the split fork-joins being referred to. let mut func_record = HashMap::new(); for (idx, label) in labels { - let fmt = if idx % 2 == 0 { - format!("fj_{}_top", fork_label_name) - } else { - format!("fj_{}_bottom", fork_label_name) - }; + let fmt = if idx % 2 == 0 { "fj_top" } else { "fj_bottom" }; func_record.insert( - fmt, + fmt.to_string(), Value::Label { labels: vec![LabelInfo { func: func_id, -- GitLab