Skip to content
This repository was archived by the owner on Oct 7, 2024. It is now read-only.

Commit f2e6c21

Browse files
mtthsscopybara-github
authored andcommitted
Port boot_dqn from optix to optax
PiperOrigin-RevId: 330726981 Change-Id: Ib7fa620772c45cc4746627b271acc2dbb5455283
1 parent 828f9bf commit f2e6c21

File tree

1 file changed

+2
-2
lines changed
  • bsuite/baselines/jax/boot_dqn

1 file changed

+2
-2
lines changed

bsuite/baselines/jax/boot_dqn/run.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,8 +28,8 @@
2828

2929
import haiku as hk
3030
from jax import lax
31-
from jax.experimental import optix
3231
import jax.numpy as jnp
32+
import optax
3333

3434
# Internal imports.
3535

@@ -70,7 +70,7 @@ def network(inputs: jnp.ndarray) -> jnp.ndarray:
7070
x = hk.Flatten()(inputs)
7171
return net(x) + prior_scale * lax.stop_gradient(prior_net(x))
7272

73-
optimizer = optix.adam(learning_rate=1e-3)
73+
optimizer = optax.adam(learning_rate=1e-3)
7474

7575
agent = boot_dqn.BootstrappedDqn(
7676
obs_spec=env.observation_spec(),

0 commit comments

Comments
 (0)