From 38b4d2635bd2a19ca32175293f72daf1376be15c Mon Sep 17 00:00:00 2001
From: Yifan Zhao <yifanz16@illinois.edu>
Date: Sun, 31 Jan 2021 07:18:30 -0600
Subject: [PATCH] Added size analysis for graph

---
 .../torch2hpvm/torch2hpvm/codegen_hpvm.py     |  2 +-
 .../torch2hpvm/torch2hpvm/graph_builder.py    | 37 +++++++----
 .../torch2hpvm/torch2hpvm/graph_ir.py         | 63 +++++++++++++------
 .../torch2hpvm/torch2hpvm/onnx_attr.py        | 16 +++--
 4 files changed, 80 insertions(+), 38 deletions(-)

diff --git a/hpvm/projects/torch2hpvm/torch2hpvm/codegen_hpvm.py b/hpvm/projects/torch2hpvm/torch2hpvm/codegen_hpvm.py
index 65e73e9a9f..439c5af5d7 100644
--- a/hpvm/projects/torch2hpvm/torch2hpvm/codegen_hpvm.py
+++ b/hpvm/projects/torch2hpvm/torch2hpvm/codegen_hpvm.py
@@ -52,7 +52,7 @@ class CodeGen:
         for weight in weights:
             name = cls.make_c_identifier(weight.name)
             file_path = f"{weight.new_name}_path.bin"
-            ret.append({"name": name, "shape": weight.shape, "filename": file_path})
+            ret.append({"name": name, "shape": weight.output_shape, "filename": file_path})
         return ret
 
     @staticmethod
diff --git a/hpvm/projects/torch2hpvm/torch2hpvm/graph_builder.py b/hpvm/projects/torch2hpvm/torch2hpvm/graph_builder.py
index f40e604d93..255cef0047 100644
--- a/hpvm/projects/torch2hpvm/torch2hpvm/graph_builder.py
+++ b/hpvm/projects/torch2hpvm/torch2hpvm/graph_builder.py
@@ -6,7 +6,7 @@ import networkx as nx
 import onnx
 
 from . import graph_ir as g
-from .onnx_attr import node_attr_to_dict, node_to_shape
+from .onnx_attr import get_node_shape, node_attr_to_dict, node_to_shape
 
 PathLike = Union[str, Path]
 GraphT = onnx.GraphProto
@@ -109,10 +109,12 @@ class DFG(object):
 
         ret_graph = nx.DiGraph()
         onnx_defs, onnx_uses = def_use(graph.node)
-        tensors = extract_tensors_from_graph(graph)
         node_shape = node_to_shape(graph)
         node_and_attr = [(n, {"shape": shape}) for n, shape in node_shape.items()]
         ret_graph.add_nodes_from(node_and_attr)
+        tensors = extract_tensors_from_graph(graph)
+        tensor_and_attr = [(t, {"shape": t.output_shape}) for t in tensors.values()]
+        ret_graph.add_nodes_from(tensor_and_attr)
         for onnx_value_name, use_nodes in onnx_uses.items():
             def_node = onnx_defs.get(onnx_value_name)
             if def_node is None:
@@ -166,36 +168,47 @@ class DFG(object):
 
     @staticmethod
     def _emit_node(in_graph: nx.DiGraph, node: NodeT) -> Optional[EmitNodeT]:
+        output_shape = in_graph.nodes[node].get("shape")
         predec = sorted_inputs(in_graph, node)
+        predec_shapes = [in_graph.nodes[n].get("shape") for n in predec]
         if isinstance(node, g.DFGNode):
             # Directly add node into return graph.
             return node
-
         attrs = node_attr_to_dict(node)
+        attrs["input_shapes"] = predec_shapes
+        attrs["output_shape"] = output_shape
+
         if node.op_type == "Conv":
-            weight_tensor = predec[1]
-            assert isinstance(weight_tensor, g.WeightTensor)
-            if len(weight_tensor.shape) != 4:
-                return None  # Only supports 2D conv
+            if not isinstance(predec[1], g.WeightTensor) or len(predec_shapes[1]) != 4:
+                return None  # Only supports 2D conv with rhs being constant
+            # Only pass in the first 2 arguments' shapes
+            attrs["input_shapes"] = predec_shapes[:2]
             conv_node = g.Conv2DNode(node.name, **attrs)
             if len(predec) == 2:
                 return conv_node
             # Split into conv followed by an addition
-            bias_node = g.BiasAddNode(f"Bias_{node.name.split('_')[-1]}")
+            bias_node = g.BiasAddNode(
+                f"Bias_{node.name.split('_')[-1]}", [output_shape], output_shape
+            )
             return MarkedSubGraph.idiomatic_1to2(conv_node, bias_node, predec)
         if node.op_type in ("MatMul", "Gemm"):
+            attrs["input_shapes"] = predec_shapes[:2]
             mul_node = g.MatMulNode(node.name, **attrs)
             if node.op_type == "Gemm":
                 mul_node.gemm_transpose(node, predec)
             if len(predec) == 2:
                 return mul_node
             # Split into mul followed by an addition
