diff --git a/examples/classifier_compression/compress_classifier.py b/examples/classifier_compression/compress_classifier.py index 0d845debc5912f7e0b4cc762bbc56ec85cf7e4b9..326ce2b35140af280952da11ec1cd47627217444 100755 --- a/examples/classifier_compression/compress_classifier.py +++ b/examples/classifier_compression/compress_classifier.py @@ -196,7 +196,8 @@ def main(): args.dataset = 'cifar10' if 'cifar' in args.arch else 'imagenet' # Create the model - model = create_model(args.pretrained, args.dataset, args.arch, device_ids=args.gpus) + is_parallel = args.summary != 'png' # For PNG summary, parallel graphs are illegible + model = create_model(args.pretrained, args.dataset, args.arch, parallel=is_parallel, device_ids=args.gpus) compression_scheduler = None # Create a couple of logging backends. TensorBoardLogger writes log files in a format