diff --git a/examples/classifier_compression/inspect_ckpt.py b/examples/classifier_compression/inspect_ckpt.py index 1d149366e57a5499d751b34e39b92b08577f5cc6..727aafd59a58a981992c15e87fafd7f678841267 100755 --- a/examples/classifier_compression/inspect_ckpt.py +++ b/examples/classifier_compression/inspect_ckpt.py @@ -29,7 +29,6 @@ $ python3 inspect_ckpt.py checkpoint.pth.tar --model --schedule import torch import argparse from tabulate import tabulate - import distiller from distiller.apputils.checkpoint import get_contents_table @@ -53,8 +52,11 @@ def inspect_checkpoint(chkpt_file, args): sched_keys = [[k, type(compression_sched[k]).__name__] for k in compression_sched.keys()] print(tabulate(sched_keys, headers=["Key", "Type"], tablefmt="fancy_grid")) if "masks_dict" in checkpoint["compression_sched"]: - print("compression_sched[\"masks_dict\"] keys:\n{}".format(", ".join( - list(compression_sched["masks_dict"].keys())))) + masks_dict = compression_sched["masks_dict"] + print("compression_sched[\"masks_dict\"] keys:\n{}".format(list(masks_dict.keys()))) + mask_sparsities = [(param_name, distiller.sparsity(mask)) for param_name, mask in masks_dict.items() + if mask is not None] + print(tabulate(mask_sparsities, headers=["Module", "Mask Sparsity"], tablefmt="fancy_grid")) if args.thinning and "thinning_recipes" in checkpoint: for recipe in checkpoint["thinning_recipes"]: