From 10cd1a853b7477ebb20c8d1afcf4b66a95bce3be Mon Sep 17 00:00:00 2001
From: Guy Jacob <guy.jacob@intel.com>
Date: Wed, 11 Dec 2019 10:54:37 +0200
Subject: [PATCH] Quantization stats collection minor updates

* Limit batch size to 128 when initiating from image classification app
* Don't raise inplace error in case of Dropout module
---
 distiller/apputils/image_classifier.py | 4 ++++
 distiller/data_loggers/collector.py    | 9 +++++----
 2 files changed, 9 insertions(+), 4 deletions(-)

diff --git a/distiller/apputils/image_classifier.py b/distiller/apputils/image_classifier.py
index 59bcfb3..c61fbf6 100755
--- a/distiller/apputils/image_classifier.py
+++ b/distiller/apputils/image_classifier.py
@@ -880,6 +880,10 @@ def acts_quant_stats_collection(model, criterion, loggers, args, test_loader=Non
     if test_loader is None:
         tmp_args = copy.deepcopy(args)
         tmp_args.effective_test_size = tmp_args.qe_calibration
+        # Batch size 256 causes out-of-memory errors on some models (due to extra space taken by
+        # stats calculations). Limiting to 128 for now.
+        # TODO: Come up with "smarter" limitation?
+        tmp_args.batch_size = min(128, tmp_args.batch_size)
         test_loader = load_data(tmp_args, fixed_subset=True, load_train=False, load_val=False)
     test_fn = partial(test, test_loader=test_loader, criterion=criterion,
                       loggers=loggers, args=args, activations_collectors=None)
diff --git a/distiller/data_loggers/collector.py b/distiller/data_loggers/collector.py
index f561a2b..27dcafc 100755
--- a/distiller/data_loggers/collector.py
+++ b/distiller/data_loggers/collector.py
@@ -527,7 +527,7 @@ class QuantCalibrationStatsCollector(ActivationStatsCollector):
             """
             Updates the 'b' parameter of Laplace Distribution.
             """
-            curr_abs_dists = (values - mean).abs()
+            curr_abs_dists = (values - mean).abs_()
             return update_running_mean(curr_abs_dists, previous_b, total_values_so_far)
 
         def update_record(record, tensor):
@@ -566,9 +566,10 @@ class QuantCalibrationStatsCollector(ActivationStatsCollector):
                 record['shape'] = distiller.size2str(tensor)
 
         if self.inplace_runtime_check and any([id(input) == id(output) for input in inputs]):
-            raise RuntimeError('Inplace operation detected, meaning inputs stats are overridden by output stats. '
-                               'You can either disable this check or make sure no in-place operations occur. '
-                               'See QuantCalibrationStatsCollector class documentation for more info.')
+            if not isinstance(module, torch.nn.modules.dropout._DropoutNd):
+                raise RuntimeError('Inplace operation detected, meaning inputs stats are overridden by output stats. '
+                                   'You can either disable this check or make sure no in-place operations occur. '
+                                   'See QuantCalibrationStatsCollector class documentation for more info.')
 
         module.batch_idx += 1
 
-- 
GitLab