Skip to content
Snippets Groups Projects
Commit cbf1a739 authored by Neta Zmora's avatar Neta Zmora
Browse files

distiller/utils.py: enhance model stats functions interface

Added the ability to configure the parameter types (weight/bias)
of parameters for which we collect statistics.
parent 3e296f49
No related branches found
No related tags found
No related merge requests found
......@@ -355,19 +355,19 @@ def density_rows(tensor, transposed=True):
return 1 - sparsity_rows(tensor, transposed)
def model_sparsity(model, param_dims=[2, 4]):
def model_sparsity(model, param_dims=[2, 4], param_types=['weight', 'bias']):
"""Returns the model sparsity as a fraction in [0..1]"""
sparsity, _, _ = model_params_stats(model, param_dims)
sparsity, _, _ = model_params_stats(model, param_dims, param_types)
return sparsity
def model_params_size(model, param_dims=[2, 4]):
"""Returns the model sparsity as a fraction in [0..1]"""
_, _, sparse_params_cnt = model_params_stats(model, param_dims)
def model_params_size(model, param_dims=[2, 4], param_types=['weight', 'bias']):
"""Returns the size of the model parameters, w/o counting zero coefficients"""
_, _, sparse_params_cnt = model_params_stats(model, param_dims, param_types)
return sparse_params_cnt
def model_params_stats(model, param_dims=[2, 4]):
def model_params_stats(model, param_dims=[2, 4], param_types=['weight', 'bias']):
"""Returns the model sparsity, weights count, and the count of weights in the sparse model.
Returns:
......@@ -379,7 +379,7 @@ def model_params_stats(model, param_dims=[2, 4]):
params_cnt = 0
params_nnz_cnt = 0
for name, param in model.state_dict().items():
if param.dim() in param_dims and any(type in name for type in ['weight', 'bias']):
if param.dim() in param_dims and any(type in name for type in param_types):
_density = density(param)
params_cnt += torch.numel(param)
params_nnz_cnt += param.numel() * _density
......@@ -399,12 +399,12 @@ def norm_filters(weights, p=1):
return weights.view(weights.size(0), -1).norm(p=p, dim=1)
def model_numel(model, param_dims=[2, 4]):
def model_numel(model, param_dims=[2, 4], param_types=['weight', 'bias']):
"""Count the number elements in a model's parameter tensors"""
total_numel = 0
for name, param in model.state_dict().items():
# Extract just the actual parameter's name, which in this context we treat as its "type"
if param.dim() in param_dims and any(type in name for type in ['weight', 'bias']):
if param.dim() in param_dims and any(type in name for type in param_types):
total_numel += torch.numel(param)
return total_numel
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment