Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion test/test_extended_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,8 @@ def test_schema_meta_validation(model_fn):
incorrect_params.append(w)
else:
if w.meta.get("num_params") != weights_enum.DEFAULT.meta.get("num_params"):
incorrect_params.append(w)
if w.meta.get("num_params") != sum(p.numel() for p in model_fn(weights=w).parameters()):
incorrect_params.append(w)
if not w.name.isupper():
bad_names.append(w)

Expand Down
35 changes: 33 additions & 2 deletions torchvision/models/vision_transformer.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import collections.abc as abc
import math
from collections import OrderedDict
from functools import partial
Expand Down Expand Up @@ -284,10 +285,20 @@ def _vision_transformer(
progress: bool,
**kwargs: Any,
) -> VisionTransformer:
image_size = kwargs.pop("image_size", 224)

if weights is not None:
_ovewrite_named_param(kwargs, "num_classes", len(weights.meta["categories"]))
if "size" in weights.meta:
if isinstance(weights.meta["size"], int):
_ovewrite_named_param(kwargs, "image_size", weights.meta["size"])
elif isinstance(weights.meta["size"], abc.Sequence):
torch._assert(
weights.meta["size"][0] == weights.meta["size"][1],
"Currently we only support a square image where width = height",
)
_ovewrite_named_param(kwargs, "image_size", weights.meta["size"][0])
else:
raise ValueError('weights.meta["size"] should have type of either an int or a Sequence[int]')
image_size = kwargs.pop("image_size", 224)

model = VisionTransformer(
image_size=image_size,
Expand All @@ -313,6 +324,12 @@ def _vision_transformer(
"interpolation": InterpolationMode.BILINEAR,
}

_COMMON_SWAG_META = {
**COMMON_META,
"recipe": "https://github.com/facebookresearch/SWAG",
"interpolation": InterpolationMode.BICUBIC,
}


class ViT_B_16_Weights(WeightsEnum):
IMAGENET1K_V1 = Weights(
Expand All @@ -328,6 +345,20 @@ class ViT_B_16_Weights(WeightsEnum):
"acc@5": 95.318,
},
)
IMAGENET1K_SWAG_V1 = Weights(
url="https://download.pytorch.org/models/vit_b_16_swag-9ac1b537.pth",
transforms=partial(
ImageClassification, resize_size=384, interpolation=InterpolationMode.BICUBIC, crop_size=384
),
meta={
**_COMMON_SWAG_META,
"num_params": 86859496,
"size": (384, 384),
"min_size": (384, 384),
"acc@1": 85.304,
"acc@5": 97.650,
},
)
DEFAULT = IMAGENET1K_V1


Expand Down