From a6a728e78018d7664454d7788929681677e99930 Mon Sep 17 00:00:00 2001
From: Maria Kotsifakou <kotsifa2@illinois.edu>
Date: Sat, 19 Aug 2017 15:29:50 -0500
Subject: [PATCH] Continuing previous commit (not all changes were pushed).

---
 .../Transforms/DFG2LLVM_X86/DFG2LLVM_X86.cpp   | 18 ++++++++++++++++++
 llvm/projects/visc-rt/policy.h                 |  9 +++------
 llvm/projects/visc-rt/visc-rt.cpp              |  4 ++--
 3 files changed, 23 insertions(+), 8 deletions(-)

diff --git a/llvm/lib/Transforms/DFG2LLVM_X86/DFG2LLVM_X86.cpp b/llvm/lib/Transforms/DFG2LLVM_X86/DFG2LLVM_X86.cpp
index cd64f28bcf..088aa85348 100644
--- a/llvm/lib/Transforms/DFG2LLVM_X86/DFG2LLVM_X86.cpp
+++ b/llvm/lib/Transforms/DFG2LLVM_X86/DFG2LLVM_X86.cpp
@@ -1681,6 +1681,15 @@ void CGT_X86::codeGen(DFInternalNode* N) {
       
       GenFuncCI = CallInst::Create(GF, GenFuncCallArgs, "", BBtrue);
       RI = ReturnInst::Create(M.getContext(), GenFuncCI, BBtrue);
+
+      if (DeviceAbstraction) {
+        // Prepare arguments and function for call to wait for device runtime call
+        std::vector<Value *> Args; // TODO: add the device type as argument?
+        Function *RTF =
+          cast<Function>(M.getOrInsertFunction("llvm_visc_deviceAbstraction_waitOnDeviceStatus",
+          runtimeModule->getFunction("llvm_visc_deviceAbstraction_waitOnDeviceStatus")->getFunctionType()));
+        CallInst *RTFInst = CallInst::Create(RTF, Args, "", GenFuncCI);
+      }
     }
 
     // Switch basic block pointers
@@ -1697,6 +1706,15 @@ void CGT_X86::codeGen(DFInternalNode* N) {
       
       GenFuncCI = CallInst::Create(SF, GenFuncCallArgs, "", BBtrue);
       RI = ReturnInst::Create(M.getContext(), GenFuncCI, BBtrue);
+
+      if (DeviceAbstraction) {
+        // Prepare arguments and function for call to wait for device runtime call
+        std::vector<Value *> Args; // TODO: add the device type as argument?
+        Function *RTF =
+          cast<Function>(M.getOrInsertFunction("llvm_visc_deviceAbstraction_waitOnDeviceStatus",
+          runtimeModule->getFunction("llvm_visc_deviceAbstraction_waitOnDeviceStatus")->getFunctionType()));
+        CallInst *RTFInst = CallInst::Create(RTF, Args, "", GenFuncCI);
+      }
     }
 
     RI = ReturnInst::Create(M.getContext(),
diff --git a/llvm/projects/visc-rt/policy.h b/llvm/projects/visc-rt/policy.h
index 1bc1e956ae..7fa3c27e3b 100644
--- a/llvm/projects/visc-rt/policy.h
+++ b/llvm/projects/visc-rt/policy.h
@@ -22,12 +22,9 @@ class NodePolicy : public Policy {
       "WrapperComputeMaxGradient_cloned",
       "WrapperRejectZeroCrossings_cloned",
     };
-    //for(int i = 0; i < 6; i++) {
-      if (!s.compare(NodeNames[4])) {
-        // if this is the kernel launch node 
-        std::cout << s << ": CPU" << "\n";
-        return 0;
-      }
+    //if (!s.compare(NodeNames[4])) {
+    //  std::cout << s << ": CPU" << "\n";
+    //  return 0;
     //}
     std::cout << s << ": GPU" << "\n";
     return 1;
diff --git a/llvm/projects/visc-rt/visc-rt.cpp b/llvm/projects/visc-rt/visc-rt.cpp
index ef8991ea6f..0b563191fd 100644
--- a/llvm/projects/visc-rt/visc-rt.cpp
+++ b/llvm/projects/visc-rt/visc-rt.cpp
@@ -71,9 +71,9 @@ static inline void checkErr(cl_int err, cl_int success, const char * name) {
 /************************* Policies *************************************/
 void llvm_visc_policy_init() {
   cout << "Initializing policy object ...\n";
-//  policy = new NodePolicy();
+  policy = new NodePolicy();
 //  policy = new IterationPolicy();
-  policy = new DeviceStatusPolicy();
+//  policy = new DeviceStatusPolicy();
   cout << "DONE: Initializing policy object.\n";
 }
 
-- 
GitLab