Conversation
|
The latest updates on your projects. Learn more about Vercel for Git ↗︎
|
|
Added Atari DQN implementation as well. |
|
Thanks for this PR! Trying to run
|
|
https://wandb.ai/costa-huang/cleanRL/runs/2ymc7qnx?workspace=user-costa-huang i did a quick run with DQN atari by it cannot replicate the same level of performance in torch. Would you mind looking into it? |
|
Thanks for this PR. I looked into it a bit more. There are two complications: Image dimensions (NCHW vs NHWC)A pre-processed Atari game image has height H=84, width W=84, channels C=4 from the frame stack, and a batch dimension N=1 . Pytorch's Flax's We can print out the models in the current implementation to compare and confirm this issue: # in dqn_atari.py
from torchsummary import summary
summary(q_network, (4, 84, 84))
# in dqn_atari_jax.py
print(q_network.tabulate(q_key, obs))We need to fix the image input format with Flax by adding a transpose: @nn.compact
def __call__(self, x):
+ x = jnp.transpose(x, (0, 2, 3, 1))
x = x / (255.0)
x = nn.Conv(32, kernel_size=(8, 8), strides=(4, 4))(x)
x = nn.relu(x)
x = nn.Conv(64, kernel_size=(4, 4), strides=(2, 2))(x)
x = nn.relu(x)
x = nn.Conv(64, kernel_size=(3, 3), strides=(1, 1))(x)
x = nn.relu(x)
x = x.reshape((x.shape[0], -1))
x = nn.Dense(512)(x)
x = nn.relu(x)
x = nn.Dense(self.action_dim)(x)
return xPaddingAfter fixing the image format issue, they still don't quite exactly the same: So, it looks like Flax's I looked into the reputable dqn_zoo implementation and found they used Implementing @nn.compact
def __call__(self, x):
x = jnp.transpose(x, (0, 2, 3, 1))
x = x / (255.0)
- x = nn.Conv(32, kernel_size=(8, 8), strides=4)(x)
+ x = nn.Conv(32, kernel_size=(8, 8), strides=(4, 4), padding='VALID')(x)
x = nn.relu(x)
- x = nn.Conv(64, kernel_size=(4, 4), strides=2)(x)
+ x = nn.Conv(64, kernel_size=(4, 4), strides=(2, 2), padding='VALID')(x)
x = nn.relu(x)
- x = nn.Conv(64, kernel_size=(3, 3), strides=1)(x)
+ x = nn.Conv(64, kernel_size=(3, 3), strides=(1, 1), padding='VALID')(x)
x = nn.relu(x)
x = x.reshape((x.shape[0], -1))
x = nn.Dense(512)(x)
x = nn.relu(x)
x = nn.Dense(self.action_dim)(x)
return x |
|
The results look great so far. One other thing: would you mind modifying the script to use |
|
Thanks for looking into the issue. I completely missed channel location preprocessing. Didn't know sb3 preprocessed it for pytorch. I've updated the code to use TrainState API. But I'm not able to use jitted apply function from state.apply_fn. Even if I replace the apply_fn after the jit operation as below. q_state = TrainState.create(
apply_fn=q_network.apply,
params=q_network.init(q_key, obs),
target_params=q_network.init(q_key, obs),
tx=optax.adam(learning_rate=args.learning_rate),
)
q_network.apply = jax.jit(q_network.apply)
q_state = q_state.replace(apply_fn=q_network.apply)And directly applying jit operation during the creation of TrainState leads to errors during the initialization of parameters. |
Maybe not, we can probably just stick with |
|
cc @yooceii |
|
Scratch that. Let me just run the experiments. It shouldn't take that long. |
|
@kinalmehta I got this error See https://wandb.ai/openrlbenchmark/cleanrl/runs/1joyhhtw/logs?workspace=user-costa-huang |
|
Nvm sorry to disturb you @kinalmehta, there was a zombie process that took over the GPU memory and once I removed that process things start to work again. |
|
Hey @kinalmehta @yooceii, I did another round of quick benchmark and found jitting the action sampling slows down throughput (* see explanation below) See report The reason #231 (comment) found jitting is faster is because the non-jitted baseline ( Can @kinalmehta and @yooceii confirm my findings? The difference could also stem from hardware differences. Namely, please compare the SPS at 200k steps for
Thank you. |
vwxyzjn
left a comment
There was a problem hiding this comment.
Everything LGTM. Thanks @kinalmehta!
* Prototype JAX + DQN * formatting changes * bug fix: predicted q value in mse * Prototype JAX + DQN + Atari * formatting changes * Fix `UNKNOWN: CUDNN_STATUS_EXECUTION` * update mse loss calculation to be (target-pred) instead of (pred-target) * Fix image format and Conv padding * Adapting to the TrainState API * Add assets * Add my benchmark script * fix benchmark script embed it was pointing to c51, fixed it to point to dqn * docs: add DQN + JAX documentation * jit action selection and linear_schedule * docs fix * update docs * change documentation addr * add test cases * update ci * Add warning on installing jax on windows * fix pre-commit * revert back changes * update benchmark scripts * Add docs * update docs Co-authored-by: Costa Huang <costa.huang@outlook.com>













Description
JAX implementation for DQN
Implementation for #220
Types of changes
Checklist:
pre-commit run --all-filespasses (required).mkdocs serve.If you are adding new algorithms or your change could result in performance difference, you may need to (re-)run tracked experiments. See #137 as an example PR.
--capture-videoflag toggled on (required).mkdocs serve.width=500andheight=300).