From da71c5c7808f6460e9bfdcea8937af6be9fe56b1 Mon Sep 17 00:00:00 2001
From: Hashim Sharif <hsharif3@miranda.cs.illinois.edu>
Date: Mon, 5 Apr 2021 15:30:09 -0500
Subject: [PATCH] Adding argparse options to Keras frontend test script

---
 .../dnn_benchmarks/keras/test_benchmarks.py   | 38 +++++++++++++------
 1 file changed, 27 insertions(+), 11 deletions(-)

diff --git a/hpvm/test/dnn_benchmarks/keras/test_benchmarks.py b/hpvm/test/dnn_benchmarks/keras/test_benchmarks.py
index 9be86c8a05..bda3d4186f 100644
--- a/hpvm/test/dnn_benchmarks/keras/test_benchmarks.py
+++ b/hpvm/test/dnn_benchmarks/keras/test_benchmarks.py
@@ -3,8 +3,10 @@
 import os
 import sys
 import subprocess
+import argparse
 from Config import *
 
+
 class Benchmark:
 
     def __init__(self, binary_path, output_dir, test_accuracy):
@@ -50,10 +52,15 @@ class Benchmark:
         return test_success
 
 
-    def runHPVM(self):
+    def runHPVM(self, weights_dump):
 
-        # Test Bechmark accuracy with pretrained weights (hpvm_relaod)
-        run_cmd = "python3 " + self.binary_path + " keras_reload frontend compile compile_tuner"
+        if weights_dump:
+          # Test Benchmark with Keras weight dumping
+          run_cmd = "python3 " + self.binary_path + " keras_reload frontend compile compile_tuner"
+        else:
+          # Test Benchmark accuracy with pretrained weights (hpvm_relaod)
+          run_cmd = "python3 " + self.binary_path + " hpvm_reload frontend compile compile_tuner"
+          
         try:
             subprocess.call(run_cmd, shell=True)
         except:
@@ -120,10 +127,10 @@ class BenchmarkTests:
                 self.passed_tests.append(benchmark.getPath())
 
 
-    def runHPVMTests(self):
+    def runHPVMTests(self, weights_dump):
 
         for benchmark in self.benchmarks:
-            test_success = benchmark.runHPVM()
+            test_success = benchmark.runHPVM(weights_dump)
 
             if not test_success:
                 self.failed_hpvm_tests.append(benchmark.getPath())
@@ -173,10 +180,21 @@ class BenchmarkTests:
             
 if __name__ == "__main__":
 
-    if len(sys.argv) < 2:
-        print ("Usage: python3 test_dnnbenchmarks.py ${work_dir}")
 
-    work_dir = sys.argv[1]
+    parser = argparse.ArgumentParser(description='Process some integers.')
+
+    parser.add_argument('--work-dir', type=str,
+                        help='working dir for dumping frontend generated files')
+
+    parser.add_argument('--dump-weights', action="store_true", help='dump h5 weights to bin (default: False)')
+    args = parser.parse_args()
+
+    work_dir = args.work_dir
+    dump_weights = args.dump_weights
+
+    #print (dump_weights)
+    #sys.exit(0)
+    
     if os.path.exists(work_dir):
         print ("Work Directory Exists. Delete it or use a different work directory.")
         sys.exit(0)
@@ -210,12 +228,10 @@ if __name__ == "__main__":
     #testMgr.runKerasTests()
     #testMgr.printKerasSummary()
     
-    testMgr.runHPVMTests()
+    testMgr.runHPVMTests(dump_weights)
     tests_passed = testMgr.printHPVMSummary()
 
     if not tests_passed:
         sys.exit(-1)
     
-    #testMgr.runKerasTests()
-    #testMgr.printKerasSummary()
  
-- 
GitLab