From 86d33f5ab792a59957fc96e1a83d8ee3b82c3641 Mon Sep 17 00:00:00 2001
From: Hashim Sharif <hsharif3@tyler.cs.illinois.edu>
Date: Tue, 7 Apr 2020 16:37:51 -0500
Subject: [PATCH] Adding Sampling test for 1*1 filter

---
 .../dnn_sources/src/test_ops.cc               | 76 ++++++++++---------
 1 file changed, 40 insertions(+), 36 deletions(-)

diff --git a/llvm/projects/hpvm-tensor-rt/dnn_sources/src/test_ops.cc b/llvm/projects/hpvm-tensor-rt/dnn_sources/src/test_ops.cc
index 29733fd58b..a62ac42757 100644
--- a/llvm/projects/hpvm-tensor-rt/dnn_sources/src/test_ops.cc
+++ b/llvm/projects/hpvm-tensor-rt/dnn_sources/src/test_ops.cc
@@ -803,45 +803,11 @@ void testSampling_3_3(){
 
   testSamplingCalls(input, filter, 1, 1, 2, 2, 4);
 
-  
-  /*
-  //printTensorValues(input);
-
-  void* res = tensorConvApprox(input, filter, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1);
-  
-  printTensorValues(res);
-
-
-  void* res2 = tensorConvApprox(input, filter, 0, 0, 1, 1, 1, 1, 1, 1, 2, 1);
-  
-  printTensorValues(res2);
-
-
-  void* res2_sim = tensorConvSampSim(input, filter, 0, 0, 1, 1, 1, 1, 2, 0);
-  
-  printTensorValues(res2_sim);
-
-  
-  void* res3 = tensorConvApprox(input, filter, 0, 0, 1, 1, 1, 1, 1, 1, 2, 0);
-  
-  printTensorValues(res3);
-
-
-  void* res4 = tensorConvApprox(input, filter, 0, 0, 1, 1, 1, 1, 1, 1, 4, 0);
-  
-  printTensorValues(res4);
-
-
-  void* res4_half = tensorConvApproxHalf2(input, filter, 0, 0, 1, 1, 1, 1, 1, 1, 4, 0);
-
-  convertToFP32((struct Tensor*) res4_half);
+    
+}
 
-  printTensorValues(res4_half);
 
-  */
 
-  
-}
 
 
 
@@ -943,6 +909,39 @@ void testSampling3(){
 
 
 
+/**** Tests Sample for a sample 1 * 1 Filter */
+void testSampling_1_1(){
+
+  
+  Tensor* input = (Tensor*) create4DTensor(CUDNN_DATA_FLOAT, CUDNN_TENSOR_NCHW, 1, 9, 2, 2);
+  fillTensorWithVal(input, 2);
+  //fillWithOnesAndTwos(input);
+  
+  Tensor* filter = (Tensor*) create4DTensor(CUDNN_DATA_FLOAT, CUDNN_TENSOR_NCHW, 4, 9, 1, 1);
+  fillTensorWithVal(filter, 2);
+  
+
+  testSamplingCalls(input, filter, 0, 0, 1, 1, 2);
+
+  testSamplingCalls(input, filter, 0, 0, 1, 1, 3);
+
+  testSamplingCalls(input, filter, 0, 0, 1, 1, 4);
+
+
+
+  testSamplingCalls(input, filter, 1, 1, 1, 1, 2);
+
+  testSamplingCalls(input, filter, 1, 1, 1, 1, 3);
+
+  testSamplingCalls(input, filter, 1, 1, 1, 1, 4);
+
+    
+}
+
+
+
+
+
 
 
 
@@ -990,7 +989,12 @@ int main(){
 
   //testSampling3();
   
+
   testSampling_3_3();
+
+  
+  //- testSampling_1_1();
+
     
   stopProfiling();
 
-- 
GitLab