@@ -818,8 +818,6 @@ def apply(self, old_window: int, new_window: int, alpha: int=1, beta: int=32):
818818 self .sin .copy_ (theta .sin ())
819819 self .attn_scale *= 0.2 * math .log (new_window / old_window ) + 1
820820
821- flash_attn_interface = get_kernel ('varunneal/flash-attention-3' ).flash_attn_interface
822-
823821def rotary (x_BTHD : Tensor , cos : Tensor , sin : Tensor ):
824822 assert cos .size (0 ) >= x_BTHD .size (- 3 )
825823 cos , sin = (
@@ -841,6 +839,8 @@ class AttnArgs:
841839 sin : torch .Tensor
842840 attn_scale : float
843841
842+ flash_attn_interface = get_kernel ('varunneal/flash-attention-3' ).flash_attn_interface
843+
844844class CausalSelfAttention (nn .Module ):
845845 def __init__ (self , dim : int , head_dim : int , num_heads : int ):
846846 super ().__init__ ()
@@ -1034,10 +1034,15 @@ def forward(self, input_seq: Tensor, target_seq: Tensor, seqlens: Tensor, ws_sho
10341034 skip_connections .append (x )
10351035
10361036 x = norm (x )
1037- logits = self .lm_head (x ). float ()
1037+ logits = self .lm_head (x )
10381038 # @Grad62304977 added tanh softcapping following Gemma 2 paper, @KoszarskyB reduced it from 30 to 15, @YouJiacheng shifted it by +15 (2*sigmoid(2*x)=tanh(x)+1)
10391039 logits = 30 * torch .sigmoid (logits / 7.5 )
1040- loss = F .cross_entropy (logits .view (- 1 , logits .size (- 1 )), target_seq , reduction = "sum" if self .training else "mean" )
1040+ logits_for_loss = logits .float () if not self .training else logits
1041+ loss = F .cross_entropy (
1042+ logits_for_loss .view (- 1 , logits_for_loss .size (- 1 )),
1043+ target_seq ,
1044+ reduction = "sum" if self .training else "mean" ,
1045+ )
10411046 return loss
10421047
10431048# -----------------------------------------------------------------------------
@@ -1229,7 +1234,7 @@ class Hyperparameters:
12291234 train_max_seq_len : int = 128 * 16
12301235 val_batch_size : int = 4 * 64 * 1024 * 8
12311236 # optimization
1232- num_iterations : int = 1640 # number of iterations to run
1237+ num_iterations : int = 1630 # number of iterations to run
12331238 iteration_extension = 40 # number of iterations to continue training at final cooldown and window size
12341239 cooldown_frac : int = 0.5 # fraction of training spent cooling down the learning rate
12351240 # evaluation and logging
@@ -1319,7 +1324,7 @@ def nvidia_smi():
13191324 eps = 1e-8 ,
13201325 weight_decay = 0.0 ,
13211326)
1322- optimizer2 = Muon (hidden_matrix_params + gate_params , lr = 0.05 , momentum = 0.95 , weight_decay = 0.0 )
1327+ optimizer2 = Muon (hidden_matrix_params + gate_params , lr = 0.06 , momentum = 0.95 , weight_decay = 0.0 )
13231328optimizers = [optimizer1 , optimizer2 ]
13241329for opt in optimizers :
13251330 for group in opt .param_groups :
0 commit comments