From 14531cbfbe3ba86ad30e1b049c7d03cd887d7b5c Mon Sep 17 00:00:00 2001
From: Neta Zmora <neta.zmora@intel.com>
Date: Thu, 4 Oct 2018 15:32:38 +0300
Subject: [PATCH] inspect_ckpt.py: temporary fix

Temporary fix for dependency on distiller class hierarchy when
serializing a model that contains a thinning recipe.
---
 examples/classifier_compression/inspect_ckpt.py | 13 +++++++++++++
 1 file changed, 13 insertions(+)

diff --git a/examples/classifier_compression/inspect_ckpt.py b/examples/classifier_compression/inspect_ckpt.py
index c88962b..344a82b 100755
--- a/examples/classifier_compression/inspect_ckpt.py
+++ b/examples/classifier_compression/inspect_ckpt.py
@@ -29,6 +29,15 @@ $ python3 inspect_ckpt.py checkpoint.pth.tar --model --schedule
 import torch
 import argparse
 from tabulate import tabulate
+import sys
+import os
+script_dir = os.path.dirname(__file__)
+module_path = os.path.abspath(os.path.join(script_dir, '..', '..'))
+try:
+    import distiller
+except ImportError:
+    sys.path.append(module_path)
+    import distiller
 
 
 def inspect_checkpoint(chkpt_file, args):
@@ -55,11 +64,15 @@ def inspect_checkpoint(chkpt_file, args):
             print("compression_sched[\"masks_dict\"] keys:\n{}".format(", ".join(
                   list(compression_sched["masks_dict"].keys()))))
 
+    if args.thinning and "thinning_recipes" in checkpoint:
+        for recipe in checkpoint["thinning_recipes"]:
+            print(recipe)
 
 if __name__ == '__main__':
     parser = argparse.ArgumentParser(description='Distiller checkpoint inspection')
     parser.add_argument('chkpt_file', help='path to the checkpoint file')
     parser.add_argument('-m', '--model', action='store_true', help='print the model keys')
     parser.add_argument('-s', '--schedule', action='store_true', help='print the schedule keys')
+    parser.add_argument('-t', '--thinning', action='store_true', help='print the thinning keys')
     args = parser.parse_args()
     inspect_checkpoint(args.chkpt_file, args)
-- 
GitLab