From 330285e41eab3cb1ff9c2049074d1d6c426551ae Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Tue, 6 Dec 2022 10:13:10 +0100 Subject: [PATCH] add usage logging to prototype dispatchers / kernels --- test/prototype_transforms_kernel_infos.py | 6 +++ test/test_prototype_transforms_functional.py | 26 +++++++++++++ .../transforms/functional/_augment.py | 4 ++ .../prototype/transforms/functional/_color.py | 35 +++++++++++++++++ .../transforms/functional/_geometry.py | 39 +++++++++++++++++++ .../prototype/transforms/functional/_meta.py | 26 +++++++++++++ .../prototype/transforms/functional/_misc.py | 7 ++++ .../transforms/functional/_temporal.py | 5 +++ 8 files changed, 148 insertions(+) diff --git a/test/prototype_transforms_kernel_infos.py b/test/prototype_transforms_kernel_infos.py index 8849365ea85..9d97b6ca701 100644 --- a/test/prototype_transforms_kernel_infos.py +++ b/test/prototype_transforms_kernel_infos.py @@ -57,6 +57,9 @@ def __init__( # structure, but with adapted parameters. This is useful in case a parameter value is closely tied to the input # dtype. float32_vs_uint8=False, + # Some kernels don't have dispatchers that would handle logging the usage. Thus, the kernel has to do it + # manually. If set, triggers a test that makes sure this happens. + logs_usage=False, # See InfoBase test_marks=None, # See InfoBase @@ -71,6 +74,7 @@ def __init__( if float32_vs_uint8 and not callable(float32_vs_uint8): float32_vs_uint8 = lambda other_args, kwargs: (other_args, kwargs) # noqa: E731 self.float32_vs_uint8 = float32_vs_uint8 + self.logs_usage = logs_usage def _pixel_difference_closeness_kwargs(uint8_atol, *, dtype=torch.uint8, mae=False): @@ -675,6 +679,7 @@ def reference_inputs_convert_format_bounding_box(): sample_inputs_fn=sample_inputs_convert_format_bounding_box, reference_fn=reference_convert_format_bounding_box, reference_inputs_fn=reference_inputs_convert_format_bounding_box, + logs_usage=True, ), ) @@ -2100,6 +2105,7 @@ def sample_inputs_clamp_bounding_box(): KernelInfo( F.clamp_bounding_box, sample_inputs_fn=sample_inputs_clamp_bounding_box, + logs_usage=True, ) ) diff --git a/test/test_prototype_transforms_functional.py b/test/test_prototype_transforms_functional.py index 7cd84fbcd61..f33992234ad 100644 --- a/test/test_prototype_transforms_functional.py +++ b/test/test_prototype_transforms_functional.py @@ -108,6 +108,19 @@ class TestKernels: args_kwargs_fn=lambda info: info.reference_inputs_fn(), ) + @make_info_args_kwargs_parametrization( + [info for info in KERNEL_INFOS if info.logs_usage], + args_kwargs_fn=lambda info: info.sample_inputs_fn(), + ) + @pytest.mark.parametrize("device", cpu_and_gpu()) + def test_logging(self, spy_on, info, args_kwargs, device): + spy = spy_on(torch._C._log_api_usage_once) + + args, kwargs = args_kwargs.load(device) + info.kernel(*args, **kwargs) + + spy.assert_any_call(f"{info.kernel.__module__}.{info.id}") + @ignore_jit_warning_no_profile @sample_inputs @pytest.mark.parametrize("device", cpu_and_gpu()) @@ -291,6 +304,19 @@ class TestDispatchers: args_kwargs_fn=lambda info: info.sample_inputs(datapoints.Image), ) + @make_info_args_kwargs_parametrization( + DISPATCHER_INFOS, + args_kwargs_fn=lambda info: info.sample_inputs(), + ) + @pytest.mark.parametrize("device", cpu_and_gpu()) + def test_logging(self, spy_on, info, args_kwargs, device): + spy = spy_on(torch._C._log_api_usage_once) + + args, kwargs = args_kwargs.load(device) + info.dispatcher(*args, **kwargs) + + spy.assert_any_call(f"{info.dispatcher.__module__}.{info.id}") + @ignore_jit_warning_no_profile @image_sample_inputs @pytest.mark.parametrize("device", cpu_and_gpu()) diff --git a/torchvision/prototype/transforms/functional/_augment.py b/torchvision/prototype/transforms/functional/_augment.py index 9f4a248089d..12af2444ef0 100644 --- a/torchvision/prototype/transforms/functional/_augment.py +++ b/torchvision/prototype/transforms/functional/_augment.py @@ -5,6 +5,7 @@ import torch from torchvision.prototype import datapoints from torchvision.transforms.functional import pil_to_tensor, to_pil_image +from torchvision.utils import _log_api_usage_once def erase_image_tensor( @@ -41,6 +42,9 @@ def erase( v: torch.Tensor, inplace: bool = False, ) -> Union[datapoints.ImageTypeJIT, datapoints.VideoTypeJIT]: + if not torch.jit.is_scripting(): + _log_api_usage_once(erase) + if isinstance(inpt, torch.Tensor) and ( torch.jit.is_scripting() or not isinstance(inpt, (datapoints.Image, datapoints.Video)) ): diff --git a/torchvision/prototype/transforms/functional/_color.py b/torchvision/prototype/transforms/functional/_color.py index 618968cbb48..517f7457775 100644 --- a/torchvision/prototype/transforms/functional/_color.py +++ b/torchvision/prototype/transforms/functional/_color.py @@ -5,6 +5,8 @@ from torchvision.transforms import functional_pil as _FP from torchvision.transforms.functional_tensor import _max_value +from torchvision.utils import _log_api_usage_once + from ._meta import _num_value_bits, _rgb_to_gray, convert_dtype_image_tensor @@ -38,6 +40,9 @@ def adjust_brightness_video(video: torch.Tensor, brightness_factor: float) -> to def adjust_brightness(inpt: datapoints.InputTypeJIT, brightness_factor: float) -> datapoints.InputTypeJIT: + if not torch.jit.is_scripting(): + _log_api_usage_once(adjust_brightness) + if isinstance(inpt, torch.Tensor) and ( torch.jit.is_scripting() or not isinstance(inpt, datapoints._datapoint.Datapoint) ): @@ -79,6 +84,9 @@ def adjust_saturation_video(video: torch.Tensor, saturation_factor: float) -> to def adjust_saturation(inpt: datapoints.InputTypeJIT, saturation_factor: float) -> datapoints.InputTypeJIT: + if not torch.jit.is_scripting(): + _log_api_usage_once(adjust_saturation) + if isinstance(inpt, torch.Tensor) and ( torch.jit.is_scripting() or not isinstance(inpt, datapoints._datapoint.Datapoint) ): @@ -120,6 +128,9 @@ def adjust_contrast_video(video: torch.Tensor, contrast_factor: float) -> torch. def adjust_contrast(inpt: datapoints.InputTypeJIT, contrast_factor: float) -> datapoints.InputTypeJIT: + if not torch.jit.is_scripting(): + _log_api_usage_once(adjust_contrast) + if isinstance(inpt, torch.Tensor) and ( torch.jit.is_scripting() or not isinstance(inpt, datapoints._datapoint.Datapoint) ): @@ -195,6 +206,9 @@ def adjust_sharpness_video(video: torch.Tensor, sharpness_factor: float) -> torc def adjust_sharpness(inpt: datapoints.InputTypeJIT, sharpness_factor: float) -> datapoints.InputTypeJIT: + if not torch.jit.is_scripting(): + _log_api_usage_once(adjust_sharpness) + if isinstance(inpt, torch.Tensor) and ( torch.jit.is_scripting() or not isinstance(inpt, datapoints._datapoint.Datapoint) ): @@ -309,6 +323,9 @@ def adjust_hue_video(video: torch.Tensor, hue_factor: float) -> torch.Tensor: def adjust_hue(inpt: datapoints.InputTypeJIT, hue_factor: float) -> datapoints.InputTypeJIT: + if not torch.jit.is_scripting(): + _log_api_usage_once(adjust_hue) + if isinstance(inpt, torch.Tensor) and ( torch.jit.is_scripting() or not isinstance(inpt, datapoints._datapoint.Datapoint) ): @@ -351,6 +368,9 @@ def adjust_gamma_video(video: torch.Tensor, gamma: float, gain: float = 1) -> to def adjust_gamma(inpt: datapoints.InputTypeJIT, gamma: float, gain: float = 1) -> datapoints.InputTypeJIT: + if not torch.jit.is_scripting(): + _log_api_usage_once(adjust_gamma) + if isinstance(inpt, torch.Tensor) and ( torch.jit.is_scripting() or not isinstance(inpt, datapoints._datapoint.Datapoint) ): @@ -387,6 +407,9 @@ def posterize_video(video: torch.Tensor, bits: int) -> torch.Tensor: def posterize(inpt: datapoints.InputTypeJIT, bits: int) -> datapoints.InputTypeJIT: + if not torch.jit.is_scripting(): + _log_api_usage_once(posterize) + if isinstance(inpt, torch.Tensor) and ( torch.jit.is_scripting() or not isinstance(inpt, datapoints._datapoint.Datapoint) ): @@ -417,6 +440,9 @@ def solarize_video(video: torch.Tensor, threshold: float) -> torch.Tensor: def solarize(inpt: datapoints.InputTypeJIT, threshold: float) -> datapoints.InputTypeJIT: + if not torch.jit.is_scripting(): + _log_api_usage_once(solarize) + if isinstance(inpt, torch.Tensor) and ( torch.jit.is_scripting() or not isinstance(inpt, datapoints._datapoint.Datapoint) ): @@ -469,6 +495,9 @@ def autocontrast_video(video: torch.Tensor) -> torch.Tensor: def autocontrast(inpt: datapoints.InputTypeJIT) -> datapoints.InputTypeJIT: + if not torch.jit.is_scripting(): + _log_api_usage_once(autocontrast) + if isinstance(inpt, torch.Tensor) and ( torch.jit.is_scripting() or not isinstance(inpt, datapoints._datapoint.Datapoint) ): @@ -561,6 +590,9 @@ def equalize_video(video: torch.Tensor) -> torch.Tensor: def equalize(inpt: datapoints.InputTypeJIT) -> datapoints.InputTypeJIT: + if not torch.jit.is_scripting(): + _log_api_usage_once(equalize) + if isinstance(inpt, torch.Tensor) and ( torch.jit.is_scripting() or not isinstance(inpt, datapoints._datapoint.Datapoint) ): @@ -594,6 +626,9 @@ def invert_video(video: torch.Tensor) -> torch.Tensor: def invert(inpt: datapoints.InputTypeJIT) -> datapoints.InputTypeJIT: + if not torch.jit.is_scripting(): + _log_api_usage_once(invert) + if isinstance(inpt, torch.Tensor) and ( torch.jit.is_scripting() or not isinstance(inpt, datapoints._datapoint.Datapoint) ): diff --git a/torchvision/prototype/transforms/functional/_geometry.py b/torchvision/prototype/transforms/functional/_geometry.py index cef68d66ee9..ba417a0ce84 100644 --- a/torchvision/prototype/transforms/functional/_geometry.py +++ b/torchvision/prototype/transforms/functional/_geometry.py @@ -19,6 +19,8 @@ ) from torchvision.transforms.functional_tensor import _pad_symmetric +from torchvision.utils import _log_api_usage_once + from ._meta import convert_format_bounding_box, get_spatial_size_image_pil @@ -55,6 +57,9 @@ def horizontal_flip_video(video: torch.Tensor) -> torch.Tensor: def horizontal_flip(inpt: datapoints.InputTypeJIT) -> datapoints.InputTypeJIT: + if not torch.jit.is_scripting(): + _log_api_usage_once(horizontal_flip) + if isinstance(inpt, torch.Tensor) and ( torch.jit.is_scripting() or not isinstance(inpt, datapoints._datapoint.Datapoint) ): @@ -103,6 +108,9 @@ def vertical_flip_video(video: torch.Tensor) -> torch.Tensor: def vertical_flip(inpt: datapoints.InputTypeJIT) -> datapoints.InputTypeJIT: + if not torch.jit.is_scripting(): + _log_api_usage_once(vertical_flip) + if isinstance(inpt, torch.Tensor) and ( torch.jit.is_scripting() or not isinstance(inpt, datapoints._datapoint.Datapoint) ): @@ -231,6 +239,8 @@ def resize( max_size: Optional[int] = None, antialias: Optional[bool] = None, ) -> datapoints.InputTypeJIT: + if not torch.jit.is_scripting(): + _log_api_usage_once(resize) if isinstance(inpt, torch.Tensor) and ( torch.jit.is_scripting() or not isinstance(inpt, datapoints._datapoint.Datapoint) ): @@ -730,6 +740,9 @@ def affine( fill: datapoints.FillTypeJIT = None, center: Optional[List[float]] = None, ) -> datapoints.InputTypeJIT: + if not torch.jit.is_scripting(): + _log_api_usage_once(affine) + # TODO: consider deprecating integers from angle and shear on the future if isinstance(inpt, torch.Tensor) and ( torch.jit.is_scripting() or not isinstance(inpt, datapoints._datapoint.Datapoint) @@ -913,6 +926,9 @@ def rotate( center: Optional[List[float]] = None, fill: datapoints.FillTypeJIT = None, ) -> datapoints.InputTypeJIT: + if not torch.jit.is_scripting(): + _log_api_usage_once(rotate) + if isinstance(inpt, torch.Tensor) and ( torch.jit.is_scripting() or not isinstance(inpt, datapoints._datapoint.Datapoint) ): @@ -1120,6 +1136,9 @@ def pad( fill: datapoints.FillTypeJIT = None, padding_mode: str = "constant", ) -> datapoints.InputTypeJIT: + if not torch.jit.is_scripting(): + _log_api_usage_once(pad) + if isinstance(inpt, torch.Tensor) and ( torch.jit.is_scripting() or not isinstance(inpt, datapoints._datapoint.Datapoint) ): @@ -1197,6 +1216,9 @@ def crop_video(video: torch.Tensor, top: int, left: int, height: int, width: int def crop(inpt: datapoints.InputTypeJIT, top: int, left: int, height: int, width: int) -> datapoints.InputTypeJIT: + if not torch.jit.is_scripting(): + _log_api_usage_once(crop) + if isinstance(inpt, torch.Tensor) and ( torch.jit.is_scripting() or not isinstance(inpt, datapoints._datapoint.Datapoint) ): @@ -1452,6 +1474,8 @@ def perspective( fill: datapoints.FillTypeJIT = None, coefficients: Optional[List[float]] = None, ) -> datapoints.InputTypeJIT: + if not torch.jit.is_scripting(): + _log_api_usage_once(perspective) if isinstance(inpt, torch.Tensor) and ( torch.jit.is_scripting() or not isinstance(inpt, datapoints._datapoint.Datapoint) ): @@ -1612,6 +1636,9 @@ def elastic( interpolation: InterpolationMode = InterpolationMode.BILINEAR, fill: datapoints.FillTypeJIT = None, ) -> datapoints.InputTypeJIT: + if not torch.jit.is_scripting(): + _log_api_usage_once(elastic) + if isinstance(inpt, torch.Tensor) and ( torch.jit.is_scripting() or not isinstance(inpt, datapoints._datapoint.Datapoint) ): @@ -1724,6 +1751,9 @@ def center_crop_video(video: torch.Tensor, output_size: List[int]) -> torch.Tens def center_crop(inpt: datapoints.InputTypeJIT, output_size: List[int]) -> datapoints.InputTypeJIT: + if not torch.jit.is_scripting(): + _log_api_usage_once(center_crop) + if isinstance(inpt, torch.Tensor) and ( torch.jit.is_scripting() or not isinstance(inpt, datapoints._datapoint.Datapoint) ): @@ -1817,6 +1847,9 @@ def resized_crop( interpolation: InterpolationMode = InterpolationMode.BILINEAR, antialias: Optional[bool] = None, ) -> datapoints.InputTypeJIT: + if not torch.jit.is_scripting(): + _log_api_usage_once(resized_crop) + if isinstance(inpt, torch.Tensor) and ( torch.jit.is_scripting() or not isinstance(inpt, datapoints._datapoint.Datapoint) ): @@ -1897,6 +1930,9 @@ def five_crop_video( def five_crop( inpt: ImageOrVideoTypeJIT, size: List[int] ) -> Tuple[ImageOrVideoTypeJIT, ImageOrVideoTypeJIT, ImageOrVideoTypeJIT, ImageOrVideoTypeJIT, ImageOrVideoTypeJIT]: + if not torch.jit.is_scripting(): + _log_api_usage_once(five_crop) + # TODO: consider breaking BC here to return List[datapoints.ImageTypeJIT/VideoTypeJIT] to align this op with # `ten_crop` if isinstance(inpt, torch.Tensor) and ( @@ -1952,6 +1988,9 @@ def ten_crop_video(video: torch.Tensor, size: List[int], vertical_flip: bool = F def ten_crop( inpt: Union[datapoints.ImageTypeJIT, datapoints.VideoTypeJIT], size: List[int], vertical_flip: bool = False ) -> Union[List[datapoints.ImageTypeJIT], List[datapoints.VideoTypeJIT]]: + if not torch.jit.is_scripting(): + _log_api_usage_once(ten_crop) + if isinstance(inpt, torch.Tensor) and ( torch.jit.is_scripting() or not isinstance(inpt, (datapoints.Image, datapoints.Video)) ): diff --git a/torchvision/prototype/transforms/functional/_meta.py b/torchvision/prototype/transforms/functional/_meta.py index a6b9c773891..28de0536978 100644 --- a/torchvision/prototype/transforms/functional/_meta.py +++ b/torchvision/prototype/transforms/functional/_meta.py @@ -7,6 +7,8 @@ from torchvision.transforms import functional_pil as _FP from torchvision.transforms.functional_tensor import _max_value +from torchvision.utils import _log_api_usage_once + def get_dimensions_image_tensor(image: torch.Tensor) -> List[int]: chw = list(image.shape[-3:]) @@ -24,6 +26,9 @@ def get_dimensions_image_tensor(image: torch.Tensor) -> List[int]: def get_dimensions(inpt: Union[datapoints.ImageTypeJIT, datapoints.VideoTypeJIT]) -> List[int]: + if not torch.jit.is_scripting(): + _log_api_usage_once(get_dimensions) + if isinstance(inpt, torch.Tensor) and ( torch.jit.is_scripting() or not isinstance(inpt, (datapoints.Image, datapoints.Video)) ): @@ -60,6 +65,9 @@ def get_num_channels_video(video: torch.Tensor) -> int: def get_num_channels(inpt: Union[datapoints.ImageTypeJIT, datapoints.VideoTypeJIT]) -> int: + if not torch.jit.is_scripting(): + _log_api_usage_once(get_num_channels) + if isinstance(inpt, torch.Tensor) and ( torch.jit.is_scripting() or not isinstance(inpt, (datapoints.Image, datapoints.Video)) ): @@ -109,6 +117,9 @@ def get_spatial_size_bounding_box(bounding_box: datapoints.BoundingBox) -> List[ def get_spatial_size(inpt: datapoints.InputTypeJIT) -> List[int]: + if not torch.jit.is_scripting(): + _log_api_usage_once(get_spatial_size) + if isinstance(inpt, torch.Tensor) and ( torch.jit.is_scripting() or not isinstance(inpt, datapoints._datapoint.Datapoint) ): @@ -129,6 +140,9 @@ def get_num_frames_video(video: torch.Tensor) -> int: def get_num_frames(inpt: datapoints.VideoTypeJIT) -> int: + if not torch.jit.is_scripting(): + _log_api_usage_once(get_num_frames) + if isinstance(inpt, torch.Tensor) and (torch.jit.is_scripting() or not isinstance(inpt, datapoints.Video)): return get_num_frames_video(inpt) elif isinstance(inpt, datapoints.Video): @@ -179,6 +193,9 @@ def _xyxy_to_cxcywh(xyxy: torch.Tensor, inplace: bool) -> torch.Tensor: def convert_format_bounding_box( bounding_box: torch.Tensor, old_format: BoundingBoxFormat, new_format: BoundingBoxFormat, inplace: bool = False ) -> torch.Tensor: + if not torch.jit.is_scripting(): + _log_api_usage_once(convert_format_bounding_box) + if new_format == old_format: return bounding_box @@ -199,6 +216,9 @@ def convert_format_bounding_box( def clamp_bounding_box( bounding_box: torch.Tensor, format: BoundingBoxFormat, spatial_size: Tuple[int, int] ) -> torch.Tensor: + if not torch.jit.is_scripting(): + _log_api_usage_once(clamp_bounding_box) + # TODO: Investigate if it makes sense from a performance perspective to have an implementation for every # BoundingBoxFormat instead of converting back and forth xyxy_boxes = convert_format_bounding_box( @@ -313,6 +333,9 @@ def convert_color_space( color_space: ColorSpace, old_color_space: Optional[ColorSpace] = None, ) -> Union[datapoints.ImageTypeJIT, datapoints.VideoTypeJIT]: + if not torch.jit.is_scripting(): + _log_api_usage_once(convert_color_space) + if isinstance(inpt, torch.Tensor) and ( torch.jit.is_scripting() or not isinstance(inpt, (datapoints.Image, datapoints.Video)) ): @@ -417,6 +440,9 @@ def convert_dtype_video(video: torch.Tensor, dtype: torch.dtype = torch.float) - def convert_dtype( inpt: Union[datapoints.ImageTypeJIT, datapoints.VideoTypeJIT], dtype: torch.dtype = torch.float ) -> torch.Tensor: + if not torch.jit.is_scripting(): + _log_api_usage_once(convert_dtype) + if isinstance(inpt, torch.Tensor) and ( torch.jit.is_scripting() or not isinstance(inpt, (datapoints.Image, datapoints.Video)) ): diff --git a/torchvision/prototype/transforms/functional/_misc.py b/torchvision/prototype/transforms/functional/_misc.py index 7799187373f..bc9408d0e2c 100644 --- a/torchvision/prototype/transforms/functional/_misc.py +++ b/torchvision/prototype/transforms/functional/_misc.py @@ -8,6 +8,8 @@ from torchvision.prototype import datapoints from torchvision.transforms.functional import pil_to_tensor, to_pil_image +from torchvision.utils import _log_api_usage_once + from ..utils import is_simple_tensor @@ -57,6 +59,8 @@ def normalize( inplace: bool = False, ) -> torch.Tensor: if not torch.jit.is_scripting(): + _log_api_usage_once(normalize) + if is_simple_tensor(inpt) or isinstance(inpt, (datapoints.Image, datapoints.Video)): inpt = inpt.as_subclass(torch.Tensor) else: @@ -168,6 +172,9 @@ def gaussian_blur_video( def gaussian_blur( inpt: datapoints.InputTypeJIT, kernel_size: List[int], sigma: Optional[List[float]] = None ) -> datapoints.InputTypeJIT: + if not torch.jit.is_scripting(): + _log_api_usage_once(gaussian_blur) + if isinstance(inpt, torch.Tensor) and ( torch.jit.is_scripting() or not isinstance(inpt, datapoints._datapoint.Datapoint) ): diff --git a/torchvision/prototype/transforms/functional/_temporal.py b/torchvision/prototype/transforms/functional/_temporal.py index 63b3baf942e..35f4a84ce7c 100644 --- a/torchvision/prototype/transforms/functional/_temporal.py +++ b/torchvision/prototype/transforms/functional/_temporal.py @@ -2,6 +2,8 @@ from torchvision.prototype import datapoints +from torchvision.utils import _log_api_usage_once + def uniform_temporal_subsample_video(video: torch.Tensor, num_samples: int, temporal_dim: int = -4) -> torch.Tensor: # Reference: https://github.com/facebookresearch/pytorchvideo/blob/a0a131e/pytorchvideo/transforms/functional.py#L19 @@ -13,6 +15,9 @@ def uniform_temporal_subsample_video(video: torch.Tensor, num_samples: int, temp def uniform_temporal_subsample( inpt: datapoints.VideoTypeJIT, num_samples: int, temporal_dim: int = -4 ) -> datapoints.VideoTypeJIT: + if not torch.jit.is_scripting(): + _log_api_usage_once(uniform_temporal_subsample) + if isinstance(inpt, torch.Tensor) and (torch.jit.is_scripting() or not isinstance(inpt, datapoints.Video)): return uniform_temporal_subsample_video(inpt, num_samples, temporal_dim=temporal_dim) elif isinstance(inpt, datapoints.Video):