From 718f6c6a2cd0e054b33442ca8a17e6e43908e0d9 Mon Sep 17 00:00:00 2001
From: Russel Arbore <russel.jma@gmail.com>
Date: Sun, 18 Feb 2024 09:57:47 -0600
Subject: [PATCH] Add ternary node

---
 hercules_cg/src/cpu_beta.rs  | 19 ++++++++++++
 hercules_ir/src/build.rs     | 15 ++++++++++
 hercules_ir/src/def_use.rs   | 12 ++++++++
 hercules_ir/src/ir.rs        | 37 ++++++++++++++++++++++++
 hercules_ir/src/parse.rs     | 22 ++++++++++++++
 hercules_ir/src/typecheck.rs | 31 ++++++++++++++++++++
 hercules_opt/src/ccp.rs      | 56 ++++++++++++++++++++++++++++++++++++
 7 files changed, 192 insertions(+)

diff --git a/hercules_cg/src/cpu_beta.rs b/hercules_cg/src/cpu_beta.rs
index 2c5897d2..9a1c6741 100644
--- a/hercules_cg/src/cpu_beta.rs
+++ b/hercules_cg/src/cpu_beta.rs
@@ -726,6 +726,25 @@ fn emit_llvm_for_node(
                 virtual_register(right),
             );
         }
+        Node::Ternary {
+            first,
+            second,
+            third,
+            op,
+        } => {
+            let opcode = match op {
+                TernaryOperator::Select => "select",
+            };
+
+            llvm_bbs.get_mut(&bb[id.idx()]).unwrap().data += &format!(
+                "  {} = {} {}, {}, {}\n",
+                virtual_register(id),
+                opcode,
+                normal_value(first),
+                normal_value(second),
+                normal_value(third),
+            );
+        }
         Node::Read {
             collect,
             ref indices,
diff --git a/hercules_ir/src/build.rs b/hercules_ir/src/build.rs
index 3c37ecb8..4345f712 100644
--- a/hercules_ir/src/build.rs
+++ b/hercules_ir/src/build.rs
@@ -498,6 +498,21 @@ impl NodeBuilder {
         self.node = Node::Binary { left, right, op };
     }
 
+    pub fn build_ternary(
+        &mut self,
+        first: NodeID,
+        second: NodeID,
+        third: NodeID,
+        op: TernaryOperator,
+    ) {
+        self.node = Node::Ternary {
+            first,
+            second,
+            third,
+            op,
+        };
+    }
+
     pub fn build_call(
         &mut self,
         function: FunctionID,
diff --git a/hercules_ir/src/def_use.rs b/hercules_ir/src/def_use.rs
index a8fd5c17..0116bb4e 100644
--- a/hercules_ir/src/def_use.rs
+++ b/hercules_ir/src/def_use.rs
@@ -152,6 +152,12 @@ pub fn get_uses<'a>(node: &'a Node) -> NodeUses<'a> {
         Node::DynamicConstant { id: _ } => NodeUses::One([NodeID::new(0)]),
         Node::Unary { input, op: _ } => NodeUses::One([*input]),
         Node::Binary { left, right, op: _ } => NodeUses::Two([*left, *right]),
+        Node::Ternary {
+            first,
+            second,
+            third,
+            op: _,
+        } => NodeUses::Three([*first, *second, *third]),
         Node::Call {
             function: _,
             dynamic_constants: _,
@@ -227,6 +233,12 @@ pub fn get_uses_mut<'a>(node: &'a mut Node) -> NodeUsesMut<'a> {
         Node::DynamicConstant { id: _ } => NodeUsesMut::Zero,
         Node::Unary { input, op: _ } => NodeUsesMut::One([input]),
         Node::Binary { left, right, op: _ } => NodeUsesMut::Two([left, right]),
+        Node::Ternary {
+            first,
+            second,
+            third,
+            op: _,
+        } => NodeUsesMut::Three([first, second, third]),
         Node::Call {
             function: _,
             dynamic_constants: _,
diff --git a/hercules_ir/src/ir.rs b/hercules_ir/src/ir.rs
index 8e5997bf..b3df19d9 100644
--- a/hercules_ir/src/ir.rs
+++ b/hercules_ir/src/ir.rs
@@ -199,6 +199,12 @@ pub enum Node {
         right: NodeID,
         op: BinaryOperator,
     },
+    Ternary {
+        first: NodeID,
+        second: NodeID,
+        third: NodeID,
+        op: TernaryOperator,
+    },
     Call {
         function: FunctionID,
         dynamic_constants: Box<[DynamicConstantID]>,
@@ -241,6 +247,11 @@ pub enum BinaryOperator {
     RSh,
 }
 
+#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
+pub enum TernaryOperator {
+    Select,
+}
+
 impl Module {
     /*
      * There are many transformations that need to iterate over the functions
@@ -889,6 +900,12 @@ impl Node {
                 right: _,
                 op,
             } => op.upper_case_name(),
+            Node::Ternary {
+                first: _,
+                second: _,
+                third: _,
+                op,
+            } => op.upper_case_name(),
             Node::Call {
                 function: _,
                 dynamic_constants: _,
@@ -943,6 +960,12 @@ impl Node {
                 right: _,
                 op,
             } => op.lower_case_name(),
+            Node::Ternary {
+                first: _,
+                second: _,
+                third: _,
+                op,
+            } => op.lower_case_name(),
             Node::Call {
                 function: _,
                 dynamic_constants: _,
@@ -1037,6 +1060,20 @@ impl BinaryOperator {
     }
 }
 
+impl TernaryOperator {
+    pub fn upper_case_name(&self) -> &'static str {
+        match self {
+            TernaryOperator::Select => "Select",
+        }
+    }
+
+    pub fn lower_case_name(&self) -> &'static str {
+        match self {
+            TernaryOperator::Select => "select",
+        }
+    }
+}
+
 /*
  * Rust things to make newtyped IDs usable.
  */
diff --git a/hercules_ir/src/parse.rs b/hercules_ir/src/parse.rs
index a00da548..63107c81 100644
--- a/hercules_ir/src/parse.rs
+++ b/hercules_ir/src/parse.rs
@@ -308,6 +308,7 @@ fn parse_node<'a>(
         "xor" => parse_binary(ir_text, context, BinaryOperator::Xor)?,
         "lsh" => parse_binary(ir_text, context, BinaryOperator::LSh)?,
         "rsh" => parse_binary(ir_text, context, BinaryOperator::RSh)?,
+        "select" => parse_ternary(ir_text, context, TernaryOperator::Select)?,
         "call" => parse_call(ir_text, context)?,
         "read" => parse_read(ir_text, context)?,
         "write" => parse_write(ir_text, context)?,
@@ -485,6 +486,27 @@ fn parse_binary<'a>(
     Ok((ir_text, Node::Binary { left, right, op }))
 }
 
+fn parse_ternary<'a>(
+    ir_text: &'a str,
+    context: &RefCell<Context<'a>>,
+    op: TernaryOperator,
+) -> nom::IResult<&'a str, Node> {
+    let (ir_text, (first, second, third)) =
+        parse_tuple3(parse_identifier, parse_identifier, parse_identifier)(ir_text)?;
+    let first = context.borrow_mut().get_node_id(first);
+    let second = context.borrow_mut().get_node_id(second);
+    let third = context.borrow_mut().get_node_id(third);
+    Ok((
+        ir_text,
+        Node::Ternary {
+            first,
+            second,
+            third,
+            op,
+        },
+    ))
+}
+
 fn parse_call<'a>(ir_text: &'a str, context: &RefCell<Context<'a>>) -> nom::IResult<&'a str, Node> {
     // Call nodes are a bit complicated because they 1. optionally take dynamic
     // constants as "arguments" (though these are specified between <>), 2.
diff --git a/hercules_ir/src/typecheck.rs b/hercules_ir/src/typecheck.rs
index eabe45d7..ff631bd7 100644
--- a/hercules_ir/src/typecheck.rs
+++ b/hercules_ir/src/typecheck.rs
@@ -701,6 +701,37 @@ fn typeflow(
 
             input_ty.clone()
         }
+        Node::Ternary {
+            first: _,
+            second: _,
+            third: _,
+            op,
+        } => {
+            if inputs.len() != 3 {
+                return Error(String::from("Ternary node must have exactly three inputs."));
+            }
+
+            if let Concrete(id) = inputs[0] {
+                match op {
+                    TernaryOperator::Select => {
+                        if !types[id.idx()].is_bool() {
+                            return Error(String::from(
+                                "Select ternary node input cannot have non-bool condition input.",
+                            ));
+                        }
+
+                        let data_ty = TypeSemilattice::meet(inputs[1], inputs[2]);
+                        if let Concrete(data_id) = data_ty {
+                            return Concrete(data_id);
+                        } else {
+                            return data_ty;
+                        }
+                    }
+                }
+            }
+
+            Error(String::from("Unhandled ternary types."))
+        }
         Node::Call {
             function: callee_id,
             dynamic_constants: dc_args,
diff --git a/hercules_opt/src/ccp.rs b/hercules_opt/src/ccp.rs
index 1381a2f1..8999730f 100644
--- a/hercules_opt/src/ccp.rs
+++ b/hercules_opt/src/ccp.rs
@@ -665,6 +665,62 @@ fn ccp_flow_function(
                 constant: new_constant,
             }
         }
+        Node::Ternary {
+            first,
+            second,
+            third,
+            op,
+        } => {
+            let CCPLattice {
+                reachability: ref first_reachability,
+                constant: ref first_constant,
+            } = inputs[first.idx()];
+            let CCPLattice {
+                reachability: ref second_reachability,
+                constant: ref second_constant,
+            } = inputs[second.idx()];
+            let CCPLattice {
+                reachability: ref third_reachability,
+                constant: ref third_constant,
+            } = inputs[third.idx()];
+
+            let new_constant = if let (
+                ConstantLattice::Constant(first_cons),
+                ConstantLattice::Constant(second_cons),
+                ConstantLattice::Constant(third_cons),
+            ) = (first_constant, second_constant, third_constant)
+            {
+                let new_cons = match(op, first_cons, second_cons, third_cons) {
+                    (TernaryOperator::Select, Constant::Boolean(first_val), second_val, third_val) => if *first_val {second_val.clone()} else {third_val.clone()},
+                    _ => panic!("Unsupported combination of ternary operation and constant values. Did typechecking succeed?")
+                };
+                ConstantLattice::Constant(new_cons)
+            } else if (first_constant.is_top()
+                && !second_constant.is_bottom()
+                && !third_constant.is_bottom())
+                || (!first_constant.is_bottom()
+                    && second_constant.is_top()
+                    && !first_constant.is_bottom())
+                || (!first_constant.is_bottom()
+                    && !second_constant.is_bottom()
+                    && third_constant.is_top())
+            {
+                ConstantLattice::top()
+            } else {
+                ConstantLattice::meet(
+                    first_constant,
+                    &ConstantLattice::meet(second_constant, third_constant),
+                )
+            };
+
+            CCPLattice {
+                reachability: ReachabilityLattice::meet(
+                    first_reachability,
+                    &ReachabilityLattice::meet(second_reachability, third_reachability),
+                ),
+                constant: new_constant,
+            }
+        }
         // Call nodes are uninterpretable.
         Node::Call {
             function: _,
-- 
GitLab