From 865991779273b4b976beab65762a80ca8a6b51c1 Mon Sep 17 00:00:00 2001
From: rarbore2 <rarbore2@illinois.edu>
Date: Mon, 19 Feb 2024 19:41:13 -0600
Subject: [PATCH] Basic IR schedules framework

---
 Cargo.lock                              |   1 +
 hercules_ir/src/ir.rs                   |   3 +-
 hercules_ir/src/lib.rs                  |   4 +-
 hercules_ir/src/schedule.rs             | 305 ++++++++++++++++++++++++
 hercules_tools/hercules_dot/Cargo.toml  |   1 +
 hercules_tools/hercules_dot/src/dot.rs  | 165 ++++++++-----
 hercules_tools/hercules_dot/src/main.rs |  31 ++-
 7 files changed, 451 insertions(+), 59 deletions(-)
 create mode 100644 hercules_ir/src/schedule.rs

diff --git a/Cargo.lock b/Cargo.lock
index 5cc9664f..2fb56d26 100644
--- a/Cargo.lock
+++ b/Cargo.lock
@@ -177,6 +177,7 @@ name = "hercules_dot"
 version = "0.1.0"
 dependencies = [
  "clap",
+ "hercules_cg",
  "hercules_ir",
  "hercules_opt",
  "rand",
diff --git a/hercules_ir/src/ir.rs b/hercules_ir/src/ir.rs
index b3df19d9..2657447e 100644
--- a/hercules_ir/src/ir.rs
+++ b/hercules_ir/src/ir.rs
@@ -857,10 +857,10 @@ impl Node {
         self.is_start()
             || self.is_region()
             || self.is_if()
+            || self.is_match()
             || self.is_fork()
             || self.is_join()
             || self.is_return()
-            || self.is_return()
     }
 
     pub fn upper_case_name(&self) -> &'static str {
@@ -1078,6 +1078,7 @@ impl TernaryOperator {
  * Rust things to make newtyped IDs usable.
  */
 
+#[macro_export]
 macro_rules! define_id_type {
     ($x: ident) => {
         #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, PartialOrd, Ord)]
diff --git a/hercules_ir/src/lib.rs b/hercules_ir/src/lib.rs
index 9606c0b2..2b8bac06 100644
--- a/hercules_ir/src/lib.rs
+++ b/hercules_ir/src/lib.rs
@@ -1,4 +1,4 @@
-#![feature(coroutines, coroutine_trait)]
+#![feature(coroutines, coroutine_trait, let_chains)]
 
 pub mod build;
 pub mod dataflow;
@@ -7,6 +7,7 @@ pub mod dom;
 pub mod ir;
 pub mod loops;
 pub mod parse;
+pub mod schedule;
 pub mod subgraph;
 pub mod typecheck;
 pub mod verify;
@@ -18,6 +19,7 @@ pub use crate::dom::*;
 pub use crate::ir::*;
 pub use crate::loops::*;
 pub use crate::parse::*;
+pub use crate::schedule::*;
 pub use crate::subgraph::*;
 pub use crate::typecheck::*;
 pub use crate::verify::*;
diff --git a/hercules_ir/src/schedule.rs b/hercules_ir/src/schedule.rs
new file mode 100644
index 00000000..2151b132
--- /dev/null
+++ b/hercules_ir/src/schedule.rs
@@ -0,0 +1,305 @@
+use std::collections::HashMap;
+use std::iter::zip;
+
+use crate::*;
+
+/*
+ * An individual schedule is a single "directive" for the compiler to take into
+ * consideration at some point during the compilation pipeline. Each schedule is
+ * associated with a single node.
+ */
+#[derive(Debug, Clone)]
+pub enum Schedule {
+    ParallelReduce,
+    Vectorize,
+}
+
+/*
+ * The authoritative enumeration of supported devices. Technically, a device
+ * refers to a specific backend, so difference "devices" may refer to the same
+ * "kind" of hardware.
+ */
+#[derive(Debug, Clone)]
+pub enum Device {
+    CPU,
+    GPU,
+}
+
+/*
+ * A plan is a set of schedules associated with each node, plus partitioning
+ * information. Partitioning information is a mapping from node ID to partition
+ * ID. A plan's scope is a single function.
+ */
+#[derive(Debug, Clone)]
+pub struct Plan {
+    pub schedules: Vec<Vec<Schedule>>,
+    pub partitions: Vec<PartitionID>,
+    pub partition_devices: Vec<Device>,
+    pub num_partitions: usize,
+}
+
+define_id_type!(PartitionID);
+
+impl Plan {
+    /*
+     * Invert stored map from node to partition to map from partition to nodes.
+     */
+    pub fn invert_partition_map(&self) -> Vec<Vec<NodeID>> {
+        let mut map = vec![vec![]; self.num_partitions];
+
+        for idx in 0..self.partitions.len() {
+            map[self.partitions[idx].idx()].push(NodeID::new(idx));
+        }
+
+        map
+    }
+}
+
+/*
+ * A "default" plan should be available, where few schedules are used and
+ * conservative partitioning is enacted. Only schedules that can be proven safe
+ * by the compiler should be included.
+ */
+pub fn default_plan(
+    function: &Function,
+    reverse_postorder: &Vec<NodeID>,
+    fork_join_map: &HashMap<NodeID, NodeID>,
+    bbs: &Vec<NodeID>,
+) -> Plan {
+    // Start by creating a completely bare-bones plan doing nothing interesting.
+    let mut plan = Plan {
+        schedules: vec![vec![]; function.nodes.len()],
+        partitions: vec![PartitionID::new(0); function.nodes.len()],
+        partition_devices: vec![Device::CPU; 1],
+        num_partitions: 0,
+    };
+
+    // Infer schedules.
+    infer_parallel_reduce(function, fork_join_map, &mut plan);
+    infer_vectorize(function, fork_join_map, &mut plan);
+
+    // Infer a partitioning.
+    partition_out_forks(function, reverse_postorder, fork_join_map, bbs, &mut plan);
+    place_fork_partitions_on_gpu(function, &mut plan);
+
+    plan
+}
+
+/*
+ * Infer parallel reductions consisting of a simple cycle between a Reduce node
+ * and a Write node, where an index of the Write is a position index using the
+ * ThreadID node attached to the corresponding Fork. This procedure also adds
+ * the ParallelReduce schedule to Reduce nodes reducing over a parallelized
+ * Reduce, as long as the base Write node also has a position index that is the
+ * ThreadID of the outer fork. In other words, the complete Reduce chain is
+ * annotated with ParallelReduce, as long as each ThreadID appears in the
+ * positional indexing of the Write.
+ */
+pub fn infer_parallel_reduce(
+    function: &Function,
+    fork_join_map: &HashMap<NodeID, NodeID>,
+    plan: &mut Plan,
+) {
+    for id in (0..function.nodes.len())
+        .map(NodeID::new)
+        .filter(|id| function.nodes[id.idx()].is_reduce())
+    {
+        let mut first_control = None;
+        let mut last_reduce = id;
+        let mut chain_id = id;
+
+        // Walk down Reduce chain until we reach the Reduce potentially looping
+        // with the Write. Note the control node of the first Reduce, since this
+        // will tell us which Thread ID to look for in the Write.
+        while let Node::Reduce {
+            control,
+            init: _,
+            reduct,
+        } = function.nodes[chain_id.idx()]
+        {
+            if first_control.is_none() {
+                first_control = Some(control);
+            }
+
+            last_reduce = chain_id;
+            chain_id = reduct;
+        }
+
+        // Check for a Write-Reduce tight cycle.
+        if let Node::Write {
+            collect,
+            data: _,
+            indices,
+        } = &function.nodes[chain_id.idx()]
+            && *collect == last_reduce
+        {
+            // If there is a Write-Reduce tight cycle, get the position indices.
+            let positions = indices
+                .iter()
+                .filter_map(|index| {
+                    if let Index::Position(indices) = index {
+                        Some(indices)
+                    } else {
+                        None
+                    }
+                })
+                .flat_map(|pos| pos.iter());
+
+            // Get the Forks corresponding to uses of bare ThreadIDs.
+            let mut forks = positions.filter_map(|id| {
+                if let Node::ThreadID { control } = function.nodes[id.idx()] {
+                    Some(control)
+                } else {
+                    None
+                }
+            });
+
+            // Check if any of the Forks correspond to the Join associated with
+            // the Reduce being considered.
+            let is_parallel = forks.any(|id| fork_join_map[&id] == first_control.unwrap());
+
+            if is_parallel {
+                plan.schedules[id.idx()].push(Schedule::ParallelReduce);
+            }
+        }
+    }
+}
+
+/*
+ * Infer vectorizable fork-joins. Just check that there are no control nodes
+ * between a fork and its join.
+ */
+pub fn infer_vectorize(
+    function: &Function,
+    fork_join_map: &HashMap<NodeID, NodeID>,
+    plan: &mut Plan,
+) {
+    for id in (0..function.nodes.len())
+        .map(NodeID::new)
+        .filter(|id| function.nodes[id.idx()].is_join())
+    {
+        let u = get_uses(&function.nodes[id.idx()]).as_ref()[0];
+        if let Some(join) = fork_join_map.get(&u)
+            && *join == id
+        {
+            plan.schedules[u.idx()].push(Schedule::Vectorize);
+        }
+    }
+}
+
+/*
+ * Create partitions corresponding to fork-join nests. Also, split the "top-
+ * level" partition into sub-partitions that are connected graphs. Place data
+ * nodes using ThreadID or Reduce nodes in the corresponding fork-join nest's
+ * partition.
+ */
+pub fn partition_out_forks(
+    function: &Function,
+    reverse_postorder: &Vec<NodeID>,
+    fork_join_map: &HashMap<NodeID, NodeID>,
+    bbs: &Vec<NodeID>,
+    plan: &mut Plan,
+) {
+    impl Semilattice for NodeID {
+        fn meet(a: &Self, b: &Self) -> Self {
+            if a.idx() < b.idx() {
+                *a
+            } else {
+                *b
+            }
+        }
+
+        fn bottom() -> Self {
+            NodeID::new(0)
+        }
+
+        fn top() -> Self {
+            NodeID::new(!0)
+        }
+    }
+
+    // Step 1: do dataflow analysis over control nodes to identify a
+    // representative node for each partition. Each fork not taking as input the
+    // ID of another fork node introduces its own ID as a partition
+    // representative. Each join node propagates the fork ID if it's not the
+    // fork pointing to the join in the fork join map - otherwise, it introduces
+    // its user as a representative node ID for a partition. A region node taking
+    // multiple node IDs as input belongs to the partition with the smaller
+    // representative node ID.
+    let mut representatives = forward_dataflow(
+        function,
+        reverse_postorder,
+        |inputs: &[&NodeID], node_id: NodeID| match function.nodes[node_id.idx()] {
+            Node::Start => NodeID::new(0),
+            Node::Fork {
+                control: _,
+                factor: _,
+            } => {
+                // Start a partition if the preceding partition isn't a fork
+                // partition. Otherwise, be part of the parent fork partition.
+                if *inputs[0] != NodeID::top() && function.nodes[inputs[0].idx()].is_fork() {
+                    inputs[0].clone()
+                } else {
+                    node_id
+                }
+            }
+            Node::Join { control: _ } => inputs[0].clone(),
+            _ => {
+                // If the previous node is a join and terminates a fork's
+                // partition, then start a new partition here. Otherwise, just
+                // meet over the input lattice values. Set all data nodes to be
+                // in the !0 partition.
+                if !function.nodes[node_id.idx()].is_control() {
+                    NodeID::top()
+                } else if zip(inputs, get_uses(&function.nodes[node_id.idx()]).as_ref())
+                    .any(|(part_id, pred_id)| fork_join_map.get(part_id) == Some(pred_id))
+                {
+                    node_id
+                } else {
+                    inputs
+                        .iter()
+                        .filter(|id| {
+                            ***id != NodeID::top() && function.nodes[id.idx()].is_control()
+                        })
+                        .fold(NodeID::top(), |a, b| NodeID::meet(&a, b))
+                }
+            }
+        },
+    );
+
+    // Step 2: assign data nodes to the partitions of the control nodes they are
+    // assigned to by GCM.
+    for idx in 0..function.nodes.len() {
+        if !function.nodes[idx].is_control() {
+            representatives[idx] = representatives[bbs[idx].idx()];
+        }
+    }
+
+    // Step 3: deduplicate representative node IDs.
+    let mut representative_to_partition_ids = HashMap::new();
+    for rep in &representatives {
+        if !representative_to_partition_ids.contains_key(rep) {
+            representative_to_partition_ids
+                .insert(rep, PartitionID::new(representative_to_partition_ids.len()));
+        }
+    }
+
+    // Step 4: update plan.
+    plan.num_partitions = representative_to_partition_ids.len();
+    for id in (0..function.nodes.len()).map(NodeID::new) {
+        plan.partitions[id.idx()] = representative_to_partition_ids[&representatives[id.idx()]];
+    }
+
+    plan.partition_devices = vec![Device::CPU; plan.num_partitions];
+}
+
+/*
+ * Set the device for all partitions containing a fork to the GPU.
+ */
+pub fn place_fork_partitions_on_gpu(function: &Function, plan: &mut Plan) {
+    for idx in 0..function.nodes.len() {
+        if function.nodes[idx].is_fork() {
+            plan.partition_devices[plan.partitions[idx].idx()] = Device::GPU;
+        }
+    }
+}
diff --git a/hercules_tools/hercules_dot/Cargo.toml b/hercules_tools/hercules_dot/Cargo.toml
index f2b42c0e..078baea9 100644
--- a/hercules_tools/hercules_dot/Cargo.toml
+++ b/hercules_tools/hercules_dot/Cargo.toml
@@ -7,4 +7,5 @@ authors = ["Russel Arbore <rarbore2@illinois.edu>"]
 clap = { version = "*", features = ["derive"] }
 hercules_ir = { path = "../../hercules_ir" }
 hercules_opt = { path = "../../hercules_opt" }
+hercules_cg = { path = "../../hercules_cg" }
 rand = "*"
diff --git a/hercules_tools/hercules_dot/src/dot.rs b/hercules_tools/hercules_dot/src/dot.rs
index 70f36831..751bd7c8 100644
--- a/hercules_tools/hercules_dot/src/dot.rs
+++ b/hercules_tools/hercules_dot/src/dot.rs
@@ -15,6 +15,7 @@ pub fn write_dot<W: Write>(
     typing: &ModuleTyping,
     doms: &Vec<DomTree>,
     fork_join_maps: &Vec<HashMap<NodeID, NodeID>>,
+    plans: &Vec<Plan>,
     w: &mut W,
 ) -> std::fmt::Result {
     write_digraph_header(w)?;
@@ -22,67 +23,85 @@ pub fn write_dot<W: Write>(
     for function_id in (0..module.functions.len()).map(FunctionID::new) {
         let function = &module.functions[function_id.idx()];
         let reverse_postorder = &reverse_postorders[function_id.idx()];
+        let plan = &plans[function_id.idx()];
         let mut reverse_postorder_node_numbers = vec![0; function.nodes.len()];
         for (idx, id) in reverse_postorder.iter().enumerate() {
             reverse_postorder_node_numbers[id.idx()] = idx;
         }
         write_subgraph_header(function_id, module, w)?;
 
+        let partition_to_node_map = plan.invert_partition_map();
+
         // Step 1: draw IR graph itself. This includes all IR nodes and all edges
         // between IR nodes.
-        for node_id in (0..function.nodes.len()).map(NodeID::new) {
-            let node = &function.nodes[node_id.idx()];
-            let dst_ty = &module.types[typing[function_id.idx()][node_id.idx()].idx()];
-            let dst_strictly_control = node.is_strictly_control();
-            let dst_control = dst_ty.is_control() || dst_strictly_control;
-
-            // Control nodes are dark red, data nodes are dark blue.
-            let color = if dst_control { "darkred" } else { "darkblue" };
-
-            write_node(node_id, function_id, color, module, w)?;
-
-            for u in def_use::get_uses(&node).as_ref() {
-                let src_ty = &module.types[typing[function_id.idx()][u.idx()].idx()];
-                let src_strictly_control = function.nodes[u.idx()].is_strictly_control();
-                let src_control = src_ty.is_control() || src_strictly_control;
+        for partition_idx in 0..plan.num_partitions {
+            let partition_color = match plan.partition_devices[partition_idx] {
+                Device::CPU => "lightblue",
+                Device::GPU => "darkseagreen",
+            };
+            write_partition_header(function_id, partition_idx, module, partition_color, w)?;
+            for node_id in &partition_to_node_map[partition_idx] {
+                let node = &function.nodes[node_id.idx()];
+                let dst_ty = &module.types[typing[function_id.idx()][node_id.idx()].idx()];
+                let dst_strictly_control = node.is_strictly_control();
+                let dst_control = dst_ty.is_control() || dst_strictly_control;
 
-                // An edge between control nodes is dashed. An edge between data
-                // nodes is filled. An edge between a control node and a data
-                // node is dotted.
-                let style = if dst_control && src_control {
-                    "dashed"
-                } else if !dst_control && !src_control {
-                    ""
-                } else {
-                    "dotted"
-                };
+                // Control nodes are dark red, data nodes are dark blue.
+                let color = if dst_control { "darkred" } else { "darkblue" };
 
-                // To have a consistent layout, we will add "back edges" in the
-                // IR graph as backward facing edges in the graphviz output, so
-                // that they don't mess up the layout. There isn't necessarily a
-                // precise definition of a "back edge" in Hercules IR. I've
-                // found what makes for the most clear output graphs is treating
-                // edges to phi nodes as back edges when the phi node appears
-                // before the use in the reverse postorder, and treating a
-                // control edge a back edge when the destination appears before
-                // the source in the reverse postorder.
-                let is_back_edge = reverse_postorder_node_numbers[node_id.idx()]
-                    < reverse_postorder_node_numbers[u.idx()]
-                    && (node.is_phi()
-                        || (function.nodes[node_id.idx()].is_control()
-                            && function.nodes[u.idx()].is_control()));
-                write_edge(
-                    node_id,
+                write_node(
+                    *node_id,
                     function_id,
-                    *u,
-                    function_id,
-                    !is_back_edge,
-                    "black",
-                    style,
+                    color,
                     module,
+                    &plan.schedules[node_id.idx()],
                     w,
                 )?;
+
+                for u in def_use::get_uses(&node).as_ref() {
+                    let src_ty = &module.types[typing[function_id.idx()][u.idx()].idx()];
+                    let src_strictly_control = function.nodes[u.idx()].is_strictly_control();
+                    let src_control = src_ty.is_control() || src_strictly_control;
+
+                    // An edge between control nodes is dashed. An edge between data
+                    // nodes is filled. An edge between a control node and a data
+                    // node is dotted.
+                    let style = if dst_control && src_control {
+                        "dashed"
+                    } else if !dst_control && !src_control {
+                        ""
+                    } else {
+                        "dotted"
+                    };
+
+                    // To have a consistent layout, we will add "back edges" in the
+                    // IR graph as backward facing edges in the graphviz output, so
+                    // that they don't mess up the layout. There isn't necessarily a
+                    // precise definition of a "back edge" in Hercules IR. I've
+                    // found what makes for the most clear output graphs is treating
+                    // edges to phi nodes as back edges when the phi node appears
+                    // before the use in the reverse postorder, and treating a
+                    // control edge a back edge when the destination appears before
+                    // the source in the reverse postorder.
+                    let is_back_edge = reverse_postorder_node_numbers[node_id.idx()]
+                        < reverse_postorder_node_numbers[u.idx()]
+                        && (node.is_phi()
+                            || (function.nodes[node_id.idx()].is_control()
+                                && function.nodes[u.idx()].is_control()));
+                    write_edge(
+                        *node_id,
+                        function_id,
+                        *u,
+                        function_id,
+                        !is_back_edge,
+                        "black",
+                        style,
+                        module,
+                        w,
+                    )?;
+                }
             }
+            write_graph_footer(w)?;
         }
 
         // Step 2: draw dominance edges in dark green. Don't draw post dominance
@@ -154,6 +173,22 @@ fn write_subgraph_header<W: Write>(
     Ok(())
 }
 
+fn write_partition_header<W: Write>(
+    function_id: FunctionID,
+    partition_idx: usize,
+    module: &Module,
+    color: &str,
+    w: &mut W,
+) -> std::fmt::Result {
+    let function = &module.functions[function_id.idx()];
+    write!(w, "subgraph {}_{} {{\n", function.name, partition_idx)?;
+    write!(w, "label=\"\"\n")?;
+    write!(w, "style=rounded\n")?;
+    write!(w, "bgcolor={}\n", color)?;
+    write!(w, "cluster=true\n")?;
+    Ok(())
+}
+
 fn write_graph_footer<W: Write>(w: &mut W) -> std::fmt::Result {
     write!(w, "}}\n")?;
     Ok(())
@@ -164,6 +199,7 @@ fn write_node<W: Write>(
     function_id: FunctionID,
     color: &str,
     module: &Module,
+    schedules: &Vec<Schedule>,
     w: &mut W,
 ) -> std::fmt::Result {
     let node = &module.functions[function_id.idx()].nodes[node_id.idx()];
@@ -222,16 +258,33 @@ fn write_node<W: Write>(
     } else {
         format!("{} ({})", node.lower_case_name(), suffix)
     };
-    write!(
-        w,
-        "{}_{}_{} [xlabel={}, label=\"{}\", color={}];\n",
-        node.lower_case_name(),
-        function_id.idx(),
-        node_id.idx(),
-        node_id.idx(),
-        label,
-        color
-    )?;
+
+    let mut iter = schedules.into_iter();
+    if let Some(first) = iter.next() {
+        let subtitle = iter.fold(format!("{:?}", first), |b, i| format!("{}, {:?}", b, i));
+        write!(
+            w,
+            "{}_{}_{} [xlabel={}, label=<{}<BR /><FONT POINT-SIZE=\"8\">{}</FONT>>, color={}];\n",
+            node.lower_case_name(),
+            function_id.idx(),
+            node_id.idx(),
+            node_id.idx(),
+            label,
+            subtitle,
+            color
+        )?;
+    } else {
+        write!(
+            w,
+            "{}_{}_{} [xlabel={}, label=\"{}\", color={}];\n",
+            node.lower_case_name(),
+            function_id.idx(),
+            node_id.idx(),
+            node_id.idx(),
+            label,
+            color
+        )?;
+    }
     Ok(())
 }
 
diff --git a/hercules_tools/hercules_dot/src/main.rs b/hercules_tools/hercules_dot/src/main.rs
index 198a73b7..4c80b17d 100644
--- a/hercules_tools/hercules_dot/src/main.rs
+++ b/hercules_tools/hercules_dot/src/main.rs
@@ -58,10 +58,37 @@ fn main() {
             (function, (types, constants, dynamic_constants))
         },
     );
-    let (_def_uses, reverse_postorders, typing, _subgraphs, doms, _postdoms, fork_join_maps) =
+    let (def_uses, reverse_postorders, typing, subgraphs, doms, _postdoms, fork_join_maps) =
         hercules_ir::verify::verify(&mut module)
             .expect("PANIC: Failed to verify Hercules IR module.");
 
+    let plans: Vec<_> = module
+        .functions
+        .iter()
+        .enumerate()
+        .map(|(idx, function)| {
+            hercules_ir::schedule::default_plan(
+                function,
+                &reverse_postorders[idx],
+                &fork_join_maps[idx],
+                &hercules_cg::gcm::gcm(
+                    function,
+                    &def_uses[idx],
+                    &reverse_postorders[idx],
+                    &subgraphs[idx],
+                    &doms[idx],
+                    &fork_join_maps[idx],
+                    &hercules_cg::antideps::array_antideps(
+                        function,
+                        &def_uses[idx],
+                        &module.types,
+                        &typing[idx],
+                    ),
+                ),
+            )
+        })
+        .collect();
+
     if args.output.is_empty() {
         let mut tmp_path = temp_dir();
         let mut rng = rand::thread_rng();
@@ -75,6 +102,7 @@ fn main() {
             &typing,
             &doms,
             &fork_join_maps,
+            &plans,
             &mut contents,
         )
         .expect("PANIC: Unable to generate output file contents.");
@@ -93,6 +121,7 @@ fn main() {
             &typing,
             &doms,
             &fork_join_maps,
+            &plans,
             &mut contents,
         )
         .expect("PANIC: Unable to generate output file contents.");
-- 
GitLab