Skip to content
Snippets Groups Projects
Commit 014fc8d3 authored by Hashim Sharif's avatar Hashim Sharif
Browse files

Adding test script for Keras Accuracy Check

parent d4ef8765
No related branches found
No related tags found
No related merge requests found
import subprocess
class Benchmark:
def __init__(self, binary_path, test_accuracy):
self.binary_path = binary_path
self.test_accuracy = test_accuracy
self.epsilon = 0.05 # Adding some slack for accuracy difference
def getPath(self):
return self.binary_path
def readAccuracy(self):
f = open("final_accuracy", "r") # File with final benchmark accuracy
acc_str = f.read()
return float(acc_str)
def run(self):
# Test Bechmark accuracy with pretrained weights (hpvm_relaod)
run_cmd = "python " + self.binary_path + " hpvm_reload "
try:
subprocess.call(run_cmd, shell=True)
except:
return False
accuracy = self.readAccuracy()
print ("accuracy = ", accuracy, " test_accuracy = ", self.test_accuracy)
test_success = False
if (abs(self.test_accuracy - accuracy) < self.epsilon):
print ("Test for " + self. binary_path + " Passed ")
test_success = True
else:
print ("Test Failed for " + self.binary_path)
test_success = False
return test_success
class BenchmarkTests:
def __init__(self):
self.benchmarks = []
self.passed_tests = []
self.failed_tests = []
def addBenchmark(self, benchmark):
self.benchmarks.append(benchmark)
def runTests(self):
for benchmark in self.benchmarks:
test_success = benchmark.run()
if not test_success:
self.failed_tests.append(benchmark.getPath())
else:
self.passed_tests.append(benchmark.getPath())
def printSummary(self):
failed_test_count = len(self.failed_tests)
passed_test_count = len(self.passed_tests)
print (" Tests Passed = " + str(passed_test_count) + " / " + str(len(self.benchmarks)))
print ("******* Passed Tests ** \n")
for passed_test in self.passed_tests:
print ("Passed: " + passed_test)
print (" Tests Failed = " + str(failed_test_count) + " / " + str(len(self.benchmarks)))
print ("****** Failed Tests *** \n")
for failed_test in self.failed_tests:
print ("Failed: " + failed_test)
if __name__ == "__main__":
testMgr = BenchmarkTests()
AlexNet = Benchmark("src/alexnet.py", 79.28)
AlexNet2 = Benchmark("src/alexnet2.py", 84.98)
LeNet = Benchmark("src/lenet.py", 98.70)
MobileNet = Benchmark("src/mobilenet_cifar10.py", 84.42)
ResNet18 = Benchmark("src/resnet18_cifar10.py", 89.56)
VGG16_cifar10 = Benchmark("src/vgg16_cifar10.py", 89.96)
VGG16_cifar100 = Benchmark("src/vgg16_cifar100.py", 66.50)
testMgr.addBenchmark(AlexNet)
testMgr.addBenchmark(AlexNet2)
testMgr.addBenchmark(LeNet)
testMgr.addBenchmark(MobileNet)
testMgr.addBenchmark(ResNet18)
testMgr.addBenchmark(VGG16_cifar10)
testMgr.addBenchmark(VGG16_cifar100)
testMgr.runTests()
testMgr.printSummary()
...@@ -90,6 +90,10 @@ class Benchmark: ...@@ -90,6 +90,10 @@ class Benchmark:
score = model.evaluate(X_test, to_categorical(y_test, self.num_classes), verbose=0) score = model.evaluate(X_test, to_categorical(y_test, self.num_classes), verbose=0)
print('Test accuracy2:', score[1]) print('Test accuracy2:', score[1])
f = open("final_accuracy", "w+")
f.write(str(score[1] * 100))
f.close()
if len(argv) > 2: if len(argv) > 2:
if argv[2] == "frontend": if argv[2] == "frontend":
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment