diff --git a/examples/classifier_compression/inspect_ckpt.py b/examples/classifier_compression/inspect_ckpt.py index 727aafd59a58a981992c15e87fafd7f678841267..7d6d05a28eb34e4a5be895ed75d55ff02bfc6f6e 100755 --- a/examples/classifier_compression/inspect_ckpt.py +++ b/examples/classifier_compression/inspect_ckpt.py @@ -33,15 +33,26 @@ import distiller from distiller.apputils.checkpoint import get_contents_table +def print_sparsities(masks_dict): + mask_sparsities = [(param_name, distiller.sparsity(mask)) for param_name, mask in masks_dict.items() + if mask is not None] + print(tabulate(mask_sparsities, headers=["Module", "Mask Sparsity"], tablefmt="fancy_grid")) + + def inspect_checkpoint(chkpt_file, args): print("Inspecting checkpoint file: ", chkpt_file) - checkpoint = torch.load(chkpt_file) + # force loading on the CPU which always has more memory than the GPU(s) + checkpoint = torch.load(chkpt_file, map_location='cpu') print(get_contents_table(checkpoint)) if 'extras' in checkpoint and checkpoint['extras']: print("\nContents of Checkpoint['extras']:") print(get_contents_table(checkpoint['extras'])) + try: + print_sparsities(checkpoint["extras"]["creation_masks"]) + except KeyError: + pass if args.model and "state_dict" in checkpoint: print("\nModel keys (state_dict):\n{}".format(", ".join(list(checkpoint["state_dict"].keys())))) @@ -51,12 +62,12 @@ def inspect_checkpoint(chkpt_file, args): print("\nSchedule keys (compression_sched):\n{}\n".format("\n\t".join(list(compression_sched.keys())))) sched_keys = [[k, type(compression_sched[k]).__name__] for k in compression_sched.keys()] print(tabulate(sched_keys, headers=["Key", "Type"], tablefmt="fancy_grid")) - if "masks_dict" in checkpoint["compression_sched"]: + try: masks_dict = compression_sched["masks_dict"] print("compression_sched[\"masks_dict\"] keys:\n{}".format(list(masks_dict.keys()))) - mask_sparsities = [(param_name, distiller.sparsity(mask)) for param_name, mask in masks_dict.items() - if mask is not None] - print(tabulate(mask_sparsities, headers=["Module", "Mask Sparsity"], tablefmt="fancy_grid")) + print_sparsities(masks_dict) + except KeyError: + pass if args.thinning and "thinning_recipes" in checkpoint: for recipe in checkpoint["thinning_recipes"]: