Skip to content
Snippets Groups Projects
Commit cbf1a739 authored by Neta Zmora's avatar Neta Zmora
Browse files

distiller/utils.py: enhance model stats functions interface

Added the ability to configure the parameter types (weight/bias)
of parameters for which we collect statistics.
parent 3e296f49
No related branches found
No related tags found
No related merge requests found
...@@ -355,19 +355,19 @@ def density_rows(tensor, transposed=True): ...@@ -355,19 +355,19 @@ def density_rows(tensor, transposed=True):
return 1 - sparsity_rows(tensor, transposed) return 1 - sparsity_rows(tensor, transposed)
def model_sparsity(model, param_dims=[2, 4]): def model_sparsity(model, param_dims=[2, 4], param_types=['weight', 'bias']):
"""Returns the model sparsity as a fraction in [0..1]""" """Returns the model sparsity as a fraction in [0..1]"""
sparsity, _, _ = model_params_stats(model, param_dims) sparsity, _, _ = model_params_stats(model, param_dims, param_types)
return sparsity return sparsity
def model_params_size(model, param_dims=[2, 4]): def model_params_size(model, param_dims=[2, 4], param_types=['weight', 'bias']):
"""Returns the model sparsity as a fraction in [0..1]""" """Returns the size of the model parameters, w/o counting zero coefficients"""
_, _, sparse_params_cnt = model_params_stats(model, param_dims) _, _, sparse_params_cnt = model_params_stats(model, param_dims, param_types)
return sparse_params_cnt return sparse_params_cnt
def model_params_stats(model, param_dims=[2, 4]): def model_params_stats(model, param_dims=[2, 4], param_types=['weight', 'bias']):
"""Returns the model sparsity, weights count, and the count of weights in the sparse model. """Returns the model sparsity, weights count, and the count of weights in the sparse model.
Returns: Returns:
...@@ -379,7 +379,7 @@ def model_params_stats(model, param_dims=[2, 4]): ...@@ -379,7 +379,7 @@ def model_params_stats(model, param_dims=[2, 4]):
params_cnt = 0 params_cnt = 0
params_nnz_cnt = 0 params_nnz_cnt = 0
for name, param in model.state_dict().items(): for name, param in model.state_dict().items():
if param.dim() in param_dims and any(type in name for type in ['weight', 'bias']): if param.dim() in param_dims and any(type in name for type in param_types):
_density = density(param) _density = density(param)
params_cnt += torch.numel(param) params_cnt += torch.numel(param)
params_nnz_cnt += param.numel() * _density params_nnz_cnt += param.numel() * _density
...@@ -399,12 +399,12 @@ def norm_filters(weights, p=1): ...@@ -399,12 +399,12 @@ def norm_filters(weights, p=1):
return weights.view(weights.size(0), -1).norm(p=p, dim=1) return weights.view(weights.size(0), -1).norm(p=p, dim=1)
def model_numel(model, param_dims=[2, 4]): def model_numel(model, param_dims=[2, 4], param_types=['weight', 'bias']):
"""Count the number elements in a model's parameter tensors""" """Count the number elements in a model's parameter tensors"""
total_numel = 0 total_numel = 0
for name, param in model.state_dict().items(): for name, param in model.state_dict().items():
# Extract just the actual parameter's name, which in this context we treat as its "type" # Extract just the actual parameter's name, which in this context we treat as its "type"
if param.dim() in param_dims and any(type in name for type in ['weight', 'bias']): if param.dim() in param_dims and any(type in name for type in param_types):
total_numel += torch.numel(param) total_numel += torch.numel(param)
return total_numel return total_numel
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment