diff --git a/distiller/apputils/data_loaders.py b/distiller/apputils/data_loaders.py index ae189984173c26f5f9ff3f8617fa1ffafc6ca01c..e187750765b85b4e4b59a996fe975dfec9b739e3 100755 --- a/distiller/apputils/data_loaders.py +++ b/distiller/apputils/data_loaders.py @@ -24,6 +24,7 @@ import torch import torchvision.transforms as transforms import torchvision.datasets as datasets from torch.utils.data.sampler import Sampler +from functools import partial import numpy as np import distiller @@ -58,19 +59,21 @@ def classification_get_input_shape(dataset): raise ValueError("dataset %s is not supported" % dataset) -def __dataset_factory(dataset): +def __dataset_factory(dataset, arch): return {'cifar10': cifar10_get_datasets, 'mnist': mnist_get_datasets, - 'imagenet': imagenet_get_datasets}.get(dataset, None) + 'imagenet': partial(imagenet_get_datasets, arch=arch)}.get(dataset, None) -def load_data(dataset, data_dir, batch_size, workers, validation_split=0.1, deterministic=False, +def load_data(dataset, arch, data_dir, + batch_size, workers, validation_split=0.1, deterministic=False, effective_train_size=1., effective_valid_size=1., effective_test_size=1., fixed_subset=False, sequential=False, test_only=False): """Load a dataset. Args: dataset: a string with the name of the dataset to load (cifar10/imagenet) + arch: a string with the name of the model architecture data_dir: the directory where the dataset resides batch_size: the batch size workers: the number of worker threads to use for loading the data @@ -86,12 +89,12 @@ def load_data(dataset, data_dir, batch_size, workers, validation_split=0.1, dete """ if dataset not in DATASETS_NAMES: raise ValueError('load_data does not support dataset %s" % dataset') - datasets_fn = __dataset_factory(dataset) - return get_data_loaders(datasets_fn, data_dir, batch_size, workers, + datasets_fn = __dataset_factory(dataset, arch) + return get_data_loaders(datasets_fn, data_dir, batch_size, workers, validation_split=validation_split, - deterministic=deterministic, + deterministic=deterministic, effective_train_size=effective_train_size, - effective_valid_size=effective_valid_size, + effective_valid_size=effective_valid_size, effective_test_size=effective_test_size, fixed_subset=fixed_subset, sequential=sequential, @@ -163,20 +166,29 @@ def cifar10_get_datasets(data_dir, load_train=True, load_test=True): return train_dataset, test_dataset - -def imagenet_get_datasets(data_dir, load_train=True, load_test=True): + +def imagenet_get_datasets(data_dir, arch, load_train=True, load_test=True): """ Load the ImageNet dataset. """ + # Inception Network accepts image of size 3, 299, 299 + if distiller.models.is_inception(arch): + resize, crop = 336, 299 + else: + resize, crop = 256, 224 + if arch == 'googlenet': + normalize = transforms.Normalize(mean=[0.5, 0.5, 0.5], + std=[0.5, 0.5, 0.5]) + else: + normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], + std=[0.229, 0.224, 0.225]) train_dir = os.path.join(data_dir, 'train') test_dir = os.path.join(data_dir, 'val') - normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], - std=[0.229, 0.224, 0.225]) train_dataset = None if load_train: train_transform = transforms.Compose([ - transforms.RandomResizedCrop(224), + transforms.RandomResizedCrop(crop), transforms.RandomHorizontalFlip(), transforms.ToTensor(), normalize, @@ -187,8 +199,8 @@ def imagenet_get_datasets(data_dir, load_train=True, load_test=True): test_dataset = None if load_test: test_transform = transforms.Compose([ - transforms.Resize(256), - transforms.CenterCrop(224), + transforms.Resize(resize), + transforms.CenterCrop(crop), transforms.ToTensor(), normalize, ]) @@ -197,7 +209,6 @@ def imagenet_get_datasets(data_dir, load_train=True, load_test=True): return train_dataset, test_dataset - def __image_size(dataset): # un-squeeze is used here to add the batch dimension (value=1), which is missing return dataset[0][0].unsqueeze(0).size() diff --git a/distiller/apputils/image_classifier.py b/distiller/apputils/image_classifier.py index c2e956f3a2069acb3f4dd4339bcf884bdf02508f..3957df7a5702960d7bc2f19889225382502b91d0 100755 --- a/distiller/apputils/image_classifier.py +++ b/distiller/apputils/image_classifier.py @@ -472,7 +472,7 @@ def save_collectors_data(collectors, directory): def load_data(args, fixed_subset=False, sequential=False, load_train=True, load_val=True, load_test=True): test_only = not load_train and not load_val - train_loader, val_loader, test_loader, _ = apputils.load_data(args.dataset, + train_loader, val_loader, test_loader, _ = apputils.load_data(args.dataset, args.arch, os.path.expanduser(args.data), args.batch_size, args.workers, args.validation_split, args.deterministic, args.effective_train_size, args.effective_valid_size, args.effective_test_size, @@ -488,7 +488,7 @@ def load_data(args, fixed_subset=False, sequential=False, load_train=True, load_ loaders = [loaders[i] for i, flag in enumerate(flags) if flag] if len(loaders) == 1: - # Unpack the list for convinience + # Unpack the list for convenience loaders = loaders[0] return loaders @@ -579,9 +579,19 @@ def train(train_loader, model, criterion, optimizer, epoch, output = args.kd_policy.forward(inputs) if not early_exit_mode(args): - loss = criterion(output, target) + # Handle loss calculation for inception models separately due to auxiliary outputs + # if user turned off auxiliary classifiers by hand, then loss should be calculated normally, + # so, we have this check to ensure we only call this function when output is a tuple + if models.is_inception(args.arch) and isinstance(output, tuple): + loss = inception_training_loss(output, target, criterion, args) + else: + loss = criterion(output, target) # Measure accuracy - classerr.add(output.detach(), target) + # For inception models, we only consider accuracy of main classifier + if isinstance(output, tuple): + classerr.add(output[0].detach(), target) + else: + classerr.add(output.detach(), target) acc_stats.append([classerr.value(1), classerr.value(5)]) else: # Measure accuracy and record loss @@ -741,6 +751,44 @@ def _validate(data_loader, model, criterion, loggers, args, epoch=-1): return total_top1, total_top5, losses_exits_stats[args.num_exits-1] +def inception_training_loss(output, target, criterion, args): + """Compute weighted loss for Inception networks as they have auxiliary classifiers + + Auxiliary classifiers were added to inception networks to tackle the vanishing gradient problem + They apply softmax to outputs of one or more intermediate inception modules and compute auxiliary + loss over same labels. + Note that auxiliary loss is purely used for training purposes, as they are disabled during inference. + + GoogleNet has 2 auxiliary classifiers, hence two 3 outputs in total, output[0] is main classifier output, + output[1] is aux2 classifier output and output[2] is aux1 classifier output and the weights of the + aux losses are weighted by 0.3 according to the paper (C. Szegedy et al., "Going deeper with convolutions," + 2015 IEEE Conference on Computer Vision and Pattern Recognition (CVPR), Boston, MA, 2015, pp. 1-9.) + + All other versions of Inception networks have only one auxiliary classifier, and the auxiliary loss + is weighted by 0.4 according to PyTorch documentation + # From https://discuss.pytorch.org/t/how-to-optimize-inception-model-with-auxiliary-classifiers/7958 + """ + weighted_loss = 0 + if args.arch == 'googlenet': + # DEFAULT, aux classifiers are NOT included in PyTorch Pretrained googlenet model as they are NOT trained, + # they are only present if network is trained from scratch. If you need to fine tune googlenet (e.g. after + # pruning a pretrained model), then you have to explicitly enable aux classifiers when creating the model + # DEFAULT, in case of pretrained model, output length is 1, so loss will be calculated in main training loop + # instead of here, as we enter this function only if output is a tuple (len>1) + # TODO: Enable user to feed some input to add aux classifiers for pretrained googlenet model + outputs, aux2_outputs, aux1_outputs = output # extract all 3 outputs + loss0 = criterion(outputs, target) + loss1 = criterion(aux1_outputs, target) + loss2 = criterion(aux2_outputs, target) + weighted_loss = loss0 + 0.3*loss1 + 0.3*loss2 + else: + outputs, aux_outputs = output # extract two outputs + loss0 = criterion(outputs, target) + loss1 = criterion(aux_outputs, target) + weighted_loss = loss0 + 0.4*loss1 + return weighted_loss + + def earlyexit_loss(output, target, criterion, args): """Compute the weighted sum of the exits losses diff --git a/distiller/models/__init__.py b/distiller/models/__init__.py index c0569fccf0846c320656599f3dd53ba37a45917e..d5c288f8f19072fca908aef11457107eff5c4de7 100755 --- a/distiller/models/__init__.py +++ b/distiller/models/__init__.py @@ -158,6 +158,13 @@ def create_model(pretrained, dataset, arch, parallel=True, device_ids=None): return model.to(device) +def is_inception(arch): + return arch in [ # Torchvision architectures + 'inception_v3', 'googlenet', + # Cadene architectures + 'inceptionv3', 'inceptionv4', 'inceptionresnetv2'] + + def _create_imagenet_model(arch, pretrained): dataset = "imagenet" cadene = False @@ -166,9 +173,13 @@ def _create_imagenet_model(arch, pretrained): model = imagenet_extra_models.__dict__[arch](pretrained=pretrained) elif arch in TORCHVISION_MODEL_NAMES: try: - model = getattr(torch_models, arch)(pretrained=pretrained) + if is_inception(arch): + model = getattr(torch_models, arch)(pretrained=pretrained, transform_input=False) + else: + model = getattr(torch_models, arch)(pretrained=pretrained) if arch == "mobilenet_v2": patch_torchvision_mobilenet_v2(model) + except NotImplementedError: # In torchvision 0.3, trying to download a model that has no # pretrained image available will raise NotImplementedError