Skip to content
Snippets Groups Projects
Unverified Commit e7c7d94f authored by Guy Jacob's avatar Guy Jacob Committed by GitHub
Browse files

Changes to weights sparsity summary (#20)

* More strict and explicit check for the parameter's type in
  weights_sparsity_summary
* Expose 'param_dims' in weights_sparsity_tbl_summary as well
* Some PEP8 related fixes
parent 71178e60
No related branches found
No related tags found
No related merge requests found
...@@ -29,13 +29,15 @@ import torch ...@@ -29,13 +29,15 @@ import torch
from torch.autograd import Variable from torch.autograd import Variable
import torch.optim import torch.optim
import distiller import distiller
from .data_loggers import PythonLogger, CsvLogger
msglogger = logging.getLogger() msglogger = logging.getLogger()
__all__ = ['model_summary', __all__ = ['model_summary',
'weights_sparsity_summary', 'weights_sparsity_tbl_summary', 'weights_sparsity_summary', 'weights_sparsity_tbl_summary',
'model_performance_summary', 'model_performance_tbl_summary'] 'model_performance_summary', 'model_performance_tbl_summary']
from .data_loggers import PythonLogger, CsvLogger
def model_summary(model, what, dataset=None): def model_summary(model, what, dataset=None):
if what == 'sparsity': if what == 'sparsity':
pylogger = PythonLogger(msglogger) pylogger = PythonLogger(msglogger)
...@@ -71,7 +73,7 @@ def model_summary(model, what, dataset=None): ...@@ -71,7 +73,7 @@ def model_summary(model, what, dataset=None):
print(tabulate(nodes, headers=['Name', 'Type'])) print(tabulate(nodes, headers=['Name', 'Type']))
def weights_sparsity_summary(model, return_total_sparsity=False, param_dims=[2,4]): def weights_sparsity_summary(model, return_total_sparsity=False, param_dims=[2, 4]):
df = pd.DataFrame(columns=['Name', 'Shape', 'NNZ (dense)', 'NNZ (sparse)', df = pd.DataFrame(columns=['Name', 'Shape', 'NNZ (dense)', 'NNZ (sparse)',
'Cols (%)','Rows (%)', 'Ch (%)', '2D (%)', '3D (%)', 'Cols (%)','Rows (%)', 'Ch (%)', '2D (%)', '3D (%)',
...@@ -79,8 +81,11 @@ def weights_sparsity_summary(model, return_total_sparsity=False, param_dims=[2,4 ...@@ -79,8 +81,11 @@ def weights_sparsity_summary(model, return_total_sparsity=False, param_dims=[2,4
pd.set_option('precision', 2) pd.set_option('precision', 2)
params_size = 0 params_size = 0
sparse_params_size = 0 sparse_params_size = 0
summary_param_types = ['weight', 'bias']
for name, param in model.state_dict().items(): for name, param in model.state_dict().items():
if (param.dim() in param_dims) and any(type in name for type in ['weight', 'bias']): # Extract just the actual parameter's name, which in this context we treat as its "type"
curr_param_type = name.split('.')[-1]
if param.dim() in param_dims and curr_param_type in summary_param_types:
_density = distiller.density(param) _density = distiller.density(param)
params_size += torch.numel(param) params_size += torch.numel(param)
sparse_params_size += param.numel() * _density sparse_params_size += param.numel() * _density
...@@ -115,13 +120,15 @@ def weights_sparsity_summary(model, return_total_sparsity=False, param_dims=[2,4 ...@@ -115,13 +120,15 @@ def weights_sparsity_summary(model, return_total_sparsity=False, param_dims=[2,4
return df, total_sparsity return df, total_sparsity
return df return df
def weights_sparsity_tbl_summary(model, return_total_sparsity=False):
df, total_sparsity = weights_sparsity_summary(model, return_total_sparsity=True) def weights_sparsity_tbl_summary(model, return_total_sparsity=False, param_dims=[2, 4]):
df, total_sparsity = weights_sparsity_summary(model, return_total_sparsity=True, param_dims=param_dims)
t = tabulate(df, headers='keys', tablefmt='psql', floatfmt=".5f") t = tabulate(df, headers='keys', tablefmt='psql', floatfmt=".5f")
if return_total_sparsity: if return_total_sparsity:
return t, total_sparsity return t, total_sparsity
return t return t
# Performance data collection code follows from here down # Performance data collection code follows from here down
def conv_visitor(self, input, output, df, model, memo): def conv_visitor(self, input, output, df, model, memo):
...@@ -138,6 +145,7 @@ def conv_visitor(self, input, output, df, model, memo): ...@@ -138,6 +145,7 @@ def conv_visitor(self, input, output, df, model, memo):
attrs = 'k=' + '('+(', ').join(['%d' % v for v in self.kernel_size])+')' attrs = 'k=' + '('+(', ').join(['%d' % v for v in self.kernel_size])+')'
module_visitor(self, input, output, df, model, weights_vol, macs, attrs) module_visitor(self, input, output, df, model, weights_vol, macs, attrs)
def fc_visitor(self, input, output, df, model, memo): def fc_visitor(self, input, output, df, model, memo):
assert isinstance(self, torch.nn.Linear) assert isinstance(self, torch.nn.Linear)
if self in memo: if self in memo:
...@@ -188,6 +196,7 @@ def model_performance_summary(model, dummy_input, batch_size=1): ...@@ -188,6 +196,7 @@ def model_performance_summary(model, dummy_input, batch_size=1):
return df return df
def model_performance_tbl_summary(model, dummy_input, batch_size): def model_performance_tbl_summary(model, dummy_input, batch_size):
df = model_performance_summary(model, dummy_input, batch_size) df = model_performance_summary(model, dummy_input, batch_size)
t = tabulate(df, headers='keys', tablefmt='psql', floatfmt=".5f") t = tabulate(df, headers='keys', tablefmt='psql', floatfmt=".5f")
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment