10
10
import torch
11
11
import torch .fx
12
12
import torch .nn as nn
13
- import torchvision
14
13
from _utils_internal import get_relative_path
15
14
from common_utils import map_nested_tensor_object , freeze_rng_state , set_rng_seed , cpu_and_gpu , needs_cuda
16
15
from torchvision import models
19
18
ACCEPT = os .getenv ("EXPECTTEST_ACCEPT" , "0" ) == "1"
20
19
21
20
22
- def get_available_classification_models ( ):
21
+ def get_models_from_module ( module ):
23
22
# TODO add a registration mechanism to torchvision.models
24
- return [k for k , v in models .__dict__ .items () if callable (v ) and k [0 ].lower () == k [0 ] and k [0 ] != "_" ]
25
-
26
-
27
- def get_available_segmentation_models ():
28
- # TODO add a registration mechanism to torchvision.models
29
- return [k for k , v in models .segmentation .__dict__ .items () if callable (v ) and k [0 ].lower () == k [0 ] and k [0 ] != "_" ]
30
-
31
-
32
- def get_available_detection_models ():
33
- # TODO add a registration mechanism to torchvision.models
34
- return [k for k , v in models .detection .__dict__ .items () if callable (v ) and k [0 ].lower () == k [0 ] and k [0 ] != "_" ]
35
-
36
-
37
- def get_available_video_models ():
38
- # TODO add a registration mechanism to torchvision.models
39
- return [k for k , v in models .video .__dict__ .items () if callable (v ) and k [0 ].lower () == k [0 ] and k [0 ] != "_" ]
40
-
41
-
42
- def get_available_quantizable_models ():
43
- # TODO add a registration mechanism to torchvision.models
44
- return [k for k , v in models .quantization .__dict__ .items () if callable (v ) and k [0 ].lower () == k [0 ] and k [0 ] != "_" ]
23
+ return [v for k , v in module .__dict__ .items () if callable (v ) and k [0 ].lower () == k [0 ] and k [0 ] != "_" ]
45
24
46
25
47
26
def _get_expected_file (name = None ):
@@ -314,20 +293,20 @@ def _make_sliced_model(model, stop_layer):
314
293
return new_model
315
294
316
295
317
- @pytest .mark .parametrize ("model_name " , [" densenet121" , " densenet169" , " densenet201" , " densenet161" ])
318
- def test_memory_efficient_densenet (model_name ):
296
+ @pytest .mark .parametrize ("model_fn " , [models . densenet121 , models . densenet169 , models . densenet201 , models . densenet161 ])
297
+ def test_memory_efficient_densenet (model_fn ):
319
298
input_shape = (1 , 3 , 300 , 300 )
320
299
x = torch .rand (input_shape )
321
300
322
- model1 = models . __dict__ [ model_name ] (num_classes = 50 , memory_efficient = True )
301
+ model1 = model_fn (num_classes = 50 , memory_efficient = True )
323
302
params = model1 .state_dict ()
324
303
num_params = sum ([x .numel () for x in model1 .parameters ()])
325
304
model1 .eval ()
326
305
out1 = model1 (x )
327
306
out1 .sum ().backward ()
328
307
num_grad = sum ([x .grad .numel () for x in model1 .parameters () if x .grad is not None ])
329
308
330
- model2 = models . __dict__ [ model_name ] (num_classes = 50 , memory_efficient = False )
309
+ model2 = model_fn (num_classes = 50 , memory_efficient = False )
331
310
model2 .load_state_dict (params )
332
311
model2 .eval ()
333
312
out2 = model2 (x )
@@ -344,7 +323,7 @@ def test_memory_efficient_densenet(model_name):
344
323
@pytest .mark .parametrize ("dilate_layer_4" , (True , False ))
345
324
def test_resnet_dilation (dilate_layer_2 , dilate_layer_3 , dilate_layer_4 ):
346
325
# TODO improve tests to also check that each layer has the right dimensionality
347
- model = models .__dict__ [ " resnet50" ] (replace_stride_with_dilation = (dilate_layer_2 , dilate_layer_3 , dilate_layer_4 ))
326
+ model = models .resnet50 (replace_stride_with_dilation = (dilate_layer_2 , dilate_layer_3 , dilate_layer_4 ))
348
327
model = _make_sliced_model (model , stop_layer = "layer4" )
349
328
model .eval ()
350
329
x = torch .rand (1 , 3 , 224 , 224 )
@@ -354,22 +333,22 @@ def test_resnet_dilation(dilate_layer_2, dilate_layer_3, dilate_layer_4):
354
333
355
334
356
335
def test_mobilenet_v2_residual_setting ():
357
- model = models .__dict__ [ " mobilenet_v2" ] (inverted_residual_setting = [[1 , 16 , 1 , 1 ], [6 , 24 , 2 , 2 ]])
336
+ model = models .mobilenet_v2 (inverted_residual_setting = [[1 , 16 , 1 , 1 ], [6 , 24 , 2 , 2 ]])
358
337
model .eval ()
359
338
x = torch .rand (1 , 3 , 224 , 224 )
360
339
out = model (x )
361
340
assert out .shape [- 1 ] == 1000
362
341
363
342
364
- @pytest .mark .parametrize ("model_name " , [" mobilenet_v2" , " mobilenet_v3_large" , " mobilenet_v3_small" ])
365
- def test_mobilenet_norm_layer (model_name ):
366
- model = models . __dict__ [ model_name ] ()
343
+ @pytest .mark .parametrize ("model_fn " , [models . mobilenet_v2 , models . mobilenet_v3_large , models . mobilenet_v3_small ])
344
+ def test_mobilenet_norm_layer (model_fn ):
345
+ model = model_fn ()
367
346
assert any (isinstance (x , nn .BatchNorm2d ) for x in model .modules ())
368
347
369
348
def get_gn (num_channels ):
370
349
return nn .GroupNorm (32 , num_channels )
371
350
372
- model = models . __dict__ [ model_name ] (norm_layer = get_gn )
351
+ model = model_fn (norm_layer = get_gn )
373
352
assert not (any (isinstance (x , nn .BatchNorm2d ) for x in model .modules ()))
374
353
assert any (isinstance (x , nn .GroupNorm ) for x in model .modules ())
375
354
@@ -478,18 +457,19 @@ def test_generalizedrcnn_transform_repr():
478
457
assert t .__repr__ () == expected_string
479
458
480
459
481
- @pytest .mark .parametrize ("model_name " , get_available_classification_models ( ))
460
+ @pytest .mark .parametrize ("model_fn " , get_models_from_module ( models ))
482
461
@pytest .mark .parametrize ("dev" , cpu_and_gpu ())
483
- def test_classification_model (model_name , dev ):
462
+ def test_classification_model (model_fn , dev ):
484
463
set_rng_seed (0 )
485
464
defaults = {
486
465
"num_classes" : 50 ,
487
466
"input_shape" : (1 , 3 , 224 , 224 ),
488
467
}
468
+ model_name = model_fn .__name__
489
469
kwargs = {** defaults , ** _model_params .get (model_name , {})}
490
470
input_shape = kwargs .pop ("input_shape" )
491
471
492
- model = models . __dict__ [ model_name ] (** kwargs )
472
+ model = model_fn (** kwargs )
493
473
model .eval ().to (device = dev )
494
474
# RNG always on CPU, to ensure x in cuda tests is bitwise identical to x in cpu tests
495
475
x = torch .rand (input_shape ).to (device = dev )
@@ -510,19 +490,20 @@ def test_classification_model(model_name, dev):
510
490
_check_input_backprop (model , x )
511
491
512
492
513
- @pytest .mark .parametrize ("model_name " , get_available_segmentation_models ( ))
493
+ @pytest .mark .parametrize ("model_fn " , get_models_from_module ( models . segmentation ))
514
494
@pytest .mark .parametrize ("dev" , cpu_and_gpu ())
515
- def test_segmentation_model (model_name , dev ):
495
+ def test_segmentation_model (model_fn , dev ):
516
496
set_rng_seed (0 )
517
497
defaults = {
518
498
"num_classes" : 10 ,
519
499
"pretrained_backbone" : False ,
520
500
"input_shape" : (1 , 3 , 32 , 32 ),
521
501
}
502
+ model_name = model_fn .__name__
522
503
kwargs = {** defaults , ** _model_params .get (model_name , {})}
523
504
input_shape = kwargs .pop ("input_shape" )
524
505
525
- model = models . segmentation . __dict__ [ model_name ] (** kwargs )
506
+ model = model_fn (** kwargs )
526
507
model .eval ().to (device = dev )
527
508
# RNG always on CPU, to ensure x in cuda tests is bitwise identical to x in cpu tests
528
509
x = torch .rand (input_shape ).to (device = dev )
@@ -571,19 +552,20 @@ def check_out(out):
571
552
_check_input_backprop (model , x )
572
553
573
554
574
- @pytest .mark .parametrize ("model_name " , get_available_detection_models ( ))
555
+ @pytest .mark .parametrize ("model_fn " , get_models_from_module ( models . detection ))
575
556
@pytest .mark .parametrize ("dev" , cpu_and_gpu ())
576
- def test_detection_model (model_name , dev ):
557
+ def test_detection_model (model_fn , dev ):
577
558
set_rng_seed (0 )
578
559
defaults = {
579
560
"num_classes" : 50 ,
580
561
"pretrained_backbone" : False ,
581
562
"input_shape" : (3 , 300 , 300 ),
582
563
}
564
+ model_name = model_fn .__name__
583
565
kwargs = {** defaults , ** _model_params .get (model_name , {})}
584
566
input_shape = kwargs .pop ("input_shape" )
585
567
586
- model = models . detection . __dict__ [ model_name ] (** kwargs )
568
+ model = model_fn (** kwargs )
587
569
model .eval ().to (device = dev )
588
570
# RNG always on CPU, to ensure x in cuda tests is bitwise identical to x in cpu tests
589
571
x = torch .rand (input_shape ).to (device = dev )
@@ -667,10 +649,10 @@ def compute_mean_std(tensor):
667
649
_check_input_backprop (model , model_input )
668
650
669
651
670
- @pytest .mark .parametrize ("model_name " , get_available_detection_models ( ))
671
- def test_detection_model_validation (model_name ):
652
+ @pytest .mark .parametrize ("model_fn " , get_models_from_module ( models . detection ))
653
+ def test_detection_model_validation (model_fn ):
672
654
set_rng_seed (0 )
673
- model = models . detection . __dict__ [ model_name ] (num_classes = 50 , pretrained_backbone = False )
655
+ model = model_fn (num_classes = 50 , pretrained_backbone = False )
674
656
input_shape = (3 , 300 , 300 )
675
657
x = [torch .rand (input_shape )]
676
658
@@ -696,14 +678,15 @@ def test_detection_model_validation(model_name):
696
678
model (x , targets = targets )
697
679
698
680
699
- @pytest .mark .parametrize ("model_name " , get_available_video_models ( ))
681
+ @pytest .mark .parametrize ("model_fn " , get_models_from_module ( models . video ))
700
682
@pytest .mark .parametrize ("dev" , cpu_and_gpu ())
701
- def test_video_model (model_name , dev ):
683
+ def test_video_model (model_fn , dev ):
702
684
# the default input shape is
703
685
# bs * num_channels * clip_len * h *w
704
686
input_shape = (1 , 3 , 4 , 112 , 112 )
687
+ model_name = model_fn .__name__
705
688
# test both basicblock and Bottleneck
706
- model = models . video . __dict__ [ model_name ] (num_classes = 50 )
689
+ model = model_fn (num_classes = 50 )
707
690
model .eval ().to (device = dev )
708
691
# RNG always on CPU, to ensure x in cuda tests is bitwise identical to x in cpu tests
709
692
x = torch .rand (input_shape ).to (device = dev )
@@ -727,20 +710,21 @@ def test_video_model(model_name, dev):
727
710
),
728
711
reason = "This Pytorch Build has not been built with fbgemm and qnnpack" ,
729
712
)
730
- @pytest .mark .parametrize ("model_name " , get_available_quantizable_models ( ))
731
- def test_quantized_classification_model (model_name ):
713
+ @pytest .mark .parametrize ("model_fn " , get_models_from_module ( models . quantization ))
714
+ def test_quantized_classification_model (model_fn ):
732
715
set_rng_seed (0 )
733
716
defaults = {
734
717
"num_classes" : 5 ,
735
718
"input_shape" : (1 , 3 , 224 , 224 ),
736
719
"pretrained" : False ,
737
720
"quantize" : True ,
738
721
}
722
+ model_name = model_fn .__name__
739
723
kwargs = {** defaults , ** _model_params .get (model_name , {})}
740
724
input_shape = kwargs .pop ("input_shape" )
741
725
742
726
# First check if quantize=True provides models that can run with input data
743
- model = torchvision . models . quantization . __dict__ [ model_name ] (** kwargs )
727
+ model = model_fn (** kwargs )
744
728
model .eval ()
745
729
x = torch .rand (input_shape )
746
730
out = model (x )
@@ -753,7 +737,7 @@ def test_quantized_classification_model(model_name):
753
737
754
738
kwargs ["quantize" ] = False
755
739
for eval_mode in [True , False ]:
756
- model = torchvision . models . quantization . __dict__ [ model_name ] (** kwargs )
740
+ model = model_fn (** kwargs )
757
741
if eval_mode :
758
742
model .eval ()
759
743
model .qconfig = torch .quantization .default_qconfig
@@ -777,14 +761,13 @@ def test_quantized_classification_model(model_name):
777
761
raise AssertionError (f"model cannot be scripted. Traceback = { str (tb )} " ) from e
778
762
779
763
780
- @pytest .mark .parametrize ("model_name" , get_available_detection_models ())
781
- def test_detection_model_trainable_backbone_layers (model_name ):
764
+ @pytest .mark .parametrize ("model_fn" , get_models_from_module (models .detection ))
765
+ def test_detection_model_trainable_backbone_layers (model_fn ):
766
+ model_name = model_fn .__name__
782
767
max_trainable = _model_tests_values [model_name ]["max_trainable" ]
783
768
n_trainable_params = []
784
769
for trainable_layers in range (0 , max_trainable + 1 ):
785
- model = torchvision .models .detection .__dict__ [model_name ](
786
- pretrained = False , pretrained_backbone = True , trainable_backbone_layers = trainable_layers
787
- )
770
+ model = model_fn (pretrained = False , pretrained_backbone = True , trainable_backbone_layers = trainable_layers )
788
771
789
772
n_trainable_params .append (len ([p for p in model .parameters () if p .requires_grad ]))
790
773
assert n_trainable_params == _model_tests_values [model_name ]["n_trn_params_per_layer" ]
0 commit comments