Skip to content
Snippets Groups Projects
Commit 6fd3daab authored by pinyili2's avatar pinyili2
Browse files

callSync implement MPI and non-local GPU and make move robust

parent fcf88bc2
No related branches found
No related tags found
No related merge requests found
......@@ -143,10 +143,13 @@ struct Proxy {
/**
* @brief Default constructor initializes the location to a default CPU resource and the address to nullptr.
*/
// Prevent Proxy of Proxy
static_assert(!std::is_same<T, Proxy>::value, "Cannot make a Proxy of a Proxy object");
Proxy() : location(Resource{Resource::CPU,0}), addr(nullptr), metadata(nullptr) {
LOGINFO("Constructing Proxy<{}> @{}", type_name<T>().c_str(), fmt::ptr(this));
};
Proxy(const Resource& r) : location(r), addr(nullptr), metadata(nullptr) {
explicit Proxy(const Resource& r) : location(r), addr(nullptr), metadata(nullptr) {
LOGINFO("Constructing Proxy<{}> @{}", type_name<T>().c_str(), fmt::ptr(this));
};
Proxy(const Resource& r, T& obj, T* dest = nullptr) : location(r), addr(dest == nullptr ? &obj : dest) {
......@@ -156,13 +159,41 @@ struct Proxy {
type_name<T>().c_str(), fmt::ptr(this), fmt::ptr(&obj), fmt::ptr(metadata));
};
// Copy constructor
Proxy(const Proxy<T>& other) : location(other.location), addr(other.addr), metadata(nullptr) {
LOGINFO("Copy Constructing Proxy<{}> @{}", type_name<T>().c_str(), fmt::ptr(this));
if (other.metadata != nullptr) {
const Metadata_t<T>& tmp = *(other.metadata);
metadata = new Metadata_t<T>(tmp);
}
};
//Proxy(const Proxy<T>& other) : location(other.location), addr(other.addr), metadata(nullptr) {
//LOGINFO("Copy Constructing Proxy<{}> @{}", type_name<T>().c_str(), fmt::ptr(this));
//if (other.metadata != nullptr) {
// const Metadata_t<T>& tmp = *(other.metadata);
// metadata = new Metadata_t<T>(tmp);
//}
//};
//Copy #2
Proxy(const Proxy<T>& other)
: location(other.location), addr(nullptr), metadata(nullptr) {
LOGINFO("Copy Constructing Proxy<{}> @{}", type_name<T>().c_str(), fmt::ptr(this));
if (other.addr != nullptr) {
// Deep copy the data based on resource type
switch (location.type) {
case Resource::CPU:
addr = new T(*other.addr);
break;
case Resource::GPU:
#ifdef USE_CUDA
if (cudaMalloc(&addr, sizeof(T)) == cudaSuccess) {
cudaMemcpy(addr, other.addr, sizeof(T), cudaMemcpyDeviceToDevice);
}
#endif
break;
default:
LOGERROR("Unsupported resource type in copy constructor");
}
}
if (other.metadata != nullptr) {
metadata = new Metadata_t<T>(*other.metadata);
}
}
Proxy<T>& operator=(const Proxy<T>& other) {
if (this != &other) {
// Free existing resources.
......@@ -175,7 +206,8 @@ struct Proxy {
}
return *this;
};
Proxy(Proxy<T>&& other) : addr(nullptr), metadata(nullptr) {
Proxy(Proxy<T>&& other) noexcept: addr(nullptr), metadata(nullptr) {
LOGINFO("Move Constructing Proxy<{}> @{}", type_name<T>().c_str(), fmt::ptr(this));
location = other.location;
addr = other.addr;
......@@ -183,7 +215,21 @@ struct Proxy {
// const Metadata_t<T>& tmp = *(other.metadata);
metadata = other.metadata;
other.metadata = nullptr;
other.addr = nullptr;
};
Proxy& operator=(Proxy<T>&& other) noexcept {
if (this != &other) {
delete metadata;
location = other.location;
addr = other.addr;
metadata = other.metadata;
other.addr = nullptr;
other.metadata = nullptr;
}
return *this;
}
~Proxy() {
LOGINFO("Deconstructing Proxy<{}> @{} with metadata @{}", type_name<T>().c_str(), fmt::ptr(this), fmt::ptr(metadata));
if (metadata != nullptr) delete metadata;
......@@ -211,11 +257,19 @@ struct Proxy {
if (location.is_local()) {
return (addr->*memberFunc)(std::forward<Args2>(args)...);
} else {
Exception( NotImplementedError, "Proxy::callSync() non-local CPU calls" );
#ifdef USE_MPI
RetType result;
MPI_Send(args..., location.id, MPI_COMM_WORLD);
MPI_Recv(&result, sizeof(RetType), MPI_BYTE, location.id, MPI_ANY_TAG, MPI_COMM_WORLD, MPI_STATUS_IGNORE);
return result;
#else
Exception(NotImplementedError, "Non-local CPU calls require MPI support");
#endif
//Exception( NotImplementedError, "Proxy::callSync() non-local CPU calls" );
}
break;
case Resource::GPU:
#ifdef __CUDACC__
#ifdef __CUDACC__
if (location.is_local()) {
if (sizeof(RetType) > 0) {
// Note: this only support basic RetType objects
......@@ -231,15 +285,48 @@ struct Proxy {
Exception( NotImplementedError, "Proxy::callSync() local GPU calls" );
}
} else {
Exception( NotImplementedError, "Proxy::callSync() non-local GPU calls" );
size_t target_device = location.id;
int current_device;
gpuErrchk(cudaGetDevice(&current_device));
gpuErrchk(cudaSetDevice(target_device));
RetType* dest;
RetType result;
gpuErrchk(cudaMalloc(&dest, sizeof(RetType)));
proxy_sync_call_kernel<T, RetType, Args2...><<<1,32>>>(dest, addr, memberFunc, args...);
gpuErrchk(cudaMemcpy(&result, dest, sizeof(RetType), cudaMemcpyDeviceToHost));
gpuErrchk(cudaFree(dest));
gpuErrchk(cudaSetDevice(current_device));
return result;
//Exception( NotImplementedError, "Proxy::callSync() non-local GPU calls" );
}
#else
Exception( NotImplementedError, "Proxy::callSync() for GPU only defined for files compiled with nvvc" );
#endif
#else
Exception( NotImplementedError, "Proxy::callSync() for GPU only defined for files compiled with nvvc" );
#endif
break;
case Resource::MPI:
Exception( NotImplementedError, "MPI sync calls (deprecate?)" );
break;
#ifdef USE_MPI
int rank, size;
MPI_Comm_rank(MPI_COMM_WORLD, &rank);
MPI_Comm_size(MPI_COMM_WORLD, &size);
if (rank == location.id) {
// Target rank executes the function
RetType result = (addr->*memberFunc)(args...);
// Broadcast result to all ranks
MPI_Bcast(&result, sizeof(RetType), MPI_BYTE, location.id, MPI_COMM_WORLD);
return result;
} else {
// Other ranks receive the result
RetType result;
MPI_Bcast(&result, sizeof(RetType), MPI_BYTE, location.id, MPI_COMM_WORLD);
return result;
}
#else
Exception(NotImplementedError, "MPI calls require USE_MPI flag");
#endif
break;
default:
Exception( ValueError, "Proxy::callSync(): Unknown resource type" );
}
......@@ -259,13 +346,54 @@ Exception( NotImplementedError, "Proxy::callSync() for GPU only defined for file
break;
case Resource::GPU:
if (location.is_local()) {
Exception( NotImplementedError, "Proxy::callAsync() local GPU calls" );
return std::async(std::launch::async, [this, memberFunc, args...] {
RetType* dest;
RetType result;
gpuErrchk(cudaMalloc(&dest, sizeof(RetType)));
proxy_sync_call_kernel<T, RetType, Args2...><<<1,32>>>(dest, addr, memberFunc, args...);
gpuErrchk(cudaMemcpy(&result, dest, sizeof(RetType), cudaMemcpyDeviceToHost));
gpuErrchk(cudaFree(dest));
return result;
//Exception( NotImplementedError, "Proxy::callAsync() local GPU calls" );
} else {
Exception( NotImplementedError, "Proxy::callAsync() non-local GPU calls" );
return std::async(std::launch::async, [this, memberFunc, args...] {
size_t target_device = location.id;
int current_device;
gpuErrchk(cudaGetDevice(&current_device));
gpuErrchk(cudaSetDevice(target_device));
RetType* dest;
RetType result;
gpuErrchk(cudaMalloc(&dest, sizeof(RetType)));
proxy_sync_call_kernel<T, RetType, Args2...><<<1,32>>>(dest, addr, memberFunc, args...);
gpuErrchk(cudaMemcpy(&result, dest, sizeof(RetType), cudaMemcpyDeviceToHost));
gpuErrchk(cudaFree(dest));
gpuErrchk(cudaSetDevice(current_device));
return result;
//Exception( NotImplementedError, "Proxy::callAsync() non-local GPU calls" );
}
break;
case Resource::MPI:
Exception( NotImplementedError, "MPI async calls (deprecate?)" );
#ifdef USE_MPI
return std::async(std::launch::async, [this, memberFunc, args...] {
int rank, size;
MPI_Comm_rank(MPI_COMM_WORLD, &rank);
MPI_Comm_size(MPI_COMM_WORLD, &size);
if (rank == location.id) {
RetType result = (addr->*memberFunc)(args...);
MPI_Bcast(&result, sizeof(RetType), MPI_BYTE, location.id, MPI_COMM_WORLD);
return result;
} else {
RetType result;
MPI_Bcast(&result, sizeof(RetType), MPI_BYTE, location.id, MPI_COMM_WORLD);
return result;
}
});
#else
Exception(NotImplementedError, "Async MPI calls require USE_MPI flag");
#endif
break;
default:
Exception( ValueError, "Proxy::callAsync(): Unknown resource type" );
......@@ -273,6 +401,61 @@ Exception( NotImplementedError, "Proxy::callSync() for GPU only defined for file
return std::async(std::launch::async, [] { return RetType{}; });
}
};
template<typename T, typename... Args>
__global__ void proxy_sync_call_kernel_noreturn(T* addr, void (T::*memberFunc)(Args...), Args... args) {
if (blockIdx.x == 0 && threadIdx.x == 0) {
(addr->*memberFunc)(args...);
}
}
template <typename... Args1, typename... Args2>
void callSync(void (T::*memberFunc)(Args1...), Args2&&... args) {
switch (location.type) {
case Resource::CPU:
if (location.is_local()) {
(addr->*memberFunc)(args...);
} else {
#ifdef USE_MPI
MPI_Send(args..., location.id, MPI_COMM_WORLD);
MPI_Barrier(MPI_COMM_WORLD); // Ensure completion
#else
Exception(NotImplementedError, "Non-local CPU calls require MPI support");
#endif
}
break;
case Resource::GPU:
if (location.is_local()) {
proxy_sync_call_kernel_noreturn<T, Args2...><<<1,32>>>(addr, memberFunc, args...);
gpuErrchk(cudaDeviceSynchronize());
} else {
int current_device;
gpuErrchk(cudaGetDevice(&current_device));
gpuErrchk(cudaSetDevice(location.id));
proxy_sync_call_kernel_noreturn<T, Args2...><<<1,32>>>(addr, memberFunc, args...);
gpuErrchk(cudaDeviceSynchronize());
gpuErrchk(cudaSetDevice(current_device));
}
break;
case Resource::MPI:
#ifdef USE_MPI
int rank;
MPI_Comm_rank(MPI_COMM_WORLD, &rank);
if (rank == location.id) {
// Target rank executes the function
(addr->*memberFunc)(args...);
}
// Synchronize all ranks
MPI_Barrier(MPI_COMM_WORLD);
#else
Exception(NotImplementedError, "MPI calls require USE_MPI flag");
#endif
break;
default:
Exception(ValueError, "callSync(): Unknown resource type");
}
}
// Specialization for bool/int/float types that do not have member functions
template<typename T>
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment