diff --git a/llvm/projects/hpvm-tensor-rt/CMakeLists.txt b/llvm/projects/hpvm-tensor-rt/CMakeLists.txt index ac07b17c2372cb0fa6a11409b79b23685da13ffd..d980429560da047f4fdb7dab8eb228e3d1b3419f 100644 --- a/llvm/projects/hpvm-tensor-rt/CMakeLists.txt +++ b/llvm/projects/hpvm-tensor-rt/CMakeLists.txt @@ -4,17 +4,26 @@ project (cudnn-training) find_package(CUDA 6.5 REQUIRED) +set( + CUDA_NVCC_FLAGS ${CUDA_NVCC_FLAGS}; + -gencode;arch=compute_60,code=sm_60; + -gencode;arch=compute_60,code=compute_60; + -std=c++14 --expt-relaxed-constexpr # These are newly added +) + if (CMAKE_BUILD_TYPE STREQUAL "Debug") message("Debug mode") - set(CUDA_NVCC_FLAGS ${CUDA_NVCC_FLAGS};-gencode;arch=compute_60,code=sm_60;-gencode;arch=compute_60,code=sm_60;-gencode;arch=compute_60,code=compute_60;-std=c++11;-g;-lineinfo;-Xcompiler;-ggdb;-lcurand) + set(CUDA_NVCC_FLAGS ${CUDA_NVCC_FLAGS};-g;-lineinfo;-Xcompiler;-ggdb;-lcurand) else() - set(CUDA_NVCC_FLAGS ${CUDA_NVCC_FLAGS};-gencode;arch=compute_60,code=sm_60;-gencode;arch=compute_60,code=sm_60;-gencode;arch=compute_60,code=compute_60;-std=c++11;-DNDEBUG;-Xcompiler;-DNDEBUG;-lcurand) + set(CUDA_NVCC_FLAGS ${CUDA_NVCC_FLAGS};-DNDEBUG;-Xcompiler;-DNDEBUG;-lcurand) endif() set(CUDA_PROPAGATE_HOST_FLAGS OFF) # Addresses a bug where code is not compiled as C++11 in non-CUDA code and older g++ versions -set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -std=c++11 -I/ " ) +# Edit: using c++14 now +set(CMAKE_CXX_STANDARD 14) +set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -std=c++14 -I/ " ) add_definitions(-DNO_INJECTION) add_definitions(-DPROMISE_TUNER_ENABLED) @@ -43,9 +52,9 @@ cuda_add_cublas_to_target(tensor_runtime) cuda_add_library(tensor_cpu_runtime tensor_runtime/src/tensor_cpu_runtime.cc) if(USE_GFLAGS) - target_link_libraries(tensor_runtime gflags cudnn -lcurand) + target_link_libraries(tensor_runtime gflags cudnn cufft -lcurand) else() - target_link_libraries(tensor_runtime cudnn -lcurand) + target_link_libraries(tensor_runtime cudnn cufft -lcurand) endif() target_link_libraries(tensor_cpu_runtime) diff --git a/llvm/projects/hpvm-tensor-rt/tensor_runtime/include/debug.h b/llvm/projects/hpvm-tensor-rt/tensor_runtime/include/debug.h index 33864fed94f3a86d065f2f166adbfc36127cc42d..813351da7b0b1fa53252c097bde3d2630efe012c 100644 --- a/llvm/projects/hpvm-tensor-rt/tensor_runtime/include/debug.h +++ b/llvm/projects/hpvm-tensor-rt/tensor_runtime/include/debug.h @@ -9,6 +9,13 @@ #include "tensor.h" +#include <sstream> +#include <iostream> +#include <cstdarg> + +#include <cudnn.h> +#include <cublas_v2.h> +#include <cufft.h> #define FatalError(s) do { \ std::stringstream _where, _message; \ @@ -37,7 +44,13 @@ } \ } while(0) +void _checkCUBLAS(cublasStatus_t error, const char *file, int line); + +void _checkCUFFT(cufftResult error, const char *file, int line); + +#define checkCUBLAS(err) _checkCUBLAS(err, __FILE__, __LINE__) +#define checkCUFFT(err) _checkCUFFT(err, __FILE__, __LINE__) void INFO(const char* format, ...){ if(!LOG_INFO) // Don't print if logging info is disabled diff --git a/llvm/projects/hpvm-tensor-rt/tensor_runtime/include/functional/broadcast.h b/llvm/projects/hpvm-tensor-rt/tensor_runtime/include/functional/broadcast.h new file mode 100644 index 0000000000000000000000000000000000000000..52b70e08ff6cf601c44acbdc07132ada2f629c58 --- /dev/null +++ b/llvm/projects/hpvm-tensor-rt/tensor_runtime/include/functional/broadcast.h @@ -0,0 +1,74 @@ +#include <algorithm> +#include <array> +#include <cstddef> +#include <type_traits> + +#include "common.h" +#include "tensor.h" + +// TODO: don't accept N == 1 +template <size_t N, typename std::enable_if<N >= 1, int>::type = 0> +class BroadcastRemap { +public: + BroadcastRemap(const std::array<Tensor *, N> &tensors) + : out_sizes(), sizes() { + this->in_dims = tensors[0]->dims.num_dims; + for (size_t i = 0; i < N; i++) { + Tensor *t = tensors[i]; + this->sizes[i] = ::sizes(t); + if (this->in_dims != t->dims.num_dims) + throw std::runtime_error("Broadcast tensors have different dimensions"); + this->tail_stride[i] = 1; + } + fill_broadcast_dims(); + } + + std::vector<size_t> getDim() const { return this->out_sizes; } + + const size_t *getStrides() const { return tail_stride; } + +private: + void fill_broadcast_dims() { + // Simplified broadcasting rule: + // 1. Tensors must have the same dimension that is greater than 1. + // 2. Dimension size being 1 (instead of equal) is only allowed for each + // tensor for a continuous N dimensions starting from the last one. + + // Assume all this->in_dims are 1, and compute + // out_dims is reverse-constructed + if (this->in_dims < 1) + throw std::runtime_error( + "Broadcast tensors should have at least 1 dimension"); + bool broadcast_ended[N]{false}; + this->out_sizes.resize(this->in_dims, 1); + for (long i = this->in_dims - 1; i >= 0; i--) { + // First get tensors agree on dim size + for (size_t j = 0; j < N; j++) { + size_t this_size = this->sizes[j][i]; + if (this_size == 1) + continue; + if (this->out_sizes[i] != 1 && this->out_sizes[i] != this_size) + throw std::runtime_error("Dimension size mismatch"); + this->out_sizes[i] = this_size; + } + } + for (size_t j = 0; j < N; j++) + for (long i = this->in_dims - 1; i >= 0; i--) { + // Check for continuity, calculate stride size + size_t this_size = this->sizes[j][i]; + if (this_size != 1) { + // Broadcast cannot go on anymore + broadcast_ended[j] = true; + continue; + } + if (this->out_sizes[i] != this_size && broadcast_ended[j]) + throw std::runtime_error("Broadcast dims must be continuous"); + else + tail_stride[j] *= this->out_sizes[i]; + } + } + + size_t in_dims; + std::vector<size_t> out_sizes, sizes[N]; + size_t tail_stride[N]; +}; diff --git a/llvm/projects/hpvm-tensor-rt/tensor_runtime/include/functional/common.h b/llvm/projects/hpvm-tensor-rt/tensor_runtime/include/functional/common.h new file mode 100644 index 0000000000000000000000000000000000000000..520c1954b237039aee681dccc44acdb9b94dc443 --- /dev/null +++ b/llvm/projects/hpvm-tensor-rt/tensor_runtime/include/functional/common.h @@ -0,0 +1,32 @@ +#ifndef IMAGE_PROCESSING_COMMON_H +#define IMAGE_PROCESSING_COMMON_H + +#include <cudnn.h> +#include <device_launch_parameters.h> +#include <vector> + +#include "debug.h" +#include "tensor.h" + +template <typename T> __host__ __device__ __forceinline__ T ceilDiv(T a, T b) { + return (a + b - 1) / b; +} + +template <typename T> __host__ T resolve_func_ptr(void *func_symbol_ptr) { + void *v_func_ptr = nullptr; + checkCudaErrors(cudaMemcpyFromSymbol( + &v_func_ptr, *(void **)func_symbol_ptr, sizeof(void *))); + return (T)v_func_ptr; +} + +std::vector<size_t> sizes(Tensor *t); + +std::vector<size_t> sizes(const Dimension &dim); + +size_t num_elems(const std::vector<size_t> &dim_sizes); + +size_t num_elems(const Dimension &dim); + +size_t num_elems(Tensor *t); + +#endif // IMAGE_PROCESSING_COMMON_H diff --git a/llvm/projects/hpvm-tensor-rt/tensor_runtime/include/functional/map.cuh b/llvm/projects/hpvm-tensor-rt/tensor_runtime/include/functional/map.cuh new file mode 100644 index 0000000000000000000000000000000000000000..1375fc8476304af1e74bf6d6be9349cdd802e98d --- /dev/null +++ b/llvm/projects/hpvm-tensor-rt/tensor_runtime/include/functional/map.cuh @@ -0,0 +1,89 @@ +#ifndef RUNTIME_MAP_H +#define RUNTIME_MAP_H + +#include <array> +#include <cstddef> +#include <device_launch_parameters.h> +#include <type_traits> + +#include "broadcast.h" +#include "common.h" +#include "debug.h" +#include "map_typing.h" +#include "tensor.h" +#include "tensor_utils.cu" + +template <size_t N> void mapPrecheck(const std::array<Tensor *, N> &srcs) { + for (Tensor *src : srcs) { + if (src->dims.num_dims != 4 || src->data_format != CUDNN_TENSOR_NCHW) + throw std::runtime_error("Not supported"); // TODO: support this + } +} + +template <typename Scalar, size_t N> +__global__ void kernelMapBroadcast( + Scalar *target, unsigned num_rows, NTo1MapF<Scalar, N> n_ary_op, + Scalar **srcs, size_t *tail_strides) { + unsigned threadId = blockIdx.x * blockDim.x + threadIdx.x, + stride = gridDim.x * blockDim.x; + Scalar buf[N]; + for (unsigned row = threadId; row < num_rows; row += stride) { + for (unsigned i = 0; i < N; i++) { + unsigned j = (unsigned)__fdividef(row, tail_strides[i]); + buf[i] = srcs[i][j]; + } + target[row] = call_on_c_array<Scalar, Scalar, N>(n_ary_op, buf); + } +} + +template <typename Scalar, size_t N> +std::tuple<size_t *, Scalar **> make_cuda_params( + const BroadcastRemap<N> &br, const std::array<Tensor *, N> &srcs) { + std::array<Scalar *, N> gpu_datas; + std::transform(srcs.begin(), srcs.end(), gpu_datas.begin(), [](Tensor *t) { + hostToDeviceCopy(t); + return (Scalar *)t->gpu_data; + }); + size_t *cuda_strides; + Scalar **cuda_gpu_data; + cudaMalloc(&cuda_strides, N * sizeof(size_t)); + cudaMemcpy( + cuda_strides, br.getStrides(), N * sizeof(size_t), + cudaMemcpyHostToDevice); + cudaMalloc(&cuda_gpu_data, N * sizeof(Scalar *)); + cudaMemcpy( + cuda_gpu_data, gpu_datas.data(), N * sizeof(size_t), + cudaMemcpyHostToDevice); + return std::make_tuple(cuda_strides, cuda_gpu_data); +} + +template < + typename Scalar, size_t N, typename std::enable_if<N >= 1, int>::type = 0> +__host__ Tensor * +mapGeneral(void *host_func_ptr, const std::array<Tensor *, N> &srcs) { + mapPrecheck(srcs); + auto func_ptr = resolve_func_ptr<NTo1MapF<Scalar, N>>(host_func_ptr); + + auto br = BroadcastRemap<N>(srcs); + std::vector<size_t> dim_sizes = br.getDim(); + auto *target = (Tensor *)create4DTensor( + CUDNN_DATA_FLOAT, CUDNN_TENSOR_NCHW, dim_sizes[0], dim_sizes[1], + dim_sizes[2], dim_sizes[3]); + changeTensorPlacement(target, DEVICE); + + size_t *cuda_strides; + Scalar **gpu_data; + std::tie(cuda_strides, gpu_data) = make_cuda_params<Scalar, N>(br, srcs); + + unsigned n_elem = num_elems(dim_sizes); + unsigned max_threads = 512, max_grid = 1024; + unsigned threads = std::min(max_threads, n_elem); + unsigned grids = std::min(max_grid, ceilDiv(n_elem, threads)); + kernelMapBroadcast<Scalar, N><<<grids, threads>>>( + (Scalar *)target->gpu_data, n_elem, func_ptr, gpu_data, cuda_strides); + cudaDeviceSynchronize(); + checkCudaErrors(cudaGetLastError()); + return target; +} + +#endif diff --git a/llvm/projects/hpvm-tensor-rt/tensor_runtime/include/functional/map_typing.h b/llvm/projects/hpvm-tensor-rt/tensor_runtime/include/functional/map_typing.h new file mode 100644 index 0000000000000000000000000000000000000000..4cb93661384a572425361a053371c409606de164 --- /dev/null +++ b/llvm/projects/hpvm-tensor-rt/tensor_runtime/include/functional/map_typing.h @@ -0,0 +1,68 @@ +#ifndef IMAGE_PROCESSING_MAP_TYPING_H +#define IMAGE_PROCESSING_MAP_TYPING_H + +// Constructs type T (*)(T, T, T, T, ... <n_times>) from T and N +#include <cstddef> +#include <device_launch_parameters.h> +#include <tuple> +#include <utility> + +namespace { +template <class T, size_t> using Type = T; + +template <typename, template <typename...> typename, typename> struct _RepNType; + +template <typename T, template <typename...> typename W, size_t... Is> +struct _RepNType<T, W, std::index_sequence<Is...>> { + using type = W<Type<T, Is>...>; +}; + +template <typename T, template <typename...> typename W, size_t N> +using RepNType = typename _RepNType<T, W, std::make_index_sequence<N>>::type; + +template <typename Ret, typename... Args> using FuncPtrT = Ret (*)(Args...); + +template <typename Ret, typename Arg, size_t N> struct _NAToBFunc { + template <typename... Args> using Wrapper = FuncPtrT<Ret, Args...>; + + using type = RepNType<Arg, Wrapper, N>; +}; +} // namespace + +template <typename Ret, typename Arg, size_t N> +using NAToBF = typename _NAToBFunc<Ret, Arg, N>::type; + +template <typename Scalar, size_t N> using NTo1MapF = NAToBF<Scalar, Scalar, N>; + +template <typename T, size_t N> using RepNTuple = RepNType<T, std::tuple, N>; + +namespace { +template <typename TIterable, typename T, size_t... Is> +constexpr RepNTuple<T, sizeof...(Is)> +as_tuple(TIterable arr, std::index_sequence<Is...>) { + return std::make_tuple(arr[Is]...); +} + +template <typename Function, typename Tuple, size_t... I> +__device__ auto call(Function f, Tuple t, std::index_sequence<I...>) { + return f(std::get<I>(t)...); +} +} // namespace + +template <typename TIterable, typename T, size_t N> +constexpr RepNTuple<T, N> as_tuple(TIterable arr) { + return as_tuple<TIterable, T>(arr, std::make_index_sequence<N>{}); +} + +template <typename Function, typename Tuple> +__device__ auto call_on_tuple(Function f, Tuple t) { + static constexpr auto size = std::tuple_size<Tuple>::value; + return call(f, t, std::make_index_sequence<size>{}); +} + +template <typename Ret, typename T, size_t N> +__device__ Ret call_on_c_array(NAToBF<Ret, T, N> f, const T arr[N]) { + return call_on_tuple(f, as_tuple<const T *, T, N>(arr)); +} + +#endif // IMAGE_PROCESSING_MAP_TYPING_H diff --git a/llvm/projects/hpvm-tensor-rt/tensor_runtime/include/functional/reduce.cuh b/llvm/projects/hpvm-tensor-rt/tensor_runtime/include/functional/reduce.cuh new file mode 100644 index 0000000000000000000000000000000000000000..bcd58f90bdf444266e511bb19d69dbb4d4d9bdcf --- /dev/null +++ b/llvm/projects/hpvm-tensor-rt/tensor_runtime/include/functional/reduce.cuh @@ -0,0 +1,184 @@ +#include <device_launch_parameters.h> +#include <functional> +#include <numeric> + +#include "common.h" +#include "debug.h" +#include "tensor.h" +#include "tensor_utils.cu" + +// Between CUDA compute capability 1.0 and 7.5, +// Least "max # threads per block" is 512, so 512 is used to be compatible; +// at most 2048 threads per multiprocessor, where # of cores varies greatly +// among devices. Titan X has 3072 cores, Quadro P1000 has 640. A bit of +// over-subscription doesn't hurt. These numbers will keep us compatible even +// for 1.0 devices. +constexpr size_t NThreads = 512, MaxNBlocks = 2048 / NThreads * 3072; +constexpr size_t MaxBlocksPerDim = 65535; + +constexpr size_t AlongDimTh = 16, CrossDimTh = 32; + +/* + * Reduce along one dimension with a single thread. + */ +template <typename K, class BinOp> +__device__ void reduceAlongDim( + K *target, K *src, K init, BinOp binary_op, size_t num_irows, + size_t dim_size) { + K acc = init; + for (size_t col = 0; col < dim_size; ++col) { + acc = binary_op(acc, *src); + src += num_irows; + } + *target = acc; +} + +/* + * Parallel reduce a dimension of tensor to a scalar value. + * Use `n_along_dim_threads` threads to sweep along the dim to be reduced. + * Intermediate values are collected in a divide-and-conquer manner, + * with thread 0 finally writing back the result. + */ +template <typename K, class BinOp> +__device__ void parallelReduceAlongDim( + K *target, K *src, K *sbuf, K init, BinOp binary_op, size_t num_irows, + size_t dim_size, size_t along_dim_tid, size_t n_along_dim_threads) { + K acc = init; + // Sequential reduction within a thread. + for (size_t col = along_dim_tid; col < dim_size; col += n_along_dim_threads) { + acc = binary_op(acc, src[col * num_irows]); + } + + sbuf[along_dim_tid] = acc; + + __syncthreads(); + + // Reduce intermediate values to single value. + for (size_t s = AlongDimTh >> 1u; s > 0; s >>= 1u) { + if (along_dim_tid < s) { + K arg1 = sbuf[along_dim_tid]; + K arg2 = sbuf[along_dim_tid + s]; + K res = binary_op(arg1, arg2); + sbuf[along_dim_tid] = res; + } + __syncthreads(); + } + + if (along_dim_tid == 0) { + *target = sbuf[0]; + } + __syncthreads(); +} + +/* + * Reduce the whole tensor with parallelism only on output. + * The reduce axis is reduced sequentially. + * Block is 2D, thread is 1D; block.y covers outer rows, block.x * thread.x + * covers inner rows. + */ +template <typename K, class BinOp> +__global__ void kernelReduceDimSeq( + K *target_, K *src_, K init, BinOp binary_op, size_t num_irows, + size_t num_orows, size_t row_size, size_t approx_row_size) { + size_t start_orow = blockIdx.y, + start_irow = blockIdx.x * blockDim.x + threadIdx.x; + size_t orow_stride = gridDim.y, irow_stride = gridDim.x * blockDim.x; + for (size_t orow = start_orow; orow < num_orows; orow += orow_stride) { + for (size_t irow = start_irow; irow < num_irows; irow += irow_stride) { + K *src = src_ + orow * row_size * num_irows + irow; + K *target = target_ + orow * num_irows + irow; + reduceAlongDim(target, src, init, binary_op, num_irows, approx_row_size); + } + } +} + +/* + * Reduce the whole tensor with parallelism on output and reduce axis. + * I.e., the reduce axis is reduced parallel. + * Block is 2D, thread is 2D; + * thread.x covers reduce axis, + * block.x * thread.y covers inner rows, + * and block.y covers outer rows. + */ +template <typename K, class BinOp> +__global__ void kernelReduceDimParallel( + K *target_, K *src_, K init, BinOp binary_op, size_t num_irows, + size_t num_orows, size_t row_size, size_t approx_row_size) { + __shared__ K sbuf[CrossDimTh][AlongDimTh + 1]; // avoid bank conflict + size_t start_orow = blockIdx.y, + start_irow = blockIdx.x * blockDim.y + threadIdx.y; + size_t orow_stride = gridDim.y, irow_stride = gridDim.x * blockDim.y; + for (size_t orow = start_orow; orow < num_orows; orow += orow_stride) { + for (size_t irow = start_irow; irow < num_irows; irow += irow_stride) { + K *src = src_ + orow * row_size * num_irows + irow; + K *target = target_ + orow * num_irows + irow; + parallelReduceAlongDim( + target, src, sbuf[threadIdx.y], init, binary_op, num_irows, + approx_row_size, threadIdx.x, blockDim.x); + } + } +} + +template <typename Scalar> +__host__ Tensor *reduceDim( + Tensor *src, const Scalar &init, void *func, size_t axis, float skip_rate) { + // Copy input over + hostToDeviceCopy(src); + + // Cast function ptr to right type + using bin_float_op = NTo1MapF<Scalar, 2>; + auto func_ptr = resolve_func_ptr<bin_float_op>(func); + + // Prepare output + std::vector<size_t> in_sizes = sizes(src), out_sizes = in_sizes; + out_sizes[axis] = 1; + auto *target = (Tensor *)create4DTensor( + CUDNN_DATA_FLOAT, CUDNN_TENSOR_NCHW, out_sizes[0], out_sizes[1], out_sizes[2], out_sizes[3]); + changeTensorPlacement(target, DEVICE); + + // Calculate schedule parameters + size_t num_orows = std::accumulate( + in_sizes.begin(), in_sizes.begin() + axis, 1, std::multiplies<>()); + size_t row_size = in_sizes[axis]; + size_t num_irows = std::accumulate( + in_sizes.begin() + axis + 1, in_sizes.end(), 1, std::multiplies<>()); + size_t num_rows = num_irows * num_orows; + + // Calculate approximation parameters + if (skip_rate != 0.0f) + INFO("Approximation happening..."); + size_t approx_row_size = (size_t)((1 - skip_rate) * row_size); + + // If # of output entries is small, and row_size is enough for 16 threads, + // reduce in parallel. + // Remember if reducing dim in parallel, threads must be (16, 32). + if (num_rows < NThreads * MaxNBlocks && row_size >= AlongDimTh * 8) { + DEBUG( + "Reducing in parallel, row size = %lu, actually using %lu", row_size, + approx_row_size); + size_t grid_x = std::min(MaxBlocksPerDim, ceilDiv(num_irows, 32ul)); + size_t grid_y = std::min( + std::min(MaxBlocksPerDim, num_orows), ceilDiv(MaxNBlocks, grid_x)); + dim3 threads(AlongDimTh, CrossDimTh); + dim3 grid(grid_x, grid_y); + kernelReduceDimParallel<Scalar, bin_float_op><<<grid, threads>>>( + (Scalar *)target->gpu_data, (Scalar *)src->gpu_data, init, func_ptr, + num_irows, num_orows, row_size, approx_row_size); + } else { + DEBUG( + "Reducing sequentially, row size = %lu, actually using %lu", row_size, + approx_row_size); + // Reduce sequentially. + size_t threads = std::min(NThreads, num_irows); + size_t grid_x = std::min(MaxBlocksPerDim, ceilDiv(num_irows, threads)); + size_t grid_y = std::min( + std::min(MaxBlocksPerDim, num_orows), ceilDiv(MaxNBlocks, grid_x)); + dim3 grid(grid_x, grid_y); + kernelReduceDimSeq<Scalar, bin_float_op><<<grid, threads>>>( + (Scalar *)target->gpu_data, (Scalar *)src->gpu_data, init, func_ptr, + num_irows, num_orows, row_size, approx_row_size); + } + cudaDeviceSynchronize(); + checkCudaErrors(cudaGetLastError()); + return target; +} diff --git a/llvm/projects/hpvm-tensor-rt/tensor_runtime/src/common.cpp b/llvm/projects/hpvm-tensor-rt/tensor_runtime/src/common.cpp new file mode 100644 index 0000000000000000000000000000000000000000..c30ddf25e5aeffef8123897af863b520846b62ce --- /dev/null +++ b/llvm/projects/hpvm-tensor-rt/tensor_runtime/src/common.cpp @@ -0,0 +1,19 @@ +#include "functional/common.h" + +#include <numeric> +#include <functional> + +std::vector<size_t> sizes(const Dimension &dim) { + return std::vector<size_t>(dim.dim_sizes, dim.dim_sizes + dim.num_dims); +} + +std::vector<size_t> sizes(Tensor *t) { return sizes(t->dims); } + +size_t num_elems(const std::vector<size_t> &dim_sizes) { + return std::accumulate( + dim_sizes.begin(), dim_sizes.end(), 1, std::multiplies<>()); +} + +size_t num_elems(const Dimension &dim) { return num_elems(sizes(dim)); } + +size_t num_elems(Tensor *t) { return num_elems(sizes(t)); } diff --git a/llvm/projects/hpvm-tensor-rt/tensor_runtime/src/debug.cpp b/llvm/projects/hpvm-tensor-rt/tensor_runtime/src/debug.cpp new file mode 100644 index 0000000000000000000000000000000000000000..467336dd411ebc8d805bccf3430b74be98f4fec0 --- /dev/null +++ b/llvm/projects/hpvm-tensor-rt/tensor_runtime/src/debug.cpp @@ -0,0 +1,130 @@ +#include <cstdarg> +#include <cstdio> +#include <stdexcept> +#include <cuda_runtime_api.h> +#include "debug.h" + +void throwError(const char *file, int line, const char *fmt, ...) { + char msg[2048]; + va_list args; + /* vasprintf not standard */ + /* vsnprintf: how to handle if does not exist? */ + va_start(args, fmt); + int n = vsnprintf(msg, 2048, fmt, args); + va_end(args); + if (n < 2048) { + snprintf(msg + n, 2048 - n, " at %s:%d", file, line); + } + + throw std::runtime_error(msg); +} + +template<typename T, typename F> +void checkCompareFlag( + T err, T success_const, F get_err_str, const char *error_kind, const char *file, int line +) { + if (err != success_const) { + static int alreadyFailed = 0; + if (!alreadyFailed) { + fprintf( + stderr, "%s Error file=%s line=%i error=%i : %s\n", + error_kind, file, line, err, + get_err_str(err) + ); + alreadyFailed = 1; + } + throwError( + file, line, "%s Error error (%d) : %s", error_kind, err, + get_err_str(err) + ); + } +} + +void _checkCUDA(cudaError_t err, const char *file, int line) { + checkCompareFlag(err, cudaSuccess, cudaGetErrorString, "CUDA", file, line); +} + +void _checkWarnCUDA(cudaError_t err, const char *file, int line) { + if (err != cudaSuccess) { + fprintf(stderr, "CUDA Warning file=%s line=%i error=%i : %s\n", file, line, err, + cudaGetErrorString(err)); + } +} + +void _checkCUDNN(cudnnStatus_t error, const char *file, int line) { + checkCompareFlag(error, CUDNN_STATUS_SUCCESS, cudnnGetErrorString, "CUDNN", file, line); +} + +static const char *cublasGetErrorString(cublasStatus_t status) { + switch (status) { + case CUBLAS_STATUS_SUCCESS: + return "CUBLAS_STATUS_SUCCESS"; + case CUBLAS_STATUS_NOT_INITIALIZED: + return "CUBLAS_STATUS_NOT_INITIALIZED"; + case CUBLAS_STATUS_ALLOC_FAILED: + return "CUBLAS_STATUS_ALLOC_FAILED"; + case CUBLAS_STATUS_INVALID_VALUE: + return "CUBLAS_STATUS_INVALID_VALUE"; + case CUBLAS_STATUS_ARCH_MISMATCH: + return "CUBLAS_STATUS_ARCH_MISMATCH"; + case CUBLAS_STATUS_MAPPING_ERROR: + return "CUBLAS_STATUS_MAPPING_ERROR"; + case CUBLAS_STATUS_EXECUTION_FAILED: + return "CUBLAS_STATUS_EXECUTION_FAILED"; + case CUBLAS_STATUS_INTERNAL_ERROR: + return "CUBLAS_STATUS_INTERNAL_ERROR"; + case CUBLAS_STATUS_NOT_SUPPORTED: + return "CUBLAS_STATUS_NOT_SUPPORTED"; + case CUBLAS_STATUS_LICENSE_ERROR: + return "CUBLAS_STATUS_LICENSE_ERROR"; + } + return "unknown error"; +} + +void _checkCUBLAS(cublasStatus_t error, const char *file, int line) { + checkCompareFlag(error, CUBLAS_STATUS_SUCCESS, cublasGetErrorString, "CUBLAS", file, line); +} + +static const char *cufftGetErrorString(cufftResult error) { + switch (error) { + case CUFFT_SUCCESS: + return "CUFFT_SUCCESS"; + case CUFFT_INVALID_PLAN: + return "CUFFT_INVALID_PLAN"; + case CUFFT_ALLOC_FAILED: + return "CUFFT_ALLOC_FAILED"; + case CUFFT_INVALID_TYPE: + return "CUFFT_INVALID_TYPE"; + case CUFFT_INVALID_VALUE: + return "CUFFT_INVALID_VALUE"; + case CUFFT_INTERNAL_ERROR: + return "CUFFT_INTERNAL_ERROR"; + case CUFFT_EXEC_FAILED: + return "CUFFT_EXEC_FAILED"; + case CUFFT_SETUP_FAILED: + return "CUFFT_SETUP_FAILED"; + case CUFFT_INVALID_SIZE: + return "CUFFT_INVALID_SIZE"; + case CUFFT_UNALIGNED_DATA: + return "CUFFT_UNALIGNED_DATA"; + case CUFFT_INCOMPLETE_PARAMETER_LIST: + return "CUFFT_INCOMPLETE_PARAMETER_LIST"; + case CUFFT_INVALID_DEVICE: + return "CUFFT_INVALID_DEVICE"; + case CUFFT_PARSE_ERROR: + return "CUFFT_PARSE_ERROR"; + case CUFFT_NO_WORKSPACE: + return "CUFFT_NO_WORKSPACE"; + case CUFFT_NOT_IMPLEMENTED: + return "CUFFT_NOT_IMPLEMENTED"; + case CUFFT_LICENSE_ERROR: + return "CUFFT_LICENSE_ERROR"; + case CUFFT_NOT_SUPPORTED: + return "CUFFT_NOT_SUPPORTED"; + } + return "<unknown>"; +} + +void _checkCUFFT(cufftResult error, const char *file, int line) { + checkCompareFlag(error, CUFFT_SUCCESS, cufftGetErrorString, "CUFFT", file, line); +} diff --git a/llvm/projects/hpvm-tensor-rt/tensor_runtime/src/img_tensor_runtime.cu b/llvm/projects/hpvm-tensor-rt/tensor_runtime/src/img_tensor_runtime.cu index 5776e5c385de79ed96255763a770a31d2b6d3df3..497d9ec38bfbff86afc21e290c29c589fa12fc7c 100644 --- a/llvm/projects/hpvm-tensor-rt/tensor_runtime/src/img_tensor_runtime.cu +++ b/llvm/projects/hpvm-tensor-rt/tensor_runtime/src/img_tensor_runtime.cu @@ -1,16 +1,108 @@ -#include "../include/approxhpvm_img_runtime_utils.h" -#include "../include/debug.h" -#include "../include/img_tensor_runtime.h" +#include "approxhpvm_img_runtime_utils.h" +#include "debug.h" +#include "img_tensor_runtime.h" + +#include "functional/map.cuh" +#include "functional/reduce.cuh" +#include "tensor_utils.cu" + +#include <cufft.h> + +// FIXME: really just a hack to compile into a single .o +#include "common.cpp" +#include "debug.cpp" // *** Runtime implementation *** // -void *tensorFft(void *input) {} -void *tensorReduce(void *input, size_t axis, void *func) {} +void *tensorFft(void *input) { + // https://docs.nvidia.com/cuda/cufft/index.html#twod-complex-to-real-transforms + // Tensor checking + INFO("FFT"); + auto *t_input = (Tensor *)input; + if (t_input->data_type != CUDNN_DATA_FLOAT) + throw std::runtime_error("Only float32 is supported"); + int total_rank = t_input->dims.num_dims; + if (total_rank != 4) + throw std::runtime_error("Only 4-dim tensor supported"); + // Dimensions + size_t *all_dims = t_input->dims.dim_sizes; + int width = all_dims[2], height = all_dims[3]; + int fft_dim[2] = {width, height}; + int n_batch = int(all_dims[0]) * int(all_dims[1]); + // Prepare input data + hostToDeviceCopy(t_input); + auto *input_cuda = (cufftReal *)t_input->gpu_data; + // Define output data + // FIXME: make a flag for float2_; not CUDNN_DATA_FLOAT. + auto *out_tensor = (Tensor *)create4DTensor( + CUDNN_DATA_FLOAT, CUDNN_TENSOR_NCHW, all_dims[0], all_dims[1], width, + (height / 2 + 1)); + changeTensorPlacement(out_tensor, DEVICE); + auto *output_cuda = (cufftComplex *)out_tensor->gpu_data; + // Create a 2D FFT plan + cufftHandle plan; + checkCUFFT(cufftPlanMany( + &plan, 2, fft_dim, nullptr, 1, 0, nullptr, 1, 0, CUFFT_R2C, n_batch)); + // Execute the plan + checkCUFFT(cufftExecR2C(plan, input_cuda, output_cuda)); + // Wait for the device to finish + checkCudaErrors(cudaDeviceSynchronize()); + // Release memory + cufftDestroy(plan); + return out_tensor; +} + +void *tensorReduce(void *input, size_t axis, void *func) { + INFO("Reduce"); + auto *src = (Tensor *)input; + if (axis >= src->dims.num_dims) + throw std::runtime_error("Dimension out of range"); + if (src->dims.num_dims != 4 || src->data_format != CUDNN_TENSOR_NCHW) + throw std::runtime_error("Not supported"); + + // Skip 0% of sample + return reduceDim<float>(src, 0.0f, func, axis, 0.0f); +} + void *tensorReductionSamplingReduce( - void *input, size_t axis, void *func, int skip_level) {} -void *tensorProjectiveT(void *input, void *transformation) {} -void *tensorMap1(void *f, void *i) {} -void *tensorMap2(void *f2, void *i1, void *i2) {} -void *tensorMap3(void *f3, void *i1, void *i2, void *i3) {} + void *input, size_t axis, void *func, int skip_level) { + INFO("Reduce with sampling"); + auto *src = (Tensor *)input; + if (axis >= src->dims.num_dims) + throw std::runtime_error("Dimension out of range"); + if (src->dims.num_dims != 4 || src->data_format != CUDNN_TENSOR_NCHW) + throw std::runtime_error("Not supported"); + + switch (skip_level) { + case 0: + return reduceDim<float>(src, 0.0f, func, axis, 0.1f); + case 1: + return reduceDim<float>(src, 0.0f, func, axis, 0.2f); + case 2: + return reduceDim<float>(src, 0.0f, func, axis, 0.4f); + } +} + +void *tensorProjectiveT(void *input, void *transformation) { + ERROR("ProjectiveT operation currently unsupported.\n"); +} + +void *tensorMap1(void *f, void *i) { + INFO("Map1"); + auto *src = (Tensor *)i; + return mapGeneral<float, 1>(f, {src}); +} + +void *tensorMap2(void *f2, void *i1, void *i2) { + INFO("Map2"); + auto *src1 = (Tensor *)i1, *src2 = (Tensor *)i2; + return mapGeneral<float, 2>(f2, {src1, src2}); +} + +void *tensorMap3(void *f3, void *i1, void *i2, void *i3) { + INFO("Map3"); + auto *src1 = (Tensor *)i1, *src2 = (Tensor *)i2, *src3 = (Tensor *)i3; + return mapGeneral<float, 3>(f3, {src1, src2, src3}); +} // *** Wrapper API implementation *** //