Skip to content
Snippets Groups Projects
Unverified Commit 1210f412 authored by Neta Zmora's avatar Neta Zmora Committed by GitHub
Browse files

Fix issue #148 + refactor load_checkpoint.py (#153)

The root-cause of issue #148 is that DataParallel modules cannot execute on the CPU,
on machines that have both CPUs and GPUs.
Therefore, we don’t use DataParallel for models loaded for the CPUs, but we do wrap
the models with DataParallel when loaded on the GPUs (to make them run faster).
The names of module keys saved in a checkpoint file depend if the modules are wrapped
by a DataParallel module or not.  So loading a checkpoint that ran on the GPU onto a
CPU-model (and vice-versa) will fail on the keys.
This is all PyTorch and despite the community asking for a fix -
e.g. https://github.com/pytorch/pytorch/issues/7457 - it is still pending.

This commit contains code to catch key errors when loading a GPU-generated model
(i.e. with DataParallel) onto a CPU, and convert the names of the keys.

This PR also merges refactoring to load_chackpoint.py done by @barrh, who also added
a test to further test loading checkpoints.
parent ed976cff
No related branches found
No related tags found
No related merge requests found
...@@ -26,6 +26,7 @@ from errno import ENOENT ...@@ -26,6 +26,7 @@ from errno import ENOENT
import logging import logging
import torch import torch
import distiller import distiller
from distiller.utils import normalize_module_name
msglogger = logging.getLogger() msglogger = logging.getLogger()
...@@ -80,45 +81,60 @@ def load_checkpoint(model, chkpt_file, optimizer=None): ...@@ -80,45 +81,60 @@ def load_checkpoint(model, chkpt_file, optimizer=None):
chkpt_file: the checkpoint file chkpt_file: the checkpoint file
optimizer: the optimizer to which we will load the serialized state optimizer: the optimizer to which we will load the serialized state
""" """
if not os.path.isfile(chkpt_file):
raise IOError(ENOENT, 'Could not find a checkpoint file at', chkpt_file)
msglogger.info("=> loading checkpoint %s", chkpt_file)
checkpoint = torch.load(chkpt_file, map_location = lambda storage, loc: storage)
msglogger.debug("\n\t".join(['Checkpoint keys:'] + list(checkpoint)))
if 'state_dict' not in checkpoint:
raise ValueError("Checkpoint must contain the model parameters under the key 'state_dict'")
checkpoint_epoch = checkpoint.get('epoch', None)
start_epoch = checkpoint_epoch + 1 if checkpoint_epoch is not None else 0
best_top1 = checkpoint.get('best_top1', None)
if best_top1 is not None:
msglogger.info(" best top@1: %.3f", best_top1)
compression_scheduler = None compression_scheduler = None
start_epoch = 0 normalize_dataparallel_keys = False
if 'compression_sched' in checkpoint:
if os.path.isfile(chkpt_file): compression_scheduler = distiller.CompressionScheduler(model)
msglogger.info("=> loading checkpoint %s", chkpt_file) try:
checkpoint = torch.load(chkpt_file, map_location = lambda storage, loc: storage) compression_scheduler.load_state_dict(checkpoint['compression_sched'], normalize_dataparallel_keys)
msglogger.info("Checkpoint keys:\n{}".format("\n\t".join(k for k in checkpoint.keys()))) except KeyError as e:
start_epoch = checkpoint.get('epoch', -1) + 1 # A very common source of this KeyError is loading a GPU model on the CPU.
best_top1 = checkpoint.get('best_top1', None) # We rename all of the DataParallel keys because DataParallel does not execute on the CPU.
if best_top1 is not None: normalize_dataparallel_keys = True
msglogger.info(" best top@1: %.3f", best_top1) compression_scheduler.load_state_dict(checkpoint['compression_sched'], normalize_dataparallel_keys)
msglogger.info("Loaded compression schedule from checkpoint (epoch {})".format(
if 'compression_sched' in checkpoint: checkpoint_epoch))
compression_scheduler = distiller.CompressionScheduler(model)
compression_scheduler.load_state_dict(checkpoint['compression_sched'])
msglogger.info("Loaded compression schedule from checkpoint (epoch %d)",
checkpoint['epoch'])
else:
msglogger.info("Warning: compression schedule data does not exist in the checkpoint")
if 'thinning_recipes' in checkpoint:
if 'compression_sched' not in checkpoint:
raise KeyError("Found thinning_recipes key, but missing mandatory key compression_sched")
msglogger.info("Loaded a thinning recipe from the checkpoint")
# Cache the recipes in case we need them later
model.thinning_recipes = checkpoint['thinning_recipes']
distiller.execute_thinning_recipes_list(model,
compression_scheduler.zeros_mask_dict,
model.thinning_recipes)
if 'quantizer_metadata' in checkpoint:
msglogger.info('Loaded quantizer metadata from the checkpoint')
qmd = checkpoint['quantizer_metadata']
quantizer = qmd['type'](model, **qmd['params'])
quantizer.prepare_model()
msglogger.info("=> loaded checkpoint '%s' (epoch %d)", chkpt_file, start_epoch-1)
model.load_state_dict(checkpoint['state_dict'])
return model, compression_scheduler, start_epoch
else: else:
raise IOError(ENOENT, 'Could not find a checkpoint file at', chkpt_file) msglogger.info("Warning: compression schedule data does not exist in the checkpoint")
if 'thinning_recipes' in checkpoint:
if 'compression_sched' not in checkpoint:
raise KeyError("Found thinning_recipes key, but missing mandatory key compression_sched")
msglogger.info("Loaded a thinning recipe from the checkpoint")
# Cache the recipes in case we need them later
model.thinning_recipes = checkpoint['thinning_recipes']
if normalize_dataparallel_keys:
model.thinning_recipes = {normalize_module_name(k): v for k, v in model.thinning_recipes.items()}
distiller.execute_thinning_recipes_list(model,
compression_scheduler.zeros_mask_dict,
model.thinning_recipes)
if 'quantizer_metadata' in checkpoint:
msglogger.info('Loaded quantizer metadata from the checkpoint')
qmd = checkpoint['quantizer_metadata']
quantizer = qmd['type'](model, **qmd['params'])
quantizer.prepare_model()
msglogger.info("=> loaded checkpoint '{f}' (epoch {e})".format(f=str(chkpt_file),
e=checkpoint_epoch))
if normalize_dataparallel_keys:
checkpoint['state_dict'] = {normalize_module_name(k): v for k, v in checkpoint['state_dict'].items()}
model.load_state_dict(checkpoint['state_dict'])
return (model, compression_scheduler, start_epoch)
...@@ -18,12 +18,14 @@ ...@@ -18,12 +18,14 @@
This implements the scheduling of the compression policies. This implements the scheduling of the compression policies.
""" """
import contextlib
from functools import partial from functools import partial
import logging import logging
import torch import torch
from .quantization.quantizer import FP_BKP_PREFIX from .quantization.quantizer import FP_BKP_PREFIX
from .policy import PolicyLoss, LossComponent from .policy import PolicyLoss, LossComponent
from .utils import model_device from .utils import model_device, normalize_module_name
msglogger = logging.getLogger() msglogger = logging.getLogger()
...@@ -187,26 +189,31 @@ class CompressionScheduler(object): ...@@ -187,26 +189,31 @@ class CompressionScheduler(object):
state = {'masks_dict': masks} state = {'masks_dict': masks}
return state return state
def load_state_dict(self, state): def load_state_dict(self, state, normalize_dataparallel_keys):
"""Loads the scheduler state. """Loads the scheduler state.
Currently the scheduler state is comprised only of the set of pruning masks. Currently the scheduler state is comprised only of the set of pruning masks.
Arguments: Arguments:
state_dict (dict): scheduler state. Should be an object returned state_dict (dict): scheduler state. Should be an object returned
from a call to :meth:`state_dict`. It is a dictionary of parameter from a call to :meth:`state_dict`. It is a dictionary of parameter
names (keys) and parameter masks (values). names (keys) and parameter masks (values).
normalize_dataparallel_keys (bool): indicates if we should convert the keys from
DataParallel format. This should be set to True when loading a model
from a GPU-checkpoint onto a CPU (because currently we don't use DataParallel
on the CPU).
""" """
try: try:
loaded_masks = state['masks_dict'] loaded_masks = state['masks_dict']
except Exception as exception: except KeyError as exception:
print("ERROR: could not load the CompressionScheduler state") msglogger.error('could not load the CompressionScheduler state.'
print("Exception: %s %s" % (type(exception), exception)) ' masks_dict is missing from state')
print("\t\tFound the following keys in the state dictionary:") with contextlib.suppress(TypeError):
for k in state.keys(): msglogger.debug('Scheduler state keys are: {}'.format(', '.join(state)))
print("\t\t" + k) raise
exit(1)
if normalize_dataparallel_keys:
loaded_masks = {normalize_module_name(k): v for k, v in loaded_masks.items()}
device = model_device(self.model) device = model_device(self.model)
for name, mask in self.zeros_mask_dict.items(): for name, mask in self.zeros_mask_dict.items():
masker = self.zeros_mask_dict[name] masker = self.zeros_mask_dict[name]
......
...@@ -17,13 +17,18 @@ ...@@ -17,13 +17,18 @@
import logging import logging
import os import os
import sys import sys
import tempfile
import torch
import pytest import pytest
module_path = os.path.abspath(os.path.join('..')) module_path = os.path.abspath(os.path.join('..'))
if module_path not in sys.path: if module_path not in sys.path:
sys.path.append(module_path) sys.path.append(module_path)
from models import create_model import distiller
from apputils import load_checkpoint from apputils import load_checkpoint
from models import create_model
def test_load(): def test_load():
logger = logging.getLogger('simple_example') logger = logging.getLogger('simple_example')
...@@ -34,7 +39,42 @@ def test_load(): ...@@ -34,7 +39,42 @@ def test_load():
assert compression_scheduler is not None assert compression_scheduler is not None
assert start_epoch == 180 assert start_epoch == 180
def test_load_state_dict():
# prepare lean checkpoint
state_dict_arrays = torch.load('../examples/ssl/checkpoints/checkpoint_trained_dense.pth.tar').get('state_dict')
with tempfile.NamedTemporaryFile() as tmpfile:
torch.save({'state_dict': state_dict_arrays}, tmpfile.name)
model = create_model(False, 'cifar10', 'resnet20_cifar')
model, compression_scheduler, start_epoch = load_checkpoint(model, tmpfile.name)
assert len(list(model.named_modules())) >= len([x for x in state_dict_arrays if x.endswith('weight')]) > 0
assert compression_scheduler is None
assert start_epoch == 0
def test_load_dumb_checkpoint():
# prepare lean checkpoint
state_dict_arrays = torch.load('../examples/ssl/checkpoints/checkpoint_trained_dense.pth.tar').get('state_dict')
with tempfile.NamedTemporaryFile() as tmpfile:
torch.save(state_dict_arrays, tmpfile.name)
model = create_model(False, 'cifar10', 'resnet20_cifar')
with pytest.raises(ValueError):
model, compression_scheduler, start_epoch = load_checkpoint(model, tmpfile.name)
def test_load_negative(): def test_load_negative():
with pytest.raises(FileNotFoundError): with pytest.raises(FileNotFoundError):
model = create_model(False, 'cifar10', 'resnet20_cifar') model = create_model(False, 'cifar10', 'resnet20_cifar')
model, compression_scheduler, start_epoch = load_checkpoint(model, 'THIS_IS_AN_ERROR/checkpoint_trained_dense.pth.tar') model, compression_scheduler, start_epoch = load_checkpoint(model, 'THIS_IS_AN_ERROR/checkpoint_trained_dense.pth.tar')
def test_load_gpu_model_on_cpu():
model = create_model(False, 'cifar10', 'resnet20_cifar', device_ids=-1)
model, compression_scheduler, start_epoch = load_checkpoint(model,
'../examples/ssl/checkpoints/checkpoint_trained_dense.pth.tar')
assert compression_scheduler is not None
assert start_epoch == 180
assert distiller.model_device(model) == 'cpu'
if __name__ == '__main__':
test_load_gpu_model_on_cpu()
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment