diff --git a/hpvm/lib/Transforms/HPVM2NVDLA/HPVM2NVDLAPass.cpp b/hpvm/lib/Transforms/HPVM2NVDLA/HPVM2NVDLAPass.cpp index fced0b3ccd389305e13e7eb23a2f60b9c492e0f1..e8f47ff5a0b357c76edb3a1d6022e84ac3841197 100644 --- a/hpvm/lib/Transforms/HPVM2NVDLA/HPVM2NVDLAPass.cpp +++ b/hpvm/lib/Transforms/HPVM2NVDLA/HPVM2NVDLAPass.cpp @@ -287,40 +287,35 @@ void CGT_NVDLA::initRuntimeAPI() { // Nothing to do here! } +// NOTE: TensorPtr is a pointer to the readTrainedWeights call instruction in LLVM/HPVM-C Weights CGT_NVDLA::readTrainedWeights(User *TensorPtr, - int dim1_size, int dim2_size, - int dim3_size, int dim4_size) { + int dim1_size, int dim2_size, + int dim3_size, int dim4_size) { + DEBUG(errs() << "READ TRAINED WEIGHTS\n"); // Get weights file name User *MemcpyPtr = dyn_cast<User>(TensorPtr->getOperand(0)); DEBUG(MemcpyPtr->print(errs())); DEBUG(errs() << "\n"); - while(!dyn_cast<AllocaInst>(MemcpyPtr)) { - MemcpyPtr = dyn_cast<User>(MemcpyPtr->getOperand(0)); - } - User *MemcpyArg = nullptr; - for(User *U: MemcpyPtr->users()) { - DEBUG(U->print(errs())); - DEBUG(errs() << "\n"); - if(auto *BCO = dyn_cast<BitCastOperator>(U)) { - for(User *CU: BCO->users()) { - if(auto *CI = dyn_cast<CallInst>(CU)) { - CI->getCalledFunction()->getName().contains(StringRef("memcpy")); - MemcpyArg = dyn_cast<User>(CI->getOperand(1)); - break; - } - } - if(MemcpyArg) - break; - } - } - assert(MemcpyArg && "File name not found."); - auto *WeightFileName = dyn_cast<GlobalVariable>(MemcpyArg->getOperand(0)); - assert(WeightFileName && "Weight file name must be a global variable."); - auto* CDA = dyn_cast<ConstantDataArray>(WeightFileName->getInitializer()); - assert(CDA && "Weight file name must be a constant array."); - const auto &file_name = std::string(CDA->getAsString()); - + errs()<<" MemcpyPtr = "<< *MemcpyPtr <<" \n"; + + // Look through bitcast instructions and geps. + Value* V = MemcpyPtr->stripPointerCasts(); + errs()<< *V << ": " << *V->getType() <<"\n"; + const GlobalVariable *GV = dyn_cast<GlobalVariable>(V); + if (!GV || !GV->isConstant() || !GV->hasDefinitiveInitializer()) + errs() << "ERROR: NOT GV ---- \n" ; + else + errs() << "GLOBAL VARIABLE = \n" << *GV; + + const ConstantDataArray *Array = + dyn_cast<ConstantDataArray>(GV->getInitializer()); + + + const auto &file_name = std::string(Array->getAsString()); + errs() << "filename = " << file_name <<"\n"; + + // Read the weights file int num_elem = dim1_size * dim2_size * dim3_size * dim4_size; int size_in_bytes = sizeof(float16) * num_elem;