Skip to content
Snippets Groups Projects
Commit 0a3b3254 authored by Hashim Sharif's avatar Hashim Sharif
Browse files

Handling DepthwiseConv2D for Promise API translation

parent 292db054
No related branches found
No related tags found
No related merge requests found
...@@ -3,7 +3,7 @@ ...@@ -3,7 +3,7 @@
import numpy as np import numpy as np
import sys import sys
from keras import backend as K from keras import backend as K
from frontend.utils import *
class State: class State:
...@@ -599,13 +599,29 @@ class PromiseRtTranslator: ...@@ -599,13 +599,29 @@ class PromiseRtTranslator:
# FIX: ADD code for TensorAdd and ACTIVATION # FIX: ADD code for TensorAdd and ACTIVATION
# TODO: ADD code for TensorAdd and ACTIVATION # TODO: ADD code for TensorAdd and ACTIVATION
print (promise_layer_str) input_var = output_var
if nodeHasBias(conv_op):
output_var2 = self.getVariableName(conv_op)
promise_layer_str += "void* " + output_var2 + " = "
promise_layer_str += "tensorAdd(" + input_var + ", "
promise_layer_str += conv_op.layer_name + "_b"
promise_layer_str += "); \n"
# Update variable that holds input for next operation
input_var = output_var2
if nodeHasActivation(conv_op):
activation_type = conv_op.activation_type
output_var = self.getVariableName(conv_op)
promise_layer_str += genActivationCallStr(input_var, output_var, activation_type)
print (promise_layer_str)
self.program_str += promise_layer_str self.program_str += promise_layer_str
self.appendLayerString("DepthwiseConv", state) self.appendLayerString("DepthwiseConv", state)
state.clear() state.clear()
......
...@@ -17,3 +17,25 @@ def nodeHasActivation(cur_node): ...@@ -17,3 +17,25 @@ def nodeHasActivation(cur_node):
return True return True
else: else:
return False return False
def genActivationCallStr(input_var, output_var, activation_type):
func_name = ""
if activation_type == "tanh":
func_name = "Tanh"
if activation_type == "relu":
func_name = "Relu"
if activation_type == "softmax":
func_name = "Softmax"
inst_str = "void* " + output_var + " = "
inst_str += "tensor" + func_name + "(" + input_var + "); \n"
print ("***** inst_str = ", inst_str, "\n")
return inst_str
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment