diff --git a/distiller/apputils/checkpoint.py b/distiller/apputils/checkpoint.py index ce7648250553b744012f122c33068cbde2a53696..2cfbbba8d2878648cf89d90d909a1b1fa035ada7 100755 --- a/distiller/apputils/checkpoint.py +++ b/distiller/apputils/checkpoint.py @@ -77,7 +77,6 @@ def save_checkpoint(epoch, arch, model, optimizer=None, scheduler=None, checkpoint['quantizer_metadata'] = model.quantizer_metadata checkpoint['extras'] = extras - torch.save(checkpoint, fullpath) if is_best: shutil.copyfile(fullpath, fullpath_best) @@ -101,8 +100,8 @@ def get_contents_table(d): return tabulate(contents, headers=["Key", "Type", "Value"], tablefmt="psql") -def load_checkpoint(model, chkpt_file, optimizer=None, model_device=None, *, - lean_checkpoint=False, strict=False): +def load_checkpoint(model, chkpt_file, optimizer=None, + model_device=None, lean_checkpoint=False, strict=False): """Load a pytorch training checkpoint. Args: @@ -114,6 +113,52 @@ def load_checkpoint(model, chkpt_file, optimizer=None, model_device=None, *, This should be set to either 'cpu' or 'cuda'. :returns: updated model, compression_scheduler, optimizer, start_epoch """ + def _load_compression_scheduler(): + normalize_keys = False + try: + compression_scheduler.load_state_dict(checkpoint['compression_sched'], normalize_keys) + except KeyError as e: + # A very common source of this KeyError is loading a GPU model on the CPU. + # We rename all of the DataParallel keys because DataParallel does not execute on the CPU. + normalize_keys = True + compression_scheduler.load_state_dict(checkpoint['compression_sched'], normalize_keys) + msglogger.info("Loaded compression schedule from checkpoint (epoch {})".format( + checkpoint_epoch)) + return normalize_keys + + def _load_and_execute_thinning_recipes(): + msglogger.info("Loaded a thinning recipe from the checkpoint") + # Cache the recipes in case we need them later + model.thinning_recipes = checkpoint['thinning_recipes'] + if normalize_dataparallel_keys: + model.thinning_recipes = [distiller.get_normalized_recipe(recipe) + for recipe in model.thinning_recipes] + distiller.execute_thinning_recipes_list(model, + compression_scheduler.zeros_mask_dict, + model.thinning_recipes) + + def _load_optimizer(): + """Initialize optimizer with model parameters and load src_state_dict""" + try: + cls, src_state_dict = checkpoint['optimizer_type'], checkpoint['optimizer_state_dict'] + # Initialize the dest_optimizer with a dummy learning rate, + # this is required to support SGD.__init__() + dest_optimizer = cls(model.parameters(), lr=1) + dest_optimizer.load_state_dict(src_state_dict) + msglogger.info('Optimizer of type {type} was loaded from checkpoint'.format( + type=type(dest_optimizer))) + optimizer_param_groups = dest_optimizer.state_dict()['param_groups'] + msglogger.info('Optimizer Args: {}'.format( + dict((k, v) for k, v in optimizer_param_groups[0].items() + if k != 'params'))) + return dest_optimizer + except KeyError: + # Older checkpoints do support optimizer loading: They either had an 'optimizer' field + # (different name) which was not used during the load, or they didn't even checkpoint + # the optimizer. + msglogger.warning('Optimizer could not be loaded from checkpoint.') + return None + if not os.path.isfile(chkpt_file): raise IOError(ENOENT, 'Could not find a checkpoint file at', chkpt_file) @@ -133,30 +178,15 @@ def load_checkpoint(model, chkpt_file, optimizer=None, model_device=None, *, normalize_dataparallel_keys = False if 'compression_sched' in checkpoint: compression_scheduler = distiller.CompressionScheduler(model) - try: - compression_scheduler.load_state_dict(checkpoint['compression_sched'], normalize_dataparallel_keys) - except KeyError as e: - # A very common source of this KeyError is loading a GPU model on the CPU. - # We rename all of the DataParallel keys because DataParallel does not execute on the CPU. - normalize_dataparallel_keys = True - compression_scheduler.load_state_dict(checkpoint['compression_sched'], normalize_dataparallel_keys) - msglogger.info("Loaded compression schedule from checkpoint (epoch {})".format( - checkpoint_epoch)) + normalize_dataparallel_keys = _load_compression_scheduler() else: msglogger.info("Warning: compression schedule data does not exist in the checkpoint") if 'thinning_recipes' in checkpoint: - if 'compression_sched' not in checkpoint: - msglogger.warning("Found thinning_recipes key, but missing mandatory key compression_sched") + if not compression_scheduler: + msglogger.warning("Found thinning_recipes key, but missing key compression_scheduler") compression_scheduler = distiller.CompressionScheduler(model) - msglogger.info("Loaded a thinning recipe from the checkpoint") - # Cache the recipes in case we need them later - model.thinning_recipes = checkpoint['thinning_recipes'] - if normalize_dataparallel_keys: - model.thinning_recipes = [distiller.get_normalized_recipe(recipe) for recipe in model.thinning_recipes] - distiller.execute_thinning_recipes_list(model, - compression_scheduler.zeros_mask_dict, - model.thinning_recipes) + _load_and_execute_thinning_recipes() if 'quantizer_metadata' in checkpoint: msglogger.info('Loaded quantizer metadata from the checkpoint') @@ -165,49 +195,26 @@ def load_checkpoint(model, chkpt_file, optimizer=None, model_device=None, *, quantizer.prepare_model(qmd['dummy_input']) if normalize_dataparallel_keys: - checkpoint['state_dict'] = {normalize_module_name(k): v for k, v in checkpoint['state_dict'].items()} + checkpoint['state_dict'] = {normalize_module_name(k): v for k, v in checkpoint['state_dict'].items()} anomalous_keys = model.load_state_dict(checkpoint['state_dict'], strict) if anomalous_keys: # This is pytorch 1.1+ missing_keys, unexpected_keys = anomalous_keys if unexpected_keys: - msglogger.warning("Warning: the loaded checkpoint (%s) contains %d unexpected state keys" % (chkpt_file, len(unexpected_keys))) + msglogger.warning("Warning: the loaded checkpoint (%s) contains %d unexpected state keys" % + (chkpt_file, len(unexpected_keys))) if missing_keys: - raise ValueError("The loaded checkpoint (%s) is missing %d state keys" % (chkpt_file, len(missing_keys))) - + raise ValueError("The loaded checkpoint (%s) is missing %d state keys" % + (chkpt_file, len(missing_keys))) + if model_device is not None: model.to(model_device) if lean_checkpoint: msglogger.info("=> loaded 'state_dict' from checkpoint '{}'".format(str(chkpt_file))) - return (model, None, None, 0) - - def _load_optimizer(cls, src_state_dict, model): - """Initiate optimizer with model parameters and load src_state_dict""" - # initiate the dest_optimizer with a dummy learning rate, - # this is required to support SGD.__init__() - dest_optimizer = cls(model.parameters(), lr=1) - dest_optimizer.load_state_dict(src_state_dict) - return dest_optimizer - - try: - optimizer = _load_optimizer(checkpoint['optimizer_type'], - checkpoint['optimizer_state_dict'], model) - except KeyError: - # Older checkpoints do support optimizer loading: They either had an 'optimizer' field - # (different name) which was not used during the load, or they didn't even checkpoint - # the optimizer. - optimizer = None - - if optimizer is not None: - msglogger.info('Optimizer of type {type} was loaded from checkpoint'.format( - type=type(optimizer))) - msglogger.info('Optimizer Args: {}'.format( - dict((k,v) for k,v in optimizer.state_dict()['param_groups'][0].items() - if k != 'params'))) - else: - msglogger.warning('Optimizer could not be loaded from checkpoint.') + return model, None, None, 0 + optimizer = _load_optimizer() msglogger.info("=> loaded checkpoint '{f}' (epoch {e})".format(f=str(chkpt_file), e=checkpoint_epoch)) - return (model, compression_scheduler, optimizer, start_epoch) + return model, compression_scheduler, optimizer, start_epoch