Skip to content

Commit d1763ab

Browse files
lucaslienzmora-nvidiaGal Agamh-guo18nvchenghaoz
authored andcommitted
[AutoDeploy] merge feat/ad-2025-07-22 (NVIDIA#6520)
Signed-off-by: Neta Zmora <[email protected]> Signed-off-by: Gal Agam <[email protected]> Signed-off-by: Lucas Liebenwein <[email protected]> Signed-off-by: haoguo <[email protected]> Signed-off-by: h-guo18 <[email protected]> Signed-off-by: Frida Hou <[email protected]> Signed-off-by: nvchenghaoz <[email protected]> Signed-off-by: Eran Geva <[email protected]> Signed-off-by: Fridah-nv <[email protected]> Co-authored-by: Neta Zmora <[email protected]> Co-authored-by: Gal Agam <[email protected]> Co-authored-by: h-guo18 <[email protected]> Co-authored-by: nvchenghaoz <[email protected]> Co-authored-by: Frida Hou <[email protected]> Co-authored-by: Eran Geva <[email protected]> Signed-off-by: Lanyu Liao <[email protected]>
1 parent 3f9053d commit d1763ab

File tree

22 files changed

+1285
-1308
lines changed

22 files changed

+1285
-1308
lines changed

tensorrt_llm/_torch/auto_deploy/config/default.yaml

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,3 +19,15 @@ transforms:
1919
stage: post_export
2020
cleanup_input_constraints:
2121
stage: post_export
22+
quantize:
23+
stage: pattern_matcher
24+
quantize_moe:
25+
stage: pattern_matcher
26+
match_repeat_kv:
27+
stage: pattern_matcher
28+
match_eager_attention:
29+
stage: pattern_matcher
30+
match_grouped_attention:
31+
stage: pattern_matcher
32+
match_attention_layout:
33+
stage: pattern_matcher

tensorrt_llm/_torch/auto_deploy/custom_ops/torch_attention.py

Lines changed: 126 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,28 @@
77
import torch.nn as nn
88
import torch.nn.functional as F
99

10-
# TODO (nvchenghaoz): Remove related kernels once we have a backend-specific implementation for attention.
10+
11+
def _apply_logit_softcapping(attn_scores: torch.Tensor, logit_cap: Optional[float]) -> torch.Tensor:
12+
"""Apply logit softcapping using the formula: logit_cap * tanh(logits / logit_cap)"""
13+
if logit_cap is not None and logit_cap > 0.0:
14+
return logit_cap * torch.tanh(attn_scores / logit_cap)
15+
return attn_scores
16+
17+
18+
def _convert_boolean_mask_to_float(attn_mask: torch.Tensor, dtype: torch.dtype) -> torch.Tensor:
19+
"""Convert boolean attention mask to floating point mask.
20+
Args:
21+
attn_mask: Boolean tensor where True allows attention, False blocks it
22+
dtype: Target dtype for the output mask
23+
Returns:
24+
Floating point mask where True -> 1.0, False -> -inf
25+
"""
26+
if attn_mask.dtype == torch.bool:
27+
float_mask = torch.zeros_like(attn_mask, dtype=dtype)
28+
float_mask = float_mask.masked_fill(attn_mask, 1.0) # True -> 1.0
29+
float_mask = float_mask.masked_fill(~attn_mask, float("-inf")) # False -> -inf
30+
return float_mask
31+
return attn_mask
1132

1233

1334
@torch.library.custom_op("auto_deploy::torch_attention_repeat_kv", mutates_args=())
@@ -77,19 +98,96 @@ def grouped_sdpa(
7798
dropout_p: float = 0.0,
7899
is_causal: bool = False,
79100
scale: Optional[float] = None,
101+
sinks: Optional[torch.Tensor] = None,
102+
sliding_window: Optional[int] = None,
103+
logit_cap: Optional[float] = None,
80104
) -> torch.Tensor:
81-
"""SDPA attention that can handle GQA."""
105+
"""SDPA attention that can handle GQA. Expects bnsd format inputs."""
106+
b, n_heads, s_q, head_dim = query.shape # bnsd format: [batch, num_heads, seq_len, head_dim]
107+
_, n_kv_heads, s_k, _ = key.shape # bnsd format: [batch, num_kv_heads, seq_len, head_dim]
108+
109+
# Inputs are already in bnsd format, no need to transpose
110+
query_t = query # [b, n_heads, s_q, head_dim]
111+
key_t = key # [b, n_kv_heads, s_k, head_dim]
112+
value_t = value # [b, n_kv_heads, s_k, v_head_dim]
113+
114+
# Handle GQA by repeating KV if needed
115+
if n_heads != n_kv_heads:
116+
n_rep = n_heads // n_kv_heads
117+
key_t = repeat_kv(key_t, n_rep)
118+
value_t = repeat_kv(value_t, n_rep)
119+
120+
# Set scale
121+
if scale is None:
122+
scale = 1.0 / math.sqrt(head_dim)
123+
124+
# Compute attention scores: Q @ K^T
125+
attn_scores = torch.matmul(query_t, key_t.transpose(-2, -1)) * scale # [b, n_heads, s_q, s_k]
126+
127+
# Apply attention mask if provided
128+
if attn_mask is not None:
129+
# Convert boolean mask to float if needed
130+
attn_mask = _convert_boolean_mask_to_float(attn_mask, attn_scores.dtype)
131+
attn_scores = attn_scores + attn_mask
132+
133+
# Apply causal mask if specified and only during the context phase
134+
if is_causal and s_q == s_k: # Only apply causal mask during context processing
135+
causal_mask = torch.triu(
136+
torch.ones(s_q, s_k, device=query.device, dtype=torch.bool),
137+
diagonal=1, # Use diagonal=1 for standard causal masking
138+
)
139+
attn_scores.masked_fill_(causal_mask.unsqueeze(0).unsqueeze(0), float("-inf"))
140+
141+
# Apply sliding window mask if specified
142+
if sliding_window is not None and sliding_window > 0:
143+
# Handle position calculation for both context and generation phases
144+
if s_q == s_k:
145+
# Context phase: standard position calculation
146+
query_positions = torch.arange(s_q, device=query.device)
147+
key_positions = torch.arange(s_k, device=query.device)
148+
else:
149+
# Generation phase: query is at position s_k (after the cache)
150+
query_positions = torch.arange(s_k, s_k + s_q, device=query.device) # [s_k] for s_q=1
151+
key_positions = torch.arange(s_k, device=query.device) # [0,1,2,...,s_k-1]
152+
153+
# Create position difference matrix: query_pos - key_pos
154+
pos_diff = query_positions.unsqueeze(1) - key_positions.unsqueeze(0) # [s_q, s_k]
155+
156+
# Sliding window mask: allow attention only if 0 <= pos_diff < sliding_window_size
157+
sliding_window_mask = (pos_diff < 0) | (pos_diff >= sliding_window) # [s_q, s_k]
158+
attn_scores.masked_fill_(sliding_window_mask.unsqueeze(0).unsqueeze(0), float("-inf"))
159+
160+
# Apply logit softcapping if enabled
161+
attn_scores = _apply_logit_softcapping(attn_scores, logit_cap)
162+
163+
# Apply sinks if provided
164+
if sinks is not None:
165+
# Concatenate sinks to attention scores following the reference implementation
166+
# sinks should have n_heads elements, each head gets its own sink value
167+
# Expand sinks to [b, n_heads, s_q, 1] - one sink column per head
168+
sinks_expanded = sinks.reshape(1, -1, 1, 1).expand(
169+
b, n_heads, s_q, 1
170+
) # [b, n_heads, s_q, 1]
171+
172+
# Concatenate along the key dimension (last dimension)
173+
logits_max = torch.max(attn_scores, dim=-1, keepdim=True).values
174+
sinks = torch.exp(sinks_expanded - logits_max)
175+
unnormalized_scores = torch.exp(attn_scores - logits_max)
176+
normalizer = unnormalized_scores.sum(dim=-1, keepdim=True) + sinks
177+
scores = unnormalized_scores / normalizer
178+
# Use only the non-sink portion for computing output
179+
# We added exactly 1 column, so remove exactly 1 column
180+
attn_out = torch.matmul(scores, value_t) # [b, n_heads, s_q, v_head_dim]
181+
else:
182+
attn_weights = torch.softmax(attn_scores, dim=-1, dtype=torch.float32).to(query.dtype)
183+
attn_out = torch.matmul(attn_weights, value_t) # [b, n_heads, s_q, v_head_dim]
82184

83-
return F.scaled_dot_product_attention(
84-
query.contiguous(),
85-
key.contiguous(),
86-
value.contiguous(),
87-
attn_mask=attn_mask,
88-
dropout_p=dropout_p,
89-
is_causal=is_causal,
90-
scale=scale,
91-
enable_gqa=True,
92-
)
185+
# Apply dropout if specified
186+
if dropout_p > 0.0:
187+
attn_out = F.dropout(attn_out, p=dropout_p, training=False)
188+
189+
# Return in bnsd format (same as input format)
190+
return attn_out
93191

94192

95193
@grouped_sdpa.register_fake
@@ -101,16 +199,19 @@ def grouped_sdpa_fake(
101199
dropout_p=0.0,
102200
is_causal=False,
103201
scale=None,
202+
sinks=None,
203+
sliding_window=None,
204+
logit_cap=None,
104205
):
105206
"""Fake implementation of grouped SDPA."""
106207
return query.new_empty(*query.shape[:-1], value.shape[-1]).contiguous()
107208

108209

109210
@torch.library.custom_op("auto_deploy::torch_attention_bsnd_grouped_sdpa", mutates_args=())
110211
def bsnd_grouped_sdpa(
111-
query: torch.Tensor, # layout: [b, n, s_q, d]
112-
key: torch.Tensor, # layout: [b, n, s_k, d]
113-
value: torch.Tensor, # layout: [b, n, s_k, d]
212+
query: torch.Tensor, # layout: [b, s_q, n, d]
213+
key: torch.Tensor, # layout: [b, s_k, n, d]
214+
value: torch.Tensor, # layout: [b, s_k, n, d]
114215
attn_mask: Optional[torch.Tensor] = None, # layout: [b, n, s_q, s_k]
115216
dropout_p: float = 0.0,
116217
is_causal: bool = False,
@@ -124,14 +225,16 @@ def bsnd_grouped_sdpa(
124225
Note that attn_mask layout is still assumed to be [b, n, s_q, s_k] and is consistent with the
125226
original sdpa op!
126227
"""
127-
# let's transpose to bnsd so we can use the grouped sdpa
128-
query = query.transpose(1, 2).contiguous()
129-
key = key.transpose(1, 2).contiguous()
130-
value = value.transpose(1, 2).contiguous()
131-
132-
out = grouped_sdpa(query, key, value, attn_mask, dropout_p, is_causal, scale)
133-
134-
# let's transpose back to bnsd
228+
# Transpose inputs to bnsd format for grouped_sdpa
229+
query = query.transpose(1, 2).contiguous() # [b, s_q, n, d] -> [b, n, s_q, d]
230+
key = key.transpose(1, 2).contiguous() # [b, s_k, n, d] -> [b, n, s_k, d]
231+
value = value.transpose(1, 2).contiguous() # [b, s_k, n, d] -> [b, n, s_k, d]
232+
233+
# Call grouped_sdpa with bnsd inputs
234+
out = grouped_sdpa(
235+
query, key, value, attn_mask, dropout_p, is_causal, scale, sinks, sliding_window, logit_cap
236+
)
237+
# Transpose back to bsnd format
135238
return out.transpose(1, 2).contiguous()
136239

137240

tensorrt_llm/_torch/auto_deploy/custom_ops/torch_backend_attention.py

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -103,7 +103,7 @@ def _torch_generate_mha(
103103
# Apply sinks if provided (following the model file pattern)
104104
if sinks is not None:
105105
# Concatenate sinks to attention scores
106-
sinks = sinks.reshape(-1, 1, 1).expand(-1, attn_scores.shape[-2], -1)
106+
sinks = sinks.reshape(-1, 1, 1)
107107
attn_weights = torch.cat([attn_scores, sinks], dim=-1)
108108
attn_weights = torch.softmax(attn_weights, dim=-1, dtype=torch.float32).to(q.dtype)
109109
# Use only the non-sink portion for computing output (ignore sinks)
@@ -202,9 +202,7 @@ def _torch_context_mha(
202202
) # [seq_len_i, kv_seq_len]
203203

204204
# Sliding window mask: allow attention only if 0 <= pos_diff < sliding_window_size
205-
sliding_window_mask = (pos_diff < 0) | (
206-
pos_diff >= sliding_window_size
207-
) # [seq_len_i, kv_seq_len]
205+
sliding_window_mask = pos_diff >= sliding_window_size
208206

209207
# Combine causal and sliding window masks
210208
combined_mask = causal_mask | sliding_window_mask
@@ -219,14 +217,14 @@ def _torch_context_mha(
219217
# Apply sinks if provided (following the model file pattern)
220218
if sinks is not None:
221219
# Concatenate sinks to attention scores
222-
sinks = sinks.reshape(1, -1, 1, 1).expand(
223-
attn_scores.shape[0], -1, attn_scores.shape[-2], -1
220+
new_sinks = sinks.reshape(1, -1, 1, 1).expand(
221+
attn_scores.shape[0], -1, attn_scores.shape[2], 1
224222
)
225-
attn_weights = torch.cat([attn_scores, sinks], dim=-1)
223+
attn_weights = torch.cat([attn_scores, new_sinks], dim=-1)
226224
attn_weights = torch.softmax(attn_weights, dim=-1, dtype=torch.float32).to(q.dtype)
227225
# Use only the non-sink portion for computing output (ignore sinks)
228226
attn_out = torch.matmul(
229-
attn_weights[..., : -sinks.size(-1)], v_seq_t
227+
attn_weights[..., : -new_sinks.size(-1)], v_seq_t
230228
) # [1, n_heads, seq_len_i, v_head_dim]
231229
else:
232230
attn_weights = torch.softmax(attn_scores, dim=-1, dtype=torch.float32).to(q.dtype)

tensorrt_llm/_torch/auto_deploy/distributed/trtllm.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,8 @@ def trtllm_allreduce(tensor, op, all_reduce_params=None):
1717
rank, world_size = get_rank_world_size()
1818
assert op == ReduceOp.SUM, "TRT-LLM all reduce only supports SUM op."
1919
p_config = Mapping(world_size=world_size, tp_size=world_size, rank=rank)
20-
torch_op = AllReduce(mapping=p_config, strategy=AllReduceStrategy.AUTO)
20+
# Use Strategy.NCCL until https://nvbugspro.nvidia.com/bug/5331013 is fixed, then change to Strategy.AUTO
21+
torch_op = AllReduce(mapping=p_config, strategy=AllReduceStrategy.NCCL)
2122
return torch_op(tensor, all_reduce_params=all_reduce_params)
2223

2324
@torch.library.custom_op(

tensorrt_llm/_torch/auto_deploy/models/hf.py

Lines changed: 23 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -76,14 +76,24 @@ class AutoModelForCausalLMFactory(ModelFactory):
7676
"max_position_embeddings": 1024,
7777
}
7878

79+
def _get_max_position_embeddings_config(self) -> Dict[str, Any]:
80+
"""Get the max position embeddings config for the model."""
81+
return {
82+
"max_position_embeddings": self.max_seq_len,
83+
}
84+
7985
def __init__(self, *args, **kwargs):
8086
super().__init__(*args, **kwargs)
8187

8288
self._quant_config: Optional[Dict] = None
8389

8490
# Ingest defaults for tokenizer and model kwargs
8591
self.tokenizer_kwargs = deep_merge_dicts(self._tokenizer_defaults, self.tokenizer_kwargs)
86-
self.model_kwargs = deep_merge_dicts(self._model_defaults, self.model_kwargs)
92+
self.model_kwargs = deep_merge_dicts(
93+
self._model_defaults,
94+
self.model_kwargs,
95+
self._get_max_position_embeddings_config(),
96+
)
8797

8898
# special handling for torch_dtype in model_kwargs since HF does not correctly update
8999
# torch_dtype string to an actual torch.dtype object (only with default)
@@ -295,7 +305,7 @@ def _prefetch_checkpoint(self, model_name_or_path: str, skip_prefetch_weights: b
295305
# at this point it should be a directory (either the original one or the download dir)
296306
assert os.path.isdir(fetched_dir), f"Checkpoint path {fetched_dir} is not a directory."
297307

298-
self._load_quantization_config()
308+
self._load_quantization_config(fetched_dir)
299309

300310
return fetched_dir
301311

@@ -313,13 +323,13 @@ def _load_checkpoint(self, model: nn.Module, device: DeviceLikeType):
313323
# model-transformed weights,leading to unexpected key mismatches or format issues.
314324
load_checkpoint_in_model(model, checkpoint=ckpt_file, full_state_dict=False)
315325

316-
def _load_quantization_config(self):
326+
def _load_quantization_config(self, fetched_dir: str):
317327
"""Load the quantization config from the model directory if not done already."""
318328
if self._quant_config is not None:
319329
return
320330

321331
assert self.model
322-
hf_quant_config_file = os.path.join(self.model, "hf_quant_config.json")
332+
hf_quant_config_file = os.path.join(fetched_dir, "hf_quant_config.json")
323333
if os.path.exists(hf_quant_config_file):
324334
with open(hf_quant_config_file, "r") as file:
325335
quantization_config = json.load(file)
@@ -344,6 +354,15 @@ class AutoModelForImageTextToTextFactory(AutoModelForCausalLMFactory):
344354
},
345355
}
346356

357+
def _get_max_position_embeddings_config(self) -> Dict[str, Any]:
358+
"""Get the max position embeddings config for the model."""
359+
return {
360+
"max_position_embeddings": self.max_seq_len,
361+
"text_config": {
362+
"max_position_embeddings": self.max_seq_len,
363+
},
364+
}
365+
347366
@property
348367
def automodel_from_config(self):
349368
return AutoModelForImageTextToText.from_config

tensorrt_llm/_torch/auto_deploy/transform/interface.py

Lines changed: 38 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -227,18 +227,26 @@ def __call__(
227227
# run or skip the transform
228228
if self.config.enabled:
229229
# run graph pre-cleanup
230-
self._run_pre_cleanup(gm, info_last)
231-
232-
# run the transform in a error-handling wrapper
233-
try:
234-
gm, info = self._apply(gm, cm, factory)
235-
except Exception as e:
236-
error_msg = f"Transform {t_name} failed"
237-
if self.config.skip_on_error:
230+
is_clean_pre, has_valid_shapes_pre = self._run_pre_cleanup(gm, info_last)
231+
232+
# run the transform in a error-handling wrapper if desired
233+
if self.config.skip_on_error:
234+
try:
235+
gm, info = self._apply(gm, cm, factory)
236+
except Exception as e:
237+
error_msg = f"Transform {t_name} failed"
238238
ad_logger.warning(f"{error_msg}: {e}")
239239
info = TransformInfo(skipped=True, num_matches=0)
240-
else:
241-
raise TransformError(error_msg) from e
240+
else:
241+
# handle this here normally to improve debugging and error message
242+
gm, info = self._apply(gm, cm, factory)
243+
244+
# we cannot say it's clean if the previous wasn't clean even if this one is
245+
# create new info object with updated cleanup status
246+
info_dict = info.model_dump()
247+
info_dict["is_clean"] &= is_clean_pre
248+
info_dict["has_valid_shapes"] &= has_valid_shapes_pre
249+
info = TransformInfo(**info_dict)
242250

243251
# run graph post-cleanup
244252
info = self._run_post_cleanup(gm, info)
@@ -279,20 +287,36 @@ def _set_autodeploy_meta(self, gm: GraphModule, autodeploy_meta: AutodeployMeta)
279287
gm.meta[self._autodeploy_meta_key] = autodeploy_meta
280288

281289
@final
282-
def _run_pre_cleanup(self, gm: GraphModule, info: TransformInfo) -> None:
290+
def _run_pre_cleanup(self, gm: GraphModule, info: TransformInfo) -> Tuple[bool, bool]:
283291
"""Run graph cleanup before the transform.
284292
293+
Args:
294+
gm: The graph module to run cleanup on.
295+
info: The last transform info.
296+
297+
Returns:
298+
A tuple of (is_clean, has_valid_shapes) indicating the cleanup status after the
299+
pre-cleanup.
300+
285301
This is used to ensure the transform is applied to a clean graph as needed by the transform.
286302
"""
287303
if not self.config.requires_clean_graph:
288-
return
304+
return info.is_clean, info.has_valid_shapes
305+
306+
is_clean = info.is_clean
307+
has_valid_shapes = is_clean and info.has_valid_shapes
289308

290309
# check if run cleanup depending on the config and info
291-
if self.config.requires_shape_prop and not (info.is_clean and info.has_valid_shapes):
310+
if self.config.requires_shape_prop and not has_valid_shapes:
292311
with lift_to_meta(gm):
293312
canonicalize_graph(gm, shape_prop=True)
294-
elif self.config.requires_clean_graph and not info.is_clean:
313+
is_clean = True
314+
has_valid_shapes = True
315+
elif self.config.requires_clean_graph and not is_clean:
295316
canonicalize_graph(gm)
317+
is_clean = True
318+
319+
return is_clean, has_valid_shapes
296320

297321
@final
298322
def _run_post_cleanup(self, gm: GraphModule, info: TransformInfo) -> TransformInfo:

0 commit comments

Comments
 (0)