Skip to content
This repository was archived by the owner on Oct 7, 2024. It is now read-only.

Commit 5116216

Browse files
iosbandcopybara-github
authored andcommitted
Make number of bandit arms configurable.
Default kwargs keep bsuite the same. PiperOrigin-RevId: 355511900 Change-Id: I0d674e9650ec24f5b9f9932219e66451df3e0fca
1 parent b4c9fed commit 5116216

File tree

2 files changed

+8
-8
lines changed

2 files changed

+8
-8
lines changed

bsuite/environments/bandit.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -31,19 +31,19 @@
3131
class SimpleBandit(base.Environment):
3232
"""SimpleBandit environment."""
3333

34-
def __init__(self, mapping_seed: int = None):
34+
def __init__(self, mapping_seed: int = None, num_actions: int = 11):
3535
"""Builds a simple bandit environment.
3636
3737
Args:
3838
mapping_seed: Optional integer. Seed for action mapping.
39+
num_actions: number of actions available, defaults to 11.
3940
"""
4041
super(SimpleBandit, self).__init__()
4142
self._rng = np.random.RandomState(mapping_seed)
42-
43-
self._n_actions = 11
43+
self._num_actions = num_actions
4444
action_mask = self._rng.choice(
45-
range(self._n_actions), size=self._n_actions, replace=False)
46-
self._rewards = np.linspace(0, 1, self._n_actions)[action_mask]
45+
range(self._num_actions), size=self._num_actions, replace=False)
46+
self._rewards = np.linspace(0, 1, self._num_actions)[action_mask]
4747

4848
self._total_regret = 0.
4949
self._optimal_return = 1.
@@ -66,7 +66,7 @@ def observation_spec(self):
6666
return specs.Array(shape=(1, 1), dtype=np.float32)
6767

6868
def action_spec(self):
69-
return specs.DiscreteArray(self._n_actions, name='action')
69+
return specs.DiscreteArray(self._num_actions, name='action')
7070

7171
def bsuite_info(self):
7272
return dict(total_regret=self._total_regret)

bsuite/experiments/bandit_noise/bandit_noise.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,10 +25,10 @@
2525
from bsuite.utils import wrappers
2626

2727

28-
def load(noise_scale, seed, mapping_seed):
28+
def load(noise_scale, seed, mapping_seed, num_actions=11):
2929
"""Load a bandit_noise experiment with the prescribed settings."""
3030
env = wrappers.RewardNoise(
31-
env=bandit.SimpleBandit(mapping_seed=mapping_seed),
31+
env=bandit.SimpleBandit(mapping_seed, num_actions=num_actions),
3232
noise_scale=noise_scale,
3333
seed=seed)
3434
env.bsuite_num_episodes = sweep.NUM_EPISODES

0 commit comments

Comments
 (0)