Skip to content
Snippets Groups Projects
Commit 73df57bf authored by Neta Zmora's avatar Neta Zmora
Browse files

Bug fix: compression schedule configuration parsing

Used the wrong indentation when parsing RegularizationPolicy
parent 4240ec94
No related branches found
No related tags found
No related merge requests found
...@@ -43,17 +43,17 @@ def save_checkpoint(epoch, arch, model, optimizer=None, scheduler=None, ...@@ -43,17 +43,17 @@ def save_checkpoint(epoch, arch, model, optimizer=None, scheduler=None,
name: the name of the checkpoint file name: the name of the checkpoint file
dir: directory in which to save the checkpoint dir: directory in which to save the checkpoint
""" """
msglogger.info("Saving checkpoint")
if not os.path.isdir(dir): if not os.path.isdir(dir):
msglogger.info("Error: Directory to save checkpoint doesn't exist - {0}".format(os.path.abspath(dir))) msglogger.info("Error: Directory to save checkpoint doesn't exist - {0}".format(os.path.abspath(dir)))
exit(1) exit(1)
filename = 'checkpoint.pth.tar' if name is None else name + '_checkpoint.pth.tar' filename = 'checkpoint.pth.tar' if name is None else name + '_checkpoint.pth.tar'
fullpath = os.path.join(dir, filename) fullpath = os.path.join(dir, filename)
msglogger.info("Saving checkpoint to: %s" % fullpath)
filename_best = 'best.pth.tar' if name is None else name + '_best.pth.tar' filename_best = 'best.pth.tar' if name is None else name + '_best.pth.tar'
fullpath_best = os.path.join(dir, filename_best) fullpath_best = os.path.join(dir, filename_best)
checkpoint = {} checkpoint = {}
checkpoint['epoch'] = epoch checkpoint['epoch'] = epoch
checkpoint['arch'] = arch checkpoint['arch'] = arch
checkpoint['state_dict'] = model.state_dict() checkpoint['state_dict'] = model.state_dict()
if best_top1 is not None: if best_top1 is not None:
checkpoint['best_top1'] = best_top1 checkpoint['best_top1'] = best_top1
...@@ -85,6 +85,7 @@ def load_checkpoint(model, chkpt_file, optimizer=None): ...@@ -85,6 +85,7 @@ def load_checkpoint(model, chkpt_file, optimizer=None):
if os.path.isfile(chkpt_file): if os.path.isfile(chkpt_file):
msglogger.info("=> loading checkpoint %s", chkpt_file) msglogger.info("=> loading checkpoint %s", chkpt_file)
checkpoint = torch.load(chkpt_file) checkpoint = torch.load(chkpt_file)
msglogger.info("Checkpoint keys:\n{}".format("\n\t".join(k for k in checkpoint.keys())))
start_epoch = checkpoint['epoch'] + 1 start_epoch = checkpoint['epoch'] + 1
best_top1 = checkpoint.get('best_top1', None) best_top1 = checkpoint.get('best_top1', None)
if best_top1 is not None: if best_top1 is not None:
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment