diff --git a/llvm/projects/hpvm-tensor-rt/dnn_sources/src/test_ops.cc b/llvm/projects/hpvm-tensor-rt/dnn_sources/src/test_ops.cc index 29733fd58b8f9767acf6656d4351c98ae4e03661..a62ac427576f5211f20e28c759cc6104fee4d943 100644 --- a/llvm/projects/hpvm-tensor-rt/dnn_sources/src/test_ops.cc +++ b/llvm/projects/hpvm-tensor-rt/dnn_sources/src/test_ops.cc @@ -803,45 +803,11 @@ void testSampling_3_3(){ testSamplingCalls(input, filter, 1, 1, 2, 2, 4); - - /* - //printTensorValues(input); - - void* res = tensorConvApprox(input, filter, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1); - - printTensorValues(res); - - - void* res2 = tensorConvApprox(input, filter, 0, 0, 1, 1, 1, 1, 1, 1, 2, 1); - - printTensorValues(res2); - - - void* res2_sim = tensorConvSampSim(input, filter, 0, 0, 1, 1, 1, 1, 2, 0); - - printTensorValues(res2_sim); - - - void* res3 = tensorConvApprox(input, filter, 0, 0, 1, 1, 1, 1, 1, 1, 2, 0); - - printTensorValues(res3); - - - void* res4 = tensorConvApprox(input, filter, 0, 0, 1, 1, 1, 1, 1, 1, 4, 0); - - printTensorValues(res4); - - - void* res4_half = tensorConvApproxHalf2(input, filter, 0, 0, 1, 1, 1, 1, 1, 1, 4, 0); - - convertToFP32((struct Tensor*) res4_half); + +} - printTensorValues(res4_half); - */ - -} @@ -943,6 +909,39 @@ void testSampling3(){ +/**** Tests Sample for a sample 1 * 1 Filter */ +void testSampling_1_1(){ + + + Tensor* input = (Tensor*) create4DTensor(CUDNN_DATA_FLOAT, CUDNN_TENSOR_NCHW, 1, 9, 2, 2); + fillTensorWithVal(input, 2); + //fillWithOnesAndTwos(input); + + Tensor* filter = (Tensor*) create4DTensor(CUDNN_DATA_FLOAT, CUDNN_TENSOR_NCHW, 4, 9, 1, 1); + fillTensorWithVal(filter, 2); + + + testSamplingCalls(input, filter, 0, 0, 1, 1, 2); + + testSamplingCalls(input, filter, 0, 0, 1, 1, 3); + + testSamplingCalls(input, filter, 0, 0, 1, 1, 4); + + + + testSamplingCalls(input, filter, 1, 1, 1, 1, 2); + + testSamplingCalls(input, filter, 1, 1, 1, 1, 3); + + testSamplingCalls(input, filter, 1, 1, 1, 1, 4); + + +} + + + + + @@ -990,7 +989,12 @@ int main(){ //testSampling3(); + testSampling_3_3(); + + + //- testSampling_1_1(); + stopProfiling();