perf: replace degenerate Conv3d with Linear in Qwen VL patch embedding#45992
Closed
TheShreyanshiDwivedi wants to merge 2 commits into
Closed
perf: replace degenerate Conv3d with Linear in Qwen VL patch embedding#45992TheShreyanshiDwivedi wants to merge 2 commits into
TheShreyanshiDwivedi wants to merge 2 commits into
Conversation
PatchEmbed.proj in Qwen2-VL, Qwen2.5-VL and Qwen3-VL uses nn.Conv3d with kernel_size == stride, no padding, and no dilation. This configuration produces disjoint, non-overlapping receptive fields, making it mathematically equivalent to nn.Linear over the flattened patch volume. The Linear backend dispatches to GEMM whereas cuDNN has no optimised path for this degenerate Conv3d case on recent GPU architectures, resulting in runtimes orders of magnitude slower than the matmul equivalent (benchmarked at ~53 000x on Blackwell RTX 5090 in bf16, 16 s vs 0.3 ms for N=6080 patches). Replace nn.Conv3d with nn.Linear(in_c * kt * kh * kw, embed_dim) in __init__ and update forward() to use reshape(-1, in_features) followed by the linear projection. A _load_from_state_dict hook reshapes 5-D Conv3d weights (out, in, kt, kh, kw) from existing checkpoints to the 2-D shape expected by Linear (out, in*kt*kh*kw), making the change transparent to all published Qwen-VL checkpoints. Already-2D weights (new checkpoints) pass through unchanged, so the hook is idempotent. fp32 max abs diff between old Conv3d and new Linear paths: < 1e-7 (single- multiplication round-off). bf16 cosine similarity on the full 24-layer vision tower: > 0.999 per sample. Fixes huggingface#45750 Files changed: - modeling_qwen2_vl.py: fix PatchEmbed (bias=False, base class for Qwen2.5-VL) - modular_qwen3_vl.py + modeling_qwen3_vl.py: fix Qwen3VLVisionPatchEmbed (bias=True) - modeling_qwen2_5_vl.py: regenerated from modular (inherits fixed PatchEmbed) - tests/models/qwen2_vl/test_patch_embed.py: fp32/bf16 equivalence + compat tests
Contributor
|
[For maintainers] Suggested jobs to run (before merge) run-slow: qwen2_5_vl, qwen2_vl, qwen3_vl |
Contributor
|
I think it's a duplicate of #45041 / https://huggingface.co/docs/transformers/fusion_mapping. We can replace the 3d conv with a linear by enabling from transformers import AutoModelForImageTextToText
model = AutoModelForImageTextToText.from_pretrained(
"Qwen/Qwen2-VL-2B-Instruct",
fusion_config={"patch_embeddings": True},
) |
Author
|
Hey @JJJYmmm yes, sorry it looks like the same. My oversight!! I have made some other pr's as well that i had prepared, mind having take a look at them as well? |
Author
|
Closing as this overlaps with the opt-in approach already shipped in #45041 (inference_fusion conv3d→linear). The default-on behaviour would require updating ~10 sibling modeling files and warrants a separate, more comprehensive PR if maintainers decide to make it the default. Thanks @JJJYmmm for the pointer! |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
What does this PR do?
Replaces
nn.Conv3d(kernel_size=k, stride=k)— wherekernel_size == stride— with an equivalentnn.Linearin Qwen2-VL, Qwen2.5-VL and Qwen3-VL patch embeddings.When
kernel_size == stride, Conv3d degenerates to a patch-wise projection with no overlap. This is mathematically identical to a linear layer after reshaping the input. Usingnn.Lineardispatches to a GEMM kernel instead of im2col + convolution, giving a large speedup especially on modern hardware.Changes
nn.Conv3d→nn.Linearin patch embed for Qwen2-VL, Qwen2.5-VL, Qwen3-VL_load_from_state_dicthook to reshape legacy 5D Conv3d weights → 2D on load (backwards compatible)Files changed
src/transformers/models/qwen2_vl/modeling_qwen2_vl.pysrc/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.pysrc/transformers/models/qwen3_vl/modeling_qwen3_vl.pysrc/transformers/models/qwen3_vl/modular_qwen3_vl.pytests/models/qwen2_vl/test_patch_embed.py