From 0c2f1134ea9294ccb7e6c044b810a4a10f1631f5 Mon Sep 17 00:00:00 2001
From: rarbore2 <rarbore2@illinois.edu>
Date: Mon, 26 Aug 2024 17:46:35 -0500
Subject: [PATCH] Misc. improvements

---
 Cargo.lock                  |  82 +++-----
 hercules_cg/src/common.rs   |   1 +
 hercules_cg/src/cpu.rs      |  30 ++-
 hercules_cg/src/top.rs      |  26 ++-
 hercules_ir/src/dom.rs      |   2 +-
 hercules_ir/src/ir.rs       |  64 ++++++-
 hercules_ir/src/manifest.rs |   5 +-
 hercules_ir/src/schedule.rs |   1 +
 hercules_opt/src/lib.rs     |   8 +-
 hercules_opt/src/pass.rs    |  26 +++
 hercules_opt/src/sroa.rs    | 369 ++++++++++++++++++++++++++++++++++++
 hercules_rt/src/elf.rs      |   1 -
 hercules_rt_proc/src/lib.rs |  89 +++++++--
 juno_frontend/src/main.rs   |  78 ++++++--
 14 files changed, 670 insertions(+), 112 deletions(-)
 create mode 100644 hercules_opt/src/sroa.rs

diff --git a/Cargo.lock b/Cargo.lock
index 644b6f7c..094512f5 100644
--- a/Cargo.lock
+++ b/Cargo.lock
@@ -84,7 +84,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
 checksum = "89b47800b0be77592da0afd425cc03468052844aff33b84e33cc696f64e77b6a"
 dependencies = [
  "concurrent-queue",
- "event-listener-strategy 0.5.2",
+ "event-listener-strategy",
  "futures-core",
  "pin-project-lite",
 ]
