Skip to content

Cleanup Models prototype implementation #4940

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 7 commits into from
Nov 16, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions torchvision/prototype/models/_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ def verify(cls, obj: Any) -> Any:
if obj is not None:
if type(obj) is str:
obj = cls.from_str(obj)
elif not isinstance(obj, cls) and not isinstance(obj, WeightEntry):
elif not isinstance(obj, cls):
raise TypeError(
f"Invalid Weight class provided; expected {cls.__name__} but received {obj.__class__.__name__}."
)
Expand All @@ -63,7 +63,7 @@ def from_str(cls, value: str) -> "Weights":
return v
raise ValueError(f"Invalid value {value} for enum {cls.__name__}.")

def state_dict(self, progress: bool) -> OrderedDict:
def get_state_dict(self, progress: bool) -> OrderedDict:
return load_state_dict_from_url(self.url, progress=progress)

def __repr__(self):
Expand All @@ -90,7 +90,7 @@ def get_weight(fn: Callable, weight_name: str) -> Weights:
"""
sig = signature(fn)
if "weights" not in sig.parameters:
raise ValueError("The method is missing the 'weights' argument.")
raise ValueError("The method is missing the 'weights' parameter.")

ann = signature(fn).parameters["weights"].annotation
weights_class = None
Expand Down
4 changes: 2 additions & 2 deletions torchvision/prototype/models/alexnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ class AlexNetWeights(Weights):

def alexnet(weights: Optional[AlexNetWeights] = None, progress: bool = True, **kwargs: Any) -> AlexNet:
if "pretrained" in kwargs:
warnings.warn("The argument pretrained is deprecated, please use weights instead.")
warnings.warn("The parameter pretrained is deprecated, please use weights instead.")
weights = AlexNetWeights.ImageNet1K_RefV1 if kwargs.pop("pretrained") else None
weights = AlexNetWeights.verify(weights)
if weights is not None:
Expand All @@ -39,6 +39,6 @@ def alexnet(weights: Optional[AlexNetWeights] = None, progress: bool = True, **k
model = AlexNet(**kwargs)

if weights is not None:
model.load_state_dict(weights.state_dict(progress=progress))
model.load_state_dict(weights.get_state_dict(progress=progress))

return model
22 changes: 11 additions & 11 deletions torchvision/prototype/models/densenet.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ def _load_state_dict(model: nn.Module, weights: Weights, progress: bool) -> None
r"^(.*denselayer\d+\.(?:norm|relu|conv))\.((?:[12])\.(?:weight|bias|running_mean|running_var))$"
)

state_dict = weights.state_dict(progress=progress)
state_dict = weights.get_state_dict(progress=progress)
for key in list(state_dict.keys()):
res = pattern.match(key)
if res:
Expand Down Expand Up @@ -63,11 +63,11 @@ def _densenet(
return model


_common_meta = {
_COMMON_META = {
"size": (224, 224),
"categories": _IMAGENET_CATEGORIES,
"interpolation": InterpolationMode.BILINEAR,
"recipe": None, # weights ported from LuaTorch
"recipe": None, # TODO: add here a URL to documentation stating that the weights were ported from LuaTorch
}


Expand All @@ -76,7 +76,7 @@ class DenseNet121Weights(Weights):
url="https://download.pytorch.org/models/densenet121-a639ec97.pth",
transforms=partial(ImageNetEval, crop_size=224),
meta={
**_common_meta,
**_COMMON_META,
"acc@1": 74.434,
"acc@5": 91.972,
},
Expand All @@ -88,7 +88,7 @@ class DenseNet161Weights(Weights):
url="https://download.pytorch.org/models/densenet161-8d451a50.pth",
transforms=partial(ImageNetEval, crop_size=224),
meta={
**_common_meta,
**_COMMON_META,
"acc@1": 77.138,
"acc@5": 93.560,
},
Expand All @@ -100,7 +100,7 @@ class DenseNet169Weights(Weights):
url="https://download.pytorch.org/models/densenet169-b2777c0a.pth",
transforms=partial(ImageNetEval, crop_size=224),
meta={
**_common_meta,
**_COMMON_META,
"acc@1": 75.600,
"acc@5": 92.806,
},
Expand All @@ -112,7 +112,7 @@ class DenseNet201Weights(Weights):
url="https://download.pytorch.org/models/densenet201-c1103571.pth",
transforms=partial(ImageNetEval, crop_size=224),
meta={
**_common_meta,
**_COMMON_META,
"acc@1": 76.896,
"acc@5": 93.370,
},
Expand All @@ -121,7 +121,7 @@ class DenseNet201Weights(Weights):

def densenet121(weights: Optional[DenseNet121Weights] = None, progress: bool = True, **kwargs: Any) -> DenseNet:
if "pretrained" in kwargs:
warnings.warn("The argument pretrained is deprecated, please use weights instead.")
warnings.warn("The parameter pretrained is deprecated, please use weights instead.")
weights = DenseNet121Weights.ImageNet1K_Community if kwargs.pop("pretrained") else None
weights = DenseNet121Weights.verify(weights)

Expand All @@ -130,7 +130,7 @@ def densenet121(weights: Optional[DenseNet121Weights] = None, progress: bool = T

def densenet161(weights: Optional[DenseNet161Weights] = None, progress: bool = True, **kwargs: Any) -> DenseNet:
if "pretrained" in kwargs:
warnings.warn("The argument pretrained is deprecated, please use weights instead.")
warnings.warn("The parameter pretrained is deprecated, please use weights instead.")
weights = DenseNet161Weights.ImageNet1K_Community if kwargs.pop("pretrained") else None
weights = DenseNet161Weights.verify(weights)

Expand All @@ -139,7 +139,7 @@ def densenet161(weights: Optional[DenseNet161Weights] = None, progress: bool = T

def densenet169(weights: Optional[DenseNet169Weights] = None, progress: bool = True, **kwargs: Any) -> DenseNet:
if "pretrained" in kwargs:
warnings.warn("The argument pretrained is deprecated, please use weights instead.")
warnings.warn("The parameter pretrained is deprecated, please use weights instead.")
weights = DenseNet169Weights.ImageNet1K_Community if kwargs.pop("pretrained") else None
weights = DenseNet169Weights.verify(weights)

Expand All @@ -148,7 +148,7 @@ def densenet169(weights: Optional[DenseNet169Weights] = None, progress: bool = T

def densenet201(weights: Optional[DenseNet201Weights] = None, progress: bool = True, **kwargs: Any) -> DenseNet:
if "pretrained" in kwargs:
warnings.warn("The argument pretrained is deprecated, please use weights instead.")
warnings.warn("The parameter pretrained is deprecated, please use weights instead.")
weights = DenseNet201Weights.ImageNet1K_Community if kwargs.pop("pretrained") else None
weights = DenseNet201Weights.verify(weights)

Expand Down
24 changes: 12 additions & 12 deletions torchvision/prototype/models/detection/faster_rcnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
]


_common_meta = {
_COMMON_META = {
"categories": _COCO_CATEGORIES,
"interpolation": InterpolationMode.BILINEAR,
}
Expand All @@ -41,7 +41,7 @@ class FasterRCNNResNet50FPNWeights(Weights):
url="https://download.pytorch.org/models/fasterrcnn_resnet50_fpn_coco-258fb6c6.pth",
transforms=CocoEval,
meta={
**_common_meta,
**_COMMON_META,
"recipe": "https://github.com/pytorch/vision/tree/main/references/detection#faster-r-cnn-resnet-50-fpn",
"map": 37.0,
},
Expand All @@ -53,7 +53,7 @@ class FasterRCNNMobileNetV3LargeFPNWeights(Weights):
url="https://download.pytorch.org/models/fasterrcnn_mobilenet_v3_large_fpn-fb6a3cc7.pth",
transforms=CocoEval,
meta={
**_common_meta,
**_COMMON_META,
"recipe": "https://github.com/pytorch/vision/tree/main/references/detection#faster-r-cnn-mobilenetv3-large-fpn",
"map": 32.8,
},
Expand All @@ -65,7 +65,7 @@ class FasterRCNNMobileNetV3Large320FPNWeights(Weights):
url="https://download.pytorch.org/models/fasterrcnn_mobilenet_v3_large_320_fpn-907ea3f9.pth",
transforms=CocoEval,
meta={
**_common_meta,
**_COMMON_META,
"recipe": "https://github.com/pytorch/vision/tree/main/references/detection#faster-r-cnn-mobilenetv3-large-320-fpn",
"map": 22.8,
},
Expand All @@ -81,11 +81,11 @@ def fasterrcnn_resnet50_fpn(
**kwargs: Any,
) -> FasterRCNN:
if "pretrained" in kwargs:
warnings.warn("The argument pretrained is deprecated, please use weights instead.")
warnings.warn("The parameter pretrained is deprecated, please use weights instead.")
weights = FasterRCNNResNet50FPNWeights.Coco_RefV1 if kwargs.pop("pretrained") else None
weights = FasterRCNNResNet50FPNWeights.verify(weights)
if "pretrained_backbone" in kwargs:
warnings.warn("The argument pretrained_backbone is deprecated, please use weights_backbone instead.")
warnings.warn("The parameter pretrained_backbone is deprecated, please use weights_backbone instead.")
weights_backbone = ResNet50Weights.ImageNet1K_RefV1 if kwargs.pop("pretrained_backbone") else None
weights_backbone = ResNet50Weights.verify(weights_backbone)

Expand All @@ -102,7 +102,7 @@ def fasterrcnn_resnet50_fpn(
model = FasterRCNN(backbone, num_classes=num_classes, **kwargs)

if weights is not None:
model.load_state_dict(weights.state_dict(progress=progress))
model.load_state_dict(weights.get_state_dict(progress=progress))
if weights == FasterRCNNResNet50FPNWeights.Coco_RefV1:
overwrite_eps(model, 0.0)

Expand Down Expand Up @@ -142,7 +142,7 @@ def _fasterrcnn_mobilenet_v3_large_fpn(
)

if weights is not None:
model.load_state_dict(weights.state_dict(progress=progress))
model.load_state_dict(weights.get_state_dict(progress=progress))

return model

Expand All @@ -156,11 +156,11 @@ def fasterrcnn_mobilenet_v3_large_fpn(
**kwargs: Any,
) -> FasterRCNN:
if "pretrained" in kwargs:
warnings.warn("The argument pretrained is deprecated, please use weights instead.")
warnings.warn("The parameter pretrained is deprecated, please use weights instead.")
weights = FasterRCNNMobileNetV3LargeFPNWeights.Coco_RefV1 if kwargs.pop("pretrained") else None
weights = FasterRCNNMobileNetV3LargeFPNWeights.verify(weights)
if "pretrained_backbone" in kwargs:
warnings.warn("The argument pretrained_backbone is deprecated, please use weights_backbone instead.")
warnings.warn("The parameter pretrained_backbone is deprecated, please use weights_backbone instead.")
weights_backbone = MobileNetV3LargeWeights.ImageNet1K_RefV1 if kwargs.pop("pretrained_backbone") else None
weights_backbone = MobileNetV3LargeWeights.verify(weights_backbone)

Expand Down Expand Up @@ -188,11 +188,11 @@ def fasterrcnn_mobilenet_v3_large_320_fpn(
**kwargs: Any,
) -> FasterRCNN:
if "pretrained" in kwargs:
warnings.warn("The argument pretrained is deprecated, please use weights instead.")
warnings.warn("The parameter pretrained is deprecated, please use weights instead.")
weights = FasterRCNNMobileNetV3Large320FPNWeights.Coco_RefV1 if kwargs.pop("pretrained") else None
weights = FasterRCNNMobileNetV3Large320FPNWeights.verify(weights)
if "pretrained_backbone" in kwargs:
warnings.warn("The argument pretrained_backbone is deprecated, please use weights_backbone instead.")
warnings.warn("The parameter pretrained_backbone is deprecated, please use weights_backbone instead.")
weights_backbone = MobileNetV3LargeWeights.ImageNet1K_RefV1 if kwargs.pop("pretrained_backbone") else None
weights_backbone = MobileNetV3LargeWeights.verify(weights_backbone)

Expand Down
12 changes: 6 additions & 6 deletions torchvision/prototype/models/detection/keypoint_rcnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,15 +22,15 @@
]


_common_meta = {"categories": _COCO_PERSON_CATEGORIES, "keypoint_names": _COCO_PERSON_KEYPOINT_NAMES}
_COMMON_META = {"categories": _COCO_PERSON_CATEGORIES, "keypoint_names": _COCO_PERSON_KEYPOINT_NAMES}


class KeypointRCNNResNet50FPNWeights(Weights):
Coco_RefV1_Legacy = WeightEntry(
url="https://download.pytorch.org/models/keypointrcnn_resnet50_fpn_coco-9f466800.pth",
transforms=CocoEval,
meta={
**_common_meta,
**_COMMON_META,
"recipe": "https://github.com/pytorch/vision/issues/1606",
"box_map": 50.6,
"kp_map": 61.1,
Expand All @@ -40,7 +40,7 @@ class KeypointRCNNResNet50FPNWeights(Weights):
url="https://download.pytorch.org/models/keypointrcnn_resnet50_fpn_coco-fc266e95.pth",
transforms=CocoEval,
meta={
**_common_meta,
**_COMMON_META,
"recipe": "https://github.com/pytorch/vision/tree/main/references/detection#keypoint-r-cnn",
"box_map": 54.6,
"kp_map": 65.0,
Expand All @@ -58,7 +58,7 @@ def keypointrcnn_resnet50_fpn(
**kwargs: Any,
) -> KeypointRCNN:
if "pretrained" in kwargs:
warnings.warn("The argument pretrained is deprecated, please use weights instead.")
warnings.warn("The parameter pretrained is deprecated, please use weights instead.")
pretrained = kwargs.pop("pretrained")
if type(pretrained) == str and pretrained == "legacy":
weights = KeypointRCNNResNet50FPNWeights.Coco_RefV1_Legacy
Expand All @@ -68,7 +68,7 @@ def keypointrcnn_resnet50_fpn(
weights = None
weights = KeypointRCNNResNet50FPNWeights.verify(weights)
if "pretrained_backbone" in kwargs:
warnings.warn("The argument pretrained_backbone is deprecated, please use weights_backbone instead.")
warnings.warn("The parameter pretrained_backbone is deprecated, please use weights_backbone instead.")
weights_backbone = ResNet50Weights.ImageNet1K_RefV1 if kwargs.pop("pretrained_backbone") else None
weights_backbone = ResNet50Weights.verify(weights_backbone)

Expand All @@ -86,7 +86,7 @@ def keypointrcnn_resnet50_fpn(
model = KeypointRCNN(backbone, num_classes, num_keypoints=num_keypoints, **kwargs)

if weights is not None:
model.load_state_dict(weights.state_dict(progress=progress))
model.load_state_dict(weights.get_state_dict(progress=progress))
if weights == KeypointRCNNResNet50FPNWeights.Coco_RefV1:
overwrite_eps(model, 0.0)

Expand Down
6 changes: 3 additions & 3 deletions torchvision/prototype/models/detection/mask_rcnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,11 +46,11 @@ def maskrcnn_resnet50_fpn(
**kwargs: Any,
) -> MaskRCNN:
if "pretrained" in kwargs:
warnings.warn("The argument pretrained is deprecated, please use weights instead.")
warnings.warn("The parameter pretrained is deprecated, please use weights instead.")
weights = MaskRCNNResNet50FPNWeights.Coco_RefV1 if kwargs.pop("pretrained") else None
weights = MaskRCNNResNet50FPNWeights.verify(weights)
if "pretrained_backbone" in kwargs:
warnings.warn("The argument pretrained_backbone is deprecated, please use weights_backbone instead.")
warnings.warn("The parameter pretrained_backbone is deprecated, please use weights_backbone instead.")
weights_backbone = ResNet50Weights.ImageNet1K_RefV1 if kwargs.pop("pretrained_backbone") else None
weights_backbone = ResNet50Weights.verify(weights_backbone)

Expand All @@ -67,7 +67,7 @@ def maskrcnn_resnet50_fpn(
model = MaskRCNN(backbone, num_classes=num_classes, **kwargs)

if weights is not None:
model.load_state_dict(weights.state_dict(progress=progress))
model.load_state_dict(weights.get_state_dict(progress=progress))
if weights == MaskRCNNResNet50FPNWeights.Coco_RefV1:
overwrite_eps(model, 0.0)

Expand Down
6 changes: 3 additions & 3 deletions torchvision/prototype/models/detection/retinanet.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,11 +46,11 @@ def retinanet_resnet50_fpn(
**kwargs: Any,
) -> RetinaNet:
if "pretrained" in kwargs:
warnings.warn("The argument pretrained is deprecated, please use weights instead.")
warnings.warn("The parameter pretrained is deprecated, please use weights instead.")
weights = RetinaNetResNet50FPNWeights.Coco_RefV1 if kwargs.pop("pretrained") else None
weights = RetinaNetResNet50FPNWeights.verify(weights)
if "pretrained_backbone" in kwargs:
warnings.warn("The argument pretrained_backbone is deprecated, please use weights_backbone instead.")
warnings.warn("The parameter pretrained_backbone is deprecated, please use weights_backbone instead.")
weights_backbone = ResNet50Weights.ImageNet1K_RefV1 if kwargs.pop("pretrained_backbone") else None
weights_backbone = ResNet50Weights.verify(weights_backbone)

Expand All @@ -70,7 +70,7 @@ def retinanet_resnet50_fpn(
model = RetinaNet(backbone, num_classes, **kwargs)

if weights is not None:
model.load_state_dict(weights.state_dict(progress=progress))
model.load_state_dict(weights.get_state_dict(progress=progress))
if weights == RetinaNetResNet50FPNWeights.Coco_RefV1:
overwrite_eps(model, 0.0)

Expand Down
8 changes: 4 additions & 4 deletions torchvision/prototype/models/detection/ssd.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,16 +44,16 @@ def ssd300_vgg16(
**kwargs: Any,
) -> SSD:
if "pretrained" in kwargs:
warnings.warn("The argument pretrained is deprecated, please use weights instead.")
warnings.warn("The parameter pretrained is deprecated, please use weights instead.")
weights = SSD300VGG16Weights.Coco_RefV1 if kwargs.pop("pretrained") else None
weights = SSD300VGG16Weights.verify(weights)
if "pretrained_backbone" in kwargs:
warnings.warn("The argument pretrained_backbone is deprecated, please use weights_backbone instead.")
warnings.warn("The parameter pretrained_backbone is deprecated, please use weights_backbone instead.")
weights_backbone = VGG16Weights.ImageNet1K_Features if kwargs.pop("pretrained_backbone") else None
weights_backbone = VGG16Weights.verify(weights_backbone)

if "size" in kwargs:
warnings.warn("The size of the model is already fixed; ignoring the argument.")
warnings.warn("The size of the model is already fixed; ignoring the parameter.")

if weights is not None:
weights_backbone = None
Expand Down Expand Up @@ -81,6 +81,6 @@ def ssd300_vgg16(
model = SSD(backbone, anchor_generator, (300, 300), num_classes, **kwargs)

if weights is not None:
model.load_state_dict(weights.state_dict(progress=progress))
model.load_state_dict(weights.get_state_dict(progress=progress))

return model
8 changes: 4 additions & 4 deletions torchvision/prototype/models/detection/ssdlite.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,16 +50,16 @@ def ssdlite320_mobilenet_v3_large(
**kwargs: Any,
) -> SSD:
if "pretrained" in kwargs:
warnings.warn("The argument pretrained is deprecated, please use weights instead.")
warnings.warn("The parameter pretrained is deprecated, please use weights instead.")
weights = SSDlite320MobileNetV3LargeFPNWeights.Coco_RefV1 if kwargs.pop("pretrained") else None
weights = SSDlite320MobileNetV3LargeFPNWeights.verify(weights)
if "pretrained_backbone" in kwargs:
warnings.warn("The argument pretrained_backbone is deprecated, please use weights_backbone instead.")
warnings.warn("The parameter pretrained_backbone is deprecated, please use weights_backbone instead.")
weights_backbone = MobileNetV3LargeWeights.ImageNet1K_RefV1 if kwargs.pop("pretrained_backbone") else None
weights_backbone = MobileNetV3LargeWeights.verify(weights_backbone)

if "size" in kwargs:
warnings.warn("The size of the model is already fixed; ignoring the argument.")
warnings.warn("The size of the model is already fixed; ignoring the parameter.")

if weights is not None:
weights_backbone = None
Expand Down Expand Up @@ -114,6 +114,6 @@ def ssdlite320_mobilenet_v3_large(
)

if weights is not None:
model.load_state_dict(weights.state_dict(progress=progress))
model.load_state_dict(weights.get_state_dict(progress=progress))

return model
Loading