From 795590c8869e849c107289bc300585a42f00f178 Mon Sep 17 00:00:00 2001
From: Neta Zmora <neta.zmora@intel.com>
Date: Mon, 11 Nov 2019 00:08:07 +0200
Subject: [PATCH] early-exit: further refactoring and resnet50-imagenet

Refactor EE code and place in a separate file.
Fix resnet50-earlyexit (inputs of nn.Linear layers was wrong).

Caveats:
1. resnet50-earlyexit performance needs to be tested for performance.
2. there is still too much EE code dispersed in apputils/image_classifier.py
and compress_classifier.py
---
 distiller/__init__.py                         |  1 +
 distiller/apputils/image_classifier.py        |  4 +-
 distiller/early_exit.py                       | 83 +++++++++++++++++
 .../models/cifar10/resnet_cifar_earlyexit.py  | 50 ++---------
 distiller/models/imagenet/resnet_earlyexit.py | 90 ++++++++-----------
 distiller/modules/__init__.py                 | 13 ++-
 tests/test_pruning.py                         |  2 +-
 7 files changed, 148 insertions(+), 95 deletions(-)
 create mode 100644 distiller/early_exit.py

diff --git a/distiller/__init__.py b/distiller/__init__.py
index 62556f4..220b8d0 100755
--- a/distiller/__init__.py
+++ b/distiller/__init__.py
@@ -26,6 +26,7 @@ from .policy import *
 from .thinning import *
 from .knowledge_distillation import KnowledgeDistillationPolicy, DistillationLossWeights
 from .summary_graph import SummaryGraph, onnx_name_2_pytorch_name
+from .early_exit import EarlyExitMgr
 
 import logging
 logging.captureWarnings(True)
