From 081035d333fd151caacbce7ccd922c05d3ecae9d Mon Sep 17 00:00:00 2001 From: Neta Zmora <neta.zmora@intel.com> Date: Sun, 22 Jul 2018 14:57:12 +0300 Subject: [PATCH] Tests: added VGG16-Cifar channel pruning test --- tests/test_pruning.py | 15 +++++++++++++++ 1 file changed, 15 insertions(+) diff --git a/tests/test_pruning.py b/tests/test_pruning.py index a59aaea..fc3b0f1 100755 --- a/tests/test_pruning.py +++ b/tests/test_pruning.py @@ -85,6 +85,18 @@ def vgg19_imagenet(is_parallel): ("features.32", "features.34")], bn_name=None) + +def vgg16_cifar(is_parallel): + if is_parallel: + return NetConfig(arch="vgg16_cifar", dataset="cifar10", + module_pairs=[("features.module.0", "features.module.2")], + bn_name=None) + else: + return NetConfig(arch="vgg16_cifar", dataset="cifar10", + module_pairs=[("features.0", "features.2")], + bn_name=None) + + @pytest.fixture(params=[True, False]) def parallel(request): return request.param @@ -168,6 +180,9 @@ def test_arbitrary_channel_pruning(parallel): arbitrary_channel_pruning(vgg19_imagenet(parallel), channels_to_remove=[0, 2], is_parallel=parallel) + arbitrary_channel_pruning(vgg16_cifar(parallel), + channels_to_remove=[0, 2], + is_parallel=parallel) def test_prune_all_channels(parallel): -- GitLab