From 839c433ab9b389a7e113e096cba8aa274976aea4 Mon Sep 17 00:00:00 2001 From: Neta Zmora <neta.zmora@intel.com> Date: Thu, 7 Mar 2019 00:43:46 +0200 Subject: [PATCH] Utils: added model_params_stats This is a utility function that returns some statistics about a model's parameters (model_sparsity, params_cnt, params_nnz_cnt). This file is required for the previous commit (and was accidentally left out) --- distiller/utils.py | 32 ++++++++++++++++++++++++++------ 1 file changed, 26 insertions(+), 6 deletions(-) diff --git a/distiller/utils.py b/distiller/utils.py index e7dcd99..e2f624a 100755 --- a/distiller/utils.py +++ b/distiller/utils.py @@ -341,15 +341,35 @@ def density_rows(tensor, transposed=True): def model_sparsity(model, param_dims=[2, 4]): - params_size = 0 - sparse_params_size = 0 + """Returns the model sparsity as a fraction in [0..1]""" + sparsity, _, _ = model_params_stats(model, param_dims) + return sparsity + + +def model_params_size(model, param_dims=[2, 4]): + """Returns the model sparsity as a fraction in [0..1]""" + _, _, sparse_params_cnt = model_params_stats(model, param_dims) + return sparse_params_cnt + + +def model_params_stats(model, param_dims=[2, 4]): + """Returns the model sparsity, weights count, and the count of weights in the sparse model. + + Returns: + model_sparsity - the model weights sparsity (in percent) + params_cnt - the number of weights in the entire model (incl. zeros) + params_nnz_cnt - the number of weights in the entire model, excluding zeros. + nnz stands for non-zeros. + """ + params_cnt = 0 + params_nnz_cnt = 0 for name, param in model.state_dict().items(): if param.dim() in param_dims and any(type in name for type in ['weight', 'bias']): _density = density(param) - params_size += torch.numel(param) - sparse_params_size += param.numel() * _density - total_sparsity = (1 - sparse_params_size/params_size)*100 - return total_sparsity + params_cnt += torch.numel(param) + params_nnz_cnt += param.numel() * _density + model_sparsity = (1 - params_nnz_cnt/params_cnt)*100 + return model_sparsity, params_cnt, params_nnz_cnt def norm_filters(weights, p=1): -- GitLab