Skip to content
Snippets Groups Projects
Unverified Commit e7c7d94f authored by Guy Jacob's avatar Guy Jacob Committed by GitHub
Browse files

Changes to weights sparsity summary (#20)

* More strict and explicit check for the parameter's type in
  weights_sparsity_summary
* Expose 'param_dims' in weights_sparsity_tbl_summary as well
* Some PEP8 related fixes
parent 71178e60
No related branches found
No related tags found
No related merge requests found
......@@ -29,13 +29,15 @@ import torch
from torch.autograd import Variable
import torch.optim
import distiller
from .data_loggers import PythonLogger, CsvLogger
msglogger = logging.getLogger()
__all__ = ['model_summary',
'weights_sparsity_summary', 'weights_sparsity_tbl_summary',
'model_performance_summary', 'model_performance_tbl_summary']
from .data_loggers import PythonLogger, CsvLogger
def model_summary(model, what, dataset=None):
if what == 'sparsity':
pylogger = PythonLogger(msglogger)
......@@ -71,7 +73,7 @@ def model_summary(model, what, dataset=None):
print(tabulate(nodes, headers=['Name', 'Type']))
def weights_sparsity_summary(model, return_total_sparsity=False, param_dims=[2,4]):
def weights_sparsity_summary(model, return_total_sparsity=False, param_dims=[2, 4]):
df = pd.DataFrame(columns=['Name', 'Shape', 'NNZ (dense)', 'NNZ (sparse)',
'Cols (%)','Rows (%)', 'Ch (%)', '2D (%)', '3D (%)',
......@@ -79,8 +81,11 @@ def weights_sparsity_summary(model, return_total_sparsity=False, param_dims=[2,4
pd.set_option('precision', 2)
params_size = 0
sparse_params_size = 0
summary_param_types = ['weight', 'bias']
for name, param in model.state_dict().items():
if (param.dim() in param_dims) and any(type in name for type in ['weight', 'bias']):
# Extract just the actual parameter's name, which in this context we treat as its "type"
curr_param_type = name.split('.')[-1]
if param.dim() in param_dims and curr_param_type in summary_param_types:
_density = distiller.density(param)
params_size += torch.numel(param)
sparse_params_size += param.numel() * _density
......@@ -115,13 +120,15 @@ def weights_sparsity_summary(model, return_total_sparsity=False, param_dims=[2,4
return df, total_sparsity
return df
def weights_sparsity_tbl_summary(model, return_total_sparsity=False):
df, total_sparsity = weights_sparsity_summary(model, return_total_sparsity=True)
def weights_sparsity_tbl_summary(model, return_total_sparsity=False, param_dims=[2, 4]):
df, total_sparsity = weights_sparsity_summary(model, return_total_sparsity=True, param_dims=param_dims)
t = tabulate(df, headers='keys', tablefmt='psql', floatfmt=".5f")
if return_total_sparsity:
return t, total_sparsity
return t
# Performance data collection code follows from here down
def conv_visitor(self, input, output, df, model, memo):
......@@ -138,6 +145,7 @@ def conv_visitor(self, input, output, df, model, memo):
attrs = 'k=' + '('+(', ').join(['%d' % v for v in self.kernel_size])+')'
module_visitor(self, input, output, df, model, weights_vol, macs, attrs)
def fc_visitor(self, input, output, df, model, memo):
assert isinstance(self, torch.nn.Linear)
if self in memo:
......@@ -188,6 +196,7 @@ def model_performance_summary(model, dummy_input, batch_size=1):
return df
def model_performance_tbl_summary(model, dummy_input, batch_size):
df = model_performance_summary(model, dummy_input, batch_size)
t = tabulate(df, headers='keys', tablefmt='psql', floatfmt=".5f")
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment