From 40712a3cd23566fe5df7cd23a14909b21c645fc8 Mon Sep 17 00:00:00 2001
From: shingjan <yjshi03@gmail.com>
Date: Sun, 31 May 2020 15:47:25 -0500
Subject: [PATCH] some minor changes on main

---
 hpvm/projects/onnx/frontend/graph_builder.py |  8 ++++----
 hpvm/projects/onnx/frontend/graph_ir.py      | 12 +++++++++++
 hpvm/projects/onnx/frontend/main.py          | 21 ++++++++------------
 3 files changed, 24 insertions(+), 17 deletions(-)

diff --git a/hpvm/projects/onnx/frontend/graph_builder.py b/hpvm/projects/onnx/frontend/graph_builder.py
index 7ffe32c538..0659fe3fcd 100644
--- a/hpvm/projects/onnx/frontend/graph_builder.py
+++ b/hpvm/projects/onnx/frontend/graph_builder.py
@@ -3,11 +3,10 @@ from onnx import numpy_helper
 from tensor import InputTensor, WeightTensor
 from graph_ir import *
 
-support_onnx_ops = {"DepthwiseConv" : [2],
+support_onnx_ops = {#"DepthwiseConv" : [2],
                "Conv" : [2], # only 2d supported here
                "MatMul" : None,
                "MaxPool": [2], # only 2d supported here
-               "Activation" : None,
                "BatchNormalization" : None,
                "Flatten" : None,
                "Add" : None,
@@ -187,6 +186,8 @@ class DFG(object):
             return SoftMaxNode(onnx_node)
         elif onnx_node.op_type == "Relu":
             return ReluNode(onnx_node)
+        elif onnx_node.op_type == "Tanh":
+            return TanhNode(onnx_node)
         elif onnx_node.op_type == "BatchNormalization":
             return BatchNormalizationNode(onnx_node)
         elif onnx_node.op_type == "Pad":
@@ -196,5 +197,4 @@ class DFG(object):
         elif onnx_node.op_type == "Flatten":
             return FlattenNode(onnx_node)
         else:
-            raise ValueError("Unsupported operator type!")
-            sys.exit("Unsupported operator type!")
+            raise ValueError("Unsupported operator type: {}!".format(onnx_node.op_type))
diff --git a/hpvm/projects/onnx/frontend/graph_ir.py b/hpvm/projects/onnx/frontend/graph_ir.py
index 2cfed0c78f..64796514bc 100644
--- a/hpvm/projects/onnx/frontend/graph_ir.py
+++ b/hpvm/projects/onnx/frontend/graph_ir.py
@@ -193,6 +193,18 @@ class ReluNode(DFGNode):
         self.inst_str += "tensorRelu(" + mapped_input_name + "); \n"
         return self.inst_str
 
+class TanhNode(DFGNode):
+    def __init__(self, layer):
+        DFGNode.__init__(self, layer)
+
+    def codegen(self, tensors):
+        cur_node = self.onnx_node
+        mapped_input_name = tensors[cur_node.input[0]].get_mapped_name()
+        mapped_output_name = tensors[cur_node.output[0]].get_mapped_name()
+        self.inst_str += "void* " + mapped_output_name + " = "
+        self.inst_str += "tensorTanh(" + mapped_input_name + "); \n"
+        return self.inst_str
+
 
 class BatchNormalizationNode(DFGNode):
     def __init__(self, layer):
diff --git a/hpvm/projects/onnx/frontend/main.py b/hpvm/projects/onnx/frontend/main.py
index 258e1d23a5..b9fefdecd7 100644
--- a/hpvm/projects/onnx/frontend/main.py
+++ b/hpvm/projects/onnx/frontend/main.py
@@ -5,9 +5,9 @@ import onnx
 import glob
 #from onnxruntime.backend.backend import OnnxRuntimeBackend as backend
 
-onnx_file_dir = "../models/keras/lenet.onnx"
+onnx_file_dir = "../models/keras/alexnet.onnx"
 src_emit_dir = "./test_src"
-opset_version_default = 11
+opset_version_default = 10
 
 def check_version(model, new_version):
     try:
@@ -33,25 +33,20 @@ def check_version(model, new_version):
     return model
 
 def compile(model):
-    # TODO: make this in constant
-    # make a cmd option, default value -> constant
     weights_dir = src_emit_dir
-    # test_data_dir = '../models/mnist/test_data_set_0'
-    # converted_model = convert_version(model)
-    # model = check_version(model, 11)
+    model = check_version(model, opset_version_default)
     from graph_builder import GraphBuilder
     from graph_codegen import GraphCodeGen
     from hpvm_codegen import HpvmCodeGen
-    gBuilder = GraphBuilder(model, None, "float32", weights_dir)
-    #gCodegen = GraphCodeGen(gBuilder.build_graph(), weights_dir)
-    hCodegen = HpvmCodeGen(gBuilder.build_graph(), weights_dir)
-    #gCodegen.compile()
-    hCodegen.compile()
+    graphBuilder = GraphBuilder(model, None, "float32", weights_dir)
+    #graphCodeGen = GraphCodeGen(gBuilder.build_graph(), weights_dir)
+    #graphCodeGen.compile()
+    hpvmCodeGen = HpvmCodeGen(graphBuilder.build_graph(), weights_dir)
+    hpvmCodeGen.compile()
 
 def main():
     # TODO: Put it in args
     model = onnx.load(onnx_file_dir)
-    # model = onnx.load('../models/keras/vgg16_cifar10.onnx')
     compile(model)
 
 if __name__ == "__main__":
-- 
GitLab