Skip to content
Snippets Groups Projects
Unverified Commit 65919dc0 authored by Neta Zmora's avatar Neta Zmora Committed by GitHub
Browse files

Attention-Based Guided Structured Sparsity (GSS) (#51)

* Added GSS ("Attention-Based Guided Structured Sparsity of Deep Neural Networks") and an example for ResNet20 channel pruning.
    - The idea is to regularize the variance of the distribution of the parameter structures. Some structures will zero completely and the rest should have a high value leading to a high variance.
    - A new regularizer class, GroupVarianceRegularizer, is used to regularize the group variance (effectively rewarding the loss function for high variance between the groups).
    - When tested on ResNet 20 GSS did not show any improvement over SSL

* Added sample of filter pruning for ResNet20 CIFAR using SSL (Learning Structured Sparsity in Deep Neural Networks)

* Added an example of pruning 45% of the compute (1.8x MAC reduction), while suffering 0.8% accuracy loss on ResNet20 CIFAR

* Added a ResNet50 ImageNet example of L1-Magnitude fine-grained pruning, using an AGP schedule: 46% sparsity with a 0.6% accuracy increase. This is an example of using pruning used as a regularizer.
parent 06f2b065
No related branches found
No related tags found
No related merge requests found
...@@ -41,7 +41,7 @@ from torch.optim.lr_scheduler import * ...@@ -41,7 +41,7 @@ from torch.optim.lr_scheduler import *
import distiller import distiller
from distiller.thinning import * from distiller.thinning import *
from distiller.pruning import * from distiller.pruning import *
from distiller.regularization import L1Regularizer, GroupLassoRegularizer from distiller.regularization import *
from distiller.learning_rate import * from distiller.learning_rate import *
from distiller.quantization import * from distiller.quantization import *
...@@ -148,7 +148,7 @@ def file_config(model, optimizer, filename): ...@@ -148,7 +148,7 @@ def file_config(model, optimizer, filename):
sched_dict = yaml_ordered_load(stream) sched_dict = yaml_ordered_load(stream)
return dict_config(model, optimizer, sched_dict) return dict_config(model, optimizer, sched_dict)
except yaml.YAMLError as exc: except yaml.YAMLError as exc:
print("\nFATAL Parsing error while parsing the pruning schedule configuration file %s" % filename) print("\nFATAL parsing error while parsing the schedule configuration file %s" % filename)
exit(1) exit(1)
......
...@@ -15,7 +15,7 @@ ...@@ -15,7 +15,7 @@
# #
from .l1_regularizer import L1Regularizer from .l1_regularizer import L1Regularizer
from .group_regularizer import GroupLassoRegularizer from .group_regularizer import GroupLassoRegularizer, GroupVarianceRegularizer
del l1_regularizer del l1_regularizer
del group_regularizer del group_regularizer
...@@ -35,13 +35,14 @@ group and individual feature levels" ...@@ -35,13 +35,14 @@ group and individual feature levels"
[2] Jenatton, Rodolphe; Audibert, Jean-Yves; Bach, Francis (2009). "Structured Variable Selection with [2] Jenatton, Rodolphe; Audibert, Jean-Yves; Bach, Francis (2009). "Structured Variable Selection with
Sparsity-Inducing Norms". Journal of Machine Learning Research. 12 (2011): 2777–2824. arXiv:0904.3523 Sparsity-Inducing Norms". Journal of Machine Learning Research. 12 (2011): 2777–2824. arXiv:0904.3523
[3] J. Friedman, T. Hastie, and R. Tibshirani, “A note on the group lassoand a sparse group lasso,” [3] J. Friedman, T. Hastie, and R. Tibshirani, “A note on the group lasso and a sparse group lasso,”
arXiv preprint arXiv:1001.0736, 2010 arXiv preprint arXiv:1001.0736, 2010
""" """
from .regularizer import _Regularizer, EPSILON from .regularizer import _Regularizer, EPSILON
import distiller import distiller
class GroupLassoRegularizer(distiller.GroupThresholdMixin, _Regularizer): class GroupLassoRegularizer(distiller.GroupThresholdMixin, _Regularizer):
def __init__(self, name, model, reg_regims, threshold_criteria=None): def __init__(self, name, model, reg_regims, threshold_criteria=None):
""" """
...@@ -66,8 +67,10 @@ class GroupLassoRegularizer(distiller.GroupThresholdMixin, _Regularizer): ...@@ -66,8 +67,10 @@ class GroupLassoRegularizer(distiller.GroupThresholdMixin, _Regularizer):
regularizer_loss += GroupLassoRegularizer.__2d_rowwise_reg(param, strength) regularizer_loss += GroupLassoRegularizer.__2d_rowwise_reg(param, strength)
regularizer_loss += GroupLassoRegularizer.__2d_colwise_reg(param, strength) regularizer_loss += GroupLassoRegularizer.__2d_colwise_reg(param, strength)
elif group_type == 'Channels': elif group_type == 'Channels':
regularizer_loss += GroupLassoRegularizer.__3d_channelwise_reg(param, strength) # This is also known as "input channels"
elif group_type == '3D': regularizer_loss += GroupLassoRegularizer._3d_channelwise_reg(param, strength)
elif group_type == 'Filters' or group_type == '3D':
# This is also known as "output channels"
regularizer_loss += GroupLassoRegularizer.__3d_filterwise_reg(param, strength) regularizer_loss += GroupLassoRegularizer.__3d_filterwise_reg(param, strength)
elif group_type == '4D': elif group_type == '4D':
regularizer_loss += GroupLassoRegularizer.__4d_layerwise_reg(param, strength) regularizer_loss += GroupLassoRegularizer.__4d_layerwise_reg(param, strength)
...@@ -75,8 +78,7 @@ class GroupLassoRegularizer(distiller.GroupThresholdMixin, _Regularizer): ...@@ -75,8 +78,7 @@ class GroupLassoRegularizer(distiller.GroupThresholdMixin, _Regularizer):
regularizer_loss += GroupLassoRegularizer.__3d_channelwise_reg(param, strength) regularizer_loss += GroupLassoRegularizer.__3d_channelwise_reg(param, strength)
regularizer_loss += GroupLassoRegularizer.__4d_layerwise_reg(param, strength) regularizer_loss += GroupLassoRegularizer.__4d_layerwise_reg(param, strength)
else: else:
print("FATAL ERROR: Unknown parameter grouping: %s" % group_type) raise ValueError('Unknown parameter grouping: ' + group_type)
exit(1)
return regularizer_loss return regularizer_loss
...@@ -94,7 +96,7 @@ class GroupLassoRegularizer(distiller.GroupThresholdMixin, _Regularizer): ...@@ -94,7 +96,7 @@ class GroupLassoRegularizer(distiller.GroupThresholdMixin, _Regularizer):
if dim == -1: if dim == -1:
# We only have single group # We only have single group
return groups.norm(2) * strength return groups.norm(2) * strength
return groups.pow(2).sum(dim=dim).add(EPSILON).pow(1/2.).sum().mul_(strength) return groups.norm(2, dim=dim).sum().mul_(strength)
@staticmethod @staticmethod
def __4d_layerwise_reg(layer_weights, strength, dim=0): def __4d_layerwise_reg(layer_weights, strength, dim=0):
...@@ -110,36 +112,26 @@ class GroupLassoRegularizer(distiller.GroupThresholdMixin, _Regularizer): ...@@ -110,36 +112,26 @@ class GroupLassoRegularizer(distiller.GroupThresholdMixin, _Regularizer):
assert layer_weights.dim() == 4, "This regularization is only supported for 4D weights" assert layer_weights.dim() == 4, "This regularization is only supported for 4D weights"
# create a filter structure # create a filter structure
filters_view = layer_weights.view(-1, layer_weights.size(1) * layer_weights.size(2) * layer_weights.size(3)) filters_view = layer_weights.view(layer_weights.size(0), -1)
return GroupLassoRegularizer.__grouplasso_reg(filters_view, strength, dim=1) return GroupLassoRegularizer.__grouplasso_reg(filters_view, strength, dim=1)
@staticmethod @staticmethod
def __2d_rowwise_reg(layer_weights, strength): def _3d_channelwise_reg(layer_weights, strength):
assert layer_weights.dim() == 2, "This regularization is only supported for 2D weights" """Group Lasso with group = 3D input channel
return GroupLassoRegularizer.__grouplasso_reg(layer_weights, strength, dim=1)
@staticmethod
def __2d_colwise_reg(layer_weights, strength):
assert layer_weights.dim() == 2, "This regularization is only supported for 2D weights"
return GroupLassoRegularizer.__grouplasso_reg(layer_weights, strength, dim=0)
@staticmethod
def __2d_kernelwise_reg(layer_weights, strength):
"""Group Lasso with one of:
- group = 2D weights kernel (convolution)
- group = 2D columns (fully connected)
- group = 2D rows (fully connected)
""" """
assert layer_weights.dim() == 4, "This regularization is only supported for 4D weights" assert layer_weights.dim() == 4, "This regularization is only supported for 4D weights"
view_2d = layer_weights.view(-1, layer_weights.size(2) * layer_weights.size(3))
return GroupLassoRegularizer.__grouplasso_reg(view_2d, strength, dim=1) # Sum of all channel L2s * regulization_strength
layer_channels_l2 = GroupLassoRegularizer._channels_l2(layer_weights).sum().mul_(strength)
return layer_channels_l2
@staticmethod @staticmethod
def __3d_channelwise_reg(layer_weights, strength): def _channels_l2(layer_weights):
"""Group Lasso with one of: """Compute the L2-norm of convolution input channels weights.
"""
assert layer_weights.dim() == 4, "This regularization is only supported for 4D weights"
A weights input channel is composed of all the kernels that are applied to the
same activation input channel. Each kernel belongs to a different weights filter.
"""
# Now, for each group, we want to select a specific channel from all of the filters # Now, for each group, we want to select a specific channel from all of the filters
num_filters = layer_weights.size(0) num_filters = layer_weights.size(0)
num_kernels_per_filter = layer_weights.size(1) num_kernels_per_filter = layer_weights.size(1)
...@@ -158,6 +150,53 @@ class GroupLassoRegularizer(distiller.GroupThresholdMixin, _Regularizer): ...@@ -158,6 +150,53 @@ class GroupLassoRegularizer(distiller.GroupThresholdMixin, _Regularizer):
k_sq_sums_mat = k_sq_sums.view(num_filters, num_kernels_per_filter).t() k_sq_sums_mat = k_sq_sums.view(num_filters, num_kernels_per_filter).t()
# Now it's easy, just do Group Lasso on groups=rows # Now it's easy, just do Group Lasso on groups=rows
#groups_loss = k_sq_sums_mat.sum(dim=0).add(EPSILON).pow(1/2.).sum().mul_(strength) channels_l2 = k_sq_sums_mat.sum(dim=1).add(EPSILON).pow(1/2.)
groups_loss = k_sq_sums_mat.sum(dim=1).add(EPSILON).pow(1/2.).sum().mul_(strength) return channels_l2
return groups_loss
@staticmethod
def __2d_rowwise_reg(layer_weights, strength):
assert layer_weights.dim() == 2, "This regularization is only supported for 2D weights"
return GroupLassoRegularizer.__grouplasso_reg(layer_weights, strength, dim=1)
@staticmethod
def __2d_colwise_reg(layer_weights, strength):
assert layer_weights.dim() == 2, "This regularization is only supported for 2D weights"
return GroupLassoRegularizer.__grouplasso_reg(layer_weights, strength, dim=0)
@staticmethod
def __2d_kernelwise_reg(layer_weights, strength):
"""Group Lasso with one of:
- group = 2D weights kernel (convolution)
- group = 2D columns (fully connected)
- group = 2D rows (fully connected)
"""
assert layer_weights.dim() == 4, "This regularization is only supported for 4D weights"
view_2d = layer_weights.view(-1, layer_weights.size(2) * layer_weights.size(3))
return GroupLassoRegularizer.__grouplasso_reg(view_2d, strength, dim=1)
class GroupVarianceRegularizer(GroupLassoRegularizer):
"""Group variance regularization.
As described in [4].
[4] Amirsina Torfi, Rouzbeh A. Shirvani, Sobhan Soleymani, Nasser M. Nasrabadi,
“Attention-Based Guided Structured Sparsity of Deep Neural Networks,”
arXiv preprint arXiv:1802.09902, ICLR 2018
"""
def __init__(self, name, model, reg_regims):
super(GroupVarianceRegularizer, self).__init__(name, model, reg_regims)
def loss(self, param, param_name, regularizer_loss, zeros_mask_dict):
if param_name in self.reg_regims.keys():
group_type = self.reg_regims[param_name][1]
strength = self.reg_regims[param_name][0]
if group_type == 'Channels':
channels_l2 = GroupLassoRegularizer._channels_l2(param)
var = channels_l2.var()
var_loss = 1 / var
regularizer_loss += strength * var_loss
else:
raise ValueError('Unknown parameter grouping: ' + group_type)
return regularizer_loss
...@@ -76,7 +76,7 @@ class GroupThresholdMixin(object): ...@@ -76,7 +76,7 @@ class GroupThresholdMixin(object):
binary_map = self.threshold_policy(param, thresholds, threshold_criteria, dim=0) binary_map = self.threshold_policy(param, thresholds, threshold_criteria, dim=0)
return binary_map.expand(param.size(0), param.size(1)) return binary_map.expand(param.size(0), param.size(1))
elif group_type == '3D': elif group_type == '3D' or group_type == 'Filters':
assert param.dim() == 4, "This thresholding is only supported for 4D weights" assert param.dim() == 4, "This thresholding is only supported for 4D weights"
view_filters = param.view(param.size(0), -1) view_filters = param.view(param.size(0), -1)
thresholds = torch.Tensor([threshold] * param.size(0)).cuda() thresholds = torch.Tensor([threshold] * param.size(0)).cuda()
......
# GSS (Guided Structured Sparsity).
# "Attention-Based Guided Structured Sparsity of Deep Neural Networks",
# Amirsina Torfi, Rouzbeh A. Shirvani, Sobhan Soleymani, Nasser M. Nasrabadi
# ICLR 2018
# https://arxiv.org/abs/1802.09902
#
# Add group variance regularization to SSL.
# So far I haven't produced results better than SSL. The regularization strengh of the variance does not come into play
# because it seems like the variance cost diminishes very quickly.
#
# time python3 compress_classifier.py --arch resnet20_cifar ../../../data.cifar10 -p=50 --lr=0.3 --epochs=180 --compress=../ssl/ssl_channels-removal_training.yaml -j=1 --deterministic
#
lr_schedulers:
training_lr:
class: StepLR
step_size: 45
gamma: 0.10
regularizers:
Channels_l2_regularizer:
class: GroupLassoRegularizer
reg_regims:
module.layer1.0.conv2.weight: [0.0028, Channels]
module.layer1.1.conv2.weight: [0.0028, Channels]
module.layer1.2.conv2.weight: [0.0024, Channels]
module.layer2.0.conv2.weight: [0.0016, Channels] # sensitive
module.layer2.1.conv2.weight: [0.0028, Channels]
module.layer2.2.conv2.weight: [0.0028, Channels]
module.layer3.0.conv2.weight: [0.0008, Channels] # sensitive
module.layer3.1.conv2.weight: [0.0028, Channels]
#module.layer3.2.conv2.weight: [0.0006, Channels] # very sensitive
threshold_criteria: Mean_Abs
Channels_variance_reguralizer:
class: GroupVarianceRegularizer
reg_regims:
module.layer1.0.conv2.weight: [0.000008, Channels]
module.layer1.1.conv2.weight: [0.000008, Channels]
module.layer1.2.conv2.weight: [0.000008, Channels]
module.layer2.0.conv2.weight: [0.000008, Channels]
module.layer2.1.conv2.weight: [0.000008, Channels]
module.layer2.2.conv2.weight: [0.000008, Channels]
module.layer3.0.conv2.weight: [0.000008, Channels]
module.layer3.1.conv2.weight: [0.000008, Channels]
#module.layer3.2.conv2.weight: [0.000008, Channels]
extensions:
net_thinner:
class: 'ChannelRemover'
thinning_func_str: remove_channels
arch: 'resnet20_cifar'
dataset: 'cifar10'
policies:
- lr_scheduler:
instance_name: training_lr
starting_epoch: 45
ending_epoch: 300
frequency: 1
# After completeing the regularization, we perform network thinning and exit.
- extension:
instance_name: net_thinner
epochs: [179]
- regularizer:
instance_name: Channels_l2_regularizer
args:
keep_mask: True
starting_epoch: 0
ending_epoch: 180
frequency: 1
- regularizer:
instance_name: Channels_variance_reguralizer
args:
keep_mask: True
starting_epoch: 0
ending_epoch: 180
frequency: 1
...@@ -87,7 +87,7 @@ extensions: ...@@ -87,7 +87,7 @@ extensions:
net_thinner: net_thinner:
class: 'ChannelRemover' class: 'ChannelRemover'
thinning_func_str: remove_channels thinning_func_str: remove_channels
arch: 'resnet56_cifar' arch: 'resnet20_cifar'
dataset: 'cifar10' dataset: 'cifar10'
policies: policies:
......
...@@ -9,6 +9,9 @@ ...@@ -9,6 +9,9 @@
# #
# time python3 compress_classifier.py --arch resnet20_cifar ../../../data.cifar10 -p=50 --lr=0.3 --epochs=180 --compress=../ssl/ssl_channels-removal_training_x1.8.yaml -j=1 --deterministic # time python3 compress_classifier.py --arch resnet20_cifar ../../../data.cifar10 -p=50 --lr=0.3 --epochs=180 --compress=../ssl/ssl_channels-removal_training_x1.8.yaml -j=1 --deterministic
# #
# To fine-tune:
# time python3 compress_classifier.py --arch resnet20_cifar ../../../data.cifar10 -p=50 --lr=0.2 --epochs=98 --compress=../ssl/ssl_channels-removal_finetuning.yaml -j=1 --deterministic --resume=...
#
# Parameters: # Parameters:
# +----+-------------------------------------+----------------+---------------+----------------+------------+------------+----------+----------+----------+------------+---------+----------+------------+ # +----+-------------------------------------+----------------+---------------+----------------+------------+------------+----------+----------+----------+------------+---------+----------+------------+
# | | Name | Shape | NNZ (dense) | NNZ (sparse) | Cols (%) | Rows (%) | Ch (%) | 2D (%) | 3D (%) | Fine (%) | Std | Mean | Abs-Mean | # | | Name | Shape | NNZ (dense) | NNZ (sparse) | Cols (%) | Rows (%) | Ch (%) | 2D (%) | 3D (%) | Fine (%) | Std | Mean | Abs-Mean |
...@@ -89,7 +92,7 @@ extensions: ...@@ -89,7 +92,7 @@ extensions:
net_thinner: net_thinner:
class: 'ChannelRemover' class: 'ChannelRemover'
thinning_func_str: remove_channels thinning_func_str: remove_channels
arch: 'resnet56_cifar' arch: 'resnet20_cifar'
dataset: 'cifar10' dataset: 'cifar10'
policies: policies:
......
# SSL: Filter regularization
#
# This is a not-so-successful attempt at filter regularization:
# Total MACs: 27,800,192 = 68% compute density.
# Test Top1 after training: 90.39
# Test Top1 after fine-tuning: 90.93
#
# To train:
# time python3 compress_classifier.py --arch resnet20_cifar ../../../data.cifar10 -p=50 --lr=0.3 --epochs=180 --compress=../ssl/ssl_filter-removal_training.yaml -j=1 --deterministic --name="filters"
#
# To fine-tune:
# time python3 compress_classifier.py --arch resnet20_cifar ../../../data.cifar10 -p=50 --lr=0.2 --epochs=98 --compress=../ssl/ssl_channels-removal_finetuning.yaml -j=1 --deterministic --resume=...
#
# Parameters:
# +----+-------------------------------------+----------------+---------------+----------------+------------+------------+----------+----------+----------+------------+---------+----------+------------+
# | | Name | Shape | NNZ (dense) | NNZ (sparse) | Cols (%) | Rows (%) | Ch (%) | 2D (%) | 3D (%) | Fine (%) | Std | Mean | Abs-Mean |
# |----+-------------------------------------+----------------+---------------+----------------+------------+------------+----------+----------+----------+------------+---------+----------+------------|
# | 0 | module.conv1.weight | (16, 3, 3, 3) | 432 | 432 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.42230 | -0.00107 | 0.29227 |
# | 1 | module.layer1.0.conv1.weight | (13, 16, 3, 3) | 1872 | 1872 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.04139 | -0.00293 | 0.02218 |
# | 2 | module.layer1.0.conv2.weight | (16, 13, 3, 3) | 1872 | 1872 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.16320 | -0.00121 | 0.10359 |
# | 3 | module.layer1.1.conv1.weight | (9, 16, 3, 3) | 1296 | 1296 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.02499 | -0.00091 | 0.01594 |
# | 4 | module.layer1.1.conv2.weight | (16, 9, 3, 3) | 1296 | 1296 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.13183 | -0.01035 | 0.09682 |
# | 5 | module.layer1.2.conv1.weight | (10, 16, 3, 3) | 1440 | 1440 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.07616 | -0.00278 | 0.05246 |
# | 6 | module.layer1.2.conv2.weight | (16, 10, 3, 3) | 1440 | 1440 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.18164 | -0.01895 | 0.13244 |
# | 7 | module.layer2.0.conv1.weight | (25, 16, 3, 3) | 3600 | 3600 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.02687 | 0.00002 | 0.01887 |
# | 8 | module.layer2.0.conv2.weight | (32, 25, 3, 3) | 7200 | 7200 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.11768 | -0.01350 | 0.09049 |
# | 9 | module.layer2.0.downsample.0.weight | (32, 16, 1, 1) | 512 | 512 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.30952 | -0.03258 | 0.21696 |
# | 10 | module.layer2.1.conv1.weight | (32, 32, 3, 3) | 9216 | 9216 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.09844 | -0.00413 | 0.07454 |
# | 11 | module.layer2.1.conv2.weight | (32, 32, 3, 3) | 9216 | 9216 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.11731 | -0.00819 | 0.09292 |
# | 12 | module.layer2.2.conv1.weight | (4, 32, 3, 3) | 1152 | 1152 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.02086 | -0.00164 | 0.01553 |
# | 13 | module.layer2.2.conv2.weight | (32, 4, 3, 3) | 1152 | 1152 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.08841 | -0.00491 | 0.06650 |
# | 14 | module.layer3.0.conv1.weight | (48, 32, 3, 3) | 13824 | 13824 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.01713 | -0.00044 | 0.01255 |
# | 15 | module.layer3.0.conv2.weight | (64, 48, 3, 3) | 27648 | 27648 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.09969 | -0.00692 | 0.07733 |
# | 16 | module.layer3.0.downsample.0.weight | (64, 32, 1, 1) | 2048 | 2048 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.19055 | -0.01650 | 0.14967 |
# | 17 | module.layer3.1.conv1.weight | (30, 64, 3, 3) | 17280 | 17280 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.01936 | 0.00019 | 0.01468 |
# | 18 | module.layer3.1.conv2.weight | (64, 30, 3, 3) | 17280 | 17280 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.08263 | -0.01507 | 0.06434 |
# | 19 | module.layer3.2.conv1.weight | (64, 64, 3, 3) | 36864 | 36864 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.07410 | -0.00536 | 0.05833 |
# | 20 | module.layer3.2.conv2.weight | (64, 64, 3, 3) | 36864 | 36864 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.06848 | -0.00032 | 0.05342 |
# | 21 | module.fc.weight | (10, 64) | 640 | 640 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.66850 | -0.00003 | 0.54848 |
# | 22 | Total sparsity: | - | 194144 | 194144 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 |
# +----+-------------------------------------+----------------+---------------+----------------+------------+------------+----------+----------+----------+------------+---------+----------+------------+
# 2018-09-22 15:44:14,320 - Total sparsity: 0.00
#
# 2018-09-22 15:44:14,321 - --- validate (epoch=179)-----------
# 2018-09-22 15:44:14,321 - 5000 samples (256 per mini-batch)
# 2018-09-22 15:44:15,800 - ==> Top1: 90.460 Top5: 99.720 Loss: 0.332
#
# 2018-09-22 15:44:15,802 - ==> Best Top1: 90.900 On Epoch: 148
#
# 2018-09-22 15:44:15,802 - Saving checkpoint to: logs/filters___2018.09.22-151047/filters_checkpoint.pth.tar
# 2018-09-22 15:44:15,818 - --- test ---------------------
# 2018-09-22 15:44:15,818 - 10000 samples (256 per mini-batch)
# 2018-09-22 15:44:17,459 - ==> Top1: 90.390 Top5: 99.750 Loss: 0.349
lr_schedulers:
training_lr:
class: StepLR
step_size: 45
gamma: 0.10
regularizers:
Filters_groups_regularizer:
class: GroupLassoRegularizer
reg_regims:
module.layer1.0.conv1.weight: [0.0008, Filters]
module.layer1.1.conv1.weight: [0.0008, Filters]
module.layer1.2.conv1.weight: [0.0006, Filters]
module.layer2.0.conv1.weight: [0.0008, Filters]
module.layer2.1.conv1.weight: [0.0002, Filters]
module.layer2.2.conv1.weight: [0.0008, Filters]
module.layer3.0.conv1.weight: [0.0012, Filters]
module.layer3.1.conv1.weight: [0.0010, Filters]
module.layer3.2.conv1.weight: [0.0002, Filters]
threshold_criteria: Mean_Abs
extensions:
net_thinner:
class: 'FilterRemover'
thinning_func_str: remove_filters
arch: 'resnet20_cifar'
dataset: 'cifar10'
policies:
- lr_scheduler:
instance_name: training_lr
starting_epoch: 45
ending_epoch: 300
frequency: 1
# After completeing the regularization, we perform network thinning and exit.
- extension:
instance_name: net_thinner
epochs: [179]
- regularizer:
instance_name: Filters_groups_regularizer
args:
keep_mask: True
starting_epoch: 0
ending_epoch: 180
frequency: 1
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment