diff --git a/distiller/policy.py b/distiller/policy.py
index ca3f1c1a8de2b9fa6da9b84b75c786c85ff5938a..76d295f5f059c14867203479c883b1aa1e9d66d6 100755
--- a/distiller/policy.py
+++ b/distiller/policy.py
@@ -40,18 +40,18 @@ class ScheduledTrainingPolicy(object):
         """A new epcoh is about to begin"""
         pass
 
-    def on_minibatch_begin(self, model, epoch, minibatch_id, minibatches_per_epoch, zeros_mask_dict):
+    def on_minibatch_begin(self, model, epoch, minibatch_id, minibatches_per_epoch, zeros_mask_dict, optimizr=None):
         """The forward-pass of a new mini-batch is about to begin"""
         pass
 
     def before_backward_pass(self, model, epoch, minibatch_id, minibatches_per_epoch, loss,
-                             regularizer_loss, zeros_mask_dict):
+                             regularizer_loss, zeros_mask_dict, optimizer=None):
         """The mini-batch training pass has completed the forward-pass,
         and is about to begin the backward pass.
         """
         pass
 
-    def on_minibatch_end(self, model, epoch, minibatch_id, minibatches_per_epoch, zeros_mask_dict):
+    def on_minibatch_end(self, model, epoch, minibatch_id, minibatches_per_epoch, zeros_mask_dict, optimizer):
         """The mini-batch training pass has ended"""
         pass
 
@@ -81,7 +81,7 @@ class PruningPolicy(ScheduledTrainingPolicy):
         for param_name, param in model.named_parameters():
             self.pruner.set_param_mask(param, param_name, zeros_mask_dict, meta)
 
-    def on_minibatch_begin(self, model, epoch, minibatch_id, minibatches_per_epoch, zeros_mask_dict):
+    def on_minibatch_begin(self, model, epoch, minibatch_id, minibatches_per_epoch, zeros_mask_dict, optimizer):
         for param_name, param in model.named_parameters():
             zeros_mask_dict[param_name].apply_mask(param)
 
@@ -100,11 +100,11 @@ class RegularizationPolicy(ScheduledTrainingPolicy):
         self.is_last_epoch = meta['current_epoch'] == (meta['ending_epoch'] - 1)
 
     def before_backward_pass(self, model, epoch, minibatch_id, minibatches_per_epoch, loss,
-                             regularizer_loss, zeros_mask_dict):
+                             regularizer_loss, zeros_mask_dict, optimizer=None):
         for param_name, param in model.named_parameters():
             self.regularizer.loss(param, param_name, regularizer_loss, zeros_mask_dict)
 
-    def on_minibatch_end(self, model, epoch, minibatch_id, minibatches_per_epoch, zeros_mask_dict):
+    def on_minibatch_end(self, model, epoch, minibatch_id, minibatches_per_epoch, zeros_mask_dict, optimizer):
         if self.regularizer.threshold_criteria is None:
             return
 
@@ -141,7 +141,7 @@ class QuantizationPolicy(ScheduledTrainingPolicy):
         self.quantizer.prepare_model()
         self.quantizer.quantize_params()
 
-    def on_minibatch_end(self, model, epoch, minibatch_id, minibatches_per_epoch, zeros_mask_dict):
+    def on_minibatch_end(self, model, epoch, minibatch_id, minibatches_per_epoch, zeros_mask_dict, optimizer):
         # After parameters update, quantize the parameters again
         # (Doing this here ensures the model parameters are quantized at training completion (and at validation time)
         self.quantizer.quantize_params()
diff --git a/distiller/scheduler.py b/distiller/scheduler.py
index e4c1d140c51a91a1de604936bbdf5dbe60d0fd92..ba7a56b95b14c5232eeb7640cc17e75f9b5a8932 100755
--- a/distiller/scheduler.py
+++ b/distiller/scheduler.py
@@ -86,24 +86,24 @@ class CompressionScheduler(object):
                 self.policies[epoch].append(policy)
             assert len(self.policies[epoch]) > 0
 
-        self.sched_metadata[policy] = { 'starting_epoch' : starting_epoch,
-                                        'ending_epoch' : ending_epoch,
-                                        'frequency' : frequency }
+        self.sched_metadata[policy] = {'starting_epoch': starting_epoch,
+                                       'ending_epoch': ending_epoch,
+                                       'frequency': frequency}
 
-    def on_epoch_begin(self, epoch):
+    def on_epoch_begin(self, epoch, optimizer=None):
         if epoch in self.policies:
             for policy in self.policies[epoch]:
                 meta = self.sched_metadata[policy]
                 meta['current_epoch'] = epoch
                 policy.on_epoch_begin(self.model, self.zeros_mask_dict, meta)
 
-    def on_minibatch_begin(self, epoch, minibatch_id, minibatches_per_epoch):
+    def on_minibatch_begin(self, epoch, minibatch_id, minibatches_per_epoch, optimizer=None):
         if epoch in self.policies:
             for policy in self.policies[epoch]:
                 policy.on_minibatch_begin(self.model, epoch, minibatch_id, minibatches_per_epoch,
-                                          self.zeros_mask_dict)
+                                          self.zeros_mask_dict, optimizer)
 
-    def before_backward_pass(self, epoch, minibatch_id, minibatches_per_epoch, loss):
+    def before_backward_pass(self, epoch, minibatch_id, minibatches_per_epoch, loss, optimizer=None):
         # Last chance to compute the regularization loss, and optionally add it to the data loss
         regularizer_loss = torch.tensor(0, dtype=torch.float, device=self.device)
 
@@ -114,25 +114,26 @@ class CompressionScheduler(object):
                                             loss, regularizer_loss, self.zeros_mask_dict)
         return regularizer_loss
 
-    def on_minibatch_end(self, epoch, minibatch_id, minibatches_per_epoch):
+    def on_minibatch_end(self, epoch, minibatch_id, minibatches_per_epoch, optimizer=None):
         # When we get to this point, the weights are no longer maksed.  This is because during the backward
         # pass, the weights are updated.  So we choose to lazily apply the pruning mask, only if some
         # component is being called-back.
         weights_are_masked = False
- 
+
         if epoch in self.policies:
             for policy in self.policies[epoch]:
                 if not weights_are_masked:
                     self.apply_mask()
                     weights_are_masked = True
                 policy.on_minibatch_end(self.model, epoch, minibatch_id, minibatches_per_epoch,
-                                        self.zeros_mask_dict)
+                                        self.zeros_mask_dict, optimizer)
 
-    def on_epoch_end(self, epoch):
+    def on_epoch_end(self, epoch, optimizer=None):
         if epoch in self.policies:
             for policy in self.policies[epoch]:
                 meta = self.sched_metadata[policy]
                 meta['current_epoch'] = epoch
+                meta['optimizer'] = optimizer
                 policy.on_epoch_end(self.model, self.zeros_mask_dict, meta)
 
     def apply_mask(self):
diff --git a/distiller/thinning.py b/distiller/thinning.py
index eb27e52624057c6457627b96664c0c0c006b2492..94346f34750c26763e564a4150ad816f3bfaeaee 100755
--- a/distiller/thinning.py
+++ b/distiller/thinning.py
@@ -149,10 +149,10 @@ def resnet_cifar_remove_layers(model):
         model.module.layer_gates[layer][block][conv] = False
 
 
-def remove_channels(model, zeros_mask_dict, arch, dataset):
+def remove_channels(model, zeros_mask_dict, arch, dataset, optimizer):
     sgraph = create_graph(dataset, arch)
     thinning_recipe = create_thinning_recipe_channels(sgraph, model, zeros_mask_dict)
-    apply_and_save_recipe(model, zeros_mask_dict, thinning_recipe)
+    apply_and_save_recipe(model, zeros_mask_dict, thinning_recipe, optimizer)
     return model
 
 
@@ -188,10 +188,10 @@ def find_nonzero_channels_list(param, param_name):
     return nnz_channels.cpu().numpy().tolist()
 
 
-def apply_and_save_recipe(model, zeros_mask_dict, thinning_recipe):
+def apply_and_save_recipe(model, zeros_mask_dict, thinning_recipe, optimizer):
     if len(thinning_recipe.modules) > 0 or len(thinning_recipe.parameters) > 0:
         # Now actually remove the filters, chaneels and make the weight tensors smaller
-        execute_thinning_recipe(model, zeros_mask_dict, thinning_recipe)
+        execute_thinning_recipe(model, zeros_mask_dict, thinning_recipe, optimizer)
 
         # Stash the recipe, so that it will be serialized together with the model
         if hasattr(model, 'thinning_recipes'):
@@ -204,10 +204,10 @@ def apply_and_save_recipe(model, zeros_mask_dict, thinning_recipe):
         msglogger.error("Failed to create a thinning recipe")
 
 
-def remove_filters(model, zeros_mask_dict, arch, dataset):
+def remove_filters(model, zeros_mask_dict, arch, dataset, optimizer):
     sgraph = create_graph(dataset, arch)
     thinning_recipe = create_thinning_recipe_filters(sgraph, model, zeros_mask_dict)
-    apply_and_save_recipe(model, zeros_mask_dict, thinning_recipe)
+    apply_and_save_recipe(model, zeros_mask_dict, thinning_recipe, optimizer)
     return model
 
 
@@ -375,7 +375,7 @@ class ChannelRemover(ScheduledTrainingPolicy):
         self.dataset = dataset
 
     def on_epoch_end(self, model, zeros_mask_dict, meta):
-        self.thinning_func(model, zeros_mask_dict, self.arch, self.dataset)
+        self.thinning_func(model, zeros_mask_dict, self.arch, self.dataset, meta.get('optimizer', None))
 
 
 class FilterRemover(ScheduledTrainingPolicy):
@@ -387,24 +387,24 @@ class FilterRemover(ScheduledTrainingPolicy):
         self.done = False
         self.active_cb = "on_minibatch_begin"
 
-    def __apply(self, model, zeros_mask_dict):
+    def __apply(self, model, zeros_mask_dict, optimizer):
         if not self.done:
             # We want to execute the thinning function only once, not every invocation of on_minibatch_begin
-            self.thinning_func(model, zeros_mask_dict, self.arch, self.dataset)
+            self.thinning_func(model, zeros_mask_dict, self.arch, self.dataset, optimizer=optimizer)
             self.done = True
 
-    def on_minibatch_begin(self, model, epoch, minibatch_id, minibatches_per_epoch, zeros_mask_dict):
+    def on_minibatch_begin(self, model, epoch, minibatch_id, minibatches_per_epoch, zeros_mask_dict, optimizer):
         # We hook onto the on_minibatch_begin because we want to run after the pruner which sparsified
         # the tensors.  Pruners configure their pruning mask in on_epoch_begin, but apply the mask
         # only in on_minibatch_begin
         if self.active_cb != "on_minibatch_begin":
             return
-        self.__apply(model, zeros_mask_dict)
+        self.__apply(model, zeros_mask_dict, optimizer)
 
-    def on_minibatch_end(self, model, epoch, minibatch_id, minibatches_per_epoch, zeros_mask_dict):
+    def on_minibatch_end(self, model, epoch, minibatch_id, minibatches_per_epoch, zeros_mask_dict, optimizer):
         if self.active_cb != "on_minibatch_end":
             return
-        self.__apply(model, zeros_mask_dict)
+        self.__apply(model, zeros_mask_dict, optimizer)
 
     def on_epoch_end(self, model, zeros_mask_dict, meta):
         # The epoch has ended and we reset the 'done' flag, so that the FilterRemover instance can be reused
@@ -416,15 +416,45 @@ def execute_thinning_recipes_list(model, zeros_mask_dict, recipe_list):
     # to a thinned model. For example, this is invoked when loading a model from a checkpoint.
     for i, recipe in enumerate(recipe_list):
         msglogger.info("recipe %d:" % i)
-        execute_thinning_recipe(model, zeros_mask_dict, recipe, loaded_from_file=True)
+        execute_thinning_recipe(model, zeros_mask_dict, recipe, optimizer=None, loaded_from_file=True)
 
 
-def execute_thinning_recipe(model, zeros_mask_dict, recipe, loaded_from_file=False):
+def optimizer_thinning(optimizer, param, dim, indices, new_shape=None):
+    """Adjust the size of the SGD vecolity-tracking tensors.
+
+    The SGD momentum update (velocity) is dependent on the weights, and because during thinning we
+    dynamically change the weights shapes, we need to make the apporpriate changes in the Optimizer,
+    or disable the momentum.
+
+    This function is brittle as it is tested on SGD only and relies on the internal representation of
+    the SGD optimizer, which can change w/o notice.
+    """
+    if optimizer is None or not isinstance(optimizer, torch.optim.SGD):
+        return False
+    for group in optimizer.param_groups:
+        momentum = group.get('momentum', 0)
+        if momentum == 0:
+            continue
+        for p in group['params']:
+            if id(p) != id(param):
+                continue
+            param_state = optimizer.state[p]
+            if 'momentum_buffer' in param_state:
+                param_state['momentum_buffer'] = torch.index_select(param_state['momentum_buffer'], dim, indices)
+                if new_shape is not None:
+                    msglogger.info("optimizer_thinning: new shape {}".format(*new_shape))
+                    param_state['momentum_buffer'] = param_state['momentum_buffer'].resize_(*new_shape)
+                return True
+    return False
+
+
+def execute_thinning_recipe(model, zeros_mask_dict, recipe, optimizer, loaded_from_file=False):
     """Apply a thinning recipe to a model.
 
     This will remove filters and channels, as well as handle batch-normalization parameter
     adjustment, and thinning of weight tensors.
     """
+
     layers = {}
     for name, m in model.named_modules():
         layers[name] = m
@@ -458,12 +488,14 @@ def execute_thinning_recipe(model, zeros_mask_dict, recipe, loaded_from_file=Fal
                 # Check if we're trying to trim a parameter that is already "thin"
                 if param.data.size(dim) != len_indices:
                     param.data = torch.index_select(selection_view, dim, indices)
-
-                if param.grad is not None:
-                    # We also need to change the dimensions of the gradient tensor.
-                    grad_selection_view = param.grad.resize_(*directive[2])
-                    if grad_selection_view.size(dim) != len_indices:
-                        param.grad = torch.index_select(grad_selection_view, dim, indices)
+                    if param.grad is not None:
+                        # We also need to change the dimensions of the gradient tensor.
+                        grad_selection_view = param.grad.resize_(*directive[2])
+                        if grad_selection_view.size(dim) != len_indices:
+                            param.grad = torch.index_select(grad_selection_view, dim, indices)
+                            if optimizer_thinning(optimizer, param, dim, indices, directive[3]):
+                                msglogger.info("Updated [4D] velocity buffer for {} (dim={},size={},shape={})".
+                                               format(param_name, dim, len_indices, directive[3]))
 
                 param.data = param.view(*directive[3])
                 if param.grad is not None:
@@ -471,12 +503,14 @@ def execute_thinning_recipe(model, zeros_mask_dict, recipe, loaded_from_file=Fal
             else:
                 if param.data.size(dim) != len_indices:
                     param.data = torch.index_select(param.data, dim, indices)
+                    msglogger.info("[thinning] changed param {} shape: {}".format(param_name, len_indices))
                 # We also need to change the dimensions of the gradient tensor.
                 # If have not done a backward-pass thus far, then the gradient will
                 # not exist, and therefore won't need to be re-dimensioned.
                 if param.grad is not None and param.grad.size(dim) != len_indices:
-                        param.grad = torch.index_select(param.grad, dim, indices)
-                msglogger.info("[thinning] changing param {} shape: {}".format(param_name, len_indices))
+                    param.grad = torch.index_select(param.grad, dim, indices)
+                    if optimizer_thinning(optimizer, param, dim, indices):
+                        msglogger.info("Updated velocity buffer %s" % param_name)
 
             if not loaded_from_file:
                 # If the masks are loaded from a checkpoint file, then we don't need to change
diff --git a/examples/classifier_compression/compress_classifier.py b/examples/classifier_compression/compress_classifier.py
index f33e171980a20646649a5b0dc3dd82b9b957f99c..c6db0b880ba27097f38de87ce7aa5441f26190cb 100755
--- a/examples/classifier_compression/compress_classifier.py
+++ b/examples/classifier_compression/compress_classifier.py
@@ -223,10 +223,6 @@ def main():
         model, compression_scheduler, start_epoch = apputils.load_checkpoint(
             model, chkpt_file=args.resume)
 
-        if 'resnet' in args.arch and 'preact' not in args.arch and 'cifar' in args.arch:
-            distiller.resnet_cifar_remove_layers(model)
-            #model = distiller.resnet_cifar_remove_channels(model, compression_scheduler.zeros_mask_dict)
-
     # Define loss function (criterion) and optimizer
     criterion = nn.CrossEntropyLoss().cuda()
     optimizer = torch.optim.SGD(model.parameters(), lr=args.lr,
@@ -320,11 +316,11 @@ def main():
                  OrderedDict([('Loss', vloss),
                               ('Top1', top1),
                               ('Top5', top5)]))
-        distiller.log_training_progress(stats, None,epoch, steps_completed=0, total_steps=1,
+        distiller.log_training_progress(stats, None, epoch, steps_completed=0, total_steps=1,
                                         log_freq=1, loggers=[tflogger])
 
         if compression_scheduler:
-            compression_scheduler.on_epoch_end(epoch)
+            compression_scheduler.on_epoch_end(epoch, optimizer)
 
         # remember best top1 and save checkpoint
         is_best = top1 > best_top1
@@ -339,8 +335,8 @@ def main():
 def train(train_loader, model, criterion, optimizer, epoch,
           compression_scheduler, loggers, print_freq, log_params_hist):
     """Training loop for one epoch."""
-    losses = {'objective_loss'   : tnt.AverageValueMeter(),
-              'regularizer_loss' : tnt.AverageValueMeter()}
+    losses = {'objective_loss':   tnt.AverageValueMeter(),
+              'regularizer_loss': tnt.AverageValueMeter()}
     if compression_scheduler is None:
         # Initialize the regularizer loss to zero
         losses['regularizer_loss'].add(0)
@@ -368,7 +364,7 @@ def train(train_loader, model, criterion, optimizer, epoch,
 
         # Execute the forward phase, compute the output and measure loss
         if compression_scheduler:
-            compression_scheduler.on_minibatch_begin(epoch, train_step, steps_per_epoch)
+            compression_scheduler.on_minibatch_begin(epoch, train_step, steps_per_epoch, optimizer)
         output = model(input_var)
         loss = criterion(output, target_var)
 
@@ -378,7 +374,7 @@ def train(train_loader, model, criterion, optimizer, epoch,
 
         if compression_scheduler:
             # Before running the backward phase, we add any regularization loss computed by the scheduler
-            regularizer_loss = compression_scheduler.before_backward_pass(epoch, train_step, steps_per_epoch, loss)
+            regularizer_loss = compression_scheduler.before_backward_pass(epoch, train_step, steps_per_epoch, loss, optimizer)
             loss += regularizer_loss
             losses['regularizer_loss'].add(regularizer_loss.item())
 
@@ -387,7 +383,7 @@ def train(train_loader, model, criterion, optimizer, epoch,
         loss.backward()
         optimizer.step()
         if compression_scheduler:
-            compression_scheduler.on_minibatch_end(epoch, train_step, steps_per_epoch)
+            compression_scheduler.on_minibatch_end(epoch, train_step, steps_per_epoch, optimizer)
 
         # measure elapsed time
         batch_time.add(time.time() - end)
diff --git a/examples/pruning_filters_for_efficient_convnets/vgg19.schedule_filter_rank.yaml b/examples/pruning_filters_for_efficient_convnets/vgg19.schedule_filter_rank.yaml
index 9efc6dc69c1362538721c6b723324d0750f2b1e3..cb20dd1ea5aea4e8d58803ce14b35cffaf10990b 100755
--- a/examples/pruning_filters_for_efficient_convnets/vgg19.schedule_filter_rank.yaml
+++ b/examples/pruning_filters_for_efficient_convnets/vgg19.schedule_filter_rank.yaml
@@ -2,7 +2,7 @@
 # This schedule performs 3D (filter-wise) regularization of some of the convolution layers, together with
 # element-wise pruning using sensitivity-pruning.
 #
-# time python3 compress_classifier.py -a=vgg19 -p=50 ../../../data.imagenet --epochs=10 --lr=0.00001 --compress=../pruning_filters_for_efficient_convnets/vgg19.schedule_filter_rank.yaml --pretrained --momentum=0
+# time python3 compress_classifier.py -a=vgg19 -p=50 ../../../data.imagenet --epochs=10 --lr=0.00001 --compress=../pruning_filters_for_efficient_convnets/vgg19.schedule_filter_rank.yaml --pretrained
 #
 
 
diff --git a/examples/ssl/ssl_4D-removal_training.yaml b/examples/ssl/ssl_4D-removal_training.yaml
index 0fcbe0f8dd95e569d9dab2dcf2440f528c587981..ef1b120452e8b07a8e2e595c4d76589bf2bec93a 100755
--- a/examples/ssl/ssl_4D-removal_training.yaml
+++ b/examples/ssl/ssl_4D-removal_training.yaml
@@ -1,3 +1,13 @@
+#
+# THIS IS CURRENTLY BROKEN!
+# There are two challenges to getting this to work:
+#   1. If we want to remove layers, we should use a thinning recipe and leave instructions.  Otherwise, it's very hard to
+#      decide if we should remove layers or not.
+#   2. We need a generic solution for removing layers.  Perhaps a new module type which wraps a layer and has a bypass-gate.
+#
+# If you still want to use this, until we implement a proper solution, invoke thinning.resnet_cifar_remove_layers from
+# compress_classifier.py.
+#
 # We used this schedule to train CIFAR10-ResNet20 from scratch with SSL.
 # After running this schedule, use ssl_4D-removal_finetuning.yaml to "physically" remove the layers and fine-tune the smaller model.
 #
diff --git a/tests/test_pruning.py b/tests/test_pruning.py
index b6c6a06e1a91c50b6ac2509f1ab28041651bfc19..b16e9012af1941d7cbbfbe86b106c8927c9cc494 100755
--- a/tests/test_pruning.py
+++ b/tests/test_pruning.py
@@ -35,7 +35,7 @@ fh = logging.FileHandler('test.log')
 logger = logging.getLogger()
 logger.addHandler(fh)
 
-NetConfig = namedtuple("test_config", "arch dataset conv1_name conv2_name bn_name")
+NetConfig = namedtuple("test_config", "arch dataset bn_name module_pairs")
 
 
 #
@@ -43,20 +43,34 @@ NetConfig = namedtuple("test_config", "arch dataset conv1_name conv2_name bn_nam
 #
 def simplenet():
     return NetConfig(arch="simplenet_cifar", dataset="cifar10",
-                     conv1_name="conv1", conv2_name="conv2",
+                     module_pairs=[("conv1", "conv2")],
                      bn_name=None)
 
 
 def resnet20_cifar():
     return NetConfig(arch="resnet20_cifar", dataset="cifar10",
-                     conv1_name="layer1.0.conv1", conv2_name="layer1.0.conv2",
+                     module_pairs=[("layer1.0.conv1", "layer1.0.conv2")],
                      bn_name="layer1.0.bn1")
 
 
+def vgg19_imagenet():
+    return NetConfig(arch="vgg19", dataset="imagenet",
+                     module_pairs=[("features.21", "features.23"),
+                                   ("features.23", "features.25"),
+                                   ("features.25", "features.28"),
+                                   ("features.28", "features.30"),
+                                   ("features.30", "features.32"),
+                                   ("features.32", "features.34")],
+                     bn_name=None)
+
+
 def test_ranked_filter_pruning():
     ranked_filter_pruning(resnet20_cifar(), ratio_to_prune=0.1)
     ranked_filter_pruning(resnet20_cifar(), ratio_to_prune=0.5)
     ranked_filter_pruning(simplenet(), ratio_to_prune=0.5)
+    ranked_filter_pruning(vgg19_imagenet(), ratio_to_prune=0.1)
+    model, zeros_mask_dict = ranked_filter_pruning(vgg19_imagenet(), ratio_to_prune=0.1)
+    test_conv_fc_interface(model, zeros_mask_dict)
 
 
 def test_prune_all_filters():
@@ -76,44 +90,46 @@ def ranked_filter_pruning(config, ratio_to_prune):
     """
     model, zeros_mask_dict = common.setup_test(config.arch, config.dataset)
 
-    # Test that we can access the weights tensor of the first convolution in layer 1
-    conv1_p = distiller.model_find_param(model, config.conv1_name + ".weight")
-    assert conv1_p is not None
-    num_filters = conv1_p.size(0)
-
-    # Test that there are no zero-filters
-    assert distiller.sparsity_3D(conv1_p) == 0.0
-
-    # Create a filter-ranking pruner
-    reg_regims = {config.conv1_name + ".weight": [ratio_to_prune, "3D"]}
-    pruner = distiller.pruning.L1RankedStructureParameterPruner("filter_pruner", reg_regims)
-    pruner.set_param_mask(conv1_p, config.conv1_name + ".weight", zeros_mask_dict, meta=None)
-
-    conv1 = common.find_module_by_name(model, config.conv1_name)
-    assert conv1 is not None
-    # Test that the mask has the correct fraction of filters pruned.
-    # We asked for 10%, but there are only 16 filters, so we have to settle for 1/16 filters
-    expected_cnt_removed_filters = int(ratio_to_prune * conv1.out_channels)
-    expected_pruning = expected_cnt_removed_filters / conv1.out_channels
-    masker = zeros_mask_dict[config.conv1_name + ".weight"]
-    assert masker is not None
-    assert distiller.sparsity_3D(masker.mask) == expected_pruning
-
-    # Use the mask to prune
-    assert distiller.sparsity_3D(conv1_p) == 0
-    masker.apply_mask(conv1_p)
-    assert distiller.sparsity_3D(conv1_p) == expected_pruning
-
-    # Remove filters
-    conv2 = common.find_module_by_name(model, config.conv2_name)
-    assert conv2 is not None
-    assert conv1.out_channels == num_filters
-    assert conv2.in_channels == num_filters
+    for pair in config.module_pairs:
+        # Test that we can access the weights tensor of the first convolution in layer 1
+        conv1_p = distiller.model_find_param(model, pair[0] + ".weight")
+        assert conv1_p is not None
+        num_filters = conv1_p.size(0)
+
+        # Test that there are no zero-filters
+        assert distiller.sparsity_3D(conv1_p) == 0.0
+
+        # Create a filter-ranking pruner
+        reg_regims = {pair[0] + ".weight": [ratio_to_prune, "3D"]}
+        pruner = distiller.pruning.L1RankedStructureParameterPruner("filter_pruner", reg_regims)
+        pruner.set_param_mask(conv1_p, pair[0] + ".weight", zeros_mask_dict, meta=None)
+
+        conv1 = common.find_module_by_name(model, pair[0])
+        assert conv1 is not None
+        # Test that the mask has the correct fraction of filters pruned.
+        # We asked for 10%, but there are only 16 filters, so we have to settle for 1/16 filters
+        expected_cnt_removed_filters = int(ratio_to_prune * conv1.out_channels)
+        expected_pruning = expected_cnt_removed_filters / conv1.out_channels
+        masker = zeros_mask_dict[pair[0] + ".weight"]
+        assert masker is not None
+        assert distiller.sparsity_3D(masker.mask) == expected_pruning
+
+        # Use the mask to prune
+        assert distiller.sparsity_3D(conv1_p) == 0
+        masker.apply_mask(conv1_p)
+        assert distiller.sparsity_3D(conv1_p) == expected_pruning
+
+        # Remove filters
+        conv2 = common.find_module_by_name(model, pair[1])
+        assert conv2 is not None
+        assert conv1.out_channels == num_filters
+        assert conv2.in_channels == num_filters
 
     # Test thinning
-    distiller.remove_filters(model, zeros_mask_dict, config.arch, config.dataset)
+    distiller.remove_filters(model, zeros_mask_dict, config.arch, config.dataset, optimizer=None)
     assert conv1.out_channels == num_filters - expected_cnt_removed_filters
     assert conv2.in_channels == num_filters - expected_cnt_removed_filters
+    return model, zeros_mask_dict
 
 
 def test_arbitrary_channel_pruning():
@@ -134,6 +150,40 @@ def test_channel_pruning_conv_bias():
     arbitrary_channel_pruning(simplenet(), channels_to_remove=[0, 1])
 
 
+def create_channels_mask(conv_p, channels_to_remove):
+    assert conv_p.dim() == 4
+    num_filters = conv_p.size(0)
+    num_channels = conv_p.size(1)
+    kernel_height = conv_p.size(2)
+    kernel_width = conv_p.size(3)
+
+    # Let's build our 4D mask.
+    # We start with a 1D mask of channels, with all but our specified channels set to one
+    channels = torch.ones(num_channels)
+    for ch in channels_to_remove:
+        channels[ch] = 0
+
+    # Now let's expand back up to a 4D mask
+    mask = channels.expand(num_filters, num_channels)
+    mask.unsqueeze_(-1)
+    mask.unsqueeze_(-1)
+    mask = mask.expand(num_filters, num_channels, kernel_height, kernel_width).contiguous()
+
+    assert mask.shape == conv_p.shape
+    return mask
+
+
+def run_forward_backward(model, optimizer, dummy_input):
+    criterion = torch.nn.CrossEntropyLoss().cuda()
+    model.train()
+    output = model(dummy_input)
+    target = torch.LongTensor(1).random_(2)
+    loss = criterion(output, target)
+    optimizer.zero_grad()
+    loss.backward()
+    optimizer.step()
+
+
 def arbitrary_channel_pruning(config, channels_to_remove):
     """Test removal of arbitrary channels.
 
@@ -143,49 +193,34 @@ def arbitrary_channel_pruning(config, channels_to_remove):
     """
     model, zeros_mask_dict = common.setup_test(config.arch, config.dataset)
 
-    conv2 = common.find_module_by_name(model, config.conv2_name)
+    assert len(config.module_pairs) == 1   # This is a temporary restriction on the test
+    pair = config.module_pairs[0]
+    conv2 = common.find_module_by_name(model, pair[1])
     assert conv2 is not None
 
     # Test that we can access the weights tensor of the first convolution in layer 1
-    conv2_p = distiller.model_find_param(model, config.conv2_name + ".weight")
+    conv2_p = distiller.model_find_param(model, pair[1] + ".weight")
     assert conv2_p is not None
 
     assert conv2_p.dim() == 4
-    num_filters = conv2_p.size(0)
     num_channels = conv2_p.size(1)
-    kernel_height = conv2_p.size(2)
-    kernel_width = conv2_p.size(3)
     cnt_nnz_channels = num_channels - len(channels_to_remove)
-
-    # Let's build our 4D mask.
-    # We start with a 1D mask of channels, with all but our specified channels set to one
-    channels = torch.ones(num_channels)
-    for ch in channels_to_remove:
-        channels[ch] = 0
-
-    # Now let's expand back up to a 4D mask
-    mask = channels.expand(num_filters, num_channels)
-    mask.unsqueeze_(-1)
-    mask.unsqueeze_(-1)
-    mask = mask.expand(num_filters, num_channels, kernel_height, kernel_width).contiguous()
-
-    assert mask.shape == conv2_p.shape
+    mask = create_channels_mask(conv2_p, channels_to_remove)
     assert distiller.density_ch(mask) == (conv2.in_channels - len(channels_to_remove)) / conv2.in_channels
-
     # Cool, so now we have a mask for pruning our channels.
+
     # Use the mask to prune
-    zeros_mask_dict[config.conv2_name + ".weight"].mask = mask
-    zeros_mask_dict[config.conv2_name + ".weight"].apply_mask(conv2_p)
+    zeros_mask_dict[pair[1] + ".weight"].mask = mask
+    zeros_mask_dict[pair[1] + ".weight"].apply_mask(conv2_p)
     all_channels = set([ch for ch in range(num_channels)])
-    nnz_channels = set(distiller.find_nonzero_channels_list(conv2_p, config.conv2_name + ".weight"))
+    nnz_channels = set(distiller.find_nonzero_channels_list(conv2_p, pair[1] + ".weight"))
     channels_removed = all_channels - nnz_channels
     logger.info("Channels removed {}".format(channels_removed))
 
     # Now, let's do the actual network thinning
-    distiller.remove_channels(model, zeros_mask_dict, config.arch, config.dataset)
-    conv1 = common.find_module_by_name(model, config.conv1_name)
-    logger.info(conv1)
-    logger.info(conv2)
+    distiller.remove_channels(model, zeros_mask_dict, config.arch, config.dataset, optimizer=None)
+    conv1 = common.find_module_by_name(model, pair[0])
+
     assert conv1.out_channels == cnt_nnz_channels
     assert conv2.in_channels == cnt_nnz_channels
     assert conv1.weight.size(0) == cnt_nnz_channels
@@ -198,6 +233,10 @@ def arbitrary_channel_pruning(config, channels_to_remove):
         assert bn1.bias.size(0) == cnt_nnz_channels
         assert bn1.weight.size(0) == cnt_nnz_channels
 
+    dummy_input = torch.randn(1, 3, 32, 32)
+    optimizer = torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.9, weight_decay=0.1)
+    run_forward_backward(model, optimizer, dummy_input)
+
     # Let's test saving and loading a thinned model.
     # We save 3 times, and load twice, to make sure to cover some corner cases:
     #   - Make sure that after loading, the model still has hold of the thinning recipes
@@ -206,16 +245,17 @@ def arbitrary_channel_pruning(config, channels_to_remove):
     # (1)
     save_checkpoint(epoch=0, arch=config.arch, model=model, optimizer=None)
     model_2 = create_model(False, config.dataset, config.arch, parallel=False)
-    dummy_input = torch.randn(1, 3, 32, 32)
     model(dummy_input)
     model_2(dummy_input)
-    conv2 = common.find_module_by_name(model_2, config.conv2_name)
+    conv2 = common.find_module_by_name(model_2, pair[1])
     assert conv2 is not None
     with pytest.raises(KeyError):
         model_2, compression_scheduler, start_epoch = load_checkpoint(model_2, 'checkpoint.pth.tar')
     compression_scheduler = distiller.CompressionScheduler(model)
     hasattr(model, 'thinning_recipes')
 
+    run_forward_backward(model, optimizer, dummy_input)
+
     # (2)
     save_checkpoint(epoch=0, arch=config.arch, model=model, optimizer=None, scheduler=compression_scheduler)
     model_2, compression_scheduler, start_epoch = load_checkpoint(model_2, 'checkpoint.pth.tar')
@@ -229,7 +269,62 @@ def arbitrary_channel_pruning(config, channels_to_remove):
     logger.info("test_arbitrary_channel_pruning - Done 2")
 
 
+def test_conv_fc_interface(model=None, zeros_mask_dict=None):
+    """A special case of convolution filter-pruning occurs when the next layer is
+    fully-connected (linear).  This test is for this case and uses VGG16.
+    """
+    arch = "vgg19"
+    dataset = "imagenet"
+    ratio_to_prune = 0.1
+    conv_name = "features.34"
+    fc_name = "classifier.0"
+    dummy_input = torch.randn(1, 3, 224, 224)
+
+    if model is None or zeros_mask_dict is None:
+        model, zeros_mask_dict = common.setup_test(arch, dataset)
+
+    # Run forward and backward passes, in order to create the gradients and optimizer params
+    optimizer = torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.9, weight_decay=0.1)
+    run_forward_backward(model, optimizer, dummy_input)
+
+    conv = common.find_module_by_name(model, conv_name)
+    assert conv is not None
+
+    conv_p = distiller.model_find_param(model, conv_name + ".weight")
+    assert conv_p is not None
+    assert conv_p.dim() == 4
+
+    # Create a filter-ranking pruner
+    reg_regims = {conv_name + ".weight": [ratio_to_prune, "3D"]}
+    pruner = distiller.pruning.L1RankedStructureParameterPruner("filter_pruner", reg_regims)
+    pruner.set_param_mask(conv_p, conv_name + ".weight", zeros_mask_dict, meta=None)
+
+    # Use the mask to prune
+    masker = zeros_mask_dict[conv_name + ".weight"]
+    assert masker is not None
+    masker.apply_mask(conv_p)
+    num_filters = conv_p.size(0)
+    expected_cnt_removed_filters = int(ratio_to_prune * conv.out_channels)
+
+    # Remove filters
+    fc = common.find_module_by_name(model, fc_name)
+    assert fc is not None
+
+    # Test thinning
+    fm_size = fc.in_features // conv.out_channels
+    num_nnz_filters = num_filters - expected_cnt_removed_filters
+    distiller.remove_filters(model, zeros_mask_dict, arch, dataset, optimizer)
+    assert conv.out_channels == num_nnz_filters
+    assert fc.in_features == fm_size * num_nnz_filters
+
+    # Run again, to make sure the optimizer and gradients shapes were updated correctly
+    run_forward_backward(model, optimizer, dummy_input)
+    run_forward_backward(model, optimizer, dummy_input)
+
+
 if __name__ == '__main__':
     test_ranked_filter_pruning()
     test_arbitrary_channel_pruning()
     test_prune_all_channels()
+    model, zeros_mask_dict = ranked_filter_pruning(vgg19_imagenet(), ratio_to_prune=0.1)
+    test_conv_fc_interface(model, zeros_mask_dict)