From 5a80468b48112e24223c51eb29eebd1fe83b547a Mon Sep 17 00:00:00 2001 From: Hashim Sharif <hsharif3@miranda.cs.illinois.edu> Date: Fri, 25 Jun 2021 02:14:48 -0500 Subject: [PATCH] ReadTrainedWeights filename extraction fixed -- still not working Pass (close) --- .../Transforms/HPVM2NVDLA/HPVM2NVDLAPass.cpp | 51 +++++++++---------- 1 file changed, 23 insertions(+), 28 deletions(-) diff --git a/hpvm/lib/Transforms/HPVM2NVDLA/HPVM2NVDLAPass.cpp b/hpvm/lib/Transforms/HPVM2NVDLA/HPVM2NVDLAPass.cpp index fced0b3ccd..e8f47ff5a0 100644 --- a/hpvm/lib/Transforms/HPVM2NVDLA/HPVM2NVDLAPass.cpp +++ b/hpvm/lib/Transforms/HPVM2NVDLA/HPVM2NVDLAPass.cpp @@ -287,40 +287,35 @@ void CGT_NVDLA::initRuntimeAPI() { // Nothing to do here! } +// NOTE: TensorPtr is a pointer to the readTrainedWeights call instruction in LLVM/HPVM-C Weights CGT_NVDLA::readTrainedWeights(User *TensorPtr, - int dim1_size, int dim2_size, - int dim3_size, int dim4_size) { + int dim1_size, int dim2_size, + int dim3_size, int dim4_size) { + DEBUG(errs() << "READ TRAINED WEIGHTS\n"); // Get weights file name User *MemcpyPtr = dyn_cast<User>(TensorPtr->getOperand(0)); DEBUG(MemcpyPtr->print(errs())); DEBUG(errs() << "\n"); - while(!dyn_cast<AllocaInst>(MemcpyPtr)) { - MemcpyPtr = dyn_cast<User>(MemcpyPtr->getOperand(0)); - } - User *MemcpyArg = nullptr; - for(User *U: MemcpyPtr->users()) { - DEBUG(U->print(errs())); - DEBUG(errs() << "\n"); - if(auto *BCO = dyn_cast<BitCastOperator>(U)) { - for(User *CU: BCO->users()) { - if(auto *CI = dyn_cast<CallInst>(CU)) { - CI->getCalledFunction()->getName().contains(StringRef("memcpy")); - MemcpyArg = dyn_cast<User>(CI->getOperand(1)); - break; - } - } - if(MemcpyArg) - break; - } - } - assert(MemcpyArg && "File name not found."); - auto *WeightFileName = dyn_cast<GlobalVariable>(MemcpyArg->getOperand(0)); - assert(WeightFileName && "Weight file name must be a global variable."); - auto* CDA = dyn_cast<ConstantDataArray>(WeightFileName->getInitializer()); - assert(CDA && "Weight file name must be a constant array."); - const auto &file_name = std::string(CDA->getAsString()); - + errs()<<" MemcpyPtr = "<< *MemcpyPtr <<" \n"; + + // Look through bitcast instructions and geps. + Value* V = MemcpyPtr->stripPointerCasts(); + errs()<< *V << ": " << *V->getType() <<"\n"; + const GlobalVariable *GV = dyn_cast<GlobalVariable>(V); + if (!GV || !GV->isConstant() || !GV->hasDefinitiveInitializer()) + errs() << "ERROR: NOT GV ---- \n" ; + else + errs() << "GLOBAL VARIABLE = \n" << *GV; + + const ConstantDataArray *Array = + dyn_cast<ConstantDataArray>(GV->getInitializer()); + + + const auto &file_name = std::string(Array->getAsString()); + errs() << "filename = " << file_name <<"\n"; + + // Read the weights file int num_elem = dim1_size * dim2_size * dim3_size * dim4_size; int size_in_bytes = sizeof(float16) * num_elem; -- GitLab