Skip to content

Commit 1feffd8

Browse files
authored
Merge pull request #529 from muupan/iqn-act-deterministically
Add a deterministic mode to IQN for stable tests
2 parents b14faec + 5647606 commit 1feffd8

File tree

3 files changed

+22
-3
lines changed

3 files changed

+22
-3
lines changed

chainerrl/agents/iqn.py

Lines changed: 18 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -235,6 +235,11 @@ class IQN(dqn.DQN):
235235
to sample from the return distribution at the next state.
236236
quantile_thresholds_K (int): Number of quantile thresholds used to
237237
compute greedy actions.
238+
act_deterministically (bool): IQN's action selection is by default
239+
stochastic as it samples quantile thresholds every time it acts,
240+
even for evaluation. If this option is set to True, it uses
241+
equally spaced quantile thresholds instead of randomly sampled ones
242+
for evaluation, making its action selection deterministic.
238243
239244
For other arguments, see chainerrl.agents.DQN.
240245
"""
@@ -246,6 +251,7 @@ def __init__(self, *args, **kwargs):
246251
self.quantile_thresholds_N_prime = kwargs.pop(
247252
'quantile_thresholds_N_prime', 64)
248253
self.quantile_thresholds_K = kwargs.pop('quantile_thresholds_K', 32)
254+
self.act_deterministically = kwargs.pop('act_deterministically', False)
249255
super().__init__(*args, **kwargs)
250256

251257
def _compute_target_values(self, exp_batch):
@@ -357,7 +363,16 @@ def _evaluate_model_and_update_recurrent_states(self, batch_obs, test):
357363
batch_xs, self.train_recurrent_states)
358364
else:
359365
tau2av = self.model(batch_xs)
360-
taus_tilde = self.xp.random.uniform(
361-
0, 1,
362-
size=(len(batch_obs), self.quantile_thresholds_K)).astype('f')
366+
if test and self.act_deterministically:
367+
# Instead of uniform sampling, use a deterministic sequence of
368+
# equally spaced numbers from 0 to 1 as quantile thresholds.
369+
taus_tilde = self.xp.broadcast_to(
370+
self.xp.linspace(
371+
0, 1, num=self.quantile_thresholds_K, dtype=np.float32),
372+
(len(batch_obs), self.quantile_thresholds_K),
373+
)
374+
else:
375+
taus_tilde = self.xp.random.uniform(
376+
0, 1,
377+
size=(len(batch_obs), self.quantile_thresholds_K)).astype('f')
363378
return tau2av(taus_tilde)

tests/agents_tests/test_double_iqn.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@ def make_dqn_agent(self, env, q_func, opt, explorer, rbuf, gpu):
4646
replay_start_size=100, target_update_interval=100,
4747
quantile_thresholds_N=self.quantile_thresholds_N,
4848
quantile_thresholds_N_prime=self.quantile_thresholds_N_prime,
49+
act_deterministically=True,
4950
)
5051

5152

@@ -76,4 +77,5 @@ def make_dqn_agent(self, env, q_func, opt, explorer, rbuf, gpu):
7677
quantile_thresholds_N=32,
7778
quantile_thresholds_N_prime=32,
7879
recurrent=True,
80+
act_deterministically=True,
7981
)

tests/agents_tests/test_iqn.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@ def make_dqn_agent(self, env, q_func, opt, explorer, rbuf, gpu):
4949
replay_start_size=100, target_update_interval=100,
5050
quantile_thresholds_N=self.quantile_thresholds_N,
5151
quantile_thresholds_N_prime=self.quantile_thresholds_N_prime,
52+
act_deterministically=True,
5253
)
5354

5455

@@ -79,6 +80,7 @@ def make_dqn_agent(self, env, q_func, opt, explorer, rbuf, gpu):
7980
quantile_thresholds_N=32,
8081
quantile_thresholds_N_prime=32,
8182
recurrent=True,
83+
act_deterministically=True,
8284
)
8385

8486

0 commit comments

Comments
 (0)