diff --git a/distiller/apputils/image_classifier.py b/distiller/apputils/image_classifier.py index 8db49acb366b83de9c5323c82f38164b44ff810a..a6a859158b33a454fe7932f22e21a017675634ea 100755 --- a/distiller/apputils/image_classifier.py +++ b/distiller/apputils/image_classifier.py @@ -115,7 +115,6 @@ class ClassifierCompressor(object): def validate_one_epoch(self, epoch, verbose=True): """Evaluate on validation set""" self.load_datasets() - with collectors_context(self.activations_collectors["valid"]) as collectors: top1, top5, vloss = validate(self.val_loader, self.model, self.criterion, [self.pylogger], self.args, epoch) @@ -128,9 +127,8 @@ class ClassifierCompressor(object): OrderedDict([('Loss', vloss), ('Top1', top1), ('Top5', top5)])) - distiller.log_training_progress(stats, None, epoch, steps_completed=0, total_steps=1, log_freq=1, - loggers=[self.tflogger]) - + distiller.log_training_progress(stats, None, epoch, steps_completed=0, + total_steps=1, log_freq=1, loggers=[self.tflogger]) return top1, top5, vloss def _finalize_epoch(self, epoch, perf_scores_history, top1, top5): @@ -284,7 +282,6 @@ def init_classifier_compression_arg_parser(): help='Load a model without DataParallel wrapping it') parser.add_argument('--thinnify', dest='thinnify', action='store_true', default=False, help='physically remove zero-filters and create a smaller model') - distiller.quantization.add_post_train_quant_args(parser) return parser @@ -490,6 +487,34 @@ def train(train_loader, model, criterion, optimizer, epoch, optimizer.step() compression_scheduler.on_minibatch_end(epoch) """ + def _log_training_progress(): + # Log some statistics + errs = OrderedDict() + if not early_exit_mode(args): + errs['Top1'] = classerr.value(1) + errs['Top5'] = classerr.value(5) + else: + # for Early Exit case, the Top1 and Top5 stats are computed for each exit. + for exitnum in range(args.num_exits): + errs['Top1_exit' + str(exitnum)] = args.exiterrors[exitnum].value(1) + errs['Top5_exit' + str(exitnum)] = args.exiterrors[exitnum].value(5) + + stats_dict = OrderedDict() + for loss_name, meter in losses.items(): + stats_dict[loss_name] = meter.mean + stats_dict.update(errs) + stats_dict['LR'] = optimizer.param_groups[0]['lr'] + stats_dict['Time'] = batch_time.mean + stats = ('Performance/Training/', stats_dict) + + params = model.named_parameters() if args.log_params_histograms else None + distiller.log_training_progress(stats, + params, + epoch, steps_completed, + steps_per_epoch, args.print_freq, + loggers) + + OVERALL_LOSS_KEY = 'Overall Loss' OBJECTIVE_LOSS_KEY = 'Objective Loss' @@ -570,31 +595,8 @@ def train(train_loader, model, criterion, optimizer, epoch, steps_completed = (train_step+1) if steps_completed % args.print_freq == 0: - # Log some statistics - errs = OrderedDict() - if not early_exit_mode(args): - errs['Top1'] = classerr.value(1) - errs['Top5'] = classerr.value(5) - else: - # for Early Exit case, the Top1 and Top5 stats are computed for each exit. - for exitnum in range(args.num_exits): - errs['Top1_exit' + str(exitnum)] = args.exiterrors[exitnum].value(1) - errs['Top5_exit' + str(exitnum)] = args.exiterrors[exitnum].value(5) + _log_training_progress() - stats_dict = OrderedDict() - for loss_name, meter in losses.items(): - stats_dict[loss_name] = meter.mean - stats_dict.update(errs) - stats_dict['LR'] = optimizer.param_groups[0]['lr'] - stats_dict['Time'] = batch_time.mean - stats = ('Performance/Training/', stats_dict) - - params = model.named_parameters() if args.log_params_histograms else None - distiller.log_training_progress(stats, - params, - epoch, steps_completed, - steps_per_epoch, args.print_freq, - loggers) end = time.time() #return acc_stats # NOTE: this breaks previous behavior, which returned a history of (top1, top5) values @@ -628,6 +630,29 @@ def _is_earlyexit(args): def _validate(data_loader, model, criterion, loggers, args, epoch=-1): + def _log_validation_progress(): + if not _is_earlyexit(args): + stats_dict = OrderedDict([('Loss', losses['objective_loss'].mean), + ('Top1', classerr.value(1)), + ('Top5', classerr.value(5))]) + else: + stats_dict = OrderedDict() + stats_dict['Test'] = validation_step + for exitnum in range(args.num_exits): + la_string = 'LossAvg' + str(exitnum) + stats_dict[la_string] = args.losses_exits[exitnum].mean + # Because of the nature of ClassErrorMeter, if an exit is never taken during the batch, + # then accessing the value(k) will cause a divide by zero. So we'll build the OrderedDict + # accordingly and we will not print for an exit error when that exit is never taken. + if args.exit_taken[exitnum]: + t1 = 'Top1_exit' + str(exitnum) + t5 = 'Top5_exit' + str(exitnum) + stats_dict[t1] = args.exiterrors[exitnum].value(1) + stats_dict[t5] = args.exiterrors[exitnum].value(5) + stats = ('Performance/Validation/', stats_dict) + distiller.log_training_progress(stats, None, epoch, steps_completed, + total_steps, args.print_freq, loggers) + """Execute the validation/test loop.""" losses = {'objective_loss': tnt.AverageValueMeter()} classerr = tnt.ClassErrorMeter(accuracy=True, topk=(1, 5)) @@ -676,29 +701,8 @@ def _validate(data_loader, model, criterion, loggers, args, epoch=-1): steps_completed = (validation_step+1) if steps_completed % args.print_freq == 0: - if not _is_earlyexit(args): - stats = ('', - OrderedDict([('Loss', losses['objective_loss'].mean), - ('Top1', classerr.value(1)), - ('Top5', classerr.value(5))])) - else: - stats_dict = OrderedDict() - stats_dict['Test'] = validation_step - for exitnum in range(args.num_exits): - la_string = 'LossAvg' + str(exitnum) - stats_dict[la_string] = args.losses_exits[exitnum].mean - # Because of the nature of ClassErrorMeter, if an exit is never taken during the batch, - # then accessing the value(k) will cause a divide by zero. So we'll build the OrderedDict - # accordingly and we will not print for an exit error when that exit is never taken. - if args.exit_taken[exitnum]: - t1 = 'Top1_exit' + str(exitnum) - t5 = 'Top5_exit' + str(exitnum) - stats_dict[t1] = args.exiterrors[exitnum].value(1) - stats_dict[t5] = args.exiterrors[exitnum].value(5) - stats = ('Performance/Validation/', stats_dict) - - distiller.log_training_progress(stats, None, epoch, steps_completed, - total_steps, args.print_freq, loggers) + _log_validation_progress() + if not _is_earlyexit(args): msglogger.info('==> Top1: %.3f Top5: %.3f Loss: %.3f\n', classerr.value()[0], classerr.value()[1], losses['objective_loss'].mean)