Skip to content
18 changes: 18 additions & 0 deletions py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,24 @@ def aten_ops_sigmoid(
)


@dynamo_tensorrt_converter(torch.ops.aten.index.Tensor)
Copy link
Collaborator

@gs-olive gs-olive Oct 3, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Consider adding @enforce_tensor_types( {0: (TRTTensor,)} ), to ensure the input is a TRTTensor

def aten_ops_index(
ctx: ConversionContext,
target: Target,
args: Tuple[Argument, ...],
kwargs: Dict[str, Argument],
name: str,
) -> Union[TRTTensor, Sequence[TRTTensor]]:
return impl.select.index(
ctx,
target,
SourceIR.ATEN,
name,
args[0],
args[1],
)


@dynamo_tensorrt_converter(torch.ops.aten.tanh.default) # type: ignore[misc]
def aten_ops_tanh(
ctx: ConversionContext,
Expand Down
294 changes: 291 additions & 3 deletions py/torch_tensorrt/dynamo/conversion/impl/select.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,27 @@
from typing import Optional, cast
import logging
from typing import Optional, Sequence, Union, cast

import numpy as np
import tensorrt as trt
from torch.fx.node import Target
from torch_tensorrt.dynamo._SourceIR import SourceIR
from torch_tensorrt.dynamo.conversion._ConversionContext import ConversionContext
from torch_tensorrt.dynamo.conversion.converter_utils import get_positive_dim, to_numpy
from torch_tensorrt.dynamo.conversion.converter_utils import (
broadcastable,
get_trt_tensor,
)
from torch_tensorrt.dynamo.conversion.impl.elementwise import convert_binary_elementwise
from torch_tensorrt.dynamo.conversion.impl.shape import get_shape_with_dynamic_shape
from torch_tensorrt.fx.converters.converter_utils import has_dynamic_shape
from torch_tensorrt.fx.converters.converter_utils import (
get_positive_dim,
has_dynamic_shape,
set_layer_name,
to_numpy,
)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Switch get_positive_dim and to_numpy to the torch_tensorrt.dynamo.conversion.converter_utils version

from torch_tensorrt.fx.types import Shape, TRTTensor

_LOGGER: logging.Logger = logging.getLogger(__name__)


def select(
ctx: ConversionContext,
Expand Down Expand Up @@ -59,3 +72,278 @@ def select(
if len(out.shape) != 1:
layer = ctx.net.add_shuffle(out)
return layer.get_output(0)


def index(
ctx: ConversionContext,
target: Target,
source_ir: Optional[SourceIR],
name: str,
input: TRTTensor,
index: Union[TRTTensor, Sequence[TRTTensor]],
) -> TRTTensor:
adv_indx_indices = []
tensor_indices = []
# _LOGGER.debug(f"The index shape is {index.shape}")
# check if the input is dynamic
dynamic_shape = has_dynamic_shape(input.shape)

# here we need to check if all the index are broadcastable
# if no, then we need to broadcast
input = get_trt_tensor(ctx, input, name + f"_input_to_tensor")

last_index = None
for i, ind in enumerate(index):
if ind is not None:
_LOGGER.debug(f"Shape of {i} index is {ind.shape}")
adv_indx_indices.append(i)
# torch.nn.parameter.Parameter=> torch.Tensor
ind = get_trt_tensor(ctx, ind, name + f"_parameter_to_fp32_tensor_{i}")
if last_index is not None:
assert broadcastable(
ind, last_index
), "The indices should be broadcastable!"
last_index = ind
tensor_indices.append(ind)

if not tensor_indices:
identity_layer = ctx.net.add_identity(input)
identity_layer.set_output_type(0, trt.int32)
set_layer_name(identity_layer, target, name + "_index_identity", source_ir)
return identity_layer.get_output(0)
elif len(tensor_indices) == 1:
# This case works
indices_tensor = tensor_indices[0]
index = adv_indx_indices[0]
_LOGGER.debug(f"The advanced index indices is {adv_indx_indices}")
gather_layer = ctx.net.add_gather(input, indices_tensor, index)
set_layer_name(gather_layer, target, name + "_index_gather", source_ir)
return gather_layer.get_output(0)
else:
input_shape = input.shape
_LOGGER.debug(f"The input shape is {input.shape}")
if dynamic_shape:
input_shape = get_shape_with_dynamic_shape(
ctx.net, target, source_ir, name, input_shape, input
)
rank = len(input_shape)
adv_indx_count = len(adv_indx_indices)
dim_tensor_list = []

for i in range(rank):
dim = input_shape[i]
dim_tensor = get_trt_tensor(ctx, dim, name + f"_individual_dim_{i}")
# dim_tensor_list is a list of tensors
dim_tensor_list.append(dim_tensor)

# for cases like
# t: [x_1, y_1, y_2, ..., x_m, ..., y_n] -> t: [x_1, x_2, ..., x_m, y_1, y_2, ..., y_n],
# where t is a tensor of rank m+n, {x_i} are axes where tensor index is provided, and {y_i} are axes
# for ":"
# Examples: x.shape = (10,20,30,40,50)
# ind_1, ind_2 broadcasted to (2,3,4)
# x[:, ind_1, ind_2] = 10, 2, 3, 4, 40, 50
# x[:,ind_1, :, ind_2] = 2, 3, 4, 10, 30, 50
transpose_layer = ctx.net.add_shuffle(input)
new_order = []
for i in range(adv_indx_count):
new_order.append(adv_indx_indices[i])
for i in range(rank):
if i not in adv_indx_indices:
new_order.append(i)
_LOGGER.debug(f"The new transpose order is {new_order}")
transpose_layer.second_transpose = tuple(new_order)
set_layer_name(transpose_layer, target, name + "_index_transpose", source_ir)
transpose_tensor = transpose_layer.get_output(0)

# Flatten [x_1, x_2,.......x_m, y_1, y_2,.....y_n]
# transpose_tensor_shape = ctx.net.add_shape(transpose_tensor)
transpose_tensor_shape = transpose_tensor.shape
_LOGGER.debug(f"The shape of transpose tensor is {transpose_tensor_shape}")
mult_d0 = 1
for i in range(adv_indx_count):
mult_d0 = mult_d0 * transpose_tensor_shape[i]
mult_d1 = 1
for i in range(adv_indx_count, rank):
mult_d1 = mult_d1 * transpose_tensor_shape[i]

concat_tensor_layer = ctx.net.add_concatenation(
[
get_trt_tensor(ctx, mult_d0, name + "_d0_shape"),
get_trt_tensor(ctx, mult_d1, name + "_d1_shape"),
]
)
set_layer_name(concat_tensor_layer, target, name + "_index_Concat", source_ir)
concat_tensor = concat_tensor_layer.get_output(0)

reshape_layer = ctx.net.add_shuffle(transpose_tensor)
# check this
reshape_layer.set_input(1, concat_tensor)
flatten_tensor = reshape_layer.get_output(0)
_LOGGER.debug(f"The flatten tensor shape is {flatten_tensor.shape}")

# tensor index = \sum_{i=1}^m (ind_i * \prod_{j=i+1}^m (x_j)), ind_i is input indices[i], x_j is the
# // j dimension of input x.
multiplier = get_trt_tensor(
ctx,
dim_tensor_list[adv_indx_indices[adv_indx_count - 1]],
name + "_dim_last",
)
cum_adv_index = tensor_indices[adv_indx_count - 1]
for i in range(adv_indx_count - 2, -1, -1):
adv_index = convert_binary_elementwise(
ctx,
target,
source_ir,
name + f"_index_intermediate_{i}",
trt.ElementWiseOperation.PROD,
multiplier,
tensor_indices[i],
)
cum_adv_index = convert_binary_elementwise(
ctx,
target,
source_ir,
name + f"_index_sum_intermediate_{i}",
trt.ElementWiseOperation.SUM,
cum_adv_index,
adv_index,
)
multiplier = convert_binary_elementwise(
ctx,
target,
source_ir,
name + f"_index_intermediate_xj_{i}",
trt.ElementWiseOperation.PROD,
multiplier,
dim_tensor_list[adv_indx_indices[i]],
)

gather_layer_element = ctx.net.add_gather(flatten_tensor, cum_adv_index, 0)
set_layer_name(
gather_layer_element, target, name + "_index_gather_element", source_ir
)
gather_out = gather_layer_element.get_output(0)
_LOGGER.debug(f"The shape after cumultative gather is {gather_out.shape}")
_LOGGER.debug(f"The shape for cumulative adv index is {cum_adv_index}")

cum_adv_index_shape_layer = ctx.net.add_shape(cum_adv_index)
set_layer_name(
cum_adv_index_shape_layer, target, name + "_cum_adv_index_shape", source_ir
)
cum_adv_index_shape_tensor = cum_adv_index_shape_layer.get_output(0)
cum_adv_index_shape = cum_adv_index.shape
_LOGGER.debug(f"The shape for cumulative adv index is {cum_adv_index_shape}")
# check if all advanced indices are consecutive
concat_tensor_reshape = []
if (
adv_indx_count
== adv_indx_indices[adv_indx_count - 1] - adv_indx_indices[0] + 1
):
_LOGGER.debug(f"The indices are continuous in this case")
concat_tensor_reshape.append(
get_trt_tensor(ctx, -1, name + "_dynamic_concat")
)
for i in range(0, rank):
if i not in adv_indx_indices:
curr_dim = dim_tensor_list[i]
concat_tensor_reshape.append(curr_dim)

concat_tensor_layer = ctx.net.add_concatenation(concat_tensor_reshape)
set_layer_name(
concat_tensor_layer, target, name + "_index_Concat_reshape", source_ir
)
concat_tensor = concat_tensor_layer.get_output(0)

regular_index_shuffle_layer = ctx.net.add_shuffle(gather_out)
regular_index_shuffle_layer.set_input(1, concat_tensor)
set_layer_name(
regular_index_shuffle_layer,
target,
name + "_index_regular_index",
source_ir,
)
unfold_tensor = regular_index_shuffle_layer.get_output(0)
_LOGGER.debug(f"The tensor is unfolded now")
_LOGGER.debug(f"The unfolded tensor shape is {unfold_tensor.shape}")

# Transpose folded advanced indexed axis to its original location.
transpose_advanced_shuffle_layer = ctx.net.add_shuffle(unfold_tensor)
new_order = []
for i in range(1, adv_indx_indices[0] + 1):
new_order.append(i)
new_order.append(0)
for i in range(adv_indx_indices[0] + 1, rank - adv_indx_count + 1):
new_order.append(i)
_LOGGER.debug(f"Transposing the indices to correct position {new_order}")

transpose_advanced_shuffle_layer.second_transpose = tuple(new_order)
set_layer_name(
transpose_advanced_shuffle_layer,
target,
name + "_index_advanced_shuffle_transpose",
source_ir,
)
transpose_tensor = transpose_advanced_shuffle_layer.get_output(0)

# unfold advanced layer
concat_final_tensor = []
for i in range(0, adv_indx_indices[0]):
current_dim = dim_tensor_list[i]
concat_final_tensor.append(current_dim)

concat_final_tensor.append(cum_adv_index_shape_tensor)
for i in range(adv_indx_indices[0], rank):
if i not in (adv_indx_indices):
current_dim = dim_tensor_list[i]
concat_final_tensor.append(current_dim)

concat_final_shape_layer = ctx.net.add_concatenation(concat_final_tensor)
set_layer_name(
concat_final_shape_layer,
target,
name + "_index_continuous_concat_final_shape_layer",
source_ir,
)
concat_final_tensor = concat_final_shape_layer.get_output(0)

unfold_advanced_shuffle_layer = ctx.net.add_shuffle(transpose_tensor)
# check this
unfold_advanced_shuffle_layer.set_input(1, concat_final_tensor)
set_layer_name(
unfold_advanced_shuffle_layer,
target,
name + "_unfold_advanced_index",
source_ir,
)
reshape_output = unfold_advanced_shuffle_layer.get_output(0)

else:
_LOGGER.debug(f"The indices are not continuous in this case")
concat_final_tensor = []
concat_final_tensor.append(cum_adv_index_shape_tensor)
for i in range(0, rank):
if i not in adv_indx_indices:
curr_dim = dim_tensor_list[i]
concat_final_tensor.append(curr_dim)

concat_final_shape_layer = ctx.net.add_concatenation(concat_final_tensor)
set_layer_name(
concat_final_shape_layer,
target,
name + "_index_non_continuous_concat_final_shape_layer",
source_ir,
)
concat_final_tensor = concat_final_shape_layer.get_output(0)

reshape_layer = ctx.net.add_shuffle(gather_out)
reshape_layer.set_input(1, concat_final_tensor)
set_layer_name(
reshape_layer,
target,
name + "_index_non_continuous_shuffle_final_shape_layer",
source_ir,
)
reshape_output = reshape_layer.get_output(0)

return reshape_output
Loading