Skip to content
Snippets Groups Projects
Commit 63ab0581 authored by Hashim Sharif's avatar Hashim Sharif
Browse files

Handling BatchNorm in PromiseBackend

parent 73c2b68a
No related branches found
No related tags found
No related merge requests found
......@@ -354,23 +354,6 @@ class TensorRtTranslator:
return False
#def genActivationCall(self, input_var, output_var, activation_type):
# func_name = ""
# if activation_type == "tanh":
# func_name = "Tanh"
# if activation_type == "relu":
# func_name = "Relu"
# if activation_type == "softmax":
# func_name = "Softmax"
# inst_str = "void* " + output_var + " = "
# inst_str += "tensor" + func_name + "(" + input_var + "); \n"
# print ("***** inst_str = ", inst_str, "\n")
# return inst_str
......@@ -460,7 +443,14 @@ class TensorRtTranslator:
input_var_name = self.getSingleInputName(cur_node)
inst_str = "void* " + out_var_name1 + " = "
inst_str += "tensorBatchNormalization(" + input_var_name + "); \n"
inst_str += "tensorBatchNorm(" + input_var_name + ", "
inst_str += cur_node.layer_name + "_gamma, "
inst_str += cur_node.layer_name + "_beta, "
inst_str += cur_node.layer_name + "_mean, "
inst_str += cur_node.layer_name + "_variance, "
inst_str += str(cur_node.epsilon)
inst_str += "); \n"
self.program_str += inst_str
......@@ -546,7 +536,6 @@ class TensorRtTranslator:
H = weights.shape[1]
W = weights.shape[0]
#unique_file_name = "conv" + str(layer_count) + ".bin"
unique_file_name = w_name + ".bin"
dumpConvWeights(prefix + unique_file_name, weights, N, C, H, W)
......@@ -574,7 +563,6 @@ class TensorRtTranslator:
self.filter_names[b_name] = 1
print (bias_weights.shape, b_name)
#unique_file_name = "conv_bias" + str(layer_count) + ".bin"
unique_file_name = b_name + ".bin"
dumpFcBias(prefix + unique_file_name, bias_weights, bias_weights.shape[0])
......@@ -599,7 +587,6 @@ class TensorRtTranslator:
H = weights.shape[0]
W = weights.shape[1]
#unique_file_name = "fc" + str(layer_count) + ".bin"
unique_file_name = w_name + ".bin"
dumpFcWeights(prefix + unique_file_name, weights, H, W)
......@@ -760,7 +747,6 @@ class TensorRtTranslator:
self.input_str += file_path + ".c_str(), 0," + str(N) + "," + str(C) + ","
self.input_str += str(H) + "," + str(W) + "); \n"
#self.weight_str += self.input_str
# Adding input to the filter map
self.filter_names["input"] = 1
......
......@@ -102,6 +102,12 @@ class State:
return True
return False
def isBatchNorm(self):
if "batchnorm" in self.op_string:
return True
return False
def isPool(self):
if "pool" in self.op_string and self.num_ops == 1:
......@@ -627,6 +633,31 @@ class PromiseRtTranslator:
state.clear()
def genBatchNormLayer(self, state):
first_op = state.getFirstOp()
last_op = state.getFirstOp()
input_var = self.getSingleInputName(first_op)
output_var = self.getVariableName(last_op)
promise_layer_str = "void* " + output_var + " = "
promise_layer_str += "tensorBatchNorm(" + input_var + ", "
promise_layer_str += first_op.layer_name + "_gamma, "
promise_layer_str += first_op.layer_name + "_beta, "
promise_layer_str += first_op.layer_name + "_mean, "
promise_layer_str += first_op.layer_name + "_variance, "
promise_layer_str += str(first_op.epsilon)
promise_layer_str += "); \n"
self.program_str += promise_layer_str
self.appendLayerString("BatchNorm", state)
state.clear()
def genSoftmaxLayer(self, state):
......@@ -744,7 +775,10 @@ class PromiseRtTranslator:
elif state.isDepthwiseConv():
self.genDepthwiseConvLayer(state)
elif state.isBatchNorm():
self.genBatchNormLayer(state)
elif state.isPool():
self.genPoolLayer(state)
......@@ -812,6 +846,26 @@ class PromiseRtTranslator:
self.traverseSuccessors(cur_node, state)
def handle_batchnorm(self, cur_node, state):
if not self.shouldVisit(cur_node):
return
layer_name = cur_node.layer_name
print ("handle_batchnorm", layer_name)
self.visited_nodes[layer_name] = True
self.genPreviousLayer(state)
state.add(cur_node, "batchnorm")
self.genBatchNormLayer(state)
self.traverseSuccessors(cur_node, state)
def handle_add(self, cur_node, state):
......@@ -907,6 +961,9 @@ class PromiseRtTranslator:
if layer_type == "DepthwiseConv2D":
self.handle_depthwise_conv(output_node, state)
if layer_type == "BatchNormalization":
self.handle_batchnorm(output_node, state)
if layer_type == "Dense":
self.handle_dense(output_node, state)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment