Skip to content
Snippets Groups Projects
Commit 795590c8 authored by Neta Zmora's avatar Neta Zmora
Browse files

early-exit: further refactoring and resnet50-imagenet

Refactor EE code and place in a separate file.
Fix resnet50-earlyexit (inputs of nn.Linear layers was wrong).

Caveats:
1. resnet50-earlyexit performance needs to be tested for performance.
2. there is still too much EE code dispersed in apputils/image_classifier.py
and compress_classifier.py
parent 49a2a967
No related branches found
No related tags found
No related merge requests found
...@@ -26,6 +26,7 @@ from .policy import * ...@@ -26,6 +26,7 @@ from .policy import *
from .thinning import * from .thinning import *
from .knowledge_distillation import KnowledgeDistillationPolicy, DistillationLossWeights from .knowledge_distillation import KnowledgeDistillationPolicy, DistillationLossWeights
from .summary_graph import SummaryGraph, onnx_name_2_pytorch_name from .summary_graph import SummaryGraph, onnx_name_2_pytorch_name
from .early_exit import EarlyExitMgr
import logging import logging
logging.captureWarnings(True) logging.captureWarnings(True)
......
...@@ -534,7 +534,7 @@ def train(train_loader, model, criterion, optimizer, epoch, ...@@ -534,7 +534,7 @@ def train(train_loader, model, criterion, optimizer, epoch,
data_time = tnt.AverageValueMeter() data_time = tnt.AverageValueMeter()
# For Early Exit, we define statistics for each exit, so # For Early Exit, we define statistics for each exit, so
# exiterrors is analogous to classerr for the non-Early Exit case # `exiterrors` is analogous to `classerr` in the non-Early Exit case
if early_exit_mode(args): if early_exit_mode(args):
args.exiterrors = [] args.exiterrors = []
for exitnum in range(args.num_exits): for exitnum in range(args.num_exits):
...@@ -749,6 +749,8 @@ def earlyexit_loss(output, target, criterion, args): ...@@ -749,6 +749,8 @@ def earlyexit_loss(output, target, criterion, args):
sum_lossweights = sum(args.earlyexit_lossweights) sum_lossweights = sum(args.earlyexit_lossweights)
assert sum_lossweights < 1 assert sum_lossweights < 1
for exitnum in range(args.num_exits-1): for exitnum in range(args.num_exits-1):
if output[exitnum] is None:
continue
exit_loss = criterion(output[exitnum], target) exit_loss = criterion(output[exitnum], target)
weighted_loss += args.earlyexit_lossweights[exitnum] * exit_loss weighted_loss += args.earlyexit_lossweights[exitnum] * exit_loss
args.exiterrors[exitnum].add(output[exitnum].detach(), target) args.exiterrors[exitnum].add(output[exitnum].detach(), target)
......
#
# Copyright (c) 2019 Intel Corporation
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
__all__ = ["EarlyExitMgr"]
from distiller.modules import BranchPoint
class EarlyExitMgr(object):
def __init__(self):
self.exit_points = []
def attach_exits(self, model, exits_def):
# For each exit point, we:
# 1. Cache the name of the exit_point module (i.e. the name of the module
# whose output we forward to the exit branch).
# 2. Override the exit_point module with an instance of BranchPoint
for exit_point, exit_branch in exits_def:
self.exit_points.append(exit_point)
replaced_module = _find_module(model, exit_point)
assert replaced_module is not None, "Could not find exit point {}".format(exit_point)
parent_name, node_name = _split_module_name(exit_point)
parent_module = _find_module(model, parent_name)
# Replace the module `node_name` with an instance of `BranchPoint`
parent_module.__setattr__(node_name, BranchPoint(replaced_module, exit_branch))
def get_exits_outputs(self, model):
"""Collect the outputs of all the exits and return them.
The output of each exit was cached during the network forward.
"""
outputs = []
for exit_point in self.exit_points:
parent_name, node_name = _split_module_name(exit_point)
parent_module = _find_module(model, parent_name)
output = parent_module.__getattr__(node_name).output
assert output is not None
outputs.append(output)
return outputs
def delete_exits_outputs(self, model):
"""Delete the outputs of all the exits.
Some exit branches may not be traversed, so we need to delete the cached
outputs to make sure these outputs do not participate in the backprop.
"""
outputs = []
for exit_point in self.exit_points:
parent_name, node_name = _split_module_name(exit_point)
parent_module = _find_module(model, parent_name)
branch_point = parent_module.__getattr__(node_name)
branch_point.output = None
return outputs
def _find_module(model, mod_name):
"""Locate a module, given its full name"""
for name, module in model.named_modules():
if name == mod_name:
return module
return None
def _split_module_name(mod_name):
name_parts = mod_name.split('.')
parent = '.'.join(name_parts[:-1])
node = name_parts[-1]
return parent, node
# #
# Copyright (c) 2018 Intel Corporation # Copyright (c) 2019 Intel Corporation
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
...@@ -36,7 +36,7 @@ changes for the 10-class Cifar-10 dataset. ...@@ -36,7 +36,7 @@ changes for the 10-class Cifar-10 dataset.
from .resnet_cifar import BasicBlock from .resnet_cifar import BasicBlock
from .resnet_cifar import ResNetCifar from .resnet_cifar import ResNetCifar
import torch.nn as nn import torch.nn as nn
from distiller.modules import BranchPoint import distiller
__all__ = ['resnet20_cifar_earlyexit', 'resnet32_cifar_earlyexit', 'resnet44_cifar_earlyexit', __all__ = ['resnet20_cifar_earlyexit', 'resnet32_cifar_earlyexit', 'resnet44_cifar_earlyexit',
...@@ -52,55 +52,23 @@ def conv3x3(in_planes, out_planes, stride=1): ...@@ -52,55 +52,23 @@ def conv3x3(in_planes, out_planes, stride=1):
def get_exits_def(): def get_exits_def():
exits_def = [('layer1.2.relu2', nn.Sequential(nn.AvgPool2d(3), exits_def = [('layer1.2.relu2', nn.Sequential(nn.AvgPool2d(3),
nn.Flatten(), nn.Flatten(),
nn.Linear(1600, NUM_CLASSES)))] nn.Linear(1600, NUM_CLASSES)))]
return exits_def return exits_def
def find_module(model, mod_name):
"""Locate a module, given its full name"""
for name, module in model.named_modules():
if name == mod_name:
return module
return None
def split_module_name(mod_name):
name_parts = mod_name.split('.')
parent = '.'.join(name_parts[:-1])
node = name_parts[-1]
return parent, node
class ResNetCifarEarlyExit(ResNetCifar): class ResNetCifarEarlyExit(ResNetCifar):
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs) super().__init__(*args, **kwargs)
self.exit_points = [] self.ee_mgr = distiller.EarlyExitMgr()
self.attach_exits(get_exits_def()) self.ee_mgr.attach_exits(self, get_exits_def())
def attach_exits(self, exits_def):
# For each exit point, we:
# 1. Cache the name of the exit_point module (i.e. the name of the module
# whose output we forward to the exit branch).
# 2. Override the exit_point module with an instance of BranchPoint
for exit_point, exit_branch in exits_def:
self.exit_points.append(exit_point)
replaced_module = find_module(self, exit_point)
parent_name, node_name = split_module_name(exit_point)
parent_module = find_module(self, parent_name)
parent_module.__setattr__(node_name, BranchPoint(replaced_module, exit_branch))
def forward(self, x): def forward(self, x):
# Run the input through the network self.ee_mgr.delete_exits_outputs(self)
# Run the input through the network (including exits)
x = super().forward(x) x = super().forward(x)
# Collect the outputs of all the exits and return them outputs = self.ee_mgr.get_exits_outputs(self) + [x]
outputs = []
for exit_point in self.exit_points:
parent_name, node_name = split_module_name(exit_point)
parent_module = find_module(self, parent_name)
output = parent_module.__getattr__(node_name).output
outputs.append(output)
outputs += [x]
return outputs return outputs
......
import torch.nn as nn import torch.nn as nn
import math
import torch.utils.model_zoo as model_zoo
import torchvision.models as models import torchvision.models as models
from torchvision.models.resnet import Bottleneck from torchvision.models.resnet import ResNet, BasicBlock, Bottleneck
from torchvision.models.resnet import BasicBlock from .resnet import DistillerBottleneck
import distiller
__all__ = ['resnet50_earlyexit'] __all__ = ['resnet50_earlyexit']
...@@ -15,59 +13,49 @@ def conv3x3(in_planes, out_planes, stride=1): ...@@ -15,59 +13,49 @@ def conv3x3(in_planes, out_planes, stride=1):
padding=1, bias=False) padding=1, bias=False)
class ResNetEarlyExit(models.ResNet): def get_exits_def(num_classes):
expansion = 1 # models.ResNet.BasicBlock.expansion
exits_def = [('layer1.2.relu3', nn.Sequential(nn.AdaptiveAvgPool2d((1, 1)),
nn.Conv2d(256, 50, kernel_size=7, stride=2, padding=3, bias=True),
nn.Conv2d(50, 12, kernel_size=7, stride=2, padding=3, bias=True),
#nn.AdaptiveAvgPool2d((1, 1)),
nn.Flatten(),
nn.Linear(12 * expansion, num_classes))),
#distiller.modules.Print())),
('layer2.3.relu3', nn.Sequential(nn.Conv2d(512, 12, kernel_size=7, stride=2, padding=3, bias=True),
nn.AdaptiveAvgPool2d((1, 1)),
nn.Flatten(),
#distiller.modules.Print()))]
nn.Linear(12 * expansion, num_classes)))]
return exits_def
def __init__(self, block, layers, num_classes=1000):
super(ResNetEarlyExit, self).__init__(block, layers, num_classes)
# Define early exit layers class ResNetEarlyExit(models.ResNet):
self.conv1_exit0 = nn.Conv2d(256, 50, kernel_size=7, stride=2, padding=3, bias=True) def __init__(self, *args, **kwargs):
self.conv2_exit0 = nn.Conv2d(50, 12, kernel_size=7, stride=2, padding=3, bias=True) super().__init__(*args, **kwargs)
self.conv1_exit1 = nn.Conv2d(512, 12, kernel_size=7, stride=2, padding=3, bias=True) self.ee_mgr = distiller.EarlyExitMgr()
self.fc_exit0 = nn.Linear(147 * block.expansion, num_classes) self.ee_mgr.attach_exits(self, get_exits_def(num_classes=1000))
self.fc_exit1 = nn.Linear(192 * block.expansion, num_classes)
def forward(self, x): def forward(self, x):
x = self.conv1(x) self.ee_mgr.delete_exits_outputs(self)
x = self.bn1(x) # Run the input through the network (including exits)
x = self.relu(x) x = super().forward(x)
x = self.maxpool(x) outputs = self.ee_mgr.get_exits_outputs(self) + [x]
return outputs
x = self.layer1(x)
# Add early exit layers
exit0 = self.avgpool(x)
exit0 = self.conv1_exit0(exit0)
exit0 = self.conv2_exit0(exit0)
exit0 = self.avgpool(exit0)
exit0 = exit0.view(exit0.size(0), -1)
exit0 = self.fc_exit0(exit0)
x = self.layer2(x)
# Add early exit layers def _resnet(arch, block, layers, pretrained, progress, **kwargs):
exit1 = self.conv1_exit1(x) model = ResNetEarlyExit(block, layers, **kwargs)
exit1 = self.avgpool(exit1) assert not pretrained
exit1 = exit1.view(exit1.size(0), -1) return model
exit1 = self.fc_exit1(exit1)
x = self.layer3(x)
x = self.layer4(x)
x = self.avgpool(x)
x = x.view(x.size(0), -1)
x = self.fc(x)
# return a list of probabilities
output = []
output.append(exit0)
output.append(exit1)
output.append(x)
return output
def resnet50_earlyexit(pretrained=False, progress=True, **kwargs):
"""Constructs a ResNet-50 model, with early exit branches.
def resnet50_earlyexit(pretrained=False, **kwargs): Args:
"""Constructs a ResNet-50 model. pretrained (bool): If True, returns a model pre-trained on ImageNet
progress (bool): If True, displays a progress bar of the download to stderr
""" """
model = ResNetEarlyExit(Bottleneck, [3, 4, 6, 3], **kwargs) return _resnet('resnet50', DistillerBottleneck, [3, 4, 6, 3], pretrained, progress,
return model **kwargs)
\ No newline at end of file
...@@ -24,4 +24,15 @@ from .topology import * ...@@ -24,4 +24,15 @@ from .topology import *
__all__ = ['EltwiseAdd', 'EltwiseMult', 'EltwiseDiv', 'Matmul', 'BatchMatmul', __all__ = ['EltwiseAdd', 'EltwiseMult', 'EltwiseDiv', 'Matmul', 'BatchMatmul',
'Concat', 'Chunk', 'Split', 'Stack', 'Concat', 'Chunk', 'Split', 'Stack',
'DistillerLSTMCell', 'DistillerLSTM', 'convert_model_to_distiller_lstm', 'DistillerLSTMCell', 'DistillerLSTM', 'convert_model_to_distiller_lstm',
'Norm', 'Mean', 'BranchPoint'] 'Norm', 'Mean', 'BranchPoint', 'Print']
class Print(nn.Module):
"""Utility module to temporarily replace modules to assess activation shape.
This is useful, e.g., when creating a new model and you want to know the size
of the input to nn.Linear and you want to avoid calculating the shape by hand.
"""
def forward(self, x):
print(x.size())
return x
\ No newline at end of file
...@@ -123,6 +123,7 @@ def test_ranked_filter_pruning(parallel): ...@@ -123,6 +123,7 @@ def test_ranked_filter_pruning(parallel):
is_parallel=parallel) is_parallel=parallel)
test_vgg19_conv_fc_interface(parallel, model=model, zeros_mask_dict=zeros_mask_dict) test_vgg19_conv_fc_interface(parallel, model=model, zeros_mask_dict=zeros_mask_dict)
# todo: add a similar test for ranked channel pruning # todo: add a similar test for ranked channel pruning
def test_prune_all_filters(parallel): def test_prune_all_filters(parallel):
"""Pruning all of the filteres in a weights tensor of a Convolution """Pruning all of the filteres in a weights tensor of a Convolution
...@@ -145,7 +146,6 @@ def ranked_filter_pruning(config, ratio_to_prune, is_parallel, rounding_fn=math. ...@@ -145,7 +146,6 @@ def ranked_filter_pruning(config, ratio_to_prune, is_parallel, rounding_fn=math.
logger.info("executing: %s (invoked by %s)" % (inspect.currentframe().f_code.co_name, logger.info("executing: %s (invoked by %s)" % (inspect.currentframe().f_code.co_name,
inspect.currentframe().f_back.f_code.co_name)) inspect.currentframe().f_back.f_code.co_name))
model, zeros_mask_dict = common.setup_test(config.arch, config.dataset, is_parallel) model, zeros_mask_dict = common.setup_test(config.arch, config.dataset, is_parallel)
for pair in config.module_pairs: for pair in config.module_pairs:
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment