Skip to content
Snippets Groups Projects
Commit 75ea6367 authored by Yury's avatar Yury
Browse files

ignore dropout during stat collection

parent 3b5fce47
No related tags found
No related merge requests found
...@@ -358,6 +358,9 @@ class QuantCalibrationStatsCollector(ActivationStatsCollector): ...@@ -358,6 +358,9 @@ class QuantCalibrationStatsCollector(ActivationStatsCollector):
setattr(m, n, False) setattr(m, n, False)
def _activation_stats_cb(self, module, inputs, output): def _activation_stats_cb(self, module, inputs, output):
if isinstance(module, torch.nn.Dropout):
return
def update_mean(old_mean, new_val): def update_mean(old_mean, new_val):
return old_mean + (new_val - old_mean) / module.batch_idx return old_mean + (new_val - old_mean) / module.batch_idx
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment