Skip to content

🐛 [Bug] CudaGraph cannot work with module with graph breaks #3755

@cehongwang

Description

@cehongwang

repro:

import torch
import torch_tensorrt

class SampleModel(torch.nn.Module):
    def forward(self, x):
        return torch.relu((x + 2) * 0.5)


model = SampleModel().eval().cuda()
input = torch.randn((1, 3, 224, 224)).to("cuda")

# The 'torch_executed_ops' compiler option is used in this example to intentionally introduce graph breaks within the module.
# Note: The Dynamo backend is required for the CUDA Graph context manager to handle modules in an Ahead-Of-Time (AOT) manner.
opt_with_graph_break = torch_tensorrt.compile(
    model,
    ir="dynamo",
    inputs=[input],
    min_block_size=1,
    pass_through_build_failures=True,
    torch_executed_ops={"torch.ops.aten.mul.Tensor"},
    use_python_runtime=True,
)

with torch_tensorrt.runtime.enable_cudagraphs(
    opt_with_graph_break
) as cudagraphs_module:
    out = cudagraphs_module(input)


    # Check output with PyTorch for correctness
    torch_out = model(input)
    print("TRT output:", out)
    print("PyTorch output:", torch_out)
    print("Difference:", (out - torch_out).abs().max())

The output is all 0s.

Metadata

Metadata

Assignees

Labels

bugSomething isn't working

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions