Skip to content
Snippets Groups Projects
Commit 738d57f4 authored by Neta Zmora's avatar Neta Zmora
Browse files

AMC: fix the replay_buffer_size when using Coach and DDPG

parent 9f7f6b14
No related branches found
No related tags found
No related merge requests found
......@@ -131,7 +131,6 @@ def train_auto_compressor(model, args, optimizer_data, validate_fn, save_checkpo
def create_environment():
env = DistillerWrapperEnvironment(model, app_args, amc_cfg, services)
#env.amc_cfg.ddpg_cfg.replay_buffer_size = int(1.5 * amc_cfg.ddpg_cfg.num_heatup_episodes * env.steps_per_episode)
env.amc_cfg.ddpg_cfg.replay_buffer_size = amc_cfg.ddpg_cfg.num_heatup_episodes * env.steps_per_episode
return env
......@@ -151,10 +150,10 @@ def train_auto_compressor(model, args, optimizer_data, validate_fn, save_checkpo
elif args.amc_rllib == "coach":
from rl_libs.coach import coach_if
rl = coach_if.RlLibInterface()
env_cfg = {'model': model,
'app_args': app_args,
'amc_cfg': amc_cfg,
'services': services}
env_cfg = {'model': env1.model,
'app_args': env1.app_args,
'amc_cfg': env1.amc_cfg,
'services': env1.services}
steps_per_episode = env1.steps_per_episode
rl.solve(**env_cfg, steps_per_episode=steps_per_episode)
elif args.amc_rllib == "random":
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment