Skip to content

Commit 4873914

Browse files
authored
Refactor MLMDataCollatorWithFlattening (#1382)
Updates this class to be more general for llama3 usage by making it accept a base collator and perform flattening after-the-fact. Essentially we're now always doing the bshd-compatable forward call. This is going to make it easier to implement CP in the llama3 recipe --------- Signed-off-by: Peter St. John <[email protected]>
1 parent 408709d commit 4873914

File tree

14 files changed

+476
-803
lines changed

14 files changed

+476
-803
lines changed

bionemo-recipes/models/esm2/src/esm/collator.py

Lines changed: 82 additions & 207 deletions
Large diffs are not rendered by default.

bionemo-recipes/models/esm2/tests/test_collator.py

Lines changed: 43 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -21,15 +21,16 @@
2121

2222
from esm.collator import (
2323
DataCollatorWithFlattening,
24-
MLMDataCollatorWithFlattening,
2524
TokenPackingDataset,
2625
_split_sample_by_num_tokens,
2726
)
2827

2928

30-
def test_data_collator_with_flattening_basic():
29+
def test_data_collator_with_flattening_basic(tokenizer):
3130
"""Test DataCollatorWithFlattening with input_ids and attention_mask."""
32-
collator = DataCollatorWithFlattening(return_position_ids=True)
31+
# Use DataCollatorForLanguageModeling with mlm_probability=0.0 to disable masking
32+
mlm_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm_probability=0.0)
33+
collator = DataCollatorWithFlattening(collator=mlm_collator, return_position_ids=True)
3334

3435
# Create test sequences of different lengths
3536
features = [
@@ -70,19 +71,24 @@ def test_data_collator_with_flattening_basic():
7071
expected_input_ids = torch.tensor([[0, 5, 6, 7, 2, 0, 8, 9, 10, 11, 2, 0, 12, 13, 2]], dtype=torch.int64)
7172
torch.testing.assert_close(input_ids_tensor, expected_input_ids)
7273

73-
# Assert labels are not present when not provided in input
74-
assert "labels" not in batch
74+
# Assert labels are present (DataCollatorForLanguageModeling always creates them)
75+
# With mlm_probability=0.0, all labels should be -100 (ignored)
76+
assert "labels" in batch
77+
assert (batch["labels"] == -100).all(), "With mlm_probability=0.0, all labels should be -100"
7578

7679

77-
def test_data_collator_with_flattening_with_labels():
80+
def test_data_collator_with_flattening_with_labels(tokenizer):
7881
"""Test DataCollatorWithFlattening with input_ids, attention_mask, and labels."""
79-
collator = DataCollatorWithFlattening()
82+
# Use DataCollatorForLanguageModeling with mlm_probability=0.0 to disable masking
83+
# Note: DataCollatorForLanguageModeling ignores input labels and creates its own
84+
mlm_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm_probability=0.0)
85+
collator = DataCollatorWithFlattening(collator=mlm_collator)
8086

81-
# Create test sequences with labels
87+
# Create test sequences (labels will be created by DataCollatorForLanguageModeling)
8288
features = [
83-
{"input_ids": [0, 5, 6, 7, 2], "labels": [0, 5, 6, 7, 2]}, # 5 tokens
84-
{"input_ids": [0, 8, 9, 10, 11, 2], "labels": [0, 8, 9, 10, 11, 2]}, # 6 tokens
85-
{"input_ids": [0, 12, 13, 2], "labels": [0, 12, 13, 2]}, # 4 tokens
89+
{"input_ids": [0, 5, 6, 7, 2]}, # 5 tokens
90+
{"input_ids": [0, 8, 9, 10, 11, 2]}, # 6 tokens
91+
{"input_ids": [0, 12, 13, 2]}, # 4 tokens
8692
]
8793

8894
# Calculate expected total tokens
@@ -114,12 +120,12 @@ def test_data_collator_with_flattening_with_labels():
114120
assert batch["max_length_q"] == 6, f"Expected max_length_q=6, got {batch['max_length_q']}"
115121
assert batch["max_length_k"] == 6, f"Expected max_length_k=6, got {batch['max_length_k']}"
116122

117-
# Assert flattened input_ids and labels match concatenated original sequences
123+
# Assert flattened input_ids match concatenated original sequences
118124
expected_input_ids = torch.tensor([[0, 5, 6, 7, 2, 0, 8, 9, 10, 11, 2, 0, 12, 13, 2]], dtype=torch.int64)
119-
expected_labels = torch.tensor([[0, 5, 6, 7, 2, 0, 8, 9, 10, 11, 2, 0, 12, 13, 2]], dtype=torch.int64)
120-
121125
torch.testing.assert_close(input_ids_tensor, expected_input_ids)
122-
torch.testing.assert_close(labels_tensor, expected_labels)
126+
127+
# With mlm_probability=0.0, all labels should be -100 (ignored)
128+
assert (labels_tensor == -100).all(), "With mlm_probability=0.0, all labels should be -100"
123129

124130
# Assert that sequence boundaries are properly maintained
125131
# by checking that token positions match expected values
@@ -134,9 +140,11 @@ def test_data_collator_with_flattening_with_labels():
134140
start_idx = end_idx
135141

136142

137-
def test_data_collator_pads_to_multiple_of():
143+
def test_data_collator_pads_to_multiple_of(tokenizer):
138144
"""Test DataCollatorWithFlattening with input_ids and attention_mask."""
139-
collator = DataCollatorWithFlattening(pad_to_multiple_of=8, token_pad=1, label_pad=-100, return_position_ids=True)
145+
# Use DataCollatorForLanguageModeling with mlm_probability=0.0 to disable masking
146+
mlm_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm_probability=0.0)
147+
collator = DataCollatorWithFlattening(collator=mlm_collator, pad_to_multiple_of=8, return_position_ids=True)
140148

141149
# Create test sequences with labels
142150
features = [
@@ -168,11 +176,8 @@ def test_data_collator_pads_to_multiple_of():
168176

169177
def test_mlm_data_collator_with_flattening_basic(tokenizer):
170178
"""Test MLMDataCollatorWithFlattening with basic input_ids and verify labels are created."""
171-
collator = MLMDataCollatorWithFlattening(
172-
tokenizer=tokenizer,
173-
mlm_probability=0.15,
174-
return_position_ids=True,
175-
)
179+
mlm_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm_probability=0.15)
180+
collator = DataCollatorWithFlattening(collator=mlm_collator, return_position_ids=True)
176181

177182
# Create test sequences of different lengths
178183
features = [
@@ -232,11 +237,8 @@ def test_mlm_data_collator_with_flattening_basic(tokenizer):
232237
def test_mlm_data_collator_with_flattening_masking(tokenizer, test_proteins):
233238
"""Test MLMDataCollatorWithFlattening with reproducible masking using a seed."""
234239
# Use a fixed seed for reproducibility
235-
collator = MLMDataCollatorWithFlattening(
236-
tokenizer=tokenizer,
237-
mlm_probability=0.15,
238-
seed=42,
239-
)
240+
mlm_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm_probability=0.15, seed=42)
241+
collator = DataCollatorWithFlattening(collator=mlm_collator)
240242

241243
features = [tokenizer(protein) for protein in test_proteins]
242244

@@ -293,11 +295,8 @@ def test_mlm_data_collator_with_flattening_pad_to_multiple_of(tokenizer, test_pr
293295
remainder = -total_tokens % 8
294296
assert remainder != 0, "Test assumes we need to pad to reach a multiple of 8"
295297

296-
collator = MLMDataCollatorWithFlattening(
297-
tokenizer=tokenizer,
298-
mlm_probability=0.15,
299-
pad_to_multiple_of=8,
300-
)
298+
mlm_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm_probability=0.15)
299+
collator = DataCollatorWithFlattening(collator=mlm_collator, pad_to_multiple_of=8)
301300

302301
features = [tokenizer(protein) for protein in test_proteins]
303302

@@ -334,21 +333,22 @@ def test_mlm_data_collator_with_flattening_pad_to_multiple_of(tokenizer, test_pr
334333

335334
def test_mlm_data_collator_with_flattening_bshd_equivalent(tokenizer, test_proteins):
336335
"""Test MLMDataCollatorWithFlattening with bshd_equivalent=True."""
337-
thd_collator = MLMDataCollatorWithFlattening(
338-
tokenizer=tokenizer,
339-
mlm_probability=0.15,
340-
seed=42,
341-
pad_to_multiple_of=16,
342-
bshd_equivalent=True,
343-
bshd_pad_to_multiple_of=256,
344-
)
345-
336+
# Create separate collator instances with the same seed to ensure matching masking
337+
# The BSHD collator pads to 256
346338
bshd_collator = DataCollatorForLanguageModeling(
347339
tokenizer=tokenizer,
348340
mlm_probability=0.15,
349341
seed=42,
350342
pad_to_multiple_of=256,
351343
)
344+
thd_collator = DataCollatorWithFlattening(
345+
collator=DataCollatorForLanguageModeling(
346+
tokenizer=tokenizer,
347+
mlm_probability=0.15,
348+
seed=42,
349+
pad_to_multiple_of=256,
350+
)
351+
)
352352

353353
features = [tokenizer(protein) for protein in test_proteins]
354354

@@ -375,11 +375,8 @@ def test_mlm_data_collator_with_flattening_bshd_equivalent(tokenizer, test_prote
375375

376376
def test_mlm_data_collator_with_flattening_pad_sequences_to_be_divisible_by(tokenizer, test_proteins):
377377
"""Test MLMDataCollatorWithFlattening with pad_sequences_to_be_divisible_by."""
378-
collator = MLMDataCollatorWithFlattening(
379-
tokenizer=tokenizer,
380-
mlm_probability=0.15,
381-
pad_sequences_to_be_divisible_by=16,
382-
)
378+
mlm_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm_probability=0.15)
379+
collator = DataCollatorWithFlattening(collator=mlm_collator, pad_sequences_to_be_divisible_by=16)
383380
features = [tokenizer(protein) for protein in test_proteins]
384381
batch = collator(features)
385382
assert batch["input_ids"].numel() % 16 == 0, (

bionemo-recipes/models/esm2/tests/test_collator_context_parallel.py

Lines changed: 35 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -208,23 +208,35 @@ def __next__(self):
208208
return copy.deepcopy(self._batch)
209209

210210

211-
class _DummyCPGroup:
212-
def __init__(self, size: int):
211+
class _DummyDeviceMesh:
212+
"""Dummy device mesh for testing ContextParallelDataLoaderWrapper."""
213+
214+
def __init__(self, size: int, rank: int = 0):
213215
self._size = size
216+
self._rank = rank
217+
self._group = mock.MagicMock() # Mock process group
218+
219+
def get_local_rank(self) -> int:
220+
"""Return the local rank within this mesh."""
221+
return self._rank
222+
223+
def get_group(self):
224+
"""Return the process group."""
225+
return self._group
214226

215227
def size(self) -> int:
228+
"""Return the size of the mesh."""
216229
return self._size
217230

218231

219232
def _fake_get_batch(
220233
cu_seqlens_padded,
221234
input_ids_padded,
222235
labels_padded,
223-
cp_group,
236+
cp_size,
224237
qvk_format,
225238
cp_rank,
226239
):
227-
cp_size = cp_group.size()
228240
total_slices = 2 * cp_size
229241
seq_tokens = input_ids_padded.view(-1)
230242
seq_labels = labels_padded.view(-1)
@@ -250,14 +262,14 @@ def _fake_get_batch(
250262
)
251263

252264

253-
def _make_cp_shards(base_batch: Dict[str, torch.Tensor], cp_group: _DummyCPGroup):
265+
def _make_cp_shards(base_batch: Dict[str, torch.Tensor], cp_size: int):
254266
combined_batch = []
255-
for cp_rank in range(cp_group.size()):
267+
for cp_rank in range(cp_size):
256268
input_ids_sharded, labels_sharded = _fake_get_batch(
257269
cu_seqlens_padded=base_batch["cu_seq_lens_q_padded"],
258270
input_ids_padded=base_batch["input_ids"],
259271
labels_padded=base_batch["labels"],
260-
cp_group=cp_group,
272+
cp_size=cp_size,
261273
qvk_format="thd",
262274
cp_rank=cp_rank,
263275
)
@@ -368,7 +380,7 @@ def test_dataloader_scatter_nopadding():
368380
CP0 | 1,2,7,8 | 9, 10, 15, 16 |
369381
CP1 | 3,4,5,6 | 11, 12, 13, 14|
370382
"""
371-
cp_group = _DummyCPGroup(size=2)
383+
cp_size = 2
372384

373385
def run_roundtrip(base_batch):
374386
combined_batch = [
@@ -381,22 +393,24 @@ def run_roundtrip(base_batch):
381393
labels_padded=base_batch["labels"],
382394
qvk_format="thd",
383395
cp_rank=cp_rank,
384-
cp_world_size=cp_group.size(),
396+
cp_world_size=cp_size,
385397
)[0],
386398
"labels": _split_batch_by_cp_rank(
387399
cu_seqlens_padded=base_batch["cu_seq_lens_q_padded"],
388400
input_ids_padded=base_batch["input_ids"],
389401
labels_padded=base_batch["labels"],
390402
qvk_format="thd",
391403
cp_rank=cp_rank,
392-
cp_world_size=cp_group.size(),
404+
cp_world_size=cp_size,
393405
)[1],
394406
},
395407
)
396-
for cp_rank in range(cp_group.size())
408+
for cp_rank in range(cp_size)
397409
]
398-
loader_rank0 = ContextParallelDataLoaderWrapper(_DummyLoader(combined_batch), cp_group, cp_rank=0)
399-
loader_rank1 = ContextParallelDataLoaderWrapper(_DummyLoader(combined_batch), cp_group, cp_rank=1)
410+
cp_mesh_rank0 = _DummyDeviceMesh(size=cp_size, rank=0)
411+
cp_mesh_rank1 = _DummyDeviceMesh(size=cp_size, rank=1)
412+
loader_rank0 = ContextParallelDataLoaderWrapper(_DummyLoader(combined_batch), cp_mesh_rank0)
413+
loader_rank1 = ContextParallelDataLoaderWrapper(_DummyLoader(combined_batch), cp_mesh_rank1)
400414

401415
scatter_payload: Dict[str, List[Dict[str, torch.Tensor]]] = {}
402416
current_rank = {"value": None}
@@ -455,7 +469,7 @@ def test_dataloader_scatter_with_pad_between_seqs():
455469
CP0 | 1,<p>,5,<p> | 9, <p>, 13, <p>|
456470
CP1 | 2,3,6, <p> | 10, 11, 14, 15 |
457471
"""
458-
cp_group = _DummyCPGroup(size=2)
472+
cp_size = 2
459473

460474
def run_roundtrip(base_batch):
461475
combined_batch = [
@@ -468,22 +482,24 @@ def run_roundtrip(base_batch):
468482
labels_padded=base_batch["labels"],
469483
qvk_format="thd",
470484
cp_rank=cp_rank,
471-
cp_world_size=cp_group.size(),
485+
cp_world_size=cp_size,
472486
)[0],
473487
"labels": _split_batch_by_cp_rank(
474488
cu_seqlens_padded=base_batch["cu_seq_lens_q_padded"],
475489
input_ids_padded=base_batch["input_ids"],
476490
labels_padded=base_batch["labels"],
477491
qvk_format="thd",
478492
cp_rank=cp_rank,
479-
cp_world_size=cp_group.size(),
493+
cp_world_size=cp_size,
480494
)[1],
481495
},
482496
)
483-
for cp_rank in range(cp_group.size())
497+
for cp_rank in range(cp_size)
484498
]
485-
loader_rank0 = ContextParallelDataLoaderWrapper(_DummyLoader(combined_batch), cp_group, cp_rank=0)
486-
loader_rank1 = ContextParallelDataLoaderWrapper(_DummyLoader(combined_batch), cp_group, cp_rank=1)
499+
cp_mesh_rank0 = _DummyDeviceMesh(size=cp_size, rank=0)
500+
cp_mesh_rank1 = _DummyDeviceMesh(size=cp_size, rank=1)
501+
loader_rank0 = ContextParallelDataLoaderWrapper(_DummyLoader(combined_batch), cp_mesh_rank0)
502+
loader_rank1 = ContextParallelDataLoaderWrapper(_DummyLoader(combined_batch), cp_mesh_rank1)
487503

488504
scatter_payload: Dict[str, List[Dict[str, torch.Tensor]]] = {}
489505
current_rank = {"value": None}

bionemo-recipes/models/esm2/tests/test_fp8.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -20,8 +20,9 @@
2020
from torch.distributed.checkpoint.state_dict import get_model_state_dict
2121
from transformer_engine.common import recipe as recipe_module
2222
from transformer_engine.pytorch import fp8
23+
from transformers import DataCollatorForLanguageModeling
2324

24-
from esm.collator import MLMDataCollatorWithFlattening
25+
from esm.collator import DataCollatorWithFlattening
2526
from esm.modeling_esm_te import NVEsmConfig, NVEsmForMaskedLM
2627

2728

@@ -87,11 +88,10 @@ def parametrize_recipes_with_support(recipes):
8788

8889
@pytest.fixture
8990
def input_data_thd(tokenizer, tokenized_proteins):
90-
data_collator = MLMDataCollatorWithFlattening(
91-
tokenizer=tokenizer,
92-
mlm_probability=0.15,
91+
mlm_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm_probability=0.15, seed=42)
92+
data_collator = DataCollatorWithFlattening(
93+
collator=mlm_collator,
9394
pad_to_multiple_of=32, # MXFP8 requires the sequence length to be divisible by 32, regular FP8 requires 16.
94-
seed=42,
9595
)
9696

