From e77d8fe921f15a2562cfeda289e26c799089ea5b Mon Sep 17 00:00:00 2001 From: Vasilis Vryniotis Date: Tue, 1 Feb 2022 13:15:52 +0000 Subject: [PATCH 1/3] Graduate ConvNeXt to main TorchVision area. --- docs/source/models.rst | 28 ++- hubconf.py | 1 + torchvision/models/__init__.py | 1 + torchvision/models/convnext.py | 271 +++++++++++++++++++++++ torchvision/prototype/models/convnext.py | 187 +--------------- 5 files changed, 299 insertions(+), 189 deletions(-) create mode 100644 torchvision/models/convnext.py diff --git a/docs/source/models.rst b/docs/source/models.rst index a37748b0a3b..58bd0d81cd0 100644 --- a/docs/source/models.rst +++ b/docs/source/models.rst @@ -89,6 +89,10 @@ You can construct a model with random weights by calling its constructor: vit_b_32 = models.vit_b_32() vit_l_16 = models.vit_l_16() vit_l_32 = models.vit_l_32() + convnext_tiny = models.convnext_tiny() + convnext_small = models.convnext_small() + convnext_base = models.convnext_base() + convnext_large = models.convnext_large() We provide pre-trained models, using the PyTorch :mod:`torch.utils.model_zoo`. These can be constructed by passing ``pretrained=True``: @@ -136,6 +140,10 @@ These can be constructed by passing ``pretrained=True``: vit_b_32 = models.vit_b_32(pretrained=True) vit_l_16 = models.vit_l_16(pretrained=True) vit_l_32 = models.vit_l_32(pretrained=True) + convnext_tiny = models.convnext_tiny(pretrained=True) + convnext_small = models.convnext_small(pretrained=True) + convnext_base = models.convnext_base(pretrained=True) + convnext_large = models.convnext_large(pretrained=True) Instancing a pre-trained model will download its weights to a cache directory. This directory can be set using the `TORCH_HOME` environment variable. See @@ -248,10 +256,10 @@ vit_b_16 81.072 95.318 vit_b_32 75.912 92.466 vit_l_16 79.662 94.638 vit_l_32 76.972 93.070 -convnext_tiny (prototype) 82.520 96.146 -convnext_small (prototype) 83.616 96.650 -convnext_base (prototype) 84.062 96.870 -convnext_large (prototype) 84.414 96.976 +convnext_tiny 82.520 96.146 +convnext_small 83.616 96.650 +convnext_base 84.062 96.870 +convnext_large 84.414 96.976 ================================ ============= ============= @@ -467,6 +475,18 @@ VisionTransformer vit_l_16 vit_l_32 +ConvNeXt +-------- + +.. autosummary:: + :toctree: generated/ + :template: function.rst + + convnext_tiny + convnext_small + convnext_base + convnext_large + Quantized Models ---------------- diff --git a/hubconf.py b/hubconf.py index 2b2eeb1c166..5c2ad8e9e0d 100644 --- a/hubconf.py +++ b/hubconf.py @@ -2,6 +2,7 @@ dependencies = ["torch"] from torchvision.models.alexnet import alexnet +from torchvision.models.convnext import convnext_tiny, convnext_small, convnext_base, convnext_large from torchvision.models.densenet import densenet121, densenet169, densenet201, densenet161 from torchvision.models.efficientnet import ( efficientnet_b0, diff --git a/torchvision/models/__init__.py b/torchvision/models/__init__.py index 22e2e45d4ce..16495e8552e 100644 --- a/torchvision/models/__init__.py +++ b/torchvision/models/__init__.py @@ -1,4 +1,5 @@ from .alexnet import * +from .convnext import * from .resnet import * from .vgg import * from .squeezenet import * diff --git a/torchvision/models/convnext.py b/torchvision/models/convnext.py new file mode 100644 index 00000000000..1fdf4c6ba55 --- /dev/null +++ b/torchvision/models/convnext.py @@ -0,0 +1,271 @@ +from functools import partial +from typing import Any, Callable, Dict, List, Optional, Sequence + +import torch +from torch import nn, Tensor +from torch.nn import functional as F + +from .._internally_replaced_utils import load_state_dict_from_url +from ..ops.misc import ConvNormActivation +from ..ops.stochastic_depth import StochasticDepth +from ..utils import _log_api_usage_once + + +__all__ = [ + "ConvNeXt", + "convnext_tiny", + "convnext_small", + "convnext_base", + "convnext_large", +] + + +model_urls: Dict[str, Optional[str]] = { + "convnext_tiny": "https://download.pytorch.org/models/convnext_tiny-983f1562.pth", + "convnext_small": "https://download.pytorch.org/models/convnext_small-0c510722.pth", + "convnext_base": "https://download.pytorch.org/models/convnext_base-6075fbad.pth", + "convnext_large": "https://download.pytorch.org/models/convnext_large-ea097f82.pth", +} + + +class LayerNorm2d(nn.LayerNorm): + def forward(self, x: Tensor) -> Tensor: + x = x.permute(0, 2, 3, 1) + x = F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps) + x = x.permute(0, 3, 1, 2) + return x + + +class Permute(nn.Module): + def __init__(self, dims: List[int]): + super().__init__() + self.dims = dims + + def forward(self, x): + return torch.permute(x, self.dims) + + +class CNBlock(nn.Module): + def __init__( + self, + dim, + layer_scale: float, + stochastic_depth_prob: float, + norm_layer: Optional[Callable[..., nn.Module]] = None, + ) -> None: + super().__init__() + if norm_layer is None: + norm_layer = partial(nn.LayerNorm, eps=1e-6) + + self.block = nn.Sequential( + nn.Conv2d(dim, dim, kernel_size=7, padding=3, groups=dim, bias=True), + Permute([0, 2, 3, 1]), + norm_layer(dim), + nn.Linear(in_features=dim, out_features=4 * dim, bias=True), + nn.GELU(), + nn.Linear(in_features=4 * dim, out_features=dim, bias=True), + Permute([0, 3, 1, 2]), + ) + self.layer_scale = nn.Parameter(torch.ones(dim, 1, 1) * layer_scale) + self.stochastic_depth = StochasticDepth(stochastic_depth_prob, "row") + + def forward(self, input: Tensor) -> Tensor: + result = self.layer_scale * self.block(input) + result = self.stochastic_depth(result) + result += input + return result + + +class CNBlockConfig: + # Stores information listed at Section 3 of the ConvNeXt paper + def __init__( + self, + input_channels: int, + out_channels: Optional[int], + num_layers: int, + ) -> None: + self.input_channels = input_channels + self.out_channels = out_channels + self.num_layers = num_layers + + def __repr__(self) -> str: + s = self.__class__.__name__ + "(" + s += "input_channels={input_channels}" + s += ", out_channels={out_channels}" + s += ", num_layers={num_layers}" + s += ")" + return s.format(**self.__dict__) + + +class ConvNeXt(nn.Module): + def __init__( + self, + block_setting: List[CNBlockConfig], + stochastic_depth_prob: float = 0.0, + layer_scale: float = 1e-6, + num_classes: int = 1000, + block: Optional[Callable[..., nn.Module]] = None, + norm_layer: Optional[Callable[..., nn.Module]] = None, + **kwargs: Any, + ) -> None: + super().__init__() + _log_api_usage_once(self) + + if not block_setting: + raise ValueError("The block_setting should not be empty") + elif not (isinstance(block_setting, Sequence) and all([isinstance(s, CNBlockConfig) for s in block_setting])): + raise TypeError("The block_setting should be List[CNBlockConfig]") + + if block is None: + block = CNBlock + + if norm_layer is None: + norm_layer = partial(LayerNorm2d, eps=1e-6) + + layers: List[nn.Module] = [] + + # Stem + firstconv_output_channels = block_setting[0].input_channels + layers.append( + ConvNormActivation( + 3, + firstconv_output_channels, + kernel_size=4, + stride=4, + padding=0, + norm_layer=norm_layer, + activation_layer=None, + bias=True, + ) + ) + + total_stage_blocks = sum(cnf.num_layers for cnf in block_setting) + stage_block_id = 0 + for cnf in block_setting: + # Bottlenecks + stage: List[nn.Module] = [] + for _ in range(cnf.num_layers): + # adjust stochastic depth probability based on the depth of the stage block + sd_prob = stochastic_depth_prob * stage_block_id / (total_stage_blocks - 1.0) + stage.append(block(cnf.input_channels, layer_scale, sd_prob)) + stage_block_id += 1 + layers.append(nn.Sequential(*stage)) + if cnf.out_channels is not None: + # Downsampling + layers.append( + nn.Sequential( + norm_layer(cnf.input_channels), + nn.Conv2d(cnf.input_channels, cnf.out_channels, kernel_size=2, stride=2), + ) + ) + + self.features = nn.Sequential(*layers) + self.avgpool = nn.AdaptiveAvgPool2d(1) + + lastblock = block_setting[-1] + lastconv_output_channels = ( + lastblock.out_channels if lastblock.out_channels is not None else lastblock.input_channels + ) + self.classifier = nn.Sequential( + norm_layer(lastconv_output_channels), nn.Flatten(1), nn.Linear(lastconv_output_channels, num_classes) + ) + + for m in self.modules(): + if isinstance(m, (nn.Conv2d, nn.Linear)): + nn.init.trunc_normal_(m.weight, std=0.02) + if m.bias is not None: + nn.init.zeros_(m.bias) + + def _forward_impl(self, x: Tensor) -> Tensor: + x = self.features(x) + x = self.avgpool(x) + x = self.classifier(x) + return x + + def forward(self, x: Tensor) -> Tensor: + return self._forward_impl(x) + + +def _convnext( + arch: str, + block_setting: List[CNBlockConfig], + stochastic_depth_prob: float, + pretrained: bool, + progress: bool, + **kwargs: Any, +) -> ConvNeXt: + model = ConvNeXt(block_setting, stochastic_depth_prob=stochastic_depth_prob, **kwargs) + if pretrained: + if arch not in model_urls: + raise ValueError(f"No checkpoint is available for model type {arch}") + state_dict = load_state_dict_from_url(model_urls[arch], progress=progress) + model.load_state_dict(state_dict) + return model + + +def convnext_tiny(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ConvNeXt: + r"""ConvNeXt Tiny model architecture from the + `"A ConvNet for the 2020s" `_ paper. + Args: + pretrained (bool): If True, returns a model pre-trained on ImageNet + progress (bool): If True, displays a progress bar of the download to stderr + """ + block_setting = [ + CNBlockConfig(96, 192, 3), + CNBlockConfig(192, 384, 3), + CNBlockConfig(384, 768, 9), + CNBlockConfig(768, None, 3), + ] + stochastic_depth_prob = kwargs.pop("stochastic_depth_prob", 0.1) + return _convnext("convnext_tiny", block_setting, stochastic_depth_prob, pretrained, progress, **kwargs) + + +def convnext_small(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ConvNeXt: + r"""ConvNeXt Small model architecture from the + `"A ConvNet for the 2020s" `_ paper. + Args: + pretrained (bool): If True, returns a model pre-trained on ImageNet + progress (bool): If True, displays a progress bar of the download to stderr + """ + block_setting = [ + CNBlockConfig(96, 192, 3), + CNBlockConfig(192, 384, 3), + CNBlockConfig(384, 768, 27), + CNBlockConfig(768, None, 3), + ] + stochastic_depth_prob = kwargs.pop("stochastic_depth_prob", 0.4) + return _convnext("convnext_small", block_setting, stochastic_depth_prob, pretrained, progress, **kwargs) + + +def convnext_base(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ConvNeXt: + r"""ConvNeXt Base model architecture from the + `"A ConvNet for the 2020s" `_ paper. + Args: + pretrained (bool): If True, returns a model pre-trained on ImageNet + progress (bool): If True, displays a progress bar of the download to stderr + """ + block_setting = [ + CNBlockConfig(128, 256, 3), + CNBlockConfig(256, 512, 3), + CNBlockConfig(512, 1024, 27), + CNBlockConfig(1024, None, 3), + ] + stochastic_depth_prob = kwargs.pop("stochastic_depth_prob", 0.5) + return _convnext("convnext_base", block_setting, stochastic_depth_prob, pretrained, progress, **kwargs) + + +def convnext_large(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ConvNeXt: + r"""ConvNeXt Large model architecture from the + `"A ConvNet for the 2020s" `_ paper. + Args: + pretrained (bool): If True, returns a model pre-trained on ImageNet + progress (bool): If True, displays a progress bar of the download to stderr + """ + block_setting = [ + CNBlockConfig(192, 384, 3), + CNBlockConfig(384, 768, 3), + CNBlockConfig(768, 1536, 27), + CNBlockConfig(1536, None, 3), + ] + stochastic_depth_prob = kwargs.pop("stochastic_depth_prob", 0.5) + return _convnext("convnext_large", block_setting, stochastic_depth_prob, pretrained, progress, **kwargs) \ No newline at end of file diff --git a/torchvision/prototype/models/convnext.py b/torchvision/prototype/models/convnext.py index f8f91307ed1..f6a2bf48539 100644 --- a/torchvision/prototype/models/convnext.py +++ b/torchvision/prototype/models/convnext.py @@ -1,191 +1,15 @@ from functools import partial -from typing import Any, Callable, List, Optional, Sequence +from typing import Any, List, Optional -import torch -from torch import nn, Tensor -from torch.nn import functional as F from torchvision.prototype.transforms import ImageNetEval from torchvision.transforms.functional import InterpolationMode -from ...ops.misc import ConvNormActivation -from ...ops.stochastic_depth import StochasticDepth -from ...utils import _log_api_usage_once +from ...models.convnext import ConvNeXt, CNBlockConfig from ._api import WeightsEnum, Weights from ._meta import _IMAGENET_CATEGORIES from ._utils import handle_legacy_interface, _ovewrite_named_param -__all__ = [ - "ConvNeXt", - "ConvNeXt_Tiny_Weights", - "ConvNeXt_Small_Weights", - "ConvNeXt_Base_Weights", - "ConvNeXt_Large_Weights", - "convnext_tiny", - "convnext_small", - "convnext_base", - "convnext_large", -] - - -class LayerNorm2d(nn.LayerNorm): - def forward(self, x: Tensor) -> Tensor: - x = x.permute(0, 2, 3, 1) - x = F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps) - x = x.permute(0, 3, 1, 2) - return x - - -class Permute(nn.Module): - def __init__(self, dims: List[int]): - super().__init__() - self.dims = dims - - def forward(self, x): - return torch.permute(x, self.dims) - - -class CNBlock(nn.Module): - def __init__( - self, - dim, - layer_scale: float, - stochastic_depth_prob: float, - norm_layer: Optional[Callable[..., nn.Module]] = None, - ) -> None: - super().__init__() - if norm_layer is None: - norm_layer = partial(nn.LayerNorm, eps=1e-6) - - self.block = nn.Sequential( - nn.Conv2d(dim, dim, kernel_size=7, padding=3, groups=dim, bias=True), - Permute([0, 2, 3, 1]), - norm_layer(dim), - nn.Linear(in_features=dim, out_features=4 * dim, bias=True), - nn.GELU(), - nn.Linear(in_features=4 * dim, out_features=dim, bias=True), - Permute([0, 3, 1, 2]), - ) - self.layer_scale = nn.Parameter(torch.ones(dim, 1, 1) * layer_scale) - self.stochastic_depth = StochasticDepth(stochastic_depth_prob, "row") - - def forward(self, input: Tensor) -> Tensor: - result = self.layer_scale * self.block(input) - result = self.stochastic_depth(result) - result += input - return result - - -class CNBlockConfig: - # Stores information listed at Section 3 of the ConvNeXt paper - def __init__( - self, - input_channels: int, - out_channels: Optional[int], - num_layers: int, - ) -> None: - self.input_channels = input_channels - self.out_channels = out_channels - self.num_layers = num_layers - - def __repr__(self) -> str: - s = self.__class__.__name__ + "(" - s += "input_channels={input_channels}" - s += ", out_channels={out_channels}" - s += ", num_layers={num_layers}" - s += ")" - return s.format(**self.__dict__) - - -class ConvNeXt(nn.Module): - def __init__( - self, - block_setting: List[CNBlockConfig], - stochastic_depth_prob: float = 0.0, - layer_scale: float = 1e-6, - num_classes: int = 1000, - block: Optional[Callable[..., nn.Module]] = None, - norm_layer: Optional[Callable[..., nn.Module]] = None, - **kwargs: Any, - ) -> None: - super().__init__() - _log_api_usage_once(self) - - if not block_setting: - raise ValueError("The block_setting should not be empty") - elif not (isinstance(block_setting, Sequence) and all([isinstance(s, CNBlockConfig) for s in block_setting])): - raise TypeError("The block_setting should be List[CNBlockConfig]") - - if block is None: - block = CNBlock - - if norm_layer is None: - norm_layer = partial(LayerNorm2d, eps=1e-6) - - layers: List[nn.Module] = [] - - # Stem - firstconv_output_channels = block_setting[0].input_channels - layers.append( - ConvNormActivation( - 3, - firstconv_output_channels, - kernel_size=4, - stride=4, - padding=0, - norm_layer=norm_layer, - activation_layer=None, - bias=True, - ) - ) - - total_stage_blocks = sum(cnf.num_layers for cnf in block_setting) - stage_block_id = 0 - for cnf in block_setting: - # Bottlenecks - stage: List[nn.Module] = [] - for _ in range(cnf.num_layers): - # adjust stochastic depth probability based on the depth of the stage block - sd_prob = stochastic_depth_prob * stage_block_id / (total_stage_blocks - 1.0) - stage.append(block(cnf.input_channels, layer_scale, sd_prob)) - stage_block_id += 1 - layers.append(nn.Sequential(*stage)) - if cnf.out_channels is not None: - # Downsampling - layers.append( - nn.Sequential( - norm_layer(cnf.input_channels), - nn.Conv2d(cnf.input_channels, cnf.out_channels, kernel_size=2, stride=2), - ) - ) - - self.features = nn.Sequential(*layers) - self.avgpool = nn.AdaptiveAvgPool2d(1) - - lastblock = block_setting[-1] - lastconv_output_channels = ( - lastblock.out_channels if lastblock.out_channels is not None else lastblock.input_channels - ) - self.classifier = nn.Sequential( - norm_layer(lastconv_output_channels), nn.Flatten(1), nn.Linear(lastconv_output_channels, num_classes) - ) - - for m in self.modules(): - if isinstance(m, (nn.Conv2d, nn.Linear)): - nn.init.trunc_normal_(m.weight, std=0.02) - if m.bias is not None: - nn.init.zeros_(m.bias) - - def _forward_impl(self, x: Tensor) -> Tensor: - x = self.features(x) - x = self.avgpool(x) - x = self.classifier(x) - return x - - def forward(self, x: Tensor) -> Tensor: - return self._forward_impl(x) - - def _convnext( block_setting: List[CNBlockConfig], stochastic_depth_prob: float, @@ -274,13 +98,6 @@ class ConvNeXt_Large_Weights(WeightsEnum): @handle_legacy_interface(weights=("pretrained", ConvNeXt_Tiny_Weights.IMAGENET1K_V1)) def convnext_tiny(*, weights: Optional[ConvNeXt_Tiny_Weights] = None, progress: bool = True, **kwargs: Any) -> ConvNeXt: - r"""ConvNeXt Tiny model architecture from the - `"A ConvNet for the 2020s" `_ paper. - - Args: - weights (ConvNeXt_Tiny_Weights, optional): The pre-trained weights of the model - progress (bool): If True, displays a progress bar of the download to stderr - """ weights = ConvNeXt_Tiny_Weights.verify(weights) block_setting = [ From 49c443312c6da31ef2dcb16729af63ae362a91b5 Mon Sep 17 00:00:00 2001 From: Vasilis Vryniotis Date: Tue, 1 Feb 2022 13:18:39 +0000 Subject: [PATCH 2/3] Linter and all var. --- torchvision/models/convnext.py | 2 +- torchvision/prototype/models/convnext.py | 13 +++++++++++++ 2 files changed, 14 insertions(+), 1 deletion(-) diff --git a/torchvision/models/convnext.py b/torchvision/models/convnext.py index 1fdf4c6ba55..b026807df44 100644 --- a/torchvision/models/convnext.py +++ b/torchvision/models/convnext.py @@ -268,4 +268,4 @@ def convnext_large(pretrained: bool = False, progress: bool = True, **kwargs: An CNBlockConfig(1536, None, 3), ] stochastic_depth_prob = kwargs.pop("stochastic_depth_prob", 0.5) - return _convnext("convnext_large", block_setting, stochastic_depth_prob, pretrained, progress, **kwargs) \ No newline at end of file + return _convnext("convnext_large", block_setting, stochastic_depth_prob, pretrained, progress, **kwargs) diff --git a/torchvision/prototype/models/convnext.py b/torchvision/prototype/models/convnext.py index f6a2bf48539..ab9d08fbd3a 100644 --- a/torchvision/prototype/models/convnext.py +++ b/torchvision/prototype/models/convnext.py @@ -10,6 +10,19 @@ from ._utils import handle_legacy_interface, _ovewrite_named_param +__all__ = [ + "ConvNeXt", + "ConvNeXt_Tiny_Weights", + "ConvNeXt_Small_Weights", + "ConvNeXt_Base_Weights", + "ConvNeXt_Large_Weights", + "convnext_tiny", + "convnext_small", + "convnext_base", + "convnext_large", +] + + def _convnext( block_setting: List[CNBlockConfig], stochastic_depth_prob: float, From 7e8c48268f15dba4be8b337086e113dab242adb1 Mon Sep 17 00:00:00 2001 From: Vasilis Vryniotis Date: Tue, 1 Feb 2022 14:23:59 +0000 Subject: [PATCH 3/3] Renaming var and making named params mandatory. --- torchvision/models/convnext.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/torchvision/models/convnext.py b/torchvision/models/convnext.py index b026807df44..9067b6876fd 100644 --- a/torchvision/models/convnext.py +++ b/torchvision/models/convnext.py @@ -20,7 +20,7 @@ ] -model_urls: Dict[str, Optional[str]] = { +_MODELS_URLS: Dict[str, Optional[str]] = { "convnext_tiny": "https://download.pytorch.org/models/convnext_tiny-983f1562.pth", "convnext_small": "https://download.pytorch.org/models/convnext_small-0c510722.pth", "convnext_base": "https://download.pytorch.org/models/convnext_base-6075fbad.pth", @@ -196,14 +196,14 @@ def _convnext( ) -> ConvNeXt: model = ConvNeXt(block_setting, stochastic_depth_prob=stochastic_depth_prob, **kwargs) if pretrained: - if arch not in model_urls: + if arch not in _MODELS_URLS: raise ValueError(f"No checkpoint is available for model type {arch}") - state_dict = load_state_dict_from_url(model_urls[arch], progress=progress) + state_dict = load_state_dict_from_url(_MODELS_URLS[arch], progress=progress) model.load_state_dict(state_dict) return model -def convnext_tiny(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ConvNeXt: +def convnext_tiny(*, pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ConvNeXt: r"""ConvNeXt Tiny model architecture from the `"A ConvNet for the 2020s" `_ paper. Args: @@ -220,7 +220,7 @@ def convnext_tiny(pretrained: bool = False, progress: bool = True, **kwargs: Any return _convnext("convnext_tiny", block_setting, stochastic_depth_prob, pretrained, progress, **kwargs) -def convnext_small(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ConvNeXt: +def convnext_small(*, pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ConvNeXt: r"""ConvNeXt Small model architecture from the `"A ConvNet for the 2020s" `_ paper. Args: @@ -237,7 +237,7 @@ def convnext_small(pretrained: bool = False, progress: bool = True, **kwargs: An return _convnext("convnext_small", block_setting, stochastic_depth_prob, pretrained, progress, **kwargs) -def convnext_base(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ConvNeXt: +def convnext_base(*, pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ConvNeXt: r"""ConvNeXt Base model architecture from the `"A ConvNet for the 2020s" `_ paper. Args: @@ -254,7 +254,7 @@ def convnext_base(pretrained: bool = False, progress: bool = True, **kwargs: Any return _convnext("convnext_base", block_setting, stochastic_depth_prob, pretrained, progress, **kwargs) -def convnext_large(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ConvNeXt: +def convnext_large(*, pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ConvNeXt: r"""ConvNeXt Large model architecture from the `"A ConvNet for the 2020s" `_ paper. Args: