Skip to content
Snippets Groups Projects
Commit d9f6bfed authored by Neta Zmora's avatar Neta Zmora
Browse files

Update ADC_DDPG.py to comply with latest Coach changes

parent 434540d4
No related branches found
No related merge requests found
......@@ -6,21 +6,20 @@ from rl_coach.core_types import EnvironmentEpisodes, EnvironmentSteps
from rl_coach.environments.gym_environment import GymVectorEnvironment
from rl_coach.exploration_policies.truncated_normal import TruncatedNormalParameters
from rl_coach.exploration_policies.additive_noise import AdditiveNoiseParameters
from rl_coach.memories.memory import MemoryGranularity
from rl_coach.base_parameters import EmbedderScheme
from rl_coach.architectures.tensorflow_components.layers import Dense
steps_per_episode = 13
from rl_coach.base_parameters import EmbeddingMergerType
from rl_coach.filters.filter import InputFilter
# !!!! Enable when using branch "distiller-AMC-induced-changes"
# from rl_coach.filters.reward import RewardEwmaNormalizationFilter
import numpy as np
####################
# Graph Scheduling #
####################
schedule_params = ScheduleParameters()
schedule_params.improve_steps = EnvironmentEpisodes(800)
schedule_params.steps_between_evaluation_periods = EnvironmentEpisodes(5000)
schedule_params.evaluation_steps = EnvironmentEpisodes(0) # Neta: 0
schedule_params.heatup_steps = EnvironmentEpisodes(100)
schedule_params.steps_between_evaluation_periods = EnvironmentEpisodes(0)
schedule_params.evaluation_steps = EnvironmentEpisodes(0)
#####################
# DDPG Agent Params #
......@@ -31,7 +30,9 @@ agent_params.network_wrappers['actor'].middleware_parameters.scheme = [Dense(300
agent_params.network_wrappers['actor'].heads_parameters[0].activation_function = 'sigmoid'
agent_params.network_wrappers['critic'].input_embedders_parameters['observation'].scheme = [Dense(300)]
agent_params.network_wrappers['critic'].middleware_parameters.scheme = [Dense(300)]
agent_params.network_wrappers['critic'].input_embedders_parameters['action'].scheme = EmbedderScheme.Empty
agent_params.network_wrappers['critic'].input_embedders_parameters['action'].scheme = [Dense(300)]
agent_params.network_wrappers['critic'].embedding_merger_type = EmbeddingMergerType.Sum
agent_params.network_wrappers['actor'].optimizer_type = 'Adam'
agent_params.network_wrappers['actor'].adam_optimizer_beta1 = 0.9
......@@ -44,8 +45,11 @@ agent_params.network_wrappers['critic'].adam_optimizer_beta1 = 0.9
agent_params.network_wrappers['critic'].adam_optimizer_beta2 = 0.999
agent_params.network_wrappers['critic'].optimizer_epsilon = 1e-8
agent_params.network_wrappers['actor'].learning_rate = 0.0001
agent_params.network_wrappers['critic'].learning_rate = 0.001
agent_params.network_wrappers['actor'].learning_rate = 1e-4
agent_params.network_wrappers['critic'].learning_rate = 1e-3
# !!!! Enable when using branch "distiller-AMC-induced-changes"
# agent_params.algorithm.override_episode_rewards_with_the_last_transition_reward = True
agent_params.algorithm.rate_for_copying_weights_to_target = 0.01 # Tau pg. 11
#agent_params.algorithm.num_steps_between_copying_online_weights_to_target = EnvironmentSteps(1)
......@@ -53,15 +57,16 @@ agent_params.algorithm.heatup_using_network_decisions = False # We want uniform-
agent_params.algorithm.discount = 1
agent_params.algorithm.use_non_zero_discount_for_terminal_states = True
# See : https://nervanasystems.github.io/coach/components/agents/policy_optimization/ddpg.html?highlight=ddpg#rl_coach.agents.ddpg_agent.DDPGAlgorithmParameters
# Replay buffer size
agent_params.memory.max_size = (MemoryGranularity.Transitions, 2000)
agent_params.exploration = TruncatedNormalParameters()
agent_params.algorithm.use_target_network_for_evaluation = True
#agent_params.exploration.evaluation_noise_percentage = 0 # Neta new
#agent_params.exploration = AdditiveNoiseParameters()
agent_params.exploration.noise_as_percentage_from_action_space = False
agent_params.exploration.evaluation_noise = 0 # Neta new
#agent_params.algorithm.num_consecutive_playing_steps = EnvironmentSteps(1)
agent_params.algorithm.use_target_network_for_evaluation = True
agent_params.algorithm.act_for_full_episodes = True
# !!!! Enable when using branch "distiller-AMC-induced-changes"
#agent_params.pre_network_filter = InputFilter()
#agent_params.pre_network_filter.add_reward_filter('ewma_norm', RewardEwmaNormalizationFilter(alpha=0.5))
##############################
# Gym #
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment