From 95a1e72644488e0a7197f0037801106b73214765 Mon Sep 17 00:00:00 2001
From: Russel Arbore <russel.jma@gmail.com>
Date: Thu, 23 Jan 2025 21:35:24 -0600
Subject: [PATCH] Lift math on dynamic constants in IR to dynamic constant math

---
 hercules_opt/src/dce.rs          |  2 +-
 hercules_opt/src/lib.rs          |  2 +
 hercules_opt/src/lift_dc_math.rs | 90 ++++++++++++++++++++++++++++++++
 juno_scheduler/src/default.rs    |  4 ++
 juno_scheduler/src/ir.rs         |  1 +
 juno_scheduler/src/pm.rs         | 16 +++++-
 6 files changed, 112 insertions(+), 3 deletions(-)
 create mode 100644 hercules_opt/src/lift_dc_math.rs

diff --git a/hercules_opt/src/dce.rs b/hercules_opt/src/dce.rs
index 026672a3..6eec42e5 100644
--- a/hercules_opt/src/dce.rs
+++ b/hercules_opt/src/dce.rs
@@ -8,7 +8,7 @@ use crate::*;
  */
 pub fn dce(editor: &mut FunctionEditor) {
     // Create worklist (starts as all nodes).
-    let mut worklist: Vec<NodeID> = (0..editor.func().nodes.len()).map(NodeID::new).collect();
+    let mut worklist: Vec<NodeID> = editor.node_ids().collect();
 
     while let Some(work) = worklist.pop() {
         // If a node on the worklist is a start node, it is either *the* start
diff --git a/hercules_opt/src/lib.rs b/hercules_opt/src/lib.rs
index 2c9d4372..0b10bdae 100644
--- a/hercules_opt/src/lib.rs
+++ b/hercules_opt/src/lib.rs
@@ -13,6 +13,7 @@ pub mod gcm;
 pub mod gvn;
 pub mod inline;
 pub mod interprocedural_sroa;
+pub mod lift_dc_math;
 pub mod outline;
 pub mod phi_elim;
 pub mod pred;
@@ -35,6 +36,7 @@ pub use crate::gcm::*;
 pub use crate::gvn::*;
 pub use crate::inline::*;
 pub use crate::interprocedural_sroa::*;
+pub use crate::lift_dc_math::*;
 pub use crate::outline::*;
 pub use crate::phi_elim::*;
 pub use crate::pred::*;
diff --git a/hercules_opt/src/lift_dc_math.rs b/hercules_opt/src/lift_dc_math.rs
new file mode 100644
index 00000000..afdb2120
--- /dev/null
+++ b/hercules_opt/src/lift_dc_math.rs
@@ -0,0 +1,90 @@
+use hercules_ir::ir::*;
+
+use crate::*;
+
+/*
+ * Lift math in IR nodes into dynamic constants.
+ */
+pub fn lift_dc_math(editor: &mut FunctionEditor) {
+    // Create worklist (starts as all nodes).
+    let mut worklist: Vec<NodeID> = editor.node_ids().collect();
+    while let Some(work) = worklist.pop() {
+        // Look for single nodes that can be converted to dynamic constants.
+        let users: Vec<_> = editor.get_users(work).collect();
+        let nodes = &editor.func().nodes;
+        let dc = match nodes[work.idx()] {
+            Node::Constant { id } => {
+                // Why do we need this weird crap? This is due to a limitation
+                // in Rust's lifetime rules w/ let guards.
+                let cons = if let Constant::UnsignedInteger64(cons) = *editor.get_constant(id) {
+                    cons
+                } else {
+                    continue;
+                };
+                DynamicConstant::Constant(cons as usize)
+            }
+            Node::DynamicConstant { id } => {
+                let Some(cons) = evaluate_dynamic_constant(id, &*editor.get_dynamic_constants())
+                else {
+                    continue;
+                };
+                DynamicConstant::Constant(cons)
+            }
+            Node::Binary { op, left, right } => {
+                let (left, right) = if let (
+                    Node::DynamicConstant { id: left },
+                    Node::DynamicConstant { id: right },
+                ) = (&nodes[left.idx()], &nodes[right.idx()])
+                {
+                    (*left, *right)
+                } else {
+                    continue;
+                };
+                match op {
+                    BinaryOperator::Add => DynamicConstant::Add(left, right),
+                    BinaryOperator::Sub => DynamicConstant::Sub(left, right),
+                    BinaryOperator::Mul => DynamicConstant::Mul(left, right),
+                    BinaryOperator::Div => DynamicConstant::Div(left, right),
+                    BinaryOperator::Rem => DynamicConstant::Rem(left, right),
+                    _ => {
+                        continue;
+                    }
+                }
+            }
+            Node::IntrinsicCall {
+                intrinsic,
+                ref args,
+            } => {
+                let (left, right) = if args.len() == 2
+                    && let (Node::DynamicConstant { id: left }, Node::DynamicConstant { id: right }) =
+                        (&nodes[args[0].idx()], &nodes[args[1].idx()])
+                {
+                    (*left, *right)
+                } else {
+                    continue;
+                };
+                match intrinsic {
+                    Intrinsic::Min => DynamicConstant::Min(left, right),
+                    Intrinsic::Max => DynamicConstant::Max(left, right),
+                    _ => {
+                        continue;
+                    }
+                }
+            }
+            _ => {
+                continue;
+            }
+        };
+
+        // Replace the node with the computed dynamic constant.
+        let success = editor.edit(|mut edit| {
+            let dc = edit.add_dynamic_constant(dc);
+            let node = edit.add_node(Node::DynamicConstant { id: dc });
+            edit = edit.replace_all_uses(work, node)?;
+            edit.delete_node(work)
+        });
+        if success {
+            worklist.extend(users);
+        }
+    }
+}
diff --git a/juno_scheduler/src/default.rs b/juno_scheduler/src/default.rs
index 8274b81a..46b51b43 100644
--- a/juno_scheduler/src/default.rs
+++ b/juno_scheduler/src/default.rs
@@ -60,6 +60,10 @@ pub fn default_schedule() -> ScheduleStmt {
         DCE,
         GVN,
         DCE,
+        LiftDCMath,
+        DCE,
+        GVN,
+        DCE,
         /*Forkify,*/
         /*ForkGuardElim,*/
         DCE,
diff --git a/juno_scheduler/src/ir.rs b/juno_scheduler/src/ir.rs
index 16f2de9b..381c3475 100644
--- a/juno_scheduler/src/ir.rs
+++ b/juno_scheduler/src/ir.rs
@@ -18,6 +18,7 @@ pub enum Pass {
     InferSchedules,
     Inline,
     InterproceduralSROA,
+    LiftDCMath,
     Outline,
     PhiElim,
     Predication,
diff --git a/juno_scheduler/src/pm.rs b/juno_scheduler/src/pm.rs
index 43fba4fd..72150570 100644
--- a/juno_scheduler/src/pm.rs
+++ b/juno_scheduler/src/pm.rs
@@ -6,8 +6,8 @@ use hercules_opt::FunctionEditor;
 use hercules_opt::{
     ccp, collapse_returns, crc, dce, dumb_outline, ensure_between_control_flow, float_collections,
     fork_split, gcm, gvn, infer_parallel_fork, infer_parallel_reduce, infer_tight_associative,
-    infer_vectorizable, inline, interprocedural_sroa, outline, phi_elim, predication, slf, sroa,
-    unforkify, write_predication,
+    infer_vectorizable, inline, interprocedural_sroa, lift_dc_math, outline, phi_elim, predication,
+    slf, sroa, unforkify, write_predication,
 };
 
 use tempfile::TempDir;
@@ -1339,6 +1339,18 @@ fn run_pass(
             pm.delete_gravestones();
             pm.clear_analyses();
         }
+        Pass::LiftDCMath => {
+            assert!(args.is_empty());
+            for func in build_selection(pm, selection) {
+                let Some(mut func) = func else {
+                    continue;
+                };
+                lift_dc_math(&mut func);
+                changed |= func.modified();
+            }
+            pm.delete_gravestones();
+            pm.clear_analyses();
+        }
         Pass::Outline => {
             let Some((nodes, func)) = selection_as_set(pm, selection) else {
                 return Err(SchedulerError::PassError {
-- 
GitLab