Skip to content

Commit d3226ff

Browse files
committed
clarify comment
2 parents 195a458 + 0040fe7 commit d3226ff

File tree

4 files changed

+23
-90
lines changed

4 files changed

+23
-90
lines changed

test/test_transforms_v2.py

Lines changed: 0 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -545,38 +545,3 @@ def test_sanitize_bounding_boxes_errors():
545545
with pytest.raises(ValueError, match="Number of boxes"):
546546
different_sizes = {"bbox": good_bbox, "labels": torch.arange(good_bbox.shape[0] + 3)}
547547
transforms.SanitizeBoundingBoxes()(different_sizes)
548-
549-
550-
class TestLambda:
551-
inputs = pytest.mark.parametrize("input", [object(), torch.empty(()), np.empty(()), "string", 1, 0.0])
552-
553-
@inputs
554-
def test_default(self, input):
555-
was_applied = False
556-
557-
def was_applied_fn(input):
558-
nonlocal was_applied
559-
was_applied = True
560-
return input
561-
562-
transform = transforms.Lambda(was_applied_fn)
563-
564-
transform(input)
565-
566-
assert was_applied
567-
568-
@inputs
569-
def test_with_types(self, input):
570-
was_applied = False
571-
572-
def was_applied_fn(input):
573-
nonlocal was_applied
574-
was_applied = True
575-
return input
576-
577-
types = (torch.Tensor, np.ndarray)
578-
transform = transforms.Lambda(was_applied_fn, *types)
579-
580-
transform(input)
581-
582-
assert was_applied is isinstance(input, types)

test/test_transforms_v2_consistency.py

Lines changed: 1 addition & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,6 @@
1212
import torchvision.transforms.v2 as v2_transforms
1313
from common_utils import assert_close, assert_equal, set_rng_seed
1414
from torchvision import transforms as legacy_transforms, tv_tensors
15-
from torchvision._utils import sequence_to_str
1615

1716
from torchvision.transforms import functional as legacy_F
1817
from torchvision.transforms.v2 import functional as prototype_F
@@ -70,57 +69,7 @@ def __init__(
7069
LINEAR_TRANSFORMATION_MEAN = torch.rand(36)
7170
LINEAR_TRANSFORMATION_MATRIX = torch.rand([LINEAR_TRANSFORMATION_MEAN.numel()] * 2)
7271

73-
CONSISTENCY_CONFIGS = [
74-
ConsistencyConfig(
75-
v2_transforms.Lambda,
76-
legacy_transforms.Lambda,
77-
[
78-
NotScriptableArgsKwargs(lambda image: image / 2),
79-
],
80-
# Technically, this also supports PIL, but it is overkill to write a function here that supports tensor and PIL
81-
# images given that the transform does nothing but call it anyway.
82-
supports_pil=False,
83-
),
84-
]
85-
86-
87-
@pytest.mark.parametrize("config", CONSISTENCY_CONFIGS, ids=lambda config: config.legacy_cls.__name__)
88-
def test_signature_consistency(config):
89-
legacy_params = dict(inspect.signature(config.legacy_cls).parameters)
90-
prototype_params = dict(inspect.signature(config.prototype_cls).parameters)
91-
92-
for param in config.removed_params:
93-
legacy_params.pop(param, None)
94-
95-
missing = legacy_params.keys() - prototype_params.keys()
96-
if missing:
97-
raise AssertionError(
98-
f"The prototype transform does not support the parameters "
99-
f"{sequence_to_str(sorted(missing), separate_last='and ')}, but the legacy transform does. "
100-
f"If that is intentional, e.g. pending deprecation, please add the parameters to the `removed_params` on "
101-
f"the `ConsistencyConfig`."
102-
)
103-
104-
extra = prototype_params.keys() - legacy_params.keys()
105-
extra_without_default = {
106-
param
107-
for param in extra
108-
if prototype_params[param].default is inspect.Parameter.empty
109-
and prototype_params[param].kind not in {inspect.Parameter.VAR_POSITIONAL, inspect.Parameter.VAR_KEYWORD}
110-
}
111-
if extra_without_default:
112-
raise AssertionError(
113-
f"The prototype transform requires the parameters "
114-
f"{sequence_to_str(sorted(extra_without_default), separate_last='and ')}, but the legacy transform does "
115-
f"not. Please add a default value."
116-
)
117-
118-
legacy_signature = list(legacy_params.keys())
119-
# Since we made sure that we don't have any extra parameters without default above, we clamp the prototype signature
120-
# to the same number of parameters as the legacy one
121-
prototype_signature = list(prototype_params.keys())[: len(legacy_signature)]
122-
123-
assert prototype_signature == legacy_signature
72+
CONSISTENCY_CONFIGS = []
12473

12574

12675
def check_call_consistency(

test/test_transforms_v2_refactored.py

Lines changed: 21 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1906,8 +1906,9 @@ def test_random_order(self):
19061906
input = make_image()
19071907

19081908
actual = check_transform(transform, input)
1909-
# horizontal and vertical flip are commutative. Meaning, although the order in the transform is indeed random,
1910-
# we don't need to care here.
1909+
# We can't really check whether the transforms are actually applied in random order. However, horizontal and
1910+
# vertical flip are commutative. Meaning, even under the assumption that the transform applies them in random
1911+
# order, we can use a fixed order to compute the expected value.
19111912
expected = F.vertical_flip(F.horizontal_flip(input))
19121913

19131914
assert_equal(actual, expected)
@@ -5221,3 +5222,21 @@ def test_functional_and_transform(self, color_space, fn):
52215222
def test_functional_error(self):
52225223
with pytest.raises(TypeError, match="pic should be PIL Image"):
52235224
F.pil_to_tensor(object())
5225+
5226+
5227+
class TestLambda:
5228+
@pytest.mark.parametrize("input", [object(), torch.empty(()), np.empty(()), "string", 1, 0.0])
5229+
@pytest.mark.parametrize("types", [(), (torch.Tensor, np.ndarray)])
5230+
def test_transform(self, input, types):
5231+
was_applied = False
5232+
5233+
def was_applied_fn(input):
5234+
nonlocal was_applied
5235+
was_applied = True
5236+
return input
5237+
5238+
transform = transforms.Lambda(was_applied_fn, *types)
5239+
output = transform(input)
5240+
5241+
assert output is input
5242+
assert was_applied is (not types or isinstance(input, types))

torchvision/transforms/functional.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1270,7 +1270,7 @@ def rgb_to_grayscale(img: Tensor, num_output_channels: int = 1) -> Tensor:
12701270
12711271
Note:
12721272
Please, note that this method supports only RGB images as input. For inputs in other color spaces,
1273-
please, consider using meth:`~torchvision.transforms.functional.to_grayscale` with PIL Image.
1273+
please, consider using :meth:`~torchvision.transforms.functional.to_grayscale` with PIL Image.
12741274
12751275
Args:
12761276
img (PIL Image or Tensor): RGB Image to be converted to grayscale.

0 commit comments

Comments
 (0)