diff --git a/distiller/apputils/image_classifier.py b/distiller/apputils/image_classifier.py
index 8db49acb366b83de9c5323c82f38164b44ff810a..a6a859158b33a454fe7932f22e21a017675634ea 100755
--- a/distiller/apputils/image_classifier.py
+++ b/distiller/apputils/image_classifier.py
@@ -115,7 +115,6 @@ class ClassifierCompressor(object):
     def validate_one_epoch(self, epoch, verbose=True):
         """Evaluate on validation set"""
         self.load_datasets()
-
         with collectors_context(self.activations_collectors["valid"]) as collectors:
             top1, top5, vloss = validate(self.val_loader, self.model, self.criterion, 
                                          [self.pylogger], self.args, epoch)
@@ -128,9 +127,8 @@ class ClassifierCompressor(object):
             OrderedDict([('Loss', vloss),
                          ('Top1', top1),
                          ('Top5', top5)]))
-            distiller.log_training_progress(stats, None, epoch, steps_completed=0, total_steps=1, log_freq=1,
-                                            loggers=[self.tflogger])
-
+            distiller.log_training_progress(stats, None, epoch, steps_completed=0,
+                                            total_steps=1, log_freq=1, loggers=[self.tflogger])
         return top1, top5, vloss
 
     def _finalize_epoch(self, epoch, perf_scores_history, top1, top5):
@@ -284,7 +282,6 @@ def init_classifier_compression_arg_parser():
                         help='Load a model without DataParallel wrapping it')
     parser.add_argument('--thinnify', dest='thinnify', action='store_true', default=False,
                         help='physically remove zero-filters and create a smaller model')
-
     distiller.quantization.add_post_train_quant_args(parser)
     return parser
 
@@ -490,6 +487,34 @@ def train(train_loader, model, criterion, optimizer, epoch,
         optimizer.step()
         compression_scheduler.on_minibatch_end(epoch)
     """
+    def _log_training_progress():
+        # Log some statistics
+        errs = OrderedDict()
+        if not early_exit_mode(args):
+            errs['Top1'] = classerr.value(1)
+            errs['Top5'] = classerr.value(5)
+        else:
+            # for Early Exit case, the Top1 and Top5 stats are computed for each exit.
+            for exitnum in range(args.num_exits):
+                errs['Top1_exit' + str(exitnum)] = args.exiterrors[exitnum].value(1)
+                errs['Top5_exit' + str(exitnum)] = args.exiterrors[exitnum].value(5)
+
+        stats_dict = OrderedDict()
+        for loss_name, meter in losses.items():
+            stats_dict[loss_name] = meter.mean
+        stats_dict.update(errs)
+        stats_dict['LR'] = optimizer.param_groups[0]['lr']
+        stats_dict['Time'] = batch_time.mean
+        stats = ('Performance/Training/', stats_dict)
+
+        params = model.named_parameters() if args.log_params_histograms else None
+        distiller.log_training_progress(stats,
+                                        params,
+                                        epoch, steps_completed,
+                                        steps_per_epoch, args.print_freq,
+                                        loggers)
+
+
     OVERALL_LOSS_KEY = 'Overall Loss'
     OBJECTIVE_LOSS_KEY = 'Objective Loss'
 
@@ -570,31 +595,8 @@ def train(train_loader, model, criterion, optimizer, epoch,
         steps_completed = (train_step+1)
 
         if steps_completed % args.print_freq == 0:
-            # Log some statistics
-            errs = OrderedDict()
-            if not early_exit_mode(args):
-                errs['Top1'] = classerr.value(1)
-                errs['Top5'] = classerr.value(5)
-            else:
-                # for Early Exit case, the Top1 and Top5 stats are computed for each exit.
-                for exitnum in range(args.num_exits):
-                    errs['Top1_exit' + str(exitnum)] = args.exiterrors[exitnum].value(1)
-                    errs['Top5_exit' + str(exitnum)] = args.exiterrors[exitnum].value(5)
+            _log_training_progress()
 
-            stats_dict = OrderedDict()
-            for loss_name, meter in losses.items():
-                stats_dict[loss_name] = meter.mean
-            stats_dict.update(errs)
-            stats_dict['LR'] = optimizer.param_groups[0]['lr']
-            stats_dict['Time'] = batch_time.mean
-            stats = ('Performance/Training/', stats_dict)
-
-            params = model.named_parameters() if args.log_params_histograms else None
-            distiller.log_training_progress(stats,
-                                            params,
-                                            epoch, steps_completed,
-                                            steps_per_epoch, args.print_freq,
-                                            loggers)
         end = time.time()
     #return acc_stats
     # NOTE: this breaks previous behavior, which returned a history of (top1, top5) values
@@ -628,6 +630,29 @@ def _is_earlyexit(args):
 
 
 def _validate(data_loader, model, criterion, loggers, args, epoch=-1):
+    def _log_validation_progress():
+        if not _is_earlyexit(args):
+            stats_dict = OrderedDict([('Loss', losses['objective_loss'].mean),
+                                      ('Top1', classerr.value(1)),
+                                      ('Top5', classerr.value(5))])
+        else:
+            stats_dict = OrderedDict()
+            stats_dict['Test'] = validation_step
+            for exitnum in range(args.num_exits):
+                la_string = 'LossAvg' + str(exitnum)
+                stats_dict[la_string] = args.losses_exits[exitnum].mean
+                # Because of the nature of ClassErrorMeter, if an exit is never taken during the batch,
+                # then accessing the value(k) will cause a divide by zero. So we'll build the OrderedDict
+                # accordingly and we will not print for an exit error when that exit is never taken.
+                if args.exit_taken[exitnum]:
+                    t1 = 'Top1_exit' + str(exitnum)
+                    t5 = 'Top5_exit' + str(exitnum)
+                    stats_dict[t1] = args.exiterrors[exitnum].value(1)
+                    stats_dict[t5] = args.exiterrors[exitnum].value(5)
+        stats = ('Performance/Validation/', stats_dict)
+        distiller.log_training_progress(stats, None, epoch, steps_completed,
+                                        total_steps, args.print_freq, loggers)
+
     """Execute the validation/test loop."""
     losses = {'objective_loss': tnt.AverageValueMeter()}
     classerr = tnt.ClassErrorMeter(accuracy=True, topk=(1, 5))
@@ -676,29 +701,8 @@ def _validate(data_loader, model, criterion, loggers, args, epoch=-1):
 
             steps_completed = (validation_step+1)
             if steps_completed % args.print_freq == 0:
-                if not _is_earlyexit(args):
-                    stats = ('',
-                            OrderedDict([('Loss', losses['objective_loss'].mean),
-                                         ('Top1', classerr.value(1)),
-                                         ('Top5', classerr.value(5))]))
-                else:
-                    stats_dict = OrderedDict()
-                    stats_dict['Test'] = validation_step
-                    for exitnum in range(args.num_exits):
-                        la_string = 'LossAvg' + str(exitnum)
-                        stats_dict[la_string] = args.losses_exits[exitnum].mean
-                        # Because of the nature of ClassErrorMeter, if an exit is never taken during the batch,
-                        # then accessing the value(k) will cause a divide by zero. So we'll build the OrderedDict
-                        # accordingly and we will not print for an exit error when that exit is never taken.
-                        if args.exit_taken[exitnum]:
-                            t1 = 'Top1_exit' + str(exitnum)
-                            t5 = 'Top5_exit' + str(exitnum)
-                            stats_dict[t1] = args.exiterrors[exitnum].value(1)
-                            stats_dict[t5] = args.exiterrors[exitnum].value(5)
-                    stats = ('Performance/Validation/', stats_dict)
-
-                distiller.log_training_progress(stats, None, epoch, steps_completed,
-                                                total_steps, args.print_freq, loggers)
+                _log_validation_progress()
+
     if not _is_earlyexit(args):
         msglogger.info('==> Top1: %.3f    Top5: %.3f    Loss: %.3f\n',
                        classerr.value()[0], classerr.value()[1], losses['objective_loss'].mean)