diff --git a/llvm/lib/Transforms/DFG2LLVM_NVPTX/DFG2LLVM_NVPTX.cpp b/llvm/lib/Transforms/DFG2LLVM_NVPTX/DFG2LLVM_NVPTX.cpp index 425cabd0b485db37c9ea1481e05905f43aef8523..dfc047a354742d5d9b01c192159d478be027b0cd 100644 --- a/llvm/lib/Transforms/DFG2LLVM_NVPTX/DFG2LLVM_NVPTX.cpp +++ b/llvm/lib/Transforms/DFG2LLVM_NVPTX/DFG2LLVM_NVPTX.cpp @@ -527,22 +527,26 @@ void CodeGenTraversal::insertRuntimeCalls(DFInternalNode* N, const Twine& FileNa } DEBUG(errs() << "Setup output edges of node and insert visc api\n"); - // Set output + // Set output if struct is not an empty struct StructType* OutputTy = C->getOutputType(); - unsigned outputIndex = CF->getFunctionType()->getNumParams(); - Value* outputSize = ConstantExpr::getSizeOf(OutputTy); - Value* setOutputArgs[] = {GraphID, - Constant::getNullValue(Type::getInt8PtrTy(M.getContext())), - ConstantInt::get(Type::getInt32Ty(M.getContext()),outputIndex), - ConstantExpr::getSizeOf(OutputTy), - False, - True - }; - - CallInst* d_Output = CallInst::Create(llvm_visc_ptx_argument_ptr, - ArrayRef<Value*>(setOutputArgs, 6), - "d_output."+CF->getName(), - RI); + Value *outputSize, *d_Output; + if(!OutputTy->isEmptyTy()) { + // Not an empty struct + unsigned outputIndex = CF->getFunctionType()->getNumParams(); + outputSize = ConstantExpr::getSizeOf(OutputTy); + Value* setOutputArgs[] = {GraphID, + Constant::getNullValue(Type::getInt8PtrTy(M.getContext())), + ConstantInt::get(Type::getInt32Ty(M.getContext()),outputIndex), + ConstantExpr::getSizeOf(OutputTy), + False, + True + }; + + d_Output = CallInst::Create(llvm_visc_ptx_argument_ptr, + ArrayRef<Value*>(setOutputArgs, 6), + "d_output."+CF->getName(), + RI); + } // Enqueue kernel // Need work dim, localworksize, globalworksize @@ -568,21 +572,23 @@ void CodeGenTraversal::insertRuntimeCalls(DFInternalNode* N, const Twine& FileNa "", RI); - // Read Output Struct - Value* GetOutputArgs[] = {GraphID, - Constant::getNullValue(Type::getInt8PtrTy(M.getContext())), - d_Output, - outputSize - }; - CallInst* h_Output = CallInst::Create(llvm_visc_ptx_getOutput, - ArrayRef<Value*>(GetOutputArgs, 4), - "h_output."+CF->getName()+".addr", - RI); - // Read each device pointer listed in output struct - // Load the output struct - CastInst* BI = BitCastInst::CreatePointerCast(h_Output, CF->getReturnType()->getPointerTo(), "output.ptr", RI); - Value* KernelOutput = new LoadInst(BI, "output."+CF->getName(), RI); - OutputMap[C] = KernelOutput; + // Read Output Struct if not empty + if(!OutputTy->isEmptyTy()) { + Value* GetOutputArgs[] = {GraphID, + Constant::getNullValue(Type::getInt8PtrTy(M.getContext())), + d_Output, + outputSize + }; + CallInst* h_Output = CallInst::Create(llvm_visc_ptx_getOutput, + ArrayRef<Value*>(GetOutputArgs, 4), + "h_output."+CF->getName()+".addr", + RI); + // Read each device pointer listed in output struct + // Load the output struct + CastInst* BI = BitCastInst::CreatePointerCast(h_Output, CF->getReturnType()->getPointerTo(), "output.ptr", RI); + Value* KernelOutput = new LoadInst(BI, "output."+CF->getName(), RI); + OutputMap[C] = KernelOutput; + } // Read all the pointer arguments which had side effects i.e., had out // attribute @@ -1160,7 +1166,7 @@ void CodeGenTraversal::transformFunctionToVoid(Function* F) { // Check for { } return struct, which means that the function returns void - if (FRetTy->getNumElements() == 0) { + if (FRetTy->isEmptyTy()) { DEBUG(errs() << "\tFunction output struct is void\n"); DEBUG(errs() << "\tNo parameters added\n"); @@ -1172,17 +1178,45 @@ void CodeGenTraversal::transformFunctionToVoid(Function* F) { (*i)->eraseFromParent(); } DEBUG(errs() << "\tChanged return statements to return void\n"); - - return; } + else { + // The struct has return values, thus needs to be converted to parameter - // The struct has return values, thus needs to be converted to parameter + int initialNumParams = F->arg_size(); - int initialNumParams = F->arg_size(); + Type* ArgType = FRetTy->getPointerTo(GENERIC_ADDRSPACE); + new Argument(ArgType, "ret_struct_ptr", F); + DEBUG(errs() << "\tCreated parameter\n"); - Type* ArgType = FRetTy->getPointerTo(GENERIC_ADDRSPACE); - new Argument(ArgType, "ret_struct_ptr", F); - DEBUG(errs() << "\tCreated parameter\n"); + // Find where the new parameter is in the header + Function::arg_iterator ai, ae; + int check = 0; + for (ai = F->arg_begin(), ae = F->arg_end(); + ai != ae; ++ai) { + if (ai->getName().equals("ret_struct_ptr")) break; + check++; + } + + // DEBUG(errs() << "\tcheck = " << check << "\tinitialNumParams = " << initialNumParams << "\n"); + assert(check == initialNumParams); + + DEBUG(errs() << "\tReplacing Return statements\n"); + // Replace return statements with extractValue and store instructions + for (std::vector<ReturnInst *>::iterator rii = RItoRemove.begin(), + rie = RItoRemove.end(); rii != rie; ++rii) { + ReturnInst* RI = (*rii); + Value* RetVal = RI->getReturnValue(); + // assert(RetVal && "Return value should not be null at this point"); + // StructType* RetType = cast<StructType>(RetVal->getType()); + // assert(RetType && "Return type is not a struct"); + + new StoreInst(RetVal, &(*ai), RI); + ReturnInst::Create((F->getContext()), 0, RI); + RI->eraseFromParent(); + + } + } + DEBUG(errs() << "\tReplaced return statements\n"); // Create the argument type list with the added argument's type std::vector<Type*> ArgTypes; @@ -1191,36 +1225,6 @@ void CodeGenTraversal::transformFunctionToVoid(Function* F) { ArgTypes.push_back(ai->getType()); } - // Find where the new parameter is in the header - Function::arg_iterator ai, ae; - int check = 0; - for (ai = F->arg_begin(), ae = F->arg_end(); - ai != ae; ++ai) { - if (ai->getName().equals("ret_struct_ptr")) break; - check++; - } - -// DEBUG(errs() << "\tcheck = " << check << "\tinitialNumParams = " << initialNumParams << "\n"); - assert(check == initialNumParams); - - DEBUG(errs() << "\tReplacing Return statements\n"); - // Replace return statements with extractValue and store instructions - for (std::vector<ReturnInst *>::iterator rii = RItoRemove.begin(), - rie = RItoRemove.end(); rii != rie; ++rii) { - ReturnInst* RI = (*rii); - Value* RetVal = RI->getReturnValue(); - // assert(RetVal && "Return value should not be null at this point"); - // StructType* RetType = cast<StructType>(RetVal->getType()); - // assert(RetType && "Return type is not a struct"); - - new StoreInst(RetVal, &(*ai), RI); - ReturnInst::Create((F->getContext()), 0, RI); - RI->eraseFromParent(); - - } - - DEBUG(errs() << "\tReplaced return statements\n"); - // Adding new arguments to the function argument list, would not change the // function type. We need to change the type of this function to reflect the // added arguments diff --git a/llvm/test/VISC/MatrixMultiplication/Makefile b/llvm/test/VISC/MatrixMultiplication/Makefile index 9ca473645d1baa081e5dd34e257564a175245e40..e1ece78f982d2f363852ec0006b667bbdfd1b585 100644 --- a/llvm/test/VISC/MatrixMultiplication/Makefile +++ b/llvm/test/VISC/MatrixMultiplication/Makefile @@ -27,4 +27,4 @@ $(HOST:%=%.bin):%.bin:%.c $(LLVM_CC) -O3 -lOpenCL -I /usr/local/cuda/include $< -o $@ clean : - rm -f $(HOST).ll $(KERNELS).ll *.bc *.s *.bin *.kernels.ll + rm -f $(HOST).ll $(KERNELS).ll *.bc *.s *.bin *.kernels.ll DataflowGraph.dot* diff --git a/llvm/test/VISC/MatrixMultiplication/visc_gemm.ll b/llvm/test/VISC/MatrixMultiplication/visc_gemm.ll index fe43d79625d2c11e10daa2f434d4d52984bfdc7e..b8191de623e10744ccf1672391f3ccf2a84f55e4 100644 --- a/llvm/test/VISC/MatrixMultiplication/visc_gemm.ll +++ b/llvm/test/VISC/MatrixMultiplication/visc_gemm.ll @@ -131,7 +131,7 @@ declare i32 @printf(i8* nocapture, ...) #1 ; --------------- VISC Intrinsics --------------- ; Return Type of VISC Compute Matrix Mul -%rtype = type {float*, i64} +%rtype = type {} %struct.arg = type <{ float*, i64, float*, i64, float*, i64, i32, i32, i32, %rtype }> ; Function Attrs: nounwind @@ -178,7 +178,7 @@ declare void @llvm.visc.bind.output(i8*, i32, i32) ; ----------------- VISC intrinsics end ------------------ ; Function Attrs: nounwind uwtable -define %rtype @matrixMul(float* nocapture %A, i64 %bytes_A, float* nocapture %B, i64 %bytes_B, float* %C, i64 %bytes_C, i32 %k, i32 %n, i32 %m) #0 { +define %rtype @matrixMul(float* nocapture in %A, i64 %bytes_A, float* nocapture in %B, i64 %bytes_B, float* out %C, i64 %bytes_C, i32 %k, i32 %n, i32 %m) #0 { entry: ;%puts = tail call i32 @puts(i8* getelementptr inbounds ([17 x i8]* @str, i64 0, i64 0)) @@ -231,13 +231,11 @@ for.end: ; preds = %for.body, %entry store float %res.0.lcssa, float* %arrayidx19, align 4, !tbaa !0 ;%puts42 = tail call i32 @puts(i8* getelementptr inbounds ([20 x i8]* @str11, i64 0, i64 0)) ;%puts43 = tail call i32 @puts(i8* getelementptr inbounds ([17 x i8]* @str12, i64 0, i64 0)) - %.fca.0.insert = insertvalue %rtype undef, float* %C, 0 - %.fca.1.insert = insertvalue %rtype %.fca.0.insert, i64 %bytes_C, 1 - ret %rtype %.fca.1.insert + ret %rtype undef } ; ----------------- VISC SGEMM root node ---------------- -define %rtype @MatrixMulRoot(float* %h_A, i64 %bytes_A, float* %h_B, i64 %bytes_B, float* %h_C, i64 %bytes_C, i32 %WA, i32 %WB, i32 %HA) { +define %rtype @MatrixMulRoot(float* in %h_A, i64 %bytes_A, float* in %h_B, i64 %bytes_B, float* out %h_C, i64 %bytes_C, i32 %WA, i32 %WB, i32 %HA) { %kernel = call i8* @llvm.visc.createNode2D(i8* bitcast (%rtype (float*, i64, float*, i64, float*, i64, i32, i32, i32)* @matrixMul to i8*), i32 %WB, i32 %HA) ; Bind Inputs call void @llvm.visc.bind.input(i8* %kernel, i32 0, i32 0); h_A @@ -250,9 +248,7 @@ define %rtype @MatrixMulRoot(float* %h_A, i64 %bytes_A, float* %h_B, i64 %bytes_ call void @llvm.visc.bind.input(i8* %kernel, i32 7, i32 7); WB = WC = n call void @llvm.visc.bind.input(i8* %kernel, i32 8, i32 8); HA = HC = m ; Bind Outputs - call void @llvm.visc.bind.output(i8* %kernel, i32 0, i32 0); d_C - call void @llvm.visc.bind.output(i8* %kernel, i32 1, i32 1); bytes_C - ret %rtype zeroinitializer + ret %rtype undef } ; Function Attrs: noinline nounwind uwtable @@ -368,8 +364,7 @@ randomInit.exit41: ; preds = %for.body.i40 %out = load %rtype* %out.addr ; -------------------------------- Completed VISC Launch Call -------------------------------- - %3 = extractvalue %rtype %out, 0 - %call14 = tail call i32 @checkResults(float* %0, float* %1, float* %3) + %call14 = tail call i32 @checkResults(float* %0, float* %1, float* %2) %tobool = icmp eq i32 %call14, 0 br i1 %tobool, label %if.else, label %if.then diff --git a/llvm/test/VISC/MatrixMultiplication/visc_gemm_ptx.ll b/llvm/test/VISC/MatrixMultiplication/visc_gemm_ptx.ll index c76076a8726fcd1ff9fa9d355b1a0153e9c64e16..f0d0837f718fe37b92a5a8cd69470c7ae37cfa17 100644 --- a/llvm/test/VISC/MatrixMultiplication/visc_gemm_ptx.ll +++ b/llvm/test/VISC/MatrixMultiplication/visc_gemm_ptx.ll @@ -136,7 +136,7 @@ declare i32 @printf(i8* nocapture, ...) #1 ; --------------- VISC Intrinsics --------------- ; Return Type of VISC Compute Matrix Mul -%rtype = type {i64} +%rtype = type {} %struct.arg = type <{ float*, i64, float*, i64, float*, i64, i32, i32, i32, %rtype }> ; Function Attrs: nounwind @@ -236,8 +236,7 @@ for.end: ; preds = %for.body, %entry store float %res.0.lcssa, float* %arrayidx19, align 4, !tbaa !0 ;%puts42 = tail call i32 @puts(i8* getelementptr inbounds ([20 x i8]* @str11, i64 0, i64 0)) ;%puts43 = tail call i32 @puts(i8* getelementptr inbounds ([17 x i8]* @str12, i64 0, i64 0)) - %.fca.1.insert = insertvalue %rtype undef, i64 %bytes_C, 0 - ret %rtype %.fca.1.insert + ret %rtype undef } ; ----------------- VISC SGEMM root node ---------------- @@ -254,8 +253,7 @@ define %rtype @MatrixMulRoot(float* %h_A, i64 %bytes_A, float* %h_B, i64 %bytes_ call void @llvm.visc.bind.input(i8* %kernel, i32 7, i32 7); WB = WC = n call void @llvm.visc.bind.input(i8* %kernel, i32 8, i32 8); HA = HC = m ; Bind Outputs - call void @llvm.visc.bind.output(i8* %kernel, i32 0, i32 0); bytes_C - ret %rtype zeroinitializer + ret %rtype undef } ; Function Attrs: noinline nounwind uwtable @@ -371,7 +369,6 @@ randomInit.exit41: ; preds = %for.body.i40 %out = load %rtype* %out.addr ; -------------------------------- Completed VISC Launch Call -------------------------------- - %3 = extractvalue %rtype %out, 0 %call14 = tail call i32 @checkResults(float* %0, float* %1, float* %2) %tobool = icmp eq i32 %call14, 0 br i1 %tobool, label %if.else, label %if.then