Skip to content
Snippets Groups Projects
Commit e6625e4a authored by Neta Zmora's avatar Neta Zmora
Browse files

AMC: environment.py - factor initiaization code

parent 7d712fab
No related branches found
No related tags found
No related merge requests found
...@@ -324,59 +324,67 @@ class DistillerWrapperEnvironment(gym.Env): ...@@ -324,59 +324,67 @@ class DistillerWrapperEnvironment(gym.Env):
try: try:
modules_list = amc_cfg.modules_dict[app_args.arch] modules_list = amc_cfg.modules_dict[app_args.arch]
except KeyError: except KeyError:
msglogger.warning("!!! The config file does not specify the modules to compress for %s" % app_args.arch)
# Default to using all convolution layers
distiller.assign_layer_fq_names(model)
modules_list = [mod.distiller_name for mod in model.modules() if type(mod)==torch.nn.Conv2d]
msglogger.warning("Using the following layers: %s" % ", ".join(modules_list))
raise ValueError("The config file does not specify the modules to compress for %s" % app_args.arch) raise ValueError("The config file does not specify the modules to compress for %s" % app_args.arch)
self.net_wrapper = NetworkWrapper(model, app_args, services, modules_list, amc_cfg.pruning_pattern) self.net_wrapper = NetworkWrapper(model, app_args, services, modules_list, amc_cfg.pruning_pattern)
self.original_model_macs, self.original_model_size = self.net_wrapper.get_resources_requirements() self.original_model_macs, self.original_model_size = self.net_wrapper.get_resources_requirements()
self.reset(init_only=True) self.reset(init_only=True)
msglogger.debug("Model %s has %d modules (%d pruned)", self.app_args.arch,
self.net_wrapper.model_metadata.num_layers(),
self.net_wrapper.model_metadata.num_pruned_layers())
msglogger.debug("\tTotal MACs: %s" % distiller.pretty_int(self.original_model_macs))
msglogger.debug("\tTotal weights: %s" % distiller.pretty_int(self.original_model_size))
self._max_episode_steps = self.net_wrapper.model_metadata.num_pruned_layers() # Hack for Coach-TD3 self._max_episode_steps = self.net_wrapper.model_metadata.num_pruned_layers() # Hack for Coach-TD3
log_amc_config(amc_cfg)
self.episode = 0 self.episode = 0
self.best_reward = float("-inf") self.best_reward = float("-inf")
self.action_low = amc_cfg.action_range[0] self.action_low = amc_cfg.action_range[0]
self.action_high = amc_cfg.action_range[1] self.action_high = amc_cfg.action_range[1]
self._log_model_info()
log_amc_config(amc_cfg)
self._configure_action_space()
self.observation_space = spaces.Box(0, float("inf"), shape=(len(Observation._fields),))
self.stats_logger = AMCStatsLogger(os.path.join(logdir, 'amc.csv'))
self.ft_stats_logger = FineTuneStatsLogger(os.path.join(logdir, 'ft_top1.csv'))
if self.amc_cfg.pruning_method == "fm-reconstruction":
self._collect_fm_reconstruction_samples(modules_list)
def _collect_fm_reconstruction_samples(self, modules_list):
"""Run the forward-pass on the selected dataset and collect feature-map samples.
These data will be used when we optimize the compressed-net's weights by trying
to reconstruct these samples.
"""
from functools import partial
if self.amc_cfg.pruning_pattern != "channels":
raise ValueError("Feature-map reconstruction is only supported when pruning weights channels")
def acceptance_criterion(m, mod_names):
# Collect feature-maps only for Conv2d layers, if they are in our modules list.
return isinstance(m, torch.nn.Conv2d) and m.distiller_name in mod_names
# For feature-map reconstruction we need to collect a representative set
# of inter-layer feature-maps
from distiller.pruning import FMReconstructionChannelPruner
collect_intermediate_featuremap_samples(
self.net_wrapper.model,
self.net_wrapper.validate,
partial(acceptance_criterion, mod_names=modules_list),
partial(FMReconstructionChannelPruner.cache_featuremaps_fwd_hook,
n_points_per_fm=self.amc_cfg.n_points_per_fm))
def _log_model_info(self):
msglogger.debug("Model %s has %d modules (%d pruned)", self.app_args.arch,
self.net_wrapper.model_metadata.num_layers(),
self.net_wrapper.model_metadata.num_pruned_layers())
msglogger.debug("\tTotal MACs: %s" % distiller.pretty_int(self.original_model_macs))
msglogger.debug("\tTotal weights: %s" % distiller.pretty_int(self.original_model_size))
def _configure_action_space(self):
if is_using_continuous_action_space(self.amc_cfg.agent_algo): if is_using_continuous_action_space(self.amc_cfg.agent_algo):
if self.amc_cfg.agent_algo == "ClippedPPO-continuous": if self.amc_cfg.agent_algo == "ClippedPPO-continuous":
self.action_space = spaces.Box(PPO_MIN, PPO_MAX, shape=(1,)) self.action_space = spaces.Box(PPO_MIN, PPO_MAX, shape=(1,))
else: else:
self.action_space = spaces.Box(self.action_low, self.action_high, shape=(1,)) self.action_space = spaces.Box(self.action_low, self.action_high, shape=(1,))
self.action_space.default_action = self.action_low self.action_space.default_action = self.action_low
else: else:
self.action_space = spaces.Discrete(10) self.action_space = spaces.Discrete(10)
self.observation_space = spaces.Box(0, float("inf"), shape=(len(Observation._fields),))
self.stats_logger = AMCStatsLogger(os.path.join(logdir, 'amc.csv'))
self.ft_stats_logger = FineTuneStatsLogger(os.path.join(logdir, 'ft_top1.csv'))
if self.amc_cfg.pruning_method == "fm-reconstruction":
if self.amc_cfg.pruning_pattern != "channels":
raise ValueError("Feature-map reconstruction is only supported when pruning weights channels")
from functools import partial
def acceptance_criterion(m, mod_names):
# Collect feature-maps only for Conv2d layers, if they are in our modules list.
return isinstance(m, torch.nn.Conv2d) and m.distiller_name in mod_names
# For feature-map reconstruction we need to collect a representative set
# of inter-layer feature-maps
from distiller.pruning import FMReconstructionChannelPruner
collect_intermediate_featuremap_samples(
self.net_wrapper.model,
self.net_wrapper.validate,
partial(acceptance_criterion, mod_names=modules_list),
partial(FMReconstructionChannelPruner.cache_featuremaps_fwd_hook,
n_points_per_fm=self.amc_cfg.n_points_per_fm))
@property @property
def steps_per_episode(self): def steps_per_episode(self):
return self.net_wrapper.model_metadata.num_pruned_layers() return self.net_wrapper.model_metadata.num_pruned_layers()
...@@ -431,7 +439,7 @@ class DistillerWrapperEnvironment(gym.Env): ...@@ -431,7 +439,7 @@ class DistillerWrapperEnvironment(gym.Env):
def step(self, pruning_action): def step(self, pruning_action):
"""Take a step, given an action. """Take a step, given an action.
The action represents the desired sparsity for the "current" layer. The action represents the desired sparsity for the "current" layer (i.e. the percentage of weights to remove).
This function is invoked by the Agent. This function is invoked by the Agent.
""" """
pruning_action = float(pruning_action[0]) pruning_action = float(pruning_action[0])
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment