diff --git a/distiller/apputils/checkpoint.py b/distiller/apputils/checkpoint.py index a3c4ae62463c10ec95dc6b1e7fdfb73503c225b9..acfa3a33da9d4e2aa3da786c9e3f65a8f5387572 100755 --- a/distiller/apputils/checkpoint.py +++ b/distiller/apputils/checkpoint.py @@ -24,6 +24,8 @@ import os import shutil from errno import ENOENT import logging +from numbers import Number +from tabulate import tabulate import torch import distiller from distiller.utils import normalize_module_name @@ -31,34 +33,39 @@ msglogger = logging.getLogger() def save_checkpoint(epoch, arch, model, optimizer=None, scheduler=None, - best_top1=None, is_best=False, name=None, dir='.'): + extras=None, is_best=False, name=None, dir='.'): """Save a pytorch training checkpoint Args: - epoch: current epoch - arch: name of the network arechitecture/topology + epoch: current epoch number + arch: name of the network architecture/topology model: a pytorch model optimizer: the optimizer used in the training session scheduler: the CompressionScheduler instance used for training, if any - best_top1: the best top1 score seen so far - is_best: True if this is the best (top1 accuracy) model so far + extras: optional dict with additional user-defined data to be saved in the checkpoint. + Will be saved under the key 'extras' + is_best: If true, will save a copy of the checkpoint with the suffix 'best' name: the name of the checkpoint file dir: directory in which to save the checkpoint """ if not os.path.isdir(dir): raise IOError(ENOENT, 'Checkpoint directory does not exist at', os.path.abspath(dir)) + if extras is None: + extras = {} + if not isinstance(extras, dict): + raise TypeError('extras must be either a dict or None') + filename = 'checkpoint.pth.tar' if name is None else name + '_checkpoint.pth.tar' fullpath = os.path.join(dir, filename) msglogger.info("Saving checkpoint to: %s" % fullpath) filename_best = 'best.pth.tar' if name is None else name + '_best.pth.tar' fullpath_best = os.path.join(dir, filename_best) + checkpoint = {} checkpoint['epoch'] = epoch checkpoint['arch'] = arch checkpoint['state_dict'] = model.state_dict() - if best_top1 is not None: - checkpoint['best_top1'] = best_top1 if optimizer is not None: checkpoint['optimizer_state_dict'] = optimizer.state_dict() checkpoint['optimizer_type'] = type(optimizer) @@ -69,6 +76,8 @@ def save_checkpoint(epoch, arch, model, optimizer=None, scheduler=None, if hasattr(model, 'quantizer_metadata'): checkpoint['quantizer_metadata'] = model.quantizer_metadata + checkpoint['extras'] = extras + torch.save(checkpoint, fullpath) if is_best: shutil.copyfile(fullpath, fullpath_best) @@ -79,6 +88,19 @@ def load_lean_checkpoint(model, chkpt_file, model_device=None): lean_checkpoint=True)[0] +def get_contents_table(d): + def inspect_val(val): + if isinstance(val, (Number, str)): + return val + elif isinstance(val, type): + return val.__name__ + return None + + contents = [[k, type(d[k]).__name__, inspect_val(d[k])] for k in d.keys()] + contents = sorted(contents, key=lambda entry: entry[0]) + return tabulate(contents, headers=["Key", "Type", "Value"], tablefmt="fancy_grid") + + def load_checkpoint(model, chkpt_file, optimizer=None, model_device=None, *, lean_checkpoint=False): """Load a pytorch training checkpoint. @@ -96,7 +118,10 @@ def load_checkpoint(model, chkpt_file, optimizer=None, model_device=None, *, lea msglogger.info("=> loading checkpoint %s", chkpt_file) checkpoint = torch.load(chkpt_file, map_location=lambda storage, loc: storage) - msglogger.debug("\n\t".join(['Checkpoint keys:'] + list(checkpoint))) + + msglogger.info('=> Checkpoint contents:\n{}\n'.format(get_contents_table(checkpoint))) + if 'extras' in checkpoint: + msglogger.info("=> Checkpoint['extras'] contents:\n{}\n".format(get_contents_table(checkpoint['extras']))) if 'state_dict' not in checkpoint: raise ValueError("Checkpoint must contain the model parameters under the key 'state_dict'") @@ -104,10 +129,6 @@ def load_checkpoint(model, chkpt_file, optimizer=None, model_device=None, *, lea checkpoint_epoch = checkpoint.get('epoch', None) start_epoch = checkpoint_epoch + 1 if checkpoint_epoch is not None else 0 - best_top1 = checkpoint.get('best_top1', None) - if best_top1 is not None: - msglogger.info(" best top@1: %.3f", best_top1) - compression_scheduler = None normalize_dataparallel_keys = False if 'compression_sched' in checkpoint: diff --git a/distiller/pruning/greedy_filter_pruning.py b/distiller/pruning/greedy_filter_pruning.py index 60202ae16327bde20fdb3052478946f60f91429a..2b067d7ba3be1de57f3d233b393e642bcca14628 100755 --- a/distiller/pruning/greedy_filter_pruning.py +++ b/distiller/pruning/greedy_filter_pruning.py @@ -296,7 +296,7 @@ def greedy_pruner(pruned_model, app_args, fraction_to_prune, pruning_step, test_ results = (iteration, prec1, param_name, compute_density, total_macs, densities) record_network_details(results) scheduler = create_scheduler(pruned_model, zeros_mask_dict) - save_checkpoint(0, arch, pruned_model, optimizer=None, best_top1=prec1, scheduler=scheduler, + save_checkpoint(0, arch, pruned_model, optimizer=None, scheduler=scheduler, extras={'top1': prec1}, name="greedy__{}__{:.1f}__{:.1f}".format(str(iteration).zfill(3), compute_density*100, prec1), dir=msglogger.logdir) del scheduler @@ -307,6 +307,6 @@ def greedy_pruner(pruned_model, app_args, fraction_to_prune, pruning_step, test_ prec1, prec5, loss = test_fn(model=pruned_model) print(prec1, prec5, loss) scheduler = create_scheduler(pruned_model, zeros_mask_dict) - save_checkpoint(0, arch, pruned_model, optimizer=None, best_top1=prec1, scheduler=scheduler, + save_checkpoint(0, arch, pruned_model, optimizer=None, scheduler=scheduler, extras={'top1': prec1}, name='_'.join(("greedy", str(fraction_to_prune))), dir=msglogger.logdir) diff --git a/examples/classifier_compression/compress_classifier.py b/examples/classifier_compression/compress_classifier.py index ee8d0d53aaffd1bc203c607a42e1b49c095ff3ba..28a715ea4a2ad3b1d2632a2ad8ca36cfe8df6a5e 100755 --- a/examples/classifier_compression/compress_classifier.py +++ b/examples/classifier_compression/compress_classifier.py @@ -241,10 +241,12 @@ def main(): if args.thinnify: #zeros_mask_dict = distiller.create_model_masks_dict(model) - assert args.resumed_checkpoint_path is not None, "You must use --resume-from to provide a checkpoint file to thinnify" + assert args.resumed_checkpoint_path is not None, \ + "You must use --resume-from to provide a checkpoint file to thinnify" distiller.remove_filters(model, compression_scheduler.zeros_mask_dict, args.arch, args.dataset, optimizer=None) apputils.save_checkpoint(0, args.arch, model, optimizer=None, scheduler=compression_scheduler, - name="{}_thinned".format(args.resumed_checkpoint_path.replace(".pth.tar", "")), dir=msglogger.logdir) + name="{}_thinned".format(args.resumed_checkpoint_path.replace(".pth.tar", "")), + dir=msglogger.logdir) print("Note: your model may have collapsed to random inference, so you may want to fine-tune") return @@ -307,8 +309,11 @@ def main(): # Update the list of top scores achieved so far, and save the checkpoint update_training_scores_history(perf_scores_history, model, top1, top5, epoch, args.num_best_scores) is_best = epoch == perf_scores_history[0].epoch - apputils.save_checkpoint(epoch, args.arch, model, optimizer, compression_scheduler, - perf_scores_history[0].top1, is_best, args.name, msglogger.logdir) + checkpoint_extras = {'current_top1': top1, + 'best_top1': perf_scores_history[0].top1, + 'best_epoch': perf_scores_history[0].epoch} + apputils.save_checkpoint(epoch, args.arch, model, optimizer=optimizer, scheduler=compression_scheduler, + extras=checkpoint_extras, is_best=is_best, name=args.name, dir=msglogger.logdir) # Finally run results on the test set test(test_loader, model, criterion, [pylogger], activations_collectors, args=args) @@ -640,9 +645,9 @@ def evaluate_model(model, criterion, test_loader, loggers, activations_collector if args.quantize_eval: checkpoint_name = 'quantized' - apputils.save_checkpoint(0, args.arch, model, optimizer=None, best_top1=top1, scheduler=scheduler, + apputils.save_checkpoint(0, args.arch, model, optimizer=None, scheduler=scheduler, name='_'.join([args.name, checkpoint_name]) if args.name else checkpoint_name, - dir=msglogger.logdir) + dir=msglogger.logdir, extras={'quantized_top1': top1}) def summarize_model(model, dataset, which_summary): diff --git a/examples/classifier_compression/inspect_ckpt.py b/examples/classifier_compression/inspect_ckpt.py index 344a82b6744e798dfb874de8fc744eafe162e0bd..1d149366e57a5499d751b34e39b92b08577f5cc6 100755 --- a/examples/classifier_compression/inspect_ckpt.py +++ b/examples/classifier_compression/inspect_ckpt.py @@ -29,28 +29,20 @@ $ python3 inspect_ckpt.py checkpoint.pth.tar --model --schedule import torch import argparse from tabulate import tabulate -import sys -import os -script_dir = os.path.dirname(__file__) -module_path = os.path.abspath(os.path.join(script_dir, '..', '..')) -try: - import distiller -except ImportError: - sys.path.append(module_path) - import distiller +import distiller +from distiller.apputils.checkpoint import get_contents_table -def inspect_checkpoint(chkpt_file, args): - def inspect_val(val): - if isinstance(val, (int, float, str)): - return val - return None +def inspect_checkpoint(chkpt_file, args): print("Inspecting checkpoint file: ", chkpt_file) checkpoint = torch.load(chkpt_file) - chkpt_keys = [[k, type(checkpoint[k]).__name__, inspect_val(checkpoint[k])] for k in checkpoint.keys()] - print(tabulate(chkpt_keys, headers=["Key", "Type", "Value"], tablefmt="fancy_grid")) + print(get_contents_table(checkpoint)) + + if 'extras' in checkpoint and checkpoint['extras']: + print("\nContents of Checkpoint['extras']:") + print(get_contents_table(checkpoint['extras'])) if args.model and "state_dict" in checkpoint: print("\nModel keys (state_dict):\n{}".format(", ".join(list(checkpoint["state_dict"].keys())))) @@ -68,6 +60,7 @@ def inspect_checkpoint(chkpt_file, args): for recipe in checkpoint["thinning_recipes"]: print(recipe) + if __name__ == '__main__': parser = argparse.ArgumentParser(description='Distiller checkpoint inspection') parser.add_argument('chkpt_file', help='path to the checkpoint file')