Skip to content
This repository was archived by the owner on Oct 7, 2024. It is now read-only.

Commit f9b74bf

Browse files
aslanidescopybara-github
authored andcommitted
Extract environments to their own package.
This makes the distinction between an _environment_ and an _experiment_ clearer. If users want to import individual environments for their own debugging/development: ✗ from bsuite.experiments.catch import catch ✓ from bsuite.environments import catch This change also introduces some more formal typing of bsuite environments: - Add a base class which includes the bsuite_* attributes/methods. PiperOrigin-RevId: 307575828 Change-Id: Iba2303d64a397ccef8a3f3f154e414bf343f905b
1 parent 6c12227 commit f9b74bf

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

56 files changed

+1251
-948
lines changed

bsuite/baselines/experiment.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,8 @@ def run(agent: base.Agent,
3838
"""
3939

4040
if verbose:
41-
environment = terminal_logging.wrap_environment(environment, log_every=True)
41+
environment = terminal_logging.wrap_environment(
42+
environment, log_every=True) # pytype: disable=wrong-arg-types
4243

4344
for _ in range(num_episodes):
4445
# Run an episode.

bsuite/baselines/third_party/dopamine_dqn/run.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -123,7 +123,7 @@ def create_environment() -> gym.Env:
123123
"""Factory method for environment initialization in Dopmamine."""
124124
env = wrappers.ImageObservation(raw_env, OBSERVATION_SHAPE)
125125
if FLAGS.verbose:
126-
env = terminal_logging.wrap_environment(env, log_every=True)
126+
env = terminal_logging.wrap_environment(env, log_every=True) # pytype: disable=wrong-arg-types
127127
env = gym_wrapper.GymFromDMEnv(env)
128128
env.game_over = False # Dopamine looks for this
129129
return env

bsuite/baselines/third_party/openai_dqn/run.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,7 @@ def run(bsuite_id: str) -> str:
7575
overwrite=FLAGS.overwrite,
7676
)
7777
if FLAGS.verbose:
78-
raw_env = terminal_logging.wrap_environment(raw_env, log_every=True)
78+
raw_env = terminal_logging.wrap_environment(raw_env, log_every=True) # pytype: disable=wrong-arg-types
7979
env = gym_wrapper.GymFromDMEnv(raw_env)
8080

8181
num_episodes = FLAGS.num_episodes or getattr(raw_env, 'bsuite_num_episodes')

bsuite/baselines/third_party/openai_ppo/run.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,7 @@ def _load_env():
6666
overwrite=FLAGS.overwrite,
6767
)
6868
if FLAGS.verbose:
69-
raw_env = terminal_logging.wrap_environment(raw_env, log_every=True)
69+
raw_env = terminal_logging.wrap_environment(raw_env, log_every=True) # pytype: disable=wrong-arg-types
7070
return gym_wrapper.GymFromDMEnv(raw_env)
7171
env = dummy_vec_env.DummyVecEnv([_load_env])
7272

bsuite/bsuite.py

Lines changed: 11 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
from typing import Any, Mapping, Tuple
2020

2121
from bsuite import sweep
22+
from bsuite.environments import base
2223
from bsuite.experiments.bandit import bandit
2324
from bsuite.experiments.bandit_noise import bandit_noise
2425
from bsuite.experiments.bandit_scale import bandit_scale
@@ -54,29 +55,29 @@
5455
# Each constructor or load function accepts keyword arguments as defined in
5556
# each experiment's sweep.py file.
5657
EXPERIMENT_NAME_TO_ENVIRONMENT = dict(
57-
bandit=bandit.SimpleBandit,
58+
bandit=bandit.load,
5859
bandit_noise=bandit_noise.load,
5960
bandit_scale=bandit_scale.load,
60-
cartpole=cartpole.Cartpole,
61+
cartpole=cartpole.load,
6162
cartpole_noise=cartpole_noise.load,
6263
cartpole_scale=cartpole_scale.load,
6364
cartpole_swingup=cartpole_swingup.CartpoleSwingup,
64-
catch=catch.Catch,
65+
catch=catch.load,
6566
catch_noise=catch_noise.load,
6667
catch_scale=catch_scale.load,
67-
deep_sea=deep_sea.DeepSea,
68+
deep_sea=deep_sea.load,
6869
deep_sea_stochastic=deep_sea_stochastic.load,
69-
discounting_chain=discounting_chain.DiscountingChain,
70+
discounting_chain=discounting_chain.load,
7071
memory_len=memory_len.load,
7172
memory_size=memory_size.load,
72-
mnist=mnist.MNISTBandit,
73+
mnist=mnist.load,
7374
mnist_noise=mnist_noise.load,
7475
mnist_scale=mnist_scale.load,
75-
mountain_car=mountain_car.MountainCar,
76+
mountain_car=mountain_car.load,
7677
mountain_car_noise=mountain_car_noise.load,
7778
mountain_car_scale=mountain_car_scale.load,
7879
umbrella_distract=umbrella_distract.load,
79-
umbrella_length=umbrella_length.UmbrellaChain,
80+
umbrella_length=umbrella_length.load,
8081
)
8182

8283

@@ -92,12 +93,12 @@ def unpack_bsuite_id(bsuite_id: str) -> Tuple[str, int]:
9293
def load(
9394
experiment_name: str,
9495
kwargs: Mapping[str, Any],
95-
) -> dm_env.Environment:
96+
) -> base.Environment:
9697
"""Returns a bsuite environment given an experiment name and settings."""
9798
return EXPERIMENT_NAME_TO_ENVIRONMENT[experiment_name](**kwargs)
9899

99100

100-
def load_from_id(bsuite_id: str) -> dm_env.Environment:
101+
def load_from_id(bsuite_id: str) -> base.Environment:
101102
"""Returns a bsuite environment given a bsuite_id."""
102103
kwargs = sweep.SETTINGS[bsuite_id]
103104
experiment_name, _ = unpack_bsuite_id(bsuite_id)

bsuite/environments/README.md

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
# Environments
2+
3+
This folder contains the raw *environments* used in `bsuite` experiments; we
4+
expose them here for debugging and development purposes;
5+
6+
Recall that in the context of bsuite, an *experiment* consists of three parts:
7+
1. Environments: a fixed set of environments determined by some parameters. 2.
8+
Interaction: a fixed regime of agent/environment interaction (e.g. 100
9+
episodes). 3. Analysis: a fixed procedure that maps agent behaviour to results
10+
and plots.
11+
12+
Note: If you load the environment from this folder you will miss out on the
13+
interaction+analysis as specified by bsuite. In general, you should use the
14+
`bsuite_id` to load the environment via `bsuite.load_from_id(bsuite_id)` rather
15+
than the raw environment.

bsuite/experiments/memory_size/memory_size_test.py renamed to bsuite/environments/__init__.py

Lines changed: 2 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -14,25 +14,6 @@
1414
# See the License for the specific language governing permissions and
1515
# limitations under the License.
1616
# ============================================================================
17-
"""Tests for bsuite.experiments.memory_len."""
17+
"""bsuite environments package."""
1818

19-
from absl.testing import absltest
20-
from bsuite.experiments.memory_size import memory_size
21-
from dm_env import test_utils
22-
import numpy as np
23-
24-
25-
class InterfaceTest(test_utils.EnvironmentTestMixin, absltest.TestCase):
26-
27-
def make_object_under_test(self):
28-
return memory_size.load(10)
29-
30-
def make_action_sequence(self):
31-
valid_actions = [0, 1]
32-
rng = np.random.RandomState(42)
33-
34-
for _ in range(100):
35-
yield rng.choice(valid_actions)
36-
37-
if __name__ == '__main__':
38-
absltest.main()
19+
from bsuite.environments.base import Environment

bsuite/environments/bandit.py

Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,72 @@
1+
# python3
2+
# pylint: disable=g-bad-file-header
3+
# Copyright 2019 DeepMind Technologies Limited. All Rights Reserved.
4+
#
5+
# Licensed under the Apache License, Version 2.0 (the "License");
6+
# you may not use this file except in compliance with the License.
7+
# You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing, software
12+
# distributed under the License is distributed on an "AS IS" BASIS,
13+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
# See the License for the specific language governing permissions and
15+
# limitations under the License.
16+
# ============================================================================
17+
"""Simple diagnostic bandit environment.
18+
19+
Observation is a single pixel of 0 - this is an independent arm bandit problem!
20+
Rewards are [0, 0.1, .. 1] assigned randomly to 11 arms and deterministic
21+
"""
22+
23+
from bsuite.environments import base
24+
from bsuite.experiments.bandit import sweep
25+
26+
import dm_env
27+
from dm_env import specs
28+
import numpy as np
29+
30+
31+
class SimpleBandit(base.Environment):
32+
"""SimpleBandit environment."""
33+
34+
def __init__(self, seed=None):
35+
"""Builds a simple bandit environment.
36+
37+
Args:
38+
seed: Optional integer. Seed for numpy's random number generator (RNG).
39+
"""
40+
super(SimpleBandit, self).__init__()
41+
self._rng = np.random.RandomState(seed)
42+
43+
self._n_actions = 11
44+
action_mask = self._rng.choice(
45+
range(self._n_actions), size=self._n_actions, replace=False)
46+
self._rewards = np.linspace(0, 1, self._n_actions)[action_mask]
47+
48+
self._total_regret = 0.
49+
self._optimal_return = 1.
50+
self.bsuite_num_episodes = sweep.NUM_EPISODES
51+
52+
def _get_observation(self):
53+
return np.ones(shape=(1, 1), dtype=np.float32)
54+
55+
def _reset(self) -> dm_env.TimeStep:
56+
observation = self._get_observation()
57+
return dm_env.restart(observation)
58+
59+
def _step(self, action: int) -> dm_env.TimeStep:
60+
reward = self._rewards[action]
61+
self._total_regret += self._optimal_return - reward
62+
observation = self._get_observation()
63+
return dm_env.termination(reward=reward, observation=observation)
64+
65+
def observation_spec(self):
66+
return specs.Array(shape=(1, 1), dtype=np.float32)
67+
68+
def action_spec(self):
69+
return specs.DiscreteArray(self._n_actions, name='action')
70+
71+
def bsuite_info(self):
72+
return dict(total_regret=self._total_regret)

bsuite/experiments/bandit/bandit_test.py renamed to bsuite/environments/bandit_test.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
"""Tests for bsuite.experiments.bandit."""
1818

1919
from absl.testing import absltest
20-
from bsuite.experiments.bandit import bandit
20+
from bsuite.environments import bandit
2121
from dm_env import test_utils
2222
import numpy as np
2323

bsuite/utils/auto_reset_environment.py renamed to bsuite/environments/base.py

Lines changed: 39 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -14,41 +14,65 @@
1414
# See the License for the specific language governing permissions and
1515
# limitations under the License.
1616
# ============================================================================
17-
"""Auto-resetting environment base class.
17+
""""Base class for bsuite environments.
1818
19-
The environment API states that stepping an environment after a LAST timestep
20-
should return the first timestep of a new episode.
19+
This inherits from the dm_env base class, with two major differences:
20+
21+
- Includes bsuite-specific metadata:
22+
- `bsuite_info` returns metadata for logging, e.g. for computing regret/score.
23+
- `bsuite_num_episodes` specifies how long the experiment should run for.
24+
- Implements the auto-reset behavior specified by the environment API.
25+
That is, stepping an environment after a LAST timestep should return the
26+
first timestep of a new episode.
2127
"""
2228

2329
import abc
30+
from typing import Any, Dict
31+
2432
import dm_env
2533

2634

27-
class Base(dm_env.Environment):
28-
"""This class implements the required `step()` and `reset()` methods.
35+
class Environment(dm_env.Environment):
36+
"""Base clas for bsuite environments.
37+
38+
A bsuite environment is a dm_env environment with extra metadata:
39+
- bsuite_info method.
40+
- bsuite_num_episodes attribute.
41+
42+
A bsuite environment also has auto-reset behavior.
43+
This class implements the required `step()` and `reset()` methods.
2944
3045
It instead requires users to implement `_step()` and `_reset()`. This class
3146
handles the reset behaviour automatically when it detects a LAST timestep.
3247
"""
3348

49+
# Number of episodes that this environment should be run for.
50+
bsuite_num_episodes: int
51+
3452
def __init__(self):
3553
self._reset_next_step = True
3654

37-
@abc.abstractmethod
38-
def _reset(self):
39-
"""Returns a `timestep` namedtuple as per the regular `reset()` method."""
40-
41-
@abc.abstractmethod
42-
def _step(self, action):
43-
"""Returns a `timestep` namedtuple as per the regular `step()` method."""
44-
45-
def reset(self):
55+
def reset(self) -> dm_env.TimeStep:
56+
"""Resets the environment, calling the underlying _reset() method."""
4657
self._reset_next_step = False
4758
return self._reset()
4859

49-
def step(self, action):
60+
def step(self, action: int) -> dm_env.TimeStep:
61+
"""Steps the environment and implements the auto-reset behavior."""
5062
if self._reset_next_step:
5163
return self.reset()
5264
timestep = self._step(action)
5365
self._reset_next_step = timestep.last()
5466
return timestep
67+
68+
@abc.abstractmethod
69+
def _reset(self) -> dm_env.TimeStep:
70+
"""Returns a `timestep` namedtuple as per the regular `reset()` method."""
71+
72+
@abc.abstractmethod
73+
def _step(self, action: int) -> dm_env.TimeStep:
74+
"""Returns a `timestep` namedtuple as per the regular `step()` method."""
75+
76+
@abc.abstractmethod
77+
def bsuite_info(self) -> Dict[str, Any]:
78+
"""Returns metadata specific to this environment for logging/scoring."""

0 commit comments

Comments
 (0)