From 477320933c976e29b1305c7d70fb0c8c0333029f Mon Sep 17 00:00:00 2001
From: Neta Zmora <neta.zmora@intel.com>
Date: Tue, 29 May 2018 23:04:28 +0300
Subject: [PATCH] Filter removal: support filter removal for VGG

This is a temporary implementation that allows filter-removal and
netowrk thinning for VGG.
The implementation continues the present design for network thinning,
which is problematic because parts of the solution are specific to
each model.

Leveraging some new features in PyTorch 0.4, we are now able to provide a
more generic solution to thinning, which we will push to 'master' soon.
This commit bridges the feature gap, for VGG filter-removal, for the
meantime.
---
 distiller/thinning.py | 195 +++++++++++++++++++++++++++++++++++++-----
 1 file changed, 173 insertions(+), 22 deletions(-)

diff --git a/distiller/thinning.py b/distiller/thinning.py
index fb95f83..241e184 100755
--- a/distiller/thinning.py
+++ b/distiller/thinning.py
@@ -21,13 +21,16 @@ from collections import namedtuple
 import torch
 from .policy import ScheduledTrainingPolicy
 import distiller
+from apputils import SummaryGraph
+from models import ALL_MODEL_NAMES, create_model
 msglogger = logging.getLogger()
 
 ThinningRecipe = namedtuple('ThinningRecipe', ['modules', 'parameters'])
 
 __all__ = ['ThinningRecipe', 'resnet_cifar_remove_layers', 'resnet_cifar_remove_filters',
            'resnet_cifar_remove_channels', 'ResnetCifarChannelRemover',
-           'ResnetCifarFilterRemover', 'execute_thinning_recipe']
+           'ResnetCifarFilterRemover', 'execute_thinning_recipe',
+           'FilterRemover', 'vgg_remove_filters']
 
 # This is a dictionary that maps ResNet-Cifar connectivity, of convolutional layers.
 # BN layers connectivity are implied; and shortcuts are not currently handled.
@@ -94,6 +97,24 @@ conv_connectivity = {
     'module.layer3.8.conv1': 'module.layer3.8.conv2',
     'module.layer3.8.conv2': 'module.fc'}
 
+vgg_conv_connectivity = {
+    'features.module.0':   'features.module.2',
+    'features.module.2':   'features.module.5',
+    'features.module.5':   'features.module.7',
+    'features.module.7':   'features.module.10',
+    'features.module.10':  'features.module.12',
+    'features.module.12':  'features.module.14',
+    'features.module.14':  'features.module.16',
+    'features.module.16':  'features.module.19',
+    'features.module.19':  'features.module.21',
+    'features.module.21':  'features.module.23',
+    'features.module.23':  'features.module.25',
+    'features.module.25':  'features.module.28',
+    'features.module.28':  'features.module.30',
+    'features.module.30':  'features.module.32',
+    'features.module.32':  'features.module.34',
+    'features.module.34':  'classifier.0'}
+
 def find_predecessors(layer_name):
     predecessors = []
     for layer, followers in conv_connectivity.items():
@@ -116,8 +137,9 @@ def bn_thinning(thinning_recipe, layers, bn_name, len_thin_features, thin_featur
     bn_directive['running_var'] = (0, thin_features)
 
     thinning_recipe.modules[bn_name] = bn_directive
-    thinning_recipe.parameters[bn_name+'.weight'] = (0, thin_features)
-    thinning_recipe.parameters[bn_name+'.bias'] = (0, thin_features)
+
+    thinning_recipe.parameters[bn_name+'.weight'] = [(0, thin_features)]
+    thinning_recipe.parameters[bn_name+'.bias'] = [(0, thin_features)]
 
 
 def resnet_cifar_remove_layers(resnet_cifar_model):
@@ -192,7 +214,9 @@ def resnet_cifar_create_thinning_recipe_channels(resnet_cifar_model, zeros_mask_
 
             # Select only the non-zero filters
             indices = nonzero_channels.data.squeeze()
-            thinning_recipe.parameters[name] = (1, indices)
+            param_directive = thinning_recipe.parameters.get(name, [])
+            param_directive.append((1, indices))
+            thinning_recipe.parameters[name] = param_directive
 
             assert layer_name in conv_connectivity
             predecessors = find_predecessors(layer_name)
@@ -204,8 +228,9 @@ def resnet_cifar_create_thinning_recipe_channels(resnet_cifar_model, zeros_mask_
                 thinning_recipe.modules[predecessor] = conv_directive
 
                 # Now remove channels from the weights tensor of the follower conv
-                #predecessor_param = distiller.model_find_param(resnet_cifar_model, predecessor+'.weight')
-                thinning_recipe.parameters[predecessor+'.weight'] = (0, indices)
+                param_directive = thinning_recipe.parameters.get(predecessor+'.weight', [])
+                param_directive.append((0, indices))
+                thinning_recipe.parameters[predecessor+'.weight'] = param_directive
 
                 # Now handle the BatchNormalization layer that follows the convolution
                 bn_name = predecessor.replace('conv', 'bn')
@@ -217,22 +242,48 @@ def resnet_cifar_create_thinning_recipe_channels(resnet_cifar_model, zeros_mask_
 
 def resnet_cifar_remove_filters(resnet_cifar_model, zeros_mask_dict):
     thinning_recipe = resnet_cifar_create_thinning_recipe_filters(resnet_cifar_model, zeros_mask_dict)
+    return remove_filters(resnet_cifar_model, zeros_mask_dict, thinning_recipe)
+
+
+def vgg_remove_filters(vgg_model, zeros_mask_dict):
+    thinning_recipe = vgg_create_thinning_recipe_filters(vgg_model, zeros_mask_dict)
+    return remove_filters(vgg_model, zeros_mask_dict, thinning_recipe)
 
+def remove_filters(model, zeros_mask_dict, thinning_recipe):
     if len(thinning_recipe.modules) > 0 or len(thinning_recipe.parameters)>0:
         # Stash the recipe, so that it will be serialized together with the model
-        resnet_cifar_model.thinning_recipe = thinning_recipe
+        model.thinning_recipe = thinning_recipe
         # Now actually remove the filters, chaneels and make the weight tensors smaller
-        execute_thinning_recipe(resnet_cifar_model, zeros_mask_dict, thinning_recipe)
+        execute_thinning_recipe(model, zeros_mask_dict, thinning_recipe)
         msglogger.info("Created, applied and saved a filter-thinning recipe")
-    return resnet_cifar_model
+    return model
+
+def get_input(dataset):
+    if dataset == 'imagenet':
+        return torch.randn((1, 3, 224, 224), requires_grad=False)
+    elif dataset == 'cifar10':
+        return torch.randn((1, 3, 32, 32))
+    return None
+
+def create_graph(dataset, arch):
+    dummy_input = get_input(dataset)
+    assert dummy_input is not None, "Unsupported dataset ({}) - aborting draw operation".format(dataset)
+
+    model = create_model(False, dataset, arch, parallel=False)
+    assert model is not None
+    return SummaryGraph(model, dummy_input)
 
+def normalize_layer_name(layer_name):
+    if layer_name.startswith('module.'):
+        layer_name = layer_name[len('module.'):]
+    return layer_name
 
 def resnet_cifar_create_thinning_recipe_filters(resnet_cifar_model, zeros_mask_dict):
     """Remove filters from ResNet-Cifar
 
     Caveats: (1) supports only ResNet50-Cifar; (2) only module.layerX.Y.conv1.weight
     """
-    msglogger.info("Invoking resnet_cifar_remove_filters")
+    msglogger.info("Invoking resnet_cifar_create_thinning_recipe_filters")
 
     layers = {thin_name : m for thin_name, m in resnet_cifar_model.named_modules()}
     thinning_recipe = ThinningRecipe(modules={}, parameters={})
@@ -259,9 +310,12 @@ def resnet_cifar_create_thinning_recipe_filters(resnet_cifar_model, zeros_mask_d
 
             # Select only the non-zero filters
             indices = nonzero_filters.data.squeeze()
-            thinning_recipe.parameters[name] = (0, indices)
+            #thinning_recipe.parameters[name] = (0, indices)
+            param_directive = thinning_recipe.parameters.get(name, [])
+            param_directive.append((0, indices))
+            thinning_recipe.parameters[name] = param_directive
 
-            assert layer_name in conv_connectivity
+            assert layer_name in conv_connectivity, "layer {} is not in conv_connectivity {}".format(layer_name, conv_connectivity)
             followers = conv_connectivity[layer_name] if isinstance(conv_connectivity[layer_name], list) else [conv_connectivity[layer_name]]
             for follower in followers:
                 # For each of the convolutional layers that follow, we have to reduce the number of input channels.
@@ -270,8 +324,12 @@ def resnet_cifar_create_thinning_recipe_filters(resnet_cifar_model, zeros_mask_d
                 thinning_recipe.modules[follower] = conv_directive
 
                 # Now remove channels from the weights tensor of the follower conv
-                #follower_param = distiller.model_find_param(resnet_cifar_model, follower+'.weight')
-                thinning_recipe.parameters[follower+'.weight'] = (1, indices)
+                param_directive = thinning_recipe.parameters.get(follower+'.weight', [])
+
+                msglogger.info("appending {}".format(follower+'.weight'))
+
+                param_directive.append((1, indices))
+                thinning_recipe.parameters[follower+'.weight'] = param_directive
 
             # Now handle the BatchNormalization layer that follows the convolution
             bn_name = layer_name.replace('conv', 'bn')
@@ -280,6 +338,80 @@ def resnet_cifar_create_thinning_recipe_filters(resnet_cifar_model, zeros_mask_d
             bn_thinning(thinning_recipe, layers, bn_name, len_thin_features=len(nonzero_filters), thin_features=indices)
     return thinning_recipe
 
+import math
+def vgg_create_thinning_recipe_filters(vgg_model, zeros_mask_dict):
+    """Remove filters from VGG
+    """
+    msglogger.info("Invoking vgg_create_thinning_recipe_filters")
+
+    layers = {thin_name : m for thin_name, m in vgg_model.named_modules()}
+    thinning_recipe = ThinningRecipe(modules={}, parameters={})
+
+    for name, p_thin in vgg_model.named_parameters():
+        if p_thin.dim() != 4:
+            continue
+
+        # Find the number of filters, in this weights tensor, that are not 100% sparse_model
+        filter_view = p_thin.view(p_thin.size(0), -1)
+        num_filters = filter_view.size()[0]
+        nonzero_filters = torch.nonzero(filter_view.abs().sum(dim=1))
+
+        # If there are zero-filters in this tensor then...
+        if num_filters > len(nonzero_filters):
+            msglogger.info("In tensor %s found %d/%d zero filters", name,
+                           num_filters - len(nonzero_filters), num_filters)
+
+            # Update the number of outgoing channels (OFMs) in the convolutional layer
+            layer_name = name[:-len('weights')]
+            assert isinstance(layers[layer_name], torch.nn.modules.Conv2d)
+            conv_directive = thinning_recipe.modules.get(layer_name, {})
+            conv_directive['out_channels'] = len(nonzero_filters)
+            thinning_recipe.modules[layer_name] = conv_directive
+
+            # Select only the non-zero filters
+            indices = nonzero_filters.data.squeeze()
+            param_directive = thinning_recipe.parameters.get(name, [])
+            param_directive.append((0, indices))
+            thinning_recipe.parameters[name] = param_directive
+
+            param_directive = thinning_recipe.parameters.get(layer_name+'.bias', [])
+            param_directive.append((0, indices))
+            thinning_recipe.parameters[layer_name+'.bias'] = param_directive
+
+            assert layer_name in vgg_conv_connectivity, "layer {} is not in vgg_conv_connectivity {}".format(layer_name, vgg_conv_connectivity)
+            followers = vgg_conv_connectivity[layer_name] if isinstance(vgg_conv_connectivity[layer_name], list) else [vgg_conv_connectivity[layer_name]]
+            for follower in followers:
+                # For each of the convolutional layers that follow, we have to reduce the number of input channels.
+                conv_directive = thinning_recipe.modules.get(follower, {})
+
+                if isinstance(layers[follower], torch.nn.modules.Conv2d):
+                    conv_directive['in_channels'] = len(nonzero_filters)
+                    msglogger.info("{}: setting in_channels = {}".format(follower, len(nonzero_filters)))
+                elif isinstance(layers[follower], torch.nn.modules.Linear):
+                    # TODO: this code is hard to follow
+                    fm_size = layers[follower].in_features / layers[layer_name].out_channels
+                    conv_directive['in_features'] = fm_size * len(nonzero_filters)
+                    #assert 22589 == conv_directive['in_features']
+                    msglogger.info("{}: setting in_features = {}".format(follower, conv_directive['in_features']))
+
+                    thinning_recipe.modules[follower] = conv_directive
+
+                # Now remove channels from the weights tensor of the follower conv
+                param_directive = thinning_recipe.parameters.get(follower+'.weight', [])
+                if isinstance(layers[follower], torch.nn.modules.Conv2d):
+                    param_directive.append((1, indices))
+                elif isinstance(layers[follower], torch.nn.modules.Linear):
+                    # TODO: this code is hard to follow
+                    fm_size = layers[follower].in_features / layers[layer_name].out_channels
+                    fm_height = fm_width = int(math.sqrt(fm_size))
+                    selection_view = (layers[follower].out_features, layers[layer_name].out_channels, fm_height, fm_width)
+                    #param_directive.append((1, indices, (4096,512,7,7)))
+                    true_view = (layers[follower].out_features, conv_directive['in_features'])
+                    param_directive.append((1, indices, selection_view, true_view))
+
+                thinning_recipe.parameters[follower+'.weight'] = param_directive
+    return thinning_recipe
+
 
 class ResnetCifarChannelRemover(ScheduledTrainingPolicy):
     """A policy which applies a network thinning function"""
@@ -290,7 +422,7 @@ class ResnetCifarChannelRemover(ScheduledTrainingPolicy):
         self.thinning_func(model, zeros_mask_dict)
 
 
-class ResnetCifarFilterRemover(ScheduledTrainingPolicy):
+class FilterRemover(ScheduledTrainingPolicy):
     """A policy which applies a network thinning function"""
     def __init__(self, thinning_func_str):
         self.thinning_func = globals()[thinning_func_str]
@@ -305,6 +437,11 @@ class ResnetCifarFilterRemover(ScheduledTrainingPolicy):
             self.thinning_func(model, zeros_mask_dict)
             self.done = True
 
+# This is for backward compatiblity with some of the published schedules
+class ResnetCifarFilterRemover(FilterRemover):
+    def __init__(self, thinning_func_str):
+        super().__init__(thinning_func_str)
+
 
 def execute_thinning_recipe(model, zeros_mask_dict, recipe):
     """Apply a thinning recipe to a model.
@@ -319,18 +456,32 @@ def execute_thinning_recipe(model, zeros_mask_dict, recipe):
     for layer_name, directives in recipe.modules.items():
         for attr, val in directives.items():
             if attr in ['running_mean', 'running_var']:
+                msglogger.info("{} thinning: setting {} to {}".format(layer_name, attr, len(val[1])))
                 setattr(layers[layer_name], attr,
                         torch.index_select(getattr(layers[layer_name], attr),
                                            dim=val[0], index=val[1]))
             else:
+                msglogger.info("{} thinning: setting {} to {}".format(layer_name, attr, val))
                 setattr(layers[layer_name], attr, val)
 
-    for param_name, info in recipe.parameters.items():
+    assert len(recipe.parameters) > 0
+
+    for param_name, param_directives in recipe.parameters.items():
         param = distiller.model_find_param(model, param_name)
-        param.data = torch.index_select(param.data, dim=info[0], index=info[1])
-        mask = zeros_mask_dict[param_name].mask
-        if (mask is not None) and (zeros_mask_dict[param_name].mask.size(dim=info[0]) != len(info[1])):
-            zeros_mask_dict[param_name].mask = torch.index_select(mask, dim=info[0], index=info[1])
+        for directive in param_directives:
+            dim = directive[0]
+            indices = directive[1]
+            if len(directive) == 4:  # TODO: this code is hard to follow
+                layer_name = param_name[:-len('weights')]
+
+                selection_view = param.view(*directive[2])
+                param.data = torch.index_select(selection_view, dim, indices)
+                #param.data = param.view(4096, 22589)
+                param.data = param.view(*directive[3])
+            else:
+                param.data = torch.index_select(param.data, dim, indices)
+                msglogger.info("thinning: changing param {} shape: {}".format(param_name, len(indices)))
 
-    model.thinning_recipe = recipe
-    return model
+            mask = zeros_mask_dict[param_name].mask
+            if (mask is not None) and (zeros_mask_dict[param_name].mask.size(dim) != len(indices)):
+                zeros_mask_dict[param_name].mask = torch.index_select(mask, dim, indices)
-- 
GitLab