@@ -25,6 +25,11 @@ def _build_model(fn, **kwargs):
25
25
return model .eval ()
26
26
27
27
28
+ def get_models_with_module_names (module ):
29
+ module_name = module .__name__ .split ("." )[- 1 ]
30
+ return [(fn , module_name ) for fn in TM .get_models_from_module (module )]
31
+
32
+
28
33
def test_get_weight ():
29
34
fn = models .resnet50
30
35
weight_name = "ImageNet1K_RefV2"
@@ -45,16 +50,35 @@ def test_segmentation_model(model_fn, dev):
45
50
TM .test_segmentation_model (model_fn , dev )
46
51
47
52
48
- @pytest .mark .parametrize ("model_fn" , TM .get_models_from_module (models ) + TM .get_models_from_module (models .segmentation ))
53
+ @pytest .mark .parametrize ("model_fn" , TM .get_models_from_module (models .video ))
54
+ @pytest .mark .parametrize ("dev" , cpu_and_gpu ())
55
+ @pytest .mark .skipif (os .getenv ("PYTORCH_TEST_WITH_PROTOTYPE" , "0" ) == "0" , reason = "Prototype code tests are disabled" )
56
+ def test_video_model (model_fn , dev ):
57
+ TM .test_video_model (model_fn , dev )
58
+
59
+
60
+ @pytest .mark .parametrize (
61
+ "model_fn, module_name" ,
62
+ get_models_with_module_names (models )
63
+ + get_models_with_module_names (models .segmentation )
64
+ + get_models_with_module_names (models .video ),
65
+ )
49
66
@pytest .mark .parametrize ("dev" , cpu_and_gpu ())
50
67
@pytest .mark .skipif (os .getenv ("PYTORCH_TEST_WITH_PROTOTYPE" , "0" ) == "0" , reason = "Prototype code tests are disabled" )
51
- def test_old_vs_new_factory (model_fn , dev ):
68
+ def test_old_vs_new_factory (model_fn , module_name , dev ):
52
69
defaults = {
53
- "pretrained" : True ,
54
- "input_shape" : (1 , 3 , 224 , 224 ),
70
+ "models" : {
71
+ "input_shape" : (1 , 3 , 224 , 224 ),
72
+ },
73
+ "segmentation" : {
74
+ "input_shape" : (1 , 3 , 520 , 520 ),
75
+ },
76
+ "video" : {
77
+ "input_shape" : (1 , 3 , 4 , 112 , 112 ),
78
+ },
55
79
}
56
80
model_name = model_fn .__name__
57
- kwargs = {** defaults , ** TM ._model_params .get (model_name , {})}
81
+ kwargs = {"pretrained" : True , ** defaults [ module_name ] , ** TM ._model_params .get (model_name , {})}
58
82
input_shape = kwargs .pop ("input_shape" )
59
83
x = torch .rand (input_shape ).to (device = dev )
60
84
0 commit comments