diff --git a/hpvm/projects/hpvm-tensor-rt/tensor_runtime/src/approx_techniques.cu b/hpvm/projects/hpvm-tensor-rt/tensor_runtime/src/approx_techniques.cu
index b97e5beadb7822cce12bdf2ee4d16407cd0483c4..546eb8286390fd5ff8e9bb09bd628a03f14b9f38 100644
--- a/hpvm/projects/hpvm-tensor-rt/tensor_runtime/src/approx_techniques.cu
+++ b/hpvm/projects/hpvm-tensor-rt/tensor_runtime/src/approx_techniques.cu
@@ -2,10 +2,16 @@
 //
 //===----------------------------------------------------------------------===//
 //
-//  This file  consists of the custom implementation of software approximations
-// for tensor convolutions. The approximations implemented are feature sampling
-// and perforation for FP32 and FP16 compute precisions.
+//  This file  consists of our CUDA-based implementation for convolution approximations
 //
+//  *Supported Approximations: Perforated Convolutions, Filter Sampling
+//
+//  FP32 Convolution Routine:  `tensorConvApprox`
+//  FP16 Convolution Routine:  `tensorConvApproxHalf2`
+// 
+//  NOTE: These approximations are tuned for NVIDIA Jetson Tx2 device
+//
+//  Author: Akash Kothari
 //===----------------------------------------------------------------------===//
 
 #include "tensor_utils.h"
@@ -209,6 +215,7 @@ __global__ void convToGemmHalfInputNewIrregular2(
     const int W_out, const int V_stride, const int H_stride,
     const int reduced_filter_elem, const int skip_every,
     const int skip_offset) {
+  
   const int tx = blockDim.x * blockIdx.x + threadIdx.x; // thread id
   const int n = tx / (C * H_out * W_out);               // output image number
   const int c = tx % (C * H_out * W_out) / (H_out * W_out); // output chan
@@ -1172,10 +1179,9 @@ convToGemmApprox(float *const __restrict__ output,
 }
 
 /// This function serves as an API with the custom implementation of convolution
-/// with the perforation and filter sampling support. The compute precison is
-/// FP32. This routine is invoked by the tuner for tuning approximations for
-/// convolutions.
-///
+/// with the perforation and filter sampling support. The compute precison is FP32.
+/// NOTE: This routine is used only for correctness testing
+/// NOTE: This is NOT the main approximation routine used by HPVM 
 void *tensorConvPerfCuda(void *input_ptr, void *filter_ptr, int vertical_pad,
                          int horizontal_pad, int vertical_stride,
                          int horizontal_stride, int conv_mode, int conv_groups,
@@ -1331,8 +1337,6 @@ void *tensorConvPerfCuda(void *input_ptr, void *filter_ptr, int vertical_pad,
         vertical_pad, horizontal_pad, h, w, vertical_stride, horizontal_stride,
         num_filter_elem, c * h * w);
     checkCudaErrors(cudaDeviceSynchronize());
-    // Do the matrix multiplication
-    // Want to multiply convData by filter->gpu_data[f * chan * KH * KW]
 
     float alpha = 1.0f, beta = 0.0f;
     checkCudaErrors(cublasSgemmStridedBatched(
@@ -1345,7 +1349,6 @@ void *tensorConvPerfCuda(void *input_ptr, void *filter_ptr, int vertical_pad,
     cudaFree(convData);
   }
 
-  // Event("Conv_end"); //, true);
   return new_output;
 }
 
@@ -1364,18 +1367,22 @@ __global__ void switchMatrixFull(int N, int n, int c, int h, int w,
   }
 }
 
-/// This function serves as an API with the custom implementation of convolution
-/// with the perforation and filter sampling support. The compute precison is
-/// FP32.
+  
+/*************   API for Approximation Convolution Implementations  ************/
+
+///  ** API for FP32 Convolution that supports Baseline (No Approx), Perforation, and Filter Sampling **
+/// - Arguments to control Approximation:
+///    `row`: Controls the fraction of rows skipped (Perforation) - (1/row * 100)% rows skipped
+///    `col`: Controls fraction of columns skipped (Perforation) - (1/col * 100)% columns skipped  
+///    `skip_every`: Controls fration of filter elements skipped (Filter Sampling). (1/skip_every * 100)% filter elems skipped
+///    `offset` controls the tensor index at which sampling/perforation starts
 ///
+///   For Baseline convolution pass `row=1` `col=1` `skip_every = 1`
 void *tensorConvApprox(void *input_ptr, void *filter_ptr, int vertical_pad,
                        int horizontal_pad, int vertical_stride,
                        int horizontal_stride, int conv_mode, int conv_groups,
                        int row, int col, int skip_every, int offset) {
 
-  //////INFO("*** TensorConvolution approximation \n");
-  // Event("Conv");
-
   Tensor *input = (Tensor *)input_ptr;
   Tensor *filter = (Tensor *)filter_ptr;
   // FIXME: Current hack to preserve backward compatibilty
@@ -1386,36 +1393,22 @@ void *tensorConvApprox(void *input_ptr, void *filter_ptr, int vertical_pad,
   hostToDeviceCopy(input);
   hostToDeviceCopy(filter);
 
-  ////Event("H2F_start");
   convertToFP32(input);
   convertToFP32(filter);
-  ////Event("H2F_end");
 
   const int n = input->dims.dim_sizes[0];
   const int c = filter->dims.dim_sizes[0]; // number of filters
   const int KH = filter->dims.dim_sizes[2];
   const int KW = filter->dims.dim_sizes[3];
-  const int h =
-      (2 * vertical_pad + input->dims.dim_sizes[2] - KH) / vertical_stride + 1;
-  const int w =
-      (2 * horizontal_pad + input->dims.dim_sizes[3] - KW) / horizontal_stride +
-      1;
+  const int h = (2 * vertical_pad + input->dims.dim_sizes[2] - KH) / vertical_stride + 1;
+  const int w = (2 * horizontal_pad + input->dims.dim_sizes[3] - KW) / horizontal_stride + 1;
   const int num_filter_elem = KH * KW * input->dims.dim_sizes[1];
-
+  
   Tensor *new_output = (Tensor *)create4DTensor((cudnnDataType_t)float_type,
                                                 CUDNN_TENSOR_NCHW, n, c, h, w);
   // NOTE: Changing output tensor placement from host to device
   changeTensorPlacement(new_output, DEVICE);
-  ////INFO("batch: %d\n", n);
-  ////INFO("channels: %d\n", input->dims.dim_sizes[1]);
-  ////INFO("num_filters: %d\n", c);
-  ////INFO("kernel height: %d\n", KH);
-  ////INFO("kernel width: %d\n", KW);
-  ////INFO("num_filter_elem: %d\n", num_filter_elem);
-  ////INFO("vertical_stride: %d\n", vertical_stride);
-  ////INFO("horizontal_stride: %d\n", horizontal_stride);
-  ////INFO("output height: %d\n", h);
-  ////INFO("output width: %d\n", w);
+ 
   if (row > 1) {
     const int rem_row = (h - offset) % row > 0;
     const int h_eff = h - ((h - offset) / row) - rem_row;
@@ -1432,8 +1425,6 @@ void *tensorConvApprox(void *input_ptr, void *filter_ptr, int vertical_pad,
     checkCudaErrors(cudaMalloc(&convData, convDataSize));
 
     const int blockSize = 128;
-    ////INFO("n * input->dims.dim_sizes[1] * h_eff * w: %d\n", (n *
-    /// input->dims.dim_sizes[1] * h_eff * w));
     const int gridSize =
         (n * input->dims.dim_sizes[1] * h_eff * w + blockSize - 1) / blockSize;
     convToGemmPerfRow<<<gridSize, blockSize>>>(
@@ -1464,7 +1455,7 @@ void *tensorConvApprox(void *input_ptr, void *filter_ptr, int vertical_pad,
     const int w_eff = w - ((w - offset) / col) - rem_col;
 
     Tensor *output = (Tensor *)create4DTensor(
-        (cudnnDataType_t)float_type, // input->data_type,
+        (cudnnDataType_t)float_type, 
         CUDNN_TENSOR_NCHW, n, c, h, w_eff);
 
     // NOTE: Changing output tensor placement from host to device
@@ -1475,8 +1466,6 @@ void *tensorConvApprox(void *input_ptr, void *filter_ptr, int vertical_pad,
     checkCudaErrors(cudaMalloc(&convData, convDataSize));
 
     const int blockSize = 128;
-    ////INFO("n * input->dims.dim_sizes[1] * h * w_eff: %d\n", (n *
-    /// input->dims.dim_sizes[1] * h * w_eff));
     const int gridSize =
         (n * input->dims.dim_sizes[1] * h * w_eff + blockSize - 1) / blockSize;
 
@@ -1494,7 +1483,7 @@ void *tensorConvApprox(void *input_ptr, void *filter_ptr, int vertical_pad,
         (float *)filter->gpu_data, num_filter_elem, 0, &beta,
         (float *)output->gpu_data, h * w_eff, c * h * w_eff, n));
 
-    // interpolate
+    // Interpolate
     int blocksize = 128;
     int numBlocks = (n * c * h * w + blocksize - 1) / blocksize;
     approxInterpolateCol<<<numBlocks, blocksize>>>(
@@ -1518,16 +1507,13 @@ void *tensorConvApprox(void *input_ptr, void *filter_ptr, int vertical_pad,
         cudaMalloc(&reducedFilter, sizeof(float) * c * reduced_filter_elem));
 
     const int filtBlockSize = 128;
-    ////INFO("c * reduced_filter_elem: %d\n", (c * reduced_filter_elem));
-    const int filtGridSize =
-        (c * reduced_filter_elem + filtBlockSize - 1) / filtBlockSize;
+    const int filtGridSize = (c * reduced_filter_elem + filtBlockSize - 1) / filtBlockSize;
     const float fac = ((float)skip_every) / ((float)skip_every - 1);
-    //////INFO("fac: %f\n", fac);
     const int blockSize = 128;
-    //////INFO("n * h * w : %d\n", (n * h * w ));
     const int gridSize = (n * h * w + blockSize - 1) / blockSize;
+    
     if (!(KH * KW % skip_every)) {
-      // ////INFO("REGULAR FILTERING\n");
+
       createReducedFiltersFullRegular<<<filtGridSize, filtBlockSize>>>(
           reducedFilter, (float *)filter->gpu_data, c, num_filter_elem,
           reduced_filter_elem, input->dims.dim_sizes[1], skip_every, offset,
@@ -1538,8 +1524,8 @@ void *tensorConvApprox(void *input_ptr, void *filter_ptr, int vertical_pad,
           input->dims.dim_sizes[2], input->dims.dim_sizes[3], KH, KW,
           vertical_pad, horizontal_pad, h, w, vertical_stride,
           horizontal_stride, reduced_filter_elem, skip_every, offset);
-    } else {
-      // ////INFO("IRREGULAR FILTERING\n");
+    }
+    else {
       createReducedFiltersFullIrregular<<<filtGridSize, filtBlockSize>>>(
           reducedFilter, (float *)filter->gpu_data, c, num_filter_elem,
           reduced_filter_elem, skip_every, offset, fac);
@@ -1563,7 +1549,6 @@ void *tensorConvApprox(void *input_ptr, void *filter_ptr, int vertical_pad,
     cudaFree(reducedFilter);
   } else {
 
-    // INFO("FP32 BASELINE\n");
     Tensor *output = (Tensor *)create4DTensor((cudnnDataType_t)float_type,
                                               CUDNN_TENSOR_NCHW, n, c, h, w);
     changeTensorPlacement(output, DEVICE);
@@ -1575,25 +1560,17 @@ void *tensorConvApprox(void *input_ptr, void *filter_ptr, int vertical_pad,
     const int blockSize = 128;
     const int gridSize =
         (n * input->dims.dim_sizes[1] * h * w + blockSize - 1) / blockSize;
-    //////INFO("n * input->dims.dim_sizes[1] * h * w: %d\n", (n *
-    /// input->dims.dim_sizes[1] * h * w));
+
     convToGemmFullInput<<<gridSize, blockSize>>>(
         convData, (float *)input->gpu_data, n, input->dims.dim_sizes[1],
         input->dims.dim_sizes[2], input->dims.dim_sizes[3], KH, KW,
         vertical_pad, horizontal_pad, h, w, vertical_stride, horizontal_stride,
-        skip_every, offset); // num_filter_elem);
+        skip_every, offset); 
+    
     checkCudaErrors(cudaDeviceSynchronize());
 
     float alpha = 1.0f, beta = 0.0f;
-    /*
-    checkCudaErrors(cublasSgemmStridedBatched(cublasHandle,
-                                         CUBLAS_OP_N, CUBLAS_OP_N,
-                                           h * w, c, num_filter_elem,
-                                           &alpha,
-                                           convData, h * w, num_filter_elem * h
-    * w, (float *)filter->gpu_data, num_filter_elem, 0, &beta, (float
-    *)new_output->gpu_data, h * w, c * h * w, n));
-   */
+
     checkCudaErrors(cublasGemmEx(
         cublasHandle, CUBLAS_OP_N, CUBLAS_OP_N, n * h * w, c, num_filter_elem,
         &alpha, convData, CUDA_R_32F, n * h * w, (float *)filter->gpu_data,
@@ -1609,7 +1586,6 @@ void *tensorConvApprox(void *input_ptr, void *filter_ptr, int vertical_pad,
     cudaFree(convData);
   }
 
-  // Event("Conv_end");
   return new_output;
 }
 
@@ -1628,21 +1604,27 @@ __global__ void switchMatrixHalf(int N, int n, int c, int h, int w,
   }
 }
 
-/// This function serves as an API to custom implementation of the
-/// half-precision convolution with the perforation and filter sampling
-/// support.
+
+
+
+///  ** API for FP16 Convolution that supports Baseline (No Approx), Perforation, and Filter Sampling **
+/// - Arguments to control Approximation:
+///    `row`: Controls the fraction of rows skipped (Perforation) - (1/row * 100)% rows skipped
+///    `col`: Controls fraction of columns skipped (Perforation) - (1/col * 100)% columns skipped  
+///    `skip_every`: Controls fration of filter elements skipped (Filter Sampling). (1/skip_every * 100)% filter elems skipped
+///    `offset` controls the tensor index at which sampling/perforation starts
 ///
+///   For Baseline convolution pass `row=1` `col=1` `skip_every = 1`
 void *tensorConvApproxHalf2(void *input_ptr, void *filter_ptr, int vertical_pad,
                             int horizontal_pad, int vertical_stride,
                             int horizontal_stride, int conv_mode,
                             int conv_groups, int row, int col, int skip_every,
                             int offset) {
 
-  // INFO("*** TensorConvolution half approximation \n");
-  // profileEvent("#Conv");
-
+ 
   Tensor *input = (Tensor *)input_ptr;
   Tensor *filter = (Tensor *)filter_ptr;
+
   // FIXME: Current hack to preserve backward compatibilty
   if (conv_groups == 0) {
     conv_groups = 1;
@@ -1670,18 +1652,7 @@ void *tensorConvApproxHalf2(void *input_ptr, void *filter_ptr, int vertical_pad,
   Tensor *new_output = (Tensor *)create4DTensor((cudnnDataType_t)half_type,
                                                 CUDNN_TENSOR_NCHW, n, c, h, w);
   changeTensorPlacement(new_output, DEVICE);
-  // INFO("batch: %d\n", n);
-  // INFO("channels: %d\n", input->dims.dim_sizes[1]);
-  // INFO("num_filters: %d\n", c);
-  // INFO("kernel height: %d\n", KH);
-  // INFO("kernel width: %d\n", KW);
-  // INFO("num_filter_elem: %d\n", num_filter_elem);
-  // INFO("num_filters * num_filter_elem: %d\n", c * num_filter_elem);
-  // INFO("vertical_stride: %d\n", vertical_stride);
-  // INFO("horizontal_stride: %d\n", horizontal_stride);
-  // INFO("output height: %d\n", h);
-  // INFO("output width: %d\n", w);
-  // INFO("skip_every: %d\n", skip_every);
+
   const __half alf = approx_float_to_half(1.0);
   const __half bet = approx_float_to_half(0.0);
   const __half *alpha_half = &alf;
@@ -1707,7 +1678,7 @@ void *tensorConvApproxHalf2(void *input_ptr, void *filter_ptr, int vertical_pad,
     const int numInterpolationBlocks =
         (n * c * h * w + interpolationBlocksize - 1) / interpolationBlocksize;
     if (h * w <= 64) {
-      // INFO("H *W <= 64\n");
+
       convToGemmPerfRowHalf2<<<numPatchBlocks, patchBlockSize>>>(
           convData, (__half *)input->gpu_half_data, n, input->dims.dim_sizes[1],
           input->dims.dim_sizes[2], input->dims.dim_sizes[3], KH, KW,
@@ -1730,7 +1701,7 @@ void *tensorConvApproxHalf2(void *input_ptr, void *filter_ptr, int vertical_pad,
       checkCudaErrors(cudaDeviceSynchronize());
 
     } else {
-      // INFO("H *W > 64\n");
+
       convToGemmPerfRowHalf<<<numPatchBlocks, patchBlockSize>>>(
           convData, (__half *)input->gpu_half_data, n, input->dims.dim_sizes[1],
           input->dims.dim_sizes[2], input->dims.dim_sizes[3], KH, KW,
@@ -1773,7 +1744,7 @@ void *tensorConvApproxHalf2(void *input_ptr, void *filter_ptr, int vertical_pad,
     const int numInterpolationBlocks =
         (n * c * h * w + interpolationBlocksize - 1) / interpolationBlocksize;
     if (h * w <= 64) {
-      // INFO("H *W <= 64\n");
+
       convToGemmPerfColHalf2<<<numPatchBlocks, patchBlockSize>>>(
           convData, (__half *)input->gpu_half_data, n, input->dims.dim_sizes[1],
           input->dims.dim_sizes[2], input->dims.dim_sizes[3], KH, KW,
@@ -1794,8 +1765,8 @@ void *tensorConvApproxHalf2(void *input_ptr, void *filter_ptr, int vertical_pad,
           (__half *)output_half->gpu_half_data,
           (__half *)new_output->gpu_half_data, col, offset);
       checkCudaErrors(cudaDeviceSynchronize());
-    } else {
-      // INFO("H *W > 64\n");
+    }
+    else {
       convToGemmPerfColHalf<<<numPatchBlocks, patchBlockSize>>>(
           convData, (__half *)input->gpu_half_data, n, input->dims.dim_sizes[1],
           input->dims.dim_sizes[2], input->dims.dim_sizes[3], KH, KW,
@@ -1836,18 +1807,15 @@ void *tensorConvApproxHalf2(void *input_ptr, void *filter_ptr, int vertical_pad,
         (c * reduced_filter_elem + filtBlockSize - 1) / filtBlockSize;
     const float fac = ((float)skip_every) / ((float)skip_every - 1);
     const int blockSize = 256;
-    // const int gridSize = (n * h * w + blockSize - 1) / blockSize;
-    // INFO("reduced_filter_elem: %d\n", (reduced_filter_elem));
-    // INFO("c * reduced_filter_elem: %d\n", (c * reduced_filter_elem));
+
     const __half alf = approx_float_to_half(1.0);
     const __half bet = approx_float_to_half(0.0);
     const __half *alpha_half = &alf;
     const __half *beta_half = &bet;
-    if (c * num_filter_elem <
-        500000) { // 250) {//c * reduced_filter_elem < 150000) {
+    if (c * num_filter_elem <  500000) { 
       if (!(KH * KW % skip_every)) {
-        // INFO("---REGULAR FILTERING\n");
-        createReducedFiltersHalfRegular<<<filtGridSize, filtBlockSize>>>(
+
+	createReducedFiltersHalfRegular<<<filtGridSize, filtBlockSize>>>(
             reducedFilter, (__half *)filter->gpu_half_data, c, num_filter_elem,
             reduced_filter_elem, input->dims.dim_sizes[1], skip_every, offset,
             fac);
@@ -1862,16 +1830,16 @@ void *tensorConvApproxHalf2(void *input_ptr, void *filter_ptr, int vertical_pad,
             w, vertical_stride, horizontal_stride, reduced_filter_elem,
             skip_every, offset);
       } else {
-        // INFO("---IRREGULAR FILTERING\n");
-        createReducedFiltersHalfIrregular<<<filtGridSize, filtBlockSize>>>(
+
+	createReducedFiltersHalfIrregular<<<filtGridSize, filtBlockSize>>>(
             reducedFilter, (__half *)filter->gpu_half_data, c, num_filter_elem,
             reduced_filter_elem, skip_every, offset, fac);
         checkCudaErrors(cudaDeviceSynchronize());
 
         const int gridSize =
             (n * h * w * input->dims.dim_sizes[1] + blockSize - 1) / blockSize;
-        // convToGemmHalfInputIrregular
-        convToGemmHalfInputNewIrregular<<<gridSize, blockSize>>>(
+
+	convToGemmHalfInputNewIrregular<<<gridSize, blockSize>>>(
             convData, (__half *)input->gpu_half_data, n,
             input->dims.dim_sizes[1], input->dims.dim_sizes[2],
             input->dims.dim_sizes[3], KH, KW, vertical_pad, horizontal_pad, h,
@@ -1891,8 +1859,8 @@ void *tensorConvApproxHalf2(void *input_ptr, void *filter_ptr, int vertical_pad,
       changeTensorPlacement(output_half, DEVICE);
 
       if (!(KH * KW % skip_every)) {
-        // INFO("REGULAR FILTERING\n");
-        createReducedFiltersHalfRegular<<<filtGridSize, filtBlockSize>>>(
+
+	createReducedFiltersHalfRegular<<<filtGridSize, filtBlockSize>>>(
             reducedFilter, (__half *)filter->gpu_half_data, c, num_filter_elem,
             reduced_filter_elem, input->dims.dim_sizes[1], skip_every, offset,
             fac);
@@ -1907,8 +1875,8 @@ void *tensorConvApproxHalf2(void *input_ptr, void *filter_ptr, int vertical_pad,
             w, vertical_stride, horizontal_stride, reduced_filter_elem,
             skip_every, offset);
       } else {
-        // INFO("IRREGULAR FILTERING\n");
-        createReducedFiltersHalfIrregular<<<filtGridSize, filtBlockSize>>>(
+
+	createReducedFiltersHalfIrregular<<<filtGridSize, filtBlockSize>>>(
             reducedFilter, (__half *)filter->gpu_half_data, c, num_filter_elem,
             reduced_filter_elem, skip_every, offset, fac);
         checkCudaErrors(cudaDeviceSynchronize());
@@ -1943,7 +1911,7 @@ void *tensorConvApproxHalf2(void *input_ptr, void *filter_ptr, int vertical_pad,
     cudaFree(convData);
     cudaFree(reducedFilter);
   } else {
-    // INFO("FP16 BASELINE\n");
+
     Tensor *output = (Tensor *)create4DTensor((cudnnDataType_t)half_type,
                                               CUDNN_TENSOR_NCHW, n, c, h, w);
 
@@ -1955,7 +1923,7 @@ void *tensorConvApproxHalf2(void *input_ptr, void *filter_ptr, int vertical_pad,
     const int blockSize = 256;
     const int gridSize =
         (n * input->dims.dim_sizes[1] * h * w + blockSize - 1) / blockSize;
-    // convToGemmHalf
+
     convToGemmHalfInputNew<<<gridSize, blockSize>>>(
         convData, (__half *)input->gpu_half_data, n, input->dims.dim_sizes[1],
         input->dims.dim_sizes[2], input->dims.dim_sizes[3], KH, KW,
diff --git a/hpvm/projects/keras/keras_frontend/hpvm_dfg_translator.py b/hpvm/projects/keras/keras_frontend/hpvm_dfg_translator.py
index 41d8b4a8491ced3e8281164aaf7ea49d3ec68551..5c7fbf9a8c8c56e803f861e9db7f37a0a632bd5f 100644
--- a/hpvm/projects/keras/keras_frontend/hpvm_dfg_translator.py
+++ b/hpvm/projects/keras/keras_frontend/hpvm_dfg_translator.py
@@ -556,10 +556,6 @@ class HPVMTranslator:
     headers += "#include <cstring> \n"
     
     headers += "#include <" + HPVM_header +  "> \n"
-    #if LLVM_9_BRANCH:
-    #   headers += "#include \"config.h\" \n"
-    
-    headers += "#include <tensorTypes.h> \n"
     headers += "#include <tensorUtils.h> \n\n"
 
     self.file_header_str = headers
diff --git a/hpvm/test/dnn_benchmarks/keras/Benchmark.py b/hpvm/test/dnn_benchmarks/keras/Benchmark.py
index f3d8e9e6b2268618dc835e3d27374a8f7d738a86..ae0adc1a51d07788976d7b2d43a5e890ecb70adc 100644
--- a/hpvm/test/dnn_benchmarks/keras/Benchmark.py
+++ b/hpvm/test/dnn_benchmarks/keras/Benchmark.py
@@ -59,17 +59,18 @@ class Benchmark:
         except:
             print("""
 
-ERROR: Could not find hpvm-clang (HPVM compile script)!!
+            ERROR: Could not find hpvm-clang (HPVM compile script)!!
 
-hpvm-clang is installed to the python environment used when compiling HPVM.
-Please try rerunning 'make -j hpvm-clang'.""")
+            hpvm-clang is installed to the python environment used when compiling HPVM.
+            Please try rerunning 'make -j hpvm-clang' and make sure `hpvm-clang` is in your $PATH""")
+            
             sys.exit(1)
 
-
         try:
             subprocess.run([
                 "hpvm-clang", src_file, target_binary,
-                "-t", "tensor", "--conf-file", approx_conf_file
+                "-t", "tensor", "--conf-file", approx_conf_file,
+                "-fno-exceptions"
             ], check=True)
         except:
             print ("\n\n ERROR: HPVM Compilation Failed!! \n\n")
diff --git a/hpvm/test/dnn_benchmarks/keras/Config.py b/hpvm/test/dnn_benchmarks/keras/Config.py
index 99e696d632c50db4ae8098a2f4836ca994b672aa..851350a5ac2fe812c9f811dab4d79d6de41d815f 100644
--- a/hpvm/test/dnn_benchmarks/keras/Config.py
+++ b/hpvm/test/dnn_benchmarks/keras/Config.py
@@ -3,8 +3,10 @@ import pathlib
 
 
 # Path Relative to Model Params Directory
-abs_path = pathlib.Path(__file__).parent.absolute()
-MODEL_PARAMS_DIR = str(abs_path) + "/../../../../hpvm/test/dnn_benchmarks/model_params/"
+CUR_SRC_PATH = str(pathlib.Path(__file__).parent.absolute())
+MODEL_PARAMS_DIR = CUR_SRC_PATH + "/../../../../hpvm/test/dnn_benchmarks/model_params/"
+
+
 
 
 if __name__ == "__main__":
diff --git a/hpvm/test/dnn_benchmarks/keras/test_benchmarks.py b/hpvm/test/dnn_benchmarks/keras/test_benchmarks.py
index 2d4b8afab532d7632de9968236690bc63798fc1e..bda3d4186f6d15e31ec0fa6ee547ca0dd0b29968 100644
--- a/hpvm/test/dnn_benchmarks/keras/test_benchmarks.py
+++ b/hpvm/test/dnn_benchmarks/keras/test_benchmarks.py
@@ -3,25 +3,17 @@
 import os
 import sys
 import subprocess
-
-import site
-from pathlib import Path
-
-import torch
-from torch.utils.data.dataloader import DataLoader
-from torch.utils.data.dataset import Subset
-
-#site.addsitedir(Path(__file__).parent.parent.absolute().as_posix())
-#from predtuner import PipedBinaryApp, config_pylogger
-
+import argparse
+from Config import *
 
 
 class Benchmark:
 
-    def __init__(self, binary_path, test_accuracy):
+    def __init__(self, binary_path, output_dir, test_accuracy):
 
         self.binary_path = binary_path
         self.test_accuracy = test_accuracy
+        self.output_dir = output_dir
         self.epsilon = 0.05 # Adding some slack for accuracy difference
 
 
@@ -60,19 +52,28 @@ class Benchmark:
         return test_success
 
 
-    def runHPVM(self):
+    def runHPVM(self, weights_dump):
 
-        # Test Bechmark accuracy with pretrained weights (hpvm_relaod)
-        run_cmd = "python3 " + self.binary_path + " keras_reload frontend compile compile_tuner"
+        if weights_dump:
+          # Test Benchmark with Keras weight dumping
+          run_cmd = "python3 " + self.binary_path + " keras_reload frontend compile compile_tuner"
+        else:
+          # Test Benchmark accuracy with pretrained weights (hpvm_relaod)
+          run_cmd = "python3 " + self.binary_path + " hpvm_reload frontend compile compile_tuner"
+          
         try:
             subprocess.call(run_cmd, shell=True)
         except:
             return False
 
-        working_dir = open("working_dir.txt").read()
+        #working_dir = open("working_dir.txt").read()
         cur_dir = os.getcwd()
-        
+
+        working_dir = self.output_dir 
         os.chdir(working_dir)
+
+        print ("cur_dir = ", os.getcwd())
+        
         binary_path =  "./HPVM_binary"
         
         try:
@@ -96,31 +97,6 @@ class Benchmark:
         return test_success
 
 
-"""    
-    def runApproxTuner(self):
-
-        working_dir = open("working_dir.txt").read()
-        cur_dir = os.getcwd()
-        
-        os.chdir(working_dir)
-        binary_path =  "./HPVM_tuner_binary"
-
-        full_binary_path = str(cur_dir) + "/" +  working_dir + "/" + binary_path
-        full_json_path = str(cur_dir) + "/" + working_dir + "/tuner.json"
-    
-        app = PipedBinaryApp("TestHPVMApp", full_binary_path, full_json_path)
-        # Tuning procedure is exactly the same as that for PyTorch DNN.
-        # Please refer to `./tune_vgg16_cifar10.py` for details.
-        tuner = app.get_tuner()
-        tuner.tune(5000, 3.0, 3.0, True, 50, cost_model="cost_linear", qos_model="qos_p1")
-
-        tuner.dump_configs("configs.json")
-        fig = tuner.plot_configs(show_qos_loss=True)
-        fig.savefig("configs.png", dpi=300)
-        app.dump_hpvm_configs(tuner.best_configs, "hpvm_confs.txt")
-
-        os.chdir(cur_dir)  # Change back to original working directory
-"""
             
         
 
@@ -151,10 +127,10 @@ class BenchmarkTests:
                 self.passed_tests.append(benchmark.getPath())
 
 
-    def runHPVMTests(self):
+    def runHPVMTests(self, weights_dump):
 
         for benchmark in self.benchmarks:
-            test_success = benchmark.runHPVM()
+            test_success = benchmark.runHPVM(weights_dump)
 
             if not test_success:
                 self.failed_hpvm_tests.append(benchmark.getPath())
@@ -178,6 +154,7 @@ class BenchmarkTests:
             print ("Failed: " + failed_test)
             
 
+    # Returns False if any of the tests failed
     def printHPVMSummary(self):
 
         failed_test_count = len(self.failed_hpvm_tests)
@@ -192,47 +169,69 @@ class BenchmarkTests:
         print ("****** Failed Tests *** \n")
         for failed_test in self.failed_hpvm_tests:
             print ("Failed: " + failed_test)
-            
 
-        
+        if failed_test_count > 0:
+            return False
+
+        return True
+
+
+    
             
 if __name__ == "__main__":
 
-    if len(sys.argv) < 2:
-        print ("Usage: python3 test_dnnbenchmarks.py ${work_dir}")
 
-    work_dir = sys.argv[1]
-    if not os.path.exists(work_dir):
-        os.mkdir(work_dir)
+    parser = argparse.ArgumentParser(description='Process some integers.')
+
+    parser.add_argument('--work-dir', type=str,
+                        help='working dir for dumping frontend generated files')
+
+    parser.add_argument('--dump-weights', action="store_true", help='dump h5 weights to bin (default: False)')
+    args = parser.parse_args()
+
+    work_dir = args.work_dir
+    dump_weights = args.dump_weights
+
+    #print (dump_weights)
+    #sys.exit(0)
+    
+    if os.path.exists(work_dir):
+        print ("Work Directory Exists. Delete it or use a different work directory.")
+        sys.exit(0)
+        
+    os.mkdir(work_dir)
     os.chdir(work_dir)
     
     testMgr = BenchmarkTests()
-    AlexNet = Benchmark("../alexnet.py", 79.28)
-    AlexNet_ImageNet = Benchmark("../alexnet_imagenet.py", 56.30)
-    AlexNet2 = Benchmark("../alexnet2.py", 84.98)
-    LeNet = Benchmark("../lenet.py", 98.70)
-    MobileNet = Benchmark("../mobilenet_cifar10.py", 84.42)
-    ResNet18 = Benchmark("../resnet18_cifar10.py", 89.56)
-    ResNet50 = Benchmark("../resnet50_imagenet.py", 75.10)
-    VGG16_cifar10 = Benchmark("../vgg16_cifar10.py", 89.96)
-    VGG16_cifar100 = Benchmark("../vgg16_cifar100.py", 66.50)
-    VGG16_ImageNet = Benchmark("../vgg16_imagenet.py", 69.46)
-
-    testMgr.addBenchmark(AlexNet)
-    #testMgr.addBenchmark(AlexNet_ImageNet)
-    testMgr.addBenchmark(AlexNet2)
-    testMgr.addBenchmark(LeNet)
+    AlexNet = Benchmark(CUR_SRC_PATH + "/alexnet.py", "src/alexnet_cifar10_src_hpvm", 79.28)
+    AlexNet_ImageNet = Benchmark(CUR_SRC_PATH + "/alexnet_imagenet.py", "src/alexnet_imagenet_src", 56.30)
+    AlexNet2 = Benchmark(CUR_SRC_PATH + "/alexnet2.py", "src/alexnet2_cifar10_src", 84.98)
+    LeNet = Benchmark(CUR_SRC_PATH + "/lenet.py", "src/lenet_mnist_src", 98.70)
+    MobileNet = Benchmark(CUR_SRC_PATH + "/mobilenet_cifar10.py", "src/mobilenet_cifar10_src", 84.42)
+    ResNet18 = Benchmark(CUR_SRC_PATH + "/resnet18_cifar10.py", "src/resnet18_cifar10_src", 89.56)
+    ResNet50 = Benchmark(CUR_SRC_PATH + "/resnet50_imagenet.py", "src/resnet50_imagenet_src", 75.10)
+    VGG16_cifar10 = Benchmark(CUR_SRC_PATH + "/vgg16_cifar10.py", "src/vgg16_cifar10_src", 89.96)
+    VGG16_cifar100 = Benchmark(CUR_SRC_PATH + "/vgg16_cifar100.py", "src/vgg16_cifar100_src", 66.50)
+    VGG16_ImageNet = Benchmark(CUR_SRC_PATH + "/vgg16_imagenet.py", "src/vgg16_imagenet_src", 69.46)
+
+    #testMgr.addBenchmark(AlexNet)
+    #testMgr.addBenchmark(AlexNet2)
+    #testMgr.addBenchmark(LeNet)
     testMgr.addBenchmark(MobileNet)
     testMgr.addBenchmark(ResNet18)
     #testMgr.addBenchmark(ResNet50)
-    testMgr.addBenchmark(VGG16_cifar10)
+    #testMgr.addBenchmark(VGG16_cifar10)
     testMgr.addBenchmark(VGG16_cifar100)
     #testMgr.addBenchmark(VGG16_ImageNet)
-
+    testMgr.addBenchmark(AlexNet_ImageNet)
+  
     #testMgr.runKerasTests()
     #testMgr.printKerasSummary()
     
-    testMgr.runHPVMTests()
-    testMgr.printHPVMSummary()
+    testMgr.runHPVMTests(dump_weights)
+    tests_passed = testMgr.printHPVMSummary()
 
+    if not tests_passed:
+        sys.exit(-1)
     
+