diff --git a/py/torch_tensorrt/_compile.py b/py/torch_tensorrt/_compile.py index 4d3ab91e32..c26d7a80b7 100644 --- a/py/torch_tensorrt/_compile.py +++ b/py/torch_tensorrt/_compile.py @@ -4,7 +4,7 @@ import logging import platform from enum import Enum -from typing import Any, Callable, List, Optional, Sequence, Set +from typing import Any, Callable, List, Optional, Sequence, Set, Union import torch import torch.fx @@ -170,7 +170,7 @@ def compile( inputs: Optional[Sequence[Input | torch.Tensor | InputTensorSpec]] = None, arg_inputs: Optional[Sequence[Sequence[Any]]] = None, kwarg_inputs: Optional[dict[Any, Any]] = None, - enabled_precisions: Optional[Set[torch.dtype | dtype]] = None, + enabled_precisions: Optional[Set[Union[torch.dtype, dtype]]] = None, **kwargs: Any, ) -> ( torch.nn.Module | torch.jit.ScriptModule | torch.fx.GraphModule | Callable[..., Any] @@ -213,7 +213,7 @@ def compile( """ input_list = inputs if inputs is not None else [] - enabled_precisions_set: Set[dtype | torch.dtype] = ( + enabled_precisions_set: Set[Union[torch.dtype, dtype]] = ( enabled_precisions if enabled_precisions is not None else _defaults.ENABLED_PRECISIONS @@ -308,7 +308,7 @@ def cross_compile_for_windows( inputs: Optional[Sequence[Input | torch.Tensor]] = None, arg_inputs: Optional[Sequence[Sequence[Any]]] = None, kwarg_inputs: Optional[dict[Any, Any]] = None, - enabled_precisions: Optional[Set[torch.dtype | dtype]] = None, + enabled_precisions: Optional[Set[Union[torch.dtype, dtype]]] = None, **kwargs: Any, ) -> None: """Compile a PyTorch module using TensorRT in Linux for Inference in Windows @@ -424,7 +424,7 @@ def convert_method_to_trt_engine( arg_inputs: Optional[Sequence[Sequence[Any]]] = None, kwarg_inputs: Optional[dict[Any, Any]] = None, ir: str = "default", - enabled_precisions: Optional[Set[torch.dtype | dtype]] = None, + enabled_precisions: Optional[Set[Union[torch.dtype, dtype]]] = None, **kwargs: Any, ) -> bytes: """Convert a TorchScript module method to a serialized TensorRT engine diff --git a/py/torch_tensorrt/dynamo/_compiler.py b/py/torch_tensorrt/dynamo/_compiler.py index d7092f1e0f..6434afe248 100644 --- a/py/torch_tensorrt/dynamo/_compiler.py +++ b/py/torch_tensorrt/dynamo/_compiler.py @@ -993,9 +993,9 @@ def convert_exported_program_to_serialized_trt_engine( *, arg_inputs: Optional[Sequence[Sequence[Any]]] = None, kwarg_inputs: Optional[dict[Any, Any]] = None, - enabled_precisions: ( - Set[torch.dtype | dtype] | Tuple[torch.dtype | dtype] - ) = _defaults.ENABLED_PRECISIONS, + enabled_precisions: Union[ + Set[Union[torch.dtype, dtype]], Tuple[Union[torch.dtype, dtype]] + ] = _defaults.ENABLED_PRECISIONS, assume_dynamic_shape_support: bool = _defaults.ASSUME_DYNAMIC_SHAPE_SUPPORT, workspace_size: int = _defaults.WORKSPACE_SIZE, min_block_size: int = _defaults.MIN_BLOCK_SIZE, diff --git a/py/torch_tensorrt/dynamo/runtime/_MutableTorchTensorRTModule.py b/py/torch_tensorrt/dynamo/runtime/_MutableTorchTensorRTModule.py index 1ad118facc..85a31b9736 100644 --- a/py/torch_tensorrt/dynamo/runtime/_MutableTorchTensorRTModule.py +++ b/py/torch_tensorrt/dynamo/runtime/_MutableTorchTensorRTModule.py @@ -69,7 +69,7 @@ def __init__( strict: bool = True, allow_complex_guards_as_runtime_asserts: bool = False, weight_streaming_budget: Optional[int] = None, - enabled_precisions: Optional[Set[torch.dtype | dtype]] = None, + enabled_precisions: Optional[Set[Union[torch.dtype, dtype]]] = None, **kwargs: Any, ) -> None: """ diff --git a/py/torch_tensorrt/ts/_compile_spec.py b/py/torch_tensorrt/ts/_compile_spec.py index 5d6d27e4ad..6016fe87c5 100644 --- a/py/torch_tensorrt/ts/_compile_spec.py +++ b/py/torch_tensorrt/ts/_compile_spec.py @@ -1,8 +1,9 @@ from __future__ import annotations from copy import deepcopy -from typing import Any, Dict, List, Optional, Set +from typing import Any, Dict, List, Optional, Set, Union +import tensorrt as trt import torch import torch_tensorrt._C.ts as _ts_C from torch_tensorrt import _C @@ -13,8 +14,6 @@ from torch_tensorrt.ts._Input import TorchScriptInput from torch_tensorrt.ts.logging import Level, log -import tensorrt as trt - def _internal_input_to_torch_class_input(i: _C.Input) -> torch.classes.tensorrt._Input: clone = torch.classes.tensorrt._Input() @@ -310,7 +309,7 @@ def TensorRTCompileSpec( device: Optional[torch.device | Device] = None, disable_tf32: bool = False, sparse_weights: bool = False, - enabled_precisions: Optional[Set[torch.dtype | dtype]] = None, + enabled_precisions: Optional[Set[Union[torch.dtype, dtype]]] = None, refit: bool = False, debug: bool = False, capability: EngineCapability = EngineCapability.STANDARD, diff --git a/py/torch_tensorrt/ts/_compiler.py b/py/torch_tensorrt/ts/_compiler.py index 3d1406c158..114398f010 100644 --- a/py/torch_tensorrt/ts/_compiler.py +++ b/py/torch_tensorrt/ts/_compiler.py @@ -1,7 +1,7 @@ from __future__ import annotations import warnings -from typing import Any, List, Optional, Sequence, Set, Tuple +from typing import Any, List, Optional, Sequence, Set, Tuple, Union import torch import torch_tensorrt._C.ts as _C @@ -18,7 +18,7 @@ def compile( device: Device = Device._current_device(), disable_tf32: bool = False, sparse_weights: bool = False, - enabled_precisions: Optional[Set[torch.dtype | dtype]] = None, + enabled_precisions: Optional[Set[Union[torch.dtype, dtype]]] = None, refit: bool = False, debug: bool = False, capability: EngineCapability = EngineCapability.STANDARD, @@ -172,7 +172,7 @@ def convert_method_to_trt_engine( device: Device = Device._current_device(), disable_tf32: bool = False, sparse_weights: bool = False, - enabled_precisions: Optional[Set[torch.dtype | dtype]] = None, + enabled_precisions: Optional[Set[Union[torch.dtype, dtype]]] = None, refit: bool = False, debug: bool = False, capability: EngineCapability = EngineCapability.STANDARD,