From 795590c8869e849c107289bc300585a42f00f178 Mon Sep 17 00:00:00 2001 From: Neta Zmora <neta.zmora@intel.com> Date: Mon, 11 Nov 2019 00:08:07 +0200 Subject: [PATCH] early-exit: further refactoring and resnet50-imagenet Refactor EE code and place in a separate file. Fix resnet50-earlyexit (inputs of nn.Linear layers was wrong). Caveats: 1. resnet50-earlyexit performance needs to be tested for performance. 2. there is still too much EE code dispersed in apputils/image_classifier.py and compress_classifier.py --- distiller/__init__.py | 1 + distiller/apputils/image_classifier.py | 4 +- distiller/early_exit.py | 83 +++++++++++++++++ .../models/cifar10/resnet_cifar_earlyexit.py | 50 ++--------- distiller/models/imagenet/resnet_earlyexit.py | 90 ++++++++----------- distiller/modules/__init__.py | 13 ++- tests/test_pruning.py | 2 +- 7 files changed, 148 insertions(+), 95 deletions(-) create mode 100644 distiller/early_exit.py diff --git a/distiller/__init__.py b/distiller/__init__.py index 62556f4..220b8d0 100755 --- a/distiller/__init__.py +++ b/distiller/__init__.py @@ -26,6 +26,7 @@ from .policy import * from .thinning import * from .knowledge_distillation import KnowledgeDistillationPolicy, DistillationLossWeights from .summary_graph import SummaryGraph, onnx_name_2_pytorch_name +from .early_exit import EarlyExitMgr import logging logging.captureWarnings(True) diff --git a/distiller/apputils/image_classifier.py b/distiller/apputils/image_classifier.py index f73d6b8..98cadfd 100755 --- a/distiller/apputils/image_classifier.py +++ b/distiller/apputils/image_classifier.py @@ -534,7 +534,7 @@ def train(train_loader, model, criterion, optimizer, epoch, data_time = tnt.AverageValueMeter() # For Early Exit, we define statistics for each exit, so - # exiterrors is analogous to classerr for the non-Early Exit case + # `exiterrors` is analogous to `classerr` in the non-Early Exit case if early_exit_mode(args): args.exiterrors = [] for exitnum in range(args.num_exits): @@ -749,6 +749,8 @@ def earlyexit_loss(output, target, criterion, args): sum_lossweights = sum(args.earlyexit_lossweights) assert sum_lossweights < 1 for exitnum in range(args.num_exits-1): + if output[exitnum] is None: + continue exit_loss = criterion(output[exitnum], target) weighted_loss += args.earlyexit_lossweights[exitnum] * exit_loss args.exiterrors[exitnum].add(output[exitnum].detach(), target) diff --git a/distiller/early_exit.py b/distiller/early_exit.py new file mode 100644 index 0000000..721cf7f --- /dev/null +++ b/distiller/early_exit.py @@ -0,0 +1,83 @@ +# +# Copyright (c) 2019 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + + +__all__ = ["EarlyExitMgr"] + + +from distiller.modules import BranchPoint + + +class EarlyExitMgr(object): + def __init__(self): + self.exit_points = [] + + def attach_exits(self, model, exits_def): + # For each exit point, we: + # 1. Cache the name of the exit_point module (i.e. the name of the module + # whose output we forward to the exit branch). + # 2. Override the exit_point module with an instance of BranchPoint + for exit_point, exit_branch in exits_def: + self.exit_points.append(exit_point) + replaced_module = _find_module(model, exit_point) + assert replaced_module is not None, "Could not find exit point {}".format(exit_point) + parent_name, node_name = _split_module_name(exit_point) + parent_module = _find_module(model, parent_name) + # Replace the module `node_name` with an instance of `BranchPoint` + parent_module.__setattr__(node_name, BranchPoint(replaced_module, exit_branch)) + + def get_exits_outputs(self, model): + """Collect the outputs of all the exits and return them. + + The output of each exit was cached during the network forward. + """ + outputs = [] + for exit_point in self.exit_points: + parent_name, node_name = _split_module_name(exit_point) + parent_module = _find_module(model, parent_name) + output = parent_module.__getattr__(node_name).output + assert output is not None + outputs.append(output) + return outputs + + def delete_exits_outputs(self, model): + """Delete the outputs of all the exits. + + Some exit branches may not be traversed, so we need to delete the cached + outputs to make sure these outputs do not participate in the backprop. + """ + outputs = [] + for exit_point in self.exit_points: + parent_name, node_name = _split_module_name(exit_point) + parent_module = _find_module(model, parent_name) + branch_point = parent_module.__getattr__(node_name) + branch_point.output = None + return outputs + + +def _find_module(model, mod_name): + """Locate a module, given its full name""" + for name, module in model.named_modules(): + if name == mod_name: + return module + return None + + +def _split_module_name(mod_name): + name_parts = mod_name.split('.') + parent = '.'.join(name_parts[:-1]) + node = name_parts[-1] + return parent, node diff --git a/distiller/models/cifar10/resnet_cifar_earlyexit.py b/distiller/models/cifar10/resnet_cifar_earlyexit.py index e323153..fd7ee6b 100644 --- a/distiller/models/cifar10/resnet_cifar_earlyexit.py +++ b/distiller/models/cifar10/resnet_cifar_earlyexit.py @@ -1,5 +1,5 @@ # -# Copyright (c) 2018 Intel Corporation +# Copyright (c) 2019 Intel Corporation # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -36,7 +36,7 @@ changes for the 10-class Cifar-10 dataset. from .resnet_cifar import BasicBlock from .resnet_cifar import ResNetCifar import torch.nn as nn -from distiller.modules import BranchPoint +import distiller __all__ = ['resnet20_cifar_earlyexit', 'resnet32_cifar_earlyexit', 'resnet44_cifar_earlyexit', @@ -52,55 +52,23 @@ def conv3x3(in_planes, out_planes, stride=1): def get_exits_def(): exits_def = [('layer1.2.relu2', nn.Sequential(nn.AvgPool2d(3), - nn.Flatten(), - nn.Linear(1600, NUM_CLASSES)))] + nn.Flatten(), + nn.Linear(1600, NUM_CLASSES)))] return exits_def -def find_module(model, mod_name): - """Locate a module, given its full name""" - for name, module in model.named_modules(): - if name == mod_name: - return module - return None - - -def split_module_name(mod_name): - name_parts = mod_name.split('.') - parent = '.'.join(name_parts[:-1]) - node = name_parts[-1] - return parent, node - class ResNetCifarEarlyExit(ResNetCifar): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) - self.exit_points = [] - self.attach_exits(get_exits_def()) - - def attach_exits(self, exits_def): - # For each exit point, we: - # 1. Cache the name of the exit_point module (i.e. the name of the module - # whose output we forward to the exit branch). - # 2. Override the exit_point module with an instance of BranchPoint - for exit_point, exit_branch in exits_def: - self.exit_points.append(exit_point) - replaced_module = find_module(self, exit_point) - parent_name, node_name = split_module_name(exit_point) - parent_module = find_module(self, parent_name) - parent_module.__setattr__(node_name, BranchPoint(replaced_module, exit_branch)) + self.ee_mgr = distiller.EarlyExitMgr() + self.ee_mgr.attach_exits(self, get_exits_def()) def forward(self, x): - # Run the input through the network + self.ee_mgr.delete_exits_outputs(self) + # Run the input through the network (including exits) x = super().forward(x) - # Collect the outputs of all the exits and return them - outputs = [] - for exit_point in self.exit_points: - parent_name, node_name = split_module_name(exit_point) - parent_module = find_module(self, parent_name) - output = parent_module.__getattr__(node_name).output - outputs.append(output) - outputs += [x] + outputs = self.ee_mgr.get_exits_outputs(self) + [x] return outputs diff --git a/distiller/models/imagenet/resnet_earlyexit.py b/distiller/models/imagenet/resnet_earlyexit.py index 4e6ba99..03fd6c9 100644 --- a/distiller/models/imagenet/resnet_earlyexit.py +++ b/distiller/models/imagenet/resnet_earlyexit.py @@ -1,10 +1,8 @@ import torch.nn as nn -import math -import torch.utils.model_zoo as model_zoo import torchvision.models as models -from torchvision.models.resnet import Bottleneck -from torchvision.models.resnet import BasicBlock - +from torchvision.models.resnet import ResNet, BasicBlock, Bottleneck +from .resnet import DistillerBottleneck +import distiller __all__ = ['resnet50_earlyexit'] @@ -15,59 +13,49 @@ def conv3x3(in_planes, out_planes, stride=1): padding=1, bias=False) -class ResNetEarlyExit(models.ResNet): +def get_exits_def(num_classes): + expansion = 1 # models.ResNet.BasicBlock.expansion + exits_def = [('layer1.2.relu3', nn.Sequential(nn.AdaptiveAvgPool2d((1, 1)), + nn.Conv2d(256, 50, kernel_size=7, stride=2, padding=3, bias=True), + nn.Conv2d(50, 12, kernel_size=7, stride=2, padding=3, bias=True), + #nn.AdaptiveAvgPool2d((1, 1)), + nn.Flatten(), + nn.Linear(12 * expansion, num_classes))), + #distiller.modules.Print())), + ('layer2.3.relu3', nn.Sequential(nn.Conv2d(512, 12, kernel_size=7, stride=2, padding=3, bias=True), + nn.AdaptiveAvgPool2d((1, 1)), + nn.Flatten(), + #distiller.modules.Print()))] + nn.Linear(12 * expansion, num_classes)))] + return exits_def - def __init__(self, block, layers, num_classes=1000): - super(ResNetEarlyExit, self).__init__(block, layers, num_classes) - # Define early exit layers - self.conv1_exit0 = nn.Conv2d(256, 50, kernel_size=7, stride=2, padding=3, bias=True) - self.conv2_exit0 = nn.Conv2d(50, 12, kernel_size=7, stride=2, padding=3, bias=True) - self.conv1_exit1 = nn.Conv2d(512, 12, kernel_size=7, stride=2, padding=3, bias=True) - self.fc_exit0 = nn.Linear(147 * block.expansion, num_classes) - self.fc_exit1 = nn.Linear(192 * block.expansion, num_classes) +class ResNetEarlyExit(models.ResNet): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.ee_mgr = distiller.EarlyExitMgr() + self.ee_mgr.attach_exits(self, get_exits_def(num_classes=1000)) def forward(self, x): - x = self.conv1(x) - x = self.bn1(x) - x = self.relu(x) - x = self.maxpool(x) - - x = self.layer1(x) - - # Add early exit layers - exit0 = self.avgpool(x) - exit0 = self.conv1_exit0(exit0) - exit0 = self.conv2_exit0(exit0) - exit0 = self.avgpool(exit0) - exit0 = exit0.view(exit0.size(0), -1) - exit0 = self.fc_exit0(exit0) + self.ee_mgr.delete_exits_outputs(self) + # Run the input through the network (including exits) + x = super().forward(x) + outputs = self.ee_mgr.get_exits_outputs(self) + [x] + return outputs - x = self.layer2(x) - # Add early exit layers - exit1 = self.conv1_exit1(x) - exit1 = self.avgpool(exit1) - exit1 = exit1.view(exit1.size(0), -1) - exit1 = self.fc_exit1(exit1) - - x = self.layer3(x) - x = self.layer4(x) - - x = self.avgpool(x) - x = x.view(x.size(0), -1) - x = self.fc(x) +def _resnet(arch, block, layers, pretrained, progress, **kwargs): + model = ResNetEarlyExit(block, layers, **kwargs) + assert not pretrained + return model - # return a list of probabilities - output = [] - output.append(exit0) - output.append(exit1) - output.append(x) - return output +def resnet50_earlyexit(pretrained=False, progress=True, **kwargs): + """Constructs a ResNet-50 model, with early exit branches. -def resnet50_earlyexit(pretrained=False, **kwargs): - """Constructs a ResNet-50 model. + Args: + pretrained (bool): If True, returns a model pre-trained on ImageNet + progress (bool): If True, displays a progress bar of the download to stderr """ - model = ResNetEarlyExit(Bottleneck, [3, 4, 6, 3], **kwargs) - return model + return _resnet('resnet50', DistillerBottleneck, [3, 4, 6, 3], pretrained, progress, + **kwargs) \ No newline at end of file diff --git a/distiller/modules/__init__.py b/distiller/modules/__init__.py index bc80489..e46bc81 100644 --- a/distiller/modules/__init__.py +++ b/distiller/modules/__init__.py @@ -24,4 +24,15 @@ from .topology import * __all__ = ['EltwiseAdd', 'EltwiseMult', 'EltwiseDiv', 'Matmul', 'BatchMatmul', 'Concat', 'Chunk', 'Split', 'Stack', 'DistillerLSTMCell', 'DistillerLSTM', 'convert_model_to_distiller_lstm', - 'Norm', 'Mean', 'BranchPoint'] + 'Norm', 'Mean', 'BranchPoint', 'Print'] + + +class Print(nn.Module): + """Utility module to temporarily replace modules to assess activation shape. + + This is useful, e.g., when creating a new model and you want to know the size + of the input to nn.Linear and you want to avoid calculating the shape by hand. + """ + def forward(self, x): + print(x.size()) + return x \ No newline at end of file diff --git a/tests/test_pruning.py b/tests/test_pruning.py index 23a795c..0862962 100755 --- a/tests/test_pruning.py +++ b/tests/test_pruning.py @@ -123,6 +123,7 @@ def test_ranked_filter_pruning(parallel): is_parallel=parallel) test_vgg19_conv_fc_interface(parallel, model=model, zeros_mask_dict=zeros_mask_dict) + # todo: add a similar test for ranked channel pruning def test_prune_all_filters(parallel): """Pruning all of the filteres in a weights tensor of a Convolution @@ -145,7 +146,6 @@ def ranked_filter_pruning(config, ratio_to_prune, is_parallel, rounding_fn=math. logger.info("executing: %s (invoked by %s)" % (inspect.currentframe().f_code.co_name, inspect.currentframe().f_back.f_code.co_name)) - model, zeros_mask_dict = common.setup_test(config.arch, config.dataset, is_parallel) for pair in config.module_pairs: -- GitLab