Skip to content
Snippets Groups Projects
Commit e45db4b3 authored by Hashim Sharif's avatar Hashim Sharif
Browse files

Adding backend support for DepthwiseConv2D

parent 8dcc7d65
No related branches found
No related tags found
No related merge requests found
......@@ -275,6 +275,19 @@ let TargetPrefix = "visc" in {
llvm_i32_ty,
llvm_i32_ty], []>;
/* Tensor group convolution intrinsic
* i8* llvm.visc.tensor.group.convolution(i8*, i8*, i32, i32, i32, i32);
*/
def int_visc_tensor_group_convolution : Intrinsic<[llvm_ptr_ty], [llvm_ptr_ty,
llvm_ptr_ty,
llvm_i32_ty,
llvm_i32_ty,
llvm_i32_ty,
llvm_i32_ty,
llvm_i32_ty,
llvm_i32_ty], []>;
/* Tensor pool intrinsics: max, min, average
* i8* llvm.visc.tensor.pool.max(i8*, i32, i32, i32, i32, i32, i32);
* i8* llvm.visc.tensor.pool.min(i8*, i32, i32, i32, i32, i32, i32);
......
......@@ -304,6 +304,40 @@ void CGT_CUDNN::codeGen(DFLeafNode* N) {
}
break;
case Intrinsic::visc_tensor_group_convolution:
{ /* llvm.hpvm.tensor.mul */
// Tensor mul is not in place.
DEBUG(errs() << F_cudnn->getName() << "\t: Handling tensor convolution \n");
// Argument list for the runtime call
std::vector<Value*> Args;
Args.push_back(II->getOperand(0));
Args.push_back(II->getOperand(1));
Args.push_back(II->getOperand(2));
Args.push_back(II->getOperand(3));
Args.push_back(II->getOperand(4));
Args.push_back(II->getOperand(5));
Constant* conv_mode = ConstantInt::get(Type::getInt32Ty(M.getContext()), 1);
Args.push_back(conv_mode);
Args.push_back(II->getOperand(7));
// Create cudnn runtime function call
Constant* tensorConvolution;
DECLARE(tensorConvolution);
CallInst* CI = CallInst::Create(tensorConvolution,
Args, "", II);
// We can replace the call to hpvm.tensor.mul with the runtime call
II->replaceAllUsesWith(CI);
// Mark to remove at the end
IItoRemove.push_back(II);
}
break;
case Intrinsic::visc_tensor_mul:
{ /* llvm.hpvm.tensor.mul */
// Tensor mul is not in place.
......@@ -407,6 +441,7 @@ void CGT_CUDNN::codeGen(DFLeafNode* N) {
IItoRemove.push_back(II);
}
break;
case Intrinsic::visc_tensor_relu:
case Intrinsic::visc_tensor_clipped_relu:
case Intrinsic::visc_tensor_tanh:
......
......@@ -166,6 +166,7 @@ IS_VISC_CALL(hint)
// Tensor Operators
IS_VISC_CALL(tensor_mul)
IS_VISC_CALL(tensor_convolution)
IS_VISC_CALL(tensor_group_convolution)
IS_VISC_CALL(tensor_add)
IS_VISC_CALL(tensor_pool_max)
IS_VISC_CALL(tensor_pool_min)
......@@ -1273,6 +1274,9 @@ bool GenVISC::runOnModule(Module &M) {
if (isVISCCall_tensor_convolution(I)) {
ReplaceCallWithIntrinsic(I, Intrinsic::visc_tensor_convolution, &toBeErased);
}
if (isVISCCall_tensor_group_convolution(I)) {
ReplaceCallWithIntrinsic(I, Intrinsic::visc_tensor_group_convolution, &toBeErased);
}
if (isVISCCall_tensor_add(I)) {
ReplaceCallWithIntrinsic(I, Intrinsic::visc_tensor_add, &toBeErased);
}
......
......@@ -95,6 +95,7 @@ float __visc__cos(float);
void* __visc__tensor_add(void*, void*);
void* __visc__tensor_mul(void*, void*);
void* __visc__tensor_convolution(void*, void*, int, int, int, int);
void* __visc__tensor_group_convolution(void*, void*, int, int, int, int, int, int);
void* __visc__tensor_pool_max(void*, int, int, int, int, int, int);
void* __visc__tensor_pool_mean(void*, int, int, int, int, int, int);
void* __visc__tensor_relu(void*);
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment