diff --git a/distiller/data_loggers/collector.py b/distiller/data_loggers/collector.py
index 7fef7487a45363ad314567d95413a849fe700e41..a7621d9d3fb33d848a8d4a2ec6a2e52fd9284d41 100755
--- a/distiller/data_loggers/collector.py
+++ b/distiller/data_loggers/collector.py
@@ -30,6 +30,8 @@ import matplotlib
 matplotlib.use('Agg')
 import matplotlib.pyplot as plt
 import distiller
+import numpy as np
+
 msglogger = logging.getLogger()
 
 __all__ = ['SummaryActivationStatsCollector', 'RecordsActivationStatsCollector',
@@ -146,6 +148,30 @@ class ActivationStatsCollector(object):
         raise NotImplementedError
 
 
+class WeightedAverageValueMeter(AverageValueMeter):
+    """
+    A correction to torchnet's AverageValueMeter which doesn't implement
+    std collection correctly by taking into account the batch size.
+    """
+    def add(self, value, n=1):
+        self.sum += value*n
+        if n <= 0:
+            raise ValueError("Cannot use a non-positive weight for the running stat.")
+        elif self.n == 0:
+            self.mean = 0.0 + value  # This is to force a copy in torch/numpy
+            self.std = np.inf
+            self.mean_old = self.mean
+            self.m_s = 0.0
+        else:
+            self.mean = self.mean_old + n * (value - self.mean_old) / float(self.n+n)
+            self.m_s += n*(value - self.mean_old) * (value - self.mean)
+            self.mean_old = self.mean
+            self.std = np.sqrt(self.m_s / (self.n + n - 1.0))
+        self.var = self.std**2
+
+        self.n += n
+
+
 class SummaryActivationStatsCollector(ActivationStatsCollector):
     """This class collects activiations statistical summaries.
 
@@ -165,7 +191,7 @@ class SummaryActivationStatsCollector(ActivationStatsCollector):
         This is a callback from the forward() of 'module'.
         """
         try:
-            getattr(module, self.stat_name).add(self.summary_fn(output.data))
+            getattr(module, self.stat_name).add(self.summary_fn(output.data), output.data.numel())
         except RuntimeError as e:
             if "The expanded size of the tensor" in e.args[0]:
                 raise ValueError("ActivationStatsCollector: a module ({} - {}) was encountered twice during model.apply().\n"
@@ -181,7 +207,7 @@ class SummaryActivationStatsCollector(ActivationStatsCollector):
 
     def _start_counter(self, module):
         if not hasattr(module, self.stat_name):
-            setattr(module, self.stat_name, AverageValueMeter())
+            setattr(module, self.stat_name, WeightedAverageValueMeter())
             # Assign a name to this summary
             if hasattr(module, 'distiller_name'):
                 getattr(module, self.stat_name).name = '_'.join((self.stat_name, module.distiller_name))
@@ -323,6 +349,7 @@ class _QuantStatsRecord(object):
         for stat_name in ['avg_min', 'avg_max', 'mean', 'std', 'b']:
             records[stat_name] = 0
         records['shape'] = ''
+        records['total_numel'] = 0
         return records
 
     def __init__(self):
@@ -395,7 +422,7 @@ class QuantCalibrationStatsCollector(ActivationStatsCollector):
 
         self.batch_idx = 0
         self.inplace_runtime_check = inplace_runtime_check
-        self.collecting_laplace = False
+        self.collecting_second_pass = False
 
         if disable_inplace_attrs:
             if not inplace_attr_names:
@@ -425,45 +452,65 @@ class QuantCalibrationStatsCollector(ActivationStatsCollector):
                                    'Please collect the required statistics using `collector.start()` and evaluating'
                                    ' the model for enough batches.' % name)
 
-    def start_laplace(self):
+    def start_second_pass(self):
         self._check_required_stats()
-        self.collecting_laplace = True
+        self.collecting_second_pass = True
         # reset batch_idx for all leaf modules
         for module in self.model.modules():
             if distiller.has_children(module) or isinstance(module, torch.nn.Identity):
                 continue
             module.batch_idx = 0
+            for record in module.quant_stats.inputs:
+                record['total_numel'] = 0
+            module.quant_stats.output['total_numel'] = 0
 
-    def stop_laplace(self):
-        self.collecting_laplace = False
+    def stop_second_pass(self):
+        self.collecting_second_pass = False
 
     def _activation_stats_cb(self, module, inputs, output):
-        def update_mean(old_mean, new_val):
-            return old_mean + (new_val - old_mean) / module.batch_idx
-
-        def update_std(values, old_std, old_mean, new_mean):
-            # See here:
-            # https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Welford's_Online_algorithm
-            numel = values.numel() if isinstance(values, torch.Tensor) else values.size
-            total_values_so_far = numel * (module.batch_idx - 1)
-            M = (old_std ** 2) * (total_values_so_far - 1)
-            mean_diffs = (values - old_mean) * (values - new_mean)
-            M += mean_diffs.sum()
-            return sqrt((M / (total_values_so_far + numel - 1)).item())
-
-        def update_b(values, old_b, mean):
+        """
+        A callback for updating the required statistics for quantization in a module.
+        """
+        def update_running_mean(values, prev_mean, total_values_so_far):
+            """
+            Updates a running mean of a tensor of values
+            Args:
+                values (torch.Tensor): the new tensor
+                prev_mean (float): the previous running mean
+                total_values_so_far (int): the number of the values so far
+            """
+            curr_numel = values.numel()
+            prev_numel = total_values_so_far
+            return (prev_numel * prev_mean + values.sum().item()) / (prev_numel + curr_numel)
+
+        def update_std(values, prev_std, mean, total_values_so_far):
+            """
+            Updates std of the tensor
+            """
+            prev_variance = prev_std ** 2
+            curr_sqr_dists = (values - mean) ** 2
+            new_variance = update_running_mean(curr_sqr_dists, prev_variance, total_values_so_far)
+            return sqrt(new_variance)
+
+        def update_b(values, previous_b, mean, total_values_so_far):
             """
             Updates the 'b' parameter of Laplace Distribution.
             """
-            current_b = (values - mean).abs().mean().item()
-            return old_b + (current_b - old_b) / module.batch_idx
+            curr_abs_dists = (values - mean).abs()
+            return update_running_mean(curr_abs_dists, previous_b, total_values_so_far)
 
         def update_record(record, tensor):
+            if tensor.dtype not in [torch.float16, torch.float32, torch.float64]:
+                # Mean function only works for float tensors
+                tensor = tensor.to(torch.float32)
             if not tensor.is_contiguous():
                 tensor = tensor.contiguous()
             act = tensor.view(tensor.size(0), -1)
-            if self.collecting_laplace:
-                record['b'] = update_b(act, record['b'], record['mean'])
+            numel = act.numel()
+            if self.collecting_second_pass:
+                record['b'] = update_b(act, record['b'], record['mean'], record['total_numel'])
+                record['std'] = update_std(act, record['std'], record['mean'], record['total_numel'])
+                record['total_numel'] += numel
                 return
 
             # In the general case, the average min/max that we're collecting are averages over the per-sample
@@ -472,23 +519,17 @@ class QuantCalibrationStatsCollector(ActivationStatsCollector):
             # But - If each sample contains just a single value, then such a per-sample calculation we'll result in
             # avg_min = avg_max. So in that case we "revert" to calculating "global" values, for the whole batch,
             # instead of per-sample values
-            dim = 0 if act.numel() == act.shape[0] else 1
+            dim = 0 if numel == act.shape[0] else 1
 
             min_per_sample = act.min(dim=dim)[0]
             max_per_sample = act.max(dim=dim)[0]
             record['min'] = min(record['min'], min_per_sample.min().item())
             record['max'] = max(record['max'], max_per_sample.max().item())
-            try:
-                record['avg_min'] = update_mean(record['avg_min'], min_per_sample.mean().item())
-                record['avg_max'] = update_mean(record['avg_max'], max_per_sample.mean().item())
-                new_mean = update_mean(record['mean'], act.mean().item())
-                record['std'] = update_std(tensor, record['std'], record['mean'], new_mean)
-            except RuntimeError:
-                record['avg_min'] = update_mean(record['avg_min'], min_per_sample.cpu().numpy().mean().item(0))
-                record['avg_max'] = update_mean(record['avg_max'], max_per_sample.cpu().numpy().mean().item(0))
-                new_mean = update_mean(record['mean'], act.cpu().numpy().mean().item(0))
-                record['std'] = update_std(tensor.cpu().numpy(), record['std'], record['mean'], new_mean)
+            record['avg_min'] = update_running_mean(min_per_sample, record['avg_min'], record['total_numel'])
+            record['avg_max'] = update_running_mean(max_per_sample, record['avg_max'], record['total_numel'])
+            new_mean = update_running_mean(act, record['mean'], record['total_numel'])
             record['mean'] = new_mean
+            record['total_numel'] += numel
 
             if not record['shape']:
                 record['shape'] = distiller.size2str(tensor)
@@ -742,13 +783,13 @@ def collect_quant_stats(model, test_fn, save_dir=None, classes=None, inplace_run
                                                            disable_inplace_attrs=disable_inplace_attrs,
                                                            inplace_attr_names=inplace_attr_names)
     with collector_context(quant_stats_collector, modules_to_collect):
-        msglogger.info('Pass 1: Collecting min, max, avg_min, avg_max, mean, std')
+        msglogger.info('Pass 1: Collecting min, max, avg_min, avg_max, mean')
         test_fn(model=model)
         # Collect Laplace distribution stats:
-        msglogger.info('Pass 2: Collecting b parameter')
-        quant_stats_collector.start_laplace()
+        msglogger.info('Pass 2: Collecting b, std parameters')
+        quant_stats_collector.start_second_pass()
         test_fn(model=model)
-        quant_stats_collector.stop_laplace()
+        quant_stats_collector.stop_second_pass()
 
     msglogger.info('Stats collection complete')
     if save_dir is not None: