diff --git a/distiller/models/cifar10/resnet_cifar_earlyexit.py b/distiller/models/cifar10/resnet_cifar_earlyexit.py index 26b67c5bbb91002602c8604ec858d6f95124d29d..e3231536019100935f7f8bc75e4a6d498fca746b 100644 --- a/distiller/models/cifar10/resnet_cifar_earlyexit.py +++ b/distiller/models/cifar10/resnet_cifar_earlyexit.py @@ -14,12 +14,11 @@ # limitations under the License. # -"""Resnet for CIFAR10 +"""Resnet for CIFAR10 with Early Exit branches Resnet for CIFAR10, based on "Deep Residual Learning for Image Recognition". This is based on TorchVision's implementation of ResNet for ImageNet, with appropriate changes for the 10-class Cifar-10 dataset. -This ResNet also has layer gates, to be able to dynamically remove layers. @inproceedings{DBLP:conf/cvpr/HeZRS16, author = {Kaiming He and @@ -34,12 +33,10 @@ This ResNet also has layer gates, to be able to dynamically remove layers. } """ -import torch.nn as nn -import math -import torch.utils.model_zoo as model_zoo -import torchvision.models as models from .resnet_cifar import BasicBlock from .resnet_cifar import ResNetCifar +import torch.nn as nn +from distiller.modules import BranchPoint __all__ = ['resnet20_cifar_earlyexit', 'resnet32_cifar_earlyexit', 'resnet44_cifar_earlyexit', @@ -53,48 +50,58 @@ def conv3x3(in_planes, out_planes, stride=1): padding=1, bias=False) -class ExitBranch(nn.Module): - def __init__(self, num_classes): - super().__init__() - self.avg_pool = nn.AvgPool2d(3) - self.linear = nn.Linear(1600, num_classes) +def get_exits_def(): + exits_def = [('layer1.2.relu2', nn.Sequential(nn.AvgPool2d(3), + nn.Flatten(), + nn.Linear(1600, NUM_CLASSES)))] + return exits_def - def forward(self, x): - x = self.avg_pool(x) - x = x.view(x.size(0), -1) - x = self.linear(x) - return x +def find_module(model, mod_name): + """Locate a module, given its full name""" + for name, module in model.named_modules(): + if name == mod_name: + return module + return None -class BranchPoint(nn.Module): - def __init__(self, original_m, exit_m): - super().__init__() - self.original_m = original_m - self.exit_m = exit_m - self.output = None - def forward(self, x): - x1 = self.original_m.forward(x) - x2 = self.exit_m.forward(x1) - self.output = x2 - return x1 +def split_module_name(mod_name): + name_parts = mod_name.split('.') + parent = '.'.join(name_parts[:-1]) + node = name_parts[-1] + return parent, node class ResNetCifarEarlyExit(ResNetCifar): - def __init__(self, block, layers, num_classes=NUM_CLASSES): - super().__init__(block, layers, num_classes) - - # Define early exit branches and install them - self.exit_branch = ExitBranch(num_classes) - self.layer1 = BranchPoint(self.layer1, self.exit_branch) + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.exit_points = [] + self.attach_exits(get_exits_def()) + + def attach_exits(self, exits_def): + # For each exit point, we: + # 1. Cache the name of the exit_point module (i.e. the name of the module + # whose output we forward to the exit branch). + # 2. Override the exit_point module with an instance of BranchPoint + for exit_point, exit_branch in exits_def: + self.exit_points.append(exit_point) + replaced_module = find_module(self, exit_point) + parent_name, node_name = split_module_name(exit_point) + parent_module = find_module(self, parent_name) + parent_module.__setattr__(node_name, BranchPoint(replaced_module, exit_branch)) def forward(self, x): + # Run the input through the network x = super().forward(x) - - # return a list of probabilities - exit0 = self.layer1.output - output = (exit0, x) - return output + # Collect the outputs of all the exits and return them + outputs = [] + for exit_point in self.exit_points: + parent_name, node_name = split_module_name(exit_point) + parent_module = find_module(self, parent_name) + output = parent_module.__getattr__(node_name).output + outputs.append(output) + outputs += [x] + return outputs def resnet20_cifar_earlyexit(**kwargs): diff --git a/distiller/modules/__init__.py b/distiller/modules/__init__.py index 71c72eef6d8cd75b5c4dc83ca9d4fc5a881bb8ac..bc804897f3847c19f2e8fecb377032c4b98803d9 100644 --- a/distiller/modules/__init__.py +++ b/distiller/modules/__init__.py @@ -19,8 +19,9 @@ from .grouping import * from .matmul import * from .rnn import * from .aggregate import * +from .topology import * __all__ = ['EltwiseAdd', 'EltwiseMult', 'EltwiseDiv', 'Matmul', 'BatchMatmul', 'Concat', 'Chunk', 'Split', 'Stack', 'DistillerLSTMCell', 'DistillerLSTM', 'convert_model_to_distiller_lstm', - 'Norm', 'Mean'] + 'Norm', 'Mean', 'BranchPoint'] diff --git a/distiller/modules/topology.py b/distiller/modules/topology.py new file mode 100644 index 0000000000000000000000000000000000000000..81f5c364279f8d6069f46eb289fe9fb9735505b7 --- /dev/null +++ b/distiller/modules/topology.py @@ -0,0 +1,55 @@ +# +# Copyright (c) 2019 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +"""Modules related to a model's topology""" +import torch.nn as nn + + +class BranchPoint(nn.Module): + """Add a branch to an existing model.""" + def __init__(self, branched_module, branch_net): + """ + :param branched_module: the module in the original network to which we add a branch. + :param branch_net: the new branch + """ + super().__init__() + self.branched_module = branched_module + self.branch_net = branch_net + self.output = None + + def forward(self, x): + x1 = self.branched_module.forward(x) + self.output = self.branch_net.forward(x1) + return x1 + + +# This class is "borrowed" from PyTorch 1.3 until we integrate it +class Flatten(nn.Module): + __constants__ = ['start_dim', 'end_dim'] + + def __init__(self, start_dim=1, end_dim=-1): + super(Flatten, self).__init__() + self.start_dim = start_dim + self.end_dim = end_dim + + def forward(self, input): + return input.flatten(self.start_dim, self.end_dim) + + +# A temporary trick to see if we need to add Flatten to the `torch.nn` namespace for convenience. +try: + Flatten = nn.Flatten +except AttributeError: + nn.Flatten = Flatten \ No newline at end of file