diff --git a/examples/classifier_compression/compress_classifier.py b/examples/classifier_compression/compress_classifier.py index 36cee2edcee1c0734712810d90970ed28e1c4d85..4a9234ce7e34c02a50fdaa41c850d4f5e8aa2390 100755 --- a/examples/classifier_compression/compress_classifier.py +++ b/examples/classifier_compression/compress_classifier.py @@ -143,6 +143,8 @@ parser.add_argument('--out-dir', '-o', dest='output_dir', default='logs', help=' parser.add_argument('--validation-size', '--vs', type=float_range, default=0.1, help='Portion of training dataset to set aside for validation') parser.add_argument('--adc', dest='ADC', action='store_true', help='temp HACK') +parser.add_argument('--adc-params', dest='ADC_params', default=None, help='temp HACK') +parser.add_argument('--confusion', dest='display_confusion', default=False, action='store_true', help='Display the confusion matrix') def check_pytorch_version(): @@ -209,6 +211,7 @@ def main(): # Infer the dataset from the model name args.dataset = 'cifar10' if 'cifar' in args.arch else 'imagenet' + args.num_classes = 10 if args.dataset == 'cifar10' else 1000 # Create the model model = create_model(args.pretrained, args.dataset, args.arch, device_ids=args.gpus) @@ -232,19 +235,24 @@ def main(): msglogger.info('Optimizer Args: %s', optimizer.defaults) if args.ADC: - HAVE_GYM_INSTALLED = False - if not HAVE_GYM_INSTALLED: + import examples.automated_deep_compression.ADC as ADC + HAVE_COACH_INSTALLED = True + if not HAVE_COACH_INSTALLED: raise ValueError("ADC is currently experimental and uses non-public Coach features") - import examples.automated_deep_compression.ADC as ADC train_loader, val_loader, test_loader, _ = apputils.load_data( args.dataset, os.path.expanduser(args.data), args.batch_size, args.workers, args.validation_size, args.deterministic) + args.display_confusion = True validate_fn = partial(validate, val_loader=test_loader, criterion=criterion, - loggers=[pylogger], print_freq=args.print_freq) + loggers=[pylogger], args=args) + + if args.ADC_params is not None: + ADC.summarize_experiment(args.ADC_params, args.dataset, args.arch, validate_fn) + exit() - save_checkpoint_fn = partial(apputils.save_checkpoint, arch=args.arch, name='adc') + save_checkpoint_fn = partial(apputils.save_checkpoint, arch=args.arch, dir=msglogger.logdir) ADC.do_adc(model, args.dataset, args.arch, val_loader, validate_fn, save_checkpoint_fn) exit() @@ -278,7 +286,7 @@ def main(): # model. The ouptut is saved to CSV and PNG. msglogger.info("Running sensitivity tests") test_fnc = partial(test, test_loader=test_loader, criterion=criterion, - loggers=[pylogger], print_freq=args.print_freq) + loggers=[pylogger], args=args) which_params = [param_name for param_name, _ in model.named_parameters()] sensitivity = distiller.perform_sensitivity_analysis(model, net_params=which_params, @@ -300,7 +308,7 @@ def main(): quantizer = quantization.SymmetricLinearQuantizer(model, 8, 8) quantizer.prepare_model() model.cuda() - top1, _, _ = test(test_loader, model, criterion, [pylogger], args.print_freq) + top1, _, _ = test(test_loader, model, criterion, [pylogger], args=args) if args.quantize: checkpoint_name = 'quantized' apputils.save_checkpoint(0, args.arch, model, optimizer=None, best_top1=top1, @@ -315,8 +323,6 @@ def main(): # Model is re-transferred to GPU in case parameters were added (e.g. PACTQuantizer) model.cuda() - best_epoch = start_epoch - for epoch in range(start_epoch, start_epoch + args.epochs): # This is the main training loop. msglogger.info('\n') @@ -325,14 +331,14 @@ def main(): # Train for one epoch train(train_loader, model, criterion, optimizer, epoch, compression_scheduler, - loggers=[tflogger, pylogger], print_freq=args.print_freq, log_params_hist=args.log_params_histograms) + loggers=[tflogger, pylogger], args=args) distiller.log_weights_sparsity(model, epoch, loggers=[tflogger, pylogger]) if args.activation_stats: distiller.log_activation_sparsity(epoch, loggers=[tflogger, pylogger], collector=activations_sparsity) # evaluate on validation set - top1, top5, vloss = validate(val_loader, model, criterion, [pylogger], args.print_freq, epoch) + top1, top5, vloss = validate(val_loader, model, criterion, [pylogger], args, epoch) stats = ('Peformance/Validation/', OrderedDict([('Loss', vloss), ('Top1', top1), @@ -353,11 +359,11 @@ def main(): args.name, msglogger.logdir) # Finally run results on the test set - test(test_loader, model, criterion, [pylogger], args.print_freq) + test(test_loader, model, criterion, [pylogger], args=args) def train(train_loader, model, criterion, optimizer, epoch, - compression_scheduler, loggers, print_freq, log_params_hist): + compression_scheduler, loggers, args): """Training loop for one epoch.""" losses = {'objective_loss': tnt.AverageValueMeter(), 'regularizer_loss': tnt.AverageValueMeter()} @@ -413,7 +419,7 @@ def train(train_loader, model, criterion, optimizer, epoch, batch_time.add(time.time() - end) steps_completed = (train_step+1) - if steps_completed % print_freq == 0: + if steps_completed % args.print_freq == 0: # Log some statistics lr = optimizer.param_groups[0]['lr'] stats = ('Peformance/Training/', @@ -425,36 +431,39 @@ def train(train_loader, model, criterion, optimizer, epoch, ('LR', lr), ('Time', batch_time.mean)])) + params = model.named_parameters() if args.log_params_histograms else None distiller.log_training_progress(stats, - model.named_parameters() if log_params_hist else None, + params, epoch, steps_completed, - steps_per_epoch, print_freq, + steps_per_epoch, args.print_freq, loggers) end = time.time() -def validate(val_loader, model, criterion, loggers, print_freq, epoch=-1): +def validate(val_loader, model, criterion, loggers, args, epoch=-1): """Model validation""" if epoch > -1: msglogger.info('--- validate (epoch=%d)-----------', epoch) else: msglogger.info('--- validate ---------------------') - return _validate(val_loader, model, criterion, loggers, print_freq, epoch) + return _validate(val_loader, model, criterion, loggers, args, epoch) -def test(test_loader, model, criterion, loggers, print_freq): +def test(test_loader, model, criterion, loggers, args): """Model Test""" msglogger.info('--- test ---------------------') - return _validate(test_loader, model, criterion, loggers, print_freq) + return _validate(test_loader, model, criterion, loggers, args) -def _validate(data_loader, model, criterion, loggers, print_freq, epoch=-1): +def _validate(data_loader, model, criterion, loggers, args, epoch=-1): """Execute the validation/test loop.""" losses = {'objective_loss': tnt.AverageValueMeter()} classerr = tnt.ClassErrorMeter(accuracy=True, topk=(1, 5)) batch_time = tnt.AverageValueMeter() total_samples = len(data_loader.sampler) batch_size = data_loader.batch_size + if args.display_confusion: + confusion = tnt.ConfusionMeter(args.num_classes) total_steps = total_samples / batch_size msglogger.info('%d samples (%d per mini-batch)', total_samples, batch_size) @@ -475,22 +484,27 @@ def _validate(data_loader, model, criterion, loggers, print_freq, epoch=-1): # measure accuracy and record loss losses['objective_loss'].add(loss.item()) classerr.add(output.data, target) + if args.display_confusion: + confusion.add(output.data, target) # measure elapsed time batch_time.add(time.time() - end) end = time.time() steps_completed = (validation_step+1) - if steps_completed % print_freq == 0: + if steps_completed % args.print_freq == 0: stats = ('', OrderedDict([('Loss', losses['objective_loss'].mean), ('Top1', classerr.value(1)), ('Top5', classerr.value(5))])) distiller.log_training_progress(stats, None, epoch, steps_completed, - total_steps, print_freq, loggers) + total_steps, args.print_freq, loggers) msglogger.info('==> Top1: %.3f Top5: %.3f Loss: %.3f\n', classerr.value()[0], classerr.value()[1], losses['objective_loss'].mean) + + if args.display_confusion: + msglogger.info('==> Confusion:\n%s', str(confusion.value())) return classerr.value(1), classerr.value(5), losses['objective_loss'].mean