Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
30 commits
Select commit Hold shift + click to select a range
19210cd
Reset if info['needs_reset'] is True
muupan Nov 11, 2018
7880de6
Support reset with done=False by frames
muupan Nov 11, 2018
e5d2743
Move ContinuingTimeLimit to a new file
muupan Nov 16, 2018
bac52cd
Clean ContinuingTimeLimit
muupan Nov 16, 2018
bb03699
Add ContinuingTimeLimit to __init__.py
muupan Nov 16, 2018
989f03d
Add tests for ContinuingTimeLimit
muupan Nov 16, 2018
194ec71
Use chainerrl.wrappers.ContinuingTimeLimit
muupan Nov 16, 2018
c7ef013
Update batch_run_evaluation_episodes to support needs_reset
muupan Nov 16, 2018
a20703c
Update train_agent_batch to support needs_reset
muupan Nov 16, 2018
2d6c298
Test train_agent with needs_reset
muupan Nov 16, 2018
54e1039
Avoid unnecessary resetting
muupan Nov 16, 2018
bf84592
Test evaluation runs with needs_reset
muupan Nov 16, 2018
80b8ad5
Avoid unnecessary resetting in batch training
muupan Nov 16, 2018
e27f22a
Update test due to avoiding unnecessary reset
muupan Nov 16, 2018
8121c74
Test train_agent_batch with needs_reset
muupan Nov 16, 2018
b7aaaed
Make logic of train_agent_async similar to train_agent
muupan Nov 16, 2018
623e8e4
Test train_agent_async with needs_reset
muupan Nov 16, 2018
5958e41
Simplify
muupan Nov 20, 2018
af6224a
Simplify
muupan Nov 20, 2018
a440022
Correct comments
muupan Nov 20, 2018
a578899
Test the case where the last state in training is terminal
muupan Nov 20, 2018
0d8f52b
Specify --max-frames instead of --max-episode-len
muupan Nov 20, 2018
c959492
Remove no longer used arg
muupan Nov 20, 2018
ffdfa05
Add comment on avoiding unnecessary env resetting
muupan Dec 7, 2018
9dc2db0
Fix typos
muupan Dec 7, 2018
dc1ec4c
Apply same changes to other examples
muupan Dec 7, 2018
7127b29
Merge branch 'master' into continuing-time-limit
muupan Dec 7, 2018
66165fa
Apply same change
muupan Dec 7, 2018
924e3b2
Merge branch 'fix-unicode-error' into continuing-time-limit
muupan Dec 9, 2018
eb9f3d2
Merge branch 'master' into continuing-time-limit
muupan Dec 9, 2018
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 12 additions & 2 deletions chainerrl/experiments/evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,10 @@ def run_evaluation_episodes(env, agent, n_runs, max_episode_len=None,
done = False
test_r = 0
t = 0
while not (done or t == max_episode_len):
info = {}
while not (done
or t == max_episode_len
or info.get('needs_reset', False)):
a = agent.act(obs)
obs, r, done, info = env.step(a)
test_r += r
Expand Down Expand Up @@ -101,7 +104,7 @@ def batch_run_evaluation_episodes(
obss = env.reset()
rs = np.zeros(num_envs, dtype='f')

while len(episode_returns) < n_runs:
while True:
# a_t
actions = agent.batch_act(obss)
# o_{t+1}, r_{t+1}
Expand All @@ -114,6 +117,9 @@ def batch_run_evaluation_episodes(
resets = np.zeros(num_envs, dtype=bool)
else:
resets = (episode_len == max_episode_len)
resets = np.logical_or(
resets, [info.get('needs_reset', False) for info in infos])

# Agent observes the consequences
agent.batch_observe(obss, rs, dones, resets)

Expand All @@ -123,6 +129,10 @@ def batch_run_evaluation_episodes(

episode_returns.extend(episode_r[end])
episode_lengths.extend(episode_len[end])

if len(episode_returns) >= n_runs:
break

episode_r[end] = 0
episode_len[end] = 0
obss = env.reset(not_end)
Expand Down
4 changes: 3 additions & 1 deletion chainerrl/experiments/train_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,9 @@ def train_agent(agent, env, steps, outdir, max_episode_len=None,
for hook in step_hooks:
hook(env, agent, t)

if done or episode_len == max_episode_len or t == steps:
reset = (episode_len == max_episode_len
or info.get('needs_reset', False))
if done or reset or t == steps:
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Later in this "if" statement (https://github.com/chainer/chainerrl/pull/356/files#diff-a2caf3ec0e2750a1d16edb375789daa5R81), you reset the environment. Why do you reset the environment if done is true? What if reset = False?

Copy link
Copy Markdown
Member Author

@muupan muupan Nov 27, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Even if reset == False, we need to call env.reset() if done == True. In other words, the reset variable can be false when we reset the env due to done == True. It is possible to rename reset as non_done_reset or something, but it would be verbose.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So this is targeted primarily at environments that reset based off of a max-episode-length or via done, and not environments where done is True but you still do not reset the environment?

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Correct. Currently, ChainerRL assumes that env.reset must be called when done==True.

agent.stop_episode_and_train(obs, r, done=done)
logger.info('outdir:%s step:%s episode:%s R:%s',
outdir, t, episode_idx, episode_r)
Expand Down
52 changes: 29 additions & 23 deletions chainerrl/experiments/train_agent_async.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,6 @@ def train_loop(process_idx, env, agent, steps, outdir, counter,

try:

total_r = 0
episode_r = 0
global_t = 0
local_t = 0
Expand All @@ -41,19 +40,34 @@ def train_loop(process_idx, env, agent, steps, outdir, counter,

while True:

total_r += r
# a_t
a = agent.act_and_train(obs, r)
# o_{t+1}, r_{t+1}
obs, r, done, info = env.step(a)
local_t += 1
episode_r += r
episode_len += 1

if done or episode_len == max_episode_len:
with episodes_counter.get_lock():
episodes_counter.value += 1
global_episodes = episodes_counter.value
# Get and increment the global counter
with counter.get_lock():
counter.value += 1
global_t = counter.value

for hook in global_step_hooks:
hook(env, agent, global_t)

reset = (episode_len == max_episode_len
or info.get('needs_reset', False))
if done or reset or global_t >= steps or training_done.value:
agent.stop_episode_and_train(obs, r, done)

if process_idx == 0:
logger.info(
'outdir:%s global_step:%s local_step:%s R:%s',
outdir, global_t, local_t, episode_r)
logger.info('statistics:%s', agent.get_statistics())

# Evaluate the current agent
if evaluator is not None:
eval_score = evaluator.evaluate_if_necessary(
t=global_t, episodes=global_episodes,
Expand All @@ -68,28 +82,20 @@ def train_loop(process_idx, env, agent, steps, outdir, counter,
# Break immediately in order to avoid an additional
# call of agent.act_and_train
break
episode_r = 0
obs = env.reset()
r = 0
done = False
episode_len = 0
else:
a = agent.act_and_train(obs, r)
obs, r, done, info = env.step(a)

# Get and increment the global counter
with counter.get_lock():
counter.value += 1
global_t = counter.value
local_t += 1
episode_len += 1

for hook in global_step_hooks:
hook(env, agent, global_t)
with episodes_counter.get_lock():
episodes_counter.value += 1
global_episodes = episodes_counter.value

if global_t >= steps or training_done.value:
break

# Start a new episode
episode_r = 0
episode_len = 0
obs = env.reset()
r = 0

except (Exception, KeyboardInterrupt):
if process_idx == 0:
# Save the current model before being killed
Expand Down
16 changes: 12 additions & 4 deletions chainerrl/experiments/train_agent_batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ def train_agent_batch(agent, env, steps, outdir, log_interval=None,
agent.t = step_offset

try:
while t < steps:
while True:
# a_t
actions = agent.batch_act_and_train(obss)
# o_{t+1}, r_{t+1}
Expand All @@ -71,6 +71,8 @@ def train_agent_batch(agent, env, steps, outdir, log_interval=None,
resets = np.zeros(num_envs, dtype=bool)
else:
resets = (episode_len == max_episode_len)
resets = np.logical_or(
resets, [info.get('needs_reset', False) for info in infos])
# Agent observes the consequences
agent.batch_observe_and_train(obss, rs, dones, resets)

Expand All @@ -84,11 +86,9 @@ def train_agent_batch(agent, env, steps, outdir, log_interval=None,
# 3. clear the record of rewards
# 4. clear the record of the number of steps
# 5. reset the env to start a new episode
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should these comments be revised?

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think these comments are still correct except that 3-5 are skipped when training is finished. I'll clarify this in the comments.

# 3-5 are skipped when training is already finished.
episode_idx += end
recent_returns.extend(episode_r[end])
episode_r[end] = 0
episode_len[end] = 0
obss = env.reset(not_end)

for _ in range(num_envs):
t += 1
Expand All @@ -114,6 +114,14 @@ def train_agent_batch(agent, env, steps, outdir, log_interval=None,
evaluator.max_score >= successful_score):
break

if t >= steps:
break

# Start new episodes if needed
episode_r[end] = 0
episode_len[end] = 0
obss = env.reset(not_end)

except (Exception, KeyboardInterrupt):
# Save the current model before being killed
save_agent(agent, t, outdir, logger, suffix='_except')
Expand Down
2 changes: 2 additions & 0 deletions chainerrl/wrappers/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
from chainerrl.wrappers.cast_observation import CastObservation # NOQA
from chainerrl.wrappers.cast_observation import CastObservationToFloat32 # NOQA

from chainerrl.wrappers.continuing_time_limit import ContinuingTimeLimit # NOQA

from chainerrl.wrappers.randomize_action import RandomizeAction # NOQA

from chainerrl.wrappers.render import Render # NOQA
Expand Down
29 changes: 18 additions & 11 deletions chainerrl/wrappers/atari_wrappers.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@

from gym import spaces

import chainerrl
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since you're changing the file, perhaps we should change the header from "This file is a fork from a MIT-licensed project" to "This file adapted from an MIT-licensed project..."

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

"fork" means it already has changes, no?


cv2.ocl.setUseOpenCL(False)


Expand All @@ -36,8 +38,8 @@ def _reset(self, **kwargs):
assert noops > 0
obs = None
for _ in range(noops):
obs, _, done, _ = self.env.step(self.noop_action)
if done:
obs, _, done, info = self.env.step(self.noop_action)
if done or info.get('needs_reset', False):
obs = self.env.reset(**kwargs)
return obs

Expand All @@ -54,11 +56,11 @@ def __init__(self, env):

def _reset(self, **kwargs):
self.env.reset(**kwargs)
obs, _, done, _ = self.env.step(1)
if done:
obs, _, done, info = self.env.step(1)
if done or info.get('needs_reset', False):
self.env.reset(**kwargs)
obs, _, done, _ = self.env.step(2)
if done:
obs, _, done, info = self.env.step(2)
if done or info.get('needs_reset', False):
self.env.reset(**kwargs)
return obs

Expand All @@ -74,11 +76,11 @@ def __init__(self, env):
"""
gym.Wrapper.__init__(self, env)
self.lives = 0
self.was_real_done = True
self.needs_real_reset = True

def _step(self, action):
obs, reward, done, info = self.env.step(action)
self.was_real_done = done
self.needs_real_reset = done or info.get('needs_reset', False)
# check current lives, make loss of life terminal,
# then update lives to handle bonus lives
lives = self.env.unwrapped.ale.lives()
Expand All @@ -97,7 +99,7 @@ def _reset(self, **kwargs):
This way all states are still reachable even though lives are episodic,
and the learner need not know about any of this behind-the-scenes.
"""
if self.was_real_done:
if self.needs_real_reset:
obs = self.env.reset(**kwargs)
else:
# no-op step to advance from terminal/lost life state
Expand Down Expand Up @@ -126,7 +128,7 @@ def _step(self, action):
if i == self._skip - 1:
self._obs_buffer[1] = obs
total_reward += reward
if done:
if done or info.get('needs_reset', False):
break
# Note that the observation on the done=True frame
# doesn't matter
Expand Down Expand Up @@ -238,9 +240,14 @@ def __array__(self, dtype=None):
return out


def make_atari(env_id):
def make_atari(env_id, max_frames=30 * 60 * 60):
env = gym.make(env_id)
assert 'NoFrameskip' in env.spec.id
assert isinstance(env, gym.wrappers.TimeLimit)
# Unwrap TimeLimit wrapper because we use our own time limits
env = env.env
env = chainerrl.wrappers.ContinuingTimeLimit(
env, max_episode_steps=max_frames)
env = NoopResetEnv(env, noop_max=30)
env = MaxAndSkipEnv(env, skip=4)
return env
Expand Down
48 changes: 48 additions & 0 deletions chainerrl/wrappers/continuing_time_limit.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
from __future__ import unicode_literals
from __future__ import print_function
from __future__ import division
from __future__ import absolute_import
from builtins import * # NOQA
from future import standard_library
standard_library.install_aliases() # NOQA

import gym


class ContinuingTimeLimit(gym.Wrapper):
"""TimeLimit wrapper for continuing environments.

This is similar gym.wrappers.TimeLimit, which sets a time limit for
each episode, except that done=False is returned and that
info['needs_reset'] is set to True when past the limit.

Code that calls env.step is responsible for checking the info dict, the
fourth returned value, and resetting the env if it has the 'needs_reset'
key and its value is True.

Args:
env (gym.Env): Env to wrap.
max_episode_steps (int): Maximum number of timesteps during an episode,
after which the env needs a reset.
"""

def __init__(self, env, max_episode_steps):
super(ContinuingTimeLimit, self).__init__(env)
self._max_episode_steps = max_episode_steps

self._elapsed_steps = None

def step(self, action):
assert self._elapsed_steps is not None,\
"Cannot call env.step() before calling reset()"
observation, reward, done, info = self.env.step(action)
self._elapsed_steps += 1

if self._max_episode_steps <= self._elapsed_steps:
info['needs_reset'] = True

return observation, reward, done, info

def reset(self):
self._elapsed_steps = 0
return self.env.reset()
9 changes: 4 additions & 5 deletions examples/ale/train_a2c_ale.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,9 +53,9 @@ def main():
parser.add_argument('--seed', type=int, default=0,
help='Random seed [0, 2 ** 31)')
parser.add_argument('--outdir', type=str, default='results')
parser.add_argument('--max-episode-len', type=int,
default=30 * 60 * 60 // 4, # 30 minutes with 60/4 fps
help='Maximum number of steps for each episode.')
parser.add_argument('--max-frames', type=int,
default=30 * 60 * 60, # 30 minutes with 60 fps
help='Maximum number of frames for each episode.')
parser.add_argument('--steps', type=int, default=8 * 10 ** 7)
parser.add_argument('--update-steps', type=int, default=5)
parser.add_argument('--lr', type=float, default=7e-4)
Expand Down Expand Up @@ -109,7 +109,7 @@ def make_env(process_idx, test):
process_seed = process_seeds[process_idx]
env_seed = 2 ** 31 - 1 - process_seed if test else process_seed
env = atari_wrappers.wrap_deepmind(
atari_wrappers.make_atari(args.env),
atari_wrappers.make_atari(args.env, max_frames=args.max_frames),
episode_life=not test,
clip_rewards=not test)
env.seed(int(env_seed))
Expand Down Expand Up @@ -169,7 +169,6 @@ def make_batch_env(test):
eval_interval=args.eval_interval,
outdir=args.outdir,
save_best_so_far_agent=False,
max_episode_len=args.max_episode_len,
log_interval=1000,
)

Expand Down
9 changes: 4 additions & 5 deletions examples/ale/train_a3c_ale.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,9 +76,9 @@ def main():
parser.add_argument('--beta', type=float, default=1e-2)
parser.add_argument('--profile', action='store_true')
parser.add_argument('--steps', type=int, default=8 * 10 ** 7)
parser.add_argument('--max-episode-len', type=int,
default=5 * 60 * 60 // 4, # 5 minutes with 60/4 fps
help='Maximum number of steps for each episode.')
parser.add_argument('--max-frames', type=int,
default=30 * 60 * 60, # 30 minutes with 60 fps
help='Maximum number of frames for each episode.')
parser.add_argument('--lr', type=float, default=7e-4)
parser.add_argument('--eval-interval', type=int, default=10 ** 6)
parser.add_argument('--eval-n-runs', type=int, default=10)
Expand Down Expand Up @@ -151,7 +151,7 @@ def make_env(process_idx, test):
process_seed = process_seeds[process_idx]
env_seed = 2 ** 31 - 1 - process_seed if test else process_seed
env = atari_wrappers.wrap_deepmind(
atari_wrappers.make_atari(args.env),
atari_wrappers.make_atari(args.env, max_frames=args.max_frames),
episode_life=not test,
clip_rewards=not test)
env.seed(int(env_seed))
Expand Down Expand Up @@ -190,7 +190,6 @@ def lr_setter(env, agent, value):
steps=args.steps,
eval_n_runs=args.eval_n_runs,
eval_interval=args.eval_interval,
max_episode_len=args.max_episode_len,
global_step_hooks=[lr_decay_hook],
save_best_so_far_agent=False,
)
Expand Down
Loading