Skip to content
Snippets Groups Projects
Unverified Commit 1210f412 authored by Neta Zmora's avatar Neta Zmora Committed by GitHub
Browse files

Fix issue #148 + refactor load_checkpoint.py (#153)

The root-cause of issue #148 is that DataParallel modules cannot execute on the CPU,
on machines that have both CPUs and GPUs.
Therefore, we don’t use DataParallel for models loaded for the CPUs, but we do wrap
the models with DataParallel when loaded on the GPUs (to make them run faster).
The names of module keys saved in a checkpoint file depend if the modules are wrapped
by a DataParallel module or not.  So loading a checkpoint that ran on the GPU onto a
CPU-model (and vice-versa) will fail on the keys.
This is all PyTorch and despite the community asking for a fix -
e.g. https://github.com/pytorch/pytorch/issues/7457 - it is still pending.

This commit contains code to catch key errors when loading a GPU-generated model
(i.e. with DataParallel) onto a CPU, and convert the names of the keys.

This PR also merges refactoring to load_chackpoint.py done by @barrh, who also added
a test to further test loading checkpoints.
parent ed976cff
No related branches found
No related tags found
No related merge requests found
......@@ -26,6 +26,7 @@ from errno import ENOENT
import logging
import torch
import distiller
from distiller.utils import normalize_module_name
msglogger = logging.getLogger()
......@@ -80,45 +81,60 @@ def load_checkpoint(model, chkpt_file, optimizer=None):
chkpt_file: the checkpoint file
optimizer: the optimizer to which we will load the serialized state
"""
if not os.path.isfile(chkpt_file):
raise IOError(ENOENT, 'Could not find a checkpoint file at', chkpt_file)
msglogger.info("=> loading checkpoint %s", chkpt_file)
checkpoint = torch.load(chkpt_file, map_location = lambda storage, loc: storage)
msglogger.debug("\n\t".join(['Checkpoint keys:'] + list(checkpoint)))
if 'state_dict' not in checkpoint:
raise ValueError("Checkpoint must contain the model parameters under the key 'state_dict'")
checkpoint_epoch = checkpoint.get('epoch', None)
start_epoch = checkpoint_epoch + 1 if checkpoint_epoch is not None else 0
best_top1 = checkpoint.get('best_top1', None)
if best_top1 is not None:
msglogger.info(" best top@1: %.3f", best_top1)
compression_scheduler = None
start_epoch = 0
if os.path.isfile(chkpt_file):
msglogger.info("=> loading checkpoint %s", chkpt_file)
checkpoint = torch.load(chkpt_file, map_location = lambda storage, loc: storage)
msglogger.info("Checkpoint keys:\n{}".format("\n\t".join(k for k in checkpoint.keys())))
start_epoch = checkpoint.get('epoch', -1) + 1
best_top1 = checkpoint.get('best_top1', None)
if best_top1 is not None:
msglogger.info(" best top@1: %.3f", best_top1)
if 'compression_sched' in checkpoint:
compression_scheduler = distiller.CompressionScheduler(model)
compression_scheduler.load_state_dict(checkpoint['compression_sched'])
msglogger.info("Loaded compression schedule from checkpoint (epoch %d)",
checkpoint['epoch'])
else:
msglogger.info("Warning: compression schedule data does not exist in the checkpoint")
if 'thinning_recipes' in checkpoint:
if 'compression_sched' not in checkpoint:
raise KeyError("Found thinning_recipes key, but missing mandatory key compression_sched")
msglogger.info("Loaded a thinning recipe from the checkpoint")
# Cache the recipes in case we need them later
model.thinning_recipes = checkpoint['thinning_recipes']
distiller.execute_thinning_recipes_list(model,
compression_scheduler.zeros_mask_dict,
model.thinning_recipes)
if 'quantizer_metadata' in checkpoint:
msglogger.info('Loaded quantizer metadata from the checkpoint')
qmd = checkpoint['quantizer_metadata']
quantizer = qmd['type'](model, **qmd['params'])
quantizer.prepare_model()
msglogger.info("=> loaded checkpoint '%s' (epoch %d)", chkpt_file, start_epoch-1)
model.load_state_dict(checkpoint['state_dict'])
return model, compression_scheduler, start_epoch
normalize_dataparallel_keys = False
if 'compression_sched' in checkpoint:
compression_scheduler = distiller.CompressionScheduler(model)
try:
compression_scheduler.load_state_dict(checkpoint['compression_sched'], normalize_dataparallel_keys)
except KeyError as e:
# A very common source of this KeyError is loading a GPU model on the CPU.
# We rename all of the DataParallel keys because DataParallel does not execute on the CPU.
normalize_dataparallel_keys = True
compression_scheduler.load_state_dict(checkpoint['compression_sched'], normalize_dataparallel_keys)
msglogger.info("Loaded compression schedule from checkpoint (epoch {})".format(
checkpoint_epoch))
else:
raise IOError(ENOENT, 'Could not find a checkpoint file at', chkpt_file)
msglogger.info("Warning: compression schedule data does not exist in the checkpoint")
if 'thinning_recipes' in checkpoint:
if 'compression_sched' not in checkpoint:
raise KeyError("Found thinning_recipes key, but missing mandatory key compression_sched")
msglogger.info("Loaded a thinning recipe from the checkpoint")
# Cache the recipes in case we need them later
model.thinning_recipes = checkpoint['thinning_recipes']
if normalize_dataparallel_keys:
model.thinning_recipes = {normalize_module_name(k): v for k, v in model.thinning_recipes.items()}
distiller.execute_thinning_recipes_list(model,
compression_scheduler.zeros_mask_dict,
model.thinning_recipes)
if 'quantizer_metadata' in checkpoint:
msglogger.info('Loaded quantizer metadata from the checkpoint')
qmd = checkpoint['quantizer_metadata']
quantizer = qmd['type'](model, **qmd['params'])
quantizer.prepare_model()
msglogger.info("=> loaded checkpoint '{f}' (epoch {e})".format(f=str(chkpt_file),
e=checkpoint_epoch))
if normalize_dataparallel_keys:
checkpoint['state_dict'] = {normalize_module_name(k): v for k, v in checkpoint['state_dict'].items()}
model.load_state_dict(checkpoint['state_dict'])
return (model, compression_scheduler, start_epoch)
......@@ -18,12 +18,14 @@
This implements the scheduling of the compression policies.
"""
import contextlib
from functools import partial
import logging
import torch
from .quantization.quantizer import FP_BKP_PREFIX
from .policy import PolicyLoss, LossComponent
from .utils import model_device
from .utils import model_device, normalize_module_name
msglogger = logging.getLogger()
......@@ -187,26 +189,31 @@ class CompressionScheduler(object):
state = {'masks_dict': masks}
return state
def load_state_dict(self, state):
def load_state_dict(self, state, normalize_dataparallel_keys):
"""Loads the scheduler state.
Currently the scheduler state is comprised only of the set of pruning masks.
Arguments:
state_dict (dict): scheduler state. Should be an object returned
from a call to :meth:`state_dict`. It is a dictionary of parameter
from a call to :meth:`state_dict`. It is a dictionary of parameter
names (keys) and parameter masks (values).
normalize_dataparallel_keys (bool): indicates if we should convert the keys from
DataParallel format. This should be set to True when loading a model
from a GPU-checkpoint onto a CPU (because currently we don't use DataParallel
on the CPU).
"""
try:
loaded_masks = state['masks_dict']
except Exception as exception:
print("ERROR: could not load the CompressionScheduler state")
print("Exception: %s %s" % (type(exception), exception))
print("\t\tFound the following keys in the state dictionary:")
for k in state.keys():
print("\t\t" + k)
exit(1)
except KeyError as exception:
msglogger.error('could not load the CompressionScheduler state.'
' masks_dict is missing from state')
with contextlib.suppress(TypeError):
msglogger.debug('Scheduler state keys are: {}'.format(', '.join(state)))
raise
if normalize_dataparallel_keys:
loaded_masks = {normalize_module_name(k): v for k, v in loaded_masks.items()}
device = model_device(self.model)
for name, mask in self.zeros_mask_dict.items():
masker = self.zeros_mask_dict[name]
......
......@@ -17,13 +17,18 @@
import logging
import os
import sys
import tempfile
import torch
import pytest
module_path = os.path.abspath(os.path.join('..'))
if module_path not in sys.path:
sys.path.append(module_path)
from models import create_model
import distiller
from apputils import load_checkpoint
from models import create_model
def test_load():
logger = logging.getLogger('simple_example')
......@@ -34,7 +39,42 @@ def test_load():
assert compression_scheduler is not None
assert start_epoch == 180
def test_load_state_dict():
# prepare lean checkpoint
state_dict_arrays = torch.load('../examples/ssl/checkpoints/checkpoint_trained_dense.pth.tar').get('state_dict')
with tempfile.NamedTemporaryFile() as tmpfile:
torch.save({'state_dict': state_dict_arrays}, tmpfile.name)
model = create_model(False, 'cifar10', 'resnet20_cifar')
model, compression_scheduler, start_epoch = load_checkpoint(model, tmpfile.name)
assert len(list(model.named_modules())) >= len([x for x in state_dict_arrays if x.endswith('weight')]) > 0
assert compression_scheduler is None
assert start_epoch == 0
def test_load_dumb_checkpoint():
# prepare lean checkpoint
state_dict_arrays = torch.load('../examples/ssl/checkpoints/checkpoint_trained_dense.pth.tar').get('state_dict')
with tempfile.NamedTemporaryFile() as tmpfile:
torch.save(state_dict_arrays, tmpfile.name)
model = create_model(False, 'cifar10', 'resnet20_cifar')
with pytest.raises(ValueError):
model, compression_scheduler, start_epoch = load_checkpoint(model, tmpfile.name)
def test_load_negative():
with pytest.raises(FileNotFoundError):
model = create_model(False, 'cifar10', 'resnet20_cifar')
model, compression_scheduler, start_epoch = load_checkpoint(model, 'THIS_IS_AN_ERROR/checkpoint_trained_dense.pth.tar')
def test_load_gpu_model_on_cpu():
model = create_model(False, 'cifar10', 'resnet20_cifar', device_ids=-1)
model, compression_scheduler, start_epoch = load_checkpoint(model,
'../examples/ssl/checkpoints/checkpoint_trained_dense.pth.tar')
assert compression_scheduler is not None
assert start_epoch == 180
assert distiller.model_device(model) == 'cpu'
if __name__ == '__main__':
test_load_gpu_model_on_cpu()
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment