Skip to content

Commit c2eed36

Browse files
committed
moar
1 parent 4b76785 commit c2eed36

File tree

1 file changed

+8
-8
lines changed

1 file changed

+8
-8
lines changed

src/levanter/trainer.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -484,15 +484,15 @@ def train_step(self, state: S, *batch: X, **batch_kwargs) -> StepInfo[S]:
484484

485485
if self.config.crash_on_nan and jnp.isnan(loss):
486486
jnp.set_printoptions(threshold=sys.maxsize, linewidth=sys.maxsize)
487-
jax.debug.print(f"tokens={batch[0].tokens.array.astype(dtype=jnp.int32)}")
488-
jax.debug.print(f"loss_mask={batch[0].loss_mask.array.astype(dtype=jnp.int32)}")
487+
# jax.debug.print(f"tokens={batch[0].tokens.array.astype(dtype=jnp.int32)}")
488+
# jax.debug.print(f"loss_mask={batch[0].loss_mask.array.astype(dtype=jnp.int32)}")
489489
# print(f"batch={batch}")
490-
jax.debug.print("{result}", result=result)
491-
jax.debug.print(f"attn_mask={batch[0].attn_mask}")
492-
jax.debug.print("input_mask={x}", x=batch[0].attn_mask.input_mask.array)
493-
jax.debug.print("segment_ids={x}", x=batch[0].attn_mask.segment_ids[0].array)
494-
materialized = batch[0].attn_mask.materialize(hax.Axis(name='position', size=1024), hax.Axis(name='key_position', size=1024))
495-
jax.debug.print(f"attn_mask={materialized.array}")
490+
# jax.debug.print("{result}", result=result)
491+
# jax.debug.print(f"attn_mask={batch[0].attn_mask}")
492+
# jax.debug.print("input_mask={x}", x=batch[0].attn_mask.input_mask.array)
493+
# jax.debug.print("segment_ids={x}", x=batch[0].attn_mask.segment_ids[0].array)
494+
# materialized = batch[0].attn_mask.materialize(hax.Axis(name='position', size=1024), hax.Axis(name='key_position', size=1024))
495+
# jax.debug.print(f"attn_mask={materialized.array}")
496496
raise RuntimeError("Loss is NaN")
497497

498498
if self.config.crash_on_inf and jnp.isinf(loss):

0 commit comments

Comments
 (0)