diff --git a/hercules_rt/src/lib.rs b/hercules_rt/src/lib.rs index 4cf9b51a3e1ef18213a18ab039e00b1399aac21c..841c6f44dc004fd333410f77d6c1b60e34dbda62 100644 --- a/hercules_rt/src/lib.rs +++ b/hercules_rt/src/lib.rs @@ -811,3 +811,19 @@ where self.as_cuda_ref() } } + +pub trait HerculesRefInto<'a> { + fn to(&'a self) -> HerculesCPURef<'a>; +} + +impl<'a, T> HerculesRefInto<'a> for &'a [T] { + fn to(&'a self) -> HerculesCPURef<'a> { + HerculesCPURef::from_slice(self) + } +} + +impl<'a, T> HerculesRefInto<'a> for Box<[T]> { + fn to(&'a self) -> HerculesCPURef<'a> { + HerculesCPURef::from_slice(self) + } +} diff --git a/juno_samples/matmul/src/main.rs b/juno_samples/matmul/src/main.rs index 2eb2804b33003b9383a981ac652d0e71142ba2a5..3cb7d7f0f35e847b7c88bf4ee8dc2b6536d30647 100644 --- a/juno_samples/matmul/src/main.rs +++ b/juno_samples/matmul/src/main.rs @@ -2,9 +2,9 @@ use rand::random; +use hercules_rt::{runner, HerculesRefInto}; #[cfg(feature = "cuda")] -use hercules_rt::CUDABox; -use hercules_rt::{runner, HerculesCPURef}; +use hercules_rt::{CUDABox, HerculesCPURef}; juno_build::juno!("matmul"); @@ -25,12 +25,8 @@ fn main() { } #[cfg(not(feature = "cuda"))] { - let a = HerculesCPURef::from_slice(&a); - let b = HerculesCPURef::from_slice(&b); let mut r = runner!(matmul); - let c = r - .run(I as u64, J as u64, K as u64, a.clone(), b.clone()) - .await; + let c = r.run(I as u64, J as u64, K as u64, a.to(), b.to()).await; assert_eq!(c.as_slice::<i32>(), &*correct_c); } #[cfg(feature = "cuda")]