Skip to content

Commit c4b10fc

Browse files
Now, dpo.py matches dpo_tune_cache.py almost perfectly on the single GPU experiments (#1451)
* metrics fix * Add single_gpu_cache.sh for DPO cache comparison Add a version of the single GPU DPO script that calls dpo_tune_cache.py instead of dpo.py, to compare metrics between the two implementations. Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com> * Fix beaker_config UnboundLocalError in dpo_tune_cache.py Move beaker_config initialization outside conditional blocks so it's always defined when needed for experiment config updates. Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com> * Add rewards_average and token_count metrics to DPO Adds missing metrics that are tracked in dpo_tune_cache.py: - train/rewards_average: Average of chosen and rejected rewards - train/token_count: Sum of non-padded tokens in chosen + rejected Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com> * Add --no-host-networking to single GPU DPO scripts Avoid port 29500 conflicts on single GPU jobs by disabling host networking. Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com> * Fix logger.info call in dpo_tune_cache.py Remove invalid main_process_only argument from logger.info(). Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com> * Sync build_reference_logprobs_cache call with dpo_utils.py Update the function call to match the current signature in dpo_utils.py. Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com> * Fix description in single_gpu_cache.sh Correctly describe it as using accelerate (dpo_tune_cache.py), not OLMo-core. Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com> * Include gradient_accumulation_steps in global_batch_size for dpo.py This makes dpo.py count optimizer updates as steps (like dpo_tune_cache.py) instead of counting micro-batches as steps. Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com> * Set drop_last=False in dpo.py to match dpo_tune_cache.py Keep incomplete last batch to match accelerate's default behavior. Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com> * Add debug logging to investigate logprobs discrepancy between dpo.py and dpo_tune_cache.py Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com> * Add attention masking support for OLMo-core DPO OLMo-core Transformer doesn't support attention_mask parameter but uses cu_doc_lens for intra-document attention masking. This change adds a pack_padded_sequences helper function that converts padded batches to packed format with cumulative document lengths. Both concatenated_forward_olmo and separate_forward_olmo now properly handle padding by packing sequences on-the-fly. Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com> * Use PyTorch RNG in HFDataLoader to match dpo_tune_cache.py data ordering Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com> * Add debug logging to compare data ordering between dpo.py and dpo_tune_cache.py Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com> * Make OLMo-core DPO use same logprob computation as HuggingFace When packing=False, OLMo-core now unpacks logits back to padded format and uses _get_batch_logps (same as HuggingFace) instead of pf_get_batch_logps. This ensures consistent logprob computation between dpo.py and dpo_tune_cache.py. Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com> * Add detailed logprob debug logging to compare dpo.py and dpo_tune_cache.py Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com> * Add micro-batching to DPO to match dpo_tune_cache.py batch structure Split large batches into micro-batches of size per_device_train_batch_size and process them one at a time with gradient accumulation. This ensures dpo.py (OLMo-core) and dpo_tune_cache.py (HuggingFace) process the same number of samples per forward pass. Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com> * Add debug logging to compare HF and OLMo-core forward passes Log input_ids, attention_mask/cu_doc_lens, labels, logits, and logprobs for both HuggingFace and OLMo-core forward functions to diagnose the logprob differences between dpo.py and dpo_tune_cache.py. Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com> * Add embedding weight logging to compare HF and OLMo-core models Log the first 5 values and mean of embedding weights to verify whether the model weights are identical between implementations. Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com> * Fix embed weight logging to handle DTensor (FSDP) Use .detach().float().cpu() to convert DTensor to regular tensor before calling .tolist() for logging. Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com> * Use full_tensor() for FSDP sharded weights Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com> * Fix: Actually load HF weights into OLMo-core model load_hf_model() loads weights into the provided state_dict, but model.state_dict() returns a copy, not a reference. The modified state_dict was never loaded back into the model, leaving it with randomly initialized weights. Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com> * Align data ordering between dpo.py and dpo_tune_cache.py - Change HFDataLoader to use NumPy RNG (np.random.default_rng) instead of PyTorch RNG to match HuggingFace Dataset.shuffle() behavior - Remove shuffle=True from DataLoader in dpo_tune_cache.py to avoid double shuffling (dataset is already shuffled) - Add debug logging to verify data indices match Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com> * Remove double shuffling from dpo.py The dataset was being shuffled twice: 1. Via HF Dataset.shuffle() before passing to HFDataLoader 2. Via numpy permutation inside HFDataLoader._reshard() Now only HFDataLoader._reshard() shuffles, matching dpo_tune_cache.py behavior. Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com> * Revert changes to dpo_tune_cache.py Keep the original double-shuffling behavior as the DataLoader needs to shuffle during iteration. Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com> * Implement double-shuffle in dpo.py to match dpo_tune_cache.py dpo_tune_cache.py does: 1. dataset.shuffle(seed) - HF Dataset shuffle 2. DataLoader(shuffle=True) - PyTorch DataLoader shuffle Now dpo.py does: 1. dataset.shuffle(seed) - HF Dataset shuffle (restored) 2. HFDataLoader._reshard() with PyTorch RNG (torch.randperm) Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com> * Reseed torch RNG before DataLoader creation in dpo_tune_cache.py The torch RNG state gets consumed by model loading between set_seed() and DataLoader creation. Reseeding ensures the DataLoader shuffle uses a fresh RNG state matching HFDataLoader's behavior. Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com> * Revert torch.manual_seed change to dpo_tune_cache.py Cannot modify dpo_tune_cache.py - need to match its behavior from dpo.py side. Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com> * Use seeded generator for DataLoader shuffle in dpo_tune_cache.py Makes the DataLoader shuffle reproducible by using a Generator seeded with args.seed, matching HFDataLoader's behavior in dpo.py. Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com> * Add detailed logits logging at label positions Log logits at first label position for chosen and rejected to help debug differences between HF and OLMo-core implementations. Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com> * Add input token logging at position 445-450 for debugging Compare chosen vs rejected input tokens at the same positions to verify they have identical prompt content. Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com> * Add packed logits logging at position 447 for debugging Compare packed_logits[447] (chosen) vs packed_logits[rejected_start+447] to debug attention masking between documents. Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com> * Fix OLMo-core attention masking by using correct argument names The model expects doc_lens (individual lengths) and max_doc_lens (list), not cu_doc_lens (cumulative) and max_doc_len (int). This fix enables proper document boundary masking in concatenated_forward_olmo. Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com> * Remove debug logging that causes index errors for short sequences Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com> * Fix DataLoader shuffle to match HFDataLoader's randperm order Use explicit RandomSampler instead of shuffle=True to avoid RNG state consumption that causes different iteration order between DataLoader and HFDataLoader. Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com> * Add debug logging to verify randperm behavior on Beaker Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com> * Add dataset_len to debug logging for randperm verification Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com> * Fix dpo.py epoch alignment with dpo_tune_cache.py The olmo-core Trainer starts at epoch=1 by default, which causes data_loader.reshuffle(1) to be called with seed=seed+1=124 instead of seed=123. This resulted in different data ordering between dpo.py and dpo_tune_cache.py after the first 4 samples. Setting trainer.epoch=0 before fit() ensures both implementations use the same seed=123 for data shuffling. Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com> * Apply H17, H25, H16 fixes from DPO divergence investigation - H17: Swap optim.step() before scheduler.set_lr() so optimizer uses correct LR (was advancing LR before optimizer step) - H25: Initialize optimizer LR to 0.0 (warmup start) to match HF behavior (was using full learning_rate for first step) - H16: Use Duration.steps(max_train_steps) when set instead of always Duration.epochs() (fixes step count mismatch) - Copy investigation doc from old branch Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com> * Revert H17/H25, keep H16: use Duration.steps() instead of Duration.epochs() The step count mismatch (96 vs 72) was caused by Duration.epochs() counting more steps than expected due to dataloader padding. Always use Duration.steps(num_training_steps) to match dpo_tune_cache.py. Reverted H17 (scheduler ordering swap) and H25 (lr=0 init) as they were incorrect — the original set_lr-before-optim.step order is correct given OLMo-core's pre-incremented global_step. Also adds compare_wandb_runs.py script with per-step output and step-count-mismatch tolerance. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com> * Add validation notebook for DPO comparison Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com> * Remove incorrect trainer.epoch = 0 (default is 1-based) OLMo-core uses 1-based epochs. Setting epoch=0 caused Duration.epochs(N) to compute N+1 epochs worth of steps. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com> * Fix rewards_accuracy to use per-sample comparison instead of scalar Previously compared mean chosen vs mean rejected rewards (always 0 or 1). Now computes per-sample accuracy and averages, matching dpo_tune_cache.py. Also set drop_last=True in dpo.py data loaders. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com> * Revert "Fix rewards_accuracy to use per-sample comparison instead of scalar" This reverts commit 6143215. * Revert "Remove incorrect trainer.epoch = 0 (default is 1-based)" This reverts commit 31024f1. * Fix rewards_accuracy to use per-sample comparison instead of scalar Previously compared mean chosen vs mean rejected rewards (always 0 or 1). Now computes per-sample accuracy and averages, matching dpo_tune_cache.py. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com> * Fix HF weight loading and micro-batch splitting after TransformerTrainModule merge TransformerTrainModule.__init__ calls parallelize_model which calls init_weights, reinitializing all model weights from scratch. This destroyed the HF checkpoint loaded in _setup_model. Fix by reloading HF weights after parallelization. Also fix micro-batch splitting: use sample_microbatch_size (in samples) for split_batch_dpo instead of rank_microbatch_size (in tokens), matching main branch's DPOTrainModule pattern. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com> * Remove DEBUG logging from DPO forward passes and data loader Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com> * cleaned up PR and merged to head * Set gradient_accumulation_steps=1 in debug DPO scripts Both scripts keep effective batch size=4 but use per_device_train_batch_size=4 with gradient_accumulation_steps=1, so micro-batch averaging differences between OLMo-core (token-weighted) and HF (uniform) don't matter. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com> * cleaned up PR * cleaned up PR * cleaned up PR * Fix PR #1451 review comments - Fix mock model parameter names to match production code (doc_lens/max_doc_lens) - Use torch RNG instead of numpy in test to match HFDataLoader implementation - Use effective_steps (max_train_steps when set) for scheduler warmup Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com> * set drop_last * updated code * Remove del statement for unused params in mock model forward Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com> --------- Co-authored-by: Claude Opus 4.5 <noreply@anthropic.com>
1 parent cf4a3f5 commit c4b10fc

File tree

9 files changed

+148
-39
lines changed

9 files changed

+148
-39
lines changed

CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@ All notable changes to this project will be documented in this file.
2828
- Increased vLLM health check timeout from 30s to 600s (10 minutes) (https://github.com/allenai/open-instruct/pull/1452).
2929
- Updated vllm version to 0.14.1 (https://github.com/allenai/open-instruct/pull/1433).
3030
- Changed default wandb x-axis from `episode` to `training_step` for grpo_fast (https://github.com/allenai/open-instruct/pull/1437).
31+
- Made a bunch of changes to `dpo.py` so it matches `dpo_tune_cache.py` perfectly (https://github.com/allenai/open-instruct/pull/1451).
3132

3233
### Fixed
3334
- Fixed test `single_example_collator` returning raw int for index, causing `TypeError` in `_iter_batches` (https://github.com/allenai/open-instruct/pull/1477).

open_instruct/data_loader.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -231,12 +231,13 @@ def _reshard(self, epoch: int) -> None:
231231
232232
Uses index-based shuffling to avoid copying the dataset.
233233
"""
234-
rng = np.random.default_rng(self.seed + epoch)
235-
all_indices = np.arange(len(self._full_dataset))
234+
generator = torch.Generator()
235+
generator.manual_seed(self.seed + epoch)
236+
dataset_len = len(self._full_dataset)
237+
all_indices = torch.randperm(dataset_len, generator=generator).numpy()
236238
if self._excluded_indices:
237239
mask = np.isin(all_indices, list(self._excluded_indices), invert=True)
238240
all_indices = all_indices[mask]
239-
rng.shuffle(all_indices)
240241

241242
global_size = len(all_indices)
242243
total_batches = global_size // self._batch_size

open_instruct/dpo.py

Lines changed: 21 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -85,7 +85,7 @@ def _load():
8585

8686

8787
def _setup_model(args: dpo_utils.ExperimentConfig, device: torch.device):
88-
"""Load and configure OLMo-core model."""
88+
"""Build OLMo-core model architecture (weights loaded after parallelization)."""
8989
hf_config = transformers.AutoConfig.from_pretrained(args.model_name_or_path)
9090
vocab_size = hf_config.vocab_size
9191
logger.info(f"Building OLMo-core model with vocab_size={vocab_size}")
@@ -103,10 +103,6 @@ def _setup_model(args: dpo_utils.ExperimentConfig, device: torch.device):
103103
)
104104
model = model_config.build(init_device="cpu")
105105

106-
logger.info(f"Loading HuggingFace weights from {args.model_name_or_path}")
107-
load_hf_model(args.model_name_or_path, model.state_dict(), work_dir=args.output_dir)
108-
model = model.to(device=device, dtype=torch.bfloat16)
109-
110106
return model, model_config
111107

112108

@@ -271,7 +267,7 @@ def main(args: dpo_utils.ExperimentConfig, tc: dataset_transformation.TokenizerC
271267

272268
dataset = _load_dataset_distributed(args, tc, transform_fn_args, is_main_process)
273269
dataset = dataset.shuffle(seed=args.seed)
274-
dataset.set_format(type="pt") # Must be after shuffle (shuffle resets format)
270+
dataset.set_format(type="pt")
275271

276272
world_size = distributed_utils.get_world_size() if distributed_utils.is_distributed() else 1
277273
dp_world_size = world_size // args.tensor_parallel_degree
@@ -308,6 +304,7 @@ def main(args: dpo_utils.ExperimentConfig, tc: dataset_transformation.TokenizerC
308304
work_dir=args.output_dir,
309305
collator=collator,
310306
device=device,
307+
drop_last=True,
311308
)
312309
# 4x batch size: forward-only (no backward), so no activation storage needed.
313310
# With packing, the collator's token budget controls the actual forward-pass size
@@ -325,7 +322,7 @@ def main(args: dpo_utils.ExperimentConfig, tc: dataset_transformation.TokenizerC
325322
work_dir=args.output_dir,
326323
collator=collator,
327324
device=device,
328-
drop_last=False,
325+
drop_last=True,
329326
)
330327

331328
forward_fn = dpo_utils.concatenated_forward_olmo if args.concatenated_forward else dpo_utils.separate_forward_olmo
@@ -350,8 +347,9 @@ def main(args: dpo_utils.ExperimentConfig, tc: dataset_transformation.TokenizerC
350347

351348
data_loader.reshuffle(epoch=0)
352349
num_training_steps = len(data_loader) * args.num_epochs
350+
effective_steps = args.max_train_steps if args.max_train_steps is not None else num_training_steps
353351
optim_config = AdamWConfig(lr=args.learning_rate, weight_decay=args.weight_decay, fused=args.fused_optimizer)
354-
scheduler = _setup_scheduler(args, num_training_steps)
352+
scheduler = _setup_scheduler(args, effective_steps)
355353
max_grad_norm = args.max_grad_norm if args.max_grad_norm > 0 else None
356354
dp_config = transformer_config.TransformerDataParallelConfig(
357355
name=DataParallelType.hsdp,
@@ -384,9 +382,13 @@ def main(args: dpo_utils.ExperimentConfig, tc: dataset_transformation.TokenizerC
384382
device=device,
385383
)
386384

387-
# Build reference cache after train_module init because TransformerTrainModule applies
388-
# FSDP parallelism to the model, and we need the parallelized model to calculate the
389-
# logprobs in case the model is too big to fit in memory.
385+
# TransformerTrainModule.__init__ calls parallelize_model which calls init_weights,
386+
# reinitializing all model weights from scratch. We must reload the HF checkpoint.
387+
logger.info("Reloading HuggingFace weights after parallelization...")
388+
sd = train_module.model.state_dict()
389+
load_hf_model(args.model_name_or_path, sd, work_dir=args.output_dir)
390+
train_module.model.load_state_dict(sd)
391+
390392
logger.info("Caching reference logprobs...")
391393
train_module.reference_cache = dpo_utils.build_reference_logprobs_cache(model=train_module.model, **cache_kwargs)
392394

@@ -399,14 +401,21 @@ def main(args: dpo_utils.ExperimentConfig, tc: dataset_transformation.TokenizerC
399401

400402
trainer_callbacks = _setup_callbacks(args, dp_world_size)
401403

404+
if args.max_train_steps is not None:
405+
max_duration = train.Duration.steps(args.max_train_steps)
406+
else:
407+
max_duration = train.Duration.steps(num_training_steps)
408+
402409
trainer = train.TrainerConfig(
403410
save_folder=args.output_dir,
404-
max_duration=train.Duration.epochs(args.num_epochs),
411+
max_duration=max_duration,
405412
metrics_collect_interval=args.logging_steps,
406413
callbacks=trainer_callbacks,
407414
save_overwrite=True,
408415
).build(train_module, data_loader)
409416

417+
trainer.epoch = 0
418+
410419
logger.info("Starting training...")
411420
trainer.fit()
412421
logger.info("Training complete.")

open_instruct/dpo_tune_cache.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@
4343
from huggingface_hub import HfApi
4444
from peft import LoraConfig, TaskType, get_peft_model, prepare_model_for_kbit_training
4545
from rich.pretty import pprint
46-
from torch.utils.data import DataLoader
46+
from torch.utils.data import DataLoader, RandomSampler
4747
from tqdm.auto import tqdm
4848
from transformers import AutoConfig, AutoModelForCausalLM, BitsAndBytesConfig, get_scheduler
4949

@@ -407,8 +407,9 @@ def load_model():
407407
else:
408408
collate_fn = dpo_utils.DataCollatorForSeq2SeqDPO(tokenizer=tokenizer, model=model, padding="longest")
409409

410+
train_sampler = RandomSampler(train_dataset, generator=torch.Generator().manual_seed(args.seed))
410411
train_dataloader = DataLoader(
411-
train_dataset, shuffle=True, collate_fn=collate_fn, batch_size=args.per_device_train_batch_size
412+
train_dataset, sampler=train_sampler, collate_fn=collate_fn, batch_size=args.per_device_train_batch_size
412413
)
413414

414415
# Optimizer
@@ -535,6 +536,7 @@ def load_model():
535536
is_main_process=accelerator.is_main_process,
536537
model_dims=model_dims,
537538
use_lora=args.use_lora,
539+
disable_adapter_context=None,
538540
)
539541
logger.info("=============after cache logprobs")
540542
print_gpu_stats(init_gpu_memory)
@@ -573,6 +575,7 @@ def load_model():
573575
average_log_prob=args.loss_type.is_average_loss,
574576
output_router_logits=args.load_balancing_loss,
575577
) # `aux_loss` is only used when `args.load_balancing_loss = True`
578+
576579
losses, chosen_rewards, rejected_rewards = dpo_utils.compute_loss(
577580
args,
578581
batch,
@@ -621,7 +624,9 @@ def load_model():
621624
# single all reduce to save time, avoiding per metric all reduce
622625
global_metrics_tensor = accelerator.reduce(local_metrics.metrics, reduction="mean")
623626
global_metrics_tensor /= args.gradient_accumulation_steps * args.logging_steps
624-
global_metrics_tensor[local_metrics.names2idx["token_count"]] *= accelerator.num_processes
627+
global_metrics_tensor[local_metrics.names2idx["token_count"]] *= (
628+
accelerator.num_processes * args.gradient_accumulation_steps * args.logging_steps
629+
)
625630
global_metrics = {
626631
name: global_metrics_tensor[index].item() for name, index in local_metrics.names2idx.items()
627632
}

open_instruct/dpo_utils.py

Lines changed: 106 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -873,6 +873,72 @@ def concatenated_inputs(batch: dict[str, list | torch.Tensor]) -> dict[str, torc
873873
return concatenated_batch
874874

875875

876+
def unpack_to_padded(
877+
packed_logits: torch.Tensor, cu_doc_lens: torch.Tensor, batch_size: int, max_seq_len: int, pad_value: float = 0.0
878+
) -> torch.Tensor:
879+
"""Unpack packed logits back to padded format (batch_size, max_seq_len, vocab_size).
880+
881+
Args:
882+
packed_logits: Packed logits of shape (1, total_tokens, vocab_size).
883+
cu_doc_lens: Cumulative document lengths of shape (batch_size + 1,).
884+
batch_size: Number of sequences in the batch.
885+
max_seq_len: Maximum sequence length for padding.
886+
pad_value: Value to use for padding (default 0.0).
887+
888+
Returns:
889+
Padded logits of shape (batch_size, max_seq_len, vocab_size).
890+
"""
891+
vocab_size = packed_logits.shape[-1]
892+
padded = torch.full(
893+
(batch_size, max_seq_len, vocab_size), pad_value, dtype=packed_logits.dtype, device=packed_logits.device
894+
)
895+
splits = cu_doc_lens.diff().tolist()
896+
packed_list = torch.split(packed_logits.squeeze(0), splits, dim=0)
897+
for i, doc_logits in enumerate(packed_list):
898+
padded[i, : doc_logits.shape[0]] = doc_logits
899+
return padded
900+
901+
902+
def pack_padded_sequences(
903+
input_ids: torch.Tensor, labels: torch.Tensor, attention_mask: torch.Tensor
904+
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, int]:
905+
"""Convert padded sequences to packed format with cumulative document lengths.
906+
907+
This is needed for OLMo-core models which don't support attention_mask but use
908+
cu_doc_lens for intra-document attention masking.
909+
910+
Args:
911+
input_ids: Padded input IDs of shape (batch_size, seq_len).
912+
labels: Padded labels of shape (batch_size, seq_len).
913+
attention_mask: Attention mask of shape (batch_size, seq_len), where 1 indicates
914+
valid tokens and 0 indicates padding.
915+
916+
Returns:
917+
Tuple of (packed_input_ids, packed_labels, cu_doc_lens, max_doc_len).
918+
- packed_input_ids: Shape (1, total_tokens) with all sequences concatenated.
919+
- packed_labels: Shape (1, total_tokens) with all labels concatenated.
920+
- cu_doc_lens: Cumulative document lengths of shape (batch_size + 1,).
921+
- max_doc_len: Maximum document length in the batch.
922+
"""
923+
batch_size = input_ids.shape[0]
924+
seq_lengths = attention_mask.sum(dim=1)
925+
max_doc_len = int(seq_lengths.max().item())
926+
cu_doc_lens = torch.zeros(batch_size + 1, dtype=torch.int32, device=input_ids.device)
927+
cu_doc_lens[1:] = seq_lengths.cumsum(dim=0)
928+
929+
packed_input_ids_list = []
930+
packed_labels_list = []
931+
for i in range(batch_size):
932+
length = seq_lengths[i].item()
933+
packed_input_ids_list.append(input_ids[i, :length])
934+
packed_labels_list.append(labels[i, :length])
935+
936+
packed_input_ids = torch.cat(packed_input_ids_list, dim=0).unsqueeze(0)
937+
packed_labels = torch.cat(packed_labels_list, dim=0).unsqueeze(0)
938+
939+
return packed_input_ids, packed_labels, cu_doc_lens, max_doc_len
940+
941+
876942
def concatenated_forward(
877943
model: nn.Module,
878944
batch: dict[str, list | torch.Tensor],
@@ -905,6 +971,7 @@ def concatenated_forward(
905971
for k, v in concatenated_batch.items()
906972
if k.startswith("concatenated_") and not k.endswith("labels")
907973
}
974+
908975
if output_router_logits:
909976
outputs = model(**inputs, output_router_logits=True)
910977
logits = outputs.logits.to(torch.float32)
@@ -1023,25 +1090,41 @@ def concatenated_forward_olmo(
10231090
Tuple of (chosen_logps, rejected_logps, aux_loss). aux_loss is always None for OLMo-core.
10241091
"""
10251092
del output_router_logits
1093+
bs = batch["chosen_input_ids"].shape[0]
1094+
10261095
if not packing:
10271096
concatenated_batch = concatenated_inputs(batch)
1028-
else:
1029-
concatenated_batch, bs = pf_concatenated_inputs(batch)
1097+
packed_input_ids, packed_labels, cu_doc_lens, max_doc_len = pack_padded_sequences(
1098+
concatenated_batch["concatenated_input_ids"],
1099+
concatenated_batch["concatenated_labels"],
1100+
concatenated_batch["concatenated_attention_mask"],
1101+
)
10301102

1031-
logits = model(concatenated_batch["concatenated_input_ids"]).to(torch.float32)
1103+
doc_lens = cu_doc_lens.diff()
1104+
packed_logits = model(packed_input_ids, doc_lens=doc_lens, max_doc_lens=[max_doc_len]).to(torch.float32)
1105+
1106+
batch_size = concatenated_batch["concatenated_input_ids"].shape[0]
1107+
max_seq_len = concatenated_batch["concatenated_input_ids"].shape[1]
1108+
logits = unpack_to_padded(packed_logits, cu_doc_lens, batch_size, max_seq_len)
10321109

1033-
if not packing:
10341110
all_logps = _get_batch_logps(
10351111
logits, concatenated_batch["concatenated_labels"], average_log_prob=average_log_prob
10361112
)
1037-
bs = batch["chosen_input_ids"].shape[0]
10381113
else:
1114+
concatenated_batch, bs = pf_concatenated_inputs(batch)
1115+
cu_doc_lens_packing = concatenated_batch["concatenated_cu_seq_lens_k"]
1116+
doc_lens_packing = cu_doc_lens_packing.diff()
1117+
max_doc_len_packing = concatenated_batch["concatenated_max_length_k"]
1118+
logits = model(
1119+
concatenated_batch["concatenated_input_ids"], doc_lens=doc_lens_packing, max_doc_lens=[max_doc_len_packing]
1120+
).to(torch.float32)
10391121
all_logps = pf_get_batch_logps(
10401122
logits,
10411123
concatenated_batch["concatenated_labels"],
10421124
concatenated_batch["concatenated_cu_seq_lens_k"],
10431125
average_log_prob=average_log_prob,
10441126
)
1127+
10451128
chosen_logps = all_logps[:bs]
10461129
rejected_logps = all_logps[bs:]
10471130
return chosen_logps, rejected_logps, None
@@ -1069,17 +1152,29 @@ def separate_forward_olmo(
10691152
"""
10701153
del output_router_logits
10711154
chosen_batch = process_batch(batch, "chosen")
1072-
chosen_logits = model(chosen_batch["input_ids"]).to(torch.float32)
1073-
1155+
packed_input_ids, _, cu_doc_lens, max_doc_len = pack_padded_sequences(
1156+
chosen_batch["input_ids"], chosen_batch["labels"], chosen_batch["attention_mask"]
1157+
)
1158+
doc_lens = cu_doc_lens.diff()
1159+
packed_logits = model(packed_input_ids, doc_lens=doc_lens, max_doc_lens=[max_doc_len]).to(torch.float32)
1160+
batch_size = chosen_batch["input_ids"].shape[0]
1161+
max_seq_len = chosen_batch["input_ids"].shape[1]
1162+
chosen_logits = unpack_to_padded(packed_logits, cu_doc_lens, batch_size, max_seq_len)
10741163
chosen_logps = _get_batch_logps(chosen_logits, chosen_batch["labels"], average_log_prob=average_log_prob)
1075-
del chosen_batch, chosen_logits
1164+
del chosen_batch, chosen_logits, packed_input_ids, packed_logits
10761165
torch.cuda.empty_cache()
10771166

10781167
rejected_batch = process_batch(batch, "rejected")
1079-
rejected_logits = model(rejected_batch["input_ids"]).to(torch.float32)
1080-
1168+
packed_input_ids, _, cu_doc_lens, max_doc_len = pack_padded_sequences(
1169+
rejected_batch["input_ids"], rejected_batch["labels"], rejected_batch["attention_mask"]
1170+
)
1171+
doc_lens = cu_doc_lens.diff()
1172+
packed_logits = model(packed_input_ids, doc_lens=doc_lens, max_doc_lens=[max_doc_len]).to(torch.float32)
1173+
batch_size = rejected_batch["input_ids"].shape[0]
1174+
max_seq_len = rejected_batch["input_ids"].shape[1]
1175+
rejected_logits = unpack_to_padded(packed_logits, cu_doc_lens, batch_size, max_seq_len)
10811176
rejected_logps = _get_batch_logps(rejected_logits, rejected_batch["labels"], average_log_prob=average_log_prob)
1082-
del rejected_batch, rejected_logits
1177+
del rejected_batch, rejected_logits, packed_input_ids, packed_logits
10831178
torch.cuda.empty_cache()
10841179

10851180
return chosen_logps, rejected_logps, None

open_instruct/olmo_core_train_modules.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -107,9 +107,6 @@ def __init__(
107107
self._forward_kwargs["packing"] = True
108108

109109
def pre_train(self):
110-
# Override to skip batch size validation from TransformerTrainModule.
111-
# DPO processes 2x sequences per batch (chosen + rejected), so the parent's
112-
# validation (global_batch_size % rank_microbatch_size == 0) would fail.
113110
pass
114111

115112
def _compute_microbatch_loss(self, micro_batch: dict[str, Any]) -> tuple[torch.Tensor, dict[str, torch.Tensor]]:

open_instruct/test_data_loader_gpu.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22
import unittest
33

44
import datasets
5-
import numpy as np
65
import parameterized
76
import torch
87

@@ -92,9 +91,9 @@ def test_multi_rank_sampling(self, name, dp_world_size):
9291
union |= indices
9392
total_batches = num_examples // batch_size
9493
usable_size = total_batches * batch_size
95-
rng = np.random.default_rng(42)
96-
shuffled = np.arange(num_examples)
97-
rng.shuffle(shuffled)
94+
generator = torch.Generator()
95+
generator.manual_seed(42)
96+
shuffled = torch.randperm(num_examples, generator=generator).numpy()
9897
expected_indices = set(shuffled[:usable_size].tolist())
9998
self.assertEqual(union, expected_indices)
10099

open_instruct/test_dpo_utils_gpu.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,9 @@ def __init__(self, vocab_size: int = 1000):
8080
self.embed = torch.nn.Embedding(vocab_size, 64)
8181
self.linear = torch.nn.Linear(64, vocab_size)
8282

83-
def forward(self, input_ids: torch.Tensor) -> torch.Tensor:
83+
def forward(
84+
self, input_ids: torch.Tensor, doc_lens: torch.Tensor | None = None, max_doc_lens: list[int] | None = None
85+
) -> torch.Tensor:
8486
return self.linear(self.embed(input_ids))
8587

8688

scripts/train/debug/dpo/single_gpu.sh

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,8 +19,8 @@ uv run python mason.py \
1919
--model_name_or_path allenai/OLMo-2-0425-1B \
2020
--tokenizer_name_or_path allenai/OLMo-2-0425-1B \
2121
--max_seq_length 1024 \
22-
--per_device_train_batch_size 1 \
23-
--gradient_accumulation_steps 4 \
22+
--per_device_train_batch_size 4 \
23+
--gradient_accumulation_steps 1 \
2424
--learning_rate 5e-07 \
2525
--lr_scheduler_type linear \
2626
--warmup_ratio 0.1 \

0 commit comments

Comments
 (0)