diff --git a/distiller/model_summaries.py b/distiller/model_summaries.py index 908f3d667f8388d574ce1a4be52fcd8849e09773..322411469de789dfd9d0303e6a64bfa0f8d5f256 100755 --- a/distiller/model_summaries.py +++ b/distiller/model_summaries.py @@ -29,13 +29,15 @@ import torch from torch.autograd import Variable import torch.optim import distiller +from .data_loggers import PythonLogger, CsvLogger + msglogger = logging.getLogger() __all__ = ['model_summary', 'weights_sparsity_summary', 'weights_sparsity_tbl_summary', 'model_performance_summary', 'model_performance_tbl_summary'] -from .data_loggers import PythonLogger, CsvLogger + def model_summary(model, what, dataset=None): if what == 'sparsity': pylogger = PythonLogger(msglogger) @@ -71,7 +73,7 @@ def model_summary(model, what, dataset=None): print(tabulate(nodes, headers=['Name', 'Type'])) -def weights_sparsity_summary(model, return_total_sparsity=False, param_dims=[2,4]): +def weights_sparsity_summary(model, return_total_sparsity=False, param_dims=[2, 4]): df = pd.DataFrame(columns=['Name', 'Shape', 'NNZ (dense)', 'NNZ (sparse)', 'Cols (%)','Rows (%)', 'Ch (%)', '2D (%)', '3D (%)', @@ -79,8 +81,11 @@ def weights_sparsity_summary(model, return_total_sparsity=False, param_dims=[2,4 pd.set_option('precision', 2) params_size = 0 sparse_params_size = 0 + summary_param_types = ['weight', 'bias'] for name, param in model.state_dict().items(): - if (param.dim() in param_dims) and any(type in name for type in ['weight', 'bias']): + # Extract just the actual parameter's name, which in this context we treat as its "type" + curr_param_type = name.split('.')[-1] + if param.dim() in param_dims and curr_param_type in summary_param_types: _density = distiller.density(param) params_size += torch.numel(param) sparse_params_size += param.numel() * _density @@ -115,13 +120,15 @@ def weights_sparsity_summary(model, return_total_sparsity=False, param_dims=[2,4 return df, total_sparsity return df -def weights_sparsity_tbl_summary(model, return_total_sparsity=False): - df, total_sparsity = weights_sparsity_summary(model, return_total_sparsity=True) + +def weights_sparsity_tbl_summary(model, return_total_sparsity=False, param_dims=[2, 4]): + df, total_sparsity = weights_sparsity_summary(model, return_total_sparsity=True, param_dims=param_dims) t = tabulate(df, headers='keys', tablefmt='psql', floatfmt=".5f") if return_total_sparsity: return t, total_sparsity return t + # Performance data collection code follows from here down def conv_visitor(self, input, output, df, model, memo): @@ -138,6 +145,7 @@ def conv_visitor(self, input, output, df, model, memo): attrs = 'k=' + '('+(', ').join(['%d' % v for v in self.kernel_size])+')' module_visitor(self, input, output, df, model, weights_vol, macs, attrs) + def fc_visitor(self, input, output, df, model, memo): assert isinstance(self, torch.nn.Linear) if self in memo: @@ -188,6 +196,7 @@ def model_performance_summary(model, dummy_input, batch_size=1): return df + def model_performance_tbl_summary(model, dummy_input, batch_size): df = model_performance_summary(model, dummy_input, batch_size) t = tabulate(df, headers='keys', tablefmt='psql', floatfmt=".5f")