@@ -110,8 +110,8 @@ checksum = "05b1b633a2115cd122d73b955eadd9916c18c8f510ec9cd1686404c60ad1c29c"
 dependencies = [
  "async-channel 2.3.1",
  "async-executor",
- "async-io 2.3.2",
- "async-lock 3.3.0",
+ "async-io 2.3.3",
+ "async-lock 3.4.0",
  "blocking",
  "futures-lite 2.3.0",
  "once_cell",
@@ -139,17 +139,17 @@ dependencies = [
 
 [[package]]
 name = "async-io"
-version = "2.3.2"
+version = "2.3.3"
 source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "dcccb0f599cfa2f8ace422d3555572f47424da5648a4382a9dd0310ff8210884"
+checksum = "0d6baa8f0178795da0e71bc42c9e5d13261aac7ee549853162e66a241ba17964"
 dependencies = [
- "async-lock 3.3.0",
+ "async-lock 3.4.0",
  "cfg-if",
  "concurrent-queue",
  "futures-io",
  "futures-lite 2.3.0",
  "parking",
- "polling 3.7.0",
+ "polling 3.7.1",
  "rustix 0.38.34",
  "slab",
  "tracing",
@@ -167,12 +167,12 @@ dependencies = [
 
 [[package]]
 name = "async-lock"
-version = "3.3.0"
+version = "3.4.0"
 source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "d034b430882f8381900d3fe6f0aaa3ad94f2cb4ac519b429692a1bc2dda4ae7b"
+checksum = "ff6e472cdea888a4bd64f342f09b3f50e1886d32afe8df3d663c01140b811b18"
 dependencies = [
- "event-listener 4.0.3",
- "event-listener-strategy 0.4.0",
+ "event-listener 5.3.1",
+ "event-listener-strategy",
  "pin-project-lite",
 ]
 
@@ -273,12 +273,11 @@ dependencies = [
 
 [[package]]
 name = "blocking"
-version = "1.6.0"
+version = "1.6.1"
 source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "495f7104e962b7356f0aeb34247aca1fe7d2e783b346582db7f2904cb5717e88"
+checksum = "703f41c54fc768e63e091340b424302bb1c29ef4aa0c7f10fe849dfb114d29ea"
 dependencies = [
  "async-channel 2.3.1",
- "async-lock 3.3.0",
  "async-task",
  "futures-io",
  "futures-lite 2.3.0",
@@ -311,9 +310,9 @@ checksum = "baf1de4339761588bc0619e3cbc0120ee582ebb74b53b4efbf79117bd2da40fd"
 
 [[package]]
 name = "cfgrammar"
-version = "0.13.5"
+version = "0.13.6"
 source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "163348850b1cd34fa99ef1592b5d598ea7e6752f18aff2125b67537e887edb36"
+checksum = "ec07af28018dd8b4b52e49eb6e57268b19dda0996d4824889eb07ee0ef67378c"
 dependencies = [
  "indexmap",
  "lazy_static",
@@ -435,43 +434,22 @@ checksum = "0206175f82b8d6bf6652ff7d71a1e27fd2e4efde587fd368662814d6ec1d9ce0"
 
 [[package]]
 name = "event-listener"
-version = "4.0.3"
+version = "5.3.1"
 source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "67b215c49b2b248c855fb73579eb1f4f26c38ffdc12973e20e07b91d78d5646e"
-dependencies = [
- "concurrent-queue",
- "parking",
- "pin-project-lite",
-]
-
-[[package]]
-name = "event-listener"
-version = "5.3.0"
-source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "6d9944b8ca13534cdfb2800775f8dd4902ff3fc75a50101466decadfdf322a24"
+checksum = "6032be9bd27023a771701cc49f9f053c751055f71efb2e0ae5c15809093675ba"
 dependencies = [
  "concurrent-queue",
  "parking",
  "pin-project-lite",
 ]
 
-[[package]]
-name = "event-listener-strategy"
-version = "0.4.0"
-source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "958e4d70b6d5e81971bebec42271ec641e7ff4e170a6fa605f2b8a8b65cb97d3"
-dependencies = [
- "event-listener 4.0.3",
- "pin-project-lite",
-]
-
 [[package]]
 name = "event-listener-strategy"
 version = "0.5.2"
 source = "registry+https://github.com/rust-lang/crates.io-index"
 checksum = "0f214dc438f977e6d4e3500aaa277f5ad94ca83fbbd9b1a15713ce2344ccc5a1"
 dependencies = [
- "event-listener 5.3.0",
+ "event-listener 5.3.1",
  "pin-project-lite",
 ]
 
@@ -848,9 +826,9 @@ dependencies = [
 
 [[package]]
 name = "lrlex"
-version = "0.13.5"
+version = "0.13.6"
 source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "77ff18e1bd3ed77d7bc2800a0f8b0e922a3c7ba525505be8bab9cf45dfc4984b"
+checksum = "c65e01ebaccc77218ed6fa4f0053daa2124bce4e25a5e83aae0f7ccfc9cbfccb"
 dependencies = [
  "cfgrammar",
  "getopts",
@@ -866,9 +844,9 @@ dependencies = [
 
 [[package]]
 name = "lrpar"
-version = "0.13.5"
+version = "0.13.6"
 source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "efea5a41b9988b5ae41ea9b2375a52cfa0e483f0210357209caa8d361a24a368"
+checksum = "2a4b858180a332aec09d10479a070802b13081077eb94010744bc4e3a11d9768"
 dependencies = [
  "bincode",
  "cactus",
@@ -888,9 +866,9 @@ dependencies = [
 
 [[package]]
 name = "lrtable"
-version = "0.13.5"
+version = "0.13.6"
 source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "ff5668c3bfd279ed24d5b0d24568c48dc993f9beabd51f74d1865a78c1d206ab"
+checksum = "4fcefc5628209d1b1f4b2cd0bcefd0e50be80bdf178e886cb07317f5ce4f2856"
 dependencies = [
  "cfgrammar",
  "fnv",
@@ -1063,9 +1041,9 @@ checksum = "8b870d8c151b6f2fb93e84a13146138f05d02ed11c7e7c54f8826aaaf7c9f184"
 
 [[package]]
 name = "piper"
-version = "0.2.2"
+version = "0.2.3"
 source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "464db0c665917b13ebb5d453ccdec4add5658ee1adc7affc7677615356a8afaf"
+checksum = "ae1d5c74c9876f070d3e8fd503d748c7d974c3e48da8f41350fa5222ef9b4391"
 dependencies = [
  "atomic-waker",
  "fastrand 2.1.0",
@@ -1090,9 +1068,9 @@ dependencies = [
 
 [[package]]
 name = "polling"
-version = "3.7.0"
+version = "3.7.1"
 source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "645493cf344456ef24219d02a768cf1fb92ddf8c92161679ae3d91b91a637be3"
+checksum = "5e6a007746f34ed64099e88783b0ae369eaa3da6392868ba262e2af9b8fbaea1"
 dependencies = [
  "cfg-if",
  "concurrent-queue",
@@ -1129,9 +1107,9 @@ checksum = "5b40af805b3121feab8a3c29f04d8ad262fa8e0561883e7653e024ae4479e6de"
 
 [[package]]
 name = "proc-macro2"
-version = "1.0.84"
+version = "1.0.85"
 source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "ec96c6a92621310b51366f1e28d05ef11489516e93be030060e5fc12024a49d6"
+checksum = "22244ce15aa966053a896d1accb3a6e68469b97c7f33f284b99f0d576879fc23"
 dependencies = [
  "unicode-ident",
 ]
diff --git a/hercules_cg/src/common.rs b/hercules_cg/src/common.rs
index e3e1a3c4..45061aad 100644
--- a/hercules_cg/src/common.rs
+++ b/hercules_cg/src/common.rs
@@ -29,6 +29,7 @@ pub(crate) struct FunctionContext<'a> {
     pub(crate) llvm_types: &'a Vec<String>,
     pub(crate) llvm_constants: &'a Vec<String>,
     pub(crate) llvm_dynamic_constants: &'a Vec<String>,
+    pub(crate) type_sizes_aligns: &'a Vec<(Option<usize>, usize)>,
     pub(crate) partitions_inverted_map: Vec<Vec<NodeID>>,
 }
 
diff --git a/hercules_cg/src/cpu.rs b/hercules_cg/src/cpu.rs
index 134e56a8..b806bd50 100644
--- a/hercules_cg/src/cpu.rs
+++ b/hercules_cg/src/cpu.rs
@@ -195,8 +195,25 @@ impl<'a> FunctionContext<'a> {
         }
 
         // Step 5: emit the now completed basic blocks, in order. Emit a dummy
-        // header block to unconditionally jump to the "top" basic block.
-        write!(w, "bb_header:\n  br label %bb_{}\n", top_node.idx())?;
+        // header block to unconditionally jump to the "top" basic block. Also
+        // emit allocas for compile-time known sized constants. TODO: only emit
+        // used constants, not all the constants in the module. TODO: emit sum
+        // constants.
+        write!(w, "bb_header:\n")?;
+        for cons_id in (0..self.constants.len()).map(ConstantID::new) {
+            if let Some(ty_id) = self.constants[cons_id.idx()].try_product_type(&self.types) {
+                if let (Some(size), align) = self.type_sizes_aligns[ty_id.idx()] {
+                    write!(
+                        w,
+                        "  %cons.{} = alloca i8, i32 {}, align {}\n",
+                        cons_id.idx(),
+                        size,
+                        align
+                    )?;
+                }
+            }
+        }
+        write!(w, "  br label %bb_{}\n", top_node.idx())?;
         for id in partition_context.reverse_postorder {
             if self.bbs[id.idx()] == id {
                 write!(
@@ -226,6 +243,9 @@ impl<'a> PartitionContext<'a> {
         // pointers to some memory at the LLVM IR level. This memory is passed
         // in as a parameter for anything involving arrays, and is alloca-ed for
         // product and summation types.
+        // TODO: actually do this ^ for products. Right now, products are still
+        // done at the LLVM struct level w/ GEP and so on. Apologies for anyone
+        // else reading this comment.
         let mut generate_index_code = |collect: NodeID, indices: &[Index]| -> std::fmt::Result {
             // Step 1: calculate the list of collection types corresponding to
             // each index.
@@ -277,7 +297,7 @@ impl<'a> PartitionContext<'a> {
                             &self.function.llvm_types[collection_ty_ids[idx].idx()];
                         write!(
                             bb.data,
-                            "  %index{}.{}.stride.ptrhack = getelementptr {}, ptr null, i64 1\n  %index{}.{}.stride = ptrtoint ptr %index{}.{}.stride.ptrhack to i64\n  %index{}.{}.offset.ptrhack = getelementptr {}, ptr null, i64 0, i64 {}\n  %index{}.{}.offset = ptrtoint ptr %index{}.{}.offset.ptrhack to i64\n",
+                            "  %index{}.{}.stride.ptrhack = getelementptr {}, ptr null, i64 1\n  %index{}.{}.stride = ptrtoint ptr %index{}.{}.stride.ptrhack to i64\n  %index{}.{}.offset.ptrhack = getelementptr {}, ptr null, i64 0, i32 {}\n  %index{}.{}.offset = ptrtoint ptr %index{}.{}.offset.ptrhack to i64\n",
                             id.idx(), idx,
                             product_llvm_ty,
                             id.idx(), idx,
@@ -344,8 +364,8 @@ impl<'a> PartitionContext<'a> {
             }
 
             // Step 3: emit the getelementptr using the total collection offset.
-            write!(bb.data, "  %index{} = getelementptr i8, ", id.idx(),)?;
-            self.cpu_emit_use_of_node(collect, Some(id), true, &mut bb.data)?;
+            write!(bb.data, "  %index{} = getelementptr i8, ptr ", id.idx(),)?;
+            self.cpu_emit_use_of_node(collect, Some(id), false, &mut bb.data)?;
             write!(bb.data, ", i64 %index{}.0.total_offset\n", id.idx())?;
 
             Ok(())
diff --git a/hercules_cg/src/top.rs b/hercules_cg/src/top.rs
index 4e029ccf..2da69355 100644
--- a/hercules_cg/src/top.rs
+++ b/hercules_cg/src/top.rs
@@ -29,6 +29,15 @@ pub fn codegen<W: Write>(
     let llvm_types = generate_type_strings(module);
     let llvm_constants = generate_constant_strings(module);
     let llvm_dynamic_constants = generate_dynamic_constant_strings(module);
+    let type_sizes_aligns = (0..module.types.len())
+        .map(|idx| {
+            if module.types[idx].is_control() {
+                (None, 0)
+            } else {
+                type_size_and_alignment(module, TypeID::new(idx))
+            }
+        })
+        .collect();
 
     // Generate a dummy uninitialized global - this is needed so that there'll
     // be a non-empty .bss section in the ELF object file.
@@ -55,6 +64,7 @@ pub fn codegen<W: Write>(
             llvm_types: &llvm_types,
             llvm_constants: &llvm_constants,
             llvm_dynamic_constants: &llvm_dynamic_constants,
+            type_sizes_aligns: &type_sizes_aligns,
             partitions_inverted_map: plans[function_idx].invert_partition_map(),
         };
 
@@ -65,15 +75,7 @@ pub fn codegen<W: Write>(
     Ok(ModuleManifest {
         functions: manifests,
         types: module.types.clone(),
-        type_sizes_aligns: (0..module.types.len())
-            .map(|idx| {
-                if module.types[idx].is_control() {
-                    (None, 0)
-                } else {
-                    type_size_and_alignment(module, TypeID::new(idx))
-                }
-            })
-            .collect(),
+        type_sizes_aligns,
         dynamic_constants: module.dynamic_constants.clone(),
         // Get the types of all of the constants. This requires collecting over
         // all of the functions, since the calculated types of constants may be
@@ -185,6 +187,12 @@ impl<'a> FunctionContext<'a> {
             param_types: self.function.param_types.clone(),
             return_type: self.function.return_type,
             typing: self.typing.clone(),
+            used_constants: self
+                .function
+                .nodes
+                .iter()
+                .filter_map(|node| node.try_constant())
+                .collect(),
             num_dynamic_constant_parameters: self.function.num_dynamic_constants,
             partitions: manifests,
             // TODO: populate dynamic constant rules.
diff --git a/hercules_ir/src/dom.rs b/hercules_ir/src/dom.rs
index 67fa231e..622f3a3e 100644
--- a/hercules_ir/src/dom.rs
+++ b/hercules_ir/src/dom.rs
@@ -244,7 +244,7 @@ fn preorder(subgraph: &Subgraph, root: NodeID) -> (Vec<NodeID>, HashMap<NodeID,
     let parents = HashMap::new();
 
     // Order and parents are threaded through arguments / return pair of
-    // reverse_postorder_helper for ownership reasons.
+    // preorder_helper for ownership reasons.
     preorder_helper(root, None, subgraph, order, parents)
 }
 
diff --git a/hercules_ir/src/ir.rs b/hercules_ir/src/ir.rs
index 6f44bff7..e25f4b7c 100644
--- a/hercules_ir/src/ir.rs
+++ b/hercules_ir/src/ir.rs
@@ -659,6 +659,14 @@ impl Type {
         self.is_bool() || self.is_fixed() || self.is_float()
     }
 
+    pub fn is_product(&self) -> bool {
+        if let Type::Product(_) = self {
+            true
+        } else {
+            false
+        }
+    }
+
     pub fn is_array(&self) -> bool {
         if let Type::Array(_, _) = self {
             true
@@ -700,16 +708,66 @@ impl Constant {
         }
     }
 
+    // A zero constant may need to return constants that don't exist yet, so we
+    // need mutable access to the constants array.
+    pub fn try_product_fields(
+        &self,
+        types: &[Type],
+        constants: &mut Vec<Constant>,
+    ) -> Option<Vec<ConstantID>> {
+        match self {
+            Constant::Product(_, fields) => Some(fields.iter().map(|x| *x).collect()),
+            Constant::Zero(ty) => match types[ty.idx()] {
+                Type::Product(ref fields) => Some(
+                    fields
+                        .iter()
+                        .map(|field_ty| {
+                            let field_constant = Constant::Zero(*field_ty);
+                            if let Some(idx) = constants
+                                .iter()
+                                .position(|constant| *constant == field_constant)
+                            {
+                                ConstantID::new(idx)
+                            } else {
+                                let id = ConstantID::new(constants.len());
+                                constants.push(field_constant);
+                                id
+                            }
+                        })
+                        .collect(),
+                ),
+                _ => None,
+            },
+            _ => None,
+        }
+    }
+
     pub fn try_array_type(&self, types: &[Type]) -> Option<TypeID> {
-        // Need types, since zero initializer may be for a collection type, ro
+        // Need types, since zero initializer may be for a collection type, or
         // not.
         match self {
             Constant::Array(ty, _) => Some(*ty),
             Constant::Zero(ty) => {
-                if types[ty.idx()].is_primitive() {
-                    None
+                if types[ty.idx()].is_array() {
+                    Some(*ty)
                 } else {
+                    None
+                }
+            }
+            _ => None,
+        }
+    }
+
+    pub fn try_product_type(&self, types: &[Type]) -> Option<TypeID> {
+        // Need types, since zero initializer may be for a collection type, or
+        // not.
+        match self {
+            Constant::Product(ty, _) => Some(*ty),
+            Constant::Zero(ty) => {
+                if types[ty.idx()].is_product() {
                     Some(*ty)
+                } else {
+                    None
                 }
             }
             _ => None,
diff --git a/hercules_ir/src/manifest.rs b/hercules_ir/src/manifest.rs
index 8b2914c8..f2bbd03a 100644
--- a/hercules_ir/src/manifest.rs
+++ b/hercules_ir/src/manifest.rs
@@ -38,7 +38,8 @@ pub struct ModuleManifest {
  * embed_constant calculates the byte representation of a constant. For zero
  * constants, we avoid storing the actual zero bytes, and optionally store the
  * size - zero constant arrays may have dynamic constant dimensions unknown at
- * compile time.
+ * compile time. The second usize in each variant is the alignment the bytes
+ * must be at to be used properly.
  */
 #[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
 pub enum ConstantBytes {
@@ -56,6 +57,8 @@ pub struct FunctionManifest {
     // Types of all of the nodes in this function. Used for figuring out the
     // type of partition data inputs and outputs.
     pub typing: Vec<TypeID>,
+    // IDs of constants that are actually used in this function.
+    pub used_constants: Vec<ConstantID>,
     // Number of dynamic constant parameters that need to provided.
     pub num_dynamic_constant_parameters: u32,
     // Manifests for constituent partitions.
diff --git a/hercules_ir/src/schedule.rs b/hercules_ir/src/schedule.rs
index f8c35276..3a240a25 100644
--- a/hercules_ir/src/schedule.rs
+++ b/hercules_ir/src/schedule.rs
@@ -319,6 +319,7 @@ pub fn partition_out_forks(
     bbs: &Vec<NodeID>,
     plan: &mut Plan,
 ) {
+    #[allow(non_local_definitions)]
     impl Semilattice for NodeID {
         fn meet(a: &Self, b: &Self) -> Self {
             if a.idx() < b.idx() {
diff --git a/hercules_opt/src/lib.rs b/hercules_opt/src/lib.rs
index f336e8d9..5e53d532 100644
--- a/hercules_opt/src/lib.rs
+++ b/hercules_opt/src/lib.rs
@@ -5,15 +5,17 @@ pub mod dce;
 pub mod fork_guard_elim;
 pub mod forkify;
 pub mod gvn;
-pub mod phi_elim;
 pub mod pass;
+pub mod phi_elim;
 pub mod pred;
+pub mod sroa;
 
 pub use crate::ccp::*;
 pub use crate::dce::*;
+pub use crate::fork_guard_elim::*;
 pub use crate::forkify::*;
 pub use crate::gvn::*;
-pub use crate::phi_elim::*;
-pub use crate::fork_guard_elim::*;
 pub use crate::pass::*;
+pub use crate::phi_elim::*;
 pub use crate::pred::*;
+pub use crate::sroa::*;
diff --git a/hercules_opt/src/pass.rs b/hercules_opt/src/pass.rs
index bb602606..f2de2741 100644
--- a/hercules_opt/src/pass.rs
+++ b/hercules_opt/src/pass.rs
@@ -29,6 +29,7 @@ pub enum Pass {
     PhiElim,
     ForkGuardElim,
     Predication,
+    SROA,
     Verify,
     // Parameterized over whether analyses that aid visualization are necessary.
     // Useful to set to false if displaying a potentially broken module.
@@ -364,6 +365,31 @@ impl PassManager {
                         )
                     }
                 }
+                Pass::SROA => {
+                    println!("{:?}", self.module.functions[0].nodes);
+                    println!("{:?}", self.module.constants);
+                    for ty_id in (0..self.module.types.len()).map(TypeID::new) {
+                        let mut str_ty = "".to_string();
+                        self.module.write_type(ty_id, &mut str_ty).unwrap();
+                        println!("{}: {}", ty_id.idx(), str_ty);
+                    }
+                    self.make_def_uses();
+                    self.make_reverse_postorders();
+                    self.make_typing();
+                    let def_uses = self.def_uses.as_ref().unwrap();
+                    let reverse_postorders = self.reverse_postorders.as_ref().unwrap();
+                    let typing = self.typing.as_ref().unwrap();
+                    for idx in 0..self.module.functions.len() {
+                        sroa(
+                            &mut self.module.functions[idx],
+                            &def_uses[idx],
+                            &reverse_postorders[idx],
+                            &typing[idx],
+                            &self.module.types,
+                            &mut self.module.constants,
+                        );
+                    }
+                }
                 Pass::Verify => {
                     let (
                         def_uses,
diff --git a/hercules_opt/src/sroa.rs b/hercules_opt/src/sroa.rs
new file mode 100644
index 00000000..ef3eceab
--- /dev/null
+++ b/hercules_opt/src/sroa.rs
@@ -0,0 +1,369 @@
+extern crate bitvec;
+extern crate hercules_ir;
+
+use std::collections::HashMap;
+use std::iter::zip;
+
+use self::bitvec::prelude::*;
+
+use self::hercules_ir::dataflow::*;
+use self::hercules_ir::def_use::*;
+use self::hercules_ir::ir::*;
+
+/*
+ * Top level function to run SROA, intraprocedurally. Product values can be used
+ * and created by a relatively small number of nodes. Here are *all* of them:
+ *
+ * - Phi: can merge SSA values of products - these get broken up into phis on
+ *   the individual fields
+ *
+ * - Reduce: similarly to phis, reduce nodes can cycle product values through
+ *   reduction loops - these get broken up into reduces on the fields
+ *
+ * + Return: can return a product - these are untouched, and are the sinks for
+ *   unbroken product values
+ *
+ * + Parameter: can introduce a product - these are untouched, and are the
+ *   sources for unbroken product values
+ *
+ * - Constant: can introduce a product - these are broken up into constants for
+ *   the individual fields
+ *
+ * - Ternary: the select ternary operator can select between products - these
+ *   are broken up into ternary nodes for the individual fields
+ *
+ * + Call: the call node can use a product value as an argument to another
+ *   function, and can produce a product value as a result - these are
+ *   untouched, and are the sink and source for unbroken product values
+ *
+ * - Read: the read node reads primitive fields from product values - these get
+ *   replaced by a direct use of the field value from the broken product value,
+ *   but are retained when the product value is unbroken
+ *
+ * - Write: the write node writes primitive fields in product values - these get
+ *   replaced by a direct def of the field value from the broken product value,
+ *   but are retained when the product value is unbroken
+ *
+ * The nodes above with the list marker "+" are retained for maintaining API/ABI
+ * compatability with other Hercules functions and the host code. These are
+ * called "sink" or "source" nodes in comments below.
+ */
+pub fn sroa(
+    function: &mut Function,
+    def_use: &ImmutableDefUseMap,
+    reverse_postorder: &Vec<NodeID>,
+    typing: &Vec<TypeID>,
+    types: &Vec<Type>,
+    constants: &mut Vec<Constant>,
+) {
+    // Determine which sources of product values we want to try breaking up. We
+    // can determine easily on the soure side if a node produces a product that
+    // shouldn't be broken up by just examining the node type. However, the way
+    // that products are used is also important for determining if the product
+    // can be broken up. We backward dataflow this info to the sources of
+    // product values.
+    #[derive(PartialEq, Eq, Clone, Debug)]
+    enum ProductUseLattice {
+        // The product value used by this node is eventually used by a sink.
+        UsedBySink,
+        // This node uses multiple product values - the stored node ID indicates
+        // which is eventually used by a sink. This lattice value is produced by
+        // read and write nodes implementing partial indexing.
+        SpecificUsedBySink(NodeID),
+        // This node doesn't use a product node, or the product node it does use
+        // is not in turn used by a sink.
+        UnusedBySink,
+    }
+
+    impl Semilattice for ProductUseLattice {
+        fn meet(a: &Self, b: &Self) -> Self {
+            match (a, b) {
+                (Self::UsedBySink, _) | (_, Self::UsedBySink) => Self::UsedBySink,
+                (Self::SpecificUsedBySink(id1), Self::SpecificUsedBySink(id2)) => {
+                    if id1 == id2 {
+                        Self::SpecificUsedBySink(*id1)
+                    } else {
+                        Self::UsedBySink
+                    }
+                }
+                (Self::SpecificUsedBySink(id), _) | (_, Self::SpecificUsedBySink(id)) => {
+                    Self::SpecificUsedBySink(*id)
+                }
+                _ => Self::UnusedBySink,
+            }
+        }
+
+        fn bottom() -> Self {
+            Self::UsedBySink
+        }
+
+        fn top() -> Self {
+            Self::UnusedBySink
+        }
+    }
+
+    // Run dataflow analysis to find which product values are used by a sink.
+    let product_uses = backward_dataflow(function, def_use, reverse_postorder, |succ_outs, id| {
+        match function.nodes[id.idx()] {
+            Node::Return {
+                control: _,
+                data: _,
+            } => {
+                if types[typing[id.idx()].idx()].is_product() {
+                    ProductUseLattice::UsedBySink
+                } else {
+                    ProductUseLattice::UnusedBySink
+                }
+            }
+            Node::Call {
+                function: _,
+                dynamic_constants: _,
+                args: _,
+            } => todo!(),
+            // For reads and writes, we only want to propagate the use of the
+            // product to the collect input of the node.
+            Node::Read {
+                collect,
+                indices: _,
+            }
+            | Node::Write {
+                collect,
+                data: _,
+                indices: _,
+            } => {
+                let meet = succ_outs
+                    .iter()
+                    .fold(ProductUseLattice::top(), |acc, latt| {
+                        ProductUseLattice::meet(&acc, latt)
+                    });
+                if meet == ProductUseLattice::UnusedBySink {
+                    ProductUseLattice::UnusedBySink
+                } else {
+                    ProductUseLattice::SpecificUsedBySink(collect)
+                }
+            }
+            // For non-sink nodes.
+            _ => {
+                if function.nodes[id.idx()].is_control() {
+                    return ProductUseLattice::UnusedBySink;
+                }
+                let meet = succ_outs
+                    .iter()
+                    .fold(ProductUseLattice::top(), |acc, latt| {
+                        ProductUseLattice::meet(&acc, latt)
+                    });
+                if let ProductUseLattice::SpecificUsedBySink(meet_id) = meet {
+                    if meet_id == id {
+                        ProductUseLattice::UsedBySink
+                    } else {
+                        ProductUseLattice::UnusedBySink
+                    }
+                } else {
+                    meet
+                }
+            }
+        }
+    });
+
+    // Only product values introduced as constants can be replaced by scalars.
+    let to_sroa: Vec<(NodeID, ConstantID)> = product_uses
+        .into_iter()
+        .enumerate()
+        .filter_map(|(node_idx, product_use)| {
+            if ProductUseLattice::UnusedBySink == product_use
+                && types[typing[node_idx].idx()].is_product()
+            {
+                function.nodes[node_idx]
+                    .try_constant()
+                    .map(|cons_id| (NodeID::new(node_idx), cons_id))
+            } else {
+                None
+            }
+        })
+        .collect();
+    println!("{:?}", to_sroa);
+
+    // Perform SROA. TODO: repair def-use when there are multiple product
+    // constants to SROA away.
+    assert!(to_sroa.len() < 2);
+    for (constant_node_id, constant_id) in to_sroa {
+        // Get the field constants to replace the product constant with.
+        let product_constant = constants[constant_id.idx()].clone();
+        let constant_fields = product_constant
+            .try_product_fields(types, constants)
+            .unwrap();
+        println!("{:?}", constant_fields);
+
+        // DFS to find all data nodes that use the product constant.
+        let to_replace = sroa_dfs(constant_node_id, function, def_use);
+        println!("{:?}", to_replace);
+
+        // Assemble a mapping from old nodes IDs acting on the product constant
+        // to new nodes IDs operating on the field constants.
+        let old_to_new_id_map: HashMap<NodeID, Vec<NodeID>> = to_replace
+            .iter()
+            .map(|old_id| match function.nodes[old_id.idx()] {
+                Node::Phi {
+                    control: _,
+                    data: _,
+                }
+                | Node::Reduce {
+                    control: _,
+                    init: _,
+                    reduct: _,
+                }
+                | Node::Constant { id: _ }
+                | Node::Ternary {
+                    op: _,
+                    first: _,
+                    second: _,
+                    third: _,
+                }
+                | Node::Write {
+                    collect: _,
+                    data: _,
+                    indices: _,
+                } => {
+                    let new_ids = (0..constant_fields.len())
+                        .map(|_| {
+                            let id = NodeID::new(function.nodes.len());
+                            function.nodes.push(Node::Start);
+                            id
+                        })
+                        .collect();
+                    (*old_id, new_ids)
+                }
+                Node::Read {
+                    collect: _,
+                    indices: _,
+                } => (*old_id, vec![]),
+                _ => panic!("PANIC: Invalid node using a constant product found during SROA."),
+            })
+            .collect();
+
+        // Replace the old nodes with the new nodes. Since we've already
+        // allocated the node IDs, at this point we can iterate through the to-
+        // replace nodes in an arbitrary order.
+        for (old_id, new_ids) in &old_to_new_id_map {
+            // First, add the new nodes to the node list.
+            let node = function.nodes[old_id.idx()].clone();
+            match node {
+                // Replace the original constant with constants for each field.
+                Node::Constant { id: _ } => {
+                    for (new_id, field_id) in zip(new_ids.iter(), constant_fields.iter()) {
+                        function.nodes[new_id.idx()] = Node::Constant { id: *field_id };
+                    }
+                }
+                // Replace writes using the constant as the data use with a
+                // series of writes writing the invidiual constant fields. TODO:
+                // handle the case where the constant is the collect use of the
+                // write node.
+                Node::Write {
+                    collect,
+                    data,
+                    ref indices,
+                } => {
+                    // Create the write chain.
+                    assert!(old_to_new_id_map.contains_key(&data), "PANIC: Can't handle case where write node depends on constant to SROA in the collect use yet.");
+                    let mut collect_def = collect;
+                    for (idx, (new_id, new_data_def)) in
+                        zip(new_ids.iter(), old_to_new_id_map[&data].iter()).enumerate()
+                    {
+                        let mut new_indices = indices.clone().into_vec();
+                        new_indices.push(Index::Field(idx));
+                        function.nodes[new_id.idx()] = Node::Write {
+                            collect: collect_def,
+                            data: *new_data_def,
+                            indices: new_indices.into_boxed_slice(),
+                        };
+                        collect_def = *new_id;
+                    }
+
+                    // Replace uses of the old write with the new write.
+                    for user in def_use.get_users(*old_id) {
+                        get_uses_mut(&mut function.nodes[user.idx()]).map(*old_id, collect_def);
+                    }
+                }
+                _ => todo!(),
+            }
+
+            // Delete the old node.
+            function.nodes[old_id.idx()] = Node::Start;
+        }
+    }
+}
+
+fn sroa_dfs(src: NodeID, function: &Function, def_uses: &ImmutableDefUseMap) -> Vec<NodeID> {
+    // Initialize order vector and bitset for tracking which nodes have been
+    // visited.
+    let order = Vec::with_capacity(def_uses.num_nodes());
+    let visited = bitvec![u8, Lsb0; 0; def_uses.num_nodes()];
+
+    // Order and visited are threaded through arguments / return pair of
+    // sroa_dfs_helper for ownership reasons.
+    let (order, _) = sroa_dfs_helper(src, src, function, def_uses, order, visited);
+    order
+}
+
+fn sroa_dfs_helper(
+    node: NodeID,
+    def: NodeID,
+    function: &Function,
+    def_uses: &ImmutableDefUseMap,
+    mut order: Vec<NodeID>,
+    mut visited: BitVec<u8, Lsb0>,
+) -> (Vec<NodeID>, BitVec<u8, Lsb0>) {
+    if visited[node.idx()] {
+        // If already visited, return early.
+        (order, visited)
+    } else {
+        // Set visited to true.
+        visited.set(node.idx(), true);
+
+        // Before iterating users, push this node.
+        order.push(node);
+        match function.nodes[node.idx()] {
+            Node::Phi {
+                control: _,
+                data: _,
+            }
+            | Node::Reduce {
+                control: _,
+                init: _,
+                reduct: _,
+            }
+            | Node::Constant { id: _ }
+            | Node::Ternary {
+                op: _,
+                first: _,
+                second: _,
+                third: _,
+            } => {}
+            Node::Read {
+                collect,
+                indices: _,
+            } => {
+                assert_eq!(def, collect);
+                return (order, visited);
+            }
+            Node::Write {
+                collect,
+                data,
+                indices: _,
+            } => {
+                if def == data {
+                    return (order, visited);
+                }
+                assert_eq!(def, collect);
+            }
+            _ => panic!("PANIC: Invalid node using a constant product found during SROA."),
+        }
+
+        // Iterate over users, if we shouldn't stop here.
+        for user in def_uses.get_users(node) {
+            (order, visited) = sroa_dfs_helper(*user, node, function, def_uses, order, visited);
+        }
+
+        (order, visited)
+    }
+}
diff --git a/hercules_rt/src/elf.rs b/hercules_rt/src/elf.rs
index 9fc9dc3b..4f2a196c 100644
--- a/hercules_rt/src/elf.rs
+++ b/hercules_rt/src/elf.rs
@@ -12,7 +12,6 @@ use self::libc::*;
  * The libc crate doesn't have everything from elf.h, so these things need to be
  * manually defined.
  */
-
 #[repr(C)]
 #[derive(Debug)]
 struct Elf64_Rela {
diff --git a/hercules_rt_proc/src/lib.rs b/hercules_rt_proc/src/lib.rs
index 1e600cb7..4f4723db 100644
--- a/hercules_rt_proc/src/lib.rs
+++ b/hercules_rt_proc/src/lib.rs
@@ -5,6 +5,7 @@ extern crate hercules_ir;
 extern crate postcard;
 extern crate proc_macro;
 
+use std::collections::HashSet;
 use std::ffi::OsStr;
 use std::fmt::Write;
 use std::fs::File;
@@ -40,13 +41,9 @@ fn generate_type_string(ty: &Type, rust_types: &Vec<String>) -> String {
         Type::UnsignedInteger64 => "u64".to_string(),
         Type::Float32 => "f32".to_string(),
         Type::Float64 => "f64".to_string(),
-        Type::Product(fields) => {
-            fields
-                .iter()
-                .map(|field_id| &rust_types[field_id.idx()] as &str)
-                .fold("(".to_string(), |acc, field| acc + field + ",")
-                + ")"
-        }
+        Type::Product(fields) => fields.iter().fold("Prod".to_string(), |acc, field| {
+            format!("{}_{}", acc, field.idx())
+        }),
         Type::Summation(_) => todo!(),
         Type::Array(elem, _) => format!("*mut {}", &rust_types[elem.idx()]),
     }
@@ -135,6 +132,34 @@ fn codegen(manifest: &ModuleManifest, elf: &[u8]) -> Result<String, anyhow::Erro
             }
         )?;
 
+        // In order to get repr(C), we need to define named structs to attach
+        // that attribute to. This unfortunately means we can't use anonymous
+        // products, which would be much nicer than this nonsense.
+        let mut already_seen = HashSet::new();
+        let mut emit_product_type_def =
+            |ty: &Type, rust_code: &mut String| -> Result<(), anyhow::Error> {
+                match ty {
+                    Type::Product(ref fields) => {
+                        if !already_seen.contains(ty) {
+                            write!(
+                                rust_code,
+                                "#[repr(C)] struct {}({});",
+                                generate_type_string(&ty, &rust_types),
+                                fields.iter().fold("".to_string(), |acc, field| {
+                                    acc + &rust_types[field.idx()] + ","
+                                })
+                            )?;
+                            already_seen.insert(ty.clone());
+                        }
+                    }
+                    _ => {}
+                }
+                Ok(())
+            };
+        for id in types_bottom_up(&manifest.types) {
+            emit_product_type_def(&manifest.types[id.idx()], &mut rust_code)?;
+        }
+
         // Load the ELF object, and cast the appropriate pointers.
         write!(
             rust_code,
@@ -172,13 +197,17 @@ fn codegen(manifest: &ModuleManifest, elf: &[u8]) -> Result<String, anyhow::Erro
                 partition
                     .outputs
                     .iter()
-                    .map(|input| match input {
+                    .map(|output| match output {
                         PartitionOutput::DataOutput(node_id) => function.typing[node_id.idx()],
                         PartitionOutput::ControlIndicator => u64_ty_id,
                     })
                     .collect(),
             );
 
+            // The product type outputs of partitions might not already exist in
+            // the list of types in the module.
+            emit_product_type_def(&output_type, &mut rust_code)?;
+
             // Get the pointer for the partition function, and cast it to the
             // correct function pointer type.
             write!(
@@ -215,9 +244,13 @@ fn codegen(manifest: &ModuleManifest, elf: &[u8]) -> Result<String, anyhow::Erro
         // Declare all of the array constant memories. We declare them as Vecs
         // to allocate the memories. We emit multiplications of the dynamic
         // constant dimensions to allocate the whole memory as one contiguous
-        // range. TODO: emit only the array constants actually used in this
-        // function.
-        for (arr_cons_num, arr_cons_id) in manifest.array_cons_ids.iter().enumerate() {
+        // range.
+        for (arr_cons_num, arr_cons_id) in manifest
+            .array_cons_ids
+            .iter()
+            .filter(|id| function.used_constants.contains(id))
+            .enumerate()
+        {
             let arr_ty_id = manifest
                 .constant_types
                 .iter()
@@ -225,15 +258,31 @@ fn codegen(manifest: &ModuleManifest, elf: &[u8]) -> Result<String, anyhow::Erro
                 .next()
                 .expect("PANIC: Couldn't find type of array constant in manifest.")
                 .1;
-            write!(
-                rust_code,
-                "let mut arr_cons_{}_vec = vec![0u8; {} as usize];let mut arr_cons_{} = arr_cons_{}_vec.as_mut_ptr();arr_cons_{}_vec.leak();",
-                arr_cons_num,
-                emit_type_size_expression(arr_ty_id, manifest),
-                arr_cons_num,
-                arr_cons_num,
-                arr_cons_num,
-            )?;
+            if let ConstantBytes::NonZero(ref bytes, _) = manifest.array_constants[arr_cons_num] {
+                // Initialize the vector from the non-zero constant bytes of the
+                // array.
+                write!(
+                    rust_code,
+                    "let mut arr_cons_{}_vec = Vec::from({});let mut arr_cons_{} = arr_cons_{}_vec.as_mut_ptr();arr_cons_{}_vec.leak();",
+                    arr_cons_num,
+                    Literal::byte_string(bytes),
+                    arr_cons_num,
+                    arr_cons_num,
+                    arr_cons_num,
+                )?;
+            } else {
+                // The array is all zeros, so create a vector of zeros with a
+                // possibly runtime only known size.
+                write!(
+                    rust_code,
+                    "let mut arr_cons_{}_vec = vec![0u8; {} as usize];let mut arr_cons_{} = arr_cons_{}_vec.as_mut_ptr();arr_cons_{}_vec.leak();",
+                    arr_cons_num,
+                    emit_type_size_expression(arr_ty_id, manifest),
+                    arr_cons_num,
+                    arr_cons_num,
+                    arr_cons_num,
+                )?;
+            }
         }
 
         // The core executor is a Rust loop. We literally run a "control token"
diff --git a/juno_frontend/src/main.rs b/juno_frontend/src/main.rs
index e5903ffc..72855c0d 100644
--- a/juno_frontend/src/main.rs
+++ b/juno_frontend/src/main.rs
@@ -13,41 +13,85 @@ mod types;
 
 use codegen::*;
 
+use std::path::PathBuf;
+
 extern crate hercules_ir;
 
 #[derive(Parser)]
 #[command(author, version, about, long_about = None)]
 struct Cli {
-    src_file : String,
+    src_file: String,
+    #[clap(short, long)]
+    verify: bool,
+    #[clap(long = "verify-all")]
+    verify_all: bool,
+    #[arg(short, long = "x-dot")]
+    x_dot: bool,
+    #[arg(short, long, value_name = "OUTPUT")]
+    output: Option<String>,
+}
+
+macro_rules! add_verified_pass {
+    ($pm:ident, $args:ident, $pass:ident) => {
+        $pm.add_pass(hercules_opt::pass::Pass::$pass);
+        if $args.verify || $args.verify_all {
+            $pm.add_pass(hercules_opt::pass::Pass::Verify);
+        }
+    };
+}
+macro_rules! add_pass {
+    ($pm:ident, $args:ident, $pass:ident) => {
+        $pm.add_pass(hercules_opt::pass::Pass::$pass);
+        if $args.verify_all {
+            $pm.add_pass(hercules_opt::pass::Pass::Verify);
+        }
+    };
 }
 
 fn main() {
     let args = Cli::parse();
+    let src_file = args.src_file.clone();
     let prg = semant::parse_and_analyze(args.src_file);
     match prg {
         Ok(prg) => {
             let module = codegen_program(prg);
 
             let mut pm = hercules_opt::pass::PassManager::new(module);
-            pm.add_pass(hercules_opt::pass::Pass::Verify);
-            pm.add_pass(hercules_opt::pass::Pass::PhiElim);
-            pm.add_pass(hercules_opt::pass::Pass::Verify);
-            pm.add_pass(hercules_opt::pass::Pass::CCP);
-            pm.add_pass(hercules_opt::pass::Pass::DCE);
-            pm.add_pass(hercules_opt::pass::Pass::GVN);
-            pm.add_pass(hercules_opt::pass::Pass::DCE);
-            pm.add_pass(hercules_opt::pass::Pass::Xdot(true));
-            pm.add_pass(hercules_opt::pass::Pass::Forkify);
-            pm.add_pass(hercules_opt::pass::Pass::ForkGuardElim);
-            pm.add_pass(hercules_opt::pass::Pass::DCE);
-            pm.add_pass(hercules_opt::pass::Pass::Verify);
-            pm.add_pass(hercules_opt::pass::Pass::Xdot(true));
+            if args.verify || args.verify_all {
+                pm.add_pass(hercules_opt::pass::Pass::Verify);
+            }
+            add_verified_pass!(pm, args, PhiElim);
+            add_pass!(pm, args, CCP);
+            add_pass!(pm, args, DCE);
+            add_pass!(pm, args, GVN);
+            add_pass!(pm, args, DCE);
+            add_pass!(pm, args, SROA);
+            if args.x_dot {
+                pm.add_pass(hercules_opt::pass::Pass::Xdot(true));
+            }
+            add_pass!(pm, args, Forkify);
+            add_pass!(pm, args, ForkGuardElim);
+            add_verified_pass!(pm, args, DCE);
+            if args.x_dot {
+                pm.add_pass(hercules_opt::pass::Pass::Xdot(true));
+            }
+            match args.output {
+                Some(file) => pm.add_pass(hercules_opt::pass::Pass::Codegen(file)),
+                None => {
+                    let mut path = PathBuf::from(src_file);
+                    path.set_extension("hbin");
+                    println!("{:?}", path);
+                    pm.add_pass(hercules_opt::pass::Pass::Codegen(
+                        path.to_str().unwrap().to_string(),
+                    ));
+                }
+            }
             let _ = pm.run_passes();
-        },
+        }
         Err(errs) => {
-            for err in errs{
+            for err in errs {
                 eprintln!("{}", err);
             }
-        },
+        }
     }
 }
-- 
GitLab