diff --git a/README.md b/README.md
index a767009b81bc21ffadb9e49e9a193ea11a28e739..77ed092dd0d414305da0ab59c0b947d305c92d29 100755
--- a/README.md
+++ b/README.md
@@ -285,7 +285,7 @@ $ python3 compress_classifier.py --resume=../ssl/checkpoints/checkpoint_trained_
 This example performs 8-bit quantization of ResNet20 for CIFAR10.  We've included in the git repository the checkpoint of a ResNet20 model that we've trained with 32-bit floats, so we'll take this model and quantize it:
 
 ```
-$ python3 compress_classifier.py -a resnet20_cifar ../../../data.cifar10 --resume ../examples/ssl/checkpoints/checkpoint_trained_dense.pth.tar --quantize-eval --evaluate
+$ python3 compress_classifier.py -a resnet20_cifar ../../../data.cifar10 --resume ../ssl/checkpoints/checkpoint_trained_dense.pth.tar --quantize-eval --evaluate
 ```
 
 The command-line above will save a checkpoint named `quantized_checkpoint.pth.tar` containing the quantized model parameters. See more examples [here](https://github.com/NervanaSystems/distiller/blob/master/examples/quantization/post_training_quant.md).
diff --git a/apputils/checkpoint.py b/apputils/checkpoint.py
index e6fc056b72f19c7f2c275a17fcfc637109b4bfd4..5f848831bf0fb80709d3b06ddd315bd83ced64f2 100755
--- a/apputils/checkpoint.py
+++ b/apputils/checkpoint.py
@@ -85,7 +85,7 @@ def load_checkpoint(model, chkpt_file, optimizer=None):
 
     if os.path.isfile(chkpt_file):
         msglogger.info("=> loading checkpoint %s", chkpt_file)
-        checkpoint = torch.load(chkpt_file)
+        checkpoint = torch.load(chkpt_file, map_location = lambda storage, loc: storage)
         msglogger.info("Checkpoint keys:\n{}".format("\n\t".join(k for k in checkpoint.keys())))
         start_epoch = checkpoint['epoch'] + 1
         best_top1 = checkpoint.get('best_top1', None)
diff --git a/distiller/model_summaries.py b/distiller/model_summaries.py
index e04ce067d6a7e40a34fb8f2f962d51e32485797c..08808666d02ba39796bd3cdb4ec046545423e151 100755
--- a/distiller/model_summaries.py
+++ b/distiller/model_summaries.py
@@ -213,7 +213,8 @@ def model_performance_summary(model, dummy_input, batch_size=1):
     model = distiller.make_non_parallel_copy(model)
     model.apply(install_perf_collector)
     # Now run the forward path and collect the data
-    model(dummy_input.cuda())
+    dummy_input = dummy_input.to(distiller.model_device(model))
+    model(dummy_input)
     # Unregister from the forward hooks
     for handle in hook_handles:
         handle.remove()
diff --git a/distiller/pruning/automated_gradual_pruner.py b/distiller/pruning/automated_gradual_pruner.py
index bede4887de3ee225eb8d798df37a955b056b372e..e267a686df218f67065f04b2e93a5e6407f5554d 100755
--- a/distiller/pruning/automated_gradual_pruner.py
+++ b/distiller/pruning/automated_gradual_pruner.py
@@ -99,10 +99,10 @@ class StructuredAGP(AutomatedGradualPrunerBase):
 # TODO: this class parameterization is cumbersome: the ranking functions (per structure)
 # should come from the YAML schedule
 class L1RankedStructureParameterPruner_AGP(StructuredAGP):
-    def __init__(self, name, initial_sparsity, final_sparsity, group_type, weights, group_dependency=None):
+    def __init__(self, name, initial_sparsity, final_sparsity, group_type, weights, group_dependency=None, kwargs=None):
         super().__init__(name, initial_sparsity, final_sparsity)
-        self.pruner = L1RankedStructureParameterPruner(name, group_type, desired_sparsity=0,
-                                                       weights=weights, group_dependency=group_dependency)
+        self.pruner = L1RankedStructureParameterPruner(name, group_type, desired_sparsity=0, weights=weights,
+                                                       group_dependency=group_dependency, kwargs=kwargs)
 
 
 class ActivationAPoZRankedFilterPruner_AGP(StructuredAGP):
diff --git a/distiller/pruning/ranked_structures_pruner.py b/distiller/pruning/ranked_structures_pruner.py
index 58f66818b21d10b8b859f3b3777cd0c8088a7a84..1ca9237a42a5d83616c9566ee1487599d8317dd7 100755
--- a/distiller/pruning/ranked_structures_pruner.py
+++ b/distiller/pruning/ranked_structures_pruner.py
@@ -14,6 +14,7 @@
 # limitations under the License.
 #
 
+from functools import partial
 import numpy as np
 import logging
 import torch
@@ -82,12 +83,17 @@ class L1RankedStructureParameterPruner(RankedStructureParameterPruner):
 
     This class prunes to a prescribed percentage of structured-sparsity (level pruning).
     """
-    def __init__(self, name, group_type, desired_sparsity, weights, group_dependency=None):
+    def __init__(self, name, group_type, desired_sparsity, weights, group_dependency=None, kwargs=None):
         super().__init__(name, group_type, desired_sparsity, weights, group_dependency)
-        if group_type not in ['3D', 'Filters', 'Channels', 'Rows']:
-            raise ValueError("Structure {} was requested but"
-                             "currently only filter (3D) and channel ranking is supported".
+        if group_type not in ['3D', 'Filters', 'Channels', 'Rows', 'Blocks']:
+            raise ValueError("Structure {} was requested but "
+                             "currently ranking of this shape is not supported".
                              format(group_type))
+        if group_type == 'Blocks':
+            try:
+                self.block_shape = kwargs['block_shape']
+            except KeyError:
+                raise ValueError("When defining a block pruner you must also specify the block shape")
 
     def prune_group(self, fraction_to_prune, param, param_name, zeros_mask_dict, model=None, binary_map=None):
         if fraction_to_prune == 0:
@@ -98,6 +104,8 @@ class L1RankedStructureParameterPruner(RankedStructureParameterPruner):
             group_pruning_fn = self.rank_and_prune_channels
         elif self.group_type == 'Rows':
             group_pruning_fn = self.rank_and_prune_rows
+        elif self.group_type == 'Blocks':
+            group_pruning_fn = partial(self.rank_and_prune_blocks, block_shape=self.block_shape)
 
         binary_map = group_pruning_fn(fraction_to_prune, param, param_name, zeros_mask_dict, model, binary_map)
         return binary_map
@@ -207,6 +215,92 @@ class L1RankedStructureParameterPruner(RankedStructureParameterPruner):
                        distiller.sparsity(zeros_mask_dict[param_name].mask),
                        fraction_to_prune, num_rows_to_prune, rows_mags.size(0))
 
+    @staticmethod
+    def rank_and_prune_blocks(fraction_to_prune, param, param_name=None,
+                              zeros_mask_dict=None, model=None, binary_map=None, block_shape=None):
+        """Block-wise pruning for 4D tensors.
+
+        The block shape is specified using a tuple: [block_repetitions, block_depth, block_height, block_width].
+        The dimension 'block_repetitions' specifies in how many consecutive filters the "basic block"
+        (shaped as [block_depth, block_height, block_width]) repeats to produce a (4D) "super block".
+
+        For example:
+
+          block_pruner:
+            class: L1RankedStructureParameterPruner_AGP
+            initial_sparsity : 0.05
+            final_sparsity: 0.70
+            group_type: Blocks
+            kwargs:
+              block_shape: [1,8,1,1]  # [block_repetitions, block_depth, block_height, block_width]
+
+        Currently the only supported block shape is: block_repetitions x block_depth x 1 x 1
+        """
+        if len(block_shape) != 4:
+            raise ValueError("The block shape must be specified as a 4-element tuple")
+        block_repetitions, block_depth, block_height, block_width = block_shape
+        if not block_width == block_height == 1:
+            raise ValueError("Currently the only supported block shape is: block_repetitions x block_depth x 1 x 1")
+        super_block_volume = distiller.volume(block_shape)
+        num_super_blocks = distiller.volume(param) / super_block_volume
+        if distiller.volume(param) % super_block_volume != 0:
+            raise ValueError("The super-block size must divide the weight tensor exactly.")
+
+        num_filters = param.size(0)
+        num_channels = param.size(1)
+        kernel_size = param.size(2) * param.size(3)
+
+        if block_depth > 1:
+            view_dims = (
+                num_filters*num_channels//(block_repetitions*block_depth),
+                block_repetitions*block_depth,
+                kernel_size,
+                )
+        else:
+            view_dims = (
+                num_filters // block_repetitions,
+                block_repetitions,
+                -1,
+                )
+
+        def rank_blocks(fraction_to_prune, param):
+            # Create a view where each block is a column
+            view1 = param.view(*view_dims)
+            # Next, compute the sums of each column (block)
+            block_sums = view1.abs().sum(dim=1)
+
+            # Now group by channels
+            block_mags = block_sums.view(-1)  # flatten
+            k = int(fraction_to_prune * block_mags.size(0))
+            if k == 0:
+                msglogger.info("Too few blocks (%d)- can't prune %.1f%% blocks",
+                               block_mags.size(0), 100*fraction_to_prune)
+                return None, None
+
+            bottomk, _ = torch.topk(block_mags, k, largest=False, sorted=True)
+            return bottomk, block_mags
+
+        def binary_map_to_mask(binary_map, param):
+            a = binary_map.view(view_dims[0], view_dims[2])
+            c = a.unsqueeze(1)
+            d = c.expand(*view_dims).contiguous()
+            return d.view(num_filters, num_channels, param.size(2), param.size(3))
+
+        if binary_map is None:
+            bottomk_blocks, block_mags = rank_blocks(fraction_to_prune, param)
+            if bottomk_blocks is None:
+                # Empty list means that fraction_to_prune is too low to prune anything
+                return
+            threshold = bottomk_blocks[-1]
+            binary_map = block_mags.gt(threshold).type(param.data.type())
+
+        if zeros_mask_dict is not None:
+            zeros_mask_dict[param_name].mask = binary_map_to_mask(binary_map, param)
+            msglogger.info("L1RankedStructureParameterPruner - param: %s pruned=%.3f goal=%.3f (%d/%d)", param_name,
+                           distiller.sparsity_blocks(zeros_mask_dict[param_name].mask, block_shape=block_shape),
+                           fraction_to_prune, binary_map.sum().item(), num_super_blocks)
+        return binary_map
+
 
 def mask_from_filter_order(filters_ordered_by_criterion, param, num_filters, binary_map):
     if binary_map is None:
diff --git a/distiller/thinning.py b/distiller/thinning.py
index ad3ff26d3fb0909901b0021b2de70dadc48b9039..5cdff7f2ff14491073dfa3429d4978a3d793f11d 100755
--- a/distiller/thinning.py
+++ b/distiller/thinning.py
@@ -68,12 +68,13 @@ def create_graph(dataset, arch):
     if dataset == 'imagenet':
         dummy_input = torch.randn((1, 3, 224, 224), requires_grad=False)
     elif dataset == 'cifar10':
-        dummy_input = torch.randn((1, 3, 32, 32))
+        dummy_input = torch.randn((1, 3, 32, 32), requires_grad=False)
     assert dummy_input is not None, "Unsupported dataset ({}) - aborting draw operation".format(dataset)
 
     model = create_model(False, dataset, arch, parallel=False)
     assert model is not None
-    return SummaryGraph(model, dummy_input.cuda())
+    dummy_input = dummy_input.to(distiller.model_device(model))
+    return SummaryGraph(model, dummy_input)
 
 
 def param_name_2_layer_name(param_name):
@@ -486,7 +487,7 @@ def execute_thinning_recipe(model, zeros_mask_dict, recipe, optimizer, loaded_fr
                     msglogger.debug("[thinning] {}: setting {} to {}".
                                     format(layer_name, attr, indices_to_select.nelement()))
                     setattr(layers[layer_name], attr,
-                            torch.index_select(running, dim=dim_to_trim, index=indices_to_select))
+                            torch.index_select(running, dim=dim_to_trim, index=indices_to_select.to(running.device)))
             else:
                 msglogger.debug("[thinning] {}: setting {} to {}".format(layer_name, attr, val))
                 setattr(layers[layer_name], attr, val)
@@ -521,13 +522,13 @@ def execute_thinning_recipe(model, zeros_mask_dict, recipe, optimizer, loaded_fr
                     param.grad = param.grad.resize_(*directive[3])
             else:
                 if param.data.size(dim) != len_indices:
-                    param.data = torch.index_select(param.data, dim, indices)
+                    param.data = torch.index_select(param.data, dim, indices.to(param.device))
                     msglogger.debug("[thinning] changed param {} shape: {}".format(param_name, len_indices))
                 # We also need to change the dimensions of the gradient tensor.
                 # If have not done a backward-pass thus far, then the gradient will
                 # not exist, and therefore won't need to be re-dimensioned.
                 if param.grad is not None and param.grad.size(dim) != len_indices:
-                    param.grad = torch.index_select(param.grad, dim, indices)
+                    param.grad = torch.index_select(param.grad, dim, indices.to(param.device))
                     if optimizer_thinning(optimizer, param, dim, indices):
                         msglogger.debug("Updated velocity buffer %s" % param_name)
 
diff --git a/distiller/thresholding.py b/distiller/thresholding.py
index da2c3b629446b721aecae93751f566fb2f35dc0a..be03a27410034d216c63021e1a5ddb18c80c0ec9 100755
--- a/distiller/thresholding.py
+++ b/distiller/thresholding.py
@@ -62,7 +62,7 @@ def group_threshold_binary_map(param, group_type, threshold, threshold_criteria)
         view_2d = param.view(-1, param.size(2) * param.size(3))
         # 1. Determine if the kernel "value" is below the threshold, by creating a 1D
         #    thresholds tensor with length = #IFMs * # OFMs
-        thresholds = torch.Tensor([threshold] * param.size(0) * param.size(1)).cuda()
+        thresholds = torch.Tensor([threshold] * param.size(0) * param.size(1)).to(param.device)
         # 2. Create a binary thresholds mask, where we use the mean of the abs values of the
         #    elements in each channel as the threshold filter.
         # 3. Apply the threshold filter
@@ -71,20 +71,20 @@ def group_threshold_binary_map(param, group_type, threshold, threshold_criteria)
 
     elif group_type == 'Rows':
         assert param.dim() == 2, "This regularization is only supported for 2D weights"
-        thresholds = torch.Tensor([threshold] * param.size(0)).cuda()
+        thresholds = torch.Tensor([threshold] * param.size(0)).to(param.device)
         binary_map = threshold_policy(param, thresholds, threshold_criteria)
         return binary_map
 
     elif group_type == 'Cols':
         assert param.dim() == 2, "This regularization is only supported for 2D weights"
-        thresholds = torch.Tensor([threshold] * param.size(1)).cuda()
+        thresholds = torch.Tensor([threshold] * param.size(1)).to(param.device)
         binary_map = threshold_policy(param, thresholds, threshold_criteria, dim=0)
         return binary_map
 
     elif group_type == '3D' or group_type == 'Filters':
         assert param.dim() == 4, "This thresholding is only supported for 4D weights"
         view_filters = param.view(param.size(0), -1)
-        thresholds = torch.Tensor([threshold] * param.size(0)).cuda()
+        thresholds = torch.Tensor([threshold] * param.size(0)).to(param.device)
         binary_map = threshold_policy(view_filters, thresholds, threshold_criteria)
         return binary_map
 
@@ -109,7 +109,7 @@ def group_threshold_binary_map(param, group_type, threshold, threshold_criteria)
         # Next, compute the sum of the squares (of the elements in each row/kernel)
         kernel_means = view_2d.abs().mean(dim=1)
         k_means_mat = kernel_means.view(num_filters, num_kernels_per_filter).t()
-        thresholds = torch.Tensor([threshold] * num_kernels_per_filter).cuda()
+        thresholds = torch.Tensor([threshold] * num_kernels_per_filter).to(param.device)
         binary_map = k_means_mat.data.mean(dim=1).gt(thresholds).type(param.type())
         return binary_map
 
diff --git a/distiller/utils.py b/distiller/utils.py
index 17d1097c52b02fdb87ced151ba86fc78d08d3e65..2c4bcf066bf53d7f08b77965088c1f041a6ee2a7 100755
--- a/distiller/utils.py
+++ b/distiller/utils.py
@@ -25,6 +25,14 @@ import torch.nn as nn
 from copy import deepcopy
 
 
+def model_device(model):
+    """Determine the device the model is allocated on."""
+    # Source: https://discuss.pytorch.org/t/how-to-check-if-model-is-on-cuda/180
+    if next(model.parameters()).is_cuda:
+        return 'cuda'
+    return 'cpu'
+
+
 def to_np(var):
     return var.data.cpu().numpy()
 
@@ -131,7 +139,7 @@ def volume(tensor):
     """return the volume of a pytorch tensor"""
     if isinstance(tensor, torch.FloatTensor) or isinstance(tensor, torch.cuda.FloatTensor):
         return np.prod(tensor.shape)
-    if isinstance(tensor, tuple):
+    if isinstance(tensor, tuple) or isinstance(tensor, list):
         return np.prod(tensor)
     raise ValueError
 
@@ -246,6 +254,50 @@ def density_ch(tensor):
     return 1 - sparsity_ch(tensor)
 
 
+def sparsity_blocks(tensor, block_shape):
+    """Block-wise sparsity for 4D tensors
+
+    Currently the only supported block shape is: block_repetitions x block_depth x 1 x 1
+    """
+    if tensor.dim() != 4:
+        raise ValueError("sparsity_blocks is only supported for 4-D tensors")
+
+    if len(block_shape) != 4:
+        raise ValueError("Block shape must be specified as a 4-element tuple")
+    block_repetitions, block_depth, block_height, block_width = block_shape
+    if not block_width == block_height == 1:
+        raise ValueError("Currently the only supported block shape is: block_repetitions x block_depth x 1 x 1")
+
+    super_block_volume = volume(block_shape)
+    num_super_blocks = volume(tensor) / super_block_volume
+
+    num_filters, num_channels = tensor.size(0), tensor.size(1)
+    kernel_size = tensor.size(2) * tensor.size(3)
+
+    # Create a view where each block is a column
+    if block_depth > 1:
+        view_dims = (
+            num_filters*num_channels//(block_repetitions*block_depth),
+            block_repetitions*block_depth,
+            kernel_size,
+            )
+    else:
+        view_dims = (
+            num_filters // block_repetitions,
+            block_repetitions,
+            -1,
+            )
+    view1 = tensor.view(*view_dims)
+
+    # Next, compute the sums of each column (block)
+    block_sums = view1.abs().sum(dim=1)
+
+    # Next, compute the sums of each column (block)
+    block_sums = view1.abs().sum(dim=1)
+    nonzero_blocks = len(torch.nonzero(block_sums))
+    return 1 - nonzero_blocks/num_super_blocks
+
+
 def sparsity_matrix(tensor, dim):
     """Generic sparsity computation for 2D matrices"""
     if tensor.dim() != 2:
diff --git a/examples/agp-pruning/resnet50.schedule_agp.1x1x8-blocks.yaml b/examples/agp-pruning/resnet50.schedule_agp.1x1x8-blocks.yaml
new file mode 100755
index 0000000000000000000000000000000000000000..342492cd57ae00f93555028a51a71776cd9cf4a1
--- /dev/null
+++ b/examples/agp-pruning/resnet50.schedule_agp.1x1x8-blocks.yaml
@@ -0,0 +1,147 @@
+#
+# This schedule performs 1x1x8 block pruning using L1-norm ranking and AGP for the setting the pruning-rate decay.
+# The final Linear layer (FC) is also pruned to 70%.
+#
+# Best Top1: 76.358 (epoch 72) vs. 76.15 baseline (+0.2%)
+#
+# time python3 compress_classifier.py -a=resnet50 --pretrained -p=50 ../../../data.imagenet/ -j=22 --epochs=100 --lr=0.0005 --compress=../agp-pruning/resnet50.schedule_agp.1x1x8-blocks.yaml --validation-size=0 --num-best-scores=10
+#
+# Parameters:
+# +----+-------------------------------------+--------------------+---------------+----------------+------------+------------+----------+----------+----------+------------+---------+----------+------------+
+# |    | Name                                | Shape              |   NNZ (dense) |   NNZ (sparse) |   Cols (%) |   Rows (%) |   Ch (%) |   2D (%) |   3D (%) |   Fine (%) |     Std |     Mean |   Abs-Mean |
+# |----+-------------------------------------+--------------------+---------------+----------------+------------+------------+----------+----------+----------+------------+---------+----------+------------|
+# |  0 | module.conv1.weight                 | (64, 3, 7, 7)      |          9408 |           9408 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.11032 | -0.00044 |    0.06760 |
+# |  1 | module.layer1.0.conv1.weight        | (64, 64, 1, 1)     |          4096 |           4096 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.06353 | -0.00371 |    0.03587 |
+# |  2 | module.layer1.0.conv2.weight        | (64, 64, 3, 3)     |         36864 |          11064 |    0.00000 |    0.00000 |  0.00000 | 28.12500 |  7.81250 |   69.98698 | 0.02277 |  0.00061 |    0.00835 |
+# |  3 | module.layer1.0.conv3.weight        | (256, 64, 1, 1)    |         16384 |          16384 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.03142 |  0.00034 |    0.01880 |
+# |  4 | module.layer1.0.downsample.0.weight | (256, 64, 1, 1)    |         16384 |          16384 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.05117 | -0.00298 |    0.02858 |
+# |  5 | module.layer1.1.conv1.weight        | (64, 256, 1, 1)    |         16384 |          16384 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.02722 |  0.00105 |    0.01803 |
+# |  6 | module.layer1.1.conv2.weight        | (64, 64, 3, 3)     |         36864 |          11064 |    0.00000 |    0.00000 |  0.00000 | 18.75000 |  1.56250 |   69.98698 | 0.02097 |  0.00016 |    0.00841 |
+# |  7 | module.layer1.1.conv3.weight        | (256, 64, 1, 1)    |         16384 |          16384 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.02896 | -0.00002 |    0.01815 |
+# |  8 | module.layer1.2.conv1.weight        | (64, 256, 1, 1)    |         16384 |          16384 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.02671 |  0.00015 |    0.01929 |
+# |  9 | module.layer1.2.conv2.weight        | (64, 64, 3, 3)     |         36864 |          11064 |    0.00000 |    0.00000 |  0.00000 | 13.47656 |  0.00000 |   69.98698 | 0.02149 | -0.00033 |    0.00930 |
+# | 10 | module.layer1.2.conv3.weight        | (256, 64, 1, 1)    |         16384 |          16384 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.02755 | -0.00215 |    0.01658 |
+# | 11 | module.layer2.0.conv1.weight        | (128, 256, 1, 1)   |         32768 |          32768 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.03152 | -0.00126 |    0.02213 |
+# | 12 | module.layer2.0.conv2.weight        | (128, 128, 3, 3)   |        147456 |          44240 |    0.00000 |    0.00000 |  0.00000 | 19.04297 |  0.00000 |   69.99783 | 0.01489 | -0.00011 |    0.00633 |
+# | 13 | module.layer2.0.conv3.weight        | (512, 128, 1, 1)   |         65536 |          65536 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.02486 |  0.00003 |    0.01535 |
+# | 14 | module.layer2.0.downsample.0.weight | (512, 256, 1, 1)   |        131072 |         131072 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.02046 | -0.00033 |    0.01198 |
+# | 15 | module.layer2.1.conv1.weight        | (128, 512, 1, 1)   |         65536 |          65536 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.01482 | -0.00005 |    0.00895 |
+# | 16 | module.layer2.1.conv2.weight        | (128, 128, 3, 3)   |        147456 |          44240 |    0.00000 |    0.00000 |  0.00000 | 22.36328 |  0.78125 |   69.99783 | 0.01512 |  0.00037 |    0.00598 |
+# | 17 | module.layer2.1.conv3.weight        | (512, 128, 1, 1)   |         65536 |          65536 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.01964 | -0.00101 |    0.01122 |
+# | 18 | module.layer2.2.conv1.weight        | (128, 512, 1, 1)   |         65536 |          65536 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.02073 | -0.00067 |    0.01437 |
+# | 19 | module.layer2.2.conv2.weight        | (128, 128, 3, 3)   |        147456 |          44240 |    0.00000 |    0.00000 |  0.00000 | 14.64844 |  0.00000 |   69.99783 | 0.01522 |  0.00006 |    0.00622 |
+# | 20 | module.layer2.2.conv3.weight        | (512, 128, 1, 1)   |         65536 |          65536 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.02328 | -0.00032 |    0.01636 |
+# | 21 | module.layer2.3.conv1.weight        | (128, 512, 1, 1)   |         65536 |          65536 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.02161 | -0.00079 |    0.01598 |
+# | 22 | module.layer2.3.conv2.weight        | (128, 128, 3, 3)   |        147456 |          44240 |    0.00000 |    0.00000 |  0.00000 | 12.79297 |  0.00000 |   69.99783 | 0.01498 | -0.00022 |    0.00650 |
+# | 23 | module.layer2.3.conv3.weight        | (512, 128, 1, 1)   |         65536 |          65536 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.02159 | -0.00091 |    0.01488 |
+# | 24 | module.layer3.0.conv1.weight        | (256, 512, 1, 1)   |        131072 |         131072 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.02732 | -0.00100 |    0.01950 |
+# | 25 | module.layer3.0.conv2.weight        | (256, 256, 3, 3)   |        589824 |         176952 |    0.00000 |    0.00000 |  0.00000 | 24.70703 |  0.00000 |   69.99919 | 0.01165 | -0.00010 |    0.00486 |
+# | 26 | module.layer3.0.conv3.weight        | (1024, 256, 1, 1)  |        262144 |         262144 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.02073 | -0.00034 |    0.01470 |
+# | 27 | module.layer3.0.downsample.0.weight | (1024, 512, 1, 1)  |        524288 |         524288 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.01429 |  0.00003 |    0.00978 |
+# | 28 | module.layer3.1.conv1.weight        | (256, 1024, 1, 1)  |        262144 |         262144 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.01360 | -0.00050 |    0.00955 |
+# | 29 | module.layer3.1.conv2.weight        | (256, 256, 3, 3)   |        589824 |         176952 |    0.00000 |    0.00000 |  0.00000 | 16.29639 |  0.00000 |   69.99919 | 0.01055 |  0.00002 |    0.00442 |
+# | 30 | module.layer3.1.conv3.weight        | (1024, 256, 1, 1)  |        262144 |         262144 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.01820 | -0.00090 |    0.01307 |
+# | 31 | module.layer3.2.conv1.weight        | (256, 1024, 1, 1)  |        262144 |         262144 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.01407 | -0.00041 |    0.01007 |
+# | 32 | module.layer3.2.conv2.weight        | (256, 256, 3, 3)   |        589824 |         176952 |    0.00000 |    0.00000 |  0.00000 | 11.88965 |  0.00000 |   69.99919 | 0.01011 | -0.00021 |    0.00433 |
+# | 33 | module.layer3.2.conv3.weight        | (1024, 256, 1, 1)  |        262144 |         262144 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.01698 | -0.00063 |    0.01240 |
+# | 34 | module.layer3.3.conv1.weight        | (256, 1024, 1, 1)  |        262144 |         262144 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.01550 | -0.00058 |    0.01146 |
+# | 35 | module.layer3.3.conv2.weight        | (256, 256, 3, 3)   |        589824 |         176952 |    0.00000 |    0.00000 |  0.00000 | 11.92627 |  0.00000 |   69.99919 | 0.00985 | -0.00019 |    0.00429 |
+# | 36 | module.layer3.3.conv3.weight        | (1024, 256, 1, 1)  |        262144 |         262144 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.01625 | -0.00094 |    0.01197 |
+# | 37 | module.layer3.4.conv1.weight        | (256, 1024, 1, 1)  |        262144 |         262144 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.01617 | -0.00080 |    0.01216 |
+# | 38 | module.layer3.4.conv2.weight        | (256, 256, 3, 3)   |        589824 |         176952 |    0.00000 |    0.00000 |  0.00000 | 11.99951 |  0.00000 |   69.99919 | 0.00980 | -0.00028 |    0.00428 |
+# | 39 | module.layer3.4.conv3.weight        | (1024, 256, 1, 1)  |        262144 |         262144 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.01623 | -0.00131 |    0.01196 |
+# | 40 | module.layer3.5.conv1.weight        | (256, 1024, 1, 1)  |        262144 |         262144 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.01757 | -0.00075 |    0.01337 |
+# | 41 | module.layer3.5.conv2.weight        | (256, 256, 3, 3)   |        589824 |         176952 |    0.00000 |    0.00000 |  0.00000 | 11.16943 |  0.00000 |   69.99919 | 0.01001 | -0.00032 |    0.00438 |
+# | 42 | module.layer3.5.conv3.weight        | (1024, 256, 1, 1)  |        262144 |         262144 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.01725 | -0.00191 |    0.01294 |
+# | 43 | module.layer4.0.conv1.weight        | (512, 1024, 1, 1)  |        524288 |         524288 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.02114 | -0.00099 |    0.01632 |
+# | 44 | module.layer4.0.conv2.weight        | (512, 512, 3, 3)   |       2359296 |         707792 |    0.00000 |    0.00000 |  0.00000 | 19.15894 |  0.00000 |   69.99986 | 0.00801 | -0.00012 |    0.00358 |
+# | 45 | module.layer4.0.conv3.weight        | (2048, 512, 1, 1)  |       1048576 |        1048576 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.01369 | -0.00055 |    0.01057 |
+# | 46 | module.layer4.0.downsample.0.weight | (2048, 1024, 1, 1) |       2097152 |        2097152 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.00892 | -0.00015 |    0.00679 |
+# | 47 | module.layer4.1.conv1.weight        | (512, 2048, 1, 1)  |       1048576 |        1048576 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.01330 | -0.00056 |    0.01038 |
+# | 48 | module.layer4.1.conv2.weight        | (512, 512, 3, 3)   |       2359296 |         707792 |    0.00000 |    0.00000 |  0.00000 | 13.93127 |  0.00000 |   69.99986 | 0.00781 | -0.00028 |    0.00351 |
+# | 49 | module.layer4.1.conv3.weight        | (2048, 512, 1, 1)  |       1048576 |        1048576 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.01347 | -0.00007 |    0.01039 |
+# | 50 | module.layer4.2.conv1.weight        | (512, 2048, 1, 1)  |       1048576 |        1048576 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.01628 | -0.00034 |    0.01277 |
+# | 51 | module.layer4.2.conv2.weight        | (512, 512, 3, 3)   |       2359296 |         707792 |    0.00000 |    0.00000 |  0.00000 | 23.70911 |  0.00000 |   69.99986 | 0.00686 | -0.00021 |    0.00310 |
+# | 52 | module.layer4.2.conv3.weight        | (2048, 512, 1, 1)  |       1048576 |        1048576 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.01262 |  0.00002 |    0.00943 |
+# | 53 | module.fc.weight                    | (1000, 2048)       |       2048000 |         614400 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |   70.00000 | 0.03148 |  0.00299 |    0.01480 |
+# | 54 | Total sparsity:                     | -                  |      25502912 |       16147304 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |   36.68447 | 0.00000 |  0.00000 |    0.00000 |
+# +----+-------------------------------------+--------------------+---------------+----------------+------------+------------+----------+----------+----------+------------+---------+----------+------------+
+# 2018-12-24 13:52:24,645 - Total sparsity: 36.68
+#
+# 2018-12-24 13:52:24,645 - --- validate (epoch=72)-----------
+# 2018-12-24 13:52:24,645 - 50000 samples (256 per mini-batch)
+# 2018-12-24 13:52:44,774 - Epoch: [72][   50/  195]    Loss 0.676330    Top1 82.195312    Top5 96.039062
+# 2018-12-24 13:52:52,702 - Epoch: [72][  100/  195]    Loss 0.799058    Top1 79.386719    Top5 94.863281
+# 2018-12-24 13:53:00,916 - Epoch: [72][  150/  195]    Loss 0.911178    Top1 77.216146    Top5 93.466146
+# 2018-12-24 13:53:08,224 - ==> Top1: 76.358    Top5: 92.972    Loss: 0.952
+#
+# 2018-12-24 13:53:08,308 - ==> Best Top1: 76.454 on Epoch: 1
+# 2018-12-24 13:53:08,308 - ==> Best Top1: 76.446 on Epoch: 0
+# 2018-12-24 13:53:08,308 - ==> Best Top1: 76.416 on Epoch: 3
+# 2018-12-24 13:53:08,308 - ==> Best Top1: 76.358 on Epoch: 72
+# 2018-12-24 13:53:08,308 - ==> Best Top1: 76.344 on Epoch: 2
+# 2018-12-24 13:53:08,308 - ==> Best Top1: 76.326 on Epoch: 69
+# 2018-12-24 13:53:08,308 - ==> Best Top1: 76.320 on Epoch: 68
+# 2018-12-24 13:53:08,308 - ==> Best Top1: 76.318 on Epoch: 70
+# 2018-12-24 13:53:08,309 - ==> Best Top1: 76.300 on Epoch: 58
+# 2018-12-24 13:53:08,309 - ==> Best Top1: 76.284 on Epoch: 71
+
+version: 1
+
+pruners:
+  fc_pruner:
+    class: AutomatedGradualPruner
+    initial_sparsity : 0.05
+    final_sparsity: 0.70
+    weights: module.fc.weight
+
+
+  block_pruner:
+    class: L1RankedStructureParameterPruner_AGP
+    initial_sparsity : 0.05
+    final_sparsity: 0.70
+    group_type: Blocks
+    kwargs:
+      block_shape: [1,8,1,1]  # [block_repetition, block_depth, block_height, block_width]
+    weights: [module.layer1.0.conv2.weight,
+              module.layer1.1.conv2.weight,
+              module.layer1.2.conv2.weight,
+              module.layer2.0.conv2.weight,
+              module.layer2.1.conv2.weight,
+              module.layer2.2.conv2.weight,
+              module.layer2.3.conv2.weight,
+              module.layer3.0.conv2.weight,
+              module.layer3.1.conv2.weight,
+              module.layer3.2.conv2.weight,
+              module.layer3.3.conv2.weight,
+              module.layer3.4.conv2.weight,
+              module.layer3.5.conv2.weight,
+              module.layer4.0.conv2.weight,
+              module.layer4.1.conv2.weight,
+              module.layer4.2.conv2.weight]
+
+
+lr_schedulers:
+  pruning_lr:
+    class: ExponentialLR
+    gamma: 0.95
+
+
+policies:
+  - pruner:
+     instance_name : block_pruner
+    starting_epoch: 0
+    ending_epoch: 30
+    frequency: 1
+
+  - pruner:
+      instance_name : fc_pruner
+    starting_epoch: 0
+    ending_epoch: 30
+    frequency: 3
+
+  - lr_scheduler:
+      instance_name: pruning_lr
+    starting_epoch: 40
+    ending_epoch: 80
+    frequency: 1
diff --git a/examples/agp-pruning/resnet50.schedule_agp.filters_2.yaml b/examples/agp-pruning/resnet50.schedule_agp.filters_2.yaml
new file mode 100755
index 0000000000000000000000000000000000000000..e5713192a62a37ad580bb0ee96c4045341435fe9
--- /dev/null
+++ b/examples/agp-pruning/resnet50.schedule_agp.filters_2.yaml
@@ -0,0 +1,192 @@
+#
+# This schedule performs filter-pruning using L1-norm ranking and AGP for the setting the pruning-rate decay.
+#
+# Best Top1: 74.782 (epoch 94)
+# No. of Parameters: 12,671,168 (of 25,502,912) = 49.69% dense (50.31% sparse)
+# Total MACs: 2,037,186,560 (of 4,089,184,256) = 49.82% compute = 2.01x
+#
+# time python3 compress_classifier.py -a=resnet50 --pretrained -p=50 ../../../data.imagenet/ -j=22 --epochs=100 --lr=0.0005 --compress=resnet50.schedule_agp.filters_2.yaml --validation-size=0 --num-best-scores=10
+#
+# Parameters:
+# +----+-------------------------------------+--------------------+---------------+----------------+------------+------------+----------+----------+----------+------------+---------+----------+------------+
+# |    | Name                                | Shape              |   NNZ (dense) |   NNZ (sparse) |   Cols (%) |   Rows (%) |   Ch (%) |   2D (%) |   3D (%) |   Fine (%) |     Std |     Mean |   Abs-Mean |
+# |----+-------------------------------------+--------------------+---------------+----------------+------------+------------+----------+----------+----------+------------+---------+----------+------------|
+# |  0 | module.conv1.weight                 | (64, 3, 7, 7)      |          9408 |           9408 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.11098 | -0.00043 |    0.06774 |
+# |  1 | module.layer1.0.conv1.weight        | (32, 64, 1, 1)     |          2048 |           2048 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.07631 | -0.00587 |    0.04636 |
+# |  2 | module.layer1.0.conv2.weight        | (32, 32, 3, 3)     |          9216 |           9216 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.04019 |  0.00147 |    0.02596 |
+# |  3 | module.layer1.0.conv3.weight        | (256, 32, 1, 1)    |          8192 |           8192 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.03788 | -0.00045 |    0.02391 |
+# |  4 | module.layer1.0.downsample.0.weight | (256, 64, 1, 1)    |         16384 |          16384 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.05137 | -0.00304 |    0.02857 |
+# |  5 | module.layer1.1.conv1.weight        | (32, 256, 1, 1)    |          8192 |           8192 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.03148 |  0.00120 |    0.02169 |
+# |  6 | module.layer1.1.conv2.weight        | (32, 32, 3, 3)     |          9216 |           9216 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.03669 |  0.00017 |    0.02582 |
+# |  7 | module.layer1.1.conv3.weight        | (256, 32, 1, 1)    |          8192 |           8192 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.03162 | -0.00060 |    0.02006 |
+# |  8 | module.layer1.2.conv1.weight        | (32, 256, 1, 1)    |          8192 |           8192 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.02993 |  0.00020 |    0.02192 |
+# |  9 | module.layer1.2.conv2.weight        | (32, 32, 3, 3)     |          9216 |           9216 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.03611 |  0.00009 |    0.02719 |
+# | 10 | module.layer1.2.conv3.weight        | (256, 32, 1, 1)    |          8192 |           8192 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.02778 | -0.00228 |    0.01659 |
+# | 11 | module.layer2.0.conv1.weight        | (128, 256, 1, 1)   |         32768 |          32768 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.03164 | -0.00144 |    0.02232 |
+# | 12 | module.layer2.0.conv2.weight        | (64, 128, 3, 3)    |         73728 |          73728 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.02147 |  0.00000 |    0.01595 |
+# | 13 | module.layer2.0.conv3.weight        | (512, 64, 1, 1)    |         32768 |          32768 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.02703 |  0.00005 |    0.01656 |
+# | 14 | module.layer2.0.downsample.0.weight | (512, 256, 1, 1)   |        131072 |         131072 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.02051 | -0.00038 |    0.01206 |
+# | 15 | module.layer2.1.conv1.weight        | (64, 512, 1, 1)    |         32768 |          32768 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.01744 | -0.00008 |    0.01081 |
+# | 16 | module.layer2.1.conv2.weight        | (128, 64, 3, 3)    |         73728 |          73728 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.02022 |  0.00011 |    0.01301 |
+# | 17 | module.layer2.1.conv3.weight        | (512, 128, 1, 1)   |         65536 |          65536 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.01982 | -0.00107 |    0.01153 |
+# | 18 | module.layer2.2.conv1.weight        | (64, 512, 1, 1)    |         32768 |          32768 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.02328 | -0.00053 |    0.01618 |
+# | 19 | module.layer2.2.conv2.weight        | (64, 64, 3, 3)     |         36864 |          36864 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.02380 |  0.00012 |    0.01667 |
+# | 20 | module.layer2.2.conv3.weight        | (512, 64, 1, 1)    |         32768 |          32768 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.02561 |  0.00015 |    0.01784 |
+# | 21 | module.layer2.3.conv1.weight        | (64, 512, 1, 1)    |         32768 |          32768 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.02327 | -0.00090 |    0.01733 |
+# | 22 | module.layer2.3.conv2.weight        | (64, 64, 3, 3)     |         36864 |          36864 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.02368 | -0.00043 |    0.01789 |
+# | 23 | module.layer2.3.conv3.weight        | (512, 64, 1, 1)    |         32768 |          32768 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.02287 | -0.00116 |    0.01577 |
+# | 24 | module.layer3.0.conv1.weight        | (256, 512, 1, 1)   |        131072 |         131072 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.02737 | -0.00126 |    0.01964 |
+# | 25 | module.layer3.0.conv2.weight        | (128, 256, 3, 3)   |        294912 |         294912 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.01679 | -0.00019 |    0.01241 |
+# | 26 | module.layer3.0.conv3.weight        | (1024, 128, 1, 1)  |        131072 |         131072 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.02290 | -0.00043 |    0.01647 |
+# | 27 | module.layer3.0.downsample.0.weight | (1024, 512, 1, 1)  |        524288 |         524288 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.01431 | -0.00000 |    0.00982 |
+# | 28 | module.layer3.1.conv1.weight        | (128, 1024, 1, 1)  |        131072 |         131072 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.01517 | -0.00037 |    0.01072 |
+# | 29 | module.layer3.1.conv2.weight        | (128, 128, 3, 3)   |        147456 |         147456 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.01683 | -0.00006 |    0.01212 |
+# | 30 | module.layer3.1.conv3.weight        | (1024, 128, 1, 1)  |        131072 |         131072 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.01959 | -0.00063 |    0.01394 |
+# | 31 | module.layer3.2.conv1.weight        | (128, 1024, 1, 1)  |        131072 |         131072 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.01547 | -0.00032 |    0.01103 |
+# | 32 | module.layer3.2.conv2.weight        | (128, 128, 3, 3)   |        147456 |         147456 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.01644 | -0.00056 |    0.01214 |
+# | 33 | module.layer3.2.conv3.weight        | (1024, 128, 1, 1)  |        131072 |         131072 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.01832 | -0.00054 |    0.01331 |
+# | 34 | module.layer3.3.conv1.weight        | (128, 1024, 1, 1)  |        131072 |         131072 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.01675 | -0.00058 |    0.01250 |
+# | 35 | module.layer3.3.conv2.weight        | (128, 128, 3, 3)   |        147456 |         147456 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.01552 | -0.00053 |    0.01179 |
+# | 36 | module.layer3.3.conv3.weight        | (1024, 128, 1, 1)  |        131072 |         131072 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.01741 | -0.00095 |    0.01280 |
+# | 37 | module.layer3.4.conv1.weight        | (128, 1024, 1, 1)  |        131072 |         131072 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.01738 | -0.00080 |    0.01312 |
+# | 38 | module.layer3.4.conv2.weight        | (128, 128, 3, 3)   |        147456 |         147456 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.01539 | -0.00064 |    0.01169 |
+# | 39 | module.layer3.4.conv3.weight        | (1024, 128, 1, 1)  |        131072 |         131072 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.01709 | -0.00126 |    0.01253 |
+# | 40 | module.layer3.5.conv1.weight        | (128, 1024, 1, 1)  |        131072 |         131072 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.01868 | -0.00072 |    0.01434 |
+# | 41 | module.layer3.5.conv2.weight        | (128, 128, 3, 3)   |        147456 |         147456 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.01528 | -0.00073 |    0.01170 |
+# | 42 | module.layer3.5.conv3.weight        | (1024, 128, 1, 1)  |        131072 |         131072 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.01853 | -0.00212 |    0.01393 |
+# | 43 | module.layer4.0.conv1.weight        | (256, 1024, 1, 1)  |        262144 |         262144 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.02219 | -0.00087 |    0.01715 |
+# | 44 | module.layer4.0.conv2.weight        | (256, 256, 3, 3)   |        589824 |         589824 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.01234 | -0.00011 |    0.00962 |
+# | 45 | module.layer4.0.conv3.weight        | (2048, 256, 1, 1)  |        524288 |         524288 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.01454 | -0.00058 |    0.01133 |
+# | 46 | module.layer4.0.downsample.0.weight | (2048, 1024, 1, 1) |       2097152 |        2097152 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.00905 | -0.00018 |    0.00689 |
+# | 47 | module.layer4.1.conv1.weight        | (256, 2048, 1, 1)  |        524288 |         524288 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.01431 | -0.00032 |    0.01119 |
+# | 48 | module.layer4.1.conv2.weight        | (256, 256, 3, 3)   |        589824 |         589824 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.01231 | -0.00060 |    0.00965 |
+# | 49 | module.layer4.1.conv3.weight        | (2048, 256, 1, 1)  |        524288 |         524288 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.01433 |  0.00003 |    0.01110 |
+# | 50 | module.layer4.2.conv1.weight        | (256, 2048, 1, 1)  |        524288 |         524288 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.01778 | -0.00008 |    0.01397 |
+# | 51 | module.layer4.2.conv2.weight        | (256, 256, 3, 3)   |        589824 |         589824 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.01080 | -0.00034 |    0.00850 |
+# | 52 | module.layer4.2.conv3.weight        | (2048, 256, 1, 1)  |        524288 |         524288 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.01315 |  0.00019 |    0.00992 |
+# | 53 | module.fc.weight                    | (1000, 2048)       |       2048000 |        2048000 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.03325 |  0.00000 |    0.02289 |
+# | 54 | Total sparsity:                     | -                  |      12671168 |       12671168 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.00000 |  0.00000 |    0.00000 |
+# +----+-------------------------------------+--------------------+---------------+----------------+------------+------------+----------+----------+----------+------------+---------+----------+------------+
+# 2018-12-09 13:27:25,875 - Total sparsity: 0.00
+#
+# 2018-12-09 13:27:25,875 - --- validate (epoch=99)-----------
+# 2018-12-09 13:27:25,875 - 50000 samples (256 per mini-batch)
+# 2018-12-09 13:27:46,138 - Epoch: [99][   50/  195]    Loss 0.728680    Top1 80.640625    Top5 95.507812
+# 2018-12-09 13:27:53,943 - Epoch: [99][  100/  195]    Loss 0.850403    Top1 78.128906    Top5 94.128906
+# 2018-12-09 13:28:03,180 - Epoch: [99][  150/  195]    Loss 0.973435    Top1 75.731771    Top5 92.619792
+# 2018-12-09 13:28:10,151 - ==> Top1: 74.738    Top5: 92.080    Loss: 1.018
+#
+# 2018-12-09 13:28:10,230 - ==> Best Top1: 75.896 on Epoch: 0
+# 2018-12-09 13:28:10,230 - ==> Best Top1: 75.402 on Epoch: 1
+# 2018-12-09 13:28:10,230 - ==> Best Top1: 74.916 on Epoch: 2
+# 2018-12-09 13:28:10,230 - ==> Best Top1: 74.782 on Epoch: 94  <==========
+# 2018-12-09 13:28:10,230 - ==> Best Top1: 74.776 on Epoch: 93
+# 2018-12-09 13:28:10,230 - ==> Best Top1: 74.774 on Epoch: 84
+# 2018-12-09 13:28:10,230 - ==> Best Top1: 74.772 on Epoch: 97
+# 2018-12-09 13:28:10,231 - ==> Best Top1: 74.770 on Epoch: 98
+# 2018-12-09 13:28:10,231 - ==> Best Top1: 74.738 on Epoch: 99
+# 2018-12-09 13:28:10,231 - ==> Best Top1: 74.726 on Epoch: 91
+# 2018-12-09 13:28:10,231 - Saving checkpoint to: logs/resnet50_filters_v3.1___2018.12.07-154945/resnet50_filters_v3.1_checkpoint.pth.tar
+# 2018-12-09 13:28:10,458 - --- test ---------------------
+# 2018-12-09 13:28:10,458 - 50000 samples (256 per mini-batch)
+# 2018-12-09 13:28:30,687 - Test: [   50/  195]    Loss 0.728680    Top1 80.640625    Top5 95.507812
+# 2018-12-09 13:28:38,854 - Test: [  100/  195]    Loss 0.850403    Top1 78.128906    Top5 94.128906
+# 2018-12-09 13:28:47,691 - Test: [  150/  195]    Loss 0.973435    Top1 75.731771    Top5 92.619792
+# 2018-12-09 13:28:54,669 - ==> Top1: 74.738    Top5: 92.080    Loss: 1.018
+
+version: 1
+
+pruners:
+  fc_pruner:
+    class: AutomatedGradualPruner
+    initial_sparsity : 0.05
+    final_sparsity: 0.87
+    weights: module.fc.weight
+
+  filter_pruner:
+    class: L1RankedStructureParameterPruner_AGP
+    initial_sparsity : 0.05
+    final_sparsity: 0.50
+    group_type: Filters
+    weights: [module.layer1.0.conv1.weight,
+              module.layer1.1.conv1.weight,
+              module.layer1.2.conv1.weight,
+              #module.layer2.0.conv1.weight,
+              module.layer2.1.conv1.weight,
+              module.layer2.2.conv1.weight,
+              module.layer2.3.conv1.weight,
+              #module.layer3.0.conv1.weight,
+              module.layer3.1.conv1.weight,
+              module.layer3.2.conv1.weight,
+              module.layer3.3.conv1.weight,
+              module.layer3.4.conv1.weight,
+              module.layer3.5.conv1.weight,
+              module.layer4.0.conv1.weight,
+              module.layer4.1.conv1.weight,
+              module.layer4.2.conv1.weight,
+
+
+              module.layer1.0.conv2.weight,
+              module.layer1.1.conv2.weight,
+              module.layer1.2.conv2.weight,
+              module.layer2.0.conv2.weight,
+              #module.layer2.1.conv2.weight,
+              module.layer2.2.conv2.weight,
+              module.layer2.3.conv2.weight,
+              module.layer3.0.conv2.weight,
+              module.layer3.1.conv2.weight,
+              module.layer3.2.conv2.weight,
+              module.layer3.3.conv2.weight,
+              module.layer3.4.conv2.weight,
+              module.layer3.5.conv2.weight,
+              module.layer4.0.conv2.weight,
+              module.layer4.1.conv2.weight,
+              module.layer4.2.conv2.weight]
+
+  fine_pruner:
+    class: AutomatedGradualPruner
+    initial_sparsity : 0.05
+    final_sparsity: 0.70
+    weights: [
+      module.layer4.0.conv2.weight,
+      module.layer4.0.conv3.weight,
+      module.layer4.0.downsample.0.weight,
+      module.layer4.1.conv1.weight,
+      module.layer4.1.conv2.weight,
+      module.layer4.1.conv3.weight,
+      module.layer4.2.conv1.weight,
+      module.layer4.2.conv2.weight,
+      module.layer4.2.conv3.weight]
+
+extensions:
+  net_thinner:
+    class: 'FilterRemover'
+    thinning_func_str: remove_filters
+    arch: 'resnet50'
+    dataset: 'imagenet'
+
+lr_schedulers:
+  pruning_lr:
+    class: ExponentialLR
+    gamma: 0.95
+
+policies:
+  - pruner:
+     instance_name : filter_pruner
+#     args:
+#       mini_batch_pruning_frequency: 1
+    starting_epoch: 0
+    ending_epoch: 30
+    frequency: 1
+
+# After completeing the pruning, we perform network thinning and continue fine-tuning.
+  - extension:
+      instance_name: net_thinner
+    epochs: [31]
+
+
+  - lr_scheduler:
+      instance_name: pruning_lr
+    starting_epoch: 40
+    ending_epoch: 80
+    frequency: 1
diff --git a/examples/agp-pruning/resnet50.schedule_agp.filters_3.yaml b/examples/agp-pruning/resnet50.schedule_agp.filters_3.yaml
new file mode 100755
index 0000000000000000000000000000000000000000..d68c3fc1b3770a43d0282ecaa92a8bab04e482f2
--- /dev/null
+++ b/examples/agp-pruning/resnet50.schedule_agp.filters_3.yaml
@@ -0,0 +1,171 @@
+
+# This schedule performs filter-pruning using L1-norm ranking and AGP for the setting the pruning-rate decay.
+#
+# Best Top1: 75.748 (epoch 94)
+# No. of Parameters: 17,329,344 (of 25,502,912) = 67.95% dense (32.05% sparse)
+# Total MACs: 2,753,298,432 (of 4,089,184,256) = 67.33% compute = 1.49x
+#
+# time python3 compress_classifier.py -a=resnet50 --pretrained -p=50 ../../../data.imagenet/ -j=22 --epochs=100 --lr=0.0005 --compress=resnet50.schedule_agp.filters_3.yaml --validation-size=0 --num-best-scores=10
+#
+# Parameters:
+# +----+-------------------------------------+--------------------+---------------+----------------+------------+------------+----------+----------+----------+------------+---------+----------+------------+
+# |    | Name                                | Shape              |   NNZ (dense) |   NNZ (sparse) |   Cols (%) |   Rows (%) |   Ch (%) |   2D (%) |   3D (%) |   Fine (%) |     Std |     Mean |   Abs-Mean |
+# |----+-------------------------------------+--------------------+---------------+----------------+------------+------------+----------+----------+----------+------------+---------+----------+------------|
+# |  0 | module.conv1.weight                 | (64, 3, 7, 7)      |          9408 |           9408 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.11053 | -0.00040 |    0.06769 |
+# |  1 | module.layer1.0.conv1.weight        | (64, 64, 1, 1)     |          4096 |           4096 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.06357 | -0.00404 |    0.03573 |
+# |  2 | module.layer1.0.conv2.weight        | (32, 64, 3, 3)     |         18432 |          18432 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.03310 |  0.00093 |    0.02084 |
+# |  3 | module.layer1.0.conv3.weight        | (256, 32, 1, 1)    |          8192 |           8192 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.03770 | -0.00022 |    0.02367 |
+# |  4 | module.layer1.0.downsample.0.weight | (256, 64, 1, 1)    |         16384 |          16384 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.05107 | -0.00305 |    0.02849 |
+# |  5 | module.layer1.1.conv1.weight        | (64, 256, 1, 1)    |         16384 |          16384 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.02716 |  0.00097 |    0.01802 |
+# |  6 | module.layer1.1.conv2.weight        | (32, 64, 3, 3)     |         18432 |          18432 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.03056 |  0.00020 |    0.02092 |
+# |  7 | module.layer1.1.conv3.weight        | (256, 32, 1, 1)    |          8192 |           8192 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.03139 | -0.00050 |    0.01988 |
+# |  8 | module.layer1.2.conv1.weight        | (64, 256, 1, 1)    |         16384 |          16384 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.02660 |  0.00006 |    0.01926 |
+# |  9 | module.layer1.2.conv2.weight        | (32, 64, 3, 3)     |         18432 |          18432 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.03028 | -0.00037 |    0.02278 |
+# | 10 | module.layer1.2.conv3.weight        | (256, 32, 1, 1)    |          8192 |           8192 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.02762 | -0.00230 |    0.01640 |
+# | 11 | module.layer2.0.conv1.weight        | (128, 256, 1, 1)   |         32768 |          32768 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.03149 | -0.00137 |    0.02218 |
+# | 12 | module.layer2.0.conv2.weight        | (64, 128, 3, 3)    |         73728 |          73728 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.02133 |  0.00000 |    0.01584 |
+# | 13 | module.layer2.0.conv3.weight        | (512, 64, 1, 1)    |         32768 |          32768 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.02686 |  0.00009 |    0.01642 |
+# | 14 | module.layer2.0.downsample.0.weight | (512, 256, 1, 1)   |        131072 |         131072 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.02047 | -0.00043 |    0.01202 |
+# | 15 | module.layer2.1.conv1.weight        | (128, 512, 1, 1)   |         65536 |          65536 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.01479 | -0.00015 |    0.00897 |
+# | 16 | module.layer2.1.conv2.weight        | (64, 128, 3, 3)    |         73728 |          73728 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.02168 |  0.00056 |    0.01426 |
+# | 17 | module.layer2.1.conv3.weight        | (512, 64, 1, 1)    |         32768 |          32768 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.02224 | -0.00137 |    0.01297 |
+# | 18 | module.layer2.2.conv1.weight        | (128, 512, 1, 1)   |         65536 |          65536 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.02067 | -0.00070 |    0.01441 |
+# | 19 | module.layer2.2.conv2.weight        | (64, 128, 3, 3)    |         73728 |          73728 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.02121 | -0.00012 |    0.01501 |
+# | 20 | module.layer2.2.conv3.weight        | (512, 64, 1, 1)    |         32768 |          32768 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.02533 |  0.00031 |    0.01765 |
+# | 21 | module.layer2.3.conv1.weight        | (128, 512, 1, 1)   |         65536 |          65536 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.02153 | -0.00086 |    0.01597 |
+# | 22 | module.layer2.3.conv2.weight        | (64, 128, 3, 3)    |         73728 |          73728 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.02162 | -0.00050 |    0.01635 |
+# | 23 | module.layer2.3.conv3.weight        | (512, 64, 1, 1)    |         32768 |          32768 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.02281 | -0.00109 |    0.01573 |
+# | 24 | module.layer3.0.conv1.weight        | (256, 512, 1, 1)   |        131072 |         131072 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.02727 | -0.00112 |    0.01952 |
+# | 25 | module.layer3.0.conv2.weight        | (128, 256, 3, 3)   |        294912 |         294912 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.01670 | -0.00017 |    0.01233 |
+# | 26 | module.layer3.0.conv3.weight        | (1024, 128, 1, 1)  |        131072 |         131072 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.02275 | -0.00041 |    0.01634 |
+# | 27 | module.layer3.0.downsample.0.weight | (1024, 512, 1, 1)  |        524288 |         524288 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.01429 |  0.00004 |    0.00978 |
+# | 28 | module.layer3.1.conv1.weight        | (256, 1024, 1, 1)  |        262144 |         262144 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.01355 | -0.00048 |    0.00953 |
+# | 29 | module.layer3.1.conv2.weight        | (128, 256, 3, 3)   |        294912 |         294912 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.01497 | -0.00012 |    0.01089 |
+# | 30 | module.layer3.1.conv3.weight        | (1024, 128, 1, 1)  |        131072 |         131072 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.01943 | -0.00062 |    0.01378 |
+# | 31 | module.layer3.2.conv1.weight        | (256, 1024, 1, 1)  |        262144 |         262144 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.01404 | -0.00045 |    0.01007 |
+# | 32 | module.layer3.2.conv2.weight        | (128, 256, 3, 3)   |        294912 |         294912 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.01462 | -0.00055 |    0.01092 |
+# | 33 | module.layer3.2.conv3.weight        | (1024, 128, 1, 1)  |        131072 |         131072 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.01819 | -0.00049 |    0.01321 |
+# | 34 | module.layer3.3.conv1.weight        | (256, 1024, 1, 1)  |        262144 |         262144 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.01545 | -0.00069 |    0.01146 |
+# | 35 | module.layer3.3.conv2.weight        | (128, 256, 3, 3)   |        294912 |         294912 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.01450 | -0.00058 |    0.01108 |
+# | 36 | module.layer3.3.conv3.weight        | (1024, 128, 1, 1)  |        131072 |         131072 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.01730 | -0.00090 |    0.01271 |
+# | 37 | module.layer3.4.conv1.weight        | (256, 1024, 1, 1)  |        262144 |         262144 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.01610 | -0.00086 |    0.01212 |
+# | 38 | module.layer3.4.conv2.weight        | (128, 256, 3, 3)   |        294912 |         294912 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.01435 | -0.00078 |    0.01100 |
+# | 39 | module.layer3.4.conv3.weight        | (1024, 128, 1, 1)  |        131072 |         131072 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.01690 | -0.00115 |    0.01236 |
+# | 40 | module.layer3.5.conv1.weight        | (256, 1024, 1, 1)  |        262144 |         262144 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.01752 | -0.00083 |    0.01335 |
+# | 41 | module.layer3.5.conv2.weight        | (128, 256, 3, 3)   |        294912 |         294912 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.01461 | -0.00076 |    0.01118 |
+# | 42 | module.layer3.5.conv3.weight        | (1024, 128, 1, 1)  |        131072 |         131072 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.01839 | -0.00203 |    0.01380 |
+# | 43 | module.layer4.0.conv1.weight        | (512, 1024, 1, 1)  |        524288 |         524288 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.02106 | -0.00114 |    0.01631 |
+# | 44 | module.layer4.0.conv2.weight        | (256, 512, 3, 3)   |       1179648 |        1179648 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.01159 | -0.00021 |    0.00906 |
+# | 45 | module.layer4.0.conv3.weight        | (2048, 256, 1, 1)  |        524288 |         524288 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.01445 | -0.00059 |    0.01123 |
+# | 46 | module.layer4.0.downsample.0.weight | (2048, 1024, 1, 1) |       2097152 |        2097152 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.00895 | -0.00014 |    0.00681 |
+# | 47 | module.layer4.1.conv1.weight        | (512, 2048, 1, 1)  |       1048576 |        1048576 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.01328 | -0.00062 |    0.01036 |
+# | 48 | module.layer4.1.conv2.weight        | (256, 512, 3, 3)   |       1179648 |        1179648 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.01152 | -0.00057 |    0.00906 |
+# | 49 | module.layer4.1.conv3.weight        | (2048, 256, 1, 1)  |        524288 |         524288 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.01414 |  0.00001 |    0.01094 |
+# | 50 | module.layer4.2.conv1.weight        | (512, 2048, 1, 1)  |       1048576 |        1048576 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.01636 | -0.00033 |    0.01284 |
+# | 51 | module.layer4.2.conv2.weight        | (256, 512, 3, 3)   |       1179648 |        1179648 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.01016 | -0.00044 |    0.00802 |
+# | 52 | module.layer4.2.conv3.weight        | (2048, 256, 1, 1)  |        524288 |         524288 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.01318 |  0.00010 |    0.00993 |
+# | 53 | module.fc.weight                    | (1000, 2048)       |       2048000 |        2048000 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.03310 |  0.00000 |    0.02281 |
+# | 54 | Total sparsity:                     | -                  |      17329344 |       17329344 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.00000 |  0.00000 |    0.00000 |
+# +----+-------------------------------------+--------------------+---------------+----------------+------------+------------+----------+----------+----------+------------+---------+----------+------------+
+# 2018-12-04 23:32:08,902 - Total sparsity: 0.00
+#
+# 2018-12-04 23:32:08,903 - --- validate (epoch=99)-----------
+# 2018-12-04 23:32:08,903 - 50000 samples (256 per mini-batch)
+# 2018-12-04 23:32:27,743 - Epoch: [99][   50/  195]    Loss 0.683687    Top1 81.867188    Top5 95.937500
+# 2018-12-04 23:32:36,850 - Epoch: [99][  100/  195]    Loss 0.810284    Top1 79.027344    Top5 94.648438
+# 2018-12-04 23:32:45,252 - Epoch: [99][  150/  195]    Loss 0.934295    Top1 76.565104    Top5 93.072917
+# 2018-12-04 23:32:52,622 - ==> Top1: 75.654    Top5: 92.596    Loss: 0.978
+#
+# 2018-12-04 23:32:52,693 - ==> Best Top1: 76.334 on Epoch: 0
+# 2018-12-04 23:32:52,694 - ==> Best Top1: 76.316 on Epoch: 1
+# 2018-12-04 23:32:52,694 - ==> Best Top1: 75.902 on Epoch: 3
+# 2018-12-04 23:32:52,694 - ==> Best Top1: 75.748 on Epoch: 94   <========
+# 2018-12-04 23:32:52,694 - ==> Best Top1: 75.732 on Epoch: 85
+# 2018-12-04 23:32:52,694 - ==> Best Top1: 75.728 on Epoch: 95
+# 2018-12-04 23:32:52,694 - ==> Best Top1: 75.698 on Epoch: 84
+# 2018-12-04 23:32:52,694 - ==> Best Top1: 75.674 on Epoch: 90
+# 2018-12-04 23:32:52,694 - ==> Best Top1: 75.664 on Epoch: 80
+# 2018-12-04 23:32:52,695 - ==> Best Top1: 75.654 on Epoch: 99
+# 2018-12-04 23:32:52,695 - Saving checkpoint to: logs/resnet50_filters___2018.12.02-224517/resnet50_filters_checkpoint.pth.tar
+# 2018-12-04 23:32:53,013 - --- test ---------------------
+# 2018-12-04 23:32:53,014 - 50000 samples (256 per mini-batch)
+# 2018-12-04 23:33:12,090 - Test: [   50/  195]    Loss 0.683687    Top1 81.867188    Top5 95.937500
+# 2018-12-04 23:33:20,491 - Test: [  100/  195]    Loss 0.810284    Top1 79.027344    Top5 94.648438
+# 2018-12-04 23:33:28,604 - Test: [  150/  195]    Loss 0.934295    Top1 76.565104    Top5 93.072917
+# 2018-12-04 23:33:36,294 - ==> Top1: 75.654    Top5: 92.596    Loss: 0.978
+
+version: 1
+
+pruners:
+  fc_pruner:
+    class: AutomatedGradualPruner
+    initial_sparsity : 0.05
+    final_sparsity: 0.87
+    weights: module.fc.weight
+
+  filter_pruner:
+    class: L1RankedStructureParameterPruner_AGP
+    initial_sparsity : 0.05
+    final_sparsity: 0.50
+    group_type: Filters
+    weights: [module.layer1.0.conv2.weight,
+              module.layer1.1.conv2.weight,
+              module.layer1.2.conv2.weight,
+              module.layer2.0.conv2.weight,
+              module.layer2.1.conv2.weight,
+              module.layer2.2.conv2.weight,
+              module.layer2.3.conv2.weight,
+              module.layer3.0.conv2.weight,
+              module.layer3.1.conv2.weight,
+              module.layer3.2.conv2.weight,
+              module.layer3.3.conv2.weight,
+              module.layer3.4.conv2.weight,
+              module.layer3.5.conv2.weight,
+              module.layer4.0.conv2.weight,
+              module.layer4.1.conv2.weight,
+              module.layer4.2.conv2.weight]
+
+  fine_pruner:
+    class: AutomatedGradualPruner
+    initial_sparsity : 0.05
+    final_sparsity: 0.70
+    weights: [
+      module.layer4.0.conv2.weight,
+      module.layer4.0.conv3.weight,
+      module.layer4.0.downsample.0.weight,
+      module.layer4.1.conv1.weight,
+      module.layer4.1.conv2.weight,
+      module.layer4.1.conv3.weight,
+      module.layer4.2.conv1.weight,
+      module.layer4.2.conv2.weight,
+      module.layer4.2.conv3.weight]
+
+extensions:
+  net_thinner:
+    class: 'FilterRemover'
+    thinning_func_str: remove_filters
+    arch: 'resnet50'
+    dataset: 'imagenet'
+
+lr_schedulers:
+  pruning_lr:
+    class: ExponentialLR
+    gamma: 0.95
+
+policies:
+  - pruner:
+     instance_name : filter_pruner
+    starting_epoch: 0
+    ending_epoch: 30
+    frequency: 2
+
+# After completeing the pruning, we perform network thinning and continue fine-tuning.
+  - extension:
+      instance_name: net_thinner
+    epochs: [31]
+
+  - lr_scheduler:
+      instance_name: pruning_lr
+    starting_epoch: 40
+    ending_epoch: 100
+    frequency: 1
diff --git a/examples/classifier_compression/compress_classifier.py b/examples/classifier_compression/compress_classifier.py
index d0b53db6512d6ebda4043a9789342ebf185f725b..daf46fb8fe1b6d51804b008736d5ed98c3693cb2 100755
--- a/examples/classifier_compression/compress_classifier.py
+++ b/examples/classifier_compression/compress_classifier.py
@@ -142,6 +142,10 @@ parser.add_argument('--deterministic', '--det', action='store_true',
                     help='Ensure deterministic execution for re-producible results.')
 parser.add_argument('--gpus', metavar='DEV_ID', default=None,
                     help='Comma-separated list of GPU device IDs to be used (default is to use all available devices)')
+parser.add_argument('--cpu', action='store_true',
+                    help='Use CPU only. \n'
+                    'Flag not set => uses GPUs according to the --gpus flag value.'
+                    'Flag set => overrides the --gpus flag')
 parser.add_argument('--name', '-n', metavar='NAME', default=None, help='Experiment name')
 parser.add_argument('--out-dir', '-o', dest='output_dir', default='logs', help='Path to dump logs and checkpoints')
 parser.add_argument('--validation-size', '--vs', type=float_range, default=0.1,
@@ -290,20 +294,25 @@ def main():
         # results are not re-produced when benchmark is set. So enabling only if deterministic mode disabled.
         cudnn.benchmark = True
 
-    if args.gpus is not None:
-        try:
-            args.gpus = [int(s) for s in args.gpus.split(',')]
-        except ValueError:
-            msglogger.error('ERROR: Argument --gpus must be a comma-separated list of integers only')
-            exit(1)
-        available_gpus = torch.cuda.device_count()
-        for dev_id in args.gpus:
-            if dev_id >= available_gpus:
-                msglogger.error('ERROR: GPU device ID {0} requested, but only {1} devices available'
-                                .format(dev_id, available_gpus))
+    if args.cpu is not None or not torch.cuda.is_available():
+        # Set GPU index to -1 if using CPU
+        args.device = 'cpu'
+    else:
+        args.device = 'cuda'
+        if args.gpus is not None:
+            try:
+                args.gpus = [int(s) for s in args.gpus.split(',')]
+            except ValueError:
+                msglogger.error('ERROR: Argument --gpus must be a comma-separated list of integers only')
                 exit(1)
-        # Set default device in case the first one on the list != 0
-        torch.cuda.set_device(args.gpus[0])
+            available_gpus = torch.cuda.device_count()
+            for dev_id in args.gpus:
+                if dev_id >= available_gpus:
+                    msglogger.error('ERROR: GPU device ID {0} requested, but only {1} devices available'
+                                    .format(dev_id, available_gpus))
+                    exit(1)
+            # Set default device in case the first one on the list != 0
+            torch.cuda.set_device(args.gpus[0])
 
     # Infer the dataset from the model name
     args.dataset = 'cifar10' if 'cifar' in args.arch else 'imagenet'
@@ -332,10 +341,11 @@ def main():
     if args.resume:
         model, compression_scheduler, start_epoch = apputils.load_checkpoint(
             model, chkpt_file=args.resume)
-        model.cuda()
+        model.to(args.device)
 
     # Define loss function (criterion) and optimizer
-    criterion = nn.CrossEntropyLoss().cuda()
+    criterion = nn.CrossEntropyLoss().to(args.device)
+
     optimizer = torch.optim.SGD(model.parameters(), lr=args.lr,
                                 momentum=args.momentum,
                                 weight_decay=args.weight_decay)
@@ -372,7 +382,7 @@ def main():
         # requires a compression schedule configuration file in YAML.
         compression_scheduler = distiller.file_config(model, optimizer, args.compress, compression_scheduler)
         # Model is re-transferred to GPU in case parameters were added (e.g. PACTQuantizer)
-        model.cuda()
+        model.to(args.device)
     elif compression_scheduler is None:
         compression_scheduler = distiller.CompressionScheduler(model)
 
@@ -476,7 +486,7 @@ def train(train_loader, model, criterion, optimizer, epoch,
     for train_step, (inputs, target) in enumerate(train_loader):
         # Measure data loading time
         data_time.add(time.time() - end)
-        inputs, target = inputs.to('cuda'), target.to('cuda')
+        inputs, target = inputs.to(args.device), target.to(args.device)
 
         # Execute the forward phase, compute the output and measure loss
         if compression_scheduler:
@@ -600,7 +610,7 @@ def _validate(data_loader, model, criterion, loggers, args, epoch=-1):
     end = time.time()
     for validation_step, (inputs, target) in enumerate(data_loader):
         with torch.no_grad():
-            inputs, target = inputs.to('cuda'), target.to('cuda')
+            inputs, target = inputs.to(args.device), target.to(args.device)
             # compute output from model
             output = model(inputs)
 
@@ -675,7 +685,8 @@ def earlyexit_validate_loss(output, target, criterion, args):
     # but with a grouping of samples equal to the batch size.
     # Note that final group might not be a full batch - so determine actual size.
     this_batch_size = target.size()[0]
-    earlyexit_validate_criterion = nn.CrossEntropyLoss(reduce=False).cuda()
+    earlyexit_validate_criterion = nn.CrossEntropyLoss(reduce=False).to(args.device)
+
     for exitnum in range(args.num_exits):
         # calculate losses at each sample separately in the minibatch.
         args.loss_exits[exitnum] = earlyexit_validate_criterion(output[exitnum], target)
@@ -744,7 +755,7 @@ def evaluate_model(model, criterion, test_loader, loggers, activations_collector
                                                           args.qe_bits_accum, args.qe_mode, args.qe_clip_acts,
                                                           args.qe_no_clip_layers, args.qe_per_channel)
         quantizer.prepare_model()
-        model.cuda()
+        model.to(args.device)
 
     top1, _, _ = test(test_loader, model, criterion, loggers, activations_collectors, args=args)
 
diff --git a/models/__init__.py b/models/__init__.py
index 7bde40697c3cc328541a2f790b43efe47c587071..193951d579dd36ac663998ad2387e41986a56109 100755
--- a/models/__init__.py
+++ b/models/__init__.py
@@ -51,6 +51,10 @@ def create_model(pretrained, dataset, arch, parallel=True, device_ids=None):
         dataset:
         arch:
         parallel:
+        device_ids: Devices on which model should be created -
+            None - GPU if available, otherwise CPU
+            -1 - CPU
+            >=0 - GPU device IDs
     """
     msglogger.info('==> using %s dataset' % dataset)
 
@@ -81,5 +85,7 @@ def create_model(pretrained, dataset, arch, parallel=True, device_ids=None):
     elif parallel:
         model = torch.nn.DataParallel(model, device_ids=device_ids)
 
-    model.cuda()
+    if torch.cuda.is_available() and device_ids != -1:
+        model.cuda()
+
     return model