diff --git a/examples/auto_compression/amc/amc.py b/examples/auto_compression/amc/amc.py index d30473e27b082e6caa26ff68d4f6d87129b72a8b..4b98877e2b4753fad1872c56c7166ed0c5f0f6dc 100755 --- a/examples/auto_compression/amc/amc.py +++ b/examples/auto_compression/amc/amc.py @@ -131,7 +131,6 @@ def train_auto_compressor(model, args, optimizer_data, validate_fn, save_checkpo def create_environment(): env = DistillerWrapperEnvironment(model, app_args, amc_cfg, services) - #env.amc_cfg.ddpg_cfg.replay_buffer_size = int(1.5 * amc_cfg.ddpg_cfg.num_heatup_episodes * env.steps_per_episode) env.amc_cfg.ddpg_cfg.replay_buffer_size = amc_cfg.ddpg_cfg.num_heatup_episodes * env.steps_per_episode return env @@ -151,10 +150,10 @@ def train_auto_compressor(model, args, optimizer_data, validate_fn, save_checkpo elif args.amc_rllib == "coach": from rl_libs.coach import coach_if rl = coach_if.RlLibInterface() - env_cfg = {'model': model, - 'app_args': app_args, - 'amc_cfg': amc_cfg, - 'services': services} + env_cfg = {'model': env1.model, + 'app_args': env1.app_args, + 'amc_cfg': env1.amc_cfg, + 'services': env1.services} steps_per_episode = env1.steps_per_episode rl.solve(**env_cfg, steps_per_episode=steps_per_episode) elif args.amc_rllib == "random":