From 54a5867e8d810e5938f5186a6fc1ae89e4f5b041 Mon Sep 17 00:00:00 2001
From: Neta Zmora <31280975+nzmora@users.noreply.github.com>
Date: Mon, 22 Oct 2018 20:36:17 +0300
Subject: [PATCH] Activation statistics collection (#61)

Activation statistics can be leveraged to make pruning and quantization decisions, and so
We added support to collect these data.
- Two types of activation statistics are supported: summary statistics, and detailed records
per activation.
Currently we support the following summaries:
- Average activation sparsity, per layer
- Average L1-norm for each activation channel, per layer
- Average sparsity for each activation channel, per layer

For the detailed records we collect some statistics per activation and store it in a record.
Using this collection method generates more detailed data, but consumes more time, so
Beware.

* You can collect activation data for the different training phases: training/validation/test.
* You can access the data directly from each module that you chose to collect stats for.
* You can also create an Excel workbook with the stats.

To demonstrate use of activation collection we added a sample schedule which prunes
weight filters by the activation APoZ according to:
"Network Trimming: A Data-Driven Neuron Pruning Approach towards
Efficient Deep Architectures",
Hengyuan Hu, Rui Peng, Yu-Wing Tai, Chi-Keung Tang, ICLR 2016
https://arxiv.org/abs/1607.03250

We also refactored the AGP code (AutomatedGradualPruner) to support structure pruning,
and specifically we separated the AGP schedule from the filter pruning criterion.  We added
examples of ranking filter importance based on activation APoZ (ActivationAPoZRankedFilterPruner),
random (RandomRankedFilterPruner), filter gradients (GradientRankedFilterPruner),
and filter L1-norm (L1RankedStructureParameterPruner)
---
 distiller/__init__.py                         |   12 +-
 distiller/data_loggers/__init__.py            |    2 +-
 distiller/data_loggers/collector.py           |  316 +++++-
 distiller/data_loggers/logger.py              |   34 +-
 distiller/data_loggers/tbbackend.py           |   43 +-
 distiller/policy.py                           |    1 +
 distiller/pruning/__init__.py                 |    7 +-
 distiller/pruning/automated_gradual_pruner.py |   60 +-
 distiller/pruning/ranked_structures_pruner.py |  187 ++-
 distiller/thresholding.py                     |  182 +--
 distiller/utils.py                            |   18 -
 .../resnet20_filters.schedule_agp.yaml        |   98 +-
 .../resnet20_filters.schedule_agp_2.yaml      |   80 +-
 .../compress_classifier.py                    |   97 +-
 .../resnet56_cifar_activation_apoz.yaml       |  212 ++++
 .../resnet56_cifar_activation_apoz_v2.yaml    |  198 ++++
 .../resnet56_cifar_filter_rank.yaml           |   20 +-
 .../resnet56_cifar_filter_rank_v2.yaml        |  207 ++++
 ...net50.imagenet.sensitivity_filter_wise.csv |  584 ++++++++++
 jupyter/imagenet_classes.py                   | 1003 +++++++++++++++++
 requirements.txt                              |    1 +
 tests/test_summarygraph.py                    |    2 +-
 22 files changed, 3052 insertions(+), 312 deletions(-)
 create mode 100755 examples/network_trimming/resnet56_cifar_activation_apoz.yaml
 create mode 100755 examples/network_trimming/resnet56_cifar_activation_apoz_v2.yaml
 create mode 100755 examples/pruning_filters_for_efficient_convnets/resnet56_cifar_filter_rank_v2.yaml
 create mode 100644 examples/sensitivity-analysis/resnet50-imagenet/resnet50.imagenet.sensitivity_filter_wise.csv
 create mode 100755 jupyter/imagenet_classes.py

diff --git a/distiller/__init__.py b/distiller/__init__.py
index 708d628..643f3f5 100755
--- a/distiller/__init__.py
+++ b/distiller/__init__.py
@@ -15,7 +15,7 @@
 #
 
 from .utils import *
-from .thresholding import GroupThresholdMixin, threshold_mask
+from .thresholding import GroupThresholdMixin, threshold_mask, group_threshold_mask
 from .config import file_config, dict_config
 from .model_summaries import *
 from .scheduler import *
@@ -25,19 +25,14 @@ from .policy import *
 from .thinning import *
 from .knowledge_distillation import KnowledgeDistillationPolicy, DistillationLossWeights
 
-#del utils
+
 del dict_config
 del thinning
-#del model_summaries
-#del scheduler
-#del sensitivity
-#del directives
-#del thresholding
-#del policy
 
 # Distiller version
 __version__ = "0.3.0-pre"
 
+
 def model_find_param_name(model, param_to_find):
     """Look up the name of a model parameter.
 
@@ -69,6 +64,7 @@ def model_find_module_name(model, module_to_find):
             return name
     return None
 
+
 def model_find_param(model, param_to_find_name):
     """Look a model parameter by its name
 
diff --git a/distiller/data_loggers/__init__.py b/distiller/data_loggers/__init__.py
index 2ffca80..a3351f0 100755
--- a/distiller/data_loggers/__init__.py
+++ b/distiller/data_loggers/__init__.py
@@ -14,7 +14,7 @@
 # limitations under the License.
 #
 
-from .collector import ActivationSparsityCollector
+from .collector import *
 from .logger import PythonLogger, TensorBoardLogger, CsvLogger
 
 del logger
diff --git a/distiller/data_loggers/collector.py b/distiller/data_loggers/collector.py
index 1254f7a..a18a368 100755
--- a/distiller/data_loggers/collector.py
+++ b/distiller/data_loggers/collector.py
@@ -14,88 +14,302 @@
 # limitations under the License.
 #
 
+import xlsxwriter
+import os
+from collections import OrderedDict
+from contextlib import contextmanager
 import torch
-from distiller.utils import sparsity
 from torchnet.meter import AverageValueMeter
 import logging
+import distiller
 msglogger = logging.getLogger()
 
-__all__ = ['ActivationSparsityCollector']
+__all__ = ['SummaryActivationStatsCollector', 'RecordsActivationStatsCollector',
+           'collector_context', 'collectors_context']
 
-class DataCollector(object):
-    def __init__(self):
-        pass
 
+class ActivationStatsCollector(object):
+    """Collect model activation statistics information.
 
-class ActivationSparsityCollector(DataCollector):
-    """Collect model activation sparsity information.
+    ActivationStatsCollector is the base class for classes that collect activations statistics.
+    You may collect statistics on different phases of the optimization process (training, validation, test).
 
-    CNN models with ReLU layers, exhibit sparse activations.
-    ActivationSparsityCollector will collect information about this sparsity.
-    Currently we only record the mean sparsity of the activations, but this can be expanded
-    to collect std and other statistics.
+    Statistics data are accessible via .value() or by accessing individual modules.
 
-    The current implementation activation sparsity collection has a few caveats:
-    * It is slow
+    The current implementation has a few caveats:
+    * It is slow - therefore it is advisable to use this only when needed.
     * It can't access the activations of torch.Functions, only torch.Modules.
-    * The layer names are mangled
 
-    ActivationSparsityCollector uses the forward hook of modules in order to access the
+    ActivationStatsCollector uses the forward hook of modules in order to access the
     feature-maps.  This is both slow and limits us to seeing only the outputs of torch.Modules.
-    We can remove some of the slowness, by choosing to log only specific layers.  By default,
-    we only logs torch.nn.ReLU activations.
+    We can remove some of the slowness, by choosing to log only specific layers or use it only
+    during validation or test.  By default, we only log torch.nn.ReLU activations.
 
     The layer names are mangled, because torch.Modules don't have names and we need to invent
-    a unique name per layer.
-    """
-    def __init__(self, model, classes=[torch.nn.ReLU]):
-        """Since only specific layers produce sparse feature-maps, the
-        ActivationSparsityCollector constructor accepts an optional list of layers to log."""
+    a unique name per layer.  To assign human-readable names, it is advisable to invoke the following
+    before starting the statistics collection:
 
-        super(ActivationSparsityCollector, self).__init__()
+        distiller.utils.assign_layer_fq_names(model)
+    """
+    def __init__(self, model, stat_name, classes):
+        """
+        Args:
+            model - the model we are monitoring.
+            statistics_dict - a dictionary of {stat_name: statistics_function}, where name
+                provides a means for us to access the statistics data at a later time; and the
+                statistics_function is a function that gets an activation as input and returns
+                some statistic.
+                For example, the dictionary below collects element-wise activation sparsity
+                statistics:
+                    {"sparsity": distiller.utils.sparsity}
+            classes - a list of class types for which we collect activation statistics.
+                You can access a module's activation statistics by referring to module.<stat_name>
+                For example:
+                    print(module.sparsity)
+        """
+        super(ActivationStatsCollector, self).__init__()
         self.model = model
+        self.stat_name = stat_name
         self.classes = classes
-        self._init_activations_sparsity(model)
+        self.fwd_hook_handles = []
 
     def value(self):
-        """Return a dictionary containing {layer_name: mean sparsity}"""
-        activation_sparsity = {}
-        _collect_activations_sparsity(self.model, activation_sparsity)
-        return activation_sparsity
+        """Return a dictionary containing {layer_name: statistic}"""
+        activation_stats = OrderedDict()
+        self.__value(self.model, activation_stats)
+        return activation_stats
+
+    def __value(self, module, activation_stats):
+        for child_module in module._modules.values():
+            self.__value(child_module, activation_stats)
+        self._collect_activations_stats(module, activation_stats)
+
+    def start(self):
+        """Start collecting activation stats.
+
+        This will iteratively register the modules' forward-hooks, so that the collector
+        will be called from the forward traversal and get exposed to activation data.
+        """
+        assert len(self.fwd_hook_handles) == 0
+        self.__start(self.model)
 
+    def stop(self):
+        """Stop collecting activation stats.
 
-    def _init_activations_sparsity(self, module, name=''):
-        def __activation_sparsity_cb(module, input, output):
-            """Record the activation sparsity of 'module'
+        This will iteratively unregister the modules' forward-hooks.
+        """
+        for handle in self.fwd_hook_handles:
+            handle.remove()
+        self.fwd_hook_handles = []
 
-            This is a callback from the forward() of 'module'.
-            """
-            module.sparsity.add(sparsity(output.data))
+    def reset(self):
+        """Reset the statsitics counters of this collector."""
+        self.__reset(self.model)
+        return self
 
-        has_children = False
+    def __reset(self, module):
+        for child_module in module._modules.values():
+            self.__reset(child_module)
+        self._reset_counter(module)
+
+    def __activation_stats_cb(self, module, input, output):
+        """Handle new activations ('output' argument).
+
+        This is invoked from the forward-pass callback of module 'module'.
+        """
+        raise NotImplementedError
+
+    def __start(self, module, name=''):
+        """Iteratively register to the forward-pass callback of all eligable modules.
+
+        Eligable modules are currently filtered by their class type.
+        """
+        is_leaf_node = True
         for name, sub_module in module._modules.items():
-            self._init_activations_sparsity(sub_module, name)
-            has_children = True
-        if not has_children:
+            self.__start(sub_module, name)
+            is_leaf_node = False
+
+        if is_leaf_node:
             if type(module) in self.classes:
-                module.register_forward_hook(__activation_sparsity_cb)
-                module.sparsity = AverageValueMeter()
-                if hasattr(module, 'ref_name'):
-                    module.sparsity.name = 'sparsity_' + module.ref_name
-                else:
-                    module.sparsity.name = 'sparsity_' + name + '_' + module.__class__.__name__ + '_' + str(id(module))
+                self.fwd_hook_handles.append(module.register_forward_hook(self._activation_stats_cb))
+                self._start_counter(module)
+
+    def _reset_counter(self, mod):
+        """Reset a specific statistic counter - this is subclass-specific code"""
+        raise NotImplementedError
+
+    def _collect_activations_stats(self, module, activation_stats, name=''):
+        """Handle new activations - this is subclass-specific code"""
+        raise NotImplementedError
+
+
+class SummaryActivationStatsCollector(ActivationStatsCollector):
+    """This class collects activiations statistical summaries.
+
+    This Collector computes the mean of some statistic of the activation.  It is rather
+    light-weight and quicker than collecting a record per activation.
+    The statistic function is configured in the constructor.
+    """
+    def __init__(self, model, stat_name, summary_fn, classes=[torch.nn.ReLU]):
+        super(SummaryActivationStatsCollector, self).__init__(model, stat_name, classes)
+        self.summary_fn = summary_fn
+
+    def _activation_stats_cb(self, module, input, output):
+        """Record the activation sparsity of 'module'
+
+        This is a callback from the forward() of 'module'.
+        """
+        getattr(module, self.stat_name).add(self.summary_fn(output.data))
+
+    def _start_counter(self, module):
+        if not hasattr(module, self.stat_name):
+            setattr(module, self.stat_name, AverageValueMeter())
+            # Assign a name to this summary
+            if hasattr(module, 'distiller_name'):
+                getattr(module, self.stat_name).name = '_'.join((self.stat_name, module.distiller_name))
+            else:
+                getattr(module, self.stat_name).name = '_'.join((self.stat_name,
+                                                                 module.__class__.__name__,
+                                                                 str(id(module))))
+
+    def _reset_counter(self, mod):
+        if hasattr(mod, self.stat_name):
+            getattr(mod, self.stat_name).reset()
+
+    def _collect_activations_stats(self, module, activation_stats, name=''):
+        if hasattr(module, self.stat_name):
+            mean = getattr(module, self.stat_name).mean
+            if isinstance(mean, torch.Tensor):
+                mean = mean.tolist()
+            activation_stats[getattr(module, self.stat_name).name] = mean
+
+    def to_xlsx(self, fname):
+        """Save the records to an Excel workbook, with one worksheet per layer.
+        """
+        fname = ".".join([fname, 'xlsx'])
+        try:
+            os.remove(fname)
+        except OSError:
+            pass
+
+        records_dict = self.value()
+        with xlsxwriter.Workbook(fname) as workbook:
+            worksheet = workbook.add_worksheet(self.stat_name)
+            col_names = []
+            for col, (module_name, module_summary_data) in enumerate(records_dict.items()):
+                if not isinstance(module_summary_data, list):
+                    module_summary_data = [module_summary_data]
+                worksheet.write_column(1, col, module_summary_data)
+                col_names.append(module_name)
+            worksheet.write_row(0, 0, col_names)
+
+
+class RecordsActivationStatsCollector(ActivationStatsCollector):
+    """This class collects activiations statistical records.
+
+    This Collector computes a hard-coded set of activations statsitics and collects a
+    record per activation.  The activation records of the entire model (only filtered modules),
+    can be saved to an Excel workbook.
+
+    For obvious reasons, this is slower than SummaryActivationStatsCollector.
+    """
+    def __init__(self, model, classes=[torch.nn.ReLU]):
+        super(RecordsActivationStatsCollector, self).__init__(model, "statsitics_records", classes)
+
+    def _activation_stats_cb(self, module, input, output):
+        """Record the activation sparsity of 'module'
+
+        This is a callback from the forward() of 'module'.
+        """
+        def to_np(stats):
+            if isinstance(stats, tuple):
+                return stats[0].detach().cpu().numpy()
+            else:
+                return stats.detach().cpu().numpy()
+
+        # We get a batch of activations, from which we collect statistics
+        act = output.view(output.size(0), -1)
+        batch_min_list = to_np(torch.min(act, dim=1)).tolist()
+        batch_max_list = to_np(torch.max(act, dim=1)).tolist()
+        batch_mean_list = to_np(torch.mean(act, dim=1)).tolist()
+        batch_std_list = to_np(torch.std(act, dim=1)).tolist()
+        batch_l2_list = to_np(torch.norm(act, p=2, dim=1)).tolist()
+
+        module.statsitics_records['min'].extend(batch_min_list)
+        module.statsitics_records['max'].extend(batch_max_list)
+        module.statsitics_records['mean'].extend(batch_mean_list)
+        module.statsitics_records['std'].extend(batch_std_list)
+        module.statsitics_records['l2'].extend(batch_l2_list)
+        module.statsitics_records['shape'] = distiller.size2str(output)
 
     @staticmethod
-    def _collect_activations_sparsity(model, activation_sparsity, name=''):
-        for name, module in model._modules.items():
-            _collect_activations_sparsity(module, activation_sparsity, name)
+    def _create_records_dict():
+        records = OrderedDict()
+        for stat_name in ['min', 'max', 'mean', 'std', 'l2']:
+            records[stat_name] = []
+        records['shape'] = ''
+        return records
+
+    def to_xlsx(self, fname):
+        """Save the records to an Excel workbook, with one worksheet per layer.
+        """
+        fname = ".".join([fname, 'xlsx'])
+        try:
+            os.remove(fname)
+        except OSError:
+            pass
+
+        records_dict = self.value()
+        with xlsxwriter.Workbook(fname) as workbook:
+            for module_name, module_act_records in records_dict.items():
+                worksheet = workbook.add_worksheet(module_name)
+                col_names = []
+                for col, (col_name, col_data) in enumerate(module_act_records.items()):
+                    if col_name == 'shape':
+                        continue
+                    worksheet.write_column(1, col, col_data)
+                    col_names.append(col_name)
+                worksheet.write_row(0, 0, col_names)
+                worksheet.write(0, len(col_names)+2, module_act_records['shape'])
+
+    def _start_counter(self, module):
+        if not hasattr(module, "statsitics_records"):
+            module.statsitics_records = self._create_records_dict()
+
+    def _reset_counter(self, mod):
+        if hasattr(mod, "statsitics_records"):
+            mod.statsitics_records = self._create_records_dict()
+
+    def _collect_activations_stats(self, module, activation_stats, name=''):
+        if hasattr(module, "statsitics_records"):
+            activation_stats[module.distiller_name] = module.statsitics_records
+
+
+@contextmanager
+def collector_context(collector):
+    """A context manager for an activation collector"""
+    if collector is not None:
+        collector.reset().start()
+    yield collector
+    if collector is not None:
+        collector.stop()
+
 
-        if hasattr(model, 'sparsity'):
-            activation_sparsity[model.sparsity.name] = model.sparsity.mean
+@contextmanager
+def collectors_context(collectors_dict):
+    """A context manager for a dictionary of collectors"""
+    if len(collectors_dict) == 0:
+        yield collectors_dict
+        return
+    for collector in collectors_dict.values():
+        collector.reset().start()
+    yield collectors_dict
+    for collector in collectors_dict.values():
+        collector.stop()
 
 
-class TrainingProgressCollector(DataCollector):
-    def __init__(self, stats = {}):
+class TrainingProgressCollector(object):
+    def __init__(self, stats={}):
         super(TrainingProgressCollector, self).__init__()
         object.__setattr__(self, '_stats', stats)
 
diff --git a/distiller/data_loggers/logger.py b/distiller/data_loggers/logger.py
index 796c6c3..89ac473 100755
--- a/distiller/data_loggers/logger.py
+++ b/distiller/data_loggers/logger.py
@@ -28,7 +28,7 @@ import torch
 import numpy as np
 import tabulate
 import distiller
-from distiller.utils import density, sparsity, sparsity_2D, size_to_str, to_np
+from distiller.utils import density, sparsity, sparsity_2D, size_to_str, to_np, norm_filters
 # TensorBoard logger
 from .tbbackend import TBBackend
 # Visdom logger
@@ -53,7 +53,7 @@ class DataLogger(object):
     def log_training_progress(self, model, epoch, i, set_size, batch_time, data_time, classerr, losses, print_freq, collectors):
         raise NotImplementedError
 
-    def log_activation_sparsity(self, activation_sparsity, logcontext):
+    def log_activation_statsitic(self, phase, stat_name, activation_stats, epoch):
         raise NotImplementedError
 
     def log_weights_sparsity(self, model, epoch):
@@ -81,12 +81,11 @@ class PythonLogger(DataLogger):
                 log = log + '{name} {val:.6f}    '.format(name=name, val=val)
         self.pylogger.info(log)
 
-
-    def log_activation_sparsity(self, activation_sparsity, logcontext):
+    def log_activation_statsitic(self, phase, stat_name, activation_stats, epoch):
         data = []
-        for layer, sparsity in activation_sparsity.items():
-            data.append([layer, sparsity*100])
-        t = tabulate.tabulate(data, headers=['Layer', 'sparsity (%)'], tablefmt='psql', floatfmt=".1f")
+        for layer, statistic in activation_stats.items():
+            data.append([layer, statistic])
+        t = tabulate.tabulate(data, headers=['Layer', stat_name], tablefmt='psql', floatfmt=".2f")
         msglogger.info('\n' + t)
 
     def log_weights_sparsity(self, model, epoch):
@@ -119,9 +118,9 @@ class TensorBoardLogger(DataLogger):
             self.tblogger.scalar_summary(prefix+tag, value, total_steps(total, epoch, completed))
         self.tblogger.sync_to_file()
 
-    def log_activation_sparsity(self, activation_sparsity, epoch):
-        group = 'sparsity/activations/'
-        for tag, value in activation_sparsity.items():
+    def log_activation_statsitic(self, phase, stat_name, activation_stats, epoch):
+        group = stat_name + '/activations/' + phase + "/"
+        for tag, value in activation_stats.items():
             self.tblogger.scalar_summary(group+tag, value, epoch)
         self.tblogger.sync_to_file()
 
@@ -142,6 +141,15 @@ class TensorBoardLogger(DataLogger):
         self.tblogger.scalar_summary("sprasity/weights/total", 100*(1 - sparse_params_size/params_size), epoch)
         self.tblogger.sync_to_file()
 
+    def log_weights_filter_magnitude(self, model, epoch, multi_graphs=False):
+        """Log the L1-magnitude of the weights tensors.
+        """
+        for name, param in model.state_dict().items():
+            if param.dim() in [4]:
+                self.tblogger.list_summary('magnitude/filters/' + name,
+                                           list(to_np(norm_filters(param))), epoch, multi_graphs)
+        self.tblogger.sync_to_file()
+
     def log_weights_distribution(self, named_params, steps_completed):
         if named_params is None:
             return
@@ -174,6 +182,6 @@ class CsvLogger(DataLogger):
                     params_size += torch.numel(param)
                     sparse_params_size += param.numel() * _density
                     writer.writerow([name, size_to_str(param.size()),
-                                       torch.numel(param),
-                                       int(_density * param.numel()),
-                                       (1-_density)*100])
+                                     torch.numel(param),
+                                     int(_density * param.numel()),
+                                     (1-_density)*100])
diff --git a/distiller/data_loggers/tbbackend.py b/distiller/data_loggers/tbbackend.py
index 0761ef2..ccd077a 100755
--- a/distiller/data_loggers/tbbackend.py
+++ b/distiller/data_loggers/tbbackend.py
@@ -18,13 +18,16 @@
 Writes logs to a file using a Google's TensorBoard protobuf format.
 See: https://github.com/tensorflow/tensorflow/blob/master/tensorflow/core/framework/summary.proto
 """
+import os
 import tensorflow as tf
 import numpy as np
 
-class TBBackend(object):
 
+class TBBackend(object):
     def __init__(self, log_dir):
-        self.writer = tf.summary.FileWriter(log_dir)
+        self.writers = []
+        self.log_dir = log_dir
+        self.writers.append(tf.summary.FileWriter(log_dir))
 
     def scalar_summary(self, tag, scalar, step):
         """From TF documentation:
@@ -32,7 +35,26 @@ class TBBackend(object):
             value: value associated with the tag (a float).
         """
         summary = tf.Summary(value=[tf.Summary.Value(tag=tag, simple_value=scalar)])
-        self.writer.add_summary(summary, step)
+        self.writers[0].add_summary(summary, step)
+
+    def list_summary(self, tag, list, step, multi_graphs):
+        """Log a relatively small list of scalars.
+
+        We want to track the progress of multiple scalar parameters in a single graph.
+        The list provides a single value for each of the parameters we are tracking.
+        
+        NOTE: There are two ways to log multiple values in TB and neither one is optimal.
+        1. Use a single writer: in this case all of the parameters use the same color, and
+           distinguishing between them is difficult.
+        2. Use multiple writers: in this case each parameter has its own color which helps
+           to visually separate the parameters.  However, each writer logs to a different
+           file and this creates a lot of files which slow down the TB load.
+        """
+        for i, scalar in enumerate(list):
+            if multi_graphs and (i+1 > len(self.writers)):
+                self.writers.append(tf.summary.FileWriter(os.path.join(self.log_dir, str(i))))
+            summary = tf.Summary(value=[tf.Summary.Value(tag=tag, simple_value=scalar)])
+            self.writers[0 if not multi_graphs else i].add_summary(summary, step)
 
     def histogram_summary(self, tag, tensor, step):
         """
@@ -49,11 +71,11 @@ class TBBackend(object):
         """
         hist, edges = np.histogram(tensor, bins=200)
         tfhist = tf.HistogramProto(
-            min = np.min(tensor),
-            max = np.max(tensor),
-            num = int(np.prod(tensor.shape)),
-            sum = np.sum(tensor),
-            sum_squares = np.sum(np.square(tensor)))
+            min=np.min(tensor),
+            max=np.max(tensor),
+            num=int(np.prod(tensor.shape)),
+            sum=np.sum(tensor),
+            sum_squares=np.sum(np.square(tensor)))
 
         # From the TF documentation:
         #   Parallel arrays encoding the bucket boundaries and the bucket values.
@@ -64,7 +86,8 @@ class TBBackend(object):
         tfhist.bucket.extend(hist)
 
         summary = tf.Summary(value=[tf.Summary.Value(tag=tag, histo=tfhist)])
-        self.writer.add_summary(summary, step)
+        self.writers[0].add_summary(summary, step)
 
     def sync_to_file(self):
-        self.writer.flush()
+        for writer in self.writers:
+            writer.flush()
diff --git a/distiller/policy.py b/distiller/policy.py
index 6d70533..dc85c60 100755
--- a/distiller/policy.py
+++ b/distiller/policy.py
@@ -90,6 +90,7 @@ class PruningPolicy(ScheduledTrainingPolicy):
         if self.levels is not None:
             self.pruner.levels = self.levels
 
+        meta['model'] = model
         for param_name, param in model.named_parameters():
             self.pruner.set_param_mask(param, param_name, zeros_mask_dict, meta)
 
diff --git a/distiller/pruning/__init__.py b/distiller/pruning/__init__.py
index dc8a1c7..24fe7e5 100755
--- a/distiller/pruning/__init__.py
+++ b/distiller/pruning/__init__.py
@@ -19,11 +19,14 @@
 """
 
 from .magnitude_pruner import MagnitudeParameterPruner
-from .automated_gradual_pruner import AutomatedGradualPruner, StructuredAutomatedGradualPruner
+from .automated_gradual_pruner import AutomatedGradualPruner, L1RankedStructureParameterPruner_AGP, \
+                                      ActivationAPoZRankedFilterPruner_AGP, GradientRankedFilterPruner_AGP, \
+                                      RandomRankedFilterPruner_AGP
 from .level_pruner import SparsityLevelParameterPruner
 from .sensitivity_pruner import SensitivityPruner
 from .structure_pruner import StructureParameterPruner
-from .ranked_structures_pruner import L1RankedStructureParameterPruner
+from .ranked_structures_pruner import L1RankedStructureParameterPruner, ActivationAPoZRankedFilterPruner, \
+                                      RandomRankedFilterPruner, GradientRankedFilterPruner
 from .baidu_rnn_pruner import BaiduRNNPruner
 
 del magnitude_pruner
diff --git a/distiller/pruning/automated_gradual_pruner.py b/distiller/pruning/automated_gradual_pruner.py
index dff2168..dfb80e5 100755
--- a/distiller/pruning/automated_gradual_pruner.py
+++ b/distiller/pruning/automated_gradual_pruner.py
@@ -16,7 +16,7 @@
 
 from .pruner import _ParameterPruner
 from .level_pruner import SparsityLevelParameterPruner
-from .ranked_structures_pruner import L1RankedStructureParameterPruner
+from .ranked_structures_pruner import *
 from distiller.utils import *
 # import logging
 # msglogger = logging.getLogger()
@@ -61,28 +61,56 @@ class AutomatedGradualPruner(_ParameterPruner):
         target_sparsity = (self.final_sparsity +
                            (self.initial_sparsity-self.final_sparsity) *
                            (1.0 - ((current_epoch-starting_epoch)/span))**3)
-        self.pruning_fn(param, param_name, zeros_mask_dict, target_sparsity)
+        self.pruning_fn(param, param_name, zeros_mask_dict, target_sparsity, meta['model'])
 
     @staticmethod
-    def prune_to_target_sparsity(param, param_name, zeros_mask_dict, target_sparsity):
+    def prune_to_target_sparsity(param, param_name, zeros_mask_dict, target_sparsity, model=None):
         return SparsityLevelParameterPruner.prune_level(param, param_name, zeros_mask_dict, target_sparsity)
 
 
-class StructuredAutomatedGradualPruner(AutomatedGradualPruner):
+class CriterionParameterizedAGP(AutomatedGradualPruner):
     def __init__(self, name, initial_sparsity, final_sparsity, reg_regims):
         self.reg_regims = reg_regims
         weights = [weight for weight in reg_regims.keys()]
-        if not all([group in ['3D', 'Filters', 'Channels'] for group in reg_regims.values()]):
-            raise ValueError("Currently only filter (3D) and channel pruning is supported")
-        super(StructuredAutomatedGradualPruner, self).__init__(name, initial_sparsity,
-                                                               final_sparsity, weights,
-                                                               pruning_fn=self.prune_to_target_sparsity)
+        if not all([group in ['3D', 'Filters', 'Channels', 'Rows'] for group in reg_regims.values()]):
+            raise ValueError("Unsupported group structure")
+        super(CriterionParameterizedAGP, self).__init__(name, initial_sparsity,
+                                                        final_sparsity, weights,
+                                                        pruning_fn=self.prune_to_target_sparsity)
 
-    def prune_to_target_sparsity(self, param, param_name, zeros_mask_dict, target_sparsity):
+    def prune_to_target_sparsity(self, param, param_name, zeros_mask_dict, target_sparsity, model):
         if self.reg_regims[param_name] in ['3D', 'Filters']:
-            L1RankedStructureParameterPruner.rank_prune_filters(target_sparsity, param,
-                                                                param_name, zeros_mask_dict)
-        else:
-            if self.reg_regims[param_name] == 'Channels':
-                L1RankedStructureParameterPruner.rank_prune_channels(target_sparsity, param,
-                                                                     param_name, zeros_mask_dict)
+            self.filters_ranking_fn(target_sparsity, param, param_name, zeros_mask_dict, model)
+        elif self.reg_regims[param_name] == 'Channels':
+            self.channels_ranking_fn(target_sparsity, param, param_name, zeros_mask_dict, model)
+        elif self.reg_regims[param_name] == 'Rows':
+            self.rows_ranking_fn(target_sparsity, param, param_name, zeros_mask_dict, model)
+
+
+# TODO: this class parameterization is cumbersome: the ranking functions (per structure)
+# should come from the YAML schedule
+
+class L1RankedStructureParameterPruner_AGP(CriterionParameterizedAGP):
+    def __init__(self, name, initial_sparsity, final_sparsity, reg_regims):
+        super(L1RankedStructureParameterPruner_AGP, self).__init__(name, initial_sparsity, final_sparsity, reg_regims)
+        self.filters_ranking_fn = L1RankedStructureParameterPruner.rank_prune_filters
+        self.channels_ranking_fn = L1RankedStructureParameterPruner.rank_prune_channels
+        self.rows_ranking_fn = L1RankedStructureParameterPruner.rank_prune_rows
+
+
+class ActivationAPoZRankedFilterPruner_AGP(CriterionParameterizedAGP):
+    def __init__(self, name, initial_sparsity, final_sparsity, reg_regims):
+        super(ActivationAPoZRankedFilterPruner_AGP, self).__init__(name, initial_sparsity, final_sparsity, reg_regims)
+        self.filters_ranking_fn = ActivationAPoZRankedFilterPruner.rank_prune_filters
+
+
+class GradientRankedFilterPruner_AGP(CriterionParameterizedAGP):
+    def __init__(self, name, initial_sparsity, final_sparsity, reg_regims):
+        super(GradientRankedFilterPruner_AGP, self).__init__(name, initial_sparsity, final_sparsity, reg_regims)
+        self.filters_ranking_fn = GradientRankedFilterPruner.rank_prune_filters
+
+
+class RandomRankedFilterPruner_AGP(CriterionParameterizedAGP):
+    def __init__(self, name, initial_sparsity, final_sparsity, reg_regims):
+        super(RandomRankedFilterPruner_AGP, self).__init__(name, initial_sparsity, final_sparsity, reg_regims)
+        self.filters_ranking_fn = RandomRankedFilterPruner.rank_prune_filters
diff --git a/distiller/pruning/ranked_structures_pruner.py b/distiller/pruning/ranked_structures_pruner.py
index 407cf73..d1b36c3 100755
--- a/distiller/pruning/ranked_structures_pruner.py
+++ b/distiller/pruning/ranked_structures_pruner.py
@@ -14,6 +14,7 @@
 # limitations under the License.
 #
 
+import numpy as np
 import logging
 import torch
 import distiller
@@ -22,15 +23,20 @@ msglogger = logging.getLogger()
 
 
 # TODO: support different policies for ranking structures
-class L1RankedStructureParameterPruner(_ParameterPruner):
+class RankedStructureParameterPruner(_ParameterPruner):
     """Uses mean L1-norm to rank structures and prune a specified percentage of structures
     """
     def __init__(self, name, reg_regims):
-        super(L1RankedStructureParameterPruner, self).__init__(name)
-        self.name = name
+        super(RankedStructureParameterPruner, self).__init__(name)
         self.reg_regims = reg_regims
 
 
+class L1RankedStructureParameterPruner(RankedStructureParameterPruner):
+    """Uses mean L1-norm to rank structures and prune a specified percentage of structures
+    """
+    def __init__(self, name, reg_regims):
+        super(L1RankedStructureParameterPruner, self).__init__(name, reg_regims)
+
     def set_param_mask(self, param, param_name, zeros_mask_dict, meta):
         if param_name not in self.reg_regims.keys():
             return
@@ -44,9 +50,11 @@ class L1RankedStructureParameterPruner(_ParameterPruner):
             return self.rank_prune_filters(fraction_to_prune, param, param_name, zeros_mask_dict)
         elif group_type == 'Channels':
             return self.rank_prune_channels(fraction_to_prune, param, param_name, zeros_mask_dict)
+        elif group_type == 'Rows':
+            return self.rank_prune_rows(fraction_to_prune, param, param_name, zeros_mask_dict)
         else:
             raise ValueError("Currently only filter (3D) and channel ranking is supported")
-            
+
     @staticmethod
     def rank_channels(fraction_to_prune, param):
         num_filters = param.size(0)
@@ -71,7 +79,7 @@ class L1RankedStructureParameterPruner(_ParameterPruner):
         return bottomk, channel_mags
 
     @staticmethod
-    def rank_prune_channels(fraction_to_prune, param, param_name, zeros_mask_dict):
+    def rank_prune_channels(fraction_to_prune, param, param_name, zeros_mask_dict, model=None):
         bottomk_channels, channel_mags = L1RankedStructureParameterPruner.rank_channels(fraction_to_prune, param)
         if bottomk_channels is None:
             # Empty list means that fraction_to_prune is too low to prune anything
@@ -92,20 +100,177 @@ class L1RankedStructureParameterPruner(_ParameterPruner):
                        fraction_to_prune, len(bottomk_channels), num_channels)
 
     @staticmethod
-    def rank_prune_filters(fraction_to_prune, param, param_name, zeros_mask_dict):
+    def rank_prune_filters(fraction_to_prune, param, param_name, zeros_mask_dict, model=None):
         assert param.dim() == 4, "This thresholding is only supported for 4D weights"
+        # First we rank the filters
         view_filters = param.view(param.size(0), -1)
-        filter_mags = view_filters.data.norm(1, dim=1)  # same as view_filters.data.abs().sum(dim=1)
+        filter_mags = view_filters.data.abs().mean(dim=1)
         topk_filters = int(fraction_to_prune * filter_mags.size(0))
         if topk_filters == 0:
             msglogger.info("Too few filters - can't prune %.1f%% filters", 100*fraction_to_prune)
             return
-
         bottomk, _ = torch.topk(filter_mags, topk_filters, largest=False, sorted=True)
         threshold = bottomk[-1]
-        binary_map = filter_mags.gt(threshold).type(param.data.type())
-        expanded = binary_map.expand(param.size(1) * param.size(2) * param.size(3), param.size(0)).t().contiguous()
-        zeros_mask_dict[param_name].mask = expanded.view(param.size(0), param.size(1), param.size(2), param.size(3))
+        # Then we threshold
+        zeros_mask_dict[param_name].mask = distiller.group_threshold_mask(param, 'Filters', threshold, 'Mean_Abs')
         msglogger.info("L1RankedStructureParameterPruner - param: %s pruned=%.3f goal=%.3f (%d/%d)", param_name,
                        distiller.sparsity(zeros_mask_dict[param_name].mask),
                        fraction_to_prune, topk_filters, filter_mags.size(0))
+
+    @staticmethod
+    def rank_prune_rows(fraction_to_prune, param, param_name, zeros_mask_dict, model=None):
+        """Prune the rows of a matrix, based on ranked L1-norms of the matrix rows.
+
+        PyTorch stores the weights matrices in a transposed format.  I.e. before performing GEMM, a matrix is
+        transposed.  This is counter-intuitive.  To deal with this, we can either transpose the matrix and
+        then proceed to compute the masks as usual, or we can treat columns as rows, and rows as columns :-(.
+        We choose the latter, because transposing very large matrices can be detrimental to performance.  Note
+        that computing mean L1-norm of columns is also not optimal, because consequtive column elements are far
+        away from each other in memory, and this means poor use of caches and system memory.
+        """
+
+        assert param.dim() == 2, "This thresholding is only supported for 2D weights"
+        ROWS_DIM = 0
+        THRESHOLD_DIM = 'Cols'
+        rows_mags = param.abs().mean(dim=ROWS_DIM)
+        num_rows_to_prune = int(fraction_to_prune * rows_mags.size(0))
+        if num_rows_to_prune == 0:
+            msglogger.info("Too few filters - can't prune %.1f%% rows", 100*fraction_to_prune)
+            return
+        bottomk_rows, _ = torch.topk(rows_mags, num_rows_to_prune, largest=False, sorted=True)
+        threshold = bottomk_rows[-1]
+        zeros_mask_dict[param_name].mask = distiller.group_threshold_mask(param, THRESHOLD_DIM, threshold, 'Mean_Abs')
+        msglogger.info("L1RankedStructureParameterPruner - param: %s pruned=%.3f goal=%.3f (%d/%d)", param_name,
+                       distiller.sparsity(zeros_mask_dict[param_name].mask),
+                       fraction_to_prune, num_rows_to_prune, rows_mags.size(0))
+
+
+class RankedFiltersParameterPruner(RankedStructureParameterPruner):
+    """Base class for the special (but often-used) case of ranking filters
+    """
+    def __init__(self, name, reg_regims):
+        super(RankedFiltersParameterPruner, self).__init__(name, reg_regims)
+
+    def set_param_mask(self, param, param_name, zeros_mask_dict, meta):
+        if param_name not in self.reg_regims.keys():
+            return
+
+        group_type = self.reg_regims[param_name][1]
+        fraction_to_prune = self.reg_regims[param_name][0]
+        if fraction_to_prune == 0:
+            return
+
+        if group_type in ['3D', 'Filters']:
+            return self.rank_prune_filters(fraction_to_prune, param, param_name, zeros_mask_dict, meta['model'])
+        else:
+            raise ValueError("Currently only filter (3D) ranking is supported")
+
+    @staticmethod
+    def mask_from_filter_order(filters_ordered_by_criterion, param, num_filters):
+        binary_map = torch.zeros(num_filters).cuda()
+        binary_map[filters_ordered_by_criterion] = 1
+        #msglogger.info("binary_map: {}".format(binary_map))
+        expanded = binary_map.expand(param.size(1) * param.size(2) * param.size(3), param.size(0)).t().contiguous()
+        return expanded.view(param.shape)
+
+
+class ActivationAPoZRankedFilterPruner(RankedFiltersParameterPruner):
+    """Uses mean APoZ (average percentage of zeros) activation channels to rank structures
+    and prune a specified percentage of structures.
+
+    "Network Trimming: A Data-Driven Neuron Pruning Approach towards Efficient Deep Architectures",
+    Hengyuan Hu, Rui Peng, Yu-Wing Tai, Chi-Keung Tang, ICLR 2016
+    https://arxiv.org/abs/1607.03250
+    """
+    def __init__(self, name, reg_regims):
+        super(ActivationAPoZRankedFilterParameterPruner, self).__init__(name, reg_regims)
+
+    @staticmethod
+    def rank_prune_filters(fraction_to_prune, param, param_name, zeros_mask_dict, model):
+        assert param.dim() == 4, "This thresholding is only supported for 4D weights"
+
+        # Use the parameter name to locate the module that has the activation sparsity statistics
+        fq_name = param_name.replace(".conv", ".relu")[:-len(".weight")]
+        module = distiller.find_module_by_fq_name(model, fq_name)
+        if module is None:
+            raise ValueError("Could not find a layer named %s in the model."
+                             "\nMake sure to use assign_layer_fq_names()" % fq_name)
+        if not hasattr(module, 'apoz_channels'):
+            raise ValueError("Could not find attribute \'apoz_channels\' in module %s."
+                             "\nMake sure to use SummaryActivationStatsCollector(\"apoz_channels\")" % fq_name)
+
+        apoz, std = module.apoz_channels.value()
+        num_filters = param.size(0)
+        num_filters_to_prune = int(fraction_to_prune * num_filters)
+        if num_filters_to_prune == 0:
+            msglogger.info("Too few filters - can't prune %.1f%% filters", 100*fraction_to_prune)
+            return
+
+        # Sort from high to low, and remove the bottom 'num_filters_to_prune' filters
+        filters_ordered_by_apoz = np.argsort(-apoz)[:-num_filters_to_prune]
+        zeros_mask_dict[param_name].mask = RankedFiltersParameterPruner.mask_from_filter_order(filters_ordered_by_apoz,
+                                                                                               param, num_filters)
+
+        msglogger.info("ActivationL1RankedStructureParameterPruner - param: %s pruned=%.3f goal=%.3f (%d/%d)",
+                       param_name,
+                       distiller.sparsity_3D(zeros_mask_dict[param_name].mask),
+                       fraction_to_prune, num_filters_to_prune, num_filters)
+
+
+class RandomRankedFilterPruner(RankedFiltersParameterPruner):
+    """A Random raanking of filters.
+
+    This is used for sanity testing of other algorithms.
+    """
+    def __init__(self, name, reg_regims):
+        super(RandomRankedFilterPruner, self).__init__(name, reg_regims)
+
+    @staticmethod
+    def rank_prune_filters(fraction_to_prune, param, param_name, zeros_mask_dict, model):
+        assert param.dim() == 4, "This thresholding is only supported for 4D weights"
+        num_filters = param.size(0)
+        num_filters_to_prune = int(fraction_to_prune * num_filters)
+
+        if num_filters_to_prune == 0:
+            msglogger.info("Too few filters - can't prune %.1f%% filters", 100*fraction_to_prune)
+            return
+
+        filters_ordered_randomly = np.random.permutation(num_filters)[:-num_filters_to_prune]
+        zeros_mask_dict[param_name].mask = RankedFiltersParameterPruner.mask_from_filter_order(filters_ordered_randomly,
+                                                                                               param, num_filters)
+
+        msglogger.info("RandomRankedFilterPruner - param: %s pruned=%.3f goal=%.3f (%d/%d)",
+                       param_name,
+                       distiller.sparsity_3D(zeros_mask_dict[param_name].mask),
+                       fraction_to_prune, num_filters_to_prune, num_filters)
+
+
+class GradientRankedFilterPruner(RankedFiltersParameterPruner):
+    """
+    """
+    def __init__(self, name, reg_regims):
+        super(RandomRankedFilterPruner, self).__init__(name, reg_regims)
+
+    @staticmethod
+    def rank_prune_filters(fraction_to_prune, param, param_name, zeros_mask_dict, model):
+        assert param.dim() == 4, "This thresholding is only supported for 4D weights"
+        num_filters = param.size(0)
+        num_filters_to_prune = int(fraction_to_prune * num_filters)
+        if num_filters_to_prune == 0:
+            msglogger.info("Too few filters - can't prune %.1f%% filters", 100*fraction_to_prune)
+            return
+
+        # Compute the multiplicatipn of the filters times the filter_gradienrs
+        view_filters = param.view(param.size(0), -1)
+        view_filter_grads = param.grad.view(param.size(0), -1)
+        weighted_gradients = view_filter_grads * view_filters
+        weighted_gradients = weighted_gradients.sum(dim=1)
+
+        # Sort from high to low, and remove the bottom 'num_filters_to_prune' filters
+        filters_ordered_by_gradient = np.argsort(-weighted_gradients.detach().cpu().numpy())[:-num_filters_to_prune]
+        zeros_mask_dict[param_name].mask = RankedFiltersParameterPruner.mask_from_filter_order(filters_ordered_by_gradient,
+                                                                                               param, num_filters)
+        msglogger.info("GradientRankedFilterPruner - param: %s pruned=%.3f goal=%.3f (%d/%d)",
+                       param_name,
+                       distiller.sparsity_3D(zeros_mask_dict[param_name].mask),
+                       fraction_to_prune, num_filters_to_prune, num_filters)
diff --git a/distiller/thresholding.py b/distiller/thresholding.py
index 4cb8189..38c9923 100755
--- a/distiller/thresholding.py
+++ b/distiller/thresholding.py
@@ -20,6 +20,7 @@ The code below supports fine-grained tensor thresholding and group-wise threshol
 """
 import torch
 
+
 def threshold_mask(weights, threshold):
     """Create a threshold mask for the provided parameter tensor using
     magnitude thresholding.
@@ -32,95 +33,104 @@ def threshold_mask(weights, threshold):
     """
     return torch.gt(torch.abs(weights), threshold).type(weights.type())
 
+
 class GroupThresholdMixin(object):
-    """A mixin class to add group thresholding capabilities"""
+    """A mixin class to add group thresholding capabilities
 
+    TODO: this does not need to be a mixin - it should be made a simple function.  We keep this until we refactor
+    """
     def group_threshold_mask(self, param, group_type, threshold, threshold_criteria):
-        """Return a threshold mask for the provided parameter and group type.
-
-        Args:
-            param: The parameter to mask
-            group_type: The elements grouping type (structure).
-              One of:2D, 3D, 4D, Channels, Row, Cols
-            threshold: The threshold
-            threshold_criteria: The thresholding criteria.
-              'Mean_Abs' thresholds the entire element group using the mean of the
-              absolute values of the tensor elements.
-              'Max' thresholds the entire group using the magnitude of the largest
-              element in the group.
-        """
-        if group_type == '2D':
-            assert param.dim() == 4, "This thresholding is only supported for 4D weights"
-            view_2d = param.view(-1, param.size(2) * param.size(3))
-            # 1. Determine if the kernel "value" is below the threshold, by creating a 1D
-            #    thresholds tensor with length = #IFMs * # OFMs
-            thresholds = torch.Tensor([threshold] * param.size(0) * param.size(1)).cuda()
-            # 2. Create a binary thresholds mask, where we use the mean of the abs values of the
-            #    elements in each channel as the threshold filter.
-            # 3. Apply the threshold filter
-            binary_map = self.threshold_policy(view_2d, thresholds, threshold_criteria)
-            # 3. Finally, expand the thresholds and view as a 4D tensor
-            a = binary_map.expand(param.size(2) * param.size(3),
-                                  param.size(0) * param.size(1)).t()
-            return a.view(param.size(0), param.size(1), param.size(2), param.size(3))
-
-        elif group_type == 'Rows':
-            assert param.dim() == 2, "This regularization is only supported for 2D weights"
-            thresholds = torch.Tensor([threshold] * param.size(0)).cuda()
-            binary_map = self.threshold_policy(param, thresholds, threshold_criteria)
-            return binary_map.expand(param.size(1), param.size(0)).t()
-
-        elif group_type == 'Cols':
-            assert param.dim() == 2, "This regularization is only supported for 2D weights"
-            thresholds = torch.Tensor([threshold] * param.size(1)).cuda()
-            binary_map = self.threshold_policy(param, thresholds, threshold_criteria, dim=0)
-            return binary_map.expand(param.size(0), param.size(1))
-
-        elif group_type == '3D' or group_type == 'Filters':
-            assert param.dim() == 4, "This thresholding is only supported for 4D weights"
-            view_filters = param.view(param.size(0), -1)
-            thresholds = torch.Tensor([threshold] * param.size(0)).cuda()
-            binary_map = self.threshold_policy(view_filters, thresholds, threshold_criteria)
-            a = binary_map.expand(param.size(1) * param.size(2) * param.size(3), param.size(0)).t()
-            return a.view(param.size(0), param.size(1), param.size(2), param.size(3))
-
-        elif group_type == '4D':
-            assert param.dim() == 4, "This thresholding is only supported for 4D weights"
-            if threshold_criteria == 'Mean_Abs':
-                if param.data.abs().mean() > threshold:
-                    return None
-                return torch.zeros_like(param.data)
-            elif threshold_criteria == 'Max':
-                if param.data.abs().max() > threshold:
-                    return None
-                return torch.zeros_like(param.data)
-            exit("Invalid threshold_criteria {}".format(self.threshold_criteria))
-
-        elif group_type == 'Channels':
-            assert param.dim() == 4, "This thresholding is only supported for 4D weights"
-            num_filters = param.size(0)
-            num_kernels_per_filter = param.size(1)
-
-            view_2d = param.view(-1, param.size(2) * param.size(3))
-            # Next, compute the sum of the squares (of the elements in each row/kernel)
-            kernel_means = view_2d.abs().mean(dim=1)
-            k_means_mat = kernel_means.view(num_filters, num_kernels_per_filter).t()
-            thresholds = torch.Tensor([threshold] * num_kernels_per_filter).cuda()
-            binary_map = k_means_mat.data.mean(dim=1).gt(thresholds).type(param.type())
-
-            # Now let's expand back up to a 4D mask
-            a = binary_map.expand(num_filters, num_kernels_per_filter)
-            c = a.unsqueeze(-1)
-            d = c.expand(num_filters, num_kernels_per_filter, param.size(2) * param.size(3)).contiguous()
-            return d.view(param.size(0), param.size(1), param.size(2), param.size(3))
-
-
-    def threshold_policy(self, weights, thresholds, threshold_criteria, dim=1):
-        """
-        """
+        return group_threshold_mask(param, group_type, threshold, threshold_criteria)
+
+
+def group_threshold_mask(param, group_type, threshold, threshold_criteria):
+    """Return a threshold mask for the provided parameter and group type.
+
+    Args:
+        param: The parameter to mask
+        group_type: The elements grouping type (structure).
+          One of:2D, 3D, 4D, Channels, Row, Cols
+        threshold: The threshold
+        threshold_criteria: The thresholding criteria.
+          'Mean_Abs' thresholds the entire element group using the mean of the
+          absolute values of the tensor elements.
+          'Max' thresholds the entire group using the magnitude of the largest
+          element in the group.
+    """
+    if group_type == '2D':
+        assert param.dim() == 4, "This thresholding is only supported for 4D weights"
+        view_2d = param.view(-1, param.size(2) * param.size(3))
+        # 1. Determine if the kernel "value" is below the threshold, by creating a 1D
+        #    thresholds tensor with length = #IFMs * # OFMs
+        thresholds = torch.Tensor([threshold] * param.size(0) * param.size(1)).cuda()
+        # 2. Create a binary thresholds mask, where we use the mean of the abs values of the
+        #    elements in each channel as the threshold filter.
+        # 3. Apply the threshold filter
+        binary_map = threshold_policy(view_2d, thresholds, threshold_criteria)
+        # 3. Finally, expand the thresholds and view as a 4D tensor
+        a = binary_map.expand(param.size(2) * param.size(3),
+                              param.size(0) * param.size(1)).t()
+        return a.view(param.size(0), param.size(1), param.size(2), param.size(3))
+
+    elif group_type == 'Rows':
+        assert param.dim() == 2, "This regularization is only supported for 2D weights"
+        thresholds = torch.Tensor([threshold] * param.size(0)).cuda()
+        binary_map = threshold_policy(param, thresholds, threshold_criteria)
+        return binary_map.expand(param.size(1), param.size(0)).t()
+
+    elif group_type == 'Cols':
+        assert param.dim() == 2, "This regularization is only supported for 2D weights"
+        thresholds = torch.Tensor([threshold] * param.size(1)).cuda()
+        binary_map = threshold_policy(param, thresholds, threshold_criteria, dim=0)
+        return binary_map.expand(param.size(0), param.size(1))
+
+    elif group_type == '3D' or group_type == 'Filters':
+        assert param.dim() == 4, "This thresholding is only supported for 4D weights"
+        view_filters = param.view(param.size(0), -1)
+        thresholds = torch.Tensor([threshold] * param.size(0)).cuda()
+        binary_map = threshold_policy(view_filters, thresholds, threshold_criteria)
+        a = binary_map.expand(param.size(1) * param.size(2) * param.size(3), param.size(0)).t()
+        return a.view(param.size(0), param.size(1), param.size(2), param.size(3))
+
+    elif group_type == '4D':
+        assert param.dim() == 4, "This thresholding is only supported for 4D weights"
         if threshold_criteria == 'Mean_Abs':
-            return weights.data.abs().mean(dim=dim).gt(thresholds).type(weights.type())
+            if param.data.abs().mean() > threshold:
+                return None
+            return torch.zeros_like(param.data)
         elif threshold_criteria == 'Max':
-            maxv, _ = weights.data.abs().max(dim=dim)
-            return maxv.gt(thresholds).type(weights.type())
+            if param.data.abs().max() > threshold:
+                return None
+            return torch.zeros_like(param.data)
         exit("Invalid threshold_criteria {}".format(threshold_criteria))
+
+    elif group_type == 'Channels':
+        assert param.dim() == 4, "This thresholding is only supported for 4D weights"
+        num_filters = param.size(0)
+        num_kernels_per_filter = param.size(1)
+
+        view_2d = param.view(-1, param.size(2) * param.size(3))
+        # Next, compute the sum of the squares (of the elements in each row/kernel)
+        kernel_means = view_2d.abs().mean(dim=1)
+        k_means_mat = kernel_means.view(num_filters, num_kernels_per_filter).t()
+        thresholds = torch.Tensor([threshold] * num_kernels_per_filter).cuda()
+        binary_map = k_means_mat.data.mean(dim=1).gt(thresholds).type(param.type())
+
+        # Now let's expand back up to a 4D mask
+        a = binary_map.expand(num_filters, num_kernels_per_filter)
+        c = a.unsqueeze(-1)
+        d = c.expand(num_filters, num_kernels_per_filter, param.size(2) * param.size(3)).contiguous()
+        return d.view(param.size(0), param.size(1), param.size(2), param.size(3))
+
+
+def threshold_policy(weights, thresholds, threshold_criteria, dim=1):
+    """
+    """
+    if threshold_criteria == 'Mean_Abs':
+        return weights.data.abs().mean(dim=dim).gt(thresholds).type(weights.type())
+    elif threshold_criteria == 'L1':
+        return weights.data.norm(p=1, dim=dim).gt(thresholds).type(weights.type())
+    elif threshold_criteria == 'Max':
+        maxv, _ = weights.data.abs().max(dim=dim)
+        return maxv.gt(thresholds).type(weights.type())
+    exit("Invalid threshold_criteria {}".format(threshold_criteria))
diff --git a/distiller/utils.py b/distiller/utils.py
index a193dfc..4676474 100755
--- a/distiller/utils.py
+++ b/distiller/utils.py
@@ -21,7 +21,6 @@ with some random helper functions.
 """
 import numpy as np
 import torch
-from torch.autograd import Variable
 import torch.nn as nn
 from copy import deepcopy
 
@@ -401,23 +400,6 @@ def has_children(module):
         return False
 
 
