From 32c01c2810f10bf6f11a65e4c82d7d031341e206 Mon Sep 17 00:00:00 2001
From: Neta Zmora <neta.zmora@intel.com>
Date: Tue, 15 May 2018 17:12:12 +0300
Subject: [PATCH] PyTorch 0.4 improvement to SummaryGraph

PyTorch 0.4 now fully supports the ONNX export features that are needed
in order to create a SummaryGraph, which is sort of a "shadow graph" for
PyTorch models.

The big advantage of SummaryGraph is that it gives us information about
the connectivity of nodes.  With connectivity information we can compute
per-node MAC (compute) and BW, and better yet, we can remove channels,
filters, and layers (more on this in future commits).

In this commit we (1) replace the long and overly-verbose ONNX node names,
with PyTorch names; and (2) move MAC and BW attributes from the Jupyter
notebook to the SummaryGraph.
---
 apputils/model_summaries.py | 92 +++++++++++++++++++++++++++++++++++
 jupyter/experimental.ipynb  | 95 ++++++-------------------------------
 2 files changed, 106 insertions(+), 81 deletions(-)

diff --git a/apputils/model_summaries.py b/apputils/model_summaries.py
index a650e54..12c3b6f 100755
--- a/apputils/model_summaries.py
+++ b/apputils/model_summaries.py
@@ -21,6 +21,8 @@ of PyTorch with stable support for the JIT tracer functionality we employ in thi
 code (it was built with a 4.x master branch).
 """
 
+import re
+import numpy as np
 import torch
 import torchvision
 from torch.autograd import Variable
@@ -28,6 +30,26 @@ import torch.jit as jit
 import pandas as pd
 from tabulate import tabulate
 
+def onnx_name_2_pytorch_name(name):
+    # Convert a layer's name from an ONNX name, to a PyTorch name
+    # For example:
+    #   ResNet/Sequential[layer3]/BasicBlock[0]/ReLU[relu].1 ==> layer3.0.relu.1
+
+    # First see if there's an instance identifier
+    instance = ''
+    if name.find('.')>0:
+        instance = name[name.find('.')+1 :]
+
+    # Next, split by square brackets
+    name_parts = re.findall('\[.*?\]', name)
+    name_parts = [part[1:-1] for part in name_parts]
+
+    #print(op['name'] , '.'.join(name_parts), op['type'])
+    name = '.'.join(name_parts) + instance
+    #if name == '':
+    #    name += 'stam'
+    return name
+
 
 class SummaryGraph(object):
     """We use Pytorch's JIT tracer to run a forward pass and generate a trace graph, which
@@ -89,6 +111,12 @@ class SummaryGraph(object):
                 same = [layer for layer in self.ops if layer['orig-name'] == op['orig-name']]
                 if len(same) > 0:
                     op['name'] += "." + str(len(same))
+
+                op['name'] = onnx_name_2_pytorch_name(op['name'])
+                #op['name'] = '\n'.join(op['name'], op['type'])
+                op['name'] += ("\n" + op['type'])
+
+                # print(op['name'])
                 self.ops.append(op)
 
                 for input_ in node.inputs():
@@ -101,6 +129,10 @@ class SummaryGraph(object):
 
                 op['attrs'] = {attr_name: node[attr_name] for attr_name in node.attributeNames()}
 
+        self.add_macs_attr()
+        self.add_footprint_attr()
+        self.add_arithmetic_intensity_attr()
+
 
     def __add_input(self, op, n):
         param = self.__add_param(n)
@@ -135,6 +167,66 @@ class SummaryGraph(object):
             return None
         return tensor
 
+    def param_shape(self, param_id):
+        return self.params[param_id]['shape']
+
+    @staticmethod
+    def volume(dims):
+        return np.prod(dims)
+
+    def param_volume(self, param_id):
+        return SummaryGraph.volume(self.param_shape(param_id))
+
+    def add_macs_attr(self):
+        for op in self.ops:
+            op['attrs']['MACs'] = 0
+            if op['type'] == 'Conv':
+                conv_out = op['outputs'][0]
+                conv_in =  op['inputs'][0]
+                conv_w = op['attrs']['kernel_shape']
+                ofm_vol = self.param_volume(conv_out)
+                # MACs = volume(OFM) * (#IFM * K^2)
+                op['attrs']['MACs'] = ofm_vol * SummaryGraph.volume(conv_w) * self.params[conv_in]['shape'][1]
+            elif op['type'] == 'Gemm':
+                conv_out =  op['outputs'][0]
+                conv_in =  op['inputs'][0]
+                n_ifm = self.param_shape(conv_in)[1]
+                n_ofm = self.param_shape(conv_out)[1]
+                # MACs = #IFM * #OFM
+                op['attrs']['MACs'] = n_ofm * n_ifm
+
+    def add_footprint_attr(self):
+        for op in self.ops:
+            op['attrs']['footprint'] = 0
+            if op['type'] in ['Conv', 'Gemm', 'MaxPool']:
+                conv_out = op['outputs'][0]
+                conv_in =  op['inputs'][0]
+                ofm_vol = self.param_volume(conv_out)
+                ifm_vol = self.param_volume(conv_in)
+                if op['type'] == 'Conv' or op['type'] == 'Gemm':
+                    conv_w = op['inputs'][1]
+                    weights_vol = self.param_volume(conv_w)
+                    #print(ofm_vol , ifm_vol , weights_vol)
+                    op['attrs']['footprint'] = ofm_vol + ifm_vol + weights_vol
+                    op['attrs']['fm_vol'] = ofm_vol + ifm_vol
+                    op['attrs']['weights_vol'] = weights_vol
+                elif op['type'] == 'MaxPool':
+                    op['attrs']['footprint'] = ofm_vol + ifm_vol
+
+    def add_arithmetic_intensity_attr(self):
+        for op in self.ops:
+            if op['attrs']['footprint'] == 0:
+                op['attrs']['ai'] = 0
+            else:
+                # integers are enough, and note that we also round up
+                op['attrs']['ai'] = ((op['attrs']['MACs']+0.5*op['attrs']['footprint']) // op['attrs']['footprint'])
+
+    def get_attr(self, attr, f = lambda op: True):
+        return [op['attrs'][attr] for op in self.ops if attr in op['attrs'] and f(op)]
+
+    def get_ops(self, attr, f = lambda op: True):
+        return [op for op in self.ops if attr in op['attrs'] and f(op)]
+
 
 def attributes_summary(sgraph, ignore_attrs):
     """Generate a summary of a graph's attributes.
diff --git a/jupyter/experimental.ipynb b/jupyter/experimental.ipynb
index d73e876..aefbdfd 100644
--- a/jupyter/experimental.ipynb
+++ b/jupyter/experimental.ipynb
@@ -73,7 +73,7 @@
    "outputs": [],
    "source": [
     "dataset = 'imagenet'\n",
-    "dummy_input = Variable(torch.randn(1, 3, 224, 224), requires_grad=False)\n",
+    "dummy_input = torch.randn((1, 3, 224, 224), requires_grad=False)\n",
     "arch = 'resnet18'\n",
     "#arch = 'alexnet'\n",
     "checkpoint_file = None \n",
@@ -145,75 +145,15 @@
    },
    "outputs": [],
    "source": [
-    "def volume(dims):\n",
-    "    vol = 1\n",
-    "    for d in range(len(dims)): vol *= dims[d]\n",
-    "    return vol\n",
     "\n",
-    "def param_shape(sgraph, param_id):\n",
-    "    return sgraph.params[param_id]['shape']\n",
-    "\n",
-    "def param_volume(sgraph, param_id):\n",
-    "    return volume(param_shape(sgraph, param_id))\n",
-    "    \n",
-    "def add_macs_attr(sgraph):                           \n",
-    "    for op in sgraph.ops:\n",
-    "        op['attrs']['MACs'] = 0\n",
-    "        if op['type'] == 'Conv':\n",
-    "            conv_out = op['outputs'][0]\n",
-    "            conv_in =  op['inputs'][0]\n",
-    "            conv_w = op['attrs']['kernel_shape']\n",
-    "            ofm_vol = param_volume(sgraph, conv_out)\n",
-    "            # MACs = volume(OFM) * (#IFM * K^2)\n",
-    "            op['attrs']['MACs'] = ofm_vol * volume(conv_w) * sgraph.params[conv_in]['shape'][1]\n",
-    "        elif op['type'] == 'Gemm':\n",
-    "            conv_out =  op['outputs'][0]\n",
-    "            conv_in =  op['inputs'][0]\n",
-    "            n_ifm = param_shape(sgraph, conv_in)[1]\n",
-    "            n_ofm = param_shape(sgraph, conv_out)[1]\n",
-    "            # MACs = #IFM * #OFM\n",
-    "            op['attrs']['MACs'] = n_ofm * n_ifm            \n",
-    "\n",
-    "def add_footprint_attr(sgraph):                           \n",
-    "    for op in sgraph.ops:\n",
-    "        op['attrs']['footprint'] = 0\n",
-    "        if op['type'] in ['Conv', 'Gemm', 'MaxPool']:\n",
-    "            conv_out = op['outputs'][0]\n",
-    "            conv_in =  op['inputs'][0]\n",
-    "            ofm_vol = param_volume(sgraph, conv_out)\n",
-    "            ifm_vol = param_volume(sgraph, conv_in)            \n",
-    "            if op['type'] == 'Conv' or op['type'] == 'Gemm':\n",
-    "                conv_w = op['inputs'][1]\n",
-    "                weights_vol = param_volume(sgraph, conv_w)\n",
-    "                #print(ofm_vol , ifm_vol , weights_vol)\n",
-    "                op['attrs']['footprint'] = ofm_vol + ifm_vol + weights_vol\n",
-    "                op['attrs']['fm_vol'] = ofm_vol + ifm_vol\n",
-    "                op['attrs']['weights_vol'] = weights_vol\n",
-    "            elif op['type'] == 'MaxPool':\n",
-    "                op['attrs']['footprint'] = ofm_vol + ifm_vol\n",
-    "\n",
-    "def add_arithmetic_intensity_attr(sgraph):                           \n",
-    "    for op in sgraph.ops:\n",
-    "        if op['attrs']['footprint'] == 0:\n",
-    "            op['attrs']['ai'] = 0\n",
-    "        else:\n",
-    "            # integers are enough, and note that we also round up\n",
-    "            op['attrs']['ai'] = ((op['attrs']['MACs']+0.5*op['attrs']['footprint'])  // op['attrs']['footprint']) \n",
-    "\n",
-    "def get_attr(sgraph, attr, f = lambda op: True):\n",
-    "    return [op['attrs'][attr] for op in sgraph.ops if attr in op['attrs'] and f(op)]\n",
-    "\n",
-    "def get_ops(sgraph, attr, f = lambda op: True):\n",
-    "    return [op for op in sgraph.ops if attr in op['attrs'] and f(op)]\n",
-    "        \n",
-    "add_macs_attr(g)\n",
-    "add_footprint_attr(g)\n",
-    "add_arithmetic_intensity_attr(g)\n",
+    "# add_macs_attr(g)\n",
+    "# add_footprint_attr(g)\n",
+    "# add_arithmetic_intensity_attr(g)\n",
     "ignore_attrs = ['group', 'is_test', 'consumed_inputs', 'alpha', 'beta', 'MACs', 'footprint', 'ai', 'fm_vol', 'weights_vol']\n",
     "df = attributes_summary(g, ignore_attrs)\n",
-    "df['MAC'] = get_attr(g, 'MACs')\n",
-    "df['BW'] = get_attr(g, 'footprint')\n",
-    "df['AI'] = get_attr(g, 'ai')\n",
+    "df['MAC'] = g.get_attr('MACs')\n",
+    "df['BW'] = g.get_attr('footprint')\n",
+    "df['AI'] = g.get_attr('ai')\n",
     "#df = df.assign([5]*len(df)).values\n",
     "\n",
     "qgrid.show_grid(df)"
@@ -252,8 +192,8 @@
     "\n",
     "sgraph = g\n",
     "names = [op['name'] for op in sgraph.ops]\n",
-    "setA = get_attr(g, 'fm_vol') \n",
-    "setB = get_attr(g, 'weights_vol') \n",
+    "setA = g.get_attr('fm_vol') \n",
+    "setB = g.get_attr('weights_vol') \n",
     "plot_bars(None, setA, 'Feature maps', setB, 'Weights', names, 'Weights footprint vs. feature-maps footprint\\n(Normalized)')"
    ]
   },
@@ -264,7 +204,7 @@
    "outputs": [],
    "source": [
     "names = [op['name'] for op in sgraph.ops if 'MACs' in op['attrs'] and op['attrs']['MACs']>0]\n",
-    "macs = get_attr(g, 'MACs', lambda op: op['attrs']['MACs']>0)\n",
+    "macs = g.get_attr('MACs', lambda op: op['attrs']['MACs']>0)\n",
     "\n",
     "y_pos = np.arange(len(names))\n",
     "fig, ax = plt.subplots(figsize=(20,10))\n",
@@ -284,7 +224,7 @@
     "for location in ['right', 'left', 'top', 'bottom']:\n",
     "    ax.spines[location].set_visible(False) \n",
     "\n",
-    "ops = get_ops(g, 'MACs', lambda op: op['attrs']['MACs']>0)\n",
+    "ops = g.get_ops('MACs', lambda op: op['attrs']['MACs']>0)\n",
     "for bar,op in zip(barlist, ops):\n",
     "    kernel = op['attrs'].get('kernel_shape', None)\n",
     "    if str(kernel) == '[7, 7]':\n",
@@ -303,12 +243,12 @@
    "source": [
     "ops = sgraph.ops\n",
     "positive_mac = lambda op: op['attrs']['MACs']>0\n",
-    "names = get_attr(g, 'name', positive_mac) \n",
+    "names = g.get_attr('name', positive_mac) \n",
     "\n",
-    "macs = get_attr(g, 'MACs', positive_mac)\n",
+    "macs = g.get_attr('MACs', positive_mac)\n",
     "norm_macs = [float(i)/np.sum(macs) for i in macs]\n",
     "\n",
-    "footprint = get_attr(g, 'footprint', positive_mac)\n",
+    "footprint = g.get_attr('footprint', positive_mac)\n",
     "norm_footprint = [float(i)/np.sum(footprint) for i in footprint]\n",
     "\n",
     "plot_bars(None, norm_macs, 'MACs', norm_footprint, 'footprint', names, \"MACs vs footprint\")\n",
@@ -359,13 +299,6 @@
     "\n",
     "\n"
    ]
-  },
-  {
-   "cell_type": "code",
-   "execution_count": null,
-   "metadata": {},
-   "outputs": [],
-   "source": []
   }
  ],
  "metadata": {
-- 
GitLab