Skip to content
Snippets Groups Projects
Commit 8111a941 authored by Hashim Sharif's avatar Hashim Sharif
Browse files

Fixing Softmax handling

parent a5f68c94
No related branches found
No related tags found
No related merge requests found
......@@ -351,6 +351,8 @@ class TensorRtTranslator:
inst_str = "void* " + output_var + " = "
inst_str += "tensor" + func_name + "(" + input_var + "); \n"
print ("***** inst_str = ", inst_str, "\n")
return inst_str
......@@ -392,7 +394,6 @@ class TensorRtTranslator:
else:
inst_str += "1); \n"
self.program_str += inst_str
......@@ -421,21 +422,22 @@ class TensorRtTranslator:
# NOTE: Changing output variable
out_var_name1 = out_var_name2
if self.hasActivation(cur_node):
activation_type = cur_node.activation_type
out_var_name3 = self.getVariableName(cur_node)
inst_str = self.genActivationCall(out_var_name1, out_var_name3, activation_type)
self.program_str += inst_str
if layer_type == "Activation":
input_var_name = self.getSingleInputName(cur_node)
inst_str = self.genActivationCall(input_var_name, out_var_name1, cur_node.activation_type)
self.program_str += inst_str
if self.hasActivation(cur_node) and layer_type != "Activation":
activation_type = cur_node.activation_type
out_var_name3 = self.getVariableName(cur_node)
inst_str = self.genActivationCall(out_var_name1, out_var_name3, activation_type)
self.program_str += inst_str
if layer_type == "BatchNormalization":
input_var_name = self.getSingleInputName(cur_node)
......@@ -482,7 +484,9 @@ class TensorRtTranslator:
# Skip visited nodes
if cur_node.layer_name in visited_nodes:
return
print ("-visiting = ", cur_node.layer_name, "\n")
if dfg.predVisited(cur_node, visited_nodes):
visited_nodes[cur_node.layer_name] = True
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment