From a7473c9508716ed9a656a880b6fe820f311bfb3c Mon Sep 17 00:00:00 2001
From: Neta Zmora <neta.zmora@intel.com>
Date: Fri, 8 Nov 2019 18:34:39 +0200
Subject: [PATCH] Early-exit refactoring: flexible exits installation

Exits can now be attached to any point in the network
By specifying the name of the attachment node and
the exit-branch subgraph.
---
 .../models/cifar10/resnet_cifar_earlyexit.py  | 83 ++++++++++---------
 distiller/modules/__init__.py                 |  3 +-
 distiller/modules/topology.py                 | 55 ++++++++++++
 3 files changed, 102 insertions(+), 39 deletions(-)
 create mode 100644 distiller/modules/topology.py

diff --git a/distiller/models/cifar10/resnet_cifar_earlyexit.py b/distiller/models/cifar10/resnet_cifar_earlyexit.py
index 26b67c5..e323153 100644
--- a/distiller/models/cifar10/resnet_cifar_earlyexit.py
+++ b/distiller/models/cifar10/resnet_cifar_earlyexit.py
@@ -14,12 +14,11 @@
 # limitations under the License.
 #
 
-"""Resnet for CIFAR10
+"""Resnet for CIFAR10 with Early Exit branches
 
 Resnet for CIFAR10, based on "Deep Residual Learning for Image Recognition".
 This is based on TorchVision's implementation of ResNet for ImageNet, with appropriate
 changes for the 10-class Cifar-10 dataset.
