Skip to content
Snippets Groups Projects
Commit 63ab0581 authored by Hashim Sharif's avatar Hashim Sharif
Browse files

Handling BatchNorm in PromiseBackend

parent 73c2b68a
No related branches found
No related tags found
No related merge requests found
...@@ -354,23 +354,6 @@ class TensorRtTranslator: ...@@ -354,23 +354,6 @@ class TensorRtTranslator:
return False return False
#def genActivationCall(self, input_var, output_var, activation_type):
# func_name = ""
# if activation_type == "tanh":
# func_name = "Tanh"
# if activation_type == "relu":
# func_name = "Relu"
# if activation_type == "softmax":
# func_name = "Softmax"
# inst_str = "void* " + output_var + " = "
# inst_str += "tensor" + func_name + "(" + input_var + "); \n"
# print ("***** inst_str = ", inst_str, "\n")
# return inst_str
...@@ -460,7 +443,14 @@ class TensorRtTranslator: ...@@ -460,7 +443,14 @@ class TensorRtTranslator:
input_var_name = self.getSingleInputName(cur_node) input_var_name = self.getSingleInputName(cur_node)
inst_str = "void* " + out_var_name1 + " = " inst_str = "void* " + out_var_name1 + " = "
inst_str += "tensorBatchNormalization(" + input_var_name + "); \n" inst_str += "tensorBatchNorm(" + input_var_name + ", "
inst_str += cur_node.layer_name + "_gamma, "
inst_str += cur_node.layer_name + "_beta, "
inst_str += cur_node.layer_name + "_mean, "
inst_str += cur_node.layer_name + "_variance, "
inst_str += str(cur_node.epsilon)
inst_str += "); \n"
self.program_str += inst_str self.program_str += inst_str
...@@ -546,7 +536,6 @@ class TensorRtTranslator: ...@@ -546,7 +536,6 @@ class TensorRtTranslator:
H = weights.shape[1] H = weights.shape[1]
W = weights.shape[0] W = weights.shape[0]
#unique_file_name = "conv" + str(layer_count) + ".bin"
unique_file_name = w_name + ".bin" unique_file_name = w_name + ".bin"
dumpConvWeights(prefix + unique_file_name, weights, N, C, H, W) dumpConvWeights(prefix + unique_file_name, weights, N, C, H, W)
...@@ -574,7 +563,6 @@ class TensorRtTranslator: ...@@ -574,7 +563,6 @@ class TensorRtTranslator:
self.filter_names[b_name] = 1 self.filter_names[b_name] = 1
print (bias_weights.shape, b_name) print (bias_weights.shape, b_name)
#unique_file_name = "conv_bias" + str(layer_count) + ".bin"
unique_file_name = b_name + ".bin" unique_file_name = b_name + ".bin"
dumpFcBias(prefix + unique_file_name, bias_weights, bias_weights.shape[0]) dumpFcBias(prefix + unique_file_name, bias_weights, bias_weights.shape[0])
...@@ -599,7 +587,6 @@ class TensorRtTranslator: ...@@ -599,7 +587,6 @@ class TensorRtTranslator:
H = weights.shape[0] H = weights.shape[0]
W = weights.shape[1] W = weights.shape[1]
#unique_file_name = "fc" + str(layer_count) + ".bin"
unique_file_name = w_name + ".bin" unique_file_name = w_name + ".bin"
dumpFcWeights(prefix + unique_file_name, weights, H, W) dumpFcWeights(prefix + unique_file_name, weights, H, W)
...@@ -760,7 +747,6 @@ class TensorRtTranslator: ...@@ -760,7 +747,6 @@ class TensorRtTranslator:
self.input_str += file_path + ".c_str(), 0," + str(N) + "," + str(C) + "," self.input_str += file_path + ".c_str(), 0," + str(N) + "," + str(C) + ","
self.input_str += str(H) + "," + str(W) + "); \n" self.input_str += str(H) + "," + str(W) + "); \n"
#self.weight_str += self.input_str
# Adding input to the filter map # Adding input to the filter map
self.filter_names["input"] = 1 self.filter_names["input"] = 1
......
...@@ -102,6 +102,12 @@ class State: ...@@ -102,6 +102,12 @@ class State:
return True return True
return False return False
def isBatchNorm(self):
if "batchnorm" in self.op_string:
return True
return False
def isPool(self): def isPool(self):
if "pool" in self.op_string and self.num_ops == 1: if "pool" in self.op_string and self.num_ops == 1:
...@@ -627,6 +633,31 @@ class PromiseRtTranslator: ...@@ -627,6 +633,31 @@ class PromiseRtTranslator:
state.clear() state.clear()
def genBatchNormLayer(self, state):
first_op = state.getFirstOp()
last_op = state.getFirstOp()
input_var = self.getSingleInputName(first_op)
output_var = self.getVariableName(last_op)
promise_layer_str = "void* " + output_var + " = "
promise_layer_str += "tensorBatchNorm(" + input_var + ", "
promise_layer_str += first_op.layer_name + "_gamma, "
promise_layer_str += first_op.layer_name + "_beta, "
promise_layer_str += first_op.layer_name + "_mean, "
promise_layer_str += first_op.layer_name + "_variance, "
promise_layer_str += str(first_op.epsilon)
promise_layer_str += "); \n"
self.program_str += promise_layer_str
self.appendLayerString("BatchNorm", state)
state.clear()
def genSoftmaxLayer(self, state): def genSoftmaxLayer(self, state):
...@@ -744,7 +775,10 @@ class PromiseRtTranslator: ...@@ -744,7 +775,10 @@ class PromiseRtTranslator:
elif state.isDepthwiseConv(): elif state.isDepthwiseConv():
self.genDepthwiseConvLayer(state) self.genDepthwiseConvLayer(state)
elif state.isBatchNorm():
self.genBatchNormLayer(state)
elif state.isPool(): elif state.isPool():
self.genPoolLayer(state) self.genPoolLayer(state)
...@@ -812,6 +846,26 @@ class PromiseRtTranslator: ...@@ -812,6 +846,26 @@ class PromiseRtTranslator:
self.traverseSuccessors(cur_node, state) self.traverseSuccessors(cur_node, state)
def handle_batchnorm(self, cur_node, state):
if not self.shouldVisit(cur_node):
return
layer_name = cur_node.layer_name
print ("handle_batchnorm", layer_name)
self.visited_nodes[layer_name] = True
self.genPreviousLayer(state)
state.add(cur_node, "batchnorm")
self.genBatchNormLayer(state)
self.traverseSuccessors(cur_node, state)
def handle_add(self, cur_node, state): def handle_add(self, cur_node, state):
...@@ -907,6 +961,9 @@ class PromiseRtTranslator: ...@@ -907,6 +961,9 @@ class PromiseRtTranslator:
if layer_type == "DepthwiseConv2D": if layer_type == "DepthwiseConv2D":
self.handle_depthwise_conv(output_node, state) self.handle_depthwise_conv(output_node, state)
if layer_type == "BatchNormalization":
self.handle_batchnorm(output_node, state)
if layer_type == "Dense": if layer_type == "Dense":
self.handle_dense(output_node, state) self.handle_dense(output_node, state)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment