From f89ba9616812a2cd3bc53f142e8fd6deb88be614 Mon Sep 17 00:00:00 2001 From: Neta Zmora <neta.zmora@intel.com> Date: Mon, 2 Sep 2019 15:07:51 +0300 Subject: [PATCH] AMC: non-functional refactoring Mainly: moved NetworkWrapper to a separate file. --- examples/auto_compression/amc/environment.py | 454 +----------------- .../rl_libs/coach/presets/ADC_ClippedPPO.py | 2 +- .../amc/rl_libs/coach/presets/ADC_PPO.py | 53 -- .../amc/rl_libs/coach/presets/ADC_TD3.py | 2 +- .../auto_compression/amc/utils/net_wrapper.py | 426 ++++++++++++++++ 5 files changed, 429 insertions(+), 508 deletions(-) delete mode 100755 examples/auto_compression/amc/rl_libs/coach/presets/ADC_PPO.py create mode 100644 examples/auto_compression/amc/utils/net_wrapper.py diff --git a/examples/auto_compression/amc/environment.py b/examples/auto_compression/amc/environment.py index 79415ab..e6fbcef 100755 --- a/examples/auto_compression/amc/environment.py +++ b/examples/auto_compression/amc/environment.py @@ -32,8 +32,6 @@ import torch import gym import distiller from collections import OrderedDict, namedtuple -from types import SimpleNamespace -from distiller import normalize_module_name, SummaryGraph from utils.features_collector import collect_intermediate_featuremap_samples from utils.ac_loggers import AMCStatsLogger, FineTuneStatsLogger @@ -77,237 +75,7 @@ def adjust_ppo_output(ppo_pruning_action, action_high, action_low): return float(pruning_action) -class NetworkMetadata(object): - def __init__(self, model, dataset, dependency_type, modules_list): - details = get_network_details(model, dataset, dependency_type, modules_list) - self.all_layers, self.pruned_idxs, self.dependent_idxs, self._total_macs, self._total_nnz = details - - @property - def total_macs(self): - return self._total_macs - - @property - def total_nnz(self): - return self._total_nnz - - def layer_net_macs(self, layer): - """Returns a MACs of a specific layer""" - return layer.macs - - def layer_macs(self, layer): - """Returns a MACs of a specific layer, with the impact on pruning-dependent layers""" - macs = layer.macs - for dependent_mod in layer.dependencies: - macs += self.name2layer(dependent_mod).macs - return macs - - def reduce_layer_macs(self, layer, reduction): - total_macs_reduced = layer.macs * reduction - total_nnz_reduced = layer.weights_vol * reduction - layer.macs -= total_macs_reduced - layer.weights_vol -= total_nnz_reduced - for dependent_mod in layer.dependencies: - macs_reduced = self.name2layer(dependent_mod).macs * reduction - nnz_reduced = self.name2layer(dependent_mod).weights_vol * reduction - total_macs_reduced += macs_reduced - total_nnz_reduced += nnz_reduced - self.name2layer(dependent_mod).macs -= macs_reduced - self.name2layer(dependent_mod).weights_vol -= nnz_reduced - self._total_macs -= total_macs_reduced - self._total_nnz -= total_nnz_reduced - - def name2layer(self, name): - layers = [layer for layer in self.all_layers.values() if layer.name == name] - if len(layers) == 1: - return layers[0] - raise ValueError("illegal module name %s" % name) - - def model_budget(self): - return self._total_macs, self._total_nnz - - def get_layer(self, layer_id): - return self.all_layers[layer_id] - - def get_pruned_layer(self, layer_id): - assert self.is_prunable(layer_id) - return self.get_layer(layer_id) - - def is_prunable(self, layer_id): - return layer_id in self.pruned_idxs - - def is_compressible(self, layer_id): - return layer_id in (self.pruned_idxs + self.dependent_idxs) - - def num_pruned_layers(self): - return len(self.pruned_idxs) - - def num_layers(self): - return len(self.all_layers) - - def performance_summary(self): - # return OrderedDict({layer.name: (layer.macs, layer.weights_vol) - # for layer in self.all_layers.values()}) - return OrderedDict({layer.name: layer.macs - for layer in self.all_layers.values()}) - - -class NetworkWrapper(object): - def __init__(self, model, app_args, services, modules_list, pruning_pattern): - self.app_args = app_args - self.services = services - self.cached_model_metadata = NetworkMetadata(model, app_args.dataset, - pruning_pattern, modules_list) - self.cached_perf_summary = self.cached_model_metadata.performance_summary() - self.reset(model) - self.sparsification_masks = None - - def reset(self, model): - self.model = model - self.zeros_mask_dict = distiller.create_model_masks_dict(self.model) - self.model_metadata = copy.deepcopy(self.cached_model_metadata) - - def get_resources_requirements(self): - total_macs, total_nnz = self.model_metadata.model_budget() - return total_macs, total_nnz - - @property - def arch(self): - return self.app_args.arch - - def num_pruned_layers(self): - return self.model_metadata.num_pruned_layers() - - def get_pruned_layer(self, layer_id): - return self.model_metadata.get_pruned_layer(layer_id) - - def get_layer(self, idx): - return self.model_metadata.get_layer(idx) - - def layer_macs(self, layer): - return self.model_metadata.layer_macs(layer) - - def layer_net_macs(self, layer): - return self.model_metadata.layer_net_macs(layer) - - def name2layer(self, name): - return self.model_metadata.name2layer(name) - - @property - def total_macs(self): - return self.model_metadata.total_macs - - @property - def total_nnz(self): - return self.model_metadata.total_nnz - - def performance_summary(self): - """Return a dictionary representing the performance the model. - - We calculate the performance of each layer relative to the original (uncompressed) model. - """ - current_perf = self.model_metadata.performance_summary() - ret = OrderedDict() - #return OrderedDict({k: v/v_baseline for ((k, v), (v_baseline)) in zip(current_perf.items(), self.cached_perf_summary.values())}) - for k, v in current_perf.items(): - ret[k] = v / self.cached_perf_summary[k] - return ret - - def create_scheduler(self): - scheduler = distiller.CompressionScheduler(self.model, self.zeros_mask_dict) - return scheduler - - def remove_structures(self, layer_id, fraction_to_prune, prune_what, prune_how, - group_size, apply_thinning, ranking_noise): - """Physically remove channels and corresponding filters from the model - - Returns the compute-sparsity of the layer with index 'layer_id' - """ - if layer_id not in self.model_metadata.pruned_idxs: - raise ValueError("idx=%d is not in correct range " % layer_id) - if fraction_to_prune < 0: - raise ValueError("fraction_to_prune=%.3f is illegal" % fraction_to_prune) - if fraction_to_prune == 0: - return 0 - - if fraction_to_prune == 1.0: - # For now, prevent the removal of entire layers - fraction_to_prune = ALMOST_ONE - - layer = self.model_metadata.get_pruned_layer(layer_id) - macs_before = self.layer_net_macs(layer) - conv_pname = layer.name + ".weight" - conv_p = distiller.model_find_param(self.model, conv_pname) - - msglogger.debug("ADC: trying to remove %.1f%% %s from %s" % (fraction_to_prune*100, prune_what, conv_pname)) - - if prune_what == "channels": - calculate_sparsity = distiller.sparsity_ch - if layer.type == "Linear": - calculate_sparsity = distiller.sparsity_rows - remove_structures_fn = distiller.remove_channels - group_type = "Channels" - elif prune_what == "filters": - calculate_sparsity = distiller.sparsity_3D - group_type = "Filters" - remove_structures_fn = distiller.remove_filters - else: - raise ValueError("unsupported structure {}".format(prune_what)) - - if prune_how in ["l1-rank", "stochastic-l1-rank"]: - # Create a channel/filter-ranking pruner - pruner = distiller.pruning.L1RankedStructureParameterPruner( - "auto_pruner", group_type, fraction_to_prune, conv_pname, - noise=ranking_noise, group_size=group_size) - meta = None - elif prune_how == "fm-reconstruction": - pruner = distiller.pruning.FMReconstructionChannelPruner( - "auto_pruner", group_type, fraction_to_prune, conv_pname, - group_size, math.ceil, ranking_noise=ranking_noise) - meta = {'model': self.model} - else: - raise ValueError("Unknown pruning method") - pruner.set_param_mask(conv_p, conv_pname, self.zeros_mask_dict, meta=meta) - del pruner - - if (self.zeros_mask_dict[conv_pname].mask is None or - 0 == calculate_sparsity(self.zeros_mask_dict[conv_pname].mask)): - msglogger.debug("remove_structures: aborting because there are no structures to prune") - return 0 - final_action = calculate_sparsity(self.zeros_mask_dict[conv_pname].mask) - - # Use the mask to prune - self.zeros_mask_dict[conv_pname].apply_mask(conv_p) - if apply_thinning: - self.cache_spasification_masks() - remove_structures_fn(self.model, self.zeros_mask_dict, self.app_args.arch, self.app_args.dataset, optimizer=None) - - self.model_metadata.reduce_layer_macs(layer, final_action) - macs_after = self.layer_net_macs(layer) - assert 1. - (macs_after / macs_before) == final_action - return final_action - - def validate(self): - top1, top5, vloss = self.services.validate_fn(model=self.model) - return top1, top5, vloss - - def train(self, num_epochs, episode=0): - """Train for zero or more epochs""" - opt_cfg = self.app_args.optimizer_data - optimizer = torch.optim.SGD(self.model.parameters(), lr=opt_cfg['lr'], - momentum=opt_cfg['momentum'], weight_decay=opt_cfg['weight_decay']) - compression_scheduler = self.create_scheduler() - acc_list = [] - for _ in range(num_epochs): - # Fine-tune the model - accuracies = self.services.train_fn(model=self.model, compression_scheduler=compression_scheduler, - optimizer=optimizer, epoch=episode) - acc_list.extend(accuracies) - del compression_scheduler - return acc_list - - def cache_spasification_masks(self): - masks = {param_name: masker.mask for param_name, masker in self.zeros_mask_dict.items()} - self.sparsification_masks = copy.deepcopy(masks) +from utils.net_wrapper import NetworkWrapper class DistillerWrapperEnvironment(gym.Env): @@ -705,223 +473,3 @@ class DistillerWrapperEnvironment(gym.Env): scheduler=scheduler, name=fname, extras=extras) del scheduler return fname - - -def get_network_details(model, dataset, dependency_type, layers_to_prune=None): - def make_conv(model, conv_module, g, name, seq_id, layer_id): - conv = SimpleNamespace() - conv.type = "Conv2D" - conv.name = name - conv.id = layer_id - conv.t = seq_id - conv.k = conv_module.kernel_size[0] - conv.stride = conv_module.stride - - # Use the SummaryGraph to obtain some other details of the models - conv_op = g.find_op(normalize_module_name(name)) - assert conv_op is not None - - conv.weights_vol = conv_op['attrs']['weights_vol'] - conv.macs = conv_op['attrs']['MACs'] - conv.n_ofm = conv_op['attrs']['n_ofm'] - conv.n_ifm = conv_op['attrs']['n_ifm'] - conv_pname = name + ".weight" - conv_p = distiller.model_find_param(model, conv_pname) - conv.ofm_h = g.param_shape(conv_op['outputs'][0])[2] - conv.ofm_w = g.param_shape(conv_op['outputs'][0])[3] - conv.ifm_h = g.param_shape(conv_op['inputs'][0])[2] - conv.ifm_w = g.param_shape(conv_op['inputs'][0])[3] - return conv - - def make_fc(model, fc_module, g, name, seq_id, layer_id): - fc = SimpleNamespace() - fc.type = "Linear" - fc.name = name - fc.id = layer_id - fc.t = seq_id - - # Use the SummaryGraph to obtain some other details of the models - fc_op = g.find_op(normalize_module_name(name)) - assert fc_op is not None - - fc.weights_vol = fc_op['attrs']['weights_vol'] - fc.macs = fc_op['attrs']['MACs'] - fc.n_ofm = fc_op['attrs']['n_ofm'] - fc.n_ifm = fc_op['attrs']['n_ifm'] - fc_pname = name + ".weight" - fc_p = distiller.model_find_param(model, fc_pname) - fc.ofm_h = g.param_shape(fc_op['outputs'][0])[0] - fc.ofm_w = g.param_shape(fc_op['outputs'][0])[1] - fc.ifm_h = g.param_shape(fc_op['inputs'][0])[0] - fc.ifm_w = g.param_shape(fc_op['inputs'][0])[1] - - return fc - - dummy_input = distiller.get_dummy_input(dataset) - g = SummaryGraph(model, dummy_input) - all_layers = OrderedDict() - pruned_indices = list() - dependent_layers = set() - total_macs = 0 - total_params = 0 - - unfiltered_layers = layers_topological_order(model, dummy_input) - mods = dict(model.named_modules()) - layers = OrderedDict({mod_name: mods[mod_name] for mod_name in unfiltered_layers - if mod_name in mods and - isinstance(mods[mod_name], (torch.nn.Conv2d, torch.nn.Linear))}) - - # layers = OrderedDict({mod_name: m for mod_name, m in model.named_modules() - # if isinstance(m, (torch.nn.Conv2d, torch.nn.Linear))}) - for layer_id, (name, m) in enumerate(layers.items()): - if isinstance(m, torch.nn.Conv2d) or isinstance(m, torch.nn.Linear): - if isinstance(m, torch.nn.Conv2d): - new_layer = make_conv(model, m, g, name, seq_id=len(pruned_indices), layer_id=layer_id) - all_layers[layer_id] = new_layer - total_params += new_layer.weights_vol - total_macs += new_layer.macs - elif isinstance(m, torch.nn.Linear): - new_layer = make_fc(model, m, g, name, seq_id=len(pruned_indices), layer_id=layer_id) - all_layers[layer_id] = new_layer - total_params += new_layer.weights_vol - total_macs += new_layer.macs - - if layers_to_prune is None or name in layers_to_prune: - pruned_indices.append(layer_id) - # Find the data-dependent layers of this convolution - from utils.data_dependencies import find_dependencies - new_layer.dependencies = list() - find_dependencies(dependency_type, g, all_layers, name, new_layer.dependencies) - dependent_layers.add(tuple(new_layer.dependencies)) - - def convert_layer_names_to_indices(layer_names): - """Args: - layer_names - list of layer names - Returns: - list of layer indices - """ - layer_indices = [index for name in layer_names for index, - layer in all_layers.items() if layer.name == name[0]] - return layer_indices - - dependent_indices = convert_layer_names_to_indices(dependent_layers) - return all_layers, pruned_indices, dependent_indices, total_macs, total_params - - -def layers_topological_order(model, dummy_input, recurrent=False): - """ - Prepares an ordered list of layers to quantize sequentially. This list has all the layers ordered by their - topological order in the graph. - Args: - model (nn.Module): the model to quantize. - dummy_input (torch.Tensor): an input to be passed through the model. - recurrent (bool): indication on whether the model might have recurrent connections. - """ - - class _OpRank: - def __init__(self, adj_entry, rank=None): - self.adj_entry = adj_entry - self._rank = rank or 0 - - @property - def rank(self): - return self._rank - - @rank.setter - def rank(self, val): - self._rank = max(val, self._rank) - - def __repr__(self): - return '_OpRank(\'%s\' | %d)' % (self.adj_entry.op_meta.name, self.rank) - - adj_map = SummaryGraph(model, dummy_input).adjacency_map() - ranked_ops = {k: _OpRank(v, 0) for k, v in adj_map.items()} - - def _recurrent_ancestor(ranked_ops_dict, dest_op_name, src_op_name): - def _is_descendant(parent_op_name, dest_op_name): - successors_names = [op.name for op in adj_map[parent_op_name].successors] - if dest_op_name in successors_names: - return True - for succ_name in successors_names: - if _is_descendant(succ_name, dest_op_name): - return True - return False - - return _is_descendant(dest_op_name, src_op_name) and \ - (0 < ranked_ops_dict[dest_op_name].rank < ranked_ops_dict[src_op_name].rank) - - def rank_op(ranked_ops_dict, op_name, rank): - ranked_ops_dict[op_name].rank = rank - for child_op in adj_map[op_name].successors: - # In recurrent models: if a successor is also an ancestor - we don't increment its rank. - if not recurrent or not _recurrent_ancestor(ranked_ops_dict, child_op.name, op_name): - rank_op(ranked_ops_dict, child_op.name, ranked_ops_dict[op_name].rank + 1) - - roots = [k for k, v in adj_map.items() if len(v.predecessors) == 0] - for root_op_name in roots: - rank_op(ranked_ops, root_op_name, 0) - - # Take only the modules from the original model - # module_dict = dict(model.named_modules()) - # Neta - ret = sorted([k for k in ranked_ops.keys()], - key=lambda k: ranked_ops[k].rank) - - # Check that only the actual roots have a rank of 0 - assert {k for k in ret if ranked_ops[k].rank == 0} <= set(roots) - return ret - - -import pandas as pd -def sample_networks(net_wrapper, services): - """Sample networks from the posterior distribution. - - 1. Sort the networks we discovered using AMC by their reward. - 2. Use the top 10% best-performing networks discovered by AMC to postulate a posterior distribution of the - density/sparsity of each layer: - p([layers-sparsity] | Top1, L1) - 3. Sample 100 networks from this distribution. - For each such network: fine-tune, score using Top1, and save - """ - #fname = "logs/resnet20___2019.01.29-102912/amc.csv" - fname = "logs/resnet20___2019.02.03-210001/amc.csv" - df = pd.read_csv(fname) - - #top1_sorted_df = df.sort_values(by=['top1'], ascending=False) - top1_sorted_df = df.sort_values(by=['reward'], ascending=False) - top10pct = top1_sorted_df[:int(len(df.index) * 0.1)] - - original_macs, _ = net_wrapper.get_resources_requirements() - layer_sparsities_list = [] - for index, row in top10pct.iterrows(): - layer_sparsities = row['action_history'] - layer_sparsities = layer_sparsities[1:-1].split(",") # convert from string to list - layer_sparsities = [float(sparsity) for sparsity in layer_sparsities] - layer_sparsities_list.append(layer_sparsities) - - layer_sparsities = np.array(layer_sparsities_list) - mean = layer_sparsities.mean(axis=0) - cov = np.cov(layer_sparsities.T) - num_networks = 100 - data = np.random.multivariate_normal(mean, cov, num_networks) - - orig_model = net_wrapper.model - for i in range(num_networks): - model = copy.deepcopy(orig_model) - net_wrapper.reset(model) - for layer_id, sparsity_level in enumerate(data[i]): - sparsity_level = min(max(0, sparsity_level), ALMOST_ONE) - net_wrapper.remove_structures(layer_id, - fraction_to_prune=sparsity_level, - prune_what="channels") - - net_wrapper.train(1) - top1, top5, vloss = net_wrapper.validate() - - """Save the learned-model checkpoint""" - scheduler = net_wrapper.create_scheduler() - total_macs, _ = net_wrapper.get_resources_requirements() - fname = "{}_top1_{:2f}__density_{:2f}_sampled".format(net_wrapper.arch, top1, total_macs/original_macs) - services.save_checkpoint_fn(epoch=0, model=net_wrapper.model, - scheduler=scheduler, name=fname) - del scheduler diff --git a/examples/auto_compression/amc/rl_libs/coach/presets/ADC_ClippedPPO.py b/examples/auto_compression/amc/rl_libs/coach/presets/ADC_ClippedPPO.py index 9c222ba..9fa419a 100755 --- a/examples/auto_compression/amc/rl_libs/coach/presets/ADC_ClippedPPO.py +++ b/examples/auto_compression/amc/rl_libs/coach/presets/ADC_ClippedPPO.py @@ -62,7 +62,7 @@ agent_params.pre_network_filter.add_observation_filter('observation', 'normalize # Environment # ############### env_params = GymVectorEnvironment() -env_params.level = '../automated_deep_compression/ADC.py:DistillerWrapperEnvironment' +env_params.level = './environment.py:DistillerWrapperEnvironment' vis_params = VisualizationParameters() vis_params.dump_parameters_documentation = False diff --git a/examples/auto_compression/amc/rl_libs/coach/presets/ADC_PPO.py b/examples/auto_compression/amc/rl_libs/coach/presets/ADC_PPO.py deleted file mode 100755 index 13b204c..0000000 --- a/examples/auto_compression/amc/rl_libs/coach/presets/ADC_PPO.py +++ /dev/null @@ -1,53 +0,0 @@ -from rl_coach.agents.ppo_agent import PPOAgentParameters -from rl_coach.architectures.layers import Dense -from rl_coach.base_parameters import VisualizationParameters, PresetValidationParameters, DistributedCoachSynchronizationType -from rl_coach.core_types import TrainingSteps, EnvironmentEpisodes, EnvironmentSteps -from rl_coach.environments.environment import SingleLevelSelection -from rl_coach.environments.gym_environment import GymVectorEnvironment, mujoco_v2 -from rl_coach.filters.filter import InputFilter -from rl_coach.filters.observation.observation_normalization_filter import ObservationNormalizationFilter -from rl_coach.graph_managers.basic_rl_graph_manager import BasicRLGraphManager -from rl_coach.graph_managers.graph_manager import ScheduleParameters - -#################### -# Graph Scheduling # -#################### -schedule_params = ScheduleParameters() -schedule_params.improve_steps = TrainingSteps(10000000000) -schedule_params.steps_between_evaluation_periods = EnvironmentSteps(2000) -schedule_params.evaluation_steps = EnvironmentEpisodes(1) -schedule_params.heatup_steps = EnvironmentSteps(0) - -######### -# Agent # -######### -agent_params = PPOAgentParameters() -agent_params.network_wrappers['actor'].learning_rate = 0.001 -agent_params.network_wrappers['critic'].learning_rate = 0.001 - -agent_params.network_wrappers['actor'].input_embedders_parameters['observation'].scheme = [Dense(64)] -agent_params.network_wrappers['actor'].middleware_parameters.scheme = [Dense(64)] -agent_params.network_wrappers['critic'].input_embedders_parameters['observation'].scheme = [Dense(64)] -agent_params.network_wrappers['critic'].middleware_parameters.scheme = [Dense(64)] - -agent_params.input_filter = InputFilter() -agent_params.input_filter.add_observation_filter('observation', 'normalize', ObservationNormalizationFilter()) - -# Distributed Coach synchronization type. -agent_params.algorithm.distributed_coach_synchronization_type = DistributedCoachSynchronizationType.SYNC - -############### -# Environment # -############### -env_params = GymVectorEnvironment() -env_params.level = '../automated_deep_compression/ADC.py:DistillerWrapperEnvironment' - -vis_params = VisualizationParameters() -vis_params.dump_parameters_documentation = False -vis_params.render = True -vis_params.native_rendering = True -vis_params.dump_signals_to_csv_every_x_episodes = 1 -graph_manager = BasicRLGraphManager(agent_params=agent_params, env_params=env_params, - schedule_params=schedule_params, vis_params=vis_params) - - diff --git a/examples/auto_compression/amc/rl_libs/coach/presets/ADC_TD3.py b/examples/auto_compression/amc/rl_libs/coach/presets/ADC_TD3.py index 3f9bd4b..5091cd8 100755 --- a/examples/auto_compression/amc/rl_libs/coach/presets/ADC_TD3.py +++ b/examples/auto_compression/amc/rl_libs/coach/presets/ADC_TD3.py @@ -41,7 +41,7 @@ agent_params.algorithm.act_for_full_episodes = True # Environment # ############### env_params = GymVectorEnvironment() -env_params.level = '../automated_deep_compression/ADC.py:DistillerWrapperEnvironment' +env_params.level = './environment.py:DistillerWrapperEnvironment' vis_params = VisualizationParameters() diff --git a/examples/auto_compression/amc/utils/net_wrapper.py b/examples/auto_compression/amc/utils/net_wrapper.py new file mode 100644 index 0000000..9bdbe3f --- /dev/null +++ b/examples/auto_compression/amc/utils/net_wrapper.py @@ -0,0 +1,426 @@ +# +# Copyright (c) 2019 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import math +import copy +import logging +import torch +import distiller +from collections import OrderedDict, namedtuple +from types import SimpleNamespace +from distiller import normalize_module_name, SummaryGraph + + +__all__ = ["NetworkWrapper"] +msglogger = logging.getLogger() + + +class NetworkWrapper(object): + def __init__(self, model, app_args, services, modules_list, pruning_pattern): + self.app_args = app_args + self.services = services + self.cached_model_metadata = NetworkMetadata(model, app_args.dataset, + pruning_pattern, modules_list) + self.cached_perf_summary = self.cached_model_metadata.performance_summary() + self.reset(model) + self.sparsification_masks = None + + def reset(self, model): + self.model = model + self.zeros_mask_dict = distiller.create_model_masks_dict(self.model) + self.model_metadata = copy.deepcopy(self.cached_model_metadata) + + def get_resources_requirements(self): + total_macs, total_nnz = self.model_metadata.model_budget() + return total_macs, total_nnz + + @property + def arch(self): + return self.app_args.arch + + def num_pruned_layers(self): + return self.model_metadata.num_pruned_layers() + + def get_pruned_layer(self, layer_id): + return self.model_metadata.get_pruned_layer(layer_id) + + def get_layer(self, idx): + return self.model_metadata.get_layer(idx) + + def layer_macs(self, layer): + return self.model_metadata.layer_macs(layer) + + def layer_net_macs(self, layer): + return self.model_metadata.layer_net_macs(layer) + + def name2layer(self, name): + return self.model_metadata.name2layer(name) + + @property + def total_macs(self): + return self.model_metadata.total_macs + + @property + def total_nnz(self): + return self.model_metadata.total_nnz + + def performance_summary(self): + """Return a dictionary representing the performance the model. + + We calculate the performance of each layer relative to the original (uncompressed) model. + """ + current_perf = self.model_metadata.performance_summary() + ret = OrderedDict() + #return OrderedDict({k: v/v_baseline for ((k, v), (v_baseline)) in zip(current_perf.items(), self.cached_perf_summary.values())}) + for k, v in current_perf.items(): + ret[k] = v / self.cached_perf_summary[k] + return ret + + def create_scheduler(self): + scheduler = distiller.CompressionScheduler(self.model, self.zeros_mask_dict) + return scheduler + + def remove_structures(self, layer_id, fraction_to_prune, prune_what, prune_how, + group_size, apply_thinning, ranking_noise): + """Physically remove channels and corresponding filters from the model + + Returns the compute-sparsity of the layer with index 'layer_id' + """ + if layer_id not in self.model_metadata.pruned_idxs: + raise ValueError("idx=%d is not in correct range " % layer_id) + if fraction_to_prune < 0: + raise ValueError("fraction_to_prune=%.3f is illegal" % fraction_to_prune) + if fraction_to_prune == 0: + return 0 + + if fraction_to_prune == 1.0: + # For now, prevent the removal of entire layers + fraction_to_prune = ALMOST_ONE + + layer = self.model_metadata.get_pruned_layer(layer_id) + macs_before = self.layer_net_macs(layer) + conv_pname = layer.name + ".weight" + conv_p = distiller.model_find_param(self.model, conv_pname) + + msglogger.debug("ADC: trying to remove %.1f%% %s from %s" % (fraction_to_prune*100, prune_what, conv_pname)) + + if prune_what == "channels": + calculate_sparsity = distiller.sparsity_ch + if layer.type == "Linear": + calculate_sparsity = distiller.sparsity_rows + remove_structures_fn = distiller.remove_channels + group_type = "Channels" + elif prune_what == "filters": + calculate_sparsity = distiller.sparsity_3D + group_type = "Filters" + remove_structures_fn = distiller.remove_filters + else: + raise ValueError("unsupported structure {}".format(prune_what)) + + if prune_how in ["l1-rank", "stochastic-l1-rank"]: + # Create a channel/filter-ranking pruner + pruner = distiller.pruning.L1RankedStructureParameterPruner( + "auto_pruner", group_type, fraction_to_prune, conv_pname, + noise=ranking_noise, group_size=group_size) + meta = None + elif prune_how == "fm-reconstruction": + pruner = distiller.pruning.FMReconstructionChannelPruner( + "auto_pruner", group_type, fraction_to_prune, conv_pname, + group_size, math.ceil, ranking_noise=ranking_noise) + meta = {'model': self.model} + else: + raise ValueError("Unknown pruning method") + pruner.set_param_mask(conv_p, conv_pname, self.zeros_mask_dict, meta=meta) + del pruner + + if (self.zeros_mask_dict[conv_pname].mask is None or + 0 == calculate_sparsity(self.zeros_mask_dict[conv_pname].mask)): + msglogger.debug("remove_structures: aborting because there are no structures to prune") + return 0 + final_action = calculate_sparsity(self.zeros_mask_dict[conv_pname].mask) + + # Use the mask to prune + self.zeros_mask_dict[conv_pname].apply_mask(conv_p) + if apply_thinning: + self.cache_spasification_masks() + remove_structures_fn(self.model, self.zeros_mask_dict, self.app_args.arch, self.app_args.dataset, optimizer=None) + + self.model_metadata.reduce_layer_macs(layer, final_action) + macs_after = self.layer_net_macs(layer) + assert 1. - (macs_after / macs_before) == final_action + return final_action + + def validate(self): + top1, top5, vloss = self.services.validate_fn(model=self.model) + return top1, top5, vloss + + def train(self, num_epochs, episode=0): + """Train for zero or more epochs""" + opt_cfg = self.app_args.optimizer_data + optimizer = torch.optim.SGD(self.model.parameters(), lr=opt_cfg['lr'], + momentum=opt_cfg['momentum'], weight_decay=opt_cfg['weight_decay']) + compression_scheduler = self.create_scheduler() + acc_list = [] + for _ in range(num_epochs): + # Fine-tune the model + accuracies = self.services.train_fn(model=self.model, compression_scheduler=compression_scheduler, + optimizer=optimizer, epoch=episode) + acc_list.extend(accuracies) + del compression_scheduler + return acc_list + + def cache_spasification_masks(self): + masks = {param_name: masker.mask for param_name, masker in self.zeros_mask_dict.items()} + self.sparsification_masks = copy.deepcopy(masks) + + +class NetworkMetadata(object): + def __init__(self, model, dataset, dependency_type, modules_list): + details = get_network_details(model, dataset, dependency_type, modules_list) + self.all_layers, self.pruned_idxs, self.dependent_idxs, self._total_macs, self._total_nnz = details + + @property + def total_macs(self): + return self._total_macs + + @property + def total_nnz(self): + return self._total_nnz + + def layer_net_macs(self, layer): + """Returns a MACs of a specific layer""" + return layer.macs + + def layer_macs(self, layer): + """Returns a MACs of a specific layer, with the impact on pruning-dependent layers""" + macs = layer.macs + for dependent_mod in layer.dependencies: + macs += self.name2layer(dependent_mod).macs + return macs + + def reduce_layer_macs(self, layer, reduction): + total_macs_reduced = layer.macs * reduction + total_nnz_reduced = layer.weights_vol * reduction + layer.macs -= total_macs_reduced + layer.weights_vol -= total_nnz_reduced + for dependent_mod in layer.dependencies: + macs_reduced = self.name2layer(dependent_mod).macs * reduction + nnz_reduced = self.name2layer(dependent_mod).weights_vol * reduction + total_macs_reduced += macs_reduced + total_nnz_reduced += nnz_reduced + self.name2layer(dependent_mod).macs -= macs_reduced + self.name2layer(dependent_mod).weights_vol -= nnz_reduced + self._total_macs -= total_macs_reduced + self._total_nnz -= total_nnz_reduced + + def name2layer(self, name): + layers = [layer for layer in self.all_layers.values() if layer.name == name] + if len(layers) == 1: + return layers[0] + raise ValueError("illegal module name %s" % name) + + def model_budget(self): + return self._total_macs, self._total_nnz + + def get_layer(self, layer_id): + return self.all_layers[layer_id] + + def get_pruned_layer(self, layer_id): + assert self.is_prunable(layer_id) + return self.get_layer(layer_id) + + def is_prunable(self, layer_id): + return layer_id in self.pruned_idxs + + def is_compressible(self, layer_id): + return layer_id in (self.pruned_idxs + self.dependent_idxs) + + def num_pruned_layers(self): + return len(self.pruned_idxs) + + def num_layers(self): + return len(self.all_layers) + + def performance_summary(self): + # return OrderedDict({layer.name: (layer.macs, layer.weights_vol) + # for layer in self.all_layers.values()}) + return OrderedDict({layer.name: layer.macs + for layer in self.all_layers.values()}) + + +def get_network_details(model, dataset, dependency_type, layers_to_prune=None): + def make_conv(model, conv_module, g, name, seq_id, layer_id): + conv = SimpleNamespace() + conv.type = "Conv2D" + conv.name = name + conv.id = layer_id + conv.t = seq_id + conv.k = conv_module.kernel_size[0] + conv.stride = conv_module.stride + + # Use the SummaryGraph to obtain some other details of the models + conv_op = g.find_op(normalize_module_name(name)) + assert conv_op is not None + + conv.weights_vol = conv_op['attrs']['weights_vol'] + conv.macs = conv_op['attrs']['MACs'] + conv.n_ofm = conv_op['attrs']['n_ofm'] + conv.n_ifm = conv_op['attrs']['n_ifm'] + conv_pname = name + ".weight" + conv_p = distiller.model_find_param(model, conv_pname) + conv.ofm_h = g.param_shape(conv_op['outputs'][0])[2] + conv.ofm_w = g.param_shape(conv_op['outputs'][0])[3] + conv.ifm_h = g.param_shape(conv_op['inputs'][0])[2] + conv.ifm_w = g.param_shape(conv_op['inputs'][0])[3] + return conv + + def make_fc(model, fc_module, g, name, seq_id, layer_id): + fc = SimpleNamespace() + fc.type = "Linear" + fc.name = name + fc.id = layer_id + fc.t = seq_id + + # Use the SummaryGraph to obtain some other details of the models + fc_op = g.find_op(normalize_module_name(name)) + assert fc_op is not None + + fc.weights_vol = fc_op['attrs']['weights_vol'] + fc.macs = fc_op['attrs']['MACs'] + fc.n_ofm = fc_op['attrs']['n_ofm'] + fc.n_ifm = fc_op['attrs']['n_ifm'] + fc_pname = name + ".weight" + fc_p = distiller.model_find_param(model, fc_pname) + fc.ofm_h = g.param_shape(fc_op['outputs'][0])[0] + fc.ofm_w = g.param_shape(fc_op['outputs'][0])[1] + fc.ifm_h = g.param_shape(fc_op['inputs'][0])[0] + fc.ifm_w = g.param_shape(fc_op['inputs'][0])[1] + + return fc + + dummy_input = distiller.get_dummy_input(dataset) + g = SummaryGraph(model, dummy_input) + all_layers = OrderedDict() + pruned_indices = list() + dependent_layers = set() + total_macs = 0 + total_params = 0 + + unfiltered_layers = layers_topological_order(model, dummy_input) + mods = dict(model.named_modules()) + layers = OrderedDict({mod_name: mods[mod_name] for mod_name in unfiltered_layers + if mod_name in mods and + isinstance(mods[mod_name], (torch.nn.Conv2d, torch.nn.Linear))}) + + # layers = OrderedDict({mod_name: m for mod_name, m in model.named_modules() + # if isinstance(m, (torch.nn.Conv2d, torch.nn.Linear))}) + for layer_id, (name, m) in enumerate(layers.items()): + if isinstance(m, torch.nn.Conv2d) or isinstance(m, torch.nn.Linear): + if isinstance(m, torch.nn.Conv2d): + new_layer = make_conv(model, m, g, name, seq_id=len(pruned_indices), layer_id=layer_id) + all_layers[layer_id] = new_layer + total_params += new_layer.weights_vol + total_macs += new_layer.macs + elif isinstance(m, torch.nn.Linear): + new_layer = make_fc(model, m, g, name, seq_id=len(pruned_indices), layer_id=layer_id) + all_layers[layer_id] = new_layer + total_params += new_layer.weights_vol + total_macs += new_layer.macs + + if layers_to_prune is None or name in layers_to_prune: + pruned_indices.append(layer_id) + # Find the data-dependent layers of this convolution + from utils.data_dependencies import find_dependencies + new_layer.dependencies = list() + find_dependencies(dependency_type, g, all_layers, name, new_layer.dependencies) + dependent_layers.add(tuple(new_layer.dependencies)) + + def convert_layer_names_to_indices(layer_names): + """Args: + layer_names - list of layer names + Returns: + list of layer indices + """ + layer_indices = [index for name in layer_names for index, + layer in all_layers.items() if layer.name == name[0]] + return layer_indices + + dependent_indices = convert_layer_names_to_indices(dependent_layers) + return all_layers, pruned_indices, dependent_indices, total_macs, total_params + + +def layers_topological_order(model, dummy_input, recurrent=False): + """ + Prepares an ordered list of layers to quantize sequentially. This list has all the layers ordered by their + topological order in the graph. + Args: + model (nn.Module): the model to quantize. + dummy_input (torch.Tensor): an input to be passed through the model. + recurrent (bool): indication on whether the model might have recurrent connections. + """ + + class _OpRank: + def __init__(self, adj_entry, rank=None): + self.adj_entry = adj_entry + self._rank = rank or 0 + + @property + def rank(self): + return self._rank + + @rank.setter + def rank(self, val): + self._rank = max(val, self._rank) + + def __repr__(self): + return '_OpRank(\'%s\' | %d)' % (self.adj_entry.op_meta.name, self.rank) + + adj_map = SummaryGraph(model, dummy_input).adjacency_map() + ranked_ops = {k: _OpRank(v, 0) for k, v in adj_map.items()} + + def _recurrent_ancestor(ranked_ops_dict, dest_op_name, src_op_name): + def _is_descendant(parent_op_name, dest_op_name): + successors_names = [op.name for op in adj_map[parent_op_name].successors] + if dest_op_name in successors_names: + return True + for succ_name in successors_names: + if _is_descendant(succ_name, dest_op_name): + return True + return False + + return _is_descendant(dest_op_name, src_op_name) and \ + (0 < ranked_ops_dict[dest_op_name].rank < ranked_ops_dict[src_op_name].rank) + + def rank_op(ranked_ops_dict, op_name, rank): + ranked_ops_dict[op_name].rank = rank + for child_op in adj_map[op_name].successors: + # In recurrent models: if a successor is also an ancestor - we don't increment its rank. + if not recurrent or not _recurrent_ancestor(ranked_ops_dict, child_op.name, op_name): + rank_op(ranked_ops_dict, child_op.name, ranked_ops_dict[op_name].rank + 1) + + roots = [k for k, v in adj_map.items() if len(v.predecessors) == 0] + for root_op_name in roots: + rank_op(ranked_ops, root_op_name, 0) + + # Take only the modules from the original model + # module_dict = dict(model.named_modules()) + # Neta + ret = sorted([k for k in ranked_ops.keys()], + key=lambda k: ranked_ops[k].rank) + + # Check that only the actual roots have a rank of 0 + assert {k for k in ret if ranked_ops[k].rank == 0} <= set(roots) + return ret -- GitLab