From 14531cbfbe3ba86ad30e1b049c7d03cd887d7b5c Mon Sep 17 00:00:00 2001 From: Neta Zmora <neta.zmora@intel.com> Date: Thu, 4 Oct 2018 15:32:38 +0300 Subject: [PATCH] inspect_ckpt.py: temporary fix Temporary fix for dependency on distiller class hierarchy when serializing a model that contains a thinning recipe. --- examples/classifier_compression/inspect_ckpt.py | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/examples/classifier_compression/inspect_ckpt.py b/examples/classifier_compression/inspect_ckpt.py index c88962b..344a82b 100755 --- a/examples/classifier_compression/inspect_ckpt.py +++ b/examples/classifier_compression/inspect_ckpt.py @@ -29,6 +29,15 @@ $ python3 inspect_ckpt.py checkpoint.pth.tar --model --schedule import torch import argparse from tabulate import tabulate +import sys +import os +script_dir = os.path.dirname(__file__) +module_path = os.path.abspath(os.path.join(script_dir, '..', '..')) +try: + import distiller +except ImportError: + sys.path.append(module_path) + import distiller def inspect_checkpoint(chkpt_file, args): @@ -55,11 +64,15 @@ def inspect_checkpoint(chkpt_file, args): print("compression_sched[\"masks_dict\"] keys:\n{}".format(", ".join( list(compression_sched["masks_dict"].keys())))) + if args.thinning and "thinning_recipes" in checkpoint: + for recipe in checkpoint["thinning_recipes"]: + print(recipe) if __name__ == '__main__': parser = argparse.ArgumentParser(description='Distiller checkpoint inspection') parser.add_argument('chkpt_file', help='path to the checkpoint file') parser.add_argument('-m', '--model', action='store_true', help='print the model keys') parser.add_argument('-s', '--schedule', action='store_true', help='print the schedule keys') + parser.add_argument('-t', '--thinning', action='store_true', help='print the thinning keys') args = parser.parse_args() inspect_checkpoint(args.chkpt_file, args) -- GitLab