-
Notifications
You must be signed in to change notification settings - Fork 225
Allow envs to send a 'needs_reset' signal #356
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
19210cd
7880de6
e5d2743
bac52cd
bb03699
989f03d
194ec71
c7ef013
a20703c
2d6c298
54e1039
bf84592
80b8ad5
e27f22a
8121c74
b7aaaed
623e8e4
5958e41
af6224a
a440022
a578899
0d8f52b
c959492
ffdfa05
9dc2db0
dc1ec4c
7127b29
66165fa
924e3b2
eb9f3d2
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -58,7 +58,7 @@ def train_agent_batch(agent, env, steps, outdir, log_interval=None, | |
| agent.t = step_offset | ||
|
|
||
| try: | ||
| while t < steps: | ||
| while True: | ||
| # a_t | ||
| actions = agent.batch_act_and_train(obss) | ||
| # o_{t+1}, r_{t+1} | ||
|
|
@@ -71,6 +71,8 @@ def train_agent_batch(agent, env, steps, outdir, log_interval=None, | |
| resets = np.zeros(num_envs, dtype=bool) | ||
| else: | ||
| resets = (episode_len == max_episode_len) | ||
| resets = np.logical_or( | ||
| resets, [info.get('needs_reset', False) for info in infos]) | ||
| # Agent observes the consequences | ||
| agent.batch_observe_and_train(obss, rs, dones, resets) | ||
|
|
||
|
|
@@ -84,11 +86,9 @@ def train_agent_batch(agent, env, steps, outdir, log_interval=None, | |
| # 3. clear the record of rewards | ||
| # 4. clear the record of the number of steps | ||
| # 5. reset the env to start a new episode | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Should these comments be revised?
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think these comments are still correct except that 3-5 are skipped when training is finished. I'll clarify this in the comments. |
||
| # 3-5 are skipped when training is already finished. | ||
| episode_idx += end | ||
| recent_returns.extend(episode_r[end]) | ||
| episode_r[end] = 0 | ||
| episode_len[end] = 0 | ||
| obss = env.reset(not_end) | ||
|
|
||
| for _ in range(num_envs): | ||
| t += 1 | ||
|
|
@@ -114,6 +114,14 @@ def train_agent_batch(agent, env, steps, outdir, log_interval=None, | |
| evaluator.max_score >= successful_score): | ||
| break | ||
|
|
||
| if t >= steps: | ||
| break | ||
|
|
||
| # Start new episodes if needed | ||
| episode_r[end] = 0 | ||
| episode_len[end] = 0 | ||
| obss = env.reset(not_end) | ||
|
|
||
| except (Exception, KeyboardInterrupt): | ||
| # Save the current model before being killed | ||
| save_agent(agent, t, outdir, logger, suffix='_except') | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -10,6 +10,8 @@ | |
|
|
||
| from gym import spaces | ||
|
|
||
| import chainerrl | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Since you're changing the file, perhaps we should change the header from "This file is a fork from a MIT-licensed project" to "This file adapted from an MIT-licensed project..."
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. "fork" means it already has changes, no? |
||
|
|
||
| cv2.ocl.setUseOpenCL(False) | ||
|
|
||
|
|
||
|
|
@@ -36,8 +38,8 @@ def _reset(self, **kwargs): | |
| assert noops > 0 | ||
| obs = None | ||
| for _ in range(noops): | ||
| obs, _, done, _ = self.env.step(self.noop_action) | ||
| if done: | ||
| obs, _, done, info = self.env.step(self.noop_action) | ||
| if done or info.get('needs_reset', False): | ||
| obs = self.env.reset(**kwargs) | ||
| return obs | ||
|
|
||
|
|
@@ -54,11 +56,11 @@ def __init__(self, env): | |
|
|
||
| def _reset(self, **kwargs): | ||
| self.env.reset(**kwargs) | ||
| obs, _, done, _ = self.env.step(1) | ||
| if done: | ||
| obs, _, done, info = self.env.step(1) | ||
| if done or info.get('needs_reset', False): | ||
| self.env.reset(**kwargs) | ||
| obs, _, done, _ = self.env.step(2) | ||
| if done: | ||
| obs, _, done, info = self.env.step(2) | ||
| if done or info.get('needs_reset', False): | ||
| self.env.reset(**kwargs) | ||
| return obs | ||
|
|
||
|
|
@@ -74,11 +76,11 @@ def __init__(self, env): | |
| """ | ||
| gym.Wrapper.__init__(self, env) | ||
| self.lives = 0 | ||
| self.was_real_done = True | ||
| self.needs_real_reset = True | ||
|
|
||
| def _step(self, action): | ||
| obs, reward, done, info = self.env.step(action) | ||
| self.was_real_done = done | ||
| self.needs_real_reset = done or info.get('needs_reset', False) | ||
| # check current lives, make loss of life terminal, | ||
| # then update lives to handle bonus lives | ||
| lives = self.env.unwrapped.ale.lives() | ||
|
|
@@ -97,7 +99,7 @@ def _reset(self, **kwargs): | |
| This way all states are still reachable even though lives are episodic, | ||
| and the learner need not know about any of this behind-the-scenes. | ||
| """ | ||
| if self.was_real_done: | ||
| if self.needs_real_reset: | ||
| obs = self.env.reset(**kwargs) | ||
| else: | ||
| # no-op step to advance from terminal/lost life state | ||
|
|
@@ -126,7 +128,7 @@ def _step(self, action): | |
| if i == self._skip - 1: | ||
| self._obs_buffer[1] = obs | ||
| total_reward += reward | ||
| if done: | ||
| if done or info.get('needs_reset', False): | ||
| break | ||
| # Note that the observation on the done=True frame | ||
| # doesn't matter | ||
|
|
@@ -238,9 +240,14 @@ def __array__(self, dtype=None): | |
| return out | ||
|
|
||
|
|
||
| def make_atari(env_id): | ||
| def make_atari(env_id, max_frames=30 * 60 * 60): | ||
| env = gym.make(env_id) | ||
| assert 'NoFrameskip' in env.spec.id | ||
| assert isinstance(env, gym.wrappers.TimeLimit) | ||
| # Unwrap TimeLimit wrapper because we use our own time limits | ||
| env = env.env | ||
| env = chainerrl.wrappers.ContinuingTimeLimit( | ||
| env, max_episode_steps=max_frames) | ||
| env = NoopResetEnv(env, noop_max=30) | ||
| env = MaxAndSkipEnv(env, skip=4) | ||
| return env | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,48 @@ | ||
| from __future__ import unicode_literals | ||
| from __future__ import print_function | ||
| from __future__ import division | ||
| from __future__ import absolute_import | ||
| from builtins import * # NOQA | ||
| from future import standard_library | ||
| standard_library.install_aliases() # NOQA | ||
|
|
||
| import gym | ||
|
|
||
|
|
||
| class ContinuingTimeLimit(gym.Wrapper): | ||
| """TimeLimit wrapper for continuing environments. | ||
|
|
||
| This is similar gym.wrappers.TimeLimit, which sets a time limit for | ||
| each episode, except that done=False is returned and that | ||
| info['needs_reset'] is set to True when past the limit. | ||
|
|
||
| Code that calls env.step is responsible for checking the info dict, the | ||
| fourth returned value, and resetting the env if it has the 'needs_reset' | ||
| key and its value is True. | ||
|
|
||
| Args: | ||
| env (gym.Env): Env to wrap. | ||
| max_episode_steps (int): Maximum number of timesteps during an episode, | ||
| after which the env needs a reset. | ||
| """ | ||
|
|
||
| def __init__(self, env, max_episode_steps): | ||
| super(ContinuingTimeLimit, self).__init__(env) | ||
| self._max_episode_steps = max_episode_steps | ||
|
|
||
| self._elapsed_steps = None | ||
|
|
||
| def step(self, action): | ||
| assert self._elapsed_steps is not None,\ | ||
| "Cannot call env.step() before calling reset()" | ||
| observation, reward, done, info = self.env.step(action) | ||
| self._elapsed_steps += 1 | ||
|
|
||
| if self._max_episode_steps <= self._elapsed_steps: | ||
| info['needs_reset'] = True | ||
|
|
||
| return observation, reward, done, info | ||
|
|
||
| def reset(self): | ||
| self._elapsed_steps = 0 | ||
| return self.env.reset() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Later in this "if" statement (https://github.com/chainer/chainerrl/pull/356/files#diff-a2caf3ec0e2750a1d16edb375789daa5R81), you reset the environment. Why do you reset the environment if done is true? What if reset = False?
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Even if
reset == False, we need to call env.reset() ifdone == True. In other words, theresetvariable can be false when we reset the env due todone == True. It is possible to renameresetasnon_done_resetor something, but it would be verbose.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
So this is targeted primarily at environments that reset based off of a
max-episode-lengthor viadone, and not environments wheredoneisTruebut you still do not reset the environment?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Correct. Currently, ChainerRL assumes that
env.resetmust be called whendone==True.