From d28587cd9db7a3f5dbbd125be7243636fc4708a8 Mon Sep 17 00:00:00 2001 From: Neta Zmora <neta.zmora@intel.com> Date: Tue, 13 Aug 2019 22:17:21 +0300 Subject: [PATCH] image_classifier.py: lazily init data-loaders when calling validate/test --- distiller/apputils/image_classifier.py | 21 ++++++++++++++------- 1 file changed, 14 insertions(+), 7 deletions(-) diff --git a/distiller/apputils/image_classifier.py b/distiller/apputils/image_classifier.py index 499abc9..6a64021 100755 --- a/distiller/apputils/image_classifier.py +++ b/distiller/apputils/image_classifier.py @@ -55,7 +55,7 @@ class ClassifierCompressor(object): """ def __init__(self, args, script_dir): self.args = args - _override_args(args) + _infer_implicit_args(args) self.logdir = _init_logger(args, script_dir) _config_determinism(args) _config_compute_device(args) @@ -158,10 +158,14 @@ class ClassifierCompressor(object): self._finalize_epoch(epoch, perf_scores_history, top1, top5) def validate(self, epoch=-1): + if self.val_loader is None: + self.load_datasets() return validate(self.val_loader, self.model, self.criterion, [self.tflogger, self.pylogger], self.args, epoch) def test(self): + if self.test_loader is None: + self.load_datasets() return test(self.test_loader, self.model, self.criterion, self.pylogger, self.activations_collectors, args=self.args) @@ -329,7 +333,6 @@ def _config_compute_device(args): if args.gpus is not None: try: args.gpus = [int(s) for s in args.gpus.split(',')] - except ValueError: raise ValueError('ERROR: Argument --gpus must be a comma-separated list of integers only') available_gpus = torch.cuda.device_count() @@ -341,10 +344,12 @@ def _config_compute_device(args): torch.cuda.set_device(args.gpus[0]) -def _override_args(args): +def _infer_implicit_args(args): # Infer the dataset from the model name - args.dataset = distiller.apputils.classification_dataset_str_from_arch(args.arch) - args.num_classes = distiller.apputils.classification_num_classes(args.dataset) + if not hasattr(args, 'dataset'): + args.dataset = distiller.apputils.classification_dataset_str_from_arch(args.arch) + if not hasattr(args, "num_classes"): + args.num_classes = distiller.apputils.classification_num_classes(args.dataset) def _init_learner(args): @@ -376,8 +381,8 @@ def _init_learner(args): msglogger.info('\nreset_optimizer flag set: Overriding resumed optimizer and resetting epoch count to 0') if optimizer is None: - optimizer = torch.optim.SGD(model.parameters(), - lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay) + optimizer = torch.optim.SGD(model.parameters(), lr=args.lr, + momentum=args.momentum, weight_decay=args.weight_decay) msglogger.debug('Optimizer Type: %s', type(optimizer)) msglogger.debug('Optimizer Args: %s', optimizer.defaults) @@ -592,6 +597,7 @@ def train(train_loader, model, criterion, optimizer, epoch, # NOTE: this breaks previous behavior, which returned a history of (top1, top5) values return classerr.value(1), classerr.value(5), losses[OVERALL_LOSS_KEY] + def validate(val_loader, model, criterion, loggers, args, epoch=-1): """Model validation""" if epoch > -1: @@ -617,6 +623,7 @@ def test(test_loader, model, criterion, loggers, activations_collectors, args): def _is_earlyexit(args): return hasattr(args, 'earlyexit_thresholds') and args.earlyexit_thresholds + def _validate(data_loader, model, criterion, loggers, args, epoch=-1): """Execute the validation/test loop.""" losses = {'objective_loss': tnt.AverageValueMeter()} -- GitLab