Skip to content

Categorical DQN (C51)#249

Merged
toslunar merged 52 commits intochainer:masterfrom
muupan:c51
Apr 6, 2018
Merged

Categorical DQN (C51)#249
toslunar merged 52 commits intochainer:masterfrom
muupan:c51

Conversation

@muupan
Copy link
Copy Markdown
Member

@muupan muupan commented Mar 16, 2018

  • Check performance on Atari
  • Add tests
    • tests of C51
    • tests of DistributionalDiscreteActionValue
  • Clean code
    • Implement its own __init__ Clarify differences from DQN in the docstring

Merge #248 first.

@muupan
Copy link
Copy Markdown
Member Author

muupan commented Mar 28, 2018

Below are the results of examples/ale/train_c51_ale.py {rom} and examples/ale/train_dqn_ale {rom} --agent DQN/DoubleDQN/PAL, each with three random seeds. C51 achieves better scores across games than other DQN variants.

asterix
beam_rider
breakout
pong
qbert
seaquest
space_invaders

@muupan muupan changed the title [WIP] C51 [WIP] Categorical DQN (C51) Apr 2, 2018
@muupan
Copy link
Copy Markdown
Member Author

muupan commented Apr 2, 2018

It is unclear to me that categorical projection should be implemented as Algorithm 1 or as (7) in the paper. Algorithm 1 seems wrong to me when b_j is an integer, so I handled the case when bj is an integer separately.

@muupan muupan changed the title [WIP] Categorical DQN (C51) Categorical DQN (C51) Apr 3, 2018
Copy link
Copy Markdown
Member

@toslunar toslunar left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks a lot. I reviewed.

import chainer
from chainer import cuda
from chainer import functions as F

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fix style: remove this empty line

(batch_size, n_atoms).
y_probs (ndarray): Probabilities of atoms whose values are y.
Its shape must be (batch_size, n_atoms).
z (ndarray): Values of atoms before projection after projection. Its
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fix typo: Values of atoms before projection after projection.

for j in range(n_atoms - 1):
if z[j] < yi <= z[j + 1]:
proj_probs[b, j] += (z[j + 1] - yi) / delta_z * p
proj_probs[b, j + 1] += (yi - z[j]) / delta_z * p
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you use delta_z = z[j + 1] - z[j] in this naive implementation?

scatter_add(
z_probs.ravel(),
(l.astype(xp.int32) + offset).ravel(),
(y_probs * (u - bj)).ravel())
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(y_probs * (1 - (bj - l))).ravel()) could eliminate the treatment for the case l == u.
The reason why the authors of the paper use u = ceil(bj) seems to me no more than avoiding "z[n_atoms] += 0".

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Wow, your solution is definitely better than mine!

n_atoms = 51
v_max = 500
v_min = 0
z_values = np.linspace(v_min, v_max, num=n_atoms, dtype=np.float32)
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you consider moving this line into the package? (e.g. pass v_min, v_max, n_atoms to DistributionalFCStateQFunctionWithDiscreteAction.) z_values should be linspace anyway.

"""Compute a loss of categorical DQN."""
y, t = self._compute_y_and_t(exp_batch, gamma)
# minimize the cross entropy
eltwise_loss = -t * F.log(F.clip(y, 1e-10, 1.))
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you explain why F.clip is here?

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I found clipping was necessary for training CategoricalDQN from earlier experiments. Without clipping, some probability values converges to 0, resulting in log(0) -> NaN.

Other unofficial implementations also apply clipping.
https://github.com/Kaixhin/Rainbow/blob/master/agent.py#L85
https://github.com/floringogianu/categorical-dqn/blob/master/policy_improvement/categorical_update.py#L53

Since clipping by 1e-10 worked, I didn't tune 1e-10 further. It is possible larger values may result in better performance.

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I added a comment to explain why.

muupan added 6 commits April 5, 2018 18:24
I found clipping was necessary for training CategoricalDQN from earlier
experiments. Without clipping, some probability values converges to 0,
resulting in log(0) -> NaN.

Other unofficial implementations also apply clipping.
https://github.com/Kaixhin/Rainbow/blob/master/agent.py#L85
https://github.com/floringogianu/categorical-dqn/blob/master/policy_improvement/categorical_update.py#L53

Since clipping by 1e-10 worked, I didn't tune 1e-10 further. It is
possible larger values may result in better performance.
@muupan
Copy link
Copy Markdown
Member Author

muupan commented Apr 5, 2018

Thanks for the review. I fixed them.

Copy link
Copy Markdown
Member

@toslunar toslunar left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@toslunar toslunar merged commit e3e2c44 into chainer:master Apr 6, 2018
@muupan muupan added this to the v0.4 milestone Jul 23, 2018
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants