Skip to content
Snippets Groups Projects
Commit dda0bb4e authored by Hashim Sharif's avatar Hashim Sharif
Browse files

Handling visc.node.id appropriately in Fusion Pass

parent be677ed3
No related branches found
No related tags found
No related merge requests found
......@@ -78,17 +78,36 @@ static bool isIncomingEdgeArgument(unsigned argno,
return false;
}
// Check that this is a valid HPVM Tensor Node (starts with an HPVM intrinsic)
// Return the node intrinsic function
static IntrinsicInst *isValidHPVMTensorNode(DFNode *N) {
Function *F = N->getFuncPointer();
IntrinsicInst *II = dyn_cast<IntrinsicInst>(&*(inst_begin(F)));
assert(II &&
"HPVM tensor intrinsic expected as first instruction of HPVM tensor node\n");
assert((II->getCalledFunction()->getName()).startswith("llvm.visc.tensor") &&
"Only HPVM tensor intrinsics allowed in ApproxHPVM leaf nodes\n");
//IntrinsicInst *II = dyn_cast<IntrinsicInst>(&*(inst_begin(F)));
IntrinsicInst *II;
for (auto I = inst_begin(F), E = inst_end(F); I != E; I++){
if(dyn_cast<IntrinsicInst>(&*I)){
II = dyn_cast<IntrinsicInst>(&*I);
if ((II->getCalledFunction()->getName()).startswith("llvm.visc.tensor")){
errs()<<"**** WATCH *** " << *II << "\n\n\n";
}
}
}
//assert(II &&
// "HPVM tensor intrinsic expected as first instruction of HPVM tensor node\n");
//assert((II->getCalledFunction()->getName()).startswith("llvm.visc.tensor") &&
// "Only HPVM tensor intrinsics allowed in ApproxHPVM leaf nodes\n");
return II;
}
// Returns the next node in a node sequence, or NULL if it does not exist.
// We consider two nodes a sequence if SrcN has a single successor, DstN,
......@@ -340,9 +359,9 @@ Function* FuseHPVMTensorNodes::createEmptyDFNodeFunction(IntrinsicInst* II1,
the body of the fused function instead *
* OutVs: This maps the output struct field index to the stored value */
void FuseHPVMTensorNodes::inlineFirstNodeFunction(Module &M, Function *F1,
Function *Ffused,
ValueMap<Value*, Value*> &VMap,
std::vector<Value*> &OutVs) {
Function *Ffused,
ValueMap<Value*, Value*> &VMap,
std::vector<Value*> &OutVs) {
ReturnInst *RI = cast<ReturnInst>(Ffused->getEntryBlock().getTerminator());
......@@ -356,8 +375,9 @@ void FuseHPVMTensorNodes::inlineFirstNodeFunction(Module &M, Function *F1,
}
IntrinsicInst* II = dyn_cast<IntrinsicInst>(I);
assert((II->getCalledFunction()->getName()).startswith("llvm.visc.tensor")
&& "Only HPVM tensor intrinsics allowed in ApproxHPVM leaf nodes\n");
assert ( ((II->getCalledFunction()->getName()).startswith("llvm.visc.tensor")
|| (II->getCalledFunction()->getName()).startswith("llvm.visc.node.id") )
&& "Only HPVM tensor intrinsics allowed in ApproxHPVM leaf nodes\n");
std::vector<Value*> Args;
for(unsigned i = 0; i < II->getNumArgOperands(); i++) {
......@@ -370,6 +390,7 @@ void FuseHPVMTensorNodes::inlineFirstNodeFunction(Module &M, Function *F1,
Args.push_back(VMap[V]);
}
}
Function *F = Intrinsic::getDeclaration(&M, II->getIntrinsicID());
CallInst* CI =
CallInst::Create(F, Args,
......@@ -409,9 +430,14 @@ void FuseHPVMTensorNodes::inlineSecondNodeFunction(Module &M, Function *F2,
Instruction *I = &(*f2_i);
if ((BuildDFG::isViscIntrinsic(I))) {
IntrinsicInst* II = dyn_cast<IntrinsicInst>(I);
assert((II->getCalledFunction()->getName()).startswith("llvm.visc.tensor")
assert( ((II->getCalledFunction()->getName()).startswith("llvm.visc.tensor")
|| (II->getCalledFunction()->getName()).startswith("llvm.visc.node.id"))
&& "Only HPVM tensor intrinsics allowed in ApproxHPVM leaf nodes\n");
if ( (II->getCalledFunction()->getName()).startswith("llvm.visc.node.id")) {
continue; // Skip adding visc.node.id calls in nodes other than first node
}
std::vector<Value*> Args;
for(unsigned i = 0; i < II->getNumArgOperands(); i++) {
Value *V = II->getArgOperand(i);
......@@ -506,10 +532,11 @@ Function* FuseHPVMTensorNodes::createLeafDFNodeFunction(IntrinsicInst* II1,
++fused_arg_it;
}
// for(const auto& v: FusedValueMap) {
// errs() << "key = " << *(v.first) << "\t";
// errs() << "value = " << *(v.second) << "\n";
// }
// for(const auto& v: FusedValueMap) {
// errs() << "key = " << *(v.first) << "\t";
// errs() << "value = " << *(v.second) << "\n";
// }
// Invoke function that inlines F1 into Ffused, using and updating mappings
inlineFirstNodeFunction(M, F1, Ffused, FusedValueMap, OutValues);
......@@ -670,6 +697,7 @@ void FuseHPVMTensorNodes::updateParentNodeFunction(IntrinsicInst* II1,
DEBUG(errs() << "Erasing: " << **ib << "\n");
(*ib)->eraseFromParent();
}
II2->replaceAllUsesWith(IInew);
II2->eraseFromParent();
......@@ -792,6 +820,7 @@ void FindFusionTargetsTraversal::codeGen(DFInternalNode *N) {
return;
}
void FindFusionTargetsTraversal::codeGen(DFLeafNode *N) {
DEBUG(errs() << "Inside leaf node: "
<< N->getFuncPointer()->getName() << "\n");
......@@ -802,9 +831,9 @@ void FindFusionTargetsTraversal::codeGen(DFLeafNode *N) {
return;
}
// if(N->getTargetHint() != visc::PROMISE_TARGET) {
if(!preferredTargetIncludes(N, visc::PROMISE_TARGET)) {
// Only fuse if we plan to target PROMISE
// Only fuse if we plan to target PROMISE/Layers API
// The CUDNN backend would be able to generate calls for the fused node,
// but not the other way around
DEBUG(errs() << "No PROMISE hint. Skipping node: "
......@@ -820,6 +849,14 @@ void FindFusionTargetsTraversal::codeGen(DFLeafNode *N) {
std::vector<IntrinsicInst*> CurrentNodeSequence;
switch(II->getIntrinsicID()) {
/*case Intrinsic::visc_node_id:
{ // Found beginning of pattern conv-bias-activation-pooling.
}
break;
*/
case Intrinsic::visc_tensor_convolution:
{ // Found beginning of pattern conv-bias-activation-pooling.
// Look for the rest
......@@ -931,9 +968,9 @@ void FindFusionTargetsTraversal::codeGen(DFLeafNode *N) {
}
bool FuseHPVMTensorNodesWrapper::runOnModule(Module &M) {
errs() << "\nFUSE HPVM TENSOR NODES PASS\n";
// Get the BuildDFG Analysis Results:
errs() << "\nFUSE HPVM TENSOR NODES PASS\n";
// Get the BuildDFG Analysis Results:
// - Dataflow graph
BuildDFG &DFG = getAnalysis<BuildDFG>();
......@@ -952,7 +989,7 @@ bool FuseHPVMTensorNodesWrapper::runOnModule(Module &M) {
FuseHPVMTensorNodes::FusionTargets &FTs = FTTVisitor->getFusionTargets();
FuseHPVMTensorNodes Fuse;
// Fuse.printFusionTargets(FTs);
// Fuse.printFusionTargets(FTs);
Fuse.run(M, FTs);
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment