Skip to content

Commit 42d21bd

Browse files
authored
PPO + JAX + EnvPool + Atari (#227)
* PPO + jax + envpool + atari * fix bug: only report metric when lifes are used up * pre-commit * quick fix * Quick refactor * push changes * pre-commit and use EnvPool's new API * update envpool * update docs * update ppo benchmark script * update docs * use the latest envpool interface * update envpool to the latest version * update pyproject.toml * update lock files * Quick clarification * Update docs * remove non benchmarked script * update docs * revert poetry changes * docs fix * remove uncessary code, add docs * add a note one envpool * update test cases * explain `get_action_and_value` * fix indent * Fix weird error with `np.mean`. See below: We got this message. See #227 (comment) ``` NotImplementedError: Got <class 'jaxlib.xla_extension.DeviceArray'>, but numpy array, torch tensor, or caffe2 blob name are expected. ``` * update docs * pre-commit * add note on `charts/avg_episodic_return` * update reproducibility script * add note on value function clipping
1 parent c20c799 commit 42d21bd

22 files changed

+157414
-32
lines changed

.github/workflows/tests.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -193,7 +193,7 @@ jobs:
193193

194194
# envpool tests
195195
- name: Install envpool dependencies
196-
run: poetry install --with pytest,envpool
196+
run: poetry install --with pytest,envpool,jax
197197
- name: Downgrade setuptools
198198
run: poetry run pip install setuptools==59.5.0
199199
- name: Run envpool tests

benchmark/ppo.sh

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -71,3 +71,21 @@ xvfb-run -a poetry run python -m cleanrl_utils.benchmark \
7171
--command "poetry run python cleanrl/ppo_continuous_action_isaacgym/ppo_continuous_action_isaacgym.py --track --capture-video --num-envs 8192 --num-steps 8 --update-epochs 5 --num-minibatches 4 --reward-scaler 0.01 --total-timesteps 600000000 --record-video-step-frequency 3660" \
7272
--num-seeds 3 \
7373
--workers 1
74+
75+
76+
poetry install --with envpool
77+
poetry run python -m cleanrl_utils.benchmark \
78+
--env-ids Alien-v5 Amidar-v5 Assault-v5 Asterix-v5 Asteroids-v5 Atlantis-v5 BankHeist-v5 BattleZone-v5 BeamRider-v5 Berzerk-v5 Bowling-v5 Boxing-v5 Breakout-v5 Centipede-v5 ChopperCommand-v5 CrazyClimber-v5 Defender-v5 DemonAttack-v5 \
79+
--command "poetry run python ppo_atari_envpool_xla_jax.py --track --wandb-project-name envpool-atari --wandb-entity openrlbenchmark" \
80+
--num-seeds 3 \
81+
--workers 1
82+
poetry run python -m cleanrl_utils.benchmark \
83+
--env-ids DoubleDunk-v5 Enduro-v5 FishingDerby-v5 Freeway-v5 Frostbite-v5 Gopher-v5 Gravitar-v5 Hero-v5 IceHockey-v5 Jamesbond-v5 Kangaroo-v5 Krull-v5 KungFuMaster-v5 MontezumaRevenge-v5 MsPacman-v5 NameThisGame-v5 Phoenix-v5 Pitfall-v5 Pong-v5 \
84+
--command "poetry run python ppo_atari_envpool_xla_jax.py --track --wandb-project-name envpool-atari --wandb-entity openrlbenchmark" \
85+
--num-seeds 3 \
86+
--workers 1
87+
poetry run python -m cleanrl_utils.benchmark \
88+
--env-ids PrivateEye-v5 Qbert-v5 Riverraid-v5 RoadRunner-v5 Robotank-v5 Seaquest-v5 Skiing-v5 Solaris-v5 SpaceInvaders-v5 StarGunner-v5 Surround-v5 Tennis-v5 TimePilot-v5 Tutankham-v5 UpNDown-v5 Venture-v5 VideoPinball-v5 WizardOfWor-v5 YarsRevenge-v5 Zaxxon-v5 \
89+
--command "poetry run python ppo_atari_envpool_xla_jax.py --track --wandb-project-name envpool-atari --wandb-entity openrlbenchmark" \
90+
--num-seeds 3 \
91+
--workers 1

cleanrl/ppo_atari_envpool.py

Lines changed: 2 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -84,13 +84,6 @@ def __init__(self, env, deque_size=100):
8484
self.num_envs = getattr(env, "num_envs", 1)
8585
self.episode_returns = None
8686
self.episode_lengths = None
87-
# get if the env has lives
88-
self.has_lives = False
89-
env.reset()
90-
info = env.step(np.zeros(self.num_envs, dtype=int))[-1]
91-
if info["lives"].sum() > 0:
92-
self.has_lives = True
93-
print("env has lives")
9487

9588
def reset(self, **kwargs):
9689
observations = super().reset(**kwargs)
@@ -107,13 +100,8 @@ def step(self, action):
107100
self.episode_lengths += 1
108101
self.returned_episode_returns[:] = self.episode_returns
109102
self.returned_episode_lengths[:] = self.episode_lengths
110-
all_lives_exhausted = infos["lives"] == 0
111-
if self.has_lives:
112-
self.episode_returns *= 1 - all_lives_exhausted
113-
self.episode_lengths *= 1 - all_lives_exhausted
114-
else:
115-
self.episode_returns *= 1 - dones
116-
self.episode_lengths *= 1 - dones
103+
self.episode_returns *= 1 - infos["terminated"]
104+
self.episode_lengths *= 1 - infos["terminated"]
117105
infos["r"] = self.returned_episode_returns
118106
infos["l"] = self.returned_episode_lengths
119107
return (

0 commit comments

Comments
 (0)