Skip to content
Snippets Groups Projects
Commit 8c07eb1e authored by Neta Zmora's avatar Neta Zmora
Browse files

AMC: added configuration option to set the frequency of computing a reward

--amc-reward-frequency
Computing the reward requires running the evaluated network on the Test
dataset (or parts of it) and may involve short-term fine-tuning before
the evaluation (depending on the configuration).
Use this new argument to configure the number of steps/iterations between
reward computation.
parent 97d5e48c
No related branches found
No related tags found
No related merge requests found
...@@ -214,11 +214,11 @@ def do_adc_internal(model, args, optimizer_data, validate_fn, save_checkpoint_fn ...@@ -214,11 +214,11 @@ def do_adc_internal(model, args, optimizer_data, validate_fn, save_checkpoint_fn
amc_cfg = distiller.utils.MutableNamedTuple({ amc_cfg = distiller.utils.MutableNamedTuple({
'protocol': args.amc_protocol, 'protocol': args.amc_protocol,
'agent_algo': args.amc_agent_algo, 'agent_algo': args.amc_agent_algo,
'compute_reward_every_step': args.amc_reward_every_step,
'perform_thinning': perform_thinning, 'perform_thinning': perform_thinning,
'num_ft_epochs': num_ft_epochs, 'num_ft_epochs': num_ft_epochs,
'action_range': action_range, 'action_range': action_range,
'conv_cnt': conv_cnt}) 'conv_cnt': conv_cnt,
'reward_frequency': args.amc_reward_frequency})
#net_wrapper = NetworkWrapper(model, app_args, services) #net_wrapper = NetworkWrapper(model, app_args, services)
#return sample_networks(net_wrapper, services) #return sample_networks(net_wrapper, services)
...@@ -239,13 +239,14 @@ def do_adc_internal(model, args, optimizer_data, validate_fn, save_checkpoint_fn ...@@ -239,13 +239,14 @@ def do_adc_internal(model, args, optimizer_data, validate_fn, save_checkpoint_fn
raise ValueError("{} is not supported currently".format(args.amc_protocol)) raise ValueError("{} is not supported currently".format(args.amc_protocol))
steps_per_episode = conv_cnt steps_per_episode = conv_cnt
amc_cfg.heatup_noise = 0.5 if args.amc_agent_algo == "DDPG":
amc_cfg.initial_training_noise = 0.5 amc_cfg.heatup_noise = 0.5
amc_cfg.training_noise_decay = 0.996 # 0.998 amc_cfg.initial_training_noise = 0.5
amc_cfg.num_heatup_epochs = args.amc_heatup_epochs amc_cfg.training_noise_decay = 0.996 # 0.998
amc_cfg.num_training_epochs = args.amc_training_epochs amc_cfg.num_heatup_epochs = args.amc_heatup_epochs
training_noise_duration = amc_cfg.num_training_epochs * steps_per_episode amc_cfg.num_training_epochs = args.amc_training_epochs
heatup_duration = amc_cfg.num_heatup_epochs * steps_per_episode training_noise_duration = amc_cfg.num_training_epochs * steps_per_episode
heatup_duration = amc_cfg.num_heatup_epochs * steps_per_episode
if amc_cfg.agent_algo == "Random-policy": if amc_cfg.agent_algo == "Random-policy":
return random_agent(DistillerWrapperEnvironment(model, app_args, amc_cfg, services)) return random_agent(DistillerWrapperEnvironment(model, app_args, amc_cfg, services))
...@@ -603,10 +604,10 @@ class DistillerWrapperEnvironment(gym.Env): ...@@ -603,10 +604,10 @@ class DistillerWrapperEnvironment(gym.Env):
self.episode += 1 self.episode += 1
else: else:
observation = self.get_obs() observation = self.get_obs()
reward = 0 if self.amc_cfg.reward_frequency > 0 and self.current_layer_id % self.amc_cfg.reward_frequency == 0:
if self.amc_cfg.compute_reward_every_step:
reward, top1, total_macs, total_nnz = self.compute_reward(False) reward, top1, total_macs, total_nnz = self.compute_reward(False)
else:
reward = 0
self.prev_action = pruning_action self.prev_action = pruning_action
info = {} info = {}
return observation, reward, self.episode_is_done(), info return observation, reward, self.episode_is_done(), info
......
...@@ -22,8 +22,8 @@ def add_automl_args(argparser, arch_choices=None, enable_pretrained=False): ...@@ -22,8 +22,8 @@ def add_automl_args(argparser, arch_choices=None, enable_pretrained=False):
help='The number of epochs for heatup/exploration') help='The number of epochs for heatup/exploration')
group.add_argument('--amc-training-epochs', type=int, default=300, group.add_argument('--amc-training-epochs', type=int, default=300,
help='The number of epochs for training/exploitation') help='The number of epochs for training/exploitation')
group.add_argument('--amc-reward-every-step', action='store_true', default=False, group.add_argument('--amc-reward-frequency', type=int, default=-1,
help='Compute the reward at every step') help='Reward computation frequency (measured in agent steps)')
group.add_argument('--amc-target-density', type=float, group.add_argument('--amc-target-density', type=float,
help='Target density of the network we are seeking') help='Target density of the network we are seeking')
group.add_argument('--amc-agent-algo', choices=["ClippedPPO-continuous", group.add_argument('--amc-agent-algo', choices=["ClippedPPO-continuous",
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment