From ba11d30737d9c3898e8055e265b88d64c2ff167a Mon Sep 17 00:00:00 2001
From: Neta Zmora <neta.zmora@intel.com>
Date: Mon, 12 Aug 2019 01:30:23 +0300
Subject: [PATCH] classifier_compression/inspect_ckpt.py: print mask sparsities

When using flag `-s` which prints the compression scheduler
pruning mask keys, we also print a table with the fine-grain
sparsity of each mask.
---
 examples/classifier_compression/inspect_ckpt.py | 8 +++++---
 1 file changed, 5 insertions(+), 3 deletions(-)

diff --git a/examples/classifier_compression/inspect_ckpt.py b/examples/classifier_compression/inspect_ckpt.py
index 1d14936..727aafd 100755
--- a/examples/classifier_compression/inspect_ckpt.py
+++ b/examples/classifier_compression/inspect_ckpt.py
@@ -29,7 +29,6 @@ $ python3 inspect_ckpt.py checkpoint.pth.tar --model --schedule
 import torch
 import argparse
 from tabulate import tabulate
-
 import distiller
 from distiller.apputils.checkpoint import get_contents_table
 
@@ -53,8 +52,11 @@ def inspect_checkpoint(chkpt_file, args):
         sched_keys = [[k, type(compression_sched[k]).__name__] for k in compression_sched.keys()]
         print(tabulate(sched_keys, headers=["Key", "Type"], tablefmt="fancy_grid"))
         if "masks_dict" in checkpoint["compression_sched"]:
-            print("compression_sched[\"masks_dict\"] keys:\n{}".format(", ".join(
-                  list(compression_sched["masks_dict"].keys()))))
+            masks_dict = compression_sched["masks_dict"]
+            print("compression_sched[\"masks_dict\"] keys:\n{}".format(list(masks_dict.keys())))
+            mask_sparsities = [(param_name, distiller.sparsity(mask)) for param_name, mask in masks_dict.items()
+                               if mask is not None]
+            print(tabulate(mask_sparsities, headers=["Module", "Mask Sparsity"], tablefmt="fancy_grid"))
 
     if args.thinning and "thinning_recipes" in checkpoint:
         for recipe in checkpoint["thinning_recipes"]:
-- 
GitLab