Skip to content

Commit e9cd2f9

Browse files
authored
Merge pull request #427 from muupan/parallel-link
Add Branched and use it to simplify train_ppo_batch_gym.py
2 parents e07b248 + a8cc75a commit e9cd2f9

File tree

8 files changed

+179
-75
lines changed

8 files changed

+179
-75
lines changed

.travis.yml

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,14 +4,14 @@ python:
44
- "2.7"
55
- "3.6"
66
env:
7-
- CHAINER_VERSION=3
7+
- CHAINER_VERSION=4
88
- CHAINER_VERSION=stable
99
# command to install dependencies
1010
install:
1111
- pip install --upgrade pip setuptools wheel
1212
- |
13-
if [[ $CHAINER_VERSION == 3 ]]; then
14-
pip install "chainer==3.1.0"
13+
if [[ $CHAINER_VERSION == 4 ]]; then
14+
pip install "chainer==4.0.0"
1515
else
1616
pip install chainer
1717
fi

chainerrl/links/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
from chainerrl.links.branched import Branched # NOQA
12
from chainerrl.links.dqn_head import NatureDQNHead # NOQA
23
from chainerrl.links.dqn_head import NIPSDQNHead # NOQA
34
from chainerrl.links.empirical_normalization import EmpiricalNormalization # NOQA

chainerrl/links/branched.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
from __future__ import division
2+
from __future__ import print_function
3+
from __future__ import unicode_literals
4+
from __future__ import absolute_import
5+
from builtins import * # NOQA
6+
from future import standard_library
7+
standard_library.install_aliases() # NOQA
8+
9+
import chainer
10+
11+
12+
class Branched(chainer.ChainList):
13+
"""Link that calls forward functions of child links in parallel.
14+
15+
When either the `__call__` method of this link are called, all the
16+
argeuments are forwarded to each child link's `__call__` method.
17+
18+
The returned values from the child links are returned as a tuple.
19+
20+
Args:
21+
*links: Child links. Each link should be callable.
22+
"""
23+
24+
def __call__(self, *args, **kwargs):
25+
"""Forward the arguments to the child links.
26+
27+
Args:
28+
*args, **kwargs: Any arguments forwarded to child links. Each child
29+
link should be able to accept the arguments.
30+
31+
Returns:
32+
tuple: Tuple of the returned values from the child links.
33+
"""
34+
return tuple(link(*args, **kwargs) for link in self)

chainerrl/policies/gaussian_policy.py

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -248,3 +248,53 @@ def get_var_array(shape):
248248
layers.append(lambda x: distribution.GaussianDistribution(
249249
x, get_var_array(x.shape)))
250250
super().__init__(*layers)
251+
252+
253+
class GaussianHeadWithStateIndependentCovariance(chainer.Chain):
254+
"""Gaussian head with state-independent learned covariance.
255+
256+
This link is intended to be attached to a neural network that outputs
257+
the mean of a Gaussian policy. The only learnable parameter this link has
258+
determines the variance in a state-independent way.
259+
260+
State-independent parameterization of the variance of a Gaussian policy
261+
is often used with PPO and TRPO, e.g., in https://arxiv.org/abs/1709.06560.
262+
263+
Args:
264+
action_size (int): Number of dimensions of the action space.
265+
var_type (str): Type of parameterization of variance. It must be
266+
'spherical' or 'diagonal'.
267+
var_func (callable): Callable that computes the variance from the var
268+
parameter. It should always return positive values.
269+
var_param_init (float): Initial value the var parameter.
270+
"""
271+
272+
def __init__(
273+
self,
274+
action_size,
275+
var_type='spherical',
276+
var_func=F.softplus,
277+
var_param_init=0,
278+
):
279+
280+
self.var_func = var_func
281+
var_size = {'spherical': 1, 'diagonal': action_size}[var_type]
282+
283+
super().__init__()
284+
with self.init_scope():
285+
self.var_param = chainer.Parameter(
286+
initializer=var_param_init, shape=(var_size,))
287+
288+
def __call__(self, mean):
289+
"""Return a Gaussian with given mean.
290+
291+
Args:
292+
mean (chainer.Variable or ndarray): Mean of Gaussian.
293+
294+
Returns:
295+
chainerrl.distribution.Distribution: Gaussian whose mean is the
296+
mean argument and whose variance is computed from the parameter
297+
of this link.
298+
"""
299+
var = F.broadcast_to(self.var_func(self.var_param), mean.shape)
300+
return distribution.GaussianDistribution(mean, var)

docs/links.rst

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,8 @@ Links
55
Link implementations
66
============================
77

8+
.. autoclass:: chainerrl.links.Branched
9+
810
.. autoclass:: chainerrl.links.EmpiricalNormalization
911

1012
.. autoclass:: chainerrl.links.FactorizedNoisyLinear

examples/gym/train_ppo_batch_gym.py

Lines changed: 44 additions & 71 deletions
Original file line numberDiff line numberDiff line change
@@ -17,67 +17,17 @@
1717

1818
import chainer
1919
from chainer import functions as F
20+
from chainer import links as L
2021
import gym
22+
import gym.spaces
2123
import gym.wrappers
2224
import numpy as np
2325

2426
import chainerrl
25-
from chainerrl.agents import a3c
2627
from chainerrl.agents import PPO
2728
from chainerrl import experiments
28-
from chainerrl import links
2929
from chainerrl import misc
3030
from chainerrl.optimizers.nonbias_weight_decay import NonbiasWeightDecay
31-
from chainerrl import policies
32-
33-
34-
class A3CFFSoftmax(chainer.ChainList, a3c.A3CModel):
35-
"""An example of A3C feedforward softmax policy."""
36-
37-
def __init__(self, ndim_obs, n_actions, hidden_sizes=(200, 200)):
38-
self.pi = policies.SoftmaxPolicy(
39-
model=links.MLP(ndim_obs, n_actions, hidden_sizes))
40-
self.v = links.MLP(ndim_obs, 1, hidden_sizes=hidden_sizes)
41-
super().__init__(self.pi, self.v)
42-
43-
def pi_and_v(self, state):
44-
return self.pi(state), self.v(state)
45-
46-
47-
class A3CFFMellowmax(chainer.ChainList, a3c.A3CModel):
48-
"""An example of A3C feedforward mellowmax policy."""
49-
50-
def __init__(self, ndim_obs, n_actions, hidden_sizes=(200, 200)):
51-
self.pi = policies.MellowmaxPolicy(
52-
model=links.MLP(ndim_obs, n_actions, hidden_sizes))
53-
self.v = links.MLP(ndim_obs, 1, hidden_sizes=hidden_sizes)
54-
super().__init__(self.pi, self.v)
55-
56-
def pi_and_v(self, state):
57-
return self.pi(state), self.v(state)
58-
59-
60-
class A3CFFGaussian(chainer.Chain, a3c.A3CModel):
61-
"""An example of A3C feedforward Gaussian policy."""
62-
63-
def __init__(self, obs_size, action_space,
64-
n_hidden_layers=2, n_hidden_channels=64,
65-
bound_mean=None):
66-
assert bound_mean in [False, True]
67-
super().__init__()
68-
hidden_sizes = (n_hidden_channels,) * n_hidden_layers
69-
with self.init_scope():
70-
self.pi = policies.FCGaussianPolicyWithStateIndependentCovariance(
71-
obs_size, action_space.low.size,
72-
n_hidden_layers, n_hidden_channels,
73-
var_type='diagonal', nonlinearity=F.tanh,
74-
bound_mean=bound_mean,
75-
min_action=action_space.low, max_action=action_space.high,
76-
mean_wscale=1e-2)
77-
self.v = links.MLP(obs_size, 1, hidden_sizes=hidden_sizes)
78-
79-
def pi_and_v(self, state):
80-
return self.pi(state), self.v(state)
8131

8232

8333
def main():
@@ -87,10 +37,6 @@ def main():
8737
parser.add_argument('--gpu', type=int, default=0)
8838
parser.add_argument('--env', type=str, default='Hopper-v2')
8939
parser.add_argument('--num-envs', type=int, default=1)
90-
parser.add_argument('--arch', type=str, default='FFGaussian',
91-
choices=('FFSoftmax', 'FFMellowmax',
92-
'FFGaussian'))
93-
parser.add_argument('--bound-mean', action='store_true')
9440
parser.add_argument('--seed', type=int, default=0,
9541
help='Random seed [0, 2 ** 32)')
9642
parser.add_argument('--outdir', type=str, default='results',
@@ -164,14 +110,49 @@ def make_batch_env(test):
164110
obs_normalizer = chainerrl.links.EmpiricalNormalization(
165111
obs_space.low.size, clip_threshold=5)
166112

113+
winit_last = chainer.initializers.LeCunNormal(1e-2)
114+
167115
# Switch policy types accordingly to action space types
168-
if args.arch == 'FFSoftmax':
169-
model = A3CFFSoftmax(obs_space.low.size, action_space.n)
170-
elif args.arch == 'FFMellowmax':
171-
model = A3CFFMellowmax(obs_space.low.size, action_space.n)
172-
elif args.arch == 'FFGaussian':
173-
model = A3CFFGaussian(obs_space.low.size, action_space,
174-
bound_mean=args.bound_mean)
116+
if isinstance(action_space, gym.spaces.Discrete):
117+
n_actions = action_space.n
118+
policy = chainer.Sequential(
119+
L.Linear(None, 64),
120+
F.tanh,
121+
L.Linear(None, 64),
122+
F.tanh,
123+
L.Linear(None, n_actions, initialW=winit_last),
124+
chainerrl.distribution.SoftmaxDistribution,
125+
)
126+
elif isinstance(action_space, gym.spaces.Box):
127+
action_size = action_space.low.size
128+
policy = chainer.Sequential(
129+
L.Linear(None, 64),
130+
F.tanh,
131+
L.Linear(None, 64),
132+
F.tanh,
133+
L.Linear(None, action_size, initialW=winit_last),
134+
chainerrl.policies.GaussianHeadWithStateIndependentCovariance(
135+
action_size=action_size,
136+
var_type='diagonal',
137+
var_func=lambda x: F.exp(2 * x), # Parameterize log std
138+
var_param_init=0, # log std = 0 => std = 1
139+
),
140+
)
141+
else:
142+
print("""\
143+
This example only supports gym.spaces.Box or gym.spaces.Discrete action spaces.""") # NOQA
144+
return
145+
146+
vf = chainer.Sequential(
147+
L.Linear(None, 64),
148+
F.tanh,
149+
L.Linear(None, 64),
150+
F.tanh,
151+
L.Linear(None, 1),
152+
)
153+
154+
# Combine a policy and a value function into a single model
155+
model = chainerrl.links.Branched(policy, vf)
175156

176157
opt = chainer.optimizers.Adam(alpha=args.lr, eps=1e-5)
177158
opt.setup(model)
@@ -208,13 +189,6 @@ def lr_setter(env, agent, value):
208189
lr_decay_hook = experiments.LinearInterpolationHook(
209190
args.steps, args.lr, 0, lr_setter)
210191

211-
# Linearly decay the clipping parameter to zero
212-
def clip_eps_setter(env, agent, value):
213-
agent.clip_eps = value
214-
215-
clip_eps_decay_hook = experiments.LinearInterpolationHook(
216-
args.steps, 0.2, 0, clip_eps_setter)
217-
218192
experiments.train_agent_batch_with_evaluation(
219193
agent=agent,
220194
env=make_batch_env(False),
@@ -230,7 +204,6 @@ def clip_eps_setter(env, agent, value):
230204
save_best_so_far_agent=False,
231205
step_hooks=[
232206
lr_decay_hook,
233-
clip_eps_decay_hook,
234207
],
235208
)
236209

requirements.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
cached-property
2-
chainer>=3.1.0
2+
chainer>=4.0.0
33
fastcache; python_version<'3.2'
44
funcsigs; python_version<'3.5'
55
future

tests/links_tests/test_branched.py

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
from __future__ import print_function
2+
from __future__ import unicode_literals
3+
from __future__ import division
4+
from __future__ import absolute_import
5+
from builtins import * # NOQA
6+
from future import standard_library
7+
standard_library.install_aliases() # NOQA
8+
9+
import unittest
10+
11+
import chainer
12+
from chainer import functions as F
13+
from chainer import links as L
14+
from chainer import testing
15+
import numpy as np
16+
17+
from chainerrl.links import Branched
18+
19+
20+
@testing.parameterize(*(
21+
testing.product({
22+
'batch_size': [1, 2],
23+
})
24+
))
25+
class TestBranched(unittest.TestCase):
26+
27+
def test_manual(self):
28+
link1 = L.Linear(2, 3)
29+
link2 = L.Linear(2, 5)
30+
link3 = chainer.Sequential(
31+
L.Linear(2, 7),
32+
F.tanh,
33+
)
34+
plink = Branched(link1, link2, link3)
35+
x = np.zeros((self.batch_size, 2), dtype=np.float32)
36+
pout = plink(x)
37+
self.assertIsInstance(pout, tuple)
38+
self.assertEqual(len(pout), 3)
39+
out1 = link1(x)
40+
out2 = link2(x)
41+
out3 = link3(x)
42+
np.testing.assert_allclose(pout[0].array, out1.array)
43+
np.testing.assert_allclose(pout[1].array, out2.array)
44+
np.testing.assert_allclose(pout[2].array, out3.array)

0 commit comments

Comments
 (0)