From a55ce6342e805467a3366b9fcfe16b163a0e636b Mon Sep 17 00:00:00 2001
From: Aaron Councilman <aaronjc4@illinois.edu>
Date: Sat, 23 Nov 2024 00:49:41 -0600
Subject: [PATCH] Fix schedule gen to panic only if non-root array
 types/constants are used

---
 Cargo.lock                         |  10 ++
 Cargo.toml                         |   2 +-
 hercules_cg/src/sched_gen.rs       | 234 +++++++++++++++--------------
 juno_samples/matmul/src/main.rs    |  13 +-
 juno_samples/matmul/src/matmul.sch |   2 +-
 5 files changed, 141 insertions(+), 120 deletions(-)

diff --git a/Cargo.lock b/Cargo.lock
index 440cdd3c..37216ee3 100644
--- a/Cargo.lock
+++ b/Cargo.lock
@@ -729,6 +729,16 @@ dependencies = [
  "phf",
 ]
 
+[[package]]
+name = "juno_matmul"
+version = "0.1.0"
+dependencies = [
+ "async-std",
+ "hercules_rt",
+ "juno_build",
+ "with_builtin_macros",
+]
+
 [[package]]
 name = "juno_scheduler"
 version = "0.0.1"
diff --git a/Cargo.toml b/Cargo.toml
index 8fa3079b..c6b2468f 100644
--- a/Cargo.toml
+++ b/Cargo.toml
@@ -20,6 +20,6 @@ members = [
 	"juno_scheduler",
 	"juno_build",
 
-	#"juno_samples/matmul",
+	"juno_samples/matmul",
 	"juno_samples/simple3",
 ]
diff --git a/hercules_cg/src/sched_gen.rs b/hercules_cg/src/sched_gen.rs
index f065340c..e5378b9c 100644
--- a/hercules_cg/src/sched_gen.rs
+++ b/hercules_cg/src/sched_gen.rs
@@ -24,9 +24,7 @@ pub fn sched_compile(
     bbs: &Vec<Vec<NodeID>>,
     plans: &Vec<Plan>,
 ) -> SModule {
-    verify_types_well_formed_for_sched_ir(&module.types);
     let stypes = convert_to_sched_ir_types(&module.types);
-    verify_constants_well_formed_for_sched_ir(&module.constants);
     let sconstants = convert_to_sched_ir_constants(&module.constants);
     let function_names: HashMap<FunctionID, String> = module
         .functions
@@ -67,111 +65,99 @@ pub fn sched_compile(
     }
 }
 
