From ce09ba57bb1e4fc3e0d45b3f2856d184ccfa16a5 Mon Sep 17 00:00:00 2001
From: Yifan Zhao <yifanz16@illinois.edu>
Date: Thu, 3 Dec 2020 01:49:46 -0600
Subject: [PATCH] Put all graph-checking in the same place

---
 hpvm/projects/onnx/frontend/graph_builder.py | 238 +++++++------------
 hpvm/projects/onnx/frontend/main.py          |   7 +-
 hpvm/projects/onnx/frontend/utils.py         |  16 --
 3 files changed, 94 insertions(+), 167 deletions(-)

diff --git a/hpvm/projects/onnx/frontend/graph_builder.py b/hpvm/projects/onnx/frontend/graph_builder.py
index fb5d97f4ac..9488c29717 100644
--- a/hpvm/projects/onnx/frontend/graph_builder.py
+++ b/hpvm/projects/onnx/frontend/graph_builder.py
@@ -1,188 +1,130 @@
-import sys
-from onnx import numpy_helper
+from os import PathLike
+from pathlib import Path
+
+import graph_ir as g
 from tensor import InputTensor, WeightTensor
-from graph_ir import *
-from utils import support_onnx_ops
+
 
 class GraphBuilder(object):
-    def __init__(self, model, shape, dtype, weight_dir):
+    def __init__(self, model, shape):
         self._check_model(model)
-        self._check_ops(model)
-        self.model = model
-        self.dtype = dtype
-        self.graph = model.graph
-        self.weight_dir = weight_dir
-        self.shape = shape if shape else self._build_shape()
-        self.tensors = dict()
+        self.shape = shape if shape else self._infer_shape(model.graph)
+        self.tensors = self._extract_tensors_from_graph(model.graph)
+        self.dfg = DFG(model.graph, self.tensors)
 
     ################################################
     # Aux functions for graph building
     ################################################
 
-    def _check_model(self, onnx_model):
-        try:
-            from onnx import checker, onnx_cpp2py_export
-            if hasattr(checker, 'check_model'):
-                # try use onnx's own model checker before converting any model
-                try:
-                    checker.check_model(onnx_model)
-                    print("onnx model is checked valid!")
-                except onnx_cpp2py_export.checker.ValidationError as e:
-                    import warnings
-                    warnings.warn(str(e))
-        except ImportError as e:
-            raise ImportError(
-                "Unable to import onnx.checker which is required {}".format(e))
+    @staticmethod
+    def _check_model(onnx_model):
+        import warnings
+        from onnx import checker, onnx_cpp2py_export
 
-    def _check_ops(self, model):
-        unsupport = dict()
-        for node in model.graph.node:
-            if node.op_type not in support_onnx_ops:
-                if node.op_type not in unsupport:
-                    unsupport[node.op_type] = 1
-                else:
-                    unsupport[node.op_type] += 1
-        if len(unsupport) != 0:
-            print(sorted(unsupport.items(), key=lambda x: x[1], reverse=True))
-            raise ValueError(
-                "Above operator(s) not currently supported! Compilation Aborted.")
+        if hasattr(checker, "check_model"):
+            # try use onnx's own model checker before converting any model
+            try:
+                checker.check_model(onnx_model)
+            except onnx_cpp2py_export.checker.ValidationError as e:
+                warnings.warn(str(e))
 
-    def _build_shape(self):
+    @staticmethod
+    def _infer_shape(onnx_graph):
         shape = {}
-        for input in self.graph.input:
+        for input in onnx_graph.input:
             # get type of input tensor
             tensor_type = input.type.tensor_type
             # check if it has a shape:
-            if (tensor_type.HasField("shape")):
+            if tensor_type.HasField("shape"):
                 shape[input.name] = tensor_type.shape
         return shape
 
-    def _parse_array(self, tensor_proto):
-        try:
-            from onnx.numpy_helper import to_array
-        except ImportError as e:
-            raise ImportError(
-                "Unable to import onnx which is required {}".format(e))
-        np_array = to_array(tensor_proto).reshape(tuple(tensor_proto.dims))
-        return np_array
-
-    def _parse_value_proto(self, value_proto):
-        """Parse ValueProto or raw str."""
-        try:
-            name = value_proto.name
-        except AttributeError:
-            name = value_proto
-        return name
-
-    def _parse_dtype(self, value_proto, dtype):
-        """Parse dtype."""
-        try:
-            from onnx.mapping import TENSOR_TYPE_TO_NP_TYPE
-            return TENSOR_TYPE_TO_NP_TYPE[value_proto.type.tensor_type.elem_type].name
-        except AttributeError:
-            return dtype
-
-    def _support_check(self, node):
-        op_name = node.op_type
-        #print(op_name)
-        if op_name not in support_onnx_ops:
-            return False
-        else:
-            if support_onnx_ops[op_name] == None:
-                return True
-            else:
-                #print(type(node.attribute))
-                for attr in node.attribute:
-                    # partially evaluate the kernel shape
-                    if attr.name == 'kernel_shape':
-                        # TODO: not assume all kernel shape is in INTS
-                        return len(attr.ints) in support_onnx_ops[op_name]
-                return False
-
-    ################################################
-    # Top level Graph Building functions
-    # return the compilation-ready graph
-    ################################################
-
-    def build_graph(self):
+    @staticmethod
+    def _extract_tensors_from_graph(onnx_graph):
+        tensors = {}
         # parse weight
         weight_cnt = 0
-        for weight_tensor in self.graph.initializer:
-            self.tensors[weight_tensor.name] = WeightTensor(weight_tensor)
-            self.tensors[weight_tensor.name].set_mapped_name("weight_" + str(weight_cnt))
+        for weight_tensor in onnx_graph.initializer:
+            tensors[weight_tensor.name] = WeightTensor(weight_tensor)
+            tensors[weight_tensor.name].set_mapped_name("weight_" + str(weight_cnt))
             weight_cnt += 1
         # parse input
         input_cnt = 0
-        for i in self.graph.input:
-            if i.name not in self.tensors:
-                self.tensors[i.name] = InputTensor(i.name)
-                self.tensors[i.name].set_mapped_name("input_" + str(input_cnt))
+        for i in onnx_graph.input:
+            if i.name not in tensors:
+                tensors[i.name] = InputTensor(i.name)
+                tensors[i.name].set_mapped_name("input_" + str(input_cnt))
                 input_cnt += 1
         # parse intermediate tensor
-        for node in self.graph.node:
-            op_name = node.op_type
-            #print("###############################")
-            if not self._support_check(node):
-                raise ValueError(
-                        "Operator not currently supported: `{0}`!".format(op_name))
-            #print("attribute: " + str(node.attribute))
-            #print("input: " + str(node.input))
-            #print("output: " + str(node.output))
-            #print("###############################")
+        for node in onnx_graph.node:
             for i in node.input:
-                if i not in self.tensors:
+                if i not in tensors:
                     raise ValueError(
-                        "Compilation Interrupted for missing input!`{0}`.".format(i))
+                        f"Compilation Interrupted for missing input!`{i}`."
+                    )
             for i in node.output:
-                if i not in self.tensors:
-                    self.tensors[i] = InputTensor(i)
-        # Dump weights
+                if i not in tensors:
+                    tensors[i] = InputTensor(i)
+        return tensors
+
+    ################################################
+    # Top level Graph Building functions
+    # return the compilation-ready graph
+    ################################################
+
+    def dump_weights(self, output_dir: PathLike):
+        output_dir = Path(output_dir)
         for tensor in self.tensors.values():
-            if isinstance(tensor, WeightTensor):
-                print("Dump weight: {0}".format(tensor.name))
-                tensor.dump_weight(self.weight_dir + "/" + tensor.get_mapped_name() + "_path.bin")
-        return DFG(self.graph, self.tensors)
+            if not isinstance(tensor, WeightTensor):
+                continue
+            tensor.dump_weight(output_dir / (tensor.get_mapped_name() + "_path.bin"))
+
 
 class DFG(object):
     def __init__(self, graph, tensors):
-        self.graph = graph
         self.tensors = tensors
-        self.nodes = list()
-        self.build_dfg()
+        self.nodes = self.build_dfg(graph)
 
-    def build_dfg(self):
-        print("\n\n ****** Traversing and Printing DFG ******* \n\n")
-        for node in self.graph.node:
-            self.nodes.extend(self.emit_node(node))
+    def build_dfg(self, graph):
+        error_nodes, generated_nodes = [], []
+        for onnx_node in graph.node:
+            our_node = self.emit_node(onnx_node)
+            if our_node is None:
+                error_nodes.append(onnx_node)
+            else:
+                generated_nodes.extend(our_node)
+        if error_nodes:
+            error_repr = [f"{n.name}({n.op_type})" for n in error_nodes]
+            if len(error_nodes) > 10:  # Magic number
+                raise ValueError(f"Unsupported operators (first 10): {error_repr[:10]}")
+            else:
+                raise ValueError(f"Unsupported operators: {error_repr}")
+        return generated_nodes
 
     # This should be the place where partial evaluation happens
     def emit_node(self, onnx_node):
         if onnx_node.op_type == "Conv":