-This ResNet also has layer gates, to be able to dynamically remove layers.
 
 @inproceedings{DBLP:conf/cvpr/HeZRS16,
   author    = {Kaiming He and
@@ -34,12 +33,10 @@ This ResNet also has layer gates, to be able to dynamically remove layers.
 }
 
 """
-import torch.nn as nn
-import math
-import torch.utils.model_zoo as model_zoo
-import torchvision.models as models
 from .resnet_cifar import BasicBlock
 from .resnet_cifar import ResNetCifar
+import torch.nn as nn
+from distiller.modules import BranchPoint
 
 
 __all__ = ['resnet20_cifar_earlyexit', 'resnet32_cifar_earlyexit', 'resnet44_cifar_earlyexit',
@@ -53,48 +50,58 @@ def conv3x3(in_planes, out_planes, stride=1):
                      padding=1, bias=False)
 
 
-class ExitBranch(nn.Module):
-    def __init__(self, num_classes):
-        super().__init__()
-        self.avg_pool = nn.AvgPool2d(3)
-        self.linear = nn.Linear(1600, num_classes)
+def get_exits_def():
+    exits_def = [('layer1.2.relu2', nn.Sequential(nn.AvgPool2d(3),
+                            nn.Flatten(),
+                            nn.Linear(1600, NUM_CLASSES)))]
+    return exits_def
 
-    def forward(self, x):
-        x = self.avg_pool(x)
-        x = x.view(x.size(0), -1)
-        x = self.linear(x)
-        return x
 
+def find_module(model, mod_name):
+    """Locate a module, given its full name"""
+    for name, module in model.named_modules():
+        if name == mod_name:
+            return module
+    return None
 
-class BranchPoint(nn.Module):
-    def __init__(self, original_m, exit_m):
-        super().__init__()
-        self.original_m = original_m
-        self.exit_m = exit_m
-        self.output = None
 
-    def forward(self, x):
-        x1 = self.original_m.forward(x)
-        x2 = self.exit_m.forward(x1)
-        self.output = x2
-        return x1
+def split_module_name(mod_name):
+    name_parts = mod_name.split('.')
+    parent = '.'.join(name_parts[:-1])
+    node = name_parts[-1]
+    return parent, node
 
 
 class ResNetCifarEarlyExit(ResNetCifar):
-    def __init__(self, block, layers, num_classes=NUM_CLASSES):
-        super().__init__(block, layers, num_classes)
-
-        # Define early exit branches and install them
-        self.exit_branch = ExitBranch(num_classes)
-        self.layer1 = BranchPoint(self.layer1, self.exit_branch)
+    def __init__(self, *args, **kwargs):
+        super().__init__(*args, **kwargs)
+        self.exit_points = []
+        self.attach_exits(get_exits_def())
+
+    def attach_exits(self, exits_def):
+        # For each exit point, we:
+        # 1. Cache the name of the exit_point module (i.e. the name of the module
+        #    whose output we forward to the exit branch).
+        # 2. Override the exit_point module with an instance of BranchPoint
+        for exit_point, exit_branch in exits_def:
+            self.exit_points.append(exit_point)
+            replaced_module = find_module(self, exit_point)
+            parent_name, node_name = split_module_name(exit_point)
+            parent_module = find_module(self, parent_name)
+            parent_module.__setattr__(node_name, BranchPoint(replaced_module, exit_branch))
 
     def forward(self, x):
+        # Run the input through the network
         x = super().forward(x)
-
-        # return a list of probabilities
-        exit0 = self.layer1.output
-        output = (exit0, x)
-        return output
+        # Collect the outputs of all the exits and return them
+        outputs = []
+        for exit_point in self.exit_points:
+            parent_name, node_name = split_module_name(exit_point)
+            parent_module = find_module(self, parent_name)
+            output = parent_module.__getattr__(node_name).output
+            outputs.append(output)
+        outputs += [x]
+        return outputs
 
 
 def resnet20_cifar_earlyexit(**kwargs):
diff --git a/distiller/modules/__init__.py b/distiller/modules/__init__.py
index 71c72ee..bc80489 100644
--- a/distiller/modules/__init__.py
+++ b/distiller/modules/__init__.py
@@ -19,8 +19,9 @@ from .grouping import *
 from .matmul import *
 from .rnn import *
 from .aggregate import *
+from .topology import *
 
 __all__ = ['EltwiseAdd', 'EltwiseMult', 'EltwiseDiv', 'Matmul', 'BatchMatmul',
            'Concat', 'Chunk', 'Split', 'Stack',
            'DistillerLSTMCell', 'DistillerLSTM', 'convert_model_to_distiller_lstm',
-           'Norm', 'Mean']
+           'Norm', 'Mean', 'BranchPoint']
diff --git a/distiller/modules/topology.py b/distiller/modules/topology.py
new file mode 100644
index 0000000..81f5c36
--- /dev/null
+++ b/distiller/modules/topology.py
@@ -0,0 +1,55 @@
+#
+# Copyright (c) 2019 Intel Corporation
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#      http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+"""Modules related to a model's topology"""
+import torch.nn as nn
+
+
+class BranchPoint(nn.Module):
+    """Add a branch to an existing model."""
+    def __init__(self, branched_module, branch_net):
+        """
+        :param branched_module: the module in the original network to which we add a branch.
+        :param branch_net: the new branch
+        """
+        super().__init__()
+        self.branched_module = branched_module
+        self.branch_net = branch_net
+        self.output = None
+
+    def forward(self, x):
+        x1 = self.branched_module.forward(x)
+        self.output = self.branch_net.forward(x1)
+        return x1
+
+
+# This class is "borrowed" from PyTorch 1.3 until we integrate it
+class Flatten(nn.Module):
+    __constants__ = ['start_dim', 'end_dim']
+
+    def __init__(self, start_dim=1, end_dim=-1):
+        super(Flatten, self).__init__()
+        self.start_dim = start_dim
+        self.end_dim = end_dim
+
+    def forward(self, input):
+        return input.flatten(self.start_dim, self.end_dim)
+
+
+# A temporary trick to see if we need to add Flatten to the `torch.nn` namespace for convenience.
+try:
+    Flatten = nn.Flatten
+except AttributeError:
+    nn.Flatten = Flatten
\ No newline at end of file
-- 
GitLab