From f8df3020fa7067d1cec12675f325610ef5e77fdd Mon Sep 17 00:00:00 2001
From: Neta Zmora <neta.zmora@intel.com>
Date: Sun, 22 Jul 2018 14:51:46 +0300
Subject: [PATCH] Utils: added a utilitiy function to count the number of
 elements in a model

---
 distiller/utils.py | 10 ++++++++++
 1 file changed, 10 insertions(+)

diff --git a/distiller/utils.py b/distiller/utils.py
index 073357d..f803d17 100755
--- a/distiller/utils.py
+++ b/distiller/utils.py
@@ -229,6 +229,16 @@ def density_rows(tensor):
     return 1 - sparsity_rows(tensor)
 
 
+def model_numel(model, param_dims=[2, 4]):
+    """Count the number elements in a model's parameter tensors"""
+    total_numel = 0
+    for name, param in model.state_dict().items():
+        # Extract just the actual parameter's name, which in this context we treat as its "type"
+        if param.dim() in param_dims and any(type in name for type in ['weight', 'bias']):
+            total_numel += torch.numel(param)
+    return total_numel
+
+
 def log_training_progress(stats_dict, params_dict, epoch, steps_completed, total_steps, log_freq, loggers):
     """Log information about the training progress, and the distribution of the weight tensors.
 
-- 
GitLab