diff --git a/distiller/utils.py b/distiller/utils.py index e7dcd99f22cc4aa102f42726176bc4aab3db446b..e2f624ac30f2be8da0f50b7169d4d5c12faef5f9 100755 --- a/distiller/utils.py +++ b/distiller/utils.py @@ -341,15 +341,35 @@ def density_rows(tensor, transposed=True): def model_sparsity(model, param_dims=[2, 4]): - params_size = 0 - sparse_params_size = 0 + """Returns the model sparsity as a fraction in [0..1]""" + sparsity, _, _ = model_params_stats(model, param_dims) + return sparsity + + +def model_params_size(model, param_dims=[2, 4]): + """Returns the model sparsity as a fraction in [0..1]""" + _, _, sparse_params_cnt = model_params_stats(model, param_dims) + return sparse_params_cnt + + +def model_params_stats(model, param_dims=[2, 4]): + """Returns the model sparsity, weights count, and the count of weights in the sparse model. + + Returns: + model_sparsity - the model weights sparsity (in percent) + params_cnt - the number of weights in the entire model (incl. zeros) + params_nnz_cnt - the number of weights in the entire model, excluding zeros. + nnz stands for non-zeros. + """ + params_cnt = 0 + params_nnz_cnt = 0 for name, param in model.state_dict().items(): if param.dim() in param_dims and any(type in name for type in ['weight', 'bias']): _density = density(param) - params_size += torch.numel(param) - sparse_params_size += param.numel() * _density - total_sparsity = (1 - sparse_params_size/params_size)*100 - return total_sparsity + params_cnt += torch.numel(param) + params_nnz_cnt += param.numel() * _density + model_sparsity = (1 - params_nnz_cnt/params_cnt)*100 + return model_sparsity, params_cnt, params_nnz_cnt def norm_filters(weights, p=1):