From d81927d966ccbfa0695ccbf7495bcd49ada51092 Mon Sep 17 00:00:00 2001 From: Neta Zmora <neta.zmora@intel.com> Date: Thu, 31 Jan 2019 14:33:53 +0200 Subject: [PATCH] More robust handling of loading non-Distiller checkpoints Specifically, gracefully handle a missing 'epoch' key in a loaded checkpoint file. --- apputils/checkpoint.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/apputils/checkpoint.py b/apputils/checkpoint.py index 5f84883..3792152 100755 --- a/apputils/checkpoint.py +++ b/apputils/checkpoint.py @@ -87,7 +87,7 @@ def load_checkpoint(model, chkpt_file, optimizer=None): msglogger.info("=> loading checkpoint %s", chkpt_file) checkpoint = torch.load(chkpt_file, map_location = lambda storage, loc: storage) msglogger.info("Checkpoint keys:\n{}".format("\n\t".join(k for k in checkpoint.keys()))) - start_epoch = checkpoint['epoch'] + 1 + start_epoch = checkpoint.get('epoch', -1) + 1 best_top1 = checkpoint.get('best_top1', None) if best_top1 is not None: msglogger.info(" best top@1: %.3f", best_top1) @@ -116,7 +116,7 @@ def load_checkpoint(model, chkpt_file, optimizer=None): quantizer = qmd['type'](model, **qmd['params']) quantizer.prepare_model() - msglogger.info("=> loaded checkpoint '%s' (epoch %d)", chkpt_file, checkpoint['epoch']) + msglogger.info("=> loaded checkpoint '%s' (epoch %d)", chkpt_file, start_epoch-1) model.load_state_dict(checkpoint['state_dict']) return model, compression_scheduler, start_epoch -- GitLab