diff --git a/distiller/apputils/image_classifier.py b/distiller/apputils/image_classifier.py
index f73d6b8..98cadfd 100755
--- a/distiller/apputils/image_classifier.py
+++ b/distiller/apputils/image_classifier.py
@@ -534,7 +534,7 @@ def train(train_loader, model, criterion, optimizer, epoch,
     data_time = tnt.AverageValueMeter()
 
     # For Early Exit, we define statistics for each exit, so
-    # exiterrors is analogous to classerr for the non-Early Exit case
+    # `exiterrors` is analogous to `classerr` in the non-Early Exit case
     if early_exit_mode(args):
         args.exiterrors = []
         for exitnum in range(args.num_exits):
@@ -749,6 +749,8 @@ def earlyexit_loss(output, target, criterion, args):
     sum_lossweights = sum(args.earlyexit_lossweights)
     assert sum_lossweights < 1
     for exitnum in range(args.num_exits-1):
+        if output[exitnum] is None:
+            continue
         exit_loss = criterion(output[exitnum], target)
         weighted_loss += args.earlyexit_lossweights[exitnum] * exit_loss
         args.exiterrors[exitnum].add(output[exitnum].detach(), target)
diff --git a/distiller/early_exit.py b/distiller/early_exit.py
new file mode 100644
index 0000000..721cf7f
--- /dev/null
+++ b/distiller/early_exit.py
@@ -0,0 +1,83 @@
+#
+# Copyright (c) 2019 Intel Corporation
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#      http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+
+__all__ = ["EarlyExitMgr"]
+
+
+from distiller.modules import BranchPoint
+
+
+class EarlyExitMgr(object):
+    def __init__(self):
+        self.exit_points = []
+
+    def attach_exits(self, model, exits_def):
+        # For each exit point, we:
+        # 1. Cache the name of the exit_point module (i.e. the name of the module
+        #    whose output we forward to the exit branch).
+        # 2. Override the exit_point module with an instance of BranchPoint
+        for exit_point, exit_branch in exits_def:
+            self.exit_points.append(exit_point)
+            replaced_module = _find_module(model, exit_point)
+            assert replaced_module is not None, "Could not find exit point {}".format(exit_point)
+            parent_name, node_name = _split_module_name(exit_point)
+            parent_module = _find_module(model, parent_name)
+            # Replace the module `node_name` with an instance of `BranchPoint`
+            parent_module.__setattr__(node_name, BranchPoint(replaced_module, exit_branch))
+
+    def get_exits_outputs(self, model):
+        """Collect the outputs of all the exits and return them.
+
+        The output of each exit was cached during the network forward.
+        """
+        outputs = []
+        for exit_point in self.exit_points:
+            parent_name, node_name = _split_module_name(exit_point)
+            parent_module = _find_module(model, parent_name)
+            output = parent_module.__getattr__(node_name).output
+            assert output is not None
+            outputs.append(output)
+        return outputs
+
+    def delete_exits_outputs(self, model):
+        """Delete the outputs of all the exits.
+
+        Some exit branches may not be traversed, so we need to delete the cached
+        outputs to make sure these outputs do not participate in the backprop.
+        """
+        outputs = []
+        for exit_point in self.exit_points:
+            parent_name, node_name = _split_module_name(exit_point)
+            parent_module = _find_module(model, parent_name)
+            branch_point = parent_module.__getattr__(node_name)
+            branch_point.output = None
+        return outputs
+
+
+def _find_module(model, mod_name):
+    """Locate a module, given its full name"""
+    for name, module in model.named_modules():
+        if name == mod_name:
+            return module
+    return None
+
+
+def _split_module_name(mod_name):
+    name_parts = mod_name.split('.')
+    parent = '.'.join(name_parts[:-1])
+    node = name_parts[-1]
+    return parent, node
diff --git a/distiller/models/cifar10/resnet_cifar_earlyexit.py b/distiller/models/cifar10/resnet_cifar_earlyexit.py
index e323153..fd7ee6b 100644
--- a/distiller/models/cifar10/resnet_cifar_earlyexit.py
+++ b/distiller/models/cifar10/resnet_cifar_earlyexit.py
@@ -1,5 +1,5 @@
 #
-# Copyright (c) 2018 Intel Corporation
+# Copyright (c) 2019 Intel Corporation
 #
 # Licensed under the Apache License, Version 2.0 (the "License");
 # you may not use this file except in compliance with the License.
@@ -36,7 +36,7 @@ changes for the 10-class Cifar-10 dataset.
 from .resnet_cifar import BasicBlock
 from .resnet_cifar import ResNetCifar
 import torch.nn as nn
-from distiller.modules import BranchPoint
+import distiller
 
 
 __all__ = ['resnet20_cifar_earlyexit', 'resnet32_cifar_earlyexit', 'resnet44_cifar_earlyexit',
@@ -52,55 +52,23 @@ def conv3x3(in_planes, out_planes, stride=1):
 
 def get_exits_def():
     exits_def = [('layer1.2.relu2', nn.Sequential(nn.AvgPool2d(3),
-                            nn.Flatten(),
-                            nn.Linear(1600, NUM_CLASSES)))]
+                                                  nn.Flatten(),
+                                                  nn.Linear(1600, NUM_CLASSES)))]
     return exits_def
 
 
-def find_module(model, mod_name):
-    """Locate a module, given its full name"""
-    for name, module in model.named_modules():
-        if name == mod_name:
-            return module
-    return None
-
-
-def split_module_name(mod_name):
-    name_parts = mod_name.split('.')
-    parent = '.'.join(name_parts[:-1])
-    node = name_parts[-1]
-    return parent, node
-
 
 class ResNetCifarEarlyExit(ResNetCifar):
     def __init__(self, *args, **kwargs):
         super().__init__(*args, **kwargs)
-        self.exit_points = []
-        self.attach_exits(get_exits_def())
-
-    def attach_exits(self, exits_def):
-        # For each exit point, we:
-        # 1. Cache the name of the exit_point module (i.e. the name of the module
-        #    whose output we forward to the exit branch).
-        # 2. Override the exit_point module with an instance of BranchPoint
-        for exit_point, exit_branch in exits_def:
-            self.exit_points.append(exit_point)
-            replaced_module = find_module(self, exit_point)
-            parent_name, node_name = split_module_name(exit_point)
-            parent_module = find_module(self, parent_name)
-            parent_module.__setattr__(node_name, BranchPoint(replaced_module, exit_branch))
+        self.ee_mgr = distiller.EarlyExitMgr()
+        self.ee_mgr.attach_exits(self, get_exits_def())
 
     def forward(self, x):
-        # Run the input through the network
+        self.ee_mgr.delete_exits_outputs(self)
+        # Run the input through the network (including exits)
         x = super().forward(x)
-        # Collect the outputs of all the exits and return them
-        outputs = []
-        for exit_point in self.exit_points:
-            parent_name, node_name = split_module_name(exit_point)
-            parent_module = find_module(self, parent_name)
-            output = parent_module.__getattr__(node_name).output
-            outputs.append(output)
-        outputs += [x]
+        outputs = self.ee_mgr.get_exits_outputs(self) + [x]
         return outputs
 
 
diff --git a/distiller/models/imagenet/resnet_earlyexit.py b/distiller/models/imagenet/resnet_earlyexit.py
index 4e6ba99..03fd6c9 100644
--- a/distiller/models/imagenet/resnet_earlyexit.py
+++ b/distiller/models/imagenet/resnet_earlyexit.py
@@ -1,10 +1,8 @@
 import torch.nn as nn
-import math
-import torch.utils.model_zoo as model_zoo
 import torchvision.models as models
-from torchvision.models.resnet import Bottleneck
-from torchvision.models.resnet import BasicBlock
-
+from torchvision.models.resnet import ResNet, BasicBlock, Bottleneck
+from .resnet import DistillerBottleneck
+import distiller
 
 __all__ = ['resnet50_earlyexit']
 
@@ -15,59 +13,49 @@ def conv3x3(in_planes, out_planes, stride=1):
                      padding=1, bias=False)
 
 
-class ResNetEarlyExit(models.ResNet):
+def get_exits_def(num_classes):
+    expansion = 1 # models.ResNet.BasicBlock.expansion
+    exits_def = [('layer1.2.relu3', nn.Sequential(nn.AdaptiveAvgPool2d((1, 1)),
+                                                  nn.Conv2d(256, 50, kernel_size=7, stride=2, padding=3, bias=True),
+                                                  nn.Conv2d(50, 12, kernel_size=7, stride=2, padding=3, bias=True),
+                                                  #nn.AdaptiveAvgPool2d((1, 1)),
+                                                  nn.Flatten(),
+                                                  nn.Linear(12 * expansion, num_classes))),
+                                                  #distiller.modules.Print())),
+                 ('layer2.3.relu3', nn.Sequential(nn.Conv2d(512, 12, kernel_size=7, stride=2, padding=3, bias=True),
+                                                  nn.AdaptiveAvgPool2d((1, 1)),
+                                                  nn.Flatten(),
+                                                  #distiller.modules.Print()))]
+                                                  nn.Linear(12 * expansion, num_classes)))]
+    return exits_def
 
-    def __init__(self, block, layers, num_classes=1000):
-        super(ResNetEarlyExit, self).__init__(block, layers, num_classes)
 
-        # Define early exit layers
-        self.conv1_exit0 = nn.Conv2d(256, 50, kernel_size=7, stride=2, padding=3, bias=True)
-        self.conv2_exit0 = nn.Conv2d(50, 12, kernel_size=7, stride=2, padding=3, bias=True)
-        self.conv1_exit1 = nn.Conv2d(512, 12, kernel_size=7, stride=2, padding=3, bias=True)
-        self.fc_exit0 = nn.Linear(147 * block.expansion, num_classes)
-        self.fc_exit1 = nn.Linear(192 * block.expansion, num_classes)
+class ResNetEarlyExit(models.ResNet):
+    def __init__(self, *args, **kwargs):
+        super().__init__(*args, **kwargs)
+        self.ee_mgr = distiller.EarlyExitMgr()
+        self.ee_mgr.attach_exits(self, get_exits_def(num_classes=1000))
 
     def forward(self, x):
-        x = self.conv1(x)
-        x = self.bn1(x)
-        x = self.relu(x)
-        x = self.maxpool(x)
-
-        x = self.layer1(x)
-
-        # Add early exit layers
-        exit0 = self.avgpool(x)
-        exit0 = self.conv1_exit0(exit0)
-        exit0 = self.conv2_exit0(exit0)
-        exit0 = self.avgpool(exit0)
-        exit0 = exit0.view(exit0.size(0), -1)
-        exit0 = self.fc_exit0(exit0)
+        self.ee_mgr.delete_exits_outputs(self)
+        # Run the input through the network (including exits)
+        x = super().forward(x)
+        outputs = self.ee_mgr.get_exits_outputs(self) + [x]
+        return outputs
 
-        x = self.layer2(x)
 
-        # Add early exit layers
-        exit1 = self.conv1_exit1(x)
-        exit1 = self.avgpool(exit1)
-        exit1 = exit1.view(exit1.size(0), -1)
-        exit1 = self.fc_exit1(exit1)
-
-        x = self.layer3(x)
-        x = self.layer4(x)
-
-        x = self.avgpool(x)
-        x = x.view(x.size(0), -1)
-        x = self.fc(x)
+def _resnet(arch, block, layers, pretrained, progress, **kwargs):
+    model = ResNetEarlyExit(block, layers, **kwargs)
+    assert not pretrained
+    return model
 
-        # return a list of probabilities
-        output = []
-        output.append(exit0)
-        output.append(exit1)
-        output.append(x)
-        return output
 
+def resnet50_earlyexit(pretrained=False, progress=True, **kwargs):
+    """Constructs a ResNet-50 model, with early exit branches.
 
-def resnet50_earlyexit(pretrained=False, **kwargs):
-    """Constructs a ResNet-50 model.
+    Args:
+        pretrained (bool): If True, returns a model pre-trained on ImageNet
+        progress (bool): If True, displays a progress bar of the download to stderr
     """
-    model = ResNetEarlyExit(Bottleneck, [3, 4, 6, 3], **kwargs)
-    return model
+    return _resnet('resnet50', DistillerBottleneck, [3, 4, 6, 3], pretrained, progress,
+                   **kwargs)
\ No newline at end of file
diff --git a/distiller/modules/__init__.py b/distiller/modules/__init__.py
index bc80489..e46bc81 100644
--- a/distiller/modules/__init__.py
+++ b/distiller/modules/__init__.py
@@ -24,4 +24,15 @@ from .topology import *
 __all__ = ['EltwiseAdd', 'EltwiseMult', 'EltwiseDiv', 'Matmul', 'BatchMatmul',
            'Concat', 'Chunk', 'Split', 'Stack',
            'DistillerLSTMCell', 'DistillerLSTM', 'convert_model_to_distiller_lstm',
-           'Norm', 'Mean', 'BranchPoint']
+           'Norm', 'Mean', 'BranchPoint', 'Print']
+
+
+class Print(nn.Module):
+    """Utility module to temporarily replace modules to assess activation shape.
+
+    This is useful, e.g., when creating a new model and you want to know the size
+    of the input to nn.Linear and you want to avoid calculating the shape by hand.
+    """
+    def forward(self, x):
+        print(x.size())
+        return x
\ No newline at end of file
diff --git a/tests/test_pruning.py b/tests/test_pruning.py
index 23a795c..0862962 100755
--- a/tests/test_pruning.py
+++ b/tests/test_pruning.py
@@ -123,6 +123,7 @@ def test_ranked_filter_pruning(parallel):
                                                    is_parallel=parallel)
     test_vgg19_conv_fc_interface(parallel, model=model, zeros_mask_dict=zeros_mask_dict)
 
+
 # todo: add a similar test for ranked channel pruning
 def test_prune_all_filters(parallel):
     """Pruning all of the filteres in a weights tensor of a Convolution
@@ -145,7 +146,6 @@ def ranked_filter_pruning(config, ratio_to_prune, is_parallel, rounding_fn=math.
     logger.info("executing: %s (invoked by %s)" % (inspect.currentframe().f_code.co_name,
                                                    inspect.currentframe().f_back.f_code.co_name))
 
-
     model, zeros_mask_dict = common.setup_test(config.arch, config.dataset, is_parallel)
 
     for pair in config.module_pairs:
-- 
GitLab