From 10cd1a853b7477ebb20c8d1afcf4b66a95bce3be Mon Sep 17 00:00:00 2001 From: Guy Jacob <guy.jacob@intel.com> Date: Wed, 11 Dec 2019 10:54:37 +0200 Subject: [PATCH] Quantization stats collection minor updates * Limit batch size to 128 when initiating from image classification app * Don't raise inplace error in case of Dropout module --- distiller/apputils/image_classifier.py | 4 ++++ distiller/data_loggers/collector.py | 9 +++++---- 2 files changed, 9 insertions(+), 4 deletions(-) diff --git a/distiller/apputils/image_classifier.py b/distiller/apputils/image_classifier.py index 59bcfb3..c61fbf6 100755 --- a/distiller/apputils/image_classifier.py +++ b/distiller/apputils/image_classifier.py @@ -880,6 +880,10 @@ def acts_quant_stats_collection(model, criterion, loggers, args, test_loader=Non if test_loader is None: tmp_args = copy.deepcopy(args) tmp_args.effective_test_size = tmp_args.qe_calibration + # Batch size 256 causes out-of-memory errors on some models (due to extra space taken by + # stats calculations). Limiting to 128 for now. + # TODO: Come up with "smarter" limitation? + tmp_args.batch_size = min(128, tmp_args.batch_size) test_loader = load_data(tmp_args, fixed_subset=True, load_train=False, load_val=False) test_fn = partial(test, test_loader=test_loader, criterion=criterion, loggers=loggers, args=args, activations_collectors=None) diff --git a/distiller/data_loggers/collector.py b/distiller/data_loggers/collector.py index f561a2b..27dcafc 100755 --- a/distiller/data_loggers/collector.py +++ b/distiller/data_loggers/collector.py @@ -527,7 +527,7 @@ class QuantCalibrationStatsCollector(ActivationStatsCollector): """ Updates the 'b' parameter of Laplace Distribution. """ - curr_abs_dists = (values - mean).abs() + curr_abs_dists = (values - mean).abs_() return update_running_mean(curr_abs_dists, previous_b, total_values_so_far) def update_record(record, tensor): @@ -566,9 +566,10 @@ class QuantCalibrationStatsCollector(ActivationStatsCollector): record['shape'] = distiller.size2str(tensor) if self.inplace_runtime_check and any([id(input) == id(output) for input in inputs]): - raise RuntimeError('Inplace operation detected, meaning inputs stats are overridden by output stats. ' - 'You can either disable this check or make sure no in-place operations occur. ' - 'See QuantCalibrationStatsCollector class documentation for more info.') + if not isinstance(module, torch.nn.modules.dropout._DropoutNd): + raise RuntimeError('Inplace operation detected, meaning inputs stats are overridden by output stats. ' + 'You can either disable this check or make sure no in-place operations occur. ' + 'See QuantCalibrationStatsCollector class documentation for more info.') module.batch_idx += 1 -- GitLab