diff --git a/distiller/utils.py b/distiller/utils.py
index 084eb56069e8a3c8bcda90e7a4732b9f8a575dd4..f00e333f03ad4feecfe0ee4659d4549873ec3dbd 100755
--- a/distiller/utils.py
+++ b/distiller/utils.py
@@ -355,19 +355,19 @@ def density_rows(tensor, transposed=True):
     return 1 - sparsity_rows(tensor, transposed)
 
 
-def model_sparsity(model, param_dims=[2, 4]):
+def model_sparsity(model, param_dims=[2, 4], param_types=['weight', 'bias']):
     """Returns the model sparsity as a fraction in [0..1]"""
-    sparsity, _, _ = model_params_stats(model, param_dims)
+    sparsity, _, _ = model_params_stats(model, param_dims, param_types)
     return sparsity
 
 
-def model_params_size(model, param_dims=[2, 4]):
-    """Returns the model sparsity as a fraction in [0..1]"""
-    _, _, sparse_params_cnt = model_params_stats(model, param_dims)
+def model_params_size(model, param_dims=[2, 4], param_types=['weight', 'bias']):
+    """Returns the size of the model parameters, w/o counting zero coefficients"""
+    _, _, sparse_params_cnt = model_params_stats(model, param_dims, param_types)
     return sparse_params_cnt
 
 
-def model_params_stats(model, param_dims=[2, 4]):
+def model_params_stats(model, param_dims=[2, 4], param_types=['weight', 'bias']):
     """Returns the model sparsity, weights count, and the count of weights in the sparse model.
 
     Returns:
@@ -379,7 +379,7 @@ def model_params_stats(model, param_dims=[2, 4]):
     params_cnt = 0
     params_nnz_cnt = 0
     for name, param in model.state_dict().items():
-        if param.dim() in param_dims and any(type in name for type in ['weight', 'bias']):
+        if param.dim() in param_dims and any(type in name for type in param_types):
             _density = density(param)
             params_cnt += torch.numel(param)
             params_nnz_cnt += param.numel() * _density
@@ -399,12 +399,12 @@ def norm_filters(weights, p=1):
     return weights.view(weights.size(0), -1).norm(p=p, dim=1)
 
 
-def model_numel(model, param_dims=[2, 4]):
+def model_numel(model, param_dims=[2, 4], param_types=['weight', 'bias']):
     """Count the number elements in a model's parameter tensors"""
     total_numel = 0
     for name, param in model.state_dict().items():
         # Extract just the actual parameter's name, which in this context we treat as its "type"
-        if param.dim() in param_dims and any(type in name for type in ['weight', 'bias']):
+        if param.dim() in param_dims and any(type in name for type in param_types):
             total_numel += torch.numel(param)
     return total_numel