//=== GenVISC.cpp - Implements "Hierarchical Dataflow Graph Builder Pass" ===//
//
//                     The LLVM Compiler Infrastructure
//
// This file is distributed under the University of Illinois Open Source
// License. See LICENSE.TXT for details.
//
//===----------------------------------------------------------------------===//

#define DEBUG_TYPE "genvisc"
#include "GenVISC/GenVISC.h"

#include "SupportVISC/VISCHint.h"
#include "SupportVISC/VISCUtils.h"
#include "llvm/ADT/Statistic.h"
#include "llvm/IR/CallSite.h"
#include "llvm/IR/DerivedTypes.h"
#include "llvm/IR/InstIterator.h"
#include "llvm/IR/Instructions.h"
#include "llvm/IRReader/IRReader.h"
#include "llvm/Support/Debug.h"
#include "llvm/Support/SourceMgr.h"
#include "llvm/Support/raw_ostream.h"
#include "llvm/Transforms/Utils/BasicBlockUtils.h"
#include "llvm/Transforms/Utils/Cloning.h"
#include "llvm/Transforms/Utils/ValueMapper.h"

#define TIMER(X)                                                               \
  do {                                                                         \
    if (VISCTimer) {                                                           \
      X;                                                                       \
    }                                                                          \
  } while (0)

using namespace llvm;
using namespace viscUtils;

// VISC Command line option to use timer or not
static cl::opt<bool> VISCTimer("visc-timers-gen",
                               cl::desc("Enable GenVISC timer"));

namespace genvisc {

// Helper Functions

static inline ConstantInt *getTimerID(Module &, enum visc_TimerID);
static Function *transformReturnTypeToStruct(Function *F);
static Type *getReturnTypeFromReturnInst(Function *F);

// Check if the dummy function call is a __visc__node call
#define IS_VISC_CALL(callName)                                                 \
  static bool isVISCCall_##callName(Instruction *I) {                          \
    if (!isa<CallInst>(I))                                                     \
      return false;                                                            \
    CallInst *CI = cast<CallInst>(I);                                          \
    return (CI->getCalledValue()->stripPointerCasts()->getName())              \
        .equals("__visc__" #callName);                                         \
  }

static void ReplaceCallWithIntrinsic(Instruction *I, Intrinsic::ID IntrinsicID,
                                     std::vector<Instruction *> *Erase) {
  // Check if the instruction is Call Instruction
  assert(isa<CallInst>(I) && "Expecting CallInst");
  CallInst *CI = cast<CallInst>(I);
  DEBUG(errs() << "Found call: " << *CI << "\n");

  // Find the correct intrinsic call
  Module *M = CI->getParent()->getParent()->getParent();
  Function *F;
  std::vector<Type *> ArgTypes;
  std::vector<Value *> args;
  if (Intrinsic::isOverloaded(IntrinsicID)) {
    // This is an overloaded intrinsic. The types must exactly match. Get the
    // argument types
    for (unsigned i = 0; i < CI->getNumArgOperands(); i++) {
      ArgTypes.push_back(CI->getArgOperand(i)->getType());
      args.push_back(CI->getArgOperand(i));
    }
    F = Intrinsic::getDeclaration(M, IntrinsicID, ArgTypes);
    DEBUG(errs() << *F << "\n");
  } else { // Non-overloaded intrinsic
    F = Intrinsic::getDeclaration(M, IntrinsicID);
    FunctionType *FTy = F->getFunctionType();
    DEBUG(errs() << *F << "\n");

    // Create argument list
    assert(CI->getNumArgOperands() == FTy->getNumParams() &&
           "Number of arguments of call do not match with Intrinsic");
    for (unsigned i = 0; i < CI->getNumArgOperands(); i++) {
      Value *V = CI->getArgOperand(i);
      // Either the type should match or both should be of pointer type
      assert((V->getType() == FTy->getParamType(i) ||
              (V->getType()->isPointerTy() &&
               FTy->getParamType(i)->isPointerTy())) &&
             "Dummy function call argument does not match with Intrinsic "
             "argument!");
      // If the types do not match, then both must be pointer type and pointer
      // cast needs to be performed
      if (V->getType() != FTy->getParamType(i)) {
        V = CastInst::CreatePointerCast(V, FTy->getParamType(i), "", CI);
      }
      args.push_back(V);
    }
  }
  // Insert call instruction
  CallInst *Inst = CallInst::Create(
      F, args, F->getReturnType()->isVoidTy() ? "" : CI->getName(), CI);

  DEBUG(errs() << "\tSubstitute with: " << *Inst << "\n");

  CI->replaceAllUsesWith(Inst);
  // If the previous instruction needs to be erased, insert it in the vector
  // Erased
  if (Erase != NULL)
    Erase->push_back(CI);
}

IS_VISC_CALL(launch) /* Exists but not required */
IS_VISC_CALL(edge)   /* Exists but not required */
IS_VISC_CALL(createNodeND)
// IS_VISC_CALL(createNode)
// IS_VISC_CALL(createNode1D)
// IS_VISC_CALL(createNode2D)
// IS_VISC_CALL(createNode3D)
IS_VISC_CALL(bindIn)
IS_VISC_CALL(bindOut)
IS_VISC_CALL(push)
IS_VISC_CALL(pop)
IS_VISC_CALL(getNode)
IS_VISC_CALL(getParentNode)
IS_VISC_CALL(barrier)
IS_VISC_CALL(malloc)
IS_VISC_CALL(return )
IS_VISC_CALL(getNodeInstanceID_x)
IS_VISC_CALL(getNodeInstanceID_y)
IS_VISC_CALL(getNodeInstanceID_z)
IS_VISC_CALL(getNumNodeInstances_x)
IS_VISC_CALL(getNumNodeInstances_y)
IS_VISC_CALL(getNumNodeInstances_z)
// Atomics
IS_VISC_CALL(atomic_cmpxchg)
IS_VISC_CALL(atomic_add)
IS_VISC_CALL(atomic_sub)
IS_VISC_CALL(atomic_xchg)
IS_VISC_CALL(atomic_inc)
IS_VISC_CALL(atomic_dec)
IS_VISC_CALL(atomic_min)
IS_VISC_CALL(atomic_max)
IS_VISC_CALL(atomic_umin)
IS_VISC_CALL(atomic_umax)
IS_VISC_CALL(atomic_and)
IS_VISC_CALL(atomic_or)
IS_VISC_CALL(atomic_xor)
// Misc Fn
IS_VISC_CALL(floor)
IS_VISC_CALL(rsqrt)
IS_VISC_CALL(sqrt)
IS_VISC_CALL(sin)
IS_VISC_CALL(cos)

IS_VISC_CALL(init)
IS_VISC_CALL(cleanup)
IS_VISC_CALL(wait)
IS_VISC_CALL(trackMemory)
IS_VISC_CALL(untrackMemory)
IS_VISC_CALL(requestMemory)
IS_VISC_CALL(attributes)
IS_VISC_CALL(hint)

// Return the constant integer represented by value V
static unsigned getNumericValue(Value *V) {
  assert(
      isa<ConstantInt>(V) &&
      "Value indicating the number of arguments should be a constant integer");
  return cast<ConstantInt>(V)->getZExtValue();
}

// Take the __visc__return instruction and generate code for combining the
// values being returned into a struct and returning it.
// The first operand is the number of returned values
static Value *genCodeForReturn(CallInst *CI) {
  LLVMContext &Ctx = CI->getContext();
  assert(isVISCCall_return(CI) && "__visc__return instruction expected!");

  // Parse the dummy function call here
  assert(CI->getNumArgOperands() > 0 &&
         "Too few arguments for __visc_return call!\n");
  unsigned numRetVals = getNumericValue(CI->getArgOperand(0));

  assert(CI->getNumArgOperands() - 1 == numRetVals &&
         "Too few arguments for __visc_return call!\n");
  DEBUG(errs() << "\tNum of return values = " << numRetVals << "\n");

  std::vector<Type *> ArgTypes;
  for (unsigned i = 1; i < CI->getNumArgOperands(); i++) {
    ArgTypes.push_back(CI->getArgOperand(i)->getType());
  }
  Twine outTyName = "struct.out." + CI->getParent()->getParent()->getName();
  StructType *RetTy = StructType::create(Ctx, ArgTypes, outTyName.str(), true);

  InsertValueInst *IV = InsertValueInst::Create(
      UndefValue::get(RetTy), CI->getArgOperand(1), 0, "returnStruct", CI);
  DEBUG(errs() << "Code generation for return:\n");
  DEBUG(errs() << *IV << "\n");

  for (unsigned i = 2; i < CI->getNumArgOperands(); i++) {
    IV = InsertValueInst::Create(IV, CI->getArgOperand(i), i - 1, IV->getName(),
                                 CI);
    DEBUG(errs() << *IV << "\n");
  }

  return IV;
}

// Analyse the attribute call for this function. Add the in and out
// attributes to pointer parameters.
static void handleVISCAttributes(Function *F, CallInst *CI) {
  DEBUG(errs() << "Kernel before adding In/Out VISC attributes:\n"
               << *F << "\n");
  // Parse the dummy function call here
  unsigned offset = 0;
  // Find number of In pointers
  assert(CI->getNumArgOperands() > offset &&
         "Too few arguments for __visc__attributes call!");
  unsigned numInPtrs = getNumericValue(CI->getArgOperand(offset));
  DEBUG(errs() << "\tNum of in pointers = " << numInPtrs << "\n");

  for (unsigned i = offset + 1; i < offset + 1 + numInPtrs; i++) {
    Value *V = CI->getArgOperand(i);
    if (Argument *arg = dyn_cast<Argument>(V)) {
      F->addAttribute(1 + arg->getArgNo(), Attribute::In);
    } else {
      DEBUG(errs() << "Invalid argument to __visc__attribute: " << *V << "\n");
      llvm_unreachable(
          "Only pointer arguments can be passed to __visc__attributes call");
    }
  }
  // Find number of Out Pointers
  offset += 1 + numInPtrs;
  assert(CI->getNumArgOperands() > offset &&
         "Too few arguments for __visc__attributes call!");
  unsigned numOutPtrs = getNumericValue(CI->getOperand(offset));
  DEBUG(errs() << "\tNum of out Pointers = " << numOutPtrs << "\n");
  for (unsigned i = offset + 1; i < offset + 1 + numOutPtrs; i++) {
    Value *V = CI->getArgOperand(i);
    if (Argument *arg = dyn_cast<Argument>(V)) {
      F->addAttribute(1 + arg->getArgNo(), Attribute::Out);
    } else {
      DEBUG(errs() << "Invalid argument to __visc__attribute: " << *V << "\n");
      llvm_unreachable(
          "Only pointer arguments can be passed to __visc__attributes call");
    }
  }
  DEBUG(errs() << "Kernel after adding In/Out VISC attributes:\n"
               << *F << "\n");
}

// Public Functions of GenVISC pass
bool GenVISC::runOnModule(Module &M) {
  DEBUG(errs() << "\nGENVISC PASS\n");
  this->M = &M;

  // Load Runtime API Module
  SMDiagnostic Err;

  char *LLVM_SRC_ROOT = getenv("LLVM_SRC_ROOT");
  assert(LLVM_SRC_ROOT != NULL && "Define LLVM_SRC_ROOT environment variable!");

  Twine llvmSrcRoot = LLVM_SRC_ROOT;
  Twine runtimeAPI = llvmSrcRoot + "/tools/hpvm/projects/visc-rt/visc-rt.ll";
  DEBUG(errs() << llvmSrcRoot << "\n");

  std::unique_ptr<Module> runtimeModule =
      parseIRFile(runtimeAPI.str(), Err, M.getContext());

  if (runtimeModule == NULL)
    DEBUG(errs() << Err.getMessage() << " " << runtimeAPI << "\n");
  else
    DEBUG(errs() << "Successfully loaded visc-rt API module\n");

  llvm_visc_initializeTimerSet = M.getOrInsertFunction(
      "llvm_visc_initializeTimerSet",
      runtimeModule->getFunction("llvm_visc_initializeTimerSet")
          ->getFunctionType());
  // DEBUG(errs() << *llvm_visc_initializeTimerSet);

  llvm_visc_switchToTimer = M.getOrInsertFunction(
      "llvm_visc_switchToTimer",
      runtimeModule->getFunction("llvm_visc_switchToTimer")->getFunctionType());
  // DEBUG(errs() << *llvm_visc_switchToTimer);

  llvm_visc_printTimerSet = M.getOrInsertFunction(
      "llvm_visc_printTimerSet",
      runtimeModule->getFunction("llvm_visc_printTimerSet")->getFunctionType());
  // DEBUG(errs() << *llvm_visc_printTimerSet);

  // Insert init context in main
  DEBUG(errs() << "Locate __visc__init()\n");
  Function *VI = M.getFunction("__visc__init");
  assert(VI->getNumUses() == 1 && "__visc__init should only be used once");
  Instruction *I = cast<Instruction>(*VI->user_begin());

  DEBUG(errs() << "Initialize Timer Set\n");
  initializeTimerSet(I);
  switchToTimer(visc_TimerID_NONE, I);

  // Insert print instruction at visc exit
  DEBUG(errs() << "Locate __visc__cleanup()\n");
  Function *VC = M.getFunction("__visc__cleanup");
  assert(VC->getNumUses() == 1 && "__visc__cleanup should only be used once");
  I = cast<Instruction>(*VC->user_begin());
  printTimerSet(I);

  DEBUG(errs() << "-------- Searching for launch sites ----------\n");

  std::vector<Instruction *> toBeErased;
  std::vector<Function *> functions;

  for (auto &F : M)
    functions.push_back(&F);

  // Iterate over all functions in the module
  for (Function *f : functions) {
    DEBUG(errs() << "Function: " << f->getName() << "\n");

    // List with the required additions in the function's return type
    std::vector<Type *> FRetTypes;

    enum mutateTypeCause {
      mtc_None,
      mtc_BIND,
      mtc_RETURN,
      mtc_NUM_CAUSES
    } bind;
    bind = mutateTypeCause::mtc_None;

    // Iterate over all the instructions in this function
    for (inst_iterator i = inst_begin(f), e = inst_end(f); i != e; ++i) {
      Instruction *I = &*i; // Grab pointer to Instruction
      // If not a call instruction, move to next instruction
      if (!isa<CallInst>(I))
        continue;

      CallInst *CI = cast<CallInst>(I);
      LLVMContext &Ctx = CI->getContext();

      if (isVISCCall_init(I)) {
        ReplaceCallWithIntrinsic(I, Intrinsic::visc_init, &toBeErased);
      }
      if (isVISCCall_cleanup(I)) {
        ReplaceCallWithIntrinsic(I, Intrinsic::visc_cleanup, &toBeErased);
      }
      if (isVISCCall_wait(I)) {
        ReplaceCallWithIntrinsic(I, Intrinsic::visc_wait, &toBeErased);
      }
      if (isVISCCall_trackMemory(I)) {
        ReplaceCallWithIntrinsic(I, Intrinsic::visc_trackMemory, &toBeErased);
      }
      if (isVISCCall_untrackMemory(I)) {
        ReplaceCallWithIntrinsic(I, Intrinsic::visc_untrackMemory, &toBeErased);
      }
      if (isVISCCall_requestMemory(I)) {
        ReplaceCallWithIntrinsic(I, Intrinsic::visc_requestMemory, &toBeErased);
      }
      if (isVISCCall_hint(I)) {
        assert(isa<ConstantInt>(CI->getArgOperand(0)) &&
               "Argument to hint must be constant integer!");
        ConstantInt *hint = cast<ConstantInt>(CI->getArgOperand(0));

        visc::Target t = (visc::Target)hint->getZExtValue();
        addHint(CI->getParent()->getParent(), t);
        DEBUG(errs() << "Found visc hint call: " << *CI << "\n");
        toBeErased.push_back(CI);
      }
      if (isVISCCall_launch(I)) {
        Function *LaunchF =
            Intrinsic::getDeclaration(&M, Intrinsic::visc_launch);
        DEBUG(errs() << *LaunchF << "\n");
        // Get i8* cast to function pointer
        Function *graphFunc = cast<Function>(CI->getArgOperand(1));
        graphFunc = transformReturnTypeToStruct(graphFunc);
        Constant *F =
            ConstantExpr::getPointerCast(graphFunc, Type::getInt8PtrTy(Ctx));
        assert(
            F &&
            "Function invoked by VISC launch has to be define and constant.");

        ConstantInt *Op = cast<ConstantInt>(CI->getArgOperand(0));
        assert(Op && "VISC launch's streaming argument is a constant value.");
        Value *isStreaming = Op->isZero() ? ConstantInt::getFalse(Ctx)
                                          : ConstantInt::getTrue(Ctx);

        auto *ArgTy = dyn_cast<PointerType>(CI->getArgOperand(2)->getType());
        assert(ArgTy && "VISC launch argument should be pointer type.");
        Value *Arg = CI->getArgOperand(2);
        if (!ArgTy->getElementType()->isIntegerTy(8))
          Arg = BitCastInst::CreatePointerCast(CI->getArgOperand(2),
                                               Type::getInt8PtrTy(Ctx), "", CI);
        Value *LaunchArgs[] = {F, Arg, isStreaming};
        CallInst *LaunchInst = CallInst::Create(
            LaunchF, ArrayRef<Value *>(LaunchArgs, 3), "graphID", CI);
        DEBUG(errs() << "Found visc launch call: " << *CI << "\n");
        DEBUG(errs() << "\tSubstitute with: " << *LaunchInst << "\n");
        CI->replaceAllUsesWith(LaunchInst);
        toBeErased.push_back(CI);
      }
      if (isVISCCall_push(I)) {
        ReplaceCallWithIntrinsic(I, Intrinsic::visc_push, &toBeErased);
      }
      if (isVISCCall_pop(I)) {
        ReplaceCallWithIntrinsic(I, Intrinsic::visc_pop, &toBeErased);
      }
      if (isVISCCall_createNodeND(I)) {
        assert(CI->getNumArgOperands() > 0 &&
               "Too few arguments for __visc__createNodeND call");
        unsigned numDims = getNumericValue(CI->getArgOperand(0));
        // We need as meny dimension argments are there are dimensions
        assert(CI->getNumArgOperands() - 2 == numDims &&
               "Too few arguments for __visc_createNodeND call!\n");

        Function *CreateNodeF;
        switch (numDims) {
        case 0:
          CreateNodeF =
              Intrinsic::getDeclaration(&M, Intrinsic::visc_createNode);
          break;
        case 1:
          CreateNodeF =
              Intrinsic::getDeclaration(&M, Intrinsic::visc_createNode1D);
          break;
        case 2:
          CreateNodeF =
              Intrinsic::getDeclaration(&M, Intrinsic::visc_createNode2D);
          break;
        case 3:
          CreateNodeF =
              Intrinsic::getDeclaration(&M, Intrinsic::visc_createNode3D);
          break;
        default:
          llvm_unreachable("Unsupported number of dimensions\n");
          break;
        }
        DEBUG(errs() << *CreateNodeF << "\n");
        DEBUG(errs() << *I << "\n");
        DEBUG(errs() << "in " << I->getParent()->getParent()->getName()
                     << "\n");

        // Get i8* cast to function pointer
        Function *graphFunc = cast<Function>(CI->getArgOperand(1));
        graphFunc = transformReturnTypeToStruct(graphFunc);
        Constant *F =
            ConstantExpr::getPointerCast(graphFunc, Type::getInt8PtrTy(Ctx));

        CallInst *CreateNodeInst;
        switch (numDims) {
        case 0:
          CreateNodeInst = CallInst::Create(CreateNodeF, ArrayRef<Value *>(F),
                                            graphFunc->getName() + ".node", CI);
          break;
        case 1: {
          assert((CI->getArgOperand(2)->getType() == Type::getInt64Ty(Ctx)) &&
                 "CreateNodeND dimension argument, 2, expected to be i64\n");
          Value *CreateNodeArgs[] = {F, CI->getArgOperand(2)};
          CreateNodeInst = CallInst::Create(
              CreateNodeF, ArrayRef<Value *>(CreateNodeArgs, 2),
              graphFunc->getName() + ".node", CI);
        } break;
        case 2: {
          assert((CI->getArgOperand(2)->getType() == Type::getInt64Ty(Ctx)) &&
                 "CreateNodeND dimension argument, 2, expected to be i64\n");
          assert((CI->getArgOperand(3)->getType() == Type::getInt64Ty(Ctx)) &&
                 "CreateNodeND dimension argument, 3, expected to be i64\n");
          Value *CreateNodeArgs[] = {F, CI->getArgOperand(2),
                                     CI->getArgOperand(3)};
          CreateNodeInst = CallInst::Create(
              CreateNodeF, ArrayRef<Value *>(CreateNodeArgs, 3),
              graphFunc->getName() + ".node", CI);
        } break;
        case 3: {
          assert((CI->getArgOperand(2)->getType() == Type::getInt64Ty(Ctx)) &&
                 "CreateNodeND dimension argument, 2, expected to be i64\n");
          assert((CI->getArgOperand(3)->getType() == Type::getInt64Ty(Ctx)) &&
                 "CreateNodeND dimension argument, 3, expected to be i64\n");
          assert((CI->getArgOperand(4)->getType() == Type::getInt64Ty(Ctx)) &&
                 "CreateNodeND dimension argument, 4, expected to be i64\n");
          Value *CreateNodeArgs[] = {F, CI->getArgOperand(2),
                                     CI->getArgOperand(3),
                                     CI->getArgOperand(4)};
          CreateNodeInst = CallInst::Create(
              CreateNodeF, ArrayRef<Value *>(CreateNodeArgs, 4),
              graphFunc->getName() + ".node", CI);
        } break;
        default:
          llvm_unreachable(
              "Impossible path: number of dimensions is 0, 1, 2, 3\n");
          break;
        }

        DEBUG(errs() << "Found visc createNode call: " << *CI << "\n");
        DEBUG(errs() << "\tSubstitute with: " << *CreateNodeInst << "\n");
        CI->replaceAllUsesWith(CreateNodeInst);
        toBeErased.push_back(CI);
      }

      if (isVISCCall_edge(I)) {
        Function *EdgeF =
            Intrinsic::getDeclaration(&M, Intrinsic::visc_createEdge);
        DEBUG(errs() << *EdgeF << "\n");
        ConstantInt *Op = cast<ConstantInt>(CI->getArgOperand(5));
        ConstantInt *EdgeTypeOp = cast<ConstantInt>(CI->getArgOperand(2));
        assert(Op && EdgeTypeOp &&
               "Arguments of CreateEdge are not constant integers.");
        Value *isStreaming = Op->isZero() ? ConstantInt::getFalse(Ctx)
                                          : ConstantInt::getTrue(Ctx);
        Value *isAllToAll = EdgeTypeOp->isZero() ? ConstantInt::getFalse(Ctx)
                                                 : ConstantInt::getTrue(Ctx);
        Value *EdgeArgs[] = {CI->getArgOperand(0), CI->getArgOperand(1),
                             isAllToAll,           CI->getArgOperand(3),
                             CI->getArgOperand(4), isStreaming};
        CallInst *EdgeInst = CallInst::Create(
            EdgeF, ArrayRef<Value *>(EdgeArgs, 6), "output", CI);
        DEBUG(errs() << "Found visc edge call: " << *CI << "\n");
        DEBUG(errs() << "\tSubstitute with: " << *EdgeInst << "\n");
        CI->replaceAllUsesWith(EdgeInst);
        toBeErased.push_back(CI);
      }
      if (isVISCCall_bindIn(I)) {
        Function *BindInF =
            Intrinsic::getDeclaration(&M, Intrinsic::visc_bind_input);
        DEBUG(errs() << *BindInF << "\n");
        // Check if this is a streaming bind or not
        ConstantInt *Op = cast<ConstantInt>(CI->getArgOperand(3));
        assert(Op && "Streaming argument for bind in intrinsic should be a "
                     "constant integer.");
        Value *isStreaming = Op->isZero() ? ConstantInt::getFalse(Ctx)
                                          : ConstantInt::getTrue(Ctx);
        Value *BindInArgs[] = {CI->getArgOperand(0), CI->getArgOperand(1),
                               CI->getArgOperand(2), isStreaming};
        CallInst *BindInInst =
            CallInst::Create(BindInF, ArrayRef<Value *>(BindInArgs, 4), "", CI);
        DEBUG(errs() << "Found visc bindIn call: " << *CI << "\n");
        DEBUG(errs() << "\tSubstitute with: " << *BindInInst << "\n");
        CI->replaceAllUsesWith(BindInInst);
        toBeErased.push_back(CI);
      }
      if (isVISCCall_bindOut(I)) {
        Function *BindOutF =
            Intrinsic::getDeclaration(&M, Intrinsic::visc_bind_output);
        DEBUG(errs() << *BindOutF << "\n");
        // Check if this is a streaming bind or not
        ConstantInt *Op = cast<ConstantInt>(CI->getArgOperand(3));
        assert(Op && "Streaming argument for bind out intrinsic should be a "
                     "constant integer.");
        Value *isStreaming = Op->isZero() ? ConstantInt::getFalse(Ctx)
                                          : ConstantInt::getTrue(Ctx);
        Value *BindOutArgs[] = {CI->getArgOperand(0), CI->getArgOperand(1),
                                CI->getArgOperand(2), isStreaming};
        CallInst *BindOutInst = CallInst::Create(
            BindOutF, ArrayRef<Value *>(BindOutArgs, 4), "", CI);
        DEBUG(errs() << "Found visc bindOut call: " << *CI << "\n");
        DEBUG(errs() << "\tSubstitute with: " << *BindOutInst << "\n");

        DEBUG(errs() << "Fixing the return type of the function\n");
        // FIXME: What if the child node function has not been visited already.
        // i.e., it's return type has not been fixed.
        Function *F = I->getParent()->getParent();
        DEBUG(errs() << F->getName() << "\n";);
        IntrinsicInst *NodeIntrinsic =
            cast<IntrinsicInst>(CI->getArgOperand(0));
        assert(NodeIntrinsic &&
               "Instruction value in bind out is not a create node intrinsic.");
        DEBUG(errs() << "Node intrinsic: " << *NodeIntrinsic << "\n");
        assert(
            (NodeIntrinsic->getIntrinsicID() == Intrinsic::visc_createNode ||
             NodeIntrinsic->getIntrinsicID() == Intrinsic::visc_createNode1D ||
             NodeIntrinsic->getIntrinsicID() == Intrinsic::visc_createNode2D ||
             NodeIntrinsic->getIntrinsicID() == Intrinsic::visc_createNode3D) &&
            "Instruction value in bind out is not a create node intrinsic.");
        Function *ChildF = cast<Function>(
            NodeIntrinsic->getArgOperand(0)->stripPointerCasts());
        DEBUG(errs() << ChildF->getName() << "\n";);
        int srcpos = cast<ConstantInt>(CI->getArgOperand(1))->getSExtValue();
        int destpos = cast<ConstantInt>(CI->getArgOperand(2))->getSExtValue();
        StructType *ChildReturnTy = cast<StructType>(ChildF->getReturnType());

        Type *ReturnType = F->getReturnType();
        DEBUG(errs() << *ReturnType << "\n";);
        assert((ReturnType->isVoidTy() || isa<StructType>(ReturnType)) &&
               "Return type should either be a struct or void type!");

        FRetTypes.insert(FRetTypes.begin() + destpos,
                         ChildReturnTy->getElementType(srcpos));
        assert(((bind == mutateTypeCause::mtc_BIND) ||
                (bind == mutateTypeCause::mtc_None)) &&
               "Both bind_out and visc_return detected");
        bind = mutateTypeCause::mtc_BIND;

        CI->replaceAllUsesWith(BindOutInst);
        toBeErased.push_back(CI);
      }
      if (isVISCCall_attributes(I)) {
        Function *F = CI->getParent()->getParent();
        handleVISCAttributes(F, CI);
        toBeErased.push_back(CI);
      }
      if (isVISCCall_getNode(I)) {
        ReplaceCallWithIntrinsic(I, Intrinsic::visc_getNode, &toBeErased);
      }
      if (isVISCCall_getParentNode(I)) {
        ReplaceCallWithIntrinsic(I, Intrinsic::visc_getParentNode, &toBeErased);
      }
      if (isVISCCall_barrier(I)) {
        ReplaceCallWithIntrinsic(I, Intrinsic::visc_barrier, &toBeErased);
      }
      if (isVISCCall_malloc(I)) {
        ReplaceCallWithIntrinsic(I, Intrinsic::visc_malloc, &toBeErased);
      }
      if (isVISCCall_return(I)) {
        DEBUG(errs() << "Function before visc return processing\n"
                     << *I->getParent()->getParent() << "\n");
        // The operands to this call are the values to be returned by the node
        Value *ReturnVal = genCodeForReturn(CI);
        DEBUG(errs() << *ReturnVal << "\n");
        Type *ReturnType = ReturnVal->getType();
        assert(isa<StructType>(ReturnType) &&
               "Return type should be a struct type!");

        assert(((bind == mutateTypeCause::mtc_RETURN) ||
                (bind == mutateTypeCause::mtc_None)) &&
               "Both bind_out and visc_return detected");

        if (bind == mutateTypeCause::mtc_None) {
          // If this is None, this is the first __visc__return
          // instruction we have come upon. Place the return type of the
          // function in the return type vector
          bind = mutateTypeCause::mtc_RETURN;
          StructType *ReturnStructTy = cast<StructType>(ReturnType);
          for (unsigned i = 0; i < ReturnStructTy->getNumElements(); i++)
            FRetTypes.push_back(ReturnStructTy->getElementType(i));
        } else { // bind == mutateTypeCause::mtc_RETURN
          // This is not the first __visc__return
          // instruction we have come upon.
          // Check that the return types are the same
          assert((ReturnType == FRetTypes[0]) &&
                 "Multiple returns with mismatching types");
        }

        ReturnInst *RetInst = ReturnInst::Create(Ctx, ReturnVal);
        DEBUG(errs() << "Found visc return call: " << *CI << "\n");
        Instruction *oldReturn = CI->getParent()->getTerminator();
        assert(isa<ReturnInst>(oldReturn) &&
               "Expecting a return to be the terminator of this BB!");
        DEBUG(errs() << "Found return statement of BB: " << *oldReturn << "\n");
        DEBUG(errs() << "\tSubstitute return with: " << *RetInst << "\n");
        // CI->replaceAllUsesWith(RetInst);
        toBeErased.push_back(CI);
        ReplaceInstWithInst(oldReturn, RetInst);
        DEBUG(errs() << "Function after visc return processing\n"
                     << *I->getParent()->getParent() << "\n");
      }

      if (isVISCCall_getNodeInstanceID_x(I)) {
        ReplaceCallWithIntrinsic(I, Intrinsic::visc_getNodeInstanceID_x,
                                 &toBeErased);
      }
      if (isVISCCall_getNodeInstanceID_y(I)) {
        ReplaceCallWithIntrinsic(I, Intrinsic::visc_getNodeInstanceID_y,
                                 &toBeErased);
      }
      if (isVISCCall_getNodeInstanceID_z(I)) {
        ReplaceCallWithIntrinsic(I, Intrinsic::visc_getNodeInstanceID_z,
                                 &toBeErased);
      }
      if (isVISCCall_getNumNodeInstances_x(I)) {
        ReplaceCallWithIntrinsic(I, Intrinsic::visc_getNumNodeInstances_x,
                                 &toBeErased);
      }
      if (isVISCCall_getNumNodeInstances_y(I)) {
        ReplaceCallWithIntrinsic(I, Intrinsic::visc_getNumNodeInstances_y,
                                 &toBeErased);
      }
      if (isVISCCall_getNumNodeInstances_z(I)) {
        ReplaceCallWithIntrinsic(I, Intrinsic::visc_getNumNodeInstances_z,
                                 &toBeErased);
      }
      if (isVISCCall_atomic_cmpxchg(I)) {
        ReplaceCallWithIntrinsic(I, Intrinsic::visc_atomic_cmpxchg,
                                 &toBeErased);
      }
      if (isVISCCall_atomic_add(I)) {
        ReplaceCallWithIntrinsic(I, Intrinsic::visc_atomic_add, &toBeErased);
      }
      if (isVISCCall_atomic_sub(I)) {
        ReplaceCallWithIntrinsic(I, Intrinsic::visc_atomic_sub, &toBeErased);
      }
      if (isVISCCall_atomic_xchg(I)) {
        ReplaceCallWithIntrinsic(I, Intrinsic::visc_atomic_xchg, &toBeErased);
      }
      if (isVISCCall_atomic_inc(I)) {
        ReplaceCallWithIntrinsic(I, Intrinsic::visc_atomic_inc, &toBeErased);
      }
      if (isVISCCall_atomic_dec(I)) {
        ReplaceCallWithIntrinsic(I, Intrinsic::visc_atomic_dec, &toBeErased);
      }
      if (isVISCCall_atomic_min(I)) {
        ReplaceCallWithIntrinsic(I, Intrinsic::visc_atomic_min, &toBeErased);
      }
      if (isVISCCall_atomic_umin(I)) {
        ReplaceCallWithIntrinsic(I, Intrinsic::visc_atomic_umin, &toBeErased);
      }
      if (isVISCCall_atomic_max(I)) {
        ReplaceCallWithIntrinsic(I, Intrinsic::visc_atomic_max, &toBeErased);
      }
      if (isVISCCall_atomic_umax(I)) {
        ReplaceCallWithIntrinsic(I, Intrinsic::visc_atomic_umax, &toBeErased);
      }
      if (isVISCCall_atomic_and(I)) {
        ReplaceCallWithIntrinsic(I, Intrinsic::visc_atomic_and, &toBeErased);
      }
      if (isVISCCall_atomic_or(I)) {
        ReplaceCallWithIntrinsic(I, Intrinsic::visc_atomic_or, &toBeErased);
      }
      if (isVISCCall_atomic_xor(I)) {
        ReplaceCallWithIntrinsic(I, Intrinsic::visc_atomic_xor, &toBeErased);
      }
      if (isVISCCall_floor(I)) {
        ReplaceCallWithIntrinsic(I, Intrinsic::floor, &toBeErased);
      }
      if (isVISCCall_rsqrt(I)) {
        ReplaceCallWithIntrinsic(I, Intrinsic::nvvm_rsqrt_approx_f,
                                 &toBeErased);
      }
      if (isVISCCall_sqrt(I)) {
        ReplaceCallWithIntrinsic(I, Intrinsic::sqrt, &toBeErased);
      }
      if (isVISCCall_sin(I)) {
        ReplaceCallWithIntrinsic(I, Intrinsic::sin, &toBeErased);
      }
      if (isVISCCall_cos(I)) {
        ReplaceCallWithIntrinsic(I, Intrinsic::cos, &toBeErased);
      }
    }

    // Erase the __visc__node calls
    DEBUG(errs() << "Erase " << toBeErased.size() << " Statements:\n");
    for (auto I : toBeErased) {
      DEBUG(errs() << *I << "\n");
    }
    while (!toBeErased.empty()) {
      Instruction *I = toBeErased.back();
      DEBUG(errs() << "\tErasing " << *I << "\n");
      I->eraseFromParent();
      toBeErased.pop_back();
    }

    if (bind == mutateTypeCause::mtc_BIND ||
        bind == mutateTypeCause::mtc_RETURN) {
      DEBUG(errs() << "Function before fixing return type\n" << *f << "\n");
      // Argument type list.
      std::vector<Type *> FArgTypes;
      for (Function::const_arg_iterator ai = f->arg_begin(), ae = f->arg_end();
           ai != ae; ++ai) {
        FArgTypes.push_back(ai->getType());
      }

      // Find new return type of function
      Type *NewReturnTy;
      if (bind == mutateTypeCause::mtc_BIND) {

        std::vector<Type *> TyList;
        for (unsigned i = 0; i < FRetTypes.size(); i++)
          TyList.push_back(FRetTypes[i]);

        NewReturnTy =
            StructType::create(f->getContext(), TyList,
                               Twine("struct.out." + f->getName()).str(), true);
      } else {
        NewReturnTy = getReturnTypeFromReturnInst(f);
        assert(NewReturnTy->isStructTy() && "Expecting a struct type!");
      }

      FunctionType *FTy =
          FunctionType::get(NewReturnTy, FArgTypes, f->isVarArg());

      // Change the function type
      Function *newF = cloneFunction(f, FTy, false);
      DEBUG(errs() << *newF << "\n");

      if (bind == mutateTypeCause::mtc_BIND) {
        // This is certainly an internal node, and hence just one BB with one
        // return terminator instruction. Change return statement
        ReturnInst *RI =
            cast<ReturnInst>(newF->getEntryBlock().getTerminator());
        ReturnInst *newRI = ReturnInst::Create(newF->getContext(),
                                               UndefValue::get(NewReturnTy));
        ReplaceInstWithInst(RI, newRI);
      }
      if (bind == mutateTypeCause::mtc_RETURN) {
        // Nothing
      }
      replaceNodeFunctionInIR(*f->getParent(), f, newF);
      DEBUG(errs() << "Function after fixing return type\n" << *newF << "\n");
    }
  }
  return false; // TODO: What does returning "false" mean?
}

// Generate Code for declaring a constant string [L x i8] and return a pointer
// to the start of it.
Value *GenVISC::getStringPointer(const Twine &S, Instruction *IB,
                                 const Twine &Name) {
  Constant *SConstant =
      ConstantDataArray::getString(M->getContext(), S.str(), true);
  Value *SGlobal =
      new GlobalVariable(*M, SConstant->getType(), true,
                         GlobalValue::InternalLinkage, SConstant, Name);
  Value *Zero = ConstantInt::get(Type::getInt64Ty(M->getContext()), 0);
  Value *GEPArgs[] = {Zero, Zero};
  GetElementPtrInst *SPtr = GetElementPtrInst::Create(
      nullptr, SGlobal, ArrayRef<Value *>(GEPArgs, 2), Name + "Ptr", IB);
  return SPtr;
}

void GenVISC::initializeTimerSet(Instruction *InsertBefore) {
  Value *TimerSetAddr;
  StoreInst *SI;
  TIMER(TimerSet = new GlobalVariable(
            *M, Type::getInt8PtrTy(M->getContext()), false,
            GlobalValue::CommonLinkage,
            Constant::getNullValue(Type::getInt8PtrTy(M->getContext())),
            "viscTimerSet_GenVISC"));
  DEBUG(errs() << "Inserting GV: " << *TimerSet->getType() << *TimerSet
               << "\n");
  // DEBUG(errs() << "Inserting call to: " << *llvm_visc_initializeTimerSet <<
  // "\n");

  TIMER(TimerSetAddr = CallInst::Create(llvm_visc_initializeTimerSet, None, "",
                                        InsertBefore));
  DEBUG(errs() << "TimerSetAddress = " << *TimerSetAddr << "\n");
  TIMER(SI = new StoreInst(TimerSetAddr, TimerSet, InsertBefore));
  DEBUG(errs() << "Store Timer Address in Global variable: " << *SI << "\n");
}

void GenVISC::switchToTimer(enum visc_TimerID timer,
                            Instruction *InsertBefore) {
  Value *switchArgs[] = {TimerSet, getTimerID(*M, timer)};
  TIMER(CallInst::Create(llvm_visc_switchToTimer,
                         ArrayRef<Value *>(switchArgs, 2), "", InsertBefore));
}

void GenVISC::printTimerSet(Instruction *InsertBefore) {
  Value *TimerName;
  TIMER(TimerName = getStringPointer("GenVISC_Timer", InsertBefore));
  Value *printArgs[] = {TimerSet, TimerName};
  TIMER(CallInst::Create(llvm_visc_printTimerSet,
                         ArrayRef<Value *>(printArgs, 2), "", InsertBefore));
}

static inline ConstantInt *getTimerID(Module &M, enum visc_TimerID timer) {
  return ConstantInt::get(Type::getInt32Ty(M.getContext()), timer);
}

static Function *transformReturnTypeToStruct(Function *F) {
  // Currently only works for void return types
  DEBUG(errs() << "Transforming return type of function to Struct: "
               << F->getName() << "\n");

  if (isa<StructType>(F->getReturnType())) {
    DEBUG(errs() << "Return type is already a Struct: " << F->getName() << ": "
                 << *F->getReturnType() << "\n");
    return F;
  }

  assert(F->getReturnType()->isVoidTy() &&
         "Unhandled case - Only void return type handled\n");

  // Create the argument type list with added argument types
  std::vector<Type *> ArgTypes;
  for (Function::const_arg_iterator ai = F->arg_begin(), ae = F->arg_end();
       ai != ae; ++ai) {
    ArgTypes.push_back(ai->getType());
  }

  StructType *RetTy =
      StructType::create(F->getContext(), None, "emptyStruct", true);
  FunctionType *FTy = FunctionType::get(RetTy, ArgTypes, F->isVarArg());

  SmallVector<ReturnInst *, 8> Returns;
  Function *newF = cloneFunction(F, FTy, false, &Returns);
  // Replace ret void instruction with ret %RetTy undef
  for (auto &RI : Returns) {
    DEBUG(errs() << "Found return inst: " << *RI << "\n");
    ReturnInst *newRI =
        ReturnInst::Create(newF->getContext(), UndefValue::get(RetTy));
    ReplaceInstWithInst(RI, newRI);
  }

  replaceNodeFunctionInIR(*F->getParent(), F, newF);
  return newF;
}

static Type *getReturnTypeFromReturnInst(Function *F) {
  for (BasicBlock &BB : *F) {
    if (ReturnInst *RI = dyn_cast<ReturnInst>(BB.getTerminator())) {
      DEBUG(errs() << "Return type value: " << *RI->getReturnValue()->getType()
                   << "\n");
      return RI->getReturnValue()->getType();
    }
  }
}

char genvisc::GenVISC::ID = 0;
static RegisterPass<genvisc::GenVISC>
    X("genvisc",
      "Pass to generate VISC IR from LLVM IR (with dummy function calls)",
      false, false);

} // End of namespace genvisc