From 1bd470b1e9ac391aecf6224071875ec290fea0e1 Mon Sep 17 00:00:00 2001
From: Yasmin Sarita <ysarita2@tyler.cs.illinois.edu>
Date: Sat, 27 Jul 2019 11:04:20 -0500
Subject: [PATCH] fp16 depthwise conv

---
 .../include/approx_techniques.h               | 295 +++++++++++++++++-
 1 file changed, 286 insertions(+), 9 deletions(-)

diff --git a/llvm/projects/hpvm-tensor-rt/tensor_runtime/include/approx_techniques.h b/llvm/projects/hpvm-tensor-rt/tensor_runtime/include/approx_techniques.h
index 6701760584..e81a78860b 100644
--- a/llvm/projects/hpvm-tensor-rt/tensor_runtime/include/approx_techniques.h
+++ b/llvm/projects/hpvm-tensor-rt/tensor_runtime/include/approx_techniques.h
@@ -317,6 +317,91 @@ __global__ void depthwise_convNew8(float* const __restrict__ y,
        #undef x4d
 }
 
+__global__ void depthwise_convNew8_half(__half* const __restrict__ y,
+	const __half* const __restrict__ x,
+	const __half* const __restrict__ w,
+	const int B, const int M,
+	const int H, const int W, const int KH,
+	const int KW, const int H_out, const int W_out,
+	const int H_pad, const int W_pad,
+	const int H_stride, const int W_stride)
+{
+
+        #define y4d(i3, i2, i1, i0) y[(i3) * (M * H_out * W_out) + (i2) * (H_out * W_out) + (i1) * (W_out) + i0]
+        #define x4d(i3, i2, i1, i0) x[(i3) * (M * H * W) + (i2) * (H * W) + (i1) * (W) + i0]
+
+	const int num = 8;
+
+	const int b = num * blockIdx.x;
+	const int m = (blockIdx.y * blockDim.x  + threadIdx.x)/ (H_out * W_out);
+	
+	if(m < M){
+	const int tx = (blockIdx.y * blockDim.x  + threadIdx.x) % (H_out * W_out);
+
+	const int start_h = (tx / W_out) * H_stride - H_pad;
+	const int start_w = (tx % W_out) * W_stride - W_pad;
+
+	__half c0 = 0;
+	__half c1 = 0;
+	__half c2 = 0;
+	__half c3 = 0;
+	__half c4 = 0;
+	__half c5 = 0;
+	__half c6 = 0;
+	__half c7 = 0;
+	
+	const __half* weights = &w[m * KH * KW];
+
+	for (int k = 0; k < KH * KW; k++) {
+		int p = k / KW;
+		int q = k % KW;
+
+		if (start_h + p > -1 && start_h + p < H &&
+				start_w + q > -1 && start_w + q < W) {
+
+		  c0 = __hfma(x4d(b, m, start_h + p, start_w + q), weights[k], c0);
+                  if(b + 1 < B)
+		    c1 = __hfma(x4d(b + 1, m, start_h + p, start_w + q), weights[k], c1);
+		  if(b + 2 < B)
+		    c2 = __hfma(x4d(b + 2, m, start_h + p, start_w + q), weights[k], c2);
+		  if(b + 3 < B)
+		    c3 = __hfma(x4d(b + 3, m, start_h + p, start_w + q), weights[k], c3);
+		  if(b + 4 < B)
+		    c4 = __hfma(x4d(b + 4, m, start_h + p, start_w + q), weights[k], c4);
+		  if(b + 5 < B)
+		    c5 = __hfma(x4d(b + 5, m, start_h + p, start_w + q), weights[k], c5);
+		  if(b + 6 < B)
+		    c6 = __hfma(x4d(b + 6, m, start_h + p, start_w + q), weights[k], c6);
+		  if(b + 7 < B)
+		    c7 = __hfma(x4d(b + 7, m, start_h + p, start_w + q), weights[k], c7);
+    
+
+		}
+	}
+
+	y4d(b, m, 0, tx) = c0;	
+        if(b + 1 < B)
+	  y4d(b + 1, m, 0, tx) = c1;
+	if(b + 2 < B)
+	  y4d(b + 2, m, 0, tx) = c2;
+	if(b + 3 < B)
+	  y4d(b + 3, m, 0, tx) = c3;
+	if(b + 4 < B)
+	  y4d(b + 4, m, 0, tx) = c4;
+	if(b + 5 < B)
+	  y4d(b + 5, m, 0, tx) = c5;
+	if(b + 6 < B)
+	  y4d(b + 6, m, 0, tx) = c6;
+	if(b + 7 < B)
+	  y4d(b + 7, m, 0, tx) = c7;
+	}
+	
+       #undef y4d 
+       #undef x4d
+}
+
+
+
 __global__ void depthwise_convNew12(float* const __restrict__ y,
 	const float* const __restrict__ x,
 	const float* const __restrict__ w,
@@ -430,7 +515,7 @@ void* tensorConvCutlass(void* input_ptr, void* filter_ptr,
 	llvm_hpvm_initTensorRt(0);
 
 	INFO("*** TensorConvolution \n");
-	profileEvent("tensorConv");
+	profileEvent("Conv");
 
 	Tensor* input = (Tensor*)input_ptr;
 	Tensor* filter = (Tensor*)filter_ptr;
@@ -442,8 +527,8 @@ void* tensorConvCutlass(void* input_ptr, void* filter_ptr,
 
 	Tensor* output;
 
-	//if(conv_groups < 0) {
-	if (conv_groups > 1) {
+	
+	if (conv_groups > 32) {
 		// TODO: Support other cases;  
 		hostToDeviceCopy(input);
 		hostToDeviceCopy(filter);
@@ -493,10 +578,7 @@ void* tensorConvCutlass(void* input_ptr, void* filter_ptr,
 		*/
 		
 		int blockSize;
-		if(h * w > 1023)
-		  blockSize = 256;
-		else
-		  blockSize = 128;
+		blockSize = 128;
 		
 		dim3 grid(((n + 7)/ 8), (c * h * w + blockSize - 1)/ blockSize);
 		dim3 block(blockSize);
@@ -617,7 +699,7 @@ void* tensorConvCutlass(void* input_ptr, void* filter_ptr,
 	}
 
 	cudaDeviceSynchronize();
-	profileEvent("tensorConv_end", true);
+	profileEvent("Conv_end", true);
 
 
         #ifdef ERROR_INJECTION_ENABLED
@@ -646,11 +728,206 @@ void* tensorConvCutlass(void* input_ptr, void* filter_ptr,
 
 }
 
-void* tensorHalfConvCutlass(void* input, void* filter,
+void* tensorHalfConvCutlass(void* input_ptr, void* filter_ptr,
 			    int vertical_pad, int horizontal_pad,
 			    int vertical_stride, int horizontal_stride,
 			    int conv_mode, int conv_groups){
 
+  INFO("*** TensorHConvolution \n");
+  profileEvent("#Conv");
+
+  Tensor* input = (Tensor*) input_ptr;
+  Tensor* filter = (Tensor*) filter_ptr;
+
+  cudnnConvolutionDescriptor_t convDesc;
+  cudnnConvolutionFwdAlgo_t convAlgo;
+  cudnnConvolutionMode_t mode;
+  if(conv_mode == 0)
+    mode = CUDNN_CONVOLUTION;
+  else if(conv_mode == 1)
+    mode = CUDNN_CROSS_CORRELATION;
+
+  // FIXIT: Need to be more aware of the implications of alpha and beta
+  float alpha = 1.0f, beta = 0.0f;
+  // NOTE: compute in half precision
+  cudnnDataType_t computeType = CUDNN_DATA_HALF;
+
+  // NOTE: Moving inputs to GPU global memory
+  hostToDeviceCopy(input);
+  hostToDeviceCopy(filter);
+
+
+  /***** CONVERSIONS from FP32 to FP16 - on the GPU */
+  size_t* input_dims = input->dims.dim_sizes;
+  size_t* filter_dims = filter->dims.dim_sizes;
+
+
+  profileEvent("F2H_start");
+
+  Tensor* input_half = (Tensor*) create4DTensor(CUDNN_DATA_HALF, CUDNN_TENSOR_NCHW,
+						input_dims[0], input_dims[1],
+						input_dims[2], input_dims[3]);
+
+
+  changeTensorPlacement(input_half, DEVICE);
+  Tensor* filter_half = (Tensor*) create4DTensor(CUDNN_DATA_HALF, CUDNN_TENSOR_NCHW,
+						 filter_dims[0], filter_dims[1],
+						 filter_dims[2], filter_dims[3]);
+
+  
+  changeTensorPlacement(filter_half, DEVICE);
+
+
+  f2h((float*) input->gpu_data, input->num_elems, (half*) input_half->gpu_data);
+  f2h((float*) filter->gpu_data, filter->num_elems, (half*) filter_half->gpu_data);
+
+
+  /******* END OF INPUT DATA CONVERSIONS*/
+  profileEvent("F2H_end");
+
+  Tensor *output;
+  Tensor *output_half;
+  
+
+  if(conv_groups > 1){
+    int n = input->dims.dim_sizes[0];
+    int c = input->dims.dim_sizes[1];
+    const int KH = filter->dims.dim_sizes[2];
+    const int KW = filter->dims.dim_sizes[3];
+    int h = (2 * vertical_pad + input->dims.dim_sizes[2] - KH) / vertical_stride + 1;
+    int w = (2 * horizontal_pad + input->dims.dim_sizes[3] - KW) / horizontal_stride + 1;
+
+    
+    DEBUG("**Output Tensor Dims, n = %d, c = %d, h = %d, w = %d \n", n, c, h, w);
+
+
+    output = (Tensor*) create4DTensor((cudnnDataType_t) input->data_type,
+				      CUDNN_TENSOR_NCHW, n, c, h, w);
+    // FIXIT: more checks for data types needed
+    output_half = (Tensor*) create4DTensor(CUDNN_DATA_HALF,
+					   CUDNN_TENSOR_NCHW, n, c, h, w);
+
+
+  
+    // NOTE: Changing output tensor placement from host to device
+    changeTensorPlacement(output, DEVICE);
+    // NOTE: Necessary to insert the above call for every output tensor
+
+    int blockSize;
+    blockSize = 128;
+		
+    dim3 grid(((n + 7)/ 8), (c * h * w + blockSize - 1)/ blockSize);
+    dim3 block(blockSize);
+    depthwise_convNew8_half<<<grid, block>>> ((__half*)output_half->gpu_data,
+			(__half*)input_half->gpu_data, (__half*)filter_half->gpu_data,
+			input->dims.dim_sizes[0], input->dims.dim_sizes[1], input->dims.dim_sizes[2], input->dims.dim_sizes[3],
+			KH, KW, h, w, vertical_pad, horizontal_pad, vertical_stride, horizontal_stride);
+    cudaDeviceSynchronize();
+
+    
+  }
+  else{
+  checkCUDNN(cudnnCreateConvolutionDescriptor(&convDesc));
+
+  //FIXME: Current hack to preserve backward compatibilty
+  if(conv_groups == 0){
+    conv_groups = 1;
+  }
+  
+  // NOTE: Adding support for grouped convolution
+  checkCUDNN(cudnnSetConvolutionGroupCount(convDesc, conv_groups));
+
+  
+  // FIXIT: Think if upscaling values need to be configurable?
+  // IMP-FIXIT:  CUDNN Cross correlation is only used in the Lenet context
+  // IMP-FIXIT: Either make mode configurable OR see if CUDNN_CONVOLUTION MODE should be used?
+  checkCUDNN(cudnnSetConvolution2dDescriptor(convDesc,
+					     vertical_pad, horizontal_pad, // conv padding
+					     vertical_stride, horizontal_stride, // conv strides
+					     1, 1, // upscaling values
+					     mode, // mode is configurable
+					     computeType)); // defines compute precision
+
+  int n, c, h, w; // output dimensions
+  // Find dimension of convolution output
+  checkCUDNN(cudnnGetConvolution2dForwardOutputDim(convDesc,
+						   input->tensor_desc,
+						   filter->filter_desc,
+						   &n, &c, &h, &w));
+  DEBUG("**Output Tensor Dims, n = %d, c = %d, h = %d, w = %d \n", n, c, h, w);
+
+
+  output = (Tensor*) create4DTensor((cudnnDataType_t) input->data_type,
+					    CUDNN_TENSOR_NCHW, n, c, h, w);
+  // FIXIT: more checks for data types needed
+  output_half = (Tensor*) create4DTensor(CUDNN_DATA_HALF,
+						 CUDNN_TENSOR_NCHW, n, c, h, w);
+
+
+  
+  // NOTE: Changing output tensor placement from host to device
+  changeTensorPlacement(output, DEVICE);
+  // NOTE: Necessary to insert the above call for every output tensor
+
+  DEBUG("tensor->data_type = %d, tensor->data_format = %d, N = %d, H = %d, W = %d, C = %d \n",
+	output->data_type, output->data_format, output->dims.dim_sizes[0], output->dims.dim_sizes[1],
+	output->dims.dim_sizes[2], output->dims.dim_sizes[3]);
+
+  if(convDesc == NULL || input->tensor_desc == NULL ||
+     filter->filter_desc == NULL || output->tensor_desc == NULL)
+    ERROR("NULL descriptor! \n");
+
+
+  // NOTE: The following algo works with TRUE half precision
+  convAlgo = CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_PRECOMP_GEMM;
+  //convAlgo = CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_GEMM;
+
+  
+  size_t workspace_size;
+  checkCUDNN(cudnnGetConvolutionForwardWorkspaceSize(cudnnHandle,
+						     input_half->tensor_desc,
+						     filter_half->filter_desc,
+						     convDesc,
+						     output_half->tensor_desc,
+						     convAlgo,
+						     &workspace_size));
+
+  // Allocating memory for the convolution workspace
+  DEBUG("workspace size = %d \n", workspace_size);
+  void* workspace;
+  checkCudaErrors(cudaMalloc(&workspace, workspace_size));
+
+
+
+
+  checkCUDNN(cudnnConvolutionForward(cudnnHandle,
+				     &alpha,
+				     input_half->tensor_desc,
+				     input_half->gpu_data,
+				     filter_half->filter_desc,
+				     filter_half->gpu_data,
+				     convDesc, convAlgo, workspace, workspace_size,
+				     &beta,
+				     output_half->tensor_desc,
+				     output_half->gpu_data));
+
+  }
+  profileEvent("H2F_start");
+
+  // NOTE: Transforming half precision output to single precision
+  h2f((half*) output_half->gpu_data, output->num_elems, (float*) output->gpu_data);
+
+  profileEvent("H2F_end");
+
+  profileEvent("#tensorHalfConv_end");
+
+
+  freeTensor(input_half);
+  freeTensor(filter_half);
+  freeTensor(output_half);
+
+  return output;
+
 }
 
 
-- 
GitLab