Skip to content
Snippets Groups Projects
Commit de4a1143 authored by Elizabeth's avatar Elizabeth
Browse files

Cleaned up code

parent 9b49fd62
No related branches found
No related tags found
No related merge requests found
......@@ -4,11 +4,23 @@ import subprocess
import sys
class Driver:
fp16_swing = 8
class ApproxTypes:
FP16 = 0
FP32 = 1
PROMISE = 2
results_time_key = "Time"
results_energy_key = "Energy"
def driver(self):
self.parse_tensor_layer_file()
self.parse_tensor_table()
self.run_simulations()
self.display_results()
self.__parse_tensor_layer_file()
self.__parse_tensor_table()
self.__run_simulations()
self.__display_results()
def __init__(self, layer_filename, table_filename, config_filename, results_filename):
self.__layer_filename = layer_filename
......@@ -29,19 +41,23 @@ class Driver:
self.__aggregate_results = defaultdict(lambda: defaultdict(float))
self.__config_count = 0
@staticmethod
def is_conv(operation_name):
return operation_name.startswith("Conv")
@staticmethod
def is_nml(operation_name):
return operation_name.startswith("NML")
@staticmethod
def is_fc(operation_name):
return operation_name.startswith("FC")
def parse_tensor_layer_file(self):
def __parse_tensor_layer_file(self):
if not os.path.isfile(self.__layer_filename):
print("ERROR: %s was not found." % self.__layer_filename)
exit(1)
......@@ -78,7 +94,8 @@ class Driver:
self.__tensor_layers.append(tensor_layer)
layer_file.close()
def parse_tensor_table(self):
def __parse_tensor_table(self):
if not os.path.isfile(self.__table_filename):
print("ERROR: %s was not found." % self.__table_filename)
exit(1)
......@@ -119,17 +136,12 @@ class Driver:
line = table_file.readline().strip()
table_file.close()
fp16_swing = 8
class ApproxTypes:
FP16 = 0
FP32 = 1
PROMISE = 2
@staticmethod
def is_promise(config_layer):
return float(config_layer.split(' ')[0]) < Driver.fp16_swing
def __quantize(self, curr_layer, prev_layer, h2f_f2h_operation_ind, layer_data):
if curr_layer == prev_layer or curr_layer == Driver.ApproxTypes.PROMISE \
or prev_layer == Driver.ApproxTypes.PROMISE: # No quantization needed
......@@ -153,6 +165,7 @@ class Driver:
print("Quantization: (%f, %f)" % (time, energy))
return (time, energy)
def __run_promise_simulation(self, swing, layer_data):
layer_name = layer_data["Name"]
patch_factor = 1
......@@ -183,6 +196,7 @@ class Driver:
print("PROMISE: (%s, %s)" % (total_time_energy[0], total_time_energy[1]))
return float(total_time_energy[0]), float(total_time_energy[1])
def __run_gpu_simulation(self, curr_layer, layer_name, tensor_ind):
tensor_info = self.__tensor_table[layer_name][tensor_ind]
if curr_layer == Driver.ApproxTypes.FP32:
......@@ -194,11 +208,8 @@ class Driver:
print("GPU: (%f, %f)" % (conversion_time, conversion_energy))
return (conversion_time, conversion_energy)
# Default dict of default dicts
results_time_key = "Time"
results_energy_key = "Energy"
def run_simulations(self):
def __run_simulations(self):
print("run sim")
if not os.path.isfile(self.__config_filename):
print("ERROR: %s was not found" % self.__config_filename)
......@@ -252,7 +263,8 @@ class Driver:
print("\n")
config_file.close()
def display_results(self):
def __display_results(self):
results_file = open(self.__results_filename, "w")
attributes_to_print = [Driver.results_time_key, Driver.results_energy_key]
......@@ -268,15 +280,18 @@ class Driver:
for config_ind in range(self.__config_count):
results_file.write("c%d" % config_ind)
time_or_energy_val = self.__aggregate_results[attribute][config_ind]
results_file.write(",%f" % time_or_energy_val)
results_file.write(",%f\n" % (baseline_val / (time_or_energy_val + 0.0001)))
# Using repr to keep all decimal digits when writing to file
results_file.write(",%s" % repr(time_or_energy_val))
results_file.write(",%s\n" % repr(baseline_val / (time_or_energy_val + 0.0001)))
if not best_result or time_or_energy_val < best_result:
best_result = time_or_energy_val
best_config = config_ind
results_file.write("\nc%d,%f\n\n" % (best_config, self.__aggregate_results[attribute][best_config]))
results_file.write("\nc%d,%s\n\n" % (best_config, repr(self.__aggregate_results[attribute][best_config])))
results_file.close()
if __name__ == "__main__":
if len(sys.argv) != 5:
print("Usage: python driver.py <layer info> <tensor info> <configurations> <results file>")
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment