diff --git a/distiller/utils.py b/distiller/utils.py
index 073357dba2857e67b097c7405247128a9718bd3a..f803d171b70378815998515c1687bd5eaa62dbb6 100755
--- a/distiller/utils.py
+++ b/distiller/utils.py
@@ -229,6 +229,16 @@ def density_rows(tensor):
     return 1 - sparsity_rows(tensor)
 
 
+def model_numel(model, param_dims=[2, 4]):
+    """Count the number elements in a model's parameter tensors"""
+    total_numel = 0
+    for name, param in model.state_dict().items():
+        # Extract just the actual parameter's name, which in this context we treat as its "type"
+        if param.dim() in param_dims and any(type in name for type in ['weight', 'bias']):
+            total_numel += torch.numel(param)
+    return total_numel
+
+
 def log_training_progress(stats_dict, params_dict, epoch, steps_completed, total_steps, log_freq, loggers):
     """Log information about the training progress, and the distribution of the weight tensors.