From 1210f412be838436e5460f374be9acc6acb7b749 Mon Sep 17 00:00:00 2001
From: Neta Zmora <31280975+nzmora@users.noreply.github.com>
Date: Tue, 12 Feb 2019 16:44:04 +0200
Subject: [PATCH] Fix issue #148 + refactor load_checkpoint.py (#153)
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit

The root-cause of issue #148 is that DataParallel modules cannot execute on the CPU,
on machines that have both CPUs and GPUs.
Therefore, we don’t use DataParallel for models loaded for the CPUs, but we do wrap
the models with DataParallel when loaded on the GPUs (to make them run faster).
The names of module keys saved in a checkpoint file depend if the modules are wrapped
by a DataParallel module or not.  So loading a checkpoint that ran on the GPU onto a
CPU-model (and vice-versa) will fail on the keys.
This is all PyTorch and despite the community asking for a fix -
e.g. https://github.com/pytorch/pytorch/issues/7457 - it is still pending.

This commit contains code to catch key errors when loading a GPU-generated model
(i.e. with DataParallel) onto a CPU, and convert the names of the keys.

This PR also merges refactoring to load_chackpoint.py done by @barrh, who also added
a test to further test loading checkpoints.
---
 apputils/checkpoint.py | 96 ++++++++++++++++++++++++------------------
 distiller/scheduler.py | 29 ++++++++-----
 tests/test_infra.py    | 42 +++++++++++++++++-
 3 files changed, 115 insertions(+), 52 deletions(-)

diff --git a/apputils/checkpoint.py b/apputils/checkpoint.py
index 3792152..26f7fff 100755
--- a/apputils/checkpoint.py
+++ b/apputils/checkpoint.py
@@ -26,6 +26,7 @@ from errno import ENOENT
 import logging
 import torch
 import distiller
+from distiller.utils import normalize_module_name
 msglogger = logging.getLogger()
 
 
@@ -80,45 +81,60 @@ def load_checkpoint(model, chkpt_file, optimizer=None):
         chkpt_file: the checkpoint file
         optimizer: the optimizer to which we will load the serialized state
     """
+    if not os.path.isfile(chkpt_file):
+        raise IOError(ENOENT, 'Could not find a checkpoint file at', chkpt_file)
+
+    msglogger.info("=> loading checkpoint %s", chkpt_file)
+    checkpoint = torch.load(chkpt_file, map_location = lambda storage, loc: storage)
+    msglogger.debug("\n\t".join(['Checkpoint keys:'] + list(checkpoint)))
+
+    if 'state_dict' not in checkpoint:
+        raise ValueError("Checkpoint must contain the model parameters under the key 'state_dict'")
+
+    checkpoint_epoch = checkpoint.get('epoch', None)
+    start_epoch = checkpoint_epoch + 1 if checkpoint_epoch is not None else 0
+
+    best_top1 = checkpoint.get('best_top1', None)
+    if best_top1 is not None:
+        msglogger.info("   best top@1: %.3f", best_top1)
+
     compression_scheduler = None
-    start_epoch = 0
-
-    if os.path.isfile(chkpt_file):
-        msglogger.info("=> loading checkpoint %s", chkpt_file)
-        checkpoint = torch.load(chkpt_file, map_location = lambda storage, loc: storage)
-        msglogger.info("Checkpoint keys:\n{}".format("\n\t".join(k for k in checkpoint.keys())))
-        start_epoch = checkpoint.get('epoch', -1) + 1
-        best_top1 = checkpoint.get('best_top1', None)
-        if best_top1 is not None:
-            msglogger.info("   best top@1: %.3f", best_top1)
-
-        if 'compression_sched' in checkpoint:
-            compression_scheduler = distiller.CompressionScheduler(model)
-            compression_scheduler.load_state_dict(checkpoint['compression_sched'])
-            msglogger.info("Loaded compression schedule from checkpoint (epoch %d)",
-                           checkpoint['epoch'])
-        else:
-            msglogger.info("Warning: compression schedule data does not exist in the checkpoint")
-
-        if 'thinning_recipes' in checkpoint:
-            if 'compression_sched' not in checkpoint:
-                raise KeyError("Found thinning_recipes key, but missing mandatory key compression_sched")
-            msglogger.info("Loaded a thinning recipe from the checkpoint")
-            # Cache the recipes in case we need them later
-            model.thinning_recipes = checkpoint['thinning_recipes']
-            distiller.execute_thinning_recipes_list(model,
-                                                    compression_scheduler.zeros_mask_dict,
-                                                    model.thinning_recipes)
-
-        if 'quantizer_metadata' in checkpoint:
-            msglogger.info('Loaded quantizer metadata from the checkpoint')
-            qmd = checkpoint['quantizer_metadata']
-            quantizer = qmd['type'](model, **qmd['params'])
-            quantizer.prepare_model()
-
-        msglogger.info("=> loaded checkpoint '%s' (epoch %d)", chkpt_file, start_epoch-1)
-
-        model.load_state_dict(checkpoint['state_dict'])
-        return model, compression_scheduler, start_epoch
+    normalize_dataparallel_keys = False
+    if 'compression_sched' in checkpoint:
+        compression_scheduler = distiller.CompressionScheduler(model)
+        try:
+            compression_scheduler.load_state_dict(checkpoint['compression_sched'], normalize_dataparallel_keys)
+        except KeyError as e:
+            # A very common source of this KeyError is loading a GPU model on the CPU.
+            # We rename all of the DataParallel keys because DataParallel does not execute on the CPU.
+            normalize_dataparallel_keys = True
+            compression_scheduler.load_state_dict(checkpoint['compression_sched'], normalize_dataparallel_keys)
+        msglogger.info("Loaded compression schedule from checkpoint (epoch {})".format(
+            checkpoint_epoch))
     else:
-        raise IOError(ENOENT, 'Could not find a checkpoint file at', chkpt_file)
+        msglogger.info("Warning: compression schedule data does not exist in the checkpoint")
+
+    if 'thinning_recipes' in checkpoint:
+        if 'compression_sched' not in checkpoint:
+            raise KeyError("Found thinning_recipes key, but missing mandatory key compression_sched")
+        msglogger.info("Loaded a thinning recipe from the checkpoint")
+        # Cache the recipes in case we need them later
+        model.thinning_recipes = checkpoint['thinning_recipes']
+        if normalize_dataparallel_keys:
+            model.thinning_recipes = {normalize_module_name(k): v for k, v in model.thinning_recipes.items()}         
+        distiller.execute_thinning_recipes_list(model,
+                                                compression_scheduler.zeros_mask_dict,
+                                                model.thinning_recipes)
+
+    if 'quantizer_metadata' in checkpoint:
+        msglogger.info('Loaded quantizer metadata from the checkpoint')
+        qmd = checkpoint['quantizer_metadata']
+        quantizer = qmd['type'](model, **qmd['params'])
+        quantizer.prepare_model()
+
+    msglogger.info("=> loaded checkpoint '{f}' (epoch {e})".format(f=str(chkpt_file),
+                                                                   e=checkpoint_epoch))
+    if normalize_dataparallel_keys:
+            checkpoint['state_dict'] = {normalize_module_name(k): v for k, v in checkpoint['state_dict'].items()}
+    model.load_state_dict(checkpoint['state_dict'])
+    return (model, compression_scheduler, start_epoch)
diff --git a/distiller/scheduler.py b/distiller/scheduler.py
index 6e65238..a5e360d 100755
--- a/distiller/scheduler.py
+++ b/distiller/scheduler.py
@@ -18,12 +18,14 @@
 
 This implements the scheduling of the compression policies.
 """
+import contextlib
 from functools import partial
 import logging
+
 import torch
 from .quantization.quantizer import FP_BKP_PREFIX
 from .policy import PolicyLoss, LossComponent
-from .utils import model_device
+from .utils import model_device, normalize_module_name
 msglogger = logging.getLogger()
 
 
@@ -187,26 +189,31 @@ class CompressionScheduler(object):
         state = {'masks_dict': masks}
         return state
 
-    def load_state_dict(self, state):
+    def load_state_dict(self, state, normalize_dataparallel_keys):
         """Loads the scheduler state.
 
         Currently the scheduler state is comprised only of the set of pruning masks.
 
         Arguments:
             state_dict (dict): scheduler state. Should be an object returned
-                from a call to :meth:`state_dict`.  It is a dictionary of parameter
+                from a call to :meth:`state_dict`. It is a dictionary of parameter
                 names (keys) and parameter masks (values).
+            normalize_dataparallel_keys (bool): indicates if we should convert the keys from
+                DataParallel format.  This should be set to True when loading a model
+                from a GPU-checkpoint onto a CPU (because currently we don't use DataParallel
+                on the CPU).
         """
         try:
             loaded_masks = state['masks_dict']
-        except Exception as exception:
-            print("ERROR: could not load the CompressionScheduler state")
-            print("Exception: %s %s" % (type(exception), exception))
-            print("\t\tFound the following keys in the state dictionary:")
-            for k in state.keys():
-                print("\t\t" + k)
-            exit(1)
-
+        except KeyError as exception:
+            msglogger.error('could not load the CompressionScheduler state.'
+                ' masks_dict is missing from state')
+            with contextlib.suppress(TypeError):
+                msglogger.debug('Scheduler state keys are: {}'.format(', '.join(state)))
+            raise
+
+        if normalize_dataparallel_keys:
+            loaded_masks = {normalize_module_name(k): v for k, v in loaded_masks.items()}
         device = model_device(self.model)
         for name, mask in self.zeros_mask_dict.items():
             masker = self.zeros_mask_dict[name]
diff --git a/tests/test_infra.py b/tests/test_infra.py
index 0995980..a101bdb 100755
--- a/tests/test_infra.py
+++ b/tests/test_infra.py
@@ -17,13 +17,18 @@
 import logging
 import os
 import sys
+import tempfile
+
+import torch
 import pytest
 module_path = os.path.abspath(os.path.join('..'))
 if module_path not in sys.path:
     sys.path.append(module_path)
 
-from models import create_model
+import distiller
 from apputils import load_checkpoint
+from models import create_model
+
 
 def test_load():
     logger = logging.getLogger('simple_example')
@@ -34,7 +39,42 @@ def test_load():
     assert compression_scheduler is not None
     assert start_epoch == 180
 
+def test_load_state_dict():
+    # prepare lean checkpoint
+    state_dict_arrays = torch.load('../examples/ssl/checkpoints/checkpoint_trained_dense.pth.tar').get('state_dict')
+
+    with tempfile.NamedTemporaryFile() as tmpfile:
+        torch.save({'state_dict': state_dict_arrays}, tmpfile.name)
+        model = create_model(False, 'cifar10', 'resnet20_cifar')
+        model, compression_scheduler, start_epoch = load_checkpoint(model, tmpfile.name)
+
+    assert len(list(model.named_modules())) >= len([x for x in state_dict_arrays if x.endswith('weight')]) > 0
+    assert compression_scheduler is None
+    assert start_epoch == 0
+
+def test_load_dumb_checkpoint():
+    # prepare lean checkpoint
+    state_dict_arrays = torch.load('../examples/ssl/checkpoints/checkpoint_trained_dense.pth.tar').get('state_dict')
+
+    with tempfile.NamedTemporaryFile() as tmpfile:
+        torch.save(state_dict_arrays, tmpfile.name)
+        model = create_model(False, 'cifar10', 'resnet20_cifar')
+        with pytest.raises(ValueError):
+            model, compression_scheduler, start_epoch = load_checkpoint(model, tmpfile.name)
+
 def test_load_negative():
     with pytest.raises(FileNotFoundError):
         model = create_model(False, 'cifar10', 'resnet20_cifar')
         model, compression_scheduler, start_epoch = load_checkpoint(model, 'THIS_IS_AN_ERROR/checkpoint_trained_dense.pth.tar')
+
+
+def test_load_gpu_model_on_cpu():
+    model = create_model(False, 'cifar10', 'resnet20_cifar', device_ids=-1)
+    model, compression_scheduler, start_epoch = load_checkpoint(model,
+                                                                '../examples/ssl/checkpoints/checkpoint_trained_dense.pth.tar')
+    assert compression_scheduler is not None
+    assert start_epoch == 180
+    assert distiller.model_device(model) == 'cpu'
+
+if __name__ == '__main__':
+    test_load_gpu_model_on_cpu()
-- 
GitLab