From 3f7a94089828c74e2ec751e825b331dc1fc67e08 Mon Sep 17 00:00:00 2001 From: Neta Zmora <neta.zmora@intel.com> Date: Sun, 1 Sep 2019 20:54:27 +0300 Subject: [PATCH] AMC: add pruning of FC layers MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit FMReconstructionChannelPruner: add support for nn.Linear layers utils.py: add non_zero_channels() thinning: support removing channels from FC layers preceding Conv layers test_pruning.py: add test_row_pruning() scheduler: init from a dictionary of Maskers coach_if.py – fix imports of Clipped-PPO and TD3 --- distiller/pruning/ranked_structures_pruner.py | 114 ++++--- distiller/scheduler.py | 5 +- distiller/thinning.py | 98 +++--- distiller/utils.py | 34 +- examples/auto_compression/amc/amc.py | 30 +- .../amc/auto_compression_channels.yaml | 6 +- examples/auto_compression/amc/environment.py | 309 ++++++++++++------ examples/auto_compression/amc/rewards.py | 28 +- .../amc/rl_libs/coach/coach_if.py | 13 +- .../amc/rl_libs/private/private_if.py | 1 - .../amc/utils/data_dependencies.py | 2 +- tests/test_pruning.py | 19 +- 12 files changed, 394 insertions(+), 265 deletions(-) diff --git a/distiller/pruning/ranked_structures_pruner.py b/distiller/pruning/ranked_structures_pruner.py index 96b8dce..84c96b7 100755 --- a/distiller/pruning/ranked_structures_pruner.py +++ b/distiller/pruning/ranked_structures_pruner.py @@ -263,6 +263,18 @@ class LpRankedStructureParameterPruner(_RankedStructureParameterPruner): # param.data = torch.randn_like(param) return binary_map + @staticmethod + def rank_rows(magnitude_fn, fraction_to_prune, param): # , group_size, rounding_fn, noise): + assert param.dim() == 2, "This pruning is only supported for 2D weights" + ROWS_DIM = 0 + cols_mags = magnitude_fn(param, dim=ROWS_DIM) + num_cols_to_prune = int(fraction_to_prune * cols_mags.size(ROWS_DIM)) + if num_cols_to_prune == 0: + msglogger.info("Too few filters - can't prune %.1f%% rows", 100*fraction_to_prune) + return None, None + bottomk_cols, _ = torch.topk(cols_mags, num_cols_to_prune, largest=False, sorted=True) + return bottomk_cols, cols_mags + @staticmethod def rank_and_prune_rows(fraction_to_prune, param, param_name, zeros_mask_dict, model=None, binary_map=None, @@ -270,30 +282,29 @@ class LpRankedStructureParameterPruner(_RankedStructureParameterPruner): """Prune the rows of a matrix, based on ranked L1-norms of the matrix rows. PyTorch stores the weights matrices in a transposed format. I.e. before performing GEMM, a matrix is - transposed. This is counter-intuitive. To deal with this, we can either transpose the matrix and - then proceed to compute the masks as usual, or we can treat columns as rows, and rows as columns :-(. + transposed. This is because the output is computed as follows: + y = x(W^T) + b ; where W^T is the transpose of W + + Removing input_channels from W^T, is removing rows of W^T, which is removing columns of W. + + To deal with this rotation, we can either transpose the matrix and then proceed to compute the masks + as usual, or we can treat columns as rows, and rows as columns :-(. We choose the latter, because transposing very large matrices can be detrimental to performance. Note - that computing mean L1-norm of columns is also not optimal, because consequtive column elements are far + that computing mean L1-norm of columns is also not optimal, because consecutive column elements are far away from each other in memory, and this means poor use of caches and system memory. """ - - assert param.dim() == 2, "This pruning is only supported for 2D weights" - ROWS_DIM = 0 + bottomk_cols, cols_mags = LpRankedStructureParameterPruner.rank_rows(magnitude_fn, fraction_to_prune, param) THRESHOLD_DIM = 'Cols' - rows_mags = magnitude_fn(param, dim=ROWS_DIM) - num_rows_to_prune = int(fraction_to_prune * rows_mags.size(0)) - if num_rows_to_prune == 0: - msglogger.info("Too few filters - can't prune %.1f%% rows", 100*fraction_to_prune) - return - bottomk_rows, _ = torch.topk(rows_mags, num_rows_to_prune, largest=False, sorted=True) - threshold = bottomk_rows[-1] + threshold = bottomk_cols[-1] threshold_type = 'L1' if magnitude_fn == l1_magnitude else 'L2' zeros_mask_dict[param_name].mask, binary_map = distiller.group_threshold_mask(param, THRESHOLD_DIM, threshold, threshold_type) + ROWS_DIM = 0 + num_cols_to_prune = int(fraction_to_prune * cols_mags.size(ROWS_DIM)) msglogger.info("%sRankedStructureParameterPruner - param: %s pruned=%.3f goal=%.3f (%d/%d)", threshold_type, param_name, distiller.sparsity(zeros_mask_dict[param_name].mask), - fraction_to_prune, num_rows_to_prune, rows_mags.size(0)) + fraction_to_prune, num_cols_to_prune, cols_mags.size(ROWS_DIM)) return binary_map @staticmethod @@ -680,8 +691,7 @@ class FMReconstructionChannelPruner(_RankedStructureParameterPruner): Use this in conjunction with distiller.features_collector.collect_intermediate_featuremap_samples, which orchestrates the process of feature-map collection. - This foward-hook samples random points. - + This foward-hook samples random points in the output feature-maps of 'module'. After collecting the feature-map samples, distiller.FMReconstructionChannelPruner can be used. Arguments: @@ -697,11 +707,15 @@ class FMReconstructionChannelPruner(_RankedStructureParameterPruner): # Sample random (uniform) points in each feature-map. # This method is biased toward small feature-maps. - randx = np.random.randint(0, output.size(2), n_points_per_fm) - randy = np.random.randint(0, output.size(3), n_points_per_fm) + if isinstance(module, torch.nn.Conv2d): + randx = np.random.randint(0, output.size(2), n_points_per_fm) + randy = np.random.randint(0, output.size(3), n_points_per_fm) X = input[0] - if module.kernel_size == (1,1): + if isinstance(module, torch.nn.Linear): + X = X.detach().cpu().clone() + Y = output.detach().cpu().clone() + elif module.kernel_size == (1, 1): X = X[:, :, randx, randy].detach().cpu().clone() Y = output[:, :, randx, randy].detach().cpu().clone() else: @@ -736,9 +750,10 @@ class FMReconstructionChannelPruner(_RankedStructureParameterPruner): def prune_group(self, fraction_to_prune, param, param_name, zeros_mask_dict, model=None, binary_map=None): if fraction_to_prune == 0: return - binary_map = self.rank_and_prune_channels(fraction_to_prune, param, param_name, - zeros_mask_dict, model, binary_map, - group_size=self.group_size, + + binary_map = self.rank_and_prune_channels(fraction_to_prune, param, param_name, + zeros_mask_dict, model, binary_map, + group_size=self.group_size, rounding_fn=self.rounding_fn, noise=self.noise) return binary_map @@ -750,16 +765,24 @@ class FMReconstructionChannelPruner(_RankedStructureParameterPruner): noise=0): assert binary_map is None if binary_map is None: - bottomk_channels, channel_mags = LpRankedStructureParameterPruner.rank_channels( - magnitude_fn, fraction_to_prune, param, group_size, rounding_fn, noise) - # Todo: this little piece of code can be refactored + op_type = 'conv' if param.dim() == 4 else 'fc' + if op_type == 'conv': + bottomk_channels, channel_mags = LpRankedStructureParameterPruner.rank_channels( + magnitude_fn, fraction_to_prune, param, group_size, rounding_fn, noise) + + else: + bottomk_channels, channel_mags = LpRankedStructureParameterPruner.rank_rows( + magnitude_fn, fraction_to_prune, param) + + # Todo: this little piece of code can be refactored if bottomk_channels is None: # Empty list means that fraction_to_prune is too low to prune anything return - + threshold = bottomk_channels[-1] binary_map = channel_mags.gt(threshold) + # These are the indices of channels we want to keep indices = binary_map.nonzero().squeeze() if len(indices.shape) == 0: @@ -779,7 +802,9 @@ class FMReconstructionChannelPruner(_RankedStructureParameterPruner): # min(MSE) to compute the weights, we need to start by removing feature-map # channels from the input. Then we perform the MSE regression to generate # a smaller weights tensor. - if conv.kernel_size == (1,1): + if op_type == 'fc': + X = X[:, binary_map] + elif conv.kernel_size == (1, 1): X = X[:, binary_map, :] X = X.transpose(1, 2) X = X.contiguous().view(-1, X.size(2)) @@ -797,18 +822,29 @@ class FMReconstructionChannelPruner(_RankedStructureParameterPruner): new_w = torch.from_numpy(new_w) # shape: (num_filters, num_non_masked_channels * k^2) cnt_retained_channels = binary_map.sum() - # Expand the weights back to their original size, - new_w = new_w.contiguous().view(param.size(0), cnt_retained_channels, param.size(2), param.size(3)) - - # Copy the weights that we learned from minimizing the feature-maps least squares error, - # to our actual weights tensor. - param.detach()[:,indices,:,:] = new_w.type(param.type()) - + if op_type == 'conv': + # Expand the weights back to their original size, + new_w = new_w.contiguous().view(param.size(0), cnt_retained_channels, param.size(2), param.size(3)) + + # Copy the weights that we learned from minimizing the feature-maps least squares error, + # to our actual weights tensor. + param.detach()[:,indices,:,:] = new_w.type(param.type()) + else: + param.detach()[:, indices] = new_w.type(param.type()) + if zeros_mask_dict is not None: binary_map = binary_map.type(param.type()) - zeros_mask_dict[param_name].mask = LpRankedStructureParameterPruner.ch_binary_map_to_mask(binary_map, param) - msglogger.info("FMReconstructionChannelPruner - param: %s pruned=%.3f goal=%.3f (%d/%d)", - param_name, - distiller.sparsity_ch(zeros_mask_dict[param_name].mask), - fraction_to_prune, binary_map.sum().item(), param.size(1)) + if op_type == 'conv': + zeros_mask_dict[param_name].mask = LpRankedStructureParameterPruner.ch_binary_map_to_mask(binary_map, param) + msglogger.info("FMReconstructionChannelPruner - param: %s pruned=%.3f goal=%.3f (%d/%d)", + param_name, + distiller.sparsity_ch(zeros_mask_dict[param_name].mask), + fraction_to_prune, binary_map.sum().item(), param.size(1)) + else: + msglogger.error("fc sparsity = %.2f" % (1 - binary_map.sum().item() / binary_map.size(0))) + zeros_mask_dict[param_name].mask = binary_map.expand(param.size(0), param.size(1)) + msglogger.info("FMReconstructionChannelPruner - param: %s pruned=%.3f goal=%.3f (%d/%d)", + param_name, + distiller.sparsity_cols(zeros_mask_dict[param_name].mask), + fraction_to_prune, binary_map.sum().item(), param.size(1)) return binary_map diff --git a/distiller/scheduler.py b/distiller/scheduler.py index a0b7490..be4b14d 100755 --- a/distiller/scheduler.py +++ b/distiller/scheduler.py @@ -34,13 +34,13 @@ class CompressionScheduler(object): """Responsible for scheduling pruning and masking parameters. """ - def __init__(self, model, device=torch.device("cuda")): + def __init__(self, model, zeros_mask_dict=None, device=torch.device("cuda")): self.model = model self.device = device self.policies = {} self.sched_metadata = {} # Create the masker objects and place them in a dictionary indexed by the parameter name - self.zeros_mask_dict = create_model_masks_dict(model) + self.zeros_mask_dict = zeros_mask_dict or create_model_masks_dict(model) def add_policy(self, policy, epochs=None, starting_epoch=0, ending_epoch=1, frequency=1): """Add a new policy to the schedule. @@ -212,7 +212,6 @@ class CompressionScheduler(object): if name not in masks_dict: masks_dict[name] = None state = {'masks_dict': masks_dict} - self.load_state_dict(state, normalize_dataparallel_keys) @staticmethod diff --git a/distiller/thinning.py b/distiller/thinning.py index d06fe3a..0d790c2 100755 --- a/distiller/thinning.py +++ b/distiller/thinning.py @@ -60,11 +60,10 @@ __all__ = ['ThinningRecipe', 'resnet_cifar_remove_layers', 'StructureRemover', 'ChannelRemover', 'remove_channels', 'FilterRemover', 'remove_filters', - 'find_nonzero_channels', 'find_nonzero_channels_list', 'execute_thinning_recipes_list', 'get_normalized_recipe'] -def create_graph(dataset, model): +def _create_graph(dataset, model): dummy_input = distiller.get_dummy_input(dataset, distiller.model_device(model)) return SummaryGraph(model, dummy_input) @@ -144,41 +143,12 @@ def append_bn_thinning_directive(thinning_recipe, layers, bn_name, len_thin_feat def remove_channels(model, zeros_mask_dict, arch, dataset, optimizer): - sgraph = create_graph(dataset, model) + sgraph = _create_graph(dataset, model) thinning_recipe = create_thinning_recipe_channels(sgraph, model, zeros_mask_dict) apply_and_save_recipe(model, zeros_mask_dict, thinning_recipe, optimizer) return model -def find_nonzero_channels(param, param_name): - """Count the number of non-zero channels in a weights tensor. - - Non-zero channels are channels that have at least one coefficient that is - non-zero. Counting non-zero channels involves some tensor acrobatics. - """ - num_filters, num_channels = param.size(0), param.size(1) - - # First, reshape the weights tensor such that each channel (kernel) in the original - # tensor, is now a row in the 2D tensor. - view_2d = param.view(-1, param.size(2) * param.size(3)) - # Next, compute the sums of each kernel - kernel_sums = view_2d.abs().sum(dim=1) - # Now group by channels - k_sums_mat = kernel_sums.view(num_filters, num_channels).t() - nonzero_channels = torch.nonzero(k_sums_mat.abs().sum(dim=1)) - - if num_channels > nonzero_channels.nelement(): - msglogger.debug("In tensor %s found %d/%d zero channels", param_name, - num_channels - nonzero_channels.nelement(), num_channels) - return nonzero_channels - -# Todo: consider removing this function -def find_nonzero_channels_list(param, param_name): - nnz_channels = find_nonzero_channels(param, param_name) - nnz_channels = nnz_channels.view(nnz_channels.numel()) - return nnz_channels.cpu().numpy().tolist() - - def apply_and_save_recipe(model, zeros_mask_dict, thinning_recipe, optimizer): if len(thinning_recipe.modules) > 0 or len(thinning_recipe.parameters) > 0: # Now actually remove the filters, channels and make the weight tensors smaller @@ -196,7 +166,7 @@ def apply_and_save_recipe(model, zeros_mask_dict, thinning_recipe, optimizer): def remove_filters(model, zeros_mask_dict, arch, dataset, optimizer): - sgraph = create_graph(dataset, model) + sgraph = _create_graph(dataset, model) thinning_recipe = create_thinning_recipe_filters(sgraph, model, zeros_mask_dict) apply_and_save_recipe(model, zeros_mask_dict, thinning_recipe, optimizer) return model @@ -212,24 +182,27 @@ def create_thinning_recipe_channels(sgraph, model, zeros_mask_dict): The thinning recipe contains meta-instructions of how the model should be changed in order to remove the channels. """ - def handle_layer(layer_name, param_name, num_nnz_channels): + def handle_layer(layer_name, param_name, nnz_channels): # We are removing channels, so update the number of incoming channels (IFMs) # in the convolutional layer - assert isinstance(layers[layer_name], torch.nn.modules.Conv2d) - append_module_directive(thinning_recipe, layer_name, key='in_channels', val=num_nnz_channels) + assert isinstance(layers[layer_name], (torch.nn.modules.Conv2d, torch.nn.modules.Linear)) + + append_module_directive(thinning_recipe, layer_name, key='in_channels', val=nnz_channels) # Select only the non-zero channels indices = nonzero_channels.data.squeeze() - dim = 1 if layers[layer_name].groups == 1 else 0 + dim = 1 if isinstance(layers[layer_name], torch.nn.modules.Conv2d) and layers[layer_name].groups == 1 else 0 + if isinstance(layers[layer_name], torch.nn.modules.Linear): + dim = 1 append_param_directive(thinning_recipe, param_name, (dim, indices)) - # Find all instances of Convolution layers that immediately preceed this layer + # Find all instances of Convolution layers that immediately precede this layer predecessors = sgraph.predecessors_f(layer_name, ['Conv']) if not predecessors: msglogger.info("Could not find predecessors for name=%s" % layer_name) for predecessor in predecessors: - # For each of the convolutional layers that preceed, we have to reduce the number of output channels. - append_module_directive(thinning_recipe, predecessor, key='out_channels', val=num_nnz_channels) + # For each of the convolution layers that precede, we have to reduce the number of output channels. + append_module_directive(thinning_recipe, predecessor, key='out_channels', val=nnz_channels) if layers[predecessor].groups == 1: # Now remove filters from the weights tensor of the predecessor conv @@ -246,13 +219,13 @@ def create_thinning_recipe_channels(sgraph, model, zeros_mask_dict): if layers[predecessor].bias is not None: # This convolution has bias coefficients append_param_directive(thinning_recipe, predecessor+'.bias', (0, indices)) - append_module_directive(thinning_recipe, predecessor, key='groups', val=num_nnz_channels) + append_module_directive(thinning_recipe, predecessor, key='groups', val=nnz_channels) # In the special case of a Convolutional layer with (groups == in_channels), if we # change in_channels, we also need to change out_channels, which means that we # have to perform filter removal for this layer as well param_name = predecessor+'.weight' - handle_layer(predecessor, param_name, num_nnz_channels) + handle_layer(predecessor, param_name, nnz_channels) else: raise ValueError("Distiller thinning code currently does not handle this conv.groups configuration") @@ -262,7 +235,7 @@ def create_thinning_recipe_channels(sgraph, model, zeros_mask_dict): # Thinning of the BN layer that follows the convolution msglogger.debug("[recipe] {}: predecessor BN module = {}".format(layer_name, bn_layer)) append_bn_thinning_directive(thinning_recipe, layers, bn_layer, - len_thin_features=num_nnz_channels, thin_features=indices) + len_thin_features=nnz_channels, thin_features=indices) msglogger.debug("Invoking create_thinning_recipe_channels") thinning_recipe = ThinningRecipe(modules={}, parameters={}) @@ -271,19 +244,25 @@ def create_thinning_recipe_channels(sgraph, model, zeros_mask_dict): # Traverse all of the model's parameters, search for zero-channels, and # create a thinning recipe that descibes the required changes to the model. for layer_name, param_name, param in sgraph.named_params_layers(): - # We are only interested in 4D weights (of Convolution layers) - if param.dim() != 4: - continue - num_channels = param.size(1) - nonzero_channels = find_nonzero_channels(param, param_name) - num_nnz_channels = nonzero_channels.nelement() - if num_nnz_channels == 0: - raise ValueError("Trying to set zero channels for parameter %s is not allowed" % param_name) - # If there are non-zero channels in this tensor then continue to next tensor - if num_channels <= num_nnz_channels: - continue - handle_layer(layer_name, param_name, num_nnz_channels) - msglogger.debug(thinning_recipe) + if param.dim() in (2, 4): + num_channels = param.size(1) + # Find nonzero input channels + if param.dim() == 2: + # 2D weights (of Linear layers) + col_sums = param.abs().sum(dim=0) + nonzero_channels = torch.nonzero(col_sums) + num_nnz_channels = nonzero_channels.nelement() + elif param.dim() == 4: + # 4D weights (of Convolution layers) + nonzero_channels = distiller.non_zero_channels(param) + num_nnz_channels = nonzero_channels.nelement() + if num_nnz_channels == 0: + raise ValueError("Trying to zero all channels for parameter %s is not allowed" % param_name) + + # If there are no non-zero channels in this tensor then continue to next tensor + if num_channels <= num_nnz_channels: + continue + handle_layer(layer_name, param_name, num_nnz_channels) return thinning_recipe @@ -504,6 +483,8 @@ def execute_thinning_recipe(model, zeros_mask_dict, recipe, optimizer, loaded_fr with torch.no_grad(): for param_name, param_directives in recipe.parameters.items(): + if param_name == "module.fc.weight": + debug = True msglogger.debug("{} : {}".format(param_name, param_directives)) param = distiller.model_find_param(model, param_name) assert param is not None @@ -533,7 +514,7 @@ def execute_thinning_recipe(model, zeros_mask_dict, recipe, optimizer, loaded_fr else: if param.data.size(dim) != len_indices: msglogger.debug("[thinning] changing param {} ({}) dim:{} new len: {}".format( - param_name, param.shape, dim, len_indices)) + param_name, param.shape, dim, len_indices)) assert param.size(dim) > len_indices param.data = torch.index_select(param.data, dim, indices.to(param.device)) msglogger.debug("[thinning] changed param {}".format(param_name)) @@ -546,13 +527,14 @@ def execute_thinning_recipe(model, zeros_mask_dict, recipe, optimizer, loaded_fr if optimizer_thinning(optimizer, param, dim, indices): msglogger.debug("Updated velocity buffer %s" % param_name) - if not loaded_from_file: + if not loaded_from_file and zeros_mask_dict: # If the masks are loaded from a checkpoint file, then we don't need to change # their shape, because they are already correctly shaped mask = zeros_mask_dict[param_name].mask if mask is not None and (mask.size(dim) != len_indices): zeros_mask_dict[param_name].mask = torch.index_select(mask, dim, indices) + # Todo: consider removing this function def resnet_cifar_remove_layers(model): """Remove layers from ResNet-Cifar. diff --git a/distiller/utils.py b/distiller/utils.py index f00e333..1b55f25 100755 --- a/distiller/utils.py +++ b/distiller/utils.py @@ -246,23 +246,35 @@ def density_2D(tensor): return 1 - sparsity_2D(tensor) -def sparsity_ch(tensor): - """Channel-wise sparsity for 4D tensors""" +def non_zero_channels(tensor): + """Returns the indices of non-zero channels. + + Non-zero channels are channels that have at least one coefficient that + is not zero. Counting non-zero channels involves some tensor acrobatics. + """ if tensor.dim() != 4: - return 0 + raise ValueError("Expecting a 4D tensor") - num_filters = tensor.size(0) - num_kernels_per_filter = tensor.size(1) + n_filters, n_channels, k_h, k_w = (tensor.size(i) for i in range(4)) - # First, reshape the weights tensor such that each channel (kernel) in the original - # tensor, is now a row in the 2D tensor. - view_2d = tensor.view(-1, tensor.size(2) * tensor.size(3)) + # First, reshape the weights tensor such that each channel (kernel) in + # the original tensor, is now a row in a 2D tensor. + view_2d = tensor.view(-1, k_h * k_w) # Next, compute the sums of each kernel kernel_sums = view_2d.abs().sum(dim=1) # Now group by channels - k_sums_mat = kernel_sums.view(num_filters, num_kernels_per_filter).t() - nonzero_channels = len(torch.nonzero(k_sums_mat.abs().sum(dim=1))) - return 1 - nonzero_channels/num_kernels_per_filter + k_sums_mat = kernel_sums.view(n_filters, n_channels).t() + nonzero_channels = torch.nonzero(k_sums_mat.abs().sum(dim=1)) + return nonzero_channels + + +def sparsity_ch(tensor): + """Channel-wise sparsity for 4D tensors""" + if tensor.dim() != 4: + return 0 + nonzero_channels = len(non_zero_channels(tensor)) + n_channels = tensor.size(1) + return 1 - nonzero_channels/n_channels def density_ch(tensor): diff --git a/examples/auto_compression/amc/amc.py b/examples/auto_compression/amc/amc.py index 5155867..46d633c 100755 --- a/examples/auto_compression/amc/amc.py +++ b/examples/auto_compression/amc/amc.py @@ -19,28 +19,12 @@ $ python3 amc.py --arch=resnet20_cifar ${CIFAR10_PATH} --resume=../../ssl/checkp """ -import math import os -import copy import logging -import numpy as np -import torch -import csv import traceback from functools import partial -try: - import gym -except ImportError as e: - print("WARNING: to use automated compression you will need to install extra packages") - print("See instructions in the interface of each RL library.") - raise e -from gym import spaces import distiller -from collections import OrderedDict, namedtuple -from types import SimpleNamespace -from distiller import normalize_module_name, SummaryGraph from environment import DistillerWrapperEnvironment, Observation -from utils.features_collector import collect_intermediate_featuremap_samples import distiller.apputils as apputils import distiller.apputils.image_classifier as classifier from rewards import reward_factory @@ -117,7 +101,7 @@ def train_auto_compressor(model, args, optimizer_data, validate_fn, save_checkpo ddpg_cfg = distiller.utils.MutableNamedTuple({ 'heatup_noise': 0.5, 'initial_training_noise': 0.5, - 'training_noise_decay': 0.99555, #0.98, #0.996, + 'training_noise_decay': 0.95, 'num_heatup_episodes': args.amc_heatup_episodes, 'num_training_episodes': args.amc_training_episodes, 'actor_lr': 1e-4, @@ -147,7 +131,8 @@ def train_auto_compressor(model, args, optimizer_data, validate_fn, save_checkpo def create_environment(): env = DistillerWrapperEnvironment(model, app_args, amc_cfg, services) - env.amc_cfg.ddpg_cfg.replay_buffer_size = 100 * env.steps_per_episode + #env.amc_cfg.ddpg_cfg.replay_buffer_size = int(1.5 * amc_cfg.ddpg_cfg.num_heatup_episodes * env.steps_per_episode) + env.amc_cfg.ddpg_cfg.replay_buffer_size = amc_cfg.ddpg_cfg.num_heatup_episodes * env.steps_per_episode return env env1 = create_environment() @@ -180,7 +165,7 @@ def train_auto_compressor(model, args, optimizer_data, validate_fn, save_checkpo raise ValueError("unsupported rl library: ", args.amc_rllib) -def config_verbose(verbose): +def config_verbose(verbose, display_summaries=False): if verbose: loglevel = logging.DEBUG else: @@ -188,11 +173,14 @@ def config_verbose(verbose): logging.getLogger().setLevel(logging.WARNING) for module in ["examples.auto_compression.amc", "distiller.apputils.image_classifier", - "distiller.data_loggers.logger", - "distiller.thinning", + "distiller.thinning", "distiller.pruning.ranked_structures_pruner"]: logging.getLogger(module).setLevel(loglevel) + # display training progress summaries + summaries_lvl = logging.INFO if display_summaries else logging.WARNING + logging.getLogger("examples.auto_compression.amc.summaries").setLevel(summaries_lvl) + if __name__ == '__main__': try: diff --git a/examples/auto_compression/amc/auto_compression_channels.yaml b/examples/auto_compression/amc/auto_compression_channels.yaml index a58493d..558c975 100755 --- a/examples/auto_compression/amc/auto_compression_channels.yaml +++ b/examples/auto_compression/amc/auto_compression_channels.yaml @@ -9,7 +9,8 @@ network: "module.model.3.3", "module.model.4.3", "module.model.5.3", "module.model.6.3", "module.model.7.3", "module.model.8.3", "module.model.9.3", "module.model.10.3", "module.model.11.3", - "module.model.12.3", "module.model.13.3"] + "module.model.12.3", "module.model.13.3", + "module.fc"] mobilenet_v2: # Only conv 1x1, without shortcut connection dependencies @@ -94,7 +95,8 @@ network: "module.layer2.2.conv1", "module.layer2.2.conv2", "module.layer3.0.conv1", "module.layer3.0.conv2", "module.layer3.1.conv1", "module.layer3.1.conv2", - "module.layer3.2.conv1", "module.layer3.2.conv2"] + "module.layer3.2.conv1", "module.layer3.2.conv2", + "module.fc"] simplenet_mnist: ["module.conv2"] diff --git a/examples/auto_compression/amc/environment.py b/examples/auto_compression/amc/environment.py index d248d4b..79415ab 100755 --- a/examples/auto_compression/amc/environment.py +++ b/examples/auto_compression/amc/environment.py @@ -29,13 +29,7 @@ import copy import logging import numpy as np import torch -try: - import gym -except ImportError as e: - print("WARNING: to use automated compression you will need to install extra packages") - print("See instructions in the header of examples/automated_deep_compression/ADC.py") - raise e -from gym import spaces +import gym import distiller from collections import OrderedDict, namedtuple from types import SimpleNamespace @@ -45,7 +39,7 @@ from utils.ac_loggers import AMCStatsLogger, FineTuneStatsLogger msglogger = logging.getLogger("examples.auto_compression.amc") -Observation = namedtuple('Observation', ['t', 'n', 'c', 'h', 'w', 'stride', 'k', 'MACs', +Observation = namedtuple('Observation', ['t', 'type', 'n', 'c', 'h', 'w', 'stride', 'k', 'MACs', 'Weights', 'reduced', 'rest', 'prev_a']) ObservationLen = len(Observation._fields) ALMOST_ONE = 0.9999 @@ -141,8 +135,8 @@ class NetworkMetadata(object): def is_prunable(self, layer_id): return layer_id in self.pruned_idxs - def is_reducible(self, layer_id): - return layer_id in self.pruned_idxs or layer_id in self.dependent_idxs + def is_compressible(self, layer_id): + return layer_id in (self.pruned_idxs + self.dependent_idxs) def num_pruned_layers(self): return len(self.pruned_idxs) @@ -165,6 +159,7 @@ class NetworkWrapper(object): pruning_pattern, modules_list) self.cached_perf_summary = self.cached_model_metadata.performance_summary() self.reset(model) + self.sparsification_masks = None def reset(self, model): self.model = model @@ -218,9 +213,7 @@ class NetworkWrapper(object): return ret def create_scheduler(self): - scheduler = distiller.CompressionScheduler(self.model) - masks = {param_name: masker.mask for param_name, masker in self.zeros_mask_dict.items()} - scheduler.load_state_dict(state={'masks_dict': masks}) + scheduler = distiller.CompressionScheduler(self.model, self.zeros_mask_dict) return scheduler def remove_structures(self, layer_id, fraction_to_prune, prune_what, prune_how, @@ -233,9 +226,9 @@ class NetworkWrapper(object): raise ValueError("idx=%d is not in correct range " % layer_id) if fraction_to_prune < 0: raise ValueError("fraction_to_prune=%.3f is illegal" % fraction_to_prune) - if fraction_to_prune == 0: return 0 + if fraction_to_prune == 1.0: # For now, prevent the removal of entire layers fraction_to_prune = ALMOST_ONE @@ -249,6 +242,8 @@ class NetworkWrapper(object): if prune_what == "channels": calculate_sparsity = distiller.sparsity_ch + if layer.type == "Linear": + calculate_sparsity = distiller.sparsity_rows remove_structures_fn = distiller.remove_channels group_type = "Channels" elif prune_what == "filters": @@ -258,7 +253,7 @@ class NetworkWrapper(object): else: raise ValueError("unsupported structure {}".format(prune_what)) - if prune_how == "l1-rank" or prune_how == "stochastic-l1-rank": + if prune_how in ["l1-rank", "stochastic-l1-rank"]: # Create a channel/filter-ranking pruner pruner = distiller.pruning.L1RankedStructureParameterPruner( "auto_pruner", group_type, fraction_to_prune, conv_pname, @@ -275,14 +270,15 @@ class NetworkWrapper(object): del pruner if (self.zeros_mask_dict[conv_pname].mask is None or - 0 == calculate_sparsity(self.zeros_mask_dict[conv_pname].mask)): + 0 == calculate_sparsity(self.zeros_mask_dict[conv_pname].mask)): msglogger.debug("remove_structures: aborting because there are no structures to prune") return 0 final_action = calculate_sparsity(self.zeros_mask_dict[conv_pname].mask) # Use the mask to prune self.zeros_mask_dict[conv_pname].apply_mask(conv_p) - if apply_thinning: + if apply_thinning: + self.cache_spasification_masks() remove_structures_fn(self.model, self.zeros_mask_dict, self.app_args.arch, self.app_args.dataset, optimizer=None) self.model_metadata.reduce_layer_macs(layer, final_action) @@ -295,7 +291,7 @@ class NetworkWrapper(object): return top1, top5, vloss def train(self, num_epochs, episode=0): - # Train for zero or more epochs + """Train for zero or more epochs""" opt_cfg = self.app_args.optimizer_data optimizer = torch.optim.SGD(self.model.parameters(), lr=opt_cfg['lr'], momentum=opt_cfg['momentum'], weight_decay=opt_cfg['weight_decay']) @@ -309,13 +305,18 @@ class NetworkWrapper(object): del compression_scheduler return acc_list + def cache_spasification_masks(self): + masks = {param_name: masker.mask for param_name, masker in self.zeros_mask_dict.items()} + self.sparsification_masks = copy.deepcopy(masks) + class DistillerWrapperEnvironment(gym.Env): def __init__(self, model, app_args, amc_cfg, services): - self.pylogger = distiller.data_loggers.PythonLogger(msglogger) + self.pylogger = distiller.data_loggers.PythonLogger( + logging.getLogger("examples.auto_compression.amc.summaries")) logdir = logging.getLogger().logdir self.tflogger = distiller.data_loggers.TensorBoardLogger(logdir) - self.verbose = False + self._render = False self.orig_model = copy.deepcopy(model) self.app_args = app_args self.amc_cfg = amc_cfg @@ -331,12 +332,12 @@ class DistillerWrapperEnvironment(gym.Env): self._max_episode_steps = self.net_wrapper.model_metadata.num_pruned_layers() # Hack for Coach-TD3 self.episode = 0 self.best_reward = float("-inf") - self.action_low = amc_cfg.action_range[0] - self.action_high = amc_cfg.action_range[1] + self.action_low, self.action_high = amc_cfg.action_range + #self.action_high = amc_cfg.action_range[1] self._log_model_info() log_amc_config(amc_cfg) self._configure_action_space() - self.observation_space = spaces.Box(0, float("inf"), shape=(len(Observation._fields),)) + self.observation_space = gym.spaces.Box(0, float("inf"), shape=(len(Observation._fields),)) self.stats_logger = AMCStatsLogger(os.path.join(logdir, 'amc.csv')) self.ft_stats_logger = FineTuneStatsLogger(os.path.join(logdir, 'ft_top1.csv')) @@ -355,7 +356,7 @@ class DistillerWrapperEnvironment(gym.Env): def acceptance_criterion(m, mod_names): # Collect feature-maps only for Conv2d layers, if they are in our modules list. - return isinstance(m, torch.nn.Conv2d) and m.distiller_name in mod_names + return isinstance(m, (torch.nn.Conv2d, torch.nn.Linear)) and m.distiller_name in mod_names # For feature-map reconstruction we need to collect a representative set # of inter-layer feature-maps @@ -377,12 +378,12 @@ class DistillerWrapperEnvironment(gym.Env): def _configure_action_space(self): if is_using_continuous_action_space(self.amc_cfg.agent_algo): if self.amc_cfg.agent_algo == "ClippedPPO-continuous": - self.action_space = spaces.Box(PPO_MIN, PPO_MAX, shape=(1,)) + self.action_space = gym.spaces.Box(PPO_MIN, PPO_MAX, shape=(1,)) else: - self.action_space = spaces.Box(self.action_low, self.action_high, shape=(1,)) + self.action_space = gym.spaces.Box(self.action_low, self.action_high, shape=(1,)) self.action_space.default_action = self.action_low else: - self.action_space = spaces.Discrete(10) + self.action_space = gym.spaces.Discrete(10) @property @@ -401,7 +402,7 @@ class DistillerWrapperEnvironment(gym.Env): if hasattr(self.net_wrapper.model, 'intermediate_fms'): self.model.intermediate_fms = self.net_wrapper.model.intermediate_fms self.net_wrapper.reset(self.model) - self._removed_macs = 0 + self.removed_macs = 0 self.action_history = [] self.agent_action_history = [] self.model_representation = self.get_model_representation() @@ -421,17 +422,13 @@ class DistillerWrapperEnvironment(gym.Env): """Return the amount of MACs removed so far. This is normalized to the range 0..1 """ - return self._removed_macs / self.original_model_macs + return self.removed_macs / self.original_model_macs def render(self, mode='human'): """Provide some feedback to the user about what's going on. This is invoked by the Agent. """ - if self.current_state_id == 0: - msglogger.info("+" + "-" * 50 + "+") - msglogger.info("Starting a new episode %d", self.episode) - msglogger.info("+" + "-" * 50 + "+") - if not self.verbose: + if not self._render: return msglogger.info("Render Environment: current_state_id=%d" % self.current_state_id) distiller.log_weights_sparsity(self.model, -1, loggers=[self.pylogger]) @@ -442,6 +439,10 @@ class DistillerWrapperEnvironment(gym.Env): The action represents the desired sparsity for the "current" layer (i.e. the percentage of weights to remove). This function is invoked by the Agent. """ + if self.current_state_id == 0: + msglogger.info("+" + "-" * 50 + "+") + msglogger.info("Episode %d is starting" % self.episode) + pruning_action = float(pruning_action[0]) msglogger.debug("env.step - current_state_id=%d (%s) episode=%d action=%.2f" % (self.current_state_id, self.current_layer().name, self.episode, pruning_action)) @@ -487,13 +488,13 @@ class DistillerWrapperEnvironment(gym.Env): layer_macs_after_action = self.net_wrapper.layer_macs(self.current_layer()) # Update the various counters after taking the step - self._removed_macs += (total_macs_before - total_macs_after_act) + self.removed_macs += (total_macs_before - total_macs_after_act) msglogger.debug("\tactual_action={}".format(pruning_action)) msglogger.debug("\tlayer_macs={} layer_macs_after_action={} removed now={}".format(layer_macs, layer_macs_after_action, (layer_macs - layer_macs_after_action))) - msglogger.debug("\tself._removed_macs={}".format(self._removed_macs)) + msglogger.debug("\tself._removed_macs={}".format(self.removed_macs)) assert math.isclose(layer_macs_after_action / layer_macs, 1 - pruning_action) stats = ('Performance/Validation/', @@ -504,13 +505,11 @@ class DistillerWrapperEnvironment(gym.Env): total_steps=self.net_wrapper.num_pruned_layers(), log_freq=1, loggers=[self.tflogger]) if self.episode_is_done(): - msglogger.info("Episode is ending") + msglogger.info("Episode %d is ending" % self.episode) observation = self.get_final_obs() - reward, top1 = self.compute_reward(total_macs_after_act, total_nnz_after_act) - normalized_macs = total_macs_after_act / self.original_model_macs * 100 - normalized_nnz = total_nnz_after_act / self.original_model_size * 100 - self.finalize_episode(top1, reward, total_macs_after_act, normalized_macs, - normalized_nnz, self.action_history, self.agent_action_history) + reward, top1, top5, vloss = self.compute_reward(total_macs_after_act, total_nnz_after_act) + self.finalize_episode(reward, (top1, top5, vloss), total_macs_after_act, total_nnz_after_act, + self.action_history, self.agent_action_history) self.episode += 1 else: self.current_layer_id = self.net_wrapper.model_metadata.pruned_idxs[self.current_state_id] @@ -519,16 +518,16 @@ class DistillerWrapperEnvironment(gym.Env): self.net_wrapper.train(1, self.episode) observation = self.get_obs() if self.amc_cfg.reward_frequency is not None and self.current_state_id % self.amc_cfg.reward_frequency == 0: - reward, top1 = self.compute_reward(total_macs_after_act, total_nnz_after_act, log_stats=False) + reward, top1, top5, vloss = self.compute_reward(total_macs_after_act, total_nnz_after_act) else: reward = 0 self.prev_action = pruning_action if self.episode_is_done(): + normalized_macs = total_macs_after_act / self.original_model_macs * 100 info = {"accuracy": top1, "compress_ratio": normalized_macs} - msglogger.info(self.removed_macs_pct) if self.amc_cfg.protocol == "mac-constrained": # Sanity check (special case only for "mac-constrained") - #assert self.removed_macs_pct >= 1 - self.amc_cfg.target_density - 0.01 + assert self.removed_macs_pct >= 1 - self.amc_cfg.target_density - 0.002 # 0.01 pass else: info = {} @@ -538,8 +537,6 @@ class DistillerWrapperEnvironment(gym.Env): """Produce a state embedding (i.e. an observation)""" current_layer_macs = self.net_wrapper.layer_net_macs(self.current_layer()) current_layer_macs_pct = current_layer_macs/self.original_model_macs - current_layer = self.current_layer() - conv_module = distiller.model_find_module(self.model, current_layer.name) obs = self.model_representation[self.current_state_id, :] obs[-1] = self.prev_action @@ -571,17 +568,31 @@ class DistillerWrapperEnvironment(gym.Env): for state_id, layer_id in enumerate(self.net_wrapper.model_metadata.pruned_idxs): layer = self.net_wrapper.get_layer(layer_id) layer_macs = self.net_wrapper.layer_macs(layer) - conv_module = distiller.model_find_module(self.model, layer.name) - obs = [state_id, - conv_module.out_channels, - conv_module.in_channels, - layer.ifm_h, - layer.ifm_w, - layer.stride[0], - layer.k, - distiller.volume(conv_module.weight), - layer_macs, - 0, 0, 0] + mod = distiller.model_find_module(self.model, layer.name) + if isinstance(mod, torch.nn.Conv2d): + obs = [state_id, + 0, + mod.out_channels, + mod.in_channels, + layer.ifm_h, + layer.ifm_w, + layer.stride[0], + layer.k, + distiller.volume(mod.weight), + layer_macs, + 0, 0, 0] + elif isinstance(mod, torch.nn.Linear): + obs = [state_id, + 1, + mod.out_features, + mod.in_features, + layer.ifm_h, + layer.ifm_w, + 0, + 1, + distiller.volume(mod.weight), + layer_macs, + 0, 0, 0] network_obs[state_id:] = np.array(obs) # Feature normalization @@ -596,8 +607,8 @@ class DistillerWrapperEnvironment(gym.Env): def rest_macs_raw(self): """Return the number of remaining MACs in the layers following the current layer""" - rest, prunable_rest = 0, 0 - prunable_layers, rest_layers, layers_to_ignore = list(), list(), list() + nonprunable_rest, prunable_rest = 0, 0 + prunable_layers, nonprunable_layers, layers_to_ignore = list(), list(), list() # Create a list of the IDs of the layers that are dependent on the current_layer. # We want to ignore these layers when we compute prunable_layers (and prunable_rest). @@ -606,17 +617,21 @@ class DistillerWrapperEnvironment(gym.Env): for layer_id in range(self.current_layer_id+1, self.net_wrapper.model_metadata.num_layers()): layer_macs = self.net_wrapper.layer_net_macs(self.net_wrapper.get_layer(layer_id)) - if self.net_wrapper.model_metadata.is_reducible(layer_id): + if self.net_wrapper.model_metadata.is_compressible(layer_id): if layer_id not in layers_to_ignore: prunable_layers.append((layer_id, self.net_wrapper.get_layer(layer_id).name, layer_macs)) prunable_rest += layer_macs - else: - rest_layers.append((layer_id, self.net_wrapper.get_layer(layer_id).name, layer_macs)) - rest += layer_macs - msglogger.debug("prunable_layers={} rest_layers={}".format(prunable_layers, rest_layers)) - msglogger.debug("layer_id=%d, prunable_rest=%.3f rest=%.3f" % (self.current_layer_id, prunable_rest, rest)) - return prunable_rest, rest + for layer_id in list(range(0, self.net_wrapper.model_metadata.num_layers())): + if not self.net_wrapper.model_metadata.is_compressible(layer_id): #and + layer_macs = self.net_wrapper.layer_net_macs(self.net_wrapper.get_layer(layer_id)) + nonprunable_layers.append((layer_id, self.net_wrapper.get_layer(layer_id).name, layer_macs)) + nonprunable_rest += layer_macs + + msglogger.debug("prunable_layers={} nonprunable_layers={}".format(prunable_layers, nonprunable_layers)) + msglogger.debug("layer_id=%d (%s), prunable_rest=%.3f nonprunable_rest=%.3f" % + (self.current_layer_id, self.current_layer().name, prunable_rest, nonprunable_rest)) + return prunable_rest, nonprunable_rest def rest_macs(self): return sum(self.rest_macs_raw()) / self.original_model_macs @@ -625,13 +640,12 @@ class DistillerWrapperEnvironment(gym.Env): current_density = compressed_model_total_macs / self.original_model_macs return self.amc_cfg.target_density >= current_density - def compute_reward(self, total_macs, total_nnz, log_stats=True): + def compute_reward(self, total_macs, total_nnz): """Compute the reward. We use the validation dataset (the size of the validation dataset is configured when the data-loader is instantiated)""" - distiller.log_weights_sparsity(self.model, -1, loggers=[self.pylogger]) - compression = distiller.model_numel(self.model, param_dims=[4]) / self.original_model_size + num_elements = distiller.model_params_size(self.model, param_dims=[2, 4], param_types=['weight']) # Fine-tune (this is a nop if self.amc_cfg.num_ft_epochs==0) accuracies = self.net_wrapper.train(self.amc_cfg.num_ft_epochs, self.episode) @@ -639,28 +653,15 @@ class DistillerWrapperEnvironment(gym.Env): top1, top5, vloss = self.net_wrapper.validate() reward = self.amc_cfg.reward_fn(self, top1, top5, vloss, total_macs) + return reward, top1, top5, vloss - if log_stats: - macs_normalized = total_macs/self.original_model_macs - msglogger.info("Total parameters left: %.2f%%" % (compression*100)) - msglogger.info("Total compute left: %.2f%%" % (total_macs/self.original_model_macs*100)) - - stats = ('Performance/EpisodeEnd/', - OrderedDict([('Loss', vloss), - ('Top1', top1), - ('Top5', top5), - ('reward', reward), - ('total_macs', int(total_macs)), - ('macs_normalized', macs_normalized*100), - ('log(total_macs)', math.log(total_macs)), - ('total_nnz', int(total_nnz))])) - distiller.log_training_progress(stats, None, self.episode, steps_completed=0, total_steps=1, - log_freq=1, loggers=[self.tflogger, self.pylogger]) - return reward, top1 - - def finalize_episode(self, top1, reward, total_macs, normalized_macs, - normalized_nnz, action_history, agent_action_history): + def finalize_episode(self, reward, val_results, total_macs, total_nnz, + action_history, agent_action_history, log_stats=True): """Write the details of one network to the logger and create a checkpoint file""" + top1, top5, vloss = val_results + normalized_macs = total_macs / self.original_model_macs * 100 + normalized_nnz = total_nnz / self.original_model_size * 100 + if reward > self.best_reward: self.best_reward = reward ckpt_name = self.save_checkpoint(is_best=True) @@ -674,6 +675,20 @@ class DistillerWrapperEnvironment(gym.Env): ckpt_name, json.dumps(action_history), json.dumps(agent_action_history), json.dumps(performance)] self.stats_logger.add_record(fields) + msglogger.info("Top1: %.2f - compute: %.2f%% - params:%.2f%% - actions: %s", + top1, normalized_macs, normalized_nnz, self.action_history) + if log_stats: + stats = ('Performance/EpisodeEnd/', + OrderedDict([('Loss', vloss), + ('Top1', top1), + ('Top5', top5), + ('reward', reward), + ('total_macs', int(total_macs)), + ('macs_normalized', normalized_macs), + ('log(total_macs)', math.log(total_macs)), + ('total_nnz', int(total_nnz))])) + distiller.log_training_progress(stats, None, self.episode, steps_completed=0, total_steps=1, + log_freq=1, loggers=[self.tflogger, self.pylogger]) def save_checkpoint(self, is_best=False): """Save the learned-model checkpoint""" @@ -685,8 +700,9 @@ class DistillerWrapperEnvironment(gym.Env): if is_best or self.amc_cfg.save_chkpts: # Always save the best episodes, and depending on amc_cfg.save_chkpts save all other episodes scheduler = self.net_wrapper.create_scheduler() + extras = {"creation_masks": self.net_wrapper.sparsification_masks} self.services.save_checkpoint_fn(epoch=0, model=self.model, - scheduler=scheduler, name=fname) + scheduler=scheduler, name=fname, extras=extras) del scheduler return fname @@ -694,6 +710,7 @@ class DistillerWrapperEnvironment(gym.Env): def get_network_details(model, dataset, dependency_type, layers_to_prune=None): def make_conv(model, conv_module, g, name, seq_id, layer_id): conv = SimpleNamespace() + conv.type = "Conv2D" conv.name = name conv.id = layer_id conv.t = seq_id @@ -718,6 +735,7 @@ def get_network_details(model, dataset, dependency_type, layers_to_prune=None): def make_fc(model, fc_module, g, name, seq_id, layer_id): fc = SimpleNamespace() + fc.type = "Linear" fc.name = name fc.id = layer_id fc.t = seq_id @@ -732,6 +750,11 @@ def get_network_details(model, dataset, dependency_type, layers_to_prune=None): fc.n_ifm = fc_op['attrs']['n_ifm'] fc_pname = name + ".weight" fc_p = distiller.model_find_param(model, fc_pname) + fc.ofm_h = g.param_shape(fc_op['outputs'][0])[0] + fc.ofm_w = g.param_shape(fc_op['outputs'][0])[1] + fc.ifm_h = g.param_shape(fc_op['inputs'][0])[0] + fc.ifm_w = g.param_shape(fc_op['inputs'][0])[1] + return fc dummy_input = distiller.get_dummy_input(dataset) @@ -741,28 +764,36 @@ def get_network_details(model, dataset, dependency_type, layers_to_prune=None): dependent_layers = set() total_macs = 0 total_params = 0 - layers = OrderedDict({mod_name: m for mod_name, m in model.named_modules() - if isinstance(m, (torch.nn.Conv2d, torch.nn.Linear))}) + + unfiltered_layers = layers_topological_order(model, dummy_input) + mods = dict(model.named_modules()) + layers = OrderedDict({mod_name: mods[mod_name] for mod_name in unfiltered_layers + if mod_name in mods and + isinstance(mods[mod_name], (torch.nn.Conv2d, torch.nn.Linear))}) + + # layers = OrderedDict({mod_name: m for mod_name, m in model.named_modules() + # if isinstance(m, (torch.nn.Conv2d, torch.nn.Linear))}) for layer_id, (name, m) in enumerate(layers.items()): - if isinstance(m, torch.nn.Conv2d): - conv = make_conv(model, m, g, name, seq_id=len(pruned_indices), layer_id=layer_id) - all_layers[layer_id] = conv - total_params += conv.weights_vol - total_macs += conv.macs + if isinstance(m, torch.nn.Conv2d) or isinstance(m, torch.nn.Linear): + if isinstance(m, torch.nn.Conv2d): + new_layer = make_conv(model, m, g, name, seq_id=len(pruned_indices), layer_id=layer_id) + all_layers[layer_id] = new_layer + total_params += new_layer.weights_vol + total_macs += new_layer.macs + elif isinstance(m, torch.nn.Linear): + new_layer = make_fc(model, m, g, name, seq_id=len(pruned_indices), layer_id=layer_id) + all_layers[layer_id] = new_layer + total_params += new_layer.weights_vol + total_macs += new_layer.macs if layers_to_prune is None or name in layers_to_prune: pruned_indices.append(layer_id) # Find the data-dependent layers of this convolution from utils.data_dependencies import find_dependencies - conv.dependencies = list() - find_dependencies(dependency_type, g, all_layers, name, conv.dependencies) - dependent_layers.add(tuple(conv.dependencies)) - elif isinstance(m, torch.nn.Linear): - fc = make_fc(model, m, g, name, seq_id=len(pruned_indices), layer_id=layer_id) - all_layers[layer_id] = fc - total_macs += fc.macs - total_params += fc.weights_vol - + new_layer.dependencies = list() + find_dependencies(dependency_type, g, all_layers, name, new_layer.dependencies) + dependent_layers.add(tuple(new_layer.dependencies)) + def convert_layer_names_to_indices(layer_names): """Args: layer_names - list of layer names @@ -777,6 +808,70 @@ def get_network_details(model, dataset, dependency_type, layers_to_prune=None): return all_layers, pruned_indices, dependent_indices, total_macs, total_params +def layers_topological_order(model, dummy_input, recurrent=False): + """ + Prepares an ordered list of layers to quantize sequentially. This list has all the layers ordered by their + topological order in the graph. + Args: + model (nn.Module): the model to quantize. + dummy_input (torch.Tensor): an input to be passed through the model. + recurrent (bool): indication on whether the model might have recurrent connections. + """ + + class _OpRank: + def __init__(self, adj_entry, rank=None): + self.adj_entry = adj_entry + self._rank = rank or 0 + + @property + def rank(self): + return self._rank + + @rank.setter + def rank(self, val): + self._rank = max(val, self._rank) + + def __repr__(self): + return '_OpRank(\'%s\' | %d)' % (self.adj_entry.op_meta.name, self.rank) + + adj_map = SummaryGraph(model, dummy_input).adjacency_map() + ranked_ops = {k: _OpRank(v, 0) for k, v in adj_map.items()} + + def _recurrent_ancestor(ranked_ops_dict, dest_op_name, src_op_name): + def _is_descendant(parent_op_name, dest_op_name): + successors_names = [op.name for op in adj_map[parent_op_name].successors] + if dest_op_name in successors_names: + return True + for succ_name in successors_names: + if _is_descendant(succ_name, dest_op_name): + return True + return False + + return _is_descendant(dest_op_name, src_op_name) and \ + (0 < ranked_ops_dict[dest_op_name].rank < ranked_ops_dict[src_op_name].rank) + + def rank_op(ranked_ops_dict, op_name, rank): + ranked_ops_dict[op_name].rank = rank + for child_op in adj_map[op_name].successors: + # In recurrent models: if a successor is also an ancestor - we don't increment its rank. + if not recurrent or not _recurrent_ancestor(ranked_ops_dict, child_op.name, op_name): + rank_op(ranked_ops_dict, child_op.name, ranked_ops_dict[op_name].rank + 1) + + roots = [k for k, v in adj_map.items() if len(v.predecessors) == 0] + for root_op_name in roots: + rank_op(ranked_ops, root_op_name, 0) + + # Take only the modules from the original model + # module_dict = dict(model.named_modules()) + # Neta + ret = sorted([k for k in ranked_ops.keys()], + key=lambda k: ranked_ops[k].rank) + + # Check that only the actual roots have a rank of 0 + assert {k for k in ret if ranked_ops[k].rank == 0} <= set(roots) + return ret + + import pandas as pd def sample_networks(net_wrapper, services): """Sample networks from the posterior distribution. diff --git a/examples/auto_compression/amc/rewards.py b/examples/auto_compression/amc/rewards.py index ab16cb4..24c7487 100755 --- a/examples/auto_compression/amc/rewards.py +++ b/examples/auto_compression/amc/rewards.py @@ -58,25 +58,23 @@ def mac_constrained_experimental_reward_fn(env, top1, top5, vloss, total_macs): def mac_constrained_clamp_action(env, pruning_action): """Compute a resource-constrained action""" - - # Todo: this is tightly coupled to the environment - refactor - flops = env.net_wrapper.layer_macs(env.current_layer()) - assert flops > 0 - reduced = env._removed_macs - prunable_rest, rest = env.rest_macs_raw() - rest += prunable_rest * env.action_high # how much we have to remove in other layers - target_reduction = (1 - env.amc_cfg.target_density) * env.original_model_macs + layer_macs = env.net_wrapper.layer_macs(env.current_layer()) + assert layer_macs > 0 + reduced = env.removed_macs + prunable_rest, non_prunable_rest = env.rest_macs_raw() + rest = prunable_rest * min(0.9, env.action_high) + target_reduction = (1. - env.amc_cfg.target_density) * env.original_model_macs assert reduced == env.original_model_macs - env.net_wrapper.total_macs duty = target_reduction - (reduced + rest) - pruning_action_final = min(env.action_high, max(pruning_action, duty/flops)) + pruning_action_final = min(1., max(pruning_action, duty/layer_macs)) - msglogger.debug("\t\tflops=%.3f reduced=%.3f rest=%.3f duty=%.3f" % (flops, reduced, rest, duty)) + msglogger.debug("\t\tflops=%.3f reduced=%.3f rest=%.3f duty=%.3f" % (layer_macs, reduced, rest, duty)) msglogger.debug("\t\tpruning_action=%.3f pruning_action_final=%.3f" % (pruning_action, pruning_action_final)) - msglogger.debug("\t\ttarget={:.2f} reduced={:.2f} rest={:.2f} duty={:.2f} flops={:.2f}". - format( 1-env.amc_cfg.target_density, reduced/env.original_model_macs, - rest/env.original_model_macs, - duty/env.original_model_macs, - flops/env.original_model_macs)) + msglogger.debug("\t\ttarget={:.2f} reduced={:.2f} rest={:.2f} duty={:.2f} flops={:.2f}\n". + format(1-env.amc_cfg.target_density, reduced/env.original_model_macs, + rest/env.original_model_macs, + duty/env.original_model_macs, + layer_macs/env.original_model_macs)) if pruning_action_final != pruning_action: msglogger.debug("pruning_action={:.2f}==>pruning_action_final={:.2f}".format(pruning_action, pruning_action_final)) diff --git a/examples/auto_compression/amc/rl_libs/coach/coach_if.py b/examples/auto_compression/amc/rl_libs/coach/coach_if.py index 41e2804..bef8e36 100755 --- a/examples/auto_compression/amc/rl_libs/coach/coach_if.py +++ b/examples/auto_compression/amc/rl_libs/coach/coach_if.py @@ -47,10 +47,11 @@ class RlLibInterface(object): graph_manager.heatup_steps = EnvironmentEpisodes(amc_cfg.ddpg_cfg.num_heatup_episodes) # Replay buffer size graph_manager.agent_params.memory.max_size = (MemoryGranularity.Transitions, amc_cfg.ddpg_cfg.replay_buffer_size) + amc_cfg.ddpg_cfg.training_noise_decay = amc_cfg.ddpg_cfg.training_noise_decay ** (1. / steps_per_episode) elif "ClippedPPO" in amc_cfg.agent_algo: - from examples.automated_deep_compression.rl_libs.coach.presets.ADC_ClippedPPO import graph_manager, agent_params + from examples.auto_compression.amc.rl_libs.coach.presets.ADC_ClippedPPO import graph_manager, agent_params elif "TD3" in amc_cfg.agent_algo: - from examples.automated_deep_compression.rl_libs.coach.presets.ADC_TD3 import graph_manager, agent_params + from examples.auto_compression.amc.rl_libs.coach.presets.ADC_TD3 import graph_manager, agent_params else: raise ValueError("The agent algorithm you are trying to use (%s) is not supported" % amc_cfg.amc_agent_algo) @@ -61,10 +62,10 @@ class RlLibInterface(object): graph_manager.steps_between_evaluation_periods = EnvironmentEpisodes(n_training_episodes) # These parameters are passed to the Distiller environment - env_cfg = {'model': model, - 'app_args': app_args, - 'amc_cfg': amc_cfg, - 'services': services} + env_cfg = {'model': model, + 'app_args': app_args, + 'amc_cfg': amc_cfg, + 'services': services} graph_manager.env_params.additional_simulator_parameters = env_cfg coach_logs_dir = os.path.join(msglogger.logdir, 'coach') diff --git a/examples/auto_compression/amc/rl_libs/private/private_if.py b/examples/auto_compression/amc/rl_libs/private/private_if.py index 9d71601..2ca43ce 100755 --- a/examples/auto_compression/amc/rl_libs/private/private_if.py +++ b/examples/auto_compression/amc/rl_libs/private/private_if.py @@ -44,7 +44,6 @@ class RlLibInterface(object): agent_args.lr_a = env.amc_cfg.ddpg_cfg.actor_lr agent_args.hidden1 = 300 agent_args.hidden2 = 300 - agent_args.rmsize = 100 agent_args.rmsize = env.amc_cfg.ddpg_cfg.replay_buffer_size agent_args.window_length = 1 agent_args.train_episode = (env.amc_cfg.ddpg_cfg.num_heatup_episodes + diff --git a/examples/auto_compression/amc/utils/data_dependencies.py b/examples/auto_compression/amc/utils/data_dependencies.py index d82c665..1e44e29 100755 --- a/examples/auto_compression/amc/utils/data_dependencies.py +++ b/examples/auto_compression/amc/utils/data_dependencies.py @@ -38,7 +38,7 @@ def find_dependencies(dependency_type, sgraph, layers, layer_name, dependencies_ def _find_dependencies_channels(sgraph, layers, layer_name, dependencies_list): - # Find all instances of Convolution layers that immediately preceed this layer + # Find all instances of Convolution layers that immediately precede this layer predecessors = sgraph.predecessors_f(layer_name, ['Conv']) for predecessor in predecessors: dependencies_list.append(predecessor) diff --git a/tests/test_pruning.py b/tests/test_pruning.py index 0239321..dd1b04f 100755 --- a/tests/test_pruning.py +++ b/tests/test_pruning.py @@ -18,6 +18,8 @@ import numpy as np import logging import math import torch +from functools import partial + import distiller import common import pytest @@ -289,7 +291,7 @@ def arbitrary_channel_pruning(config, channels_to_remove, is_parallel): zeros_mask_dict[pair[1] + ".weight"].mask = mask zeros_mask_dict[pair[1] + ".weight"].apply_mask(conv2_p) all_channels = set([ch for ch in range(num_channels)]) - nnz_channels = set(distiller.find_nonzero_channels_list(conv2_p, pair[1] + ".weight")) + nnz_channels = set(distiller.non_zero_channels(conv2_p)) channels_removed = all_channels - nnz_channels logger.info("Channels removed {}".format(channels_removed)) @@ -457,6 +459,20 @@ def test_magnitude_pruning(): assert common.almost_equal(distiller.sparsity(b), 1/distiller.volume(a)) +def test_row_pruning(): + param = torch.tensor([[1., 2., 3.], + [4., 5., 6.], + [7., 8., 9.]]) + from distiller.pruning import L1RankedStructureParameterPruner + + masker = distiller.scheduler.ParameterMasker("why name") + zeros_mask_dict = {"some name": masker} + L1RankedStructureParameterPruner.rank_and_prune_rows(0.5, param, "some name", zeros_mask_dict) + print(distiller.sparsity_rows(masker.mask)) + assert math.isclose(distiller.sparsity_rows(masker.mask), 1/3) + pass + + if __name__ == '__main__': for is_parallel in [True, False]: test_ranked_filter_pruning(is_parallel) @@ -477,3 +493,4 @@ if __name__ == '__main__': arbitrary_channel_pruning(mobilenet_imagenet(is_parallel), channels_to_remove=[0, 2], is_parallel=is_parallel) + test_row_pruning() \ No newline at end of file -- GitLab