diff --git a/distiller/apputils/__init__.py b/distiller/apputils/__init__.py
index 82bbec1a455c4647633da2970b2d0dea57ff8e16..bf3714d42229900bac18fdb021f1ec03476c0871 100755
--- a/distiller/apputils/__init__.py
+++ b/distiller/apputils/__init__.py
@@ -22,8 +22,10 @@ from .data_loaders import *
 from .checkpoint import *
 from .execution_env import *
 from .dataset_summaries import *
+from .performance_tracker import *
 
 del data_loaders
 del checkpoint
 del execution_env
 del dataset_summaries
+del performance_tracker
diff --git a/distiller/apputils/image_classifier.py b/distiller/apputils/image_classifier.py
index 907772cc5b21a39b4fc1bf725ffe5e96e180e570..59bcfb3429666eceadc7df1a646768f7404cccad 100755
--- a/distiller/apputils/image_classifier.py
+++ b/distiller/apputils/image_classifier.py
@@ -1,5 +1,5 @@
 #
-# Copyright (c) 2018 Intel Corporation
+# Copyright (c) 2019 Intel Corporation
 #
 # Licensed under the Apache License, Version 2.0 (the "License");
 # you may not use this file except in compliance with the License.
@@ -28,16 +28,15 @@ import torch.backends.cudnn as cudnn
 import torch.optim
 import torch.utils.data
 import torchnet.meter as tnt
+import parser
+from functools import partial
+import argparse
 import distiller
 import distiller.apputils as apputils
 from distiller.data_loggers import *
 import distiller.quantization as quantization
 import distiller.models as models
 from distiller.models import create_model
-import parser
-from functools import partial
-import operator
-import argparse
 from distiller.utils import float_range_argparse_checker as float_range
 
 # Logger handle
@@ -76,6 +75,7 @@ class ClassifierCompressor(object):
         self.train_loader, self.val_loader, self.test_loader = (None, None, None)
         self.activations_collectors = create_activation_stats_collectors(
             self.model, *self.args.activation_stats)
+        self.performance_tracker = apputils.SparsityAccuracyTracker(self.args.num_best_scores)
     
     def load_datasets(self):
         """Load the datasets"""
@@ -155,20 +155,20 @@ class ClassifierCompressor(object):
                                             total_steps=1, log_freq=1, loggers=[self.tflogger])
         return top1, top5, vloss
 
-    def _finalize_epoch(self, epoch, perf_scores_history, top1, top5):
+    def _finalize_epoch(self, epoch, top1, top5):
         # Update the list of top scores achieved so far, and save the checkpoint
-        update_training_scores_history(perf_scores_history, self.model,
-                                       top1, top5, epoch, self.args.num_best_scores)
-        is_best = epoch == perf_scores_history[0].epoch
+        self.performance_tracker.step(self.model, epoch, top1=top1, top5=top5)
+        _log_best_scores(self.performance_tracker, msglogger)
+        best_score = self.performance_tracker.best_scores()[0]
+        is_best = epoch == best_score.epoch
         checkpoint_extras = {'current_top1': top1,
-                             'best_top1': perf_scores_history[0].top1,
-                             'best_epoch': perf_scores_history[0].epoch}
+                             'best_top1': best_score.top1,
+                             'best_epoch': best_score.epoch}
         if msglogger.logdir:
             apputils.save_checkpoint(epoch, self.args.arch, self.model, optimizer=self.optimizer,
                                      scheduler=self.compression_scheduler, extras=checkpoint_extras,
                                      is_best=is_best, name=self.args.name, dir=msglogger.logdir)
 
