diff --git a/distiller/config.py b/distiller/config.py index 2c16fbac9945bf5a11a07f7c33990a7370eaa2d1..ad54cba42c2fca51a655894d2734f5bd56dc7639 100755 --- a/distiller/config.py +++ b/distiller/config.py @@ -58,8 +58,7 @@ def dict_config(model, optimizer, sched_dict): regularizers = __factory('regularizers', model, sched_dict) quantizers = __factory('quantizers', model, sched_dict, optimizer=optimizer) if len(quantizers) > 1: - print("\nError: Multiple Quantizers not supported") - exit(1) + raise ValueError("\nError: Multiple Quantizers not supported") extensions = __factory('extensions', model, sched_dict) try: @@ -72,7 +71,7 @@ def dict_config(model, optimizer, sched_dict): except TypeError as e: print('\n\nFatal Error: a policy is defined with a null pruner') print('Here\'s the policy definition for your reference:\n{}'.format(json.dumps(policy_def, indent=1))) - exit(1) + raise assert instance_name in pruners, "Pruner {} was not defined in the list of pruners".format(instance_name) pruner = pruners[instance_name] policy = distiller.PruningPolicy(pruner, args) @@ -105,8 +104,7 @@ def dict_config(model, optimizer, sched_dict): policy = extension else: - print("\nFATAL Parsing error while parsing the pruning schedule - unknown policy [%s]" % policy_def) - exit(1) + raise ValueError("\nFATAL Parsing error while parsing the pruning schedule - unknown policy [%s]".format(policy_def)) add_policy_to_scheduler(policy, policy_def, schedule) @@ -126,8 +124,7 @@ def dict_config(model, optimizer, sched_dict): except Exception as exception: print("\nFATAL Parsing error!\n%s" % json.dumps(policy_def, indent=1)) print("Exception: %s %s" % (type(exception), exception)) - exit(1) - + raise return schedule @@ -149,7 +146,7 @@ def file_config(model, optimizer, filename): return dict_config(model, optimizer, sched_dict) except yaml.YAMLError as exc: print("\nFATAL parsing error while parsing the schedule configuration file %s" % filename) - exit(1) + raise def __factory(container_type, model, sched_dict, **kwargs): @@ -164,14 +161,17 @@ def __factory(container_type, model, sched_dict, **kwargs): cfg_kwargs['name'] = name class_ = globals()[cfg_kwargs['class']] container[name] = class_(**__filter_kwargs(cfg_kwargs, class_.__init__)) + except NameError as error: + print("\nFatal error while parsing [section:%s] [item:%s]" % (container_type, name)) + raise except Exception as exception: print("\nFatal error while parsing [section:%s] [item:%s]" % (container_type, name)) print("Exception: %s %s" % (type(exception), exception)) - exit(1) + raise except Exception as exception: print("\nFatal while creating %s" % container_type) print("Exception: %s %s" % (type(exception), exception)) - exit(1) + raise return container diff --git a/distiller/pruning/__init__.py b/distiller/pruning/__init__.py index 5705a7a9435779019931d76159aca552e017dc3c..dc8a1c76ca7910f00194b5c7c07f28695cc68fac 100755 --- a/distiller/pruning/__init__.py +++ b/distiller/pruning/__init__.py @@ -19,7 +19,7 @@ """ from .magnitude_pruner import MagnitudeParameterPruner -from .automated_gradual_pruner import AutomatedGradualPruner +from .automated_gradual_pruner import AutomatedGradualPruner, StructuredAutomatedGradualPruner from .level_pruner import SparsityLevelParameterPruner from .sensitivity_pruner import SensitivityPruner from .structure_pruner import StructureParameterPruner diff --git a/distiller/pruning/automated_gradual_pruner.py b/distiller/pruning/automated_gradual_pruner.py index 25a072ade43728887bbdb5746a8e8ac9dffcf03c..dff2168320811349d2644e2dadd94c21155695d2 100755 --- a/distiller/pruning/automated_gradual_pruner.py +++ b/distiller/pruning/automated_gradual_pruner.py @@ -16,7 +16,11 @@ from .pruner import _ParameterPruner from .level_pruner import SparsityLevelParameterPruner +from .ranked_structures_pruner import L1RankedStructureParameterPruner from distiller.utils import * +# import logging +# msglogger = logging.getLogger() + class AutomatedGradualPruner(_ParameterPruner): """Prune to an exact pruning level specification. @@ -30,13 +34,18 @@ class AutomatedGradualPruner(_ParameterPruner): (https://arxiv.org/pdf/1710.01878.pdf) """ - def __init__(self, name, initial_sparsity, final_sparsity, weights): + def __init__(self, name, initial_sparsity, final_sparsity, weights, + pruning_fn=None): super(AutomatedGradualPruner, self).__init__(name) self.initial_sparsity = initial_sparsity self.final_sparsity = final_sparsity assert final_sparsity > initial_sparsity self.params_names = weights assert self.params_names + if pruning_fn is None: + self.pruning_fn = self.prune_to_target_sparsity + else: + self.pruning_fn = pruning_fn def set_param_mask(self, param, param_name, zeros_mask_dict, meta): if param_name not in self.params_names: @@ -52,6 +61,28 @@ class AutomatedGradualPruner(_ParameterPruner): target_sparsity = (self.final_sparsity + (self.initial_sparsity-self.final_sparsity) * (1.0 - ((current_epoch-starting_epoch)/span))**3) + self.pruning_fn(param, param_name, zeros_mask_dict, target_sparsity) + + @staticmethod + def prune_to_target_sparsity(param, param_name, zeros_mask_dict, target_sparsity): + return SparsityLevelParameterPruner.prune_level(param, param_name, zeros_mask_dict, target_sparsity) + + +class StructuredAutomatedGradualPruner(AutomatedGradualPruner): + def __init__(self, name, initial_sparsity, final_sparsity, reg_regims): + self.reg_regims = reg_regims + weights = [weight for weight in reg_regims.keys()] + if not all([group in ['3D', 'Filters', 'Channels'] for group in reg_regims.values()]): + raise ValueError("Currently only filter (3D) and channel pruning is supported") + super(StructuredAutomatedGradualPruner, self).__init__(name, initial_sparsity, + final_sparsity, weights, + pruning_fn=self.prune_to_target_sparsity) - SparsityLevelParameterPruner.prune_level(param, param_name, zeros_mask_dict, - target_sparsity) + def prune_to_target_sparsity(self, param, param_name, zeros_mask_dict, target_sparsity): + if self.reg_regims[param_name] in ['3D', 'Filters']: + L1RankedStructureParameterPruner.rank_prune_filters(target_sparsity, param, + param_name, zeros_mask_dict) + else: + if self.reg_regims[param_name] == 'Channels': + L1RankedStructureParameterPruner.rank_prune_channels(target_sparsity, param, + param_name, zeros_mask_dict) diff --git a/distiller/pruning/ranked_structures_pruner.py b/distiller/pruning/ranked_structures_pruner.py index b4a667cb931328fc3e022b79c856213028e9c891..407cf737edbad9b4ac4dc445ac30b55aa0ee5b3f 100755 --- a/distiller/pruning/ranked_structures_pruner.py +++ b/distiller/pruning/ranked_structures_pruner.py @@ -40,13 +40,13 @@ class L1RankedStructureParameterPruner(_ParameterPruner): if fraction_to_prune == 0: return - if group_type not in ['3D', 'Channels']: - raise ValueError("Currently only filter (3D) and channel ranking is supported") - if group_type == "3D": + if group_type in ['3D', 'Filters']: return self.rank_prune_filters(fraction_to_prune, param, param_name, zeros_mask_dict) - elif group_type == "Channels": + elif group_type == 'Channels': return self.rank_prune_channels(fraction_to_prune, param, param_name, zeros_mask_dict) - + else: + raise ValueError("Currently only filter (3D) and channel ranking is supported") + @staticmethod def rank_channels(fraction_to_prune, param): num_filters = param.size(0) @@ -64,15 +64,15 @@ class L1RankedStructureParameterPruner(_ParameterPruner): k = int(fraction_to_prune * channel_mags.size(0)) if k == 0: msglogger.info("Too few channels (%d)- can't prune %.1f%% channels", - num_channels, 100*fraction_to_prune) + num_channels, 100*fraction_to_prune) return None, None bottomk, _ = torch.topk(channel_mags, k, largest=False, sorted=True) return bottomk, channel_mags - - def rank_prune_channels(self, fraction_to_prune, param, param_name, zeros_mask_dict): - bottomk_channels, channel_mags = self.rank_channels(fraction_to_prune, param) + @staticmethod + def rank_prune_channels(fraction_to_prune, param, param_name, zeros_mask_dict): + bottomk_channels, channel_mags = L1RankedStructureParameterPruner.rank_channels(fraction_to_prune, param) if bottomk_channels is None: # Empty list means that fraction_to_prune is too low to prune anything return @@ -91,8 +91,8 @@ class L1RankedStructureParameterPruner(_ParameterPruner): distiller.sparsity_ch(zeros_mask_dict[param_name].mask), fraction_to_prune, len(bottomk_channels), num_channels) - - def rank_prune_filters(self, fraction_to_prune, param, param_name, zeros_mask_dict): + @staticmethod + def rank_prune_filters(fraction_to_prune, param, param_name, zeros_mask_dict): assert param.dim() == 4, "This thresholding is only supported for 4D weights" view_filters = param.view(param.size(0), -1) filter_mags = view_filters.data.norm(1, dim=1) # same as view_filters.data.abs().sum(dim=1) diff --git a/examples/agp-pruning/resnet20_filters.schedule_agp.yaml b/examples/agp-pruning/resnet20_filters.schedule_agp.yaml new file mode 100755 index 0000000000000000000000000000000000000000..83cc2fdfe2cd62e7a3e78c04eaf7f304eccc7fa2 --- /dev/null +++ b/examples/agp-pruning/resnet20_filters.schedule_agp.yaml @@ -0,0 +1,109 @@ +# Baseline results: +# Top1: 91.780 Top5: 99.710 Loss: 0.376 +# time python3 compress_classifier.py --arch resnet20_cifar ../../../data.cifar10 -p=50 --lr=0.1 --epochs=180 --compress=../agp-pruning/resnet20_filters.schedule_agp.yaml -j=1 --deterministic --resume=../ssl/checkpoints/checkpoint_trained_dense.pth.tar +# +# Parameters: +# +----+-------------------------------------+----------------+---------------+----------------+------------+------------+----------+----------+----------+------------+---------+----------+------------+ +# | | Name | Shape | NNZ (dense) | NNZ (sparse) | Cols (%) | Rows (%) | Ch (%) | 2D (%) | 3D (%) | Fine (%) | Std | Mean | Abs-Mean | +# |----+-------------------------------------+----------------+---------------+----------------+------------+------------+----------+----------+----------+------------+---------+----------+------------| +# | 0 | module.conv1.weight | (16, 3, 3, 3) | 432 | 432 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.39017 | -0.00681 | 0.27798 | +# | 1 | module.layer1.0.conv1.weight | (16, 16, 3, 3) | 2304 | 2304 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.14674 | -0.00888 | 0.10358 | +# | 2 | module.layer1.0.conv2.weight | (16, 16, 3, 3) | 2304 | 2304 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.14363 | 0.00146 | 0.10571 | +# | 3 | module.layer1.1.conv1.weight | (16, 16, 3, 3) | 2304 | 2304 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.12673 | -0.01323 | 0.09655 | +# | 4 | module.layer1.1.conv2.weight | (16, 16, 3, 3) | 2304 | 2304 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.11736 | -0.00420 | 0.09039 | +# | 5 | module.layer1.2.conv1.weight | (16, 16, 3, 3) | 2304 | 2304 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.16400 | -0.00959 | 0.12023 | +# | 6 | module.layer1.2.conv2.weight | (16, 16, 3, 3) | 2304 | 2304 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.13288 | -0.00014 | 0.10020 | +# | 7 | module.layer2.0.conv1.weight | (20, 16, 3, 3) | 2880 | 2880 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.14688 | -0.00195 | 0.11372 | +# | 8 | module.layer2.0.conv2.weight | (32, 20, 3, 3) | 5760 | 5760 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.12828 | -0.00643 | 0.10049 | +# | 9 | module.layer2.0.downsample.0.weight | (32, 16, 1, 1) | 512 | 512 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.25453 | -0.00949 | 0.17990 | +# | 10 | module.layer2.1.conv1.weight | (20, 32, 3, 3) | 5760 | 5760 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.10884 | -0.00760 | 0.08639 | +# | 11 | module.layer2.1.conv2.weight | (32, 20, 3, 3) | 5760 | 5760 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.09702 | -0.00599 | 0.07635 | +# | 12 | module.layer2.2.conv1.weight | (20, 32, 3, 3) | 5760 | 5760 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.11464 | -0.01339 | 0.09051 | +# | 13 | module.layer2.2.conv2.weight | (32, 20, 3, 3) | 5760 | 5760 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.09177 | 0.00195 | 0.07188 | +# | 14 | module.layer3.0.conv1.weight | (64, 32, 3, 3) | 18432 | 18432 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.09764 | -0.00680 | 0.07753 | +# | 15 | module.layer3.0.conv2.weight | (64, 64, 3, 3) | 36864 | 36864 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.09308 | -0.00392 | 0.07406 | +# | 16 | module.layer3.0.downsample.0.weight | (64, 32, 1, 1) | 2048 | 2048 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.12596 | -0.00848 | 0.09993 | +# | 17 | module.layer3.1.conv1.weight | (64, 64, 3, 3) | 36864 | 11060 | 0.00000 | 0.00000 | 0.00000 | 6.49414 | 0.00000 | 69.99783 | 0.07444 | -0.00396 | 0.03728 | +# | 18 | module.layer3.1.conv2.weight | (64, 64, 3, 3) | 36864 | 11060 | 0.00000 | 0.00000 | 0.00000 | 7.49512 | 0.00000 | 69.99783 | 0.06792 | -0.00462 | 0.03381 | +# | 19 | module.layer3.2.conv1.weight | (64, 64, 3, 3) | 36864 | 11060 | 0.00000 | 0.00000 | 0.00000 | 9.81445 | 0.00000 | 69.99783 | 0.06811 | -0.00477 | 0.03417 | +# | 20 | module.layer3.2.conv2.weight | (64, 64, 3, 3) | 36864 | 11060 | 0.00000 | 0.00000 | 0.00000 | 26.00098 | 0.00000 | 69.99783 | 0.03877 | 0.00056 | 0.01954 | +# | 21 | module.fc.weight | (10, 64) | 640 | 640 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.56077 | -0.00002 | 0.48798 | +# | 22 | Total sparsity: | - | 251888 | 148672 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 40.97694 | 0.00000 | 0.00000 | 0.00000 | +# +----+-------------------------------------+----------------+---------------+----------------+------------+------------+----------+----------+----------+------------+---------+----------+------------+ +# Total sparsity: 40.98 +# +# --- validate (epoch=359)----------- +# 5000 samples (256 per mini-batch) +# ==> Top1: 93.320 Top5: 99.880 Loss: 0.246 +# +# ==> Best Top1: 93.740 On Epoch: 265 +# +# Saving checkpoint to: logs/2018.10.09-020359/checkpoint.pth.tar +# --- test --------------------- +# 10000 samples (256 per mini-batch) +# ==> Top1: 91.580 Top5: 99.710 Loss: 0.355 +# + +version: 1 +pruners: + low_pruner: + class: StructuredAutomatedGradualPruner + initial_sparsity : 0.10 + final_sparsity: 0.40 + reg_regims: + module.layer2.0.conv1.weight: Filters + module.layer2.1.conv1.weight: Filters + module.layer2.2.conv1.weight: Filters + + fine_pruner: + class: AutomatedGradualPruner + initial_sparsity : 0.05 + final_sparsity: 0.70 + weights: [module.layer3.1.conv1.weight, module.layer3.1.conv2.weight, + module.layer3.2.conv1.weight, module.layer3.2.conv2.weight] + +lr_schedulers: + pruning_lr: + class: StepLR + step_size: 50 + gamma: 0.10 + + +extensions: + net_thinner: + class: 'FilterRemover' + thinning_func_str: remove_filters + arch: 'resnet20_cifar' + dataset: 'cifar10' + +policies: + - pruner: + instance_name : low_pruner + starting_epoch: 180 + ending_epoch: 200 + frequency: 2 + + - pruner: + instance_name : fine_pruner + starting_epoch: 200 + ending_epoch: 220 + frequency: 2 + # Currently the thinner is disabled until the end, because it interacts with the sparsity + # goals of the StructuredAutomatedGradualPruner. + # This can be fixed rather easily. + # - extension: + # instance_name: net_thinner + # starting_epoch: 0 + # ending_epoch: 20 + # frequency: 2 + +# After completeing the pruning, we perform network thinning and continue fine-tuning. + - extension: + instance_name: net_thinner + epochs: [202] + + - lr_scheduler: + instance_name: pruning_lr + starting_epoch: 0 + ending_epoch: 400 + frequency: 1 diff --git a/examples/agp-pruning/resnet20_filters.schedule_agp_2.yaml b/examples/agp-pruning/resnet20_filters.schedule_agp_2.yaml new file mode 100755 index 0000000000000000000000000000000000000000..c91a157934b587c0d9d6c18ecd9c98f1014f01a3 --- /dev/null +++ b/examples/agp-pruning/resnet20_filters.schedule_agp_2.yaml @@ -0,0 +1,119 @@ +# Baseline results: +# Top1: 91.780 Top5: 99.710 Loss: 0.376 +# time python3 compress_classifier.py --arch resnet20_cifar ../../../data.cifar10 -p=50 --lr=0.1 --epochs=180 --compress=../agp-pruning/resnet20_filters.schedule_agp.yaml -j=1 --deterministic --resume=../ssl/checkpoints/checkpoint_trained_dense.pth.tar +# +# +# Parameters: +# +----+-------------------------------------+----------------+---------------+----------------+------------+------------+----------+----------+----------+------------+---------+----------+------------+ +# | | Name | Shape | NNZ (dense) | NNZ (sparse) | Cols (%) | Rows (%) | Ch (%) | 2D (%) | 3D (%) | Fine (%) | Std | Mean | Abs-Mean | +# |----+-------------------------------------+----------------+---------------+----------------+------------+------------+----------+----------+----------+------------+---------+----------+------------| +# | 0 | module.conv1.weight | (16, 3, 3, 3) | 432 | 432 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.39196 | -0.00533 | 0.27677 | +# | 1 | module.layer1.0.conv1.weight | (10, 16, 3, 3) | 1440 | 1440 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.17516 | -0.01627 | 0.12761 | +# | 2 | module.layer1.0.conv2.weight | (16, 10, 3, 3) | 1440 | 1440 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.17375 | 0.00208 | 0.12753 | +# | 3 | module.layer1.1.conv1.weight | (10, 16, 3, 3) | 1440 | 1440 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.14753 | -0.02355 | 0.11205 | +# | 4 | module.layer1.1.conv2.weight | (16, 10, 3, 3) | 1440 | 1440 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.13242 | -0.00280 | 0.10184 | +# | 5 | module.layer1.2.conv1.weight | (10, 16, 3, 3) | 1440 | 1440 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.18848 | -0.00708 | 0.13828 | +# | 6 | module.layer1.2.conv2.weight | (16, 10, 3, 3) | 1440 | 1440 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.15502 | -0.00528 | 0.11709 | +# | 7 | module.layer2.0.conv1.weight | (20, 16, 3, 3) | 2880 | 2880 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.15266 | -0.00169 | 0.11773 | +# | 8 | module.layer2.0.conv2.weight | (32, 20, 3, 3) | 5760 | 5760 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.13070 | -0.00823 | 0.10204 | +# | 9 | module.layer2.0.downsample.0.weight | (32, 16, 1, 1) | 512 | 512 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.25380 | -0.01324 | 0.17815 | +# | 10 | module.layer2.1.conv1.weight | (20, 32, 3, 3) | 5760 | 5760 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.11349 | -0.00928 | 0.08977 | +# | 11 | module.layer2.1.conv2.weight | (32, 20, 3, 3) | 5760 | 5760 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.09904 | -0.00621 | 0.07856 | +# | 12 | module.layer2.2.conv1.weight | (20, 32, 3, 3) | 5760 | 5760 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.11538 | -0.01280 | 0.09106 | +# | 13 | module.layer2.2.conv2.weight | (32, 20, 3, 3) | 5760 | 5760 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.09239 | 0.00091 | 0.07236 | +# | 14 | module.layer3.0.conv1.weight | (64, 32, 3, 3) | 18432 | 18432 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.09853 | -0.00671 | 0.07821 | +# | 15 | module.layer3.0.conv2.weight | (64, 64, 3, 3) | 36864 | 36864 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.09391 | -0.00407 | 0.07466 | +# | 16 | module.layer3.0.downsample.0.weight | (64, 32, 1, 1) | 2048 | 2048 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.12660 | -0.00968 | 0.10101 | +# | 17 | module.layer3.1.conv1.weight | (64, 64, 3, 3) | 36864 | 11060 | 0.00000 | 0.00000 | 0.00000 | 6.56738 | 0.00000 | 69.99783 | 0.07488 | -0.00414 | 0.03739 | +# | 18 | module.layer3.1.conv2.weight | (64, 64, 3, 3) | 36864 | 11060 | 0.00000 | 0.00000 | 0.00000 | 7.69043 | 0.00000 | 69.99783 | 0.06839 | -0.00472 | 0.03404 | +# | 19 | module.layer3.2.conv1.weight | (64, 64, 3, 3) | 36864 | 11060 | 0.00000 | 0.00000 | 0.00000 | 9.47266 | 0.00000 | 69.99783 | 0.06867 | -0.00485 | 0.03450 | +# | 20 | module.layer3.2.conv2.weight | (64, 64, 3, 3) | 36864 | 11060 | 0.00000 | 0.00000 | 0.00000 | 26.41602 | 0.00000 | 69.99783 | 0.03915 | 0.00033 | 0.01970 | +# | 21 | module.fc.weight | (10, 64) | 640 | 640 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.56522 | -0.00002 | 0.49040 | +# | 22 | Total sparsity: | - | 246704 | 143488 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 41.83799 | 0.00000 | 0.00000 | 0.00000 | +# +----+-------------------------------------+----------------+---------------+----------------+------------+------------+----------+----------+----------+------------+---------+----------+------------+ +# Total sparsity: 41.84 +# +# --- validate (epoch=359)----------- +# 5000 samples (256 per mini-batch) +# ==> Top1: 92.540 Top5: 99.960 Loss: 0.246 +# +# ==> Best Top1: 93.580 On Epoch: 328 +# +# Saving checkpoint to: logs/2018.10.09-200709/checkpoint.pth.tar +# --- test --------------------- +# 10000 samples (256 per mini-batch) +# ==> Top1: 91.190 Top5: 99.660 Loss: 0.372 +# +# +# Log file for this run: /home/cvds_lab/nzmora/pytorch_workspace/distiller/examples/classifier_compression/logs/2018.10.09-200709/2018.10.09-200709.log +# +# real 32m23.439s +# user 74m59.073s +# sys 9m7.764s + +version: 1 +pruners: + low_pruner: + class: StructuredAutomatedGradualPruner + initial_sparsity : 0.10 + final_sparsity: 0.40 + reg_regims: + module.layer1.0.conv1.weight: Filters + module.layer1.1.conv1.weight: Filters + module.layer1.2.conv1.weight: Filters + module.layer2.0.conv1.weight: Filters + module.layer2.1.conv1.weight: Filters + module.layer2.2.conv1.weight: Filters + + fine_pruner: + class: AutomatedGradualPruner + initial_sparsity : 0.05 + final_sparsity: 0.70 + weights: [module.layer3.1.conv1.weight, module.layer3.1.conv2.weight, + module.layer3.2.conv1.weight, module.layer3.2.conv2.weight] + +lr_schedulers: + pruning_lr: + class: StepLR + step_size: 50 + gamma: 0.10 + + +extensions: + net_thinner: + class: 'FilterRemover' + thinning_func_str: remove_filters + arch: 'resnet20_cifar' + dataset: 'cifar10' + +policies: + - pruner: + instance_name : low_pruner + starting_epoch: 180 + ending_epoch: 200 + frequency: 2 + + - pruner: + instance_name : fine_pruner + starting_epoch: 200 + ending_epoch: 220 + frequency: 2 + # Currently the thinner is disabled until the end, because it interacts with the sparsity + # goals of the StructuredAutomatedGradualPruner. + # This can be fixed rather easily. + # - extension: + # instance_name: net_thinner + # starting_epoch: 0 + # ending_epoch: 20 + # frequency: 2 + +# After completeing the pruning, we perform network thinning and continue fine-tuning. + - extension: + instance_name: net_thinner + epochs: [202] + + - lr_scheduler: + instance_name: pruning_lr + starting_epoch: 0 + ending_epoch: 400 + frequency: 1 diff --git a/models/__init__.py b/models/__init__.py index 91f8cec936dd8e634b881dd2650a6ab1fec4fdee..8d40c77d8f59b1fc8ff987a716a0f1ea2182fd4f 100755 --- a/models/__init__.py +++ b/models/__init__.py @@ -74,5 +74,10 @@ def create_model(pretrained, dataset, arch, parallel=True, device_ids=None): model.features = torch.nn.DataParallel(model.features, device_ids=device_ids) elif parallel: model = torch.nn.DataParallel(model, device_ids=device_ids) + + # explicitly add a softmax layer, because it is useful when exporting to ONNX + model.original_forward = model.forward + softmax = torch.nn.Softmax(dim=1) + model.forward = lambda input: softmax(model.original_forward(input)) model.cuda() return model