diff --git a/distiller/apputils/checkpoint.py b/distiller/apputils/checkpoint.py index be7c4dc9916226f9eb19dd87b9dda3d97db80a25..b8f71c118a44ddb560bb98f3881e21f8a65fc396 100755 --- a/distiller/apputils/checkpoint.py +++ b/distiller/apputils/checkpoint.py @@ -101,7 +101,8 @@ def get_contents_table(d): return tabulate(contents, headers=["Key", "Type", "Value"], tablefmt="psql") -def load_checkpoint(model, chkpt_file, optimizer=None, model_device=None, *, lean_checkpoint=False): +def load_checkpoint(model, chkpt_file, optimizer=None, model_device=None, *, + lean_checkpoint=False, strict=False): """Load a pytorch training checkpoint. Args: @@ -164,7 +165,15 @@ def load_checkpoint(model, chkpt_file, optimizer=None, model_device=None, *, lea if normalize_dataparallel_keys: checkpoint['state_dict'] = {normalize_module_name(k): v for k, v in checkpoint['state_dict'].items()} - model.load_state_dict(checkpoint['state_dict']) + anomalous_keys = model.load_state_dict(checkpoint['state_dict'], strict) + if anomalous_keys: + # This is pytorch 1.1+ + missing_keys, unexpected_keys = anomalous_keys + if unexpected_keys: + msglogger.warning("Warning: the loaded checkpoint (%s) contains %d unexpected state keys" % (chkpt_file, len(unexpected_keys))) + if missing_keys: + raise ValueError("The loaded checkpoint (%s) is missing %d state keys" % (chkpt_file, len(missing_keys))) + if model_device is not None: model.to(model_device)