diff --git a/llvm/lib/Transforms/DFG2LLVM_NVPTX/DFG2LLVM_NVPTX.cpp b/llvm/lib/Transforms/DFG2LLVM_NVPTX/DFG2LLVM_NVPTX.cpp index cb6a2967bbce1117cfda470179bc7cf3ba5d6fce..1526989e82f72e2d9805714bfaae2277ebca730d 100644 --- a/llvm/lib/Transforms/DFG2LLVM_NVPTX/DFG2LLVM_NVPTX.cpp +++ b/llvm/lib/Transforms/DFG2LLVM_NVPTX/DFG2LLVM_NVPTX.cpp @@ -90,6 +90,7 @@ namespace { //Functions std::string getKernelsModuleName(Module &M); + void fixValueAddrspace(Value* V, unsigned addrspace); void changeArgAddrspace(Function* F, unsigned i); void addCLMetadata(Function* F); void writeKernelsModule(); @@ -841,7 +842,6 @@ namespace { re = IItoRemove.rend(); ri != re; ++ri) (*ri)->eraseFromParent(); - changeArgAddrspace(F_nvptx, 1); addCLMetadata(F_nvptx); DEBUG(errs() << KernelM); @@ -874,7 +874,39 @@ namespace { return mid.append("_kernels.ll"); } - void CodeGenTraversal::changeArgAddrspace(Function* F, unsigned i) { + void CodeGenTraversal::fixValueAddrspace(Value* V, unsigned addrspace) { + assert(isa<PointerType>(V->getType()) + && "Value should be of Pointer Type!"); + PointerType* OldTy = cast<PointerType>(V->getType()); + PointerType* NewTy = PointerType::get(OldTy->getElementType(), addrspace); + V->mutateType(NewTy); + for(Value::use_iterator ui = V->use_begin(), ue = V->use_end(); ui != ue; ui++) { + // Change all uses producing pointer type in same address space to new + // addressspace. + if(PointerType* PTy = dyn_cast<PointerType>(ui->getType())) { + if(PTy->getAddressSpace() == OldTy->getAddressSpace()) { + fixValueAddrspace(*ui, addrspace); + } + } + } + } + + void CodeGenTraversal::changeArgAddrspace(Function* F, unsigned addrspace) { + std::vector<Type*> ArgTypes; + for(auto& arg: F->getArgumentList()) { + DEBUG(errs() << arg << "\n"); + if(PointerType* argTy = dyn_cast<PointerType>(arg.getType())) { + if(argTy->getAddressSpace() == 0) { + fixValueAddrspace(&arg, addrspace); + } + } + ArgTypes.push_back(arg.getType()); + } + FunctionType* FTy = FunctionType::get(F->getReturnType(), ArgTypes, false); + PointerType* PTy = FTy->getPointerTo(cast<PointerType>(F->getType())->getAddressSpace()); + + F->mutateType(PTy); + DEBUG(errs() << *F->getFunctionType() << "\n" <<*F << "\n"); } void CodeGenTraversal::addCLMetadata(Function* F) {