From f8df3020fa7067d1cec12675f325610ef5e77fdd Mon Sep 17 00:00:00 2001 From: Neta Zmora <neta.zmora@intel.com> Date: Sun, 22 Jul 2018 14:51:46 +0300 Subject: [PATCH] Utils: added a utilitiy function to count the number of elements in a model --- distiller/utils.py | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/distiller/utils.py b/distiller/utils.py index 073357d..f803d17 100755 --- a/distiller/utils.py +++ b/distiller/utils.py @@ -229,6 +229,16 @@ def density_rows(tensor): return 1 - sparsity_rows(tensor) +def model_numel(model, param_dims=[2, 4]): + """Count the number elements in a model's parameter tensors""" + total_numel = 0 + for name, param in model.state_dict().items(): + # Extract just the actual parameter's name, which in this context we treat as its "type" + if param.dim() in param_dims and any(type in name for type in ['weight', 'bias']): + total_numel += torch.numel(param) + return total_numel + + def log_training_progress(stats_dict, params_dict, epoch, steps_completed, total_steps, log_freq, loggers): """Log information about the training progress, and the distribution of the weight tensors. -- GitLab