Skip to content
Snippets Groups Projects
Commit 081035d3 authored by Neta Zmora's avatar Neta Zmora
Browse files

Tests: added VGG16-Cifar channel pruning test

parent f8df3020
No related branches found
No related tags found
No related merge requests found
......@@ -85,6 +85,18 @@ def vgg19_imagenet(is_parallel):
("features.32", "features.34")],
bn_name=None)
def vgg16_cifar(is_parallel):
if is_parallel:
return NetConfig(arch="vgg16_cifar", dataset="cifar10",
module_pairs=[("features.module.0", "features.module.2")],
bn_name=None)
else:
return NetConfig(arch="vgg16_cifar", dataset="cifar10",
module_pairs=[("features.0", "features.2")],
bn_name=None)
@pytest.fixture(params=[True, False])
def parallel(request):
return request.param
......@@ -168,6 +180,9 @@ def test_arbitrary_channel_pruning(parallel):
arbitrary_channel_pruning(vgg19_imagenet(parallel),
channels_to_remove=[0, 2],
is_parallel=parallel)
arbitrary_channel_pruning(vgg16_cifar(parallel),
channels_to_remove=[0, 2],
is_parallel=parallel)
def test_prune_all_channels(parallel):
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment