Skip to content

add usage logging to prototype dispatchers / kernels #7012

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Dec 6, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions test/prototype_transforms_kernel_infos.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,9 @@ def __init__(
# structure, but with adapted parameters. This is useful in case a parameter value is closely tied to the input
# dtype.
float32_vs_uint8=False,
# Some kernels don't have dispatchers that would handle logging the usage. Thus, the kernel has to do it
# manually. If set, triggers a test that makes sure this happens.
logs_usage=False,
# See InfoBase
test_marks=None,
# See InfoBase
Expand All @@ -71,6 +74,7 @@ def __init__(
if float32_vs_uint8 and not callable(float32_vs_uint8):
float32_vs_uint8 = lambda other_args, kwargs: (other_args, kwargs) # noqa: E731
self.float32_vs_uint8 = float32_vs_uint8
self.logs_usage = logs_usage


def _pixel_difference_closeness_kwargs(uint8_atol, *, dtype=torch.uint8, mae=False):
Expand Down Expand Up @@ -675,6 +679,7 @@ def reference_inputs_convert_format_bounding_box():
sample_inputs_fn=sample_inputs_convert_format_bounding_box,
reference_fn=reference_convert_format_bounding_box,
reference_inputs_fn=reference_inputs_convert_format_bounding_box,
logs_usage=True,
),
)

Expand Down Expand Up @@ -2100,6 +2105,7 @@ def sample_inputs_clamp_bounding_box():
KernelInfo(
F.clamp_bounding_box,
sample_inputs_fn=sample_inputs_clamp_bounding_box,
logs_usage=True,
)
)

Expand Down
26 changes: 26 additions & 0 deletions test/test_prototype_transforms_functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,19 @@ class TestKernels:
args_kwargs_fn=lambda info: info.reference_inputs_fn(),
)

@make_info_args_kwargs_parametrization(
[info for info in KERNEL_INFOS if info.logs_usage],
args_kwargs_fn=lambda info: info.sample_inputs_fn(),
)
@pytest.mark.parametrize("device", cpu_and_gpu())
def test_logging(self, spy_on, info, args_kwargs, device):
spy = spy_on(torch._C._log_api_usage_once)

args, kwargs = args_kwargs.load(device)
info.kernel(*args, **kwargs)

spy.assert_any_call(f"{info.kernel.__module__}.{info.id}")

@ignore_jit_warning_no_profile
@sample_inputs
@pytest.mark.parametrize("device", cpu_and_gpu())
Expand Down Expand Up @@ -291,6 +304,19 @@ class TestDispatchers:
args_kwargs_fn=lambda info: info.sample_inputs(datapoints.Image),
)

@make_info_args_kwargs_parametrization(
DISPATCHER_INFOS,
args_kwargs_fn=lambda info: info.sample_inputs(),
)
@pytest.mark.parametrize("device", cpu_and_gpu())
def test_logging(self, spy_on, info, args_kwargs, device):
spy = spy_on(torch._C._log_api_usage_once)

args, kwargs = args_kwargs.load(device)
info.dispatcher(*args, **kwargs)

spy.assert_any_call(f"{info.dispatcher.__module__}.{info.id}")

@ignore_jit_warning_no_profile
@image_sample_inputs
@pytest.mark.parametrize("device", cpu_and_gpu())
Expand Down
4 changes: 4 additions & 0 deletions torchvision/prototype/transforms/functional/_augment.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import torch
from torchvision.prototype import datapoints
from torchvision.transforms.functional import pil_to_tensor, to_pil_image
from torchvision.utils import _log_api_usage_once


def erase_image_tensor(
Expand Down Expand Up @@ -41,6 +42,9 @@ def erase(
v: torch.Tensor,
inplace: bool = False,
) -> Union[datapoints.ImageTypeJIT, datapoints.VideoTypeJIT]:
if not torch.jit.is_scripting():
_log_api_usage_once(erase)

if isinstance(inpt, torch.Tensor) and (
torch.jit.is_scripting() or not isinstance(inpt, (datapoints.Image, datapoints.Video))
):
Expand Down
35 changes: 35 additions & 0 deletions torchvision/prototype/transforms/functional/_color.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
from torchvision.transforms import functional_pil as _FP
from torchvision.transforms.functional_tensor import _max_value

from torchvision.utils import _log_api_usage_once

from ._meta import _num_value_bits, _rgb_to_gray, convert_dtype_image_tensor


Expand Down Expand Up @@ -38,6 +40,9 @@ def adjust_brightness_video(video: torch.Tensor, brightness_factor: float) -> to


def adjust_brightness(inpt: datapoints.InputTypeJIT, brightness_factor: float) -> datapoints.InputTypeJIT:
if not torch.jit.is_scripting():
_log_api_usage_once(adjust_brightness)

if isinstance(inpt, torch.Tensor) and (
torch.jit.is_scripting() or not isinstance(inpt, datapoints._datapoint.Datapoint)
):
Expand Down Expand Up @@ -79,6 +84,9 @@ def adjust_saturation_video(video: torch.Tensor, saturation_factor: float) -> to


def adjust_saturation(inpt: datapoints.InputTypeJIT, saturation_factor: float) -> datapoints.InputTypeJIT:
if not torch.jit.is_scripting():
_log_api_usage_once(adjust_saturation)

if isinstance(inpt, torch.Tensor) and (
torch.jit.is_scripting() or not isinstance(inpt, datapoints._datapoint.Datapoint)
):
Expand Down Expand Up @@ -120,6 +128,9 @@ def adjust_contrast_video(video: torch.Tensor, contrast_factor: float) -> torch.


def adjust_contrast(inpt: datapoints.InputTypeJIT, contrast_factor: float) -> datapoints.InputTypeJIT:
if not torch.jit.is_scripting():
_log_api_usage_once(adjust_contrast)

if isinstance(inpt, torch.Tensor) and (
torch.jit.is_scripting() or not isinstance(inpt, datapoints._datapoint.Datapoint)
):
Expand Down Expand Up @@ -195,6 +206,9 @@ def adjust_sharpness_video(video: torch.Tensor, sharpness_factor: float) -> torc


def adjust_sharpness(inpt: datapoints.InputTypeJIT, sharpness_factor: float) -> datapoints.InputTypeJIT:
if not torch.jit.is_scripting():
_log_api_usage_once(adjust_sharpness)

if isinstance(inpt, torch.Tensor) and (
torch.jit.is_scripting() or not isinstance(inpt, datapoints._datapoint.Datapoint)
):
Expand Down Expand Up @@ -309,6 +323,9 @@ def adjust_hue_video(video: torch.Tensor, hue_factor: float) -> torch.Tensor:


def adjust_hue(inpt: datapoints.InputTypeJIT, hue_factor: float) -> datapoints.InputTypeJIT:
if not torch.jit.is_scripting():
_log_api_usage_once(adjust_hue)

if isinstance(inpt, torch.Tensor) and (
torch.jit.is_scripting() or not isinstance(inpt, datapoints._datapoint.Datapoint)
):
Expand Down Expand Up @@ -351,6 +368,9 @@ def adjust_gamma_video(video: torch.Tensor, gamma: float, gain: float = 1) -> to


def adjust_gamma(inpt: datapoints.InputTypeJIT, gamma: float, gain: float = 1) -> datapoints.InputTypeJIT:
if not torch.jit.is_scripting():
_log_api_usage_once(adjust_gamma)

if isinstance(inpt, torch.Tensor) and (
torch.jit.is_scripting() or not isinstance(inpt, datapoints._datapoint.Datapoint)
):
Expand Down Expand Up @@ -387,6 +407,9 @@ def posterize_video(video: torch.Tensor, bits: int) -> torch.Tensor:


def posterize(inpt: datapoints.InputTypeJIT, bits: int) -> datapoints.InputTypeJIT:
if not torch.jit.is_scripting():
_log_api_usage_once(posterize)

if isinstance(inpt, torch.Tensor) and (
torch.jit.is_scripting() or not isinstance(inpt, datapoints._datapoint.Datapoint)
):
Expand Down Expand Up @@ -417,6 +440,9 @@ def solarize_video(video: torch.Tensor, threshold: float) -> torch.Tensor:


def solarize(inpt: datapoints.InputTypeJIT, threshold: float) -> datapoints.InputTypeJIT:
if not torch.jit.is_scripting():
_log_api_usage_once(solarize)

if isinstance(inpt, torch.Tensor) and (
torch.jit.is_scripting() or not isinstance(inpt, datapoints._datapoint.Datapoint)
):
Expand Down Expand Up @@ -469,6 +495,9 @@ def autocontrast_video(video: torch.Tensor) -> torch.Tensor:


def autocontrast(inpt: datapoints.InputTypeJIT) -> datapoints.InputTypeJIT:
if not torch.jit.is_scripting():
_log_api_usage_once(autocontrast)

if isinstance(inpt, torch.Tensor) and (
torch.jit.is_scripting() or not isinstance(inpt, datapoints._datapoint.Datapoint)
):
Expand Down Expand Up @@ -561,6 +590,9 @@ def equalize_video(video: torch.Tensor) -> torch.Tensor:


def equalize(inpt: datapoints.InputTypeJIT) -> datapoints.InputTypeJIT:
if not torch.jit.is_scripting():
_log_api_usage_once(equalize)

if isinstance(inpt, torch.Tensor) and (
torch.jit.is_scripting() or not isinstance(inpt, datapoints._datapoint.Datapoint)
):
Expand Down Expand Up @@ -594,6 +626,9 @@ def invert_video(video: torch.Tensor) -> torch.Tensor:


def invert(inpt: datapoints.InputTypeJIT) -> datapoints.InputTypeJIT:
if not torch.jit.is_scripting():
_log_api_usage_once(invert)

if isinstance(inpt, torch.Tensor) and (
torch.jit.is_scripting() or not isinstance(inpt, datapoints._datapoint.Datapoint)
):
Expand Down
39 changes: 39 additions & 0 deletions torchvision/prototype/transforms/functional/_geometry.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@
)
from torchvision.transforms.functional_tensor import _pad_symmetric

from torchvision.utils import _log_api_usage_once

from ._meta import convert_format_bounding_box, get_spatial_size_image_pil


Expand Down Expand Up @@ -55,6 +57,9 @@ def horizontal_flip_video(video: torch.Tensor) -> torch.Tensor:


def horizontal_flip(inpt: datapoints.InputTypeJIT) -> datapoints.InputTypeJIT:
if not torch.jit.is_scripting():
_log_api_usage_once(horizontal_flip)

if isinstance(inpt, torch.Tensor) and (
torch.jit.is_scripting() or not isinstance(inpt, datapoints._datapoint.Datapoint)
):
Expand Down Expand Up @@ -103,6 +108,9 @@ def vertical_flip_video(video: torch.Tensor) -> torch.Tensor:


def vertical_flip(inpt: datapoints.InputTypeJIT) -> datapoints.InputTypeJIT:
if not torch.jit.is_scripting():
_log_api_usage_once(vertical_flip)

if isinstance(inpt, torch.Tensor) and (
torch.jit.is_scripting() or not isinstance(inpt, datapoints._datapoint.Datapoint)
):
Expand Down Expand Up @@ -231,6 +239,8 @@ def resize(
max_size: Optional[int] = None,
antialias: Optional[bool] = None,
) -> datapoints.InputTypeJIT:
if not torch.jit.is_scripting():
_log_api_usage_once(resize)
if isinstance(inpt, torch.Tensor) and (
torch.jit.is_scripting() or not isinstance(inpt, datapoints._datapoint.Datapoint)
):
Expand Down Expand Up @@ -730,6 +740,9 @@ def affine(
fill: datapoints.FillTypeJIT = None,
center: Optional[List[float]] = None,
) -> datapoints.InputTypeJIT:
if not torch.jit.is_scripting():
_log_api_usage_once(affine)

# TODO: consider deprecating integers from angle and shear on the future
if isinstance(inpt, torch.Tensor) and (
torch.jit.is_scripting() or not isinstance(inpt, datapoints._datapoint.Datapoint)
Expand Down Expand Up @@ -913,6 +926,9 @@ def rotate(
center: Optional[List[float]] = None,
fill: datapoints.FillTypeJIT = None,
) -> datapoints.InputTypeJIT:
if not torch.jit.is_scripting():
_log_api_usage_once(rotate)

if isinstance(inpt, torch.Tensor) and (
torch.jit.is_scripting() or not isinstance(inpt, datapoints._datapoint.Datapoint)
):
Expand Down Expand Up @@ -1120,6 +1136,9 @@ def pad(
fill: datapoints.FillTypeJIT = None,
padding_mode: str = "constant",
) -> datapoints.InputTypeJIT:
if not torch.jit.is_scripting():
_log_api_usage_once(pad)

if isinstance(inpt, torch.Tensor) and (
torch.jit.is_scripting() or not isinstance(inpt, datapoints._datapoint.Datapoint)
):
Expand Down Expand Up @@ -1197,6 +1216,9 @@ def crop_video(video: torch.Tensor, top: int, left: int, height: int, width: int


def crop(inpt: datapoints.InputTypeJIT, top: int, left: int, height: int, width: int) -> datapoints.InputTypeJIT:
if not torch.jit.is_scripting():
_log_api_usage_once(crop)

if isinstance(inpt, torch.Tensor) and (
torch.jit.is_scripting() or not isinstance(inpt, datapoints._datapoint.Datapoint)
):
Expand Down Expand Up @@ -1452,6 +1474,8 @@ def perspective(
fill: datapoints.FillTypeJIT = None,
coefficients: Optional[List[float]] = None,
) -> datapoints.InputTypeJIT:
if not torch.jit.is_scripting():
_log_api_usage_once(perspective)
if isinstance(inpt, torch.Tensor) and (
torch.jit.is_scripting() or not isinstance(inpt, datapoints._datapoint.Datapoint)
):
Expand Down Expand Up @@ -1612,6 +1636,9 @@ def elastic(
interpolation: InterpolationMode = InterpolationMode.BILINEAR,
fill: datapoints.FillTypeJIT = None,
) -> datapoints.InputTypeJIT:
if not torch.jit.is_scripting():
_log_api_usage_once(elastic)

if isinstance(inpt, torch.Tensor) and (
torch.jit.is_scripting() or not isinstance(inpt, datapoints._datapoint.Datapoint)
):
Expand Down Expand Up @@ -1724,6 +1751,9 @@ def center_crop_video(video: torch.Tensor, output_size: List[int]) -> torch.Tens


def center_crop(inpt: datapoints.InputTypeJIT, output_size: List[int]) -> datapoints.InputTypeJIT:
if not torch.jit.is_scripting():
_log_api_usage_once(center_crop)

if isinstance(inpt, torch.Tensor) and (
torch.jit.is_scripting() or not isinstance(inpt, datapoints._datapoint.Datapoint)
):
Expand Down Expand Up @@ -1817,6 +1847,9 @@ def resized_crop(
interpolation: InterpolationMode = InterpolationMode.BILINEAR,
antialias: Optional[bool] = None,
) -> datapoints.InputTypeJIT:
if not torch.jit.is_scripting():
_log_api_usage_once(resized_crop)

if isinstance(inpt, torch.Tensor) and (
torch.jit.is_scripting() or not isinstance(inpt, datapoints._datapoint.Datapoint)
):
Expand Down Expand Up @@ -1897,6 +1930,9 @@ def five_crop_video(
def five_crop(
inpt: ImageOrVideoTypeJIT, size: List[int]
) -> Tuple[ImageOrVideoTypeJIT, ImageOrVideoTypeJIT, ImageOrVideoTypeJIT, ImageOrVideoTypeJIT, ImageOrVideoTypeJIT]:
if not torch.jit.is_scripting():
_log_api_usage_once(five_crop)

# TODO: consider breaking BC here to return List[datapoints.ImageTypeJIT/VideoTypeJIT] to align this op with
# `ten_crop`
if isinstance(inpt, torch.Tensor) and (
Expand Down Expand Up @@ -1952,6 +1988,9 @@ def ten_crop_video(video: torch.Tensor, size: List[int], vertical_flip: bool = F
def ten_crop(
inpt: Union[datapoints.ImageTypeJIT, datapoints.VideoTypeJIT], size: List[int], vertical_flip: bool = False
) -> Union[List[datapoints.ImageTypeJIT], List[datapoints.VideoTypeJIT]]:
if not torch.jit.is_scripting():
_log_api_usage_once(ten_crop)

if isinstance(inpt, torch.Tensor) and (
torch.jit.is_scripting() or not isinstance(inpt, (datapoints.Image, datapoints.Video))
):
Expand Down
Loading