7
7
import torch .nn as nn
8
8
import torch .nn .functional as F
9
9
10
- # TODO (nvchenghaoz): Remove related kernels once we have a backend-specific implementation for attention.
10
+
11
+ def _apply_logit_softcapping (attn_scores : torch .Tensor , logit_cap : Optional [float ]) -> torch .Tensor :
12
+ """Apply logit softcapping using the formula: logit_cap * tanh(logits / logit_cap)"""
13
+ if logit_cap is not None and logit_cap > 0.0 :
14
+ return logit_cap * torch .tanh (attn_scores / logit_cap )
15
+ return attn_scores
16
+
17
+
18
+ def _convert_boolean_mask_to_float (attn_mask : torch .Tensor , dtype : torch .dtype ) -> torch .Tensor :
19
+ """Convert boolean attention mask to floating point mask.
20
+ Args:
21
+ attn_mask: Boolean tensor where True allows attention, False blocks it
22
+ dtype: Target dtype for the output mask
23
+ Returns:
24
+ Floating point mask where True -> 1.0, False -> -inf
25
+ """
26
+ if attn_mask .dtype == torch .bool :
27
+ float_mask = torch .zeros_like (attn_mask , dtype = dtype )
28
+ float_mask = float_mask .masked_fill (attn_mask , 1.0 ) # True -> 1.0
29
+ float_mask = float_mask .masked_fill (~ attn_mask , float ("-inf" )) # False -> -inf
30
+ return float_mask
31
+ return attn_mask
11
32
12
33
13
34
@torch .library .custom_op ("auto_deploy::torch_attention_repeat_kv" , mutates_args = ())
@@ -77,19 +98,96 @@ def grouped_sdpa(
77
98
dropout_p : float = 0.0 ,
78
99
is_causal : bool = False ,
79
100
scale : Optional [float ] = None ,
101
+ sinks : Optional [torch .Tensor ] = None ,
102
+ sliding_window : Optional [int ] = None ,
103
+ logit_cap : Optional [float ] = None ,
80
104
) -> torch .Tensor :
81
- """SDPA attention that can handle GQA."""
105
+ """SDPA attention that can handle GQA. Expects bnsd format inputs."""
106
+ b , n_heads , s_q , head_dim = query .shape # bnsd format: [batch, num_heads, seq_len, head_dim]
107
+ _ , n_kv_heads , s_k , _ = key .shape # bnsd format: [batch, num_kv_heads, seq_len, head_dim]
108
+
109
+ # Inputs are already in bnsd format, no need to transpose
110
+ query_t = query # [b, n_heads, s_q, head_dim]
111
+ key_t = key # [b, n_kv_heads, s_k, head_dim]
112
+ value_t = value # [b, n_kv_heads, s_k, v_head_dim]
113
+
114
+ # Handle GQA by repeating KV if needed
115
+ if n_heads != n_kv_heads :
116
+ n_rep = n_heads // n_kv_heads
117
+ key_t = repeat_kv (key_t , n_rep )
118
+ value_t = repeat_kv (value_t , n_rep )
119
+
120
+ # Set scale
121
+ if scale is None :
122
+ scale = 1.0 / math .sqrt (head_dim )
123
+
124
+ # Compute attention scores: Q @ K^T
125
+ attn_scores = torch .matmul (query_t , key_t .transpose (- 2 , - 1 )) * scale # [b, n_heads, s_q, s_k]
126
+
127
+ # Apply attention mask if provided
128
+ if attn_mask is not None :
129
+ # Convert boolean mask to float if needed
130
+ attn_mask = _convert_boolean_mask_to_float (attn_mask , attn_scores .dtype )
131
+ attn_scores = attn_scores + attn_mask
132
+
133
+ # Apply causal mask if specified and only during the context phase
134
+ if is_causal and s_q == s_k : # Only apply causal mask during context processing
135
+ causal_mask = torch .triu (
136
+ torch .ones (s_q , s_k , device = query .device , dtype = torch .bool ),
137
+ diagonal = 1 , # Use diagonal=1 for standard causal masking
138
+ )
139
+ attn_scores .masked_fill_ (causal_mask .unsqueeze (0 ).unsqueeze (0 ), float ("-inf" ))
140
+
141
+ # Apply sliding window mask if specified
142
+ if sliding_window is not None and sliding_window > 0 :
143
+ # Handle position calculation for both context and generation phases
144
+ if s_q == s_k :
145
+ # Context phase: standard position calculation
146
+ query_positions = torch .arange (s_q , device = query .device )
147
+ key_positions = torch .arange (s_k , device = query .device )
148
+ else :
149
+ # Generation phase: query is at position s_k (after the cache)
150
+ query_positions = torch .arange (s_k , s_k + s_q , device = query .device ) # [s_k] for s_q=1
151
+ key_positions = torch .arange (s_k , device = query .device ) # [0,1,2,...,s_k-1]
152
+
153
+ # Create position difference matrix: query_pos - key_pos
154
+ pos_diff = query_positions .unsqueeze (1 ) - key_positions .unsqueeze (0 ) # [s_q, s_k]
155
+
156
+ # Sliding window mask: allow attention only if 0 <= pos_diff < sliding_window_size
157
+ sliding_window_mask = (pos_diff < 0 ) | (pos_diff >= sliding_window ) # [s_q, s_k]
158
+ attn_scores .masked_fill_ (sliding_window_mask .unsqueeze (0 ).unsqueeze (0 ), float ("-inf" ))
159
+
160
+ # Apply logit softcapping if enabled
161
+ attn_scores = _apply_logit_softcapping (attn_scores , logit_cap )
162
+
163
+ # Apply sinks if provided
164
+ if sinks is not None :
165
+ # Concatenate sinks to attention scores following the reference implementation
166
+ # sinks should have n_heads elements, each head gets its own sink value
167
+ # Expand sinks to [b, n_heads, s_q, 1] - one sink column per head
168
+ sinks_expanded = sinks .reshape (1 , - 1 , 1 , 1 ).expand (
169
+ b , n_heads , s_q , 1
170
+ ) # [b, n_heads, s_q, 1]
171
+
172
+ # Concatenate along the key dimension (last dimension)
173
+ logits_max = torch .max (attn_scores , dim = - 1 , keepdim = True ).values
174
+ sinks = torch .exp (sinks_expanded - logits_max )
175
+ unnormalized_scores = torch .exp (attn_scores - logits_max )
176
+ normalizer = unnormalized_scores .sum (dim = - 1 , keepdim = True ) + sinks
177
+ scores = unnormalized_scores / normalizer
178
+ # Use only the non-sink portion for computing output
179
+ # We added exactly 1 column, so remove exactly 1 column
180
+ attn_out = torch .matmul (scores , value_t ) # [b, n_heads, s_q, v_head_dim]
181
+ else :
182
+ attn_weights = torch .softmax (attn_scores , dim = - 1 , dtype = torch .float32 ).to (query .dtype )
183
+ attn_out = torch .matmul (attn_weights , value_t ) # [b, n_heads, s_q, v_head_dim]
82
184
83
- return F .scaled_dot_product_attention (
84
- query .contiguous (),
85
- key .contiguous (),
86
- value .contiguous (),
87
- attn_mask = attn_mask ,
88
- dropout_p = dropout_p ,
89
- is_causal = is_causal ,
90
- scale = scale ,
91
- enable_gqa = True ,
92
- )
185
+ # Apply dropout if specified
186
+ if dropout_p > 0.0 :
187
+ attn_out = F .dropout (attn_out , p = dropout_p , training = False )
188
+
189
+ # Return in bnsd format (same as input format)
190
+ return attn_out
93
191
94
192
95
193
@grouped_sdpa .register_fake
@@ -101,16 +199,19 @@ def grouped_sdpa_fake(
101
199
dropout_p = 0.0 ,
102
200
is_causal = False ,
103
201
scale = None ,
202
+ sinks = None ,
203
+ sliding_window = None ,
204
+ logit_cap = None ,
104
205
):
105
206
"""Fake implementation of grouped SDPA."""
106
207
return query .new_empty (* query .shape [:- 1 ], value .shape [- 1 ]).contiguous ()
107
208
108
209
109
210
@torch .library .custom_op ("auto_deploy::torch_attention_bsnd_grouped_sdpa" , mutates_args = ())
110
211
def bsnd_grouped_sdpa (
111
- query : torch .Tensor , # layout: [b, n, s_q , d]
112
- key : torch .Tensor , # layout: [b, n, s_k , d]
113
- value : torch .Tensor , # layout: [b, n, s_k , d]
212
+ query : torch .Tensor , # layout: [b, s_q, n , d]
213
+ key : torch .Tensor , # layout: [b, s_k, n , d]
214
+ value : torch .Tensor , # layout: [b, s_k, n , d]
114
215
attn_mask : Optional [torch .Tensor ] = None , # layout: [b, n, s_q, s_k]
115
216
dropout_p : float = 0.0 ,
116
217
is_causal : bool = False ,
@@ -124,14 +225,16 @@ def bsnd_grouped_sdpa(
124
225
Note that attn_mask layout is still assumed to be [b, n, s_q, s_k] and is consistent with the
125
226
original sdpa op!
126
227
"""
127
- # let's transpose to bnsd so we can use the grouped sdpa
128
- query = query .transpose (1 , 2 ).contiguous ()
129
- key = key .transpose (1 , 2 ).contiguous ()
130
- value = value .transpose (1 , 2 ).contiguous ()
131
-
132
- out = grouped_sdpa (query , key , value , attn_mask , dropout_p , is_causal , scale )
133
-
134
- # let's transpose back to bnsd
228
+ # Transpose inputs to bnsd format for grouped_sdpa
229
+ query = query .transpose (1 , 2 ).contiguous () # [b, s_q, n, d] -> [b, n, s_q, d]
230
+ key = key .transpose (1 , 2 ).contiguous () # [b, s_k, n, d] -> [b, n, s_k, d]
231
+ value = value .transpose (1 , 2 ).contiguous () # [b, s_k, n, d] -> [b, n, s_k, d]
232
+
233
+ # Call grouped_sdpa with bnsd inputs
234
+ out = grouped_sdpa (
235
+ query , key , value , attn_mask , dropout_p , is_causal , scale , sinks , sliding_window , logit_cap
236
+ )
237
+ # Transpose back to bsnd format
135
238
return out .transpose (1 , 2 ).contiguous ()
136
239
137
240
0 commit comments