Skip to content
Snippets Groups Projects
Commit 839c433a authored by Neta Zmora's avatar Neta Zmora
Browse files

Utils: added model_params_stats

This is a utility function that returns some statistics about a
model's parameters (model_sparsity, params_cnt, params_nnz_cnt).

This file is required for the previous commit (and was accidentally
left out)
parent 9cb0dd68
No related branches found
No related tags found
No related merge requests found
...@@ -341,15 +341,35 @@ def density_rows(tensor, transposed=True): ...@@ -341,15 +341,35 @@ def density_rows(tensor, transposed=True):
def model_sparsity(model, param_dims=[2, 4]): def model_sparsity(model, param_dims=[2, 4]):
params_size = 0 """Returns the model sparsity as a fraction in [0..1]"""
sparse_params_size = 0 sparsity, _, _ = model_params_stats(model, param_dims)
return sparsity
def model_params_size(model, param_dims=[2, 4]):
"""Returns the model sparsity as a fraction in [0..1]"""
_, _, sparse_params_cnt = model_params_stats(model, param_dims)
return sparse_params_cnt
def model_params_stats(model, param_dims=[2, 4]):
"""Returns the model sparsity, weights count, and the count of weights in the sparse model.
Returns:
model_sparsity - the model weights sparsity (in percent)
params_cnt - the number of weights in the entire model (incl. zeros)
params_nnz_cnt - the number of weights in the entire model, excluding zeros.
nnz stands for non-zeros.
"""
params_cnt = 0
params_nnz_cnt = 0
for name, param in model.state_dict().items(): for name, param in model.state_dict().items():
if param.dim() in param_dims and any(type in name for type in ['weight', 'bias']): if param.dim() in param_dims and any(type in name for type in ['weight', 'bias']):
_density = density(param) _density = density(param)
params_size += torch.numel(param) params_cnt += torch.numel(param)
sparse_params_size += param.numel() * _density params_nnz_cnt += param.numel() * _density
total_sparsity = (1 - sparse_params_size/params_size)*100 model_sparsity = (1 - params_nnz_cnt/params_cnt)*100
return total_sparsity return model_sparsity, params_cnt, params_nnz_cnt
def norm_filters(weights, p=1): def norm_filters(weights, p=1):
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment