diff --git a/docs/source/en/_toctree.yml b/docs/source/en/_toctree.yml
index 3c776a0c8d3c..61190f7837e5 100644
--- a/docs/source/en/_toctree.yml
+++ b/docs/source/en/_toctree.yml
@@ -741,6 +741,8 @@
title: MiniMax
- local: model_doc/minimax_m2
title: MiniMax-M2
+ - local: model_doc/minimax_m3_vl
+ title: MiniMax-M3-VL
- local: model_doc/ministral
title: Ministral
- local: model_doc/ministral3
diff --git a/docs/source/en/model_doc/minimax_m3_vl.md b/docs/source/en/model_doc/minimax_m3_vl.md
new file mode 100644
index 000000000000..5a73737b7e67
--- /dev/null
+++ b/docs/source/en/model_doc/minimax_m3_vl.md
@@ -0,0 +1,253 @@
+
+*This model was contributed to Hugging Face Transformers on 2026-06-12.*
+
+
+# MiniMax-M3-VL
+
+## Overview
+
+MiniMax-M3-VL is the vision-language member of the MiniMax-M3 family. It pairs a CLIP-style vision tower (Conv3d patch embedding with 3D rotary position embeddings) with the MiniMax-M3 text backbone, a mixed dense/sparse Mixture-of-Experts decoder that uses SwiGLU-OAI gated experts and a lightning indexer for block-sparse attention.
+
+## Architecture
+### Block-sparse attention (Lightning Indexer)
+
+Every layer is GQA (`num_key_value_heads = 4`) with per-head QK-norm and **partial RoPE** on the first
+`rotary_dim`. `config.layer_types[i]` then picks `"full_attention"` (dense causal) or
+`"minimax_m3_sparse"`, where a [`MiniMaxM3VLIndexer`] decides, per query, which block of keys the main attention may see.
+
+The indexer scores every key, then **max-poolsthose per-key scores into blocks of `index_block_size` keys**, so selection happens at the granularity of a *block
+of keys*: per query it keeps the top-`index_topk_blocks` key blocks plus the always-on `index_local_blocks`
+local-window block (under block-level causality), broadcasts the per-block `0`/`-inf` choice back onto every key in
+the block. The result is a `[B, 1, S_q, S_k]` additive bias summed onto the causal mask.
+Theoretically this means that the attention is only computed over the selected blocks of keys, but `transformers` does not support the kernels that compute this efficiently!
+We are adding it to `kernels` asap!
+
+
+
+
+### Vision tower
+
+A [`MiniMaxM3VLVisionModel`]: a `Conv3d` patch embedding over flattened `[N_patches, C·T·P·P]` input, a stack of
+CLIP-style encoder layers carrying a **3D rotary** position embedding (time / height / width bands). A [`MiniMaxM3VLPatchMerger`] groups
+`spatial_merge_size²` patches into the channel dim before the 2-layer GELU [`MiniMaxM3VLMultiModalProjector`] maps vision features into the text hidden size.
+
+## Usage examples
+
+The example below runs the model on a real image loaded with [`~transformers.image_utils.load_image`].
+
+```python
+import torch
+from transformers import AutoModelForImageTextToText, AutoProcessor
+from transformers.image_utils import load_image
+
+
+model = AutoModelForImageTextToText.from_pretrained(
+ "MiniMaxAI/MiniMax-M3-preview", dtype=torch.bfloat16, device_map="auto",
+)
+processor = AutoProcessor.from_pretrained("MiniMaxAI/MiniMax-M3-preview")
+
+image = load_image("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/bee.jpg")
+messages = [
+ {
+ "role": "user",
+ "content": [
+ {"type": "image"},
+ {"type": "text", "text": "Describe this image briefly."},
+ ],
+ }
+]
+text = processor.apply_chat_template(messages, add_generation_prompt=True, tokenize=False)
+inputs = processor(images=[image], text=text, return_tensors="pt").to(model.device)
+
+generated_ids = model.generate(**inputs, max_new_tokens=32, do_sample=False)
+print(processor.batch_decode(generated_ids, skip_special_tokens=True)[0])
+```
+
+### Apple example
+
+This example asks the model about an image of apples, again loading a real image with
+[`~transformers.image_utils.load_image`].
+
+```python
+import torch
+from transformers import AutoModelForImageTextToText, AutoProcessor
+from transformers.image_utils import load_image
+
+
+model = AutoModelForImageTextToText.from_pretrained(
+ "MiniMaxAI/MiniMax-M3-preview", dtype=torch.bfloat16, device_map="auto",
+)
+processor = AutoProcessor.from_pretrained("MiniMaxAI/MiniMax-M3-preview")
+
+image = load_image("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/bee.jpg")
+messages = [
+ {
+ "role": "user",
+ "content": [
+ {"type": "image"},
+ {"type": "text", "text": "How many apples are in this image, and what color are they?"},
+ ],
+ }
+]
+text = processor.apply_chat_template(messages, add_generation_prompt=True, tokenize=False)
+inputs = processor(images=[image], text=text, return_tensors="pt").to(model.device)
+
+generated_ids = model.generate(**inputs, max_new_tokens=32, do_sample=False)
+print(processor.batch_decode(generated_ids, skip_special_tokens=True)[0])
+```
+
+## Fastest inference configuration
+
+| ctx | SDPA decode | MSA decode | MSA decode adv. | SDPA prefill | MSA prefill | MSA prefill adv. |
+| --: | ----------: | ---------: | --------------: | -----------: | ----------: | ---------------: |
+| 2K | 27.8 tok/s | 31.0 | +12% | 303 ms | 257 ms | 1.18× |
+| 4K | 23.4 tok/s | 30.5 | +30% | 684 ms | 460 ms | 1.49× |
+| 8K | 17.8 tok/s | 29.6 | +66% | 1906 ms | 976 ms | 1.95× |
+| 16K | 12.0 tok/s | 27.6 | +130% | 6110 ms | 2344 ms | 2.61× |
+
+The checkpoint ships in native MXFP8. For **decode throughput**, the fastest validated configuration is
+**bf16 (dequantized at load) + the MSA block-sparse attention kernel + tensor & expert parallelism + a
+`reduce-overhead` cudagraph compile** — roughly **31 tok/s** decode on 8×B200 at a 2048-token prefill.
+
+Keeping the weights in **native FP8 is a memory-footprint option only — it is never faster on this setup**.
+The FP8 Triton experts/linear kernels lower as opaque inductor fallback kernels that cudagraph cannot
+capture on the hot expert path, so native-FP8 decode measured ~4.2 tok/s (≈7× slower than the bf16 path)
+even under `torch.compile(fullgraph=True)`. Use FP8 only when the bf16 weights do not fit.
+
+| config (sdpa baseline, TP+EP, 2048-token prefill, 8×B200) | decode |
+|---|---|
+| bf16 dequantize-at-load + **MSA** + compile/cudagraph | **~31 tok/s** |
+| bf16 dequantize-at-load + sdpa + compile/cudagraph | ~28 tok/s |
+| native FP8 + compile/cudagraph | ~4 tok/s (memory-only, not for speed) |
+
+Dequantizing to bf16 only fits with even sharding across GPUs (TP/EP), not with `device_map="auto"`
+(pipeline placement OOMs at load). Launch one process per GPU with `torchrun`:
+
+```bash
+torchrun --nproc_per_node=8 fastest_m3_vl.py
+```
+
+```python
+# fastest_m3_vl.py
+import os, sys
+import torch
+import torch.distributed as dist
+from transformers import (
+ AutoModelForImageTextToText,
+ AutoTokenizer,
+ CompileConfig,
+ FineGrainedFP8Config,
+)
+from transformers.distributed import DistributedConfig
+
+# The indexer feeds SDPA an additive float mask; the cuDNN SDP backend segfaults on it (B200).
+torch.backends.cuda.enable_cudnn_sdp(False)
+
+model = AutoModelForImageTextToText.from_pretrained(
+ "MiniMaxAI/MiniMax-M3-preview",
+ dtype=torch.bfloat16,
+ # Dequantize the native MXFP8 weights to bf16 at load (the speed win); needs even TP/EP sharding.
+ quantization_config=FineGrainedFP8Config(dequantize=True),
+ tp_plan="auto",
+ distributed_config=DistributedConfig(enable_expert_parallel=True),
+ attn_implementation="kernels-staging/msa@v0", # MSA block-sparse attention kernel
+)
+model.eval()
+
+tokenizer = AutoTokenizer.from_pretrained("MiniMaxAI/MiniMax-M3-preview")
+messages = [{"role": "user", "content": "Summarize the history of computing."}]
+inputs = tokenizer.apply_chat_template(
+ messages, add_generation_prompt=True, return_tensors="pt", return_dict=True
+).to(f"cuda:{os.environ.get('LOCAL_RANK', '0')}")
+
+generated_ids = model.generate(
+ **inputs,
+ max_new_tokens=128,
+ do_sample=False,
+ # Static cache + reduce-overhead cudagraph capture is what pushes decode to ~31 tok/s.
+ cache_implementation="static",
+ compile_config=CompileConfig(mode="reduce-overhead", fullgraph=True),
+)
+if int(os.environ.get("RANK", "0")) == 0:
+ print(tokenizer.decode(generated_ids[0, inputs["input_ids"].shape[1]:], skip_special_tokens=True))
+
+# cudagraph-captured NCCL collectives deadlock the NCCL/CUDA destructors at teardown; the output is
+# already produced, so hard-exit to skip the hanging cleanup.
+if dist.is_initialized():
+ sys.stdout.flush()
+ os._exit(0)
+```
+
+## MiniMaxM3VLConfig
+
+[[autodoc]] MiniMaxM3VLConfig
+
+## MiniMaxM3VLTextConfig
+
+[[autodoc]] MiniMaxM3VLTextConfig
+
+## MiniMaxM3VLVisionConfig
+
+[[autodoc]] MiniMaxM3VLVisionConfig
+
+## MiniMaxM3VLProcessor
+
+[[autodoc]] MiniMaxM3VLProcessor
+
+## MiniMaxM3VLImageProcessor
+
+This is a standalone (non-modular) image processor: it shares the patch-flattening idea of [`Qwen2VLImageProcessor`]
+but does not inherit from it because the two diverge in ways that touch most of the class. The resize budget is driven by
+a `max_pixels` attribute and a `{"height", "width"}` `size` rather than Qwen's `shortest_edge`/`longest_edge` scheme; the
+`smart_resize` helper clamps the initial rounding with `max(factor, ...)`; and `_preprocess` performs real temporal
+handling (5D patches, last-frame repeat to fill `temporal_patch_size`, and a `grid_t` dimension) instead of Qwen's
+`grid_t = 1` + expand. Mapping to or subclassing Qwen would therefore change behavior or require overriding nearly
+everything, so the processor is kept on its own.
+
+[[autodoc]] MiniMaxM3VLImageProcessor
+
+## MiniMaxM3VLVideoProcessor
+
+[[autodoc]] MiniMaxM3VLVideoProcessor
+
+## MiniMaxM3VLVisionModel
+
+[[autodoc]] MiniMaxM3VLVisionModel
+ - forward
+
+## MiniMaxM3VLTextModel
+
+[[autodoc]] MiniMaxM3VLTextModel
+ - forward
+
+## MiniMaxM3VLModel
+
+[[autodoc]] MiniMaxM3VLModel
+ - forward
+
+## MiniMaxM3VLForCausalLM
+
+[[autodoc]] MiniMaxM3VLForCausalLM
+ - forward
+
+## MiniMaxM3SparseForConditionalGeneration
+
+[[autodoc]] MiniMaxM3SparseForConditionalGeneration
+ - forward
diff --git a/src/transformers/cache_utils.py b/src/transformers/cache_utils.py
index 4801b37a8558..46f4620bdd1e 100644
--- a/src/transformers/cache_utils.py
+++ b/src/transformers/cache_utils.py
@@ -33,6 +33,12 @@
# ``PreTrainedConfig`` (the decoder text config) as the only positional argument.
LAYER_TYPE_CACHE_MAPPING: dict[str, type] = {}
+# Parallel registry for the *static* implementation of a custom layer type, consulted by
+# ``StaticCache``. A ``StaticLayer`` subclass with a ``layer_type`` registers here instead of in
+# ``LAYER_TYPE_CACHE_MAPPING`` (constructed with ``max_cache_len=...``), so a single ``layer_type``
+# can have both a dynamic and a static cache layer (e.g. M3's sparse-attention indexer cache).
+LAYER_TYPE_STATIC_CACHE_MAPPING: dict[str, type] = {}
+
class CacheLayerMixin(ABC):
"""Base, abstract class for a single layer's cache."""
@@ -47,7 +53,11 @@ def __init_subclass__(cls, **kwargs):
super().__init_subclass__(**kwargs)
layer_type = cls.__dict__.get("layer_type", None)
if layer_type is not None:
- LAYER_TYPE_CACHE_MAPPING[layer_type] = cls
+ static_base = globals().get("StaticLayer")
+ if static_base is not None and issubclass(cls, static_base):
+ LAYER_TYPE_STATIC_CACHE_MAPPING[layer_type] = cls
+ else:
+ LAYER_TYPE_CACHE_MAPPING[layer_type] = cls
def __init__(self):
self.keys: torch.Tensor | None = None
@@ -1578,6 +1588,9 @@ def __init__(
# LinearAttention layers are static by essence - using `"moe"` as well is a trick, see the comment about it on DynamicCache
elif layer_type in ("mamba", "conv", "linear_attention", "moe"):
layer = LinearAttentionLayer()
+ # Custom layer types (e.g. M3's sparse-attention indexer cache) that registered a static variant.
+ elif layer_type in LAYER_TYPE_STATIC_CACHE_MAPPING:
+ layer = LAYER_TYPE_STATIC_CACHE_MAPPING[layer_type](max_cache_len=max_cache_len)
elif layer_type == "deepseek_sparse_attention":
# Static / compile-friendly indexed layer (preallocated indexer key cache).
layer = StaticIndexedLayer(max_cache_len=max_cache_len)
diff --git a/src/transformers/configuration_utils.py b/src/transformers/configuration_utils.py
index f902994eb71b..2b6fdce06e28 100755
--- a/src/transformers/configuration_utils.py
+++ b/src/transformers/configuration_utils.py
@@ -65,6 +65,7 @@
"chunked_attention",
"compressed_sparse_attention", # CSA, used in deepseek_v4
"heavily_compressed_attention", # HCA, used in deepseek_v4
+ "minimax_m3_sparse", # lightning-index sparse attention, used in minimax_m3_vl
"linear_attention", # used in minimax
"conv", # used in LFMv2
"mamba",
diff --git a/src/transformers/conversion_mapping.py b/src/transformers/conversion_mapping.py
index fe1a66b7a210..a0cff222273a 100755
--- a/src/transformers/conversion_mapping.py
+++ b/src/transformers/conversion_mapping.py
@@ -455,6 +455,115 @@ def _build_checkpoint_conversion_mapping():
WeightRenaming(source_patterns=r"^vision_tower", target_patterns="model.vision_tower"),
WeightRenaming(source_patterns=r"^multi_modal_projector", target_patterns="model.multi_modal_projector"),
],
+ "minimax_m3_vl": [
+ # Ordering matters for save round-tripping: the reverse mapping flips the order *and* each
+ # transform (see deepseek_v4 above). We therefore split into two passes: structural prefix
+ # renames first (so they apply last on save / first on load), then specific in-prefix renames
+ # that operate on the already-prefixed keys. Every target prefix here is distinct and anchored,
+ # so no reversed source pattern is broad enough to steal keys from another namespace.
+ # ---- Pass 1: top-level + structural prefix renames ----
+ WeightRenaming(source_patterns=r"^language_model\.lm_head", target_patterns="lm_head"),
+ WeightRenaming(source_patterns=r"^language_model\.model\.", target_patterns="model.language_model."),
+ # The vision tower flattens CLIP's `vision_model.{encoder.layers,embeddings.patch_embedding,
+ # pre_layrnorm}` nesting onto `vision_tower.{layers,embeddings.proj,pre_layrnorm}`. Each rule is
+ # anchored and leaf-specific so its reverse re-inserts `vision_model` only on the right keys (a
+ # blanket `.vision_model.` -> `.` rule reverses to "match any char" and mangles every key).
+ WeightRenaming(
+ source_patterns=r"^vision_tower\.vision_model\.embeddings\.patch_embedding\.",
+ target_patterns="model.vision_tower.embeddings.proj.",
+ ),
+ WeightRenaming(
+ source_patterns=r"^vision_tower\.vision_model\.encoder\.layers\.",
+ target_patterns="model.vision_tower.layers.",
+ ),
+ WeightRenaming(
+ source_patterns=r"^vision_tower\.vision_model\.pre_layrnorm\.",
+ target_patterns="model.vision_tower.pre_layrnorm.",
+ ),
+ # The projector hosts both the upstream `multi_modal_projector.linear_{1,2}` and the
+ # `patch_merge_mlp.linear_{1,2}` (registered as `merge_linear_{1,2}`). Spell each leaf out so the
+ # reversed `linear_*` source never also matches `merge_linear_*` (or vice versa).
+ WeightRenaming(
+ source_patterns=r"^multi_modal_projector\.linear_1\.",
+ target_patterns="model.multi_modal_projector.linear_1.",
+ ),
+ WeightRenaming(
+ source_patterns=r"^multi_modal_projector\.linear_2\.",
+ target_patterns="model.multi_modal_projector.linear_2.",
+ ),
+ WeightRenaming(
+ source_patterns=r"^patch_merge_mlp\.linear_1\.",
+ target_patterns="model.multi_modal_projector.merge_linear_1.",
+ ),
+ WeightRenaming(
+ source_patterns=r"^patch_merge_mlp\.linear_2\.",
+ target_patterns="model.multi_modal_projector.merge_linear_2.",
+ ),
+ # ---- Pass 2: specific in-prefix renames (operate on already-prefixed keys) ----
+ # MoE layers rename `block_sparse_moe.*` -> `mlp.*`, but dense layers already use `mlp.*`. A blanket
+ # `block_sparse_moe.` -> `mlp.` rule reverses to `mlp.` -> `block_sparse_moe.` and corrupts the dense
+ # layers on save, so we rename per MoE leaf (`experts` / `shared_experts` / `gate` / e-score). The
+ # reversed `mlp.experts.` / `mlp.shared_experts.` / `mlp.gate.weight` sources match only MoE
+ # sublayers, never the dense `mlp.gate_up_proj` / `mlp.down_proj`.
+ WeightRenaming(
+ source_patterns=r"\.language_model\.layers\.(\d+)\.block_sparse_moe\.experts\.",
+ target_patterns=r".language_model.layers.\1.mlp.experts.",
+ ),
+ WeightRenaming(
+ source_patterns=r"\.language_model\.layers\.(\d+)\.block_sparse_moe\.shared_experts\.",
+ target_patterns=r".language_model.layers.\1.mlp.shared_experts.",
+ ),
+ WeightRenaming(
+ source_patterns=r"\.language_model\.layers\.(\d+)\.block_sparse_moe\.gate\.weight",
+ target_patterns=r".language_model.layers.\1.mlp.gate.weight",
+ ),
+ WeightRenaming(
+ source_patterns=r"\.language_model\.layers\.(\d+)\.block_sparse_moe\.e_score_correction_bias",
+ target_patterns=r".language_model.layers.\1.mlp.gate.e_score_correction_bias",
+ ),
+ WeightRenaming(
+ source_patterns=r"\.self_attn\.index_q_proj\.",
+ target_patterns=".self_attn.indexer.q_proj.",
+ ),
+ WeightRenaming(
+ source_patterns=r"\.self_attn\.index_k_proj\.",
+ target_patterns=".self_attn.indexer.k_proj.",
+ ),
+ WeightRenaming(
+ source_patterns=r"\.self_attn\.index_q_norm\.",
+ target_patterns=".self_attn.indexer.q_norm.",
+ ),
+ WeightRenaming(
+ source_patterns=r"\.self_attn\.index_k_norm\.",
+ target_patterns=".self_attn.indexer.k_norm.",
+ ),
+ WeightConverter(
+ source_patterns=[
+ "mlp.experts.*.w1.weight",
+ "mlp.experts.*.w3.weight",
+ ],
+ target_patterns="mlp.experts.gate_up_proj",
+ operations=[MergeModulelist(dim=0), Concatenate(dim=1)],
+ ),
+ WeightConverter(
+ source_patterns="mlp.experts.*.w2.weight",
+ target_patterns="mlp.experts.down_proj",
+ operations=[MergeModulelist(dim=0)],
+ ),
+ WeightConverter(
+ source_patterns=["mlp.gate_proj.weight", "mlp.up_proj.weight"],
+ target_patterns="mlp.gate_up_proj.weight",
+ operations=[Concatenate(dim=0)],
+ ),
+ WeightConverter(
+ source_patterns=[
+ "mlp.shared_experts.gate_proj.weight",
+ "mlp.shared_experts.up_proj.weight",
+ ],
+ target_patterns="mlp.shared_experts.gate_up_proj.weight",
+ operations=[Concatenate(dim=0)],
+ ),
+ ],
"qwen2_audio": [
WeightRenaming(source_patterns=r"^language_model.model", target_patterns="model.language_model"),
WeightRenaming(source_patterns=r"^language_model.lm_head", target_patterns="lm_head"),
diff --git a/src/transformers/core_model_loading.py b/src/transformers/core_model_loading.py
index 2493e1560a36..0d57b0b8d23c 100644
--- a/src/transformers/core_model_loading.py
+++ b/src/transformers/core_model_loading.py
@@ -1429,6 +1429,10 @@ def convert_and_load_state_dict_in_model(
mapping.distributed_operation = tp_layer(
device_mesh=device_mesh, rank=device_mesh.get_local_rank(), empty_param=empty_param.clone()
)
+ # Per-expert sharding (EP) needs `tensor_idx` = the expert index so the
+ # distributed op selects whole experts. The signal is a `MergeModulelist`
+ # in the chain; it isn't always `operations[0]` (e.g. an FP8 quantizer
+ # prepends a scale-decode op), so scan the whole chain rather than just the head.
shard_index = (
len(mapping.collected_tensors.get(source_pattern, []))
if isinstance(mapping, WeightConverter)
diff --git a/src/transformers/generation/utils.py b/src/transformers/generation/utils.py
index c58a3d154e9f..03c59d7b2465 100644
--- a/src/transformers/generation/utils.py
+++ b/src/transformers/generation/utils.py
@@ -1840,9 +1840,13 @@ def _supports_default_dynamic_cache(cls: type["GenerativePreTrainedModel"]) -> b
"rwkv",
"xlstm",
)
- # name clash between minimax and minimax m2, so we add this "or"
- return "minimaxm2" in cls.__name__.lower() or all(
- unsupported_name not in cls.__name__.lower() for unsupported_name in unsupported_model_names
+ # The "minimax" exclusion targets the original linear-attention MiniMax (custom cache); the later
+ # MiniMax M2 / M3 are standard attention models that use the regular Dynamic/Static caches.
+ name = cls.__name__.lower()
+ return (
+ "minimaxm2" in name
+ or "minimaxm3" in name
+ or all(unsupported_name not in name for unsupported_name in unsupported_model_names)
)
def _prepare_cache_for_generation(
@@ -1904,13 +1908,7 @@ def _prepare_cache_for_generation(
generation_config.cache_implementation = "dynamic_full"
dynamic_cache_kwargs = {}
- # linear attention models always need to pass the config, otherwise it will use an Attention cache for the LinearAttention layers
- is_linear_attention = any(
- x in ("mamba", "conv", "linear_attention")
- for x in (getattr(self.config.get_text_config(decoder=True), "layer_types", []) or [])
- )
- if generation_config.cache_implementation != "dynamic_full" or is_linear_attention:
- dynamic_cache_kwargs["config"] = self.config.get_text_config(decoder=True)
+ dynamic_cache_kwargs["config"] = self.config.get_text_config(decoder=True)
if generation_config.cache_implementation == "offloaded":
dynamic_cache_kwargs["offloading"] = True
diff --git a/src/transformers/integrations/finegrained_fp8.py b/src/transformers/integrations/finegrained_fp8.py
index d8a43c97c193..2bbda2d23025 100644
--- a/src/transformers/integrations/finegrained_fp8.py
+++ b/src/transformers/integrations/finegrained_fp8.py
@@ -593,6 +593,8 @@ def __init__(
self.activation_scheme = activation_scheme
self.num_experts = _first_attr(config, "num_local_experts", "num_experts")
self.intermediate_dim = _first_attr(config, "moe_intermediate_size", "intermediate_size")
+ self.swiglu_alpha = getattr(config, "swiglu_alpha", None)
+ self.swiglu_limit = getattr(config, "swiglu_limit", None)
self.act_fn = ACT2FN[_first_attr(config, "hidden_activation", "hidden_act")]
self.limit = getattr(config, "swiglu_limit", None)
@@ -640,7 +642,13 @@ def __init__(
def _apply_gate(self, gate_up: torch.Tensor) -> torch.Tensor:
gate, up = gate_up.chunk(2, dim=-1)
- if self.limit is not None:
+ if self.swiglu_alpha is not None:
+ # Clamped SwiGLU-OAI gate (same math as the model's non-quantized experts).
+ gate = gate.clamp(max=self.swiglu_limit)
+ up = up.clamp(min=-self.swiglu_limit, max=self.swiglu_limit)
+ glu = gate * torch.sigmoid(gate * self.swiglu_alpha)
+ return (up + 1.0) * glu
+ elif self.limit is not None:
gate = gate.clamp(max=self.limit)
up = up.clamp(min=-self.limit, max=self.limit)
return self.act_fn(gate) * up
@@ -958,10 +966,16 @@ def _dequantize_one(
output_dtype = (
scales.dtype if scales.dtype.is_floating_point and scales.element_size() >= 2 else torch.bfloat16
)
-
+ # MXFP8 checkpoints ship E8M0 exponents stored as ``torch.uint8`` (one byte per
+ # block) — the actual scale is `2 ** (byte - 127)`. Interpreting the raw bytes
+ # as scalar multipliers would be silently wrong, so unpack to fp32 here.
+ if scales.dtype == torch.uint8:
+ s_fp32 = (scales.to(torch.float32) - 127.0).exp2()
+ else:
+ s_fp32 = scales.to(torch.float32)
original_shape = quantized_fp32.shape
q = quantized_fp32.reshape(-1, scale_rows, block_m, scale_cols, block_n)
- s = scales.to(torch.float32).reshape(-1, scale_rows, scale_cols).unsqueeze(-1).unsqueeze(2)
+ s = s_fp32.reshape(-1, scale_rows, scale_cols).unsqueeze(-1).unsqueeze(2)
return (q * s).to(output_dtype).reshape(original_shape)
def _get_target_dtype(self, model: torch.nn.Module | None, full_layer_name: str | None) -> torch.dtype | None:
@@ -1023,3 +1037,29 @@ def reverse_op(self) -> ConversionOps:
# checkpoint preserves the FP8 format (weight + per-block ``weight_scale_inv``)
# whether the in-memory state stayed quantized or was dequantized for compute.
return Fp8Quantize(self.hf_quantizer)
+
+
+class Fp8DecodeScale(ConversionOps):
+ """Decode MXFP8 ``ue8m0`` per-block scales (stored as ``uint8`` exponents) into the
+ float32 multiplicative scales the FP8 compute path expects.
+
+ Native MXFP8 loading (``dequantize=False``) keeps weights in ``float8_e4m3fn`` and only
+ needs the sibling ``*.weight_scale_inv`` tensors turned from raw E8M0 bytes into real
+ scales (``2 ** (byte - 127)``). Prepended to each weight converter, this op runs before
+ any merge/concat collapses the per-expert structure: it rewrites only the ``uint8`` scale
+ entries and passes weights (and already-float scales) through untouched.
+ """
+
+ def __init__(self, hf_quantizer):
+ self.hf_quantizer = hf_quantizer
+
+ @staticmethod
+ def _decode(tensor: torch.Tensor) -> torch.Tensor:
+ # E8M0 stores one exponent byte per block; the real scale is ``2 ** (byte - 127)``.
+ return (tensor.to(torch.float32) - 127.0).exp2() if tensor.dtype == torch.uint8 else tensor
+
+ def convert(self, input_dict: dict[str, list[torch.Tensor] | torch.Tensor], **kwargs):
+ return {
+ key: [self._decode(t) for t in value] if isinstance(value, list) else self._decode(value)
+ for key, value in input_dict.items()
+ }
diff --git a/src/transformers/integrations/hub_kernels.py b/src/transformers/integrations/hub_kernels.py
index 9ae52963c245..8e7bbbc24ab9 100644
--- a/src/transformers/integrations/hub_kernels.py
+++ b/src/transformers/integrations/hub_kernels.py
@@ -370,16 +370,26 @@ def load_and_register_attn_kernel(
except Exception as e:
raise ValueError(f"An error occurred while trying to load from '{repo_id}': {e}.")
# correctly wrap the kernel
+ mask_implementation = "flash_attention_2"
if hasattr(kernel, "flash_attn_varlen_func"):
if attention_wrapper is None:
attention_wrapper = flash_attention_forward
kernel_function = attention_wrapper
+ elif hasattr(kernel, "sparse_atten_func"):
+ # Block-sparse kernels (e.g. `kernels-staging/msa`) expose `sparse_atten_func` instead of
+ # `flash_attn_varlen_func`; their call contract differs from the attention interface, so we
+ # bind the dedicated transformers-side wrapper that adapts the arguments and hides the
+ # prefill-kernel / decode-fallback dispatch.
+ from .msa_attention import msa_attention_forward
+
+ kernel_function = attention_wrapper if attention_wrapper is not None else msa_attention_forward
+ mask_implementation = "sdpa"
elif kernel_name is not None:
kernel_function = getattr(kernel, kernel_name)
# Register the kernel as a valid attention
ALL_ATTENTION_FUNCTIONS.register(attn_implementation, kernel_function)
- ALL_MASK_ATTENTION_FUNCTIONS.register(attn_implementation, ALL_MASK_ATTENTION_FUNCTIONS["flash_attention_2"])
+ ALL_MASK_ATTENTION_FUNCTIONS.register(attn_implementation, ALL_MASK_ATTENTION_FUNCTIONS[mask_implementation])
return kernel
diff --git a/src/transformers/integrations/msa_attention.py b/src/transformers/integrations/msa_attention.py
new file mode 100644
index 000000000000..87a479b3d9a4
--- /dev/null
+++ b/src/transformers/integrations/msa_attention.py
@@ -0,0 +1,259 @@
+# Copyright 2026 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import torch
+
+from ..utils import logging
+from .sdpa_attention import sdpa_attention_forward
+
+
+logger = logging.get_logger(__name__)
+
+# `sparse_atten_func` only compiles for these per-query block counts.
+MSA_SUPPORTED_TOPK = (4, 8, 16, 32)
+# `SparseK2qCsrBuilderSm100` only supports a 128-key block.
+MSA_SUPPORTED_BLOCK_SIZE = 128
+# SM100 / Blackwell head_dim 128 kernel.
+MSA_SUPPORTED_HEAD_DIM = 128
+
+_MSA_KERNEL = None
+
+
+def load_and_register_msa_kernel(attn_implementation: str):
+ """Load the MSA hub kernel once and verify the expected callables are present.
+
+ The ``attn_implementation`` string may carry a ``paged|`` prefix and/or an ``@`` pin
+ (e.g. ``kernels-staging/msa@v0``); the build currently lives on the repo's ``v0`` branch. The
+ loaded module is cached in a module-level global so registration happens once, not per call.
+ """
+ global _MSA_KERNEL
+ if _MSA_KERNEL is not None:
+ return _MSA_KERNEL
+
+ from .hub_kernels import get_kernel
+
+ repo_id = attn_implementation.split("|")[-1]
+ repo_id, _, rev = repo_id.partition("@")
+ kernel = get_kernel(repo_id, revision=rev or None, version=None if rev else 0, allow_all_kernels=True)
+
+ for fn_name in ("sparse_atten_func", "build_k2q_csr"):
+ if not callable(getattr(kernel, fn_name, None)):
+ raise ImportError(
+ f"The MSA kernel loaded from `{repo_id}` does not expose a callable `{fn_name}`. "
+ "Make sure you request a compatible build, e.g. `kernels-staging/msa@v0`."
+ )
+
+ _MSA_KERNEL = kernel
+ return _MSA_KERNEL
+
+
+@torch.library.custom_op("transformers_msa::sparse_atten", mutates_args=())
+def _msa_sparse_atten_op(
+ q: torch.Tensor,
+ k: torch.Tensor,
+ v: torch.Tensor,
+ q2k: torch.Tensor,
+ cu_seqlens_q: torch.Tensor,
+ cu_seqlens_k: torch.Tensor,
+ topk: int,
+ block_size: int,
+ total_k: int,
+ max_seqlen_q: int,
+ max_seqlen_k: int,
+ qheads_per_kv: int,
+ scaling: float,
+ impl: str,
+) -> torch.Tensor:
+ """Opaque wrapper around the CuTe-DSL CSR build + block-sparse kernel.
+
+ Registered as a ``torch.library`` custom op so ``torch.compile(fullgraph=True)`` treats the
+ whole CSR-build + attention as a single opaque node (no graph break) and ``reduce-overhead``
+ CUDA graphs can capture it. The internal ``build_k2q_csr`` output is data-dependent in shape,
+ but it never escapes this op (only the fixed-shape ``[total_q, Hq, D]`` attention output does),
+ so the fake/meta impl below is exact. The op is functional (no input mutation).
+ """
+ msa = load_and_register_msa_kernel(impl)
+ # CuTe-DSL kernel launches on the ambient ``current_device`` with no internal guard; pin context
+ # to the tensors' device so device_map (mixed-GPU) layouts don't reference the wrong context.
+ with torch.cuda.device(q.device):
+ k2q_row_ptr, k2q_q_indices = msa.build_k2q_csr(
+ q2k,
+ cu_seqlens_q,
+ cu_seqlens_k,
+ block_size,
+ total_k=total_k,
+ max_seqlen_k=max_seqlen_k,
+ max_seqlen_q=max_seqlen_q,
+ qhead_per_kv=qheads_per_kv,
+ )
+ attn_output = msa.sparse_atten_func(
+ q,
+ k,
+ v,
+ k2q_row_ptr,
+ k2q_q_indices,
+ topk,
+ cu_seqlens_q=cu_seqlens_q,
+ cu_seqlens_k=cu_seqlens_k,
+ max_seqlen_q=max_seqlen_q,
+ max_seqlen_k=max_seqlen_k,
+ blk_kv=block_size,
+ causal=True,
+ softmax_scale=scaling,
+ )
+ return attn_output.contiguous()
+
+
+@_msa_sparse_atten_op.register_fake
+def _msa_sparse_atten_fake(
+ q,
+ k,
+ v,
+ q2k,
+ cu_seqlens_q,
+ cu_seqlens_k,
+ topk,
+ block_size,
+ total_k,
+ max_seqlen_q,
+ max_seqlen_k,
+ qheads_per_kv,
+ scaling,
+ impl,
+):
+ # Output matches the query varlen layout [total_q, Hq, head_dim].
+ return torch.empty_like(q)
+
+
+def _validate_msa_init(module, query: torch.Tensor, dropout: float) -> None:
+ """Validate kernel capability, dropout and configured topk once per attention module.
+
+ Mirrors the flash-attention integration, which checks capability/dropout at model init rather
+ than on every forward. The check is cached on the module so the hot path never re-runs it.
+
+ There is no SDPA fallback: a sparse layer either runs the MSA kernel or this raises. Serves both
+ prefill (q_len > 1) and single-token decode (q_len == 1) -- decode is just a varlen call with one
+ query slot, so there is no context-length threshold.
+ """
+ if query.device.type != "cuda" or torch.cuda.get_device_capability(query.device)[0] != 10:
+ raise RuntimeError(
+ "MSA block-sparse attention requires an SM100 / Blackwell CUDA device. "
+ "Select a different `attn_implementation` on unsupported hardware."
+ )
+ if query.shape[-1] != MSA_SUPPORTED_HEAD_DIM:
+ raise ValueError(f"MSA block-sparse attention only supports head_dim {MSA_SUPPORTED_HEAD_DIM}.")
+ if module.indexer.block_size != MSA_SUPPORTED_BLOCK_SIZE:
+ raise ValueError(f"MSA block-sparse attention only supports block_size {MSA_SUPPORTED_BLOCK_SIZE}.")
+ if dropout != 0.0:
+ raise ValueError("MSA block-sparse attention does not support attention dropout; set `attention_dropout=0`.")
+ topk = module.indexer.topk_blocks
+ if topk not in MSA_SUPPORTED_TOPK:
+ raise ValueError(
+ f"MSA block-sparse attention only supports topk in {MSA_SUPPORTED_TOPK}, got `{topk}`. "
+ "Set `index_topk_blocks` to a supported value."
+ )
+
+
+def _sparse_attention(module, query, key, value, scaling, block_indices, block_size, cache_position):
+ bsz, num_q_heads, q_len, head_dim = query.shape
+ num_kv_heads, k_len = key.shape[1], key.shape[2]
+ qheads_per_kv = num_q_heads // num_kv_heads
+ topk = block_indices.shape[-1]
+
+ # The indexer emits `min(index_topk_blocks, num_key_blocks)` selected blocks, so on sequences with
+ # fewer key blocks than the configured budget the width lands on an arbitrary value (e.g. 12) that the
+ # `SparseK2qCsrBuilderSm100` CSR builder rejects -- it only accepts a CSR width in `MSA_SUPPORTED_TOPK`.
+ # Right-pad the selection up to the next supported width with `-1`, the same empty-slot sentinel the
+ # kernel already skips, so behaviour is unchanged and the width is always one the builder accepts. The
+ # width is a Python int (static under `torch.compile`), so this stays fullgraph / cudagraph stable.
+ padded_topk = next(t for t in MSA_SUPPORTED_TOPK if t >= topk)
+ if padded_topk != topk:
+ pad = block_indices.new_full((*block_indices.shape[:-1], padded_topk - topk), -1)
+ block_indices = torch.cat([block_indices, pad], dim=-1)
+ topk = padded_topk
+
+ # Flatten the batch dim into a packed varlen layout [total, H, head_dim] + cu_seqlens. The
+ # query boundary is a fixed stride (every row is `q_len` long), built device-side with no host
+ # sync so it stays compile/cudagraph stable.
+ q = query.transpose(1, 2).reshape(bsz * q_len, num_q_heads, head_dim).contiguous()
+ k = key.transpose(1, 2).reshape(bsz * k_len, num_kv_heads, head_dim).contiguous()
+ v = value.transpose(1, 2).reshape(bsz * k_len, num_kv_heads, head_dim).contiguous()
+ cu_seqlens_q = torch.arange(0, (bsz + 1) * q_len, q_len, device=q.device, dtype=torch.int32)
+ # Under a StaticCache `k_len` is the pre-allocated buffer (`max_cache_len`), not the valid length,
+ # so a packed `[0, k_len, 2*k_len, ...]` boundary would let the kernel's causal attend zero-padded
+ # future slots. For bsz==1 the real boundary is `[0, valid_k]` (valid_k = cache_position[-1]+1) built
+ # as a device tensor (no host sync) -- `build_k2q_csr` reads only the host shape hints `total_k`/
+ # `max_seqlen_k` (kept fixed at `bsz*k_len`/`k_len` below), so this stays compile/cudagraph stable.
+ if bsz == 1 and cache_position is not None:
+ valid_k = (cache_position[-1] + 1).to(torch.int32).reshape(1)
+ cu_seqlens_k = torch.cat([torch.zeros(1, device=q.device, dtype=torch.int32), valid_k])
+ else:
+ cu_seqlens_k = torch.arange(0, (bsz + 1) * k_len, k_len, device=q.device, dtype=torch.int32)
+
+ q2k = block_indices.to(torch.int32)
+ q2k = q2k.reshape(bsz * q_len, topk).unsqueeze(0).expand(num_kv_heads, -1, -1).contiguous()
+
+ # Opaque custom op: keeps the CuTe-DSL CSR build + block-sparse kernel as a single graph node
+ # so ``torch.compile(fullgraph=True)`` doesn't break and ``reduce-overhead`` CUDA graphs capture it.
+ attn_output = _msa_sparse_atten_op(
+ q,
+ k,
+ v,
+ q2k,
+ cu_seqlens_q,
+ cu_seqlens_k,
+ topk,
+ block_size,
+ bsz * k_len,
+ q_len,
+ k_len,
+ qheads_per_kv,
+ scaling,
+ module.config._attn_implementation,
+ )
+ return attn_output.reshape(bsz, q_len, num_q_heads, head_dim)
+
+
+def msa_attention_forward(
+ module: torch.nn.Module,
+ query: torch.Tensor,
+ key: torch.Tensor,
+ value: torch.Tensor,
+ attention_mask: torch.Tensor | None = None,
+ dropout: float = 0.0,
+ scaling: float | None = None,
+ block_indices: torch.Tensor | None = None,
+ **kwargs,
+) -> tuple[torch.Tensor, None]:
+ """
+ TODO: this opens a door to per-layer attn implementation which is something we might want lalter on.
+ """
+ if scaling is None:
+ scaling = query.shape[-1] ** -0.5
+
+ # No block selection (the dense vision tower, or full-attention layers without an indexer) -> plain SDPA.
+ if block_indices is None:
+ return sdpa_attention_forward(
+ module, query, key, value, attention_mask, dropout=dropout, scaling=scaling, **kwargs
+ )
+
+ # A sparse layer always runs the MSA kernel -- there is no SDPA fallback. Capability/config is
+ # validated once per module (raises on unsupported hardware or config) and cached on the module.
+ if not getattr(module, "_msa_validated", False):
+ _validate_msa_init(module, query, dropout)
+ module._msa_validated = True
+
+ block_size = module.indexer.block_size
+ cache_position = kwargs.get("cache_position")
+ attn_output = _sparse_attention(module, query, key, value, scaling, block_indices, block_size, cache_position)
+ return attn_output, None
diff --git a/src/transformers/masking_utils.py b/src/transformers/masking_utils.py
index 4a0022abac5f..eca20baa8918 100644
--- a/src/transformers/masking_utils.py
+++ b/src/transformers/masking_utils.py
@@ -1458,6 +1458,7 @@ def create_chunked_causal_mask(
"chunked_attention": create_chunked_causal_mask,
"compressed_sparse_attention": create_sliding_window_causal_mask,
"heavily_compressed_attention": create_sliding_window_causal_mask,
+ "minimax_m3_sparse": create_causal_mask,
"deepseek_sparse_attention": create_causal_mask,
}
diff --git a/src/transformers/modeling_flash_attention_utils.py b/src/transformers/modeling_flash_attention_utils.py
index c5958f77bb23..567f03a3bba7 100644
--- a/src/transformers/modeling_flash_attention_utils.py
+++ b/src/transformers/modeling_flash_attention_utils.py
@@ -190,6 +190,12 @@ def _lazy_imports(
flash_attn_func = getattr(kernel, "flash_attn_func", None)
flash_attn_varlen_func = getattr(kernel, "flash_attn_varlen_func", None)
flash_attn_with_kvcache = getattr(kernel, "flash_attn_with_kvcache", None)
+ # Block-sparse kernels (e.g. ``kernels-staging/msa``) expose ``sparse_atten_func`` rather than
+ # ``flash_attn_varlen_func``. ``load_and_register_attn_kernel`` already registered their dedicated
+ # wrapper into ``ALL_ATTENTION_FUNCTIONS``, so they dispatch through the attention interface and
+ # never touch the flash varlen globals -- preloading them here is a no-op, not an error.
+ if flash_attn_varlen_func is None and hasattr(kernel, "sparse_atten_func"):
+ return flash_attn_func, flash_attn_varlen_func, flash_attn_with_kvcache, pad_input, unpad_input
if flash_attn_varlen_func is None:
raise ValueError(
f"Could not find the currently requested flash attention implementation at `{implementation}`."
@@ -255,7 +261,9 @@ def lazy_import_flash_attention(
_flash_fn, _flash_varlen_fn, _flash_with_kvcache_fn, _pad_fn, _unpad_fn = _lazy_imports(
implementation, attention_wrapper, allow_all_kernels=allow_all_kernels
)
- _process_flash_kwargs_fn = _lazy_define_process_function(_flash_varlen_fn)
+ # Block-sparse kernels register their own attention interface and expose no varlen fn to introspect;
+ # skip building the kwargs-support map (it is only consumed by the flash varlen path they never take).
+ _process_flash_kwargs_fn = _lazy_define_process_function(_flash_varlen_fn) if _flash_varlen_fn else None
return (_flash_fn, _flash_varlen_fn, _flash_with_kvcache_fn, _pad_fn, _unpad_fn), _process_flash_kwargs_fn
diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py
index f511cfb1aacb..1b8c93328f43 100644
--- a/src/transformers/modeling_utils.py
+++ b/src/transformers/modeling_utils.py
@@ -1911,6 +1911,11 @@ def _check_and_adjust_attn_implementation(
"""
is_paged, base_implementation = split_attention_implementation(attn_implementation)
+ # A kernel the model explicitly lists in `_compatible_flash_implementations` is vouched for by
+ # the model author, so authorize loading it even when it lives outside the `kernels-community` org.
+ if base_implementation in (getattr(self, "_compatible_flash_implementations", None) or []):
+ allow_all_kernels = True
+
# Auto-correct model's default flash implementation if specified
if attn_implementation is not None:
compatible_flash_implementations = getattr(self, "_compatible_flash_implementations", None)
diff --git a/src/transformers/models/__init__.py b/src/transformers/models/__init__.py
index e5d880b9a46b..025c2ac7c0e1 100644
--- a/src/transformers/models/__init__.py
+++ b/src/transformers/models/__init__.py
@@ -267,6 +267,7 @@
from .minicpmv4_6 import *
from .minimax import *
from .minimax_m2 import *
+ from .minimax_m3_vl import *
from .ministral import *
from .ministral3 import *
from .mistral import *
diff --git a/src/transformers/models/auto/auto_mappings.py b/src/transformers/models/auto/auto_mappings.py
index e3edc3f7f2e7..fe1212f230c1 100644
--- a/src/transformers/models/auto/auto_mappings.py
+++ b/src/transformers/models/auto/auto_mappings.py
@@ -362,6 +362,9 @@
("minicpmv4_6_vision", "MiniCPMV4_6VisionConfig"),
("minimax", "MiniMaxConfig"),
("minimax_m2", "MiniMaxM2Config"),
+ ("minimax_m3_vl", "MiniMaxM3VLConfig"),
+ ("minimax_m3_vl_text", "MiniMaxM3VLTextConfig"),
+ ("minimax_m3_vl_vision", "MiniMaxM3VLVisionConfig"),
("ministral", "MinistralConfig"),
("ministral3", "Ministral3Config"),
("mistral", "MistralConfig"),
@@ -791,6 +794,8 @@
("metaclip_2_vision_model", "metaclip_2"),
("mgp-str", "mgp_str"),
("minicpmv4_6_vision", "minicpmv4_6"),
+ ("minimax_m3_vl_text", "minimax_m3_vl"),
+ ("minimax_m3_vl_vision", "minimax_m3_vl"),
("mlcd_vision_model", "mlcd"),
("mllama_text_model", "mllama"),
("mllama_vision_model", "mllama"),
@@ -910,6 +915,7 @@
("llava_next_video", "LlavaNextVideoVideoProcessor"),
("llava_onevision", "LlavaOnevisionVideoProcessor"),
("minicpmv4_6", "MiniCPMV4_6VideoProcessor"),
+ ("minimax_m3_vl", "MiniMaxM3VLVideoProcessor"),
("pe_video", "PeVideoVideoProcessor"),
("perception_lm", "PerceptionLMVideoProcessor"),
("qwen2_vl", "Qwen2VLVideoProcessor"),
@@ -1025,6 +1031,7 @@
("markuplm", "MarkupLMProcessor"),
("mgp-str", "MgpstrProcessor"),
("minicpmv4_6", "MiniCPMV4_6Processor"),
+ ("minimax_m3_vl", "MiniMaxM3VLProcessor"),
("mllama", "MllamaProcessor"),
("moonshine_streaming", "MoonshineStreamingProcessor"),
("musicflamingo", "MusicFlamingoProcessor"),
@@ -1148,6 +1155,7 @@
("mask2former", {"pil": "Mask2FormerImageProcessorPil", "torchvision": "Mask2FormerImageProcessor"}),
("maskformer", {"pil": "MaskFormerImageProcessorPil", "torchvision": "MaskFormerImageProcessor"}),
("minicpmv4_6", {"pil": "MiniCPMV4_6ImageProcessorPil", "torchvision": "MiniCPMV4_6ImageProcessor"}),
+ ("minimax_m3_vl", {"torchvision": "MiniMaxM3VLImageProcessor"}),
("mllama", {"pil": "MllamaImageProcessorPil", "torchvision": "MllamaImageProcessor"}),
("mobilenet_v1", {"pil": "MobileNetV1ImageProcessorPil", "torchvision": "MobileNetV1ImageProcessor"}),
("mobilenet_v2", {"pil": "MobileNetV2ImageProcessorPil", "torchvision": "MobileNetV2ImageProcessor"}),
diff --git a/src/transformers/models/auto/modeling_auto.py b/src/transformers/models/auto/modeling_auto.py
index 3d7aeb3e7408..9dfc42d07cd0 100644
--- a/src/transformers/models/auto/modeling_auto.py
+++ b/src/transformers/models/auto/modeling_auto.py
@@ -304,6 +304,8 @@ class _BaseModelWithGenerate(PreTrainedModel, GenerationMixin):
("minicpmv4_6", "MiniCPMV4_6Model"),
("minimax", "MiniMaxModel"),
("minimax_m2", "MiniMaxM2Model"),
+ ("minimax_m3_vl", "MiniMaxM3VLModel"),
+ ("minimax_m3_vl_text", "MiniMaxM3VLTextModel"),
("ministral", "MinistralModel"),
("ministral3", "Ministral3Model"),
("mistral", "MistralModel"),
@@ -741,6 +743,7 @@ class _BaseModelWithGenerate(PreTrainedModel, GenerationMixin):
("mellum", "MellumForCausalLM"),
("minimax", "MiniMaxForCausalLM"),
("minimax_m2", "MiniMaxM2ForCausalLM"),
+ ("minimax_m3_vl_text", "MiniMaxM3VLForCausalLM"),
("ministral", "MinistralForCausalLM"),
("ministral3", "Ministral3ForCausalLM"),
("mistral", "MistralForCausalLM"),
@@ -1059,6 +1062,7 @@ class _BaseModelWithGenerate(PreTrainedModel, GenerationMixin):
("llava_next_video", "LlavaNextVideoForConditionalGeneration"),
("llava_onevision", "LlavaOnevisionForConditionalGeneration"),
("minicpmv4_6", "MiniCPMV4_6ForConditionalGeneration"),
+ ("minimax_m3_vl", "MiniMaxM3SparseForConditionalGeneration"),
("mistral3", "Mistral3ForConditionalGeneration"),
("mistral4", "Mistral4ForCausalLM"),
("mllama", "MllamaForConditionalGeneration"),
diff --git a/src/transformers/models/deepseek_v4/modeling_deepseek_v4.py b/src/transformers/models/deepseek_v4/modeling_deepseek_v4.py
index 4ee9a9d53bf5..b0fc580df777 100644
--- a/src/transformers/models/deepseek_v4/modeling_deepseek_v4.py
+++ b/src/transformers/models/deepseek_v4/modeling_deepseek_v4.py
@@ -513,7 +513,9 @@ def forward(
layer_idx: int,
) -> torch.LongTensor:
batch, seq_len, _ = hidden_states.shape
- cache_layer: DeepseekV4CSACache = past_key_values.layers[layer_idx] if past_key_values is not None else None
+ cache_layer: DeepseekV4CSACache | None = (
+ past_key_values.layers[layer_idx] if past_key_values is not None else None
+ )
kv = self.kv_proj(hidden_states)
gate = self.gate_proj(hidden_states)
@@ -627,7 +629,9 @@ def forward(
layer_idx: int,
) -> tuple[torch.Tensor, torch.Tensor]:
batch, seq_len, _ = hidden_states.shape
- cache_layer: DeepseekV4CSACache = past_key_values.layers[layer_idx] if past_key_values is not None else None
+ cache_layer: DeepseekV4CSACache | None = (
+ past_key_values.layers[layer_idx] if past_key_values is not None else None
+ )
kv = self.kv_proj(hidden_states)
gate = self.gate_proj(hidden_states)
diff --git a/src/transformers/models/deepseek_v4/modular_deepseek_v4.py b/src/transformers/models/deepseek_v4/modular_deepseek_v4.py
index b5af24599f4b..089192d0e94f 100644
--- a/src/transformers/models/deepseek_v4/modular_deepseek_v4.py
+++ b/src/transformers/models/deepseek_v4/modular_deepseek_v4.py
@@ -449,7 +449,9 @@ def forward(
layer_idx: int,
) -> torch.LongTensor:
batch, seq_len, _ = hidden_states.shape
- cache_layer: DeepseekV4CSACache = past_key_values.layers[layer_idx] if past_key_values is not None else None
+ cache_layer: DeepseekV4CSACache | None = (
+ past_key_values.layers[layer_idx] if past_key_values is not None else None
+ )
kv = self.kv_proj(hidden_states)
gate = self.gate_proj(hidden_states)
@@ -563,7 +565,9 @@ def forward(
layer_idx: int,
) -> tuple[torch.Tensor, torch.Tensor]:
batch, seq_len, _ = hidden_states.shape
- cache_layer: DeepseekV4CSACache = past_key_values.layers[layer_idx] if past_key_values is not None else None
+ cache_layer: DeepseekV4CSACache | None = (
+ past_key_values.layers[layer_idx] if past_key_values is not None else None
+ )
kv = self.kv_proj(hidden_states)
gate = self.gate_proj(hidden_states)
diff --git a/src/transformers/models/minimax_m3_vl/__init__.py b/src/transformers/models/minimax_m3_vl/__init__.py
new file mode 100644
index 000000000000..9c8d2fbac88f
--- /dev/null
+++ b/src/transformers/models/minimax_m3_vl/__init__.py
@@ -0,0 +1,31 @@
+# Copyright 2026 the MiniMax AI Team and HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from typing import TYPE_CHECKING
+
+from ...utils import _LazyModule
+from ...utils.import_utils import define_import_structure
+
+
+if TYPE_CHECKING:
+ from .configuration_minimax_m3_vl import *
+ from .image_processing_minimax_m3_vl import *
+ from .modeling_minimax_m3_vl import *
+ from .processing_minimax_m3_vl import *
+ from .video_processing_minimax_m3_vl import *
+else:
+ import sys
+
+ _file = globals()["__file__"]
+ sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__)
diff --git a/src/transformers/models/minimax_m3_vl/configuration_minimax_m3_vl.py b/src/transformers/models/minimax_m3_vl/configuration_minimax_m3_vl.py
new file mode 100644
index 000000000000..86c0009389be
--- /dev/null
+++ b/src/transformers/models/minimax_m3_vl/configuration_minimax_m3_vl.py
@@ -0,0 +1,226 @@
+# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
+# This file was automatically generated from src/transformers/models/minimax_m3_vl/modular_minimax_m3_vl.py.
+# Do NOT edit this file manually as any edits will be overwritten by the generation of
+# the file from the modular. If any change should be done, please apply the change to the
+# modular_minimax_m3_vl.py file directly. One of our CI enforces this.
+# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
+# Copyright 2026 the MiniMax AI Team and HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from huggingface_hub.dataclasses import strict
+
+from ...configuration_utils import PreTrainedConfig
+from ...modeling_rope_utils import RopeParameters
+from ...utils import auto_docstring
+from ..auto import AutoConfig
+
+
+@auto_docstring(checkpoint="MiniMaxAI/MiniMax-M3-preview")
+@strict
+class MiniMaxM3VLTextConfig(PreTrainedConfig):
+ r"""
+ dense_intermediate_size (`int`, *optional*, defaults to 12288):
+ Intermediate size of the dense MLP used on layers whose `mlp_layer_types` entry is `"dense"`.
+ shared_intermediate_size (`int`, *optional*, defaults to 3072):
+ Intermediate size of a single shared expert in the MoE layers.
+ rotary_dim (`int`, *optional*, defaults to 64):
+ Number of head channels rotated by RoPE; the remaining channels are passed through unchanged.
+ swiglu_alpha (`float`, *optional*, defaults to 1.702):
+ Sigmoid gain of the SwiGLU-OAI activation.
+ swiglu_limit (`float`, *optional*, defaults to 7.0):
+ Clamp bound applied to the gate and up projections of the SwiGLU-OAI activation.
+ mlp_layer_types (`list[str]`, *optional*):
+ Per-layer MLP selector: `"sparse"` for a MoE block, `"dense"` for a dense MLP.
+ index_n_heads (`int`, *optional*, defaults to 4):
+ Number of heads in the lightning indexer's dot-product scoring branch.
+ index_head_dim (`int`, *optional*, defaults to 128):
+ Per-head channel dimension of the lightning indexer.
+ index_block_size (`int`, *optional*, defaults to 128):
+ Number of key tokens pooled into a single scored block.
+ index_topk_blocks (`int`, *optional*, defaults to 16):
+ Number of top-scoring key blocks each query may attend to.
+ index_local_blocks (`int`, *optional*, defaults to 1):
+ Number of key blocks immediately preceding the query always kept visible / attended to.
+ """
+
+ model_type = "minimax_m3_vl_text"
+ keys_to_ignore_at_inference = ["past_key_values"]
+ base_model_tp_plan = {
+ "layers.*.self_attn.q_proj": "colwise_gather_output",
+ "layers.*.self_attn.k_proj": "colwise_gather_output",
+ "layers.*.self_attn.v_proj": "colwise_gather_output",
+ "layers.*.self_attn.o_proj": "rowwise_split_input",
+ "layers.*.mlp.experts.gate_up_proj": "packed_colwise",
+ "layers.*.mlp.experts.down_proj": "rowwise",
+ "layers.*.mlp.experts": "moe_tp_experts",
+ }
+ base_model_pp_plan = {
+ "embed_tokens": (["input_ids"], ["inputs_embeds"]),
+ "layers": (["hidden_states", "attention_mask"], ["hidden_states"]),
+ "norm": (["hidden_states"], ["hidden_states"]),
+ }
+ attribute_map = {
+ "num_experts": "num_local_experts",
+ }
+ default_theta = 5000000.0
+ vocab_size: int = 200064
+
+ hidden_size: int = 6144
+ intermediate_size: int = 3072
+ num_hidden_layers: int = 60
+ num_attention_heads: int = 64
+ num_key_value_heads: int = 4
+ head_dim: int = 128
+ hidden_act: str = "silu"
+ max_position_embeddings: int = 524288
+ initializer_range: float = 0.02
+ rms_norm_eps: float = 1e-06
+ use_cache: bool = True
+ pad_token_id: int | None = None
+ bos_token_id: int | None = 200034
+ eos_token_id: int | list[int] | None = 200020
+ tie_word_embeddings: bool = False
+ attention_dropout: float | int = 0.0
+ num_experts_per_tok: int = 4
+ num_local_experts: int = 128
+ output_router_logits: bool = False
+ router_aux_loss_coef: float = 0.001
+ router_jitter_noise: float = 0.0
+ rope_parameters: RopeParameters | dict | None = None
+ base_config_key = "text_config"
+ base_model_ep_plan = {
+ "layers.*.mlp.gate": "ep_router",
+ "layers.*.mlp.experts.gate_up_proj": "grouped_gemm",
+ "layers.*.mlp.experts.gate_up_proj_scale_inv": "grouped_gemm",
+ "layers.*.mlp.experts.down_proj": "grouped_gemm",
+ "layers.*.mlp.experts.down_proj_scale_inv": "grouped_gemm",
+ "layers.*.mlp.experts": "moe_tp_experts",
+ }
+ dense_intermediate_size: int = 12288
+ shared_intermediate_size: int = 3072
+ routed_scaling_factor: float = 2.0
+ rotary_dim: int = 64
+ swiglu_alpha: float = 1.702
+ swiglu_limit: float = 7.0
+ mlp_layer_types: list[str] | None = None
+ index_n_heads: int = 4
+ index_head_dim: int = 128
+ index_block_size: int = 128
+ index_topk_blocks: int = 16
+ index_local_blocks: int = 1
+ layer_types: list[str] | None = None
+
+ def __post_init__(self, **kwargs):
+ sparse_cfg = kwargs.pop("sparse_attention_config", None) or {}
+ moe_layer_freq = kwargs.pop("moe_layer_freq", None)
+ super().__post_init__(**kwargs)
+ # Checkpoint declares "swigluoai", but the gate is computed inline from swiglu_alpha/limit; hidden_act
+ # is only the pointwise fallback and must be a real ACT2FN key, so normalize it to silu.
+ self.hidden_act = "silu"
+
+ for flat, legacy in {
+ "index_n_heads": "sparse_num_index_heads",
+ "index_head_dim": "sparse_index_dim",
+ "index_block_size": "sparse_block_size",
+ "index_topk_blocks": "sparse_topk_blocks",
+ "index_local_blocks": "sparse_local_block",
+ }.items():
+ if legacy in sparse_cfg:
+ setattr(self, flat, sparse_cfg[legacy])
+
+ # `layer_types` is the canonical per-layer attention dispatch: it tells
+ # `DynamicCache(config=...)` which layers want the sparse cache and tells
+ # `MiniMaxM3VLAttention` which layers build a sparse Lightning Indexer.
+ if self.layer_types is None and "sparse_attention_freq" in sparse_cfg:
+ self.layer_types = [
+ "minimax_m3_sparse" if f else "full_attention" for f in sparse_cfg["sparse_attention_freq"]
+ ]
+ if self.layer_types is None:
+ self.layer_types = ["full_attention"] * self.num_hidden_layers
+
+ # `mlp_layer_types` is the per-layer MLP dispatch read by `MiniMaxM3VLDecoderLayer`:
+ if self.mlp_layer_types is None and moe_layer_freq is not None:
+ self.mlp_layer_types = ["sparse" if f else "dense" for f in moe_layer_freq]
+ if self.mlp_layer_types is None:
+ self.mlp_layer_types = ["sparse"] * self.num_hidden_layers
+
+
+@auto_docstring(checkpoint="MiniMaxAI/MiniMax-M3-preview")
+@strict
+class MiniMaxM3VLVisionConfig(PreTrainedConfig):
+ r"""
+ rope_parameters (`RopeParameters`, *optional*):
+ Standard RoPE configuration for the vision tower's 3D rotary position embedding.
+ """
+
+ model_type = "minimax_m3_vl_vision"
+ base_config_key = "vision_config"
+ default_theta = 10000.0
+
+ hidden_size: int = 1280
+ intermediate_size: int = 5120
+ num_hidden_layers: int = 32
+ num_attention_heads: int = 16
+ num_channels: int = 3
+ image_size: int = 2016
+ patch_size: int = 14
+ temporal_patch_size: int = 2
+ spatial_merge_size: int = 2
+ hidden_act: str = "gelu"
+ layer_norm_eps: float = 1e-05
+ attention_dropout: float = 0.0
+ rope_parameters: RopeParameters | dict | None = None
+ initializer_range: float = 0.02
+
+
+@auto_docstring(checkpoint="MiniMaxAI/MiniMax-M3-preview")
+@strict
+class MiniMaxM3VLConfig(PreTrainedConfig):
+ model_type = "minimax_m3_vl"
+ sub_configs = {"text_config": AutoConfig, "vision_config": AutoConfig}
+ attribute_map = {
+ "image_token_id": "image_token_index",
+ "video_token_id": "video_token_index",
+ }
+
+ vision_config: dict | PreTrainedConfig | None = None
+ text_config: dict | PreTrainedConfig | None = None
+ image_token_index: int = 200025
+ video_token_index: int = 200026
+ projector_hidden_size: int = 6144
+ tie_word_embeddings: bool = False
+
+ def __post_init__(self, **kwargs):
+ if isinstance(self.vision_config, dict):
+ self.vision_config.pop("model_type", None)
+ self.vision_config = MiniMaxM3VLVisionConfig(**self.vision_config)
+ elif self.vision_config is None:
+ self.vision_config = MiniMaxM3VLVisionConfig()
+
+ if isinstance(self.text_config, dict):
+ self.text_config.pop("model_type", None)
+ self.text_config = MiniMaxM3VLTextConfig(**self.text_config)
+ elif self.text_config is None:
+ self.text_config = MiniMaxM3VLTextConfig()
+
+ if not self.tie_word_embeddings and self.text_config.tie_word_embeddings:
+ self.tie_word_embeddings = self.text_config.tie_word_embeddings
+
+ # Channel dim after grouping `spatial_merge_size**2` projected patches, consumed by the
+ # patch-merge MLP inside `MiniMaxM3VLMultiModalProjector`.
+ self.merged_hidden_size = self.text_config.hidden_size * (self.vision_config.spatial_merge_size**2)
+
+ super().__post_init__(**kwargs)
+
+
+__all__ = ["MiniMaxM3VLConfig", "MiniMaxM3VLTextConfig", "MiniMaxM3VLVisionConfig"]
diff --git a/src/transformers/models/minimax_m3_vl/image_processing_minimax_m3_vl.py b/src/transformers/models/minimax_m3_vl/image_processing_minimax_m3_vl.py
new file mode 100644
index 000000000000..04766fd745a2
--- /dev/null
+++ b/src/transformers/models/minimax_m3_vl/image_processing_minimax_m3_vl.py
@@ -0,0 +1,204 @@
+# Copyright 2026 the MiniMax AI Team and HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import math
+
+import torch
+from torchvision.transforms.v2 import functional as tvF
+
+from ...image_processing_backends import TorchvisionBackend
+from ...image_processing_utils import BatchFeature
+from ...image_transforms import group_images_by_shape, reorder_images
+from ...image_utils import OPENAI_CLIP_MEAN, OPENAI_CLIP_STD, ImageInput, PILImageResampling, SizeDict
+from ...processing_utils import ImagesKwargs, Unpack
+from ...utils import TensorType, auto_docstring
+
+
+class MiniMaxM3VLImageProcessorKwargs(ImagesKwargs, total=False):
+ r"""
+ patch_size (`int`, *optional*, defaults to 14):
+ The spatial patch size of the vision encoder.
+ temporal_patch_size (`int`, *optional*, defaults to 2):
+ The temporal patch size of the vision encoder.
+ merge_size (`int`, *optional*, defaults to 2):
+ The merge size of the vision encoder to llm encoder.
+ max_pixels (`int`, *optional*, defaults to 451584):
+ The max pixels of the image to resize the image.
+ """
+
+ patch_size: int
+ temporal_patch_size: int
+ merge_size: int
+ max_pixels: int
+
+
+MAX_RATIO = 200
+
+
+def smart_resize(
+ height: int,
+ width: int,
+ factor: int = 28,
+ min_pixels: int = 4 * 28 * 28,
+ max_pixels: int = 451584,
+) -> tuple[int, int]:
+ if max(height, width) / min(height, width) > MAX_RATIO:
+ raise ValueError(
+ f"absolute aspect ratio must be smaller than {MAX_RATIO}, got {max(height, width) / min(height, width)}"
+ )
+ h_bar = max(factor, round(height / factor) * factor)
+ w_bar = max(factor, round(width / factor) * factor)
+ if h_bar * w_bar > max_pixels:
+ beta = math.sqrt((height * width) / max_pixels)
+ h_bar = math.floor(height / beta / factor) * factor
+ w_bar = math.floor(width / beta / factor) * factor
+ elif h_bar * w_bar < min_pixels:
+ beta = math.sqrt(min_pixels / (height * width))
+ h_bar = math.ceil(height * beta / factor) * factor
+ w_bar = math.ceil(width * beta / factor) * factor
+ return h_bar, w_bar
+
+
+@auto_docstring
+class MiniMaxM3VLImageProcessor(TorchvisionBackend):
+ do_resize = True
+ resample = PILImageResampling.BICUBIC
+ size = {"height": 672, "width": 672}
+ default_to_square = False
+ do_rescale = True
+ rescale_factor = 1 / 255
+ do_normalize = True
+ image_mean = OPENAI_CLIP_MEAN
+ image_std = OPENAI_CLIP_STD
+ do_convert_rgb = True
+ patch_size = 14
+ temporal_patch_size = 2
+ merge_size = 2
+ max_pixels = 451584
+ valid_kwargs = MiniMaxM3VLImageProcessorKwargs
+ model_input_names = ["pixel_values", "image_grid_thw"]
+
+ @auto_docstring
+ def preprocess(self, images: ImageInput, **kwargs: Unpack[MiniMaxM3VLImageProcessorKwargs]) -> BatchFeature:
+ return super().preprocess(images, **kwargs)
+
+ def _preprocess(
+ self,
+ images: list["torch.Tensor"],
+ do_resize: bool,
+ size: SizeDict,
+ resample: "PILImageResampling | tvF.InterpolationMode | int | None",
+ do_rescale: bool,
+ rescale_factor: float,
+ do_normalize: bool,
+ image_mean: float | list[float] | None,
+ image_std: float | list[float] | None,
+ patch_size: int,
+ temporal_patch_size: int,
+ merge_size: int,
+ max_pixels: int,
+ disable_grouping: bool | None,
+ return_tensors: str | TensorType | None,
+ **kwargs,
+ ) -> BatchFeature:
+ grouped_images, grouped_images_index = group_images_by_shape(images, disable_grouping=disable_grouping)
+ resized_images_grouped = {}
+ factor = patch_size * merge_size
+ for shape, stacked_images in grouped_images.items():
+ height, width = stacked_images.shape[-2:]
+ if do_resize:
+ resized_height, resized_width = smart_resize(
+ height=height,
+ width=width,
+ factor=factor,
+ max_pixels=max_pixels,
+ )
+ stacked_images = self.resize(
+ stacked_images,
+ size=SizeDict(height=resized_height, width=resized_width),
+ resample=resample,
+ )
+ resized_images_grouped[shape] = stacked_images
+
+ resized_images = reorder_images(resized_images_grouped, grouped_images_index)
+
+ grouped_images, grouped_images_index = group_images_by_shape(resized_images, disable_grouping=disable_grouping)
+ processed_images_grouped = {}
+ processed_grids = {}
+
+ for shape, stacked_images in grouped_images.items():
+ resized_height, resized_width = stacked_images.shape[-2:]
+
+ patches = self.rescale_and_normalize(
+ stacked_images, do_rescale, rescale_factor, do_normalize, image_mean, image_std
+ )
+ if patches.ndim == 4: # (B, C, H, W)
+ patches = patches.unsqueeze(1) # (B, T=1, C, H, W)
+
+ if patches.shape[1] % temporal_patch_size != 0:
+ repeats = patches[:, -1:].repeat(
+ 1, temporal_patch_size - (patches.shape[1] % temporal_patch_size), 1, 1, 1
+ )
+ patches = torch.cat([patches, repeats], dim=1)
+
+ batch_size, t_len, channel = patches.shape[:3]
+ grid_t = t_len // temporal_patch_size
+ grid_h, grid_w = resized_height // patch_size, resized_width // patch_size
+
+ patches = patches.view(
+ batch_size,
+ grid_t,
+ temporal_patch_size,
+ channel,
+ grid_h // merge_size,
+ merge_size,
+ patch_size,
+ grid_w // merge_size,
+ merge_size,
+ patch_size,
+ )
+ patches = patches.permute(0, 1, 4, 7, 5, 8, 3, 2, 6, 9)
+
+ flatten_patches = patches.reshape(
+ batch_size,
+ grid_t * grid_h * grid_w,
+ channel * temporal_patch_size * patch_size * patch_size,
+ )
+
+ processed_images_grouped[shape] = flatten_patches
+ processed_grids[shape] = [[grid_t, grid_h, grid_w]] * batch_size
+
+ processed_images = reorder_images(processed_images_grouped, grouped_images_index)
+ processed_grids = reorder_images(processed_grids, grouped_images_index)
+
+ pixel_values = torch.cat(processed_images, dim=0)
+ image_grid_thw = torch.tensor(processed_grids, dtype=torch.long)
+
+ return BatchFeature(
+ data={"pixel_values": pixel_values, "image_grid_thw": image_grid_thw}, tensor_type=return_tensors
+ )
+
+ def get_number_of_image_patches(self, height: int, width: int, images_kwargs=None) -> int:
+ images_kwargs = images_kwargs or {}
+ patch_size = images_kwargs.get("patch_size", self.patch_size)
+ merge_size = images_kwargs.get("merge_size", self.merge_size)
+ max_pixels = images_kwargs.get("max_pixels", self.max_pixels)
+ resized_height, resized_width = smart_resize(
+ height, width, factor=patch_size * merge_size, max_pixels=max_pixels
+ )
+ grid_h, grid_w = resized_height // patch_size, resized_width // patch_size
+ return grid_h * grid_w
+
+
+__all__ = ["MiniMaxM3VLImageProcessor"]
diff --git a/src/transformers/models/minimax_m3_vl/modeling_minimax_m3_vl.py b/src/transformers/models/minimax_m3_vl/modeling_minimax_m3_vl.py
new file mode 100644
index 000000000000..2e825ce8e5e9
--- /dev/null
+++ b/src/transformers/models/minimax_m3_vl/modeling_minimax_m3_vl.py
@@ -0,0 +1,1586 @@
+# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
+# This file was automatically generated from src/transformers/models/minimax_m3_vl/modular_minimax_m3_vl.py.
+# Do NOT edit this file manually as any edits will be overwritten by the generation of
+# the file from the modular. If any change should be done, please apply the change to the
+# modular_minimax_m3_vl.py file directly. One of our CI enforces this.
+# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
+# Copyright 2026 the MiniMax AI Team and HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from collections.abc import Callable
+from dataclasses import dataclass
+from typing import Optional
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+from ... import initialization as init
+from ...activations import ACT2FN
+from ...cache_utils import Cache, DynamicCache, DynamicLayer, StaticLayer
+from ...configuration_utils import PreTrainedConfig
+from ...generation import GenerationMixin
+from ...integrations import use_experts_implementation
+from ...masking_utils import create_causal_mask
+from ...modeling_layers import GradientCheckpointingLayer
+from ...modeling_outputs import (
+ BaseModelOutputWithPast,
+ BaseModelOutputWithPooling,
+ MoeCausalLMOutputWithPast,
+ MoeModelOutputWithPast,
+)
+from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
+from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
+from ...processing_utils import Unpack
+from ...utils import ModelOutput, TransformersKwargs, auto_docstring, torch_compilable_check
+from ...utils.generic import can_return_tuple, maybe_autocast, merge_with_config_defaults
+from ...utils.import_utils import is_torchdynamo_compiling
+from ...utils.output_capturing import OutputRecorder, capture_outputs
+from .configuration_minimax_m3_vl import MiniMaxM3VLConfig, MiniMaxM3VLTextConfig, MiniMaxM3VLVisionConfig
+
+
+class MiniMaxM3VLSparseCacheLayer(DynamicLayer):
+ layer_type = "minimax_m3_sparse"
+
+ def __init__(self, config: PreTrainedConfig | None = None):
+ super().__init__(config)
+ self.idx_keys: torch.Tensor | None = None
+
+ def update_index(self, idx_k: torch.Tensor) -> torch.Tensor:
+ """Append the new token's `idx_k` to the cache and return the full history."""
+ self.idx_keys = idx_k if self.idx_keys is None else torch.cat([self.idx_keys, idx_k], dim=-2)
+ return self.idx_keys
+
+ def reorder_cache(self, beam_idx: torch.LongTensor) -> None:
+ super().reorder_cache(beam_idx)
+ if self.idx_keys is not None:
+ self.idx_keys = self.idx_keys.index_select(0, beam_idx.to(self.idx_keys.device))
+
+ def batch_repeat_interleave(self, repeats: int) -> None:
+ super().batch_repeat_interleave(repeats)
+ if self.idx_keys is not None:
+ self.idx_keys = self.idx_keys.repeat_interleave(repeats, dim=0)
+
+ def batch_select_indices(self, indices: torch.Tensor) -> None:
+ super().batch_select_indices(indices)
+ if self.idx_keys is not None:
+ self.idx_keys = self.idx_keys[indices, ...]
+
+ def crop(self, max_length: int) -> None:
+ super().crop(max_length)
+ if max_length < 0:
+ max_length = self.get_seq_length() - abs(max_length)
+ if self.idx_keys is not None and self.idx_keys.shape[-2] > max_length:
+ self.idx_keys = self.idx_keys[..., :max_length, :]
+
+
+class MiniMaxM3VLSparseStaticCacheLayer(StaticLayer):
+ layer_type = "minimax_m3_sparse"
+
+ def __init__(self, max_cache_len: int):
+ super().__init__(max_cache_len)
+ self.idx_keys: torch.Tensor | None = None
+ # Tensor (not int) so it can be marked as a static address for cudagraphs, like `cumulative_length`.
+ self.idx_cumulative_length = torch.tensor([0], dtype=int)
+
+ def update_index(self, idx_k: torch.Tensor) -> torch.Tensor:
+ """Write the new token's `idx_k` into the static buffer in place and return the whole buffer.
+
+ The buffer's unfilled tail holds zeros, but those slots sit at key positions ahead of every
+ current query, so the indexer's block- and token-level causal masking discards them — the
+ returned `[B, 1, max_cache_len, D]` history is therefore safe to score against directly.
+ """
+ if self.idx_keys is None:
+ self.idx_keys = torch.zeros(
+ (idx_k.shape[0], idx_k.shape[1], self.max_cache_len, idx_k.shape[-1]),
+ dtype=idx_k.dtype,
+ device=idx_k.device,
+ )
+ self.idx_cumulative_length = self.idx_cumulative_length.to(idx_k.device)
+ if not is_torchdynamo_compiling():
+ torch._dynamo.mark_static_address(self.idx_keys)
+ torch._dynamo.mark_static_address(self.idx_cumulative_length)
+
+ kv_len = idx_k.shape[-2]
+ cache_position = torch.arange(kv_len, device=self.idx_keys.device) + self.idx_cumulative_length
+ self.idx_cumulative_length.add_(kv_len)
+ try:
+ self.idx_keys.index_copy_(2, cache_position, idx_k)
+ except NotImplementedError:
+ # Fallback for devices like MPS where index_copy_ might not be supported.
+ self.idx_keys[:, :, cache_position] = idx_k
+ return self.idx_keys
+
+ def reset(self) -> None:
+ super().reset()
+ if self.idx_keys is not None:
+ self.idx_keys.zero_()
+ self.idx_cumulative_length.zero_()
+
+ def reorder_cache(self, beam_idx: torch.LongTensor) -> None:
+ super().reorder_cache(beam_idx)
+ if self.idx_keys is not None:
+ self.idx_keys = self.idx_keys.index_select(0, beam_idx.to(self.idx_keys.device))
+
+
+class MiniMaxM3VLRMSNorm(nn.Module):
+ """Gemma-style RMSNorm: normalizes in fp32 and scales by `weight + 1`."""
+
+ def __init__(self, dim: int, eps: float = 1e-6):
+ super().__init__()
+ self.eps = eps
+ self.weight = nn.Parameter(torch.zeros(dim))
+
+ def _norm(self, x):
+ return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
+
+ def forward(self, x):
+ output = self._norm(x.float())
+ # Llama does x.to(float16) * w whilst MiniMaxM3VL is (x * w).to(float16)
+ # See https://github.com/huggingface/transformers/pull/29402
+ output = output * (1.0 + self.weight.float())
+ return output.type_as(x)
+
+ def extra_repr(self):
+ return f"{tuple(self.weight.shape)}, eps={self.eps}"
+
+
+class MiniMaxM3VLDenseMLP(nn.Module):
+ def __init__(self, config: MiniMaxM3VLTextConfig, intermediate_size: int | None = None):
+ super().__init__()
+ inter = intermediate_size if intermediate_size is not None else config.dense_intermediate_size
+ self.swiglu_alpha = config.swiglu_alpha
+ self.swiglu_limit = config.swiglu_limit
+ self.gate_up_proj = nn.Linear(config.hidden_size, 2 * inter, bias=False)
+ self.down_proj = nn.Linear(inter, config.hidden_size, bias=False)
+
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
+ gate_up = self.gate_up_proj(hidden_states)
+ gate, up = gate_up.chunk(2, dim=-1)
+ gate = gate.clamp(max=self.swiglu_limit)
+ up = up.clamp(min=-self.swiglu_limit, max=self.swiglu_limit)
+ glu = gate * torch.sigmoid(gate * self.swiglu_alpha)
+ return self.down_proj((up + 1.0) * glu)
+
+
+@use_experts_implementation
+class MiniMaxM3VLExperts(nn.Module):
+ """Collection of expert weights stored as 3D tensors."""
+
+ def __init__(self, config: MiniMaxM3VLTextConfig):
+ super().__init__()
+ self.num_experts = config.num_local_experts
+ self.hidden_dim = config.hidden_size
+ self.intermediate_dim = config.intermediate_size
+ self.gate_up_proj = nn.Parameter(torch.empty(self.num_experts, 2 * self.intermediate_dim, self.hidden_dim))
+ self.down_proj = nn.Parameter(torch.empty(self.num_experts, self.hidden_dim, self.intermediate_dim))
+ self.limit = config.swiglu_limit
+ self.swiglu_alpha = config.swiglu_alpha
+ self.swiglu_limit = config.swiglu_limit
+
+ def forward(
+ self, hidden_states: torch.Tensor, top_k_index: torch.Tensor, top_k_weights: torch.Tensor
+ ) -> torch.Tensor:
+ final = torch.zeros_like(hidden_states)
+ with torch.no_grad():
+ mask = F.one_hot(top_k_index, num_classes=self.num_experts).permute(2, 1, 0)
+ hit = torch.greater(mask.sum(dim=(-1, -2)), 0).nonzero()
+ for expert_idx in hit:
+ expert_idx = expert_idx[0]
+ if expert_idx == self.num_experts:
+ continue
+ top_k_pos, token_idx = torch.where(mask[expert_idx])
+ current = self._apply_gate(F.linear(hidden_states[token_idx], self.gate_up_proj[expert_idx]))
+ current = F.linear(current, self.down_proj[expert_idx]) * top_k_weights[token_idx, top_k_pos, None]
+ final.index_add_(0, token_idx, current.to(final.dtype))
+ return final
+
+ def _apply_gate(self, gate_up: torch.Tensor) -> torch.Tensor:
+ # same as GPT OSS, but the weights are not interleaved
+ gate, up = gate_up.chunk(2, dim=-1)
+ gate = gate.clamp(max=self.swiglu_limit)
+ up = up.clamp(min=-self.swiglu_limit, max=self.swiglu_limit)
+ glu = gate * torch.sigmoid(gate * self.swiglu_alpha)
+ return (up + 1.0) * glu
+
+
+class MiniMaxM3VLTopKRouter(nn.Module):
+ def __init__(self, config: MiniMaxM3VLTextConfig):
+ super().__init__()
+ self.top_k = config.num_experts_per_tok
+ self.num_experts = config.num_local_experts
+ self.hidden_dim = config.hidden_size
+ self.weight = nn.Parameter(torch.empty(self.num_experts, self.hidden_dim))
+ self.register_buffer("e_score_correction_bias", torch.zeros(config.num_local_experts))
+
+ def forward(self, hidden_states: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
+ hidden_states = hidden_states.reshape(-1, self.hidden_dim)
+ router_logits = F.linear(hidden_states.to(self.weight.dtype), self.weight)
+ # Sigmoid scoring (not softmax), as in M2.
+ routing_weights = F.sigmoid(router_logits.float())
+ scores_for_choice = routing_weights + self.e_score_correction_bias
+ _, top_k_index = torch.topk(scores_for_choice, self.top_k, dim=-1, sorted=False)
+ top_k_weights = routing_weights.gather(1, top_k_index)
+ top_k_weights /= top_k_weights.sum(dim=-1, keepdim=True)
+ return router_logits, top_k_weights, top_k_index
+
+
+class MiniMaxM3VLSparseMoeBlock(nn.Module):
+ def __init__(self, config: MiniMaxM3VLTextConfig):
+ super().__init__()
+ self.gate = MiniMaxM3VLTopKRouter(config)
+ self.experts = MiniMaxM3VLExperts(config)
+ self.routed_scaling_factor = config.routed_scaling_factor
+ self.shared_experts = MiniMaxM3VLDenseMLP(config, intermediate_size=config.shared_intermediate_size)
+
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
+ batch_size, sequence_length, hidden_dim = hidden_states.shape
+ hidden_states = hidden_states.view(-1, hidden_dim)
+ shared_output = self.shared_experts(hidden_states)
+
+ _, routing_weights, selected_experts = self.gate(hidden_states)
+ hidden_states = self.experts(hidden_states, selected_experts, routing_weights)
+ # Additional scaling
+ hidden_states = hidden_states * self.routed_scaling_factor
+ hidden_states = hidden_states + shared_output
+
+ hidden_states = hidden_states.reshape(batch_size, sequence_length, hidden_dim)
+ return hidden_states
+
+
+class MiniMaxM3VLRotaryEmbedding(nn.Module):
+ inv_freq: torch.Tensor # fix linting for `register_buffer`
+
+ def __init__(self, config: MiniMaxM3VLConfig, device=None):
+ super().__init__()
+ self.max_seq_len_cached = config.max_position_embeddings
+ self.original_max_seq_len = config.max_position_embeddings
+
+ self.config = config
+
+ self.rope_type = self.config.rope_parameters["rope_type"]
+ rope_init_fn: Callable = self.compute_default_rope_parameters
+ if self.rope_type != "default":
+ rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
+ inv_freq, self.attention_scaling = rope_init_fn(self.config, device)
+
+ self.register_buffer("inv_freq", inv_freq, persistent=False)
+ self.register_buffer("original_inv_freq", inv_freq.clone(), persistent=False)
+
+ @staticmethod
+ def compute_default_rope_parameters(
+ config: MiniMaxM3VLConfig | None = None,
+ device: Optional["torch.device"] = None,
+ seq_len: int | None = None,
+ ) -> tuple["torch.Tensor", float]:
+ """
+ Computes the inverse frequencies according to the original RoPE implementation
+ Args:
+ config ([`~transformers.PreTrainedConfig`]):
+ The model configuration.
+ device (`torch.device`):
+ The device to use for initialization of the inverse frequencies.
+ seq_len (`int`, *optional*):
+ The current sequence length. Unused for this type of RoPE.
+ Returns:
+ Tuple of (`torch.Tensor`, `float`), containing the inverse frequencies for the RoPE embeddings and the
+ post-processing scaling factor applied to the computed cos/sin (unused in this type of RoPE).
+ """
+ base = config.rope_parameters["rope_theta"]
+ partial_rotary_factor = config.rope_parameters.get("partial_rotary_factor", 1.0)
+ head_dim = getattr(config, "head_dim", None) or config.hidden_size // config.num_attention_heads
+ dim = int(head_dim * partial_rotary_factor)
+
+ attention_factor = 1.0 # Unused in this type of RoPE
+
+ # Compute the inverse frequencies
+ inv_freq = 1.0 / (
+ base ** (torch.arange(0, dim, 2, dtype=torch.int64).to(device=device, dtype=torch.float) / dim)
+ )
+ return inv_freq, attention_factor
+
+ @torch.no_grad()
+ @dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope)
+ def forward(self, x, position_ids):
+ inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device)
+ position_ids_expanded = position_ids[:, None, :].float()
+
+ device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
+ with maybe_autocast(device_type=device_type, enabled=False): # Force float32
+ freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
+ emb = torch.cat((freqs, freqs), dim=-1)
+ cos = emb.cos() * self.attention_scaling
+ sin = emb.sin() * self.attention_scaling
+
+ return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
+
+
+def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
+ """
+ This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
+ num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
+ """
+ batch, num_key_value_heads, slen, head_dim = hidden_states.shape
+ if n_rep == 1:
+ return hidden_states
+ hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
+ return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
+
+
+def eager_attention_forward(
+ module: nn.Module,
+ query: torch.Tensor,
+ key: torch.Tensor,
+ value: torch.Tensor,
+ attention_mask: torch.Tensor | None,
+ scaling: float,
+ dropout: float = 0.0,
+ **kwargs: Unpack[TransformersKwargs],
+):
+ key_states = repeat_kv(key, module.num_key_value_groups)
+ value_states = repeat_kv(value, module.num_key_value_groups)
+
+ attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling
+ if attention_mask is not None:
+ attn_weights = attn_weights + attention_mask
+
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
+ attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
+ attn_output = torch.matmul(attn_weights, value_states)
+ attn_output = attn_output.transpose(1, 2).contiguous()
+
+ return attn_output, attn_weights
+
+
+def apply_rotary_pos_emb(q, k, cos, sin, unsqueeze_dim=1):
+ """Applies Rotary Position Embedding to the query and key tensors.
+
+ Args:
+ q (`torch.Tensor`): The query tensor.
+ k (`torch.Tensor`): The key tensor.
+ cos (`torch.Tensor`): The cosine part of the rotary embedding.
+ sin (`torch.Tensor`): The sine part of the rotary embedding.
+ unsqueeze_dim (`int`, *optional*, defaults to 1):
+ The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
+ sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
+ that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
+ k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
+ cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
+ the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
+ Returns:
+ `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
+ """
+ cos = cos.unsqueeze(unsqueeze_dim)
+ sin = sin.unsqueeze(unsqueeze_dim)
+
+ # Keep half or full tensor for later concatenation
+ rotary_dim = cos.shape[-1]
+ q_rot, q_pass = q[..., :rotary_dim], q[..., rotary_dim:]
+ k_rot, k_pass = k[..., :rotary_dim], k[..., rotary_dim:]
+
+ # Apply rotary embeddings on the first half or full tensor
+ q_embed = (q_rot * cos) + (rotate_half(q_rot) * sin)
+ k_embed = (k_rot * cos) + (rotate_half(k_rot) * sin)
+
+ # Concatenate back to full shape
+ q_embed = torch.cat([q_embed, q_pass], dim=-1)
+ k_embed = torch.cat([k_embed, k_pass], dim=-1)
+ return q_embed, k_embed
+
+
+def rotate_half(x):
+ """Rotates half the hidden dims of the input."""
+ x1 = x[..., : x.shape[-1] // 2]
+ x2 = x[..., x.shape[-1] // 2 :]
+ return torch.cat((-x2, x1), dim=-1)
+
+
+class MiniMaxM3VLAttention(nn.Module):
+ """
+ M3 attention: per-head Gemma QK-norm + partial RoPE, optionally sparse indexer selection which require position IDs.
+ """
+
+ def __init__(self, config: MiniMaxM3VLTextConfig, layer_idx: int):
+ super().__init__()
+ self.config = config
+ self.layer_idx = layer_idx
+ self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads)
+ self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads
+ self.scaling = self.head_dim**-0.5
+ self.attention_dropout = config.attention_dropout
+ self.is_causal = True
+ self.q_proj = nn.Linear(config.hidden_size, config.num_attention_heads * self.head_dim, bias=False)
+ self.k_proj = nn.Linear(config.hidden_size, config.num_key_value_heads * self.head_dim, bias=False)
+ self.v_proj = nn.Linear(config.hidden_size, config.num_key_value_heads * self.head_dim, bias=False)
+ self.o_proj = nn.Linear(config.num_attention_heads * self.head_dim, config.hidden_size, bias=False)
+ self.q_norm = MiniMaxM3VLRMSNorm(self.head_dim, eps=config.rms_norm_eps)
+ self.k_norm = MiniMaxM3VLRMSNorm(self.head_dim, eps=config.rms_norm_eps)
+ self.indexer = (
+ MiniMaxM3VLIndexer(config, layer_idx) if config.layer_types[layer_idx] == "minimax_m3_sparse" else None
+ )
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ position_embeddings: tuple[torch.Tensor, torch.Tensor],
+ attention_mask: torch.Tensor | None,
+ past_key_values: Cache | None = None,
+ **kwargs: Unpack[TransformersKwargs],
+ ) -> tuple[torch.Tensor, torch.Tensor | None]:
+ input_shape = hidden_states.shape[:-1]
+ hidden_shape = (*input_shape, -1, self.head_dim)
+
+ query_states = self.q_norm(self.q_proj(hidden_states).view(hidden_shape)).transpose(1, 2)
+ key_states = self.k_norm(self.k_proj(hidden_states).view(hidden_shape)).transpose(1, 2)
+ value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)
+
+ cos, sin = position_embeddings
+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
+
+ if past_key_values is not None:
+ key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx)
+
+ attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface(
+ self.config._attn_implementation, eager_attention_forward
+ )
+ block_indices = None
+ if self.indexer is not None:
+ position_ids = kwargs.get("position_ids")
+ if position_ids is None:
+ position_ids = torch.arange(
+ key_states.shape[2] - query_states.shape[2], key_states.shape[2], device=query_states.device
+ )
+ position_ids = (position_ids if position_ids.ndim > 1 else position_ids.unsqueeze(0)).expand(
+ query_states.shape[0], -1
+ )
+ block_indices = self.indexer(hidden_states, position_embeddings, past_key_values, position_ids)
+ if self.config._attn_implementation in ("eager", "sdpa"):
+ attention_mask = self.indexer.build_block_mask(
+ block_indices,
+ attention_mask,
+ key_states.shape[2],
+ query_states.dtype,
+ query_states.device,
+ position_ids,
+ )
+
+ attn_output, attn_weights = attention_interface(
+ self,
+ query_states,
+ key_states,
+ value_states,
+ attention_mask,
+ dropout=0.0 if not self.training else self.attention_dropout,
+ scaling=self.scaling,
+ block_indices=block_indices,
+ **kwargs,
+ )
+ attn_output = attn_output.reshape(*input_shape, -1).contiguous()
+ return self.o_proj(attn_output), attn_weights
+
+
+class MiniMaxM3VLIndexer(nn.Module):
+ r"""Lightning Indexer for MiniMax M3 sparse attention.
+
+ Scores each query against every key with a small `index_n_heads`-head
+ dot-product branch, then max-pools those per-key scores into *blocks* of
+ `index_block_size` keys and keeps, per query, the top-`index_topk_blocks`
+ key blocks plus the `index_local_blocks` blocks immediately preceding the
+ query (always visible). Selection therefore happens at the granularity of a
+ *block of keys* rather than individual keys: the expensive main attention
+ only has to attend the handful of selected key blocks, which is what makes
+ it block-sparse (and cheaper) on long sequences.
+
+ The `index_local_blocks` boosting their score so they always win key slots, the
+ same way the deployment block-sparse kernel (MiniMax `topk_sparse`) does it.
+
+ `forward` returns the per-query selected key-block indices
+ `[B, S_q, index_topk_blocks]`. Valid indices are left-packed and `-1`
+ right-pads the unused slots (future/empty blocks), and the local boost makes
+ selections deduplicated -- the exact contract the block-sparse attention
+ kernel consumes (it counts the valid entries, then reads them sequentially
+ and would double-count a repeated block). The eager/SDPA path instead calls
+ `build_block_mask`, which expands the indices into the dense
+ `[B, 1, S_q, S_k]` additive mask the standard attention interface expects
+ (`0` at every allowed (query, key) pair, `-inf` elsewhere).
+
+ Like DeepSeek-V4's indexer this is purely a *selection* branch: it has no
+ value projection and produces no residual output of its own (the upstream
+ checkpoint disables the index-value path on every sparse layer).
+
+ TODO: blocks are anchored to absolute key *slots* (the contiguous reshape in
+ `forward` and `q_block = slot // block_size`), so left-padding shifts the block
+ boundaries and the selection diverges from an unpadded run -- only right-padding
+ is equivalent (same limitation as DeepSeek-V4; see `test_right_padding_does_not_leak`
+ / the skipped `test_left_padding_compatibility`). For *true* left-padding equivalence
+ we'd make blocking content-relative instead of slot-relative:
+ 1. derive block ids from `position_ids` (content positions, 0 at each row's first
+ real token) rather than from absolute slots, and
+ 2. replace the contiguous `view(..., num_key_blocks, block_size).amax(-1)` key pool
+ with a per-row position-binned pool (e.g. `scatter_reduce` over `key_position //
+ block_size`), so pad never shifts the boundaries, and
+ 3. mask padded keys' scores to `-inf` before the pool so a pad key can't win a block
+ a top-k slot.
+ """
+
+ def __init__(self, config: MiniMaxM3VLTextConfig, layer_idx: int):
+ super().__init__()
+ self.config = config
+ self.layer_idx = layer_idx
+ self.head_dim = config.index_head_dim
+ self.num_heads = config.index_n_heads
+ self.block_size = config.index_block_size
+ self.topk_blocks = config.index_topk_blocks
+ self.local_blocks = config.index_local_blocks
+ self.q_proj = nn.Linear(config.hidden_size, config.index_n_heads * config.index_head_dim, bias=False)
+ self.k_proj = nn.Linear(config.hidden_size, config.index_head_dim, bias=False)
+ self.q_norm = MiniMaxM3VLRMSNorm(config.index_head_dim, eps=config.rms_norm_eps)
+ self.k_norm = MiniMaxM3VLRMSNorm(config.index_head_dim, eps=config.rms_norm_eps)
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ position_embeddings: tuple[torch.Tensor, torch.Tensor],
+ past_key_values: Cache | None,
+ position_ids: torch.Tensor,
+ ) -> torch.Tensor:
+ batch, q_len, _ = hidden_states.shape
+ idx_q = self.q_proj(hidden_states).view(batch, q_len, -1, self.head_dim)
+ idx_q = self.q_norm(idx_q).transpose(1, 2) # [B, H_idx, Sq, D]
+ idx_k = self.k_proj(hidden_states).view(batch, q_len, 1, self.head_dim)
+ idx_k = self.k_norm(idx_k).transpose(1, 2) # [B, 1, Sq, D]
+ cos, sin = position_embeddings
+ idx_q, idx_k = apply_rotary_pos_emb(idx_q, idx_k, cos[..., : self.head_dim], sin[..., : self.head_dim])
+
+ if past_key_values is not None:
+ idx_k = past_key_values.layers[self.layer_idx].update_index(idx_k)
+
+ k_len = idx_k.shape[2]
+ num_key_blocks = -(-k_len // self.block_size) # ceil-div
+ pad = num_key_blocks * self.block_size - k_len
+
+ scores = torch.matmul(idx_q.float(), idx_k.float().transpose(-1, -2))
+ k_positions = torch.arange(k_len, device=idx_q.device)
+ token_future = k_positions[None, None, None, :] > position_ids[:, None, :, None] # [B, 1, S_q, S_k]
+ scores = scores.masked_fill(token_future, float("-inf"))
+ if pad:
+ scores = F.pad(scores, (0, pad), value=float("-inf"))
+ scores = scores.view(batch, self.num_heads, q_len, num_key_blocks, self.block_size)
+ block_scores = scores.amax(dim=-1).amax(dim=1) # -> [B, S_q, num_key_blocks]
+
+ q_block = position_ids // self.block_size # [B, S_q]
+
+ if self.local_blocks > 0:
+ local = torch.arange(self.local_blocks, device=idx_q.device)
+ local_idx = (q_block[..., None] - local.view(1, 1, -1)).clamp(min=0) # [B, S_q, local]
+ block_scores.scatter_(-1, local_idx, float("inf"))
+
+ # Slots that fall on a future/empty block keep their `-inf`
+ # score, which top-k sorts to the end, so tagging them `-1` yields left-packed block indices
+ # with `-1` right-padding which is the format expect by block-sparse attention kernel.
+ topk = min(self.topk_blocks, num_key_blocks)
+ topk_scores, topk_indices = block_scores.topk(topk, dim=-1) # [B, S_q, topk]
+ return topk_indices.masked_fill(topk_scores == float("-inf"), -1)
+
+ def build_block_mask(
+ self,
+ block_indices: torch.Tensor,
+ attention_mask: torch.Tensor | None,
+ key_length: int,
+ dtype: torch.dtype,
+ device: torch.device,
+ position_ids: torch.Tensor,
+ ) -> torch.Tensor:
+ """
+ We build the full 4D attention mask (Batch, query, key, head)
+ """
+ batch, q_len, _ = block_indices.shape
+ num_key_blocks = -(-key_length // self.block_size)
+
+ # Scatter the kept blocks to `0`; `-1` slots land in a throwaway column we drop afterwards.
+ safe = block_indices.masked_fill(block_indices < 0, num_key_blocks)
+ bias = block_indices.new_full((batch, q_len, num_key_blocks + 1), float("-inf"), dtype=dtype)
+ bias.scatter_(-1, safe, 0.0)
+ bias = bias[..., :num_key_blocks]
+
+ # Broadcast the per-block keep/drop verdict back onto every key (block granularity), add head axis.
+ block_keep = (bias == 0.0).repeat_interleave(self.block_size, dim=-1)[..., :key_length].unsqueeze(1)
+
+ # Compose block-selection with the existing mask, then emit a single additive float mask.
+ if attention_mask is not None:
+ padding_mask = attention_mask if attention_mask.dtype == torch.bool else attention_mask == 0
+ keep = block_keep & padding_mask
+ else:
+ k_positions = torch.arange(key_length, device=device)
+ token_future = k_positions[None, None, None, :] > position_ids[:, None, :, None] # [B, 1, S_q, S_k]
+ keep = block_keep & ~token_future
+ min_dtype = torch.finfo(dtype).min
+ return torch.zeros(keep.shape, dtype=dtype, device=device).masked_fill(~keep, min_dtype)
+
+
+class MiniMaxM3VLDecoderLayer(GradientCheckpointingLayer):
+ """M3 decoder layer: per-layer dense/MoE MLP and dense/sparse attention."""
+
+ def __init__(self, config: MiniMaxM3VLTextConfig, layer_idx: int):
+ super().__init__()
+ self.hidden_size = config.hidden_size
+ self.self_attn = MiniMaxM3VLAttention(config, layer_idx)
+ self.mlp = (
+ MiniMaxM3VLSparseMoeBlock(config)
+ if config.mlp_layer_types[layer_idx] == "sparse"
+ else MiniMaxM3VLDenseMLP(config, intermediate_size=config.dense_intermediate_size)
+ )
+ self.input_layernorm = MiniMaxM3VLRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
+ self.post_attention_layernorm = MiniMaxM3VLRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ position_embeddings: tuple[torch.Tensor, torch.Tensor] | None = None,
+ attention_mask: torch.Tensor | None = None,
+ position_ids: torch.LongTensor | None = None,
+ past_key_values: Cache | None = None,
+ **kwargs: Unpack[TransformersKwargs],
+ ) -> torch.Tensor:
+ residual = hidden_states
+ hidden_states = self.input_layernorm(hidden_states)
+ hidden_states, _ = self.self_attn(
+ hidden_states=hidden_states,
+ position_embeddings=position_embeddings,
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ past_key_values=past_key_values,
+ **kwargs,
+ )
+ hidden_states = residual + hidden_states
+ residual = hidden_states
+ hidden_states = self.post_attention_layernorm(hidden_states)
+ hidden_states = self.mlp(hidden_states)
+ hidden_states = residual + hidden_states
+ return hidden_states
+
+
+@auto_docstring
+class MiniMaxM3VLPreTrainedModel(PreTrainedModel):
+ config: MiniMaxM3VLConfig | MiniMaxM3VLTextConfig
+ base_model_prefix = "model"
+ supports_gradient_checkpointing = True
+ _no_split_modules = ["MiniMaxM3VLDecoderLayer", "MiniMaxM3VLVisionEncoderLayer"]
+ _skip_keys_device_placement = ["past_key_values"]
+ _supports_flash_attn = False
+ _supports_sdpa = True
+ _supports_flex_attn = False
+ _can_compile_fullgraph = True
+ _supports_attention_backend = True
+ _can_record_outputs = {
+ "router_logits": OutputRecorder(MiniMaxM3VLTopKRouter, index=0),
+ "hidden_states": MiniMaxM3VLDecoderLayer,
+ "attentions": MiniMaxM3VLAttention,
+ }
+ input_modalities = ("image", "video", "text")
+ _keys_to_ignore_on_load_unexpected = [r"(^|\.)mtp\..*"]
+ _compatible_flash_implementations = ["kernels-staging/msa@v0"]
+
+ @torch.no_grad()
+ def _init_weights(self, module):
+ super()._init_weights(module)
+ std = getattr(self.config, "initializer_range", 0.02)
+ if isinstance(module, MiniMaxM3VLExperts):
+ init.normal_(module.gate_up_proj, mean=0.0, std=std)
+ init.normal_(module.down_proj, mean=0.0, std=std)
+ elif isinstance(module, MiniMaxM3VLTopKRouter):
+ init.normal_(module.weight, mean=0.0, std=std)
+ init.zeros_(module.e_score_correction_bias)
+ elif isinstance(module, MiniMaxM3VLRMSNorm):
+ init.zeros_(module.weight)
+
+
+@auto_docstring
+class MiniMaxM3VLTextModel(MiniMaxM3VLPreTrainedModel):
+ config: MiniMaxM3VLTextConfig
+
+ def __init__(self, config: MiniMaxM3VLTextConfig):
+ super().__init__(config)
+ self.padding_idx = config.pad_token_id
+ self.vocab_size = config.vocab_size
+
+ self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
+ self.layers = nn.ModuleList([MiniMaxM3VLDecoderLayer(config, i) for i in range(config.num_hidden_layers)])
+ self.norm = MiniMaxM3VLRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
+ self.rotary_emb = MiniMaxM3VLRotaryEmbedding(config=config)
+ self.gradient_checkpointing = False
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ @merge_with_config_defaults
+ @capture_outputs
+ @auto_docstring
+ def forward(
+ self,
+ input_ids: torch.LongTensor | None = None,
+ attention_mask: torch.Tensor | None = None,
+ position_ids: torch.LongTensor | None = None,
+ past_key_values: Cache | None = None,
+ inputs_embeds: torch.FloatTensor | None = None,
+ use_cache: bool | None = None,
+ **kwargs: Unpack[TransformersKwargs],
+ ) -> MoeModelOutputWithPast:
+ if (input_ids is None) ^ (inputs_embeds is not None):
+ raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
+
+ if use_cache and past_key_values is None:
+ past_key_values = DynamicCache(config=self.config)
+
+ if inputs_embeds is None:
+ inputs_embeds = self.embed_tokens(input_ids)
+
+ if position_ids is None:
+ past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
+ position_ids = torch.arange(inputs_embeds.shape[1], device=inputs_embeds.device) + past_seen_tokens
+ position_ids = position_ids.unsqueeze(0)
+
+ if isinstance(attention_mask, dict):
+ causal_mask = next(iter(attention_mask.values()))
+ else:
+ causal_mask = create_causal_mask(
+ config=self.config,
+ inputs_embeds=inputs_embeds,
+ attention_mask=attention_mask,
+ past_key_values=past_key_values,
+ position_ids=position_ids,
+ )
+
+ hidden_states = inputs_embeds
+ position_embeddings = self.rotary_emb(hidden_states, position_ids=position_ids)
+
+ # `position_ids` is threaded to every layer so the sparse layers' lightning indexer can anchor
+ # block selection to each query's content position (see `MiniMaxM3VLIndexer`).
+ for decoder_layer in self.layers[: self.config.num_hidden_layers]:
+ hidden_states = decoder_layer(
+ hidden_states,
+ attention_mask=causal_mask,
+ position_ids=position_ids,
+ past_key_values=past_key_values,
+ use_cache=use_cache,
+ position_embeddings=position_embeddings,
+ **kwargs,
+ )
+
+ hidden_states = self.norm(hidden_states)
+
+ return MoeModelOutputWithPast(
+ last_hidden_state=hidden_states,
+ past_key_values=past_key_values,
+ )
+
+
+def load_balancing_loss_func(
+ gate_logits: torch.Tensor | tuple[torch.Tensor] | None,
+ num_experts: int | None = None,
+ top_k=2,
+ attention_mask: torch.Tensor | None = None,
+) -> torch.Tensor | int:
+ r"""
+ Computes auxiliary load balancing loss as in Switch Transformer - implemented in Pytorch.
+
+ See Switch Transformer (https://huggingface.co/papers/2101.03961) for more details. This function implements the loss
+ function presented in equations (4) - (6) of the paper. It aims at penalizing cases where the routing between
+ experts is too unbalanced.
+
+ Args:
+ gate_logits:
+ Logits from the `gate`, should be a tuple of model.config.num_hidden_layers tensors of
+ shape [batch_size X sequence_length, num_experts].
+ num_experts:
+ Number of experts
+ top_k:
+ The number of experts to route per-token, can be also interpreted as the `top-k` routing
+ parameter.
+ attention_mask (`torch.Tensor`, *optional*):
+ The attention_mask used in forward function
+ shape [batch_size X sequence_length] if not None.
+
+ Returns:
+ The auxiliary loss.
+ """
+ if gate_logits is None or not isinstance(gate_logits, tuple):
+ return 0
+
+ if isinstance(gate_logits, tuple):
+ compute_device = gate_logits[0].device
+ concatenated_gate_logits = torch.cat([layer_gate.to(compute_device) for layer_gate in gate_logits], dim=0)
+
+ routing_weights = torch.nn.functional.softmax(concatenated_gate_logits, dim=-1)
+
+ _, selected_experts = torch.topk(routing_weights, top_k, dim=-1)
+
+ expert_mask = torch.nn.functional.one_hot(selected_experts, num_experts)
+
+ if attention_mask is None:
+ # Compute the percentage of tokens routed to each experts
+ tokens_per_expert = torch.mean(expert_mask.float(), dim=0)
+
+ # Compute the average probability of routing to these experts
+ router_prob_per_expert = torch.mean(routing_weights, dim=0)
+ else:
+ batch_size, sequence_length = attention_mask.shape
+ num_hidden_layers = concatenated_gate_logits.shape[0] // (batch_size * sequence_length)
+
+ # Compute the mask that masks all padding tokens as 0 with the same shape of expert_mask
+ expert_attention_mask = (
+ attention_mask[None, :, :, None, None]
+ .expand((num_hidden_layers, batch_size, sequence_length, top_k, num_experts))
+ .reshape(-1, top_k, num_experts)
+ .to(compute_device)
+ )
+
+ # Compute the percentage of tokens routed to each experts
+ tokens_per_expert = torch.sum(expert_mask.float() * expert_attention_mask, dim=0) / torch.sum(
+ expert_attention_mask, dim=0
+ )
+
+ # Compute the mask that masks all padding tokens as 0 with the same shape of tokens_per_expert
+ router_per_expert_attention_mask = (
+ attention_mask[None, :, :, None]
+ .expand((num_hidden_layers, batch_size, sequence_length, num_experts))
+ .reshape(-1, num_experts)
+ .to(compute_device)
+ )
+
+ # Compute the average probability of routing to these experts
+ router_prob_per_expert = torch.sum(routing_weights * router_per_expert_attention_mask, dim=0) / torch.sum(
+ router_per_expert_attention_mask, dim=0
+ )
+
+ overall_loss = torch.sum(tokens_per_expert * router_prob_per_expert.unsqueeze(0))
+ return overall_loss * num_experts
+
+
+@auto_docstring
+class MiniMaxM3VLForCausalLM(MiniMaxM3VLPreTrainedModel, GenerationMixin):
+ _tied_weights_keys = {"lm_head.weight": "model.embed_tokens.weight"}
+ _tp_plan = {"lm_head": "colwise_gather_output"}
+ _pp_plan = {"lm_head": (["hidden_states"], ["logits"])}
+ config: MiniMaxM3VLTextConfig
+
+ def __init__(self, config: MiniMaxM3VLTextConfig):
+ super().__init__(config)
+ self.model = MiniMaxM3VLTextModel(config)
+ self.vocab_size = config.vocab_size
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
+ self.router_aux_loss_coef = config.router_aux_loss_coef
+ self.num_experts = config.num_local_experts
+ self.num_experts_per_tok = config.num_experts_per_tok
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ @can_return_tuple
+ @auto_docstring
+ def forward(
+ self,
+ input_ids: torch.LongTensor | None = None,
+ attention_mask: torch.Tensor | None = None,
+ position_ids: torch.LongTensor | None = None,
+ past_key_values: Cache | None = None,
+ inputs_embeds: torch.FloatTensor | None = None,
+ labels: torch.LongTensor | None = None,
+ use_cache: bool | None = None,
+ output_router_logits: bool | None = None,
+ logits_to_keep: int | torch.Tensor = 0,
+ **kwargs: Unpack[TransformersKwargs],
+ ) -> MoeCausalLMOutputWithPast:
+ r"""
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
+ Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
+ config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
+ (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
+
+ Example:
+
+ ```python
+ >>> from transformers import AutoTokenizer, MiniMaxM3VLForCausalLM
+
+ >>> model = MiniMaxM3VLForCausalLM.from_pretrained("mistralai/MiniMaxM3VL-8x7B-v0.1")
+ >>> tokenizer = AutoTokenizer.from_pretrained("mistralai/MiniMaxM3VL-8x7B-v0.1")
+
+ >>> prompt = "Hey, are you conscious? Can you talk to me?"
+ >>> inputs = tokenizer(prompt, return_tensors="pt")
+
+ >>> # Generate
+ >>> generate_ids = model.generate(inputs.input_ids, max_length=30)
+ >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
+ "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
+ ```"""
+
+ output_router_logits = (
+ output_router_logits if output_router_logits is not None else self.config.output_router_logits
+ )
+
+ # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
+ outputs: MoeModelOutputWithPast = self.model(
+ input_ids=input_ids,
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ past_key_values=past_key_values,
+ inputs_embeds=inputs_embeds,
+ use_cache=use_cache,
+ output_router_logits=output_router_logits,
+ **kwargs,
+ )
+
+ hidden_states = outputs.last_hidden_state
+ # Only compute necessary logits, and do not upcast them to float if we are not computing the loss
+ slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
+ logits = self.lm_head(hidden_states[:, slice_indices, :])
+
+ loss = None
+ if labels is not None:
+ loss = self.loss_function(logits, labels, self.vocab_size, **kwargs)
+
+ aux_loss = None
+ if output_router_logits:
+ aux_loss = load_balancing_loss_func(
+ outputs.router_logits,
+ self.num_experts,
+ self.num_experts_per_tok,
+ attention_mask,
+ )
+ if labels is not None:
+ loss += self.router_aux_loss_coef * aux_loss.to(loss.device) # make sure to reside in the same device
+
+ return MoeCausalLMOutputWithPast(
+ loss=loss,
+ aux_loss=aux_loss,
+ logits=logits,
+ past_key_values=outputs.past_key_values,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
+ router_logits=outputs.router_logits,
+ )
+
+
+class MiniMaxM3VLVisionEmbeddings(nn.Module):
+ """Patch embedding, identical to [`Qwen2_5_VisionPatchEmbed`] (reads its dims from the vision
+ config). The upstream checkpoint stores the conv as `patch_embedding`, renamed to the
+ inherited `proj` in the conversion mapping."""
+
+ def __init__(self, config) -> None:
+ super().__init__()
+ self.patch_size = config.patch_size
+ self.temporal_patch_size = config.temporal_patch_size
+ self.in_channels = config.num_channels
+ self.embed_dim = config.hidden_size
+
+ kernel_size = [self.temporal_patch_size, self.patch_size, self.patch_size]
+ self.proj = nn.Conv3d(
+ self.in_channels, self.embed_dim, kernel_size=kernel_size, stride=kernel_size, bias=False
+ )
+
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
+ target_dtype = self.proj.weight.dtype
+ hidden_states = hidden_states.view(
+ -1, self.in_channels, self.temporal_patch_size, self.patch_size, self.patch_size
+ )
+ hidden_states = self.proj(hidden_states.to(dtype=target_dtype)).view(-1, self.embed_dim)
+ return hidden_states
+
+
+class MiniMaxM3VL3DRotaryEmbedding(nn.Module):
+ r"""3D RoPE for the vision tower: each patch is rotated by its `(T, H, W)` grid position.
+
+ `2 * (head_dim // 2)` rotary dims are split evenly across the three axes (each rounded
+ down to a multiple of 2), giving `axis_dim` dims per axis and `axis_dim // 2` frequencies::
+
+ |<------------------ rotated (3 * axis_dim) ------------------>|<- pass ->|
+ +--------------------+--------------------+--------------------+----------+
+ | T (frames) | H (rows) | W (cols) | |
+ | axis_dim | axis_dim | axis_dim | |
+ +--------------------+--------------------+--------------------+----------+
+
+ Each axis' coordinate scales its own band of frequencies; the bands are concatenated as
+ `T|H|W` and duplicated via `cat([f, f])` to pair with the half-rotation in
+ `apply_rotary_pos_emb_vision`. Any head dims past `3 * axis_dim` are left unrotated.
+ """
+
+ def __init__(self, head_dim: int, theta: float = 10000.0, spatial_merge_size: int = 1):
+ super().__init__()
+ # `2 * (head_dim // 2)` rotary dims are split evenly across T/H/W, each axis rounded
+ # down to a multiple of 2. With head_dim=80 that is 26 dims/axis (39 freqs total); the
+ # remaining `head_dim - 3 * axis_dim` dims are never rotated (they pass through).
+ rope_dims = 2 * (head_dim // 2)
+ self.axis_dim = 2 * ((rope_dims // 3) // 2)
+ self.spatial_merge_size = spatial_merge_size
+ self.theta = theta
+
+ def forward(
+ self, grid_thw: torch.Tensor, device: torch.device, dtype: torch.dtype
+ ) -> tuple[torch.Tensor, torch.Tensor]:
+ m = self.spatial_merge_size
+ coords = []
+ for t, h, w in grid_thw.tolist():
+ hi = torch.arange(h).unsqueeze(1).expand(-1, w)
+ hi = hi.reshape(h // m, m, w // m, m).permute(0, 2, 1, 3).flatten()
+ wi = torch.arange(w).unsqueeze(0).expand(h, -1)
+ wi = wi.reshape(h // m, m, w // m, m).permute(0, 2, 1, 3).flatten()
+ ti = torch.arange(t).repeat_interleave(h * w)
+ coords.append(torch.stack([ti, hi.repeat(t), wi.repeat(t)], dim=-1))
+ coords = torch.cat(coords).to(device=device, dtype=torch.float32)
+
+ # meta device init was having trouble when it was registered. TODO standardize?
+ inv_freq = 1.0 / (
+ self.theta ** (torch.arange(0, self.axis_dim, 2, dtype=torch.float32, device=device) / self.axis_dim)
+ )
+ freqs = torch.cat([coords[:, i : i + 1] * inv_freq for i in range(3)], dim=-1)
+ emb = torch.cat([freqs, freqs], dim=-1)
+ return emb.cos().to(dtype), emb.sin().to(dtype)
+
+
+def apply_rotary_pos_emb_vision(
+ q: torch.Tensor, k: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor
+) -> tuple[torch.Tensor, torch.Tensor]:
+ # Only the first `rot_dim` head dims carry 3D RoPE; the tail passes through untouched.
+ rot_dim = cos.shape[-1]
+ cos, sin = cos[None, :, None, :], sin[None, :, None, :]
+ q_rot, q_pass = q[..., :rot_dim], q[..., rot_dim:]
+ k_rot, k_pass = k[..., :rot_dim], k[..., rot_dim:]
+ q_rot = q_rot * cos + rotate_half(q_rot) * sin
+ k_rot = k_rot * cos + rotate_half(k_rot) * sin
+ return torch.cat([q_rot, q_pass], dim=-1), torch.cat([k_rot, k_pass], dim=-1)
+
+
+class MiniMaxM3VLVisionAttention(nn.Module):
+ """CLIP-style vision attention; the only difference from [`CLIPAttention`] is
+ that queries and keys are rotated by the tower's 3D RoPE before the
+ (interface-dispatched) scaled dot-product attention."""
+
+ def __init__(self, config: MiniMaxM3VLVisionConfig):
+ super().__init__()
+ self.config = config
+ self.embed_dim = config.hidden_size
+ self.num_heads = config.num_attention_heads
+ self.head_dim = self.embed_dim // self.num_heads
+ self.scale = self.head_dim**-0.5
+ self.dropout = config.attention_dropout
+ self.is_causal = False
+
+ self.k_proj = nn.Linear(self.embed_dim, self.embed_dim)
+ self.v_proj = nn.Linear(self.embed_dim, self.embed_dim)
+ self.q_proj = nn.Linear(self.embed_dim, self.embed_dim)
+ self.out_proj = nn.Linear(self.embed_dim, self.embed_dim)
+ # The vision tower has no grouped-query attention; the shared eager kernel
+ # still expects this attribute to drive its (no-op) `repeat_kv`.
+ self.num_key_value_groups = 1
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ position_embeddings: tuple[torch.Tensor, torch.Tensor],
+ attention_mask: torch.Tensor | None = None,
+ **kwargs: Unpack[TransformersKwargs],
+ ) -> tuple[torch.Tensor, torch.Tensor | None]:
+ """Input shape: Batch x Time x Channel"""
+ input_shape = hidden_states.shape[:-1]
+ hidden_shape = (*input_shape, -1, self.head_dim)
+
+ queries = self.q_proj(hidden_states).view(hidden_shape)
+ keys = self.k_proj(hidden_states).view(hidden_shape)
+ values = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)
+
+ cos, sin = position_embeddings
+ queries, keys = apply_rotary_pos_emb_vision(queries, keys, cos, sin)
+ queries, keys = queries.transpose(1, 2), keys.transpose(1, 2)
+
+ attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface(
+ self.config._attn_implementation, eager_attention_forward
+ )
+ attn_output, attn_weights = attention_interface(
+ self,
+ queries,
+ keys,
+ values,
+ attention_mask,
+ scaling=self.scale,
+ dropout=0.0 if not self.training else self.dropout,
+ **kwargs,
+ )
+ attn_output = attn_output.reshape(*input_shape, -1).contiguous()
+ return self.out_proj(attn_output), attn_weights
+
+
+class MiniMaxM3VLVisionMLP(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+ self.config = config
+ self.activation_fn = ACT2FN[config.hidden_act]
+ self.fc1 = nn.Linear(config.hidden_size, config.intermediate_size)
+ self.fc2 = nn.Linear(config.intermediate_size, config.hidden_size)
+
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
+ hidden_states = self.fc1(hidden_states)
+ hidden_states = self.activation_fn(hidden_states)
+ hidden_states = self.fc2(hidden_states)
+ return hidden_states
+
+
+class MiniMaxM3VLVisionEncoderLayer(GradientCheckpointingLayer):
+ def __init__(self, config: MiniMaxM3VLVisionConfig):
+ super().__init__()
+ self.embed_dim = config.hidden_size
+ self.self_attn = MiniMaxM3VLVisionAttention(config)
+ self.layer_norm1 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
+ self.mlp = MiniMaxM3VLVisionMLP(config)
+ self.layer_norm2 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ attention_mask: torch.Tensor,
+ **kwargs: Unpack[TransformersKwargs],
+ ) -> torch.FloatTensor:
+ residual = hidden_states
+
+ hidden_states = self.layer_norm1(hidden_states)
+ hidden_states, _ = self.self_attn(
+ hidden_states=hidden_states,
+ attention_mask=attention_mask,
+ **kwargs,
+ )
+ hidden_states = residual + hidden_states
+
+ residual = hidden_states
+ hidden_states = self.layer_norm2(hidden_states)
+ hidden_states = self.mlp(hidden_states)
+ hidden_states = residual + hidden_states
+
+ return hidden_states
+
+
+@auto_docstring
+class MiniMaxM3VLVisionModel(MiniMaxM3VLPreTrainedModel):
+ """CLIP-like vision tower with Conv3d patch embed + 3D RoPE."""
+
+ config: MiniMaxM3VLVisionConfig
+ main_input_name = "pixel_values"
+ _can_record_outputs = {
+ "hidden_states": MiniMaxM3VLVisionEncoderLayer,
+ "attentions": MiniMaxM3VLVisionAttention,
+ }
+
+ def __init__(self, config: MiniMaxM3VLVisionConfig):
+ super().__init__(config)
+ self.embeddings = MiniMaxM3VLVisionEmbeddings(config)
+ self.pre_layrnorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
+ self.layers = nn.ModuleList([MiniMaxM3VLVisionEncoderLayer(config) for _ in range(config.num_hidden_layers)])
+ head_dim = config.hidden_size // config.num_attention_heads
+ self.rotary_emb = MiniMaxM3VL3DRotaryEmbedding(
+ head_dim, theta=config.rope_parameters["rope_theta"], spatial_merge_size=config.spatial_merge_size
+ )
+ self.post_init()
+
+ @merge_with_config_defaults
+ @capture_outputs
+ @auto_docstring
+ def forward(
+ self, pixel_values: torch.Tensor, image_grid_thw: torch.Tensor, **kwargs: Unpack[TransformersKwargs]
+ ) -> BaseModelOutputWithPooling:
+ r"""
+ image_grid_thw (`torch.Tensor` of shape `(num_images, 3)`):
+ The temporal, height and width of feature shape of each image.
+ """
+ embeds = self.embeddings(pixel_values).to(self.pre_layrnorm.weight.dtype)
+ cos, sin = self.rotary_emb(image_grid_thw, device=embeds.device, dtype=embeds.dtype)
+ hidden_states = self.pre_layrnorm(embeds).unsqueeze(0)
+ for layer in self.layers:
+ hidden_states = layer(hidden_states, attention_mask=None, position_embeddings=(cos, sin), **kwargs)
+ return BaseModelOutputWithPooling(last_hidden_state=hidden_states, pooler_output=hidden_states[:, 0])
+
+
+class MiniMaxM3VLMultiModalProjector(nn.Module):
+ """Projects each vision patch from `vision_config.hidden_size` to `text_config.hidden_size`
+ (GELU MLP), then groups `spatial_merge_size**2` neighbouring patches into the channel dim and
+ fuses them back to a single `text_config.hidden_size` token with a second GELU MLP."""
+
+ def __init__(self, config: MiniMaxM3VLConfig):
+ super().__init__()
+ text_hidden = config.text_config.hidden_size
+ self.spatial_merge_size = config.vision_config.spatial_merge_size
+ self.linear_1 = nn.Linear(config.vision_config.hidden_size, config.projector_hidden_size, bias=True)
+ self.act = ACT2FN["gelu"]
+ self.linear_2 = nn.Linear(config.projector_hidden_size, text_hidden, bias=True)
+ self.merge_linear_1 = nn.Linear(config.merged_hidden_size, config.projector_hidden_size, bias=True)
+ self.merge_act = ACT2FN["gelu"]
+ self.merge_linear_2 = nn.Linear(config.projector_hidden_size, text_hidden, bias=True)
+
+ def forward(self, image_features: torch.Tensor) -> torch.Tensor:
+ hidden_states = self.linear_2(self.act(self.linear_1(image_features)))
+ hidden_states = hidden_states.reshape(hidden_states.shape[0] // (self.spatial_merge_size**2), -1)
+ return self.merge_linear_2(self.merge_act(self.merge_linear_1(hidden_states)))
+
+
+@auto_docstring(
+ custom_intro="""
+ Base class for MiniMaxM3VL outputs, with hidden states and attentions.
+ """
+)
+@dataclass
+class MiniMaxM3VLModelOutputWithPast(BaseModelOutputWithPast):
+ r"""
+ past_key_values (`Cache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
+ It is a [`~cache_utils.Cache`] instance. For more details, see our [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache).
+
+ Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see
+ `past_key_values` input) to speed up sequential decoding.
+ image_hidden_states (`torch.FloatTensor`, *optional*):
+ A `torch.FloatTensor` of size `(num_image_patches, hidden_size)`.
+ image_hidden_states of the model produced by the vision encoder and after projecting the last hidden state.
+ video_hidden_states (`torch.FloatTensor`, *optional*):
+ A `torch.FloatTensor` of size `(num_video_patches, hidden_size)`.
+ video_hidden_states of the model produced by the vision encoder and after projecting the last hidden state.
+ """
+
+ image_hidden_states: torch.FloatTensor | None = None
+
+ video_hidden_states: torch.FloatTensor | None = None
+
+
+@auto_docstring(
+ custom_intro="""
+ Base class for MiniMaxM3VL causal language model (or autoregressive) outputs.
+ """
+)
+@dataclass
+class MiniMaxM3VLCausalLMOutputWithPast(ModelOutput):
+ r"""
+ loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
+ Language modeling loss (for next-token prediction).
+ logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`):
+ Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
+ past_key_values (`Cache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
+ It is a [`~cache_utils.Cache`] instance. For more details, see our [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache).
+
+ Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see
+ `past_key_values` input) to speed up sequential decoding.
+ image_hidden_states (`torch.FloatTensor`, *optional*):
+ A `torch.FloatTensor` of size `(num_image_patches, hidden_size)`.
+ image_hidden_states of the model produced by the vision encoder and after projecting the last hidden state.
+ video_hidden_states (`torch.FloatTensor`, *optional*):
+ A `torch.FloatTensor` of size `(num_video_patches, hidden_size)`.
+ video_hidden_states of the model produced by the vision encoder and after projecting the last hidden state.
+ """
+
+ loss: torch.FloatTensor | None = None
+ logits: torch.FloatTensor | None = None
+ past_key_values: Cache | None = None
+ hidden_states: tuple[torch.FloatTensor] | None = None
+ attentions: tuple[torch.FloatTensor] | None = None
+ image_hidden_states: torch.FloatTensor | None = None
+
+ video_hidden_states: torch.FloatTensor | None = None
+
+
+@auto_docstring(custom_intro="MiniMax M3 VL backbone (vision + projector + text), without LM head.")
+class MiniMaxM3VLModel(MiniMaxM3VLPreTrainedModel):
+ config: MiniMaxM3VLConfig
+
+ def __init__(self, config: MiniMaxM3VLConfig):
+ super().__init__(config)
+ self.vision_tower = MiniMaxM3VLVisionModel(config.vision_config)
+ self.multi_modal_projector = MiniMaxM3VLMultiModalProjector(config)
+ self.language_model = MiniMaxM3VLTextModel(config.text_config)
+ self.post_init()
+
+ @merge_with_config_defaults
+ @can_return_tuple
+ @auto_docstring(
+ custom_intro="Obtains image last hidden states from the vision tower and apply multimodal projection."
+ )
+ def get_image_features(
+ self,
+ pixel_values: torch.FloatTensor,
+ image_grid_thw: torch.Tensor,
+ **kwargs,
+ ) -> BaseModelOutputWithPooling:
+ r"""
+ image_grid_thw (`torch.Tensor` of shape `(num_images, 3)`, *optional*):
+ The temporal, height and width of each image's feature grid, used to build the vision 3D RoPE
+ and to merge patch features.
+ """
+ # Return the raw vision-tower output (so callers can inspect hidden states /
+ # attentions) while stashing the projected + spatially-merged features —
+ # ready to scatter into the text embeddings — in `pooler_output`.
+ vision_outputs = self.vision_tower(pixel_values=pixel_values, image_grid_thw=image_grid_thw, **kwargs)
+ vision_outputs.pooler_output = self.multi_modal_projector(vision_outputs.last_hidden_state.squeeze(0))
+ return vision_outputs
+
+ def get_placeholder_mask(
+ self,
+ input_ids: torch.LongTensor,
+ inputs_embeds: torch.FloatTensor,
+ image_features: torch.FloatTensor | None = None,
+ video_features: torch.FloatTensor | None = None,
+ ):
+ """
+ Obtains the image/video placeholder masks from `input_ids` or `inputs_embeds`, and checks that the
+ placeholder token count matches the multimodal feature length. Raises if they differ.
+ """
+ if input_ids is None:
+ special_image_mask = inputs_embeds == self.get_input_embeddings()(
+ torch.tensor(self.config.image_token_id, dtype=torch.long, device=inputs_embeds.device)
+ )
+ special_image_mask = special_image_mask.all(-1)
+ special_video_mask = inputs_embeds == self.get_input_embeddings()(
+ torch.tensor(self.config.video_token_id, dtype=torch.long, device=inputs_embeds.device)
+ )
+ special_video_mask = special_video_mask.all(-1)
+ else:
+ special_image_mask = input_ids == self.config.image_token_id
+ special_video_mask = input_ids == self.config.video_token_id
+
+ n_image_tokens = special_image_mask.sum()
+ special_image_mask = special_image_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device)
+ if image_features is not None:
+ torch_compilable_check(
+ inputs_embeds[special_image_mask].numel() == image_features.numel(),
+ f"Image features and image tokens do not match, tokens: {n_image_tokens}, features: {image_features.shape[0]}",
+ )
+
+ n_video_tokens = special_video_mask.sum()
+ special_video_mask = special_video_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device)
+ if video_features is not None:
+ torch_compilable_check(
+ inputs_embeds[special_video_mask].numel() == video_features.numel(),
+ f"Video features and video tokens do not match, tokens: {n_video_tokens}, features: {video_features.shape[0]}",
+ )
+ return special_image_mask, special_video_mask
+
+ @can_return_tuple
+ @auto_docstring
+ def forward(
+ self,
+ input_ids: torch.LongTensor | None = None,
+ pixel_values: torch.FloatTensor | None = None,
+ pixel_values_videos: torch.FloatTensor | None = None,
+ image_grid_thw: torch.Tensor | None = None,
+ video_grid_thw: torch.Tensor | None = None,
+ attention_mask: torch.Tensor | None = None,
+ position_ids: torch.LongTensor | None = None,
+ past_key_values: Cache | None = None,
+ inputs_embeds: torch.FloatTensor | None = None,
+ **kwargs: Unpack[TransformersKwargs],
+ ) -> tuple | MiniMaxM3VLModelOutputWithPast:
+ r"""
+ image_grid_thw (`torch.Tensor` of shape `(num_images, 3)`, *optional*):
+ The temporal, height and width of each image's feature grid, used to build the vision 3D RoPE
+ and to merge patch features.
+ video_grid_thw (`torch.Tensor` of shape `(num_videos, 3)`, *optional*):
+ The temporal, height and width of each video's feature grid, used to build the vision 3D RoPE
+ and to merge patch features.
+ """
+ if (input_ids is None) ^ (inputs_embeds is not None):
+ raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
+
+ if inputs_embeds is None:
+ inputs_embeds = self.get_input_embeddings()(input_ids)
+
+ image_features = None
+ if pixel_values is not None:
+ image_features = self.get_image_features(
+ pixel_values=pixel_values, image_grid_thw=image_grid_thw
+ ).pooler_output.to(inputs_embeds.device, inputs_embeds.dtype)
+
+ video_features = None
+ if pixel_values_videos is not None:
+ video_features = self.get_video_features(
+ pixel_values_videos=pixel_values_videos, video_grid_thw=video_grid_thw
+ ).pooler_output.to(inputs_embeds.device, inputs_embeds.dtype)
+
+ image_mask, video_mask = self.get_placeholder_mask(
+ input_ids, inputs_embeds, image_features=image_features, video_features=video_features
+ )
+ if image_features is not None:
+ inputs_embeds = inputs_embeds.masked_scatter(image_mask, image_features)
+ if video_features is not None:
+ inputs_embeds = inputs_embeds.masked_scatter(video_mask, video_features)
+
+ outputs = self.language_model(
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ past_key_values=past_key_values,
+ inputs_embeds=inputs_embeds,
+ **kwargs,
+ )
+
+ return MiniMaxM3VLModelOutputWithPast(
+ last_hidden_state=outputs.last_hidden_state,
+ past_key_values=outputs.past_key_values,
+ hidden_states=getattr(outputs, "hidden_states", None),
+ attentions=getattr(outputs, "attentions", None),
+ image_hidden_states=image_features,
+ video_hidden_states=video_features,
+ )
+
+ @merge_with_config_defaults
+ @can_return_tuple
+ @auto_docstring(
+ custom_intro="Obtains video last hidden states from the vision tower and apply multimodal projection."
+ )
+ def get_video_features(
+ self,
+ pixel_values_videos: torch.FloatTensor,
+ video_grid_thw: torch.Tensor,
+ **kwargs,
+ ) -> BaseModelOutputWithPooling:
+ r"""
+ pixel_values_videos (`torch.FloatTensor`):
+ The tensors corresponding to the input video frames.
+ video_grid_thw (`torch.Tensor` of shape `(num_videos, 3)`, *optional*):
+ The temporal, height and width of each video's feature grid, used to build the vision 3D RoPE
+ and to merge patch features.
+ """
+ # Video frames flow through the same vision pipeline as images (the tower is
+ # grid-agnostic); only the placeholder token they scatter into differs.
+ vision_outputs = self.vision_tower(pixel_values=pixel_values_videos, image_grid_thw=video_grid_thw, **kwargs)
+ vision_outputs.pooler_output = self.multi_modal_projector(vision_outputs.last_hidden_state.squeeze(0))
+ return vision_outputs
+
+
+@auto_docstring(custom_intro="MiniMax M3 VL full model with LM head (text + vision).")
+class MiniMaxM3SparseForConditionalGeneration(MiniMaxM3VLPreTrainedModel, GenerationMixin):
+ _tied_weights_keys = {"lm_head.weight": "model.language_model.embed_tokens.weight"}
+ config: MiniMaxM3VLConfig
+
+ def __init__(self, config: MiniMaxM3VLConfig):
+ super().__init__(config)
+ self.model = MiniMaxM3VLModel(config)
+ self.lm_head = nn.Linear(config.text_config.hidden_size, config.text_config.vocab_size, bias=False)
+ self.post_init()
+
+ def get_output_embeddings(self) -> nn.Module:
+ return self.lm_head
+
+ @auto_docstring
+ def get_image_features(self, pixel_values, image_grid_thw, **kwargs) -> tuple | BaseModelOutputWithPooling:
+ r"""
+ image_grid_thw (`torch.Tensor` of shape `(num_images, 3)`, *optional*):
+ The temporal, height and width of each image's feature grid, used to build the vision 3D RoPE
+ and to merge patch features.
+ """
+ return self.model.get_image_features(pixel_values, image_grid_thw, **kwargs)
+
+ @can_return_tuple
+ @auto_docstring
+ def forward(
+ self,
+ input_ids: torch.LongTensor | None = None,
+ pixel_values: torch.FloatTensor | None = None,
+ pixel_values_videos: torch.FloatTensor | None = None,
+ image_grid_thw: torch.Tensor | None = None,
+ video_grid_thw: torch.Tensor | None = None,
+ attention_mask: torch.Tensor | None = None,
+ position_ids: torch.LongTensor | None = None,
+ past_key_values: Cache | None = None,
+ inputs_embeds: torch.FloatTensor | None = None,
+ labels: torch.LongTensor | None = None,
+ logits_to_keep: int | torch.Tensor = 0,
+ **kwargs: Unpack[TransformersKwargs],
+ ) -> tuple | MiniMaxM3VLCausalLMOutputWithPast:
+ r"""
+ image_grid_thw (`torch.Tensor` of shape `(num_images, 3)`, *optional*):
+ The temporal, height and width of each image's feature grid, used to build the vision 3D RoPE
+ and to merge patch features.
+ video_grid_thw (`torch.Tensor` of shape `(num_videos, 3)`, *optional*):
+ The temporal, height and width of each video's feature grid, used to build the vision 3D RoPE
+ and to merge patch features.
+ """
+ outputs = self.model(
+ input_ids=input_ids,
+ pixel_values=pixel_values,
+ pixel_values_videos=pixel_values_videos,
+ image_grid_thw=image_grid_thw,
+ video_grid_thw=video_grid_thw,
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ past_key_values=past_key_values,
+ inputs_embeds=inputs_embeds,
+ **kwargs,
+ )
+ hidden_states = outputs.last_hidden_state
+ slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
+ logits = self.lm_head(hidden_states[:, slice_indices, :])
+
+ loss = None
+ if labels is not None:
+ loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.text_config.vocab_size)
+
+ return MiniMaxM3VLCausalLMOutputWithPast(
+ loss=loss,
+ logits=logits,
+ past_key_values=outputs.past_key_values,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
+ image_hidden_states=outputs.image_hidden_states,
+ video_hidden_states=outputs.video_hidden_states,
+ )
+
+ def prepare_inputs_for_generation(
+ self,
+ input_ids,
+ past_key_values=None,
+ inputs_embeds=None,
+ pixel_values=None,
+ pixel_values_videos=None,
+ attention_mask=None,
+ logits_to_keep=None,
+ is_first_iteration=False,
+ **kwargs,
+ ):
+ # Overwritten -- pixel inputs are merged into the cache on the first step, so we
+ # only forward them once (image and video alike).
+ model_inputs = super().prepare_inputs_for_generation(
+ input_ids,
+ past_key_values=past_key_values,
+ inputs_embeds=inputs_embeds,
+ attention_mask=attention_mask,
+ logits_to_keep=logits_to_keep,
+ is_first_iteration=is_first_iteration,
+ **kwargs,
+ )
+
+ if is_first_iteration or not kwargs.get("use_cache", True):
+ model_inputs["pixel_values"] = pixel_values
+ model_inputs["pixel_values_videos"] = pixel_values_videos
+
+ return model_inputs
+
+ def get_video_features(self, pixel_values_videos, video_grid_thw, **kwargs):
+ r"""
+ pixel_values_videos (`torch.FloatTensor`):
+ The tensors corresponding to the input video frames.
+ video_grid_thw (`torch.Tensor` of shape `(num_videos, 3)`, *optional*):
+ The temporal, height and width of each video's feature grid, used to build the vision 3D RoPE
+ and to merge patch features.
+ """
+ return self.model.get_video_features(pixel_values_videos, video_grid_thw, **kwargs)
+
+
+__all__ = [
+ "MiniMaxM3VLForCausalLM",
+ "MiniMaxM3SparseForConditionalGeneration",
+ "MiniMaxM3VLModel",
+ "MiniMaxM3VLPreTrainedModel",
+ "MiniMaxM3VLTextModel",
+ "MiniMaxM3VLVisionModel",
+]
diff --git a/src/transformers/models/minimax_m3_vl/modular_minimax_m3_vl.py b/src/transformers/models/minimax_m3_vl/modular_minimax_m3_vl.py
new file mode 100644
index 000000000000..e1c7a0962356
--- /dev/null
+++ b/src/transformers/models/minimax_m3_vl/modular_minimax_m3_vl.py
@@ -0,0 +1,1308 @@
+# Copyright 2026 the MiniMax AI Team and HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""MiniMax M3 VL: vision tower + M3 (mixed sparse/dense MoE) text backbone."""
+
+from collections.abc import Callable
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from huggingface_hub.dataclasses import strict
+
+from ... import initialization as init
+from ...activations import ACT2FN
+from ...cache_utils import Cache, DynamicCache, DynamicLayer, StaticLayer
+from ...configuration_utils import PreTrainedConfig
+from ...masking_utils import create_causal_mask
+from ...modeling_outputs import BaseModelOutputWithPooling, MoeModelOutputWithPast
+from ...modeling_rope_utils import RopeParameters
+from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
+from ...processing_utils import Unpack
+from ...utils import TransformersKwargs, auto_docstring, logging, torch_compilable_check
+from ...utils.generic import can_return_tuple, merge_with_config_defaults
+from ...utils.import_utils import is_torchdynamo_compiling
+from ...utils.output_capturing import capture_outputs
+from ..auto import AutoConfig
+from ..clip.modeling_clip import CLIPMLP, CLIPAttention, CLIPEncoderLayer
+from ..deepseek_v4.modeling_deepseek_v4 import DeepseekV4Experts
+from ..gemma3.modeling_gemma3 import Gemma3RMSNorm
+from ..laguna.modeling_laguna import LagunaSparseMoeBlock
+from ..llama.modeling_llama import eager_attention_forward
+from ..llava.modeling_llava import (
+ LlavaCausalLMOutputWithPast,
+ LlavaForConditionalGeneration,
+ LlavaModel,
+ LlavaModelOutputWithPast,
+)
+from ..minimax_m2.configuration_minimax_m2 import MiniMaxM2Config
+from ..minimax_m2.modeling_minimax_m2 import (
+ MiniMaxM2Attention,
+ MiniMaxM2ForCausalLM,
+ MiniMaxM2Model,
+ MiniMaxM2PreTrainedModel,
+ MiniMaxM2RotaryEmbedding,
+ MiniMaxM2TopKRouter,
+ apply_rotary_pos_emb,
+)
+from ..mixtral.modeling_mixtral import MixtralDecoderLayer
+from ..qwen2_5_vl.modeling_qwen2_5_vl import Qwen2_5_VisionPatchEmbed
+from ..qwen2_vl.processing_qwen2_vl import Qwen2VLProcessor, Qwen2VLProcessorKwargs
+
+
+logger = logging.get_logger(__name__)
+
+
+@auto_docstring(checkpoint="MiniMaxAI/MiniMax-M3-preview")
+@strict
+class MiniMaxM3VLTextConfig(MiniMaxM2Config):
+ r"""
+ dense_intermediate_size (`int`, *optional*, defaults to 12288):
+ Intermediate size of the dense MLP used on layers whose `mlp_layer_types` entry is `"dense"`.
+ shared_intermediate_size (`int`, *optional*, defaults to 3072):
+ Intermediate size of a single shared expert in the MoE layers.
+ rotary_dim (`int`, *optional*, defaults to 64):
+ Number of head channels rotated by RoPE; the remaining channels are passed through unchanged.
+ swiglu_alpha (`float`, *optional*, defaults to 1.702):
+ Sigmoid gain of the SwiGLU-OAI activation.
+ swiglu_limit (`float`, *optional*, defaults to 7.0):
+ Clamp bound applied to the gate and up projections of the SwiGLU-OAI activation.
+ mlp_layer_types (`list[str]`, *optional*):
+ Per-layer MLP selector: `"sparse"` for a MoE block, `"dense"` for a dense MLP.
+ index_n_heads (`int`, *optional*, defaults to 4):
+ Number of heads in the lightning indexer's dot-product scoring branch.
+ index_head_dim (`int`, *optional*, defaults to 128):
+ Per-head channel dimension of the lightning indexer.
+ index_block_size (`int`, *optional*, defaults to 128):
+ Number of key tokens pooled into a single scored block.
+ index_topk_blocks (`int`, *optional*, defaults to 16):
+ Number of top-scoring key blocks each query may attend to.
+ index_local_blocks (`int`, *optional*, defaults to 1):
+ Number of key blocks immediately preceding the query always kept visible / attended to.
+ """
+
+ model_type = "minimax_m3_vl_text"
+ base_config_key = "text_config"
+ base_model_ep_plan = {
+ "layers.*.mlp.gate": "ep_router",
+ "layers.*.mlp.experts.gate_up_proj": "grouped_gemm",
+ "layers.*.mlp.experts.gate_up_proj_scale_inv": "grouped_gemm",
+ "layers.*.mlp.experts.down_proj": "grouped_gemm",
+ "layers.*.mlp.experts.down_proj_scale_inv": "grouped_gemm",
+ "layers.*.mlp.experts": "moe_tp_experts",
+ }
+
+ hidden_size: int = 6144
+ intermediate_size: int = 3072
+ dense_intermediate_size: int = 12288
+ shared_intermediate_size: int = 3072
+ num_hidden_layers: int = 60
+ num_attention_heads: int = 64
+ num_key_value_heads: int = 4
+ head_dim: int = 128
+ max_position_embeddings: int = 524288
+ vocab_size: int = 200064
+ rms_norm_eps: float = 1e-06
+ num_local_experts: int = 128
+ num_experts_per_tok: int = 4
+ routed_scaling_factor: float = 2.0
+ rotary_dim: int = 64
+ swiglu_alpha: float = 1.702
+ swiglu_limit: float = 7.0
+ mlp_layer_types: list[str] | None = None
+ index_n_heads: int = 4
+ index_head_dim: int = 128
+ index_block_size: int = 128
+ index_topk_blocks: int = 16
+ index_local_blocks: int = 1
+ layer_types: list[str] | None = None
+ tie_word_embeddings: bool = False
+ pad_token_id: int | None = None
+ bos_token_id: int | None = 200034
+ eos_token_id: int | list[int] | None = 200020
+ rope_parameters: RopeParameters | dict | None = None
+
+ def __post_init__(self, **kwargs):
+ sparse_cfg = kwargs.pop("sparse_attention_config", None) or {}
+ moe_layer_freq = kwargs.pop("moe_layer_freq", None)
+ PreTrainedConfig.__post_init__(self, **kwargs)
+ # Checkpoint declares "swigluoai", but the gate is computed inline from swiglu_alpha/limit; hidden_act
+ # is only the pointwise fallback and must be a real ACT2FN key, so normalize it to silu.
+ self.hidden_act = "silu"
+
+ for flat, legacy in {
+ "index_n_heads": "sparse_num_index_heads",
+ "index_head_dim": "sparse_index_dim",
+ "index_block_size": "sparse_block_size",
+ "index_topk_blocks": "sparse_topk_blocks",
+ "index_local_blocks": "sparse_local_block",
+ }.items():
+ if legacy in sparse_cfg:
+ setattr(self, flat, sparse_cfg[legacy])
+
+ # `layer_types` is the canonical per-layer attention dispatch: it tells
+ # `DynamicCache(config=...)` which layers want the sparse cache and tells
+ # `MiniMaxM3VLAttention` which layers build a sparse Lightning Indexer.
+ if self.layer_types is None and "sparse_attention_freq" in sparse_cfg:
+ self.layer_types = [
+ "minimax_m3_sparse" if f else "full_attention" for f in sparse_cfg["sparse_attention_freq"]
+ ]
+ if self.layer_types is None:
+ self.layer_types = ["full_attention"] * self.num_hidden_layers
+
+ # `mlp_layer_types` is the per-layer MLP dispatch read by `MiniMaxM3VLDecoderLayer`:
+ if self.mlp_layer_types is None and moe_layer_freq is not None:
+ self.mlp_layer_types = ["sparse" if f else "dense" for f in moe_layer_freq]
+ if self.mlp_layer_types is None:
+ self.mlp_layer_types = ["sparse"] * self.num_hidden_layers
+
+
+@auto_docstring(checkpoint="MiniMaxAI/MiniMax-M3-preview")
+@strict
+class MiniMaxM3VLVisionConfig(PreTrainedConfig):
+ r"""
+ rope_parameters (`RopeParameters`, *optional*):
+ Standard RoPE configuration for the vision tower's 3D rotary position embedding.
+ """
+
+ model_type = "minimax_m3_vl_vision"
+ base_config_key = "vision_config"
+ default_theta = 10000.0
+
+ hidden_size: int = 1280
+ intermediate_size: int = 5120
+ num_hidden_layers: int = 32
+ num_attention_heads: int = 16
+ num_channels: int = 3
+ image_size: int = 2016
+ patch_size: int = 14
+ temporal_patch_size: int = 2
+ spatial_merge_size: int = 2
+ hidden_act: str = "gelu"
+ layer_norm_eps: float = 1e-05
+ attention_dropout: float = 0.0
+ rope_parameters: RopeParameters | dict | None = None
+ initializer_range: float = 0.02
+
+
+@auto_docstring(checkpoint="MiniMaxAI/MiniMax-M3-preview")
+@strict
+class MiniMaxM3VLConfig(PreTrainedConfig):
+ model_type = "minimax_m3_vl"
+ sub_configs = {"text_config": AutoConfig, "vision_config": AutoConfig}
+ attribute_map = {
+ "image_token_id": "image_token_index",
+ "video_token_id": "video_token_index",
+ }
+
+ vision_config: dict | PreTrainedConfig | None = None
+ text_config: dict | PreTrainedConfig | None = None
+ image_token_index: int = 200025
+ video_token_index: int = 200026
+ projector_hidden_size: int = 6144
+ tie_word_embeddings: bool = False
+
+ def __post_init__(self, **kwargs):
+ if isinstance(self.vision_config, dict):
+ self.vision_config.pop("model_type", None)
+ self.vision_config = MiniMaxM3VLVisionConfig(**self.vision_config)
+ elif self.vision_config is None:
+ self.vision_config = MiniMaxM3VLVisionConfig()
+
+ if isinstance(self.text_config, dict):
+ self.text_config.pop("model_type", None)
+ self.text_config = MiniMaxM3VLTextConfig(**self.text_config)
+ elif self.text_config is None:
+ self.text_config = MiniMaxM3VLTextConfig()
+
+ if not self.tie_word_embeddings and self.text_config.tie_word_embeddings:
+ self.tie_word_embeddings = self.text_config.tie_word_embeddings
+
+ # Channel dim after grouping `spatial_merge_size**2` projected patches, consumed by the
+ # patch-merge MLP inside `MiniMaxM3VLMultiModalProjector`.
+ self.merged_hidden_size = self.text_config.hidden_size * (self.vision_config.spatial_merge_size**2)
+
+ super().__post_init__(**kwargs)
+
+
+class MiniMaxM3VLSparseCacheLayer(DynamicLayer):
+ layer_type = "minimax_m3_sparse"
+
+ def __init__(self, config: PreTrainedConfig | None = None):
+ super().__init__(config)
+ self.idx_keys: torch.Tensor | None = None
+
+ def update_index(self, idx_k: torch.Tensor) -> torch.Tensor:
+ """Append the new token's `idx_k` to the cache and return the full history."""
+ self.idx_keys = idx_k if self.idx_keys is None else torch.cat([self.idx_keys, idx_k], dim=-2)
+ return self.idx_keys
+
+ def reorder_cache(self, beam_idx: torch.LongTensor) -> None:
+ super().reorder_cache(beam_idx)
+ if self.idx_keys is not None:
+ self.idx_keys = self.idx_keys.index_select(0, beam_idx.to(self.idx_keys.device))
+
+ def batch_repeat_interleave(self, repeats: int) -> None:
+ super().batch_repeat_interleave(repeats)
+ if self.idx_keys is not None:
+ self.idx_keys = self.idx_keys.repeat_interleave(repeats, dim=0)
+
+ def batch_select_indices(self, indices: torch.Tensor) -> None:
+ super().batch_select_indices(indices)
+ if self.idx_keys is not None:
+ self.idx_keys = self.idx_keys[indices, ...]
+
+ def crop(self, max_length: int) -> None:
+ super().crop(max_length)
+ if max_length < 0:
+ max_length = self.get_seq_length() - abs(max_length)
+ if self.idx_keys is not None and self.idx_keys.shape[-2] > max_length:
+ self.idx_keys = self.idx_keys[..., :max_length, :]
+
+
+class MiniMaxM3VLSparseStaticCacheLayer(StaticLayer):
+ layer_type = "minimax_m3_sparse"
+
+ def __init__(self, max_cache_len: int):
+ super().__init__(max_cache_len)
+ self.idx_keys: torch.Tensor | None = None
+ # Tensor (not int) so it can be marked as a static address for cudagraphs, like `cumulative_length`.
+ self.idx_cumulative_length = torch.tensor([0], dtype=int)
+
+ def update_index(self, idx_k: torch.Tensor) -> torch.Tensor:
+ """Write the new token's `idx_k` into the static buffer in place and return the whole buffer.
+
+ The buffer's unfilled tail holds zeros, but those slots sit at key positions ahead of every
+ current query, so the indexer's block- and token-level causal masking discards them — the
+ returned `[B, 1, max_cache_len, D]` history is therefore safe to score against directly.
+ """
+ if self.idx_keys is None:
+ self.idx_keys = torch.zeros(
+ (idx_k.shape[0], idx_k.shape[1], self.max_cache_len, idx_k.shape[-1]),
+ dtype=idx_k.dtype,
+ device=idx_k.device,
+ )
+ self.idx_cumulative_length = self.idx_cumulative_length.to(idx_k.device)
+ if not is_torchdynamo_compiling():
+ torch._dynamo.mark_static_address(self.idx_keys)
+ torch._dynamo.mark_static_address(self.idx_cumulative_length)
+
+ kv_len = idx_k.shape[-2]
+ cache_position = torch.arange(kv_len, device=self.idx_keys.device) + self.idx_cumulative_length
+ self.idx_cumulative_length.add_(kv_len)
+ try:
+ self.idx_keys.index_copy_(2, cache_position, idx_k)
+ except NotImplementedError:
+ # Fallback for devices like MPS where index_copy_ might not be supported.
+ self.idx_keys[:, :, cache_position] = idx_k
+ return self.idx_keys
+
+ def reset(self) -> None:
+ super().reset()
+ if self.idx_keys is not None:
+ self.idx_keys.zero_()
+ self.idx_cumulative_length.zero_()
+
+ def reorder_cache(self, beam_idx: torch.LongTensor) -> None:
+ super().reorder_cache(beam_idx)
+ if self.idx_keys is not None:
+ self.idx_keys = self.idx_keys.index_select(0, beam_idx.to(self.idx_keys.device))
+
+
+class MiniMaxM3VLRMSNorm(Gemma3RMSNorm):
+ """Gemma-style RMSNorm: normalizes in fp32 and scales by `weight + 1`."""
+
+
+class MiniMaxM3VLDenseMLP(nn.Module):
+ def __init__(self, config: MiniMaxM3VLTextConfig, intermediate_size: int | None = None):
+ super().__init__()
+ inter = intermediate_size if intermediate_size is not None else config.dense_intermediate_size
+ self.swiglu_alpha = config.swiglu_alpha
+ self.swiglu_limit = config.swiglu_limit
+ self.gate_up_proj = nn.Linear(config.hidden_size, 2 * inter, bias=False)
+ self.down_proj = nn.Linear(inter, config.hidden_size, bias=False)
+
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
+ gate_up = self.gate_up_proj(hidden_states)
+ gate, up = gate_up.chunk(2, dim=-1)
+ gate = gate.clamp(max=self.swiglu_limit)
+ up = up.clamp(min=-self.swiglu_limit, max=self.swiglu_limit)
+ glu = gate * torch.sigmoid(gate * self.swiglu_alpha)
+ return self.down_proj((up + 1.0) * glu)
+
+
+class MiniMaxM3VLExperts(DeepseekV4Experts):
+ def __init__(self, config: MiniMaxM3VLTextConfig):
+ super().__init__(config)
+ self.swiglu_alpha = config.swiglu_alpha
+ self.swiglu_limit = config.swiglu_limit
+ del self.act_fn
+
+ def _apply_gate(self, gate_up: torch.Tensor) -> torch.Tensor:
+ # same as GPT OSS, but the weights are not interleaved
+ gate, up = gate_up.chunk(2, dim=-1)
+ gate = gate.clamp(max=self.swiglu_limit)
+ up = up.clamp(min=-self.swiglu_limit, max=self.swiglu_limit)
+ glu = gate * torch.sigmoid(gate * self.swiglu_alpha)
+ return (up + 1.0) * glu
+
+
+class MiniMaxM3VLTopKRouter(MiniMaxM2TopKRouter):
+ def __init__(self, config: MiniMaxM3VLTextConfig):
+ super().__init__(config)
+ self.register_buffer("e_score_correction_bias", torch.zeros(config.num_local_experts))
+
+ def forward(self, hidden_states: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
+ hidden_states = hidden_states.reshape(-1, self.hidden_dim)
+ router_logits = F.linear(hidden_states.to(self.weight.dtype), self.weight)
+ # Sigmoid scoring (not softmax), as in M2.
+ routing_weights = F.sigmoid(router_logits.float())
+ scores_for_choice = routing_weights + self.e_score_correction_bias
+ _, top_k_index = torch.topk(scores_for_choice, self.top_k, dim=-1, sorted=False)
+ top_k_weights = routing_weights.gather(1, top_k_index)
+ top_k_weights /= top_k_weights.sum(dim=-1, keepdim=True)
+ return router_logits, top_k_weights, top_k_index
+
+
+class MiniMaxM3VLSparseMoeBlock(LagunaSparseMoeBlock):
+ def __init__(self, config: MiniMaxM3VLTextConfig):
+ nn.Module.__init__(self)
+ self.gate = MiniMaxM3VLTopKRouter(config)
+ self.experts = MiniMaxM3VLExperts(config)
+ self.routed_scaling_factor = config.routed_scaling_factor
+ self.shared_experts = MiniMaxM3VLDenseMLP(config, intermediate_size=config.shared_intermediate_size)
+
+
+class MiniMaxM3VLRotaryEmbedding(MiniMaxM2RotaryEmbedding):
+ pass
+
+
+class MiniMaxM3VLAttention(MiniMaxM2Attention):
+ """
+ M3 attention: per-head Gemma QK-norm + partial RoPE, optionally sparse indexer selection which require position IDs.
+ """
+
+ def __init__(self, config: MiniMaxM3VLTextConfig, layer_idx: int):
+ super().__init__(config, layer_idx)
+ self.q_norm = MiniMaxM3VLRMSNorm(self.head_dim, eps=config.rms_norm_eps)
+ self.k_norm = MiniMaxM3VLRMSNorm(self.head_dim, eps=config.rms_norm_eps)
+ self.indexer = (
+ MiniMaxM3VLIndexer(config, layer_idx) if config.layer_types[layer_idx] == "minimax_m3_sparse" else None
+ )
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ position_embeddings: tuple[torch.Tensor, torch.Tensor],
+ attention_mask: torch.Tensor | None,
+ past_key_values: Cache | None = None,
+ **kwargs: Unpack[TransformersKwargs],
+ ) -> tuple[torch.Tensor, torch.Tensor | None]:
+ input_shape = hidden_states.shape[:-1]
+ hidden_shape = (*input_shape, -1, self.head_dim)
+
+ query_states = self.q_norm(self.q_proj(hidden_states).view(hidden_shape)).transpose(1, 2)
+ key_states = self.k_norm(self.k_proj(hidden_states).view(hidden_shape)).transpose(1, 2)
+ value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)
+
+ cos, sin = position_embeddings
+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
+
+ if past_key_values is not None:
+ key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx)
+
+ attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface(
+ self.config._attn_implementation, eager_attention_forward
+ )
+ block_indices = None
+ if self.indexer is not None:
+ position_ids = kwargs.get("position_ids")
+ if position_ids is None:
+ position_ids = torch.arange(
+ key_states.shape[2] - query_states.shape[2], key_states.shape[2], device=query_states.device
+ )
+ position_ids = (position_ids if position_ids.ndim > 1 else position_ids.unsqueeze(0)).expand(
+ query_states.shape[0], -1
+ )
+ block_indices = self.indexer(hidden_states, position_embeddings, past_key_values, position_ids)
+ if self.config._attn_implementation in ("eager", "sdpa"):
+ attention_mask = self.indexer.build_block_mask(
+ block_indices,
+ attention_mask,
+ key_states.shape[2],
+ query_states.dtype,
+ query_states.device,
+ position_ids,
+ )
+
+ attn_output, attn_weights = attention_interface(
+ self,
+ query_states,
+ key_states,
+ value_states,
+ attention_mask,
+ dropout=0.0 if not self.training else self.attention_dropout,
+ scaling=self.scaling,
+ block_indices=block_indices,
+ **kwargs,
+ )
+ attn_output = attn_output.reshape(*input_shape, -1).contiguous()
+ return self.o_proj(attn_output), attn_weights
+
+
+class MiniMaxM3VLIndexer(nn.Module):
+ r"""Lightning Indexer for MiniMax M3 sparse attention.
+
+ Scores each query against every key with a small `index_n_heads`-head
+ dot-product branch, then max-pools those per-key scores into *blocks* of
+ `index_block_size` keys and keeps, per query, the top-`index_topk_blocks`
+ key blocks plus the `index_local_blocks` blocks immediately preceding the
+ query (always visible). Selection therefore happens at the granularity of a
+ *block of keys* rather than individual keys: the expensive main attention
+ only has to attend the handful of selected key blocks, which is what makes
+ it block-sparse (and cheaper) on long sequences.
+
+ The `index_local_blocks` boosting their score so they always win key slots, the
+ same way the deployment block-sparse kernel (MiniMax `topk_sparse`) does it.
+
+ `forward` returns the per-query selected key-block indices
+ `[B, S_q, index_topk_blocks]`. Valid indices are left-packed and `-1`
+ right-pads the unused slots (future/empty blocks), and the local boost makes
+ selections deduplicated -- the exact contract the block-sparse attention
+ kernel consumes (it counts the valid entries, then reads them sequentially
+ and would double-count a repeated block). The eager/SDPA path instead calls
+ `build_block_mask`, which expands the indices into the dense
+ `[B, 1, S_q, S_k]` additive mask the standard attention interface expects
+ (`0` at every allowed (query, key) pair, `-inf` elsewhere).
+
+ Like DeepSeek-V4's indexer this is purely a *selection* branch: it has no
+ value projection and produces no residual output of its own (the upstream
+ checkpoint disables the index-value path on every sparse layer).
+
+ TODO: blocks are anchored to absolute key *slots* (the contiguous reshape in
+ `forward` and `q_block = slot // block_size`), so left-padding shifts the block
+ boundaries and the selection diverges from an unpadded run -- only right-padding
+ is equivalent (same limitation as DeepSeek-V4; see `test_right_padding_does_not_leak`
+ / the skipped `test_left_padding_compatibility`). For *true* left-padding equivalence
+ we'd make blocking content-relative instead of slot-relative:
+ 1. derive block ids from `position_ids` (content positions, 0 at each row's first
+ real token) rather than from absolute slots, and
+ 2. replace the contiguous `view(..., num_key_blocks, block_size).amax(-1)` key pool
+ with a per-row position-binned pool (e.g. `scatter_reduce` over `key_position //
+ block_size`), so pad never shifts the boundaries, and
+ 3. mask padded keys' scores to `-inf` before the pool so a pad key can't win a block
+ a top-k slot.
+ """
+
+ def __init__(self, config: MiniMaxM3VLTextConfig, layer_idx: int):
+ super().__init__()
+ self.config = config
+ self.layer_idx = layer_idx
+ self.head_dim = config.index_head_dim
+ self.num_heads = config.index_n_heads
+ self.block_size = config.index_block_size
+ self.topk_blocks = config.index_topk_blocks
+ self.local_blocks = config.index_local_blocks
+ self.q_proj = nn.Linear(config.hidden_size, config.index_n_heads * config.index_head_dim, bias=False)
+ self.k_proj = nn.Linear(config.hidden_size, config.index_head_dim, bias=False)
+ self.q_norm = MiniMaxM3VLRMSNorm(config.index_head_dim, eps=config.rms_norm_eps)
+ self.k_norm = MiniMaxM3VLRMSNorm(config.index_head_dim, eps=config.rms_norm_eps)
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ position_embeddings: tuple[torch.Tensor, torch.Tensor],
+ past_key_values: Cache | None,
+ position_ids: torch.Tensor,
+ ) -> torch.Tensor:
+ batch, q_len, _ = hidden_states.shape
+ idx_q = self.q_proj(hidden_states).view(batch, q_len, -1, self.head_dim)
+ idx_q = self.q_norm(idx_q).transpose(1, 2) # [B, H_idx, Sq, D]
+ idx_k = self.k_proj(hidden_states).view(batch, q_len, 1, self.head_dim)
+ idx_k = self.k_norm(idx_k).transpose(1, 2) # [B, 1, Sq, D]
+ cos, sin = position_embeddings
+ idx_q, idx_k = apply_rotary_pos_emb(idx_q, idx_k, cos[..., : self.head_dim], sin[..., : self.head_dim])
+
+ if past_key_values is not None:
+ idx_k = past_key_values.layers[self.layer_idx].update_index(idx_k)
+
+ k_len = idx_k.shape[2]
+ num_key_blocks = -(-k_len // self.block_size) # ceil-div
+ pad = num_key_blocks * self.block_size - k_len
+
+ scores = torch.matmul(idx_q.float(), idx_k.float().transpose(-1, -2))
+ k_positions = torch.arange(k_len, device=idx_q.device)
+ token_future = k_positions[None, None, None, :] > position_ids[:, None, :, None] # [B, 1, S_q, S_k]
+ scores = scores.masked_fill(token_future, float("-inf"))
+ if pad:
+ scores = F.pad(scores, (0, pad), value=float("-inf"))
+ scores = scores.view(batch, self.num_heads, q_len, num_key_blocks, self.block_size)
+ block_scores = scores.amax(dim=-1).amax(dim=1) # -> [B, S_q, num_key_blocks]
+
+ q_block = position_ids // self.block_size # [B, S_q]
+
+ if self.local_blocks > 0:
+ local = torch.arange(self.local_blocks, device=idx_q.device)
+ local_idx = (q_block[..., None] - local.view(1, 1, -1)).clamp(min=0) # [B, S_q, local]
+ block_scores.scatter_(-1, local_idx, float("inf"))
+
+ # Slots that fall on a future/empty block keep their `-inf`
+ # score, which top-k sorts to the end, so tagging them `-1` yields left-packed block indices
+ # with `-1` right-padding which is the format expect by block-sparse attention kernel.
+ topk = min(self.topk_blocks, num_key_blocks)
+ topk_scores, topk_indices = block_scores.topk(topk, dim=-1) # [B, S_q, topk]
+ return topk_indices.masked_fill(topk_scores == float("-inf"), -1)
+
+ def build_block_mask(
+ self,
+ block_indices: torch.Tensor,
+ attention_mask: torch.Tensor | None,
+ key_length: int,
+ dtype: torch.dtype,
+ device: torch.device,
+ position_ids: torch.Tensor,
+ ) -> torch.Tensor:
+ """
+ We build the full 4D attention mask (Batch, query, key, head)
+ """
+ batch, q_len, _ = block_indices.shape
+ num_key_blocks = -(-key_length // self.block_size)
+
+ # Scatter the kept blocks to `0`; `-1` slots land in a throwaway column we drop afterwards.
+ safe = block_indices.masked_fill(block_indices < 0, num_key_blocks)
+ bias = block_indices.new_full((batch, q_len, num_key_blocks + 1), float("-inf"), dtype=dtype)
+ bias.scatter_(-1, safe, 0.0)
+ bias = bias[..., :num_key_blocks]
+
+ # Broadcast the per-block keep/drop verdict back onto every key (block granularity), add head axis.
+ block_keep = (bias == 0.0).repeat_interleave(self.block_size, dim=-1)[..., :key_length].unsqueeze(1)
+
+ # Compose block-selection with the existing mask, then emit a single additive float mask.
+ if attention_mask is not None:
+ padding_mask = attention_mask if attention_mask.dtype == torch.bool else attention_mask == 0
+ keep = block_keep & padding_mask
+ else:
+ k_positions = torch.arange(key_length, device=device)
+ token_future = k_positions[None, None, None, :] > position_ids[:, None, :, None] # [B, 1, S_q, S_k]
+ keep = block_keep & ~token_future
+ min_dtype = torch.finfo(dtype).min
+ return torch.zeros(keep.shape, dtype=dtype, device=device).masked_fill(~keep, min_dtype)
+
+
+class MiniMaxM3VLDecoderLayer(MixtralDecoderLayer):
+ """M3 decoder layer: per-layer dense/MoE MLP and dense/sparse attention."""
+
+ def __init__(self, config: MiniMaxM3VLTextConfig, layer_idx: int):
+ super().__init__(config, layer_idx)
+ self.self_attn = MiniMaxM3VLAttention(config, layer_idx)
+ self.mlp = (
+ MiniMaxM3VLSparseMoeBlock(config)
+ if config.mlp_layer_types[layer_idx] == "sparse"
+ else MiniMaxM3VLDenseMLP(config, intermediate_size=config.dense_intermediate_size)
+ )
+ self.input_layernorm = MiniMaxM3VLRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
+ self.post_attention_layernorm = MiniMaxM3VLRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
+
+
+class MiniMaxM3VLPreTrainedModel(MiniMaxM2PreTrainedModel):
+ config: MiniMaxM3VLConfig | MiniMaxM3VLTextConfig
+ base_model_prefix = "model"
+ _no_split_modules = ["MiniMaxM3VLDecoderLayer", "MiniMaxM3VLVisionEncoderLayer"]
+ input_modalities = ("image", "video", "text")
+ _keys_to_ignore_on_load_unexpected = [r"(^|\.)mtp\..*"]
+ _supports_flash_attn = False
+ _supports_sdpa = True
+ _supports_flex_attn = False
+ _can_compile_fullgraph = True
+ _supports_attention_backend = True
+ _compatible_flash_implementations = ["kernels-staging/msa@v0"]
+
+ @torch.no_grad()
+ def _init_weights(self, module):
+ PreTrainedModel._init_weights(self, module)
+ std = getattr(self.config, "initializer_range", 0.02)
+ if isinstance(module, MiniMaxM3VLExperts):
+ init.normal_(module.gate_up_proj, mean=0.0, std=std)
+ init.normal_(module.down_proj, mean=0.0, std=std)
+ elif isinstance(module, MiniMaxM3VLTopKRouter):
+ init.normal_(module.weight, mean=0.0, std=std)
+ init.zeros_(module.e_score_correction_bias)
+ elif isinstance(module, MiniMaxM3VLRMSNorm):
+ init.zeros_(module.weight)
+
+
+class MiniMaxM3VLTextModel(MiniMaxM2Model):
+ config: MiniMaxM3VLTextConfig
+
+ def __init__(self, config: MiniMaxM3VLTextConfig):
+ super().__init__(config)
+ self.layers = nn.ModuleList([MiniMaxM3VLDecoderLayer(config, i) for i in range(config.num_hidden_layers)])
+ self.norm = MiniMaxM3VLRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
+
+ def forward(
+ self,
+ input_ids: torch.LongTensor | None = None,
+ attention_mask: torch.Tensor | None = None,
+ position_ids: torch.LongTensor | None = None,
+ past_key_values: Cache | None = None,
+ inputs_embeds: torch.FloatTensor | None = None,
+ use_cache: bool | None = None,
+ **kwargs: Unpack[TransformersKwargs],
+ ) -> MoeModelOutputWithPast:
+ if (input_ids is None) ^ (inputs_embeds is not None):
+ raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
+
+ if use_cache and past_key_values is None:
+ past_key_values = DynamicCache(config=self.config)
+
+ if inputs_embeds is None:
+ inputs_embeds = self.embed_tokens(input_ids)
+
+ if position_ids is None:
+ past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
+ position_ids = torch.arange(inputs_embeds.shape[1], device=inputs_embeds.device) + past_seen_tokens
+ position_ids = position_ids.unsqueeze(0)
+
+ if isinstance(attention_mask, dict):
+ causal_mask = next(iter(attention_mask.values()))
+ else:
+ causal_mask = create_causal_mask(
+ config=self.config,
+ inputs_embeds=inputs_embeds,
+ attention_mask=attention_mask,
+ past_key_values=past_key_values,
+ position_ids=position_ids,
+ )
+
+ hidden_states = inputs_embeds
+ position_embeddings = self.rotary_emb(hidden_states, position_ids=position_ids)
+
+ # `position_ids` is threaded to every layer so the sparse layers' lightning indexer can anchor
+ # block selection to each query's content position (see `MiniMaxM3VLIndexer`).
+ for decoder_layer in self.layers[: self.config.num_hidden_layers]:
+ hidden_states = decoder_layer(
+ hidden_states,
+ attention_mask=causal_mask,
+ position_ids=position_ids,
+ past_key_values=past_key_values,
+ use_cache=use_cache,
+ position_embeddings=position_embeddings,
+ **kwargs,
+ )
+
+ hidden_states = self.norm(hidden_states)
+
+ return MoeModelOutputWithPast(
+ last_hidden_state=hidden_states,
+ past_key_values=past_key_values,
+ )
+
+
+class MiniMaxM3VLForCausalLM(MiniMaxM2ForCausalLM):
+ config: MiniMaxM3VLTextConfig
+ _tied_weights_keys = {"lm_head.weight": "model.embed_tokens.weight"}
+
+ def __init__(self, config: MiniMaxM3VLTextConfig):
+ super().__init__(config)
+ self.model = MiniMaxM3VLTextModel(config)
+ self.post_init()
+
+
+class MiniMaxM3VLVisionEmbeddings(Qwen2_5_VisionPatchEmbed):
+ """Patch embedding, identical to [`Qwen2_5_VisionPatchEmbed`] (reads its dims from the vision
+ config). The upstream checkpoint stores the conv as `patch_embedding`, renamed to the
+ inherited `proj` in the conversion mapping."""
+
+ def __init__(self, config) -> None:
+ nn.Module.__init__(self)
+ self.patch_size = config.patch_size
+ self.temporal_patch_size = config.temporal_patch_size
+ self.in_channels = config.num_channels
+ self.embed_dim = config.hidden_size
+
+ kernel_size = [self.temporal_patch_size, self.patch_size, self.patch_size]
+ self.proj = nn.Conv3d(
+ self.in_channels, self.embed_dim, kernel_size=kernel_size, stride=kernel_size, bias=False
+ )
+
+
+class MiniMaxM3VL3DRotaryEmbedding(nn.Module):
+ r"""3D RoPE for the vision tower: each patch is rotated by its `(T, H, W)` grid position.
+
+ `2 * (head_dim // 2)` rotary dims are split evenly across the three axes (each rounded
+ down to a multiple of 2), giving `axis_dim` dims per axis and `axis_dim // 2` frequencies::
+
+ |<------------------ rotated (3 * axis_dim) ------------------>|<- pass ->|
+ +--------------------+--------------------+--------------------+----------+
+ | T (frames) | H (rows) | W (cols) | |
+ | axis_dim | axis_dim | axis_dim | |
+ +--------------------+--------------------+--------------------+----------+
+
+ Each axis' coordinate scales its own band of frequencies; the bands are concatenated as
+ `T|H|W` and duplicated via `cat([f, f])` to pair with the half-rotation in
+ `apply_rotary_pos_emb_vision`. Any head dims past `3 * axis_dim` are left unrotated.
+ """
+
+ def __init__(self, head_dim: int, theta: float = 10000.0, spatial_merge_size: int = 1):
+ super().__init__()
+ # `2 * (head_dim // 2)` rotary dims are split evenly across T/H/W, each axis rounded
+ # down to a multiple of 2. With head_dim=80 that is 26 dims/axis (39 freqs total); the
+ # remaining `head_dim - 3 * axis_dim` dims are never rotated (they pass through).
+ rope_dims = 2 * (head_dim // 2)
+ self.axis_dim = 2 * ((rope_dims // 3) // 2)
+ self.spatial_merge_size = spatial_merge_size
+ self.theta = theta
+
+ def forward(
+ self, grid_thw: torch.Tensor, device: torch.device, dtype: torch.dtype
+ ) -> tuple[torch.Tensor, torch.Tensor]:
+ m = self.spatial_merge_size
+ coords = []
+ for t, h, w in grid_thw.tolist():
+ hi = torch.arange(h).unsqueeze(1).expand(-1, w)
+ hi = hi.reshape(h // m, m, w // m, m).permute(0, 2, 1, 3).flatten()
+ wi = torch.arange(w).unsqueeze(0).expand(h, -1)
+ wi = wi.reshape(h // m, m, w // m, m).permute(0, 2, 1, 3).flatten()
+ ti = torch.arange(t).repeat_interleave(h * w)
+ coords.append(torch.stack([ti, hi.repeat(t), wi.repeat(t)], dim=-1))
+ coords = torch.cat(coords).to(device=device, dtype=torch.float32)
+
+ # meta device init was having trouble when it was registered. TODO standardize?
+ inv_freq = 1.0 / (
+ self.theta ** (torch.arange(0, self.axis_dim, 2, dtype=torch.float32, device=device) / self.axis_dim)
+ )
+ freqs = torch.cat([coords[:, i : i + 1] * inv_freq for i in range(3)], dim=-1)
+ emb = torch.cat([freqs, freqs], dim=-1)
+ return emb.cos().to(dtype), emb.sin().to(dtype)
+
+
+def rotate_half(x):
+ """Rotates half the hidden dims of the input."""
+ x1 = x[..., : x.shape[-1] // 2]
+ x2 = x[..., x.shape[-1] // 2 :]
+ return torch.cat((-x2, x1), dim=-1)
+
+
+def apply_rotary_pos_emb_vision(
+ q: torch.Tensor, k: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor
+) -> tuple[torch.Tensor, torch.Tensor]:
+ # Only the first `rot_dim` head dims carry 3D RoPE; the tail passes through untouched.
+ rot_dim = cos.shape[-1]
+ cos, sin = cos[None, :, None, :], sin[None, :, None, :]
+ q_rot, q_pass = q[..., :rot_dim], q[..., rot_dim:]
+ k_rot, k_pass = k[..., :rot_dim], k[..., rot_dim:]
+ q_rot = q_rot * cos + rotate_half(q_rot) * sin
+ k_rot = k_rot * cos + rotate_half(k_rot) * sin
+ return torch.cat([q_rot, q_pass], dim=-1), torch.cat([k_rot, k_pass], dim=-1)
+
+
+class MiniMaxM3VLVisionAttention(CLIPAttention):
+ """CLIP-style vision attention; the only difference from [`CLIPAttention`] is
+ that queries and keys are rotated by the tower's 3D RoPE before the
+ (interface-dispatched) scaled dot-product attention."""
+
+ def __init__(self, config: MiniMaxM3VLVisionConfig):
+ super().__init__(config)
+ # The vision tower has no grouped-query attention; the shared eager kernel
+ # still expects this attribute to drive its (no-op) `repeat_kv`.
+ self.num_key_value_groups = 1
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ position_embeddings: tuple[torch.Tensor, torch.Tensor],
+ attention_mask: torch.Tensor | None = None,
+ **kwargs: Unpack[TransformersKwargs],
+ ) -> tuple[torch.Tensor, torch.Tensor | None]:
+ input_shape = hidden_states.shape[:-1]
+ hidden_shape = (*input_shape, -1, self.head_dim)
+
+ queries = self.q_proj(hidden_states).view(hidden_shape)
+ keys = self.k_proj(hidden_states).view(hidden_shape)
+ values = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)
+
+ cos, sin = position_embeddings
+ queries, keys = apply_rotary_pos_emb_vision(queries, keys, cos, sin)
+ queries, keys = queries.transpose(1, 2), keys.transpose(1, 2)
+
+ attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface(
+ self.config._attn_implementation, eager_attention_forward
+ )
+ attn_output, attn_weights = attention_interface(
+ self,
+ queries,
+ keys,
+ values,
+ attention_mask,
+ scaling=self.scale,
+ dropout=0.0 if not self.training else self.dropout,
+ **kwargs,
+ )
+ attn_output = attn_output.reshape(*input_shape, -1).contiguous()
+ return self.out_proj(attn_output), attn_weights
+
+
+class MiniMaxM3VLVisionMLP(CLIPMLP):
+ pass
+
+
+# 3D-RoPE `position_embeddings` pass via `**kwargs` for simplicity
+class MiniMaxM3VLVisionEncoderLayer(CLIPEncoderLayer):
+ def __init__(self, config: MiniMaxM3VLVisionConfig):
+ super().__init__(config)
+ self.self_attn = MiniMaxM3VLVisionAttention(config)
+ self.mlp = MiniMaxM3VLVisionMLP(config)
+
+
+@auto_docstring
+class MiniMaxM3VLVisionModel(MiniMaxM3VLPreTrainedModel):
+ """CLIP-like vision tower with Conv3d patch embed + 3D RoPE."""
+
+ config: MiniMaxM3VLVisionConfig
+ main_input_name = "pixel_values"
+ _can_record_outputs = {
+ "hidden_states": MiniMaxM3VLVisionEncoderLayer,
+ "attentions": MiniMaxM3VLVisionAttention,
+ }
+
+ def __init__(self, config: MiniMaxM3VLVisionConfig):
+ super().__init__(config)
+ self.embeddings = MiniMaxM3VLVisionEmbeddings(config)
+ self.pre_layrnorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
+ self.layers = nn.ModuleList([MiniMaxM3VLVisionEncoderLayer(config) for _ in range(config.num_hidden_layers)])
+ head_dim = config.hidden_size // config.num_attention_heads
+ self.rotary_emb = MiniMaxM3VL3DRotaryEmbedding(
+ head_dim, theta=config.rope_parameters["rope_theta"], spatial_merge_size=config.spatial_merge_size
+ )
+ self.post_init()
+
+ @merge_with_config_defaults
+ @capture_outputs
+ @auto_docstring
+ def forward(
+ self, pixel_values: torch.Tensor, image_grid_thw: torch.Tensor, **kwargs: Unpack[TransformersKwargs]
+ ) -> BaseModelOutputWithPooling:
+ r"""
+ image_grid_thw (`torch.Tensor` of shape `(num_images, 3)`):
+ The temporal, height and width of feature shape of each image.
+ """
+ embeds = self.embeddings(pixel_values).to(self.pre_layrnorm.weight.dtype)
+ cos, sin = self.rotary_emb(image_grid_thw, device=embeds.device, dtype=embeds.dtype)
+ hidden_states = self.pre_layrnorm(embeds).unsqueeze(0)
+ for layer in self.layers:
+ hidden_states = layer(hidden_states, attention_mask=None, position_embeddings=(cos, sin), **kwargs)
+ return BaseModelOutputWithPooling(last_hidden_state=hidden_states, pooler_output=hidden_states[:, 0])
+
+
+class MiniMaxM3VLMultiModalProjector(nn.Module):
+ """Projects each vision patch from `vision_config.hidden_size` to `text_config.hidden_size`
+ (GELU MLP), then groups `spatial_merge_size**2` neighbouring patches into the channel dim and
+ fuses them back to a single `text_config.hidden_size` token with a second GELU MLP."""
+
+ def __init__(self, config: MiniMaxM3VLConfig):
+ super().__init__()
+ text_hidden = config.text_config.hidden_size
+ self.spatial_merge_size = config.vision_config.spatial_merge_size
+ self.linear_1 = nn.Linear(config.vision_config.hidden_size, config.projector_hidden_size, bias=True)
+ self.act = ACT2FN["gelu"]
+ self.linear_2 = nn.Linear(config.projector_hidden_size, text_hidden, bias=True)
+ self.merge_linear_1 = nn.Linear(config.merged_hidden_size, config.projector_hidden_size, bias=True)
+ self.merge_act = ACT2FN["gelu"]
+ self.merge_linear_2 = nn.Linear(config.projector_hidden_size, text_hidden, bias=True)
+
+ def forward(self, image_features: torch.Tensor) -> torch.Tensor:
+ hidden_states = self.linear_2(self.act(self.linear_1(image_features)))
+ hidden_states = hidden_states.reshape(hidden_states.shape[0] // (self.spatial_merge_size**2), -1)
+ return self.merge_linear_2(self.merge_act(self.merge_linear_1(hidden_states)))
+
+
+class MiniMaxM3VLModelOutputWithPast(LlavaModelOutputWithPast):
+ r"""
+ past_key_values (`Cache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
+ It is a [`~cache_utils.Cache`] instance. For more details, see our [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache).
+
+ Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see
+ `past_key_values` input) to speed up sequential decoding.
+ image_hidden_states (`torch.FloatTensor`, *optional*):
+ A `torch.FloatTensor` of size `(num_image_patches, hidden_size)`.
+ image_hidden_states of the model produced by the vision encoder and after projecting the last hidden state.
+ video_hidden_states (`torch.FloatTensor`, *optional*):
+ A `torch.FloatTensor` of size `(num_video_patches, hidden_size)`.
+ video_hidden_states of the model produced by the vision encoder and after projecting the last hidden state.
+ """
+
+ video_hidden_states: torch.FloatTensor | None = None
+
+
+class MiniMaxM3VLCausalLMOutputWithPast(LlavaCausalLMOutputWithPast):
+ r"""
+ loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
+ Language modeling loss (for next-token prediction).
+ logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`):
+ Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
+ past_key_values (`Cache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
+ It is a [`~cache_utils.Cache`] instance. For more details, see our [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache).
+
+ Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see
+ `past_key_values` input) to speed up sequential decoding.
+ image_hidden_states (`torch.FloatTensor`, *optional*):
+ A `torch.FloatTensor` of size `(num_image_patches, hidden_size)`.
+ image_hidden_states of the model produced by the vision encoder and after projecting the last hidden state.
+ video_hidden_states (`torch.FloatTensor`, *optional*):
+ A `torch.FloatTensor` of size `(num_video_patches, hidden_size)`.
+ video_hidden_states of the model produced by the vision encoder and after projecting the last hidden state.
+ """
+
+ video_hidden_states: torch.FloatTensor | None = None
+
+
+@auto_docstring(custom_intro="MiniMax M3 VL backbone (vision + projector + text), without LM head.")
+class MiniMaxM3VLModel(LlavaModel):
+ config: MiniMaxM3VLConfig
+
+ def __init__(self, config: MiniMaxM3VLConfig):
+ super().__init__(config)
+ self.vision_tower = MiniMaxM3VLVisionModel(config.vision_config)
+ self.multi_modal_projector = MiniMaxM3VLMultiModalProjector(config)
+ self.language_model = MiniMaxM3VLTextModel(config.text_config)
+ self.post_init()
+
+ def get_image_features(
+ self,
+ pixel_values: torch.FloatTensor,
+ image_grid_thw: torch.Tensor,
+ **kwargs,
+ ) -> BaseModelOutputWithPooling:
+ r"""
+ image_grid_thw (`torch.Tensor` of shape `(num_images, 3)`, *optional*):
+ The temporal, height and width of each image's feature grid, used to build the vision 3D RoPE
+ and to merge patch features.
+ """
+ # Return the raw vision-tower output (so callers can inspect hidden states /
+ # attentions) while stashing the projected + spatially-merged features —
+ # ready to scatter into the text embeddings — in `pooler_output`.
+ vision_outputs = self.vision_tower(pixel_values=pixel_values, image_grid_thw=image_grid_thw, **kwargs)
+ vision_outputs.pooler_output = self.multi_modal_projector(vision_outputs.last_hidden_state.squeeze(0))
+ return vision_outputs
+
+ @merge_with_config_defaults
+ @can_return_tuple
+ @auto_docstring(
+ custom_intro="Obtains video last hidden states from the vision tower and apply multimodal projection."
+ )
+ def get_video_features(
+ self,
+ pixel_values_videos: torch.FloatTensor,
+ video_grid_thw: torch.Tensor,
+ **kwargs,
+ ) -> BaseModelOutputWithPooling:
+ r"""
+ pixel_values_videos (`torch.FloatTensor`):
+ The tensors corresponding to the input video frames.
+ video_grid_thw (`torch.Tensor` of shape `(num_videos, 3)`, *optional*):
+ The temporal, height and width of each video's feature grid, used to build the vision 3D RoPE
+ and to merge patch features.
+ """
+ # Video frames flow through the same vision pipeline as images (the tower is
+ # grid-agnostic); only the placeholder token they scatter into differs.
+ vision_outputs = self.vision_tower(pixel_values=pixel_values_videos, image_grid_thw=video_grid_thw, **kwargs)
+ vision_outputs.pooler_output = self.multi_modal_projector(vision_outputs.last_hidden_state.squeeze(0))
+ return vision_outputs
+
+ def get_placeholder_mask(
+ self,
+ input_ids: torch.LongTensor,
+ inputs_embeds: torch.FloatTensor,
+ image_features: torch.FloatTensor | None = None,
+ video_features: torch.FloatTensor | None = None,
+ ):
+ """
+ Obtains the image/video placeholder masks from `input_ids` or `inputs_embeds`, and checks that the
+ placeholder token count matches the multimodal feature length. Raises if they differ.
+ """
+ if input_ids is None:
+ special_image_mask = inputs_embeds == self.get_input_embeddings()(
+ torch.tensor(self.config.image_token_id, dtype=torch.long, device=inputs_embeds.device)
+ )
+ special_image_mask = special_image_mask.all(-1)
+ special_video_mask = inputs_embeds == self.get_input_embeddings()(
+ torch.tensor(self.config.video_token_id, dtype=torch.long, device=inputs_embeds.device)
+ )
+ special_video_mask = special_video_mask.all(-1)
+ else:
+ special_image_mask = input_ids == self.config.image_token_id
+ special_video_mask = input_ids == self.config.video_token_id
+
+ n_image_tokens = special_image_mask.sum()
+ special_image_mask = special_image_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device)
+ if image_features is not None:
+ torch_compilable_check(
+ inputs_embeds[special_image_mask].numel() == image_features.numel(),
+ f"Image features and image tokens do not match, tokens: {n_image_tokens}, features: {image_features.shape[0]}",
+ )
+
+ n_video_tokens = special_video_mask.sum()
+ special_video_mask = special_video_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device)
+ if video_features is not None:
+ torch_compilable_check(
+ inputs_embeds[special_video_mask].numel() == video_features.numel(),
+ f"Video features and video tokens do not match, tokens: {n_video_tokens}, features: {video_features.shape[0]}",
+ )
+ return special_image_mask, special_video_mask
+
+ @can_return_tuple
+ @auto_docstring
+ def forward(
+ self,
+ input_ids: torch.LongTensor | None = None,
+ pixel_values: torch.FloatTensor | None = None,
+ pixel_values_videos: torch.FloatTensor | None = None,
+ image_grid_thw: torch.Tensor | None = None,
+ video_grid_thw: torch.Tensor | None = None,
+ attention_mask: torch.Tensor | None = None,
+ position_ids: torch.LongTensor | None = None,
+ past_key_values: Cache | None = None,
+ inputs_embeds: torch.FloatTensor | None = None,
+ **kwargs: Unpack[TransformersKwargs],
+ ) -> tuple | MiniMaxM3VLModelOutputWithPast:
+ r"""
+ image_grid_thw (`torch.Tensor` of shape `(num_images, 3)`, *optional*):
+ The temporal, height and width of each image's feature grid, used to build the vision 3D RoPE
+ and to merge patch features.
+ video_grid_thw (`torch.Tensor` of shape `(num_videos, 3)`, *optional*):
+ The temporal, height and width of each video's feature grid, used to build the vision 3D RoPE
+ and to merge patch features.
+ """
+ if (input_ids is None) ^ (inputs_embeds is not None):
+ raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
+
+ if inputs_embeds is None:
+ inputs_embeds = self.get_input_embeddings()(input_ids)
+
+ image_features = None
+ if pixel_values is not None:
+ image_features = self.get_image_features(
+ pixel_values=pixel_values, image_grid_thw=image_grid_thw
+ ).pooler_output.to(inputs_embeds.device, inputs_embeds.dtype)
+
+ video_features = None
+ if pixel_values_videos is not None:
+ video_features = self.get_video_features(
+ pixel_values_videos=pixel_values_videos, video_grid_thw=video_grid_thw
+ ).pooler_output.to(inputs_embeds.device, inputs_embeds.dtype)
+
+ image_mask, video_mask = self.get_placeholder_mask(
+ input_ids, inputs_embeds, image_features=image_features, video_features=video_features
+ )
+ if image_features is not None:
+ inputs_embeds = inputs_embeds.masked_scatter(image_mask, image_features)
+ if video_features is not None:
+ inputs_embeds = inputs_embeds.masked_scatter(video_mask, video_features)
+
+ outputs = self.language_model(
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ past_key_values=past_key_values,
+ inputs_embeds=inputs_embeds,
+ **kwargs,
+ )
+
+ return MiniMaxM3VLModelOutputWithPast(
+ last_hidden_state=outputs.last_hidden_state,
+ past_key_values=outputs.past_key_values,
+ hidden_states=getattr(outputs, "hidden_states", None),
+ attentions=getattr(outputs, "attentions", None),
+ image_hidden_states=image_features,
+ video_hidden_states=video_features,
+ )
+
+
+@auto_docstring(custom_intro="MiniMax M3 VL full model with LM head (text + vision).")
+class MiniMaxM3SparseForConditionalGeneration(LlavaForConditionalGeneration):
+ config: MiniMaxM3VLConfig
+
+ def get_image_features(self, pixel_values, image_grid_thw, **kwargs):
+ r"""
+ image_grid_thw (`torch.Tensor` of shape `(num_images, 3)`, *optional*):
+ The temporal, height and width of each image's feature grid, used to build the vision 3D RoPE
+ and to merge patch features.
+ """
+ return self.model.get_image_features(pixel_values, image_grid_thw, **kwargs)
+
+ def get_video_features(self, pixel_values_videos, video_grid_thw, **kwargs):
+ r"""
+ pixel_values_videos (`torch.FloatTensor`):
+ The tensors corresponding to the input video frames.
+ video_grid_thw (`torch.Tensor` of shape `(num_videos, 3)`, *optional*):
+ The temporal, height and width of each video's feature grid, used to build the vision 3D RoPE
+ and to merge patch features.
+ """
+ return self.model.get_video_features(pixel_values_videos, video_grid_thw, **kwargs)
+
+ @can_return_tuple
+ @auto_docstring
+ def forward(
+ self,
+ input_ids: torch.LongTensor | None = None,
+ pixel_values: torch.FloatTensor | None = None,
+ pixel_values_videos: torch.FloatTensor | None = None,
+ image_grid_thw: torch.Tensor | None = None,
+ video_grid_thw: torch.Tensor | None = None,
+ attention_mask: torch.Tensor | None = None,
+ position_ids: torch.LongTensor | None = None,
+ past_key_values: Cache | None = None,
+ inputs_embeds: torch.FloatTensor | None = None,
+ labels: torch.LongTensor | None = None,
+ logits_to_keep: int | torch.Tensor = 0,
+ **kwargs: Unpack[TransformersKwargs],
+ ) -> tuple | MiniMaxM3VLCausalLMOutputWithPast:
+ r"""
+ image_grid_thw (`torch.Tensor` of shape `(num_images, 3)`, *optional*):
+ The temporal, height and width of each image's feature grid, used to build the vision 3D RoPE
+ and to merge patch features.
+ video_grid_thw (`torch.Tensor` of shape `(num_videos, 3)`, *optional*):
+ The temporal, height and width of each video's feature grid, used to build the vision 3D RoPE
+ and to merge patch features.
+ """
+ outputs = self.model(
+ input_ids=input_ids,
+ pixel_values=pixel_values,
+ pixel_values_videos=pixel_values_videos,
+ image_grid_thw=image_grid_thw,
+ video_grid_thw=video_grid_thw,
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ past_key_values=past_key_values,
+ inputs_embeds=inputs_embeds,
+ **kwargs,
+ )
+ hidden_states = outputs.last_hidden_state
+ slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
+ logits = self.lm_head(hidden_states[:, slice_indices, :])
+
+ loss = None
+ if labels is not None:
+ loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.text_config.vocab_size)
+
+ return MiniMaxM3VLCausalLMOutputWithPast(
+ loss=loss,
+ logits=logits,
+ past_key_values=outputs.past_key_values,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
+ image_hidden_states=outputs.image_hidden_states,
+ video_hidden_states=outputs.video_hidden_states,
+ )
+
+ def prepare_inputs_for_generation(
+ self,
+ input_ids,
+ past_key_values=None,
+ inputs_embeds=None,
+ pixel_values=None,
+ pixel_values_videos=None,
+ attention_mask=None,
+ logits_to_keep=None,
+ is_first_iteration=False,
+ **kwargs,
+ ):
+ # Overwritten -- pixel inputs are merged into the cache on the first step, so we
+ # only forward them once (image and video alike).
+ model_inputs = super().prepare_inputs_for_generation(
+ input_ids,
+ past_key_values=past_key_values,
+ inputs_embeds=inputs_embeds,
+ attention_mask=attention_mask,
+ logits_to_keep=logits_to_keep,
+ is_first_iteration=is_first_iteration,
+ **kwargs,
+ )
+
+ if is_first_iteration or not kwargs.get("use_cache", True):
+ model_inputs["pixel_values"] = pixel_values
+ model_inputs["pixel_values_videos"] = pixel_values_videos
+
+ return model_inputs
+
+
+class MiniMaxM3VLProcessorKwargs(Qwen2VLProcessorKwargs):
+ _defaults = {
+ "videos_kwargs": {"do_resize": False, "return_metadata": True},
+ }
+
+
+class MiniMaxM3VLProcessor(Qwen2VLProcessor):
+ """Combines tokenizer + image_processor + video_processor for MiniMax M3 VL.
+
+ Expands `IMAGE_TOKEN` / `VIDEO_TOKEN` markers in the prompt into the matching
+ number of placeholder tokens (one per merged patch), wrapped in `VISION_START_TOKEN`
+ / `VISION_END_TOKEN` brackets. Video chunks are additionally prefixed with a
+ `]<]{seconds} seconds[>[` timestamp marker per frame when metadata is available.
+ """
+
+ valid_processor_kwargs = MiniMaxM3VLProcessorKwargs
+
+ IMAGE_TOKEN = "]<]image[>["
+ VIDEO_TOKEN = "]<]video[>["
+ VISION_START_TOKEN = "]<]start of image[>["
+ VISION_END_TOKEN = "]<]end of image[>["
+
+ def __init__(self, image_processor=None, tokenizer=None, video_processor=None, chat_template=None, **kwargs):
+ super().__init__(image_processor, tokenizer, video_processor, chat_template=chat_template)
+ self.image_token = self.IMAGE_TOKEN
+ self.video_token = self.VIDEO_TOKEN
+ self.image_token_id = tokenizer.convert_tokens_to_ids(self.IMAGE_TOKEN) if tokenizer else None
+ self.video_token_id = tokenizer.convert_tokens_to_ids(self.VIDEO_TOKEN) if tokenizer else None
+ self.vision_start_token_id = tokenizer.convert_tokens_to_ids(self.VISION_START_TOKEN) if tokenizer else None
+ self.vision_end_token_id = tokenizer.convert_tokens_to_ids(self.VISION_END_TOKEN) if tokenizer else None
+
+ def replace_image_token(self, image_inputs: dict, image_idx: int) -> str:
+ merge_length = self.image_processor.merge_size**2
+ num_image_tokens = int(image_inputs["image_grid_thw"][image_idx].prod() // merge_length)
+ return self.VISION_START_TOKEN + self.IMAGE_TOKEN * num_image_tokens + self.VISION_END_TOKEN
+
+ def replace_video_token(self, video_inputs: dict, video_idx: int) -> str:
+ merge_length = self.video_processor.merge_size**2
+ grid_thw = video_inputs["video_grid_thw"][video_idx]
+ grid_t = int(grid_thw[0])
+ frame_seqlen = int(grid_thw[1:].prod() // merge_length)
+ metadata = video_inputs.get("video_metadata", [None] * (video_idx + 1))[video_idx]
+ temporal_patch_size = self.video_processor.temporal_patch_size
+ chunk = ""
+ for frame in range(grid_t):
+ if (
+ metadata is not None
+ and getattr(metadata, "fps", None) is not None
+ and getattr(metadata, "frames_indices", None) is not None
+ ):
+ ts = (
+ metadata.frames_indices[min(frame * temporal_patch_size, len(metadata.frames_indices) - 1)]
+ / metadata.fps
+ )
+ chunk += f"]<]{ts:.1f} seconds[>["
+ chunk += self.VISION_START_TOKEN + self.VIDEO_TOKEN * frame_seqlen + self.VISION_END_TOKEN
+ return chunk
+
+
+__all__ = [
+ "MiniMaxM3VLConfig",
+ "MiniMaxM3VLTextConfig",
+ "MiniMaxM3VLVisionConfig",
+ "MiniMaxM3VLForCausalLM",
+ "MiniMaxM3SparseForConditionalGeneration",
+ "MiniMaxM3VLModel",
+ "MiniMaxM3VLPreTrainedModel",
+ "MiniMaxM3VLProcessor",
+ "MiniMaxM3VLTextModel",
+ "MiniMaxM3VLVisionModel",
+]
diff --git a/src/transformers/models/minimax_m3_vl/processing_minimax_m3_vl.py b/src/transformers/models/minimax_m3_vl/processing_minimax_m3_vl.py
new file mode 100644
index 000000000000..6ab6f0517d3b
--- /dev/null
+++ b/src/transformers/models/minimax_m3_vl/processing_minimax_m3_vl.py
@@ -0,0 +1,153 @@
+# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
+# This file was automatically generated from src/transformers/models/minimax_m3_vl/modular_minimax_m3_vl.py.
+# Do NOT edit this file manually as any edits will be overwritten by the generation of
+# the file from the modular. If any change should be done, please apply the change to the
+# modular_minimax_m3_vl.py file directly. One of our CI enforces this.
+# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
+# Copyright 2026 the MiniMax AI Team and HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from ...processing_utils import MultiModalData, ProcessingKwargs, ProcessorMixin
+from ...utils import auto_docstring
+
+
+class MiniMaxM3VLProcessorKwargs(ProcessingKwargs, total=False):
+ _defaults = {
+ "videos_kwargs": {"do_resize": False, "return_metadata": True},
+ }
+
+
+@auto_docstring
+class MiniMaxM3VLProcessor(ProcessorMixin):
+ """Combines tokenizer + image_processor + video_processor for MiniMax M3 VL.
+
+ Expands `IMAGE_TOKEN` / `VIDEO_TOKEN` markers in the prompt into the matching
+ number of placeholder tokens (one per merged patch), wrapped in `VISION_START_TOKEN`
+ / `VISION_END_TOKEN` brackets. Video chunks are additionally prefixed with a
+ `]<]{seconds} seconds[>[` timestamp marker per frame when metadata is available.
+ """
+
+ valid_processor_kwargs = MiniMaxM3VLProcessorKwargs
+
+ IMAGE_TOKEN = "]<]image[>["
+ VIDEO_TOKEN = "]<]video[>["
+ VISION_START_TOKEN = "]<]start of image[>["
+ VISION_END_TOKEN = "]<]end of image[>["
+
+ def __init__(self, image_processor=None, tokenizer=None, video_processor=None, chat_template=None, **kwargs):
+ self.image_token = self.IMAGE_TOKEN
+ self.video_token = self.VIDEO_TOKEN
+ self.image_token_id = tokenizer.convert_tokens_to_ids(self.IMAGE_TOKEN) if tokenizer else None
+ self.video_token_id = tokenizer.convert_tokens_to_ids(self.VIDEO_TOKEN) if tokenizer else None
+ super().__init__(image_processor, tokenizer, video_processor, chat_template=chat_template)
+ self.vision_start_token_id = tokenizer.convert_tokens_to_ids(self.VISION_START_TOKEN) if tokenizer else None
+ self.vision_end_token_id = tokenizer.convert_tokens_to_ids(self.VISION_END_TOKEN) if tokenizer else None
+
+ def replace_image_token(self, image_inputs: dict, image_idx: int) -> str:
+ merge_length = self.image_processor.merge_size**2
+ num_image_tokens = int(image_inputs["image_grid_thw"][image_idx].prod() // merge_length)
+ return self.VISION_START_TOKEN + self.IMAGE_TOKEN * num_image_tokens + self.VISION_END_TOKEN
+
+ def replace_video_token(self, video_inputs: dict, video_idx: int) -> str:
+ merge_length = self.video_processor.merge_size**2
+ grid_thw = video_inputs["video_grid_thw"][video_idx]
+ grid_t = int(grid_thw[0])
+ frame_seqlen = int(grid_thw[1:].prod() // merge_length)
+ metadata = video_inputs.get("video_metadata", [None] * (video_idx + 1))[video_idx]
+ temporal_patch_size = self.video_processor.temporal_patch_size
+ chunk = ""
+ for frame in range(grid_t):
+ if (
+ metadata is not None
+ and getattr(metadata, "fps", None) is not None
+ and getattr(metadata, "frames_indices", None) is not None
+ ):
+ ts = (
+ metadata.frames_indices[min(frame * temporal_patch_size, len(metadata.frames_indices) - 1)]
+ / metadata.fps
+ )
+ chunk += f"]<]{ts:.1f} seconds[>["
+ chunk += self.VISION_START_TOKEN + self.VIDEO_TOKEN * frame_seqlen + self.VISION_END_TOKEN
+ return chunk
+
+ def _get_num_multimodal_tokens(self, image_sizes=None, video_sizes=None, **kwargs):
+ """
+ Computes the number of placeholder tokens needed for multimodal inputs with the given sizes.
+ Args:
+ image_sizes (`list[list[int]]`, *optional*):
+ The input sizes formatted as (height, width) per each image.
+ video_sizes (`list[list[int]]`, *optional*):
+ The input sizes formatted as (num_frames, height, width) per each video.
+ Returns:
+ `MultiModalData`: A `MultiModalData` object holding number of tokens per each of the provided
+ input modalities, along with other useful data.
+ """
+
+ vision_data = {}
+ if image_sizes is not None:
+ images_kwargs = MiniMaxM3VLProcessorKwargs._defaults.get("images_kwargs", {})
+ images_kwargs.update(kwargs)
+ merge_size = images_kwargs.get("merge_size", None) or self.image_processor.merge_size
+
+ num_image_patches = [
+ self.image_processor.get_number_of_image_patches(*image_size, images_kwargs)
+ for image_size in image_sizes
+ ]
+ num_image_tokens = [(num_patches // merge_size**2) for num_patches in num_image_patches]
+ vision_data.update({"num_image_tokens": num_image_tokens, "num_image_patches": num_image_patches})
+
+ if video_sizes is not None:
+ videos_kwargs = MiniMaxM3VLProcessorKwargs._defaults.get("videos_kwargs", {})
+ videos_kwargs.update(kwargs)
+ num_video_patches = [
+ self.video_processor.get_number_of_video_patches(*video_size, videos_kwargs)
+ for video_size in video_sizes
+ ]
+ num_video_tokens = [(num_patches // merge_size**2) for num_patches in num_video_patches]
+ vision_data["num_video_tokens"] = num_video_tokens
+
+ return MultiModalData(**vision_data)
+
+ def post_process_image_text_to_text(
+ self, generated_outputs, skip_special_tokens=True, clean_up_tokenization_spaces=False, **kwargs
+ ):
+ """
+ Post-process the output of the model to decode the text.
+
+ Args:
+ generated_outputs (`torch.Tensor` or `np.ndarray`):
+ The output of the model `generate` function. The output is expected to be a tensor of shape `(batch_size, sequence_length)`
+ or `(sequence_length,)`.
+ skip_special_tokens (`bool`, *optional*, defaults to `True`):
+ Whether or not to remove special tokens in the output. Argument passed to the tokenizer's `batch_decode` method.
+ clean_up_tokenization_spaces (`bool`, *optional*, defaults to `False`):
+ Whether or not to clean up the tokenization spaces. Argument passed to the tokenizer's `batch_decode` method.
+ **kwargs:
+ Additional arguments to be passed to the tokenizer's `batch_decode method`.
+
+ Returns:
+ `list[str]`: The decoded text.
+ """
+ return self.tokenizer.batch_decode(
+ generated_outputs,
+ skip_special_tokens=skip_special_tokens,
+ clean_up_tokenization_spaces=clean_up_tokenization_spaces,
+ **kwargs,
+ )
+
+ @property
+ def model_input_names(self):
+ return super().model_input_names + ["mm_token_type_ids"]
+
+
+__all__ = ["MiniMaxM3VLProcessor"]
diff --git a/src/transformers/models/minimax_m3_vl/video_processing_minimax_m3_vl.py b/src/transformers/models/minimax_m3_vl/video_processing_minimax_m3_vl.py
new file mode 100644
index 000000000000..c25729309e14
--- /dev/null
+++ b/src/transformers/models/minimax_m3_vl/video_processing_minimax_m3_vl.py
@@ -0,0 +1,137 @@
+# Copyright 2026 the MiniMax AI Team and HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+
+"""Video processor for MiniMax M3 VL."""
+
+import torch
+from torchvision.transforms import InterpolationMode
+
+from ...feature_extraction_utils import BatchFeature
+from ...image_utils import PILImageResampling, SizeDict
+from ...processing_utils import Unpack, VideosKwargs
+from ...utils import TensorType
+from ...video_processing_utils import BaseVideoProcessor
+from ...video_utils import group_videos_by_shape, reorder_videos
+from .image_processing_minimax_m3_vl import smart_resize
+
+
+class MiniMaxM3VLVideoProcessorKwargs(VideosKwargs, total=False):
+ patch_size: int
+ temporal_patch_size: int
+ merge_size: int
+ min_pixels: int
+ max_pixels: int
+ total_pixels: int
+ min_frames: int
+ max_frames: int
+ fps: float | int
+
+
+class MiniMaxM3VLVideoProcessor(BaseVideoProcessor):
+ do_resize = True
+ resample = PILImageResampling.BICUBIC
+ size = {"height": 672, "width": 672}
+ default_to_square = False
+ do_rescale = True
+ rescale_factor = 1 / 255
+ do_normalize = True
+ image_mean = [0.48145466, 0.4578275, 0.40821073]
+ image_std = [0.26862954, 0.26130258, 0.27577711]
+ do_convert_rgb = True
+ do_sample_frames = False
+ patch_size = 14
+ temporal_patch_size = 2
+ merge_size = 2
+ min_pixels = 4 * 28 * 28
+ max_pixels = 768 * 28 * 28
+ total_pixels = int(64000 * 28 * 28 * 0.9)
+ fps = 1.0
+ min_frames = 4
+ max_frames = 768
+ valid_kwargs = MiniMaxM3VLVideoProcessorKwargs
+ model_input_names = ["pixel_values_videos", "video_grid_thw"]
+
+ def __init__(self, **kwargs: Unpack[MiniMaxM3VLVideoProcessorKwargs]):
+ super().__init__(**kwargs)
+
+ def _preprocess(
+ self,
+ videos: list[torch.Tensor],
+ do_convert_rgb: bool,
+ do_resize: bool,
+ size: SizeDict,
+ resample: PILImageResampling | InterpolationMode | int | None,
+ do_rescale: bool,
+ rescale_factor: float,
+ do_normalize: bool,
+ image_mean: float | list[float] | None,
+ image_std: float | list[float] | None,
+ patch_size: int,
+ temporal_patch_size: int,
+ merge_size: int,
+ min_pixels: int,
+ max_pixels: int,
+ return_tensors: str | TensorType | None = None,
+ **kwargs,
+ ) -> BatchFeature:
+ grouped, grouped_idx = group_videos_by_shape(videos)
+ resized_grouped = {}
+ factor = patch_size * merge_size
+ for shape, stacked in grouped.items():
+ bs, nf, c, h, w = stacked.shape
+ rh, rw = (h, w)
+ if do_resize:
+ rh, rw = smart_resize(h, w, factor=factor, min_pixels=min_pixels, max_pixels=max_pixels)
+ stacked = stacked.view(bs * nf, c, h, w)
+ stacked = self.resize(stacked, size=SizeDict(height=rh, width=rw), resample=resample)
+ stacked = stacked.view(bs, nf, c, rh, rw)
+ resized_grouped[shape] = stacked
+ resized = reorder_videos(resized_grouped, grouped_idx)
+
+ grouped, grouped_idx = group_videos_by_shape(resized)
+ processed_grouped = {}
+ grids = {}
+ for shape, stacked in grouped.items():
+ rh, rw = stacked.shape[-2:]
+ patches = self.rescale_and_normalize(
+ stacked, do_rescale, rescale_factor, do_normalize, image_mean, image_std
+ )
+ if pad := -patches.shape[1] % temporal_patch_size:
+ repeats = patches[:, -1:].expand(-1, pad, -1, -1, -1)
+ patches = torch.cat([patches, repeats], dim=1)
+ bs, grid_t, c = patches.shape[:3]
+ grid_t = grid_t // temporal_patch_size
+ grid_h, grid_w = rh // patch_size, rw // patch_size
+ patches = patches.view(
+ bs,
+ grid_t,
+ temporal_patch_size,
+ c,
+ grid_h // merge_size,
+ merge_size,
+ patch_size,
+ grid_w // merge_size,
+ merge_size,
+ patch_size,
+ )
+ patches = patches.permute(0, 1, 4, 7, 5, 8, 3, 2, 6, 9)
+ flat = patches.reshape(bs, grid_t * grid_h * grid_w, c * temporal_patch_size * patch_size * patch_size)
+ processed_grouped[shape] = flat
+ grids[shape] = [[grid_t, grid_h, grid_w]] * bs
+
+ processed = reorder_videos(processed_grouped, grouped_idx)
+ grids = reorder_videos(grids, grouped_idx)
+ pixel_values_videos = torch.cat(processed, dim=0)
+ video_grid_thw = torch.tensor(grids, dtype=torch.long)
+ return BatchFeature(
+ data={"pixel_values_videos": pixel_values_videos, "video_grid_thw": video_grid_thw},
+ tensor_type=return_tensors,
+ )
+
+
+__all__ = ["MiniMaxM3VLVideoProcessor"]
diff --git a/src/transformers/quantizers/auto.py b/src/transformers/quantizers/auto.py
index d053401be153..4fe470f8727a 100644
--- a/src/transformers/quantizers/auto.py
+++ b/src/transformers/quantizers/auto.py
@@ -90,6 +90,10 @@
"vptq": VptqHfQuantizer,
"spqr": SpQRHfQuantizer,
"fp8": FineGrainedFP8HfQuantizer,
+ # MXFP8 = FP8 (E4M3 weights) with per-block ``[1, 32]`` E8M0 (uint8) scales —
+ # reuses the FineGrainedFP8 dequant path, with the E8M0 byte→exponent
+ # unpacking handled inside ``Fp8Dequantize._dequantize_one``.
+ "mxfp8": FineGrainedFP8HfQuantizer,
"auto-round": AutoRoundQuantizer,
"mxfp4": Mxfp4HfQuantizer,
"metal": MetalHfQuantizer,
@@ -117,6 +121,7 @@
"vptq": VptqConfig,
"spqr": SpQRConfig,
"fp8": FineGrainedFP8Config,
+ "mxfp8": FineGrainedFP8Config,
"auto-round": AutoRoundConfig,
"mxfp4": Mxfp4Config,
"metal": MetalConfig,
diff --git a/src/transformers/quantizers/quantizer_finegrained_fp8.py b/src/transformers/quantizers/quantizer_finegrained_fp8.py
index 0d300e89e954..011f45e6a1a5 100644
--- a/src/transformers/quantizers/quantizer_finegrained_fp8.py
+++ b/src/transformers/quantizers/quantizer_finegrained_fp8.py
@@ -94,6 +94,27 @@ def param_element_size(self, model: "PreTrainedModel", param_name: str, param: "
return 1
return super().param_element_size(model, param_name, param)
+ def _normalize_modules_to_not_convert(self, model: "PreTrainedModel"):
+ """Rewrite the skip-list to the model's own module tree.
+ For models that were already released, if they have a list of modules to not quantize
+ we need to apply the weight renaming / weight conversion opérations to get the actual
+ layer name of the model in `transformers`.
+ """
+ skip = self.quantization_config.modules_to_not_convert
+ if not skip:
+ return
+
+ from ..conversion_mapping import get_model_conversion_mapping
+
+ renamings = get_model_conversion_mapping(model)
+ remapped = []
+ for name in skip:
+ renamed = name
+ for rename in renamings:
+ renamed, _ = rename.rename_source_key(renamed)
+ remapped.append(renamed)
+ self.quantization_config.modules_to_not_convert = remapped
+
def _process_model_before_weight_loading(
self,
model: "PreTrainedModel",
@@ -101,6 +122,7 @@ def _process_model_before_weight_loading(
):
from ..integrations.finegrained_fp8 import replace_with_fp8_linear
+ self._normalize_modules_to_not_convert(model)
self.modules_to_not_convert = self.get_modules_to_not_convert(
model, self.quantization_config.modules_to_not_convert, model._keep_in_fp32_modules
)
@@ -182,6 +204,41 @@ def get_weight_conversions(self):
]
return []
+ def _is_mxfp8(self) -> bool:
+ """MXFP8 checkpoints ship E8M0 (uint8) per-block scales; plain FP8 ships float32."""
+ quant_method = getattr(self.quantization_config, "quant_method", None)
+ return quant_method == "mxfp8"
+
+ def _update_weight_conversions_mxfp8(self, weight_conversions):
+ """
+ Native MXFP8 path: prepend a `Fp8DecodeScale` op so the uint8 E8M0
+ scales are decoded to float32 `2 ** (byte - 127)` *before* any merge/concat op
+ and add a generic fallback converter that decodes the scales of plain `FP8Linear` weights (attention / dense projections)
+ which have no model-specific converter.
+ """
+ from ..core_model_loading import WeightConverter
+ from ..integrations.finegrained_fp8 import Fp8DecodeScale
+
+ updated: list = []
+ for conv in weight_conversions:
+ if isinstance(conv, WeightConverter) and any(p.endswith(".weight") for p in conv.source_patterns):
+ conv = WeightConverter(
+ source_patterns=conv.source_patterns,
+ target_patterns=conv._original_target_patterns,
+ operations=[Fp8DecodeScale(self)] + list(conv.operations),
+ )
+ updated.append(conv)
+ # Generic fallback for plain ``nn.Linear`` scales with no model-specific converter.
+ # Listed last so the model converters above win the first-match for expert/dense scales.
+ updated.append(
+ WeightConverter(
+ source_patterns=["weight_scale_inv"],
+ target_patterns="weight_scale_inv",
+ operations=[Fp8DecodeScale(self)],
+ )
+ )
+ return updated
+
def update_weight_conversions(self, weight_conversions):
"""When loading with ``dequantize=True``, attach an :class:`Fp8Dequantize` op to
every existing :class:`WeightConverter` so that per-block scales are folded into
@@ -213,6 +270,9 @@ def update_weight_conversions(self, weight_conversions):
weight_conversions = [scale_rename] + list(weight_conversions)
if not (self.pre_quantized and self.quantization_config.dequantize):
+ if self.pre_quantized and self._is_mxfp8():
+ # mxfp8 needs a pre-processing on the scales when not dequantizing
+ return self._update_weight_conversions_mxfp8(weight_conversions)
return weight_conversions + self.get_weight_conversions()
updated: list = []
diff --git a/src/transformers/utils/quantization_config.py b/src/transformers/utils/quantization_config.py
index e2235060b85f..ef7a18083da6 100644
--- a/src/transformers/utils/quantization_config.py
+++ b/src/transformers/utils/quantization_config.py
@@ -60,6 +60,7 @@ class QuantizationMethod(str, Enum):
FPQUANT = "fp_quant"
AUTOROUND = "auto-round"
MXFP4 = "mxfp4"
+ MXFP8 = "mxfp8"
METAL = "metal"
FOUR_OVER_SIX = "fouroversix"
SINQ = "sinq"
@@ -1662,7 +1663,10 @@ def __init__(
scale_fmt: str = "float",
**kwargs,
):
- self.quant_method = QuantizationMethod.FP8
+ self.quant_method = kwargs.pop("quant_method", QuantizationMethod.FP8)
+ # MiniMax ships the skip-list under ``ignored_layers``; accept it as an alias.
+ if modules_to_not_convert is None and "ignored_layers" in kwargs:
+ modules_to_not_convert = kwargs.pop("ignored_layers")
self.modules_to_not_convert = modules_to_not_convert
self.activation_scheme = activation_scheme
self.weight_block_size = weight_block_size
diff --git a/tests/fixtures/tests_samples/COCO/apple.jpg b/tests/fixtures/tests_samples/COCO/apple.jpg
new file mode 100644
index 000000000000..f055aecd4bb4
Binary files /dev/null and b/tests/fixtures/tests_samples/COCO/apple.jpg differ
diff --git a/tests/models/minimax_m3_vl/__init__.py b/tests/models/minimax_m3_vl/__init__.py
new file mode 100644
index 000000000000..e69de29bb2d1
diff --git a/tests/models/minimax_m3_vl/test_modeling_minimax_m3_vl.py b/tests/models/minimax_m3_vl/test_modeling_minimax_m3_vl.py
new file mode 100644
index 000000000000..359cf1f52e3a
--- /dev/null
+++ b/tests/models/minimax_m3_vl/test_modeling_minimax_m3_vl.py
@@ -0,0 +1,652 @@
+# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Testing suite for the PyTorch MiniMax-M3-VL model."""
+
+import copy
+import unittest
+
+from parameterized import parameterized
+
+from transformers import (
+ AutoTokenizer,
+ MiniMaxM3SparseForConditionalGeneration,
+ MiniMaxM3VLConfig,
+ MiniMaxM3VLImageProcessorFast,
+ MiniMaxM3VLModel,
+ MiniMaxM3VLProcessor,
+ MiniMaxM3VLVideoProcessor,
+ is_torch_available,
+ is_vision_available,
+)
+from transformers.testing_utils import (
+ require_torch,
+ slow,
+ torch_device,
+)
+
+from ...generation.test_utils import GenerationTesterMixin
+from ...test_configuration_common import ConfigTester
+from ...test_modeling_common import (
+ TEST_EAGER_MATCHES_BATCHED_AND_GROUPED_INFERENCE_PARAMETERIZATION,
+ ModelTesterMixin,
+ _test_eager_matches_batched_and_grouped_inference,
+ floats_tensor,
+ ids_tensor,
+)
+from ...test_pipeline_mixin import PipelineTesterMixin
+
+
+if is_torch_available():
+ import torch
+
+
+if is_vision_available():
+ from PIL import Image
+
+
+class MiniMaxM3VLVisionText2TextModelTester:
+ def __init__(
+ self,
+ parent,
+ batch_size=3,
+ seq_length=7,
+ ignore_index=-100,
+ image_token_index=4,
+ video_token_index=5,
+ is_training=True,
+ text_config={
+ "hidden_size": 32,
+ "intermediate_size": 64,
+ "dense_intermediate_size": 128,
+ "shared_intermediate_size": 32,
+ "num_hidden_layers": 2,
+ "num_attention_heads": 4,
+ "num_key_value_heads": 2,
+ "head_dim": 32,
+ "rotary_dim": 16,
+ "hidden_act": "silu",
+ "max_position_embeddings": 512,
+ "rms_norm_eps": 1e-6,
+ "vocab_size": 99,
+ "bos_token_id": 0,
+ "eos_token_id": 1,
+ "pad_token_id": 2,
+ "num_local_experts": 4,
+ "num_experts_per_tok": 2,
+ "n_shared_experts": 1,
+ "moe_layer_freq": [0, 1],
+ "layer_types": [
+ "full_attention",
+ "minimax_m3_sparse",
+ ],
+ "use_routing_bias": True,
+ "routed_scaling_factor": 2.0,
+ "swiglu_alpha": 1.702,
+ "swiglu_limit": 7.0,
+ "tie_word_embeddings": False,
+ "rope_parameters": {
+ "rope_type": "default",
+ "rope_theta": 5000000.0,
+ "partial_rotary_factor": 0.5,
+ },
+ "index_n_heads": 2,
+ "index_head_dim": 16,
+ "index_block_size": 8,
+ "index_topk_blocks": 4,
+ "index_local_blocks": 1,
+ },
+ vision_config={
+ "hidden_size": 32,
+ "intermediate_size": 64,
+ "num_hidden_layers": 2,
+ "num_attention_heads": 4,
+ "num_channels": 3,
+ "image_size": 14,
+ "patch_size": 14,
+ "temporal_patch_size": 2,
+ "spatial_merge_size": 1,
+ "rope_theta": 10000.0,
+ },
+ ):
+ self.parent = parent
+ self.batch_size = batch_size
+ self.ignore_index = ignore_index
+ self.image_token_index = image_token_index
+ self.video_token_index = video_token_index
+ self.is_training = is_training
+ self.text_config = text_config
+ self.vision_config = vision_config
+
+ self.pad_token_id = text_config["pad_token_id"]
+ self.num_hidden_layers = text_config["num_hidden_layers"]
+ self.num_attention_heads = text_config["num_attention_heads"]
+ self.hidden_size = text_config["hidden_size"]
+ self.vocab_size = text_config["vocab_size"]
+
+ self.num_channels = vision_config["num_channels"]
+ self.image_size = vision_config["image_size"]
+ self.patch_size = vision_config["patch_size"]
+ self.temporal_patch_size = vision_config["temporal_patch_size"]
+ self.spatial_merge_size = vision_config["spatial_merge_size"]
+
+ # One patch per image (grid [1, 1, 1]) so that the generation common tests, which crop
+ # all inputs along the batch dim, keep ``pixel_values`` and ``image_grid_thw`` consistent.
+ self.num_patches = 1
+ self.num_image_tokens = self.num_patches // (self.spatial_merge_size**2)
+ self.seq_length = seq_length + self.num_image_tokens
+ self.encoder_seq_length = self.seq_length
+
+ def get_config(self):
+ return MiniMaxM3VLConfig(
+ text_config=self.text_config,
+ vision_config=self.vision_config,
+ image_token_index=self.image_token_index,
+ video_token_index=self.video_token_index,
+ projector_hidden_size=self.text_config["hidden_size"],
+ )
+
+ def prepare_config_and_inputs(self):
+ config = self.get_config()
+ patch_dim = self.num_channels * (self.patch_size**2) * self.temporal_patch_size
+ pixel_values = floats_tensor([self.batch_size * self.num_patches, patch_dim])
+ return config, pixel_values
+
+ def prepare_config_and_inputs_for_common(self):
+ config, pixel_values = self.prepare_config_and_inputs()
+
+ input_ids = ids_tensor([self.batch_size, self.seq_length], config.text_config.vocab_size - 2) + 2
+ attention_mask = torch.ones(input_ids.shape, dtype=torch.long, device=torch_device)
+ input_ids[input_ids == self.image_token_index] = self.pad_token_id
+ input_ids[input_ids == self.video_token_index] = self.pad_token_id
+ input_ids[:, : self.num_image_tokens] = self.image_token_index
+
+ inputs_dict = {
+ "pixel_values": pixel_values,
+ "image_grid_thw": torch.tensor([[1, 1, 1]] * self.batch_size, device=torch_device),
+ "input_ids": input_ids,
+ "attention_mask": attention_mask,
+ }
+ return config, inputs_dict
+
+
+@require_torch
+class MiniMaxM3VLModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin, unittest.TestCase):
+ """
+ Model tester for `MiniMaxM3SparseForConditionalGeneration`.
+ """
+
+ all_model_classes = (
+ (
+ MiniMaxM3VLModel,
+ MiniMaxM3SparseForConditionalGeneration,
+ )
+ if is_torch_available()
+ else ()
+ )
+ pipeline_model_mapping = (
+ {
+ "image-text-to-text": MiniMaxM3SparseForConditionalGeneration,
+ }
+ if is_torch_available()
+ else {}
+ )
+ _is_composite = True
+ # The vision tower packs every image's (and video frame's) patches into a single sequence
+ # (batch dim 1), so ``last_hidden_state`` does not carry a per-item batch axis to shape-check.
+ skip_test_image_features_output_shape = True
+ skip_test_video_features_output_shape = True
+
+ # The indexer parameters only influence the argmax over compressed blocks (``topk``),
+ # which is non-differentiable — their gradients flow through a separate objective in
+ # the upstream training recipe, not the main causal-LM loss (same as DeepSeek-V4).
+ test_all_params_have_gradient = False
+
+ def setUp(self):
+ self.model_tester = MiniMaxM3VLVisionText2TextModelTester(self)
+ self.config_tester = ConfigTester(self, config_class=MiniMaxM3VLConfig, has_text_modality=False)
+
+ def test_config(self):
+ self.config_tester.run_common_tests()
+
+ @unittest.skip(reason="IDK exactly why, can be adressed later")
+ def test_reverse_loading_mapping(self):
+ pass
+
+ @unittest.skip(
+ reason=(
+ "This model lists the Lightning-indexer MSA kernel in `_compatible_flash_implementations`, so "
+ "requesting `flash_attention_2` is auto-redirected to `kernels-staging/msa@v0` (the model's "
+ "native flash path) instead of raising. The base test instead expects a `ValueError` whenever "
+ "the composite sub-models don't all set `_supports_flash_attn`, an invariant that does not hold "
+ "for a model whose flash implementation *is* a custom kernel. MSA dispatch via the public "
+ "`attn_implementation` API is covered by the slow integration tests."
+ )
+ )
+ def test_flash_attn_2_can_dispatch_composite_models(self):
+ pass
+
+ @parameterized.expand(TEST_EAGER_MATCHES_BATCHED_AND_GROUPED_INFERENCE_PARAMETERIZATION)
+ def test_eager_matches_batched_and_grouped_inference(self, name, dtype):
+ # In low precision the grouped/batched/sonic expert kernels accumulate in a different order
+ # than the eager per-expert loop, so a handful of near-zero MoE outputs drift past the 1e-4
+ # tolerance. This is precision noise, not a logic mismatch (fp32 matches exactly).
+ if dtype in ("fp16", "bf16"):
+ self.skipTest("Low-precision float casting fluctuations across expert kernels exceed the 1e-4 tolerance")
+ _test_eager_matches_batched_and_grouped_inference(self, name, dtype)
+
+ @unittest.skip(
+ reason=(
+ "The lightning indexer tiles the key axis into blocks of `index_block_size` *slots* and "
+ "selects whole blocks, so its block boundaries are anchored to absolute sequence slots. "
+ "Left-padding shifts every real token by the (per-row, generally non-block-aligned) pad "
+ "width, which regroups real keys into different blocks than the unpadded run and changes "
+ "which blocks win top-k — so left-padded logits diverge by design. This is the same "
+ "block-sparse limitation as DeepSeek-V4; see `test_right_padding_does_not_leak` for the "
+ "padding direction that *is* equivalent."
+ )
+ )
+ def test_left_padding_compatibility(self):
+ pass
+
+ def test_right_padding_does_not_leak(self):
+ """Right-padding must not change a sequence's real-token logits.
+
+ Pad keys land on slots *after* every real token, so block-level causality drops them before
+ top-k selection and the indexer's folded mask zeroes the pad columns. The real tokens therefore
+ occupy the same slots and see the same key blocks as in an unpadded run -- unlike left-padding,
+ which shifts the slot-anchored block boundaries (see `test_left_padding_compatibility`).
+ """
+ config, _ = self.model_tester.prepare_config_and_inputs_for_common()
+ config.text_config._attn_implementation = "eager"
+ vocab = config.text_config.vocab_size
+ pad_id = config.text_config.pad_token_id
+ # Text-only inputs: keep ids clear of the image/video placeholder ids so no vision tower runs.
+ low = max(self.model_tester.image_token_index, self.model_tester.video_token_index) + 1
+ torch.manual_seed(0)
+ lengths = [self.model_tester.seq_length + 6, self.model_tester.seq_length + 1]
+ seqs = [torch.randint(low, vocab - 2, (n,), device=torch_device) for n in lengths]
+ max_len = max(lengths)
+
+ model = MiniMaxM3SparseForConditionalGeneration(config).to(torch_device).eval()
+
+ per_seq_logits = []
+ for seq in seqs:
+ with torch.no_grad():
+ out = model(
+ input_ids=seq[None],
+ attention_mask=torch.ones(1, len(seq), dtype=torch.long, device=torch_device),
+ )
+ per_seq_logits.append(out.logits[0, : len(seq)])
+
+ input_ids = torch.full((len(seqs), max_len), pad_id, device=torch_device)
+ attention_mask = torch.zeros(len(seqs), max_len, dtype=torch.long, device=torch_device)
+ for i, seq in enumerate(seqs):
+ input_ids[i, : len(seq)] = seq
+ attention_mask[i, : len(seq)] = 1
+ with torch.no_grad():
+ batched_logits = model(input_ids=input_ids, attention_mask=attention_mask).logits
+
+ for i, seq in enumerate(seqs):
+ torch.testing.assert_close(batched_logits[i, : len(seq)], per_seq_logits[i], rtol=1e-4, atol=1e-4)
+
+ def test_mismatching_num_image_tokens(self):
+ """
+ Tests that VLMs raise an explicit error when the number of images doesn't match the number
+ of image tokens in the text, and that genuine multi-image cases are accepted.
+ """
+ config, input_dict = self.model_tester.prepare_config_and_inputs_for_common()
+ num_patches = self.model_tester.num_patches
+ for model_class in self.all_model_classes:
+ model = model_class(config).to(torch_device)
+ model.eval()
+ curr_input_dict = copy.deepcopy(input_dict)
+ _ = model(**curr_input_dict) # successful forward with no modifications
+
+ # remove one image but leave its image tokens in text
+ curr_input_dict["pixel_values"] = curr_input_dict["pixel_values"][:-num_patches, ...]
+ curr_input_dict["image_grid_thw"] = curr_input_dict["image_grid_thw"][:-1, ...]
+ with self.assertRaisesRegex(ValueError, "Image features and image tokens do not match"):
+ _ = model(**curr_input_dict)
+
+ # simulate multi-image case by concatenating inputs where each has exactly one image
+ input_ids = curr_input_dict["input_ids"][:1]
+ pixel_values = curr_input_dict["pixel_values"][:num_patches]
+ image_grid_thw = curr_input_dict["image_grid_thw"][:1]
+ input_ids = torch.cat([input_ids, input_ids], dim=0)
+
+ # two image-token groups but one image raises an error
+ with self.assertRaisesRegex(ValueError, "Image features and image tokens do not match"):
+ _ = model(input_ids=input_ids, pixel_values=pixel_values, image_grid_thw=image_grid_thw)
+
+ # two images and two image-token groups don't raise an error
+ pixel_values = torch.cat([pixel_values, pixel_values], dim=0)
+ image_grid_thw = torch.cat([image_grid_thw, image_grid_thw], dim=0)
+ _ = model(input_ids=input_ids, pixel_values=pixel_values, image_grid_thw=image_grid_thw)
+
+ def test_video_forward(self):
+ """Video frames flow through the same vision tower as images and scatter into the video-token slots."""
+ config, _ = self.model_tester.prepare_config_and_inputs_for_common()
+
+ batch_size = self.model_tester.batch_size
+ num_channels = self.model_tester.num_channels
+ temporal_patch_size = self.model_tester.temporal_patch_size
+ patch_size = self.model_tester.patch_size
+ merge_size = self.model_tester.spatial_merge_size
+
+ num_frames = 4
+ grid_t = num_frames // temporal_patch_size
+ grid_h = self.model_tester.image_size // patch_size
+ grid_w = self.model_tester.image_size // patch_size
+ patches_per_video = grid_t * grid_h * grid_w
+ # ``patch_merge`` groups ``merge_size**2`` patches, so each video yields this many tokens.
+ tokens_per_video = patches_per_video // (merge_size**2)
+
+ patch_dim = num_channels * (patch_size**2) * temporal_patch_size
+ pixel_values_videos = floats_tensor([batch_size * patches_per_video, patch_dim])
+ video_grid_thw = torch.tensor([[grid_t, grid_h, grid_w]] * batch_size, device=torch_device)
+ # The vision tower consumes exactly ``grid_t * grid_h * grid_w`` patches per video.
+ self.assertEqual(pixel_values_videos.shape[0], int(video_grid_thw.prod(dim=1).sum()))
+
+ input_ids = ids_tensor([batch_size, self.model_tester.seq_length], config.text_config.vocab_size - 2) + 2
+ input_ids[input_ids == self.model_tester.image_token_index] = self.model_tester.pad_token_id
+ input_ids[input_ids == self.model_tester.video_token_index] = self.model_tester.pad_token_id
+ # Carve out one contiguous block of video-token slots per sequence.
+ self.assertLessEqual(tokens_per_video, self.model_tester.seq_length)
+ input_ids[:, :tokens_per_video] = self.model_tester.video_token_index
+ attention_mask = torch.ones_like(input_ids)
+
+ for model_class in self.all_model_classes:
+ model = model_class(config).to(torch_device)
+ model.eval()
+ with torch.no_grad():
+ outputs = model(
+ input_ids=input_ids,
+ attention_mask=attention_mask,
+ pixel_values_videos=pixel_values_videos,
+ video_grid_thw=video_grid_thw,
+ )
+ self.assertIsNotNone(outputs)
+ self.assertIsNotNone(outputs.video_hidden_states)
+ self.assertEqual(outputs.video_hidden_states.shape[0], batch_size * tokens_per_video)
+
+ def test_mismatching_num_video_tokens(self):
+ """VLMs must raise when the number of videos doesn't match the number of video tokens in the text."""
+ config, _ = self.model_tester.prepare_config_and_inputs_for_common()
+
+ batch_size = self.model_tester.batch_size
+ num_channels = self.model_tester.num_channels
+ temporal_patch_size = self.model_tester.temporal_patch_size
+ patch_size = self.model_tester.patch_size
+ merge_size = self.model_tester.spatial_merge_size
+
+ num_frames = 4
+ grid_t = num_frames // temporal_patch_size
+ grid_h = self.model_tester.image_size // patch_size
+ grid_w = self.model_tester.image_size // patch_size
+ patches_per_video = grid_t * grid_h * grid_w
+ tokens_per_video = patches_per_video // (merge_size**2)
+
+ patch_dim = num_channels * (patch_size**2) * temporal_patch_size
+ pixel_values_videos = floats_tensor([batch_size * patches_per_video, patch_dim])
+ video_grid_thw = torch.tensor([[grid_t, grid_h, grid_w]] * batch_size, device=torch_device)
+
+ input_ids = ids_tensor([batch_size, self.model_tester.seq_length], config.text_config.vocab_size - 2) + 2
+ input_ids[input_ids == self.model_tester.image_token_index] = self.model_tester.pad_token_id
+ input_ids[input_ids == self.model_tester.video_token_index] = self.model_tester.pad_token_id
+ # One fewer video-token slot than features -> mismatch.
+ input_ids[:, : tokens_per_video - 1] = self.model_tester.video_token_index
+
+ for model_class in self.all_model_classes:
+ model = model_class(config).to(torch_device)
+ model.eval()
+ with self.assertRaisesRegex(ValueError, "Video features and video tokens do not match"):
+ _ = model(
+ input_ids=input_ids,
+ pixel_values_videos=pixel_values_videos,
+ video_grid_thw=video_grid_thw,
+ )
+
+
+@slow
+@require_torch
+class MiniMaxM3VLIntegrationTest(unittest.TestCase):
+ model_id = "MiniMaxAI/Minimax-M3-preview"
+
+ def _load_model(self):
+ # The indexer feeds SDPA an additive float mask (the block-sparse bias). On B200 + this
+ # cuDNN build, the cuDNN SDPA backend segfaults in ``run_cudnn_SDP_fprop`` on such masks;
+ # disabling it routes SDPA to the mem-efficient backend, which handles additive float masks.
+ torch.backends.cuda.enable_cudnn_sdp(False)
+
+ # Out-of-the-box load: the MXFP8 ``quantization_config`` (quant_method, weight_block_size,
+ # and the ``ignored_layers`` skip-list) is read straight from the checkpoint's config.json
+ # and dispatched automatically — no hand-built quant config, no manual dtype patching.
+ model = MiniMaxM3SparseForConditionalGeneration.from_pretrained(
+ self.model_id,
+ dtype=torch.bfloat16,
+ device_map="auto",
+ local_files_only=True,
+ )
+ model.eval()
+ return model
+
+ def _load_processor(self):
+ from transformers.utils import cached_file
+
+ tokenizer = AutoTokenizer.from_pretrained(self.model_id, local_files_only=True)
+ image_processor = MiniMaxM3VLImageProcessorFast.from_pretrained(self.model_id, local_files_only=True)
+ video_processor = MiniMaxM3VLVideoProcessor.from_pretrained(self.model_id, local_files_only=True)
+ with open(cached_file(self.model_id, "chat_template.jinja", local_files_only=True)) as f:
+ chat_template = f.read()
+ return MiniMaxM3VLProcessor(
+ image_processor=image_processor,
+ tokenizer=tokenizer,
+ video_processor=video_processor,
+ chat_template=chat_template,
+ )
+
+ @staticmethod
+ def _prompt(processor, question):
+ messages = [{"role": "user", "content": [{"type": "image"}, {"type": "text", "text": question}]}]
+ return processor.apply_chat_template(
+ messages, add_generation_prompt=True, tokenize=False, thinking_mode="disabled"
+ )
+
+ def test_image_and_text_generation(self):
+ model = self._load_model()
+ processor = self._load_processor()
+ image = Image.new("RGB", (672, 672), (127, 127, 127))
+ text = self._prompt(processor, "Describe this image briefly.")
+ inputs = processor(images=[image], text=text, return_tensors="pt").to(model.device)
+ with torch.no_grad():
+ output = model.generate(**inputs, max_new_tokens=32, do_sample=False)
+ decoded = processor.batch_decode(output, skip_special_tokens=True)[0]
+ print(f"\n[test_image_and_text_generation] generation:\n{decoded!r}\n")
+ self.assertIsInstance(decoded, str)
+ self.assertGreater(len(decoded.strip()), 0)
+
+ def test_real_image_apple_recognition(self):
+ import os
+
+ model = self._load_model()
+ processor = self._load_processor()
+
+ apple_path = os.path.join(os.path.dirname(__file__), "../../fixtures/tests_samples/COCO/apple.jpg")
+ image = Image.open(apple_path).convert("RGB")
+ text = self._prompt(processor, "What fruit is shown in this image? Answer in one word.")
+ inputs = processor(images=[image], text=text, return_tensors="pt").to(model.device)
+ with torch.no_grad():
+ output = model.generate(**inputs, max_new_tokens=16, do_sample=False)
+ completion = processor.batch_decode(output[:, inputs.input_ids.size(1) :], skip_special_tokens=True)[0]
+ print(f"\n[test_real_image_apple_recognition] completion:\n{completion!r}\n")
+ self.assertIn("apple", completion.lower())
+
+ def test_batched_image_generation(self):
+ """Batch of two image+text prompts, no padding.
+
+ Mirrors the DeepSeek-V4 multi-prompt integration check: two distinct solid-color
+ images are batched with an identical prompt, so both rows tokenize to the same
+ length and no padding is needed. A correct run must describe each image with its
+ own color, proving the vision features stay aligned with the right tokens across
+ the batch and the MXFP8 MoE path.
+
+ Padding is deliberately avoided: the block-sparse indexer anchors blocks to absolute
+ key *slots* (see ``MiniMaxM3VLIndexer`` docstring), so left-padding shifts block
+ boundaries and diverges from an unpadded run — the same slot-based limitation as
+ DeepSeek-V4. ``test_left_padding_compatibility`` documents that gap.
+ """
+ model = self._load_model()
+ processor = self._load_processor()
+
+ red = Image.new("RGB", (672, 672), (200, 30, 30))
+ blue = Image.new("RGB", (672, 672), (30, 30, 200))
+ # Identical prompt + identical image geometry → identical token length → no padding.
+ question = "What is the dominant color of this image? Answer in one word."
+ texts = [self._prompt(processor, question), self._prompt(processor, question)]
+ inputs = processor(images=[red, blue], text=texts, return_tensors="pt").to(model.device)
+ self.assertEqual(inputs.input_ids.shape[0], 2)
+ with torch.no_grad():
+ output = model.generate(**inputs, max_new_tokens=32, do_sample=False)
+ completions = processor.batch_decode(output[:, inputs.input_ids.size(1) :], skip_special_tokens=True)
+ print(f"\n[test_batched_image_generation] completions:\n{completions!r}\n")
+ self.assertEqual(len(completions), 2)
+ for completion in completions:
+ self.assertGreater(len(completion.strip()), 0)
+ # Each completion should name its own image's color, not the other's.
+ self.assertIn("red", completions[0].lower())
+ self.assertIn("blue", completions[1].lower())
+
+ def test_padding_sides_text_and_image(self):
+ """Lock the batched-padding contract end to end, for both text and image batches.
+
+ Two facts, both consequences of the slot-anchored Lightning indexer (see
+ ``MiniMaxM3VLIndexer``) plus causal attention:
+
+ * RIGHT padding does not change a real token's prediction. Pad keys land on slots *after*
+ every real token, so causal attention + the folded indexer mask drop them before any real
+ query attends to them. The MoE still *runs* on the pad rows (there is no token-level mask
+ inside the router), but that wasted compute never flows back into a real token -- so the
+ well-known right-padding generation garbage is purely a "generation continues from a pad
+ slot" artifact, NOT a missing-mask bug in the MoE. We assert the greedy next-token at each
+ real position is identical to an unpadded single-sequence run.
+ * LEFT padding is therefore the side to use for batched ``generate``: every real token's
+ continuation stays anchored to the live slots, so each row decodes coherently.
+ """
+ import requests
+
+ model = self._load_model()
+ processor = self._load_processor()
+ tokenizer = processor.tokenizer
+
+ # ---- RIGHT padding: real-token greedy predictions must match an unpadded run (no MoE leak) ----
+ prompts = [
+ "The history of computing began with",
+ "Summarize the theory of relativity in a single concise sentence:",
+ ]
+ per_seq = [tokenizer(p, return_tensors="pt").to(model.device) for p in prompts]
+ per_seq_logits = []
+ for enc in per_seq:
+ with torch.no_grad():
+ per_seq_logits.append(model(**enc).logits[0, : enc.input_ids.size(1)])
+
+ tokenizer.padding_side = "right"
+ right = tokenizer(prompts, return_tensors="pt", padding=True).to(model.device)
+ with torch.no_grad():
+ right_logits = model(**right).logits
+ for i, enc in enumerate(per_seq):
+ n = enc.input_ids.size(1)
+ self.assertTrue(
+ torch.equal(right_logits[i, :n].argmax(-1), per_seq_logits[i].argmax(-1)),
+ "Right-padding changed a real token's greedy prediction -> padding leaked into the "
+ "MoE/attention. The MoE may compute on pad rows, but causal attention + the indexer "
+ "mask must keep that out of every real token.",
+ )
+
+ # ---- LEFT padding: batched text generation decodes each row coherently ----
+ tokenizer.padding_side = "left"
+ left = tokenizer(prompts, return_tensors="pt", padding=True).to(model.device)
+ with torch.no_grad():
+ gen = model.generate(**left, max_new_tokens=24, do_sample=False)
+ completions = tokenizer.batch_decode(gen[:, left.input_ids.size(1) :], skip_special_tokens=True)
+ print(f"\n[test_padding_sides_text_and_image] left-pad text completions:\n{completions!r}\n")
+ for completion in completions:
+ self.assertGreater(len(completion.strip()), 0)
+
+ # ---- LEFT padding with an image batch of differing prompt lengths (real padding) ----
+ # Two real, semantically distinct images downloaded from the hub/web (the picsum dog and the
+ # canonical COCO two-cats photo), paired with deliberately different-length questions so the
+ # batch genuinely needs padding.
+ dog = Image.open(requests.get("https://picsum.photos/id/237/400/300", stream=True).raw).convert("RGB")
+ cats = Image.open(
+ requests.get("http://images.cocodataset.org/val2017/000000039769.jpg", stream=True).raw
+ ).convert("RGB")
+ texts = [
+ self._prompt(processor, "What animal is this? One word."),
+ self._prompt(
+ processor,
+ "Look carefully at this photograph and tell me which animal appears in it, "
+ "answering with a single word.",
+ ),
+ ]
+ processor.tokenizer.padding_side = "left"
+ img_inputs = processor(images=[dog, cats], text=texts, return_tensors="pt", padding=True).to(model.device)
+ self.assertEqual(img_inputs.input_ids.shape[0], 2)
+ with torch.no_grad():
+ img_gen = model.generate(**img_inputs, max_new_tokens=16, do_sample=False)
+ img_completions = processor.batch_decode(img_gen[:, img_inputs.input_ids.size(1) :], skip_special_tokens=True)
+ print(f"\n[test_padding_sides_text_and_image] left-pad image completions:\n{img_completions!r}\n")
+ self.assertIn("dog", img_completions[0].lower())
+ self.assertIn("cat", img_completions[1].lower())
+
+ def test_video_generation(self):
+ """End-to-end video path: the processor emits ``pixel_values_videos`` / ``video_grid_thw`` and the
+ model scatters the video features into the video-token slots before generating.
+
+ Uses a short synthetic clip rather than a network fetch: 672 is divisible by the vision tower's
+ ``patch_size * spatial_merge_size`` factor (28) and the frame count is a multiple of
+ ``temporal_patch_size``, so the video grid is well formed and the merged-patch count lines up
+ exactly with the expanded video tokens.
+ """
+ import numpy as np
+
+ model = self._load_model()
+ processor = self._load_processor()
+
+ num_frames = 4
+ video = np.zeros((num_frames, 672, 672, 3), dtype=np.uint8)
+ video[..., 0] = 200 # a solid red-ish clip
+
+ messages = [
+ {
+ "role": "user",
+ "content": [
+ {"type": "video"},
+ {"type": "text", "text": "What is the dominant color in this video? Answer in one word."},
+ ],
+ }
+ ]
+ text = processor.apply_chat_template(
+ messages, add_generation_prompt=True, tokenize=False, thinking_mode="disabled"
+ )
+ inputs = processor(videos=[video], text=text, return_tensors="pt").to(model.device)
+ # The processor must have produced the video tensors the model consumes.
+ self.assertIn("pixel_values_videos", inputs)
+ self.assertIn("video_grid_thw", inputs)
+ with torch.no_grad():
+ output = model.generate(**inputs, max_new_tokens=32, do_sample=False)
+ decoded = processor.batch_decode(output[:, inputs.input_ids.size(1) :], skip_special_tokens=True)[0]
+ print(f"\n[test_video_generation] generation:\n{decoded!r}\n")
+ self.assertIsInstance(decoded, str)
+ self.assertGreater(len(decoded.strip()), 0)
+ self.assertIn("red", decoded.lower())
diff --git a/utils/check_config_attributes.py b/utils/check_config_attributes.py
index 25a4f7e00edc..9b176987ccc6 100644
--- a/utils/check_config_attributes.py
+++ b/utils/check_config_attributes.py
@@ -48,6 +48,7 @@
"pooling_kernel_size",
], # Used as meta data for other attributes/properties
"MiniCPMV4_6Config": ["drop_vision_last_layer"],
+ "MiniMaxM3VLTextConfig": ["rotary_dim", "router_jitter_noise"],
"OpenAIPrivacyFilterConfig": ["classifier_dropout", "output_router_logits", "router_aux_loss_coef"],
"HYV3Config": ["output_router_logits"],
"NougatConfig": ["decoder", "encoder"],
diff --git a/utils/check_repo.py b/utils/check_repo.py
index 778f9ecdbb42..e026cba796eb 100644
--- a/utils/check_repo.py
+++ b/utils/check_repo.py
@@ -113,6 +113,7 @@
"SmolVLMVisionTransformer",
"MiniCPMV4_6VisionPreTrainedModel",
"MiniCPMV4_6VisionModel",
+ "MiniMaxM3VLVisionModel",
"AriaTextForCausalLM",
"AriaTextModel",
"Phi4MultimodalAudioModel",
@@ -223,6 +224,9 @@
"PaddleOCRTextModel", # Building part of bigger (tested) model. Tested implicitly through PaddleOCRVLForConditionalGeneration.
"Qwen2VLModel", # Building part of bigger (tested) model. Tested implicitly through Qwen2VLForConditionalGeneration.
"Qwen2_5_VLModel", # Building part of bigger (tested) model. Tested implicitly through Qwen2_5_VLForConditionalGeneration.
+ "MiniMaxM3VLForCausalLM", # Building part of bigger (tested) model. Tested implicitly through MiniMaxM3SparseForConditionalGeneration.
+ "MiniMaxM3VLTextModel", # Building part of bigger (tested) model. Tested implicitly through MiniMaxM3SparseForConditionalGeneration.
+ "MiniMaxM3VLVisionModel", # Building part of bigger (tested) model. Tested implicitly through MiniMaxM3SparseForConditionalGeneration.
"Qwen3VLModel", # Building part of bigger (tested) model. Tested implicitly through Qwen3VLForConditionalGeneration.
"Qwen3VLMoeModel", # Building part of bigger (tested) model. Tested implicitly through Qwen3VLMoeForConditionalGeneration.
"Qwen3VLTextModel", # Building part of bigger (tested) model.
diff --git a/utils/modular_model_detector.py b/utils/modular_model_detector.py
index 33a71e4f1fd7..92ab1206502d 100644
--- a/utils/modular_model_detector.py
+++ b/utils/modular_model_detector.py
@@ -257,6 +257,7 @@ def __init__(self, hub_dataset: str):
self.model = AutoModel.from_pretrained(EMBEDDING_MODEL, torch_dtype="auto", device_map="auto").eval()
self.device = self.model.device
+ self.dtype = self.model.dtype
self.index_dir: Path | None = None
# ---------- HUB IO ----------