diff --git a/llvm/projects/hpvm-tensor-rt/tensor_runtime/include/approx_techniques2.h b/llvm/projects/hpvm-tensor-rt/tensor_runtime/include/approx_techniques2.h
index a81ffe296233178126555bbb53babdcd4192a7bf..0a741316682324ca6270aab2066ebc9f0b48bcdf 100644
--- a/llvm/projects/hpvm-tensor-rt/tensor_runtime/include/approx_techniques2.h
+++ b/llvm/projects/hpvm-tensor-rt/tensor_runtime/include/approx_techniques2.h
@@ -2,6 +2,36 @@
 #include "tensor_utils.cu"
 
 
+//produces N COL MAJOR matrixes with H_out*W_out rows and reduced_filter_elem cols
+__global__ void convToGemmApproxHalf(__half * const __restrict__ output,
+				     const __half * const __restrict input, const int N, const int C,
+				     const int H, const int W, const int KH, const int KW, const int V_pad,
+				     const int H_pad, const int H_out, const int W_out, const int V_stride,
+				     const int H_stride, const int reduced_filter_elem,
+				     const int skip_every) {
+  const int tx = blockDim.x * blockIdx.x + threadIdx.x; //thread id
+  const int n = tx / (C * H_out * W_out); //output image number
+  const int c = tx % (C * H_out * W_out) / (H_out * W_out); //output chan number
+  const int h = tx % (H_out * W_out) / W_out; //output height index (row number)
+  const int w = tx % W_out; //output width index (col number)
+  const int inH = h * V_stride - V_pad; //input height index (row number)
+  const int inW = w * H_stride - H_pad; //input width index (col number)
+  if(n < N) { //is thread id within bounds?
+    for(int i = 0; i < KH; i++) {
+      for(int j = 0; j < KW; j++) {
+	const int filter_elem_num = (c * KH + i) * KW + j; //index of this filter element
+	if(filter_elem_num % skip_every != skip_every-1) { //are we including this filter element?
+	  const int output_col = filter_elem_num - (filter_elem_num/skip_every); //calculate output column, taking skipping into account
+	  if(inH + i >= 0 && inH + i < H && inW + j >= 0 && inW + j < W)
+	    output[((n * reduced_filter_elem + output_col) * H_out + h) * W_out + w] = input[((n * C + c) * H + (inH + i)) * W + (inW + j)];
+	  else
+	    output[((n * reduced_filter_elem + output_col) * H_out + h) * W_out + w] = 0;
+	}
+      }
+    }
+  }
+}
+
 
 //This skips every xth row
 //H_eff is the number of rows calculated exactly
@@ -350,3 +380,477 @@ void* tensorConvPerfCuda(void* input_ptr, void* filter_ptr,
   
   return new_output;
 }
+
+__global__
+void convToGemmPerfRowHalf(__half * const __restrict__ output,
+			   const __half * const __restrict input, const int N, const int C,
+			   const int H, const int W, const int KH, const int KW, const int V_pad,
+			   const int H_pad, const int H_out, const int W_out, const int V_stride,
+			   const int H_stride, const int x, const int start, const int H_eff){
+
+  const int tx = blockDim.x * blockIdx.x + threadIdx.x; //thread id
+  const int n = tx / (C * H_eff * W_out); //output image number
+  const int c = tx % (C * H_eff * W_out) / (H_eff * W_out); //output chan number
+  const int h = tx % (H_eff * W_out) / W_out; //output height index (row number)
+  const int w = tx % W_out; //output width index (col number)
+  int past_start = (h % (x - 1) >= (x - 1 - start));
+  const int inH = (h / (x - 1) * x + h % (x-1) +
+		   past_start) * V_stride - V_pad; //input height index (row number)
+  const int inW = w * H_stride - H_pad; //input width index (col number)
+  if(n < N) { //is thread id within bounds?
+    for(int i = 0; i < KH; i++) {
+      for(int j = 0; j < KW; j++) {
+	const int filter_elem_num = (c * KH + i) * KW + j; //index of this filter element
+
+	if(inH + i >= 0 && inH + i < H && inW + j >= 0 && inW + j < W)
+	  output[((filter_elem_num * N + n) * H_eff + h) * W_out + w] =
+	    input[((n * C + c) * H + (inH + i)) * W + (inW + j)];
+	else
+	  output[((filter_elem_num * N + n) * H_eff + h) * W_out + w] = 0;
+
+      }
+    }
+  }
+
+}
+
+
+//For use in tensorConvPerfCuda
+//Interpolates every xth row starting from x - 1 - start
+//N is total number of elements in final output array
+__global__
+void approxInterpolateRowHalf(int N, int old_h, int b, int c, int h, int w,
+			      __half *old_data, __half *new_data, int x, int start){
+
+  int index = blockIdx.x * blockDim.x + threadIdx.x;
+  int stride = blockDim.x * gridDim.x;
+
+  for(int i = index; i < N; i += stride){
+    int col = ((i % (c * h * w)) % (h * w)) % w;
+    int row = ((i % (c * h * w)) % (h * w)) / w;
+    int ch = (i % (c * h * w)) / (h * w);
+    int n = i / (c * h * w);
+    int past_start = ((row % x) >= (x - 1 - start));
+
+    if(row == h-1)
+      new_data[n * (c * h * w) + ch * (h * w) + row * (w) + col] =
+	old_data[ch * (b * old_h * w) + n * (old_h * w) + (old_h - 1) * (w) + col];
+    else if (row == 0)
+      new_data[n * (c * h * w) + ch * (h * w) + row * (w) + col] =
+	old_data[ch * (b * old_h * w) + n * (old_h * w) + 0 * (w) + col];
+    else if(row % x == x - 1 - start){
+      int past_startO = ((row - 1) % x) > (x - 1 - start);
+      int oldIdx1 = ch * (b * old_h * w) + n * (old_h * w) +
+	((x-1) * ((row - 1) / x) + (row-1) % x - past_startO) * (w) + col;
+
+      new_data[n * (c * h * w) + ch * (h * w) + row * (w) + col] =
+	__hdiv(__hadd(old_data[oldIdx1], old_data[oldIdx1 + 1 * w]), 2);
+    }
+    else
+      new_data[n * (c * h * w) + ch * (h * w) + row * (w) + col] =
+	old_data[ch * (b * old_h * w) + n * (old_h * w) +
+		 ((x-1) * (row / x) + row % x - past_start )  * (w) + col];
+
+
+  }
+
+}
+
+
+//This skips every xth row
+//W_eff is the number of cols calculated exactly
+__global__
+void convToGemmPerfColHalf(__half * const __restrict__ output,
+			   const __half * const __restrict input, const int N, const int C,
+			   const int H, const int W, const int KH, const int KW, const int V_pad,
+			   const int H_pad, const int H_out, const int W_out, const int V_stride,
+			   const int H_stride, const int x, const int start, const int W_eff){
+
+  const int tx = blockDim.x * blockIdx.x + threadIdx.x; //thread id
+  const int n = tx / (C * H_out * W_eff); //output image number
+  const int c = tx % (C * H_out * W_eff) / (H_out * W_eff); //output chan number
+  const int h = tx % (H_out * W_eff) / W_eff; //output height index (row number)
+  const int w = tx % W_eff; //output width index (col number)
+  int past_start = (w % (x - 1)) >= (x - 1 - start);
+  const int inH = h * V_stride - V_pad; //input height index (row number)
+  const int inW = (w / (x - 1) * x + w % (x-1) +
+		   past_start) * H_stride - H_pad; //input width index (col number)
+  if(n < N) { //is thread id within bounds?
+    for(int i = 0; i < KH; i++) {
+      for(int j = 0; j < KW; j++) {
+	const int filter_elem_num = (c * KH + i) * KW + j; //index of this filter element
+
+	if(inH + i >= 0 && inH + i < H && inW + j >= 0 && inW + j < W)
+	  output[((filter_elem_num * N + n) * H_out + h) * W_eff + w] =
+	    input[((n * C + c) * H + (inH + i)) * W + (inW + j)];
+	else
+	  output[((filter_elem_num * N + n) * H_out + h) * W_eff + w] = 0;
+
+      }
+    }
+  }
+
+}
+
+
+//For use in tensorConvPerfCuda
+//Interpolates every xth col starting from x - 1 - start
+//N is total number of elements in final output array
+__global__
+void approxInterpolateColHalf(int N, int old_w, int b, int c, int h, int w,
+			      __half *old_data, __half *new_data, int x, int start){
+
+
+  int index = blockIdx.x * blockDim.x + threadIdx.x;
+  int stride = blockDim.x * gridDim.x;
+
+  for(int i = index; i < N; i += stride){
+    int col = ((i % (c * h * w)) % (h * w)) % w;
+    int row = ((i % (c * h * w)) % (h * w)) / w;
+    int ch = (i % (c * h * w)) / (h * w);
+    int n = i / (c * h * w);
+    int past_start = ((col % x) >= (x - 1 - start));
+
+    if(col == w-1)
+      new_data[n * (c * h * w) + ch * (h * w) + row * (w) + col] =
+	old_data[ch * (b * h * old_w) + n * (h * old_w) + row * (old_w) + old_w - 1];
+    else if (col == 0)
+      new_data[n * (c * h * w) + ch * (h * w) + row * (w) + col] =
+	old_data[ch * (b * h * old_w) + n * (h * old_w) + row * (old_w)];
+    else if(col % x == x - 1 - start){
+      int past_startO = ((col - 1) % x) > (x - 1 - start);
+      int oldIdx1 = ch * (b * h * old_w) + n * (h * old_w) + row * old_w +
+	((x-1) * ((col - 1) / x) + (col-1) % x - past_startO);
+
+      new_data[n * (c * h * w) + ch * (h * w) + row * (w) + col] =
+	__hdiv(__hadd(old_data[oldIdx1], old_data[oldIdx1 + 1]), 2);
+    }
+    else
+      new_data[n * (c * h * w) + ch * (h * w) + row * (w) + col] =
+	old_data[ch * (b * h * old_w) + n * (h * old_w) + row * old_w +
+		 ((x-1) * (col / x) + col % x - past_start)];
+
+  } 
+}
+
+__global__
+void switchMatrix(int N, int n, int c, int h, int w, __half *old_data, __half *new_data){
+
+  int i = blockIdx.x * blockDim.x + threadIdx.x;
+  if(i < N){
+    int col = ((i % (c * h * w)) % (h * w)) % w;
+    int row = ((i % (c * h * w)) % (h * w)) / w;
+    int ch = (i % (c * h * w)) / (h * w);
+    int n_new = i / (c * h * w);
+
+    new_data[((n_new * c + ch) * h + row ) * w + col] =
+      old_data[((ch * n + n_new) * h + row ) * w + col];
+  }
+
+}
+						
+
+__global__
+void createNewFilter(__half *new_filter, __half *old_filter,
+		     int newFilterSize, int oldFilterSize){
+  int index = blockIdx.x * blockDim.x + threadIdx.x;
+  int stride = blockDim.x * gridDim.x;
+
+  for(int i = index; i < newFilterSize; i += stride){
+    new_filter[i] = old_filter[i % oldFilterSize];
+  }
+}
+
+__global__
+void createBatches(int n, const __half * matA[], const __half * matB[], __half * matC[],
+		   __half * convData, __half * newFilter, __half * output,
+		   int aStride, int bStride, int cStride){
+  int index = blockIdx.x * blockDim.x + threadIdx.x;
+  int stride = blockDim.x * gridDim.x;
+
+  for(int i = index; i < n; i += stride){
+    matA[i] = &convData[i * aStride];
+    matB[i] = &newFilter[i * bStride];
+    matC[i] = &output[i * cStride];
+  }
+}
+
+//produces N COL MAJOR matrixes with H_out*W_out rows and reduced_filter_elem cols
+__global__ void convToGemmApproxHalfN(__half * const __restrict__ output,
+				     const __half * const __restrict input, const int N, const int C,
+				     const int H, const int W, const int KH, const int KW, const int V_pad,
+				     const int H_pad, const int H_out, const int W_out, const int V_stride,
+				     const int H_stride, const int reduced_filter_elem,
+				     const int skip_every) {
+  const int tx = blockDim.x * blockIdx.x + threadIdx.x; //thread id
+  const int n = tx / (C * H_out * W_out); //output image number
+  const int c = tx % (C * H_out * W_out) / (H_out * W_out); //output chan number
+  const int h = tx % (H_out * W_out) / W_out; //output height index (row number)
+  const int w = tx % W_out; //output width index (col number)
+  const int inH = h * V_stride - V_pad; //input height index (row number)
+  const int inW = w * H_stride - H_pad; //input width index (col number)
+  if(n < N) { //is thread id within bounds?
+    for(int i = 0; i < KH; i++) {
+      for(int j = 0; j < KW; j++) {
+	const int filter_elem_num = (c * KH + i) * KW + j; //index of this filter element
+	const int output_col = filter_elem_num; //calculate output column, taking skipping into account
+	if(inH + i >= 0 && inH + i < H && inW + j >= 0 && inW + j < W)
+	  output[((output_col * N + n) * H_out + h) * W_out + w] =
+	    input[((n * C + c) * H + (inH + i)) * W + (inW + j)];
+	else
+	  output[((output_col * N + n) * H_out + h) * W_out + w] = 0;
+
+      }
+    }
+  }
+}
+
+//start has to be less than row or less than col
+//row and col have to be >= 0
+//row = col = 1 means no perforation
+void* tensorConvPerfCudaHalf(void* input_ptr, void* filter_ptr,
+			     int vertical_pad, int horizontal_pad, int vertical_stride,
+			     int horizontal_stride, int conv_mode, int conv_groups,
+			     int row, int col, int start){
+
+  INFO("*** TensorConvolution half perforation \n");
+
+  Tensor* input = (Tensor*)input_ptr;
+  Tensor* filter = (Tensor*)filter_ptr;
+  //FIXME: Current hack to preserve backward compatibilty
+  if (conv_groups == 0) {
+    conv_groups = 1;
+  }
+
+  profileEvent("F2H_start");
+
+  hostToDeviceCopy(input);
+  hostToDeviceCopy(filter);
+
+  convertToFP16(input);
+  convertToFP16(filter);
+
+  /******* END OF INPUT DATA CONVERSIONS*/
+  profileEvent("F2H_end");
+
+  profileEvent("Conv");
+
+  Tensor* output_half;
+  int n, c, h, w; // output dimensions
+  n = input->dims.dim_sizes[0];
+  c = filter->dims.dim_sizes[0]; //number of filters
+  const int KH = filter->dims.dim_sizes[2];
+  const int KW = filter->dims.dim_sizes[3];
+
+  h = (2 * vertical_pad + input->dims.dim_sizes[2] - KH) / vertical_stride + 1;
+  int h_eff = h - h / row;
+  if(h % row > row - 1 - start)
+    h_eff = h_eff - 1;
+
+  w = (2 * horizontal_pad + input->dims.dim_sizes[3] - KW) / horizontal_stride + 1;
+  int w_eff = w - w / col;
+  if(w % col > col - 1 - start)
+    w_eff = w_eff - 1;
+
+
+  Tensor *new_output;
+  if(row > 1){
+    output_half = (Tensor*)create4DTensor((cudnnDataType_t) half_type, CUDNN_TENSOR_NCHW,
+					  n, c, h_eff, w);
+
+    // NOTE: Changing output tensor placement from host to device
+    changeTensorPlacement(output_half, DEVICE);
+    // NOTE: Necessary to insert the above call for every output tensor
+    //total number of filter elem
+    const int num_filter_elem = KH * KW * input->dims.dim_sizes[1];
+
+    __half * convData;
+    int convDataSize = sizeof(__half) * n * num_filter_elem * h_eff * w;
+    checkCudaErrors(cudaMalloc(&convData, convDataSize));
+
+    const int blockSize = 256;
+    const int gridSize = (n * input->dims.dim_sizes[1] * h_eff * w + blockSize - 1) / blockSize;
+
+    convToGemmPerfRowHalf<<<gridSize, blockSize>>>(convData, (__half *)input->gpu_half_data, n,
+						   input->dims.dim_sizes[1], input->dims.dim_sizes[2],
+						   input->dims.dim_sizes[3], KH, KW, vertical_pad,
+						   horizontal_pad, h, w,
+						   vertical_stride, horizontal_stride, row, start, h_eff);
+
+
+    checkCudaErrors(cudaDeviceSynchronize());
+
+    const __half alf = approx_float_to_half(1.0);
+    const __half bet = approx_float_to_half(0.0);
+    const __half *alpha_half = &alf;
+    const __half *beta_half = &bet;
+
+    checkCudaErrors(cublasGemmEx(cublasHandle, CUBLAS_OP_N, CUBLAS_OP_N,
+				 n * h_eff * w, c, num_filter_elem,
+				 alpha_half,
+				 convData, CUDA_R_16F, n * h_eff * w,
+				 (__half*) filter->gpu_half_data, CUDA_R_16F, num_filter_elem,
+				 beta_half,
+				 (__half*) output_half->gpu_half_data, CUDA_R_16F, n * h_eff * w,
+				 CUDA_R_16F, CUBLAS_GEMM_DEFAULT_TENSOR_OP) );
+
+    
+    new_output = (Tensor*)create4DTensor((cudnnDataType_t) half_type,
+					 CUDNN_TENSOR_NCHW, n, c, h, w);
+
+    // NOTE: Changing output tensor placement from host to device
+    changeTensorPlacement(new_output, DEVICE);
+
+    //interpolate
+    int numBlocks = (n * c * h * w  + 255) / 256;
+    approxInterpolateRowHalf<<<numBlocks,256>>>(n * c * h * w, h_eff, n, c, h, w,
+						(__half *)output_half->gpu_half_data,
+						(__half *)new_output->gpu_half_data,
+						row, start);
+    cudaDeviceSynchronize();
+
+    cudaFree(output_half);
+    cudaFree(convData);
+  }
+  else if(col > 1){
+    output_half = (Tensor*)create4DTensor((cudnnDataType_t) half_type,
+					  CUDNN_TENSOR_NCHW, n, c, h, w_eff);
+
+    // NOTE: Changing output tensor placement from host to device
+    changeTensorPlacement(output_half, DEVICE);
+    // NOTE: Necessary to insert the above call for every output tensor
+    //total number of filter elem
+    const int num_filter_elem = KH * KW * input->dims.dim_sizes[1];
+
+    __half * convData;
+    int convDataSize = sizeof(__half) * n * num_filter_elem * h * w_eff;
+    checkCudaErrors(cudaMalloc(&convData, convDataSize));
+
+    const int blockSize = 256;
+    const int gridSize = (n * input->dims.dim_sizes[1] * h * w_eff + blockSize - 1) / blockSize;
+
+    convToGemmPerfColHalf<<<gridSize, blockSize>>>(convData, (__half *)input->gpu_half_data, n,
+						   input->dims.dim_sizes[1], input->dims.dim_sizes[2],
+						   input->dims.dim_sizes[3], KH, KW, vertical_pad,
+						   horizontal_pad, h, w,
+						   vertical_stride, horizontal_stride, col, start, w_eff);
+
+
+    checkCudaErrors(cudaDeviceSynchronize());
+
+    const __half alf = approx_float_to_half(1.0);
+    const __half bet = approx_float_to_half(0.0);
+    const __half *alpha_half = &alf;
+    const __half *beta_half = &bet;
+
+    
+    checkCudaErrors(cublasGemmEx(cublasHandle, CUBLAS_OP_N, CUBLAS_OP_N,
+				 n * h * w_eff, c, num_filter_elem,
+				 alpha_half,
+				 convData, CUDA_R_16F, n * h * w_eff,
+				 (__half*) filter->gpu_half_data, CUDA_R_16F, num_filter_elem,
+				 beta_half,
+				 (__half*) output_half->gpu_half_data, CUDA_R_16F, n * h * w_eff,
+				 CUDA_R_16F, CUBLAS_GEMM_DEFAULT_TENSOR_OP) );
+
+    
+    new_output = (Tensor*)create4DTensor((cudnnDataType_t) half_type,
+					 CUDNN_TENSOR_NCHW, n, c, h, w);
+
+    // NOTE: Changing output tensor placement from host to device
+    changeTensorPlacement(new_output, DEVICE);
+
+    //interpolate
+    int numBlocks = (n * c * h * w  + 255) / 256;
+    approxInterpolateColHalf<<<numBlocks,256>>>(n * c * h * w, w_eff, n, c, h, w,
+						(__half *)output_half->gpu_half_data,
+						(__half *)new_output->gpu_half_data,
+						col, start);
+    
+    cudaDeviceSynchronize();
+
+    cudaFree(output_half);
+    cudaFree(convData);
+
+  }
+  else{
+    output_half = (Tensor*)create4DTensor((cudnnDataType_t) half_type,
+					  CUDNN_TENSOR_NCHW, c, n, h, w);
+
+    // NOTE: Changing output tensor placement from host to device
+    changeTensorPlacement(output_half, DEVICE);
+    // NOTE: Necessary to insert the above call for every output tensor
+    //total number of filter elem
+    const int num_filter_elem = KH * KW * input->dims.dim_sizes[1];
+
+    __half * convData;
+    int convDataSize = sizeof(__half) * n * num_filter_elem * h * w;
+    checkCudaErrors(cudaMalloc(&convData, convDataSize));
+
+    const int blockSize = 256;
+    const int gridSize = (n * input->dims.dim_sizes[1] * h * w + blockSize - 1) / blockSize;
+    convToGemmApproxHalfN<<<gridSize, blockSize>>>(convData, (__half *)input->gpu_half_data, n,
+						  input->dims.dim_sizes[1], input->dims.dim_sizes[2],
+						  input->dims.dim_sizes[3], KH, KW, vertical_pad, horizontal_pad, h, w,
+						  vertical_stride, horizontal_stride, num_filter_elem, c * h * w);
+    checkCudaErrors(cudaDeviceSynchronize());
+    //Do the matrix multiplication. Want to multiply convData by filter->gpu_data[f * chan * KH * KW]
+    const __half alf = approx_float_to_half(1.0);
+    const __half bet = approx_float_to_half(0.0);
+    const __half *alpha_half = &alf;
+    const __half *beta_half = &bet;
+
+    checkCudaErrors(cublasGemmEx(cublasHandle, CUBLAS_OP_N, CUBLAS_OP_N,
+				 n * h * w, c, num_filter_elem,
+				 alpha_half,
+				 convData, CUDA_R_16F, n * h * w,
+				 (__half*) filter->gpu_half_data, CUDA_R_16F, num_filter_elem,
+				 beta_half,
+				 (__half*) output_half->gpu_half_data, CUDA_R_16F, n * h * w,
+				 CUDA_R_16F, CUBLAS_GEMM_DEFAULT_TENSOR_OP) );
+
+
+
+    // profileEvent("gemm_end", true);
+    new_output = (Tensor*)create4DTensor((cudnnDataType_t) half_type,
+					  CUDNN_TENSOR_NCHW, n, c, h, w);
+    changeTensorPlacement(new_output, DEVICE);
+
+    
+    int numBlocks = (n * c * h * w  + 255) / 256;
+    switchMatrix<<<numBlocks,256>>>(n * c * h * w, n, c, h, w,
+				    (__half *)output_half->gpu_half_data,
+				    (__half *)new_output->gpu_half_data);
+
+    checkCudaErrors(cudaDeviceSynchronize());
+    
+    cudaFree(convData);
+    cudaFree(output_half);
+  }
+
+  profileEvent("Conv_end", true);
+
+  profileEvent("H2F_start");
+
+  convertToFP32(new_output);
+
+  profileEvent("H2F_end");
+
+
+  #ifdef ERROR_INJECTION_ENABLED
+  if (op_counter >= total_ops) {
+    ERROR("No accuracy flag found \n");
+  }
+  int op_acc = op_accuracies[op_counter];
+  // Skip errorInjection if explicitly requested
+  if (skip_tensors.find(op_counter) != skip_tensors.end()) {
+    op_acc = 0;
+  }
+  void* error_norms = tensorAddError(output, op_acc);
+  add_norms(error_norms, "tensorConv", op_acc);
+  add_conv_overheads(input, filter, vertical_stride, horizontal_stride, op_acc);
+  op_counter++;
+  #endif
+  return new_output;
+}
+