From 957e6777885d3a1a4fe3edbfe42f892f5a062188 Mon Sep 17 00:00:00 2001
From: Neta Zmora <neta.zmora@intel.com>
Date: Thu, 10 May 2018 18:25:03 +0300
Subject: [PATCH] pytorch 0.4: adjustments to API changes

Various small changes due to the chamnges in the semantics and syntax of the
PyTorch 0.4 API.

Note that currently distiller.model_performance_summary() returns wrong results
on graphs containing torch.nn.DataParallel layers.
---
 distiller/__init__.py        | 12 ++++++------
 distiller/model_summaries.py | 17 +++++++++++------
 jupyter/model_summary.ipynb  |  2 +-
 3 files changed, 18 insertions(+), 13 deletions(-)

diff --git a/distiller/__init__.py b/distiller/__init__.py
index c48eaea..c63517d 100755
--- a/distiller/__init__.py
+++ b/distiller/__init__.py
@@ -37,18 +37,18 @@ del thinning
 # Distiller version
 __version__ = "0.1.0"
 
-def model_find_param_name(model, tensor_to_find):
-    """Look up the name of a model tensor.
+def model_find_param_name(model, param_to_find):
+    """Look up the name of a model parameter.
 
     Arguments:
         model: the model to search
-        tensor_to_find: the tensors who's name we want to look up
+        param_to_find: the parameter whose name we want to look up
 
     Returns:
-        The parameter name (string) or None, if the paramter was not found.
+        The parameter name (string) or None, if the parameter was not found.
     """
-    for name, tensor  in model.state_dict().items():
-        if tensor is tensor_to_find:
+    for name, param  in model.named_parameters():
+        if param is param_to_find:
             return name
     return None
 
diff --git a/distiller/model_summaries.py b/distiller/model_summaries.py
index 943aed5..376e91e 100755
--- a/distiller/model_summaries.py
+++ b/distiller/model_summaries.py
@@ -43,7 +43,7 @@ def model_summary(model, optimizer, what, dataset=None):
         distiller.log_weights_sparsity(model, -1, loggers=[pylogger, csvlogger])
     elif what == 'compute':
         if dataset == 'imagenet':
-            dummy_input = Variable(torch.randn(1, 3, 224, 224), requires_grad=False)
+            dummy_input = Variable(torch.randn(1, 3, 224, 224))
         elif dataset == 'cifar10':
             dummy_input = Variable(torch.randn(1, 3, 32, 32))
         else:
@@ -101,9 +101,9 @@ def weights_sparsity_summary(model, return_total_sparsity=False, param_dims=[2,4
                 distiller.sparsity_2D(param)*100,
                 distiller.sparsity_3D(param)*100,
                 (1-_density)*100,
-                param.std(),
-                param.mean(),
-                param.abs().mean()
+                param.std().item(),
+                param.mean().item(),
+                param.abs().mean().item()
             ])
 
     total_sparsity = (1 - sparse_params_size/params_size)*100
@@ -158,7 +158,7 @@ def module_visitor(self, input, output, df, model, weights_vol, macs, attrs=None
     in_features_shape = input[0].size()
     out_features_shape = output.size()
 
-    param_name = distiller.model_find_param_name(model, self.weight.data)
+    param_name = distiller.model_find_param_name(model, self.weight)
     if param_name is None:
         return
     mod_name = param_name[:param_name.find(".weight")]
@@ -168,8 +168,13 @@ def module_visitor(self, input, output, df, model, weights_vol, macs, attrs=None
                               distiller.size_to_str(out_features_shape), distiller.volume(output),
                               weights_vol, int(macs)])
 
+
 def model_performance_summary(model, dummy_input, batch_size=1):
-    """Collect performance data"""
+    """Collect performance data
+
+    warning: in PyTorch 0.4 this function does not return correct values when
+    the graph contains torch.nn.DataParallel layers.
+    """
     def install_perf_collector(m):
         if isinstance(m, torch.nn.Conv2d):
             hook_handles.append(m.register_forward_hook(
diff --git a/jupyter/model_summary.ipynb b/jupyter/model_summary.ipynb
index 179437d..c15da35 100644
--- a/jupyter/model_summary.ipynb
+++ b/jupyter/model_summary.ipynb
@@ -78,7 +78,7 @@
    "outputs": [],
    "source": [
     "dataset = 'imagenet'\n",
-    "dummy_input = Variable(torch.randn(1, 3, 224, 224), requires_grad=False)\n",
+    "dummy_input = torch.randn(1, 3, 224, 224)\n",
     "arch = 'resnet18'\n",
     "#arch = 'alexnet'\n",
     "checkpoint_file = None \n",
-- 
GitLab