From 2b83b5f50b03de5c626364ffeef3b20535290c89 Mon Sep 17 00:00:00 2001 From: Yifan Zhao <yifanz16@illinois.edu> Date: Wed, 6 Nov 2019 18:48:01 -0600 Subject: [PATCH] Half functions for half precision --- .gitignore | 2 +- .../include/approxhpvm_img_runtime_utils.h | 20 +++++++++--- .../tensor_runtime/include/device_math.h | 13 +++++--- .../tensor_runtime/src/device_math.cu | 31 ++++++++++++++++++- 4 files changed, 55 insertions(+), 11 deletions(-) diff --git a/.gitignore b/.gitignore index a39f1eeab2..ec301811d1 100644 --- a/.gitignore +++ b/.gitignore @@ -34,5 +34,5 @@ llvm/test/VISC/parboil/benchmarks/*/run llvm/test/VISC/parboil/benchmarks/*/build llvm/build llvm/install -build/ +build*/ install/ diff --git a/llvm/projects/hpvm-tensor-rt/tensor_runtime/include/approxhpvm_img_runtime_utils.h b/llvm/projects/hpvm-tensor-rt/tensor_runtime/include/approxhpvm_img_runtime_utils.h index 413e2f2bbf..d19e26728a 100644 --- a/llvm/projects/hpvm-tensor-rt/tensor_runtime/include/approxhpvm_img_runtime_utils.h +++ b/llvm/projects/hpvm-tensor-rt/tensor_runtime/include/approxhpvm_img_runtime_utils.h @@ -63,8 +63,11 @@ void *handleTensorReduceApproximationTuples( } case GPUNodeConfiguration::APPROX::FP16: { void *t_out; + void *half_func = find_half_function(func); + if (half_func == nullptr) + throw std::runtime_error("Half version of function does not exist"); RC->resume_profiler(); - t_out = tensorReduce(input, axis, func, true); + t_out = tensorReduce(input, axis, half_func, true); RC->pause_profiler(); std::pair<double, double> pinfo = RC->get_time_energy(); RC->reset_profiler(); @@ -155,8 +158,11 @@ void *handleTensorMap1ApproximationTuples( } case GPUNodeConfiguration::APPROX::FP16: { void *t_out; + void *half_func = find_half_function(func); + if (half_func == nullptr) + throw std::runtime_error("Half version of function does not exist"); RC->resume_profiler(); - t_out = tensorMap1(func, input, true); + t_out = tensorMap1(half_func, input, true); RC->pause_profiler(); std::pair<double, double> pinfo = RC->get_time_energy(); RC->reset_profiler(); @@ -200,8 +206,11 @@ void *handleTensorMap2ApproximationTuples( } case GPUNodeConfiguration::APPROX::FP16: { void *t_out; + void *half_func = find_half_function(func); + if (half_func == nullptr) + throw std::runtime_error("Half version of function does not exist"); RC->resume_profiler(); - t_out = tensorMap2(func, input1, input2, true); + t_out = tensorMap2(half_func, input1, input2, true); RC->pause_profiler(); std::pair<double, double> pinfo = RC->get_time_energy(); RC->reset_profiler(); @@ -245,8 +254,11 @@ void *handleTensorMap3ApproximationTuples( } case GPUNodeConfiguration::APPROX::FP16: { void *t_out; + void *half_func = find_half_function(func); + if (half_func == nullptr) + throw std::runtime_error("Half version of function does not exist"); RC->resume_profiler(); - t_out = tensorMap3(func, input1, input2, input3, true); + t_out = tensorMap3(half_func, input1, input2, input3, true); RC->pause_profiler(); std::pair<double, double> pinfo = RC->get_time_energy(); RC->reset_profiler(); diff --git a/llvm/projects/hpvm-tensor-rt/tensor_runtime/include/device_math.h b/llvm/projects/hpvm-tensor-rt/tensor_runtime/include/device_math.h index 353e966951..9398730d7e 100644 --- a/llvm/projects/hpvm-tensor-rt/tensor_runtime/include/device_math.h +++ b/llvm/projects/hpvm-tensor-rt/tensor_runtime/include/device_math.h @@ -1,3 +1,6 @@ +#ifndef DEVICE_MATH_H +#define DEVICE_MATH_H + #include <device_launch_parameters.h> namespace device { @@ -25,13 +28,13 @@ extern void *sqrt_ptrptr; extern void *fmax_ptrptr; extern void *fmin_ptrptr; extern void *favg3_ptrptr; - -extern void *hhypot_ptrptr; -extern void *hadd_ptrptr; -extern void *hmax_ptrptr; -extern void *hdiv_ptrptr; } // namespace device #define DEF_FUNC(func) \ __device__ void *func##_ptr = (void *)func; \ void *func##_ptrptr = (void *)&func##_ptr; + + +void *find_half_function(void *known_float_func); + +#endif diff --git a/llvm/projects/hpvm-tensor-rt/tensor_runtime/src/device_math.cu b/llvm/projects/hpvm-tensor-rt/tensor_runtime/src/device_math.cu index f20de26c0a..6230df2a91 100644 --- a/llvm/projects/hpvm-tensor-rt/tensor_runtime/src/device_math.cu +++ b/llvm/projects/hpvm-tensor-rt/tensor_runtime/src/device_math.cu @@ -163,6 +163,8 @@ __device__ half hhypot(half x, half y) { return hsqrt(x * x + y * y); } __device__ half hmax(half x, half y) { return x >= y ? x : y; } +__device__ half havg3(half x, half y) { return __hdiv(x, 3.0f); } + __device__ void *hypot_ptr = (void *)device::hypot; __device__ void *atan2_ptr = (void *)device::atan2; __device__ void *fadd_ptr = (void *)device::add; @@ -178,6 +180,7 @@ __device__ void *hhypot_ptr = (void *)hhypot; __device__ void *hadd_ptr = (void *)(half(*)(half, half))__hadd; __device__ void *hmax_ptr = (void *)(half(*)(half, half))hmax; __device__ void *hdiv_ptr = (void *)__hdiv; +__device__ void *havg3_ptr = (void *)havg3; namespace device { void *fhypot_ptrptr = (void *)&hypot_ptr; @@ -191,8 +194,34 @@ void *fmax_ptrptr = (void *)&fmax_ptr; void *fmin_ptrptr = (void *)&fmin_ptr; void *favg3_ptrptr = (void *)&favg3_ptr; +} // namespace device + void *hhypot_ptrptr = (void *)&hhypot_ptr; void *hadd_ptrptr = (void *)&hadd_ptr; void *hmax_ptrptr = (void *)&hmax_ptr; void *hdiv_ptrptr = (void *)&hdiv_ptr; -} // namespace device +void *havg3_ptrptr = (void *)&havg3_ptr; + +void *find_half_function(void *known_float_func) { + if (known_float_func == device::fhypot_ptrptr) + return hhypot_ptrptr; + if (known_float_func == device::atan2_ptrptr) + return nullptr; + if (known_float_func == device::fadd_ptrptr) + return hadd_ptrptr; + if (known_float_func == device::fsub_ptrptr) + return nullptr; + if (known_float_func == device::fmul_ptrptr) + return nullptr; + if (known_float_func == device::fdiv_ptrptr) + return hdiv_ptrptr; + if (known_float_func == device::sqrt_ptrptr) + return nullptr; + if (known_float_func == device::fmax_ptrptr) + return hmax_ptrptr; + if (known_float_func == device::fmin_ptrptr) + return nullptr; + if (known_float_func == device::favg3_ptrptr) + return havg3_ptrptr; + return nullptr; +} -- GitLab