From 9e78723897e6fc0435536a04a47128100341e991 Mon Sep 17 00:00:00 2001
From: Soumendu Kumar Ghosh <31513237+soumendukrg@users.noreply.github.com>
Date: Mon, 27 Apr 2020 14:57:24 -0400
Subject: [PATCH] Modify image size and training for Inception Models (#425)

* Merge pytorch 1.3 commits

This PR is a fix for issue #422.

1. ImageNet models usually use input size [batch, 3, 224, 224], but all Inception models require an input image size of [batch, 3, 299, 299].

2. Inception models have auxiliary branches which contribute to the loss only during training.  The reported classification loss only considers the main classification loss.

3. Inception_V3 normalizes the input inside the network itself.  More details can be found in @soumendukrg's PR #425 [comments](https://github.com/NervanaSystems/distiller/pull/425#issuecomment-557941736).

NOTE: Training using Inception_V3 is only possible on a single GPU as of now. This issue talks about this problem. I have checked and this problem persists in torch 1.3.0:
[inception_v3 of vision 0.3.0 does not fit in DataParallel of torch 1.1.0 #1048](https://github.com/pytorch/vision/issues/1048)

Co-authored-by: Neta Zmora <neta.zmora@intel.com>
---
 distiller/apputils/data_loaders.py     | 41 ++++++++++++-------
 distiller/apputils/image_classifier.py | 56 ++++++++++++++++++++++++--
 distiller/models/__init__.py           | 13 +++++-
 3 files changed, 90 insertions(+), 20 deletions(-)

diff --git a/distiller/apputils/data_loaders.py b/distiller/apputils/data_loaders.py
index ae18998..e187750 100755
--- a/distiller/apputils/data_loaders.py
+++ b/distiller/apputils/data_loaders.py
@@ -24,6 +24,7 @@ import torch
 import torchvision.transforms as transforms
 import torchvision.datasets as datasets
 from torch.utils.data.sampler import Sampler
+from functools import partial
 import numpy as np
 import distiller
 
@@ -58,19 +59,21 @@ def classification_get_input_shape(dataset):
         raise ValueError("dataset %s is not supported" % dataset)
 
 
-def __dataset_factory(dataset):
+def __dataset_factory(dataset, arch):
     return {'cifar10': cifar10_get_datasets,
             'mnist': mnist_get_datasets,
-            'imagenet': imagenet_get_datasets}.get(dataset, None)
+            'imagenet': partial(imagenet_get_datasets, arch=arch)}.get(dataset, None)
 
 
-def load_data(dataset, data_dir, batch_size, workers, validation_split=0.1, deterministic=False,
+def load_data(dataset, arch, data_dir,
+              batch_size, workers, validation_split=0.1, deterministic=False,
               effective_train_size=1., effective_valid_size=1., effective_test_size=1.,
               fixed_subset=False, sequential=False, test_only=False):
     """Load a dataset.
 
     Args:
         dataset: a string with the name of the dataset to load (cifar10/imagenet)
+        arch: a string with the name of the model architecture
         data_dir: the directory where the dataset resides
         batch_size: the batch size
         workers: the number of worker threads to use for loading the data
@@ -86,12 +89,12 @@ def load_data(dataset, data_dir, batch_size, workers, validation_split=0.1, dete
     """
     if dataset not in DATASETS_NAMES:
         raise ValueError('load_data does not support dataset %s" % dataset')
-    datasets_fn = __dataset_factory(dataset)
-    return get_data_loaders(datasets_fn, data_dir, batch_size, workers, 
+    datasets_fn = __dataset_factory(dataset, arch)
+    return get_data_loaders(datasets_fn, data_dir, batch_size, workers,
                             validation_split=validation_split,
-                            deterministic=deterministic, 
+                            deterministic=deterministic,
                             effective_train_size=effective_train_size,
-                            effective_valid_size=effective_valid_size, 
+                            effective_valid_size=effective_valid_size,
                             effective_test_size=effective_test_size,
                             fixed_subset=fixed_subset,
                             sequential=sequential,
@@ -163,20 +166,29 @@ def cifar10_get_datasets(data_dir, load_train=True, load_test=True):
 
     return train_dataset, test_dataset
 
-
-def imagenet_get_datasets(data_dir, load_train=True, load_test=True):
+  
+def imagenet_get_datasets(data_dir, arch, load_train=True, load_test=True):
     """
     Load the ImageNet dataset.
     """
+    # Inception Network accepts image of size 3, 299, 299
+    if distiller.models.is_inception(arch):
+        resize, crop = 336, 299
+    else:
+        resize, crop = 256, 224
+    if arch == 'googlenet':
+        normalize = transforms.Normalize(mean=[0.5, 0.5, 0.5],
+                                         std=[0.5, 0.5, 0.5])
+    else:
+        normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
+                                         std=[0.229, 0.224, 0.225])
     train_dir = os.path.join(data_dir, 'train')
     test_dir = os.path.join(data_dir, 'val')
-    normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
-                                     std=[0.229, 0.224, 0.225])
 
     train_dataset = None
     if load_train:
         train_transform = transforms.Compose([
-            transforms.RandomResizedCrop(224),
+            transforms.RandomResizedCrop(crop),
             transforms.RandomHorizontalFlip(),
             transforms.ToTensor(),
             normalize,
@@ -187,8 +199,8 @@ def imagenet_get_datasets(data_dir, load_train=True, load_test=True):
     test_dataset = None
     if load_test:
         test_transform = transforms.Compose([
-            transforms.Resize(256),
-            transforms.CenterCrop(224),
+            transforms.Resize(resize),
+            transforms.CenterCrop(crop),
             transforms.ToTensor(),
             normalize,
         ])
@@ -197,7 +209,6 @@ def imagenet_get_datasets(data_dir, load_train=True, load_test=True):
 
     return train_dataset, test_dataset
 
-
 def __image_size(dataset):
     # un-squeeze is used here to add the batch dimension (value=1), which is missing
     return dataset[0][0].unsqueeze(0).size()
diff --git a/distiller/apputils/image_classifier.py b/distiller/apputils/image_classifier.py
index c2e956f..3957df7 100755
--- a/distiller/apputils/image_classifier.py
+++ b/distiller/apputils/image_classifier.py
@@ -472,7 +472,7 @@ def save_collectors_data(collectors, directory):
 def load_data(args, fixed_subset=False, sequential=False, load_train=True, load_val=True, load_test=True):
     test_only = not load_train and not load_val
 
-    train_loader, val_loader, test_loader, _ = apputils.load_data(args.dataset,
+    train_loader, val_loader, test_loader, _ = apputils.load_data(args.dataset, args.arch,
                               os.path.expanduser(args.data), args.batch_size,
                               args.workers, args.validation_split, args.deterministic,
                               args.effective_train_size, args.effective_valid_size, args.effective_test_size,
@@ -488,7 +488,7 @@ def load_data(args, fixed_subset=False, sequential=False, load_train=True, load_
     loaders = [loaders[i] for i, flag in enumerate(flags) if flag]
     
     if len(loaders) == 1:
-        # Unpack the list for convinience
+        # Unpack the list for convenience
         loaders = loaders[0]
     return loaders
 
@@ -579,9 +579,19 @@ def train(train_loader, model, criterion, optimizer, epoch,
             output = args.kd_policy.forward(inputs)
 
         if not early_exit_mode(args):
-            loss = criterion(output, target)
+            # Handle loss calculation for inception models separately due to auxiliary outputs
+            # if user turned off auxiliary classifiers by hand, then loss should be calculated normally,
+            # so, we have this check to ensure we only call this function when output is a tuple
+            if models.is_inception(args.arch) and isinstance(output, tuple):
+                loss = inception_training_loss(output, target, criterion, args)
+            else:
+                loss = criterion(output, target)
             # Measure accuracy
-            classerr.add(output.detach(), target)
+            # For inception models, we only consider accuracy of main classifier
+            if isinstance(output, tuple):
+                classerr.add(output[0].detach(), target)
+            else:
+                classerr.add(output.detach(), target)
             acc_stats.append([classerr.value(1), classerr.value(5)])
         else:
             # Measure accuracy and record loss
@@ -741,6 +751,44 @@ def _validate(data_loader, model, criterion, loggers, args, epoch=-1):
         return total_top1, total_top5, losses_exits_stats[args.num_exits-1]
 
 
+def inception_training_loss(output, target, criterion, args):
+    """Compute weighted loss for Inception networks as they have auxiliary classifiers
+
+    Auxiliary classifiers were added to inception networks to tackle the vanishing gradient problem
+    They apply softmax to outputs of one or more intermediate inception modules and compute auxiliary
+    loss over same labels.
+    Note that auxiliary loss is purely used for training purposes, as they are disabled during inference.
+
+    GoogleNet has 2 auxiliary classifiers, hence two 3 outputs in total, output[0] is main classifier output,
+    output[1] is aux2 classifier output and output[2] is aux1 classifier output and the weights of the
+    aux losses are weighted by 0.3 according to the paper (C. Szegedy et al., "Going deeper with convolutions,"
+    2015 IEEE Conference on Computer Vision and Pattern Recognition (CVPR), Boston, MA, 2015, pp. 1-9.)
+
+    All other versions of Inception networks have only one auxiliary classifier, and the auxiliary loss
+    is weighted by 0.4 according to PyTorch documentation
+    # From https://discuss.pytorch.org/t/how-to-optimize-inception-model-with-auxiliary-classifiers/7958
+    """
+    weighted_loss = 0
+    if args.arch == 'googlenet':
+        # DEFAULT, aux classifiers are NOT included in PyTorch Pretrained googlenet model as they are NOT trained,
+        # they are only present if network is trained from scratch. If you need to fine tune googlenet (e.g. after
+        # pruning a pretrained model), then you have to explicitly enable aux classifiers when creating the model
+        # DEFAULT, in case of pretrained model, output length is 1, so loss will be calculated in main training loop
+        # instead of here, as we enter this function only if output is a tuple (len>1)
+        # TODO: Enable user to feed some input to add aux classifiers for pretrained googlenet model
+        outputs, aux2_outputs, aux1_outputs = output    # extract all 3 outputs
+        loss0 = criterion(outputs, target)
+        loss1 = criterion(aux1_outputs, target)
+        loss2 = criterion(aux2_outputs, target)
+        weighted_loss = loss0 + 0.3*loss1 + 0.3*loss2
+    else:
+        outputs, aux_outputs = output    # extract two outputs
+        loss0 = criterion(outputs, target)
+        loss1 = criterion(aux_outputs, target)
+        weighted_loss = loss0 + 0.4*loss1
+    return weighted_loss
+
+
 def earlyexit_loss(output, target, criterion, args):
     """Compute the weighted sum of the exits losses
 
diff --git a/distiller/models/__init__.py b/distiller/models/__init__.py
index c0569fc..d5c288f 100755
--- a/distiller/models/__init__.py
+++ b/distiller/models/__init__.py
@@ -158,6 +158,13 @@ def create_model(pretrained, dataset, arch, parallel=True, device_ids=None):
     return model.to(device)
 
 
+def is_inception(arch):
+    return arch in [ # Torchvision architectures
+                    'inception_v3', 'googlenet',
+                    # Cadene architectures
+                    'inceptionv3', 'inceptionv4', 'inceptionresnetv2']
+
+
 def _create_imagenet_model(arch, pretrained):
     dataset = "imagenet"
     cadene = False
@@ -166,9 +173,13 @@ def _create_imagenet_model(arch, pretrained):
         model = imagenet_extra_models.__dict__[arch](pretrained=pretrained)
     elif arch in TORCHVISION_MODEL_NAMES:
         try:
-            model = getattr(torch_models, arch)(pretrained=pretrained)
+            if is_inception(arch):
+                model = getattr(torch_models, arch)(pretrained=pretrained, transform_input=False)
+            else:
+                model = getattr(torch_models, arch)(pretrained=pretrained)
             if arch == "mobilenet_v2":
                 patch_torchvision_mobilenet_v2(model)
+
         except NotImplementedError:
             # In torchvision 0.3, trying to download a model that has no
             # pretrained image available will raise NotImplementedError
-- 
GitLab