Skip to content
Snippets Groups Projects
Commit 8111a941 authored by Hashim Sharif's avatar Hashim Sharif
Browse files

Fixing Softmax handling

parent a5f68c94
No related branches found
No related tags found
No related merge requests found
...@@ -351,6 +351,8 @@ class TensorRtTranslator: ...@@ -351,6 +351,8 @@ class TensorRtTranslator:
inst_str = "void* " + output_var + " = " inst_str = "void* " + output_var + " = "
inst_str += "tensor" + func_name + "(" + input_var + "); \n" inst_str += "tensor" + func_name + "(" + input_var + "); \n"
print ("***** inst_str = ", inst_str, "\n")
return inst_str return inst_str
...@@ -392,7 +394,6 @@ class TensorRtTranslator: ...@@ -392,7 +394,6 @@ class TensorRtTranslator:
else: else:
inst_str += "1); \n" inst_str += "1); \n"
self.program_str += inst_str self.program_str += inst_str
...@@ -421,21 +422,22 @@ class TensorRtTranslator: ...@@ -421,21 +422,22 @@ class TensorRtTranslator:
# NOTE: Changing output variable # NOTE: Changing output variable
out_var_name1 = out_var_name2 out_var_name1 = out_var_name2
if self.hasActivation(cur_node):
activation_type = cur_node.activation_type
out_var_name3 = self.getVariableName(cur_node)
inst_str = self.genActivationCall(out_var_name1, out_var_name3, activation_type)
self.program_str += inst_str
if layer_type == "Activation": if layer_type == "Activation":
input_var_name = self.getSingleInputName(cur_node) input_var_name = self.getSingleInputName(cur_node)
inst_str = self.genActivationCall(input_var_name, out_var_name1, cur_node.activation_type) inst_str = self.genActivationCall(input_var_name, out_var_name1, cur_node.activation_type)
self.program_str += inst_str
if self.hasActivation(cur_node) and layer_type != "Activation":
activation_type = cur_node.activation_type
out_var_name3 = self.getVariableName(cur_node)
inst_str = self.genActivationCall(out_var_name1, out_var_name3, activation_type)
self.program_str += inst_str self.program_str += inst_str
if layer_type == "BatchNormalization": if layer_type == "BatchNormalization":
input_var_name = self.getSingleInputName(cur_node) input_var_name = self.getSingleInputName(cur_node)
...@@ -482,7 +484,9 @@ class TensorRtTranslator: ...@@ -482,7 +484,9 @@ class TensorRtTranslator:
# Skip visited nodes # Skip visited nodes
if cur_node.layer_name in visited_nodes: if cur_node.layer_name in visited_nodes:
return return
print ("-visiting = ", cur_node.layer_name, "\n")
if dfg.predVisited(cur_node, visited_nodes): if dfg.predVisited(cur_node, visited_nodes):
visited_nodes[cur_node.layer_name] = True visited_nodes[cur_node.layer_name] = True
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment