From 0fd409702f258c45bcfbf49e396a6e1853e31953 Mon Sep 17 00:00:00 2001
From: Neta Zmora <neta.zmora@intel.com>
Date: Tue, 26 Jun 2018 15:42:13 +0300
Subject: [PATCH] =?UTF-8?q?Model=20thinning:=20bug=20fix=20=E2=80=93=20pro?=
 =?UTF-8?q?perly=20handle=20thinning=20of=20Convs=20with=20biases?=
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit

The channel-thinning code does not handle correctly channel removal when
the Convolution layer has a biases tensor.
---
 distiller/thinning.py |  4 ++
 tests/test_pruning.py | 96 +++++++++++++++++++++++++++++++++++++++++++
 2 files changed, 100 insertions(+)

diff --git a/distiller/thinning.py b/distiller/thinning.py
index e019e7d..eb27e52 100755
--- a/distiller/thinning.py
+++ b/distiller/thinning.py
@@ -263,6 +263,10 @@ def create_thinning_recipe_channels(sgraph, model, zeros_mask_dict):
             # Now remove channels from the weights tensor of the successor conv
             append_param_directive(thinning_recipe, predecessor+'.weight', (0, indices))
 
+            if layers[predecessor].bias is not None:
+                # This convolution has bias coefficients
+                append_param_directive(thinning_recipe, predecessor+'.bias', (0, indices))
+
         # Now handle the BatchNormalization layer that follows the convolution
         bn_layers = sgraph.predecessors_f(normalize_module_name(layer_name), ['BatchNormalization'])
         if len(bn_layers) > 0:
diff --git a/tests/test_pruning.py b/tests/test_pruning.py
index e71dcb5..38ad480 100755
--- a/tests/test_pruning.py
+++ b/tests/test_pruning.py
@@ -204,6 +204,102 @@ def arbitrary_channel_pruning(channels_to_remove):
     logger.info("test_arbitrary_channel_pruning - Done 2")
 
 
+def test_channel_pruning_conv_bias():
+    """Test removal of arbitrary channels, for Convolutions with bias term.
+
+    This is different from test_arbitrary_channel_pruning() in that this model
+    has Convolution layers with biases tensors.
+    """
+    ARCH = "simplenet_cifar"
+    DATASET = "cifar10"
+    channels_to_remove = [0, 1]
+    has_bn = False
+    model, zeros_mask_dict = common.setup_test(ARCH, DATASET)
+
+    conv2 = common.find_module_by_name(model, "conv2")
+    assert conv2 is not None
+
+    # Test that we can access the weights tensor of the first convolution in layer 1
+    conv2_p = distiller.model_find_param(model, "conv2.weight")
+    assert conv2_p is not None
+
+    assert conv2_p.dim() == 4
+    num_filters = conv2_p.size(0)
+    num_channels = conv2_p.size(1)
+    kernel_height = conv2_p.size(2)
+    kernel_width = conv2_p.size(3)
+    cnt_nnz_channels = num_channels - len(channels_to_remove)
+
+    # Let's build our 4D mask.
+    # We start with a 1D mask of channels, with all but our specified channels set to one
+    channels = torch.ones(num_channels)
+    for ch in channels_to_remove:
+        channels[ch] = 0
+
+    # Now let's expand back up to a 4D mask
+    mask = channels.expand(num_filters, num_channels)
+    mask.unsqueeze_(-1)
+    mask.unsqueeze_(-1)
+    mask = mask.expand(num_filters, num_channels, kernel_height, kernel_width).contiguous()
+
+    assert mask.shape == conv2_p.shape
+    assert distiller.density_ch(mask) == (conv2.in_channels - len(channels_to_remove)) / conv2.in_channels
+
+    # Cool, so now we have a mask for pruning our channels.
+    # Use the mask to prune
+    zeros_mask_dict["conv2.weight"].mask = mask
+    zeros_mask_dict["conv2.weight"].apply_mask(conv2_p)
+    all_channels = set([ch for ch in range(num_channels)])
+    nnz_channels = set(distiller.find_nonzero_channels_list(conv2_p, "conv2.weight"))
+    channels_removed = all_channels - nnz_channels
+    logger.info("Channels removed {}".format(channels_removed))
+
+    # Now, let's do the actual network thinning
+    distiller.remove_channels(model, zeros_mask_dict, ARCH, DATASET)
+    conv1 = common.find_module_by_name(model, "conv1")
+    logger.info(conv1)
+    logger.info(conv2)
+    assert conv1.out_channels == cnt_nnz_channels
+    assert conv2.in_channels == cnt_nnz_channels
+    assert conv1.weight.size(0) == cnt_nnz_channels
+    assert conv2.weight.size(1) == cnt_nnz_channels
+    if has_bn:
+        bn1 = common.find_module_by_name(model, "layer1.0.bn1")
+        assert bn1.running_var.size(0) == cnt_nnz_channels
+        assert bn1.running_mean.size(0) == cnt_nnz_channels
+        assert bn1.num_features == cnt_nnz_channels
+        assert bn1.bias.size(0) == cnt_nnz_channels
+        assert bn1.weight.size(0) == cnt_nnz_channels
+
+    # Let's test saving and loading a thinned model.
+    # We save 3 times, and load twice, to make sure to cover some corner cases:
+    #   - Make sure that after loading, the model still has hold of the thinning recipes
+    #   - Make sure that after a 2nd load, there no problem loading (in this case, the
+    #   - tensors are already thin, so this is a new flow)
+    # (1)
+    save_checkpoint(epoch=0, arch=ARCH, model=model, optimizer=None)
+    model_2 = create_model(False, DATASET, ARCH, parallel=False)
+    dummy_input = torch.randn(1, 3, 32, 32)
+    model(dummy_input)
+    model_2(dummy_input)
+    conv2 = common.find_module_by_name(model_2, "conv2")
+    assert conv2 is not None
+    with pytest.raises(KeyError):
+        model_2, compression_scheduler, start_epoch = load_checkpoint(model_2, 'checkpoint.pth.tar')
+    compression_scheduler = distiller.CompressionScheduler(model)
+    hasattr(model, 'thinning_recipes')
+
+    # (2)
+    save_checkpoint(epoch=0, arch=ARCH, model=model, optimizer=None, scheduler=compression_scheduler)
+    model_2, compression_scheduler, start_epoch = load_checkpoint(model_2, 'checkpoint.pth.tar')
+    assert hasattr(model_2, 'thinning_recipes')
+    logger.info("test_arbitrary_channel_pruning - Done")
+
+    # (3)
+    save_checkpoint(epoch=0, arch=ARCH, model=model_2, optimizer=None, scheduler=compression_scheduler)
+    model_2, compression_scheduler, start_epoch = load_checkpoint(model_2, 'checkpoint.pth.tar')
+    assert hasattr(model_2, 'thinning_recipes')
+    logger.info("test_arbitrary_channel_pruning - Done 2")
 
 
 if __name__ == '__main__':
-- 
GitLab