From 738d57f4c4f42fe64f392ee920136fcfb3e3f311 Mon Sep 17 00:00:00 2001
From: Neta Zmora <neta.zmora@intel.com>
Date: Mon, 7 Oct 2019 13:51:39 +0300
Subject: [PATCH] AMC: fix the replay_buffer_size when using Coach and DDPG

---
 examples/auto_compression/amc/amc.py | 9 ++++-----
 1 file changed, 4 insertions(+), 5 deletions(-)

diff --git a/examples/auto_compression/amc/amc.py b/examples/auto_compression/amc/amc.py
index d30473e..4b98877 100755
--- a/examples/auto_compression/amc/amc.py
+++ b/examples/auto_compression/amc/amc.py
@@ -131,7 +131,6 @@ def train_auto_compressor(model, args, optimizer_data, validate_fn, save_checkpo
 
     def create_environment():
         env = DistillerWrapperEnvironment(model, app_args, amc_cfg, services)
-        #env.amc_cfg.ddpg_cfg.replay_buffer_size = int(1.5 * amc_cfg.ddpg_cfg.num_heatup_episodes * env.steps_per_episode)
         env.amc_cfg.ddpg_cfg.replay_buffer_size = amc_cfg.ddpg_cfg.num_heatup_episodes * env.steps_per_episode
         return env
 
@@ -151,10 +150,10 @@ def train_auto_compressor(model, args, optimizer_data, validate_fn, save_checkpo
     elif args.amc_rllib == "coach":
         from rl_libs.coach import coach_if
         rl = coach_if.RlLibInterface()
-        env_cfg  = {'model': model, 
-                    'app_args': app_args,
-                    'amc_cfg': amc_cfg,
-                    'services': services}
+        env_cfg = {'model': env1.model,
+                   'app_args': env1.app_args,
+                   'amc_cfg': env1.amc_cfg,
+                   'services': env1.services}
         steps_per_episode = env1.steps_per_episode
         rl.solve(**env_cfg, steps_per_episode=steps_per_episode)
     elif args.amc_rllib == "random":
-- 
GitLab