Skip to content
Snippets Groups Projects
ReplaceIntrinsics.cpp 18.05 KiB
//=== ReplaceApproxHPVMIntrinsicsWithFCalls.cpp ===//
//
//                     The LLVM Compiler Infrastructure
//
// This file is distributed under the University of Illinois Open Source
// License. See LICENSE.TXT for details.
//
//===----------------------------------------------------------------------===//
#define ENABLE_ASSERTS

#define DEBUG_TYPE "REPLACE_APPROXHPVM_INTRINSICS_WITH_FCALLS"

#include "llvm/IR/DataLayout.h"
#include "llvm/IR/IRBuilder.h"
#include "llvm/IR/Module.h"
#include "llvm/Pass.h"
#include "llvm/IR/InstIterator.h"
#include "llvm/Transforms/Utils/ValueMapper.h"
#include "llvm/Transforms/Utils/BasicBlockUtils.h"
#include "llvm/Transforms/Utils/Cloning.h"
#include "llvm/IRReader/IRReader.h"
#include "llvm/Linker/Linker.h"
#include "llvm/Support/SourceMgr.h"
#include "llvm/Support/FileSystem.h"
#include "llvm/IR/Attributes.h"
#include "llvm-c/Core.h"

#include "SupportHPVM/DFG2LLVM.h"
#include "InPlaceDFG/InPlaceDFGAnalysis.h"

#include <sstream>

using namespace llvm;
using namespace builddfg;
using namespace dfg2llvm;

// TODO: We still need in place analysis, if calls have the same interface
using namespace inplacedfg;

namespace {
// Helper class declarations

// Replace ApproxHPVM intrinsics with LLVM function calls.
// aiming to go through the CPU backend code generation.

struct DFG2LLVM_ReplaceApproxHPVMIntrinsicsWithFCalls : public DFG2LLVM {
  static char ID; // Pass identification, replacement for typeid
  DFG2LLVM_ReplaceApproxHPVMIntrinsicsWithFCalls() : DFG2LLVM(ID) {}

private:
public:
  void getAnalysisUsage(AnalysisUsage &AU) const {
    AU.addRequired<BuildDFG>();
    AU.addRequired<InPlaceDFGAnalysisWrapper>();
    AU.addPreserved<BuildDFG>();
    AU.addPreserved<InPlaceDFGAnalysisWrapper>();
  }

  bool runOnModule(Module &M);
};

// Visitor for Code generation traversal (tree traversal for now)
class CGT_ReplaceApproxHPVMIntrinsicsWithFCalls : public CodeGenTraversal {

private:
  // Member variables
  InPlaceDFGAnalysis::InPlaceDFGParameter *IPP;

  // VISC Runtime API and Tensor runtime API
  /* TODO: I believe that TensorRt is not needed, since we will have llvm
   implementations linked in, so init and cleanup calls can be removed and
   relevant code also, but I leave in in for now until verified. */
  FunctionCallee llvm_hpvm_initTensorRt;
  FunctionCallee llvm_hpvm_cleanupTensorRt;
  //  Constant* hpvm_request_tensor; DONE: request tensor will not be used

  // Functions
  bool isValidOperandForInPlaceOperation(Value *Op, Function *Fgen, DFNode *N);

