From e6625e4abc723837a6d7693522d304d05ad2c43d Mon Sep 17 00:00:00 2001 From: Neta Zmora <neta.zmora@intel.com> Date: Tue, 13 Aug 2019 22:36:34 +0300 Subject: [PATCH] AMC: environment.py - factor initiaization code --- examples/auto_compression/amc/environment.py | 80 +++++++++++--------- 1 file changed, 44 insertions(+), 36 deletions(-) diff --git a/examples/auto_compression/amc/environment.py b/examples/auto_compression/amc/environment.py index 01dfd0b..d248d4b 100755 --- a/examples/auto_compression/amc/environment.py +++ b/examples/auto_compression/amc/environment.py @@ -324,59 +324,67 @@ class DistillerWrapperEnvironment(gym.Env): try: modules_list = amc_cfg.modules_dict[app_args.arch] except KeyError: - msglogger.warning("!!! The config file does not specify the modules to compress for %s" % app_args.arch) - # Default to using all convolution layers - distiller.assign_layer_fq_names(model) - modules_list = [mod.distiller_name for mod in model.modules() if type(mod)==torch.nn.Conv2d] - msglogger.warning("Using the following layers: %s" % ", ".join(modules_list)) raise ValueError("The config file does not specify the modules to compress for %s" % app_args.arch) - self.net_wrapper = NetworkWrapper(model, app_args, services, modules_list, amc_cfg.pruning_pattern) self.original_model_macs, self.original_model_size = self.net_wrapper.get_resources_requirements() self.reset(init_only=True) - msglogger.debug("Model %s has %d modules (%d pruned)", self.app_args.arch, - self.net_wrapper.model_metadata.num_layers(), - self.net_wrapper.model_metadata.num_pruned_layers()) - msglogger.debug("\tTotal MACs: %s" % distiller.pretty_int(self.original_model_macs)) - msglogger.debug("\tTotal weights: %s" % distiller.pretty_int(self.original_model_size)) self._max_episode_steps = self.net_wrapper.model_metadata.num_pruned_layers() # Hack for Coach-TD3 - log_amc_config(amc_cfg) - self.episode = 0 self.best_reward = float("-inf") self.action_low = amc_cfg.action_range[0] self.action_high = amc_cfg.action_range[1] + self._log_model_info() + log_amc_config(amc_cfg) + self._configure_action_space() + self.observation_space = spaces.Box(0, float("inf"), shape=(len(Observation._fields),)) + self.stats_logger = AMCStatsLogger(os.path.join(logdir, 'amc.csv')) + self.ft_stats_logger = FineTuneStatsLogger(os.path.join(logdir, 'ft_top1.csv')) + + if self.amc_cfg.pruning_method == "fm-reconstruction": + self._collect_fm_reconstruction_samples(modules_list) + def _collect_fm_reconstruction_samples(self, modules_list): + """Run the forward-pass on the selected dataset and collect feature-map samples. + + These data will be used when we optimize the compressed-net's weights by trying + to reconstruct these samples. + """ + from functools import partial + if self.amc_cfg.pruning_pattern != "channels": + raise ValueError("Feature-map reconstruction is only supported when pruning weights channels") + + def acceptance_criterion(m, mod_names): + # Collect feature-maps only for Conv2d layers, if they are in our modules list. + return isinstance(m, torch.nn.Conv2d) and m.distiller_name in mod_names + + # For feature-map reconstruction we need to collect a representative set + # of inter-layer feature-maps + from distiller.pruning import FMReconstructionChannelPruner + collect_intermediate_featuremap_samples( + self.net_wrapper.model, + self.net_wrapper.validate, + partial(acceptance_criterion, mod_names=modules_list), + partial(FMReconstructionChannelPruner.cache_featuremaps_fwd_hook, + n_points_per_fm=self.amc_cfg.n_points_per_fm)) + + def _log_model_info(self): + msglogger.debug("Model %s has %d modules (%d pruned)", self.app_args.arch, + self.net_wrapper.model_metadata.num_layers(), + self.net_wrapper.model_metadata.num_pruned_layers()) + msglogger.debug("\tTotal MACs: %s" % distiller.pretty_int(self.original_model_macs)) + msglogger.debug("\tTotal weights: %s" % distiller.pretty_int(self.original_model_size)) + + def _configure_action_space(self): if is_using_continuous_action_space(self.amc_cfg.agent_algo): if self.amc_cfg.agent_algo == "ClippedPPO-continuous": self.action_space = spaces.Box(PPO_MIN, PPO_MAX, shape=(1,)) else: - self.action_space = spaces.Box(self.action_low, self.action_high, shape=(1,)) + self.action_space = spaces.Box(self.action_low, self.action_high, shape=(1,)) self.action_space.default_action = self.action_low else: self.action_space = spaces.Discrete(10) - self.observation_space = spaces.Box(0, float("inf"), shape=(len(Observation._fields),)) - self.stats_logger = AMCStatsLogger(os.path.join(logdir, 'amc.csv')) - self.ft_stats_logger = FineTuneStatsLogger(os.path.join(logdir, 'ft_top1.csv')) - if self.amc_cfg.pruning_method == "fm-reconstruction": - if self.amc_cfg.pruning_pattern != "channels": - raise ValueError("Feature-map reconstruction is only supported when pruning weights channels") - - from functools import partial - def acceptance_criterion(m, mod_names): - # Collect feature-maps only for Conv2d layers, if they are in our modules list. - return isinstance(m, torch.nn.Conv2d) and m.distiller_name in mod_names - - # For feature-map reconstruction we need to collect a representative set - # of inter-layer feature-maps - from distiller.pruning import FMReconstructionChannelPruner - collect_intermediate_featuremap_samples( - self.net_wrapper.model, - self.net_wrapper.validate, - partial(acceptance_criterion, mod_names=modules_list), - partial(FMReconstructionChannelPruner.cache_featuremaps_fwd_hook, - n_points_per_fm=self.amc_cfg.n_points_per_fm)) + @property def steps_per_episode(self): return self.net_wrapper.model_metadata.num_pruned_layers() @@ -431,7 +439,7 @@ class DistillerWrapperEnvironment(gym.Env): def step(self, pruning_action): """Take a step, given an action. - The action represents the desired sparsity for the "current" layer. + The action represents the desired sparsity for the "current" layer (i.e. the percentage of weights to remove). This function is invoked by the Agent. """ pruning_action = float(pruning_action[0]) -- GitLab