-
Notifications
You must be signed in to change notification settings - Fork 372
Open
Labels
bugSomething isn't workingSomething isn't working
Description
repro:
import torch
import torch_tensorrt
class SampleModel(torch.nn.Module):
def forward(self, x):
return torch.relu((x + 2) * 0.5)
model = SampleModel().eval().cuda()
input = torch.randn((1, 3, 224, 224)).to("cuda")
# The 'torch_executed_ops' compiler option is used in this example to intentionally introduce graph breaks within the module.
# Note: The Dynamo backend is required for the CUDA Graph context manager to handle modules in an Ahead-Of-Time (AOT) manner.
opt_with_graph_break = torch_tensorrt.compile(
model,
ir="dynamo",
inputs=[input],
min_block_size=1,
pass_through_build_failures=True,
torch_executed_ops={"torch.ops.aten.mul.Tensor"},
use_python_runtime=True,
)
with torch_tensorrt.runtime.enable_cudagraphs(
opt_with_graph_break
) as cudagraphs_module:
out = cudagraphs_module(input)
# Check output with PyTorch for correctness
torch_out = model(input)
print("TRT output:", out)
print("PyTorch output:", torch_out)
print("Difference:", (out - torch_out).abs().max())
The output is all 0s.
Metadata
Metadata
Assignees
Labels
bugSomething isn't workingSomething isn't working