diff --git a/distiller/utils.py b/distiller/utils.py index 084eb56069e8a3c8bcda90e7a4732b9f8a575dd4..f00e333f03ad4feecfe0ee4659d4549873ec3dbd 100755 --- a/distiller/utils.py +++ b/distiller/utils.py @@ -355,19 +355,19 @@ def density_rows(tensor, transposed=True): return 1 - sparsity_rows(tensor, transposed) -def model_sparsity(model, param_dims=[2, 4]): +def model_sparsity(model, param_dims=[2, 4], param_types=['weight', 'bias']): """Returns the model sparsity as a fraction in [0..1]""" - sparsity, _, _ = model_params_stats(model, param_dims) + sparsity, _, _ = model_params_stats(model, param_dims, param_types) return sparsity -def model_params_size(model, param_dims=[2, 4]): - """Returns the model sparsity as a fraction in [0..1]""" - _, _, sparse_params_cnt = model_params_stats(model, param_dims) +def model_params_size(model, param_dims=[2, 4], param_types=['weight', 'bias']): + """Returns the size of the model parameters, w/o counting zero coefficients""" + _, _, sparse_params_cnt = model_params_stats(model, param_dims, param_types) return sparse_params_cnt -def model_params_stats(model, param_dims=[2, 4]): +def model_params_stats(model, param_dims=[2, 4], param_types=['weight', 'bias']): """Returns the model sparsity, weights count, and the count of weights in the sparse model. Returns: @@ -379,7 +379,7 @@ def model_params_stats(model, param_dims=[2, 4]): params_cnt = 0 params_nnz_cnt = 0 for name, param in model.state_dict().items(): - if param.dim() in param_dims and any(type in name for type in ['weight', 'bias']): + if param.dim() in param_dims and any(type in name for type in param_types): _density = density(param) params_cnt += torch.numel(param) params_nnz_cnt += param.numel() * _density @@ -399,12 +399,12 @@ def norm_filters(weights, p=1): return weights.view(weights.size(0), -1).norm(p=p, dim=1) -def model_numel(model, param_dims=[2, 4]): +def model_numel(model, param_dims=[2, 4], param_types=['weight', 'bias']): """Count the number elements in a model's parameter tensors""" total_numel = 0 for name, param in model.state_dict().items(): # Extract just the actual parameter's name, which in this context we treat as its "type" - if param.dim() in param_dims and any(type in name for type in ['weight', 'bias']): + if param.dim() in param_dims and any(type in name for type in param_types): total_numel += torch.numel(param) return total_numel