From e76923405e879bddd94888edd89a8088359622cd Mon Sep 17 00:00:00 2001
From: Russel Arbore <russel.jma@gmail.com>
Date: Thu, 20 Feb 2025 15:20:59 -0600
Subject: [PATCH] Lower intrinsics in RT backend

---
 hercules_cg/src/rt.rs                   | 37 +++++++++++++++++++++++--
 juno_samples/dot/src/cpu.sch            |  2 +-
 juno_samples/edge_detection/src/cpu.sch | 15 ++++++++--
 juno_scheduler/src/pm.rs                | 10 ++-----
 4 files changed, 51 insertions(+), 13 deletions(-)

diff --git a/hercules_cg/src/rt.rs b/hercules_cg/src/rt.rs
index d3013239..7cbb43ad 100644
--- a/hercules_cg/src/rt.rs
+++ b/hercules_cg/src/rt.rs
@@ -489,8 +489,24 @@ impl<'a> RTContext<'a> {
                     Constant::UnsignedInteger16(val) => write!(block, "{}u16", val)?,
                     Constant::UnsignedInteger32(val) => write!(block, "{}u32", val)?,
                     Constant::UnsignedInteger64(val) => write!(block, "{}u64", val)?,
-                    Constant::Float32(val) => write!(block, "{}f32", val)?,
-                    Constant::Float64(val) => write!(block, "{}f64", val)?,
+                    Constant::Float32(val) => {
+                        if val == f32::INFINITY {
+                            write!(block, "f32::INFINITY")?
+                        } else if val == f32::NEG_INFINITY {
+                            write!(block, "f32::NEG_INFINITY")?
+                        } else {
+                            write!(block, "{}f32", val)?
+                        }
+                    }
+                    Constant::Float64(val) => {
+                        if val == f64::INFINITY {
+                            write!(block, "f64::INFINITY")?
+                        } else if val == f64::NEG_INFINITY {
+                            write!(block, "f64::NEG_INFINITY")?
+                        } else {
+                            write!(block, "{}f64", val)?
+                        }
+                    }
                     Constant::Product(ty, _)
                     | Constant::Summation(ty, _, _)
                     | Constant::Array(ty) => {
@@ -628,6 +644,23 @@ impl<'a> RTContext<'a> {
                 }
                 write!(block, "){};", postfix)?;
             }
+            Node::IntrinsicCall {
+                intrinsic,
+                ref args,
+            } => {
+                let block = &mut blocks.get_mut(&bb).unwrap().data;
+                write!(
+                    block,
+                    "{} = {}::{}(",
+                    self.get_value(id, bb, true),
+                    self.get_type(self.typing[id.idx()]),
+                    intrinsic.lower_case_name(),
+                )?;
+                for arg in args {
+                    write!(block, "{}, ", self.get_value(*arg, bb, false))?;
+                }
+                write!(block, ");")?;
+            }
             Node::LibraryCall {
                 library_function,
                 ref args,
diff --git a/juno_samples/dot/src/cpu.sch b/juno_samples/dot/src/cpu.sch
index 1f8953d9..5c763772 100644
--- a/juno_samples/dot/src/cpu.sch
+++ b/juno_samples/dot/src/cpu.sch
@@ -24,7 +24,7 @@ dce(*);
 let fission_out = fork-fission[out@loop](dot);
 simplify-cfg(dot);
 dce(dot);
-unforkify(fission_out.dot.fj_loop_bottom);
+unforkify(fission_out.dot.fj_bottom);
 ccp(dot);
 simplify-cfg(dot);
 gvn(dot);
diff --git a/juno_samples/edge_detection/src/cpu.sch b/juno_samples/edge_detection/src/cpu.sch
index d08e86e6..ead722ce 100644
--- a/juno_samples/edge_detection/src/cpu.sch
+++ b/juno_samples/edge_detection/src/cpu.sch
@@ -58,8 +58,17 @@ fixpoint {
   fork-coalesce(max_gradient);
 }
 simpl!(max_gradient);
+fork-dim-merge(max_gradient);
+simpl!(max_gradient);
+fork-tile[8, 0, false, false](max_gradient);
+let split = fork-split(max_gradient);
 clean-monoid-reduces(max_gradient);
-xdot[true](max_gradient);
+let out = outline(split._4_max_gradient.fj1);
+simpl!(max_gradient, out);
+unforkify(out);
+let out = fork-fission[split._4_max_gradient.fj0](max_gradient);
+simpl!(max_gradient);
+unforkify(out._4_max_gradient.fj_bottom);
 
 no-memset(reject_zero_crossings@res);
 fixpoint {
@@ -72,8 +81,8 @@ simpl!(reject_zero_crossings);
 
 async-call(edge_detection@le, edge_detection@zc);
 
-fork-split(gaussian_smoothing, laplacian_estimate, zero_crossings, gradient, max_gradient, reject_zero_crossings);
-unforkify(gaussian_smoothing, laplacian_estimate, zero_crossings, gradient, max_gradient, reject_zero_crossings);
+fork-split(gaussian_smoothing, laplacian_estimate, zero_crossings, gradient, reject_zero_crossings);
+unforkify(gaussian_smoothing, laplacian_estimate, zero_crossings, gradient, reject_zero_crossings);
 
 simpl!(*);
 
diff --git a/juno_scheduler/src/pm.rs b/juno_scheduler/src/pm.rs
index 6931ce2e..392273d3 100644
--- a/juno_scheduler/src/pm.rs
+++ b/juno_scheduler/src/pm.rs
@@ -2560,7 +2560,7 @@ fn run_pass(
             let nodes_in_fork_joins = pm.nodes_in_fork_joins.take().unwrap();
             let mut new_fork_joins = HashMap::new();
 
-            let fork_label_name = &pm.labels.borrow()[fork_label.idx()].clone();
+            let _fork_label_name = &pm.labels.borrow()[fork_label.idx()].clone();
 
             for (mut func, created_fork_joins) in
                 build_editors(pm).into_iter().zip(created_fork_joins)
@@ -2583,13 +2583,9 @@ fn run_pass(
                 // level of the split fork-joins being referred to.
                 let mut func_record = HashMap::new();
                 for (idx, label) in labels {
-                    let fmt = if idx % 2 == 0 {
-                        format!("fj_{}_top", fork_label_name)
-                    } else {
-                        format!("fj_{}_bottom", fork_label_name)
-                    };
+                    let fmt = if idx % 2 == 0 { "fj_top" } else { "fj_bottom" };
                     func_record.insert(
-                        fmt,
+                        fmt.to_string(),
                         Value::Label {
                             labels: vec![LabelInfo {
                                 func: func_id,
-- 
GitLab