Skip to content
Snippets Groups Projects
Unverified Commit 2b6b7251 authored by Guy Jacob's avatar Guy Jacob Committed by GitHub
Browse files

Add functionality to log values of buffers in a model (#220)

* In all logger types (PythonLogger, TensorBoardLogger, CSVLogger)
* Exact behavior varies per logger type and documented in the code.
* To enable in CSVLogger, changed its API to take a file name prefix
  (optionally empty) instead of the full name, and use a hard-coded 
  name for logging weights sparsity.
* Also fixed signature of log_training_progress in base DataLogger
  class to match the signature used in the sub-classes.
parent 1a8c6bb8
No related branches found
No related tags found
No related merge requests found
...@@ -32,6 +32,8 @@ from distiller.utils import density, sparsity, sparsity_2D, size_to_str, to_np, ...@@ -32,6 +32,8 @@ from distiller.utils import density, sparsity, sparsity_2D, size_to_str, to_np,
from .tbbackend import TBBackend from .tbbackend import TBBackend
import csv import csv
import logging import logging
from contextlib import ExitStack
import os
msglogger = logging.getLogger() msglogger = logging.getLogger()
__all__ = ['PythonLogger', 'TensorBoardLogger', 'CsvLogger'] __all__ = ['PythonLogger', 'TensorBoardLogger', 'CsvLogger']
...@@ -47,18 +49,21 @@ class DataLogger(object): ...@@ -47,18 +49,21 @@ class DataLogger(object):
def __init__(self): def __init__(self):
pass pass
def log_training_progress(self, model, epoch, i, set_size, batch_time, data_time, classerr, losses, print_freq, collectors): def log_training_progress(self, stats_dict, epoch, completed, total, freq):
raise NotImplementedError pass
def log_activation_statsitic(self, phase, stat_name, activation_stats, epoch): def log_activation_statsitic(self, phase, stat_name, activation_stats, epoch):
raise NotImplementedError pass
def log_weights_sparsity(self, model, epoch): def log_weights_sparsity(self, model, epoch):
raise NotImplementedError pass
def log_weights_distribution(self, named_params, steps_completed): def log_weights_distribution(self, named_params, steps_completed):
pass pass
def log_model_buffers(self, model, buffer_names, tag_prefix, epoch, completed, total, freq):
pass
class PythonLogger(DataLogger): class PythonLogger(DataLogger):
def __init__(self, logger): def __init__(self, logger):
...@@ -90,6 +95,32 @@ class PythonLogger(DataLogger): ...@@ -90,6 +95,32 @@ class PythonLogger(DataLogger):
msglogger.info("\nParameters:\n" + str(t)) msglogger.info("\nParameters:\n" + str(t))
msglogger.info('Total sparsity: {:0.2f}\n'.format(total)) msglogger.info('Total sparsity: {:0.2f}\n'.format(total))
def log_model_buffers(self, model, buffer_names, tag_prefix, epoch, completed, total, freq):
"""Logs values of model buffers.
Notes:
1. Each buffer provided in 'buffer_names' is displayed in a separate table.
2. Within each table, each value is displayed in a separate column.
"""
datas = {name: [] for name in buffer_names}
maxlens = {name: 0 for name in buffer_names}
for n, m in model.named_modules():
for buffer_name in buffer_names:
try:
p = getattr(m, buffer_name)
except AttributeError:
continue
data = datas[buffer_name]
values = p if isinstance(p, (list, torch.nn.ParameterList)) else p.view(-1).tolist()
data.append([distiller.normalize_module_name(n) + '.' + buffer_name, *values])
maxlens[buffer_name] = max(maxlens[buffer_name], len(values))
for name in buffer_names:
if datas[name]:
headers = ['Layer'] + ['Val_' + str(i) for i in range(maxlens[name])]
t = tabulate.tabulate(datas[name], headers=headers, tablefmt='psql', floatfmt='.4f')
msglogger.info('\n' + name.upper() + ': (Epoch {0}, Step {1})\n'.format(epoch, completed) + t)
class TensorBoardLogger(DataLogger): class TensorBoardLogger(DataLogger):
def __init__(self, logdir): def __init__(self, logdir):
...@@ -158,14 +189,55 @@ class TensorBoardLogger(DataLogger): ...@@ -158,14 +189,55 @@ class TensorBoardLogger(DataLogger):
self.tblogger.histogram_summary(tag+'/grad', to_np(value.grad), steps_completed) self.tblogger.histogram_summary(tag+'/grad', to_np(value.grad), steps_completed)
self.tblogger.sync_to_file() self.tblogger.sync_to_file()
def log_model_buffers(self, model, buffer_names, tag_prefix, epoch, completed, total, freq):
"""Logs values of model buffers.
Notes:
1. Buffers are logged separately per-layer (i.e. module) within model
2. All values in a single buffer are logged such that they will be displayed on the same graph in
TensorBoard
3. Similarly, if multiple buffers are provided in buffer_names, all are presented on the same graph.
If this is un-desirable, call the function separately for each buffer
4. USE WITH CAUTION: While sometimes desirable, displaying multiple distinct values in a single
graph isn't well supported in TensorBoard. It is achieved using a work-around, which slows
down TensorBoard loading time considerably as the number of distinct values increases.
Therefore, while not limited, this function is only meant for use with a very limited number of
buffers and/or values, e.g. 2-5.
"""
for module_name, module in model.named_modules():
if distiller.has_children(module):
continue
sd = module.state_dict()
values = []
for buf_name in buffer_names:
try:
values += sd[buf_name].view(-1).tolist()
except KeyError:
continue
if values:
tag = '/'.join([tag_prefix, module_name])
self.tblogger.list_summary(tag, values, total * epoch + completed, len(values) > 1)
self.tblogger.sync_to_file()
class CsvLogger(DataLogger): class CsvLogger(DataLogger):
def __init__(self, fname): def __init__(self, fname_prefix='', logdir=''):
super(CsvLogger, self).__init__() super(CsvLogger, self).__init__()
self.fname = fname self.logdir = logdir
self.fname_prefix = fname_prefix
def get_fname(self, postfix):
fname = postfix + '.csv'
if self.fname_prefix:
fname = self.fname_prefix + '_' + fname
return os.path.join(self.logdir, fname)
def log_weights_sparsity(self, model, epoch): def log_weights_sparsity(self, model, epoch):
with open(self.fname, 'w') as csv_file: fname = self.get_fname('weights_sparsity')
with open(fname, 'w') as csv_file:
params_size = 0 params_size = 0
sparse_params_size = 0 sparse_params_size = 0
...@@ -182,3 +254,39 @@ class CsvLogger(DataLogger): ...@@ -182,3 +254,39 @@ class CsvLogger(DataLogger):
torch.numel(param), torch.numel(param),
int(_density * param.numel()), int(_density * param.numel()),
(1-_density)*100]) (1-_density)*100])
def log_model_buffers(self, model, buffer_names, tag_prefix, epoch, completed, total, freq):
"""Logs values of model buffers.
Notes:
1. Each buffer provided is logged in a separate CSV file
2. Each CSV file is continuously updated during the run.
3. In each call, a line is appended for each layer (i.e. module) containing the named buffers.
"""
with ExitStack() as stack:
files = {}
writers = {}
for buf_name in buffer_names:
fname = self.get_fname(buf_name)
new = not os.path.isfile(fname)
files[buf_name] = stack.enter_context(open(fname, 'a'))
writer = csv.writer(files[buf_name])
if new:
writer.writerow(['Layer', 'Epoch', 'Step', 'Total', 'Values'])
writers[buf_name] = writer
for n, m in model.named_modules():
for buffer_name in buffer_names:
try:
p = getattr(m, buffer_name)
except AttributeError:
continue
writer = writers[buffer_name]
if isinstance(p, (list, torch.nn.ParameterList)):
values = []
for v in p:
values += v.view(-1).tolist()
else:
values = p.view(-1).tolist()
writer.writerow([distiller.normalize_module_name(n) + '.' + buffer_name,
epoch, completed, int(total)] + values)
...@@ -46,7 +46,7 @@ __all__ = ['model_summary', ...@@ -46,7 +46,7 @@ __all__ = ['model_summary',
def model_summary(model, what, dataset=None): def model_summary(model, what, dataset=None):
if what == 'sparsity': if what == 'sparsity':
pylogger = PythonLogger(msglogger) pylogger = PythonLogger(msglogger)
csvlogger = CsvLogger('weights.csv') csvlogger = CsvLogger()
distiller.log_weights_sparsity(model, -1, loggers=[pylogger, csvlogger]) distiller.log_weights_sparsity(model, -1, loggers=[pylogger, csvlogger])
elif what == 'compute': elif what == 'compute':
if dataset == 'imagenet': if dataset == 'imagenet':
......
...@@ -88,7 +88,7 @@ def assign_layer_fq_names(container, name=None): ...@@ -88,7 +88,7 @@ def assign_layer_fq_names(container, name=None):
"""Assign human-readable names to the modules (layers). """Assign human-readable names to the modules (layers).
Sometimes we need to access modules by their names, and we'd like to use Sometimes we need to access modules by their names, and we'd like to use
fully-qualified names for convinience. fully-qualified names for convenience.
""" """
for name, module in container.named_modules(): for name, module in container.named_modules():
module.distiller_name = name module.distiller_name = name
...@@ -518,6 +518,32 @@ def log_weights_sparsity(model, epoch, loggers): ...@@ -518,6 +518,32 @@ def log_weights_sparsity(model, epoch, loggers):
logger.log_weights_sparsity(model, epoch) logger.log_weights_sparsity(model, epoch)
def log_model_buffers(model, buffer_names, tag_prefix, epoch, steps_completed, total_steps, log_freq, loggers=()):
"""
Log values of model buffers. 'buffer_names' is a list of buffers to be logged (which not necessarily exist
in all layers in the model).
USE WITH CARE:
* This logger logs each value within the buffers. As such, while any buffer can be passed
it is not really intended for big buffers such as model weights.
* Special attention is needed when using this using this functionality in TensorBoardLogger, as it could
significantly slow down the load time of TensorBard. Please see the documentation of 'log_model_buffers'
in that class.
Args:
model: Model containing buffers to be logged
buffer_names: Names of buffers to be logged. Expected to be
tag_prefix: Prefix to be used before buffer name by logger
epoch: The current epoch
steps_completed: The current step in the epoch
total_steps: The total number of training steps taken so far
log_freq: The number of steps between logging records
loggers: An iterable of loggers to send the log info to
"""
for logger in loggers:
logger.log_model_buffers(model, buffer_names, tag_prefix, epoch, steps_completed, total_steps, log_freq)
def has_children(module): def has_children(module):
try: try:
next(module.children()) next(module.children())
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment