From 081035d333fd151caacbce7ccd922c05d3ecae9d Mon Sep 17 00:00:00 2001
From: Neta Zmora <neta.zmora@intel.com>
Date: Sun, 22 Jul 2018 14:57:12 +0300
Subject: [PATCH] Tests: added VGG16-Cifar channel pruning test

---
 tests/test_pruning.py | 15 +++++++++++++++
 1 file changed, 15 insertions(+)

diff --git a/tests/test_pruning.py b/tests/test_pruning.py
index a59aaea..fc3b0f1 100755
--- a/tests/test_pruning.py
+++ b/tests/test_pruning.py
@@ -85,6 +85,18 @@ def vgg19_imagenet(is_parallel):
                                        ("features.32", "features.34")],
                          bn_name=None)
 
+
+def vgg16_cifar(is_parallel):
+    if is_parallel:
+        return NetConfig(arch="vgg16_cifar", dataset="cifar10",
+                         module_pairs=[("features.module.0", "features.module.2")],
+                         bn_name=None)
+    else:
+        return NetConfig(arch="vgg16_cifar", dataset="cifar10",
+                         module_pairs=[("features.0", "features.2")],
+                         bn_name=None)
+
+
 @pytest.fixture(params=[True, False])
 def parallel(request):
     return request.param
@@ -168,6 +180,9 @@ def test_arbitrary_channel_pruning(parallel):
     arbitrary_channel_pruning(vgg19_imagenet(parallel),
                               channels_to_remove=[0, 2],
                               is_parallel=parallel)
+    arbitrary_channel_pruning(vgg16_cifar(parallel),
+                              channels_to_remove=[0, 2],
+                              is_parallel=parallel)
 
 
 def test_prune_all_channels(parallel):
-- 
GitLab