Skip to content
Snippets Groups Projects
Commit a7473c95 authored by Neta Zmora's avatar Neta Zmora
Browse files

Early-exit refactoring: flexible exits installation

Exits can now be attached to any point in the network
By specifying the name of the attachment node and
the exit-branch subgraph.
parent 660a0da5
No related branches found
No related tags found
No related merge requests found
...@@ -14,12 +14,11 @@ ...@@ -14,12 +14,11 @@
# limitations under the License. # limitations under the License.
# #
"""Resnet for CIFAR10 """Resnet for CIFAR10 with Early Exit branches
Resnet for CIFAR10, based on "Deep Residual Learning for Image Recognition". Resnet for CIFAR10, based on "Deep Residual Learning for Image Recognition".
This is based on TorchVision's implementation of ResNet for ImageNet, with appropriate This is based on TorchVision's implementation of ResNet for ImageNet, with appropriate
changes for the 10-class Cifar-10 dataset. changes for the 10-class Cifar-10 dataset.
This ResNet also has layer gates, to be able to dynamically remove layers.
@inproceedings{DBLP:conf/cvpr/HeZRS16, @inproceedings{DBLP:conf/cvpr/HeZRS16,
author = {Kaiming He and author = {Kaiming He and
...@@ -34,12 +33,10 @@ This ResNet also has layer gates, to be able to dynamically remove layers. ...@@ -34,12 +33,10 @@ This ResNet also has layer gates, to be able to dynamically remove layers.
} }
""" """
import torch.nn as nn
import math
import torch.utils.model_zoo as model_zoo
import torchvision.models as models
from .resnet_cifar import BasicBlock from .resnet_cifar import BasicBlock
from .resnet_cifar import ResNetCifar from .resnet_cifar import ResNetCifar
import torch.nn as nn
from distiller.modules import BranchPoint
__all__ = ['resnet20_cifar_earlyexit', 'resnet32_cifar_earlyexit', 'resnet44_cifar_earlyexit', __all__ = ['resnet20_cifar_earlyexit', 'resnet32_cifar_earlyexit', 'resnet44_cifar_earlyexit',
...@@ -53,48 +50,58 @@ def conv3x3(in_planes, out_planes, stride=1): ...@@ -53,48 +50,58 @@ def conv3x3(in_planes, out_planes, stride=1):
padding=1, bias=False) padding=1, bias=False)
class ExitBranch(nn.Module): def get_exits_def():
def __init__(self, num_classes): exits_def = [('layer1.2.relu2', nn.Sequential(nn.AvgPool2d(3),
super().__init__() nn.Flatten(),
self.avg_pool = nn.AvgPool2d(3) nn.Linear(1600, NUM_CLASSES)))]
self.linear = nn.Linear(1600, num_classes) return exits_def
def forward(self, x):
x = self.avg_pool(x)
x = x.view(x.size(0), -1)
x = self.linear(x)
return x
def find_module(model, mod_name):
"""Locate a module, given its full name"""
for name, module in model.named_modules():
if name == mod_name:
return module
return None
class BranchPoint(nn.Module):
def __init__(self, original_m, exit_m):
super().__init__()
self.original_m = original_m
self.exit_m = exit_m
self.output = None
def forward(self, x): def split_module_name(mod_name):
x1 = self.original_m.forward(x) name_parts = mod_name.split('.')
x2 = self.exit_m.forward(x1) parent = '.'.join(name_parts[:-1])
self.output = x2 node = name_parts[-1]
return x1 return parent, node
class ResNetCifarEarlyExit(ResNetCifar): class ResNetCifarEarlyExit(ResNetCifar):
def __init__(self, block, layers, num_classes=NUM_CLASSES): def __init__(self, *args, **kwargs):
super().__init__(block, layers, num_classes) super().__init__(*args, **kwargs)
self.exit_points = []
# Define early exit branches and install them self.attach_exits(get_exits_def())
self.exit_branch = ExitBranch(num_classes)
self.layer1 = BranchPoint(self.layer1, self.exit_branch) def attach_exits(self, exits_def):
# For each exit point, we:
# 1. Cache the name of the exit_point module (i.e. the name of the module
# whose output we forward to the exit branch).
# 2. Override the exit_point module with an instance of BranchPoint
for exit_point, exit_branch in exits_def:
self.exit_points.append(exit_point)
replaced_module = find_module(self, exit_point)
parent_name, node_name = split_module_name(exit_point)
parent_module = find_module(self, parent_name)
parent_module.__setattr__(node_name, BranchPoint(replaced_module, exit_branch))
def forward(self, x): def forward(self, x):
# Run the input through the network
x = super().forward(x) x = super().forward(x)
# Collect the outputs of all the exits and return them
# return a list of probabilities outputs = []
exit0 = self.layer1.output for exit_point in self.exit_points:
output = (exit0, x) parent_name, node_name = split_module_name(exit_point)
return output parent_module = find_module(self, parent_name)
output = parent_module.__getattr__(node_name).output
outputs.append(output)
outputs += [x]
return outputs
def resnet20_cifar_earlyexit(**kwargs): def resnet20_cifar_earlyexit(**kwargs):
......
...@@ -19,8 +19,9 @@ from .grouping import * ...@@ -19,8 +19,9 @@ from .grouping import *
from .matmul import * from .matmul import *
from .rnn import * from .rnn import *
from .aggregate import * from .aggregate import *
from .topology import *
__all__ = ['EltwiseAdd', 'EltwiseMult', 'EltwiseDiv', 'Matmul', 'BatchMatmul', __all__ = ['EltwiseAdd', 'EltwiseMult', 'EltwiseDiv', 'Matmul', 'BatchMatmul',
'Concat', 'Chunk', 'Split', 'Stack', 'Concat', 'Chunk', 'Split', 'Stack',
'DistillerLSTMCell', 'DistillerLSTM', 'convert_model_to_distiller_lstm', 'DistillerLSTMCell', 'DistillerLSTM', 'convert_model_to_distiller_lstm',
'Norm', 'Mean'] 'Norm', 'Mean', 'BranchPoint']
#
# Copyright (c) 2019 Intel Corporation
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
"""Modules related to a model's topology"""
import torch.nn as nn
class BranchPoint(nn.Module):
"""Add a branch to an existing model."""
def __init__(self, branched_module, branch_net):
"""
:param branched_module: the module in the original network to which we add a branch.
:param branch_net: the new branch
"""
super().__init__()
self.branched_module = branched_module
self.branch_net = branch_net
self.output = None
def forward(self, x):
x1 = self.branched_module.forward(x)
self.output = self.branch_net.forward(x1)
return x1
# This class is "borrowed" from PyTorch 1.3 until we integrate it
class Flatten(nn.Module):
__constants__ = ['start_dim', 'end_dim']
def __init__(self, start_dim=1, end_dim=-1):
super(Flatten, self).__init__()
self.start_dim = start_dim
self.end_dim = end_dim
def forward(self, input):
return input.flatten(self.start_dim, self.end_dim)
# A temporary trick to see if we need to add Flatten to the `torch.nn` namespace for convenience.
try:
Flatten = nn.Flatten
except AttributeError:
nn.Flatten = Flatten
\ No newline at end of file
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment