From 3f7a94089828c74e2ec751e825b331dc1fc67e08 Mon Sep 17 00:00:00 2001
From: Neta Zmora <neta.zmora@intel.com>
Date: Sun, 1 Sep 2019 20:54:27 +0300
Subject: [PATCH] AMC: add pruning of FC layers
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit

FMReconstructionChannelPruner: add support for nn.Linear layers
utils.py: add non_zero_channels()
thinning: support removing channels from FC layers preceding Conv layers
test_pruning.py: add test_row_pruning()
scheduler: init from a dictionary of Maskers
coach_if.py – fix imports of Clipped-PPO and TD3
---
 distiller/pruning/ranked_structures_pruner.py | 114 ++++---
 distiller/scheduler.py                        |   5 +-
 distiller/thinning.py                         |  98 +++---
 distiller/utils.py                            |  34 +-
 examples/auto_compression/amc/amc.py          |  30 +-
 .../amc/auto_compression_channels.yaml        |   6 +-
 examples/auto_compression/amc/environment.py  | 309 ++++++++++++------
 examples/auto_compression/amc/rewards.py      |  28 +-
 .../amc/rl_libs/coach/coach_if.py             |  13 +-
 .../amc/rl_libs/private/private_if.py         |   1 -
 .../amc/utils/data_dependencies.py            |   2 +-
 tests/test_pruning.py                         |  19 +-
 12 files changed, 394 insertions(+), 265 deletions(-)

diff --git a/distiller/pruning/ranked_structures_pruner.py b/distiller/pruning/ranked_structures_pruner.py
index 96b8dce..84c96b7 100755
--- a/distiller/pruning/ranked_structures_pruner.py
+++ b/distiller/pruning/ranked_structures_pruner.py
@@ -263,6 +263,18 @@ class LpRankedStructureParameterPruner(_RankedStructureParameterPruner):
         # param.data = torch.randn_like(param)
         return binary_map
 
+    @staticmethod
+    def rank_rows(magnitude_fn, fraction_to_prune, param): # , group_size, rounding_fn, noise):
+        assert param.dim() == 2, "This pruning is only supported for 2D weights"
+        ROWS_DIM = 0
+        cols_mags = magnitude_fn(param, dim=ROWS_DIM)
+        num_cols_to_prune = int(fraction_to_prune * cols_mags.size(ROWS_DIM))
+        if num_cols_to_prune == 0:
+            msglogger.info("Too few filters - can't prune %.1f%% rows", 100*fraction_to_prune)
+            return None, None
+        bottomk_cols, _ = torch.topk(cols_mags, num_cols_to_prune, largest=False, sorted=True)
+        return bottomk_cols, cols_mags
+
     @staticmethod
     def rank_and_prune_rows(fraction_to_prune, param, param_name,
                             zeros_mask_dict, model=None, binary_map=None,
@@ -270,30 +282,29 @@ class LpRankedStructureParameterPruner(_RankedStructureParameterPruner):
         """Prune the rows of a matrix, based on ranked L1-norms of the matrix rows.
 
         PyTorch stores the weights matrices in a transposed format.  I.e. before performing GEMM, a matrix is
-        transposed.  This is counter-intuitive.  To deal with this, we can either transpose the matrix and
-        then proceed to compute the masks as usual, or we can treat columns as rows, and rows as columns :-(.
+        transposed.  This is because the output is computed as follows:
+            y = x(W^T) + b ; where W^T is the transpose of W
+
+        Removing input_channels from W^T, is removing rows of W^T, which is removing columns of W.
+
+        To deal with this rotation, we can either transpose the matrix and then proceed to compute the masks
+        as usual, or we can treat columns as rows, and rows as columns :-(.
         We choose the latter, because transposing very large matrices can be detrimental to performance.  Note
-        that computing mean L1-norm of columns is also not optimal, because consequtive column elements are far
+        that computing mean L1-norm of columns is also not optimal, because consecutive column elements are far
         away from each other in memory, and this means poor use of caches and system memory.
         """
-
-        assert param.dim() == 2, "This pruning is only supported for 2D weights"
-        ROWS_DIM = 0
+        bottomk_cols, cols_mags = LpRankedStructureParameterPruner.rank_rows(magnitude_fn, fraction_to_prune, param)
         THRESHOLD_DIM = 'Cols'
-        rows_mags = magnitude_fn(param, dim=ROWS_DIM)
-        num_rows_to_prune = int(fraction_to_prune * rows_mags.size(0))
-        if num_rows_to_prune == 0:
-            msglogger.info("Too few filters - can't prune %.1f%% rows", 100*fraction_to_prune)
-            return
-        bottomk_rows, _ = torch.topk(rows_mags, num_rows_to_prune, largest=False, sorted=True)
-        threshold = bottomk_rows[-1]
+        threshold = bottomk_cols[-1]
         threshold_type = 'L1' if magnitude_fn == l1_magnitude else 'L2'
         zeros_mask_dict[param_name].mask, binary_map = distiller.group_threshold_mask(param, THRESHOLD_DIM,
                                                                                       threshold, threshold_type)
+        ROWS_DIM = 0
+        num_cols_to_prune = int(fraction_to_prune * cols_mags.size(ROWS_DIM))
         msglogger.info("%sRankedStructureParameterPruner - param: %s pruned=%.3f goal=%.3f (%d/%d)",
                        threshold_type, param_name,
                        distiller.sparsity(zeros_mask_dict[param_name].mask),
-                       fraction_to_prune, num_rows_to_prune, rows_mags.size(0))
+                       fraction_to_prune, num_cols_to_prune, cols_mags.size(ROWS_DIM))
         return binary_map
 
     @staticmethod
@@ -680,8 +691,7 @@ class FMReconstructionChannelPruner(_RankedStructureParameterPruner):
         Use this in conjunction with distiller.features_collector.collect_intermediate_featuremap_samples,
         which orchestrates the process of feature-map collection.
 
-        This foward-hook samples random points.
-
+        This foward-hook samples random points in the output feature-maps of 'module'.
         After collecting the feature-map samples, distiller.FMReconstructionChannelPruner can be used.
 
         Arguments:
@@ -697,11 +707,15 @@ class FMReconstructionChannelPruner(_RankedStructureParameterPruner):
 
         # Sample random (uniform) points in each feature-map.
         # This method is biased toward small feature-maps.
-        randx = np.random.randint(0, output.size(2), n_points_per_fm)
-        randy = np.random.randint(0, output.size(3), n_points_per_fm)
+        if isinstance(module, torch.nn.Conv2d):
+            randx = np.random.randint(0, output.size(2), n_points_per_fm)
+            randy = np.random.randint(0, output.size(3), n_points_per_fm)
 
         X = input[0]
-        if module.kernel_size == (1,1):
+        if isinstance(module, torch.nn.Linear):
+            X = X.detach().cpu().clone()
+            Y = output.detach().cpu().clone()
+        elif module.kernel_size == (1, 1):
             X = X[:, :, randx, randy].detach().cpu().clone()
             Y = output[:, :, randx, randy].detach().cpu().clone()
         else:
@@ -736,9 +750,10 @@ class FMReconstructionChannelPruner(_RankedStructureParameterPruner):
     def prune_group(self, fraction_to_prune, param, param_name, zeros_mask_dict, model=None, binary_map=None):
         if fraction_to_prune == 0:
             return
-        binary_map = self.rank_and_prune_channels(fraction_to_prune, param, param_name, 
-                                                  zeros_mask_dict, model, binary_map, 
-                                                  group_size=self.group_size, 
+
+        binary_map = self.rank_and_prune_channels(fraction_to_prune, param, param_name,
+                                                  zeros_mask_dict, model, binary_map,
+                                                  group_size=self.group_size,
                                                   rounding_fn=self.rounding_fn,
                                                   noise=self.noise)
         return binary_map
@@ -750,16 +765,24 @@ class FMReconstructionChannelPruner(_RankedStructureParameterPruner):
                                 noise=0):
         assert binary_map is None
         if binary_map is None:
-            bottomk_channels, channel_mags = LpRankedStructureParameterPruner.rank_channels(
-                magnitude_fn, fraction_to_prune, param, group_size, rounding_fn, noise)
-            # Todo: this little piece of code can be refactored                                                                                
+            op_type = 'conv' if param.dim() == 4 else 'fc'
+            if op_type == 'conv':
+                bottomk_channels, channel_mags = LpRankedStructureParameterPruner.rank_channels(
+                    magnitude_fn, fraction_to_prune, param, group_size, rounding_fn, noise)
+
+            else:
+                bottomk_channels, channel_mags = LpRankedStructureParameterPruner.rank_rows(
+                     magnitude_fn, fraction_to_prune, param)
+
+            # Todo: this little piece of code can be refactored
             if bottomk_channels is None:
                 # Empty list means that fraction_to_prune is too low to prune anything
                 return
-            
+
             threshold = bottomk_channels[-1]
             binary_map = channel_mags.gt(threshold)
 
+
             # These are the indices of channels we want to keep
             indices = binary_map.nonzero().squeeze()
             if len(indices.shape) == 0:
@@ -779,7 +802,9 @@ class FMReconstructionChannelPruner(_RankedStructureParameterPruner):
             # min(MSE) to compute the weights, we need to start by removing feature-map 
             # channels from the input.  Then we perform the MSE regression to generate
             # a smaller weights tensor.
-            if conv.kernel_size == (1,1):
+            if op_type == 'fc':
+                X = X[:, binary_map]
+            elif conv.kernel_size == (1, 1):
                 X = X[:, binary_map, :]
                 X = X.transpose(1, 2)
                 X = X.contiguous().view(-1, X.size(2))
@@ -797,18 +822,29 @@ class FMReconstructionChannelPruner(_RankedStructureParameterPruner):
             new_w = torch.from_numpy(new_w) # shape: (num_filters, num_non_masked_channels * k^2)
             cnt_retained_channels = binary_map.sum()
 
-            # Expand the weights back to their original size,
-            new_w = new_w.contiguous().view(param.size(0), cnt_retained_channels, param.size(2), param.size(3))
-            
-            # Copy the weights that we learned from minimizing the feature-maps least squares error,
-            # to our actual weights tensor.
-            param.detach()[:,indices,:,:] = new_w.type(param.type())
-            
+            if op_type == 'conv':
+                # Expand the weights back to their original size,
+                new_w = new_w.contiguous().view(param.size(0), cnt_retained_channels, param.size(2), param.size(3))
+
+                # Copy the weights that we learned from minimizing the feature-maps least squares error,
+                # to our actual weights tensor.
+                param.detach()[:,indices,:,:] = new_w.type(param.type())
+            else:
+                param.detach()[:, indices] = new_w.type(param.type())
+
         if zeros_mask_dict is not None:
             binary_map = binary_map.type(param.type())
-            zeros_mask_dict[param_name].mask = LpRankedStructureParameterPruner.ch_binary_map_to_mask(binary_map, param)
-            msglogger.info("FMReconstructionChannelPruner - param: %s pruned=%.3f goal=%.3f (%d/%d)",
-                           param_name,
-                           distiller.sparsity_ch(zeros_mask_dict[param_name].mask),
-                           fraction_to_prune, binary_map.sum().item(), param.size(1))
+            if op_type == 'conv':
+                zeros_mask_dict[param_name].mask = LpRankedStructureParameterPruner.ch_binary_map_to_mask(binary_map, param)
+                msglogger.info("FMReconstructionChannelPruner - param: %s pruned=%.3f goal=%.3f (%d/%d)",
+                               param_name,
+                               distiller.sparsity_ch(zeros_mask_dict[param_name].mask),
+                               fraction_to_prune, binary_map.sum().item(), param.size(1))
+            else:
+                msglogger.error("fc sparsity = %.2f" % (1 - binary_map.sum().item() / binary_map.size(0)))
+                zeros_mask_dict[param_name].mask = binary_map.expand(param.size(0), param.size(1))
+                msglogger.info("FMReconstructionChannelPruner - param: %s pruned=%.3f goal=%.3f (%d/%d)",
+                               param_name,
+                               distiller.sparsity_cols(zeros_mask_dict[param_name].mask),
+                               fraction_to_prune, binary_map.sum().item(), param.size(1))
         return binary_map
diff --git a/distiller/scheduler.py b/distiller/scheduler.py
index a0b7490..be4b14d 100755
--- a/distiller/scheduler.py
+++ b/distiller/scheduler.py
@@ -34,13 +34,13 @@ class CompressionScheduler(object):
     """Responsible for scheduling pruning and masking parameters.
 
     """
-    def __init__(self, model, device=torch.device("cuda")):
+    def __init__(self, model, zeros_mask_dict=None, device=torch.device("cuda")):
         self.model = model
         self.device = device
         self.policies = {}
         self.sched_metadata = {}
         # Create the masker objects and place them in a dictionary indexed by the parameter name
-        self.zeros_mask_dict = create_model_masks_dict(model)
+        self.zeros_mask_dict = zeros_mask_dict or create_model_masks_dict(model)
 
     def add_policy(self, policy, epochs=None, starting_epoch=0, ending_epoch=1, frequency=1):
         """Add a new policy to the schedule.
@@ -212,7 +212,6 @@ class CompressionScheduler(object):
             if name not in masks_dict:
                 masks_dict[name] = None
         state = {'masks_dict': masks_dict}
-
         self.load_state_dict(state, normalize_dataparallel_keys)
 
     @staticmethod
diff --git a/distiller/thinning.py b/distiller/thinning.py
index d06fe3a..0d790c2 100755
--- a/distiller/thinning.py
+++ b/distiller/thinning.py
@@ -60,11 +60,10 @@ __all__ = ['ThinningRecipe', 'resnet_cifar_remove_layers',
            'StructureRemover',
            'ChannelRemover', 'remove_channels',
            'FilterRemover',  'remove_filters',
-           'find_nonzero_channels', 'find_nonzero_channels_list',
            'execute_thinning_recipes_list', 'get_normalized_recipe']
 
 
-def create_graph(dataset, model):
+def _create_graph(dataset, model):
     dummy_input = distiller.get_dummy_input(dataset, distiller.model_device(model))
     return SummaryGraph(model, dummy_input)
 
@@ -144,41 +143,12 @@ def append_bn_thinning_directive(thinning_recipe, layers, bn_name, len_thin_feat
 
 
 def remove_channels(model, zeros_mask_dict, arch, dataset, optimizer):
-    sgraph = create_graph(dataset, model)
+    sgraph = _create_graph(dataset, model)
     thinning_recipe = create_thinning_recipe_channels(sgraph, model, zeros_mask_dict)
     apply_and_save_recipe(model, zeros_mask_dict, thinning_recipe, optimizer)
     return model
 
 
-def find_nonzero_channels(param, param_name):
-    """Count the number of non-zero channels in a weights tensor.
-
-    Non-zero channels are channels that have at least one coefficient that is
-    non-zero.  Counting non-zero channels involves some tensor acrobatics.
-    """
-    num_filters, num_channels = param.size(0), param.size(1)
-
-    # First, reshape the weights tensor such that each channel (kernel) in the original
-    # tensor, is now a row in the 2D tensor.
-    view_2d = param.view(-1, param.size(2) * param.size(3))
-    # Next, compute the sums of each kernel
-    kernel_sums = view_2d.abs().sum(dim=1)
-    # Now group by channels
-    k_sums_mat = kernel_sums.view(num_filters, num_channels).t()
-    nonzero_channels = torch.nonzero(k_sums_mat.abs().sum(dim=1))
-
-    if num_channels > nonzero_channels.nelement():
-        msglogger.debug("In tensor %s found %d/%d zero channels", param_name,
-                        num_channels - nonzero_channels.nelement(), num_channels)
-    return nonzero_channels
-
-# Todo: consider removing this function
-def find_nonzero_channels_list(param, param_name):
-    nnz_channels = find_nonzero_channels(param, param_name)
-    nnz_channels = nnz_channels.view(nnz_channels.numel())
-    return nnz_channels.cpu().numpy().tolist()
-
-
 def apply_and_save_recipe(model, zeros_mask_dict, thinning_recipe, optimizer):
     if len(thinning_recipe.modules) > 0 or len(thinning_recipe.parameters) > 0:
         # Now actually remove the filters, channels and make the weight tensors smaller
@@ -196,7 +166,7 @@ def apply_and_save_recipe(model, zeros_mask_dict, thinning_recipe, optimizer):
 
 
 def remove_filters(model, zeros_mask_dict, arch, dataset, optimizer):
-    sgraph = create_graph(dataset, model)
+    sgraph = _create_graph(dataset, model)
     thinning_recipe = create_thinning_recipe_filters(sgraph, model, zeros_mask_dict)
     apply_and_save_recipe(model, zeros_mask_dict, thinning_recipe, optimizer)
     return model
@@ -212,24 +182,27 @@ def create_thinning_recipe_channels(sgraph, model, zeros_mask_dict):
     The thinning recipe contains meta-instructions of how the model
     should be changed in order to remove the channels.
     """
-    def handle_layer(layer_name, param_name, num_nnz_channels):
+    def handle_layer(layer_name, param_name, nnz_channels):
         # We are removing channels, so update the number of incoming channels (IFMs)
         # in the convolutional layer
-        assert isinstance(layers[layer_name], torch.nn.modules.Conv2d)
-        append_module_directive(thinning_recipe, layer_name, key='in_channels', val=num_nnz_channels)
+        assert isinstance(layers[layer_name], (torch.nn.modules.Conv2d, torch.nn.modules.Linear))
+
+        append_module_directive(thinning_recipe, layer_name, key='in_channels', val=nnz_channels)
 
         # Select only the non-zero channels
         indices = nonzero_channels.data.squeeze()
-        dim = 1 if layers[layer_name].groups == 1 else 0
+        dim = 1 if isinstance(layers[layer_name], torch.nn.modules.Conv2d) and layers[layer_name].groups == 1 else 0
+        if isinstance(layers[layer_name], torch.nn.modules.Linear):
+            dim = 1
         append_param_directive(thinning_recipe, param_name, (dim, indices))
 
-        # Find all instances of Convolution layers that immediately preceed this layer
+        # Find all instances of Convolution layers that immediately precede this layer
         predecessors = sgraph.predecessors_f(layer_name, ['Conv'])
         if not predecessors:
             msglogger.info("Could not find predecessors for name=%s" % layer_name)
         for predecessor in predecessors:
-            # For each of the convolutional layers that preceed, we have to reduce the number of output channels.
-            append_module_directive(thinning_recipe, predecessor, key='out_channels', val=num_nnz_channels)
+            # For each of the convolution layers that precede, we have to reduce the number of output channels.
+            append_module_directive(thinning_recipe, predecessor, key='out_channels', val=nnz_channels)
 
             if layers[predecessor].groups == 1:
                 # Now remove filters from the weights tensor of the predecessor conv
@@ -246,13 +219,13 @@ def create_thinning_recipe_channels(sgraph, model, zeros_mask_dict):
                 if layers[predecessor].bias is not None:
                     # This convolution has bias coefficients
                     append_param_directive(thinning_recipe, predecessor+'.bias', (0, indices))
-                append_module_directive(thinning_recipe, predecessor, key='groups', val=num_nnz_channels)
+                append_module_directive(thinning_recipe, predecessor, key='groups', val=nnz_channels)
 
                 # In the special case of a Convolutional layer with (groups == in_channels), if we
                 # change in_channels, we also need to change out_channels, which means that we
                 # have to perform filter removal for this layer as well
                 param_name = predecessor+'.weight'
-                handle_layer(predecessor, param_name, num_nnz_channels)
+                handle_layer(predecessor, param_name, nnz_channels)
             else:
                 raise ValueError("Distiller thinning code currently does not handle this conv.groups configuration")
 
@@ -262,7 +235,7 @@ def create_thinning_recipe_channels(sgraph, model, zeros_mask_dict):
             # Thinning of the BN layer that follows the convolution
             msglogger.debug("[recipe] {}: predecessor BN module = {}".format(layer_name, bn_layer))
             append_bn_thinning_directive(thinning_recipe, layers, bn_layer,
-                                         len_thin_features=num_nnz_channels, thin_features=indices)
+                                         len_thin_features=nnz_channels, thin_features=indices)
 
     msglogger.debug("Invoking create_thinning_recipe_channels")
     thinning_recipe = ThinningRecipe(modules={}, parameters={})
@@ -271,19 +244,25 @@ def create_thinning_recipe_channels(sgraph, model, zeros_mask_dict):
     # Traverse all of the model's parameters, search for zero-channels, and
     # create a thinning recipe that descibes the required changes to the model.
     for layer_name, param_name, param in sgraph.named_params_layers():
-        # We are only interested in 4D weights (of Convolution layers)
-        if param.dim() != 4:
-            continue
-        num_channels = param.size(1)
-        nonzero_channels = find_nonzero_channels(param, param_name)
-        num_nnz_channels = nonzero_channels.nelement()
-        if num_nnz_channels == 0:
-            raise ValueError("Trying to set zero channels for parameter %s is not allowed" % param_name)
-        # If there are non-zero channels in this tensor then continue to next tensor
-        if num_channels <= num_nnz_channels:
-            continue
-        handle_layer(layer_name, param_name, num_nnz_channels)       
-    msglogger.debug(thinning_recipe)
+        if param.dim() in (2, 4):
+            num_channels = param.size(1)
+            # Find nonzero input channels
+            if param.dim() == 2:
+                # 2D weights (of Linear layers)
+                col_sums = param.abs().sum(dim=0)
+                nonzero_channels = torch.nonzero(col_sums)
+                num_nnz_channels = nonzero_channels.nelement()
+            elif param.dim() == 4:
+                # 4D weights (of Convolution layers)
+                nonzero_channels = distiller.non_zero_channels(param)
+                num_nnz_channels = nonzero_channels.nelement()
+            if num_nnz_channels == 0:
+                raise ValueError("Trying to zero all channels for parameter %s is not allowed" % param_name)
+
+            # If there are no non-zero channels in this tensor then continue to next tensor
+            if num_channels <= num_nnz_channels:
+                 continue
+            handle_layer(layer_name, param_name, num_nnz_channels)
     return thinning_recipe
 
 
@@ -504,6 +483,8 @@ def execute_thinning_recipe(model, zeros_mask_dict, recipe, optimizer, loaded_fr
 
     with torch.no_grad():
         for param_name, param_directives in recipe.parameters.items():
+            if param_name == "module.fc.weight":
+                debug = True
             msglogger.debug("{} : {}".format(param_name, param_directives))
             param = distiller.model_find_param(model, param_name)
             assert param is not None
@@ -533,7 +514,7 @@ def execute_thinning_recipe(model, zeros_mask_dict, recipe, optimizer, loaded_fr
                 else:
                     if param.data.size(dim) != len_indices:
                         msglogger.debug("[thinning] changing param {} ({})  dim:{}  new len: {}".format(
-                            param_name, param.shape, dim, len_indices))
+                                        param_name, param.shape, dim, len_indices))
                         assert param.size(dim) > len_indices
                         param.data = torch.index_select(param.data, dim, indices.to(param.device))
                         msglogger.debug("[thinning] changed param {}".format(param_name))
@@ -546,13 +527,14 @@ def execute_thinning_recipe(model, zeros_mask_dict, recipe, optimizer, loaded_fr
                         if optimizer_thinning(optimizer, param, dim, indices):
                             msglogger.debug("Updated velocity buffer %s" % param_name)
 
-                if not loaded_from_file:
+                if not loaded_from_file and zeros_mask_dict:
                     # If the masks are loaded from a checkpoint file, then we don't need to change
                     # their shape, because they are already correctly shaped
                     mask = zeros_mask_dict[param_name].mask
                     if mask is not None and (mask.size(dim) != len_indices):
                         zeros_mask_dict[param_name].mask = torch.index_select(mask, dim, indices)
 
+
 # Todo: consider removing this function
 def resnet_cifar_remove_layers(model):
     """Remove layers from ResNet-Cifar.
diff --git a/distiller/utils.py b/distiller/utils.py
index f00e333..1b55f25 100755
--- a/distiller/utils.py
+++ b/distiller/utils.py
@@ -246,23 +246,35 @@ def density_2D(tensor):
     return 1 - sparsity_2D(tensor)
 
 
-def sparsity_ch(tensor):
-    """Channel-wise sparsity for 4D tensors"""
+def non_zero_channels(tensor):
+    """Returns the indices of non-zero channels.
+
+    Non-zero channels are channels that have at least one coefficient that
+    is not zero.  Counting non-zero channels involves some tensor acrobatics.
+    """
     if tensor.dim() != 4:
-        return 0
+        raise ValueError("Expecting a 4D tensor")
 
-    num_filters = tensor.size(0)
-    num_kernels_per_filter = tensor.size(1)
+    n_filters, n_channels, k_h, k_w = (tensor.size(i) for i in range(4))
 
-    # First, reshape the weights tensor such that each channel (kernel) in the original
-    # tensor, is now a row in the 2D tensor.
-    view_2d = tensor.view(-1, tensor.size(2) * tensor.size(3))
+    # First, reshape the weights tensor such that each channel (kernel) in
+    # the original tensor, is now a row in a 2D tensor.
+    view_2d = tensor.view(-1, k_h * k_w)
     # Next, compute the sums of each kernel
     kernel_sums = view_2d.abs().sum(dim=1)
     # Now group by channels
-    k_sums_mat = kernel_sums.view(num_filters, num_kernels_per_filter).t()
-    nonzero_channels = len(torch.nonzero(k_sums_mat.abs().sum(dim=1)))
-    return 1 - nonzero_channels/num_kernels_per_filter
+    k_sums_mat = kernel_sums.view(n_filters, n_channels).t()
+    nonzero_channels = torch.nonzero(k_sums_mat.abs().sum(dim=1))
+    return nonzero_channels
+
+
+def sparsity_ch(tensor):
+    """Channel-wise sparsity for 4D tensors"""
+    if tensor.dim() != 4:
+        return 0
+    nonzero_channels = len(non_zero_channels(tensor))
+    n_channels = tensor.size(1)
+    return 1 - nonzero_channels/n_channels
 
 
 def density_ch(tensor):
diff --git a/examples/auto_compression/amc/amc.py b/examples/auto_compression/amc/amc.py
index 5155867..46d633c 100755
--- a/examples/auto_compression/amc/amc.py
+++ b/examples/auto_compression/amc/amc.py
@@ -19,28 +19,12 @@ $ python3 amc.py --arch=resnet20_cifar ${CIFAR10_PATH} --resume=../../ssl/checkp
 """
 
 
-import math
 import os
-import copy
 import logging
-import numpy as np
-import torch
-import csv
 import traceback
 from functools import partial
-try:
-    import gym
-except ImportError as e:
-    print("WARNING: to use automated compression you will need to install extra packages")
-    print("See instructions in the interface of each RL library.")
-    raise e
-from gym import spaces
 import distiller
-from collections import OrderedDict, namedtuple
-from types import SimpleNamespace
-from distiller import normalize_module_name, SummaryGraph
 from environment import DistillerWrapperEnvironment, Observation
-from utils.features_collector import collect_intermediate_featuremap_samples
 import distiller.apputils as apputils
 import distiller.apputils.image_classifier as classifier
 from rewards import reward_factory
@@ -117,7 +101,7 @@ def train_auto_compressor(model, args, optimizer_data, validate_fn, save_checkpo
     ddpg_cfg = distiller.utils.MutableNamedTuple({
             'heatup_noise': 0.5,
             'initial_training_noise': 0.5,
-            'training_noise_decay': 0.99555, #0.98, #0.996,
+            'training_noise_decay': 0.95,
             'num_heatup_episodes': args.amc_heatup_episodes,
             'num_training_episodes': args.amc_training_episodes,
             'actor_lr': 1e-4,
@@ -147,7 +131,8 @@ def train_auto_compressor(model, args, optimizer_data, validate_fn, save_checkpo
 
     def create_environment():
         env = DistillerWrapperEnvironment(model, app_args, amc_cfg, services)
-        env.amc_cfg.ddpg_cfg.replay_buffer_size = 100 * env.steps_per_episode
+        #env.amc_cfg.ddpg_cfg.replay_buffer_size = int(1.5 * amc_cfg.ddpg_cfg.num_heatup_episodes * env.steps_per_episode)
+        env.amc_cfg.ddpg_cfg.replay_buffer_size = amc_cfg.ddpg_cfg.num_heatup_episodes * env.steps_per_episode
         return env
 
     env1 = create_environment()
@@ -180,7 +165,7 @@ def train_auto_compressor(model, args, optimizer_data, validate_fn, save_checkpo
         raise ValueError("unsupported rl library: ", args.amc_rllib)
 
 
-def config_verbose(verbose):
+def config_verbose(verbose, display_summaries=False):
     if verbose:
         loglevel = logging.DEBUG
     else:
@@ -188,11 +173,14 @@ def config_verbose(verbose):
         logging.getLogger().setLevel(logging.WARNING)
     for module in ["examples.auto_compression.amc",
                    "distiller.apputils.image_classifier",
-                   "distiller.data_loggers.logger",
-                   "distiller.thinning", 
+                   "distiller.thinning",
                    "distiller.pruning.ranked_structures_pruner"]:
         logging.getLogger(module).setLevel(loglevel)
 
+    # display training progress summaries
+    summaries_lvl = logging.INFO if display_summaries else logging.WARNING
+    logging.getLogger("examples.auto_compression.amc.summaries").setLevel(summaries_lvl)
+
 
 if __name__ == '__main__':
     try:
diff --git a/examples/auto_compression/amc/auto_compression_channels.yaml b/examples/auto_compression/amc/auto_compression_channels.yaml
index a58493d..558c975 100755
--- a/examples/auto_compression/amc/auto_compression_channels.yaml
+++ b/examples/auto_compression/amc/auto_compression_channels.yaml
@@ -9,7 +9,8 @@ network:
       "module.model.3.3",  "module.model.4.3", "module.model.5.3",
       "module.model.6.3",  "module.model.7.3", "module.model.8.3",
       "module.model.9.3",  "module.model.10.3", "module.model.11.3",
-      "module.model.12.3", "module.model.13.3"]
+      "module.model.12.3", "module.model.13.3",
+      "module.fc"]
 
   mobilenet_v2:
       # Only conv 1x1, without shortcut connection dependencies
@@ -94,7 +95,8 @@ network:
       "module.layer2.2.conv1", "module.layer2.2.conv2",
       "module.layer3.0.conv1", "module.layer3.0.conv2",
       "module.layer3.1.conv1", "module.layer3.1.conv2",
-      "module.layer3.2.conv1", "module.layer3.2.conv2"]
+      "module.layer3.2.conv1", "module.layer3.2.conv2",
+      "module.fc"]
 
   simplenet_mnist:
     ["module.conv2"]
diff --git a/examples/auto_compression/amc/environment.py b/examples/auto_compression/amc/environment.py
index d248d4b..79415ab 100755
--- a/examples/auto_compression/amc/environment.py
+++ b/examples/auto_compression/amc/environment.py
@@ -29,13 +29,7 @@ import copy
 import logging
 import numpy as np
 import torch
-try:
-    import gym
-except ImportError as e:
-    print("WARNING: to use automated compression you will need to install extra packages")
-    print("See instructions in the header of examples/automated_deep_compression/ADC.py")
-    raise e
-from gym import spaces
+import gym
 import distiller
 from collections import OrderedDict, namedtuple
 from types import SimpleNamespace
@@ -45,7 +39,7 @@ from utils.ac_loggers import AMCStatsLogger, FineTuneStatsLogger
 
 
 msglogger = logging.getLogger("examples.auto_compression.amc")
-Observation = namedtuple('Observation', ['t', 'n', 'c',  'h', 'w', 'stride', 'k', 'MACs',
+Observation = namedtuple('Observation', ['t', 'type', 'n', 'c',  'h', 'w', 'stride', 'k', 'MACs',
                                          'Weights', 'reduced', 'rest', 'prev_a'])
 ObservationLen = len(Observation._fields)
 ALMOST_ONE = 0.9999
@@ -141,8 +135,8 @@ class NetworkMetadata(object):
     def is_prunable(self, layer_id):
         return layer_id in self.pruned_idxs
 
-    def is_reducible(self, layer_id):
-        return layer_id in self.pruned_idxs or layer_id in self.dependent_idxs
+    def is_compressible(self, layer_id):
+        return layer_id in (self.pruned_idxs + self.dependent_idxs)
 
     def num_pruned_layers(self):
         return len(self.pruned_idxs)
@@ -165,6 +159,7 @@ class NetworkWrapper(object):
                                                      pruning_pattern, modules_list)
         self.cached_perf_summary = self.cached_model_metadata.performance_summary()
         self.reset(model)
+        self.sparsification_masks = None
 
     def reset(self, model):
         self.model = model
@@ -218,9 +213,7 @@ class NetworkWrapper(object):
         return ret
 
     def create_scheduler(self):
-        scheduler = distiller.CompressionScheduler(self.model)
-        masks = {param_name: masker.mask for param_name, masker in self.zeros_mask_dict.items()}
-        scheduler.load_state_dict(state={'masks_dict': masks})
+        scheduler = distiller.CompressionScheduler(self.model, self.zeros_mask_dict)
         return scheduler
 
     def remove_structures(self, layer_id, fraction_to_prune, prune_what, prune_how, 
@@ -233,9 +226,9 @@ class NetworkWrapper(object):
             raise ValueError("idx=%d is not in correct range " % layer_id)
         if fraction_to_prune < 0:
             raise ValueError("fraction_to_prune=%.3f is illegal" % fraction_to_prune)
-
         if fraction_to_prune == 0:
             return 0
+
         if fraction_to_prune == 1.0:
             # For now, prevent the removal of entire layers
             fraction_to_prune = ALMOST_ONE
@@ -249,6 +242,8 @@ class NetworkWrapper(object):
 
         if prune_what == "channels":
             calculate_sparsity = distiller.sparsity_ch
+            if layer.type == "Linear":
+                calculate_sparsity = distiller.sparsity_rows
             remove_structures_fn = distiller.remove_channels
             group_type = "Channels"
         elif prune_what == "filters":
@@ -258,7 +253,7 @@ class NetworkWrapper(object):
         else:
             raise ValueError("unsupported structure {}".format(prune_what))
 
-        if prune_how == "l1-rank" or prune_how == "stochastic-l1-rank":
+        if prune_how in ["l1-rank", "stochastic-l1-rank"]:
             # Create a channel/filter-ranking pruner
             pruner = distiller.pruning.L1RankedStructureParameterPruner(
                 "auto_pruner", group_type, fraction_to_prune, conv_pname,
@@ -275,14 +270,15 @@ class NetworkWrapper(object):
         del pruner
 
         if (self.zeros_mask_dict[conv_pname].mask is None or 
-                0 == calculate_sparsity(self.zeros_mask_dict[conv_pname].mask)):
+            0 == calculate_sparsity(self.zeros_mask_dict[conv_pname].mask)):
             msglogger.debug("remove_structures: aborting because there are no structures to prune")
             return 0
         final_action = calculate_sparsity(self.zeros_mask_dict[conv_pname].mask)
 
         # Use the mask to prune
         self.zeros_mask_dict[conv_pname].apply_mask(conv_p)
-        if apply_thinning:     
+        if apply_thinning:
+            self.cache_spasification_masks()
             remove_structures_fn(self.model, self.zeros_mask_dict, self.app_args.arch, self.app_args.dataset, optimizer=None)
 
         self.model_metadata.reduce_layer_macs(layer, final_action)
@@ -295,7 +291,7 @@ class NetworkWrapper(object):
         return top1, top5, vloss
 
     def train(self, num_epochs, episode=0):
-        # Train for zero or more epochs
+        """Train for zero or more epochs"""
         opt_cfg = self.app_args.optimizer_data
         optimizer = torch.optim.SGD(self.model.parameters(), lr=opt_cfg['lr'],
                                     momentum=opt_cfg['momentum'], weight_decay=opt_cfg['weight_decay'])
@@ -309,13 +305,18 @@ class NetworkWrapper(object):
         del compression_scheduler
         return acc_list
 
+    def cache_spasification_masks(self):
+        masks = {param_name: masker.mask for param_name, masker in self.zeros_mask_dict.items()}
+        self.sparsification_masks = copy.deepcopy(masks)
+
 
 class DistillerWrapperEnvironment(gym.Env):
     def __init__(self, model, app_args, amc_cfg, services):
-        self.pylogger = distiller.data_loggers.PythonLogger(msglogger)
+        self.pylogger = distiller.data_loggers.PythonLogger(
+            logging.getLogger("examples.auto_compression.amc.summaries"))
         logdir = logging.getLogger().logdir
         self.tflogger = distiller.data_loggers.TensorBoardLogger(logdir)
-        self.verbose = False
+        self._render = False
         self.orig_model = copy.deepcopy(model)
         self.app_args = app_args
         self.amc_cfg = amc_cfg
@@ -331,12 +332,12 @@ class DistillerWrapperEnvironment(gym.Env):
         self._max_episode_steps = self.net_wrapper.model_metadata.num_pruned_layers()  # Hack for Coach-TD3
         self.episode = 0
         self.best_reward = float("-inf")
-        self.action_low = amc_cfg.action_range[0]
-        self.action_high = amc_cfg.action_range[1]
+        self.action_low, self.action_high = amc_cfg.action_range
+        #self.action_high = amc_cfg.action_range[1]
         self._log_model_info()
         log_amc_config(amc_cfg)
         self._configure_action_space()
-        self.observation_space = spaces.Box(0, float("inf"), shape=(len(Observation._fields),))
+        self.observation_space = gym.spaces.Box(0, float("inf"), shape=(len(Observation._fields),))
         self.stats_logger = AMCStatsLogger(os.path.join(logdir, 'amc.csv'))
         self.ft_stats_logger = FineTuneStatsLogger(os.path.join(logdir, 'ft_top1.csv'))
 
@@ -355,7 +356,7 @@ class DistillerWrapperEnvironment(gym.Env):
 
         def acceptance_criterion(m, mod_names):
             # Collect feature-maps only for Conv2d layers, if they are in our modules list.
-            return isinstance(m, torch.nn.Conv2d) and m.distiller_name in mod_names
+            return isinstance(m, (torch.nn.Conv2d, torch.nn.Linear)) and m.distiller_name in mod_names
 
         # For feature-map reconstruction we need to collect a representative set
         # of inter-layer feature-maps
@@ -377,12 +378,12 @@ class DistillerWrapperEnvironment(gym.Env):
     def _configure_action_space(self):
         if is_using_continuous_action_space(self.amc_cfg.agent_algo):
             if self.amc_cfg.agent_algo == "ClippedPPO-continuous":
-                self.action_space = spaces.Box(PPO_MIN, PPO_MAX, shape=(1,))
+                self.action_space = gym.spaces.Box(PPO_MIN, PPO_MAX, shape=(1,))
             else:
-                self.action_space = spaces.Box(self.action_low, self.action_high, shape=(1,))
+                self.action_space = gym.spaces.Box(self.action_low, self.action_high, shape=(1,))
             self.action_space.default_action = self.action_low
         else:
-            self.action_space = spaces.Discrete(10)
+            self.action_space = gym.spaces.Discrete(10)
 
 
     @property
@@ -401,7 +402,7 @@ class DistillerWrapperEnvironment(gym.Env):
         if hasattr(self.net_wrapper.model, 'intermediate_fms'):
             self.model.intermediate_fms = self.net_wrapper.model.intermediate_fms
         self.net_wrapper.reset(self.model)
-        self._removed_macs = 0
+        self.removed_macs = 0
         self.action_history = []
         self.agent_action_history = []
         self.model_representation = self.get_model_representation()
@@ -421,17 +422,13 @@ class DistillerWrapperEnvironment(gym.Env):
         """Return the amount of MACs removed so far.
         This is normalized to the range 0..1
         """
-        return self._removed_macs / self.original_model_macs
+        return self.removed_macs / self.original_model_macs
 
     def render(self, mode='human'):
         """Provide some feedback to the user about what's going on.
         This is invoked by the Agent.
         """
-        if self.current_state_id == 0:
-            msglogger.info("+" + "-" * 50 + "+")
-            msglogger.info("Starting a new episode %d", self.episode)
-            msglogger.info("+" + "-" * 50 + "+")
-        if not self.verbose:
+        if not self._render:
             return
         msglogger.info("Render Environment: current_state_id=%d" % self.current_state_id)
         distiller.log_weights_sparsity(self.model, -1, loggers=[self.pylogger])
@@ -442,6 +439,10 @@ class DistillerWrapperEnvironment(gym.Env):
         The action represents the desired sparsity for the "current" layer (i.e. the percentage of weights to remove).
         This function is invoked by the Agent.
         """
+        if self.current_state_id == 0:
+            msglogger.info("+" + "-" * 50 + "+")
+            msglogger.info("Episode %d is starting" % self.episode)
+
         pruning_action = float(pruning_action[0])
         msglogger.debug("env.step - current_state_id=%d (%s) episode=%d action=%.2f" %
                         (self.current_state_id, self.current_layer().name, self.episode, pruning_action))
@@ -487,13 +488,13 @@ class DistillerWrapperEnvironment(gym.Env):
         layer_macs_after_action = self.net_wrapper.layer_macs(self.current_layer())
 
         # Update the various counters after taking the step
-        self._removed_macs += (total_macs_before - total_macs_after_act)
+        self.removed_macs += (total_macs_before - total_macs_after_act)
 
         msglogger.debug("\tactual_action={}".format(pruning_action))
         msglogger.debug("\tlayer_macs={} layer_macs_after_action={} removed now={}".format(layer_macs,
                                                                                         layer_macs_after_action,
                                                                                         (layer_macs - layer_macs_after_action)))
-        msglogger.debug("\tself._removed_macs={}".format(self._removed_macs))
+        msglogger.debug("\tself._removed_macs={}".format(self.removed_macs))
         assert math.isclose(layer_macs_after_action / layer_macs, 1 - pruning_action)
 
         stats = ('Performance/Validation/',
@@ -504,13 +505,11 @@ class DistillerWrapperEnvironment(gym.Env):
                                         total_steps=self.net_wrapper.num_pruned_layers(), log_freq=1, loggers=[self.tflogger])
 
         if self.episode_is_done():
-            msglogger.info("Episode is ending")
+            msglogger.info("Episode %d is ending" % self.episode)
             observation = self.get_final_obs()
-            reward, top1 = self.compute_reward(total_macs_after_act, total_nnz_after_act)
-            normalized_macs = total_macs_after_act / self.original_model_macs * 100
-            normalized_nnz = total_nnz_after_act / self.original_model_size * 100
-            self.finalize_episode(top1, reward, total_macs_after_act, normalized_macs,
-                                  normalized_nnz, self.action_history, self.agent_action_history)
+            reward, top1, top5, vloss = self.compute_reward(total_macs_after_act, total_nnz_after_act)
+            self.finalize_episode(reward, (top1, top5, vloss), total_macs_after_act, total_nnz_after_act,
+                                  self.action_history, self.agent_action_history)
             self.episode += 1
         else:
             self.current_layer_id = self.net_wrapper.model_metadata.pruned_idxs[self.current_state_id]
@@ -519,16 +518,16 @@ class DistillerWrapperEnvironment(gym.Env):
                 self.net_wrapper.train(1, self.episode)
             observation = self.get_obs()
             if self.amc_cfg.reward_frequency is not None and self.current_state_id % self.amc_cfg.reward_frequency == 0:
-                reward, top1 = self.compute_reward(total_macs_after_act, total_nnz_after_act, log_stats=False)
+                reward, top1, top5, vloss = self.compute_reward(total_macs_after_act, total_nnz_after_act)
             else:
                 reward = 0
         self.prev_action = pruning_action
         if self.episode_is_done():
+            normalized_macs = total_macs_after_act / self.original_model_macs * 100
             info = {"accuracy": top1, "compress_ratio": normalized_macs}
-            msglogger.info(self.removed_macs_pct)
             if self.amc_cfg.protocol == "mac-constrained":
                 # Sanity check (special case only for "mac-constrained")
-                #assert self.removed_macs_pct >= 1 - self.amc_cfg.target_density - 0.01
+                assert self.removed_macs_pct >= 1 - self.amc_cfg.target_density - 0.002 # 0.01
                 pass
         else:
             info = {}
@@ -538,8 +537,6 @@ class DistillerWrapperEnvironment(gym.Env):
         """Produce a state embedding (i.e. an observation)"""
         current_layer_macs = self.net_wrapper.layer_net_macs(self.current_layer())
         current_layer_macs_pct = current_layer_macs/self.original_model_macs
-        current_layer = self.current_layer()
-        conv_module = distiller.model_find_module(self.model, current_layer.name)
 
         obs = self.model_representation[self.current_state_id, :]
         obs[-1] = self.prev_action
@@ -571,17 +568,31 @@ class DistillerWrapperEnvironment(gym.Env):
         for state_id, layer_id in enumerate(self.net_wrapper.model_metadata.pruned_idxs):
             layer = self.net_wrapper.get_layer(layer_id)
             layer_macs = self.net_wrapper.layer_macs(layer)
-            conv_module = distiller.model_find_module(self.model, layer.name)
-            obs = [state_id,
-                   conv_module.out_channels,
-                   conv_module.in_channels,
-                   layer.ifm_h,
-                   layer.ifm_w,
-                   layer.stride[0],
-                   layer.k,
-                   distiller.volume(conv_module.weight),
-                   layer_macs,
-                   0, 0, 0]
+            mod = distiller.model_find_module(self.model, layer.name)
+            if isinstance(mod, torch.nn.Conv2d):
+                obs = [state_id,
+                       0,
+                       mod.out_channels,
+                       mod.in_channels,
+                       layer.ifm_h,
+                       layer.ifm_w,
+                       layer.stride[0],
+                       layer.k,
+                       distiller.volume(mod.weight),
+                       layer_macs,
+                       0, 0, 0]
+            elif isinstance(mod, torch.nn.Linear):
+                obs = [state_id,
+                       1,
+                       mod.out_features,
+                       mod.in_features,
+                       layer.ifm_h,
+                       layer.ifm_w,
+                       0,
+                       1,
+                       distiller.volume(mod.weight),
+                       layer_macs,
+                       0, 0, 0]
             network_obs[state_id:] = np.array(obs)
 
         # Feature normalization
@@ -596,8 +607,8 @@ class DistillerWrapperEnvironment(gym.Env):
 
     def rest_macs_raw(self):
         """Return the number of remaining MACs in the layers following the current layer"""
-        rest, prunable_rest = 0, 0
-        prunable_layers, rest_layers, layers_to_ignore = list(), list(), list()
+        nonprunable_rest, prunable_rest = 0, 0
+        prunable_layers, nonprunable_layers, layers_to_ignore = list(), list(), list()
 
         # Create a list of the IDs of the layers that are dependent on the current_layer.
         # We want to ignore these layers when we compute prunable_layers (and prunable_rest).
@@ -606,17 +617,21 @@ class DistillerWrapperEnvironment(gym.Env):
 
         for layer_id in range(self.current_layer_id+1, self.net_wrapper.model_metadata.num_layers()):
             layer_macs = self.net_wrapper.layer_net_macs(self.net_wrapper.get_layer(layer_id))
-            if self.net_wrapper.model_metadata.is_reducible(layer_id):
+            if self.net_wrapper.model_metadata.is_compressible(layer_id):
                 if layer_id not in layers_to_ignore:
                     prunable_layers.append((layer_id, self.net_wrapper.get_layer(layer_id).name, layer_macs))
                     prunable_rest += layer_macs
-            else:
-                rest_layers.append((layer_id, self.net_wrapper.get_layer(layer_id).name, layer_macs))
-                rest += layer_macs
 
-        msglogger.debug("prunable_layers={} rest_layers={}".format(prunable_layers, rest_layers))
-        msglogger.debug("layer_id=%d, prunable_rest=%.3f rest=%.3f" % (self.current_layer_id, prunable_rest, rest))
-        return prunable_rest, rest
+        for layer_id in list(range(0, self.net_wrapper.model_metadata.num_layers())):
+            if not self.net_wrapper.model_metadata.is_compressible(layer_id): #and
+                layer_macs = self.net_wrapper.layer_net_macs(self.net_wrapper.get_layer(layer_id))
+                nonprunable_layers.append((layer_id, self.net_wrapper.get_layer(layer_id).name, layer_macs))
+                nonprunable_rest += layer_macs
+
+        msglogger.debug("prunable_layers={} nonprunable_layers={}".format(prunable_layers, nonprunable_layers))
+        msglogger.debug("layer_id=%d (%s), prunable_rest=%.3f nonprunable_rest=%.3f" %
+                        (self.current_layer_id, self.current_layer().name, prunable_rest, nonprunable_rest))
+        return prunable_rest, nonprunable_rest
 
     def rest_macs(self):
         return sum(self.rest_macs_raw()) / self.original_model_macs
@@ -625,13 +640,12 @@ class DistillerWrapperEnvironment(gym.Env):
         current_density = compressed_model_total_macs / self.original_model_macs
         return self.amc_cfg.target_density >= current_density
 
-    def compute_reward(self, total_macs, total_nnz, log_stats=True):
+    def compute_reward(self, total_macs, total_nnz):
         """Compute the reward.
 
         We use the validation dataset (the size of the validation dataset is
         configured when the data-loader is instantiated)"""
-        distiller.log_weights_sparsity(self.model, -1, loggers=[self.pylogger])
-        compression = distiller.model_numel(self.model, param_dims=[4]) / self.original_model_size
+        num_elements = distiller.model_params_size(self.model, param_dims=[2, 4], param_types=['weight'])
 
         # Fine-tune (this is a nop if self.amc_cfg.num_ft_epochs==0)
         accuracies = self.net_wrapper.train(self.amc_cfg.num_ft_epochs, self.episode)
@@ -639,28 +653,15 @@ class DistillerWrapperEnvironment(gym.Env):
 
         top1, top5, vloss = self.net_wrapper.validate()
         reward = self.amc_cfg.reward_fn(self, top1, top5, vloss, total_macs)
+        return reward, top1, top5, vloss
 
-        if log_stats:
-            macs_normalized = total_macs/self.original_model_macs
-            msglogger.info("Total parameters left: %.2f%%" % (compression*100))
-            msglogger.info("Total compute left: %.2f%%" % (total_macs/self.original_model_macs*100))
-
-            stats = ('Performance/EpisodeEnd/',
-                     OrderedDict([('Loss', vloss),
-                                  ('Top1', top1),
-                                  ('Top5', top5),
-                                  ('reward', reward),
-                                  ('total_macs', int(total_macs)),
-                                  ('macs_normalized', macs_normalized*100),
-                                  ('log(total_macs)', math.log(total_macs)),
-                                  ('total_nnz', int(total_nnz))]))
-            distiller.log_training_progress(stats, None, self.episode, steps_completed=0, total_steps=1,
-                                            log_freq=1, loggers=[self.tflogger, self.pylogger])
-        return reward, top1
-
-    def finalize_episode(self, top1, reward, total_macs, normalized_macs,
-                         normalized_nnz, action_history, agent_action_history):
+    def finalize_episode(self, reward, val_results, total_macs, total_nnz,
+                         action_history, agent_action_history, log_stats=True):
         """Write the details of one network to the logger and create a checkpoint file"""
+        top1, top5, vloss = val_results
+        normalized_macs = total_macs / self.original_model_macs * 100
+        normalized_nnz = total_nnz / self.original_model_size * 100
+
         if reward > self.best_reward:
             self.best_reward = reward
             ckpt_name = self.save_checkpoint(is_best=True)
@@ -674,6 +675,20 @@ class DistillerWrapperEnvironment(gym.Env):
                   ckpt_name, json.dumps(action_history), json.dumps(agent_action_history),
                   json.dumps(performance)]
         self.stats_logger.add_record(fields)
+        msglogger.info("Top1: %.2f - compute: %.2f%% - params:%.2f%% - actions: %s",
+                       top1, normalized_macs, normalized_nnz, self.action_history)
+        if log_stats:
+            stats = ('Performance/EpisodeEnd/',
+                     OrderedDict([('Loss', vloss),
+                                  ('Top1', top1),
+                                  ('Top5', top5),
+                                  ('reward', reward),
+                                  ('total_macs', int(total_macs)),
+                                  ('macs_normalized', normalized_macs),
+                                  ('log(total_macs)', math.log(total_macs)),
+                                  ('total_nnz', int(total_nnz))]))
+            distiller.log_training_progress(stats, None, self.episode, steps_completed=0, total_steps=1,
+                                            log_freq=1, loggers=[self.tflogger, self.pylogger])
 
     def save_checkpoint(self, is_best=False):
         """Save the learned-model checkpoint"""
@@ -685,8 +700,9 @@ class DistillerWrapperEnvironment(gym.Env):
         if is_best or self.amc_cfg.save_chkpts:
             # Always save the best episodes, and depending on amc_cfg.save_chkpts save all other episodes
             scheduler = self.net_wrapper.create_scheduler()
+            extras = {"creation_masks": self.net_wrapper.sparsification_masks}
             self.services.save_checkpoint_fn(epoch=0, model=self.model,
-                                            scheduler=scheduler, name=fname)
+                                             scheduler=scheduler, name=fname, extras=extras)
             del scheduler
         return fname
 
@@ -694,6 +710,7 @@ class DistillerWrapperEnvironment(gym.Env):
 def get_network_details(model, dataset, dependency_type, layers_to_prune=None):
     def make_conv(model, conv_module, g, name, seq_id, layer_id):
         conv = SimpleNamespace()
+        conv.type = "Conv2D"
         conv.name = name
         conv.id = layer_id
         conv.t = seq_id
@@ -718,6 +735,7 @@ def get_network_details(model, dataset, dependency_type, layers_to_prune=None):
 
     def make_fc(model, fc_module, g, name, seq_id, layer_id):
         fc = SimpleNamespace()
+        fc.type = "Linear"
         fc.name = name
         fc.id = layer_id
         fc.t = seq_id
@@ -732,6 +750,11 @@ def get_network_details(model, dataset, dependency_type, layers_to_prune=None):
         fc.n_ifm = fc_op['attrs']['n_ifm']
         fc_pname = name + ".weight"
         fc_p = distiller.model_find_param(model, fc_pname)
+        fc.ofm_h = g.param_shape(fc_op['outputs'][0])[0]
+        fc.ofm_w = g.param_shape(fc_op['outputs'][0])[1]
+        fc.ifm_h = g.param_shape(fc_op['inputs'][0])[0]
+        fc.ifm_w = g.param_shape(fc_op['inputs'][0])[1]
+
         return fc
 
     dummy_input = distiller.get_dummy_input(dataset)
@@ -741,28 +764,36 @@ def get_network_details(model, dataset, dependency_type, layers_to_prune=None):
     dependent_layers = set()
     total_macs = 0
     total_params = 0
-    layers = OrderedDict({mod_name: m for mod_name, m in model.named_modules() 
-                          if isinstance(m, (torch.nn.Conv2d, torch.nn.Linear))})
+
+    unfiltered_layers = layers_topological_order(model, dummy_input)
+    mods = dict(model.named_modules())
+    layers = OrderedDict({mod_name: mods[mod_name] for mod_name in unfiltered_layers
+                          if mod_name in mods and
+                          isinstance(mods[mod_name], (torch.nn.Conv2d, torch.nn.Linear))})
+
+    # layers = OrderedDict({mod_name: m for mod_name, m in model.named_modules()
+    #                       if isinstance(m, (torch.nn.Conv2d, torch.nn.Linear))})
     for layer_id, (name, m) in enumerate(layers.items()):
-        if isinstance(m, torch.nn.Conv2d):
-            conv = make_conv(model, m, g, name, seq_id=len(pruned_indices), layer_id=layer_id)
-            all_layers[layer_id] = conv
-            total_params += conv.weights_vol
-            total_macs += conv.macs
+        if isinstance(m, torch.nn.Conv2d) or isinstance(m, torch.nn.Linear):
+            if isinstance(m, torch.nn.Conv2d):
+                new_layer = make_conv(model, m, g, name, seq_id=len(pruned_indices), layer_id=layer_id)
+                all_layers[layer_id] = new_layer
+                total_params += new_layer.weights_vol
+                total_macs += new_layer.macs
+            elif isinstance(m, torch.nn.Linear):
+                new_layer = make_fc(model, m, g, name, seq_id=len(pruned_indices), layer_id=layer_id)
+                all_layers[layer_id] = new_layer
+                total_params += new_layer.weights_vol
+                total_macs += new_layer.macs
 
             if layers_to_prune is None or name in layers_to_prune:
                 pruned_indices.append(layer_id)
                 # Find the data-dependent layers of this convolution
                 from utils.data_dependencies import find_dependencies
-                conv.dependencies = list()
-                find_dependencies(dependency_type, g, all_layers, name, conv.dependencies)
-                dependent_layers.add(tuple(conv.dependencies))
-        elif isinstance(m, torch.nn.Linear):
-            fc = make_fc(model, m, g, name, seq_id=len(pruned_indices), layer_id=layer_id)
-            all_layers[layer_id] = fc
-            total_macs += fc.macs
-            total_params += fc.weights_vol
- 
+                new_layer.dependencies = list()
+                find_dependencies(dependency_type, g, all_layers, name, new_layer.dependencies)
+                dependent_layers.add(tuple(new_layer.dependencies))
+
     def convert_layer_names_to_indices(layer_names):
         """Args:
             layer_names - list of layer names
@@ -777,6 +808,70 @@ def get_network_details(model, dataset, dependency_type, layers_to_prune=None):
     return all_layers, pruned_indices, dependent_indices, total_macs, total_params
 
 
+def layers_topological_order(model, dummy_input, recurrent=False):
+    """
+    Prepares an ordered list of layers to quantize sequentially. This list has all the layers ordered by their
+    topological order in the graph.
+    Args:
+        model (nn.Module): the model to quantize.
+        dummy_input (torch.Tensor): an input to be passed through the model.
+        recurrent (bool): indication on whether the model might have recurrent connections.
+    """
+
+    class _OpRank:
+        def __init__(self, adj_entry, rank=None):
+            self.adj_entry = adj_entry
+            self._rank = rank or 0
+
+        @property
+        def rank(self):
+            return self._rank
+
+        @rank.setter
+        def rank(self, val):
+            self._rank = max(val, self._rank)
+
+        def __repr__(self):
+            return '_OpRank(\'%s\' | %d)' % (self.adj_entry.op_meta.name, self.rank)
+
+    adj_map = SummaryGraph(model, dummy_input).adjacency_map()
+    ranked_ops = {k: _OpRank(v, 0) for k, v in adj_map.items()}
+
+    def _recurrent_ancestor(ranked_ops_dict, dest_op_name, src_op_name):
+        def _is_descendant(parent_op_name, dest_op_name):
+            successors_names = [op.name for op in adj_map[parent_op_name].successors]
+            if dest_op_name in successors_names:
+                return True
+            for succ_name in successors_names:
+                if _is_descendant(succ_name, dest_op_name):
+                    return True
+            return False
+
+        return _is_descendant(dest_op_name, src_op_name) and \
+            (0 < ranked_ops_dict[dest_op_name].rank < ranked_ops_dict[src_op_name].rank)
+
+    def rank_op(ranked_ops_dict, op_name, rank):
+        ranked_ops_dict[op_name].rank = rank
+        for child_op in adj_map[op_name].successors:
+            # In recurrent models: if a successor is also an ancestor - we don't increment its rank.
+            if not recurrent or not _recurrent_ancestor(ranked_ops_dict, child_op.name, op_name):
+                rank_op(ranked_ops_dict, child_op.name, ranked_ops_dict[op_name].rank + 1)
+
+    roots = [k for k, v in adj_map.items() if len(v.predecessors) == 0]
+    for root_op_name in roots:
+        rank_op(ranked_ops, root_op_name, 0)
+
+     # Take only the modules from the original model
+    # module_dict = dict(model.named_modules())
+    # Neta
+    ret = sorted([k for k in ranked_ops.keys()],
+                 key=lambda k: ranked_ops[k].rank)
+
+    # Check that only the actual roots have a rank of 0
+    assert {k for k in ret if ranked_ops[k].rank == 0} <= set(roots)
+    return ret
+
+
 import pandas as pd
 def sample_networks(net_wrapper, services):
     """Sample networks from the posterior distribution.
diff --git a/examples/auto_compression/amc/rewards.py b/examples/auto_compression/amc/rewards.py
index ab16cb4..24c7487 100755
--- a/examples/auto_compression/amc/rewards.py
+++ b/examples/auto_compression/amc/rewards.py
@@ -58,25 +58,23 @@ def mac_constrained_experimental_reward_fn(env, top1, top5, vloss, total_macs):
 
 def mac_constrained_clamp_action(env, pruning_action):
     """Compute a resource-constrained action"""
-
-    # Todo: this is tightly coupled to the environment - refactor
-    flops = env.net_wrapper.layer_macs(env.current_layer())
-    assert flops > 0
-    reduced = env._removed_macs
-    prunable_rest, rest = env.rest_macs_raw()
-    rest += prunable_rest * env.action_high  # how much we have to remove in other layers
-    target_reduction = (1 - env.amc_cfg.target_density) * env.original_model_macs
+    layer_macs = env.net_wrapper.layer_macs(env.current_layer())
+    assert layer_macs > 0
+    reduced = env.removed_macs
+    prunable_rest, non_prunable_rest = env.rest_macs_raw()
+    rest = prunable_rest * min(0.9, env.action_high)
+    target_reduction = (1. - env.amc_cfg.target_density) * env.original_model_macs
     assert reduced == env.original_model_macs - env.net_wrapper.total_macs
     duty = target_reduction - (reduced + rest)
-    pruning_action_final = min(env.action_high, max(pruning_action, duty/flops))
+    pruning_action_final = min(1., max(pruning_action, duty/layer_macs))
 
-    msglogger.debug("\t\tflops=%.3f  reduced=%.3f  rest=%.3f  duty=%.3f" % (flops, reduced, rest, duty))
+    msglogger.debug("\t\tflops=%.3f  reduced=%.3f  rest=%.3f  duty=%.3f" % (layer_macs, reduced, rest, duty))
     msglogger.debug("\t\tpruning_action=%.3f  pruning_action_final=%.3f" % (pruning_action, pruning_action_final))
-    msglogger.debug("\t\ttarget={:.2f} reduced={:.2f} rest={:.2f} duty={:.2f} flops={:.2f}".
-                        format( 1-env.amc_cfg.target_density, reduced/env.original_model_macs,
-                                rest/env.original_model_macs, 
-                                duty/env.original_model_macs,
-                                flops/env.original_model_macs))
+    msglogger.debug("\t\ttarget={:.2f} reduced={:.2f} rest={:.2f} duty={:.2f} flops={:.2f}\n".
+                    format(1-env.amc_cfg.target_density, reduced/env.original_model_macs,
+                           rest/env.original_model_macs,
+                           duty/env.original_model_macs,
+                           layer_macs/env.original_model_macs))
     if pruning_action_final != pruning_action:
         msglogger.debug("pruning_action={:.2f}==>pruning_action_final={:.2f}".format(pruning_action,
                                                                                      pruning_action_final))
diff --git a/examples/auto_compression/amc/rl_libs/coach/coach_if.py b/examples/auto_compression/amc/rl_libs/coach/coach_if.py
index 41e2804..bef8e36 100755
--- a/examples/auto_compression/amc/rl_libs/coach/coach_if.py
+++ b/examples/auto_compression/amc/rl_libs/coach/coach_if.py
@@ -47,10 +47,11 @@ class RlLibInterface(object):
             graph_manager.heatup_steps = EnvironmentEpisodes(amc_cfg.ddpg_cfg.num_heatup_episodes)
             # Replay buffer size
             graph_manager.agent_params.memory.max_size = (MemoryGranularity.Transitions, amc_cfg.ddpg_cfg.replay_buffer_size)
+            amc_cfg.ddpg_cfg.training_noise_decay = amc_cfg.ddpg_cfg.training_noise_decay ** (1. / steps_per_episode)
         elif "ClippedPPO" in amc_cfg.agent_algo:
-            from examples.automated_deep_compression.rl_libs.coach.presets.ADC_ClippedPPO import graph_manager, agent_params
+            from examples.auto_compression.amc.rl_libs.coach.presets.ADC_ClippedPPO import graph_manager, agent_params
         elif "TD3" in amc_cfg.agent_algo:
-            from examples.automated_deep_compression.rl_libs.coach.presets.ADC_TD3 import graph_manager, agent_params
+            from examples.auto_compression.amc.rl_libs.coach.presets.ADC_TD3 import graph_manager, agent_params
         else:
             raise ValueError("The agent algorithm you are trying to use (%s) is not supported" % amc_cfg.amc_agent_algo)
 
@@ -61,10 +62,10 @@ class RlLibInterface(object):
         graph_manager.steps_between_evaluation_periods = EnvironmentEpisodes(n_training_episodes)
 
         # These parameters are passed to the Distiller environment
-        env_cfg  = {'model': model, 
-                    'app_args': app_args,
-                    'amc_cfg': amc_cfg,
-                    'services': services}
+        env_cfg = {'model': model,
+                   'app_args': app_args,
+                   'amc_cfg': amc_cfg,
+                   'services': services}
         graph_manager.env_params.additional_simulator_parameters = env_cfg
 
         coach_logs_dir = os.path.join(msglogger.logdir, 'coach')
diff --git a/examples/auto_compression/amc/rl_libs/private/private_if.py b/examples/auto_compression/amc/rl_libs/private/private_if.py
index 9d71601..2ca43ce 100755
--- a/examples/auto_compression/amc/rl_libs/private/private_if.py
+++ b/examples/auto_compression/amc/rl_libs/private/private_if.py
@@ -44,7 +44,6 @@ class RlLibInterface(object):
         agent_args.lr_a = env.amc_cfg.ddpg_cfg.actor_lr
         agent_args.hidden1 = 300
         agent_args.hidden2 = 300
-        agent_args.rmsize = 100
         agent_args.rmsize = env.amc_cfg.ddpg_cfg.replay_buffer_size
         agent_args.window_length = 1
         agent_args.train_episode = (env.amc_cfg.ddpg_cfg.num_heatup_episodes +
diff --git a/examples/auto_compression/amc/utils/data_dependencies.py b/examples/auto_compression/amc/utils/data_dependencies.py
index d82c665..1e44e29 100755
--- a/examples/auto_compression/amc/utils/data_dependencies.py
+++ b/examples/auto_compression/amc/utils/data_dependencies.py
@@ -38,7 +38,7 @@ def find_dependencies(dependency_type, sgraph, layers, layer_name, dependencies_
 
 
 def _find_dependencies_channels(sgraph, layers, layer_name, dependencies_list):
-    # Find all instances of Convolution layers that immediately preceed this layer
+    # Find all instances of Convolution layers that immediately precede this layer
     predecessors = sgraph.predecessors_f(layer_name, ['Conv'])
     for predecessor in predecessors:
         dependencies_list.append(predecessor)
diff --git a/tests/test_pruning.py b/tests/test_pruning.py
index 0239321..dd1b04f 100755
--- a/tests/test_pruning.py
+++ b/tests/test_pruning.py
@@ -18,6 +18,8 @@ import numpy as np
 import logging
 import math
 import torch
+from functools import partial
+
 import distiller
 import common
 import pytest
@@ -289,7 +291,7 @@ def arbitrary_channel_pruning(config, channels_to_remove, is_parallel):
     zeros_mask_dict[pair[1] + ".weight"].mask = mask
     zeros_mask_dict[pair[1] + ".weight"].apply_mask(conv2_p)
     all_channels = set([ch for ch in range(num_channels)])
-    nnz_channels = set(distiller.find_nonzero_channels_list(conv2_p, pair[1] + ".weight"))
+    nnz_channels = set(distiller.non_zero_channels(conv2_p))
     channels_removed = all_channels - nnz_channels
     logger.info("Channels removed {}".format(channels_removed))
 
@@ -457,6 +459,20 @@ def test_magnitude_pruning():
     assert common.almost_equal(distiller.sparsity(b), 1/distiller.volume(a))
 
 
+def test_row_pruning():
+    param = torch.tensor([[1., 2., 3.],
+                          [4., 5., 6.],
+                          [7., 8., 9.]])
+    from distiller.pruning import L1RankedStructureParameterPruner
+
+    masker = distiller.scheduler.ParameterMasker("why name")
+    zeros_mask_dict = {"some name": masker}
+    L1RankedStructureParameterPruner.rank_and_prune_rows(0.5, param, "some name", zeros_mask_dict)
+    print(distiller.sparsity_rows(masker.mask))
+    assert math.isclose(distiller.sparsity_rows(masker.mask), 1/3)
+    pass
+
+
 if __name__ == '__main__':
     for is_parallel in [True, False]:
         test_ranked_filter_pruning(is_parallel)
@@ -477,3 +493,4 @@ if __name__ == '__main__':
         arbitrary_channel_pruning(mobilenet_imagenet(is_parallel),
                                   channels_to_remove=[0, 2],
                                   is_parallel=is_parallel)
+    test_row_pruning()
\ No newline at end of file
-- 
GitLab