-            bias_node = g.BiasAddNode(f"Bias_{node.name.split('_')[-1]}")
+            bias_node = g.BiasAddNode(
+                f"Bias_{node.name.split('_')[-1]}", [output_shape], output_shape
+            )
             return MarkedSubGraph.idiomatic_1to2(mul_node, bias_node, predec)
         if node.op_type == "GlobalAveragePool":
             input0_shape = in_graph.nodes[predec[0]]["shape"]
             _, _, h, w = input0_shape
-            return g.AveragePool2DNode(node.name, [1, 1], (h, w), [0, 0, 0, 0])
+            return g.AveragePool2DNode(
+                node.name, predec_shapes, output_shape, [1, 1], (h, w), [0, 0, 0, 0]
+            )
         one_to_one_nodes = {
             "MaxPool": g.MaxPool2DNode,
             "AveragePool": g.AveragePool2DNode,
@@ -303,7 +316,9 @@ def extract_tensors_from_graph(onnx_graph: GraphT) -> Dict[str, g.TensorNode]:
     for input_ in onnx_graph.input:
         if input_.name in tensors:
             continue
-        tensors[input_.name] = g.InputTensor(input_, f"input_{input_cnt}")
+        tensors[input_.name] = g.InputTensor(
+            input_, get_node_shape(input_), f"input_{input_cnt}"
+        )
         input_cnt += 1
     return tensors
 
diff --git a/hpvm/projects/torch2hpvm/torch2hpvm/graph_ir.py b/hpvm/projects/torch2hpvm/torch2hpvm/graph_ir.py
index 9eb2ebf5f6..58a8e85691 100644
--- a/hpvm/projects/torch2hpvm/torch2hpvm/graph_ir.py
+++ b/hpvm/projects/torch2hpvm/torch2hpvm/graph_ir.py
@@ -1,11 +1,14 @@
 import abc
 from os import PathLike
-from typing import List, Sequence, Tuple
+from typing import List, Optional, Sequence, Tuple
 
+import numpy as np
 import onnx
 
 from .onnx_attr import node_attr_to_dict
 
+ShapeT = Optional[List[int]]
+
 
 class DFGNode(abc.ABC):
     """Abstract node that represents 1 instruction in HPVM.
@@ -15,8 +18,18 @@ class DFGNode(abc.ABC):
 
     op_type = ""
 
-    def __init__(self, name: str, **kwargs):
+    def __init__(
+        self, name: str, input_shapes: Sequence[ShapeT], output_shape: ShapeT, **kwargs
+    ):
         self.name = name
+        self.input_shapes = input_shapes
+        self.output_shape = output_shape
+        sin = " x ".join(str(sh) if sh else "??" for sh in input_shapes)
+        sout = output_shape if output_shape else "??"
+        if sin:
+            print(f"{name}: {sin} -> {sout})")
+        else:
+            print(f"{name}: {sout}")
 
     def codegen(self) -> Tuple[str, list]:
         return "", []
@@ -34,10 +47,10 @@ class TensorNode(DFGNode, abc.ABC):
     This is akin to Value class in LLVM, but in a different place on the
     inheritance tree."""
 
-    def __init__(self, proto: onnx.TensorProto, new_name: str):
+    def __init__(self, proto: onnx.TensorProto, shape: ShapeT, new_name: str):
         if not proto.name.strip():
             raise ValueError("Tensor's name is required.")
-        super().__init__(proto.name)
+        super().__init__(proto.name, [], shape)
         self.new_name = new_name
 
     def __str__(self):
@@ -55,8 +68,8 @@ class InputTensor(TensorNode):
 
     op_type = "InputTensor"
 
-    def __init__(self, input_proto: onnx.TensorProto, new_name: str):
-        super().__init__(input_proto, new_name)
+    def __init__(self, input_proto: onnx.TensorProto, shape: ShapeT, new_name: str):
+        super().__init__(input_proto, shape, new_name)
         # get type of input tensor
         tensor_type = input_proto.type.tensor_type
         # check if it has a shape:
@@ -76,18 +89,17 @@ class WeightTensor(TensorNode):
     def __init__(self, weight_proto: onnx.TensorProto, new_name: str):
         from onnx import numpy_helper
 
-        super().__init__(weight_proto, new_name)
-        self.shape = []
         self.input_data = numpy_helper.to_array(weight_proto)
         sh = self.input_data.shape
         if len(sh) == 1:
-            self.shape = [1, sh[0], 1, 1]
+            shape = [1, sh[0], 1, 1]
         elif len(sh) == 2:
-            self.shape = [1, 1, sh[0], sh[1]]
+            shape = [1, 1, sh[0], sh[1]]
         elif len(sh) == 4:
-            self.shape = [sh[0], sh[1], sh[2], sh[3]]
+            shape = [sh[0], sh[1], sh[2], sh[3]]
         else:
-            self.shape = [1] * 4
+            shape = [1] * 4
+        super().__init__(weight_proto, shape, new_name)
 
     def dump_weight(self, file_name: PathLike):
         self.input_data.tofile(file_name)
@@ -96,7 +108,7 @@ class WeightTensor(TensorNode):
         if len(self.input_data.shape) != 2:
             raise ValueError("Can only transpose 2D array")
         self.input_data = self.input_data.T
-        self.shape[3], self.shape[2] = self.shape[2:]
+        self.output_shape[3], self.output_shape[2] = self.output_shape[2:]
 
 
 class Conv2DNode(DFGNode):
@@ -105,13 +117,15 @@ class Conv2DNode(DFGNode):
     def __init__(
         self,
         name: str,
+        input_shapes: Tuple[ShapeT, ShapeT],
+        output_shape: ShapeT,
         pads: Sequence[int],
         strides: Sequence[int],
         dilations: Sequence[int],
         group: int,
         kernel_shape,
     ):
-        super().__init__(name)
+        super().__init__(name, input_shapes, output_shape)
         assert len(pads) == 4, "2D convolution must have 4 padding values"
         if any(p != pads[0] for p in pads[1:]):
             raise ValueError("Convolution with different padding is unsupported")
@@ -120,18 +134,18 @@ class Conv2DNode(DFGNode):
         if group != 1:
             raise ValueError("Group > 1 is unsupported")
         self.pads = pads[0]
-        self.strides = sh, sw = strides
+        self.sh, self.sw = strides
 
     def codegen(self):
         return (
             "tensorConvolution",
-            [self.pads, self.pads, *self.strides],
+            [self.pads, self.pads, self.sh, self.sw],
         )
 
     def hpvm_codegen(self):
         return (
             "__hpvm__tensor_convolution",
-            [self.pads, self.pads, *self.strides],
+            [self.pads, self.pads, self.sh, self.sw],
         )
 
 
@@ -143,12 +157,14 @@ class _Pool2DNode(DFGNode, abc.ABC):
     def __init__(
         self,
         name: str,
+        input_shapes: Tuple[ShapeT, ShapeT],
+        output_shape: ShapeT,
         strides: Sequence[int],
         kernel_shape: Sequence[int],
         pads: Sequence[int],
         ceil_mode: int = 0,
     ):
-        super().__init__(name)
+        super().__init__(name, input_shapes, output_shape)
         self.strides = sh, sw = strides
         self.kernel_shape = kernel_shape
         pt, pb, pl, pr = pads
@@ -268,8 +284,15 @@ class TanhNode(DFGNode):
 class BatchNormalizationNode(DFGNode):
     op_type = "BN"
 
-    def __init__(self, name: str, epsilon: float, axis: int):
-        super().__init__(name)
+    def __init__(
+        self,
+        name: str,
+        input_shapes: Tuple[ShapeT, ShapeT],
+        output_shape: ShapeT,
+        epsilon: float,
+        axis: int,
+    ):
+        super().__init__(name, input_shapes, output_shape)
         self.epsilon = epsilon
 
     def codegen(self):
diff --git a/hpvm/projects/torch2hpvm/torch2hpvm/onnx_attr.py b/hpvm/projects/torch2hpvm/torch2hpvm/onnx_attr.py
index a43961b024..2186b7e0cd 100644
--- a/hpvm/projects/torch2hpvm/torch2hpvm/onnx_attr.py
+++ b/hpvm/projects/torch2hpvm/torch2hpvm/onnx_attr.py
@@ -83,16 +83,20 @@ def node_attr_to_dict(onnx_node: NodeProto):
     return {attr.name: parse_node_attr(attr) for attr in onnx_node.attribute}
 
 
-def node_to_shape(onnx_graph: GraphProto) -> Dict[NodeProto, Optional[List[int]]]:
-    def parse_shape(shape: TensorShapeProto) -> List[int]:
-        return [dim.dim_value for dim in shape.dim]
+def get_node_shape(node: NodeProto) -> List[int]:
+    return [dim.dim_value for dim in node.type.tensor_type.shape.dim]
+
 
+def node_to_shape(onnx_graph: GraphProto) -> Dict[NodeProto, Optional[List[int]]]:
     def unique_output_name(node: NodeProto) -> str:
         if len(node.output) != 1:
             raise ValueError(f"Node {node} has more than 1 outputs")
         return node.output[0]
 
-    out_name_to_shape = {
-        vi.name: parse_shape(vi.type.tensor_type.shape) for vi in onnx_graph.value_info
-    }
+    out_name_to_shape = {vi.name: get_node_shape(vi) for vi in onnx_graph.value_info}
+    # Add model's output shape into out_name_to_shape
+    if len(onnx_graph.output) != 1:
+        raise ValueError("Model doesn't have unique output")
+    model_output = onnx_graph.output[0]
+    out_name_to_shape[model_output.name] = get_node_shape(model_output)
     return {n: out_name_to_shape.get(unique_output_name(n)) for n in onnx_graph.node}
-- 
GitLab