Skip to content
Snippets Groups Projects
Commit a1cf9595 authored by Neta Zmora's avatar Neta Zmora
Browse files

Model thinning: bug fix – aggressive channel/filter pruning raises an exception

* Fix bug: taking the len() of a zero-dimensional ‘indices’ tensor is not legal.
    Use nelement() instead.
    A zero-dim ‘indices’ tensor occurs when the pruning is very aggressive and
    leaves one channel or filter in the tensor.
* Protect again pruning of all channels/filters of a layer: Raise ValueError if
  trying to create (thru thinning) a Convolution layer with zero channels or filters.
* Tests:
* Some PEP8 cleanup.
* Add some test documentation.
* Refactored some test code to tests/common.py
* Added testing of pruning all the channels/filters in a Convolution
parent 94d3d518
No related branches found
No related tags found
No related merge requests found
......@@ -29,15 +29,13 @@ documented in a different place.
import math
import logging
import copy
import re
from collections import namedtuple
import torch
from .policy import ScheduledTrainingPolicy
import distiller
from distiller import normalize_module_name, denormalize_module_name
from apputils import SummaryGraph
from models import ALL_MODEL_NAMES, create_model
from models import create_model
msglogger = logging.getLogger()
ThinningRecipe = namedtuple('ThinningRecipe', ['modules', 'parameters'])
......@@ -67,7 +65,7 @@ These tuples can have 2 values, or 4 values.
__all__ = ['ThinningRecipe', 'resnet_cifar_remove_layers',
'ChannelRemover', 'remove_channels',
'FilterRemover', 'remove_filters',
'find_nonzero_channels',
'find_nonzero_channels', 'find_nonzero_channels_list',
'execute_thinning_recipes_list']
......@@ -177,15 +175,21 @@ def find_nonzero_channels(param, param_name):
k_sums_mat = kernel_sums.view(num_filters, num_channels).t()
nonzero_channels = torch.nonzero(k_sums_mat.abs().sum(dim=1))
if num_channels > len(nonzero_channels):
if num_channels > nonzero_channels.nelement():
msglogger.info("In tensor %s found %d/%d zero channels", param_name,
num_filters - len(nonzero_channels), num_filters)
num_channels - nonzero_channels.nelement(), num_channels)
return nonzero_channels
def find_nonzero_channels_list(param, param_name):
nnz_channels = find_nonzero_channels(param, param_name)
nnz_channels = nnz_channels.view(nnz_channels.numel())
return nnz_channels.cpu().numpy().tolist()
def apply_and_save_recipe(model, zeros_mask_dict, thinning_recipe):
if len(thinning_recipe.modules)>0 or len(thinning_recipe.parameters)>0:
if len(thinning_recipe.modules) > 0 or len(thinning_recipe.parameters) > 0:
# Now actually remove the filters, chaneels and make the weight tensors smaller
execute_thinning_recipe(model, zeros_mask_dict, thinning_recipe)
......@@ -195,7 +199,7 @@ def apply_and_save_recipe(model, zeros_mask_dict, thinning_recipe):
model.thinning_recipes.append(thinning_recipe)
else:
model.thinning_recipes = [thinning_recipe]
msglogger.info("Created, applied and saved a filter-thinning recipe")
msglogger.info("Created, applied and saved a thinning recipe")
else:
msglogger.error("Failed to create a thinning recipe")
......@@ -220,7 +224,7 @@ def create_thinning_recipe_channels(sgraph, model, zeros_mask_dict):
msglogger.info("Invoking create_thinning_recipe_channels")
thinning_recipe = ThinningRecipe(modules={}, parameters={})
layers = {mod_name : m for mod_name, m in model.named_modules()}
layers = {mod_name: m for mod_name, m in model.named_modules()}
# Traverse all of the model's parameters, search for zero-channels, and
# create a thinning recipe that descibes the required changes to the model.
......@@ -231,16 +235,18 @@ def create_thinning_recipe_channels(sgraph, model, zeros_mask_dict):
num_channels = param.size(1)
nonzero_channels = find_nonzero_channels(param, param_name)
num_nnz_channels = nonzero_channels.nelement()
if num_nnz_channels == 0:
raise ValueError("Trying to set zero channels for parameter %s is not allowed" % param_name)
# If there are non-zero channels in this tensor then continue to next tensor
if num_channels <= len(nonzero_channels):
if num_channels <= num_nnz_channels:
continue
# We are removing channels, so update the number of incoming channels (IFMs)
# in the convolutional layer
layer_name = param_name_2_layer_name(param_name)
assert isinstance(layers[layer_name], torch.nn.modules.Conv2d)
append_module_directive(thinning_recipe, layer_name, key='in_channels', val=len(nonzero_channels))
append_module_directive(thinning_recipe, layer_name, key='in_channels', val=num_nnz_channels)
# Select only the non-zero filters
indices = nonzero_channels.data.squeeze()
......@@ -252,7 +258,7 @@ def create_thinning_recipe_channels(sgraph, model, zeros_mask_dict):
predecessors = [denormalize_module_name(model, predecessor) for predecessor in predecessors]
for predecessor in predecessors:
# For each of the convolutional layers that preceed, we have to reduce the number of output channels.
append_module_directive(thinning_recipe, predecessor, key='out_channels', val=len(nonzero_channels))
append_module_directive(thinning_recipe, predecessor, key='out_channels', val=num_nnz_channels)
# Now remove channels from the weights tensor of the successor conv
append_param_directive(thinning_recipe, predecessor+'.weight', (0, indices))
......@@ -263,7 +269,8 @@ def create_thinning_recipe_channels(sgraph, model, zeros_mask_dict):
assert len(bn_layers) == 1
# Thinning of the BN layer that follows the convolution
bn_layer_name = denormalize_module_name(model, bn_layers[0])
bn_thinning(thinning_recipe, layers, bn_layer_name, len_thin_features=len(nonzero_channels), thin_features=indices)
bn_thinning(thinning_recipe, layers, bn_layer_name,
len_thin_features=num_nnz_channels, thin_features=indices)
return thinning_recipe
......@@ -281,7 +288,7 @@ def create_thinning_recipe_filters(sgraph, model, zeros_mask_dict):
msglogger.info("Invoking create_thinning_recipe_filters")
thinning_recipe = ThinningRecipe(modules={}, parameters={})
layers = {mod_name : m for mod_name, m in model.named_modules()}
layers = {mod_name: m for mod_name, m in model.named_modules()}
for param_name, param in model.named_parameters():
# We are only interested in 4D weights
......@@ -292,20 +299,22 @@ def create_thinning_recipe_filters(sgraph, model, zeros_mask_dict):
filter_view = param.view(param.size(0), -1)
num_filters = filter_view.size()[0]
nonzero_filters = torch.nonzero(filter_view.abs().sum(dim=1))
num_nnz_filters = nonzero_filters.nelement()
if num_nnz_filters == 0:
raise ValueError("Trying to set zero filters for parameter %s is not allowed" % param_name)
# If there are non-zero filters in this tensor then continue to next tensor
if num_filters <= len(nonzero_filters):
if num_filters <= num_nnz_filters:
msglogger.debug("SKipping {} shape={}".format(param_name_2_layer_name(param_name), param.shape))
continue
msglogger.info("In tensor %s found %d/%d zero filters", param_name,
num_filters - len(nonzero_filters), num_filters)
num_filters - num_nnz_filters, num_filters)
# We are removing filters, so update the number of outgoing channels (OFMs)
# in the convolutional layer
layer_name = param_name_2_layer_name(param_name)
assert isinstance(layers[layer_name], torch.nn.modules.Conv2d)
append_module_directive(thinning_recipe, layer_name, key='out_channels', val=len(nonzero_filters))
append_module_directive(thinning_recipe, layer_name, key='out_channels', val=num_nnz_filters)
# Select only the non-zero filters
indices = nonzero_filters.data.squeeze()
......@@ -323,8 +332,8 @@ def create_thinning_recipe_filters(sgraph, model, zeros_mask_dict):
if isinstance(layers[successor], torch.nn.modules.Conv2d):
# For each of the convolutional layers that follow, we have to reduce the number of input channels.
append_module_directive(thinning_recipe, successor, key='in_channels', val=len(nonzero_filters))
msglogger.info("[recipe] {}: setting in_channels = {}".format(successor, len(nonzero_filters)))
append_module_directive(thinning_recipe, successor, key='in_channels', val=num_nnz_filters)
msglogger.info("[recipe] {}: setting in_channels = {}".format(successor, num_nnz_filters))
# Now remove channels from the weights tensor of the successor conv
append_param_directive(thinning_recipe, successor+'.weight', (1, indices))
......@@ -332,8 +341,7 @@ def create_thinning_recipe_filters(sgraph, model, zeros_mask_dict):
elif isinstance(layers[successor], torch.nn.modules.Linear):
# If a Linear (Fully-Connected) layer follows, we need to update it's in_features member
fm_size = layers[successor].in_features // layers[layer_name].out_channels
in_features = fm_size * len(nonzero_filters)
#append_module_directive(thinning_recipe, layer_name, key='in_features', val=in_features)
in_features = fm_size * num_nnz_filters
append_module_directive(thinning_recipe, successor, key='in_features', val=in_features)
msglogger.info("[recipe] {}: setting in_features = {}".format(successor, in_features))
......@@ -350,7 +358,8 @@ def create_thinning_recipe_filters(sgraph, model, zeros_mask_dict):
assert len(bn_layers) == 1
# Thinning of the BN layer that follows the convolution
bn_layer_name = denormalize_module_name(model, bn_layers[0])
bn_thinning(thinning_recipe, layers, bn_layer_name, len_thin_features=len(nonzero_filters), thin_features=indices)
bn_thinning(thinning_recipe, layers, bn_layer_name,
len_thin_features=num_nnz_filters, thin_features=indices)
return thinning_recipe
......@@ -423,8 +432,9 @@ def execute_thinning_recipe(model, zeros_mask_dict, recipe, loaded_from_file=Fal
dim_to_trim = val[0]
indices_to_select = val[1]
# Check if we're trying to trim a parameter that is already "thin"
if running.size(dim_to_trim) != len(indices_to_select):
msglogger.info("[thinning] {}: setting {} to {}".format(layer_name, attr, len(indices_to_select)))
if running.size(dim_to_trim) != indices_to_select.nelement():
msglogger.info("[thinning] {}: setting {} to {}".
format(layer_name, attr, indices_to_select.nelement()))
setattr(layers[layer_name], attr,
torch.index_select(running, dim=dim_to_trim, index=indices_to_select))
else:
......@@ -438,34 +448,35 @@ def execute_thinning_recipe(model, zeros_mask_dict, recipe, loaded_from_file=Fal
for directive in param_directives:
dim = directive[0]
indices = directive[1]
len_indices = indices.nelement()
if len(directive) == 4: # TODO: this code is hard to follow
selection_view = param.view(*directive[2])
# Check if we're trying to trim a parameter that is already "thin"
if param.data.size(dim) != len(indices):
if param.data.size(dim) != len_indices:
param.data = torch.index_select(selection_view, dim, indices)
if param.grad is not None:
# We also need to change the dimensions of the gradient tensor.
grad_selection_view = param.grad.resize_(*directive[2])
if grad_selection_view.size(dim) != len(indices):
if grad_selection_view.size(dim) != len_indices:
param.grad = torch.index_select(grad_selection_view, dim, indices)
param.data = param.view(*directive[3])
if param.grad is not None:
param.grad = param.grad.resize_(*directive[3])
else:
if param.data.size(dim) != len(indices):
if param.data.size(dim) != len_indices:
param.data = torch.index_select(param.data, dim, indices)
# We also need to change the dimensions of the gradient tensor.
# If have not done a backward-pass thus far, then the gradient will
# not exist, and therefore won't need to be re-dimensioned.
if param.grad is not None and param.grad.size(dim) != len(indices):
if param.grad is not None and param.grad.size(dim) != len_indices:
param.grad = torch.index_select(param.grad, dim, indices)
msglogger.info("[thinning] changing param {} shape: {}".format(param_name, len(indices)))
msglogger.info("[thinning] changing param {} shape: {}".format(param_name, len_indices))
if not loaded_from_file:
# If the masks are loaded from a checkpoint file, then we don't need to change
# their shape, because they are already correctly shaped
mask = zeros_mask_dict[param_name].mask
if mask is not None and (mask.size(dim) != len(indices)):
if mask is not None and (mask.size(dim) != len_indices):
zeros_mask_dict[param_name].mask = torch.index_select(mask, dim, indices)
#
# Copyright (c) 2018 Intel Corporation
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import os
import sys
module_path = os.path.abspath(os.path.join('..'))
if module_path not in sys.path:
sys.path.append(module_path)
import distiller
from models import create_model
def setup_test(arch, dataset):
model = create_model(False, dataset, arch, parallel=False)
assert model is not None
# Create the masks
zeros_mask_dict = {}
for name, param in model.named_parameters():
masker = distiller.ParameterMasker(name)
zeros_mask_dict[name] = masker
return model, zeros_mask_dict
def find_module_by_name(model, module_to_find):
for name, m in model.named_modules():
if name == module_to_find:
return m
return None
......@@ -18,14 +18,16 @@ import logging
import torch
import os
import sys
module_path = os.path.abspath(os.path.join('..'))
if module_path not in sys.path:
try:
import distiller
except ImportError:
module_path = os.path.abspath(os.path.join('..'))
sys.path.append(module_path)
import distiller
import distiller
import common
import pytest
from models import ALL_MODEL_NAMES, create_model
from apputils import SummaryGraph, onnx_name_2_pytorch_name, save_checkpoint, load_checkpoint
from models import create_model
from apputils import save_checkpoint, load_checkpoint
# Logging configuration
logging.basicConfig(level=logging.INFO)
......@@ -33,44 +35,47 @@ fh = logging.FileHandler('test.log')
logger = logging.getLogger()
logger.addHandler(fh)
def find_module_by_name(model, module_to_find):
for name, m in model.named_modules():
if name == module_to_find:
return m
return None
def test_ranked_filter_pruning():
ranked_filter_pruning(ratio_to_prune=0.1)
ranked_filter_pruning(ratio_to_prune=0.5)
def setup_test(arch, dataset):
model = create_model(False, dataset, arch, parallel=False)
assert model is not None
def test_prune_all_filters():
"""Pruning all of the filteres in a weights tensor of a Convolution
is illegal and should raise an exception.
"""
with pytest.raises(ValueError):
ranked_filter_pruning(ratio_to_prune=1.0)
# Create the masks
zeros_mask_dict = {}
for name, param in model.named_parameters():
masker = distiller.ParameterMasker(name)
zeros_mask_dict[name] = masker
return model, zeros_mask_dict
def test_ranked_filter_pruning():
model, zeros_mask_dict = setup_test("resnet20_cifar", "cifar10")
def ranked_filter_pruning(ratio_to_prune):
"""Test L1 ranking and pruning of filters.
First we rank and prune the filters of a Convolutional layer using
a L1RankedStructureParameterPruner. Then we physically remove the
filters from the model (via "thining" process).
"""
model, zeros_mask_dict = common.setup_test("resnet20_cifar", "cifar10")
# Test that we can access the weights tensor of the first convolution in layer 1
conv1_p = distiller.model_find_param(model, "layer1.0.conv1.weight")
assert conv1_p is not None
# Test that there are no zero-channels
# Test that there are no zero-filters
assert distiller.sparsity_3D(conv1_p) == 0.0
# Create a filter-ranking pruner
reg_regims = {"layer1.0.conv1.weight" : [0.1, "3D"]}
reg_regims = {"layer1.0.conv1.weight": [ratio_to_prune, "3D"]}
pruner = distiller.pruning.L1RankedStructureParameterPruner("filter_pruner", reg_regims)
pruner.set_param_mask(conv1_p, "layer1.0.conv1.weight", zeros_mask_dict, meta=None)
conv1 = find_module_by_name(model, "layer1.0.conv1")
conv1 = common.find_module_by_name(model, "layer1.0.conv1")
assert conv1 is not None
# Test that the mask has the correct fraction of filters pruned.
# We asked for 10%, but there are only 16 filters, so we have to settle for 1/16 filters
expected_pruning = int(0.1 * conv1.out_channels) / conv1.out_channels
expected_cnt_removed_filters = int(ratio_to_prune * conv1.out_channels)
expected_pruning = expected_cnt_removed_filters / conv1.out_channels
assert distiller.sparsity_3D(zeros_mask_dict["layer1.0.conv1.weight"].mask) == expected_pruning
# Use the mask to prune
......@@ -79,24 +84,42 @@ def test_ranked_filter_pruning():
assert distiller.sparsity_3D(conv1_p) == expected_pruning
# Remove filters
conv2 = find_module_by_name(model, "layer1.0.conv2")
conv2 = common.find_module_by_name(model, "layer1.0.conv2")
assert conv2 is not None
assert conv1.out_channels == 16
assert conv2.in_channels == 16
# Test thinning
distiller.remove_filters(model, zeros_mask_dict, "resnet20_cifar", "cifar10")
assert conv1.out_channels == 15
assert conv2.in_channels == 15
assert conv1.out_channels == 16 - expected_cnt_removed_filters
assert conv2.in_channels == 16 - expected_cnt_removed_filters
def test_arbitrary_channel_pruning():
arbitrary_channel_pruning(channels_to_remove=[0, 2])
def test_prune_all_channels():
"""Pruning all of the channels in a weights tensor of a Convolution
is illegal and should raise an exception.
"""
with pytest.raises(ValueError):
arbitrary_channel_pruning(channels_to_remove=[ch for ch in range(16)])
def arbitrary_channel_pruning(channels_to_remove):
"""Test removal of arbitrary channels.
The test receives a specification of channels to remove.
Based on this specification, the channels are pruned and then physically
removed from the model (via a "thinning" process).
"""
ARCH = "resnet20_cifar"
DATASET = "cifar10"
model, zeros_mask_dict = setup_test(ARCH, DATASET)
model, zeros_mask_dict = common.setup_test(ARCH, DATASET)
conv2 = find_module_by_name(model, "layer1.0.conv2")
conv2 = common.find_module_by_name(model, "layer1.0.conv2")
assert conv2 is not None
# Test that we can access the weights tensor of the first convolution in layer 1
......@@ -108,8 +131,7 @@ def test_arbitrary_channel_pruning():
num_channels = conv2_p.size(1)
kernel_height = conv2_p.size(2)
kernel_width = conv2_p.size(3)
channels_to_remove = [0, 2]
cnt_nnz_channels = num_channels - len(channels_to_remove)
# Let's build our 4D mask.
# We start with a 1D mask of channels, with all but our specified channels set to one
......@@ -131,52 +153,59 @@ def test_arbitrary_channel_pruning():
zeros_mask_dict["layer1.0.conv2.weight"].mask = mask
zeros_mask_dict["layer1.0.conv2.weight"].apply_mask(conv2_p)
all_channels = set([ch for ch in range(num_channels)])
channels_removed = all_channels - set(distiller.find_nonzero_channels(conv2_p, "layer1.0.conv2.weight"))
logger.info(channels_removed)
nnz_channels = set(distiller.find_nonzero_channels_list(conv2_p, "layer1.0.conv2.weight"))
channels_removed = all_channels - nnz_channels
logger.info("Channels removed {}".format(channels_removed))
# Now, let's do the actual network thinning
distiller.remove_channels(model, zeros_mask_dict, ARCH, DATASET)
conv1 = find_module_by_name(model, "layer1.0.conv1")
conv1 = common.find_module_by_name(model, "layer1.0.conv1")
logger.info(conv1)
logger.info(conv2)
assert conv1.out_channels == 14
assert conv2.in_channels == 14
assert conv1.weight.size(0) == 14
assert conv2.weight.size(1) == 14
bn1 = find_module_by_name(model, "layer1.0.bn1")
assert bn1.running_var.size(0) == 14
assert bn1.running_mean.size(0) == 14
assert bn1.num_features == 14
assert bn1.bias.size(0) == 14
assert bn1.weight.size(0) == 14
assert conv1.out_channels == cnt_nnz_channels
assert conv2.in_channels == cnt_nnz_channels
assert conv1.weight.size(0) == cnt_nnz_channels
assert conv2.weight.size(1) == cnt_nnz_channels
bn1 = common.find_module_by_name(model, "layer1.0.bn1")
assert bn1.running_var.size(0) == cnt_nnz_channels
assert bn1.running_mean.size(0) == cnt_nnz_channels
assert bn1.num_features == cnt_nnz_channels
assert bn1.bias.size(0) == cnt_nnz_channels
assert bn1.weight.size(0) == cnt_nnz_channels
# Let's test saving and loading a thinned model.
# We save 3 times, and load twice, to make sure to cover some corner cases:
# - Make sure that after loading, the model still has hold of the thinning recipes
# - Make sure that after a 2nd load, there no problem loading (in this case, the
# - tensors are already thin, so this is a new flow)
# (1)
save_checkpoint(epoch=0, arch=ARCH, model=model, optimizer=None)
model_2 = create_model(False, DATASET, ARCH, parallel=False)
dummy_input = torch.randn(1, 3, 32, 32)
model(dummy_input)
model_2(dummy_input)
conv2 = find_module_by_name(model_2, "layer1.0.conv2")
conv2 = common.find_module_by_name(model_2, "layer1.0.conv2")
assert conv2 is not None
with pytest.raises(KeyError):
model_2, compression_scheduler, start_epoch = load_checkpoint(model_2, 'checkpoint.pth.tar')
compression_scheduler = distiller.CompressionScheduler(model)
hasattr(model, 'thinning_recipes')
# (2)
save_checkpoint(epoch=0, arch=ARCH, model=model, optimizer=None, scheduler=compression_scheduler)
model_2, compression_scheduler, start_epoch = load_checkpoint(model_2, 'checkpoint.pth.tar')
assert hasattr(model_2, 'thinning_recipes')
logger.info("test_arbitrary_channel_pruning - Done")
# (3)
save_checkpoint(epoch=0, arch=ARCH, model=model_2, optimizer=None, scheduler=compression_scheduler)
model_2, compression_scheduler, start_epoch = load_checkpoint(model_2, 'checkpoint.pth.tar')
assert hasattr(model_2, 'thinning_recipes')
logger.info("test_arbitrary_channel_pruning - Done 2")
if __name__ == '__main__':
test_arbitrary_channel_pruning()
test_prune_all_channels()
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment