From 5c83a0447480683c2b718c81fb96edadb1ae2874 Mon Sep 17 00:00:00 2001 From: Neta Zmora <31280975+nzmora@users.noreply.github.com> Date: Wed, 26 Jun 2019 00:13:59 +0300 Subject: [PATCH] Checkpoint loading: allow loading non-strict state-keys (#300) * Checkpoint loading: allow loading non-strict state-keys Change the default behavior of load_state_dict() so that the keys in the loaded checkpoint do not need to match exactly the keys in Distiller's model. However, we placed some restriction on non-strict checkpoint loading: Even when loading checkpoints non-strict, we raise an exception if some keys are missing (extra keys are accepted). This is because the time-wasting potential of loading (and using) a model which only contains part of the state-keys (while the user expects it to contain all of a model's state-keys) is too large. We want the user to be completely aware that not all of the state-keys are initialized from the loaded checkpoint. --- distiller/apputils/checkpoint.py | 13 +++++++++++-- 1 file changed, 11 insertions(+), 2 deletions(-) diff --git a/distiller/apputils/checkpoint.py b/distiller/apputils/checkpoint.py index be7c4dc..b8f71c1 100755 --- a/distiller/apputils/checkpoint.py +++ b/distiller/apputils/checkpoint.py @@ -101,7 +101,8 @@ def get_contents_table(d): return tabulate(contents, headers=["Key", "Type", "Value"], tablefmt="psql") -def load_checkpoint(model, chkpt_file, optimizer=None, model_device=None, *, lean_checkpoint=False): +def load_checkpoint(model, chkpt_file, optimizer=None, model_device=None, *, + lean_checkpoint=False, strict=False): """Load a pytorch training checkpoint. Args: @@ -164,7 +165,15 @@ def load_checkpoint(model, chkpt_file, optimizer=None, model_device=None, *, lea if normalize_dataparallel_keys: checkpoint['state_dict'] = {normalize_module_name(k): v for k, v in checkpoint['state_dict'].items()} - model.load_state_dict(checkpoint['state_dict']) + anomalous_keys = model.load_state_dict(checkpoint['state_dict'], strict) + if anomalous_keys: + # This is pytorch 1.1+ + missing_keys, unexpected_keys = anomalous_keys + if unexpected_keys: + msglogger.warning("Warning: the loaded checkpoint (%s) contains %d unexpected state keys" % (chkpt_file, len(unexpected_keys))) + if missing_keys: + raise ValueError("The loaded checkpoint (%s) is missing %d state keys" % (chkpt_file, len(missing_keys))) + if model_device is not None: model.to(model_device) -- GitLab