From 60a4f44a11435cea757ca807788493b5d044506d Mon Sep 17 00:00:00 2001 From: Neta Zmora <31280975+nzmora@users.noreply.github.com> Date: Mon, 5 Nov 2018 15:59:32 +0200 Subject: [PATCH] Dynamic Network Surgery (#69) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Added an implementation of: Dynamic Network Surgery for Efficient DNNs, Yiwen Guo, Anbang Yao, Yurong Chen. NIPS 2016, https://arxiv.org/abs/1608.04493. - Added SplicingPruner: A pruner that both prunes and splices connections. - Included an example schedule on ResNet20 CIFAR. - New features for compress_classifier.py: 1. Added the "--masks-sparsity" which, when enabled, logs the sparsity of the weight masks during training. 2. Added a new command-line argument to report the top N best accuracy scores, instead of just the highest score. This is sometimes useful when pruning a pre-trained model, that has the best Top1 accuracy in the first few pruning epochs. - New features for PruningPolicy: 1. The pruning policy can use two copies of the weights: one is used during the forward-pass, the other during the backward pass. This is controlled by the “mask_on_forward_only†argument. 2. If we enable “mask_on_forward_onlyâ€, we probably want to permanently apply the mask at some point (usually once the pruning phase is done). This is controlled by the “keep_mask†argument. 3. We introduce a first implementation of scheduling at the training-iteration granularity (i.e. at the mini-batch granularity). Until now we could schedule pruning at the epoch-granularity. This is controlled by the “mini_batch_pruning_frequency†(disable by setting to zero). Some of the abstractions may have leaked from PruningPolicy to CompressionScheduler. Need to reexamine this in the future. --- distiller/model_summaries.py | 30 +++- distiller/policy.py | 74 +++++++++- distiller/pruning/__init__.py | 1 + distiller/pruning/splicing_pruner.py | 88 +++++++++++ distiller/scheduler.py | 28 +++- distiller/utils.py | 37 ++++- .../compress_classifier.py | 28 ++-- .../resnet20.network_surgery.yaml | 139 ++++++++++++++++++ .../resnet56_cifar_filter_rank_v2.yaml | 4 +- 9 files changed, 402 insertions(+), 27 deletions(-) create mode 100755 distiller/pruning/splicing_pruner.py create mode 100755 examples/network_surgery/resnet20.network_surgery.yaml diff --git a/distiller/model_summaries.py b/distiller/model_summaries.py index 6e57421..e04ce06 100755 --- a/distiller/model_summaries.py +++ b/distiller/model_summaries.py @@ -34,7 +34,7 @@ msglogger = logging.getLogger() __all__ = ['model_summary', 'weights_sparsity_summary', 'weights_sparsity_tbl_summary', - 'model_performance_summary', 'model_performance_tbl_summary'] + 'model_performance_summary', 'model_performance_tbl_summary', 'masks_sparsity_tbl_summary'] def model_summary(model, what, dataset=None): @@ -126,6 +126,34 @@ def weights_sparsity_tbl_summary(model, return_total_sparsity=False, param_dims= return t +def masks_sparsity_summary(model, scheduler, param_dims=[2, 4]): + df = pd.DataFrame(columns=['Name', 'Fine (%)']) + pd.set_option('precision', 2) + params_size = 0 + sparse_params_size = 0 + for name, param in model.state_dict().items(): + # Extract just the actual parameter's name, which in this context we treat as its "type" + if param.dim() in param_dims and any(type in name for type in ['weight', 'bias']): + mask = scheduler.zeros_mask_dict[name].mask + if mask is None: + _density = 1 + else: + _density = distiller.density(mask) + params_size += torch.numel(param) + sparse_params_size += param.numel() * _density + df.loc[len(df.index)] = ([name, (1-_density)*100]) + + assert params_size != 0 + total_sparsity = (1 - sparse_params_size/params_size)*100 + df.loc[len(df.index)] = (['Total sparsity:', total_sparsity]) + return df + + +def masks_sparsity_tbl_summary(model, scheduler, param_dims=[2, 4]): + df = masks_sparsity_summary(model, scheduler, param_dims=param_dims) + return tabulate(df, headers='keys', tablefmt='psql', floatfmt=".5f") + + # Performance data collection code follows from here down def conv_visitor(self, input, output, df, model, memo): diff --git a/distiller/policy.py b/distiller/policy.py index dc85c60..f42219f 100755 --- a/distiller/policy.py +++ b/distiller/policy.py @@ -46,7 +46,8 @@ class ScheduledTrainingPolicy(object): """A new epcoh is about to begin""" pass - def on_minibatch_begin(self, model, epoch, minibatch_id, minibatches_per_epoch, zeros_mask_dict, optimizer=None): + def on_minibatch_begin(self, model, epoch, minibatch_id, minibatches_per_epoch, + zeros_mask_dict, meta, optimizer=None): """The forward-pass of a new mini-batch is about to begin""" pass @@ -74,30 +75,89 @@ class ScheduledTrainingPolicy(object): class PruningPolicy(ScheduledTrainingPolicy): """Base class for pruning policies. - - The current implementation restricts the pruning step to the beginning of - each epoch. This can be easily changed. """ def __init__(self, pruner, pruner_args, classes=None, layers=None): + """ + Arguments: + mask_on_forward_only: controls what we do after the weights are updated by the backward pass. + In issue #53 (https://github.com/NervanaSystems/distiller/issues/53) we explain why in some + cases masked weights will be updated to a non-zero value, even if their gradients are masked + (e.g. when using SGD with momentum). Therefore, to circumvent this weights-update performed by + the backward pass, we usually mask the weights again - right after the backward pass. To + disable this masking set: + pruner_args['mask_on_forward_only'] = False + + use_double_copies: when set to 'True', two sets of weights are used. In the forward-pass we use + masked weights to compute the loss, but in the backward-pass we update the unmasked weights (using + gradients computed from the masked-weights loss). + + mini_batch_pruning_frequency: this controls pruning scheduling at the mini-batch granularity. Every + mini_batch_pruning_frequency training steps (i.e. mini_batches) we perform pruning. This provides more + fine-grained control over pruning than that provided by CompressionScheduler (epoch granularity). + When setting 'mini_batch_pruning_frequency' to a value other than zero, make sure to configure the policy's + schedule to once-every-epoch. + """ super(PruningPolicy, self).__init__(classes, layers) self.pruner = pruner self.levels = None - if pruner_args is not None and 'levels' in pruner_args: - self.levels = pruner_args['levels'] + self.keep_mask = False + self.mini_batch_pruning_frequency = 0 + self.mask_on_forward_only = False + self.use_double_copies = False + if pruner_args is not None: + if 'levels' in pruner_args: + self.levels = pruner_args['levels'] + self.keep_mask = pruner_args.get('keep_mask', False) + self.mini_batch_pruning_frequency = pruner_args.get('mini_batch_pruning_frequency', 0) + self.mask_on_forward_only = pruner_args.get('mask_on_forward_only', False) + self.use_double_copies = pruner_args.get('use_double_copies', False) + self.is_last_epoch = False + self.mini_batch_id = 0 # The ID of the mini_batch within the present epoch + self.global_mini_batch_id = 0 # The ID of the mini_batch within the present training session def on_epoch_begin(self, model, zeros_mask_dict, meta): msglogger.debug("Pruner {} is about to prune".format(self.pruner.name)) + self.mini_batch_id = 0 + self.is_last_epoch = meta['current_epoch'] == (meta['ending_epoch'] - 1) + self.is_first_epoch = meta['current_epoch'] == meta['starting_epoch'] if self.levels is not None: self.pruner.levels = self.levels + if self.is_first_epoch: + self.global_mini_batch_id = 0 + meta['model'] = model for param_name, param in model.named_parameters(): + if self.mask_on_forward_only and self.is_first_epoch: + zeros_mask_dict[param_name].use_double_copies = self.use_double_copies + zeros_mask_dict[param_name].mask_on_forward_only = self.mask_on_forward_only self.pruner.set_param_mask(param, param_name, zeros_mask_dict, meta) - def on_minibatch_begin(self, model, epoch, minibatch_id, minibatches_per_epoch, zeros_mask_dict, optimizer=None): + def on_minibatch_begin(self, model, epoch, minibatch_id, minibatches_per_epoch, + zeros_mask_dict, meta, optimizer=None): + self.mini_batch_id += 1 + self.global_mini_batch_id += 1 + if (self.mini_batch_pruning_frequency != 0 and + self.global_mini_batch_id % self.mini_batch_pruning_frequency == 0): + for param_name, param in model.named_parameters(): + self.pruner.set_param_mask(param, param_name, zeros_mask_dict, meta) + for param_name, param in model.named_parameters(): zeros_mask_dict[param_name].apply_mask(param) + def on_minibatch_end(self, model, epoch, minibatch_id, minibatches_per_epoch, zeros_mask_dict, optimizer): + for param_name, param in model.named_parameters(): + zeros_mask_dict[param_name].remove_mask(param) + + def on_epoch_end(self, model, zeros_mask_dict, meta): + """The current epoch has ended""" + is_last_epoch = meta['current_epoch'] == (meta['ending_epoch'] - 1) + if self.keep_mask and is_last_epoch: + for param_name, param in model.named_parameters(): + zeros_mask_dict[param_name].use_double_copies = False + zeros_mask_dict[param_name].mask_on_forward_only = False + zeros_mask_dict[param_name].apply_mask(param) + class RegularizationPolicy(ScheduledTrainingPolicy): """Regularization policy. diff --git a/distiller/pruning/__init__.py b/distiller/pruning/__init__.py index 24fe7e5..2f576e9 100755 --- a/distiller/pruning/__init__.py +++ b/distiller/pruning/__init__.py @@ -24,6 +24,7 @@ from .automated_gradual_pruner import AutomatedGradualPruner, L1RankedStructureP RandomRankedFilterPruner_AGP from .level_pruner import SparsityLevelParameterPruner from .sensitivity_pruner import SensitivityPruner +from .splicing_pruner import SplicingPruner from .structure_pruner import StructureParameterPruner from .ranked_structures_pruner import L1RankedStructureParameterPruner, ActivationAPoZRankedFilterPruner, \ RandomRankedFilterPruner, GradientRankedFilterPruner diff --git a/distiller/pruning/splicing_pruner.py b/distiller/pruning/splicing_pruner.py new file mode 100755 index 0000000..dfe7e25 --- /dev/null +++ b/distiller/pruning/splicing_pruner.py @@ -0,0 +1,88 @@ +# +# Copyright (c) 2018 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + + +from .pruner import _ParameterPruner +import torch +import logging +msglogger = logging.getLogger() + + +class SplicingPruner(_ParameterPruner): + """A pruner that both prunes and splices connections. + + The idea of pruning and splicing working in tandem was first proposed in the following + NIPS paper from Intel Labs China in 2016: + Dynamic Network Surgery for Efficient DNNs, Yiwen Guo, Anbang Yao, Yurong Chen. + NIPS 2016, https://arxiv.org/abs/1608.04493. + + A SplicingPruner works best with a Dynamic Network Surgery schedule. + The original Caffe code from the authors of the paper is available here: + https://github.com/yiwenguo/Dynamic-Network-Surgery/blob/master/src/caffe/layers/compress_conv_layer.cpp + """ + + def __init__(self, name, sensitivities, low_thresh_mult, hi_thresh_mult, sensitivity_multiplier=0): + """Arguments: + """ + super(SplicingPruner, self).__init__(name) + self.sensitivities = sensitivities + self.low_thresh_mult = low_thresh_mult + self.hi_thresh_mult = hi_thresh_mult + self.sensitivity_multiplier = sensitivity_multiplier + + def set_param_mask(self, param, param_name, zeros_mask_dict, meta): + if param_name not in self.sensitivities: + if '*' not in self.sensitivities: + return + else: + sensitivity = self.sensitivities['*'] + else: + sensitivity = self.sensitivities[param_name] + + if not hasattr(param, '_std'): + # Compute the mean and standard-deviation once, and cache them. + param._std = torch.std(param.abs()).item() + param._mean = torch.mean(param.abs()).item() + + if self.sensitivity_multiplier > 0: + # Linearly growing sensitivity - for now this is hard-coded + starting_epoch = meta['starting_epoch'] + current_epoch = meta['current_epoch'] + sensitivity *= (current_epoch - starting_epoch) * self.sensitivity_multiplier + 1 + + threshold_low = (param._mean + param._std * sensitivity) * self.low_thresh_mult + threshold_hi = (param._mean + param._std * sensitivity) * self.hi_thresh_mult + + if zeros_mask_dict[param_name].mask is None: + zeros_mask_dict[param_name].mask = torch.ones_like(param) + + # This code performs the code in equation (3) of the "Dynamic Network Surgery" paper: + # + # 0 if a > |W| + # h(W) = mask if a <= |W| < b + # 1 if b <= |W| + # + # h(W) is the so-called "network surgery function". + # mask is the mask used in the previous iteration. + # a and b are the low and high thresholds, respectively. + # We followed the example implementation from Yiwen Guo in Caffe, and used the + # weight tensor's starting mean and std. + # This is very similar to the initialization performed by distiller.SensitivityPruner. + + masked_weights = param.mul(zeros_mask_dict[param_name].mask).abs() + a = masked_weights.ge(threshold_low) + b = a & zeros_mask_dict[param_name].mask.type(torch.cuda.ByteTensor) + zeros_mask_dict[param_name].mask = (b | masked_weights.ge(threshold_hi)).type(torch.cuda.FloatTensor) diff --git a/distiller/scheduler.py b/distiller/scheduler.py index d831c19..11b7e69 100755 --- a/distiller/scheduler.py +++ b/distiller/scheduler.py @@ -33,6 +33,9 @@ class ParameterMasker(object): self.mask = None # Mask lazily initialized by pruners self.param_name = param_name # For debug/logging purposes self.is_regularization_mask = False + self.use_double_copies = False + self.mask_on_forward_only = False + self.unmasked_copy = None def apply_mask(self, tensor): """Apply a mask on the weights tensor.""" @@ -40,11 +43,22 @@ class ParameterMasker(object): msglogger.debug('No mask for parameter {0}'.format(self.param_name)) return msglogger.debug('Masking parameter {0}'.format(self.param_name)) + if self.use_double_copies: + self.unmasked_copy = tensor.clone() tensor.data.mul_(self.mask) if self.is_regularization_mask: self.mask = None return tensor + def remove_mask(self, tensor): + if self.mask is None: + msglogger.debug('No mask for parameter {0}'.format(self.param_name)) + return + if not self.use_double_copies: + msglogger.debug('Parameter {0} does not maintain double copies'.format(self.param_name)) + return + tensor.data = self.unmasked_copy.data + def create_model_masks_dict(model): """A convinience function to create a dictionary of paramter maskers for a model""" @@ -64,7 +78,6 @@ class CompressionScheduler(object): self.device = device self.policies = {} self.sched_metadata = {} - self.zeros_mask_dict = {} for name, param in self.model.named_parameters(): masker = ParameterMasker(name) @@ -101,8 +114,10 @@ class CompressionScheduler(object): def on_minibatch_begin(self, epoch, minibatch_id, minibatches_per_epoch, optimizer=None): if epoch in self.policies: for policy in self.policies[epoch]: + meta = self.sched_metadata[policy] + meta['current_epoch'] = epoch policy.on_minibatch_begin(self.model, epoch, minibatch_id, minibatches_per_epoch, - self.zeros_mask_dict, optimizer) + self.zeros_mask_dict, meta, optimizer) def before_backward_pass(self, epoch, minibatch_id, minibatches_per_epoch, loss, optimizer=None, return_loss_components=False): @@ -131,7 +146,7 @@ class CompressionScheduler(object): # # Therefore we choose to always apply the pruning mask. In the future we may optimize this by applying # the mask only if the some policy is actually using the mask. - self.apply_mask() + self.apply_mask(is_forward=False) if epoch in self.policies: for policy in self.policies[epoch]: policy.on_minibatch_end(self.model, epoch, minibatch_id, minibatches_per_epoch, @@ -145,10 +160,13 @@ class CompressionScheduler(object): meta['optimizer'] = optimizer policy.on_epoch_end(self.model, self.zeros_mask_dict, meta) - def apply_mask(self): + def apply_mask(self, is_forward=True): for name, param in self.model.named_parameters(): try: - self.zeros_mask_dict[name].apply_mask(param) + if is_forward or not self.zeros_mask_dict[name].mask_on_forward_only: + # When we mask on forward-pass only, we allow the gradients to change + # the weights. + self.zeros_mask_dict[name].apply_mask(param) except KeyError: # Quantizers for training modify some model parameters by adding a prefix # If this is the source of the error, workaround and move on diff --git a/distiller/utils.py b/distiller/utils.py index 9e93755..25ccc08 100755 --- a/distiller/utils.py +++ b/distiller/utils.py @@ -49,6 +49,21 @@ def pretty_int(i): return "{:,}".format(i) +class MutableNamedTuple(dict): + def __init__(self, init_dict): + for k, v in init_dict.items(): + self[k] = v + + def __getattr__(self, key): + return self[key] + + def __setattr__(self, key, val): + if key in self.__dict__: + self.__dict__[key] = val + else: + self[key] = val + + def assign_layer_fq_names(container, name=None): """Assign human-readable names to the modules (layers). @@ -108,11 +123,13 @@ def denormalize_module_name(parallel_model, normalized_name): return normalized_name # Did not find a module with the name <normalized_name> -def volume(tensor_desc): +def volume(tensor): """return the volume of a pytorch tensor""" - if isinstance(tensor_desc, tuple): - return np.prod(tensor_desc) - return np.prod(tensor_desc.shape) + if isinstance(tensor, torch.FloatTensor) or isinstance(tensor, torch.cuda.FloatTensor): + return np.prod(tensor.shape) + if isinstance(tensor, tuple): + return np.prod(tensor) + raise ValueError def density(tensor): @@ -269,6 +286,18 @@ def density_rows(tensor, transposed=True): return 1 - sparsity_rows(tensor, transposed) +def model_sparsity(model, param_dims=[2, 4]): + params_size = 0 + sparse_params_size = 0 + for name, param in model.state_dict().items(): + if param.dim() in param_dims and any(type in name for type in ['weight', 'bias']): + _density = density(param) + params_size += torch.numel(param) + sparse_params_size += param.numel() * _density + total_sparsity = (1 - sparse_params_size/params_size)*100 + return total_sparsity + + def norm_filters(weights, p=1): """Compute the p-norm of convolution filters. diff --git a/examples/classifier_compression/compress_classifier.py b/examples/classifier_compression/compress_classifier.py index 588b58d..d0663c6 100755 --- a/examples/classifier_compression/compress_classifier.py +++ b/examples/classifier_compression/compress_classifier.py @@ -120,6 +120,8 @@ parser.add_argument('--pretrained', dest='pretrained', action='store_true', help='use pre-trained model') parser.add_argument('--act-stats', dest='activation_stats', choices=["train", "valid", "test"], default=None, help='collect activation statistics (WARNING: this slows down training)') +parser.add_argument('--masks-sparsity', dest='masks_sparsity', action='store_true', default=False, + help='print masks sparsity table at end of each epoch') parser.add_argument('--param-hist', dest='log_params_histograms', action='store_true', default=False, help='log the paramter tensors histograms to file (WARNING: this can use significant disk space)') SUMMARY_CHOICES = ['sparsity', 'compute', 'model', 'modules', 'png', 'png_w_params', 'onnx'] @@ -148,6 +150,8 @@ parser.add_argument('--earlyexit_lossweights', type=float, nargs='*', dest='earl help='List of loss weights for early exits (e.g. --lossweights 0.1 0.3)') parser.add_argument('--earlyexit_thresholds', type=float, nargs='*', dest='earlyexit_thresholds', default=None, help='List of EarlyExit thresholds (e.g. --earlyexit 1.2 0.9)') +parser.add_argument('--num-best-scores', dest='num_best_scores', default=1, type=int, + help='number of best scores to track and report (default: 1)') quant_group = parser.add_argument_group('Arguments controlling quantization at evaluation time' '("post-training quantization)') @@ -238,8 +242,8 @@ def main(): msglogger.debug("Distiller: %s", distiller.__version__) start_epoch = 0 - best_top1 = 0 - best_epoch = 0 + best_epochs = [distiller.MutableNamedTuple({'epoch': 0, 'top1': 0, 'sparsity': 0}) + for i in range(args.num_best_scores)] if args.deterministic: # Experiment reproducibility is sometimes important. Pete Warden expounded about this @@ -372,6 +376,8 @@ def main(): distiller.log_weights_sparsity(model, epoch, loggers=[tflogger, pylogger]) distiller.log_activation_statsitics(epoch, "train", loggers=[tflogger], collector=collectors["sparsity"]) + if args.masks_sparsity: + msglogger.info(distiller.masks_sparsity_tbl_summary(model, compression_scheduler)) # evaluate on validation set with collectors_context(activations_collectors["valid"]) as collectors: @@ -391,13 +397,18 @@ def main(): compression_scheduler.on_epoch_end(epoch, optimizer) # remember best top1 and save checkpoint - is_best = top1 > best_top1 + #sparsity = distiller.model_sparsity(model) + is_best = top1 > best_epochs[0].top1 if is_best: - best_epoch = epoch - best_top1 = top1 - msglogger.info('==> Best Top1: %.3f On Epoch: %d\n', best_top1, best_epoch) - apputils.save_checkpoint(epoch, args.arch, model, optimizer, compression_scheduler, best_top1, is_best, - args.name, msglogger.logdir) + best_epochs[0].epoch = epoch + best_epochs[0].top1 = top1 + #best_epoch.sparsity = sparsity + best_epochs = sorted(best_epochs, key=lambda score: score.top1) + for score in reversed(best_epochs): + if score.top1 > 0: + msglogger.info('==> Best Top1: %.3f on Epoch: %d', score.top1, score.epoch) + apputils.save_checkpoint(epoch, args.arch, model, optimizer, compression_scheduler, + best_epochs[0].top1, is_best, args.name, msglogger.logdir) # Finally run results on the test set test(test_loader, model, criterion, [pylogger], activations_collectors, args=args) @@ -525,7 +536,6 @@ def test(test_loader, model, criterion, loggers, activations_collectors, args): with collectors_context(activations_collectors["test"]) as collectors: top1, top5, lossses = _validate(test_loader, model, criterion, loggers, args) distiller.log_activation_statsitics(-1, "test", loggers, collector=collectors['sparsity']) - return top1, top5, lossses diff --git a/examples/network_surgery/resnet20.network_surgery.yaml b/examples/network_surgery/resnet20.network_surgery.yaml new file mode 100755 index 0000000..ba8740e --- /dev/null +++ b/examples/network_surgery/resnet20.network_surgery.yaml @@ -0,0 +1,139 @@ +# This schedule follows the methodology proposed by Intel Labs China in the paper: +# Dynamic Network Surgery for Efficient DNNs, Yiwen Guo, Anbang Yao, Yurong Chen. +# NIPS 2016, https://arxiv.org/abs/1600.504493. +# +# We have not been able to see great results with our implementation of the paper, as we understood it. In the paper +# two sets of weights are used: in the forward-pass Guo et. al use masked weights to compute the loss, but in the +# backward-pass they update the unmasked weights (using gradients computed from the masked-weights loss). +# To replicate this behavior set in the SplicingPruner policy: +# use_double_copies: True +# mask_on_forward_only: True +# +# We found that using two copies of weights reduces the accuracy results, and so we disable this configuration in the +# example schedule below. +# +# The "mask_on_forward_only" configuration controls what we do after the weights are updated by the backward pass. +# In issue #53 (https://github.com/NervanaSystems/distiller/issues/53) we explain why in some cases masked weights +# will be updated to a non-zero value, even if their gradients are masked (e.g. when using SGD with momentum). +# Therefore, to circumvent this weights-update performed by the backward pass, we usually mask the weights again - +# right after the backward pass. To disable this masking set: +# mask_on_forward_only: False +# +# +# Baseline results: +# Top1: 91.780 Top5: 99.710 Loss: 0.376 +# Total MACs: 40,813,184 +# # of parameters: 270,896 +# +# Results: +# Best Top1: 91.490 (on Epoch: 339) +# Total MACs: 40,813,184 +# Total sparsity: 69.1% +# # of parameters: 83,671 +# +# time python3 compress_classifier.py --arch resnet20_cifar ../../../data.cifar10 -p=50 --lr=0.01 --epochs=180 --compress=../network_surgery/resnet20.network_surgery.yaml -j=1 --deterministic --validation-size=0 --resume=../ssl/checkpoints/checkpoint_trained_dense.pth.tar --masks-sparsity --num-best-scores=10 +# +# Parameters: +# +----+-------------------------------------+----------------+---------------+----------------+------------+------------+----------+----------+----------+------------+---------+----------+------------+ +# | | Name | Shape | NNZ (dense) | NNZ (sparse) | Cols (%) | Rows (%) | Ch (%) | 2D (%) | 3D (%) | Fine (%) | Std | Mean | Abs-Mean | +# |----+-------------------------------------+----------------+---------------+----------------+------------+------------+----------+----------+----------+------------+---------+----------+------------| +# | 0 | module.conv1.weight | (16, 3, 3, 3) | 432 | 432 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.36322 | -0.00677 | 0.25811 | +# | 1 | module.layer1.0.conv1.weight | (16, 16, 3, 3) | 2304 | 641 | 0.00000 | 0.00000 | 6.25000 | 19.53125 | 0.00000 | 72.17882 | 0.12306 | -0.00732 | 0.05892 | +# | 2 | module.layer1.0.conv2.weight | (16, 16, 3, 3) | 2304 | 629 | 0.00000 | 0.00000 | 0.00000 | 19.14062 | 0.00000 | 72.69965 | 0.12015 | -0.00124 | 0.05719 | +# | 3 | module.layer1.1.conv1.weight | (16, 16, 3, 3) | 2304 | 670 | 0.00000 | 0.00000 | 0.00000 | 16.40625 | 0.00000 | 70.92014 | 0.10390 | -0.00880 | 0.05262 | +# | 4 | module.layer1.1.conv2.weight | (16, 16, 3, 3) | 2304 | 659 | 0.00000 | 0.00000 | 0.00000 | 14.84375 | 0.00000 | 71.39757 | 0.09757 | -0.00261 | 0.04895 | +# | 5 | module.layer1.2.conv1.weight | (16, 16, 3, 3) | 2304 | 714 | 0.00000 | 0.00000 | 0.00000 | 20.31250 | 0.00000 | 69.01042 | 0.13813 | -0.00660 | 0.07117 | +# | 6 | module.layer1.2.conv2.weight | (16, 16, 3, 3) | 2304 | 685 | 0.00000 | 0.00000 | 0.00000 | 19.14062 | 0.00000 | 70.26910 | 0.10906 | 0.00146 | 0.05552 | +# | 7 | module.layer2.0.conv1.weight | (32, 16, 3, 3) | 4608 | 1534 | 0.00000 | 0.00000 | 0.00000 | 14.64844 | 0.00000 | 66.71007 | 0.10999 | 0.00096 | 0.05983 | +# | 8 | module.layer2.0.conv2.weight | (32, 32, 3, 3) | 9216 | 2915 | 0.00000 | 0.00000 | 0.00000 | 8.30078 | 0.00000 | 68.37023 | 0.09245 | -0.00413 | 0.04926 | +# | 9 | module.layer2.0.downsample.0.weight | (32, 16, 1, 1) | 512 | 137 | 0.00000 | 0.00000 | 0.00000 | 73.24219 | 0.00000 | 73.24219 | 0.20471 | -0.00300 | 0.09534 | +# | 10 | module.layer2.1.conv1.weight | (32, 32, 3, 3) | 9216 | 2738 | 0.00000 | 0.00000 | 0.00000 | 9.96094 | 0.00000 | 70.29080 | 0.08032 | -0.00450 | 0.04164 | +# | 11 | module.layer2.1.conv2.weight | (32, 32, 3, 3) | 9216 | 2791 | 0.00000 | 0.00000 | 0.00000 | 8.39844 | 0.00000 | 69.71571 | 0.07054 | -0.00283 | 0.03704 | +# | 12 | module.layer2.2.conv1.weight | (32, 32, 3, 3) | 9216 | 2770 | 0.00000 | 0.00000 | 0.00000 | 11.03516 | 0.00000 | 69.94358 | 0.08060 | -0.00651 | 0.04184 | +# | 13 | module.layer2.2.conv2.weight | (32, 32, 3, 3) | 9216 | 2745 | 0.00000 | 0.00000 | 0.00000 | 11.52344 | 0.00000 | 70.21484 | 0.06489 | 0.00081 | 0.03376 | +# | 14 | module.layer3.0.conv1.weight | (64, 32, 3, 3) | 18432 | 5852 | 0.00000 | 0.00000 | 0.00000 | 13.37891 | 0.00000 | 68.25087 | 0.07912 | -0.00226 | 0.04262 | +# | 15 | module.layer3.0.conv2.weight | (64, 64, 3, 3) | 36864 | 11772 | 0.00000 | 0.00000 | 0.00000 | 4.83398 | 0.00000 | 68.06641 | 0.07390 | -0.00186 | 0.03979 | +# | 16 | module.layer3.0.downsample.0.weight | (64, 32, 1, 1) | 2048 | 667 | 0.00000 | 0.00000 | 0.00000 | 67.43164 | 0.00000 | 67.43164 | 0.11225 | -0.00579 | 0.06123 | +# | 17 | module.layer3.1.conv1.weight | (64, 64, 3, 3) | 36864 | 10906 | 0.00000 | 0.00000 | 0.00000 | 6.98242 | 0.00000 | 70.41558 | 0.07001 | -0.00362 | 0.03640 | +# | 18 | module.layer3.1.conv2.weight | (64, 64, 3, 3) | 36864 | 10850 | 0.00000 | 0.00000 | 0.00000 | 8.20312 | 0.00000 | 70.56749 | 0.06218 | -0.00419 | 0.03227 | +# | 19 | module.layer3.2.conv1.weight | (64, 64, 3, 3) | 36864 | 11053 | 0.00000 | 0.00000 | 0.00000 | 10.42480 | 0.00000 | 70.01682 | 0.06197 | -0.00483 | 0.03245 | +# | 20 | module.layer3.2.conv2.weight | (64, 64, 3, 3) | 36864 | 11871 | 0.00000 | 0.00000 | 0.00000 | 23.90137 | 0.00000 | 67.79785 | 0.03901 | 0.00004 | 0.02103 | +# | 21 | module.fc.weight | (10, 64) | 640 | 640 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.56654 | -0.00002 | 0.48838 | +# | 22 | Total sparsity: | - | 270896 | 83671 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 69.11324 | 0.00000 | 0.00000 | 0.00000 | +# +----+-------------------------------------+----------------+---------------+----------------+------------+------------+----------+----------+----------+------------+---------+----------+------------+ +# 2018-11-03 12:20:32,056 - Total sparsity: 69.11 +# 2018-11-03 12:20:32,108 - --- validate (epoch=359)----------- +# 2018-11-03 12:20:32,109 - 10000 samples (256 per mini-batch) +# 2018-11-03 12:20:34,825 - ==> Top1: 91.430 Top5: 99.670 Loss: 0.368 +# +# 2018-11-03 12:20:34,827 - ==> Best Top1: 91.490 on Epoch: 339 <====== BEST RESULT +# 2018-11-03 12:20:34,827 - ==> Best Top1: 91.480 on Epoch: 222 +# 2018-11-03 12:20:34,827 - ==> Best Top1: 91.470 on Epoch: 280 +# 2018-11-03 12:20:34,827 - ==> Best Top1: 91.450 on Epoch: 328 +# 2018-11-03 12:20:34,827 - ==> Best Top1: 91.440 on Epoch: 298 +# 2018-11-03 12:20:34,828 - ==> Best Top1: 91.440 on Epoch: 314 +# 2018-11-03 12:20:34,828 - ==> Best Top1: 91.440 on Epoch: 319 +# 2018-11-03 12:20:34,828 - ==> Best Top1: 91.430 on Epoch: 283 +# 2018-11-03 12:20:34,828 - ==> Best Top1: 91.430 on Epoch: 304 +# 2018-11-03 12:20:34,828 - ==> Best Top1: 91.430 on Epoch: 359 +# 2018-11-03 12:20:34,828 - Saving checkpoint to: logs/2018.11.03-111011/checkpoint.pth.tar +# 2018-11-03 12:20:34,859 - --- test --------------------- +# 2018-11-03 12:20:34,859 - 10000 samples (256 per mini-batch) +# 2018-11-03 12:20:37,610 - ==> Top1: 91.430 Top5: 99.670 Loss: 0.368 + + +version: 1 +pruners: + pruner1: + class: SplicingPruner + low_thresh_mult: 0.9 # 0.6 + hi_thresh_mult: 1.1 # 0.7 + sensitivity_multiplier: 0.015 + sensitivities: + #module.conv1.weight: 0.50 + module.layer1.0.conv1.weight: 0.050 + module.layer1.0.conv2.weight: 0.050 + module.layer1.1.conv1.weight: 0.050 + module.layer1.1.conv2.weight: 0.050 + module.layer1.2.conv1.weight: 0.010 + module.layer1.2.conv2.weight: 0.050 + module.layer2.0.conv1.weight: 0.010 + module.layer2.0.conv2.weight: 0.030 + module.layer2.0.downsample.0.weight: 0.050 + module.layer2.1.conv1.weight: 0.050 + module.layer2.1.conv2.weight: 0.050 + module.layer2.2.conv1.weight: 0.040 + module.layer2.2.conv2.weight: 0.050 + module.layer3.0.conv1.weight: 0.040 + module.layer3.0.conv2.weight: 0.020 + module.layer3.0.downsample.0.weight: 0.050 + module.layer3.1.conv1.weight: 0.050 + module.layer3.1.conv2.weight: 0.050 + module.layer3.2.conv1.weight: 0.050 + module.layer3.2.conv2.weight: 0.050 + #module.fc.weight + +lr_schedulers: + training_lr: + class: StepLR + step_size: 45 + gamma: 0.10 + +policies: + - pruner: + instance_name: pruner1 + args: + keep_mask: True + mini_batch_pruning_frequency: 1 + mask_on_forward_only: True + # use_double_copies: True + starting_epoch: 180 + ending_epoch: 280 + frequency: 1 + + + - lr_scheduler: + instance_name: training_lr + starting_epoch: 225 + ending_epoch: 400 + frequency: 1 diff --git a/examples/pruning_filters_for_efficient_convnets/resnet56_cifar_filter_rank_v2.yaml b/examples/pruning_filters_for_efficient_convnets/resnet56_cifar_filter_rank_v2.yaml index d0f58ce..5f67404 100755 --- a/examples/pruning_filters_for_efficient_convnets/resnet56_cifar_filter_rank_v2.yaml +++ b/examples/pruning_filters_for_efficient_convnets/resnet56_cifar_filter_rank_v2.yaml @@ -21,11 +21,13 @@ # # Baseline results: # Top1: 92.850 Top5: 99.780 Loss: 0.464 +# Parameters: 851,504 # Total MACs: 125,747,840 # # Results: # Top1: 92.740 Top5: 99.640 Loss: 1.534 -# Total MACs: 67,797,632 +# Parameters: 570,704 (= 33% sparse ) +# Total MACs: 67,797,632 (=1.85x less MACs) # # Parameters: # +----+-------------------------------------+----------------+---------------+----------------+------------+------------+----------+----------+----------+------------+---------+----------+------------+ -- GitLab