Commit c4b10fc
* metrics fix
* Add single_gpu_cache.sh for DPO cache comparison
Add a version of the single GPU DPO script that calls dpo_tune_cache.py
instead of dpo.py, to compare metrics between the two implementations.
Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
* Fix beaker_config UnboundLocalError in dpo_tune_cache.py
Move beaker_config initialization outside conditional blocks so it's
always defined when needed for experiment config updates.
Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
* Add rewards_average and token_count metrics to DPO
Adds missing metrics that are tracked in dpo_tune_cache.py:
- train/rewards_average: Average of chosen and rejected rewards
- train/token_count: Sum of non-padded tokens in chosen + rejected
Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
* Add --no-host-networking to single GPU DPO scripts
Avoid port 29500 conflicts on single GPU jobs by disabling host networking.
Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
* Fix logger.info call in dpo_tune_cache.py
Remove invalid main_process_only argument from logger.info().
Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
* Sync build_reference_logprobs_cache call with dpo_utils.py
Update the function call to match the current signature in dpo_utils.py.
Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
* Fix description in single_gpu_cache.sh
Correctly describe it as using accelerate (dpo_tune_cache.py), not OLMo-core.
Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
* Include gradient_accumulation_steps in global_batch_size for dpo.py
This makes dpo.py count optimizer updates as steps (like dpo_tune_cache.py)
instead of counting micro-batches as steps.
Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
* Set drop_last=False in dpo.py to match dpo_tune_cache.py
Keep incomplete last batch to match accelerate's default behavior.
Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
* Add debug logging to investigate logprobs discrepancy between dpo.py and dpo_tune_cache.py
Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
* Add attention masking support for OLMo-core DPO
OLMo-core Transformer doesn't support attention_mask parameter but uses
cu_doc_lens for intra-document attention masking. This change adds a
pack_padded_sequences helper function that converts padded batches to
packed format with cumulative document lengths.
Both concatenated_forward_olmo and separate_forward_olmo now properly
handle padding by packing sequences on-the-fly.
Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
* Use PyTorch RNG in HFDataLoader to match dpo_tune_cache.py data ordering
Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
* Add debug logging to compare data ordering between dpo.py and dpo_tune_cache.py
Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
* Make OLMo-core DPO use same logprob computation as HuggingFace
When packing=False, OLMo-core now unpacks logits back to padded format
and uses _get_batch_logps (same as HuggingFace) instead of pf_get_batch_logps.
This ensures consistent logprob computation between dpo.py and dpo_tune_cache.py.
Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
* Add detailed logprob debug logging to compare dpo.py and dpo_tune_cache.py
Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
* Add micro-batching to DPO to match dpo_tune_cache.py batch structure
Split large batches into micro-batches of size per_device_train_batch_size
and process them one at a time with gradient accumulation. This ensures
dpo.py (OLMo-core) and dpo_tune_cache.py (HuggingFace) process the same
number of samples per forward pass.
Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
* Add debug logging to compare HF and OLMo-core forward passes
Log input_ids, attention_mask/cu_doc_lens, labels, logits, and logprobs
for both HuggingFace and OLMo-core forward functions to diagnose
the logprob differences between dpo.py and dpo_tune_cache.py.
Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
* Add embedding weight logging to compare HF and OLMo-core models
Log the first 5 values and mean of embedding weights to verify
whether the model weights are identical between implementations.
Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
* Fix embed weight logging to handle DTensor (FSDP)
Use .detach().float().cpu() to convert DTensor to regular tensor
before calling .tolist() for logging.
Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
* Use full_tensor() for FSDP sharded weights
Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
* Fix: Actually load HF weights into OLMo-core model
load_hf_model() loads weights into the provided state_dict, but
model.state_dict() returns a copy, not a reference. The modified
state_dict was never loaded back into the model, leaving it with
randomly initialized weights.
Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
* Align data ordering between dpo.py and dpo_tune_cache.py
- Change HFDataLoader to use NumPy RNG (np.random.default_rng) instead of
PyTorch RNG to match HuggingFace Dataset.shuffle() behavior
- Remove shuffle=True from DataLoader in dpo_tune_cache.py to avoid double
shuffling (dataset is already shuffled)
- Add debug logging to verify data indices match
Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
* Remove double shuffling from dpo.py
The dataset was being shuffled twice:
1. Via HF Dataset.shuffle() before passing to HFDataLoader
2. Via numpy permutation inside HFDataLoader._reshard()
Now only HFDataLoader._reshard() shuffles, matching dpo_tune_cache.py behavior.
Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
* Revert changes to dpo_tune_cache.py
Keep the original double-shuffling behavior as the DataLoader needs to
shuffle during iteration.
Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
* Implement double-shuffle in dpo.py to match dpo_tune_cache.py
dpo_tune_cache.py does:
1. dataset.shuffle(seed) - HF Dataset shuffle
2. DataLoader(shuffle=True) - PyTorch DataLoader shuffle
Now dpo.py does:
1. dataset.shuffle(seed) - HF Dataset shuffle (restored)
2. HFDataLoader._reshard() with PyTorch RNG (torch.randperm)
Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
* Reseed torch RNG before DataLoader creation in dpo_tune_cache.py
The torch RNG state gets consumed by model loading between set_seed()
and DataLoader creation. Reseeding ensures the DataLoader shuffle
uses a fresh RNG state matching HFDataLoader's behavior.
Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
* Revert torch.manual_seed change to dpo_tune_cache.py
Cannot modify dpo_tune_cache.py - need to match its behavior from dpo.py side.
Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
* Use seeded generator for DataLoader shuffle in dpo_tune_cache.py
Makes the DataLoader shuffle reproducible by using a Generator seeded
with args.seed, matching HFDataLoader's behavior in dpo.py.
Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
* Add detailed logits logging at label positions
Log logits at first label position for chosen and rejected to help
debug differences between HF and OLMo-core implementations.
Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
* Add input token logging at position 445-450 for debugging
Compare chosen vs rejected input tokens at the same positions to verify
they have identical prompt content.
Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
* Add packed logits logging at position 447 for debugging
Compare packed_logits[447] (chosen) vs packed_logits[rejected_start+447]
to debug attention masking between documents.
Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
* Fix OLMo-core attention masking by using correct argument names
The model expects doc_lens (individual lengths) and max_doc_lens (list),
not cu_doc_lens (cumulative) and max_doc_len (int). This fix enables
proper document boundary masking in concatenated_forward_olmo.
Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
* Remove debug logging that causes index errors for short sequences
Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
* Fix DataLoader shuffle to match HFDataLoader's randperm order
Use explicit RandomSampler instead of shuffle=True to avoid RNG state
consumption that causes different iteration order between DataLoader
and HFDataLoader.
Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
* Add debug logging to verify randperm behavior on Beaker
Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
* Add dataset_len to debug logging for randperm verification
Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
* Fix dpo.py epoch alignment with dpo_tune_cache.py
The olmo-core Trainer starts at epoch=1 by default, which causes
data_loader.reshuffle(1) to be called with seed=seed+1=124 instead of
seed=123. This resulted in different data ordering between dpo.py and
dpo_tune_cache.py after the first 4 samples.
Setting trainer.epoch=0 before fit() ensures both implementations use
the same seed=123 for data shuffling.
Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
* Apply H17, H25, H16 fixes from DPO divergence investigation
- H17: Swap optim.step() before scheduler.set_lr() so optimizer uses
correct LR (was advancing LR before optimizer step)
- H25: Initialize optimizer LR to 0.0 (warmup start) to match HF
behavior (was using full learning_rate for first step)
- H16: Use Duration.steps(max_train_steps) when set instead of always
Duration.epochs() (fixes step count mismatch)
- Copy investigation doc from old branch
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
* Revert H17/H25, keep H16: use Duration.steps() instead of Duration.epochs()
The step count mismatch (96 vs 72) was caused by Duration.epochs()
counting more steps than expected due to dataloader padding. Always
use Duration.steps(num_training_steps) to match dpo_tune_cache.py.
Reverted H17 (scheduler ordering swap) and H25 (lr=0 init) as they
were incorrect — the original set_lr-before-optim.step order is
correct given OLMo-core's pre-incremented global_step.
Also adds compare_wandb_runs.py script with per-step output and
step-count-mismatch tolerance.
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
* Add validation notebook for DPO comparison
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
* Remove incorrect trainer.epoch = 0 (default is 1-based)
OLMo-core uses 1-based epochs. Setting epoch=0 caused
Duration.epochs(N) to compute N+1 epochs worth of steps.
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
* Fix rewards_accuracy to use per-sample comparison instead of scalar
Previously compared mean chosen vs mean rejected rewards (always 0 or 1).
Now computes per-sample accuracy and averages, matching dpo_tune_cache.py.
Also set drop_last=True in dpo.py data loaders.
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
* Revert "Fix rewards_accuracy to use per-sample comparison instead of scalar"
This reverts commit 6143215.
* Revert "Remove incorrect trainer.epoch = 0 (default is 1-based)"
This reverts commit 31024f1.
* Fix rewards_accuracy to use per-sample comparison instead of scalar
Previously compared mean chosen vs mean rejected rewards (always 0 or 1).
Now computes per-sample accuracy and averages, matching dpo_tune_cache.py.
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
* Fix HF weight loading and micro-batch splitting after TransformerTrainModule merge
TransformerTrainModule.__init__ calls parallelize_model which calls
init_weights, reinitializing all model weights from scratch. This
destroyed the HF checkpoint loaded in _setup_model. Fix by reloading
HF weights after parallelization.
Also fix micro-batch splitting: use sample_microbatch_size (in samples)
for split_batch_dpo instead of rank_microbatch_size (in tokens), matching
main branch's DPOTrainModule pattern.
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
* Remove DEBUG logging from DPO forward passes and data loader
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
* cleaned up PR and merged to head
* Set gradient_accumulation_steps=1 in debug DPO scripts
Both scripts keep effective batch size=4 but use per_device_train_batch_size=4
with gradient_accumulation_steps=1, so micro-batch averaging differences
between OLMo-core (token-weighted) and HF (uniform) don't matter.
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
* cleaned up PR
* cleaned up PR
* cleaned up PR
* Fix PR #1451 review comments
- Fix mock model parameter names to match production code (doc_lens/max_doc_lens)
- Use torch RNG instead of numpy in test to match HFDataLoader implementation
- Use effective_steps (max_train_steps when set) for scheduler warmup
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
* set drop_last
* updated code
* Remove del statement for unused params in mock model forward
Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
---------
Co-authored-by: Claude Opus 4.5 <noreply@anthropic.com>
1 parent cf4a3f5 commit c4b10fc
File tree
9 files changed
+148
-39
lines changed- open_instruct
- scripts/train/debug/dpo
9 files changed
+148
-39
lines changed| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
28 | 28 | | |
29 | 29 | | |
30 | 30 | | |
| 31 | + | |
31 | 32 | | |
32 | 33 | | |
33 | 34 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
231 | 231 | | |
232 | 232 | | |
233 | 233 | | |
234 | | - | |
235 | | - | |
| 234 | + | |
| 235 | + | |
| 236 | + | |
| 237 | + | |
236 | 238 | | |
237 | 239 | | |
238 | 240 | | |
239 | | - | |
240 | 241 | | |
241 | 242 | | |
242 | 243 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
85 | 85 | | |
86 | 86 | | |
87 | 87 | | |
88 | | - | |
| 88 | + | |
89 | 89 | | |
90 | 90 | | |
91 | 91 | | |
| |||
103 | 103 | | |
104 | 104 | | |
105 | 105 | | |
106 | | - | |
107 | | - | |
108 | | - | |
109 | | - | |
110 | 106 | | |
111 | 107 | | |
112 | 108 | | |
| |||
271 | 267 | | |
272 | 268 | | |
273 | 269 | | |
274 | | - | |
| 270 | + | |
275 | 271 | | |
276 | 272 | | |
277 | 273 | | |
| |||
308 | 304 | | |
309 | 305 | | |
310 | 306 | | |
| 307 | + | |
311 | 308 | | |
312 | 309 | | |
313 | 310 | | |
| |||
325 | 322 | | |
326 | 323 | | |
327 | 324 | | |
328 | | - | |
| 325 | + | |
329 | 326 | | |
330 | 327 | | |
331 | 328 | | |
| |||
350 | 347 | | |
351 | 348 | | |
352 | 349 | | |
| 350 | + | |
353 | 351 | | |
354 | | - | |
| 352 | + | |
355 | 353 | | |
356 | 354 | | |
357 | 355 | | |
| |||
384 | 382 | | |
385 | 383 | | |
386 | 384 | | |
387 | | - | |
388 | | - | |
389 | | - | |
| 385 | + | |
| 386 | + | |
| 387 | + | |
| 388 | + | |
| 389 | + | |
| 390 | + | |
| 391 | + | |
390 | 392 | | |
391 | 393 | | |
392 | 394 | | |
| |||
399 | 401 | | |
400 | 402 | | |
401 | 403 | | |
| 404 | + | |
| 405 | + | |
| 406 | + | |
| 407 | + | |
| 408 | + | |
402 | 409 | | |
403 | 410 | | |
404 | | - | |
| 411 | + | |
405 | 412 | | |
406 | 413 | | |
407 | 414 | | |
408 | 415 | | |
409 | 416 | | |
| 417 | + | |
| 418 | + | |
410 | 419 | | |
411 | 420 | | |
412 | 421 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
43 | 43 | | |
44 | 44 | | |
45 | 45 | | |
46 | | - | |
| 46 | + | |
47 | 47 | | |
48 | 48 | | |
49 | 49 | | |
| |||
407 | 407 | | |
408 | 408 | | |
409 | 409 | | |
| 410 | + | |
410 | 411 | | |
411 | | - | |
| 412 | + | |
412 | 413 | | |
413 | 414 | | |
414 | 415 | | |
| |||
535 | 536 | | |
536 | 537 | | |
537 | 538 | | |
| 539 | + | |
538 | 540 | | |
539 | 541 | | |
540 | 542 | | |
| |||
573 | 575 | | |
574 | 576 | | |
575 | 577 | | |
| 578 | + | |
576 | 579 | | |
577 | 580 | | |
578 | 581 | | |
| |||
621 | 624 | | |
622 | 625 | | |
623 | 626 | | |
624 | | - | |
| 627 | + | |
| 628 | + | |
| 629 | + | |
625 | 630 | | |
626 | 631 | | |
627 | 632 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
873 | 873 | | |
874 | 874 | | |
875 | 875 | | |
| 876 | + | |
| 877 | + | |
| 878 | + | |
| 879 | + | |
| 880 | + | |
| 881 | + | |
| 882 | + | |
| 883 | + | |
| 884 | + | |
| 885 | + | |
| 886 | + | |
| 887 | + | |
| 888 | + | |
| 889 | + | |
| 890 | + | |
| 891 | + | |
| 892 | + | |
| 893 | + | |
| 894 | + | |
| 895 | + | |
| 896 | + | |
| 897 | + | |
| 898 | + | |
| 899 | + | |
| 900 | + | |
| 901 | + | |
| 902 | + | |
| 903 | + | |
| 904 | + | |
| 905 | + | |
| 906 | + | |
| 907 | + | |
| 908 | + | |
| 909 | + | |
| 910 | + | |
| 911 | + | |
| 912 | + | |
| 913 | + | |
| 914 | + | |
| 915 | + | |
| 916 | + | |
| 917 | + | |
| 918 | + | |
| 919 | + | |
| 920 | + | |
| 921 | + | |
| 922 | + | |
| 923 | + | |
| 924 | + | |
| 925 | + | |
| 926 | + | |
| 927 | + | |
| 928 | + | |
| 929 | + | |
| 930 | + | |
| 931 | + | |
| 932 | + | |
| 933 | + | |
| 934 | + | |
| 935 | + | |
| 936 | + | |
| 937 | + | |
| 938 | + | |
| 939 | + | |
| 940 | + | |
| 941 | + | |
876 | 942 | | |
877 | 943 | | |
878 | 944 | | |
| |||
905 | 971 | | |
906 | 972 | | |
907 | 973 | | |
| 974 | + | |
908 | 975 | | |
909 | 976 | | |
910 | 977 | | |
| |||
1023 | 1090 | | |
1024 | 1091 | | |
1025 | 1092 | | |
| 1093 | + | |
| 1094 | + | |
1026 | 1095 | | |
1027 | 1096 | | |
1028 | | - | |
1029 | | - | |
| 1097 | + | |
| 1098 | + | |
| 1099 | + | |
| 1100 | + | |
| 1101 | + | |
1030 | 1102 | | |
1031 | | - | |
| 1103 | + | |
| 1104 | + | |
| 1105 | + | |
| 1106 | + | |
| 1107 | + | |
| 1108 | + | |
1032 | 1109 | | |
1033 | | - | |
1034 | 1110 | | |
1035 | 1111 | | |
1036 | 1112 | | |
1037 | | - | |
1038 | 1113 | | |
| 1114 | + | |
| 1115 | + | |
| 1116 | + | |
| 1117 | + | |
| 1118 | + | |
| 1119 | + | |
| 1120 | + | |
1039 | 1121 | | |
1040 | 1122 | | |
1041 | 1123 | | |
1042 | 1124 | | |
1043 | 1125 | | |
1044 | 1126 | | |
| 1127 | + | |
1045 | 1128 | | |
1046 | 1129 | | |
1047 | 1130 | | |
| |||
1069 | 1152 | | |
1070 | 1153 | | |
1071 | 1154 | | |
1072 | | - | |
1073 | | - | |
| 1155 | + | |
| 1156 | + | |
| 1157 | + | |
| 1158 | + | |
| 1159 | + | |
| 1160 | + | |
| 1161 | + | |
| 1162 | + | |
1074 | 1163 | | |
1075 | | - | |
| 1164 | + | |
1076 | 1165 | | |
1077 | 1166 | | |
1078 | 1167 | | |
1079 | | - | |
1080 | | - | |
| 1168 | + | |
| 1169 | + | |
| 1170 | + | |
| 1171 | + | |
| 1172 | + | |
| 1173 | + | |
| 1174 | + | |
| 1175 | + | |
1081 | 1176 | | |
1082 | | - | |
| 1177 | + | |
1083 | 1178 | | |
1084 | 1179 | | |
1085 | 1180 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
107 | 107 | | |
108 | 108 | | |
109 | 109 | | |
110 | | - | |
111 | | - | |
112 | | - | |
113 | 110 | | |
114 | 111 | | |
115 | 112 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
2 | 2 | | |
3 | 3 | | |
4 | 4 | | |
5 | | - | |
6 | 5 | | |
7 | 6 | | |
8 | 7 | | |
| |||
92 | 91 | | |
93 | 92 | | |
94 | 93 | | |
95 | | - | |
96 | | - | |
97 | | - | |
| 94 | + | |
| 95 | + | |
| 96 | + | |
98 | 97 | | |
99 | 98 | | |
100 | 99 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
80 | 80 | | |
81 | 81 | | |
82 | 82 | | |
83 | | - | |
| 83 | + | |
| 84 | + | |
| 85 | + | |
84 | 86 | | |
85 | 87 | | |
86 | 88 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
19 | 19 | | |
20 | 20 | | |
21 | 21 | | |
22 | | - | |
23 | | - | |
| 22 | + | |
| 23 | + | |
24 | 24 | | |
25 | 25 | | |
26 | 26 | | |
| |||
0 commit comments