+            weight_tensor = self.tensors[onnx_node.input[1]]
+            assert isinstance(weight_tensor, WeightTensor)
+            if len(weight_tensor.shape) != 4:
+                return None  # Only supports 2D conv
             if len(onnx_node.input) == 2:
-                return [Conv2DNode(onnx_node)]
+                return [g.Conv2DNode(onnx_node)]
             else:
-                return [Conv2DNode(onnx_node), BiasAddNode(onnx_node)]
-        elif onnx_node.op_type == "MaxPool":
-            return [MaxPool2DNode(onnx_node)]
-        elif onnx_node.op_type == "AveragePool":
-            return [AveragePool2DNode(onnx_node)]
-        elif onnx_node.op_type == "MatMul":
-            return [MatMulNode(onnx_node)]
-        elif onnx_node.op_type == "Add":
-            return [AddNode(onnx_node)]
-        elif onnx_node.op_type == "Softmax":
-            return [SoftMaxNode(onnx_node)]
-        elif onnx_node.op_type == "Relu":
-            return [ReluNode(onnx_node)]
-        elif onnx_node.op_type == "Tanh":
-            return [TanhNode(onnx_node)]
-        elif onnx_node.op_type == "BatchNormalization":
-            return [BatchNormalizationNode(onnx_node)]
-        elif onnx_node.op_type == "Pad":
-            return [PadNode(onnx_node)]
-        elif onnx_node.op_type == "Identity":
-            return [IdentityNode(onnx_node)]
-        elif onnx_node.op_type == "Flatten":
-            return [FlattenNode(onnx_node)]
-        else:
-            raise ValueError("Unsupported operator type: {}!".format(onnx_node.op_type))
+                return [g.Conv2DNode(onnx_node), g.BiasAddNode(onnx_node)]
+        one_to_one_nodes = {
+            "MaxPool": g.MaxPool2DNode,
+            "AveragePool": g.AveragePool2DNode,
+            "MatMul": g.MatMulNode,
+            "Add": g.AddNode,
+            "Softmax": g.SoftMaxNode,
+            "Relu": g.ReluNode,
+            "Tanh": g.TanhNode,
+            "BatchNormalization": g.BatchNormalizationNode,
+            "Pad": g.PadNode,
+            "Identity": g.IdentityNode,
+            "Flatten": g.FlattenNode,
+        }
+        if onnx_node.op_type in one_to_one_nodes:
+            return [one_to_one_nodes[onnx_node.op_type](onnx_node)]
+        return None
diff --git a/hpvm/projects/onnx/frontend/main.py b/hpvm/projects/onnx/frontend/main.py
index 683ca3bdae..7e3809d0eb 100644
--- a/hpvm/projects/onnx/frontend/main.py
+++ b/hpvm/projects/onnx/frontend/main.py
@@ -35,13 +35,14 @@ def compile(
 
     if opset_version is not None:
         model = check_version(model, opset_version)
-    graphBuilder = GraphBuilder(model, None, "float32", output_dir)
+    graphBuilder = GraphBuilder(model, output_dir)
     if hpvmc:
-        hpvmCodeGen = HpvmCodeGen(graphBuilder.build_graph(), output_dir)
+        hpvmCodeGen = HpvmCodeGen(graphBuilder.dfg, output_dir)
         hpvmCodeGen.compile()
     else:
-        graphCodeGen = GraphCodeGen(graphBuilder.build_graph(), output_dir, input_size)
+        graphCodeGen = GraphCodeGen(graphBuilder.dfg, output_dir, input_size)
         graphCodeGen.compile()
+    graphBuilder.dump_weights(output_dir)
 
 
 def parse_args():
diff --git a/hpvm/projects/onnx/frontend/utils.py b/hpvm/projects/onnx/frontend/utils.py
index 2a43557958..891f40fae7 100644
--- a/hpvm/projects/onnx/frontend/utils.py
+++ b/hpvm/projects/onnx/frontend/utils.py
@@ -1,20 +1,4 @@
 import numpy as np
-import struct
-import random
-
-support_onnx_ops = {#"DepthwiseConv" : [2],
-               "Conv" : [2], # only 2d supported here
-               "MatMul" : None,
-               "MaxPool": [2], # only 2d supported here
-               "BatchNormalization" : None,
-               "Flatten" : None,
-               "Add" : None,
-               "Relu" : None,
-               "Softmax" : None,
-               "Identity": None,
-               "Pad": None,
-               "AveragePool": None,
-               "Tanh": None}
 
 skip_layer = ["Identity", "Flatten", "Pad"]
 
-- 
GitLab