-/*
- * Checks for the following conditions:
- * 1. Array types are only ever root types.
- * 2. Summation types do not exist - TODO: properly support summation types.
- */
-fn verify_types_well_formed_for_sched_ir(types: &Vec<Type>) {
-    for ty in types.iter() {
-        match ty {
-            Type::Product(fields) => {
-                let any_array = fields
-                    .iter()
-                    .any(|field_ty| types[field_ty.idx()].is_array());
-                assert!(!any_array, "PANIC: Found non-root array type.");
-            }
-            Type::Summation(_) => panic!("PANIC: Can't lower summations to schedule IR yet."),
-            Type::Array(elem, _) => assert!(
-                !types[elem.idx()].is_array(),
-                "PANIC: Found non-root array type."
-            ),
-            _ => {}
-        }
-    }
-}
-
-fn convert_to_sched_ir_types(types: &Vec<Type>) -> Vec<SType> {
-    let mut stypes = vec![SType::Boolean; types.len()];
+fn convert_to_sched_ir_types(types: &Vec<Type>) -> Vec<Option<SType>> {
+    let mut stypes = vec![None; types.len()];
 
     for id in types_bottom_up(types) {
         stypes[id.idx()] = match &types[id.idx()] {
-            Type::Control => SType::Boolean,
-            Type::Boolean => SType::Boolean,
-            Type::Integer8 => SType::Integer8,
-            Type::Integer16 => SType::Integer16,
-            Type::Integer32 => SType::Integer32,
-            Type::Integer64 => SType::Integer64,
-            Type::UnsignedInteger8 => SType::UnsignedInteger8,
-            Type::UnsignedInteger16 => SType::UnsignedInteger16,
-            Type::UnsignedInteger32 => SType::UnsignedInteger32,
-            Type::UnsignedInteger64 => SType::UnsignedInteger64,
-            Type::Float32 => SType::Float32,
-            Type::Float64 => SType::Float64,
+            Type::Control => None,
+            Type::Boolean => Some(SType::Boolean),
+            Type::Integer8 => Some(SType::Integer8),
+            Type::Integer16 => Some(SType::Integer16),
+            Type::Integer32 => Some(SType::Integer32),
+            Type::Integer64 => Some(SType::Integer64),
+            Type::UnsignedInteger8 => Some(SType::UnsignedInteger8),
+            Type::UnsignedInteger16 => Some(SType::UnsignedInteger16),
+            Type::UnsignedInteger32 => Some(SType::UnsignedInteger32),
+            Type::UnsignedInteger64 => Some(SType::UnsignedInteger64),
+            Type::Float32 => Some(SType::Float32),
+            Type::Float64 => Some(SType::Float64),
             Type::Product(fields) => {
-                SType::Product(fields.iter().map(|id| stypes[id.idx()].clone()).collect())
+                let mut typs = vec![];
+                let mut res_none = false;
+                for id in fields {
+                    if types[id.idx()].is_array() {
+                        res_none = true;
+                        break;
+                    } else {
+                        match &stypes[id.idx()] {
+                            None => {
+                                res_none = true;
+                                break;
+                            }
+                            Some(t) => typs.push(t.clone()),
+                        }
+                    }
+                }
+                if res_none {
+                    None
+                } else {
+                    Some(SType::Product(typs.into()))
+                }
             }
             Type::Summation(_) => todo!(),
-            Type::Array(elem_ty, _) => SType::ArrayRef(Box::new(stypes[elem_ty.idx()].clone())),
+            Type::Array(elem_ty, _) => match &stypes[elem_ty.idx()] {
+                None => None,
+                Some(t) => Some(SType::ArrayRef(Box::new(t.clone()))),
+            },
         };
     }
 
     stypes
 }
 
-/*
- * Checks for the following conditions:
- * 1. Array constants are only ever root constants.
- * 2. Summation constants do not exist - TODO: properly support summation
- *    constants.
- */
-fn verify_constants_well_formed_for_sched_ir(constants: &Vec<Constant>) {
-    for ty in constants.iter() {
-        match ty {
-            Constant::Product(_, fields) => {
-                let any_array = fields
-                    .iter()
-                    .any(|field_ty| constants[field_ty.idx()].is_array());
-                assert!(!any_array, "PANIC: Found non-root array constant.");
-            }
-            Constant::Summation(_, _, _) => {
-                panic!("PANIC: Can't lower summations to schedule IR yet.")
-            }
-            // We don't need to check array constants explicitly, since they
-            // explicitly store their type, and nothing else - if an array
-            // constant were invalid by condition #1, then its corresponding
-            // type would be invalid, which would get caught by
-            // `verify_types_well_formed_for_sched_ir`.
-            _ => {}
-        }
-    }
-}
-
-fn convert_to_sched_ir_constants(constants: &Vec<Constant>) -> Vec<SConstant> {
-    let mut sconstants = vec![SConstant::Boolean(false); constants.len()];
+fn convert_to_sched_ir_constants(constants: &Vec<Constant>) -> Vec<Option<SConstant>> {
+    let mut sconstants = vec![None; constants.len()];
 
     for id in constants_bottom_up(constants) {
         sconstants[id.idx()] = match &constants[id.idx()] {
-            Constant::Boolean(val) => SConstant::Boolean(*val),
-            Constant::Integer8(val) => SConstant::Integer8(*val),
-            Constant::Integer16(val) => SConstant::Integer16(*val),
-            Constant::Integer32(val) => SConstant::Integer32(*val),
-            Constant::Integer64(val) => SConstant::Integer64(*val),
-            Constant::UnsignedInteger8(val) => SConstant::UnsignedInteger8(*val),
-            Constant::UnsignedInteger16(val) => SConstant::UnsignedInteger16(*val),
-            Constant::UnsignedInteger32(val) => SConstant::UnsignedInteger32(*val),
-            Constant::UnsignedInteger64(val) => SConstant::UnsignedInteger64(*val),
-            Constant::Float32(val) => SConstant::Float32(*val),
-            Constant::Float64(val) => SConstant::Float64(*val),
-            Constant::Product(_, fields) => SConstant::Product(
-                fields
-                    .iter()
-                    .map(|id| sconstants[id.idx()].clone())
-                    .collect(),
-            ),
+            Constant::Boolean(val) => Some(SConstant::Boolean(*val)),
+            Constant::Integer8(val) => Some(SConstant::Integer8(*val)),
+            Constant::Integer16(val) => Some(SConstant::Integer16(*val)),
+            Constant::Integer32(val) => Some(SConstant::Integer32(*val)),
+            Constant::Integer64(val) => Some(SConstant::Integer64(*val)),
+            Constant::UnsignedInteger8(val) => Some(SConstant::UnsignedInteger8(*val)),
+            Constant::UnsignedInteger16(val) => Some(SConstant::UnsignedInteger16(*val)),
+            Constant::UnsignedInteger32(val) => Some(SConstant::UnsignedInteger32(*val)),
+            Constant::UnsignedInteger64(val) => Some(SConstant::UnsignedInteger64(*val)),
+            Constant::Float32(val) => Some(SConstant::Float32(*val)),
+            Constant::Float64(val) => Some(SConstant::Float64(*val)),
+            Constant::Product(_, fields) => {
+                let mut consts = vec![];
+                let mut res_none = false;
+                for id in fields {
+                    if constants[id.idx()].is_array() {
+                        res_none = true;
+                        break;
+                    } else {
+                        match &sconstants[id.idx()] {
+                            None => {
+                                res_none = true;
+                                break;
+                            }
+                            Some(c) => consts.push(c.clone()),
+                        }
+                    }
+                }
+                if res_none {
+                    None
+                } else {
+                    Some(SConstant::Product(consts.into()))
+                }
+            }
             Constant::Summation(_, _, _) => todo!(),
             // Array constants are never generated inline schedule IR.
-            Constant::Array(_) => SConstant::Boolean(false),
+            Constant::Array(_) => None,
         };
     }
 
@@ -195,8 +181,8 @@ struct FunctionContext<'a> {
     antideps: &'a Vec<(NodeID, NodeID)>,
     bbs: &'a Vec<NodeID>,
     plan: &'a Plan,
-    stypes: &'a Vec<SType>,
-    sconstants: &'a Vec<SConstant>,
+    stypes: &'a Vec<Option<SType>>,
+    sconstants: &'a Vec<Option<SConstant>>,
     function_names: &'a HashMap<FunctionID, String>,
 
     top_nodes: Vec<NodeID>,
@@ -222,8 +208,8 @@ impl<'a> FunctionContext<'a> {
         antideps: &'a Vec<(NodeID, NodeID)>,
         bbs: &'a Vec<NodeID>,
         plan: &'a Plan,
-        stypes: &'a Vec<SType>,
-        sconstants: &'a Vec<SConstant>,
+        stypes: &'a Vec<Option<SType>>,
+        sconstants: &'a Vec<Option<SConstant>>,
         function_names: &'a HashMap<FunctionID, String>,
     ) -> Self {
         let inverted_partition_map = plan.invert_partition_map();
@@ -348,7 +334,7 @@ impl<'a> FunctionContext<'a> {
                     parameters.extend(self.function.param_types.iter().enumerate().map(
                         |(param_idx, ty_id)| {
                             (
-                                self.stypes[ty_id.idx()].clone(),
+                                self.stypes[ty_id.idx()].clone().unwrap(),
                                 ParameterKind::HerculesParameter(param_idx),
                             )
                         },
@@ -356,7 +342,9 @@ impl<'a> FunctionContext<'a> {
                 } else {
                     parameters.extend(self.data_inputs[partition_idx].iter().map(|node_id| {
                         (
-                            self.stypes[self.typing[node_id.idx()].idx()].clone(),
+                            self.stypes[self.typing[node_id.idx()].idx()]
+                                .clone()
+                                .unwrap(),
                             ParameterKind::DataInput(*node_id),
                         )
                     }))
@@ -378,7 +366,9 @@ impl<'a> FunctionContext<'a> {
                                 .filter_map(|use_id| {
                                     if let Some(array_id) = array_node_to_array_id.get(use_id) {
                                         Some((
-                                            self.stypes[self.typing[use_id.idx()].idx()].clone(),
+                                            self.stypes[self.typing[use_id.idx()].idx()]
+                                                .clone()
+                                                .unwrap(),
                                             ParameterKind::ArrayConstant(*array_id),
                                         ))
                                     } else {
@@ -423,14 +413,18 @@ impl<'a> FunctionContext<'a> {
                 {
                     assert_eq!(successors.len(), 0);
                     returns.push((
-                        self.stypes[self.function.return_type.idx()].clone(),
+                        self.stypes[self.function.return_type.idx()]
+                            .clone()
+                            .unwrap(),
                         ReturnKind::HerculesReturn,
                     ));
                 } else {
                     assert!(successors.len() > 0);
                     returns.extend(self.data_outputs[partition_idx].iter().map(|node_id| {
                         (
-                            self.stypes[self.typing[node_id.idx()].idx()].clone(),
+                            self.stypes[self.typing[node_id.idx()].idx()]
+                                .clone()
+                                .unwrap(),
                             ReturnKind::DataOutput(*node_id),
                         )
                     }));
@@ -461,14 +455,16 @@ impl<'a> FunctionContext<'a> {
         param_types.extend(self.function.param_types.iter().enumerate().map(
             |(param_idx, ty_id)| {
                 (
-                    self.stypes[ty_id.idx()].clone(),
+                    self.stypes[ty_id.idx()].clone().unwrap(),
                     ParameterKind::HerculesParameter(param_idx),
                 )
             },
         ));
         param_types.extend(array_node_to_array_id.iter().map(|(node_id, array_id)| {
             (
-                self.stypes[self.typing[node_id.idx()].idx()].clone(),
+                self.stypes[self.typing[node_id.idx()].idx()]
+                    .clone()
+                    .unwrap(),
                 ParameterKind::ArrayConstant(*array_id),
             )
         }));
@@ -481,7 +477,9 @@ impl<'a> FunctionContext<'a> {
 
         // The return type is just the schedule IR type corresponding to the
         // Hercules function's return type.
-        let return_type = self.stypes[self.function.return_type.idx()].clone();
+        let return_type = self.stypes[self.function.return_type.idx()]
+            .clone()
+            .unwrap();
 
         let manifest = Manifest {
             param_types,
@@ -647,7 +645,7 @@ impl<'a> FunctionContext<'a> {
                                                 .unwrap(),
                                         )
                                     } else {
-                                        SValue::Constant(self.sconstants[id.idx()].clone())
+                                        SValue::Constant(self.sconstants[id.idx()].clone().unwrap())
                                     };
                                     Some((*use_id, svalue))
                                 } else {
@@ -998,7 +996,7 @@ impl<'a> FunctionContext<'a> {
                     block.insts.push(SInst::Phi { inputs });
                     block.virt_regs.push((
                         self_virt_reg(),
-                        self.stypes[self.typing[id.idx()].idx()].clone(),
+                        self.stypes[self.typing[id.idx()].idx()].clone().unwrap(),
                     ));
                 }
             }
@@ -1031,7 +1029,7 @@ impl<'a> FunctionContext<'a> {
                 block.insts.push(SInst::ReductionVariable { number });
                 block.virt_regs.push((
                     self_virt_reg(),
-                    self.stypes[self.typing[id.idx()].idx()].clone(),
+                    self.stypes[self.typing[id.idx()].idx()].clone().unwrap(),
                 ));
             }
 
@@ -1042,7 +1040,7 @@ impl<'a> FunctionContext<'a> {
                 });
                 block.virt_regs.push((
                     self_virt_reg(),
-                    self.stypes[self.typing[id.idx()].idx()].clone(),
+                    self.stypes[self.typing[id.idx()].idx()].clone().unwrap(),
                 ));
             }
             Node::Binary { left, right, op } => {
@@ -1053,7 +1051,7 @@ impl<'a> FunctionContext<'a> {
                 });
                 block.virt_regs.push((
                     self_virt_reg(),
-                    self.stypes[self.typing[id.idx()].idx()].clone(),
+                    self.stypes[self.typing[id.idx()].idx()].clone().unwrap(),
                 ));
             }
             Node::Ternary {
@@ -1070,7 +1068,7 @@ impl<'a> FunctionContext<'a> {
                 });
                 block.virt_regs.push((
                     self_virt_reg(),
-                    self.stypes[self.typing[id.idx()].idx()].clone(),
+                    self.stypes[self.typing[id.idx()].idx()].clone().unwrap(),
                 ));
             }
 
@@ -1097,8 +1095,10 @@ impl<'a> FunctionContext<'a> {
                     // Array loads need the dynamic constant bounds for indexing
                     // math.
                     let bounds = lower_extents(collect, &mut block);
-                    let load_ty = if let SType::ArrayRef(elem_ty) =
-                        self.stypes[self.typing[collect.idx()].idx()].clone()
+                    let load_ty = if let SType::ArrayRef(elem_ty) = self.stypes
+                        [self.typing[collect.idx()].idx()]
+                    .clone()
+                    .unwrap()
                     {
                         *elem_ty
                     } else {
@@ -1129,7 +1129,7 @@ impl<'a> FunctionContext<'a> {
                     });
                     block.virt_regs.push((
                         self_virt_reg(),
-                        self.stypes[self.typing[id.idx()].idx()].clone(),
+                        self.stypes[self.typing[id.idx()].idx()].clone().unwrap(),
                     ));
                 }
             }
@@ -1175,8 +1175,10 @@ impl<'a> FunctionContext<'a> {
 
                     // Load the product.
                     let load_virt_reg = self.make_virt_reg(partition_idx);
-                    let load_ty = if let SType::ArrayRef(elem_ty) =
-                        self.stypes[self.typing[collect.idx()].idx()].clone()
+                    let load_ty = if let SType::ArrayRef(elem_ty) = self.stypes
+                        [self.typing[collect.idx()].idx()]
+                    .clone()
+                    .unwrap()
                     {
                         *elem_ty
                     } else {
@@ -1227,7 +1229,7 @@ impl<'a> FunctionContext<'a> {
                     // they create a new product value.
                     block.virt_regs.push((
                         self_virt_reg(),
-                        self.stypes[self.typing[id.idx()].idx()].clone(),
+                        self.stypes[self.typing[id.idx()].idx()].clone().unwrap(),
                     ));
                 }
             }
@@ -1412,11 +1414,11 @@ impl<'a> FunctionContext<'a> {
     }
 }
 
-fn convert_unary_op(op: UnaryOperator, simple_ir_types: &[SType]) -> SUnaryOperator {
+fn convert_unary_op(op: UnaryOperator, simple_ir_types: &[Option<SType>]) -> SUnaryOperator {
     match op {
         UnaryOperator::Not => SUnaryOperator::Not,
         UnaryOperator::Neg => SUnaryOperator::Neg,
-        UnaryOperator::Cast(ty) => SUnaryOperator::Cast(simple_ir_types[ty.idx()].clone()),
+        UnaryOperator::Cast(ty) => SUnaryOperator::Cast(simple_ir_types[ty.idx()].clone().unwrap()),
     }
 }
 
diff --git a/juno_samples/matmul/src/main.rs b/juno_samples/matmul/src/main.rs
index e6a73a3f..6d1867ac 100644
--- a/juno_samples/matmul/src/main.rs
+++ b/juno_samples/matmul/src/main.rs
@@ -4,7 +4,7 @@ extern crate async_std;
 extern crate juno_build;
 extern crate hercules_rt;
 
-juno_build::juno!("matmul.jn");
+juno_build::juno!("matmul");
 
 fn main() {
     async_std::task::block_on(async {
@@ -12,8 +12,17 @@ fn main() {
         let mut b = vec![5.0, 6.0, 7.0, 8.0];
         let mut c = vec![0.0, 0.0, 0.0, 0.0];
         unsafe {
-            matmul(a.as_mut_prt(), b.as_mut_ptr(), c.as_mut_ptr(), 2, 2, 2).await;
+            matmul(a.as_mut_ptr(), b.as_mut_ptr(), c.as_mut_ptr(), 2, 2, 2).await;
         }
         println!("[[{}, {}], [{}, {}]]", c[0], c[1], c[2], c[3]);
+        assert_eq!(c[0], 19.0);
+        assert_eq!(c[1], 22.0);
+        assert_eq!(c[2], 43.0);
+        assert_eq!(c[3], 50.0);
     });
 }
+
+#[test]
+fn matmul_test() {
+    main();
+}
diff --git a/juno_samples/matmul/src/matmul.sch b/juno_samples/matmul/src/matmul.sch
index bbc7ed0e..847a9121 100644
--- a/juno_samples/matmul/src/matmul.sch
+++ b/juno_samples/matmul/src/matmul.sch
@@ -1,5 +1,5 @@
 function matmul {
-  partition { @outer, @middle, @inner } on gpu
+  partition { @outer, @middle, @inner } on cpu //gpu
   partition @exit on cpu
 
   parallelize @outer
-- 
GitLab