diff --git a/distiller/__init__.py b/distiller/__init__.py
index 708d6288f5bcff7548095bdd255cb3bee332f9cd..643f3f5afcff8b0a9b71e397c4802af3e91419ed 100755
--- a/distiller/__init__.py
+++ b/distiller/__init__.py
@@ -15,7 +15,7 @@
 #
 
 from .utils import *
-from .thresholding import GroupThresholdMixin, threshold_mask
+from .thresholding import GroupThresholdMixin, threshold_mask, group_threshold_mask
 from .config import file_config, dict_config
 from .model_summaries import *
 from .scheduler import *
@@ -25,19 +25,14 @@ from .policy import *
 from .thinning import *
 from .knowledge_distillation import KnowledgeDistillationPolicy, DistillationLossWeights
 
-#del utils
+
 del dict_config
 del thinning
-#del model_summaries
-#del scheduler
-#del sensitivity
-#del directives
-#del thresholding
-#del policy
 
 # Distiller version
 __version__ = "0.3.0-pre"
 
+
 def model_find_param_name(model, param_to_find):
     """Look up the name of a model parameter.
 
@@ -69,6 +64,7 @@ def model_find_module_name(model, module_to_find):
             return name
     return None
 
+
 def model_find_param(model, param_to_find_name):
     """Look a model parameter by its name
 
diff --git a/distiller/data_loggers/__init__.py b/distiller/data_loggers/__init__.py
index 2ffca80dfbf4fb3767bf3a92799ab52fc7c71b36..a3351f0fe6db92956f0516074e0f86d8f0baea6f 100755
--- a/distiller/data_loggers/__init__.py
+++ b/distiller/data_loggers/__init__.py
@@ -14,7 +14,7 @@
 # limitations under the License.
 #
 
-from .collector import ActivationSparsityCollector
+from .collector import *
 from .logger import PythonLogger, TensorBoardLogger, CsvLogger
 
 del logger
diff --git a/distiller/data_loggers/collector.py b/distiller/data_loggers/collector.py
index 1254f7a389ef837b423fb0a3da93aa65f8a1c0ef..a18a368efa3061bf1a0c594b1be799d83bc8e62d 100755
--- a/distiller/data_loggers/collector.py
+++ b/distiller/data_loggers/collector.py
@@ -14,88 +14,302 @@
 # limitations under the License.
 #
 
+import xlsxwriter
+import os
+from collections import OrderedDict
+from contextlib import contextmanager
 import torch
-from distiller.utils import sparsity
 from torchnet.meter import AverageValueMeter
 import logging
+import distiller
 msglogger = logging.getLogger()
 
-__all__ = ['ActivationSparsityCollector']
+__all__ = ['SummaryActivationStatsCollector', 'RecordsActivationStatsCollector',
+           'collector_context', 'collectors_context']
 
-class DataCollector(object):
-    def __init__(self):
-        pass
 
+class ActivationStatsCollector(object):
+    """Collect model activation statistics information.
 
-class ActivationSparsityCollector(DataCollector):
-    """Collect model activation sparsity information.
+    ActivationStatsCollector is the base class for classes that collect activations statistics.
+    You may collect statistics on different phases of the optimization process (training, validation, test).
 
-    CNN models with ReLU layers, exhibit sparse activations.
-    ActivationSparsityCollector will collect information about this sparsity.
-    Currently we only record the mean sparsity of the activations, but this can be expanded
-    to collect std and other statistics.
+    Statistics data are accessible via .value() or by accessing individual modules.
 
-    The current implementation activation sparsity collection has a few caveats:
-    * It is slow
+    The current implementation has a few caveats:
+    * It is slow - therefore it is advisable to use this only when needed.
     * It can't access the activations of torch.Functions, only torch.Modules.
-    * The layer names are mangled
 
-    ActivationSparsityCollector uses the forward hook of modules in order to access the
+    ActivationStatsCollector uses the forward hook of modules in order to access the
     feature-maps.  This is both slow and limits us to seeing only the outputs of torch.Modules.
-    We can remove some of the slowness, by choosing to log only specific layers.  By default,
-    we only logs torch.nn.ReLU activations.
+    We can remove some of the slowness, by choosing to log only specific layers or use it only
+    during validation or test.  By default, we only log torch.nn.ReLU activations.
 
     The layer names are mangled, because torch.Modules don't have names and we need to invent
-    a unique name per layer.
-    """
-    def __init__(self, model, classes=[torch.nn.ReLU]):
-        """Since only specific layers produce sparse feature-maps, the
-        ActivationSparsityCollector constructor accepts an optional list of layers to log."""
+    a unique name per layer.  To assign human-readable names, it is advisable to invoke the following
+    before starting the statistics collection:
 
-        super(ActivationSparsityCollector, self).__init__()
+        distiller.utils.assign_layer_fq_names(model)
+    """
+    def __init__(self, model, stat_name, classes):
+        """
+        Args:
+            model - the model we are monitoring.
+            statistics_dict - a dictionary of {stat_name: statistics_function}, where name
+                provides a means for us to access the statistics data at a later time; and the
+                statistics_function is a function that gets an activation as input and returns
+                some statistic.
+                For example, the dictionary below collects element-wise activation sparsity
+                statistics:
+                    {"sparsity": distiller.utils.sparsity}
+            classes - a list of class types for which we collect activation statistics.
+                You can access a module's activation statistics by referring to module.<stat_name>
+                For example:
+                    print(module.sparsity)
+        """
+        super(ActivationStatsCollector, self).__init__()
         self.model = model
+        self.stat_name = stat_name
         self.classes = classes
-        self._init_activations_sparsity(model)
+        self.fwd_hook_handles = []
 
     def value(self):
-        """Return a dictionary containing {layer_name: mean sparsity}"""
-        activation_sparsity = {}
-        _collect_activations_sparsity(self.model, activation_sparsity)
-        return activation_sparsity
+        """Return a dictionary containing {layer_name: statistic}"""
+        activation_stats = OrderedDict()
+        self.__value(self.model, activation_stats)
+        return activation_stats
+
+    def __value(self, module, activation_stats):
+        for child_module in module._modules.values():
+            self.__value(child_module, activation_stats)
+        self._collect_activations_stats(module, activation_stats)
+
+    def start(self):
+        """Start collecting activation stats.
+
+        This will iteratively register the modules' forward-hooks, so that the collector
+        will be called from the forward traversal and get exposed to activation data.
+        """
+        assert len(self.fwd_hook_handles) == 0
+        self.__start(self.model)
 
+    def stop(self):
+        """Stop collecting activation stats.
 
-    def _init_activations_sparsity(self, module, name=''):
-        def __activation_sparsity_cb(module, input, output):
-            """Record the activation sparsity of 'module'
+        This will iteratively unregister the modules' forward-hooks.
+        """
+        for handle in self.fwd_hook_handles:
+            handle.remove()
+        self.fwd_hook_handles = []
 
-            This is a callback from the forward() of 'module'.
-            """
-            module.sparsity.add(sparsity(output.data))
+    def reset(self):
+        """Reset the statsitics counters of this collector."""
+        self.__reset(self.model)
+        return self
 
-        has_children = False
+    def __reset(self, module):
+        for child_module in module._modules.values():
+            self.__reset(child_module)
+        self._reset_counter(module)
+
+    def __activation_stats_cb(self, module, input, output):
+        """Handle new activations ('output' argument).
+
+        This is invoked from the forward-pass callback of module 'module'.
+        """
+        raise NotImplementedError
+
+    def __start(self, module, name=''):
+        """Iteratively register to the forward-pass callback of all eligable modules.
+
+        Eligable modules are currently filtered by their class type.
+        """
+        is_leaf_node = True
         for name, sub_module in module._modules.items():
-            self._init_activations_sparsity(sub_module, name)
-            has_children = True
-        if not has_children:
+            self.__start(sub_module, name)
+            is_leaf_node = False
+
+        if is_leaf_node:
             if type(module) in self.classes:
-                module.register_forward_hook(__activation_sparsity_cb)
-                module.sparsity = AverageValueMeter()
-                if hasattr(module, 'ref_name'):
-                    module.sparsity.name = 'sparsity_' + module.ref_name
-                else:
-                    module.sparsity.name = 'sparsity_' + name + '_' + module.__class__.__name__ + '_' + str(id(module))
+                self.fwd_hook_handles.append(module.register_forward_hook(self._activation_stats_cb))
+                self._start_counter(module)
+
+    def _reset_counter(self, mod):
+        """Reset a specific statistic counter - this is subclass-specific code"""
+        raise NotImplementedError
+
+    def _collect_activations_stats(self, module, activation_stats, name=''):
+        """Handle new activations - this is subclass-specific code"""
+        raise NotImplementedError
+
+
+class SummaryActivationStatsCollector(ActivationStatsCollector):
+    """This class collects activiations statistical summaries.
+
+    This Collector computes the mean of some statistic of the activation.  It is rather
+    light-weight and quicker than collecting a record per activation.
+    The statistic function is configured in the constructor.
+    """
+    def __init__(self, model, stat_name, summary_fn, classes=[torch.nn.ReLU]):
+        super(SummaryActivationStatsCollector, self).__init__(model, stat_name, classes)
+        self.summary_fn = summary_fn
+
+    def _activation_stats_cb(self, module, input, output):
+        """Record the activation sparsity of 'module'
+
+        This is a callback from the forward() of 'module'.
+        """
+        getattr(module, self.stat_name).add(self.summary_fn(output.data))
+
+    def _start_counter(self, module):
+        if not hasattr(module, self.stat_name):
+            setattr(module, self.stat_name, AverageValueMeter())
+            # Assign a name to this summary
+            if hasattr(module, 'distiller_name'):
+                getattr(module, self.stat_name).name = '_'.join((self.stat_name, module.distiller_name))
+            else:
+                getattr(module, self.stat_name).name = '_'.join((self.stat_name,
+                                                                 module.__class__.__name__,
+                                                                 str(id(module))))
+
+    def _reset_counter(self, mod):
+        if hasattr(mod, self.stat_name):
+            getattr(mod, self.stat_name).reset()
+
+    def _collect_activations_stats(self, module, activation_stats, name=''):
+        if hasattr(module, self.stat_name):
+            mean = getattr(module, self.stat_name).mean
+            if isinstance(mean, torch.Tensor):
+                mean = mean.tolist()
+            activation_stats[getattr(module, self.stat_name).name] = mean
+
+    def to_xlsx(self, fname):
+        """Save the records to an Excel workbook, with one worksheet per layer.
+        """
+        fname = ".".join([fname, 'xlsx'])
+        try:
+            os.remove(fname)
+        except OSError:
+            pass
+
+        records_dict = self.value()
+        with xlsxwriter.Workbook(fname) as workbook:
+            worksheet = workbook.add_worksheet(self.stat_name)
+            col_names = []
+            for col, (module_name, module_summary_data) in enumerate(records_dict.items()):
+                if not isinstance(module_summary_data, list):
+                    module_summary_data = [module_summary_data]
+                worksheet.write_column(1, col, module_summary_data)
+                col_names.append(module_name)
+            worksheet.write_row(0, 0, col_names)
+
+
+class RecordsActivationStatsCollector(ActivationStatsCollector):
+    """This class collects activiations statistical records.
+
+    This Collector computes a hard-coded set of activations statsitics and collects a
+    record per activation.  The activation records of the entire model (only filtered modules),
+    can be saved to an Excel workbook.
+
+    For obvious reasons, this is slower than SummaryActivationStatsCollector.
+    """
+    def __init__(self, model, classes=[torch.nn.ReLU]):
+        super(RecordsActivationStatsCollector, self).__init__(model, "statsitics_records", classes)
+
+    def _activation_stats_cb(self, module, input, output):
+        """Record the activation sparsity of 'module'
+
+        This is a callback from the forward() of 'module'.
+        """
+        def to_np(stats):
+            if isinstance(stats, tuple):
+                return stats[0].detach().cpu().numpy()
+            else:
+                return stats.detach().cpu().numpy()
+
+        # We get a batch of activations, from which we collect statistics
+        act = output.view(output.size(0), -1)
+        batch_min_list = to_np(torch.min(act, dim=1)).tolist()
+        batch_max_list = to_np(torch.max(act, dim=1)).tolist()
+        batch_mean_list = to_np(torch.mean(act, dim=1)).tolist()
+        batch_std_list = to_np(torch.std(act, dim=1)).tolist()
+        batch_l2_list = to_np(torch.norm(act, p=2, dim=1)).tolist()
+
+        module.statsitics_records['min'].extend(batch_min_list)
+        module.statsitics_records['max'].extend(batch_max_list)
+        module.statsitics_records['mean'].extend(batch_mean_list)
+        module.statsitics_records['std'].extend(batch_std_list)
+        module.statsitics_records['l2'].extend(batch_l2_list)
+        module.statsitics_records['shape'] = distiller.size2str(output)
 
     @staticmethod
-    def _collect_activations_sparsity(model, activation_sparsity, name=''):
-        for name, module in model._modules.items():
-            _collect_activations_sparsity(module, activation_sparsity, name)
+    def _create_records_dict():
+        records = OrderedDict()
+        for stat_name in ['min', 'max', 'mean', 'std', 'l2']:
+            records[stat_name] = []
+        records['shape'] = ''
+        return records
+
+    def to_xlsx(self, fname):
+        """Save the records to an Excel workbook, with one worksheet per layer.
+        """
+        fname = ".".join([fname, 'xlsx'])
+        try:
+            os.remove(fname)
+        except OSError:
+            pass
+
+        records_dict = self.value()
+        with xlsxwriter.Workbook(fname) as workbook:
+            for module_name, module_act_records in records_dict.items():
+                worksheet = workbook.add_worksheet(module_name)
+                col_names = []
+                for col, (col_name, col_data) in enumerate(module_act_records.items()):
+                    if col_name == 'shape':
+                        continue
+                    worksheet.write_column(1, col, col_data)
+                    col_names.append(col_name)
+                worksheet.write_row(0, 0, col_names)
+                worksheet.write(0, len(col_names)+2, module_act_records['shape'])
+
+    def _start_counter(self, module):
+        if not hasattr(module, "statsitics_records"):
+            module.statsitics_records = self._create_records_dict()
+
+    def _reset_counter(self, mod):
+        if hasattr(mod, "statsitics_records"):
+            mod.statsitics_records = self._create_records_dict()
+
+    def _collect_activations_stats(self, module, activation_stats, name=''):
+        if hasattr(module, "statsitics_records"):
+            activation_stats[module.distiller_name] = module.statsitics_records
+
+
+@contextmanager
+def collector_context(collector):
+    """A context manager for an activation collector"""
+    if collector is not None:
+        collector.reset().start()
+    yield collector
+    if collector is not None:
+        collector.stop()
+
 
-        if hasattr(model, 'sparsity'):
-            activation_sparsity[model.sparsity.name] = model.sparsity.mean
+@contextmanager
+def collectors_context(collectors_dict):
+    """A context manager for a dictionary of collectors"""
+    if len(collectors_dict) == 0:
+        yield collectors_dict
+        return
+    for collector in collectors_dict.values():
+        collector.reset().start()
+    yield collectors_dict
+    for collector in collectors_dict.values():
+        collector.stop()
 
 
-class TrainingProgressCollector(DataCollector):
-    def __init__(self, stats = {}):
+class TrainingProgressCollector(object):
+    def __init__(self, stats={}):
         super(TrainingProgressCollector, self).__init__()
         object.__setattr__(self, '_stats', stats)
 
diff --git a/distiller/data_loggers/logger.py b/distiller/data_loggers/logger.py
index 796c6c3870cdec901ab8767182be485e8d8c91e4..89ac4730fc938e2a2da01677096f3d10ae33167c 100755
--- a/distiller/data_loggers/logger.py
+++ b/distiller/data_loggers/logger.py
@@ -28,7 +28,7 @@ import torch
 import numpy as np
 import tabulate
 import distiller
-from distiller.utils import density, sparsity, sparsity_2D, size_to_str, to_np
+from distiller.utils import density, sparsity, sparsity_2D, size_to_str, to_np, norm_filters
 # TensorBoard logger
 from .tbbackend import TBBackend
 # Visdom logger
@@ -53,7 +53,7 @@ class DataLogger(object):
     def log_training_progress(self, model, epoch, i, set_size, batch_time, data_time, classerr, losses, print_freq, collectors):
         raise NotImplementedError
 
-    def log_activation_sparsity(self, activation_sparsity, logcontext):
+    def log_activation_statsitic(self, phase, stat_name, activation_stats, epoch):
         raise NotImplementedError
 
     def log_weights_sparsity(self, model, epoch):
@@ -81,12 +81,11 @@ class PythonLogger(DataLogger):
                 log = log + '{name} {val:.6f}    '.format(name=name, val=val)
         self.pylogger.info(log)
 
-
-    def log_activation_sparsity(self, activation_sparsity, logcontext):
+    def log_activation_statsitic(self, phase, stat_name, activation_stats, epoch):
         data = []
-        for layer, sparsity in activation_sparsity.items():
-            data.append([layer, sparsity*100])
-        t = tabulate.tabulate(data, headers=['Layer', 'sparsity (%)'], tablefmt='psql', floatfmt=".1f")
+        for layer, statistic in activation_stats.items():
+            data.append([layer, statistic])
+        t = tabulate.tabulate(data, headers=['Layer', stat_name], tablefmt='psql', floatfmt=".2f")
         msglogger.info('\n' + t)
 
     def log_weights_sparsity(self, model, epoch):
@@ -119,9 +118,9 @@ class TensorBoardLogger(DataLogger):
             self.tblogger.scalar_summary(prefix+tag, value, total_steps(total, epoch, completed))
         self.tblogger.sync_to_file()
 
-    def log_activation_sparsity(self, activation_sparsity, epoch):
-        group = 'sparsity/activations/'
-        for tag, value in activation_sparsity.items():
+    def log_activation_statsitic(self, phase, stat_name, activation_stats, epoch):
+        group = stat_name + '/activations/' + phase + "/"
+        for tag, value in activation_stats.items():
             self.tblogger.scalar_summary(group+tag, value, epoch)
         self.tblogger.sync_to_file()
 
@@ -142,6 +141,15 @@ class TensorBoardLogger(DataLogger):
         self.tblogger.scalar_summary("sprasity/weights/total", 100*(1 - sparse_params_size/params_size), epoch)
         self.tblogger.sync_to_file()
 
+    def log_weights_filter_magnitude(self, model, epoch, multi_graphs=False):
+        """Log the L1-magnitude of the weights tensors.
+        """
+        for name, param in model.state_dict().items():
+            if param.dim() in [4]:
+                self.tblogger.list_summary('magnitude/filters/' + name,
+                                           list(to_np(norm_filters(param))), epoch, multi_graphs)
+        self.tblogger.sync_to_file()
+
     def log_weights_distribution(self, named_params, steps_completed):
         if named_params is None:
             return
@@ -174,6 +182,6 @@ class CsvLogger(DataLogger):
                     params_size += torch.numel(param)
                     sparse_params_size += param.numel() * _density
                     writer.writerow([name, size_to_str(param.size()),
-                                       torch.numel(param),
-                                       int(_density * param.numel()),
-                                       (1-_density)*100])
+                                     torch.numel(param),
+                                     int(_density * param.numel()),
+                                     (1-_density)*100])
diff --git a/distiller/data_loggers/tbbackend.py b/distiller/data_loggers/tbbackend.py
index 0761ef2cfd866cfefeb5f420bff6481ba8194775..ccd077a5426153e5814b1e28fbf3fbfc56634e28 100755
--- a/distiller/data_loggers/tbbackend.py
+++ b/distiller/data_loggers/tbbackend.py
@@ -18,13 +18,16 @@
 Writes logs to a file using a Google's TensorBoard protobuf format.
 See: https://github.com/tensorflow/tensorflow/blob/master/tensorflow/core/framework/summary.proto
 """
+import os
 import tensorflow as tf
 import numpy as np
 
-class TBBackend(object):
 
+class TBBackend(object):
     def __init__(self, log_dir):
-        self.writer = tf.summary.FileWriter(log_dir)
+        self.writers = []
+        self.log_dir = log_dir
+        self.writers.append(tf.summary.FileWriter(log_dir))
 
     def scalar_summary(self, tag, scalar, step):
         """From TF documentation:
@@ -32,7 +35,26 @@ class TBBackend(object):
             value: value associated with the tag (a float).
         """
         summary = tf.Summary(value=[tf.Summary.Value(tag=tag, simple_value=scalar)])
-        self.writer.add_summary(summary, step)
+        self.writers[0].add_summary(summary, step)
+
+    def list_summary(self, tag, list, step, multi_graphs):
+        """Log a relatively small list of scalars.
+
+        We want to track the progress of multiple scalar parameters in a single graph.
+        The list provides a single value for each of the parameters we are tracking.
+        
+        NOTE: There are two ways to log multiple values in TB and neither one is optimal.
+        1. Use a single writer: in this case all of the parameters use the same color, and
+           distinguishing between them is difficult.
+        2. Use multiple writers: in this case each parameter has its own color which helps
+           to visually separate the parameters.  However, each writer logs to a different
+           file and this creates a lot of files which slow down the TB load.
+        """
+        for i, scalar in enumerate(list):
+            if multi_graphs and (i+1 > len(self.writers)):
+                self.writers.append(tf.summary.FileWriter(os.path.join(self.log_dir, str(i))))
+            summary = tf.Summary(value=[tf.Summary.Value(tag=tag, simple_value=scalar)])
+            self.writers[0 if not multi_graphs else i].add_summary(summary, step)
 
     def histogram_summary(self, tag, tensor, step):
         """
@@ -49,11 +71,11 @@ class TBBackend(object):
         """
         hist, edges = np.histogram(tensor, bins=200)
         tfhist = tf.HistogramProto(
-            min = np.min(tensor),
-            max = np.max(tensor),
-            num = int(np.prod(tensor.shape)),
-            sum = np.sum(tensor),
-            sum_squares = np.sum(np.square(tensor)))
+            min=np.min(tensor),
+            max=np.max(tensor),
+            num=int(np.prod(tensor.shape)),
+            sum=np.sum(tensor),
+            sum_squares=np.sum(np.square(tensor)))
 
         # From the TF documentation:
         #   Parallel arrays encoding the bucket boundaries and the bucket values.
@@ -64,7 +86,8 @@ class TBBackend(object):
         tfhist.bucket.extend(hist)
 
         summary = tf.Summary(value=[tf.Summary.Value(tag=tag, histo=tfhist)])
-        self.writer.add_summary(summary, step)
+        self.writers[0].add_summary(summary, step)
 
     def sync_to_file(self):
-        self.writer.flush()
+        for writer in self.writers:
+            writer.flush()
diff --git a/distiller/policy.py b/distiller/policy.py
index 6d705332f6f4d76a9d7ffc43fc492b85c7a06565..dc85c60def1dee5aa53554f60e9150b462c44375 100755
--- a/distiller/policy.py
+++ b/distiller/policy.py
@@ -90,6 +90,7 @@ class PruningPolicy(ScheduledTrainingPolicy):
         if self.levels is not None:
             self.pruner.levels = self.levels
 
+        meta['model'] = model
         for param_name, param in model.named_parameters():
             self.pruner.set_param_mask(param, param_name, zeros_mask_dict, meta)
 
diff --git a/distiller/pruning/__init__.py b/distiller/pruning/__init__.py
index dc8a1c76ca7910f00194b5c7c07f28695cc68fac..24fe7e5b965ba19b1f3fd3739c4c62a3935245ba 100755
--- a/distiller/pruning/__init__.py
+++ b/distiller/pruning/__init__.py
@@ -19,11 +19,14 @@
 """
 
 from .magnitude_pruner import MagnitudeParameterPruner
-from .automated_gradual_pruner import AutomatedGradualPruner, StructuredAutomatedGradualPruner
+from .automated_gradual_pruner import AutomatedGradualPruner, L1RankedStructureParameterPruner_AGP, \
+                                      ActivationAPoZRankedFilterPruner_AGP, GradientRankedFilterPruner_AGP, \
+                                      RandomRankedFilterPruner_AGP
 from .level_pruner import SparsityLevelParameterPruner
 from .sensitivity_pruner import SensitivityPruner
 from .structure_pruner import StructureParameterPruner
-from .ranked_structures_pruner import L1RankedStructureParameterPruner
+from .ranked_structures_pruner import L1RankedStructureParameterPruner, ActivationAPoZRankedFilterPruner, \
+                                      RandomRankedFilterPruner, GradientRankedFilterPruner
 from .baidu_rnn_pruner import BaiduRNNPruner
 
 del magnitude_pruner
diff --git a/distiller/pruning/automated_gradual_pruner.py b/distiller/pruning/automated_gradual_pruner.py
index dff2168320811349d2644e2dadd94c21155695d2..dfb80e58bc4f59def418f3ebeb3b001a7fdc9ae2 100755
--- a/distiller/pruning/automated_gradual_pruner.py
+++ b/distiller/pruning/automated_gradual_pruner.py
@@ -16,7 +16,7 @@
 
 from .pruner import _ParameterPruner
 from .level_pruner import SparsityLevelParameterPruner
-from .ranked_structures_pruner import L1RankedStructureParameterPruner
+from .ranked_structures_pruner import *
 from distiller.utils import *
 # import logging
 # msglogger = logging.getLogger()
@@ -61,28 +61,56 @@ class AutomatedGradualPruner(_ParameterPruner):
         target_sparsity = (self.final_sparsity +
                            (self.initial_sparsity-self.final_sparsity) *
                            (1.0 - ((current_epoch-starting_epoch)/span))**3)
-        self.pruning_fn(param, param_name, zeros_mask_dict, target_sparsity)
+        self.pruning_fn(param, param_name, zeros_mask_dict, target_sparsity, meta['model'])
 
     @staticmethod
-    def prune_to_target_sparsity(param, param_name, zeros_mask_dict, target_sparsity):
+    def prune_to_target_sparsity(param, param_name, zeros_mask_dict, target_sparsity, model=None):
         return SparsityLevelParameterPruner.prune_level(param, param_name, zeros_mask_dict, target_sparsity)
 
 
-class StructuredAutomatedGradualPruner(AutomatedGradualPruner):
+class CriterionParameterizedAGP(AutomatedGradualPruner):
     def __init__(self, name, initial_sparsity, final_sparsity, reg_regims):
         self.reg_regims = reg_regims
         weights = [weight for weight in reg_regims.keys()]
-        if not all([group in ['3D', 'Filters', 'Channels'] for group in reg_regims.values()]):
-            raise ValueError("Currently only filter (3D) and channel pruning is supported")
-        super(StructuredAutomatedGradualPruner, self).__init__(name, initial_sparsity,
-                                                               final_sparsity, weights,
-                                                               pruning_fn=self.prune_to_target_sparsity)
+        if not all([group in ['3D', 'Filters', 'Channels', 'Rows'] for group in reg_regims.values()]):
+            raise ValueError("Unsupported group structure")
+        super(CriterionParameterizedAGP, self).__init__(name, initial_sparsity,
+                                                        final_sparsity, weights,
+                                                        pruning_fn=self.prune_to_target_sparsity)
 
-    def prune_to_target_sparsity(self, param, param_name, zeros_mask_dict, target_sparsity):
+    def prune_to_target_sparsity(self, param, param_name, zeros_mask_dict, target_sparsity, model):
         if self.reg_regims[param_name] in ['3D', 'Filters']:
-            L1RankedStructureParameterPruner.rank_prune_filters(target_sparsity, param,
-                                                                param_name, zeros_mask_dict)
-        else:
-            if self.reg_regims[param_name] == 'Channels':
-                L1RankedStructureParameterPruner.rank_prune_channels(target_sparsity, param,
-                                                                     param_name, zeros_mask_dict)
+            self.filters_ranking_fn(target_sparsity, param, param_name, zeros_mask_dict, model)
+        elif self.reg_regims[param_name] == 'Channels':
+            self.channels_ranking_fn(target_sparsity, param, param_name, zeros_mask_dict, model)
+        elif self.reg_regims[param_name] == 'Rows':
+            self.rows_ranking_fn(target_sparsity, param, param_name, zeros_mask_dict, model)
+
+
+# TODO: this class parameterization is cumbersome: the ranking functions (per structure)
+# should come from the YAML schedule
+
+class L1RankedStructureParameterPruner_AGP(CriterionParameterizedAGP):
+    def __init__(self, name, initial_sparsity, final_sparsity, reg_regims):
+        super(L1RankedStructureParameterPruner_AGP, self).__init__(name, initial_sparsity, final_sparsity, reg_regims)
+        self.filters_ranking_fn = L1RankedStructureParameterPruner.rank_prune_filters
+        self.channels_ranking_fn = L1RankedStructureParameterPruner.rank_prune_channels
+        self.rows_ranking_fn = L1RankedStructureParameterPruner.rank_prune_rows
+
+
+class ActivationAPoZRankedFilterPruner_AGP(CriterionParameterizedAGP):
+    def __init__(self, name, initial_sparsity, final_sparsity, reg_regims):
+        super(ActivationAPoZRankedFilterPruner_AGP, self).__init__(name, initial_sparsity, final_sparsity, reg_regims)
+        self.filters_ranking_fn = ActivationAPoZRankedFilterPruner.rank_prune_filters
+
+
+class GradientRankedFilterPruner_AGP(CriterionParameterizedAGP):
+    def __init__(self, name, initial_sparsity, final_sparsity, reg_regims):
+        super(GradientRankedFilterPruner_AGP, self).__init__(name, initial_sparsity, final_sparsity, reg_regims)
+        self.filters_ranking_fn = GradientRankedFilterPruner.rank_prune_filters
+
+
+class RandomRankedFilterPruner_AGP(CriterionParameterizedAGP):
+    def __init__(self, name, initial_sparsity, final_sparsity, reg_regims):
+        super(RandomRankedFilterPruner_AGP, self).__init__(name, initial_sparsity, final_sparsity, reg_regims)
+        self.filters_ranking_fn = RandomRankedFilterPruner.rank_prune_filters
diff --git a/distiller/pruning/ranked_structures_pruner.py b/distiller/pruning/ranked_structures_pruner.py
index 407cf737edbad9b4ac4dc445ac30b55aa0ee5b3f..d1b36c3b2c531cf542d41b32bc52f958dbdae2c3 100755
--- a/distiller/pruning/ranked_structures_pruner.py
+++ b/distiller/pruning/ranked_structures_pruner.py
@@ -14,6 +14,7 @@
 # limitations under the License.
 #
 
+import numpy as np
 import logging
 import torch
 import distiller
@@ -22,15 +23,20 @@ msglogger = logging.getLogger()
 
 
 # TODO: support different policies for ranking structures
-class L1RankedStructureParameterPruner(_ParameterPruner):
+class RankedStructureParameterPruner(_ParameterPruner):
     """Uses mean L1-norm to rank structures and prune a specified percentage of structures
     """
     def __init__(self, name, reg_regims):
-        super(L1RankedStructureParameterPruner, self).__init__(name)
-        self.name = name
+        super(RankedStructureParameterPruner, self).__init__(name)
         self.reg_regims = reg_regims
 
 
+class L1RankedStructureParameterPruner(RankedStructureParameterPruner):
+    """Uses mean L1-norm to rank structures and prune a specified percentage of structures
+    """
+    def __init__(self, name, reg_regims):
+        super(L1RankedStructureParameterPruner, self).__init__(name, reg_regims)
+
     def set_param_mask(self, param, param_name, zeros_mask_dict, meta):
         if param_name not in self.reg_regims.keys():
             return
@@ -44,9 +50,11 @@ class L1RankedStructureParameterPruner(_ParameterPruner):
             return self.rank_prune_filters(fraction_to_prune, param, param_name, zeros_mask_dict)
         elif group_type == 'Channels':
             return self.rank_prune_channels(fraction_to_prune, param, param_name, zeros_mask_dict)
+        elif group_type == 'Rows':
+            return self.rank_prune_rows(fraction_to_prune, param, param_name, zeros_mask_dict)
         else:
             raise ValueError("Currently only filter (3D) and channel ranking is supported")
-            
+
     @staticmethod
     def rank_channels(fraction_to_prune, param):
         num_filters = param.size(0)
@@ -71,7 +79,7 @@ class L1RankedStructureParameterPruner(_ParameterPruner):
         return bottomk, channel_mags
 
     @staticmethod
-    def rank_prune_channels(fraction_to_prune, param, param_name, zeros_mask_dict):
+    def rank_prune_channels(fraction_to_prune, param, param_name, zeros_mask_dict, model=None):
         bottomk_channels, channel_mags = L1RankedStructureParameterPruner.rank_channels(fraction_to_prune, param)
         if bottomk_channels is None:
             # Empty list means that fraction_to_prune is too low to prune anything
@@ -92,20 +100,177 @@ class L1RankedStructureParameterPruner(_ParameterPruner):
                        fraction_to_prune, len(bottomk_channels), num_channels)
 
     @staticmethod
-    def rank_prune_filters(fraction_to_prune, param, param_name, zeros_mask_dict):
+    def rank_prune_filters(fraction_to_prune, param, param_name, zeros_mask_dict, model=None):
         assert param.dim() == 4, "This thresholding is only supported for 4D weights"
+        # First we rank the filters
         view_filters = param.view(param.size(0), -1)
-        filter_mags = view_filters.data.norm(1, dim=1)  # same as view_filters.data.abs().sum(dim=1)
+        filter_mags = view_filters.data.abs().mean(dim=1)
         topk_filters = int(fraction_to_prune * filter_mags.size(0))
         if topk_filters == 0:
             msglogger.info("Too few filters - can't prune %.1f%% filters", 100*fraction_to_prune)
             return
-
         bottomk, _ = torch.topk(filter_mags, topk_filters, largest=False, sorted=True)
         threshold = bottomk[-1]
-        binary_map = filter_mags.gt(threshold).type(param.data.type())
-        expanded = binary_map.expand(param.size(1) * param.size(2) * param.size(3), param.size(0)).t().contiguous()
-        zeros_mask_dict[param_name].mask = expanded.view(param.size(0), param.size(1), param.size(2), param.size(3))
+        # Then we threshold
+        zeros_mask_dict[param_name].mask = distiller.group_threshold_mask(param, 'Filters', threshold, 'Mean_Abs')
         msglogger.info("L1RankedStructureParameterPruner - param: %s pruned=%.3f goal=%.3f (%d/%d)", param_name,
                        distiller.sparsity(zeros_mask_dict[param_name].mask),
                        fraction_to_prune, topk_filters, filter_mags.size(0))
+
+    @staticmethod
+    def rank_prune_rows(fraction_to_prune, param, param_name, zeros_mask_dict, model=None):
+        """Prune the rows of a matrix, based on ranked L1-norms of the matrix rows.
+
+        PyTorch stores the weights matrices in a transposed format.  I.e. before performing GEMM, a matrix is
+        transposed.  This is counter-intuitive.  To deal with this, we can either transpose the matrix and
+        then proceed to compute the masks as usual, or we can treat columns as rows, and rows as columns :-(.
+        We choose the latter, because transposing very large matrices can be detrimental to performance.  Note
+        that computing mean L1-norm of columns is also not optimal, because consequtive column elements are far
+        away from each other in memory, and this means poor use of caches and system memory.
+        """
+
+        assert param.dim() == 2, "This thresholding is only supported for 2D weights"
+        ROWS_DIM = 0
+        THRESHOLD_DIM = 'Cols'
+        rows_mags = param.abs().mean(dim=ROWS_DIM)
+        num_rows_to_prune = int(fraction_to_prune * rows_mags.size(0))
+        if num_rows_to_prune == 0:
+            msglogger.info("Too few filters - can't prune %.1f%% rows", 100*fraction_to_prune)
+            return
+        bottomk_rows, _ = torch.topk(rows_mags, num_rows_to_prune, largest=False, sorted=True)
+        threshold = bottomk_rows[-1]
+        zeros_mask_dict[param_name].mask = distiller.group_threshold_mask(param, THRESHOLD_DIM, threshold, 'Mean_Abs')
+        msglogger.info("L1RankedStructureParameterPruner - param: %s pruned=%.3f goal=%.3f (%d/%d)", param_name,
+                       distiller.sparsity(zeros_mask_dict[param_name].mask),
+                       fraction_to_prune, num_rows_to_prune, rows_mags.size(0))
+
+
+class RankedFiltersParameterPruner(RankedStructureParameterPruner):
+    """Base class for the special (but often-used) case of ranking filters
+    """
+    def __init__(self, name, reg_regims):
+        super(RankedFiltersParameterPruner, self).__init__(name, reg_regims)
+
+    def set_param_mask(self, param, param_name, zeros_mask_dict, meta):
+        if param_name not in self.reg_regims.keys():
+            return
+
+        group_type = self.reg_regims[param_name][1]
+        fraction_to_prune = self.reg_regims[param_name][0]
+        if fraction_to_prune == 0:
+            return
+
+        if group_type in ['3D', 'Filters']:
+            return self.rank_prune_filters(fraction_to_prune, param, param_name, zeros_mask_dict, meta['model'])
+        else:
+            raise ValueError("Currently only filter (3D) ranking is supported")
+
+    @staticmethod
+    def mask_from_filter_order(filters_ordered_by_criterion, param, num_filters):
+        binary_map = torch.zeros(num_filters).cuda()
+        binary_map[filters_ordered_by_criterion] = 1
+        #msglogger.info("binary_map: {}".format(binary_map))
+        expanded = binary_map.expand(param.size(1) * param.size(2) * param.size(3), param.size(0)).t().contiguous()
+        return expanded.view(param.shape)
+
+
+class ActivationAPoZRankedFilterPruner(RankedFiltersParameterPruner):
+    """Uses mean APoZ (average percentage of zeros) activation channels to rank structures
+    and prune a specified percentage of structures.
+
+    "Network Trimming: A Data-Driven Neuron Pruning Approach towards Efficient Deep Architectures",
+    Hengyuan Hu, Rui Peng, Yu-Wing Tai, Chi-Keung Tang, ICLR 2016
+    https://arxiv.org/abs/1607.03250
+    """
+    def __init__(self, name, reg_regims):
+        super(ActivationAPoZRankedFilterParameterPruner, self).__init__(name, reg_regims)
+
+    @staticmethod
+    def rank_prune_filters(fraction_to_prune, param, param_name, zeros_mask_dict, model):
+        assert param.dim() == 4, "This thresholding is only supported for 4D weights"
+
+        # Use the parameter name to locate the module that has the activation sparsity statistics
+        fq_name = param_name.replace(".conv", ".relu")[:-len(".weight")]
+        module = distiller.find_module_by_fq_name(model, fq_name)
+        if module is None:
+            raise ValueError("Could not find a layer named %s in the model."
+                             "\nMake sure to use assign_layer_fq_names()" % fq_name)
+        if not hasattr(module, 'apoz_channels'):
+            raise ValueError("Could not find attribute \'apoz_channels\' in module %s."
+                             "\nMake sure to use SummaryActivationStatsCollector(\"apoz_channels\")" % fq_name)
+
+        apoz, std = module.apoz_channels.value()
+        num_filters = param.size(0)
+        num_filters_to_prune = int(fraction_to_prune * num_filters)
+        if num_filters_to_prune == 0:
+            msglogger.info("Too few filters - can't prune %.1f%% filters", 100*fraction_to_prune)
+            return
+
+        # Sort from high to low, and remove the bottom 'num_filters_to_prune' filters
+        filters_ordered_by_apoz = np.argsort(-apoz)[:-num_filters_to_prune]
+        zeros_mask_dict[param_name].mask = RankedFiltersParameterPruner.mask_from_filter_order(filters_ordered_by_apoz,
+                                                                                               param, num_filters)
+
+        msglogger.info("ActivationL1RankedStructureParameterPruner - param: %s pruned=%.3f goal=%.3f (%d/%d)",
+                       param_name,
+                       distiller.sparsity_3D(zeros_mask_dict[param_name].mask),
+                       fraction_to_prune, num_filters_to_prune, num_filters)
+
+
+class RandomRankedFilterPruner(RankedFiltersParameterPruner):
+    """A Random raanking of filters.
+
+    This is used for sanity testing of other algorithms.
+    """
+    def __init__(self, name, reg_regims):
+        super(RandomRankedFilterPruner, self).__init__(name, reg_regims)
+
+    @staticmethod
+    def rank_prune_filters(fraction_to_prune, param, param_name, zeros_mask_dict, model):
+        assert param.dim() == 4, "This thresholding is only supported for 4D weights"
+        num_filters = param.size(0)
+        num_filters_to_prune = int(fraction_to_prune * num_filters)
+
+        if num_filters_to_prune == 0:
+            msglogger.info("Too few filters - can't prune %.1f%% filters", 100*fraction_to_prune)
+            return
+
+        filters_ordered_randomly = np.random.permutation(num_filters)[:-num_filters_to_prune]
+        zeros_mask_dict[param_name].mask = RankedFiltersParameterPruner.mask_from_filter_order(filters_ordered_randomly,
+                                                                                               param, num_filters)
+
+        msglogger.info("RandomRankedFilterPruner - param: %s pruned=%.3f goal=%.3f (%d/%d)",
+                       param_name,
+                       distiller.sparsity_3D(zeros_mask_dict[param_name].mask),
+                       fraction_to_prune, num_filters_to_prune, num_filters)
+
+
+class GradientRankedFilterPruner(RankedFiltersParameterPruner):
+    """
+    """
+    def __init__(self, name, reg_regims):
+        super(RandomRankedFilterPruner, self).__init__(name, reg_regims)
+
+    @staticmethod
+    def rank_prune_filters(fraction_to_prune, param, param_name, zeros_mask_dict, model):
+        assert param.dim() == 4, "This thresholding is only supported for 4D weights"
+        num_filters = param.size(0)
+        num_filters_to_prune = int(fraction_to_prune * num_filters)
+        if num_filters_to_prune == 0:
+            msglogger.info("Too few filters - can't prune %.1f%% filters", 100*fraction_to_prune)
+            return
+
+        # Compute the multiplicatipn of the filters times the filter_gradienrs
+        view_filters = param.view(param.size(0), -1)
+        view_filter_grads = param.grad.view(param.size(0), -1)
+        weighted_gradients = view_filter_grads * view_filters
+        weighted_gradients = weighted_gradients.sum(dim=1)
+
+        # Sort from high to low, and remove the bottom 'num_filters_to_prune' filters
+        filters_ordered_by_gradient = np.argsort(-weighted_gradients.detach().cpu().numpy())[:-num_filters_to_prune]
+        zeros_mask_dict[param_name].mask = RankedFiltersParameterPruner.mask_from_filter_order(filters_ordered_by_gradient,
+                                                                                               param, num_filters)
+        msglogger.info("GradientRankedFilterPruner - param: %s pruned=%.3f goal=%.3f (%d/%d)",
+                       param_name,
+                       distiller.sparsity_3D(zeros_mask_dict[param_name].mask),
+                       fraction_to_prune, num_filters_to_prune, num_filters)
diff --git a/distiller/thresholding.py b/distiller/thresholding.py
index 4cb81898d104578cb15e82c6dffe9c06d66e46df..38c9923908bcf54bb6d9cf441b9c48032a52f3bf 100755
--- a/distiller/thresholding.py
+++ b/distiller/thresholding.py
@@ -20,6 +20,7 @@ The code below supports fine-grained tensor thresholding and group-wise threshol
 """
 import torch
 
+
 def threshold_mask(weights, threshold):
     """Create a threshold mask for the provided parameter tensor using
     magnitude thresholding.
@@ -32,95 +33,104 @@ def threshold_mask(weights, threshold):
     """
     return torch.gt(torch.abs(weights), threshold).type(weights.type())
 
+
 class GroupThresholdMixin(object):
-    """A mixin class to add group thresholding capabilities"""
+    """A mixin class to add group thresholding capabilities
 
+    TODO: this does not need to be a mixin - it should be made a simple function.  We keep this until we refactor
+    """
     def group_threshold_mask(self, param, group_type, threshold, threshold_criteria):
-        """Return a threshold mask for the provided parameter and group type.
-
-        Args:
-            param: The parameter to mask
-            group_type: The elements grouping type (structure).
-              One of:2D, 3D, 4D, Channels, Row, Cols
-            threshold: The threshold
-            threshold_criteria: The thresholding criteria.
-              'Mean_Abs' thresholds the entire element group using the mean of the
-              absolute values of the tensor elements.
-              'Max' thresholds the entire group using the magnitude of the largest
-              element in the group.
-        """
-        if group_type == '2D':
-            assert param.dim() == 4, "This thresholding is only supported for 4D weights"
-            view_2d = param.view(-1, param.size(2) * param.size(3))
-            # 1. Determine if the kernel "value" is below the threshold, by creating a 1D
-            #    thresholds tensor with length = #IFMs * # OFMs
-            thresholds = torch.Tensor([threshold] * param.size(0) * param.size(1)).cuda()
-            # 2. Create a binary thresholds mask, where we use the mean of the abs values of the
-            #    elements in each channel as the threshold filter.
-            # 3. Apply the threshold filter
-            binary_map = self.threshold_policy(view_2d, thresholds, threshold_criteria)
-            # 3. Finally, expand the thresholds and view as a 4D tensor
-            a = binary_map.expand(param.size(2) * param.size(3),
-                                  param.size(0) * param.size(1)).t()
-            return a.view(param.size(0), param.size(1), param.size(2), param.size(3))
-
-        elif group_type == 'Rows':
-            assert param.dim() == 2, "This regularization is only supported for 2D weights"
-            thresholds = torch.Tensor([threshold] * param.size(0)).cuda()
-            binary_map = self.threshold_policy(param, thresholds, threshold_criteria)
-            return binary_map.expand(param.size(1), param.size(0)).t()
-
-        elif group_type == 'Cols':
-            assert param.dim() == 2, "This regularization is only supported for 2D weights"
-            thresholds = torch.Tensor([threshold] * param.size(1)).cuda()
-            binary_map = self.threshold_policy(param, thresholds, threshold_criteria, dim=0)
-            return binary_map.expand(param.size(0), param.size(1))
-
-        elif group_type == '3D' or group_type == 'Filters':
-            assert param.dim() == 4, "This thresholding is only supported for 4D weights"
-            view_filters = param.view(param.size(0), -1)
-            thresholds = torch.Tensor([threshold] * param.size(0)).cuda()
-            binary_map = self.threshold_policy(view_filters, thresholds, threshold_criteria)
-            a = binary_map.expand(param.size(1) * param.size(2) * param.size(3), param.size(0)).t()
-            return a.view(param.size(0), param.size(1), param.size(2), param.size(3))
-
-        elif group_type == '4D':
-            assert param.dim() == 4, "This thresholding is only supported for 4D weights"
-            if threshold_criteria == 'Mean_Abs':
-                if param.data.abs().mean() > threshold:
-                    return None
-                return torch.zeros_like(param.data)
-            elif threshold_criteria == 'Max':
-                if param.data.abs().max() > threshold:
-                    return None
-                return torch.zeros_like(param.data)
-            exit("Invalid threshold_criteria {}".format(self.threshold_criteria))
-
-        elif group_type == 'Channels':
-            assert param.dim() == 4, "This thresholding is only supported for 4D weights"
-            num_filters = param.size(0)
-            num_kernels_per_filter = param.size(1)
-
-            view_2d = param.view(-1, param.size(2) * param.size(3))
-            # Next, compute the sum of the squares (of the elements in each row/kernel)
-            kernel_means = view_2d.abs().mean(dim=1)
-            k_means_mat = kernel_means.view(num_filters, num_kernels_per_filter).t()
-            thresholds = torch.Tensor([threshold] * num_kernels_per_filter).cuda()
-            binary_map = k_means_mat.data.mean(dim=1).gt(thresholds).type(param.type())
-
-            # Now let's expand back up to a 4D mask
-            a = binary_map.expand(num_filters, num_kernels_per_filter)
-            c = a.unsqueeze(-1)
-            d = c.expand(num_filters, num_kernels_per_filter, param.size(2) * param.size(3)).contiguous()
-            return d.view(param.size(0), param.size(1), param.size(2), param.size(3))
-
-
-    def threshold_policy(self, weights, thresholds, threshold_criteria, dim=1):
-        """
-        """
+        return group_threshold_mask(param, group_type, threshold, threshold_criteria)
+
+
+def group_threshold_mask(param, group_type, threshold, threshold_criteria):
+    """Return a threshold mask for the provided parameter and group type.
+
+    Args:
+        param: The parameter to mask
+        group_type: The elements grouping type (structure).
+          One of:2D, 3D, 4D, Channels, Row, Cols
+        threshold: The threshold
+        threshold_criteria: The thresholding criteria.
+          'Mean_Abs' thresholds the entire element group using the mean of the
+          absolute values of the tensor elements.
+          'Max' thresholds the entire group using the magnitude of the largest
+          element in the group.
+    """
+    if group_type == '2D':
+        assert param.dim() == 4, "This thresholding is only supported for 4D weights"
+        view_2d = param.view(-1, param.size(2) * param.size(3))
+        # 1. Determine if the kernel "value" is below the threshold, by creating a 1D
+        #    thresholds tensor with length = #IFMs * # OFMs
+        thresholds = torch.Tensor([threshold] * param.size(0) * param.size(1)).cuda()
+        # 2. Create a binary thresholds mask, where we use the mean of the abs values of the
+        #    elements in each channel as the threshold filter.
+        # 3. Apply the threshold filter
+        binary_map = threshold_policy(view_2d, thresholds, threshold_criteria)
+        # 3. Finally, expand the thresholds and view as a 4D tensor
+        a = binary_map.expand(param.size(2) * param.size(3),
+                              param.size(0) * param.size(1)).t()
+        return a.view(param.size(0), param.size(1), param.size(2), param.size(3))
+
+    elif group_type == 'Rows':
+        assert param.dim() == 2, "This regularization is only supported for 2D weights"
+        thresholds = torch.Tensor([threshold] * param.size(0)).cuda()
+        binary_map = threshold_policy(param, thresholds, threshold_criteria)
+        return binary_map.expand(param.size(1), param.size(0)).t()
+
+    elif group_type == 'Cols':
+        assert param.dim() == 2, "This regularization is only supported for 2D weights"
+        thresholds = torch.Tensor([threshold] * param.size(1)).cuda()
+        binary_map = threshold_policy(param, thresholds, threshold_criteria, dim=0)
+        return binary_map.expand(param.size(0), param.size(1))
+
+    elif group_type == '3D' or group_type == 'Filters':
+        assert param.dim() == 4, "This thresholding is only supported for 4D weights"
+        view_filters = param.view(param.size(0), -1)
+        thresholds = torch.Tensor([threshold] * param.size(0)).cuda()
+        binary_map = threshold_policy(view_filters, thresholds, threshold_criteria)
+        a = binary_map.expand(param.size(1) * param.size(2) * param.size(3), param.size(0)).t()
+        return a.view(param.size(0), param.size(1), param.size(2), param.size(3))
+
+    elif group_type == '4D':
+        assert param.dim() == 4, "This thresholding is only supported for 4D weights"
         if threshold_criteria == 'Mean_Abs':
-            return weights.data.abs().mean(dim=dim).gt(thresholds).type(weights.type())
+            if param.data.abs().mean() > threshold:
+                return None
+            return torch.zeros_like(param.data)
         elif threshold_criteria == 'Max':
-            maxv, _ = weights.data.abs().max(dim=dim)
-            return maxv.gt(thresholds).type(weights.type())
+            if param.data.abs().max() > threshold:
+                return None
+            return torch.zeros_like(param.data)
         exit("Invalid threshold_criteria {}".format(threshold_criteria))
+
+    elif group_type == 'Channels':
+        assert param.dim() == 4, "This thresholding is only supported for 4D weights"
+        num_filters = param.size(0)
+        num_kernels_per_filter = param.size(1)
+
+        view_2d = param.view(-1, param.size(2) * param.size(3))
+        # Next, compute the sum of the squares (of the elements in each row/kernel)
+        kernel_means = view_2d.abs().mean(dim=1)
+        k_means_mat = kernel_means.view(num_filters, num_kernels_per_filter).t()
+        thresholds = torch.Tensor([threshold] * num_kernels_per_filter).cuda()
+        binary_map = k_means_mat.data.mean(dim=1).gt(thresholds).type(param.type())
+
+        # Now let's expand back up to a 4D mask
+        a = binary_map.expand(num_filters, num_kernels_per_filter)
+        c = a.unsqueeze(-1)
+        d = c.expand(num_filters, num_kernels_per_filter, param.size(2) * param.size(3)).contiguous()
+        return d.view(param.size(0), param.size(1), param.size(2), param.size(3))
+
+
+def threshold_policy(weights, thresholds, threshold_criteria, dim=1):
+    """
+    """
+    if threshold_criteria == 'Mean_Abs':
+        return weights.data.abs().mean(dim=dim).gt(thresholds).type(weights.type())
+    elif threshold_criteria == 'L1':
+        return weights.data.norm(p=1, dim=dim).gt(thresholds).type(weights.type())
+    elif threshold_criteria == 'Max':
+        maxv, _ = weights.data.abs().max(dim=dim)
+        return maxv.gt(thresholds).type(weights.type())
+    exit("Invalid threshold_criteria {}".format(threshold_criteria))
diff --git a/distiller/utils.py b/distiller/utils.py
index a193dfcdda35a9139f5c7a8b9d15140134debe84..4676474f441b6dc1801f5f0179240ff98795aedc 100755
--- a/distiller/utils.py
+++ b/distiller/utils.py
@@ -21,7 +21,6 @@ with some random helper functions.
 """
 import numpy as np
 import torch
-from torch.autograd import Variable
 import torch.nn as nn
 from copy import deepcopy
 
@@ -401,23 +400,6 @@ def has_children(module):
         return False
 
 
-class DoNothingModuleWrapper(nn.Module):
-    """Implement a nn.Module which wraps another nn.Module.
-
-    The DoNothingModuleWrapper wrapper does nothing but forward
-    to the wrapped module.
-    One use-case for this class, is for replacing nn.DataParallel
-    by a module that does nothing :-).  This is a trick we use
-    to transform data-parallel to serialized models.
-    """
-    def __init__(self, module):
-        super(DoNothingModuleWrapper, self).__init__()
-        self.module = module
-
-    def forward(self, *inputs, **kwargs):
-        return self.module(*inputs, **kwargs)
-
-
 def make_non_parallel_copy(model):
     """Make a non-data-parallel copy of the provided model.
 
diff --git a/examples/agp-pruning/resnet20_filters.schedule_agp.yaml b/examples/agp-pruning/resnet20_filters.schedule_agp.yaml
index 83cc2fdfe2cd62e7a3e78c04eaf7f304eccc7fa2..406180e3e5f4f05f05bbc602b4143263c80088c8 100755
--- a/examples/agp-pruning/resnet20_filters.schedule_agp.yaml
+++ b/examples/agp-pruning/resnet20_filters.schedule_agp.yaml
@@ -1,53 +1,73 @@
+# This is a hybrid pruning schedule composed of several pruning techniques, all using AGP scheduling:
+# 1. Filter pruning (and thinning) to reduce compute and activation sizes of some layers.
+# 2. Fine grained pruning to reduce the parameter memory requirements of layers with large weights tensors.
+# 3. Row pruning for the last linear (fully-connected) layer.
+#
 # Baseline results:
-# Top1: 91.780    Top5: 99.710    Loss: 0.376
+#     Top1: 91.780    Top5: 99.710    Loss: 0.376
+#     Total MACs: 40,813,184
+#
+# Results:
+#     Top1: 91.760    Top5: 99.700    Loss: 1.546
+#     Total MACs: 35,947,136
+#     Total sparsity: 41.10
+#
 # time python3 compress_classifier.py --arch resnet20_cifar  ../../../data.cifar10 -p=50 --lr=0.1 --epochs=180 --compress=../agp-pruning/resnet20_filters.schedule_agp.yaml -j=1 --deterministic --resume=../ssl/checkpoints/checkpoint_trained_dense.pth.tar
 #
 # Parameters:
 # +----+-------------------------------------+----------------+---------------+----------------+------------+------------+----------+----------+----------+------------+---------+----------+------------+
 # |    | Name                                | Shape          |   NNZ (dense) |   NNZ (sparse) |   Cols (%) |   Rows (%) |   Ch (%) |   2D (%) |   3D (%) |   Fine (%) |     Std |     Mean |   Abs-Mean |
 # |----+-------------------------------------+----------------+---------------+----------------+------------+------------+----------+----------+----------+------------+---------+----------+------------|
-# |  0 | module.conv1.weight                 | (16, 3, 3, 3)  |           432 |            432 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.39017 | -0.00681 |    0.27798 |
-# |  1 | module.layer1.0.conv1.weight        | (16, 16, 3, 3) |          2304 |           2304 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.14674 | -0.00888 |    0.10358 |
-# |  2 | module.layer1.0.conv2.weight        | (16, 16, 3, 3) |          2304 |           2304 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.14363 |  0.00146 |    0.10571 |
-# |  3 | module.layer1.1.conv1.weight        | (16, 16, 3, 3) |          2304 |           2304 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.12673 | -0.01323 |    0.09655 |
-# |  4 | module.layer1.1.conv2.weight        | (16, 16, 3, 3) |          2304 |           2304 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.11736 | -0.00420 |    0.09039 |
-# |  5 | module.layer1.2.conv1.weight        | (16, 16, 3, 3) |          2304 |           2304 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.16400 | -0.00959 |    0.12023 |
-# |  6 | module.layer1.2.conv2.weight        | (16, 16, 3, 3) |          2304 |           2304 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.13288 | -0.00014 |    0.10020 |
-# |  7 | module.layer2.0.conv1.weight        | (20, 16, 3, 3) |          2880 |           2880 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.14688 | -0.00195 |    0.11372 |
-# |  8 | module.layer2.0.conv2.weight        | (32, 20, 3, 3) |          5760 |           5760 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.12828 | -0.00643 |    0.10049 |
-# |  9 | module.layer2.0.downsample.0.weight | (32, 16, 1, 1) |           512 |            512 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.25453 | -0.00949 |    0.17990 |
-# | 10 | module.layer2.1.conv1.weight        | (20, 32, 3, 3) |          5760 |           5760 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.10884 | -0.00760 |    0.08639 |
-# | 11 | module.layer2.1.conv2.weight        | (32, 20, 3, 3) |          5760 |           5760 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.09702 | -0.00599 |    0.07635 |
-# | 12 | module.layer2.2.conv1.weight        | (20, 32, 3, 3) |          5760 |           5760 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.11464 | -0.01339 |    0.09051 |
-# | 13 | module.layer2.2.conv2.weight        | (32, 20, 3, 3) |          5760 |           5760 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.09177 |  0.00195 |    0.07188 |
-# | 14 | module.layer3.0.conv1.weight        | (64, 32, 3, 3) |         18432 |          18432 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.09764 | -0.00680 |    0.07753 |
-# | 15 | module.layer3.0.conv2.weight        | (64, 64, 3, 3) |         36864 |          36864 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.09308 | -0.00392 |    0.07406 |
-# | 16 | module.layer3.0.downsample.0.weight | (64, 32, 1, 1) |          2048 |           2048 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.12596 | -0.00848 |    0.09993 |
-# | 17 | module.layer3.1.conv1.weight        | (64, 64, 3, 3) |         36864 |          11060 |    0.00000 |    0.00000 |  0.00000 |  6.49414 |  0.00000 |   69.99783 | 0.07444 | -0.00396 |    0.03728 |
-# | 18 | module.layer3.1.conv2.weight        | (64, 64, 3, 3) |         36864 |          11060 |    0.00000 |    0.00000 |  0.00000 |  7.49512 |  0.00000 |   69.99783 | 0.06792 | -0.00462 |    0.03381 |
-# | 19 | module.layer3.2.conv1.weight        | (64, 64, 3, 3) |         36864 |          11060 |    0.00000 |    0.00000 |  0.00000 |  9.81445 |  0.00000 |   69.99783 | 0.06811 | -0.00477 |    0.03417 |
-# | 20 | module.layer3.2.conv2.weight        | (64, 64, 3, 3) |         36864 |          11060 |    0.00000 |    0.00000 |  0.00000 | 26.00098 |  0.00000 |   69.99783 | 0.03877 |  0.00056 |    0.01954 |
-# | 21 | module.fc.weight                    | (10, 64)       |           640 |            640 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.56077 | -0.00002 |    0.48798 |
-# | 22 | Total sparsity:                     | -              |        251888 |         148672 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |   40.97694 | 0.00000 |  0.00000 |    0.00000 |
+# |  0 | module.conv1.weight                 | (16, 3, 3, 3)  |           432 |            432 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.26754 | -0.00478 |    0.18996 |
+# |  1 | module.layer1.0.conv1.weight        | (16, 16, 3, 3) |          2304 |           2304 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.10113 | -0.00595 |    0.07182 |
+# |  2 | module.layer1.0.conv2.weight        | (16, 16, 3, 3) |          2304 |           2304 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.09882 | -0.00013 |    0.07256 |
+# |  3 | module.layer1.1.conv1.weight        | (16, 16, 3, 3) |          2304 |           2304 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.08715 | -0.01028 |    0.06691 |
+# |  4 | module.layer1.1.conv2.weight        | (16, 16, 3, 3) |          2304 |           2304 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.08150 | -0.00316 |    0.06242 |
+# |  5 | module.layer1.2.conv1.weight        | (16, 16, 3, 3) |          2304 |           2304 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.11227 | -0.00627 |    0.08206 |
+# |  6 | module.layer1.2.conv2.weight        | (16, 16, 3, 3) |          2304 |           2304 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.09145 |  0.00145 |    0.06919 |
+# |  7 | module.layer2.0.conv1.weight        | (20, 16, 3, 3) |          2880 |           2880 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.09975 | -0.00178 |    0.07747 |
+# |  8 | module.layer2.0.conv2.weight        | (32, 20, 3, 3) |          5760 |           5760 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.08692 | -0.00438 |    0.06784 |
+# |  9 | module.layer2.0.downsample.0.weight | (32, 16, 1, 1) |           512 |            512 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.17339 | -0.00644 |    0.12457 |
+# | 10 | module.layer2.1.conv1.weight        | (20, 32, 3, 3) |          5760 |           5760 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.07515 | -0.00582 |    0.05967 |
+# | 11 | module.layer2.1.conv2.weight        | (32, 20, 3, 3) |          5760 |           5760 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.06694 | -0.00409 |    0.05272 |
+# | 12 | module.layer2.2.conv1.weight        | (20, 32, 3, 3) |          5760 |           5760 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.07822 | -0.00873 |    0.06161 |
+# | 13 | module.layer2.2.conv2.weight        | (32, 20, 3, 3) |          5760 |           5760 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.06251 |  0.00119 |    0.04923 |
+# | 14 | module.layer3.0.conv1.weight        | (64, 32, 3, 3) |         18432 |          18432 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.06655 | -0.00436 |    0.05293 |
+# | 15 | module.layer3.0.conv2.weight        | (64, 64, 3, 3) |         36864 |          36864 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.06298 | -0.00286 |    0.05019 |
+# | 16 | module.layer3.0.downsample.0.weight | (64, 32, 1, 1) |          2048 |           2048 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.08574 | -0.00490 |    0.06750 |
+# | 17 | module.layer3.1.conv1.weight        | (64, 64, 3, 3) |         36864 |          11060 |    0.00000 |    0.00000 |  0.00000 |  7.00684 |  0.00000 |   69.99783 | 0.05113 | -0.00318 |    0.02568 |
+# | 18 | module.layer3.1.conv2.weight        | (64, 64, 3, 3) |         36864 |          11060 |    0.00000 |    0.00000 |  0.00000 |  7.64160 |  0.00000 |   69.99783 | 0.04585 | -0.00355 |    0.02293 |
+# | 19 | module.layer3.2.conv1.weight        | (64, 64, 3, 3) |         36864 |          11060 |    0.00000 |    0.00000 |  0.00000 | 10.88867 |  0.00000 |   69.99783 | 0.04487 | -0.00409 |    0.02258 |
+# | 20 | module.layer3.2.conv2.weight        | (64, 64, 3, 3) |         36864 |          11060 |    0.00000 |    0.00000 |  0.00000 | 31.51855 |  1.56250 |   69.99783 | 0.02512 |  0.00008 |    0.01251 |
+# | 21 | module.fc.weight                    | (10, 64)       |           640 |            320 |    0.00000 |   50.00000 |  0.00000 |  0.00000 |  0.00000 |   50.00000 | 0.48359 | -0.00001 |    0.30379 |
+# | 22 | Total sparsity:                     | -              |        251888 |         148352 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |   41.10398 | 0.00000 |  0.00000 |    0.00000 |
 # +----+-------------------------------------+----------------+---------------+----------------+------------+------------+----------+----------+----------+------------+---------+----------+------------+
-# Total sparsity: 40.98
+# Total sparsity: 41.10
 #
 # --- validate (epoch=359)-----------
 # 5000 samples (256 per mini-batch)
-# ==> Top1: 93.320    Top5: 99.880    Loss: 0.246
+# ==> Top1: 93.720    Top5: 99.880    Loss: 1.529
 #
-# ==> Best Top1: 93.740   On Epoch: 265
+# ==> Best Top1: 96.900   On Epoch: 181
 #
-# Saving checkpoint to: logs/2018.10.09-020359/checkpoint.pth.tar
+# Saving checkpoint to: logs/2018.10.15-111439/checkpoint.pth.tar
 # --- test ---------------------
 # 10000 samples (256 per mini-batch)
-# ==> Top1: 91.580    Top5: 99.710    Loss: 0.355
+# ==> Top1: 91.760    Top5: 99.700    Loss: 1.546
+#
 #
+# Log file for this run: /home/cvds_lab/nzmora/pytorch_workspace/distiller/examples/classifier_compression/logs/2018.10.15-111439/2018.10.15-111439.log
+#
+# real    31m55.802s
+# user    73m1.353s
+# sys     8m46.687s
+
 
 version: 1
+
 pruners:
   low_pruner:
-    class: StructuredAutomatedGradualPruner
+    class: L1RankedStructureParameterPruner_AGP
     initial_sparsity : 0.10
     final_sparsity: 0.40
     reg_regims:
@@ -62,6 +82,13 @@ pruners:
     weights: [module.layer3.1.conv1.weight,  module.layer3.1.conv2.weight,
               module.layer3.2.conv1.weight,  module.layer3.2.conv2.weight]
 
+  fc_pruner:
+    class: L1RankedStructureParameterPruner_AGP
+    initial_sparsity : 0.05
+    final_sparsity: 0.50
+    reg_regims:
+      module.fc.weight: Rows
+
 lr_schedulers:
   pruning_lr:
     class: StepLR
@@ -88,8 +115,15 @@ policies:
     starting_epoch: 200
     ending_epoch: 220
     frequency: 2
-  # Currently the thinner is disabled until the end, because it interacts with the sparsity
-  # goals of the StructuredAutomatedGradualPruner.
+
+  - pruner:
+      instance_name : fc_pruner
+    starting_epoch: 200
+    ending_epoch: 220
+    frequency: 2
+
+  # Currently the thinner is disabled until the the structure pruner is done, because it interacts
+  # with the sparsity goals of the L1RankedStructureParameterPruner_AGP.
   # This can be fixed rather easily.
   # - extension:
   #     instance_name: net_thinner
diff --git a/examples/agp-pruning/resnet20_filters.schedule_agp_2.yaml b/examples/agp-pruning/resnet20_filters.schedule_agp_2.yaml
index c91a157934b587c0d9d6c18ecd9c98f1014f01a3..730dd9a830a427f9e39a10dcaaf8d91ebdb6e45c 100755
--- a/examples/agp-pruning/resnet20_filters.schedule_agp_2.yaml
+++ b/examples/agp-pruning/resnet20_filters.schedule_agp_2.yaml
@@ -1,60 +1,72 @@
+# This is a hybrid pruning schedule composed of several pruning techniques, all using AGP scheduling:
+# 1. Filter pruning (and thinning) to reduce compute and activation sizes of some layers.
+# 2. Fine grained pruning to reduce the parameter memory requirements of layers with large weights tensors.
+# 3. Row pruning for the last linear (fully-connected) layer.
+#
 # Baseline results:
-# Top1: 91.780    Top5: 99.710    Loss: 0.376
+#     Top1: 91.780    Top5: 99.710    Loss: 0.376
+#     Total MACs: 40,813,184
+#
+# Results:
+#     Top1: 91.200    Top5: 99.660    Loss: 1.551
+#     Total MACs: 30,638,720
+#     Total sparsity: 41.84
+#
 # time python3 compress_classifier.py --arch resnet20_cifar  ../../../data.cifar10 -p=50 --lr=0.1 --epochs=180 --compress=../agp-pruning/resnet20_filters.schedule_agp.yaml -j=1 --deterministic --resume=../ssl/checkpoints/checkpoint_trained_dense.pth.tar
 #
-# 
+#
 # Parameters:
 # +----+-------------------------------------+----------------+---------------+----------------+------------+------------+----------+----------+----------+------------+---------+----------+------------+
 # |    | Name                                | Shape          |   NNZ (dense) |   NNZ (sparse) |   Cols (%) |   Rows (%) |   Ch (%) |   2D (%) |   3D (%) |   Fine (%) |     Std |     Mean |   Abs-Mean |
 # |----+-------------------------------------+----------------+---------------+----------------+------------+------------+----------+----------+----------+------------+---------+----------+------------|
-# |  0 | module.conv1.weight                 | (16, 3, 3, 3)  |           432 |            432 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.39196 | -0.00533 |    0.27677 |
-# |  1 | module.layer1.0.conv1.weight        | (10, 16, 3, 3) |          1440 |           1440 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.17516 | -0.01627 |    0.12761 |
-# |  2 | module.layer1.0.conv2.weight        | (16, 10, 3, 3) |          1440 |           1440 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.17375 |  0.00208 |    0.12753 |
-# |  3 | module.layer1.1.conv1.weight        | (10, 16, 3, 3) |          1440 |           1440 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.14753 | -0.02355 |    0.11205 |
-# |  4 | module.layer1.1.conv2.weight        | (16, 10, 3, 3) |          1440 |           1440 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.13242 | -0.00280 |    0.10184 |
-# |  5 | module.layer1.2.conv1.weight        | (10, 16, 3, 3) |          1440 |           1440 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.18848 | -0.00708 |    0.13828 |
-# |  6 | module.layer1.2.conv2.weight        | (16, 10, 3, 3) |          1440 |           1440 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.15502 | -0.00528 |    0.11709 |
-# |  7 | module.layer2.0.conv1.weight        | (20, 16, 3, 3) |          2880 |           2880 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.15266 | -0.00169 |    0.11773 |
-# |  8 | module.layer2.0.conv2.weight        | (32, 20, 3, 3) |          5760 |           5760 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.13070 | -0.00823 |    0.10204 |
-# |  9 | module.layer2.0.downsample.0.weight | (32, 16, 1, 1) |           512 |            512 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.25380 | -0.01324 |    0.17815 |
-# | 10 | module.layer2.1.conv1.weight        | (20, 32, 3, 3) |          5760 |           5760 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.11349 | -0.00928 |    0.08977 |
-# | 11 | module.layer2.1.conv2.weight        | (32, 20, 3, 3) |          5760 |           5760 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.09904 | -0.00621 |    0.07856 |
-# | 12 | module.layer2.2.conv1.weight        | (20, 32, 3, 3) |          5760 |           5760 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.11538 | -0.01280 |    0.09106 |
-# | 13 | module.layer2.2.conv2.weight        | (32, 20, 3, 3) |          5760 |           5760 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.09239 |  0.00091 |    0.07236 |
-# | 14 | module.layer3.0.conv1.weight        | (64, 32, 3, 3) |         18432 |          18432 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.09853 | -0.00671 |    0.07821 |
-# | 15 | module.layer3.0.conv2.weight        | (64, 64, 3, 3) |         36864 |          36864 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.09391 | -0.00407 |    0.07466 |
-# | 16 | module.layer3.0.downsample.0.weight | (64, 32, 1, 1) |          2048 |           2048 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.12660 | -0.00968 |    0.10101 |
-# | 17 | module.layer3.1.conv1.weight        | (64, 64, 3, 3) |         36864 |          11060 |    0.00000 |    0.00000 |  0.00000 |  6.56738 |  0.00000 |   69.99783 | 0.07488 | -0.00414 |    0.03739 |
-# | 18 | module.layer3.1.conv2.weight        | (64, 64, 3, 3) |         36864 |          11060 |    0.00000 |    0.00000 |  0.00000 |  7.69043 |  0.00000 |   69.99783 | 0.06839 | -0.00472 |    0.03404 |
-# | 19 | module.layer3.2.conv1.weight        | (64, 64, 3, 3) |         36864 |          11060 |    0.00000 |    0.00000 |  0.00000 |  9.47266 |  0.00000 |   69.99783 | 0.06867 | -0.00485 |    0.03450 |
-# | 20 | module.layer3.2.conv2.weight        | (64, 64, 3, 3) |         36864 |          11060 |    0.00000 |    0.00000 |  0.00000 | 26.41602 |  0.00000 |   69.99783 | 0.03915 |  0.00033 |    0.01970 |
-# | 21 | module.fc.weight                    | (10, 64)       |           640 |            640 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.56522 | -0.00002 |    0.49040 |
+# |  0 | module.conv1.weight                 | (16, 3, 3, 3)  |           432 |            432 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.27315 | -0.00387 |    0.19394 |
+# |  1 | module.layer1.0.conv1.weight        | (10, 16, 3, 3) |          1440 |           1440 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.12038 | -0.01295 |    0.08811 |
+# |  2 | module.layer1.0.conv2.weight        | (16, 10, 3, 3) |          1440 |           1440 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.11879 | -0.00031 |    0.08735 |
+# |  3 | module.layer1.1.conv1.weight        | (10, 16, 3, 3) |          1440 |           1440 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.10293 | -0.01274 |    0.07795 |
+# |  4 | module.layer1.1.conv2.weight        | (16, 10, 3, 3) |          1440 |           1440 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.09285 | -0.00276 |    0.07141 |
+# |  5 | module.layer1.2.conv1.weight        | (10, 16, 3, 3) |          1440 |           1440 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.12849 | -0.00345 |    0.09355 |
+# |  6 | module.layer1.2.conv2.weight        | (16, 10, 3, 3) |          1440 |           1440 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.10689 | -0.00381 |    0.08038 |
+# |  7 | module.layer2.0.conv1.weight        | (20, 16, 3, 3) |          2880 |           2880 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.10467 | -0.00371 |    0.08149 |
+# |  8 | module.layer2.0.conv2.weight        | (32, 20, 3, 3) |          5760 |           5760 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.08897 | -0.00502 |    0.06938 |
+# |  9 | module.layer2.0.downsample.0.weight | (32, 16, 1, 1) |           512 |            512 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.17695 | -0.01111 |    0.12479 |
+# | 10 | module.layer2.1.conv1.weight        | (20, 32, 3, 3) |          5760 |           5760 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.07736 | -0.00531 |    0.06118 |
+# | 11 | module.layer2.1.conv2.weight        | (32, 20, 3, 3) |          5760 |           5760 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.06832 | -0.00404 |    0.05406 |
+# | 12 | module.layer2.2.conv1.weight        | (20, 32, 3, 3) |          5760 |           5760 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.07965 | -0.00904 |    0.06278 |
+# | 13 | module.layer2.2.conv2.weight        | (32, 20, 3, 3) |          5760 |           5760 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.06305 |  0.00122 |    0.04955 |
+# | 14 | module.layer3.0.conv1.weight        | (64, 32, 3, 3) |         18432 |          18432 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.06753 | -0.00459 |    0.05371 |
+# | 15 | module.layer3.0.conv2.weight        | (64, 64, 3, 3) |         36864 |          36864 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.06379 | -0.00297 |    0.05078 |
+# | 16 | module.layer3.0.downsample.0.weight | (64, 32, 1, 1) |          2048 |           2048 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.08779 | -0.00584 |    0.06956 |
+# | 17 | module.layer3.1.conv1.weight        | (64, 64, 3, 3) |         36864 |          11060 |    0.00000 |    0.00000 |  0.00000 |  7.12891 |  0.00000 |   69.99783 | 0.05191 | -0.00319 |    0.02604 |
+# | 18 | module.layer3.1.conv2.weight        | (64, 64, 3, 3) |         36864 |          11060 |    0.00000 |    0.00000 |  0.00000 |  7.98340 |  0.00000 |   69.99783 | 0.04658 | -0.00360 |    0.02330 |
+# | 19 | module.layer3.2.conv1.weight        | (64, 64, 3, 3) |         36864 |          11060 |    0.00000 |    0.00000 |  0.00000 | 10.10742 |  0.00000 |   69.99783 | 0.04563 | -0.00393 |    0.02297 |
+# | 20 | module.layer3.2.conv2.weight        | (64, 64, 3, 3) |         36864 |          11060 |    0.00000 |    0.00000 |  0.00000 | 31.20117 |  0.00000 |   69.99783 | 0.02453 |  0.00005 |    0.01240 |
+# | 21 | module.fc.weight                    | (10, 64)       |           640 |            640 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.50451 | -0.00001 |    0.43698 |
 # | 22 | Total sparsity:                     | -              |        246704 |         143488 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |   41.83799 | 0.00000 |  0.00000 |    0.00000 |
 # +----+-------------------------------------+----------------+---------------+----------------+------------+------------+----------+----------+----------+------------+---------+----------+------------+
 # Total sparsity: 41.84
 #
 # --- validate (epoch=359)-----------
 # 5000 samples (256 per mini-batch)
-# ==> Top1: 92.540    Top5: 99.960    Loss: 0.246
+# ==> Top1: 93.460    Top5: 99.800    Loss: 1.530
 #
-# ==> Best Top1: 93.580   On Epoch: 328
+# ==> Best Top1: 97.320   On Epoch: 180
 #
-# Saving checkpoint to: logs/2018.10.09-200709/checkpoint.pth.tar
+# Saving checkpoint to: logs/2018.10.15-115941/checkpoint.pth.tar
 # --- test ---------------------
 # 10000 samples (256 per mini-batch)
-# ==> Top1: 91.190    Top5: 99.660    Loss: 0.372
+# ==> Top1: 91.200    Top5: 99.660    Loss: 1.551
 #
 #
-# Log file for this run: /home/cvds_lab/nzmora/pytorch_workspace/distiller/examples/classifier_compression/logs/2018.10.09-200709/2018.10.09-200709.log
+# Log file for this run: /home/cvds_lab/nzmora/pytorch_workspace/distiller/examples/classifier_compression/logs/2018.10.15-115941/2018.10.15-115941.log
 #
-# real    32m23.439s
-# user    74m59.073s
-# sys     9m7.764s
+# real    32m31.997s
+# user    72m58.813s
+# sys     9m1.245s
 
 version: 1
 pruners:
   low_pruner:
-    class: StructuredAutomatedGradualPruner
+    class: L1RankedStructureParameterPruner_AGP
     initial_sparsity : 0.10
     final_sparsity: 0.40
     reg_regims:
@@ -99,7 +111,7 @@ policies:
     ending_epoch: 220
     frequency: 2
   # Currently the thinner is disabled until the end, because it interacts with the sparsity
-  # goals of the StructuredAutomatedGradualPruner.
+  # goals of the L1RankedStructureParameterPruner_AGP.
   # This can be fixed rather easily.
   # - extension:
   #     instance_name: net_thinner
diff --git a/examples/classifier_compression/compress_classifier.py b/examples/classifier_compression/compress_classifier.py
index e0687f1babe67352cf6a6b07f7307f66b9b604ab..dc0c5becb3b6ff4f5e489574bf2e8f49ae94ae14 100755
--- a/examples/classifier_compression/compress_classifier.py
+++ b/examples/classifier_compression/compress_classifier.py
@@ -57,7 +57,7 @@ import os
 import sys
 import random
 import traceback
-from collections import OrderedDict
+from collections import OrderedDict, defaultdict
 from functools import partial
 import numpy as np
 import torch
@@ -75,7 +75,7 @@ except ImportError:
     sys.path.append(module_path)
     import distiller
 import apputils
-from distiller.data_loggers import TensorBoardLogger, PythonLogger, ActivationSparsityCollector
+from distiller.data_loggers import *
 import distiller.quantization as quantization
 from models import ALL_MODEL_NAMES, create_model
 
@@ -118,7 +118,7 @@ parser.add_argument('-e', '--evaluate', dest='evaluate', action='store_true',
                     help='evaluate model on validation set')
 parser.add_argument('--pretrained', dest='pretrained', action='store_true',
                     help='use pre-trained model')
-parser.add_argument('--act-stats', dest='activation_stats', action='store_true', default=False,
+parser.add_argument('--act-stats', dest='activation_stats', choices=["train", "valid", "test"], default=None,
                     help='collect activation statistics (WARNING: this slows down training)')
 parser.add_argument('--param-hist', dest='log_params_histograms', action='store_true', default=False,
                     help='log the paramter tensors histograms to file (WARNING: this can use significant disk space)')
@@ -153,6 +153,7 @@ parser.add_argument('--earlyexit_thresholds', type=float, nargs='*', dest='early
 
 distiller.knowledge_distillation.add_distillation_args(parser, ALL_MODEL_NAMES, True)
 
+
 def check_pytorch_version():
     if torch.__version__ < '0.4.0':
         print("\nNOTICE:")
@@ -166,6 +167,48 @@ def check_pytorch_version():
         exit(1)
 
 
+def create_activation_stats_collectors(model, collection_phase):
+    """Create objects that collect activation statistics.
+
+    This is a utility function that creates two collectors:
+    1. Fine-grade sparsity levels of the activations
+    2. L1-magnitude of each of the activation channels
+
+    Args:
+        model - the model on which we want to collect statistics
+        phase - the statistics collection phase which is either "train" (for training),
+                or "valid" (for validation)
+
+    WARNING! Enabling activation statsitics collection will significantly slow down training!
+    """
+    class missingdict(dict):
+        """This is a little trick to prevent KeyError"""
+        def __missing__(self, key):
+            return None  # note, does *not* set self[key] - we don't want defaultdict's behavior
+
+    distiller.utils.assign_layer_fq_names(model)
+
+    activations_collectors = {"train": missingdict(), "valid": missingdict(), "test": missingdict()}
+    if collection_phase is None:
+        return activations_collectors
+    collectors = missingdict()
+    collectors["sparsity"] = SummaryActivationStatsCollector(model, "sparsity", distiller.utils.sparsity)
+    collectors["l1_channels"] = SummaryActivationStatsCollector(model, "l1_channels",
+                                                                distiller.utils.activation_channels_l1)
+    collectors["apoz_channels"] = SummaryActivationStatsCollector(model, "apoz_channels",
+                                                                  distiller.utils.activation_channels_apoz)
+    collectors["records"] = RecordsActivationStatsCollector(model, classes=[torch.nn.Conv2d])
+    activations_collectors[collection_phase] = collectors
+    return activations_collectors
+
+
+def save_collectors_data(collectors, directory):
+    """Utility function that saves all activation statistics to Excel workbooks
+    """
+    for name, collector in collectors.items():
+        collector.to_xlsx(os.path.join(directory, name))
+
+
 def main():
     global msglogger
     check_pytorch_version()
@@ -267,18 +310,13 @@ def main():
     msglogger.info('Dataset sizes:\n\ttraining=%d\n\tvalidation=%d\n\ttest=%d',
                    len(train_loader.sampler), len(val_loader.sampler), len(test_loader.sampler))
 
-    activations_sparsity = None
-    if args.activation_stats:
-        # If your model has ReLU layers, then those layers have sparse activations.
-        # ActivationSparsityCollector will collect information about this sparsity.
-        # WARNING! Enabling activation sparsity collection will significantly slow down training!
-        activations_sparsity = ActivationSparsityCollector(model)
+    activations_collectors = create_activation_stats_collectors(model, collection_phase=args.activation_stats)
 
     if args.sensitivity is not None:
         return sensitivity_analysis(model, criterion, test_loader, pylogger, args)
 
     if args.evaluate:
-        return evaluate_model(model, criterion, test_loader, pylogger, args)
+        return evaluate_model(model, criterion, test_loader, pylogger, activations_collectors, args)
 
     if args.compress:
         # The main use-case for this sample application is CNN compression. Compression
@@ -313,15 +351,20 @@ def main():
             compression_scheduler.on_epoch_begin(epoch)
 
         # Train for one epoch
-        train(train_loader, model, criterion, optimizer, epoch, compression_scheduler,
-              loggers=[tflogger, pylogger], args=args)
-        distiller.log_weights_sparsity(model, epoch, loggers=[tflogger, pylogger])
-        if args.activation_stats:
-            distiller.log_activation_sparsity(epoch, loggers=[tflogger, pylogger],
-                                              collector=activations_sparsity)
+        with collectors_context(activations_collectors["train"]) as collectors:
+            train(train_loader, model, criterion, optimizer, epoch, compression_scheduler,
+                  loggers=[tflogger, pylogger], args=args)
+            distiller.log_weights_sparsity(model, epoch, loggers=[tflogger, pylogger])
+            distiller.log_activation_statsitics(epoch, "train", loggers=[tflogger],
+                                                collector=collectors["sparsity"])
 
         # evaluate on validation set
-        top1, top5, vloss = validate(val_loader, model, criterion, [pylogger], args, epoch)
+        with collectors_context(activations_collectors["valid"]) as collectors:
+            top1, top5, vloss = validate(val_loader, model, criterion, [pylogger], args, epoch)
+            distiller.log_activation_statsitics(epoch, "valid", loggers=[tflogger],
+                                                collector=collectors["sparsity"])
+            save_collectors_data(collectors, msglogger.logdir)
+
         stats = ('Peformance/Validation/',
                  OrderedDict([('Loss', vloss),
                               ('Top1', top1),
@@ -342,7 +385,7 @@ def main():
                                  args.name, msglogger.logdir)
 
     # Finally run results on the test set
-    test(test_loader, model, criterion, [pylogger], args=args)
+    test(test_loader, model, criterion, [pylogger], activations_collectors, args=args)
 
 
 OVERALL_LOSS_KEY = 'Overall Loss'
@@ -460,10 +503,15 @@ def validate(val_loader, model, criterion, loggers, args, epoch=-1):
     return _validate(val_loader, model, criterion, loggers, args, epoch)
 
 
-def test(test_loader, model, criterion, loggers, args):
+def test(test_loader, model, criterion, loggers, activations_collectors, args):
     """Model Test"""
     msglogger.info('--- test ---------------------')
-    return _validate(test_loader, model, criterion, loggers, args)
+
+    with collectors_context(activations_collectors["test"]) as collectors:
+        top1, top5, lossses = _validate(test_loader, model, criterion, loggers, args)
+        distiller.log_activation_statsitics(-1, "test", loggers, collector=collectors['sparsity'])
+
+    return top1, top5, lossses
 
 
 def _validate(data_loader, model, criterion, loggers, args, epoch=-1):
@@ -606,7 +654,7 @@ def earlyexit_validate_loss(output, target, criterion, args):
                 args.exit_taken[args.num_exits-1] += 1
 
 
-def evaluate_model(model, criterion, test_loader, loggers, args):
+def evaluate_model(model, criterion, test_loader, loggers, activations_collectors, args):
     # This sample application can be invoked to evaluate the accuracy of your model on
     # the test dataset.
     # You can optionally quantize the model to 8-bit integer before evaluation.
@@ -621,7 +669,9 @@ def evaluate_model(model, criterion, test_loader, loggers, args):
         quantizer = quantization.SymmetricLinearQuantizer(model, 8, 8)
         quantizer.prepare_model()
         model.cuda()
-    top1, _, _ = test(test_loader, model, criterion, loggers, args=args)
+
+    top1, _, _ = test(test_loader, model, criterion, loggers, activations_collectors, args=args)
+
     if args.quantize:
         checkpoint_name = 'quantized'
         apputils.save_checkpoint(0, args.arch, model, optimizer=None, best_top1=top1,
@@ -645,7 +695,8 @@ def sensitivity_analysis(model, criterion, data_loader, loggers, args):
     if not isinstance(loggers, list):
         loggers = [loggers]
     test_fnc = partial(test, test_loader=data_loader, criterion=criterion,
-                       loggers=loggers, args=args)
+                       loggers=loggers, args=args,
+                       activations_collectors=create_activation_stats_collectors(model, None))
     which_params = [param_name for param_name, _ in model.named_parameters()]
     sensitivity = distiller.perform_sensitivity_analysis(model,
                                                          net_params=which_params,
diff --git a/examples/network_trimming/resnet56_cifar_activation_apoz.yaml b/examples/network_trimming/resnet56_cifar_activation_apoz.yaml
new file mode 100755
index 0000000000000000000000000000000000000000..1e9e455680e08980013cb05830038994c310ea85
--- /dev/null
+++ b/examples/network_trimming/resnet56_cifar_activation_apoz.yaml
@@ -0,0 +1,212 @@
+#
+# This schedule uses the average percentage of zeros (APoZ) in the activations, to rank filters.
+# Compare this to examples/pruning_filters_for_efficient_convnets/resnet56_cifar_filter_rank.yaml - the pruning time is
+# much longer due to the callbacks required for collecting the activation statistics (this can be improved by disabling
+# of the detailed records collection, for example).
+# This provides 62.7% compute compression (x1.6) while increasing the Top1.
+#
+# Baseline results:
+#     Top1: 92.850    Top5: 99.780    Loss: 0.364
+#     Total MACs: 125,747,840
+#
+# Results:
+#     Top1: 93.030    Top5: 99.650    Loss: 1.533
+#     Total MACs: 78,856,832
+#
+#
+# time python3 compress_classifier.py -a=resnet56_cifar -p=50 ../../../data.cifar10 --epochs=70 --lr=0.1 --compress=../network_trimming/resnet56_cifar_activation_apoz.yaml --resume=checkpoint.resnet56_cifar_baseline.pth.tar -j=1 --deterministic --act-stats=valid
+#
+# Parameters:
+# +----+-------------------------------------+----------------+---------------+----------------+------------+------------+----------+----------+----------+------------+---------+----------+------------+
+# |    | Name                                | Shape          |   NNZ (dense) |   NNZ (sparse) |   Cols (%) |   Rows (%) |   Ch (%) |   2D (%) |   3D (%) |   Fine (%) |     Std |     Mean |   Abs-Mean |
+# |----+-------------------------------------+----------------+---------------+----------------+------------+------------+----------+----------+----------+------------+---------+----------+------------|
+# |  0 | module.conv1.weight                 | (16, 3, 3, 3)  |           432 |            432 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.25444 |  0.01128 |    0.13307 |
+# |  1 | module.layer1.0.conv1.weight        | (7, 16, 3, 3)  |          1008 |           1008 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.07351 |  0.00182 |    0.04119 |
+# |  2 | module.layer1.0.conv2.weight        | (16, 7, 3, 3)  |          1008 |           1008 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.07510 | -0.00968 |    0.05190 |
+# |  3 | module.layer1.1.conv1.weight        | (7, 16, 3, 3)  |          1008 |           1008 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.06982 |  0.00599 |    0.04476 |
+# |  4 | module.layer1.1.conv2.weight        | (16, 7, 3, 3)  |          1008 |           1008 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.05886 | -0.01451 |    0.04284 |
+# |  5 | module.layer1.2.conv1.weight        | (7, 16, 3, 3)  |          1008 |           1008 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.06894 | -0.00031 |    0.04735 |
+# |  6 | module.layer1.2.conv2.weight        | (16, 7, 3, 3)  |          1008 |           1008 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.06561 | -0.00311 |    0.04952 |
+# |  7 | module.layer1.3.conv1.weight        | (7, 16, 3, 3)  |          1008 |           1008 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.07374 | -0.00087 |    0.05137 |
+# |  8 | module.layer1.3.conv2.weight        | (16, 7, 3, 3)  |          1008 |           1008 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.07346 | -0.00474 |    0.05348 |
+# |  9 | module.layer1.4.conv1.weight        | (7, 16, 3, 3)  |          1008 |           1008 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.06855 |  0.00053 |    0.04867 |
+# | 10 | module.layer1.4.conv2.weight        | (16, 7, 3, 3)  |          1008 |           1008 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.07078 | -0.01038 |    0.05366 |
+# | 11 | module.layer1.5.conv1.weight        | (7, 16, 3, 3)  |          1008 |           1008 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.09372 | -0.00430 |    0.06283 |
+# | 12 | module.layer1.5.conv2.weight        | (16, 7, 3, 3)  |          1008 |           1008 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.09056 | -0.00089 |    0.06517 |
+# | 13 | module.layer1.6.conv1.weight        | (7, 16, 3, 3)  |          1008 |           1008 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.08050 | -0.00971 |    0.06157 |
+# | 14 | module.layer1.6.conv2.weight        | (16, 7, 3, 3)  |          1008 |           1008 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.08000 | -0.00081 |    0.06004 |
+# | 15 | module.layer1.7.conv1.weight        | (7, 16, 3, 3)  |          1008 |           1008 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.09966 | -0.01270 |    0.07424 |
+# | 16 | module.layer1.7.conv2.weight        | (16, 7, 3, 3)  |          1008 |           1008 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.09293 |  0.00685 |    0.07128 |
+# | 17 | module.layer1.8.conv1.weight        | (7, 16, 3, 3)  |          1008 |           1008 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.08764 | -0.01361 |    0.06730 |
+# | 18 | module.layer1.8.conv2.weight        | (16, 7, 3, 3)  |          1008 |           1008 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.07053 |  0.00491 |    0.05341 |
+# | 19 | module.layer2.0.conv1.weight        | (32, 16, 3, 3) |          4608 |           4608 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.09426 | -0.00345 |    0.07094 |
+# | 20 | module.layer2.0.conv2.weight        | (32, 32, 3, 3) |          9216 |           9216 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.07798 | -0.00154 |    0.05783 |
+# | 21 | module.layer2.0.downsample.0.weight | (32, 16, 1, 1) |           512 |            512 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.16513 |  0.00688 |    0.11354 |
+# | 24 | module.layer2.2.conv1.weight        | (16, 32, 3, 3) |          4608 |           4608 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.05589 | -0.00558 |    0.04355 |
+# | 25 | module.layer2.2.conv2.weight        | (32, 16, 3, 3) |          4608 |           4608 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.04979 | -0.00491 |    0.03863 |
+# | 26 | module.layer2.3.conv1.weight        | (16, 32, 3, 3) |          4608 |           4608 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.05622 | -0.00471 |    0.04379 |
+# | 27 | module.layer2.3.conv2.weight        | (32, 16, 3, 3) |          4608 |           4608 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.04541 | -0.00271 |    0.03535 |
+# | 28 | module.layer2.4.conv1.weight        | (16, 32, 3, 3) |          4608 |           4608 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.05166 | -0.00597 |    0.03896 |
+# | 29 | module.layer2.4.conv2.weight        | (32, 16, 3, 3) |          4608 |           4608 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.04098 | -0.00381 |    0.03114 |
+# | 30 | module.layer2.5.conv1.weight        | (32, 32, 3, 3) |          9216 |           9216 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.04188 | -0.00373 |    0.03040 |
+# | 31 | module.layer2.5.conv2.weight        | (32, 32, 3, 3) |          9216 |           9216 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.03249 | -0.00190 |    0.02291 |
+# | 32 | module.layer2.6.conv1.weight        | (16, 32, 3, 3) |          4608 |           4608 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.04584 | -0.00553 |    0.03569 |
+# | 33 | module.layer2.6.conv2.weight        | (32, 16, 3, 3) |          4608 |           4608 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.03655 | -0.00216 |    0.02758 |
+# | 34 | module.layer2.7.conv1.weight        | (16, 32, 3, 3) |          4608 |           4608 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.05110 | -0.00700 |    0.03909 |
+# | 35 | module.layer2.7.conv2.weight        | (32, 16, 3, 3) |          4608 |           4608 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.03884 | -0.00129 |    0.02946 |
+# | 36 | module.layer2.8.conv1.weight        | (32, 32, 3, 3) |          9216 |           9216 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.03331 | -0.00269 |    0.02211 |
+# | 37 | module.layer2.8.conv2.weight        | (32, 32, 3, 3) |          9216 |           9216 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.02406 | -0.00014 |    0.01479 |
+# | 38 | module.layer3.0.conv1.weight        | (64, 32, 3, 3) |         18432 |          18432 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.05957 | -0.00091 |    0.04658 |
+# | 39 | module.layer3.0.conv2.weight        | (64, 64, 3, 3) |         36864 |          36864 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.05103 | -0.00016 |    0.03729 |
+# | 40 | module.layer3.0.downsample.0.weight | (64, 32, 1, 1) |          2048 |           2048 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.09200 |  0.00203 |    0.06440 |
+# | 41 | module.layer3.1.conv1.weight        | (58, 64, 3, 3) |         33408 |          33408 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.03758 | -0.00117 |    0.02728 |
+# | 42 | module.layer3.1.conv2.weight        | (64, 58, 3, 3) |         33408 |          33408 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.03577 | -0.00397 |    0.02686 |
+# | 43 | module.layer3.2.conv1.weight        | (45, 64, 3, 3) |         25920 |          25920 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.03704 | -0.00146 |    0.02762 |
+# | 44 | module.layer3.2.conv2.weight        | (64, 45, 3, 3) |         25920 |          25920 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.03409 | -0.00464 |    0.02638 |
+# | 45 | module.layer3.3.conv1.weight        | (45, 64, 3, 3) |         25920 |          25920 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.03887 | -0.00274 |    0.03015 |
+# | 46 | module.layer3.3.conv2.weight        | (64, 45, 3, 3) |         25920 |          25920 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.03390 | -0.00448 |    0.02648 |
+# | 47 | module.layer3.4.conv1.weight        | (64, 64, 3, 3) |         36864 |          36864 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.04296 | -0.00361 |    0.03345 |
+# | 48 | module.layer3.4.conv2.weight        | (64, 64, 3, 3) |         36864 |          36864 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.03454 | -0.00255 |    0.02628 |
+# | 49 | module.layer3.5.conv1.weight        | (45, 64, 3, 3) |         25920 |          25920 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.04208 | -0.00441 |    0.03301 |
+# | 50 | module.layer3.5.conv2.weight        | (64, 45, 3, 3) |         25920 |          25920 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.03186 | -0.00319 |    0.02431 |
+# | 51 | module.layer3.6.conv1.weight        | (45, 64, 3, 3) |         25920 |          25920 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.03119 | -0.00262 |    0.02419 |
+# | 52 | module.layer3.6.conv2.weight        | (64, 45, 3, 3) |         25920 |          25920 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.02298 | -0.00015 |    0.01670 |
+# | 53 | module.layer3.7.conv1.weight        | (45, 64, 3, 3) |         25920 |          25920 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.02912 | -0.00265 |    0.02235 |
+# | 54 | module.layer3.7.conv2.weight        | (64, 45, 3, 3) |         25920 |          25920 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.02078 | -0.00010 |    0.01524 |
+# | 55 | module.layer3.8.conv1.weight        | (45, 64, 3, 3) |         25920 |          25920 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.03270 | -0.00269 |    0.02542 |
+# | 56 | module.layer3.8.conv2.weight        | (64, 45, 3, 3) |         25920 |          25920 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.02244 |  0.00045 |    0.01630 |
+# | 57 | module.fc.weight                    | (10, 64)       |           640 |            640 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.42577 | -0.00001 |    0.33523 |
+# | 58 | Total sparsity:                     | -              |        634640 |         634640 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.00000 |  0.00000 |    0.00000 |
+# +----+-------------------------------------+----------------+---------------+----------------+------------+------------+----------+----------+----------+------------+---------+----------+------------+
+# Total sparsity: 0.00
+#
+# --- validate (epoch=249)-----------
+# 5000 samples (256 per mini-batch)
+# ==> Top1: 92.740    Top5: 99.720    Loss: 1.534
+#
+# ==> Best Top1: 92.760   On Epoch: 237
+#
+# Saving checkpoint to: logs/2018.10.16-013006/checkpoint.pth.tar
+# --- test ---------------------
+# 10000 samples (256 per mini-batch)
+# ==> Top1: 93.030    Top5: 99.650    Loss: 1.533
+#
+#
+# Log file for this run: /home/cvds_lab/nzmora/pytorch_workspace/distiller/examples/classifier_compression/logs/2018.10.16-013006/2018.10.16-013006.log
+#
+# real    49m0.623s
+# user    90m51.054s
+# sys     8m36.745s
+
+version: 1
+pruners:
+#   filter_pruner:
+#     class: 'ActivationAPoZRankedStructureParameterPruner'
+#     reg_regims:
+#       'module.layer1.0.conv1.weight': Filters
+
+  # filter_pruner:
+  #   class: ActivationAPoZRankedFilterPruner_AGP
+  #   initial_sparsity : 0.10
+  #   final_sparsity: 0.6
+  #   reg_regims:
+  #     module.layer1.0.conv1.weight: Filters
+
+  filter_pruner_60:
+    #class: StructuredAutomatedGradualPruner
+    class: ActivationAPoZRankedFilterPruner_AGP
+    initial_sparsity : 0.10
+    final_sparsity: 0.6
+    reg_regims:
+      module.layer1.0.conv1.weight: Filters
+      module.layer1.1.conv1.weight: Filters
+      module.layer1.2.conv1.weight: Filters
+      module.layer1.3.conv1.weight: Filters
+      module.layer1.4.conv1.weight: Filters
+      module.layer1.5.conv1.weight: Filters
+      module.layer1.6.conv1.weight: Filters
+      module.layer1.7.conv1.weight: Filters
+      module.layer1.8.conv1.weight: Filters
+
+  filter_pruner_50:
+    #class: StructuredAutomatedGradualPruner
+    class: ActivationAPoZRankedFilterPruner_AGP
+    initial_sparsity : 0.10
+    final_sparsity: 0.5
+    reg_regims:
+      module.layer2.1.conv1.weight: Filters
+      module.layer2.2.conv1.weight: Filters
+      module.layer2.3.conv1.weight: Filters
+      module.layer2.4.conv1.weight: Filters
+      module.layer2.6.conv1.weight: Filters
+      module.layer2.7.conv1.weight: Filters
+
+  filter_pruner_10:
+    #class: StructuredAutomatedGradualPruner
+    class: ActivationAPoZRankedFilterPruner_AGP
+    initial_sparsity : 0
+    final_sparsity: 0.1
+    reg_regims:
+      module.layer3.1.conv1.weight: Filters
+
+  filter_pruner_30:
+    #class: StructuredAutomatedGradualPruner
+    class: ActivationAPoZRankedFilterPruner_AGP
+    initial_sparsity : 0.10
+    final_sparsity: 0.3
+    reg_regims:
+        module.layer3.2.conv1.weight: Filters
+        module.layer3.3.conv1.weight: Filters
+        module.layer3.5.conv1.weight: Filters
+        module.layer3.6.conv1.weight: Filters
+        module.layer3.7.conv1.weight: Filters
+        module.layer3.8.conv1.weight: Filters
+
+
+extensions:
+  net_thinner:
+      class: 'FilterRemover'
+      thinning_func_str: remove_filters
+      arch: 'resnet56_cifar'
+      dataset: 'cifar10'
+
+lr_schedulers:
+   exp_finetuning_lr:
+     class: ExponentialLR
+     gamma: 0.95
+
+
+policies:
+  - pruner:
+      instance_name: filter_pruner_60
+    starting_epoch: 181
+    ending_epoch: 200
+    frequency: 2
+
+  - pruner:
+      instance_name: filter_pruner_50
+    starting_epoch: 181
+    ending_epoch: 200
+    frequency: 2
+
+  - pruner:
+      instance_name: filter_pruner_30
+    starting_epoch: 181
+    ending_epoch: 200
+    frequency: 2
+
+  - pruner:
+      instance_name: filter_pruner_10
+    starting_epoch: 181
+    ending_epoch: 200
+    frequency: 2
+
+  - extension:
+      instance_name: net_thinner
+    epochs: [200]
+
+  - lr_scheduler:
+      instance_name: exp_finetuning_lr
+    starting_epoch: 190
+    ending_epoch: 300
+    frequency: 1
diff --git a/examples/network_trimming/resnet56_cifar_activation_apoz_v2.yaml b/examples/network_trimming/resnet56_cifar_activation_apoz_v2.yaml
new file mode 100755
index 0000000000000000000000000000000000000000..1062520d51d39a72426a69edc6391244d8e88213
--- /dev/null
+++ b/examples/network_trimming/resnet56_cifar_activation_apoz_v2.yaml
@@ -0,0 +1,198 @@
+#
+# This schedule uses the average percentage of zeros (APoZ) in the activations, to rank filters.
+# Compare this to examples/pruning_filters_for_efficient_convnets/resnet56_cifar_filter_rank.yaml - the pruning time is
+# much longer due to the callbacks required for collecting the activation statistics (this can be improved by disabling
+# of the detailed records collection, for example).
+# This provides 62.7% compute compression (x1.6) while increasing the Top1.
+#
+# Baseline results:
+#     Top1: 92.850    Top5: 99.780    Loss: 0.364
+#     Total MACs: 125,747,840
+#
+# Results:
+#     Top1: 92.590    Top5: 99.630    Loss: 1.537
+#     Total MACs: 67,797,632
+#
+#
+# time python3 compress_classifier.py -a=resnet56_cifar -p=50 ../../../data.cifar10 --epochs=70 --lr=0.1 --compress=../network_trimming/resnet56_cifar_activation_apoz_v2.yaml --resume=checkpoint.resnet56_cifar_baseline.pth.tar -j=1 --deterministic --act-stats=valid
+#
+# Parameters:
+# +----+-------------------------------------+----------------+---------------+----------------+------------+------------+----------+----------+----------+------------+---------+----------+------------+
+# |    | Name                                | Shape          |   NNZ (dense) |   NNZ (sparse) |   Cols (%) |   Rows (%) |   Ch (%) |   2D (%) |   3D (%) |   Fine (%) |     Std |     Mean |   Abs-Mean |
+# |----+-------------------------------------+----------------+---------------+----------------+------------+------------+----------+----------+----------+------------+---------+----------+------------|
+# |  0 | module.conv1.weight                 | (16, 3, 3, 3)  |           432 |            432 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.25855 |  0.00995 |    0.13518 |
+# |  1 | module.layer1.0.conv1.weight        | (5, 16, 3, 3)  |           720 |            720 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.08280 | -0.00200 |    0.04624 |
+# |  2 | module.layer1.0.conv2.weight        | (16, 5, 3, 3)  |           720 |            720 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.08243 | -0.01280 |    0.05715 |
+# |  3 | module.layer1.1.conv1.weight        | (5, 16, 3, 3)  |           720 |            720 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.07365 |  0.00145 |    0.04511 |
+# |  4 | module.layer1.1.conv2.weight        | (16, 5, 3, 3)  |           720 |            720 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.06361 | -0.01105 |    0.04664 |
+# |  5 | module.layer1.2.conv1.weight        | (5, 16, 3, 3)  |           720 |            720 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.07599 | -0.00007 |    0.05344 |
+# |  6 | module.layer1.2.conv2.weight        | (16, 5, 3, 3)  |           720 |            720 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.07223 | -0.00321 |    0.05500 |
+# |  7 | module.layer1.3.conv1.weight        | (5, 16, 3, 3)  |           720 |            720 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.08744 |  0.00275 |    0.06146 |
+# |  8 | module.layer1.3.conv2.weight        | (16, 5, 3, 3)  |           720 |            720 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.08422 | -0.01179 |    0.06336 |
+# |  9 | module.layer1.4.conv1.weight        | (5, 16, 3, 3)  |           720 |            720 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.07359 | -0.00328 |    0.05202 |
+# | 10 | module.layer1.4.conv2.weight        | (16, 5, 3, 3)  |           720 |            720 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.07655 | -0.00830 |    0.05807 |
+# | 11 | module.layer1.5.conv1.weight        | (5, 16, 3, 3)  |           720 |            720 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.10343 | -0.00167 |    0.06877 |
+# | 12 | module.layer1.5.conv2.weight        | (16, 5, 3, 3)  |           720 |            720 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.10276 | -0.00419 |    0.07489 |
+# | 13 | module.layer1.6.conv1.weight        | (5, 16, 3, 3)  |           720 |            720 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.08513 | -0.00885 |    0.06337 |
+# | 14 | module.layer1.6.conv2.weight        | (16, 5, 3, 3)  |           720 |            720 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.08756 | -0.00307 |    0.06605 |
+# | 15 | module.layer1.7.conv1.weight        | (5, 16, 3, 3)  |           720 |            720 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.10795 | -0.01498 |    0.07969 |
+# | 16 | module.layer1.7.conv2.weight        | (16, 5, 3, 3)  |           720 |            720 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.09896 |  0.00700 |    0.07549 |
+# | 17 | module.layer1.8.conv1.weight        | (5, 16, 3, 3)  |           720 |            720 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.09739 | -0.01620 |    0.07525 |
+# | 18 | module.layer1.8.conv2.weight        | (16, 5, 3, 3)  |           720 |            720 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.07876 |  0.00671 |    0.05963 |
+# | 19 | module.layer2.0.conv1.weight        | (32, 16, 3, 3) |          4608 |           4608 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.09674 | -0.00379 |    0.07281 |
+# | 20 | module.layer2.0.conv2.weight        | (32, 32, 3, 3) |          9216 |           9216 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.07978 | -0.00164 |    0.05939 |
+# | 21 | module.layer2.0.downsample.0.weight | (32, 16, 1, 1) |           512 |            512 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.16693 |  0.00597 |    0.11503 |
+# | 22 | module.layer2.1.conv1.weight        | (13, 32, 3, 3) |          3744 |           3744 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.06721 | -0.00298 |    0.05084 |
+# | 23 | module.layer2.1.conv2.weight        | (32, 13, 3, 3) |          3744 |           3744 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.05842 | -0.00357 |    0.04602 |
+# | 24 | module.layer2.2.conv1.weight        | (13, 32, 3, 3) |          3744 |           3744 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.05689 | -0.00687 |    0.04438 |
+# | 25 | module.layer2.2.conv2.weight        | (32, 13, 3, 3) |          3744 |           3744 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.05025 | -0.00454 |    0.03910 |
+# | 26 | module.layer2.3.conv1.weight        | (13, 32, 3, 3) |          3744 |           3744 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.05853 | -0.00510 |    0.04560 |
+# | 27 | module.layer2.3.conv2.weight        | (32, 13, 3, 3) |          3744 |           3744 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.04758 | -0.00213 |    0.03705 |
+# | 28 | module.layer2.4.conv1.weight        | (13, 32, 3, 3) |          3744 |           3744 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.05571 | -0.00672 |    0.04273 |
+# | 29 | module.layer2.4.conv2.weight        | (32, 13, 3, 3) |          3744 |           3744 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.04397 | -0.00459 |    0.03407 |
+# | 30 | module.layer2.5.conv1.weight        | (32, 32, 3, 3) |          9216 |           9216 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.04363 | -0.00326 |    0.03157 |
+# | 31 | module.layer2.5.conv2.weight        | (32, 32, 3, 3) |          9216 |           9216 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.03392 | -0.00203 |    0.02392 |
+# | 32 | module.layer2.6.conv1.weight        | (13, 32, 3, 3) |          3744 |           3744 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.04665 | -0.00607 |    0.03589 |
+# | 33 | module.layer2.6.conv2.weight        | (32, 13, 3, 3) |          3744 |           3744 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.03699 | -0.00209 |    0.02802 |
+# | 34 | module.layer2.7.conv1.weight        | (13, 32, 3, 3) |          3744 |           3744 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.05658 | -0.00710 |    0.04287 |
+# | 35 | module.layer2.7.conv2.weight        | (32, 13, 3, 3) |          3744 |           3744 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.04273 | -0.00281 |    0.03222 |
+# | 36 | module.layer2.8.conv1.weight        | (32, 32, 3, 3) |          9216 |           9216 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.03511 | -0.00271 |    0.02330 |
+# | 37 | module.layer2.8.conv2.weight        | (32, 32, 3, 3) |          9216 |           9216 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.02510 | -0.00009 |    0.01561 |
+# | 38 | module.layer3.0.conv1.weight        | (64, 32, 3, 3) |         18432 |          18432 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.06056 | -0.00144 |    0.04736 |
+# | 39 | module.layer3.0.conv2.weight        | (64, 64, 3, 3) |         36864 |          36864 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.05184 | -0.00042 |    0.03792 |
+# | 40 | module.layer3.0.downsample.0.weight | (64, 32, 1, 1) |          2048 |           2048 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.09288 |  0.00286 |    0.06527 |
+# | 41 | module.layer3.1.conv1.weight        | (52, 64, 3, 3) |         29952 |          29952 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.03912 | -0.00155 |    0.02842 |
+# | 42 | module.layer3.1.conv2.weight        | (64, 52, 3, 3) |         29952 |          29952 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.03727 | -0.00412 |    0.02804 |
+# | 43 | module.layer3.2.conv1.weight        | (39, 64, 3, 3) |         22464 |          22464 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.03857 | -0.00132 |    0.02876 |
+# | 44 | module.layer3.2.conv2.weight        | (64, 39, 3, 3) |         22464 |          22464 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.03551 | -0.00487 |    0.02752 |
+# | 45 | module.layer3.3.conv1.weight        | (39, 64, 3, 3) |         22464 |          22464 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.04020 | -0.00281 |    0.03119 |
+# | 46 | module.layer3.3.conv2.weight        | (64, 39, 3, 3) |         22464 |          22464 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.03524 | -0.00470 |    0.02756 |
+# | 47 | module.layer3.4.conv1.weight        | (64, 64, 3, 3) |         36864 |          36864 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.04392 | -0.00341 |    0.03419 |
+# | 48 | module.layer3.4.conv2.weight        | (64, 64, 3, 3) |         36864 |          36864 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.03533 | -0.00286 |    0.02693 |
+# | 49 | module.layer3.5.conv1.weight        | (39, 64, 3, 3) |         22464 |          22464 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.04352 | -0.00462 |    0.03410 |
+# | 50 | module.layer3.5.conv2.weight        | (64, 39, 3, 3) |         22464 |          22464 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.03269 | -0.00326 |    0.02494 |
+# | 51 | module.layer3.6.conv1.weight        | (39, 64, 3, 3) |         22464 |          22464 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.03254 | -0.00281 |    0.02521 |
+# | 52 | module.layer3.6.conv2.weight        | (64, 39, 3, 3) |         22464 |          22464 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.02411 | -0.00045 |    0.01755 |
+# | 53 | module.layer3.7.conv1.weight        | (39, 64, 3, 3) |         22464 |          22464 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.03012 | -0.00277 |    0.02301 |
+# | 54 | module.layer3.7.conv2.weight        | (64, 39, 3, 3) |         22464 |          22464 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.02162 | -0.00051 |    0.01584 |
+# | 55 | module.layer3.8.conv1.weight        | (39, 64, 3, 3) |         22464 |          22464 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.03413 | -0.00266 |    0.02653 |
+# | 56 | module.layer3.8.conv2.weight        | (64, 39, 3, 3) |         22464 |          22464 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.02345 |  0.00058 |    0.01716 |
+# | 57 | module.fc.weight                    | (10, 64)       |           640 |            640 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.43253 | -0.00001 |    0.34022 |
+# | 58 | Total sparsity:                     | -              |        570704 |         570704 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.00000 |  0.00000 |    0.00000 |
+# +----+-------------------------------------+----------------+---------------+----------------+------------+------------+----------+----------+----------+------------+---------+----------+------------+
+# Total sparsity: 0.00
+#
+# --- validate (epoch=249)-----------
+# 5000 samples (256 per mini-batch)
+# ==> Top1: 91.880    Top5: 99.520    Loss: 1.543
+#
+# ==> Best Top1: 92.680   On Epoch: 237
+#
+# Saving checkpoint to: logs/2018.10.16-115506/checkpoint.pth.tar
+# --- test ---------------------
+# 10000 samples (256 per mini-batch)
+# ==> Top1: 92.590    Top5: 99.630    Loss: 1.537
+#
+#
+# Log file for this run: /home/cvds_lab/nzmora/pytorch_workspace/distiller/examples/classifier_compression/logs/2018.10.16-115506/2018.10.16-115506.log
+#
+# real    49m10.711s
+# user    91m13.215s
+# sys     8m34.108s
+
+version: 1
+pruners:
+  filter_pruner_60:
+    class: ActivationAPoZRankedFilterPruner_AGP
+    initial_sparsity : 0.10
+    final_sparsity: 0.7
+    reg_regims:
+      module.layer1.0.conv1.weight: Filters
+      module.layer1.1.conv1.weight: Filters
+      module.layer1.2.conv1.weight: Filters
+      module.layer1.3.conv1.weight: Filters
+      module.layer1.4.conv1.weight: Filters
+      module.layer1.5.conv1.weight: Filters
+      module.layer1.6.conv1.weight: Filters
+      module.layer1.7.conv1.weight: Filters
+      module.layer1.8.conv1.weight: Filters
+
+  filter_pruner_50:
+    class: ActivationAPoZRankedFilterPruner_AGP
+    initial_sparsity : 0.10
+    final_sparsity: 0.6
+    reg_regims:
+      module.layer2.1.conv1.weight: Filters
+      module.layer2.2.conv1.weight: Filters
+      module.layer2.3.conv1.weight: Filters
+      module.layer2.4.conv1.weight: Filters
+      module.layer2.6.conv1.weight: Filters
+      module.layer2.7.conv1.weight: Filters
+
+  filter_pruner_10:
+    class: ActivationAPoZRankedFilterPruner_AGP
+    initial_sparsity : 0
+    final_sparsity: 0.2
+    reg_regims:
+      module.layer3.1.conv1.weight: Filters
+
+  filter_pruner_30:
+    class: ActivationAPoZRankedFilterPruner_AGP
+    initial_sparsity : 0.10
+    final_sparsity: 0.4
+    reg_regims:
+        module.layer3.2.conv1.weight: Filters
+        module.layer3.3.conv1.weight: Filters
+        module.layer3.5.conv1.weight: Filters
+        module.layer3.6.conv1.weight: Filters
+        module.layer3.7.conv1.weight: Filters
+        module.layer3.8.conv1.weight: Filters
+
+
+extensions:
+  net_thinner:
+      class: 'FilterRemover'
+      thinning_func_str: remove_filters
+      arch: 'resnet56_cifar'
+      dataset: 'cifar10'
+
+lr_schedulers:
+   exp_finetuning_lr:
+     class: ExponentialLR
+     gamma: 0.95
+
+
+policies:
+  - pruner:
+      instance_name: filter_pruner_60
+    starting_epoch: 181
+    ending_epoch: 200
+    frequency: 2
+
+  - pruner:
+      instance_name: filter_pruner_50
+    starting_epoch: 181
+    ending_epoch: 200
+    frequency: 2
+
+  - pruner:
+      instance_name: filter_pruner_30
+    starting_epoch: 181
+    ending_epoch: 200
+    frequency: 2
+
+  - pruner:
+      instance_name: filter_pruner_10
+    starting_epoch: 181
+    ending_epoch: 200
+    frequency: 2
+
+  - extension:
+      instance_name: net_thinner
+    epochs: [200]
+
+  - lr_scheduler:
+      instance_name: exp_finetuning_lr
+    starting_epoch: 190
+    ending_epoch: 300
+    frequency: 1
diff --git a/examples/pruning_filters_for_efficient_convnets/resnet56_cifar_filter_rank.yaml b/examples/pruning_filters_for_efficient_convnets/resnet56_cifar_filter_rank.yaml
index a3c282cce9208d0bedca8f34fc17eea34a8222e7..08b2b012b4959d994d3360e16a3875606b5a481c 100755
--- a/examples/pruning_filters_for_efficient_convnets/resnet56_cifar_filter_rank.yaml
+++ b/examples/pruning_filters_for_efficient_convnets/resnet56_cifar_filter_rank.yaml
@@ -1,8 +1,8 @@
-#
 # This schedule performs filter ranking and removal, for the convolution layers in ResNet56-CIFAR, as described in
-# Pruning Filters for Efficient Convnets.
-# Filters are ranked and pruned accordingly.
+# Pruning Filters for Efficient Convnets, H. Li, A. Kadav, I. Durdanovic, H. Samet, and H. P. Graf.
+# ICLR 2017, arXiv:1608.087
 #
+# Filters are ranked and pruned accordingly.
 # This is followed by network thinning which removes the filters entirely from the model, and changes the convolution
 # layers' dimensions accordingly.  Convolution layers that follow have their respective channels removed as well, as do
 # Batch normailization layers.
@@ -18,6 +18,14 @@
 #
 # Results: 62.7% of the original convolution MACs (when calculated using direct convolution)
 #
+# Baseline results:
+#     Top1: 92.850    Top5: 99.780    Loss: 0.464
+#     Total MACs: 125,747,840
+#
+# Results:
+#     Top1: 92.830    Top5: 99.760    Loss: 0.489
+#     Total MACs: 78,856,832
+#
 # Parameters:
 # +----+-------------------------------------+----------------+---------------+----------------+------------+------------+----------+----------+----------+------------+---------+----------+------------+
 # |    | Name                                | Shape          |   NNZ (dense) |   NNZ (sparse) |   Cols (%) |   Rows (%) |   Ch (%) |   2D (%) |   3D (%) |   Fine (%) |     Std |     Mean |   Abs-Mean |
@@ -79,19 +87,19 @@
 # | 54 | module.layer3.7.conv2.weight        | (64, 45, 3, 3) |         25920 |          25920 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.02753 |  0.00051 |    0.02028 |
 # | 55 | module.layer3.8.conv1.weight        | (45, 64, 3, 3) |         25920 |          25920 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.02897 | -0.00308 |    0.02240 |
 # | 56 | module.layer3.8.conv2.weight        | (64, 45, 3, 3) |         25920 |          25920 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.02130 | -0.00061 |    0.01556 |
-# | 57 | module.fc.weight                    | (10, 64)       |           640 |            640 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.47902 | -0.00002 |    0.37518 |
+# | 57 | module.fc.weight                    | (10, 64)       |           640 |            640 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.47902 | -0.00002 |    0.47518 |
 # | 58 | Total sparsity:                     | -              |        634640 |         634640 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.00000 |  0.00000 |    0.00000 |
 # +----+-------------------------------------+----------------+---------------+----------------+------------+------------+----------+----------+----------+------------+---------+----------+------------+
 # Total sparsity: 0.00
 #
 # --- validate (epoch=249)-----------
 # 5000 samples (256 per mini-batch)
-# ==> Top1: 92.640    Top5: 99.820    Loss: 0.353
+# ==> Top1: 92.640    Top5: 99.820    Loss: 0.453
 #
 # Saving checkpoint
 # --- test ---------------------
 # 10000 samples (256 per mini-batch)
-# ==> Top1: 92.830    Top5: 99.760    Loss: 0.389
+# ==> Top1: 92.830    Top5: 99.760    Loss: 0.489
 #
 #
 # Log file for this run: /home/cvds_lab/nzmora/pytorch_workspace/private-distiller/examples/classifier_compression/logs/2018.04.17-005852/2018.04.17-005852.log
diff --git a/examples/pruning_filters_for_efficient_convnets/resnet56_cifar_filter_rank_v2.yaml b/examples/pruning_filters_for_efficient_convnets/resnet56_cifar_filter_rank_v2.yaml
new file mode 100755
index 0000000000000000000000000000000000000000..d0f58ce176d70ea63590b7e2bd49961bd615b291
--- /dev/null
+++ b/examples/pruning_filters_for_efficient_convnets/resnet56_cifar_filter_rank_v2.yaml
@@ -0,0 +1,207 @@
+#
+# This schedule performs filter ranking and removal, for the convolution layers in ResNet56-CIFAR, as described in
+# Pruning Filters for Efficient Convnets, H. Li, A. Kadav, I. Durdanovic, H. Samet, and H. P. Graf.
+# ICLR 2017, arXiv:1608.087
+#
+# Filters are ranked and pruned accordingly.
+# This is followed by network thinning which removes the filters entirely from the model, and changes the convolution
+# layers' dimensions accordingly.  Convolution layers that follow have their respective channels removed as well, as do
+# Batch normailization layers.
+#
+# The authors write that: "Since there is no projection mapping for choosing the identity featuremaps, we only
+# consider pruning the first layer of the residual block."
+#
+# Note that to use the command-line below, you will need the baseline ResNet56 model (checkpoint.resnet56_cifar_baseline.pth.tar).
+# You may either train this model from scratch, or download it from the link below.
+# https://s3-us-west-1.amazonaws.com/nndistiller/pruning_filters_for_efficient_convnets/checkpoint.resnet56_cifar_baseline.pth.tar
+#
+# time python3 compress_classifier.py -a=resnet56_cifar -p=50 ../../../data.cifar10 --epochs=70 --lr=0.1 --compress=../pruning_filters_for_efficient_convnets/resnet56_cifar_filter_rank.yaml --resume=checkpoint.resnet56_cifar_baseline.pth.tar -j=1 --deterministic
+#
+# Results: 53.9% (1.85x) of the original convolution MACs (when calculated using direct convolution)
+#
+# Baseline results:
+#     Top1: 92.850    Top5: 99.780    Loss: 0.464
+#     Total MACs: 125,747,840
+#
+# Results:
+#     Top1: 92.740    Top5: 99.640    Loss: 1.534
+#     Total MACs: 67,797,632
+#
+# Parameters:
+# +----+-------------------------------------+----------------+---------------+----------------+------------+------------+----------+----------+----------+------------+---------+----------+------------+
+# |    | Name                                | Shape          |   NNZ (dense) |   NNZ (sparse) |   Cols (%) |   Rows (%) |   Ch (%) |   2D (%) |   3D (%) |   Fine (%) |     Std |     Mean |   Abs-Mean |
+# |----+-------------------------------------+----------------+---------------+----------------+------------+------------+----------+----------+----------+------------+---------+----------+------------|
+# |  0 | module.conv1.weight                 | (16, 3, 3, 3)  |           432 |            432 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.25775 |  0.01021 |    0.13389 |
+# |  1 | module.layer1.0.conv1.weight        | (5, 16, 3, 3)  |           720 |            720 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.08400 |  0.00016 |    0.04778 |
+# |  2 | module.layer1.0.conv2.weight        | (16, 5, 3, 3)  |           720 |            720 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.08528 | -0.00987 |    0.05921 |
+# |  3 | module.layer1.1.conv1.weight        | (5, 16, 3, 3)  |           720 |            720 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.08004 |  0.00482 |    0.05321 |
+# |  4 | module.layer1.1.conv2.weight        | (16, 5, 3, 3)  |           720 |            720 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.06757 | -0.01286 |    0.04833 |
+# |  5 | module.layer1.2.conv1.weight        | (5, 16, 3, 3)  |           720 |            720 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.07856 |  0.00001 |    0.05491 |
+# |  6 | module.layer1.2.conv2.weight        | (16, 5, 3, 3)  |           720 |            720 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.06927 | -0.00455 |    0.05305 |
+# |  7 | module.layer1.3.conv1.weight        | (5, 16, 3, 3)  |           720 |            720 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.08582 |  0.00157 |    0.06020 |
+# |  8 | module.layer1.3.conv2.weight        | (16, 5, 3, 3)  |           720 |            720 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.08446 | -0.00188 |    0.06118 |
+# |  9 | module.layer1.4.conv1.weight        | (5, 16, 3, 3)  |           720 |            720 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.08725 |  0.00491 |    0.06379 |
+# | 10 | module.layer1.4.conv2.weight        | (16, 5, 3, 3)  |           720 |            720 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.07957 | -0.01561 |    0.06278 |
+# | 11 | module.layer1.5.conv1.weight        | (5, 16, 3, 3)  |           720 |            720 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.10966 | -0.00952 |    0.07636 |
+# | 12 | module.layer1.5.conv2.weight        | (16, 5, 3, 3)  |           720 |            720 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.10307 |  0.00342 |    0.07403 |
+# | 13 | module.layer1.6.conv1.weight        | (5, 16, 3, 3)  |           720 |            720 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.10200 | -0.00991 |    0.07828 |
+# | 14 | module.layer1.6.conv2.weight        | (16, 5, 3, 3)  |           720 |            720 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.09630 |  0.00492 |    0.07303 |
+# | 15 | module.layer1.7.conv1.weight        | (5, 16, 3, 3)  |           720 |            720 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.11086 | -0.01292 |    0.08203 |
+# | 16 | module.layer1.7.conv2.weight        | (16, 5, 3, 3)  |           720 |            720 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.10022 |  0.00847 |    0.07685 |
+# | 17 | module.layer1.8.conv1.weight        | (5, 16, 3, 3)  |           720 |            720 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.10323 | -0.01191 |    0.07876 |
+# | 18 | module.layer1.8.conv2.weight        | (16, 5, 3, 3)  |           720 |            720 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.07648 |  0.00615 |    0.05865 |
+# | 19 | module.layer2.0.conv1.weight        | (32, 16, 3, 3) |          4608 |           4608 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.09559 | -0.00432 |    0.07201 |
+# | 20 | module.layer2.0.conv2.weight        | (32, 32, 3, 3) |          9216 |           9216 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.07934 | -0.00183 |    0.05900 |
+# | 21 | module.layer2.0.downsample.0.weight | (32, 16, 1, 1) |           512 |            512 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.16645 |  0.00714 |    0.11508 |
+# | 22 | module.layer2.1.conv1.weight        | (13, 32, 3, 3) |          3744 |           3744 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.06975 | -0.00591 |    0.05336 |
+# | 23 | module.layer2.1.conv2.weight        | (32, 13, 3, 3) |          3744 |           3744 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.05907 | -0.00069 |    0.04646 |
+# | 24 | module.layer2.2.conv1.weight        | (13, 32, 3, 3) |          3744 |           3744 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.06476 | -0.00591 |    0.05015 |
+# | 25 | module.layer2.2.conv2.weight        | (32, 13, 3, 3) |          3744 |           3744 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.05564 | -0.00607 |    0.04301 |
+# | 26 | module.layer2.3.conv1.weight        | (13, 32, 3, 3) |          3744 |           3744 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.06246 | -0.00269 |    0.04888 |
+# | 27 | module.layer2.3.conv2.weight        | (32, 13, 3, 3) |          3744 |           3744 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.05009 | -0.00171 |    0.03892 |
+# | 28 | module.layer2.4.conv1.weight        | (13, 32, 3, 3) |          3744 |           3744 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.06047 | -0.00494 |    0.04774 |
+# | 29 | module.layer2.4.conv2.weight        | (32, 13, 3, 3) |          3744 |           3744 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.04690 | -0.00493 |    0.03661 |
+# | 30 | module.layer2.5.conv1.weight        | (32, 32, 3, 3) |          9216 |           9216 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.04318 | -0.00403 |    0.03144 |
+# | 31 | module.layer2.5.conv2.weight        | (32, 32, 3, 3) |          9216 |           9216 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.03371 | -0.00219 |    0.02386 |
+# | 32 | module.layer2.6.conv1.weight        | (13, 32, 3, 3) |          3744 |           3744 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.05296 | -0.00465 |    0.04163 |
+# | 33 | module.layer2.6.conv2.weight        | (32, 13, 3, 3) |          3744 |           3744 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.04041 | -0.00099 |    0.03073 |
+# | 34 | module.layer2.7.conv1.weight        | (13, 32, 3, 3) |          3744 |           3744 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.06190 | -0.00745 |    0.04871 |
+# | 35 | module.layer2.7.conv2.weight        | (32, 13, 3, 3) |          3744 |           3744 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.04535 |  0.00052 |    0.03475 |
+# | 36 | module.layer2.8.conv1.weight        | (32, 32, 3, 3) |          9216 |           9216 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.03378 | -0.00309 |    0.02261 |
+# | 37 | module.layer2.8.conv2.weight        | (32, 32, 3, 3) |          9216 |           9216 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.02457 | -0.00035 |    0.01523 |
+# | 38 | module.layer3.0.conv1.weight        | (64, 32, 3, 3) |         18432 |          18432 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.06005 | -0.00122 |    0.04697 |
+# | 39 | module.layer3.0.conv2.weight        | (64, 64, 3, 3) |         36864 |          36864 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.05147 | -0.00011 |    0.03767 |
+# | 40 | module.layer3.0.downsample.0.weight | (64, 32, 1, 1) |          2048 |           2048 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.09252 |  0.00159 |    0.06504 |
+# | 41 | module.layer3.1.conv1.weight        | (52, 64, 3, 3) |         29952 |          29952 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.03913 | -0.00138 |    0.02846 |
+# | 42 | module.layer3.1.conv2.weight        | (64, 52, 3, 3) |         29952 |          29952 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.03736 | -0.00428 |    0.02826 |
+# | 43 | module.layer3.2.conv1.weight        | (39, 64, 3, 3) |         22464 |          22464 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.03982 | -0.00118 |    0.02987 |
+# | 44 | module.layer3.2.conv2.weight        | (64, 39, 3, 3) |         22464 |          22464 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.03651 | -0.00484 |    0.02836 |
+# | 45 | module.layer3.3.conv1.weight        | (39, 64, 3, 3) |         22464 |          22464 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.04171 | -0.00306 |    0.03253 |
+# | 46 | module.layer3.3.conv2.weight        | (64, 39, 3, 3) |         22464 |          22464 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.03619 | -0.00400 |    0.02820 |
+# | 47 | module.layer3.4.conv1.weight        | (64, 64, 3, 3) |         36864 |          36864 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.04342 | -0.00387 |    0.03380 |
+# | 48 | module.layer3.4.conv2.weight        | (64, 64, 3, 3) |         36864 |          36864 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.03494 | -0.00264 |    0.02668 |
+# | 49 | module.layer3.5.conv1.weight        | (39, 64, 3, 3) |         22464 |          22464 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.04597 | -0.00467 |    0.03630 |
+# | 50 | module.layer3.5.conv2.weight        | (64, 39, 3, 3) |         22464 |          22464 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.03378 | -0.00285 |    0.02578 |
+# | 51 | module.layer3.6.conv1.weight        | (39, 64, 3, 3) |         22464 |          22464 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.03428 | -0.00242 |    0.02700 |
+# | 52 | module.layer3.6.conv2.weight        | (64, 39, 3, 3) |         22464 |          22464 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.02501 | -0.00001 |    0.01828 |
+# | 53 | module.layer3.7.conv1.weight        | (39, 64, 3, 3) |         22464 |          22464 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.03189 | -0.00319 |    0.02489 |
+# | 54 | module.layer3.7.conv2.weight        | (64, 39, 3, 3) |         22464 |          22464 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.02260 | -0.00034 |    0.01673 |
+# | 55 | module.layer3.8.conv1.weight        | (39, 64, 3, 3) |         22464 |          22464 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.03693 | -0.00290 |    0.02890 |
+# | 56 | module.layer3.8.conv2.weight        | (64, 39, 3, 3) |         22464 |          22464 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.02474 |  0.00102 |    0.01800 |
+# | 57 | module.fc.weight                    | (10, 64)       |           640 |            640 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.42898 | -0.00001 |    0.33739 |
+# | 58 | Total sparsity:                     | -              |        570704 |         570704 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.00000 |  0.00000 |    0.00000 |
+# +----+-------------------------------------+----------------+---------------+----------------+------------+------------+----------+----------+----------+------------+---------+----------+------------+
+# Total sparsity: 0.00
+#
+# --- validate (epoch=249)-----------
+# 5000 samples (256 per mini-batch)
+# ==> Top1: 92.180    Top5: 99.660    Loss: 1.540
+#
+# ==> Best Top1: 92.580   On Epoch: 238
+#
+# Saving checkpoint to: logs/2018.10.16-103816/checkpoint.pth.tar
+# --- test ---------------------
+# 10000 samples (256 per mini-batch)
+# ==> Top1: 92.740    Top5: 99.640    Loss: 1.534
+#
+#
+# Log file for this run: /home/cvds_lab/nzmora/pytorch_workspace/distiller/examples/classifier_compression/logs/2018.10.16-103816/2018.10.16-103816.log
+#
+# real    29m38.833s
+# user    65m54.241s
+# sys     6m17.564s
+
+
+version: 1
+pruners:
+  filter_pruner:
+    class: 'L1RankedStructureParameterPruner'
+    reg_regims:
+      #'module.conv1.weight':          [0.2, '3D']
+      'module.layer1.0.conv1.weight': [0.7, '3D']
+      #'module.layer1.0.conv2.weight': [0.4, '3D']
+      'module.layer1.1.conv1.weight': [0.7, '3D']
+      #'module.layer1.1.conv2.weight': [0.7, '3D']
+      'module.layer1.2.conv1.weight': [0.7, '3D']
+      #'module.layer1.2.conv2.weight': [0.7, '3D']
+      'module.layer1.3.conv1.weight': [0.7, '3D']
+      #'module.layer1.3.conv2.weight': [0.7, '3D']
+      'module.layer1.4.conv1.weight': [0.7, '3D']
+      #'module.layer1.4.conv2.weight': [0.7, '3D']
+      'module.layer1.5.conv1.weight': [0.7, '3D']
+      #'module.layer1.5.conv2.weight': [0.7, '3D']
+      'module.layer1.6.conv1.weight': [0.7, '3D']
+      #'module.layer1.6.conv2.weight': [0.7, '3D']
+      'module.layer1.7.conv1.weight': [0.7, '3D']
+      #'module.layer1.7.conv2.weight': [0.7, '3D']
+      'module.layer1.8.conv1.weight': [0.7, '3D']
+      #'module.layer1.8.conv2.weight': [0.2, '3D']
+
+      ##'module.layer2.0.conv1.weight': [0.6, '3D']
+      #'module.layer2.0.conv2.weight': [0.4, '3D']
+      #'module.layer2.0.downsample.0.weight': [0.4, '3D']
+      'module.layer2.1.conv1.weight': [0.6, '3D']
+      #'module.layer2.1.conv2.weight': [0.4, '3D']
+      'module.layer2.2.conv1.weight': [0.6, '3D']
+      #'module.layer2.2.conv2.weight': [0.4, '3D']
+      'module.layer2.3.conv1.weight': [0.6, '3D']
+      #'module.layer2.3.conv2.weight': [0.4, '3D']
+
+      'module.layer2.4.conv1.weight': [0.6, '3D']
+      # 'module.layer2.4.conv2.weight': [0.4, '3D']
+      #'module.layer2.5.conv1.weight': [0.6, '3D']
+      # 'module.layer2.5.conv2.weight': [0.4, '3D']
+      'module.layer2.6.conv1.weight': [0.6, '3D']
+      # 'module.layer2.6.conv2.weight': [0.2, '3D']
+      'module.layer2.7.conv1.weight': [0.6, '3D']
+      # 'module.layer2.7.conv2.weight': [0.2, '3D']
+      ##'module.layer2.8.conv1.weight': [0.4, '3D']
+      # 'module.layer2.8.conv2.weight': [0.2, '3D']
+
+      #'module.layer3.0.conv1.weight': [0.1, '3D']
+      # 'module.layer3.0.conv2.weight': [0.1, '3D']
+      # 'module.layer3.0.downsample.0.weight': [0.1, '3D']
+      'module.layer3.1.conv1.weight': [0.2, '3D']
+      # 'module.layer3.1.conv2.weight': [0.1, '3D']
+      'module.layer3.2.conv1.weight': [0.4, '3D']
+      # 'module.layer3.2.conv2.weight': [0.1, '3D']
+      'module.layer3.3.conv1.weight': [0.4, '3D']
+      # 'module.layer3.3.conv2.weight': [0.1, '3D']
+      #'module.layer3.4.conv1.weight': [0.1, '3D']
+      # 'module.layer3.4.conv2.weight': [0.1, '3D']
+      'module.layer3.5.conv1.weight': [0.4, '3D']
+      #'module.layer3.5.conv2.weight': [0.1, '3D']
+      'module.layer3.6.conv1.weight': [0.4, '3D']
+      # 'module.layer3.6.conv2.weight': [0.1, '3D']
+      'module.layer3.7.conv1.weight': [0.4, '3D']
+      # 'module.layer3.7.conv2.weight': [0.1, '3D']
+      'module.layer3.8.conv1.weight': [0.4, '3D']
+      # 'module.layer3.8.conv2.weight': [0.2, '3D']
+
+extensions:
+  net_thinner:
+      class: 'FilterRemover'
+      thinning_func_str: remove_filters
+      arch: 'resnet56_cifar'
+      dataset: 'cifar10'
+
+lr_schedulers:
+   exp_finetuning_lr:
+     class: ExponentialLR
+     gamma: 0.95
+
+
+policies:
+  - pruner:
+      instance_name: filter_pruner
+    epochs: [180]
+
+  - extension:
+      instance_name: net_thinner
+    epochs: [180]
+
+  - lr_scheduler:
+      instance_name: exp_finetuning_lr
+    starting_epoch: 190
+    ending_epoch: 300
+    frequency: 1
diff --git a/examples/sensitivity-analysis/resnet50-imagenet/resnet50.imagenet.sensitivity_filter_wise.csv b/examples/sensitivity-analysis/resnet50-imagenet/resnet50.imagenet.sensitivity_filter_wise.csv
new file mode 100644
index 0000000000000000000000000000000000000000..d9183c90f26ba983dd2263f295d79cd3bfabfc91
--- /dev/null
+++ b/examples/sensitivity-analysis/resnet50-imagenet/resnet50.imagenet.sensitivity_filter_wise.csv
@@ -0,0 +1,584 @@
+parameter,sparsity,top1,top5,loss
+module.conv1.weight,0.0,76.13,92.862,0.9646787385703351
+module.conv1.weight,0.05,76.092,92.878,0.9647577350997194
+module.conv1.weight,0.1,75.83,92.88,0.9732017052568949
+module.conv1.weight,0.15000000000000002,75.544,92.672,0.9858137353190358
+module.conv1.weight,0.2,75.036,92.434,1.00316541261819
+module.conv1.weight,0.25,67.07600000000001,87.788,1.367336580947954
+module.conv1.weight,0.30000000000000004,63.832,85.552,1.5320777771424268
+module.conv1.weight,0.35000000000000003,62.72599999999999,84.932,1.5777477780166933
+module.conv1.weight,0.4,54.571999999999996,78.654,2.0279791546719426
+module.conv1.weight,0.45,45.67400000000001,69.792,2.5944640064726063
+module.conv1.weight,0.5,12.258000000000003,26.004000000000005,6.298526190981571
+module.layer1.0.conv1.weight,0.0,76.13,92.862,0.9646787385703351
+module.layer1.0.conv1.weight,0.05,76.13,92.862,0.9646787385703351
+module.layer1.0.conv1.weight,0.1,75.932,92.77,0.9672794009045679
+module.layer1.0.conv1.weight,0.15000000000000002,76.012,92.842,0.9668000321455148
+module.layer1.0.conv1.weight,0.2,75.316,92.622,0.9959884060128614
+module.layer1.0.conv1.weight,0.25,75.302,92.594,0.997434276403213
+module.layer1.0.conv1.weight,0.30000000000000004,75.18599999999999,92.47800000000001,1.0021765834975
+module.layer1.0.conv1.weight,0.35000000000000003,74.424,91.99,1.039658749438063
+module.layer1.0.conv1.weight,0.4,73.13799999999999,91.216,1.0942878348334713
+module.layer1.0.conv1.weight,0.45,70.798,89.90599999999999,1.1983419334401888
+module.layer1.0.conv1.weight,0.5,65.846,86.896,1.4433029138920255
+module.layer1.0.conv2.weight,0.0,76.13,92.862,0.9646787385703351
+module.layer1.0.conv2.weight,0.05,76.13,92.862,0.9646787385703351
+module.layer1.0.conv2.weight,0.1,76.14,92.86999999999999,0.9648298266894967
+module.layer1.0.conv2.weight,0.15000000000000002,76.088,92.83200000000001,0.9653575117034573
+module.layer1.0.conv2.weight,0.2,76.042,92.878,0.9661754479973899
+module.layer1.0.conv2.weight,0.25,75.82,92.824,0.9717823474534921
+module.layer1.0.conv2.weight,0.30000000000000004,75.726,92.704,0.9781641415795503
+module.layer1.0.conv2.weight,0.35000000000000003,73.08600000000001,91.33200000000001,1.098903458337394
+module.layer1.0.conv2.weight,0.4,73.046,91.366,1.1005394557604984
+module.layer1.0.conv2.weight,0.45,64.318,86.012,1.5271362866065938
+module.layer1.0.conv2.weight,0.5,64.916,86.41799999999999,1.4901175172049177
+module.layer1.0.conv3.weight,0.0,76.13,92.862,0.9646787385703351
+module.layer1.0.conv3.weight,0.05,76.13,92.862,0.9646787385703351
+module.layer1.0.conv3.weight,0.1,76.13,92.862,0.9646787382662295
+module.layer1.0.conv3.weight,0.15000000000000002,76.134,92.86,0.9646867231598922
+module.layer1.0.conv3.weight,0.2,76.13600000000001,92.876,0.9648272743334576
+module.layer1.0.conv3.weight,0.25,76.126,92.866,0.964804031487022
+module.layer1.0.conv3.weight,0.30000000000000004,76.042,92.892,0.966239919391822
+module.layer1.0.conv3.weight,0.35000000000000003,75.982,92.91,0.96753498220018
+module.layer1.0.conv3.weight,0.4,75.94,92.876,0.9695769828193038
+module.layer1.0.conv3.weight,0.45,75.86,92.81200000000001,0.9735630200225487
+module.layer1.0.conv3.weight,0.5,75.388,92.572,0.9887932550390158
+module.layer1.0.downsample.0.weight,0.0,76.13,92.862,0.9646787385703351
+module.layer1.0.downsample.0.weight,0.05,76.13,92.862,0.9646787385703351
+module.layer1.0.downsample.0.weight,0.1,76.13,92.862,0.9646787385703351
+module.layer1.0.downsample.0.weight,0.15000000000000002,76.084,92.852,0.9652105177543602
+module.layer1.0.downsample.0.weight,0.2,75.978,92.854,0.9662235065230306
+module.layer1.0.downsample.0.weight,0.25,75.778,92.784,0.9759335609114901
+module.layer1.0.downsample.0.weight,0.30000000000000004,74.91399999999999,92.336,1.0159372910857198
+module.layer1.0.downsample.0.weight,0.35000000000000003,71.514,90.47800000000001,1.1738625028911904
+module.layer1.0.downsample.0.weight,0.4,67.556,87.998,1.3625655029805341
+module.layer1.0.downsample.0.weight,0.45,61.681999999999995,84.002,1.6625876446463614
+module.layer1.0.downsample.0.weight,0.5,40.928,65.08800000000001,3.0252077293639297
+module.layer1.1.conv1.weight,0.0,76.13,92.862,0.9646787385703351
+module.layer1.1.conv1.weight,0.05,76.13,92.862,0.9646787385703351
+module.layer1.1.conv1.weight,0.1,76.098,92.876,0.9652870166666652
+module.layer1.1.conv1.weight,0.15000000000000002,76.078,92.88,0.9660209278214948
+module.layer1.1.conv1.weight,0.2,75.906,92.854,0.9702612833709134
+module.layer1.1.conv1.weight,0.25,75.874,92.838,0.9720629371550619
+module.layer1.1.conv1.weight,0.30000000000000004,75.71600000000001,92.784,0.9748685826756516
+module.layer1.1.conv1.weight,0.35000000000000003,75.516,92.636,0.9820135380996733
+module.layer1.1.conv1.weight,0.4,75.422,92.58999999999999,0.9858410236026562
+module.layer1.1.conv1.weight,0.45,75.342,92.552,0.9901515745690892
+module.layer1.1.conv1.weight,0.5,75.214,92.44,0.9973136783406443
+module.layer1.1.conv2.weight,0.0,76.13,92.862,0.9646787385703351
+module.layer1.1.conv2.weight,0.05,76.074,92.866,0.9656340561959206
+module.layer1.1.conv2.weight,0.1,76.106,92.882,0.9665291072154533
+module.layer1.1.conv2.weight,0.15000000000000002,76.004,92.838,0.9674485273355128
+module.layer1.1.conv2.weight,0.2,75.74799999999999,92.798,0.973839068215112
+module.layer1.1.conv2.weight,0.25,75.754,92.774,0.976552191209428
+module.layer1.1.conv2.weight,0.30000000000000004,75.634,92.796,0.9795367948862973
+module.layer1.1.conv2.weight,0.35000000000000003,75.57000000000001,92.67800000000001,0.9839644473122093
+module.layer1.1.conv2.weight,0.4,75.566,92.636,0.9858285662318981
+module.layer1.1.conv2.weight,0.45,75.09599999999999,92.49000000000001,1.0086073198792886
+module.layer1.1.conv2.weight,0.5,74.634,92.238,1.0265772062904983
+module.layer1.1.conv3.weight,0.0,76.13,92.862,0.9646787385703351
+module.layer1.1.conv3.weight,0.05,76.13,92.862,0.9646781417636237
+module.layer1.1.conv3.weight,0.1,76.13600000000001,92.862,0.9646657041597122
+module.layer1.1.conv3.weight,0.15000000000000002,76.128,92.862,0.9646495287211573
+module.layer1.1.conv3.weight,0.2,76.122,92.85799999999999,0.964540821633169
+module.layer1.1.conv3.weight,0.25,76.142,92.85600000000001,0.9644717303465824
+module.layer1.1.conv3.weight,0.30000000000000004,76.118,92.876,0.9645412101277283
+module.layer1.1.conv3.weight,0.35000000000000003,76.1,92.9,0.9652030228504112
+module.layer1.1.conv3.weight,0.4,76.056,92.876,0.9660147736419221
+module.layer1.1.conv3.weight,0.45,76.018,92.85,0.9658604563042826
+module.layer1.1.conv3.weight,0.5,75.958,92.862,0.9682643982220668
+module.layer1.2.conv1.weight,0.0,76.13,92.862,0.9646787385703351
+module.layer1.2.conv1.weight,0.05,76.054,92.892,0.9656125142105988
+module.layer1.2.conv1.weight,0.1,76.058,92.84400000000001,0.9668843527229466
+module.layer1.2.conv1.weight,0.15000000000000002,76.08200000000001,92.872,0.9668020492001456
+module.layer1.2.conv1.weight,0.2,75.932,92.808,0.9682766031093746
+module.layer1.2.conv1.weight,0.25,75.862,92.75999999999999,0.9708546974829264
+module.layer1.2.conv1.weight,0.30000000000000004,75.75800000000001,92.718,0.9728293174079484
+module.layer1.2.conv1.weight,0.35000000000000003,75.71799999999999,92.7,0.9746806868637092
+module.layer1.2.conv1.weight,0.4,75.38199999999999,92.526,0.9875993205576526
+module.layer1.2.conv1.weight,0.45,75.18599999999999,92.452,0.9937687095026578
+module.layer1.2.conv1.weight,0.5,74.81400000000001,92.218,1.0118083756188954
+module.layer1.2.conv2.weight,0.0,76.13,92.862,0.9646787385703351
+module.layer1.2.conv2.weight,0.05,75.91799999999999,92.826,0.9685604718266703
+module.layer1.2.conv2.weight,0.1,75.924,92.82000000000001,0.9695663569228992
+module.layer1.2.conv2.weight,0.15000000000000002,75.666,92.742,0.9778546592866886
+module.layer1.2.conv2.weight,0.2,75.606,92.684,0.9812715381992112
+module.layer1.2.conv2.weight,0.25,75.22999999999999,92.542,0.9929108396172521
+module.layer1.2.conv2.weight,0.30000000000000004,75.046,92.414,1.0019573495552248
+module.layer1.2.conv2.weight,0.35000000000000003,74.944,92.24600000000001,1.011767355230998
+module.layer1.2.conv2.weight,0.4,74.58,92.134,1.019189273672445
+module.layer1.2.conv2.weight,0.45,74.078,91.912,1.0405801353710036
+module.layer1.2.conv2.weight,0.5,73.89,91.752,1.049945247036461
+module.layer1.2.conv3.weight,0.0,76.13,92.862,0.9646787385703351
+module.layer1.2.conv3.weight,0.05,76.12400000000001,92.86,0.9646835384168189
+module.layer1.2.conv3.weight,0.1,76.12400000000001,92.86,0.9647042825818065
+module.layer1.2.conv3.weight,0.15000000000000002,76.118,92.85600000000001,0.9646848114017325
+module.layer1.2.conv3.weight,0.2,76.132,92.854,0.9646729855056924
+module.layer1.2.conv3.weight,0.25,76.134,92.854,0.9646716789931671
+module.layer1.2.conv3.weight,0.30000000000000004,76.128,92.862,0.9646537449895121
+module.layer1.2.conv3.weight,0.35000000000000003,76.11,92.882,0.9645740807968741
+module.layer1.2.conv3.weight,0.4,76.102,92.874,0.9645784062390422
+module.layer1.2.conv3.weight,0.45,76.116,92.874,0.9646241778165708
+module.layer1.2.conv3.weight,0.5,76.118,92.86,0.9649834648839066
+module.layer2.0.conv1.weight,0.0,76.13,92.862,0.9646787385703351
+module.layer2.0.conv1.weight,0.05,76.018,92.842,0.967156604996749
+module.layer2.0.conv1.weight,0.1,75.53999999999999,92.694,0.98511602874009
+module.layer2.0.conv1.weight,0.15000000000000002,75.384,92.57,0.9942189067298051
+module.layer2.0.conv1.weight,0.2,73.736,91.75999999999999,1.0647947725896927
+module.layer2.0.conv1.weight,0.25,72.794,91.312,1.107239385876728
+module.layer2.0.conv1.weight,0.30000000000000004,69.922,89.432,1.2431173596759222
+module.layer2.0.conv1.weight,0.35000000000000003,69.928,89.59,1.2365096223597627
+module.layer2.0.conv1.weight,0.4,66.234,87.068,1.4263274966149908
+module.layer2.0.conv1.weight,0.45,44.62199999999999,67.55,2.767483531820531
+module.layer2.0.conv1.weight,0.5,36.982000000000006,59.326,3.3448752748722943
+module.layer2.0.conv2.weight,0.0,76.13,92.862,0.9646787385703351
+module.layer2.0.conv2.weight,0.05,75.902,92.842,0.969782282351231
+module.layer2.0.conv2.weight,0.1,75.756,92.824,0.9730839444210331
+module.layer2.0.conv2.weight,0.15000000000000002,75.64999999999999,92.72,0.978511454727577
+module.layer2.0.conv2.weight,0.2,75.58,92.676,0.9845758525996795
+module.layer2.0.conv2.weight,0.25,75.21799999999999,92.5,0.997505821591737
+module.layer2.0.conv2.weight,0.30000000000000004,74.788,92.288,1.0151040011218613
+module.layer2.0.conv2.weight,0.35000000000000003,74.42,92.078,1.0322041135965574
+module.layer2.0.conv2.weight,0.4,74.246,91.974,1.0409574107247952
+module.layer2.0.conv2.weight,0.45,73.40599999999999,91.438,1.077756118409488
+module.layer2.0.conv2.weight,0.5,72.55,90.956,1.1181802381666337
+module.layer2.0.conv3.weight,0.0,76.13,92.862,0.9646787385703351
+module.layer2.0.conv3.weight,0.05,76.13,92.862,0.9646787385703351
+module.layer2.0.conv3.weight,0.1,76.13,92.862,0.9646787385703351
+module.layer2.0.conv3.weight,0.15000000000000002,76.13,92.85799999999999,0.9647160731256009
+module.layer2.0.conv3.weight,0.2,76.13799999999999,92.85600000000001,0.9647735800518065
+module.layer2.0.conv3.weight,0.25,76.12400000000001,92.852,0.9647925913485942
+module.layer2.0.conv3.weight,0.30000000000000004,76.126,92.84400000000001,0.9647947986971356
+module.layer2.0.conv3.weight,0.35000000000000003,76.076,92.882,0.9641403041171785
+module.layer2.0.conv3.weight,0.4,76.066,92.908,0.9645294097005104
+module.layer2.0.conv3.weight,0.45,75.91600000000001,92.886,0.9665729468878436
+module.layer2.0.conv3.weight,0.5,75.72,92.74,0.9742113863479118
+module.layer2.0.downsample.0.weight,0.0,76.13,92.862,0.9646787385703351
+module.layer2.0.downsample.0.weight,0.05,76.13,92.862,0.9646787385703351
+module.layer2.0.downsample.0.weight,0.1,76.13,92.862,0.9646787385703351
+module.layer2.0.downsample.0.weight,0.15000000000000002,76.112,92.866,0.9650651668103376
+module.layer2.0.downsample.0.weight,0.2,76.07,92.86,0.9653989572306069
+module.layer2.0.downsample.0.weight,0.25,76.028,92.824,0.9663104700038628
+module.layer2.0.downsample.0.weight,0.30000000000000004,75.98,92.868,0.9676333344256389
+module.layer2.0.downsample.0.weight,0.35000000000000003,75.96000000000001,92.838,0.970427606482895
+module.layer2.0.downsample.0.weight,0.4,75.876,92.794,0.9731549210086157
+module.layer2.0.downsample.0.weight,0.45,75.616,92.73400000000001,0.9814816470809126
+module.layer2.0.downsample.0.weight,0.5,75.412,92.67999999999999,0.9890781333099821
+module.layer2.1.conv1.weight,0.0,76.13,92.862,0.9646787385703351
+module.layer2.1.conv1.weight,0.05,76.096,92.872,0.9646778918650688
+module.layer2.1.conv1.weight,0.1,76.058,92.896,0.9657731824079335
+module.layer2.1.conv1.weight,0.15000000000000002,75.986,92.866,0.9670812168291635
+module.layer2.1.conv1.weight,0.2,75.924,92.848,0.9687874529282655
+module.layer2.1.conv1.weight,0.25,75.864,92.81200000000001,0.9736961963377434
+module.layer2.1.conv1.weight,0.30000000000000004,75.726,92.774,0.9777020039607072
+module.layer2.1.conv1.weight,0.35000000000000003,75.602,92.738,0.980759804878307
+module.layer2.1.conv1.weight,0.4,75.214,92.548,0.9945249997687585
+module.layer2.1.conv1.weight,0.45,75.03999999999999,92.49000000000001,1.0022178821417755
+module.layer2.1.conv1.weight,0.5,74.806,92.388,1.0118249537689343
+module.layer2.1.conv2.weight,0.0,76.13,92.862,0.9646787385703351
+module.layer2.1.conv2.weight,0.05,76.05,92.872,0.9652618023053723
+module.layer2.1.conv2.weight,0.1,75.996,92.85799999999999,0.9672502354547686
+module.layer2.1.conv2.weight,0.15000000000000002,75.934,92.792,0.9704457092650083
+module.layer2.1.conv2.weight,0.2,75.858,92.75,0.9751517729339547
+module.layer2.1.conv2.weight,0.25,75.48,92.582,0.9862348007760484
+module.layer2.1.conv2.weight,0.30000000000000004,74.222,91.824,1.0404229535892298
+module.layer2.1.conv2.weight,0.35000000000000003,74.146,91.73599999999999,1.0469928039427934
+module.layer2.1.conv2.weight,0.4,69.038,88.73400000000001,1.28229090145656
+module.layer2.1.conv2.weight,0.45,54.476,77.58,2.1224025232451305
+module.layer2.1.conv2.weight,0.5,54.217999999999996,77.29599999999999,2.1405887241874417
+module.layer2.1.conv3.weight,0.0,76.13,92.862,0.9646787385703351
+module.layer2.1.conv3.weight,0.05,76.128,92.864,0.9646533330788418
+module.layer2.1.conv3.weight,0.1,76.122,92.864,0.9646252431735699
+module.layer2.1.conv3.weight,0.15000000000000002,76.12400000000001,92.86,0.9647012647925591
+module.layer2.1.conv3.weight,0.2,76.12400000000001,92.85799999999999,0.9647608776481782
+module.layer2.1.conv3.weight,0.25,76.132,92.85799999999999,0.9646967381847145
+module.layer2.1.conv3.weight,0.30000000000000004,76.13600000000001,92.866,0.9647041466467233
+module.layer2.1.conv3.weight,0.35000000000000003,76.128,92.864,0.964786077032284
+module.layer2.1.conv3.weight,0.4,76.12400000000001,92.888,0.9648655662123035
+module.layer2.1.conv3.weight,0.45,76.104,92.9,0.9648101361734531
+module.layer2.1.conv3.weight,0.5,76.07,92.9,0.9647237465393779
+module.layer2.2.conv1.weight,0.0,76.13,92.862,0.9646787385703351
+module.layer2.2.conv1.weight,0.05,76.046,92.854,0.9665528654443971
+module.layer2.2.conv1.weight,0.1,75.964,92.862,0.9686528508912543
+module.layer2.2.conv1.weight,0.15000000000000002,75.848,92.784,0.9736106615437541
+module.layer2.2.conv1.weight,0.2,75.69399999999999,92.77,0.9788862500263723
+module.layer2.2.conv1.weight,0.25,75.512,92.67999999999999,0.9830762848866231
+module.layer2.2.conv1.weight,0.30000000000000004,75.354,92.656,0.9894499518737504
+module.layer2.2.conv1.weight,0.35000000000000003,75.17399999999999,92.51599999999999,0.9970121695374953
+module.layer2.2.conv1.weight,0.4,74.972,92.414,1.0066485434618535
+module.layer2.2.conv1.weight,0.45,74.85,92.34400000000001,1.012786984291612
+module.layer2.2.conv1.weight,0.5,74.642,92.20200000000001,1.0226188807615206
+module.layer2.2.conv2.weight,0.0,76.13,92.862,0.9646787385703351
+module.layer2.2.conv2.weight,0.05,76.03800000000001,92.838,0.968614383011448
+module.layer2.2.conv2.weight,0.1,75.888,92.842,0.9717724581001971
+module.layer2.2.conv2.weight,0.15000000000000002,75.806,92.852,0.974677234659998
+module.layer2.2.conv2.weight,0.2,75.708,92.83200000000001,0.9765448045669768
+module.layer2.2.conv2.weight,0.25,75.66,92.776,0.9798674343952113
+module.layer2.2.conv2.weight,0.30000000000000004,75.386,92.686,0.9875597967481125
+module.layer2.2.conv2.weight,0.35000000000000003,75.322,92.64,0.9919184984601273
+module.layer2.2.conv2.weight,0.4,75.226,92.506,0.9981514553026278
+module.layer2.2.conv2.weight,0.45,75.102,92.49199999999999,1.002877431712589
+module.layer2.2.conv2.weight,0.5,74.874,92.306,1.0117304964485216
+module.layer2.2.conv3.weight,0.0,76.13,92.862,0.9646787385703351
+module.layer2.2.conv3.weight,0.05,76.126,92.876,0.9645333172259287
+module.layer2.2.conv3.weight,0.1,76.078,92.862,0.9660904666751013
+module.layer2.2.conv3.weight,0.15000000000000002,76.044,92.838,0.966414061431982
+module.layer2.2.conv3.weight,0.2,76.018,92.80199999999999,0.9685548517320841
+module.layer2.2.conv3.weight,0.25,76.006,92.824,0.9688468615011292
+module.layer2.2.conv3.weight,0.30000000000000004,75.97,92.78,0.9721222430923765
+module.layer2.2.conv3.weight,0.35000000000000003,75.73,92.746,0.9787512461141661
+module.layer2.2.conv3.weight,0.4,75.66199999999999,92.732,0.9809944002451944
+module.layer2.2.conv3.weight,0.45,75.51,92.622,0.9858916249810431
+module.layer2.2.conv3.weight,0.5,75.442,92.572,0.9902838661658522
+module.layer2.3.conv1.weight,0.0,76.13,92.862,0.9646787385703351
+module.layer2.3.conv1.weight,0.05,76.042,92.882,0.9676482852320281
+module.layer2.3.conv1.weight,0.1,75.998,92.884,0.9686084858008795
+module.layer2.3.conv1.weight,0.15000000000000002,75.852,92.83,0.9724806093287712
+module.layer2.3.conv1.weight,0.2,75.76,92.774,0.9753189640385767
+module.layer2.3.conv1.weight,0.25,75.538,92.702,0.9820158223868634
+module.layer2.3.conv1.weight,0.30000000000000004,75.414,92.704,0.9897886653031617
+module.layer2.3.conv1.weight,0.35000000000000003,75.03800000000001,92.51,1.0011441475730771
+module.layer2.3.conv1.weight,0.4,74.8,92.364,1.0166383479930918
+module.layer2.3.conv1.weight,0.45,74.402,92.148,1.034298639668494
+module.layer2.3.conv1.weight,0.5,73.746,91.718,1.064169044701421
+module.layer2.3.conv2.weight,0.0,76.13,92.862,0.9646787385703351
+module.layer2.3.conv2.weight,0.05,76.00200000000001,92.86,0.9662534713897172
+module.layer2.3.conv2.weight,0.1,76.00200000000001,92.862,0.96831920879836
+module.layer2.3.conv2.weight,0.15000000000000002,75.946,92.84400000000001,0.9721023115728582
+module.layer2.3.conv2.weight,0.2,75.732,92.794,0.9761088758098834
+module.layer2.3.conv2.weight,0.25,75.614,92.73599999999999,0.981304814164736
+module.layer2.3.conv2.weight,0.30000000000000004,75.542,92.708,0.9833431544200503
+module.layer2.3.conv2.weight,0.35000000000000003,75.42,92.646,0.9895992186300612
+module.layer2.3.conv2.weight,0.4,75.166,92.434,1.0021974734809933
+module.layer2.3.conv2.weight,0.45,74.996,92.396,1.0080492892587674
+module.layer2.3.conv2.weight,0.5,74.832,92.30000000000001,1.0158192809595141
+module.layer2.3.conv3.weight,0.0,76.13,92.862,0.9646787385703351
+module.layer2.3.conv3.weight,0.05,76.116,92.874,0.964743162448309
+module.layer2.3.conv3.weight,0.1,76.104,92.862,0.9648478018994233
+module.layer2.3.conv3.weight,0.15000000000000002,76.116,92.866,0.964911586852098
+module.layer2.3.conv3.weight,0.2,76.094,92.884,0.9648562783033265
+module.layer2.3.conv3.weight,0.25,76.094,92.85799999999999,0.9652072199601301
+module.layer2.3.conv3.weight,0.30000000000000004,76.028,92.84,0.9652869780452884
+module.layer2.3.conv3.weight,0.35000000000000003,76.072,92.894,0.9655525322471346
+module.layer2.3.conv3.weight,0.4,75.924,92.854,0.9668436814479684
+module.layer2.3.conv3.weight,0.45,75.85,92.794,0.9713384634530061
+module.layer2.3.conv3.weight,0.5,75.76,92.726,0.9754472542174012
+module.layer3.0.conv1.weight,0.0,76.13,92.862,0.9646787385703351
+module.layer3.0.conv1.weight,0.05,75.854,92.7,0.9755539494965755
+module.layer3.0.conv1.weight,0.1,75.558,92.604,0.9903189021409771
+module.layer3.0.conv1.weight,0.15000000000000002,74.932,92.384,1.0120573165465376
+module.layer3.0.conv1.weight,0.2,73.97,91.902,1.0564837084741014
+module.layer3.0.conv1.weight,0.25,71.67999999999999,90.572,1.160918220497515
+module.layer3.0.conv1.weight,0.30000000000000004,70.22999999999999,89.724,1.2242342644200028
+module.layer3.0.conv1.weight,0.35000000000000003,67.42,87.63,1.3680912685029358
+module.layer3.0.conv1.weight,0.4,63.426,85.04,1.5690640997217635
+module.layer3.0.conv1.weight,0.45,62.834,84.516,1.6085462114032438
+module.layer3.0.conv1.weight,0.5,59.484,82.27,1.779903719619829
+module.layer3.0.conv2.weight,0.0,76.13,92.862,0.9646787385703351
+module.layer3.0.conv2.weight,0.05,75.984,92.886,0.9716181799921453
+module.layer3.0.conv2.weight,0.1,75.67200000000001,92.73599999999999,0.9827592867065448
+module.layer3.0.conv2.weight,0.15000000000000002,75.4,92.578,0.9944861367040749
+module.layer3.0.conv2.weight,0.2,75.108,92.384,1.0120520112769948
+module.layer3.0.conv2.weight,0.25,74.46,92.052,1.0373907407023468
+module.layer3.0.conv2.weight,0.30000000000000004,74.05999999999999,91.838,1.0571384563737984
+module.layer3.0.conv2.weight,0.35000000000000003,72.89,91.12599999999999,1.1066943608528506
+module.layer3.0.conv2.weight,0.4,71.754,90.564,1.1546887049109358
+module.layer3.0.conv2.weight,0.45,69.98400000000001,89.476,1.2325848809310374
+module.layer3.0.conv2.weight,0.5,67.73599999999999,87.976,1.3425824202749195
+module.layer3.0.conv3.weight,0.0,76.13,92.862,0.9646787385703351
+module.layer3.0.conv3.weight,0.05,76.128,92.86999999999999,0.9647133990514034
+module.layer3.0.conv3.weight,0.1,76.074,92.828,0.965464632121884
+module.layer3.0.conv3.weight,0.15000000000000002,76.058,92.84599999999999,0.9675507292303505
+module.layer3.0.conv3.weight,0.2,76.008,92.83,0.9695352081741606
+module.layer3.0.conv3.weight,0.25,75.882,92.814,0.9731980219331319
+module.layer3.0.conv3.weight,0.30000000000000004,75.698,92.758,0.9806116086487863
+module.layer3.0.conv3.weight,0.35000000000000003,75.524,92.666,0.9887452712472609
+module.layer3.0.conv3.weight,0.4,75.30399999999999,92.55,0.9961166965718176
+module.layer3.0.conv3.weight,0.45,74.752,92.188,1.017145576373655
+module.layer3.0.conv3.weight,0.5,74.166,91.93,1.0403414208974158
+module.layer3.0.downsample.0.weight,0.0,76.13,92.862,0.9646787385703351
+module.layer3.0.downsample.0.weight,0.05,76.122,92.88,0.9646817222997849
+module.layer3.0.downsample.0.weight,0.1,76.09,92.868,0.9652682708538308
+module.layer3.0.downsample.0.weight,0.15000000000000002,76.02,92.83200000000001,0.9678591250308928
+module.layer3.0.downsample.0.weight,0.2,76.012,92.84599999999999,0.969518158356754
+module.layer3.0.downsample.0.weight,0.25,75.964,92.862,0.9715871051410021
+module.layer3.0.downsample.0.weight,0.30000000000000004,75.83,92.762,0.9757750733774535
+module.layer3.0.downsample.0.weight,0.35000000000000003,75.75800000000001,92.74,0.9812161533200012
+module.layer3.0.downsample.0.weight,0.4,75.568,92.652,0.988263431404318
+module.layer3.0.downsample.0.weight,0.45,75.35000000000001,92.60000000000001,0.9997599098299228
+module.layer3.0.downsample.0.weight,0.5,75.042,92.36800000000001,1.0125716319497748
+module.layer3.1.conv1.weight,0.0,76.13,92.862,0.9646787385703351
+module.layer3.1.conv1.weight,0.05,76.026,92.84,0.9650524181826997
+module.layer3.1.conv1.weight,0.1,75.996,92.836,0.9674490037925388
+module.layer3.1.conv1.weight,0.15000000000000002,75.858,92.742,0.971184025309524
+module.layer3.1.conv1.weight,0.2,75.702,92.742,0.9757259215931502
+module.layer3.1.conv1.weight,0.25,75.622,92.67999999999999,0.9801329179685943
+module.layer3.1.conv1.weight,0.30000000000000004,75.486,92.56,0.9874053938048225
+module.layer3.1.conv1.weight,0.35000000000000003,75.42999999999999,92.562,0.9897024807121072
+module.layer3.1.conv1.weight,0.4,75.322,92.45,0.99729801554765
+module.layer3.1.conv1.weight,0.45,75.018,92.35600000000001,1.0083161153051319
+module.layer3.1.conv1.weight,0.5,74.71799999999999,92.12400000000001,1.0199289657175545
+module.layer3.1.conv2.weight,0.0,76.13,92.862,0.9646787385703351
+module.layer3.1.conv2.weight,0.05,75.982,92.854,0.9691590736715163
+module.layer3.1.conv2.weight,0.1,75.818,92.686,0.9758419072901713
+module.layer3.1.conv2.weight,0.15000000000000002,75.742,92.58800000000001,0.9815147864271184
+module.layer3.1.conv2.weight,0.2,75.578,92.58,0.9874765161348843
+module.layer3.1.conv2.weight,0.25,75.47,92.462,0.994418795619692
+module.layer3.1.conv2.weight,0.30000000000000004,75.224,92.328,0.9989697071058415
+module.layer3.1.conv2.weight,0.35000000000000003,75.18199999999999,92.374,1.000461941394879
+module.layer3.1.conv2.weight,0.4,74.89399999999999,92.244,1.0080392082430885
+module.layer3.1.conv2.weight,0.45,74.602,92.148,1.0224535768585545
+module.layer3.1.conv2.weight,0.5,74.366,92.00200000000001,1.0328381399110875
+module.layer3.1.conv3.weight,0.0,76.13,92.862,0.9646787385703351
+module.layer3.1.conv3.weight,0.05,76.062,92.864,0.9646131991579829
+module.layer3.1.conv3.weight,0.1,76.028,92.86,0.9663307377118234
+module.layer3.1.conv3.weight,0.15000000000000002,76.03999999999999,92.86,0.967569736953901
+module.layer3.1.conv3.weight,0.2,75.948,92.814,0.9681062712048996
+module.layer3.1.conv3.weight,0.25,75.944,92.878,0.9703570115475021
+module.layer3.1.conv3.weight,0.30000000000000004,75.896,92.816,0.9743181323366505
+module.layer3.1.conv3.weight,0.35000000000000003,75.83800000000001,92.778,0.9771007541186953
+module.layer3.1.conv3.weight,0.4,75.732,92.752,0.9815992211200752
+module.layer3.1.conv3.weight,0.45,75.684,92.666,0.985187043492891
+module.layer3.1.conv3.weight,0.5,75.512,92.598,0.989619911401247
+module.layer3.2.conv1.weight,0.0,76.13,92.862,0.9646787385703351
+module.layer3.2.conv1.weight,0.05,76.112,92.902,0.9664402683170475
+module.layer3.2.conv1.weight,0.1,76.012,92.84,0.9691802668480238
+module.layer3.2.conv1.weight,0.15000000000000002,75.878,92.80000000000001,0.9741810408173774
+module.layer3.2.conv1.weight,0.2,75.888,92.75999999999999,0.9778152308141698
+module.layer3.2.conv1.weight,0.25,75.726,92.694,0.985200674025988
+module.layer3.2.conv1.weight,0.30000000000000004,75.602,92.638,0.9895688426889937
+module.layer3.2.conv1.weight,0.35000000000000003,75.384,92.57600000000001,0.9999833124480682
+module.layer3.2.conv1.weight,0.4,75.126,92.45,1.0099009264792718
+module.layer3.2.conv1.weight,0.45,74.91799999999999,92.374,1.016800964213147
+module.layer3.2.conv1.weight,0.5,74.784,92.27,1.026382406147159
+module.layer3.2.conv2.weight,0.0,76.13,92.862,0.9646787385703351
+module.layer3.2.conv2.weight,0.05,76.05999999999999,92.874,0.9645183723495929
+module.layer3.2.conv2.weight,0.1,75.954,92.85,0.9664772243372033
+module.layer3.2.conv2.weight,0.15000000000000002,75.89200000000001,92.824,0.9708176260547979
+module.layer3.2.conv2.weight,0.2,75.848,92.768,0.9740906326594406
+module.layer3.2.conv2.weight,0.25,75.74,92.688,0.9793189029608457
+module.layer3.2.conv2.weight,0.30000000000000004,75.486,92.628,0.9856890367001899
+module.layer3.2.conv2.weight,0.35000000000000003,75.30399999999999,92.55799999999999,0.9944385276461133
+module.layer3.2.conv2.weight,0.4,75.11,92.432,1.0005023837545692
+module.layer3.2.conv2.weight,0.45,74.88,92.366,1.007985245816562
+module.layer3.2.conv2.weight,0.5,74.57400000000001,92.142,1.0231264646412168
+module.layer3.2.conv3.weight,0.0,76.13,92.862,0.9646787385703351
+module.layer3.2.conv3.weight,0.05,76.076,92.86999999999999,0.9658319840625846
+module.layer3.2.conv3.weight,0.1,76.058,92.898,0.9680370671712619
+module.layer3.2.conv3.weight,0.15000000000000002,76.064,92.876,0.9695597740308358
+module.layer3.2.conv3.weight,0.2,76.01,92.876,0.9713810268713504
+module.layer3.2.conv3.weight,0.25,75.966,92.81200000000001,0.9739112689026764
+module.layer3.2.conv3.weight,0.30000000000000004,75.856,92.80000000000001,0.9776264067967332
+module.layer3.2.conv3.weight,0.35000000000000003,75.77000000000001,92.782,0.9793802074023656
+module.layer3.2.conv3.weight,0.4,75.704,92.73599999999999,0.9823789858088203
+module.layer3.2.conv3.weight,0.45,75.628,92.69800000000001,0.9874243092324051
+module.layer3.2.conv3.weight,0.5,75.47,92.584,0.9941450587036659
+module.layer3.3.conv1.weight,0.0,76.13,92.862,0.9646787385703351
+module.layer3.3.conv1.weight,0.05,76.05199999999999,92.854,0.9672309211170188
+module.layer3.3.conv1.weight,0.1,76.0,92.862,0.971594920039785
+module.layer3.3.conv1.weight,0.15000000000000002,75.858,92.84599999999999,0.9759238329620991
+module.layer3.3.conv1.weight,0.2,75.788,92.794,0.9799342356926326
+module.layer3.3.conv1.weight,0.25,75.664,92.71199999999999,0.9868303231742916
+module.layer3.3.conv1.weight,0.30000000000000004,75.44800000000001,92.566,0.997164914802629
+module.layer3.3.conv1.weight,0.35000000000000003,75.168,92.44200000000001,1.0106435110982583
+module.layer3.3.conv1.weight,0.4,74.988,92.306,1.0185190551743215
+module.layer3.3.conv1.weight,0.45,74.83999999999999,92.188,1.0264643902073107
+module.layer3.3.conv1.weight,0.5,74.7,92.08800000000001,1.0364262275397775
+module.layer3.3.conv2.weight,0.0,76.13,92.862,0.9646787385703351
+module.layer3.3.conv2.weight,0.05,76.022,92.808,0.9676625841886412
+module.layer3.3.conv2.weight,0.1,75.82,92.792,0.9736834481662636
+module.layer3.3.conv2.weight,0.15000000000000002,75.736,92.726,0.9789236883575819
+module.layer3.3.conv2.weight,0.2,75.464,92.58999999999999,0.9897794537246225
+module.layer3.3.conv2.weight,0.25,75.318,92.57,0.9949841195223281
+module.layer3.3.conv2.weight,0.30000000000000004,75.198,92.464,1.0017243420743207
+module.layer3.3.conv2.weight,0.35000000000000003,75.006,92.412,1.0079616335581758
+module.layer3.3.conv2.weight,0.4,74.856,92.294,1.0192331128886765
+module.layer3.3.conv2.weight,0.45,74.52799999999999,92.14,1.030934373258936
+module.layer3.3.conv2.weight,0.5,74.278,91.946,1.04510385771187
+module.layer3.3.conv3.weight,0.0,76.13,92.862,0.9646787385703351
+module.layer3.3.conv3.weight,0.05,76.126,92.86,0.9650276996651469
+module.layer3.3.conv3.weight,0.1,76.098,92.842,0.9651636611290126
+module.layer3.3.conv3.weight,0.15000000000000002,76.05,92.86999999999999,0.9664237293205701
+module.layer3.3.conv3.weight,0.2,76.054,92.876,0.968392280595643
+module.layer3.3.conv3.weight,0.25,76.012,92.842,0.970424373843232
+module.layer3.3.conv3.weight,0.30000000000000004,75.928,92.822,0.9723988493182223
+module.layer3.3.conv3.weight,0.35000000000000003,75.868,92.806,0.9749150700411023
+module.layer3.3.conv3.weight,0.4,75.87,92.75999999999999,0.9775160366327179
+module.layer3.3.conv3.weight,0.45,75.784,92.768,0.9805781351668494
+module.layer3.3.conv3.weight,0.5,75.79599999999999,92.72399999999999,0.9821375484521293
+module.layer3.4.conv1.weight,0.0,76.13,92.862,0.9646787385703351
+module.layer3.4.conv1.weight,0.05,75.958,92.852,0.9707729078981338
+module.layer3.4.conv1.weight,0.1,75.926,92.842,0.9745264629624331
+module.layer3.4.conv1.weight,0.15000000000000002,75.79,92.73599999999999,0.9820926417501608
+module.layer3.4.conv1.weight,0.2,75.506,92.662,0.990671943797141
+module.layer3.4.conv1.weight,0.25,75.362,92.56,0.9998092267434208
+module.layer3.4.conv1.weight,0.30000000000000004,75.252,92.464,1.0118998679883624
+module.layer3.4.conv1.weight,0.35000000000000003,74.91,92.308,1.0283412074252054
+module.layer3.4.conv1.weight,0.4,74.634,92.11399999999999,1.0410314060899681
+module.layer3.4.conv1.weight,0.45,74.39,91.91,1.0556378165373999
+module.layer3.4.conv1.weight,0.5,74.16,91.804,1.0647825668660966
+module.layer3.4.conv2.weight,0.0,76.13,92.862,0.9646787385703351
+module.layer3.4.conv2.weight,0.05,76.028,92.826,0.9699842729890834
+module.layer3.4.conv2.weight,0.1,75.95,92.828,0.9719456350620915
+module.layer3.4.conv2.weight,0.15000000000000002,75.708,92.726,0.9797904483061665
+module.layer3.4.conv2.weight,0.2,75.646,92.662,0.9835705510055528
+module.layer3.4.conv2.weight,0.25,75.38,92.514,0.9943041349095958
+module.layer3.4.conv2.weight,0.30000000000000004,75.136,92.384,1.0046290529473703
+module.layer3.4.conv2.weight,0.35000000000000003,75.008,92.30000000000001,1.0104411206379234
+module.layer3.4.conv2.weight,0.4,74.76400000000001,92.12400000000001,1.0218490987864075
+module.layer3.4.conv2.weight,0.45,74.366,91.962,1.0363766772254388
+module.layer3.4.conv2.weight,0.5,73.676,91.602,1.0667164083190113
+module.layer3.4.conv3.weight,0.0,76.13,92.862,0.9646787385703351
+module.layer3.4.conv3.weight,0.05,76.106,92.872,0.964855870574104
+module.layer3.4.conv3.weight,0.1,76.05,92.876,0.9654396832445448
+module.layer3.4.conv3.weight,0.15000000000000002,75.956,92.852,0.968098088338667
+module.layer3.4.conv3.weight,0.2,75.91799999999999,92.84599999999999,0.9696684817270358
+module.layer3.4.conv3.weight,0.25,75.928,92.85600000000001,0.9708511539715896
+module.layer3.4.conv3.weight,0.30000000000000004,75.856,92.854,0.972894585710399
+module.layer3.4.conv3.weight,0.35000000000000003,75.86,92.85799999999999,0.9759889332463548
+module.layer3.4.conv3.weight,0.4,75.78,92.784,0.9785844093682812
+module.layer3.4.conv3.weight,0.45,75.682,92.754,0.982491150194285
+module.layer3.4.conv3.weight,0.5,75.606,92.682,0.9880777105536995
+module.layer3.5.conv1.weight,0.0,76.13,92.862,0.9646787385703351
+module.layer3.5.conv1.weight,0.05,75.926,92.798,0.9710209958103239
+module.layer3.5.conv1.weight,0.1,75.68599999999999,92.706,0.9788274871451516
+module.layer3.5.conv1.weight,0.15000000000000002,75.596,92.646,0.9845531354753339
+module.layer3.5.conv1.weight,0.2,75.424,92.538,0.9935152600614392
+module.layer3.5.conv1.weight,0.25,75.18199999999999,92.34599999999999,1.0075640448806236
+module.layer3.5.conv1.weight,0.30000000000000004,74.83800000000001,92.2,1.021532102506987
+module.layer3.5.conv1.weight,0.35000000000000003,74.514,92.006,1.0346408500811277
+module.layer3.5.conv1.weight,0.4,74.33000000000001,91.89,1.0447048462015025
+module.layer3.5.conv1.weight,0.45,74.146,91.768,1.053496342076331
+module.layer3.5.conv1.weight,0.5,73.646,91.52799999999999,1.0740979883287631
+module.layer3.5.conv2.weight,0.0,76.13,92.862,0.9646787385703351
+module.layer3.5.conv2.weight,0.05,75.998,92.918,0.9653302145247556
+module.layer3.5.conv2.weight,0.1,75.848,92.826,0.9723002608029214
+module.layer3.5.conv2.weight,0.15000000000000002,75.698,92.75999999999999,0.9797882666545261
+module.layer3.5.conv2.weight,0.2,75.28200000000001,92.61200000000001,0.991536684030173
+module.layer3.5.conv2.weight,0.25,75.168,92.572,0.9951652087727368
+module.layer3.5.conv2.weight,0.30000000000000004,75.078,92.472,1.0034233816728295
+module.layer3.5.conv2.weight,0.35000000000000003,74.83800000000001,92.384,1.013210628865933
+module.layer3.5.conv2.weight,0.4,74.462,92.20200000000001,1.0288722234569039
+module.layer3.5.conv2.weight,0.45,74.07000000000001,91.948,1.0470701186936726
+module.layer3.5.conv2.weight,0.5,73.854,91.796,1.0605612496028145
+module.layer3.5.conv3.weight,0.0,76.13,92.862,0.9646787385703351
+module.layer3.5.conv3.weight,0.05,76.044,92.848,0.9663973350305944
+module.layer3.5.conv3.weight,0.1,75.978,92.84400000000001,0.9683438963731941
+module.layer3.5.conv3.weight,0.15000000000000002,75.96199999999999,92.836,0.9695205550108633
+module.layer3.5.conv3.weight,0.2,75.81,92.78,0.9731556126961902
+module.layer3.5.conv3.weight,0.25,75.77000000000001,92.78,0.9747894256546789
+module.layer3.5.conv3.weight,0.30000000000000004,75.78,92.73599999999999,0.9779086081045018
+module.layer3.5.conv3.weight,0.35000000000000003,75.682,92.666,0.9823050108947311
+module.layer3.5.conv3.weight,0.4,75.624,92.64,0.9860117449900325
+module.layer3.5.conv3.weight,0.45,75.506,92.58999999999999,0.9931091208543096
+module.layer3.5.conv3.weight,0.5,75.252,92.494,1.001836841887966
+module.layer4.0.conv1.weight,0.0,76.13,92.862,0.9646787385703351
+module.layer4.0.conv1.weight,0.05,75.452,92.72200000000001,0.980977897969436
+module.layer4.0.conv1.weight,0.1,75.06200000000001,92.476,0.99777424723214
+module.layer4.0.conv1.weight,0.15000000000000002,74.49600000000001,92.296,1.0166220766093048
+module.layer4.0.conv1.weight,0.2,73.66,91.852,1.0479011927940405
+module.layer4.0.conv1.weight,0.25,72.994,91.36999999999999,1.0818585154353357
+module.layer4.0.conv1.weight,0.30000000000000004,71.96199999999999,90.944,1.122837725342537
+module.layer4.0.conv1.weight,0.35000000000000003,71.224,90.57,1.1563155524888813
+module.layer4.0.conv1.weight,0.4,69.818,89.60000000000001,1.2258589152170691
+module.layer4.0.conv1.weight,0.45,68.00399999999999,88.592,1.3101937196084439
+module.layer4.0.conv1.weight,0.5,65.932,87.21799999999999,1.4147236747096998
+module.layer4.0.conv2.weight,0.0,76.13,92.862,0.9646787385703351
+module.layer4.0.conv2.weight,0.05,75.824,92.766,0.9820537386196001
+module.layer4.0.conv2.weight,0.1,75.442,92.582,0.9957953673236224
+module.layer4.0.conv2.weight,0.15000000000000002,75.17,92.49199999999999,1.005282780953816
+module.layer4.0.conv2.weight,0.2,74.852,92.304,1.0212704213146044
+module.layer4.0.conv2.weight,0.25,74.292,91.964,1.0414224162089585
+module.layer4.0.conv2.weight,0.30000000000000004,73.604,91.634,1.0696578053187356
+module.layer4.0.conv2.weight,0.35000000000000003,72.89999999999999,91.208,1.0999143880848987
+module.layer4.0.conv2.weight,0.4,72.184,90.8,1.1301327659463394
+module.layer4.0.conv2.weight,0.45,71.114,90.238,1.173654002346554
+module.layer4.0.conv2.weight,0.5,69.95599999999999,89.628,1.222327757398693
+module.layer4.0.conv3.weight,0.0,76.13,92.862,0.9646787385703351
+module.layer4.0.conv3.weight,0.05,76.068,92.872,0.974233268703125
+module.layer4.0.conv3.weight,0.1,75.902,92.764,0.9797552155748923
+module.layer4.0.conv3.weight,0.15000000000000002,75.80799999999999,92.738,0.9865577943166911
+module.layer4.0.conv3.weight,0.2,75.632,92.642,0.9941004342114437
+module.layer4.0.conv3.weight,0.25,75.39,92.508,1.0038409359296971
+module.layer4.0.conv3.weight,0.30000000000000004,75.114,92.41,1.0120449089730275
+module.layer4.0.conv3.weight,0.35000000000000003,74.892,92.244,1.023147908704622
+module.layer4.0.conv3.weight,0.4,74.56,92.092,1.0372554889442975
+module.layer4.0.conv3.weight,0.45,74.076,91.852,1.0547336231993174
+module.layer4.0.conv3.weight,0.5,73.584,91.602,1.077762626719718
+module.layer4.0.downsample.0.weight,0.0,76.13,92.862,0.9646787385703351
+module.layer4.0.downsample.0.weight,0.05,76.012,92.872,0.9663943356397199
+module.layer4.0.downsample.0.weight,0.1,75.99000000000001,92.866,0.9696360729938869
+module.layer4.0.downsample.0.weight,0.15000000000000002,75.91,92.84599999999999,0.9735238836431992
+module.layer4.0.downsample.0.weight,0.2,75.782,92.81,0.975955544001594
+module.layer4.0.downsample.0.weight,0.25,75.71799999999999,92.726,0.9828369682844804
+module.layer4.0.downsample.0.weight,0.30000000000000004,75.58,92.704,0.9846665353647297
+module.layer4.0.downsample.0.weight,0.35000000000000003,75.55199999999999,92.65,0.9895600051600107
+module.layer4.0.downsample.0.weight,0.4,75.38199999999999,92.586,0.9927099836724144
+module.layer4.0.downsample.0.weight,0.45,75.12,92.534,0.9976098318489229
+module.layer4.0.downsample.0.weight,0.5,74.932,92.428,1.0028844749440953
+module.layer4.1.conv1.weight,0.0,76.13,92.862,0.9646787385703351
+module.layer4.1.conv1.weight,0.05,75.332,92.654,0.98502134539339
+module.layer4.1.conv1.weight,0.1,75.018,92.47800000000001,1.0129908440368516
+module.layer4.1.conv1.weight,0.15000000000000002,74.4,92.008,1.0427387130199646
+module.layer4.1.conv1.weight,0.2,73.52799999999999,91.636,1.0739348117186098
+module.layer4.1.conv1.weight,0.25,72.952,91.32000000000001,1.103592288889447
+module.layer4.1.conv1.weight,0.30000000000000004,72.392,90.914,1.1367468675788572
+module.layer4.1.conv1.weight,0.35000000000000003,71.892,90.664,1.1603467205957496
+module.layer4.1.conv1.weight,0.4,71.218,90.34400000000001,1.1915501374371191
+module.layer4.1.conv1.weight,0.45,69.674,89.46600000000001,1.2544893468825187
+module.layer4.1.conv1.weight,0.5,68.68799999999999,88.822,1.2998564492683022
+module.layer4.1.conv2.weight,0.0,76.13,92.862,0.9646787385703351
+module.layer4.1.conv2.weight,0.05,75.53,92.622,0.9798856420176366
+module.layer4.1.conv2.weight,0.1,74.52600000000001,92.082,1.0328085240052662
+module.layer4.1.conv2.weight,0.15000000000000002,73.35199999999999,91.58800000000001,1.1326180742103231
+module.layer4.1.conv2.weight,0.2,72.074,90.818,1.2893840082141816
+module.layer4.1.conv2.weight,0.25,70.956,90.19200000000001,1.4188573299622047
+module.layer4.1.conv2.weight,0.30000000000000004,68.238,88.714,1.7882773164583712
+module.layer4.1.conv2.weight,0.35000000000000003,67.46600000000001,88.27000000000001,1.970356911420823
+module.layer4.1.conv2.weight,0.4,66.98200000000001,87.858,2.050041017483691
+module.layer4.1.conv2.weight,0.45,64.67,86.59,2.2327518517873735
+module.layer4.1.conv2.weight,0.5,62.056,84.88,2.648381370062732
+module.layer4.1.conv3.weight,0.0,76.13,92.862,0.9646787385703351
+module.layer4.1.conv3.weight,0.05,76.074,92.872,0.9659452228521812
+module.layer4.1.conv3.weight,0.1,75.988,92.842,0.967287575940088
+module.layer4.1.conv3.weight,0.15000000000000002,75.934,92.772,0.9751768792618293
+module.layer4.1.conv3.weight,0.2,75.866,92.73400000000001,0.9757551226232728
+module.layer4.1.conv3.weight,0.25,75.8,92.686,0.981689058548334
+module.layer4.1.conv3.weight,0.30000000000000004,75.628,92.642,0.9852055123417963
+module.layer4.1.conv3.weight,0.35000000000000003,75.49,92.536,0.9915929486100771
+module.layer4.1.conv3.weight,0.4,75.272,92.498,1.0091029061954848
+module.layer4.1.conv3.weight,0.45,75.028,92.34400000000001,1.022709506050665
+module.layer4.1.conv3.weight,0.5,74.638,92.148,1.0381865026999488
+module.layer4.2.conv1.weight,0.0,76.13,92.862,0.9646787385703351
+module.layer4.2.conv1.weight,0.05,75.846,92.826,0.9717933706635113
+module.layer4.2.conv1.weight,0.1,75.276,92.472,0.9955067696163846
+module.layer4.2.conv1.weight,0.15000000000000002,74.656,92.20200000000001,1.0159382923525204
+module.layer4.2.conv1.weight,0.2,74.12,91.932,1.0420076324775516
+module.layer4.2.conv1.weight,0.25,73.394,91.662,1.08965770018344
+module.layer4.2.conv1.weight,0.30000000000000004,72.66,91.33200000000001,1.1267745362556716
+module.layer4.2.conv1.weight,0.35000000000000003,72.174,91.024,1.1684093069361183
+module.layer4.2.conv1.weight,0.4,71.224,90.53,1.239526554181868
+module.layer4.2.conv1.weight,0.45,70.39800000000001,89.862,1.310148460828528
+module.layer4.2.conv1.weight,0.5,69.206,89.23400000000001,1.4035242774656844
+module.layer4.2.conv2.weight,0.0,76.13,92.862,0.9646787385703351
+module.layer4.2.conv2.weight,0.05,75.828,92.818,0.9695281652467589
+module.layer4.2.conv2.weight,0.1,75.676,92.702,0.975221398989765
+module.layer4.2.conv2.weight,0.15000000000000002,75.462,92.52,0.9864033583019457
+module.layer4.2.conv2.weight,0.2,74.98400000000001,92.362,1.0007667366643338
+module.layer4.2.conv2.weight,0.25,74.54799999999999,92.176,1.0197215826839818
+module.layer4.2.conv2.weight,0.30000000000000004,73.988,92.006,1.0427208835525168
+module.layer4.2.conv2.weight,0.35000000000000003,73.628,91.768,1.0686493072734802
+module.layer4.2.conv2.weight,0.4,73.262,91.602,1.097705980800852
+module.layer4.2.conv2.weight,0.45,72.57000000000001,91.266,1.1520085818305306
+module.layer4.2.conv2.weight,0.5,71.44800000000001,90.73400000000001,1.2216127002421697
+module.layer4.2.conv3.weight,0.0,76.13,92.862,0.9646787385703351
+module.layer4.2.conv3.weight,0.05,76.088,92.882,0.9627235708188039
+module.layer4.2.conv3.weight,0.1,76.062,92.868,0.9608848703911115
+module.layer4.2.conv3.weight,0.15000000000000002,76.032,92.848,0.9586635404551517
+module.layer4.2.conv3.weight,0.2,75.888,92.826,0.9583521980260096
+module.layer4.2.conv3.weight,0.25,75.874,92.794,0.9590057633361035
+module.layer4.2.conv3.weight,0.30000000000000004,75.71799999999999,92.752,0.9618615930025672
+module.layer4.2.conv3.weight,0.35000000000000003,75.598,92.702,0.9675871675111806
+module.layer4.2.conv3.weight,0.4,75.422,92.648,0.9774525021107828
+module.layer4.2.conv3.weight,0.45,75.216,92.584,0.9908461873324553
+module.layer4.2.conv3.weight,0.5,74.978,92.554,1.0100993886590002
diff --git a/jupyter/imagenet_classes.py b/jupyter/imagenet_classes.py
new file mode 100755
index 0000000000000000000000000000000000000000..d1866650ddc3bc72830cc2d4ad26712718ec8877
--- /dev/null
+++ b/jupyter/imagenet_classes.py
@@ -0,0 +1,1003 @@
+# https://gist.github.com/yrevar/942d3a0ac09ec9e5eb3a#file-imagenet1000_clsid_to_human-txt
+
+imagenet_classes = {
+ 0: 'tench, Tinca tinca',
+ 1: 'goldfish, Carassius auratus',
+ 2: 'great white shark, white shark, man-eater, man-eating shark, Carcharodon carcharias',
+ 3: 'tiger shark, Galeocerdo cuvieri',
+ 4: 'hammerhead, hammerhead shark',
+ 5: 'electric ray, crampfish, numbfish, torpedo',
+ 6: 'stingray',
+ 7: 'cock',
+ 8: 'hen',
+ 9: 'ostrich, Struthio camelus',
+ 10: 'brambling, Fringilla montifringilla',
+ 11: 'goldfinch, Carduelis carduelis',
+ 12: 'house finch, linnet, Carpodacus mexicanus',
+ 13: 'junco, snowbird',
+ 14: 'indigo bunting, indigo finch, indigo bird, Passerina cyanea',
+ 15: 'robin, American robin, Turdus migratorius',
+ 16: 'bulbul',
+ 17: 'jay',
+ 18: 'magpie',
+ 19: 'chickadee',
+ 20: 'water ouzel, dipper',
+ 21: 'kite',
+ 22: 'bald eagle, American eagle, Haliaeetus leucocephalus',
+ 23: 'vulture',
+ 24: 'great grey owl, great gray owl, Strix nebulosa',
+ 25: 'European fire salamander, Salamandra salamandra',
+ 26: 'common newt, Triturus vulgaris',
+ 27: 'eft',
+ 28: 'spotted salamander, Ambystoma maculatum',
+ 29: 'axolotl, mud puppy, Ambystoma mexicanum',
+ 30: 'bullfrog, Rana catesbeiana',
+ 31: 'tree frog, tree-frog',
+ 32: 'tailed frog, bell toad, ribbed toad, tailed toad, Ascaphus trui',
+ 33: 'loggerhead, loggerhead turtle, Caretta caretta',
+ 34: 'leatherback turtle, leatherback, leathery turtle, Dermochelys coriacea',
+ 35: 'mud turtle',
+ 36: 'terrapin',
+ 37: 'box turtle, box tortoise',
+ 38: 'banded gecko',
+ 39: 'common iguana, iguana, Iguana iguana',
+ 40: 'American chameleon, anole, Anolis carolinensis',
+ 41: 'whiptail, whiptail lizard',
+ 42: 'agama',
+ 43: 'frilled lizard, Chlamydosaurus kingi',
+ 44: 'alligator lizard',
+ 45: 'Gila monster, Heloderma suspectum',
+ 46: 'green lizard, Lacerta viridis',
+ 47: 'African chameleon, Chamaeleo chamaeleon',
+ 48: 'Komodo dragon, Komodo lizard, dragon lizard, giant lizard, Varanus komodoensis',
+ 49: 'African crocodile, Nile crocodile, Crocodylus niloticus',
+ 50: 'American alligator, Alligator mississipiensis',
+ 51: 'triceratops',
+ 52: 'thunder snake, worm snake, Carphophis amoenus',
+ 53: 'ringneck snake, ring-necked snake, ring snake',
+ 54: 'hognose snake, puff adder, sand viper',
+ 55: 'green snake, grass snake',
+ 56: 'king snake, kingsnake',
+ 57: 'garter snake, grass snake',
+ 58: 'water snake',
+ 59: 'vine snake',
+ 60: 'night snake, Hypsiglena torquata',
+ 61: 'boa constrictor, Constrictor constrictor',
+ 62: 'rock python, rock snake, Python sebae',
+ 63: 'Indian cobra, Naja naja',
+ 64: 'green mamba',
+ 65: 'sea snake',
+ 66: 'horned viper, cerastes, sand viper, horned asp, Cerastes cornutus',
+ 67: 'diamondback, diamondback rattlesnake, Crotalus adamanteus',
+ 68: 'sidewinder, horned rattlesnake, Crotalus cerastes',
+ 69: 'trilobite',
+ 70: 'harvestman, daddy longlegs, Phalangium opilio',
+ 71: 'scorpion',
+ 72: 'black and gold garden spider, Argiope aurantia',
+ 73: 'barn spider, Araneus cavaticus',
+ 74: 'garden spider, Aranea diademata',
+ 75: 'black widow, Latrodectus mactans',
+ 76: 'tarantula',
+ 77: 'wolf spider, hunting spider',
+ 78: 'tick',
+ 79: 'centipede',
+ 80: 'black grouse',
+ 81: 'ptarmigan',
+ 82: 'ruffed grouse, partridge, Bonasa umbellus',
+ 83: 'prairie chicken, prairie grouse, prairie fowl',
+ 84: 'peacock',
+ 85: 'quail',
+ 86: 'partridge',
+ 87: 'African grey, African gray, Psittacus erithacus',
+ 88: 'macaw',
+ 89: 'sulphur-crested cockatoo, Kakatoe galerita, Cacatua galerita',
+ 90: 'lorikeet',
+ 91: 'coucal',
+ 92: 'bee eater',
+ 93: 'hornbill',
+ 94: 'hummingbird',
+ 95: 'jacamar',
+ 96: 'toucan',
+ 97: 'drake',
+ 98: 'red-breasted merganser, Mergus serrator',
+ 99: 'goose',
+ 100: 'black swan, Cygnus atratus',
+ 101: 'tusker',
+ 102: 'echidna, spiny anteater, anteater',
+ 103: 'platypus, duckbill, duckbilled platypus, duck-billed platypus, Ornithorhynchus anatinus',
+ 104: 'wallaby, brush kangaroo',
+ 105: 'koala, koala bear, kangaroo bear, native bear, Phascolarctos cinereus',
+ 106: 'wombat',
+ 107: 'jellyfish',
+ 108: 'sea anemone, anemone',
+ 109: 'brain coral',
+ 110: 'flatworm, platyhelminth',
+ 111: 'nematode, nematode worm, roundworm',
+ 112: 'conch',
+ 113: 'snail',
+ 114: 'slug',
+ 115: 'sea slug, nudibranch',
+ 116: 'chiton, coat-of-mail shell, sea cradle, polyplacophore',
+ 117: 'chambered nautilus, pearly nautilus, nautilus',
+ 118: 'Dungeness crab, Cancer magister',
+ 119: 'rock crab, Cancer irroratus',
+ 120: 'fiddler crab',
+ 121: 'king crab, Alaska crab, Alaskan king crab, Alaska king crab, Paralithodes camtschatica',
+ 122: 'American lobster, Northern lobster, Maine lobster, Homarus americanus',
+ 123: 'spiny lobster, langouste, rock lobster, crawfish, crayfish, sea crawfish',
+ 124: 'crayfish, crawfish, crawdad, crawdaddy',
+ 125: 'hermit crab',
+ 126: 'isopod',
+ 127: 'white stork, Ciconia ciconia',
+ 128: 'black stork, Ciconia nigra',
+ 129: 'spoonbill',
+ 130: 'flamingo',
+ 131: 'little blue heron, Egretta caerulea',
+ 132: 'American egret, great white heron, Egretta albus',
+ 133: 'bittern',
+ 134: 'crane',
+ 135: 'limpkin, Aramus pictus',
+ 136: 'European gallinule, Porphyrio porphyrio',
+ 137: 'American coot, marsh hen, mud hen, water hen, Fulica americana',
+ 138: 'bustard',
+ 139: 'ruddy turnstone, Arenaria interpres',
+ 140: 'red-backed sandpiper, dunlin, Erolia alpina',
+ 141: 'redshank, Tringa totanus',
+ 142: 'dowitcher',
+ 143: 'oystercatcher, oyster catcher',
+ 144: 'pelican',
+ 145: 'king penguin, Aptenodytes patagonica',
+ 146: 'albatross, mollymawk',
+ 147: 'grey whale, gray whale, devilfish, Eschrichtius gibbosus, Eschrichtius robustus',
+ 148: 'killer whale, killer, orca, grampus, sea wolf, Orcinus orca',
+ 149: 'dugong, Dugong dugon',
+ 150: 'sea lion',
+ 151: 'Chihuahua',
+ 152: 'Japanese spaniel',
+ 153: 'Maltese dog, Maltese terrier, Maltese',
+ 154: 'Pekinese, Pekingese, Peke',
+ 155: 'Shih-Tzu',
+ 156: 'Blenheim spaniel',
+ 157: 'papillon',
+ 158: 'toy terrier',
+ 159: 'Rhodesian ridgeback',
+ 160: 'Afghan hound, Afghan',
+ 161: 'basset, basset hound',
+ 162: 'beagle',
+ 163: 'bloodhound, sleuthhound',
+ 164: 'bluetick',
+ 165: 'black-and-tan coonhound',
+ 166: 'Walker hound, Walker foxhound',
+ 167: 'English foxhound',
+ 168: 'redbone',
+ 169: 'borzoi, Russian wolfhound',
+ 170: 'Irish wolfhound',
+ 171: 'Italian greyhound',
+ 172: 'whippet',
+ 173: 'Ibizan hound, Ibizan Podenco',
+ 174: 'Norwegian elkhound, elkhound',
+ 175: 'otterhound, otter hound',
+ 176: 'Saluki, gazelle hound',
+ 177: 'Scottish deerhound, deerhound',
+ 178: 'Weimaraner',
+ 179: 'Staffordshire bullterrier, Staffordshire bull terrier',
+ 180: 'American Staffordshire terrier, Staffordshire terrier, American pit bull terrier, pit bull terrier',
+ 181: 'Bedlington terrier',
+ 182: 'Border terrier',
+ 183: 'Kerry blue terrier',
+ 184: 'Irish terrier',
+ 185: 'Norfolk terrier',
+ 186: 'Norwich terrier',
+ 187: 'Yorkshire terrier',
+ 188: 'wire-haired fox terrier',
+ 189: 'Lakeland terrier',
+ 190: 'Sealyham terrier, Sealyham',
+ 191: 'Airedale, Airedale terrier',
+ 192: 'cairn, cairn terrier',
+ 193: 'Australian terrier',
+ 194: 'Dandie Dinmont, Dandie Dinmont terrier',
+ 195: 'Boston bull, Boston terrier',
+ 196: 'miniature schnauzer',
+ 197: 'giant schnauzer',
+ 198: 'standard schnauzer',
+ 199: 'Scotch terrier, Scottish terrier, Scottie',
+ 200: 'Tibetan terrier, chrysanthemum dog',
+ 201: 'silky terrier, Sydney silky',
+ 202: 'soft-coated wheaten terrier',
+ 203: 'West Highland white terrier',
+ 204: 'Lhasa, Lhasa apso',
+ 205: 'flat-coated retriever',
+ 206: 'curly-coated retriever',
+ 207: 'golden retriever',
+ 208: 'Labrador retriever',
+ 209: 'Chesapeake Bay retriever',
+ 210: 'German short-haired pointer',
+ 211: 'vizsla, Hungarian pointer',
+ 212: 'English setter',
+ 213: 'Irish setter, red setter',
+ 214: 'Gordon setter',
+ 215: 'Brittany spaniel',
+ 216: 'clumber, clumber spaniel',
+ 217: 'English springer, English springer spaniel',
+ 218: 'Welsh springer spaniel',
+ 219: 'cocker spaniel, English cocker spaniel, cocker',
+ 220: 'Sussex spaniel',
+ 221: 'Irish water spaniel',
+ 222: 'kuvasz',
+ 223: 'schipperke',
+ 224: 'groenendael',
+ 225: 'malinois',
+ 226: 'briard',
+ 227: 'kelpie',
+ 228: 'komondor',
+ 229: 'Old English sheepdog, bobtail',
+ 230: 'Shetland sheepdog, Shetland sheep dog, Shetland',
+ 231: 'collie',
+ 232: 'Border collie',
+ 233: 'Bouvier des Flandres, Bouviers des Flandres',
+ 234: 'Rottweiler',
+ 235: 'German shepherd, German shepherd dog, German police dog, alsatian',
+ 236: 'Doberman, Doberman pinscher',
+ 237: 'miniature pinscher',
+ 238: 'Greater Swiss Mountain dog',
+ 239: 'Bernese mountain dog',
+ 240: 'Appenzeller',
+ 241: 'EntleBucher',
+ 242: 'boxer',
+ 243: 'bull mastiff',
+ 244: 'Tibetan mastiff',
+ 245: 'French bulldog',
+ 246: 'Great Dane',
+ 247: 'Saint Bernard, St Bernard',
+ 248: 'Eskimo dog, husky',
+ 249: 'malamute, malemute, Alaskan malamute',
+ 250: 'Siberian husky',
+ 251: 'dalmatian, coach dog, carriage dog',
+ 252: 'affenpinscher, monkey pinscher, monkey dog',
+ 253: 'basenji',
+ 254: 'pug, pug-dog',
+ 255: 'Leonberg',
+ 256: 'Newfoundland, Newfoundland dog',
+ 257: 'Great Pyrenees',
+ 258: 'Samoyed, Samoyede',
+ 259: 'Pomeranian',
+ 260: 'chow, chow chow',
+ 261: 'keeshond',
+ 262: 'Brabancon griffon',
+ 263: 'Pembroke, Pembroke Welsh corgi',
+ 264: 'Cardigan, Cardigan Welsh corgi',
+ 265: 'toy poodle',
+ 266: 'miniature poodle',
+ 267: 'standard poodle',
+ 268: 'Mexican hairless',
+ 269: 'timber wolf, grey wolf, gray wolf, Canis lupus',
+ 270: 'white wolf, Arctic wolf, Canis lupus tundrarum',
+ 271: 'red wolf, maned wolf, Canis rufus, Canis niger',
+ 272: 'coyote, prairie wolf, brush wolf, Canis latrans',
+ 273: 'dingo, warrigal, warragal, Canis dingo',
+ 274: 'dhole, Cuon alpinus',
+ 275: 'African hunting dog, hyena dog, Cape hunting dog, Lycaon pictus',
+ 276: 'hyena, hyaena',
+ 277: 'red fox, Vulpes vulpes',
+ 278: 'kit fox, Vulpes macrotis',
+ 279: 'Arctic fox, white fox, Alopex lagopus',
+ 280: 'grey fox, gray fox, Urocyon cinereoargenteus',
+ 281: 'tabby, tabby cat',
+ 282: 'tiger cat',
+ 283: 'Persian cat',
+ 284: 'Siamese cat, Siamese',
+ 285: 'Egyptian cat',
+ 286: 'cougar, puma, catamount, mountain lion, painter, panther, Felis concolor',
+ 287: 'lynx, catamount',
+ 288: 'leopard, Panthera pardus',
+ 289: 'snow leopard, ounce, Panthera uncia',
+ 290: 'jaguar, panther, Panthera onca, Felis onca',
+ 291: 'lion, king of beasts, Panthera leo',
+ 292: 'tiger, Panthera tigris',
+ 293: 'cheetah, chetah, Acinonyx jubatus',
+ 294: 'brown bear, bruin, Ursus arctos',
+ 295: 'American black bear, black bear, Ursus americanus, Euarctos americanus',
+ 296: 'ice bear, polar bear, Ursus Maritimus, Thalarctos maritimus',
+ 297: 'sloth bear, Melursus ursinus, Ursus ursinus',
+ 298: 'mongoose',
+ 299: 'meerkat, mierkat',
+ 300: 'tiger beetle',
+ 301: 'ladybug, ladybeetle, lady beetle, ladybird, ladybird beetle',
+ 302: 'ground beetle, carabid beetle',
+ 303: 'long-horned beetle, longicorn, longicorn beetle',
+ 304: 'leaf beetle, chrysomelid',
+ 305: 'dung beetle',
+ 306: 'rhinoceros beetle',
+ 307: 'weevil',
+ 308: 'fly',
+ 309: 'bee',
+ 310: 'ant, emmet, pismire',
+ 311: 'grasshopper, hopper',
+ 312: 'cricket',
+ 313: 'walking stick, walkingstick, stick insect',
+ 314: 'cockroach, roach',
+ 315: 'mantis, mantid',
+ 316: 'cicada, cicala',
+ 317: 'leafhopper',
+ 318: 'lacewing, lacewing fly',
+ 319: "dragonfly, darning needle, devil's darning needle, sewing needle, snake feeder, snake doctor, mosquito hawk, skeeter hawk",
+ 320: 'damselfly',
+ 321: 'admiral',
+ 322: 'ringlet, ringlet butterfly',
+ 323: 'monarch, monarch butterfly, milkweed butterfly, Danaus plexippus',
+ 324: 'cabbage butterfly',
+ 325: 'sulphur butterfly, sulfur butterfly',
+ 326: 'lycaenid, lycaenid butterfly',
+ 327: 'starfish, sea star',
+ 328: 'sea urchin',
+ 329: 'sea cucumber, holothurian',
+ 330: 'wood rabbit, cottontail, cottontail rabbit',
+ 331: 'hare',
+ 332: 'Angora, Angora rabbit',
+ 333: 'hamster',
+ 334: 'porcupine, hedgehog',
+ 335: 'fox squirrel, eastern fox squirrel, Sciurus niger',
+ 336: 'marmot',
+ 337: 'beaver',
+ 338: 'guinea pig, Cavia cobaya',
+ 339: 'sorrel',
+ 340: 'zebra',
+ 341: 'hog, pig, grunter, squealer, Sus scrofa',
+ 342: 'wild boar, boar, Sus scrofa',
+ 343: 'warthog',
+ 344: 'hippopotamus, hippo, river horse, Hippopotamus amphibius',
+ 345: 'ox',
+ 346: 'water buffalo, water ox, Asiatic buffalo, Bubalus bubalis',
+ 347: 'bison',
+ 348: 'ram, tup',
+ 349: 'bighorn, bighorn sheep, cimarron, Rocky Mountain bighorn, Rocky Mountain sheep, Ovis canadensis',
+ 350: 'ibex, Capra ibex',
+ 351: 'hartebeest',
+ 352: 'impala, Aepyceros melampus',
+ 353: 'gazelle',
+ 354: 'Arabian camel, dromedary, Camelus dromedarius',
+ 355: 'llama',
+ 356: 'weasel',
+ 357: 'mink',
+ 358: 'polecat, fitch, foulmart, foumart, Mustela putorius',
+ 359: 'black-footed ferret, ferret, Mustela nigripes',
+ 360: 'otter',
+ 361: 'skunk, polecat, wood pussy',
+ 362: 'badger',
+ 363: 'armadillo',
+ 364: 'three-toed sloth, ai, Bradypus tridactylus',
+ 365: 'orangutan, orang, orangutang, Pongo pygmaeus',
+ 366: 'gorilla, Gorilla gorilla',
+ 367: 'chimpanzee, chimp, Pan troglodytes',
+ 368: 'gibbon, Hylobates lar',
+ 369: 'siamang, Hylobates syndactylus, Symphalangus syndactylus',
+ 370: 'guenon, guenon monkey',
+ 371: 'patas, hussar monkey, Erythrocebus patas',
+ 372: 'baboon',
+ 373: 'macaque',
+ 374: 'langur',
+ 375: 'colobus, colobus monkey',
+ 376: 'proboscis monkey, Nasalis larvatus',
+ 377: 'marmoset',
+ 378: 'capuchin, ringtail, Cebus capucinus',
+ 379: 'howler monkey, howler',
+ 380: 'titi, titi monkey',
+ 381: 'spider monkey, Ateles geoffroyi',
+ 382: 'squirrel monkey, Saimiri sciureus',
+ 383: 'Madagascar cat, ring-tailed lemur, Lemur catta',
+ 384: 'indri, indris, Indri indri, Indri brevicaudatus',
+ 385: 'Indian elephant, Elephas maximus',
+ 386: 'African elephant, Loxodonta africana',
+ 387: 'lesser panda, red panda, panda, bear cat, cat bear, Ailurus fulgens',
+ 388: 'giant panda, panda, panda bear, coon bear, Ailuropoda melanoleuca',
+ 389: 'barracouta, snoek',
+ 390: 'eel',
+ 391: 'coho, cohoe, coho salmon, blue jack, silver salmon, Oncorhynchus kisutch',
+ 392: 'rock beauty, Holocanthus tricolor',
+ 393: 'anemone fish',
+ 394: 'sturgeon',
+ 395: 'gar, garfish, garpike, billfish, Lepisosteus osseus',
+ 396: 'lionfish',
+ 397: 'puffer, pufferfish, blowfish, globefish',
+ 398: 'abacus',
+ 399: 'abaya',
+ 400: "academic gown, academic robe, judge's robe",
+ 401: 'accordion, piano accordion, squeeze box',
+ 402: 'acoustic guitar',
+ 403: 'aircraft carrier, carrier, flattop, attack aircraft carrier',
+ 404: 'airliner',
+ 405: 'airship, dirigible',
+ 406: 'altar',
+ 407: 'ambulance',
+ 408: 'amphibian, amphibious vehicle',
+ 409: 'analog clock',
+ 410: 'apiary, bee house',
+ 411: 'apron',
+ 412: 'ashcan, trash can, garbage can, wastebin, ash bin, ash-bin, ashbin, dustbin, trash barrel, trash bin',
+ 413: 'assault rifle, assault gun',
+ 414: 'backpack, back pack, knapsack, packsack, rucksack, haversack',
+ 415: 'bakery, bakeshop, bakehouse',
+ 416: 'balance beam, beam',
+ 417: 'balloon',
+ 418: 'ballpoint, ballpoint pen, ballpen, Biro',
+ 419: 'Band Aid',
+ 420: 'banjo',
+ 421: 'bannister, banister, balustrade, balusters, handrail',
+ 422: 'barbell',
+ 423: 'barber chair',
+ 424: 'barbershop',
+ 425: 'barn',
+ 426: 'barometer',
+ 427: 'barrel, cask',
+ 428: 'barrow, garden cart, lawn cart, wheelbarrow',
+ 429: 'baseball',
+ 430: 'basketball',
+ 431: 'bassinet',
+ 432: 'bassoon',
+ 433: 'bathing cap, swimming cap',
+ 434: 'bath towel',
+ 435: 'bathtub, bathing tub, bath, tub',
+ 436: 'beach wagon, station wagon, wagon, estate car, beach waggon, station waggon, waggon',
+ 437: 'beacon, lighthouse, beacon light, pharos',
+ 438: 'beaker',
+ 439: 'bearskin, busby, shako',
+ 440: 'beer bottle',
+ 441: 'beer glass',
+ 442: 'bell cote, bell cot',
+ 443: 'bib',
+ 444: 'bicycle-built-for-two, tandem bicycle, tandem',
+ 445: 'bikini, two-piece',
+ 446: 'binder, ring-binder',
+ 447: 'binoculars, field glasses, opera glasses',
+ 448: 'birdhouse',
+ 449: 'boathouse',
+ 450: 'bobsled, bobsleigh, bob',
+ 451: 'bolo tie, bolo, bola tie, bola',
+ 452: 'bonnet, poke bonnet',
+ 453: 'bookcase',
+ 454: 'bookshop, bookstore, bookstall',
+ 455: 'bottlecap',
+ 456: 'bow',
+ 457: 'bow tie, bow-tie, bowtie',
+ 458: 'brass, memorial tablet, plaque',
+ 459: 'brassiere, bra, bandeau',
+ 460: 'breakwater, groin, groyne, mole, bulwark, seawall, jetty',
+ 461: 'breastplate, aegis, egis',
+ 462: 'broom',
+ 463: 'bucket, pail',
+ 464: 'buckle',
+ 465: 'bulletproof vest',
+ 466: 'bullet train, bullet',
+ 467: 'butcher shop, meat market',
+ 468: 'cab, hack, taxi, taxicab',
+ 469: 'caldron, cauldron',
+ 470: 'candle, taper, wax light',
+ 471: 'cannon',
+ 472: 'canoe',
+ 473: 'can opener, tin opener',
+ 474: 'cardigan',
+ 475: 'car mirror',
+ 476: 'carousel, carrousel, merry-go-round, roundabout, whirligig',
+ 477: "carpenter's kit, tool kit",
+ 478: 'carton',
+ 479: 'car wheel',
+ 480: 'cash machine, cash dispenser, automated teller machine, automatic teller machine, automated teller, automatic teller, ATM',
+ 481: 'cassette',
+ 482: 'cassette player',
+ 483: 'castle',
+ 484: 'catamaran',
+ 485: 'CD player',
+ 486: 'cello, violoncello',
+ 487: 'cellular telephone, cellular phone, cellphone, cell, mobile phone',
+ 488: 'chain',
+ 489: 'chainlink fence',
+ 490: 'chain mail, ring mail, mail, chain armor, chain armour, ring armor, ring armour',
+ 491: 'chain saw, chainsaw',
+ 492: 'chest',
+ 493: 'chiffonier, commode',
+ 494: 'chime, bell, gong',
+ 495: 'china cabinet, china closet',
+ 496: 'Christmas stocking',
+ 497: 'church, church building',
+ 498: 'cinema, movie theater, movie theatre, movie house, picture palace',
+ 499: 'cleaver, meat cleaver, chopper',
+ 500: 'cliff dwelling',
+ 501: 'cloak',
+ 502: 'clog, geta, patten, sabot',
+ 503: 'cocktail shaker',
+ 504: 'coffee mug',
+ 505: 'coffeepot',
+ 506: 'coil, spiral, volute, whorl, helix',
+ 507: 'combination lock',
+ 508: 'computer keyboard, keypad',
+ 509: 'confectionery, confectionary, candy store',
+ 510: 'container ship, containership, container vessel',
+ 511: 'convertible',
+ 512: 'corkscrew, bottle screw',
+ 513: 'cornet, horn, trumpet, trump',
+ 514: 'cowboy boot',
+ 515: 'cowboy hat, ten-gallon hat',
+ 516: 'cradle',
+ 517: 'crane',
+ 518: 'crash helmet',
+ 519: 'crate',
+ 520: 'crib, cot',
+ 521: 'Crock Pot',
+ 522: 'croquet ball',
+ 523: 'crutch',
+ 524: 'cuirass',
+ 525: 'dam, dike, dyke',
+ 526: 'desk',
+ 527: 'desktop computer',
+ 528: 'dial telephone, dial phone',
+ 529: 'diaper, nappy, napkin',
+ 530: 'digital clock',
+ 531: 'digital watch',
+ 532: 'dining table, board',
+ 533: 'dishrag, dishcloth',
+ 534: 'dishwasher, dish washer, dishwashing machine',
+ 535: 'disk brake, disc brake',
+ 536: 'dock, dockage, docking facility',
+ 537: 'dogsled, dog sled, dog sleigh',
+ 538: 'dome',
+ 539: 'doormat, welcome mat',
+ 540: 'drilling platform, offshore rig',
+ 541: 'drum, membranophone, tympan',
+ 542: 'drumstick',
+ 543: 'dumbbell',
+ 544: 'Dutch oven',
+ 545: 'electric fan, blower',
+ 546: 'electric guitar',
+ 547: 'electric locomotive',
+ 548: 'entertainment center',
+ 549: 'envelope',
+ 550: 'espresso maker',
+ 551: 'face powder',
+ 552: 'feather boa, boa',
+ 553: 'file, file cabinet, filing cabinet',
+ 554: 'fireboat',
+ 555: 'fire engine, fire truck',
+ 556: 'fire screen, fireguard',
+ 557: 'flagpole, flagstaff',
+ 558: 'flute, transverse flute',
+ 559: 'folding chair',
+ 560: 'football helmet',
+ 561: 'forklift',
+ 562: 'fountain',
+ 563: 'fountain pen',
+ 564: 'four-poster',
+ 565: 'freight car',
+ 566: 'French horn, horn',
+ 567: 'frying pan, frypan, skillet',
+ 568: 'fur coat',
+ 569: 'garbage truck, dustcart',
+ 570: 'gasmask, respirator, gas helmet',
+ 571: 'gas pump, gasoline pump, petrol pump, island dispenser',
+ 572: 'goblet',
+ 573: 'go-kart',
+ 574: 'golf ball',
+ 575: 'golfcart, golf cart',
+ 576: 'gondola',
+ 577: 'gong, tam-tam',
+ 578: 'gown',
+ 579: 'grand piano, grand',
+ 580: 'greenhouse, nursery, glasshouse',
+ 581: 'grille, radiator grille',
+ 582: 'grocery store, grocery, food market, market',
+ 583: 'guillotine',
+ 584: 'hair slide',
+ 585: 'hair spray',
+ 586: 'half track',
+ 587: 'hammer',
+ 588: 'hamper',
+ 589: 'hand blower, blow dryer, blow drier, hair dryer, hair drier',
+ 590: 'hand-held computer, hand-held microcomputer',
+ 591: 'handkerchief, hankie, hanky, hankey',
+ 592: 'hard disc, hard disk, fixed disk',
+ 593: 'harmonica, mouth organ, harp, mouth harp',
+ 594: 'harp',
+ 595: 'harvester, reaper',
+ 596: 'hatchet',
+ 597: 'holster',
+ 598: 'home theater, home theatre',
+ 599: 'honeycomb',
+ 600: 'hook, claw',
+ 601: 'hoopskirt, crinoline',
+ 602: 'horizontal bar, high bar',
+ 603: 'horse cart, horse-cart',
+ 604: 'hourglass',
+ 605: 'iPod',
+ 606: 'iron, smoothing iron',
+ 607: "jack-o'-lantern",
+ 608: 'jean, blue jean, denim',
+ 609: 'jeep, landrover',
+ 610: 'jersey, T-shirt, tee shirt',
+ 611: 'jigsaw puzzle',
+ 612: 'jinrikisha, ricksha, rickshaw',
+ 613: 'joystick',
+ 614: 'kimono',
+ 615: 'knee pad',
+ 616: 'knot',
+ 617: 'lab coat, laboratory coat',
+ 618: 'ladle',
+ 619: 'lampshade, lamp shade',
+ 620: 'laptop, laptop computer',
+ 621: 'lawn mower, mower',
+ 622: 'lens cap, lens cover',
+ 623: 'letter opener, paper knife, paperknife',
+ 624: 'library',
+ 625: 'lifeboat',
+ 626: 'lighter, light, igniter, ignitor',
+ 627: 'limousine, limo',
+ 628: 'liner, ocean liner',
+ 629: 'lipstick, lip rouge',
+ 630: 'Loafer',
+ 631: 'lotion',
+ 632: 'loudspeaker, speaker, speaker unit, loudspeaker system, speaker system',
+ 633: "loupe, jeweler's loupe",
+ 634: 'lumbermill, sawmill',
+ 635: 'magnetic compass',
+ 636: 'mailbag, postbag',
+ 637: 'mailbox, letter box',
+ 638: 'maillot',
+ 639: 'maillot, tank suit',
+ 640: 'manhole cover',
+ 641: 'maraca',
+ 642: 'marimba, xylophone',
+ 643: 'mask',
+ 644: 'matchstick',
+ 645: 'maypole',
+ 646: 'maze, labyrinth',
+ 647: 'measuring cup',
+ 648: 'medicine chest, medicine cabinet',
+ 649: 'megalith, megalithic structure',
+ 650: 'microphone, mike',
+ 651: 'microwave, microwave oven',
+ 652: 'military uniform',
+ 653: 'milk can',
+ 654: 'minibus',
+ 655: 'miniskirt, mini',
+ 656: 'minivan',
+ 657: 'missile',
+ 658: 'mitten',
+ 659: 'mixing bowl',
+ 660: 'mobile home, manufactured home',
+ 661: 'Model T',
+ 662: 'modem',
+ 663: 'monastery',
+ 664: 'monitor',
+ 665: 'moped',
+ 666: 'mortar',
+ 667: 'mortarboard',
+ 668: 'mosque',
+ 669: 'mosquito net',
+ 670: 'motor scooter, scooter',
+ 671: 'mountain bike, all-terrain bike, off-roader',
+ 672: 'mountain tent',
+ 673: 'mouse, computer mouse',
+ 674: 'mousetrap',
+ 675: 'moving van',
+ 676: 'muzzle',
+ 677: 'nail',
+ 678: 'neck brace',
+ 679: 'necklace',
+ 680: 'nipple',
+ 681: 'notebook, notebook computer',
+ 682: 'obelisk',
+ 683: 'oboe, hautboy, hautbois',
+ 684: 'ocarina, sweet potato',
+ 685: 'odometer, hodometer, mileometer, milometer',
+ 686: 'oil filter',
+ 687: 'organ, pipe organ',
+ 688: 'oscilloscope, scope, cathode-ray oscilloscope, CRO',
+ 689: 'overskirt',
+ 690: 'oxcart',
+ 691: 'oxygen mask',
+ 692: 'packet',
+ 693: 'paddle, boat paddle',
+ 694: 'paddlewheel, paddle wheel',
+ 695: 'padlock',
+ 696: 'paintbrush',
+ 697: "pajama, pyjama, pj's, jammies",
+ 698: 'palace',
+ 699: 'panpipe, pandean pipe, syrinx',
+ 700: 'paper towel',
+ 701: 'parachute, chute',
+ 702: 'parallel bars, bars',
+ 703: 'park bench',
+ 704: 'parking meter',
+ 705: 'passenger car, coach, carriage',
+ 706: 'patio, terrace',
+ 707: 'pay-phone, pay-station',
+ 708: 'pedestal, plinth, footstall',
+ 709: 'pencil box, pencil case',
+ 710: 'pencil sharpener',
+ 711: 'perfume, essence',
+ 712: 'Petri dish',
+ 713: 'photocopier',
+ 714: 'pick, plectrum, plectron',
+ 715: 'pickelhaube',
+ 716: 'picket fence, paling',
+ 717: 'pickup, pickup truck',
+ 718: 'pier',
+ 719: 'piggy bank, penny bank',
+ 720: 'pill bottle',
+ 721: 'pillow',
+ 722: 'ping-pong ball',
+ 723: 'pinwheel',
+ 724: 'pirate, pirate ship',
+ 725: 'pitcher, ewer',
+ 726: "plane, carpenter's plane, woodworking plane",
+ 727: 'planetarium',
+ 728: 'plastic bag',
+ 729: 'plate rack',
+ 730: 'plow, plough',
+ 731: "plunger, plumber's helper",
+ 732: 'Polaroid camera, Polaroid Land camera',
+ 733: 'pole',
+ 734: 'police van, police wagon, paddy wagon, patrol wagon, wagon, black Maria',
+ 735: 'poncho',
+ 736: 'pool table, billiard table, snooker table',
+ 737: 'pop bottle, soda bottle',
+ 738: 'pot, flowerpot',
+ 739: "potter's wheel",
+ 740: 'power drill',
+ 741: 'prayer rug, prayer mat',
+ 742: 'printer',
+ 743: 'prison, prison house',
+ 744: 'projectile, missile',
+ 745: 'projector',
+ 746: 'puck, hockey puck',
+ 747: 'punching bag, punch bag, punching ball, punchball',
+ 748: 'purse',
+ 749: 'quill, quill pen',
+ 750: 'quilt, comforter, comfort, puff',
+ 751: 'racer, race car, racing car',
+ 752: 'racket, racquet',
+ 753: 'radiator',
+ 754: 'radio, wireless',
+ 755: 'radio telescope, radio reflector',
+ 756: 'rain barrel',
+ 757: 'recreational vehicle, RV, R.V.',
+ 758: 'reel',
+ 759: 'reflex camera',
+ 760: 'refrigerator, icebox',
+ 761: 'remote control, remote',
+ 762: 'restaurant, eating house, eating place, eatery',
+ 763: 'revolver, six-gun, six-shooter',
+ 764: 'rifle',
+ 765: 'rocking chair, rocker',
+ 766: 'rotisserie',
+ 767: 'rubber eraser, rubber, pencil eraser',
+ 768: 'rugby ball',
+ 769: 'rule, ruler',
+ 770: 'running shoe',
+ 771: 'safe',
+ 772: 'safety pin',
+ 773: 'saltshaker, salt shaker',
+ 774: 'sandal',
+ 775: 'sarong',
+ 776: 'sax, saxophone',
+ 777: 'scabbard',
+ 778: 'scale, weighing machine',
+ 779: 'school bus',
+ 780: 'schooner',
+ 781: 'scoreboard',
+ 782: 'screen, CRT screen',
+ 783: 'screw',
+ 784: 'screwdriver',
+ 785: 'seat belt, seatbelt',
+ 786: 'sewing machine',
+ 787: 'shield, buckler',
+ 788: 'shoe shop, shoe-shop, shoe store',
+ 789: 'shoji',
+ 790: 'shopping basket',
+ 791: 'shopping cart',
+ 792: 'shovel',
+ 793: 'shower cap',
+ 794: 'shower curtain',
+ 795: 'ski',
+ 796: 'ski mask',
+ 797: 'sleeping bag',
+ 798: 'slide rule, slipstick',
+ 799: 'sliding door',
+ 800: 'slot, one-armed bandit',
+ 801: 'snorkel',
+ 802: 'snowmobile',
+ 803: 'snowplow, snowplough',
+ 804: 'soap dispenser',
+ 805: 'soccer ball',
+ 806: 'sock',
+ 807: 'solar dish, solar collector, solar furnace',
+ 808: 'sombrero',
+ 809: 'soup bowl',
+ 810: 'space bar',
+ 811: 'space heater',
+ 812: 'space shuttle',
+ 813: 'spatula',
+ 814: 'speedboat',
+ 815: "spider web, spider's web",
+ 816: 'spindle',
+ 817: 'sports car, sport car',
+ 818: 'spotlight, spot',
+ 819: 'stage',
+ 820: 'steam locomotive',
+ 821: 'steel arch bridge',
+ 822: 'steel drum',
+ 823: 'stethoscope',
+ 824: 'stole',
+ 825: 'stone wall',
+ 826: 'stopwatch, stop watch',
+ 827: 'stove',
+ 828: 'strainer',
+ 829: 'streetcar, tram, tramcar, trolley, trolley car',
+ 830: 'stretcher',
+ 831: 'studio couch, day bed',
+ 832: 'stupa, tope',
+ 833: 'submarine, pigboat, sub, U-boat',
+ 834: 'suit, suit of clothes',
+ 835: 'sundial',
+ 836: 'sunglass',
+ 837: 'sunglasses, dark glasses, shades',
+ 838: 'sunscreen, sunblock, sun blocker',
+ 839: 'suspension bridge',
+ 840: 'swab, swob, mop',
+ 841: 'sweatshirt',
+ 842: 'swimming trunks, bathing trunks',
+ 843: 'swing',
+ 844: 'switch, electric switch, electrical switch',
+ 845: 'syringe',
+ 846: 'table lamp',
+ 847: 'tank, army tank, armored combat vehicle, armoured combat vehicle',
+ 848: 'tape player',
+ 849: 'teapot',
+ 850: 'teddy, teddy bear',
+ 851: 'television, television system',
+ 852: 'tennis ball',
+ 853: 'thatch, thatched roof',
+ 854: 'theater curtain, theatre curtain',
+ 855: 'thimble',
+ 856: 'thresher, thrasher, threshing machine',
+ 857: 'throne',
+ 858: 'tile roof',
+ 859: 'toaster',
+ 860: 'tobacco shop, tobacconist shop, tobacconist',
+ 861: 'toilet seat',
+ 862: 'torch',
+ 863: 'totem pole',
+ 864: 'tow truck, tow car, wrecker',
+ 865: 'toyshop',
+ 866: 'tractor',
+ 867: 'trailer truck, tractor trailer, trucking rig, rig, articulated lorry, semi',
+ 868: 'tray',
+ 869: 'trench coat',
+ 870: 'tricycle, trike, velocipede',
+ 871: 'trimaran',
+ 872: 'tripod',
+ 873: 'triumphal arch',
+ 874: 'trolleybus, trolley coach, trackless trolley',
+ 875: 'trombone',
+ 876: 'tub, vat',
+ 877: 'turnstile',
+ 878: 'typewriter keyboard',
+ 879: 'umbrella',
+ 880: 'unicycle, monocycle',
+ 881: 'upright, upright piano',
+ 882: 'vacuum, vacuum cleaner',
+ 883: 'vase',
+ 884: 'vault',
+ 885: 'velvet',
+ 886: 'vending machine',
+ 887: 'vestment',
+ 888: 'viaduct',
+ 889: 'violin, fiddle',
+ 890: 'volleyball',
+ 891: 'waffle iron',
+ 892: 'wall clock',
+ 893: 'wallet, billfold, notecase, pocketbook',
+ 894: 'wardrobe, closet, press',
+ 895: 'warplane, military plane',
+ 896: 'washbasin, handbasin, washbowl, lavabo, wash-hand basin',
+ 897: 'washer, automatic washer, washing machine',
+ 898: 'water bottle',
+ 899: 'water jug',
+ 900: 'water tower',
+ 901: 'whiskey jug',
+ 902: 'whistle',
+ 903: 'wig',
+ 904: 'window screen',
+ 905: 'window shade',
+ 906: 'Windsor tie',
+ 907: 'wine bottle',
+ 908: 'wing',
+ 909: 'wok',
+ 910: 'wooden spoon',
+ 911: 'wool, woolen, woollen',
+ 912: 'worm fence, snake fence, snake-rail fence, Virginia fence',
+ 913: 'wreck',
+ 914: 'yawl',
+ 915: 'yurt',
+ 916: 'web site, website, internet site, site',
+ 917: 'comic book',
+ 918: 'crossword puzzle, crossword',
+ 919: 'street sign',
+ 920: 'traffic light, traffic signal, stoplight',
+ 921: 'book jacket, dust cover, dust jacket, dust wrapper',
+ 922: 'menu',
+ 923: 'plate',
+ 924: 'guacamole',
+ 925: 'consomme',
+ 926: 'hot pot, hotpot',
+ 927: 'trifle',
+ 928: 'ice cream, icecream',
+ 929: 'ice lolly, lolly, lollipop, popsicle',
+ 930: 'French loaf',
+ 931: 'bagel, beigel',
+ 932: 'pretzel',
+ 933: 'cheeseburger',
+ 934: 'hotdog, hot dog, red hot',
+ 935: 'mashed potato',
+ 936: 'head cabbage',
+ 937: 'broccoli',
+ 938: 'cauliflower',
+ 939: 'zucchini, courgette',
+ 940: 'spaghetti squash',
+ 941: 'acorn squash',
+ 942: 'butternut squash',
+ 943: 'cucumber, cuke',
+ 944: 'artichoke, globe artichoke',
+ 945: 'bell pepper',
+ 946: 'cardoon',
+ 947: 'mushroom',
+ 948: 'Granny Smith',
+ 949: 'strawberry',
+ 950: 'orange',
+ 951: 'lemon',
+ 952: 'fig',
+ 953: 'pineapple, ananas',
+ 954: 'banana',
+ 955: 'jackfruit, jak, jack',
+ 956: 'custard apple',
+ 957: 'pomegranate',
+ 958: 'hay',
+ 959: 'carbonara',
+ 960: 'chocolate sauce, chocolate syrup',
+ 961: 'dough',
+ 962: 'meat loaf, meatloaf',
+ 963: 'pizza, pizza pie',
+ 964: 'potpie',
+ 965: 'burrito',
+ 966: 'red wine',
+ 967: 'espresso',
+ 968: 'cup',
+ 969: 'eggnog',
+ 970: 'alp',
+ 971: 'bubble',
+ 972: 'cliff, drop, drop-off',
+ 973: 'coral reef',
+ 974: 'geyser',
+ 975: 'lakeside, lakeshore',
+ 976: 'promontory, headland, head, foreland',
+ 977: 'sandbar, sand bar',
+ 978: 'seashore, coast, seacoast, sea-coast',
+ 979: 'valley, vale',
+ 980: 'volcano',
+ 981: 'ballplayer, baseball player',
+ 982: 'groom, bridegroom',
+ 983: 'scuba diver',
+ 984: 'rapeseed',
+ 985: 'daisy',
+ 986: "yellow lady's slipper, yellow lady-slipper, Cypripedium calceolus, Cypripedium parviflorum",
+ 987: 'corn',
+ 988: 'acorn',
+ 989: 'hip, rose hip, rosehip',
+ 990: 'buckeye, horse chestnut, conker',
+ 991: 'coral fungus',
+ 992: 'agaric',
+ 993: 'gyromitra',
+ 994: 'stinkhorn, carrion fungus',
+ 995: 'earthstar',
+ 996: 'hen-of-the-woods, hen of the woods, Polyporus frondosus, Grifola frondosa',
+ 997: 'bolete',
+ 998: 'ear, spike, capitulum',
+ 999: 'toilet tissue, toilet paper, bathroom tissue'}
diff --git a/requirements.txt b/requirements.txt
index 7a1fc6876208008f51675eb80aae4ca2632c7994..8196c05e72713c94f05a7ce15f1499deec44a7bc 100755
--- a/requirements.txt
+++ b/requirements.txt
@@ -16,3 +16,4 @@ ipywidgets==7.1.2
 bqplot==0.10.5
 pyyaml
 pytest==3.5.1
+xlsxwriter==1.1.1
diff --git a/tests/test_summarygraph.py b/tests/test_summarygraph.py
index 6fcda7eda91932943bc8e3803124a2c135930b80..3d4153f0b87bb3120befdd91dc1e4537a4760133 100755
--- a/tests/test_summarygraph.py
+++ b/tests/test_summarygraph.py
@@ -183,7 +183,7 @@ def test_connectivity_summary():
     assert len(summary) == 73
 
     verbose_summary = connectivity_summary_verbose(g)
-    assert len(verbose_summary  ) == 73
+    assert len(verbose_summary) == 73
 
 
 if __name__ == '__main__':