From c801bf06aa6c2b02501bf5acdd19483895974de3 Mon Sep 17 00:00:00 2001 From: Yosua Michael Maranatha Date: Thu, 31 Mar 2022 14:46:39 +0100 Subject: [PATCH 01/16] Add vit_b_16_swag --- torchvision/models/vision_transformer.py | 34 +++++++++++++++++++++++- 1 file changed, 33 insertions(+), 1 deletion(-) diff --git a/torchvision/models/vision_transformer.py b/torchvision/models/vision_transformer.py index fb34cf3c8e1..434d82d46ab 100644 --- a/torchvision/models/vision_transformer.py +++ b/torchvision/models/vision_transformer.py @@ -1,3 +1,4 @@ +import collections.abc as abc import math from collections import OrderedDict from functools import partial @@ -284,7 +285,15 @@ def _vision_transformer( progress: bool, **kwargs: Any, ) -> VisionTransformer: - image_size = kwargs.pop("image_size", 224) + + image_size = None + if "image_size" in kwargs: + image_size = kwargs.pop("image_size", None) + if image_size is None and weights is not None and "size" in weights.meta: + image_size = weights.meta["size"] + if isinstance(image_size, abc.Sequence): + image_size = image_size[0] + image_size = image_size or 224 if weights is not None: _ovewrite_named_param(kwargs, "num_classes", len(weights.meta["categories"])) @@ -313,6 +322,15 @@ def _vision_transformer( "interpolation": InterpolationMode.BILINEAR, } +_COMMON_SWAG_META = { + "task": "image_classification", + "architecture": "ViT", + "publication_year": 2022, + "recipe": "https://github.com/facebookresearch/SWAG", + "categories": _IMAGENET_CATEGORIES, + "interpolation": InterpolationMode.BICUBIC, +} + class ViT_B_16_Weights(WeightsEnum): IMAGENET1K_V1 = Weights( @@ -328,6 +346,20 @@ class ViT_B_16_Weights(WeightsEnum): "acc@5": 95.318, }, ) + IMAGENET1K_SWAG_V1 = Weights( + url="https://download.pytorch.org/models/vit_b_16_swag-9ac1b537.pth", + transforms=partial( + ImageClassification, resize_size=384, interpolation=InterpolationMode.BICUBIC, crop_size=384 + ), + meta={ + **_COMMON_SWAG_META, + "num_params": 86859496, + "size": (384, 384), + "min_size": (384, 384), + "acc@1": 85.29, + "acc@5": 97.65, + }, + ) DEFAULT = IMAGENET1K_V1 From 9e13f79b47513bbbfd0fb3ffbf455fce05e691ed Mon Sep 17 00:00:00 2001 From: Yosua Michael Maranatha Date: Thu, 31 Mar 2022 16:15:20 +0100 Subject: [PATCH 02/16] Better handling idiom for image_size, edit test_extended_model to handle case where number of param differ from default due to different image size input --- test/test_extended_models.py | 3 ++- torchvision/models/vision_transformer.py | 22 ++++++++++++---------- 2 files changed, 14 insertions(+), 11 deletions(-) diff --git a/test/test_extended_models.py b/test/test_extended_models.py index a07b501e15b..994a3b7036a 100644 --- a/test/test_extended_models.py +++ b/test/test_extended_models.py @@ -116,7 +116,8 @@ def test_schema_meta_validation(model_fn): incorrect_params.append(w) else: if w.meta.get("num_params") != weights_enum.DEFAULT.meta.get("num_params"): - incorrect_params.append(w) + if w.meta.get("num_params") != sum(p.numel() for p in model_fn(weights=w).parameters()): + incorrect_params.append(w) if not w.name.isupper(): bad_names.append(w) diff --git a/torchvision/models/vision_transformer.py b/torchvision/models/vision_transformer.py index 434d82d46ab..979d1b55ae1 100644 --- a/torchvision/models/vision_transformer.py +++ b/torchvision/models/vision_transformer.py @@ -285,18 +285,20 @@ def _vision_transformer( progress: bool, **kwargs: Any, ) -> VisionTransformer: - - image_size = None - if "image_size" in kwargs: - image_size = kwargs.pop("image_size", None) - if image_size is None and weights is not None and "size" in weights.meta: - image_size = weights.meta["size"] - if isinstance(image_size, abc.Sequence): - image_size = image_size[0] - image_size = image_size or 224 - if weights is not None: _ovewrite_named_param(kwargs, "num_classes", len(weights.meta["categories"])) + if "size" in weights.meta: + if isinstance(weights.meta["size"], int): + _ovewrite_named_param(kwargs, "image_size", weights.meta["size"]) + elif isinstance(weights.meta["size"], abc.Sequence): + torch._assert( + weights.meta["size"][0] == weights.meta["size"][1], + "Currently we only support a square image where width = height", + ) + _ovewrite_named_param(kwargs, "image_size", weights.meta["size"][0]) + else: + raise ValueError('weights.meta["size"] should have type of either an int or a Sequence[int]') + image_size = kwargs.pop("image_size", 224) model = VisionTransformer( image_size=image_size, From 17071719cc54a726b8fd34843c1a0a10ee8cbf7b Mon Sep 17 00:00:00 2001 From: Yosua Michael Maranatha Date: Thu, 31 Mar 2022 16:42:04 +0100 Subject: [PATCH 03/16] Update the accuracy to the experiment result on torchvision model --- torchvision/models/vision_transformer.py | 9 +++------ 1 file changed, 3 insertions(+), 6 deletions(-) diff --git a/torchvision/models/vision_transformer.py b/torchvision/models/vision_transformer.py index 979d1b55ae1..fbb5ebed668 100644 --- a/torchvision/models/vision_transformer.py +++ b/torchvision/models/vision_transformer.py @@ -325,11 +325,8 @@ def _vision_transformer( } _COMMON_SWAG_META = { - "task": "image_classification", - "architecture": "ViT", - "publication_year": 2022, + **COMMON_META, "recipe": "https://github.com/facebookresearch/SWAG", - "categories": _IMAGENET_CATEGORIES, "interpolation": InterpolationMode.BICUBIC, } @@ -358,8 +355,8 @@ class ViT_B_16_Weights(WeightsEnum): "num_params": 86859496, "size": (384, 384), "min_size": (384, 384), - "acc@1": 85.29, - "acc@5": 97.65, + "acc@1": 85.304, + "acc@5": 97.650, }, ) DEFAULT = IMAGENET1K_V1 From bd8b1a87e13a1a822c4a03b20cb7db28be9aa6e5 Mon Sep 17 00:00:00 2001 From: Yosua Michael Maranatha Date: Thu, 31 Mar 2022 17:04:57 +0100 Subject: [PATCH 04/16] Fix typo missing underscore --- torchvision/models/vision_transformer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchvision/models/vision_transformer.py b/torchvision/models/vision_transformer.py index fbb5ebed668..403d3367cdd 100644 --- a/torchvision/models/vision_transformer.py +++ b/torchvision/models/vision_transformer.py @@ -325,7 +325,7 @@ def _vision_transformer( } _COMMON_SWAG_META = { - **COMMON_META, + **_COMMON_META, "recipe": "https://github.com/facebookresearch/SWAG", "interpolation": InterpolationMode.BICUBIC, } From 6c765a5993cf7d9655c2a3a1712985be854573b5 Mon Sep 17 00:00:00 2001 From: Yosua Michael Maranatha Date: Thu, 31 Mar 2022 17:57:02 +0100 Subject: [PATCH 05/16] raise exception instead of torch._assert, add back publication year (accidentally deleted) --- torchvision/models/vision_transformer.py | 30 ++++++++++++++---------- 1 file changed, 17 insertions(+), 13 deletions(-) diff --git a/torchvision/models/vision_transformer.py b/torchvision/models/vision_transformer.py index 403d3367cdd..29774b6701b 100644 --- a/torchvision/models/vision_transformer.py +++ b/torchvision/models/vision_transformer.py @@ -1,8 +1,7 @@ -import collections.abc as abc import math from collections import OrderedDict from functools import partial -from typing import Any, Callable, List, NamedTuple, Optional +from typing import Any, Callable, List, NamedTuple, Optional, Sequence import torch import torch.nn as nn @@ -287,17 +286,18 @@ def _vision_transformer( ) -> VisionTransformer: if weights is not None: _ovewrite_named_param(kwargs, "num_classes", len(weights.meta["categories"])) - if "size" in weights.meta: - if isinstance(weights.meta["size"], int): - _ovewrite_named_param(kwargs, "image_size", weights.meta["size"]) - elif isinstance(weights.meta["size"], abc.Sequence): - torch._assert( - weights.meta["size"][0] == weights.meta["size"][1], - "Currently we only support a square image where width = height", + if isinstance(weights.meta["size"], int): + _ovewrite_named_param(kwargs, "image_size", weights.meta["size"]) + elif isinstance(weights.meta["size"], Sequence): + if len(weights.meta["size"]) != 2 or weights.meta["size"][0] != weights.meta["size"][1]: + raise ValueError( + f'size: {weights.meta["size"]} is not valid! Currently we only support a 2-dimensional square and width = height' ) - _ovewrite_named_param(kwargs, "image_size", weights.meta["size"][0]) - else: - raise ValueError('weights.meta["size"] should have type of either an int or a Sequence[int]') + _ovewrite_named_param(kwargs, "image_size", weights.meta["size"][0]) + else: + raise ValueError( + f'weights.meta["size"]: {weights.meta["size"]} is not valid, the type should be either an int or a Sequence[int]' + ) image_size = kwargs.pop("image_size", 224) model = VisionTransformer( @@ -326,6 +326,7 @@ def _vision_transformer( _COMMON_SWAG_META = { **_COMMON_META, + "publication_year": 2022, "recipe": "https://github.com/facebookresearch/SWAG", "interpolation": InterpolationMode.BICUBIC, } @@ -348,7 +349,10 @@ class ViT_B_16_Weights(WeightsEnum): IMAGENET1K_SWAG_V1 = Weights( url="https://download.pytorch.org/models/vit_b_16_swag-9ac1b537.pth", transforms=partial( - ImageClassification, resize_size=384, interpolation=InterpolationMode.BICUBIC, crop_size=384 + ImageClassification, + crop_size=384, + resize_size=384, + interpolation=InterpolationMode.BICUBIC, ), meta={ **_COMMON_SWAG_META, From e444c5a1a00fe3c7b41c6f93f17c82de9d271c41 Mon Sep 17 00:00:00 2001 From: Yosua Michael Maranatha Date: Fri, 1 Apr 2022 13:47:17 +0100 Subject: [PATCH 06/16] Add license information on meta and readme --- README.rst | 7 +++++++ torchvision/models/vision_transformer.py | 1 + 2 files changed, 8 insertions(+) diff --git a/README.rst b/README.rst index a65ffd8340a..a689bf25b91 100644 --- a/README.rst +++ b/README.rst @@ -185,3 +185,10 @@ Disclaimer on Datasets This is a utility library that downloads and prepares public datasets. We do not host or distribute these datasets, vouch for their quality or fairness, or claim that you have license to use the dataset. It is your responsibility to determine whether you have permission to use the dataset under the dataset's license. If you're a dataset owner and wish to update any part of it (description, citation, etc.), or do not want your dataset to be included in this library, please get in touch through a GitHub issue. Thanks for your contribution to the ML community! + +Disclaimer on Models +==================== + +Pretrained models provided in this library may have their own license or term and condition to use. It is your responsibility to determine whether you have permission to use the model for your use case. + +More specifically, SWAG models are released under the CC-BY-NC 4.0 license. See `SWAG LICENSE `_ for additional details. diff --git a/torchvision/models/vision_transformer.py b/torchvision/models/vision_transformer.py index 29774b6701b..4a550ce39bb 100644 --- a/torchvision/models/vision_transformer.py +++ b/torchvision/models/vision_transformer.py @@ -328,6 +328,7 @@ def _vision_transformer( **_COMMON_META, "publication_year": 2022, "recipe": "https://github.com/facebookresearch/SWAG", + "license": "https://github.com/facebookresearch/SWAG/blob/main/LICENSE", "interpolation": InterpolationMode.BICUBIC, } From 54aa8cff32186d44f5ee9b397b8fb3487aa45e7e Mon Sep 17 00:00:00 2001 From: Yosua Michael Maranatha Date: Fri, 1 Apr 2022 13:55:34 +0100 Subject: [PATCH 07/16] Improve wording and fix typo for pretrained model license in readme --- README.rst | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/README.rst b/README.rst index a689bf25b91..32ee5bb90ee 100644 --- a/README.rst +++ b/README.rst @@ -186,9 +186,9 @@ This is a utility library that downloads and prepares public datasets. We do not If you're a dataset owner and wish to update any part of it (description, citation, etc.), or do not want your dataset to be included in this library, please get in touch through a GitHub issue. Thanks for your contribution to the ML community! -Disclaimer on Models -==================== +Pre-trained Model License +========================= -Pretrained models provided in this library may have their own license or term and condition to use. It is your responsibility to determine whether you have permission to use the model for your use case. +Pre-trained models provided in this library may have their own license or terms and conditions to use. It is your responsibility to determine whether you have permission to use the model for your use case. More specifically, SWAG models are released under the CC-BY-NC 4.0 license. See `SWAG LICENSE `_ for additional details. From f9c32ebdfb7a2220a35b35c704d47f62a5e6bd31 Mon Sep 17 00:00:00 2001 From: Yosua Michael Maranatha Date: Fri, 1 Apr 2022 14:37:57 +0100 Subject: [PATCH 08/16] Add vit_l_16 weight --- torchvision/models/vision_transformer.py | 18 ++++++++++++++++++ 1 file changed, 18 insertions(+) diff --git a/torchvision/models/vision_transformer.py b/torchvision/models/vision_transformer.py index 4a550ce39bb..340852812e1 100644 --- a/torchvision/models/vision_transformer.py +++ b/torchvision/models/vision_transformer.py @@ -398,6 +398,24 @@ class ViT_L_16_Weights(WeightsEnum): "acc@5": 94.638, }, ) + IMAGENET1K_SWAG_V1 = Weights( + url="https://download.pytorch.org/models/vit_l_16_swag-4f3808c9.pth", + transforms=partial( + ImageClassification, + crop_size=512, + resize_size=512, + interpolation=InterpolationMode.BICUBIC, + ), + meta={ + **_COMMON_SWAG_META, + "num_params": 305174504, + "size": (512, 512), + "min_size": (512, 512), + # Still mock: + "acc@1": 88.07, + "acc@5": 98.51, + }, + ) DEFAULT = IMAGENET1K_V1 From 4cf4eff70b89044f344ebfa5e3c642683ed2ed12 Mon Sep 17 00:00:00 2001 From: YosuaMichael Date: Fri, 1 Apr 2022 15:15:51 +0100 Subject: [PATCH 09/16] Update README.rst Co-authored-by: Vasilis Vryniotis --- README.rst | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.rst b/README.rst index 32ee5bb90ee..bf1a9b53106 100644 --- a/README.rst +++ b/README.rst @@ -189,6 +189,6 @@ If you're a dataset owner and wish to update any part of it (description, citati Pre-trained Model License ========================= -Pre-trained models provided in this library may have their own license or terms and conditions to use. It is your responsibility to determine whether you have permission to use the model for your use case. +The pre-trained models provided in this library may have their own licenses or terms and conditions derived from the dataset used for training. It is your responsibility to determine whether you have permission to use the models for your use case. More specifically, SWAG models are released under the CC-BY-NC 4.0 license. See `SWAG LICENSE `_ for additional details. From 9230f40de89cd009ef67e155c41ae5e0408b9582 Mon Sep 17 00:00:00 2001 From: Yosua Michael Maranatha Date: Fri, 1 Apr 2022 15:21:32 +0100 Subject: [PATCH 10/16] Update the accuracy meta on vit_l_16_swag model to result from our experiment --- torchvision/models/vision_transformer.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/torchvision/models/vision_transformer.py b/torchvision/models/vision_transformer.py index 340852812e1..59da51c1bd9 100644 --- a/torchvision/models/vision_transformer.py +++ b/torchvision/models/vision_transformer.py @@ -411,9 +411,8 @@ class ViT_L_16_Weights(WeightsEnum): "num_params": 305174504, "size": (512, 512), "min_size": (512, 512), - # Still mock: - "acc@1": 88.07, - "acc@5": 98.51, + "acc@1": 88.064, + "acc@5": 98.512, }, ) DEFAULT = IMAGENET1K_V1 From ce6eb3e2f764b389de0f5f0b66b5f640ee7e59d3 Mon Sep 17 00:00:00 2001 From: Yosua Michael Maranatha Date: Fri, 1 Apr 2022 16:04:38 +0100 Subject: [PATCH 11/16] Add vit_h_14_swag model --- torchvision/models/vision_transformer.py | 48 ++++++++++++++++++++++++ 1 file changed, 48 insertions(+) diff --git a/torchvision/models/vision_transformer.py b/torchvision/models/vision_transformer.py index 59da51c1bd9..0b2e1795e78 100644 --- a/torchvision/models/vision_transformer.py +++ b/torchvision/models/vision_transformer.py @@ -20,10 +20,12 @@ "ViT_B_32_Weights", "ViT_L_16_Weights", "ViT_L_32_Weights", + "ViT_H_14_Weights", "vit_b_16", "vit_b_32", "vit_l_16", "vit_l_32", + "vit_h_14", ] @@ -435,6 +437,28 @@ class ViT_L_32_Weights(WeightsEnum): DEFAULT = IMAGENET1K_V1 +class ViT_H_14_Weights(WeightsEnum): + IMAGENET1K_SWAG_V1 = Weights( + url="https://download.pytorch.org/models/vit_h_14_swag-80465313.pth", + transforms=partial( + ImageClassification, + crop_size=518, + resize_size=518, + interpolation=InterpolationMode.BICUBIC, + ), + meta={ + **_COMMON_SWAG_META, + "num_params": 633470440, + "size": (518, 518), + "min_size": (518, 518), + # Still mock + "acc@1": 88.55, + "acc@5": 98.69, + }, + ) + DEFAULT = IMAGENET1K_SWAG_V1 + + @handle_legacy_interface(weights=("pretrained", ViT_B_16_Weights.IMAGENET1K_V1)) def vit_b_16(*, weights: Optional[ViT_B_16_Weights] = None, progress: bool = True, **kwargs: Any) -> VisionTransformer: """ @@ -531,6 +555,30 @@ def vit_l_32(*, weights: Optional[ViT_L_32_Weights] = None, progress: bool = Tru ) +@handle_legacy_interface(weights=("pretrained", ViT_H_14_Weights.IMAGENET1K_SWAG_V1)) +def vit_h_14(*, weights: Optional[ViT_H_14_Weights] = None, progress: bool = True, **kwargs: Any) -> VisionTransformer: + """ + Constructs a vit_h_14 architecture from + `"An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale" `_. + + Args: + weights (ViT_H_14_Weights, optional): The pretrained weights for the model + progress (bool): If True, displays a progress bar of the download to stderr + """ + weights = ViT_H_14_Weights.verify(weights) + + return _vision_transformer( + patch_size=14, + num_layers=32, + num_heads=16, + hidden_dim=1280, + mlp_dim=5120, + weights=weights, + progress=progress, + **kwargs, + ) + + def interpolate_embeddings( image_size: int, patch_size: int, From ff76a53f3b70f3f598bbe9f8128d86e264877428 Mon Sep 17 00:00:00 2001 From: Yosua Michael Maranatha Date: Fri, 1 Apr 2022 17:02:55 +0100 Subject: [PATCH 12/16] Add accuracy from experiments --- torchvision/models/vision_transformer.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/torchvision/models/vision_transformer.py b/torchvision/models/vision_transformer.py index 0b2e1795e78..7d5ab43a774 100644 --- a/torchvision/models/vision_transformer.py +++ b/torchvision/models/vision_transformer.py @@ -451,9 +451,8 @@ class ViT_H_14_Weights(WeightsEnum): "num_params": 633470440, "size": (518, 518), "min_size": (518, 518), - # Still mock - "acc@1": 88.55, - "acc@5": 98.69, + "acc@1": 88.552, + "acc@5": 98.694, }, ) DEFAULT = IMAGENET1K_SWAG_V1 From e87454838701b8e6d52c471df5b9f1fb4b5d2d1b Mon Sep 17 00:00:00 2001 From: Yosua Michael Maranatha Date: Fri, 1 Apr 2022 17:32:21 +0100 Subject: [PATCH 13/16] Add to vit_h_16 model to hubconf.py --- hubconf.py | 1 + 1 file changed, 1 insertion(+) diff --git a/hubconf.py b/hubconf.py index c3de4f2da9a..bbd5da52b13 100644 --- a/hubconf.py +++ b/hubconf.py @@ -67,4 +67,5 @@ vit_b_32, vit_l_16, vit_l_32, + vit_h_14, ) From 2ca4ac4153b64d61609df3b583d4b0a0ede157af Mon Sep 17 00:00:00 2001 From: Yosua Michael Maranatha Date: Fri, 1 Apr 2022 17:52:20 +0100 Subject: [PATCH 14/16] Add docs and expected pkl file for test --- docs/source/models.rst | 3 +++ test/expect/ModelTester.test_vit_h_14_expect.pkl | Bin 0 -> 939 bytes 2 files changed, 3 insertions(+) create mode 100644 test/expect/ModelTester.test_vit_h_14_expect.pkl diff --git a/docs/source/models.rst b/docs/source/models.rst index 16825d2b8b2..f84d9c7fd1a 100644 --- a/docs/source/models.rst +++ b/docs/source/models.rst @@ -92,6 +92,7 @@ You can construct a model with random weights by calling its constructor: vit_b_32 = models.vit_b_32() vit_l_16 = models.vit_l_16() vit_l_32 = models.vit_l_32() + vit_h_14 = models.vit_h_14() convnext_tiny = models.convnext_tiny() convnext_small = models.convnext_small() convnext_base = models.convnext_base() @@ -213,6 +214,7 @@ vit_b_16 81.072 95.318 vit_b_32 75.912 92.466 vit_l_16 79.662 94.638 vit_l_32 76.972 93.070 +vit_h_14 88.552 98.694 convnext_tiny 82.520 96.146 convnext_small 83.616 96.650 convnext_base 84.062 96.870 @@ -434,6 +436,7 @@ VisionTransformer vit_b_32 vit_l_16 vit_l_32 + vit_h_14 ConvNeXt -------- diff --git a/test/expect/ModelTester.test_vit_h_14_expect.pkl b/test/expect/ModelTester.test_vit_h_14_expect.pkl new file mode 100644 index 0000000000000000000000000000000000000000..1f846beb6a0bccf8b545f5a67b74482015cc878b GIT binary patch literal 939 zcmWIWW@cev;NW1u00Im`42ea_8JT6N`YDMeiFyUuIc`pT3{fbcfhoBpAE-(%zO*DW zr zf)S|3ppZF&8AvA=loqmh8ZvUemW=jY_4CYNO9=M{7L z7p0^YrKY%KCYNv(a%ct>a+VZw1r>7Z1$eV_Fj*X^nFTZrgadH;l#f9R#i#lPZcb`w z{zUOK5(jATumGUYDqB%_@&wQ~AdEY-_!+F>p;eYzR1Ay-Hz#ul>i!MRpZGie3qz3t@Vp zVG!WW#-;;RB*&}^R}M Date: Mon, 4 Apr 2022 10:36:46 +0100 Subject: [PATCH 15/16] Remove legacy compatibility for ViT_H_14 model Co-authored-by: Vasilis Vryniotis --- torchvision/models/vision_transformer.py | 1 - 1 file changed, 1 deletion(-) diff --git a/torchvision/models/vision_transformer.py b/torchvision/models/vision_transformer.py index 7d5ab43a774..de2e61c440a 100644 --- a/torchvision/models/vision_transformer.py +++ b/torchvision/models/vision_transformer.py @@ -554,7 +554,6 @@ def vit_l_32(*, weights: Optional[ViT_L_32_Weights] = None, progress: bool = Tru ) -@handle_legacy_interface(weights=("pretrained", ViT_H_14_Weights.IMAGENET1K_SWAG_V1)) def vit_h_14(*, weights: Optional[ViT_H_14_Weights] = None, progress: bool = True, **kwargs: Any) -> VisionTransformer: """ Constructs a vit_h_14 architecture from From 02be29672cfa6ea754114b28babdabd6e8f1117a Mon Sep 17 00:00:00 2001 From: Yosua Michael Maranatha Date: Mon, 4 Apr 2022 17:54:12 +0100 Subject: [PATCH 16/16] Test vit_h_14 with smaller image_size to speedup the test --- test/test_models.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/test/test_models.py b/test/test_models.py index 0fbf45b9750..5e0cc742d84 100644 --- a/test/test_models.py +++ b/test/test_models.py @@ -280,6 +280,10 @@ def _check_input_backprop(model, inputs): "rpn_pre_nms_top_n_test": 1000, "rpn_post_nms_top_n_test": 1000, }, + "vit_h_14": { + "image_size": 56, + "input_shape": (1, 3, 56, 56), + }, } # speeding up slow models: slow_models = [