-
Notifications
You must be signed in to change notification settings - Fork 515
Expand file tree
/
Copy pathdata_loader.py
More file actions
1249 lines (1098 loc) · 52.4 KB
/
data_loader.py
File metadata and controls
1249 lines (1098 loc) · 52.4 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
# Copyright 2024 AllenAI. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import os
import threading
import time
from collections.abc import Callable, Iterable, Iterator
from concurrent.futures import ThreadPoolExecutor
from dataclasses import asdict, dataclass, field
from pathlib import Path
from queue import Empty
from typing import Any, Literal
import numpy as np
import ray
import torch
import vllm
from datasets import Dataset
from olmo_core.data import data_loader
from ray.util import queue as ray_queue
from tqdm import tqdm
from transformers import PreTrainedTokenizer
from open_instruct import data_types, padding_free_collator, utils
from open_instruct.dataset_transformation import (
GROUND_TRUTHS_KEY,
INPUT_IDS_PROMPT_KEY,
RAW_PROMPT_KEY,
TOOLS_COLUMN_KEY,
VERIFIER_SOURCE_KEY,
)
from open_instruct.model_utils import Batch
from open_instruct.rl_utils import PackedSequences, pack_sequences, save_rollout_metadata, save_rollouts_to_disk
from open_instruct.tools.utils import ToolStatistics
from open_instruct.utils import combine_reward_metrics, repeat_each
logger = logging.getLogger(__name__)
def to_device(batch: dict[str, Any], device: torch.device | None) -> dict[str, Any]:
"""Move all tensors in a batch dictionary to the specified device.
Args:
batch: Dictionary potentially containing torch.Tensor values.
device: Target device. If None, tensors are not moved.
Returns:
Dictionary with the same keys, but tensor values moved to the target device.
"""
if device is None:
return batch
return {k: v.to(device) if isinstance(v, torch.Tensor) else v for k, v in batch.items()}
class HFDataLoader(data_loader.DataLoaderBase):
"""A DataLoader that wraps a HuggingFace Dataset for use with olmo_core's Trainer.
This class implements the DataLoaderBase interface, providing iteration over
a HuggingFace Dataset with support for sharding across distributed workers,
shuffling, checkpointing, and optional collation.
"""
def __init__(
self,
dataset: Dataset,
batch_size: int,
seed: int,
dp_rank: int,
dp_world_size: int,
work_dir: str,
automatic_reshuffle: bool = False,
collator: Callable[[list[dict[str, Any]]], dict[str, Any]] | None = None,
device: torch.device | None = None,
drop_last: bool = True,
fs_local_rank: int | None = None,
) -> None:
"""Initialize the HFDataLoader.
Args:
dataset: The HuggingFace Dataset to load data from. Must have an 'index' column.
batch_size: The global batch size.
seed: Random seed for shuffling.
dp_rank: The rank of the current process in the distributed setup.
dp_world_size: Total number of data-parallel processes in the distributed setup.
work_dir: Working directory for the data loader (required by DataLoaderBase).
automatic_reshuffle: If True, automatically reshuffle at epoch boundaries.
collator: Optional collation function for batching examples. If None, batches will be
dictionaries of the form `{'examples': [example_1, example_2, ...]}`.
device: Device to move tensors to.
drop_last: If True, drop the last incomplete batch. If False, pad the last batch
with repeated indices to fill a complete batch.
fs_local_rank: File system local rank. Defaults to dp_rank when None.
Note:
The dataset must have an 'index' column for tracking samples across epochs.
This is automatically added by get_cached_dataset_tulu(). For custom datasets,
add it with: dataset.add_column('index', range(len(dataset)))
"""
super().__init__(
work_dir=work_dir,
global_batch_size=batch_size,
dp_world_size=dp_world_size,
dp_rank=dp_rank,
fs_local_rank=fs_local_rank if fs_local_rank is not None else dp_rank,
)
if "index" not in dataset.column_names:
raise ValueError(
"Dataset must have an 'index' column. This is typically added by get_cached_dataset_tulu(). "
"If using a custom dataset, add it with: dataset.add_column('index', range(len(dataset)))"
)
self._full_dataset = dataset
self.seed = seed
self._batch_size = batch_size
if batch_size < dp_world_size:
raise ValueError(
f"Global batch size ({batch_size}) must be >= world size ({dp_world_size}). "
f"Each rank needs at least one example per batch."
)
if batch_size % dp_world_size != 0:
logger.warning(
f"Global batch size {batch_size} is not divisible by world size {dp_world_size}. "
f"The effective global batch size will be {batch_size // dp_world_size * dp_world_size}."
)
self._per_rank_batch_size = batch_size // dp_world_size
self._collator = collator if collator is not None else (lambda x: {"examples": x})
self._automatic_reshuffle = automatic_reshuffle
self._drop_last = drop_last
self._excluded_indices: set[int] = set()
self._overflow: list[dict[str, Any]] = []
self._epoch: int = 0
self._current_iter: Iterator[dict[str, Any]] | None = None
self._device = device
self._reshard(epoch=0)
def __next__(self) -> dict[str, Any]:
if self._current_iter is None:
self._current_iter = iter(self)
try:
return next(self._current_iter)
except StopIteration:
self._current_iter = None
if self._automatic_reshuffle:
self.reshuffle()
if self.effective_size == 0:
raise RuntimeError("All dataset examples have been excluded. Cannot continue iteration.") from None
self._current_iter = iter(self)
return next(self._current_iter)
self._epoch += 1
self.batches_processed = 0
raise
def _iter_batches(self) -> Iterable[dict[str, Any]]:
"""Return an iterable over all batches in the epoch."""
start_example = self.batches_processed * self._per_rank_batch_size
batch_examples: list[dict[str, Any]] = []
for i in range(start_example, self.effective_size):
example = self.dataset[i]
batch_examples.append(example | {"prompt_id": f"{self._epoch}_{example['index']}"})
if len(batch_examples) == self._per_rank_batch_size:
all_examples = self._overflow + batch_examples
batch = to_device(self._collator(all_examples), self._device)
self._overflow = all_examples[len(batch["index"]) :]
yield batch
batch_examples = []
while self._overflow:
batch = to_device(self._collator(self._overflow), self._device)
assert len(batch["index"]) > 0, (
f"Collator consumed 0 examples from {len(self._overflow)} overflow examples"
)
self._overflow = self._overflow[len(batch["index"]) :]
yield batch
@property
def total_batches(self) -> int:
"""Return the total number of batches in an epoch."""
return self.effective_size // self._per_rank_batch_size
def state_dict(self) -> dict[str, Any]:
"""Return a state dictionary for checkpointing."""
return {
"epoch": self._epoch,
"batches_processed": self.batches_processed,
"excluded_indices": list(self._excluded_indices),
}
def load_state_dict(self, state_dict: dict[str, Any]) -> None:
"""Load a state dictionary to restore the data loader's state."""
self._excluded_indices = set(state_dict.get("excluded_indices", []))
# Set epoch to one less than target since reshuffle() increments it
self._epoch = state_dict["epoch"] - 1
self.reshuffle()
assert self._epoch == state_dict["epoch"]
self.batches_processed = state_dict["batches_processed"]
self._current_iter = None
def exclude_index(self, index: int) -> None:
"""Exclude a dataset index from future iterations.
Args:
index: The index to exclude.
"""
self._excluded_indices.add(index)
def reshuffle(self, epoch: int | None = None, **kwargs: Any) -> None:
"""Reshuffle and reshard the dataset for a new epoch.
Args:
epoch: The epoch number to use for shuffling seed. If None, increments internal counter.
**kwargs: Additional keyword arguments (unused, for API compatibility).
"""
self._epoch = self._epoch + 1 if epoch is None else epoch
self.batches_processed = 0
self._reshard(self._epoch)
def _reshard(self, epoch: int) -> None:
"""Reshard the dataset for a given epoch.
Uses index-based shuffling to avoid copying the dataset.
"""
generator = torch.Generator()
generator.manual_seed(self.seed + epoch)
dataset_len = len(self._full_dataset)
all_indices = torch.randperm(dataset_len, generator=generator).numpy()
if self._excluded_indices:
mask = np.isin(all_indices, list(self._excluded_indices), invert=True)
all_indices = all_indices[mask]
global_size = len(all_indices)
total_batches = global_size // self._batch_size
usable_size = total_batches * self._batch_size
if not self._drop_last and usable_size < global_size:
remainder = global_size - usable_size
pad_indices = all_indices[: self._batch_size - remainder]
all_indices = np.concatenate([all_indices, pad_indices])
total_batches += 1
usable_size = total_batches * self._batch_size
# Distribute examples from global batches to ranks. This is a form of strided sampling where each
# rank gets a subset of examples from each global batch, ensuring a diverse set of examples.
rank_indices = all_indices[:usable_size].reshape(total_batches, self._batch_size)
rank_indices = rank_indices[:, self.dp_rank :: self.dp_world_size].flatten()
self.effective_size = len(rank_indices)
self.dataset = self._full_dataset.select(rank_indices.tolist())
def get_mock_batch(self) -> dict[str, Any]:
"""Return a batch with arbitrary data for dry-run testing.
Used by the trainer to do a dry-run of the
forward and backward pass before training officially starts.
"""
num_examples = min(self._per_rank_batch_size, len(self.dataset))
examples = [self.dataset[i] for i in range(num_examples)]
return to_device(self._collator(examples), self._device)
def global_num_tokens_in_batch(self, batch: dict[str, Any]) -> int:
"""Return the total number of tokens in the batch across all ranks.
Counts tokens from all keys containing 'input_ids' that are torch tensors.
Args:
batch: A batch dictionary containing input tensors.
Returns:
Total number of tokens across all ranks.
Raises:
ValueError: If no input_ids tensors are found in the batch.
"""
num_tokens = padding_free_collator.get_num_tokens(batch)
return num_tokens * self.dp_world_size
@dataclass
class VLLMConfig:
vllm_num_engines: int = 1
vllm_tensor_parallel_size: int = 1
vllm_enforce_eager: bool = False
vllm_sync_backend: str = "nccl"
vllm_gpu_memory_utilization: float = 0.9
vllm_enable_prefix_caching: bool = False
vllm_top_p: float = 1.0
def __post_init__(self):
if os.environ.get("VLLM_USE_V1") == "0":
logger.warning("When using the v0 version of vLLM, caching is broken and will never be invalidated.")
if self.vllm_enable_prefix_caching:
raise ValueError("Prefix caching is currently not supported for v0.")
@dataclass
class StreamingDataLoaderConfig:
# Data loading/packing
max_prompt_token_length: int = 256
response_length: int = 256
pack_length: int = 512
# Batching
async_steps: int = 1
num_samples_per_prompt_rollout: int = 4
num_unique_prompts_rollout: int = 16
# GRPO sampling/filtering
active_sampling: bool = False
filter_zero_std_samples: bool = True
no_resampling_pass_rate: float | None = None
advantage_normalization_type: str = "standard"
mask_truncated_completions: bool = False
mask_tool_use: bool = True
# Dataset
dataset_mixer_list: list[str] = field(default_factory=lambda: ["ai2-adapt-dev/rlvr_gsm8k_zs", "1.0"])
dataset_mixer_eval_list: list[str] = field(default_factory=lambda: ["ai2-adapt-dev/rlvr_gsm8k_zs", "1.0"])
dataset_mixer_list_splits: list[str] = field(default_factory=lambda: ["train"])
dataset_mixer_eval_list_splits: list[str] = field(default_factory=lambda: ["test"])
dataset_transform_fn: list[str] = field(default_factory=lambda: ["rlvr_tokenize_v1", "rlvr_max_length_filter_v1"])
dataset_cache_mode: Literal["hf", "local"] = "local"
dataset_local_cache_dir: str = "local_dataset_cache"
dataset_config_hash: str | None = None
dataset_config_eval_hash: str | None = None
dataset_skip_cache: bool = False
shuffle_eval_dataset: bool = False
system_prompt_override_file: str | None = None
# Generation
temperature: float = 0.7
stop_strings: list[str] | None = None
inflight_updates: bool = False
# Reward - R1 style format reward
apply_r1_style_format_reward: bool = False
r1_style_format_reward: float = 1.0
additive_format_reward: bool = False
# Reward - Verifiable reward
apply_verifiable_reward: bool = True
verification_reward: float = 10.0
remap_verifier: str | None = None
# LLM judge verifier
llm_judge_model: str = "azure/gpt-4o-mini-standard"
llm_judge_max_tokens: int = 2048
llm_judge_max_context_length: int = 8192
llm_judge_temperature: float = 1.0
llm_judge_timeout: int = 60
# Code verifier
code_api_url: str = field(
default_factory=lambda: os.environ.get("CODE_API_URL", "http://localhost:1234") + "/test_program"
)
code_max_execution_time: float = 1.0
code_pass_rate_reward_threshold: float = 0.0
code_apply_perf_penalty: bool = False
# Max length verifier
max_length_verifier_max_length: int = 32768
# Non stop penalty
non_stop_penalty: bool = False
non_stop_penalty_value: float = 0.0
# Evolving rubric reward
apply_evolving_rubric_reward: bool = False
"""Whether to generate and apply evolving rubrics for reward computation.
When enabled, a rubric buffer is automatically maintained across training steps."""
max_active_rubrics: int = 5
"""Maximum number of active evolving rubrics per query."""
cache_evolving_rubric_data_dir: str | None = None
"""Directory to cache evolving rubric generation data for debugging/analysis. If set, rubric data will be saved."""
# Rollout saving
save_traces: bool = False
rollouts_save_path: str = "/weka/oe-adapt-default/allennlp/deletable_rollouts/"
# Computed at post_init
max_possible_score: float = 1.0
def __post_init__(self):
assert self.pack_length >= self.max_prompt_token_length + self.response_length, (
"The `pack_length` needs to be greater than the sum of `max_prompt_token_length` and `response_length`!"
)
assert self.num_samples_per_prompt_rollout > 0, "Number of samples per prompt must be greater than 0!"
if self.num_samples_per_prompt_rollout == 1:
logger.warning("num_samples_per_prompt_rollout is 1. This reduces GRPO to REINFORCE.")
if self.active_sampling:
assert self.async_steps > 1, (
"With active_sampling, you should set async_steps > 1 to account for filtering of the first batch. "
"Otherwise, your generator only generates only one batch worth of prompts and a single filtered "
"prompt will cause the trainer to stall waiting for more data . "
)
assert self.filter_zero_std_samples, (
"filter_zero_std_samples must be True when active_sampling is True. "
"Active sampling requires filtering to work correctly."
)
if self.num_samples_per_prompt_rollout == 1 and self.filter_zero_std_samples:
raise ValueError(
"`filter_zero_std_samples` cannot be True when `num_samples_per_prompt_rollout` is 1, "
"as the reward standard deviation will always be 0, causing all samples to be filtered."
)
if self.async_steps < 1:
raise ValueError("`async_steps` must be greater than 0. Fully synchronous training is not supported.")
assert (
self.apply_verifiable_reward
or self.apply_r1_style_format_reward
or self.non_stop_penalty
or self.apply_evolving_rubric_reward
), "At least one reward must be applied!"
if self.stop_strings is None:
self.stop_strings = []
self.max_possible_score = 0.0
if self.apply_verifiable_reward:
self.max_possible_score += self.verification_reward
if self.apply_r1_style_format_reward and self.additive_format_reward:
self.max_possible_score += self.r1_style_format_reward
if self.save_traces and not self.rollouts_save_path:
raise ValueError("`rollouts_save_path` must be provided when `save_traces` is True.")
def build_dataloader(
self,
data_prep_actor_name: str,
tokenizer: PreTrainedTokenizer,
dp_rank: int,
fs_local_rank: int,
num_training_steps: int,
work_dir: Path | str,
dp_world_size: int,
) -> "StreamingDataLoader":
"""Build a thin wrapper dataloader that pulls from the DataPreparationActor singleton."""
return StreamingDataLoader(
data_prep_actor_name=data_prep_actor_name,
tokenizer=tokenizer,
work_dir=work_dir,
global_batch_size=self.num_unique_prompts_rollout,
num_training_steps=num_training_steps,
dp_world_size=dp_world_size,
dp_rank=dp_rank,
fs_local_rank=fs_local_rank,
)
class StreamingDataLoader(data_loader.DataLoaderBase):
"""Thin wrapper dataloader that pulls pre-prepared data from the DataPreparationActor singleton."""
def __init__(
self,
*,
data_prep_actor_name: str,
tokenizer: PreTrainedTokenizer,
work_dir: Path | str,
global_batch_size: int,
num_training_steps: int = 0,
dp_world_size: int = 1,
dp_rank: int = 0,
fs_local_rank: int = 0,
):
super().__init__(
work_dir=work_dir,
global_batch_size=global_batch_size,
dp_world_size=dp_world_size,
dp_rank=dp_rank,
fs_local_rank=fs_local_rank,
)
self.data_prep_actor = ray.get_actor(data_prep_actor_name)
self.tokenizer = tokenizer
self.num_training_steps = num_training_steps
self.training_step = 0
self.current_epoch = 0
@property
def total_batches(self) -> int | None:
return self.num_training_steps
def state_dict(self) -> dict[str, Any]:
return {"training_step": self.training_step, "current_epoch": self.current_epoch}
def load_state_dict(self, state_dict: dict[str, Any]):
self.training_step = state_dict["training_step"]
self.current_epoch = state_dict.get("current_epoch", 0)
def reshuffle(self, epoch: int | None = None, **kwargs):
if epoch is not None:
self.current_epoch = epoch
def get_mock_batch(self) -> dict[str, Any]:
dummy_qr = torch.tensor([self.tokenizer.pad_token_id, self.tokenizer.eos_token_id], dtype=torch.long)
dummy_attention = torch.tensor([1, 1], dtype=torch.long)
dummy_position_ids = torch.arange(len(dummy_qr), dtype=torch.long)
dummy_response_mask = torch.zeros_like(dummy_qr)
dummy_advantage = torch.zeros_like(dummy_qr, dtype=torch.float)
batch = data_types.CollatedBatchData(
query_responses=[dummy_qr],
attention_masks=[dummy_attention],
position_ids=[dummy_position_ids],
advantages=[dummy_advantage],
response_masks=[dummy_response_mask],
vllm_logprobs=[torch.zeros_like(dummy_qr, dtype=torch.float)],
)
return {"batch": batch, "metrics": {}}
def _iter_batches(self) -> Iterable[dict[str, Any]]:
for step in range(self.training_step, self.num_training_steps):
batch_data = ray.get(self.data_prep_actor.get_data.remote(rank=self.dp_rank, step=step))
self.training_step = step + 1
yield batch_data
def collate_fn(tensors_list: list[torch.Tensor], pad_token_id: int, pin_memory: bool = True) -> torch.Tensor:
padded_tensor = torch.nn.utils.rnn.pad_sequence(tensors_list, batch_first=True, padding_value=pad_token_id)
if pin_memory and torch.cuda.is_available():
padded_tensor = padded_tensor.pin_memory()
return padded_tensor
@dataclass
class BatchStatistics:
prompt_lengths: list[int]
response_lengths: list[int]
filtered_prompts: int
filtered_prompts_zero: int
filtered_prompts_solved: int
filtered_prompts_nonzero: int
percent_solved_mean: float
percent_solved_hist: np.ndarray
no_resampled_prompts: int
total_prompts: int
def single_example_collator(examples: list[dict[str, Any]]) -> dict[str, Any]:
assert len(examples) == 1, f"Expected 1 example, got {len(examples)}"
example = examples[0]
return example | {"index": torch.tensor([example["index"]])}
def add_prompt_to_generator(
example: dict[str, Any], epoch_number: int, param_prompt_Q: ray_queue.Queue, generation_config, is_eval: bool
) -> None:
index = int(example["index"])
param_prompt_Q.put(
data_types.PromptRequest(
prompt=example[INPUT_IDS_PROMPT_KEY],
generation_config=generation_config,
index=index,
prompt_id=f"{epoch_number}_{index}",
is_eval=is_eval,
active_tools=example.get(TOOLS_COLUMN_KEY),
)
)
def accumulate_inference_batches(
inference_results_Q: ray_queue.Queue,
generation_config: vllm.SamplingParams,
num_prompts: int,
model_dims: utils.ModelDims,
tokenizer: PreTrainedTokenizer,
dataset: Dataset,
actor_manager=None,
timeout: float | None = None,
active_sampling: bool = False,
filter_zero_std_samples: bool = False,
replenish_prompts: bool = False,
no_resampling_pass_rate: float | None = None,
iter_dataloader: HFDataLoader | None = None,
param_prompt_Q: ray_queue.Queue | None = None,
training_step: int | None = None,
verbose: bool = False,
max_possible_score: float = 1.0,
requeue_on_timeout: bool = True,
) -> (
tuple[data_types.GenerationResult, Batch, dict, BatchStatistics]
| tuple[data_types.ShutdownSentinel | None, None, None, None]
):
if no_resampling_pass_rate is not None:
assert iter_dataloader is not None, "no_resampling requires the iter_dataloader passed"
if replenish_prompts:
assert param_prompt_Q is not None and iter_dataloader is not None and dataset is not None, (
"replenish_prompts requires param_prompt_Q and iter_dataloader and dataset"
)
results = []
all_queries = []
all_ground_truths = []
all_datasets = []
all_raw_queries = []
all_decoded_responses = []
all_reward_metrics = []
all_active_tools = []
all_scores = []
all_percent_solved = []
total_filtered_prompts = 0
filtered_prompt_zero = 0
filtered_prompt_solved = 0
filtered_prompt_nonzero = 0
total_no_resampled = 0
progress_bar = tqdm(
total=num_prompts,
desc=f"Accumulating Responses and Rewarding {num_prompts} prompts",
bar_format="{l_bar}{bar}{r_bar}\n",
disable=not verbose,
)
logger.info(
f"[accumulate_inference_batches] Starting to accumulate {num_prompts} prompts, training_step={training_step}"
)
num_prompts_sampled = 0
collected_results = [] # Track results for potential requeue on timeout
while num_prompts_sampled < num_prompts:
logger.info(
f"[accumulate_inference_batches] Waiting for result {num_prompts_sampled + 1}/{num_prompts} from inference_results_Q"
)
try:
result = inference_results_Q.get(timeout=timeout)
except Empty:
if requeue_on_timeout and collected_results:
logger.info(
f"[accumulate_inference_batches] Timeout with {len(collected_results)}/{num_prompts} results, requeuing"
)
for r in collected_results:
inference_results_Q.put(r)
raise
collected_results.append(result)
logger.info(
f"[accumulate_inference_batches] Got result {num_prompts_sampled + 1}/{num_prompts}, type: {type(result).__name__}"
)
if isinstance(result, data_types.ShutdownSentinel):
return result, None, None, None
assert len(result.responses) == generation_config.n, (
f"Mismatch: individual prompt result has {len(result.responses)} responses "
f"but expected {generation_config.n} samples per prompt. "
f"Index: {result.index}, Prompt ID: {result.prompt_id}"
)
example = dataset[result.index]
query = example[INPUT_IDS_PROMPT_KEY]
ground_truth = example[GROUND_TRUTHS_KEY]
dataset_name = example[VERIFIER_SOURCE_KEY]
raw_query = example[RAW_PROMPT_KEY]
sample_active_tools = example.get(TOOLS_COLUMN_KEY)
if replenish_prompts:
assert iter_dataloader is not None
assert param_prompt_Q is not None
example = next(iter_dataloader)
add_prompt_to_generator(example, iter_dataloader._epoch, param_prompt_Q, generation_config, is_eval=False)
for i in range(len(result.finish_reasons)):
if result.finish_reasons[i] == "stop" and len(result.responses[i]) == 0:
result.responses[i].append(tokenizer.eos_token_id)
result.masks[i].append(1)
result.logprobs[i].append(float("nan"))
decoded_responses = tokenizer.batch_decode(result.responses, skip_special_tokens=False)
k_queries = repeat_each([query], generation_config.n)
k_ground_truths = repeat_each([ground_truth], generation_config.n)
k_datasets = repeat_each([dataset_name], generation_config.n)
k_raw_queries = repeat_each([raw_query], generation_config.n)
k_active_tools = repeat_each([sample_active_tools], generation_config.n)
percent_solved = np.mean(result.reward_scores).item() / max_possible_score
if no_resampling_pass_rate is not None and percent_solved >= no_resampling_pass_rate:
assert iter_dataloader is not None
iter_dataloader.exclude_index(result.index)
total_no_resampled += 1
logging.debug(
f"[Data Preparation Thread] Prompt solved at {percent_solved}, will be excluded from resampling, total no resampled: {total_no_resampled}"
)
if filter_zero_std_samples and np.std(result.reward_scores) == 0:
if not active_sampling:
num_prompts_sampled += 1
progress_bar.update(1)
total_filtered_prompts += 1
if result.reward_scores[0] == 0:
filtered_prompt_zero += 1
elif result.reward_scores[0] == max_possible_score:
filtered_prompt_solved += 1
else:
filtered_prompt_nonzero += 1
logging.debug(
f"[Data Preparation Thread] Filtered prompt with reward std 0, total filtered {total_filtered_prompts}"
)
continue
else:
num_prompts_sampled += 1
progress_bar.update(1)
results.append(result)
all_queries.extend(k_queries)
all_ground_truths.extend(k_ground_truths)
all_datasets.extend(k_datasets)
all_raw_queries.extend(k_raw_queries)
all_active_tools.extend(k_active_tools)
all_decoded_responses.extend(decoded_responses)
all_scores.extend(result.reward_scores)
all_reward_metrics.append(result.reward_metrics)
all_percent_solved.append(percent_solved)
if len(results) == 0:
logging.warning(
"[Data Preparation Thread] All prompts were filtered during accumulation. "
f"Filtered: {total_filtered_prompts} (zero std: {filtered_prompt_zero}, "
f"solved: {filtered_prompt_solved}, nonzero: {filtered_prompt_nonzero})"
)
return None, None, None, None
combined_responses = []
combined_finish_reasons = []
combined_masks = []
combined_num_calls = []
combined_timeouts = []
combined_tool_errors = []
combined_tool_outputs = []
combined_tool_runtimes = []
combined_tool_calleds = []
combined_tool_call_stats = []
combined_logprobs = []
earliest_start_time = float("inf")
prompt_lengths = []
response_lengths = []
total_prompt_tokens = 0
total_response_tokens = 0
max_generation_time = 0
for i, result in enumerate(results):
combined_responses.extend(result.responses)
combined_finish_reasons.extend(result.finish_reasons)
combined_masks.extend(result.masks)
combined_num_calls.extend(result.request_info.num_calls)
combined_timeouts.extend(result.request_info.timeouts)
combined_tool_errors.extend(result.request_info.tool_errors)
combined_tool_outputs.extend(result.request_info.tool_outputs)
combined_tool_runtimes.extend(result.request_info.tool_runtimes)
combined_tool_calleds.extend(result.request_info.tool_calleds)
combined_tool_call_stats.extend(result.request_info.tool_call_stats)
combined_logprobs.extend(result.logprobs)
earliest_start_time = min(earliest_start_time, result.start_time)
prompt_lengths.append(len(all_queries[i * generation_config.n]))
for response in result.responses:
response_lengths.append(len(response))
total_prompt_tokens += result.token_statistics.num_prompt_tokens
total_response_tokens += result.token_statistics.num_response_tokens
max_generation_time = max(max_generation_time, result.token_statistics.generation_time)
accumulated_stats = data_types.TokenStatistics(
num_prompt_tokens=total_prompt_tokens,
num_response_tokens=total_response_tokens,
generation_time=max_generation_time,
earliest_start_time=earliest_start_time,
)
combined_request_info = data_types.RequestInfo(
num_calls=combined_num_calls,
timeouts=combined_timeouts,
tool_errors=combined_tool_errors,
tool_outputs=combined_tool_outputs,
tool_runtimes=combined_tool_runtimes,
tool_calleds=combined_tool_calleds,
tool_call_stats=combined_tool_call_stats,
)
combined_result = data_types.GenerationResult(
responses=combined_responses,
finish_reasons=combined_finish_reasons,
masks=combined_masks,
request_info=combined_request_info,
index=None,
prompt_id=results[0].prompt_id,
token_statistics=accumulated_stats,
logprobs=combined_logprobs,
)
if actor_manager is not None:
ray.get(actor_manager.report_token_statistics.remote(accumulated_stats))
batch = Batch(
queries=all_queries,
ground_truths=all_ground_truths,
datasets=all_datasets,
raw_queries=all_raw_queries,
decoded_responses=all_decoded_responses,
indices=None,
scores=all_scores,
active_tools=all_active_tools if all_active_tools else None,
)
combined_reward_metrics = combine_reward_metrics(all_reward_metrics)
percent_solved_mean = np.mean(all_percent_solved) if all_percent_solved else 0.0
batch_stats = BatchStatistics(
prompt_lengths=prompt_lengths,
response_lengths=response_lengths,
filtered_prompts=total_filtered_prompts,
filtered_prompts_zero=filtered_prompt_zero,
filtered_prompts_solved=filtered_prompt_solved,
filtered_prompts_nonzero=filtered_prompt_nonzero,
percent_solved_mean=percent_solved_mean,
percent_solved_hist=np.array(all_percent_solved),
no_resampled_prompts=total_no_resampled,
total_prompts=len(results),
)
return combined_result, batch, combined_reward_metrics, batch_stats
def prepare_collated_data_for_workers(
packed_sequences: PackedSequences,
dp_world_size: int,
per_device_train_batch_size: int,
pad_token_id: int,
pin_memory: bool = True,
) -> list[data_types.CollatedBatchData]:
"""Distributes and collates packed sequences for distributed training.
Splits packed sequences across workers, randomly shuffles each worker's data,
and collates into micro-batches for training.
Args:
packed_sequences: Packed training sequences containing query responses,
attention masks, position IDs, advantages, response masks,
and vllm logprobs.
dp_world_size: Number of distributed workers.
per_device_train_batch_size: Batch size for each device's micro-batch.
pad_token_id: Token ID used for padding sequences.
pin_memory: Whether to pin memory for faster data transfer to GPU.
Returns:
List of CollatedBatchData, one per worker, each containing collated tensors
for query_responses, attention_masks, position_ids,
advantages, response_masks, and vllm_logprobs.
"""
total_sequences = len(packed_sequences.query_responses)
if total_sequences % dp_world_size != 0:
new_total = (total_sequences // dp_world_size) * dp_world_size
logger.warning(
f"Total packed sequences ({total_sequences}) is not evenly divisible by dp_world_size ({dp_world_size}). "
f"Truncating to {new_total} sequences (dropping {total_sequences - new_total})."
)
B = total_sequences // dp_world_size
collated_data = []
assert packed_sequences.position_ids is not None
assert packed_sequences.advantages is not None
assert packed_sequences.vllm_logprobs is not None
for i in range(dp_world_size):
per_device_packed_query_responses = packed_sequences.query_responses[B * i : B * (i + 1)]
per_device_packed_attention_masks = packed_sequences.attention_masks[B * i : B * (i + 1)]
per_device_packed_position_ids = packed_sequences.position_ids[B * i : B * (i + 1)]
per_device_packed_advantages = packed_sequences.advantages[B * i : B * (i + 1)]
per_device_packed_response_masks = packed_sequences.response_masks[B * i : B * (i + 1)]
per_device_packed_vllm_logprobs = packed_sequences.vllm_logprobs[B * i : B * (i + 1)]
# Shuffle the batch and collate the data
b_inds = np.random.permutation(len(per_device_packed_query_responses))
collated_query_responses = []
collated_attention_masks = []
collated_position_ids = []
collated_response_masks = []
collated_advantages = []
collated_vllm_logprobs = []
for j in range(0, len(per_device_packed_query_responses), per_device_train_batch_size):
micro_range = b_inds[j : j + per_device_train_batch_size]
collated_query_responses.append(
collate_fn([per_device_packed_query_responses[idx] for idx in micro_range], pad_token_id, pin_memory)
)
collated_attention_masks.append(
collate_fn([per_device_packed_attention_masks[idx] for idx in micro_range], 0, pin_memory)
)
collated_position_ids.append(
collate_fn([per_device_packed_position_ids[idx] for idx in micro_range], 0, pin_memory)
)
collated_response_masks.append(
collate_fn([per_device_packed_response_masks[idx] for idx in micro_range], 0, pin_memory)
)
collated_advantages.append(
collate_fn([per_device_packed_advantages[idx] for idx in micro_range], 0, pin_memory)
)
collated_vllm_logprobs.append(
collate_fn([per_device_packed_vllm_logprobs[idx] for idx in micro_range], 0, pin_memory)
)
collated_data.append(
data_types.CollatedBatchData(
query_responses=collated_query_responses,
attention_masks=collated_attention_masks,
position_ids=collated_position_ids,
advantages=collated_advantages,
response_masks=collated_response_masks,
vllm_logprobs=collated_vllm_logprobs,
)
)
return collated_data
@ray.remote
class DataPreparationActor:
"""Ray actor singleton that handles centralized data preparation for all ranks.
This actor runs a background thread that continuously prepares training data,
ensuring all ranks receive the same number of micro-batches (preventing deadlock
from uneven filtering).
"""
def __init__(
self,
dataset: Dataset,
inference_results_Q: ray_queue.Queue,
param_prompt_Q: ray_queue.Queue,
tokenizer: PreTrainedTokenizer,
config: StreamingDataLoaderConfig,
generation_config,
num_training_steps: int,
seed: int,
per_device_train_batch_size: int,
global_batch_size: int,
dp_world_size: int,
max_possible_score: float,
actor_manager,
model_dims: utils.ModelDims,
verbose: bool,
work_dir: str,
tool_names: list[str],
run_name: str,
model_name: str | None,
initial_state: dict | None = None,
):
self.inference_results_Q = inference_results_Q
self.param_prompt_Q = param_prompt_Q
self.tokenizer = tokenizer
self.config = config
self.config.max_possible_score = max_possible_score
self.generation_config = generation_config
self.num_training_steps = num_training_steps
self.per_device_train_batch_size = per_device_train_batch_size
self.global_batch_size = global_batch_size
self.dp_world_size = dp_world_size
self.actor_manager = actor_manager
self.model_dims = model_dims
self.verbose = verbose
self.dataset = dataset
self.tool_names = tool_names
self.run_name = run_name
self.model_name = model_name
self.iter_dataloader = HFDataLoader(
dataset=dataset,
batch_size=1,
seed=seed,
dp_rank=0,
dp_world_size=1,
work_dir=work_dir,
automatic_reshuffle=True,
collator=single_example_collator,
)
self.prepared_data: dict[int, list[data_types.CollatedBatchData]] = {}
self.metrics: dict[int, dict] = {}
self.current_prepared_step = -1
self.lock = threading.Lock()
self.shutdown_requested = False
self.training_step = 0
self.total_samples_written = 0
self.metadata_saved = False
if initial_state is not None:
self.training_step = initial_state["training_step"]
self.iter_dataloader.load_state_dict(initial_state["iter_dataloader_state"])
logger.info(f"[DataPreparationActor] Restored state: training_step={self.training_step}")
self._executor = ThreadPoolExecutor(max_workers=1, thread_name_prefix="DataPrepActor")
self._prep_future = self._executor.submit(self._data_preparation_loop)