Skip to content
Snippets Groups Projects
Commit 78d20827 authored by Hashim Sharif's avatar Hashim Sharif
Browse files

Adding mising Benchmark.py (Keras frontend

parent 3f83b085
No related branches found
No related tags found
No related merge requests found
import sys
import os
import shutil
from keras.utils.np_utils import to_categorical
from keras.models import load_model
from frontend.approxhpvm_translator import translate_to_approxhpvm
from frontend.weight_utils import dumpCalibrationData
from frontend.weight_utils import reloadHPVMWeights
# Every CNN Benchmark must inherit from Benchmark class
# Defines common interfaces and virtual methods to be overridden by child classes
class Benchmark:
def __init__(self, name, reload_dir, keras_model_file, data_dir, src_dir, num_classes, batch_size=500):
self.name = name
self.reload_dir = reload_dir
self.keras_model_file = keras_model_file
self.data_dir = data_dir
self.src_dir = src_dir
self.num_classes = num_classes
self.batch_size = batch_size
def buildModel(self):
return
def data_preprocess(self):
return
def trainModel(self, X_train, y_train, X_test, y_test):
return
def inference(self):
return
# Compiles frontend generated sources
def compileSource(self, working_dir):
# set LLVM_SRC_ROOT
os.environ["CFLAGS"] = ""
os.environ["CXXFLAGS"] = ""
dest_file = working_dir + "CMakeLists.txt"
shutil.copy("cmake_template/CMakeLists.txt", dest_file)
# Cmake ../
# make
def printUsage(self):
print ("Usage: python ${benchmark.py} [hpvm_reload|train] [frontend] [compile]")
sys.exit(0)
def run(self, argv):
if len(argv) < 2:
self.printUsage()
print ("Build Model ...")
# Virtual method call implemented by each CNN
model = self.buildModel()
print ("Data Preprocess... \n")
# Virtual method call to preprocess test and train data
X_train, y_train, X_test, y_test, X_tuner, y_tuner = self.data_preprocess()
if argv[1] == "hpvm_reload":
print ("loading weights .....\n\n")
model = reloadHPVMWeights(model, self.reload_dir, self.keras_model_file)
elif argv[1] == "keras_reload":
model.load_weights(self.keras_model_file)
model.compile(loss='categorical_crossentropy',
optimizer='adam',
metrics=['accuracy'])
elif argv[1] == "train":
print ("Train Model ...")
model = self.trainModel(model, X_train, y_train, X_test, y_test)
else:
self.printUsage()
score = model.evaluate(X_test, to_categorical(y_test, self.num_classes), verbose=0)
print('Test accuracy2:', score[1])
if len(argv) > 2:
if argv[2] == "frontend":
# Main call to ApproxHPVM-Keras Frontend
working_dir = translate_to_approxhpvm(model,
self.data_dir, self.src_dir, ## "data/test_src/",
X_test, y_test,
X_tuner, y_tuner,
self.batch_size, # FIXIT
self.num_classes,
(argv[1] == "hpvm_reload")) # Do not redump HPVM weights if `hpvm_reload` used
if len(argv) > 3 and argv[3] == "compile":
self.compileSource(working_dir)
else:
self.printUsage()
if argv[2] == "keras_dump":
model.save_weights(self.keras_model_file)
elif len(argv) > 2:
self.printUsage()
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment