diff --git a/distiller/apputils/checkpoint.py b/distiller/apputils/checkpoint.py
index f45026b49c3669352622d750779c9adc99dd4c5d..d317bdead0e235680c9088152f8eb28770545b84 100755
--- a/distiller/apputils/checkpoint.py
+++ b/distiller/apputils/checkpoint.py
@@ -198,8 +198,8 @@ def load_checkpoint(model, chkpt_file, optimizer=None,
     if not model:
         model = _create_model_from_ckpt()
         if not model:
-            raise ValueError("You didn't provide a model, and the checkpoint doesn't contain"
-                             "enough information to create one")
+            raise ValueError("You didn't provide a model, and the checkpoint %s doesn't contain "
+                             "enough information to create one", chkpt_file)
 
     checkpoint_epoch = checkpoint.get('epoch', None)
     start_epoch = checkpoint_epoch + 1 if checkpoint_epoch is not None else 0
diff --git a/distiller/thinning.py b/distiller/thinning.py
index 0d790c21817afe251fe9c3ef340eee05e27eece7..74ee6f73a07a684e0f663007376f5a86b3afc6b3 100755
--- a/distiller/thinning.py
+++ b/distiller/thinning.py
@@ -60,9 +60,37 @@ __all__ = ['ThinningRecipe', 'resnet_cifar_remove_layers',
            'StructureRemover',
            'ChannelRemover', 'remove_channels',
            'FilterRemover',  'remove_filters',
+           'contract_model',
            'execute_thinning_recipes_list', 'get_normalized_recipe']
 
 
+def contract_model(model, zeros_mask_dict, arch, dataset, optimizer):
+    """Contract a model by removing filters and channels
+
+    The algorithm searches for weight filters and channels that have all
+    zero-coefficients, and shrinks the model by removing these channels
+    and filters from the model definition, along with any related parameters.
+    """
+    remove_filters(model, zeros_mask_dict, arch, dataset, optimizer)
+    remove_channels(model, zeros_mask_dict, arch, dataset, optimizer)
+
+
+def remove_channels(model, zeros_mask_dict, arch, dataset, optimizer):
+    """Contract a model by removing weight channels"""
+    sgraph = _create_graph(dataset, model)
+    thinning_recipe = create_thinning_recipe_channels(sgraph, model, zeros_mask_dict)
+    apply_and_save_recipe(model, zeros_mask_dict, thinning_recipe, optimizer)
+    return model
+
+
+def remove_filters(model, zeros_mask_dict, arch, dataset, optimizer):
+    """Contract a model by removing weight filters"""
+    sgraph = _create_graph(dataset, model)
+    thinning_recipe = create_thinning_recipe_filters(sgraph, model, zeros_mask_dict)
+    apply_and_save_recipe(model, zeros_mask_dict, thinning_recipe, optimizer)
+    return model
+
+
 def _create_graph(dataset, model):
     dummy_input = distiller.get_dummy_input(dataset, distiller.model_device(model))
     return SummaryGraph(model, dummy_input)
@@ -75,7 +103,7 @@ def get_normalized_recipe(recipe):
         )
 
 
-def directives_equal(d1, d2):
+def _directives_equal(d1, d2):
     """Test if two directives are equal"""
     if len(d1) != len(d2):
         return False
@@ -88,7 +116,7 @@ def directives_equal(d1, d2):
     assert ValueError("Unsupported directive length")
 
 
-def append_param_directive(thinning_recipe, param_name, directive):
+def _append_param_directive(thinning_recipe, param_name, directive):
     """Add a parameter directive to a recipe.
 
     Parameter directives contain instructions for changing the physical shape of parameters.
@@ -98,14 +126,14 @@ def append_param_directive(thinning_recipe, param_name, directive):
         # Duplicate parameter directives are rooted out because they can create erronous conditions.
         # For example, if the first directive changes the change of the parameter, a second
         # directive will cause an exception.
-        if directives_equal(d, directive):
+        if _directives_equal(d, directive):
             return
     msglogger.debug("\t[recipe] param_directive for {} = {}".format(param_name, directive))
     param_directives.append(directive)
     thinning_recipe.parameters[param_name] = param_directives
 
 
-def append_module_directive(thinning_recipe, module_name, key, val):
+def _append_module_directive(thinning_recipe, module_name, key, val):
     """Add a module directive to a recipe.
 
     Parameter directives contain instructions for changing the attributes of 
