Skip to content

Commit 3a947e2

Browse files
authored
[inference_fusion] convert conv3d patch embed to linear (#45041)
* ok * fix consistency * pass qwen35 reverse mapping * update new failed test according to captured info * Revert "update new failed test according to captured info" This reverts commit 445a400. * make it optional * make fusion_mapping more general * make conv3d conversion more general * make fusion_mapping more general * better name for conversion * add fusion_mapping doc and clean tests * fix reverse mapping test follow gemma3n * chore: retrigger ci * tests: move qwen3.5 reverse mapping fix to separate branch * code clean! * ruff format and clean test to make it simple * richer doc * get converters from config rather than each module * add explict module_name check for fusion! * better isolated test and code clean * support serialized fusion_config * ruff format * config can handle unknown attributes * move fused cls out of spec by mixin * detailed comments * ruff
1 parent 282078b commit 3a947e2

7 files changed

Lines changed: 677 additions & 1 deletion

File tree

docs/source/en/_toctree.yml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,8 @@
1818
title: Customizing models
1919
- local: monkey_patching
2020
title: Monkey patching
21+
- local: fusion_mapping
22+
title: Fusion mapping
2123
- local: how_to_hack_models
2224
title: Customizing model components
2325
- local: model_sharing

docs/source/en/fusion_mapping.md

Lines changed: 99 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,99 @@
1+
<!--Copyright 2026 The HuggingFace Team. All rights reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
4+
the License. You may obtain a copy of the License at
5+
6+
http://www.apache.org/licenses/LICENSE-2.0
7+
8+
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
9+
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
10+
specific language governing permissions and limitations under the License.
11+
12+
⚠️ Note that this file is in Markdown but contain specific syntax for our doc-builder (similar to MDX) that may not be
13+
rendered properly in your Markdown viewer.
14+
15+
-->
16+
17+
# Fusion mapping (experimental feature)
18+
19+
Fusion mapping provides an opt-in way to replace model submodules at load time while preserving the original checkpoint format.
20+
21+
It builds on:
22+
23+
- [Monkey patching](./monkey_patching) to swap module classes before model instantiation.
24+
- [Dynamic weight loading](./weightconverter) to map weights between the original and fused runtime layouts.
25+
26+
> [!WARNING]
27+
> Fusion mapping is an experimental loading feature. It changes the runtime module structure and may affect model behavior. Use it only when you explicitly want a fused runtime layout.
28+
29+
## Quick start
30+
31+
Fusion is enabled through [`~PreTrainedModel.from_pretrained`] with `fusion_config`:
32+
33+
```python
34+
from transformers import AutoModelForImageTextToText
35+
36+
37+
model = AutoModelForImageTextToText.from_pretrained(
38+
"Qwen/Qwen2-VL-2B-Instruct",
39+
fusion_config={"patch_embeddings": True},
40+
)
41+
```
42+
43+
By default, no fusion is applied.
44+
If `fusion_config` is stored in the model config, `from_pretrained()` will reuse it automatically.
45+
46+
## How it works
47+
48+
Fusion registration happens before the model is instantiated:
49+
50+
1. [`~PreTrainedModel.from_pretrained`] uses the explicit `fusion_config` argument or falls back to `config.fusion_config`.
51+
2. The fusion registry validates the requested fusion names.
52+
3. Each enabled fusion meta-initializes the target model class, optionally filters candidate modules by name, and uses `is_fusable(...)` to discover compatible module classes.
53+
4. Fused replacement classes are registered through [`~transformers.monkey_patching.register_patch_mapping`].
54+
5. Matching [`~WeightTransform`] rules are generated from the config so checkpoint loading can map weights into the fused runtime layout.
55+
6. By default, [`~PreTrainedModel.save_pretrained`] uses the reverse conversion path to restore the original checkpoint layout. Pass `save_original_format=False` to keep the converted runtime layout instead.
56+
57+
This lets a fusion use a different runtime module structure while still loading from the original checkpoint format, and by default saving back to it as well.
58+
59+
Note: With the current monkey-patching mechanism, fusion registration is class-level: one compatible module class maps to one fused replacement class.
60+
61+
## Current fusion families
62+
63+
Currently, `fusion_config` supports one fusion family:
64+
65+
- `patch_embeddings`
66+
Enable with:
67+
68+
```python
69+
fusion_config = {"patch_embeddings": True}
70+
```
71+
72+
Effect:
73+
Replaces compatible `nn.Conv3d` patch embedding projections with equivalent flattened `nn.Linear` projections at runtime.
74+
75+
## Extending fusion mapping
76+
77+
To add a new fusion family:
78+
79+
1. Add an `is_fusable` predicate.
80+
This decides whether a discovered module is compatible with the fusion.
81+
2. Optionally add `target_modules_patterns`.
82+
This makes the discovery step more explicit by pre-filtering candidate module names before `is_fusable(...)`.
83+
3. Add a `make_fused_class` factory.
84+
This returns the runtime replacement class for a compatible module class.
85+
4. Add a `make_transforms` factory if the fused layout needs checkpoint conversion.
86+
This returns the [`~WeightTransform`] rules that map weights between the original and fused layouts for a given config.
87+
5. Register the new `ModuleFusionSpec` in [`fusion_mapping.py`](https://github.com/huggingface/transformers/blob/main/src/transformers/fusion_mapping.py).
88+
89+
Once registered, the new fusion becomes available through `fusion_config`.
90+
91+
## Internal API
92+
93+
[[autodoc]] fusion_mapping.ModuleFusionSpec
94+
95+
[[autodoc]] fusion_mapping.PatchEmbeddingsFusionSpec
96+
97+
[[autodoc]] fusion_mapping._register_module_fusion
98+
99+
[[autodoc]] fusion_mapping.register_fusion_patches

src/transformers/core_model_loading.py

Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -306,6 +306,72 @@ def reverse_op(self) -> ConversionOps:
306306
return Transpose(dim0=self.dim1, dim1=self.dim0, check_dims=self.check_dims)
307307

308308

309+
class Conv3dToLinear(ConversionOps):
310+
"""Conv3d weights → flattened Linear layout."""
311+
312+
def __init__(self, in_channels: int, kernel_size: tuple[int, int, int]):
313+
self.in_channels = in_channels
314+
self.kernel_size = kernel_size
315+
316+
@staticmethod
317+
def _get_target_pattern(
318+
input_dict: dict[str, torch.Tensor], source_patterns: list[str], target_patterns: list[str]
319+
) -> str:
320+
if len(input_dict) != 1:
321+
raise ValueError("Undefined Operation encountered!")
322+
if len(target_patterns) > 1:
323+
if len(source_patterns) == 1:
324+
return source_patterns[0]
325+
else:
326+
raise ValueError("Undefined Operation encountered!")
327+
return target_patterns[0]
328+
329+
@torch.no_grad
330+
def convert(
331+
self, input_dict: dict[str, torch.Tensor], source_patterns: list[str], target_patterns: list[str], **kwargs
332+
) -> dict[str, torch.Tensor]:
333+
target_pattern = self._get_target_pattern(input_dict, source_patterns, target_patterns)
334+
tensors = next(iter(input_dict.values()))
335+
tensor = tensors[0] if isinstance(tensors, list) else tensors
336+
337+
if tensor.ndim == 5:
338+
tensor = tensor.reshape(tensor.shape[0], -1).contiguous()
339+
elif tensor.ndim != 2:
340+
raise ValueError(f"Conv3dToLinear expects a 5D or 2D tensor, got {tensor.ndim}D")
341+
342+
return {target_pattern: tensor}
343+
344+
@property
345+
def reverse_op(self) -> ConversionOps:
346+
return LinearToConv3d(in_channels=self.in_channels, kernel_size=self.kernel_size)
347+
348+
349+
class LinearToConv3d(ConversionOps):
350+
"""Flattened Linear weights → Conv3d layout."""
351+
352+
def __init__(self, in_channels: int, kernel_size: tuple[int, int, int]):
353+
self.in_channels = in_channels
354+
self.kernel_size = kernel_size
355+
356+
@torch.no_grad
357+
def convert(
358+
self, input_dict: dict[str, torch.Tensor], source_patterns: list[str], target_patterns: list[str], **kwargs
359+
) -> dict[str, torch.Tensor]:
360+
target_pattern = Conv3dToLinear._get_target_pattern(input_dict, source_patterns, target_patterns)
361+
tensors = next(iter(input_dict.values()))
362+
tensor = tensors[0] if isinstance(tensors, list) else tensors
363+
364+
target_shape = (tensor.shape[0], self.in_channels, *self.kernel_size)
365+
if tensor.numel() != math.prod(target_shape):
366+
raise ValueError(f"Cannot reshape tensor with shape {tensor.shape} into {target_shape}")
367+
368+
return {target_pattern: tensor.reshape(target_shape).contiguous()}
369+
370+
@property
371+
def reverse_op(self) -> ConversionOps:
372+
return Conv3dToLinear(in_channels=self.in_channels, kernel_size=self.kernel_size)
373+
374+
309375
class PermuteForRope(ConversionOps):
310376
"""
311377
Applies the permutation required to convert complex RoPE weights to the split sin/cos format.

0 commit comments

Comments
 (0)