diff --git a/examples/classifier_compression/compress_classifier.py b/examples/classifier_compression/compress_classifier.py index 6c5ab5903b45d1e677e590daa7c39fa7045d7279..bf228801163df2d09cae1989a93b7c0d8b2d78ee 100755 --- a/examples/classifier_compression/compress_classifier.py +++ b/examples/classifier_compression/compress_classifier.py @@ -152,7 +152,9 @@ parser.add_argument('--earlyexit_thresholds', type=float, nargs='*', dest='early help='List of EarlyExit thresholds (e.g. --earlyexit 1.2 0.9)') parser.add_argument('--num-best-scores', dest='num_best_scores', default=1, type=int, help='number of best scores to track and report (default: 1)') - +parser.add_argument('--load-serialized', dest='load_serialized', action='store_true', default=False, + help='Load a model without DataParallel wrapping it') + quant_group = parser.add_argument_group('Arguments controlling quantization at evaluation time' '("post-training quantization)') quant_group.add_argument('--quantize-eval', '--qe', action='store_true', @@ -289,7 +291,8 @@ def main(): args.exiterrors = [] # Create the model - model = create_model(args.pretrained, args.dataset, args.arch, device_ids=args.gpus) + model = create_model(args.pretrained, args.dataset, args.arch, + parallel=not args.load_serialized, device_ids=args.gpus) compression_scheduler = None # Create a couple of logging backends. TensorBoardLogger writes log files in a format # that can be read by Google's Tensor Board. PythonLogger writes to the Python logger. @@ -666,7 +669,7 @@ def earlyexit_validate_loss(output, target, criterion, args): this_batch_size = target.size()[0] earlyexit_validate_criterion = nn.CrossEntropyLoss(reduction='none').cuda() for exitnum in range(args.num_exits): - # calculate losses at each sample separately in the minibatch. + # calculate losses at each sample separately in the minibatch. args.loss_exits[exitnum] = earlyexit_validate_criterion(output[exitnum], target) # for batch_size > 1, we need to reduce this down to an average over the batch args.losses_exits[exitnum].add(torch.mean(args.loss_exits[exitnum]))