Skip to content
Snippets Groups Projects
Commit 4c51dc68 authored by Hashim Sharif's avatar Hashim Sharif
Browse files

Handling tensor_pool_mean in TensorRT backend

parent f6ca608a
No related branches found
No related tags found
No related merge requests found
......@@ -227,7 +227,7 @@ void CGT_CUDNN::codeGen(DFLeafNode* N) {
std::string FName(F->getName().data());
F_cudnn = CloneFunction(F, VMap);
F_cudnn->setName(FName + "_cudnn");
errs()<<"Cloned function name = "<<F_cudnn->getName()<<"\n";
errs()<<"Cloned function name2 = "<<F_cudnn->getName()<<"\n";
F_cudnn->removeFromParent();
M.getFunctionList().push_back(F_cudnn);
......@@ -337,8 +337,10 @@ void CGT_CUDNN::codeGen(DFLeafNode* N) {
bool inplace = isValidOperandForInPlaceOperation(Op, F_cudnn, N);
// Code generation cannot continue if this is false, because the target
// only provides an in place operation
assert(inplace &&
"Operand not valid for in place operation. Code gen aborted.\n");
// FIXME: remove this comment - must check for in-place
//assert(inplace &&
// "Operand not valid for in place operation. Code gen aborted.\n");
// Argument list for the runtime call
std::vector<Value*> Args;
......@@ -358,6 +360,7 @@ void CGT_CUDNN::codeGen(DFLeafNode* N) {
}
break;
case Intrinsic::visc_tensor_pool_max:
case Intrinsic::visc_tensor_pool_mean:
{ /* llvm.visc.tensor.relu */
DEBUG(errs() << F_cudnn->getName() << "\t: Handling tensor_pool_max\n");
// Tensor relu(a) is in place for argument a.
......@@ -374,8 +377,17 @@ void CGT_CUDNN::codeGen(DFLeafNode* N) {
// vertical_stride, horizontal_stride);
std::vector<Value*> Args;
Args.push_back(II->getOperand(0));
Constant* constZero = ConstantInt::get(Type::getInt32Ty(M.getContext()), 0);
Args.push_back(constZero); // ID for max pool. Min/Avg have different IDs (non-zero)
int pool_type = 0;
if (II->getIntrinsicID() == Intrinsic::visc_tensor_pool_max){
pool_type = 0;
}
if (II->getIntrinsicID() == Intrinsic::visc_tensor_pool_mean){
pool_type = 1;
}
Constant* constPoolType = ConstantInt::get(Type::getInt32Ty(M.getContext()), pool_type);
Args.push_back(constPoolType); // ID for max pool. Min/Avg have different IDs (non-zero)
Args.push_back(II->getOperand(1));
Args.push_back(II->getOperand(2));
Args.push_back(II->getOperand(3));
......@@ -430,7 +442,9 @@ void CGT_CUDNN::codeGen(DFLeafNode* N) {
else if (II->getIntrinsicID() == Intrinsic::visc_tensor_tanh){
// Create cudnn runtime function call
Constant* tensorTanh;
errs()<<"tensorTanh Call = \n\n";
DECLARE(tensorTanh);
//errs()<<"tensorTanh Call = "<<*tensorTanh<<"\l";
CallInst::Create(tensorTanh, Args, "", II);
}
......
......@@ -96,6 +96,7 @@ void* __visc__tensor_add(void*, void*);
void* __visc__tensor_mul(void*, void*);
void* __visc__tensor_convolution(void*, void*, int, int, int, int);
void* __visc__tensor_pool_max(void*, int, int, int, int, int, int);
void* __visc__tensor_pool_mean(void*, int, int, int, int, int, int);
void* __visc__tensor_relu(void*);
void* __visc__tensor_tanh(void*);
void* __visc__tensor_softmax(void*);
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment