Skip to content

Commit f92e701

Browse files
committed
reserve space for random_roll creating extra noise span
1 parent 45c2c6e commit f92e701

File tree

2 files changed

+82
-19
lines changed

2 files changed

+82
-19
lines changed

src/levanter/data/ul2r.py

Lines changed: 20 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -286,9 +286,9 @@ def apply_roll(m):
286286
@functools.partial(jax.jit, static_argnames=["force_initial_sentinel"])
287287
def noise_span_to_unique_sentinel(
288288
tokens: jnp.ndarray,
289+
length: int,
289290
noise_mask: jnp.ndarray,
290291
sentinel_tokens: jnp.ndarray,
291-
length: int,
292292
force_initial_sentinel: bool,
293293
) -> jnp.ndarray:
294294
"""
@@ -369,6 +369,7 @@ def to_ul2r_rx_tokens(
369369
mask_prob: float,
370370
mean_noise_span_length: float,
371371
random_roll: bool,
372+
pad_token_id: int,
372373
sentinel_token_ids: jnp.ndarray,
373374
max_length: int,
374375
) -> tuple[jnp.ndarray, jnp.ndarray]:
@@ -384,8 +385,8 @@ def to_ul2r_rx_tokens(
384385
- The length of `inputs` (before `targets`).
385386
For use when generating the loss mask / PrefixLM attention mask.
386387
- A tensor with the same shape as `tokens` containing
387-
`inputs targets 0...` where `inputs targets` is truncated to
388-
fit `max_length`. There is no leading padding.
388+
`inputs targets 0...` where `inputs targets` is truncated to fit `max_length`.
389+
There is no leading padding.
389390
"""
390391

