Skip to content
Snippets Groups Projects
Commit 73df57bf authored by Neta Zmora's avatar Neta Zmora
Browse files

Bug fix: compression schedule configuration parsing

Used the wrong indentation when parsing RegularizationPolicy
parent 4240ec94
No related branches found
No related tags found
No related merge requests found
......@@ -43,17 +43,17 @@ def save_checkpoint(epoch, arch, model, optimizer=None, scheduler=None,
name: the name of the checkpoint file
dir: directory in which to save the checkpoint
"""
msglogger.info("Saving checkpoint")
if not os.path.isdir(dir):
msglogger.info("Error: Directory to save checkpoint doesn't exist - {0}".format(os.path.abspath(dir)))
exit(1)
filename = 'checkpoint.pth.tar' if name is None else name + '_checkpoint.pth.tar'
fullpath = os.path.join(dir, filename)
msglogger.info("Saving checkpoint to: %s" % fullpath)
filename_best = 'best.pth.tar' if name is None else name + '_best.pth.tar'
fullpath_best = os.path.join(dir, filename_best)
checkpoint = {}
checkpoint['epoch'] = epoch
checkpoint['arch'] = arch
checkpoint['arch'] = arch
checkpoint['state_dict'] = model.state_dict()
if best_top1 is not None:
checkpoint['best_top1'] = best_top1
......@@ -85,6 +85,7 @@ def load_checkpoint(model, chkpt_file, optimizer=None):
if os.path.isfile(chkpt_file):
msglogger.info("=> loading checkpoint %s", chkpt_file)
checkpoint = torch.load(chkpt_file)
msglogger.info("Checkpoint keys:\n{}".format("\n\t".join(k for k in checkpoint.keys())))
start_epoch = checkpoint['epoch'] + 1
best_top1 = checkpoint.get('best_top1', None)
if best_top1 is not None:
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment