diff --git a/distiller/policy.py b/distiller/policy.py
index 76d295f5f059c14867203479c883b1aa1e9d66d6..6d705332f6f4d76a9d7ffc43fc492b85c7a06565 100755
--- a/distiller/policy.py
+++ b/distiller/policy.py
@@ -21,11 +21,17 @@
 - LRPolicy: learning-rate decay scheduling
 """
 import torch
+from collections import namedtuple
 
 import logging
 msglogger = logging.getLogger()
 
-__all__ = ['PruningPolicy', 'RegularizationPolicy', 'QuantizationPolicy', 'LRPolicy', 'ScheduledTrainingPolicy']
+__all__ = ['PruningPolicy', 'RegularizationPolicy', 'QuantizationPolicy', 'LRPolicy', 'ScheduledTrainingPolicy',
+           'PolicyLoss', 'LossComponent']
+
+PolicyLoss = namedtuple('PolicyLoss', ['overall_loss', 'loss_components'])
+LossComponent = namedtuple('LossComponent', ['name', 'value'])
+
 
 class ScheduledTrainingPolicy(object):
     """ Base class for all scheduled training policies.
@@ -40,14 +46,20 @@ class ScheduledTrainingPolicy(object):
         """A new epcoh is about to begin"""
         pass
 
-    def on_minibatch_begin(self, model, epoch, minibatch_id, minibatches_per_epoch, zeros_mask_dict, optimizr=None):
+    def on_minibatch_begin(self, model, epoch, minibatch_id, minibatches_per_epoch, zeros_mask_dict, optimizer=None):
         """The forward-pass of a new mini-batch is about to begin"""
         pass
 
-    def before_backward_pass(self, model, epoch, minibatch_id, minibatches_per_epoch, loss,
-                             regularizer_loss, zeros_mask_dict, optimizer=None):
+    def before_backward_pass(self, model, epoch, minibatch_id, minibatches_per_epoch, loss, zeros_mask_dict,
+                             optimizer=None):
         """The mini-batch training pass has completed the forward-pass,
         and is about to begin the backward pass.
+
+        This callback receives a 'loss' argument. The callback should not modify this argument, but it can
+        optionally return an instance of 'PolicyLoss' which will be used in place of `loss'.
+
+        Note: The 'loss_components' parameter within 'PolicyLoss' should contain any new, individual loss components
+              the callback contributed to 'overall_loss'. It should not contain the incoming 'loss' argument.
         """
         pass
 
@@ -81,7 +93,7 @@ class PruningPolicy(ScheduledTrainingPolicy):
         for param_name, param in model.named_parameters():
             self.pruner.set_param_mask(param, param_name, zeros_mask_dict, meta)
 
-    def on_minibatch_begin(self, model, epoch, minibatch_id, minibatches_per_epoch, zeros_mask_dict, optimizer):
+    def on_minibatch_begin(self, model, epoch, minibatch_id, minibatches_per_epoch, zeros_mask_dict, optimizer=None):
         for param_name, param in model.named_parameters():
             zeros_mask_dict[param_name].apply_mask(param)
 
@@ -100,10 +112,16 @@ class RegularizationPolicy(ScheduledTrainingPolicy):
         self.is_last_epoch = meta['current_epoch'] == (meta['ending_epoch'] - 1)
 
     def before_backward_pass(self, model, epoch, minibatch_id, minibatches_per_epoch, loss,
-                             regularizer_loss, zeros_mask_dict, optimizer=None):
+                             zeros_mask_dict, optimizer=None):
+        regularizer_loss = torch.tensor(0, dtype=torch.float, device=loss.device)
+
         for param_name, param in model.named_parameters():
             self.regularizer.loss(param, param_name, regularizer_loss, zeros_mask_dict)
 
+        policy_loss = PolicyLoss(loss + regularizer_loss,
+                                 [LossComponent(self.regularizer.__class__.__name__ + '_loss', regularizer_loss)])
+        return policy_loss
+
     def on_minibatch_end(self, model, epoch, minibatch_id, minibatches_per_epoch, zeros_mask_dict, optimizer):
         if self.regularizer.threshold_criteria is None:
             return
diff --git a/distiller/scheduler.py b/distiller/scheduler.py
index d052d51bc6a563ccad9427dc72d3a7dc2c844c83..3ef8a97e141298424cc571069781b527a9df64b4 100755
--- a/distiller/scheduler.py
+++ b/distiller/scheduler.py
@@ -22,6 +22,7 @@ from functools import partial
 import logging
 import torch
 from .quantization.quantizer import FP_BKP_PREFIX
+from .policy import PolicyLoss, LossComponent
 
 msglogger = logging.getLogger()
 
@@ -112,16 +113,24 @@ class CompressionScheduler(object):
                 policy.on_minibatch_begin(self.model, epoch, minibatch_id, minibatches_per_epoch,
                                           self.zeros_mask_dict, optimizer)
 
-    def before_backward_pass(self, epoch, minibatch_id, minibatches_per_epoch, loss, optimizer=None):
-        # Last chance to compute the regularization loss, and optionally add it to the data loss
-        regularizer_loss = torch.tensor(0, dtype=torch.float, device=self.device)
-
+    def before_backward_pass(self, epoch, minibatch_id, minibatches_per_epoch, loss, optimizer=None,
+                             return_loss_components=False):
+        # We pass the loss to the policies, which may override it
+        overall_loss = loss
+        loss_components = []
         if epoch in self.policies:
             for policy in self.policies[epoch]:
-                # regularizer_loss is passed to policy objects which may increase it.
-                policy.before_backward_pass(self.model, epoch, minibatch_id, minibatches_per_epoch,
-                                            loss, regularizer_loss, self.zeros_mask_dict)
-        return regularizer_loss
+                policy_loss = policy.before_backward_pass(self.model, epoch, minibatch_id, minibatches_per_epoch,
+                                                          overall_loss, self.zeros_mask_dict)
+                if policy_loss is not None:
+                    curr_loss_components = self.verify_policy_loss(policy_loss)
+                    overall_loss = policy_loss.overall_loss
+                    loss_components += curr_loss_components
+
+        if return_loss_components:
+            return PolicyLoss(overall_loss, loss_components)
+
+        return overall_loss
 
     def on_minibatch_end(self, epoch, minibatch_id, minibatches_per_epoch, optimizer=None):
         # When we get to this point, the weights are no longer maksed.  This is because during the backward
@@ -161,7 +170,7 @@ class CompressionScheduler(object):
     def state_dict(self):
         """Returns the state of the scheduler as a :class:`dict`.
 
-        Curently it contains just the pruning mask.
+        Currently it contains just the pruning mask.
         """
         masks = {}
         for name, masker in self.zeros_mask_dict.items():
@@ -192,3 +201,16 @@ class CompressionScheduler(object):
         for name, mask in self.zeros_mask_dict.items():
             masker = self.zeros_mask_dict[name]
             masker.mask = loaded_masks[name]
+
+    @staticmethod
+    def verify_policy_loss(policy_loss):
+        if not isinstance(policy_loss, PolicyLoss):
+            raise TypeError("A Policy's before_backward_pass must return either None or an instance of " +
+                            PolicyLoss.__name__)
+        curr_loss_components = policy_loss.loss_components
+        if not isinstance(curr_loss_components, list):
+            curr_loss_components = [curr_loss_components]
+        if not all(isinstance(lc, LossComponent) for lc in curr_loss_components):
+            raise TypeError("Expected an instance of " + LossComponent.__name__ +
+                            " or a list of such instances")
+        return curr_loss_components
diff --git a/examples/classifier_compression/compress_classifier.py b/examples/classifier_compression/compress_classifier.py
index 02e86e2aa61f9a194612dc0151a064f4ceea6e66..4ffbba3d26c38c86e5a89e25c9605f2e30eaed86 100755
--- a/examples/classifier_compression/compress_classifier.py
+++ b/examples/classifier_compression/compress_classifier.py
@@ -144,9 +144,12 @@ parser.add_argument('--validation-size', '--vs', type=float_range, default=0.1,
                     help='Portion of training dataset to set aside for validation')
 parser.add_argument('--adc', dest='ADC', action='store_true', help='temp HACK')
 parser.add_argument('--adc-params', dest='ADC_params', default=None, help='temp HACK')
-parser.add_argument('--confusion', dest='display_confusion', default=False, action='store_true', help='Display the confusion matrix')
-parser.add_argument('--earlyexit_lossweights', type=float, nargs='*', dest='earlyexit_lossweights', default=None, help='List of loss weights for early exits (e.g. --lossweights 0.1 0.3)')
-parser.add_argument('--earlyexit_thresholds', type=float, nargs='*', dest='earlyexit_thresholds', default=None, help='List of EarlyExit thresholds (e.g. --earlyexit 1.2 0.9)')
+parser.add_argument('--confusion', dest='display_confusion', default=False, action='store_true',
+                    help='Display the confusion matrix')
+parser.add_argument('--earlyexit_lossweights', type=float, nargs='*', dest='earlyexit_lossweights', default=None,
+                    help='List of loss weights for early exits (e.g. --lossweights 0.1 0.3)')
+parser.add_argument('--earlyexit_thresholds', type=float, nargs='*', dest='earlyexit_thresholds', default=None,
+                    help='List of EarlyExit thresholds (e.g. --earlyexit 1.2 0.9)')
 
 
 def check_pytorch_version():
@@ -302,7 +305,8 @@ def main():
                  OrderedDict([('Loss', vloss),
                               ('Top1', top1),
                               ('Top5', top5)]))
-        distiller.log_training_progress(stats, None, epoch, steps_completed=0, total_steps=1, log_freq=1, loggers=[tflogger])
+        distiller.log_training_progress(stats, None, epoch, steps_completed=0, total_steps=1, log_freq=1,
+                                        loggers=[tflogger])
 
         if compression_scheduler:
             compression_scheduler.on_epoch_end(epoch, optimizer)
@@ -320,19 +324,22 @@ def main():
     test(test_loader, model, criterion, [pylogger], args=args)
 
 
+OVERALL_LOSS_KEY = 'Overall Loss'
+OBJECTIVE_LOSS_KEY = 'Objective Loss'
+
+
 def train(train_loader, model, criterion, optimizer, epoch,
           compression_scheduler, loggers, args):
     """Training loop for one epoch."""
-    losses = {'objective_loss':   tnt.AverageValueMeter(),
-              'regularizer_loss': tnt.AverageValueMeter()}
-    if compression_scheduler is None:
-        # Initialize the regularizer loss to zero
-        losses['regularizer_loss'].add(0)
+    losses = OrderedDict([(OVERALL_LOSS_KEY, tnt.AverageValueMeter()),
+                          (OBJECTIVE_LOSS_KEY, tnt.AverageValueMeter())])
 
     classerr = tnt.ClassErrorMeter(accuracy=True, topk=(1, 5))
     batch_time = tnt.AverageValueMeter()
     data_time = tnt.AverageValueMeter()
-    # For Early Exit, we define statistics for each exit - so exiterrors is analogous to classerr for the non-Early Exit case
+
+    # For Early Exit, we define statistics for each exit
+    # So exiterrors is analogous to classerr for the non-Early Exit case
     if args.earlyexit_lossweights:
         args.exiterrors = []
         for exitnum in range(args.num_exits):
