diff --git a/distiller/config.py b/distiller/config.py
index 2c16fbac9945bf5a11a07f7c33990a7370eaa2d1..ad54cba42c2fca51a655894d2734f5bd56dc7639 100755
--- a/distiller/config.py
+++ b/distiller/config.py
@@ -58,8 +58,7 @@ def dict_config(model, optimizer, sched_dict):
     regularizers = __factory('regularizers', model, sched_dict)
     quantizers = __factory('quantizers', model, sched_dict, optimizer=optimizer)
     if len(quantizers) > 1:
-        print("\nError: Multiple Quantizers not supported")
-        exit(1)
+        raise ValueError("\nError: Multiple Quantizers not supported")
     extensions = __factory('extensions', model, sched_dict)
 
     try:
@@ -72,7 +71,7 @@ def dict_config(model, optimizer, sched_dict):
                 except TypeError as e:
                     print('\n\nFatal Error: a policy is defined with a null pruner')
                     print('Here\'s the policy definition for your reference:\n{}'.format(json.dumps(policy_def, indent=1)))
-                    exit(1)
+                    raise
                 assert instance_name in pruners, "Pruner {} was not defined in the list of pruners".format(instance_name)
                 pruner = pruners[instance_name]
                 policy = distiller.PruningPolicy(pruner, args)
@@ -105,8 +104,7 @@ def dict_config(model, optimizer, sched_dict):
                 policy = extension
 
             else:
-                print("\nFATAL Parsing error while parsing the pruning schedule - unknown policy [%s]" % policy_def)
-                exit(1)
+                raise ValueError("\nFATAL Parsing error while parsing the pruning schedule - unknown policy [%s]".format(policy_def))
 
             add_policy_to_scheduler(policy, policy_def, schedule)
 
@@ -126,8 +124,7 @@ def dict_config(model, optimizer, sched_dict):
     except Exception as exception:
         print("\nFATAL Parsing error!\n%s" % json.dumps(policy_def, indent=1))
         print("Exception: %s %s" % (type(exception), exception))
-        exit(1)
-
+        raise
     return schedule
 
 
@@ -149,7 +146,7 @@ def file_config(model, optimizer, filename):
             return dict_config(model, optimizer, sched_dict)
         except yaml.YAMLError as exc:
             print("\nFATAL parsing error while parsing the schedule configuration file %s" % filename)
-            exit(1)
+            raise
 
 
 def __factory(container_type, model, sched_dict, **kwargs):
@@ -164,14 +161,17 @@ def __factory(container_type, model, sched_dict, **kwargs):
                     cfg_kwargs['name'] = name
                     class_ = globals()[cfg_kwargs['class']]
                     container[name] = class_(**__filter_kwargs(cfg_kwargs, class_.__init__))
+                except NameError as error:
+                    print("\nFatal error while parsing [section:%s] [item:%s]" % (container_type, name))
+                    raise
                 except Exception as exception:
                     print("\nFatal error while parsing [section:%s] [item:%s]" % (container_type, name))
                     print("Exception: %s %s" % (type(exception), exception))
-                    exit(1)
+                    raise
         except Exception as exception:
             print("\nFatal while creating %s" % container_type)
             print("Exception: %s %s" % (type(exception), exception))
-            exit(1)
+            raise
 
     return container
 
diff --git a/distiller/pruning/__init__.py b/distiller/pruning/__init__.py
index 5705a7a9435779019931d76159aca552e017dc3c..dc8a1c76ca7910f00194b5c7c07f28695cc68fac 100755
--- a/distiller/pruning/__init__.py
+++ b/distiller/pruning/__init__.py
@@ -19,7 +19,7 @@
 """
 
 from .magnitude_pruner import MagnitudeParameterPruner
-from .automated_gradual_pruner import AutomatedGradualPruner
+from .automated_gradual_pruner import AutomatedGradualPruner, StructuredAutomatedGradualPruner
 from .level_pruner import SparsityLevelParameterPruner
 from .sensitivity_pruner import SensitivityPruner
 from .structure_pruner import StructureParameterPruner
diff --git a/distiller/pruning/automated_gradual_pruner.py b/distiller/pruning/automated_gradual_pruner.py
index 25a072ade43728887bbdb5746a8e8ac9dffcf03c..dff2168320811349d2644e2dadd94c21155695d2 100755
--- a/distiller/pruning/automated_gradual_pruner.py
+++ b/distiller/pruning/automated_gradual_pruner.py
@@ -16,7 +16,11 @@
 
 from .pruner import _ParameterPruner
 from .level_pruner import SparsityLevelParameterPruner
+from .ranked_structures_pruner import L1RankedStructureParameterPruner
 from distiller.utils import *
+# import logging
+# msglogger = logging.getLogger()
+
 
 class AutomatedGradualPruner(_ParameterPruner):
     """Prune to an exact pruning level specification.
@@ -30,13 +34,18 @@ class AutomatedGradualPruner(_ParameterPruner):
     (https://arxiv.org/pdf/1710.01878.pdf)
     """
 
-    def __init__(self, name, initial_sparsity, final_sparsity, weights):
+    def __init__(self, name, initial_sparsity, final_sparsity, weights,
+                 pruning_fn=None):
         super(AutomatedGradualPruner, self).__init__(name)
         self.initial_sparsity = initial_sparsity
         self.final_sparsity = final_sparsity
         assert final_sparsity > initial_sparsity
         self.params_names = weights
         assert self.params_names
+        if pruning_fn is None:
+            self.pruning_fn = self.prune_to_target_sparsity
+        else:
+            self.pruning_fn = pruning_fn
 
     def set_param_mask(self, param, param_name, zeros_mask_dict, meta):
         if param_name not in self.params_names:
@@ -52,6 +61,28 @@ class AutomatedGradualPruner(_ParameterPruner):
         target_sparsity = (self.final_sparsity +
                            (self.initial_sparsity-self.final_sparsity) *
                            (1.0 - ((current_epoch-starting_epoch)/span))**3)
+        self.pruning_fn(param, param_name, zeros_mask_dict, target_sparsity)
+
+    @staticmethod
+    def prune_to_target_sparsity(param, param_name, zeros_mask_dict, target_sparsity):
+        return SparsityLevelParameterPruner.prune_level(param, param_name, zeros_mask_dict, target_sparsity)
+
+
+class StructuredAutomatedGradualPruner(AutomatedGradualPruner):
+    def __init__(self, name, initial_sparsity, final_sparsity, reg_regims):
+        self.reg_regims = reg_regims
+        weights = [weight for weight in reg_regims.keys()]
+        if not all([group in ['3D', 'Filters', 'Channels'] for group in reg_regims.values()]):
+            raise ValueError("Currently only filter (3D) and channel pruning is supported")
+        super(StructuredAutomatedGradualPruner, self).__init__(name, initial_sparsity,
+                                                               final_sparsity, weights,
+                                                               pruning_fn=self.prune_to_target_sparsity)
 
-        SparsityLevelParameterPruner.prune_level(param, param_name, zeros_mask_dict,
-                                                 target_sparsity)
+    def prune_to_target_sparsity(self, param, param_name, zeros_mask_dict, target_sparsity):
+        if self.reg_regims[param_name] in ['3D', 'Filters']:
+            L1RankedStructureParameterPruner.rank_prune_filters(target_sparsity, param,
+                                                                param_name, zeros_mask_dict)
+        else:
+            if self.reg_regims[param_name] == 'Channels':
+                L1RankedStructureParameterPruner.rank_prune_channels(target_sparsity, param,
+                                                                     param_name, zeros_mask_dict)
diff --git a/distiller/pruning/ranked_structures_pruner.py b/distiller/pruning/ranked_structures_pruner.py
index b4a667cb931328fc3e022b79c856213028e9c891..407cf737edbad9b4ac4dc445ac30b55aa0ee5b3f 100755
--- a/distiller/pruning/ranked_structures_pruner.py
+++ b/distiller/pruning/ranked_structures_pruner.py
@@ -40,13 +40,13 @@ class L1RankedStructureParameterPruner(_ParameterPruner):
         if fraction_to_prune == 0:
             return
 
-        if group_type not in ['3D', 'Channels']:
-            raise ValueError("Currently only filter (3D) and channel ranking is supported")
-        if group_type == "3D":
+        if group_type in ['3D', 'Filters']:
             return self.rank_prune_filters(fraction_to_prune, param, param_name, zeros_mask_dict)
-        elif group_type == "Channels":
+        elif group_type == 'Channels':
             return self.rank_prune_channels(fraction_to_prune, param, param_name, zeros_mask_dict)
-
+        else:
+            raise ValueError("Currently only filter (3D) and channel ranking is supported")
+            
     @staticmethod
     def rank_channels(fraction_to_prune, param):
         num_filters = param.size(0)
@@ -64,15 +64,15 @@ class L1RankedStructureParameterPruner(_ParameterPruner):
         k = int(fraction_to_prune * channel_mags.size(0))
         if k == 0:
             msglogger.info("Too few channels (%d)- can't prune %.1f%% channels",
-                            num_channels, 100*fraction_to_prune)
+                           num_channels, 100*fraction_to_prune)
             return None, None
 
         bottomk, _ = torch.topk(channel_mags, k, largest=False, sorted=True)
         return bottomk, channel_mags
 
-
-    def rank_prune_channels(self, fraction_to_prune, param, param_name, zeros_mask_dict):
-        bottomk_channels, channel_mags = self.rank_channels(fraction_to_prune, param)
+    @staticmethod
+    def rank_prune_channels(fraction_to_prune, param, param_name, zeros_mask_dict):
+        bottomk_channels, channel_mags = L1RankedStructureParameterPruner.rank_channels(fraction_to_prune, param)
         if bottomk_channels is None:
             # Empty list means that fraction_to_prune is too low to prune anything
             return
@@ -91,8 +91,8 @@ class L1RankedStructureParameterPruner(_ParameterPruner):
                        distiller.sparsity_ch(zeros_mask_dict[param_name].mask),
                        fraction_to_prune, len(bottomk_channels), num_channels)
 
-
-    def rank_prune_filters(self, fraction_to_prune, param, param_name, zeros_mask_dict):
+    @staticmethod
+    def rank_prune_filters(fraction_to_prune, param, param_name, zeros_mask_dict):
         assert param.dim() == 4, "This thresholding is only supported for 4D weights"
         view_filters = param.view(param.size(0), -1)
         filter_mags = view_filters.data.norm(1, dim=1)  # same as view_filters.data.abs().sum(dim=1)
diff --git a/examples/agp-pruning/resnet20_filters.schedule_agp.yaml b/examples/agp-pruning/resnet20_filters.schedule_agp.yaml
new file mode 100755
index 0000000000000000000000000000000000000000..83cc2fdfe2cd62e7a3e78c04eaf7f304eccc7fa2
--- /dev/null
+++ b/examples/agp-pruning/resnet20_filters.schedule_agp.yaml
@@ -0,0 +1,109 @@
+# Baseline results:
+# Top1: 91.780    Top5: 99.710    Loss: 0.376
+# time python3 compress_classifier.py --arch resnet20_cifar  ../../../data.cifar10 -p=50 --lr=0.1 --epochs=180 --compress=../agp-pruning/resnet20_filters.schedule_agp.yaml -j=1 --deterministic --resume=../ssl/checkpoints/checkpoint_trained_dense.pth.tar
+#
+# Parameters:
+# +----+-------------------------------------+----------------+---------------+----------------+------------+------------+----------+----------+----------+------------+---------+----------+------------+
+# |    | Name                                | Shape          |   NNZ (dense) |   NNZ (sparse) |   Cols (%) |   Rows (%) |   Ch (%) |   2D (%) |   3D (%) |   Fine (%) |     Std |     Mean |   Abs-Mean |
+# |----+-------------------------------------+----------------+---------------+----------------+------------+------------+----------+----------+----------+------------+---------+----------+------------|
+# |  0 | module.conv1.weight                 | (16, 3, 3, 3)  |           432 |            432 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.39017 | -0.00681 |    0.27798 |
+# |  1 | module.layer1.0.conv1.weight        | (16, 16, 3, 3) |          2304 |           2304 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.14674 | -0.00888 |    0.10358 |
+# |  2 | module.layer1.0.conv2.weight        | (16, 16, 3, 3) |          2304 |           2304 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.14363 |  0.00146 |    0.10571 |
+# |  3 | module.layer1.1.conv1.weight        | (16, 16, 3, 3) |          2304 |           2304 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.12673 | -0.01323 |    0.09655 |
+# |  4 | module.layer1.1.conv2.weight        | (16, 16, 3, 3) |          2304 |           2304 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.11736 | -0.00420 |    0.09039 |
+# |  5 | module.layer1.2.conv1.weight        | (16, 16, 3, 3) |          2304 |           2304 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.16400 | -0.00959 |    0.12023 |
+# |  6 | module.layer1.2.conv2.weight        | (16, 16, 3, 3) |          2304 |           2304 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.13288 | -0.00014 |    0.10020 |
+# |  7 | module.layer2.0.conv1.weight        | (20, 16, 3, 3) |          2880 |           2880 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.14688 | -0.00195 |    0.11372 |
+# |  8 | module.layer2.0.conv2.weight        | (32, 20, 3, 3) |          5760 |           5760 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.12828 | -0.00643 |    0.10049 |
+# |  9 | module.layer2.0.downsample.0.weight | (32, 16, 1, 1) |           512 |            512 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.25453 | -0.00949 |    0.17990 |
+# | 10 | module.layer2.1.conv1.weight        | (20, 32, 3, 3) |          5760 |           5760 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.10884 | -0.00760 |    0.08639 |
+# | 11 | module.layer2.1.conv2.weight        | (32, 20, 3, 3) |          5760 |           5760 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.09702 | -0.00599 |    0.07635 |
+# | 12 | module.layer2.2.conv1.weight        | (20, 32, 3, 3) |          5760 |           5760 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.11464 | -0.01339 |    0.09051 |
+# | 13 | module.layer2.2.conv2.weight        | (32, 20, 3, 3) |          5760 |           5760 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.09177 |  0.00195 |    0.07188 |
+# | 14 | module.layer3.0.conv1.weight        | (64, 32, 3, 3) |         18432 |          18432 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.09764 | -0.00680 |    0.07753 |
+# | 15 | module.layer3.0.conv2.weight        | (64, 64, 3, 3) |         36864 |          36864 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.09308 | -0.00392 |    0.07406 |
+# | 16 | module.layer3.0.downsample.0.weight | (64, 32, 1, 1) |          2048 |           2048 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.12596 | -0.00848 |    0.09993 |
+# | 17 | module.layer3.1.conv1.weight        | (64, 64, 3, 3) |         36864 |          11060 |    0.00000 |    0.00000 |  0.00000 |  6.49414 |  0.00000 |   69.99783 | 0.07444 | -0.00396 |    0.03728 |
+# | 18 | module.layer3.1.conv2.weight        | (64, 64, 3, 3) |         36864 |          11060 |    0.00000 |    0.00000 |  0.00000 |  7.49512 |  0.00000 |   69.99783 | 0.06792 | -0.00462 |    0.03381 |
+# | 19 | module.layer3.2.conv1.weight        | (64, 64, 3, 3) |         36864 |          11060 |    0.00000 |    0.00000 |  0.00000 |  9.81445 |  0.00000 |   69.99783 | 0.06811 | -0.00477 |    0.03417 |
+# | 20 | module.layer3.2.conv2.weight        | (64, 64, 3, 3) |         36864 |          11060 |    0.00000 |    0.00000 |  0.00000 | 26.00098 |  0.00000 |   69.99783 | 0.03877 |  0.00056 |    0.01954 |
+# | 21 | module.fc.weight                    | (10, 64)       |           640 |            640 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.56077 | -0.00002 |    0.48798 |
+# | 22 | Total sparsity:                     | -              |        251888 |         148672 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |   40.97694 | 0.00000 |  0.00000 |    0.00000 |
+# +----+-------------------------------------+----------------+---------------+----------------+------------+------------+----------+----------+----------+------------+---------+----------+------------+
+# Total sparsity: 40.98
+#
+# --- validate (epoch=359)-----------
+# 5000 samples (256 per mini-batch)
+# ==> Top1: 93.320    Top5: 99.880    Loss: 0.246
+#
+# ==> Best Top1: 93.740   On Epoch: 265
+#
+# Saving checkpoint to: logs/2018.10.09-020359/checkpoint.pth.tar
+# --- test ---------------------
+# 10000 samples (256 per mini-batch)
+# ==> Top1: 91.580    Top5: 99.710    Loss: 0.355
+#
+
+version: 1
+pruners:
+  low_pruner:
+    class: StructuredAutomatedGradualPruner
+    initial_sparsity : 0.10
+    final_sparsity: 0.40
+    reg_regims:
+      module.layer2.0.conv1.weight: Filters
+      module.layer2.1.conv1.weight: Filters
+      module.layer2.2.conv1.weight: Filters
+
+  fine_pruner:
+    class:  AutomatedGradualPruner
+    initial_sparsity : 0.05
+    final_sparsity: 0.70
+    weights: [module.layer3.1.conv1.weight,  module.layer3.1.conv2.weight,
+              module.layer3.2.conv1.weight,  module.layer3.2.conv2.weight]
+
+lr_schedulers:
+  pruning_lr:
+    class: StepLR
+    step_size: 50
+    gamma: 0.10
+
+
+extensions:
+  net_thinner:
+      class: 'FilterRemover'
+      thinning_func_str: remove_filters
+      arch: 'resnet20_cifar'
+      dataset: 'cifar10'
+
+policies:
+  - pruner:
+      instance_name : low_pruner
+    starting_epoch: 180
+    ending_epoch: 200
+    frequency: 2
+
+  - pruner:
+      instance_name : fine_pruner
+    starting_epoch: 200
+    ending_epoch: 220
+    frequency: 2
+  # Currently the thinner is disabled until the end, because it interacts with the sparsity
+  # goals of the StructuredAutomatedGradualPruner.
+  # This can be fixed rather easily.
+  # - extension:
+  #     instance_name: net_thinner
+  #   starting_epoch: 0
+  #   ending_epoch: 20
+  #   frequency: 2
+
+# After completeing the pruning, we perform network thinning and continue fine-tuning.
+  - extension:
+      instance_name: net_thinner
+    epochs: [202]
+
+  - lr_scheduler:
+      instance_name: pruning_lr
+    starting_epoch: 0
+    ending_epoch: 400
+    frequency: 1
diff --git a/examples/agp-pruning/resnet20_filters.schedule_agp_2.yaml b/examples/agp-pruning/resnet20_filters.schedule_agp_2.yaml
new file mode 100755
index 0000000000000000000000000000000000000000..c91a157934b587c0d9d6c18ecd9c98f1014f01a3
--- /dev/null
+++ b/examples/agp-pruning/resnet20_filters.schedule_agp_2.yaml
@@ -0,0 +1,119 @@
+# Baseline results:
+# Top1: 91.780    Top5: 99.710    Loss: 0.376
+# time python3 compress_classifier.py --arch resnet20_cifar  ../../../data.cifar10 -p=50 --lr=0.1 --epochs=180 --compress=../agp-pruning/resnet20_filters.schedule_agp.yaml -j=1 --deterministic --resume=../ssl/checkpoints/checkpoint_trained_dense.pth.tar
+#
+# 
+# Parameters:
+# +----+-------------------------------------+----------------+---------------+----------------+------------+------------+----------+----------+----------+------------+---------+----------+------------+
+# |    | Name                                | Shape          |   NNZ (dense) |   NNZ (sparse) |   Cols (%) |   Rows (%) |   Ch (%) |   2D (%) |   3D (%) |   Fine (%) |     Std |     Mean |   Abs-Mean |
+# |----+-------------------------------------+----------------+---------------+----------------+------------+------------+----------+----------+----------+------------+---------+----------+------------|
+# |  0 | module.conv1.weight                 | (16, 3, 3, 3)  |           432 |            432 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.39196 | -0.00533 |    0.27677 |
+# |  1 | module.layer1.0.conv1.weight        | (10, 16, 3, 3) |          1440 |           1440 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.17516 | -0.01627 |    0.12761 |
+# |  2 | module.layer1.0.conv2.weight        | (16, 10, 3, 3) |          1440 |           1440 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.17375 |  0.00208 |    0.12753 |
+# |  3 | module.layer1.1.conv1.weight        | (10, 16, 3, 3) |          1440 |           1440 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.14753 | -0.02355 |    0.11205 |
+# |  4 | module.layer1.1.conv2.weight        | (16, 10, 3, 3) |          1440 |           1440 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.13242 | -0.00280 |    0.10184 |
+# |  5 | module.layer1.2.conv1.weight        | (10, 16, 3, 3) |          1440 |           1440 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.18848 | -0.00708 |    0.13828 |
+# |  6 | module.layer1.2.conv2.weight        | (16, 10, 3, 3) |          1440 |           1440 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.15502 | -0.00528 |    0.11709 |
+# |  7 | module.layer2.0.conv1.weight        | (20, 16, 3, 3) |          2880 |           2880 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.15266 | -0.00169 |    0.11773 |
+# |  8 | module.layer2.0.conv2.weight        | (32, 20, 3, 3) |          5760 |           5760 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.13070 | -0.00823 |    0.10204 |
+# |  9 | module.layer2.0.downsample.0.weight | (32, 16, 1, 1) |           512 |            512 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.25380 | -0.01324 |    0.17815 |
+# | 10 | module.layer2.1.conv1.weight        | (20, 32, 3, 3) |          5760 |           5760 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.11349 | -0.00928 |    0.08977 |
+# | 11 | module.layer2.1.conv2.weight        | (32, 20, 3, 3) |          5760 |           5760 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.09904 | -0.00621 |    0.07856 |
+# | 12 | module.layer2.2.conv1.weight        | (20, 32, 3, 3) |          5760 |           5760 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.11538 | -0.01280 |    0.09106 |
+# | 13 | module.layer2.2.conv2.weight        | (32, 20, 3, 3) |          5760 |           5760 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.09239 |  0.00091 |    0.07236 |
+# | 14 | module.layer3.0.conv1.weight        | (64, 32, 3, 3) |         18432 |          18432 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.09853 | -0.00671 |    0.07821 |
+# | 15 | module.layer3.0.conv2.weight        | (64, 64, 3, 3) |         36864 |          36864 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.09391 | -0.00407 |    0.07466 |
+# | 16 | module.layer3.0.downsample.0.weight | (64, 32, 1, 1) |          2048 |           2048 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.12660 | -0.00968 |    0.10101 |
+# | 17 | module.layer3.1.conv1.weight        | (64, 64, 3, 3) |         36864 |          11060 |    0.00000 |    0.00000 |  0.00000 |  6.56738 |  0.00000 |   69.99783 | 0.07488 | -0.00414 |    0.03739 |
+# | 18 | module.layer3.1.conv2.weight        | (64, 64, 3, 3) |         36864 |          11060 |    0.00000 |    0.00000 |  0.00000 |  7.69043 |  0.00000 |   69.99783 | 0.06839 | -0.00472 |    0.03404 |
+# | 19 | module.layer3.2.conv1.weight        | (64, 64, 3, 3) |         36864 |          11060 |    0.00000 |    0.00000 |  0.00000 |  9.47266 |  0.00000 |   69.99783 | 0.06867 | -0.00485 |    0.03450 |
+# | 20 | module.layer3.2.conv2.weight        | (64, 64, 3, 3) |         36864 |          11060 |    0.00000 |    0.00000 |  0.00000 | 26.41602 |  0.00000 |   69.99783 | 0.03915 |  0.00033 |    0.01970 |
+# | 21 | module.fc.weight                    | (10, 64)       |           640 |            640 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.56522 | -0.00002 |    0.49040 |
+# | 22 | Total sparsity:                     | -              |        246704 |         143488 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |   41.83799 | 0.00000 |  0.00000 |    0.00000 |
+# +----+-------------------------------------+----------------+---------------+----------------+------------+------------+----------+----------+----------+------------+---------+----------+------------+
+# Total sparsity: 41.84
+#
+# --- validate (epoch=359)-----------
+# 5000 samples (256 per mini-batch)
+# ==> Top1: 92.540    Top5: 99.960    Loss: 0.246
+#
+# ==> Best Top1: 93.580   On Epoch: 328
+#
+# Saving checkpoint to: logs/2018.10.09-200709/checkpoint.pth.tar
+# --- test ---------------------
+# 10000 samples (256 per mini-batch)
+# ==> Top1: 91.190    Top5: 99.660    Loss: 0.372
+#
+#
+# Log file for this run: /home/cvds_lab/nzmora/pytorch_workspace/distiller/examples/classifier_compression/logs/2018.10.09-200709/2018.10.09-200709.log
+#
+# real    32m23.439s
+# user    74m59.073s
+# sys     9m7.764s
+
+version: 1
+pruners:
+  low_pruner:
+    class: StructuredAutomatedGradualPruner
+    initial_sparsity : 0.10
+    final_sparsity: 0.40
+    reg_regims:
+      module.layer1.0.conv1.weight: Filters
+      module.layer1.1.conv1.weight: Filters
+      module.layer1.2.conv1.weight: Filters
+      module.layer2.0.conv1.weight: Filters
+      module.layer2.1.conv1.weight: Filters
+      module.layer2.2.conv1.weight: Filters
+
+  fine_pruner:
+    class:  AutomatedGradualPruner
+    initial_sparsity : 0.05
+    final_sparsity: 0.70
+    weights: [module.layer3.1.conv1.weight,  module.layer3.1.conv2.weight,
+              module.layer3.2.conv1.weight,  module.layer3.2.conv2.weight]
+
+lr_schedulers:
+  pruning_lr:
+    class: StepLR
+    step_size: 50
+    gamma: 0.10
+
+
+extensions:
+  net_thinner:
+      class: 'FilterRemover'
+      thinning_func_str: remove_filters
+      arch: 'resnet20_cifar'
+      dataset: 'cifar10'
+
+policies:
+  - pruner:
+      instance_name : low_pruner
+    starting_epoch: 180
+    ending_epoch: 200
+    frequency: 2
+
+  - pruner:
+      instance_name : fine_pruner
+    starting_epoch: 200
+    ending_epoch: 220
+    frequency: 2
+  # Currently the thinner is disabled until the end, because it interacts with the sparsity
+  # goals of the StructuredAutomatedGradualPruner.
+  # This can be fixed rather easily.
+  # - extension:
+  #     instance_name: net_thinner
+  #   starting_epoch: 0
+  #   ending_epoch: 20
+  #   frequency: 2
+
+# After completeing the pruning, we perform network thinning and continue fine-tuning.
+  - extension:
+      instance_name: net_thinner
+    epochs: [202]
+
+  - lr_scheduler:
+      instance_name: pruning_lr
+    starting_epoch: 0
+    ending_epoch: 400
+    frequency: 1
diff --git a/models/__init__.py b/models/__init__.py
index 91f8cec936dd8e634b881dd2650a6ab1fec4fdee..8d40c77d8f59b1fc8ff987a716a0f1ea2182fd4f 100755
--- a/models/__init__.py
+++ b/models/__init__.py
@@ -74,5 +74,10 @@ def create_model(pretrained, dataset, arch, parallel=True, device_ids=None):
         model.features = torch.nn.DataParallel(model.features, device_ids=device_ids)
     elif parallel:
         model = torch.nn.DataParallel(model, device_ids=device_ids)
+
+    # explicitly add a softmax layer, because it is useful when exporting to ONNX
+    model.original_forward = model.forward
+    softmax = torch.nn.Softmax(dim=1)
+    model.forward = lambda input: softmax(model.original_forward(input))
     model.cuda()
     return model