Skip to content

Commit a45ac73

Browse files
committed
Correction to match logs
1 parent 9350f4d commit a45ac73

File tree

1 file changed

+11
-6
lines changed

1 file changed

+11
-6
lines changed

train_gpt.py

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -818,8 +818,6 @@ def apply(self, old_window: int, new_window: int, alpha: int=1, beta: int=32):
818818
self.sin.copy_(theta.sin())
819819
self.attn_scale *= 0.2 * math.log(new_window / old_window) + 1
820820

821-
flash_attn_interface = get_kernel('varunneal/flash-attention-3').flash_attn_interface
822-
823821
def rotary(x_BTHD: Tensor, cos: Tensor, sin: Tensor):
824822
assert cos.size(0) >= x_BTHD.size(-3)
825823
cos, sin = (
@@ -841,6 +839,8 @@ class AttnArgs:
841839
sin: torch.Tensor
842840
attn_scale: float
843841

842+
flash_attn_interface = get_kernel('varunneal/flash-attention-3').flash_attn_interface
843+
844844
class CausalSelfAttention(nn.Module):
845845
def __init__(self, dim: int, head_dim: int, num_heads: int):
846846
super().__init__()
@@ -1034,10 +1034,15 @@ def forward(self, input_seq: Tensor, target_seq: Tensor, seqlens: Tensor, ws_sho
10341034
skip_connections.append(x)
10351035

10361036
x = norm(x)
1037-
logits = self.lm_head(x).float()
1037+
logits = self.lm_head(x)
10381038
# @Grad62304977 added tanh softcapping following Gemma 2 paper, @KoszarskyB reduced it from 30 to 15, @YouJiacheng shifted it by +15 (2*sigmoid(2*x)=tanh(x)+1)
10391039
logits = 30 * torch.sigmoid(logits / 7.5)
1040-
loss = F.cross_entropy(logits.view(-1, logits.size(-1)), target_seq, reduction="sum" if self.training else "mean")
1040+
logits_for_loss = logits.float() if not self.training else logits
1041+
loss = F.cross_entropy(
1042+
logits_for_loss.view(-1, logits_for_loss.size(-1)),
1043+
target_seq,
1044+
reduction="sum" if self.training else "mean",
1045+
)
10411046
return loss
10421047

10431048
# -----------------------------------------------------------------------------
@@ -1229,7 +1234,7 @@ class Hyperparameters:
12291234
train_max_seq_len: int = 128 * 16
12301235
val_batch_size: int = 4 * 64 * 1024 * 8
12311236
# optimization
1232-
num_iterations: int = 1640 # number of iterations to run
1237+
num_iterations: int = 1630 # number of iterations to run
12331238
iteration_extension = 40 # number of iterations to continue training at final cooldown and window size
12341239
cooldown_frac: int = 0.5 # fraction of training spent cooling down the learning rate
12351240
# evaluation and logging
@@ -1319,7 +1324,7 @@ def nvidia_smi():
13191324
eps=1e-8,
13201325
weight_decay=0.0,
13211326
)
1322-
optimizer2 = Muon(hidden_matrix_params + gate_params, lr=0.05, momentum=0.95, weight_decay=0.0)
1327+
optimizer2 = Muon(hidden_matrix_params + gate_params, lr=0.06, momentum=0.95, weight_decay=0.0)
13231328
optimizers = [optimizer1, optimizer2]
13241329
for opt in optimizers:
13251330
for group in opt.param_groups:

0 commit comments

Comments
 (0)