Skip to content

Commit b7aa4e3

Browse files
kevin-j-millercopybara-github
authored andcommitted
Update split_dataset to pass batch_size correctly in "single" mode.
In this mode the batch size of the split datasets should not match that of the parent dataset, but the number of episodes that each one ends up with. PiperOrigin-RevId: 836430128
1 parent d9c0a7e commit b7aa4e3

File tree

3 files changed

+33
-15
lines changed

3 files changed

+33
-15
lines changed

disentangled_rnns/library/rnn_utils.py

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -183,9 +183,15 @@ def __init__(
183183
####################
184184
# Property setting #
185185
####################
186-
# If batch size not specified, use all episodes in the dataset
186+
# If batch size not specified, use all episodes in a single batch
187187
if batch_size is None:
188188
batch_size = xs.shape[1]
189+
# In single-batch mode, batch size must match dataset size
190+
if batch_mode == 'single' and batch_size != xs.shape[1]:
191+
raise ValueError(
192+
'In single batch mode, match size must be equal to dataset size, or',
193+
f'must be None. Instead, is {batch_size}'
194+
)
189195

190196
self.x_names = x_names
191197
self.y_names = y_names
@@ -260,14 +266,19 @@ def split_dataset(
260266
train_sessions[np.arange(eval_every_n - 1, n_sessions, eval_every_n)] = False
261267
eval_sessions = np.logical_not(train_sessions)
262268

269+
if dataset.batch_mode == 'single':
270+
batch_size = None
271+
else:
272+
batch_size = dataset.batch_size
273+
263274
dataset_train = DatasetRNN(
264275
xs[:, train_sessions, :],
265276
ys[:, train_sessions, :],
266277
x_names=dataset.x_names,
267278
y_names=dataset.y_names,
268279
y_type=dataset.y_type,
269280
n_classes=dataset.n_classes,
270-
batch_size=dataset.batch_size,
281+
batch_size=batch_size,
271282
batch_mode=dataset.batch_mode,
272283
)
273284
dataset_eval = DatasetRNN(
@@ -277,7 +288,7 @@ def split_dataset(
277288
y_names=dataset.y_names,
278289
y_type=dataset.y_type,
279290
n_classes=dataset.n_classes,
280-
batch_size=dataset.batch_size,
291+
batch_size=None,
281292
batch_mode=dataset.batch_mode,
282293
)
283294
return dataset_train, dataset_eval

disentangled_rnns/library/two_armed_bandits.py

Lines changed: 19 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -479,27 +479,34 @@ def run_experiment(agent: Agent,
479479
choices[step] = choice
480480
rewards[step] = reward
481481

482-
experiment = SessData(choices=choices,
483-
rewards=rewards,
484-
n_trials=n_steps,
485-
reward_probs=reward_probs)
482+
experiment = SessData(
483+
choices=choices,
484+
rewards=rewards,
485+
n_trials=n_steps,
486+
reward_probs=reward_probs,
487+
)
486488
return experiment
487489

488490

489-
def create_dataset(agent: Agent,
490-
environment: EnvironmentBanditsDrift,
491-
n_steps_per_session: int,
492-
n_sessions: int,
493-
batch_size: int) -> rnn_utils.DatasetRNN:
491+
def create_dataset(
492+
agent: Agent,
493+
environment: EnvironmentBanditsDrift,
494+
n_steps_per_session: int,
495+
n_sessions: int,
496+
batch_size: int | None = None,
497+
batch_mode: Literal['single', 'rolling', 'random'] = 'single',
498+
) -> rnn_utils.DatasetRNN:
494499
"""Generates a behavioral dataset from a given agent and environment.
495500
496501
Args:
497502
agent: An agent object to generate choices
498503
environment: An environment object to generate rewards
499-
n_steps_per_session: The number of trials in each behavioral session to
500-
be generated
504+
n_steps_per_session: The number of trials in each behavioral session to be
505+
generated
501506
n_sessions: The number of sessions to generate
502507
batch_size: The size of the batches to serve from the dataset
508+
batch_mode: Batch mode to pass to DatasetRNN. Must be a type that is
509+
supported by DatasetRNN.
503510
504511
Returns:
505512
rnn_utils.DatasetRNN object
@@ -526,6 +533,7 @@ def create_dataset(agent: Agent,
526533
y_type='categorical',
527534
n_classes=2,
528535
batch_size=batch_size,
536+
batch_mode=batch_mode,
529537
)
530538
return dataset
531539

disentangled_rnns/library/two_armed_bandits_test.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,6 @@ def test_generate_dataset(self, agent):
7979
agent=agent,
8080
n_steps_per_session=10,
8181
n_sessions=10,
82-
batch_size=5,
8382
)
8483

8584
self.assertIsInstance(dataset, rnn_utils.DatasetRNN)

0 commit comments

Comments
 (0)