From dbb5c5a3d8745de03e8fedf2fab296df3552d7a1 Mon Sep 17 00:00:00 2001 From: Yosua Michael Maranatha Date: Thu, 30 Jun 2022 10:17:38 +0100 Subject: [PATCH 1/3] Handle case where window_size larger than input_size in swin_transformer by updating window_size --- torchvision/models/swin_transformer.py | 23 ++++++++++++++++++++--- 1 file changed, 20 insertions(+), 3 deletions(-) diff --git a/torchvision/models/swin_transformer.py b/torchvision/models/swin_transformer.py index 2f2cfd44445..e3fec34ade6 100644 --- a/torchvision/models/swin_transformer.py +++ b/torchvision/models/swin_transformer.py @@ -160,7 +160,20 @@ def shifted_window_attention( return x +def _fix_window_and_shift_size( + input_hw: List[int], window_size: List[int], shift_size: List[int] +) -> Tuple[List[int], List[int]]: + # Handle case where window_size is larger than input tensor + # Reference on the original implementation: https://github.com/microsoft/Swin-Transformer/blob/main/models/swin_transformer.py#L192-L195 + for i in range(2): + if input_hw[i] <= window_size[i]: + window_size[i] = input_hw[i] + shift_size[i] = 0 + return window_size, shift_size + + torch.fx.wrap("shifted_window_attention") +torch.fx.wrap("_fix_window_and_shift_size") class ShiftedWindowAttention(nn.Module): @@ -218,8 +231,12 @@ def forward(self, x: Tensor): Returns: Tensor with same layout as input, i.e. [B, H, W, C] """ + _, H, W, _ = x.shape + input_hw = [H, W] + # Handle case where the window_size is larger than the input + window_size, shift_size = _fix_window_and_shift_size(input_hw, self.window_size, self.shift_size) - N = self.window_size[0] * self.window_size[1] + N = window_size[0] * window_size[1] relative_position_bias = self.relative_position_bias_table[self.relative_position_index] # type: ignore[index] relative_position_bias = relative_position_bias.view(N, N, -1) relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous().unsqueeze(0) @@ -229,9 +246,9 @@ def forward(self, x: Tensor): self.qkv.weight, self.proj.weight, relative_position_bias, - self.window_size, + window_size, self.num_heads, - shift_size=self.shift_size, + shift_size=shift_size, attention_dropout=self.attention_dropout, dropout=self.dropout, qkv_bias=self.qkv.bias, From ef676399a3f70237aa67e57cbbd717bcf94c37c3 Mon Sep 17 00:00:00 2001 From: Yosua Michael Maranatha Date: Thu, 30 Jun 2022 10:00:30 +0000 Subject: [PATCH 2/3] Add missing import Tuple and format ufmt --- torchvision/models/swin_transformer.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/torchvision/models/swin_transformer.py b/torchvision/models/swin_transformer.py index e3fec34ade6..767e498b2f6 100644 --- a/torchvision/models/swin_transformer.py +++ b/torchvision/models/swin_transformer.py @@ -1,5 +1,5 @@ from functools import partial -from typing import Optional, Callable, List, Any +from typing import Any, Callable, List, Optional, Tuple import torch import torch.nn.functional as F @@ -9,7 +9,7 @@ from ..ops.stochastic_depth import StochasticDepth from ..transforms._presets import ImageClassification, InterpolationMode from ..utils import _log_api_usage_once -from ._api import WeightsEnum, Weights +from ._api import Weights, WeightsEnum from ._meta import _IMAGENET_CATEGORIES from ._utils import _ovewrite_named_param @@ -374,7 +374,7 @@ def __init__( # build SwinTransformer blocks for i_stage in range(len(depths)): stage: List[nn.Module] = [] - dim = embed_dim * 2 ** i_stage + dim = embed_dim * 2**i_stage for i_layer in range(depths[i_stage]): # adjust stochastic depth probability based on the depth of the stage block sd_prob = stochastic_depth_prob * float(stage_block_id) / (total_stage_blocks - 1) From 417e8cd98c487d1ffee82d9fbab59b6de52e14f1 Mon Sep 17 00:00:00 2001 From: Yosua Michael Maranatha Date: Thu, 30 Jun 2022 20:34:42 +0100 Subject: [PATCH 3/3] Resolve comment and update correct ufmt format --- torchvision/models/swin_transformer.py | 16 +++++++++------- 1 file changed, 9 insertions(+), 7 deletions(-) diff --git a/torchvision/models/swin_transformer.py b/torchvision/models/swin_transformer.py index 767e498b2f6..e7082ddd9cb 100644 --- a/torchvision/models/swin_transformer.py +++ b/torchvision/models/swin_transformer.py @@ -161,15 +161,17 @@ def shifted_window_attention( def _fix_window_and_shift_size( - input_hw: List[int], window_size: List[int], shift_size: List[int] + input_size: List[int], window_size: List[int], shift_size: List[int] ) -> Tuple[List[int], List[int]]: # Handle case where window_size is larger than input tensor # Reference on the original implementation: https://github.com/microsoft/Swin-Transformer/blob/main/models/swin_transformer.py#L192-L195 - for i in range(2): - if input_hw[i] <= window_size[i]: - window_size[i] = input_hw[i] - shift_size[i] = 0 - return window_size, shift_size + updated_window_size = window_size.copy() + updated_shift_size = shift_size.copy() + for i in range(len(input_size)): + if input_size[i] <= window_size[i]: + updated_window_size[i] = input_size[i] + updated_shift_size[i] = 0 + return updated_window_size, updated_shift_size torch.fx.wrap("shifted_window_attention") @@ -374,7 +376,7 @@ def __init__( # build SwinTransformer blocks for i_stage in range(len(depths)): stage: List[nn.Module] = [] - dim = embed_dim * 2**i_stage + dim = embed_dim * 2 ** i_stage for i_layer in range(depths[i_stage]): # adjust stochastic depth probability based on the depth of the stage block sd_prob = stochastic_depth_prob * float(stage_block_id) / (total_stage_blocks - 1)