Skip to content

Commit dd04039

Browse files
committed
add test for cp_dataloader that doesn't require datacenter hardware
Signed-off-by: Peter St. John <[email protected]>
1 parent d1fc224 commit dd04039

File tree

2 files changed

+90
-1
lines changed

2 files changed

+90
-1
lines changed

bionemo-recipes/recipes/esm2_native_te/dataset.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -244,6 +244,11 @@ def create_cp_dataloader(
244244
Returns:
245245
A tuple of (dataloader, dataset_or_sampler).
246246
"""
247+
# Ensure pad_sequences_to_be_divisible_by is passed to create_thd_dataloader
248+
if kwargs.get("pad_sequences_to_be_divisible_by", None) is None:
249+
logger.info("pad_sequences_to_be_divisible_by is not provided, using cp_mesh.size() * 2")
250+
kwargs["pad_sequences_to_be_divisible_by"] = cp_mesh.size() * 2
251+
247252
train_dataloader, tokenized_dataset = create_thd_dataloader(*args, **kwargs)
248253

249254
train_dataloader.collate_fn = DataCollatorForContextParallel(

bionemo-recipes/recipes/esm2_native_te/tests/test_dataset.py

Lines changed: 85 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,12 +16,21 @@
1616
import logging
1717
import os
1818
import shutil
19+
import subprocess
1920
from dataclasses import dataclass
2021

22+
import pytest
2123
import torch
24+
from torch.distributed.device_mesh import init_device_mesh
2225

2326
from checkpoint import load_dataloader, save_dataloader
24-
from dataset import create_bshd_dataloader, create_thd_dataloader
27+
from dataset import DistributedConfig, create_bshd_dataloader, create_cp_dataloader, create_thd_dataloader
28+
29+
30+
requires_multi_gpu = pytest.mark.skipif(
31+
not torch.cuda.is_available() or torch.cuda.device_count() < 2,
32+
reason="Test requires at least 2 GPUs",
33+
)
2534

2635

2736
@dataclass
@@ -880,3 +889,78 @@ def test_token_packing_dataloader():
880889
batch = next(iter(dataloader))
881890
assert batch["input_ids"].shape[1] == 8 * 1024
882891
assert batch["labels"].shape[1] == 8 * 1024
892+
893+
894+
@requires_multi_gpu
895+
def test_cp_dataloader(recipe_path):
896+
import os
897+
898+
env = os.environ.copy()
899+
env["PYTHONPATH"] = str(recipe_path)
900+
901+
cmd = [
902+
"torchrun",
903+
"--nproc_per_node=2",
904+
"tests/test_dataset.py",
905+
]
906+
907+
result = subprocess.run(
908+
cmd,
909+
check=False,
910+
text=True,
911+
stdout=subprocess.PIPE,
912+
stderr=subprocess.PIPE,
913+
timeout=240,
914+
cwd=str(recipe_path),
915+
env=env,
916+
)
917+
if result.returncode != 0:
918+
print(f"STDOUT:\n{result.stdout}")
919+
print(f"STDERR:\n{result.stderr}")
920+
pytest.fail(f"Command failed with exit code {result.returncode}")
921+
922+
923+
if __name__ == "__main__":
924+
dist_config = DistributedConfig()
925+
device = torch.device(f"cuda:{dist_config.local_rank}")
926+
torch.distributed.init_process_group(backend="nccl", device_id=device)
927+
torch.cuda.set_device(dist_config.local_rank)
928+
device_mesh = init_device_mesh("cuda", mesh_shape=(1, 2), mesh_dim_names=("dp", "cp"))
929+
930+
dataloader, _ = create_cp_dataloader(
931+
distributed_config=dist_config,
932+
cp_mesh=device_mesh["cp"],
933+
tokenizer_name="facebook/esm2_t6_8M_UR50D",
934+
load_dataset_kwargs={
935+
"path": "parquet",
936+
"split": "train",
937+
"data_files": "train.parquet",
938+
"streaming": True,
939+
},
940+
token_micro_batch_size=8 * 1024,
941+
num_workers=1,
942+
)
943+
944+
batch = next(iter(dataloader))
945+
# With CP size 2, each sequence is split into 2 * cp_world_size = 4 slices.
946+
# Each rank gets 2 slices (beginning and end), so each rank gets approximately
947+
# (8 * 1024) / 2 = 4096 tokens per rank
948+
# Note: Sequences are padded to be divisible by pad_sequences_to_be_divisible_by
949+
# (which defaults to cp_mesh.size() * 2 = 4 if not provided)
950+
# The actual token count per rank can vary due to:
951+
# 1. Sequence packing (variable-length sequences packed up to token_micro_batch_size)
952+
# 2. Per-sequence padding to be divisible by pad_sequences_to_be_divisible_by
953+
# 3. CP splitting logic that takes slices from beginning and end
954+
expected_tokens_per_rank = (8 * 1024) // device_mesh["cp"].size()
955+
actual_shape = batch["input_ids"].shape[1]
956+
# Allow for variance due to sequence packing, padding, and CP splitting
957+
# The actual shape should be close to expected_tokens_per_rank but can vary
958+
# Allow up to 100 tokens of variance (both above and below) to account for
959+
# sequence packing and padding effects
960+
assert actual_shape >= expected_tokens_per_rank - 100, (
961+
f"Expected at least {expected_tokens_per_rank - 100} tokens, got {actual_shape}"
962+
)
963+
assert actual_shape <= expected_tokens_per_rank + 100, (
964+
f"Expected at most {expected_tokens_per_rank + 100} tokens, got {actual_shape}"
965+
)
966+
assert batch["labels"].shape[1] == actual_shape

0 commit comments

Comments
 (0)