diff --git a/llvm/lib/Transforms/DFG2LLVM_NVPTX/DFG2LLVM_NVPTX.cpp b/llvm/lib/Transforms/DFG2LLVM_NVPTX/DFG2LLVM_NVPTX.cpp index 14a284ef1b29ff3d027d3fa82e602daf028e5185..d08a3ae170df76df7df8f46a628d2022a0f568d1 100644 --- a/llvm/lib/Transforms/DFG2LLVM_NVPTX/DFG2LLVM_NVPTX.cpp +++ b/llvm/lib/Transforms/DFG2LLVM_NVPTX/DFG2LLVM_NVPTX.cpp @@ -81,19 +81,20 @@ namespace { // VISC Runtime API Module* runtimeModule; - Constant* llvm_visc_ptx_launch; - Constant* llvm_visc_ptx_wait; - Constant* llvm_visc_ptx_initContext; - Constant* llvm_visc_ptx_input_scalar; - Constant* llvm_visc_ptx_input_ptr; - Constant* llvm_visc_ptx_output_ptr; - Constant* llvm_visc_ptx_getOutput; - Constant* llvm_visc_ptx_executeNode; + Function* llvm_visc_ptx_launch; + Function* llvm_visc_ptx_wait; + Function* llvm_visc_ptx_initContext; + Function* llvm_visc_ptx_input_scalar; + Function* llvm_visc_ptx_input_ptr; + Function* llvm_visc_ptx_output_ptr; + Function* llvm_visc_ptx_getOutput; + Function* llvm_visc_ptx_executeNode; //Functions std::string getKernelsModuleName(Module &M); void fixValueAddrspace(Value* V, unsigned addrspace); + Value* getStringPointer(const Twine& S, Instruction* InsertBefore, const Twine& Name = ""); void changeArgAddrspace(Function* F, unsigned i); void addCLMetadata(Function* F); void writeKernelsModule(); @@ -112,6 +113,8 @@ namespace { // Constructor CodeGenTraversal(Module &_M, BuildDFG &_DFG) : M(_M), DFG(_DFG), KernelM(*CloneModule(&_M)) { + // Initialize Runtime API + initRuntimeAPI(); // Copying instead of creating new, in order to preserve required info (metadata) @@ -189,36 +192,36 @@ namespace { DEBUG(errs() << "Successfully loaded visc-rt API module\n"); // Get or insert the global declarations for launch/wait functions - llvm_visc_ptx_launch = M.getOrInsertFunction("llvm_visc_ptx_launch", - runtimeModule->getFunction("llvm_visc_ptx_launch")->getFunctionType()); + llvm_visc_ptx_launch = cast<Function>(M.getOrInsertFunction("llvm_visc_ptx_launch", + runtimeModule->getFunction("llvm_visc_ptx_launch")->getFunctionType())); DEBUG(errs() << *llvm_visc_ptx_launch); - llvm_visc_ptx_wait = M.getOrInsertFunction("llvm_visc_ptx_wait", - runtimeModule->getFunction("llvm_visc_ptx_wait")->getFunctionType()); + llvm_visc_ptx_wait = cast<Function>(M.getOrInsertFunction("llvm_visc_ptx_wait", + runtimeModule->getFunction("llvm_visc_ptx_wait")->getFunctionType())); DEBUG(errs() << *llvm_visc_ptx_wait); - llvm_visc_ptx_initContext = M.getOrInsertFunction("llvm_visc_ptx_initContext" , - runtimeModule->getFunction("llvm_visc_ptx_initContext")->getFunctionType()); + llvm_visc_ptx_initContext = cast<Function>(M.getOrInsertFunction("llvm_visc_ptx_initContext" , + runtimeModule->getFunction("llvm_visc_ptx_initContext")->getFunctionType())); DEBUG(errs() << *llvm_visc_ptx_initContext); - llvm_visc_ptx_input_scalar = M.getOrInsertFunction("llvm_visc_ptx_input_scalar", - runtimeModule->getFunction("llvm_visc_ptx_input_scalar")->getFunctionType()); + llvm_visc_ptx_input_scalar = cast<Function>(M.getOrInsertFunction("llvm_visc_ptx_input_scalar", + runtimeModule->getFunction("llvm_visc_ptx_input_scalar")->getFunctionType())); DEBUG(errs() << *llvm_visc_ptx_input_scalar); - llvm_visc_ptx_input_ptr = M.getOrInsertFunction("llvm_visc_ptx_input_ptr", - runtimeModule->getFunction("llvm_visc_ptx_input_ptr")->getFunctionType()); + llvm_visc_ptx_input_ptr = cast<Function>(M.getOrInsertFunction("llvm_visc_ptx_input_ptr", + runtimeModule->getFunction("llvm_visc_ptx_input_ptr")->getFunctionType())); DEBUG(errs() << *llvm_visc_ptx_input_ptr); - llvm_visc_ptx_output_ptr = M.getOrInsertFunction("llvm_visc_ptx_output_ptr", - runtimeModule->getFunction("llvm_visc_ptx_output_ptr")->getFunctionType()); + llvm_visc_ptx_output_ptr = cast<Function>(M.getOrInsertFunction("llvm_visc_ptx_output_ptr", + runtimeModule->getFunction("llvm_visc_ptx_output_ptr")->getFunctionType())); DEBUG(errs() << *llvm_visc_ptx_output_ptr); - llvm_visc_ptx_getOutput = M.getOrInsertFunction("llvm_visc_ptx_getOutput", - runtimeModule->getFunction("llvm_visc_ptx_getOutput")->getFunctionType()); + llvm_visc_ptx_getOutput = cast<Function>(M.getOrInsertFunction("llvm_visc_ptx_getOutput", + runtimeModule->getFunction("llvm_visc_ptx_getOutput")->getFunctionType())); DEBUG(errs() << *llvm_visc_ptx_getOutput); - llvm_visc_ptx_executeNode = M.getOrInsertFunction("llvm_visc_ptx_executeNode", - runtimeModule->getFunction("llvm_visc_ptx_executeNode")->getFunctionType()); + llvm_visc_ptx_executeNode = cast<Function>(M.getOrInsertFunction("llvm_visc_ptx_executeNode", + runtimeModule->getFunction("llvm_visc_ptx_executeNode")->getFunctionType())); DEBUG(errs() << *llvm_visc_ptx_executeNode); } @@ -256,7 +259,6 @@ namespace { offset--; } arg = i; - DEBUG(errs() << *F); DEBUG(errs() << *arg <<"\n"); return arg; } @@ -300,6 +302,19 @@ namespace { return inputVal; } + // Generate Code for declaring a constant string [L x i8] and return a pointer + // to the start of it. + Value* CodeGenTraversal::getStringPointer(const Twine& S, Instruction* IB, const Twine& Name) { + Constant* SConstant = ConstantDataArray::getString(M.getContext(), S.str(), true); + Value* SGlobal = new GlobalVariable(M, SConstant->getType(), true, + GlobalValue::InternalLinkage, SConstant, Name); + Value* Zero = ConstantInt::get(Type::getInt64Ty(getGlobalContext()), 0); + Value* GEPArgs[] = {Zero, Zero}; + GetElementPtrInst* SPtr = GetElementPtrInst::Create(SGlobal, + ArrayRef<Value*>(GEPArgs, 2), Name+"Ptr", IB); + return SPtr; + } + // Generate Code to call the kernel // The plan is to replace the internal node with a leaf node. This method is // used to generate a function to associate with this leaf node. The function @@ -310,6 +325,8 @@ namespace { // function before and nothing else needs to be done for this leaf node. assert(N->getGenFunc() != NULL && "Code already generated for this node"); + DEBUG(errs() << "Generating kernel call code\n"); + Function* F = N->getFuncPointer(); @@ -344,7 +361,7 @@ namespace { // FIXME: Adding Index and Dim arguments are probably not required except // for consistency purpose (DFG2LLVM_X86 does assume that all leaf nodes do // have those arguments) - + // Add Index and Dim arguments except for the root node if(!N->isRoot()) addIdxDimArgs(F_X86); @@ -356,7 +373,7 @@ namespace { // and Exit dummy nodes). This child is the PTX kernel. This simplifies code // generation for kernel calls significantly. All the inputs to this child // node would either be constants or from the parent node N. - + assert(N->getChildGraph()->size() == 3 && "Node expected to have just one non-dummy node!"); @@ -373,22 +390,25 @@ namespace { Function* CF = C->getFuncPointer(); // Initialize context + DEBUG(errs() << "Initializing context" << "\n"); CallInst::Create(llvm_visc_ptx_initContext, None, "", RI); + DEBUG(errs() << "Initializing commandQ" << "\n"); // Initialize command queue - Constant* file = ConstantDataArray::get(M.getContext(), - ArrayRef<uint8_t>((uint8_t*)FileName.str().c_str(), FileName.str().length())); - - Constant* kernel = ConstantDataArray::get(M.getContext(), - ArrayRef<uint8_t>((uint8_t*)KernelName.str().c_str(), KernelName.str().length())); + Value* file = getStringPointer(FileName, RI, "Filename"); + Value* kernel = getStringPointer(KernelName, RI,"KernelName"); Value* LaunchInstArgs[] = {file, kernel}; + + DEBUG(errs() << "Inserting launch call" << "\n"); CallInst* GraphID = CallInst::Create(llvm_visc_ptx_launch, ArrayRef<Value*>(LaunchInstArgs, 2), "graph"+CF->getName(), RI); + DEBUG(errs() << *GraphID << "\n"); // Iterate over the required input edges of the node and use the visc-rt API // to set inputs + DEBUG(errs() << "Iterate over input edges of node and insert visc api\n"); for(unsigned i=0; i<CF->getFunctionType()->getNumParams(); i++) { Value* inputVal = getInValueAt(C, i, F_X86, RI); @@ -398,21 +418,35 @@ namespace { // type on target machine, but for pointers, the size of data would be the // next integer argument if(inputVal->getType()->isPointerTy()) { + Value* inputValI8Ptr = CastInst::CreatePointerCast(inputVal, + Type::getInt8PtrTy(M.getContext()), + inputVal->getName()+".i8ptr", + RI); // Pointer Input Value* inputSize = getInValueAt(C, i+1, F_X86, RI); - assert(inputSize->getType()->isIntegerTy() + assert(inputSize->getType() == Type::getInt64Ty(M.getContext()) && "Pointer type input must always be followed by size (integer type)"); Value* setInputArgs[] = {GraphID, - inputVal, + inputValI8Ptr, ConstantInt::get(Type::getInt32Ty(M.getContext()),i), - inputSize + inputSize }; CallInst::Create(llvm_visc_ptx_input_ptr, ArrayRef<Value*>(setInputArgs, 4), "", RI); } else { // Scalar Input + // Store the scalar value on stack and then pass the pointer to its + // location + AllocaInst* inputValPtr = new AllocaInst(inputVal->getType(), inputVal->getName()+".ptr", RI); + StoreInst* SI = new StoreInst(inputVal, inputValPtr, RI); + + Value* inputValI8Ptr = CastInst::CreatePointerCast(inputValPtr, + Type::getInt8PtrTy(M.getContext()), + inputVal->getName()+".i8ptr", + RI); + Value* setInputArgs[] = {GraphID, - inputVal, + inputValI8Ptr, ConstantInt::get(Type::getInt32Ty(M.getContext()),i), ConstantExpr::getSizeOf(inputVal->getType()) }; @@ -421,6 +455,7 @@ namespace { } } + DEBUG(errs() << "Setup output edges of node and insert visc api\n"); // Setup output // FIXME: Note - There is a tricky question. In X86 we do not need to care // about pointer inputs which modify data in memory implicitly (without @@ -472,24 +507,31 @@ namespace { RI); // Read each device pointer listed in output struct // Load the output struct - CastInst* BI = BitCastInst::CreatePointerCast(h_Output, CF->getReturnType(), "output.ptr", RI); + CastInst* BI = BitCastInst::CreatePointerCast(h_Output, CF->getReturnType()->getPointerTo(), "output.ptr", RI); Value* KernelOutput = new LoadInst(BI, "", RI); for(unsigned i=0; i < OutputTy->getNumElements(); i++) { Type* elemTy = OutputTy->getElementType(i); if(elemTy->isPointerTy()) { // Pointer type - assert(OutputTy->getElementType(i+1)->isIntegerTy() + assert(OutputTy->getElementType(i+1) == Type::getInt64Ty(M.getContext()) && "Every Pointer type must be followed by an integer"); ExtractValueInst* d_ptr = ExtractValueInst::Create(KernelOutput, ArrayRef<unsigned>(i), "", RI); + // Change d_ptr to i8* + CastInst* d_ptr_i8 = BitCastInst::CreatePointerCast(d_ptr, Type::getInt8PtrTy(M.getContext()), "", RI); ExtractValueInst* len = ExtractValueInst::Create(KernelOutput, ArrayRef<unsigned>(i+1), "", RI); // GetOutputPtr call Value* GetOutputArgs[] = {GraphID, - d_ptr, + d_ptr_i8, len}; - CallInst* h_ptr = CallInst::Create(llvm_visc_ptx_getOutput, + CallInst* h_ptr_i8 = CallInst::Create(llvm_visc_ptx_getOutput, ArrayRef<Value*>(GetOutputArgs, 3), "", RI); + // Change h_ptr to correct type + CastInst* h_ptr = CastInst::CreatePointerCast(h_ptr_i8, + cast<StructType>(KernelOutput->getType())->getElementType(i), + "", + RI); KernelOutput = InsertValueInst::Create(KernelOutput, h_ptr, ArrayRef<unsigned>(i), "", RI); } @@ -568,6 +610,9 @@ namespace { // Now the remaining nodes to be visited should be ignored KernelLaunchNode = NULL; writeKernelsModule(); + errs() << "Insert Runtime calls\n"; + insertRuntimeCalls(N, getKernelsModuleName(M), "matrixMul"); + } else { DEBUG(errs() << "Found intermediate node. Generating device code.\n"); //TODO