Skip to content

Commit b8c72c9

Browse files
authored
Merge pull request #416 from prabhatnagarajan/batch_ddpg
Enables batch DDPG agents to be trained.
2 parents e9cd2f9 + 27ffbad commit b8c72c9

File tree

5 files changed

+294
-8
lines changed

5 files changed

+294
-8
lines changed

chainerrl/agents/ddpg.py

Lines changed: 88 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,8 @@
1313
from chainer import cuda
1414
import chainer.functions as F
1515

16-
from chainerrl.agent import Agent
1716
from chainerrl.agent import AttributeSavingMixin
17+
from chainerrl.agent import BatchAgent
1818
from chainerrl.misc.batch_states import batch_states
1919
from chainerrl.misc.copy_param import synchronize_parameters
2020
from chainerrl.recurrent import Recurrent
@@ -40,7 +40,7 @@ def __init__(self, policy, q_func):
4040
super().__init__(policy=policy, q_function=q_func)
4141

4242

43-
class DDPG(AttributeSavingMixin, Agent):
43+
class DDPG(AttributeSavingMixin, BatchAgent):
4444
"""Deep Deterministic Policy Gradients.
4545
4646
This can be used as SVG(0) by specifying a Gaussian policy instead of a
@@ -178,7 +178,6 @@ def compute_critic_loss(self, batch):
178178
batch_terminal = batch['is_state_terminal']
179179
batch_state = batch['state']
180180
batch_actions = batch['action']
181-
batch_next_actions = batch['next_action']
182181
batchsize = len(batch_rewards)
183182

184183
with chainer.no_backprop_mode():
@@ -193,6 +192,7 @@ def compute_critic_loss(self, batch):
193192

194193
# Target Q-function observes s_{t+1} and a_{t+1}
195194
if isinstance(self.target_q_function, Recurrent):
195+
batch_next_actions = batch['next_action']
196196
self.target_q_function.update_state(
197197
batch_next_state, batch_next_actions)
198198

@@ -344,6 +344,91 @@ def act(self, obs):
344344
self.t, action.array[0], q.array)
345345
return cuda.to_cpu(action.array[0])
346346

347+
def batch_act(self, batch_obs):
348+
"""Select a batch of actions for evaluation.
349+
350+
Args:
351+
batch_obs (Sequence of ~object): Observations.
352+
353+
Returns:
354+
Sequence of ~object: Actions.
355+
"""
356+
357+
with chainer.using_config('train', False), chainer.no_backprop_mode():
358+
batch_xs = self.batch_states(batch_obs, self.xp, self.phi)
359+
batch_action = self.policy(batch_xs).sample()
360+
# Q is not needed here, but log it just for information
361+
q = self.q_function(batch_xs, batch_action)
362+
363+
# Update stats
364+
self.average_q *= self.average_q_decay
365+
self.average_q += (1 - self.average_q_decay) * float(
366+
q.array.mean(axis=0))
367+
self.logger.debug('t:%s a:%s q:%s',
368+
self.t, batch_action.array[0], q.array)
369+
return [cuda.to_cpu(action.array) for action in batch_action]
370+
371+
def batch_act_and_train(self, batch_obs):
372+
"""Select a batch of actions for training.
373+
374+
Args:
375+
batch_obs (Sequence of ~object): Observations.
376+
377+
Returns:
378+
Sequence of ~object: Actions.
379+
"""
380+
381+
batch_greedy_action = self.batch_act(batch_obs)
382+
batch_action = [
383+
self.explorer.select_action(
384+
self.t, lambda: batch_greedy_action[i])
385+
for i in range(len(batch_greedy_action))]
386+
387+
self.batch_last_obs = list(batch_obs)
388+
self.batch_last_action = list(batch_action)
389+
390+
return batch_action
391+
392+
def batch_observe_and_train(
393+
self, batch_obs, batch_reward, batch_done, batch_reset):
394+
"""Observe a batch of action consequences for training.
395+
396+
Args:
397+
batch_obs (Sequence of ~object): Observations.
398+
batch_reward (Sequence of float): Rewards.
399+
batch_done (Sequence of boolean): Boolean values where True
400+
indicates the current state is terminal.
401+
batch_reset (Sequence of boolean): Boolean values where True
402+
indicates the current episode will be reset, even if the
403+
current state is not terminal.
404+
405+
Returns:
406+
None
407+
"""
408+
for i in range(len(batch_obs)):
409+
self.t += 1
410+
# Update the target network
411+
if self.t % self.target_update_interval == 0:
412+
self.sync_target_network()
413+
if self.batch_last_obs[i] is not None:
414+
assert self.batch_last_action[i] is not None
415+
# Add a transition to the replay buffer
416+
self.replay_buffer.append(
417+
state=self.batch_last_obs[i],
418+
action=self.batch_last_action[i],
419+
reward=batch_reward[i],
420+
next_state=batch_obs[i],
421+
next_action=None,
422+
is_state_terminal=batch_done[i],
423+
)
424+
if batch_reset[i] or batch_done[i]:
425+
self.batch_last_obs[i] = None
426+
self.replay_updater.update_if_necessary(self.t)
427+
428+
def batch_observe(self, batch_obs, batch_reward,
429+
batch_done, batch_reset):
430+
pass
431+
347432
def stop_episode_and_train(self, state, reward, done=False):
348433

349434
assert self.last_state is not None

chainerrl/q_functions/state_action_q_functions.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -146,8 +146,7 @@ def __call__(self, state, action):
146146
return super().__call__(h)
147147

148148

149-
class FCBNLateActionSAQFunction(chainer.Chain, StateActionQFunction,
150-
RecurrentChainMixin):
149+
class FCBNLateActionSAQFunction(chainer.Chain, StateActionQFunction):
151150
"""Fully-connected + BN (s,a)-input Q-function with late action input.
152151
153152
Actions are not included until the second hidden layer and not normalized.
@@ -202,8 +201,7 @@ def __call__(self, state, action):
202201
return self.mlp(h)
203202

