diff --git a/distiller/config.py b/distiller/config.py index 2661a2cea1afcdb3d6e4aee4c181a832be3f6d28..2c16fbac9945bf5a11a07f7c33990a7370eaa2d1 100755 --- a/distiller/config.py +++ b/distiller/config.py @@ -41,7 +41,7 @@ from torch.optim.lr_scheduler import * import distiller from distiller.thinning import * from distiller.pruning import * -from distiller.regularization import L1Regularizer, GroupLassoRegularizer +from distiller.regularization import * from distiller.learning_rate import * from distiller.quantization import * @@ -148,7 +148,7 @@ def file_config(model, optimizer, filename): sched_dict = yaml_ordered_load(stream) return dict_config(model, optimizer, sched_dict) except yaml.YAMLError as exc: - print("\nFATAL Parsing error while parsing the pruning schedule configuration file %s" % filename) + print("\nFATAL parsing error while parsing the schedule configuration file %s" % filename) exit(1) diff --git a/distiller/regularization/__init__.py b/distiller/regularization/__init__.py index e8faf1c3de257523c6a77c5485eea90493f1397e..a87a2119677f6e602ec1252f8cbef0bdcc0f1116 100755 --- a/distiller/regularization/__init__.py +++ b/distiller/regularization/__init__.py @@ -15,7 +15,7 @@ # from .l1_regularizer import L1Regularizer -from .group_regularizer import GroupLassoRegularizer +from .group_regularizer import GroupLassoRegularizer, GroupVarianceRegularizer del l1_regularizer del group_regularizer diff --git a/distiller/regularization/group_regularizer.py b/distiller/regularization/group_regularizer.py index 41ae68df625a2132fd588940f5ea6583db64b101..6b6a7284752209b1023327830a6843be3f54ff57 100755 --- a/distiller/regularization/group_regularizer.py +++ b/distiller/regularization/group_regularizer.py @@ -35,13 +35,14 @@ group and individual feature levels" [2] Jenatton, Rodolphe; Audibert, Jean-Yves; Bach, Francis (2009). "Structured Variable Selection with Sparsity-Inducing Norms". Journal of Machine Learning Research. 12 (2011): 2777–2824. arXiv:0904.3523 -[3] J. Friedman, T. Hastie, and R. Tibshirani, “A note on the group lassoand a sparse group lasso,†+[3] J. Friedman, T. Hastie, and R. Tibshirani, “A note on the group lasso and a sparse group lasso,†arXiv preprint arXiv:1001.0736, 2010 """ from .regularizer import _Regularizer, EPSILON import distiller + class GroupLassoRegularizer(distiller.GroupThresholdMixin, _Regularizer): def __init__(self, name, model, reg_regims, threshold_criteria=None): """ @@ -66,8 +67,10 @@ class GroupLassoRegularizer(distiller.GroupThresholdMixin, _Regularizer): regularizer_loss += GroupLassoRegularizer.__2d_rowwise_reg(param, strength) regularizer_loss += GroupLassoRegularizer.__2d_colwise_reg(param, strength) elif group_type == 'Channels': - regularizer_loss += GroupLassoRegularizer.__3d_channelwise_reg(param, strength) - elif group_type == '3D': + # This is also known as "input channels" + regularizer_loss += GroupLassoRegularizer._3d_channelwise_reg(param, strength) + elif group_type == 'Filters' or group_type == '3D': + # This is also known as "output channels" regularizer_loss += GroupLassoRegularizer.__3d_filterwise_reg(param, strength) elif group_type == '4D': regularizer_loss += GroupLassoRegularizer.__4d_layerwise_reg(param, strength) @@ -75,8 +78,7 @@ class GroupLassoRegularizer(distiller.GroupThresholdMixin, _Regularizer): regularizer_loss += GroupLassoRegularizer.__3d_channelwise_reg(param, strength) regularizer_loss += GroupLassoRegularizer.__4d_layerwise_reg(param, strength) else: - print("FATAL ERROR: Unknown parameter grouping: %s" % group_type) - exit(1) + raise ValueError('Unknown parameter grouping: ' + group_type) return regularizer_loss @@ -94,7 +96,7 @@ class GroupLassoRegularizer(distiller.GroupThresholdMixin, _Regularizer): if dim == -1: # We only have single group return groups.norm(2) * strength - return groups.pow(2).sum(dim=dim).add(EPSILON).pow(1/2.).sum().mul_(strength) + return groups.norm(2, dim=dim).sum().mul_(strength) @staticmethod def __4d_layerwise_reg(layer_weights, strength, dim=0): @@ -110,36 +112,26 @@ class GroupLassoRegularizer(distiller.GroupThresholdMixin, _Regularizer): assert layer_weights.dim() == 4, "This regularization is only supported for 4D weights" # create a filter structure - filters_view = layer_weights.view(-1, layer_weights.size(1) * layer_weights.size(2) * layer_weights.size(3)) + filters_view = layer_weights.view(layer_weights.size(0), -1) return GroupLassoRegularizer.__grouplasso_reg(filters_view, strength, dim=1) @staticmethod - def __2d_rowwise_reg(layer_weights, strength): - assert layer_weights.dim() == 2, "This regularization is only supported for 2D weights" - return GroupLassoRegularizer.__grouplasso_reg(layer_weights, strength, dim=1) - - @staticmethod - def __2d_colwise_reg(layer_weights, strength): - assert layer_weights.dim() == 2, "This regularization is only supported for 2D weights" - return GroupLassoRegularizer.__grouplasso_reg(layer_weights, strength, dim=0) - - @staticmethod - def __2d_kernelwise_reg(layer_weights, strength): - """Group Lasso with one of: - - group = 2D weights kernel (convolution) - - group = 2D columns (fully connected) - - group = 2D rows (fully connected) + def _3d_channelwise_reg(layer_weights, strength): + """Group Lasso with group = 3D input channel """ assert layer_weights.dim() == 4, "This regularization is only supported for 4D weights" - view_2d = layer_weights.view(-1, layer_weights.size(2) * layer_weights.size(3)) - return GroupLassoRegularizer.__grouplasso_reg(view_2d, strength, dim=1) + + # Sum of all channel L2s * regulization_strength + layer_channels_l2 = GroupLassoRegularizer._channels_l2(layer_weights).sum().mul_(strength) + return layer_channels_l2 @staticmethod - def __3d_channelwise_reg(layer_weights, strength): - """Group Lasso with one of: - """ - assert layer_weights.dim() == 4, "This regularization is only supported for 4D weights" + def _channels_l2(layer_weights): + """Compute the L2-norm of convolution input channels weights. + A weights input channel is composed of all the kernels that are applied to the + same activation input channel. Each kernel belongs to a different weights filter. + """ # Now, for each group, we want to select a specific channel from all of the filters num_filters = layer_weights.size(0) num_kernels_per_filter = layer_weights.size(1) @@ -158,6 +150,53 @@ class GroupLassoRegularizer(distiller.GroupThresholdMixin, _Regularizer): k_sq_sums_mat = k_sq_sums.view(num_filters, num_kernels_per_filter).t() # Now it's easy, just do Group Lasso on groups=rows - #groups_loss = k_sq_sums_mat.sum(dim=0).add(EPSILON).pow(1/2.).sum().mul_(strength) - groups_loss = k_sq_sums_mat.sum(dim=1).add(EPSILON).pow(1/2.).sum().mul_(strength) - return groups_loss + channels_l2 = k_sq_sums_mat.sum(dim=1).add(EPSILON).pow(1/2.) + return channels_l2 + + @staticmethod + def __2d_rowwise_reg(layer_weights, strength): + assert layer_weights.dim() == 2, "This regularization is only supported for 2D weights" + return GroupLassoRegularizer.__grouplasso_reg(layer_weights, strength, dim=1) + + @staticmethod + def __2d_colwise_reg(layer_weights, strength): + assert layer_weights.dim() == 2, "This regularization is only supported for 2D weights" + return GroupLassoRegularizer.__grouplasso_reg(layer_weights, strength, dim=0) + + @staticmethod + def __2d_kernelwise_reg(layer_weights, strength): + """Group Lasso with one of: + - group = 2D weights kernel (convolution) + - group = 2D columns (fully connected) + - group = 2D rows (fully connected) + """ + assert layer_weights.dim() == 4, "This regularization is only supported for 4D weights" + view_2d = layer_weights.view(-1, layer_weights.size(2) * layer_weights.size(3)) + return GroupLassoRegularizer.__grouplasso_reg(view_2d, strength, dim=1) + + +class GroupVarianceRegularizer(GroupLassoRegularizer): + """Group variance regularization. + + As described in [4]. + + [4] Amirsina Torfi, Rouzbeh A. Shirvani, Sobhan Soleymani, Nasser M. Nasrabadi, + “Attention-Based Guided Structured Sparsity of Deep Neural Networks,†+ arXiv preprint arXiv:1802.09902, ICLR 2018 + """ + def __init__(self, name, model, reg_regims): + super(GroupVarianceRegularizer, self).__init__(name, model, reg_regims) + + def loss(self, param, param_name, regularizer_loss, zeros_mask_dict): + if param_name in self.reg_regims.keys(): + group_type = self.reg_regims[param_name][1] + strength = self.reg_regims[param_name][0] + + if group_type == 'Channels': + channels_l2 = GroupLassoRegularizer._channels_l2(param) + var = channels_l2.var() + var_loss = 1 / var + regularizer_loss += strength * var_loss + else: + raise ValueError('Unknown parameter grouping: ' + group_type) + return regularizer_loss diff --git a/distiller/thresholding.py b/distiller/thresholding.py index a00cbebe66c31206802cf68f3bf6e42ed20082db..4cb81898d104578cb15e82c6dffe9c06d66e46df 100755 --- a/distiller/thresholding.py +++ b/distiller/thresholding.py @@ -76,7 +76,7 @@ class GroupThresholdMixin(object): binary_map = self.threshold_policy(param, thresholds, threshold_criteria, dim=0) return binary_map.expand(param.size(0), param.size(1)) - elif group_type == '3D': + elif group_type == '3D' or group_type == 'Filters': assert param.dim() == 4, "This thresholding is only supported for 4D weights" view_filters = param.view(param.size(0), -1) thresholds = torch.Tensor([threshold] * param.size(0)).cuda() diff --git a/examples/gss/gss_channels-removal_training.yaml b/examples/gss/gss_channels-removal_training.yaml new file mode 100755 index 0000000000000000000000000000000000000000..a6103ee05a247ba192791afbebd6de47ca09b5b7 --- /dev/null +++ b/examples/gss/gss_channels-removal_training.yaml @@ -0,0 +1,82 @@ +# GSS (Guided Structured Sparsity). +# "Attention-Based Guided Structured Sparsity of Deep Neural Networks", +# Amirsina Torfi, Rouzbeh A. Shirvani, Sobhan Soleymani, Nasser M. Nasrabadi +# ICLR 2018 +# https://arxiv.org/abs/1802.09902 +# +# Add group variance regularization to SSL. +# So far I haven't produced results better than SSL. The regularization strengh of the variance does not come into play +# because it seems like the variance cost diminishes very quickly. +# +# time python3 compress_classifier.py --arch resnet20_cifar ../../../data.cifar10 -p=50 --lr=0.3 --epochs=180 --compress=../ssl/ssl_channels-removal_training.yaml -j=1 --deterministic +# + + +lr_schedulers: + training_lr: + class: StepLR + step_size: 45 + gamma: 0.10 + +regularizers: + Channels_l2_regularizer: + class: GroupLassoRegularizer + reg_regims: + module.layer1.0.conv2.weight: [0.0028, Channels] + module.layer1.1.conv2.weight: [0.0028, Channels] + module.layer1.2.conv2.weight: [0.0024, Channels] + module.layer2.0.conv2.weight: [0.0016, Channels] # sensitive + module.layer2.1.conv2.weight: [0.0028, Channels] + module.layer2.2.conv2.weight: [0.0028, Channels] + module.layer3.0.conv2.weight: [0.0008, Channels] # sensitive + module.layer3.1.conv2.weight: [0.0028, Channels] + #module.layer3.2.conv2.weight: [0.0006, Channels] # very sensitive + threshold_criteria: Mean_Abs + + Channels_variance_reguralizer: + class: GroupVarianceRegularizer + reg_regims: + module.layer1.0.conv2.weight: [0.000008, Channels] + module.layer1.1.conv2.weight: [0.000008, Channels] + module.layer1.2.conv2.weight: [0.000008, Channels] + module.layer2.0.conv2.weight: [0.000008, Channels] + module.layer2.1.conv2.weight: [0.000008, Channels] + module.layer2.2.conv2.weight: [0.000008, Channels] + module.layer3.0.conv2.weight: [0.000008, Channels] + module.layer3.1.conv2.weight: [0.000008, Channels] + #module.layer3.2.conv2.weight: [0.000008, Channels] + +extensions: + net_thinner: + class: 'ChannelRemover' + thinning_func_str: remove_channels + arch: 'resnet20_cifar' + dataset: 'cifar10' + +policies: + - lr_scheduler: + instance_name: training_lr + starting_epoch: 45 + ending_epoch: 300 + frequency: 1 + +# After completeing the regularization, we perform network thinning and exit. + - extension: + instance_name: net_thinner + epochs: [179] + + - regularizer: + instance_name: Channels_l2_regularizer + args: + keep_mask: True + starting_epoch: 0 + ending_epoch: 180 + frequency: 1 + + - regularizer: + instance_name: Channels_variance_reguralizer + args: + keep_mask: True + starting_epoch: 0 + ending_epoch: 180 + frequency: 1 diff --git a/examples/ssl/ssl_channels-removal_training.yaml b/examples/ssl/ssl_channels-removal_training.yaml index e99ba3b6974263ea649b6f31d4ebf6647c0f7493..d53dfaab4bfaeae551959534e8feb07b39a9cbcd 100755 --- a/examples/ssl/ssl_channels-removal_training.yaml +++ b/examples/ssl/ssl_channels-removal_training.yaml @@ -87,7 +87,7 @@ extensions: net_thinner: class: 'ChannelRemover' thinning_func_str: remove_channels - arch: 'resnet56_cifar' + arch: 'resnet20_cifar' dataset: 'cifar10' policies: diff --git a/examples/ssl/ssl_channels-removal_training_x1.8.yaml b/examples/ssl/ssl_channels-removal_training_x1.8.yaml index 9ab0b5d890ba17790d10dba17416cfd933f90dc3..00625c6dfc9d03687b45eeacb2ef5f4e716118d4 100755 --- a/examples/ssl/ssl_channels-removal_training_x1.8.yaml +++ b/examples/ssl/ssl_channels-removal_training_x1.8.yaml @@ -9,6 +9,9 @@ # # time python3 compress_classifier.py --arch resnet20_cifar ../../../data.cifar10 -p=50 --lr=0.3 --epochs=180 --compress=../ssl/ssl_channels-removal_training_x1.8.yaml -j=1 --deterministic # +# To fine-tune: +# time python3 compress_classifier.py --arch resnet20_cifar ../../../data.cifar10 -p=50 --lr=0.2 --epochs=98 --compress=../ssl/ssl_channels-removal_finetuning.yaml -j=1 --deterministic --resume=... +# # Parameters: # +----+-------------------------------------+----------------+---------------+----------------+------------+------------+----------+----------+----------+------------+---------+----------+------------+ # | | Name | Shape | NNZ (dense) | NNZ (sparse) | Cols (%) | Rows (%) | Ch (%) | 2D (%) | 3D (%) | Fine (%) | Std | Mean | Abs-Mean | @@ -89,7 +92,7 @@ extensions: net_thinner: class: 'ChannelRemover' thinning_func_str: remove_channels - arch: 'resnet56_cifar' + arch: 'resnet20_cifar' dataset: 'cifar10' policies: diff --git a/examples/ssl/ssl_filter-removal_training.yaml b/examples/ssl/ssl_filter-removal_training.yaml new file mode 100755 index 0000000000000000000000000000000000000000..84c30e861bca8c8f147c5d124947e6804471dec7 --- /dev/null +++ b/examples/ssl/ssl_filter-removal_training.yaml @@ -0,0 +1,101 @@ +# SSL: Filter regularization +# +# This is a not-so-successful attempt at filter regularization: +# Total MACs: 27,800,192 = 68% compute density. +# Test Top1 after training: 90.39 +# Test Top1 after fine-tuning: 90.93 +# +# To train: +# time python3 compress_classifier.py --arch resnet20_cifar ../../../data.cifar10 -p=50 --lr=0.3 --epochs=180 --compress=../ssl/ssl_filter-removal_training.yaml -j=1 --deterministic --name="filters" +# +# To fine-tune: +# time python3 compress_classifier.py --arch resnet20_cifar ../../../data.cifar10 -p=50 --lr=0.2 --epochs=98 --compress=../ssl/ssl_channels-removal_finetuning.yaml -j=1 --deterministic --resume=... +# +# Parameters: +# +----+-------------------------------------+----------------+---------------+----------------+------------+------------+----------+----------+----------+------------+---------+----------+------------+ +# | | Name | Shape | NNZ (dense) | NNZ (sparse) | Cols (%) | Rows (%) | Ch (%) | 2D (%) | 3D (%) | Fine (%) | Std | Mean | Abs-Mean | +# |----+-------------------------------------+----------------+---------------+----------------+------------+------------+----------+----------+----------+------------+---------+----------+------------| +# | 0 | module.conv1.weight | (16, 3, 3, 3) | 432 | 432 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.42230 | -0.00107 | 0.29227 | +# | 1 | module.layer1.0.conv1.weight | (13, 16, 3, 3) | 1872 | 1872 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.04139 | -0.00293 | 0.02218 | +# | 2 | module.layer1.0.conv2.weight | (16, 13, 3, 3) | 1872 | 1872 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.16320 | -0.00121 | 0.10359 | +# | 3 | module.layer1.1.conv1.weight | (9, 16, 3, 3) | 1296 | 1296 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.02499 | -0.00091 | 0.01594 | +# | 4 | module.layer1.1.conv2.weight | (16, 9, 3, 3) | 1296 | 1296 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.13183 | -0.01035 | 0.09682 | +# | 5 | module.layer1.2.conv1.weight | (10, 16, 3, 3) | 1440 | 1440 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.07616 | -0.00278 | 0.05246 | +# | 6 | module.layer1.2.conv2.weight | (16, 10, 3, 3) | 1440 | 1440 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.18164 | -0.01895 | 0.13244 | +# | 7 | module.layer2.0.conv1.weight | (25, 16, 3, 3) | 3600 | 3600 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.02687 | 0.00002 | 0.01887 | +# | 8 | module.layer2.0.conv2.weight | (32, 25, 3, 3) | 7200 | 7200 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.11768 | -0.01350 | 0.09049 | +# | 9 | module.layer2.0.downsample.0.weight | (32, 16, 1, 1) | 512 | 512 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.30952 | -0.03258 | 0.21696 | +# | 10 | module.layer2.1.conv1.weight | (32, 32, 3, 3) | 9216 | 9216 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.09844 | -0.00413 | 0.07454 | +# | 11 | module.layer2.1.conv2.weight | (32, 32, 3, 3) | 9216 | 9216 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.11731 | -0.00819 | 0.09292 | +# | 12 | module.layer2.2.conv1.weight | (4, 32, 3, 3) | 1152 | 1152 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.02086 | -0.00164 | 0.01553 | +# | 13 | module.layer2.2.conv2.weight | (32, 4, 3, 3) | 1152 | 1152 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.08841 | -0.00491 | 0.06650 | +# | 14 | module.layer3.0.conv1.weight | (48, 32, 3, 3) | 13824 | 13824 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.01713 | -0.00044 | 0.01255 | +# | 15 | module.layer3.0.conv2.weight | (64, 48, 3, 3) | 27648 | 27648 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.09969 | -0.00692 | 0.07733 | +# | 16 | module.layer3.0.downsample.0.weight | (64, 32, 1, 1) | 2048 | 2048 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.19055 | -0.01650 | 0.14967 | +# | 17 | module.layer3.1.conv1.weight | (30, 64, 3, 3) | 17280 | 17280 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.01936 | 0.00019 | 0.01468 | +# | 18 | module.layer3.1.conv2.weight | (64, 30, 3, 3) | 17280 | 17280 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.08263 | -0.01507 | 0.06434 | +# | 19 | module.layer3.2.conv1.weight | (64, 64, 3, 3) | 36864 | 36864 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.07410 | -0.00536 | 0.05833 | +# | 20 | module.layer3.2.conv2.weight | (64, 64, 3, 3) | 36864 | 36864 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.06848 | -0.00032 | 0.05342 | +# | 21 | module.fc.weight | (10, 64) | 640 | 640 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.66850 | -0.00003 | 0.54848 | +# | 22 | Total sparsity: | - | 194144 | 194144 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | +# +----+-------------------------------------+----------------+---------------+----------------+------------+------------+----------+----------+----------+------------+---------+----------+------------+ +# 2018-09-22 15:44:14,320 - Total sparsity: 0.00 +# +# 2018-09-22 15:44:14,321 - --- validate (epoch=179)----------- +# 2018-09-22 15:44:14,321 - 5000 samples (256 per mini-batch) +# 2018-09-22 15:44:15,800 - ==> Top1: 90.460 Top5: 99.720 Loss: 0.332 +# +# 2018-09-22 15:44:15,802 - ==> Best Top1: 90.900 On Epoch: 148 +# +# 2018-09-22 15:44:15,802 - Saving checkpoint to: logs/filters___2018.09.22-151047/filters_checkpoint.pth.tar +# 2018-09-22 15:44:15,818 - --- test --------------------- +# 2018-09-22 15:44:15,818 - 10000 samples (256 per mini-batch) +# 2018-09-22 15:44:17,459 - ==> Top1: 90.390 Top5: 99.750 Loss: 0.349 + +lr_schedulers: + training_lr: + class: StepLR + step_size: 45 + gamma: 0.10 + +regularizers: + Filters_groups_regularizer: + class: GroupLassoRegularizer + reg_regims: + module.layer1.0.conv1.weight: [0.0008, Filters] + module.layer1.1.conv1.weight: [0.0008, Filters] + module.layer1.2.conv1.weight: [0.0006, Filters] + module.layer2.0.conv1.weight: [0.0008, Filters] + module.layer2.1.conv1.weight: [0.0002, Filters] + module.layer2.2.conv1.weight: [0.0008, Filters] + module.layer3.0.conv1.weight: [0.0012, Filters] + module.layer3.1.conv1.weight: [0.0010, Filters] + module.layer3.2.conv1.weight: [0.0002, Filters] + threshold_criteria: Mean_Abs + +extensions: + net_thinner: + class: 'FilterRemover' + thinning_func_str: remove_filters + arch: 'resnet20_cifar' + dataset: 'cifar10' + +policies: + - lr_scheduler: + instance_name: training_lr + starting_epoch: 45 + ending_epoch: 300 + frequency: 1 + +# After completeing the regularization, we perform network thinning and exit. + - extension: + instance_name: net_thinner + epochs: [179] + + - regularizer: + instance_name: Filters_groups_regularizer + args: + keep_mask: True + starting_epoch: 0 + ending_epoch: 180 + frequency: 1