diff --git a/tests/test_pruning.py b/tests/test_pruning.py
index a59aaeaa092a47ac25a0c81149e3b03d753da6c9..fc3b0f19adb3cec0fd0f7feadab084f212d03a69 100755
--- a/tests/test_pruning.py
+++ b/tests/test_pruning.py
@@ -85,6 +85,18 @@ def vgg19_imagenet(is_parallel):
                                        ("features.32", "features.34")],
                          bn_name=None)
 
+
+def vgg16_cifar(is_parallel):
+    if is_parallel:
+        return NetConfig(arch="vgg16_cifar", dataset="cifar10",
+                         module_pairs=[("features.module.0", "features.module.2")],
+                         bn_name=None)
+    else:
+        return NetConfig(arch="vgg16_cifar", dataset="cifar10",
+                         module_pairs=[("features.0", "features.2")],
+                         bn_name=None)
+
+
 @pytest.fixture(params=[True, False])
 def parallel(request):
     return request.param
@@ -168,6 +180,9 @@ def test_arbitrary_channel_pruning(parallel):
     arbitrary_channel_pruning(vgg19_imagenet(parallel),
                               channels_to_remove=[0, 2],
                               is_parallel=parallel)
+    arbitrary_channel_pruning(vgg16_cifar(parallel),
+                              channels_to_remove=[0, 2],
+                              is_parallel=parallel)
 
 
 def test_prune_all_channels(parallel):