diff --git a/distiller/apputils/image_classifier.py b/distiller/apputils/image_classifier.py index f3be179f1e953b263ef9442b8f906813d013a699..3aa1fe7ca8976fd9a1f64115ba3c2b0a6cabfd4c 100755 --- a/distiller/apputils/image_classifier.py +++ b/distiller/apputils/image_classifier.py @@ -156,6 +156,12 @@ class ClassifierCompressor(object): validate_one_epoch finalize_epoch """ + if self.start_epoch >= self.ending_epoch: + msglogger.error( + 'epoch count is too low, starting epoch is {} but total epochs set to {}'.format( + self.start_epoch, self.ending_epoch)) + raise ValueError('Epochs parameter is too low. Nothing to do.') + # Load the datasets lazily self.load_datasets() @@ -230,7 +236,7 @@ def init_classifier_compression_arg_parser(): help='collect activation statistics on phases: train, valid, and/or test' ' (WARNING: this slows down training)') parser.add_argument('--activation-histograms', '--act-hist', - type=distiller.utils.float_range_argparse_checker(exc_min=True), + type=float_range(exc_min=True), metavar='PORTION_OF_TEST_SET', help='Run the model in evaluation mode on the specified portion of the test dataset and ' 'generate activation histograms. NOTE: This slows down evaluation significantly') @@ -251,8 +257,6 @@ def init_classifier_compression_arg_parser(): help='an optional parameter for sensitivity testing ' 'providing the range of sparsities to test.\n' 'This is equivalent to creating sensitivities = np.arange(start, stop, step)') - parser.add_argument('--extras', default=None, type=str, - help='file with extra configuration information') parser.add_argument('--deterministic', '--det', action='store_true', help='Ensure deterministic execution for re-producible results.') parser.add_argument('--seed', type=int, default=None, @@ -404,13 +408,7 @@ def _init_learner(args): elif compression_scheduler is None: compression_scheduler = distiller.CompressionScheduler(model) - ending_epoch = args.epochs - if start_epoch >= ending_epoch: - msglogger.error( - 'epoch count is too low, starting epoch is {} but total epochs set to {}'.format( - start_epoch, ending_epoch)) - raise ValueError('Epochs parameter is too low. Nothing to do.') - return model, compression_scheduler, optimizer, start_epoch, ending_epoch + return model, compression_scheduler, optimizer, start_epoch, args.epochs def create_activation_stats_collectors(model, *phases): diff --git a/distiller/models/__init__.py b/distiller/models/__init__.py index 95edfcb3651fa0092f050a513891a7ee174bc7aa..bf628db183fdfbc8187dd95ba543e8c3f3cb0c17 100755 --- a/distiller/models/__init__.py +++ b/distiller/models/__init__.py @@ -17,7 +17,7 @@ """This package contains ImageNet and CIFAR image classification models for pytorch""" import copy - +from functools import partial import torch import torchvision.models as torch_models from . import cifar10 as cifar10_models @@ -58,9 +58,9 @@ MNIST_MODEL_NAMES = sorted(name for name in mnist_models.__dict__ ALL_MODEL_NAMES = sorted(map(lambda s: s.lower(), set(IMAGENET_MODEL_NAMES + CIFAR10_MODEL_NAMES + MNIST_MODEL_NAMES))) + # A temporary monkey-patch to get past this Torchvision bug: # https://github.com/pytorch/pytorch/issues/20516 -from functools import partial def patch_torchvision_mobilenet_v2_bug(model): def patched_forward(self, x): x = self.features(x) @@ -202,7 +202,7 @@ def _is_registered_extension(arch, dataset, pretrained): try: return _model_extensions[(arch, dataset)] is not None except KeyError: - return None + return False def _create_extension_model(arch, dataset):