@@ -117,7 +145,7 @@ def append_module_directive(thinning_recipe, module_name, key, val):
     thinning_recipe.modules[module_name] = mod_directive
 
 
-def append_bn_thinning_directive(thinning_recipe, layers, bn_name, len_thin_features, thin_features):
+def _append_bn_thinning_directive(thinning_recipe, layers, bn_name, len_thin_features, thin_features):
     """Adjust the sizes of the parameters of a BatchNormalization layer.
 
     This function is invoked after the Convolution layer preceeding a BN layer has
@@ -142,13 +170,6 @@ def append_bn_thinning_directive(thinning_recipe, layers, bn_name, len_thin_feat
     thinning_recipe.parameters[bn_name+'.bias'] = [(0, thin_features)]
 
 
-def remove_channels(model, zeros_mask_dict, arch, dataset, optimizer):
-    sgraph = _create_graph(dataset, model)
-    thinning_recipe = create_thinning_recipe_channels(sgraph, model, zeros_mask_dict)
-    apply_and_save_recipe(model, zeros_mask_dict, thinning_recipe, optimizer)
-    return model
-
-
 def apply_and_save_recipe(model, zeros_mask_dict, thinning_recipe, optimizer):
     if len(thinning_recipe.modules) > 0 or len(thinning_recipe.parameters) > 0:
         # Now actually remove the filters, channels and make the weight tensors smaller
@@ -165,13 +186,6 @@ def apply_and_save_recipe(model, zeros_mask_dict, thinning_recipe, optimizer):
         msglogger.error("Failed to create a thinning recipe")
 
 
-def remove_filters(model, zeros_mask_dict, arch, dataset, optimizer):
-    sgraph = _create_graph(dataset, model)
-    thinning_recipe = create_thinning_recipe_filters(sgraph, model, zeros_mask_dict)
-    apply_and_save_recipe(model, zeros_mask_dict, thinning_recipe, optimizer)
-    return model
-
-
 def create_thinning_recipe_channels(sgraph, model, zeros_mask_dict):
     """Create a recipe for removing channels from Convolution layers.
 
@@ -187,14 +201,14 @@ def create_thinning_recipe_channels(sgraph, model, zeros_mask_dict):
         # in the convolutional layer
         assert isinstance(layers[layer_name], (torch.nn.modules.Conv2d, torch.nn.modules.Linear))
 
-        append_module_directive(thinning_recipe, layer_name, key='in_channels', val=nnz_channels)
+        _append_module_directive(thinning_recipe, layer_name, key='in_channels', val=nnz_channels)
 
         # Select only the non-zero channels
         indices = nonzero_channels.data.squeeze()
         dim = 1 if isinstance(layers[layer_name], torch.nn.modules.Conv2d) and layers[layer_name].groups == 1 else 0
         if isinstance(layers[layer_name], torch.nn.modules.Linear):
             dim = 1
-        append_param_directive(thinning_recipe, param_name, (dim, indices))
+        _append_param_directive(thinning_recipe, param_name, (dim, indices))
 
         # Find all instances of Convolution layers that immediately precede this layer
         predecessors = sgraph.predecessors_f(layer_name, ['Conv'])
@@ -202,24 +216,24 @@ def create_thinning_recipe_channels(sgraph, model, zeros_mask_dict):
             msglogger.info("Could not find predecessors for name=%s" % layer_name)
         for predecessor in predecessors:
             # For each of the convolution layers that precede, we have to reduce the number of output channels.
-            append_module_directive(thinning_recipe, predecessor, key='out_channels', val=nnz_channels)
+            _append_module_directive(thinning_recipe, predecessor, key='out_channels', val=nnz_channels)
 
             if layers[predecessor].groups == 1:
                 # Now remove filters from the weights tensor of the predecessor conv
-                append_param_directive(thinning_recipe, predecessor+'.weight', (0, indices))
+                _append_param_directive(thinning_recipe, predecessor + '.weight', (0, indices))
                 if layers[predecessor].bias is not None:
                     # This convolution has bias coefficients
