diff --git a/distiller/apputils/data_loaders.py b/distiller/apputils/data_loaders.py index 96802afaafb24ae933712ded99d79ab4064b3fa1..a6a667dd49a6baa5a2e25dfc69596832f2079a1a 100755 --- a/distiller/apputils/data_loaders.py +++ b/distiller/apputils/data_loaders.py @@ -19,6 +19,7 @@ This code will help with the image classification datasets: ImageNet and CIFAR10 """ +import logging import os import torch import torchvision.transforms as transforms @@ -26,6 +27,11 @@ import torchvision.datasets as datasets from torch.utils.data.sampler import Sampler import numpy as np +import distiller + + +msglogger = logging.getLogger() + DATASETS_NAMES = ['imagenet', 'cifar10'] @@ -170,7 +176,10 @@ def get_data_loaders(datasets_fn, data_dir, batch_size, num_workers, validation_ effective_train_size=1., effective_valid_size=1., effective_test_size=1.): train_dataset, test_dataset = datasets_fn(data_dir) - worker_init_fn = __deterministic_worker_init_fn if deterministic else None + worker_init_fn = None + if deterministic: + distiller.set_deterministic() + worker_init_fn = __deterministic_worker_init_fn num_train = len(train_dataset) indices = list(range(num_train)) diff --git a/distiller/utils.py b/distiller/utils.py index 99557f51b4149e745221d39f7c1e80bd3f47ca01..3f8b825416424a11791a6647c8b2d22e1296c47f 100755 --- a/distiller/utils.py +++ b/distiller/utils.py @@ -19,18 +19,22 @@ This module contains various tensor sparsity/density measurement functions, together with some random helper functions. """ -import inspect +import argparse +from collections import OrderedDict +from copy import deepcopy +import logging +import operator +import random import numpy as np import torch import torch.nn as nn import torch.backends.cudnn as cudnn -import random -from copy import deepcopy import yaml -from collections import OrderedDict -import argparse -import operator + +import inspect + +msglogger = logging.getLogger() def model_device(model): @@ -584,10 +588,12 @@ def make_non_parallel_copy(model): def set_deterministic(): + msglogger.debug('set_deterministic is called') torch.manual_seed(0) random.seed(0) np.random.seed(0) torch.backends.cudnn.deterministic = True + torch.backends.cudnn.benchmark = False def yaml_ordered_load(stream, Loader=yaml.Loader, object_pairs_hook=OrderedDict): @@ -623,7 +629,6 @@ def float_range_argparse_checker(min_val=0., max_val=1., exc_min=False, exc_max= return checker - def filter_kwargs(dict_to_filter, function_to_call): """Utility to check which arguments in the passed dictionary exist in a function's signature diff --git a/examples/classifier_compression/compress_classifier.py b/examples/classifier_compression/compress_classifier.py index 28a715ea4a2ad3b1d2632a2ad8ca36cfe8df6a5e..ad499a75451551d58f94f730e31772b1c38b0712 100755 --- a/examples/classifier_compression/compress_classifier.py +++ b/examples/classifier_compression/compress_classifier.py @@ -103,18 +103,17 @@ def main(): start_epoch = 0 ending_epoch = args.epochs perf_scores_history = [] + + if args.evaluate: + args.deterministic = True if args.deterministic: # Experiment reproducibility is sometimes important. Pete Warden expounded about this # in his blog: https://petewarden.com/2018/03/19/the-machine-learning-reproducibility-crisis/ - # In Pytorch, support for deterministic execution is still a bit clunky. - if args.workers > 1: - raise ValueError('ERROR: Setting --deterministic requires setting --workers/-j to 0 or 1') - # Use a well-known seed, for repeatability of experiments - distiller.set_deterministic() + distiller.set_deterministic() # Use a well-known seed, for repeatability of experiments else: - # This issue: https://github.com/pytorch/pytorch/issues/3659 - # Implies that cudnn.benchmark should respect cudnn.deterministic, but empirically we see that - # results are not re-produced when benchmark is set. So enabling only if deterministic mode disabled. + # Turn on CUDNN benchmark mode for best performance. This is usually "safe" for image + # classification models, as the input sizes don't change during the run + # See here: https://discuss.pytorch.org/t/what-does-torch-backends-cudnn-benchmark-do/5936/3 cudnn.benchmark = True if args.cpu or not torch.cuda.is_available(): diff --git a/examples/classifier_compression/parser.py b/examples/classifier_compression/parser.py index 2783b29c311edad9a7f5da1132d7c0978055b2ac..2deddcff3216c8610eead8ea0a5b935e7f7157c4 100755 --- a/examples/classifier_compression/parser.py +++ b/examples/classifier_compression/parser.py @@ -69,7 +69,7 @@ def get_parser(): help='Flag to override optimizer if resumed from checkpoint. This will reset epochs count.') parser.add_argument('-e', '--evaluate', dest='evaluate', action='store_true', - help='evaluate model on validation set') + help='evaluate model on test set') parser.add_argument('--activation-stats', '--act-stats', nargs='+', metavar='PHASE', default=list(), help='collect activation statistics on phases: train, valid, and/or test' ' (WARNING: this slows down training)') diff --git a/tests/full_flow_tests.py b/tests/full_flow_tests.py index 2a35dd7a6f85edac9977daeecf71005e53518b97..4a27458679f341a6776b31d457ac77e5847b502f 100755 --- a/tests/full_flow_tests.py +++ b/tests/full_flow_tests.py @@ -115,16 +115,16 @@ def collateral_checker(log, *collateral_list): TestConfig = namedtuple('TestConfig', ['args', 'dataset', 'checker_fn', 'checker_args']) test_configs = [ - TestConfig('--arch simplenet_cifar --epochs 2', DS_CIFAR, accuracy_checker, [48.220, 92.930]), + TestConfig('--arch simplenet_cifar --epochs 2', DS_CIFAR, accuracy_checker, [44.610, 92.080]), TestConfig('-a resnet20_cifar --resume {0} --quantize-eval --evaluate'. format(os.path.join(examples_root, 'ssl', 'checkpoints', 'checkpoint_trained_dense.pth.tar')), - DS_CIFAR, accuracy_checker, [91.640, 99.610]), + DS_CIFAR, accuracy_checker, [91.710, 99.610]), TestConfig('-a preact_resnet20_cifar --epochs 2 --compress {0}'. format(os.path.join('full_flow_tests', 'preact_resnet20_cifar_pact_test.yaml')), - DS_CIFAR, accuracy_checker, [54.390, 94.280]), + DS_CIFAR, accuracy_checker, [54.590, 94.810]), TestConfig('-a resnet20_cifar --resume {0} --sense=filter --sense-range 0 0.10 0.05'. format(os.path.join(examples_root, 'ssl', 'checkpoints', 'checkpoint_trained_dense.pth.tar')), - DS_CIFAR, collateral_checker, [('sensitivity.csv', 3165), ('sensitivity.png', 96158)]) + DS_CIFAR, collateral_checker, [('sensitivity.csv', 3175), ('sensitivity.png', 96158)]) ]