diff --git a/examples/classifier_compression/inspect_ckpt.py b/examples/classifier_compression/inspect_ckpt.py index c88962b074eacd853bf45cc67116f2d839f27f6d..344a82b6744e798dfb874de8fc744eafe162e0bd 100755 --- a/examples/classifier_compression/inspect_ckpt.py +++ b/examples/classifier_compression/inspect_ckpt.py @@ -29,6 +29,15 @@ $ python3 inspect_ckpt.py checkpoint.pth.tar --model --schedule import torch import argparse from tabulate import tabulate +import sys +import os +script_dir = os.path.dirname(__file__) +module_path = os.path.abspath(os.path.join(script_dir, '..', '..')) +try: + import distiller +except ImportError: + sys.path.append(module_path) + import distiller def inspect_checkpoint(chkpt_file, args): @@ -55,11 +64,15 @@ def inspect_checkpoint(chkpt_file, args): print("compression_sched[\"masks_dict\"] keys:\n{}".format(", ".join( list(compression_sched["masks_dict"].keys())))) + if args.thinning and "thinning_recipes" in checkpoint: + for recipe in checkpoint["thinning_recipes"]: + print(recipe) if __name__ == '__main__': parser = argparse.ArgumentParser(description='Distiller checkpoint inspection') parser.add_argument('chkpt_file', help='path to the checkpoint file') parser.add_argument('-m', '--model', action='store_true', help='print the model keys') parser.add_argument('-s', '--schedule', action='store_true', help='print the schedule keys') + parser.add_argument('-t', '--thinning', action='store_true', help='print the thinning keys') args = parser.parse_args() inspect_checkpoint(args.chkpt_file, args)