9797
return data_collator(tokenized_proteins)
@@ -139,6 +139,9 @@ def test_fp8_forward_and_backward_pass_thd(te_model_checkpoint, input_data_thd,
139139
if isinstance(fp8_recipe, recipe_module.NVFP4BlockScaling):
140140
atol = 0.2
141141
rtol = 0.05
142+
elif isinstance(fp8_recipe, recipe_module.DelayedScaling):
143+
atol = 0.1
144+
rtol = 0.03
142145
else:
143146
atol = None
144147
rtol = None

bionemo-recipes/models/esm2/tests/test_thd.py

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -20,8 +20,9 @@
2020
import torch
2121
from transformer_engine.pytorch.attention.dot_product_attention import _attention_backends
2222
from transformer_engine.pytorch.attention.dot_product_attention.context_parallel import pad_thd_sequences_for_cp
23+
from transformers import DataCollatorForLanguageModeling
2324

24-
from esm.collator import MLMDataCollatorWithFlattening
25+
from esm.collator import DataCollatorWithFlattening
2526
from esm.modeling_esm_te import NVEsmConfig, NVEsmEmbeddings, NVEsmForMaskedLM
2627

2728

@@ -39,12 +40,14 @@
3940

4041
@pytest.fixture
4142
def input_data_thd(tokenizer, tokenized_proteins):
42-
data_collator = MLMDataCollatorWithFlattening(
43-
tokenizer=tokenizer,
44-
mlm_probability=0.15,
45-
seed=42,
46-
bshd_equivalent=True,
47-
bshd_pad_to_multiple_of=32,
43+
"""The collator here needs to exactly match the one used in the `input_data` fixture for golden values to pass."""
44+
data_collator = DataCollatorWithFlattening(
45+
collator=DataCollatorForLanguageModeling(
46+
tokenizer=tokenizer,
47+
mlm_probability=0.15,
48+
pad_to_multiple_of=32,
49+
seed=42,
50+
)
4851
)
4952
return data_collator(tokenized_proteins)
5053

bionemo-recipes/recipes/esm2_native_te/README.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -292,3 +292,6 @@ training configurations, allowing for easy modification of training hyper-parame
292292

293293
Configuration parameters can be overridden from the command line, e.g.
294294
`python train_fsdp2.py --config-name L0_sanity fp8_config.enabled=true`.
295+
296+
For verbose logging, use the hydra command line override `hydra.verbose=true`, see
297+
https://hydra.cc/docs/tutorials/basic/running_your_app/logging/ for more details.

0 commit comments

Comments
 (0)