diff --git a/distiller/thinning.py b/distiller/thinning.py index 684d715535b407a3a5d1175974e2422114658e86..8e0b642aef1b1b9422b88101c866da2153f5e003 100755 --- a/distiller/thinning.py +++ b/distiller/thinning.py @@ -70,7 +70,7 @@ def create_graph(dataset, arch): dummy_input = torch.randn((1, 3, 32, 32)) assert dummy_input is not None, "Unsupported dataset ({}) - aborting draw operation".format(dataset) - model = create_model(False, dataset, arch, parallel=False) + model = create_model(False, dataset, arch, parallel=True) assert model is not None return SummaryGraph(model, dummy_input.cuda()) @@ -85,7 +85,8 @@ def append_param_directive(thinning_recipe, param_name, directive): thinning_recipe.parameters[param_name] = param_directive -def append_module_directive(thinning_recipe, module_name, key, val): +def append_module_directive(model, thinning_recipe, module_name, key, val): + module_name = denormalize_module_name(model, module_name) mod_directive = thinning_recipe.modules.get(module_name, {}) mod_directive[key] = val thinning_recipe.modules[module_name] = mod_directive @@ -236,7 +237,7 @@ def create_thinning_recipe_channels(sgraph, model, zeros_mask_dict): # in the convolutional layer layer_name = param_name_2_layer_name(param_name) assert isinstance(layers[layer_name], torch.nn.modules.Conv2d) - append_module_directive(thinning_recipe, layer_name, key='in_channels', val=num_nnz_channels) + append_module_directive(model, thinning_recipe, layer_name, key='in_channels', val=num_nnz_channels) # Select only the non-zero filters indices = nonzero_channels.data.squeeze() @@ -245,17 +246,19 @@ def create_thinning_recipe_channels(sgraph, model, zeros_mask_dict): # Find all instances of Convolution layers that immediately preceed this layer predecessors = sgraph.predecessors_f(normalize_module_name(layer_name), ['Conv']) # Convert the layers names to PyTorch's convoluted naming scheme (when DataParallel is used) - predecessors = [denormalize_module_name(model, predecessor) for predecessor in predecessors] + predecessors = [normalize_module_name(predecessor) for predecessor in predecessors] + if len(predecessors) == 0: + msglogger.info("Could not find predecessors for name={} normal={} {}".format(layer_name, normalize_module_name(layer_name), denormalize_module_name(model, layer_name))) for predecessor in predecessors: # For each of the convolutional layers that preceed, we have to reduce the number of output channels. - append_module_directive(thinning_recipe, predecessor, key='out_channels', val=num_nnz_channels) + append_module_directive(model, thinning_recipe, predecessor, key='out_channels', val=num_nnz_channels) # Now remove channels from the weights tensor of the successor conv - append_param_directive(thinning_recipe, predecessor+'.weight', (0, indices)) + append_param_directive(thinning_recipe, denormalize_module_name(model, predecessor)+'.weight', (0, indices)) - if layers[predecessor].bias is not None: + if layers[denormalize_module_name(model, predecessor)].bias is not None: # This convolution has bias coefficients - append_param_directive(thinning_recipe, predecessor+'.bias', (0, indices)) + append_param_directive(thinning_recipe, denormalize_module_name(model, predecessor)+'.bias', (0, indices)) # Now handle the BatchNormalization layer that follows the convolution bn_layers = sgraph.predecessors_f(normalize_module_name(layer_name), ['BatchNormalization']) @@ -265,7 +268,6 @@ def create_thinning_recipe_channels(sgraph, model, zeros_mask_dict): bn_layer_name = denormalize_module_name(model, bn_layers[0]) bn_thinning(thinning_recipe, layers, bn_layer_name, len_thin_features=num_nnz_channels, thin_features=indices) - return thinning_recipe @@ -307,7 +309,7 @@ def create_thinning_recipe_filters(sgraph, model, zeros_mask_dict): # in the convolutional layer layer_name = param_name_2_layer_name(param_name) assert isinstance(layers[layer_name], torch.nn.modules.Conv2d) - append_module_directive(thinning_recipe, layer_name, key='out_channels', val=num_nnz_filters) + append_module_directive(model, thinning_recipe, layer_name, key='out_channels', val=num_nnz_filters) # Select only the non-zero filters indices = nonzero_filters.data.squeeze() @@ -325,17 +327,17 @@ def create_thinning_recipe_filters(sgraph, model, zeros_mask_dict): if isinstance(layers[successor], torch.nn.modules.Conv2d): # For each of the convolutional layers that follow, we have to reduce the number of input channels. - append_module_directive(thinning_recipe, successor, key='in_channels', val=num_nnz_filters) + append_module_directive(model, thinning_recipe, successor, key='in_channels', val=num_nnz_filters) msglogger.debug("[recipe] {}: setting in_channels = {}".format(successor, num_nnz_filters)) # Now remove channels from the weights tensor of the successor conv - append_param_directive(thinning_recipe, successor+'.weight', (1, indices)) + append_param_directive(thinning_recipe, denormalize_module_name(model, successor)+'.weight', (1, indices)) elif isinstance(layers[successor], torch.nn.modules.Linear): # If a Linear (Fully-Connected) layer follows, we need to update it's in_features member fm_size = layers[successor].in_features // layers[layer_name].out_channels in_features = fm_size * num_nnz_filters - append_module_directive(thinning_recipe, successor, key='in_features', val=in_features) + append_module_directive(model, thinning_recipe, successor, key='in_features', val=in_features) msglogger.debug("[recipe] {}: setting in_features = {}".format(successor, in_features)) # Now remove channels from the weights tensor of the successor FC layer: @@ -343,7 +345,7 @@ def create_thinning_recipe_filters(sgraph, model, zeros_mask_dict): fm_height = fm_width = int(math.sqrt(fm_size)) view_4D = (layers[successor].out_features, layers[layer_name].out_channels, fm_height, fm_width) view_2D = (layers[successor].out_features, in_features) - append_param_directive(thinning_recipe, successor+'.weight', (1, indices, view_4D, view_2D)) + append_param_directive(thinning_recipe, denormalize_module_name(model, successor)+'.weight', (1, indices, view_4D, view_2D)) # Now handle the BatchNormalization layer that follows the convolution bn_layers = sgraph.successors_f(normalize_module_name(layer_name), ['BatchNormalization']) @@ -404,8 +406,9 @@ def execute_thinning_recipes_list(model, zeros_mask_dict, recipe_list): # Invoke this function when you want to use a list of thinning recipes to convert a programmed model # to a thinned model. For example, this is invoked when loading a model from a checkpoint. for i, recipe in enumerate(recipe_list): - msglogger.info("recipe %d:" % i) + msglogger.debug("Executing recipe %d:" % i) execute_thinning_recipe(model, zeros_mask_dict, recipe, optimizer=None, loaded_from_file=True) + msglogger.info("Executed %d recipes" % len(recipe_list)) def optimizer_thinning(optimizer, param, dim, indices, new_shape=None): @@ -431,7 +434,7 @@ def optimizer_thinning(optimizer, param, dim, indices, new_shape=None): if 'momentum_buffer' in param_state: param_state['momentum_buffer'] = torch.index_select(param_state['momentum_buffer'], dim, indices) if new_shape is not None: - msglogger.info("optimizer_thinning: new shape {}".format(*new_shape)) + msglogger.debug("optimizer_thinning: new shape {}".format(*new_shape)) param_state['momentum_buffer'] = param_state['momentum_buffer'].resize_(*new_shape) return True return False @@ -443,10 +446,7 @@ def execute_thinning_recipe(model, zeros_mask_dict, recipe, optimizer, loaded_fr adjustment, and thinning of weight tensors. """ - layers = {} - for name, m in model.named_modules(): - layers[name] = m - + layers = {mod_name: m for mod_name, m in model.named_modules()} for layer_name, directives in recipe.modules.items(): for attr, val in directives.items(): if attr in ['running_mean', 'running_var']: @@ -467,6 +467,7 @@ def execute_thinning_recipe(model, zeros_mask_dict, recipe, optimizer, loaded_fr for param_name, param_directives in recipe.parameters.items(): param = distiller.model_find_param(model, param_name) + assert param is not None for directive in param_directives: dim = directive[0] indices = directive[1] @@ -482,8 +483,8 @@ def execute_thinning_recipe(model, zeros_mask_dict, recipe, optimizer, loaded_fr if grad_selection_view.size(dim) != len_indices: param.grad = torch.index_select(grad_selection_view, dim, indices) if optimizer_thinning(optimizer, param, dim, indices, directive[3]): - msglogger.info("Updated [4D] velocity buffer for {} (dim={},size={},shape={})". - format(param_name, dim, len_indices, directive[3])) + msglogger.debug("Updated [4D] velocity buffer for {} (dim={},size={},shape={})". + format(param_name, dim, len_indices, directive[3])) param.data = param.view(*directive[3]) if param.grad is not None: @@ -498,7 +499,7 @@ def execute_thinning_recipe(model, zeros_mask_dict, recipe, optimizer, loaded_fr if param.grad is not None and param.grad.size(dim) != len_indices: param.grad = torch.index_select(param.grad, dim, indices) if optimizer_thinning(optimizer, param, dim, indices): - msglogger.info("Updated velocity buffer %s" % param_name) + msglogger.debug("Updated velocity buffer %s" % param_name) if not loaded_from_file: # If the masks are loaded from a checkpoint file, then we don't need to change diff --git a/distiller/utils.py b/distiller/utils.py index 17279538e113814dc92955b05790d0e1600ba110..588b1f53a296ed6e8ab17a9b3feb6ab305690ed4 100755 --- a/distiller/utils.py +++ b/distiller/utils.py @@ -80,7 +80,7 @@ def denormalize_module_name(parallel_model, normalized_name): if len(fully_qualified_name) > 0: return fully_qualified_name[-1] else: - return "" # Did not find a module with the name <normalized_name> + return normalized_name # Did not find a module with the name <normalized_name> def volume(tensor): @@ -277,10 +277,10 @@ class DoNothingModuleWrapper(nn.Module): """ def __init__(self, module): super(DoNothingModuleWrapper, self).__init__() - self.wrapped_module = module + self.module = module def forward(self, *inputs, **kwargs): - return self.wrapped_module(*inputs, **kwargs) + return self.module(*inputs, **kwargs) def make_non_parallel_copy(model): diff --git a/tests/common.py b/tests/common.py index 324f2b824681a2dc9f6e11dab894f62d2a81e494..0991321ceb9bea0b5c207d7899cc30d748fa5917 100755 --- a/tests/common.py +++ b/tests/common.py @@ -16,6 +16,7 @@ import os import sys +import torch module_path = os.path.abspath(os.path.join('..')) if module_path not in sys.path: sys.path.append(module_path) @@ -40,3 +41,11 @@ def find_module_by_name(model, module_to_find): if name == module_to_find: return m return None + + +def get_dummy_input(dataset): + if dataset == "imagenet": + return torch.randn(1, 3, 224, 224).cuda() + elif dataset == "cifar10": + return torch.randn(1, 3, 32, 32).cuda() + raise ValueError("Trying to use an unknown dataset " + dataset) diff --git a/tests/test_pruning.py b/tests/test_pruning.py index 57d297510ef2774ba13647a421c2a840498870fd..50fb78c20995c05fe21cabafd833be6e519dc96f 100755 --- a/tests/test_pruning.py +++ b/tests/test_pruning.py @@ -66,7 +66,8 @@ def resnet20_cifar(is_parallel): def vgg19_imagenet(is_parallel): if is_parallel: return NetConfig(arch="vgg19", dataset="imagenet", - module_pairs=[("features.module.21", "features.module.23"), + module_pairs=[("features.module.0", "features.module.2"), + ("features.module.21", "features.module.23"), ("features.module.23", "features.module.25"), ("features.module.25", "features.module.28"), ("features.module.28", "features.module.30"), @@ -163,6 +164,9 @@ def test_arbitrary_channel_pruning(parallel): arbitrary_channel_pruning(simplenet(parallel), channels_to_remove=[0, 2], is_parallel=parallel) + arbitrary_channel_pruning(vgg19_imagenet(parallel), + channels_to_remove=[0, 2], + is_parallel=parallel) def test_prune_all_channels(parallel): @@ -223,7 +227,6 @@ def arbitrary_channel_pruning(config, channels_to_remove, is_parallel): """ model, zeros_mask_dict = common.setup_test(config.arch, config.dataset, is_parallel) - assert len(config.module_pairs) == 1 # This is a temporary restriction on the test pair = config.module_pairs[0] conv2 = common.find_module_by_name(model, pair[1]) assert conv2 is not None @@ -250,7 +253,7 @@ def arbitrary_channel_pruning(config, channels_to_remove, is_parallel): # Now, let's do the actual network thinning distiller.remove_channels(model, zeros_mask_dict, config.arch, config.dataset, optimizer=None) conv1 = common.find_module_by_name(model, pair[0]) - + assert conv1 assert conv1.out_channels == cnt_nnz_channels assert conv2.in_channels == cnt_nnz_channels assert conv1.weight.size(0) == cnt_nnz_channels @@ -263,7 +266,7 @@ def arbitrary_channel_pruning(config, channels_to_remove, is_parallel): assert bn1.bias.size(0) == cnt_nnz_channels assert bn1.weight.size(0) == cnt_nnz_channels - dummy_input = torch.randn(1, 3, 32, 32).cuda() + dummy_input = common.get_dummy_input(config.dataset) optimizer = torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.9, weight_decay=0.1) run_forward_backward(model, optimizer, dummy_input) @@ -364,3 +367,7 @@ if __name__ == '__main__': ratio_to_prune=0.1, is_parallel=is_parallel) test_conv_fc_interface(is_parallel, model, zeros_mask_dict) + + arbitrary_channel_pruning(vgg19_imagenet(parallel), + channels_to_remove=[0, 2], + is_parallel=parallel)