From d81927d966ccbfa0695ccbf7495bcd49ada51092 Mon Sep 17 00:00:00 2001
From: Neta Zmora <neta.zmora@intel.com>
Date: Thu, 31 Jan 2019 14:33:53 +0200
Subject: [PATCH] More robust handling of loading non-Distiller checkpoints

Specifically, gracefully handle a missing 'epoch' key in a loaded
checkpoint file.
---
 apputils/checkpoint.py | 4 ++--
 1 file changed, 2 insertions(+), 2 deletions(-)

diff --git a/apputils/checkpoint.py b/apputils/checkpoint.py
index 5f84883..3792152 100755
--- a/apputils/checkpoint.py
+++ b/apputils/checkpoint.py
@@ -87,7 +87,7 @@ def load_checkpoint(model, chkpt_file, optimizer=None):
         msglogger.info("=> loading checkpoint %s", chkpt_file)
         checkpoint = torch.load(chkpt_file, map_location = lambda storage, loc: storage)
         msglogger.info("Checkpoint keys:\n{}".format("\n\t".join(k for k in checkpoint.keys())))
-        start_epoch = checkpoint['epoch'] + 1
+        start_epoch = checkpoint.get('epoch', -1) + 1
         best_top1 = checkpoint.get('best_top1', None)
         if best_top1 is not None:
             msglogger.info("   best top@1: %.3f", best_top1)
@@ -116,7 +116,7 @@ def load_checkpoint(model, chkpt_file, optimizer=None):
             quantizer = qmd['type'](model, **qmd['params'])
             quantizer.prepare_model()
 
-        msglogger.info("=> loaded checkpoint '%s' (epoch %d)", chkpt_file, checkpoint['epoch'])
+        msglogger.info("=> loaded checkpoint '%s' (epoch %d)", chkpt_file, start_epoch-1)
 
         model.load_state_dict(checkpoint['state_dict'])
         return model, compression_scheduler, start_epoch
-- 
GitLab