2121
2222from esm .collator import (
2323 DataCollatorWithFlattening ,
24- MLMDataCollatorWithFlattening ,
2524 TokenPackingDataset ,
2625 _split_sample_by_num_tokens ,
2726)
2827
2928
30- def test_data_collator_with_flattening_basic ():
29+ def test_data_collator_with_flattening_basic (tokenizer ):
3130 """Test DataCollatorWithFlattening with input_ids and attention_mask."""
32- collator = DataCollatorWithFlattening (return_position_ids = True )
31+ # Use DataCollatorForLanguageModeling with mlm_probability=0.0 to disable masking
32+ mlm_collator = DataCollatorForLanguageModeling (tokenizer = tokenizer , mlm_probability = 0.0 )
33+ collator = DataCollatorWithFlattening (collator = mlm_collator , return_position_ids = True )
3334
3435 # Create test sequences of different lengths
3536 features = [
@@ -70,19 +71,24 @@ def test_data_collator_with_flattening_basic():
7071 expected_input_ids = torch .tensor ([[0 , 5 , 6 , 7 , 2 , 0 , 8 , 9 , 10 , 11 , 2 , 0 , 12 , 13 , 2 ]], dtype = torch .int64 )
7172 torch .testing .assert_close (input_ids_tensor , expected_input_ids )
7273
73- # Assert labels are not present when not provided in input
74- assert "labels" not in batch
74+ # Assert labels are present (DataCollatorForLanguageModeling always creates them)
75+ # With mlm_probability=0.0, all labels should be -100 (ignored)
76+ assert "labels" in batch
77+ assert (batch ["labels" ] == - 100 ).all (), "With mlm_probability=0.0, all labels should be -100"
7578
7679
77- def test_data_collator_with_flattening_with_labels ():
80+ def test_data_collator_with_flattening_with_labels (tokenizer ):
7881 """Test DataCollatorWithFlattening with input_ids, attention_mask, and labels."""
79- collator = DataCollatorWithFlattening ()
82+ # Use DataCollatorForLanguageModeling with mlm_probability=0.0 to disable masking
83+ # Note: DataCollatorForLanguageModeling ignores input labels and creates its own
84+ mlm_collator = DataCollatorForLanguageModeling (tokenizer = tokenizer , mlm_probability = 0.0 )
85+ collator = DataCollatorWithFlattening (collator = mlm_collator )
8086
81- # Create test sequences with labels
87+ # Create test sequences ( labels will be created by DataCollatorForLanguageModeling)
8288 features = [
83- {"input_ids" : [0 , 5 , 6 , 7 , 2 ], "labels" : [ 0 , 5 , 6 , 7 , 2 ] }, # 5 tokens
84- {"input_ids" : [0 , 8 , 9 , 10 , 11 , 2 ], "labels" : [ 0 , 8 , 9 , 10 , 11 , 2 ] }, # 6 tokens
85- {"input_ids" : [0 , 12 , 13 , 2 ], "labels" : [ 0 , 12 , 13 , 2 ] }, # 4 tokens
89+ {"input_ids" : [0 , 5 , 6 , 7 , 2 ]}, # 5 tokens
90+ {"input_ids" : [0 , 8 , 9 , 10 , 11 , 2 ]}, # 6 tokens
91+ {"input_ids" : [0 , 12 , 13 , 2 ]}, # 4 tokens
8692 ]
8793
8894 # Calculate expected total tokens
@@ -114,12 +120,12 @@ def test_data_collator_with_flattening_with_labels():
114120 assert batch ["max_length_q" ] == 6 , f"Expected max_length_q=6, got { batch ['max_length_q' ]} "
115121 assert batch ["max_length_k" ] == 6 , f"Expected max_length_k=6, got { batch ['max_length_k' ]} "
116122
117- # Assert flattened input_ids and labels match concatenated original sequences
123+ # Assert flattened input_ids match concatenated original sequences
118124 expected_input_ids = torch .tensor ([[0 , 5 , 6 , 7 , 2 , 0 , 8 , 9 , 10 , 11 , 2 , 0 , 12 , 13 , 2 ]], dtype = torch .int64 )
119- expected_labels = torch .tensor ([[0 , 5 , 6 , 7 , 2 , 0 , 8 , 9 , 10 , 11 , 2 , 0 , 12 , 13 , 2 ]], dtype = torch .int64 )
120-
121125 torch .testing .assert_close (input_ids_tensor , expected_input_ids )
122- torch .testing .assert_close (labels_tensor , expected_labels )
126+
127+ # With mlm_probability=0.0, all labels should be -100 (ignored)
128+ assert (labels_tensor == - 100 ).all (), "With mlm_probability=0.0, all labels should be -100"
123129
124130 # Assert that sequence boundaries are properly maintained
125131 # by checking that token positions match expected values
@@ -134,9 +140,11 @@ def test_data_collator_with_flattening_with_labels():
134140 start_idx = end_idx
135141
136142
137- def test_data_collator_pads_to_multiple_of ():
143+ def test_data_collator_pads_to_multiple_of (tokenizer ):
138144 """Test DataCollatorWithFlattening with input_ids and attention_mask."""
139- collator = DataCollatorWithFlattening (pad_to_multiple_of = 8 , token_pad = 1 , label_pad = - 100 , return_position_ids = True )
145+ # Use DataCollatorForLanguageModeling with mlm_probability=0.0 to disable masking
146+ mlm_collator = DataCollatorForLanguageModeling (tokenizer = tokenizer , mlm_probability = 0.0 )
147+ collator = DataCollatorWithFlattening (collator = mlm_collator , pad_to_multiple_of = 8 , return_position_ids = True )
140148
141149 # Create test sequences with labels
142150 features = [
@@ -168,11 +176,8 @@ def test_data_collator_pads_to_multiple_of():
168176
169177def test_mlm_data_collator_with_flattening_basic (tokenizer ):
170178 """Test MLMDataCollatorWithFlattening with basic input_ids and verify labels are created."""
171- collator = MLMDataCollatorWithFlattening (
172- tokenizer = tokenizer ,
173- mlm_probability = 0.15 ,
174- return_position_ids = True ,
175- )
179+ mlm_collator = DataCollatorForLanguageModeling (tokenizer = tokenizer , mlm_probability = 0.15 )
180+ collator = DataCollatorWithFlattening (collator = mlm_collator , return_position_ids = True )
176181
177182 # Create test sequences of different lengths
178183 features = [
@@ -232,11 +237,8 @@ def test_mlm_data_collator_with_flattening_basic(tokenizer):
232237def test_mlm_data_collator_with_flattening_masking (tokenizer , test_proteins ):
233238 """Test MLMDataCollatorWithFlattening with reproducible masking using a seed."""
234239 # Use a fixed seed for reproducibility
235- collator = MLMDataCollatorWithFlattening (
236- tokenizer = tokenizer ,
237- mlm_probability = 0.15 ,
238- seed = 42 ,
239- )
240+ mlm_collator = DataCollatorForLanguageModeling (tokenizer = tokenizer , mlm_probability = 0.15 , seed = 42 )
241+ collator = DataCollatorWithFlattening (collator = mlm_collator )
240242
241243 features = [tokenizer (protein ) for protein in test_proteins ]
242244
@@ -293,11 +295,8 @@ def test_mlm_data_collator_with_flattening_pad_to_multiple_of(tokenizer, test_pr
293295 remainder = - total_tokens % 8
294296 assert remainder != 0 , "Test assumes we need to pad to reach a multiple of 8"
295297
296- collator = MLMDataCollatorWithFlattening (
297- tokenizer = tokenizer ,
298- mlm_probability = 0.15 ,
299- pad_to_multiple_of = 8 ,
300- )
298+ mlm_collator = DataCollatorForLanguageModeling (tokenizer = tokenizer , mlm_probability = 0.15 )
299+ collator = DataCollatorWithFlattening (collator = mlm_collator , pad_to_multiple_of = 8 )
301300
302301 features = [tokenizer (protein ) for protein in test_proteins ]
303302
@@ -334,21 +333,22 @@ def test_mlm_data_collator_with_flattening_pad_to_multiple_of(tokenizer, test_pr
334333
335334def test_mlm_data_collator_with_flattening_bshd_equivalent (tokenizer , test_proteins ):
336335 """Test MLMDataCollatorWithFlattening with bshd_equivalent=True."""
337- thd_collator = MLMDataCollatorWithFlattening (
338- tokenizer = tokenizer ,
339- mlm_probability = 0.15 ,
340- seed = 42 ,
341- pad_to_multiple_of = 16 ,
342- bshd_equivalent = True ,
343- bshd_pad_to_multiple_of = 256 ,
344- )
345-
336+ # Create separate collator instances with the same seed to ensure matching masking
337+ # The BSHD collator pads to 256
346338 bshd_collator = DataCollatorForLanguageModeling (
347339 tokenizer = tokenizer ,
348340 mlm_probability = 0.15 ,
349341 seed = 42 ,
350342 pad_to_multiple_of = 256 ,
351343 )
344+ thd_collator = DataCollatorWithFlattening (
345+ collator = DataCollatorForLanguageModeling (
346+ tokenizer = tokenizer ,
347+ mlm_probability = 0.15 ,
348+ seed = 42 ,
349+ pad_to_multiple_of = 256 ,
350+ )
351+ )
352352
353353 features = [tokenizer (protein ) for protein in test_proteins ]
354354
@@ -375,11 +375,8 @@ def test_mlm_data_collator_with_flattening_bshd_equivalent(tokenizer, test_prote
375375
376376def test_mlm_data_collator_with_flattening_pad_sequences_to_be_divisible_by (tokenizer , test_proteins ):
377377 """Test MLMDataCollatorWithFlattening with pad_sequences_to_be_divisible_by."""
378- collator = MLMDataCollatorWithFlattening (
379- tokenizer = tokenizer ,
380- mlm_probability = 0.15 ,
381- pad_sequences_to_be_divisible_by = 16 ,
382- )
378+ mlm_collator = DataCollatorForLanguageModeling (tokenizer = tokenizer , mlm_probability = 0.15 )
379+ collator = DataCollatorWithFlattening (collator = mlm_collator , pad_sequences_to_be_divisible_by = 16 )
383380 features = [tokenizer (protein ) for protein in test_proteins ]
384381 batch = collator (features )
385382 assert batch ["input_ids" ].numel () % 16 == 0 , (
0 commit comments