-
Notifications
You must be signed in to change notification settings - Fork 7.1k
Graduate ConvNeXt to main TorchVision area #5330
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
e77d8fe
49c4433
7e8c482
0948f50
d40ba15
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,271 @@ | ||
from functools import partial | ||
from typing import Any, Callable, Dict, List, Optional, Sequence | ||
|
||
import torch | ||
from torch import nn, Tensor | ||
from torch.nn import functional as F | ||
|
||
from .._internally_replaced_utils import load_state_dict_from_url | ||
from ..ops.misc import ConvNormActivation | ||
from ..ops.stochastic_depth import StochasticDepth | ||
from ..utils import _log_api_usage_once | ||
|
||
|
||
__all__ = [ | ||
"ConvNeXt", | ||
"convnext_tiny", | ||
"convnext_small", | ||
"convnext_base", | ||
"convnext_large", | ||
] | ||
|
||
|
||
_MODELS_URLS: Dict[str, Optional[str]] = { | ||
"convnext_tiny": "https://download.pytorch.org/models/convnext_tiny-983f1562.pth", | ||
"convnext_small": "https://download.pytorch.org/models/convnext_small-0c510722.pth", | ||
"convnext_base": "https://download.pytorch.org/models/convnext_base-6075fbad.pth", | ||
"convnext_large": "https://download.pytorch.org/models/convnext_large-ea097f82.pth", | ||
} | ||
|
||
|
||
class LayerNorm2d(nn.LayerNorm): | ||
def forward(self, x: Tensor) -> Tensor: | ||
x = x.permute(0, 2, 3, 1) | ||
x = F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps) | ||
x = x.permute(0, 3, 1, 2) | ||
return x | ||
|
||
|
||
class Permute(nn.Module): | ||
def __init__(self, dims: List[int]): | ||
super().__init__() | ||
self.dims = dims | ||
|
||
def forward(self, x): | ||
return torch.permute(x, self.dims) | ||
|
||
|
||
class CNBlock(nn.Module): | ||
def __init__( | ||
self, | ||
dim, | ||
layer_scale: float, | ||
stochastic_depth_prob: float, | ||
norm_layer: Optional[Callable[..., nn.Module]] = None, | ||
) -> None: | ||
super().__init__() | ||
if norm_layer is None: | ||
norm_layer = partial(nn.LayerNorm, eps=1e-6) | ||
|
||
self.block = nn.Sequential( | ||
nn.Conv2d(dim, dim, kernel_size=7, padding=3, groups=dim, bias=True), | ||
Permute([0, 2, 3, 1]), | ||
norm_layer(dim), | ||
nn.Linear(in_features=dim, out_features=4 * dim, bias=True), | ||
nn.GELU(), | ||
nn.Linear(in_features=4 * dim, out_features=dim, bias=True), | ||
Permute([0, 3, 1, 2]), | ||
) | ||
self.layer_scale = nn.Parameter(torch.ones(dim, 1, 1) * layer_scale) | ||
self.stochastic_depth = StochasticDepth(stochastic_depth_prob, "row") | ||
|
||
def forward(self, input: Tensor) -> Tensor: | ||
result = self.layer_scale * self.block(input) | ||
result = self.stochastic_depth(result) | ||
result += input | ||
return result | ||
|
||
|
||
class CNBlockConfig: | ||
# Stores information listed at Section 3 of the ConvNeXt paper | ||
def __init__( | ||
self, | ||
input_channels: int, | ||
out_channels: Optional[int], | ||
num_layers: int, | ||
) -> None: | ||
self.input_channels = input_channels | ||
self.out_channels = out_channels | ||
self.num_layers = num_layers | ||
|
||
def __repr__(self) -> str: | ||
s = self.__class__.__name__ + "(" | ||
s += "input_channels={input_channels}" | ||
s += ", out_channels={out_channels}" | ||
s += ", num_layers={num_layers}" | ||
s += ")" | ||
return s.format(**self.__dict__) | ||
|
||
|
||
class ConvNeXt(nn.Module): | ||
def __init__( | ||
self, | ||
block_setting: List[CNBlockConfig], | ||
stochastic_depth_prob: float = 0.0, | ||
layer_scale: float = 1e-6, | ||
num_classes: int = 1000, | ||
block: Optional[Callable[..., nn.Module]] = None, | ||
norm_layer: Optional[Callable[..., nn.Module]] = None, | ||
**kwargs: Any, | ||
) -> None: | ||
super().__init__() | ||
_log_api_usage_once(self) | ||
|
||
if not block_setting: | ||
raise ValueError("The block_setting should not be empty") | ||
elif not (isinstance(block_setting, Sequence) and all([isinstance(s, CNBlockConfig) for s in block_setting])): | ||
raise TypeError("The block_setting should be List[CNBlockConfig]") | ||
|
||
if block is None: | ||
block = CNBlock | ||
|
||
if norm_layer is None: | ||
norm_layer = partial(LayerNorm2d, eps=1e-6) | ||
|
||
layers: List[nn.Module] = [] | ||
|
||
# Stem | ||
firstconv_output_channels = block_setting[0].input_channels | ||
layers.append( | ||
ConvNormActivation( | ||
3, | ||
firstconv_output_channels, | ||
kernel_size=4, | ||
stride=4, | ||
padding=0, | ||
norm_layer=norm_layer, | ||
activation_layer=None, | ||
bias=True, | ||
) | ||
) | ||
|
||
total_stage_blocks = sum(cnf.num_layers for cnf in block_setting) | ||
stage_block_id = 0 | ||
for cnf in block_setting: | ||
# Bottlenecks | ||
stage: List[nn.Module] = [] | ||
for _ in range(cnf.num_layers): | ||
# adjust stochastic depth probability based on the depth of the stage block | ||
sd_prob = stochastic_depth_prob * stage_block_id / (total_stage_blocks - 1.0) | ||
stage.append(block(cnf.input_channels, layer_scale, sd_prob)) | ||
stage_block_id += 1 | ||
layers.append(nn.Sequential(*stage)) | ||
if cnf.out_channels is not None: | ||
# Downsampling | ||
layers.append( | ||
nn.Sequential( | ||
norm_layer(cnf.input_channels), | ||
nn.Conv2d(cnf.input_channels, cnf.out_channels, kernel_size=2, stride=2), | ||
) | ||
) | ||
|
||
self.features = nn.Sequential(*layers) | ||
self.avgpool = nn.AdaptiveAvgPool2d(1) | ||
|
||
lastblock = block_setting[-1] | ||
lastconv_output_channels = ( | ||
lastblock.out_channels if lastblock.out_channels is not None else lastblock.input_channels | ||
) | ||
self.classifier = nn.Sequential( | ||
norm_layer(lastconv_output_channels), nn.Flatten(1), nn.Linear(lastconv_output_channels, num_classes) | ||
) | ||
|
||
for m in self.modules(): | ||
if isinstance(m, (nn.Conv2d, nn.Linear)): | ||
nn.init.trunc_normal_(m.weight, std=0.02) | ||
if m.bias is not None: | ||
nn.init.zeros_(m.bias) | ||
|
||
def _forward_impl(self, x: Tensor) -> Tensor: | ||
x = self.features(x) | ||
x = self.avgpool(x) | ||
x = self.classifier(x) | ||
return x | ||
|
||
def forward(self, x: Tensor) -> Tensor: | ||
return self._forward_impl(x) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. End of copy-pasted code. |
||
|
||
|
||
def _convnext( | ||
arch: str, | ||
block_setting: List[CNBlockConfig], | ||
stochastic_depth_prob: float, | ||
pretrained: bool, | ||
progress: bool, | ||
**kwargs: Any, | ||
) -> ConvNeXt: | ||
model = ConvNeXt(block_setting, stochastic_depth_prob=stochastic_depth_prob, **kwargs) | ||
if pretrained: | ||
if arch not in _MODELS_URLS: | ||
raise ValueError(f"No checkpoint is available for model type {arch}") | ||
state_dict = load_state_dict_from_url(_MODELS_URLS[arch], progress=progress) | ||
model.load_state_dict(state_dict) | ||
return model | ||
|
||
|
||
def convnext_tiny(*, pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ConvNeXt: | ||
r"""ConvNeXt Tiny model architecture from the | ||
`"A ConvNet for the 2020s" <https://arxiv.org/abs/2201.03545>`_ paper. | ||
Args: | ||
pretrained (bool): If True, returns a model pre-trained on ImageNet | ||
progress (bool): If True, displays a progress bar of the download to stderr | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. should we document the kwargs here? For the RAFT models I wrote
which I would agree is still awkward, considering the RAFT class isn't even exposed in the docs for now There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Hmm... Nowhere else on the |
||
""" | ||
block_setting = [ | ||
CNBlockConfig(96, 192, 3), | ||
CNBlockConfig(192, 384, 3), | ||
CNBlockConfig(384, 768, 9), | ||
CNBlockConfig(768, None, 3), | ||
] | ||
stochastic_depth_prob = kwargs.pop("stochastic_depth_prob", 0.1) | ||
return _convnext("convnext_tiny", block_setting, stochastic_depth_prob, pretrained, progress, **kwargs) | ||
|
||
|
||
def convnext_small(*, pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ConvNeXt: | ||
r"""ConvNeXt Small model architecture from the | ||
`"A ConvNet for the 2020s" <https://arxiv.org/abs/2201.03545>`_ paper. | ||
Args: | ||
pretrained (bool): If True, returns a model pre-trained on ImageNet | ||
progress (bool): If True, displays a progress bar of the download to stderr | ||
""" | ||
block_setting = [ | ||
CNBlockConfig(96, 192, 3), | ||
CNBlockConfig(192, 384, 3), | ||
CNBlockConfig(384, 768, 27), | ||
CNBlockConfig(768, None, 3), | ||
] | ||
stochastic_depth_prob = kwargs.pop("stochastic_depth_prob", 0.4) | ||
return _convnext("convnext_small", block_setting, stochastic_depth_prob, pretrained, progress, **kwargs) | ||
|
||
|
||
def convnext_base(*, pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ConvNeXt: | ||
r"""ConvNeXt Base model architecture from the | ||
`"A ConvNet for the 2020s" <https://arxiv.org/abs/2201.03545>`_ paper. | ||
Args: | ||
pretrained (bool): If True, returns a model pre-trained on ImageNet | ||
progress (bool): If True, displays a progress bar of the download to stderr | ||
""" | ||
block_setting = [ | ||
CNBlockConfig(128, 256, 3), | ||
CNBlockConfig(256, 512, 3), | ||
CNBlockConfig(512, 1024, 27), | ||
CNBlockConfig(1024, None, 3), | ||
] | ||
stochastic_depth_prob = kwargs.pop("stochastic_depth_prob", 0.5) | ||
return _convnext("convnext_base", block_setting, stochastic_depth_prob, pretrained, progress, **kwargs) | ||
|
||
|
||
def convnext_large(*, pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ConvNeXt: | ||
r"""ConvNeXt Large model architecture from the | ||
`"A ConvNet for the 2020s" <https://arxiv.org/abs/2201.03545>`_ paper. | ||
Args: | ||
pretrained (bool): If True, returns a model pre-trained on ImageNet | ||
progress (bool): If True, displays a progress bar of the download to stderr | ||
""" | ||
block_setting = [ | ||
CNBlockConfig(192, 384, 3), | ||
CNBlockConfig(384, 768, 3), | ||
CNBlockConfig(768, 1536, 27), | ||
CNBlockConfig(1536, None, 3), | ||
] | ||
stochastic_depth_prob = kwargs.pop("stochastic_depth_prob", 0.5) | ||
return _convnext("convnext_large", block_setting, stochastic_depth_prob, pretrained, progress, **kwargs) |
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Start of copy-pasted code.