From 6a940466acd33676b76b5e60180f0298430be38f Mon Sep 17 00:00:00 2001 From: Neta Zmora <neta.zmora@intel.com> Date: Tue, 15 May 2018 17:09:49 +0300 Subject: [PATCH] New summary option: print modules names This is a niche feature, which lets you print the names of the modules in a model, from the command-line. Non-leaf nodes are excluded from this list. Other caveats are documented in the code. --- distiller/model_summaries.py | 15 +++++++++++++-- .../classifier_compression/compress_classifier.py | 2 +- 2 files changed, 14 insertions(+), 3 deletions(-) diff --git a/distiller/model_summaries.py b/distiller/model_summaries.py index 4897779..dd3d348 100755 --- a/distiller/model_summaries.py +++ b/distiller/model_summaries.py @@ -58,8 +58,19 @@ def model_summary(model, optimizer, what, dataset=None): elif what == 'optimizer': optimizer_summary(optimizer) elif what == 'model': - print(model) # print the simple form of the model - + # print the simple form of the model + print(model) + elif what == 'modules': + # Print the names of non-leaf modules + # Remember that in PyTorch not every node is a module (e.g. F.relu). + # Also remember that parameterless modules, like nn.MaxPool2d, can be used multiple + # times in the same model, but they will only appear once in the modules list. + nodes = [] + for name, module in model.named_modules(): + # Only print leaf modules + if len(module._modules) == 0: + nodes.append([name, module.__class__.__name__]) + print(tabulate(nodes, headers=['Name', 'Type'])) def optimizer_summary(optimizer): assert isinstance(optimizer, torch.optim.SGD) diff --git a/examples/classifier_compression/compress_classifier.py b/examples/classifier_compression/compress_classifier.py index 326ce2b..26bedf5 100755 --- a/examples/classifier_compression/compress_classifier.py +++ b/examples/classifier_compression/compress_classifier.py @@ -115,7 +115,7 @@ parser.add_argument('--act-stats', dest='activation_stats', action='store_true', help='collect activation statistics (WARNING: this slows down training)') parser.add_argument('--param-hist', dest='log_params_histograms', action='store_true', default=False, help='log the paramter tensors histograms to file (WARNING: this can use significant disk space)') -SUMMARY_CHOICES = ['sparsity', 'compute', 'optimizer', 'model', 'png'] +SUMMARY_CHOICES = ['sparsity', 'compute', 'optimizer', 'model', 'modules', 'png'] parser.add_argument('--summary', type=str, choices=SUMMARY_CHOICES, help='print a summary of the model, and exit - options: ' + ' | '.join(SUMMARY_CHOICES)) -- GitLab