-
     def run_training_loop(self):
         """Run the main training loop with compression.
 
@@ -186,12 +186,12 @@ class ClassifierCompressor(object):
         # Load the datasets lazily
         self.load_datasets()
 
-        perf_scores_history = []       
+        self.performance_tracker.reset()
         for epoch in range(self.start_epoch, self.ending_epoch):
             msglogger.info('\n')
             top1, top5, loss = self.train_validate_with_scheduling(epoch)
-            self._finalize_epoch(epoch, perf_scores_history, top1, top5)
-        return perf_scores_history
+            self._finalize_epoch(epoch, top1, top5)
+        return self.performance_tracker.perf_scores_history
 
     def validate(self, epoch=-1):
         self.load_datasets()
@@ -533,7 +533,6 @@ def train(train_loader, model, criterion, optimizer, epoch,
                                         steps_per_epoch, args.print_freq,
                                         loggers)
 
-
     OVERALL_LOSS_KEY = 'Overall Loss'
     OBJECTIVE_LOSS_KEY = 'Objective Loss'
 
@@ -737,21 +736,6 @@ def _validate(data_loader, model, criterion, loggers, args, epoch=-1):
         return total_top1, total_top5, losses_exits_stats[args.num_exits-1]
 
 
-def update_training_scores_history(perf_scores_history, model, top1, top5, epoch, num_best_scores):
-    """ Update the list of top training scores achieved so far, and log the best scores so far"""
-
-    model_sparsity, _, params_nnz_cnt = distiller.model_params_stats(model)
-    perf_scores_history.append(distiller.MutableNamedTuple({'params_nnz_cnt': -params_nnz_cnt,
-                                                            'sparsity': model_sparsity,
-                                                            'top1': top1, 'top5': top5, 'epoch': epoch}))
-    # Keep perf_scores_history sorted from best to worst
-    # Sort by sparsity as main sort key, then sort by top1, top5 and epoch
-    perf_scores_history.sort(key=operator.attrgetter('params_nnz_cnt', 'top1', 'top5', 'epoch'), reverse=True)
-    for score in perf_scores_history[:num_best_scores]:
-        msglogger.info('==> Best [Top1: %.3f   Top5: %.3f   Sparsity:%.2f   NNZ-Params: %d on epoch: %d]',
-                       score.top1, score.top5, score.sparsity, -score.params_nnz_cnt, score.epoch)
-
-
 def earlyexit_loss(output, target, criterion, args):
     """Compute the weighted sum of the exits losses
 
@@ -915,3 +899,18 @@ def acts_histogram_collection(model, criterion, loggers, args):
                       loggers=loggers, args=args, activations_collectors=None)
     collect_histograms(model, test_fn, save_dir=msglogger.logdir,
                        classes=None, nbins=2048, save_hist_imgs=True)
+
+
+def _log_best_scores(performance_tracker, logger, how_many=-1):
+    """Utility to log the best scores.
+
+    This function is currently written for pruning use-cases, but can be generalized.
+    """
+    assert isinstance(performance_tracker, (apputils.SparsityAccuracyTracker))
+    if how_many < 1:
+        how_many = performance_tracker.max_len
+    how_many = min(how_many, performance_tracker.max_len)
+    best_scores = performance_tracker.best_scores(how_many)
+    for score in best_scores:
+        logger.info('==> Best [Top1: %.3f   Top5: %.3f   Sparsity:%.2f   NNZ-Params: %d on epoch: %d]',
+                    score.top1, score.top5, score.sparsity, -score.params_nnz_cnt, score.epoch)
diff --git a/distiller/apputils/performance_tracker.py b/distiller/apputils/performance_tracker.py
new file mode 100644
index 0000000000000000000000000000000000000000..e6f4bc151b25167f3238eac8fde3ebf6276c1a7b
--- /dev/null
+++ b/distiller/apputils/performance_tracker.py
@@ -0,0 +1,68 @@
+#
+# Copyright (c) 2019 Intel Corporation
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#      http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+"""Performance trackers used to track the best performing epochs when training.
+"""
+import operator
+import distiller
+
+
+__all__ = ["TrainingPerformanceTracker",
+           "SparsityAccuracyTracker"]
+
+
+class TrainingPerformanceTracker(object):
+    """Base class for performance trackers using Top1 and Top5 accuracy metrics"""
+    def __init__(self, num_best_scores):
+        self.perf_scores_history = []
+        self.max_len = num_best_scores
+
+    def reset(self):
+        self.perf_scores_history = []
+
+    def step(self, model, epoch, **kwargs):
+        """Update the list of top training scores achieved so far"""
+        raise NotImplementedError
+
+    def best_scores(self, how_many=1):
+        """Returns `how_many` best scores experienced so far"""
+        if how_many < 1:
+            how_many = self.max_len
+        how_many = min(how_many, self.max_len)
+        return self.perf_scores_history[:how_many]
+
+
+class SparsityAccuracyTracker(TrainingPerformanceTracker):
+    """A performance tracker which prioritizes non-zero parameters.
+
+    Sort the performance history using the count of non-zero parameters
+    as main sort key, then sort by top1, top5 and and finally epoch number.
+
+    Expects 'top1' and 'top5' to appear in the kwargs.
+    """
+    def step(self, model, epoch, **kwargs):
+        assert all(score in kwargs.keys() for score in ('top1', 'top5'))
+        model_sparsity, _, params_nnz_cnt = distiller.model_params_stats(model)
+        self.perf_scores_history.append(distiller.MutableNamedTuple({
+            'params_nnz_cnt': -params_nnz_cnt,
+            'sparsity': model_sparsity,
+            'top1': kwargs['top1'],
+            'top5': kwargs['top5'],
+            'epoch': epoch}))
+        # Keep perf_scores_history sorted from best to worst
+        self.perf_scores_history.sort(
+            key=operator.attrgetter('params_nnz_cnt', 'top1', 'top5', 'epoch'),
+            reverse=True)