diff --git a/apputils/checkpoint.py b/apputils/checkpoint.py index a7509f57fc8f25686d333144f321d0f557ada584..83ae242b4c0d8c9281e3d87226a0d005ab79341a 100755 --- a/apputils/checkpoint.py +++ b/apputils/checkpoint.py @@ -50,7 +50,8 @@ def save_checkpoint(epoch, arch, model, optimizer, scheduler=None, best_top1=Non checkpoint['state_dict'] = model.state_dict() if best_top1 is not None: checkpoint['best_top1'] = best_top1 - checkpoint['optimizer'] = optimizer.state_dict() + if optimizer is not None: + checkpoint['optimizer'] = optimizer.state_dict() if scheduler is not None: checkpoint['compression_sched'] = scheduler.state_dict() if hasattr(model, 'thinning_recipes'): @@ -86,18 +87,21 @@ def load_checkpoint(model, chkpt_file, optimizer=None): msglogger.info("Loaded compression schedule from checkpoint (epoch %d)", checkpoint['epoch']) - if 'thinning_recipes' in checkpoint: - msglogger.info("Loaded a thinning recipe from the checkpoint") - distiller.execute_thinning_recipes_list(model, - compression_scheduler.zeros_mask_dict, - checkpoint['thinning_recipes']) + if 'thinning_recipes' in checkpoint: + if 'compression_sched' not in checkpoint: + raise KeyError("Found thinning_recipes key, but missing mandatoy key compression_sched") + msglogger.info("Loaded a thinning recipe from the checkpoint") + # Cache the recipes in case we need them later + model.thinning_recipes = checkpoint['thinning_recipes'] + distiller.execute_thinning_recipes_list(model, + compression_scheduler.zeros_mask_dict, + model.thinning_recipes) else: msglogger.info("Warning: compression schedule data does not exist in the checkpoint") msglogger.info("=> loaded checkpoint '%s' (epoch %d)", chkpt_file, checkpoint['epoch']) model.load_state_dict(checkpoint['state_dict']) - return model, compression_scheduler, start_epoch else: msglogger.info("Error: no checkpoint found at %s", chkpt_file) diff --git a/distiller/thinning.py b/distiller/thinning.py index 4a2cfdb18a22661463d3ad89ade5794e6b46f44f..caa3192528b6a6518beafbc5cc349b2a83e734cb 100755 --- a/distiller/thinning.py +++ b/distiller/thinning.py @@ -66,6 +66,7 @@ These tuples can have 2 values, or 4 values. __all__ = ['ThinningRecipe', 'resnet_cifar_remove_layers', 'ChannelRemover', 'remove_channels', 'FilterRemover', 'remove_filters', + 'find_nonzero_channels', 'execute_thinning_recipes_list'] def create_graph(dataset, arch): @@ -266,7 +267,7 @@ def create_thinning_recipe_channels(sgraph, model, zeros_mask_dict): predecessors = [denormalize_layer_name(model, predecessor) for predecessor in predecessors] for predecessor in predecessors: # For each of the convolutional layers that preceed, we have to reduce the number of output channels. - append_module_directive(thinning_recipe, layer_name, key='out_channels', val=len(nonzero_channels)) + append_module_directive(thinning_recipe, predecessor, key='out_channels', val=len(nonzero_channels)) # Now remove channels from the weights tensor of the successor conv append_param_directive(thinning_recipe, predecessor+'.weight', (0, indices)) @@ -415,7 +416,7 @@ def execute_thinning_recipes_list(model, zeros_mask_dict, recipe_list): # Invoke this function when you want to use a list of thinning recipes to convert a programmed model # to a thinned model. For example, this is invoked when loading a model from a checkpoint. for i, recipe in enumerate(recipe_list): - msglogger.info("recipe %d" % i) + msglogger.info("recipe %d:" % i) execute_thinning_recipe(model, zeros_mask_dict, recipe, loaded_from_file=True) def execute_thinning_recipe(model, zeros_mask_dict, recipe, loaded_from_file=False): @@ -431,10 +432,14 @@ def execute_thinning_recipe(model, zeros_mask_dict, recipe, loaded_from_file=Fal for layer_name, directives in recipe.modules.items(): for attr, val in directives.items(): if attr in ['running_mean', 'running_var']: - msglogger.info("[thinning] {}: setting {} to {}".format(layer_name, attr, len(val[1]))) - setattr(layers[layer_name], attr, - torch.index_select(getattr(layers[layer_name], attr), - dim=val[0], index=val[1])) + running = getattr(layers[layer_name], attr) + dim_to_trim = val[0] + indices_to_select = val[1] + # Check if we're trying to trim a parameter that is already "thin" + if running.size(dim_to_trim) != len(indices_to_select): + msglogger.info("[thinning] {}: setting {} to {}".format(layer_name, attr, len(indices_to_select))) + setattr(layers[layer_name], attr, + torch.index_select(running, dim=dim_to_trim, index=indices_to_select)) else: msglogger.info("[thinning] {}: setting {} to {}".format(layer_name, attr, val)) setattr(layers[layer_name], attr, val) @@ -448,28 +453,32 @@ def execute_thinning_recipe(model, zeros_mask_dict, recipe, loaded_from_file=Fal indices = directive[1] if len(directive) == 4: # TODO: this code is hard to follow selection_view = param.view(*directive[2]) - param.data = torch.index_select(selection_view, dim, indices) + # Check if we're trying to trim a parameter that is already "thin" + if param.data.size(dim) != len(indices): + param.data = torch.index_select(selection_view, dim, indices) if param.grad is not None: # We also need to change the dimensions of the gradient tensor. grad_selection_view = param.grad.resize_(*directive[2]) - param.grad = torch.index_select(grad_selection_view, dim, indices) + if grad_selection_view.size(dim) != len(indices): + param.grad = torch.index_select(grad_selection_view, dim, indices) param.data = param.view(*directive[3]) if param.grad is not None: param.grad = param.grad.resize_(*directive[3]) else: - param.data = torch.index_select(param.data, dim, indices) + if param.data.size(dim) != len(indices): + param.data = torch.index_select(param.data, dim, indices) # We also need to change the dimensions of the gradient tensor. # If have not done a backward-pass thus far, then the gradient will # not exist, and therefore won't need to be re-dimensioned. - if param.grad is not None: - param.grad = torch.index_select(param.grad, dim, indices) + if param.grad is not None and param.grad.size(dim) != len(indices): + param.grad = torch.index_select(param.grad, dim, indices) msglogger.info("[thinning] changing param {} shape: {}".format(param_name, len(indices))) if not loaded_from_file: # If the masks are loaded from a checkpoint file, then we don't need to change # their shape, because they are already correctly shaped mask = zeros_mask_dict[param_name].mask - if mask is not None: # and (mask.size(dim) != len(indices)): + if mask is not None and (mask.size(dim) != len(indices)): zeros_mask_dict[param_name].mask = torch.index_select(mask, dim, indices) diff --git a/tests/test_pruning.py b/tests/test_pruning.py index 93f3c77a87e273db4971f3e827eb367b32358cd4..fa01f22f7076592b686f9c6731e99ef6ea567b10 100755 --- a/tests/test_pruning.py +++ b/tests/test_pruning.py @@ -25,7 +25,7 @@ import distiller import pytest from models import ALL_MODEL_NAMES, create_model -from apputils import SummaryGraph, onnx_name_2_pytorch_name +from apputils import SummaryGraph, onnx_name_2_pytorch_name, save_checkpoint, load_checkpoint # Logging configuration logging.basicConfig(level=logging.INFO) @@ -39,8 +39,9 @@ def find_module_by_name(model, module_to_find): return m return None -def test_ranked_filter_pruning(): - model = create_model(False, 'cifar10', 'resnet20_cifar', parallel=False) + +def setup_test(arch, dataset): + model = create_model(False, dataset, arch, parallel=False) assert model is not None # Create the masks @@ -48,6 +49,10 @@ def test_ranked_filter_pruning(): for name, param in model.named_parameters(): masker = distiller.ParameterMasker(name) zeros_mask_dict[name] = masker + return model, zeros_mask_dict + +def test_ranked_filter_pruning(): + model, zeros_mask_dict = setup_test("resnet20_cifar", "cifar10") # Test that we can access the weights tensor of the first convolution in layer 1 conv1_p = distiller.model_find_param(model, "layer1.0.conv1.weight") @@ -68,18 +73,110 @@ def test_ranked_filter_pruning(): expected_pruning = int(0.1 * conv1.out_channels) / conv1.out_channels assert distiller.sparsity_3D(zeros_mask_dict["layer1.0.conv1.weight"].mask) == expected_pruning - # Test masker + # Use the mask to prune assert distiller.sparsity_3D(conv1_p) == 0 zeros_mask_dict["layer1.0.conv1.weight"].apply_mask(conv1_p) assert distiller.sparsity_3D(conv1_p) == expected_pruning # Remove filters - assert conv1.out_channels == 16 conv2 = find_module_by_name(model, "layer1.0.conv2") assert conv2 is not None - assert conv1.in_channels == 16 + assert conv1.out_channels == 16 + assert conv2.in_channels == 16 # Test thinning distiller.remove_filters(model, zeros_mask_dict, "resnet20_cifar", "cifar10") assert conv1.out_channels == 15 assert conv2.in_channels == 15 + + +def test_arbitrary_channel_pruning(): + ARCH = "resnet20_cifar" + DATASET = "cifar10" + + model, zeros_mask_dict = setup_test(ARCH, DATASET) + + conv2 = find_module_by_name(model, "layer1.0.conv2") + assert conv2 is not None + + # Test that we can access the weights tensor of the first convolution in layer 1 + conv2_p = distiller.model_find_param(model, "layer1.0.conv2.weight") + assert conv2_p is not None + + assert conv2_p.dim() == 4 + num_filters = conv2_p.size(0) + num_channels = conv2_p.size(1) + kernel_height = conv2_p.size(2) + kernel_width = conv2_p.size(3) + + channels_to_remove = [0, 2] + + # Let's build our 4D mask. + # We start with a 1D mask of channels, with all but our specified channels set to one + channels = torch.ones(num_channels) + for ch in channels_to_remove: + channels[ch] = 0 + + # Now let's expand back up to a 4D mask + mask = channels.expand(num_filters, num_channels) + mask.unsqueeze_(-1) + mask.unsqueeze_(-1) + mask = mask.expand(num_filters, num_channels, kernel_height, kernel_width).contiguous() + + assert mask.shape == conv2_p.shape + assert distiller.density_ch(mask) == (conv2.in_channels - len(channels_to_remove)) / conv2.in_channels + + # Cool, so now we have a mask for pruning our channels. + # Use the mask to prune + zeros_mask_dict["layer1.0.conv2.weight"].mask = mask + zeros_mask_dict["layer1.0.conv2.weight"].apply_mask(conv2_p) + all_channels = set([ch for ch in range(num_channels)]) + channels_removed = all_channels - set(distiller.find_nonzero_channels(conv2_p, "layer1.0.conv2.weight")) + logger.info(channels_removed) + + # Now, let's do the actual network thinning + distiller.remove_channels(model, zeros_mask_dict, ARCH, DATASET) + conv1 = find_module_by_name(model, "layer1.0.conv1") + logger.info(conv1) + logger.info(conv2) + assert conv1.out_channels == 14 + assert conv2.in_channels == 14 + assert conv1.weight.size(0) == 14 + assert conv2.weight.size(1) == 14 + bn1 = find_module_by_name(model, "layer1.0.bn1") + assert bn1.running_var.size(0) == 14 + assert bn1.running_mean.size(0) == 14 + assert bn1.num_features == 14 + assert bn1.bias.size(0) == 14 + assert bn1.weight.size(0) == 14 + + # Let's test saving and loading a thinned model. + # We save 3 times, and load twice, to make sure to cover some corner cases: + # - Make sure that after loading, the model still has hold of the thinning recipes + # - Make sure that after a 2nd load, there no problem loading (in this case, the + # - tensors are already thin, so this is a new flow) + save_checkpoint(epoch=0, arch=ARCH, model=model, optimizer=None) + model_2 = create_model(False, DATASET, ARCH, parallel=False) + dummy_input = torch.randn(1, 3, 32, 32) + model(dummy_input) + model_2(dummy_input) + conv2 = find_module_by_name(model_2, "layer1.0.conv2") + assert conv2 is not None + with pytest.raises(KeyError): + model_2, compression_scheduler, start_epoch = load_checkpoint(model_2, 'checkpoint.pth.tar') + + compression_scheduler = distiller.CompressionScheduler(model) + hasattr(model, 'thinning_recipes') + save_checkpoint(epoch=0, arch=ARCH, model=model, optimizer=None, scheduler=compression_scheduler) + model_2, compression_scheduler, start_epoch = load_checkpoint(model_2, 'checkpoint.pth.tar') + assert hasattr(model_2, 'thinning_recipes') + logger.info("test_arbitrary_channel_pruning - Done") + + save_checkpoint(epoch=0, arch=ARCH, model=model_2, optimizer=None, scheduler=compression_scheduler) + model_2, compression_scheduler, start_epoch = load_checkpoint(model_2, 'checkpoint.pth.tar') + logger.info("test_arbitrary_channel_pruning - Done 2") + + + +if __name__ == '__main__': + test_arbitrary_channel_pruning()