@@ -103,7 +103,7 @@ def _torch_generate_mha(
103
103
# Apply sinks if provided (following the model file pattern)
104
104
if sinks is not None :
105
105
# Concatenate sinks to attention scores
106
- sinks = sinks .reshape (- 1 , 1 , 1 ). expand ( - 1 , attn_scores . shape [ - 2 ], - 1 )
106
+ sinks = sinks .reshape (- 1 , 1 , 1 )
107
107
attn_weights = torch .cat ([attn_scores , sinks ], dim = - 1 )
108
108
attn_weights = torch .softmax (attn_weights , dim = - 1 , dtype = torch .float32 ).to (q .dtype )
109
109
# Use only the non-sink portion for computing output (ignore sinks)
@@ -202,9 +202,7 @@ def _torch_context_mha(
202
202
) # [seq_len_i, kv_seq_len]
203
203
204
204
# Sliding window mask: allow attention only if 0 <= pos_diff < sliding_window_size
205
- sliding_window_mask = (pos_diff < 0 ) | (
206
- pos_diff >= sliding_window_size
207
- ) # [seq_len_i, kv_seq_len]
205
+ sliding_window_mask = pos_diff >= sliding_window_size
208
206
209
207
# Combine causal and sliding window masks
210
208
combined_mask = causal_mask | sliding_window_mask
@@ -219,14 +217,14 @@ def _torch_context_mha(
219
217
# Apply sinks if provided (following the model file pattern)
220
218
if sinks is not None :
221
219
# Concatenate sinks to attention scores
222
- sinks = sinks .reshape (1 , - 1 , 1 , 1 ).expand (
223
- attn_scores .shape [0 ], - 1 , attn_scores .shape [- 2 ], - 1
220
+ new_sinks = sinks .reshape (1 , - 1 , 1 , 1 ).expand (
221
+ attn_scores .shape [0 ], - 1 , attn_scores .shape [2 ], 1
224
222
)
225
- attn_weights = torch .cat ([attn_scores , sinks ], dim = - 1 )
223
+ attn_weights = torch .cat ([attn_scores , new_sinks ], dim = - 1 )
226
224
attn_weights = torch .softmax (attn_weights , dim = - 1 , dtype = torch .float32 ).to (q .dtype )
227
225
# Use only the non-sink portion for computing output (ignore sinks)
228
226
attn_out = torch .matmul (
229
- attn_weights [..., : - sinks .size (- 1 )], v_seq_t
227
+ attn_weights [..., : - new_sinks .size (- 1 )], v_seq_t
230
228
) # [1, n_heads, seq_len_i, v_head_dim]
231
229
else :
232
230
attn_weights = torch .softmax (attn_scores , dim = - 1 , dtype = torch .float32 ).to (q .dtype )
0 commit comments