From a43b9f101dbf95e034c404f89162ce0082e12ecf Mon Sep 17 00:00:00 2001 From: Guy Jacob <guy.jacob@intel.com> Date: Thu, 9 Aug 2018 13:26:36 +0300 Subject: [PATCH] Generalize the loss value returned from before_backward_pass callbacks (#38) * Instead of a single additive value (which so far represented only the regularizer loss), callbacks return a new overall loss * Policy callbacks also return the individual loss components used to calculate the new overall loss. * Add boolean flag to the Scheduler's callback so applications can choose if they want to get individual loss components, or just the new overall loss * In compress_classifier.py, log the individual loss components * Add test for the loss-from-callback flow --- distiller/policy.py | 30 +++++-- distiller/scheduler.py | 40 +++++++-- .../compress_classifier.py | 85 +++++++++++-------- examples/word_language_model/main.py | 9 +- tests/test_loss.py | 68 +++++++++++++++ 5 files changed, 176 insertions(+), 56 deletions(-) create mode 100644 tests/test_loss.py diff --git a/distiller/policy.py b/distiller/policy.py index 76d295f..6d70533 100755 --- a/distiller/policy.py +++ b/distiller/policy.py @@ -21,11 +21,17 @@ - LRPolicy: learning-rate decay scheduling """ import torch +from collections import namedtuple import logging msglogger = logging.getLogger() -__all__ = ['PruningPolicy', 'RegularizationPolicy', 'QuantizationPolicy', 'LRPolicy', 'ScheduledTrainingPolicy'] +__all__ = ['PruningPolicy', 'RegularizationPolicy', 'QuantizationPolicy', 'LRPolicy', 'ScheduledTrainingPolicy', + 'PolicyLoss', 'LossComponent'] + +PolicyLoss = namedtuple('PolicyLoss', ['overall_loss', 'loss_components']) +LossComponent = namedtuple('LossComponent', ['name', 'value']) + class ScheduledTrainingPolicy(object): """ Base class for all scheduled training policies. @@ -40,14 +46,20 @@ class ScheduledTrainingPolicy(object): """A new epcoh is about to begin""" pass - def on_minibatch_begin(self, model, epoch, minibatch_id, minibatches_per_epoch, zeros_mask_dict, optimizr=None): + def on_minibatch_begin(self, model, epoch, minibatch_id, minibatches_per_epoch, zeros_mask_dict, optimizer=None): """The forward-pass of a new mini-batch is about to begin""" pass - def before_backward_pass(self, model, epoch, minibatch_id, minibatches_per_epoch, loss, - regularizer_loss, zeros_mask_dict, optimizer=None): + def before_backward_pass(self, model, epoch, minibatch_id, minibatches_per_epoch, loss, zeros_mask_dict, + optimizer=None): """The mini-batch training pass has completed the forward-pass, and is about to begin the backward pass. + + This callback receives a 'loss' argument. The callback should not modify this argument, but it can + optionally return an instance of 'PolicyLoss' which will be used in place of `loss'. + + Note: The 'loss_components' parameter within 'PolicyLoss' should contain any new, individual loss components + the callback contributed to 'overall_loss'. It should not contain the incoming 'loss' argument. """ pass @@ -81,7 +93,7 @@ class PruningPolicy(ScheduledTrainingPolicy): for param_name, param in model.named_parameters(): self.pruner.set_param_mask(param, param_name, zeros_mask_dict, meta) - def on_minibatch_begin(self, model, epoch, minibatch_id, minibatches_per_epoch, zeros_mask_dict, optimizer): + def on_minibatch_begin(self, model, epoch, minibatch_id, minibatches_per_epoch, zeros_mask_dict, optimizer=None): for param_name, param in model.named_parameters(): zeros_mask_dict[param_name].apply_mask(param) @@ -100,10 +112,16 @@ class RegularizationPolicy(ScheduledTrainingPolicy): self.is_last_epoch = meta['current_epoch'] == (meta['ending_epoch'] - 1) def before_backward_pass(self, model, epoch, minibatch_id, minibatches_per_epoch, loss, - regularizer_loss, zeros_mask_dict, optimizer=None): + zeros_mask_dict, optimizer=None): + regularizer_loss = torch.tensor(0, dtype=torch.float, device=loss.device) + for param_name, param in model.named_parameters(): self.regularizer.loss(param, param_name, regularizer_loss, zeros_mask_dict) + policy_loss = PolicyLoss(loss + regularizer_loss, + [LossComponent(self.regularizer.__class__.__name__ + '_loss', regularizer_loss)]) + return policy_loss + def on_minibatch_end(self, model, epoch, minibatch_id, minibatches_per_epoch, zeros_mask_dict, optimizer): if self.regularizer.threshold_criteria is None: return diff --git a/distiller/scheduler.py b/distiller/scheduler.py index d052d51..3ef8a97 100755 --- a/distiller/scheduler.py +++ b/distiller/scheduler.py @@ -22,6 +22,7 @@ from functools import partial import logging import torch from .quantization.quantizer import FP_BKP_PREFIX +from .policy import PolicyLoss, LossComponent msglogger = logging.getLogger() @@ -112,16 +113,24 @@ class CompressionScheduler(object): policy.on_minibatch_begin(self.model, epoch, minibatch_id, minibatches_per_epoch, self.zeros_mask_dict, optimizer) - def before_backward_pass(self, epoch, minibatch_id, minibatches_per_epoch, loss, optimizer=None): - # Last chance to compute the regularization loss, and optionally add it to the data loss - regularizer_loss = torch.tensor(0, dtype=torch.float, device=self.device) - + def before_backward_pass(self, epoch, minibatch_id, minibatches_per_epoch, loss, optimizer=None, + return_loss_components=False): + # We pass the loss to the policies, which may override it + overall_loss = loss + loss_components = [] if epoch in self.policies: for policy in self.policies[epoch]: - # regularizer_loss is passed to policy objects which may increase it. - policy.before_backward_pass(self.model, epoch, minibatch_id, minibatches_per_epoch, - loss, regularizer_loss, self.zeros_mask_dict) - return regularizer_loss + policy_loss = policy.before_backward_pass(self.model, epoch, minibatch_id, minibatches_per_epoch, + overall_loss, self.zeros_mask_dict) + if policy_loss is not None: + curr_loss_components = self.verify_policy_loss(policy_loss) + overall_loss = policy_loss.overall_loss + loss_components += curr_loss_components + + if return_loss_components: + return PolicyLoss(overall_loss, loss_components) + + return overall_loss def on_minibatch_end(self, epoch, minibatch_id, minibatches_per_epoch, optimizer=None): # When we get to this point, the weights are no longer maksed. This is because during the backward @@ -161,7 +170,7 @@ class CompressionScheduler(object): def state_dict(self): """Returns the state of the scheduler as a :class:`dict`. - Curently it contains just the pruning mask. + Currently it contains just the pruning mask. """ masks = {} for name, masker in self.zeros_mask_dict.items(): @@ -192,3 +201,16 @@ class CompressionScheduler(object): for name, mask in self.zeros_mask_dict.items(): masker = self.zeros_mask_dict[name] masker.mask = loaded_masks[name] + + @staticmethod + def verify_policy_loss(policy_loss): + if not isinstance(policy_loss, PolicyLoss): + raise TypeError("A Policy's before_backward_pass must return either None or an instance of " + + PolicyLoss.__name__) + curr_loss_components = policy_loss.loss_components + if not isinstance(curr_loss_components, list): + curr_loss_components = [curr_loss_components] + if not all(isinstance(lc, LossComponent) for lc in curr_loss_components): + raise TypeError("Expected an instance of " + LossComponent.__name__ + + " or a list of such instances") + return curr_loss_components diff --git a/examples/classifier_compression/compress_classifier.py b/examples/classifier_compression/compress_classifier.py index 02e86e2..4ffbba3 100755 --- a/examples/classifier_compression/compress_classifier.py +++ b/examples/classifier_compression/compress_classifier.py @@ -144,9 +144,12 @@ parser.add_argument('--validation-size', '--vs', type=float_range, default=0.1, help='Portion of training dataset to set aside for validation') parser.add_argument('--adc', dest='ADC', action='store_true', help='temp HACK') parser.add_argument('--adc-params', dest='ADC_params', default=None, help='temp HACK') -parser.add_argument('--confusion', dest='display_confusion', default=False, action='store_true', help='Display the confusion matrix') -parser.add_argument('--earlyexit_lossweights', type=float, nargs='*', dest='earlyexit_lossweights', default=None, help='List of loss weights for early exits (e.g. --lossweights 0.1 0.3)') -parser.add_argument('--earlyexit_thresholds', type=float, nargs='*', dest='earlyexit_thresholds', default=None, help='List of EarlyExit thresholds (e.g. --earlyexit 1.2 0.9)') +parser.add_argument('--confusion', dest='display_confusion', default=False, action='store_true', + help='Display the confusion matrix') +parser.add_argument('--earlyexit_lossweights', type=float, nargs='*', dest='earlyexit_lossweights', default=None, + help='List of loss weights for early exits (e.g. --lossweights 0.1 0.3)') +parser.add_argument('--earlyexit_thresholds', type=float, nargs='*', dest='earlyexit_thresholds', default=None, + help='List of EarlyExit thresholds (e.g. --earlyexit 1.2 0.9)') def check_pytorch_version(): @@ -302,7 +305,8 @@ def main(): OrderedDict([('Loss', vloss), ('Top1', top1), ('Top5', top5)])) - distiller.log_training_progress(stats, None, epoch, steps_completed=0, total_steps=1, log_freq=1, loggers=[tflogger]) + distiller.log_training_progress(stats, None, epoch, steps_completed=0, total_steps=1, log_freq=1, + loggers=[tflogger]) if compression_scheduler: compression_scheduler.on_epoch_end(epoch, optimizer) @@ -320,19 +324,22 @@ def main(): test(test_loader, model, criterion, [pylogger], args=args) +OVERALL_LOSS_KEY = 'Overall Loss' +OBJECTIVE_LOSS_KEY = 'Objective Loss' + + def train(train_loader, model, criterion, optimizer, epoch, compression_scheduler, loggers, args): """Training loop for one epoch.""" - losses = {'objective_loss': tnt.AverageValueMeter(), - 'regularizer_loss': tnt.AverageValueMeter()} - if compression_scheduler is None: - # Initialize the regularizer loss to zero - losses['regularizer_loss'].add(0) + losses = OrderedDict([(OVERALL_LOSS_KEY, tnt.AverageValueMeter()), + (OBJECTIVE_LOSS_KEY, tnt.AverageValueMeter())]) classerr = tnt.ClassErrorMeter(accuracy=True, topk=(1, 5)) batch_time = tnt.AverageValueMeter() data_time = tnt.AverageValueMeter() - # For Early Exit, we define statistics for each exit - so exiterrors is analogous to classerr for the non-Early Exit case + + # For Early Exit, we define statistics for each exit + # So exiterrors is analogous to classerr for the non-Early Exit case if args.earlyexit_lossweights: args.exiterrors = [] for exitnum in range(args.num_exits): @@ -368,13 +375,19 @@ def train(train_loader, model, criterion, optimizer, epoch, # Measure accuracy and record loss loss = earlyexit_loss(output, target_var, criterion, args) - losses['objective_loss'].add(loss.item()) + losses[OBJECTIVE_LOSS_KEY].add(loss.item()) if compression_scheduler: - # Before running the backward phase, we add any regularization loss computed by the scheduler - regularizer_loss = compression_scheduler.before_backward_pass(epoch, train_step, steps_per_epoch, loss, optimizer) - loss += regularizer_loss - losses['regularizer_loss'].add(regularizer_loss.item()) + # Before running the backward phase, we allow the scheduler to modify the loss + # (e.g. add regularization loss) + agg_loss = compression_scheduler.before_backward_pass(epoch, train_step, steps_per_epoch, loss, + optimizer=optimizer, return_loss_components=True) + loss = agg_loss.overall_loss + losses[OVERALL_LOSS_KEY].add(loss.item()) + for lc in agg_loss.loss_components: + if lc.name not in losses: + losses[lc.name] = tnt.AverageValueMeter() + losses[lc.name].add(lc.value.item()) # Compute the gradient and do SGD step optimizer.zero_grad() @@ -389,28 +402,23 @@ def train(train_loader, model, criterion, optimizer, epoch, if steps_completed % args.print_freq == 0: # Log some statistics - lr = optimizer.param_groups[0]['lr'] + errs = OrderedDict() if not args.earlyexit_lossweights: - stats = ('Peformance/Training/', - OrderedDict([ - ('Loss', losses['objective_loss'].mean), - ('Reg Loss', losses['regularizer_loss'].mean), - ('Top1', classerr.value(1)), - ('Top5', classerr.value(5)), - ('LR', lr), - ('Time', batch_time.mean)])) + errs['Top1'] = classerr.value(1) + errs['Top5'] = classerr.value(5) else: # for Early Exit case, the Top1 and Top5 stats are computed for each exit. - stats_dict = OrderedDict() - stats_dict['Objective Loss'] = losses['objective_loss'].mean for exitnum in range(args.num_exits): - t1 = 'Top1_exit' + str(exitnum) - t5 = 'Top5_exit' + str(exitnum) - stats_dict[t1] = args.exiterrors[exitnum].value(1) - stats_dict[t5] = args.exiterrors[exitnum].value(5) - stats_dict['LR'] = lr - stats_dict['Time'] = batch_time.mean - stats = ('Peformance/Training/', stats_dict) + errs['Top1_exit' + str(exitnum)] = args.exiterrors[exitnum].value(1) + errs['Top5_exit' + str(exitnum)] = args.exiterrors[exitnum].value(5) + + stats_dict = OrderedDict() + for loss_name, meter in losses.items(): + stats_dict[loss_name] = meter.mean + stats_dict.update(errs) + stats_dict['LR'] = optimizer.param_groups[0]['lr'] + stats_dict['Time'] = batch_time.mean + stats = ('Peformance/Training/', stats_dict) params = model.named_parameters() if args.log_params_histograms else None distiller.log_training_progress(stats, @@ -520,7 +528,7 @@ def _validate(data_loader, model, criterion, loggers, args, epoch=-1): msglogger.info('==> Confusion:\n%s', str(confusion.value())) return classerr.value(1), classerr.value(5), losses['objective_loss'].mean else: - #print some interesting summary stats for number of data points that could exit early + # Print some interesting summary stats for number of data points that could exit early top1k_stats = [0] * args.num_exits top5k_stats = [0] * args.num_exits losses_exits_stats = [0] * args.num_exits @@ -534,8 +542,8 @@ def _validate(data_loader, model, criterion, loggers, args, epoch=-1): losses_exits_stats[exitnum] += args.losses_exits[exitnum].mean for exitnum in range(args.num_exits): if args.exit_taken[exitnum]: - msglogger.info("Percent Early Exit %d: %.3f", exitnum, (args.exit_taken[exitnum]*100.0) / sum_exit_stats) - + msglogger.info("Percent Early Exit %d: %.3f", exitnum, + (args.exit_taken[exitnum]*100.0) / sum_exit_stats) return top1k_stats[args.num_exits-1], top5k_stats[args.num_exits-1], losses_exits_stats[args.num_exits-1] @@ -563,6 +571,7 @@ def get_inference_var(tensor): return torch.autograd.Variable(tensor) return torch.autograd.Variable(tensor, volatile=True) + def earlyexit_loss(output, target_var, criterion, args): loss = 0 sum_lossweights = 0 @@ -575,6 +584,7 @@ def earlyexit_loss(output, target_var, criterion, args): args.exiterrors[args.num_exits-1].add(output[args.num_exits-1].data, target_var) return loss + def earlyexit_validate_loss(output, target_var, criterion, args): for exitnum in range(args.num_exits): args.loss_exits[exitnum] = criterion(output[exitnum], target_var) @@ -594,7 +604,8 @@ def earlyexit_validate_loss(output, target_var, criterion, args): args.exit_taken[exitnum] += 1 else: # skip the early exits and include results from end of net - args.exiterrors[args.num_exits-1].add(torch.tensor(np.array(output[args.num_exits-1].data[batchnum], ndmin=2)), + args.exiterrors[args.num_exits-1].add(torch.tensor(np.array(output[args.num_exits-1].data[batchnum], + ndmin=2)), torch.full([1], target_var[batchnum], dtype=torch.long)) args.exit_taken[args.num_exits-1] += 1 diff --git a/examples/word_language_model/main.py b/examples/word_language_model/main.py index a8ff493..56d92c5 100755 --- a/examples/word_language_model/main.py +++ b/examples/word_language_model/main.py @@ -236,10 +236,11 @@ def train(epoch, optimizer, compression_scheduler=None): loss = criterion(output.view(-1, ntokens), targets) if compression_scheduler: - # Before running the backward phase, we add any regularization loss computed by the scheduler - regularizer_loss = compression_scheduler.before_backward_pass(epoch, minibatch_id=batch, - minibatches_per_epoch=steps_per_epoch, loss=loss) - loss += regularizer_loss + # Before running the backward phase, we allow the scheduler to modify the loss + # (e.g. add regularization loss) + loss = compression_scheduler.before_backward_pass(epoch, minibatch_id=batch, + minibatches_per_epoch=steps_per_epoch, loss=loss, + return_loss_components=False) optimizer.zero_grad() loss.backward() diff --git a/tests/test_loss.py b/tests/test_loss.py new file mode 100644 index 0000000..6c197d0 --- /dev/null +++ b/tests/test_loss.py @@ -0,0 +1,68 @@ +# +# Copyright (c) 2018 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import torch +import os +import sys +import torch.nn as nn +from copy import deepcopy +import pytest + +module_path = os.path.abspath(os.path.join('..')) +if module_path not in sys.path: + sys.path.append(module_path) +from distiller import ScheduledTrainingPolicy, CompressionScheduler +from distiller.policy import PolicyLoss, LossComponent + + +class DummyPolicy(ScheduledTrainingPolicy): + def __init__(self, idx): + super(DummyPolicy, self).__init__() + self.loss_val = torch.randint(0, 10000, (1,)) + self.idx = idx + + def before_backward_pass(self, model, epoch, minibatch_id, minibatches_per_epoch, loss, + zeros_mask_dict, optimizer=None): + return PolicyLoss(loss + self.loss_val, [LossComponent('Dummy Loss ' + str(self.idx), self.loss_val)]) + + +@pytest.mark.parametrize("check_loss_components", [False, True]) +def test_multiple_policies_loss(check_loss_components): + model = nn.Module() + scheduler = CompressionScheduler(model, device=torch.device('cpu')) + num_policies = 3 + expected_overall_loss = 0 + expected_policy_losses = [] + for i in range(num_policies): + policy = DummyPolicy(i) + expected_overall_loss += policy.loss_val + expected_policy_losses.append(policy.loss_val) + scheduler.add_policy(policy, epochs=[0]) + + main_loss = torch.randint(0, 10000, (1,)) + expected_overall_loss += main_loss + main_loss_before = deepcopy(main_loss) + + policies_loss = scheduler.before_backward_pass(0, 0, 1, main_loss, return_loss_components=check_loss_components) + + assert main_loss_before == main_loss + if check_loss_components: + assert expected_overall_loss == policies_loss.overall_loss + for idx, lc in enumerate(policies_loss.loss_components): + assert lc.name == 'Dummy Loss ' + str(idx) + assert expected_policy_losses[idx] == lc.value.item() + else: + assert expected_overall_loss == policies_loss -- GitLab