From 5c83a0447480683c2b718c81fb96edadb1ae2874 Mon Sep 17 00:00:00 2001
From: Neta Zmora <31280975+nzmora@users.noreply.github.com>
Date: Wed, 26 Jun 2019 00:13:59 +0300
Subject: [PATCH] Checkpoint loading: allow loading non-strict state-keys
 (#300)

* Checkpoint loading: allow loading non-strict state-keys

Change the default behavior of load_state_dict() so that the
keys in the loaded checkpoint do not need to match exactly the
keys in Distiller's model.

However, we placed some restriction on non-strict checkpoint loading:
Even when loading checkpoints non-strict, we raise an exception if some keys
are missing (extra keys are accepted).
This is because the time-wasting potential of loading (and using) a model which only contains part of the state-keys (while the user expects it to contain all of a model's state-keys) is too large.
We want the user to be completely aware that not all of the state-keys are initialized from the loaded checkpoint.
---
 distiller/apputils/checkpoint.py | 13 +++++++++++--
 1 file changed, 11 insertions(+), 2 deletions(-)

diff --git a/distiller/apputils/checkpoint.py b/distiller/apputils/checkpoint.py
index be7c4dc..b8f71c1 100755
--- a/distiller/apputils/checkpoint.py
+++ b/distiller/apputils/checkpoint.py
@@ -101,7 +101,8 @@ def get_contents_table(d):
     return tabulate(contents, headers=["Key", "Type", "Value"], tablefmt="psql")
 
 
-def load_checkpoint(model, chkpt_file, optimizer=None, model_device=None, *, lean_checkpoint=False):
+def load_checkpoint(model, chkpt_file, optimizer=None, model_device=None, *, 
+                    lean_checkpoint=False, strict=False):
     """Load a pytorch training checkpoint.
 
     Args:
@@ -164,7 +165,15 @@ def load_checkpoint(model, chkpt_file, optimizer=None, model_device=None, *, lea
 
     if normalize_dataparallel_keys:
             checkpoint['state_dict'] = {normalize_module_name(k): v for k, v in checkpoint['state_dict'].items()}
-    model.load_state_dict(checkpoint['state_dict'])
+    anomalous_keys = model.load_state_dict(checkpoint['state_dict'], strict)
+    if anomalous_keys:
+        # This is pytorch 1.1+
+        missing_keys, unexpected_keys = anomalous_keys
+        if unexpected_keys:
+            msglogger.warning("Warning: the loaded checkpoint (%s) contains %d unexpected state keys" % (chkpt_file, len(unexpected_keys)))
+        if missing_keys:
+            raise ValueError("The loaded checkpoint (%s) is missing %d state keys" % (chkpt_file, len(missing_keys)))
+            
     if model_device is not None:
         model.to(model_device)
 
-- 
GitLab