diff --git a/distiller/norms.py b/distiller/norms.py index bafda4e16795a78decba1058756e2727f680e1cf..a73a07fc1ee65d45fb5e5289e5446afda69f038a 100644 --- a/distiller/norms.py +++ b/distiller/norms.py @@ -40,7 +40,7 @@ __all__ = ["kernels_lp_norm", "channels_lp_norm", "filters_lp_norm", "rows_lp_norm", "cols_lp_norm", "rows_norm", "cols_norm", "l1_norm", "l2_norm", "max_norm", - "rank_channels", "rank_filters", "rank_cols"] + "rank_channels", "rank_filters"] class NamedFunction: @@ -138,9 +138,12 @@ def channels_lp_norm(param, p=1, group_len=1, length_normalized=False): def channels_norm(param, norm_fn, group_len=1, length_normalized=False): - """Compute some norm of 3D channels of 4D parameter tensors. + """Compute some norm of parameter input-channels. + + Computing the norms of weight-matrices input channels is logically similar to computing + the norms of 4D Conv weights input channels so we treat both cases. + Assumes 2D or 4D weights tensors. - Assumes 4D weights tensors. Args: param: shape (num_filters(0), nun_channels(1), kernel_height(2), kernel_width(3)) norm_fn: a callable that computes a normal @@ -154,6 +157,13 @@ def channels_norm(param, norm_fn, group_len=1, length_normalized=False): """ assert param.dim() in (2, 4), "param has invalid dimensions" if param.dim() == 2: + # For GEMM operations, PyTorch stores the weights matrices in a transposed format. I.e. + # before performing GEMM, a matrix is transposed. This is because the output is computed + # as follows (see torch.nn.functional.linear): + # y = x(W^T) + b ; where W^T is the transpose of W + # + # Therefore, W is expected to have shape (output_channels, input_channels), and to compute + # the norms of input_channels, we compute the norms of W's columns. return cols_norm(param, norm_fn, group_len, length_normalized) param = param.transpose(0, 1).contiguous() group_size = group_len * np.prod(param.shape[1:]) @@ -316,13 +326,3 @@ def rank_filters(param, group_len, magnitude_fn, fraction_to_partition, rounding mags = filters_norm(param, magnitude_fn, group_len, length_normalized=True) return k_smallest_elems(mags, n_filters_to_prune, noise) - -def rank_cols(param, group_len, magnitude_fn, fraction_to_partition, rounding_fn, noise): - assert param.dim() == 2, "This ranking is only supported for 2D tensors" - COLS_DIM = 0 - n_cols = param.size(COLS_DIM) - n_cols_to_prune = num_structs_to_prune(n_cols, group_len, fraction_to_partition, rounding_fn) - if n_cols_to_prune == 0: - return None, None - mags = cols_norm(param, magnitude_fn, group_len, length_normalized=True) - return k_smallest_elems(mags, n_cols_to_prune, noise) diff --git a/distiller/policy.py b/distiller/policy.py index d460d2476a4cea2a94c8b3e34316bc55901c9909..e040d6559903b36b31dbdbab6d294c72f9ca5c37 100755 --- a/distiller/policy.py +++ b/distiller/policy.py @@ -22,14 +22,17 @@ - QuantizationPolicy: quantization scheduling """ import torch +import torch.nn as nn import torch.optim.lr_scheduler -from collections import namedtuple +from collections import namedtuple, OrderedDict import logging -msglogger = logging.getLogger() +import distiller + __all__ = ['PruningPolicy', 'RegularizationPolicy', 'QuantizationPolicy', 'LRPolicy', 'ScheduledTrainingPolicy', 'PolicyLoss', 'LossComponent'] +msglogger = logging.getLogger() PolicyLoss = namedtuple('PolicyLoss', ['overall_loss', 'loss_components']) LossComponent = namedtuple('LossComponent', ['name', 'value']) @@ -94,7 +97,7 @@ class PruningPolicy(ScheduledTrainingPolicy): disable this masking set: pruner_args['mask_on_forward_only'] = False - use_double_copies: when set to 'True', two sets of weights are used. In the forward-pass we use + use_double_copies: when set to `True`, two sets of weights are used. In the forward-pass we use masked weights to compute the loss, but in the backward-pass we update the unmasked weights (using gradients computed from the masked-weights loss). @@ -103,7 +106,15 @@ class PruningPolicy(ScheduledTrainingPolicy): fine-grained control over pruning than that provided by CompressionScheduler (epoch granularity). When setting 'mini_batch_pruning_frequency' to a value other than zero, make sure to configure the policy's schedule to once-every-epoch. - """ + + fold_batchnorm: when set to `True`, the weights of BatchNorm modules are folded into the the weights of + Conv-2D modules (if Conv2D->BN edges exist in the model graph). Each weights filter is attenuated using + a different pair of (gamma, beta) coefficients, so `fold_batchnorm` is relevant for fine-grained and + filter-ranking pruning methods. We attenuate using the running values of the mean and variance, as is + done in quantization. + This control argument is only supported for Conv-2D modules (i.e. other convolution operation variants and + Linear operations are not supported). + """ super(PruningPolicy, self).__init__(classes, layers) self.pruner = pruner # Copy external policy configuration, if available @@ -120,10 +131,51 @@ class PruningPolicy(ScheduledTrainingPolicy): self.use_double_copies = pruner_args.get('use_double_copies', False) self.discard_masks_at_minibatch_end = pruner_args.get('discard_masks_at_minibatch_end', False) self.skip_first_minibatch = pruner_args.get('skip_first_minibatch', False) - # Initiliaze state + self.fold_bn = pruner_args.get('fold_batchnorm', False) + # These are required for BN-folding. We cache them to improve performance + self.named_modules = None + self.sg = None + # Initialize state self.is_last_epoch = False self.is_initialized = False + @staticmethod + def _fold_batchnorm(model, param_name, param, named_modules, sg): + def _get_all_parameters(param_module, bn_module): + w, b, gamma, beta = param_module.weight, param_module.bias, bn_module.weight, bn_module.bias + if not bn_module.affine: + gamma = 1. + beta = 0. + return w, b, gamma, beta + + def get_bn_folded_weights(conv_module, bn_module): + """Compute the weights of `conv_module` after folding successor BN layer. + + In inference, DL frameworks and graph-compilers fold the batch normalization into + the weights as defined by equations 20 and 21 of https://arxiv.org/pdf/1806.08342.pdf + + :param conv_module: nn.Conv2d module + :param bn_module: nn.BatchNorm2d module which succeeds `conv_module` + :return: Folded weights + """ + w, b, gamma, beta = _get_all_parameters(conv_module, bn_module) + with torch.no_grad(): + sigma_running = torch.sqrt(bn_module.running_var + bn_module.eps) + w_corrected = w * (gamma / sigma_running).view(-1, 1, 1, 1) + return w_corrected + + layer_name = distiller.utils.param_name_2_module_name(param_name) + if not isinstance(named_modules[layer_name], nn.Conv2d): + return param + + bn_layers = sg.successors_f(layer_name, ['BatchNormalization']) + if bn_layers: + assert len(bn_layers) == 1 + bn_module = named_modules[bn_layers[0]] + conv_module = named_modules[layer_name] + param = get_bn_folded_weights(conv_module, bn_module) + return param + def on_epoch_begin(self, model, zeros_mask_dict, meta, **kwargs): msglogger.debug("Pruner {} is about to prune".format(self.pruner.name)) self.is_last_epoch = meta['current_epoch'] == (meta['ending_epoch'] - 1) @@ -132,7 +184,16 @@ class PruningPolicy(ScheduledTrainingPolicy): meta['model'] = model is_initialized = self.is_initialized + + if self.fold_bn: + # Cache this information (required for BN-folding) to improve performance + self.named_modules = OrderedDict(model.named_modules()) + dummy_input = torch.randn(model.input_shape) + self.sg = distiller.SummaryGraph(model, dummy_input) + for param_name, param in model.named_parameters(): + if self.fold_bn: + param = self._fold_batchnorm(model, param_name, param, self.named_modules, self.sg) if not is_initialized: # Initialize the maskers masker = zeros_mask_dict[param_name] @@ -141,6 +202,7 @@ class PruningPolicy(ScheduledTrainingPolicy): # register for the backward hook of the parameters if self.mask_gradients: masker.backward_hook_handle = param.register_hook(masker.mask_gradient) + self.is_initialized = True if not self.skip_first_minibatch: self.pruner.set_param_mask(param, param_name, zeros_mask_dict, meta) @@ -164,6 +226,8 @@ class PruningPolicy(ScheduledTrainingPolicy): for param_name, param in model.named_parameters(): if set_masks: + if self.fold_bn: + param = self._fold_batchnorm(model, param_name, param, self.named_modules, self.sg) self.pruner.set_param_mask(param, param_name, zeros_mask_dict, meta) zeros_mask_dict[param_name].apply_mask(param) diff --git a/distiller/pruning/automated_gradual_pruner.py b/distiller/pruning/automated_gradual_pruner.py index d906a12e24ad3e59b917ed0e555267bb6b6de414..c2e6d4ce44291895fc9dda0642915c0a86e18b9c 100755 --- a/distiller/pruning/automated_gradual_pruner.py +++ b/distiller/pruning/automated_gradual_pruner.py @@ -56,9 +56,10 @@ class AutomatedGradualPrunerBase(_ParameterPruner): def set_param_mask(self, param, param_name, zeros_mask_dict, meta): target_sparsity = self.compute_target_sparsity(meta) - self.prune_to_target_sparsity(param, param_name, zeros_mask_dict, target_sparsity, meta['model']) + self._set_param_mask_by_sparsity_target(param, param_name, zeros_mask_dict, target_sparsity, meta['model']) - def prune_to_target_sparsity(self, param, param_name, zeros_mask_dict, target_sparsity, model=None): + def _set_param_mask_by_sparsity_target(self, param, param_name, zeros_mask_dict, target_sparsity, model=None): + """Set the parameter mask using a target sparsity. Override this in subclasses""" raise NotImplementedError @@ -78,7 +79,7 @@ class AutomatedGradualPruner(AutomatedGradualPrunerBase): return super().set_param_mask(param, param_name, zeros_mask_dict, meta) - def prune_to_target_sparsity(self, param, param_name, zeros_mask_dict, target_sparsity, model=None): + def _set_param_mask_by_sparsity_target(self, param, param_name, zeros_mask_dict, target_sparsity, model=None): zeros_mask_dict[param_name].mask = SparsityLevelParameterPruner.create_mask(param, target_sparsity) @@ -92,8 +93,8 @@ class StructuredAGP(AutomatedGradualPrunerBase): super().__init__(name, initial_sparsity, final_sparsity) self.pruner = None - def prune_to_target_sparsity(self, param, param_name, zeros_mask_dict, target_sparsity, model): - self.pruner.prune_to_target_sparsity(param, param_name, zeros_mask_dict, target_sparsity, model) + def _set_param_mask_by_sparsity_target(self, param, param_name, zeros_mask_dict, target_sparsity, model): + self.pruner._set_param_mask_by_sparsity_target(param, param_name, zeros_mask_dict, target_sparsity, model) class L1RankedStructureParameterPruner_AGP(StructuredAGP): diff --git a/distiller/pruning/ranked_structures_pruner.py b/distiller/pruning/ranked_structures_pruner.py index 145303140f87b508097943aacc9ca6ed5bdb6af7..ed66d78d2e0ac7fa549a6ccc5bb39b09e24071d5 100755 --- a/distiller/pruning/ranked_structures_pruner.py +++ b/distiller/pruning/ranked_structures_pruner.py @@ -74,9 +74,9 @@ class _RankedStructureParameterPruner(_ParameterPruner): model = meta['model'] except TypeError: model = None - return self.prune_to_target_sparsity(param, param_name, zeros_mask_dict, fraction_to_prune, model) + return self._set_param_mask_by_sparsity_target(param, param_name, zeros_mask_dict, fraction_to_prune, model) - def prune_to_target_sparsity(self, param, param_name, zeros_mask_dict, target_sparsity, model): + def _set_param_mask_by_sparsity_target(self, param, param_name, zeros_mask_dict, target_sparsity, model): if not self.is_supported(param_name): return @@ -127,12 +127,10 @@ class LpRankedStructureParameterPruner(_RankedStructureParameterPruner): def prune_group(self, fraction_to_prune, param, param_name, zeros_mask_dict, model=None, binary_map=None): if fraction_to_prune == 0: return - if self.group_type in ['3D', 'Filters']: + if self.group_type in ('3D', 'Filters'): group_pruning_fn = partial(self.rank_and_prune_filters, noise=self.noise) - elif self.group_type == 'Channels': + elif self.group_type in ('Channels', 'Rows'): group_pruning_fn = partial(self.rank_and_prune_channels, noise=self.noise) - elif self.group_type == 'Rows': - group_pruning_fn = self.rank_and_prune_rows elif self.group_type == 'Blocks': group_pruning_fn = partial(self.rank_and_prune_blocks, block_shape=self.block_shape) @@ -147,12 +145,6 @@ class LpRankedStructureParameterPruner(_RankedStructureParameterPruner): model=None, binary_map=None, magnitude_fn=distiller.norms.l1_norm, noise=0.0, group_size=1, rounding_fn=math.floor): if binary_map is None: - - if param.dim() == 2: - # 2D Linear parameters - return LpRankedStructureParameterPruner.rank_and_prune_rows(fraction_to_prune, param, param_name, - zeros_mask_dict, model, binary_map, - magnitude_fn, group_size) bottomk_channels, channel_mags = distiller.norms.rank_channels(param, group_size, magnitude_fn, fraction_to_prune, rounding_fn, noise) if bottomk_channels is None: @@ -194,43 +186,6 @@ class LpRankedStructureParameterPruner(_RankedStructureParameterPruner): fraction_to_prune) return binary_map - @staticmethod - def rank_and_prune_rows(fraction_to_prune, param, param_name, - zeros_mask_dict, model=None, binary_map=None, - magnitude_fn=distiller.norms.l1_norm, group_size=1): - """Prune the rows of a matrix, based on ranked L1-norms of the matrix rows. - - PyTorch stores the weights matrices in a transposed format. I.e. before performing GEMM, a matrix is - transposed. This is because the output is computed as follows: - y = x(W^T) + b ; where W^T is the transpose of W - - Removing input_channels from W^T, is removing rows of W^T, which is removing columns of W. - - To deal with this rotation, we can either transpose the matrix and then proceed to compute the masks - as usual, or we can treat columns as rows, and rows as columns :-(. - We choose the latter, because transposing very large matrices can be detrimental to performance. Note - that computing mean L1-norm of columns is also not optimal, because consecutive column elements are far - away from each other in memory, and this means poor use of caches and system memory. - """ - if binary_map is None: - bottomk_cols, cols_mags = distiller.norms.rank_cols(param, group_size, magnitude_fn, fraction_to_prune, - rounding_fn=math.floor, noise=None) - if bottomk_cols is None: - # Empty list means that fraction_to_prune is too low to prune anything - msglogger.info("Too few cols - can't prune %.1f%% cols", 100 * fraction_to_prune) - return - threshold = bottomk_cols[-1] - binary_map = cols_mags.gt(threshold).type(param.data.type()) - - if zeros_mask_dict is not None: - mask, _ = distiller.thresholding.expand_binary_map(param, 'Cols', binary_map) - zeros_mask_dict[param_name].mask = mask - msglogger.info("%sRankedStructureParameterPruner - param: %s pruned=%.3f goal=%.3f", - magnitude_fn, param_name, - distiller.sparsity(mask), - fraction_to_prune) - return binary_map - @staticmethod def rank_and_prune_blocks(fraction_to_prune, param, param_name=None, zeros_mask_dict=None, model=None, binary_map=None, block_shape=None, @@ -691,14 +646,9 @@ class FMReconstructionChannelPruner(_RankedStructureParameterPruner): noise=0): assert binary_map is None if binary_map is None: - op_type = 'conv' if param.dim() == 4 else 'fc' - if op_type == 'conv': - bottomk_channels, channel_mags = distiller.norms.rank_channels(param, group_size, magnitude_fn, - fraction_to_prune, rounding_fn, noise) - else: - bottomk_channels, channel_mags = distiller.norms.rank_cols(param, group_size, magnitude_fn, - fraction_to_prune, rounding_fn=math.floor, - noise=None) + bottomk_channels, channel_mags = distiller.norms.rank_channels(param, group_size, magnitude_fn, + fraction_to_prune, rounding_fn, noise) + # Todo: this little piece of code can be refactored if bottomk_channels is None: # Empty list means that fraction_to_prune is too low to prune anything diff --git a/distiller/thresholding.py b/distiller/thresholding.py index 2f93662f6fd227e3b6d2eb497a639286c8020d4b..cb38909667788baa6d5d360230a08b6031654ccc 100755 --- a/distiller/thresholding.py +++ b/distiller/thresholding.py @@ -162,14 +162,17 @@ def expand_binary_map(param, group_type, binary_map): assert binary_map is not None # Now let's expand back up to a 4D mask + if group_type == 'Channels' and param.dim() == 2: + group_type = 'Cols' + if group_type == '2D': a = binary_map.expand(param.size(2) * param.size(3), param.size(0) * param.size(1)).t() - return a.view(param.size(0), param.size(1), param.size(2), param.size(3)), binary_map + return a.view(*param.shape), binary_map elif group_type == 'Rows': return binary_map.expand(param.size(1), param.size(0)).t(), binary_map elif group_type == 'Cols': - return binary_map.expand(param.size(0), param.size(1)), binary_map + return binary_map.expand(*param.shape), binary_map elif group_type == '3D' or group_type == 'Filters': a = binary_map.expand(np.prod(param.shape[1:]), param.size(0)).t() return a.view(*param.shape), binary_map @@ -178,5 +181,5 @@ def expand_binary_map(param, group_type, binary_map): a = binary_map.expand(num_filters, num_channels) c = a.unsqueeze(-1) d = c.expand(num_filters, num_channels, param.size(2) * param.size(3)).contiguous() - return d.view(param.size(0), param.size(1), param.size(2), param.size(3)), binary_map + return d.view(*param.shape), binary_map diff --git a/examples/agp-pruning/resnet20_filters.bn_fold.schedule_agp.yaml b/examples/agp-pruning/resnet20_filters.bn_fold.schedule_agp.yaml new file mode 100755 index 0000000000000000000000000000000000000000..f23ef843e51c874f239ca36ba98b89d978eefa88 --- /dev/null +++ b/examples/agp-pruning/resnet20_filters.bn_fold.schedule_agp.yaml @@ -0,0 +1,140 @@ +# This is the same as examples/agp-pruning/resnet20_filters.schedule_agp.yaml, but with BN-folding for ranking +# +# Baseline results: +# Top1: 91.780 Top5: 99.710 Loss: 0.376 +# Total MACs: 40,813,184 +# # of parameters: 270,896 +# +# Results: +# Top1: 91.34 +# Total MACs: 30,655,104 +# Total sparsity: 46.3 +# # of parameters: 120,000 (=55.7% of the baseline parameters) +# +# time python3 compress_classifier.py --arch resnet20_cifar $CIFAR10_PATH -p=50 --lr=0.4 --epochs=180 --compress=../agp-pruning/resnet20_filters.bn_fold.schedule_agp.yaml --resume-from=../ssl/checkpoints/checkpoint_trained_dense.pth.tar --vs=0 --reset-optimizer --gpu=0 +# +# Parameters: +# +----+-------------------------------------+----------------+---------------+----------------+------------+------------+----------+----------+----------+------------+---------+----------+------------+ +# | | Name | Shape | NNZ (dense) | NNZ (sparse) | Cols (%) | Rows (%) | Ch (%) | 2D (%) | 3D (%) | Fine (%) | Std | Mean | Abs-Mean | +# |----+-------------------------------------+----------------+---------------+----------------+------------+------------+----------+----------+----------+------------+---------+----------+------------| +# | 0 | module.conv1.weight | (16, 3, 3, 3) | 432 | 432 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.41630 | -0.00526 | 0.29222 | +# | 1 | module.layer1.0.conv1.weight | (16, 16, 3, 3) | 2304 | 2304 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.15170 | -0.01458 | 0.10325 | +# | 2 | module.layer1.0.conv2.weight | (16, 16, 3, 3) | 2304 | 2304 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.14991 | 0.00199 | 0.10401 | +# | 3 | module.layer1.1.conv1.weight | (16, 16, 3, 3) | 2304 | 2304 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.13939 | -0.01859 | 0.10473 | +# | 4 | module.layer1.1.conv2.weight | (16, 16, 3, 3) | 2304 | 2304 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.13140 | -0.00861 | 0.10033 | +# | 5 | module.layer1.2.conv1.weight | (16, 16, 3, 3) | 2304 | 2304 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.17184 | -0.00677 | 0.11853 | +# | 6 | module.layer1.2.conv2.weight | (16, 16, 3, 3) | 2304 | 2304 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.13944 | -0.00140 | 0.09709 | +# | 7 | module.layer2.0.conv1.weight | (16, 16, 3, 3) | 2304 | 2304 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.17399 | -0.00476 | 0.13297 | +# | 8 | module.layer2.0.conv2.weight | (16, 16, 3, 3) | 2304 | 2304 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.18404 | 0.00603 | 0.14192 | +# | 9 | module.layer2.0.downsample.0.weight | (16, 16, 1, 1) | 256 | 256 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.34235 | -0.01751 | 0.25170 | +# | 10 | module.layer2.1.conv1.weight | (16, 16, 3, 3) | 2304 | 2304 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.09837 | -0.00903 | 0.07080 | +# | 11 | module.layer2.1.conv2.weight | (16, 16, 3, 3) | 2304 | 2304 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.08172 | -0.00530 | 0.05841 | +# | 12 | module.layer2.2.conv1.weight | (16, 16, 3, 3) | 2304 | 2304 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.12200 | -0.00766 | 0.09124 | +# | 13 | module.layer2.2.conv2.weight | (16, 16, 3, 3) | 2304 | 2304 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.08811 | 0.00425 | 0.06511 | +# | 14 | module.layer3.0.conv1.weight | (64, 16, 3, 3) | 9216 | 9216 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.13276 | -0.00406 | 0.10314 | +# | 15 | module.layer3.0.conv2.weight | (64, 64, 3, 3) | 36864 | 36864 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.09774 | -0.00467 | 0.07597 | +# | 16 | module.layer3.0.downsample.0.weight | (64, 16, 1, 1) | 1024 | 1024 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.16717 | -0.00613 | 0.13268 | +# | 17 | module.layer3.1.conv1.weight | (64, 64, 3, 3) | 36864 | 11060 | 0.00000 | 0.00000 | 0.00000 | 7.76367 | 1.56250 | 69.99783 | 0.07646 | -0.00394 | 0.03609 | +# | 18 | module.layer3.1.conv2.weight | (64, 64, 3, 3) | 36864 | 11060 | 0.00000 | 0.00000 | 1.56250 | 10.59570 | 0.00000 | 69.99783 | 0.07015 | -0.00542 | 0.03294 | +# | 19 | module.layer3.2.conv1.weight | (64, 64, 3, 3) | 36864 | 11060 | 0.00000 | 0.00000 | 0.00000 | 13.37891 | 3.12500 | 69.99783 | 0.07038 | -0.00402 | 0.03317 | +# | 20 | module.layer3.2.conv2.weight | (64, 64, 3, 3) | 36864 | 11060 | 0.00000 | 0.00000 | 3.12500 | 30.32227 | 0.00000 | 69.99783 | 0.04269 | 0.00014 | 0.01844 | +# | 21 | module.fc.weight | (10, 64) | 640 | 320 | 0.00000 | 50.00000 | 0.00000 | 0.00000 | 0.00000 | 50.00000 | 0.55338 | -0.00001 | 0.31654 | +# | 22 | Total sparsity: | - | 223536 | 120000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 46.31737 | 0.00000 | 0.00000 | 0.00000 | +# +----+-------------------------------------+----------------+---------------+----------------+------------+------------+----------+----------+----------+------------+---------+----------+------------+ +# Total sparsity: 46.32 +# +# --- validate (epoch=179)----------- +# 10000 samples (256 per mini-batch) +# ==> Top1: 91.120 Top5: 99.660 Loss: 0.374 +# +# ==> Best [Top1: 91.340 Top5: 99.650 Sparsity:46.32 NNZ-Params: 120000 on epoch: 109] +# Saving checkpoint to: logs/2019.10.31-200150/checkpoint.pth.tar +# --- test --------------------- +# 10000 samples (256 per mini-batch) +# ==> Top1: 91.120 Top5: 99.660 Loss: 0.364 +# +# +# real 33m39.771s +# user 289m4.681s +# sys 17m43.188s + +version: 1 + +pruners: + low_pruner: + class: L1RankedStructureParameterPruner_AGP + initial_sparsity : 0.10 + final_sparsity: 0.50 + group_type: Filters + weights: [module.layer2.0.conv1.weight, module.layer2.0.conv2.weight, + module.layer2.0.downsample.0.weight, + module.layer2.1.conv2.weight, module.layer2.2.conv2.weight, + module.layer2.1.conv1.weight, module.layer2.2.conv1.weight] + + fine_pruner: + class: AutomatedGradualPruner + initial_sparsity : 0.05 + final_sparsity: 0.70 + weights: [module.layer3.1.conv1.weight, module.layer3.1.conv2.weight, + module.layer3.2.conv1.weight, module.layer3.2.conv2.weight] + + fc_pruner: + class: L1RankedStructureParameterPruner_AGP + initial_sparsity : 0.05 + final_sparsity: 0.50 + group_type: Rows + weights: [module.fc.weight] + + +lr_schedulers: + pruning_lr: + class: StepLR + step_size: 50 + gamma: 0.10 + +extensions: + net_thinner: + class: 'FilterRemover' + thinning_func_str: remove_filters + arch: 'resnet20_cifar' + dataset: 'cifar10' + + +policies: + - pruner: + instance_name : low_pruner + args: + fold_batchnorm: True + starting_epoch: 0 + ending_epoch: 30 + frequency: 2 + +# After completing the pruning, we perform network thinning and continue fine-tuning. +# When there is ambiguity in the scheduling order of policies, Distiller follows the +# order of declaration. Because epoch 30 is the end of one pruner, and the beginning +# of two others, and because we want the thinning to happen at the beginning of +# epoch 30, it is important to declare the thinning policy here and not lower in the +# file. + - extension: + instance_name: net_thinner + epochs: [30] + + - pruner: + instance_name : fine_pruner + args: + fold_batchnorm: True + starting_epoch: 30 + ending_epoch: 50 + frequency: 2 + + - pruner: + instance_name : fc_pruner + starting_epoch: 30 + ending_epoch: 50 + frequency: 2 + + - lr_scheduler: + instance_name: pruning_lr + starting_epoch: 0 + ending_epoch: 400 + frequency: 1 diff --git a/examples/agp-pruning/resnet20_filters.schedule_agp.yaml b/examples/agp-pruning/resnet20_filters.schedule_agp.yaml index 362e52c00c800528060502c14002e65cddf3f07c..eabf8c65b23538533b8bf506280fe902090b8ede 100755 --- a/examples/agp-pruning/resnet20_filters.schedule_agp.yaml +++ b/examples/agp-pruning/resnet20_filters.schedule_agp.yaml @@ -9,60 +9,57 @@ # # of parameters: 270,896 # # Results: -# Top1: 91.73 +# Top1: 91.34 # Total MACs: 30,655,104 -# Total sparsity: 41.10 +# Total sparsity: 46.3% # # of parameters: 120,000 (=55.7% of the baseline parameters) # -# time python3 compress_classifier.py --arch resnet20_cifar ../../../data.cifar10 -p=50 --lr=0.3 --epochs=180 --compress=../agp-pruning/resnet20_filters.schedule_agp.yaml --resume-from=../ssl/checkpoints/checkpoint_trained_dense.pth.tar --vs=0 --reset-optimizer +# time python3 compress_classifier.py --arch resnet20_cifar $CIFAR10_PATH -p=50 --lr=0.4 --epochs=180 --compress=../agp-pruning/resnet20_filters.schedule_agp.yaml --resume-from=../ssl/checkpoints/checkpoint_trained_dense.pth.tar --vs=0 --reset-optimizer --gpu=0 # -# Parameters: -# +----+-------------------------------------+----------------+---------------+----------------+------------+------------+----------+----------+----------+------------+---------+----------+------------+ -# | | Name | Shape | NNZ (dense) | NNZ (sparse) | Cols (%) | Rows (%) | Ch (%) | 2D (%) | 3D (%) | Fine (%) | Std | Mean | Abs-Mean | -# |----+-------------------------------------+----------------+---------------+----------------+------------+------------+----------+----------+----------+------------+---------+----------+------------| -# | 0 | module.conv1.weight | (16, 3, 3, 3) | 432 | 432 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.42267 | -0.01028 | 0.29903 | -# | 1 | module.layer1.0.conv1.weight | (16, 16, 3, 3) | 2304 | 2304 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.15895 | -0.01265 | 0.11210 | -# | 2 | module.layer1.0.conv2.weight | (16, 16, 3, 3) | 2304 | 2304 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.15610 | 0.00257 | 0.11472 | -# | 3 | module.layer1.1.conv1.weight | (16, 16, 3, 3) | 2304 | 2304 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.13888 | -0.01590 | 0.10543 | -# | 4 | module.layer1.1.conv2.weight | (16, 16, 3, 3) | 2304 | 2304 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.13052 | -0.00519 | 0.10135 | -# | 5 | module.layer1.2.conv1.weight | (16, 16, 3, 3) | 2304 | 2304 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.18351 | -0.01298 | 0.13564 | -# | 6 | module.layer1.2.conv2.weight | (16, 16, 3, 3) | 2304 | 2304 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.14909 | -0.00098 | 0.11435 | -# | 7 | module.layer2.0.conv1.weight | (16, 16, 3, 3) | 2304 | 2304 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.17438 | -0.00580 | 0.13427 | -# | 8 | module.layer2.0.conv2.weight | (16, 16, 3, 3) | 2304 | 2304 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.18654 | -0.00126 | 0.14499 | -# | 9 | module.layer2.0.downsample.0.weight | (16, 16, 1, 1) | 256 | 256 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.34412 | -0.01243 | 0.24940 | -# | 10 | module.layer2.1.conv1.weight | (16, 16, 3, 3) | 2304 | 2304 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.11833 | -0.00937 | 0.08865 | -# | 11 | module.layer2.1.conv2.weight | (16, 16, 3, 3) | 2304 | 2304 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.09171 | -0.00197 | 0.06956 | -# | 12 | module.layer2.2.conv1.weight | (16, 16, 3, 3) | 2304 | 2304 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.13403 | -0.01057 | 0.09999 | -# | 13 | module.layer2.2.conv2.weight | (16, 16, 3, 3) | 2304 | 2304 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.09652 | 0.00544 | 0.07033 | -# | 14 | module.layer3.0.conv1.weight | (64, 16, 3, 3) | 9216 | 9216 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.13635 | -0.00543 | 0.10654 | -# | 15 | module.layer3.0.conv2.weight | (64, 64, 3, 3) | 36864 | 36864 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.09992 | -0.00600 | 0.07893 | -# | 16 | module.layer3.0.downsample.0.weight | (64, 16, 1, 1) | 1024 | 1024 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.17133 | -0.00926 | 0.13503 | -# | 17 | module.layer3.1.conv1.weight | (64, 64, 3, 3) | 36864 | 11060 | 0.00000 | 0.00000 | 0.00000 | 8.47168 | 1.56250 | 69.99783 | 0.07819 | -0.00423 | 0.03752 | -# | 18 | module.layer3.1.conv2.weight | (64, 64, 3, 3) | 36864 | 11060 | 0.00000 | 0.00000 | 1.56250 | 8.37402 | 0.00000 | 69.99783 | 0.07238 | -0.00539 | 0.03450 | -# | 19 | module.layer3.2.conv1.weight | (64, 64, 3, 3) | 36864 | 11060 | 0.00000 | 0.00000 | 0.00000 | 11.93848 | 3.12500 | 69.99783 | 0.07195 | -0.00571 | 0.03462 | -# | 20 | module.layer3.2.conv2.weight | (64, 64, 3, 3) | 36864 | 11060 | 0.00000 | 0.00000 | 3.12500 | 28.75977 | 1.56250 | 69.99783 | 0.04405 | 0.00060 | 0.02004 | -# | 21 | module.fc.weight | (10, 64) | 640 | 320 | 0.00000 | 50.00000 | 0.00000 | 0.00000 | 0.00000 | 50.00000 | 0.57112 | -0.00001 | 0.36129 | -# | 22 | Total sparsity: | - | 223536 | 120000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 46.31737 | 0.00000 | 0.00000 | 0.00000 | -# +----+-------------------------------------+----------------+---------------+----------------+------------+------------+----------+----------+----------+------------+---------+----------+------------+ -# Total sparsity: 46.32 +# Parameters: +# +----+-------------------------------------+----------------+---------------+----------------+------------+------------+----------+----------+----------+------------+---------+----------+------------+ +# | | Name | Shape | NNZ (dense) | NNZ (sparse) | Cols (%) | Rows (%) | Ch (%) | 2D (%) | 3D (%) | Fine (%) | Std | Mean | Abs-Mean | +# |----+-------------------------------------+----------------+---------------+----------------+------------+------------+----------+----------+----------+------------+---------+----------+------------| +# | 0 | module.conv1.weight | (16, 3, 3, 3) | 432 | 432 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.41726 | -0.00601 | 0.29649 | +# | 1 | module.layer1.0.conv1.weight | (16, 16, 3, 3) | 2304 | 2304 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.15451 | -0.01086 | 0.10477 | +# | 2 | module.layer1.0.conv2.weight | (16, 16, 3, 3) | 2304 | 2304 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.15074 | -0.00062 | 0.10780 | +# | 3 | module.layer1.1.conv1.weight | (16, 16, 3, 3) | 2304 | 2304 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.13604 | -0.01868 | 0.10372 | +# | 4 | module.layer1.1.conv2.weight | (16, 16, 3, 3) | 2304 | 2304 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.12923 | -0.00463 | 0.09968 | +# | 5 | module.layer1.2.conv1.weight | (16, 16, 3, 3) | 2304 | 2304 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.17818 | -0.01222 | 0.13184 | +# | 6 | module.layer1.2.conv2.weight | (16, 16, 3, 3) | 2304 | 2304 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.14475 | -0.00069 | 0.11089 | +# | 7 | module.layer2.0.conv1.weight | (16, 16, 3, 3) | 2304 | 2304 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.16991 | 0.00091 | 0.12894 | +# | 8 | module.layer2.0.conv2.weight | (16, 16, 3, 3) | 2304 | 2304 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.18059 | 0.00199 | 0.14176 | +# | 9 | module.layer2.0.downsample.0.weight | (16, 16, 1, 1) | 256 | 256 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.34145 | -0.03631 | 0.25094 | +# | 10 | module.layer2.1.conv1.weight | (16, 16, 3, 3) | 2304 | 2304 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.13213 | -0.00809 | 0.10198 | +# | 11 | module.layer2.1.conv2.weight | (16, 16, 3, 3) | 2304 | 2304 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.10230 | 0.00805 | 0.07883 | +# | 12 | module.layer2.2.conv1.weight | (16, 16, 3, 3) | 2304 | 2304 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.11462 | -0.00682 | 0.08532 | +# | 13 | module.layer2.2.conv2.weight | (16, 16, 3, 3) | 2304 | 2304 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.08012 | 0.00611 | 0.05776 | +# | 14 | module.layer3.0.conv1.weight | (64, 16, 3, 3) | 9216 | 9216 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.13316 | -0.00256 | 0.10497 | +# | 15 | module.layer3.0.conv2.weight | (64, 64, 3, 3) | 36864 | 36864 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.09755 | -0.00598 | 0.07722 | +# | 16 | module.layer3.0.downsample.0.weight | (64, 16, 1, 1) | 1024 | 1024 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.16702 | 0.00251 | 0.12968 | +# | 17 | module.layer3.1.conv1.weight | (64, 64, 3, 3) | 36864 | 11060 | 0.00000 | 0.00000 | 0.00000 | 10.25391 | 4.68750 | 69.99783 | 0.07554 | -0.00373 | 0.03568 | +# | 18 | module.layer3.1.conv2.weight | (64, 64, 3, 3) | 36864 | 11060 | 0.00000 | 0.00000 | 4.68750 | 11.49902 | 0.00000 | 69.99783 | 0.06968 | -0.00573 | 0.03275 | +# | 19 | module.layer3.2.conv1.weight | (64, 64, 3, 3) | 36864 | 11060 | 0.00000 | 0.00000 | 0.00000 | 15.57617 | 4.68750 | 69.99783 | 0.06895 | -0.00504 | 0.03245 | +# | 20 | module.layer3.2.conv2.weight | (64, 64, 3, 3) | 36864 | 11060 | 0.00000 | 0.00000 | 4.68750 | 32.08008 | 0.00000 | 69.99783 | 0.04180 | 0.00053 | 0.01793 | +# | 21 | module.fc.weight | (10, 64) | 640 | 320 | 0.00000 | 50.00000 | 0.00000 | 0.00000 | 0.00000 | 50.00000 | 0.55055 | -0.00001 | 0.31038 | +# | 22 | Total sparsity: | - | 223536 | 120000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 46.31737 | 0.00000 | 0.00000 | 0.00000 | +# +----+-------------------------------------+----------------+---------------+----------------+------------+------------+----------+----------+----------+------------+---------+----------+------------+ +# Total sparsity: 46.32 # -# --- validate (epoch=359)----------- -# 10000 samples (256 per mini-batch) -# ==> Top1: 91.490 Top5: 99.710 Loss: 0.346 +# --- validate (epoch=179)----------- +# 10000 samples (256 per mini-batch) +# ==> Top1: 91.110 Top5: 99.700 Loss: 0.379 # -# ==> Best Top1: 91.730 On Epoch: 344 +# ==> Best [Top1: 91.340 Top5: 99.670 Sparsity:46.32 NNZ-Params: 120000 on epoch: 127] +# Saving checkpoint to: logs/2019.10.31-204215/checkpoint.pth.tar +# --- test --------------------- +# 10000 samples (256 per mini-batch) +# ==> Top1: 91.110 Top5: 99.700 Loss: 0.363 # -# Saving checkpoint to: logs/2018.10.30-150931/checkpoint.pth.tar -# --- test --------------------- -# 10000 samples (256 per mini-batch) -# ==> Top1: 91.490 Top5: 99.710 Loss: 0.346 # -# -# Log file for this run: /home/cvds_lab/nzmora/sandbox_5/distiller/examples/classifier_compression/logs/2018.10.30-150931/2018.10.30-150931.log -# -# real 36m36.329s -# user 82m32.685s -# sys 10m8.746s +# real 31m38.549s +# user 293m36.410s +# sys 18m6.041s version: 1 @@ -113,6 +110,16 @@ policies: ending_epoch: 30 frequency: 2 +# After completing the pruning, we perform network thinning and continue fine-tuning. +# When there is ambiguity in the scheduling order of policies, Distiller follows the +# order of declaration. Because epoch 30 is the end of one pruner, and the beginning +# of two others, and because we want the thinning to happen at the beginning of +# epoch 30, it is important to declare the thinning policy here and not lower in the +# file. + - extension: + instance_name: net_thinner + epochs: [30] + - pruner: instance_name : fine_pruner starting_epoch: 30 @@ -125,11 +132,6 @@ policies: ending_epoch: 50 frequency: 2 -# After completeing the pruning, we perform network thinning and continue fine-tuning. - - extension: - instance_name: net_thinner - epochs: [32] - - lr_scheduler: instance_name: pruning_lr starting_epoch: 0 diff --git a/examples/agp-pruning/resnet20_filters.schedule_agp_2.yaml b/examples/agp-pruning/resnet20_filters.schedule_agp_2.yaml index 1ea06610ad142adb4b641f1f4ab0c0f03dc7e7af..cc07e90d330db41523fded09b9fe6384c681816c 100755 --- a/examples/agp-pruning/resnet20_filters.schedule_agp_2.yaml +++ b/examples/agp-pruning/resnet20_filters.schedule_agp_2.yaml @@ -9,60 +9,59 @@ # # of parameters: 270,896 # # Results: -# Top1: 91.200 Top5: 99.660 Loss: 1.551 +# Top1: 91.630 Top5: 99.670 # Total MACs: 30,638,720 # Total sparsity: 41.84 # # of parameters: 143,488 (=53% of the baseline parameters) # -# time python3 compress_classifier.py --arch resnet20_cifar ../../../data.cifar10 -p=50 --lr=0.1 --epochs=180 --compress=../agp-pruning/resnet20_filters.schedule_agp_2.yaml -j=1 --deterministic --resume-from=../ssl/checkpoints/checkpoint_trained_dense.pth.tar --reset-optimizer +# time python3 compress_classifier.py --arch resnet20_cifar ../../../data.cifar10 -p=50 --lr=0.1 --epochs=180 --compress=../agp-pruning/resnet20_filters.schedule_agp_2.yaml -j=1 --deterministic --resume-from=../ssl/checkpoints/checkpoint_trained_dense.pth.tar --reset-optimizer --gpus=0 --vs=0 # -# Parameters: -# +----+-------------------------------------+----------------+---------------+----------------+------------+------------+----------+----------+----------+------------+---------+----------+------------+ -# | | Name | Shape | NNZ (dense) | NNZ (sparse) | Cols (%) | Rows (%) | Ch (%) | 2D (%) | 3D (%) | Fine (%) | Std | Mean | Abs-Mean | -# |----+-------------------------------------+----------------+---------------+----------------+------------+------------+----------+----------+----------+------------+---------+----------+------------| -# | 0 | module.conv1.weight | (16, 3, 3, 3) | 432 | 432 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.27315 | -0.00387 | 0.19394 | -# | 1 | module.layer1.0.conv1.weight | (10, 16, 3, 3) | 1440 | 1440 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.12038 | -0.01295 | 0.08811 | -# | 2 | module.layer1.0.conv2.weight | (16, 10, 3, 3) | 1440 | 1440 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.11879 | -0.00031 | 0.08735 | -# | 3 | module.layer1.1.conv1.weight | (10, 16, 3, 3) | 1440 | 1440 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.10293 | -0.01274 | 0.07795 | -# | 4 | module.layer1.1.conv2.weight | (16, 10, 3, 3) | 1440 | 1440 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.09285 | -0.00276 | 0.07141 | -# | 5 | module.layer1.2.conv1.weight | (10, 16, 3, 3) | 1440 | 1440 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.12849 | -0.00345 | 0.09355 | -# | 6 | module.layer1.2.conv2.weight | (16, 10, 3, 3) | 1440 | 1440 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.10689 | -0.00381 | 0.08038 | -# | 7 | module.layer2.0.conv1.weight | (20, 16, 3, 3) | 2880 | 2880 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.10467 | -0.00371 | 0.08149 | -# | 8 | module.layer2.0.conv2.weight | (32, 20, 3, 3) | 5760 | 5760 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.08897 | -0.00502 | 0.06938 | -# | 9 | module.layer2.0.downsample.0.weight | (32, 16, 1, 1) | 512 | 512 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.17695 | -0.01111 | 0.12479 | -# | 10 | module.layer2.1.conv1.weight | (20, 32, 3, 3) | 5760 | 5760 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.07736 | -0.00531 | 0.06118 | -# | 11 | module.layer2.1.conv2.weight | (32, 20, 3, 3) | 5760 | 5760 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.06832 | -0.00404 | 0.05406 | -# | 12 | module.layer2.2.conv1.weight | (20, 32, 3, 3) | 5760 | 5760 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.07965 | -0.00904 | 0.06278 | -# | 13 | module.layer2.2.conv2.weight | (32, 20, 3, 3) | 5760 | 5760 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.06305 | 0.00122 | 0.04955 | -# | 14 | module.layer3.0.conv1.weight | (64, 32, 3, 3) | 18432 | 18432 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.06753 | -0.00459 | 0.05371 | -# | 15 | module.layer3.0.conv2.weight | (64, 64, 3, 3) | 36864 | 36864 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.06379 | -0.00297 | 0.05078 | -# | 16 | module.layer3.0.downsample.0.weight | (64, 32, 1, 1) | 2048 | 2048 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.08779 | -0.00584 | 0.06956 | -# | 17 | module.layer3.1.conv1.weight | (64, 64, 3, 3) | 36864 | 11060 | 0.00000 | 0.00000 | 0.00000 | 7.12891 | 0.00000 | 69.99783 | 0.05191 | -0.00319 | 0.02604 | -# | 18 | module.layer3.1.conv2.weight | (64, 64, 3, 3) | 36864 | 11060 | 0.00000 | 0.00000 | 0.00000 | 7.98340 | 0.00000 | 69.99783 | 0.04658 | -0.00360 | 0.02330 | -# | 19 | module.layer3.2.conv1.weight | (64, 64, 3, 3) | 36864 | 11060 | 0.00000 | 0.00000 | 0.00000 | 10.10742 | 0.00000 | 69.99783 | 0.04563 | -0.00393 | 0.02297 | -# | 20 | module.layer3.2.conv2.weight | (64, 64, 3, 3) | 36864 | 11060 | 0.00000 | 0.00000 | 0.00000 | 31.20117 | 0.00000 | 69.99783 | 0.02453 | 0.00005 | 0.01240 | -# | 21 | module.fc.weight | (10, 64) | 640 | 640 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.50451 | -0.00001 | 0.43698 | -# | 22 | Total sparsity: | - | 246704 | 143488 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 41.83799 | 0.00000 | 0.00000 | 0.00000 | -# +----+-------------------------------------+----------------+---------------+----------------+------------+------------+----------+----------+----------+------------+---------+----------+------------+ -# Total sparsity: 41.84 +# Parameters: +# +----+-------------------------------------+----------------+---------------+----------------+------------+------------+----------+----------+----------+------------+---------+----------+------------+ +# | | Name | Shape | NNZ (dense) | NNZ (sparse) | Cols (%) | Rows (%) | Ch (%) | 2D (%) | 3D (%) | Fine (%) | Std | Mean | Abs-Mean | +# |----+-------------------------------------+----------------+---------------+----------------+------------+------------+----------+----------+----------+------------+---------+----------+------------| +# | 0 | module.conv1.weight | (16, 3, 3, 3) | 432 | 432 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.38465 | -0.00533 | 0.27349 | +# | 1 | module.layer1.0.conv1.weight | (10, 16, 3, 3) | 1440 | 1440 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.17334 | -0.01720 | 0.12535 | +# | 2 | module.layer1.0.conv2.weight | (16, 10, 3, 3) | 1440 | 1440 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.17280 | 0.00148 | 0.12660 | +# | 3 | module.layer1.1.conv1.weight | (10, 16, 3, 3) | 1440 | 1440 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.14518 | -0.02108 | 0.11044 | +# | 4 | module.layer1.1.conv2.weight | (16, 10, 3, 3) | 1440 | 1440 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.13157 | -0.00240 | 0.09998 | +# | 5 | module.layer1.2.conv1.weight | (10, 16, 3, 3) | 1440 | 1440 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.18724 | -0.00470 | 0.13594 | +# | 6 | module.layer1.2.conv2.weight | (16, 10, 3, 3) | 1440 | 1440 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.15303 | -0.00564 | 0.11591 | +# | 7 | module.layer2.0.conv1.weight | (20, 16, 3, 3) | 2880 | 2880 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.15062 | -0.00379 | 0.11690 | +# | 8 | module.layer2.0.conv2.weight | (32, 20, 3, 3) | 5760 | 5760 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.12943 | -0.00739 | 0.10150 | +# | 9 | module.layer2.0.downsample.0.weight | (32, 16, 1, 1) | 512 | 512 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.25227 | -0.01490 | 0.17715 | +# | 10 | module.layer2.1.conv1.weight | (20, 32, 3, 3) | 5760 | 5760 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.11074 | -0.00783 | 0.08721 | +# | 11 | module.layer2.1.conv2.weight | (32, 20, 3, 3) | 5760 | 5760 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.09753 | -0.00582 | 0.07681 | +# | 12 | module.layer2.2.conv1.weight | (20, 32, 3, 3) | 5760 | 5760 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.11501 | -0.01363 | 0.09121 | +# | 13 | module.layer2.2.conv2.weight | (32, 20, 3, 3) | 5760 | 5760 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.09145 | 0.00280 | 0.07167 | +# | 14 | module.layer3.0.conv1.weight | (64, 32, 3, 3) | 18432 | 18432 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.09772 | -0.00674 | 0.07769 | +# | 15 | module.layer3.0.conv2.weight | (64, 64, 3, 3) | 36864 | 36864 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.09316 | -0.00339 | 0.07396 | +# | 16 | module.layer3.0.downsample.0.weight | (64, 32, 1, 1) | 2048 | 2048 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.12438 | -0.00958 | 0.09868 | +# | 17 | module.layer3.1.conv1.weight | (64, 64, 3, 3) | 36864 | 11060 | 0.00000 | 0.00000 | 0.00000 | 6.71387 | 0.00000 | 69.99783 | 0.07404 | -0.00405 | 0.03694 | +# | 18 | module.layer3.1.conv2.weight | (64, 64, 3, 3) | 36864 | 11060 | 0.00000 | 0.00000 | 0.00000 | 7.25098 | 0.00000 | 69.99783 | 0.06739 | -0.00494 | 0.03356 | +# | 19 | module.layer3.2.conv1.weight | (64, 64, 3, 3) | 36864 | 11060 | 0.00000 | 0.00000 | 0.00000 | 10.37598 | 0.00000 | 69.99783 | 0.06739 | -0.00414 | 0.03368 | +# | 20 | module.layer3.2.conv2.weight | (64, 64, 3, 3) | 36864 | 11060 | 0.00000 | 0.00000 | 0.00000 | 28.49121 | 0.00000 | 69.99783 | 0.03788 | 0.00048 | 0.01900 | +# | 21 | module.fc.weight | (10, 64) | 640 | 640 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.54585 | -0.00002 | 0.46076 | +# | 22 | Total sparsity: | - | 246704 | 143488 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 41.83799 | 0.00000 | 0.00000 | 0.00000 | +# +----+-------------------------------------+----------------+---------------+----------------+------------+------------+----------+----------+----------+------------+---------+----------+------------+ +# Total sparsity: 41.84 # -# --- validate (epoch=359)----------- -# 5000 samples (256 per mini-batch) -# ==> Top1: 93.460 Top5: 99.800 Loss: 1.530 +# --- validate (epoch=179)----------- +# 10000 samples (256 per mini-batch) +# ==> Top1: 91.430 Top5: 99.640 Loss: 0.365 # -# ==> Best Top1: 97.320 On Epoch: 180 +# ==> Best [Top1: 91.630 Top5: 99.670 Sparsity:41.84 NNZ-Params: 143488 on epoch: 74] +# Saving checkpoint to: logs/2019.10.31-235045/checkpoint.pth.tar +# --- test --------------------- +# 10000 samples (256 per mini-batch) +# ==> Top1: 91.430 Top5: 99.640 Loss: 0.379 # -# Saving checkpoint to: logs/2018.10.15-115941/checkpoint.pth.tar -# --- test --------------------- -# 10000 samples (256 per mini-batch) -# ==> Top1: 91.200 Top5: 99.660 Loss: 1.551 # +# Log file for this run: /home/cvds_lab/nzmora/pytorch_workspace/distiller_remote/examples/classifier_compression/logs/2019.10.31-235045/2019.10.31-235045.log # -# Log file for this run: /home/cvds_lab/nzmora/pytorch_workspace/distiller/examples/classifier_compression/logs/2018.10.15-115941/2018.10.15-115941.log -# -# real 32m31.997s -# user 72m58.813s -# sys 9m1.245s +# real 52m57.688s +# user 304m20.353s +# sys 10m56.498s version: 1 pruners: @@ -112,7 +111,7 @@ policies: ending_epoch: 40 frequency: 2 -# After completeing the pruning, we perform network thinning and continue fine-tuning. +# After completing the pruning, we perform network thinning and continue fine-tuning. - extension: instance_name: net_thinner epochs: [22] diff --git a/examples/agp-pruning/resnet20_filters.schedule_agp_3.yaml b/examples/agp-pruning/resnet20_filters.schedule_agp_3.yaml index f958ec13ce230b5764e1bcec22eeef49d1671752..2f965af2239c68c272196b32b5b84d6a3a17cc05 100755 --- a/examples/agp-pruning/resnet20_filters.schedule_agp_3.yaml +++ b/examples/agp-pruning/resnet20_filters.schedule_agp_3.yaml @@ -9,59 +9,57 @@ # # of parameters: 270,896 # # Results: -# Top1: 91.470 on Epoch: 288 +# Top1: 91.29 # Total MACs: 30,433,920 (74.6% of the original compute) # Total sparsity: 56.41% # # of parameters: 95922 (=35.4% of the baseline parameters ==> 64.6% sparsity) # # time python3 compress_classifier.py --arch resnet20_cifar ../../../data.cifar10 -p=50 --lr=0.4 --epochs=180 --compress=../agp-pruning/resnet20_filters.schedule_agp_3.yaml --resume-from=../ssl/checkpoints/checkpoint_trained_dense.pth.tar --reset-optimizer --vs=0 # -# Parameters: -# +----+-------------------------------------+----------------+---------------+----------------+------------+------------+----------+----------+----------+------------+---------+----------+------------+ -# | | Name | Shape | NNZ (dense) | NNZ (sparse) | Cols (%) | Rows (%) | Ch (%) | 2D (%) | 3D (%) | Fine (%) | Std | Mean | Abs-Mean | -# |----+-------------------------------------+----------------+---------------+----------------+------------+------------+----------+----------+----------+------------+---------+----------+------------| -# | 0 | module.conv1.weight | (16, 3, 3, 3) | 432 | 432 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.41372 | -0.00535 | 0.29289 | -# | 1 | module.layer1.0.conv1.weight | (16, 16, 3, 3) | 2304 | 2304 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.15610 | -0.01373 | 0.11096 | -# | 2 | module.layer1.0.conv2.weight | (16, 16, 3, 3) | 2304 | 2304 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.15429 | 0.00180 | 0.11294 | -# | 3 | module.layer1.1.conv1.weight | (16, 16, 3, 3) | 2304 | 2304 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.13297 | -0.01580 | 0.10052 | -# | 4 | module.layer1.1.conv2.weight | (16, 16, 3, 3) | 2304 | 2304 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.12638 | -0.00556 | 0.09699 | -# | 5 | module.layer1.2.conv1.weight | (16, 16, 3, 3) | 2304 | 2304 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.17940 | -0.01313 | 0.13183 | -# | 6 | module.layer1.2.conv2.weight | (16, 16, 3, 3) | 2304 | 2304 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.14671 | -0.00056 | 0.11065 | -# | 7 | module.layer2.0.conv1.weight | (16, 16, 3, 3) | 2304 | 2304 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.16872 | -0.00380 | 0.12838 | -# | 8 | module.layer2.0.conv2.weight | (16, 16, 3, 3) | 2304 | 2304 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.18371 | 0.00119 | 0.14401 | -# | 9 | module.layer2.0.downsample.0.weight | (16, 16, 1, 1) | 256 | 256 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.33976 | 0.00148 | 0.24721 | -# | 10 | module.layer2.1.conv1.weight | (16, 16, 3, 3) | 2304 | 2304 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.12741 | -0.00734 | 0.09754 | -# | 11 | module.layer2.1.conv2.weight | (16, 16, 3, 3) | 2304 | 2304 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.10207 | 0.00286 | 0.07914 | -# | 12 | module.layer2.2.conv1.weight | (16, 16, 3, 3) | 2304 | 2304 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.13480 | -0.00943 | 0.10174 | -# | 13 | module.layer2.2.conv2.weight | (16, 16, 3, 3) | 2304 | 2304 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.09721 | 0.00049 | 0.07094 | -# | 14 | module.layer3.0.conv1.weight | (64, 16, 3, 3) | 9216 | 4608 | 0.00000 | 0.00000 | 0.00000 | 2.63672 | 1.56250 | 50.00000 | 0.11758 | -0.00484 | 0.07093 | -# | 15 | module.layer3.0.conv2.weight | (64, 64, 3, 3) | 36864 | 18432 | 0.00000 | 0.00000 | 1.56250 | 2.00195 | 0.00000 | 50.00000 | 0.08720 | -0.00522 | 0.05143 | -# | 16 | module.layer3.0.downsample.0.weight | (64, 16, 1, 1) | 1024 | 1024 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.16003 | -0.01049 | 0.12534 | -# | 17 | module.layer3.1.conv1.weight | (63, 64, 3, 3) | 36288 | 10887 | 0.00000 | 0.00000 | 0.00000 | 9.20139 | 1.58730 | 69.99835 | 0.07613 | -0.00415 | 0.03605 | -# | 18 | module.layer3.1.conv2.weight | (64, 63, 3, 3) | 36288 | 10887 | 0.00000 | 0.00000 | 1.58730 | 9.10218 | 0.00000 | 69.99835 | 0.07025 | -0.00544 | 0.03305 | -# | 19 | module.layer3.2.conv1.weight | (62, 64, 3, 3) | 35712 | 10714 | 0.00000 | 0.00000 | 0.00000 | 13.33165 | 3.22581 | 69.99888 | 0.07118 | -0.00550 | 0.03367 | -# | 20 | module.layer3.2.conv2.weight | (64, 62, 3, 3) | 35712 | 10714 | 0.00000 | 0.00000 | 3.22581 | 28.80544 | 0.00000 | 69.99888 | 0.04353 | 0.00071 | 0.01894 | -# | 21 | module.fc.weight | (10, 64) | 640 | 320 | 0.00000 | 50.00000 | 0.00000 | 0.00000 | 0.00000 | 50.00000 | 0.57334 | -0.00001 | 0.35840 | -# | 22 | Total sparsity: | - | 220080 | 95922 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 56.41494 | 0.00000 | 0.00000 | 0.00000 | -# +----+-------------------------------------+----------------+---------------+----------------+------------+------------+----------+----------+----------+------------+---------+----------+------------+ -# Total sparsity: 56.41 +# Parameters: +# +----+-------------------------------------+----------------+---------------+----------------+------------+------------+----------+----------+----------+------------+---------+----------+------------+ +# | | Name | Shape | NNZ (dense) | NNZ (sparse) | Cols (%) | Rows (%) | Ch (%) | 2D (%) | 3D (%) | Fine (%) | Std | Mean | Abs-Mean | +# |----+-------------------------------------+----------------+---------------+----------------+------------+------------+----------+----------+----------+------------+---------+----------+------------| +# | 0 | module.conv1.weight | (16, 3, 3, 3) | 432 | 432 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.41066 | -0.01044 | 0.28684 | +# | 1 | module.layer1.0.conv1.weight | (16, 16, 3, 3) | 2304 | 2304 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.15311 | -0.01381 | 0.10292 | +# | 2 | module.layer1.0.conv2.weight | (16, 16, 3, 3) | 2304 | 2304 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.15199 | 0.00052 | 0.10768 | +# | 3 | module.layer1.1.conv1.weight | (16, 16, 3, 3) | 2304 | 2304 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.13703 | -0.01426 | 0.10459 | +# | 4 | module.layer1.1.conv2.weight | (16, 16, 3, 3) | 2304 | 2304 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.12956 | -0.00363 | 0.10109 | +# | 5 | module.layer1.2.conv1.weight | (16, 16, 3, 3) | 2304 | 2304 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.18044 | -0.01347 | 0.13435 | +# | 6 | module.layer1.2.conv2.weight | (16, 16, 3, 3) | 2304 | 2304 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.14809 | -0.00015 | 0.11282 | +# | 7 | module.layer2.0.conv1.weight | (16, 16, 3, 3) | 2304 | 2304 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.16746 | -0.00541 | 0.13009 | +# | 8 | module.layer2.0.conv2.weight | (16, 16, 3, 3) | 2304 | 2304 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.18162 | -0.00449 | 0.14069 | +# | 9 | module.layer2.0.downsample.0.weight | (16, 16, 1, 1) | 256 | 256 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.34670 | 0.00145 | 0.23974 | +# | 10 | module.layer2.1.conv1.weight | (16, 16, 3, 3) | 2304 | 2304 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.11932 | -0.00470 | 0.08793 | +# | 11 | module.layer2.1.conv2.weight | (16, 16, 3, 3) | 2304 | 2304 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.09387 | 0.00694 | 0.06755 | +# | 12 | module.layer2.2.conv1.weight | (16, 16, 3, 3) | 2304 | 2304 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.11968 | -0.00680 | 0.08725 | +# | 13 | module.layer2.2.conv2.weight | (16, 16, 3, 3) | 2304 | 2304 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.08844 | 0.00324 | 0.06366 | +# | 14 | module.layer3.0.conv1.weight | (64, 16, 3, 3) | 9216 | 4608 | 0.00000 | 0.00000 | 0.00000 | 1.26953 | 0.00000 | 50.00000 | 0.11985 | -0.00543 | 0.07131 | +# | 15 | module.layer3.0.conv2.weight | (64, 64, 3, 3) | 36864 | 18432 | 0.00000 | 0.00000 | 0.00000 | 0.48828 | 0.00000 | 50.00000 | 0.08819 | -0.00512 | 0.05183 | +# | 16 | module.layer3.0.downsample.0.weight | (64, 16, 1, 1) | 1024 | 1024 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.15923 | -0.01150 | 0.12339 | +# | 17 | module.layer3.1.conv1.weight | (64, 64, 3, 3) | 36864 | 11060 | 0.00000 | 0.00000 | 0.00000 | 7.03125 | 0.00000 | 69.99783 | 0.07690 | -0.00368 | 0.03622 | +# | 18 | module.layer3.1.conv2.weight | (64, 64, 3, 3) | 36864 | 11060 | 0.00000 | 0.00000 | 0.00000 | 8.78906 | 0.00000 | 69.99783 | 0.07113 | -0.00574 | 0.03346 | +# | 19 | module.layer3.2.conv1.weight | (64, 64, 3, 3) | 36864 | 11060 | 0.00000 | 0.00000 | 0.00000 | 18.65234 | 9.37500 | 69.99783 | 0.06893 | -0.00443 | 0.03236 | +# | 20 | module.layer3.2.conv2.weight | (64, 64, 3, 3) | 36864 | 11060 | 0.00000 | 0.00000 | 9.37500 | 33.39844 | 0.00000 | 69.99783 | 0.04186 | 0.00078 | 0.01802 | +# | 21 | module.fc.weight | (10, 64) | 640 | 320 | 0.00000 | 50.00000 | 0.00000 | 0.00000 | 0.00000 | 50.00000 | 0.54733 | -0.00001 | 0.30963 | +# | 22 | Total sparsity: | - | 223536 | 96960 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 56.62444 | 0.00000 | 0.00000 | 0.00000 | +# +----+-------------------------------------+----------------+---------------+----------------+------------+------------+----------+----------+----------+------------+---------+----------+------------+ +# Total sparsity: 56.62 # -# --- validate (epoch=359)----------- -# 10000 samples (256 per mini-batch) -# ==> Top1: 91.140 Top5: 99.750 Loss: 0.331 +# --- validate (epoch=179)----------- +# 10000 samples (256 per mini-batch) +# ==> Top1: 91.290 Top5: 99.760 Loss: 0.344 # -# ==> Best Top1: 91.470 on Epoch: 288 -# Saving checkpoint to: logs/2018.11.08-232134/checkpoint.pth.tar -# --- test --------------------- -# 10000 samples (256 per mini-batch) -# ==> Top1: 91.140 Top5: 99.750 Loss: 0.331 +# ==> Best [Top1: 91.290 Top5: 99.760 Sparsity:56.62 NNZ-Params: 96960 on epoch: 179] +# Saving checkpoint to: logs/2019.11.01-005544/checkpoint.pth.tar +# --- test --------------------- +# 10000 samples (256 per mini-batch) +# ==> Top1: 91.290 Top5: 99.760 Loss: 0.344 # # -# Log file for this run: /home/cvds_lab/nzmora/sandbox_5/distiller/examples/classifier_compression/logs/2018.11.08-232134/2018.11.08-232134.log -# -# real 37m51.274s -# user 85m48.506s -# sys 10m35.410s +# real 31m53.439s +# user 264m58.198s +# sys 16m57.435s version: 1 @@ -147,7 +145,7 @@ policies: # ending_epoch: 20 # frequency: 2 -# After completeing the pruning, we perform network thinning and continue fine-tuning. +# After completing the pruning, we perform network thinning and continue fine-tuning. - extension: instance_name: net_thinner epochs: [32] diff --git a/examples/agp-pruning/resnet20_filters.schedule_agp_4.yaml b/examples/agp-pruning/resnet20_filters.schedule_agp_4.yaml index 147d7b8e0cd86b3afa6af9545d397f645b6f2736..954752c089b2a7a232aaff7d81c8808c1f9b65a2 100755 --- a/examples/agp-pruning/resnet20_filters.schedule_agp_4.yaml +++ b/examples/agp-pruning/resnet20_filters.schedule_agp_4.yaml @@ -9,52 +9,57 @@ # # of parameters: 270,896 # # Results: -# Top1: 90.86 +# Top1: 90.89 # Total MACs: 24,723,776 (=1.65x MACs) # Total sparsity: 39.66 # # of parameters: 78,776 (=29.1% of the baseline parameters) # -# time python3 compress_classifier.py --arch resnet20_cifar ../../../data.cifar10 -p=50 --lr=0.3 --epochs=180 --compress=../agp-pruning/resnet20_filters.schedule_agp_4.yaml --resume-from=../ssl/checkpoints/checkpoint_trained_dense.pth.tar --reset-optimizer --vs=0 +# time python3 compress_classifier.py --arch resnet20_cifar ../../../data.cifar10 -p=50 --lr=0.3 --epochs=180 --compress=../agp-pruning/resnet20_filters.schedule_agp_4.yaml --resume-from=../ssl/checkpoints/checkpoint_trained_dense.pth.tar --reset-optimizer --vs=0 --gpu=0 # -# Parameters: -# +----+-------------------------------------+----------------+---------------+----------------+------------+------------+----------+----------+----------+------------+---------+----------+------------+ -# | | Name | Shape | NNZ (dense) | NNZ (sparse) | Cols (%) | Rows (%) | Ch (%) | 2D (%) | 3D (%) | Fine (%) | Std | Mean | Abs-Mean | -# |----+-------------------------------------+----------------+---------------+----------------+------------+------------+----------+----------+----------+------------+---------+----------+------------| -# | 0 | module.conv1.weight | (16, 3, 3, 3) | 432 | 432 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.43875 | -0.01073 | 0.30863 | -# | 1 | module.layer1.0.conv1.weight | (16, 16, 3, 3) | 2304 | 2304 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.16813 | -0.01061 | 0.11996 | -# | 2 | module.layer1.0.conv2.weight | (16, 16, 3, 3) | 2304 | 2304 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.16491 | 0.00070 | 0.12291 | -# | 3 | module.layer1.1.conv1.weight | (16, 16, 3, 3) | 2304 | 2304 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.14755 | -0.01732 | 0.11317 | -# | 4 | module.layer1.1.conv2.weight | (16, 16, 3, 3) | 2304 | 2304 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.13875 | -0.00517 | 0.10758 | -# | 5 | module.layer1.2.conv1.weight | (16, 16, 3, 3) | 2304 | 2304 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.18923 | -0.01326 | 0.14210 | -# | 6 | module.layer1.2.conv2.weight | (16, 16, 3, 3) | 2304 | 2304 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.15583 | -0.00106 | 0.11905 | -# | 7 | module.layer2.0.conv1.weight | (16, 16, 3, 3) | 2304 | 2304 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.18225 | -0.00390 | 0.14086 | -# | 8 | module.layer2.0.conv2.weight | (16, 16, 3, 3) | 2304 | 2304 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.19739 | -0.00785 | 0.15657 | -# | 9 | module.layer2.0.downsample.0.weight | (16, 16, 1, 1) | 256 | 256 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.34375 | -0.01899 | 0.24210 | -# | 10 | module.layer2.1.conv1.weight | (16, 16, 3, 3) | 2304 | 2304 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.13369 | -0.00798 | 0.10439 | -# | 11 | module.layer2.1.conv2.weight | (16, 16, 3, 3) | 2304 | 2304 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.10876 | -0.00311 | 0.08349 | -# | 12 | module.layer2.2.conv1.weight | (16, 16, 3, 3) | 2304 | 2304 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.13521 | -0.01331 | 0.09820 | -# | 13 | module.layer2.2.conv2.weight | (16, 16, 3, 3) | 2304 | 2304 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.09794 | 0.00274 | 0.07076 | -# | 14 | module.layer3.0.conv1.weight | (64, 16, 3, 3) | 9216 | 9216 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.14004 | -0.00417 | 0.11131 | -# | 15 | module.layer3.0.conv2.weight | (32, 64, 3, 3) | 18432 | 18432 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.12559 | -0.00528 | 0.09919 | -# | 16 | module.layer3.0.downsample.0.weight | (32, 16, 1, 1) | 512 | 512 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.20064 | 0.00941 | 0.15694 | -# | 17 | module.layer3.1.conv1.weight | (64, 32, 3, 3) | 18432 | 5530 | 0.00000 | 0.00000 | 0.00000 | 7.76367 | 1.56250 | 69.99783 | 0.10111 | -0.00460 | 0.04898 | -# | 18 | module.layer3.1.conv2.weight | (32, 64, 3, 3) | 18432 | 5530 | 0.00000 | 0.00000 | 1.56250 | 8.83789 | 0.00000 | 69.99783 | 0.09000 | -0.00652 | 0.04327 | -# | 19 | module.layer3.2.conv1.weight | (64, 32, 3, 3) | 18432 | 5530 | 0.00000 | 0.00000 | 0.00000 | 9.37500 | 0.00000 | 69.99783 | 0.09371 | -0.00490 | 0.04521 | -# | 20 | module.layer3.2.conv2.weight | (32, 64, 3, 3) | 18432 | 5530 | 0.00000 | 0.00000 | 0.00000 | 24.65820 | 0.00000 | 69.99783 | 0.06228 | 0.00012 | 0.02827 | -# | 21 | module.fc.weight | (10, 32) | 320 | 160 | 0.00000 | 50.00000 | 0.00000 | 0.00000 | 0.00000 | 50.00000 | 0.78238 | -0.00000 | 0.48823 | -# | 22 | Total sparsity: | - | 130544 | 78776 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 39.65560 | 0.00000 | 0.00000 | 0.00000 | -# +----+-------------------------------------+----------------+---------------+----------------+------------+------------+----------+----------+----------+------------+---------+----------+------------+ -# Total sparsity: 39.66 +# Parameters: +# +----+-------------------------------------+----------------+---------------+----------------+------------+------------+----------+----------+----------+------------+---------+----------+------------+ +# | | Name | Shape | NNZ (dense) | NNZ (sparse) | Cols (%) | Rows (%) | Ch (%) | 2D (%) | 3D (%) | Fine (%) | Std | Mean | Abs-Mean | +# |----+-------------------------------------+----------------+---------------+----------------+------------+------------+----------+----------+----------+------------+---------+----------+------------| +# | 0 | module.conv1.weight | (16, 3, 3, 3) | 432 | 432 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.43641 | -0.01077 | 0.30848 | +# | 1 | module.layer1.0.conv1.weight | (16, 16, 3, 3) | 2304 | 2304 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.16754 | -0.01429 | 0.11853 | +# | 2 | module.layer1.0.conv2.weight | (16, 16, 3, 3) | 2304 | 2304 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.16230 | 0.00232 | 0.12090 | +# | 3 | module.layer1.1.conv1.weight | (16, 16, 3, 3) | 2304 | 2304 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.14305 | -0.01696 | 0.10940 | +# | 4 | module.layer1.1.conv2.weight | (16, 16, 3, 3) | 2304 | 2304 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.13429 | -0.00717 | 0.10360 | +# | 5 | module.layer1.2.conv1.weight | (16, 16, 3, 3) | 2304 | 2304 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.19055 | -0.01061 | 0.14304 | +# | 6 | module.layer1.2.conv2.weight | (16, 16, 3, 3) | 2304 | 2304 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.15439 | 0.00421 | 0.11755 | +# | 7 | module.layer2.0.conv1.weight | (16, 16, 3, 3) | 2304 | 2304 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.18164 | -0.00175 | 0.14038 | +# | 8 | module.layer2.0.conv2.weight | (16, 16, 3, 3) | 2304 | 2304 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.20045 | -0.01430 | 0.15603 | +# | 9 | module.layer2.0.downsample.0.weight | (16, 16, 1, 1) | 256 | 256 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.33919 | -0.02323 | 0.25281 | +# | 10 | module.layer2.1.conv1.weight | (16, 16, 3, 3) | 2304 | 2304 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.13992 | -0.00223 | 0.10972 | +# | 11 | module.layer2.1.conv2.weight | (16, 16, 3, 3) | 2304 | 2304 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.11157 | 0.00652 | 0.08719 | +# | 12 | module.layer2.2.conv1.weight | (16, 16, 3, 3) | 2304 | 2304 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.14406 | -0.00417 | 0.10841 | +# | 13 | module.layer2.2.conv2.weight | (16, 16, 3, 3) | 2304 | 2304 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.10744 | 0.00565 | 0.08125 | +# | 14 | module.layer3.0.conv1.weight | (64, 16, 3, 3) | 9216 | 9216 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.13871 | -0.00321 | 0.10984 | +# | 15 | module.layer3.0.conv2.weight | (32, 64, 3, 3) | 18432 | 18432 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.12541 | -0.00566 | 0.09928 | +# | 16 | module.layer3.0.downsample.0.weight | (32, 16, 1, 1) | 512 | 512 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.19916 | -0.00610 | 0.15371 | +# | 17 | module.layer3.1.conv1.weight | (64, 32, 3, 3) | 18432 | 5530 | 0.00000 | 0.00000 | 0.00000 | 7.61719 | 1.56250 | 69.99783 | 0.10042 | -0.00397 | 0.04862 | +# | 18 | module.layer3.1.conv2.weight | (32, 64, 3, 3) | 18432 | 5530 | 0.00000 | 0.00000 | 1.56250 | 9.71680 | 0.00000 | 69.99783 | 0.09041 | -0.00711 | 0.04355 | +# | 19 | module.layer3.2.conv1.weight | (64, 32, 3, 3) | 18432 | 5530 | 0.00000 | 0.00000 | 0.00000 | 12.30469 | 3.12500 | 69.99783 | 0.09180 | -0.00314 | 0.04391 | +# | 20 | module.layer3.2.conv2.weight | (32, 64, 3, 3) | 18432 | 5530 | 0.00000 | 0.00000 | 3.12500 | 29.29688 | 0.00000 | 69.99783 | 0.06059 | 0.00070 | 0.02763 | +# | 21 | module.fc.weight | (10, 32) | 320 | 160 | 0.00000 | 50.00000 | 0.00000 | 0.00000 | 0.00000 | 50.00000 | 0.73275 | -0.00001 | 0.41788 | +# | 22 | Total sparsity: | - | 130544 | 78776 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 39.65560 | 0.00000 | 0.00000 | 0.00000 | +# +----+-------------------------------------+----------------+---------------+----------------+------------+------------+----------+----------+----------+------------+---------+----------+------------+ +# Total sparsity: 39.66 # -# --- validate (epoch=359)----------- -# 10000 samples (256 per mini-batch) -# ==> Top1: 90.580 Top5: 99.650 Loss: 0.341 +# --- validate (epoch=179)----------- +# 10000 samples (256 per mini-batch) +# ==> Top1: 90.660 Top5: 99.700 Loss: 0.343 # -# ==> Best Top1: 90.860 on Epoch: 294 -# Saving checkpoint to: logs/2018.11.30-152150/checkpoint.pth.tar -# --- test --------------------- -# 10000 samples (256 per mini-batch) -# ==> Top1: 90.580 Top5: 99.650 Loss: 0.341 +# ==> Best [Top1: 90.890 Top5: 99.750 Sparsity:39.66 NNZ-Params: 78776 on epoch: 107] +# Saving checkpoint to: logs/2019.11.01-013919/checkpoint.pth.tar +# --- test --------------------- +# 10000 samples (256 per mini-batch) +# ==> Top1: 90.660 Top5: 99.700 Loss: 0.358 +# +# +# real 30m44.680s +# user 290m49.048s +# sys 17m17.928s version: 1 @@ -141,7 +146,7 @@ policies: # ending_epoch: 20 # frequency: 2 -# After completeing the pruning, we perform network thinning and continue fine-tuning. +# After completing the pruning, we perform network thinning and continue fine-tuning. - extension: instance_name: net_thinner epochs: [32] diff --git a/tests/test_pruning.py b/tests/test_pruning.py index 0862962c5f7fab78eb2b546d7c778bebf91fa1ec..269d262280dd0496004d210e48489398de04047d 100755 --- a/tests/test_pruning.py +++ b/tests/test_pruning.py @@ -453,9 +453,9 @@ def test_row_pruning(): [7., 8., 9.]]) from distiller.pruning import L1RankedStructureParameterPruner - masker = distiller.scheduler.ParameterMasker("why name") + masker = distiller.scheduler.ParameterMasker("debug name") zeros_mask_dict = {"some name": masker} - L1RankedStructureParameterPruner.rank_and_prune_rows(0.5, param, "some name", zeros_mask_dict) + L1RankedStructureParameterPruner.rank_and_prune_channels(0.5, param, "some name", zeros_mask_dict) print(distiller.sparsity_rows(masker.mask)) assert math.isclose(distiller.sparsity_rows(masker.mask), 1/3) pass