From 738d57f4c4f42fe64f392ee920136fcfb3e3f311 Mon Sep 17 00:00:00 2001 From: Neta Zmora <neta.zmora@intel.com> Date: Mon, 7 Oct 2019 13:51:39 +0300 Subject: [PATCH] AMC: fix the replay_buffer_size when using Coach and DDPG --- examples/auto_compression/amc/amc.py | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/examples/auto_compression/amc/amc.py b/examples/auto_compression/amc/amc.py index d30473e..4b98877 100755 --- a/examples/auto_compression/amc/amc.py +++ b/examples/auto_compression/amc/amc.py @@ -131,7 +131,6 @@ def train_auto_compressor(model, args, optimizer_data, validate_fn, save_checkpo def create_environment(): env = DistillerWrapperEnvironment(model, app_args, amc_cfg, services) - #env.amc_cfg.ddpg_cfg.replay_buffer_size = int(1.5 * amc_cfg.ddpg_cfg.num_heatup_episodes * env.steps_per_episode) env.amc_cfg.ddpg_cfg.replay_buffer_size = amc_cfg.ddpg_cfg.num_heatup_episodes * env.steps_per_episode return env @@ -151,10 +150,10 @@ def train_auto_compressor(model, args, optimizer_data, validate_fn, save_checkpo elif args.amc_rllib == "coach": from rl_libs.coach import coach_if rl = coach_if.RlLibInterface() - env_cfg = {'model': model, - 'app_args': app_args, - 'amc_cfg': amc_cfg, - 'services': services} + env_cfg = {'model': env1.model, + 'app_args': env1.app_args, + 'amc_cfg': env1.amc_cfg, + 'services': env1.services} steps_per_episode = env1.steps_per_episode rl.solve(**env_cfg, steps_per_episode=steps_per_episode) elif args.amc_rllib == "random": -- GitLab