From a1cf95951428b566acde14b382f0e9e95b119abf Mon Sep 17 00:00:00 2001 From: Neta Zmora <neta.zmora@intel.com> Date: Tue, 26 Jun 2018 15:21:51 +0300 Subject: [PATCH] =?UTF-8?q?Model=20thinning:=20bug=20fix=20=E2=80=93=20agg?= =?UTF-8?q?ressive=20channel/filter=20pruning=20raises=20an=20exception?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * Fix bug: taking the len() of a zero-dimensional ‘indices’ tensor is not legal. Use nelement() instead. A zero-dim ‘indices’ tensor occurs when the pruning is very aggressive and leaves one channel or filter in the tensor. * Protect again pruning of all channels/filters of a layer: Raise ValueError if trying to create (thru thinning) a Convolution layer with zero channels or filters. * Tests: * Some PEP8 cleanup. * Add some test documentation. * Refactored some test code to tests/common.py * Added testing of pruning all the channels/filters in a Convolution --- distiller/thinning.py | 75 ++++++++++++++----------- tests/common.py | 42 ++++++++++++++ tests/test_pruning.py | 125 ++++++++++++++++++++++++++---------------- 3 files changed, 162 insertions(+), 80 deletions(-) create mode 100755 tests/common.py diff --git a/distiller/thinning.py b/distiller/thinning.py index dbd2c56..e019e7d 100755 --- a/distiller/thinning.py +++ b/distiller/thinning.py @@ -29,15 +29,13 @@ documented in a different place. import math import logging -import copy -import re from collections import namedtuple import torch from .policy import ScheduledTrainingPolicy import distiller from distiller import normalize_module_name, denormalize_module_name from apputils import SummaryGraph -from models import ALL_MODEL_NAMES, create_model +from models import create_model msglogger = logging.getLogger() ThinningRecipe = namedtuple('ThinningRecipe', ['modules', 'parameters']) @@ -67,7 +65,7 @@ These tuples can have 2 values, or 4 values. __all__ = ['ThinningRecipe', 'resnet_cifar_remove_layers', 'ChannelRemover', 'remove_channels', 'FilterRemover', 'remove_filters', - 'find_nonzero_channels', + 'find_nonzero_channels', 'find_nonzero_channels_list', 'execute_thinning_recipes_list'] @@ -177,15 +175,21 @@ def find_nonzero_channels(param, param_name): k_sums_mat = kernel_sums.view(num_filters, num_channels).t() nonzero_channels = torch.nonzero(k_sums_mat.abs().sum(dim=1)) - if num_channels > len(nonzero_channels): + if num_channels > nonzero_channels.nelement(): msglogger.info("In tensor %s found %d/%d zero channels", param_name, - num_filters - len(nonzero_channels), num_filters) + num_channels - nonzero_channels.nelement(), num_channels) return nonzero_channels +def find_nonzero_channels_list(param, param_name): + nnz_channels = find_nonzero_channels(param, param_name) + nnz_channels = nnz_channels.view(nnz_channels.numel()) + return nnz_channels.cpu().numpy().tolist() + + def apply_and_save_recipe(model, zeros_mask_dict, thinning_recipe): - if len(thinning_recipe.modules)>0 or len(thinning_recipe.parameters)>0: + if len(thinning_recipe.modules) > 0 or len(thinning_recipe.parameters) > 0: # Now actually remove the filters, chaneels and make the weight tensors smaller execute_thinning_recipe(model, zeros_mask_dict, thinning_recipe) @@ -195,7 +199,7 @@ def apply_and_save_recipe(model, zeros_mask_dict, thinning_recipe): model.thinning_recipes.append(thinning_recipe) else: model.thinning_recipes = [thinning_recipe] - msglogger.info("Created, applied and saved a filter-thinning recipe") + msglogger.info("Created, applied and saved a thinning recipe") else: msglogger.error("Failed to create a thinning recipe") @@ -220,7 +224,7 @@ def create_thinning_recipe_channels(sgraph, model, zeros_mask_dict): msglogger.info("Invoking create_thinning_recipe_channels") thinning_recipe = ThinningRecipe(modules={}, parameters={}) - layers = {mod_name : m for mod_name, m in model.named_modules()} + layers = {mod_name: m for mod_name, m in model.named_modules()} # Traverse all of the model's parameters, search for zero-channels, and # create a thinning recipe that descibes the required changes to the model. @@ -231,16 +235,18 @@ def create_thinning_recipe_channels(sgraph, model, zeros_mask_dict): num_channels = param.size(1) nonzero_channels = find_nonzero_channels(param, param_name) - + num_nnz_channels = nonzero_channels.nelement() + if num_nnz_channels == 0: + raise ValueError("Trying to set zero channels for parameter %s is not allowed" % param_name) # If there are non-zero channels in this tensor then continue to next tensor - if num_channels <= len(nonzero_channels): + if num_channels <= num_nnz_channels: continue # We are removing channels, so update the number of incoming channels (IFMs) # in the convolutional layer layer_name = param_name_2_layer_name(param_name) assert isinstance(layers[layer_name], torch.nn.modules.Conv2d) - append_module_directive(thinning_recipe, layer_name, key='in_channels', val=len(nonzero_channels)) + append_module_directive(thinning_recipe, layer_name, key='in_channels', val=num_nnz_channels) # Select only the non-zero filters indices = nonzero_channels.data.squeeze() @@ -252,7 +258,7 @@ def create_thinning_recipe_channels(sgraph, model, zeros_mask_dict): predecessors = [denormalize_module_name(model, predecessor) for predecessor in predecessors] for predecessor in predecessors: # For each of the convolutional layers that preceed, we have to reduce the number of output channels. - append_module_directive(thinning_recipe, predecessor, key='out_channels', val=len(nonzero_channels)) + append_module_directive(thinning_recipe, predecessor, key='out_channels', val=num_nnz_channels) # Now remove channels from the weights tensor of the successor conv append_param_directive(thinning_recipe, predecessor+'.weight', (0, indices)) @@ -263,7 +269,8 @@ def create_thinning_recipe_channels(sgraph, model, zeros_mask_dict): assert len(bn_layers) == 1 # Thinning of the BN layer that follows the convolution bn_layer_name = denormalize_module_name(model, bn_layers[0]) - bn_thinning(thinning_recipe, layers, bn_layer_name, len_thin_features=len(nonzero_channels), thin_features=indices) + bn_thinning(thinning_recipe, layers, bn_layer_name, + len_thin_features=num_nnz_channels, thin_features=indices) return thinning_recipe @@ -281,7 +288,7 @@ def create_thinning_recipe_filters(sgraph, model, zeros_mask_dict): msglogger.info("Invoking create_thinning_recipe_filters") thinning_recipe = ThinningRecipe(modules={}, parameters={}) - layers = {mod_name : m for mod_name, m in model.named_modules()} + layers = {mod_name: m for mod_name, m in model.named_modules()} for param_name, param in model.named_parameters(): # We are only interested in 4D weights @@ -292,20 +299,22 @@ def create_thinning_recipe_filters(sgraph, model, zeros_mask_dict): filter_view = param.view(param.size(0), -1) num_filters = filter_view.size()[0] nonzero_filters = torch.nonzero(filter_view.abs().sum(dim=1)) - + num_nnz_filters = nonzero_filters.nelement() + if num_nnz_filters == 0: + raise ValueError("Trying to set zero filters for parameter %s is not allowed" % param_name) # If there are non-zero filters in this tensor then continue to next tensor - if num_filters <= len(nonzero_filters): + if num_filters <= num_nnz_filters: msglogger.debug("SKipping {} shape={}".format(param_name_2_layer_name(param_name), param.shape)) continue msglogger.info("In tensor %s found %d/%d zero filters", param_name, - num_filters - len(nonzero_filters), num_filters) + num_filters - num_nnz_filters, num_filters) # We are removing filters, so update the number of outgoing channels (OFMs) # in the convolutional layer layer_name = param_name_2_layer_name(param_name) assert isinstance(layers[layer_name], torch.nn.modules.Conv2d) - append_module_directive(thinning_recipe, layer_name, key='out_channels', val=len(nonzero_filters)) + append_module_directive(thinning_recipe, layer_name, key='out_channels', val=num_nnz_filters) # Select only the non-zero filters indices = nonzero_filters.data.squeeze() @@ -323,8 +332,8 @@ def create_thinning_recipe_filters(sgraph, model, zeros_mask_dict): if isinstance(layers[successor], torch.nn.modules.Conv2d): # For each of the convolutional layers that follow, we have to reduce the number of input channels. - append_module_directive(thinning_recipe, successor, key='in_channels', val=len(nonzero_filters)) - msglogger.info("[recipe] {}: setting in_channels = {}".format(successor, len(nonzero_filters))) + append_module_directive(thinning_recipe, successor, key='in_channels', val=num_nnz_filters) + msglogger.info("[recipe] {}: setting in_channels = {}".format(successor, num_nnz_filters)) # Now remove channels from the weights tensor of the successor conv append_param_directive(thinning_recipe, successor+'.weight', (1, indices)) @@ -332,8 +341,7 @@ def create_thinning_recipe_filters(sgraph, model, zeros_mask_dict): elif isinstance(layers[successor], torch.nn.modules.Linear): # If a Linear (Fully-Connected) layer follows, we need to update it's in_features member fm_size = layers[successor].in_features // layers[layer_name].out_channels - in_features = fm_size * len(nonzero_filters) - #append_module_directive(thinning_recipe, layer_name, key='in_features', val=in_features) + in_features = fm_size * num_nnz_filters append_module_directive(thinning_recipe, successor, key='in_features', val=in_features) msglogger.info("[recipe] {}: setting in_features = {}".format(successor, in_features)) @@ -350,7 +358,8 @@ def create_thinning_recipe_filters(sgraph, model, zeros_mask_dict): assert len(bn_layers) == 1 # Thinning of the BN layer that follows the convolution bn_layer_name = denormalize_module_name(model, bn_layers[0]) - bn_thinning(thinning_recipe, layers, bn_layer_name, len_thin_features=len(nonzero_filters), thin_features=indices) + bn_thinning(thinning_recipe, layers, bn_layer_name, + len_thin_features=num_nnz_filters, thin_features=indices) return thinning_recipe @@ -423,8 +432,9 @@ def execute_thinning_recipe(model, zeros_mask_dict, recipe, loaded_from_file=Fal dim_to_trim = val[0] indices_to_select = val[1] # Check if we're trying to trim a parameter that is already "thin" - if running.size(dim_to_trim) != len(indices_to_select): - msglogger.info("[thinning] {}: setting {} to {}".format(layer_name, attr, len(indices_to_select))) + if running.size(dim_to_trim) != indices_to_select.nelement(): + msglogger.info("[thinning] {}: setting {} to {}". + format(layer_name, attr, indices_to_select.nelement())) setattr(layers[layer_name], attr, torch.index_select(running, dim=dim_to_trim, index=indices_to_select)) else: @@ -438,34 +448,35 @@ def execute_thinning_recipe(model, zeros_mask_dict, recipe, loaded_from_file=Fal for directive in param_directives: dim = directive[0] indices = directive[1] + len_indices = indices.nelement() if len(directive) == 4: # TODO: this code is hard to follow selection_view = param.view(*directive[2]) # Check if we're trying to trim a parameter that is already "thin" - if param.data.size(dim) != len(indices): + if param.data.size(dim) != len_indices: param.data = torch.index_select(selection_view, dim, indices) if param.grad is not None: # We also need to change the dimensions of the gradient tensor. grad_selection_view = param.grad.resize_(*directive[2]) - if grad_selection_view.size(dim) != len(indices): + if grad_selection_view.size(dim) != len_indices: param.grad = torch.index_select(grad_selection_view, dim, indices) param.data = param.view(*directive[3]) if param.grad is not None: param.grad = param.grad.resize_(*directive[3]) else: - if param.data.size(dim) != len(indices): + if param.data.size(dim) != len_indices: param.data = torch.index_select(param.data, dim, indices) # We also need to change the dimensions of the gradient tensor. # If have not done a backward-pass thus far, then the gradient will # not exist, and therefore won't need to be re-dimensioned. - if param.grad is not None and param.grad.size(dim) != len(indices): + if param.grad is not None and param.grad.size(dim) != len_indices: param.grad = torch.index_select(param.grad, dim, indices) - msglogger.info("[thinning] changing param {} shape: {}".format(param_name, len(indices))) + msglogger.info("[thinning] changing param {} shape: {}".format(param_name, len_indices)) if not loaded_from_file: # If the masks are loaded from a checkpoint file, then we don't need to change # their shape, because they are already correctly shaped mask = zeros_mask_dict[param_name].mask - if mask is not None and (mask.size(dim) != len(indices)): + if mask is not None and (mask.size(dim) != len_indices): zeros_mask_dict[param_name].mask = torch.index_select(mask, dim, indices) diff --git a/tests/common.py b/tests/common.py new file mode 100755 index 0000000..75c4bc3 --- /dev/null +++ b/tests/common.py @@ -0,0 +1,42 @@ +# +# Copyright (c) 2018 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import os +import sys +module_path = os.path.abspath(os.path.join('..')) +if module_path not in sys.path: + sys.path.append(module_path) +import distiller +from models import create_model + + +def setup_test(arch, dataset): + model = create_model(False, dataset, arch, parallel=False) + assert model is not None + + # Create the masks + zeros_mask_dict = {} + for name, param in model.named_parameters(): + masker = distiller.ParameterMasker(name) + zeros_mask_dict[name] = masker + return model, zeros_mask_dict + + +def find_module_by_name(model, module_to_find): + for name, m in model.named_modules(): + if name == module_to_find: + return m + return None diff --git a/tests/test_pruning.py b/tests/test_pruning.py index fa01f22..e71dcb5 100755 --- a/tests/test_pruning.py +++ b/tests/test_pruning.py @@ -18,14 +18,16 @@ import logging import torch import os import sys -module_path = os.path.abspath(os.path.join('..')) -if module_path not in sys.path: +try: + import distiller +except ImportError: + module_path = os.path.abspath(os.path.join('..')) sys.path.append(module_path) -import distiller - + import distiller +import common import pytest -from models import ALL_MODEL_NAMES, create_model -from apputils import SummaryGraph, onnx_name_2_pytorch_name, save_checkpoint, load_checkpoint +from models import create_model +from apputils import save_checkpoint, load_checkpoint # Logging configuration logging.basicConfig(level=logging.INFO) @@ -33,44 +35,47 @@ fh = logging.FileHandler('test.log') logger = logging.getLogger() logger.addHandler(fh) -def find_module_by_name(model, module_to_find): - for name, m in model.named_modules(): - if name == module_to_find: - return m - return None + +def test_ranked_filter_pruning(): + ranked_filter_pruning(ratio_to_prune=0.1) + ranked_filter_pruning(ratio_to_prune=0.5) -def setup_test(arch, dataset): - model = create_model(False, dataset, arch, parallel=False) - assert model is not None +def test_prune_all_filters(): + """Pruning all of the filteres in a weights tensor of a Convolution + is illegal and should raise an exception. + """ + with pytest.raises(ValueError): + ranked_filter_pruning(ratio_to_prune=1.0) - # Create the masks - zeros_mask_dict = {} - for name, param in model.named_parameters(): - masker = distiller.ParameterMasker(name) - zeros_mask_dict[name] = masker - return model, zeros_mask_dict -def test_ranked_filter_pruning(): - model, zeros_mask_dict = setup_test("resnet20_cifar", "cifar10") +def ranked_filter_pruning(ratio_to_prune): + """Test L1 ranking and pruning of filters. + + First we rank and prune the filters of a Convolutional layer using + a L1RankedStructureParameterPruner. Then we physically remove the + filters from the model (via "thining" process). + """ + model, zeros_mask_dict = common.setup_test("resnet20_cifar", "cifar10") # Test that we can access the weights tensor of the first convolution in layer 1 conv1_p = distiller.model_find_param(model, "layer1.0.conv1.weight") assert conv1_p is not None - # Test that there are no zero-channels + # Test that there are no zero-filters assert distiller.sparsity_3D(conv1_p) == 0.0 # Create a filter-ranking pruner - reg_regims = {"layer1.0.conv1.weight" : [0.1, "3D"]} + reg_regims = {"layer1.0.conv1.weight": [ratio_to_prune, "3D"]} pruner = distiller.pruning.L1RankedStructureParameterPruner("filter_pruner", reg_regims) pruner.set_param_mask(conv1_p, "layer1.0.conv1.weight", zeros_mask_dict, meta=None) - conv1 = find_module_by_name(model, "layer1.0.conv1") + conv1 = common.find_module_by_name(model, "layer1.0.conv1") assert conv1 is not None # Test that the mask has the correct fraction of filters pruned. # We asked for 10%, but there are only 16 filters, so we have to settle for 1/16 filters - expected_pruning = int(0.1 * conv1.out_channels) / conv1.out_channels + expected_cnt_removed_filters = int(ratio_to_prune * conv1.out_channels) + expected_pruning = expected_cnt_removed_filters / conv1.out_channels assert distiller.sparsity_3D(zeros_mask_dict["layer1.0.conv1.weight"].mask) == expected_pruning # Use the mask to prune @@ -79,24 +84,42 @@ def test_ranked_filter_pruning(): assert distiller.sparsity_3D(conv1_p) == expected_pruning # Remove filters - conv2 = find_module_by_name(model, "layer1.0.conv2") + conv2 = common.find_module_by_name(model, "layer1.0.conv2") assert conv2 is not None assert conv1.out_channels == 16 assert conv2.in_channels == 16 # Test thinning distiller.remove_filters(model, zeros_mask_dict, "resnet20_cifar", "cifar10") - assert conv1.out_channels == 15 - assert conv2.in_channels == 15 + assert conv1.out_channels == 16 - expected_cnt_removed_filters + assert conv2.in_channels == 16 - expected_cnt_removed_filters def test_arbitrary_channel_pruning(): + arbitrary_channel_pruning(channels_to_remove=[0, 2]) + + +def test_prune_all_channels(): + """Pruning all of the channels in a weights tensor of a Convolution + is illegal and should raise an exception. + """ + with pytest.raises(ValueError): + arbitrary_channel_pruning(channels_to_remove=[ch for ch in range(16)]) + + +def arbitrary_channel_pruning(channels_to_remove): + """Test removal of arbitrary channels. + + The test receives a specification of channels to remove. + Based on this specification, the channels are pruned and then physically + removed from the model (via a "thinning" process). + """ ARCH = "resnet20_cifar" DATASET = "cifar10" - model, zeros_mask_dict = setup_test(ARCH, DATASET) + model, zeros_mask_dict = common.setup_test(ARCH, DATASET) - conv2 = find_module_by_name(model, "layer1.0.conv2") + conv2 = common.find_module_by_name(model, "layer1.0.conv2") assert conv2 is not None # Test that we can access the weights tensor of the first convolution in layer 1 @@ -108,8 +131,7 @@ def test_arbitrary_channel_pruning(): num_channels = conv2_p.size(1) kernel_height = conv2_p.size(2) kernel_width = conv2_p.size(3) - - channels_to_remove = [0, 2] + cnt_nnz_channels = num_channels - len(channels_to_remove) # Let's build our 4D mask. # We start with a 1D mask of channels, with all but our specified channels set to one @@ -131,52 +153,59 @@ def test_arbitrary_channel_pruning(): zeros_mask_dict["layer1.0.conv2.weight"].mask = mask zeros_mask_dict["layer1.0.conv2.weight"].apply_mask(conv2_p) all_channels = set([ch for ch in range(num_channels)]) - channels_removed = all_channels - set(distiller.find_nonzero_channels(conv2_p, "layer1.0.conv2.weight")) - logger.info(channels_removed) + nnz_channels = set(distiller.find_nonzero_channels_list(conv2_p, "layer1.0.conv2.weight")) + channels_removed = all_channels - nnz_channels + logger.info("Channels removed {}".format(channels_removed)) # Now, let's do the actual network thinning distiller.remove_channels(model, zeros_mask_dict, ARCH, DATASET) - conv1 = find_module_by_name(model, "layer1.0.conv1") + conv1 = common.find_module_by_name(model, "layer1.0.conv1") logger.info(conv1) logger.info(conv2) - assert conv1.out_channels == 14 - assert conv2.in_channels == 14 - assert conv1.weight.size(0) == 14 - assert conv2.weight.size(1) == 14 - bn1 = find_module_by_name(model, "layer1.0.bn1") - assert bn1.running_var.size(0) == 14 - assert bn1.running_mean.size(0) == 14 - assert bn1.num_features == 14 - assert bn1.bias.size(0) == 14 - assert bn1.weight.size(0) == 14 + assert conv1.out_channels == cnt_nnz_channels + assert conv2.in_channels == cnt_nnz_channels + assert conv1.weight.size(0) == cnt_nnz_channels + assert conv2.weight.size(1) == cnt_nnz_channels + bn1 = common.find_module_by_name(model, "layer1.0.bn1") + assert bn1.running_var.size(0) == cnt_nnz_channels + assert bn1.running_mean.size(0) == cnt_nnz_channels + assert bn1.num_features == cnt_nnz_channels + assert bn1.bias.size(0) == cnt_nnz_channels + assert bn1.weight.size(0) == cnt_nnz_channels # Let's test saving and loading a thinned model. # We save 3 times, and load twice, to make sure to cover some corner cases: # - Make sure that after loading, the model still has hold of the thinning recipes # - Make sure that after a 2nd load, there no problem loading (in this case, the # - tensors are already thin, so this is a new flow) + # (1) save_checkpoint(epoch=0, arch=ARCH, model=model, optimizer=None) model_2 = create_model(False, DATASET, ARCH, parallel=False) dummy_input = torch.randn(1, 3, 32, 32) model(dummy_input) model_2(dummy_input) - conv2 = find_module_by_name(model_2, "layer1.0.conv2") + conv2 = common.find_module_by_name(model_2, "layer1.0.conv2") assert conv2 is not None with pytest.raises(KeyError): model_2, compression_scheduler, start_epoch = load_checkpoint(model_2, 'checkpoint.pth.tar') - compression_scheduler = distiller.CompressionScheduler(model) hasattr(model, 'thinning_recipes') + + # (2) save_checkpoint(epoch=0, arch=ARCH, model=model, optimizer=None, scheduler=compression_scheduler) model_2, compression_scheduler, start_epoch = load_checkpoint(model_2, 'checkpoint.pth.tar') assert hasattr(model_2, 'thinning_recipes') logger.info("test_arbitrary_channel_pruning - Done") + # (3) save_checkpoint(epoch=0, arch=ARCH, model=model_2, optimizer=None, scheduler=compression_scheduler) model_2, compression_scheduler, start_epoch = load_checkpoint(model_2, 'checkpoint.pth.tar') + assert hasattr(model_2, 'thinning_recipes') logger.info("test_arbitrary_channel_pruning - Done 2") + if __name__ == '__main__': test_arbitrary_channel_pruning() + test_prune_all_channels() -- GitLab