-                    append_param_directive(thinning_recipe, predecessor+'.bias', (0, indices))
+                    _append_param_directive(thinning_recipe, predecessor + '.bias', (0, indices))
 
             elif layers[predecessor].groups == layers[predecessor].in_channels:
                 # This is a group-wise convolution, and a special one at that (groups == in_channels).
                 # Now remove filters from the weights tensor of the predecessor conv
-                append_param_directive(thinning_recipe, predecessor+'.weight', (0, indices))
+                _append_param_directive(thinning_recipe, predecessor + '.weight', (0, indices))
 
                 if layers[predecessor].bias is not None:
                     # This convolution has bias coefficients
-                    append_param_directive(thinning_recipe, predecessor+'.bias', (0, indices))
-                append_module_directive(thinning_recipe, predecessor, key='groups', val=nnz_channels)
+                    _append_param_directive(thinning_recipe, predecessor + '.bias', (0, indices))
+                _append_module_directive(thinning_recipe, predecessor, key='groups', val=nnz_channels)
 
                 # In the special case of a Convolutional layer with (groups == in_channels), if we
                 # change in_channels, we also need to change out_channels, which means that we
@@ -234,8 +248,8 @@ def create_thinning_recipe_channels(sgraph, model, zeros_mask_dict):
         for bn_layer in bn_layers:
             # Thinning of the BN layer that follows the convolution
             msglogger.debug("[recipe] {}: predecessor BN module = {}".format(layer_name, bn_layer))
-            append_bn_thinning_directive(thinning_recipe, layers, bn_layer,
-                                         len_thin_features=nnz_channels, thin_features=indices)
+            _append_bn_thinning_directive(thinning_recipe, layers, bn_layer,
+                                          len_thin_features=nnz_channels, thin_features=indices)
 
     msglogger.debug("Invoking create_thinning_recipe_channels")
     thinning_recipe = ThinningRecipe(modules={}, parameters={})
@@ -279,15 +293,15 @@ def create_thinning_recipe_filters(sgraph, model, zeros_mask_dict):
         # We are removing filters, so update the number of outgoing channels (OFMs)
         # in the convolutional layer
         assert isinstance(layers[layer_name], torch.nn.modules.Conv2d)
-        append_module_directive(thinning_recipe, layer_name, key='out_channels', val=num_nnz_filters)
+        _append_module_directive(thinning_recipe, layer_name, key='out_channels', val=num_nnz_filters)
 
         # Select only the non-zero filters
         indices = nonzero_filters.data.squeeze()
-        append_param_directive(thinning_recipe, param_name, (0, indices))
+        _append_param_directive(thinning_recipe, param_name, (0, indices))
 
         if layers[layer_name].bias is not None:
             # This convolution has bias coefficients
-            append_param_directive(thinning_recipe, layer_name+'.bias', (0, indices))
+            _append_param_directive(thinning_recipe, layer_name + '.bias', (0, indices))
 
         # Find all instances of Convolution or FC (GEMM) layers that immediately follow this layer
         successors = sgraph.successors_f(layer_name, ['Conv', 'Gemm'])
@@ -295,36 +309,22 @@ def create_thinning_recipe_filters(sgraph, model, zeros_mask_dict):
             if isinstance(layers[successor], torch.nn.modules.Conv2d):
                 handle_conv_successor(thinning_recipe, layers, successor, num_nnz_filters, indices)
             elif isinstance(layers[successor], torch.nn.modules.Linear):
-                # If a Linear (Fully-Connected) layer follows, we need to update it's in_features member
-                fm_size = layers[successor].in_features // layers[layer_name].out_channels
-                in_features = fm_size * num_nnz_filters
-                append_module_directive(thinning_recipe, successor, key='in_features', val=in_features)
-                msglogger.debug("[recipe] Linear {}: fm_size = {}  layers[{}].out_channels={}".format(
-                                successor, in_features, layer_name, layers[layer_name].out_channels))
-                msglogger.debug("[recipe] {}: setting in_features = {}".format(successor, in_features))
-
-                # Now remove channels from the weights tensor of the successor FC layer:
-                # This is a bit tricky:
-                fm_height = fm_width = int(math.sqrt(fm_size))
-                view_4D = (layers[successor].out_features, layers[layer_name].out_channels, fm_height, fm_width)
-                view_2D = (layers[successor].out_features, in_features)
-                append_param_directive(thinning_recipe, successor+'.weight',
-                                       (1, indices, view_4D, view_2D))
+                handle_linear_successor(successor, indices)
 
         # Now handle the BatchNormalization layer that follows the convolution
         handle_bn_layers(layers, layer_name, num_nnz_filters, indices)
 
     def handle_conv_successor(thinning_recipe, layers, successor, num_nnz_filters, indices):
         # For each of the convolutional layers that follow, we have to reduce the number of input channels.
