Skip to content
Snippets Groups Projects
Commit 99124355 authored by Neta Zmora's avatar Neta Zmora
Browse files

checkpoint.py: non-functional code refactoring

Rearranged the code for easier reading and maintenance
parent bdafebea
No related branches found
No related tags found
No related merge requests found
...@@ -77,7 +77,6 @@ def save_checkpoint(epoch, arch, model, optimizer=None, scheduler=None, ...@@ -77,7 +77,6 @@ def save_checkpoint(epoch, arch, model, optimizer=None, scheduler=None,
checkpoint['quantizer_metadata'] = model.quantizer_metadata checkpoint['quantizer_metadata'] = model.quantizer_metadata
checkpoint['extras'] = extras checkpoint['extras'] = extras
torch.save(checkpoint, fullpath) torch.save(checkpoint, fullpath)
if is_best: if is_best:
shutil.copyfile(fullpath, fullpath_best) shutil.copyfile(fullpath, fullpath_best)
...@@ -101,8 +100,8 @@ def get_contents_table(d): ...@@ -101,8 +100,8 @@ def get_contents_table(d):
return tabulate(contents, headers=["Key", "Type", "Value"], tablefmt="psql") return tabulate(contents, headers=["Key", "Type", "Value"], tablefmt="psql")
def load_checkpoint(model, chkpt_file, optimizer=None, model_device=None, *, def load_checkpoint(model, chkpt_file, optimizer=None,
lean_checkpoint=False, strict=False): model_device=None, lean_checkpoint=False, strict=False):
"""Load a pytorch training checkpoint. """Load a pytorch training checkpoint.
Args: Args:
...@@ -114,6 +113,52 @@ def load_checkpoint(model, chkpt_file, optimizer=None, model_device=None, *, ...@@ -114,6 +113,52 @@ def load_checkpoint(model, chkpt_file, optimizer=None, model_device=None, *,
This should be set to either 'cpu' or 'cuda'. This should be set to either 'cpu' or 'cuda'.
:returns: updated model, compression_scheduler, optimizer, start_epoch :returns: updated model, compression_scheduler, optimizer, start_epoch
""" """
def _load_compression_scheduler():
normalize_keys = False
try:
compression_scheduler.load_state_dict(checkpoint['compression_sched'], normalize_keys)
except KeyError as e:
# A very common source of this KeyError is loading a GPU model on the CPU.
# We rename all of the DataParallel keys because DataParallel does not execute on the CPU.
normalize_keys = True
compression_scheduler.load_state_dict(checkpoint['compression_sched'], normalize_keys)
msglogger.info("Loaded compression schedule from checkpoint (epoch {})".format(
checkpoint_epoch))
return normalize_keys
def _load_and_execute_thinning_recipes():
msglogger.info("Loaded a thinning recipe from the checkpoint")
# Cache the recipes in case we need them later
model.thinning_recipes = checkpoint['thinning_recipes']
if normalize_dataparallel_keys:
model.thinning_recipes = [distiller.get_normalized_recipe(recipe)
for recipe in model.thinning_recipes]
distiller.execute_thinning_recipes_list(model,
compression_scheduler.zeros_mask_dict,
model.thinning_recipes)
def _load_optimizer():
"""Initialize optimizer with model parameters and load src_state_dict"""
try:
cls, src_state_dict = checkpoint['optimizer_type'], checkpoint['optimizer_state_dict']
# Initialize the dest_optimizer with a dummy learning rate,
# this is required to support SGD.__init__()
dest_optimizer = cls(model.parameters(), lr=1)
dest_optimizer.load_state_dict(src_state_dict)
msglogger.info('Optimizer of type {type} was loaded from checkpoint'.format(
type=type(dest_optimizer)))
optimizer_param_groups = dest_optimizer.state_dict()['param_groups']
msglogger.info('Optimizer Args: {}'.format(
dict((k, v) for k, v in optimizer_param_groups[0].items()
if k != 'params')))
return dest_optimizer
except KeyError:
# Older checkpoints do support optimizer loading: They either had an 'optimizer' field
# (different name) which was not used during the load, or they didn't even checkpoint
# the optimizer.
msglogger.warning('Optimizer could not be loaded from checkpoint.')
return None
if not os.path.isfile(chkpt_file): if not os.path.isfile(chkpt_file):
raise IOError(ENOENT, 'Could not find a checkpoint file at', chkpt_file) raise IOError(ENOENT, 'Could not find a checkpoint file at', chkpt_file)
...@@ -133,30 +178,15 @@ def load_checkpoint(model, chkpt_file, optimizer=None, model_device=None, *, ...@@ -133,30 +178,15 @@ def load_checkpoint(model, chkpt_file, optimizer=None, model_device=None, *,
normalize_dataparallel_keys = False normalize_dataparallel_keys = False
if 'compression_sched' in checkpoint: if 'compression_sched' in checkpoint:
compression_scheduler = distiller.CompressionScheduler(model) compression_scheduler = distiller.CompressionScheduler(model)
try: normalize_dataparallel_keys = _load_compression_scheduler()
compression_scheduler.load_state_dict(checkpoint['compression_sched'], normalize_dataparallel_keys)
except KeyError as e:
# A very common source of this KeyError is loading a GPU model on the CPU.
# We rename all of the DataParallel keys because DataParallel does not execute on the CPU.
normalize_dataparallel_keys = True
compression_scheduler.load_state_dict(checkpoint['compression_sched'], normalize_dataparallel_keys)
msglogger.info("Loaded compression schedule from checkpoint (epoch {})".format(
checkpoint_epoch))
else: else:
msglogger.info("Warning: compression schedule data does not exist in the checkpoint") msglogger.info("Warning: compression schedule data does not exist in the checkpoint")
if 'thinning_recipes' in checkpoint: if 'thinning_recipes' in checkpoint:
if 'compression_sched' not in checkpoint: if not compression_scheduler:
msglogger.warning("Found thinning_recipes key, but missing mandatory key compression_sched") msglogger.warning("Found thinning_recipes key, but missing key compression_scheduler")
compression_scheduler = distiller.CompressionScheduler(model) compression_scheduler = distiller.CompressionScheduler(model)
msglogger.info("Loaded a thinning recipe from the checkpoint") _load_and_execute_thinning_recipes()
# Cache the recipes in case we need them later
model.thinning_recipes = checkpoint['thinning_recipes']
if normalize_dataparallel_keys:
model.thinning_recipes = [distiller.get_normalized_recipe(recipe) for recipe in model.thinning_recipes]
distiller.execute_thinning_recipes_list(model,
compression_scheduler.zeros_mask_dict,
model.thinning_recipes)
if 'quantizer_metadata' in checkpoint: if 'quantizer_metadata' in checkpoint:
msglogger.info('Loaded quantizer metadata from the checkpoint') msglogger.info('Loaded quantizer metadata from the checkpoint')
...@@ -165,49 +195,26 @@ def load_checkpoint(model, chkpt_file, optimizer=None, model_device=None, *, ...@@ -165,49 +195,26 @@ def load_checkpoint(model, chkpt_file, optimizer=None, model_device=None, *,
quantizer.prepare_model(qmd['dummy_input']) quantizer.prepare_model(qmd['dummy_input'])
if normalize_dataparallel_keys: if normalize_dataparallel_keys:
checkpoint['state_dict'] = {normalize_module_name(k): v for k, v in checkpoint['state_dict'].items()} checkpoint['state_dict'] = {normalize_module_name(k): v for k, v in checkpoint['state_dict'].items()}
anomalous_keys = model.load_state_dict(checkpoint['state_dict'], strict) anomalous_keys = model.load_state_dict(checkpoint['state_dict'], strict)
if anomalous_keys: if anomalous_keys:
# This is pytorch 1.1+ # This is pytorch 1.1+
missing_keys, unexpected_keys = anomalous_keys missing_keys, unexpected_keys = anomalous_keys
if unexpected_keys: if unexpected_keys:
msglogger.warning("Warning: the loaded checkpoint (%s) contains %d unexpected state keys" % (chkpt_file, len(unexpected_keys))) msglogger.warning("Warning: the loaded checkpoint (%s) contains %d unexpected state keys" %
(chkpt_file, len(unexpected_keys)))
if missing_keys: if missing_keys:
raise ValueError("The loaded checkpoint (%s) is missing %d state keys" % (chkpt_file, len(missing_keys))) raise ValueError("The loaded checkpoint (%s) is missing %d state keys" %
(chkpt_file, len(missing_keys)))
if model_device is not None: if model_device is not None:
model.to(model_device) model.to(model_device)
if lean_checkpoint: if lean_checkpoint:
msglogger.info("=> loaded 'state_dict' from checkpoint '{}'".format(str(chkpt_file))) msglogger.info("=> loaded 'state_dict' from checkpoint '{}'".format(str(chkpt_file)))
return (model, None, None, 0) return model, None, None, 0
def _load_optimizer(cls, src_state_dict, model):
"""Initiate optimizer with model parameters and load src_state_dict"""
# initiate the dest_optimizer with a dummy learning rate,
# this is required to support SGD.__init__()
dest_optimizer = cls(model.parameters(), lr=1)
dest_optimizer.load_state_dict(src_state_dict)
return dest_optimizer
try:
optimizer = _load_optimizer(checkpoint['optimizer_type'],
checkpoint['optimizer_state_dict'], model)
except KeyError:
# Older checkpoints do support optimizer loading: They either had an 'optimizer' field
# (different name) which was not used during the load, or they didn't even checkpoint
# the optimizer.
optimizer = None
if optimizer is not None:
msglogger.info('Optimizer of type {type} was loaded from checkpoint'.format(
type=type(optimizer)))
msglogger.info('Optimizer Args: {}'.format(
dict((k,v) for k,v in optimizer.state_dict()['param_groups'][0].items()
if k != 'params')))
else:
msglogger.warning('Optimizer could not be loaded from checkpoint.')
optimizer = _load_optimizer()
msglogger.info("=> loaded checkpoint '{f}' (epoch {e})".format(f=str(chkpt_file), msglogger.info("=> loaded checkpoint '{f}' (epoch {e})".format(f=str(chkpt_file),
e=checkpoint_epoch)) e=checkpoint_epoch))
return (model, compression_scheduler, optimizer, start_epoch) return model, compression_scheduler, optimizer, start_epoch
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment