From 78d20827fa0d0a36d0827a1476260c2b6f7d6e84 Mon Sep 17 00:00:00 2001
From: Hashim Sharif <hsharif3@miranda.cs.illinois.edu>
Date: Sun, 31 Jan 2021 01:23:28 -0600
Subject: [PATCH] Adding mising Benchmark.py (Keras frontend

---
 hpvm/projects/keras/src/Benchmark.py | 122 +++++++++++++++++++++++++++
 1 file changed, 122 insertions(+)
 create mode 100644 hpvm/projects/keras/src/Benchmark.py

diff --git a/hpvm/projects/keras/src/Benchmark.py b/hpvm/projects/keras/src/Benchmark.py
new file mode 100644
index 0000000000..3610b2e9a5
--- /dev/null
+++ b/hpvm/projects/keras/src/Benchmark.py
@@ -0,0 +1,122 @@
+
+
+import sys
+import os
+import shutil
+from keras.utils.np_utils import to_categorical
+from keras.models import load_model
+from frontend.approxhpvm_translator import translate_to_approxhpvm
+from frontend.weight_utils import dumpCalibrationData
+from frontend.weight_utils import reloadHPVMWeights
+
+
+# Every CNN Benchmark must inherit from Benchmark class
+# Defines common interfaces and virtual methods to be overridden by child classes
+class Benchmark:
+
+    def __init__(self, name, reload_dir, keras_model_file, data_dir, src_dir, num_classes, batch_size=500):
+        self.name = name
+        self.reload_dir = reload_dir
+        self.keras_model_file = keras_model_file
+        self.data_dir = data_dir
+        self.src_dir = src_dir
+        self.num_classes = num_classes
+        self.batch_size = batch_size
+        
+        
+    def buildModel(self):
+        return
+
+    def data_preprocess(self):
+        return
+    
+    def trainModel(self, X_train, y_train, X_test, y_test):
+        return
+
+    def inference(self):
+        return
+
+
+    # Compiles frontend generated sources
+    def compileSource(self, working_dir):
+
+        # set LLVM_SRC_ROOT
+        os.environ["CFLAGS"] = ""
+        os.environ["CXXFLAGS"] = ""
+
+        dest_file = working_dir + "CMakeLists.txt"
+        shutil.copy("cmake_template/CMakeLists.txt", dest_file)
+
+        # Cmake ../
+        # make
+
+
+    def printUsage(self):
+
+        print ("Usage: python ${benchmark.py} [hpvm_reload|train] [frontend] [compile]")
+        sys.exit(0)
+
+        
+    def run(self, argv):
+
+      if len(argv) < 2:
+          self.printUsage()
+          
+      print ("Build Model ...")
+      # Virtual method call implemented by each CNN
+      model = self.buildModel()
+
+      print ("Data Preprocess... \n")
+      # Virtual method call to preprocess test and train data 
+      X_train, y_train, X_test, y_test, X_tuner, y_tuner = self.data_preprocess()   
+
+      if argv[1] == "hpvm_reload":
+        print ("loading weights .....\n\n")  
+        model = reloadHPVMWeights(model, self.reload_dir, self.keras_model_file)
+
+      elif argv[1] == "keras_reload":
+        model.load_weights(self.keras_model_file)
+        model.compile(loss='categorical_crossentropy',
+                    optimizer='adam',
+                    metrics=['accuracy'])   
+
+      elif argv[1] == "train":
+        print ("Train Model ...")
+        model = self.trainModel(model, X_train, y_train, X_test, y_test)
+      else:
+          self.printUsage()
+
+          
+      score = model.evaluate(X_test, to_categorical(y_test, self.num_classes), verbose=0)
+      print('Test accuracy2:', score[1])
+
+
+      if len(argv) > 2:
+        if argv[2] == "frontend":
+          
+          # Main call to ApproxHPVM-Keras Frontend
+          working_dir = translate_to_approxhpvm(model,
+                                                self.data_dir, self.src_dir,  ##  "data/test_src/", 
+                                                X_test, y_test,
+                                                X_tuner, y_tuner,
+                                                self.batch_size, # FIXIT
+                                                self.num_classes,
+                                                (argv[1] == "hpvm_reload")) # Do not redump HPVM weights if `hpvm_reload` used
+
+          if len(argv) > 3 and argv[3] == "compile":
+            self.compileSource(working_dir)
+          else:
+            self.printUsage()
+
+
+        if argv[2] == "keras_dump":
+          model.save_weights(self.keras_model_file)
+
+          
+      elif len(argv) > 2:
+        self.printUsage()
+            
+
+    
+
+        
-- 
GitLab