diff --git a/distiller/apputils/checkpoint.py b/distiller/apputils/checkpoint.py
index be7c4dc9916226f9eb19dd87b9dda3d97db80a25..b8f71c118a44ddb560bb98f3881e21f8a65fc396 100755
--- a/distiller/apputils/checkpoint.py
+++ b/distiller/apputils/checkpoint.py
@@ -101,7 +101,8 @@ def get_contents_table(d):
     return tabulate(contents, headers=["Key", "Type", "Value"], tablefmt="psql")
 
 
-def load_checkpoint(model, chkpt_file, optimizer=None, model_device=None, *, lean_checkpoint=False):
+def load_checkpoint(model, chkpt_file, optimizer=None, model_device=None, *, 
+                    lean_checkpoint=False, strict=False):
     """Load a pytorch training checkpoint.
 
     Args:
@@ -164,7 +165,15 @@ def load_checkpoint(model, chkpt_file, optimizer=None, model_device=None, *, lea
 
     if normalize_dataparallel_keys:
             checkpoint['state_dict'] = {normalize_module_name(k): v for k, v in checkpoint['state_dict'].items()}
-    model.load_state_dict(checkpoint['state_dict'])
+    anomalous_keys = model.load_state_dict(checkpoint['state_dict'], strict)
+    if anomalous_keys:
+        # This is pytorch 1.1+
+        missing_keys, unexpected_keys = anomalous_keys
+        if unexpected_keys:
+            msglogger.warning("Warning: the loaded checkpoint (%s) contains %d unexpected state keys" % (chkpt_file, len(unexpected_keys)))
+        if missing_keys:
+            raise ValueError("The loaded checkpoint (%s) is missing %d state keys" % (chkpt_file, len(missing_keys)))
+            
     if model_device is not None:
         model.to(model_device)