Skip to content

Commit 346f4cb

Browse files
committed
feat: train-time ce in bf16
1 parent 89a1f62 commit 346f4cb

File tree

1 file changed

+9
-4
lines changed

1 file changed

+9
-4
lines changed

train_gpt.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1019,10 +1019,15 @@ def forward(self, input_seq: Tensor, target_seq: Tensor, seqlens: Tensor, ws_sho
10191019
skip_connections.append(x)
10201020

10211021
x = norm(x)
1022-
logits = self.lm_head(x).float()
1022+
logits = self.lm_head(x)
10231023
# @Grad62304977 added tanh softcapping following Gemma 2 paper, @KoszarskyB reduced it from 30 to 15, @YouJiacheng shifted it by +15 (2*sigmoid(2*x)=tanh(x)+1)
1024-
logits = 30 * torch.sigmoid(logits / 7.5)
1025-
loss = F.cross_entropy(logits.view(-1, logits.size(-1)), target_seq, reduction="sum" if self.training else "mean")
1024+
logits = torch.sigmoid(logits / logits.new_tensor(7.5)) * logits.new_tensor(30.0)
1025+
logits_for_loss = logits.float() if not self.training else logits
1026+
loss = F.cross_entropy(
1027+
logits_for_loss.view(-1, logits_for_loss.size(-1)),
1028+
target_seq,
1029+
reduction="sum" if self.training else "mean",
1030+
)
10261031
return loss
10271032

10281033
# -----------------------------------------------------------------------------
@@ -1389,7 +1394,7 @@ def get_ws(step: int):
13891394
assert args.val_tokens % args.val_batch_size == 0
13901395
val_steps = grad_accum_steps * args.val_tokens // args.val_batch_size
13911396
val_loader = distributed_data_generator(args.val_files, args.val_batch_size, -1, grad_accum_steps=grad_accum_steps, align_to_bos=False)
1392-
val_loss = 0
1397+
val_loss = torch.zeros((), device=device, dtype=torch.float32)
13931398
with torch.no_grad():
13941399
for _ in range(val_steps):
13951400
inputs, targets, cum_seqlens = next(val_loader)

0 commit comments

Comments
 (0)