Skip to content

Commit 9e171ba

Browse files
committed
update prelu
1 parent 90fcd2b commit 9e171ba

File tree

1 file changed

+8
-1
lines changed
  • py/torch_tensorrt/dynamo/conversion/impl

1 file changed

+8
-1
lines changed

py/torch_tensorrt/dynamo/conversion/impl/prelu.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
from torch.fx.node import Target
44
from torch_tensorrt.dynamo._SourceIR import SourceIR
5+
from torch_tensorrt.dynamo.conversion import impl
56
from torch_tensorrt.dynamo.conversion._ConversionContext import ConversionContext
67
from torch_tensorrt.dynamo.conversion.converter_utils import set_layer_name
78
from torch_tensorrt.dynamo.types import TRTTensor
@@ -15,6 +16,12 @@ def prelu(
1516
input: TRTTensor,
1617
weight: TRTTensor,
1718
) -> TRTTensor:
18-
layer = ctx.net.add_parametric_relu(input, weight)
19+
# TRT requires that the slopes tensor must be unidirectional broadcastable to the input tensor:
20+
# the rank of the two tensors must be the same, and all dimensions of the slopes tensor must
21+
# either equal the input tensor or be 1. The output tensor has the same shape as the input tensor.
22+
input, weight = impl.elementwise.broadcast(
23+
ctx, input, weight, f"{name}_expand_input", f"{name}_expand_weight"
24+
)
25+
layer = ctx.net.add_parametric_relu(input, slopes=weight)
1926
set_layer_name(layer, target, name, source_ir)
2027
return layer.get_output(0)

0 commit comments

Comments
 (0)