Skip to content
Snippets Groups Projects
Commit a7473c95 authored by Neta Zmora's avatar Neta Zmora
Browse files

Early-exit refactoring: flexible exits installation

Exits can now be attached to any point in the network
By specifying the name of the attachment node and
the exit-branch subgraph.
parent 660a0da5
No related branches found
No related tags found
No related merge requests found
......@@ -14,12 +14,11 @@
# limitations under the License.
#
"""Resnet for CIFAR10
"""Resnet for CIFAR10 with Early Exit branches
Resnet for CIFAR10, based on "Deep Residual Learning for Image Recognition".
This is based on TorchVision's implementation of ResNet for ImageNet, with appropriate
changes for the 10-class Cifar-10 dataset.
This ResNet also has layer gates, to be able to dynamically remove layers.
@inproceedings{DBLP:conf/cvpr/HeZRS16,
author = {Kaiming He and
......@@ -34,12 +33,10 @@ This ResNet also has layer gates, to be able to dynamically remove layers.
}
"""
import torch.nn as nn
import math
import torch.utils.model_zoo as model_zoo
import torchvision.models as models
from .resnet_cifar import BasicBlock
from .resnet_cifar import ResNetCifar
import torch.nn as nn
from distiller.modules import BranchPoint
__all__ = ['resnet20_cifar_earlyexit', 'resnet32_cifar_earlyexit', 'resnet44_cifar_earlyexit',
......@@ -53,48 +50,58 @@ def conv3x3(in_planes, out_planes, stride=1):
padding=1, bias=False)
class ExitBranch(nn.Module):
def __init__(self, num_classes):
super().__init__()
self.avg_pool = nn.AvgPool2d(3)
self.linear = nn.Linear(1600, num_classes)
def get_exits_def():
exits_def = [('layer1.2.relu2', nn.Sequential(nn.AvgPool2d(3),
nn.Flatten(),
nn.Linear(1600, NUM_CLASSES)))]
return exits_def
def forward(self, x):
x = self.avg_pool(x)
x = x.view(x.size(0), -1)
x = self.linear(x)
return x
def find_module(model, mod_name):
"""Locate a module, given its full name"""
for name, module in model.named_modules():
if name == mod_name:
return module
return None
class BranchPoint(nn.Module):
def __init__(self, original_m, exit_m):
super().__init__()
self.original_m = original_m
self.exit_m = exit_m
self.output = None
def forward(self, x):
x1 = self.original_m.forward(x)
x2 = self.exit_m.forward(x1)
self.output = x2
return x1
def split_module_name(mod_name):
name_parts = mod_name.split('.')
parent = '.'.join(name_parts[:-1])
node = name_parts[-1]
return parent, node
class ResNetCifarEarlyExit(ResNetCifar):
def __init__(self, block, layers, num_classes=NUM_CLASSES):
super().__init__(block, layers, num_classes)
# Define early exit branches and install them
self.exit_branch = ExitBranch(num_classes)
self.layer1 = BranchPoint(self.layer1, self.exit_branch)
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.exit_points = []
self.attach_exits(get_exits_def())
def attach_exits(self, exits_def):
# For each exit point, we:
# 1. Cache the name of the exit_point module (i.e. the name of the module
# whose output we forward to the exit branch).
# 2. Override the exit_point module with an instance of BranchPoint
for exit_point, exit_branch in exits_def:
self.exit_points.append(exit_point)
replaced_module = find_module(self, exit_point)
parent_name, node_name = split_module_name(exit_point)
parent_module = find_module(self, parent_name)
parent_module.__setattr__(node_name, BranchPoint(replaced_module, exit_branch))
def forward(self, x):
# Run the input through the network
x = super().forward(x)
# return a list of probabilities
exit0 = self.layer1.output
output = (exit0, x)
return output
# Collect the outputs of all the exits and return them
outputs = []
for exit_point in self.exit_points:
parent_name, node_name = split_module_name(exit_point)
parent_module = find_module(self, parent_name)
output = parent_module.__getattr__(node_name).output
outputs.append(output)
outputs += [x]
return outputs
def resnet20_cifar_earlyexit(**kwargs):
......
......@@ -19,8 +19,9 @@ from .grouping import *
from .matmul import *
from .rnn import *
from .aggregate import *
from .topology import *
__all__ = ['EltwiseAdd', 'EltwiseMult', 'EltwiseDiv', 'Matmul', 'BatchMatmul',
'Concat', 'Chunk', 'Split', 'Stack',
'DistillerLSTMCell', 'DistillerLSTM', 'convert_model_to_distiller_lstm',
'Norm', 'Mean']
'Norm', 'Mean', 'BranchPoint']
#
# Copyright (c) 2019 Intel Corporation
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
"""Modules related to a model's topology"""
import torch.nn as nn
class BranchPoint(nn.Module):
"""Add a branch to an existing model."""
def __init__(self, branched_module, branch_net):
"""
:param branched_module: the module in the original network to which we add a branch.
:param branch_net: the new branch
"""
super().__init__()
self.branched_module = branched_module
self.branch_net = branch_net
self.output = None
def forward(self, x):
x1 = self.branched_module.forward(x)
self.output = self.branch_net.forward(x1)
return x1
# This class is "borrowed" from PyTorch 1.3 until we integrate it
class Flatten(nn.Module):
__constants__ = ['start_dim', 'end_dim']
def __init__(self, start_dim=1, end_dim=-1):
super(Flatten, self).__init__()
self.start_dim = start_dim
self.end_dim = end_dim
def forward(self, input):
return input.flatten(self.start_dim, self.end_dim)
# A temporary trick to see if we need to add Flatten to the `torch.nn` namespace for convenience.
try:
Flatten = nn.Flatten
except AttributeError:
nn.Flatten = Flatten
\ No newline at end of file
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment