Skip to content

Commit 42b8e7d

Browse files
authored
Merge branch 'main' into completely-remove-requests
2 parents 1a4ef77 + a7a3675 commit 42b8e7d

File tree

3 files changed

+303
-94
lines changed

3 files changed

+303
-94
lines changed

test/test_ops.py

Lines changed: 121 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -7,12 +7,54 @@
77
import numpy as np
88
import pytest
99
import torch
10+
import torch.fx
1011
from common_utils import needs_cuda, cpu_and_gpu, assert_equal
1112
from PIL import Image
1213
from torch import nn, Tensor
1314
from torch.autograd import gradcheck
1415
from torch.nn.modules.utils import _pair
1516
from torchvision import models, ops
17+
from torchvision.models.feature_extraction import get_graph_node_names
18+
19+
20+
class RoIOpTesterModuleWrapper(nn.Module):
21+
def __init__(self, obj):
22+
super().__init__()
23+
self.layer = obj
24+
self.n_inputs = 2
25+
26+
def forward(self, a, b):
27+
self.layer(a, b)
28+
29+
30+
class MultiScaleRoIAlignModuleWrapper(nn.Module):
31+
def __init__(self, obj):
32+
super().__init__()
33+
self.layer = obj
34+
self.n_inputs = 3
35+
36+
def forward(self, a, b, c):
37+
self.layer(a, b, c)
38+
39+
40+
class DeformConvModuleWrapper(nn.Module):
41+
def __init__(self, obj):
42+
super().__init__()
43+
self.layer = obj
44+
self.n_inputs = 3
45+
46+
def forward(self, a, b, c):
47+
self.layer(a, b, c)
48+
49+
50+
class StochasticDepthWrapper(nn.Module):
51+
def __init__(self, obj):
52+
super().__init__()
53+
self.layer = obj
54+
self.n_inputs = 1
55+
56+
def forward(self, a):
57+
self.layer(a)
1658

1759

