From e46e196e482d6e9bb4b3aa60fe0869e4422da5a2 Mon Sep 17 00:00:00 2001 From: Neta Zmora <neta.zmora@intel.com> Date: Wed, 25 Jul 2018 13:01:40 +0300 Subject: [PATCH] compress_classifier.py: changed the signature of validate() and test() Due to the various uses of these functions, we need to pass an ever growing number of arguments to these functions and the API is becoing bloated and unstable. Also added the option to log the confusion matrix. --- .../compress_classifier.py | 60 ++++++++++++------- 1 file changed, 37 insertions(+), 23 deletions(-) diff --git a/examples/classifier_compression/compress_classifier.py b/examples/classifier_compression/compress_classifier.py index 36cee2e..4a9234c 100755 --- a/examples/classifier_compression/compress_classifier.py +++ b/examples/classifier_compression/compress_classifier.py @@ -143,6 +143,8 @@ parser.add_argument('--out-dir', '-o', dest='output_dir', default='logs', help=' parser.add_argument('--validation-size', '--vs', type=float_range, default=0.1, help='Portion of training dataset to set aside for validation') parser.add_argument('--adc', dest='ADC', action='store_true', help='temp HACK') +parser.add_argument('--adc-params', dest='ADC_params', default=None, help='temp HACK') +parser.add_argument('--confusion', dest='display_confusion', default=False, action='store_true', help='Display the confusion matrix') def check_pytorch_version(): @@ -209,6 +211,7 @@ def main(): # Infer the dataset from the model name args.dataset = 'cifar10' if 'cifar' in args.arch else 'imagenet' + args.num_classes = 10 if args.dataset == 'cifar10' else 1000 # Create the model model = create_model(args.pretrained, args.dataset, args.arch, device_ids=args.gpus) @@ -232,19 +235,24 @@ def main(): msglogger.info('Optimizer Args: %s', optimizer.defaults) if args.ADC: - HAVE_GYM_INSTALLED = False - if not HAVE_GYM_INSTALLED: + import examples.automated_deep_compression.ADC as ADC + HAVE_COACH_INSTALLED = True + if not HAVE_COACH_INSTALLED: raise ValueError("ADC is currently experimental and uses non-public Coach features") - import examples.automated_deep_compression.ADC as ADC train_loader, val_loader, test_loader, _ = apputils.load_data( args.dataset, os.path.expanduser(args.data), args.batch_size, args.workers, args.validation_size, args.deterministic) + args.display_confusion = True validate_fn = partial(validate, val_loader=test_loader, criterion=criterion, - loggers=[pylogger], print_freq=args.print_freq) + loggers=[pylogger], args=args) + + if args.ADC_params is not None: + ADC.summarize_experiment(args.ADC_params, args.dataset, args.arch, validate_fn) + exit() - save_checkpoint_fn = partial(apputils.save_checkpoint, arch=args.arch, name='adc') + save_checkpoint_fn = partial(apputils.save_checkpoint, arch=args.arch, dir=msglogger.logdir) ADC.do_adc(model, args.dataset, args.arch, val_loader, validate_fn, save_checkpoint_fn) exit() @@ -278,7 +286,7 @@ def main(): # model. The ouptut is saved to CSV and PNG. msglogger.info("Running sensitivity tests") test_fnc = partial(test, test_loader=test_loader, criterion=criterion, - loggers=[pylogger], print_freq=args.print_freq) + loggers=[pylogger], args=args) which_params = [param_name for param_name, _ in model.named_parameters()] sensitivity = distiller.perform_sensitivity_analysis(model, net_params=which_params, @@ -300,7 +308,7 @@ def main(): quantizer = quantization.SymmetricLinearQuantizer(model, 8, 8) quantizer.prepare_model() model.cuda() - top1, _, _ = test(test_loader, model, criterion, [pylogger], args.print_freq) + top1, _, _ = test(test_loader, model, criterion, [pylogger], args=args) if args.quantize: checkpoint_name = 'quantized' apputils.save_checkpoint(0, args.arch, model, optimizer=None, best_top1=top1, @@ -315,8 +323,6 @@ def main(): # Model is re-transferred to GPU in case parameters were added (e.g. PACTQuantizer) model.cuda() - best_epoch = start_epoch - for epoch in range(start_epoch, start_epoch + args.epochs): # This is the main training loop. msglogger.info('\n') @@ -325,14 +331,14 @@ def main(): # Train for one epoch train(train_loader, model, criterion, optimizer, epoch, compression_scheduler, - loggers=[tflogger, pylogger], print_freq=args.print_freq, log_params_hist=args.log_params_histograms) + loggers=[tflogger, pylogger], args=args) distiller.log_weights_sparsity(model, epoch, loggers=[tflogger, pylogger]) if args.activation_stats: distiller.log_activation_sparsity(epoch, loggers=[tflogger, pylogger], collector=activations_sparsity) # evaluate on validation set - top1, top5, vloss = validate(val_loader, model, criterion, [pylogger], args.print_freq, epoch) + top1, top5, vloss = validate(val_loader, model, criterion, [pylogger], args, epoch) stats = ('Peformance/Validation/', OrderedDict([('Loss', vloss), ('Top1', top1), @@ -353,11 +359,11 @@ def main(): args.name, msglogger.logdir) # Finally run results on the test set - test(test_loader, model, criterion, [pylogger], args.print_freq) + test(test_loader, model, criterion, [pylogger], args=args) def train(train_loader, model, criterion, optimizer, epoch, - compression_scheduler, loggers, print_freq, log_params_hist): + compression_scheduler, loggers, args): """Training loop for one epoch.""" losses = {'objective_loss': tnt.AverageValueMeter(), 'regularizer_loss': tnt.AverageValueMeter()} @@ -413,7 +419,7 @@ def train(train_loader, model, criterion, optimizer, epoch, batch_time.add(time.time() - end) steps_completed = (train_step+1) - if steps_completed % print_freq == 0: + if steps_completed % args.print_freq == 0: # Log some statistics lr = optimizer.param_groups[0]['lr'] stats = ('Peformance/Training/', @@ -425,36 +431,39 @@ def train(train_loader, model, criterion, optimizer, epoch, ('LR', lr), ('Time', batch_time.mean)])) + params = model.named_parameters() if args.log_params_histograms else None distiller.log_training_progress(stats, - model.named_parameters() if log_params_hist else None, + params, epoch, steps_completed, - steps_per_epoch, print_freq, + steps_per_epoch, args.print_freq, loggers) end = time.time() -def validate(val_loader, model, criterion, loggers, print_freq, epoch=-1): +def validate(val_loader, model, criterion, loggers, args, epoch=-1): """Model validation""" if epoch > -1: msglogger.info('--- validate (epoch=%d)-----------', epoch) else: msglogger.info('--- validate ---------------------') - return _validate(val_loader, model, criterion, loggers, print_freq, epoch) + return _validate(val_loader, model, criterion, loggers, args, epoch) -def test(test_loader, model, criterion, loggers, print_freq): +def test(test_loader, model, criterion, loggers, args): """Model Test""" msglogger.info('--- test ---------------------') - return _validate(test_loader, model, criterion, loggers, print_freq) + return _validate(test_loader, model, criterion, loggers, args) -def _validate(data_loader, model, criterion, loggers, print_freq, epoch=-1): +def _validate(data_loader, model, criterion, loggers, args, epoch=-1): """Execute the validation/test loop.""" losses = {'objective_loss': tnt.AverageValueMeter()} classerr = tnt.ClassErrorMeter(accuracy=True, topk=(1, 5)) batch_time = tnt.AverageValueMeter() total_samples = len(data_loader.sampler) batch_size = data_loader.batch_size + if args.display_confusion: + confusion = tnt.ConfusionMeter(args.num_classes) total_steps = total_samples / batch_size msglogger.info('%d samples (%d per mini-batch)', total_samples, batch_size) @@ -475,22 +484,27 @@ def _validate(data_loader, model, criterion, loggers, print_freq, epoch=-1): # measure accuracy and record loss losses['objective_loss'].add(loss.item()) classerr.add(output.data, target) + if args.display_confusion: + confusion.add(output.data, target) # measure elapsed time batch_time.add(time.time() - end) end = time.time() steps_completed = (validation_step+1) - if steps_completed % print_freq == 0: + if steps_completed % args.print_freq == 0: stats = ('', OrderedDict([('Loss', losses['objective_loss'].mean), ('Top1', classerr.value(1)), ('Top5', classerr.value(5))])) distiller.log_training_progress(stats, None, epoch, steps_completed, - total_steps, print_freq, loggers) + total_steps, args.print_freq, loggers) msglogger.info('==> Top1: %.3f Top5: %.3f Loss: %.3f\n', classerr.value()[0], classerr.value()[1], losses['objective_loss'].mean) + + if args.display_confusion: + msglogger.info('==> Confusion:\n%s', str(confusion.value())) return classerr.value(1), classerr.value(5), losses['objective_loss'].mean -- GitLab