Skip to content
Snippets Groups Projects
Commit 9bbc47d2 authored by Hashim Sharif's avatar Hashim Sharif
Browse files

First full json file gen

parent 0b688f37
No related branches found
No related tags found
No related merge requests found
...@@ -233,8 +233,8 @@ class TensorRtTranslator: ...@@ -233,8 +233,8 @@ class TensorRtTranslator:
flops = H_d * W_d * K_d flops = H_d * W_d * K_d
DEBUG ("conv_flops = ", flops) DEBUG ("conv_flops = ", flops)
self.json_str += "convolution_" + str(self.op_count) + " : " + str(flops) + ", \n" self.json_str += "\"convolution_" + str(self.op_count) + "\" : " + str(flops) + ", \n"
self.knobs_str += "convolution_" + str(self.op_count) + " : [" + conv_knobs + "], \n" self.knobs_str += "\"convolution_" + str(self.op_count) + "\" : [" + conv_knobs + "], \n"
self.op_count += 1 self.op_count += 1
self.cur_height = self.cur_height / strides[0] self.cur_height = self.cur_height / strides[0]
...@@ -248,8 +248,8 @@ class TensorRtTranslator: ...@@ -248,8 +248,8 @@ class TensorRtTranslator:
flops = weights.shape[0] * weights.shape[1] flops = weights.shape[0] * weights.shape[1]
DEBUG ("dense_flops = ", flops) DEBUG ("dense_flops = ", flops)
self.json_str += "linear_" + str(self.op_count) + " : " + str(flops) + "\n" self.json_str += "\"linear_" + str(self.op_count) + "\" : " + str(flops) + "\n"
self.knobs_str += "linear_" + str(self.op_count) + " : [" + baseline_knobs + "], \n" self.knobs_str += "\"linear_" + str(self.op_count) + "\" : [" + baseline_knobs + "], \n"
self.op_count += 1 self.op_count += 1
self.cur_height = 1 self.cur_height = 1
...@@ -268,8 +268,8 @@ class TensorRtTranslator: ...@@ -268,8 +268,8 @@ class TensorRtTranslator:
def addBaselineKnob(self, op_name): def addBaselineKnob(self, op_name):
self.json_str += op_name + "_" + str(self.op_count) + " : 0, \n" self.json_str += "\"" + op_name + "_" + str(self.op_count) + "\" : 0, \n"
self.knobs_str += op_name + "_" + str(self.op_count) + " : [" + baseline_knobs + "], \n" self.knobs_str += "\"" + op_name + "_" + str(self.op_count) + "\" : [" + baseline_knobs + "], \n"
self.op_count += 1 self.op_count += 1
...@@ -949,8 +949,27 @@ class TensorRtTranslator: ...@@ -949,8 +949,27 @@ class TensorRtTranslator:
def dumpJsonFile(self, dir_prefix): def dumpJsonFile(self, dir_prefix):
f = open(dir_prefix + "/tuner.json", "w+") f = open(dir_prefix + "/tuner.json", "w+")
f.write(self.json_str) f.write("{ \n\n")
f.write(self.knobs_str)
op_cost_str = " \"op_cost\" = { \n"
op_cost_str += self.json_str[:-3]
#f.write(self.json_str)
op_cost_str += "\n } \n\n"
f.write(op_cost_str)
knobs_speedup_str = "\n \"knob_speedup\" : { \n"
for key in knobs_speedups:
knobs_speedup_str += "\"" + str(key) + "\" : " + str(knobs_speedups[key]) + ", \n"
f.write(knobs_speedup_str[:-3] + "\n} \n\n")
layer_knobs_str = " \"op_knobs\" = { \n"
layer_knobs_str += self.knobs_str[:-3]
layer_knobs_str += " \n\n } \n\n"
f.write(layer_knobs_str)
f.write("\n\n}")
f.close() f.close()
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment