diff --git a/distiller/utils.py b/distiller/utils.py index 073357dba2857e67b097c7405247128a9718bd3a..f803d171b70378815998515c1687bd5eaa62dbb6 100755 --- a/distiller/utils.py +++ b/distiller/utils.py @@ -229,6 +229,16 @@ def density_rows(tensor): return 1 - sparsity_rows(tensor) +def model_numel(model, param_dims=[2, 4]): + """Count the number elements in a model's parameter tensors""" + total_numel = 0 + for name, param in model.state_dict().items(): + # Extract just the actual parameter's name, which in this context we treat as its "type" + if param.dim() in param_dims and any(type in name for type in ['weight', 'bias']): + total_numel += torch.numel(param) + return total_numel + + def log_training_progress(stats_dict, params_dict, epoch, steps_completed, total_steps, log_freq, loggers): """Log information about the training progress, and the distribution of the weight tensors.