From 5985d91e2aee9a201c3cd95e7d1d331a4893cc7e Mon Sep 17 00:00:00 2001
From: Peter Pao-Huang <ytp2@miranda.cs.illinois.edu>
Date: Mon, 12 Jul 2021 21:46:11 -0500
Subject: [PATCH] Changed quantization stat collection to work for yolo and
 other cnns

---
 distiller/data_loggers/collector.py    | 10 ++++++----
 distiller/quantization/quantizer.py    |  3 +++
 distiller/quantization/range_linear.py |  8 +++-----
 3 files changed, 12 insertions(+), 9 deletions(-)

diff --git a/distiller/data_loggers/collector.py b/distiller/data_loggers/collector.py
index adcce1b..3ba3be0 100755
--- a/distiller/data_loggers/collector.py
+++ b/distiller/data_loggers/collector.py
@@ -557,7 +557,7 @@ class QuantCalibrationStatsCollector(ActivationStatsCollector):
             if not tensor.is_contiguous():
                 tensor = tensor.contiguous()
 
-            if len(tensor.size()) == 0:
+            if (len(tensor.size()) == 0):
                 return
 
             act = tensor.view(tensor.size(0), -1)
@@ -846,7 +846,7 @@ class RawActivationsCollector(ActivationStatsCollector):
         return dir_name
 
 
-def collect_quant_stats(model, test_fn, save_dir=None, classes=None, inplace_runtime_check=False,
+def collect_quant_stats(model, trainer, dataloader, save_dir=None, classes=None, inplace_runtime_check=False,
                         disable_inplace_attrs=False, inplace_attr_names=('inplace',),
                         modules_to_collect=None):
     """
@@ -877,11 +877,13 @@ def collect_quant_stats(model, test_fn, save_dir=None, classes=None, inplace_run
                                                            inplace_attr_names=inplace_attr_names)
     with collector_context(quant_stats_collector, modules_to_collect):
         msglogger.info('Pass 1: Collecting min, max, avg_min, avg_max, mean')
-        test_fn(model=model)
+        # trainer(model, dataloader)
+        trainer(model, test_dataloaders=dataloader)
         # Collect Laplace distribution stats:
         msglogger.info('Pass 2: Collecting b, std parameters')
         quant_stats_collector.start_second_pass()
-        test_fn(model=model)
+        # trainer(model, dataloader)
+        trainer(model, test_dataloaders=dataloader)
         quant_stats_collector.stop_second_pass()
 
     msglogger.info('Stats collection complete')
diff --git a/distiller/quantization/quantizer.py b/distiller/quantization/quantizer.py
index 1972c6b..633b176 100644
--- a/distiller/quantization/quantizer.py
+++ b/distiller/quantization/quantizer.py
@@ -111,6 +111,9 @@ class Quantizer(object):
     def __init__(self, model, optimizer=None,
                  bits_activations=None, bits_weights=None, bits_bias=None,
                  overrides=None, train_with_fp_copy=False):
+        print("BITS FOR BIAS: {}".format(bits_bias))
+        print("BITS FOR BIAS: {}".format(bits_bias))
+        print("BITS FOR WEIGHTS: {}".format(bits_weights))
         if overrides is None:
             overrides = OrderedDict()
         if not isinstance(overrides, OrderedDict):
diff --git a/distiller/quantization/range_linear.py b/distiller/quantization/range_linear.py
index f602fb6..a9a3051 100644
--- a/distiller/quantization/range_linear.py
+++ b/distiller/quantization/range_linear.py
@@ -1927,7 +1927,7 @@ class PostTrainLinearQuantizer(Quantizer):
                        inputs_quant_auto_fallback=True,
                        save_fp_weights=args.qe_save_fp_weights)
 
-    def save_per_layer_parameters(self, save_dir=''):
+    def save_per_layer_parameters(self, save_path='./layer_quant_params.yaml'):
         defaults = OrderedDict(self.model.quantizer_metadata['params'])
         defaults.pop('bits_activations')
         defaults.pop('bits_parameters')
@@ -1953,11 +1953,10 @@ class PostTrainLinearQuantizer(Quantizer):
                 if v.numel() == 1:
                     lqp_dict[k] = v.item()
 
-        save_path = os.path.join(save_dir, 'layer_quant_params.yaml')
         distiller.yaml_ordered_save(save_path, out)
         msglogger.info('Per-layer quantization parameters saved to ' + save_path)
 
-    def prepare_model(self, dummy_input=None):
+    def prepare_model(self, dummy_input=None, save_path="./layer_quant_params.yaml"):
         if not self.model_activation_stats:
             msglogger.warning("\nWARNING:\nNo stats file passed - Dynamic quantization will be used\n"
                               "At the moment, this mode isn't as fully featured as stats-based quantization, and "
@@ -1983,8 +1982,7 @@ class PostTrainLinearQuantizer(Quantizer):
                                           'input in order to perform certain optimizations')
         super(PostTrainLinearQuantizer, self).prepare_model(dummy_input)
 
-        save_dir = msglogger.logdir if hasattr(msglogger, 'logdir') else '.'
-        self.save_per_layer_parameters(save_dir)
+        self.save_per_layer_parameters(save_path)
 
     def _pre_prepare_model(self, dummy_input):
         if not self.has_bidi_distiller_lstm:
-- 
GitLab