Skip to content

Commit 313d522

Browse files
authored
Multi-weight support for DeepLabV3 prototype models (#4757)
* adding multiweight support for deeplabv3 prototype models * adding default values for optional params * fixing bug * addressing PR comment
1 parent b621e38 commit 313d522

File tree

2 files changed

+158
-0
lines changed

2 files changed

+158
-0
lines changed
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,3 @@
11
from .fcn import *
22
from .lraspp import *
3+
from .deeplabv3 import *
Lines changed: 157 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,157 @@
1+
import warnings
2+
from functools import partial
3+
from typing import Any, Optional
4+
5+
from ....models.segmentation.deeplabv3 import DeepLabV3, _deeplabv3_mobilenetv3, _deeplabv3_resnet
6+
from ...transforms.presets import VocEval
7+
from .._api import Weights, WeightEntry
8+
from .._meta import _VOC_CATEGORIES
9+
from ..mobilenetv3 import MobileNetV3LargeWeights, mobilenet_v3_large
10+
from ..resnet import resnet50, resnet101
11+
from ..resnet import ResNet50Weights, ResNet101Weights
12+
13+
14+
__all__ = [
15+
"DeepLabV3",
16+
"DeepLabV3ResNet50Weights",
17+
"DeepLabV3ResNet101Weights",
18+
"DeepLabV3MobileNetV3LargeWeights",
19+
"deeplabv3_mobilenet_v3_large",
20+
"deeplabv3_resnet50",
21+
"deeplabv3_resnet101",
22+
]
23+
24+
25+
class DeepLabV3ResNet50Weights(Weights):
26+
CocoWithVocLabels_RefV1 = WeightEntry(
27+
url="https://download.pytorch.org/models/deeplabv3_resnet50_coco-cd0a2569.pth",
28+
transforms=partial(VocEval, resize_size=520),
29+
meta={
30+
"categories": _VOC_CATEGORIES,
31+
"recipe": "https://github.com/pytorch/vision/tree/main/references/segmentation#deeplabv3_resnet50",
32+
"mIoU": 66.4,
33+
"acc": 92.4,
34+
},
35+
)
36+
37+
38+
class DeepLabV3ResNet101Weights(Weights):
39+
CocoWithVocLabels_RefV1 = WeightEntry(
40+
url="https://download.pytorch.org/models/deeplabv3_resnet101_coco-586e9e4e.pth",
41+
transforms=partial(VocEval, resize_size=520),
42+
meta={
43+
"categories": _VOC_CATEGORIES,
44+
"recipe": "https://github.com/pytorch/vision/tree/main/references/segmentation#fcn_resnet101",
45+
"mIoU": 67.4,
46+
"acc": 92.4,
47+
},
48+
)
49+
50+
51+
class DeepLabV3MobileNetV3LargeWeights(Weights):
52+
CocoWithVocLabels_RefV1 = WeightEntry(
53+
url="https://download.pytorch.org/models/deeplabv3_mobilenet_v3_large-fc3c493d.pth",
54+
transforms=partial(VocEval, resize_size=520),
55+
meta={
56+
"categories": _VOC_CATEGORIES,
57+
"recipe": "https://github.com/pytorch/vision/tree/main/references/segmentation#deeplabv3_mobilenet_v3_large",
58+
"mIoU": 60.3,
59+
"acc": 91.2,
60+
},
61+
)
62+
63+
64+
def deeplabv3_resnet50(
65+
weights: Optional[DeepLabV3ResNet50Weights] = None,
66+
weights_backbone: Optional[ResNet50Weights] = None,
67+
progress: bool = True,
68+
num_classes: int = 21,
69+
aux_loss: Optional[bool] = None,
70+
**kwargs: Any,
71+
) -> DeepLabV3:
72+
if "pretrained" in kwargs:
73+
warnings.warn("The argument pretrained is deprecated, please use weights instead.")
74+
weights = DeepLabV3ResNet50Weights.CocoWithVocLabels_RefV1 if kwargs.pop("pretrained") else None
75+
76+
weights = DeepLabV3ResNet50Weights.verify(weights)
77+
if "pretrained_backbone" in kwargs:
78+
warnings.warn("The argument pretrained_backbone is deprecated, please use weights_backbone instead.")
79+
weights_backbone = ResNet50Weights.ImageNet1K_RefV1 if kwargs.pop("pretrained_backbone") else None
80+
weights_backbone = ResNet50Weights.verify(weights_backbone)
81+
82+
if weights is not None:
83+
weights_backbone = None
84+
aux_loss = True
85+
num_classes = len(weights.meta["categories"])
86+
87+
backbone = resnet50(weights=weights_backbone, replace_stride_with_dilation=[False, True, True])
88+
model = _deeplabv3_resnet(backbone, num_classes, aux_loss)
89+
90+
if weights is not None:
91+
model.load_state_dict(weights.state_dict(progress=progress))
92+
93+
return model
94+
95+
96+
def deeplabv3_resnet101(
97+
weights: Optional[DeepLabV3ResNet101Weights] = None,
98+
weights_backbone: Optional[ResNet101Weights] = None,
99+
progress: bool = True,
100+
num_classes: int = 21,
101+
aux_loss: Optional[bool] = None,
102+
**kwargs: Any,
103+
) -> DeepLabV3:
104+
if "pretrained" in kwargs:
105+
warnings.warn("The argument pretrained is deprecated, please use weights instead.")
106+
weights = DeepLabV3ResNet101Weights.CocoWithVocLabels_RefV1 if kwargs.pop("pretrained") else None
107+
108+
weights = DeepLabV3ResNet101Weights.verify(weights)
109+
if "pretrained_backbone" in kwargs:
110+
warnings.warn("The argument pretrained_backbone is deprecated, please use weights_backbone instead.")
111+
weights_backbone = ResNet101Weights.ImageNet1K_RefV1 if kwargs.pop("pretrained_backbone") else None
112+
weights_backbone = ResNet101Weights.verify(weights_backbone)
113+
114+
if weights is not None:
115+
weights_backbone = None
116+
aux_loss = True
117+
num_classes = len(weights.meta["categories"])
118+
119+
backbone = resnet101(weights=weights_backbone, replace_stride_with_dilation=[False, True, True])
120+
model = _deeplabv3_resnet(backbone, num_classes, aux_loss)
121+
122+
if weights is not None:
123+
model.load_state_dict(weights.state_dict(progress=progress))
124+
125+
return model
126+
127+
128+
def deeplabv3_mobilenet_v3_large(
129+
weights: Optional[DeepLabV3MobileNetV3LargeWeights] = None,
130+
weights_backbone: Optional[MobileNetV3LargeWeights] = None,
131+
progress: bool = True,
132+
num_classes: int = 21,
133+
aux_loss: Optional[bool] = None,
134+
**kwargs: Any,
135+
) -> DeepLabV3:
136+
if "pretrained" in kwargs:
137+
warnings.warn("The argument pretrained is deprecated, please use weights instead.")
138+
weights = DeepLabV3MobileNetV3LargeWeights.CocoWithVocLabels_RefV1 if kwargs.pop("pretrained") else None
139+
140+
weights = DeepLabV3MobileNetV3LargeWeights.verify(weights)
141+
if "pretrained_backbone" in kwargs:
142+
warnings.warn("The argument pretrained_backbone is deprecated, please use weights_backbone instead.")
143+
weights_backbone = MobileNetV3LargeWeights.ImageNet1K_RefV1 if kwargs.pop("pretrained_backbone") else None
144+
weights_backbone = MobileNetV3LargeWeights.verify(weights_backbone)
145+
146+
if weights is not None:
147+
weights_backbone = None
148+
aux_loss = True
149+
num_classes = len(weights.meta["categories"])
150+
151+
backbone = mobilenet_v3_large(weights=weights_backbone, dilated=True)
152+
model = _deeplabv3_mobilenetv3(backbone, num_classes, aux_loss)
153+
154+
if weights is not None:
155+
model.load_state_dict(weights.state_dict(progress=progress))
156+
157+
return model

0 commit comments

Comments
 (0)