diff --git a/distiller/apputils/image_classifier.py b/distiller/apputils/image_classifier.py index f2f2a57599ee14a8fdb1dc4ac1d70c1fa3254395..c5cd3f658698316a535cf29a65b7235fd7636b07 100755 --- a/distiller/apputils/image_classifier.py +++ b/distiller/apputils/image_classifier.py @@ -152,6 +152,12 @@ class ClassifierCompressor(object): validate_one_epoch finalize_epoch """ + if self.start_epoch >= self.ending_epoch: + msglogger.error( + 'epoch count is too low, starting epoch is {} but total epochs set to {}'.format( + self.start_epoch, self.ending_epoch)) + raise ValueError('Epochs parameter is too low. Nothing to do.') + # Load the datasets lazily self.load_datasets() @@ -396,13 +402,7 @@ def _init_learner(args): elif compression_scheduler is None: compression_scheduler = distiller.CompressionScheduler(model) - ending_epoch = args.epochs - if start_epoch >= ending_epoch: - msglogger.error( - 'epoch count is too low, starting epoch is {} but total epochs set to {}'.format( - start_epoch, ending_epoch)) - raise ValueError('Epochs parameter is too low. Nothing to do.') - return model, compression_scheduler, optimizer, start_epoch, ending_epoch + return model, compression_scheduler, optimizer, start_epoch, args.epochs def create_activation_stats_collectors(model, *phases):