diff --git a/docs/source/models.rst b/docs/source/models.rst index 82eb3170e78..a37748b0a3b 100644 --- a/docs/source/models.rst +++ b/docs/source/models.rst @@ -249,6 +249,9 @@ vit_b_32 75.912 92.466 vit_l_16 79.662 94.638 vit_l_32 76.972 93.070 convnext_tiny (prototype) 82.520 96.146 +convnext_small (prototype) 83.616 96.650 +convnext_base (prototype) 84.062 96.870 +convnext_large (prototype) 84.414 96.976 ================================ ============= ============= diff --git a/references/classification/README.md b/references/classification/README.md index 0fb27eac7cc..e75336f23ca 100644 --- a/references/classification/README.md +++ b/references/classification/README.md @@ -201,11 +201,12 @@ and `--batch_size 64`. ### ConvNeXt ``` torchrun --nproc_per_node=8 train.py\ ---model convnext_tiny --batch-size 128 --opt adamw --lr 1e-3 --lr-scheduler cosineannealinglr \ +--model $MODEL --batch-size 128 --opt adamw --lr 1e-3 --lr-scheduler cosineannealinglr \ --lr-warmup-epochs 5 --lr-warmup-method linear --auto-augment ta_wide --epochs 600 --random-erase 0.1 \ --label-smoothing 0.1 --mixup-alpha 0.2 --cutmix-alpha 1.0 --weight-decay 0.05 --norm-weight-decay 0.0 \ ---train-crop-size 176 --model-ema --val-resize-size 236 --ra-sampler --ra-reps 4 +--train-crop-size 176 --model-ema --val-resize-size 232 --ra-sampler --ra-reps 4 ``` +Here `$MODEL` is one of `convnext_tiny`, `convnext_small`, `convnext_base` and `convnext_large`. Note that each variant had its `--val-resize-size` optimized in a post-training step, see their `Weights` entry for their exact value. Note that the above command corresponds to training on a single node with 8 GPUs. For generatring the pre-trained weights, we trained with 2 nodes, each with 8 GPUs (for a total of 16 GPUs), diff --git a/test/expect/ModelTester.test_convnext_base_expect.pkl b/test/expect/ModelTester.test_convnext_base_expect.pkl new file mode 100644 index 00000000000..09148743b10 Binary files /dev/null and b/test/expect/ModelTester.test_convnext_base_expect.pkl differ diff --git a/test/expect/ModelTester.test_convnext_large_expect.pkl b/test/expect/ModelTester.test_convnext_large_expect.pkl new file mode 100644 index 00000000000..98a85bc27f5 Binary files /dev/null and b/test/expect/ModelTester.test_convnext_large_expect.pkl differ diff --git a/test/expect/ModelTester.test_convnext_small_expect.pkl b/test/expect/ModelTester.test_convnext_small_expect.pkl new file mode 100644 index 00000000000..f5bf3b800bf Binary files /dev/null and b/test/expect/ModelTester.test_convnext_small_expect.pkl differ diff --git a/torchvision/prototype/models/convnext.py b/torchvision/prototype/models/convnext.py index 7fb3026b4e0..f8f91307ed1 100644 --- a/torchvision/prototype/models/convnext.py +++ b/torchvision/prototype/models/convnext.py @@ -15,47 +15,56 @@ from ._utils import handle_legacy_interface, _ovewrite_named_param -__all__ = ["ConvNeXt", "ConvNeXt_Tiny_Weights", "convnext_tiny"] +__all__ = [ + "ConvNeXt", + "ConvNeXt_Tiny_Weights", + "ConvNeXt_Small_Weights", + "ConvNeXt_Base_Weights", + "ConvNeXt_Large_Weights", + "convnext_tiny", + "convnext_small", + "convnext_base", + "convnext_large", +] class LayerNorm2d(nn.LayerNorm): - def __init__(self, *args: Any, **kwargs: Any) -> None: - self.channels_last = kwargs.pop("channels_last", False) - super().__init__(*args, **kwargs) - def forward(self, x: Tensor) -> Tensor: - # TODO: Benchmark this against the approach described at https://github.com/pytorch/vision/pull/5197#discussion_r786251298 - if not self.channels_last: - x = x.permute(0, 2, 3, 1) + x = x.permute(0, 2, 3, 1) x = F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps) - if not self.channels_last: - x = x.permute(0, 3, 1, 2) + x = x.permute(0, 3, 1, 2) return x +class Permute(nn.Module): + def __init__(self, dims: List[int]): + super().__init__() + self.dims = dims + + def forward(self, x): + return torch.permute(x, self.dims) + + class CNBlock(nn.Module): def __init__( - self, dim, layer_scale: float, stochastic_depth_prob: float, norm_layer: Callable[..., nn.Module] + self, + dim, + layer_scale: float, + stochastic_depth_prob: float, + norm_layer: Optional[Callable[..., nn.Module]] = None, ) -> None: super().__init__() + if norm_layer is None: + norm_layer = partial(nn.LayerNorm, eps=1e-6) + self.block = nn.Sequential( - ConvNormActivation( - dim, - dim, - kernel_size=7, - groups=dim, - norm_layer=norm_layer, - activation_layer=None, - bias=True, - ), - ConvNormActivation(dim, 4 * dim, kernel_size=1, norm_layer=None, activation_layer=nn.GELU, inplace=None), - ConvNormActivation( - 4 * dim, - dim, - kernel_size=1, - norm_layer=None, - activation_layer=None, - ), + nn.Conv2d(dim, dim, kernel_size=7, padding=3, groups=dim, bias=True), + Permute([0, 2, 3, 1]), + norm_layer(dim), + nn.Linear(in_features=dim, out_features=4 * dim, bias=True), + nn.GELU(), + nn.Linear(in_features=4 * dim, out_features=dim, bias=True), + Permute([0, 3, 1, 2]), ) self.layer_scale = nn.Parameter(torch.ones(dim, 1, 1) * layer_scale) self.stochastic_depth = StochasticDepth(stochastic_depth_prob, "row") @@ -138,7 +147,7 @@ def __init__( for _ in range(cnf.num_layers): # adjust stochastic depth probability based on the depth of the stage block sd_prob = stochastic_depth_prob * stage_block_id / (total_stage_blocks - 1.0) - stage.append(block(cnf.input_channels, layer_scale, sd_prob, norm_layer)) + stage.append(block(cnf.input_channels, layer_scale, sd_prob)) stage_block_id += 1 layers.append(nn.Sequential(*stage)) if cnf.out_channels is not None: @@ -177,20 +186,43 @@ def forward(self, x: Tensor) -> Tensor: return self._forward_impl(x) +def _convnext( + block_setting: List[CNBlockConfig], + stochastic_depth_prob: float, + weights: Optional[WeightsEnum], + progress: bool, + **kwargs: Any, +) -> ConvNeXt: + if weights is not None: + _ovewrite_named_param(kwargs, "num_classes", len(weights.meta["categories"])) + + model = ConvNeXt(block_setting, stochastic_depth_prob=stochastic_depth_prob, **kwargs) + + if weights is not None: + model.load_state_dict(weights.get_state_dict(progress=progress)) + + return model + + +_COMMON_META = { + "task": "image_classification", + "architecture": "ConvNeXt", + "publication_year": 2022, + "size": (224, 224), + "min_size": (32, 32), + "categories": _IMAGENET_CATEGORIES, + "interpolation": InterpolationMode.BILINEAR, + "recipe": "https://github.com/pytorch/vision/tree/main/references/classification#convnext", +} + + class ConvNeXt_Tiny_Weights(WeightsEnum): IMAGENET1K_V1 = Weights( - url="https://download.pytorch.org/models/convnext_tiny-47b116bd.pth", + url="https://download.pytorch.org/models/convnext_tiny-983f1562.pth", transforms=partial(ImageNetEval, crop_size=224, resize_size=236), meta={ - "task": "image_classification", - "architecture": "ConvNeXt", - "publication_year": 2022, + **_COMMON_META, "num_params": 28589128, - "size": (224, 224), - "min_size": (32, 32), - "categories": _IMAGENET_CATEGORIES, - "interpolation": InterpolationMode.BILINEAR, - "recipe": "https://github.com/pytorch/vision/tree/main/references/classification#convnext", "acc@1": 82.520, "acc@5": 96.146, }, @@ -198,9 +230,51 @@ class ConvNeXt_Tiny_Weights(WeightsEnum): DEFAULT = IMAGENET1K_V1 +class ConvNeXt_Small_Weights(WeightsEnum): + IMAGENET1K_V1 = Weights( + url="https://download.pytorch.org/models/convnext_small-0c510722.pth", + transforms=partial(ImageNetEval, crop_size=224, resize_size=230), + meta={ + **_COMMON_META, + "num_params": 50223688, + "acc@1": 83.616, + "acc@5": 96.650, + }, + ) + DEFAULT = IMAGENET1K_V1 + + +class ConvNeXt_Base_Weights(WeightsEnum): + IMAGENET1K_V1 = Weights( + url="https://download.pytorch.org/models/convnext_base-6075fbad.pth", + transforms=partial(ImageNetEval, crop_size=224, resize_size=232), + meta={ + **_COMMON_META, + "num_params": 88591464, + "acc@1": 84.062, + "acc@5": 96.870, + }, + ) + DEFAULT = IMAGENET1K_V1 + + +class ConvNeXt_Large_Weights(WeightsEnum): + IMAGENET1K_V1 = Weights( + url="https://download.pytorch.org/models/convnext_large-ea097f82.pth", + transforms=partial(ImageNetEval, crop_size=224, resize_size=232), + meta={ + **_COMMON_META, + "num_params": 197767336, + "acc@1": 84.414, + "acc@5": 96.976, + }, + ) + DEFAULT = IMAGENET1K_V1 + + @handle_legacy_interface(weights=("pretrained", ConvNeXt_Tiny_Weights.IMAGENET1K_V1)) def convnext_tiny(*, weights: Optional[ConvNeXt_Tiny_Weights] = None, progress: bool = True, **kwargs: Any) -> ConvNeXt: - r"""ConvNeXt model architecture from the + r"""ConvNeXt Tiny model architecture from the `"A ConvNet for the 2020s" `_ paper. Args: @@ -209,9 +283,6 @@ def convnext_tiny(*, weights: Optional[ConvNeXt_Tiny_Weights] = None, progress: """ weights = ConvNeXt_Tiny_Weights.verify(weights) - if weights is not None: - _ovewrite_named_param(kwargs, "num_classes", len(weights.meta["categories"])) - block_setting = [ CNBlockConfig(96, 192, 3), CNBlockConfig(192, 384, 3), @@ -219,9 +290,50 @@ def convnext_tiny(*, weights: Optional[ConvNeXt_Tiny_Weights] = None, progress: CNBlockConfig(768, None, 3), ] stochastic_depth_prob = kwargs.pop("stochastic_depth_prob", 0.1) - model = ConvNeXt(block_setting, stochastic_depth_prob=stochastic_depth_prob, **kwargs) + return _convnext(block_setting, stochastic_depth_prob, weights, progress, **kwargs) - if weights is not None: - model.load_state_dict(weights.get_state_dict(progress=progress)) - return model +@handle_legacy_interface(weights=("pretrained", ConvNeXt_Small_Weights.IMAGENET1K_V1)) +def convnext_small( + *, weights: Optional[ConvNeXt_Small_Weights] = None, progress: bool = True, **kwargs: Any +) -> ConvNeXt: + weights = ConvNeXt_Small_Weights.verify(weights) + + block_setting = [ + CNBlockConfig(96, 192, 3), + CNBlockConfig(192, 384, 3), + CNBlockConfig(384, 768, 27), + CNBlockConfig(768, None, 3), + ] + stochastic_depth_prob = kwargs.pop("stochastic_depth_prob", 0.4) + return _convnext(block_setting, stochastic_depth_prob, weights, progress, **kwargs) + + +@handle_legacy_interface(weights=("pretrained", ConvNeXt_Base_Weights.IMAGENET1K_V1)) +def convnext_base(*, weights: Optional[ConvNeXt_Base_Weights] = None, progress: bool = True, **kwargs: Any) -> ConvNeXt: + weights = ConvNeXt_Base_Weights.verify(weights) + + block_setting = [ + CNBlockConfig(128, 256, 3), + CNBlockConfig(256, 512, 3), + CNBlockConfig(512, 1024, 27), + CNBlockConfig(1024, None, 3), + ] + stochastic_depth_prob = kwargs.pop("stochastic_depth_prob", 0.5) + return _convnext(block_setting, stochastic_depth_prob, weights, progress, **kwargs) + + +@handle_legacy_interface(weights=("pretrained", ConvNeXt_Large_Weights.IMAGENET1K_V1)) +def convnext_large( + *, weights: Optional[ConvNeXt_Large_Weights] = None, progress: bool = True, **kwargs: Any +) -> ConvNeXt: + weights = ConvNeXt_Large_Weights.verify(weights) + + block_setting = [ + CNBlockConfig(192, 384, 3), + CNBlockConfig(384, 768, 3), + CNBlockConfig(768, 1536, 27), + CNBlockConfig(1536, None, 3), + ] + stochastic_depth_prob = kwargs.pop("stochastic_depth_prob", 0.5) + return _convnext(block_setting, stochastic_depth_prob, weights, progress, **kwargs)