diff --git a/distiller/model_summaries.py b/distiller/model_summaries.py index 48977795f06b77ad337f1eadaaf26e1be859ba8c..dd3d34817387b36953b6a8108b4f0e8c778beb88 100755 --- a/distiller/model_summaries.py +++ b/distiller/model_summaries.py @@ -58,8 +58,19 @@ def model_summary(model, optimizer, what, dataset=None): elif what == 'optimizer': optimizer_summary(optimizer) elif what == 'model': - print(model) # print the simple form of the model - + # print the simple form of the model + print(model) + elif what == 'modules': + # Print the names of non-leaf modules + # Remember that in PyTorch not every node is a module (e.g. F.relu). + # Also remember that parameterless modules, like nn.MaxPool2d, can be used multiple + # times in the same model, but they will only appear once in the modules list. + nodes = [] + for name, module in model.named_modules(): + # Only print leaf modules + if len(module._modules) == 0: + nodes.append([name, module.__class__.__name__]) + print(tabulate(nodes, headers=['Name', 'Type'])) def optimizer_summary(optimizer): assert isinstance(optimizer, torch.optim.SGD) diff --git a/examples/classifier_compression/compress_classifier.py b/examples/classifier_compression/compress_classifier.py index 326ce2b35140af280952da11ec1cd47627217444..26bedf5a49d33474ea27a5500796ccf6d95b1261 100755 --- a/examples/classifier_compression/compress_classifier.py +++ b/examples/classifier_compression/compress_classifier.py @@ -115,7 +115,7 @@ parser.add_argument('--act-stats', dest='activation_stats', action='store_true', help='collect activation statistics (WARNING: this slows down training)') parser.add_argument('--param-hist', dest='log_params_histograms', action='store_true', default=False, help='log the paramter tensors histograms to file (WARNING: this can use significant disk space)') -SUMMARY_CHOICES = ['sparsity', 'compute', 'optimizer', 'model', 'png'] +SUMMARY_CHOICES = ['sparsity', 'compute', 'optimizer', 'model', 'modules', 'png'] parser.add_argument('--summary', type=str, choices=SUMMARY_CHOICES, help='print a summary of the model, and exit - options: ' + ' | '.join(SUMMARY_CHOICES))