Skip to content
Snippets Groups Projects
Commit 59e3b7e3 authored by Hashim Sharif's avatar Hashim Sharif
Browse files

Adding tensorSelect CPU version

parent 299157d4
No related branches found
No related tags found
No related merge requests found
......@@ -1099,7 +1099,7 @@ void testSampling_1_1(){
void testTensorArgMax(){
void* testTensorArgMax(){
Tensor* input = (Tensor*) create4DTensor(CUDNN_DATA_FLOAT, CUDNN_TENSOR_NCHW, 4, 3, 1, 1);
......@@ -1126,15 +1126,24 @@ void testTensorArgMax(){
host_ptr[11] = 8;
void* argmax_out = tensorArgMax(input);
// Expect Output of call below to be:
// 1 2 2 0
printTensorValues(argmax_out);
return argmax_out;
}
void testTensorSelect(){
void* testTensorSelect(void* argmax_out){
void* select_out = tensorSelect(argmax_out, 2);
printf ("***** tensorSelect output \n");
printTensorValues(select_out);
return select_out;
}
......@@ -1148,8 +1157,8 @@ void testTensorContract(){
void testNewTensorOps(){
testTensorArgMax();
testTensorSelect();
void* argmax_out = testTensorArgMax();
testTensorSelect(argmax_out);
testTensorContract();
}
......
#include "tensor.h"
#include <stdlib.h>
void* tensorArgMax(Tensor* input_ptr){
void* tensorArgMax(void* input_ptr){
Tensor* input = (Tensor*) input_ptr;
float* host_ptr = (float*) input->host_data;
......@@ -39,3 +40,35 @@ void* tensorArgMax(Tensor* input_ptr){
return output;
}
void* tensorSelect(void* input_ptr, float target_value){
Tensor* input = (Tensor*) input_ptr;
float* host_ptr = (float*) input->host_data;
int batch_size = input->dims.dim_sizes[0];
int channels = input->dims.dim_sizes[1];
if (channels != 1){
printf("* Channels dimension must be 1 \n");
abort();
}
Tensor* output = (Tensor *) create4DTensor(0, 0, batch_size, 1, 1, 1);
changeTensorPlacement(output, HOST);
float* out_ptr = (float*) output->host_data;
for(int i = 0; i < batch_size; i++){
if (host_ptr[i] == target_value){
out_ptr[i] = 1;
}
else{
out_ptr[i] = 0;
}
}
return output;
}
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment