From ecade1b2573fdf0bed3ebe38ddd1cb03beb6cb48 Mon Sep 17 00:00:00 2001
From: Neta Zmora <neta.zmora@intel.com>
Date: Wed, 13 Jun 2018 10:17:20 +0300
Subject: [PATCH] ModelSummary: adapt sparsity accounting to correctly account
 for "weight tying"wq

In language models, we might use use "weight tying", which means that the same
weights tensor is used in several different places.  If tying is used, we'd like
to log the tensor information, but exclude it from the total sparsity calculation.
---
 distiller/model_summaries.py | 10 ++++++++--
 1 file changed, 8 insertions(+), 2 deletions(-)

diff --git a/distiller/model_summaries.py b/distiller/model_summaries.py
index b9edddc..bc9ee91 100755
--- a/distiller/model_summaries.py
+++ b/distiller/model_summaries.py
@@ -96,11 +96,17 @@ def weights_sparsity_summary(model, return_total_sparsity=False, param_dims=[2,4
     pd.set_option('precision', 2)
     params_size = 0
     sparse_params_size = 0
+    # In language models, we might use use "weight tying", which means that the same
+    # weights tensor is used in several different places.  If tying is used, we'd like
+    # to log the tensor information, but exclude it from the total sparsity calculation.
+    seen_params = []
     for name, param in model.state_dict().items():
         if (param.dim() in param_dims) and any(type in name for type in ['weight', 'bias']):
             _density = distiller.density(param)
-            params_size += torch.numel(param)
-            sparse_params_size += param.numel() * _density
+            if name not in seen_params:
+                params_size += torch.numel(param)
+                sparse_params_size += param.numel() * _density
+                seen_params.append(name)
             df.loc[len(df.index)] = ([
                 name,
                 distiller.size_to_str(param.size()),
-- 
GitLab