Skip to content

perf: replace degenerate Conv3d with Linear in Qwen VL patch embedding#45992

Closed
TheShreyanshiDwivedi wants to merge 2 commits into
huggingface:mainfrom
TheShreyanshiDwivedi:fresh/qwen-conv3d-linear
Closed

perf: replace degenerate Conv3d with Linear in Qwen VL patch embedding#45992
TheShreyanshiDwivedi wants to merge 2 commits into
huggingface:mainfrom
TheShreyanshiDwivedi:fresh/qwen-conv3d-linear

Conversation

@TheShreyanshiDwivedi

Copy link
Copy Markdown

What does this PR do?

Replaces nn.Conv3d(kernel_size=k, stride=k) — where kernel_size == stride — with an equivalent nn.Linear in Qwen2-VL, Qwen2.5-VL and Qwen3-VL patch embeddings.

When kernel_size == stride, Conv3d degenerates to a patch-wise projection with no overlap. This is mathematically identical to a linear layer after reshaping the input. Using nn.Linear dispatches to a GEMM kernel instead of im2col + convolution, giving a large speedup especially on modern hardware.

Changes

  • Replace nn.Conv3dnn.Linear in patch embed for Qwen2-VL, Qwen2.5-VL, Qwen3-VL
  • Add _load_from_state_dict hook to reshape legacy 5D Conv3d weights → 2D on load (backwards compatible)
  • Add equivalence test verifying fp32 and bf16 outputs match

Files changed

  • src/transformers/models/qwen2_vl/modeling_qwen2_vl.py
  • src/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py
  • src/transformers/models/qwen3_vl/modeling_qwen3_vl.py
  • src/transformers/models/qwen3_vl/modular_qwen3_vl.py
  • tests/models/qwen2_vl/test_patch_embed.py

PatchEmbed.proj in Qwen2-VL, Qwen2.5-VL and Qwen3-VL uses nn.Conv3d with
kernel_size == stride, no padding, and no dilation.  This configuration
produces disjoint, non-overlapping receptive fields, making it mathematically
equivalent to nn.Linear over the flattened patch volume.  The Linear backend
dispatches to GEMM whereas cuDNN has no optimised path for this degenerate
Conv3d case on recent GPU architectures, resulting in runtimes orders of
magnitude slower than the matmul equivalent (benchmarked at ~53 000x on
Blackwell RTX 5090 in bf16, 16 s vs 0.3 ms for N=6080 patches).

Replace nn.Conv3d with nn.Linear(in_c * kt * kh * kw, embed_dim) in __init__
and update forward() to use reshape(-1, in_features) followed by the linear
projection.  A _load_from_state_dict hook reshapes 5-D Conv3d weights
(out, in, kt, kh, kw) from existing checkpoints to the 2-D shape expected by
Linear (out, in*kt*kh*kw), making the change transparent to all published
Qwen-VL checkpoints.  Already-2D weights (new checkpoints) pass through
unchanged, so the hook is idempotent.

fp32 max abs diff between old Conv3d and new Linear paths: < 1e-7 (single-
multiplication round-off).  bf16 cosine similarity on the full 24-layer vision
tower: > 0.999 per sample.

Fixes huggingface#45750

Files changed:
- modeling_qwen2_vl.py: fix PatchEmbed (bias=False, base class for Qwen2.5-VL)
- modular_qwen3_vl.py + modeling_qwen3_vl.py: fix Qwen3VLVisionPatchEmbed (bias=True)
- modeling_qwen2_5_vl.py: regenerated from modular (inherits fixed PatchEmbed)
- tests/models/qwen2_vl/test_patch_embed.py: fp32/bf16 equivalence + compat tests
@github-actions

Copy link
Copy Markdown
Contributor

[For maintainers] Suggested jobs to run (before merge)

run-slow: qwen2_5_vl, qwen2_vl, qwen3_vl

@JJJYmmm

JJJYmmm commented May 15, 2026

Copy link
Copy Markdown
Contributor

I think it's a duplicate of #45041 / https://huggingface.co/docs/transformers/fusion_mapping.

We can replace the 3d conv with a linear by enabling patch_embeddings in fusion_config:

from transformers import AutoModelForImageTextToText


model = AutoModelForImageTextToText.from_pretrained(
    "Qwen/Qwen2-VL-2B-Instruct",
    fusion_config={"patch_embeddings": True},
)

@TheShreyanshiDwivedi

Copy link
Copy Markdown
Author

Hey @JJJYmmm yes, sorry it looks like the same. My oversight!!

I have made some other pr's as well that i had prepared, mind having take a look at them as well?

@TheShreyanshiDwivedi

Copy link
Copy Markdown
Author

Closing as this overlaps with the opt-in approach already shipped in #45041 (inference_fusion conv3d→linear). The default-on behaviour would require updating ~10 sibling modeling files and warrants a separate, more comprehensive PR if maintainers decide to make it the default. Thanks @JJJYmmm for the pointer!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants