@@ -484,15 +484,15 @@ def train_step(self, state: S, *batch: X, **batch_kwargs) -> StepInfo[S]:
484484
485485 if self .config .crash_on_nan and jnp .isnan (loss ):
486486 jnp .set_printoptions (threshold = sys .maxsize , linewidth = sys .maxsize )
487- jax .debug .print (f"tokens={ batch [0 ].tokens .array .astype (dtype = jnp .int32 )} " )
488- jax .debug .print (f"loss_mask={ batch [0 ].loss_mask .array .astype (dtype = jnp .int32 )} " )
487+ # jax.debug.print(f"tokens={batch[0].tokens.array.astype(dtype=jnp.int32)}")
488+ # jax.debug.print(f"loss_mask={batch[0].loss_mask.array.astype(dtype=jnp.int32)}")
489489 # print(f"batch={batch}")
490- jax .debug .print ("{result}" , result = result )
491- jax .debug .print (f"attn_mask={ batch [0 ].attn_mask } " )
492- jax .debug .print ("input_mask={x}" , x = batch [0 ].attn_mask .input_mask .array )
493- jax .debug .print ("segment_ids={x}" , x = batch [0 ].attn_mask .segment_ids [0 ].array )
494- materialized = batch [0 ].attn_mask .materialize (hax .Axis (name = 'position' , size = 1024 ), hax .Axis (name = 'key_position' , size = 1024 ))
495- jax .debug .print (f"attn_mask={ materialized .array } " )
490+ # jax.debug.print("{result}", result=result)
491+ # jax.debug.print(f"attn_mask={batch[0].attn_mask}")
492+ # jax.debug.print("input_mask={x}", x=batch[0].attn_mask.input_mask.array)
493+ # jax.debug.print("segment_ids={x}", x=batch[0].attn_mask.segment_ids[0].array)
494+ # materialized = batch[0].attn_mask.materialize(hax.Axis(name='position', size=1024), hax.Axis(name='key_position', size=1024))
495+ # jax.debug.print(f"attn_mask={materialized.array}")
496496 raise RuntimeError ("Loss is NaN" )
497497
498498 if self .config .crash_on_inf and jnp .isinf (loss ):
0 commit comments