diff --git a/distiller/apputils/image_classifier.py b/distiller/apputils/image_classifier.py index 6a640216578ea88cca4d9c0ac2c24b9df76f2b24..8db49acb366b83de9c5323c82f38164b44ff810a 100755 --- a/distiller/apputils/image_classifier.py +++ b/distiller/apputils/image_classifier.py @@ -65,7 +65,7 @@ class ClassifierCompressor(object): self.tflogger = TensorBoardLogger(msglogger.logdir) self.pylogger = PythonLogger(msglogger) (self.model, self.compression_scheduler, self.optimizer, - self.start_epoch, self.ending_epoch) = _init_learner(args) + self.start_epoch, self.ending_epoch) = _init_learner(args) # Define loss function (criterion) self.criterion = nn.CrossEntropyLoss().to(args.device) @@ -74,13 +74,18 @@ class ClassifierCompressor(object): def load_datasets(self): """Load the datasets""" - self.train_loader, self.val_loader, self.test_loader = load_data(self.args) + if not all((self.train_loader, self.val_loader, self.test_loader)): + self.train_loader, self.val_loader, self.test_loader = load_data(self.args) + return self.data_loaders + + @property + def data_loaders(self): + return self.train_loader, self.val_loader, self.test_loader def train_one_epoch(self, epoch, verbose=True): - if self.train_loader is None: - self.load_datasets() + """Train for one epoch""" + self.load_datasets() - # Train for one epoch with collectors_context(self.activations_collectors["train"]) as collectors: top1, top5, loss = train(self.train_loader, self.model, self.criterion, self.optimizer, epoch, self.compression_scheduler, @@ -108,7 +113,9 @@ class ClassifierCompressor(object): return top1, top5, loss def validate_one_epoch(self, epoch, verbose=True): - # evaluate on validation set + """Evaluate on validation set""" + self.load_datasets() + with collectors_context(self.activations_collectors["valid"]) as collectors: top1, top5, vloss = validate(self.val_loader, self.model, self.criterion, [self.pylogger], self.args, epoch) @@ -147,9 +154,8 @@ class ClassifierCompressor(object): validate_one_epoch finalize_epoch """ - if self.train_loader is None: - # Load the datasets lazily - self.load_datasets() + # Load the datasets lazily + self.load_datasets() perf_scores_history = [] for epoch in range(self.start_epoch, self.ending_epoch): @@ -158,14 +164,12 @@ class ClassifierCompressor(object): self._finalize_epoch(epoch, perf_scores_history, top1, top5) def validate(self, epoch=-1): - if self.val_loader is None: - self.load_datasets() + self.load_datasets() return validate(self.val_loader, self.model, self.criterion, [self.tflogger, self.pylogger], self.args, epoch) def test(self): - if self.test_loader is None: - self.load_datasets() + self.load_datasets() return test(self.test_loader, self.model, self.criterion, self.pylogger, self.activations_collectors, args=self.args) @@ -372,8 +376,7 @@ def _init_learner(args): model, compression_scheduler, optimizer, start_epoch = apputils.load_checkpoint( model, args.resumed_checkpoint_path, model_device=args.device) elif args.load_model_path: - model = apputils.load_lean_checkpoint(model, args.load_model_path, - model_device=args.device) + model = apputils.load_lean_checkpoint(model, args.load_model_path, model_device=args.device) if args.reset_optimizer: start_epoch = 0 if optimizer is not None: