From e6625e4abc723837a6d7693522d304d05ad2c43d Mon Sep 17 00:00:00 2001
From: Neta Zmora <neta.zmora@intel.com>
Date: Tue, 13 Aug 2019 22:36:34 +0300
Subject: [PATCH] AMC: environment.py - factor initiaization code

---
 examples/auto_compression/amc/environment.py | 80 +++++++++++---------
 1 file changed, 44 insertions(+), 36 deletions(-)

diff --git a/examples/auto_compression/amc/environment.py b/examples/auto_compression/amc/environment.py
index 01dfd0b..d248d4b 100755
--- a/examples/auto_compression/amc/environment.py
+++ b/examples/auto_compression/amc/environment.py
@@ -324,59 +324,67 @@ class DistillerWrapperEnvironment(gym.Env):
         try:
             modules_list = amc_cfg.modules_dict[app_args.arch]
         except KeyError:
-            msglogger.warning("!!! The config file does not specify the modules to compress for %s" % app_args.arch)
-            # Default to using all convolution layers
-            distiller.assign_layer_fq_names(model)
-            modules_list = [mod.distiller_name for mod in model.modules() if type(mod)==torch.nn.Conv2d]
-            msglogger.warning("Using the following layers: %s" % ", ".join(modules_list))
             raise ValueError("The config file does not specify the modules to compress for %s" % app_args.arch)
-
         self.net_wrapper = NetworkWrapper(model, app_args, services, modules_list, amc_cfg.pruning_pattern)
         self.original_model_macs, self.original_model_size = self.net_wrapper.get_resources_requirements()
         self.reset(init_only=True)
-        msglogger.debug("Model %s has %d modules (%d pruned)", self.app_args.arch, 
-                                                               self.net_wrapper.model_metadata.num_layers(),
-                                                               self.net_wrapper.model_metadata.num_pruned_layers())
-        msglogger.debug("\tTotal MACs: %s" % distiller.pretty_int(self.original_model_macs))
-        msglogger.debug("\tTotal weights: %s" % distiller.pretty_int(self.original_model_size))
         self._max_episode_steps = self.net_wrapper.model_metadata.num_pruned_layers()  # Hack for Coach-TD3
-        log_amc_config(amc_cfg)
-
         self.episode = 0
         self.best_reward = float("-inf")
         self.action_low = amc_cfg.action_range[0]
         self.action_high = amc_cfg.action_range[1]
+        self._log_model_info()
+        log_amc_config(amc_cfg)
+        self._configure_action_space()
+        self.observation_space = spaces.Box(0, float("inf"), shape=(len(Observation._fields),))
+        self.stats_logger = AMCStatsLogger(os.path.join(logdir, 'amc.csv'))
+        self.ft_stats_logger = FineTuneStatsLogger(os.path.join(logdir, 'ft_top1.csv'))
+
+        if self.amc_cfg.pruning_method == "fm-reconstruction":
+            self._collect_fm_reconstruction_samples(modules_list)
 
+    def _collect_fm_reconstruction_samples(self, modules_list):
+        """Run the forward-pass on the selected dataset and collect feature-map samples.
+
+        These data will be used when we optimize the compressed-net's weights by trying
+        to reconstruct these samples.
+        """
+        from functools import partial
+        if self.amc_cfg.pruning_pattern != "channels":
+            raise ValueError("Feature-map reconstruction is only supported when pruning weights channels")
+
+        def acceptance_criterion(m, mod_names):
+            # Collect feature-maps only for Conv2d layers, if they are in our modules list.
+            return isinstance(m, torch.nn.Conv2d) and m.distiller_name in mod_names
+
+        # For feature-map reconstruction we need to collect a representative set
+        # of inter-layer feature-maps
+        from distiller.pruning import FMReconstructionChannelPruner
+        collect_intermediate_featuremap_samples(
+            self.net_wrapper.model,
+            self.net_wrapper.validate,
+            partial(acceptance_criterion, mod_names=modules_list),
+            partial(FMReconstructionChannelPruner.cache_featuremaps_fwd_hook,
+                    n_points_per_fm=self.amc_cfg.n_points_per_fm))
+
+    def _log_model_info(self):
+        msglogger.debug("Model %s has %d modules (%d pruned)", self.app_args.arch,
+                                                               self.net_wrapper.model_metadata.num_layers(),
+                                                               self.net_wrapper.model_metadata.num_pruned_layers())
+        msglogger.debug("\tTotal MACs: %s" % distiller.pretty_int(self.original_model_macs))
+        msglogger.debug("\tTotal weights: %s" % distiller.pretty_int(self.original_model_size))
+
+    def _configure_action_space(self):
         if is_using_continuous_action_space(self.amc_cfg.agent_algo):
             if self.amc_cfg.agent_algo == "ClippedPPO-continuous":
                 self.action_space = spaces.Box(PPO_MIN, PPO_MAX, shape=(1,))
             else:
-                self.action_space = spaces.Box(self.action_low, self.action_high, shape=(1,)) 
+                self.action_space = spaces.Box(self.action_low, self.action_high, shape=(1,))
             self.action_space.default_action = self.action_low
         else:
             self.action_space = spaces.Discrete(10)
-        self.observation_space = spaces.Box(0, float("inf"), shape=(len(Observation._fields),))
-        self.stats_logger = AMCStatsLogger(os.path.join(logdir, 'amc.csv'))
-        self.ft_stats_logger = FineTuneStatsLogger(os.path.join(logdir, 'ft_top1.csv'))
 
-        if self.amc_cfg.pruning_method == "fm-reconstruction":
-            if self.amc_cfg.pruning_pattern != "channels":
-                raise ValueError("Feature-map reconstruction is only supported when pruning weights channels")
-            
-            from functools import partial
-            def acceptance_criterion(m, mod_names):
-                # Collect feature-maps only for Conv2d layers, if they are in our modules list.
-                return isinstance(m, torch.nn.Conv2d) and m.distiller_name in mod_names
-
-            # For feature-map reconstruction we need to collect a representative set 
-            # of inter-layer feature-maps
-            from distiller.pruning import FMReconstructionChannelPruner
-            collect_intermediate_featuremap_samples(
-                self.net_wrapper.model,
-                self.net_wrapper.validate, 
-                partial(acceptance_criterion, mod_names=modules_list),
-                partial(FMReconstructionChannelPruner.cache_featuremaps_fwd_hook, 
-                        n_points_per_fm=self.amc_cfg.n_points_per_fm))
+
     @property
     def steps_per_episode(self):
         return self.net_wrapper.model_metadata.num_pruned_layers() 
@@ -431,7 +439,7 @@ class DistillerWrapperEnvironment(gym.Env):
     def step(self, pruning_action):
         """Take a step, given an action.
 
-        The action represents the desired sparsity for the "current" layer.
+        The action represents the desired sparsity for the "current" layer (i.e. the percentage of weights to remove).
         This function is invoked by the Agent.
         """
         pruning_action = float(pruning_action[0])
-- 
GitLab