Skip to content
Snippets Groups Projects
Commit 31c1bd89 authored by Neta Zmora's avatar Neta Zmora
Browse files

inspect_ckpt.py: support for very large models

Force loading on the CPU which always has more memory than a
single GPU.  This is useful for models that cannot be loaded onto
a single GPU.
parent 2c2a9417
No related branches found
No related tags found
No related merge requests found
...@@ -33,15 +33,26 @@ import distiller ...@@ -33,15 +33,26 @@ import distiller
from distiller.apputils.checkpoint import get_contents_table from distiller.apputils.checkpoint import get_contents_table
def print_sparsities(masks_dict):
mask_sparsities = [(param_name, distiller.sparsity(mask)) for param_name, mask in masks_dict.items()
if mask is not None]
print(tabulate(mask_sparsities, headers=["Module", "Mask Sparsity"], tablefmt="fancy_grid"))
def inspect_checkpoint(chkpt_file, args): def inspect_checkpoint(chkpt_file, args):
print("Inspecting checkpoint file: ", chkpt_file) print("Inspecting checkpoint file: ", chkpt_file)
checkpoint = torch.load(chkpt_file) # force loading on the CPU which always has more memory than the GPU(s)
checkpoint = torch.load(chkpt_file, map_location='cpu')
print(get_contents_table(checkpoint)) print(get_contents_table(checkpoint))
if 'extras' in checkpoint and checkpoint['extras']: if 'extras' in checkpoint and checkpoint['extras']:
print("\nContents of Checkpoint['extras']:") print("\nContents of Checkpoint['extras']:")
print(get_contents_table(checkpoint['extras'])) print(get_contents_table(checkpoint['extras']))
try:
print_sparsities(checkpoint["extras"]["creation_masks"])
except KeyError:
pass
if args.model and "state_dict" in checkpoint: if args.model and "state_dict" in checkpoint:
print("\nModel keys (state_dict):\n{}".format(", ".join(list(checkpoint["state_dict"].keys())))) print("\nModel keys (state_dict):\n{}".format(", ".join(list(checkpoint["state_dict"].keys()))))
...@@ -51,12 +62,12 @@ def inspect_checkpoint(chkpt_file, args): ...@@ -51,12 +62,12 @@ def inspect_checkpoint(chkpt_file, args):
print("\nSchedule keys (compression_sched):\n{}\n".format("\n\t".join(list(compression_sched.keys())))) print("\nSchedule keys (compression_sched):\n{}\n".format("\n\t".join(list(compression_sched.keys()))))
sched_keys = [[k, type(compression_sched[k]).__name__] for k in compression_sched.keys()] sched_keys = [[k, type(compression_sched[k]).__name__] for k in compression_sched.keys()]
print(tabulate(sched_keys, headers=["Key", "Type"], tablefmt="fancy_grid")) print(tabulate(sched_keys, headers=["Key", "Type"], tablefmt="fancy_grid"))
if "masks_dict" in checkpoint["compression_sched"]: try:
masks_dict = compression_sched["masks_dict"] masks_dict = compression_sched["masks_dict"]
print("compression_sched[\"masks_dict\"] keys:\n{}".format(list(masks_dict.keys()))) print("compression_sched[\"masks_dict\"] keys:\n{}".format(list(masks_dict.keys())))
mask_sparsities = [(param_name, distiller.sparsity(mask)) for param_name, mask in masks_dict.items() print_sparsities(masks_dict)
if mask is not None] except KeyError:
print(tabulate(mask_sparsities, headers=["Module", "Mask Sparsity"], tablefmt="fancy_grid")) pass
if args.thinning and "thinning_recipes" in checkpoint: if args.thinning and "thinning_recipes" in checkpoint:
for recipe in checkpoint["thinning_recipes"]: for recipe in checkpoint["thinning_recipes"]:
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment