diff --git a/distiller/thinning.py b/distiller/thinning.py index fb95f83833fd46f1bfab5f4b197db283500c7ebe..241e18467aea75ff1950c506000aac86cd824921 100755 --- a/distiller/thinning.py +++ b/distiller/thinning.py @@ -21,13 +21,16 @@ from collections import namedtuple import torch from .policy import ScheduledTrainingPolicy import distiller +from apputils import SummaryGraph +from models import ALL_MODEL_NAMES, create_model msglogger = logging.getLogger() ThinningRecipe = namedtuple('ThinningRecipe', ['modules', 'parameters']) __all__ = ['ThinningRecipe', 'resnet_cifar_remove_layers', 'resnet_cifar_remove_filters', 'resnet_cifar_remove_channels', 'ResnetCifarChannelRemover', - 'ResnetCifarFilterRemover', 'execute_thinning_recipe'] + 'ResnetCifarFilterRemover', 'execute_thinning_recipe', + 'FilterRemover', 'vgg_remove_filters'] # This is a dictionary that maps ResNet-Cifar connectivity, of convolutional layers. # BN layers connectivity are implied; and shortcuts are not currently handled. @@ -94,6 +97,24 @@ conv_connectivity = { 'module.layer3.8.conv1': 'module.layer3.8.conv2', 'module.layer3.8.conv2': 'module.fc'} +vgg_conv_connectivity = { + 'features.module.0': 'features.module.2', + 'features.module.2': 'features.module.5', + 'features.module.5': 'features.module.7', + 'features.module.7': 'features.module.10', + 'features.module.10': 'features.module.12', + 'features.module.12': 'features.module.14', + 'features.module.14': 'features.module.16', + 'features.module.16': 'features.module.19', + 'features.module.19': 'features.module.21', + 'features.module.21': 'features.module.23', + 'features.module.23': 'features.module.25', + 'features.module.25': 'features.module.28', + 'features.module.28': 'features.module.30', + 'features.module.30': 'features.module.32', + 'features.module.32': 'features.module.34', + 'features.module.34': 'classifier.0'} + def find_predecessors(layer_name): predecessors = [] for layer, followers in conv_connectivity.items(): @@ -116,8 +137,9 @@ def bn_thinning(thinning_recipe, layers, bn_name, len_thin_features, thin_featur bn_directive['running_var'] = (0, thin_features) thinning_recipe.modules[bn_name] = bn_directive - thinning_recipe.parameters[bn_name+'.weight'] = (0, thin_features) - thinning_recipe.parameters[bn_name+'.bias'] = (0, thin_features) + + thinning_recipe.parameters[bn_name+'.weight'] = [(0, thin_features)] + thinning_recipe.parameters[bn_name+'.bias'] = [(0, thin_features)] def resnet_cifar_remove_layers(resnet_cifar_model): @@ -192,7 +214,9 @@ def resnet_cifar_create_thinning_recipe_channels(resnet_cifar_model, zeros_mask_ # Select only the non-zero filters indices = nonzero_channels.data.squeeze() - thinning_recipe.parameters[name] = (1, indices) + param_directive = thinning_recipe.parameters.get(name, []) + param_directive.append((1, indices)) + thinning_recipe.parameters[name] = param_directive assert layer_name in conv_connectivity predecessors = find_predecessors(layer_name) @@ -204,8 +228,9 @@ def resnet_cifar_create_thinning_recipe_channels(resnet_cifar_model, zeros_mask_ thinning_recipe.modules[predecessor] = conv_directive # Now remove channels from the weights tensor of the follower conv - #predecessor_param = distiller.model_find_param(resnet_cifar_model, predecessor+'.weight') - thinning_recipe.parameters[predecessor+'.weight'] = (0, indices) + param_directive = thinning_recipe.parameters.get(predecessor+'.weight', []) + param_directive.append((0, indices)) + thinning_recipe.parameters[predecessor+'.weight'] = param_directive # Now handle the BatchNormalization layer that follows the convolution bn_name = predecessor.replace('conv', 'bn') @@ -217,22 +242,48 @@ def resnet_cifar_create_thinning_recipe_channels(resnet_cifar_model, zeros_mask_ def resnet_cifar_remove_filters(resnet_cifar_model, zeros_mask_dict): thinning_recipe = resnet_cifar_create_thinning_recipe_filters(resnet_cifar_model, zeros_mask_dict) + return remove_filters(resnet_cifar_model, zeros_mask_dict, thinning_recipe) + + +def vgg_remove_filters(vgg_model, zeros_mask_dict): + thinning_recipe = vgg_create_thinning_recipe_filters(vgg_model, zeros_mask_dict) + return remove_filters(vgg_model, zeros_mask_dict, thinning_recipe) +def remove_filters(model, zeros_mask_dict, thinning_recipe): if len(thinning_recipe.modules) > 0 or len(thinning_recipe.parameters)>0: # Stash the recipe, so that it will be serialized together with the model - resnet_cifar_model.thinning_recipe = thinning_recipe + model.thinning_recipe = thinning_recipe # Now actually remove the filters, chaneels and make the weight tensors smaller - execute_thinning_recipe(resnet_cifar_model, zeros_mask_dict, thinning_recipe) + execute_thinning_recipe(model, zeros_mask_dict, thinning_recipe) msglogger.info("Created, applied and saved a filter-thinning recipe") - return resnet_cifar_model + return model + +def get_input(dataset): + if dataset == 'imagenet': + return torch.randn((1, 3, 224, 224), requires_grad=False) + elif dataset == 'cifar10': + return torch.randn((1, 3, 32, 32)) + return None + +def create_graph(dataset, arch): + dummy_input = get_input(dataset) + assert dummy_input is not None, "Unsupported dataset ({}) - aborting draw operation".format(dataset) + + model = create_model(False, dataset, arch, parallel=False) + assert model is not None + return SummaryGraph(model, dummy_input) +def normalize_layer_name(layer_name): + if layer_name.startswith('module.'): + layer_name = layer_name[len('module.'):] + return layer_name def resnet_cifar_create_thinning_recipe_filters(resnet_cifar_model, zeros_mask_dict): """Remove filters from ResNet-Cifar Caveats: (1) supports only ResNet50-Cifar; (2) only module.layerX.Y.conv1.weight """ - msglogger.info("Invoking resnet_cifar_remove_filters") + msglogger.info("Invoking resnet_cifar_create_thinning_recipe_filters") layers = {thin_name : m for thin_name, m in resnet_cifar_model.named_modules()} thinning_recipe = ThinningRecipe(modules={}, parameters={}) @@ -259,9 +310,12 @@ def resnet_cifar_create_thinning_recipe_filters(resnet_cifar_model, zeros_mask_d # Select only the non-zero filters indices = nonzero_filters.data.squeeze() - thinning_recipe.parameters[name] = (0, indices) + #thinning_recipe.parameters[name] = (0, indices) + param_directive = thinning_recipe.parameters.get(name, []) + param_directive.append((0, indices)) + thinning_recipe.parameters[name] = param_directive - assert layer_name in conv_connectivity + assert layer_name in conv_connectivity, "layer {} is not in conv_connectivity {}".format(layer_name, conv_connectivity) followers = conv_connectivity[layer_name] if isinstance(conv_connectivity[layer_name], list) else [conv_connectivity[layer_name]] for follower in followers: # For each of the convolutional layers that follow, we have to reduce the number of input channels. @@ -270,8 +324,12 @@ def resnet_cifar_create_thinning_recipe_filters(resnet_cifar_model, zeros_mask_d thinning_recipe.modules[follower] = conv_directive # Now remove channels from the weights tensor of the follower conv - #follower_param = distiller.model_find_param(resnet_cifar_model, follower+'.weight') - thinning_recipe.parameters[follower+'.weight'] = (1, indices) + param_directive = thinning_recipe.parameters.get(follower+'.weight', []) + + msglogger.info("appending {}".format(follower+'.weight')) + + param_directive.append((1, indices)) + thinning_recipe.parameters[follower+'.weight'] = param_directive # Now handle the BatchNormalization layer that follows the convolution bn_name = layer_name.replace('conv', 'bn') @@ -280,6 +338,80 @@ def resnet_cifar_create_thinning_recipe_filters(resnet_cifar_model, zeros_mask_d bn_thinning(thinning_recipe, layers, bn_name, len_thin_features=len(nonzero_filters), thin_features=indices) return thinning_recipe +import math +def vgg_create_thinning_recipe_filters(vgg_model, zeros_mask_dict): + """Remove filters from VGG + """ + msglogger.info("Invoking vgg_create_thinning_recipe_filters") + + layers = {thin_name : m for thin_name, m in vgg_model.named_modules()} + thinning_recipe = ThinningRecipe(modules={}, parameters={}) + + for name, p_thin in vgg_model.named_parameters(): + if p_thin.dim() != 4: + continue + + # Find the number of filters, in this weights tensor, that are not 100% sparse_model + filter_view = p_thin.view(p_thin.size(0), -1) + num_filters = filter_view.size()[0] + nonzero_filters = torch.nonzero(filter_view.abs().sum(dim=1)) + + # If there are zero-filters in this tensor then... + if num_filters > len(nonzero_filters): + msglogger.info("In tensor %s found %d/%d zero filters", name, + num_filters - len(nonzero_filters), num_filters) + + # Update the number of outgoing channels (OFMs) in the convolutional layer + layer_name = name[:-len('weights')] + assert isinstance(layers[layer_name], torch.nn.modules.Conv2d) + conv_directive = thinning_recipe.modules.get(layer_name, {}) + conv_directive['out_channels'] = len(nonzero_filters) + thinning_recipe.modules[layer_name] = conv_directive + + # Select only the non-zero filters + indices = nonzero_filters.data.squeeze() + param_directive = thinning_recipe.parameters.get(name, []) + param_directive.append((0, indices)) + thinning_recipe.parameters[name] = param_directive + + param_directive = thinning_recipe.parameters.get(layer_name+'.bias', []) + param_directive.append((0, indices)) + thinning_recipe.parameters[layer_name+'.bias'] = param_directive + + assert layer_name in vgg_conv_connectivity, "layer {} is not in vgg_conv_connectivity {}".format(layer_name, vgg_conv_connectivity) + followers = vgg_conv_connectivity[layer_name] if isinstance(vgg_conv_connectivity[layer_name], list) else [vgg_conv_connectivity[layer_name]] + for follower in followers: + # For each of the convolutional layers that follow, we have to reduce the number of input channels. + conv_directive = thinning_recipe.modules.get(follower, {}) + + if isinstance(layers[follower], torch.nn.modules.Conv2d): + conv_directive['in_channels'] = len(nonzero_filters) + msglogger.info("{}: setting in_channels = {}".format(follower, len(nonzero_filters))) + elif isinstance(layers[follower], torch.nn.modules.Linear): + # TODO: this code is hard to follow + fm_size = layers[follower].in_features / layers[layer_name].out_channels + conv_directive['in_features'] = fm_size * len(nonzero_filters) + #assert 22589 == conv_directive['in_features'] + msglogger.info("{}: setting in_features = {}".format(follower, conv_directive['in_features'])) + + thinning_recipe.modules[follower] = conv_directive + + # Now remove channels from the weights tensor of the follower conv + param_directive = thinning_recipe.parameters.get(follower+'.weight', []) + if isinstance(layers[follower], torch.nn.modules.Conv2d): + param_directive.append((1, indices)) + elif isinstance(layers[follower], torch.nn.modules.Linear): + # TODO: this code is hard to follow + fm_size = layers[follower].in_features / layers[layer_name].out_channels + fm_height = fm_width = int(math.sqrt(fm_size)) + selection_view = (layers[follower].out_features, layers[layer_name].out_channels, fm_height, fm_width) + #param_directive.append((1, indices, (4096,512,7,7))) + true_view = (layers[follower].out_features, conv_directive['in_features']) + param_directive.append((1, indices, selection_view, true_view)) + + thinning_recipe.parameters[follower+'.weight'] = param_directive + return thinning_recipe + class ResnetCifarChannelRemover(ScheduledTrainingPolicy): """A policy which applies a network thinning function""" @@ -290,7 +422,7 @@ class ResnetCifarChannelRemover(ScheduledTrainingPolicy): self.thinning_func(model, zeros_mask_dict) -class ResnetCifarFilterRemover(ScheduledTrainingPolicy): +class FilterRemover(ScheduledTrainingPolicy): """A policy which applies a network thinning function""" def __init__(self, thinning_func_str): self.thinning_func = globals()[thinning_func_str] @@ -305,6 +437,11 @@ class ResnetCifarFilterRemover(ScheduledTrainingPolicy): self.thinning_func(model, zeros_mask_dict) self.done = True +# This is for backward compatiblity with some of the published schedules +class ResnetCifarFilterRemover(FilterRemover): + def __init__(self, thinning_func_str): + super().__init__(thinning_func_str) + def execute_thinning_recipe(model, zeros_mask_dict, recipe): """Apply a thinning recipe to a model. @@ -319,18 +456,32 @@ def execute_thinning_recipe(model, zeros_mask_dict, recipe): for layer_name, directives in recipe.modules.items(): for attr, val in directives.items(): if attr in ['running_mean', 'running_var']: + msglogger.info("{} thinning: setting {} to {}".format(layer_name, attr, len(val[1]))) setattr(layers[layer_name], attr, torch.index_select(getattr(layers[layer_name], attr), dim=val[0], index=val[1])) else: + msglogger.info("{} thinning: setting {} to {}".format(layer_name, attr, val)) setattr(layers[layer_name], attr, val) - for param_name, info in recipe.parameters.items(): + assert len(recipe.parameters) > 0 + + for param_name, param_directives in recipe.parameters.items(): param = distiller.model_find_param(model, param_name) - param.data = torch.index_select(param.data, dim=info[0], index=info[1]) - mask = zeros_mask_dict[param_name].mask - if (mask is not None) and (zeros_mask_dict[param_name].mask.size(dim=info[0]) != len(info[1])): - zeros_mask_dict[param_name].mask = torch.index_select(mask, dim=info[0], index=info[1]) + for directive in param_directives: + dim = directive[0] + indices = directive[1] + if len(directive) == 4: # TODO: this code is hard to follow + layer_name = param_name[:-len('weights')] + + selection_view = param.view(*directive[2]) + param.data = torch.index_select(selection_view, dim, indices) + #param.data = param.view(4096, 22589) + param.data = param.view(*directive[3]) + else: + param.data = torch.index_select(param.data, dim, indices) + msglogger.info("thinning: changing param {} shape: {}".format(param_name, len(indices))) - model.thinning_recipe = recipe - return model + mask = zeros_mask_dict[param_name].mask + if (mask is not None) and (zeros_mask_dict[param_name].mask.size(dim) != len(indices)): + zeros_mask_dict[param_name].mask = torch.index_select(mask, dim, indices)