diff --git a/tests/test_pruning.py b/tests/test_pruning.py index 38ad480730037769e23b5757a9bd4c305ff657a3..b6c6a06e1a91c50b6ac2509f1ab28041651bfc19 100755 --- a/tests/test_pruning.py +++ b/tests/test_pruning.py @@ -13,7 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. # - +from collections import namedtuple import logging import torch import os @@ -35,10 +35,28 @@ fh = logging.FileHandler('test.log') logger = logging.getLogger() logger.addHandler(fh) +NetConfig = namedtuple("test_config", "arch dataset conv1_name conv2_name bn_name") + + +# +# Model configurations +# +def simplenet(): + return NetConfig(arch="simplenet_cifar", dataset="cifar10", + conv1_name="conv1", conv2_name="conv2", + bn_name=None) + + +def resnet20_cifar(): + return NetConfig(arch="resnet20_cifar", dataset="cifar10", + conv1_name="layer1.0.conv1", conv2_name="layer1.0.conv2", + bn_name="layer1.0.bn1") + def test_ranked_filter_pruning(): - ranked_filter_pruning(ratio_to_prune=0.1) - ranked_filter_pruning(ratio_to_prune=0.5) + ranked_filter_pruning(resnet20_cifar(), ratio_to_prune=0.1) + ranked_filter_pruning(resnet20_cifar(), ratio_to_prune=0.5) + ranked_filter_pruning(simplenet(), ratio_to_prune=0.5) def test_prune_all_filters(): @@ -46,57 +64,61 @@ def test_prune_all_filters(): is illegal and should raise an exception. """ with pytest.raises(ValueError): - ranked_filter_pruning(ratio_to_prune=1.0) + ranked_filter_pruning(resnet20_cifar(), ratio_to_prune=1.0) -def ranked_filter_pruning(ratio_to_prune): +def ranked_filter_pruning(config, ratio_to_prune): """Test L1 ranking and pruning of filters. First we rank and prune the filters of a Convolutional layer using a L1RankedStructureParameterPruner. Then we physically remove the filters from the model (via "thining" process). """ - model, zeros_mask_dict = common.setup_test("resnet20_cifar", "cifar10") + model, zeros_mask_dict = common.setup_test(config.arch, config.dataset) # Test that we can access the weights tensor of the first convolution in layer 1 - conv1_p = distiller.model_find_param(model, "layer1.0.conv1.weight") + conv1_p = distiller.model_find_param(model, config.conv1_name + ".weight") assert conv1_p is not None + num_filters = conv1_p.size(0) # Test that there are no zero-filters assert distiller.sparsity_3D(conv1_p) == 0.0 # Create a filter-ranking pruner - reg_regims = {"layer1.0.conv1.weight": [ratio_to_prune, "3D"]} + reg_regims = {config.conv1_name + ".weight": [ratio_to_prune, "3D"]} pruner = distiller.pruning.L1RankedStructureParameterPruner("filter_pruner", reg_regims) - pruner.set_param_mask(conv1_p, "layer1.0.conv1.weight", zeros_mask_dict, meta=None) + pruner.set_param_mask(conv1_p, config.conv1_name + ".weight", zeros_mask_dict, meta=None) - conv1 = common.find_module_by_name(model, "layer1.0.conv1") + conv1 = common.find_module_by_name(model, config.conv1_name) assert conv1 is not None # Test that the mask has the correct fraction of filters pruned. # We asked for 10%, but there are only 16 filters, so we have to settle for 1/16 filters expected_cnt_removed_filters = int(ratio_to_prune * conv1.out_channels) expected_pruning = expected_cnt_removed_filters / conv1.out_channels - assert distiller.sparsity_3D(zeros_mask_dict["layer1.0.conv1.weight"].mask) == expected_pruning + masker = zeros_mask_dict[config.conv1_name + ".weight"] + assert masker is not None + assert distiller.sparsity_3D(masker.mask) == expected_pruning # Use the mask to prune assert distiller.sparsity_3D(conv1_p) == 0 - zeros_mask_dict["layer1.0.conv1.weight"].apply_mask(conv1_p) + masker.apply_mask(conv1_p) assert distiller.sparsity_3D(conv1_p) == expected_pruning # Remove filters - conv2 = common.find_module_by_name(model, "layer1.0.conv2") + conv2 = common.find_module_by_name(model, config.conv2_name) assert conv2 is not None - assert conv1.out_channels == 16 - assert conv2.in_channels == 16 + assert conv1.out_channels == num_filters + assert conv2.in_channels == num_filters # Test thinning - distiller.remove_filters(model, zeros_mask_dict, "resnet20_cifar", "cifar10") - assert conv1.out_channels == 16 - expected_cnt_removed_filters - assert conv2.in_channels == 16 - expected_cnt_removed_filters + distiller.remove_filters(model, zeros_mask_dict, config.arch, config.dataset) + assert conv1.out_channels == num_filters - expected_cnt_removed_filters + assert conv2.in_channels == num_filters - expected_cnt_removed_filters def test_arbitrary_channel_pruning(): - arbitrary_channel_pruning(channels_to_remove=[0, 2]) + arbitrary_channel_pruning(resnet20_cifar(), channels_to_remove=[0, 2]) + arbitrary_channel_pruning(simplenet(), channels_to_remove=[0, 2]) def test_prune_all_channels(): @@ -104,123 +126,28 @@ def test_prune_all_channels(): is illegal and should raise an exception. """ with pytest.raises(ValueError): - arbitrary_channel_pruning(channels_to_remove=[ch for ch in range(16)]) + arbitrary_channel_pruning(resnet20_cifar(), + channels_to_remove=[ch for ch in range(16)]) + + +def test_channel_pruning_conv_bias(): + arbitrary_channel_pruning(simplenet(), channels_to_remove=[0, 1]) -def arbitrary_channel_pruning(channels_to_remove): +def arbitrary_channel_pruning(config, channels_to_remove): """Test removal of arbitrary channels. The test receives a specification of channels to remove. Based on this specification, the channels are pruned and then physically removed from the model (via a "thinning" process). """ - ARCH = "resnet20_cifar" - DATASET = "cifar10" - - model, zeros_mask_dict = common.setup_test(ARCH, DATASET) - - conv2 = common.find_module_by_name(model, "layer1.0.conv2") - assert conv2 is not None - - # Test that we can access the weights tensor of the first convolution in layer 1 - conv2_p = distiller.model_find_param(model, "layer1.0.conv2.weight") - assert conv2_p is not None - - assert conv2_p.dim() == 4 - num_filters = conv2_p.size(0) - num_channels = conv2_p.size(1) - kernel_height = conv2_p.size(2) - kernel_width = conv2_p.size(3) - cnt_nnz_channels = num_channels - len(channels_to_remove) - - # Let's build our 4D mask. - # We start with a 1D mask of channels, with all but our specified channels set to one - channels = torch.ones(num_channels) - for ch in channels_to_remove: - channels[ch] = 0 - - # Now let's expand back up to a 4D mask - mask = channels.expand(num_filters, num_channels) - mask.unsqueeze_(-1) - mask.unsqueeze_(-1) - mask = mask.expand(num_filters, num_channels, kernel_height, kernel_width).contiguous() - - assert mask.shape == conv2_p.shape - assert distiller.density_ch(mask) == (conv2.in_channels - len(channels_to_remove)) / conv2.in_channels - - # Cool, so now we have a mask for pruning our channels. - # Use the mask to prune - zeros_mask_dict["layer1.0.conv2.weight"].mask = mask - zeros_mask_dict["layer1.0.conv2.weight"].apply_mask(conv2_p) - all_channels = set([ch for ch in range(num_channels)]) - nnz_channels = set(distiller.find_nonzero_channels_list(conv2_p, "layer1.0.conv2.weight")) - channels_removed = all_channels - nnz_channels - logger.info("Channels removed {}".format(channels_removed)) - - # Now, let's do the actual network thinning - distiller.remove_channels(model, zeros_mask_dict, ARCH, DATASET) - conv1 = common.find_module_by_name(model, "layer1.0.conv1") - logger.info(conv1) - logger.info(conv2) - assert conv1.out_channels == cnt_nnz_channels - assert conv2.in_channels == cnt_nnz_channels - assert conv1.weight.size(0) == cnt_nnz_channels - assert conv2.weight.size(1) == cnt_nnz_channels - bn1 = common.find_module_by_name(model, "layer1.0.bn1") - assert bn1.running_var.size(0) == cnt_nnz_channels - assert bn1.running_mean.size(0) == cnt_nnz_channels - assert bn1.num_features == cnt_nnz_channels - assert bn1.bias.size(0) == cnt_nnz_channels - assert bn1.weight.size(0) == cnt_nnz_channels - - # Let's test saving and loading a thinned model. - # We save 3 times, and load twice, to make sure to cover some corner cases: - # - Make sure that after loading, the model still has hold of the thinning recipes - # - Make sure that after a 2nd load, there no problem loading (in this case, the - # - tensors are already thin, so this is a new flow) - # (1) - save_checkpoint(epoch=0, arch=ARCH, model=model, optimizer=None) - model_2 = create_model(False, DATASET, ARCH, parallel=False) - dummy_input = torch.randn(1, 3, 32, 32) - model(dummy_input) - model_2(dummy_input) - conv2 = common.find_module_by_name(model_2, "layer1.0.conv2") - assert conv2 is not None - with pytest.raises(KeyError): - model_2, compression_scheduler, start_epoch = load_checkpoint(model_2, 'checkpoint.pth.tar') - compression_scheduler = distiller.CompressionScheduler(model) - hasattr(model, 'thinning_recipes') - - # (2) - save_checkpoint(epoch=0, arch=ARCH, model=model, optimizer=None, scheduler=compression_scheduler) - model_2, compression_scheduler, start_epoch = load_checkpoint(model_2, 'checkpoint.pth.tar') - assert hasattr(model_2, 'thinning_recipes') - logger.info("test_arbitrary_channel_pruning - Done") - - # (3) - save_checkpoint(epoch=0, arch=ARCH, model=model_2, optimizer=None, scheduler=compression_scheduler) - model_2, compression_scheduler, start_epoch = load_checkpoint(model_2, 'checkpoint.pth.tar') - assert hasattr(model_2, 'thinning_recipes') - logger.info("test_arbitrary_channel_pruning - Done 2") - - -def test_channel_pruning_conv_bias(): - """Test removal of arbitrary channels, for Convolutions with bias term. - - This is different from test_arbitrary_channel_pruning() in that this model - has Convolution layers with biases tensors. - """ - ARCH = "simplenet_cifar" - DATASET = "cifar10" - channels_to_remove = [0, 1] - has_bn = False - model, zeros_mask_dict = common.setup_test(ARCH, DATASET) + model, zeros_mask_dict = common.setup_test(config.arch, config.dataset) - conv2 = common.find_module_by_name(model, "conv2") + conv2 = common.find_module_by_name(model, config.conv2_name) assert conv2 is not None # Test that we can access the weights tensor of the first convolution in layer 1 - conv2_p = distiller.model_find_param(model, "conv2.weight") + conv2_p = distiller.model_find_param(model, config.conv2_name + ".weight") assert conv2_p is not None assert conv2_p.dim() == 4 @@ -247,24 +174,24 @@ def test_channel_pruning_conv_bias(): # Cool, so now we have a mask for pruning our channels. # Use the mask to prune - zeros_mask_dict["conv2.weight"].mask = mask - zeros_mask_dict["conv2.weight"].apply_mask(conv2_p) + zeros_mask_dict[config.conv2_name + ".weight"].mask = mask + zeros_mask_dict[config.conv2_name + ".weight"].apply_mask(conv2_p) all_channels = set([ch for ch in range(num_channels)]) - nnz_channels = set(distiller.find_nonzero_channels_list(conv2_p, "conv2.weight")) + nnz_channels = set(distiller.find_nonzero_channels_list(conv2_p, config.conv2_name + ".weight")) channels_removed = all_channels - nnz_channels logger.info("Channels removed {}".format(channels_removed)) # Now, let's do the actual network thinning - distiller.remove_channels(model, zeros_mask_dict, ARCH, DATASET) - conv1 = common.find_module_by_name(model, "conv1") + distiller.remove_channels(model, zeros_mask_dict, config.arch, config.dataset) + conv1 = common.find_module_by_name(model, config.conv1_name) logger.info(conv1) logger.info(conv2) assert conv1.out_channels == cnt_nnz_channels assert conv2.in_channels == cnt_nnz_channels assert conv1.weight.size(0) == cnt_nnz_channels assert conv2.weight.size(1) == cnt_nnz_channels - if has_bn: - bn1 = common.find_module_by_name(model, "layer1.0.bn1") + if config.bn_name is not None: + bn1 = common.find_module_by_name(model, config.bn_name) assert bn1.running_var.size(0) == cnt_nnz_channels assert bn1.running_mean.size(0) == cnt_nnz_channels assert bn1.num_features == cnt_nnz_channels @@ -277,12 +204,12 @@ def test_channel_pruning_conv_bias(): # - Make sure that after a 2nd load, there no problem loading (in this case, the # - tensors are already thin, so this is a new flow) # (1) - save_checkpoint(epoch=0, arch=ARCH, model=model, optimizer=None) - model_2 = create_model(False, DATASET, ARCH, parallel=False) + save_checkpoint(epoch=0, arch=config.arch, model=model, optimizer=None) + model_2 = create_model(False, config.dataset, config.arch, parallel=False) dummy_input = torch.randn(1, 3, 32, 32) model(dummy_input) model_2(dummy_input) - conv2 = common.find_module_by_name(model_2, "conv2") + conv2 = common.find_module_by_name(model_2, config.conv2_name) assert conv2 is not None with pytest.raises(KeyError): model_2, compression_scheduler, start_epoch = load_checkpoint(model_2, 'checkpoint.pth.tar') @@ -290,18 +217,19 @@ def test_channel_pruning_conv_bias(): hasattr(model, 'thinning_recipes') # (2) - save_checkpoint(epoch=0, arch=ARCH, model=model, optimizer=None, scheduler=compression_scheduler) + save_checkpoint(epoch=0, arch=config.arch, model=model, optimizer=None, scheduler=compression_scheduler) model_2, compression_scheduler, start_epoch = load_checkpoint(model_2, 'checkpoint.pth.tar') assert hasattr(model_2, 'thinning_recipes') logger.info("test_arbitrary_channel_pruning - Done") # (3) - save_checkpoint(epoch=0, arch=ARCH, model=model_2, optimizer=None, scheduler=compression_scheduler) + save_checkpoint(epoch=0, arch=config.arch, model=model_2, optimizer=None, scheduler=compression_scheduler) model_2, compression_scheduler, start_epoch = load_checkpoint(model_2, 'checkpoint.pth.tar') assert hasattr(model_2, 'thinning_recipes') logger.info("test_arbitrary_channel_pruning - Done 2") if __name__ == '__main__': + test_ranked_filter_pruning() test_arbitrary_channel_pruning() test_prune_all_channels()