From ba05f6cf48d4f7524f1632364d6addcb33c9d415 Mon Sep 17 00:00:00 2001
From: Neta Zmora <neta.zmora@intel.com>
Date: Wed, 13 Feb 2019 01:01:19 +0200
Subject: [PATCH] CPU support: fix the case of loading a thinned GPU-model on
 the CPU

This commit fixes (and adds a test) for the case that we with to load
a thinned GPU checkpoint onto the CPU.
---
 apputils/checkpoint.py |  4 ++--
 distiller/thinning.py  |  8 +++++++-
 tests/test_infra.py    | 37 +++++++++++++++++++++++++++++++++++--
 3 files changed, 44 insertions(+), 5 deletions(-)

diff --git a/apputils/checkpoint.py b/apputils/checkpoint.py
index 26f7fff..91005eb 100755
--- a/apputils/checkpoint.py
+++ b/apputils/checkpoint.py
@@ -85,7 +85,7 @@ def load_checkpoint(model, chkpt_file, optimizer=None):
         raise IOError(ENOENT, 'Could not find a checkpoint file at', chkpt_file)
 
     msglogger.info("=> loading checkpoint %s", chkpt_file)
-    checkpoint = torch.load(chkpt_file, map_location = lambda storage, loc: storage)
+    checkpoint = torch.load(chkpt_file, map_location=lambda storage, loc: storage)
     msglogger.debug("\n\t".join(['Checkpoint keys:'] + list(checkpoint)))
 
     if 'state_dict' not in checkpoint:
@@ -121,7 +121,7 @@ def load_checkpoint(model, chkpt_file, optimizer=None):
         # Cache the recipes in case we need them later
         model.thinning_recipes = checkpoint['thinning_recipes']
         if normalize_dataparallel_keys:
-            model.thinning_recipes = {normalize_module_name(k): v for k, v in model.thinning_recipes.items()}         
+            model.thinning_recipes = [distiller.get_normalized_recipe(recipe) for recipe in model.thinning_recipes]
         distiller.execute_thinning_recipes_list(model,
                                                 compression_scheduler.zeros_mask_dict,
                                                 model.thinning_recipes)
diff --git a/distiller/thinning.py b/distiller/thinning.py
index 4346687..43ab135 100755
--- a/distiller/thinning.py
+++ b/distiller/thinning.py
@@ -61,7 +61,7 @@ __all__ = ['ThinningRecipe', 'resnet_cifar_remove_layers',
            'ChannelRemover', 'remove_channels',
            'FilterRemover',  'remove_filters',
            'find_nonzero_channels', 'find_nonzero_channels_list',
-           'execute_thinning_recipes_list']
+           'execute_thinning_recipes_list', 'get_normalized_recipe']
 
 
 def create_graph(dataset, arch):
@@ -77,6 +77,12 @@ def create_graph(dataset, arch):
     return SummaryGraph(model, dummy_input)
 
 
+def get_normalized_recipe(recipe):
+    new_recipe = ThinningRecipe(modules={normalize_module_name(k): v for k, v in recipe.modules.items()},
+                                parameters={normalize_module_name(k): v for k, v in recipe.parameters.items()})
+    return new_recipe
+
+
 def param_name_2_layer_name(param_name):
     return param_name[:-len('weights')]
 
diff --git a/tests/test_infra.py b/tests/test_infra.py
index a101bdb..7a15be1 100755
--- a/tests/test_infra.py
+++ b/tests/test_infra.py
@@ -26,7 +26,7 @@ if module_path not in sys.path:
     sys.path.append(module_path)
 
 import distiller
-from apputils import load_checkpoint
+from apputils import save_checkpoint, load_checkpoint
 from models import create_model
 
 
@@ -39,6 +39,7 @@ def test_load():
     assert compression_scheduler is not None
     assert start_epoch == 180
 
+
 def test_load_state_dict():
     # prepare lean checkpoint
     state_dict_arrays = torch.load('../examples/ssl/checkpoints/checkpoint_trained_dense.pth.tar').get('state_dict')
@@ -52,6 +53,7 @@ def test_load_state_dict():
     assert compression_scheduler is None
     assert start_epoch == 0
 
+
 def test_load_dumb_checkpoint():
     # prepare lean checkpoint
     state_dict_arrays = torch.load('../examples/ssl/checkpoints/checkpoint_trained_dense.pth.tar').get('state_dict')
@@ -62,6 +64,7 @@ def test_load_dumb_checkpoint():
         with pytest.raises(ValueError):
             model, compression_scheduler, start_epoch = load_checkpoint(model, tmpfile.name)
 
+
 def test_load_negative():
     with pytest.raises(FileNotFoundError):
         model = create_model(False, 'cifar10', 'resnet20_cifar')
@@ -69,12 +72,42 @@ def test_load_negative():
 
 
 def test_load_gpu_model_on_cpu():
-    model = create_model(False, 'cifar10', 'resnet20_cifar', device_ids=-1)
+    # Issue #148
+    CPU_DEVICE_ID = -1
+    model = create_model(False, 'cifar10', 'resnet20_cifar', device_ids=CPU_DEVICE_ID)
     model, compression_scheduler, start_epoch = load_checkpoint(model,
                                                                 '../examples/ssl/checkpoints/checkpoint_trained_dense.pth.tar')
     assert compression_scheduler is not None
     assert start_epoch == 180
     assert distiller.model_device(model) == 'cpu'
 
+
+def test_load_gpu_model_on_cpu_with_thinning():
+    # Issue #148
+    # 1. create a GPU model and remove 50% of the filters in one of the layers (thninning)
+    # 2. save the thinned model in a checkpoint file
+    # 3. load the checkpoint and place it on the CPU
+    CPU_DEVICE_ID = -1
+    gpu_model = create_model(False, 'cifar10', 'resnet20_cifar')
+    conv_pname = "module.layer1.0.conv1.weight"
+    conv_p = distiller.model_find_param(gpu_model, conv_pname)
+    pruner = distiller.pruning.L1RankedStructureParameterPruner("test_pruner", group_type="Filters",
+                                                                desired_sparsity=0.5, weights=conv_pname)
+    zeros_mask_dict = distiller.create_model_masks_dict(gpu_model)
+    pruner.set_param_mask(conv_p, conv_pname, zeros_mask_dict, meta=None)
+
+    # Use the mask to prune
+    zeros_mask_dict[conv_pname].apply_mask(conv_p)
+    distiller.remove_filters(gpu_model, zeros_mask_dict, 'resnet20_cifar', 'cifar10', optimizer=None)
+    assert hasattr(gpu_model, 'thinning_recipes')
+    scheduler = distiller.CompressionScheduler(gpu_model)
+    save_checkpoint(epoch=0, arch='resnet20_cifar', model=gpu_model, scheduler=scheduler, optimizer=None)
+
+    CPU_DEVICE_ID = -1
+    cpu_model = create_model(False, 'cifar10', 'resnet20_cifar', device_ids=CPU_DEVICE_ID)
+    load_checkpoint(cpu_model, "checkpoint.pth.tar")
+    assert distiller.model_device(cpu_model) == 'cpu'
+
+
 if __name__ == '__main__':
     test_load_gpu_model_on_cpu()
-- 
GitLab