Skip to content
Snippets Groups Projects
Commit 7556bd8e authored by pinyili2's avatar pinyili2
Browse files

implement callASync

parent 6fd3daab
No related branches found
No related tags found
No related merge requests found
......@@ -456,6 +456,82 @@ void callSync(void (T::*memberFunc)(Args1...), Args2&&... args) {
Exception(ValueError, "callSync(): Unknown resource type");
}
}
template <typename RetType, typename... Args1, typename... Args2>
std::future<RetType> callAsync(RetType (T::*memberFunc)(Args1...), Args2&&... args) {
switch (location.type) {
case Resource::CPU:
if (location.is_local()) {
return std::async(std::launch::async, [this, memberFunc, args...] {
return (addr->*memberFunc)(args...);
});
} else {
#ifdef USE_MPI
return std::async(std::launch::async, [this, memberFunc, args...] {
RetType result;
MPI_Send(args..., location.id, MPI_COMM_WORLD);
MPI_Recv(&result, sizeof(RetType), MPI_BYTE, location.id, MPI_ANY_TAG, MPI_COMM_WORLD, MPI_STATUS_IGNORE);
return result;
});
#else
Exception(NotImplementedError, "Non-local CPU calls require MPI support");
#endif
}
break;
case Resource::GPU:
if (location.is_local()) {
return std::async(std::launch::async, [this, memberFunc, args...] {
RetType* dest;
RetType result;
gpuErrchk(cudaMalloc(&dest, sizeof(RetType)));
proxy_sync_call_kernel<T, RetType, Args2...><<<1,32>>>(dest, addr, memberFunc, args...);
gpuErrchk(cudaMemcpy(&result, dest, sizeof(RetType), cudaMemcpyDeviceToHost));
gpuErrchk(cudaFree(dest));
return result;
});
} else {
return std::async(std::launch::async, [this, memberFunc, args...] {
int current_device;
gpuErrchk(cudaGetDevice(&current_device));
gpuErrchk(cudaSetDevice(location.id));
RetType* dest;
RetType result;
gpuErrchk(cudaMalloc(&dest, sizeof(RetType)));
proxy_sync_call_kernel<T, RetType, Args2...><<<1,32>>>(dest, addr, memberFunc, args...);
gpuErrchk(cudaMemcpy(&result, dest, sizeof(RetType), cudaMemcpyDeviceToHost));
gpuErrchk(cudaFree(dest));
gpuErrchk(cudaSetDevice(current_device));
return result;
});
}
break;
case Resource::MPI:
#ifdef USE_MPI
return std::async(std::launch::async, [this, memberFunc, args...] {
int rank, size;
MPI_Comm_rank(MPI_COMM_WORLD, &rank);
MPI_Comm_size(MPI_COMM_WORLD, &size);
if (rank == location.id) {
RetType result = (addr->*memberFunc)(args...);
MPI_Bcast(&result, sizeof(RetType), MPI_BYTE, location.id, MPI_COMM_WORLD);
return result;
} else {
RetType result;
MPI_Bcast(&result, sizeof(RetType), MPI_BYTE, location.id, MPI_COMM_WORLD);
return result;
}
});
#else
Exception(NotImplementedError, "Async MPI calls require USE_MPI flag");
#endif
break;
default:
Exception(ValueError, "callAsync(): Unknown resource type");
}
return std::async(std::launch::async, [] { return RetType{}; });
}
// Specialization for bool/int/float types that do not have member functions
template<typename T>
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment