diff --git a/hercules_cg/src/rt.rs b/hercules_cg/src/rt.rs
index 65f7121576c8b7c5d703a140ad90ea6262005cd8..13370c4580a9365f0a9013b573e558133656d407 100644
--- a/hercules_cg/src/rt.rs
+++ b/hercules_cg/src/rt.rs
@@ -295,10 +295,19 @@ impl<'a> RTContext<'a> {
                 ref dynamic_constants,
                 ref args,
             } => {
-                match self.devices[callee_id.idx()] {
-                    Device::LLVM => {
+                let device = self.devices[callee_id.idx()];
+                match device {
+                    // The device backends ensure that device functions have the
+                    // same C interface.
+                    Device::LLVM | Device::CUDA => {
                         let block = &mut blocks.get_mut(&self.bbs.0[id.idx()]).unwrap();
 
+                        let device = match device {
+                            Device::LLVM => "cpu",
+                            Device::CUDA => "cuda",
+                            _ => panic!(),
+                        };
+
                         // First, get the raw pointers to collections that the
                         // device function takes as input.
                         let callee_objs = &self.collection_objects[&callee_id];
@@ -308,16 +317,18 @@ impl<'a> RTContext<'a> {
                                 if callee_objs.is_mutated(obj) {
                                     write!(
                                         block,
-                                        "                let arg_tmp{} = unsafe {{ {}.__cpu_ptr_mut() }};\n",
+                                        "                let arg_tmp{} = unsafe {{ {}.__{}_ptr_mut() }};\n",
                                         idx,
-                                        self.get_value(*arg)
+                                        self.get_value(*arg),
+                                        device
                                     )?;
                                 } else {
                                     write!(
                                         block,
-                                        "                let arg_tmp{} = unsafe {{ {}.__cpu_ptr() }};\n",
+                                        "                let arg_tmp{} = unsafe {{ {}.__{}_ptr() }};\n",
                                         idx,
-                                        self.get_value(*arg)
+                                        self.get_value(*arg),
+                                        device
                                     )?;
                                 }
                             } else {
@@ -401,7 +412,6 @@ impl<'a> RTContext<'a> {
                         }
                         write!(block, ").await;\n")?;
                     }
-                    _ => todo!(),
                 }
             }
             _ => panic!(
diff --git a/hercules_ir/src/ir.rs b/hercules_ir/src/ir.rs
index 3c5ed0bdfa9c6a7e7c9c55c9da16782ff5cc7cde..4fd0cf0b11a55beb3320e4211b2c52f7c0e4e38d 100644
--- a/hercules_ir/src/ir.rs
+++ b/hercules_ir/src/ir.rs
@@ -329,7 +329,7 @@ pub enum Schedule {
 #[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
 pub enum Device {
     LLVM,
-    NVVM,
+    CUDA,
     // Entry functions are lowered to async Rust code that calls device
     // functions (leaf nodes in the call graph), possibly concurrently.
     AsyncRust,