Skip to content
Snippets Groups Projects
Commit d28587cd authored by Neta Zmora's avatar Neta Zmora
Browse files

image_classifier.py: lazily init data-loaders when calling validate/test

parent 54e72012
No related branches found
No related tags found
No related merge requests found
...@@ -55,7 +55,7 @@ class ClassifierCompressor(object): ...@@ -55,7 +55,7 @@ class ClassifierCompressor(object):
""" """
def __init__(self, args, script_dir): def __init__(self, args, script_dir):
self.args = args self.args = args
_override_args(args) _infer_implicit_args(args)
self.logdir = _init_logger(args, script_dir) self.logdir = _init_logger(args, script_dir)
_config_determinism(args) _config_determinism(args)
_config_compute_device(args) _config_compute_device(args)
...@@ -158,10 +158,14 @@ class ClassifierCompressor(object): ...@@ -158,10 +158,14 @@ class ClassifierCompressor(object):
self._finalize_epoch(epoch, perf_scores_history, top1, top5) self._finalize_epoch(epoch, perf_scores_history, top1, top5)
def validate(self, epoch=-1): def validate(self, epoch=-1):
if self.val_loader is None:
self.load_datasets()
return validate(self.val_loader, self.model, self.criterion, return validate(self.val_loader, self.model, self.criterion,
[self.tflogger, self.pylogger], self.args, epoch) [self.tflogger, self.pylogger], self.args, epoch)
def test(self): def test(self):
if self.test_loader is None:
self.load_datasets()
return test(self.test_loader, self.model, self.criterion, return test(self.test_loader, self.model, self.criterion,
self.pylogger, self.activations_collectors, args=self.args) self.pylogger, self.activations_collectors, args=self.args)
...@@ -329,7 +333,6 @@ def _config_compute_device(args): ...@@ -329,7 +333,6 @@ def _config_compute_device(args):
if args.gpus is not None: if args.gpus is not None:
try: try:
args.gpus = [int(s) for s in args.gpus.split(',')] args.gpus = [int(s) for s in args.gpus.split(',')]
except ValueError: except ValueError:
raise ValueError('ERROR: Argument --gpus must be a comma-separated list of integers only') raise ValueError('ERROR: Argument --gpus must be a comma-separated list of integers only')
available_gpus = torch.cuda.device_count() available_gpus = torch.cuda.device_count()
...@@ -341,10 +344,12 @@ def _config_compute_device(args): ...@@ -341,10 +344,12 @@ def _config_compute_device(args):
torch.cuda.set_device(args.gpus[0]) torch.cuda.set_device(args.gpus[0])
def _override_args(args): def _infer_implicit_args(args):
# Infer the dataset from the model name # Infer the dataset from the model name
args.dataset = distiller.apputils.classification_dataset_str_from_arch(args.arch) if not hasattr(args, 'dataset'):
args.num_classes = distiller.apputils.classification_num_classes(args.dataset) args.dataset = distiller.apputils.classification_dataset_str_from_arch(args.arch)
if not hasattr(args, "num_classes"):
args.num_classes = distiller.apputils.classification_num_classes(args.dataset)
def _init_learner(args): def _init_learner(args):
...@@ -376,8 +381,8 @@ def _init_learner(args): ...@@ -376,8 +381,8 @@ def _init_learner(args):
msglogger.info('\nreset_optimizer flag set: Overriding resumed optimizer and resetting epoch count to 0') msglogger.info('\nreset_optimizer flag set: Overriding resumed optimizer and resetting epoch count to 0')
if optimizer is None: if optimizer is None:
optimizer = torch.optim.SGD(model.parameters(), optimizer = torch.optim.SGD(model.parameters(), lr=args.lr,
lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay) momentum=args.momentum, weight_decay=args.weight_decay)
msglogger.debug('Optimizer Type: %s', type(optimizer)) msglogger.debug('Optimizer Type: %s', type(optimizer))
msglogger.debug('Optimizer Args: %s', optimizer.defaults) msglogger.debug('Optimizer Args: %s', optimizer.defaults)
...@@ -592,6 +597,7 @@ def train(train_loader, model, criterion, optimizer, epoch, ...@@ -592,6 +597,7 @@ def train(train_loader, model, criterion, optimizer, epoch,
# NOTE: this breaks previous behavior, which returned a history of (top1, top5) values # NOTE: this breaks previous behavior, which returned a history of (top1, top5) values
return classerr.value(1), classerr.value(5), losses[OVERALL_LOSS_KEY] return classerr.value(1), classerr.value(5), losses[OVERALL_LOSS_KEY]
def validate(val_loader, model, criterion, loggers, args, epoch=-1): def validate(val_loader, model, criterion, loggers, args, epoch=-1):
"""Model validation""" """Model validation"""
if epoch > -1: if epoch > -1:
...@@ -617,6 +623,7 @@ def test(test_loader, model, criterion, loggers, activations_collectors, args): ...@@ -617,6 +623,7 @@ def test(test_loader, model, criterion, loggers, activations_collectors, args):
def _is_earlyexit(args): def _is_earlyexit(args):
return hasattr(args, 'earlyexit_thresholds') and args.earlyexit_thresholds return hasattr(args, 'earlyexit_thresholds') and args.earlyexit_thresholds
def _validate(data_loader, model, criterion, loggers, args, epoch=-1): def _validate(data_loader, model, criterion, loggers, args, epoch=-1):
"""Execute the validation/test loop.""" """Execute the validation/test loop."""
losses = {'objective_loss': tnt.AverageValueMeter()} losses = {'objective_loss': tnt.AverageValueMeter()}
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment