|
16 | 16 | import logging |
17 | 17 | import os |
18 | 18 | import shutil |
| 19 | +import subprocess |
19 | 20 | from dataclasses import dataclass |
20 | 21 |
|
| 22 | +import pytest |
21 | 23 | import torch |
| 24 | +from torch.distributed.device_mesh import init_device_mesh |
22 | 25 |
|
23 | 26 | from checkpoint import load_dataloader, save_dataloader |
24 | | -from dataset import create_bshd_dataloader, create_thd_dataloader |
| 27 | +from dataset import DistributedConfig, create_bshd_dataloader, create_cp_dataloader, create_thd_dataloader |
| 28 | + |
| 29 | + |
| 30 | +requires_multi_gpu = pytest.mark.skipif( |
| 31 | + not torch.cuda.is_available() or torch.cuda.device_count() < 2, |
| 32 | + reason="Test requires at least 2 GPUs", |
| 33 | +) |
25 | 34 |
|
26 | 35 |
|
27 | 36 | @dataclass |
@@ -880,3 +889,78 @@ def test_token_packing_dataloader(): |
880 | 889 | batch = next(iter(dataloader)) |
881 | 890 | assert batch["input_ids"].shape[1] == 8 * 1024 |
882 | 891 | assert batch["labels"].shape[1] == 8 * 1024 |
| 892 | + |
| 893 | + |
| 894 | +@requires_multi_gpu |
| 895 | +def test_cp_dataloader(recipe_path): |
| 896 | + import os |
| 897 | + |
| 898 | + env = os.environ.copy() |
| 899 | + env["PYTHONPATH"] = str(recipe_path) |
| 900 | + |
| 901 | + cmd = [ |
| 902 | + "torchrun", |
| 903 | + "--nproc_per_node=2", |
| 904 | + "tests/test_dataset.py", |
| 905 | + ] |
| 906 | + |
| 907 | + result = subprocess.run( |
| 908 | + cmd, |
| 909 | + check=False, |
| 910 | + text=True, |
| 911 | + stdout=subprocess.PIPE, |
| 912 | + stderr=subprocess.PIPE, |
| 913 | + timeout=240, |
| 914 | + cwd=str(recipe_path), |
| 915 | + env=env, |
| 916 | + ) |
| 917 | + if result.returncode != 0: |
| 918 | + print(f"STDOUT:\n{result.stdout}") |
| 919 | + print(f"STDERR:\n{result.stderr}") |
| 920 | + pytest.fail(f"Command failed with exit code {result.returncode}") |
| 921 | + |
| 922 | + |
| 923 | +if __name__ == "__main__": |
| 924 | + dist_config = DistributedConfig() |
| 925 | + device = torch.device(f"cuda:{dist_config.local_rank}") |
| 926 | + torch.distributed.init_process_group(backend="nccl", device_id=device) |
| 927 | + torch.cuda.set_device(dist_config.local_rank) |
| 928 | + device_mesh = init_device_mesh("cuda", mesh_shape=(1, 2), mesh_dim_names=("dp", "cp")) |
| 929 | + |
| 930 | + dataloader, _ = create_cp_dataloader( |
| 931 | + distributed_config=dist_config, |
| 932 | + cp_mesh=device_mesh["cp"], |
| 933 | + tokenizer_name="facebook/esm2_t6_8M_UR50D", |
| 934 | + load_dataset_kwargs={ |
| 935 | + "path": "parquet", |
| 936 | + "split": "train", |
| 937 | + "data_files": "train.parquet", |
| 938 | + "streaming": True, |
| 939 | + }, |
| 940 | + token_micro_batch_size=8 * 1024, |
| 941 | + num_workers=1, |
| 942 | + ) |
| 943 | + |
| 944 | + batch = next(iter(dataloader)) |
| 945 | + # With CP size 2, each sequence is split into 2 * cp_world_size = 4 slices. |
| 946 | + # Each rank gets 2 slices (beginning and end), so each rank gets approximately |
| 947 | + # (8 * 1024) / 2 = 4096 tokens per rank |
| 948 | + # Note: Sequences are padded to be divisible by pad_sequences_to_be_divisible_by |
| 949 | + # (which defaults to cp_mesh.size() * 2 = 4 if not provided) |
| 950 | + # The actual token count per rank can vary due to: |
| 951 | + # 1. Sequence packing (variable-length sequences packed up to token_micro_batch_size) |
| 952 | + # 2. Per-sequence padding to be divisible by pad_sequences_to_be_divisible_by |
| 953 | + # 3. CP splitting logic that takes slices from beginning and end |
| 954 | + expected_tokens_per_rank = (8 * 1024) // device_mesh["cp"].size() |
| 955 | + actual_shape = batch["input_ids"].shape[1] |
| 956 | + # Allow for variance due to sequence packing, padding, and CP splitting |
| 957 | + # The actual shape should be close to expected_tokens_per_rank but can vary |
| 958 | + # Allow up to 100 tokens of variance (both above and below) to account for |
| 959 | + # sequence packing and padding effects |
| 960 | + assert actual_shape >= expected_tokens_per_rank - 100, ( |
| 961 | + f"Expected at least {expected_tokens_per_rank - 100} tokens, got {actual_shape}" |
| 962 | + ) |
| 963 | + assert actual_shape <= expected_tokens_per_rank + 100, ( |
| 964 | + f"Expected at most {expected_tokens_per_rank + 100} tokens, got {actual_shape}" |
| 965 | + ) |
| 966 | + assert batch["labels"].shape[1] == actual_shape |
0 commit comments