204203

205-
class FCLateActionSAQFunction(chainer.Chain, StateActionQFunction,
206-
RecurrentChainMixin):
204+
class FCLateActionSAQFunction(chainer.Chain, StateActionQFunction):
207205
"""Fully-connected (s,a)-input Q-function with late action input.
208206
209207
Actions are not included until the second hidden layer and not normalized.
Lines changed: 194 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,194 @@
1+
from __future__ import print_function
2+
from __future__ import unicode_literals
3+
from __future__ import division
4+
from __future__ import absolute_import
5+
from future import standard_library
6+
standard_library.install_aliases() # NOQA
7+
import argparse
8+
import sys
9+
10+
import chainer
11+
from chainer import optimizers
12+
import gym
13+
from gym import spaces
14+
import gym.wrappers
15+
import numpy as np
16+
17+
import chainerrl
18+
from chainerrl.agents.ddpg import DDPG
19+
from chainerrl.agents.ddpg import DDPGModel
20+
from chainerrl import experiments
21+
from chainerrl import explorers
22+
from chainerrl import misc
23+
from chainerrl import policy
24+
from chainerrl import q_functions
25+
from chainerrl import replay_buffer
26+
27+
28+
def main():
29+
import logging
30+
logging.basicConfig(level=logging.DEBUG)
31+
32+
parser = argparse.ArgumentParser()
33+
parser.add_argument('--outdir', type=str, default='results',
34+
help='Directory path to save output files.'
35+
' If it does not exist, it will be created.')
36+
parser.add_argument('--env', type=str, default='Humanoid-v2')
37+
parser.add_argument('--seed', type=int, default=0,
38+
help='Random seed [0, 2 ** 32)')
39+
parser.add_argument('--gpu', type=int, default=0)
40+
parser.add_argument('--final-exploration-steps',
41+
type=int, default=10 ** 6)
42+
parser.add_argument('--actor-lr', type=float, default=1e-4)
43+
parser.add_argument('--critic-lr', type=float, default=1e-3)
44+
parser.add_argument('--load', type=str, default='')
45+
parser.add_argument('--steps', type=int, default=10 ** 7)
46+
parser.add_argument('--n-hidden-channels', type=int, default=300)
47+
parser.add_argument('--n-hidden-layers', type=int, default=3)
48+
parser.add_argument('--replay-start-size', type=int, default=5000)
49+
parser.add_argument('--n-update-times', type=int, default=1)
50+
parser.add_argument('--target-update-interval',
51+
type=int, default=1)
52+
parser.add_argument('--target-update-method',
53+
type=str, default='soft', choices=['hard', 'soft'])
54+
parser.add_argument('--soft-update-tau', type=float, default=1e-2)
55+
parser.add_argument('--update-interval', type=int, default=4)
56+
parser.add_argument('--eval-n-runs', type=int, default=100)
57+
parser.add_argument('--eval-interval', type=int, default=10 ** 5)
58+
parser.add_argument('--gamma', type=float, default=0.995)
59+
parser.add_argument('--minibatch-size', type=int, default=200)
60+
parser.add_argument('--render', action='store_true')
61+
parser.add_argument('--demo', action='store_true')
62+
parser.add_argument('--use-bn', action='store_true', default=False)
63+
parser.add_argument('--monitor', action='store_true')
64+
parser.add_argument('--reward-scale-factor', type=float, default=1e-2)
65+
parser.add_argument('--num-envs', type=int, default=1)
66+
args = parser.parse_args()
67+
68+
args.outdir = experiments.prepare_output_dir(
69+
args, args.outdir, argv=sys.argv)
70+
print('Output files are saved in {}'.format(args.outdir))
71+
72+
# Set a random seed used in ChainerRL
73+
misc.set_random_seed(args.seed, gpus=(args.gpu,))
74+
75+
def clip_action_filter(a):
76+
return np.clip(a, action_space.low, action_space.high)
77+
78+
def reward_filter(r):
79+
return r * args.reward_scale_factor
80+
81+
# Set different random seeds for different subprocesses.
82+
# If seed=0 and processes=4, subprocess seeds are [0, 1, 2, 3].
83+
# If seed=1 and processes=4, subprocess seeds are [4, 5, 6, 7].
84+
process_seeds = np.arange(args.num_envs) + args.seed * args.num_envs
85+
assert process_seeds.max() < 2 ** 32
86+
87+
def make_env(idx, test):
88+
env = gym.make(args.env)
89+
# Use different random seeds for train and test envs
90+
process_seed = int(process_seeds[idx])
91+
env_seed = 2 ** 32 - 1 - process_seed if test else process_seed
92+
env.seed(env_seed)
93+
# Cast observations to float32 because our model uses float32
94+
env = chainerrl.wrappers.CastObservationToFloat32(env)
95+
if args.monitor:
96+
env = gym.wrappers.Monitor(env, args.outdir)
97+
if isinstance(env.action_space, spaces.Box):
98+
misc.env_modifiers.make_action_filtered(env, clip_action_filter)
99+
if not test:
100+
# Scale rewards (and thus returns) to a reasonable range so that
101+
# training is easier
102+
env = chainerrl.wrappers.ScaleReward(env, args.reward_scale_factor)
103+
if args.render and not test:
104+
env = chainerrl.wrappers.Render(env)
105+
return env
106+
107+
def make_batch_env(test):
108+
return chainerrl.envs.MultiprocessVectorEnv(
109+
[(lambda: make_env(idx, test))
110+
for idx, env in enumerate(range(args.num_envs))])
111+
112+
sample_env = make_env(0, test=False)
113+
timestep_limit = sample_env.spec.tags.get(
114+
'wrapper_config.TimeLimit.max_episode_steps')
115+
116+
obs_size = np.asarray(sample_env.observation_space.shape).prod()
117+
action_space = sample_env.action_space
118+
119+
action_size = np.asarray(action_space.shape).prod()
120+
if args.use_bn:
121+
q_func = q_functions.FCBNLateActionSAQFunction(
122+
obs_size, action_size,
123+
n_hidden_channels=args.n_hidden_channels,
124+
n_hidden_layers=args.n_hidden_layers,
125+
normalize_input=True)
126+
pi = policy.FCBNDeterministicPolicy(
127+
obs_size, action_size=action_size,
128+
n_hidden_channels=args.n_hidden_channels,
129+
n_hidden_layers=args.n_hidden_layers,
130+
min_action=action_space.low, max_action=action_space.high,
131+
bound_action=True,
132+
normalize_input=True)
133+
else:
134+
q_func = q_functions.FCSAQFunction(
135+
obs_size, action_size,
136+
n_hidden_channels=args.n_hidden_channels,
137+
n_hidden_layers=args.n_hidden_layers)
138+
pi = policy.FCDeterministicPolicy(
139+
obs_size, action_size=action_size,
140+
n_hidden_channels=args.n_hidden_channels,
141+
n_hidden_layers=args.n_hidden_layers,
142+
min_action=action_space.low, max_action=action_space.high,
143+
bound_action=True)
144+
model = DDPGModel(q_func=q_func, policy=pi)
145+
opt_a = optimizers.Adam(alpha=args.actor_lr)
146+
opt_c = optimizers.Adam(alpha=args.critic_lr)
147+
opt_a.setup(model['policy'])
148+
opt_c.setup(model['q_function'])
149+
opt_a.add_hook(chainer.optimizer.GradientClipping(1.0), 'hook_a')
150+
opt_c.add_hook(chainer.optimizer.GradientClipping(1.0), 'hook_c')
151+
152+
rbuf = replay_buffer.ReplayBuffer(5 * 10 ** 5)
153+
154+
def random_action():
155+
a = action_space.sample()
156+
if isinstance(a, np.ndarray):
157+
a = a.astype(np.float32)
158+
return a
159+
160+
ou_sigma = (action_space.high - action_space.low) * 0.2
161+
explorer = explorers.AdditiveOU(sigma=ou_sigma)
162+
agent = DDPG(model, opt_a, opt_c, rbuf, gamma=args.gamma,
163+
explorer=explorer, replay_start_size=args.replay_start_size,
164+
target_update_method=args.target_update_method,
165+
target_update_interval=args.target_update_interval,
166+
update_interval=args.update_interval,
167+
soft_update_tau=args.soft_update_tau,
168+
n_times_update=args.n_update_times,
169+
gpu=args.gpu, minibatch_size=args.minibatch_size)
170+
171+
if len(args.load) > 0:
172+
agent.load(args.load)
173+
174+
if args.demo:
175+
eval_stats = experiments.eval_performance(
176+
env=make_batch_env(test=True),
177+
agent=agent,
178+
n_steps=None,
179+
n_episodes=args.eval_n_runs,
180+
max_episode_len=timestep_limit)
181+
print('n_runs: {} mean: {} median: {} stdev {}'.format(
182+
args.eval_n_runs, eval_stats['mean'], eval_stats['median'],
183+
eval_stats['stdev']))
184+
else:
185+
experiments.train_agent_batch_with_evaluation(
186+
agent=agent, env=make_batch_env(test=False), steps=args.steps,
187+
eval_env=make_batch_env(test=True), eval_n_steps=None,
188+
eval_n_episodes=args.eval_n_runs, eval_interval=args.eval_interval,
189+
outdir=args.outdir,
190+
max_episode_len=timestep_limit)
191+
192+
193+
if __name__ == '__main__':
194+
main()

test_examples.sh

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -93,6 +93,11 @@ python examples/gym/train_ddpg_gym.py --steps 100 --replay-start-size 50 --minib
9393
model=$(find $outdir/gym/ddpg -name "*_finish")
9494
python examples/gym/train_ddpg_gym.py --demo --load $model --eval-n-runs 1 --env Pendulum-v0 --outdir $outdir/temp --gpu $gpu
9595

96+
# gym/ddpg batch (specify non-mujoco env to test without mujoco)
97+
python examples/gym/train_ddpg_batch_gym.py --steps 100 --replay-start-size 50 --minibatch-size 32 --outdir $outdir/gym/ddpg_batch --env Pendulum-v0 --gpu $gpu
98+
model=$(find $outdir/gym/ddpg_batch -name "*_finish")
99+
python examples/gym/train_ddpg_batch_gym.py --demo --load $model --eval-n-runs 1 --env Pendulum-v0 --outdir $outdir/temp --gpu $gpu
100+
96101
# gym/reinforce
97102
python examples/gym/train_reinforce_gym.py --steps 100 --batchsize 1 --outdir $outdir/gym/reinforce --gpu $gpu
98103
model=$(find $outdir/gym/reinforce -name "*_finish")

tests/agents_tests/test_ddpg.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,10 @@
99
import basetest_ddpg as base
1010
from chainerrl.agents.ddpg import DDPG
1111

12+
from basetest_training import _TestBatchTrainingMixin
1213

14+
15+
# Batch training with recurrent models is currently not supported
1316
class TestDDPGOnContinuousPOABC(base._TestDDPGOnContinuousPOABC):
1417

1518
def make_ddpg_agent(self, env, model, actor_opt, critic_opt, explorer,
@@ -20,7 +23,8 @@ def make_ddpg_agent(self, env, model, actor_opt, critic_opt, explorer,
2023
episodic_update=True, update_interval=1)
2124

2225

23-
class TestDDPGOnContinuousABC(base._TestDDPGOnContinuousABC):
26+
class TestDDPGOnContinuousABC(_TestBatchTrainingMixin,
27+
base._TestDDPGOnContinuousABC):
2428

2529
def make_ddpg_agent(self, env, model, actor_opt, critic_opt, explorer,
2630
rbuf, gpu):

0 commit comments

Comments
 (0)