From f25d6c4d9db6b0e214c71804f033de37b87d6f53 Mon Sep 17 00:00:00 2001
From: Hashim Sharif <hsharif3@miranda.cs.illinois.edu>
Date: Fri, 4 Dec 2020 16:34:47 -0600
Subject: [PATCH] Adding more flag option checks to Benchmark class

---
 llvm/projects/keras/src/Benchmark.py | 28 +++++++++++++++++++++-------
 llvm/projects/keras/src/alexnet.py   |  7 +------
 2 files changed, 22 insertions(+), 13 deletions(-)

diff --git a/llvm/projects/keras/src/Benchmark.py b/llvm/projects/keras/src/Benchmark.py
index 8053345548..a275d103e2 100644
--- a/llvm/projects/keras/src/Benchmark.py
+++ b/llvm/projects/keras/src/Benchmark.py
@@ -47,14 +47,19 @@ class Benchmark:
 
         # Cmake ../
         # make
-        
 
-    def run(self, argv):
 
-      if len(argv) < 2:
-        print ("Usage: python ${benchmark.py} [hpvm_reload|keras_reload|train] [frontend] [compile]")   
+    def printUsage(self):
+
+        print ("Usage: python ${benchmark.py} [hpvm_reload|keras_reload|train] [frontend] [compile]")
         sys.exit(0)
 
+        
+    def run(self, argv):
+
+      if len(argv) < 2:
+          self.printUsage()
+          
       # Virtual method call implemented by each CNN
       model = self.buildModel()
 
@@ -65,12 +70,16 @@ class Benchmark:
         print ("loading weights .....\n\n")  
         model = reloadHPVMWeights(model, self.reload_dir, self.keras_model_file, X_test, Y_test)
 
-      if argv[1] == "keras_reload":
+      elif argv[1] == "keras_reload":
         model = load_model(self.keras_model_file)
 
-      if argv[1] == "train":
+      elif argv[1] == "train":
         model = self.trainModel(model)
 
+      else:
+          self.printUsage()
+
+          
       score = model.evaluate(X_test, to_categorical(Y_test, self.num_classes), verbose=0)
       print('Test accuracy2:', score[1])
 
@@ -86,6 +95,11 @@ class Benchmark:
         
         if len(argv) > 3 and argv[3] == "compile":
           self.compileSource(working_dir)
-        
 
+        else:
+          self.printUsage()
+
+      elif len(argv) > 2:
+        self.printUsage()
+            
 
diff --git a/llvm/projects/keras/src/alexnet.py b/llvm/projects/keras/src/alexnet.py
index ae2c20493c..9bfe80a156 100644
--- a/llvm/projects/keras/src/alexnet.py
+++ b/llvm/projects/keras/src/alexnet.py
@@ -146,16 +146,11 @@ class AlexNet(Benchmark):
     
     X_train = X_train / 255.0
     X_test = X_test / 255.0
-
-    print(X_train, X_test)
-      
+     
     mean = np.mean(X_train,axis=(0,1,2,3))
     std = np.std(X_train,axis=(0,1,2,3))   
     X_train = (X_train-mean)/(std+1e-7)
     X_test = (X_test-mean)/(std+1e-7)  
-
-    print(X_train, X_test)
-  
     
     return X_train, Y_train, X_test, Y_test
 
-- 
GitLab