From 8c5de42cf08e8d32b966a570c207039cb799a528 Mon Sep 17 00:00:00 2001
From: Bar <elhararb@gmail.com>
Date: Thu, 18 Apr 2019 12:13:48 +0300
Subject: [PATCH] Remove single worker limitation in deterministic mode (#227)

Also:
* Single worker limitation not needed anymore, been fixed in PyTorch
  since v0.4.0 (https://github.com/pytorch/pytorch/pull/4640)
* compress_classifier.py: If run in evaluation mode (--eval), enable
  deterministic mode.
* Call utils.set_deterministic at data loaders creation if
  deterministic argument is set (don't assume user calls it outside)
* Disable CUDNN benchmark mode in utils.set_deterministic
  (https://pytorch.org/docs/stable/notes/randomness.html#cudnn)
---
 distiller/apputils/data_loaders.py            | 11 ++++++++++-
 distiller/utils.py                            | 19 ++++++++++++-------
 .../compress_classifier.py                    | 15 +++++++--------
 examples/classifier_compression/parser.py     |  2 +-
 tests/full_flow_tests.py                      |  8 ++++----
 5 files changed, 34 insertions(+), 21 deletions(-)

diff --git a/distiller/apputils/data_loaders.py b/distiller/apputils/data_loaders.py
index 96802af..a6a667d 100755
--- a/distiller/apputils/data_loaders.py
+++ b/distiller/apputils/data_loaders.py
@@ -19,6 +19,7 @@
 This code will help with the image classification datasets: ImageNet and CIFAR10
 
 """
+import logging
 import os
 import torch
 import torchvision.transforms as transforms
@@ -26,6 +27,11 @@ import torchvision.datasets as datasets
 from torch.utils.data.sampler import Sampler
 import numpy as np
 
+import distiller
+
+
+msglogger = logging.getLogger()
+
 DATASETS_NAMES = ['imagenet', 'cifar10']
 
 
@@ -170,7 +176,10 @@ def get_data_loaders(datasets_fn, data_dir, batch_size, num_workers, validation_
                      effective_train_size=1., effective_valid_size=1., effective_test_size=1.):
     train_dataset, test_dataset = datasets_fn(data_dir)
 
-    worker_init_fn = __deterministic_worker_init_fn if deterministic else None
+    worker_init_fn = None
+    if deterministic:
+        distiller.set_deterministic()
+        worker_init_fn = __deterministic_worker_init_fn
 
     num_train = len(train_dataset)
     indices = list(range(num_train))
diff --git a/distiller/utils.py b/distiller/utils.py
index 99557f5..3f8b825 100755
--- a/distiller/utils.py
+++ b/distiller/utils.py
@@ -19,18 +19,22 @@
 This module contains various tensor sparsity/density measurement functions, together
 with some random helper functions.
 """
-import inspect
+import argparse
+from collections import OrderedDict
+from copy import deepcopy
+import logging
+import operator
+import random
 
 import numpy as np
 import torch
 import torch.nn as nn
 import torch.backends.cudnn as cudnn
-import random
-from copy import deepcopy
 import yaml
-from collections import OrderedDict
-import argparse
-import operator
+
+import inspect
+
+msglogger = logging.getLogger()
 
 
 def model_device(model):
@@ -584,10 +588,12 @@ def make_non_parallel_copy(model):
 
 
 def set_deterministic():
+    msglogger.debug('set_deterministic is called')
     torch.manual_seed(0)
     random.seed(0)
     np.random.seed(0)
     torch.backends.cudnn.deterministic = True
+    torch.backends.cudnn.benchmark = False
 
 
 def yaml_ordered_load(stream, Loader=yaml.Loader, object_pairs_hook=OrderedDict):
@@ -623,7 +629,6 @@ def float_range_argparse_checker(min_val=0., max_val=1., exc_min=False, exc_max=
     return checker
 
 
-
 def filter_kwargs(dict_to_filter, function_to_call):
     """Utility to check which arguments in the passed dictionary exist in a function's signature
 
diff --git a/examples/classifier_compression/compress_classifier.py b/examples/classifier_compression/compress_classifier.py
index 28a715e..ad499a7 100755
--- a/examples/classifier_compression/compress_classifier.py
+++ b/examples/classifier_compression/compress_classifier.py
@@ -103,18 +103,17 @@ def main():
     start_epoch = 0
     ending_epoch = args.epochs
     perf_scores_history = []
+
+    if args.evaluate:
+        args.deterministic = True
     if args.deterministic:
         # Experiment reproducibility is sometimes important.  Pete Warden expounded about this
         # in his blog: https://petewarden.com/2018/03/19/the-machine-learning-reproducibility-crisis/
-        # In Pytorch, support for deterministic execution is still a bit clunky.
-        if args.workers > 1:
-            raise ValueError('ERROR: Setting --deterministic requires setting --workers/-j to 0 or 1')
-        # Use a well-known seed, for repeatability of experiments
-        distiller.set_deterministic()
+        distiller.set_deterministic()  # Use a well-known seed, for repeatability of experiments
     else:
-        # This issue: https://github.com/pytorch/pytorch/issues/3659
-        # Implies that cudnn.benchmark should respect cudnn.deterministic, but empirically we see that
-        # results are not re-produced when benchmark is set. So enabling only if deterministic mode disabled.
+        # Turn on CUDNN benchmark mode for best performance. This is usually "safe" for image
+        # classification models, as the input sizes don't change during the run
+        # See here: https://discuss.pytorch.org/t/what-does-torch-backends-cudnn-benchmark-do/5936/3
         cudnn.benchmark = True
 
     if args.cpu or not torch.cuda.is_available():
diff --git a/examples/classifier_compression/parser.py b/examples/classifier_compression/parser.py
index 2783b29..2deddcf 100755
--- a/examples/classifier_compression/parser.py
+++ b/examples/classifier_compression/parser.py
@@ -69,7 +69,7 @@ def get_parser():
                         help='Flag to override optimizer if resumed from checkpoint. This will reset epochs count.')
 
     parser.add_argument('-e', '--evaluate', dest='evaluate', action='store_true',
-                        help='evaluate model on validation set')
+                        help='evaluate model on test set')
     parser.add_argument('--activation-stats', '--act-stats', nargs='+', metavar='PHASE', default=list(),
                         help='collect activation statistics on phases: train, valid, and/or test'
                         ' (WARNING: this slows down training)')
diff --git a/tests/full_flow_tests.py b/tests/full_flow_tests.py
index 2a35dd7..4a27458 100755
--- a/tests/full_flow_tests.py
+++ b/tests/full_flow_tests.py
@@ -115,16 +115,16 @@ def collateral_checker(log, *collateral_list):
 TestConfig = namedtuple('TestConfig', ['args', 'dataset', 'checker_fn', 'checker_args'])
 
 test_configs = [
-    TestConfig('--arch simplenet_cifar --epochs 2', DS_CIFAR, accuracy_checker, [48.220, 92.930]),
+    TestConfig('--arch simplenet_cifar --epochs 2', DS_CIFAR, accuracy_checker, [44.610, 92.080]),
     TestConfig('-a resnet20_cifar --resume {0} --quantize-eval --evaluate'.
                format(os.path.join(examples_root, 'ssl', 'checkpoints', 'checkpoint_trained_dense.pth.tar')),
-               DS_CIFAR, accuracy_checker, [91.640, 99.610]),
+               DS_CIFAR, accuracy_checker, [91.710, 99.610]),
     TestConfig('-a preact_resnet20_cifar --epochs 2 --compress {0}'.
                format(os.path.join('full_flow_tests', 'preact_resnet20_cifar_pact_test.yaml')),
-               DS_CIFAR, accuracy_checker, [54.390, 94.280]),
+               DS_CIFAR, accuracy_checker, [54.590, 94.810]),
     TestConfig('-a resnet20_cifar --resume {0} --sense=filter --sense-range 0 0.10 0.05'.
                format(os.path.join(examples_root, 'ssl', 'checkpoints', 'checkpoint_trained_dense.pth.tar')),
-               DS_CIFAR, collateral_checker, [('sensitivity.csv', 3165), ('sensitivity.png', 96158)])
+               DS_CIFAR, collateral_checker, [('sensitivity.csv', 3175), ('sensitivity.png', 96158)])
 ]
 
 
-- 
GitLab