Skip to content
Snippets Groups Projects
Commit 14531cbf authored by Neta Zmora's avatar Neta Zmora
Browse files

inspect_ckpt.py: temporary fix

Temporary fix for dependency on distiller class hierarchy when
serializing a model that contains a thinning recipe.
parent 7a97f229
No related branches found
No related tags found
No related merge requests found
...@@ -29,6 +29,15 @@ $ python3 inspect_ckpt.py checkpoint.pth.tar --model --schedule ...@@ -29,6 +29,15 @@ $ python3 inspect_ckpt.py checkpoint.pth.tar --model --schedule
import torch import torch
import argparse import argparse
from tabulate import tabulate from tabulate import tabulate
import sys
import os
script_dir = os.path.dirname(__file__)
module_path = os.path.abspath(os.path.join(script_dir, '..', '..'))
try:
import distiller
except ImportError:
sys.path.append(module_path)
import distiller
def inspect_checkpoint(chkpt_file, args): def inspect_checkpoint(chkpt_file, args):
...@@ -55,11 +64,15 @@ def inspect_checkpoint(chkpt_file, args): ...@@ -55,11 +64,15 @@ def inspect_checkpoint(chkpt_file, args):
print("compression_sched[\"masks_dict\"] keys:\n{}".format(", ".join( print("compression_sched[\"masks_dict\"] keys:\n{}".format(", ".join(
list(compression_sched["masks_dict"].keys())))) list(compression_sched["masks_dict"].keys()))))
if args.thinning and "thinning_recipes" in checkpoint:
for recipe in checkpoint["thinning_recipes"]:
print(recipe)
if __name__ == '__main__': if __name__ == '__main__':
parser = argparse.ArgumentParser(description='Distiller checkpoint inspection') parser = argparse.ArgumentParser(description='Distiller checkpoint inspection')
parser.add_argument('chkpt_file', help='path to the checkpoint file') parser.add_argument('chkpt_file', help='path to the checkpoint file')
parser.add_argument('-m', '--model', action='store_true', help='print the model keys') parser.add_argument('-m', '--model', action='store_true', help='print the model keys')
parser.add_argument('-s', '--schedule', action='store_true', help='print the schedule keys') parser.add_argument('-s', '--schedule', action='store_true', help='print the schedule keys')
parser.add_argument('-t', '--thinning', action='store_true', help='print the thinning keys')
args = parser.parse_args() args = parser.parse_args()
inspect_checkpoint(args.chkpt_file, args) inspect_checkpoint(args.chkpt_file, args)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment