Skip to content
Snippets Groups Projects
Commit 7ea016f0 authored by Hashim Sharif's avatar Hashim Sharif
Browse files

BatchNorm handling support in Backends

parent 7027f5a2
No related branches found
No related tags found
No related merge requests found
...@@ -276,7 +276,7 @@ let TargetPrefix = "visc" in { ...@@ -276,7 +276,7 @@ let TargetPrefix = "visc" in {
llvm_i32_ty], []>; llvm_i32_ty], []>;
/* Tensor group convolution intrinsic /* Tensor group convolution intrinsic
* i8* llvm.visc.tensor.group.convolution(i8*, i8*, i32, i32, i32, i32); * i8* llvm.visc.tensor.group.convolution(i8*, i8*, i32, i32, i32, i32, i32, i32);
*/ */
def int_visc_tensor_group_convolution : Intrinsic<[llvm_ptr_ty], [llvm_ptr_ty, def int_visc_tensor_group_convolution : Intrinsic<[llvm_ptr_ty], [llvm_ptr_ty,
llvm_ptr_ty, llvm_ptr_ty,
...@@ -287,6 +287,16 @@ let TargetPrefix = "visc" in { ...@@ -287,6 +287,16 @@ let TargetPrefix = "visc" in {
llvm_i32_ty, llvm_i32_ty,
llvm_i32_ty], []>; llvm_i32_ty], []>;
/* Tensor BatchNorm intrinsic
* i8* llvm.visc.tensor.batchnorm(i8*, i8*, i8*, i8*, i8*, double);
*/
def int_visc_tensor_batchnorm : Intrinsic<[llvm_ptr_ty], [llvm_ptr_ty,
llvm_ptr_ty,
llvm_ptr_ty,
llvm_ptr_ty,
llvm_ptr_ty,
llvm_double_ty], []>;
/* Tensor pool intrinsics: max, min, average /* Tensor pool intrinsics: max, min, average
* i8* llvm.visc.tensor.pool.max(i8*, i32, i32, i32, i32, i32, i32); * i8* llvm.visc.tensor.pool.max(i8*, i32, i32, i32, i32, i32, i32);
......
...@@ -337,6 +337,35 @@ void CGT_CUDNN::codeGen(DFLeafNode* N) { ...@@ -337,6 +337,35 @@ void CGT_CUDNN::codeGen(DFLeafNode* N) {
} }
break; break;
case Intrinsic::visc_tensor_batchnorm:
{ /* llvm.hpvm.tensor.batchnorm */
// Tensor batchnorm is in place.
// FIXME: Add Check for InPlace Analysis
DEBUG(errs() << F_cudnn->getName() << "\t: Handling tensor batch normalization \n");
// Argument list for the runtime call
std::vector<Value*> Args;
Args.push_back(II->getOperand(0));
Args.push_back(II->getOperand(1));
Args.push_back(II->getOperand(2));
Args.push_back(II->getOperand(3));
Args.push_back(II->getOperand(4));
Args.push_back(II->getOperand(5));
// Create cudnn runtime function call
Constant* tensorBatchNorm;
DECLARE(tensorBatchNorm);
CallInst* CI = CallInst::Create(tensorBatchNorm,
Args, "", II);
// We can replace the call to hpvm.tensor.batchnorm with the TensorRT call
II->replaceAllUsesWith(CI);
// Mark to remove at the end
IItoRemove.push_back(II);
}
break;
case Intrinsic::visc_tensor_mul: case Intrinsic::visc_tensor_mul:
{ /* llvm.hpvm.tensor.mul */ { /* llvm.hpvm.tensor.mul */
......
...@@ -167,6 +167,7 @@ IS_VISC_CALL(hint) ...@@ -167,6 +167,7 @@ IS_VISC_CALL(hint)
IS_VISC_CALL(tensor_mul) IS_VISC_CALL(tensor_mul)
IS_VISC_CALL(tensor_convolution) IS_VISC_CALL(tensor_convolution)
IS_VISC_CALL(tensor_group_convolution) IS_VISC_CALL(tensor_group_convolution)
IS_VISC_CALL(tensor_batchnorm)
IS_VISC_CALL(tensor_add) IS_VISC_CALL(tensor_add)
IS_VISC_CALL(tensor_pool_max) IS_VISC_CALL(tensor_pool_max)
IS_VISC_CALL(tensor_pool_min) IS_VISC_CALL(tensor_pool_min)
...@@ -1280,6 +1281,9 @@ bool GenVISC::runOnModule(Module &M) { ...@@ -1280,6 +1281,9 @@ bool GenVISC::runOnModule(Module &M) {
if (isVISCCall_tensor_add(I)) { if (isVISCCall_tensor_add(I)) {
ReplaceCallWithIntrinsic(I, Intrinsic::visc_tensor_add, &toBeErased); ReplaceCallWithIntrinsic(I, Intrinsic::visc_tensor_add, &toBeErased);
} }
if (isVISCCall_tensor_batchnorm(I)) {
ReplaceCallWithIntrinsic(I, Intrinsic::visc_tensor_batchnorm, &toBeErased);
}
if (isVISCCall_tensor_mul(I)) { if (isVISCCall_tensor_mul(I)) {
ReplaceCallWithIntrinsic(I, Intrinsic::visc_tensor_mul, &toBeErased); ReplaceCallWithIntrinsic(I, Intrinsic::visc_tensor_mul, &toBeErased);
} }
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment