From 718f777b8f6b12ffa8ce620c0cd5c20755ee7197 Mon Sep 17 00:00:00 2001
From: Neta Zmora <31280975+nzmora@users.noreply.github.com>
Date: Fri, 13 Jul 2018 13:32:14 +0300
Subject: [PATCH] ADC (Automatic Deep Compression) example + features, tests,
 bug fixes  (#28)
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit

This is a merge of the ADC branch and master.
ADC (using a DDPG RL agent to compress image classifiers) is still WiP and requires
An unreleased version of Coach (https://github.com/NervanaSystems/coach).

Small features in this commit:
-Added model_find_module() - find module object given its name
- Add channel ranking and pruning: pruning/ranked_structures_pruner.py
- Add a CIFAR10 VGG16 model: models/cifar10/vgg_cifar.py
- Thinning: change the level of some log messages – some of the messages were
moved to ‘debug’ level because they are not usually interesting.
- Add a function to print nicely formatted integers - distiller/utils.py
- Sensitivity analysis for channels-removal
- compress_classifier.py – handle keyboard interrupts
- compress_classifier.py – fix re-raise of exceptions, so they maintain call-stack

-Added tests:
-- test_summarygraph.py: test_simplenet() - Added a regression test to target a bug that occurs when taking the predecessor of the first node in a graph
-- test_ranking.py - test_ch_ranking, test_ranked_channel_pruning
-- test_model_summary.py - test_png_generation, test_summary (sparsity/ compute/model/modules)

- Bug fixes in this commit:
-- Thinning bug fix: handle zero-sized 'indices' tensor
During the thinning process, the 'indices' tensor can become zero-sized,
and will have an undefiend length. Therefore, we need to check for this
situation when assessing the number of elements in 'indices'
-- Language model: adjust main.py to new distiller.model_summary API
---
 distiller/__init__.py                         |  16 +
 distiller/data_loggers/logger.py              |   8 +-
 distiller/model_summaries.py                  |  10 +-
 distiller/pruning/ranked_structures_pruner.py |  56 ++-
 distiller/sensitivity.py                      |  11 +-
 distiller/thinning.py                         |  24 +-
 distiller/utils.py                            |   4 +
 examples/automated_deep_compression/ADC.py    | 357 ++++++++++++++++++
 .../presets/ADC_DDPG.py                       |  73 ++++
 .../compress_classifier.py                    |  35 +-
 .../resnet56_cifar_baseline_training.yaml     | 143 ++++---
 examples/word_language_model/main.py          |   6 +-
 models/__init__.py                            |   1 -
 models/cifar10/__init__.py                    |   1 +
 models/cifar10/vgg_cifar.py                   | 133 +++++++
 tests/common.py                               |   2 +-
 tests/test_model_summary.py                   |  63 ++++
 tests/test_pruning.py                         |   2 -
 tests/test_ranking.py                         | 101 +++++
 tests/test_summarygraph.py                    |  25 ++
 20 files changed, 953 insertions(+), 118 deletions(-)
 create mode 100755 examples/automated_deep_compression/ADC.py
 create mode 100755 examples/automated_deep_compression/presets/ADC_DDPG.py
 create mode 100755 models/cifar10/vgg_cifar.py
 create mode 100755 tests/test_model_summary.py
 create mode 100755 tests/test_ranking.py

diff --git a/distiller/__init__.py b/distiller/__init__.py
index 6dec91f..6b0c21f 100755
--- a/distiller/__init__.py
+++ b/distiller/__init__.py
@@ -82,3 +82,19 @@ def model_find_param(model, param_to_find_name):
         if name == param_to_find_name:
             return param
     return None
+
+
+def model_find_module(model, module_to_find):
+    """Given a module name, find the module in the provided model.
+
+    Arguments:
+        model: the model to search
+        module_to_find: the module whose name we want to look up
+
+    Returns:
+        The module or None, if the module was not found.
+    """
+    for name, m in model.named_modules():
+        if name == module_to_find:
+            return m
+    return None
diff --git a/distiller/data_loggers/logger.py b/distiller/data_loggers/logger.py
index ae3ad55..796c6c3 100755
--- a/distiller/data_loggers/logger.py
+++ b/distiller/data_loggers/logger.py
@@ -70,13 +70,15 @@ class PythonLogger(DataLogger):
 
     def log_training_progress(self, stats_dict, epoch, completed, total, freq):
         stats_dict = stats_dict[1]
-        if epoch>-1:
+        if epoch > -1:
             log = 'Epoch: [{}][{:5d}/{:5d}]    '.format(epoch, completed, int(total))
         else:
             log = 'Test: [{:5d}/{:5d}]    '.format(completed, int(total))
-            #log = 'Test: [{1:5d}/{2:5d}]    '.format(total)
         for name, val in stats_dict.items():
-            log = log + '{name} {val:.6f}    '.format(name=name, val=val)
+            if isinstance(val, int):
+                log = log + '{name} {val}    '.format(name=name, val=distiller.pretty_int(val))
+            else:
+                log = log + '{name} {val:.6f}    '.format(name=name, val=val)
         self.pylogger.info(log)
 
 
diff --git a/distiller/model_summaries.py b/distiller/model_summaries.py
index cf67720..8138402 100755
--- a/distiller/model_summaries.py
+++ b/distiller/model_summaries.py
@@ -56,7 +56,6 @@ def model_summary(model, what, dataset=None):
         total_macs = df['MACs'].sum()
         print(t)
         print("Total MACs: " + "{:,}".format(total_macs))
-
     elif what == 'model':
         # print the simple form of the model
         print(model)
@@ -71,21 +70,20 @@ def model_summary(model, what, dataset=None):
             if len(module._modules) == 0:
                 nodes.append([name, module.__class__.__name__])
         print(tabulate(nodes, headers=['Name', 'Type']))
+    else:
+        raise ValueError("%s is not a supported summary type" % what)
 
 
 def weights_sparsity_summary(model, return_total_sparsity=False, param_dims=[2, 4]):
-
     df = pd.DataFrame(columns=['Name', 'Shape', 'NNZ (dense)', 'NNZ (sparse)',
-                               'Cols (%)','Rows (%)', 'Ch (%)', '2D (%)', '3D (%)',
+                               'Cols (%)', 'Rows (%)', 'Ch (%)', '2D (%)', '3D (%)',
                                'Fine (%)', 'Std', 'Mean', 'Abs-Mean'])
     pd.set_option('precision', 2)
     params_size = 0
     sparse_params_size = 0
-    summary_param_types = ['weight', 'bias']
     for name, param in model.state_dict().items():
         # Extract just the actual parameter's name, which in this context we treat as its "type"
-        curr_param_type = name.split('.')[-1]
-        if param.dim() in param_dims and curr_param_type in summary_param_types:
+        if param.dim() in param_dims and any(type in name for type in ['weight', 'bias']):
             _density = distiller.density(param)
             params_size += torch.numel(param)
             sparse_params_size += param.numel() * _density
diff --git a/distiller/pruning/ranked_structures_pruner.py b/distiller/pruning/ranked_structures_pruner.py
index 683f118..4816410 100755
--- a/distiller/pruning/ranked_structures_pruner.py
+++ b/distiller/pruning/ranked_structures_pruner.py
@@ -20,6 +20,7 @@ import distiller
 from .pruner import _ParameterPruner
 msglogger = logging.getLogger()
 
+# TODO: support different policies for ranking structures
 class L1RankedStructureParameterPruner(_ParameterPruner):
     """Uses mean L1-norm to rank structures and prune a specified percentage of structures
     """
@@ -28,6 +29,7 @@ class L1RankedStructureParameterPruner(_ParameterPruner):
         self.name = name
         self.reg_regims = reg_regims
 
+
     def set_param_mask(self, param, param_name, zeros_mask_dict, meta):
         if param_name not in self.reg_regims.keys():
             return
@@ -37,7 +39,59 @@ class L1RankedStructureParameterPruner(_ParameterPruner):
         if fraction_to_prune == 0:
             return
 
-        assert group_type == "3D", "Currently only filter ranking is supported"
+        if group_type not in ['3D', 'Channels']:
+            raise ValueError("Currently only filter (3D) and channel ranking is supported")
+        if group_type == "3D":
+            return self.rank_prune_filters(fraction_to_prune, param, param_name, zeros_mask_dict)
+        elif group_type == "Channels":
+            return self.rank_prune_channels(fraction_to_prune, param, param_name, zeros_mask_dict)
+
+    @staticmethod
+    def rank_channels(fraction_to_prune, param):
+        num_filters = param.size(0)
+        num_channels = param.size(1)
+        kernel_size = param.size(2) * param.size(3)
+
+        # First, reshape the weights tensor such that each channel (kernel) in the original
+        # tensor, is now a row in the 2D tensor.
+        view_2d = param.view(-1, kernel_size)
+        # Next, compute the sums of each kernel
+        kernel_sums = view_2d.abs().sum(dim=1)
+        # Now group by channels
+        k_sums_mat = kernel_sums.view(num_filters, num_channels).t()
+        channel_mags = k_sums_mat.mean(dim=1)
+        k = int(fraction_to_prune * channel_mags.size(0))
+        if k == 0:
+            msglogger.info("Too few channels (%d)- can't prune %.1f%% channels",
+                            num_channels, 100*fraction_to_prune)
+            return None, None
+
+        bottomk, _ = torch.topk(channel_mags, k, largest=False, sorted=True)
+        return bottomk, channel_mags
+
+
+    def rank_prune_channels(self, fraction_to_prune, param, param_name, zeros_mask_dict):
+        bottomk_channels, channel_mags = self.rank_channels(fraction_to_prune, param)
+        if bottomk_channels is None:
+            # Empty list means that fraction_to_prune is too low to prune anything
+            return
+
+        num_filters = param.size(0)
+        num_channels = param.size(1)
+
+        threshold = bottomk_channels[-1]
+        binary_map = channel_mags.gt(threshold).type(param.data.type())
+        a = binary_map.expand(num_filters, num_channels)
+        c = a.unsqueeze(-1)
+        d = c.expand(num_filters, num_channels, param.size(2) * param.size(3)).contiguous()
+        zeros_mask_dict[param_name].mask = d.view(num_filters, num_channels, param.size(2), param.size(3))
+
+        msglogger.info("L1RankedStructureParameterPruner - param: %s pruned=%.3f goal=%.3f (%d/%d)", param_name,
+                       distiller.sparsity_ch(zeros_mask_dict[param_name].mask),
+                       fraction_to_prune, len(bottomk_channels), num_channels)
+
+
+    def rank_prune_filters(self, fraction_to_prune, param, param_name, zeros_mask_dict):
         assert param.dim() == 4, "This thresholding is only supported for 4D weights"
         view_filters = param.view(param.size(0), -1)
         filter_mags = view_filters.data.abs().mean(dim=1)
diff --git a/distiller/sensitivity.py b/distiller/sensitivity.py
index 69f8a0d..8db2773 100755
--- a/distiller/sensitivity.py
+++ b/distiller/sensitivity.py
@@ -65,7 +65,8 @@ def perform_sensitivity_analysis(model, net_params, sparsities, test_func, group
     The test_func is expected to execute the model on a test/validation dataset,
     and return the results for top1 and top5 accuracies, and the loss value.
     """
-    assert group in ['element', 'filter']
+    if group not in ['element', 'filter', 'channel']:
+        raise ValueError("group parameter contains an illegal value: {}".format(group))
     sensitivities = OrderedDict()
 
     for param_name in net_params:
@@ -86,12 +87,18 @@ def perform_sensitivity_analysis(model, net_params, sparsities, test_func, group
                 # Element-wise sparasity
                 sparsity_levels = {param_name: sparsity_level}
                 pruner = distiller.pruning.SparsityLevelParameterPruner(name='sensitivity', levels=sparsity_levels)
-            else:
+            elif group == 'filter':
                 # Filter ranking
                 if model.state_dict()[param_name].dim() != 4:
                     continue
                 regims = {param_name: [sparsity_level, '3D']}
                 pruner = distiller.pruning.L1RankedStructureParameterPruner(name='sensitivity', reg_regims=regims)
+            elif group == 'channel':
+                # Filter ranking
+                if model.state_dict()[param_name].dim() != 4:
+                    continue
+                regims = {param_name: [sparsity_level, 'Channels']}
+                pruner = distiller.pruning.L1RankedStructureParameterPruner(name='sensitivity', reg_regims=regims)
 
             policy = distiller.PruningPolicy(pruner, pruner_args=None)
             scheduler = CompressionScheduler(model_cpy)
diff --git a/distiller/thinning.py b/distiller/thinning.py
index de7b625..684d715 100755
--- a/distiller/thinning.py
+++ b/distiller/thinning.py
@@ -15,13 +15,11 @@
 #
 
 """Model thinning support.
-
 Thinning a model is the process of taking a dense network architecture with a parameter model that
 has structure-sparsity (filters or channels) in the weights tensors of convolution layers, and making changes
 in the network architecture and parameters, in order to completely remove the structures.
 The new architecture is smaller (condensed), with less channels and filters in some of the convolution layers.
 Linear and BatchNormalization layers are also adjusted as required.
-
 To perform thinning, we create a SummaryGraph (‘sgraph’) of our model.  We use the ‘sgraph’ to infer the
 data-dependency between the modules in the PyTorch network.  This entire process is not trivial and will be
 documented in a different place.
@@ -42,11 +40,8 @@ ThinningRecipe = namedtuple('ThinningRecipe', ['modules', 'parameters'])
 """A ThinningRecipe is composed of two sets of instructions.
 1. Instructions for setting module attributes (e.g. Conv2d.out_channels).  This set
 is called 'ThinningRecipe.modules'.
-
 2. Information on how to select specific dimensions from parameter tensors.  This
 set is called 'ThinningRecipe.parameters'.
-
-
 ThinningRecipe.modules is a dictionary keyed by the module names (strings).  Values
 are called 'module-directives', and are grouped in another dictionary, whose keys are
 the module attributes.  For example:
@@ -55,7 +50,6 @@ the module attributes.  For example:
         out_channels: 512
     classifier.0:
         in_channels: 22589
-
 ThinningRecipe.parameters is a dictionary keyed by the parameter names (strings).
 Values are called 'parameter directives', and each directive is a list of tuples.
 These tuples can have 2 values, or 4 values.
@@ -78,7 +72,7 @@ def create_graph(dataset, arch):
 
     model = create_model(False, dataset, arch, parallel=False)
     assert model is not None
-    return SummaryGraph(model, dummy_input)
+    return SummaryGraph(model, dummy_input.cuda())
 
 
 def param_name_2_layer_name(param_name):
@@ -99,7 +93,6 @@ def append_module_directive(thinning_recipe, module_name, key, val):
 
 def bn_thinning(thinning_recipe, layers, bn_name, len_thin_features, thin_features):
     """Adjust the sizes of the parameters of a BatchNormalization layer
-
     This function is invoked after the Convolution layer preceeding a BN layer has
     changed dimensions (filters or channels were removed), and the BN layer also
     requires updating as a result.
@@ -123,7 +116,6 @@ def bn_thinning(thinning_recipe, layers, bn_name, len_thin_features, thin_featur
 
 def resnet_cifar_remove_layers(model):
     """Remove layers from ResNet-Cifar
-
     Search for convolution layers which have 100% sparse weight tensors and remove
     them from the model.  This ugly code is specific to ResNet for Cifar, using the
     layer gating mechanism that we added in order to remove layers from the network.
@@ -158,7 +150,6 @@ def remove_channels(model, zeros_mask_dict, arch, dataset, optimizer):
 
 def find_nonzero_channels(param, param_name):
     """Count the number of non-zero channels in a weights tensor.
-
     Non-zero channels are channels that have at least one coefficient that is
     non-zero.  Counting non-zero channels involves some tensor acrobatics.
     """
@@ -213,7 +204,6 @@ def remove_filters(model, zeros_mask_dict, arch, dataset, optimizer):
 
 def create_thinning_recipe_channels(sgraph, model, zeros_mask_dict):
     """Create a recipe for removing channels from Convolution layers.
-
     The 4D weights of the model parameters (i.e. the convolution parameters) are
     examined one by one, to determine which has channels that are all zeros.
     For each weights tensor that has at least one zero-channel, we create a
@@ -281,7 +271,6 @@ def create_thinning_recipe_channels(sgraph, model, zeros_mask_dict):
 
 def create_thinning_recipe_filters(sgraph, model, zeros_mask_dict):
     """Create a recipe for removing filters from Convolution layers.
-
     The 4D weights of the model parameters (i.e. the convolution parameters) are
     examined one by one, to determine which has filters that are all zeros.
     For each weights tensor that has at least one zero-filter, we create a
@@ -337,7 +326,7 @@ def create_thinning_recipe_filters(sgraph, model, zeros_mask_dict):
             if isinstance(layers[successor], torch.nn.modules.Conv2d):
                 # For each of the convolutional layers that follow, we have to reduce the number of input channels.
                 append_module_directive(thinning_recipe, successor, key='in_channels', val=num_nnz_filters)
-                msglogger.info("[recipe] {}: setting in_channels = {}".format(successor, num_nnz_filters))
+                msglogger.debug("[recipe] {}: setting in_channels = {}".format(successor, num_nnz_filters))
 
                 # Now remove channels from the weights tensor of the successor conv
                 append_param_directive(thinning_recipe, successor+'.weight', (1, indices))
@@ -347,7 +336,7 @@ def create_thinning_recipe_filters(sgraph, model, zeros_mask_dict):
                 fm_size = layers[successor].in_features // layers[layer_name].out_channels
                 in_features = fm_size * num_nnz_filters
                 append_module_directive(thinning_recipe, successor, key='in_features', val=in_features)
-                msglogger.info("[recipe] {}: setting in_features = {}".format(successor, in_features))
+                msglogger.debug("[recipe] {}: setting in_features = {}".format(successor, in_features))
 
                 # Now remove channels from the weights tensor of the successor FC layer:
                 # This is a bit tricky:
@@ -450,7 +439,6 @@ def optimizer_thinning(optimizer, param, dim, indices, new_shape=None):
 
 def execute_thinning_recipe(model, zeros_mask_dict, recipe, optimizer, loaded_from_file=False):
     """Apply a thinning recipe to a model.
-
     This will remove filters and channels, as well as handle batch-normalization parameter
     adjustment, and thinning of weight tensors.
     """
@@ -467,12 +455,12 @@ def execute_thinning_recipe(model, zeros_mask_dict, recipe, optimizer, loaded_fr
                 indices_to_select = val[1]
                 # Check if we're trying to trim a parameter that is already "thin"
                 if running.size(dim_to_trim) != indices_to_select.nelement():
-                    msglogger.info("[thinning] {}: setting {} to {}".
+                    msglogger.debug("[thinning] {}: setting {} to {}".
                                    format(layer_name, attr, indices_to_select.nelement()))
                     setattr(layers[layer_name], attr,
                             torch.index_select(running, dim=dim_to_trim, index=indices_to_select))
             else:
-                msglogger.info("[thinning] {}: setting {} to {}".format(layer_name, attr, val))
+                msglogger.debug("[thinning] {}: setting {} to {}".format(layer_name, attr, val))
                 setattr(layers[layer_name], attr, val)
 
     assert len(recipe.parameters) > 0
@@ -503,7 +491,7 @@ def execute_thinning_recipe(model, zeros_mask_dict, recipe, optimizer, loaded_fr
             else:
                 if param.data.size(dim) != len_indices:
                     param.data = torch.index_select(param.data, dim, indices)
-                    msglogger.info("[thinning] changed param {} shape: {}".format(param_name, len_indices))
+                    msglogger.debug("[thinning] changed param {} shape: {}".format(param_name, len_indices))
                 # We also need to change the dimensions of the gradient tensor.
                 # If have not done a backward-pass thus far, then the gradient will
                 # not exist, and therefore won't need to be re-dimensioned.
diff --git a/distiller/utils.py b/distiller/utils.py
index da90f64..1727953 100755
--- a/distiller/utils.py
+++ b/distiller/utils.py
@@ -52,6 +52,10 @@ def size_to_str(torch_size):
     return '('+(', ').join(['%d' % v for v in torch_size])+')'
 
 
+def pretty_int(i):
+    return "{:,}".format(i)
+
+
 def normalize_module_name(layer_name):
     """Normalize a module's name.
 
diff --git a/examples/automated_deep_compression/ADC.py b/examples/automated_deep_compression/ADC.py
new file mode 100755
index 0000000..cca53e1
--- /dev/null
+++ b/examples/automated_deep_compression/ADC.py
@@ -0,0 +1,357 @@
+import random
+import math
+import copy
+import logging
+import numpy as np
+import torch
+import gym
+from gym import spaces
+import distiller
+from apputils import SummaryGraph
+from collections import OrderedDict, namedtuple
+from types import SimpleNamespace
+from distiller import normalize_module_name
+
+from base_parameters import TaskParameters
+from examples.automated_deep_compression.presets.ADC_DDPG import graph_manager
+
+msglogger = logging.getLogger()
+Observation = namedtuple('Observation', ['t', 'n', 'c', 'h', 'w', 'stride', 'k', 'MACs', 'reduced', 'rest', 'prev_a'])
+ALMOST_ONE = 0.9999
+
+# TODO: this is also defined in test_pruning.py
+def create_model_masks(model):
+    # Create the masks
+    zeros_mask_dict = {}
+    for name, param in model.named_parameters():
+        masker = distiller.ParameterMasker(name)
+        zeros_mask_dict[name] = masker
+    return zeros_mask_dict
+
+
+USE_COACH = True
+
+
+def do_adc(model, dataset, arch, data_loader, validate_fn, save_checkpoint_fn):
+    np.random.seed()
+
+    if USE_COACH:
+        task_parameters = TaskParameters(framework_type="tensorflow",
+                                         experiment_path="./experiments/test")
+        extra_params = {'save_checkpoint_secs': None,
+                        'render': True}
+        task_parameters.__dict__.update(extra_params)
+
+        graph_manager.env_params.additional_simulator_parameters = {
+            'model': model,
+            'dataset': dataset,
+            'arch': arch,
+            'data_loader': data_loader,
+            'validate_fn': validate_fn,
+            'save_checkpoint_fn': save_checkpoint_fn
+        }
+        graph_manager.create_graph(task_parameters)
+        graph_manager.improve()
+        return
+
+    """Random ADC agent"""
+    env = CNNEnvironment(model, dataset, arch, data_loader, validate_fn, save_checkpoint_fn)
+
+    for ep in range(10):
+        observation = env.reset()
+        for t in range(100):
+            env.render(0, 0)
+            msglogger.info("[episode={}:{}] observation = {}".format(ep, t, observation))
+            # take a random action
+            action = env.action_space.sample()
+            observation, reward, done, info = env.step(action)
+            if done:
+                msglogger.info("Episode finished after {} timesteps".format(t+1))
+                break
+
+
+class RandomADCActionSpace(object):
+    def sample(self):
+        return random.uniform(0, 1)
+
+
+def collect_conv_details(model, dataset):
+    if dataset == 'imagenet':
+        dummy_input = torch.randn(1, 3, 224, 224)
+    elif dataset == 'cifar10':
+        dummy_input = torch.randn(1, 3, 32, 32)
+    else:
+        raise ValueError("dataset %s is not supported" % dataset)
+
+    g = SummaryGraph(model.cuda(), dummy_input.cuda())
+    conv_layers = OrderedDict()
+    total_macs = 0
+    for id, (name, m) in enumerate(model.named_modules()):
+        if isinstance(m, torch.nn.Conv2d):
+            conv = SimpleNamespace()
+            conv.t = len(conv_layers)
+            conv.k = m.kernel_size[0]
+            conv.stride = m.stride
+
+            # Use the SummaryGraph to obtain some other details of the models
+            conv_op = g.find_op(normalize_module_name(name))
+            assert conv_op is not None
+
+            conv.macs = conv_op['attrs']['MACs']
+            total_macs += conv.macs
+            conv.ofm_h = g.param_shape(conv_op['outputs'][0])[2]
+            conv.ofm_w = g.param_shape(conv_op['outputs'][0])[3]
+            conv.ifm_h = g.param_shape(conv_op['inputs'][0])[2]
+            conv.ifm_w = g.param_shape(conv_op['inputs'][0])[3]
+
+            conv.name = name
+            conv.id = id
+            conv_layers[len(conv_layers)] = conv
+
+    return conv_layers, total_macs
+
+
+class CNNEnvironment(gym.Env):
+    metadata = {'render.modes': ['human']}
+    STATE_EMBEDDING_LEN = len(Observation._fields)
+
+    def __init__(self, model, dataset, arch, data_loader, validate_fn, save_checkpoint_fn):
+        self.pylogger = distiller.data_loggers.PythonLogger(msglogger)
+        self.tflogger = distiller.data_loggers.TensorBoardLogger(msglogger.logdir)
+
+        self.action_space = RandomADCActionSpace()
+        self.dataset = dataset
+        self.arch = arch
+        self.data_loader = data_loader
+        self.validate_fn = validate_fn
+        self.save_checkpoint_fn = save_checkpoint_fn
+        self.orig_model = model
+
+        self.conv_layers, self.dense_model_macs = collect_conv_details(model, dataset)
+        self.reset(init_only=True)
+        msglogger.info("Model %s has %d Convolution layers", arch, len(self.conv_layers))
+        msglogger.info("\tTotal MACs: %s" % distiller.pretty_int(self.dense_model_macs))
+
+        self.debug_stats = {'episode': 0}
+
+        # Gym
+        # spaces documentation: https://gym.openai.com/docs/
+        self.action_space = spaces.Box(0, 1, shape=(1,))
+        self.observation_space = spaces.Box(0, float("inf"), shape=(self.STATE_EMBEDDING_LEN,))
+
+    def reset(self, init_only=False):
+        """Reset the environment.
+        This is invoked by the Agent.
+        """
+        msglogger.info("Resetting the environment")
+        self.current_layer_id = 0
+        self.prev_action = 0
+        self.model = copy.deepcopy(self.orig_model)
+        self.zeros_mask_dict = create_model_masks(self.model)
+        self._remaining_macs = self.dense_model_macs
+        self._removed_macs = 0
+
+        # self.unprocessed_layers = []
+        # for conv in self.conv_layers:
+        #     self.unprocessed_layers.append(conv)
+        # self.processed_layers = []
+        if init_only:
+            return
+
+        #layer_macs = self.get_macs(self.current_layer())
+        #return self._get_obs(layer_macs)
+        obs, _, _, _, = self.step(0)
+        return obs
+
+
+    def num_layers(self):
+        return len(self.conv_layers)
+
+    def current_layer(self):
+        try:
+            return self.conv_layers[self.current_layer_id]
+        except KeyError:
+            return None
+
+    def episode_is_done(self):
+        return self.current_layer_id == self.num_layers()
+
+    def remaining_macs(self):
+        """Return the amount of MACs remaining in the model's unprocessed
+        Convolution layers.
+        This is normalized to the range 0..1
+        """
+        #return 1 - self.sum_list_macs(self.unprocessed_layers) / self.dense_model_macs
+        return self._remaining_macs / self.dense_model_macs
+
+    def removed_macs(self):
+        """Return the amount of MACs removed so far.
+        This is normalized to the range 0..1
+        """
+        #return self.sum_list_macs(self.processed_layers) / self.dense_model_macs
+        return self._removed_macs / self.dense_model_macs
+
+    # def sum_list_macs(self, conv_list):
+    #     """Sum the MACs in the provided list of Convolution layers"""
+    #     total_macs = 0
+    #     for conv in conv_list:
+    #         total_macs += conv.macs
+    #     return total_macs
+
+    def render(self, mode, close):
+        """Provide some feedback to the user about what's going on
+        This is invoked by the Agent.
+        """
+        if self.current_layer_id == 0:
+            msglogger.info("+" + "-" * 50 + "+")
+            msglogger.info("Starting a new episode")
+            msglogger.info("+" + "-" * 50 + "+")
+
+        msglogger.info("Environment: current_layer_id=%d" % self.current_layer_id)
+        distiller.log_weights_sparsity(self.model, -1, loggers=[self.pylogger])
+
+    def step(self, action):
+        """Take a step, given an action.
+        This is invoked by the Agent.
+        """
+        layer_macs = self.get_macs(self.current_layer())
+        if action > 0:
+            actual_action = self.__remove_channels(self.current_layer_id, action)
+        else:
+            actual_action = 0
+        layer_macs_after_action = self.get_macs(self.current_layer())
+
+        # Update the various counters after taking the step
+        self.current_layer_id += 1
+        next_layer_macs = self.get_macs(self.current_layer())
+        self._removed_macs += (layer_macs - layer_macs_after_action)
+        self._remaining_macs -= next_layer_macs
+
+        #self.prev_action = actual_action
+        if self.episode_is_done():
+            observation = self.get_final_obs()
+            reward = self.compute_reward()
+            # Save the learned-model checkpoint
+            scheduler = distiller.CompressionScheduler(self.model)
+            scheduler.load_state_dict(state={'masks_dict': self.zeros_mask_dict})
+            self.save_checkpoint_fn(epoch=self.debug_stats['episode'], model=self.model, scheduler=scheduler)
+            self.debug_stats['episode'] += 1
+        else:
+            observation = self._get_obs(next_layer_macs)
+            if True:
+                reward = 0
+            else:
+                reward = self.compute_reward()
+
+        self.prev_action = actual_action
+        info = {}
+        return observation, reward, self.episode_is_done(), info
+
+    def _get_obs(self, macs):
+        """Produce a state embedding (i.e. an observation)"""
+
+        layer = self.current_layer()
+        conv_module = distiller.model_find_module(self.model, layer.name)
+
+        obs = np.array([layer.t, conv_module.out_channels, conv_module.in_channels,
+                        layer.ifm_h, layer.ifm_w, layer.stride[0], layer.k,
+                        macs/self.dense_model_macs, self.removed_macs(), self.remaining_macs(), self.prev_action])
+
+        assert len(obs) == self.STATE_EMBEDDING_LEN
+        assert (macs/self.dense_model_macs + self.removed_macs() + self.remaining_macs()) <= 1
+        msglogger.info("obs={}".format(Observation._make(obs)))
+        return obs
+
+    def get_final_obs(self):
+        """Return the final stae embedding (observation)
+        The final state is reached after we traverse all of the Convolution layers.
+        """
+        obs = np.array([-1, 0, 0,
+                         0, 0, 0, 0,
+                         0, self.removed_macs(), 0, self.prev_action])
+        assert len(obs) == self.STATE_EMBEDDING_LEN
+        return obs
+
+    def get_macs(self, layer):
+        """Return the number of MACs required to compute <layer>'s Convolution"""
+        if layer is None:
+            return 0
+
+        conv_module = distiller.model_find_module(self.model, layer.name)
+        # MACs = volume(OFM) * (#IFM * K^2)
+        return (conv_module.out_channels * layer.ofm_h * layer.ofm_w) * (conv_module.in_channels * layer.k**2)
+
+    def __remove_channels(self, idx, fraction_to_prune, prune_what="channels"):
+        """Physically remove channels and corresponding filters from the model"""
+        if idx not in range(self.num_layers()):
+            raise ValueError("idx=%d is not in correct range (0-%d)" % (idx, self.num_layers()))
+        if fraction_to_prune < 0:
+            raise ValueError("fraction_to_prune=%f is illegal" % (fraction_to_prune))
+
+        if fraction_to_prune == 0:
+            return 0
+        if fraction_to_prune == 1.0:
+            # For now, prevent the removal of entire layers
+            fraction_to_prune = ALMOST_ONE
+
+        layer = self.conv_layers[idx]
+        conv_pname = layer.name + ".weight"
+        conv_p = distiller.model_find_param(self.model, conv_pname)
+
+        msglogger.info("ADC: removing %.1f%% channels from %s" % (fraction_to_prune*100, conv_pname))
+
+        if prune_what == "channels":
+            calculate_sparsity = distiller.sparsity_ch
+            reg_regims = {conv_pname: [fraction_to_prune, "Channels"]}
+            remove_structures = distiller.remove_channels
+        else:
+            calculate_sparsity = distiller.sparsity_3D
+            reg_regims = {conv_pname: [fraction_to_prune, "3D"]}
+            remove_structures = distiller.remove_filters
+
+        # Create a channel-ranking pruner
+        pruner = distiller.pruning.L1RankedStructureParameterPruner("adc_pruner", reg_regims)
+        pruner.set_param_mask(conv_p, conv_pname, self.zeros_mask_dict, meta=None)
+
+        if (self.zeros_mask_dict[conv_pname].mask is None or
+            calculate_sparsity(self.zeros_mask_dict[conv_pname].mask) == 0):
+            msglogger.info("__remove_channels: aborting because there are no channels to prune")
+            return 0
+
+        # Use the mask to prune
+        self.zeros_mask_dict[conv_pname].apply_mask(conv_p)
+        actual_sparsity = calculate_sparsity(conv_p)
+        remove_structures(self.model, self.zeros_mask_dict, self.arch, self.dataset, optimizer=None)
+        return actual_sparsity
+
+    def compute_reward(self):
+        """The ADC paper defines reward = -Error"""
+        distiller.log_weights_sparsity(self.model, -1, loggers=[self.pylogger])
+
+        top1, top5, vloss = self.validate_fn(model=self.model, epoch=self.debug_stats['episode'])
+        _, total_macs = collect_conv_details(self.model, self.dataset)
+        reward = -1 * vloss * math.log(total_macs)
+        #reward = -1 * vloss * math.sqrt(math.log(total_macs))
+        #reward = top1 / math.log(total_macs)
+        #alpha = 0.9
+        #reward = -1 * ( (1-alpha)*(top1/100) + 10*alpha*(total_macs/self.dense_model_macs) )
+
+        #alpha = 0.99
+        #reward = -1 * ( (1-alpha)*(top1/100) + alpha*(total_macs/self.dense_model_macs) )
+
+        #reward = vloss * math.log(total_macs)
+        #reward = -1 * vloss * (total_macs / self.dense_model_macs)
+        #reward = top1 * (self.dense_model_macs / total_macs)
+        #reward = -1 * math.log(total_macs)
+        #reward =  -1 * vloss
+        stats = ('Peformance/Validation/',
+                 OrderedDict([('Loss', vloss),
+                              ('Top1', top1),
+                              ('Top5', top5),
+                              ('reward', reward),
+                              ('total_macs', int(total_macs)),
+                              ('log(total_macs)', math.log(total_macs))]))
+        distiller.log_training_progress(stats, None, self.debug_stats['episode'], steps_completed=0, total_steps=1,
+                                        log_freq=1, loggers=[self.tflogger, self.pylogger])
+
+        return reward
diff --git a/examples/automated_deep_compression/presets/ADC_DDPG.py b/examples/automated_deep_compression/presets/ADC_DDPG.py
new file mode 100755
index 0000000..e7e5a1f
--- /dev/null
+++ b/examples/automated_deep_compression/presets/ADC_DDPG.py
@@ -0,0 +1,73 @@
+from agents.ddpg_agent import DDPGAgentParameters
+from graph_managers.basic_rl_graph_manager import BasicRLGraphManager
+from graph_managers.graph_manager import ScheduleParameters
+from base_parameters import VisualizationParameters
+from core_types import EnvironmentEpisodes, EnvironmentSteps
+from environments.gym_environment import MujocoInputFilter, GymEnvironmentParameters, MujocoOutputFilter
+from exploration_policies.additive_noise import AdditiveNoiseParameters
+from exploration_policies.truncated_normal import TruncatedNormalParameters
+from schedules import ConstantSchedule, PieceWiseSchedule, ExponentialSchedule
+from memories.memory import MemoryGranularity
+from architectures.tensorflow_components.architecture import Dense
+
+####################
+# Block Scheduling #
+####################
+schedule_params = ScheduleParameters()
+schedule_params.improve_steps = EnvironmentEpisodes(400)
+if True:
+    schedule_params.steps_between_evaluation_periods = EnvironmentEpisodes(10)
+    schedule_params.evaluation_steps = EnvironmentEpisodes(3)
+else:
+    schedule_params.steps_between_evaluation_periods = EnvironmentEpisodes(1)
+    schedule_params.evaluation_steps = EnvironmentEpisodes(1)
+schedule_params.heatup_steps = EnvironmentSteps(2)
+
+#####################
+# DDPG Agent Params #
+#####################
+agent_params = DDPGAgentParameters()
+agent_params.network_wrappers['actor'].input_embedders_parameters['observation'].scheme = [Dense([300])]
+agent_params.network_wrappers['actor'].middleware_parameters.scheme = [Dense([300])]
+agent_params.network_wrappers['critic'].input_embedders_parameters['observation'].scheme = [Dense([300])]
+agent_params.network_wrappers['critic'].middleware_parameters.scheme = [Dense([300])]
+agent_params.network_wrappers['critic'].input_embedders_parameters['action'].scheme = [Dense([300])]
+#agent_params.network_wrappers['critic'].clip_gradients = 100
+#agent_params.network_wrappers['actor'].clip_gradients = 100
+
+agent_params.algorithm.rate_for_copying_weights_to_target = 0.01  # Tau pg. 11
+agent_params.memory.max_size = (MemoryGranularity.Transitions, 2000)
+# agent_params.memory.max_size = (MemoryGranularity.Episodes, 2000)
+agent_params.exploration = TruncatedNormalParameters() # AdditiveNoiseParameters()
+steps_per_episode = 13
+agent_params.exploration.noise_percentage_schedule = PieceWiseSchedule([(ConstantSchedule(0.5), EnvironmentSteps(100*steps_per_episode)),
+                                                                        (ExponentialSchedule(0.5, 0, 0.95), EnvironmentSteps(350*steps_per_episode))])
+agent_params.algorithm.num_consecutive_playing_steps = EnvironmentSteps(1)
+agent_params.input_filter = MujocoInputFilter()
+agent_params.output_filter = MujocoOutputFilter()
+# agent_params.network_wrappers['actor'].learning_rate = 0.0001
+# agent_params.network_wrappers['critic'].learning_rate = 0.0001
+# These seem like good values for Reward = -Error
+agent_params.network_wrappers['actor'].learning_rate = 0.0001
+agent_params.network_wrappers['critic'].learning_rate = 0.0001
+# agent_params.network_wrappers['actor'].learning_rate = 0.1
+# agent_params.network_wrappers['critic'].learning_rate = 0.1
+# agent_params.network_wrappers['actor'].learning_rate =  0.000001
+# agent_params.network_wrappers['critic'].learning_rate = 0.000001
+
+##############################
+#      Gym                   #
+##############################
+env_params = GymEnvironmentParameters()
+#env_params.level = '/home/cvds_lab/nzmora/pytorch_workspace/distiller/examples/automated_deep_compression/gym_env/distiller_adc/distiller_adc.py:AutomatedDeepCompression'
+# This path works when training from Coach
+#env_params.level = '../distiller/examples/automated_deep_compression/gym_env/distiller_adc/distiller_adc.py:AutomatedDeepCompression'
+# This path works when training from Distiller
+#env_params.level = '../automated_deep_compression/gym_env/distiller_adc/distiller_adc.py:AutomatedDeepCompression'
+env_params.level = '../automated_deep_compression/ADC.py:CNNEnvironment'
+
+
+vis_params = VisualizationParameters()
+
+graph_manager = BasicRLGraphManager(agent_params=agent_params, env_params=env_params,
+                                    schedule_params=schedule_params, vis_params=vis_params)
diff --git a/examples/classifier_compression/compress_classifier.py b/examples/classifier_compression/compress_classifier.py
index a540dd7..f227b06 100755
--- a/examples/classifier_compression/compress_classifier.py
+++ b/examples/classifier_compression/compress_classifier.py
@@ -67,11 +67,11 @@ import torch.backends.cudnn as cudnn
 import torch.optim
 import torch.utils.data
 import torchnet.meter as tnt
+script_dir = os.path.dirname(__file__)
+module_path = os.path.abspath(os.path.join(script_dir, '..', '..'))
 try:
     import distiller
 except ImportError:
-    script_dir = os.path.dirname(__file__)
-    module_path = os.path.abspath(os.path.join(script_dir, '..', '..'))
     sys.path.append(module_path)
     import distiller
 import apputils
@@ -79,6 +79,7 @@ from distiller.data_loggers import TensorBoardLogger, PythonLogger, ActivationSp
 import distiller.quantization as quantization
 from models import ALL_MODEL_NAMES, create_model
 
+
 # Logger handle
 msglogger = None
 
@@ -127,7 +128,7 @@ parser.add_argument('--summary', type=str, choices=SUMMARY_CHOICES,
                     ' | '.join(SUMMARY_CHOICES))
 parser.add_argument('--compress', dest='compress', type=str, nargs='?', action='store',
                     help='configuration file for pruning the model (default is to use hard-coded schedule)')
-parser.add_argument('--sense', dest='sensitivity', choices=['element', 'filter'],
+parser.add_argument('--sense', dest='sensitivity', choices=['element', 'filter', 'channel'],
                     help='test the sensitivity of layers to pruning')
 parser.add_argument('--extras', default=None, type=str,
                     help='file with extra configuration information')
@@ -141,6 +142,7 @@ parser.add_argument('--name', '-n', metavar='NAME', default=None, help='Experime
 parser.add_argument('--out-dir', '-o', dest='output_dir', default='logs', help='Path to dump logs and checkpoints')
 parser.add_argument('--validation-size', '--vs', type=float_range, default=0.1,
                     help='Portion of training dataset to set aside for validation')
+parser.add_argument('--adc', dest='ADC', action='store_true', help='temp HACK')
 
 
 def check_pytorch_version():
@@ -210,7 +212,6 @@ def main():
 
     # Create the model
     model = create_model(args.pretrained, args.dataset, args.arch, device_ids=args.gpus)
-
     compression_scheduler = None
     # Create a couple of logging backends.  TensorBoardLogger writes log files in a format
     # that can be read by Google's Tensor Board.  PythonLogger writes to the Python logger.
@@ -230,6 +231,23 @@ def main():
     msglogger.info('Optimizer Type: %s', type(optimizer))
     msglogger.info('Optimizer Args: %s', optimizer.defaults)
 
+    if args.ADC:
+        HAVE_GYM_INSTALLED = False
+        if not HAVE_GYM_INSTALLED:
+            raise ValueError("ADC is currently experimental and uses non-public Coach features")
+
+        import examples.automated_deep_compression.ADC as ADC
+        train_loader, val_loader, test_loader, _ = apputils.load_data(
+            args.dataset, os.path.expanduser(args.data), args.batch_size,
+            args.workers, args.validation_size, args.deterministic)
+
+        validate_fn = partial(validate, val_loader=test_loader, criterion=criterion,
+                              loggers=[pylogger], print_freq=args.print_freq)
+
+        save_checkpoint_fn = partial(apputils.save_checkpoint, arch=args.arch, name='adc')
+        ADC.do_adc(model, args.dataset, args.arch, val_loader, validate_fn, save_checkpoint_fn)
+        exit()
+
     # This sample application can be invoked to produce various summary reports.
     if args.summary:
         which_summary = args.summary
@@ -264,7 +282,7 @@ def main():
         which_params = [param_name for param_name, _ in model.named_parameters()]
         sensitivity = distiller.perform_sensitivity_analysis(model,
                                                              net_params=which_params,
-                                                             sparsities=np.arange(0.0, 0.50, 0.05) if args.sensitivity == 'filter' else np.arange(0.0, 0.95, 0.05),
+                                                             sparsities=np.arange(0.0, 0.95, 0.05),
                                                              test_func=test_fnc,
                                                              group=args.sensitivity)
         distiller.sensitivities_to_png(sensitivity, 'sensitivity.png')
@@ -358,7 +376,7 @@ def train(train_loader, model, criterion, optimizer, epoch,
         data_time.add(time.time() - end)
 
         target = target.cuda(async=True)
-        input_var = torch.autograd.Variable(inputs)
+        input_var = inputs.cuda()
         target_var = torch.autograd.Variable(target)
 
         # Execute the forward phase, compute the output and measure loss
@@ -487,6 +505,7 @@ class PytorchNoGrad(object):
 
 def get_inference_var(tensor):
     """This is a temporary function to bridge some difference between PyTorch 3.x and 4.x"""
+    tensor = tensor.cuda(async=True)
     if torch.__version__ >= '0.4':
         return torch.autograd.Variable(tensor)
     return torch.autograd.Variable(tensor, volatile=True)
@@ -495,10 +514,12 @@ def get_inference_var(tensor):
 if __name__ == '__main__':
     try:
         main()
+    except KeyboardInterrupt:
+        print("\n-- KeyboardInterrupt --")
     except Exception as e:
         if msglogger is not None:
             msglogger.error(traceback.format_exc())
-        raise e
+        raise
     finally:
         if msglogger is not None:
             msglogger.info('')
diff --git a/examples/pruning_filters_for_efficient_convnets/resnet56_cifar_baseline_training.yaml b/examples/pruning_filters_for_efficient_convnets/resnet56_cifar_baseline_training.yaml
index 01c5c7a..8692301 100755
--- a/examples/pruning_filters_for_efficient_convnets/resnet56_cifar_baseline_training.yaml
+++ b/examples/pruning_filters_for_efficient_convnets/resnet56_cifar_baseline_training.yaml
@@ -1,88 +1,83 @@
 # We used this schedule to train CIFAR10-ResNet56 from scratch
 #
-# time python3 compress_classifier.py --arch resnet56_cifar  ../data.cifar10 -p=50 --lr=0.4 --epochs=180 --compress=../pruning_filters_for_efficient_convnets/resnet56_cifar_baseline_training.yaml -j=1 --deterministic
+# time python3 compress_classifier.py --arch resnet56_cifar  ../../../data.cifar10 -p=50 --lr=0.3 --epochs=180 --compress=../pruning_filters_for_efficient_convnets/resnet56_cifar_baseline_training.yaml -j=1 --deterministic
 #
-# Target: 6.96% error was reported Pruning Filters for Efficient Convnets 
+# Target: 6.96% error was reported Pruning Filters for Efficient Convnets
 #
+# Parameters:
 # +----+-------------------------------------+----------------+---------------+----------------+------------+------------+----------+----------+----------+------------+---------+----------+------------+
 # |    | Name                                | Shape          |   NNZ (dense) |   NNZ (sparse) |   Cols (%) |   Rows (%) |   Ch (%) |   2D (%) |   3D (%) |   Fine (%) |     Std |     Mean |   Abs-Mean |
 # |----+-------------------------------------+----------------+---------------+----------------+------------+------------+----------+----------+----------+------------+---------+----------+------------|
-# |  0 | module.conv1.weight                 | (16, 3, 3, 3)  |           432 |            432 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.34148 |  0.01379 |    0.14357 |
-# |  1 | module.layer1.0.conv1.weight        | (16, 16, 3, 3) |          2304 |           2304 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.06301 |  0.00203 |    0.02347 |
-# |  2 | module.layer1.0.conv2.weight        | (16, 16, 3, 3) |          2304 |           2304 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.05760 |  0.00007 |    0.02742 |
-# |  3 | module.layer1.1.conv1.weight        | (16, 16, 3, 3) |          2304 |           2304 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.04780 |  0.00338 |    0.02383 |
-# |  4 | module.layer1.1.conv2.weight        | (16, 16, 3, 3) |          2304 |           2304 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.04858 | -0.00358 |    0.02670 |
-# |  5 | module.layer1.2.conv1.weight        | (16, 16, 3, 3) |          2304 |           2304 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.07968 |  0.00273 |    0.04429 |
-# |  6 | module.layer1.2.conv2.weight        | (16, 16, 3, 3) |          2304 |           2304 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.07640 | -0.00262 |    0.04895 |
-# |  7 | module.layer1.3.conv1.weight        | (16, 16, 3, 3) |          2304 |           2304 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.10044 | -0.00384 |    0.05374 |
-# |  8 | module.layer1.3.conv2.weight        | (16, 16, 3, 3) |          2304 |           2304 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.09931 | -0.00360 |    0.06238 |
-# |  9 | module.layer1.4.conv1.weight        | (16, 16, 3, 3) |          2304 |           2304 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.08298 | -0.00024 |    0.05489 |
-# | 10 | module.layer1.4.conv2.weight        | (16, 16, 3, 3) |          2304 |           2304 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.08289 | -0.00766 |    0.05761 |
-# | 11 | module.layer1.5.conv1.weight        | (16, 16, 3, 3) |          2304 |           2304 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.11349 | -0.00590 |    0.08049 |
-# | 12 | module.layer1.5.conv2.weight        | (16, 16, 3, 3) |          2304 |           2304 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.10642 | -0.00195 |    0.07803 |
-# | 13 | module.layer1.6.conv1.weight        | (16, 16, 3, 3) |          2304 |           2304 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.10367 | -0.00788 |    0.07537 |
-# | 14 | module.layer1.6.conv2.weight        | (16, 16, 3, 3) |          2304 |           2304 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.09865 | -0.00195 |    0.07261 |
-# | 15 | module.layer1.7.conv1.weight        | (16, 16, 3, 3) |          2304 |           2304 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.12612 | -0.00886 |    0.09447 |
-# | 16 | module.layer1.7.conv2.weight        | (16, 16, 3, 3) |          2304 |           2304 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.11011 |  0.00163 |    0.08398 |
-# | 17 | module.layer1.8.conv1.weight        | (16, 16, 3, 3) |          2304 |           2304 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.13849 | -0.01522 |    0.10323 |
-# | 18 | module.layer1.8.conv2.weight        | (16, 16, 3, 3) |          2304 |           2304 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.10821 | -0.00555 |    0.08318 |
-# | 19 | module.layer2.0.conv1.weight        | (32, 16, 3, 3) |          4608 |           4608 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.13763 | -0.00246 |    0.10269 |
-# | 20 | module.layer2.0.conv2.weight        | (32, 32, 3, 3) |          9216 |           9216 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.11410 | -0.00401 |    0.08719 |
-# | 21 | module.layer2.0.downsample.0.weight | (32, 16, 1, 1) |           512 |            512 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.25910 |  0.01282 |    0.18712 |
-# | 22 | module.layer2.1.conv1.weight        | (32, 32, 3, 3) |          9216 |           9216 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.09005 | -0.00572 |    0.06956 |
-# | 23 | module.layer2.1.conv2.weight        | (32, 32, 3, 3) |          9216 |           9216 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.08083 | -0.00496 |    0.06368 |
-# | 24 | module.layer2.2.conv1.weight        | (32, 32, 3, 3) |          9216 |           9216 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.07592 | -0.00750 |    0.05929 |
-# | 25 | module.layer2.2.conv2.weight        | (32, 32, 3, 3) |          9216 |           9216 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.06707 | -0.00587 |    0.05252 |
-# | 26 | module.layer2.3.conv1.weight        | (32, 32, 3, 3) |          9216 |           9216 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.07829 | -0.00719 |    0.06119 |
-# | 27 | module.layer2.3.conv2.weight        | (32, 32, 3, 3) |          9216 |           9216 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.06452 | -0.00374 |    0.05061 |
-# | 28 | module.layer2.4.conv1.weight        | (32, 32, 3, 3) |          9216 |           9216 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.06574 | -0.00771 |    0.04972 |
-# | 29 | module.layer2.4.conv2.weight        | (32, 32, 3, 3) |          9216 |           9216 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.05378 | -0.00263 |    0.03984 |
-# | 30 | module.layer2.5.conv1.weight        | (32, 32, 3, 3) |          9216 |           9216 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.07436 | -0.00515 |    0.05701 |
-# | 31 | module.layer2.5.conv2.weight        | (32, 32, 3, 3) |          9216 |           9216 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.06059 | -0.00472 |    0.04677 |
-# | 32 | module.layer2.6.conv1.weight        | (32, 32, 3, 3) |          9216 |           9216 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.06222 | -0.00527 |    0.04587 |
-# | 33 | module.layer2.6.conv2.weight        | (32, 32, 3, 3) |          9216 |           9216 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.04993 | -0.00212 |    0.03606 |
-# | 34 | module.layer2.7.conv1.weight        | (32, 32, 3, 3) |          9216 |           9216 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.06119 | -0.00785 |    0.04308 |
-# | 35 | module.layer2.7.conv2.weight        | (32, 32, 3, 3) |          9216 |           9216 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.04664 | -0.00216 |    0.03203 |
-# | 36 | module.layer2.8.conv1.weight        | (32, 32, 3, 3) |          9216 |           9216 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.06416 | -0.00867 |    0.04732 |
-# | 37 | module.layer2.8.conv2.weight        | (32, 32, 3, 3) |          9216 |           9216 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.04904 | -0.00276 |    0.03586 |
-# | 38 | module.layer3.0.conv1.weight        | (64, 32, 3, 3) |         18432 |          18432 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.08855 | -0.00176 |    0.06946 |
-# | 39 | module.layer3.0.conv2.weight        | (64, 64, 3, 3) |         36864 |          36864 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.07219 |  0.00106 |    0.05211 |
-# | 40 | module.layer3.0.downsample.0.weight | (64, 32, 1, 1) |          2048 |           2048 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.13456 |  0.00539 |    0.09422 |
-# | 41 | module.layer3.1.conv1.weight        | (64, 64, 3, 3) |         36864 |          36864 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.05082 | -0.00166 |    0.03574 |
-# | 42 | module.layer3.1.conv2.weight        | (64, 64, 3, 3) |         36864 |          36864 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.04510 | -0.00510 |    0.03232 |
-# | 43 | module.layer3.2.conv1.weight        | (64, 64, 3, 3) |         36864 |          36864 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.05256 | -0.00417 |    0.03748 |
-# | 44 | module.layer3.2.conv2.weight        | (64, 64, 3, 3) |         36864 |          36864 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.04448 | -0.00243 |    0.03171 |
-# | 45 | module.layer3.3.conv1.weight        | (64, 64, 3, 3) |         36864 |          36864 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.04190 | -0.00189 |    0.03038 |
-# | 46 | module.layer3.3.conv2.weight        | (64, 64, 3, 3) |         36864 |          36864 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.03494 | -0.00418 |    0.02498 |
-# | 47 | module.layer3.4.conv1.weight        | (64, 64, 3, 3) |         36864 |          36864 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.04426 | -0.00368 |    0.03268 |
-# | 48 | module.layer3.4.conv2.weight        | (64, 64, 3, 3) |         36864 |          36864 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.03460 | -0.00293 |    0.02468 |
-# | 49 | module.layer3.5.conv1.weight        | (64, 64, 3, 3) |         36864 |          36864 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.04880 | -0.00321 |    0.03613 || 50 | module.layer3.5.conv2.weight        | (64, 64, 3, 3) |         36864 |          36864 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.03709 |  0.00014 |    0.02571 |
-# | 51 | module.layer3.6.conv1.weight        | (64, 64, 3, 3) |         36864 |          36864 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.02699 | -0.00166 |    0.01931 |
-# | 52 | module.layer3.6.conv2.weight        | (64, 64, 3, 3) |         36864 |          36864 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.02024 | -0.00064 |    0.01354 |
-# | 53 | module.layer3.7.conv1.weight        | (64, 64, 3, 3) |         36864 |          36864 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.02466 | -0.00162 |    0.01766 |
-# | 54 | module.layer3.7.conv2.weight        | (64, 64, 3, 3) |         36864 |          36864 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.01816 | -0.00159 |    0.01202 |
-# | 55 | module.layer3.8.conv1.weight        | (64, 64, 3, 3) |         36864 |          36864 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.03662 | -0.00271 |    0.02692 |
-# | 56 | module.layer3.8.conv2.weight        | (64, 64, 3, 3) |         36864 |          36864 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.02626 |  0.00011 |    0.01813 |
-# | 57 | module.fc.weight                    | (10, 64)       |           640 |            640 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.52207 | -0.00001 |    0.39151 |
+# |  0 | module.conv1.weight                 | (16, 3, 3, 3)  |           432 |            432 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.39191 |  0.00826 |    0.18757 |
+# |  1 | module.layer1.0.conv1.weight        | (16, 16, 3, 3) |          2304 |           2304 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.08334 | -0.00180 |    0.03892 |
+# |  2 | module.layer1.0.conv2.weight        | (16, 16, 3, 3) |          2304 |           2304 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.08565 | -0.00033 |    0.05106 |
+# |  3 | module.layer1.1.conv1.weight        | (16, 16, 3, 3) |          2304 |           2304 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.08190 |  0.00082 |    0.04765 |
+# |  4 | module.layer1.1.conv2.weight        | (16, 16, 3, 3) |          2304 |           2304 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.08365 | -0.00600 |    0.05459 |
+# |  5 | module.layer1.2.conv1.weight        | (16, 16, 3, 3) |          2304 |           2304 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.09640 | -0.00182 |    0.06337 |
+# |  6 | module.layer1.2.conv2.weight        | (16, 16, 3, 3) |          2304 |           2304 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.09881 | -0.00400 |    0.07056 |
+# |  7 | module.layer1.3.conv1.weight        | (16, 16, 3, 3) |          2304 |           2304 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.13412 | -0.00416 |    0.08827 |
+# |  8 | module.layer1.3.conv2.weight        | (16, 16, 3, 3) |          2304 |           2304 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.12693 | -0.00271 |    0.09395 |
+# |  9 | module.layer1.4.conv1.weight        | (16, 16, 3, 3) |          2304 |           2304 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.12149 | -0.01105 |    0.09064 |
+# | 10 | module.layer1.4.conv2.weight        | (16, 16, 3, 3) |          2304 |           2304 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.11322 |  0.00333 |    0.08556 |
+# | 11 | module.layer1.5.conv1.weight        | (16, 16, 3, 3) |          2304 |           2304 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.12076 | -0.01164 |    0.09311 |
+# | 12 | module.layer1.5.conv2.weight        | (16, 16, 3, 3) |          2304 |           2304 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.11627 | -0.00355 |    0.08882 |
+# | 13 | module.layer1.6.conv1.weight        | (16, 16, 3, 3) |          2304 |           2304 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.12492 | -0.00637 |    0.09493 |
+# | 14 | module.layer1.6.conv2.weight        | (16, 16, 3, 3) |          2304 |           2304 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.11240 | -0.00837 |    0.08710 |
+# | 15 | module.layer1.7.conv1.weight        | (16, 16, 3, 3) |          2304 |           2304 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.13819 | -0.00735 |    0.10096 |
+# | 16 | module.layer1.7.conv2.weight        | (16, 16, 3, 3) |          2304 |           2304 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.11107 | -0.00293 |    0.08613 |
+# | 17 | module.layer1.8.conv1.weight        | (16, 16, 3, 3) |          2304 |           2304 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.12269 | -0.01133 |    0.09511 |
+# | 18 | module.layer1.8.conv2.weight        | (16, 16, 3, 3) |          2304 |           2304 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.09276 |  0.00240 |    0.07117 |
+# | 19 | module.layer2.0.conv1.weight        | (32, 16, 3, 3) |          4608 |           4608 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.13876 | -0.01190 |    0.11061 |
+# | 20 | module.layer2.0.conv2.weight        | (32, 32, 3, 3) |          9216 |           9216 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.12728 | -0.00499 |    0.10012 |
+# | 21 | module.layer2.0.downsample.0.weight | (32, 16, 1, 1) |           512 |            512 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.24306 | -0.01255 |    0.19073 |
+# | 22 | module.layer2.1.conv1.weight        | (32, 32, 3, 3) |          9216 |           9216 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.11474 | -0.00995 |    0.09044 |
+# | 23 | module.layer2.1.conv2.weight        | (32, 32, 3, 3) |          9216 |           9216 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.10452 | -0.00440 |    0.08196 |
+# | 24 | module.layer2.2.conv1.weight        | (32, 32, 3, 3) |          9216 |           9216 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.09873 | -0.00629 |    0.07833 |
+# | 25 | module.layer2.2.conv2.weight        | (32, 32, 3, 3) |          9216 |           9216 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.08747 | -0.00393 |    0.06891 |
+# | 26 | module.layer2.3.conv1.weight        | (32, 32, 3, 3) |          9216 |           9216 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.09434 | -0.00762 |    0.07469 |
+# | 27 | module.layer2.3.conv2.weight        | (32, 32, 3, 3) |          9216 |           9216 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.07984 | -0.00449 |    0.06271 |
+# | 28 | module.layer2.4.conv1.weight        | (32, 32, 3, 3) |          9216 |           9216 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.08767 | -0.00733 |    0.06852 |
+# | 29 | module.layer2.4.conv2.weight        | (32, 32, 3, 3) |          9216 |           9216 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.06642 | -0.00396 |    0.05196 |
+# | 30 | module.layer2.5.conv1.weight        | (32, 32, 3, 3) |          9216 |           9216 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.07521 | -0.00699 |    0.05799 |
+# | 31 | module.layer2.5.conv2.weight        | (32, 32, 3, 3) |          9216 |           9216 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.05739 | -0.00351 |    0.04334 |
+# | 32 | module.layer2.6.conv1.weight        | (32, 32, 3, 3) |          9216 |           9216 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.06130 | -0.00595 |    0.04791 |
+# | 33 | module.layer2.6.conv2.weight        | (32, 32, 3, 3) |          9216 |           9216 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.04703 | -0.00519 |    0.03527 |
+# | 34 | module.layer2.7.conv1.weight        | (32, 32, 3, 3) |          9216 |           9216 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.06366 | -0.00734 |    0.04806 |
+# | 35 | module.layer2.7.conv2.weight        | (32, 32, 3, 3) |          9216 |           9216 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.04591 | -0.00131 |    0.03282 |
+# | 36 | module.layer2.8.conv1.weight        | (32, 32, 3, 3) |          9216 |           9216 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.05903 | -0.00606 |    0.04555 |
+# | 37 | module.layer2.8.conv2.weight        | (32, 32, 3, 3) |          9216 |           9216 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.04344 | -0.00566 |    0.03290 |
+# | 38 | module.layer3.0.conv1.weight        | (64, 32, 3, 3) |         18432 |          18432 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.08262 |  0.00251 |    0.06520 |
+# | 39 | module.layer3.0.conv2.weight        | (64, 64, 3, 3) |         36864 |          36864 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.06248 |  0.00073 |    0.04578 |
+# | 40 | module.layer3.0.downsample.0.weight | (64, 32, 1, 1) |          2048 |           2048 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.12275 |  0.01139 |    0.08651 |
+# | 41 | module.layer3.1.conv1.weight        | (64, 64, 3, 3) |         36864 |          36864 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.03438 | -0.00186 |    0.02419 |
+# | 42 | module.layer3.1.conv2.weight        | (64, 64, 3, 3) |         36864 |          36864 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.03091 | -0.00368 |    0.02203 |
+# | 43 | module.layer3.2.conv1.weight        | (64, 64, 3, 3) |         36864 |          36864 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.03477 | -0.00226 |    0.02499 |
+# | 44 | module.layer3.2.conv2.weight        | (64, 64, 3, 3) |         36864 |          36864 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.03012 | -0.00350 |    0.02159 |
+# | 45 | module.layer3.3.conv1.weight        | (64, 64, 3, 3) |         36864 |          36864 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.03577 | -0.00166 |    0.02608 |
+# | 46 | module.layer3.3.conv2.weight        | (64, 64, 3, 3) |         36864 |          36864 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.02962 | -0.00124 |    0.02115 |
+# | 47 | module.layer3.4.conv1.weight        | (64, 64, 3, 3) |         36864 |          36864 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.03694 | -0.00285 |    0.02677 |
+# | 48 | module.layer3.4.conv2.weight        | (64, 64, 3, 3) |         36864 |          36864 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.02916 | -0.00165 |    0.02024 |
+# | 49 | module.layer3.5.conv1.weight        | (64, 64, 3, 3) |         36864 |          36864 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.03158 | -0.00180 |    0.02342 |
+# | 50 | module.layer3.5.conv2.weight        | (64, 64, 3, 3) |         36864 |          36864 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.02527 | -0.00177 |    0.01787 |
+# | 51 | module.layer3.6.conv1.weight        | (64, 64, 3, 3) |         36864 |          36864 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.03074 | -0.00169 |    0.02256 |
+# | 52 | module.layer3.6.conv2.weight        | (64, 64, 3, 3) |         36864 |          36864 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.02406 | -0.00006 |    0.01658 |
+# | 53 | module.layer3.7.conv1.weight        | (64, 64, 3, 3) |         36864 |          36864 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.03160 | -0.00249 |    0.02294 |
+# | 54 | module.layer3.7.conv2.weight        | (64, 64, 3, 3) |         36864 |          36864 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.02298 | -0.00083 |    0.01553 |
+# | 55 | module.layer3.8.conv1.weight        | (64, 64, 3, 3) |         36864 |          36864 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.02594 | -0.00219 |    0.01890 |
+# | 56 | module.layer3.8.conv2.weight        | (64, 64, 3, 3) |         36864 |          36864 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.01986 | -0.00061 |    0.01318 |
+# | 57 | module.fc.weight                    | (10, 64)       |           640 |            640 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.52562 | -0.00003 |    0.39168 |
 # | 58 | Total sparsity:                     | -              |        851504 |         851504 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.00000 |  0.00000 |    0.00000 |
 # +----+-------------------------------------+----------------+---------------+----------------+------------+------------+----------+----------+----------+------------+---------+----------+------------+
-# Total sparsity: 0.00
+# 2018-07-02 16:36:31,555 - Total sparsity: 0.00
 #
-# --- validate (epoch=179)-----------
-# 5000 samples (256 per mini-batch)
-# ==> Top1: 93.000    Top5: 99.820    Loss: 0.314
+# 2018-07-02 16:36:31,555 - --- validate (epoch=179)-----------
+# 2018-07-02 16:36:31,555 - 5000 samples (256 per mini-batch)
+# 2018-07-02 16:36:33,121 - ==> Top1: 91.520    Top5: 99.680    Loss: 0.387
 #
-# Saving checkpoint
-# --- test ---------------------
-# 10000 samples (256 per mini-batch)
-# ==> Top1: 92.970    Top5: 99.740    Loss: 0.349
-#
-#
-# Log file for this run: /home/cvds_lab/nzmora/pytorch_workspace/private-distiller/examples/classifier_compression/logs/2018.04.09-222954/2018.04.09-222954.log
-#
-# real    91m56.310s
-# user    176m50.080s
-# sys     27m5.873s
+# 2018-07-02 16:36:33,123 - Saving checkpoint to: logs/2018.07.02-152746/checkpoint.pth.tar
+# 2018-07-02 16:36:33,159 - --- test ---------------------
+# 2018-07-02 16:36:33,159 - 10000 samples (256 per mini-batch)
+# 2018-07-02 16:36:36,194 - ==> Top1: 92.850    Top5: 99.780    Loss: 0.364
 
 lr_schedulers:
   training_lr:
diff --git a/examples/word_language_model/main.py b/examples/word_language_model/main.py
index 767b80c..a8ff493 100755
--- a/examples/word_language_model/main.py
+++ b/examples/word_language_model/main.py
@@ -306,7 +306,7 @@ if args.summary:
             threshold = bottomk.data[-1]
             msglogger.info("parameter %s: q = %.2f" %(name, threshold))
     else:
-        distiller.model_summary(model, None, which_summary, 'wikitext2')
+        distiller.model_summary(model, which_summary, 'wikitext2')
     exit(0)
 
 compression_scheduler = None
@@ -317,8 +317,8 @@ if args.compress:
     compression_scheduler = distiller.config.file_config(model, None, args.compress)
 
 optimizer = torch.optim.SGD(model.parameters(), args.lr,
-                                 momentum=args.momentum,
-                                 weight_decay=args.weight_decay)
+                            momentum=args.momentum,
+                            weight_decay=args.weight_decay)
 lr_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min',
                                                           patience=0, verbose=True, factor=0.5)
 
diff --git a/models/__init__.py b/models/__init__.py
index d2a5c4f..b34d7a2 100755
--- a/models/__init__.py
+++ b/models/__init__.py
@@ -72,6 +72,5 @@ def create_model(pretrained, dataset, arch, parallel=True, device_ids=None):
         model.features = torch.nn.DataParallel(model.features, device_ids=device_ids)
     elif parallel:
         model = torch.nn.DataParallel(model, device_ids=device_ids)
-
     model.cuda()
     return model
diff --git a/models/cifar10/__init__.py b/models/cifar10/__init__.py
index 3b72572..e4f636f 100755
--- a/models/cifar10/__init__.py
+++ b/models/cifar10/__init__.py
@@ -19,3 +19,4 @@
 from .simplenet_cifar import *
 from .resnet_cifar import *
 from .preresnet_cifar import *
+from .vgg_cifar import *
diff --git a/models/cifar10/vgg_cifar.py b/models/cifar10/vgg_cifar.py
new file mode 100755
index 0000000..0b5a5bb
--- /dev/null
+++ b/models/cifar10/vgg_cifar.py
@@ -0,0 +1,133 @@
+#
+# Copyright (c) 2018 Intel Corporation
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#      http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+"""VGG for CIFAR10
+
+VGG for CIFAR10, based on "Very Deep Convolutional Networks for Large-Scale
+Image Recognition".
+This is based on TorchVision's implementation of VGG for ImageNet, with
+appropriate changes for the 10-class Cifar-10 dataset.
+We replaced the three linear classifiers with a single one.
+"""
+
+import torch.nn as nn
+
+__all__ = [
+    'VGGCifar', 'vgg11_cifar', 'vgg11_bn_cifar', 'vgg13_cifar', 'vgg13_bn_cifar', 'vgg16_cifar', 'vgg16_bn_cifar',
+    'vgg19_bn_cifar', 'vgg19_cifar',
+]
+
+
+class VGGCifar(nn.Module):
+    def __init__(self, features, num_classes=10, init_weights=True):
+        super(VGGCifar, self).__init__()
+        self.features = features
+        self.classifier = nn.Linear(512, num_classes)
+        if init_weights:
+            self._initialize_weights()
+
+    def forward(self, x):
+        x = self.features(x)
+        x = x.view(x.size(0), -1)
+        x = self.classifier(x)
+        return x
+
+    def _initialize_weights(self):
+        for m in self.modules():
+            if isinstance(m, nn.Conv2d):
+                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
+                if m.bias is not None:
+                    nn.init.constant_(m.bias, 0)
+            elif isinstance(m, nn.BatchNorm2d):
+                nn.init.constant_(m.weight, 1)
+                nn.init.constant_(m.bias, 0)
+            elif isinstance(m, nn.Linear):
+                nn.init.normal_(m.weight, 0, 0.01)
+                nn.init.constant_(m.bias, 0)
+
+
+def make_layers(cfg, batch_norm=False):
+    layers = []
+    in_channels = 3
+    for v in cfg:
+        if v == 'M':
+            layers += [nn.MaxPool2d(kernel_size=2, stride=2)]
+        else:
+            conv2d = nn.Conv2d(in_channels, v, kernel_size=3, padding=1)
+            if batch_norm:
+                layers += [conv2d, nn.BatchNorm2d(v), nn.ReLU(inplace=True)]
+            else:
+                layers += [conv2d, nn.ReLU(inplace=True)]
+            in_channels = v
+    return nn.Sequential(*layers)
+
+
+cfg = {
+    'A': [64, 'M', 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'],
+    'B': [64, 64, 'M', 128, 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'],
+    'D': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512, 'M'],
+    'E': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 256, 'M', 512, 512, 512, 512, 'M', 512, 512, 512, 512, 'M'],
+}
+
+
+def vgg11_cifar(**kwargs):
+    """VGG 11-layer model (configuration "A")"""
+    model = VGGCifar(make_layers(cfg['A']), **kwargs)
+    return model
+
+
+def vgg11_bn_cifar(**kwargs):
+    """VGG 11-layer model (configuration "A") with batch normalization"""
+    model = VGGCifar(make_layers(cfg['A'], batch_norm=True), **kwargs)
+    return model
+
+
+def vgg13_cifar(**kwargs):
+    """VGG 13-layer model (configuration "B")"""
+    model = VGGCifar(make_layers(cfg['B']), **kwargs)
+    return model
+
+
+def vgg13_bn_cifar(**kwargs):
+    """VGG 13-layer model (configuration "B") with batch normalization"""
+    model = VGGCifar(make_layers(cfg['B'], batch_norm=True), **kwargs)
+    return model
+
+
+def vgg16_cifar(**kwargs):
+    """VGG 16-layer model (configuration "D")
+    """
+    model = VGGCifar(make_layers(cfg['D']), **kwargs)
+    return model
+
+
+def vgg16_bn_cifar(**kwargs):
+    """VGG 16-layer model (configuration "D") with batch normalization"""
+    model = VGGCifar(make_layers(cfg['D'], batch_norm=True), **kwargs)
+    return model
+
+
+def vgg19_cifar(**kwargs):
+    """VGG 19-layer model (configuration "E")
+    """
+    model = VGGCifar(make_layers(cfg['E']), **kwargs)
+    return model
+
+
+def vgg19_bn_cifar(**kwargs):
+    """VGG 19-layer model (configuration 'E') with batch normalization"""
+    model = VGGCifar(make_layers(cfg['E'], batch_norm=True), **kwargs)
+    return model
diff --git a/tests/common.py b/tests/common.py
index bcee3bf..324f2b8 100755
--- a/tests/common.py
+++ b/tests/common.py
@@ -23,7 +23,7 @@ import distiller
 from models import create_model
 
 
-def setup_test(arch, dataset, parallel=True):
+def setup_test(arch, dataset, parallel):
     model = create_model(False, dataset, arch, parallel=parallel)
     assert model is not None
 
diff --git a/tests/test_model_summary.py b/tests/test_model_summary.py
new file mode 100755
index 0000000..f63e290
--- /dev/null
+++ b/tests/test_model_summary.py
@@ -0,0 +1,63 @@
+#
+# Copyright (c) 2018 Intel Corporation
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#      http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import logging
+import torch
+import os
+import sys
+module_path = os.path.abspath(os.path.join('..'))
+if module_path not in sys.path:
+    sys.path.append(module_path)
+import distiller
+import pytest
+import common  # common test code
+import apputils
+
+# Logging configuration
+logging.basicConfig(level=logging.INFO)
+fh = logging.FileHandler('test.log')
+logger = logging.getLogger()
+logger.addHandler(fh)
+
+
+def test_png_generation():
+    DATASET = "cifar10"
+    ARCH = "resnet20_cifar"
+    model, zeros_mask_dict = common.setup_test(ARCH, DATASET, parallel=True)
+    # 2 different ways to create a PNG
+    apputils.draw_img_classifier_to_file(model, 'model.png', DATASET, True)
+    apputils.draw_img_classifier_to_file(model, 'model.png', DATASET, False)
+
+
+def test_negative():
+    DATASET = "cifar10"
+    ARCH = "resnet20_cifar"
+    model, zeros_mask_dict = common.setup_test(ARCH, DATASET, parallel=True)
+
+    with pytest.raises(ValueError):
+        # png is not a supported summary type, so we expect this to fail with a ValueError
+        distiller.model_summary(model, what='png', dataset=DATASET)
+
+
+def test_summary():
+    DATASET = "cifar10"
+    ARCH = "resnet20_cifar"
+    model, zeros_mask_dict = common.setup_test(ARCH, DATASET, parallel=True)
+
+    distiller.model_summary(model, what='sparsity', dataset=DATASET)
+    distiller.model_summary(model, what='compute', dataset=DATASET)
+    distiller.model_summary(model, what='model', dataset=DATASET)
+    distiller.model_summary(model, what='modules', dataset=DATASET)
diff --git a/tests/test_pruning.py b/tests/test_pruning.py
index 7b26709..57d2975 100755
--- a/tests/test_pruning.py
+++ b/tests/test_pruning.py
@@ -108,7 +108,6 @@ def test_prune_all_filters(parallel):
 
 def ranked_filter_pruning(config, ratio_to_prune, is_parallel):
     """Test L1 ranking and pruning of filters.
-
     First we rank and prune the filters of a Convolutional layer using
     a L1RankedStructureParameterPruner.  Then we physically remove the
     filters from the model (via "thining" process).
@@ -218,7 +217,6 @@ def run_forward_backward(model, optimizer, dummy_input):
 
 def arbitrary_channel_pruning(config, channels_to_remove, is_parallel):
     """Test removal of arbitrary channels.
-
     The test receives a specification of channels to remove.
     Based on this specification, the channels are pruned and then physically
     removed from the model (via a "thinning" process).
diff --git a/tests/test_ranking.py b/tests/test_ranking.py
new file mode 100755
index 0000000..a0fa14a
--- /dev/null
+++ b/tests/test_ranking.py
@@ -0,0 +1,101 @@
+#
+# Copyright (c) 2018 Intel Corporation
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#      http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import logging
+import torch
+import os
+import sys
+try:
+    import distiller
+except ImportError:
+    module_path = os.path.abspath(os.path.join('..'))
+    if module_path not in sys.path:
+        sys.path.append(module_path)
+    import distiller
+import common  # common test code
+
+# Logging configuration
+logging.basicConfig(level=logging.INFO)
+fh = logging.FileHandler('test.log')
+logger = logging.getLogger()
+logger.addHandler(fh)
+
+
+def test_ch_ranking():
+    # Tensor with shape [3, 2, 2, 2] -- 3 filters, 2 channels
+    param = torch.tensor([[[[11., 12],
+                            [13,  14]],
+
+                           [[15., 16],
+                            [17,  18]]],
+                          # Filter #2
+                          [[[21., 22],
+                            [23,  24]],
+
+                           [[25., 26],
+                            [27,  28]]],
+                          # Filter #3
+                          [[[31., 32],
+                            [33,  34]],
+
+                           [[35., 36],
+                            [37,  38]]]])
+
+    fraction_to_prune = 0.5
+    bottomk_channels, channel_mags = distiller.pruning.L1RankedStructureParameterPruner.rank_channels(fraction_to_prune, param)
+    logger.info("bottom {}% channels: {}".format(fraction_to_prune*100, bottomk_channels))
+    assert bottomk_channels == torch.tensor([90.])
+
+
+def test_ranked_channel_pruning():
+    model, zeros_mask_dict = common.setup_test("resnet20_cifar", "cifar10", parallel=False)
+
+    # Test that we can access the weights tensor of the first convolution in layer 1
+    conv1_p = distiller.model_find_param(model, "layer1.0.conv1.weight")
+    assert conv1_p is not None
+
+    # Test that there are no zero-channels
+    assert distiller.sparsity_ch(conv1_p) == 0.0
+
+    # # Create a channel-ranking pruner
+    reg_regims = {"layer1.0.conv1.weight": [0.1, "Channels"]}
+    pruner = distiller.pruning.L1RankedStructureParameterPruner("channel_pruner", reg_regims)
+    pruner.set_param_mask(conv1_p, "layer1.0.conv1.weight", zeros_mask_dict, meta=None)
+
+    conv1 = common.find_module_by_name(model, "layer1.0.conv1")
+    assert conv1 is not None
+
+    # Test that the mask has the correct fraction of channels pruned.
+    # We asked for 10%, but there are only 16 channels, so we have to settle for 1/16 channels
+    logger.info("layer1.0.conv1 = {}".format(conv1))
+    expected_pruning = int(0.1 * conv1.in_channels) / conv1.in_channels
+    assert distiller.sparsity_ch(zeros_mask_dict["layer1.0.conv1.weight"].mask) == expected_pruning
+
+    # Use the mask to prune
+    assert distiller.sparsity_ch(conv1_p) == 0
+    zeros_mask_dict["layer1.0.conv1.weight"].apply_mask(conv1_p)
+    assert distiller.sparsity_ch(conv1_p) == expected_pruning
+
+    # Remove channels (and filters)
+    conv0 = common.find_module_by_name(model, "conv1")
+    assert conv0 is not None
+    assert conv0.out_channels == 16
+    assert conv1.in_channels == 16
+
+    # Test thinning
+    distiller.remove_channels(model, zeros_mask_dict, "resnet20_cifar", "cifar10", optimizer=None)
+    assert conv0.out_channels == 15
+    assert conv1.in_channels == 15
diff --git a/tests/test_summarygraph.py b/tests/test_summarygraph.py
index f36f565..b5bf6a8 100755
--- a/tests/test_summarygraph.py
+++ b/tests/test_summarygraph.py
@@ -18,6 +18,7 @@ import logging
 import torch
 import os
 import sys
+import pytest
 module_path = os.path.abspath(os.path.join('..'))
 if module_path not in sys.path:
     sys.path.append(module_path)
@@ -154,6 +155,30 @@ def test_simplenet():
     assert len(preds) == 1
 
 
+def test_simplenet():
+    g = create_graph('cifar10', 'simplenet_cifar')
+    assert g is not None
+    preds = g.predecessors_f(normalize_module_name('module.conv1'), 'Conv')
+    logging.debug("[simplenet_cifar]: preds of module.conv1 = {}".format(preds))
+    assert len(preds) == 0
+
+    preds = g.predecessors_f(normalize_module_name('module.conv2'), 'Conv')
+    logging.debug("[simplenet_cifar]: preds of module.conv2 = {}".format(preds))
+    assert len(preds) == 1
+
+
+def test_simplenet():
+    g = create_graph('cifar10', 'simplenet_cifar')
+    assert g is not None
+    preds = g.predecessors_f(normalize_module_name('module.conv1'), 'Conv')
+    logging.debug("[simplenet_cifar]: preds of module.conv1 = {}".format(preds))
+    assert len(preds) == 0
+
+    preds = g.predecessors_f(normalize_module_name('module.conv2'), 'Conv')
+    logging.debug("[simplenet_cifar]: preds of module.conv2 = {}".format(preds))
+    assert len(preds) == 1
+
+
 def name_test(dataset, arch):
     model = create_model(False, dataset, arch, parallel=False)
     modelp = create_model(False, dataset, arch, parallel=True)
-- 
GitLab