  // Virtual Functions
  void init();
  void initRuntimeAPI();
  void codeGen(DFInternalNode *N);
  void codeGen(DFLeafNode *N);

public:
  // Constructor
  CGT_ReplaceApproxHPVMIntrinsicsWithFCalls(
      Module &_M, BuildDFG &_DFG, InPlaceDFGAnalysis::InPlaceDFGParameter &_IPP)
      : CodeGenTraversal(_M, _DFG), IPP(&_IPP) {
    initRuntimeAPI();
  }
};

bool CGT_ReplaceApproxHPVMIntrinsicsWithFCalls::
    isValidOperandForInPlaceOperation(Value *Op, Function *Fgen, DFNode *N) {
  // We only expect the if branch to be taken
  if (Argument *Arg = dyn_cast<Argument>(Op)) {
    DEBUG(errs() << *Arg << "\t: argument, candidate for in place\n");
    assert((Arg->getParent() == Fgen) &&
           "Extra Parameter in body of Function\n");
    // Candidae parameter is a function argument
    // In this case, consult the result of in place analysis
    // Find position in arg list
    unsigned pos = Arg->getArgNo();
    // If this parameter cannot be used for in place operation
    // code gen cannot continue
    if (IPP->at(N)[pos]) {
      DEBUG(errs() << *Arg << "\t: argument, suitable for in place\n");
      return true;
    } else {
      DEBUG(errs() << *Arg << "\t: argument, not suitable for in place\n");
      return false;
    }
  } else {
    // If it is not an argument, then it needs to be the result of
    // another intrinsic. These are new objects that are allocated,
    // and consumed by next intrinsic. Alternatively, the intrinsic
    // could have been replaced by a call to an LLVM function.
    // We do not expect a merge pass to have run before the replacement pass,
    // therefore we do not expect to go in the else branch.
    DEBUG(errs() << *Op << "\t: Test for result of intrinsic operation\n");
    if (dyn_cast<IntrinsicInst>(Op)) {
      DEBUG(errs() << *Arg << "\t: local, suitable for in place\n");
      return true;
    } else if (CallInst *CI = dyn_cast<CallInst>(Op)) {
      if ((CI->getCalledFunction()->getName()).startswith("tensor"))
        return true;
      else
        return false;
    } else {
      DEBUG(errs() << *Arg << "\t: local, not suitable for in place\n");
      return false;
    }
  }
}

void CGT_ReplaceApproxHPVMIntrinsicsWithFCalls::init() {}
// Initialize the VISC runtime API. This makes it easier to insert these calls
void CGT_ReplaceApproxHPVMIntrinsicsWithFCalls::initRuntimeAPI() {

  // Load Runtime API Module
  SMDiagnostic Err;
  runtimeModule = parseIRFile(TENSOR_RT_LL, Err, M.getContext());
  if (runtimeModule == nullptr)
    DEBUG(errs() << Err.getMessage());
  else
    DEBUG(errs() << "Successfully loaded hpvm-tensor-rt API module\n");

  // Get or insert Global declarations for
  // - initialization
  // - cleanup
  // - request a tensor
  DECLARE(llvm_hpvm_initTensorRt);
  DECLARE(llvm_hpvm_cleanupTensorRt);
  //  DECLARE(hpvm_request_tensor);

  // Find hpvm.init and visc.cleanup calls, and add placeholder methods
  // for initialization and cleanup of the hpvm tensor runtime

  Function *VI = M.getFunction("llvm.hpvm.init");
  assert(VI->getNumUses() == 1 && "__hpvm__init should only be used once\n");
  InitCall = cast<Instruction>(*VI->user_begin());
  CallInst::Create(
      llvm_hpvm_initTensorRt,
      ArrayRef<Value *>(ConstantInt::get(Type::getInt32Ty(M.getContext()), 0)),
      "", InitCall);

  Function *VC = M.getFunction("llvm.hpvm.cleanup");
  assert(VC->getNumUses() == 1 && "__hpvm__clear should only be used once\n");
  CleanupCall = cast<Instruction>(*VC->user_begin());
  CallInst::Create(llvm_hpvm_cleanupTensorRt, ArrayRef<Value *>(), "",
                   CleanupCall);
}

void CGT_ReplaceApproxHPVMIntrinsicsWithFCalls::codeGen(DFInternalNode *N) {
  errs() << "Inside node: " << N->getFuncPointer()->getName() << "\n";
  errs() << "Skipping internal node\n";
}

void CGT_ReplaceApproxHPVMIntrinsicsWithFCalls::codeGen(DFLeafNode *N) {

  // Skip if it is a dummy node
  if (N->isDummyNode()) {
    DEBUG(errs() << "Skipping dummy node\n");
    return;
  }

  // Abort if it is an allocation node
  if (N->isAllocationNode()) {
    assert(false && "Allocation Node not expected in ApproxHPVM");
    return;
  }

  // Search for intrinsic only if it has the right hint
  if (!checkPreferredTarget(N, hpvm::CPU_TARGET)) {
    errs() << "Skipping node: " << N->getFuncPointer()->getName() << "\n";
    return;
  }

  // Get the function associated with the dataflow node
  Function *F = N->getFuncPointer();
  errs() << "function name = " << F->getName() << "\n";

  std::vector<IntrinsicInst *> IItoRemove;

  for (inst_iterator i = inst_begin(F), e = inst_end(F); i != e; ++i) {
    Instruction *I = &(*i);
    if (BuildDFG::isHPVMIntrinsic(I)) {
      IntrinsicInst *II = dyn_cast<IntrinsicInst>(I);
      assert(
          (II->getCalledFunction()->getName()).startswith("llvm.hpvm.tensor") &&
          "Only HPVM tensor intrinsics allowed in ApproxHPVM leaf nodes\n");
      /********************* Handle VISC Tensor intrinsics ********************/
      // We replace them with calls to functions with implementations at the
      // LLVM level
      switch (II->getIntrinsicID()) {

      case Intrinsic::hpvm_tensor_convolution: { /* llvm.hpvm.tensor.convolution
                                                  */
        DEBUG(errs() << F->getName() << "\t: Handling tensor convolution \n");

        // Argument list for the runtime call
        std::vector<Value *> Args;
        Args.push_back(II->getOperand(0));
        Args.push_back(II->getOperand(1));
        Args.push_back(II->getOperand(2));
        Args.push_back(II->getOperand(3));
        Args.push_back(II->getOperand(4));
        Args.push_back(II->getOperand(5));

        Constant *conv_mode =
            ConstantInt::get(Type::getInt32Ty(M.getContext()), 1);
        Constant *conv_precision =
            ConstantInt::get(Type::getInt32Ty(M.getContext()), 0);

        Args.push_back(conv_mode);
        Args.push_back(conv_precision);

        // Create function call
        FunctionCallee tensorConvolutionCPU;
        DECLARE(tensorConvolutionCPU);

        CallInst *CI = CallInst::Create(tensorConvolutionCPU, Args, "", II);
        // We can replace the call to hpvm.tensor.mul with the LLVM call
        II->replaceAllUsesWith(CI);

        // Mark to remove at the end
        IItoRemove.push_back(II);
      } break;

      case Intrinsic::hpvm_tensor_mul: { /* llvm.hpvm.tensor.mul */
        DEBUG(errs() << F->getName() << "\t: Handling tensor mul\n");

        // Argument list for the runtime call
        std::vector<Value *> Args;
        Args.push_back(II->getOperand(0));
        Args.push_back(II->getOperand(1));

        // Create function call
        FunctionCallee tensorGemmCPU;
        DECLARE(tensorGemmCPU);

        CallInst *CI = CallInst::Create(tensorGemmCPU, Args, "", II);
        // We can replace the call to hpvm.tensor.mul with the LLVM call
        II->replaceAllUsesWith(CI);

        // Mark to remove at the end
        IItoRemove.push_back(II);
      } break;

      case Intrinsic::hpvm_tensor_add: { /* llvm.hpvm.tensor.add */
        DEBUG(errs() << F->getName() << "\t: Handling tensor add\n");
        // Tensor add(a,b) is in place for argument a.
        Value *Op = II->getOperand(0);

        // Test the intrinsic operand for in place operation.
        bool inplace = isValidOperandForInPlaceOperation(Op, F, N);
        // Code generation cannot continue if this is false, because the target
        // only provides an in place operation

        // FIXME: remove this comment - must check for in-place
        // assert(inplace &&
        //       "Operand not valid for in place operation. Code gen
        //       aborted.\n");

        // Argument list for the runtime call
        std::vector<Value *> Args;
        Args.push_back(II->getOperand(0));
        Args.push_back(II->getOperand(1));

        // Create function call
        FunctionCallee tensorAddCPU;
        DECLARE(tensorAddCPU);
        CallInst::Create(tensorAddCPU, Args, "", II);
        // We can replace the call to hpvm.tensor.add with the 1st argument
        // that, due to in place operation, now contains the result
        II->replaceAllUsesWith(II->getOperand(0));

        // Mark to remove at the end
        IItoRemove.push_back(II);
      } break;

      case Intrinsic::hpvm_tensor_pool_max:
      case Intrinsic::hpvm_tensor_pool_mean: { /* llvm.hpvm.tensor.relu */
        DEBUG(errs() << F->getName() << "\t: Handling tensor_pool_max\n");
        // Tensor relu(a) is in place for argument a.
        Value *Op = II->getOperand(0);

        // Test the intrinsic operand for in place operation.
        bool inplace = isValidOperandForInPlaceOperation(Op, F, N);
        // Code generation cannot continue if this is false, because the target
        // only provides an in place operation
        assert(inplace &&
               "Operand not valid for in place operation. Code gen aborted.\n");

        // Argument list - tensorPooling(input, poolFunction, window_height,
        // window_width, vertical_pad, horizontal_pad,
        //                               vertical_stride, horizontal_stride);
        std::vector<Value *> Args;
        Args.push_back(II->getOperand(0));

        int pool_type = 0;
        if (II->getIntrinsicID() == Intrinsic::hpvm_tensor_pool_max) {
          pool_type = 0;
        }
        if (II->getIntrinsicID() == Intrinsic::hpvm_tensor_pool_mean) {
          pool_type = 1;
        }

        Constant *constPoolType =
            ConstantInt::get(Type::getInt32Ty(M.getContext()), pool_type);
        Args.push_back(constPoolType); // ID for max pool. Min/Avg have
                                       // different IDs (non-zero)
        Args.push_back(II->getOperand(1));
        Args.push_back(II->getOperand(2));
        Args.push_back(II->getOperand(3));
        Args.push_back(II->getOperand(4));
        Args.push_back(II->getOperand(5));
        Args.push_back(II->getOperand(6));

        // Create function call
        FunctionCallee tensorPoolingCPU;
        DECLARE(tensorPoolingCPU);
        CallInst *CI = CallInst::Create(tensorPoolingCPU, Args, "", II);

        // Replacing intrinsic result uses with the result of the LLVM call
        II->replaceAllUsesWith(CI);

        // Mark to remove at the end
        IItoRemove.push_back(II);
      } break;

      case Intrinsic::hpvm_tensor_relu:
      case Intrinsic::hpvm_tensor_clipped_relu:
      case Intrinsic::hpvm_tensor_tanh: { /* llvm.hpvm.tensor.relu */
        DEBUG(errs() << F->getName()
                     << "\t: Handling tensor activation functions \n");
        // Tensor relu(a) is in place for argument a.
        Value *Op = II->getOperand(0);

        // Test the intrinsic operand for in place operation.
        bool inplace = isValidOperandForInPlaceOperation(Op, F, N);
        // Code generation cannot continue if this is false, because the target
        // only provides an in place operation
        assert(inplace &&
               "Operand not valid for in place operation. Code gen aborted.\n");

        // Argument list for the runtime call
        std::vector<Value *> Args;
        Args.push_back(II->getOperand(0));

        if (II->getIntrinsicID() == Intrinsic::hpvm_tensor_relu) {
          // Create function call
          FunctionCallee tensorReluCPU;
          DECLARE(tensorReluCPU);
          CallInst::Create(tensorReluCPU, Args, "", II);
        } else if (II->getIntrinsicID() ==
                   Intrinsic::hpvm_tensor_clipped_relu) {
          // Create function call
          //-- FunctionCallee tensorClippedRelu;
          FunctionCallee tensorRelu2CPU;
          DECLARE(tensorRelu2CPU);
          CallInst::Create(tensorRelu2CPU, Args, "", II);
        } else if (II->getIntrinsicID() == Intrinsic::hpvm_tensor_tanh) {
          // Create function call
          FunctionCallee tensorTanhCPU;
          errs() << "tensorTanh Call = \n\n";
          DECLARE(tensorTanhCPU);
          // errs()<<"tensorTanh Call = "<<*tensorTanh<<"\l";
          CallInst::Create(tensorTanhCPU, Args, "", II);
        }

        // We can replace the call to hpvm.tensor.relu with the 1st argument
        // that, due to in place operation, now contains the result
        II->replaceAllUsesWith(II->getOperand(0));

        // Mark to remove at the end
        IItoRemove.push_back(II);
      } break;

      case Intrinsic::hpvm_tensor_softmax: { /* llvm.hpvm.tensor.softmax */
        DEBUG(errs() << F->getName() << "\t: Handling tensor softmax\n");
        // Tensor relu(a) is in place for argument a.
        Value *Op = II->getOperand(0);

        // Test the intrinsic operand for in place operation.
        bool inplace = isValidOperandForInPlaceOperation(Op, F, N);
        // Code generation cannot continue if this is false, because the target
        // only provides an in place operation
        assert(inplace &&
               "Operand not valid for in place operation. Code gen aborted.\n");

        // Argument list for the runtime call
        std::vector<Value *> Args;
        Args.push_back(II->getOperand(0));

        // Create function call
        FunctionCallee tensorSoftmaxCPU;
        DECLARE(tensorSoftmaxCPU);
        CallInst::Create(tensorSoftmaxCPU, Args, "", II);
        // We can replace the call to hpvm.tensor.softmax with the 1st argument
        // that, due to in place operation, now contains the result
        II->replaceAllUsesWith(II->getOperand(0));

        // Mark to remove at the end
        IItoRemove.push_back(II);
      } break;

      default:
        llvm_unreachable("Unknown VISC Intrinsic!");
        break;
      }
    }
  }

  // We need to do this explicitly: DCE pass may not remove them.
  // Traverse the vector backwards, otherwise definitions are deleted while
  // their subsequent uses are still around.
  for (std::vector<IntrinsicInst *>::reverse_iterator ri = IItoRemove.rbegin(),
                                                      re = IItoRemove.rend();
       ri != re; ++ri) {
    DEBUG(errs() << "Erasing: " << **ri << "\n");
    errs() << "Erasing: " << **ri << "\n";
    (*ri)->eraseFromParent();
  }

  return;
}

bool DFG2LLVM_ReplaceApproxHPVMIntrinsicsWithFCalls::runOnModule(Module &M) {
  errs() << "\nDFG2LLVM_ReplaceApproxHPVMIntrinsicsWithFCalls PASS\n";

  // Get the BuildDFG Analysis Results:
  // - Dataflow graph
  BuildDFG &DFG = getAnalysis<BuildDFG>();

  // Get the In Place Analysis Results
  InPlaceDFGAnalysis::InPlaceDFGParameter IPP =
      (getAnalysis<InPlaceDFGAnalysisWrapper>()).getIPP();
  // Print results
  printInPlaceDFGParameter(IPP);

  std::vector<DFInternalNode *> Roots = DFG.getRoots();

  // Visitor for Code Generation Graph Traversal
  CGT_ReplaceApproxHPVMIntrinsicsWithFCalls *CGTVisitor =
      new CGT_ReplaceApproxHPVMIntrinsicsWithFCalls(M, DFG, IPP);

  // Iterate over all the DFGs and produce code for each one of them
  for (auto rootNode : Roots) {
    // Initiate code generation for root DFNode
    CGTVisitor->visit(rootNode);
  }

  // TODO: Edit module epilogue to remove the VISC intrinsic declarations
  delete CGTVisitor;

  return true;
}

/******************************************************************************
 *                              Helper functions                              *
 ******************************************************************************/

} // End of namespace

char DFG2LLVM_ReplaceApproxHPVMIntrinsicsWithFCalls::ID = 0;
static RegisterPass<DFG2LLVM_ReplaceApproxHPVMIntrinsicsWithFCalls> X("replace-intrinsics",
                                      "Replace ApproxHPVM intrinsics with LLVM calls",
                                      false /* does not modify the CFG */,
                                      true /* transformation,   *
                                            * not just analysis */);