diff --git a/apputils/checkpoint.py b/apputils/checkpoint.py index 3792152f7227331f1b82f9634d1e76db57dc4582..26f7fff4f61c446f63a3c79453677e4c0cba0223 100755 --- a/apputils/checkpoint.py +++ b/apputils/checkpoint.py @@ -26,6 +26,7 @@ from errno import ENOENT import logging import torch import distiller +from distiller.utils import normalize_module_name msglogger = logging.getLogger() @@ -80,45 +81,60 @@ def load_checkpoint(model, chkpt_file, optimizer=None): chkpt_file: the checkpoint file optimizer: the optimizer to which we will load the serialized state """ + if not os.path.isfile(chkpt_file): + raise IOError(ENOENT, 'Could not find a checkpoint file at', chkpt_file) + + msglogger.info("=> loading checkpoint %s", chkpt_file) + checkpoint = torch.load(chkpt_file, map_location = lambda storage, loc: storage) + msglogger.debug("\n\t".join(['Checkpoint keys:'] + list(checkpoint))) + + if 'state_dict' not in checkpoint: + raise ValueError("Checkpoint must contain the model parameters under the key 'state_dict'") + + checkpoint_epoch = checkpoint.get('epoch', None) + start_epoch = checkpoint_epoch + 1 if checkpoint_epoch is not None else 0 + + best_top1 = checkpoint.get('best_top1', None) + if best_top1 is not None: + msglogger.info(" best top@1: %.3f", best_top1) + compression_scheduler = None - start_epoch = 0 - - if os.path.isfile(chkpt_file): - msglogger.info("=> loading checkpoint %s", chkpt_file) - checkpoint = torch.load(chkpt_file, map_location = lambda storage, loc: storage) - msglogger.info("Checkpoint keys:\n{}".format("\n\t".join(k for k in checkpoint.keys()))) - start_epoch = checkpoint.get('epoch', -1) + 1 - best_top1 = checkpoint.get('best_top1', None) - if best_top1 is not None: - msglogger.info(" best top@1: %.3f", best_top1) - - if 'compression_sched' in checkpoint: - compression_scheduler = distiller.CompressionScheduler(model) - compression_scheduler.load_state_dict(checkpoint['compression_sched']) - msglogger.info("Loaded compression schedule from checkpoint (epoch %d)", - checkpoint['epoch']) - else: - msglogger.info("Warning: compression schedule data does not exist in the checkpoint") - - if 'thinning_recipes' in checkpoint: - if 'compression_sched' not in checkpoint: - raise KeyError("Found thinning_recipes key, but missing mandatory key compression_sched") - msglogger.info("Loaded a thinning recipe from the checkpoint") - # Cache the recipes in case we need them later - model.thinning_recipes = checkpoint['thinning_recipes'] - distiller.execute_thinning_recipes_list(model, - compression_scheduler.zeros_mask_dict, - model.thinning_recipes) - - if 'quantizer_metadata' in checkpoint: - msglogger.info('Loaded quantizer metadata from the checkpoint') - qmd = checkpoint['quantizer_metadata'] - quantizer = qmd['type'](model, **qmd['params']) - quantizer.prepare_model() - - msglogger.info("=> loaded checkpoint '%s' (epoch %d)", chkpt_file, start_epoch-1) - - model.load_state_dict(checkpoint['state_dict']) - return model, compression_scheduler, start_epoch + normalize_dataparallel_keys = False + if 'compression_sched' in checkpoint: + compression_scheduler = distiller.CompressionScheduler(model) + try: + compression_scheduler.load_state_dict(checkpoint['compression_sched'], normalize_dataparallel_keys) + except KeyError as e: + # A very common source of this KeyError is loading a GPU model on the CPU. + # We rename all of the DataParallel keys because DataParallel does not execute on the CPU. + normalize_dataparallel_keys = True + compression_scheduler.load_state_dict(checkpoint['compression_sched'], normalize_dataparallel_keys) + msglogger.info("Loaded compression schedule from checkpoint (epoch {})".format( + checkpoint_epoch)) else: - raise IOError(ENOENT, 'Could not find a checkpoint file at', chkpt_file) + msglogger.info("Warning: compression schedule data does not exist in the checkpoint") + + if 'thinning_recipes' in checkpoint: + if 'compression_sched' not in checkpoint: + raise KeyError("Found thinning_recipes key, but missing mandatory key compression_sched") + msglogger.info("Loaded a thinning recipe from the checkpoint") + # Cache the recipes in case we need them later + model.thinning_recipes = checkpoint['thinning_recipes'] + if normalize_dataparallel_keys: + model.thinning_recipes = {normalize_module_name(k): v for k, v in model.thinning_recipes.items()} + distiller.execute_thinning_recipes_list(model, + compression_scheduler.zeros_mask_dict, + model.thinning_recipes) + + if 'quantizer_metadata' in checkpoint: + msglogger.info('Loaded quantizer metadata from the checkpoint') + qmd = checkpoint['quantizer_metadata'] + quantizer = qmd['type'](model, **qmd['params']) + quantizer.prepare_model() + + msglogger.info("=> loaded checkpoint '{f}' (epoch {e})".format(f=str(chkpt_file), + e=checkpoint_epoch)) + if normalize_dataparallel_keys: + checkpoint['state_dict'] = {normalize_module_name(k): v for k, v in checkpoint['state_dict'].items()} + model.load_state_dict(checkpoint['state_dict']) + return (model, compression_scheduler, start_epoch) diff --git a/distiller/scheduler.py b/distiller/scheduler.py index 6e65238a0abc0d990207fd0e5782c3517bdddf2c..a5e360d01d3e7dfa0bb861d6d9a87fda83ed1fa6 100755 --- a/distiller/scheduler.py +++ b/distiller/scheduler.py @@ -18,12 +18,14 @@ This implements the scheduling of the compression policies. """ +import contextlib from functools import partial import logging + import torch from .quantization.quantizer import FP_BKP_PREFIX from .policy import PolicyLoss, LossComponent -from .utils import model_device +from .utils import model_device, normalize_module_name msglogger = logging.getLogger() @@ -187,26 +189,31 @@ class CompressionScheduler(object): state = {'masks_dict': masks} return state - def load_state_dict(self, state): + def load_state_dict(self, state, normalize_dataparallel_keys): """Loads the scheduler state. Currently the scheduler state is comprised only of the set of pruning masks. Arguments: state_dict (dict): scheduler state. Should be an object returned - from a call to :meth:`state_dict`. It is a dictionary of parameter + from a call to :meth:`state_dict`. It is a dictionary of parameter names (keys) and parameter masks (values). + normalize_dataparallel_keys (bool): indicates if we should convert the keys from + DataParallel format. This should be set to True when loading a model + from a GPU-checkpoint onto a CPU (because currently we don't use DataParallel + on the CPU). """ try: loaded_masks = state['masks_dict'] - except Exception as exception: - print("ERROR: could not load the CompressionScheduler state") - print("Exception: %s %s" % (type(exception), exception)) - print("\t\tFound the following keys in the state dictionary:") - for k in state.keys(): - print("\t\t" + k) - exit(1) - + except KeyError as exception: + msglogger.error('could not load the CompressionScheduler state.' + ' masks_dict is missing from state') + with contextlib.suppress(TypeError): + msglogger.debug('Scheduler state keys are: {}'.format(', '.join(state))) + raise + + if normalize_dataparallel_keys: + loaded_masks = {normalize_module_name(k): v for k, v in loaded_masks.items()} device = model_device(self.model) for name, mask in self.zeros_mask_dict.items(): masker = self.zeros_mask_dict[name] diff --git a/tests/test_infra.py b/tests/test_infra.py index 099598016c5368b19096379ab8ff592473acd2a1..a101bdb82d4b5d5e058a0e53992b7a4b6e153cab 100755 --- a/tests/test_infra.py +++ b/tests/test_infra.py @@ -17,13 +17,18 @@ import logging import os import sys +import tempfile + +import torch import pytest module_path = os.path.abspath(os.path.join('..')) if module_path not in sys.path: sys.path.append(module_path) -from models import create_model +import distiller from apputils import load_checkpoint +from models import create_model + def test_load(): logger = logging.getLogger('simple_example') @@ -34,7 +39,42 @@ def test_load(): assert compression_scheduler is not None assert start_epoch == 180 +def test_load_state_dict(): + # prepare lean checkpoint + state_dict_arrays = torch.load('../examples/ssl/checkpoints/checkpoint_trained_dense.pth.tar').get('state_dict') + + with tempfile.NamedTemporaryFile() as tmpfile: + torch.save({'state_dict': state_dict_arrays}, tmpfile.name) + model = create_model(False, 'cifar10', 'resnet20_cifar') + model, compression_scheduler, start_epoch = load_checkpoint(model, tmpfile.name) + + assert len(list(model.named_modules())) >= len([x for x in state_dict_arrays if x.endswith('weight')]) > 0 + assert compression_scheduler is None + assert start_epoch == 0 + +def test_load_dumb_checkpoint(): + # prepare lean checkpoint + state_dict_arrays = torch.load('../examples/ssl/checkpoints/checkpoint_trained_dense.pth.tar').get('state_dict') + + with tempfile.NamedTemporaryFile() as tmpfile: + torch.save(state_dict_arrays, tmpfile.name) + model = create_model(False, 'cifar10', 'resnet20_cifar') + with pytest.raises(ValueError): + model, compression_scheduler, start_epoch = load_checkpoint(model, tmpfile.name) + def test_load_negative(): with pytest.raises(FileNotFoundError): model = create_model(False, 'cifar10', 'resnet20_cifar') model, compression_scheduler, start_epoch = load_checkpoint(model, 'THIS_IS_AN_ERROR/checkpoint_trained_dense.pth.tar') + + +def test_load_gpu_model_on_cpu(): + model = create_model(False, 'cifar10', 'resnet20_cifar', device_ids=-1) + model, compression_scheduler, start_epoch = load_checkpoint(model, + '../examples/ssl/checkpoints/checkpoint_trained_dense.pth.tar') + assert compression_scheduler is not None + assert start_epoch == 180 + assert distiller.model_device(model) == 'cpu' + +if __name__ == '__main__': + test_load_gpu_model_on_cpu()