From 839c433ab9b389a7e113e096cba8aa274976aea4 Mon Sep 17 00:00:00 2001
From: Neta Zmora <neta.zmora@intel.com>
Date: Thu, 7 Mar 2019 00:43:46 +0200
Subject: [PATCH] Utils: added model_params_stats

This is a utility function that returns some statistics about a
model's parameters (model_sparsity, params_cnt, params_nnz_cnt).

This file is required for the previous commit (and was accidentally
left out)
---
 distiller/utils.py | 32 ++++++++++++++++++++++++++------
 1 file changed, 26 insertions(+), 6 deletions(-)

diff --git a/distiller/utils.py b/distiller/utils.py
index e7dcd99..e2f624a 100755
--- a/distiller/utils.py
+++ b/distiller/utils.py
@@ -341,15 +341,35 @@ def density_rows(tensor, transposed=True):
 
 
 def model_sparsity(model, param_dims=[2, 4]):
-    params_size = 0
-    sparse_params_size = 0
+    """Returns the model sparsity as a fraction in [0..1]"""
+    sparsity, _, _ = model_params_stats(model, param_dims)
+    return sparsity
+
+
+def model_params_size(model, param_dims=[2, 4]):
+    """Returns the model sparsity as a fraction in [0..1]"""
+    _, _, sparse_params_cnt = model_params_stats(model, param_dims)
+    return sparse_params_cnt
+
+
+def model_params_stats(model, param_dims=[2, 4]):
+    """Returns the model sparsity, weights count, and the count of weights in the sparse model.
+
+    Returns:
+        model_sparsity - the model weights sparsity (in percent)
+        params_cnt - the number of weights in the entire model (incl. zeros)
+        params_nnz_cnt - the number of weights in the entire model, excluding zeros.
+                         nnz stands for non-zeros.
+    """
+    params_cnt = 0
+    params_nnz_cnt = 0
     for name, param in model.state_dict().items():
         if param.dim() in param_dims and any(type in name for type in ['weight', 'bias']):
             _density = density(param)
-            params_size += torch.numel(param)
-            sparse_params_size += param.numel() * _density
-    total_sparsity = (1 - sparse_params_size/params_size)*100
-    return total_sparsity
+            params_cnt += torch.numel(param)
+            params_nnz_cnt += param.numel() * _density
+    model_sparsity = (1 - params_nnz_cnt/params_cnt)*100
+    return model_sparsity, params_cnt, params_nnz_cnt
 
 
 def norm_filters(weights, p=1):
-- 
GitLab