Skip to content
Snippets Groups Projects
Commit 11402988 authored by Neta Zmora's avatar Neta Zmora
Browse files

compress_classifier.py: add an option to load a model in serialized mode

By default, when we create a model we  wrap it with DataParallel to benefit
from data-parallelism across GPUs (mainly for convolution layers).

But sometimes we don't want the sample application to do this: for
example when we receive a model that was trained serially.
This commit adds a new argument to the application to prevent
the use of DataParallel.
parent 3876a912
No related branches found
No related tags found
No related merge requests found
......@@ -152,7 +152,9 @@ parser.add_argument('--earlyexit_thresholds', type=float, nargs='*', dest='early
help='List of EarlyExit thresholds (e.g. --earlyexit 1.2 0.9)')
parser.add_argument('--num-best-scores', dest='num_best_scores', default=1, type=int,
help='number of best scores to track and report (default: 1)')
parser.add_argument('--load-serialized', dest='load_serialized', action='store_true', default=False,
help='Load a model without DataParallel wrapping it')
quant_group = parser.add_argument_group('Arguments controlling quantization at evaluation time'
'("post-training quantization)')
quant_group.add_argument('--quantize-eval', '--qe', action='store_true',
......@@ -289,7 +291,8 @@ def main():
args.exiterrors = []
# Create the model
model = create_model(args.pretrained, args.dataset, args.arch, device_ids=args.gpus)
model = create_model(args.pretrained, args.dataset, args.arch,
parallel=not args.load_serialized, device_ids=args.gpus)
compression_scheduler = None
# Create a couple of logging backends. TensorBoardLogger writes log files in a format
# that can be read by Google's Tensor Board. PythonLogger writes to the Python logger.
......@@ -666,7 +669,7 @@ def earlyexit_validate_loss(output, target, criterion, args):
this_batch_size = target.size()[0]
earlyexit_validate_criterion = nn.CrossEntropyLoss(reduction='none').cuda()
for exitnum in range(args.num_exits):
# calculate losses at each sample separately in the minibatch.
# calculate losses at each sample separately in the minibatch.
args.loss_exits[exitnum] = earlyexit_validate_criterion(output[exitnum], target)
# for batch_size > 1, we need to reduce this down to an average over the batch
args.losses_exits[exitnum].add(torch.mean(args.loss_exits[exitnum]))
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment