Skip to content

Commit dc74ec3

Browse files
nvchenghaozFridah-nv
authored andcommitted
[Fix] Fix the torch backend (#108)
* attention matcher with torch._inductor pattern matcher,matching repeat kv, sdpa and group attention, update unit tests Signed-off-by: Frida Hou <[email protected]> * Fix the torch backend Attention Signed-off-by: nvchenghaoz <[email protected]> * Revert "attention matcher with torch._inductor pattern matcher,matching repeat kv, sdpa and group attention, update unit tests" This reverts commit 5743fb3. --------- Signed-off-by: Frida Hou <[email protected]> Signed-off-by: nvchenghaoz <[email protected]> Co-authored-by: Frida Hou <[email protected]>
1 parent ad5fd3b commit dc74ec3

File tree

1 file changed

+6
-8
lines changed

1 file changed

+6
-8
lines changed

tensorrt_llm/_torch/auto_deploy/custom_ops/torch_backend_attention.py

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -103,7 +103,7 @@ def _torch_generate_mha(
103103
# Apply sinks if provided (following the model file pattern)
104104
if sinks is not None:
105105
# Concatenate sinks to attention scores
106-
sinks = sinks.reshape(-1, 1, 1).expand(-1, attn_scores.shape[-2], -1)
106+
sinks = sinks.reshape(-1, 1, 1)
107107
attn_weights = torch.cat([attn_scores, sinks], dim=-1)
108108
attn_weights = torch.softmax(attn_weights, dim=-1, dtype=torch.float32).to(q.dtype)
109109
# Use only the non-sink portion for computing output (ignore sinks)
@@ -202,9 +202,7 @@ def _torch_context_mha(
202202
) # [seq_len_i, kv_seq_len]
203203

204204
# Sliding window mask: allow attention only if 0 <= pos_diff < sliding_window_size
205-
sliding_window_mask = (pos_diff < 0) | (
206-
pos_diff >= sliding_window_size
207-
) # [seq_len_i, kv_seq_len]
205+
sliding_window_mask = pos_diff >= sliding_window_size
208206

209207
# Combine causal and sliding window masks
210208
combined_mask = causal_mask | sliding_window_mask
@@ -219,14 +217,14 @@ def _torch_context_mha(
219217
# Apply sinks if provided (following the model file pattern)
220218
if sinks is not None:
221219
# Concatenate sinks to attention scores
222-
sinks = sinks.reshape(1, -1, 1, 1).expand(
223-
attn_scores.shape[0], -1, attn_scores.shape[-2], -1
220+
new_sinks = sinks.reshape(1, -1, 1, 1).expand(
221+
attn_scores.shape[0], -1, attn_scores.shape[2], 1
224222
)
225-
attn_weights = torch.cat([attn_scores, sinks], dim=-1)
223+
attn_weights = torch.cat([attn_scores, new_sinks], dim=-1)
226224
attn_weights = torch.softmax(attn_weights, dim=-1, dtype=torch.float32).to(q.dtype)
227225
# Use only the non-sink portion for computing output (ignore sinks)
228226
attn_out = torch.matmul(
229-
attn_weights[..., : -sinks.size(-1)], v_seq_t
227+
attn_weights[..., : -new_sinks.size(-1)], v_seq_t
230228
) # [1, n_heads, seq_len_i, v_head_dim]
231229
else:
232230
attn_weights = torch.softmax(attn_scores, dim=-1, dtype=torch.float32).to(q.dtype)

0 commit comments

Comments
 (0)