Skip to content

Commit 45c2c6e

Browse files
committed
length is upper bound, not count
1 parent 5183479 commit 45c2c6e

File tree

1 file changed

+2
-4
lines changed

1 file changed

+2
-4
lines changed

src/levanter/data/ul2r.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -355,10 +355,8 @@ def loop_body(read_pos, state):
355355
return out, write_pos, sentinel_count
356356

357357
result = jnp.zeros_like(tokens)
358-
needs_initial_sentinel = force_initial_sentinel & ~first_noise_tokens[0]
359-
i0 = jax.lax.select(needs_initial_sentinel, -1, 0)
360-
n = jax.lax.select(needs_initial_sentinel, length + 1, length)
361-
result, _, _ = jax.lax.fori_loop(i0, n, loop_body, (result, 0, 0))
358+
i0 = jax.lax.select(force_initial_sentinel & ~first_noise_tokens[0], -1, 0)
359+
result, _, _ = jax.lax.fori_loop(i0, length, loop_body, (result, 0, 0))
362360

363361
return result
364362

0 commit comments

Comments
 (0)