1860
class RoIOpTester(ABC):
@@ -46,6 +88,15 @@ def test_forward(self, device, contiguous, x_dtype=None, rois_dtype=None, **kwar
4688
tol = 1e-3 if (x_dtype is torch.half or rois_dtype is torch.half) else 1e-5
4789
torch.testing.assert_close(gt_y.to(y), y, rtol=tol, atol=tol)
4890

91+
@pytest.mark.parametrize("device", cpu_and_gpu())
92+
def test_is_leaf_node(self, device):
93+
op_obj = self.make_obj(wrap=True).to(device=device)
94+
graph_node_names = get_graph_node_names(op_obj)
95+
96+
assert len(graph_node_names) == 2
97+
assert len(graph_node_names[0]) == len(graph_node_names[1])
98+
assert len(graph_node_names[0]) == 1 + op_obj.n_inputs
99+
49100
@pytest.mark.parametrize("seed", range(10))
50101
@pytest.mark.parametrize("device", cpu_and_gpu())
51102
@pytest.mark.parametrize("contiguous", (True, False))
@@ -91,6 +142,10 @@ def _helper_boxes_shape(self, func):
91142
def fn(*args, **kwargs):
92143
pass
93144

145+
@abstractmethod
146+
def make_obj(*args, **kwargs):
147+
pass
148+
94149
@abstractmethod
95150
def get_script_fn(*args, **kwargs):
96151
pass
@@ -104,6 +159,10 @@ class TestRoiPool(RoIOpTester):
104159
def fn(self, x, rois, pool_h, pool_w, spatial_scale=1, sampling_ratio=-1, **kwargs):
105160
return ops.RoIPool((pool_h, pool_w), spatial_scale)(x, rois)
106161

162+
def make_obj(self, pool_h=5, pool_w=5, spatial_scale=1, wrap=False):
163+
obj = ops.RoIPool((pool_h, pool_w), spatial_scale)
164+
return RoIOpTesterModuleWrapper(obj) if wrap else obj
165+
107166
def get_script_fn(self, rois, pool_size):
108167
scriped = torch.jit.script(ops.roi_pool)
109168
return lambda x: scriped(x, rois, pool_size)
@@ -144,6 +203,10 @@ class TestPSRoIPool(RoIOpTester):
144203
def fn(self, x, rois, pool_h, pool_w, spatial_scale=1, sampling_ratio=-1, **kwargs):
145204
return ops.PSRoIPool((pool_h, pool_w), 1)(x, rois)
146205

206+
def make_obj(self, pool_h=5, pool_w=5, spatial_scale=1, wrap=False):
207+
obj = ops.PSRoIPool((pool_h, pool_w), spatial_scale)
208+
return RoIOpTesterModuleWrapper(obj) if wrap else obj
209+
147210
def get_script_fn(self, rois, pool_size):
148211
scriped = torch.jit.script(ops.ps_roi_pool)
149212
return lambda x: scriped(x, rois, pool_size)
@@ -223,6 +286,12 @@ def fn(self, x, rois, pool_h, pool_w, spatial_scale=1, sampling_ratio=-1, aligne
223286
(pool_h, pool_w), spatial_scale=spatial_scale, sampling_ratio=sampling_ratio, aligned=aligned
224287
)(x, rois)
225288

289+
def make_obj(self, pool_h=5, pool_w=5, spatial_scale=1, sampling_ratio=-1, aligned=False, wrap=False):
290+
obj = ops.RoIAlign(
291+
(pool_h, pool_w), spatial_scale=spatial_scale, sampling_ratio=sampling_ratio, aligned=aligned
292+
)
293+
return RoIOpTesterModuleWrapper(obj) if wrap else obj
294+
226295
def get_script_fn(self, rois, pool_size):
227296
scriped = torch.jit.script(ops.roi_align)
228297
return lambda x: scriped(x, rois, pool_size)
@@ -374,6 +443,10 @@ class TestPSRoIAlign(RoIOpTester):
374443
def fn(self, x, rois, pool_h, pool_w, spatial_scale=1, sampling_ratio=-1, **kwargs):
375444
return ops.PSRoIAlign((pool_h, pool_w), spatial_scale=spatial_scale, sampling_ratio=sampling_ratio)(x, rois)
376445

446+
def make_obj(self, pool_h=5, pool_w=5, spatial_scale=1, sampling_ratio=-1, wrap=False):
447+
obj = ops.PSRoIAlign((pool_h, pool_w), spatial_scale=spatial_scale, sampling_ratio=sampling_ratio)
448+
return RoIOpTesterModuleWrapper(obj) if wrap else obj
449+
377450
def get_script_fn(self, rois, pool_size):
378451
scriped = torch.jit.script(ops.ps_roi_align)
379452
return lambda x: scriped(x, rois, pool_size)
@@ -422,12 +495,18 @@ def test_boxes_shape(self):
422495

423496

424497
class TestMultiScaleRoIAlign:
498+
def make_obj(self, fmap_names=None, output_size=(7, 7), sampling_ratio=2, wrap=False):
499+
if fmap_names is None:
500+
fmap_names = ["0"]
501+
obj = ops.poolers.MultiScaleRoIAlign(fmap_names, output_size, sampling_ratio)
502+
return MultiScaleRoIAlignModuleWrapper(obj) if wrap else obj
503+
425504
def test_msroialign_repr(self):
426505
fmap_names = ["0"]
427506
output_size = (7, 7)
428507
sampling_ratio = 2
429508
# Pass mock feature map names
430-
t = ops.poolers.MultiScaleRoIAlign(fmap_names, output_size, sampling_ratio)
509+
t = self.make_obj(fmap_names, output_size, sampling_ratio, wrap=False)
431510

432511
# Check integrity of object __repr__ attribute
433512
expected_string = (
@@ -436,6 +515,15 @@ def test_msroialign_repr(self):
436515
)
437516
assert repr(t) == expected_string
438517

518+
@pytest.mark.parametrize("device", cpu_and_gpu())
519+
def test_is_leaf_node(self, device):
520+
op_obj = self.make_obj(wrap=True).to(device=device)
521+
graph_node_names = get_graph_node_names(op_obj)
522+
523+
assert len(graph_node_names) == 2
524+
assert len(graph_node_names[0]) == len(graph_node_names[1])
525+
assert len(graph_node_names[0]) == 1 + op_obj.n_inputs
526+
439527

440528
class TestNMS:
441529
def _reference_nms(self, boxes, scores, iou_threshold):
@@ -693,6 +781,21 @@ def get_fn_args(self, device, contiguous, batch_sz, dtype):
693781

694782
return x, weight, offset, mask, bias, stride, pad, dilation
695783

784+
def make_obj(self, in_channels=6, out_channels=2, kernel_size=(3, 2), groups=2, wrap=False):
785+
obj = ops.DeformConv2d(
786+
in_channels, out_channels, kernel_size, stride=(2, 1), padding=(1, 0), dilation=(2, 1), groups=groups
787+
)
788+
return DeformConvModuleWrapper(obj) if wrap else obj
789+
790+
@pytest.mark.parametrize("device", cpu_and_gpu())
791+
def test_is_leaf_node(self, device):
792+
op_obj = self.make_obj(wrap=True).to(device=device)
793+
graph_node_names = get_graph_node_names(op_obj)
794+
795+
assert len(graph_node_names) == 2
796+
assert len(graph_node_names[0]) == len(graph_node_names[1])
797+
assert len(graph_node_names[0]) == 1 + op_obj.n_inputs
798+
696799
@pytest.mark.parametrize("device", cpu_and_gpu())
697800
@pytest.mark.parametrize("contiguous", (True, False))
698801
@pytest.mark.parametrize("batch_sz", (0, 33))
@@ -705,9 +808,9 @@ def test_forward(self, device, contiguous, batch_sz, dtype=None):
705808
groups = 2
706809
tol = 2e-3 if dtype is torch.half else 1e-5
707810

708-
layer = ops.DeformConv2d(
709-
in_channels, out_channels, kernel_size, stride=stride, padding=padding, dilation=dilation, groups=groups
710-
).to(device=x.device, dtype=dtype)
811+
layer = self.make_obj(in_channels, out_channels, kernel_size, groups, wrap=False).to(
812+
device=x.device, dtype=dtype
813+
)
711814
res = layer(x, offset, mask)
712815

713816
weight = layer.weight.data
@@ -1200,6 +1303,20 @@ def test_stochastic_depth(self, seed, mode, p):
12001303
elif p == 1:
12011304
assert out.equal(torch.zeros_like(x))
12021305

1306+
def make_obj(self, p, mode, wrap=False):
1307+
obj = ops.StochasticDepth(p, mode)
1308+
return StochasticDepthWrapper(obj) if wrap else obj
1309+
1310+
@pytest.mark.parametrize("p", (0, 1))
1311+
@pytest.mark.parametrize("mode", ["batch", "row"])
1312+
def test_is_leaf_node(self, p, mode):
1313+
op_obj = self.make_obj(p, mode, wrap=True)
1314+
graph_node_names = get_graph_node_names(op_obj)
1315+
1316+
assert len(graph_node_names) == 2
1317+
assert len(graph_node_names[0]) == len(graph_node_names[1])
1318+
assert len(graph_node_names[0]) == 1 + op_obj.n_inputs
1319+
12031320

12041321
class TestUtils:
12051322
@pytest.mark.parametrize("norm_layer", [None, nn.BatchNorm2d, nn.LayerNorm])

torchvision/models/feature_extraction.py

Lines changed: 35 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,14 @@
1+
import inspect
2+
import math
13
import re
24
import warnings
35
from collections import OrderedDict
46
from copy import deepcopy
57
from itertools import chain
6-
from typing import Dict, Callable, List, Union, Optional, Tuple
8+
from typing import Dict, Callable, List, Union, Optional, Tuple, Any
79

810
import torch
11+
import torchvision
912
from torch import fx
1013
from torch import nn
1114
from torch.fx.graph_module import _copy_attr
@@ -172,8 +175,19 @@ def _warn_graph_differences(train_tracer: NodePathTracer, eval_tracer: NodePathT
172175
warnings.warn(msg + suggestion_msg)
173176

174177

178+
def _get_leaf_modules_for_ops() -> List[type]:
179+
members = inspect.getmembers(torchvision.ops)
180+
result = []
181+
for _, obj in members:
182+
if inspect.isclass(obj) and issubclass(obj, torch.nn.Module):
183+
result.append(obj)
184+
return result
185+
186+
175187
def get_graph_node_names(
176-
model: nn.Module, tracer_kwargs: Dict = {}, suppress_diff_warning: bool = False
188+
model: nn.Module,
189+
tracer_kwargs: Optional[Dict[str, Any]] = None,
190+
suppress_diff_warning: bool = False,
177191
) -> Tuple[List[str], List[str]]:
178192
"""
179193
Dev utility to return node names in order of execution. See note on node
@@ -198,6 +212,7 @@ def get_graph_node_names(
198212
tracer_kwargs (dict, optional): a dictionary of keywork arguments for
199213
``NodePathTracer`` (they are eventually passed onto
200214
`torch.fx.Tracer <https://pytorch.org/docs/stable/fx.html#torch.fx.Tracer>`_).
215+
By default it will be set to wrap and make leaf nodes all torchvision ops.
201216
suppress_diff_warning (bool, optional): whether to suppress a warning
202217
when there are discrepancies between the train and eval version of
203218
the graph. Defaults to False.
@@ -211,6 +226,14 @@ def get_graph_node_names(
211226
>>> model = torchvision.models.resnet18()
212227
>>> train_nodes, eval_nodes = get_graph_node_names(model)
213228
"""
229+
if tracer_kwargs is None:
230+
tracer_kwargs = {
231+
"autowrap_modules": (
232+
math,
233+
torchvision.ops,
234+
),
235+
"leaf_modules": _get_leaf_modules_for_ops(),
236+
}
214237
is_training = model.training
215238
train_tracer = NodePathTracer(**tracer_kwargs)
216239
train_tracer.trace(model.train())
@@ -294,7 +317,7 @@ def create_feature_extractor(
294317
return_nodes: Optional[Union[List[str], Dict[str, str]]] = None,
295318
train_return_nodes: Optional[Union[List[str], Dict[str, str]]] = None,
296319
eval_return_nodes: Optional[Union[List[str], Dict[str, str]]] = None,
297-
tracer_kwargs: Dict = {},
320+
tracer_kwargs: Optional[Dict[str, Any]] = None,
298321
suppress_diff_warning: bool = False,
299322
) -> fx.GraphModule:
300323
"""
@@ -353,6 +376,7 @@ def create_feature_extractor(
353376
tracer_kwargs (dict, optional): a dictionary of keywork arguments for
354377
``NodePathTracer`` (which passes them onto it's parent class
355378
`torch.fx.Tracer <https://pytorch.org/docs/stable/fx.html#torch.fx.Tracer>`_).
379+
By default it will be set to wrap and make leaf nodes all torchvision ops.
356380
suppress_diff_warning (bool, optional): whether to suppress a warning
357381
when there are discrepancies between the train and eval version of
358382
the graph. Defaults to False.
@@ -397,6 +421,14 @@ def create_feature_extractor(
397421
>>> 'autowrap_functions': [leaf_function]})
398422
399423
"""
424+
if tracer_kwargs is None:
425+
tracer_kwargs = {
426+
"autowrap_modules": (
427+
math,
428+
torchvision.ops,
429+
),
430+
"leaf_modules": _get_leaf_modules_for_ops(),
431+
}
400432
is_training = model.training
401433

402434
assert any(

0 commit comments

Comments
 (0)