diff --git a/apputils/model_summaries.py b/apputils/model_summaries.py
index a650e54772164357816a15f33f13a842a8286874..12c3b6fb8ca0d289f6e8ca07581fb1693bf34c57 100755
--- a/apputils/model_summaries.py
+++ b/apputils/model_summaries.py
@@ -21,6 +21,8 @@ of PyTorch with stable support for the JIT tracer functionality we employ in thi
 code (it was built with a 4.x master branch).
 """
 
+import re
+import numpy as np
 import torch
 import torchvision
 from torch.autograd import Variable
@@ -28,6 +30,26 @@ import torch.jit as jit
 import pandas as pd
 from tabulate import tabulate
 
+def onnx_name_2_pytorch_name(name):
+    # Convert a layer's name from an ONNX name, to a PyTorch name
+    # For example:
+    #   ResNet/Sequential[layer3]/BasicBlock[0]/ReLU[relu].1 ==> layer3.0.relu.1
+
+    # First see if there's an instance identifier
+    instance = ''
+    if name.find('.')>0:
+        instance = name[name.find('.')+1 :]
+
+    # Next, split by square brackets
+    name_parts = re.findall('\[.*?\]', name)
+    name_parts = [part[1:-1] for part in name_parts]
+
+    #print(op['name'] , '.'.join(name_parts), op['type'])
+    name = '.'.join(name_parts) + instance
+    #if name == '':
+    #    name += 'stam'
+    return name
+
 
 class SummaryGraph(object):
     """We use Pytorch's JIT tracer to run a forward pass and generate a trace graph, which
@@ -89,6 +111,12 @@ class SummaryGraph(object):
                 same = [layer for layer in self.ops if layer['orig-name'] == op['orig-name']]
                 if len(same) > 0:
                     op['name'] += "." + str(len(same))
+
+                op['name'] = onnx_name_2_pytorch_name(op['name'])
+                #op['name'] = '\n'.join(op['name'], op['type'])
+                op['name'] += ("\n" + op['type'])
+
+                # print(op['name'])
                 self.ops.append(op)
 
                 for input_ in node.inputs():
@@ -101,6 +129,10 @@ class SummaryGraph(object):
 
                 op['attrs'] = {attr_name: node[attr_name] for attr_name in node.attributeNames()}
 
+        self.add_macs_attr()
+        self.add_footprint_attr()
+        self.add_arithmetic_intensity_attr()
+
 
     def __add_input(self, op, n):
         param = self.__add_param(n)
@@ -135,6 +167,66 @@ class SummaryGraph(object):
             return None
         return tensor
 
+    def param_shape(self, param_id):
+        return self.params[param_id]['shape']
+
+    @staticmethod
+    def volume(dims):
+        return np.prod(dims)
+
+    def param_volume(self, param_id):
+        return SummaryGraph.volume(self.param_shape(param_id))
+
+    def add_macs_attr(self):
+        for op in self.ops:
+            op['attrs']['MACs'] = 0
+            if op['type'] == 'Conv':
+                conv_out = op['outputs'][0]
+                conv_in =  op['inputs'][0]
+                conv_w = op['attrs']['kernel_shape']
+                ofm_vol = self.param_volume(conv_out)
+                # MACs = volume(OFM) * (#IFM * K^2)
+                op['attrs']['MACs'] = ofm_vol * SummaryGraph.volume(conv_w) * self.params[conv_in]['shape'][1]
+            elif op['type'] == 'Gemm':
+                conv_out =  op['outputs'][0]
+                conv_in =  op['inputs'][0]
+                n_ifm = self.param_shape(conv_in)[1]
+                n_ofm = self.param_shape(conv_out)[1]
+                # MACs = #IFM * #OFM
+                op['attrs']['MACs'] = n_ofm * n_ifm
+
+    def add_footprint_attr(self):
+        for op in self.ops:
+            op['attrs']['footprint'] = 0
+            if op['type'] in ['Conv', 'Gemm', 'MaxPool']:
+                conv_out = op['outputs'][0]
+                conv_in =  op['inputs'][0]
+                ofm_vol = self.param_volume(conv_out)
+                ifm_vol = self.param_volume(conv_in)
+                if op['type'] == 'Conv' or op['type'] == 'Gemm':
+                    conv_w = op['inputs'][1]
+                    weights_vol = self.param_volume(conv_w)
+                    #print(ofm_vol , ifm_vol , weights_vol)
+                    op['attrs']['footprint'] = ofm_vol + ifm_vol + weights_vol
+                    op['attrs']['fm_vol'] = ofm_vol + ifm_vol
+                    op['attrs']['weights_vol'] = weights_vol
+                elif op['type'] == 'MaxPool':
+                    op['attrs']['footprint'] = ofm_vol + ifm_vol
+
+    def add_arithmetic_intensity_attr(self):
+        for op in self.ops:
+            if op['attrs']['footprint'] == 0:
+                op['attrs']['ai'] = 0
+            else:
+                # integers are enough, and note that we also round up
+                op['attrs']['ai'] = ((op['attrs']['MACs']+0.5*op['attrs']['footprint']) // op['attrs']['footprint'])
+
+    def get_attr(self, attr, f = lambda op: True):
+        return [op['attrs'][attr] for op in self.ops if attr in op['attrs'] and f(op)]
+
+    def get_ops(self, attr, f = lambda op: True):
+        return [op for op in self.ops if attr in op['attrs'] and f(op)]
+
 
 def attributes_summary(sgraph, ignore_attrs):
     """Generate a summary of a graph's attributes.
diff --git a/jupyter/experimental.ipynb b/jupyter/experimental.ipynb
index d73e8768a8beb634e090993ff78813620ae5d03b..aefbdfd9a46bd20e9ada73a89f8767b750a50888 100644
--- a/jupyter/experimental.ipynb
+++ b/jupyter/experimental.ipynb
@@ -73,7 +73,7 @@
    "outputs": [],
    "source": [
     "dataset = 'imagenet'\n",
-    "dummy_input = Variable(torch.randn(1, 3, 224, 224), requires_grad=False)\n",
+    "dummy_input = torch.randn((1, 3, 224, 224), requires_grad=False)\n",
     "arch = 'resnet18'\n",
     "#arch = 'alexnet'\n",
     "checkpoint_file = None \n",
@@ -145,75 +145,15 @@
    },
    "outputs": [],
    "source": [
-    "def volume(dims):\n",
-    "    vol = 1\n",
-    "    for d in range(len(dims)): vol *= dims[d]\n",
-    "    return vol\n",
     "\n",
-    "def param_shape(sgraph, param_id):\n",
-    "    return sgraph.params[param_id]['shape']\n",
-    "\n",
-    "def param_volume(sgraph, param_id):\n",
-    "    return volume(param_shape(sgraph, param_id))\n",
-    "    \n",
-    "def add_macs_attr(sgraph):                           \n",
-    "    for op in sgraph.ops:\n",
-    "        op['attrs']['MACs'] = 0\n",
-    "        if op['type'] == 'Conv':\n",
-    "            conv_out = op['outputs'][0]\n",
-    "            conv_in =  op['inputs'][0]\n",
-    "            conv_w = op['attrs']['kernel_shape']\n",
-    "            ofm_vol = param_volume(sgraph, conv_out)\n",
-    "            # MACs = volume(OFM) * (#IFM * K^2)\n",
-    "            op['attrs']['MACs'] = ofm_vol * volume(conv_w) * sgraph.params[conv_in]['shape'][1]\n",
-    "        elif op['type'] == 'Gemm':\n",
-    "            conv_out =  op['outputs'][0]\n",
-    "            conv_in =  op['inputs'][0]\n",
-    "            n_ifm = param_shape(sgraph, conv_in)[1]\n",
-    "            n_ofm = param_shape(sgraph, conv_out)[1]\n",
-    "            # MACs = #IFM * #OFM\n",
-    "            op['attrs']['MACs'] = n_ofm * n_ifm            \n",
-    "\n",
-    "def add_footprint_attr(sgraph):                           \n",
-    "    for op in sgraph.ops:\n",
-    "        op['attrs']['footprint'] = 0\n",
-    "        if op['type'] in ['Conv', 'Gemm', 'MaxPool']:\n",
-    "            conv_out = op['outputs'][0]\n",
-    "            conv_in =  op['inputs'][0]\n",
-    "            ofm_vol = param_volume(sgraph, conv_out)\n",
-    "            ifm_vol = param_volume(sgraph, conv_in)            \n",
-    "            if op['type'] == 'Conv' or op['type'] == 'Gemm':\n",
-    "                conv_w = op['inputs'][1]\n",
-    "                weights_vol = param_volume(sgraph, conv_w)\n",
-    "                #print(ofm_vol , ifm_vol , weights_vol)\n",
-    "                op['attrs']['footprint'] = ofm_vol + ifm_vol + weights_vol\n",
-    "                op['attrs']['fm_vol'] = ofm_vol + ifm_vol\n",
-    "                op['attrs']['weights_vol'] = weights_vol\n",
-    "            elif op['type'] == 'MaxPool':\n",
-    "                op['attrs']['footprint'] = ofm_vol + ifm_vol\n",
-    "\n",
-    "def add_arithmetic_intensity_attr(sgraph):                           \n",
-    "    for op in sgraph.ops:\n",
-    "        if op['attrs']['footprint'] == 0:\n",
-    "            op['attrs']['ai'] = 0\n",
-    "        else:\n",
-    "            # integers are enough, and note that we also round up\n",
-    "            op['attrs']['ai'] = ((op['attrs']['MACs']+0.5*op['attrs']['footprint'])  // op['attrs']['footprint']) \n",
-    "\n",
-    "def get_attr(sgraph, attr, f = lambda op: True):\n",
-    "    return [op['attrs'][attr] for op in sgraph.ops if attr in op['attrs'] and f(op)]\n",
-    "\n",
-    "def get_ops(sgraph, attr, f = lambda op: True):\n",
-    "    return [op for op in sgraph.ops if attr in op['attrs'] and f(op)]\n",
-    "        \n",
-    "add_macs_attr(g)\n",
-    "add_footprint_attr(g)\n",
-    "add_arithmetic_intensity_attr(g)\n",
+    "# add_macs_attr(g)\n",
+    "# add_footprint_attr(g)\n",
+    "# add_arithmetic_intensity_attr(g)\n",
     "ignore_attrs = ['group', 'is_test', 'consumed_inputs', 'alpha', 'beta', 'MACs', 'footprint', 'ai', 'fm_vol', 'weights_vol']\n",
     "df = attributes_summary(g, ignore_attrs)\n",
-    "df['MAC'] = get_attr(g, 'MACs')\n",
-    "df['BW'] = get_attr(g, 'footprint')\n",
-    "df['AI'] = get_attr(g, 'ai')\n",
+    "df['MAC'] = g.get_attr('MACs')\n",
+    "df['BW'] = g.get_attr('footprint')\n",
+    "df['AI'] = g.get_attr('ai')\n",
     "#df = df.assign([5]*len(df)).values\n",
     "\n",
     "qgrid.show_grid(df)"
@@ -252,8 +192,8 @@
     "\n",
     "sgraph = g\n",
     "names = [op['name'] for op in sgraph.ops]\n",
-    "setA = get_attr(g, 'fm_vol') \n",
-    "setB = get_attr(g, 'weights_vol') \n",
+    "setA = g.get_attr('fm_vol') \n",
+    "setB = g.get_attr('weights_vol') \n",
     "plot_bars(None, setA, 'Feature maps', setB, 'Weights', names, 'Weights footprint vs. feature-maps footprint\\n(Normalized)')"
    ]
   },
@@ -264,7 +204,7 @@
    "outputs": [],
    "source": [
     "names = [op['name'] for op in sgraph.ops if 'MACs' in op['attrs'] and op['attrs']['MACs']>0]\n",
-    "macs = get_attr(g, 'MACs', lambda op: op['attrs']['MACs']>0)\n",
+    "macs = g.get_attr('MACs', lambda op: op['attrs']['MACs']>0)\n",
     "\n",
     "y_pos = np.arange(len(names))\n",
     "fig, ax = plt.subplots(figsize=(20,10))\n",
@@ -284,7 +224,7 @@
     "for location in ['right', 'left', 'top', 'bottom']:\n",
     "    ax.spines[location].set_visible(False) \n",
     "\n",
-    "ops = get_ops(g, 'MACs', lambda op: op['attrs']['MACs']>0)\n",
+    "ops = g.get_ops('MACs', lambda op: op['attrs']['MACs']>0)\n",
     "for bar,op in zip(barlist, ops):\n",
     "    kernel = op['attrs'].get('kernel_shape', None)\n",
     "    if str(kernel) == '[7, 7]':\n",
@@ -303,12 +243,12 @@
    "source": [
     "ops = sgraph.ops\n",
     "positive_mac = lambda op: op['attrs']['MACs']>0\n",
-    "names = get_attr(g, 'name', positive_mac) \n",
+    "names = g.get_attr('name', positive_mac) \n",
     "\n",
-    "macs = get_attr(g, 'MACs', positive_mac)\n",
+    "macs = g.get_attr('MACs', positive_mac)\n",
     "norm_macs = [float(i)/np.sum(macs) for i in macs]\n",
     "\n",
-    "footprint = get_attr(g, 'footprint', positive_mac)\n",
+    "footprint = g.get_attr('footprint', positive_mac)\n",
     "norm_footprint = [float(i)/np.sum(footprint) for i in footprint]\n",
     "\n",
     "plot_bars(None, norm_macs, 'MACs', norm_footprint, 'footprint', names, \"MACs vs footprint\")\n",
@@ -359,13 +299,6 @@
     "\n",
     "\n"
    ]
-  },
-  {
-   "cell_type": "code",
-   "execution_count": null,
-   "metadata": {},
-   "outputs": [],
-   "source": []
   }
  ],
  "metadata": {