Skip to content
Snippets Groups Projects
Commit 47732093 authored by Neta Zmora's avatar Neta Zmora
Browse files

Filter removal: support filter removal for VGG

This is a temporary implementation that allows filter-removal and
netowrk thinning for VGG.
The implementation continues the present design for network thinning,
which is problematic because parts of the solution are specific to
each model.

Leveraging some new features in PyTorch 0.4, we are now able to provide a
more generic solution to thinning, which we will push to 'master' soon.
This commit bridges the feature gap, for VGG filter-removal, for the
meantime.
parent 6f7c5ae4
No related branches found
No related tags found
No related merge requests found
...@@ -21,13 +21,16 @@ from collections import namedtuple ...@@ -21,13 +21,16 @@ from collections import namedtuple
import torch import torch
from .policy import ScheduledTrainingPolicy from .policy import ScheduledTrainingPolicy
import distiller import distiller
from apputils import SummaryGraph
from models import ALL_MODEL_NAMES, create_model
msglogger = logging.getLogger() msglogger = logging.getLogger()
ThinningRecipe = namedtuple('ThinningRecipe', ['modules', 'parameters']) ThinningRecipe = namedtuple('ThinningRecipe', ['modules', 'parameters'])
__all__ = ['ThinningRecipe', 'resnet_cifar_remove_layers', 'resnet_cifar_remove_filters', __all__ = ['ThinningRecipe', 'resnet_cifar_remove_layers', 'resnet_cifar_remove_filters',
'resnet_cifar_remove_channels', 'ResnetCifarChannelRemover', 'resnet_cifar_remove_channels', 'ResnetCifarChannelRemover',
'ResnetCifarFilterRemover', 'execute_thinning_recipe'] 'ResnetCifarFilterRemover', 'execute_thinning_recipe',
'FilterRemover', 'vgg_remove_filters']
# This is a dictionary that maps ResNet-Cifar connectivity, of convolutional layers. # This is a dictionary that maps ResNet-Cifar connectivity, of convolutional layers.
# BN layers connectivity are implied; and shortcuts are not currently handled. # BN layers connectivity are implied; and shortcuts are not currently handled.
...@@ -94,6 +97,24 @@ conv_connectivity = { ...@@ -94,6 +97,24 @@ conv_connectivity = {
'module.layer3.8.conv1': 'module.layer3.8.conv2', 'module.layer3.8.conv1': 'module.layer3.8.conv2',
'module.layer3.8.conv2': 'module.fc'} 'module.layer3.8.conv2': 'module.fc'}
vgg_conv_connectivity = {
'features.module.0': 'features.module.2',
'features.module.2': 'features.module.5',
'features.module.5': 'features.module.7',
'features.module.7': 'features.module.10',
'features.module.10': 'features.module.12',
'features.module.12': 'features.module.14',
'features.module.14': 'features.module.16',
'features.module.16': 'features.module.19',
'features.module.19': 'features.module.21',
'features.module.21': 'features.module.23',
'features.module.23': 'features.module.25',
'features.module.25': 'features.module.28',
'features.module.28': 'features.module.30',
'features.module.30': 'features.module.32',
'features.module.32': 'features.module.34',
'features.module.34': 'classifier.0'}
def find_predecessors(layer_name): def find_predecessors(layer_name):
predecessors = [] predecessors = []
for layer, followers in conv_connectivity.items(): for layer, followers in conv_connectivity.items():
...@@ -116,8 +137,9 @@ def bn_thinning(thinning_recipe, layers, bn_name, len_thin_features, thin_featur ...@@ -116,8 +137,9 @@ def bn_thinning(thinning_recipe, layers, bn_name, len_thin_features, thin_featur
bn_directive['running_var'] = (0, thin_features) bn_directive['running_var'] = (0, thin_features)
thinning_recipe.modules[bn_name] = bn_directive thinning_recipe.modules[bn_name] = bn_directive
thinning_recipe.parameters[bn_name+'.weight'] = (0, thin_features)
thinning_recipe.parameters[bn_name+'.bias'] = (0, thin_features) thinning_recipe.parameters[bn_name+'.weight'] = [(0, thin_features)]
thinning_recipe.parameters[bn_name+'.bias'] = [(0, thin_features)]
def resnet_cifar_remove_layers(resnet_cifar_model): def resnet_cifar_remove_layers(resnet_cifar_model):
...@@ -192,7 +214,9 @@ def resnet_cifar_create_thinning_recipe_channels(resnet_cifar_model, zeros_mask_ ...@@ -192,7 +214,9 @@ def resnet_cifar_create_thinning_recipe_channels(resnet_cifar_model, zeros_mask_
# Select only the non-zero filters # Select only the non-zero filters
indices = nonzero_channels.data.squeeze() indices = nonzero_channels.data.squeeze()
thinning_recipe.parameters[name] = (1, indices) param_directive = thinning_recipe.parameters.get(name, [])
param_directive.append((1, indices))
thinning_recipe.parameters[name] = param_directive
assert layer_name in conv_connectivity assert layer_name in conv_connectivity
predecessors = find_predecessors(layer_name) predecessors = find_predecessors(layer_name)
...@@ -204,8 +228,9 @@ def resnet_cifar_create_thinning_recipe_channels(resnet_cifar_model, zeros_mask_ ...@@ -204,8 +228,9 @@ def resnet_cifar_create_thinning_recipe_channels(resnet_cifar_model, zeros_mask_
thinning_recipe.modules[predecessor] = conv_directive thinning_recipe.modules[predecessor] = conv_directive
# Now remove channels from the weights tensor of the follower conv # Now remove channels from the weights tensor of the follower conv
#predecessor_param = distiller.model_find_param(resnet_cifar_model, predecessor+'.weight') param_directive = thinning_recipe.parameters.get(predecessor+'.weight', [])
thinning_recipe.parameters[predecessor+'.weight'] = (0, indices) param_directive.append((0, indices))
thinning_recipe.parameters[predecessor+'.weight'] = param_directive
# Now handle the BatchNormalization layer that follows the convolution # Now handle the BatchNormalization layer that follows the convolution
bn_name = predecessor.replace('conv', 'bn') bn_name = predecessor.replace('conv', 'bn')
...@@ -217,22 +242,48 @@ def resnet_cifar_create_thinning_recipe_channels(resnet_cifar_model, zeros_mask_ ...@@ -217,22 +242,48 @@ def resnet_cifar_create_thinning_recipe_channels(resnet_cifar_model, zeros_mask_
def resnet_cifar_remove_filters(resnet_cifar_model, zeros_mask_dict): def resnet_cifar_remove_filters(resnet_cifar_model, zeros_mask_dict):
thinning_recipe = resnet_cifar_create_thinning_recipe_filters(resnet_cifar_model, zeros_mask_dict) thinning_recipe = resnet_cifar_create_thinning_recipe_filters(resnet_cifar_model, zeros_mask_dict)
return remove_filters(resnet_cifar_model, zeros_mask_dict, thinning_recipe)
def vgg_remove_filters(vgg_model, zeros_mask_dict):
thinning_recipe = vgg_create_thinning_recipe_filters(vgg_model, zeros_mask_dict)
return remove_filters(vgg_model, zeros_mask_dict, thinning_recipe)
def remove_filters(model, zeros_mask_dict, thinning_recipe):
if len(thinning_recipe.modules) > 0 or len(thinning_recipe.parameters)>0: if len(thinning_recipe.modules) > 0 or len(thinning_recipe.parameters)>0:
# Stash the recipe, so that it will be serialized together with the model # Stash the recipe, so that it will be serialized together with the model
resnet_cifar_model.thinning_recipe = thinning_recipe model.thinning_recipe = thinning_recipe
# Now actually remove the filters, chaneels and make the weight tensors smaller # Now actually remove the filters, chaneels and make the weight tensors smaller
execute_thinning_recipe(resnet_cifar_model, zeros_mask_dict, thinning_recipe) execute_thinning_recipe(model, zeros_mask_dict, thinning_recipe)
msglogger.info("Created, applied and saved a filter-thinning recipe") msglogger.info("Created, applied and saved a filter-thinning recipe")
return resnet_cifar_model return model
def get_input(dataset):
if dataset == 'imagenet':
return torch.randn((1, 3, 224, 224), requires_grad=False)
elif dataset == 'cifar10':
return torch.randn((1, 3, 32, 32))
return None
def create_graph(dataset, arch):
dummy_input = get_input(dataset)
assert dummy_input is not None, "Unsupported dataset ({}) - aborting draw operation".format(dataset)
model = create_model(False, dataset, arch, parallel=False)
assert model is not None
return SummaryGraph(model, dummy_input)
def normalize_layer_name(layer_name):
if layer_name.startswith('module.'):
layer_name = layer_name[len('module.'):]
return layer_name
def resnet_cifar_create_thinning_recipe_filters(resnet_cifar_model, zeros_mask_dict): def resnet_cifar_create_thinning_recipe_filters(resnet_cifar_model, zeros_mask_dict):
"""Remove filters from ResNet-Cifar """Remove filters from ResNet-Cifar
Caveats: (1) supports only ResNet50-Cifar; (2) only module.layerX.Y.conv1.weight Caveats: (1) supports only ResNet50-Cifar; (2) only module.layerX.Y.conv1.weight
""" """
msglogger.info("Invoking resnet_cifar_remove_filters") msglogger.info("Invoking resnet_cifar_create_thinning_recipe_filters")
layers = {thin_name : m for thin_name, m in resnet_cifar_model.named_modules()} layers = {thin_name : m for thin_name, m in resnet_cifar_model.named_modules()}
thinning_recipe = ThinningRecipe(modules={}, parameters={}) thinning_recipe = ThinningRecipe(modules={}, parameters={})
...@@ -259,9 +310,12 @@ def resnet_cifar_create_thinning_recipe_filters(resnet_cifar_model, zeros_mask_d ...@@ -259,9 +310,12 @@ def resnet_cifar_create_thinning_recipe_filters(resnet_cifar_model, zeros_mask_d
# Select only the non-zero filters # Select only the non-zero filters
indices = nonzero_filters.data.squeeze() indices = nonzero_filters.data.squeeze()
thinning_recipe.parameters[name] = (0, indices) #thinning_recipe.parameters[name] = (0, indices)
param_directive = thinning_recipe.parameters.get(name, [])
param_directive.append((0, indices))
thinning_recipe.parameters[name] = param_directive
assert layer_name in conv_connectivity assert layer_name in conv_connectivity, "layer {} is not in conv_connectivity {}".format(layer_name, conv_connectivity)
followers = conv_connectivity[layer_name] if isinstance(conv_connectivity[layer_name], list) else [conv_connectivity[layer_name]] followers = conv_connectivity[layer_name] if isinstance(conv_connectivity[layer_name], list) else [conv_connectivity[layer_name]]
for follower in followers: for follower in followers:
# For each of the convolutional layers that follow, we have to reduce the number of input channels. # For each of the convolutional layers that follow, we have to reduce the number of input channels.
...@@ -270,8 +324,12 @@ def resnet_cifar_create_thinning_recipe_filters(resnet_cifar_model, zeros_mask_d ...@@ -270,8 +324,12 @@ def resnet_cifar_create_thinning_recipe_filters(resnet_cifar_model, zeros_mask_d
thinning_recipe.modules[follower] = conv_directive thinning_recipe.modules[follower] = conv_directive
# Now remove channels from the weights tensor of the follower conv # Now remove channels from the weights tensor of the follower conv
#follower_param = distiller.model_find_param(resnet_cifar_model, follower+'.weight') param_directive = thinning_recipe.parameters.get(follower+'.weight', [])
thinning_recipe.parameters[follower+'.weight'] = (1, indices)
msglogger.info("appending {}".format(follower+'.weight'))
param_directive.append((1, indices))
thinning_recipe.parameters[follower+'.weight'] = param_directive
# Now handle the BatchNormalization layer that follows the convolution # Now handle the BatchNormalization layer that follows the convolution
bn_name = layer_name.replace('conv', 'bn') bn_name = layer_name.replace('conv', 'bn')
...@@ -280,6 +338,80 @@ def resnet_cifar_create_thinning_recipe_filters(resnet_cifar_model, zeros_mask_d ...@@ -280,6 +338,80 @@ def resnet_cifar_create_thinning_recipe_filters(resnet_cifar_model, zeros_mask_d
bn_thinning(thinning_recipe, layers, bn_name, len_thin_features=len(nonzero_filters), thin_features=indices) bn_thinning(thinning_recipe, layers, bn_name, len_thin_features=len(nonzero_filters), thin_features=indices)
return thinning_recipe return thinning_recipe
import math
def vgg_create_thinning_recipe_filters(vgg_model, zeros_mask_dict):
"""Remove filters from VGG
"""
msglogger.info("Invoking vgg_create_thinning_recipe_filters")
layers = {thin_name : m for thin_name, m in vgg_model.named_modules()}
thinning_recipe = ThinningRecipe(modules={}, parameters={})
for name, p_thin in vgg_model.named_parameters():
if p_thin.dim() != 4:
continue
# Find the number of filters, in this weights tensor, that are not 100% sparse_model
filter_view = p_thin.view(p_thin.size(0), -1)
num_filters = filter_view.size()[0]
nonzero_filters = torch.nonzero(filter_view.abs().sum(dim=1))
# If there are zero-filters in this tensor then...
if num_filters > len(nonzero_filters):
msglogger.info("In tensor %s found %d/%d zero filters", name,
num_filters - len(nonzero_filters), num_filters)
# Update the number of outgoing channels (OFMs) in the convolutional layer
layer_name = name[:-len('weights')]
assert isinstance(layers[layer_name], torch.nn.modules.Conv2d)
conv_directive = thinning_recipe.modules.get(layer_name, {})
conv_directive['out_channels'] = len(nonzero_filters)
thinning_recipe.modules[layer_name] = conv_directive
# Select only the non-zero filters
indices = nonzero_filters.data.squeeze()
param_directive = thinning_recipe.parameters.get(name, [])
param_directive.append((0, indices))
thinning_recipe.parameters[name] = param_directive
param_directive = thinning_recipe.parameters.get(layer_name+'.bias', [])
param_directive.append((0, indices))
thinning_recipe.parameters[layer_name+'.bias'] = param_directive
assert layer_name in vgg_conv_connectivity, "layer {} is not in vgg_conv_connectivity {}".format(layer_name, vgg_conv_connectivity)
followers = vgg_conv_connectivity[layer_name] if isinstance(vgg_conv_connectivity[layer_name], list) else [vgg_conv_connectivity[layer_name]]
for follower in followers:
# For each of the convolutional layers that follow, we have to reduce the number of input channels.
conv_directive = thinning_recipe.modules.get(follower, {})
if isinstance(layers[follower], torch.nn.modules.Conv2d):
conv_directive['in_channels'] = len(nonzero_filters)
msglogger.info("{}: setting in_channels = {}".format(follower, len(nonzero_filters)))
elif isinstance(layers[follower], torch.nn.modules.Linear):
# TODO: this code is hard to follow
fm_size = layers[follower].in_features / layers[layer_name].out_channels
conv_directive['in_features'] = fm_size * len(nonzero_filters)
#assert 22589 == conv_directive['in_features']
msglogger.info("{}: setting in_features = {}".format(follower, conv_directive['in_features']))
thinning_recipe.modules[follower] = conv_directive
# Now remove channels from the weights tensor of the follower conv
param_directive = thinning_recipe.parameters.get(follower+'.weight', [])
if isinstance(layers[follower], torch.nn.modules.Conv2d):
param_directive.append((1, indices))
elif isinstance(layers[follower], torch.nn.modules.Linear):
# TODO: this code is hard to follow
fm_size = layers[follower].in_features / layers[layer_name].out_channels
fm_height = fm_width = int(math.sqrt(fm_size))
selection_view = (layers[follower].out_features, layers[layer_name].out_channels, fm_height, fm_width)
#param_directive.append((1, indices, (4096,512,7,7)))
true_view = (layers[follower].out_features, conv_directive['in_features'])
param_directive.append((1, indices, selection_view, true_view))
thinning_recipe.parameters[follower+'.weight'] = param_directive
return thinning_recipe
class ResnetCifarChannelRemover(ScheduledTrainingPolicy): class ResnetCifarChannelRemover(ScheduledTrainingPolicy):
"""A policy which applies a network thinning function""" """A policy which applies a network thinning function"""
...@@ -290,7 +422,7 @@ class ResnetCifarChannelRemover(ScheduledTrainingPolicy): ...@@ -290,7 +422,7 @@ class ResnetCifarChannelRemover(ScheduledTrainingPolicy):
self.thinning_func(model, zeros_mask_dict) self.thinning_func(model, zeros_mask_dict)
class ResnetCifarFilterRemover(ScheduledTrainingPolicy): class FilterRemover(ScheduledTrainingPolicy):
"""A policy which applies a network thinning function""" """A policy which applies a network thinning function"""
def __init__(self, thinning_func_str): def __init__(self, thinning_func_str):
self.thinning_func = globals()[thinning_func_str] self.thinning_func = globals()[thinning_func_str]
...@@ -305,6 +437,11 @@ class ResnetCifarFilterRemover(ScheduledTrainingPolicy): ...@@ -305,6 +437,11 @@ class ResnetCifarFilterRemover(ScheduledTrainingPolicy):
self.thinning_func(model, zeros_mask_dict) self.thinning_func(model, zeros_mask_dict)
self.done = True self.done = True
# This is for backward compatiblity with some of the published schedules
class ResnetCifarFilterRemover(FilterRemover):
def __init__(self, thinning_func_str):
super().__init__(thinning_func_str)
def execute_thinning_recipe(model, zeros_mask_dict, recipe): def execute_thinning_recipe(model, zeros_mask_dict, recipe):
"""Apply a thinning recipe to a model. """Apply a thinning recipe to a model.
...@@ -319,18 +456,32 @@ def execute_thinning_recipe(model, zeros_mask_dict, recipe): ...@@ -319,18 +456,32 @@ def execute_thinning_recipe(model, zeros_mask_dict, recipe):
for layer_name, directives in recipe.modules.items(): for layer_name, directives in recipe.modules.items():
for attr, val in directives.items(): for attr, val in directives.items():
if attr in ['running_mean', 'running_var']: if attr in ['running_mean', 'running_var']:
msglogger.info("{} thinning: setting {} to {}".format(layer_name, attr, len(val[1])))
setattr(layers[layer_name], attr, setattr(layers[layer_name], attr,
torch.index_select(getattr(layers[layer_name], attr), torch.index_select(getattr(layers[layer_name], attr),
dim=val[0], index=val[1])) dim=val[0], index=val[1]))
else: else:
msglogger.info("{} thinning: setting {} to {}".format(layer_name, attr, val))
setattr(layers[layer_name], attr, val) setattr(layers[layer_name], attr, val)
for param_name, info in recipe.parameters.items(): assert len(recipe.parameters) > 0
for param_name, param_directives in recipe.parameters.items():
param = distiller.model_find_param(model, param_name) param = distiller.model_find_param(model, param_name)
param.data = torch.index_select(param.data, dim=info[0], index=info[1]) for directive in param_directives:
mask = zeros_mask_dict[param_name].mask dim = directive[0]
if (mask is not None) and (zeros_mask_dict[param_name].mask.size(dim=info[0]) != len(info[1])): indices = directive[1]
zeros_mask_dict[param_name].mask = torch.index_select(mask, dim=info[0], index=info[1]) if len(directive) == 4: # TODO: this code is hard to follow
layer_name = param_name[:-len('weights')]
selection_view = param.view(*directive[2])
param.data = torch.index_select(selection_view, dim, indices)
#param.data = param.view(4096, 22589)
param.data = param.view(*directive[3])
else:
param.data = torch.index_select(param.data, dim, indices)
msglogger.info("thinning: changing param {} shape: {}".format(param_name, len(indices)))
model.thinning_recipe = recipe mask = zeros_mask_dict[param_name].mask
return model if (mask is not None) and (zeros_mask_dict[param_name].mask.size(dim) != len(indices)):
zeros_mask_dict[param_name].mask = torch.index_select(mask, dim, indices)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment