From 31c1bd8975710798d4a88008e91a05933925e9e8 Mon Sep 17 00:00:00 2001
From: Neta Zmora <neta.zmora@intel.com>
Date: Wed, 23 Oct 2019 21:29:45 +0300
Subject: [PATCH] inspect_ckpt.py: support for very large models

Force loading on the CPU which always has more memory than a
single GPU.  This is useful for models that cannot be loaded onto
a single GPU.
---
 .../classifier_compression/inspect_ckpt.py    | 21 ++++++++++++++-----
 1 file changed, 16 insertions(+), 5 deletions(-)

diff --git a/examples/classifier_compression/inspect_ckpt.py b/examples/classifier_compression/inspect_ckpt.py
index 727aafd..7d6d05a 100755
--- a/examples/classifier_compression/inspect_ckpt.py
+++ b/examples/classifier_compression/inspect_ckpt.py
@@ -33,15 +33,26 @@ import distiller
 from distiller.apputils.checkpoint import get_contents_table
 
 
+def print_sparsities(masks_dict):
+    mask_sparsities = [(param_name, distiller.sparsity(mask)) for param_name, mask in masks_dict.items()
+                       if mask is not None]
+    print(tabulate(mask_sparsities, headers=["Module", "Mask Sparsity"], tablefmt="fancy_grid"))
+
+
 def inspect_checkpoint(chkpt_file, args):
     print("Inspecting checkpoint file: ", chkpt_file)
-    checkpoint = torch.load(chkpt_file)
+    # force loading on the CPU which always has more memory than the GPU(s)
+    checkpoint = torch.load(chkpt_file, map_location='cpu')
 
     print(get_contents_table(checkpoint))
 
     if 'extras' in checkpoint and checkpoint['extras']:
         print("\nContents of Checkpoint['extras']:")
         print(get_contents_table(checkpoint['extras']))
+        try:
+            print_sparsities(checkpoint["extras"]["creation_masks"])
+        except KeyError:
+            pass
 
     if args.model and "state_dict" in checkpoint:
         print("\nModel keys (state_dict):\n{}".format(", ".join(list(checkpoint["state_dict"].keys()))))
@@ -51,12 +62,12 @@ def inspect_checkpoint(chkpt_file, args):
         print("\nSchedule keys (compression_sched):\n{}\n".format("\n\t".join(list(compression_sched.keys()))))
         sched_keys = [[k, type(compression_sched[k]).__name__] for k in compression_sched.keys()]
         print(tabulate(sched_keys, headers=["Key", "Type"], tablefmt="fancy_grid"))
-        if "masks_dict" in checkpoint["compression_sched"]:
+        try:
             masks_dict = compression_sched["masks_dict"]
             print("compression_sched[\"masks_dict\"] keys:\n{}".format(list(masks_dict.keys())))
-            mask_sparsities = [(param_name, distiller.sparsity(mask)) for param_name, mask in masks_dict.items()
-                               if mask is not None]
-            print(tabulate(mask_sparsities, headers=["Module", "Mask Sparsity"], tablefmt="fancy_grid"))
+            print_sparsities(masks_dict)
+        except KeyError:
+            pass
 
     if args.thinning and "thinning_recipes" in checkpoint:
         for recipe in checkpoint["thinning_recipes"]:
-- 
GitLab