From 54bf7312724cab4ae2353ff7736b54f4b77d47b4 Mon Sep 17 00:00:00 2001
From: rarbore2 <rarbore2@illinois.edu>
Date: Wed, 14 Feb 2024 15:45:28 -0600
Subject: [PATCH] Add zero initializer constants

---
 Cargo.lock                              |  1 +
 hercules_cg/src/cpu_beta.rs             | 38 ++++++++++++++--
 hercules_ir/src/ir.rs                   | 21 ++++++++-
 hercules_ir/src/parse.rs                | 20 ++++++---
 hercules_ir/src/typecheck.rs            |  2 +
 hercules_opt/Cargo.toml                 |  1 +
 hercules_opt/src/ccp.rs                 | 59 ++++++++++++++++++-------
 hercules_samples/matmul/matmul.hir      |  3 +-
 hercules_samples/matmul/src/main.rs     |  4 +-
 hercules_tools/hercules_cpu/src/main.rs |  1 +
 hercules_tools/hercules_dot/src/main.rs |  1 +
 11 files changed, 122 insertions(+), 29 deletions(-)

diff --git a/Cargo.lock b/Cargo.lock
index bfcc04d0..5cc9664f 100644
--- a/Cargo.lock
+++ b/Cargo.lock
@@ -205,6 +205,7 @@ name = "hercules_opt"
 version = "0.1.0"
 dependencies = [
  "hercules_ir",
+ "ordered-float",
 ]
 
 [[package]]
diff --git a/hercules_cg/src/cpu_beta.rs b/hercules_cg/src/cpu_beta.rs
index e6c858fd..2c5897d2 100644
--- a/hercules_cg/src/cpu_beta.rs
+++ b/hercules_cg/src/cpu_beta.rs
@@ -106,7 +106,36 @@ pub fn cpu_beta_codegen<W: Write>(
 
     // Step 2: render constants into LLVM IR. This is done in a very similar
     // manner as types.
-    let mut llvm_constants = vec!["".to_string(); types.len()];
+    let mut llvm_constants = vec!["".to_string(); constants.len()];
+    fn render_zero_constant(cons_id: ConstantID, ty_id: TypeID, types: &Vec<Type>) -> String {
+        match &types[ty_id.idx()] {
+            Type::Control(_) => panic!(),
+            Type::Boolean => "false".to_string(),
+            Type::Integer8
+            | Type::Integer16
+            | Type::Integer32
+            | Type::Integer64
+            | Type::UnsignedInteger8
+            | Type::UnsignedInteger16
+            | Type::UnsignedInteger32
+            | Type::UnsignedInteger64 => "0".to_string(),
+            Type::Float32 | Type::Float64 => "0.0".to_string(),
+            Type::Product(fields) => {
+                let mut iter = fields.iter();
+                if let Some(first) = iter.next() {
+                    iter.fold(
+                        "{".to_string() + &render_zero_constant(cons_id, *first, types),
+                        |s, f| s + ", " + &render_zero_constant(cons_id, *f, types),
+                    ) + "}"
+                } else {
+                    "{}".to_string()
+                }
+            }
+            Type::Summation(_) => todo!(),
+            Type::Array(_, _) => format!("%arr.{}", cons_id.idx()),
+        }
+    }
+
     for id in module.constants_bottom_up() {
         match &constants[id.idx()] {
             Constant::Boolean(val) => {
@@ -151,6 +180,9 @@ pub fn cpu_beta_codegen<W: Write>(
             }
             Constant::Array(_, _) => llvm_constants[id.idx()] = format!("%arr.{}", id.idx()),
             Constant::Summation(_, _, _) => todo!(),
+            Constant::Zero(ty_id) => {
+                llvm_constants[id.idx()] = render_zero_constant(id, *ty_id, types)
+            }
         }
     }
 
@@ -190,8 +222,8 @@ pub fn cpu_beta_codegen<W: Write>(
             .chain((0..function.num_dynamic_constants).map(|idx| format!("i64 %dc{}", idx)))
             .chain(
                 (0..constants.len())
-                    .filter(|idx| constants[*idx].is_array())
-                    .map(|idx| format!("%arr.{}", idx)),
+                    .filter(|idx| module.is_array_constant(ConstantID::new(*idx)))
+                    .map(|idx| format!("ptr %arr.{}", idx)),
             );
         write!(w, "define {} @{}(", llvm_ret_type, function.name)?;
         if let Some(first) = llvm_params.next() {
diff --git a/hercules_ir/src/ir.rs b/hercules_ir/src/ir.rs
index ecc7eecf..8e5997bf 100644
--- a/hercules_ir/src/ir.rs
+++ b/hercules_ir/src/ir.rs
@@ -77,7 +77,9 @@ pub enum Type {
  * interning constants during IR construction). Product, summation, and array
  * constants all contain their own type. This is only strictly necessary for
  * summation types, but provides a nice mechanism for sanity checking for
- * product and array types as well.
+ * product and array types as well. There is also a zero initializer constant,
+ * which stores its own type as well. The zero value of a summation is defined
+ * as the zero value of the first variant.
  */
 #[derive(Debug, Clone, PartialEq, Eq, Hash)]
 pub enum Constant {
@@ -95,6 +97,7 @@ pub enum Constant {
     Product(TypeID, Box<[ConstantID]>),
     Summation(TypeID, u32, ConstantID),
     Array(TypeID, Box<[ConstantID]>),
+    Zero(TypeID),
 }
 
 /*
@@ -374,6 +377,7 @@ impl Module {
                 }
                 write!(w, "]")
             }
+            Constant::Zero(_) => write!(w, "zero"),
         }?;
 
         Ok(())
@@ -509,6 +513,18 @@ impl Module {
             coroutine: Box::new(coroutine),
         }
     }
+
+    /*
+     * Unfortunately, determining if a constant is an array requires both
+     * knowledge of constants and types, due to zero initializer constants.
+     */
+    pub fn is_array_constant(&self, cons_id: ConstantID) -> bool {
+        if let Constant::Zero(ty_id) = self.constants[cons_id.idx()] {
+            self.types[ty_id.idx()].is_array()
+        } else {
+            self.constants[cons_id.idx()].is_strictly_array()
+        }
+    }
 }
 
 struct CoroutineIterator<G, I>
@@ -674,7 +690,7 @@ pub fn element_type(mut ty: TypeID, types: &Vec<Type>) -> TypeID {
 }
 
 impl Constant {
-    pub fn is_array(&self) -> bool {
+    pub fn is_strictly_array(&self) -> bool {
         if let Constant::Array(_, _) = self {
             true
         } else {
@@ -697,6 +713,7 @@ impl Constant {
             Constant::UnsignedInteger64(0) => true,
             Constant::Float32(ord) => *ord == ordered_float::OrderedFloat::<f32>(0.0),
             Constant::Float64(ord) => *ord == ordered_float::OrderedFloat::<f64>(0.0),
+            Constant::Zero(_) => true,
             _ => false,
         }
     }
diff --git a/hercules_ir/src/parse.rs b/hercules_ir/src/parse.rs
index 3354cbb2..a00da548 100644
--- a/hercules_ir/src/parse.rs
+++ b/hercules_ir/src/parse.rs
@@ -876,7 +876,16 @@ fn parse_constant<'a>(
     ty: Type,
     context: &RefCell<Context<'a>>,
 ) -> nom::IResult<&'a str, Constant> {
-    let (ir_text, constant) = match ty.clone() {
+    let ty_id = context.borrow_mut().get_type_id(ty.clone());
+    let (ir_text, maybe_constant) = nom::combinator::opt(nom::combinator::map(
+        nom::bytes::complete::tag("zero"),
+        |_| Constant::Zero(ty_id),
+    ))(ir_text)?;
+    if let Some(cons) = maybe_constant {
+        return Ok((ir_text, cons));
+    }
+
+    let (ir_text, constant) = match ty {
         // There are not control constants.
         Type::Control(_) => Err(nom::Err::Error(nom::error::Error {
             input: ir_text,
@@ -893,13 +902,13 @@ fn parse_constant<'a>(
         Type::UnsignedInteger64 => parse_unsigned_integer64(ir_text)?,
         Type::Float32 => parse_float32(ir_text)?,
         Type::Float64 => parse_float64(ir_text)?,
-        Type::Product(tys) => parse_product_constant(
+        Type::Product(ref tys) => parse_product_constant(
             ir_text,
             context.borrow_mut().get_type_id(ty.clone()),
             tys,
             context,
         )?,
-        Type::Summation(tys) => parse_summation_constant(
+        Type::Summation(ref tys) => parse_summation_constant(
             ir_text,
             context.borrow_mut().get_type_id(ty.clone()),
             tys,
@@ -912,7 +921,6 @@ fn parse_constant<'a>(
             context,
         )?,
     };
-    context.borrow_mut().get_type_id(ty);
     Ok((ir_text, constant))
 }
 
@@ -997,7 +1005,7 @@ fn parse_float64<'a>(ir_text: &'a str) -> nom::IResult<&'a str, Constant> {
 fn parse_product_constant<'a>(
     ir_text: &'a str,
     prod_ty: TypeID,
-    tys: Box<[TypeID]>,
+    tys: &Box<[TypeID]>,
     context: &RefCell<Context<'a>>,
 ) -> nom::IResult<&'a str, Constant> {
     let ir_text = nom::character::complete::multispace0(ir_text)?.0;
@@ -1032,7 +1040,7 @@ fn parse_product_constant<'a>(
 fn parse_summation_constant<'a>(
     ir_text: &'a str,
     sum_ty: TypeID,
-    tys: Box<[TypeID]>,
+    tys: &Box<[TypeID]>,
     context: &RefCell<Context<'a>>,
 ) -> nom::IResult<&'a str, Constant> {
     let ir_text = nom::character::complete::multispace0(ir_text)?.0;
diff --git a/hercules_ir/src/typecheck.rs b/hercules_ir/src/typecheck.rs
index 1a183411..eabe45d7 100644
--- a/hercules_ir/src/typecheck.rs
+++ b/hercules_ir/src/typecheck.rs
@@ -579,6 +579,8 @@ fn typeflow(
                         ))
                     }
                 }
+                // Zero constants need to store their type, and we trust it.
+                Constant::Zero(id) => Concrete(id),
             }
         }
         Node::DynamicConstant { id } => {
diff --git a/hercules_opt/Cargo.toml b/hercules_opt/Cargo.toml
index 47bd9bd5..bc305405 100644
--- a/hercules_opt/Cargo.toml
+++ b/hercules_opt/Cargo.toml
@@ -4,4 +4,5 @@ version = "0.1.0"
 authors = ["Russel Arbore <rarbore2@illinois.edu>"]
 
 [dependencies]
+ordered-float = "*"
 hercules_ir = { path = "../hercules_ir" }
diff --git a/hercules_opt/src/ccp.rs b/hercules_opt/src/ccp.rs
index c061a877..1381a2f1 100644
--- a/hercules_opt/src/ccp.rs
+++ b/hercules_opt/src/ccp.rs
@@ -133,13 +133,14 @@ impl Semilattice for ConstantLattice {
  */
 pub fn ccp(
     function: &mut Function,
+    types: &Vec<Type>,
     constants: &mut Vec<Constant>,
     def_use: &ImmutableDefUseMap,
     reverse_postorder: &Vec<NodeID>,
 ) {
     // Step 1: run ccp analysis to understand the function.
     let result = dataflow_global(&function, reverse_postorder, |inputs, node_id| {
-        ccp_flow_function(inputs, node_id, &function, &constants)
+        ccp_flow_function(inputs, node_id, &function, &types, &constants)
     });
 
     // Step 2: update uses of constants. Any node that doesn't produce a
@@ -371,6 +372,7 @@ fn ccp_flow_function(
     inputs: &[CCPLattice],
     node_id: NodeID,
     function: &Function,
+    types: &Vec<Type>,
     old_constants: &Vec<Constant>,
 ) -> CCPLattice {
     let node = &function.nodes[node_id.idx()];
@@ -415,7 +417,7 @@ fn ccp_flow_function(
         // TODO: This should produce a constant zero if the dynamic constant for
         // for the corresponding fork is one.
         Node::ThreadID { control } => inputs[control.idx()].clone(),
-        // TODO: At least for now, collect nodes always produce unknown values.
+        // TODO: At least for now, reduce nodes always produce unknown values.
         Node::Reduce {
             control,
             init: _,
@@ -454,6 +456,7 @@ fn ccp_flow_function(
                     (UnaryOperator::Neg, Constant::Integer64(val)) => Constant::Integer64(-val),
                     (UnaryOperator::Neg, Constant::Float32(val)) => Constant::Float32(-val),
                     (UnaryOperator::Neg, Constant::Float64(val)) => Constant::Float64(-val),
+                    (UnaryOperator::Neg, Constant::Zero(id)) => Constant::Zero(*id),
                     _ => panic!("Unsupported combination of unary operation and constant value. Did typechecking succeed?")
                 };
                 ConstantLattice::Constant(new_cons)
@@ -482,6 +485,32 @@ fn ccp_flow_function(
                 ConstantLattice::Constant(right_cons),
             ) = (left_constant, right_constant)
             {
+                let type_to_zero_cons = |ty_id: TypeID| {
+                    match types[ty_id.idx()] {
+                    Type::Boolean => Constant::Boolean(false),
+                    Type::Integer8 => Constant::Integer8(0),
+                    Type::Integer16 => Constant::Integer16(0),
+                    Type::Integer32 => Constant::Integer32(0),
+                    Type::Integer64 => Constant::Integer64(0),
+                    Type::UnsignedInteger8 => Constant::UnsignedInteger8(0),
+                    Type::UnsignedInteger16 => Constant::UnsignedInteger16(0),
+                    Type::UnsignedInteger32 => Constant::UnsignedInteger32(0),
+                    Type::UnsignedInteger64 => Constant::UnsignedInteger64(0),
+                    Type::Float32 => Constant::Float32(ordered_float::OrderedFloat::<f32>(0.0)),
+                    Type::Float64 => Constant::Float64(ordered_float::OrderedFloat::<f64>(0.0)),
+                    _ => panic!("Unsupported combination of binary operation and constant values. Did typechecking succeed?")
+                }
+                };
+                let left_cons = if let Constant::Zero(id) = left_cons {
+                    type_to_zero_cons(*id)
+                } else {
+                    left_cons.clone()
+                };
+                let right_cons = if let Constant::Zero(id) = right_cons {
+                    type_to_zero_cons(*id)
+                } else {
+                    right_cons.clone()
+                };
                 let new_cons = match (op, left_cons, right_cons) {
                     (BinaryOperator::Add, Constant::Integer8(left_val), Constant::Integer8(right_val)) => Constant::Integer8(left_val + right_val),
                     (BinaryOperator::Add, Constant::Integer16(left_val), Constant::Integer16(right_val)) => Constant::Integer16(left_val + right_val),
@@ -491,8 +520,8 @@ fn ccp_flow_function(
                     (BinaryOperator::Add, Constant::UnsignedInteger16(left_val), Constant::UnsignedInteger16(right_val)) => Constant::UnsignedInteger16(left_val + right_val),
                     (BinaryOperator::Add, Constant::UnsignedInteger32(left_val), Constant::UnsignedInteger32(right_val)) => Constant::UnsignedInteger32(left_val + right_val),
                     (BinaryOperator::Add, Constant::UnsignedInteger64(left_val), Constant::UnsignedInteger64(right_val)) => Constant::UnsignedInteger64(left_val + right_val),
-                    (BinaryOperator::Add, Constant::Float32(left_val), Constant::Float32(right_val)) => Constant::Float32(*left_val + *right_val),
-                    (BinaryOperator::Add, Constant::Float64(left_val), Constant::Float64(right_val)) => Constant::Float64(*left_val + *right_val),
+                    (BinaryOperator::Add, Constant::Float32(left_val), Constant::Float32(right_val)) => Constant::Float32(left_val + right_val),
+                    (BinaryOperator::Add, Constant::Float64(left_val), Constant::Float64(right_val)) => Constant::Float64(left_val + right_val),
                     (BinaryOperator::Sub, Constant::Integer8(left_val), Constant::Integer8(right_val)) => Constant::Integer8(left_val - right_val),
                     (BinaryOperator::Sub, Constant::Integer16(left_val), Constant::Integer16(right_val)) => Constant::Integer16(left_val - right_val),
                     (BinaryOperator::Sub, Constant::Integer32(left_val), Constant::Integer32(right_val)) => Constant::Integer32(left_val - right_val),
@@ -501,8 +530,8 @@ fn ccp_flow_function(
                     (BinaryOperator::Sub, Constant::UnsignedInteger16(left_val), Constant::UnsignedInteger16(right_val)) => Constant::UnsignedInteger16(left_val - right_val),
                     (BinaryOperator::Sub, Constant::UnsignedInteger32(left_val), Constant::UnsignedInteger32(right_val)) => Constant::UnsignedInteger32(left_val - right_val),
                     (BinaryOperator::Sub, Constant::UnsignedInteger64(left_val), Constant::UnsignedInteger64(right_val)) => Constant::UnsignedInteger64(left_val - right_val),
-                    (BinaryOperator::Sub, Constant::Float32(left_val), Constant::Float32(right_val)) => Constant::Float32(*left_val - *right_val),
-                    (BinaryOperator::Sub, Constant::Float64(left_val), Constant::Float64(right_val)) => Constant::Float64(*left_val - *right_val),
+                    (BinaryOperator::Sub, Constant::Float32(left_val), Constant::Float32(right_val)) => Constant::Float32(left_val - right_val),
+                    (BinaryOperator::Sub, Constant::Float64(left_val), Constant::Float64(right_val)) => Constant::Float64(left_val - right_val),
                     (BinaryOperator::Mul, Constant::Integer8(left_val), Constant::Integer8(right_val)) => Constant::Integer8(left_val * right_val),
                     (BinaryOperator::Mul, Constant::Integer16(left_val), Constant::Integer16(right_val)) => Constant::Integer16(left_val * right_val),
                     (BinaryOperator::Mul, Constant::Integer32(left_val), Constant::Integer32(right_val)) => Constant::Integer32(left_val * right_val),
@@ -511,8 +540,8 @@ fn ccp_flow_function(
                     (BinaryOperator::Mul, Constant::UnsignedInteger16(left_val), Constant::UnsignedInteger16(right_val)) => Constant::UnsignedInteger16(left_val * right_val),
                     (BinaryOperator::Mul, Constant::UnsignedInteger32(left_val), Constant::UnsignedInteger32(right_val)) => Constant::UnsignedInteger32(left_val * right_val),
                     (BinaryOperator::Mul, Constant::UnsignedInteger64(left_val), Constant::UnsignedInteger64(right_val)) => Constant::UnsignedInteger64(left_val * right_val),
-                    (BinaryOperator::Mul, Constant::Float32(left_val), Constant::Float32(right_val)) => Constant::Float32(*left_val * *right_val),
-                    (BinaryOperator::Mul, Constant::Float64(left_val), Constant::Float64(right_val)) => Constant::Float64(*left_val * *right_val),
+                    (BinaryOperator::Mul, Constant::Float32(left_val), Constant::Float32(right_val)) => Constant::Float32(left_val * right_val),
+                    (BinaryOperator::Mul, Constant::Float64(left_val), Constant::Float64(right_val)) => Constant::Float64(left_val * right_val),
                     (BinaryOperator::Div, Constant::Integer8(left_val), Constant::Integer8(right_val)) => Constant::Integer8(left_val / right_val),
                     (BinaryOperator::Div, Constant::Integer16(left_val), Constant::Integer16(right_val)) => Constant::Integer16(left_val / right_val),
                     (BinaryOperator::Div, Constant::Integer32(left_val), Constant::Integer32(right_val)) => Constant::Integer32(left_val / right_val),
@@ -521,8 +550,8 @@ fn ccp_flow_function(
                     (BinaryOperator::Div, Constant::UnsignedInteger16(left_val), Constant::UnsignedInteger16(right_val)) => Constant::UnsignedInteger16(left_val / right_val),
                     (BinaryOperator::Div, Constant::UnsignedInteger32(left_val), Constant::UnsignedInteger32(right_val)) => Constant::UnsignedInteger32(left_val / right_val),
                     (BinaryOperator::Div, Constant::UnsignedInteger64(left_val), Constant::UnsignedInteger64(right_val)) => Constant::UnsignedInteger64(left_val / right_val),
-                    (BinaryOperator::Div, Constant::Float32(left_val), Constant::Float32(right_val)) => Constant::Float32(*left_val / *right_val),
-                    (BinaryOperator::Div, Constant::Float64(left_val), Constant::Float64(right_val)) => Constant::Float64(*left_val / *right_val),
+                    (BinaryOperator::Div, Constant::Float32(left_val), Constant::Float32(right_val)) => Constant::Float32(left_val / right_val),
+                    (BinaryOperator::Div, Constant::Float64(left_val), Constant::Float64(right_val)) => Constant::Float64(left_val / right_val),
                     (BinaryOperator::Rem, Constant::Integer8(left_val), Constant::Integer8(right_val)) => Constant::Integer8(left_val % right_val),
                     (BinaryOperator::Rem, Constant::Integer16(left_val), Constant::Integer16(right_val)) => Constant::Integer16(left_val % right_val),
                     (BinaryOperator::Rem, Constant::Integer32(left_val), Constant::Integer32(right_val)) => Constant::Integer32(left_val % right_val),
@@ -531,8 +560,8 @@ fn ccp_flow_function(
                     (BinaryOperator::Rem, Constant::UnsignedInteger16(left_val), Constant::UnsignedInteger16(right_val)) => Constant::UnsignedInteger16(left_val % right_val),
                     (BinaryOperator::Rem, Constant::UnsignedInteger32(left_val), Constant::UnsignedInteger32(right_val)) => Constant::UnsignedInteger32(left_val % right_val),
                     (BinaryOperator::Rem, Constant::UnsignedInteger64(left_val), Constant::UnsignedInteger64(right_val)) => Constant::UnsignedInteger64(left_val % right_val),
-                    (BinaryOperator::Rem, Constant::Float32(left_val), Constant::Float32(right_val)) => Constant::Float32(*left_val % *right_val),
-                    (BinaryOperator::Rem, Constant::Float64(left_val), Constant::Float64(right_val)) => Constant::Float64(*left_val % *right_val),
+                    (BinaryOperator::Rem, Constant::Float32(left_val), Constant::Float32(right_val)) => Constant::Float32(left_val % right_val),
+                    (BinaryOperator::Rem, Constant::Float64(left_val), Constant::Float64(right_val)) => Constant::Float64(left_val % right_val),
                     (BinaryOperator::LT, Constant::Integer8(left_val), Constant::Integer8(right_val)) => Constant::Boolean(left_val < right_val),
                     (BinaryOperator::LT, Constant::Integer16(left_val), Constant::Integer16(right_val)) => Constant::Boolean(left_val < right_val),
                     (BinaryOperator::LT, Constant::Integer32(left_val), Constant::Integer32(right_val)) => Constant::Boolean(left_val < right_val),
@@ -577,7 +606,7 @@ fn ccp_flow_function(
                     // need to unpack the constants.
                     (BinaryOperator::EQ, left_val, right_val) => Constant::Boolean(left_val == right_val),
                     (BinaryOperator::NE, left_val, right_val) => Constant::Boolean(left_val != right_val),
-                    (BinaryOperator::Or, Constant::Boolean(left_val), Constant::Boolean(right_val)) => Constant::Boolean(*left_val || *right_val),
+                    (BinaryOperator::Or, Constant::Boolean(left_val), Constant::Boolean(right_val)) => Constant::Boolean(left_val || right_val),
                     (BinaryOperator::Or, Constant::Integer8(left_val), Constant::Integer8(right_val)) => Constant::Integer8(left_val | right_val),
                     (BinaryOperator::Or, Constant::Integer16(left_val), Constant::Integer16(right_val)) => Constant::Integer16(left_val | right_val),
                     (BinaryOperator::Or, Constant::Integer32(left_val), Constant::Integer32(right_val)) => Constant::Integer32(left_val | right_val),
@@ -586,7 +615,7 @@ fn ccp_flow_function(
                     (BinaryOperator::Or, Constant::UnsignedInteger16(left_val), Constant::UnsignedInteger16(right_val)) => Constant::UnsignedInteger16(left_val | right_val),
                     (BinaryOperator::Or, Constant::UnsignedInteger32(left_val), Constant::UnsignedInteger32(right_val)) => Constant::UnsignedInteger32(left_val | right_val),
                     (BinaryOperator::Or, Constant::UnsignedInteger64(left_val), Constant::UnsignedInteger64(right_val)) => Constant::UnsignedInteger64(left_val | right_val),
-                    (BinaryOperator::And, Constant::Boolean(left_val), Constant::Boolean(right_val)) => Constant::Boolean(*left_val && *right_val),
+                    (BinaryOperator::And, Constant::Boolean(left_val), Constant::Boolean(right_val)) => Constant::Boolean(left_val && right_val),
                     (BinaryOperator::And, Constant::Integer8(left_val), Constant::Integer8(right_val)) => Constant::Integer8(left_val & right_val),
                     (BinaryOperator::And, Constant::Integer16(left_val), Constant::Integer16(right_val)) => Constant::Integer16(left_val & right_val),
                     (BinaryOperator::And, Constant::Integer32(left_val), Constant::Integer32(right_val)) => Constant::Integer32(left_val & right_val),
@@ -595,7 +624,7 @@ fn ccp_flow_function(
                     (BinaryOperator::And, Constant::UnsignedInteger16(left_val), Constant::UnsignedInteger16(right_val)) => Constant::UnsignedInteger16(left_val & right_val),
                     (BinaryOperator::And, Constant::UnsignedInteger32(left_val), Constant::UnsignedInteger32(right_val)) => Constant::UnsignedInteger32(left_val & right_val),
                     (BinaryOperator::And, Constant::UnsignedInteger64(left_val), Constant::UnsignedInteger64(right_val)) => Constant::UnsignedInteger64(left_val & right_val),
-                    (BinaryOperator::Xor, Constant::Boolean(left_val), Constant::Boolean(right_val)) => Constant::Boolean(*left_val ^ *right_val),
+                    (BinaryOperator::Xor, Constant::Boolean(left_val), Constant::Boolean(right_val)) => Constant::Boolean(left_val ^ right_val),
                     (BinaryOperator::Xor, Constant::Integer8(left_val), Constant::Integer8(right_val)) => Constant::Integer8(left_val ^ right_val),
                     (BinaryOperator::Xor, Constant::Integer16(left_val), Constant::Integer16(right_val)) => Constant::Integer16(left_val ^ right_val),
                     (BinaryOperator::Xor, Constant::Integer32(left_val), Constant::Integer32(right_val)) => Constant::Integer32(left_val ^ right_val),
diff --git a/hercules_samples/matmul/matmul.hir b/hercules_samples/matmul/matmul.hir
index 65e0cb10..2f0fb67a 100644
--- a/hercules_samples/matmul/matmul.hir
+++ b/hercules_samples/matmul/matmul.hir
@@ -1,4 +1,5 @@
-fn matmul<3>(a: array(f32, #0, #1), b: array(f32, #1, #2), c: array(f32, #0, #2)) -> array(f32, #0, #2)
+fn matmul<3>(a: array(f32, #0, #1), b: array(f32, #1, #2)) -> array(f32, #0, #2)
+  c = constant(array(f32, #0, #2), zero)
   i_ctrl = fork(start, #0)
   i_idx = thread_id(i_ctrl)
   j_ctrl = fork(i_ctrl, #2)
diff --git a/hercules_samples/matmul/src/main.rs b/hercules_samples/matmul/src/main.rs
index 02aeb3a4..2a28cd80 100644
--- a/hercules_samples/matmul/src/main.rs
+++ b/hercules_samples/matmul/src/main.rs
@@ -10,10 +10,10 @@ fn main() {
         "matmul",
         *const f32,
         *const f32,
-        *mut f32,
         u64,
         u64,
         u64,
+        *mut f32,
         => *const f32
     );
 
@@ -24,10 +24,10 @@ fn main() {
         matmul(
             std::mem::transmute(a.as_ptr()),
             std::mem::transmute(b.as_ptr()),
-            std::mem::transmute(c.as_mut_ptr()),
             2,
             2,
             2,
+            std::mem::transmute(c.as_mut_ptr()),
         )
     };
     println!("{} {}\n{} {}", c[0][0], c[0][1], c[1][0], c[1][1]);
diff --git a/hercules_tools/hercules_cpu/src/main.rs b/hercules_tools/hercules_cpu/src/main.rs
index c6b077f9..209d8619 100644
--- a/hercules_tools/hercules_cpu/src/main.rs
+++ b/hercules_tools/hercules_cpu/src/main.rs
@@ -34,6 +34,7 @@ fn main() {
         |(mut function, id), (types, mut constants, dynamic_constants)| {
             hercules_opt::ccp::ccp(
                 &mut function,
+                &types,
                 &mut constants,
                 &def_uses[id.idx()],
                 &reverse_postorders[id.idx()],
diff --git a/hercules_tools/hercules_dot/src/main.rs b/hercules_tools/hercules_dot/src/main.rs
index a24285eb..198a73b7 100644
--- a/hercules_tools/hercules_dot/src/main.rs
+++ b/hercules_tools/hercules_dot/src/main.rs
@@ -42,6 +42,7 @@ fn main() {
         |(mut function, id), (types, mut constants, dynamic_constants)| {
             hercules_opt::ccp::ccp(
                 &mut function,
+                &types,
                 &mut constants,
                 &def_uses[id.idx()],
                 &reverse_postorders[id.idx()],
-- 
GitLab