1818 preprocessor_for_format ,
1919)
2020from levanter .data .ul2r import (
21+ RX_TASK_KIND ,
2122 TokenizedDict ,
2223 compute_denoising_length ,
2324 noise_span_to_unique_sentinel ,
@@ -291,7 +292,7 @@ def test_noise_span_to_unique_sentinel():
291292 )
292293 noise_mask = jnp .pad (noise_mask , (0 , padded_length - 10 ), constant_values = False )
293294
294- result = noise_span_to_unique_sentinel (tokens , noise_mask , sentinel_tokens , 10 , force_initial_sentinel = False )
295+ result = noise_span_to_unique_sentinel (tokens , 10 , noise_mask , pad_token_id , sentinel_tokens , force_initial_sentinel = False )
295296
296297 expected = jnp .array ([100 , 13 , 14 , 15 , 16 , 17 , 18 , 19 ])
297298 np .testing .assert_array_equal (result [:8 ], expected )
@@ -323,7 +324,7 @@ def test_noise_span_to_unique_sentinel():
323324 )
324325 noise_mask = jnp .pad (noise_mask , (0 , padded_length - 15 ), constant_values = False )
325326
326- result = noise_span_to_unique_sentinel (tokens , noise_mask , sentinel_tokens , 15 , force_initial_sentinel = True )
327+ result = noise_span_to_unique_sentinel (tokens , 15 , noise_mask , pad_token_id , sentinel_tokens , force_initial_sentinel = True )
327328
328329 # Should still start with sentinel
329330 expected = jnp .array ([100 , 10 , 101 , 13 , 14 , 15 , 102 , 17 , 18 , 19 , 103 , 23 , 24 ])
@@ -335,7 +336,7 @@ def test_noise_span_to_unique_sentinel():
335336 tokens = jnp .pad (tokens , (0 , padded_length - 10 ), constant_values = pad_token_id )
336337 noise_mask = jnp .zeros (padded_length , dtype = jnp .bool_ )
337338
338- result = noise_span_to_unique_sentinel (tokens , noise_mask , sentinel_tokens , 10 , force_initial_sentinel = False )
339+ result = noise_span_to_unique_sentinel (tokens , 10 , noise_mask , pad_token_id , sentinel_tokens , force_initial_sentinel = False )
339340
340341 # Should be unchanged except for padding
341342 np .testing .assert_array_equal (result [:10 ], jnp .arange (10 , 20 ))
@@ -405,16 +406,18 @@ def test_to_ul2r_rx_tokens():
405406
406407 inputs = noise_span_to_unique_sentinel (
407408 tokens ,
409+ length ,
408410 noise_mask ,
411+ pad_token_id ,
409412 sentinel_tokens ,
410- length ,
411413 force_initial_sentinel = False ,
412414 )
413415 targets = noise_span_to_unique_sentinel (
414416 tokens ,
417+ length ,
415418 ~ noise_mask ,
419+ pad_token_id ,
416420 sentinel_tokens ,
417- length ,
418421 force_initial_sentinel = True ,
419422 )
420423
@@ -430,6 +433,7 @@ def test_to_ul2r_rx_tokens():
430433 mask_prob = 0.3 ,
431434 mean_noise_span_length = 3.0 ,
432435 random_roll = False ,
436+ pad_token_id = pad_token_id ,
433437 sentinel_token_ids = sentinel_tokens ,
434438 max_length = max_length ,
435439 )
@@ -502,6 +506,7 @@ def test_to_ul2r_rx_tokens_roll():
502506 mask_prob = 0.3 ,
503507 mean_noise_span_length = 3.0 ,
504508 random_roll = True ,
509+ pad_token_id = pad_token_id ,
505510 sentinel_token_ids = sentinel_ids ,
506511 max_length = max_length ,
507512 )
@@ -530,6 +535,51 @@ def test_to_ul2r_rx_tokens_roll():
530535 assert jnp .any (jnp .isin (sentinel_ids , targets ))
531536
532537
538+ def test_compute_denoising_length_rx_random_roll ():
539+ """
540+ Test that compute_denoising_length with random_roll=True reserves enough
541+ space and sets pad_token_id when it doesn't create an extra span.
542+
543+ When random_roll=True, we reserve space for an extra span. However, rolling doesn't
544+ always create an additional span, so we should see both cases:
545+ - no pad_token_id (rolling created an extra span)
546+ - 1 pad_token_id (rolling created an extra span)
547+ """
548+ max_length = 16
549+ pad_token_id = 999 # Use non-zero pad token to verify it's actually being used
550+ sentinel_ids = jnp .arange (100 , 120 )
551+
552+ length = 12
553+ tokens = jnp .arange (1 , length + 1 )
554+ tokens = jnp .pad (tokens , (0 , max_length - length ), constant_values = pad_token_id )
555+
556+ mask_prob = 0.3
557+ mean_noise_span_length = 3.0
558+ task_params = jnp .array ([RX_TASK_KIND , R_TASK_TOKEN_ID , mask_prob , mean_noise_span_length ])
559+
560+ predicted_length = compute_denoising_length (task_params , length , random_roll = True )
561+
562+ padding_counts = []
563+ for i in range (16 ):
564+ key = jax .random .PRNGKey (i )
565+ _input_length , result = to_ul2r_rx_tokens (
566+ key , tokens , length , mask_prob , mean_noise_span_length , True ,
567+ pad_token_id , sentinel_ids , max_length
568+ )
569+
570+ # Subtract 1 because `result` doesn't include the task token.
571+ # print(result[:predicted_length - 1])
572+ num_padding = jnp .sum (result [:predicted_length - 1 ] == pad_token_id )
573+ padding_counts .append (int (num_padding ))
574+
575+ actual_length = jnp .sum (result != pad_token_id )
576+ assert actual_length <= predicted_length
577+
578+ assert any (p == 0 for p in padding_counts )
579+ assert any (p == 2 for p in padding_counts )
580+ assert all (p == 0 or p == 2 for p in padding_counts )
581+
582+
533583def test_ul2r_loss_mask ():
534584 # Test case 1: Simple single segment
535585 input_masks = jnp .array ([1 , 1 , 0 , 0 ]) # First 2 are inputs
@@ -597,6 +647,7 @@ def test_to_ul2r_rx_tokens_truncates_both_sections_and_contains_sentinels():
597647 mask_prob = mask_prob ,
598648 mean_noise_span_length = mean_noise_span_length ,
599649 random_roll = random_roll ,
650+ pad_token_id = pad_token_id ,
600651 sentinel_token_ids = sentinel_tokens ,
601652 max_length = padded_length ,
602653 )
@@ -612,6 +663,7 @@ def test_to_ul2r_rx_tokens_truncates_both_sections_and_contains_sentinels():
612663 mask_prob = mask_prob ,
613664 mean_noise_span_length = mean_noise_span_length ,
614665 random_roll = random_roll ,
666+ pad_token_id = pad_token_id ,
615667 sentinel_token_ids = sentinel_tokens ,
616668 max_length = max_length ,
617669 )
@@ -687,9 +739,9 @@ def test_create_ul2r_example():
687739 in_len_s = 5
688740 in_len = in_len_r + in_len_x + in_len_s
689741
690- out_len_r = compute_denoising_length (task_params [0 ], in_len_r )
691- out_len_x = compute_denoising_length (task_params [1 ], in_len_x )
692- out_len_s = compute_denoising_length (task_params [2 ], in_len_s )
742+ out_len_r = compute_denoising_length (task_params [0 ], in_len_r , False )
743+ out_len_x = compute_denoising_length (task_params [1 ], in_len_x , False )
744+ out_len_s = compute_denoising_length (task_params [2 ], in_len_s , False )
693745
694746 tokens = jnp .concatenate (
695747 [
@@ -754,7 +806,7 @@ def prepare_segment(id: int) -> tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray]:
754806 # batch independently (but in a way that matches how we computed
755807 # lengths for packing).
756808 task_idx = task_indices [id ]
757- out_length = compute_denoising_length (task_params [task_idx ], in_length )
809+ out_length = compute_denoising_length (task_params [task_idx ], in_length , False )
758810
759811 return in_start , in_length , out_length
760812
@@ -790,7 +842,7 @@ def process_segment(key, id: int) -> tuple[jnp.ndarray, jnp.ndarray, int, int]:
790842
791843 segment = jnp .roll (tokens .array , - in_start )
792844 print (key , task_params [task_idx ], segment , in_length , QPos .size )
793- inputs_len , denoising_tokens = to_ul2r_tokens (key , task_params [task_idx ], segment , in_length , SENTINEL_TOKEN_IDS , QPos .size )
845+ inputs_len , denoising_tokens = to_ul2r_tokens (key , task_params [task_idx ], segment , in_length , pad_token_id , SENTINEL_TOKEN_IDS , QPos .size )
794846
795847 n_tokens = tokens .array .shape [0 ]
796848 input_mask = jnp .arange (n_tokens ) < inputs_len
0 commit comments