diff --git a/llvm/projects/soc_simulator/src/driver.py b/llvm/projects/soc_simulator/src/driver.py index aa20a7e7a22743399d5281889542a79c8adb4397..c747f47d998edaef603a4fc4f38fccb9a4207ea6 100644 --- a/llvm/projects/soc_simulator/src/driver.py +++ b/llvm/projects/soc_simulator/src/driver.py @@ -1,12 +1,123 @@ # Python driver -- ported from Perl driver (driver.pl) +from collections import defaultdict +import os import sys -# Read layer info +def build_nested_default_dict(): + return defaultdict(build_nested_default_dict) -# Read tensor info +tensor_layers = defaultdict(build_nested_default_dict) -# Run simulations on promise +def is_conv(operation_name): + return operation_name.startswith("Conv") + +def is_nml(operation_name): + return operation_name.startswith("NML") + +def is_fc(operation_name): + return operation_name.startswith("FC") + +def parse_tensor_layer_file(layer_filename): + ''' + Convs: Layer name, N, Cin, H, W, Cout, Kh, Kw, Sh, Sw + FCs: Layer name, Rows_A, Cols_A, Rows_B, Cols_B + NMLs (No Man Lands):Â Â NML<number>Â (edited)Â + ''' + if not os.path.isfile(layer_filename): + print("ERROR: %s was not found." % layer_filename) + exit(1) + + layer_file = open(layer_filename, "r") + for line in layer_file: + layer_data = line.strip().split(',') + layer_name = layer_data[0] + + if is_conv(layer_name): + tensor_layers[layer_name]["N"] = layer_data[1] + tensor_layers[layer_name]["Cin"] = layer_data[2] + tensor_layers[layer_name]["H"] = layer_data[3] + tensor_layers[layer_name]["W"] = layer_data[4] + tensor_layers[layer_name]["Cout"] = layer_data[5] + tensor_layers[layer_name]["Kh"] = layer_data[6] + tensor_layers[layer_name]["Kw"] = layer_data[7] + tensor_layers[layer_name]["Sh"] = layer_data[8] + tensor_layers[layer_name]["Sw"] = layer_data[9] + + elif is_fc(layer_name): + tensor_layers[layer_name]["RA"] = layer_data[1] + tensor_layers[layer_name]["CA"] = layer_data[2] + tensor_layers[layer_name]["RB"] = layer_data[3] + tensor_layers[layer_name]["CB"] = layer_data[4] + + elif not is_nml(layer_name): # TODO should we store data for NMLs? + print("ERROR: Invalid layer name %s" % layer_name) + exit(1) + + layer_file.close() + +# should this be a nested dict of dicts? +# [layer_name][operation_name][cols] +tensor_table = defaultdict(build_nested_default_dict) + +def parse_tensor_table(table_filename): + if not os.path.isfile(table_filename): + print("ERROR: %s was not found." % table_filename) + exit(1) + + table_file = open(table_filename, "r") + + line = table_file.readline().strip() + + while line: + # Line here MUST be a header or there's a bug + # Get the description of the layer + assert(line.startswith("**")) + header_contents = line.split(' ')[1:] + layer_name = header_contents[0] + num_ops = int(header_contents[1]) + col_names = header_contents[2:] + + # Go through all operations in the layer + for op_count in range(num_ops): + line = table_file.readline().strip() + op_data = line.split(' ') + op_name = op_data[0] + + # Number of data items (#s) needs to match up with the # of cols + assert(len(op_data) - 1 == len(col_names)) + + # Go through all data items (each col element) per operation + for i in range(len(col_names)): + tensor_table[layer_name][op_name][col_names[i]] = op_data[i + 1] + + line = table_file.readline().strip() + + table_file.close() + + +def run_simulations(): + # open configuration file + # open results file + # read through each line in the configuration file + # for each config file line --> parse the comma separated voltage swing levels + # recall: each line = a configuration that works + # for each level + # if promise --> promise runs an entire layer + # quantize, no patching and unpatching + # run on promise + # output the total time and energy + # else + # for each sublevel (separated by spaces) + # quantize + # run + # keep track of total time and energy --> update as needed + # output the total time and energy + +# quantization: we always have smart dma +# need to search stuff up +# $layer = a map of elements +# stores the layer name, then if __name__ == "__main__": if len(sys.argv) != 4):