Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
59 commits
Select commit Hold shift + click to select a range
6ae4146
metrics fix
finbarrtimbers Jan 30, 2026
53ef2ef
Add single_gpu_cache.sh for DPO cache comparison
finbarrtimbers Jan 30, 2026
74b1d4e
Fix beaker_config UnboundLocalError in dpo_tune_cache.py
finbarrtimbers Jan 30, 2026
cdec970
Add rewards_average and token_count metrics to DPO
finbarrtimbers Jan 30, 2026
893772e
Add --no-host-networking to single GPU DPO scripts
finbarrtimbers Jan 30, 2026
3c477ba
Fix logger.info call in dpo_tune_cache.py
finbarrtimbers Jan 30, 2026
075537d
Sync build_reference_logprobs_cache call with dpo_utils.py
finbarrtimbers Jan 30, 2026
a41e9a4
Fix description in single_gpu_cache.sh
finbarrtimbers Jan 30, 2026
f04f080
Include gradient_accumulation_steps in global_batch_size for dpo.py
finbarrtimbers Jan 30, 2026
d6bc330
Set drop_last=False in dpo.py to match dpo_tune_cache.py
finbarrtimbers Jan 30, 2026
2646b44
Add debug logging to investigate logprobs discrepancy between dpo.py …
finbarrtimbers Jan 30, 2026
5e04417
Add attention masking support for OLMo-core DPO
finbarrtimbers Jan 30, 2026
0079553
Use PyTorch RNG in HFDataLoader to match dpo_tune_cache.py data ordering
finbarrtimbers Jan 30, 2026
6de7ad8
Add debug logging to compare data ordering between dpo.py and dpo_tun…
finbarrtimbers Jan 30, 2026
1974570
Make OLMo-core DPO use same logprob computation as HuggingFace
finbarrtimbers Jan 30, 2026
b4865ec
Add detailed logprob debug logging to compare dpo.py and dpo_tune_cac…
finbarrtimbers Jan 30, 2026
665707f
Add micro-batching to DPO to match dpo_tune_cache.py batch structure
finbarrtimbers Jan 30, 2026
889bfe9
Add debug logging to compare HF and OLMo-core forward passes
finbarrtimbers Jan 30, 2026
008fc86
Add embedding weight logging to compare HF and OLMo-core models
finbarrtimbers Jan 30, 2026
c3feb3b
Fix embed weight logging to handle DTensor (FSDP)
finbarrtimbers Jan 30, 2026
017c924
Use full_tensor() for FSDP sharded weights
finbarrtimbers Jan 30, 2026
ced54f6
Fix: Actually load HF weights into OLMo-core model
finbarrtimbers Jan 30, 2026
23e139e
Align data ordering between dpo.py and dpo_tune_cache.py
finbarrtimbers Jan 30, 2026
016861c
Remove double shuffling from dpo.py
finbarrtimbers Jan 30, 2026
844aa0b
Revert changes to dpo_tune_cache.py
finbarrtimbers Jan 30, 2026
e4575e5
Implement double-shuffle in dpo.py to match dpo_tune_cache.py
finbarrtimbers Jan 30, 2026
fa3bd40
Reseed torch RNG before DataLoader creation in dpo_tune_cache.py
finbarrtimbers Jan 30, 2026
223622a
Revert torch.manual_seed change to dpo_tune_cache.py
finbarrtimbers Jan 30, 2026
4d496b1
Use seeded generator for DataLoader shuffle in dpo_tune_cache.py
finbarrtimbers Jan 30, 2026
4ea43f4
Add detailed logits logging at label positions
finbarrtimbers Jan 30, 2026
9eec226
Add input token logging at position 445-450 for debugging
finbarrtimbers Jan 30, 2026
77ef874
Add packed logits logging at position 447 for debugging
finbarrtimbers Jan 30, 2026
9361ae7
Fix OLMo-core attention masking by using correct argument names
finbarrtimbers Jan 31, 2026
1a289c2
Remove debug logging that causes index errors for short sequences
finbarrtimbers Jan 31, 2026
cbe9960
Fix DataLoader shuffle to match HFDataLoader's randperm order
finbarrtimbers Jan 31, 2026
1803432
Add debug logging to verify randperm behavior on Beaker
finbarrtimbers Jan 31, 2026
a14af8f
Add dataset_len to debug logging for randperm verification
finbarrtimbers Jan 31, 2026
1592470
Fix dpo.py epoch alignment with dpo_tune_cache.py
finbarrtimbers Feb 1, 2026
26e8dfc
Apply H17, H25, H16 fixes from DPO divergence investigation
finbarrtimbers Feb 10, 2026
22fd8e2
Revert H17/H25, keep H16: use Duration.steps() instead of Duration.ep…
finbarrtimbers Feb 10, 2026
8fd3771
Add validation notebook for DPO comparison
finbarrtimbers Feb 10, 2026
31024f1
Remove incorrect trainer.epoch = 0 (default is 1-based)
finbarrtimbers Feb 10, 2026
6143215
Fix rewards_accuracy to use per-sample comparison instead of scalar
finbarrtimbers Feb 11, 2026
5a95094
Revert "Fix rewards_accuracy to use per-sample comparison instead of …
finbarrtimbers Feb 11, 2026
72c7715
Revert "Remove incorrect trainer.epoch = 0 (default is 1-based)"
finbarrtimbers Feb 11, 2026
e01505b
Fix rewards_accuracy to use per-sample comparison instead of scalar
finbarrtimbers Feb 11, 2026
faa15a4
Merge origin/main into finbarr/dpo-match-single-gpu
finbarrtimbers Feb 12, 2026
f4631f9
Fix HF weight loading and micro-batch splitting after TransformerTrai…
finbarrtimbers Feb 12, 2026
f9ab894
Remove DEBUG logging from DPO forward passes and data loader
finbarrtimbers Feb 12, 2026
7043d69
Merge remote-tracking branch 'origin/main' into finbarr/dpo-match-sin…
finbarrtimbers Feb 18, 2026
f8d6a37
cleaned up PR and merged to head
finbarrtimbers Feb 18, 2026
0d0e930
Set gradient_accumulation_steps=1 in debug DPO scripts
finbarrtimbers Feb 18, 2026
4d55299
cleaned up PR
finbarrtimbers Feb 18, 2026
186b1f7
cleaned up PR
finbarrtimbers Feb 18, 2026
3fff585
cleaned up PR
finbarrtimbers Feb 18, 2026
a8f652a
Fix PR #1451 review comments
finbarrtimbers Feb 18, 2026
19bf948
set drop_last
finbarrtimbers Feb 18, 2026
fc694ee
updated code
finbarrtimbers Feb 18, 2026
c5cb1bb
Remove del statement for unused params in mock model forward
finbarrtimbers Feb 18, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ All notable changes to this project will be documented in this file.
- Increased vLLM health check timeout from 30s to 600s (10 minutes) (https://github.com/allenai/open-instruct/pull/1452).
- Updated vllm version to 0.14.1 (https://github.com/allenai/open-instruct/pull/1433).
- Changed default wandb x-axis from `episode` to `training_step` for grpo_fast (https://github.com/allenai/open-instruct/pull/1437).
- Made a bunch of changes to `dpo.py` so it matches `dpo_tune_cache.py` perfectly (https://github.com/allenai/open-instruct/pull/1451).

### Fixed
- Fixed test `single_example_collator` returning raw int for index, causing `TypeError` in `_iter_batches` (https://github.com/allenai/open-instruct/pull/1477).
Expand Down
7 changes: 4 additions & 3 deletions open_instruct/data_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -231,12 +231,13 @@ def _reshard(self, epoch: int) -> None:

Uses index-based shuffling to avoid copying the dataset.
"""
rng = np.random.default_rng(self.seed + epoch)
all_indices = np.arange(len(self._full_dataset))
generator = torch.Generator()
generator.manual_seed(self.seed + epoch)
dataset_len = len(self._full_dataset)
all_indices = torch.randperm(dataset_len, generator=generator).numpy()
if self._excluded_indices:
mask = np.isin(all_indices, list(self._excluded_indices), invert=True)
all_indices = all_indices[mask]
rng.shuffle(all_indices)

global_size = len(all_indices)
total_batches = global_size // self._batch_size
Expand Down
33 changes: 21 additions & 12 deletions open_instruct/dpo.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ def _load():


def _setup_model(args: dpo_utils.ExperimentConfig, device: torch.device):
"""Load and configure OLMo-core model."""
"""Build OLMo-core model architecture (weights loaded after parallelization)."""
hf_config = transformers.AutoConfig.from_pretrained(args.model_name_or_path)
vocab_size = hf_config.vocab_size
logger.info(f"Building OLMo-core model with vocab_size={vocab_size}")
Expand All @@ -103,10 +103,6 @@ def _setup_model(args: dpo_utils.ExperimentConfig, device: torch.device):
)
model = model_config.build(init_device="cpu")

logger.info(f"Loading HuggingFace weights from {args.model_name_or_path}")
load_hf_model(args.model_name_or_path, model.state_dict(), work_dir=args.output_dir)
model = model.to(device=device, dtype=torch.bfloat16)

return model, model_config


Expand Down Expand Up @@ -271,7 +267,7 @@ def main(args: dpo_utils.ExperimentConfig, tc: dataset_transformation.TokenizerC

dataset = _load_dataset_distributed(args, tc, transform_fn_args, is_main_process)
dataset = dataset.shuffle(seed=args.seed)
dataset.set_format(type="pt") # Must be after shuffle (shuffle resets format)
dataset.set_format(type="pt")

world_size = distributed_utils.get_world_size() if distributed_utils.is_distributed() else 1
dp_world_size = world_size // args.tensor_parallel_degree
Expand Down Expand Up @@ -308,6 +304,7 @@ def main(args: dpo_utils.ExperimentConfig, tc: dataset_transformation.TokenizerC
work_dir=args.output_dir,
collator=collator,
device=device,
drop_last=True,
)
# 4x batch size: forward-only (no backward), so no activation storage needed.
# With packing, the collator's token budget controls the actual forward-pass size
Expand All @@ -325,7 +322,7 @@ def main(args: dpo_utils.ExperimentConfig, tc: dataset_transformation.TokenizerC
work_dir=args.output_dir,
collator=collator,
device=device,
drop_last=False,
drop_last=True,
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Cache data loader drops examples, causing RuntimeError

High Severity

The cache_data_loader was changed from drop_last=False to drop_last=True. This causes build_reference_logprobs_cache to skip caching some dataset indices (the remainder that doesn't fill a full batch). Since the cache function allocates tensors of size full_dataset_size=len(dataset) and then validates that every index was populated (raising RuntimeError for any -inf entries), this will crash whenever len(dataset) % cache_batch_size != 0.

Fix in Cursor Fix in Web

)

forward_fn = dpo_utils.concatenated_forward_olmo if args.concatenated_forward else dpo_utils.separate_forward_olmo
Expand All @@ -350,8 +347,9 @@ def main(args: dpo_utils.ExperimentConfig, tc: dataset_transformation.TokenizerC

data_loader.reshuffle(epoch=0)
num_training_steps = len(data_loader) * args.num_epochs
effective_steps = args.max_train_steps if args.max_train_steps is not None else num_training_steps
optim_config = AdamWConfig(lr=args.learning_rate, weight_decay=args.weight_decay, fused=args.fused_optimizer)
scheduler = _setup_scheduler(args, num_training_steps)
scheduler = _setup_scheduler(args, effective_steps)
max_grad_norm = args.max_grad_norm if args.max_grad_norm > 0 else None
dp_config = transformer_config.TransformerDataParallelConfig(
name=DataParallelType.hsdp,
Expand Down Expand Up @@ -384,9 +382,13 @@ def main(args: dpo_utils.ExperimentConfig, tc: dataset_transformation.TokenizerC
device=device,
)

# Build reference cache after train_module init because TransformerTrainModule applies
# FSDP parallelism to the model, and we need the parallelized model to calculate the
# logprobs in case the model is too big to fit in memory.
# TransformerTrainModule.__init__ calls parallelize_model which calls init_weights,
# reinitializing all model weights from scratch. We must reload the HF checkpoint.
logger.info("Reloading HuggingFace weights after parallelization...")
sd = train_module.model.state_dict()
load_hf_model(args.model_name_or_path, sd, work_dir=args.output_dir)
train_module.model.load_state_dict(sd)

logger.info("Caching reference logprobs...")
train_module.reference_cache = dpo_utils.build_reference_logprobs_cache(model=train_module.model, **cache_kwargs)

Expand All @@ -399,14 +401,21 @@ def main(args: dpo_utils.ExperimentConfig, tc: dataset_transformation.TokenizerC

trainer_callbacks = _setup_callbacks(args, dp_world_size)

if args.max_train_steps is not None:
max_duration = train.Duration.steps(args.max_train_steps)
else:
max_duration = train.Duration.steps(num_training_steps)

trainer = train.TrainerConfig(
save_folder=args.output_dir,
max_duration=train.Duration.epochs(args.num_epochs),
max_duration=max_duration,
metrics_collect_interval=args.logging_steps,
callbacks=trainer_callbacks,
save_overwrite=True,
).build(train_module, data_loader)

trainer.epoch = 0

logger.info("Starting training...")
trainer.fit()
logger.info("Training complete.")
Expand Down
11 changes: 8 additions & 3 deletions open_instruct/dpo_tune_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@
from huggingface_hub import HfApi
from peft import LoraConfig, TaskType, get_peft_model, prepare_model_for_kbit_training
from rich.pretty import pprint
from torch.utils.data import DataLoader
from torch.utils.data import DataLoader, RandomSampler
from tqdm.auto import tqdm
from transformers import AutoConfig, AutoModelForCausalLM, BitsAndBytesConfig, get_scheduler

Expand Down Expand Up @@ -407,8 +407,9 @@ def load_model():
else:
collate_fn = dpo_utils.DataCollatorForSeq2SeqDPO(tokenizer=tokenizer, model=model, padding="longest")

train_sampler = RandomSampler(train_dataset, generator=torch.Generator().manual_seed(args.seed))
train_dataloader = DataLoader(
train_dataset, shuffle=True, collate_fn=collate_fn, batch_size=args.per_device_train_batch_size
train_dataset, sampler=train_sampler, collate_fn=collate_fn, batch_size=args.per_device_train_batch_size
)

# Optimizer
Expand Down Expand Up @@ -535,6 +536,7 @@ def load_model():
is_main_process=accelerator.is_main_process,
model_dims=model_dims,
use_lora=args.use_lora,
disable_adapter_context=None,
)
logger.info("=============after cache logprobs")
print_gpu_stats(init_gpu_memory)
Expand Down Expand Up @@ -573,6 +575,7 @@ def load_model():
average_log_prob=args.loss_type.is_average_loss,
output_router_logits=args.load_balancing_loss,
) # `aux_loss` is only used when `args.load_balancing_loss = True`

losses, chosen_rewards, rejected_rewards = dpo_utils.compute_loss(
args,
batch,
Expand Down Expand Up @@ -621,7 +624,9 @@ def load_model():
# single all reduce to save time, avoiding per metric all reduce
global_metrics_tensor = accelerator.reduce(local_metrics.metrics, reduction="mean")
global_metrics_tensor /= args.gradient_accumulation_steps * args.logging_steps
global_metrics_tensor[local_metrics.names2idx["token_count"]] *= accelerator.num_processes
global_metrics_tensor[local_metrics.names2idx["token_count"]] *= (
accelerator.num_processes * args.gradient_accumulation_steps * args.logging_steps
)
global_metrics = {
name: global_metrics_tensor[index].item() for name, index in local_metrics.names2idx.items()
}
Expand Down
117 changes: 106 additions & 11 deletions open_instruct/dpo_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -873,6 +873,72 @@ def concatenated_inputs(batch: dict[str, list | torch.Tensor]) -> dict[str, torc
return concatenated_batch


def unpack_to_padded(
packed_logits: torch.Tensor, cu_doc_lens: torch.Tensor, batch_size: int, max_seq_len: int, pad_value: float = 0.0
) -> torch.Tensor:
"""Unpack packed logits back to padded format (batch_size, max_seq_len, vocab_size).

Args:
packed_logits: Packed logits of shape (1, total_tokens, vocab_size).
cu_doc_lens: Cumulative document lengths of shape (batch_size + 1,).
batch_size: Number of sequences in the batch.
max_seq_len: Maximum sequence length for padding.
pad_value: Value to use for padding (default 0.0).

Returns:
Padded logits of shape (batch_size, max_seq_len, vocab_size).
"""
vocab_size = packed_logits.shape[-1]
padded = torch.full(
(batch_size, max_seq_len, vocab_size), pad_value, dtype=packed_logits.dtype, device=packed_logits.device
)
splits = cu_doc_lens.diff().tolist()
packed_list = torch.split(packed_logits.squeeze(0), splits, dim=0)
for i, doc_logits in enumerate(packed_list):
padded[i, : doc_logits.shape[0]] = doc_logits
return padded


def pack_padded_sequences(
input_ids: torch.Tensor, labels: torch.Tensor, attention_mask: torch.Tensor
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, int]:
"""Convert padded sequences to packed format with cumulative document lengths.

This is needed for OLMo-core models which don't support attention_mask but use
cu_doc_lens for intra-document attention masking.

Args:
input_ids: Padded input IDs of shape (batch_size, seq_len).
labels: Padded labels of shape (batch_size, seq_len).
attention_mask: Attention mask of shape (batch_size, seq_len), where 1 indicates
valid tokens and 0 indicates padding.

Returns:
Tuple of (packed_input_ids, packed_labels, cu_doc_lens, max_doc_len).
- packed_input_ids: Shape (1, total_tokens) with all sequences concatenated.
- packed_labels: Shape (1, total_tokens) with all labels concatenated.
- cu_doc_lens: Cumulative document lengths of shape (batch_size + 1,).
- max_doc_len: Maximum document length in the batch.
"""
batch_size = input_ids.shape[0]
seq_lengths = attention_mask.sum(dim=1)
max_doc_len = int(seq_lengths.max().item())
cu_doc_lens = torch.zeros(batch_size + 1, dtype=torch.int32, device=input_ids.device)
cu_doc_lens[1:] = seq_lengths.cumsum(dim=0)

packed_input_ids_list = []
packed_labels_list = []
for i in range(batch_size):
length = seq_lengths[i].item()
packed_input_ids_list.append(input_ids[i, :length])
packed_labels_list.append(labels[i, :length])

packed_input_ids = torch.cat(packed_input_ids_list, dim=0).unsqueeze(0)
packed_labels = torch.cat(packed_labels_list, dim=0).unsqueeze(0)

return packed_input_ids, packed_labels, cu_doc_lens, max_doc_len


def concatenated_forward(
model: nn.Module,
batch: dict[str, list | torch.Tensor],
Expand Down Expand Up @@ -905,6 +971,7 @@ def concatenated_forward(
for k, v in concatenated_batch.items()
if k.startswith("concatenated_") and not k.endswith("labels")
}

if output_router_logits:
outputs = model(**inputs, output_router_logits=True)
logits = outputs.logits.to(torch.float32)
Expand Down Expand Up @@ -1023,25 +1090,41 @@ def concatenated_forward_olmo(
Tuple of (chosen_logps, rejected_logps, aux_loss). aux_loss is always None for OLMo-core.
"""
del output_router_logits
bs = batch["chosen_input_ids"].shape[0]

if not packing:
concatenated_batch = concatenated_inputs(batch)
else:
concatenated_batch, bs = pf_concatenated_inputs(batch)
packed_input_ids, packed_labels, cu_doc_lens, max_doc_len = pack_padded_sequences(
concatenated_batch["concatenated_input_ids"],
concatenated_batch["concatenated_labels"],
concatenated_batch["concatenated_attention_mask"],
)

logits = model(concatenated_batch["concatenated_input_ids"]).to(torch.float32)
doc_lens = cu_doc_lens.diff()
packed_logits = model(packed_input_ids, doc_lens=doc_lens, max_doc_lens=[max_doc_len]).to(torch.float32)

batch_size = concatenated_batch["concatenated_input_ids"].shape[0]
max_seq_len = concatenated_batch["concatenated_input_ids"].shape[1]
logits = unpack_to_padded(packed_logits, cu_doc_lens, batch_size, max_seq_len)

if not packing:
all_logps = _get_batch_logps(
logits, concatenated_batch["concatenated_labels"], average_log_prob=average_log_prob
)
bs = batch["chosen_input_ids"].shape[0]
else:
concatenated_batch, bs = pf_concatenated_inputs(batch)
cu_doc_lens_packing = concatenated_batch["concatenated_cu_seq_lens_k"]
doc_lens_packing = cu_doc_lens_packing.diff()
max_doc_len_packing = concatenated_batch["concatenated_max_length_k"]
logits = model(
concatenated_batch["concatenated_input_ids"], doc_lens=doc_lens_packing, max_doc_lens=[max_doc_len_packing]
).to(torch.float32)
all_logps = pf_get_batch_logps(
logits,
concatenated_batch["concatenated_labels"],
concatenated_batch["concatenated_cu_seq_lens_k"],
average_log_prob=average_log_prob,
)

chosen_logps = all_logps[:bs]
rejected_logps = all_logps[bs:]
return chosen_logps, rejected_logps, None
Expand Down Expand Up @@ -1069,17 +1152,29 @@ def separate_forward_olmo(
"""
del output_router_logits
chosen_batch = process_batch(batch, "chosen")
chosen_logits = model(chosen_batch["input_ids"]).to(torch.float32)

packed_input_ids, _, cu_doc_lens, max_doc_len = pack_padded_sequences(
chosen_batch["input_ids"], chosen_batch["labels"], chosen_batch["attention_mask"]
)
doc_lens = cu_doc_lens.diff()
packed_logits = model(packed_input_ids, doc_lens=doc_lens, max_doc_lens=[max_doc_len]).to(torch.float32)
batch_size = chosen_batch["input_ids"].shape[0]
max_seq_len = chosen_batch["input_ids"].shape[1]
chosen_logits = unpack_to_padded(packed_logits, cu_doc_lens, batch_size, max_seq_len)
chosen_logps = _get_batch_logps(chosen_logits, chosen_batch["labels"], average_log_prob=average_log_prob)
del chosen_batch, chosen_logits
del chosen_batch, chosen_logits, packed_input_ids, packed_logits
torch.cuda.empty_cache()

rejected_batch = process_batch(batch, "rejected")
rejected_logits = model(rejected_batch["input_ids"]).to(torch.float32)

packed_input_ids, _, cu_doc_lens, max_doc_len = pack_padded_sequences(
rejected_batch["input_ids"], rejected_batch["labels"], rejected_batch["attention_mask"]
)
doc_lens = cu_doc_lens.diff()
packed_logits = model(packed_input_ids, doc_lens=doc_lens, max_doc_lens=[max_doc_len]).to(torch.float32)
batch_size = rejected_batch["input_ids"].shape[0]
max_seq_len = rejected_batch["input_ids"].shape[1]
rejected_logits = unpack_to_padded(packed_logits, cu_doc_lens, batch_size, max_seq_len)
rejected_logps = _get_batch_logps(rejected_logits, rejected_batch["labels"], average_log_prob=average_log_prob)
del rejected_batch, rejected_logits
del rejected_batch, rejected_logits, packed_input_ids, packed_logits
torch.cuda.empty_cache()

return chosen_logps, rejected_logps, None
Expand Down
3 changes: 0 additions & 3 deletions open_instruct/olmo_core_train_modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,9 +107,6 @@ def __init__(
self._forward_kwargs["packing"] = True

def pre_train(self):
# Override to skip batch size validation from TransformerTrainModule.
# DPO processes 2x sequences per batch (chosen + rejected), so the parent's
# validation (global_batch_size % rank_microbatch_size == 0) would fail.
pass

def _compute_microbatch_loss(self, micro_batch: dict[str, Any]) -> tuple[torch.Tensor, dict[str, torch.Tensor]]:
Expand Down
7 changes: 3 additions & 4 deletions open_instruct/test_data_loader_gpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
import unittest

import datasets
import numpy as np
import parameterized
import torch

Expand Down Expand Up @@ -92,9 +91,9 @@ def test_multi_rank_sampling(self, name, dp_world_size):
union |= indices
total_batches = num_examples // batch_size
usable_size = total_batches * batch_size
rng = np.random.default_rng(42)
shuffled = np.arange(num_examples)
rng.shuffle(shuffled)
generator = torch.Generator()
generator.manual_seed(42)
shuffled = torch.randperm(num_examples, generator=generator).numpy()
expected_indices = set(shuffled[:usable_size].tolist())
self.assertEqual(union, expected_indices)

Expand Down
4 changes: 3 additions & 1 deletion open_instruct/test_dpo_utils_gpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,9 @@ def __init__(self, vocab_size: int = 1000):
self.embed = torch.nn.Embedding(vocab_size, 64)
self.linear = torch.nn.Linear(64, vocab_size)

def forward(self, input_ids: torch.Tensor) -> torch.Tensor:
def forward(
self, input_ids: torch.Tensor, doc_lens: torch.Tensor | None = None, max_doc_lens: list[int] | None = None
) -> torch.Tensor:
return self.linear(self.embed(input_ids))


Expand Down
4 changes: 2 additions & 2 deletions scripts/train/debug/dpo/single_gpu.sh
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,8 @@ uv run python mason.py \
--model_name_or_path allenai/OLMo-2-0425-1B \
--tokenizer_name_or_path allenai/OLMo-2-0425-1B \
--max_seq_length 1024 \
--per_device_train_batch_size 1 \
--gradient_accumulation_steps 4 \
--per_device_train_batch_size 4 \
--gradient_accumulation_steps 1 \
--learning_rate 5e-07 \
--lr_scheduler_type linear \
--warmup_ratio 0.1 \
Expand Down
Loading