diff --git a/distiller/apputils/checkpoint.py b/distiller/apputils/checkpoint.py index f45026b49c3669352622d750779c9adc99dd4c5d..d317bdead0e235680c9088152f8eb28770545b84 100755 --- a/distiller/apputils/checkpoint.py +++ b/distiller/apputils/checkpoint.py @@ -198,8 +198,8 @@ def load_checkpoint(model, chkpt_file, optimizer=None, if not model: model = _create_model_from_ckpt() if not model: - raise ValueError("You didn't provide a model, and the checkpoint doesn't contain" - "enough information to create one") + raise ValueError("You didn't provide a model, and the checkpoint %s doesn't contain " + "enough information to create one", chkpt_file) checkpoint_epoch = checkpoint.get('epoch', None) start_epoch = checkpoint_epoch + 1 if checkpoint_epoch is not None else 0 diff --git a/distiller/thinning.py b/distiller/thinning.py index 0d790c21817afe251fe9c3ef340eee05e27eece7..74ee6f73a07a684e0f663007376f5a86b3afc6b3 100755 --- a/distiller/thinning.py +++ b/distiller/thinning.py @@ -60,9 +60,37 @@ __all__ = ['ThinningRecipe', 'resnet_cifar_remove_layers', 'StructureRemover', 'ChannelRemover', 'remove_channels', 'FilterRemover', 'remove_filters', + 'contract_model', 'execute_thinning_recipes_list', 'get_normalized_recipe'] +def contract_model(model, zeros_mask_dict, arch, dataset, optimizer): + """Contract a model by removing filters and channels + + The algorithm searches for weight filters and channels that have all + zero-coefficients, and shrinks the model by removing these channels + and filters from the model definition, along with any related parameters. + """ + remove_filters(model, zeros_mask_dict, arch, dataset, optimizer) + remove_channels(model, zeros_mask_dict, arch, dataset, optimizer) + + +def remove_channels(model, zeros_mask_dict, arch, dataset, optimizer): + """Contract a model by removing weight channels""" + sgraph = _create_graph(dataset, model) + thinning_recipe = create_thinning_recipe_channels(sgraph, model, zeros_mask_dict) + apply_and_save_recipe(model, zeros_mask_dict, thinning_recipe, optimizer) + return model + + +def remove_filters(model, zeros_mask_dict, arch, dataset, optimizer): + """Contract a model by removing weight filters""" + sgraph = _create_graph(dataset, model) + thinning_recipe = create_thinning_recipe_filters(sgraph, model, zeros_mask_dict) + apply_and_save_recipe(model, zeros_mask_dict, thinning_recipe, optimizer) + return model + + def _create_graph(dataset, model): dummy_input = distiller.get_dummy_input(dataset, distiller.model_device(model)) return SummaryGraph(model, dummy_input) @@ -75,7 +103,7 @@ def get_normalized_recipe(recipe): ) -def directives_equal(d1, d2): +def _directives_equal(d1, d2): """Test if two directives are equal""" if len(d1) != len(d2): return False @@ -88,7 +116,7 @@ def directives_equal(d1, d2): assert ValueError("Unsupported directive length") -def append_param_directive(thinning_recipe, param_name, directive): +def _append_param_directive(thinning_recipe, param_name, directive): """Add a parameter directive to a recipe. Parameter directives contain instructions for changing the physical shape of parameters. @@ -98,14 +126,14 @@ def append_param_directive(thinning_recipe, param_name, directive): # Duplicate parameter directives are rooted out because they can create erronous conditions. # For example, if the first directive changes the change of the parameter, a second # directive will cause an exception. - if directives_equal(d, directive): + if _directives_equal(d, directive): return msglogger.debug("\t[recipe] param_directive for {} = {}".format(param_name, directive)) param_directives.append(directive) thinning_recipe.parameters[param_name] = param_directives -def append_module_directive(thinning_recipe, module_name, key, val): +def _append_module_directive(thinning_recipe, module_name, key, val): """Add a module directive to a recipe. Parameter directives contain instructions for changing the attributes of @@ -117,7 +145,7 @@ def append_module_directive(thinning_recipe, module_name, key, val): thinning_recipe.modules[module_name] = mod_directive -def append_bn_thinning_directive(thinning_recipe, layers, bn_name, len_thin_features, thin_features): +def _append_bn_thinning_directive(thinning_recipe, layers, bn_name, len_thin_features, thin_features): """Adjust the sizes of the parameters of a BatchNormalization layer. This function is invoked after the Convolution layer preceeding a BN layer has @@ -142,13 +170,6 @@ def append_bn_thinning_directive(thinning_recipe, layers, bn_name, len_thin_feat thinning_recipe.parameters[bn_name+'.bias'] = [(0, thin_features)] -def remove_channels(model, zeros_mask_dict, arch, dataset, optimizer): - sgraph = _create_graph(dataset, model) - thinning_recipe = create_thinning_recipe_channels(sgraph, model, zeros_mask_dict) - apply_and_save_recipe(model, zeros_mask_dict, thinning_recipe, optimizer) - return model - - def apply_and_save_recipe(model, zeros_mask_dict, thinning_recipe, optimizer): if len(thinning_recipe.modules) > 0 or len(thinning_recipe.parameters) > 0: # Now actually remove the filters, channels and make the weight tensors smaller @@ -165,13 +186,6 @@ def apply_and_save_recipe(model, zeros_mask_dict, thinning_recipe, optimizer): msglogger.error("Failed to create a thinning recipe") -def remove_filters(model, zeros_mask_dict, arch, dataset, optimizer): - sgraph = _create_graph(dataset, model) - thinning_recipe = create_thinning_recipe_filters(sgraph, model, zeros_mask_dict) - apply_and_save_recipe(model, zeros_mask_dict, thinning_recipe, optimizer) - return model - - def create_thinning_recipe_channels(sgraph, model, zeros_mask_dict): """Create a recipe for removing channels from Convolution layers. @@ -187,14 +201,14 @@ def create_thinning_recipe_channels(sgraph, model, zeros_mask_dict): # in the convolutional layer assert isinstance(layers[layer_name], (torch.nn.modules.Conv2d, torch.nn.modules.Linear)) - append_module_directive(thinning_recipe, layer_name, key='in_channels', val=nnz_channels) + _append_module_directive(thinning_recipe, layer_name, key='in_channels', val=nnz_channels) # Select only the non-zero channels indices = nonzero_channels.data.squeeze() dim = 1 if isinstance(layers[layer_name], torch.nn.modules.Conv2d) and layers[layer_name].groups == 1 else 0 if isinstance(layers[layer_name], torch.nn.modules.Linear): dim = 1 - append_param_directive(thinning_recipe, param_name, (dim, indices)) + _append_param_directive(thinning_recipe, param_name, (dim, indices)) # Find all instances of Convolution layers that immediately precede this layer predecessors = sgraph.predecessors_f(layer_name, ['Conv']) @@ -202,24 +216,24 @@ def create_thinning_recipe_channels(sgraph, model, zeros_mask_dict): msglogger.info("Could not find predecessors for name=%s" % layer_name) for predecessor in predecessors: # For each of the convolution layers that precede, we have to reduce the number of output channels. - append_module_directive(thinning_recipe, predecessor, key='out_channels', val=nnz_channels) + _append_module_directive(thinning_recipe, predecessor, key='out_channels', val=nnz_channels) if layers[predecessor].groups == 1: # Now remove filters from the weights tensor of the predecessor conv - append_param_directive(thinning_recipe, predecessor+'.weight', (0, indices)) + _append_param_directive(thinning_recipe, predecessor + '.weight', (0, indices)) if layers[predecessor].bias is not None: # This convolution has bias coefficients - append_param_directive(thinning_recipe, predecessor+'.bias', (0, indices)) + _append_param_directive(thinning_recipe, predecessor + '.bias', (0, indices)) elif layers[predecessor].groups == layers[predecessor].in_channels: # This is a group-wise convolution, and a special one at that (groups == in_channels). # Now remove filters from the weights tensor of the predecessor conv - append_param_directive(thinning_recipe, predecessor+'.weight', (0, indices)) + _append_param_directive(thinning_recipe, predecessor + '.weight', (0, indices)) if layers[predecessor].bias is not None: # This convolution has bias coefficients - append_param_directive(thinning_recipe, predecessor+'.bias', (0, indices)) - append_module_directive(thinning_recipe, predecessor, key='groups', val=nnz_channels) + _append_param_directive(thinning_recipe, predecessor + '.bias', (0, indices)) + _append_module_directive(thinning_recipe, predecessor, key='groups', val=nnz_channels) # In the special case of a Convolutional layer with (groups == in_channels), if we # change in_channels, we also need to change out_channels, which means that we @@ -234,8 +248,8 @@ def create_thinning_recipe_channels(sgraph, model, zeros_mask_dict): for bn_layer in bn_layers: # Thinning of the BN layer that follows the convolution msglogger.debug("[recipe] {}: predecessor BN module = {}".format(layer_name, bn_layer)) - append_bn_thinning_directive(thinning_recipe, layers, bn_layer, - len_thin_features=nnz_channels, thin_features=indices) + _append_bn_thinning_directive(thinning_recipe, layers, bn_layer, + len_thin_features=nnz_channels, thin_features=indices) msglogger.debug("Invoking create_thinning_recipe_channels") thinning_recipe = ThinningRecipe(modules={}, parameters={}) @@ -279,15 +293,15 @@ def create_thinning_recipe_filters(sgraph, model, zeros_mask_dict): # We are removing filters, so update the number of outgoing channels (OFMs) # in the convolutional layer assert isinstance(layers[layer_name], torch.nn.modules.Conv2d) - append_module_directive(thinning_recipe, layer_name, key='out_channels', val=num_nnz_filters) + _append_module_directive(thinning_recipe, layer_name, key='out_channels', val=num_nnz_filters) # Select only the non-zero filters indices = nonzero_filters.data.squeeze() - append_param_directive(thinning_recipe, param_name, (0, indices)) + _append_param_directive(thinning_recipe, param_name, (0, indices)) if layers[layer_name].bias is not None: # This convolution has bias coefficients - append_param_directive(thinning_recipe, layer_name+'.bias', (0, indices)) + _append_param_directive(thinning_recipe, layer_name + '.bias', (0, indices)) # Find all instances of Convolution or FC (GEMM) layers that immediately follow this layer successors = sgraph.successors_f(layer_name, ['Conv', 'Gemm']) @@ -295,36 +309,22 @@ def create_thinning_recipe_filters(sgraph, model, zeros_mask_dict): if isinstance(layers[successor], torch.nn.modules.Conv2d): handle_conv_successor(thinning_recipe, layers, successor, num_nnz_filters, indices) elif isinstance(layers[successor], torch.nn.modules.Linear): - # If a Linear (Fully-Connected) layer follows, we need to update it's in_features member - fm_size = layers[successor].in_features // layers[layer_name].out_channels - in_features = fm_size * num_nnz_filters - append_module_directive(thinning_recipe, successor, key='in_features', val=in_features) - msglogger.debug("[recipe] Linear {}: fm_size = {} layers[{}].out_channels={}".format( - successor, in_features, layer_name, layers[layer_name].out_channels)) - msglogger.debug("[recipe] {}: setting in_features = {}".format(successor, in_features)) - - # Now remove channels from the weights tensor of the successor FC layer: - # This is a bit tricky: - fm_height = fm_width = int(math.sqrt(fm_size)) - view_4D = (layers[successor].out_features, layers[layer_name].out_channels, fm_height, fm_width) - view_2D = (layers[successor].out_features, in_features) - append_param_directive(thinning_recipe, successor+'.weight', - (1, indices, view_4D, view_2D)) + handle_linear_successor(successor, indices) # Now handle the BatchNormalization layer that follows the convolution handle_bn_layers(layers, layer_name, num_nnz_filters, indices) def handle_conv_successor(thinning_recipe, layers, successor, num_nnz_filters, indices): # For each of the convolutional layers that follow, we have to reduce the number of input channels. - append_module_directive(thinning_recipe, successor, key='in_channels', val=num_nnz_filters) + _append_module_directive(thinning_recipe, successor, key='in_channels', val=num_nnz_filters) if layers[successor].groups == 1: # Now remove channels from the weights tensor of the successor conv - append_param_directive(thinning_recipe, successor+'.weight', (1, indices)) + _append_param_directive(thinning_recipe, successor + '.weight', (1, indices)) elif layers[successor].groups == layers[successor].in_channels: # Special case: number of groups is equal to the number of input channels - append_param_directive(thinning_recipe, successor+'.weight', (0, indices)) - append_module_directive(thinning_recipe, successor, key='groups', val=num_nnz_filters) + _append_param_directive(thinning_recipe, successor + '.weight', (0, indices)) + _append_module_directive(thinning_recipe, successor, key='groups', val=num_nnz_filters) # In the special case of a Convolutional layer with (groups == in_channels), if we # change in_channels, we also need to change out_channels, which means that we @@ -334,13 +334,30 @@ def create_thinning_recipe_filters(sgraph, model, zeros_mask_dict): else: raise ValueError("Distiller thinning code currently does not handle this conv.groups configuration") + def handle_linear_successor(successor, indices): + # If a Linear (Fully-Connected) layer follows, we need to update it's in_features member + fm_size = layers[successor].in_features // layers[layer_name].out_channels + in_features = fm_size * num_nnz_filters + _append_module_directive(thinning_recipe, successor, key='in_features', val=in_features) + msglogger.debug("[recipe] Linear {}: fm_size = {} layers[{}].out_channels={}".format( + successor, in_features, layer_name, layers[layer_name].out_channels)) + msglogger.debug("[recipe] {}: setting in_features = {}".format(successor, in_features)) + + # Now remove channels from the weights tensor of the successor FC layer: + # This is a bit tricky: + fm_height = fm_width = int(math.sqrt(fm_size)) + view_4D = (layers[successor].out_features, layers[layer_name].out_channels, fm_height, fm_width) + view_2D = (layers[successor].out_features, in_features) + _append_param_directive(thinning_recipe, successor + '.weight', + (1, indices, view_4D, view_2D)) + def handle_bn_layers(layers, layer_name, num_nnz_filters, indices): bn_layers = sgraph.successors_f(layer_name, ['BatchNormalization']) if bn_layers: assert len(bn_layers) == 1 # Thinning of the BN layer that follows the convolution - append_bn_thinning_directive(thinning_recipe, layers, bn_layers[0], - len_thin_features=num_nnz_filters, thin_features=indices) + _append_bn_thinning_directive(thinning_recipe, layers, bn_layers[0], + len_thin_features=num_nnz_filters, thin_features=indices) msglogger.debug("Invoking create_thinning_recipe_filters") @@ -423,7 +440,7 @@ def execute_thinning_recipes_list(model, zeros_mask_dict, recipe_list): msglogger.debug("Executed %d recipes" % len(recipe_list)) -def optimizer_thinning(optimizer, param, dim, indices, new_shape=None): +def _optimizer_thinning(optimizer, param, dim, indices, new_shape=None): """Adjust the size of the SGD velocity-tracking tensors. The SGD momentum update (velocity) is dependent on the weights, and because during thinning we @@ -504,7 +521,7 @@ def execute_thinning_recipe(model, zeros_mask_dict, recipe, optimizer, loaded_fr if grad_selection_view.size(dim) != len_indices: param.grad = torch.index_select(grad_selection_view, dim, indices) # update optimizer - if optimizer_thinning(optimizer, param, dim, indices, directive[3]): + if _optimizer_thinning(optimizer, param, dim, indices, directive[3]): msglogger.debug("Updated [4D] velocity buffer for {} (dim={},size={},shape={})". format(param_name, dim, len_indices, directive[3])) @@ -524,7 +541,7 @@ def execute_thinning_recipe(model, zeros_mask_dict, recipe, optimizer, loaded_fr if param.grad is not None and param.grad.size(dim) != len_indices: param.grad = torch.index_select(param.grad, dim, indices.to(param.device)) # update optimizer - if optimizer_thinning(optimizer, param, dim, indices): + if _optimizer_thinning(optimizer, param, dim, indices): msglogger.debug("Updated velocity buffer %s" % param_name) if not loaded_from_file and zeros_mask_dict: diff --git a/examples/classifier_compression/compress_classifier.py b/examples/classifier_compression/compress_classifier.py index 500eeb57b613690974b55239d0c6a3376746dc6a..b7c45032f464529c65efb76eb9cd3377f4f0d7f3 100755 --- a/examples/classifier_compression/compress_classifier.py +++ b/examples/classifier_compression/compress_classifier.py @@ -120,7 +120,7 @@ def handle_subapps(model, criterion, optimizer, compression_scheduler, pylogger, #zeros_mask_dict = distiller.create_model_masks_dict(model) assert args.resumed_checkpoint_path is not None, \ "You must use --resume-from to provide a checkpoint file to thinnify" - distiller.remove_filters(model, compression_scheduler.zeros_mask_dict, args.arch, args.dataset, optimizer=None) + distiller.contract_model(model, compression_scheduler.zeros_mask_dict, args.arch, args.dataset, optimizer=None) apputils.save_checkpoint(0, args.arch, model, optimizer=None, scheduler=compression_scheduler, name="{}_thinned".format(args.resumed_checkpoint_path.replace(".pth.tar", "")), dir=msglogger.logdir) diff --git a/examples/lottery_ticket/README.md b/examples/lottery_ticket/README.md index b9497391896b6570a283f424591b94cddc7bf5a1..8842a7949da2b18f12401a50d595d580c9feffa1 100644 --- a/examples/lottery_ticket/README.md +++ b/examples/lottery_ticket/README.md @@ -9,11 +9,11 @@ most the same number of iterations." their smallest-magnitude weights. The set of connections that survives this process is the architecture of a winning ticket. Unique to our work, the winning ticket’s weights are the values to which these connections were initialized before training. This forms our central experiment: ->1. Randomly initialize a neural network f(x; theta_0) (where theta_0 ~ D_0). ->2. Train the network for j iterations, reaching parameters theta_j . +>1. Randomly initialize a neural network f(x; theta_0) (where theta_0 ~ D_0). +>2. Train the network for j iterations, reaching parameters theta_j. >3. Prune s% of the parameters, creating a mask m where Pm = (100 - s)%. ->4. To extract the winning ticket, reset the remaining parameters to their values intheta_0, creating -the untrained network f(x;m *theta_0). +>4. To extract the winning ticket, reset the remaining parameters to their values in theta_0, creating +the untrained network f(x; m * theta_0). ### Example Train a ResNet20-CIFAR10 network from scratch, and save the untrained, randomized initial network weights in a checkpoint file. @@ -23,7 +23,7 @@ python3 compress_classifier.py --arch resnet20_cifar ${CIFAR10_PATH} -p=50 --ep ``` After training the network, we have two outputs: the best trained network (`resnet20_best.pth.tar`) and the initial untrained network (`resnet20_untrained_checkpoint.pth.tar`).<br> -In this example, we copy them into the `examples/lottery_ticket` directory for convenience. +In this example, we copy both checkpoints into the `examples/lottery_ticket` directory for convenience. ```bash cp logs/resnet20___2019.08.22-220243/resnet20_best.pth.tar ../lottery_ticket/