-        append_module_directive(thinning_recipe, successor, key='in_channels', val=num_nnz_filters)
+        _append_module_directive(thinning_recipe, successor, key='in_channels', val=num_nnz_filters)
 
         if layers[successor].groups == 1:
             # Now remove channels from the weights tensor of the successor conv
-            append_param_directive(thinning_recipe, successor+'.weight', (1, indices))
+            _append_param_directive(thinning_recipe, successor + '.weight', (1, indices))
         elif layers[successor].groups == layers[successor].in_channels:
             # Special case: number of groups is equal to the number of input channels
-            append_param_directive(thinning_recipe, successor+'.weight', (0, indices))
-            append_module_directive(thinning_recipe, successor, key='groups', val=num_nnz_filters)
+            _append_param_directive(thinning_recipe, successor + '.weight', (0, indices))
+            _append_module_directive(thinning_recipe, successor, key='groups', val=num_nnz_filters)
 
             # In the special case of a Convolutional layer with (groups == in_channels), if we
             # change in_channels, we also need to change out_channels, which means that we
@@ -334,13 +334,30 @@ def create_thinning_recipe_filters(sgraph, model, zeros_mask_dict):
         else:
             raise ValueError("Distiller thinning code currently does not handle this conv.groups configuration")
 
+    def handle_linear_successor(successor, indices):
+        # If a Linear (Fully-Connected) layer follows, we need to update it's in_features member
+        fm_size = layers[successor].in_features // layers[layer_name].out_channels
+        in_features = fm_size * num_nnz_filters
+        _append_module_directive(thinning_recipe, successor, key='in_features', val=in_features)
+        msglogger.debug("[recipe] Linear {}: fm_size = {}  layers[{}].out_channels={}".format(
+            successor, in_features, layer_name, layers[layer_name].out_channels))
+        msglogger.debug("[recipe] {}: setting in_features = {}".format(successor, in_features))
+
+        # Now remove channels from the weights tensor of the successor FC layer:
+        # This is a bit tricky:
+        fm_height = fm_width = int(math.sqrt(fm_size))
+        view_4D = (layers[successor].out_features, layers[layer_name].out_channels, fm_height, fm_width)
+        view_2D = (layers[successor].out_features, in_features)
+        _append_param_directive(thinning_recipe, successor + '.weight',
+                                (1, indices, view_4D, view_2D))
+
     def handle_bn_layers(layers, layer_name, num_nnz_filters, indices):
         bn_layers = sgraph.successors_f(layer_name, ['BatchNormalization'])
         if bn_layers:
             assert len(bn_layers) == 1
             # Thinning of the BN layer that follows the convolution
-            append_bn_thinning_directive(thinning_recipe, layers, bn_layers[0],
-                                         len_thin_features=num_nnz_filters, thin_features=indices)
+            _append_bn_thinning_directive(thinning_recipe, layers, bn_layers[0],
+                                          len_thin_features=num_nnz_filters, thin_features=indices)
 
     msglogger.debug("Invoking create_thinning_recipe_filters")
 
@@ -423,7 +440,7 @@ def execute_thinning_recipes_list(model, zeros_mask_dict, recipe_list):
     msglogger.debug("Executed %d recipes" % len(recipe_list))
 
 
-def optimizer_thinning(optimizer, param, dim, indices, new_shape=None):
+def _optimizer_thinning(optimizer, param, dim, indices, new_shape=None):
     """Adjust the size of the SGD velocity-tracking tensors.
 
     The SGD momentum update (velocity) is dependent on the weights, and because during thinning we
@@ -504,7 +521,7 @@ def execute_thinning_recipe(model, zeros_mask_dict, recipe, optimizer, loaded_fr
                             if grad_selection_view.size(dim) != len_indices:
                                 param.grad = torch.index_select(grad_selection_view, dim, indices)
                                 # update optimizer
-                                if optimizer_thinning(optimizer, param, dim, indices, directive[3]):
+                                if _optimizer_thinning(optimizer, param, dim, indices, directive[3]):
                                     msglogger.debug("Updated [4D] velocity buffer for {} (dim={},size={},shape={})".
                                                     format(param_name, dim, len_indices, directive[3]))
 
@@ -524,7 +541,7 @@ def execute_thinning_recipe(model, zeros_mask_dict, recipe, optimizer, loaded_fr
                     if param.grad is not None and param.grad.size(dim) != len_indices:
                         param.grad = torch.index_select(param.grad, dim, indices.to(param.device))
                         # update optimizer
-                        if optimizer_thinning(optimizer, param, dim, indices):
+                        if _optimizer_thinning(optimizer, param, dim, indices):
                             msglogger.debug("Updated velocity buffer %s" % param_name)
 
                 if not loaded_from_file and zeros_mask_dict:
diff --git a/examples/classifier_compression/compress_classifier.py b/examples/classifier_compression/compress_classifier.py
index 500eeb57b613690974b55239d0c6a3376746dc6a..b7c45032f464529c65efb76eb9cd3377f4f0d7f3 100755
--- a/examples/classifier_compression/compress_classifier.py
+++ b/examples/classifier_compression/compress_classifier.py
@@ -120,7 +120,7 @@ def handle_subapps(model, criterion, optimizer, compression_scheduler, pylogger,
         #zeros_mask_dict = distiller.create_model_masks_dict(model)
         assert args.resumed_checkpoint_path is not None, \
             "You must use --resume-from to provide a checkpoint file to thinnify"
-        distiller.remove_filters(model, compression_scheduler.zeros_mask_dict, args.arch, args.dataset, optimizer=None)
+        distiller.contract_model(model, compression_scheduler.zeros_mask_dict, args.arch, args.dataset, optimizer=None)
         apputils.save_checkpoint(0, args.arch, model, optimizer=None, scheduler=compression_scheduler,
                                  name="{}_thinned".format(args.resumed_checkpoint_path.replace(".pth.tar", "")),
                                  dir=msglogger.logdir)
diff --git a/examples/lottery_ticket/README.md b/examples/lottery_ticket/README.md
index b9497391896b6570a283f424591b94cddc7bf5a1..8842a7949da2b18f12401a50d595d580c9feffa1 100644
--- a/examples/lottery_ticket/README.md
+++ b/examples/lottery_ticket/README.md
@@ -9,11 +9,11 @@ most the same number of iterations."
 their smallest-magnitude weights. The set of connections that survives this process is the architecture
 of a winning ticket. Unique to our work, the winning ticket’s weights are the values to which these
 connections were initialized before training. This forms our central experiment:
->1. Randomly initialize a neural network f(x; theta_0) (where theta_0 ~ D_0).
->2. Train the network for j iterations, reaching parameters theta_j .
+>1. Randomly initialize a neural network f(x; theta_0) (where theta_0 ~ D_0).
+>2. Train the network for j iterations, reaching parameters theta_j.
 >3. Prune s% of the parameters, creating a mask m where Pm = (100 - s)%.
->4. To extract the winning ticket, reset the remaining parameters to their values intheta_0, creating
-the untrained network f(x;m *theta_0).
+>4. To extract the winning ticket, reset the remaining parameters to their values in theta_0, creating
+the untrained network f(x; m * theta_0).
 
 ### Example
 Train a ResNet20-CIFAR10 network from scratch, and save the untrained, randomized initial network weights in a checkpoint file.
@@ -23,7 +23,7 @@ python3 compress_classifier.py --arch resnet20_cifar  ${CIFAR10_PATH} -p=50 --ep
 ``` 
 
 After training the network, we have two outputs: the best trained network (`resnet20_best.pth.tar`) and the initial untrained network (`resnet20_untrained_checkpoint.pth.tar`).<br>
-In this example, we copy them into the `examples/lottery_ticket` directory for convenience.
+In this example, we copy both checkpoints into the `examples/lottery_ticket` directory for convenience.
 
 ```bash
 cp logs/resnet20___2019.08.22-220243/resnet20_best.pth.tar ../lottery_ticket/