Skip to content

🐛 [Bug] Compilation failure with NVFP4 Quantization with dynamic shapes #3745

@keehyuna

Description

@keehyuna

Bug Description

When applying NVFP4 quantization to LLM models with dynamic sequence lengths, Torch-TensorRT compilation fails because the dynamic_block_quantize_op does not support inferred dimensions (-1) in the last dimension of the input shape.

12:37:35 - INFO - Converted node /_reshape_copy_3 [aten._reshape_copy.default] (Inputs: (permute_2: (1, s87, 32, 64)@torch.float32, [1, sym_size_int_3, -1]) | Outputs: (_reshape_copy_3: (1, s87, 2048)@torch.float32))
12:37:35 - DEBUG - Converting node o_proj_input_quantizer__amax (kind: o_proj.input_quantizer._amax, args: ())
12:37:35 - INFO - Converted node o_proj_input_quantizer__amax [o_proj.input_quantizer._amax] (Inputs: () | Outputs: (o_proj_input_quantizer__amax: ()@torch.float32))
12:37:35 - DEBUG - Converting node o_proj.input_quantizer/dynamic_block_quantize_op_2 (kind: tensorrt.dynamic_block_quantize_op.default, args: ('_reshape_copy_3 <Node>', '16 <int>', 'o_proj_input_quantizer__amax <Node>', '4 <int>', '2 <int>', '8 <int>', '4 <int>'))
12:37:35 - DEBUG - Converter options for tensorrt.dynamic_block_quantize_op.default: 1
12:37:35 - DEBUG - Selecting converter option 0 for converting tensorrt.dynamic_block_quantize_op.default
12:37:35 - ERROR - ITensor::getDimensions: Error Code 4: API Usage Error ([DYNAMIC_QUANTIZE]-[aten_ops.dynamic_block_quantize_op.default]-[o_proj.input_quantizer/dynamic_block_quantize_op_2_dynamic_quantize] The input extent in the blocked axis should be known at build time. axis = 2.)
12:37:35 - ERROR - ITensor::getDimensions: Error Code 4: API Usage Error (Output shape can not be computed for node [DYNAMIC_QUANTIZE]-[aten_ops.dynamic_block_quantize_op.default]-[o_proj.input_quantizer/dynamic_block_quantize_op_2_dynamic_quantize].)
12:37:35 - ERROR - ITensor::getDimensions: Error Code 4: API Usage Error (Output shape can not be computed for node [DEQUANTIZE]-[aten_ops.dynamic_block_quantize_op.default]-[o_proj.input_quantizer/dynamic_block_quantize_op_2_dequantize_scale].)

To Reproduce

Steps to reproduce the behavior:

Minimal reproducible example

import torch
import torch_tensorrt as torchtrt
import torch.nn as nn
import modelopt.torch.quantization as mtq
from modelopt.torch.quantization.utils import export_torch_mode

dtype = torch.float32

class SimpleNetwork(torch.nn.Module):
    def __init__(self):
        super(SimpleNetwork, self).__init__()
        self.hidden_size = 2048
        num_attention_heads = 32
        head_dim = self.hidden_size // num_attention_heads
        self.head_dim = head_dim
        self.q_proj = nn.Linear(
            self.hidden_size, num_attention_heads * head_dim
        )
        self.o_proj = nn.Linear(
            num_attention_heads * head_dim, self.hidden_size
        )
    def forward(self, hidden_states):
        input_shape = hidden_states.shape[:-1]
        
        hidden_shape = (*input_shape, -1, self.head_dim)
        query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2)
        query_states = query_states.transpose(1, 2).contiguous()
        # input shape : [1, s, 32, 64]
        # reshape: [1, s, -1]
        attn_output = query_states.reshape(*input_shape, -1).contiguous()
        attn_output = self.o_proj(attn_output)
        return attn_output

def calibrate_loop(model):
    model(dummy_inputs)

SEQ_SIZE = torch.export.Dim("SEQ_SIZE", min=1, max=128)

# dim_embed = 2048, seq_len = 10
dummy_inputs = torch.ones(1, 10, 2048, dtype=dtype).cuda()

model = SimpleNetwork().eval().cuda().to(dtype)

quant_cfg = mtq.NVFP4_DEFAULT_CFG
mtq.quantize(model, quant_cfg, forward_loop=calibrate_loop)
# model has qdq nodes at this point
with torch.no_grad():
    with export_torch_mode():
        exp_program = torch.export._trace._export(
            model, (dummy_inputs,), strict=False, dynamic_shapes=({1: SEQ_SIZE},), allow_complex_guards_as_runtime_asserts=True,
        )
        print(exp_program)
        with torchtrt.dynamo.Debugger(
            "debug",
            engine_builder_monitor=False,
        ):
            trt_model = torchtrt.dynamo.compile(
                exp_program,
                inputs=[dummy_inputs],
                min_block_size=1,
                cache_built_engines=False,
                reuse_cached_engines=False,
            )

Expected behavior

Environment

Build information about Torch-TensorRT can be found by turning on debug messages

  • Torch-TensorRT Version (e.g. 1.0.0):
  • PyTorch Version (e.g. 1.0):
  • CPU Architecture:
  • OS (e.g., Linux):
  • How you installed PyTorch (conda, pip, libtorch, source):
  • Build command you used (if compiling from source):
  • Are you using local sources or building from archives:
  • Python version:
  • CUDA version:
  • GPU models and configuration:
  • Any other relevant information:

Additional context

Metadata

Metadata

Assignees

Labels

bugSomething isn't working

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions