diff --git a/distiller/apputils/image_classifier.py b/distiller/apputils/image_classifier.py index 59bcfb3429666eceadc7df1a646768f7404cccad..c61fbf6605940f09373db8cedcc18c49ebe4ffd8 100755 --- a/distiller/apputils/image_classifier.py +++ b/distiller/apputils/image_classifier.py @@ -880,6 +880,10 @@ def acts_quant_stats_collection(model, criterion, loggers, args, test_loader=Non if test_loader is None: tmp_args = copy.deepcopy(args) tmp_args.effective_test_size = tmp_args.qe_calibration + # Batch size 256 causes out-of-memory errors on some models (due to extra space taken by + # stats calculations). Limiting to 128 for now. + # TODO: Come up with "smarter" limitation? + tmp_args.batch_size = min(128, tmp_args.batch_size) test_loader = load_data(tmp_args, fixed_subset=True, load_train=False, load_val=False) test_fn = partial(test, test_loader=test_loader, criterion=criterion, loggers=loggers, args=args, activations_collectors=None) diff --git a/distiller/data_loggers/collector.py b/distiller/data_loggers/collector.py index f561a2ba19288730a6cb98eb61e2172d4a145591..27dcafc081c9c53dce53c09ca7acc57ec6d74953 100755 --- a/distiller/data_loggers/collector.py +++ b/distiller/data_loggers/collector.py @@ -527,7 +527,7 @@ class QuantCalibrationStatsCollector(ActivationStatsCollector): """ Updates the 'b' parameter of Laplace Distribution. """ - curr_abs_dists = (values - mean).abs() + curr_abs_dists = (values - mean).abs_() return update_running_mean(curr_abs_dists, previous_b, total_values_so_far) def update_record(record, tensor): @@ -566,9 +566,10 @@ class QuantCalibrationStatsCollector(ActivationStatsCollector): record['shape'] = distiller.size2str(tensor) if self.inplace_runtime_check and any([id(input) == id(output) for input in inputs]): - raise RuntimeError('Inplace operation detected, meaning inputs stats are overridden by output stats. ' - 'You can either disable this check or make sure no in-place operations occur. ' - 'See QuantCalibrationStatsCollector class documentation for more info.') + if not isinstance(module, torch.nn.modules.dropout._DropoutNd): + raise RuntimeError('Inplace operation detected, meaning inputs stats are overridden by output stats. ' + 'You can either disable this check or make sure no in-place operations occur. ' + 'See QuantCalibrationStatsCollector class documentation for more info.') module.batch_idx += 1