From d28587cd9db7a3f5dbbd125be7243636fc4708a8 Mon Sep 17 00:00:00 2001
From: Neta Zmora <neta.zmora@intel.com>
Date: Tue, 13 Aug 2019 22:17:21 +0300
Subject: [PATCH] image_classifier.py: lazily init data-loaders when calling
 validate/test

---
 distiller/apputils/image_classifier.py | 21 ++++++++++++++-------
 1 file changed, 14 insertions(+), 7 deletions(-)

diff --git a/distiller/apputils/image_classifier.py b/distiller/apputils/image_classifier.py
index 499abc9..6a64021 100755
--- a/distiller/apputils/image_classifier.py
+++ b/distiller/apputils/image_classifier.py
@@ -55,7 +55,7 @@ class ClassifierCompressor(object):
     """
     def __init__(self, args, script_dir):
         self.args = args
-        _override_args(args)
+        _infer_implicit_args(args)
         self.logdir = _init_logger(args, script_dir)
         _config_determinism(args)
         _config_compute_device(args)
@@ -158,10 +158,14 @@ class ClassifierCompressor(object):
             self._finalize_epoch(epoch, perf_scores_history, top1, top5)
 
     def validate(self, epoch=-1):
+        if self.val_loader is None:
+            self.load_datasets()
         return validate(self.val_loader, self.model, self.criterion,
                         [self.tflogger, self.pylogger], self.args, epoch)
 
     def test(self):
+        if self.test_loader is None:
+            self.load_datasets()
         return test(self.test_loader, self.model, self.criterion,
                     self.pylogger, self.activations_collectors, args=self.args)
 
@@ -329,7 +333,6 @@ def _config_compute_device(args):
         if args.gpus is not None:
             try:
                 args.gpus = [int(s) for s in args.gpus.split(',')]
-                
             except ValueError:
                 raise ValueError('ERROR: Argument --gpus must be a comma-separated list of integers only')
             available_gpus = torch.cuda.device_count()
@@ -341,10 +344,12 @@ def _config_compute_device(args):
             torch.cuda.set_device(args.gpus[0])
 
 
-def _override_args(args):
+def _infer_implicit_args(args):
     # Infer the dataset from the model name
-    args.dataset = distiller.apputils.classification_dataset_str_from_arch(args.arch)
-    args.num_classes = distiller.apputils.classification_num_classes(args.dataset)
+    if not hasattr(args, 'dataset'):
+        args.dataset = distiller.apputils.classification_dataset_str_from_arch(args.arch)
+    if not hasattr(args, "num_classes"):
+        args.num_classes = distiller.apputils.classification_num_classes(args.dataset)
 
 
 def _init_learner(args):
@@ -376,8 +381,8 @@ def _init_learner(args):
             msglogger.info('\nreset_optimizer flag set: Overriding resumed optimizer and resetting epoch count to 0')
 
     if optimizer is None:
-        optimizer = torch.optim.SGD(model.parameters(),
-            lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay)
+        optimizer = torch.optim.SGD(model.parameters(), lr=args.lr,
+                                    momentum=args.momentum, weight_decay=args.weight_decay)
         msglogger.debug('Optimizer Type: %s', type(optimizer))
         msglogger.debug('Optimizer Args: %s', optimizer.defaults)
 
@@ -592,6 +597,7 @@ def train(train_loader, model, criterion, optimizer, epoch,
     # NOTE: this breaks previous behavior, which returned a history of (top1, top5) values
     return classerr.value(1), classerr.value(5), losses[OVERALL_LOSS_KEY]
 
+
 def validate(val_loader, model, criterion, loggers, args, epoch=-1):
     """Model validation"""
     if epoch > -1:
@@ -617,6 +623,7 @@ def test(test_loader, model, criterion, loggers, activations_collectors, args):
 def _is_earlyexit(args):
     return hasattr(args, 'earlyexit_thresholds') and args.earlyexit_thresholds
 
+
 def _validate(data_loader, model, criterion, loggers, args, epoch=-1):
     """Execute the validation/test loop."""
     losses = {'objective_loss': tnt.AverageValueMeter()}
-- 
GitLab