Skip to content
Snippets Groups Projects
Commit 9bc2ca73 authored by Hashim Sharif's avatar Hashim Sharif
Browse files

Grouped Conv support in tensorHalfConvolution

parent a9f49265
No related branches found
No related tags found
No related merge requests found
......@@ -149,7 +149,7 @@ void* tensorHalfGemmGPU(void* lhs_ptr, void* rhs_ptr){
void* tensorHalfConvolution(void* input_ptr, void* filter_ptr,
int vertical_pad, int horizontal_pad,
int vertical_stride, int horizontal_stride,
int conv_mode, int compute_precision){
int conv_mode, int conv_groups){
INFO("*** TensorHConvolution \n");
profileEvent("tensorHalfConv");
......@@ -200,8 +200,17 @@ void* tensorHalfConvolution(void* input_ptr, void* filter_ptr,
/******* END OF INPUT DATA CONVERSIONS*/
checkCUDNN(cudnnCreateConvolutionDescriptor(&convDesc));
//FIXME: Current hack to preserve backward compatibilty
if(conv_groups == 0){
conv_groups = 1;
}
// NOTE: Adding support for grouped convolution
checkCUDNN(cudnnSetConvolutionGroupCount(convDesc, conv_groups));
// FIXIT: Think if upscaling values need to be configurable?
// IMP-FIXIT: CUDNN Cross correlation is only used in the Lenet context
// IMP-FIXIT: Either make mode configurable OR see if CUDNN_CONVOLUTION MODE should be used?
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment