Skip to content

Commit dbfe7fa

Browse files
authored
torch.hub: add support for DOFA and Swin models (#2052)
* torch.hub: add support for DOFA and Swin models * Fix tests * Add *args support to DOFA to match other models
1 parent 94bd5c7 commit dbfe7fa

File tree

4 files changed

+67
-15
lines changed

4 files changed

+67
-15
lines changed

hubconf.py

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,22 @@
77
* https://pytorch.org/docs/stable/hub.html
88
"""
99

10-
from torchgeo.models import resnet18, resnet50, vit_small_patch16_224
10+
from torchgeo.models import (
11+
dofa_base_patch16_224,
12+
dofa_large_patch16_224,
13+
resnet18,
14+
resnet50,
15+
swin_v2_b,
16+
vit_small_patch16_224,
17+
)
1118

12-
__all__ = ('resnet18', 'resnet50', 'vit_small_patch16_224')
19+
__all__ = (
20+
'dofa_base_patch16_224',
21+
'dofa_large_patch16_224',
22+
'resnet18',
23+
'resnet50',
24+
'swin_v2_b',
25+
'vit_small_patch16_224',
26+
)
1327

1428
dependencies = ['timm']

tests/models/test_api.py

Lines changed: 20 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,10 +9,14 @@
99
from torchvision.models._api import WeightsEnum
1010

1111
from torchgeo.models import (
12+
DOFABase16_Weights,
13+
DOFALarge16_Weights,
1214
ResNet18_Weights,
1315
ResNet50_Weights,
1416
Swin_V2_B_Weights,
1517
ViTSmall16_Weights,
18+
dofa_base_patch16_224,
19+
dofa_large_patch16_224,
1620
get_model,
1721
get_model_weights,
1822
get_weight,
@@ -23,8 +27,22 @@
2327
vit_small_patch16_224,
2428
)
2529

26-
builders = [resnet18, resnet50, vit_small_patch16_224, swin_v2_b]
27-
enums = [ResNet18_Weights, ResNet50_Weights, ViTSmall16_Weights, Swin_V2_B_Weights]
30+
builders = [
31+
dofa_base_patch16_224,
32+
dofa_large_patch16_224,
33+
resnet18,
34+
resnet50,
35+
swin_v2_b,
36+
vit_small_patch16_224,
37+
]
38+
enums = [
39+
DOFABase16_Weights,
40+
DOFALarge16_Weights,
41+
ResNet18_Weights,
42+
ResNet50_Weights,
43+
Swin_V2_B_Weights,
44+
ViTSmall16_Weights,
45+
]
2846

2947

3048
@pytest.mark.parametrize('builder', builders)

torchgeo/models/api.py

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,26 +16,38 @@
1616
import torch.nn as nn
1717
from torchvision.models._api import WeightsEnum
1818

19+
from .dofa import (
20+
DOFABase16_Weights,
21+
DOFALarge16_Weights,
22+
dofa_base_patch16_224,
23+
dofa_large_patch16_224,
24+
)
1925
from .resnet import ResNet18_Weights, ResNet50_Weights, resnet18, resnet50
2026
from .swin import Swin_V2_B_Weights, swin_v2_b
2127
from .vit import ViTSmall16_Weights, vit_small_patch16_224
2228

2329
_model = {
30+
'dofa_base_patch16_224': dofa_base_patch16_224,
31+
'dofa_large_patch16_224': dofa_large_patch16_224,
2432
'resnet18': resnet18,
2533
'resnet50': resnet50,
26-
'vit_small_patch16_224': vit_small_patch16_224,
2734
'swin_v2_b': swin_v2_b,
35+
'vit_small_patch16_224': vit_small_patch16_224,
2836
}
2937

3038
_model_weights = {
39+
dofa_base_patch16_224: DOFABase16_Weights,
40+
dofa_large_patch16_224: DOFALarge16_Weights,
3141
resnet18: ResNet18_Weights,
3242
resnet50: ResNet50_Weights,
33-
vit_small_patch16_224: ViTSmall16_Weights,
3443
swin_v2_b: Swin_V2_B_Weights,
44+
vit_small_patch16_224: ViTSmall16_Weights,
45+
'dofa_base_patch16_224': DOFABase16_Weights,
46+
'dofa_large_patch16_224': DOFALarge16_Weights,
3547
'resnet18': ResNet18_Weights,
3648
'resnet50': ResNet50_Weights,
37-
'vit_small_patch16_224': ViTSmall16_Weights,
3849
'swin_v2_b': Swin_V2_B_Weights,
50+
'vit_small_patch16_224': ViTSmall16_Weights,
3951
}
4052

4153

torchgeo/models/dofa.py

Lines changed: 16 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -415,7 +415,7 @@ class DOFALarge16_Weights(WeightsEnum): # type: ignore[misc]
415415
)
416416

417417

418-
def dofa_small_patch16_224(**kwargs: Any) -> DOFA:
418+
def dofa_small_patch16_224(*args: Any, **kwargs: Any) -> DOFA:
419419
"""Dynamic One-For-All (DOFA) small patch size 16 model.
420420
421421
If you use this model in your research, please cite the following paper:
@@ -425,17 +425,19 @@ def dofa_small_patch16_224(**kwargs: Any) -> DOFA:
425425
.. versionadded:: 0.6
426426
427427
Args:
428+
*args: Additional arguments to pass to :class:`DOFA`.
428429
**kwargs: Additional keywork arguments to pass to :class:`DOFA`.
429430
430431
Returns:
431432
A DOFA small 16 model.
432433
"""
433-
model = DOFA(patch_size=16, embed_dim=384, depth=12, num_heads=6, **kwargs)
434+
kwargs |= {'patch_size': 16, 'embed_dim': 384, 'depth': 12, 'num_heads': 6}
435+
model = DOFA(*args, **kwargs)
434436
return model
435437

436438

437439
def dofa_base_patch16_224(
438-
weights: DOFABase16_Weights | None = None, **kwargs: Any
440+
weights: DOFABase16_Weights | None = None, *args: Any, **kwargs: Any
439441
) -> DOFA:
440442
"""Dynamic One-For-All (DOFA) base patch size 16 model.
441443
@@ -447,12 +449,14 @@ def dofa_base_patch16_224(
447449
448450
Args:
449451
weights: Pre-trained model weights to use.
452+
*args: Additional arguments to pass to :class:`DOFA`.
450453
**kwargs: Additional keywork arguments to pass to :class:`DOFA`.
451454
452455
Returns:
453456
A DOFA base 16 model.
454457
"""
455-
model = DOFA(patch_size=16, embed_dim=768, depth=12, num_heads=12, **kwargs)
458+
kwargs |= {'patch_size': 16, 'embed_dim': 768, 'depth': 12, 'num_heads': 12}
459+
model = DOFA(*args, **kwargs)
456460

457461
if weights:
458462
missing_keys, unexpected_keys = model.load_state_dict(
@@ -471,7 +475,7 @@ def dofa_base_patch16_224(
471475

472476

473477
def dofa_large_patch16_224(
474-
weights: DOFALarge16_Weights | None = None, **kwargs: Any
478+
weights: DOFALarge16_Weights | None = None, *args: Any, **kwargs: Any
475479
) -> DOFA:
476480
"""Dynamic One-For-All (DOFA) large patch size 16 model.
477481
@@ -483,12 +487,14 @@ def dofa_large_patch16_224(
483487
484488
Args:
485489
weights: Pre-trained model weights to use.
490+
*args: Additional arguments to pass to :class:`DOFA`.
486491
**kwargs: Additional keywork arguments to pass to :class:`DOFA`.
487492
488493
Returns:
489494
A DOFA large 16 model.
490495
"""
491-
model = DOFA(patch_size=16, embed_dim=1024, depth=24, num_heads=16, **kwargs)
496+
kwargs |= {'patch_size': 16, 'embed_dim': 1024, 'depth': 24, 'num_heads': 16}
497+
model = DOFA(*args, **kwargs)
492498

493499
if weights:
494500
missing_keys, unexpected_keys = model.load_state_dict(
@@ -506,7 +512,7 @@ def dofa_large_patch16_224(
506512
return model
507513

508514

509-
def dofa_huge_patch16_224(**kwargs: Any) -> DOFA:
515+
def dofa_huge_patch16_224(*args: Any, **kwargs: Any) -> DOFA:
510516
"""Dynamic One-For-All (DOFA) huge patch size 16 model.
511517
512518
If you use this model in your research, please cite the following paper:
@@ -516,10 +522,12 @@ def dofa_huge_patch16_224(**kwargs: Any) -> DOFA:
516522
.. versionadded:: 0.6
517523
518524
Args:
525+
*args: Additional arguments to pass to :class:`DOFA`.
519526
**kwargs: Additional keywork arguments to pass to :class:`DOFA`.
520527
521528
Returns:
522529
A DOFA huge 16 model.
523530
"""
524-
model = DOFA(patch_size=14, embed_dim=1280, depth=32, num_heads=16, **kwargs)
531+
kwargs |= {'patch_size': 14, 'embed_dim': 1280, 'depth': 32, 'num_heads': 16}
532+
model = DOFA(*args, **kwargs)
525533
return model

0 commit comments

Comments
 (0)