Skip to content

Commit 6d575eb

Browse files
fmassafacebook-github-bot
authored andcommitted
[fbsync] Add multi-weight support for VideoResNet (#4770)
Summary: * Add mutli-weight support for VideoResNet. * Fix linter. * Minor refactoring. * Update comments. Reviewed By: datumbox Differential Revision: D32064688 fbshipit-source-id: f66c1d321d1dbdb19384858c5d9ac757d5c93a36
1 parent e694f5b commit 6d575eb

File tree

7 files changed

+606
-8
lines changed

7 files changed

+606
-8
lines changed

test/test_prototype_models.py

Lines changed: 29 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,11 @@ def _build_model(fn, **kwargs):
2525
return model.eval()
2626

2727

28+
def get_models_with_module_names(module):
29+
module_name = module.__name__.split(".")[-1]
30+
return [(fn, module_name) for fn in TM.get_models_from_module(module)]
31+
32+
2833
def test_get_weight():
2934
fn = models.resnet50
3035
weight_name = "ImageNet1K_RefV2"
@@ -45,16 +50,35 @@ def test_segmentation_model(model_fn, dev):
4550
TM.test_segmentation_model(model_fn, dev)
4651

4752

48-
@pytest.mark.parametrize("model_fn", TM.get_models_from_module(models) + TM.get_models_from_module(models.segmentation))
53+
@pytest.mark.parametrize("model_fn", TM.get_models_from_module(models.video))
54+
@pytest.mark.parametrize("dev", cpu_and_gpu())
55+
@pytest.mark.skipif(os.getenv("PYTORCH_TEST_WITH_PROTOTYPE", "0") == "0", reason="Prototype code tests are disabled")
56+
def test_video_model(model_fn, dev):
57+
TM.test_video_model(model_fn, dev)
58+
59+
60+
@pytest.mark.parametrize(
61+
"model_fn, module_name",
62+
get_models_with_module_names(models)
63+
+ get_models_with_module_names(models.segmentation)
64+
+ get_models_with_module_names(models.video),
65+
)
4966
@pytest.mark.parametrize("dev", cpu_and_gpu())
5067
@pytest.mark.skipif(os.getenv("PYTORCH_TEST_WITH_PROTOTYPE", "0") == "0", reason="Prototype code tests are disabled")
51-
def test_old_vs_new_factory(model_fn, dev):
68+
def test_old_vs_new_factory(model_fn, module_name, dev):
5269
defaults = {
53-
"pretrained": True,
54-
"input_shape": (1, 3, 224, 224),
70+
"models": {
71+
"input_shape": (1, 3, 224, 224),
72+
},
73+
"segmentation": {
74+
"input_shape": (1, 3, 520, 520),
75+
},
76+
"video": {
77+
"input_shape": (1, 3, 4, 112, 112),
78+
},
5579
}
5680
model_name = model_fn.__name__
57-
kwargs = {**defaults, **TM._model_params.get(model_name, {})}
81+
kwargs = {"pretrained": True, **defaults[module_name], **TM._model_params.get(model_name, {})}
5882
input_shape = kwargs.pop("input_shape")
5983
x = torch.rand(input_shape).to(device=dev)
6084

torchvision/models/video/resnet.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from typing import Tuple, Optional, Callable, List, Type, Any, Union
1+
from typing import Tuple, Optional, Callable, List, Sequence, Type, Any, Union
22

33
import torch.nn as nn
44
from torch import Tensor
@@ -191,7 +191,7 @@ class VideoResNet(nn.Module):
191191
def __init__(
192192
self,
193193
block: Type[Union[BasicBlock, Bottleneck]],
194-
conv_makers: List[Type[Union[Conv3DSimple, Conv3DNoTemporal, Conv2Plus1D]]],
194+
conv_makers: Sequence[Type[Union[Conv3DSimple, Conv3DNoTemporal, Conv2Plus1D]]],
195195
layers: List[int],
196196
stem: Callable[..., nn.Module],
197197
num_classes: int = 400,

torchvision/prototype/models/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,3 +8,4 @@
88
from . import detection
99
from . import quantization
1010
from . import segmentation
11+
from . import video

0 commit comments

Comments
 (0)