Add Videoprism#39895
Conversation
Summary of the code so far
@qubvel I'd request your preliminary review on the current code structure. There are detailed comments with "# ?" for ease of review, these will be removed later on. Todos:
** Please note that i am currently using the preprocessing utils from the original code in my convert_weights_to_hf script. The code uses mediapy, which uses lanczos interpolation for resizing videos by default. However, lanczos interpolation is not supported in torch yet and that's why we can't get the exact same outputs if a fast video processor is used. ** The videoprism team have not released the weights for the classification head. ** The weights released on hub are in npz format, safetensors need to be uploaded there. |
… exact values of jax.image.resize
| @dataclass | ||
| class TextEncoderOutput(ModelOutput): | ||
| """ | ||
| Base class for text encoder outputs. | ||
| """ | ||
|
|
||
| last_hidden_state: Optional[torch.FloatTensor] = None | ||
| hidden_states: Optional[tuple[torch.FloatTensor, ...]] = None | ||
| attentions: Optional[tuple[torch.FloatTensor, ...]] = None |
There was a problem hiding this comment.
This is just a common BaseModelOutput, no need to redefine it
| if self.mode == "spatial": | ||
| self.patch_embeddings = VideoPrismTubeletEmbeddings(config) | ||
| self.spatial_pos_emb = nn.Parameter(torch.zeros(1, self.pos_emb_shape[1] * self.pos_emb_shape[2], config.hidden_size)) # ? (1, 256, 768) | ||
|
|
||
| elif self.mode == "temporal": | ||
| self.temporal_pos_emb = nn.Parameter(torch.zeros(1, self.pos_emb_shape[0], config.hidden_size)) # ? (1, 16, 768) |
There was a problem hiding this comment.
Do we have checkpoint releases for both versions? Otherwise, please leave only one.
| def _interpolate_emb_2d( | ||
| self, emb: torch.Tensor, source_emb_shape: tuple[int, int], target_emb_shape: tuple[int, int] | ||
| ): |
There was a problem hiding this comment.
should be defined in interpolate_pos_encoding instead, no?
There was a problem hiding this comment.
didn't use interpolate_pos_encoding because there are two different types interpolation for the pos embeds (spatial first then temporal). Now that the plan is to have two different embedding classes, I'll inherit from VivitEmbeddings for both and they both will have interpolate_pos_encoding method.
| with torch.no_grad(): | ||
| self.layernorm_before.weight += nn.Parameter( | ||
| torch.ones(self.config.hidden_size) | ||
| ) | ||
| self.layernorm_after.weight += nn.Parameter( | ||
| torch.ones(self.config.hidden_size) | ||
| ) |
There was a problem hiding this comment.
That seems super strange to me. Why do we need this, is this correct? The operation is in-place, so after each forward pass, are we continuing to increase the weight?
There was a problem hiding this comment.
I found it very strange too.
Here is the original code

https://github.com/google-deepmind/videoprism/blob/main/videoprism/layers.py#L182-L193
when direct_scale is set to False (this is always the case), +1 is added to the scale tensor of layernorm.

This code seems more like an attempt to ensure the layernorm scale factor is not zero during training. I can create an issue in the repo to confirm if they want this behavior during inference as well.
You are right that the hf code in the current form means that the layernorm scale will get increased by +1 for every iteration of forward pass. The jax code initializes the Layernorm class on the go just before it is called so this problem does not happen there. If this +1 portion is moved (from forward) to the init of the relevant class, then that does not work as I guess during creation of a model instance, the init methods are evoked with the initialized weights, and later the pretrained weights are placed and that's why +1 does not happen. I've been working with single forward passes of the model, so it's been fine, but this issue still needs to be resolved.
There was a problem hiding this comment.
Ok, I suppose we must not modify weight inplace and define it as follows:
import torch
import torch.nn as nn
import torch.nn.functional as F
class VideoPrismLayerNorm(nn.LayerNorm):
def forward(self, input: torch.Tensor) -> torch.Tensor:
return F.layer_norm(
input, self.normalized_shape, self.weight + 1, self.bias, self.eps
)That should be equivalent, right?
There was a problem hiding this comment.
Or in case +1 is only for initialization, we should not have it in modeling code, just in _init_weights method.
Please keep in mind that logits should match exactly (1e-3/1e-4) with the original implementation, and matching them should give you the right answer whether this addition is relevant
P.S. Just saw the message below 👍
There was a problem hiding this comment.
alright, using F.layer_norm is a very good solution; I'll rename the weights and refactor the code. Also I got the confirmation from the videoprism team, scale = 0 for initialization, and the +1 is expected to be there.
| if mode == "spatial": | ||
| self.layer = nn.ModuleList([VideoPrismLayer(config) for _ in range(config.num_spatial_layers)]) | ||
| elif mode == "temporal": | ||
| self.layer = nn.ModuleList([VideoPrismLayer(config) for _ in range(config.num_temporal_layers)]) | ||
| elif mode == "auxiliary": | ||
| self.layer = nn.ModuleList([VideoPrismLayer(config) for _ in range(config.num_auxiliary_layers)]) | ||
| elif mode == "unimodal": | ||
| self.layer = nn.ModuleList([VideoPrismLayer(config) for _ in range(config.num_unimodal_layers)]) | ||
| else: | ||
| raise ValueError(f"Unknown mode: {mode}. Supported modes are: spatial, temporal, auxiliary and unimodal.") |
There was a problem hiding this comment.
I don't see any reason to split it into different if/else. We might have only one attribute num_layers and that's it
There was a problem hiding this comment.
If we remove these conditionals, then we will need to set config.num_layers = config.num_spatial_layers, before instantiating the spatial encoder object and then config.num_layers = config.num_temporal_layers before temporal encoder in the model init. It's slightly less explicit, but, since you agreed with the change of config from python_gelu to relu in the init, I'll go ahead with this one too, and then the encoder can be directly used from vivit.
| with torch.no_grad(): | ||
| self.layernorm.weight += nn.Parameter(torch.ones(self.config.hidden_size)) |
| class PositionalEmbedding(nn.Module): | ||
| def __init__(self, config: VideoPrismConfig): | ||
| super().__init__() | ||
| self.hidden_size = config.hidden_size | ||
| self.min_timescale = 1 | ||
| self.max_timescale = 10000 | ||
|
|
||
| def forward(self, seq_length): | ||
| position = torch.arange(seq_length, dtype=torch.float32).unsqueeze(0) # ? (1, seq_length) | ||
| num_timescales = self.hidden_size // 2 | ||
|
|
||
| log_timescale_increment = math.log( | ||
| float(self.max_timescale) / float(self.min_timescale) # ? log(10000/1) = ln(10000) | ||
| ) / torch.maximum(torch.tensor(num_timescales, dtype=torch.float32) - 1, torch.tensor(1)) | ||
|
|
||
| inv_timescales = self.min_timescale * torch.exp( | ||
| torch.arange(num_timescales, dtype=torch.float32) * -log_timescale_increment | ||
| ) | ||
|
|
||
| scaled_time = position.unsqueeze(-1) * inv_timescales.expand(1, 1, -1) | ||
|
|
||
| embs = torch.cat((torch.sin(scaled_time), torch.cos(scaled_time)), dim=-1) | ||
|
|
||
| return embs |
There was a problem hiding this comment.
align with the RoPE classes defined in transformers
| self.backbone = VideoPrismModel(config) | ||
| self.auxiliary_encoder = VideoPrismEncoder(config, mode="auxiliary") | ||
| self.contrastive_vision_pooler = VideoPrismMultiheadAttentionPoolingHead(config) | ||
| self.text_encoder = VideoPrismTextEncoder(config) |
There was a problem hiding this comment.
I assume we should have the following classes (similar to CLIP)
- VideoPrismVideoModel
- VideoPrismTextModel
please, align to this
There was a problem hiding this comment.
This should be refactored to follow mllama or dinov3_vit conversion script format. We should have a weights mapping dict and manipulate that
|
The following notebooks are from the original repo and the expected values of the output tensors used to validate the HF implementation have been taken from here. Please note that the output tensors are slightly different for the larger models when a TPU is used. My expected tensor values are for the cpu. Since the output of the final tensors (of HF code) matches that of the original code, the elements in the current code are aligned with the original implementation, despite that strange layernorm weight increase. |
I'm loosening the tol for that test |
|
[For maintainers] Suggested jobs to run (before merge) run-slow: auto, videoprism |
|
run-slow: videoprism |
|
This comment contains models: ["models/videoprism"] |
|
@vasqu do you remember the hf dataset repo for storing assets for docs? One image link needs to be replaced, it was already uploaded there |
regarding this: everything uses the google repos except for the three test files where you will need to replace 'mhrdyn7' to 'google' via ctrl+F. Slow tests passed, the very last commit simply contains the correct image link in doc, so no more run_slow needed |
|
[For maintainers] Suggested jobs to run (before merge) run-slow: auto, videoprism |
|
[For maintainers] Suggested jobs to run (before merge) run-slow: auto, videoprism |
|
@MHRDYN7 let's update the checkpoints on docs and tests, we can merge then imo |
|
I meant in this PR pls :D so to use the google ckpts with revision for all occurrences |
|
[For maintainers] Suggested jobs to run (before merge) run-slow: auto, videoprism |
|
ok done. Had some confusion, but all good now. run_slow is needed, there was a stale tokenizer in a google ref, so just to be safe |
|
run-slow: videoprism |
|
Last sanity check 🙏 |
|
This comment contains models: ["models/videoprism"] |
|
Rerunning fast CI, don't understand why it failed tbh 🤔 otherwise will check next week and merge then |
1048e9a
|
Haya it worked!! Nice and gz on the merge @MHRDYN7 🫡 |
Fixes #39893. This pr adds the VideoPrism model by google deepmind. Original repo