diff --git a/distiller/model_summaries.py b/distiller/model_summaries.py index b9edddcc28ebf3e778c63d883c754c8e45536fa9..bc9ee91afa1cc9ff778d0cca27813329c39e3bb1 100755 --- a/distiller/model_summaries.py +++ b/distiller/model_summaries.py @@ -96,11 +96,17 @@ def weights_sparsity_summary(model, return_total_sparsity=False, param_dims=[2,4 pd.set_option('precision', 2) params_size = 0 sparse_params_size = 0 + # In language models, we might use use "weight tying", which means that the same + # weights tensor is used in several different places. If tying is used, we'd like + # to log the tensor information, but exclude it from the total sparsity calculation. + seen_params = [] for name, param in model.state_dict().items(): if (param.dim() in param_dims) and any(type in name for type in ['weight', 'bias']): _density = distiller.density(param) - params_size += torch.numel(param) - sparse_params_size += param.numel() * _density + if name not in seen_params: + params_size += torch.numel(param) + sparse_params_size += param.numel() * _density + seen_params.append(name) df.loc[len(df.index)] = ([ name, distiller.size_to_str(param.size()),