-
Notifications
You must be signed in to change notification settings - Fork 370
Open
Labels
bugSomething isn't workingSomething isn't working
Description
Bug Description
When applying NVFP4 quantization to LLM models with dynamic sequence lengths, Torch-TensorRT compilation fails because the dynamic_block_quantize_op
does not support inferred dimensions (-1) in the last dimension of the input shape.
12:37:35 - INFO - Converted node /_reshape_copy_3 [aten._reshape_copy.default] (Inputs: (permute_2: (1, s87, 32, 64)@torch.float32, [1, sym_size_int_3, -1]) | Outputs: (_reshape_copy_3: (1, s87, 2048)@torch.float32))
12:37:35 - DEBUG - Converting node o_proj_input_quantizer__amax (kind: o_proj.input_quantizer._amax, args: ())
12:37:35 - INFO - Converted node o_proj_input_quantizer__amax [o_proj.input_quantizer._amax] (Inputs: () | Outputs: (o_proj_input_quantizer__amax: ()@torch.float32))
12:37:35 - DEBUG - Converting node o_proj.input_quantizer/dynamic_block_quantize_op_2 (kind: tensorrt.dynamic_block_quantize_op.default, args: ('_reshape_copy_3 <Node>', '16 <int>', 'o_proj_input_quantizer__amax <Node>', '4 <int>', '2 <int>', '8 <int>', '4 <int>'))
12:37:35 - DEBUG - Converter options for tensorrt.dynamic_block_quantize_op.default: 1
12:37:35 - DEBUG - Selecting converter option 0 for converting tensorrt.dynamic_block_quantize_op.default
12:37:35 - ERROR - ITensor::getDimensions: Error Code 4: API Usage Error ([DYNAMIC_QUANTIZE]-[aten_ops.dynamic_block_quantize_op.default]-[o_proj.input_quantizer/dynamic_block_quantize_op_2_dynamic_quantize] The input extent in the blocked axis should be known at build time. axis = 2.)
12:37:35 - ERROR - ITensor::getDimensions: Error Code 4: API Usage Error (Output shape can not be computed for node [DYNAMIC_QUANTIZE]-[aten_ops.dynamic_block_quantize_op.default]-[o_proj.input_quantizer/dynamic_block_quantize_op_2_dynamic_quantize].)
12:37:35 - ERROR - ITensor::getDimensions: Error Code 4: API Usage Error (Output shape can not be computed for node [DEQUANTIZE]-[aten_ops.dynamic_block_quantize_op.default]-[o_proj.input_quantizer/dynamic_block_quantize_op_2_dequantize_scale].)
To Reproduce
Steps to reproduce the behavior:
Minimal reproducible example
import torch
import torch_tensorrt as torchtrt
import torch.nn as nn
import modelopt.torch.quantization as mtq
from modelopt.torch.quantization.utils import export_torch_mode
dtype = torch.float32
class SimpleNetwork(torch.nn.Module):
def __init__(self):
super(SimpleNetwork, self).__init__()
self.hidden_size = 2048
num_attention_heads = 32
head_dim = self.hidden_size // num_attention_heads
self.head_dim = head_dim
self.q_proj = nn.Linear(
self.hidden_size, num_attention_heads * head_dim
)
self.o_proj = nn.Linear(
num_attention_heads * head_dim, self.hidden_size
)
def forward(self, hidden_states):
input_shape = hidden_states.shape[:-1]
hidden_shape = (*input_shape, -1, self.head_dim)
query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2)
query_states = query_states.transpose(1, 2).contiguous()
# input shape : [1, s, 32, 64]
# reshape: [1, s, -1]
attn_output = query_states.reshape(*input_shape, -1).contiguous()
attn_output = self.o_proj(attn_output)
return attn_output
def calibrate_loop(model):
model(dummy_inputs)
SEQ_SIZE = torch.export.Dim("SEQ_SIZE", min=1, max=128)
# dim_embed = 2048, seq_len = 10
dummy_inputs = torch.ones(1, 10, 2048, dtype=dtype).cuda()
model = SimpleNetwork().eval().cuda().to(dtype)
quant_cfg = mtq.NVFP4_DEFAULT_CFG
mtq.quantize(model, quant_cfg, forward_loop=calibrate_loop)
# model has qdq nodes at this point
with torch.no_grad():
with export_torch_mode():
exp_program = torch.export._trace._export(
model, (dummy_inputs,), strict=False, dynamic_shapes=({1: SEQ_SIZE},), allow_complex_guards_as_runtime_asserts=True,
)
print(exp_program)
with torchtrt.dynamo.Debugger(
"debug",
engine_builder_monitor=False,
):
trt_model = torchtrt.dynamo.compile(
exp_program,
inputs=[dummy_inputs],
min_block_size=1,
cache_built_engines=False,
reuse_cached_engines=False,
)
Expected behavior
Environment
Build information about Torch-TensorRT can be found by turning on debug messages
- Torch-TensorRT Version (e.g. 1.0.0):
- PyTorch Version (e.g. 1.0):
- CPU Architecture:
- OS (e.g., Linux):
- How you installed PyTorch (
conda
,pip
,libtorch
, source): - Build command you used (if compiling from source):
- Are you using local sources or building from archives:
- Python version:
- CUDA version:
- GPU models and configuration:
- Any other relevant information:
Additional context
Metadata
Metadata
Assignees
Labels
bugSomething isn't workingSomething isn't working