diff --git a/distiller/thinning.py b/distiller/thinning.py
index dbd2c56e7f28ad29a6c836a8473a1ecbf9026399..e019e7defb4afe14becb1fb893c68e1ba9d26dfb 100755
--- a/distiller/thinning.py
+++ b/distiller/thinning.py
@@ -29,15 +29,13 @@ documented in a different place.
 
 import math
 import logging
-import copy
-import re
 from collections import namedtuple
 import torch
 from .policy import ScheduledTrainingPolicy
 import distiller
 from distiller import normalize_module_name, denormalize_module_name
 from apputils import SummaryGraph
-from models import ALL_MODEL_NAMES, create_model
+from models import create_model
 msglogger = logging.getLogger()
 
 ThinningRecipe = namedtuple('ThinningRecipe', ['modules', 'parameters'])
@@ -67,7 +65,7 @@ These tuples can have 2 values, or 4 values.
 __all__ = ['ThinningRecipe', 'resnet_cifar_remove_layers',
            'ChannelRemover', 'remove_channels',
            'FilterRemover',  'remove_filters',
-           'find_nonzero_channels',
+           'find_nonzero_channels', 'find_nonzero_channels_list',
            'execute_thinning_recipes_list']
 
 
@@ -177,15 +175,21 @@ def find_nonzero_channels(param, param_name):
     k_sums_mat = kernel_sums.view(num_filters, num_channels).t()
     nonzero_channels = torch.nonzero(k_sums_mat.abs().sum(dim=1))
 
-    if num_channels > len(nonzero_channels):
+    if num_channels > nonzero_channels.nelement():
         msglogger.info("In tensor %s found %d/%d zero channels", param_name,
-                       num_filters - len(nonzero_channels), num_filters)
+                       num_channels - nonzero_channels.nelement(), num_channels)
 
     return nonzero_channels
 
 
+def find_nonzero_channels_list(param, param_name):
+    nnz_channels = find_nonzero_channels(param, param_name)
+    nnz_channels = nnz_channels.view(nnz_channels.numel())
+    return nnz_channels.cpu().numpy().tolist()
+
+
 def apply_and_save_recipe(model, zeros_mask_dict, thinning_recipe):
-    if len(thinning_recipe.modules)>0 or len(thinning_recipe.parameters)>0:
+    if len(thinning_recipe.modules) > 0 or len(thinning_recipe.parameters) > 0:
         # Now actually remove the filters, chaneels and make the weight tensors smaller
         execute_thinning_recipe(model, zeros_mask_dict, thinning_recipe)
 
@@ -195,7 +199,7 @@ def apply_and_save_recipe(model, zeros_mask_dict, thinning_recipe):
             model.thinning_recipes.append(thinning_recipe)
         else:
             model.thinning_recipes = [thinning_recipe]
-        msglogger.info("Created, applied and saved a filter-thinning recipe")
+        msglogger.info("Created, applied and saved a thinning recipe")
     else:
         msglogger.error("Failed to create a thinning recipe")
 
@@ -220,7 +224,7 @@ def create_thinning_recipe_channels(sgraph, model, zeros_mask_dict):
     msglogger.info("Invoking create_thinning_recipe_channels")
 
     thinning_recipe = ThinningRecipe(modules={}, parameters={})
-    layers = {mod_name : m for mod_name, m in model.named_modules()}
+    layers = {mod_name: m for mod_name, m in model.named_modules()}
 
     # Traverse all of the model's parameters, search for zero-channels, and
     # create a thinning recipe that descibes the required changes to the model.
@@ -231,16 +235,18 @@ def create_thinning_recipe_channels(sgraph, model, zeros_mask_dict):
 
         num_channels = param.size(1)
         nonzero_channels = find_nonzero_channels(param, param_name)
-
+        num_nnz_channels = nonzero_channels.nelement()
+        if num_nnz_channels == 0:
+            raise ValueError("Trying to set zero channels for parameter %s is not allowed" % param_name)
         # If there are non-zero channels in this tensor then continue to next tensor
-        if num_channels <= len(nonzero_channels):
+        if num_channels <= num_nnz_channels:
             continue
 
         # We are removing channels, so update the number of incoming channels (IFMs)
         # in the convolutional layer
         layer_name = param_name_2_layer_name(param_name)
         assert isinstance(layers[layer_name], torch.nn.modules.Conv2d)
-        append_module_directive(thinning_recipe, layer_name, key='in_channels', val=len(nonzero_channels))
+        append_module_directive(thinning_recipe, layer_name, key='in_channels', val=num_nnz_channels)
 
         # Select only the non-zero filters
         indices = nonzero_channels.data.squeeze()
@@ -252,7 +258,7 @@ def create_thinning_recipe_channels(sgraph, model, zeros_mask_dict):
         predecessors = [denormalize_module_name(model, predecessor) for predecessor in predecessors]
         for predecessor in predecessors:
             # For each of the convolutional layers that preceed, we have to reduce the number of output channels.
-            append_module_directive(thinning_recipe, predecessor, key='out_channels', val=len(nonzero_channels))
+            append_module_directive(thinning_recipe, predecessor, key='out_channels', val=num_nnz_channels)
 
             # Now remove channels from the weights tensor of the successor conv
             append_param_directive(thinning_recipe, predecessor+'.weight', (0, indices))
@@ -263,7 +269,8 @@ def create_thinning_recipe_channels(sgraph, model, zeros_mask_dict):
             assert len(bn_layers) == 1
             # Thinning of the BN layer that follows the convolution
             bn_layer_name = denormalize_module_name(model, bn_layers[0])
-            bn_thinning(thinning_recipe, layers, bn_layer_name, len_thin_features=len(nonzero_channels), thin_features=indices)
+            bn_thinning(thinning_recipe, layers, bn_layer_name,
+                        len_thin_features=num_nnz_channels, thin_features=indices)
 
     return thinning_recipe
 
@@ -281,7 +288,7 @@ def create_thinning_recipe_filters(sgraph, model, zeros_mask_dict):
     msglogger.info("Invoking create_thinning_recipe_filters")
 
     thinning_recipe = ThinningRecipe(modules={}, parameters={})
-    layers = {mod_name : m for mod_name, m in model.named_modules()}
+    layers = {mod_name: m for mod_name, m in model.named_modules()}
 
     for param_name, param in model.named_parameters():
         # We are only interested in 4D weights
@@ -292,20 +299,22 @@ def create_thinning_recipe_filters(sgraph, model, zeros_mask_dict):
         filter_view = param.view(param.size(0), -1)
         num_filters = filter_view.size()[0]
         nonzero_filters = torch.nonzero(filter_view.abs().sum(dim=1))
-
+        num_nnz_filters = nonzero_filters.nelement()
+        if num_nnz_filters == 0:
+            raise ValueError("Trying to set zero filters for parameter %s is not allowed" % param_name)
         # If there are non-zero filters in this tensor then continue to next tensor
-        if num_filters <= len(nonzero_filters):
+        if num_filters <= num_nnz_filters:
             msglogger.debug("SKipping {} shape={}".format(param_name_2_layer_name(param_name), param.shape))
             continue
 
         msglogger.info("In tensor %s found %d/%d zero filters", param_name,
-                       num_filters - len(nonzero_filters), num_filters)
+                       num_filters - num_nnz_filters, num_filters)
 
         # We are removing filters, so update the number of outgoing channels (OFMs)
         # in the convolutional layer
         layer_name = param_name_2_layer_name(param_name)
         assert isinstance(layers[layer_name], torch.nn.modules.Conv2d)
-        append_module_directive(thinning_recipe, layer_name, key='out_channels', val=len(nonzero_filters))
+        append_module_directive(thinning_recipe, layer_name, key='out_channels', val=num_nnz_filters)
 
         # Select only the non-zero filters
         indices = nonzero_filters.data.squeeze()
@@ -323,8 +332,8 @@ def create_thinning_recipe_filters(sgraph, model, zeros_mask_dict):
 
             if isinstance(layers[successor], torch.nn.modules.Conv2d):
                 # For each of the convolutional layers that follow, we have to reduce the number of input channels.
-                append_module_directive(thinning_recipe, successor, key='in_channels', val=len(nonzero_filters))
-                msglogger.info("[recipe] {}: setting in_channels = {}".format(successor, len(nonzero_filters)))
+                append_module_directive(thinning_recipe, successor, key='in_channels', val=num_nnz_filters)
+                msglogger.info("[recipe] {}: setting in_channels = {}".format(successor, num_nnz_filters))
 
                 # Now remove channels from the weights tensor of the successor conv
                 append_param_directive(thinning_recipe, successor+'.weight', (1, indices))
@@ -332,8 +341,7 @@ def create_thinning_recipe_filters(sgraph, model, zeros_mask_dict):
             elif isinstance(layers[successor], torch.nn.modules.Linear):
                 # If a Linear (Fully-Connected) layer follows, we need to update it's in_features member
                 fm_size = layers[successor].in_features // layers[layer_name].out_channels
-                in_features = fm_size * len(nonzero_filters)
-                #append_module_directive(thinning_recipe, layer_name, key='in_features', val=in_features)
+                in_features = fm_size * num_nnz_filters
                 append_module_directive(thinning_recipe, successor, key='in_features', val=in_features)
                 msglogger.info("[recipe] {}: setting in_features = {}".format(successor, in_features))
 
@@ -350,7 +358,8 @@ def create_thinning_recipe_filters(sgraph, model, zeros_mask_dict):
             assert len(bn_layers) == 1
             # Thinning of the BN layer that follows the convolution
             bn_layer_name = denormalize_module_name(model, bn_layers[0])
-            bn_thinning(thinning_recipe, layers, bn_layer_name, len_thin_features=len(nonzero_filters), thin_features=indices)
+            bn_thinning(thinning_recipe, layers, bn_layer_name,
+                        len_thin_features=num_nnz_filters, thin_features=indices)
     return thinning_recipe
 
 
@@ -423,8 +432,9 @@ def execute_thinning_recipe(model, zeros_mask_dict, recipe, loaded_from_file=Fal
                 dim_to_trim = val[0]
                 indices_to_select = val[1]
                 # Check if we're trying to trim a parameter that is already "thin"
-                if running.size(dim_to_trim) != len(indices_to_select):
-                    msglogger.info("[thinning] {}: setting {} to {}".format(layer_name, attr, len(indices_to_select)))
+                if running.size(dim_to_trim) != indices_to_select.nelement():
+                    msglogger.info("[thinning] {}: setting {} to {}".
+                                   format(layer_name, attr, indices_to_select.nelement()))
                     setattr(layers[layer_name], attr,
                             torch.index_select(running, dim=dim_to_trim, index=indices_to_select))
             else:
@@ -438,34 +448,35 @@ def execute_thinning_recipe(model, zeros_mask_dict, recipe, loaded_from_file=Fal
         for directive in param_directives:
             dim = directive[0]
             indices = directive[1]
+            len_indices = indices.nelement()
             if len(directive) == 4:  # TODO: this code is hard to follow
                 selection_view = param.view(*directive[2])
                 # Check if we're trying to trim a parameter that is already "thin"
-                if param.data.size(dim) != len(indices):
+                if param.data.size(dim) != len_indices:
                     param.data = torch.index_select(selection_view, dim, indices)
 
                 if param.grad is not None:
                     # We also need to change the dimensions of the gradient tensor.
                     grad_selection_view = param.grad.resize_(*directive[2])
-                    if grad_selection_view.size(dim) != len(indices):
+                    if grad_selection_view.size(dim) != len_indices:
                         param.grad = torch.index_select(grad_selection_view, dim, indices)
 
                 param.data = param.view(*directive[3])
                 if param.grad is not None:
                     param.grad = param.grad.resize_(*directive[3])
             else:
-                if param.data.size(dim) != len(indices):
+                if param.data.size(dim) != len_indices:
                     param.data = torch.index_select(param.data, dim, indices)
                 # We also need to change the dimensions of the gradient tensor.
                 # If have not done a backward-pass thus far, then the gradient will
                 # not exist, and therefore won't need to be re-dimensioned.
-                if param.grad is not None and param.grad.size(dim) != len(indices):
+                if param.grad is not None and param.grad.size(dim) != len_indices:
                         param.grad = torch.index_select(param.grad, dim, indices)
-                msglogger.info("[thinning] changing param {} shape: {}".format(param_name, len(indices)))
+                msglogger.info("[thinning] changing param {} shape: {}".format(param_name, len_indices))
 
             if not loaded_from_file:
                 # If the masks are loaded from a checkpoint file, then we don't need to change
                 # their shape, because they are already correctly shaped
                 mask = zeros_mask_dict[param_name].mask
-                if mask is not None and (mask.size(dim) != len(indices)):
+                if mask is not None and (mask.size(dim) != len_indices):
                     zeros_mask_dict[param_name].mask = torch.index_select(mask, dim, indices)
diff --git a/tests/common.py b/tests/common.py
new file mode 100755
index 0000000000000000000000000000000000000000..75c4bc37baa86206a6965d635e078b1a4c8e5326
--- /dev/null
+++ b/tests/common.py
@@ -0,0 +1,42 @@
+#
+# Copyright (c) 2018 Intel Corporation
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#      http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import os
+import sys
+module_path = os.path.abspath(os.path.join('..'))
+if module_path not in sys.path:
+    sys.path.append(module_path)
+import distiller
+from models import create_model
+
+
+def setup_test(arch, dataset):
+    model = create_model(False, dataset, arch, parallel=False)
+    assert model is not None
+
+    # Create the masks
+    zeros_mask_dict = {}
+    for name, param in model.named_parameters():
+        masker = distiller.ParameterMasker(name)
+        zeros_mask_dict[name] = masker
+    return model, zeros_mask_dict
+
+
+def find_module_by_name(model, module_to_find):
+    for name, m in model.named_modules():
+        if name == module_to_find:
+            return m
+    return None
diff --git a/tests/test_pruning.py b/tests/test_pruning.py
index fa01f22f7076592b686f9c6731e99ef6ea567b10..e71dcb5fb81b4f0dab99737a750cc1d2f6b1701c 100755
--- a/tests/test_pruning.py
+++ b/tests/test_pruning.py
@@ -18,14 +18,16 @@ import logging
 import torch
 import os
 import sys
-module_path = os.path.abspath(os.path.join('..'))
-if module_path not in sys.path:
+try:
+    import distiller
+except ImportError:
+    module_path = os.path.abspath(os.path.join('..'))
     sys.path.append(module_path)
-import distiller
-
+    import distiller
+import common
 import pytest
-from models import ALL_MODEL_NAMES, create_model
-from apputils import SummaryGraph, onnx_name_2_pytorch_name, save_checkpoint, load_checkpoint
+from models import create_model
+from apputils import save_checkpoint, load_checkpoint
 
 # Logging configuration
 logging.basicConfig(level=logging.INFO)
@@ -33,44 +35,47 @@ fh = logging.FileHandler('test.log')
 logger = logging.getLogger()
 logger.addHandler(fh)
 
-def find_module_by_name(model, module_to_find):
-    for name, m in model.named_modules():
-        if name == module_to_find:
-            return m
-    return None
+
+def test_ranked_filter_pruning():
+    ranked_filter_pruning(ratio_to_prune=0.1)
+    ranked_filter_pruning(ratio_to_prune=0.5)
 
 
-def setup_test(arch, dataset):
-    model = create_model(False, dataset, arch, parallel=False)
-    assert model is not None
+def test_prune_all_filters():
+    """Pruning all of the filteres in a weights tensor of a Convolution
+    is illegal and should raise an exception.
+    """
+    with pytest.raises(ValueError):
+        ranked_filter_pruning(ratio_to_prune=1.0)
 
-    # Create the masks
-    zeros_mask_dict = {}
-    for name, param in model.named_parameters():
-        masker = distiller.ParameterMasker(name)
-        zeros_mask_dict[name] = masker
-    return model, zeros_mask_dict
 
-def test_ranked_filter_pruning():
-    model, zeros_mask_dict = setup_test("resnet20_cifar", "cifar10")
+def ranked_filter_pruning(ratio_to_prune):
+    """Test L1 ranking and pruning of filters.
+
+    First we rank and prune the filters of a Convolutional layer using
+    a L1RankedStructureParameterPruner.  Then we physically remove the
+    filters from the model (via "thining" process).
+    """
+    model, zeros_mask_dict = common.setup_test("resnet20_cifar", "cifar10")
 
     # Test that we can access the weights tensor of the first convolution in layer 1
     conv1_p = distiller.model_find_param(model, "layer1.0.conv1.weight")
     assert conv1_p is not None
 
-    # Test that there are no zero-channels
+    # Test that there are no zero-filters
     assert distiller.sparsity_3D(conv1_p) == 0.0
 
     # Create a filter-ranking pruner
-    reg_regims = {"layer1.0.conv1.weight" : [0.1, "3D"]}
+    reg_regims = {"layer1.0.conv1.weight": [ratio_to_prune, "3D"]}
     pruner = distiller.pruning.L1RankedStructureParameterPruner("filter_pruner", reg_regims)
     pruner.set_param_mask(conv1_p, "layer1.0.conv1.weight", zeros_mask_dict, meta=None)
 
-    conv1 = find_module_by_name(model, "layer1.0.conv1")
+    conv1 = common.find_module_by_name(model, "layer1.0.conv1")
     assert conv1 is not None
     # Test that the mask has the correct fraction of filters pruned.
     # We asked for 10%, but there are only 16 filters, so we have to settle for 1/16 filters
-    expected_pruning = int(0.1 * conv1.out_channels) / conv1.out_channels
+    expected_cnt_removed_filters = int(ratio_to_prune * conv1.out_channels)
+    expected_pruning = expected_cnt_removed_filters / conv1.out_channels
     assert distiller.sparsity_3D(zeros_mask_dict["layer1.0.conv1.weight"].mask) == expected_pruning
 
     # Use the mask to prune
@@ -79,24 +84,42 @@ def test_ranked_filter_pruning():
     assert distiller.sparsity_3D(conv1_p) == expected_pruning
 
     # Remove filters
-    conv2 = find_module_by_name(model, "layer1.0.conv2")
+    conv2 = common.find_module_by_name(model, "layer1.0.conv2")
     assert conv2 is not None
     assert conv1.out_channels == 16
     assert conv2.in_channels == 16
 
     # Test thinning
     distiller.remove_filters(model, zeros_mask_dict, "resnet20_cifar", "cifar10")
-    assert conv1.out_channels == 15
-    assert conv2.in_channels == 15
+    assert conv1.out_channels == 16 - expected_cnt_removed_filters
+    assert conv2.in_channels == 16 - expected_cnt_removed_filters
 
 
 def test_arbitrary_channel_pruning():
+    arbitrary_channel_pruning(channels_to_remove=[0, 2])
+
+
+def test_prune_all_channels():
+    """Pruning all of the channels in a weights tensor of a Convolution
+    is illegal and should raise an exception.
+    """
+    with pytest.raises(ValueError):
+        arbitrary_channel_pruning(channels_to_remove=[ch for ch in range(16)])
+
+
+def arbitrary_channel_pruning(channels_to_remove):
+    """Test removal of arbitrary channels.
+
+    The test receives a specification of channels to remove.
+    Based on this specification, the channels are pruned and then physically
+    removed from the model (via a "thinning" process).
+    """
     ARCH = "resnet20_cifar"
     DATASET = "cifar10"
 
-    model, zeros_mask_dict = setup_test(ARCH, DATASET)
+    model, zeros_mask_dict = common.setup_test(ARCH, DATASET)
 
-    conv2 = find_module_by_name(model, "layer1.0.conv2")
+    conv2 = common.find_module_by_name(model, "layer1.0.conv2")
     assert conv2 is not None
 
     # Test that we can access the weights tensor of the first convolution in layer 1
@@ -108,8 +131,7 @@ def test_arbitrary_channel_pruning():
     num_channels = conv2_p.size(1)
     kernel_height = conv2_p.size(2)
     kernel_width = conv2_p.size(3)
-
-    channels_to_remove = [0, 2]
+    cnt_nnz_channels = num_channels - len(channels_to_remove)
 
     # Let's build our 4D mask.
     # We start with a 1D mask of channels, with all but our specified channels set to one
@@ -131,52 +153,59 @@ def test_arbitrary_channel_pruning():
     zeros_mask_dict["layer1.0.conv2.weight"].mask = mask
     zeros_mask_dict["layer1.0.conv2.weight"].apply_mask(conv2_p)
     all_channels = set([ch for ch in range(num_channels)])
-    channels_removed = all_channels - set(distiller.find_nonzero_channels(conv2_p, "layer1.0.conv2.weight"))
-    logger.info(channels_removed)
+    nnz_channels = set(distiller.find_nonzero_channels_list(conv2_p, "layer1.0.conv2.weight"))
+    channels_removed = all_channels - nnz_channels
+    logger.info("Channels removed {}".format(channels_removed))
 
     # Now, let's do the actual network thinning
     distiller.remove_channels(model, zeros_mask_dict, ARCH, DATASET)
-    conv1 = find_module_by_name(model, "layer1.0.conv1")
+    conv1 = common.find_module_by_name(model, "layer1.0.conv1")
     logger.info(conv1)
     logger.info(conv2)
-    assert conv1.out_channels == 14
-    assert conv2.in_channels == 14
-    assert conv1.weight.size(0) == 14
-    assert conv2.weight.size(1) == 14
-    bn1 = find_module_by_name(model, "layer1.0.bn1")
-    assert bn1.running_var.size(0) == 14
-    assert bn1.running_mean.size(0) == 14
-    assert bn1.num_features == 14
-    assert bn1.bias.size(0) == 14
-    assert bn1.weight.size(0) == 14
+    assert conv1.out_channels == cnt_nnz_channels
+    assert conv2.in_channels == cnt_nnz_channels
+    assert conv1.weight.size(0) == cnt_nnz_channels
+    assert conv2.weight.size(1) == cnt_nnz_channels
+    bn1 = common.find_module_by_name(model, "layer1.0.bn1")
+    assert bn1.running_var.size(0) == cnt_nnz_channels
+    assert bn1.running_mean.size(0) == cnt_nnz_channels
+    assert bn1.num_features == cnt_nnz_channels
+    assert bn1.bias.size(0) == cnt_nnz_channels
+    assert bn1.weight.size(0) == cnt_nnz_channels
 
     # Let's test saving and loading a thinned model.
     # We save 3 times, and load twice, to make sure to cover some corner cases:
     #   - Make sure that after loading, the model still has hold of the thinning recipes
     #   - Make sure that after a 2nd load, there no problem loading (in this case, the
     #   - tensors are already thin, so this is a new flow)
+    # (1)
     save_checkpoint(epoch=0, arch=ARCH, model=model, optimizer=None)
     model_2 = create_model(False, DATASET, ARCH, parallel=False)
     dummy_input = torch.randn(1, 3, 32, 32)
     model(dummy_input)
     model_2(dummy_input)
-    conv2 = find_module_by_name(model_2, "layer1.0.conv2")
+    conv2 = common.find_module_by_name(model_2, "layer1.0.conv2")
     assert conv2 is not None
     with pytest.raises(KeyError):
         model_2, compression_scheduler, start_epoch = load_checkpoint(model_2, 'checkpoint.pth.tar')
-
     compression_scheduler = distiller.CompressionScheduler(model)
     hasattr(model, 'thinning_recipes')
+
+    # (2)
     save_checkpoint(epoch=0, arch=ARCH, model=model, optimizer=None, scheduler=compression_scheduler)
     model_2, compression_scheduler, start_epoch = load_checkpoint(model_2, 'checkpoint.pth.tar')
     assert hasattr(model_2, 'thinning_recipes')
     logger.info("test_arbitrary_channel_pruning - Done")
 
+    # (3)
     save_checkpoint(epoch=0, arch=ARCH, model=model_2, optimizer=None, scheduler=compression_scheduler)
     model_2, compression_scheduler, start_epoch = load_checkpoint(model_2, 'checkpoint.pth.tar')
+    assert hasattr(model_2, 'thinning_recipes')
     logger.info("test_arbitrary_channel_pruning - Done 2")
 
 
 
+
 if __name__ == '__main__':
     test_arbitrary_channel_pruning()
+    test_prune_all_channels()