From b21f449b0e0b4ba0be4781b226a653cfd1186bbe Mon Sep 17 00:00:00 2001
From: Neta Zmora <neta.zmora@intel.com>
Date: Sat, 30 Jun 2018 20:48:02 +0300
Subject: [PATCH] Bug fix: add support for thinning the optimizer
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit

You no longer need to use —momentum=0 when removing structures
dynamically.
The SGD momentum update (velocity) is dependent on the weights, which
PyTorch optimizers cache internally.  This caching is not a problem for
filter/channel removal (thinning) because although we dynamically
change the shapes of the weights tensors, we don’t change the weights
tensors themselves.
PyTorch’s SGD creates tensors to store the momentum updates, and these
tensors have the same shape as the weights tensors.  When we change the
weights tensors, we need to make the appropriate changes in the Optimizer,
or disable the momentum.
We added a new function - thinning.optimizer_thinning() - to do this.
This function is brittle as it is tested only on optim.SGD and relies on the
internal representation of the SGD optimizer, which can change w/o notice.
For example, optim.Adam uses state['exp_avg'], state['exp_avg_sq']
Which also depend the shape of the weight tensors.
We needed to pass the Optimizer instance to Thinning policies
(ChannelRemover, FilterRemover) via the callbacks, which required us
to change the callback interface.
In the future we plan a bigger change to the callback API, to allow
passing of arbitrary context from the training environment to Distiller.

Also in this commit:
* compress_classifier.py had special handling for resnet layer-removal, which
is used in examples/ssl/ssl_4D-removal_training.yaml.
This is a brittle and ugly hack.  Until we have a more elegant solution, I’m
Removing support for layer-removal.
* Added to the tests invocation of forward and backward passes over a model.
This tests more of the real flows, which use the optimizer and construct
gradient tensors.
* Added a test of a special case of convolution filter-pruning which occurs
when the next layer is fully-connected (linear)
---
 distiller/policy.py                           |  14 +-
 distiller/scheduler.py                        |  23 +-
 distiller/thinning.py                         |  80 ++++--
 .../compress_classifier.py                    |  18 +-
 .../vgg19.schedule_filter_rank.yaml           |   2 +-
 examples/ssl/ssl_4D-removal_training.yaml     |  10 +
 tests/test_pruning.py                         | 227 +++++++++++++-----
 7 files changed, 255 insertions(+), 119 deletions(-)

diff --git a/distiller/policy.py b/distiller/policy.py
index ca3f1c1..76d295f 100755
--- a/distiller/policy.py
+++ b/distiller/policy.py
@@ -40,18 +40,18 @@ class ScheduledTrainingPolicy(object):
         """A new epcoh is about to begin"""
         pass
 
-    def on_minibatch_begin(self, model, epoch, minibatch_id, minibatches_per_epoch, zeros_mask_dict):
+    def on_minibatch_begin(self, model, epoch, minibatch_id, minibatches_per_epoch, zeros_mask_dict, optimizr=None):
         """The forward-pass of a new mini-batch is about to begin"""
         pass
 
     def before_backward_pass(self, model, epoch, minibatch_id, minibatches_per_epoch, loss,
-                             regularizer_loss, zeros_mask_dict):
+                             regularizer_loss, zeros_mask_dict, optimizer=None):
         """The mini-batch training pass has completed the forward-pass,
         and is about to begin the backward pass.
         """
         pass
 
-    def on_minibatch_end(self, model, epoch, minibatch_id, minibatches_per_epoch, zeros_mask_dict):
+    def on_minibatch_end(self, model, epoch, minibatch_id, minibatches_per_epoch, zeros_mask_dict, optimizer):
         """The mini-batch training pass has ended"""
         pass
 
@@ -81,7 +81,7 @@ class PruningPolicy(ScheduledTrainingPolicy):
         for param_name, param in model.named_parameters():
             self.pruner.set_param_mask(param, param_name, zeros_mask_dict, meta)
 
-    def on_minibatch_begin(self, model, epoch, minibatch_id, minibatches_per_epoch, zeros_mask_dict):
+    def on_minibatch_begin(self, model, epoch, minibatch_id, minibatches_per_epoch, zeros_mask_dict, optimizer):
         for param_name, param in model.named_parameters():
             zeros_mask_dict[param_name].apply_mask(param)
 
@@ -100,11 +100,11 @@ class RegularizationPolicy(ScheduledTrainingPolicy):
         self.is_last_epoch = meta['current_epoch'] == (meta['ending_epoch'] - 1)
 
     def before_backward_pass(self, model, epoch, minibatch_id, minibatches_per_epoch, loss,
-                             regularizer_loss, zeros_mask_dict):
+                             regularizer_loss, zeros_mask_dict, optimizer=None):
         for param_name, param in model.named_parameters():
             self.regularizer.loss(param, param_name, regularizer_loss, zeros_mask_dict)
 
-    def on_minibatch_end(self, model, epoch, minibatch_id, minibatches_per_epoch, zeros_mask_dict):
+    def on_minibatch_end(self, model, epoch, minibatch_id, minibatches_per_epoch, zeros_mask_dict, optimizer):
         if self.regularizer.threshold_criteria is None:
             return
 
@@ -141,7 +141,7 @@ class QuantizationPolicy(ScheduledTrainingPolicy):
         self.quantizer.prepare_model()
         self.quantizer.quantize_params()
 
-    def on_minibatch_end(self, model, epoch, minibatch_id, minibatches_per_epoch, zeros_mask_dict):
+    def on_minibatch_end(self, model, epoch, minibatch_id, minibatches_per_epoch, zeros_mask_dict, optimizer):
         # After parameters update, quantize the parameters again
         # (Doing this here ensures the model parameters are quantized at training completion (and at validation time)
         self.quantizer.quantize_params()
diff --git a/distiller/scheduler.py b/distiller/scheduler.py
index e4c1d14..ba7a56b 100755
--- a/distiller/scheduler.py
+++ b/distiller/scheduler.py
@@ -86,24 +86,24 @@ class CompressionScheduler(object):
                 self.policies[epoch].append(policy)
             assert len(self.policies[epoch]) > 0
 
-        self.sched_metadata[policy] = { 'starting_epoch' : starting_epoch,
-                                        'ending_epoch' : ending_epoch,
-                                        'frequency' : frequency }
+        self.sched_metadata[policy] = {'starting_epoch': starting_epoch,
+                                       'ending_epoch': ending_epoch,
+                                       'frequency': frequency}
 
-    def on_epoch_begin(self, epoch):
+    def on_epoch_begin(self, epoch, optimizer=None):
         if epoch in self.policies:
             for policy in self.policies[epoch]:
                 meta = self.sched_metadata[policy]
                 meta['current_epoch'] = epoch
                 policy.on_epoch_begin(self.model, self.zeros_mask_dict, meta)
 
-    def on_minibatch_begin(self, epoch, minibatch_id, minibatches_per_epoch):
+    def on_minibatch_begin(self, epoch, minibatch_id, minibatches_per_epoch, optimizer=None):
         if epoch in self.policies:
             for policy in self.policies[epoch]:
                 policy.on_minibatch_begin(self.model, epoch, minibatch_id, minibatches_per_epoch,
-                                          self.zeros_mask_dict)
+                                          self.zeros_mask_dict, optimizer)
 
-    def before_backward_pass(self, epoch, minibatch_id, minibatches_per_epoch, loss):
+    def before_backward_pass(self, epoch, minibatch_id, minibatches_per_epoch, loss, optimizer=None):
         # Last chance to compute the regularization loss, and optionally add it to the data loss
         regularizer_loss = torch.tensor(0, dtype=torch.float, device=self.device)
 
@@ -114,25 +114,26 @@ class CompressionScheduler(object):
                                             loss, regularizer_loss, self.zeros_mask_dict)
         return regularizer_loss
 
-    def on_minibatch_end(self, epoch, minibatch_id, minibatches_per_epoch):
+    def on_minibatch_end(self, epoch, minibatch_id, minibatches_per_epoch, optimizer=None):
         # When we get to this point, the weights are no longer maksed.  This is because during the backward
         # pass, the weights are updated.  So we choose to lazily apply the pruning mask, only if some
         # component is being called-back.
         weights_are_masked = False
- 
+
         if epoch in self.policies:
             for policy in self.policies[epoch]:
                 if not weights_are_masked:
                     self.apply_mask()
                     weights_are_masked = True
                 policy.on_minibatch_end(self.model, epoch, minibatch_id, minibatches_per_epoch,
-                                        self.zeros_mask_dict)
+                                        self.zeros_mask_dict, optimizer)
 
-    def on_epoch_end(self, epoch):
+    def on_epoch_end(self, epoch, optimizer=None):
         if epoch in self.policies:
             for policy in self.policies[epoch]:
                 meta = self.sched_metadata[policy]
                 meta['current_epoch'] = epoch
+                meta['optimizer'] = optimizer
                 policy.on_epoch_end(self.model, self.zeros_mask_dict, meta)
 
     def apply_mask(self):
diff --git a/distiller/thinning.py b/distiller/thinning.py
index eb27e52..94346f3 100755
--- a/distiller/thinning.py
+++ b/distiller/thinning.py
@@ -149,10 +149,10 @@ def resnet_cifar_remove_layers(model):
         model.module.layer_gates[layer][block][conv] = False
 
 
-def remove_channels(model, zeros_mask_dict, arch, dataset):
+def remove_channels(model, zeros_mask_dict, arch, dataset, optimizer):
     sgraph = create_graph(dataset, arch)
     thinning_recipe = create_thinning_recipe_channels(sgraph, model, zeros_mask_dict)
-    apply_and_save_recipe(model, zeros_mask_dict, thinning_recipe)
+    apply_and_save_recipe(model, zeros_mask_dict, thinning_recipe, optimizer)
     return model
 
 
@@ -188,10 +188,10 @@ def find_nonzero_channels_list(param, param_name):
     return nnz_channels.cpu().numpy().tolist()
 
 
-def apply_and_save_recipe(model, zeros_mask_dict, thinning_recipe):
+def apply_and_save_recipe(model, zeros_mask_dict, thinning_recipe, optimizer):
     if len(thinning_recipe.modules) > 0 or len(thinning_recipe.parameters) > 0:
         # Now actually remove the filters, chaneels and make the weight tensors smaller
-        execute_thinning_recipe(model, zeros_mask_dict, thinning_recipe)
+        execute_thinning_recipe(model, zeros_mask_dict, thinning_recipe, optimizer)
 
         # Stash the recipe, so that it will be serialized together with the model
         if hasattr(model, 'thinning_recipes'):
@@ -204,10 +204,10 @@ def apply_and_save_recipe(model, zeros_mask_dict, thinning_recipe):
         msglogger.error("Failed to create a thinning recipe")
 
 
-def remove_filters(model, zeros_mask_dict, arch, dataset):
+def remove_filters(model, zeros_mask_dict, arch, dataset, optimizer):
     sgraph = create_graph(dataset, arch)
     thinning_recipe = create_thinning_recipe_filters(sgraph, model, zeros_mask_dict)
-    apply_and_save_recipe(model, zeros_mask_dict, thinning_recipe)
+    apply_and_save_recipe(model, zeros_mask_dict, thinning_recipe, optimizer)
     return model
 
 
@@ -375,7 +375,7 @@ class ChannelRemover(ScheduledTrainingPolicy):
         self.dataset = dataset
 
     def on_epoch_end(self, model, zeros_mask_dict, meta):
-        self.thinning_func(model, zeros_mask_dict, self.arch, self.dataset)
+        self.thinning_func(model, zeros_mask_dict, self.arch, self.dataset, meta.get('optimizer', None))
 
 
 class FilterRemover(ScheduledTrainingPolicy):
@@ -387,24 +387,24 @@ class FilterRemover(ScheduledTrainingPolicy):
         self.done = False
         self.active_cb = "on_minibatch_begin"
 
-    def __apply(self, model, zeros_mask_dict):
+    def __apply(self, model, zeros_mask_dict, optimizer):
         if not self.done:
             # We want to execute the thinning function only once, not every invocation of on_minibatch_begin
-            self.thinning_func(model, zeros_mask_dict, self.arch, self.dataset)
+            self.thinning_func(model, zeros_mask_dict, self.arch, self.dataset, optimizer=optimizer)
             self.done = True
 
-    def on_minibatch_begin(self, model, epoch, minibatch_id, minibatches_per_epoch, zeros_mask_dict):
+    def on_minibatch_begin(self, model, epoch, minibatch_id, minibatches_per_epoch, zeros_mask_dict, optimizer):
         # We hook onto the on_minibatch_begin because we want to run after the pruner which sparsified
         # the tensors.  Pruners configure their pruning mask in on_epoch_begin, but apply the mask
         # only in on_minibatch_begin
         if self.active_cb != "on_minibatch_begin":
             return
-        self.__apply(model, zeros_mask_dict)
+        self.__apply(model, zeros_mask_dict, optimizer)
 
-    def on_minibatch_end(self, model, epoch, minibatch_id, minibatches_per_epoch, zeros_mask_dict):
+    def on_minibatch_end(self, model, epoch, minibatch_id, minibatches_per_epoch, zeros_mask_dict, optimizer):
         if self.active_cb != "on_minibatch_end":
             return
-        self.__apply(model, zeros_mask_dict)
+        self.__apply(model, zeros_mask_dict, optimizer)
 
     def on_epoch_end(self, model, zeros_mask_dict, meta):
         # The epoch has ended and we reset the 'done' flag, so that the FilterRemover instance can be reused
@@ -416,15 +416,45 @@ def execute_thinning_recipes_list(model, zeros_mask_dict, recipe_list):
     # to a thinned model. For example, this is invoked when loading a model from a checkpoint.
     for i, recipe in enumerate(recipe_list):
         msglogger.info("recipe %d:" % i)
-        execute_thinning_recipe(model, zeros_mask_dict, recipe, loaded_from_file=True)
+        execute_thinning_recipe(model, zeros_mask_dict, recipe, optimizer=None, loaded_from_file=True)
 
 
-def execute_thinning_recipe(model, zeros_mask_dict, recipe, loaded_from_file=False):
+def optimizer_thinning(optimizer, param, dim, indices, new_shape=None):
+    """Adjust the size of the SGD vecolity-tracking tensors.
+
+    The SGD momentum update (velocity) is dependent on the weights, and because during thinning we
+    dynamically change the weights shapes, we need to make the apporpriate changes in the Optimizer,
+    or disable the momentum.
+
+    This function is brittle as it is tested on SGD only and relies on the internal representation of
+    the SGD optimizer, which can change w/o notice.
+    """
+    if optimizer is None or not isinstance(optimizer, torch.optim.SGD):
+        return False
+    for group in optimizer.param_groups:
+        momentum = group.get('momentum', 0)
+        if momentum == 0:
+            continue
+        for p in group['params']:
+            if id(p) != id(param):
+                continue
+            param_state = optimizer.state[p]
+            if 'momentum_buffer' in param_state:
+                param_state['momentum_buffer'] = torch.index_select(param_state['momentum_buffer'], dim, indices)
+                if new_shape is not None:
+                    msglogger.info("optimizer_thinning: new shape {}".format(*new_shape))
+                    param_state['momentum_buffer'] = param_state['momentum_buffer'].resize_(*new_shape)
+                return True
+    return False
+
+
+def execute_thinning_recipe(model, zeros_mask_dict, recipe, optimizer, loaded_from_file=False):
     """Apply a thinning recipe to a model.
 
     This will remove filters and channels, as well as handle batch-normalization parameter
     adjustment, and thinning of weight tensors.
     """
+
     layers = {}
     for name, m in model.named_modules():
         layers[name] = m
@@ -458,12 +488,14 @@ def execute_thinning_recipe(model, zeros_mask_dict, recipe, loaded_from_file=Fal
                 # Check if we're trying to trim a parameter that is already "thin"
                 if param.data.size(dim) != len_indices:
                     param.data = torch.index_select(selection_view, dim, indices)
-
-                if param.grad is not None:
-                    # We also need to change the dimensions of the gradient tensor.
-                    grad_selection_view = param.grad.resize_(*directive[2])
-                    if grad_selection_view.size(dim) != len_indices:
-                        param.grad = torch.index_select(grad_selection_view, dim, indices)
+                    if param.grad is not None:
+                        # We also need to change the dimensions of the gradient tensor.
+                        grad_selection_view = param.grad.resize_(*directive[2])
+                        if grad_selection_view.size(dim) != len_indices:
+                            param.grad = torch.index_select(grad_selection_view, dim, indices)
+                            if optimizer_thinning(optimizer, param, dim, indices, directive[3]):
+                                msglogger.info("Updated [4D] velocity buffer for {} (dim={},size={},shape={})".
+                                               format(param_name, dim, len_indices, directive[3]))
 
                 param.data = param.view(*directive[3])
                 if param.grad is not None:
@@ -471,12 +503,14 @@ def execute_thinning_recipe(model, zeros_mask_dict, recipe, loaded_from_file=Fal
             else:
                 if param.data.size(dim) != len_indices:
                     param.data = torch.index_select(param.data, dim, indices)
+                    msglogger.info("[thinning] changed param {} shape: {}".format(param_name, len_indices))
                 # We also need to change the dimensions of the gradient tensor.
                 # If have not done a backward-pass thus far, then the gradient will
                 # not exist, and therefore won't need to be re-dimensioned.
                 if param.grad is not None and param.grad.size(dim) != len_indices:
-                        param.grad = torch.index_select(param.grad, dim, indices)
-                msglogger.info("[thinning] changing param {} shape: {}".format(param_name, len_indices))
+                    param.grad = torch.index_select(param.grad, dim, indices)
+                    if optimizer_thinning(optimizer, param, dim, indices):
+                        msglogger.info("Updated velocity buffer %s" % param_name)
 
             if not loaded_from_file:
                 # If the masks are loaded from a checkpoint file, then we don't need to change
diff --git a/examples/classifier_compression/compress_classifier.py b/examples/classifier_compression/compress_classifier.py
index f33e171..c6db0b8 100755
--- a/examples/classifier_compression/compress_classifier.py
+++ b/examples/classifier_compression/compress_classifier.py
@@ -223,10 +223,6 @@ def main():
         model, compression_scheduler, start_epoch = apputils.load_checkpoint(
             model, chkpt_file=args.resume)
 
-        if 'resnet' in args.arch and 'preact' not in args.arch and 'cifar' in args.arch:
-            distiller.resnet_cifar_remove_layers(model)
-            #model = distiller.resnet_cifar_remove_channels(model, compression_scheduler.zeros_mask_dict)
-
     # Define loss function (criterion) and optimizer
     criterion = nn.CrossEntropyLoss().cuda()
     optimizer = torch.optim.SGD(model.parameters(), lr=args.lr,
@@ -320,11 +316,11 @@ def main():
                  OrderedDict([('Loss', vloss),
                               ('Top1', top1),
                               ('Top5', top5)]))
-        distiller.log_training_progress(stats, None,epoch, steps_completed=0, total_steps=1,
+        distiller.log_training_progress(stats, None, epoch, steps_completed=0, total_steps=1,
                                         log_freq=1, loggers=[tflogger])
 
         if compression_scheduler:
-            compression_scheduler.on_epoch_end(epoch)
+            compression_scheduler.on_epoch_end(epoch, optimizer)
 
         # remember best top1 and save checkpoint
         is_best = top1 > best_top1
@@ -339,8 +335,8 @@ def main():
 def train(train_loader, model, criterion, optimizer, epoch,
           compression_scheduler, loggers, print_freq, log_params_hist):
     """Training loop for one epoch."""
-    losses = {'objective_loss'   : tnt.AverageValueMeter(),
-              'regularizer_loss' : tnt.AverageValueMeter()}
+    losses = {'objective_loss':   tnt.AverageValueMeter(),
+              'regularizer_loss': tnt.AverageValueMeter()}
     if compression_scheduler is None:
         # Initialize the regularizer loss to zero
         losses['regularizer_loss'].add(0)
@@ -368,7 +364,7 @@ def train(train_loader, model, criterion, optimizer, epoch,
 
         # Execute the forward phase, compute the output and measure loss
         if compression_scheduler:
-            compression_scheduler.on_minibatch_begin(epoch, train_step, steps_per_epoch)
+            compression_scheduler.on_minibatch_begin(epoch, train_step, steps_per_epoch, optimizer)
         output = model(input_var)
         loss = criterion(output, target_var)
 
@@ -378,7 +374,7 @@ def train(train_loader, model, criterion, optimizer, epoch,
 
         if compression_scheduler:
             # Before running the backward phase, we add any regularization loss computed by the scheduler
-            regularizer_loss = compression_scheduler.before_backward_pass(epoch, train_step, steps_per_epoch, loss)
+            regularizer_loss = compression_scheduler.before_backward_pass(epoch, train_step, steps_per_epoch, loss, optimizer)
             loss += regularizer_loss
             losses['regularizer_loss'].add(regularizer_loss.item())
 
@@ -387,7 +383,7 @@ def train(train_loader, model, criterion, optimizer, epoch,
         loss.backward()
         optimizer.step()
         if compression_scheduler:
-            compression_scheduler.on_minibatch_end(epoch, train_step, steps_per_epoch)
+            compression_scheduler.on_minibatch_end(epoch, train_step, steps_per_epoch, optimizer)
 
         # measure elapsed time
         batch_time.add(time.time() - end)
diff --git a/examples/pruning_filters_for_efficient_convnets/vgg19.schedule_filter_rank.yaml b/examples/pruning_filters_for_efficient_convnets/vgg19.schedule_filter_rank.yaml
index 9efc6dc..cb20dd1 100755
--- a/examples/pruning_filters_for_efficient_convnets/vgg19.schedule_filter_rank.yaml
+++ b/examples/pruning_filters_for_efficient_convnets/vgg19.schedule_filter_rank.yaml
@@ -2,7 +2,7 @@
 # This schedule performs 3D (filter-wise) regularization of some of the convolution layers, together with
 # element-wise pruning using sensitivity-pruning.
 #
-# time python3 compress_classifier.py -a=vgg19 -p=50 ../../../data.imagenet --epochs=10 --lr=0.00001 --compress=../pruning_filters_for_efficient_convnets/vgg19.schedule_filter_rank.yaml --pretrained --momentum=0
+# time python3 compress_classifier.py -a=vgg19 -p=50 ../../../data.imagenet --epochs=10 --lr=0.00001 --compress=../pruning_filters_for_efficient_convnets/vgg19.schedule_filter_rank.yaml --pretrained
 #
 
 
diff --git a/examples/ssl/ssl_4D-removal_training.yaml b/examples/ssl/ssl_4D-removal_training.yaml
index 0fcbe0f..ef1b120 100755
--- a/examples/ssl/ssl_4D-removal_training.yaml
+++ b/examples/ssl/ssl_4D-removal_training.yaml
@@ -1,3 +1,13 @@
+#
+# THIS IS CURRENTLY BROKEN!
+# There are two challenges to getting this to work:
+#   1. If we want to remove layers, we should use a thinning recipe and leave instructions.  Otherwise, it's very hard to
+#      decide if we should remove layers or not.
+#   2. We need a generic solution for removing layers.  Perhaps a new module type which wraps a layer and has a bypass-gate.
+#
+# If you still want to use this, until we implement a proper solution, invoke thinning.resnet_cifar_remove_layers from
+# compress_classifier.py.
+#
 # We used this schedule to train CIFAR10-ResNet20 from scratch with SSL.
 # After running this schedule, use ssl_4D-removal_finetuning.yaml to "physically" remove the layers and fine-tune the smaller model.
 #
diff --git a/tests/test_pruning.py b/tests/test_pruning.py
index b6c6a06..b16e901 100755
--- a/tests/test_pruning.py
+++ b/tests/test_pruning.py
@@ -35,7 +35,7 @@ fh = logging.FileHandler('test.log')
 logger = logging.getLogger()
 logger.addHandler(fh)
 
-NetConfig = namedtuple("test_config", "arch dataset conv1_name conv2_name bn_name")
+NetConfig = namedtuple("test_config", "arch dataset bn_name module_pairs")
 
 
 #
@@ -43,20 +43,34 @@ NetConfig = namedtuple("test_config", "arch dataset conv1_name conv2_name bn_nam
 #
 def simplenet():
     return NetConfig(arch="simplenet_cifar", dataset="cifar10",
-                     conv1_name="conv1", conv2_name="conv2",
+                     module_pairs=[("conv1", "conv2")],
                      bn_name=None)
 
 
 def resnet20_cifar():
     return NetConfig(arch="resnet20_cifar", dataset="cifar10",
-                     conv1_name="layer1.0.conv1", conv2_name="layer1.0.conv2",
+                     module_pairs=[("layer1.0.conv1", "layer1.0.conv2")],
                      bn_name="layer1.0.bn1")
 
 
+def vgg19_imagenet():
+    return NetConfig(arch="vgg19", dataset="imagenet",
+                     module_pairs=[("features.21", "features.23"),
+                                   ("features.23", "features.25"),
+                                   ("features.25", "features.28"),
+                                   ("features.28", "features.30"),
+                                   ("features.30", "features.32"),
+                                   ("features.32", "features.34")],
+                     bn_name=None)
+
+
 def test_ranked_filter_pruning():
     ranked_filter_pruning(resnet20_cifar(), ratio_to_prune=0.1)
     ranked_filter_pruning(resnet20_cifar(), ratio_to_prune=0.5)
     ranked_filter_pruning(simplenet(), ratio_to_prune=0.5)
+    ranked_filter_pruning(vgg19_imagenet(), ratio_to_prune=0.1)
+    model, zeros_mask_dict = ranked_filter_pruning(vgg19_imagenet(), ratio_to_prune=0.1)
+    test_conv_fc_interface(model, zeros_mask_dict)
 
 
 def test_prune_all_filters():
@@ -76,44 +90,46 @@ def ranked_filter_pruning(config, ratio_to_prune):
     """
     model, zeros_mask_dict = common.setup_test(config.arch, config.dataset)
 
-    # Test that we can access the weights tensor of the first convolution in layer 1
-    conv1_p = distiller.model_find_param(model, config.conv1_name + ".weight")
-    assert conv1_p is not None
-    num_filters = conv1_p.size(0)
-
-    # Test that there are no zero-filters
-    assert distiller.sparsity_3D(conv1_p) == 0.0
-
-    # Create a filter-ranking pruner
-    reg_regims = {config.conv1_name + ".weight": [ratio_to_prune, "3D"]}
-    pruner = distiller.pruning.L1RankedStructureParameterPruner("filter_pruner", reg_regims)
-    pruner.set_param_mask(conv1_p, config.conv1_name + ".weight", zeros_mask_dict, meta=None)
-
-    conv1 = common.find_module_by_name(model, config.conv1_name)
-    assert conv1 is not None
-    # Test that the mask has the correct fraction of filters pruned.
-    # We asked for 10%, but there are only 16 filters, so we have to settle for 1/16 filters
-    expected_cnt_removed_filters = int(ratio_to_prune * conv1.out_channels)
-    expected_pruning = expected_cnt_removed_filters / conv1.out_channels
-    masker = zeros_mask_dict[config.conv1_name + ".weight"]
-    assert masker is not None
-    assert distiller.sparsity_3D(masker.mask) == expected_pruning
-
-    # Use the mask to prune
-    assert distiller.sparsity_3D(conv1_p) == 0
-    masker.apply_mask(conv1_p)
-    assert distiller.sparsity_3D(conv1_p) == expected_pruning
-
-    # Remove filters
-    conv2 = common.find_module_by_name(model, config.conv2_name)
-    assert conv2 is not None
-    assert conv1.out_channels == num_filters
-    assert conv2.in_channels == num_filters
+    for pair in config.module_pairs:
+        # Test that we can access the weights tensor of the first convolution in layer 1
+        conv1_p = distiller.model_find_param(model, pair[0] + ".weight")
+        assert conv1_p is not None
+        num_filters = conv1_p.size(0)
+
+        # Test that there are no zero-filters
+        assert distiller.sparsity_3D(conv1_p) == 0.0
+
+        # Create a filter-ranking pruner
+        reg_regims = {pair[0] + ".weight": [ratio_to_prune, "3D"]}
+        pruner = distiller.pruning.L1RankedStructureParameterPruner("filter_pruner", reg_regims)
+        pruner.set_param_mask(conv1_p, pair[0] + ".weight", zeros_mask_dict, meta=None)
+
+        conv1 = common.find_module_by_name(model, pair[0])
+        assert conv1 is not None
+        # Test that the mask has the correct fraction of filters pruned.
+        # We asked for 10%, but there are only 16 filters, so we have to settle for 1/16 filters
+        expected_cnt_removed_filters = int(ratio_to_prune * conv1.out_channels)
+        expected_pruning = expected_cnt_removed_filters / conv1.out_channels
+        masker = zeros_mask_dict[pair[0] + ".weight"]
+        assert masker is not None
+        assert distiller.sparsity_3D(masker.mask) == expected_pruning
+
+        # Use the mask to prune
+        assert distiller.sparsity_3D(conv1_p) == 0
+        masker.apply_mask(conv1_p)
+        assert distiller.sparsity_3D(conv1_p) == expected_pruning
+
+        # Remove filters
+        conv2 = common.find_module_by_name(model, pair[1])
+        assert conv2 is not None
+        assert conv1.out_channels == num_filters
+        assert conv2.in_channels == num_filters
 
     # Test thinning
-    distiller.remove_filters(model, zeros_mask_dict, config.arch, config.dataset)
+    distiller.remove_filters(model, zeros_mask_dict, config.arch, config.dataset, optimizer=None)
     assert conv1.out_channels == num_filters - expected_cnt_removed_filters
     assert conv2.in_channels == num_filters - expected_cnt_removed_filters
+    return model, zeros_mask_dict
 
 
 def test_arbitrary_channel_pruning():
@@ -134,6 +150,40 @@ def test_channel_pruning_conv_bias():
     arbitrary_channel_pruning(simplenet(), channels_to_remove=[0, 1])
 
 
+def create_channels_mask(conv_p, channels_to_remove):
+    assert conv_p.dim() == 4
+    num_filters = conv_p.size(0)
+    num_channels = conv_p.size(1)
+    kernel_height = conv_p.size(2)
+    kernel_width = conv_p.size(3)
+
+    # Let's build our 4D mask.
+    # We start with a 1D mask of channels, with all but our specified channels set to one
+    channels = torch.ones(num_channels)
+    for ch in channels_to_remove:
+        channels[ch] = 0
+
+    # Now let's expand back up to a 4D mask
+    mask = channels.expand(num_filters, num_channels)
+    mask.unsqueeze_(-1)
+    mask.unsqueeze_(-1)
+    mask = mask.expand(num_filters, num_channels, kernel_height, kernel_width).contiguous()
+
+    assert mask.shape == conv_p.shape
+    return mask
+
+
+def run_forward_backward(model, optimizer, dummy_input):
+    criterion = torch.nn.CrossEntropyLoss().cuda()
+    model.train()
+    output = model(dummy_input)
+    target = torch.LongTensor(1).random_(2)
+    loss = criterion(output, target)
+    optimizer.zero_grad()
+    loss.backward()
+    optimizer.step()
+
+
 def arbitrary_channel_pruning(config, channels_to_remove):
     """Test removal of arbitrary channels.
 
@@ -143,49 +193,34 @@ def arbitrary_channel_pruning(config, channels_to_remove):
     """
     model, zeros_mask_dict = common.setup_test(config.arch, config.dataset)
 
-    conv2 = common.find_module_by_name(model, config.conv2_name)
+    assert len(config.module_pairs) == 1   # This is a temporary restriction on the test
+    pair = config.module_pairs[0]
+    conv2 = common.find_module_by_name(model, pair[1])
     assert conv2 is not None
 
     # Test that we can access the weights tensor of the first convolution in layer 1
-    conv2_p = distiller.model_find_param(model, config.conv2_name + ".weight")
+    conv2_p = distiller.model_find_param(model, pair[1] + ".weight")
     assert conv2_p is not None
 
     assert conv2_p.dim() == 4
-    num_filters = conv2_p.size(0)
     num_channels = conv2_p.size(1)
-    kernel_height = conv2_p.size(2)
-    kernel_width = conv2_p.size(3)
     cnt_nnz_channels = num_channels - len(channels_to_remove)
-
-    # Let's build our 4D mask.
-    # We start with a 1D mask of channels, with all but our specified channels set to one
-    channels = torch.ones(num_channels)
-    for ch in channels_to_remove:
-        channels[ch] = 0
-
-    # Now let's expand back up to a 4D mask
-    mask = channels.expand(num_filters, num_channels)
-    mask.unsqueeze_(-1)
-    mask.unsqueeze_(-1)
-    mask = mask.expand(num_filters, num_channels, kernel_height, kernel_width).contiguous()
-
-    assert mask.shape == conv2_p.shape
+    mask = create_channels_mask(conv2_p, channels_to_remove)
     assert distiller.density_ch(mask) == (conv2.in_channels - len(channels_to_remove)) / conv2.in_channels
-
     # Cool, so now we have a mask for pruning our channels.
+
     # Use the mask to prune
-    zeros_mask_dict[config.conv2_name + ".weight"].mask = mask
-    zeros_mask_dict[config.conv2_name + ".weight"].apply_mask(conv2_p)
+    zeros_mask_dict[pair[1] + ".weight"].mask = mask
+    zeros_mask_dict[pair[1] + ".weight"].apply_mask(conv2_p)
     all_channels = set([ch for ch in range(num_channels)])
-    nnz_channels = set(distiller.find_nonzero_channels_list(conv2_p, config.conv2_name + ".weight"))
+    nnz_channels = set(distiller.find_nonzero_channels_list(conv2_p, pair[1] + ".weight"))
     channels_removed = all_channels - nnz_channels
     logger.info("Channels removed {}".format(channels_removed))
 
     # Now, let's do the actual network thinning
-    distiller.remove_channels(model, zeros_mask_dict, config.arch, config.dataset)
-    conv1 = common.find_module_by_name(model, config.conv1_name)
-    logger.info(conv1)
-    logger.info(conv2)
+    distiller.remove_channels(model, zeros_mask_dict, config.arch, config.dataset, optimizer=None)
+    conv1 = common.find_module_by_name(model, pair[0])
+
     assert conv1.out_channels == cnt_nnz_channels
     assert conv2.in_channels == cnt_nnz_channels
     assert conv1.weight.size(0) == cnt_nnz_channels
@@ -198,6 +233,10 @@ def arbitrary_channel_pruning(config, channels_to_remove):
         assert bn1.bias.size(0) == cnt_nnz_channels
         assert bn1.weight.size(0) == cnt_nnz_channels
 
+    dummy_input = torch.randn(1, 3, 32, 32)
+    optimizer = torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.9, weight_decay=0.1)
+    run_forward_backward(model, optimizer, dummy_input)
+
     # Let's test saving and loading a thinned model.
     # We save 3 times, and load twice, to make sure to cover some corner cases:
     #   - Make sure that after loading, the model still has hold of the thinning recipes
@@ -206,16 +245,17 @@ def arbitrary_channel_pruning(config, channels_to_remove):
     # (1)
     save_checkpoint(epoch=0, arch=config.arch, model=model, optimizer=None)
     model_2 = create_model(False, config.dataset, config.arch, parallel=False)
-    dummy_input = torch.randn(1, 3, 32, 32)
     model(dummy_input)
     model_2(dummy_input)
-    conv2 = common.find_module_by_name(model_2, config.conv2_name)
+    conv2 = common.find_module_by_name(model_2, pair[1])
     assert conv2 is not None
     with pytest.raises(KeyError):
         model_2, compression_scheduler, start_epoch = load_checkpoint(model_2, 'checkpoint.pth.tar')
     compression_scheduler = distiller.CompressionScheduler(model)
     hasattr(model, 'thinning_recipes')
 
+    run_forward_backward(model, optimizer, dummy_input)
+
     # (2)
     save_checkpoint(epoch=0, arch=config.arch, model=model, optimizer=None, scheduler=compression_scheduler)
     model_2, compression_scheduler, start_epoch = load_checkpoint(model_2, 'checkpoint.pth.tar')
@@ -229,7 +269,62 @@ def arbitrary_channel_pruning(config, channels_to_remove):
     logger.info("test_arbitrary_channel_pruning - Done 2")
 
 
+def test_conv_fc_interface(model=None, zeros_mask_dict=None):
+    """A special case of convolution filter-pruning occurs when the next layer is
+    fully-connected (linear).  This test is for this case and uses VGG16.
+    """
+    arch = "vgg19"
+    dataset = "imagenet"
+    ratio_to_prune = 0.1
+    conv_name = "features.34"
+    fc_name = "classifier.0"
+    dummy_input = torch.randn(1, 3, 224, 224)
+
+    if model is None or zeros_mask_dict is None:
+        model, zeros_mask_dict = common.setup_test(arch, dataset)
+
+    # Run forward and backward passes, in order to create the gradients and optimizer params
+    optimizer = torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.9, weight_decay=0.1)
+    run_forward_backward(model, optimizer, dummy_input)
+
+    conv = common.find_module_by_name(model, conv_name)
+    assert conv is not None
+
+    conv_p = distiller.model_find_param(model, conv_name + ".weight")
+    assert conv_p is not None
+    assert conv_p.dim() == 4
+
+    # Create a filter-ranking pruner
+    reg_regims = {conv_name + ".weight": [ratio_to_prune, "3D"]}
+    pruner = distiller.pruning.L1RankedStructureParameterPruner("filter_pruner", reg_regims)
+    pruner.set_param_mask(conv_p, conv_name + ".weight", zeros_mask_dict, meta=None)
+
+    # Use the mask to prune
+    masker = zeros_mask_dict[conv_name + ".weight"]
+    assert masker is not None
+    masker.apply_mask(conv_p)
+    num_filters = conv_p.size(0)
+    expected_cnt_removed_filters = int(ratio_to_prune * conv.out_channels)
+
+    # Remove filters
+    fc = common.find_module_by_name(model, fc_name)
+    assert fc is not None
+
+    # Test thinning
+    fm_size = fc.in_features // conv.out_channels
+    num_nnz_filters = num_filters - expected_cnt_removed_filters
+    distiller.remove_filters(model, zeros_mask_dict, arch, dataset, optimizer)
+    assert conv.out_channels == num_nnz_filters
+    assert fc.in_features == fm_size * num_nnz_filters
+
+    # Run again, to make sure the optimizer and gradients shapes were updated correctly
+    run_forward_backward(model, optimizer, dummy_input)
+    run_forward_backward(model, optimizer, dummy_input)
+
+
 if __name__ == '__main__':
     test_ranked_filter_pruning()
     test_arbitrary_channel_pruning()
     test_prune_all_channels()
+    model, zeros_mask_dict = ranked_filter_pruning(vgg19_imagenet(), ratio_to_prune=0.1)
+    test_conv_fc_interface(model, zeros_mask_dict)
-- 
GitLab