From ba05f6cf48d4f7524f1632364d6addcb33c9d415 Mon Sep 17 00:00:00 2001 From: Neta Zmora <neta.zmora@intel.com> Date: Wed, 13 Feb 2019 01:01:19 +0200 Subject: [PATCH] CPU support: fix the case of loading a thinned GPU-model on the CPU This commit fixes (and adds a test) for the case that we with to load a thinned GPU checkpoint onto the CPU. --- apputils/checkpoint.py | 4 ++-- distiller/thinning.py | 8 +++++++- tests/test_infra.py | 37 +++++++++++++++++++++++++++++++++++-- 3 files changed, 44 insertions(+), 5 deletions(-) diff --git a/apputils/checkpoint.py b/apputils/checkpoint.py index 26f7fff..91005eb 100755 --- a/apputils/checkpoint.py +++ b/apputils/checkpoint.py @@ -85,7 +85,7 @@ def load_checkpoint(model, chkpt_file, optimizer=None): raise IOError(ENOENT, 'Could not find a checkpoint file at', chkpt_file) msglogger.info("=> loading checkpoint %s", chkpt_file) - checkpoint = torch.load(chkpt_file, map_location = lambda storage, loc: storage) + checkpoint = torch.load(chkpt_file, map_location=lambda storage, loc: storage) msglogger.debug("\n\t".join(['Checkpoint keys:'] + list(checkpoint))) if 'state_dict' not in checkpoint: @@ -121,7 +121,7 @@ def load_checkpoint(model, chkpt_file, optimizer=None): # Cache the recipes in case we need them later model.thinning_recipes = checkpoint['thinning_recipes'] if normalize_dataparallel_keys: - model.thinning_recipes = {normalize_module_name(k): v for k, v in model.thinning_recipes.items()} + model.thinning_recipes = [distiller.get_normalized_recipe(recipe) for recipe in model.thinning_recipes] distiller.execute_thinning_recipes_list(model, compression_scheduler.zeros_mask_dict, model.thinning_recipes) diff --git a/distiller/thinning.py b/distiller/thinning.py index 4346687..43ab135 100755 --- a/distiller/thinning.py +++ b/distiller/thinning.py @@ -61,7 +61,7 @@ __all__ = ['ThinningRecipe', 'resnet_cifar_remove_layers', 'ChannelRemover', 'remove_channels', 'FilterRemover', 'remove_filters', 'find_nonzero_channels', 'find_nonzero_channels_list', - 'execute_thinning_recipes_list'] + 'execute_thinning_recipes_list', 'get_normalized_recipe'] def create_graph(dataset, arch): @@ -77,6 +77,12 @@ def create_graph(dataset, arch): return SummaryGraph(model, dummy_input) +def get_normalized_recipe(recipe): + new_recipe = ThinningRecipe(modules={normalize_module_name(k): v for k, v in recipe.modules.items()}, + parameters={normalize_module_name(k): v for k, v in recipe.parameters.items()}) + return new_recipe + + def param_name_2_layer_name(param_name): return param_name[:-len('weights')] diff --git a/tests/test_infra.py b/tests/test_infra.py index a101bdb..7a15be1 100755 --- a/tests/test_infra.py +++ b/tests/test_infra.py @@ -26,7 +26,7 @@ if module_path not in sys.path: sys.path.append(module_path) import distiller -from apputils import load_checkpoint +from apputils import save_checkpoint, load_checkpoint from models import create_model @@ -39,6 +39,7 @@ def test_load(): assert compression_scheduler is not None assert start_epoch == 180 + def test_load_state_dict(): # prepare lean checkpoint state_dict_arrays = torch.load('../examples/ssl/checkpoints/checkpoint_trained_dense.pth.tar').get('state_dict') @@ -52,6 +53,7 @@ def test_load_state_dict(): assert compression_scheduler is None assert start_epoch == 0 + def test_load_dumb_checkpoint(): # prepare lean checkpoint state_dict_arrays = torch.load('../examples/ssl/checkpoints/checkpoint_trained_dense.pth.tar').get('state_dict') @@ -62,6 +64,7 @@ def test_load_dumb_checkpoint(): with pytest.raises(ValueError): model, compression_scheduler, start_epoch = load_checkpoint(model, tmpfile.name) + def test_load_negative(): with pytest.raises(FileNotFoundError): model = create_model(False, 'cifar10', 'resnet20_cifar') @@ -69,12 +72,42 @@ def test_load_negative(): def test_load_gpu_model_on_cpu(): - model = create_model(False, 'cifar10', 'resnet20_cifar', device_ids=-1) + # Issue #148 + CPU_DEVICE_ID = -1 + model = create_model(False, 'cifar10', 'resnet20_cifar', device_ids=CPU_DEVICE_ID) model, compression_scheduler, start_epoch = load_checkpoint(model, '../examples/ssl/checkpoints/checkpoint_trained_dense.pth.tar') assert compression_scheduler is not None assert start_epoch == 180 assert distiller.model_device(model) == 'cpu' + +def test_load_gpu_model_on_cpu_with_thinning(): + # Issue #148 + # 1. create a GPU model and remove 50% of the filters in one of the layers (thninning) + # 2. save the thinned model in a checkpoint file + # 3. load the checkpoint and place it on the CPU + CPU_DEVICE_ID = -1 + gpu_model = create_model(False, 'cifar10', 'resnet20_cifar') + conv_pname = "module.layer1.0.conv1.weight" + conv_p = distiller.model_find_param(gpu_model, conv_pname) + pruner = distiller.pruning.L1RankedStructureParameterPruner("test_pruner", group_type="Filters", + desired_sparsity=0.5, weights=conv_pname) + zeros_mask_dict = distiller.create_model_masks_dict(gpu_model) + pruner.set_param_mask(conv_p, conv_pname, zeros_mask_dict, meta=None) + + # Use the mask to prune + zeros_mask_dict[conv_pname].apply_mask(conv_p) + distiller.remove_filters(gpu_model, zeros_mask_dict, 'resnet20_cifar', 'cifar10', optimizer=None) + assert hasattr(gpu_model, 'thinning_recipes') + scheduler = distiller.CompressionScheduler(gpu_model) + save_checkpoint(epoch=0, arch='resnet20_cifar', model=gpu_model, scheduler=scheduler, optimizer=None) + + CPU_DEVICE_ID = -1 + cpu_model = create_model(False, 'cifar10', 'resnet20_cifar', device_ids=CPU_DEVICE_ID) + load_checkpoint(cpu_model, "checkpoint.pth.tar") + assert distiller.model_device(cpu_model) == 'cpu' + + if __name__ == '__main__': test_load_gpu_model_on_cpu() -- GitLab