From 0fd409702f258c45bcfbf49e396a6e1853e31953 Mon Sep 17 00:00:00 2001 From: Neta Zmora <neta.zmora@intel.com> Date: Tue, 26 Jun 2018 15:42:13 +0300 Subject: [PATCH] =?UTF-8?q?Model=20thinning:=20bug=20fix=20=E2=80=93=20pro?= =?UTF-8?q?perly=20handle=20thinning=20of=20Convs=20with=20biases?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The channel-thinning code does not handle correctly channel removal when the Convolution layer has a biases tensor. --- distiller/thinning.py | 4 ++ tests/test_pruning.py | 96 +++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 100 insertions(+) diff --git a/distiller/thinning.py b/distiller/thinning.py index e019e7d..eb27e52 100755 --- a/distiller/thinning.py +++ b/distiller/thinning.py @@ -263,6 +263,10 @@ def create_thinning_recipe_channels(sgraph, model, zeros_mask_dict): # Now remove channels from the weights tensor of the successor conv append_param_directive(thinning_recipe, predecessor+'.weight', (0, indices)) + if layers[predecessor].bias is not None: + # This convolution has bias coefficients + append_param_directive(thinning_recipe, predecessor+'.bias', (0, indices)) + # Now handle the BatchNormalization layer that follows the convolution bn_layers = sgraph.predecessors_f(normalize_module_name(layer_name), ['BatchNormalization']) if len(bn_layers) > 0: diff --git a/tests/test_pruning.py b/tests/test_pruning.py index e71dcb5..38ad480 100755 --- a/tests/test_pruning.py +++ b/tests/test_pruning.py @@ -204,6 +204,102 @@ def arbitrary_channel_pruning(channels_to_remove): logger.info("test_arbitrary_channel_pruning - Done 2") +def test_channel_pruning_conv_bias(): + """Test removal of arbitrary channels, for Convolutions with bias term. + + This is different from test_arbitrary_channel_pruning() in that this model + has Convolution layers with biases tensors. + """ + ARCH = "simplenet_cifar" + DATASET = "cifar10" + channels_to_remove = [0, 1] + has_bn = False + model, zeros_mask_dict = common.setup_test(ARCH, DATASET) + + conv2 = common.find_module_by_name(model, "conv2") + assert conv2 is not None + + # Test that we can access the weights tensor of the first convolution in layer 1 + conv2_p = distiller.model_find_param(model, "conv2.weight") + assert conv2_p is not None + + assert conv2_p.dim() == 4 + num_filters = conv2_p.size(0) + num_channels = conv2_p.size(1) + kernel_height = conv2_p.size(2) + kernel_width = conv2_p.size(3) + cnt_nnz_channels = num_channels - len(channels_to_remove) + + # Let's build our 4D mask. + # We start with a 1D mask of channels, with all but our specified channels set to one + channels = torch.ones(num_channels) + for ch in channels_to_remove: + channels[ch] = 0 + + # Now let's expand back up to a 4D mask + mask = channels.expand(num_filters, num_channels) + mask.unsqueeze_(-1) + mask.unsqueeze_(-1) + mask = mask.expand(num_filters, num_channels, kernel_height, kernel_width).contiguous() + + assert mask.shape == conv2_p.shape + assert distiller.density_ch(mask) == (conv2.in_channels - len(channels_to_remove)) / conv2.in_channels + + # Cool, so now we have a mask for pruning our channels. + # Use the mask to prune + zeros_mask_dict["conv2.weight"].mask = mask + zeros_mask_dict["conv2.weight"].apply_mask(conv2_p) + all_channels = set([ch for ch in range(num_channels)]) + nnz_channels = set(distiller.find_nonzero_channels_list(conv2_p, "conv2.weight")) + channels_removed = all_channels - nnz_channels + logger.info("Channels removed {}".format(channels_removed)) + + # Now, let's do the actual network thinning + distiller.remove_channels(model, zeros_mask_dict, ARCH, DATASET) + conv1 = common.find_module_by_name(model, "conv1") + logger.info(conv1) + logger.info(conv2) + assert conv1.out_channels == cnt_nnz_channels + assert conv2.in_channels == cnt_nnz_channels + assert conv1.weight.size(0) == cnt_nnz_channels + assert conv2.weight.size(1) == cnt_nnz_channels + if has_bn: + bn1 = common.find_module_by_name(model, "layer1.0.bn1") + assert bn1.running_var.size(0) == cnt_nnz_channels + assert bn1.running_mean.size(0) == cnt_nnz_channels + assert bn1.num_features == cnt_nnz_channels + assert bn1.bias.size(0) == cnt_nnz_channels + assert bn1.weight.size(0) == cnt_nnz_channels + + # Let's test saving and loading a thinned model. + # We save 3 times, and load twice, to make sure to cover some corner cases: + # - Make sure that after loading, the model still has hold of the thinning recipes + # - Make sure that after a 2nd load, there no problem loading (in this case, the + # - tensors are already thin, so this is a new flow) + # (1) + save_checkpoint(epoch=0, arch=ARCH, model=model, optimizer=None) + model_2 = create_model(False, DATASET, ARCH, parallel=False) + dummy_input = torch.randn(1, 3, 32, 32) + model(dummy_input) + model_2(dummy_input) + conv2 = common.find_module_by_name(model_2, "conv2") + assert conv2 is not None + with pytest.raises(KeyError): + model_2, compression_scheduler, start_epoch = load_checkpoint(model_2, 'checkpoint.pth.tar') + compression_scheduler = distiller.CompressionScheduler(model) + hasattr(model, 'thinning_recipes') + + # (2) + save_checkpoint(epoch=0, arch=ARCH, model=model, optimizer=None, scheduler=compression_scheduler) + model_2, compression_scheduler, start_epoch = load_checkpoint(model_2, 'checkpoint.pth.tar') + assert hasattr(model_2, 'thinning_recipes') + logger.info("test_arbitrary_channel_pruning - Done") + + # (3) + save_checkpoint(epoch=0, arch=ARCH, model=model_2, optimizer=None, scheduler=compression_scheduler) + model_2, compression_scheduler, start_epoch = load_checkpoint(model_2, 'checkpoint.pth.tar') + assert hasattr(model_2, 'thinning_recipes') + logger.info("test_arbitrary_channel_pruning - Done 2") if __name__ == '__main__': -- GitLab