Skip to content
Snippets Groups Projects
Commit 2a3e6579 authored by Hashim Sharif's avatar Hashim Sharif
Browse files

Only starting to cleanup hpvm-tensor-rt unit tests

parent daf46156
No related branches found
No related tags found
No related merge requests found
...@@ -2,10 +2,8 @@ ...@@ -2,10 +2,8 @@
#include <stdio.h> #include <stdio.h>
#include <stdlib.h> #include <stdlib.h>
#include <unistd.h> #include <unistd.h>
#include "tensor_runtime.h" #include "tensor_runtime.h"
#include "utils.h" #include "utils.h"
#include "tensor_custom_ops_cpu.h" #include "tensor_custom_ops_cpu.h"
...@@ -59,40 +57,10 @@ void testTensorHgemm(){ ...@@ -59,40 +57,10 @@ void testTensorHgemm(){
void* output = tensorHalfGemm(lhs, rhs); void* output = tensorHalfGemm(lhs, rhs);
printTensorValues(output); printTensorValues(output);
void* bias_ptr = create4DTensor(CUDNN_TENSOR_NCHW, CUDNN_DATA_FLOAT, 1, 3, 1, 1); // TODO: Add result comparator - Make a generic result comparator
struct Tensor* bias = (struct Tensor*) bias_ptr;
fillTensorWithOnes(bias);
float* bias_arr = (float*) bias->host_data;
for(int i = 0; i < bias->num_elems; i++){
bias_arr[i] = i + 1;
}
void* output2 = tensorAdd(output, bias);
printTensorValues(output2);
} }
void testTensorHgemm2(){
printf("***** TensorHgemm ***** \n\n");
void* lhs_ptr = create4DTensor(CUDNN_DATA_FLOAT, CUDNN_TENSOR_NCHW,
10000, 800, 1, 1);
struct Tensor* lhs = (struct Tensor*) lhs_ptr;
float* data_arr = (float*) lhs->host_data;
for(int i = 0; i < lhs->num_elems; i++){
data_arr[i] = (i / 4) + 1;
}
void* rhs = create4DTensor(CUDNN_TENSOR_NCHW, CUDNN_DATA_FLOAT,
1, 1, 800, 800);
fillTensorWithOnes(rhs);
void* output = tensorHalfGemm(lhs, rhs);
//printTensorValues(output);
}
void testTensorSgemm2(){ void testTensorSgemm2(){
...@@ -1209,7 +1177,8 @@ int main(){ ...@@ -1209,7 +1177,8 @@ int main(){
startProfiling(); startProfiling();
//testTensorHgemm2(); testTensorHgemm();
//testTensorSgemm2(); //testTensorSgemm2();
//testTensorConv(); //testTensorConv();
//testTensorError(); //testTensorError();
...@@ -1250,9 +1219,7 @@ int main(){ ...@@ -1250,9 +1219,7 @@ int main(){
*************/ *************/
testNewTensorOps(); //testNewTensorOps();
//testQuantization(); //testQuantization();
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment