Skip to content

Commit 5183479

Browse files
committed
hmm is it because we stopping too early? but not sure where 50399 comes from instead of 0
1 parent 6b1366a commit 5183479

File tree

1 file changed

+4
-2
lines changed

1 file changed

+4
-2
lines changed

src/levanter/data/ul2r.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -355,8 +355,10 @@ def loop_body(read_pos, state):
355355
return out, write_pos, sentinel_count
356356

357357
result = jnp.zeros_like(tokens)
358-
i0 = jax.lax.select(force_initial_sentinel & ~first_noise_tokens[0], -1, 0)
359-
result, _, _ = jax.lax.fori_loop(i0, length, loop_body, (result, 0, 0))
358+
needs_initial_sentinel = force_initial_sentinel & ~first_noise_tokens[0]
359+
i0 = jax.lax.select(needs_initial_sentinel, -1, 0)
360+
n = jax.lax.select(needs_initial_sentinel, length + 1, length)
361+
result, _, _ = jax.lax.fori_loop(i0, n, loop_body, (result, 0, 0))
360362

361363
return result
362364

0 commit comments

Comments
 (0)