Skip to content

Add Videoprism#39895

Merged
vasqu merged 159 commits into
huggingface:mainfrom
MHRDYN7:videoprism
Jun 19, 2026
Merged

Add Videoprism#39895
vasqu merged 159 commits into
huggingface:mainfrom
MHRDYN7:videoprism

Conversation

@MHRDYN7

@MHRDYN7 MHRDYN7 commented Aug 4, 2025

Copy link
Copy Markdown
Contributor

Fixes #39893. This pr adds the VideoPrism model by google deepmind. Original repo

@MHRDYN7 MHRDYN7 marked this pull request as ready for review August 23, 2025 21:08
@MHRDYN7

MHRDYN7 commented Aug 23, 2025

Copy link
Copy Markdown
Contributor Author

Summary of the code so far

  • The VideoPrismModel has been implemented using the modular code on top of Vivit and this is only a video encoder.

  • The code design is such that the same encoder code is reused as many times as possible.

  • Batches of video segments of shape (num_frames=16, H=288, W=288) are passed into the model and then the tubelet embedding class is used to convert each frame into hidden states of shape (256, 768) and therefore the whole input becomes (B*num_frame, 256, 768).

  • These spatial embeddings are passed into a spatial encoder

  • The outputs of the spatial encoder are reshaped to (B*256, num_frame, 768) and then passed into a temporal encoder

  • The attention function has an internal attention cap implemented in the modified eager_attention, not sure if this can be somehow used along with sdpa.

  • The VideoPrismClip model uses VideoPrismModel as a backbone for the video input, passes the embeddings through an auxiliary encoder then through an attention pooling layer.

  • There is also a standard text encoder that is called inside VideoPrismClip.

  • All the exact details from the original code have been extracted and correct tensors are being returned for both the models.

@qubvel I'd request your preliminary review on the current code structure. There are detailed comments with "# ?" for ease of review, these will be removed later on.

Todos:

  1. Get started with the tests
  2. Implement the video and text processors
  3. Add code for the video classification model

** Please note that i am currently using the preprocessing utils from the original code in my convert_weights_to_hf script. The code uses mediapy, which uses lanczos interpolation for resizing videos by default. However, lanczos interpolation is not supported in torch yet and that's why we can't get the exact same outputs if a fast video processor is used.

** The videoprism team have not released the weights for the classification head.

** The weights released on hub are in npz format, safetensors need to be uploaded there.

@qubvel qubvel requested review from qubvel and removed request for ArthurZucker and Rocketknight1 August 25, 2025 10:54

@qubvel qubvel left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @MHRDYN7, huge thanks for working on the model addition! You already did a great work, please see the comments to align it further to the transformers standards 🤗

Comment on lines +126 to +134
@dataclass
class TextEncoderOutput(ModelOutput):
"""
Base class for text encoder outputs.
"""

last_hidden_state: Optional[torch.FloatTensor] = None
hidden_states: Optional[tuple[torch.FloatTensor, ...]] = None
attentions: Optional[tuple[torch.FloatTensor, ...]] = None

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is just a common BaseModelOutput, no need to redefine it

Comment on lines +197 to +202
if self.mode == "spatial":
self.patch_embeddings = VideoPrismTubeletEmbeddings(config)
self.spatial_pos_emb = nn.Parameter(torch.zeros(1, self.pos_emb_shape[1] * self.pos_emb_shape[2], config.hidden_size)) # ? (1, 256, 768)

elif self.mode == "temporal":
self.temporal_pos_emb = nn.Parameter(torch.zeros(1, self.pos_emb_shape[0], config.hidden_size)) # ? (1, 16, 768)

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we have checkpoint releases for both versions? Otherwise, please leave only one.

Comment on lines +257 to +259
def _interpolate_emb_2d(
self, emb: torch.Tensor, source_emb_shape: tuple[int, int], target_emb_shape: tuple[int, int]
):

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should be defined in interpolate_pos_encoding instead, no?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

didn't use interpolate_pos_encoding because there are two different types interpolation for the pos embeds (spatial first then temporal). Now that the plan is to have two different embedding classes, I'll inherit from VivitEmbeddings for both and they both will have interpolate_pos_encoding method.

Comment on lines +351 to +357
with torch.no_grad():
self.layernorm_before.weight += nn.Parameter(
torch.ones(self.config.hidden_size)
)
self.layernorm_after.weight += nn.Parameter(
torch.ones(self.config.hidden_size)
)

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That seems super strange to me. Why do we need this, is this correct? The operation is in-place, so after each forward pass, are we continuing to increase the weight?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I found it very strange too.

Here is the original code
image
https://github.com/google-deepmind/videoprism/blob/main/videoprism/layers.py#L182-L193

when direct_scale is set to False (this is always the case), +1 is added to the scale tensor of layernorm.
image

This code seems more like an attempt to ensure the layernorm scale factor is not zero during training. I can create an issue in the repo to confirm if they want this behavior during inference as well.

You are right that the hf code in the current form means that the layernorm scale will get increased by +1 for every iteration of forward pass. The jax code initializes the Layernorm class on the go just before it is called so this problem does not happen there. If this +1 portion is moved (from forward) to the init of the relevant class, then that does not work as I guess during creation of a model instance, the init methods are evoked with the initialized weights, and later the pretrained weights are placed and that's why +1 does not happen. I've been working with single forward passes of the model, so it's been fine, but this issue still needs to be resolved.

@qubvel qubvel Aug 27, 2025

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok, I suppose we must not modify weight inplace and define it as follows:

import torch
import torch.nn as nn
import torch.nn.functional as F

class VideoPrismLayerNorm(nn.LayerNorm):
    def forward(self, input: torch.Tensor) -> torch.Tensor:
        return F.layer_norm(
            input, self.normalized_shape, self.weight + 1, self.bias, self.eps
        )

That should be equivalent, right?

@qubvel qubvel Aug 27, 2025

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Or in case +1 is only for initialization, we should not have it in modeling code, just in _init_weights method.
Please keep in mind that logits should match exactly (1e-3/1e-4) with the original implementation, and matching them should give you the right answer whether this addition is relevant

P.S. Just saw the message below 👍

@MHRDYN7 MHRDYN7 Aug 28, 2025

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

alright, using F.layer_norm is a very good solution; I'll rename the weights and refactor the code. Also I got the confirmation from the videoprism team, scale = 0 for initialization, and the +1 is expected to be there.

Comment on lines +366 to +375
if mode == "spatial":
self.layer = nn.ModuleList([VideoPrismLayer(config) for _ in range(config.num_spatial_layers)])
elif mode == "temporal":
self.layer = nn.ModuleList([VideoPrismLayer(config) for _ in range(config.num_temporal_layers)])
elif mode == "auxiliary":
self.layer = nn.ModuleList([VideoPrismLayer(config) for _ in range(config.num_auxiliary_layers)])
elif mode == "unimodal":
self.layer = nn.ModuleList([VideoPrismLayer(config) for _ in range(config.num_unimodal_layers)])
else:
raise ValueError(f"Unknown mode: {mode}. Supported modes are: spatial, temporal, auxiliary and unimodal.")

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't see any reason to split it into different if/else. We might have only one attribute num_layers and that's it

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we remove these conditionals, then we will need to set config.num_layers = config.num_spatial_layers, before instantiating the spatial encoder object and then config.num_layers = config.num_temporal_layers before temporal encoder in the model init. It's slightly less explicit, but, since you agreed with the change of config from python_gelu to relu in the init, I'll go ahead with this one too, and then the encoder can be directly used from vivit.

Comment on lines +641 to +642
with torch.no_grad():
self.layernorm.weight += nn.Parameter(torch.ones(self.config.hidden_size))

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

again, super strange

Comment on lines +652 to +675
class PositionalEmbedding(nn.Module):
def __init__(self, config: VideoPrismConfig):
super().__init__()
self.hidden_size = config.hidden_size
self.min_timescale = 1
self.max_timescale = 10000

def forward(self, seq_length):
position = torch.arange(seq_length, dtype=torch.float32).unsqueeze(0) # ? (1, seq_length)
num_timescales = self.hidden_size // 2

log_timescale_increment = math.log(
float(self.max_timescale) / float(self.min_timescale) # ? log(10000/1) = ln(10000)
) / torch.maximum(torch.tensor(num_timescales, dtype=torch.float32) - 1, torch.tensor(1))

inv_timescales = self.min_timescale * torch.exp(
torch.arange(num_timescales, dtype=torch.float32) * -log_timescale_increment
)

scaled_time = position.unsqueeze(-1) * inv_timescales.expand(1, 1, -1)

embs = torch.cat((torch.sin(scaled_time), torch.cos(scaled_time)), dim=-1)

return embs

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

align with the RoPE classes defined in transformers

Comment thread src/transformers/models/videoprism/modular_videoprism.py Outdated
Comment on lines +748 to +751
self.backbone = VideoPrismModel(config)
self.auxiliary_encoder = VideoPrismEncoder(config, mode="auxiliary")
self.contrastive_vision_pooler = VideoPrismMultiheadAttentionPoolingHead(config)
self.text_encoder = VideoPrismTextEncoder(config)

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I assume we should have the following classes (similar to CLIP)

  • VideoPrismVideoModel
  • VideoPrismTextModel

please, align to this

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should be refactored to follow mllama or dinov3_vit conversion script format. We should have a weights mapping dict and manipulate that

@MHRDYN7

MHRDYN7 commented Aug 27, 2025

Copy link
Copy Markdown
Contributor Author

The following notebooks are from the original repo and the expected values of the output tensors used to validate the HF implementation have been taken from here.
Video encoder model notebook
Video Text model notebook

Please note that the output tensors are slightly different for the larger models when a TPU is used. My expected tensor values are for the cpu.

Since the output of the final tensors (of HF code) matches that of the original code, the elements in the current code are aligned with the original implementation, despite that strange layernorm weight increase.

@MHRDYN7

MHRDYN7 commented Jun 19, 2026

Copy link
Copy Markdown
Contributor Author

Works again now 🥳

I'm loosening the tol for that test

@github-actions

Copy link
Copy Markdown
Contributor

[For maintainers] Suggested jobs to run (before merge)

run-slow: auto, videoprism

@vasqu

vasqu commented Jun 19, 2026

Copy link
Copy Markdown
Collaborator

run-slow: videoprism

@github-actions

Copy link
Copy Markdown
Contributor

Workflow Run ⚙️

This comment contains run-slow, running the specified jobs:

models: ["models/videoprism"]
quantizations: []

@MHRDYN7

MHRDYN7 commented Jun 19, 2026

Copy link
Copy Markdown
Contributor Author

@vasqu do you remember the hf dataset repo for storing assets for docs? One image link needs to be replaced, it was already uploaded there

@github-actions

Copy link
Copy Markdown
Contributor

CI Results

Workflow Run ⚙️

Commit Info

Context Commit Description
RUN 3d08a2ad workflow commit (merge commit)
PR 7cebc1dd branch commit (from PR)
main e7835fba base commit (on main)

✅ No failing test specific to this PR 🎉 👏 !

@MHRDYN7

MHRDYN7 commented Jun 19, 2026

Copy link
Copy Markdown
Contributor Author

Yes pretty much ready to merge, checking with slow CI. Can you update the checkpoints to use the official repos with the current revision?

I talked internally, it's handled and I will update the ckpt paths / revisions when they are merged/moved so should be all good 🤗

regarding this: everything uses the google repos except for the three test files where you will need to replace 'mhrdyn7' to 'google' via ctrl+F.

Slow tests passed, the very last commit simply contains the correct image link in doc, so no more run_slow needed

@github-actions

Copy link
Copy Markdown
Contributor

[For maintainers] Suggested jobs to run (before merge)

run-slow: auto, videoprism

@github-actions

Copy link
Copy Markdown
Contributor

[For maintainers] Suggested jobs to run (before merge)

run-slow: auto, videoprism

Comment thread docs/source/en/model_doc/videoprism.md Outdated
Comment thread tests/models/videoprism/test_processing_videoprism.py Outdated
Comment thread tests/models/videoprism/test_tokenization_videoprism.py Outdated
Comment thread tests/models/videoprism/test_modeling_videoprism.py Outdated
@vasqu

vasqu commented Jun 19, 2026

Copy link
Copy Markdown
Collaborator

@MHRDYN7 let's update the checkpoints on docs and tests, we can merge then imo

@vasqu

vasqu commented Jun 19, 2026

Copy link
Copy Markdown
Collaborator

I meant in this PR pls :D so to use the google ckpts with revision for all occurrences

@github-actions

Copy link
Copy Markdown
Contributor

[For maintainers] Suggested jobs to run (before merge)

run-slow: auto, videoprism

@MHRDYN7

MHRDYN7 commented Jun 19, 2026

Copy link
Copy Markdown
Contributor Author

ok done. Had some confusion, but all good now. run_slow is needed, there was a stale tokenizer in a google ref, so just to be safe

@vasqu

vasqu commented Jun 19, 2026

Copy link
Copy Markdown
Collaborator

run-slow: videoprism

@vasqu

vasqu commented Jun 19, 2026

Copy link
Copy Markdown
Collaborator

Last sanity check 🙏

@github-actions

Copy link
Copy Markdown
Contributor

Workflow Run ⚙️

This comment contains run-slow, running the specified jobs:

models: ["models/videoprism"]
quantizations: []

@github-actions

Copy link
Copy Markdown
Contributor

CI Results

Workflow Run ⚙️

Commit Info

Context Commit Description
RUN b5a5263c workflow commit (merge commit)
PR 53cc84fb branch commit (from PR)
main 95ed1b82 base commit (on main)

✅ No failing test specific to this PR 🎉 👏 !

@vasqu

vasqu commented Jun 19, 2026

Copy link
Copy Markdown
Collaborator

Rerunning fast CI, don't understand why it failed tbh 🤔 otherwise will check next week and merge then

@vasqu vasqu added this pull request to the merge queue Jun 19, 2026
Merged via the queue into huggingface:main with commit 1048e9a Jun 19, 2026
190 of 195 checks passed
@vasqu

vasqu commented Jun 19, 2026

Copy link
Copy Markdown
Collaborator

Haya it worked!! Nice and gz on the merge @MHRDYN7 🫡

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Add VideoPrism

8 participants