391392
padded_length = tokens.shape[0]
@@ -407,16 +408,16 @@ def to_ul2r_rx_tokens(
407408

408409
inputs = noise_span_to_unique_sentinel(
409410
tokens,
411+
length,
410412
noise_mask,
411413
sentinel_token_ids,
412-
length,
413414
force_initial_sentinel=False,
414415
)
415416
targets = noise_span_to_unique_sentinel(
416417
tokens,
418+
target_in_len,
417419
~noise_mask,
418420
sentinel_token_ids,
419-
target_in_len,
420421
force_initial_sentinel=True,
421422
)
422423

@@ -440,7 +441,7 @@ def to_ul2r_rx_tokens(
440441
trunc_target_len = jnp.maximum(target_len - drop_targets, 0)
441442

442443
# Truncate `targets` to the new length; `inputs` are gated by `new_input_len` below
443-
targets = jnp.where(indices < trunc_target_len, targets, 0)
444+
targets = jnp.where(indices < trunc_target_len, targets, pad_token_id)
444445
targets = typing.cast(jnp.ndarray, targets)
445446

446447
targets = jnp.roll(targets, trunc_input_len)
@@ -512,6 +513,7 @@ def to_ul2r_tokens(
512513
task_params: jnp.ndarray,
513514
tokens: jnp.ndarray,
514515
length: int,
516+
pad_token_id: int,
515517
sentinel_token_ids: jnp.ndarray,
516518
# TODO maybe we don't actually need the truncation logic in
517519
# to_ul2r_rx_tokens given that we truncate while packing
@@ -547,6 +549,7 @@ def rx_tokens():
547549
noise_density,
548550
mean_noise_span_length,
549551
True,
552+
pad_token_id,
550553
sentinel_token_ids,
551554
max_length - 1,
552555
)
@@ -616,13 +619,18 @@ def ul2r_loss_mask(
616619
def compute_denoising_length(
617620
task_params: jnp.ndarray,
618621
length: jnp.ndarray,
622+
random_roll: bool,
619623
) -> jnp.ndarray:
620624
def rx_length() -> jnp.ndarray:
621625
noise_density = RXDenoisingConfig.mask_prob_from_task_params(task_params)
622626
mean_noise_span_length = RXDenoisingConfig.mean_span_length_from_task_params(task_params)
623627
_num_noise_tokens, num_noise_spans, _num_nonnoise_tokens = num_noise_spans_tokens_and_spans(
624628
length, noise_density, mean_noise_span_length
625629
)
630+
# When random_roll is True, we might create an additional noise span by
631+
# rolling a noise span so that it is cut by the beginning/end. Reserve
632+
# space for it.
633+
num_noise_spans = jax.lax.select(random_roll, num_noise_spans + 1, num_noise_spans)
626634
# [task_token] one <sentinel_0> three <sentinel_0> two
627635
return 1 + 2 * num_noise_spans + length
628636

@@ -678,7 +686,7 @@ def prepare_segment(id: int) -> tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray]:
678686
# batch independently (but in a way that matches how we computed
679687
# lengths for packing).
680688
task_idx = task_indices[id]
681-
out_length = compute_denoising_length(task_params[task_idx], in_length)
689+
out_length = compute_denoising_length(task_params[task_idx], in_length, False)
682690

683691
return in_start, in_length, out_length
684692

@@ -711,8 +719,11 @@ def process_segment(key: PRNGKeyArray, id: int) -> tuple[jnp.ndarray, jnp.ndarra
711719
out_start = typing.cast(int, jnp.squeeze(out_starts[idx]))
712720

713721
segment = jnp.roll(tokens.array, -in_start)
722+
# TODO this should return the actual length, not just out_length which
723+
# might include an extra token? Or we could just use padding. Loss
724+
# shouldn't be compute don padding anyways.
714725
inputs_len, denoising_tokens = to_ul2r_tokens(
715-
key, task_params[task_idx], segment, in_length, sentinel_token_ids, QPos.size
726+
key, task_params[task_idx], segment, in_length, pad_token_id, sentinel_token_ids, QPos.size
716727
)
717728

718729
n_tokens = tokens.array.shape[0]
@@ -825,7 +836,7 @@ def diff_offsets(offsets: np.ndarray):
825836
# to turn each input batch into a denoising batch while still staying
826837
# under the max sequence length for the model.
827838
def _compute_length(task_idx: jnp.ndarray, length: jnp.ndarray) -> int:
828-
return compute_denoising_length(task_params[task_idx], length)
839+
return compute_denoising_length(task_params[task_idx], length, False)
829840

830841
out_token_counts = jax.vmap(_compute_length)(task_indices, in_token_counts)
831842
out_lengths = {**in_lengths, "input_ids": out_token_counts}

tests/test_ul2r.py

Lines changed: 62 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
preprocessor_for_format,
1919
)
2020
from levanter.data.ul2r import (
21+
RX_TASK_KIND,
2122
TokenizedDict,
2223
compute_denoising_length,
2324
noise_span_to_unique_sentinel,
@@ -291,7 +292,7 @@ def test_noise_span_to_unique_sentinel():
291292
)
292293
noise_mask = jnp.pad(noise_mask, (0, padded_length - 10), constant_values=False)
293294

294-
result = noise_span_to_unique_sentinel(tokens, noise_mask, sentinel_tokens, 10, force_initial_sentinel=False)
295+
result = noise_span_to_unique_sentinel(tokens, 10, noise_mask, pad_token_id, sentinel_tokens, force_initial_sentinel=False)
295296

296297
expected = jnp.array([100, 13, 14, 15, 16, 17, 18, 19])
297298
np.testing.assert_array_equal(result[:8], expected)
@@ -323,7 +324,7 @@ def test_noise_span_to_unique_sentinel():
323324
)
324325
noise_mask = jnp.pad(noise_mask, (0, padded_length - 15), constant_values=False)
325326

326-
result = noise_span_to_unique_sentinel(tokens, noise_mask, sentinel_tokens, 15, force_initial_sentinel=True)
327+
result = noise_span_to_unique_sentinel(tokens, 15, noise_mask, pad_token_id, sentinel_tokens, force_initial_sentinel=True)
327328

328329
# Should still start with sentinel
329330
expected = jnp.array([100, 10, 101, 13, 14, 15, 102, 17, 18, 19, 103, 23, 24])
@@ -335,7 +336,7 @@ def test_noise_span_to_unique_sentinel():
335336
tokens = jnp.pad(tokens, (0, padded_length - 10), constant_values=pad_token_id)
336337
noise_mask = jnp.zeros(padded_length, dtype=jnp.bool_)
337338

338-
result = noise_span_to_unique_sentinel(tokens, noise_mask, sentinel_tokens, 10, force_initial_sentinel=False)
339+
result = noise_span_to_unique_sentinel(tokens, 10, noise_mask, pad_token_id, sentinel_tokens, force_initial_sentinel=False)
339340

340341
# Should be unchanged except for padding
341342
np.testing.assert_array_equal(result[:10], jnp.arange(10, 20))
@@ -405,16 +406,18 @@ def test_to_ul2r_rx_tokens():
405406

406407
inputs = noise_span_to_unique_sentinel(
407408
tokens,
409+
length,
408410
noise_mask,
411+
pad_token_id,
409412
sentinel_tokens,
410-
length,
411413
force_initial_sentinel=False,
412414
)
413415
targets = noise_span_to_unique_sentinel(
414416
tokens,
417+
length,
415418
~noise_mask,
419+
pad_token_id,
416420
sentinel_tokens,
417-
length,
418421
force_initial_sentinel=True,
419422
)
420423

@@ -430,6 +433,7 @@ def test_to_ul2r_rx_tokens():
430433
mask_prob=0.3,
431434
mean_noise_span_length=3.0,
432435
random_roll=False,
436+
pad_token_id=pad_token_id,
433437
sentinel_token_ids=sentinel_tokens,
434438
max_length=max_length,
435439
)
@@ -502,6 +506,7 @@ def test_to_ul2r_rx_tokens_roll():
502506
mask_prob=0.3,
503507
mean_noise_span_length=3.0,
504508
random_roll=True,
509+
pad_token_id=pad_token_id,
505510
sentinel_token_ids=sentinel_ids,
506511
max_length=max_length,
507512
)
@@ -530,6 +535,51 @@ def test_to_ul2r_rx_tokens_roll():
530535
assert jnp.any(jnp.isin(sentinel_ids, targets))
531536

532537

538+
def test_compute_denoising_length_rx_random_roll():
539+
"""
540+
Test that compute_denoising_length with random_roll=True reserves enough
541+
space and sets pad_token_id when it doesn't create an extra span.
542+
543+
When random_roll=True, we reserve space for an extra span. However, rolling doesn't
544+
always create an additional span, so we should see both cases:
545+
- no pad_token_id (rolling created an extra span)
546+
- 1 pad_token_id (rolling created an extra span)
547+
"""
548+
max_length = 16
549+
pad_token_id = 999 # Use non-zero pad token to verify it's actually being used
550+
sentinel_ids = jnp.arange(100, 120)
551+
552+
length = 12
553+
tokens = jnp.arange(1, length + 1)
554+
tokens = jnp.pad(tokens, (0, max_length - length), constant_values=pad_token_id)
555+
556+
mask_prob = 0.3
557+
mean_noise_span_length = 3.0
558+
task_params = jnp.array([RX_TASK_KIND, R_TASK_TOKEN_ID, mask_prob, mean_noise_span_length])
559+
560+
predicted_length = compute_denoising_length(task_params, length, random_roll=True)
561+
562+
padding_counts = []
563+
for i in range(16):
564+
key = jax.random.PRNGKey(i)
565+
_input_length, result = to_ul2r_rx_tokens(
566+
key, tokens, length, mask_prob, mean_noise_span_length, True,
567+
pad_token_id, sentinel_ids, max_length
568+
)
569+
570+
# Subtract 1 because `result` doesn't include the task token.
571+
# print(result[:predicted_length - 1])
572+
num_padding = jnp.sum(result[:predicted_length - 1] == pad_token_id)
573+
padding_counts.append(int(num_padding))
574+
575+
actual_length = jnp.sum(result != pad_token_id)
576+
assert actual_length <= predicted_length
577+
578+
assert any(p == 0 for p in padding_counts)
579+
assert any(p == 2 for p in padding_counts)
580+
assert all(p == 0 or p == 2 for p in padding_counts)
581+
582+
533583
def test_ul2r_loss_mask():
534584
# Test case 1: Simple single segment
535585
input_masks = jnp.array([1, 1, 0, 0]) # First 2 are inputs
@@ -597,6 +647,7 @@ def test_to_ul2r_rx_tokens_truncates_both_sections_and_contains_sentinels():
597647
mask_prob=mask_prob,
598648
mean_noise_span_length=mean_noise_span_length,
599649
random_roll=random_roll,
650+
pad_token_id=pad_token_id,
600651
sentinel_token_ids=sentinel_tokens,
601652
max_length=padded_length,
602653
)
@@ -612,6 +663,7 @@ def test_to_ul2r_rx_tokens_truncates_both_sections_and_contains_sentinels():
612663
mask_prob=mask_prob,
613664
mean_noise_span_length=mean_noise_span_length,
614665
random_roll=random_roll,
666+
pad_token_id=pad_token_id,
615667
sentinel_token_ids=sentinel_tokens,
616668
max_length=max_length,
617669
)
@@ -687,9 +739,9 @@ def test_create_ul2r_example():
687739
in_len_s = 5
688740
in_len = in_len_r + in_len_x + in_len_s
689741

690-
out_len_r = compute_denoising_length(task_params[0], in_len_r)
691-
out_len_x = compute_denoising_length(task_params[1], in_len_x)
692-
out_len_s = compute_denoising_length(task_params[2], in_len_s)
742+
out_len_r = compute_denoising_length(task_params[0], in_len_r, False)
743+
out_len_x = compute_denoising_length(task_params[1], in_len_x, False)
744+
out_len_s = compute_denoising_length(task_params[2], in_len_s, False)
693745

694746
tokens = jnp.concatenate(
695747
[
@@ -754,7 +806,7 @@ def prepare_segment(id: int) -> tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray]:
754806
# batch independently (but in a way that matches how we computed
755807
# lengths for packing).
756808
task_idx = task_indices[id]
757-
out_length = compute_denoising_length(task_params[task_idx], in_length)
809+
out_length = compute_denoising_length(task_params[task_idx], in_length, False)
758810

759811
return in_start, in_length, out_length
760812

@@ -790,7 +842,7 @@ def process_segment(key, id: int) -> tuple[jnp.ndarray, jnp.ndarray, int, int]:
790842

791843
segment = jnp.roll(tokens.array, -in_start)
792844
print(key, task_params[task_idx], segment, in_length, QPos.size)
793-
inputs_len, denoising_tokens = to_ul2r_tokens(key, task_params[task_idx], segment, in_length, SENTINEL_TOKEN_IDS, QPos.size)
845+
inputs_len, denoising_tokens = to_ul2r_tokens(key, task_params[task_idx], segment, in_length, pad_token_id, SENTINEL_TOKEN_IDS, QPos.size)
794846

795847
n_tokens = tokens.array.shape[0]
796848
input_mask = jnp.arange(n_tokens) < inputs_len

0 commit comments

Comments
 (0)