diff --git a/hercules_samples/dot/src/main.rs b/hercules_samples/dot/src/main.rs index 7f5b453ab426f1ce0ab220682ce6be89bf851305..197431507fd0c9246e07527d3f58081df25757b8 100644 --- a/hercules_samples/dot/src/main.rs +++ b/hercules_samples/dot/src/main.rs @@ -2,36 +2,20 @@ #[cfg(feature = "cuda")] use hercules_rt::CUDABox; -use hercules_rt::{runner, HerculesCPURef}; +use hercules_rt::{runner, HerculesImmBox, HerculesImmBoxTo}; juno_build::juno!("dot"); fn main() { async_std::task::block_on(async { - #[cfg(not(feature = "cuda"))] - { - let a: [f32; 8] = [0.0, 1.0, 0.0, 2.0, 0.0, 3.0, 0.0, 4.0]; - let a = HerculesCPURef::from_slice(&a); - let b: [f32; 8] = [0.0, 5.0, 0.0, 6.0, 0.0, 7.0, 0.0, 8.0]; - let b = HerculesCPURef::from_slice(&b); - let mut r = runner!(dot); - let c = r.run(8, a, b).await; - println!("{}", c); - assert_eq!(c, 70.0); - } - #[cfg(feature = "cuda")] - { - let mut a: [f32; 8] = [0.0, 1.0, 0.0, 2.0, 0.0, 3.0, 0.0, 4.0]; - let a_box = CUDABox::from_cpu_ref(HerculesCPURef::from_slice(&mut a)); - let a = a_box.get_ref(); - let mut b: [f32; 8] = [0.0, 5.0, 0.0, 6.0, 0.0, 7.0, 0.0, 8.0]; - let b_box = CUDABox::from_cpu_ref(HerculesCPURef::from_slice(&mut b)); - let b = b_box.get_ref(); - let mut r = runner!(dot); - let c = r.run(8, a, b).await; - println!("{}", c); - assert_eq!(c, 70.0); - } + let a: [f32; 8] = [0.0, 1.0, 0.0, 2.0, 0.0, 3.0, 0.0, 4.0]; + let b: [f32; 8] = [0.0, 5.0, 0.0, 6.0, 0.0, 7.0, 0.0, 8.0]; + let a = HerculesImmBox::from(&a as &[f32]); + let b = HerculesImmBox::from(&b as &[f32]); + let mut r = runner!(dot); + let c = r.run(8, a.to(), b.to()).await; + println!("{}", c); + assert_eq!(c, 70.0); }); }