-class DoNothingModuleWrapper(nn.Module):
-    """Implement a nn.Module which wraps another nn.Module.
-
-    The DoNothingModuleWrapper wrapper does nothing but forward
-    to the wrapped module.
-    One use-case for this class, is for replacing nn.DataParallel
-    by a module that does nothing :-).  This is a trick we use
-    to transform data-parallel to serialized models.
-    """
-    def __init__(self, module):
-        super(DoNothingModuleWrapper, self).__init__()
-        self.module = module
-
-    def forward(self, *inputs, **kwargs):
-        return self.module(*inputs, **kwargs)
-
-
 def make_non_parallel_copy(model):
     """Make a non-data-parallel copy of the provided model.
 
diff --git a/examples/agp-pruning/resnet20_filters.schedule_agp.yaml b/examples/agp-pruning/resnet20_filters.schedule_agp.yaml
index 83cc2fd..406180e 100755
--- a/examples/agp-pruning/resnet20_filters.schedule_agp.yaml
+++ b/examples/agp-pruning/resnet20_filters.schedule_agp.yaml
@@ -1,53 +1,73 @@
+# This is a hybrid pruning schedule composed of several pruning techniques, all using AGP scheduling:
+# 1. Filter pruning (and thinning) to reduce compute and activation sizes of some layers.
+# 2. Fine grained pruning to reduce the parameter memory requirements of layers with large weights tensors.
+# 3. Row pruning for the last linear (fully-connected) layer.
+#
 # Baseline results:
-# Top1: 91.780    Top5: 99.710    Loss: 0.376
+#     Top1: 91.780    Top5: 99.710    Loss: 0.376
+#     Total MACs: 40,813,184
+#
+# Results:
+#     Top1: 91.760    Top5: 99.700    Loss: 1.546
+#     Total MACs: 35,947,136
+#     Total sparsity: 41.10
+#
 # time python3 compress_classifier.py --arch resnet20_cifar  ../../../data.cifar10 -p=50 --lr=0.1 --epochs=180 --compress=../agp-pruning/resnet20_filters.schedule_agp.yaml -j=1 --deterministic --resume=../ssl/checkpoints/checkpoint_trained_dense.pth.tar
 #
 # Parameters:
 # +----+-------------------------------------+----------------+---------------+----------------+------------+------------+----------+----------+----------+------------+---------+----------+------------+
 # |    | Name                                | Shape          |   NNZ (dense) |   NNZ (sparse) |   Cols (%) |   Rows (%) |   Ch (%) |   2D (%) |   3D (%) |   Fine (%) |     Std |     Mean |   Abs-Mean |
 # |----+-------------------------------------+----------------+---------------+----------------+------------+------------+----------+----------+----------+------------+---------+----------+------------|
-# |  0 | module.conv1.weight                 | (16, 3, 3, 3)  |           432 |            432 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.39017 | -0.00681 |    0.27798 |
-# |  1 | module.layer1.0.conv1.weight        | (16, 16, 3, 3) |          2304 |           2304 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.14674 | -0.00888 |    0.10358 |
-# |  2 | module.layer1.0.conv2.weight        | (16, 16, 3, 3) |          2304 |           2304 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.14363 |  0.00146 |    0.10571 |
-# |  3 | module.layer1.1.conv1.weight        | (16, 16, 3, 3) |          2304 |           2304 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.12673 | -0.01323 |    0.09655 |
-# |  4 | module.layer1.1.conv2.weight        | (16, 16, 3, 3) |          2304 |           2304 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.11736 | -0.00420 |    0.09039 |
-# |  5 | module.layer1.2.conv1.weight        | (16, 16, 3, 3) |          2304 |           2304 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.16400 | -0.00959 |    0.12023 |
-# |  6 | module.layer1.2.conv2.weight        | (16, 16, 3, 3) |          2304 |           2304 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.13288 | -0.00014 |    0.10020 |
-# |  7 | module.layer2.0.conv1.weight        | (20, 16, 3, 3) |          2880 |           2880 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.14688 | -0.00195 |    0.11372 |
-# |  8 | module.layer2.0.conv2.weight        | (32, 20, 3, 3) |          5760 |           5760 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.12828 | -0.00643 |    0.10049 |
-# |  9 | module.layer2.0.downsample.0.weight | (32, 16, 1, 1) |           512 |            512 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.25453 | -0.00949 |    0.17990 |
-# | 10 | module.layer2.1.conv1.weight        | (20, 32, 3, 3) |          5760 |           5760 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.10884 | -0.00760 |    0.08639 |
-# | 11 | module.layer2.1.conv2.weight        | (32, 20, 3, 3) |          5760 |           5760 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.09702 | -0.00599 |    0.07635 |
-# | 12 | module.layer2.2.conv1.weight        | (20, 32, 3, 3) |          5760 |           5760 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.11464 | -0.01339 |    0.09051 |
-# | 13 | module.layer2.2.conv2.weight        | (32, 20, 3, 3) |          5760 |           5760 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.09177 |  0.00195 |    0.07188 |
-# | 14 | module.layer3.0.conv1.weight        | (64, 32, 3, 3) |         18432 |          18432 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.09764 | -0.00680 |    0.07753 |
-# | 15 | module.layer3.0.conv2.weight        | (64, 64, 3, 3) |         36864 |          36864 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.09308 | -0.00392 |    0.07406 |
-# | 16 | module.layer3.0.downsample.0.weight | (64, 32, 1, 1) |          2048 |           2048 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.12596 | -0.00848 |    0.09993 |
-# | 17 | module.layer3.1.conv1.weight        | (64, 64, 3, 3) |         36864 |          11060 |    0.00000 |    0.00000 |  0.00000 |  6.49414 |  0.00000 |   69.99783 | 0.07444 | -0.00396 |    0.03728 |
-# | 18 | module.layer3.1.conv2.weight        | (64, 64, 3, 3) |         36864 |          11060 |    0.00000 |    0.00000 |  0.00000 |  7.49512 |  0.00000 |   69.99783 | 0.06792 | -0.00462 |    0.03381 |
-# | 19 | module.layer3.2.conv1.weight        | (64, 64, 3, 3) |         36864 |          11060 |    0.00000 |    0.00000 |  0.00000 |  9.81445 |  0.00000 |   69.99783 | 0.06811 | -0.00477 |    0.03417 |
-# | 20 | module.layer3.2.conv2.weight        | (64, 64, 3, 3) |         36864 |          11060 |    0.00000 |    0.00000 |  0.00000 | 26.00098 |  0.00000 |   69.99783 | 0.03877 |  0.00056 |    0.01954 |
-# | 21 | module.fc.weight                    | (10, 64)       |           640 |            640 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.56077 | -0.00002 |    0.48798 |
-# | 22 | Total sparsity:                     | -              |        251888 |         148672 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |   40.97694 | 0.00000 |  0.00000 |    0.00000 |
+# |  0 | module.conv1.weight                 | (16, 3, 3, 3)  |           432 |            432 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.26754 | -0.00478 |    0.18996 |
+# |  1 | module.layer1.0.conv1.weight        | (16, 16, 3, 3) |          2304 |           2304 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.10113 | -0.00595 |    0.07182 |
+# |  2 | module.layer1.0.conv2.weight        | (16, 16, 3, 3) |          2304 |           2304 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.09882 | -0.00013 |    0.07256 |
+# |  3 | module.layer1.1.conv1.weight        | (16, 16, 3, 3) |          2304 |           2304 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.08715 | -0.01028 |    0.06691 |
+# |  4 | module.layer1.1.conv2.weight        | (16, 16, 3, 3) |          2304 |           2304 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.08150 | -0.00316 |    0.06242 |
+# |  5 | module.layer1.2.conv1.weight        | (16, 16, 3, 3) |          2304 |           2304 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.11227 | -0.00627 |    0.08206 |
+# |  6 | module.layer1.2.conv2.weight        | (16, 16, 3, 3) |          2304 |           2304 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.09145 |  0.00145 |    0.06919 |
+# |  7 | module.layer2.0.conv1.weight        | (20, 16, 3, 3) |          2880 |           2880 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.09975 | -0.00178 |    0.07747 |
+# |  8 | module.layer2.0.conv2.weight        | (32, 20, 3, 3) |          5760 |           5760 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.08692 | -0.00438 |    0.06784 |
+# |  9 | module.layer2.0.downsample.0.weight | (32, 16, 1, 1) |           512 |            512 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.17339 | -0.00644 |    0.12457 |
+# | 10 | module.layer2.1.conv1.weight        | (20, 32, 3, 3) |          5760 |           5760 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.07515 | -0.00582 |    0.05967 |
+# | 11 | module.layer2.1.conv2.weight        | (32, 20, 3, 3) |          5760 |           5760 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.06694 | -0.00409 |    0.05272 |
+# | 12 | module.layer2.2.conv1.weight        | (20, 32, 3, 3) |          5760 |           5760 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.07822 | -0.00873 |    0.06161 |
+# | 13 | module.layer2.2.conv2.weight        | (32, 20, 3, 3) |          5760 |           5760 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.06251 |  0.00119 |    0.04923 |
+# | 14 | module.layer3.0.conv1.weight        | (64, 32, 3, 3) |         18432 |          18432 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.06655 | -0.00436 |    0.05293 |
+# | 15 | module.layer3.0.conv2.weight        | (64, 64, 3, 3) |         36864 |          36864 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.06298 | -0.00286 |    0.05019 |
+# | 16 | module.layer3.0.downsample.0.weight | (64, 32, 1, 1) |          2048 |           2048 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.08574 | -0.00490 |    0.06750 |
+# | 17 | module.layer3.1.conv1.weight        | (64, 64, 3, 3) |         36864 |          11060 |    0.00000 |    0.00000 |  0.00000 |  7.00684 |  0.00000 |   69.99783 | 0.05113 | -0.00318 |    0.02568 |
+# | 18 | module.layer3.1.conv2.weight        | (64, 64, 3, 3) |         36864 |          11060 |    0.00000 |    0.00000 |  0.00000 |  7.64160 |  0.00000 |   69.99783 | 0.04585 | -0.00355 |    0.02293 |
+# | 19 | module.layer3.2.conv1.weight        | (64, 64, 3, 3) |         36864 |          11060 |    0.00000 |    0.00000 |  0.00000 | 10.88867 |  0.00000 |   69.99783 | 0.04487 | -0.00409 |    0.02258 |
+# | 20 | module.layer3.2.conv2.weight        | (64, 64, 3, 3) |         36864 |          11060 |    0.00000 |    0.00000 |  0.00000 | 31.51855 |  1.56250 |   69.99783 | 0.02512 |  0.00008 |    0.01251 |
+# | 21 | module.fc.weight                    | (10, 64)       |           640 |            320 |    0.00000 |   50.00000 |  0.00000 |  0.00000 |  0.00000 |   50.00000 | 0.48359 | -0.00001 |    0.30379 |
+# | 22 | Total sparsity:                     | -              |        251888 |         148352 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |   41.10398 | 0.00000 |  0.00000 |    0.00000 |
 # +----+-------------------------------------+----------------+---------------+----------------+------------+------------+----------+----------+----------+------------+---------+----------+------------+
-# Total sparsity: 40.98
+# Total sparsity: 41.10
 #
 # --- validate (epoch=359)-----------
 # 5000 samples (256 per mini-batch)
-# ==> Top1: 93.320    Top5: 99.880    Loss: 0.246
+# ==> Top1: 93.720    Top5: 99.880    Loss: 1.529
 #
-# ==> Best Top1: 93.740   On Epoch: 265
+# ==> Best Top1: 96.900   On Epoch: 181
 #
-# Saving checkpoint to: logs/2018.10.09-020359/checkpoint.pth.tar
+# Saving checkpoint to: logs/2018.10.15-111439/checkpoint.pth.tar
 # --- test ---------------------
 # 10000 samples (256 per mini-batch)
-# ==> Top1: 91.580    Top5: 99.710    Loss: 0.355
+# ==> Top1: 91.760    Top5: 99.700    Loss: 1.546
+#
 #
+# Log file for this run: /home/cvds_lab/nzmora/pytorch_workspace/distiller/examples/classifier_compression/logs/2018.10.15-111439/2018.10.15-111439.log
+#
+# real    31m55.802s
+# user    73m1.353s
+# sys     8m46.687s
+
 
 version: 1
+
 pruners:
   low_pruner:
-    class: StructuredAutomatedGradualPruner
+    class: L1RankedStructureParameterPruner_AGP
     initial_sparsity : 0.10
     final_sparsity: 0.40
     reg_regims:
@@ -62,6 +82,13 @@ pruners:
     weights: [module.layer3.1.conv1.weight,  module.layer3.1.conv2.weight,
               module.layer3.2.conv1.weight,  module.layer3.2.conv2.weight]
 
+  fc_pruner:
+    class: L1RankedStructureParameterPruner_AGP
+    initial_sparsity : 0.05
+    final_sparsity: 0.50
+    reg_regims:
+      module.fc.weight: Rows
+
 lr_schedulers:
   pruning_lr:
     class: StepLR
@@ -88,8 +115,15 @@ policies:
     starting_epoch: 200
     ending_epoch: 220
     frequency: 2
-  # Currently the thinner is disabled until the end, because it interacts with the sparsity
-  # goals of the StructuredAutomatedGradualPruner.
+
+  - pruner:
+      instance_name : fc_pruner
+    starting_epoch: 200
+    ending_epoch: 220
+    frequency: 2
+
+  # Currently the thinner is disabled until the the structure pruner is done, because it interacts
+  # with the sparsity goals of the L1RankedStructureParameterPruner_AGP.
   # This can be fixed rather easily.
   # - extension:
   #     instance_name: net_thinner
diff --git a/examples/agp-pruning/resnet20_filters.schedule_agp_2.yaml b/examples/agp-pruning/resnet20_filters.schedule_agp_2.yaml
index c91a157..730dd9a 100755
--- a/examples/agp-pruning/resnet20_filters.schedule_agp_2.yaml
+++ b/examples/agp-pruning/resnet20_filters.schedule_agp_2.yaml
@@ -1,60 +1,72 @@
+# This is a hybrid pruning schedule composed of several pruning techniques, all using AGP scheduling:
+# 1. Filter pruning (and thinning) to reduce compute and activation sizes of some layers.
+# 2. Fine grained pruning to reduce the parameter memory requirements of layers with large weights tensors.
+# 3. Row pruning for the last linear (fully-connected) layer.
+#
 # Baseline results:
-# Top1: 91.780    Top5: 99.710    Loss: 0.376
+#     Top1: 91.780    Top5: 99.710    Loss: 0.376
+#     Total MACs: 40,813,184
+#
+# Results:
+#     Top1: 91.200    Top5: 99.660    Loss: 1.551
+#     Total MACs: 30,638,720
+#     Total sparsity: 41.84
+#
 # time python3 compress_classifier.py --arch resnet20_cifar  ../../../data.cifar10 -p=50 --lr=0.1 --epochs=180 --compress=../agp-pruning/resnet20_filters.schedule_agp.yaml -j=1 --deterministic --resume=../ssl/checkpoints/checkpoint_trained_dense.pth.tar
 #
-# 
+#
 # Parameters:
 # +----+-------------------------------------+----------------+---------------+----------------+------------+------------+----------+----------+----------+------------+---------+----------+------------+
 # |    | Name                                | Shape          |   NNZ (dense) |   NNZ (sparse) |   Cols (%) |   Rows (%) |   Ch (%) |   2D (%) |   3D (%) |   Fine (%) |     Std |     Mean |   Abs-Mean |
 # |----+-------------------------------------+----------------+---------------+----------------+------------+------------+----------+----------+----------+------------+---------+----------+------------|
-# |  0 | module.conv1.weight                 | (16, 3, 3, 3)  |           432 |            432 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.39196 | -0.00533 |    0.27677 |
-# |  1 | module.layer1.0.conv1.weight        | (10, 16, 3, 3) |          1440 |           1440 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.17516 | -0.01627 |    0.12761 |
-# |  2 | module.layer1.0.conv2.weight        | (16, 10, 3, 3) |          1440 |           1440 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.17375 |  0.00208 |    0.12753 |
-# |  3 | module.layer1.1.conv1.weight        | (10, 16, 3, 3) |          1440 |           1440 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.14753 | -0.02355 |    0.11205 |
-# |  4 | module.layer1.1.conv2.weight        | (16, 10, 3, 3) |          1440 |           1440 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.13242 | -0.00280 |    0.10184 |
-# |  5 | module.layer1.2.conv1.weight        | (10, 16, 3, 3) |          1440 |           1440 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.18848 | -0.00708 |    0.13828 |
-# |  6 | module.layer1.2.conv2.weight        | (16, 10, 3, 3) |          1440 |           1440 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.15502 | -0.00528 |    0.11709 |
-# |  7 | module.layer2.0.conv1.weight        | (20, 16, 3, 3) |          2880 |           2880 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.15266 | -0.00169 |    0.11773 |
-# |  8 | module.layer2.0.conv2.weight        | (32, 20, 3, 3) |          5760 |           5760 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.13070 | -0.00823 |    0.10204 |
-# |  9 | module.layer2.0.downsample.0.weight | (32, 16, 1, 1) |           512 |            512 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.25380 | -0.01324 |    0.17815 |
-# | 10 | module.layer2.1.conv1.weight        | (20, 32, 3, 3) |          5760 |           5760 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.11349 | -0.00928 |    0.08977 |
-# | 11 | module.layer2.1.conv2.weight        | (32, 20, 3, 3) |          5760 |           5760 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.09904 | -0.00621 |    0.07856 |
-# | 12 | module.layer2.2.conv1.weight        | (20, 32, 3, 3) |          5760 |           5760 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.11538 | -0.01280 |    0.09106 |
-# | 13 | module.layer2.2.conv2.weight        | (32, 20, 3, 3) |          5760 |           5760 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.09239 |  0.00091 |    0.07236 |
-# | 14 | module.layer3.0.conv1.weight        | (64, 32, 3, 3) |         18432 |          18432 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.09853 | -0.00671 |    0.07821 |
-# | 15 | module.layer3.0.conv2.weight        | (64, 64, 3, 3) |         36864 |          36864 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.09391 | -0.00407 |    0.07466 |
-# | 16 | module.layer3.0.downsample.0.weight | (64, 32, 1, 1) |          2048 |           2048 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.12660 | -0.00968 |    0.10101 |
-# | 17 | module.layer3.1.conv1.weight        | (64, 64, 3, 3) |         36864 |          11060 |    0.00000 |    0.00000 |  0.00000 |  6.56738 |  0.00000 |   69.99783 | 0.07488 | -0.00414 |    0.03739 |
-# | 18 | module.layer3.1.conv2.weight        | (64, 64, 3, 3) |         36864 |          11060 |    0.00000 |    0.00000 |  0.00000 |  7.69043 |  0.00000 |   69.99783 | 0.06839 | -0.00472 |    0.03404 |
-# | 19 | module.layer3.2.conv1.weight        | (64, 64, 3, 3) |         36864 |          11060 |    0.00000 |    0.00000 |  0.00000 |  9.47266 |  0.00000 |   69.99783 | 0.06867 | -0.00485 |    0.03450 |
-# | 20 | module.layer3.2.conv2.weight        | (64, 64, 3, 3) |         36864 |          11060 |    0.00000 |    0.00000 |  0.00000 | 26.41602 |  0.00000 |   69.99783 | 0.03915 |  0.00033 |    0.01970 |
-# | 21 | module.fc.weight                    | (10, 64)       |           640 |            640 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.56522 | -0.00002 |    0.49040 |
+# |  0 | module.conv1.weight                 | (16, 3, 3, 3)  |           432 |            432 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.27315 | -0.00387 |    0.19394 |
+# |  1 | module.layer1.0.conv1.weight        | (10, 16, 3, 3) |          1440 |           1440 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.12038 | -0.01295 |    0.08811 |
+# |  2 | module.layer1.0.conv2.weight        | (16, 10, 3, 3) |          1440 |           1440 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.11879 | -0.00031 |    0.08735 |
+# |  3 | module.layer1.1.conv1.weight        | (10, 16, 3, 3) |          1440 |           1440 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.10293 | -0.01274 |    0.07795 |
+# |  4 | module.layer1.1.conv2.weight        | (16, 10, 3, 3) |          1440 |           1440 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.09285 | -0.00276 |    0.07141 |
+# |  5 | module.layer1.2.conv1.weight        | (10, 16, 3, 3) |          1440 |           1440 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.12849 | -0.00345 |    0.09355 |
+# |  6 | module.layer1.2.conv2.weight        | (16, 10, 3, 3) |          1440 |           1440 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.10689 | -0.00381 |    0.08038 |
+# |  7 | module.layer2.0.conv1.weight        | (20, 16, 3, 3) |          2880 |           2880 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.10467 | -0.00371 |    0.08149 |
+# |  8 | module.layer2.0.conv2.weight        | (32, 20, 3, 3) |          5760 |           5760 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.08897 | -0.00502 |    0.06938 |
+# |  9 | module.layer2.0.downsample.0.weight | (32, 16, 1, 1) |           512 |            512 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.17695 | -0.01111 |    0.12479 |
+# | 10 | module.layer2.1.conv1.weight        | (20, 32, 3, 3) |          5760 |           5760 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.07736 | -0.00531 |    0.06118 |
+# | 11 | module.layer2.1.conv2.weight        | (32, 20, 3, 3) |          5760 |           5760 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.06832 | -0.00404 |    0.05406 |
+# | 12 | module.layer2.2.conv1.weight        | (20, 32, 3, 3) |          5760 |           5760 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.07965 | -0.00904 |    0.06278 |
+# | 13 | module.layer2.2.conv2.weight        | (32, 20, 3, 3) |          5760 |           5760 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.06305 |  0.00122 |    0.04955 |
+# | 14 | module.layer3.0.conv1.weight        | (64, 32, 3, 3) |         18432 |          18432 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.06753 | -0.00459 |    0.05371 |
+# | 15 | module.layer3.0.conv2.weight        | (64, 64, 3, 3) |         36864 |          36864 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.06379 | -0.00297 |    0.05078 |
+# | 16 | module.layer3.0.downsample.0.weight | (64, 32, 1, 1) |          2048 |           2048 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.08779 | -0.00584 |    0.06956 |
+# | 17 | module.layer3.1.conv1.weight        | (64, 64, 3, 3) |         36864 |          11060 |    0.00000 |    0.00000 |  0.00000 |  7.12891 |  0.00000 |   69.99783 | 0.05191 | -0.00319 |    0.02604 |
+# | 18 | module.layer3.1.conv2.weight        | (64, 64, 3, 3) |         36864 |          11060 |    0.00000 |    0.00000 |  0.00000 |  7.98340 |  0.00000 |   69.99783 | 0.04658 | -0.00360 |    0.02330 |
+# | 19 | module.layer3.2.conv1.weight        | (64, 64, 3, 3) |         36864 |          11060 |    0.00000 |    0.00000 |  0.00000 | 10.10742 |  0.00000 |   69.99783 | 0.04563 | -0.00393 |    0.02297 |
+# | 20 | module.layer3.2.conv2.weight        | (64, 64, 3, 3) |         36864 |          11060 |    0.00000 |    0.00000 |  0.00000 | 31.20117 |  0.00000 |   69.99783 | 0.02453 |  0.00005 |    0.01240 |
+# | 21 | module.fc.weight                    | (10, 64)       |           640 |            640 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.50451 | -0.00001 |    0.43698 |
 # | 22 | Total sparsity:                     | -              |        246704 |         143488 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |   41.83799 | 0.00000 |  0.00000 |    0.00000 |
 # +----+-------------------------------------+----------------+---------------+----------------+------------+------------+----------+----------+----------+------------+---------+----------+------------+
 # Total sparsity: 41.84
 #
 # --- validate (epoch=359)-----------
 # 5000 samples (256 per mini-batch)
-# ==> Top1: 92.540    Top5: 99.960    Loss: 0.246
+# ==> Top1: 93.460    Top5: 99.800    Loss: 1.530
 #
-# ==> Best Top1: 93.580   On Epoch: 328
+# ==> Best Top1: 97.320   On Epoch: 180
 #
-# Saving checkpoint to: logs/2018.10.09-200709/checkpoint.pth.tar
+# Saving checkpoint to: logs/2018.10.15-115941/checkpoint.pth.tar
 # --- test ---------------------
 # 10000 samples (256 per mini-batch)
-# ==> Top1: 91.190    Top5: 99.660    Loss: 0.372
+# ==> Top1: 91.200    Top5: 99.660    Loss: 1.551
 #
 #
-# Log file for this run: /home/cvds_lab/nzmora/pytorch_workspace/distiller/examples/classifier_compression/logs/2018.10.09-200709/2018.10.09-200709.log
+# Log file for this run: /home/cvds_lab/nzmora/pytorch_workspace/distiller/examples/classifier_compression/logs/2018.10.15-115941/2018.10.15-115941.log
 #
-# real    32m23.439s
-# user    74m59.073s
-# sys     9m7.764s
+# real    32m31.997s
+# user    72m58.813s
+# sys     9m1.245s
 
 version: 1
 pruners:
   low_pruner:
-    class: StructuredAutomatedGradualPruner
+    class: L1RankedStructureParameterPruner_AGP
     initial_sparsity : 0.10
     final_sparsity: 0.40
     reg_regims:
@@ -99,7 +111,7 @@ policies:
     ending_epoch: 220
     frequency: 2
   # Currently the thinner is disabled until the end, because it interacts with the sparsity
-  # goals of the StructuredAutomatedGradualPruner.
+  # goals of the L1RankedStructureParameterPruner_AGP.
   # This can be fixed rather easily.
   # - extension:
   #     instance_name: net_thinner
diff --git a/examples/classifier_compression/compress_classifier.py b/examples/classifier_compression/compress_classifier.py
index e0687f1..dc0c5be 100755
--- a/examples/classifier_compression/compress_classifier.py
+++ b/examples/classifier_compression/compress_classifier.py
@@ -57,7 +57,7 @@ import os
 import sys
 import random
 import traceback
-from collections import OrderedDict
+from collections import OrderedDict, defaultdict
 from functools import partial
 import numpy as np
 import torch
@@ -75,7 +75,7 @@ except ImportError:
     sys.path.append(module_path)
     import distiller
 import apputils
-from distiller.data_loggers import TensorBoardLogger, PythonLogger, ActivationSparsityCollector
+from distiller.data_loggers import *
 import distiller.quantization as quantization
 from models import ALL_MODEL_NAMES, create_model
 
@@ -118,7 +118,7 @@ parser.add_argument('-e', '--evaluate', dest='evaluate', action='store_true',
                     help='evaluate model on validation set')
 parser.add_argument('--pretrained', dest='pretrained', action='store_true',
                     help='use pre-trained model')
-parser.add_argument('--act-stats', dest='activation_stats', action='store_true', default=False,
+parser.add_argument('--act-stats', dest='activation_stats', choices=["train", "valid", "test"], default=None,
                     help='collect activation statistics (WARNING: this slows down training)')
 parser.add_argument('--param-hist', dest='log_params_histograms', action='store_true', default=False,
                     help='log the paramter tensors histograms to file (WARNING: this can use significant disk space)')
@@ -153,6 +153,7 @@ parser.add_argument('--earlyexit_thresholds', type=float, nargs='*', dest='early
 
 distiller.knowledge_distillation.add_distillation_args(parser, ALL_MODEL_NAMES, True)
 
+
 def check_pytorch_version():
     if torch.__version__ < '0.4.0':
         print("\nNOTICE:")
@@ -166,6 +167,48 @@ def check_pytorch_version():
         exit(1)
 
 
+def create_activation_stats_collectors(model, collection_phase):
+    """Create objects that collect activation statistics.
+
+    This is a utility function that creates two collectors:
+    1. Fine-grade sparsity levels of the activations
+    2. L1-magnitude of each of the activation channels
+
+    Args:
+        model - the model on which we want to collect statistics
+        phase - the statistics collection phase which is either "train" (for training),
+                or "valid" (for validation)
+
+    WARNING! Enabling activation statsitics collection will significantly slow down training!
+    """
+    class missingdict(dict):
+        """This is a little trick to prevent KeyError"""
+        def __missing__(self, key):
+            return None  # note, does *not* set self[key] - we don't want defaultdict's behavior
+
+    distiller.utils.assign_layer_fq_names(model)
+
+    activations_collectors = {"train": missingdict(), "valid": missingdict(), "test": missingdict()}
+    if collection_phase is None:
+        return activations_collectors
+    collectors = missingdict()
+    collectors["sparsity"] = SummaryActivationStatsCollector(model, "sparsity", distiller.utils.sparsity)
+    collectors["l1_channels"] = SummaryActivationStatsCollector(model, "l1_channels",
+                                                                distiller.utils.activation_channels_l1)
+    collectors["apoz_channels"] = SummaryActivationStatsCollector(model, "apoz_channels",
+                                                                  distiller.utils.activation_channels_apoz)
+    collectors["records"] = RecordsActivationStatsCollector(model, classes=[torch.nn.Conv2d])
+    activations_collectors[collection_phase] = collectors
+    return activations_collectors
+
+
+def save_collectors_data(collectors, directory):
+    """Utility function that saves all activation statistics to Excel workbooks
+    """
+    for name, collector in collectors.items():
+        collector.to_xlsx(os.path.join(directory, name))
+
+
 def main():
     global msglogger
     check_pytorch_version()
@@ -267,18 +310,13 @@ def main():
     msglogger.info('Dataset sizes:\n\ttraining=%d\n\tvalidation=%d\n\ttest=%d',
                    len(train_loader.sampler), len(val_loader.sampler), len(test_loader.sampler))
 
-    activations_sparsity = None
-    if args.activation_stats:
-        # If your model has ReLU layers, then those layers have sparse activations.
-        # ActivationSparsityCollector will collect information about this sparsity.
-        # WARNING! Enabling activation sparsity collection will significantly slow down training!
-        activations_sparsity = ActivationSparsityCollector(model)
+    activations_collectors = create_activation_stats_collectors(model, collection_phase=args.activation_stats)
 
     if args.sensitivity is not None:
         return sensitivity_analysis(model, criterion, test_loader, pylogger, args)
 
     if args.evaluate:
-        return evaluate_model(model, criterion, test_loader, pylogger, args)
+        return evaluate_model(model, criterion, test_loader, pylogger, activations_collectors, args)
 
     if args.compress:
         # The main use-case for this sample application is CNN compression. Compression
@@ -313,15 +351,20 @@ def main():
             compression_scheduler.on_epoch_begin(epoch)
 
         # Train for one epoch
-        train(train_loader, model, criterion, optimizer, epoch, compression_scheduler,
-              loggers=[tflogger, pylogger], args=args)
-        distiller.log_weights_sparsity(model, epoch, loggers=[tflogger, pylogger])
-        if args.activation_stats:
-            distiller.log_activation_sparsity(epoch, loggers=[tflogger, pylogger],
-                                              collector=activations_sparsity)
+        with collectors_context(activations_collectors["train"]) as collectors:
+            train(train_loader, model, criterion, optimizer, epoch, compression_scheduler,
+                  loggers=[tflogger, pylogger], args=args)
+            distiller.log_weights_sparsity(model, epoch, loggers=[tflogger, pylogger])
+            distiller.log_activation_statsitics(epoch, "train", loggers=[tflogger],
+                                                collector=collectors["sparsity"])
 
         # evaluate on validation set
-        top1, top5, vloss = validate(val_loader, model, criterion, [pylogger], args, epoch)
+        with collectors_context(activations_collectors["valid"]) as collectors:
+            top1, top5, vloss = validate(val_loader, model, criterion, [pylogger], args, epoch)
+            distiller.log_activation_statsitics(epoch, "valid", loggers=[tflogger],
+                                                collector=collectors["sparsity"])
+            save_collectors_data(collectors, msglogger.logdir)
+
         stats = ('Peformance/Validation/',
                  OrderedDict([('Loss', vloss),
                               ('Top1', top1),
@@ -342,7 +385,7 @@ def main():
                                  args.name, msglogger.logdir)
 
     # Finally run results on the test set
-    test(test_loader, model, criterion, [pylogger], args=args)
+    test(test_loader, model, criterion, [pylogger], activations_collectors, args=args)
 
 
 OVERALL_LOSS_KEY = 'Overall Loss'
@@ -460,10 +503,15 @@ def validate(val_loader, model, criterion, loggers, args, epoch=-1):
     return _validate(val_loader, model, criterion, loggers, args, epoch)
 
 
-def test(test_loader, model, criterion, loggers, args):
+def test(test_loader, model, criterion, loggers, activations_collectors, args):
     """Model Test"""
     msglogger.info('--- test ---------------------')
-    return _validate(test_loader, model, criterion, loggers, args)
+
+    with collectors_context(activations_collectors["test"]) as collectors:
+        top1, top5, lossses = _validate(test_loader, model, criterion, loggers, args)
+        distiller.log_activation_statsitics(-1, "test", loggers, collector=collectors['sparsity'])
+
+    return top1, top5, lossses
 
 
 def _validate(data_loader, model, criterion, loggers, args, epoch=-1):
@@ -606,7 +654,7 @@ def earlyexit_validate_loss(output, target, criterion, args):
                 args.exit_taken[args.num_exits-1] += 1
 
 
-def evaluate_model(model, criterion, test_loader, loggers, args):
+def evaluate_model(model, criterion, test_loader, loggers, activations_collectors, args):
     # This sample application can be invoked to evaluate the accuracy of your model on
     # the test dataset.
     # You can optionally quantize the model to 8-bit integer before evaluation.
@@ -621,7 +669,9 @@ def evaluate_model(model, criterion, test_loader, loggers, args):
         quantizer = quantization.SymmetricLinearQuantizer(model, 8, 8)
         quantizer.prepare_model()
         model.cuda()
-    top1, _, _ = test(test_loader, model, criterion, loggers, args=args)
+
+    top1, _, _ = test(test_loader, model, criterion, loggers, activations_collectors, args=args)
+
     if args.quantize:
         checkpoint_name = 'quantized'
         apputils.save_checkpoint(0, args.arch, model, optimizer=None, best_top1=top1,
@@ -645,7 +695,8 @@ def sensitivity_analysis(model, criterion, data_loader, loggers, args):
     if not isinstance(loggers, list):
         loggers = [loggers]
     test_fnc = partial(test, test_loader=data_loader, criterion=criterion,
-                       loggers=loggers, args=args)
+                       loggers=loggers, args=args,
+                       activations_collectors=create_activation_stats_collectors(model, None))
     which_params = [param_name for param_name, _ in model.named_parameters()]
     sensitivity = distiller.perform_sensitivity_analysis(model,
                                                          net_params=which_params,
diff --git a/examples/network_trimming/resnet56_cifar_activation_apoz.yaml b/examples/network_trimming/resnet56_cifar_activation_apoz.yaml
new file mode 100755
index 0000000..1e9e455
--- /dev/null
+++ b/examples/network_trimming/resnet56_cifar_activation_apoz.yaml
@@ -0,0 +1,212 @@
+#
+# This schedule uses the average percentage of zeros (APoZ) in the activations, to rank filters.
+# Compare this to examples/pruning_filters_for_efficient_convnets/resnet56_cifar_filter_rank.yaml - the pruning time is
+# much longer due to the callbacks required for collecting the activation statistics (this can be improved by disabling
+# of the detailed records collection, for example).
+# This provides 62.7% compute compression (x1.6) while increasing the Top1.
+#
+# Baseline results:
+#     Top1: 92.850    Top5: 99.780    Loss: 0.364
+#     Total MACs: 125,747,840
+#
+# Results:
+#     Top1: 93.030    Top5: 99.650    Loss: 1.533
+#     Total MACs: 78,856,832
+#
+#
+# time python3 compress_classifier.py -a=resnet56_cifar -p=50 ../../../data.cifar10 --epochs=70 --lr=0.1 --compress=../network_trimming/resnet56_cifar_activation_apoz.yaml --resume=checkpoint.resnet56_cifar_baseline.pth.tar -j=1 --deterministic --act-stats=valid
+#
+# Parameters:
+# +----+-------------------------------------+----------------+---------------+----------------+------------+------------+----------+----------+----------+------------+---------+----------+------------+
+# |    | Name                                | Shape          |   NNZ (dense) |   NNZ (sparse) |   Cols (%) |   Rows (%) |   Ch (%) |   2D (%) |   3D (%) |   Fine (%) |     Std |     Mean |   Abs-Mean |
+# |----+-------------------------------------+----------------+---------------+----------------+------------+------------+----------+----------+----------+------------+---------+----------+------------|
+# |  0 | module.conv1.weight                 | (16, 3, 3, 3)  |           432 |            432 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.25444 |  0.01128 |    0.13307 |
+# |  1 | module.layer1.0.conv1.weight        | (7, 16, 3, 3)  |          1008 |           1008 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.07351 |  0.00182 |    0.04119 |
+# |  2 | module.layer1.0.conv2.weight        | (16, 7, 3, 3)  |          1008 |           1008 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.07510 | -0.00968 |    0.05190 |
+# |  3 | module.layer1.1.conv1.weight        | (7, 16, 3, 3)  |          1008 |           1008 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.06982 |  0.00599 |    0.04476 |
+# |  4 | module.layer1.1.conv2.weight        | (16, 7, 3, 3)  |          1008 |           1008 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.05886 | -0.01451 |    0.04284 |
+# |  5 | module.layer1.2.conv1.weight        | (7, 16, 3, 3)  |          1008 |           1008 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.06894 | -0.00031 |    0.04735 |
+# |  6 | module.layer1.2.conv2.weight        | (16, 7, 3, 3)  |          1008 |           1008 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.06561 | -0.00311 |    0.04952 |
+# |  7 | module.layer1.3.conv1.weight        | (7, 16, 3, 3)  |          1008 |           1008 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.07374 | -0.00087 |    0.05137 |
+# |  8 | module.layer1.3.conv2.weight        | (16, 7, 3, 3)  |          1008 |           1008 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.07346 | -0.00474 |    0.05348 |
+# |  9 | module.layer1.4.conv1.weight        | (7, 16, 3, 3)  |          1008 |           1008 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.06855 |  0.00053 |    0.04867 |
+# | 10 | module.layer1.4.conv2.weight        | (16, 7, 3, 3)  |          1008 |           1008 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.07078 | -0.01038 |    0.05366 |
+# | 11 | module.layer1.5.conv1.weight        | (7, 16, 3, 3)  |          1008 |           1008 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.09372 | -0.00430 |    0.06283 |
+# | 12 | module.layer1.5.conv2.weight        | (16, 7, 3, 3)  |          1008 |           1008 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.09056 | -0.00089 |    0.06517 |
+# | 13 | module.layer1.6.conv1.weight        | (7, 16, 3, 3)  |          1008 |           1008 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.08050 | -0.00971 |    0.06157 |
+# | 14 | module.layer1.6.conv2.weight        | (16, 7, 3, 3)  |          1008 |           1008 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.08000 | -0.00081 |    0.06004 |
+# | 15 | module.layer1.7.conv1.weight        | (7, 16, 3, 3)  |          1008 |           1008 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.09966 | -0.01270 |    0.07424 |
+# | 16 | module.layer1.7.conv2.weight        | (16, 7, 3, 3)  |          1008 |           1008 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.09293 |  0.00685 |    0.07128 |
+# | 17 | module.layer1.8.conv1.weight        | (7, 16, 3, 3)  |          1008 |           1008 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.08764 | -0.01361 |    0.06730 |
+# | 18 | module.layer1.8.conv2.weight        | (16, 7, 3, 3)  |          1008 |           1008 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.07053 |  0.00491 |    0.05341 |
+# | 19 | module.layer2.0.conv1.weight        | (32, 16, 3, 3) |          4608 |           4608 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.09426 | -0.00345 |    0.07094 |
+# | 20 | module.layer2.0.conv2.weight        | (32, 32, 3, 3) |          9216 |           9216 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.07798 | -0.00154 |    0.05783 |
+# | 21 | module.layer2.0.downsample.0.weight | (32, 16, 1, 1) |           512 |            512 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.16513 |  0.00688 |    0.11354 |
+# | 24 | module.layer2.2.conv1.weight        | (16, 32, 3, 3) |          4608 |           4608 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.05589 | -0.00558 |    0.04355 |
+# | 25 | module.layer2.2.conv2.weight        | (32, 16, 3, 3) |          4608 |           4608 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.04979 | -0.00491 |    0.03863 |
+# | 26 | module.layer2.3.conv1.weight        | (16, 32, 3, 3) |          4608 |           4608 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.05622 | -0.00471 |    0.04379 |
+# | 27 | module.layer2.3.conv2.weight        | (32, 16, 3, 3) |          4608 |           4608 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.04541 | -0.00271 |    0.03535 |
+# | 28 | module.layer2.4.conv1.weight        | (16, 32, 3, 3) |          4608 |           4608 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.05166 | -0.00597 |    0.03896 |
+# | 29 | module.layer2.4.conv2.weight        | (32, 16, 3, 3) |          4608 |           4608 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.04098 | -0.00381 |    0.03114 |
+# | 30 | module.layer2.5.conv1.weight        | (32, 32, 3, 3) |          9216 |           9216 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.04188 | -0.00373 |    0.03040 |
+# | 31 | module.layer2.5.conv2.weight        | (32, 32, 3, 3) |          9216 |           9216 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.03249 | -0.00190 |    0.02291 |
+# | 32 | module.layer2.6.conv1.weight        | (16, 32, 3, 3) |          4608 |           4608 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.04584 | -0.00553 |    0.03569 |
+# | 33 | module.layer2.6.conv2.weight        | (32, 16, 3, 3) |          4608 |           4608 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.03655 | -0.00216 |    0.02758 |
+# | 34 | module.layer2.7.conv1.weight        | (16, 32, 3, 3) |          4608 |           4608 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.05110 | -0.00700 |    0.03909 |
+# | 35 | module.layer2.7.conv2.weight        | (32, 16, 3, 3) |          4608 |           4608 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.03884 | -0.00129 |    0.02946 |
+# | 36 | module.layer2.8.conv1.weight        | (32, 32, 3, 3) |          9216 |           9216 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.03331 | -0.00269 |    0.02211 |
+# | 37 | module.layer2.8.conv2.weight        | (32, 32, 3, 3) |          9216 |           9216 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.02406 | -0.00014 |    0.01479 |
+# | 38 | module.layer3.0.conv1.weight        | (64, 32, 3, 3) |         18432 |          18432 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.05957 | -0.00091 |    0.04658 |
+# | 39 | module.layer3.0.conv2.weight        | (64, 64, 3, 3) |         36864 |          36864 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.05103 | -0.00016 |    0.03729 |
+# | 40 | module.layer3.0.downsample.0.weight | (64, 32, 1, 1) |          2048 |           2048 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.09200 |  0.00203 |    0.06440 |
+# | 41 | module.layer3.1.conv1.weight        | (58, 64, 3, 3) |         33408 |          33408 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.03758 | -0.00117 |    0.02728 |
+# | 42 | module.layer3.1.conv2.weight        | (64, 58, 3, 3) |         33408 |          33408 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.03577 | -0.00397 |    0.02686 |
+# | 43 | module.layer3.2.conv1.weight        | (45, 64, 3, 3) |         25920 |          25920 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.03704 | -0.00146 |    0.02762 |
+# | 44 | module.layer3.2.conv2.weight        | (64, 45, 3, 3) |         25920 |          25920 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.03409 | -0.00464 |    0.02638 |
+# | 45 | module.layer3.3.conv1.weight        | (45, 64, 3, 3) |         25920 |          25920 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.03887 | -0.00274 |    0.03015 |
+# | 46 | module.layer3.3.conv2.weight        | (64, 45, 3, 3) |         25920 |          25920 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.03390 | -0.00448 |    0.02648 |
+# | 47 | module.layer3.4.conv1.weight        | (64, 64, 3, 3) |         36864 |          36864 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.04296 | -0.00361 |    0.03345 |
+# | 48 | module.layer3.4.conv2.weight        | (64, 64, 3, 3) |         36864 |          36864 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.03454 | -0.00255 |    0.02628 |
+# | 49 | module.layer3.5.conv1.weight        | (45, 64, 3, 3) |         25920 |          25920 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.04208 | -0.00441 |    0.03301 |
+# | 50 | module.layer3.5.conv2.weight        | (64, 45, 3, 3) |         25920 |          25920 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.03186 | -0.00319 |    0.02431 |
+# | 51 | module.layer3.6.conv1.weight        | (45, 64, 3, 3) |         25920 |          25920 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.03119 | -0.00262 |    0.02419 |
+# | 52 | module.layer3.6.conv2.weight        | (64, 45, 3, 3) |         25920 |          25920 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.02298 | -0.00015 |    0.01670 |
+# | 53 | module.layer3.7.conv1.weight        | (45, 64, 3, 3) |         25920 |          25920 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.02912 | -0.00265 |    0.02235 |
+# | 54 | module.layer3.7.conv2.weight        | (64, 45, 3, 3) |         25920 |          25920 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.02078 | -0.00010 |    0.01524 |
+# | 55 | module.layer3.8.conv1.weight        | (45, 64, 3, 3) |         25920 |          25920 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.03270 | -0.00269 |    0.02542 |
+# | 56 | module.layer3.8.conv2.weight        | (64, 45, 3, 3) |         25920 |          25920 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.02244 |  0.00045 |    0.01630 |
+# | 57 | module.fc.weight                    | (10, 64)       |           640 |            640 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.42577 | -0.00001 |    0.33523 |
+# | 58 | Total sparsity:                     | -              |        634640 |         634640 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.00000 |  0.00000 |    0.00000 |
+# +----+-------------------------------------+----------------+---------------+----------------+------------+------------+----------+----------+----------+------------+---------+----------+------------+
+# Total sparsity: 0.00
+#
+# --- validate (epoch=249)-----------
+# 5000 samples (256 per mini-batch)
+# ==> Top1: 92.740    Top5: 99.720    Loss: 1.534
+#
+# ==> Best Top1: 92.760   On Epoch: 237
+#
+# Saving checkpoint to: logs/2018.10.16-013006/checkpoint.pth.tar
+# --- test ---------------------
+# 10000 samples (256 per mini-batch)
+# ==> Top1: 93.030    Top5: 99.650    Loss: 1.533
+#
+#
+# Log file for this run: /home/cvds_lab/nzmora/pytorch_workspace/distiller/examples/classifier_compression/logs/2018.10.16-013006/2018.10.16-013006.log
+#
+# real    49m0.623s
+# user    90m51.054s
+# sys     8m36.745s
+
+version: 1
+pruners:
+#   filter_pruner:
+#     class: 'ActivationAPoZRankedStructureParameterPruner'
+#     reg_regims:
+#       'module.layer1.0.conv1.weight': Filters
+
+  # filter_pruner:
+  #   class: ActivationAPoZRankedFilterPruner_AGP
+  #   initial_sparsity : 0.10
+  #   final_sparsity: 0.6
+  #   reg_regims:
+  #     module.layer1.0.conv1.weight: Filters
+
+  filter_pruner_60:
+    #class: StructuredAutomatedGradualPruner
+    class: ActivationAPoZRankedFilterPruner_AGP
+    initial_sparsity : 0.10
+    final_sparsity: 0.6
+    reg_regims:
+      module.layer1.0.conv1.weight: Filters
+      module.layer1.1.conv1.weight: Filters
+      module.layer1.2.conv1.weight: Filters
+      module.layer1.3.conv1.weight: Filters
+      module.layer1.4.conv1.weight: Filters
+      module.layer1.5.conv1.weight: Filters
+      module.layer1.6.conv1.weight: Filters
+      module.layer1.7.conv1.weight: Filters
+      module.layer1.8.conv1.weight: Filters
+
+  filter_pruner_50:
+    #class: StructuredAutomatedGradualPruner
+    class: ActivationAPoZRankedFilterPruner_AGP
+    initial_sparsity : 0.10
+    final_sparsity: 0.5
+    reg_regims:
+      module.layer2.1.conv1.weight: Filters
+      module.layer2.2.conv1.weight: Filters
+      module.layer2.3.conv1.weight: Filters
+      module.layer2.4.conv1.weight: Filters
+      module.layer2.6.conv1.weight: Filters
+      module.layer2.7.conv1.weight: Filters
+
+  filter_pruner_10:
+    #class: StructuredAutomatedGradualPruner
+    class: ActivationAPoZRankedFilterPruner_AGP
+    initial_sparsity : 0
+    final_sparsity: 0.1
+    reg_regims:
+      module.layer3.1.conv1.weight: Filters
+
+  filter_pruner_30:
+    #class: StructuredAutomatedGradualPruner
+    class: ActivationAPoZRankedFilterPruner_AGP
+    initial_sparsity : 0.10
+    final_sparsity: 0.3
+    reg_regims:
+        module.layer3.2.conv1.weight: Filters
+        module.layer3.3.conv1.weight: Filters
+        module.layer3.5.conv1.weight: Filters
+        module.layer3.6.conv1.weight: Filters
+        module.layer3.7.conv1.weight: Filters
+        module.layer3.8.conv1.weight: Filters
+
+
+extensions:
+  net_thinner:
+      class: 'FilterRemover'
+      thinning_func_str: remove_filters
+      arch: 'resnet56_cifar'
+      dataset: 'cifar10'
+
+lr_schedulers:
+   exp_finetuning_lr:
+     class: ExponentialLR
+     gamma: 0.95
+
+
+policies:
+  - pruner:
+      instance_name: filter_pruner_60
+    starting_epoch: 181
+    ending_epoch: 200
+    frequency: 2
+
+  - pruner:
+      instance_name: filter_pruner_50
+    starting_epoch: 181
+    ending_epoch: 200
+    frequency: 2
+
+  - pruner:
+      instance_name: filter_pruner_30
+    starting_epoch: 181
+    ending_epoch: 200
+    frequency: 2
+
+  - pruner:
+      instance_name: filter_pruner_10
+    starting_epoch: 181
+    ending_epoch: 200
+    frequency: 2
+
+  - extension:
+      instance_name: net_thinner
+    epochs: [200]
+
+  - lr_scheduler:
+      instance_name: exp_finetuning_lr
+    starting_epoch: 190
+    ending_epoch: 300
+    frequency: 1
diff --git a/examples/network_trimming/resnet56_cifar_activation_apoz_v2.yaml b/examples/network_trimming/resnet56_cifar_activation_apoz_v2.yaml
new file mode 100755
index 0000000..1062520
--- /dev/null
+++ b/examples/network_trimming/resnet56_cifar_activation_apoz_v2.yaml
@@ -0,0 +1,198 @@
+#
+# This schedule uses the average percentage of zeros (APoZ) in the activations, to rank filters.
+# Compare this to examples/pruning_filters_for_efficient_convnets/resnet56_cifar_filter_rank.yaml - the pruning time is
+# much longer due to the callbacks required for collecting the activation statistics (this can be improved by disabling
+# of the detailed records collection, for example).
+# This provides 62.7% compute compression (x1.6) while increasing the Top1.
+#
+# Baseline results:
+#     Top1: 92.850    Top5: 99.780    Loss: 0.364
+#     Total MACs: 125,747,840
+#
+# Results:
+#     Top1: 92.590    Top5: 99.630    Loss: 1.537
+#     Total MACs: 67,797,632
+#
+#
+# time python3 compress_classifier.py -a=resnet56_cifar -p=50 ../../../data.cifar10 --epochs=70 --lr=0.1 --compress=../network_trimming/resnet56_cifar_activation_apoz_v2.yaml --resume=checkpoint.resnet56_cifar_baseline.pth.tar -j=1 --deterministic --act-stats=valid
+#
+# Parameters:
+# +----+-------------------------------------+----------------+---------------+----------------+------------+------------+----------+----------+----------+------------+---------+----------+------------+
+# |    | Name                                | Shape          |   NNZ (dense) |   NNZ (sparse) |   Cols (%) |   Rows (%) |   Ch (%) |   2D (%) |   3D (%) |   Fine (%) |     Std |     Mean |   Abs-Mean |
+# |----+-------------------------------------+----------------+---------------+----------------+------------+------------+----------+----------+----------+------------+---------+----------+------------|
+# |  0 | module.conv1.weight                 | (16, 3, 3, 3)  |           432 |            432 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.25855 |  0.00995 |    0.13518 |
+# |  1 | module.layer1.0.conv1.weight        | (5, 16, 3, 3)  |           720 |            720 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.08280 | -0.00200 |    0.04624 |
+# |  2 | module.layer1.0.conv2.weight        | (16, 5, 3, 3)  |           720 |            720 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.08243 | -0.01280 |    0.05715 |
+# |  3 | module.layer1.1.conv1.weight        | (5, 16, 3, 3)  |           720 |            720 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.07365 |  0.00145 |    0.04511 |
+# |  4 | module.layer1.1.conv2.weight        | (16, 5, 3, 3)  |           720 |            720 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.06361 | -0.01105 |    0.04664 |
+# |  5 | module.layer1.2.conv1.weight        | (5, 16, 3, 3)  |           720 |            720 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.07599 | -0.00007 |    0.05344 |
+# |  6 | module.layer1.2.conv2.weight        | (16, 5, 3, 3)  |           720 |            720 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.07223 | -0.00321 |    0.05500 |
+# |  7 | module.layer1.3.conv1.weight        | (5, 16, 3, 3)  |           720 |            720 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.08744 |  0.00275 |    0.06146 |
+# |  8 | module.layer1.3.conv2.weight        | (16, 5, 3, 3)  |           720 |            720 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.08422 | -0.01179 |    0.06336 |
+# |  9 | module.layer1.4.conv1.weight        | (5, 16, 3, 3)  |           720 |            720 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.07359 | -0.00328 |    0.05202 |
+# | 10 | module.layer1.4.conv2.weight        | (16, 5, 3, 3)  |           720 |            720 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.07655 | -0.00830 |    0.05807 |
+# | 11 | module.layer1.5.conv1.weight        | (5, 16, 3, 3)  |           720 |            720 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.10343 | -0.00167 |    0.06877 |
+# | 12 | module.layer1.5.conv2.weight        | (16, 5, 3, 3)  |           720 |            720 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.10276 | -0.00419 |    0.07489 |
+# | 13 | module.layer1.6.conv1.weight        | (5, 16, 3, 3)  |           720 |            720 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.08513 | -0.00885 |    0.06337 |
+# | 14 | module.layer1.6.conv2.weight        | (16, 5, 3, 3)  |           720 |            720 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.08756 | -0.00307 |    0.06605 |
+# | 15 | module.layer1.7.conv1.weight        | (5, 16, 3, 3)  |           720 |            720 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.10795 | -0.01498 |    0.07969 |
+# | 16 | module.layer1.7.conv2.weight        | (16, 5, 3, 3)  |           720 |            720 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.09896 |  0.00700 |    0.07549 |
+# | 17 | module.layer1.8.conv1.weight        | (5, 16, 3, 3)  |           720 |            720 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.09739 | -0.01620 |    0.07525 |
+# | 18 | module.layer1.8.conv2.weight        | (16, 5, 3, 3)  |           720 |            720 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.07876 |  0.00671 |    0.05963 |
+# | 19 | module.layer2.0.conv1.weight        | (32, 16, 3, 3) |          4608 |           4608 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.09674 | -0.00379 |    0.07281 |
+# | 20 | module.layer2.0.conv2.weight        | (32, 32, 3, 3) |          9216 |           9216 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.07978 | -0.00164 |    0.05939 |
+# | 21 | module.layer2.0.downsample.0.weight | (32, 16, 1, 1) |           512 |            512 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.16693 |  0.00597 |    0.11503 |
+# | 22 | module.layer2.1.conv1.weight        | (13, 32, 3, 3) |          3744 |           3744 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.06721 | -0.00298 |    0.05084 |
+# | 23 | module.layer2.1.conv2.weight        | (32, 13, 3, 3) |          3744 |           3744 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.05842 | -0.00357 |    0.04602 |
+# | 24 | module.layer2.2.conv1.weight        | (13, 32, 3, 3) |          3744 |           3744 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.05689 | -0.00687 |    0.04438 |
+# | 25 | module.layer2.2.conv2.weight        | (32, 13, 3, 3) |          3744 |           3744 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.05025 | -0.00454 |    0.03910 |
+# | 26 | module.layer2.3.conv1.weight        | (13, 32, 3, 3) |          3744 |           3744 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.05853 | -0.00510 |    0.04560 |
+# | 27 | module.layer2.3.conv2.weight        | (32, 13, 3, 3) |          3744 |           3744 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.04758 | -0.00213 |    0.03705 |
+# | 28 | module.layer2.4.conv1.weight        | (13, 32, 3, 3) |          3744 |           3744 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.05571 | -0.00672 |    0.04273 |
+# | 29 | module.layer2.4.conv2.weight        | (32, 13, 3, 3) |          3744 |           3744 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.04397 | -0.00459 |    0.03407 |
+# | 30 | module.layer2.5.conv1.weight        | (32, 32, 3, 3) |          9216 |           9216 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.04363 | -0.00326 |    0.03157 |
+# | 31 | module.layer2.5.conv2.weight        | (32, 32, 3, 3) |          9216 |           9216 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.03392 | -0.00203 |    0.02392 |
+# | 32 | module.layer2.6.conv1.weight        | (13, 32, 3, 3) |          3744 |           3744 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.04665 | -0.00607 |    0.03589 |
+# | 33 | module.layer2.6.conv2.weight        | (32, 13, 3, 3) |          3744 |           3744 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.03699 | -0.00209 |    0.02802 |
+# | 34 | module.layer2.7.conv1.weight        | (13, 32, 3, 3) |          3744 |           3744 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.05658 | -0.00710 |    0.04287 |
+# | 35 | module.layer2.7.conv2.weight        | (32, 13, 3, 3) |          3744 |           3744 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.04273 | -0.00281 |    0.03222 |
+# | 36 | module.layer2.8.conv1.weight        | (32, 32, 3, 3) |          9216 |           9216 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.03511 | -0.00271 |    0.02330 |
+# | 37 | module.layer2.8.conv2.weight        | (32, 32, 3, 3) |          9216 |           9216 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.02510 | -0.00009 |    0.01561 |
+# | 38 | module.layer3.0.conv1.weight        | (64, 32, 3, 3) |         18432 |          18432 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.06056 | -0.00144 |    0.04736 |
+# | 39 | module.layer3.0.conv2.weight        | (64, 64, 3, 3) |         36864 |          36864 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.05184 | -0.00042 |    0.03792 |
+# | 40 | module.layer3.0.downsample.0.weight | (64, 32, 1, 1) |          2048 |           2048 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.09288 |  0.00286 |    0.06527 |
+# | 41 | module.layer3.1.conv1.weight        | (52, 64, 3, 3) |         29952 |          29952 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.03912 | -0.00155 |    0.02842 |
+# | 42 | module.layer3.1.conv2.weight        | (64, 52, 3, 3) |         29952 |          29952 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.03727 | -0.00412 |    0.02804 |
+# | 43 | module.layer3.2.conv1.weight        | (39, 64, 3, 3) |         22464 |          22464 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.03857 | -0.00132 |    0.02876 |
+# | 44 | module.layer3.2.conv2.weight        | (64, 39, 3, 3) |         22464 |          22464 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.03551 | -0.00487 |    0.02752 |
+# | 45 | module.layer3.3.conv1.weight        | (39, 64, 3, 3) |         22464 |          22464 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.04020 | -0.00281 |    0.03119 |
+# | 46 | module.layer3.3.conv2.weight        | (64, 39, 3, 3) |         22464 |          22464 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.03524 | -0.00470 |    0.02756 |
+# | 47 | module.layer3.4.conv1.weight        | (64, 64, 3, 3) |         36864 |          36864 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.04392 | -0.00341 |    0.03419 |
+# | 48 | module.layer3.4.conv2.weight        | (64, 64, 3, 3) |         36864 |          36864 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.03533 | -0.00286 |    0.02693 |
+# | 49 | module.layer3.5.conv1.weight        | (39, 64, 3, 3) |         22464 |          22464 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.04352 | -0.00462 |    0.03410 |
+# | 50 | module.layer3.5.conv2.weight        | (64, 39, 3, 3) |         22464 |          22464 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.03269 | -0.00326 |    0.02494 |
+# | 51 | module.layer3.6.conv1.weight        | (39, 64, 3, 3) |         22464 |          22464 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.03254 | -0.00281 |    0.02521 |
+# | 52 | module.layer3.6.conv2.weight        | (64, 39, 3, 3) |         22464 |          22464 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.02411 | -0.00045 |    0.01755 |
+# | 53 | module.layer3.7.conv1.weight        | (39, 64, 3, 3) |         22464 |          22464 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.03012 | -0.00277 |    0.02301 |
+# | 54 | module.layer3.7.conv2.weight        | (64, 39, 3, 3) |         22464 |          22464 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.02162 | -0.00051 |    0.01584 |
+# | 55 | module.layer3.8.conv1.weight        | (39, 64, 3, 3) |         22464 |          22464 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.03413 | -0.00266 |    0.02653 |
+# | 56 | module.layer3.8.conv2.weight        | (64, 39, 3, 3) |         22464 |          22464 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.02345 |  0.00058 |    0.01716 |
+# | 57 | module.fc.weight                    | (10, 64)       |           640 |            640 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.43253 | -0.00001 |    0.34022 |
+# | 58 | Total sparsity:                     | -              |        570704 |         570704 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.00000 |  0.00000 |    0.00000 |
+# +----+-------------------------------------+----------------+---------------+----------------+------------+------------+----------+----------+----------+------------+---------+----------+------------+
+# Total sparsity: 0.00
+#
+# --- validate (epoch=249)-----------
+# 5000 samples (256 per mini-batch)
+# ==> Top1: 91.880    Top5: 99.520    Loss: 1.543
+#
+# ==> Best Top1: 92.680   On Epoch: 237
+#
+# Saving checkpoint to: logs/2018.10.16-115506/checkpoint.pth.tar
+# --- test ---------------------
+# 10000 samples (256 per mini-batch)
+# ==> Top1: 92.590    Top5: 99.630    Loss: 1.537
+#
+#
+# Log file for this run: /home/cvds_lab/nzmora/pytorch_workspace/distiller/examples/classifier_compression/logs/2018.10.16-115506/2018.10.16-115506.log
+#
+# real    49m10.711s
+# user    91m13.215s
+# sys     8m34.108s
+
+version: 1
+pruners:
+  filter_pruner_60:
+    class: ActivationAPoZRankedFilterPruner_AGP
+    initial_sparsity : 0.10
+    final_sparsity: 0.7
+    reg_regims:
+      module.layer1.0.conv1.weight: Filters
+      module.layer1.1.conv1.weight: Filters
+      module.layer1.2.conv1.weight: Filters
+      module.layer1.3.conv1.weight: Filters
+      module.layer1.4.conv1.weight: Filters
+      module.layer1.5.conv1.weight: Filters
+      module.layer1.6.conv1.weight: Filters
+      module.layer1.7.conv1.weight: Filters
+      module.layer1.8.conv1.weight: Filters
+
+  filter_pruner_50:
+    class: ActivationAPoZRankedFilterPruner_AGP
+    initial_sparsity : 0.10
+    final_sparsity: 0.6
+    reg_regims:
+      module.layer2.1.conv1.weight: Filters
+      module.layer2.2.conv1.weight: Filters
+      module.layer2.3.conv1.weight: Filters
+      module.layer2.4.conv1.weight: Filters
+      module.layer2.6.conv1.weight: Filters
+      module.layer2.7.conv1.weight: Filters
+
+  filter_pruner_10:
+    class: ActivationAPoZRankedFilterPruner_AGP
+    initial_sparsity : 0
+    final_sparsity: 0.2
+    reg_regims:
+      module.layer3.1.conv1.weight: Filters
+
+  filter_pruner_30:
+    class: ActivationAPoZRankedFilterPruner_AGP
+    initial_sparsity : 0.10
+    final_sparsity: 0.4
+    reg_regims:
+        module.layer3.2.conv1.weight: Filters
+        module.layer3.3.conv1.weight: Filters
+        module.layer3.5.conv1.weight: Filters
+        module.layer3.6.conv1.weight: Filters
+        module.layer3.7.conv1.weight: Filters
+        module.layer3.8.conv1.weight: Filters
+
+
+extensions:
+  net_thinner:
+      class: 'FilterRemover'
+      thinning_func_str: remove_filters
+      arch: 'resnet56_cifar'
+      dataset: 'cifar10'
+
+lr_schedulers:
+   exp_finetuning_lr:
+     class: ExponentialLR
+     gamma: 0.95
+
+
+policies:
+  - pruner:
+      instance_name: filter_pruner_60
+    starting_epoch: 181
+    ending_epoch: 200
+    frequency: 2
+
+  - pruner:
+      instance_name: filter_pruner_50
+    starting_epoch: 181
+    ending_epoch: 200
+    frequency: 2
+
+  - pruner:
+      instance_name: filter_pruner_30
+    starting_epoch: 181
+    ending_epoch: 200
+    frequency: 2
+
+  - pruner:
+      instance_name: filter_pruner_10
+    starting_epoch: 181
+    ending_epoch: 200
+    frequency: 2
+
+  - extension:
+      instance_name: net_thinner
+    epochs: [200]
+
+  - lr_scheduler:
+      instance_name: exp_finetuning_lr
+    starting_epoch: 190
+    ending_epoch: 300
+    frequency: 1
diff --git a/examples/pruning_filters_for_efficient_convnets/resnet56_cifar_filter_rank.yaml b/examples/pruning_filters_for_efficient_convnets/resnet56_cifar_filter_rank.yaml
index a3c282c..08b2b01 100755
--- a/examples/pruning_filters_for_efficient_convnets/resnet56_cifar_filter_rank.yaml
+++ b/examples/pruning_filters_for_efficient_convnets/resnet56_cifar_filter_rank.yaml
@@ -1,8 +1,8 @@
-#
 # This schedule performs filter ranking and removal, for the convolution layers in ResNet56-CIFAR, as described in
-# Pruning Filters for Efficient Convnets.
-# Filters are ranked and pruned accordingly.
+# Pruning Filters for Efficient Convnets, H. Li, A. Kadav, I. Durdanovic, H. Samet, and H. P. Graf.
+# ICLR 2017, arXiv:1608.087
 #
+# Filters are ranked and pruned accordingly.
 # This is followed by network thinning which removes the filters entirely from the model, and changes the convolution
 # layers' dimensions accordingly.  Convolution layers that follow have their respective channels removed as well, as do
 # Batch normailization layers.
@@ -18,6 +18,14 @@
 #
 # Results: 62.7% of the original convolution MACs (when calculated using direct convolution)
 #
+# Baseline results:
+#     Top1: 92.850    Top5: 99.780    Loss: 0.464
+#     Total MACs: 125,747,840
+#
+# Results:
+#     Top1: 92.830    Top5: 99.760    Loss: 0.489
+#     Total MACs: 78,856,832
+#
 # Parameters:
 # +----+-------------------------------------+----------------+---------------+----------------+------------+------------+----------+----------+----------+------------+---------+----------+------------+
 # |    | Name                                | Shape          |   NNZ (dense) |   NNZ (sparse) |   Cols (%) |   Rows (%) |   Ch (%) |   2D (%) |   3D (%) |   Fine (%) |     Std |     Mean |   Abs-Mean |
@@ -79,19 +87,19 @@
 # | 54 | module.layer3.7.conv2.weight        | (64, 45, 3, 3) |         25920 |          25920 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.02753 |  0.00051 |    0.02028 |
 # | 55 | module.layer3.8.conv1.weight        | (45, 64, 3, 3) |         25920 |          25920 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.02897 | -0.00308 |    0.02240 |
 # | 56 | module.layer3.8.conv2.weight        | (64, 45, 3, 3) |         25920 |          25920 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.02130 | -0.00061 |    0.01556 |
-# | 57 | module.fc.weight                    | (10, 64)       |           640 |            640 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.47902 | -0.00002 |    0.37518 |
+# | 57 | module.fc.weight                    | (10, 64)       |           640 |            640 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.47902 | -0.00002 |    0.47518 |
 # | 58 | Total sparsity:                     | -              |        634640 |         634640 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.00000 |  0.00000 |    0.00000 |
 # +----+-------------------------------------+----------------+---------------+----------------+------------+------------+----------+----------+----------+------------+---------+----------+------------+
 # Total sparsity: 0.00
 #
 # --- validate (epoch=249)-----------
 # 5000 samples (256 per mini-batch)
-# ==> Top1: 92.640    Top5: 99.820    Loss: 0.353
+# ==> Top1: 92.640    Top5: 99.820    Loss: 0.453
 #
 # Saving checkpoint
 # --- test ---------------------
 # 10000 samples (256 per mini-batch)
-# ==> Top1: 92.830    Top5: 99.760    Loss: 0.389
+# ==> Top1: 92.830    Top5: 99.760    Loss: 0.489
 #
 #
 # Log file for this run: /home/cvds_lab/nzmora/pytorch_workspace/private-distiller/examples/classifier_compression/logs/2018.04.17-005852/2018.04.17-005852.log
diff --git a/examples/pruning_filters_for_efficient_convnets/resnet56_cifar_filter_rank_v2.yaml b/examples/pruning_filters_for_efficient_convnets/resnet56_cifar_filter_rank_v2.yaml
new file mode 100755
index 0000000..d0f58ce
--- /dev/null
+++ b/examples/pruning_filters_for_efficient_convnets/resnet56_cifar_filter_rank_v2.yaml
@@ -0,0 +1,207 @@
+#
+# This schedule performs filter ranking and removal, for the convolution layers in ResNet56-CIFAR, as described in
+# Pruning Filters for Efficient Convnets, H. Li, A. Kadav, I. Durdanovic, H. Samet, and H. P. Graf.
+# ICLR 2017, arXiv:1608.087
+#
+# Filters are ranked and pruned accordingly.
+# This is followed by network thinning which removes the filters entirely from the model, and changes the convolution
+# layers' dimensions accordingly.  Convolution layers that follow have their respective channels removed as well, as do
+# Batch normailization layers.
+#
+# The authors write that: "Since there is no projection mapping for choosing the identity featuremaps, we only
+# consider pruning the first layer of the residual block."
+#
+# Note that to use the command-line below, you will need the baseline ResNet56 model (checkpoint.resnet56_cifar_baseline.pth.tar).
+# You may either train this model from scratch, or download it from the link below.
+# https://s3-us-west-1.amazonaws.com/nndistiller/pruning_filters_for_efficient_convnets/checkpoint.resnet56_cifar_baseline.pth.tar
+#
+# time python3 compress_classifier.py -a=resnet56_cifar -p=50 ../../../data.cifar10 --epochs=70 --lr=0.1 --compress=../pruning_filters_for_efficient_convnets/resnet56_cifar_filter_rank.yaml --resume=checkpoint.resnet56_cifar_baseline.pth.tar -j=1 --deterministic
+#
+# Results: 53.9% (1.85x) of the original convolution MACs (when calculated using direct convolution)
+#
+# Baseline results:
+#     Top1: 92.850    Top5: 99.780    Loss: 0.464
+#     Total MACs: 125,747,840
+#
+# Results:
+#     Top1: 92.740    Top5: 99.640    Loss: 1.534
+#     Total MACs: 67,797,632
+#
+# Parameters:
+# +----+-------------------------------------+----------------+---------------+----------------+------------+------------+----------+----------+----------+------------+---------+----------+------------+
+# |    | Name                                | Shape          |   NNZ (dense) |   NNZ (sparse) |   Cols (%) |   Rows (%) |   Ch (%) |   2D (%) |   3D (%) |   Fine (%) |     Std |     Mean |   Abs-Mean |
+# |----+-------------------------------------+----------------+---------------+----------------+------------+------------+----------+----------+----------+------------+---------+----------+------------|
+# |  0 | module.conv1.weight                 | (16, 3, 3, 3)  |           432 |            432 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.25775 |  0.01021 |    0.13389 |
+# |  1 | module.layer1.0.conv1.weight        | (5, 16, 3, 3)  |           720 |            720 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.08400 |  0.00016 |    0.04778 |
+# |  2 | module.layer1.0.conv2.weight        | (16, 5, 3, 3)  |           720 |            720 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.08528 | -0.00987 |    0.05921 |
+# |  3 | module.layer1.1.conv1.weight        | (5, 16, 3, 3)  |           720 |            720 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.08004 |  0.00482 |    0.05321 |
+# |  4 | module.layer1.1.conv2.weight        | (16, 5, 3, 3)  |           720 |            720 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.06757 | -0.01286 |    0.04833 |
+# |  5 | module.layer1.2.conv1.weight        | (5, 16, 3, 3)  |           720 |            720 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.07856 |  0.00001 |    0.05491 |
+# |  6 | module.layer1.2.conv2.weight        | (16, 5, 3, 3)  |           720 |            720 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.06927 | -0.00455 |    0.05305 |
+# |  7 | module.layer1.3.conv1.weight        | (5, 16, 3, 3)  |           720 |            720 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.08582 |  0.00157 |    0.06020 |
+# |  8 | module.layer1.3.conv2.weight        | (16, 5, 3, 3)  |           720 |            720 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.08446 | -0.00188 |    0.06118 |
+# |  9 | module.layer1.4.conv1.weight        | (5, 16, 3, 3)  |           720 |            720 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.08725 |  0.00491 |    0.06379 |
+# | 10 | module.layer1.4.conv2.weight        | (16, 5, 3, 3)  |           720 |            720 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.07957 | -0.01561 |    0.06278 |
+# | 11 | module.layer1.5.conv1.weight        | (5, 16, 3, 3)  |           720 |            720 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.10966 | -0.00952 |    0.07636 |
+# | 12 | module.layer1.5.conv2.weight        | (16, 5, 3, 3)  |           720 |            720 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.10307 |  0.00342 |    0.07403 |
+# | 13 | module.layer1.6.conv1.weight        | (5, 16, 3, 3)  |           720 |            720 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.10200 | -0.00991 |    0.07828 |
+# | 14 | module.layer1.6.conv2.weight        | (16, 5, 3, 3)  |           720 |            720 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.09630 |  0.00492 |    0.07303 |
+# | 15 | module.layer1.7.conv1.weight        | (5, 16, 3, 3)  |           720 |            720 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.11086 | -0.01292 |    0.08203 |
+# | 16 | module.layer1.7.conv2.weight        | (16, 5, 3, 3)  |           720 |            720 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.10022 |  0.00847 |    0.07685 |
+# | 17 | module.layer1.8.conv1.weight        | (5, 16, 3, 3)  |           720 |            720 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.10323 | -0.01191 |    0.07876 |
+# | 18 | module.layer1.8.conv2.weight        | (16, 5, 3, 3)  |           720 |            720 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.07648 |  0.00615 |    0.05865 |
+# | 19 | module.layer2.0.conv1.weight        | (32, 16, 3, 3) |          4608 |           4608 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.09559 | -0.00432 |    0.07201 |
+# | 20 | module.layer2.0.conv2.weight        | (32, 32, 3, 3) |          9216 |           9216 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.07934 | -0.00183 |    0.05900 |
+# | 21 | module.layer2.0.downsample.0.weight | (32, 16, 1, 1) |           512 |            512 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.16645 |  0.00714 |    0.11508 |
+# | 22 | module.layer2.1.conv1.weight        | (13, 32, 3, 3) |          3744 |           3744 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.06975 | -0.00591 |    0.05336 |
+# | 23 | module.layer2.1.conv2.weight        | (32, 13, 3, 3) |          3744 |           3744 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.05907 | -0.00069 |    0.04646 |
+# | 24 | module.layer2.2.conv1.weight        | (13, 32, 3, 3) |          3744 |           3744 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.06476 | -0.00591 |    0.05015 |
+# | 25 | module.layer2.2.conv2.weight        | (32, 13, 3, 3) |          3744 |           3744 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.05564 | -0.00607 |    0.04301 |
+# | 26 | module.layer2.3.conv1.weight        | (13, 32, 3, 3) |          3744 |           3744 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.06246 | -0.00269 |    0.04888 |
+# | 27 | module.layer2.3.conv2.weight        | (32, 13, 3, 3) |          3744 |           3744 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.05009 | -0.00171 |    0.03892 |
+# | 28 | module.layer2.4.conv1.weight        | (13, 32, 3, 3) |          3744 |           3744 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.06047 | -0.00494 |    0.04774 |
+# | 29 | module.layer2.4.conv2.weight        | (32, 13, 3, 3) |          3744 |           3744 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.04690 | -0.00493 |    0.03661 |
+# | 30 | module.layer2.5.conv1.weight        | (32, 32, 3, 3) |          9216 |           9216 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.04318 | -0.00403 |    0.03144 |
+# | 31 | module.layer2.5.conv2.weight        | (32, 32, 3, 3) |          9216 |           9216 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.03371 | -0.00219 |    0.02386 |
+# | 32 | module.layer2.6.conv1.weight        | (13, 32, 3, 3) |          3744 |           3744 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.05296 | -0.00465 |    0.04163 |
+# | 33 | module.layer2.6.conv2.weight        | (32, 13, 3, 3) |          3744 |           3744 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.04041 | -0.00099 |    0.03073 |
+# | 34 | module.layer2.7.conv1.weight        | (13, 32, 3, 3) |          3744 |           3744 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.06190 | -0.00745 |    0.04871 |
+# | 35 | module.layer2.7.conv2.weight        | (32, 13, 3, 3) |          3744 |           3744 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.04535 |  0.00052 |    0.03475 |
+# | 36 | module.layer2.8.conv1.weight        | (32, 32, 3, 3) |          9216 |           9216 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.03378 | -0.00309 |    0.02261 |
+# | 37 | module.layer2.8.conv2.weight        | (32, 32, 3, 3) |          9216 |           9216 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.02457 | -0.00035 |    0.01523 |
+# | 38 | module.layer3.0.conv1.weight        | (64, 32, 3, 3) |         18432 |          18432 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.06005 | -0.00122 |    0.04697 |
+# | 39 | module.layer3.0.conv2.weight        | (64, 64, 3, 3) |         36864 |          36864 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.05147 | -0.00011 |    0.03767 |
+# | 40 | module.layer3.0.downsample.0.weight | (64, 32, 1, 1) |          2048 |           2048 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.09252 |  0.00159 |    0.06504 |
+# | 41 | module.layer3.1.conv1.weight        | (52, 64, 3, 3) |         29952 |          29952 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.03913 | -0.00138 |    0.02846 |
+# | 42 | module.layer3.1.conv2.weight        | (64, 52, 3, 3) |         29952 |          29952 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.03736 | -0.00428 |    0.02826 |
+# | 43 | module.layer3.2.conv1.weight        | (39, 64, 3, 3) |         22464 |          22464 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.03982 | -0.00118 |    0.02987 |
+# | 44 | module.layer3.2.conv2.weight        | (64, 39, 3, 3) |         22464 |          22464 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.03651 | -0.00484 |    0.02836 |
+# | 45 | module.layer3.3.conv1.weight        | (39, 64, 3, 3) |         22464 |          22464 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.04171 | -0.00306 |    0.03253 |
+# | 46 | module.layer3.3.conv2.weight        | (64, 39, 3, 3) |         22464 |          22464 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.03619 | -0.00400 |    0.02820 |
+# | 47 | module.layer3.4.conv1.weight        | (64, 64, 3, 3) |         36864 |          36864 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.04342 | -0.00387 |    0.03380 |
+# | 48 | module.layer3.4.conv2.weight        | (64, 64, 3, 3) |         36864 |          36864 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.03494 | -0.00264 |    0.02668 |
+# | 49 | module.layer3.5.conv1.weight        | (39, 64, 3, 3) |         22464 |          22464 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.04597 | -0.00467 |    0.03630 |
+# | 50 | module.layer3.5.conv2.weight        | (64, 39, 3, 3) |         22464 |          22464 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.03378 | -0.00285 |    0.02578 |
+# | 51 | module.layer3.6.conv1.weight        | (39, 64, 3, 3) |         22464 |          22464 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.03428 | -0.00242 |    0.02700 |
+# | 52 | module.layer3.6.conv2.weight        | (64, 39, 3, 3) |         22464 |          22464 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.02501 | -0.00001 |    0.01828 |
+# | 53 | module.layer3.7.conv1.weight        | (39, 64, 3, 3) |         22464 |          22464 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.03189 | -0.00319 |    0.02489 |
+# | 54 | module.layer3.7.conv2.weight        | (64, 39, 3, 3) |         22464 |          22464 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.02260 | -0.00034 |    0.01673 |
+# | 55 | module.layer3.8.conv1.weight        | (39, 64, 3, 3) |         22464 |          22464 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.03693 | -0.00290 |    0.02890 |
+# | 56 | module.layer3.8.conv2.weight        | (64, 39, 3, 3) |         22464 |          22464 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.02474 |  0.00102 |    0.01800 |
+# | 57 | module.fc.weight                    | (10, 64)       |           640 |            640 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.42898 | -0.00001 |    0.33739 |
+# | 58 | Total sparsity:                     | -              |        570704 |         570704 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.00000 |  0.00000 |    0.00000 |
+# +----+-------------------------------------+----------------+---------------+----------------+------------+------------+----------+----------+----------+------------+---------+----------+------------+
+# Total sparsity: 0.00
+#
+# --- validate (epoch=249)-----------
+# 5000 samples (256 per mini-batch)
+# ==> Top1: 92.180    Top5: 99.660    Loss: 1.540
+#
+# ==> Best Top1: 92.580   On Epoch: 238
+#
+# Saving checkpoint to: logs/2018.10.16-103816/checkpoint.pth.tar
+# --- test ---------------------
+# 10000 samples (256 per mini-batch)
+# ==> Top1: 92.740    Top5: 99.640    Loss: 1.534
+#
+#
+# Log file for this run: /home/cvds_lab/nzmora/pytorch_workspace/distiller/examples/classifier_compression/logs/2018.10.16-103816/2018.10.16-103816.log
+#
+# real    29m38.833s
+# user    65m54.241s
+# sys     6m17.564s
+
+
+version: 1
+pruners:
+  filter_pruner:
+    class: 'L1RankedStructureParameterPruner'
+    reg_regims:
+      #'module.conv1.weight':          [0.2, '3D']
+      'module.layer1.0.conv1.weight': [0.7, '3D']
+      #'module.layer1.0.conv2.weight': [0.4, '3D']
+      'module.layer1.1.conv1.weight': [0.7, '3D']
+      #'module.layer1.1.conv2.weight': [0.7, '3D']
+      'module.layer1.2.conv1.weight': [0.7, '3D']
+      #'module.layer1.2.conv2.weight': [0.7, '3D']
+      'module.layer1.3.conv1.weight': [0.7, '3D']
+      #'module.layer1.3.conv2.weight': [0.7, '3D']
+      'module.layer1.4.conv1.weight': [0.7, '3D']
+      #'module.layer1.4.conv2.weight': [0.7, '3D']
+      'module.layer1.5.conv1.weight': [0.7, '3D']
+      #'module.layer1.5.conv2.weight': [0.7, '3D']
+      'module.layer1.6.conv1.weight': [0.7, '3D']
+      #'module.layer1.6.conv2.weight': [0.7, '3D']
+      'module.layer1.7.conv1.weight': [0.7, '3D']
+      #'module.layer1.7.conv2.weight': [0.7, '3D']
+      'module.layer1.8.conv1.weight': [0.7, '3D']
+      #'module.layer1.8.conv2.weight': [0.2, '3D']
+
+      ##'module.layer2.0.conv1.weight': [0.6, '3D']
+      #'module.layer2.0.conv2.weight': [0.4, '3D']
+      #'module.layer2.0.downsample.0.weight': [0.4, '3D']
+      'module.layer2.1.conv1.weight': [0.6, '3D']
+      #'module.layer2.1.conv2.weight': [0.4, '3D']
+      'module.layer2.2.conv1.weight': [0.6, '3D']
+      #'module.layer2.2.conv2.weight': [0.4, '3D']
+      'module.layer2.3.conv1.weight': [0.6, '3D']
+      #'module.layer2.3.conv2.weight': [0.4, '3D']
+
+      'module.layer2.4.conv1.weight': [0.6, '3D']
+      # 'module.layer2.4.conv2.weight': [0.4, '3D']
+      #'module.layer2.5.conv1.weight': [0.6, '3D']
+      # 'module.layer2.5.conv2.weight': [0.4, '3D']
+      'module.layer2.6.conv1.weight': [0.6, '3D']
+      # 'module.layer2.6.conv2.weight': [0.2, '3D']
+      'module.layer2.7.conv1.weight': [0.6, '3D']
+      # 'module.layer2.7.conv2.weight': [0.2, '3D']
+      ##'module.layer2.8.conv1.weight': [0.4, '3D']
+      # 'module.layer2.8.conv2.weight': [0.2, '3D']
+
+      #'module.layer3.0.conv1.weight': [0.1, '3D']
+      # 'module.layer3.0.conv2.weight': [0.1, '3D']
+      # 'module.layer3.0.downsample.0.weight': [0.1, '3D']
+      'module.layer3.1.conv1.weight': [0.2, '3D']
+      # 'module.layer3.1.conv2.weight': [0.1, '3D']
+      'module.layer3.2.conv1.weight': [0.4, '3D']
+      # 'module.layer3.2.conv2.weight': [0.1, '3D']
+      'module.layer3.3.conv1.weight': [0.4, '3D']
+      # 'module.layer3.3.conv2.weight': [0.1, '3D']
+      #'module.layer3.4.conv1.weight': [0.1, '3D']
+      # 'module.layer3.4.conv2.weight': [0.1, '3D']
+      'module.layer3.5.conv1.weight': [0.4, '3D']
+      #'module.layer3.5.conv2.weight': [0.1, '3D']
+      'module.layer3.6.conv1.weight': [0.4, '3D']
+      # 'module.layer3.6.conv2.weight': [0.1, '3D']
+      'module.layer3.7.conv1.weight': [0.4, '3D']
+      # 'module.layer3.7.conv2.weight': [0.1, '3D']
+      'module.layer3.8.conv1.weight': [0.4, '3D']
+      # 'module.layer3.8.conv2.weight': [0.2, '3D']
+
+extensions:
+  net_thinner:
+      class: 'FilterRemover'
+      thinning_func_str: remove_filters
+      arch: 'resnet56_cifar'
+      dataset: 'cifar10'
+
+lr_schedulers:
+   exp_finetuning_lr:
+     class: ExponentialLR
+     gamma: 0.95
+
+
+policies:
+  - pruner:
+      instance_name: filter_pruner
+    epochs: [180]
+
+  - extension:
+      instance_name: net_thinner
+    epochs: [180]
+
+  - lr_scheduler:
+      instance_name: exp_finetuning_lr
+    starting_epoch: 190
+    ending_epoch: 300
+    frequency: 1
diff --git a/examples/sensitivity-analysis/resnet50-imagenet/resnet50.imagenet.sensitivity_filter_wise.csv b/examples/sensitivity-analysis/resnet50-imagenet/resnet50.imagenet.sensitivity_filter_wise.csv
new file mode 100644
index 0000000..d9183c9
--- /dev/null
+++ b/examples/sensitivity-analysis/resnet50-imagenet/resnet50.imagenet.sensitivity_filter_wise.csv
@@ -0,0 +1,584 @@
+parameter,sparsity,top1,top5,loss
+module.conv1.weight,0.0,76.13,92.862,0.9646787385703351
+module.conv1.weight,0.05,76.092,92.878,0.9647577350997194
+module.conv1.weight,0.1,75.83,92.88,0.9732017052568949
+module.conv1.weight,0.15000000000000002,75.544,92.672,0.9858137353190358
+module.conv1.weight,0.2,75.036,92.434,1.00316541261819
+module.conv1.weight,0.25,67.07600000000001,87.788,1.367336580947954
+module.conv1.weight,0.30000000000000004,63.832,85.552,1.5320777771424268
+module.conv1.weight,0.35000000000000003,62.72599999999999,84.932,1.5777477780166933
+module.conv1.weight,0.4,54.571999999999996,78.654,2.0279791546719426
+module.conv1.weight,0.45,45.67400000000001,69.792,2.5944640064726063
+module.conv1.weight,0.5,12.258000000000003,26.004000000000005,6.298526190981571
+module.layer1.0.conv1.weight,0.0,76.13,92.862,0.9646787385703351
+module.layer1.0.conv1.weight,0.05,76.13,92.862,0.9646787385703351
+module.layer1.0.conv1.weight,0.1,75.932,92.77,0.9672794009045679
+module.layer1.0.conv1.weight,0.15000000000000002,76.012,92.842,0.9668000321455148
+module.layer1.0.conv1.weight,0.2,75.316,92.622,0.9959884060128614
+module.layer1.0.conv1.weight,0.25,75.302,92.594,0.997434276403213
+module.layer1.0.conv1.weight,0.30000000000000004,75.18599999999999,92.47800000000001,1.0021765834975
+module.layer1.0.conv1.weight,0.35000000000000003,74.424,91.99,1.039658749438063
+module.layer1.0.conv1.weight,0.4,73.13799999999999,91.216,1.0942878348334713
+module.layer1.0.conv1.weight,0.45,70.798,89.90599999999999,1.1983419334401888
+module.layer1.0.conv1.weight,0.5,65.846,86.896,1.4433029138920255
+module.layer1.0.conv2.weight,0.0,76.13,92.862,0.9646787385703351
+module.layer1.0.conv2.weight,0.05,76.13,92.862,0.9646787385703351
+module.layer1.0.conv2.weight,0.1,76.14,92.86999999999999,0.9648298266894967
+module.layer1.0.conv2.weight,0.15000000000000002,76.088,92.83200000000001,0.9653575117034573
+module.layer1.0.conv2.weight,0.2,76.042,92.878,0.9661754479973899
+module.layer1.0.conv2.weight,0.25,75.82,92.824,0.9717823474534921
+module.layer1.0.conv2.weight,0.30000000000000004,75.726,92.704,0.9781641415795503
+module.layer1.0.conv2.weight,0.35000000000000003,73.08600000000001,91.33200000000001,1.098903458337394
+module.layer1.0.conv2.weight,0.4,73.046,91.366,1.1005394557604984
+module.layer1.0.conv2.weight,0.45,64.318,86.012,1.5271362866065938
+module.layer1.0.conv2.weight,0.5,64.916,86.41799999999999,1.4901175172049177
+module.layer1.0.conv3.weight,0.0,76.13,92.862,0.9646787385703351
+module.layer1.0.conv3.weight,0.05,76.13,92.862,0.9646787385703351
+module.layer1.0.conv3.weight,0.1,76.13,92.862,0.9646787382662295
+module.layer1.0.conv3.weight,0.15000000000000002,76.134,92.86,0.9646867231598922
+module.layer1.0.conv3.weight,0.2,76.13600000000001,92.876,0.9648272743334576
+module.layer1.0.conv3.weight,0.25,76.126,92.866,0.964804031487022
+module.layer1.0.conv3.weight,0.30000000000000004,76.042,92.892,0.966239919391822
+module.layer1.0.conv3.weight,0.35000000000000003,75.982,92.91,0.96753498220018
+module.layer1.0.conv3.weight,0.4,75.94,92.876,0.9695769828193038
+module.layer1.0.conv3.weight,0.45,75.86,92.81200000000001,0.9735630200225487
+module.layer1.0.conv3.weight,0.5,75.388,92.572,0.9887932550390158
+module.layer1.0.downsample.0.weight,0.0,76.13,92.862,0.9646787385703351
+module.layer1.0.downsample.0.weight,0.05,76.13,92.862,0.9646787385703351
+module.layer1.0.downsample.0.weight,0.1,76.13,92.862,0.9646787385703351
+module.layer1.0.downsample.0.weight,0.15000000000000002,76.084,92.852,0.9652105177543602
+module.layer1.0.downsample.0.weight,0.2,75.978,92.854,0.9662235065230306
+module.layer1.0.downsample.0.weight,0.25,75.778,92.784,0.9759335609114901
+module.layer1.0.downsample.0.weight,0.30000000000000004,74.91399999999999,92.336,1.0159372910857198
+module.layer1.0.downsample.0.weight,0.35000000000000003,71.514,90.47800000000001,1.1738625028911904
+module.layer1.0.downsample.0.weight,0.4,67.556,87.998,1.3625655029805341
+module.layer1.0.downsample.0.weight,0.45,61.681999999999995,84.002,1.6625876446463614
+module.layer1.0.downsample.0.weight,0.5,40.928,65.08800000000001,3.0252077293639297
+module.layer1.1.conv1.weight,0.0,76.13,92.862,0.9646787385703351
+module.layer1.1.conv1.weight,0.05,76.13,92.862,0.9646787385703351
+module.layer1.1.conv1.weight,0.1,76.098,92.876,0.9652870166666652
+module.layer1.1.conv1.weight,0.15000000000000002,76.078,92.88,0.9660209278214948
+module.layer1.1.conv1.weight,0.2,75.906,92.854,0.9702612833709134
+module.layer1.1.conv1.weight,0.25,75.874,92.838,0.9720629371550619
+module.layer1.1.conv1.weight,0.30000000000000004,75.71600000000001,92.784,0.9748685826756516
+module.layer1.1.conv1.weight,0.35000000000000003,75.516,92.636,0.9820135380996733
+module.layer1.1.conv1.weight,0.4,75.422,92.58999999999999,0.9858410236026562
+module.layer1.1.conv1.weight,0.45,75.342,92.552,0.9901515745690892
+module.layer1.1.conv1.weight,0.5,75.214,92.44,0.9973136783406443
+module.layer1.1.conv2.weight,0.0,76.13,92.862,0.9646787385703351
+module.layer1.1.conv2.weight,0.05,76.074,92.866,0.9656340561959206
+module.layer1.1.conv2.weight,0.1,76.106,92.882,0.9665291072154533
+module.layer1.1.conv2.weight,0.15000000000000002,76.004,92.838,0.9674485273355128
+module.layer1.1.conv2.weight,0.2,75.74799999999999,92.798,0.973839068215112
+module.layer1.1.conv2.weight,0.25,75.754,92.774,0.976552191209428
+module.layer1.1.conv2.weight,0.30000000000000004,75.634,92.796,0.9795367948862973
+module.layer1.1.conv2.weight,0.35000000000000003,75.57000000000001,92.67800000000001,0.9839644473122093
+module.layer1.1.conv2.weight,0.4,75.566,92.636,0.9858285662318981
+module.layer1.1.conv2.weight,0.45,75.09599999999999,92.49000000000001,1.0086073198792886
+module.layer1.1.conv2.weight,0.5,74.634,92.238,1.0265772062904983
+module.layer1.1.conv3.weight,0.0,76.13,92.862,0.9646787385703351
+module.layer1.1.conv3.weight,0.05,76.13,92.862,0.9646781417636237
+module.layer1.1.conv3.weight,0.1,76.13600000000001,92.862,0.9646657041597122
+module.layer1.1.conv3.weight,0.15000000000000002,76.128,92.862,0.9646495287211573
+module.layer1.1.conv3.weight,0.2,76.122,92.85799999999999,0.964540821633169
+module.layer1.1.conv3.weight,0.25,76.142,92.85600000000001,0.9644717303465824
+module.layer1.1.conv3.weight,0.30000000000000004,76.118,92.876,0.9645412101277283
+module.layer1.1.conv3.weight,0.35000000000000003,76.1,92.9,0.9652030228504112
+module.layer1.1.conv3.weight,0.4,76.056,92.876,0.9660147736419221
+module.layer1.1.conv3.weight,0.45,76.018,92.85,0.9658604563042826
+module.layer1.1.conv3.weight,0.5,75.958,92.862,0.9682643982220668
+module.layer1.2.conv1.weight,0.0,76.13,92.862,0.9646787385703351
+module.layer1.2.conv1.weight,0.05,76.054,92.892,0.9656125142105988
+module.layer1.2.conv1.weight,0.1,76.058,92.84400000000001,0.9668843527229466
+module.layer1.2.conv1.weight,0.15000000000000002,76.08200000000001,92.872,0.9668020492001456
+module.layer1.2.conv1.weight,0.2,75.932,92.808,0.9682766031093746
+module.layer1.2.conv1.weight,0.25,75.862,92.75999999999999,0.9708546974829264
+module.layer1.2.conv1.weight,0.30000000000000004,75.75800000000001,92.718,0.9728293174079484
+module.layer1.2.conv1.weight,0.35000000000000003,75.71799999999999,92.7,0.9746806868637092
+module.layer1.2.conv1.weight,0.4,75.38199999999999,92.526,0.9875993205576526
+module.layer1.2.conv1.weight,0.45,75.18599999999999,92.452,0.9937687095026578
+module.layer1.2.conv1.weight,0.5,74.81400000000001,92.218,1.0118083756188954
+module.layer1.2.conv2.weight,0.0,76.13,92.862,0.9646787385703351
+module.layer1.2.conv2.weight,0.05,75.91799999999999,92.826,0.9685604718266703
+module.layer1.2.conv2.weight,0.1,75.924,92.82000000000001,0.9695663569228992
+module.layer1.2.conv2.weight,0.15000000000000002,75.666,92.742,0.9778546592866886
+module.layer1.2.conv2.weight,0.2,75.606,92.684,0.9812715381992112
+module.layer1.2.conv2.weight,0.25,75.22999999999999,92.542,0.9929108396172521
+module.layer1.2.conv2.weight,0.30000000000000004,75.046,92.414,1.0019573495552248
+module.layer1.2.conv2.weight,0.35000000000000003,74.944,92.24600000000001,1.011767355230998
+module.layer1.2.conv2.weight,0.4,74.58,92.134,1.019189273672445
+module.layer1.2.conv2.weight,0.45,74.078,91.912,1.0405801353710036
+module.layer1.2.conv2.weight,0.5,73.89,91.752,1.049945247036461
+module.layer1.2.conv3.weight,0.0,76.13,92.862,0.9646787385703351
+module.layer1.2.conv3.weight,0.05,76.12400000000001,92.86,0.9646835384168189
+module.layer1.2.conv3.weight,0.1,76.12400000000001,92.86,0.9647042825818065
+module.layer1.2.conv3.weight,0.15000000000000002,76.118,92.85600000000001,0.9646848114017325
+module.layer1.2.conv3.weight,0.2,76.132,92.854,0.9646729855056924
+module.layer1.2.conv3.weight,0.25,76.134,92.854,0.9646716789931671
+module.layer1.2.conv3.weight,0.30000000000000004,76.128,92.862,0.9646537449895121
+module.layer1.2.conv3.weight,0.35000000000000003,76.11,92.882,0.9645740807968741
+module.layer1.2.conv3.weight,0.4,76.102,92.874,0.9645784062390422
+module.layer1.2.conv3.weight,0.45,76.116,92.874,0.9646241778165708
+module.layer1.2.conv3.weight,0.5,76.118,92.86,0.9649834648839066
+module.layer2.0.conv1.weight,0.0,76.13,92.862,0.9646787385703351
+module.layer2.0.conv1.weight,0.05,76.018,92.842,0.967156604996749
+module.layer2.0.conv1.weight,0.1,75.53999999999999,92.694,0.98511602874009
+module.layer2.0.conv1.weight,0.15000000000000002,75.384,92.57,0.9942189067298051
+module.layer2.0.conv1.weight,0.2,73.736,91.75999999999999,1.0647947725896927
+module.layer2.0.conv1.weight,0.25,72.794,91.312,1.107239385876728
+module.layer2.0.conv1.weight,0.30000000000000004,69.922,89.432,1.2431173596759222
+module.layer2.0.conv1.weight,0.35000000000000003,69.928,89.59,1.2365096223597627
+module.layer2.0.conv1.weight,0.4,66.234,87.068,1.4263274966149908
+module.layer2.0.conv1.weight,0.45,44.62199999999999,67.55,2.767483531820531
+module.layer2.0.conv1.weight,0.5,36.982000000000006,59.326,3.3448752748722943
+module.layer2.0.conv2.weight,0.0,76.13,92.862,0.9646787385703351
+module.layer2.0.conv2.weight,0.05,75.902,92.842,0.969782282351231
+module.layer2.0.conv2.weight,0.1,75.756,92.824,0.9730839444210331
+module.layer2.0.conv2.weight,0.15000000000000002,75.64999999999999,92.72,0.978511454727577
+module.layer2.0.conv2.weight,0.2,75.58,92.676,0.9845758525996795
+module.layer2.0.conv2.weight,0.25,75.21799999999999,92.5,0.997505821591737
+module.layer2.0.conv2.weight,0.30000000000000004,74.788,92.288,1.0151040011218613
+module.layer2.0.conv2.weight,0.35000000000000003,74.42,92.078,1.0322041135965574
+module.layer2.0.conv2.weight,0.4,74.246,91.974,1.0409574107247952
+module.layer2.0.conv2.weight,0.45,73.40599999999999,91.438,1.077756118409488
+module.layer2.0.conv2.weight,0.5,72.55,90.956,1.1181802381666337
+module.layer2.0.conv3.weight,0.0,76.13,92.862,0.9646787385703351
+module.layer2.0.conv3.weight,0.05,76.13,92.862,0.9646787385703351
+module.layer2.0.conv3.weight,0.1,76.13,92.862,0.9646787385703351
+module.layer2.0.conv3.weight,0.15000000000000002,76.13,92.85799999999999,0.9647160731256009
+module.layer2.0.conv3.weight,0.2,76.13799999999999,92.85600000000001,0.9647735800518065
+module.layer2.0.conv3.weight,0.25,76.12400000000001,92.852,0.9647925913485942
+module.layer2.0.conv3.weight,0.30000000000000004,76.126,92.84400000000001,0.9647947986971356
+module.layer2.0.conv3.weight,0.35000000000000003,76.076,92.882,0.9641403041171785
+module.layer2.0.conv3.weight,0.4,76.066,92.908,0.9645294097005104
+module.layer2.0.conv3.weight,0.45,75.91600000000001,92.886,0.9665729468878436
+module.layer2.0.conv3.weight,0.5,75.72,92.74,0.9742113863479118
+module.layer2.0.downsample.0.weight,0.0,76.13,92.862,0.9646787385703351
+module.layer2.0.downsample.0.weight,0.05,76.13,92.862,0.9646787385703351
+module.layer2.0.downsample.0.weight,0.1,76.13,92.862,0.9646787385703351
+module.layer2.0.downsample.0.weight,0.15000000000000002,76.112,92.866,0.9650651668103376
+module.layer2.0.downsample.0.weight,0.2,76.07,92.86,0.9653989572306069
+module.layer2.0.downsample.0.weight,0.25,76.028,92.824,0.9663104700038628
+module.layer2.0.downsample.0.weight,0.30000000000000004,75.98,92.868,0.9676333344256389
+module.layer2.0.downsample.0.weight,0.35000000000000003,75.96000000000001,92.838,0.970427606482895
+module.layer2.0.downsample.0.weight,0.4,75.876,92.794,0.9731549210086157
+module.layer2.0.downsample.0.weight,0.45,75.616,92.73400000000001,0.9814816470809126
+module.layer2.0.downsample.0.weight,0.5,75.412,92.67999999999999,0.9890781333099821
+module.layer2.1.conv1.weight,0.0,76.13,92.862,0.9646787385703351
+module.layer2.1.conv1.weight,0.05,76.096,92.872,0.9646778918650688
+module.layer2.1.conv1.weight,0.1,76.058,92.896,0.9657731824079335
+module.layer2.1.conv1.weight,0.15000000000000002,75.986,92.866,0.9670812168291635
+module.layer2.1.conv1.weight,0.2,75.924,92.848,0.9687874529282655
+module.layer2.1.conv1.weight,0.25,75.864,92.81200000000001,0.9736961963377434
+module.layer2.1.conv1.weight,0.30000000000000004,75.726,92.774,0.9777020039607072
+module.layer2.1.conv1.weight,0.35000000000000003,75.602,92.738,0.980759804878307
+module.layer2.1.conv1.weight,0.4,75.214,92.548,0.9945249997687585
+module.layer2.1.conv1.weight,0.45,75.03999999999999,92.49000000000001,1.0022178821417755
+module.layer2.1.conv1.weight,0.5,74.806,92.388,1.0118249537689343
+module.layer2.1.conv2.weight,0.0,76.13,92.862,0.9646787385703351
+module.layer2.1.conv2.weight,0.05,76.05,92.872,0.9652618023053723
+module.layer2.1.conv2.weight,0.1,75.996,92.85799999999999,0.9672502354547686
+module.layer2.1.conv2.weight,0.15000000000000002,75.934,92.792,0.9704457092650083
+module.layer2.1.conv2.weight,0.2,75.858,92.75,0.9751517729339547
+module.layer2.1.conv2.weight,0.25,75.48,92.582,0.9862348007760484
+module.layer2.1.conv2.weight,0.30000000000000004,74.222,91.824,1.0404229535892298
+module.layer2.1.conv2.weight,0.35000000000000003,74.146,91.73599999999999,1.0469928039427934
+module.layer2.1.conv2.weight,0.4,69.038,88.73400000000001,1.28229090145656
+module.layer2.1.conv2.weight,0.45,54.476,77.58,2.1224025232451305
+module.layer2.1.conv2.weight,0.5,54.217999999999996,77.29599999999999,2.1405887241874417
+module.layer2.1.conv3.weight,0.0,76.13,92.862,0.9646787385703351
+module.layer2.1.conv3.weight,0.05,76.128,92.864,0.9646533330788418
+module.layer2.1.conv3.weight,0.1,76.122,92.864,0.9646252431735699
+module.layer2.1.conv3.weight,0.15000000000000002,76.12400000000001,92.86,0.9647012647925591
+module.layer2.1.conv3.weight,0.2,76.12400000000001,92.85799999999999,0.9647608776481782
+module.layer2.1.conv3.weight,0.25,76.132,92.85799999999999,0.9646967381847145
+module.layer2.1.conv3.weight,0.30000000000000004,76.13600000000001,92.866,0.9647041466467233
+module.layer2.1.conv3.weight,0.35000000000000003,76.128,92.864,0.964786077032284
+module.layer2.1.conv3.weight,0.4,76.12400000000001,92.888,0.9648655662123035
+module.layer2.1.conv3.weight,0.45,76.104,92.9,0.9648101361734531
+module.layer2.1.conv3.weight,0.5,76.07,92.9,0.9647237465393779
+module.layer2.2.conv1.weight,0.0,76.13,92.862,0.9646787385703351
+module.layer2.2.conv1.weight,0.05,76.046,92.854,0.9665528654443971
+module.layer2.2.conv1.weight,0.1,75.964,92.862,0.9686528508912543
+module.layer2.2.conv1.weight,0.15000000000000002,75.848,92.784,0.9736106615437541
+module.layer2.2.conv1.weight,0.2,75.69399999999999,92.77,0.9788862500263723
+module.layer2.2.conv1.weight,0.25,75.512,92.67999999999999,0.9830762848866231
+module.layer2.2.conv1.weight,0.30000000000000004,75.354,92.656,0.9894499518737504
+module.layer2.2.conv1.weight,0.35000000000000003,75.17399999999999,92.51599999999999,0.9970121695374953
+module.layer2.2.conv1.weight,0.4,74.972,92.414,1.0066485434618535
+module.layer2.2.conv1.weight,0.45,74.85,92.34400000000001,1.012786984291612
+module.layer2.2.conv1.weight,0.5,74.642,92.20200000000001,1.0226188807615206
+module.layer2.2.conv2.weight,0.0,76.13,92.862,0.9646787385703351
+module.layer2.2.conv2.weight,0.05,76.03800000000001,92.838,0.968614383011448
+module.layer2.2.conv2.weight,0.1,75.888,92.842,0.9717724581001971
+module.layer2.2.conv2.weight,0.15000000000000002,75.806,92.852,0.974677234659998
+module.layer2.2.conv2.weight,0.2,75.708,92.83200000000001,0.9765448045669768
+module.layer2.2.conv2.weight,0.25,75.66,92.776,0.9798674343952113
+module.layer2.2.conv2.weight,0.30000000000000004,75.386,92.686,0.9875597967481125
+module.layer2.2.conv2.weight,0.35000000000000003,75.322,92.64,0.9919184984601273
+module.layer2.2.conv2.weight,0.4,75.226,92.506,0.9981514553026278
+module.layer2.2.conv2.weight,0.45,75.102,92.49199999999999,1.002877431712589
+module.layer2.2.conv2.weight,0.5,74.874,92.306,1.0117304964485216
+module.layer2.2.conv3.weight,0.0,76.13,92.862,0.9646787385703351
+module.layer2.2.conv3.weight,0.05,76.126,92.876,0.9645333172259287
+module.layer2.2.conv3.weight,0.1,76.078,92.862,0.9660904666751013
+module.layer2.2.conv3.weight,0.15000000000000002,76.044,92.838,0.966414061431982
+module.layer2.2.conv3.weight,0.2,76.018,92.80199999999999,0.9685548517320841
+module.layer2.2.conv3.weight,0.25,76.006,92.824,0.9688468615011292
+module.layer2.2.conv3.weight,0.30000000000000004,75.97,92.78,0.9721222430923765
+module.layer2.2.conv3.weight,0.35000000000000003,75.73,92.746,0.9787512461141661
+module.layer2.2.conv3.weight,0.4,75.66199999999999,92.732,0.9809944002451944
+module.layer2.2.conv3.weight,0.45,75.51,92.622,0.9858916249810431
+module.layer2.2.conv3.weight,0.5,75.442,92.572,0.9902838661658522
+module.layer2.3.conv1.weight,0.0,76.13,92.862,0.9646787385703351
+module.layer2.3.conv1.weight,0.05,76.042,92.882,0.9676482852320281
+module.layer2.3.conv1.weight,0.1,75.998,92.884,0.9686084858008795
+module.layer2.3.conv1.weight,0.15000000000000002,75.852,92.83,0.9724806093287712
+module.layer2.3.conv1.weight,0.2,75.76,92.774,0.9753189640385767
+module.layer2.3.conv1.weight,0.25,75.538,92.702,0.9820158223868634
+module.layer2.3.conv1.weight,0.30000000000000004,75.414,92.704,0.9897886653031617
+module.layer2.3.conv1.weight,0.35000000000000003,75.03800000000001,92.51,1.0011441475730771
+module.layer2.3.conv1.weight,0.4,74.8,92.364,1.0166383479930918
+module.layer2.3.conv1.weight,0.45,74.402,92.148,1.034298639668494
+module.layer2.3.conv1.weight,0.5,73.746,91.718,1.064169044701421
+module.layer2.3.conv2.weight,0.0,76.13,92.862,0.9646787385703351
+module.layer2.3.conv2.weight,0.05,76.00200000000001,92.86,0.9662534713897172
+module.layer2.3.conv2.weight,0.1,76.00200000000001,92.862,0.96831920879836
+module.layer2.3.conv2.weight,0.15000000000000002,75.946,92.84400000000001,0.9721023115728582
+module.layer2.3.conv2.weight,0.2,75.732,92.794,0.9761088758098834
+module.layer2.3.conv2.weight,0.25,75.614,92.73599999999999,0.981304814164736
+module.layer2.3.conv2.weight,0.30000000000000004,75.542,92.708,0.9833431544200503
+module.layer2.3.conv2.weight,0.35000000000000003,75.42,92.646,0.9895992186300612
+module.layer2.3.conv2.weight,0.4,75.166,92.434,1.0021974734809933
+module.layer2.3.conv2.weight,0.45,74.996,92.396,1.0080492892587674
+module.layer2.3.conv2.weight,0.5,74.832,92.30000000000001,1.0158192809595141
+module.layer2.3.conv3.weight,0.0,76.13,92.862,0.9646787385703351
+module.layer2.3.conv3.weight,0.05,76.116,92.874,0.964743162448309
+module.layer2.3.conv3.weight,0.1,76.104,92.862,0.9648478018994233
+module.layer2.3.conv3.weight,0.15000000000000002,76.116,92.866,0.964911586852098
+module.layer2.3.conv3.weight,0.2,76.094,92.884,0.9648562783033265
+module.layer2.3.conv3.weight,0.25,76.094,92.85799999999999,0.9652072199601301
+module.layer2.3.conv3.weight,0.30000000000000004,76.028,92.84,0.9652869780452884
+module.layer2.3.conv3.weight,0.35000000000000003,76.072,92.894,0.9655525322471346
+module.layer2.3.conv3.weight,0.4,75.924,92.854,0.9668436814479684
+module.layer2.3.conv3.weight,0.45,75.85,92.794,0.9713384634530061
+module.layer2.3.conv3.weight,0.5,75.76,92.726,0.9754472542174012
+module.layer3.0.conv1.weight,0.0,76.13,92.862,0.9646787385703351
+module.layer3.0.conv1.weight,0.05,75.854,92.7,0.9755539494965755
+module.layer3.0.conv1.weight,0.1,75.558,92.604,0.9903189021409771
+module.layer3.0.conv1.weight,0.15000000000000002,74.932,92.384,1.0120573165465376
+module.layer3.0.conv1.weight,0.2,73.97,91.902,1.0564837084741014
+module.layer3.0.conv1.weight,0.25,71.67999999999999,90.572,1.160918220497515
+module.layer3.0.conv1.weight,0.30000000000000004,70.22999999999999,89.724,1.2242342644200028
+module.layer3.0.conv1.weight,0.35000000000000003,67.42,87.63,1.3680912685029358
+module.layer3.0.conv1.weight,0.4,63.426,85.04,1.5690640997217635
+module.layer3.0.conv1.weight,0.45,62.834,84.516,1.6085462114032438
+module.layer3.0.conv1.weight,0.5,59.484,82.27,1.779903719619829
+module.layer3.0.conv2.weight,0.0,76.13,92.862,0.9646787385703351
+module.layer3.0.conv2.weight,0.05,75.984,92.886,0.9716181799921453
+module.layer3.0.conv2.weight,0.1,75.67200000000001,92.73599999999999,0.9827592867065448
+module.layer3.0.conv2.weight,0.15000000000000002,75.4,92.578,0.9944861367040749
+module.layer3.0.conv2.weight,0.2,75.108,92.384,1.0120520112769948
+module.layer3.0.conv2.weight,0.25,74.46,92.052,1.0373907407023468
+module.layer3.0.conv2.weight,0.30000000000000004,74.05999999999999,91.838,1.0571384563737984
+module.layer3.0.conv2.weight,0.35000000000000003,72.89,91.12599999999999,1.1066943608528506
+module.layer3.0.conv2.weight,0.4,71.754,90.564,1.1546887049109358
+module.layer3.0.conv2.weight,0.45,69.98400000000001,89.476,1.2325848809310374
+module.layer3.0.conv2.weight,0.5,67.73599999999999,87.976,1.3425824202749195
+module.layer3.0.conv3.weight,0.0,76.13,92.862,0.9646787385703351
+module.layer3.0.conv3.weight,0.05,76.128,92.86999999999999,0.9647133990514034
+module.layer3.0.conv3.weight,0.1,76.074,92.828,0.965464632121884
+module.layer3.0.conv3.weight,0.15000000000000002,76.058,92.84599999999999,0.9675507292303505
+module.layer3.0.conv3.weight,0.2,76.008,92.83,0.9695352081741606
+module.layer3.0.conv3.weight,0.25,75.882,92.814,0.9731980219331319
+module.layer3.0.conv3.weight,0.30000000000000004,75.698,92.758,0.9806116086487863
+module.layer3.0.conv3.weight,0.35000000000000003,75.524,92.666,0.9887452712472609
+module.layer3.0.conv3.weight,0.4,75.30399999999999,92.55,0.9961166965718176
+module.layer3.0.conv3.weight,0.45,74.752,92.188,1.017145576373655
+module.layer3.0.conv3.weight,0.5,74.166,91.93,1.0403414208974158
+module.layer3.0.downsample.0.weight,0.0,76.13,92.862,0.9646787385703351
+module.layer3.0.downsample.0.weight,0.05,76.122,92.88,0.9646817222997849
+module.layer3.0.downsample.0.weight,0.1,76.09,92.868,0.9652682708538308
+module.layer3.0.downsample.0.weight,0.15000000000000002,76.02,92.83200000000001,0.9678591250308928
+module.layer3.0.downsample.0.weight,0.2,76.012,92.84599999999999,0.969518158356754
+module.layer3.0.downsample.0.weight,0.25,75.964,92.862,0.9715871051410021
+module.layer3.0.downsample.0.weight,0.30000000000000004,75.83,92.762,0.9757750733774535
+module.layer3.0.downsample.0.weight,0.35000000000000003,75.75800000000001,92.74,0.9812161533200012
+module.layer3.0.downsample.0.weight,0.4,75.568,92.652,0.988263431404318
+module.layer3.0.downsample.0.weight,0.45,75.35000000000001,92.60000000000001,0.9997599098299228
+module.layer3.0.downsample.0.weight,0.5,75.042,92.36800000000001,1.0125716319497748
+module.layer3.1.conv1.weight,0.0,76.13,92.862,0.9646787385703351
+module.layer3.1.conv1.weight,0.05,76.026,92.84,0.9650524181826997
+module.layer3.1.conv1.weight,0.1,75.996,92.836,0.9674490037925388
+module.layer3.1.conv1.weight,0.15000000000000002,75.858,92.742,0.971184025309524
+module.layer3.1.conv1.weight,0.2,75.702,92.742,0.9757259215931502
+module.layer3.1.conv1.weight,0.25,75.622,92.67999999999999,0.9801329179685943
+module.layer3.1.conv1.weight,0.30000000000000004,75.486,92.56,0.9874053938048225
+module.layer3.1.conv1.weight,0.35000000000000003,75.42999999999999,92.562,0.9897024807121072
+module.layer3.1.conv1.weight,0.4,75.322,92.45,0.99729801554765
+module.layer3.1.conv1.weight,0.45,75.018,92.35600000000001,1.0083161153051319
+module.layer3.1.conv1.weight,0.5,74.71799999999999,92.12400000000001,1.0199289657175545
+module.layer3.1.conv2.weight,0.0,76.13,92.862,0.9646787385703351
+module.layer3.1.conv2.weight,0.05,75.982,92.854,0.9691590736715163
+module.layer3.1.conv2.weight,0.1,75.818,92.686,0.9758419072901713
+module.layer3.1.conv2.weight,0.15000000000000002,75.742,92.58800000000001,0.9815147864271184
+module.layer3.1.conv2.weight,0.2,75.578,92.58,0.9874765161348843
+module.layer3.1.conv2.weight,0.25,75.47,92.462,0.994418795619692
+module.layer3.1.conv2.weight,0.30000000000000004,75.224,92.328,0.9989697071058415
+module.layer3.1.conv2.weight,0.35000000000000003,75.18199999999999,92.374,1.000461941394879
+module.layer3.1.conv2.weight,0.4,74.89399999999999,92.244,1.0080392082430885
+module.layer3.1.conv2.weight,0.45,74.602,92.148,1.0224535768585545
+module.layer3.1.conv2.weight,0.5,74.366,92.00200000000001,1.0328381399110875
+module.layer3.1.conv3.weight,0.0,76.13,92.862,0.9646787385703351
+module.layer3.1.conv3.weight,0.05,76.062,92.864,0.9646131991579829
+module.layer3.1.conv3.weight,0.1,76.028,92.86,0.9663307377118234
+module.layer3.1.conv3.weight,0.15000000000000002,76.03999999999999,92.86,0.967569736953901
+module.layer3.1.conv3.weight,0.2,75.948,92.814,0.9681062712048996
+module.layer3.1.conv3.weight,0.25,75.944,92.878,0.9703570115475021
+module.layer3.1.conv3.weight,0.30000000000000004,75.896,92.816,0.9743181323366505
+module.layer3.1.conv3.weight,0.35000000000000003,75.83800000000001,92.778,0.9771007541186953
+module.layer3.1.conv3.weight,0.4,75.732,92.752,0.9815992211200752
+module.layer3.1.conv3.weight,0.45,75.684,92.666,0.985187043492891
+module.layer3.1.conv3.weight,0.5,75.512,92.598,0.989619911401247
+module.layer3.2.conv1.weight,0.0,76.13,92.862,0.9646787385703351
+module.layer3.2.conv1.weight,0.05,76.112,92.902,0.9664402683170475
+module.layer3.2.conv1.weight,0.1,76.012,92.84,0.9691802668480238
+module.layer3.2.conv1.weight,0.15000000000000002,75.878,92.80000000000001,0.9741810408173774
+module.layer3.2.conv1.weight,0.2,75.888,92.75999999999999,0.9778152308141698
+module.layer3.2.conv1.weight,0.25,75.726,92.694,0.985200674025988
+module.layer3.2.conv1.weight,0.30000000000000004,75.602,92.638,0.9895688426889937
+module.layer3.2.conv1.weight,0.35000000000000003,75.384,92.57600000000001,0.9999833124480682
+module.layer3.2.conv1.weight,0.4,75.126,92.45,1.0099009264792718
+module.layer3.2.conv1.weight,0.45,74.91799999999999,92.374,1.016800964213147
+module.layer3.2.conv1.weight,0.5,74.784,92.27,1.026382406147159
+module.layer3.2.conv2.weight,0.0,76.13,92.862,0.9646787385703351
+module.layer3.2.conv2.weight,0.05,76.05999999999999,92.874,0.9645183723495929
+module.layer3.2.conv2.weight,0.1,75.954,92.85,0.9664772243372033
+module.layer3.2.conv2.weight,0.15000000000000002,75.89200000000001,92.824,0.9708176260547979
+module.layer3.2.conv2.weight,0.2,75.848,92.768,0.9740906326594406
+module.layer3.2.conv2.weight,0.25,75.74,92.688,0.9793189029608457
+module.layer3.2.conv2.weight,0.30000000000000004,75.486,92.628,0.9856890367001899
+module.layer3.2.conv2.weight,0.35000000000000003,75.30399999999999,92.55799999999999,0.9944385276461133
+module.layer3.2.conv2.weight,0.4,75.11,92.432,1.0005023837545692
+module.layer3.2.conv2.weight,0.45,74.88,92.366,1.007985245816562
+module.layer3.2.conv2.weight,0.5,74.57400000000001,92.142,1.0231264646412168
+module.layer3.2.conv3.weight,0.0,76.13,92.862,0.9646787385703351
+module.layer3.2.conv3.weight,0.05,76.076,92.86999999999999,0.9658319840625846
+module.layer3.2.conv3.weight,0.1,76.058,92.898,0.9680370671712619
+module.layer3.2.conv3.weight,0.15000000000000002,76.064,92.876,0.9695597740308358
+module.layer3.2.conv3.weight,0.2,76.01,92.876,0.9713810268713504
+module.layer3.2.conv3.weight,0.25,75.966,92.81200000000001,0.9739112689026764
+module.layer3.2.conv3.weight,0.30000000000000004,75.856,92.80000000000001,0.9776264067967332
+module.layer3.2.conv3.weight,0.35000000000000003,75.77000000000001,92.782,0.9793802074023656
+module.layer3.2.conv3.weight,0.4,75.704,92.73599999999999,0.9823789858088203
+module.layer3.2.conv3.weight,0.45,75.628,92.69800000000001,0.9874243092324051
+module.layer3.2.conv3.weight,0.5,75.47,92.584,0.9941450587036659
+module.layer3.3.conv1.weight,0.0,76.13,92.862,0.9646787385703351
+module.layer3.3.conv1.weight,0.05,76.05199999999999,92.854,0.9672309211170188
+module.layer3.3.conv1.weight,0.1,76.0,92.862,0.971594920039785
+module.layer3.3.conv1.weight,0.15000000000000002,75.858,92.84599999999999,0.9759238329620991
+module.layer3.3.conv1.weight,0.2,75.788,92.794,0.9799342356926326
+module.layer3.3.conv1.weight,0.25,75.664,92.71199999999999,0.9868303231742916
+module.layer3.3.conv1.weight,0.30000000000000004,75.44800000000001,92.566,0.997164914802629
+module.layer3.3.conv1.weight,0.35000000000000003,75.168,92.44200000000001,1.0106435110982583
+module.layer3.3.conv1.weight,0.4,74.988,92.306,1.0185190551743215
+module.layer3.3.conv1.weight,0.45,74.83999999999999,92.188,1.0264643902073107
+module.layer3.3.conv1.weight,0.5,74.7,92.08800000000001,1.0364262275397775
+module.layer3.3.conv2.weight,0.0,76.13,92.862,0.9646787385703351
+module.layer3.3.conv2.weight,0.05,76.022,92.808,0.9676625841886412
+module.layer3.3.conv2.weight,0.1,75.82,92.792,0.9736834481662636
+module.layer3.3.conv2.weight,0.15000000000000002,75.736,92.726,0.9789236883575819
+module.layer3.3.conv2.weight,0.2,75.464,92.58999999999999,0.9897794537246225
+module.layer3.3.conv2.weight,0.25,75.318,92.57,0.9949841195223281
+module.layer3.3.conv2.weight,0.30000000000000004,75.198,92.464,1.0017243420743207
+module.layer3.3.conv2.weight,0.35000000000000003,75.006,92.412,1.0079616335581758
+module.layer3.3.conv2.weight,0.4,74.856,92.294,1.0192331128886765
+module.layer3.3.conv2.weight,0.45,74.52799999999999,92.14,1.030934373258936
+module.layer3.3.conv2.weight,0.5,74.278,91.946,1.04510385771187
+module.layer3.3.conv3.weight,0.0,76.13,92.862,0.9646787385703351
+module.layer3.3.conv3.weight,0.05,76.126,92.86,0.9650276996651469
+module.layer3.3.conv3.weight,0.1,76.098,92.842,0.9651636611290126
+module.layer3.3.conv3.weight,0.15000000000000002,76.05,92.86999999999999,0.9664237293205701
+module.layer3.3.conv3.weight,0.2,76.054,92.876,0.968392280595643
+module.layer3.3.conv3.weight,0.25,76.012,92.842,0.970424373843232
+module.layer3.3.conv3.weight,0.30000000000000004,75.928,92.822,0.9723988493182223
+module.layer3.3.conv3.weight,0.35000000000000003,75.868,92.806,0.9749150700411023
+module.layer3.3.conv3.weight,0.4,75.87,92.75999999999999,0.9775160366327179
+module.layer3.3.conv3.weight,0.45,75.784,92.768,0.9805781351668494
+module.layer3.3.conv3.weight,0.5,75.79599999999999,92.72399999999999,0.9821375484521293
+module.layer3.4.conv1.weight,0.0,76.13,92.862,0.9646787385703351
+module.layer3.4.conv1.weight,0.05,75.958,92.852,0.9707729078981338
+module.layer3.4.conv1.weight,0.1,75.926,92.842,0.9745264629624331
+module.layer3.4.conv1.weight,0.15000000000000002,75.79,92.73599999999999,0.9820926417501608
+module.layer3.4.conv1.weight,0.2,75.506,92.662,0.990671943797141
+module.layer3.4.conv1.weight,0.25,75.362,92.56,0.9998092267434208
+module.layer3.4.conv1.weight,0.30000000000000004,75.252,92.464,1.0118998679883624
+module.layer3.4.conv1.weight,0.35000000000000003,74.91,92.308,1.0283412074252054
+module.layer3.4.conv1.weight,0.4,74.634,92.11399999999999,1.0410314060899681
+module.layer3.4.conv1.weight,0.45,74.39,91.91,1.0556378165373999
+module.layer3.4.conv1.weight,0.5,74.16,91.804,1.0647825668660966
+module.layer3.4.conv2.weight,0.0,76.13,92.862,0.9646787385703351
+module.layer3.4.conv2.weight,0.05,76.028,92.826,0.9699842729890834
+module.layer3.4.conv2.weight,0.1,75.95,92.828,0.9719456350620915
+module.layer3.4.conv2.weight,0.15000000000000002,75.708,92.726,0.9797904483061665
+module.layer3.4.conv2.weight,0.2,75.646,92.662,0.9835705510055528
+module.layer3.4.conv2.weight,0.25,75.38,92.514,0.9943041349095958
+module.layer3.4.conv2.weight,0.30000000000000004,75.136,92.384,1.0046290529473703
+module.layer3.4.conv2.weight,0.35000000000000003,75.008,92.30000000000001,1.0104411206379234
+module.layer3.4.conv2.weight,0.4,74.76400000000001,92.12400000000001,1.0218490987864075
+module.layer3.4.conv2.weight,0.45,74.366,91.962,1.0363766772254388
+module.layer3.4.conv2.weight,0.5,73.676,91.602,1.0667164083190113
+module.layer3.4.conv3.weight,0.0,76.13,92.862,0.9646787385703351
+module.layer3.4.conv3.weight,0.05,76.106,92.872,0.964855870574104
+module.layer3.4.conv3.weight,0.1,76.05,92.876,0.9654396832445448
+module.layer3.4.conv3.weight,0.15000000000000002,75.956,92.852,0.968098088338667
+module.layer3.4.conv3.weight,0.2,75.91799999999999,92.84599999999999,0.9696684817270358
+module.layer3.4.conv3.weight,0.25,75.928,92.85600000000001,0.9708511539715896
+module.layer3.4.conv3.weight,0.30000000000000004,75.856,92.854,0.972894585710399
+module.layer3.4.conv3.weight,0.35000000000000003,75.86,92.85799999999999,0.9759889332463548
+module.layer3.4.conv3.weight,0.4,75.78,92.784,0.9785844093682812
+module.layer3.4.conv3.weight,0.45,75.682,92.754,0.982491150194285
+module.layer3.4.conv3.weight,0.5,75.606,92.682,0.9880777105536995
+module.layer3.5.conv1.weight,0.0,76.13,92.862,0.9646787385703351
+module.layer3.5.conv1.weight,0.05,75.926,92.798,0.9710209958103239
+module.layer3.5.conv1.weight,0.1,75.68599999999999,92.706,0.9788274871451516
+module.layer3.5.conv1.weight,0.15000000000000002,75.596,92.646,0.9845531354753339
+module.layer3.5.conv1.weight,0.2,75.424,92.538,0.9935152600614392
+module.layer3.5.conv1.weight,0.25,75.18199999999999,92.34599999999999,1.0075640448806236
+module.layer3.5.conv1.weight,0.30000000000000004,74.83800000000001,92.2,1.021532102506987
+module.layer3.5.conv1.weight,0.35000000000000003,74.514,92.006,1.0346408500811277
+module.layer3.5.conv1.weight,0.4,74.33000000000001,91.89,1.0447048462015025
+module.layer3.5.conv1.weight,0.45,74.146,91.768,1.053496342076331
+module.layer3.5.conv1.weight,0.5,73.646,91.52799999999999,1.0740979883287631
+module.layer3.5.conv2.weight,0.0,76.13,92.862,0.9646787385703351
+module.layer3.5.conv2.weight,0.05,75.998,92.918,0.9653302145247556
+module.layer3.5.conv2.weight,0.1,75.848,92.826,0.9723002608029214
+module.layer3.5.conv2.weight,0.15000000000000002,75.698,92.75999999999999,0.9797882666545261
+module.layer3.5.conv2.weight,0.2,75.28200000000001,92.61200000000001,0.991536684030173
+module.layer3.5.conv2.weight,0.25,75.168,92.572,0.9951652087727368
+module.layer3.5.conv2.weight,0.30000000000000004,75.078,92.472,1.0034233816728295
+module.layer3.5.conv2.weight,0.35000000000000003,74.83800000000001,92.384,1.013210628865933
+module.layer3.5.conv2.weight,0.4,74.462,92.20200000000001,1.0288722234569039
+module.layer3.5.conv2.weight,0.45,74.07000000000001,91.948,1.0470701186936726
+module.layer3.5.conv2.weight,0.5,73.854,91.796,1.0605612496028145
+module.layer3.5.conv3.weight,0.0,76.13,92.862,0.9646787385703351
+module.layer3.5.conv3.weight,0.05,76.044,92.848,0.9663973350305944
+module.layer3.5.conv3.weight,0.1,75.978,92.84400000000001,0.9683438963731941
+module.layer3.5.conv3.weight,0.15000000000000002,75.96199999999999,92.836,0.9695205550108633
+module.layer3.5.conv3.weight,0.2,75.81,92.78,0.9731556126961902
+module.layer3.5.conv3.weight,0.25,75.77000000000001,92.78,0.9747894256546789
+module.layer3.5.conv3.weight,0.30000000000000004,75.78,92.73599999999999,0.9779086081045018
+module.layer3.5.conv3.weight,0.35000000000000003,75.682,92.666,0.9823050108947311
+module.layer3.5.conv3.weight,0.4,75.624,92.64,0.9860117449900325
+module.layer3.5.conv3.weight,0.45,75.506,92.58999999999999,0.9931091208543096
+module.layer3.5.conv3.weight,0.5,75.252,92.494,1.001836841887966
+module.layer4.0.conv1.weight,0.0,76.13,92.862,0.9646787385703351
+module.layer4.0.conv1.weight,0.05,75.452,92.72200000000001,0.980977897969436
+module.layer4.0.conv1.weight,0.1,75.06200000000001,92.476,0.99777424723214
+module.layer4.0.conv1.weight,0.15000000000000002,74.49600000000001,92.296,1.0166220766093048
+module.layer4.0.conv1.weight,0.2,73.66,91.852,1.0479011927940405
+module.layer4.0.conv1.weight,0.25,72.994,91.36999999999999,1.0818585154353357
+module.layer4.0.conv1.weight,0.30000000000000004,71.96199999999999,90.944,1.122837725342537
+module.layer4.0.conv1.weight,0.35000000000000003,71.224,90.57,1.1563155524888813
+module.layer4.0.conv1.weight,0.4,69.818,89.60000000000001,1.2258589152170691
+module.layer4.0.conv1.weight,0.45,68.00399999999999,88.592,1.3101937196084439
+module.layer4.0.conv1.weight,0.5,65.932,87.21799999999999,1.4147236747096998
+module.layer4.0.conv2.weight,0.0,76.13,92.862,0.9646787385703351
+module.layer4.0.conv2.weight,0.05,75.824,92.766,0.9820537386196001
+module.layer4.0.conv2.weight,0.1,75.442,92.582,0.9957953673236224
+module.layer4.0.conv2.weight,0.15000000000000002,75.17,92.49199999999999,1.005282780953816
+module.layer4.0.conv2.weight,0.2,74.852,92.304,1.0212704213146044
+module.layer4.0.conv2.weight,0.25,74.292,91.964,1.0414224162089585
+module.layer4.0.conv2.weight,0.30000000000000004,73.604,91.634,1.0696578053187356
+module.layer4.0.conv2.weight,0.35000000000000003,72.89999999999999,91.208,1.0999143880848987
+module.layer4.0.conv2.weight,0.4,72.184,90.8,1.1301327659463394
+module.layer4.0.conv2.weight,0.45,71.114,90.238,1.173654002346554
+module.layer4.0.conv2.weight,0.5,69.95599999999999,89.628,1.222327757398693
+module.layer4.0.conv3.weight,0.0,76.13,92.862,0.9646787385703351
+module.layer4.0.conv3.weight,0.05,76.068,92.872,0.974233268703125
+module.layer4.0.conv3.weight,0.1,75.902,92.764,0.9797552155748923
+module.layer4.0.conv3.weight,0.15000000000000002,75.80799999999999,92.738,0.9865577943166911
+module.layer4.0.conv3.weight,0.2,75.632,92.642,0.9941004342114437
+module.layer4.0.conv3.weight,0.25,75.39,92.508,1.0038409359296971
+module.layer4.0.conv3.weight,0.30000000000000004,75.114,92.41,1.0120449089730275
+module.layer4.0.conv3.weight,0.35000000000000003,74.892,92.244,1.023147908704622
+module.layer4.0.conv3.weight,0.4,74.56,92.092,1.0372554889442975
+module.layer4.0.conv3.weight,0.45,74.076,91.852,1.0547336231993174
+module.layer4.0.conv3.weight,0.5,73.584,91.602,1.077762626719718
+module.layer4.0.downsample.0.weight,0.0,76.13,92.862,0.9646787385703351
+module.layer4.0.downsample.0.weight,0.05,76.012,92.872,0.9663943356397199
+module.layer4.0.downsample.0.weight,0.1,75.99000000000001,92.866,0.9696360729938869
+module.layer4.0.downsample.0.weight,0.15000000000000002,75.91,92.84599999999999,0.9735238836431992
+module.layer4.0.downsample.0.weight,0.2,75.782,92.81,0.975955544001594
+module.layer4.0.downsample.0.weight,0.25,75.71799999999999,92.726,0.9828369682844804
+module.layer4.0.downsample.0.weight,0.30000000000000004,75.58,92.704,0.9846665353647297
+module.layer4.0.downsample.0.weight,0.35000000000000003,75.55199999999999,92.65,0.9895600051600107
+module.layer4.0.downsample.0.weight,0.4,75.38199999999999,92.586,0.9927099836724144
+module.layer4.0.downsample.0.weight,0.45,75.12,92.534,0.9976098318489229
+module.layer4.0.downsample.0.weight,0.5,74.932,92.428,1.0028844749440953
+module.layer4.1.conv1.weight,0.0,76.13,92.862,0.9646787385703351
+module.layer4.1.conv1.weight,0.05,75.332,92.654,0.98502134539339
+module.layer4.1.conv1.weight,0.1,75.018,92.47800000000001,1.0129908440368516
+module.layer4.1.conv1.weight,0.15000000000000002,74.4,92.008,1.0427387130199646
+module.layer4.1.conv1.weight,0.2,73.52799999999999,91.636,1.0739348117186098
+module.layer4.1.conv1.weight,0.25,72.952,91.32000000000001,1.103592288889447
+module.layer4.1.conv1.weight,0.30000000000000004,72.392,90.914,1.1367468675788572
+module.layer4.1.conv1.weight,0.35000000000000003,71.892,90.664,1.1603467205957496
+module.layer4.1.conv1.weight,0.4,71.218,90.34400000000001,1.1915501374371191
+module.layer4.1.conv1.weight,0.45,69.674,89.46600000000001,1.2544893468825187
+module.layer4.1.conv1.weight,0.5,68.68799999999999,88.822,1.2998564492683022
+module.layer4.1.conv2.weight,0.0,76.13,92.862,0.9646787385703351
+module.layer4.1.conv2.weight,0.05,75.53,92.622,0.9798856420176366
+module.layer4.1.conv2.weight,0.1,74.52600000000001,92.082,1.0328085240052662
+module.layer4.1.conv2.weight,0.15000000000000002,73.35199999999999,91.58800000000001,1.1326180742103231
+module.layer4.1.conv2.weight,0.2,72.074,90.818,1.2893840082141816
+module.layer4.1.conv2.weight,0.25,70.956,90.19200000000001,1.4188573299622047
+module.layer4.1.conv2.weight,0.30000000000000004,68.238,88.714,1.7882773164583712
+module.layer4.1.conv2.weight,0.35000000000000003,67.46600000000001,88.27000000000001,1.970356911420823
+module.layer4.1.conv2.weight,0.4,66.98200000000001,87.858,2.050041017483691
+module.layer4.1.conv2.weight,0.45,64.67,86.59,2.2327518517873735
+module.layer4.1.conv2.weight,0.5,62.056,84.88,2.648381370062732
+module.layer4.1.conv3.weight,0.0,76.13,92.862,0.9646787385703351
+module.layer4.1.conv3.weight,0.05,76.074,92.872,0.9659452228521812
+module.layer4.1.conv3.weight,0.1,75.988,92.842,0.967287575940088
+module.layer4.1.conv3.weight,0.15000000000000002,75.934,92.772,0.9751768792618293
+module.layer4.1.conv3.weight,0.2,75.866,92.73400000000001,0.9757551226232728
+module.layer4.1.conv3.weight,0.25,75.8,92.686,0.981689058548334
+module.layer4.1.conv3.weight,0.30000000000000004,75.628,92.642,0.9852055123417963
+module.layer4.1.conv3.weight,0.35000000000000003,75.49,92.536,0.9915929486100771
+module.layer4.1.conv3.weight,0.4,75.272,92.498,1.0091029061954848
+module.layer4.1.conv3.weight,0.45,75.028,92.34400000000001,1.022709506050665
+module.layer4.1.conv3.weight,0.5,74.638,92.148,1.0381865026999488
+module.layer4.2.conv1.weight,0.0,76.13,92.862,0.9646787385703351
+module.layer4.2.conv1.weight,0.05,75.846,92.826,0.9717933706635113
+module.layer4.2.conv1.weight,0.1,75.276,92.472,0.9955067696163846
+module.layer4.2.conv1.weight,0.15000000000000002,74.656,92.20200000000001,1.0159382923525204
+module.layer4.2.conv1.weight,0.2,74.12,91.932,1.0420076324775516
+module.layer4.2.conv1.weight,0.25,73.394,91.662,1.08965770018344
+module.layer4.2.conv1.weight,0.30000000000000004,72.66,91.33200000000001,1.1267745362556716
+module.layer4.2.conv1.weight,0.35000000000000003,72.174,91.024,1.1684093069361183
+module.layer4.2.conv1.weight,0.4,71.224,90.53,1.239526554181868
+module.layer4.2.conv1.weight,0.45,70.39800000000001,89.862,1.310148460828528
+module.layer4.2.conv1.weight,0.5,69.206,89.23400000000001,1.4035242774656844
+module.layer4.2.conv2.weight,0.0,76.13,92.862,0.9646787385703351
+module.layer4.2.conv2.weight,0.05,75.828,92.818,0.9695281652467589
+module.layer4.2.conv2.weight,0.1,75.676,92.702,0.975221398989765
+module.layer4.2.conv2.weight,0.15000000000000002,75.462,92.52,0.9864033583019457
+module.layer4.2.conv2.weight,0.2,74.98400000000001,92.362,1.0007667366643338
+module.layer4.2.conv2.weight,0.25,74.54799999999999,92.176,1.0197215826839818
+module.layer4.2.conv2.weight,0.30000000000000004,73.988,92.006,1.0427208835525168
+module.layer4.2.conv2.weight,0.35000000000000003,73.628,91.768,1.0686493072734802
+module.layer4.2.conv2.weight,0.4,73.262,91.602,1.097705980800852
+module.layer4.2.conv2.weight,0.45,72.57000000000001,91.266,1.1520085818305306
+module.layer4.2.conv2.weight,0.5,71.44800000000001,90.73400000000001,1.2216127002421697
+module.layer4.2.conv3.weight,0.0,76.13,92.862,0.9646787385703351
+module.layer4.2.conv3.weight,0.05,76.088,92.882,0.9627235708188039
+module.layer4.2.conv3.weight,0.1,76.062,92.868,0.9608848703911115
+module.layer4.2.conv3.weight,0.15000000000000002,76.032,92.848,0.9586635404551517
+module.layer4.2.conv3.weight,0.2,75.888,92.826,0.9583521980260096
+module.layer4.2.conv3.weight,0.25,75.874,92.794,0.9590057633361035
+module.layer4.2.conv3.weight,0.30000000000000004,75.71799999999999,92.752,0.9618615930025672
+module.layer4.2.conv3.weight,0.35000000000000003,75.598,92.702,0.9675871675111806
+module.layer4.2.conv3.weight,0.4,75.422,92.648,0.9774525021107828
+module.layer4.2.conv3.weight,0.45,75.216,92.584,0.9908461873324553
+module.layer4.2.conv3.weight,0.5,74.978,92.554,1.0100993886590002
diff --git a/jupyter/imagenet_classes.py b/jupyter/imagenet_classes.py
new file mode 100755
index 0000000..d186665
--- /dev/null
+++ b/jupyter/imagenet_classes.py
@@ -0,0 +1,1003 @@
+# https://gist.github.com/yrevar/942d3a0ac09ec9e5eb3a#file-imagenet1000_clsid_to_human-txt
+
+imagenet_classes = {
+ 0: 'tench, Tinca tinca',
+ 1: 'goldfish, Carassius auratus',
+ 2: 'great white shark, white shark, man-eater, man-eating shark, Carcharodon carcharias',
+ 3: 'tiger shark, Galeocerdo cuvieri',
+ 4: 'hammerhead, hammerhead shark',
+ 5: 'electric ray, crampfish, numbfish, torpedo',
+ 6: 'stingray',
+ 7: 'cock',
+ 8: 'hen',
+ 9: 'ostrich, Struthio camelus',
+ 10: 'brambling, Fringilla montifringilla',
+ 11: 'goldfinch, Carduelis carduelis',
+ 12: 'house finch, linnet, Carpodacus mexicanus',
+ 13: 'junco, snowbird',
+ 14: 'indigo bunting, indigo finch, indigo bird, Passerina cyanea',
+ 15: 'robin, American robin, Turdus migratorius',
+ 16: 'bulbul',
+ 17: 'jay',
+ 18: 'magpie',
+ 19: 'chickadee',
+ 20: 'water ouzel, dipper',
+ 21: 'kite',
+ 22: 'bald eagle, American eagle, Haliaeetus leucocephalus',
+ 23: 'vulture',
+ 24: 'great grey owl, great gray owl, Strix nebulosa',
+ 25: 'European fire salamander, Salamandra salamandra',
+ 26: 'common newt, Triturus vulgaris',
+ 27: 'eft',
+ 28: 'spotted salamander, Ambystoma maculatum',
+ 29: 'axolotl, mud puppy, Ambystoma mexicanum',
+ 30: 'bullfrog, Rana catesbeiana',
+ 31: 'tree frog, tree-frog',
+ 32: 'tailed frog, bell toad, ribbed toad, tailed toad, Ascaphus trui',
+ 33: 'loggerhead, loggerhead turtle, Caretta caretta',
+ 34: 'leatherback turtle, leatherback, leathery turtle, Dermochelys coriacea',
+ 35: 'mud turtle',
+ 36: 'terrapin',
+ 37: 'box turtle, box tortoise',
+ 38: 'banded gecko',
+ 39: 'common iguana, iguana, Iguana iguana',
+ 40: 'American chameleon, anole, Anolis carolinensis',
+ 41: 'whiptail, whiptail lizard',
+ 42: 'agama',
+ 43: 'frilled lizard, Chlamydosaurus kingi',
+ 44: 'alligator lizard',
+ 45: 'Gila monster, Heloderma suspectum',
+ 46: 'green lizard, Lacerta viridis',
+ 47: 'African chameleon, Chamaeleo chamaeleon',
+ 48: 'Komodo dragon, Komodo lizard, dragon lizard, giant lizard, Varanus komodoensis',
+ 49: 'African crocodile, Nile crocodile, Crocodylus niloticus',
+ 50: 'American alligator, Alligator mississipiensis',
+ 51: 'triceratops',
+ 52: 'thunder snake, worm snake, Carphophis amoenus',
+ 53: 'ringneck snake, ring-necked snake, ring snake',
+ 54: 'hognose snake, puff adder, sand viper',
+ 55: 'green snake, grass snake',
+ 56: 'king snake, kingsnake',
+ 57: 'garter snake, grass snake',
+ 58: 'water snake',
+ 59: 'vine snake',
+ 60: 'night snake, Hypsiglena torquata',
+ 61: 'boa constrictor, Constrictor constrictor',
+ 62: 'rock python, rock snake, Python sebae',
+ 63: 'Indian cobra, Naja naja',
+ 64: 'green mamba',
+ 65: 'sea snake',
+ 66: 'horned viper, cerastes, sand viper, horned asp, Cerastes cornutus',
+ 67: 'diamondback, diamondback rattlesnake, Crotalus adamanteus',
+ 68: 'sidewinder, horned rattlesnake, Crotalus cerastes',
+ 69: 'trilobite',
+ 70: 'harvestman, daddy longlegs, Phalangium opilio',
+ 71: 'scorpion',
+ 72: 'black and gold garden spider, Argiope aurantia',
+ 73: 'barn spider, Araneus cavaticus',
+ 74: 'garden spider, Aranea diademata',
+ 75: 'black widow, Latrodectus mactans',
+ 76: 'tarantula',
+ 77: 'wolf spider, hunting spider',
+ 78: 'tick',
+ 79: 'centipede',
+ 80: 'black grouse',
+ 81: 'ptarmigan',
+ 82: 'ruffed grouse, partridge, Bonasa umbellus',
+ 83: 'prairie chicken, prairie grouse, prairie fowl',
+ 84: 'peacock',
+ 85: 'quail',
+ 86: 'partridge',
+ 87: 'African grey, African gray, Psittacus erithacus',
+ 88: 'macaw',
+ 89: 'sulphur-crested cockatoo, Kakatoe galerita, Cacatua galerita',
+ 90: 'lorikeet',
+ 91: 'coucal',
+ 92: 'bee eater',
+ 93: 'hornbill',
+ 94: 'hummingbird',
+ 95: 'jacamar',
+ 96: 'toucan',
+ 97: 'drake',
+ 98: 'red-breasted merganser, Mergus serrator',
+ 99: 'goose',
+ 100: 'black swan, Cygnus atratus',
+ 101: 'tusker',
+ 102: 'echidna, spiny anteater, anteater',
+ 103: 'platypus, duckbill, duckbilled platypus, duck-billed platypus, Ornithorhynchus anatinus',
+ 104: 'wallaby, brush kangaroo',
+ 105: 'koala, koala bear, kangaroo bear, native bear, Phascolarctos cinereus',
+ 106: 'wombat',
+ 107: 'jellyfish',
+ 108: 'sea anemone, anemone',
+ 109: 'brain coral',
+ 110: 'flatworm, platyhelminth',
+ 111: 'nematode, nematode worm, roundworm',
+ 112: 'conch',
+ 113: 'snail',
+ 114: 'slug',
+ 115: 'sea slug, nudibranch',
+ 116: 'chiton, coat-of-mail shell, sea cradle, polyplacophore',
+ 117: 'chambered nautilus, pearly nautilus, nautilus',
+ 118: 'Dungeness crab, Cancer magister',
+ 119: 'rock crab, Cancer irroratus',
+ 120: 'fiddler crab',
+ 121: 'king crab, Alaska crab, Alaskan king crab, Alaska king crab, Paralithodes camtschatica',
+ 122: 'American lobster, Northern lobster, Maine lobster, Homarus americanus',
+ 123: 'spiny lobster, langouste, rock lobster, crawfish, crayfish, sea crawfish',
+ 124: 'crayfish, crawfish, crawdad, crawdaddy',
+ 125: 'hermit crab',
+ 126: 'isopod',
+ 127: 'white stork, Ciconia ciconia',
+ 128: 'black stork, Ciconia nigra',
+ 129: 'spoonbill',
+ 130: 'flamingo',
+ 131: 'little blue heron, Egretta caerulea',
+ 132: 'American egret, great white heron, Egretta albus',
+ 133: 'bittern',
+ 134: 'crane',
+ 135: 'limpkin, Aramus pictus',
+ 136: 'European gallinule, Porphyrio porphyrio',
+ 137: 'American coot, marsh hen, mud hen, water hen, Fulica americana',
+ 138: 'bustard',
+ 139: 'ruddy turnstone, Arenaria interpres',
+ 140: 'red-backed sandpiper, dunlin, Erolia alpina',
+ 141: 'redshank, Tringa totanus',
+ 142: 'dowitcher',
+ 143: 'oystercatcher, oyster catcher',
+ 144: 'pelican',
+ 145: 'king penguin, Aptenodytes patagonica',
+ 146: 'albatross, mollymawk',
+ 147: 'grey whale, gray whale, devilfish, Eschrichtius gibbosus, Eschrichtius robustus',
+ 148: 'killer whale, killer, orca, grampus, sea wolf, Orcinus orca',
+ 149: 'dugong, Dugong dugon',
+ 150: 'sea lion',
+ 151: 'Chihuahua',
+ 152: 'Japanese spaniel',
+ 153: 'Maltese dog, Maltese terrier, Maltese',
+ 154: 'Pekinese, Pekingese, Peke',
+ 155: 'Shih-Tzu',
+ 156: 'Blenheim spaniel',
+ 157: 'papillon',
+ 158: 'toy terrier',
+ 159: 'Rhodesian ridgeback',
+ 160: 'Afghan hound, Afghan',
+ 161: 'basset, basset hound',
+ 162: 'beagle',
+ 163: 'bloodhound, sleuthhound',
+ 164: 'bluetick',
+ 165: 'black-and-tan coonhound',
+ 166: 'Walker hound, Walker foxhound',
+ 167: 'English foxhound',
+ 168: 'redbone',
+ 169: 'borzoi, Russian wolfhound',
+ 170: 'Irish wolfhound',
+ 171: 'Italian greyhound',
+ 172: 'whippet',
+ 173: 'Ibizan hound, Ibizan Podenco',
+ 174: 'Norwegian elkhound, elkhound',
+ 175: 'otterhound, otter hound',
+ 176: 'Saluki, gazelle hound',
+ 177: 'Scottish deerhound, deerhound',
+ 178: 'Weimaraner',
+ 179: 'Staffordshire bullterrier, Staffordshire bull terrier',
+ 180: 'American Staffordshire terrier, Staffordshire terrier, American pit bull terrier, pit bull terrier',
+ 181: 'Bedlington terrier',
+ 182: 'Border terrier',
+ 183: 'Kerry blue terrier',
+ 184: 'Irish terrier',
+ 185: 'Norfolk terrier',
+ 186: 'Norwich terrier',
+ 187: 'Yorkshire terrier',
+ 188: 'wire-haired fox terrier',
+ 189: 'Lakeland terrier',
+ 190: 'Sealyham terrier, Sealyham',
+ 191: 'Airedale, Airedale terrier',
+ 192: 'cairn, cairn terrier',
+ 193: 'Australian terrier',
+ 194: 'Dandie Dinmont, Dandie Dinmont terrier',
+ 195: 'Boston bull, Boston terrier',
+ 196: 'miniature schnauzer',
+ 197: 'giant schnauzer',
+ 198: 'standard schnauzer',
+ 199: 'Scotch terrier, Scottish terrier, Scottie',
+ 200: 'Tibetan terrier, chrysanthemum dog',
+ 201: 'silky terrier, Sydney silky',
+ 202: 'soft-coated wheaten terrier',
+ 203: 'West Highland white terrier',
+ 204: 'Lhasa, Lhasa apso',
+ 205: 'flat-coated retriever',
+ 206: 'curly-coated retriever',
+ 207: 'golden retriever',
+ 208: 'Labrador retriever',
+ 209: 'Chesapeake Bay retriever',
+ 210: 'German short-haired pointer',
+ 211: 'vizsla, Hungarian pointer',
+ 212: 'English setter',
+ 213: 'Irish setter, red setter',
+ 214: 'Gordon setter',
+ 215: 'Brittany spaniel',
+ 216: 'clumber, clumber spaniel',
+ 217: 'English springer, English springer spaniel',
+ 218: 'Welsh springer spaniel',
+ 219: 'cocker spaniel, English cocker spaniel, cocker',
+ 220: 'Sussex spaniel',
+ 221: 'Irish water spaniel',
+ 222: 'kuvasz',
+ 223: 'schipperke',
+ 224: 'groenendael',
+ 225: 'malinois',
+ 226: 'briard',
+ 227: 'kelpie',
+ 228: 'komondor',
+ 229: 'Old English sheepdog, bobtail',
+ 230: 'Shetland sheepdog, Shetland sheep dog, Shetland',
+ 231: 'collie',
+ 232: 'Border collie',
+ 233: 'Bouvier des Flandres, Bouviers des Flandres',
+ 234: 'Rottweiler',
+ 235: 'German shepherd, German shepherd dog, German police dog, alsatian',
+ 236: 'Doberman, Doberman pinscher',
+ 237: 'miniature pinscher',
+ 238: 'Greater Swiss Mountain dog',
+ 239: 'Bernese mountain dog',
+ 240: 'Appenzeller',
+ 241: 'EntleBucher',
+ 242: 'boxer',
+ 243: 'bull mastiff',
+ 244: 'Tibetan mastiff',
+ 245: 'French bulldog',
+ 246: 'Great Dane',
+ 247: 'Saint Bernard, St Bernard',
+ 248: 'Eskimo dog, husky',
+ 249: 'malamute, malemute, Alaskan malamute',
+ 250: 'Siberian husky',
+ 251: 'dalmatian, coach dog, carriage dog',
+ 252: 'affenpinscher, monkey pinscher, monkey dog',
+ 253: 'basenji',
+ 254: 'pug, pug-dog',
+ 255: 'Leonberg',
+ 256: 'Newfoundland, Newfoundland dog',
+ 257: 'Great Pyrenees',
+ 258: 'Samoyed, Samoyede',
+ 259: 'Pomeranian',
+ 260: 'chow, chow chow',
+ 261: 'keeshond',
+ 262: 'Brabancon griffon',
+ 263: 'Pembroke, Pembroke Welsh corgi',
+ 264: 'Cardigan, Cardigan Welsh corgi',
+ 265: 'toy poodle',
+ 266: 'miniature poodle',
+ 267: 'standard poodle',
+ 268: 'Mexican hairless',
+ 269: 'timber wolf, grey wolf, gray wolf, Canis lupus',
+ 270: 'white wolf, Arctic wolf, Canis lupus tundrarum',
+ 271: 'red wolf, maned wolf, Canis rufus, Canis niger',
+ 272: 'coyote, prairie wolf, brush wolf, Canis latrans',
+ 273: 'dingo, warrigal, warragal, Canis dingo',
+ 274: 'dhole, Cuon alpinus',
+ 275: 'African hunting dog, hyena dog, Cape hunting dog, Lycaon pictus',
+ 276: 'hyena, hyaena',
+ 277: 'red fox, Vulpes vulpes',
+ 278: 'kit fox, Vulpes macrotis',
+ 279: 'Arctic fox, white fox, Alopex lagopus',
+ 280: 'grey fox, gray fox, Urocyon cinereoargenteus',
+ 281: 'tabby, tabby cat',
+ 282: 'tiger cat',
+ 283: 'Persian cat',
+ 284: 'Siamese cat, Siamese',
+ 285: 'Egyptian cat',
+ 286: 'cougar, puma, catamount, mountain lion, painter, panther, Felis concolor',
+ 287: 'lynx, catamount',
+ 288: 'leopard, Panthera pardus',
+ 289: 'snow leopard, ounce, Panthera uncia',
+ 290: 'jaguar, panther, Panthera onca, Felis onca',
+ 291: 'lion, king of beasts, Panthera leo',
+ 292: 'tiger, Panthera tigris',
+ 293: 'cheetah, chetah, Acinonyx jubatus',
+ 294: 'brown bear, bruin, Ursus arctos',
+ 295: 'American black bear, black bear, Ursus americanus, Euarctos americanus',
+ 296: 'ice bear, polar bear, Ursus Maritimus, Thalarctos maritimus',
+ 297: 'sloth bear, Melursus ursinus, Ursus ursinus',
+ 298: 'mongoose',
+ 299: 'meerkat, mierkat',
+ 300: 'tiger beetle',
+ 301: 'ladybug, ladybeetle, lady beetle, ladybird, ladybird beetle',
+ 302: 'ground beetle, carabid beetle',
+ 303: 'long-horned beetle, longicorn, longicorn beetle',
+ 304: 'leaf beetle, chrysomelid',
+ 305: 'dung beetle',
+ 306: 'rhinoceros beetle',
+ 307: 'weevil',
+ 308: 'fly',
+ 309: 'bee',
+ 310: 'ant, emmet, pismire',
+ 311: 'grasshopper, hopper',
+ 312: 'cricket',
+ 313: 'walking stick, walkingstick, stick insect',
+ 314: 'cockroach, roach',
+ 315: 'mantis, mantid',
+ 316: 'cicada, cicala',
+ 317: 'leafhopper',
+ 318: 'lacewing, lacewing fly',
+ 319: "dragonfly, darning needle, devil's darning needle, sewing needle, snake feeder, snake doctor, mosquito hawk, skeeter hawk",
+ 320: 'damselfly',
+ 321: 'admiral',
+ 322: 'ringlet, ringlet butterfly',
+ 323: 'monarch, monarch butterfly, milkweed butterfly, Danaus plexippus',
+ 324: 'cabbage butterfly',
+ 325: 'sulphur butterfly, sulfur butterfly',
+ 326: 'lycaenid, lycaenid butterfly',
+ 327: 'starfish, sea star',
+ 328: 'sea urchin',
+ 329: 'sea cucumber, holothurian',
+ 330: 'wood rabbit, cottontail, cottontail rabbit',
+ 331: 'hare',
+ 332: 'Angora, Angora rabbit',
+ 333: 'hamster',
+ 334: 'porcupine, hedgehog',
+ 335: 'fox squirrel, eastern fox squirrel, Sciurus niger',
+ 336: 'marmot',
+ 337: 'beaver',
+ 338: 'guinea pig, Cavia cobaya',
+ 339: 'sorrel',
+ 340: 'zebra',
+ 341: 'hog, pig, grunter, squealer, Sus scrofa',
+ 342: 'wild boar, boar, Sus scrofa',
+ 343: 'warthog',
+ 344: 'hippopotamus, hippo, river horse, Hippopotamus amphibius',
+ 345: 'ox',
+ 346: 'water buffalo, water ox, Asiatic buffalo, Bubalus bubalis',
+ 347: 'bison',
+ 348: 'ram, tup',
+ 349: 'bighorn, bighorn sheep, cimarron, Rocky Mountain bighorn, Rocky Mountain sheep, Ovis canadensis',
+ 350: 'ibex, Capra ibex',
+ 351: 'hartebeest',
+ 352: 'impala, Aepyceros melampus',
+ 353: 'gazelle',
+ 354: 'Arabian camel, dromedary, Camelus dromedarius',
+ 355: 'llama',
+ 356: 'weasel',
+ 357: 'mink',
+ 358: 'polecat, fitch, foulmart, foumart, Mustela putorius',
+ 359: 'black-footed ferret, ferret, Mustela nigripes',
+ 360: 'otter',
+ 361: 'skunk, polecat, wood pussy',
+ 362: 'badger',
+ 363: 'armadillo',
+ 364: 'three-toed sloth, ai, Bradypus tridactylus',
+ 365: 'orangutan, orang, orangutang, Pongo pygmaeus',
+ 366: 'gorilla, Gorilla gorilla',
+ 367: 'chimpanzee, chimp, Pan troglodytes',
+ 368: 'gibbon, Hylobates lar',
+ 369: 'siamang, Hylobates syndactylus, Symphalangus syndactylus',
+ 370: 'guenon, guenon monkey',
+ 371: 'patas, hussar monkey, Erythrocebus patas',
+ 372: 'baboon',
+ 373: 'macaque',
+ 374: 'langur',
+ 375: 'colobus, colobus monkey',
+ 376: 'proboscis monkey, Nasalis larvatus',
+ 377: 'marmoset',
+ 378: 'capuchin, ringtail, Cebus capucinus',
+ 379: 'howler monkey, howler',
+ 380: 'titi, titi monkey',
+ 381: 'spider monkey, Ateles geoffroyi',
+ 382: 'squirrel monkey, Saimiri sciureus',
+ 383: 'Madagascar cat, ring-tailed lemur, Lemur catta',
+ 384: 'indri, indris, Indri indri, Indri brevicaudatus',
+ 385: 'Indian elephant, Elephas maximus',
+ 386: 'African elephant, Loxodonta africana',
+ 387: 'lesser panda, red panda, panda, bear cat, cat bear, Ailurus fulgens',
+ 388: 'giant panda, panda, panda bear, coon bear, Ailuropoda melanoleuca',
+ 389: 'barracouta, snoek',
+ 390: 'eel',
+ 391: 'coho, cohoe, coho salmon, blue jack, silver salmon, Oncorhynchus kisutch',
+ 392: 'rock beauty, Holocanthus tricolor',
+ 393: 'anemone fish',
+ 394: 'sturgeon',
+ 395: 'gar, garfish, garpike, billfish, Lepisosteus osseus',
+ 396: 'lionfish',
+ 397: 'puffer, pufferfish, blowfish, globefish',
+ 398: 'abacus',
+ 399: 'abaya',
+ 400: "academic gown, academic robe, judge's robe",
+ 401: 'accordion, piano accordion, squeeze box',
+ 402: 'acoustic guitar',
+ 403: 'aircraft carrier, carrier, flattop, attack aircraft carrier',
+ 404: 'airliner',
+ 405: 'airship, dirigible',
+ 406: 'altar',
+ 407: 'ambulance',
+ 408: 'amphibian, amphibious vehicle',
+ 409: 'analog clock',
+ 410: 'apiary, bee house',
+ 411: 'apron',
+ 412: 'ashcan, trash can, garbage can, wastebin, ash bin, ash-bin, ashbin, dustbin, trash barrel, trash bin',
+ 413: 'assault rifle, assault gun',
+ 414: 'backpack, back pack, knapsack, packsack, rucksack, haversack',
+ 415: 'bakery, bakeshop, bakehouse',
+ 416: 'balance beam, beam',
+ 417: 'balloon',
+ 418: 'ballpoint, ballpoint pen, ballpen, Biro',
+ 419: 'Band Aid',
+ 420: 'banjo',
+ 421: 'bannister, banister, balustrade, balusters, handrail',
+ 422: 'barbell',
+ 423: 'barber chair',
+ 424: 'barbershop',
+ 425: 'barn',
+ 426: 'barometer',
+ 427: 'barrel, cask',
+ 428: 'barrow, garden cart, lawn cart, wheelbarrow',
+ 429: 'baseball',
+ 430: 'basketball',
+ 431: 'bassinet',
+ 432: 'bassoon',
+ 433: 'bathing cap, swimming cap',
+ 434: 'bath towel',
+ 435: 'bathtub, bathing tub, bath, tub',
+ 436: 'beach wagon, station wagon, wagon, estate car, beach waggon, station waggon, waggon',
+ 437: 'beacon, lighthouse, beacon light, pharos',
+ 438: 'beaker',
+ 439: 'bearskin, busby, shako',
+ 440: 'beer bottle',
+ 441: 'beer glass',
+ 442: 'bell cote, bell cot',
+ 443: 'bib',
+ 444: 'bicycle-built-for-two, tandem bicycle, tandem',
+ 445: 'bikini, two-piece',
+ 446: 'binder, ring-binder',
+ 447: 'binoculars, field glasses, opera glasses',
+ 448: 'birdhouse',
+ 449: 'boathouse',
+ 450: 'bobsled, bobsleigh, bob',
+ 451: 'bolo tie, bolo, bola tie, bola',
+ 452: 'bonnet, poke bonnet',
+ 453: 'bookcase',
+ 454: 'bookshop, bookstore, bookstall',
+ 455: 'bottlecap',
+ 456: 'bow',
+ 457: 'bow tie, bow-tie, bowtie',
+ 458: 'brass, memorial tablet, plaque',
+ 459: 'brassiere, bra, bandeau',
+ 460: 'breakwater, groin, groyne, mole, bulwark, seawall, jetty',
+ 461: 'breastplate, aegis, egis',
+ 462: 'broom',
+ 463: 'bucket, pail',
+ 464: 'buckle',
+ 465: 'bulletproof vest',
+ 466: 'bullet train, bullet',
+ 467: 'butcher shop, meat market',
+ 468: 'cab, hack, taxi, taxicab',
+ 469: 'caldron, cauldron',
+ 470: 'candle, taper, wax light',
+ 471: 'cannon',
+ 472: 'canoe',
+ 473: 'can opener, tin opener',
+ 474: 'cardigan',
+ 475: 'car mirror',
+ 476: 'carousel, carrousel, merry-go-round, roundabout, whirligig',
+ 477: "carpenter's kit, tool kit",
+ 478: 'carton',
+ 479: 'car wheel',
+ 480: 'cash machine, cash dispenser, automated teller machine, automatic teller machine, automated teller, automatic teller, ATM',
+ 481: 'cassette',
+ 482: 'cassette player',
+ 483: 'castle',
+ 484: 'catamaran',
+ 485: 'CD player',
+ 486: 'cello, violoncello',
+ 487: 'cellular telephone, cellular phone, cellphone, cell, mobile phone',
+ 488: 'chain',
+ 489: 'chainlink fence',
+ 490: 'chain mail, ring mail, mail, chain armor, chain armour, ring armor, ring armour',
+ 491: 'chain saw, chainsaw',
+ 492: 'chest',
+ 493: 'chiffonier, commode',
+ 494: 'chime, bell, gong',
+ 495: 'china cabinet, china closet',
+ 496: 'Christmas stocking',
+ 497: 'church, church building',
+ 498: 'cinema, movie theater, movie theatre, movie house, picture palace',
+ 499: 'cleaver, meat cleaver, chopper',
+ 500: 'cliff dwelling',
+ 501: 'cloak',
+ 502: 'clog, geta, patten, sabot',
+ 503: 'cocktail shaker',
+ 504: 'coffee mug',
+ 505: 'coffeepot',
+ 506: 'coil, spiral, volute, whorl, helix',
+ 507: 'combination lock',
+ 508: 'computer keyboard, keypad',
+ 509: 'confectionery, confectionary, candy store',
+ 510: 'container ship, containership, container vessel',
+ 511: 'convertible',
+ 512: 'corkscrew, bottle screw',
+ 513: 'cornet, horn, trumpet, trump',
+ 514: 'cowboy boot',
+ 515: 'cowboy hat, ten-gallon hat',
+ 516: 'cradle',
+ 517: 'crane',
+ 518: 'crash helmet',
+ 519: 'crate',
+ 520: 'crib, cot',
+ 521: 'Crock Pot',
+ 522: 'croquet ball',
+ 523: 'crutch',
+ 524: 'cuirass',
+ 525: 'dam, dike, dyke',
+ 526: 'desk',
+ 527: 'desktop computer',
+ 528: 'dial telephone, dial phone',
+ 529: 'diaper, nappy, napkin',
+ 530: 'digital clock',
+ 531: 'digital watch',
+ 532: 'dining table, board',
+ 533: 'dishrag, dishcloth',
+ 534: 'dishwasher, dish washer, dishwashing machine',
+ 535: 'disk brake, disc brake',
+ 536: 'dock, dockage, docking facility',
+ 537: 'dogsled, dog sled, dog sleigh',
+ 538: 'dome',
+ 539: 'doormat, welcome mat',
+ 540: 'drilling platform, offshore rig',
+ 541: 'drum, membranophone, tympan',
+ 542: 'drumstick',
+ 543: 'dumbbell',
+ 544: 'Dutch oven',
+ 545: 'electric fan, blower',
+ 546: 'electric guitar',
+ 547: 'electric locomotive',
+ 548: 'entertainment center',
+ 549: 'envelope',
+ 550: 'espresso maker',
+ 551: 'face powder',
+ 552: 'feather boa, boa',
+ 553: 'file, file cabinet, filing cabinet',
+ 554: 'fireboat',
+ 555: 'fire engine, fire truck',
+ 556: 'fire screen, fireguard',
+ 557: 'flagpole, flagstaff',
+ 558: 'flute, transverse flute',
+ 559: 'folding chair',
+ 560: 'football helmet',
+ 561: 'forklift',
+ 562: 'fountain',
+ 563: 'fountain pen',
+ 564: 'four-poster',
+ 565: 'freight car',
+ 566: 'French horn, horn',
+ 567: 'frying pan, frypan, skillet',
+ 568: 'fur coat',
+ 569: 'garbage truck, dustcart',
+ 570: 'gasmask, respirator, gas helmet',
+ 571: 'gas pump, gasoline pump, petrol pump, island dispenser',
+ 572: 'goblet',
+ 573: 'go-kart',
+ 574: 'golf ball',
+ 575: 'golfcart, golf cart',
+ 576: 'gondola',
+ 577: 'gong, tam-tam',
+ 578: 'gown',
+ 579: 'grand piano, grand',
+ 580: 'greenhouse, nursery, glasshouse',
+ 581: 'grille, radiator grille',
+ 582: 'grocery store, grocery, food market, market',
+ 583: 'guillotine',
+ 584: 'hair slide',
+ 585: 'hair spray',
+ 586: 'half track',
+ 587: 'hammer',
+ 588: 'hamper',
+ 589: 'hand blower, blow dryer, blow drier, hair dryer, hair drier',
+ 590: 'hand-held computer, hand-held microcomputer',
+ 591: 'handkerchief, hankie, hanky, hankey',
+ 592: 'hard disc, hard disk, fixed disk',
+ 593: 'harmonica, mouth organ, harp, mouth harp',
+ 594: 'harp',
+ 595: 'harvester, reaper',
+ 596: 'hatchet',
+ 597: 'holster',
+ 598: 'home theater, home theatre',
+ 599: 'honeycomb',
+ 600: 'hook, claw',
+ 601: 'hoopskirt, crinoline',
+ 602: 'horizontal bar, high bar',
+ 603: 'horse cart, horse-cart',
+ 604: 'hourglass',
+ 605: 'iPod',
+ 606: 'iron, smoothing iron',
+ 607: "jack-o'-lantern",
+ 608: 'jean, blue jean, denim',
+ 609: 'jeep, landrover',
+ 610: 'jersey, T-shirt, tee shirt',
+ 611: 'jigsaw puzzle',
+ 612: 'jinrikisha, ricksha, rickshaw',
+ 613: 'joystick',
+ 614: 'kimono',
+ 615: 'knee pad',
+ 616: 'knot',
+ 617: 'lab coat, laboratory coat',
+ 618: 'ladle',
+ 619: 'lampshade, lamp shade',
+ 620: 'laptop, laptop computer',
+ 621: 'lawn mower, mower',
+ 622: 'lens cap, lens cover',
+ 623: 'letter opener, paper knife, paperknife',
+ 624: 'library',
+ 625: 'lifeboat',
+ 626: 'lighter, light, igniter, ignitor',
+ 627: 'limousine, limo',
+ 628: 'liner, ocean liner',
+ 629: 'lipstick, lip rouge',
+ 630: 'Loafer',
+ 631: 'lotion',
+ 632: 'loudspeaker, speaker, speaker unit, loudspeaker system, speaker system',
+ 633: "loupe, jeweler's loupe",
+ 634: 'lumbermill, sawmill',
+ 635: 'magnetic compass',
+ 636: 'mailbag, postbag',
+ 637: 'mailbox, letter box',
+ 638: 'maillot',
+ 639: 'maillot, tank suit',
+ 640: 'manhole cover',
+ 641: 'maraca',
+ 642: 'marimba, xylophone',
+ 643: 'mask',
+ 644: 'matchstick',
+ 645: 'maypole',
+ 646: 'maze, labyrinth',
+ 647: 'measuring cup',
+ 648: 'medicine chest, medicine cabinet',
+ 649: 'megalith, megalithic structure',
+ 650: 'microphone, mike',
+ 651: 'microwave, microwave oven',
+ 652: 'military uniform',
+ 653: 'milk can',
+ 654: 'minibus',
+ 655: 'miniskirt, mini',
+ 656: 'minivan',
+ 657: 'missile',
+ 658: 'mitten',
+ 659: 'mixing bowl',
+ 660: 'mobile home, manufactured home',
+ 661: 'Model T',
+ 662: 'modem',
+ 663: 'monastery',
+ 664: 'monitor',
+ 665: 'moped',
+ 666: 'mortar',
+ 667: 'mortarboard',
+ 668: 'mosque',
+ 669: 'mosquito net',
+ 670: 'motor scooter, scooter',
+ 671: 'mountain bike, all-terrain bike, off-roader',
+ 672: 'mountain tent',
+ 673: 'mouse, computer mouse',
+ 674: 'mousetrap',
+ 675: 'moving van',
+ 676: 'muzzle',
+ 677: 'nail',
+ 678: 'neck brace',
+ 679: 'necklace',
+ 680: 'nipple',
+ 681: 'notebook, notebook computer',
+ 682: 'obelisk',
+ 683: 'oboe, hautboy, hautbois',
+ 684: 'ocarina, sweet potato',
+ 685: 'odometer, hodometer, mileometer, milometer',
+ 686: 'oil filter',
+ 687: 'organ, pipe organ',
+ 688: 'oscilloscope, scope, cathode-ray oscilloscope, CRO',
+ 689: 'overskirt',
+ 690: 'oxcart',
+ 691: 'oxygen mask',
+ 692: 'packet',
+ 693: 'paddle, boat paddle',
+ 694: 'paddlewheel, paddle wheel',
+ 695: 'padlock',
+ 696: 'paintbrush',
+ 697: "pajama, pyjama, pj's, jammies",
+ 698: 'palace',
+ 699: 'panpipe, pandean pipe, syrinx',
+ 700: 'paper towel',
+ 701: 'parachute, chute',
+ 702: 'parallel bars, bars',
+ 703: 'park bench',
+ 704: 'parking meter',
+ 705: 'passenger car, coach, carriage',
+ 706: 'patio, terrace',
+ 707: 'pay-phone, pay-station',
+ 708: 'pedestal, plinth, footstall',
+ 709: 'pencil box, pencil case',
+ 710: 'pencil sharpener',
+ 711: 'perfume, essence',
+ 712: 'Petri dish',
+ 713: 'photocopier',
+ 714: 'pick, plectrum, plectron',
+ 715: 'pickelhaube',
+ 716: 'picket fence, paling',
+ 717: 'pickup, pickup truck',
+ 718: 'pier',
+ 719: 'piggy bank, penny bank',
+ 720: 'pill bottle',
+ 721: 'pillow',
+ 722: 'ping-pong ball',
+ 723: 'pinwheel',
+ 724: 'pirate, pirate ship',
+ 725: 'pitcher, ewer',
+ 726: "plane, carpenter's plane, woodworking plane",
+ 727: 'planetarium',
+ 728: 'plastic bag',
+ 729: 'plate rack',
+ 730: 'plow, plough',
+ 731: "plunger, plumber's helper",
+ 732: 'Polaroid camera, Polaroid Land camera',
+ 733: 'pole',
+ 734: 'police van, police wagon, paddy wagon, patrol wagon, wagon, black Maria',
+ 735: 'poncho',
+ 736: 'pool table, billiard table, snooker table',
+ 737: 'pop bottle, soda bottle',
+ 738: 'pot, flowerpot',
+ 739: "potter's wheel",
+ 740: 'power drill',
+ 741: 'prayer rug, prayer mat',
+ 742: 'printer',
+ 743: 'prison, prison house',
+ 744: 'projectile, missile',
+ 745: 'projector',
+ 746: 'puck, hockey puck',
+ 747: 'punching bag, punch bag, punching ball, punchball',
+ 748: 'purse',
+ 749: 'quill, quill pen',
+ 750: 'quilt, comforter, comfort, puff',
+ 751: 'racer, race car, racing car',
+ 752: 'racket, racquet',
+ 753: 'radiator',
+ 754: 'radio, wireless',
+ 755: 'radio telescope, radio reflector',
+ 756: 'rain barrel',
+ 757: 'recreational vehicle, RV, R.V.',
+ 758: 'reel',
+ 759: 'reflex camera',
+ 760: 'refrigerator, icebox',
+ 761: 'remote control, remote',
+ 762: 'restaurant, eating house, eating place, eatery',
+ 763: 'revolver, six-gun, six-shooter',
+ 764: 'rifle',
+ 765: 'rocking chair, rocker',
+ 766: 'rotisserie',
+ 767: 'rubber eraser, rubber, pencil eraser',
+ 768: 'rugby ball',
+ 769: 'rule, ruler',
+ 770: 'running shoe',
+ 771: 'safe',
+ 772: 'safety pin',
+ 773: 'saltshaker, salt shaker',
+ 774: 'sandal',
+ 775: 'sarong',
+ 776: 'sax, saxophone',
+ 777: 'scabbard',
+ 778: 'scale, weighing machine',
+ 779: 'school bus',
+ 780: 'schooner',
+ 781: 'scoreboard',
+ 782: 'screen, CRT screen',
+ 783: 'screw',
+ 784: 'screwdriver',
+ 785: 'seat belt, seatbelt',
+ 786: 'sewing machine',
+ 787: 'shield, buckler',
+ 788: 'shoe shop, shoe-shop, shoe store',
+ 789: 'shoji',
+ 790: 'shopping basket',
+ 791: 'shopping cart',
+ 792: 'shovel',
+ 793: 'shower cap',
+ 794: 'shower curtain',
+ 795: 'ski',
+ 796: 'ski mask',
+ 797: 'sleeping bag',
+ 798: 'slide rule, slipstick',
+ 799: 'sliding door',
+ 800: 'slot, one-armed bandit',
+ 801: 'snorkel',
+ 802: 'snowmobile',
+ 803: 'snowplow, snowplough',
+ 804: 'soap dispenser',
+ 805: 'soccer ball',
+ 806: 'sock',
+ 807: 'solar dish, solar collector, solar furnace',
+ 808: 'sombrero',
+ 809: 'soup bowl',
+ 810: 'space bar',
+ 811: 'space heater',
+ 812: 'space shuttle',
+ 813: 'spatula',
+ 814: 'speedboat',
+ 815: "spider web, spider's web",
+ 816: 'spindle',
+ 817: 'sports car, sport car',
+ 818: 'spotlight, spot',
+ 819: 'stage',
+ 820: 'steam locomotive',
+ 821: 'steel arch bridge',
+ 822: 'steel drum',
+ 823: 'stethoscope',
+ 824: 'stole',
+ 825: 'stone wall',
+ 826: 'stopwatch, stop watch',
+ 827: 'stove',
+ 828: 'strainer',
+ 829: 'streetcar, tram, tramcar, trolley, trolley car',
+ 830: 'stretcher',
+ 831: 'studio couch, day bed',
+ 832: 'stupa, tope',
+ 833: 'submarine, pigboat, sub, U-boat',
+ 834: 'suit, suit of clothes',
+ 835: 'sundial',
+ 836: 'sunglass',
+ 837: 'sunglasses, dark glasses, shades',
+ 838: 'sunscreen, sunblock, sun blocker',
+ 839: 'suspension bridge',
+ 840: 'swab, swob, mop',
+ 841: 'sweatshirt',
+ 842: 'swimming trunks, bathing trunks',
+ 843: 'swing',
+ 844: 'switch, electric switch, electrical switch',
+ 845: 'syringe',
+ 846: 'table lamp',
+ 847: 'tank, army tank, armored combat vehicle, armoured combat vehicle',
+ 848: 'tape player',
+ 849: 'teapot',
+ 850: 'teddy, teddy bear',
+ 851: 'television, television system',
+ 852: 'tennis ball',
+ 853: 'thatch, thatched roof',
+ 854: 'theater curtain, theatre curtain',
+ 855: 'thimble',
+ 856: 'thresher, thrasher, threshing machine',
+ 857: 'throne',
+ 858: 'tile roof',
+ 859: 'toaster',
+ 860: 'tobacco shop, tobacconist shop, tobacconist',
+ 861: 'toilet seat',
+ 862: 'torch',
+ 863: 'totem pole',
+ 864: 'tow truck, tow car, wrecker',
+ 865: 'toyshop',
+ 866: 'tractor',
+ 867: 'trailer truck, tractor trailer, trucking rig, rig, articulated lorry, semi',
+ 868: 'tray',
+ 869: 'trench coat',
+ 870: 'tricycle, trike, velocipede',
+ 871: 'trimaran',
+ 872: 'tripod',
+ 873: 'triumphal arch',
+ 874: 'trolleybus, trolley coach, trackless trolley',
+ 875: 'trombone',
+ 876: 'tub, vat',
+ 877: 'turnstile',
+ 878: 'typewriter keyboard',
+ 879: 'umbrella',
+ 880: 'unicycle, monocycle',
+ 881: 'upright, upright piano',
+ 882: 'vacuum, vacuum cleaner',
+ 883: 'vase',
+ 884: 'vault',
+ 885: 'velvet',
+ 886: 'vending machine',
+ 887: 'vestment',
+ 888: 'viaduct',
+ 889: 'violin, fiddle',
+ 890: 'volleyball',
+ 891: 'waffle iron',
+ 892: 'wall clock',
+ 893: 'wallet, billfold, notecase, pocketbook',
+ 894: 'wardrobe, closet, press',
+ 895: 'warplane, military plane',
+ 896: 'washbasin, handbasin, washbowl, lavabo, wash-hand basin',
+ 897: 'washer, automatic washer, washing machine',
+ 898: 'water bottle',
+ 899: 'water jug',
+ 900: 'water tower',
+ 901: 'whiskey jug',
+ 902: 'whistle',
+ 903: 'wig',
+ 904: 'window screen',
+ 905: 'window shade',
+ 906: 'Windsor tie',
+ 907: 'wine bottle',
+ 908: 'wing',
+ 909: 'wok',
+ 910: 'wooden spoon',
+ 911: 'wool, woolen, woollen',
+ 912: 'worm fence, snake fence, snake-rail fence, Virginia fence',
+ 913: 'wreck',
+ 914: 'yawl',
+ 915: 'yurt',
+ 916: 'web site, website, internet site, site',
+ 917: 'comic book',
+ 918: 'crossword puzzle, crossword',
+ 919: 'street sign',
+ 920: 'traffic light, traffic signal, stoplight',
+ 921: 'book jacket, dust cover, dust jacket, dust wrapper',
+ 922: 'menu',
+ 923: 'plate',
+ 924: 'guacamole',
+ 925: 'consomme',
+ 926: 'hot pot, hotpot',
+ 927: 'trifle',
+ 928: 'ice cream, icecream',
+ 929: 'ice lolly, lolly, lollipop, popsicle',
+ 930: 'French loaf',
+ 931: 'bagel, beigel',
+ 932: 'pretzel',
+ 933: 'cheeseburger',
+ 934: 'hotdog, hot dog, red hot',
+ 935: 'mashed potato',
+ 936: 'head cabbage',
+ 937: 'broccoli',
+ 938: 'cauliflower',
+ 939: 'zucchini, courgette',
+ 940: 'spaghetti squash',
+ 941: 'acorn squash',
+ 942: 'butternut squash',
+ 943: 'cucumber, cuke',
+ 944: 'artichoke, globe artichoke',
+ 945: 'bell pepper',
+ 946: 'cardoon',
+ 947: 'mushroom',
+ 948: 'Granny Smith',
+ 949: 'strawberry',
+ 950: 'orange',
+ 951: 'lemon',
+ 952: 'fig',
+ 953: 'pineapple, ananas',
+ 954: 'banana',
+ 955: 'jackfruit, jak, jack',
+ 956: 'custard apple',
+ 957: 'pomegranate',
+ 958: 'hay',
+ 959: 'carbonara',
+ 960: 'chocolate sauce, chocolate syrup',
+ 961: 'dough',
+ 962: 'meat loaf, meatloaf',
+ 963: 'pizza, pizza pie',
+ 964: 'potpie',
+ 965: 'burrito',
+ 966: 'red wine',
+ 967: 'espresso',
+ 968: 'cup',
+ 969: 'eggnog',
+ 970: 'alp',
+ 971: 'bubble',
+ 972: 'cliff, drop, drop-off',
+ 973: 'coral reef',
+ 974: 'geyser',
+ 975: 'lakeside, lakeshore',
+ 976: 'promontory, headland, head, foreland',
+ 977: 'sandbar, sand bar',
+ 978: 'seashore, coast, seacoast, sea-coast',
+ 979: 'valley, vale',
+ 980: 'volcano',
+ 981: 'ballplayer, baseball player',
+ 982: 'groom, bridegroom',
+ 983: 'scuba diver',
+ 984: 'rapeseed',
+ 985: 'daisy',
+ 986: "yellow lady's slipper, yellow lady-slipper, Cypripedium calceolus, Cypripedium parviflorum",
+ 987: 'corn',
+ 988: 'acorn',
+ 989: 'hip, rose hip, rosehip',
+ 990: 'buckeye, horse chestnut, conker',
+ 991: 'coral fungus',
+ 992: 'agaric',
+ 993: 'gyromitra',
+ 994: 'stinkhorn, carrion fungus',
+ 995: 'earthstar',
+ 996: 'hen-of-the-woods, hen of the woods, Polyporus frondosus, Grifola frondosa',
+ 997: 'bolete',
+ 998: 'ear, spike, capitulum',
+ 999: 'toilet tissue, toilet paper, bathroom tissue'}
diff --git a/requirements.txt b/requirements.txt
index 7a1fc68..8196c05 100755
--- a/requirements.txt
+++ b/requirements.txt
@@ -16,3 +16,4 @@ ipywidgets==7.1.2
 bqplot==0.10.5
 pyyaml
 pytest==3.5.1
+xlsxwriter==1.1.1
diff --git a/tests/test_summarygraph.py b/tests/test_summarygraph.py
index 6fcda7e..3d4153f 100755
--- a/tests/test_summarygraph.py
+++ b/tests/test_summarygraph.py
@@ -183,7 +183,7 @@ def test_connectivity_summary():
     assert len(summary) == 73
 
     verbose_summary = connectivity_summary_verbose(g)
-    assert len(verbose_summary  ) == 73
+    assert len(verbose_summary) == 73
 
 
 if __name__ == '__main__':
-- 
GitLab