diff --git a/distiller/thinning.py b/distiller/thinning.py
index e019e7defb4afe14becb1fb893c68e1ba9d26dfb..eb27e52624057c6457627b96664c0c0c006b2492 100755
--- a/distiller/thinning.py
+++ b/distiller/thinning.py
@@ -263,6 +263,10 @@ def create_thinning_recipe_channels(sgraph, model, zeros_mask_dict):
             # Now remove channels from the weights tensor of the successor conv
             append_param_directive(thinning_recipe, predecessor+'.weight', (0, indices))
 
+            if layers[predecessor].bias is not None:
+                # This convolution has bias coefficients
+                append_param_directive(thinning_recipe, predecessor+'.bias', (0, indices))
+
         # Now handle the BatchNormalization layer that follows the convolution
         bn_layers = sgraph.predecessors_f(normalize_module_name(layer_name), ['BatchNormalization'])
         if len(bn_layers) > 0:
diff --git a/tests/test_pruning.py b/tests/test_pruning.py
index e71dcb5fb81b4f0dab99737a750cc1d2f6b1701c..38ad480730037769e23b5757a9bd4c305ff657a3 100755
--- a/tests/test_pruning.py
+++ b/tests/test_pruning.py
@@ -204,6 +204,102 @@ def arbitrary_channel_pruning(channels_to_remove):
     logger.info("test_arbitrary_channel_pruning - Done 2")
 
 
+def test_channel_pruning_conv_bias():
+    """Test removal of arbitrary channels, for Convolutions with bias term.
+
+    This is different from test_arbitrary_channel_pruning() in that this model
+    has Convolution layers with biases tensors.
+    """
+    ARCH = "simplenet_cifar"
+    DATASET = "cifar10"
+    channels_to_remove = [0, 1]
+    has_bn = False
+    model, zeros_mask_dict = common.setup_test(ARCH, DATASET)
+
+    conv2 = common.find_module_by_name(model, "conv2")
+    assert conv2 is not None
+
+    # Test that we can access the weights tensor of the first convolution in layer 1
+    conv2_p = distiller.model_find_param(model, "conv2.weight")
+    assert conv2_p is not None
+
+    assert conv2_p.dim() == 4
+    num_filters = conv2_p.size(0)
+    num_channels = conv2_p.size(1)
+    kernel_height = conv2_p.size(2)
+    kernel_width = conv2_p.size(3)
+    cnt_nnz_channels = num_channels - len(channels_to_remove)
+
+    # Let's build our 4D mask.
+    # We start with a 1D mask of channels, with all but our specified channels set to one
+    channels = torch.ones(num_channels)
+    for ch in channels_to_remove:
+        channels[ch] = 0
+
+    # Now let's expand back up to a 4D mask
+    mask = channels.expand(num_filters, num_channels)
+    mask.unsqueeze_(-1)
+    mask.unsqueeze_(-1)
+    mask = mask.expand(num_filters, num_channels, kernel_height, kernel_width).contiguous()
+
+    assert mask.shape == conv2_p.shape
+    assert distiller.density_ch(mask) == (conv2.in_channels - len(channels_to_remove)) / conv2.in_channels
+
+    # Cool, so now we have a mask for pruning our channels.
+    # Use the mask to prune
+    zeros_mask_dict["conv2.weight"].mask = mask
+    zeros_mask_dict["conv2.weight"].apply_mask(conv2_p)
+    all_channels = set([ch for ch in range(num_channels)])
+    nnz_channels = set(distiller.find_nonzero_channels_list(conv2_p, "conv2.weight"))
+    channels_removed = all_channels - nnz_channels
+    logger.info("Channels removed {}".format(channels_removed))
+
+    # Now, let's do the actual network thinning
+    distiller.remove_channels(model, zeros_mask_dict, ARCH, DATASET)
+    conv1 = common.find_module_by_name(model, "conv1")
+    logger.info(conv1)
+    logger.info(conv2)
+    assert conv1.out_channels == cnt_nnz_channels
+    assert conv2.in_channels == cnt_nnz_channels
+    assert conv1.weight.size(0) == cnt_nnz_channels
+    assert conv2.weight.size(1) == cnt_nnz_channels
+    if has_bn:
+        bn1 = common.find_module_by_name(model, "layer1.0.bn1")
+        assert bn1.running_var.size(0) == cnt_nnz_channels
+        assert bn1.running_mean.size(0) == cnt_nnz_channels
+        assert bn1.num_features == cnt_nnz_channels
+        assert bn1.bias.size(0) == cnt_nnz_channels
+        assert bn1.weight.size(0) == cnt_nnz_channels
+
+    # Let's test saving and loading a thinned model.
+    # We save 3 times, and load twice, to make sure to cover some corner cases:
+    #   - Make sure that after loading, the model still has hold of the thinning recipes
+    #   - Make sure that after a 2nd load, there no problem loading (in this case, the
+    #   - tensors are already thin, so this is a new flow)
+    # (1)
+    save_checkpoint(epoch=0, arch=ARCH, model=model, optimizer=None)
+    model_2 = create_model(False, DATASET, ARCH, parallel=False)
+    dummy_input = torch.randn(1, 3, 32, 32)
+    model(dummy_input)
+    model_2(dummy_input)
+    conv2 = common.find_module_by_name(model_2, "conv2")
+    assert conv2 is not None
+    with pytest.raises(KeyError):
+        model_2, compression_scheduler, start_epoch = load_checkpoint(model_2, 'checkpoint.pth.tar')
+    compression_scheduler = distiller.CompressionScheduler(model)
+    hasattr(model, 'thinning_recipes')
+
+    # (2)
+    save_checkpoint(epoch=0, arch=ARCH, model=model, optimizer=None, scheduler=compression_scheduler)
+    model_2, compression_scheduler, start_epoch = load_checkpoint(model_2, 'checkpoint.pth.tar')
+    assert hasattr(model_2, 'thinning_recipes')
+    logger.info("test_arbitrary_channel_pruning - Done")
+
+    # (3)
+    save_checkpoint(epoch=0, arch=ARCH, model=model_2, optimizer=None, scheduler=compression_scheduler)
+    model_2, compression_scheduler, start_epoch = load_checkpoint(model_2, 'checkpoint.pth.tar')
+    assert hasattr(model_2, 'thinning_recipes')
+    logger.info("test_arbitrary_channel_pruning - Done 2")
 
 
 if __name__ == '__main__':