diff --git a/distiller/data_loggers/collector.py b/distiller/data_loggers/collector.py
index d3f98cb79eaefb0d7f13434342d6a68741155cf1..f561a2ba19288730a6cb98eb61e2172d4a145591 100755
--- a/distiller/data_loggers/collector.py
+++ b/distiller/data_loggers/collector.py
@@ -30,6 +30,7 @@ import matplotlib
 matplotlib.use('Agg')
 import matplotlib.pyplot as plt
 import distiller
+from distiller.quantization.range_linear import is_post_train_quant_wrapper
 import numpy as np
 
 msglogger = logging.getLogger()
@@ -78,6 +79,13 @@ class ActivationStatsCollector(object):
         # a unique, human-readable name per layer.
         distiller.utils.assign_layer_fq_names(model)
 
+        # Currently this is internal, and its only purpose is to enable skipping collection
+        # for wrapped modules inside post-training quantization wrapper classes.
+        # When doing PTQ, the outputs of these wrapped modules are actually intermediate results
+        # which are not relevant for tracking.
+        self._dont_collect_list = [module.wrapped_module.distiller_name for module in model.modules() if
+                                   is_post_train_quant_wrapper(module)]
+
     def value(self):
         """Return a dictionary containing {layer_name: statistic}"""
         activation_stats = OrderedDict()
@@ -104,10 +112,7 @@ class ActivationStatsCollector(object):
 
         Eligible modules are currently filtered by their class type.
         """
-        if distiller.has_children(module) or isinstance(module, torch.nn.Identity):
-            return
-        register_all_class_types = not self.classes
-        if register_all_class_types or isinstance(module, tuple(self.classes)):
+        if self._should_collect(module):
             self.fwd_hook_handles.append(module.register_forward_hook(self._activation_stats_cb))
             self._start_counter(module)
 
@@ -147,6 +152,24 @@ class ActivationStatsCollector(object):
         """Handle new activations - this is subclass-specific code"""
         raise NotImplementedError
 
+    def _should_collect(self, module):
+        if module.distiller_name in self._dont_collect_list:
+            return False
+        # In general, we only collect stats for "leaf" modules.
+        # We make an exception for models that were quantized with 'PostTrainLinearQuantizer'. In these
+        # models, the quantized modules are actually wrappers of the original FP32 modules, so they are
+        # NOT leaf modules - but we still want to track them.
+        if distiller.has_children(module) and not is_post_train_quant_wrapper(module):
+            return False
+        if isinstance(module, torch.nn.Identity):
+            return False
+
+        register_all_class_types = not self.classes
+        if register_all_class_types or isinstance(module, tuple(self.classes)):
+            return True
+
+        return False
+
 
 class WeightedAverageValueMeter(AverageValueMeter):
     """
@@ -173,7 +196,7 @@ class WeightedAverageValueMeter(AverageValueMeter):
 
 
 class SummaryActivationStatsCollector(ActivationStatsCollector):
-    """This class collects activiations statistical summaries.
+    """This class collects activations statistical summaries.
 
     This Collector computes the mean of some statistic of the activation.  It is rather
     light-weight and quicker than collecting a record per activation.
@@ -445,7 +468,7 @@ class QuantCalibrationStatsCollector(ActivationStatsCollector):
         Check whether the required statistics were collected to allow collecting laplace distribution stats.
         """
         for name, module in self.model.named_modules():
-            if distiller.has_children(module) or isinstance(module, torch.nn.Identity):
+            if not self._should_collect(module):
                 continue
             if not hasattr(module, 'quant_stats'):
                 raise RuntimeError('Collection of Laplace distribution statistics is '
@@ -463,9 +486,9 @@ class QuantCalibrationStatsCollector(ActivationStatsCollector):
     def start_second_pass(self):
         self._check_required_stats()
         self.collecting_second_pass = True
-        # reset batch_idx for all leaf modules
+        # reset batch_idx for all tracked modules
         for module in self.model.modules():
-            if distiller.has_children(module) or isinstance(module, torch.nn.Identity):
+            if not self._should_collect(module):
                 continue
             module.batch_idx = 0
             for record in module.quant_stats.inputs:
@@ -571,8 +594,6 @@ class QuantCalibrationStatsCollector(ActivationStatsCollector):
             module.batch_idx = 0
 
     def _collect_activations_stats(self, module, activation_stats, name=''):
-        if distiller.utils.has_children(module):
-            return
         if not hasattr(module, 'quant_stats'):
             return
 
@@ -688,8 +709,6 @@ class ActivationHistogramsCollector(ActivationStatsCollector):
             self._reset(module)
 
     def _collect_activations_stats(self, module, activation_stats, name=''):
-        if distiller.utils.has_children(module):
-            return
         if not hasattr(module, 'output_hist'):
             return
 
diff --git a/distiller/quantization/range_linear.py b/distiller/quantization/range_linear.py
index a5e03b4b3f131c18aa210cd0f5418e8e4297e068..87f85466a54ef1f67b1f435edaf8c62e57d3384b 100644
--- a/distiller/quantization/range_linear.py
+++ b/distiller/quantization/range_linear.py
@@ -674,7 +674,7 @@ class RangeLinearQuantParamLayerWrapper(RangeLinearQuantWrapper):
             self.is_simulated_quant_weight_shifted.fill_(True)  # i.e. is_simulated_quant_weight_shifted = True
 
         input_q += input_q.quant_metadata.zero_point
-        accum = self.wrapped_module.forward(input_q)
+        accum = self.wrapped_module(input_q)
         clamp(accum.data, self.accum_min_q_val, self.accum_max_q_val, inplace=True)
 
         return accum
@@ -749,8 +749,8 @@ class RangeLinearQuantMatmulWrapper(RangeLinearQuantWrapper):
 
     def quantized_forward(self, input0_q, input1_q):
         self.accum_scale = input0_q.quant_metadata.scale * input1_q.quant_metadata.scale
-        accum = self.wrapped_module.forward(input0_q + input0_q.quant_metadata.zero_point,
-                                            input1_q + input1_q.quant_metadata.zero_point)
+        accum = self.wrapped_module(input0_q + input0_q.quant_metadata.zero_point,
+                                    input1_q + input1_q.quant_metadata.zero_point)
         clamp(accum.data, self.accum_min_q_val, self.accum_max_q_val, inplace=True)
         return accum
 
@@ -994,8 +994,13 @@ class RangeLinearFakeQuantWrapper(RangeLinearQuantWrapper):
         return output_scale, output_zero_point
 
 
-def _is_range_linear_wrapper(module):
-    return isinstance(module, (RangeLinearEmbeddingWrapper, RangeLinearQuantWrapper))
+_ptq_wrappers_int_only = (RangeLinearQuantWrapper, RangeLinearEmbeddingWrapper)
+_ptq_wrappers_all = _ptq_wrappers_int_only + (FPWrapper,)
+
+
+def is_post_train_quant_wrapper(module, include_fpwrapper=True):
+    types = _ptq_wrappers_all if include_fpwrapper else _ptq_wrappers_int_only
+    return isinstance(module, types)
 
 
 class PostTrainLinearQuantizer(Quantizer):
@@ -1232,7 +1237,7 @@ class PostTrainLinearQuantizer(Quantizer):
 
     def named_acts_quant_params(self):
         for module_name, module in self.model.named_modules():
-            if _is_range_linear_wrapper(module):
+            if is_post_train_quant_wrapper(module, include_fpwrapper=False):
                 for buff_name, buff in module.named_acts_quant_params():
                     full_buff_name = "%s.%s" % (module_name, buff_name)
                     yield full_buff_name, buff