From 54acf3e288890bbe95d0bf003f7ecc1919348769 Mon Sep 17 00:00:00 2001
From: Russel Arbore <russel.jma@gmail.com>
Date: Tue, 4 Mar 2025 15:26:41 -0600
Subject: [PATCH] fix gpu backend to emit namespaces properly across cuda
 versions

---
 hercules_cg/src/gpu.rs | 26 ++++++++++++++++++++------
 1 file changed, 20 insertions(+), 6 deletions(-)

diff --git a/hercules_cg/src/gpu.rs b/hercules_cg/src/gpu.rs
index dd87acbe..4069cb02 100644
--- a/hercules_cg/src/gpu.rs
+++ b/hercules_cg/src/gpu.rs
@@ -354,7 +354,6 @@ impl GPUContext<'_> {
         write!(
             w,
             "
-#define _CG_ABI_EXPERIMENTAL
 #include <assert.h>
 #include <stdio.h>
 #include <stddef.h>
@@ -362,8 +361,23 @@ impl GPUContext<'_> {
 #include <cuda_runtime.h>
 #include <math_constants.h>
 #include <mma.h>
+
+#if (CUDA_VERSION >= 12000)
+#else
+#define _CG_ABI_EXPERIMENTAL
+#endif
+
 #include <cooperative_groups.h>
 #include <cooperative_groups/reduce.h>
+
+#if (CUDA_VERSION >= 12000)
+namespace cg = cooperative_groups;
+namespace cge = cooperative_groups;
+#else
+namespace cg = cooperative_groups;
+namespace cge = cooperative_groups::experimental;
+#endif
+
 #include <cuda_bf16.h>
 namespace cg = cooperative_groups;
 
@@ -564,12 +578,12 @@ namespace cg = cooperative_groups;
     fn codegen_helpers(&self, w: &mut String) -> Result<(), Error> {
         write!(
             w,
-            "\t__shared__ cg::experimental::block_tile_memory<1024> block_sync_shared;\n"
+            "\t__shared__ cge::block_tile_memory<1024> block_sync_shared;\n"
         )?;
         write!(w, "\tcg::grid_group grid = cg::this_grid();\n")?;
         write!(
             w,
-            "\tcg::thread_block block = cg::experimental::this_thread_block(block_sync_shared);\n"
+            "\tcg::thread_block block = cge::this_thread_block(block_sync_shared);\n"
         )?;
         Ok(())
     }
@@ -1715,20 +1729,20 @@ namespace cg = cooperative_groups;
                     };
                     write!(
                         thread_block_tiles,
-                        "\tcg::thread_block_tile<{}> {} = cg::experimental::tiled_partition<{}>(block);\n",
+                        "\tcg::thread_block_tile<{}> {} = cge::tiled_partition<{}>(block);\n",
                         use_thread_per_id, cg_tile, use_thread_per_id
                     )?;
                     let cg_tile_use = self.get_cg_tile(id, CGType::Use);
                     write!(
                         thread_block_tiles,
-                        "\tcg::thread_block_tile<{}> {} = cg::experimental::tiled_partition<{}>(block);\n",
+                        "\tcg::thread_block_tile<{}> {} = cge::tiled_partition<{}>(block);\n",
                         use_thread_quota, cg_tile_use, use_thread_quota
                     )?;
                     let available_thread_quota = available_thread_quota.unwrap();
                     let cg_tile_available = self.get_cg_tile(id, CGType::Available);
                     write!(
                         thread_block_tiles,
-                        "\tcg::thread_block_tile<{}> {} = cg::experimental::tiled_partition<{}>(block);\n",
+                        "\tcg::thread_block_tile<{}> {} = cge::tiled_partition<{}>(block);\n",
                         available_thread_quota, cg_tile_available, available_thread_quota
                     )?;
                     if parallel_factor.is_none() {
-- 
GitLab