diff --git a/hpvm/test/dnn_benchmarks/keras/test_benchmarks.py b/hpvm/test/dnn_benchmarks/keras/test_benchmarks.py index 9be86c8a056b9a29ea05d209ec2c86c6dd0e56b4..bda3d4186f6d15e31ec0fa6ee547ca0dd0b29968 100644 --- a/hpvm/test/dnn_benchmarks/keras/test_benchmarks.py +++ b/hpvm/test/dnn_benchmarks/keras/test_benchmarks.py @@ -3,8 +3,10 @@ import os import sys import subprocess +import argparse from Config import * + class Benchmark: def __init__(self, binary_path, output_dir, test_accuracy): @@ -50,10 +52,15 @@ class Benchmark: return test_success - def runHPVM(self): + def runHPVM(self, weights_dump): - # Test Bechmark accuracy with pretrained weights (hpvm_relaod) - run_cmd = "python3 " + self.binary_path + " keras_reload frontend compile compile_tuner" + if weights_dump: + # Test Benchmark with Keras weight dumping + run_cmd = "python3 " + self.binary_path + " keras_reload frontend compile compile_tuner" + else: + # Test Benchmark accuracy with pretrained weights (hpvm_relaod) + run_cmd = "python3 " + self.binary_path + " hpvm_reload frontend compile compile_tuner" + try: subprocess.call(run_cmd, shell=True) except: @@ -120,10 +127,10 @@ class BenchmarkTests: self.passed_tests.append(benchmark.getPath()) - def runHPVMTests(self): + def runHPVMTests(self, weights_dump): for benchmark in self.benchmarks: - test_success = benchmark.runHPVM() + test_success = benchmark.runHPVM(weights_dump) if not test_success: self.failed_hpvm_tests.append(benchmark.getPath()) @@ -173,10 +180,21 @@ class BenchmarkTests: if __name__ == "__main__": - if len(sys.argv) < 2: - print ("Usage: python3 test_dnnbenchmarks.py ${work_dir}") - work_dir = sys.argv[1] + parser = argparse.ArgumentParser(description='Process some integers.') + + parser.add_argument('--work-dir', type=str, + help='working dir for dumping frontend generated files') + + parser.add_argument('--dump-weights', action="store_true", help='dump h5 weights to bin (default: False)') + args = parser.parse_args() + + work_dir = args.work_dir + dump_weights = args.dump_weights + + #print (dump_weights) + #sys.exit(0) + if os.path.exists(work_dir): print ("Work Directory Exists. Delete it or use a different work directory.") sys.exit(0) @@ -210,12 +228,10 @@ if __name__ == "__main__": #testMgr.runKerasTests() #testMgr.printKerasSummary() - testMgr.runHPVMTests() + testMgr.runHPVMTests(dump_weights) tests_passed = testMgr.printHPVMSummary() if not tests_passed: sys.exit(-1) - #testMgr.runKerasTests() - #testMgr.printKerasSummary()