From 63ab05818f9415f2ed29faf72e91df73b64c0cc6 Mon Sep 17 00:00:00 2001
From: Hashim Sharif <hsharif3@tyler.cs.illinois.edu>
Date: Sun, 14 Jul 2019 00:51:42 -0500
Subject: [PATCH] Handling BatchNorm in PromiseBackend

---
 .../keras/frontend/approxhpvm_translator.py   | 30 +++-------
 .../keras/frontend/promise_translator.py      | 59 ++++++++++++++++++-
 2 files changed, 66 insertions(+), 23 deletions(-)

diff --git a/llvm/projects/keras/frontend/approxhpvm_translator.py b/llvm/projects/keras/frontend/approxhpvm_translator.py
index 603431cdaa..87f53ca1be 100644
--- a/llvm/projects/keras/frontend/approxhpvm_translator.py
+++ b/llvm/projects/keras/frontend/approxhpvm_translator.py
@@ -354,23 +354,6 @@ class TensorRtTranslator:
     return False
 
 
-  #def genActivationCall(self, input_var, output_var, activation_type):
- 
-  #  func_name = ""
-  #  if activation_type == "tanh":
-  #    func_name = "Tanh"
-
-  #  if activation_type == "relu":
-  #    func_name = "Relu"
-
-  #  if activation_type == "softmax":
-  #    func_name = "Softmax"
-
-  #  inst_str = "void* " + output_var + " = "
-  #  inst_str += "tensor" + func_name + "(" + input_var + "); \n"
-  #  print ("***** inst_str = ", inst_str, "\n")
-    
-  #  return inst_str
 
   
       
@@ -460,7 +443,14 @@ class TensorRtTranslator:
       input_var_name = self.getSingleInputName(cur_node)
 
       inst_str = "void* " + out_var_name1 + " = "
-      inst_str += "tensorBatchNormalization(" + input_var_name + "); \n"
+      inst_str += "tensorBatchNorm(" + input_var_name + ", "
+      inst_str += cur_node.layer_name + "_gamma, "
+      inst_str += cur_node.layer_name + "_beta, "
+      inst_str += cur_node.layer_name + "_mean, "
+      inst_str += cur_node.layer_name + "_variance, "
+      inst_str += str(cur_node.epsilon)
+      inst_str += "); \n"
+      
       self.program_str += inst_str
       
       
@@ -546,7 +536,6 @@ class TensorRtTranslator:
         H = weights.shape[1]
         W = weights.shape[0]
 
-        #unique_file_name = "conv" + str(layer_count) + ".bin"
         unique_file_name = w_name + ".bin"
         dumpConvWeights(prefix + unique_file_name, weights, N, C, H, W)
 
@@ -574,7 +563,6 @@ class TensorRtTranslator:
           self.filter_names[b_name] = 1
           print (bias_weights.shape, b_name)
 
-          #unique_file_name = "conv_bias" + str(layer_count) + ".bin"
           unique_file_name = b_name + ".bin"
           dumpFcBias(prefix + unique_file_name, bias_weights, bias_weights.shape[0])
 
@@ -599,7 +587,6 @@ class TensorRtTranslator:
         H = weights.shape[0]
         W = weights.shape[1]
 
-        #unique_file_name = "fc" + str(layer_count) + ".bin"
         unique_file_name = w_name + ".bin"
         dumpFcWeights(prefix + unique_file_name, weights, H, W)
 
@@ -760,7 +747,6 @@ class TensorRtTranslator:
     self.input_str += file_path + ".c_str(), 0," + str(N) + "," + str(C) + ","
     self.input_str += str(H) + "," + str(W) + "); \n"
 
-    #self.weight_str += self.input_str
     
     # Adding input to the filter map
     self.filter_names["input"] = 1
diff --git a/llvm/projects/keras/frontend/promise_translator.py b/llvm/projects/keras/frontend/promise_translator.py
index f5a269f1af..670d6918f9 100644
--- a/llvm/projects/keras/frontend/promise_translator.py
+++ b/llvm/projects/keras/frontend/promise_translator.py
@@ -102,6 +102,12 @@ class State:
       return True
     return False
 
+  
+  def isBatchNorm(self):
+    if "batchnorm" in self.op_string:
+      return True
+    return False
+
 
   def isPool(self):
     if "pool" in self.op_string and self.num_ops == 1:
@@ -627,6 +633,31 @@ class PromiseRtTranslator:
     state.clear()
 
 
+
+  def genBatchNormLayer(self, state):
+
+    first_op = state.getFirstOp()
+    last_op = state.getFirstOp()
+
+    input_var = self.getSingleInputName(first_op)
+    output_var = self.getVariableName(last_op)
+
+    promise_layer_str = "void* " + output_var + " = "
+    promise_layer_str += "tensorBatchNorm(" + input_var + ", "
+    promise_layer_str += first_op.layer_name + "_gamma, "
+    promise_layer_str += first_op.layer_name + "_beta, "
+    promise_layer_str += first_op.layer_name + "_mean, "
+    promise_layer_str += first_op.layer_name + "_variance, "
+    promise_layer_str += str(first_op.epsilon)
+    promise_layer_str += "); \n"
+
+    self.program_str += promise_layer_str
+
+    self.appendLayerString("BatchNorm", state)
+
+    state.clear()
+
+    
     
 
   def genSoftmaxLayer(self, state):
@@ -744,7 +775,10 @@ class PromiseRtTranslator:
 
     elif state.isDepthwiseConv():
       self.genDepthwiseConvLayer(state)
-
+      
+    elif state.isBatchNorm():
+      self.genBatchNormLayer(state)
+      
     elif state.isPool():
       self.genPoolLayer(state)
 
@@ -812,6 +846,26 @@ class PromiseRtTranslator:
     
     self.traverseSuccessors(cur_node, state)
 
+
+
+  def handle_batchnorm(self, cur_node, state):
+    if not self.shouldVisit(cur_node):
+      return  
+
+    layer_name = cur_node.layer_name
+    print ("handle_batchnorm", layer_name)
+    self.visited_nodes[layer_name] = True
+
+    self.genPreviousLayer(state)
+
+    state.add(cur_node, "batchnorm")
+    
+    self.genBatchNormLayer(state)    
+    
+    self.traverseSuccessors(cur_node, state)
+
+
+
     
     
   def handle_add(self, cur_node, state):
@@ -907,6 +961,9 @@ class PromiseRtTranslator:
     if layer_type == "DepthwiseConv2D":
       self.handle_depthwise_conv(output_node, state)
 
+    if layer_type == "BatchNormalization":
+      self.handle_batchnorm(output_node, state)
+
     if layer_type == "Dense":
       self.handle_dense(output_node, state)
 
-- 
GitLab