-
Notifications
You must be signed in to change notification settings - Fork 158
Add model and architecture for Omnivore model #43
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add model and architecture for Omnivore model #43
Conversation
…mparing with original model with pretrained weight
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
- some tests are failing on CI, can you take a look
- I left some comments around passing in instantiated modules but realized you are following torchvision style. so if we plan to upstream this, its ok to leave them as is
torchmultimodal/modules/encoders/swin_transformer_3d_encoder.py
Outdated
Show resolved
Hide resolved
torchmultimodal/modules/encoders/swin_transformer_3d_encoder.py
Outdated
Show resolved
Hide resolved
torchmultimodal/modules/encoders/swin_transformer_3d_encoder.py
Outdated
Show resolved
Hide resolved
from torchvision.ops.stochastic_depth import StochasticDepth | ||
|
||
|
||
def _compute_pad_size_3d( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why do u need this and next function to be global?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For _compute_pad_size_3d
it is needed on 2 different classes so we put it on global but private.
For _compute_attention_mask_3d
we put this as function so we can cache it. We don't put it inside the function shifted_window_attention_3d
because we avoid nested function as it is not supported by jit script.
torchmultimodal/modules/encoders/swin_transformer_3d_encoder.py
Outdated
Show resolved
Hide resolved
|
||
image = torch.randn(1, 3, 1, 112, 112) # B C D H W | ||
image_score = model(image, input_type="image") | ||
self.assertEqual(image_score.size(), torch.Size((1, 1000))) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
please use our test utility to assert on values and shapes. (we need to assert values too)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the suggestion! Will do.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Use the utiltiy for assertion:
Line 74 in e857541
def assert_expected( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ah okay, will change self.assertEqual
with assert_expected
then
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@langong347 after I read the utility assert_expected
it seems to compare two tensor with float type. In this case, I think assertEqual
is better for comparing the size since it is integer?
Q: how did you make sure that the order of pretrained model keys is the same as that of your model keys?
|
Sometimes this can be a problem, especially if we have branching operation. For instance in omnivore, the head that use However from what I understand, if there is no branching (like the swin transformer encoder only, in fact the implementation here is different with the original one: https://github.com/facebookresearch/omnivore/blob/main/models/swin_transformer_3d.py ) usually the layout will follow the same order on how the input will be processed, and in this case it should be okay. Some of the problem that may occur here is that we might store same data in different dimension like 100 x 100 (2D) or 10000 (1D), in this case we could modify our model a bit to follow how the original one store data. |
|
||
image = torch.randn(1, 3, 1, 112, 112) # B C D H W | ||
image_score = model(image, input_type="image") | ||
self.assertEqual(image_score.size(), torch.Size((1, 1000))) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Use the utiltiy for assertion:
Line 74 in e857541
def assert_expected( |
torchmultimodal/modules/encoders/swin_transformer_3d_encoder.py
Outdated
Show resolved
Hide resolved
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Adding here some notes based on an offline discussion I had with @YosuaMichael today.
# LICENSE file in the root directory of this source tree. | ||
|
||
# Modified from 2d Swin Transformers in torchvision: | ||
# https://github.com/pytorch/vision/blob/main/torchvision/models/swin_transformer.py |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@YosuaMichael Copy-pasting Swin implementation and making it 3D looks suboptimal. We should examine the possibility of adding native support of 3d on TorchVision's implementation or refactor to make shareable more components.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hi @datumbox , I plan to put this upstream the SwinTransformer3d on video_classification
in torchvision after finishing the Omnivore first. Do you think this plan sounds okay? Or maybe it is better to put this on torchvision first?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@YosuaMichael I recommend taking the time to create a plan for this and similar models. Not only it will deliver better code quality but also sketch out how we will handle supporting Multimodal more effectively on the future. Copy pasting is a shortcut that kicks down the line the problem and only makes it harder to solve.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@ankitade @langong347 Sharing here what we discussed offline with @YosuaMichael and @kartikayk.
The adaptation of Swin 2d to 3d, highlighted some things we can improve on the original implementation of TorchVision. One of them is the use of single integers instead of tuples/lists for the sizes, another one is the fact that some modules can easily be reused if we make minor adaptations (candidates are PatchMerging
and SwinTransformerBlock
). It would be really nice, if we could make these changes on TorchVision's code now to minimize the degree of copy-pasting. The reason for the urgency is the upcoming release and the fact that Swin is a brand new class which can be easily modified without worrying about BC now. After the release, there will be all sorts of BC considerations.
The timing is tight and we definitely don't want to delay you. So we will just give it 1-2 days max to see if that's possible. If we are successful to refactor, we will be able to simplify this implementation and reduce the copy-paste. If not, we will merge this and review improvements on TorchVision on the future.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just to update here, the update on torchvision 2d swin transformer is on pytorch/vision#6088
torchmultimodal/modules/encoders/swin_transformer_3d_encoder.py
Outdated
Show resolved
Hide resolved
self.heads = heads | ||
self.input_types = set(heads.keys()) | ||
|
||
def forward(self, x: torch.Tensor, input_type: str): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The idiom of passing the input_type
as string and then choosing to which head to forward the data is not very common. Why can't we instantiate Omnivore modules with the same encoder but different heads and push to the training loop which module you use depending on the input? This will move the complexity from the nn.Module to the loop and it will be beneficial for use-cases where the loss is different for each head.
Speaking offline with @YosuaMichael there are some concerns on whether the whole approach will work well on a distributed setup.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hi @datumbox, currently I plan to try on doing the training first and see if there are any particular problem with the architectures (whether the current or multiple model with shared encoder). In particular, I need to check the behaviour of having 2 models A
and B
that shared the same encoder
wrapped individually into DDP(A)
and DDP(B)
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is an interesting comment. confirming, does the ckpt include all the heads?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@ankitade yes, the original checkpoint include all the heads
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@ankitade @YosuaMichael It's worth keeping in mind that this pattern wont support FX.
FX does tracing which means the flow of the execution of the model should not depend on the input. By adding the string FX won't know what to execute. Even though you might not be interested right now to make it FX traceable, on the future you might want to adopt FX quantization or other FX-based utils.
This is the reason my advice is to split this module in submodules depending on the head. It's a safer idiom that would be forward compatible with future core expansions. It might require some massaging on the original weights to fix but I think that's straightforward to do and worth it.
Up to you. :)
return x | ||
|
||
|
||
class SwinTransformer3dEncoder(nn.Module): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I agree with @datumbox regarding the adaptation to 3d. I think it would be preferable to add it to torchvision
directly. This main Class can be a good entry point for the unification effort, looking into it, I think we can have a common logic under SwinTransformer
with extra parameters in order to cater for the 2d (SwinTransformer2d
) and 3d (SwinTransformer3d
) versions. For BC the default values should be for the 2d version. From this main class we could drill down and look into more unification efforts (not possible in all cases probably). For example, we have ConvNormActivationBlock
and then 2d and 3d vertions (see https://github.com/pytorch/vision/blob/49496c4f6201f05f2351788389a0863b087e78f4/torchvision/ops/misc.py#L68). This was a simpler example, but could be used as inspiration.
Let me know your thoughts @YosuaMichael. I maybe I might have overlooked some of the differences which make this even harder than I think.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@jdsgomes I agree by add this to torchvision (planned to do later). But for having generic SwinTransformer for both 2d and 3d, I have a feeling this could be difficult, especially to make it BC. Here are some reasons I can think of (let me know if you have a good solution):
- some of the params input differs in type. For instance the
window_size
for 2d isint
(they assume it always a symmetry windows like 16x16), however in 3d we useTuple[int, int, int]
since we have windows like(8, 7, 7)
. - In this 3d version, I modularize the
patch_embed
layer because the Omnivore actually use differentpatch_embed
method
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There are some options. For the window_size we can have union as parameter. For patch_embed we can have it as optional pre-step.
torchmultimodal/modules/encoders/swin_transformer_3d_encoder.py
Outdated
Show resolved
Hide resolved
norm_layer=nn.LayerNorm, | ||
patch_embed=PatchEmbedOmnivore, | ||
) | ||
if encoder_only: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
instead of this, separate out into a function to return the encoder and call it here
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ok, I think we can create a function for the encoder only and call it here.
self.heads = heads | ||
self.input_types = set(heads.keys()) | ||
|
||
def forward(self, x: torch.Tensor, input_type: str): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is an interesting comment. confirming, does the ckpt include all the heads?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hi @ankitade, thanks for your comment and sorry for a very late feedback. I miss the notif on your comment...
I will update the code to reflect some of your comment, thanks!
self.heads = heads | ||
self.input_types = set(heads.keys()) | ||
|
||
def forward(self, x: torch.Tensor, input_type: str): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@ankitade yes, the original checkpoint include all the heads
norm_layer=nn.LayerNorm, | ||
patch_embed=PatchEmbedOmnivore, | ||
) | ||
if encoder_only: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ok, I think we can create a function for the encoder only and call it here.
torchmultimodal/modules/encoders/swin_transformer_3d_encoder.py
Outdated
Show resolved
Hide resolved
Codecov Report
@@ Coverage Diff @@
## main #43 +/- ##
==========================================
+ Coverage 88.96% 89.27% +0.31%
==========================================
Files 33 38 +5
Lines 1722 2061 +339
==========================================
+ Hits 1532 1840 +308
- Misses 190 221 +31
Continue to review full report at Codecov.
|
@YosuaMichael has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator. |
@YosuaMichael has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator. |
Summary:
Omnivore is a vision multimodal model that able to classify RGB image, video, and depth image (RGBD) with a shared trunk parameters.
With this PR, we aim to have the model and architecture class for Omnivore that able to load converted pretrained weight from original author: https://github.com/facebookresearch/omnivore and produce same result.
Test plan:
examples/omnivore/LoadOriginalPretrainedWeightAndCompare.ipynb
that load the original omnivore pretrained weight in torchhub and make sure the output is the same