diff --git a/tests/test_pruning.py b/tests/test_pruning.py index a59aaeaa092a47ac25a0c81149e3b03d753da6c9..fc3b0f19adb3cec0fd0f7feadab084f212d03a69 100755 --- a/tests/test_pruning.py +++ b/tests/test_pruning.py @@ -85,6 +85,18 @@ def vgg19_imagenet(is_parallel): ("features.32", "features.34")], bn_name=None) + +def vgg16_cifar(is_parallel): + if is_parallel: + return NetConfig(arch="vgg16_cifar", dataset="cifar10", + module_pairs=[("features.module.0", "features.module.2")], + bn_name=None) + else: + return NetConfig(arch="vgg16_cifar", dataset="cifar10", + module_pairs=[("features.0", "features.2")], + bn_name=None) + + @pytest.fixture(params=[True, False]) def parallel(request): return request.param @@ -168,6 +180,9 @@ def test_arbitrary_channel_pruning(parallel): arbitrary_channel_pruning(vgg19_imagenet(parallel), channels_to_remove=[0, 2], is_parallel=parallel) + arbitrary_channel_pruning(vgg16_cifar(parallel), + channels_to_remove=[0, 2], + is_parallel=parallel) def test_prune_all_channels(parallel):