@@ -368,13 +375,19 @@ def train(train_loader, model, criterion, optimizer, epoch,
             # Measure accuracy and record loss
             loss = earlyexit_loss(output, target_var, criterion, args)
 
-        losses['objective_loss'].add(loss.item())
+        losses[OBJECTIVE_LOSS_KEY].add(loss.item())
 
         if compression_scheduler:
-            # Before running the backward phase, we add any regularization loss computed by the scheduler
-            regularizer_loss = compression_scheduler.before_backward_pass(epoch, train_step, steps_per_epoch, loss, optimizer)
-            loss += regularizer_loss
-            losses['regularizer_loss'].add(regularizer_loss.item())
+            # Before running the backward phase, we allow the scheduler to modify the loss
+            # (e.g. add regularization loss)
+            agg_loss = compression_scheduler.before_backward_pass(epoch, train_step, steps_per_epoch, loss,
+                                                                  optimizer=optimizer, return_loss_components=True)
+            loss = agg_loss.overall_loss
+            losses[OVERALL_LOSS_KEY].add(loss.item())
+            for lc in agg_loss.loss_components:
+                if lc.name not in losses:
+                    losses[lc.name] = tnt.AverageValueMeter()
+                losses[lc.name].add(lc.value.item())
 
         # Compute the gradient and do SGD step
         optimizer.zero_grad()
@@ -389,28 +402,23 @@ def train(train_loader, model, criterion, optimizer, epoch,
 
         if steps_completed % args.print_freq == 0:
             # Log some statistics
-            lr = optimizer.param_groups[0]['lr']
+            errs = OrderedDict()
             if not args.earlyexit_lossweights:
-                stats = ('Peformance/Training/',
-                         OrderedDict([
-                             ('Loss', losses['objective_loss'].mean),
-                             ('Reg Loss', losses['regularizer_loss'].mean),
-                             ('Top1', classerr.value(1)),
-                             ('Top5', classerr.value(5)),
-                             ('LR', lr),
-                             ('Time', batch_time.mean)]))
+                errs['Top1'] = classerr.value(1)
+                errs['Top5'] = classerr.value(5)
             else:
                 # for Early Exit case, the Top1 and Top5 stats are computed for each exit.
-                stats_dict = OrderedDict()
-                stats_dict['Objective Loss'] = losses['objective_loss'].mean
                 for exitnum in range(args.num_exits):
-                    t1 = 'Top1_exit' + str(exitnum)
-                    t5 = 'Top5_exit' + str(exitnum)
-                    stats_dict[t1] = args.exiterrors[exitnum].value(1)
-                    stats_dict[t5] = args.exiterrors[exitnum].value(5)
-                stats_dict['LR'] = lr
-                stats_dict['Time'] = batch_time.mean
-                stats = ('Peformance/Training/', stats_dict)
+                    errs['Top1_exit' + str(exitnum)] = args.exiterrors[exitnum].value(1)
+                    errs['Top5_exit' + str(exitnum)] = args.exiterrors[exitnum].value(5)
+
+            stats_dict = OrderedDict()
+            for loss_name, meter in losses.items():
+                stats_dict[loss_name] = meter.mean
+            stats_dict.update(errs)
+            stats_dict['LR'] = optimizer.param_groups[0]['lr']
+            stats_dict['Time'] = batch_time.mean
+            stats = ('Peformance/Training/', stats_dict)
 
             params = model.named_parameters() if args.log_params_histograms else None
             distiller.log_training_progress(stats,
@@ -520,7 +528,7 @@ def _validate(data_loader, model, criterion, loggers, args, epoch=-1):
             msglogger.info('==> Confusion:\n%s', str(confusion.value()))
         return classerr.value(1), classerr.value(5), losses['objective_loss'].mean
     else:
-        #print some interesting summary stats for number of data points that could exit early
+        # Print some interesting summary stats for number of data points that could exit early
         top1k_stats = [0] * args.num_exits
         top5k_stats = [0] * args.num_exits
         losses_exits_stats = [0] * args.num_exits
@@ -534,8 +542,8 @@ def _validate(data_loader, model, criterion, loggers, args, epoch=-1):
                 losses_exits_stats[exitnum] += args.losses_exits[exitnum].mean
         for exitnum in range(args.num_exits):
             if args.exit_taken[exitnum]:
-                msglogger.info("Percent Early Exit %d: %.3f", exitnum, (args.exit_taken[exitnum]*100.0) / sum_exit_stats)
-
+                msglogger.info("Percent Early Exit %d: %.3f", exitnum,
+                               (args.exit_taken[exitnum]*100.0) / sum_exit_stats)
 
         return top1k_stats[args.num_exits-1], top5k_stats[args.num_exits-1], losses_exits_stats[args.num_exits-1]
 
@@ -563,6 +571,7 @@ def get_inference_var(tensor):
         return torch.autograd.Variable(tensor)
     return torch.autograd.Variable(tensor, volatile=True)
 
+
 def earlyexit_loss(output, target_var, criterion, args):
     loss = 0
     sum_lossweights = 0
@@ -575,6 +584,7 @@ def earlyexit_loss(output, target_var, criterion, args):
     args.exiterrors[args.num_exits-1].add(output[args.num_exits-1].data, target_var)
     return loss
 
+
 def earlyexit_validate_loss(output, target_var, criterion, args):
     for exitnum in range(args.num_exits):
         args.loss_exits[exitnum] = criterion(output[exitnum], target_var)
@@ -594,7 +604,8 @@ def earlyexit_validate_loss(output, target_var, criterion, args):
                 args.exit_taken[exitnum] += 1
             else:
                 # skip the early exits and include results from end of net
-                args.exiterrors[args.num_exits-1].add(torch.tensor(np.array(output[args.num_exits-1].data[batchnum], ndmin=2)),
+                args.exiterrors[args.num_exits-1].add(torch.tensor(np.array(output[args.num_exits-1].data[batchnum],
+                                                                            ndmin=2)),
                         torch.full([1], target_var[batchnum], dtype=torch.long))
                 args.exit_taken[args.num_exits-1] += 1
 
diff --git a/examples/word_language_model/main.py b/examples/word_language_model/main.py
index a8ff493d7a76733cb52e47df3ab42482f3904d9e..56d92c5c547a4aeb9ff4fdc84ecaa034684f12d4 100755
--- a/examples/word_language_model/main.py
+++ b/examples/word_language_model/main.py
@@ -236,10 +236,11 @@ def train(epoch, optimizer, compression_scheduler=None):
         loss = criterion(output.view(-1, ntokens), targets)
 
         if compression_scheduler:
-            # Before running the backward phase, we add any regularization loss computed by the scheduler
-            regularizer_loss = compression_scheduler.before_backward_pass(epoch, minibatch_id=batch,
-                                                                          minibatches_per_epoch=steps_per_epoch, loss=loss)
-            loss += regularizer_loss
+            # Before running the backward phase, we allow the scheduler to modify the loss
+            # (e.g. add regularization loss)
+            loss = compression_scheduler.before_backward_pass(epoch, minibatch_id=batch,
+                                                              minibatches_per_epoch=steps_per_epoch, loss=loss,
+                                                              return_loss_components=False)
 
         optimizer.zero_grad()
         loss.backward()
diff --git a/tests/test_loss.py b/tests/test_loss.py
new file mode 100644
index 0000000000000000000000000000000000000000..6c197d09bf1d39ea68066aabf64d73696effe5e0
--- /dev/null
+++ b/tests/test_loss.py
@@ -0,0 +1,68 @@
+#
+# Copyright (c) 2018 Intel Corporation
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#      http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import torch
+import os
+import sys
+import torch.nn as nn
+from copy import deepcopy
+import pytest
+
+module_path = os.path.abspath(os.path.join('..'))
+if module_path not in sys.path:
+    sys.path.append(module_path)
+from distiller import ScheduledTrainingPolicy, CompressionScheduler
+from distiller.policy import PolicyLoss, LossComponent
+
+
+class DummyPolicy(ScheduledTrainingPolicy):
+    def __init__(self, idx):
+        super(DummyPolicy, self).__init__()
+        self.loss_val = torch.randint(0, 10000, (1,))
+        self.idx = idx
+
+    def before_backward_pass(self, model, epoch, minibatch_id, minibatches_per_epoch, loss,
+                             zeros_mask_dict, optimizer=None):
+        return PolicyLoss(loss + self.loss_val, [LossComponent('Dummy Loss ' + str(self.idx), self.loss_val)])
+
+
+@pytest.mark.parametrize("check_loss_components", [False, True])
+def test_multiple_policies_loss(check_loss_components):
+    model = nn.Module()
+    scheduler = CompressionScheduler(model, device=torch.device('cpu'))
+    num_policies = 3
+    expected_overall_loss = 0
+    expected_policy_losses = []
+    for i in range(num_policies):
+        policy = DummyPolicy(i)
+        expected_overall_loss += policy.loss_val
+        expected_policy_losses.append(policy.loss_val)
+        scheduler.add_policy(policy, epochs=[0])
+
+    main_loss = torch.randint(0, 10000, (1,))
+    expected_overall_loss += main_loss
+    main_loss_before = deepcopy(main_loss)
+
+    policies_loss = scheduler.before_backward_pass(0, 0, 1, main_loss, return_loss_components=check_loss_components)
+
+    assert main_loss_before == main_loss
+    if check_loss_components:
+        assert expected_overall_loss == policies_loss.overall_loss
+        for idx, lc in enumerate(policies_loss.loss_components):
+            assert lc.name == 'Dummy Loss ' + str(idx)
+            assert expected_policy_losses[idx] == lc.value.item()
+    else:
+        assert expected_overall_loss == policies_loss