diff --git a/apputils/checkpoint.py b/apputils/checkpoint.py
index a7509f57fc8f25686d333144f321d0f557ada584..83ae242b4c0d8c9281e3d87226a0d005ab79341a 100755
--- a/apputils/checkpoint.py
+++ b/apputils/checkpoint.py
@@ -50,7 +50,8 @@ def save_checkpoint(epoch, arch, model, optimizer, scheduler=None, best_top1=Non
     checkpoint['state_dict'] = model.state_dict()
     if best_top1 is not None:
         checkpoint['best_top1'] = best_top1
-    checkpoint['optimizer'] = optimizer.state_dict()
+    if optimizer is not None:
+        checkpoint['optimizer'] = optimizer.state_dict()
     if scheduler is not None:
         checkpoint['compression_sched'] = scheduler.state_dict()
     if hasattr(model, 'thinning_recipes'):
@@ -86,18 +87,21 @@ def load_checkpoint(model, chkpt_file, optimizer=None):
             msglogger.info("Loaded compression schedule from checkpoint (epoch %d)",
                            checkpoint['epoch'])
 
-            if 'thinning_recipes' in checkpoint:
-                msglogger.info("Loaded a thinning recipe from the checkpoint")
-                distiller.execute_thinning_recipes_list(model,
-                                                  compression_scheduler.zeros_mask_dict,
-                                                  checkpoint['thinning_recipes'])
+        if 'thinning_recipes' in checkpoint:
+            if 'compression_sched' not in checkpoint:
+                raise KeyError("Found thinning_recipes key, but missing mandatoy key compression_sched")
+            msglogger.info("Loaded a thinning recipe from the checkpoint")
+            # Cache the recipes in case we need them later
+            model.thinning_recipes = checkpoint['thinning_recipes']
+            distiller.execute_thinning_recipes_list(model,
+                                              compression_scheduler.zeros_mask_dict,
+                                              model.thinning_recipes)
         else:
             msglogger.info("Warning: compression schedule data does not exist in the checkpoint")
             msglogger.info("=> loaded checkpoint '%s' (epoch %d)",
                            chkpt_file, checkpoint['epoch'])
 
         model.load_state_dict(checkpoint['state_dict'])
-
         return model, compression_scheduler, start_epoch
     else:
         msglogger.info("Error: no checkpoint found at %s", chkpt_file)
diff --git a/distiller/thinning.py b/distiller/thinning.py
index 4a2cfdb18a22661463d3ad89ade5794e6b46f44f..caa3192528b6a6518beafbc5cc349b2a83e734cb 100755
--- a/distiller/thinning.py
+++ b/distiller/thinning.py
@@ -66,6 +66,7 @@ These tuples can have 2 values, or 4 values.
 __all__ = ['ThinningRecipe', 'resnet_cifar_remove_layers',
            'ChannelRemover', 'remove_channels',
            'FilterRemover',  'remove_filters',
+           'find_nonzero_channels',
            'execute_thinning_recipes_list']
 
 def create_graph(dataset, arch):
@@ -266,7 +267,7 @@ def create_thinning_recipe_channels(sgraph, model, zeros_mask_dict):
         predecessors = [denormalize_layer_name(model, predecessor) for predecessor in predecessors]
         for predecessor in predecessors:
             # For each of the convolutional layers that preceed, we have to reduce the number of output channels.
-            append_module_directive(thinning_recipe, layer_name, key='out_channels', val=len(nonzero_channels))
+            append_module_directive(thinning_recipe, predecessor, key='out_channels', val=len(nonzero_channels))
 
             # Now remove channels from the weights tensor of the successor conv
             append_param_directive(thinning_recipe, predecessor+'.weight', (0, indices))
@@ -415,7 +416,7 @@ def execute_thinning_recipes_list(model, zeros_mask_dict, recipe_list):
     # Invoke this function when you want to use a list of thinning recipes to convert a programmed model
     # to a thinned model. For example, this is invoked when loading a model from a checkpoint.
     for i, recipe in enumerate(recipe_list):
-        msglogger.info("recipe %d" % i)
+        msglogger.info("recipe %d:" % i)
         execute_thinning_recipe(model, zeros_mask_dict, recipe, loaded_from_file=True)
 
 def execute_thinning_recipe(model, zeros_mask_dict, recipe, loaded_from_file=False):
@@ -431,10 +432,14 @@ def execute_thinning_recipe(model, zeros_mask_dict, recipe, loaded_from_file=Fal
     for layer_name, directives in recipe.modules.items():
         for attr, val in directives.items():
             if attr in ['running_mean', 'running_var']:
-                msglogger.info("[thinning] {}: setting {} to {}".format(layer_name, attr, len(val[1])))
-                setattr(layers[layer_name], attr,
-                        torch.index_select(getattr(layers[layer_name], attr),
-                                           dim=val[0], index=val[1]))
+                running = getattr(layers[layer_name], attr)
+                dim_to_trim = val[0]
+                indices_to_select = val[1]
+                # Check if we're trying to trim a parameter that is already "thin"
+                if running.size(dim_to_trim) != len(indices_to_select):
+                    msglogger.info("[thinning] {}: setting {} to {}".format(layer_name, attr, len(indices_to_select)))
+                    setattr(layers[layer_name], attr,
+                            torch.index_select(running, dim=dim_to_trim, index=indices_to_select))
             else:
                 msglogger.info("[thinning] {}: setting {} to {}".format(layer_name, attr, val))
                 setattr(layers[layer_name], attr, val)
@@ -448,28 +453,32 @@ def execute_thinning_recipe(model, zeros_mask_dict, recipe, loaded_from_file=Fal
             indices = directive[1]
             if len(directive) == 4:  # TODO: this code is hard to follow
                 selection_view = param.view(*directive[2])
-                param.data = torch.index_select(selection_view, dim, indices)
+                # Check if we're trying to trim a parameter that is already "thin"
+                if param.data.size(dim) != len(indices):
+                    param.data = torch.index_select(selection_view, dim, indices)
 
                 if param.grad is not None:
                     # We also need to change the dimensions of the gradient tensor.
                     grad_selection_view = param.grad.resize_(*directive[2])
-                    param.grad = torch.index_select(grad_selection_view, dim, indices)
+                    if grad_selection_view.size(dim) != len(indices):
+                        param.grad = torch.index_select(grad_selection_view, dim, indices)
 
                 param.data = param.view(*directive[3])
                 if param.grad is not None:
                     param.grad = param.grad.resize_(*directive[3])
             else:
-                param.data = torch.index_select(param.data, dim, indices)
+                if param.data.size(dim) != len(indices):
+                    param.data = torch.index_select(param.data, dim, indices)
                 # We also need to change the dimensions of the gradient tensor.
                 # If have not done a backward-pass thus far, then the gradient will
                 # not exist, and therefore won't need to be re-dimensioned.
-                if param.grad is not None:
-                    param.grad = torch.index_select(param.grad, dim, indices)
+                if param.grad is not None and param.grad.size(dim) != len(indices):
+                        param.grad = torch.index_select(param.grad, dim, indices)
                 msglogger.info("[thinning] changing param {} shape: {}".format(param_name, len(indices)))
 
             if not loaded_from_file:
                 # If the masks are loaded from a checkpoint file, then we don't need to change
                 # their shape, because they are already correctly shaped
                 mask = zeros_mask_dict[param_name].mask
-                if mask is not None: # and (mask.size(dim) != len(indices)):
+                if mask is not None and (mask.size(dim) != len(indices)):
                     zeros_mask_dict[param_name].mask = torch.index_select(mask, dim, indices)
diff --git a/tests/test_pruning.py b/tests/test_pruning.py
index 93f3c77a87e273db4971f3e827eb367b32358cd4..fa01f22f7076592b686f9c6731e99ef6ea567b10 100755
--- a/tests/test_pruning.py
+++ b/tests/test_pruning.py
@@ -25,7 +25,7 @@ import distiller
 
 import pytest
 from models import ALL_MODEL_NAMES, create_model
-from apputils import SummaryGraph, onnx_name_2_pytorch_name
+from apputils import SummaryGraph, onnx_name_2_pytorch_name, save_checkpoint, load_checkpoint
 
 # Logging configuration
 logging.basicConfig(level=logging.INFO)
@@ -39,8 +39,9 @@ def find_module_by_name(model, module_to_find):
             return m
     return None
 
-def test_ranked_filter_pruning():
-    model = create_model(False, 'cifar10', 'resnet20_cifar', parallel=False)
+
+def setup_test(arch, dataset):
+    model = create_model(False, dataset, arch, parallel=False)
     assert model is not None
 
     # Create the masks
@@ -48,6 +49,10 @@ def test_ranked_filter_pruning():
     for name, param in model.named_parameters():
         masker = distiller.ParameterMasker(name)
         zeros_mask_dict[name] = masker
+    return model, zeros_mask_dict
+
+def test_ranked_filter_pruning():
+    model, zeros_mask_dict = setup_test("resnet20_cifar", "cifar10")
 
     # Test that we can access the weights tensor of the first convolution in layer 1
     conv1_p = distiller.model_find_param(model, "layer1.0.conv1.weight")
@@ -68,18 +73,110 @@ def test_ranked_filter_pruning():
     expected_pruning = int(0.1 * conv1.out_channels) / conv1.out_channels
     assert distiller.sparsity_3D(zeros_mask_dict["layer1.0.conv1.weight"].mask) == expected_pruning
 
-    # Test masker
+    # Use the mask to prune
     assert distiller.sparsity_3D(conv1_p) == 0
     zeros_mask_dict["layer1.0.conv1.weight"].apply_mask(conv1_p)
     assert distiller.sparsity_3D(conv1_p) == expected_pruning
 
     # Remove filters
-    assert conv1.out_channels == 16
     conv2 = find_module_by_name(model, "layer1.0.conv2")
     assert conv2 is not None
-    assert conv1.in_channels == 16
+    assert conv1.out_channels == 16
+    assert conv2.in_channels == 16
 
     # Test thinning
     distiller.remove_filters(model, zeros_mask_dict, "resnet20_cifar", "cifar10")
     assert conv1.out_channels == 15
     assert conv2.in_channels == 15
+
+
+def test_arbitrary_channel_pruning():
+    ARCH = "resnet20_cifar"
+    DATASET = "cifar10"
+
+    model, zeros_mask_dict = setup_test(ARCH, DATASET)
+
+    conv2 = find_module_by_name(model, "layer1.0.conv2")
+    assert conv2 is not None
+
+    # Test that we can access the weights tensor of the first convolution in layer 1
+    conv2_p = distiller.model_find_param(model, "layer1.0.conv2.weight")
+    assert conv2_p is not None
+
+    assert conv2_p.dim() == 4
+    num_filters = conv2_p.size(0)
+    num_channels = conv2_p.size(1)
+    kernel_height = conv2_p.size(2)
+    kernel_width = conv2_p.size(3)
+
+    channels_to_remove = [0, 2]
+
+    # Let's build our 4D mask.
+    # We start with a 1D mask of channels, with all but our specified channels set to one
+    channels = torch.ones(num_channels)
+    for ch in channels_to_remove:
+        channels[ch] = 0
+
+    # Now let's expand back up to a 4D mask
+    mask = channels.expand(num_filters, num_channels)
+    mask.unsqueeze_(-1)
+    mask.unsqueeze_(-1)
+    mask = mask.expand(num_filters, num_channels, kernel_height, kernel_width).contiguous()
+
+    assert mask.shape == conv2_p.shape
+    assert distiller.density_ch(mask) == (conv2.in_channels - len(channels_to_remove)) / conv2.in_channels
+
+    # Cool, so now we have a mask for pruning our channels.
+    # Use the mask to prune
+    zeros_mask_dict["layer1.0.conv2.weight"].mask = mask
+    zeros_mask_dict["layer1.0.conv2.weight"].apply_mask(conv2_p)
+    all_channels = set([ch for ch in range(num_channels)])
+    channels_removed = all_channels - set(distiller.find_nonzero_channels(conv2_p, "layer1.0.conv2.weight"))
+    logger.info(channels_removed)
+
+    # Now, let's do the actual network thinning
+    distiller.remove_channels(model, zeros_mask_dict, ARCH, DATASET)
+    conv1 = find_module_by_name(model, "layer1.0.conv1")
+    logger.info(conv1)
+    logger.info(conv2)
+    assert conv1.out_channels == 14
+    assert conv2.in_channels == 14
+    assert conv1.weight.size(0) == 14
+    assert conv2.weight.size(1) == 14
+    bn1 = find_module_by_name(model, "layer1.0.bn1")
+    assert bn1.running_var.size(0) == 14
+    assert bn1.running_mean.size(0) == 14
+    assert bn1.num_features == 14
+    assert bn1.bias.size(0) == 14
+    assert bn1.weight.size(0) == 14
+
+    # Let's test saving and loading a thinned model.
+    # We save 3 times, and load twice, to make sure to cover some corner cases:
+    #   - Make sure that after loading, the model still has hold of the thinning recipes
+    #   - Make sure that after a 2nd load, there no problem loading (in this case, the
+    #   - tensors are already thin, so this is a new flow)
+    save_checkpoint(epoch=0, arch=ARCH, model=model, optimizer=None)
+    model_2 = create_model(False, DATASET, ARCH, parallel=False)
+    dummy_input = torch.randn(1, 3, 32, 32)
+    model(dummy_input)
+    model_2(dummy_input)
+    conv2 = find_module_by_name(model_2, "layer1.0.conv2")
+    assert conv2 is not None
+    with pytest.raises(KeyError):
+        model_2, compression_scheduler, start_epoch = load_checkpoint(model_2, 'checkpoint.pth.tar')
+
+    compression_scheduler = distiller.CompressionScheduler(model)
+    hasattr(model, 'thinning_recipes')
+    save_checkpoint(epoch=0, arch=ARCH, model=model, optimizer=None, scheduler=compression_scheduler)
+    model_2, compression_scheduler, start_epoch = load_checkpoint(model_2, 'checkpoint.pth.tar')
+    assert hasattr(model_2, 'thinning_recipes')
+    logger.info("test_arbitrary_channel_pruning - Done")
+
+    save_checkpoint(epoch=0, arch=ARCH, model=model_2, optimizer=None, scheduler=compression_scheduler)
+    model_2, compression_scheduler, start_epoch = load_checkpoint(model_2, 'checkpoint.pth.tar')
+    logger.info("test_arbitrary_channel_pruning - Done 2")
+
+
+
+if __name__ == '__main__':
+    test_arbitrary_channel_pruning()