diff --git a/apputils/model_summaries.py b/apputils/model_summaries.py index 3cd46ca321a99edac1a9a725ffeb5dc77a7a6410..ec6a1e12a410e24a83b1ac70d84f4740e9b07547 100755 --- a/apputils/model_summaries.py +++ b/apputils/model_summaries.py @@ -31,6 +31,7 @@ import torch.jit as jit import pandas as pd from tabulate import tabulate import pydot +import distiller def onnx_name_2_pytorch_name(name, op_type): @@ -90,7 +91,7 @@ class SummaryGraph(object): def __init__(self, model, dummy_input): with torch.onnx.set_training(model, False): - trace, _ = jit.get_trace_graph(model, dummy_input) + trace, _ = jit.get_trace_graph(model, dummy_input.cuda()) # Let ONNX do the heavy lifting: fusing the convolution nodes; fusing the nodes # composing a GEMM operation; etc. @@ -588,6 +589,7 @@ def draw_img_classifier_to_file(model, png_fname, dataset, display_param_nodes=F print("Unsupported dataset (%s) - aborting draw operation" % dataset) return + model = distiller.make_non_parallel_copy(model) g = SummaryGraph(model, dummy_input) draw_model_to_file(g, png_fname, display_param_nodes, rankdir, styles) print("Network PNG image generation completed") diff --git a/distiller/model_summaries.py b/distiller/model_summaries.py index 322411469de789dfd9d0303e6a64bfa0f8d5f256..cf67720376532ae87f05a8f0795f3885fdbe3383 100755 --- a/distiller/model_summaries.py +++ b/distiller/model_summaries.py @@ -170,11 +170,7 @@ def module_visitor(self, input, output, df, model, weights_vol, macs, attrs=None def model_performance_summary(model, dummy_input, batch_size=1): - """Collect performance data - - warning: in PyTorch 0.4 this function does not return correct values when - the graph contains torch.nn.DataParallel layers. - """ + """Collect performance data""" def install_perf_collector(m): if isinstance(m, torch.nn.Conv2d): hook_handles.append(m.register_forward_hook( @@ -188,11 +184,14 @@ def model_performance_summary(model, dummy_input, batch_size=1): hook_handles = [] memo = [] + + model = distiller.make_non_parallel_copy(model) model.apply(install_perf_collector) # Now run the forward path and collect the data - model(dummy_input); + model(dummy_input.cuda()) # Unregister from the forward hooks - for handle in hook_handles: handle.remove() + for handle in hook_handles: + handle.remove() return df diff --git a/distiller/scheduler.py b/distiller/scheduler.py index ba7a56b95b14c5232eeb7640cc17e75f9b5a8932..ff7fd9d7e6d8b8102cce63858acbe02388819f9f 100755 --- a/distiller/scheduler.py +++ b/distiller/scheduler.py @@ -160,8 +160,7 @@ class CompressionScheduler(object): masks = {} for name, masker in self.zeros_mask_dict.items(): masks[name] = masker.mask - state = {'masks_dict': masks, - 'parallel_model': isinstance(self.model, torch.nn.DataParallel)} + state = {'masks_dict': masks} return state def load_state_dict(self, state): @@ -184,17 +183,6 @@ class CompressionScheduler(object): print("\t\t" + k) exit(1) - curr_model_parallel = isinstance(self.model, torch.nn.DataParallel) - # Fallback to 'True' for old checkpoints that don't have this attribute, since parallel=True is the - # default for create_model - loaded_model_parallel = state.get('parallel_model', True) for name, mask in self.zeros_mask_dict.items(): - # DataParallel modules wrap the actual module with a module named "module"... - if loaded_model_parallel and not curr_model_parallel: - load_name = 'module.' + name - elif curr_model_parallel and not loaded_model_parallel: - load_name = name.replace('module.', '', 1) - else: - load_name = name masker = self.zeros_mask_dict[name] - masker.mask = loaded_masks[load_name] + masker.mask = loaded_masks[name] diff --git a/distiller/thinning.py b/distiller/thinning.py index 94346f34750c26763e564a4150ad816f3bfaeaee..de7b6252dc8c056ceda5cfc73e46061c956e5dd6 100755 --- a/distiller/thinning.py +++ b/distiller/thinning.py @@ -18,7 +18,7 @@ Thinning a model is the process of taking a dense network architecture with a parameter model that has structure-sparsity (filters or channels) in the weights tensors of convolution layers, and making changes - in the network architecture and parameters, in order to completely remove the structures. +in the network architecture and parameters, in order to completely remove the structures. The new architecture is smaller (condensed), with less channels and filters in some of the convolution layers. Linear and BatchNormalization layers are also adjusted as required. diff --git a/distiller/utils.py b/distiller/utils.py index c977b809973ee8c30939782221af06479b4bf132..da90f64d856daee6af59c9a2aa56857e4bcc43d1 100755 --- a/distiller/utils.py +++ b/distiller/utils.py @@ -19,10 +19,11 @@ This module contains various tensor sparsity/density measurement functions, together with some random helper functions. """ -from functools import reduce import numpy as np import torch from torch.autograd import Variable +import torch.nn as nn +from copy import deepcopy def to_np(var): @@ -259,3 +260,46 @@ def log_weights_sparsity(model, epoch, loggers): """Log information about the weights sparsity""" for logger in loggers: logger.log_weights_sparsity(model, epoch) + + +class DoNothingModuleWrapper(nn.Module): + """Implement a nn.Module which wraps another nn.Module. + + The DoNothingModuleWrapper wrapper does nothing but forward + to the wrapped module. + One use-case for this class, is for replacing nn.DataParallel + by a module that does nothing :-). This is a trick we use + to transform data-parallel to serialized models. + """ + def __init__(self, module): + super(DoNothingModuleWrapper, self).__init__() + self.wrapped_module = module + + def forward(self, *inputs, **kwargs): + return self.wrapped_module(*inputs, **kwargs) + + +def make_non_parallel_copy(model): + """Make a non-data-parallel copy of the provided model. + + nn.DataParallel instances are replaced by DoNothingModuleWrapper + instances. + """ + def replace_data_parallel(container, prefix=''): + for name, module in container.named_children(): + full_name = prefix + name + if isinstance(module, nn.DataParallel): + # msglogger.debug('Replacing module {}'.format(full_name)) + container._modules[name] = DoNothingModuleWrapper(module.module) + if len(module._modules) > 0: + # For a container we call recursively + replace_data_parallel(module, full_name + '.') + + # Make a copy of the model, because we're going to change it + new_model = deepcopy(model) + if isinstance(new_model, nn.DataParallel): + #new_model = new_model.module # + new_model = DoNothingModuleWrapper(new_model.module) + + replace_data_parallel(new_model) + return new_model diff --git a/examples/classifier_compression/compress_classifier.py b/examples/classifier_compression/compress_classifier.py index b92997259259ac5cb647d51934e81fba72e1909d..a540dd71d85636150e0d0155a38bbb9406a48695 100755 --- a/examples/classifier_compression/compress_classifier.py +++ b/examples/classifier_compression/compress_classifier.py @@ -209,9 +209,7 @@ def main(): args.dataset = 'cifar10' if 'cifar' in args.arch else 'imagenet' # Create the model - png_summary = args.summary is not None and args.summary.startswith('png') - is_parallel = not png_summary and args.summary != 'compute' # For PNG summary, parallel graphs are illegible - model = create_model(args.pretrained, args.dataset, args.arch, parallel=is_parallel, device_ids=args.gpus) + model = create_model(args.pretrained, args.dataset, args.arch, device_ids=args.gpus) compression_scheduler = None # Create a couple of logging backends. TensorBoardLogger writes log files in a format diff --git a/models/__init__.py b/models/__init__.py index 41c5ba7f2c89297da6deb25e5fb9dfdf1a0e4685..d2a5c4fbb277a21ed16f291ede774b48e4d442ca 100755 --- a/models/__init__.py +++ b/models/__init__.py @@ -70,8 +70,8 @@ def create_model(pretrained, dataset, arch, parallel=True, device_ids=None): if (arch.startswith('alexnet') or arch.startswith('vgg')) and parallel: model.features = torch.nn.DataParallel(model.features, device_ids=device_ids) - model.cuda() elif parallel: - model = torch.nn.DataParallel(model, device_ids=device_ids).cuda() + model = torch.nn.DataParallel(model, device_ids=device_ids) + model.cuda() return model diff --git a/tests/common.py b/tests/common.py index 75c4bc37baa86206a6965d635e078b1a4c8e5326..bcee3bf1f403b834784b192b646b51e705919c36 100755 --- a/tests/common.py +++ b/tests/common.py @@ -23,8 +23,8 @@ import distiller from models import create_model -def setup_test(arch, dataset): - model = create_model(False, dataset, arch, parallel=False) +def setup_test(arch, dataset, parallel=True): + model = create_model(False, dataset, arch, parallel=parallel) assert model is not None # Create the masks diff --git a/tests/test_pruning.py b/tests/test_pruning.py index b16e9012af1941d7cbbfbe86b106c8927c9cc494..5007cd3d7b6a26fc2f4906e0af80787799238faa 100755 --- a/tests/test_pruning.py +++ b/tests/test_pruning.py @@ -41,54 +41,79 @@ NetConfig = namedtuple("test_config", "arch dataset bn_name module_pairs") # # Model configurations # -def simplenet(): - return NetConfig(arch="simplenet_cifar", dataset="cifar10", - module_pairs=[("conv1", "conv2")], - bn_name=None) - - -def resnet20_cifar(): - return NetConfig(arch="resnet20_cifar", dataset="cifar10", - module_pairs=[("layer1.0.conv1", "layer1.0.conv2")], - bn_name="layer1.0.bn1") - - -def vgg19_imagenet(): - return NetConfig(arch="vgg19", dataset="imagenet", - module_pairs=[("features.21", "features.23"), - ("features.23", "features.25"), - ("features.25", "features.28"), - ("features.28", "features.30"), - ("features.30", "features.32"), - ("features.32", "features.34")], - bn_name=None) - - -def test_ranked_filter_pruning(): - ranked_filter_pruning(resnet20_cifar(), ratio_to_prune=0.1) - ranked_filter_pruning(resnet20_cifar(), ratio_to_prune=0.5) - ranked_filter_pruning(simplenet(), ratio_to_prune=0.5) - ranked_filter_pruning(vgg19_imagenet(), ratio_to_prune=0.1) - model, zeros_mask_dict = ranked_filter_pruning(vgg19_imagenet(), ratio_to_prune=0.1) - test_conv_fc_interface(model, zeros_mask_dict) - - -def test_prune_all_filters(): +def simplenet(is_parallel): + if is_parallel: + return NetConfig(arch="simplenet_cifar", dataset="cifar10", + module_pairs=[("module.conv1", "module.conv2")], + bn_name=None) + else: + return NetConfig(arch="simplenet_cifar", dataset="cifar10", + module_pairs=[("conv1", "conv2")], + bn_name=None) + + +def resnet20_cifar(is_parallel): + if is_parallel: + return NetConfig(arch="resnet20_cifar", dataset="cifar10", + module_pairs=[("module.layer1.0.conv1", "module.layer1.0.conv2")], + bn_name="module.layer1.0.bn1") + else: + return NetConfig(arch="resnet20_cifar", dataset="cifar10", + module_pairs=[("layer1.0.conv1", "layer1.0.conv2")], + bn_name="layer1.0.bn1") + + +def vgg19_imagenet(is_parallel): + if is_parallel: + return NetConfig(arch="vgg19", dataset="imagenet", + module_pairs=[("features.module.21", "features.module.23"), + ("features.module.23", "features.module.25"), + ("features.module.25", "features.module.28"), + ("features.module.28", "features.module.30"), + ("features.module.30", "features.module.32"), + ("features.module.32", "features.module.34")], + bn_name=None) + else: + return NetConfig(arch="vgg19", dataset="imagenet", + module_pairs=[("features.21", "features.23"), + ("features.23", "features.25"), + ("features.25", "features.28"), + ("features.28", "features.30"), + ("features.30", "features.32"), + ("features.32", "features.34")], + bn_name=None) + +@pytest.fixture(params=[False]) +def parallel(request): + return request.param + +def test_ranked_filter_pruning(parallel): + ranked_filter_pruning(resnet20_cifar(parallel), ratio_to_prune=0.1, is_parallel=parallel) + ranked_filter_pruning(resnet20_cifar(parallel), ratio_to_prune=0.5, is_parallel=parallel) + ranked_filter_pruning(simplenet(parallel), ratio_to_prune=0.5, is_parallel=parallel) + ranked_filter_pruning(vgg19_imagenet(parallel), ratio_to_prune=0.1, is_parallel=parallel) + model, zeros_mask_dict = ranked_filter_pruning(vgg19_imagenet(parallel), + ratio_to_prune=0.1, + is_parallel=parallel) + test_conv_fc_interface(parallel, model, zeros_mask_dict) + + +def test_prune_all_filters(parallel): """Pruning all of the filteres in a weights tensor of a Convolution is illegal and should raise an exception. """ with pytest.raises(ValueError): - ranked_filter_pruning(resnet20_cifar(), ratio_to_prune=1.0) + ranked_filter_pruning(resnet20_cifar(parallel), ratio_to_prune=1.0, is_parallel=parallel) -def ranked_filter_pruning(config, ratio_to_prune): +def ranked_filter_pruning(config, ratio_to_prune, is_parallel): """Test L1 ranking and pruning of filters. First we rank and prune the filters of a Convolutional layer using a L1RankedStructureParameterPruner. Then we physically remove the filters from the model (via "thining" process). """ - model, zeros_mask_dict = common.setup_test(config.arch, config.dataset) + model, zeros_mask_dict = common.setup_test(config.arch, config.dataset, is_parallel) for pair in config.module_pairs: # Test that we can access the weights tensor of the first convolution in layer 1 @@ -132,22 +157,29 @@ def ranked_filter_pruning(config, ratio_to_prune): return model, zeros_mask_dict -def test_arbitrary_channel_pruning(): - arbitrary_channel_pruning(resnet20_cifar(), channels_to_remove=[0, 2]) - arbitrary_channel_pruning(simplenet(), channels_to_remove=[0, 2]) +def test_arbitrary_channel_pruning(parallel): + arbitrary_channel_pruning(resnet20_cifar(parallel), + channels_to_remove=[0, 2], + is_parallel=parallel) + arbitrary_channel_pruning(simplenet(parallel), + channels_to_remove=[0, 2], + is_parallel=parallel) -def test_prune_all_channels(): +def test_prune_all_channels(parallel): """Pruning all of the channels in a weights tensor of a Convolution is illegal and should raise an exception. """ with pytest.raises(ValueError): - arbitrary_channel_pruning(resnet20_cifar(), - channels_to_remove=[ch for ch in range(16)]) + arbitrary_channel_pruning(resnet20_cifar(parallel), + channels_to_remove=[ch for ch in range(16)], + is_parallel=parallel) -def test_channel_pruning_conv_bias(): - arbitrary_channel_pruning(simplenet(), channels_to_remove=[0, 1]) +def test_channel_pruning_conv_bias(parallel): + arbitrary_channel_pruning(simplenet(parallel), + channels_to_remove=[0, 1], + is_parallel=parallel) def create_channels_mask(conv_p, channels_to_remove): @@ -167,7 +199,7 @@ def create_channels_mask(conv_p, channels_to_remove): mask = channels.expand(num_filters, num_channels) mask.unsqueeze_(-1) mask.unsqueeze_(-1) - mask = mask.expand(num_filters, num_channels, kernel_height, kernel_width).contiguous() + mask = mask.expand(num_filters, num_channels, kernel_height, kernel_width).contiguous().cuda() assert mask.shape == conv_p.shape return mask @@ -177,21 +209,21 @@ def run_forward_backward(model, optimizer, dummy_input): criterion = torch.nn.CrossEntropyLoss().cuda() model.train() output = model(dummy_input) - target = torch.LongTensor(1).random_(2) + target = torch.LongTensor(1).random_(2).cuda() loss = criterion(output, target) optimizer.zero_grad() loss.backward() optimizer.step() -def arbitrary_channel_pruning(config, channels_to_remove): +def arbitrary_channel_pruning(config, channels_to_remove, is_parallel): """Test removal of arbitrary channels. The test receives a specification of channels to remove. Based on this specification, the channels are pruned and then physically removed from the model (via a "thinning" process). """ - model, zeros_mask_dict = common.setup_test(config.arch, config.dataset) + model, zeros_mask_dict = common.setup_test(config.arch, config.dataset, is_parallel) assert len(config.module_pairs) == 1 # This is a temporary restriction on the test pair = config.module_pairs[0] @@ -233,7 +265,7 @@ def arbitrary_channel_pruning(config, channels_to_remove): assert bn1.bias.size(0) == cnt_nnz_channels assert bn1.weight.size(0) == cnt_nnz_channels - dummy_input = torch.randn(1, 3, 32, 32) + dummy_input = torch.randn(1, 3, 32, 32).cuda() optimizer = torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.9, weight_decay=0.1) run_forward_backward(model, optimizer, dummy_input) @@ -244,7 +276,7 @@ def arbitrary_channel_pruning(config, channels_to_remove): # - tensors are already thin, so this is a new flow) # (1) save_checkpoint(epoch=0, arch=config.arch, model=model, optimizer=None) - model_2 = create_model(False, config.dataset, config.arch, parallel=False) + model_2 = create_model(False, config.dataset, config.arch, parallel=is_parallel) model(dummy_input) model_2(dummy_input) conv2 = common.find_module_by_name(model_2, pair[1]) @@ -269,19 +301,22 @@ def arbitrary_channel_pruning(config, channels_to_remove): logger.info("test_arbitrary_channel_pruning - Done 2") -def test_conv_fc_interface(model=None, zeros_mask_dict=None): +def test_conv_fc_interface(is_parallel=parallel, model=None, zeros_mask_dict=None): """A special case of convolution filter-pruning occurs when the next layer is fully-connected (linear). This test is for this case and uses VGG16. """ arch = "vgg19" dataset = "imagenet" ratio_to_prune = 0.1 - conv_name = "features.34" + if is_parallel: + conv_name = "features.module.34" + else: + conv_name = "features.34" fc_name = "classifier.0" - dummy_input = torch.randn(1, 3, 224, 224) + dummy_input = torch.randn(1, 3, 224, 224).cuda() if model is None or zeros_mask_dict is None: - model, zeros_mask_dict = common.setup_test(arch, dataset) + model, zeros_mask_dict = common.setup_test(arch, dataset, is_parallel) # Run forward and backward passes, in order to create the gradients and optimizer params optimizer = torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.9, weight_decay=0.1) @@ -323,8 +358,11 @@ def test_conv_fc_interface(model=None, zeros_mask_dict=None): if __name__ == '__main__': - test_ranked_filter_pruning() - test_arbitrary_channel_pruning() - test_prune_all_channels() - model, zeros_mask_dict = ranked_filter_pruning(vgg19_imagenet(), ratio_to_prune=0.1) - test_conv_fc_interface(model, zeros_mask_dict) + for is_parallel in [True, False]: + test_ranked_filter_pruning(is_parallel) + test_arbitrary_channel_pruning(is_parallel) + test_prune_all_channels(is_parallel) + model, zeros_mask_dict = ranked_filter_pruning(vgg19_imagenet(is_parallel), + ratio_to_prune=0.1, + is_parallel=is_parallel) + test_conv_fc_interface(is_parallel, model, zeros_mask_dict)