diff --git a/docs/source/en/_toctree.yml b/docs/source/en/_toctree.yml index 3c776a0c8d3c..61190f7837e5 100644 --- a/docs/source/en/_toctree.yml +++ b/docs/source/en/_toctree.yml @@ -741,6 +741,8 @@ title: MiniMax - local: model_doc/minimax_m2 title: MiniMax-M2 + - local: model_doc/minimax_m3_vl + title: MiniMax-M3-VL - local: model_doc/ministral title: Ministral - local: model_doc/ministral3 diff --git a/docs/source/en/model_doc/minimax_m3_vl.md b/docs/source/en/model_doc/minimax_m3_vl.md new file mode 100644 index 000000000000..5a73737b7e67 --- /dev/null +++ b/docs/source/en/model_doc/minimax_m3_vl.md @@ -0,0 +1,253 @@ + +*This model was contributed to Hugging Face Transformers on 2026-06-12.* + + +# MiniMax-M3-VL + +## Overview + +MiniMax-M3-VL is the vision-language member of the MiniMax-M3 family. It pairs a CLIP-style vision tower (Conv3d patch embedding with 3D rotary position embeddings) with the MiniMax-M3 text backbone, a mixed dense/sparse Mixture-of-Experts decoder that uses SwiGLU-OAI gated experts and a lightning indexer for block-sparse attention. + +## Architecture +### Block-sparse attention (Lightning Indexer) + +Every layer is GQA (`num_key_value_heads = 4`) with per-head QK-norm and **partial RoPE** on the first +`rotary_dim`. `config.layer_types[i]` then picks `"full_attention"` (dense causal) or +`"minimax_m3_sparse"`, where a [`MiniMaxM3VLIndexer`] decides, per query, which block of keys the main attention may see. + +The indexer scores every key, then **max-poolsthose per-key scores into blocks of `index_block_size` keys**, so selection happens at the granularity of a *block +of keys*: per query it keeps the top-`index_topk_blocks` key blocks plus the always-on `index_local_blocks` +local-window block (under block-level causality), broadcasts the per-block `0`/`-inf` choice back onto every key in +the block. The result is a `[B, 1, S_q, S_k]` additive bias summed onto the causal mask. +Theoretically this means that the attention is only computed over the selected blocks of keys, but `transformers` does not support the kernels that compute this efficiently! +We are adding it to `kernels` asap! + +MiniMax M3 Lightning Indexer mask + + +### Vision tower + +A [`MiniMaxM3VLVisionModel`]: a `Conv3d` patch embedding over flattened `[N_patches, C·T·P·P]` input, a stack of +CLIP-style encoder layers carrying a **3D rotary** position embedding (time / height / width bands). A [`MiniMaxM3VLPatchMerger`] groups +`spatial_merge_size²` patches into the channel dim before the 2-layer GELU [`MiniMaxM3VLMultiModalProjector`] maps vision features into the text hidden size. + +## Usage examples + +The example below runs the model on a real image loaded with [`~transformers.image_utils.load_image`]. + +```python +import torch +from transformers import AutoModelForImageTextToText, AutoProcessor +from transformers.image_utils import load_image + + +model = AutoModelForImageTextToText.from_pretrained( + "MiniMaxAI/MiniMax-M3-preview", dtype=torch.bfloat16, device_map="auto", +) +processor = AutoProcessor.from_pretrained("MiniMaxAI/MiniMax-M3-preview") + +image = load_image("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/bee.jpg") +messages = [ + { + "role": "user", + "content": [ + {"type": "image"}, + {"type": "text", "text": "Describe this image briefly."}, + ], + } +] +text = processor.apply_chat_template(messages, add_generation_prompt=True, tokenize=False) +inputs = processor(images=[image], text=text, return_tensors="pt").to(model.device) + +generated_ids = model.generate(**inputs, max_new_tokens=32, do_sample=False) +print(processor.batch_decode(generated_ids, skip_special_tokens=True)[0]) +``` + +### Apple example + +This example asks the model about an image of apples, again loading a real image with +[`~transformers.image_utils.load_image`]. + +```python +import torch +from transformers import AutoModelForImageTextToText, AutoProcessor +from transformers.image_utils import load_image + + +model = AutoModelForImageTextToText.from_pretrained( + "MiniMaxAI/MiniMax-M3-preview", dtype=torch.bfloat16, device_map="auto", +) +processor = AutoProcessor.from_pretrained("MiniMaxAI/MiniMax-M3-preview") + +image = load_image("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/bee.jpg") +messages = [ + { + "role": "user", + "content": [ + {"type": "image"}, + {"type": "text", "text": "How many apples are in this image, and what color are they?"}, + ], + } +] +text = processor.apply_chat_template(messages, add_generation_prompt=True, tokenize=False) +inputs = processor(images=[image], text=text, return_tensors="pt").to(model.device) + +generated_ids = model.generate(**inputs, max_new_tokens=32, do_sample=False) +print(processor.batch_decode(generated_ids, skip_special_tokens=True)[0]) +``` + +## Fastest inference configuration + +| ctx | SDPA decode | MSA decode | MSA decode adv. | SDPA prefill | MSA prefill | MSA prefill adv. | +| --: | ----------: | ---------: | --------------: | -----------: | ----------: | ---------------: | +| 2K | 27.8 tok/s | 31.0 | +12% | 303 ms | 257 ms | 1.18× | +| 4K | 23.4 tok/s | 30.5 | +30% | 684 ms | 460 ms | 1.49× | +| 8K | 17.8 tok/s | 29.6 | +66% | 1906 ms | 976 ms | 1.95× | +| 16K | 12.0 tok/s | 27.6 | +130% | 6110 ms | 2344 ms | 2.61× | + +The checkpoint ships in native MXFP8. For **decode throughput**, the fastest validated configuration is +**bf16 (dequantized at load) + the MSA block-sparse attention kernel + tensor & expert parallelism + a +`reduce-overhead` cudagraph compile** — roughly **31 tok/s** decode on 8×B200 at a 2048-token prefill. + +Keeping the weights in **native FP8 is a memory-footprint option only — it is never faster on this setup**. +The FP8 Triton experts/linear kernels lower as opaque inductor fallback kernels that cudagraph cannot +capture on the hot expert path, so native-FP8 decode measured ~4.2 tok/s (≈7× slower than the bf16 path) +even under `torch.compile(fullgraph=True)`. Use FP8 only when the bf16 weights do not fit. + +| config (sdpa baseline, TP+EP, 2048-token prefill, 8×B200) | decode | +|---|---| +| bf16 dequantize-at-load + **MSA** + compile/cudagraph | **~31 tok/s** | +| bf16 dequantize-at-load + sdpa + compile/cudagraph | ~28 tok/s | +| native FP8 + compile/cudagraph | ~4 tok/s (memory-only, not for speed) | + +Dequantizing to bf16 only fits with even sharding across GPUs (TP/EP), not with `device_map="auto"` +(pipeline placement OOMs at load). Launch one process per GPU with `torchrun`: + +```bash +torchrun --nproc_per_node=8 fastest_m3_vl.py +``` + +```python +# fastest_m3_vl.py +import os, sys +import torch +import torch.distributed as dist +from transformers import ( + AutoModelForImageTextToText, + AutoTokenizer, + CompileConfig, + FineGrainedFP8Config, +) +from transformers.distributed import DistributedConfig + +# The indexer feeds SDPA an additive float mask; the cuDNN SDP backend segfaults on it (B200). +torch.backends.cuda.enable_cudnn_sdp(False) + +model = AutoModelForImageTextToText.from_pretrained( + "MiniMaxAI/MiniMax-M3-preview", + dtype=torch.bfloat16, + # Dequantize the native MXFP8 weights to bf16 at load (the speed win); needs even TP/EP sharding. + quantization_config=FineGrainedFP8Config(dequantize=True), + tp_plan="auto", + distributed_config=DistributedConfig(enable_expert_parallel=True), + attn_implementation="kernels-staging/msa@v0", # MSA block-sparse attention kernel +) +model.eval() + +tokenizer = AutoTokenizer.from_pretrained("MiniMaxAI/MiniMax-M3-preview") +messages = [{"role": "user", "content": "Summarize the history of computing."}] +inputs = tokenizer.apply_chat_template( + messages, add_generation_prompt=True, return_tensors="pt", return_dict=True +).to(f"cuda:{os.environ.get('LOCAL_RANK', '0')}") + +generated_ids = model.generate( + **inputs, + max_new_tokens=128, + do_sample=False, + # Static cache + reduce-overhead cudagraph capture is what pushes decode to ~31 tok/s. + cache_implementation="static", + compile_config=CompileConfig(mode="reduce-overhead", fullgraph=True), +) +if int(os.environ.get("RANK", "0")) == 0: + print(tokenizer.decode(generated_ids[0, inputs["input_ids"].shape[1]:], skip_special_tokens=True)) + +# cudagraph-captured NCCL collectives deadlock the NCCL/CUDA destructors at teardown; the output is +# already produced, so hard-exit to skip the hanging cleanup. +if dist.is_initialized(): + sys.stdout.flush() + os._exit(0) +``` + +## MiniMaxM3VLConfig + +[[autodoc]] MiniMaxM3VLConfig + +## MiniMaxM3VLTextConfig + +[[autodoc]] MiniMaxM3VLTextConfig + +## MiniMaxM3VLVisionConfig + +[[autodoc]] MiniMaxM3VLVisionConfig + +## MiniMaxM3VLProcessor + +[[autodoc]] MiniMaxM3VLProcessor + +## MiniMaxM3VLImageProcessor + +This is a standalone (non-modular) image processor: it shares the patch-flattening idea of [`Qwen2VLImageProcessor`] +but does not inherit from it because the two diverge in ways that touch most of the class. The resize budget is driven by +a `max_pixels` attribute and a `{"height", "width"}` `size` rather than Qwen's `shortest_edge`/`longest_edge` scheme; the +`smart_resize` helper clamps the initial rounding with `max(factor, ...)`; and `_preprocess` performs real temporal +handling (5D patches, last-frame repeat to fill `temporal_patch_size`, and a `grid_t` dimension) instead of Qwen's +`grid_t = 1` + expand. Mapping to or subclassing Qwen would therefore change behavior or require overriding nearly +everything, so the processor is kept on its own. + +[[autodoc]] MiniMaxM3VLImageProcessor + +## MiniMaxM3VLVideoProcessor + +[[autodoc]] MiniMaxM3VLVideoProcessor + +## MiniMaxM3VLVisionModel + +[[autodoc]] MiniMaxM3VLVisionModel + - forward + +## MiniMaxM3VLTextModel + +[[autodoc]] MiniMaxM3VLTextModel + - forward + +## MiniMaxM3VLModel + +[[autodoc]] MiniMaxM3VLModel + - forward + +## MiniMaxM3VLForCausalLM + +[[autodoc]] MiniMaxM3VLForCausalLM + - forward + +## MiniMaxM3SparseForConditionalGeneration + +[[autodoc]] MiniMaxM3SparseForConditionalGeneration + - forward diff --git a/src/transformers/cache_utils.py b/src/transformers/cache_utils.py index 4801b37a8558..46f4620bdd1e 100644 --- a/src/transformers/cache_utils.py +++ b/src/transformers/cache_utils.py @@ -33,6 +33,12 @@ # ``PreTrainedConfig`` (the decoder text config) as the only positional argument. LAYER_TYPE_CACHE_MAPPING: dict[str, type] = {} +# Parallel registry for the *static* implementation of a custom layer type, consulted by +# ``StaticCache``. A ``StaticLayer`` subclass with a ``layer_type`` registers here instead of in +# ``LAYER_TYPE_CACHE_MAPPING`` (constructed with ``max_cache_len=...``), so a single ``layer_type`` +# can have both a dynamic and a static cache layer (e.g. M3's sparse-attention indexer cache). +LAYER_TYPE_STATIC_CACHE_MAPPING: dict[str, type] = {} + class CacheLayerMixin(ABC): """Base, abstract class for a single layer's cache.""" @@ -47,7 +53,11 @@ def __init_subclass__(cls, **kwargs): super().__init_subclass__(**kwargs) layer_type = cls.__dict__.get("layer_type", None) if layer_type is not None: - LAYER_TYPE_CACHE_MAPPING[layer_type] = cls + static_base = globals().get("StaticLayer") + if static_base is not None and issubclass(cls, static_base): + LAYER_TYPE_STATIC_CACHE_MAPPING[layer_type] = cls + else: + LAYER_TYPE_CACHE_MAPPING[layer_type] = cls def __init__(self): self.keys: torch.Tensor | None = None @@ -1578,6 +1588,9 @@ def __init__( # LinearAttention layers are static by essence - using `"moe"` as well is a trick, see the comment about it on DynamicCache elif layer_type in ("mamba", "conv", "linear_attention", "moe"): layer = LinearAttentionLayer() + # Custom layer types (e.g. M3's sparse-attention indexer cache) that registered a static variant. + elif layer_type in LAYER_TYPE_STATIC_CACHE_MAPPING: + layer = LAYER_TYPE_STATIC_CACHE_MAPPING[layer_type](max_cache_len=max_cache_len) elif layer_type == "deepseek_sparse_attention": # Static / compile-friendly indexed layer (preallocated indexer key cache). layer = StaticIndexedLayer(max_cache_len=max_cache_len) diff --git a/src/transformers/configuration_utils.py b/src/transformers/configuration_utils.py index f902994eb71b..2b6fdce06e28 100755 --- a/src/transformers/configuration_utils.py +++ b/src/transformers/configuration_utils.py @@ -65,6 +65,7 @@ "chunked_attention", "compressed_sparse_attention", # CSA, used in deepseek_v4 "heavily_compressed_attention", # HCA, used in deepseek_v4 + "minimax_m3_sparse", # lightning-index sparse attention, used in minimax_m3_vl "linear_attention", # used in minimax "conv", # used in LFMv2 "mamba", diff --git a/src/transformers/conversion_mapping.py b/src/transformers/conversion_mapping.py index fe1a66b7a210..a0cff222273a 100755 --- a/src/transformers/conversion_mapping.py +++ b/src/transformers/conversion_mapping.py @@ -455,6 +455,115 @@ def _build_checkpoint_conversion_mapping(): WeightRenaming(source_patterns=r"^vision_tower", target_patterns="model.vision_tower"), WeightRenaming(source_patterns=r"^multi_modal_projector", target_patterns="model.multi_modal_projector"), ], + "minimax_m3_vl": [ + # Ordering matters for save round-tripping: the reverse mapping flips the order *and* each + # transform (see deepseek_v4 above). We therefore split into two passes: structural prefix + # renames first (so they apply last on save / first on load), then specific in-prefix renames + # that operate on the already-prefixed keys. Every target prefix here is distinct and anchored, + # so no reversed source pattern is broad enough to steal keys from another namespace. + # ---- Pass 1: top-level + structural prefix renames ---- + WeightRenaming(source_patterns=r"^language_model\.lm_head", target_patterns="lm_head"), + WeightRenaming(source_patterns=r"^language_model\.model\.", target_patterns="model.language_model."), + # The vision tower flattens CLIP's `vision_model.{encoder.layers,embeddings.patch_embedding, + # pre_layrnorm}` nesting onto `vision_tower.{layers,embeddings.proj,pre_layrnorm}`. Each rule is + # anchored and leaf-specific so its reverse re-inserts `vision_model` only on the right keys (a + # blanket `.vision_model.` -> `.` rule reverses to "match any char" and mangles every key). + WeightRenaming( + source_patterns=r"^vision_tower\.vision_model\.embeddings\.patch_embedding\.", + target_patterns="model.vision_tower.embeddings.proj.", + ), + WeightRenaming( + source_patterns=r"^vision_tower\.vision_model\.encoder\.layers\.", + target_patterns="model.vision_tower.layers.", + ), + WeightRenaming( + source_patterns=r"^vision_tower\.vision_model\.pre_layrnorm\.", + target_patterns="model.vision_tower.pre_layrnorm.", + ), + # The projector hosts both the upstream `multi_modal_projector.linear_{1,2}` and the + # `patch_merge_mlp.linear_{1,2}` (registered as `merge_linear_{1,2}`). Spell each leaf out so the + # reversed `linear_*` source never also matches `merge_linear_*` (or vice versa). + WeightRenaming( + source_patterns=r"^multi_modal_projector\.linear_1\.", + target_patterns="model.multi_modal_projector.linear_1.", + ), + WeightRenaming( + source_patterns=r"^multi_modal_projector\.linear_2\.", + target_patterns="model.multi_modal_projector.linear_2.", + ), + WeightRenaming( + source_patterns=r"^patch_merge_mlp\.linear_1\.", + target_patterns="model.multi_modal_projector.merge_linear_1.", + ), + WeightRenaming( + source_patterns=r"^patch_merge_mlp\.linear_2\.", + target_patterns="model.multi_modal_projector.merge_linear_2.", + ), + # ---- Pass 2: specific in-prefix renames (operate on already-prefixed keys) ---- + # MoE layers rename `block_sparse_moe.*` -> `mlp.*`, but dense layers already use `mlp.*`. A blanket + # `block_sparse_moe.` -> `mlp.` rule reverses to `mlp.` -> `block_sparse_moe.` and corrupts the dense + # layers on save, so we rename per MoE leaf (`experts` / `shared_experts` / `gate` / e-score). The + # reversed `mlp.experts.` / `mlp.shared_experts.` / `mlp.gate.weight` sources match only MoE + # sublayers, never the dense `mlp.gate_up_proj` / `mlp.down_proj`. + WeightRenaming( + source_patterns=r"\.language_model\.layers\.(\d+)\.block_sparse_moe\.experts\.", + target_patterns=r".language_model.layers.\1.mlp.experts.", + ), + WeightRenaming( + source_patterns=r"\.language_model\.layers\.(\d+)\.block_sparse_moe\.shared_experts\.", + target_patterns=r".language_model.layers.\1.mlp.shared_experts.", + ), + WeightRenaming( + source_patterns=r"\.language_model\.layers\.(\d+)\.block_sparse_moe\.gate\.weight", + target_patterns=r".language_model.layers.\1.mlp.gate.weight", + ), + WeightRenaming( + source_patterns=r"\.language_model\.layers\.(\d+)\.block_sparse_moe\.e_score_correction_bias", + target_patterns=r".language_model.layers.\1.mlp.gate.e_score_correction_bias", + ), + WeightRenaming( + source_patterns=r"\.self_attn\.index_q_proj\.", + target_patterns=".self_attn.indexer.q_proj.", + ), + WeightRenaming( + source_patterns=r"\.self_attn\.index_k_proj\.", + target_patterns=".self_attn.indexer.k_proj.", + ), + WeightRenaming( + source_patterns=r"\.self_attn\.index_q_norm\.", + target_patterns=".self_attn.indexer.q_norm.", + ), + WeightRenaming( + source_patterns=r"\.self_attn\.index_k_norm\.", + target_patterns=".self_attn.indexer.k_norm.", + ), + WeightConverter( + source_patterns=[ + "mlp.experts.*.w1.weight", + "mlp.experts.*.w3.weight", + ], + target_patterns="mlp.experts.gate_up_proj", + operations=[MergeModulelist(dim=0), Concatenate(dim=1)], + ), + WeightConverter( + source_patterns="mlp.experts.*.w2.weight", + target_patterns="mlp.experts.down_proj", + operations=[MergeModulelist(dim=0)], + ), + WeightConverter( + source_patterns=["mlp.gate_proj.weight", "mlp.up_proj.weight"], + target_patterns="mlp.gate_up_proj.weight", + operations=[Concatenate(dim=0)], + ), + WeightConverter( + source_patterns=[ + "mlp.shared_experts.gate_proj.weight", + "mlp.shared_experts.up_proj.weight", + ], + target_patterns="mlp.shared_experts.gate_up_proj.weight", + operations=[Concatenate(dim=0)], + ), + ], "qwen2_audio": [ WeightRenaming(source_patterns=r"^language_model.model", target_patterns="model.language_model"), WeightRenaming(source_patterns=r"^language_model.lm_head", target_patterns="lm_head"), diff --git a/src/transformers/core_model_loading.py b/src/transformers/core_model_loading.py index 2493e1560a36..0d57b0b8d23c 100644 --- a/src/transformers/core_model_loading.py +++ b/src/transformers/core_model_loading.py @@ -1429,6 +1429,10 @@ def convert_and_load_state_dict_in_model( mapping.distributed_operation = tp_layer( device_mesh=device_mesh, rank=device_mesh.get_local_rank(), empty_param=empty_param.clone() ) + # Per-expert sharding (EP) needs `tensor_idx` = the expert index so the + # distributed op selects whole experts. The signal is a `MergeModulelist` + # in the chain; it isn't always `operations[0]` (e.g. an FP8 quantizer + # prepends a scale-decode op), so scan the whole chain rather than just the head. shard_index = ( len(mapping.collected_tensors.get(source_pattern, [])) if isinstance(mapping, WeightConverter) diff --git a/src/transformers/generation/utils.py b/src/transformers/generation/utils.py index c58a3d154e9f..03c59d7b2465 100644 --- a/src/transformers/generation/utils.py +++ b/src/transformers/generation/utils.py @@ -1840,9 +1840,13 @@ def _supports_default_dynamic_cache(cls: type["GenerativePreTrainedModel"]) -> b "rwkv", "xlstm", ) - # name clash between minimax and minimax m2, so we add this "or" - return "minimaxm2" in cls.__name__.lower() or all( - unsupported_name not in cls.__name__.lower() for unsupported_name in unsupported_model_names + # The "minimax" exclusion targets the original linear-attention MiniMax (custom cache); the later + # MiniMax M2 / M3 are standard attention models that use the regular Dynamic/Static caches. + name = cls.__name__.lower() + return ( + "minimaxm2" in name + or "minimaxm3" in name + or all(unsupported_name not in name for unsupported_name in unsupported_model_names) ) def _prepare_cache_for_generation( @@ -1904,13 +1908,7 @@ def _prepare_cache_for_generation( generation_config.cache_implementation = "dynamic_full" dynamic_cache_kwargs = {} - # linear attention models always need to pass the config, otherwise it will use an Attention cache for the LinearAttention layers - is_linear_attention = any( - x in ("mamba", "conv", "linear_attention") - for x in (getattr(self.config.get_text_config(decoder=True), "layer_types", []) or []) - ) - if generation_config.cache_implementation != "dynamic_full" or is_linear_attention: - dynamic_cache_kwargs["config"] = self.config.get_text_config(decoder=True) + dynamic_cache_kwargs["config"] = self.config.get_text_config(decoder=True) if generation_config.cache_implementation == "offloaded": dynamic_cache_kwargs["offloading"] = True diff --git a/src/transformers/integrations/finegrained_fp8.py b/src/transformers/integrations/finegrained_fp8.py index d8a43c97c193..2bbda2d23025 100644 --- a/src/transformers/integrations/finegrained_fp8.py +++ b/src/transformers/integrations/finegrained_fp8.py @@ -593,6 +593,8 @@ def __init__( self.activation_scheme = activation_scheme self.num_experts = _first_attr(config, "num_local_experts", "num_experts") self.intermediate_dim = _first_attr(config, "moe_intermediate_size", "intermediate_size") + self.swiglu_alpha = getattr(config, "swiglu_alpha", None) + self.swiglu_limit = getattr(config, "swiglu_limit", None) self.act_fn = ACT2FN[_first_attr(config, "hidden_activation", "hidden_act")] self.limit = getattr(config, "swiglu_limit", None) @@ -640,7 +642,13 @@ def __init__( def _apply_gate(self, gate_up: torch.Tensor) -> torch.Tensor: gate, up = gate_up.chunk(2, dim=-1) - if self.limit is not None: + if self.swiglu_alpha is not None: + # Clamped SwiGLU-OAI gate (same math as the model's non-quantized experts). + gate = gate.clamp(max=self.swiglu_limit) + up = up.clamp(min=-self.swiglu_limit, max=self.swiglu_limit) + glu = gate * torch.sigmoid(gate * self.swiglu_alpha) + return (up + 1.0) * glu + elif self.limit is not None: gate = gate.clamp(max=self.limit) up = up.clamp(min=-self.limit, max=self.limit) return self.act_fn(gate) * up @@ -958,10 +966,16 @@ def _dequantize_one( output_dtype = ( scales.dtype if scales.dtype.is_floating_point and scales.element_size() >= 2 else torch.bfloat16 ) - + # MXFP8 checkpoints ship E8M0 exponents stored as ``torch.uint8`` (one byte per + # block) — the actual scale is `2 ** (byte - 127)`. Interpreting the raw bytes + # as scalar multipliers would be silently wrong, so unpack to fp32 here. + if scales.dtype == torch.uint8: + s_fp32 = (scales.to(torch.float32) - 127.0).exp2() + else: + s_fp32 = scales.to(torch.float32) original_shape = quantized_fp32.shape q = quantized_fp32.reshape(-1, scale_rows, block_m, scale_cols, block_n) - s = scales.to(torch.float32).reshape(-1, scale_rows, scale_cols).unsqueeze(-1).unsqueeze(2) + s = s_fp32.reshape(-1, scale_rows, scale_cols).unsqueeze(-1).unsqueeze(2) return (q * s).to(output_dtype).reshape(original_shape) def _get_target_dtype(self, model: torch.nn.Module | None, full_layer_name: str | None) -> torch.dtype | None: @@ -1023,3 +1037,29 @@ def reverse_op(self) -> ConversionOps: # checkpoint preserves the FP8 format (weight + per-block ``weight_scale_inv``) # whether the in-memory state stayed quantized or was dequantized for compute. return Fp8Quantize(self.hf_quantizer) + + +class Fp8DecodeScale(ConversionOps): + """Decode MXFP8 ``ue8m0`` per-block scales (stored as ``uint8`` exponents) into the + float32 multiplicative scales the FP8 compute path expects. + + Native MXFP8 loading (``dequantize=False``) keeps weights in ``float8_e4m3fn`` and only + needs the sibling ``*.weight_scale_inv`` tensors turned from raw E8M0 bytes into real + scales (``2 ** (byte - 127)``). Prepended to each weight converter, this op runs before + any merge/concat collapses the per-expert structure: it rewrites only the ``uint8`` scale + entries and passes weights (and already-float scales) through untouched. + """ + + def __init__(self, hf_quantizer): + self.hf_quantizer = hf_quantizer + + @staticmethod + def _decode(tensor: torch.Tensor) -> torch.Tensor: + # E8M0 stores one exponent byte per block; the real scale is ``2 ** (byte - 127)``. + return (tensor.to(torch.float32) - 127.0).exp2() if tensor.dtype == torch.uint8 else tensor + + def convert(self, input_dict: dict[str, list[torch.Tensor] | torch.Tensor], **kwargs): + return { + key: [self._decode(t) for t in value] if isinstance(value, list) else self._decode(value) + for key, value in input_dict.items() + } diff --git a/src/transformers/integrations/hub_kernels.py b/src/transformers/integrations/hub_kernels.py index 9ae52963c245..8e7bbbc24ab9 100644 --- a/src/transformers/integrations/hub_kernels.py +++ b/src/transformers/integrations/hub_kernels.py @@ -370,16 +370,26 @@ def load_and_register_attn_kernel( except Exception as e: raise ValueError(f"An error occurred while trying to load from '{repo_id}': {e}.") # correctly wrap the kernel + mask_implementation = "flash_attention_2" if hasattr(kernel, "flash_attn_varlen_func"): if attention_wrapper is None: attention_wrapper = flash_attention_forward kernel_function = attention_wrapper + elif hasattr(kernel, "sparse_atten_func"): + # Block-sparse kernels (e.g. `kernels-staging/msa`) expose `sparse_atten_func` instead of + # `flash_attn_varlen_func`; their call contract differs from the attention interface, so we + # bind the dedicated transformers-side wrapper that adapts the arguments and hides the + # prefill-kernel / decode-fallback dispatch. + from .msa_attention import msa_attention_forward + + kernel_function = attention_wrapper if attention_wrapper is not None else msa_attention_forward + mask_implementation = "sdpa" elif kernel_name is not None: kernel_function = getattr(kernel, kernel_name) # Register the kernel as a valid attention ALL_ATTENTION_FUNCTIONS.register(attn_implementation, kernel_function) - ALL_MASK_ATTENTION_FUNCTIONS.register(attn_implementation, ALL_MASK_ATTENTION_FUNCTIONS["flash_attention_2"]) + ALL_MASK_ATTENTION_FUNCTIONS.register(attn_implementation, ALL_MASK_ATTENTION_FUNCTIONS[mask_implementation]) return kernel diff --git a/src/transformers/integrations/msa_attention.py b/src/transformers/integrations/msa_attention.py new file mode 100644 index 000000000000..87a479b3d9a4 --- /dev/null +++ b/src/transformers/integrations/msa_attention.py @@ -0,0 +1,259 @@ +# Copyright 2026 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import torch + +from ..utils import logging +from .sdpa_attention import sdpa_attention_forward + + +logger = logging.get_logger(__name__) + +# `sparse_atten_func` only compiles for these per-query block counts. +MSA_SUPPORTED_TOPK = (4, 8, 16, 32) +# `SparseK2qCsrBuilderSm100` only supports a 128-key block. +MSA_SUPPORTED_BLOCK_SIZE = 128 +# SM100 / Blackwell head_dim 128 kernel. +MSA_SUPPORTED_HEAD_DIM = 128 + +_MSA_KERNEL = None + + +def load_and_register_msa_kernel(attn_implementation: str): + """Load the MSA hub kernel once and verify the expected callables are present. + + The ``attn_implementation`` string may carry a ``paged|`` prefix and/or an ``@`` pin + (e.g. ``kernels-staging/msa@v0``); the build currently lives on the repo's ``v0`` branch. The + loaded module is cached in a module-level global so registration happens once, not per call. + """ + global _MSA_KERNEL + if _MSA_KERNEL is not None: + return _MSA_KERNEL + + from .hub_kernels import get_kernel + + repo_id = attn_implementation.split("|")[-1] + repo_id, _, rev = repo_id.partition("@") + kernel = get_kernel(repo_id, revision=rev or None, version=None if rev else 0, allow_all_kernels=True) + + for fn_name in ("sparse_atten_func", "build_k2q_csr"): + if not callable(getattr(kernel, fn_name, None)): + raise ImportError( + f"The MSA kernel loaded from `{repo_id}` does not expose a callable `{fn_name}`. " + "Make sure you request a compatible build, e.g. `kernels-staging/msa@v0`." + ) + + _MSA_KERNEL = kernel + return _MSA_KERNEL + + +@torch.library.custom_op("transformers_msa::sparse_atten", mutates_args=()) +def _msa_sparse_atten_op( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + q2k: torch.Tensor, + cu_seqlens_q: torch.Tensor, + cu_seqlens_k: torch.Tensor, + topk: int, + block_size: int, + total_k: int, + max_seqlen_q: int, + max_seqlen_k: int, + qheads_per_kv: int, + scaling: float, + impl: str, +) -> torch.Tensor: + """Opaque wrapper around the CuTe-DSL CSR build + block-sparse kernel. + + Registered as a ``torch.library`` custom op so ``torch.compile(fullgraph=True)`` treats the + whole CSR-build + attention as a single opaque node (no graph break) and ``reduce-overhead`` + CUDA graphs can capture it. The internal ``build_k2q_csr`` output is data-dependent in shape, + but it never escapes this op (only the fixed-shape ``[total_q, Hq, D]`` attention output does), + so the fake/meta impl below is exact. The op is functional (no input mutation). + """ + msa = load_and_register_msa_kernel(impl) + # CuTe-DSL kernel launches on the ambient ``current_device`` with no internal guard; pin context + # to the tensors' device so device_map (mixed-GPU) layouts don't reference the wrong context. + with torch.cuda.device(q.device): + k2q_row_ptr, k2q_q_indices = msa.build_k2q_csr( + q2k, + cu_seqlens_q, + cu_seqlens_k, + block_size, + total_k=total_k, + max_seqlen_k=max_seqlen_k, + max_seqlen_q=max_seqlen_q, + qhead_per_kv=qheads_per_kv, + ) + attn_output = msa.sparse_atten_func( + q, + k, + v, + k2q_row_ptr, + k2q_q_indices, + topk, + cu_seqlens_q=cu_seqlens_q, + cu_seqlens_k=cu_seqlens_k, + max_seqlen_q=max_seqlen_q, + max_seqlen_k=max_seqlen_k, + blk_kv=block_size, + causal=True, + softmax_scale=scaling, + ) + return attn_output.contiguous() + + +@_msa_sparse_atten_op.register_fake +def _msa_sparse_atten_fake( + q, + k, + v, + q2k, + cu_seqlens_q, + cu_seqlens_k, + topk, + block_size, + total_k, + max_seqlen_q, + max_seqlen_k, + qheads_per_kv, + scaling, + impl, +): + # Output matches the query varlen layout [total_q, Hq, head_dim]. + return torch.empty_like(q) + + +def _validate_msa_init(module, query: torch.Tensor, dropout: float) -> None: + """Validate kernel capability, dropout and configured topk once per attention module. + + Mirrors the flash-attention integration, which checks capability/dropout at model init rather + than on every forward. The check is cached on the module so the hot path never re-runs it. + + There is no SDPA fallback: a sparse layer either runs the MSA kernel or this raises. Serves both + prefill (q_len > 1) and single-token decode (q_len == 1) -- decode is just a varlen call with one + query slot, so there is no context-length threshold. + """ + if query.device.type != "cuda" or torch.cuda.get_device_capability(query.device)[0] != 10: + raise RuntimeError( + "MSA block-sparse attention requires an SM100 / Blackwell CUDA device. " + "Select a different `attn_implementation` on unsupported hardware." + ) + if query.shape[-1] != MSA_SUPPORTED_HEAD_DIM: + raise ValueError(f"MSA block-sparse attention only supports head_dim {MSA_SUPPORTED_HEAD_DIM}.") + if module.indexer.block_size != MSA_SUPPORTED_BLOCK_SIZE: + raise ValueError(f"MSA block-sparse attention only supports block_size {MSA_SUPPORTED_BLOCK_SIZE}.") + if dropout != 0.0: + raise ValueError("MSA block-sparse attention does not support attention dropout; set `attention_dropout=0`.") + topk = module.indexer.topk_blocks + if topk not in MSA_SUPPORTED_TOPK: + raise ValueError( + f"MSA block-sparse attention only supports topk in {MSA_SUPPORTED_TOPK}, got `{topk}`. " + "Set `index_topk_blocks` to a supported value." + ) + + +def _sparse_attention(module, query, key, value, scaling, block_indices, block_size, cache_position): + bsz, num_q_heads, q_len, head_dim = query.shape + num_kv_heads, k_len = key.shape[1], key.shape[2] + qheads_per_kv = num_q_heads // num_kv_heads + topk = block_indices.shape[-1] + + # The indexer emits `min(index_topk_blocks, num_key_blocks)` selected blocks, so on sequences with + # fewer key blocks than the configured budget the width lands on an arbitrary value (e.g. 12) that the + # `SparseK2qCsrBuilderSm100` CSR builder rejects -- it only accepts a CSR width in `MSA_SUPPORTED_TOPK`. + # Right-pad the selection up to the next supported width with `-1`, the same empty-slot sentinel the + # kernel already skips, so behaviour is unchanged and the width is always one the builder accepts. The + # width is a Python int (static under `torch.compile`), so this stays fullgraph / cudagraph stable. + padded_topk = next(t for t in MSA_SUPPORTED_TOPK if t >= topk) + if padded_topk != topk: + pad = block_indices.new_full((*block_indices.shape[:-1], padded_topk - topk), -1) + block_indices = torch.cat([block_indices, pad], dim=-1) + topk = padded_topk + + # Flatten the batch dim into a packed varlen layout [total, H, head_dim] + cu_seqlens. The + # query boundary is a fixed stride (every row is `q_len` long), built device-side with no host + # sync so it stays compile/cudagraph stable. + q = query.transpose(1, 2).reshape(bsz * q_len, num_q_heads, head_dim).contiguous() + k = key.transpose(1, 2).reshape(bsz * k_len, num_kv_heads, head_dim).contiguous() + v = value.transpose(1, 2).reshape(bsz * k_len, num_kv_heads, head_dim).contiguous() + cu_seqlens_q = torch.arange(0, (bsz + 1) * q_len, q_len, device=q.device, dtype=torch.int32) + # Under a StaticCache `k_len` is the pre-allocated buffer (`max_cache_len`), not the valid length, + # so a packed `[0, k_len, 2*k_len, ...]` boundary would let the kernel's causal attend zero-padded + # future slots. For bsz==1 the real boundary is `[0, valid_k]` (valid_k = cache_position[-1]+1) built + # as a device tensor (no host sync) -- `build_k2q_csr` reads only the host shape hints `total_k`/ + # `max_seqlen_k` (kept fixed at `bsz*k_len`/`k_len` below), so this stays compile/cudagraph stable. + if bsz == 1 and cache_position is not None: + valid_k = (cache_position[-1] + 1).to(torch.int32).reshape(1) + cu_seqlens_k = torch.cat([torch.zeros(1, device=q.device, dtype=torch.int32), valid_k]) + else: + cu_seqlens_k = torch.arange(0, (bsz + 1) * k_len, k_len, device=q.device, dtype=torch.int32) + + q2k = block_indices.to(torch.int32) + q2k = q2k.reshape(bsz * q_len, topk).unsqueeze(0).expand(num_kv_heads, -1, -1).contiguous() + + # Opaque custom op: keeps the CuTe-DSL CSR build + block-sparse kernel as a single graph node + # so ``torch.compile(fullgraph=True)`` doesn't break and ``reduce-overhead`` CUDA graphs capture it. + attn_output = _msa_sparse_atten_op( + q, + k, + v, + q2k, + cu_seqlens_q, + cu_seqlens_k, + topk, + block_size, + bsz * k_len, + q_len, + k_len, + qheads_per_kv, + scaling, + module.config._attn_implementation, + ) + return attn_output.reshape(bsz, q_len, num_q_heads, head_dim) + + +def msa_attention_forward( + module: torch.nn.Module, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attention_mask: torch.Tensor | None = None, + dropout: float = 0.0, + scaling: float | None = None, + block_indices: torch.Tensor | None = None, + **kwargs, +) -> tuple[torch.Tensor, None]: + """ + TODO: this opens a door to per-layer attn implementation which is something we might want lalter on. + """ + if scaling is None: + scaling = query.shape[-1] ** -0.5 + + # No block selection (the dense vision tower, or full-attention layers without an indexer) -> plain SDPA. + if block_indices is None: + return sdpa_attention_forward( + module, query, key, value, attention_mask, dropout=dropout, scaling=scaling, **kwargs + ) + + # A sparse layer always runs the MSA kernel -- there is no SDPA fallback. Capability/config is + # validated once per module (raises on unsupported hardware or config) and cached on the module. + if not getattr(module, "_msa_validated", False): + _validate_msa_init(module, query, dropout) + module._msa_validated = True + + block_size = module.indexer.block_size + cache_position = kwargs.get("cache_position") + attn_output = _sparse_attention(module, query, key, value, scaling, block_indices, block_size, cache_position) + return attn_output, None diff --git a/src/transformers/masking_utils.py b/src/transformers/masking_utils.py index 4a0022abac5f..eca20baa8918 100644 --- a/src/transformers/masking_utils.py +++ b/src/transformers/masking_utils.py @@ -1458,6 +1458,7 @@ def create_chunked_causal_mask( "chunked_attention": create_chunked_causal_mask, "compressed_sparse_attention": create_sliding_window_causal_mask, "heavily_compressed_attention": create_sliding_window_causal_mask, + "minimax_m3_sparse": create_causal_mask, "deepseek_sparse_attention": create_causal_mask, } diff --git a/src/transformers/modeling_flash_attention_utils.py b/src/transformers/modeling_flash_attention_utils.py index c5958f77bb23..567f03a3bba7 100644 --- a/src/transformers/modeling_flash_attention_utils.py +++ b/src/transformers/modeling_flash_attention_utils.py @@ -190,6 +190,12 @@ def _lazy_imports( flash_attn_func = getattr(kernel, "flash_attn_func", None) flash_attn_varlen_func = getattr(kernel, "flash_attn_varlen_func", None) flash_attn_with_kvcache = getattr(kernel, "flash_attn_with_kvcache", None) + # Block-sparse kernels (e.g. ``kernels-staging/msa``) expose ``sparse_atten_func`` rather than + # ``flash_attn_varlen_func``. ``load_and_register_attn_kernel`` already registered their dedicated + # wrapper into ``ALL_ATTENTION_FUNCTIONS``, so they dispatch through the attention interface and + # never touch the flash varlen globals -- preloading them here is a no-op, not an error. + if flash_attn_varlen_func is None and hasattr(kernel, "sparse_atten_func"): + return flash_attn_func, flash_attn_varlen_func, flash_attn_with_kvcache, pad_input, unpad_input if flash_attn_varlen_func is None: raise ValueError( f"Could not find the currently requested flash attention implementation at `{implementation}`." @@ -255,7 +261,9 @@ def lazy_import_flash_attention( _flash_fn, _flash_varlen_fn, _flash_with_kvcache_fn, _pad_fn, _unpad_fn = _lazy_imports( implementation, attention_wrapper, allow_all_kernels=allow_all_kernels ) - _process_flash_kwargs_fn = _lazy_define_process_function(_flash_varlen_fn) + # Block-sparse kernels register their own attention interface and expose no varlen fn to introspect; + # skip building the kwargs-support map (it is only consumed by the flash varlen path they never take). + _process_flash_kwargs_fn = _lazy_define_process_function(_flash_varlen_fn) if _flash_varlen_fn else None return (_flash_fn, _flash_varlen_fn, _flash_with_kvcache_fn, _pad_fn, _unpad_fn), _process_flash_kwargs_fn diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index f511cfb1aacb..1b8c93328f43 100644 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -1911,6 +1911,11 @@ def _check_and_adjust_attn_implementation( """ is_paged, base_implementation = split_attention_implementation(attn_implementation) + # A kernel the model explicitly lists in `_compatible_flash_implementations` is vouched for by + # the model author, so authorize loading it even when it lives outside the `kernels-community` org. + if base_implementation in (getattr(self, "_compatible_flash_implementations", None) or []): + allow_all_kernels = True + # Auto-correct model's default flash implementation if specified if attn_implementation is not None: compatible_flash_implementations = getattr(self, "_compatible_flash_implementations", None) diff --git a/src/transformers/models/__init__.py b/src/transformers/models/__init__.py index e5d880b9a46b..025c2ac7c0e1 100644 --- a/src/transformers/models/__init__.py +++ b/src/transformers/models/__init__.py @@ -267,6 +267,7 @@ from .minicpmv4_6 import * from .minimax import * from .minimax_m2 import * + from .minimax_m3_vl import * from .ministral import * from .ministral3 import * from .mistral import * diff --git a/src/transformers/models/auto/auto_mappings.py b/src/transformers/models/auto/auto_mappings.py index e3edc3f7f2e7..fe1212f230c1 100644 --- a/src/transformers/models/auto/auto_mappings.py +++ b/src/transformers/models/auto/auto_mappings.py @@ -362,6 +362,9 @@ ("minicpmv4_6_vision", "MiniCPMV4_6VisionConfig"), ("minimax", "MiniMaxConfig"), ("minimax_m2", "MiniMaxM2Config"), + ("minimax_m3_vl", "MiniMaxM3VLConfig"), + ("minimax_m3_vl_text", "MiniMaxM3VLTextConfig"), + ("minimax_m3_vl_vision", "MiniMaxM3VLVisionConfig"), ("ministral", "MinistralConfig"), ("ministral3", "Ministral3Config"), ("mistral", "MistralConfig"), @@ -791,6 +794,8 @@ ("metaclip_2_vision_model", "metaclip_2"), ("mgp-str", "mgp_str"), ("minicpmv4_6_vision", "minicpmv4_6"), + ("minimax_m3_vl_text", "minimax_m3_vl"), + ("minimax_m3_vl_vision", "minimax_m3_vl"), ("mlcd_vision_model", "mlcd"), ("mllama_text_model", "mllama"), ("mllama_vision_model", "mllama"), @@ -910,6 +915,7 @@ ("llava_next_video", "LlavaNextVideoVideoProcessor"), ("llava_onevision", "LlavaOnevisionVideoProcessor"), ("minicpmv4_6", "MiniCPMV4_6VideoProcessor"), + ("minimax_m3_vl", "MiniMaxM3VLVideoProcessor"), ("pe_video", "PeVideoVideoProcessor"), ("perception_lm", "PerceptionLMVideoProcessor"), ("qwen2_vl", "Qwen2VLVideoProcessor"), @@ -1025,6 +1031,7 @@ ("markuplm", "MarkupLMProcessor"), ("mgp-str", "MgpstrProcessor"), ("minicpmv4_6", "MiniCPMV4_6Processor"), + ("minimax_m3_vl", "MiniMaxM3VLProcessor"), ("mllama", "MllamaProcessor"), ("moonshine_streaming", "MoonshineStreamingProcessor"), ("musicflamingo", "MusicFlamingoProcessor"), @@ -1148,6 +1155,7 @@ ("mask2former", {"pil": "Mask2FormerImageProcessorPil", "torchvision": "Mask2FormerImageProcessor"}), ("maskformer", {"pil": "MaskFormerImageProcessorPil", "torchvision": "MaskFormerImageProcessor"}), ("minicpmv4_6", {"pil": "MiniCPMV4_6ImageProcessorPil", "torchvision": "MiniCPMV4_6ImageProcessor"}), + ("minimax_m3_vl", {"torchvision": "MiniMaxM3VLImageProcessor"}), ("mllama", {"pil": "MllamaImageProcessorPil", "torchvision": "MllamaImageProcessor"}), ("mobilenet_v1", {"pil": "MobileNetV1ImageProcessorPil", "torchvision": "MobileNetV1ImageProcessor"}), ("mobilenet_v2", {"pil": "MobileNetV2ImageProcessorPil", "torchvision": "MobileNetV2ImageProcessor"}), diff --git a/src/transformers/models/auto/modeling_auto.py b/src/transformers/models/auto/modeling_auto.py index 3d7aeb3e7408..9dfc42d07cd0 100644 --- a/src/transformers/models/auto/modeling_auto.py +++ b/src/transformers/models/auto/modeling_auto.py @@ -304,6 +304,8 @@ class _BaseModelWithGenerate(PreTrainedModel, GenerationMixin): ("minicpmv4_6", "MiniCPMV4_6Model"), ("minimax", "MiniMaxModel"), ("minimax_m2", "MiniMaxM2Model"), + ("minimax_m3_vl", "MiniMaxM3VLModel"), + ("minimax_m3_vl_text", "MiniMaxM3VLTextModel"), ("ministral", "MinistralModel"), ("ministral3", "Ministral3Model"), ("mistral", "MistralModel"), @@ -741,6 +743,7 @@ class _BaseModelWithGenerate(PreTrainedModel, GenerationMixin): ("mellum", "MellumForCausalLM"), ("minimax", "MiniMaxForCausalLM"), ("minimax_m2", "MiniMaxM2ForCausalLM"), + ("minimax_m3_vl_text", "MiniMaxM3VLForCausalLM"), ("ministral", "MinistralForCausalLM"), ("ministral3", "Ministral3ForCausalLM"), ("mistral", "MistralForCausalLM"), @@ -1059,6 +1062,7 @@ class _BaseModelWithGenerate(PreTrainedModel, GenerationMixin): ("llava_next_video", "LlavaNextVideoForConditionalGeneration"), ("llava_onevision", "LlavaOnevisionForConditionalGeneration"), ("minicpmv4_6", "MiniCPMV4_6ForConditionalGeneration"), + ("minimax_m3_vl", "MiniMaxM3SparseForConditionalGeneration"), ("mistral3", "Mistral3ForConditionalGeneration"), ("mistral4", "Mistral4ForCausalLM"), ("mllama", "MllamaForConditionalGeneration"), diff --git a/src/transformers/models/deepseek_v4/modeling_deepseek_v4.py b/src/transformers/models/deepseek_v4/modeling_deepseek_v4.py index 4ee9a9d53bf5..b0fc580df777 100644 --- a/src/transformers/models/deepseek_v4/modeling_deepseek_v4.py +++ b/src/transformers/models/deepseek_v4/modeling_deepseek_v4.py @@ -513,7 +513,9 @@ def forward( layer_idx: int, ) -> torch.LongTensor: batch, seq_len, _ = hidden_states.shape - cache_layer: DeepseekV4CSACache = past_key_values.layers[layer_idx] if past_key_values is not None else None + cache_layer: DeepseekV4CSACache | None = ( + past_key_values.layers[layer_idx] if past_key_values is not None else None + ) kv = self.kv_proj(hidden_states) gate = self.gate_proj(hidden_states) @@ -627,7 +629,9 @@ def forward( layer_idx: int, ) -> tuple[torch.Tensor, torch.Tensor]: batch, seq_len, _ = hidden_states.shape - cache_layer: DeepseekV4CSACache = past_key_values.layers[layer_idx] if past_key_values is not None else None + cache_layer: DeepseekV4CSACache | None = ( + past_key_values.layers[layer_idx] if past_key_values is not None else None + ) kv = self.kv_proj(hidden_states) gate = self.gate_proj(hidden_states) diff --git a/src/transformers/models/deepseek_v4/modular_deepseek_v4.py b/src/transformers/models/deepseek_v4/modular_deepseek_v4.py index b5af24599f4b..089192d0e94f 100644 --- a/src/transformers/models/deepseek_v4/modular_deepseek_v4.py +++ b/src/transformers/models/deepseek_v4/modular_deepseek_v4.py @@ -449,7 +449,9 @@ def forward( layer_idx: int, ) -> torch.LongTensor: batch, seq_len, _ = hidden_states.shape - cache_layer: DeepseekV4CSACache = past_key_values.layers[layer_idx] if past_key_values is not None else None + cache_layer: DeepseekV4CSACache | None = ( + past_key_values.layers[layer_idx] if past_key_values is not None else None + ) kv = self.kv_proj(hidden_states) gate = self.gate_proj(hidden_states) @@ -563,7 +565,9 @@ def forward( layer_idx: int, ) -> tuple[torch.Tensor, torch.Tensor]: batch, seq_len, _ = hidden_states.shape - cache_layer: DeepseekV4CSACache = past_key_values.layers[layer_idx] if past_key_values is not None else None + cache_layer: DeepseekV4CSACache | None = ( + past_key_values.layers[layer_idx] if past_key_values is not None else None + ) kv = self.kv_proj(hidden_states) gate = self.gate_proj(hidden_states) diff --git a/src/transformers/models/minimax_m3_vl/__init__.py b/src/transformers/models/minimax_m3_vl/__init__.py new file mode 100644 index 000000000000..9c8d2fbac88f --- /dev/null +++ b/src/transformers/models/minimax_m3_vl/__init__.py @@ -0,0 +1,31 @@ +# Copyright 2026 the MiniMax AI Team and HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import TYPE_CHECKING + +from ...utils import _LazyModule +from ...utils.import_utils import define_import_structure + + +if TYPE_CHECKING: + from .configuration_minimax_m3_vl import * + from .image_processing_minimax_m3_vl import * + from .modeling_minimax_m3_vl import * + from .processing_minimax_m3_vl import * + from .video_processing_minimax_m3_vl import * +else: + import sys + + _file = globals()["__file__"] + sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__) diff --git a/src/transformers/models/minimax_m3_vl/configuration_minimax_m3_vl.py b/src/transformers/models/minimax_m3_vl/configuration_minimax_m3_vl.py new file mode 100644 index 000000000000..86c0009389be --- /dev/null +++ b/src/transformers/models/minimax_m3_vl/configuration_minimax_m3_vl.py @@ -0,0 +1,226 @@ +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# This file was automatically generated from src/transformers/models/minimax_m3_vl/modular_minimax_m3_vl.py. +# Do NOT edit this file manually as any edits will be overwritten by the generation of +# the file from the modular. If any change should be done, please apply the change to the +# modular_minimax_m3_vl.py file directly. One of our CI enforces this. +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# Copyright 2026 the MiniMax AI Team and HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from huggingface_hub.dataclasses import strict + +from ...configuration_utils import PreTrainedConfig +from ...modeling_rope_utils import RopeParameters +from ...utils import auto_docstring +from ..auto import AutoConfig + + +@auto_docstring(checkpoint="MiniMaxAI/MiniMax-M3-preview") +@strict +class MiniMaxM3VLTextConfig(PreTrainedConfig): + r""" + dense_intermediate_size (`int`, *optional*, defaults to 12288): + Intermediate size of the dense MLP used on layers whose `mlp_layer_types` entry is `"dense"`. + shared_intermediate_size (`int`, *optional*, defaults to 3072): + Intermediate size of a single shared expert in the MoE layers. + rotary_dim (`int`, *optional*, defaults to 64): + Number of head channels rotated by RoPE; the remaining channels are passed through unchanged. + swiglu_alpha (`float`, *optional*, defaults to 1.702): + Sigmoid gain of the SwiGLU-OAI activation. + swiglu_limit (`float`, *optional*, defaults to 7.0): + Clamp bound applied to the gate and up projections of the SwiGLU-OAI activation. + mlp_layer_types (`list[str]`, *optional*): + Per-layer MLP selector: `"sparse"` for a MoE block, `"dense"` for a dense MLP. + index_n_heads (`int`, *optional*, defaults to 4): + Number of heads in the lightning indexer's dot-product scoring branch. + index_head_dim (`int`, *optional*, defaults to 128): + Per-head channel dimension of the lightning indexer. + index_block_size (`int`, *optional*, defaults to 128): + Number of key tokens pooled into a single scored block. + index_topk_blocks (`int`, *optional*, defaults to 16): + Number of top-scoring key blocks each query may attend to. + index_local_blocks (`int`, *optional*, defaults to 1): + Number of key blocks immediately preceding the query always kept visible / attended to. + """ + + model_type = "minimax_m3_vl_text" + keys_to_ignore_at_inference = ["past_key_values"] + base_model_tp_plan = { + "layers.*.self_attn.q_proj": "colwise_gather_output", + "layers.*.self_attn.k_proj": "colwise_gather_output", + "layers.*.self_attn.v_proj": "colwise_gather_output", + "layers.*.self_attn.o_proj": "rowwise_split_input", + "layers.*.mlp.experts.gate_up_proj": "packed_colwise", + "layers.*.mlp.experts.down_proj": "rowwise", + "layers.*.mlp.experts": "moe_tp_experts", + } + base_model_pp_plan = { + "embed_tokens": (["input_ids"], ["inputs_embeds"]), + "layers": (["hidden_states", "attention_mask"], ["hidden_states"]), + "norm": (["hidden_states"], ["hidden_states"]), + } + attribute_map = { + "num_experts": "num_local_experts", + } + default_theta = 5000000.0 + vocab_size: int = 200064 + + hidden_size: int = 6144 + intermediate_size: int = 3072 + num_hidden_layers: int = 60 + num_attention_heads: int = 64 + num_key_value_heads: int = 4 + head_dim: int = 128 + hidden_act: str = "silu" + max_position_embeddings: int = 524288 + initializer_range: float = 0.02 + rms_norm_eps: float = 1e-06 + use_cache: bool = True + pad_token_id: int | None = None + bos_token_id: int | None = 200034 + eos_token_id: int | list[int] | None = 200020 + tie_word_embeddings: bool = False + attention_dropout: float | int = 0.0 + num_experts_per_tok: int = 4 + num_local_experts: int = 128 + output_router_logits: bool = False + router_aux_loss_coef: float = 0.001 + router_jitter_noise: float = 0.0 + rope_parameters: RopeParameters | dict | None = None + base_config_key = "text_config" + base_model_ep_plan = { + "layers.*.mlp.gate": "ep_router", + "layers.*.mlp.experts.gate_up_proj": "grouped_gemm", + "layers.*.mlp.experts.gate_up_proj_scale_inv": "grouped_gemm", + "layers.*.mlp.experts.down_proj": "grouped_gemm", + "layers.*.mlp.experts.down_proj_scale_inv": "grouped_gemm", + "layers.*.mlp.experts": "moe_tp_experts", + } + dense_intermediate_size: int = 12288 + shared_intermediate_size: int = 3072 + routed_scaling_factor: float = 2.0 + rotary_dim: int = 64 + swiglu_alpha: float = 1.702 + swiglu_limit: float = 7.0 + mlp_layer_types: list[str] | None = None + index_n_heads: int = 4 + index_head_dim: int = 128 + index_block_size: int = 128 + index_topk_blocks: int = 16 + index_local_blocks: int = 1 + layer_types: list[str] | None = None + + def __post_init__(self, **kwargs): + sparse_cfg = kwargs.pop("sparse_attention_config", None) or {} + moe_layer_freq = kwargs.pop("moe_layer_freq", None) + super().__post_init__(**kwargs) + # Checkpoint declares "swigluoai", but the gate is computed inline from swiglu_alpha/limit; hidden_act + # is only the pointwise fallback and must be a real ACT2FN key, so normalize it to silu. + self.hidden_act = "silu" + + for flat, legacy in { + "index_n_heads": "sparse_num_index_heads", + "index_head_dim": "sparse_index_dim", + "index_block_size": "sparse_block_size", + "index_topk_blocks": "sparse_topk_blocks", + "index_local_blocks": "sparse_local_block", + }.items(): + if legacy in sparse_cfg: + setattr(self, flat, sparse_cfg[legacy]) + + # `layer_types` is the canonical per-layer attention dispatch: it tells + # `DynamicCache(config=...)` which layers want the sparse cache and tells + # `MiniMaxM3VLAttention` which layers build a sparse Lightning Indexer. + if self.layer_types is None and "sparse_attention_freq" in sparse_cfg: + self.layer_types = [ + "minimax_m3_sparse" if f else "full_attention" for f in sparse_cfg["sparse_attention_freq"] + ] + if self.layer_types is None: + self.layer_types = ["full_attention"] * self.num_hidden_layers + + # `mlp_layer_types` is the per-layer MLP dispatch read by `MiniMaxM3VLDecoderLayer`: + if self.mlp_layer_types is None and moe_layer_freq is not None: + self.mlp_layer_types = ["sparse" if f else "dense" for f in moe_layer_freq] + if self.mlp_layer_types is None: + self.mlp_layer_types = ["sparse"] * self.num_hidden_layers + + +@auto_docstring(checkpoint="MiniMaxAI/MiniMax-M3-preview") +@strict +class MiniMaxM3VLVisionConfig(PreTrainedConfig): + r""" + rope_parameters (`RopeParameters`, *optional*): + Standard RoPE configuration for the vision tower's 3D rotary position embedding. + """ + + model_type = "minimax_m3_vl_vision" + base_config_key = "vision_config" + default_theta = 10000.0 + + hidden_size: int = 1280 + intermediate_size: int = 5120 + num_hidden_layers: int = 32 + num_attention_heads: int = 16 + num_channels: int = 3 + image_size: int = 2016 + patch_size: int = 14 + temporal_patch_size: int = 2 + spatial_merge_size: int = 2 + hidden_act: str = "gelu" + layer_norm_eps: float = 1e-05 + attention_dropout: float = 0.0 + rope_parameters: RopeParameters | dict | None = None + initializer_range: float = 0.02 + + +@auto_docstring(checkpoint="MiniMaxAI/MiniMax-M3-preview") +@strict +class MiniMaxM3VLConfig(PreTrainedConfig): + model_type = "minimax_m3_vl" + sub_configs = {"text_config": AutoConfig, "vision_config": AutoConfig} + attribute_map = { + "image_token_id": "image_token_index", + "video_token_id": "video_token_index", + } + + vision_config: dict | PreTrainedConfig | None = None + text_config: dict | PreTrainedConfig | None = None + image_token_index: int = 200025 + video_token_index: int = 200026 + projector_hidden_size: int = 6144 + tie_word_embeddings: bool = False + + def __post_init__(self, **kwargs): + if isinstance(self.vision_config, dict): + self.vision_config.pop("model_type", None) + self.vision_config = MiniMaxM3VLVisionConfig(**self.vision_config) + elif self.vision_config is None: + self.vision_config = MiniMaxM3VLVisionConfig() + + if isinstance(self.text_config, dict): + self.text_config.pop("model_type", None) + self.text_config = MiniMaxM3VLTextConfig(**self.text_config) + elif self.text_config is None: + self.text_config = MiniMaxM3VLTextConfig() + + if not self.tie_word_embeddings and self.text_config.tie_word_embeddings: + self.tie_word_embeddings = self.text_config.tie_word_embeddings + + # Channel dim after grouping `spatial_merge_size**2` projected patches, consumed by the + # patch-merge MLP inside `MiniMaxM3VLMultiModalProjector`. + self.merged_hidden_size = self.text_config.hidden_size * (self.vision_config.spatial_merge_size**2) + + super().__post_init__(**kwargs) + + +__all__ = ["MiniMaxM3VLConfig", "MiniMaxM3VLTextConfig", "MiniMaxM3VLVisionConfig"] diff --git a/src/transformers/models/minimax_m3_vl/image_processing_minimax_m3_vl.py b/src/transformers/models/minimax_m3_vl/image_processing_minimax_m3_vl.py new file mode 100644 index 000000000000..04766fd745a2 --- /dev/null +++ b/src/transformers/models/minimax_m3_vl/image_processing_minimax_m3_vl.py @@ -0,0 +1,204 @@ +# Copyright 2026 the MiniMax AI Team and HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math + +import torch +from torchvision.transforms.v2 import functional as tvF + +from ...image_processing_backends import TorchvisionBackend +from ...image_processing_utils import BatchFeature +from ...image_transforms import group_images_by_shape, reorder_images +from ...image_utils import OPENAI_CLIP_MEAN, OPENAI_CLIP_STD, ImageInput, PILImageResampling, SizeDict +from ...processing_utils import ImagesKwargs, Unpack +from ...utils import TensorType, auto_docstring + + +class MiniMaxM3VLImageProcessorKwargs(ImagesKwargs, total=False): + r""" + patch_size (`int`, *optional*, defaults to 14): + The spatial patch size of the vision encoder. + temporal_patch_size (`int`, *optional*, defaults to 2): + The temporal patch size of the vision encoder. + merge_size (`int`, *optional*, defaults to 2): + The merge size of the vision encoder to llm encoder. + max_pixels (`int`, *optional*, defaults to 451584): + The max pixels of the image to resize the image. + """ + + patch_size: int + temporal_patch_size: int + merge_size: int + max_pixels: int + + +MAX_RATIO = 200 + + +def smart_resize( + height: int, + width: int, + factor: int = 28, + min_pixels: int = 4 * 28 * 28, + max_pixels: int = 451584, +) -> tuple[int, int]: + if max(height, width) / min(height, width) > MAX_RATIO: + raise ValueError( + f"absolute aspect ratio must be smaller than {MAX_RATIO}, got {max(height, width) / min(height, width)}" + ) + h_bar = max(factor, round(height / factor) * factor) + w_bar = max(factor, round(width / factor) * factor) + if h_bar * w_bar > max_pixels: + beta = math.sqrt((height * width) / max_pixels) + h_bar = math.floor(height / beta / factor) * factor + w_bar = math.floor(width / beta / factor) * factor + elif h_bar * w_bar < min_pixels: + beta = math.sqrt(min_pixels / (height * width)) + h_bar = math.ceil(height * beta / factor) * factor + w_bar = math.ceil(width * beta / factor) * factor + return h_bar, w_bar + + +@auto_docstring +class MiniMaxM3VLImageProcessor(TorchvisionBackend): + do_resize = True + resample = PILImageResampling.BICUBIC + size = {"height": 672, "width": 672} + default_to_square = False + do_rescale = True + rescale_factor = 1 / 255 + do_normalize = True + image_mean = OPENAI_CLIP_MEAN + image_std = OPENAI_CLIP_STD + do_convert_rgb = True + patch_size = 14 + temporal_patch_size = 2 + merge_size = 2 + max_pixels = 451584 + valid_kwargs = MiniMaxM3VLImageProcessorKwargs + model_input_names = ["pixel_values", "image_grid_thw"] + + @auto_docstring + def preprocess(self, images: ImageInput, **kwargs: Unpack[MiniMaxM3VLImageProcessorKwargs]) -> BatchFeature: + return super().preprocess(images, **kwargs) + + def _preprocess( + self, + images: list["torch.Tensor"], + do_resize: bool, + size: SizeDict, + resample: "PILImageResampling | tvF.InterpolationMode | int | None", + do_rescale: bool, + rescale_factor: float, + do_normalize: bool, + image_mean: float | list[float] | None, + image_std: float | list[float] | None, + patch_size: int, + temporal_patch_size: int, + merge_size: int, + max_pixels: int, + disable_grouping: bool | None, + return_tensors: str | TensorType | None, + **kwargs, + ) -> BatchFeature: + grouped_images, grouped_images_index = group_images_by_shape(images, disable_grouping=disable_grouping) + resized_images_grouped = {} + factor = patch_size * merge_size + for shape, stacked_images in grouped_images.items(): + height, width = stacked_images.shape[-2:] + if do_resize: + resized_height, resized_width = smart_resize( + height=height, + width=width, + factor=factor, + max_pixels=max_pixels, + ) + stacked_images = self.resize( + stacked_images, + size=SizeDict(height=resized_height, width=resized_width), + resample=resample, + ) + resized_images_grouped[shape] = stacked_images + + resized_images = reorder_images(resized_images_grouped, grouped_images_index) + + grouped_images, grouped_images_index = group_images_by_shape(resized_images, disable_grouping=disable_grouping) + processed_images_grouped = {} + processed_grids = {} + + for shape, stacked_images in grouped_images.items(): + resized_height, resized_width = stacked_images.shape[-2:] + + patches = self.rescale_and_normalize( + stacked_images, do_rescale, rescale_factor, do_normalize, image_mean, image_std + ) + if patches.ndim == 4: # (B, C, H, W) + patches = patches.unsqueeze(1) # (B, T=1, C, H, W) + + if patches.shape[1] % temporal_patch_size != 0: + repeats = patches[:, -1:].repeat( + 1, temporal_patch_size - (patches.shape[1] % temporal_patch_size), 1, 1, 1 + ) + patches = torch.cat([patches, repeats], dim=1) + + batch_size, t_len, channel = patches.shape[:3] + grid_t = t_len // temporal_patch_size + grid_h, grid_w = resized_height // patch_size, resized_width // patch_size + + patches = patches.view( + batch_size, + grid_t, + temporal_patch_size, + channel, + grid_h // merge_size, + merge_size, + patch_size, + grid_w // merge_size, + merge_size, + patch_size, + ) + patches = patches.permute(0, 1, 4, 7, 5, 8, 3, 2, 6, 9) + + flatten_patches = patches.reshape( + batch_size, + grid_t * grid_h * grid_w, + channel * temporal_patch_size * patch_size * patch_size, + ) + + processed_images_grouped[shape] = flatten_patches + processed_grids[shape] = [[grid_t, grid_h, grid_w]] * batch_size + + processed_images = reorder_images(processed_images_grouped, grouped_images_index) + processed_grids = reorder_images(processed_grids, grouped_images_index) + + pixel_values = torch.cat(processed_images, dim=0) + image_grid_thw = torch.tensor(processed_grids, dtype=torch.long) + + return BatchFeature( + data={"pixel_values": pixel_values, "image_grid_thw": image_grid_thw}, tensor_type=return_tensors + ) + + def get_number_of_image_patches(self, height: int, width: int, images_kwargs=None) -> int: + images_kwargs = images_kwargs or {} + patch_size = images_kwargs.get("patch_size", self.patch_size) + merge_size = images_kwargs.get("merge_size", self.merge_size) + max_pixels = images_kwargs.get("max_pixels", self.max_pixels) + resized_height, resized_width = smart_resize( + height, width, factor=patch_size * merge_size, max_pixels=max_pixels + ) + grid_h, grid_w = resized_height // patch_size, resized_width // patch_size + return grid_h * grid_w + + +__all__ = ["MiniMaxM3VLImageProcessor"] diff --git a/src/transformers/models/minimax_m3_vl/modeling_minimax_m3_vl.py b/src/transformers/models/minimax_m3_vl/modeling_minimax_m3_vl.py new file mode 100644 index 000000000000..2e825ce8e5e9 --- /dev/null +++ b/src/transformers/models/minimax_m3_vl/modeling_minimax_m3_vl.py @@ -0,0 +1,1586 @@ +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# This file was automatically generated from src/transformers/models/minimax_m3_vl/modular_minimax_m3_vl.py. +# Do NOT edit this file manually as any edits will be overwritten by the generation of +# the file from the modular. If any change should be done, please apply the change to the +# modular_minimax_m3_vl.py file directly. One of our CI enforces this. +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# Copyright 2026 the MiniMax AI Team and HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from collections.abc import Callable +from dataclasses import dataclass +from typing import Optional + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from ... import initialization as init +from ...activations import ACT2FN +from ...cache_utils import Cache, DynamicCache, DynamicLayer, StaticLayer +from ...configuration_utils import PreTrainedConfig +from ...generation import GenerationMixin +from ...integrations import use_experts_implementation +from ...masking_utils import create_causal_mask +from ...modeling_layers import GradientCheckpointingLayer +from ...modeling_outputs import ( + BaseModelOutputWithPast, + BaseModelOutputWithPooling, + MoeCausalLMOutputWithPast, + MoeModelOutputWithPast, +) +from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update +from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel +from ...processing_utils import Unpack +from ...utils import ModelOutput, TransformersKwargs, auto_docstring, torch_compilable_check +from ...utils.generic import can_return_tuple, maybe_autocast, merge_with_config_defaults +from ...utils.import_utils import is_torchdynamo_compiling +from ...utils.output_capturing import OutputRecorder, capture_outputs +from .configuration_minimax_m3_vl import MiniMaxM3VLConfig, MiniMaxM3VLTextConfig, MiniMaxM3VLVisionConfig + + +class MiniMaxM3VLSparseCacheLayer(DynamicLayer): + layer_type = "minimax_m3_sparse" + + def __init__(self, config: PreTrainedConfig | None = None): + super().__init__(config) + self.idx_keys: torch.Tensor | None = None + + def update_index(self, idx_k: torch.Tensor) -> torch.Tensor: + """Append the new token's `idx_k` to the cache and return the full history.""" + self.idx_keys = idx_k if self.idx_keys is None else torch.cat([self.idx_keys, idx_k], dim=-2) + return self.idx_keys + + def reorder_cache(self, beam_idx: torch.LongTensor) -> None: + super().reorder_cache(beam_idx) + if self.idx_keys is not None: + self.idx_keys = self.idx_keys.index_select(0, beam_idx.to(self.idx_keys.device)) + + def batch_repeat_interleave(self, repeats: int) -> None: + super().batch_repeat_interleave(repeats) + if self.idx_keys is not None: + self.idx_keys = self.idx_keys.repeat_interleave(repeats, dim=0) + + def batch_select_indices(self, indices: torch.Tensor) -> None: + super().batch_select_indices(indices) + if self.idx_keys is not None: + self.idx_keys = self.idx_keys[indices, ...] + + def crop(self, max_length: int) -> None: + super().crop(max_length) + if max_length < 0: + max_length = self.get_seq_length() - abs(max_length) + if self.idx_keys is not None and self.idx_keys.shape[-2] > max_length: + self.idx_keys = self.idx_keys[..., :max_length, :] + + +class MiniMaxM3VLSparseStaticCacheLayer(StaticLayer): + layer_type = "minimax_m3_sparse" + + def __init__(self, max_cache_len: int): + super().__init__(max_cache_len) + self.idx_keys: torch.Tensor | None = None + # Tensor (not int) so it can be marked as a static address for cudagraphs, like `cumulative_length`. + self.idx_cumulative_length = torch.tensor([0], dtype=int) + + def update_index(self, idx_k: torch.Tensor) -> torch.Tensor: + """Write the new token's `idx_k` into the static buffer in place and return the whole buffer. + + The buffer's unfilled tail holds zeros, but those slots sit at key positions ahead of every + current query, so the indexer's block- and token-level causal masking discards them — the + returned `[B, 1, max_cache_len, D]` history is therefore safe to score against directly. + """ + if self.idx_keys is None: + self.idx_keys = torch.zeros( + (idx_k.shape[0], idx_k.shape[1], self.max_cache_len, idx_k.shape[-1]), + dtype=idx_k.dtype, + device=idx_k.device, + ) + self.idx_cumulative_length = self.idx_cumulative_length.to(idx_k.device) + if not is_torchdynamo_compiling(): + torch._dynamo.mark_static_address(self.idx_keys) + torch._dynamo.mark_static_address(self.idx_cumulative_length) + + kv_len = idx_k.shape[-2] + cache_position = torch.arange(kv_len, device=self.idx_keys.device) + self.idx_cumulative_length + self.idx_cumulative_length.add_(kv_len) + try: + self.idx_keys.index_copy_(2, cache_position, idx_k) + except NotImplementedError: + # Fallback for devices like MPS where index_copy_ might not be supported. + self.idx_keys[:, :, cache_position] = idx_k + return self.idx_keys + + def reset(self) -> None: + super().reset() + if self.idx_keys is not None: + self.idx_keys.zero_() + self.idx_cumulative_length.zero_() + + def reorder_cache(self, beam_idx: torch.LongTensor) -> None: + super().reorder_cache(beam_idx) + if self.idx_keys is not None: + self.idx_keys = self.idx_keys.index_select(0, beam_idx.to(self.idx_keys.device)) + + +class MiniMaxM3VLRMSNorm(nn.Module): + """Gemma-style RMSNorm: normalizes in fp32 and scales by `weight + 1`.""" + + def __init__(self, dim: int, eps: float = 1e-6): + super().__init__() + self.eps = eps + self.weight = nn.Parameter(torch.zeros(dim)) + + def _norm(self, x): + return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) + + def forward(self, x): + output = self._norm(x.float()) + # Llama does x.to(float16) * w whilst MiniMaxM3VL is (x * w).to(float16) + # See https://github.com/huggingface/transformers/pull/29402 + output = output * (1.0 + self.weight.float()) + return output.type_as(x) + + def extra_repr(self): + return f"{tuple(self.weight.shape)}, eps={self.eps}" + + +class MiniMaxM3VLDenseMLP(nn.Module): + def __init__(self, config: MiniMaxM3VLTextConfig, intermediate_size: int | None = None): + super().__init__() + inter = intermediate_size if intermediate_size is not None else config.dense_intermediate_size + self.swiglu_alpha = config.swiglu_alpha + self.swiglu_limit = config.swiglu_limit + self.gate_up_proj = nn.Linear(config.hidden_size, 2 * inter, bias=False) + self.down_proj = nn.Linear(inter, config.hidden_size, bias=False) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + gate_up = self.gate_up_proj(hidden_states) + gate, up = gate_up.chunk(2, dim=-1) + gate = gate.clamp(max=self.swiglu_limit) + up = up.clamp(min=-self.swiglu_limit, max=self.swiglu_limit) + glu = gate * torch.sigmoid(gate * self.swiglu_alpha) + return self.down_proj((up + 1.0) * glu) + + +@use_experts_implementation +class MiniMaxM3VLExperts(nn.Module): + """Collection of expert weights stored as 3D tensors.""" + + def __init__(self, config: MiniMaxM3VLTextConfig): + super().__init__() + self.num_experts = config.num_local_experts + self.hidden_dim = config.hidden_size + self.intermediate_dim = config.intermediate_size + self.gate_up_proj = nn.Parameter(torch.empty(self.num_experts, 2 * self.intermediate_dim, self.hidden_dim)) + self.down_proj = nn.Parameter(torch.empty(self.num_experts, self.hidden_dim, self.intermediate_dim)) + self.limit = config.swiglu_limit + self.swiglu_alpha = config.swiglu_alpha + self.swiglu_limit = config.swiglu_limit + + def forward( + self, hidden_states: torch.Tensor, top_k_index: torch.Tensor, top_k_weights: torch.Tensor + ) -> torch.Tensor: + final = torch.zeros_like(hidden_states) + with torch.no_grad(): + mask = F.one_hot(top_k_index, num_classes=self.num_experts).permute(2, 1, 0) + hit = torch.greater(mask.sum(dim=(-1, -2)), 0).nonzero() + for expert_idx in hit: + expert_idx = expert_idx[0] + if expert_idx == self.num_experts: + continue + top_k_pos, token_idx = torch.where(mask[expert_idx]) + current = self._apply_gate(F.linear(hidden_states[token_idx], self.gate_up_proj[expert_idx])) + current = F.linear(current, self.down_proj[expert_idx]) * top_k_weights[token_idx, top_k_pos, None] + final.index_add_(0, token_idx, current.to(final.dtype)) + return final + + def _apply_gate(self, gate_up: torch.Tensor) -> torch.Tensor: + # same as GPT OSS, but the weights are not interleaved + gate, up = gate_up.chunk(2, dim=-1) + gate = gate.clamp(max=self.swiglu_limit) + up = up.clamp(min=-self.swiglu_limit, max=self.swiglu_limit) + glu = gate * torch.sigmoid(gate * self.swiglu_alpha) + return (up + 1.0) * glu + + +class MiniMaxM3VLTopKRouter(nn.Module): + def __init__(self, config: MiniMaxM3VLTextConfig): + super().__init__() + self.top_k = config.num_experts_per_tok + self.num_experts = config.num_local_experts + self.hidden_dim = config.hidden_size + self.weight = nn.Parameter(torch.empty(self.num_experts, self.hidden_dim)) + self.register_buffer("e_score_correction_bias", torch.zeros(config.num_local_experts)) + + def forward(self, hidden_states: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + hidden_states = hidden_states.reshape(-1, self.hidden_dim) + router_logits = F.linear(hidden_states.to(self.weight.dtype), self.weight) + # Sigmoid scoring (not softmax), as in M2. + routing_weights = F.sigmoid(router_logits.float()) + scores_for_choice = routing_weights + self.e_score_correction_bias + _, top_k_index = torch.topk(scores_for_choice, self.top_k, dim=-1, sorted=False) + top_k_weights = routing_weights.gather(1, top_k_index) + top_k_weights /= top_k_weights.sum(dim=-1, keepdim=True) + return router_logits, top_k_weights, top_k_index + + +class MiniMaxM3VLSparseMoeBlock(nn.Module): + def __init__(self, config: MiniMaxM3VLTextConfig): + super().__init__() + self.gate = MiniMaxM3VLTopKRouter(config) + self.experts = MiniMaxM3VLExperts(config) + self.routed_scaling_factor = config.routed_scaling_factor + self.shared_experts = MiniMaxM3VLDenseMLP(config, intermediate_size=config.shared_intermediate_size) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + batch_size, sequence_length, hidden_dim = hidden_states.shape + hidden_states = hidden_states.view(-1, hidden_dim) + shared_output = self.shared_experts(hidden_states) + + _, routing_weights, selected_experts = self.gate(hidden_states) + hidden_states = self.experts(hidden_states, selected_experts, routing_weights) + # Additional scaling + hidden_states = hidden_states * self.routed_scaling_factor + hidden_states = hidden_states + shared_output + + hidden_states = hidden_states.reshape(batch_size, sequence_length, hidden_dim) + return hidden_states + + +class MiniMaxM3VLRotaryEmbedding(nn.Module): + inv_freq: torch.Tensor # fix linting for `register_buffer` + + def __init__(self, config: MiniMaxM3VLConfig, device=None): + super().__init__() + self.max_seq_len_cached = config.max_position_embeddings + self.original_max_seq_len = config.max_position_embeddings + + self.config = config + + self.rope_type = self.config.rope_parameters["rope_type"] + rope_init_fn: Callable = self.compute_default_rope_parameters + if self.rope_type != "default": + rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] + inv_freq, self.attention_scaling = rope_init_fn(self.config, device) + + self.register_buffer("inv_freq", inv_freq, persistent=False) + self.register_buffer("original_inv_freq", inv_freq.clone(), persistent=False) + + @staticmethod + def compute_default_rope_parameters( + config: MiniMaxM3VLConfig | None = None, + device: Optional["torch.device"] = None, + seq_len: int | None = None, + ) -> tuple["torch.Tensor", float]: + """ + Computes the inverse frequencies according to the original RoPE implementation + Args: + config ([`~transformers.PreTrainedConfig`]): + The model configuration. + device (`torch.device`): + The device to use for initialization of the inverse frequencies. + seq_len (`int`, *optional*): + The current sequence length. Unused for this type of RoPE. + Returns: + Tuple of (`torch.Tensor`, `float`), containing the inverse frequencies for the RoPE embeddings and the + post-processing scaling factor applied to the computed cos/sin (unused in this type of RoPE). + """ + base = config.rope_parameters["rope_theta"] + partial_rotary_factor = config.rope_parameters.get("partial_rotary_factor", 1.0) + head_dim = getattr(config, "head_dim", None) or config.hidden_size // config.num_attention_heads + dim = int(head_dim * partial_rotary_factor) + + attention_factor = 1.0 # Unused in this type of RoPE + + # Compute the inverse frequencies + inv_freq = 1.0 / ( + base ** (torch.arange(0, dim, 2, dtype=torch.int64).to(device=device, dtype=torch.float) / dim) + ) + return inv_freq, attention_factor + + @torch.no_grad() + @dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope) + def forward(self, x, position_ids): + inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device) + position_ids_expanded = position_ids[:, None, :].float() + + device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu" + with maybe_autocast(device_type=device_type, enabled=False): # Force float32 + freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) + emb = torch.cat((freqs, freqs), dim=-1) + cos = emb.cos() * self.attention_scaling + sin = emb.sin() * self.attention_scaling + + return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) + + +def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: + """ + This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, + num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim) + """ + batch, num_key_value_heads, slen, head_dim = hidden_states.shape + if n_rep == 1: + return hidden_states + hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim) + return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) + + +def eager_attention_forward( + module: nn.Module, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attention_mask: torch.Tensor | None, + scaling: float, + dropout: float = 0.0, + **kwargs: Unpack[TransformersKwargs], +): + key_states = repeat_kv(key, module.num_key_value_groups) + value_states = repeat_kv(value, module.num_key_value_groups) + + attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling + if attention_mask is not None: + attn_weights = attn_weights + attention_mask + + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype) + attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training) + attn_output = torch.matmul(attn_weights, value_states) + attn_output = attn_output.transpose(1, 2).contiguous() + + return attn_output, attn_weights + + +def apply_rotary_pos_emb(q, k, cos, sin, unsqueeze_dim=1): + """Applies Rotary Position Embedding to the query and key tensors. + + Args: + q (`torch.Tensor`): The query tensor. + k (`torch.Tensor`): The key tensor. + cos (`torch.Tensor`): The cosine part of the rotary embedding. + sin (`torch.Tensor`): The sine part of the rotary embedding. + unsqueeze_dim (`int`, *optional*, defaults to 1): + The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and + sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note + that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and + k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes + cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have + the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2. + Returns: + `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. + """ + cos = cos.unsqueeze(unsqueeze_dim) + sin = sin.unsqueeze(unsqueeze_dim) + + # Keep half or full tensor for later concatenation + rotary_dim = cos.shape[-1] + q_rot, q_pass = q[..., :rotary_dim], q[..., rotary_dim:] + k_rot, k_pass = k[..., :rotary_dim], k[..., rotary_dim:] + + # Apply rotary embeddings on the first half or full tensor + q_embed = (q_rot * cos) + (rotate_half(q_rot) * sin) + k_embed = (k_rot * cos) + (rotate_half(k_rot) * sin) + + # Concatenate back to full shape + q_embed = torch.cat([q_embed, q_pass], dim=-1) + k_embed = torch.cat([k_embed, k_pass], dim=-1) + return q_embed, k_embed + + +def rotate_half(x): + """Rotates half the hidden dims of the input.""" + x1 = x[..., : x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2 :] + return torch.cat((-x2, x1), dim=-1) + + +class MiniMaxM3VLAttention(nn.Module): + """ + M3 attention: per-head Gemma QK-norm + partial RoPE, optionally sparse indexer selection which require position IDs. + """ + + def __init__(self, config: MiniMaxM3VLTextConfig, layer_idx: int): + super().__init__() + self.config = config + self.layer_idx = layer_idx + self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads) + self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads + self.scaling = self.head_dim**-0.5 + self.attention_dropout = config.attention_dropout + self.is_causal = True + self.q_proj = nn.Linear(config.hidden_size, config.num_attention_heads * self.head_dim, bias=False) + self.k_proj = nn.Linear(config.hidden_size, config.num_key_value_heads * self.head_dim, bias=False) + self.v_proj = nn.Linear(config.hidden_size, config.num_key_value_heads * self.head_dim, bias=False) + self.o_proj = nn.Linear(config.num_attention_heads * self.head_dim, config.hidden_size, bias=False) + self.q_norm = MiniMaxM3VLRMSNorm(self.head_dim, eps=config.rms_norm_eps) + self.k_norm = MiniMaxM3VLRMSNorm(self.head_dim, eps=config.rms_norm_eps) + self.indexer = ( + MiniMaxM3VLIndexer(config, layer_idx) if config.layer_types[layer_idx] == "minimax_m3_sparse" else None + ) + + def forward( + self, + hidden_states: torch.Tensor, + position_embeddings: tuple[torch.Tensor, torch.Tensor], + attention_mask: torch.Tensor | None, + past_key_values: Cache | None = None, + **kwargs: Unpack[TransformersKwargs], + ) -> tuple[torch.Tensor, torch.Tensor | None]: + input_shape = hidden_states.shape[:-1] + hidden_shape = (*input_shape, -1, self.head_dim) + + query_states = self.q_norm(self.q_proj(hidden_states).view(hidden_shape)).transpose(1, 2) + key_states = self.k_norm(self.k_proj(hidden_states).view(hidden_shape)).transpose(1, 2) + value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) + + cos, sin = position_embeddings + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + + if past_key_values is not None: + key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx) + + attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface( + self.config._attn_implementation, eager_attention_forward + ) + block_indices = None + if self.indexer is not None: + position_ids = kwargs.get("position_ids") + if position_ids is None: + position_ids = torch.arange( + key_states.shape[2] - query_states.shape[2], key_states.shape[2], device=query_states.device + ) + position_ids = (position_ids if position_ids.ndim > 1 else position_ids.unsqueeze(0)).expand( + query_states.shape[0], -1 + ) + block_indices = self.indexer(hidden_states, position_embeddings, past_key_values, position_ids) + if self.config._attn_implementation in ("eager", "sdpa"): + attention_mask = self.indexer.build_block_mask( + block_indices, + attention_mask, + key_states.shape[2], + query_states.dtype, + query_states.device, + position_ids, + ) + + attn_output, attn_weights = attention_interface( + self, + query_states, + key_states, + value_states, + attention_mask, + dropout=0.0 if not self.training else self.attention_dropout, + scaling=self.scaling, + block_indices=block_indices, + **kwargs, + ) + attn_output = attn_output.reshape(*input_shape, -1).contiguous() + return self.o_proj(attn_output), attn_weights + + +class MiniMaxM3VLIndexer(nn.Module): + r"""Lightning Indexer for MiniMax M3 sparse attention. + + Scores each query against every key with a small `index_n_heads`-head + dot-product branch, then max-pools those per-key scores into *blocks* of + `index_block_size` keys and keeps, per query, the top-`index_topk_blocks` + key blocks plus the `index_local_blocks` blocks immediately preceding the + query (always visible). Selection therefore happens at the granularity of a + *block of keys* rather than individual keys: the expensive main attention + only has to attend the handful of selected key blocks, which is what makes + it block-sparse (and cheaper) on long sequences. + + The `index_local_blocks` boosting their score so they always win key slots, the + same way the deployment block-sparse kernel (MiniMax `topk_sparse`) does it. + + `forward` returns the per-query selected key-block indices + `[B, S_q, index_topk_blocks]`. Valid indices are left-packed and `-1` + right-pads the unused slots (future/empty blocks), and the local boost makes + selections deduplicated -- the exact contract the block-sparse attention + kernel consumes (it counts the valid entries, then reads them sequentially + and would double-count a repeated block). The eager/SDPA path instead calls + `build_block_mask`, which expands the indices into the dense + `[B, 1, S_q, S_k]` additive mask the standard attention interface expects + (`0` at every allowed (query, key) pair, `-inf` elsewhere). + + Like DeepSeek-V4's indexer this is purely a *selection* branch: it has no + value projection and produces no residual output of its own (the upstream + checkpoint disables the index-value path on every sparse layer). + + TODO: blocks are anchored to absolute key *slots* (the contiguous reshape in + `forward` and `q_block = slot // block_size`), so left-padding shifts the block + boundaries and the selection diverges from an unpadded run -- only right-padding + is equivalent (same limitation as DeepSeek-V4; see `test_right_padding_does_not_leak` + / the skipped `test_left_padding_compatibility`). For *true* left-padding equivalence + we'd make blocking content-relative instead of slot-relative: + 1. derive block ids from `position_ids` (content positions, 0 at each row's first + real token) rather than from absolute slots, and + 2. replace the contiguous `view(..., num_key_blocks, block_size).amax(-1)` key pool + with a per-row position-binned pool (e.g. `scatter_reduce` over `key_position // + block_size`), so pad never shifts the boundaries, and + 3. mask padded keys' scores to `-inf` before the pool so a pad key can't win a block + a top-k slot. + """ + + def __init__(self, config: MiniMaxM3VLTextConfig, layer_idx: int): + super().__init__() + self.config = config + self.layer_idx = layer_idx + self.head_dim = config.index_head_dim + self.num_heads = config.index_n_heads + self.block_size = config.index_block_size + self.topk_blocks = config.index_topk_blocks + self.local_blocks = config.index_local_blocks + self.q_proj = nn.Linear(config.hidden_size, config.index_n_heads * config.index_head_dim, bias=False) + self.k_proj = nn.Linear(config.hidden_size, config.index_head_dim, bias=False) + self.q_norm = MiniMaxM3VLRMSNorm(config.index_head_dim, eps=config.rms_norm_eps) + self.k_norm = MiniMaxM3VLRMSNorm(config.index_head_dim, eps=config.rms_norm_eps) + + def forward( + self, + hidden_states: torch.Tensor, + position_embeddings: tuple[torch.Tensor, torch.Tensor], + past_key_values: Cache | None, + position_ids: torch.Tensor, + ) -> torch.Tensor: + batch, q_len, _ = hidden_states.shape + idx_q = self.q_proj(hidden_states).view(batch, q_len, -1, self.head_dim) + idx_q = self.q_norm(idx_q).transpose(1, 2) # [B, H_idx, Sq, D] + idx_k = self.k_proj(hidden_states).view(batch, q_len, 1, self.head_dim) + idx_k = self.k_norm(idx_k).transpose(1, 2) # [B, 1, Sq, D] + cos, sin = position_embeddings + idx_q, idx_k = apply_rotary_pos_emb(idx_q, idx_k, cos[..., : self.head_dim], sin[..., : self.head_dim]) + + if past_key_values is not None: + idx_k = past_key_values.layers[self.layer_idx].update_index(idx_k) + + k_len = idx_k.shape[2] + num_key_blocks = -(-k_len // self.block_size) # ceil-div + pad = num_key_blocks * self.block_size - k_len + + scores = torch.matmul(idx_q.float(), idx_k.float().transpose(-1, -2)) + k_positions = torch.arange(k_len, device=idx_q.device) + token_future = k_positions[None, None, None, :] > position_ids[:, None, :, None] # [B, 1, S_q, S_k] + scores = scores.masked_fill(token_future, float("-inf")) + if pad: + scores = F.pad(scores, (0, pad), value=float("-inf")) + scores = scores.view(batch, self.num_heads, q_len, num_key_blocks, self.block_size) + block_scores = scores.amax(dim=-1).amax(dim=1) # -> [B, S_q, num_key_blocks] + + q_block = position_ids // self.block_size # [B, S_q] + + if self.local_blocks > 0: + local = torch.arange(self.local_blocks, device=idx_q.device) + local_idx = (q_block[..., None] - local.view(1, 1, -1)).clamp(min=0) # [B, S_q, local] + block_scores.scatter_(-1, local_idx, float("inf")) + + # Slots that fall on a future/empty block keep their `-inf` + # score, which top-k sorts to the end, so tagging them `-1` yields left-packed block indices + # with `-1` right-padding which is the format expect by block-sparse attention kernel. + topk = min(self.topk_blocks, num_key_blocks) + topk_scores, topk_indices = block_scores.topk(topk, dim=-1) # [B, S_q, topk] + return topk_indices.masked_fill(topk_scores == float("-inf"), -1) + + def build_block_mask( + self, + block_indices: torch.Tensor, + attention_mask: torch.Tensor | None, + key_length: int, + dtype: torch.dtype, + device: torch.device, + position_ids: torch.Tensor, + ) -> torch.Tensor: + """ + We build the full 4D attention mask (Batch, query, key, head) + """ + batch, q_len, _ = block_indices.shape + num_key_blocks = -(-key_length // self.block_size) + + # Scatter the kept blocks to `0`; `-1` slots land in a throwaway column we drop afterwards. + safe = block_indices.masked_fill(block_indices < 0, num_key_blocks) + bias = block_indices.new_full((batch, q_len, num_key_blocks + 1), float("-inf"), dtype=dtype) + bias.scatter_(-1, safe, 0.0) + bias = bias[..., :num_key_blocks] + + # Broadcast the per-block keep/drop verdict back onto every key (block granularity), add head axis. + block_keep = (bias == 0.0).repeat_interleave(self.block_size, dim=-1)[..., :key_length].unsqueeze(1) + + # Compose block-selection with the existing mask, then emit a single additive float mask. + if attention_mask is not None: + padding_mask = attention_mask if attention_mask.dtype == torch.bool else attention_mask == 0 + keep = block_keep & padding_mask + else: + k_positions = torch.arange(key_length, device=device) + token_future = k_positions[None, None, None, :] > position_ids[:, None, :, None] # [B, 1, S_q, S_k] + keep = block_keep & ~token_future + min_dtype = torch.finfo(dtype).min + return torch.zeros(keep.shape, dtype=dtype, device=device).masked_fill(~keep, min_dtype) + + +class MiniMaxM3VLDecoderLayer(GradientCheckpointingLayer): + """M3 decoder layer: per-layer dense/MoE MLP and dense/sparse attention.""" + + def __init__(self, config: MiniMaxM3VLTextConfig, layer_idx: int): + super().__init__() + self.hidden_size = config.hidden_size + self.self_attn = MiniMaxM3VLAttention(config, layer_idx) + self.mlp = ( + MiniMaxM3VLSparseMoeBlock(config) + if config.mlp_layer_types[layer_idx] == "sparse" + else MiniMaxM3VLDenseMLP(config, intermediate_size=config.dense_intermediate_size) + ) + self.input_layernorm = MiniMaxM3VLRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.post_attention_layernorm = MiniMaxM3VLRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + + def forward( + self, + hidden_states: torch.Tensor, + position_embeddings: tuple[torch.Tensor, torch.Tensor] | None = None, + attention_mask: torch.Tensor | None = None, + position_ids: torch.LongTensor | None = None, + past_key_values: Cache | None = None, + **kwargs: Unpack[TransformersKwargs], + ) -> torch.Tensor: + residual = hidden_states + hidden_states = self.input_layernorm(hidden_states) + hidden_states, _ = self.self_attn( + hidden_states=hidden_states, + position_embeddings=position_embeddings, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + **kwargs, + ) + hidden_states = residual + hidden_states + residual = hidden_states + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states = self.mlp(hidden_states) + hidden_states = residual + hidden_states + return hidden_states + + +@auto_docstring +class MiniMaxM3VLPreTrainedModel(PreTrainedModel): + config: MiniMaxM3VLConfig | MiniMaxM3VLTextConfig + base_model_prefix = "model" + supports_gradient_checkpointing = True + _no_split_modules = ["MiniMaxM3VLDecoderLayer", "MiniMaxM3VLVisionEncoderLayer"] + _skip_keys_device_placement = ["past_key_values"] + _supports_flash_attn = False + _supports_sdpa = True + _supports_flex_attn = False + _can_compile_fullgraph = True + _supports_attention_backend = True + _can_record_outputs = { + "router_logits": OutputRecorder(MiniMaxM3VLTopKRouter, index=0), + "hidden_states": MiniMaxM3VLDecoderLayer, + "attentions": MiniMaxM3VLAttention, + } + input_modalities = ("image", "video", "text") + _keys_to_ignore_on_load_unexpected = [r"(^|\.)mtp\..*"] + _compatible_flash_implementations = ["kernels-staging/msa@v0"] + + @torch.no_grad() + def _init_weights(self, module): + super()._init_weights(module) + std = getattr(self.config, "initializer_range", 0.02) + if isinstance(module, MiniMaxM3VLExperts): + init.normal_(module.gate_up_proj, mean=0.0, std=std) + init.normal_(module.down_proj, mean=0.0, std=std) + elif isinstance(module, MiniMaxM3VLTopKRouter): + init.normal_(module.weight, mean=0.0, std=std) + init.zeros_(module.e_score_correction_bias) + elif isinstance(module, MiniMaxM3VLRMSNorm): + init.zeros_(module.weight) + + +@auto_docstring +class MiniMaxM3VLTextModel(MiniMaxM3VLPreTrainedModel): + config: MiniMaxM3VLTextConfig + + def __init__(self, config: MiniMaxM3VLTextConfig): + super().__init__(config) + self.padding_idx = config.pad_token_id + self.vocab_size = config.vocab_size + + self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) + self.layers = nn.ModuleList([MiniMaxM3VLDecoderLayer(config, i) for i in range(config.num_hidden_layers)]) + self.norm = MiniMaxM3VLRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.rotary_emb = MiniMaxM3VLRotaryEmbedding(config=config) + self.gradient_checkpointing = False + + # Initialize weights and apply final processing + self.post_init() + + @merge_with_config_defaults + @capture_outputs + @auto_docstring + def forward( + self, + input_ids: torch.LongTensor | None = None, + attention_mask: torch.Tensor | None = None, + position_ids: torch.LongTensor | None = None, + past_key_values: Cache | None = None, + inputs_embeds: torch.FloatTensor | None = None, + use_cache: bool | None = None, + **kwargs: Unpack[TransformersKwargs], + ) -> MoeModelOutputWithPast: + if (input_ids is None) ^ (inputs_embeds is not None): + raise ValueError("You must specify exactly one of input_ids or inputs_embeds") + + if use_cache and past_key_values is None: + past_key_values = DynamicCache(config=self.config) + + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + + if position_ids is None: + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + position_ids = torch.arange(inputs_embeds.shape[1], device=inputs_embeds.device) + past_seen_tokens + position_ids = position_ids.unsqueeze(0) + + if isinstance(attention_mask, dict): + causal_mask = next(iter(attention_mask.values())) + else: + causal_mask = create_causal_mask( + config=self.config, + inputs_embeds=inputs_embeds, + attention_mask=attention_mask, + past_key_values=past_key_values, + position_ids=position_ids, + ) + + hidden_states = inputs_embeds + position_embeddings = self.rotary_emb(hidden_states, position_ids=position_ids) + + # `position_ids` is threaded to every layer so the sparse layers' lightning indexer can anchor + # block selection to each query's content position (see `MiniMaxM3VLIndexer`). + for decoder_layer in self.layers[: self.config.num_hidden_layers]: + hidden_states = decoder_layer( + hidden_states, + attention_mask=causal_mask, + position_ids=position_ids, + past_key_values=past_key_values, + use_cache=use_cache, + position_embeddings=position_embeddings, + **kwargs, + ) + + hidden_states = self.norm(hidden_states) + + return MoeModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=past_key_values, + ) + + +def load_balancing_loss_func( + gate_logits: torch.Tensor | tuple[torch.Tensor] | None, + num_experts: int | None = None, + top_k=2, + attention_mask: torch.Tensor | None = None, +) -> torch.Tensor | int: + r""" + Computes auxiliary load balancing loss as in Switch Transformer - implemented in Pytorch. + + See Switch Transformer (https://huggingface.co/papers/2101.03961) for more details. This function implements the loss + function presented in equations (4) - (6) of the paper. It aims at penalizing cases where the routing between + experts is too unbalanced. + + Args: + gate_logits: + Logits from the `gate`, should be a tuple of model.config.num_hidden_layers tensors of + shape [batch_size X sequence_length, num_experts]. + num_experts: + Number of experts + top_k: + The number of experts to route per-token, can be also interpreted as the `top-k` routing + parameter. + attention_mask (`torch.Tensor`, *optional*): + The attention_mask used in forward function + shape [batch_size X sequence_length] if not None. + + Returns: + The auxiliary loss. + """ + if gate_logits is None or not isinstance(gate_logits, tuple): + return 0 + + if isinstance(gate_logits, tuple): + compute_device = gate_logits[0].device + concatenated_gate_logits = torch.cat([layer_gate.to(compute_device) for layer_gate in gate_logits], dim=0) + + routing_weights = torch.nn.functional.softmax(concatenated_gate_logits, dim=-1) + + _, selected_experts = torch.topk(routing_weights, top_k, dim=-1) + + expert_mask = torch.nn.functional.one_hot(selected_experts, num_experts) + + if attention_mask is None: + # Compute the percentage of tokens routed to each experts + tokens_per_expert = torch.mean(expert_mask.float(), dim=0) + + # Compute the average probability of routing to these experts + router_prob_per_expert = torch.mean(routing_weights, dim=0) + else: + batch_size, sequence_length = attention_mask.shape + num_hidden_layers = concatenated_gate_logits.shape[0] // (batch_size * sequence_length) + + # Compute the mask that masks all padding tokens as 0 with the same shape of expert_mask + expert_attention_mask = ( + attention_mask[None, :, :, None, None] + .expand((num_hidden_layers, batch_size, sequence_length, top_k, num_experts)) + .reshape(-1, top_k, num_experts) + .to(compute_device) + ) + + # Compute the percentage of tokens routed to each experts + tokens_per_expert = torch.sum(expert_mask.float() * expert_attention_mask, dim=0) / torch.sum( + expert_attention_mask, dim=0 + ) + + # Compute the mask that masks all padding tokens as 0 with the same shape of tokens_per_expert + router_per_expert_attention_mask = ( + attention_mask[None, :, :, None] + .expand((num_hidden_layers, batch_size, sequence_length, num_experts)) + .reshape(-1, num_experts) + .to(compute_device) + ) + + # Compute the average probability of routing to these experts + router_prob_per_expert = torch.sum(routing_weights * router_per_expert_attention_mask, dim=0) / torch.sum( + router_per_expert_attention_mask, dim=0 + ) + + overall_loss = torch.sum(tokens_per_expert * router_prob_per_expert.unsqueeze(0)) + return overall_loss * num_experts + + +@auto_docstring +class MiniMaxM3VLForCausalLM(MiniMaxM3VLPreTrainedModel, GenerationMixin): + _tied_weights_keys = {"lm_head.weight": "model.embed_tokens.weight"} + _tp_plan = {"lm_head": "colwise_gather_output"} + _pp_plan = {"lm_head": (["hidden_states"], ["logits"])} + config: MiniMaxM3VLTextConfig + + def __init__(self, config: MiniMaxM3VLTextConfig): + super().__init__(config) + self.model = MiniMaxM3VLTextModel(config) + self.vocab_size = config.vocab_size + self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + self.router_aux_loss_coef = config.router_aux_loss_coef + self.num_experts = config.num_local_experts + self.num_experts_per_tok = config.num_experts_per_tok + + # Initialize weights and apply final processing + self.post_init() + + @can_return_tuple + @auto_docstring + def forward( + self, + input_ids: torch.LongTensor | None = None, + attention_mask: torch.Tensor | None = None, + position_ids: torch.LongTensor | None = None, + past_key_values: Cache | None = None, + inputs_embeds: torch.FloatTensor | None = None, + labels: torch.LongTensor | None = None, + use_cache: bool | None = None, + output_router_logits: bool | None = None, + logits_to_keep: int | torch.Tensor = 0, + **kwargs: Unpack[TransformersKwargs], + ) -> MoeCausalLMOutputWithPast: + r""" + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., + config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored + (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. + + Example: + + ```python + >>> from transformers import AutoTokenizer, MiniMaxM3VLForCausalLM + + >>> model = MiniMaxM3VLForCausalLM.from_pretrained("mistralai/MiniMaxM3VL-8x7B-v0.1") + >>> tokenizer = AutoTokenizer.from_pretrained("mistralai/MiniMaxM3VL-8x7B-v0.1") + + >>> prompt = "Hey, are you conscious? Can you talk to me?" + >>> inputs = tokenizer(prompt, return_tensors="pt") + + >>> # Generate + >>> generate_ids = model.generate(inputs.input_ids, max_length=30) + >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] + "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you." + ```""" + + output_router_logits = ( + output_router_logits if output_router_logits is not None else self.config.output_router_logits + ) + + # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) + outputs: MoeModelOutputWithPast = self.model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_router_logits=output_router_logits, + **kwargs, + ) + + hidden_states = outputs.last_hidden_state + # Only compute necessary logits, and do not upcast them to float if we are not computing the loss + slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep + logits = self.lm_head(hidden_states[:, slice_indices, :]) + + loss = None + if labels is not None: + loss = self.loss_function(logits, labels, self.vocab_size, **kwargs) + + aux_loss = None + if output_router_logits: + aux_loss = load_balancing_loss_func( + outputs.router_logits, + self.num_experts, + self.num_experts_per_tok, + attention_mask, + ) + if labels is not None: + loss += self.router_aux_loss_coef * aux_loss.to(loss.device) # make sure to reside in the same device + + return MoeCausalLMOutputWithPast( + loss=loss, + aux_loss=aux_loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + router_logits=outputs.router_logits, + ) + + +class MiniMaxM3VLVisionEmbeddings(nn.Module): + """Patch embedding, identical to [`Qwen2_5_VisionPatchEmbed`] (reads its dims from the vision + config). The upstream checkpoint stores the conv as `patch_embedding`, renamed to the + inherited `proj` in the conversion mapping.""" + + def __init__(self, config) -> None: + super().__init__() + self.patch_size = config.patch_size + self.temporal_patch_size = config.temporal_patch_size + self.in_channels = config.num_channels + self.embed_dim = config.hidden_size + + kernel_size = [self.temporal_patch_size, self.patch_size, self.patch_size] + self.proj = nn.Conv3d( + self.in_channels, self.embed_dim, kernel_size=kernel_size, stride=kernel_size, bias=False + ) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + target_dtype = self.proj.weight.dtype + hidden_states = hidden_states.view( + -1, self.in_channels, self.temporal_patch_size, self.patch_size, self.patch_size + ) + hidden_states = self.proj(hidden_states.to(dtype=target_dtype)).view(-1, self.embed_dim) + return hidden_states + + +class MiniMaxM3VL3DRotaryEmbedding(nn.Module): + r"""3D RoPE for the vision tower: each patch is rotated by its `(T, H, W)` grid position. + + `2 * (head_dim // 2)` rotary dims are split evenly across the three axes (each rounded + down to a multiple of 2), giving `axis_dim` dims per axis and `axis_dim // 2` frequencies:: + + |<------------------ rotated (3 * axis_dim) ------------------>|<- pass ->| + +--------------------+--------------------+--------------------+----------+ + | T (frames) | H (rows) | W (cols) | | + | axis_dim | axis_dim | axis_dim | | + +--------------------+--------------------+--------------------+----------+ + + Each axis' coordinate scales its own band of frequencies; the bands are concatenated as + `T|H|W` and duplicated via `cat([f, f])` to pair with the half-rotation in + `apply_rotary_pos_emb_vision`. Any head dims past `3 * axis_dim` are left unrotated. + """ + + def __init__(self, head_dim: int, theta: float = 10000.0, spatial_merge_size: int = 1): + super().__init__() + # `2 * (head_dim // 2)` rotary dims are split evenly across T/H/W, each axis rounded + # down to a multiple of 2. With head_dim=80 that is 26 dims/axis (39 freqs total); the + # remaining `head_dim - 3 * axis_dim` dims are never rotated (they pass through). + rope_dims = 2 * (head_dim // 2) + self.axis_dim = 2 * ((rope_dims // 3) // 2) + self.spatial_merge_size = spatial_merge_size + self.theta = theta + + def forward( + self, grid_thw: torch.Tensor, device: torch.device, dtype: torch.dtype + ) -> tuple[torch.Tensor, torch.Tensor]: + m = self.spatial_merge_size + coords = [] + for t, h, w in grid_thw.tolist(): + hi = torch.arange(h).unsqueeze(1).expand(-1, w) + hi = hi.reshape(h // m, m, w // m, m).permute(0, 2, 1, 3).flatten() + wi = torch.arange(w).unsqueeze(0).expand(h, -1) + wi = wi.reshape(h // m, m, w // m, m).permute(0, 2, 1, 3).flatten() + ti = torch.arange(t).repeat_interleave(h * w) + coords.append(torch.stack([ti, hi.repeat(t), wi.repeat(t)], dim=-1)) + coords = torch.cat(coords).to(device=device, dtype=torch.float32) + + # meta device init was having trouble when it was registered. TODO standardize? + inv_freq = 1.0 / ( + self.theta ** (torch.arange(0, self.axis_dim, 2, dtype=torch.float32, device=device) / self.axis_dim) + ) + freqs = torch.cat([coords[:, i : i + 1] * inv_freq for i in range(3)], dim=-1) + emb = torch.cat([freqs, freqs], dim=-1) + return emb.cos().to(dtype), emb.sin().to(dtype) + + +def apply_rotary_pos_emb_vision( + q: torch.Tensor, k: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor +) -> tuple[torch.Tensor, torch.Tensor]: + # Only the first `rot_dim` head dims carry 3D RoPE; the tail passes through untouched. + rot_dim = cos.shape[-1] + cos, sin = cos[None, :, None, :], sin[None, :, None, :] + q_rot, q_pass = q[..., :rot_dim], q[..., rot_dim:] + k_rot, k_pass = k[..., :rot_dim], k[..., rot_dim:] + q_rot = q_rot * cos + rotate_half(q_rot) * sin + k_rot = k_rot * cos + rotate_half(k_rot) * sin + return torch.cat([q_rot, q_pass], dim=-1), torch.cat([k_rot, k_pass], dim=-1) + + +class MiniMaxM3VLVisionAttention(nn.Module): + """CLIP-style vision attention; the only difference from [`CLIPAttention`] is + that queries and keys are rotated by the tower's 3D RoPE before the + (interface-dispatched) scaled dot-product attention.""" + + def __init__(self, config: MiniMaxM3VLVisionConfig): + super().__init__() + self.config = config + self.embed_dim = config.hidden_size + self.num_heads = config.num_attention_heads + self.head_dim = self.embed_dim // self.num_heads + self.scale = self.head_dim**-0.5 + self.dropout = config.attention_dropout + self.is_causal = False + + self.k_proj = nn.Linear(self.embed_dim, self.embed_dim) + self.v_proj = nn.Linear(self.embed_dim, self.embed_dim) + self.q_proj = nn.Linear(self.embed_dim, self.embed_dim) + self.out_proj = nn.Linear(self.embed_dim, self.embed_dim) + # The vision tower has no grouped-query attention; the shared eager kernel + # still expects this attribute to drive its (no-op) `repeat_kv`. + self.num_key_value_groups = 1 + + def forward( + self, + hidden_states: torch.Tensor, + position_embeddings: tuple[torch.Tensor, torch.Tensor], + attention_mask: torch.Tensor | None = None, + **kwargs: Unpack[TransformersKwargs], + ) -> tuple[torch.Tensor, torch.Tensor | None]: + """Input shape: Batch x Time x Channel""" + input_shape = hidden_states.shape[:-1] + hidden_shape = (*input_shape, -1, self.head_dim) + + queries = self.q_proj(hidden_states).view(hidden_shape) + keys = self.k_proj(hidden_states).view(hidden_shape) + values = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) + + cos, sin = position_embeddings + queries, keys = apply_rotary_pos_emb_vision(queries, keys, cos, sin) + queries, keys = queries.transpose(1, 2), keys.transpose(1, 2) + + attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface( + self.config._attn_implementation, eager_attention_forward + ) + attn_output, attn_weights = attention_interface( + self, + queries, + keys, + values, + attention_mask, + scaling=self.scale, + dropout=0.0 if not self.training else self.dropout, + **kwargs, + ) + attn_output = attn_output.reshape(*input_shape, -1).contiguous() + return self.out_proj(attn_output), attn_weights + + +class MiniMaxM3VLVisionMLP(nn.Module): + def __init__(self, config): + super().__init__() + self.config = config + self.activation_fn = ACT2FN[config.hidden_act] + self.fc1 = nn.Linear(config.hidden_size, config.intermediate_size) + self.fc2 = nn.Linear(config.intermediate_size, config.hidden_size) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + hidden_states = self.fc1(hidden_states) + hidden_states = self.activation_fn(hidden_states) + hidden_states = self.fc2(hidden_states) + return hidden_states + + +class MiniMaxM3VLVisionEncoderLayer(GradientCheckpointingLayer): + def __init__(self, config: MiniMaxM3VLVisionConfig): + super().__init__() + self.embed_dim = config.hidden_size + self.self_attn = MiniMaxM3VLVisionAttention(config) + self.layer_norm1 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps) + self.mlp = MiniMaxM3VLVisionMLP(config) + self.layer_norm2 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: torch.Tensor, + **kwargs: Unpack[TransformersKwargs], + ) -> torch.FloatTensor: + residual = hidden_states + + hidden_states = self.layer_norm1(hidden_states) + hidden_states, _ = self.self_attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + **kwargs, + ) + hidden_states = residual + hidden_states + + residual = hidden_states + hidden_states = self.layer_norm2(hidden_states) + hidden_states = self.mlp(hidden_states) + hidden_states = residual + hidden_states + + return hidden_states + + +@auto_docstring +class MiniMaxM3VLVisionModel(MiniMaxM3VLPreTrainedModel): + """CLIP-like vision tower with Conv3d patch embed + 3D RoPE.""" + + config: MiniMaxM3VLVisionConfig + main_input_name = "pixel_values" + _can_record_outputs = { + "hidden_states": MiniMaxM3VLVisionEncoderLayer, + "attentions": MiniMaxM3VLVisionAttention, + } + + def __init__(self, config: MiniMaxM3VLVisionConfig): + super().__init__(config) + self.embeddings = MiniMaxM3VLVisionEmbeddings(config) + self.pre_layrnorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.layers = nn.ModuleList([MiniMaxM3VLVisionEncoderLayer(config) for _ in range(config.num_hidden_layers)]) + head_dim = config.hidden_size // config.num_attention_heads + self.rotary_emb = MiniMaxM3VL3DRotaryEmbedding( + head_dim, theta=config.rope_parameters["rope_theta"], spatial_merge_size=config.spatial_merge_size + ) + self.post_init() + + @merge_with_config_defaults + @capture_outputs + @auto_docstring + def forward( + self, pixel_values: torch.Tensor, image_grid_thw: torch.Tensor, **kwargs: Unpack[TransformersKwargs] + ) -> BaseModelOutputWithPooling: + r""" + image_grid_thw (`torch.Tensor` of shape `(num_images, 3)`): + The temporal, height and width of feature shape of each image. + """ + embeds = self.embeddings(pixel_values).to(self.pre_layrnorm.weight.dtype) + cos, sin = self.rotary_emb(image_grid_thw, device=embeds.device, dtype=embeds.dtype) + hidden_states = self.pre_layrnorm(embeds).unsqueeze(0) + for layer in self.layers: + hidden_states = layer(hidden_states, attention_mask=None, position_embeddings=(cos, sin), **kwargs) + return BaseModelOutputWithPooling(last_hidden_state=hidden_states, pooler_output=hidden_states[:, 0]) + + +class MiniMaxM3VLMultiModalProjector(nn.Module): + """Projects each vision patch from `vision_config.hidden_size` to `text_config.hidden_size` + (GELU MLP), then groups `spatial_merge_size**2` neighbouring patches into the channel dim and + fuses them back to a single `text_config.hidden_size` token with a second GELU MLP.""" + + def __init__(self, config: MiniMaxM3VLConfig): + super().__init__() + text_hidden = config.text_config.hidden_size + self.spatial_merge_size = config.vision_config.spatial_merge_size + self.linear_1 = nn.Linear(config.vision_config.hidden_size, config.projector_hidden_size, bias=True) + self.act = ACT2FN["gelu"] + self.linear_2 = nn.Linear(config.projector_hidden_size, text_hidden, bias=True) + self.merge_linear_1 = nn.Linear(config.merged_hidden_size, config.projector_hidden_size, bias=True) + self.merge_act = ACT2FN["gelu"] + self.merge_linear_2 = nn.Linear(config.projector_hidden_size, text_hidden, bias=True) + + def forward(self, image_features: torch.Tensor) -> torch.Tensor: + hidden_states = self.linear_2(self.act(self.linear_1(image_features))) + hidden_states = hidden_states.reshape(hidden_states.shape[0] // (self.spatial_merge_size**2), -1) + return self.merge_linear_2(self.merge_act(self.merge_linear_1(hidden_states))) + + +@auto_docstring( + custom_intro=""" + Base class for MiniMaxM3VL outputs, with hidden states and attentions. + """ +) +@dataclass +class MiniMaxM3VLModelOutputWithPast(BaseModelOutputWithPast): + r""" + past_key_values (`Cache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): + It is a [`~cache_utils.Cache`] instance. For more details, see our [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache). + + Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see + `past_key_values` input) to speed up sequential decoding. + image_hidden_states (`torch.FloatTensor`, *optional*): + A `torch.FloatTensor` of size `(num_image_patches, hidden_size)`. + image_hidden_states of the model produced by the vision encoder and after projecting the last hidden state. + video_hidden_states (`torch.FloatTensor`, *optional*): + A `torch.FloatTensor` of size `(num_video_patches, hidden_size)`. + video_hidden_states of the model produced by the vision encoder and after projecting the last hidden state. + """ + + image_hidden_states: torch.FloatTensor | None = None + + video_hidden_states: torch.FloatTensor | None = None + + +@auto_docstring( + custom_intro=""" + Base class for MiniMaxM3VL causal language model (or autoregressive) outputs. + """ +) +@dataclass +class MiniMaxM3VLCausalLMOutputWithPast(ModelOutput): + r""" + loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided): + Language modeling loss (for next-token prediction). + logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`): + Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax). + past_key_values (`Cache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): + It is a [`~cache_utils.Cache`] instance. For more details, see our [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache). + + Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see + `past_key_values` input) to speed up sequential decoding. + image_hidden_states (`torch.FloatTensor`, *optional*): + A `torch.FloatTensor` of size `(num_image_patches, hidden_size)`. + image_hidden_states of the model produced by the vision encoder and after projecting the last hidden state. + video_hidden_states (`torch.FloatTensor`, *optional*): + A `torch.FloatTensor` of size `(num_video_patches, hidden_size)`. + video_hidden_states of the model produced by the vision encoder and after projecting the last hidden state. + """ + + loss: torch.FloatTensor | None = None + logits: torch.FloatTensor | None = None + past_key_values: Cache | None = None + hidden_states: tuple[torch.FloatTensor] | None = None + attentions: tuple[torch.FloatTensor] | None = None + image_hidden_states: torch.FloatTensor | None = None + + video_hidden_states: torch.FloatTensor | None = None + + +@auto_docstring(custom_intro="MiniMax M3 VL backbone (vision + projector + text), without LM head.") +class MiniMaxM3VLModel(MiniMaxM3VLPreTrainedModel): + config: MiniMaxM3VLConfig + + def __init__(self, config: MiniMaxM3VLConfig): + super().__init__(config) + self.vision_tower = MiniMaxM3VLVisionModel(config.vision_config) + self.multi_modal_projector = MiniMaxM3VLMultiModalProjector(config) + self.language_model = MiniMaxM3VLTextModel(config.text_config) + self.post_init() + + @merge_with_config_defaults + @can_return_tuple + @auto_docstring( + custom_intro="Obtains image last hidden states from the vision tower and apply multimodal projection." + ) + def get_image_features( + self, + pixel_values: torch.FloatTensor, + image_grid_thw: torch.Tensor, + **kwargs, + ) -> BaseModelOutputWithPooling: + r""" + image_grid_thw (`torch.Tensor` of shape `(num_images, 3)`, *optional*): + The temporal, height and width of each image's feature grid, used to build the vision 3D RoPE + and to merge patch features. + """ + # Return the raw vision-tower output (so callers can inspect hidden states / + # attentions) while stashing the projected + spatially-merged features — + # ready to scatter into the text embeddings — in `pooler_output`. + vision_outputs = self.vision_tower(pixel_values=pixel_values, image_grid_thw=image_grid_thw, **kwargs) + vision_outputs.pooler_output = self.multi_modal_projector(vision_outputs.last_hidden_state.squeeze(0)) + return vision_outputs + + def get_placeholder_mask( + self, + input_ids: torch.LongTensor, + inputs_embeds: torch.FloatTensor, + image_features: torch.FloatTensor | None = None, + video_features: torch.FloatTensor | None = None, + ): + """ + Obtains the image/video placeholder masks from `input_ids` or `inputs_embeds`, and checks that the + placeholder token count matches the multimodal feature length. Raises if they differ. + """ + if input_ids is None: + special_image_mask = inputs_embeds == self.get_input_embeddings()( + torch.tensor(self.config.image_token_id, dtype=torch.long, device=inputs_embeds.device) + ) + special_image_mask = special_image_mask.all(-1) + special_video_mask = inputs_embeds == self.get_input_embeddings()( + torch.tensor(self.config.video_token_id, dtype=torch.long, device=inputs_embeds.device) + ) + special_video_mask = special_video_mask.all(-1) + else: + special_image_mask = input_ids == self.config.image_token_id + special_video_mask = input_ids == self.config.video_token_id + + n_image_tokens = special_image_mask.sum() + special_image_mask = special_image_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device) + if image_features is not None: + torch_compilable_check( + inputs_embeds[special_image_mask].numel() == image_features.numel(), + f"Image features and image tokens do not match, tokens: {n_image_tokens}, features: {image_features.shape[0]}", + ) + + n_video_tokens = special_video_mask.sum() + special_video_mask = special_video_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device) + if video_features is not None: + torch_compilable_check( + inputs_embeds[special_video_mask].numel() == video_features.numel(), + f"Video features and video tokens do not match, tokens: {n_video_tokens}, features: {video_features.shape[0]}", + ) + return special_image_mask, special_video_mask + + @can_return_tuple + @auto_docstring + def forward( + self, + input_ids: torch.LongTensor | None = None, + pixel_values: torch.FloatTensor | None = None, + pixel_values_videos: torch.FloatTensor | None = None, + image_grid_thw: torch.Tensor | None = None, + video_grid_thw: torch.Tensor | None = None, + attention_mask: torch.Tensor | None = None, + position_ids: torch.LongTensor | None = None, + past_key_values: Cache | None = None, + inputs_embeds: torch.FloatTensor | None = None, + **kwargs: Unpack[TransformersKwargs], + ) -> tuple | MiniMaxM3VLModelOutputWithPast: + r""" + image_grid_thw (`torch.Tensor` of shape `(num_images, 3)`, *optional*): + The temporal, height and width of each image's feature grid, used to build the vision 3D RoPE + and to merge patch features. + video_grid_thw (`torch.Tensor` of shape `(num_videos, 3)`, *optional*): + The temporal, height and width of each video's feature grid, used to build the vision 3D RoPE + and to merge patch features. + """ + if (input_ids is None) ^ (inputs_embeds is not None): + raise ValueError("You must specify exactly one of input_ids or inputs_embeds") + + if inputs_embeds is None: + inputs_embeds = self.get_input_embeddings()(input_ids) + + image_features = None + if pixel_values is not None: + image_features = self.get_image_features( + pixel_values=pixel_values, image_grid_thw=image_grid_thw + ).pooler_output.to(inputs_embeds.device, inputs_embeds.dtype) + + video_features = None + if pixel_values_videos is not None: + video_features = self.get_video_features( + pixel_values_videos=pixel_values_videos, video_grid_thw=video_grid_thw + ).pooler_output.to(inputs_embeds.device, inputs_embeds.dtype) + + image_mask, video_mask = self.get_placeholder_mask( + input_ids, inputs_embeds, image_features=image_features, video_features=video_features + ) + if image_features is not None: + inputs_embeds = inputs_embeds.masked_scatter(image_mask, image_features) + if video_features is not None: + inputs_embeds = inputs_embeds.masked_scatter(video_mask, video_features) + + outputs = self.language_model( + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + **kwargs, + ) + + return MiniMaxM3VLModelOutputWithPast( + last_hidden_state=outputs.last_hidden_state, + past_key_values=outputs.past_key_values, + hidden_states=getattr(outputs, "hidden_states", None), + attentions=getattr(outputs, "attentions", None), + image_hidden_states=image_features, + video_hidden_states=video_features, + ) + + @merge_with_config_defaults + @can_return_tuple + @auto_docstring( + custom_intro="Obtains video last hidden states from the vision tower and apply multimodal projection." + ) + def get_video_features( + self, + pixel_values_videos: torch.FloatTensor, + video_grid_thw: torch.Tensor, + **kwargs, + ) -> BaseModelOutputWithPooling: + r""" + pixel_values_videos (`torch.FloatTensor`): + The tensors corresponding to the input video frames. + video_grid_thw (`torch.Tensor` of shape `(num_videos, 3)`, *optional*): + The temporal, height and width of each video's feature grid, used to build the vision 3D RoPE + and to merge patch features. + """ + # Video frames flow through the same vision pipeline as images (the tower is + # grid-agnostic); only the placeholder token they scatter into differs. + vision_outputs = self.vision_tower(pixel_values=pixel_values_videos, image_grid_thw=video_grid_thw, **kwargs) + vision_outputs.pooler_output = self.multi_modal_projector(vision_outputs.last_hidden_state.squeeze(0)) + return vision_outputs + + +@auto_docstring(custom_intro="MiniMax M3 VL full model with LM head (text + vision).") +class MiniMaxM3SparseForConditionalGeneration(MiniMaxM3VLPreTrainedModel, GenerationMixin): + _tied_weights_keys = {"lm_head.weight": "model.language_model.embed_tokens.weight"} + config: MiniMaxM3VLConfig + + def __init__(self, config: MiniMaxM3VLConfig): + super().__init__(config) + self.model = MiniMaxM3VLModel(config) + self.lm_head = nn.Linear(config.text_config.hidden_size, config.text_config.vocab_size, bias=False) + self.post_init() + + def get_output_embeddings(self) -> nn.Module: + return self.lm_head + + @auto_docstring + def get_image_features(self, pixel_values, image_grid_thw, **kwargs) -> tuple | BaseModelOutputWithPooling: + r""" + image_grid_thw (`torch.Tensor` of shape `(num_images, 3)`, *optional*): + The temporal, height and width of each image's feature grid, used to build the vision 3D RoPE + and to merge patch features. + """ + return self.model.get_image_features(pixel_values, image_grid_thw, **kwargs) + + @can_return_tuple + @auto_docstring + def forward( + self, + input_ids: torch.LongTensor | None = None, + pixel_values: torch.FloatTensor | None = None, + pixel_values_videos: torch.FloatTensor | None = None, + image_grid_thw: torch.Tensor | None = None, + video_grid_thw: torch.Tensor | None = None, + attention_mask: torch.Tensor | None = None, + position_ids: torch.LongTensor | None = None, + past_key_values: Cache | None = None, + inputs_embeds: torch.FloatTensor | None = None, + labels: torch.LongTensor | None = None, + logits_to_keep: int | torch.Tensor = 0, + **kwargs: Unpack[TransformersKwargs], + ) -> tuple | MiniMaxM3VLCausalLMOutputWithPast: + r""" + image_grid_thw (`torch.Tensor` of shape `(num_images, 3)`, *optional*): + The temporal, height and width of each image's feature grid, used to build the vision 3D RoPE + and to merge patch features. + video_grid_thw (`torch.Tensor` of shape `(num_videos, 3)`, *optional*): + The temporal, height and width of each video's feature grid, used to build the vision 3D RoPE + and to merge patch features. + """ + outputs = self.model( + input_ids=input_ids, + pixel_values=pixel_values, + pixel_values_videos=pixel_values_videos, + image_grid_thw=image_grid_thw, + video_grid_thw=video_grid_thw, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + **kwargs, + ) + hidden_states = outputs.last_hidden_state + slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep + logits = self.lm_head(hidden_states[:, slice_indices, :]) + + loss = None + if labels is not None: + loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.text_config.vocab_size) + + return MiniMaxM3VLCausalLMOutputWithPast( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + image_hidden_states=outputs.image_hidden_states, + video_hidden_states=outputs.video_hidden_states, + ) + + def prepare_inputs_for_generation( + self, + input_ids, + past_key_values=None, + inputs_embeds=None, + pixel_values=None, + pixel_values_videos=None, + attention_mask=None, + logits_to_keep=None, + is_first_iteration=False, + **kwargs, + ): + # Overwritten -- pixel inputs are merged into the cache on the first step, so we + # only forward them once (image and video alike). + model_inputs = super().prepare_inputs_for_generation( + input_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + attention_mask=attention_mask, + logits_to_keep=logits_to_keep, + is_first_iteration=is_first_iteration, + **kwargs, + ) + + if is_first_iteration or not kwargs.get("use_cache", True): + model_inputs["pixel_values"] = pixel_values + model_inputs["pixel_values_videos"] = pixel_values_videos + + return model_inputs + + def get_video_features(self, pixel_values_videos, video_grid_thw, **kwargs): + r""" + pixel_values_videos (`torch.FloatTensor`): + The tensors corresponding to the input video frames. + video_grid_thw (`torch.Tensor` of shape `(num_videos, 3)`, *optional*): + The temporal, height and width of each video's feature grid, used to build the vision 3D RoPE + and to merge patch features. + """ + return self.model.get_video_features(pixel_values_videos, video_grid_thw, **kwargs) + + +__all__ = [ + "MiniMaxM3VLForCausalLM", + "MiniMaxM3SparseForConditionalGeneration", + "MiniMaxM3VLModel", + "MiniMaxM3VLPreTrainedModel", + "MiniMaxM3VLTextModel", + "MiniMaxM3VLVisionModel", +] diff --git a/src/transformers/models/minimax_m3_vl/modular_minimax_m3_vl.py b/src/transformers/models/minimax_m3_vl/modular_minimax_m3_vl.py new file mode 100644 index 000000000000..e1c7a0962356 --- /dev/null +++ b/src/transformers/models/minimax_m3_vl/modular_minimax_m3_vl.py @@ -0,0 +1,1308 @@ +# Copyright 2026 the MiniMax AI Team and HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""MiniMax M3 VL: vision tower + M3 (mixed sparse/dense MoE) text backbone.""" + +from collections.abc import Callable + +import torch +import torch.nn as nn +import torch.nn.functional as F +from huggingface_hub.dataclasses import strict + +from ... import initialization as init +from ...activations import ACT2FN +from ...cache_utils import Cache, DynamicCache, DynamicLayer, StaticLayer +from ...configuration_utils import PreTrainedConfig +from ...masking_utils import create_causal_mask +from ...modeling_outputs import BaseModelOutputWithPooling, MoeModelOutputWithPast +from ...modeling_rope_utils import RopeParameters +from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel +from ...processing_utils import Unpack +from ...utils import TransformersKwargs, auto_docstring, logging, torch_compilable_check +from ...utils.generic import can_return_tuple, merge_with_config_defaults +from ...utils.import_utils import is_torchdynamo_compiling +from ...utils.output_capturing import capture_outputs +from ..auto import AutoConfig +from ..clip.modeling_clip import CLIPMLP, CLIPAttention, CLIPEncoderLayer +from ..deepseek_v4.modeling_deepseek_v4 import DeepseekV4Experts +from ..gemma3.modeling_gemma3 import Gemma3RMSNorm +from ..laguna.modeling_laguna import LagunaSparseMoeBlock +from ..llama.modeling_llama import eager_attention_forward +from ..llava.modeling_llava import ( + LlavaCausalLMOutputWithPast, + LlavaForConditionalGeneration, + LlavaModel, + LlavaModelOutputWithPast, +) +from ..minimax_m2.configuration_minimax_m2 import MiniMaxM2Config +from ..minimax_m2.modeling_minimax_m2 import ( + MiniMaxM2Attention, + MiniMaxM2ForCausalLM, + MiniMaxM2Model, + MiniMaxM2PreTrainedModel, + MiniMaxM2RotaryEmbedding, + MiniMaxM2TopKRouter, + apply_rotary_pos_emb, +) +from ..mixtral.modeling_mixtral import MixtralDecoderLayer +from ..qwen2_5_vl.modeling_qwen2_5_vl import Qwen2_5_VisionPatchEmbed +from ..qwen2_vl.processing_qwen2_vl import Qwen2VLProcessor, Qwen2VLProcessorKwargs + + +logger = logging.get_logger(__name__) + + +@auto_docstring(checkpoint="MiniMaxAI/MiniMax-M3-preview") +@strict +class MiniMaxM3VLTextConfig(MiniMaxM2Config): + r""" + dense_intermediate_size (`int`, *optional*, defaults to 12288): + Intermediate size of the dense MLP used on layers whose `mlp_layer_types` entry is `"dense"`. + shared_intermediate_size (`int`, *optional*, defaults to 3072): + Intermediate size of a single shared expert in the MoE layers. + rotary_dim (`int`, *optional*, defaults to 64): + Number of head channels rotated by RoPE; the remaining channels are passed through unchanged. + swiglu_alpha (`float`, *optional*, defaults to 1.702): + Sigmoid gain of the SwiGLU-OAI activation. + swiglu_limit (`float`, *optional*, defaults to 7.0): + Clamp bound applied to the gate and up projections of the SwiGLU-OAI activation. + mlp_layer_types (`list[str]`, *optional*): + Per-layer MLP selector: `"sparse"` for a MoE block, `"dense"` for a dense MLP. + index_n_heads (`int`, *optional*, defaults to 4): + Number of heads in the lightning indexer's dot-product scoring branch. + index_head_dim (`int`, *optional*, defaults to 128): + Per-head channel dimension of the lightning indexer. + index_block_size (`int`, *optional*, defaults to 128): + Number of key tokens pooled into a single scored block. + index_topk_blocks (`int`, *optional*, defaults to 16): + Number of top-scoring key blocks each query may attend to. + index_local_blocks (`int`, *optional*, defaults to 1): + Number of key blocks immediately preceding the query always kept visible / attended to. + """ + + model_type = "minimax_m3_vl_text" + base_config_key = "text_config" + base_model_ep_plan = { + "layers.*.mlp.gate": "ep_router", + "layers.*.mlp.experts.gate_up_proj": "grouped_gemm", + "layers.*.mlp.experts.gate_up_proj_scale_inv": "grouped_gemm", + "layers.*.mlp.experts.down_proj": "grouped_gemm", + "layers.*.mlp.experts.down_proj_scale_inv": "grouped_gemm", + "layers.*.mlp.experts": "moe_tp_experts", + } + + hidden_size: int = 6144 + intermediate_size: int = 3072 + dense_intermediate_size: int = 12288 + shared_intermediate_size: int = 3072 + num_hidden_layers: int = 60 + num_attention_heads: int = 64 + num_key_value_heads: int = 4 + head_dim: int = 128 + max_position_embeddings: int = 524288 + vocab_size: int = 200064 + rms_norm_eps: float = 1e-06 + num_local_experts: int = 128 + num_experts_per_tok: int = 4 + routed_scaling_factor: float = 2.0 + rotary_dim: int = 64 + swiglu_alpha: float = 1.702 + swiglu_limit: float = 7.0 + mlp_layer_types: list[str] | None = None + index_n_heads: int = 4 + index_head_dim: int = 128 + index_block_size: int = 128 + index_topk_blocks: int = 16 + index_local_blocks: int = 1 + layer_types: list[str] | None = None + tie_word_embeddings: bool = False + pad_token_id: int | None = None + bos_token_id: int | None = 200034 + eos_token_id: int | list[int] | None = 200020 + rope_parameters: RopeParameters | dict | None = None + + def __post_init__(self, **kwargs): + sparse_cfg = kwargs.pop("sparse_attention_config", None) or {} + moe_layer_freq = kwargs.pop("moe_layer_freq", None) + PreTrainedConfig.__post_init__(self, **kwargs) + # Checkpoint declares "swigluoai", but the gate is computed inline from swiglu_alpha/limit; hidden_act + # is only the pointwise fallback and must be a real ACT2FN key, so normalize it to silu. + self.hidden_act = "silu" + + for flat, legacy in { + "index_n_heads": "sparse_num_index_heads", + "index_head_dim": "sparse_index_dim", + "index_block_size": "sparse_block_size", + "index_topk_blocks": "sparse_topk_blocks", + "index_local_blocks": "sparse_local_block", + }.items(): + if legacy in sparse_cfg: + setattr(self, flat, sparse_cfg[legacy]) + + # `layer_types` is the canonical per-layer attention dispatch: it tells + # `DynamicCache(config=...)` which layers want the sparse cache and tells + # `MiniMaxM3VLAttention` which layers build a sparse Lightning Indexer. + if self.layer_types is None and "sparse_attention_freq" in sparse_cfg: + self.layer_types = [ + "minimax_m3_sparse" if f else "full_attention" for f in sparse_cfg["sparse_attention_freq"] + ] + if self.layer_types is None: + self.layer_types = ["full_attention"] * self.num_hidden_layers + + # `mlp_layer_types` is the per-layer MLP dispatch read by `MiniMaxM3VLDecoderLayer`: + if self.mlp_layer_types is None and moe_layer_freq is not None: + self.mlp_layer_types = ["sparse" if f else "dense" for f in moe_layer_freq] + if self.mlp_layer_types is None: + self.mlp_layer_types = ["sparse"] * self.num_hidden_layers + + +@auto_docstring(checkpoint="MiniMaxAI/MiniMax-M3-preview") +@strict +class MiniMaxM3VLVisionConfig(PreTrainedConfig): + r""" + rope_parameters (`RopeParameters`, *optional*): + Standard RoPE configuration for the vision tower's 3D rotary position embedding. + """ + + model_type = "minimax_m3_vl_vision" + base_config_key = "vision_config" + default_theta = 10000.0 + + hidden_size: int = 1280 + intermediate_size: int = 5120 + num_hidden_layers: int = 32 + num_attention_heads: int = 16 + num_channels: int = 3 + image_size: int = 2016 + patch_size: int = 14 + temporal_patch_size: int = 2 + spatial_merge_size: int = 2 + hidden_act: str = "gelu" + layer_norm_eps: float = 1e-05 + attention_dropout: float = 0.0 + rope_parameters: RopeParameters | dict | None = None + initializer_range: float = 0.02 + + +@auto_docstring(checkpoint="MiniMaxAI/MiniMax-M3-preview") +@strict +class MiniMaxM3VLConfig(PreTrainedConfig): + model_type = "minimax_m3_vl" + sub_configs = {"text_config": AutoConfig, "vision_config": AutoConfig} + attribute_map = { + "image_token_id": "image_token_index", + "video_token_id": "video_token_index", + } + + vision_config: dict | PreTrainedConfig | None = None + text_config: dict | PreTrainedConfig | None = None + image_token_index: int = 200025 + video_token_index: int = 200026 + projector_hidden_size: int = 6144 + tie_word_embeddings: bool = False + + def __post_init__(self, **kwargs): + if isinstance(self.vision_config, dict): + self.vision_config.pop("model_type", None) + self.vision_config = MiniMaxM3VLVisionConfig(**self.vision_config) + elif self.vision_config is None: + self.vision_config = MiniMaxM3VLVisionConfig() + + if isinstance(self.text_config, dict): + self.text_config.pop("model_type", None) + self.text_config = MiniMaxM3VLTextConfig(**self.text_config) + elif self.text_config is None: + self.text_config = MiniMaxM3VLTextConfig() + + if not self.tie_word_embeddings and self.text_config.tie_word_embeddings: + self.tie_word_embeddings = self.text_config.tie_word_embeddings + + # Channel dim after grouping `spatial_merge_size**2` projected patches, consumed by the + # patch-merge MLP inside `MiniMaxM3VLMultiModalProjector`. + self.merged_hidden_size = self.text_config.hidden_size * (self.vision_config.spatial_merge_size**2) + + super().__post_init__(**kwargs) + + +class MiniMaxM3VLSparseCacheLayer(DynamicLayer): + layer_type = "minimax_m3_sparse" + + def __init__(self, config: PreTrainedConfig | None = None): + super().__init__(config) + self.idx_keys: torch.Tensor | None = None + + def update_index(self, idx_k: torch.Tensor) -> torch.Tensor: + """Append the new token's `idx_k` to the cache and return the full history.""" + self.idx_keys = idx_k if self.idx_keys is None else torch.cat([self.idx_keys, idx_k], dim=-2) + return self.idx_keys + + def reorder_cache(self, beam_idx: torch.LongTensor) -> None: + super().reorder_cache(beam_idx) + if self.idx_keys is not None: + self.idx_keys = self.idx_keys.index_select(0, beam_idx.to(self.idx_keys.device)) + + def batch_repeat_interleave(self, repeats: int) -> None: + super().batch_repeat_interleave(repeats) + if self.idx_keys is not None: + self.idx_keys = self.idx_keys.repeat_interleave(repeats, dim=0) + + def batch_select_indices(self, indices: torch.Tensor) -> None: + super().batch_select_indices(indices) + if self.idx_keys is not None: + self.idx_keys = self.idx_keys[indices, ...] + + def crop(self, max_length: int) -> None: + super().crop(max_length) + if max_length < 0: + max_length = self.get_seq_length() - abs(max_length) + if self.idx_keys is not None and self.idx_keys.shape[-2] > max_length: + self.idx_keys = self.idx_keys[..., :max_length, :] + + +class MiniMaxM3VLSparseStaticCacheLayer(StaticLayer): + layer_type = "minimax_m3_sparse" + + def __init__(self, max_cache_len: int): + super().__init__(max_cache_len) + self.idx_keys: torch.Tensor | None = None + # Tensor (not int) so it can be marked as a static address for cudagraphs, like `cumulative_length`. + self.idx_cumulative_length = torch.tensor([0], dtype=int) + + def update_index(self, idx_k: torch.Tensor) -> torch.Tensor: + """Write the new token's `idx_k` into the static buffer in place and return the whole buffer. + + The buffer's unfilled tail holds zeros, but those slots sit at key positions ahead of every + current query, so the indexer's block- and token-level causal masking discards them — the + returned `[B, 1, max_cache_len, D]` history is therefore safe to score against directly. + """ + if self.idx_keys is None: + self.idx_keys = torch.zeros( + (idx_k.shape[0], idx_k.shape[1], self.max_cache_len, idx_k.shape[-1]), + dtype=idx_k.dtype, + device=idx_k.device, + ) + self.idx_cumulative_length = self.idx_cumulative_length.to(idx_k.device) + if not is_torchdynamo_compiling(): + torch._dynamo.mark_static_address(self.idx_keys) + torch._dynamo.mark_static_address(self.idx_cumulative_length) + + kv_len = idx_k.shape[-2] + cache_position = torch.arange(kv_len, device=self.idx_keys.device) + self.idx_cumulative_length + self.idx_cumulative_length.add_(kv_len) + try: + self.idx_keys.index_copy_(2, cache_position, idx_k) + except NotImplementedError: + # Fallback for devices like MPS where index_copy_ might not be supported. + self.idx_keys[:, :, cache_position] = idx_k + return self.idx_keys + + def reset(self) -> None: + super().reset() + if self.idx_keys is not None: + self.idx_keys.zero_() + self.idx_cumulative_length.zero_() + + def reorder_cache(self, beam_idx: torch.LongTensor) -> None: + super().reorder_cache(beam_idx) + if self.idx_keys is not None: + self.idx_keys = self.idx_keys.index_select(0, beam_idx.to(self.idx_keys.device)) + + +class MiniMaxM3VLRMSNorm(Gemma3RMSNorm): + """Gemma-style RMSNorm: normalizes in fp32 and scales by `weight + 1`.""" + + +class MiniMaxM3VLDenseMLP(nn.Module): + def __init__(self, config: MiniMaxM3VLTextConfig, intermediate_size: int | None = None): + super().__init__() + inter = intermediate_size if intermediate_size is not None else config.dense_intermediate_size + self.swiglu_alpha = config.swiglu_alpha + self.swiglu_limit = config.swiglu_limit + self.gate_up_proj = nn.Linear(config.hidden_size, 2 * inter, bias=False) + self.down_proj = nn.Linear(inter, config.hidden_size, bias=False) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + gate_up = self.gate_up_proj(hidden_states) + gate, up = gate_up.chunk(2, dim=-1) + gate = gate.clamp(max=self.swiglu_limit) + up = up.clamp(min=-self.swiglu_limit, max=self.swiglu_limit) + glu = gate * torch.sigmoid(gate * self.swiglu_alpha) + return self.down_proj((up + 1.0) * glu) + + +class MiniMaxM3VLExperts(DeepseekV4Experts): + def __init__(self, config: MiniMaxM3VLTextConfig): + super().__init__(config) + self.swiglu_alpha = config.swiglu_alpha + self.swiglu_limit = config.swiglu_limit + del self.act_fn + + def _apply_gate(self, gate_up: torch.Tensor) -> torch.Tensor: + # same as GPT OSS, but the weights are not interleaved + gate, up = gate_up.chunk(2, dim=-1) + gate = gate.clamp(max=self.swiglu_limit) + up = up.clamp(min=-self.swiglu_limit, max=self.swiglu_limit) + glu = gate * torch.sigmoid(gate * self.swiglu_alpha) + return (up + 1.0) * glu + + +class MiniMaxM3VLTopKRouter(MiniMaxM2TopKRouter): + def __init__(self, config: MiniMaxM3VLTextConfig): + super().__init__(config) + self.register_buffer("e_score_correction_bias", torch.zeros(config.num_local_experts)) + + def forward(self, hidden_states: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + hidden_states = hidden_states.reshape(-1, self.hidden_dim) + router_logits = F.linear(hidden_states.to(self.weight.dtype), self.weight) + # Sigmoid scoring (not softmax), as in M2. + routing_weights = F.sigmoid(router_logits.float()) + scores_for_choice = routing_weights + self.e_score_correction_bias + _, top_k_index = torch.topk(scores_for_choice, self.top_k, dim=-1, sorted=False) + top_k_weights = routing_weights.gather(1, top_k_index) + top_k_weights /= top_k_weights.sum(dim=-1, keepdim=True) + return router_logits, top_k_weights, top_k_index + + +class MiniMaxM3VLSparseMoeBlock(LagunaSparseMoeBlock): + def __init__(self, config: MiniMaxM3VLTextConfig): + nn.Module.__init__(self) + self.gate = MiniMaxM3VLTopKRouter(config) + self.experts = MiniMaxM3VLExperts(config) + self.routed_scaling_factor = config.routed_scaling_factor + self.shared_experts = MiniMaxM3VLDenseMLP(config, intermediate_size=config.shared_intermediate_size) + + +class MiniMaxM3VLRotaryEmbedding(MiniMaxM2RotaryEmbedding): + pass + + +class MiniMaxM3VLAttention(MiniMaxM2Attention): + """ + M3 attention: per-head Gemma QK-norm + partial RoPE, optionally sparse indexer selection which require position IDs. + """ + + def __init__(self, config: MiniMaxM3VLTextConfig, layer_idx: int): + super().__init__(config, layer_idx) + self.q_norm = MiniMaxM3VLRMSNorm(self.head_dim, eps=config.rms_norm_eps) + self.k_norm = MiniMaxM3VLRMSNorm(self.head_dim, eps=config.rms_norm_eps) + self.indexer = ( + MiniMaxM3VLIndexer(config, layer_idx) if config.layer_types[layer_idx] == "minimax_m3_sparse" else None + ) + + def forward( + self, + hidden_states: torch.Tensor, + position_embeddings: tuple[torch.Tensor, torch.Tensor], + attention_mask: torch.Tensor | None, + past_key_values: Cache | None = None, + **kwargs: Unpack[TransformersKwargs], + ) -> tuple[torch.Tensor, torch.Tensor | None]: + input_shape = hidden_states.shape[:-1] + hidden_shape = (*input_shape, -1, self.head_dim) + + query_states = self.q_norm(self.q_proj(hidden_states).view(hidden_shape)).transpose(1, 2) + key_states = self.k_norm(self.k_proj(hidden_states).view(hidden_shape)).transpose(1, 2) + value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) + + cos, sin = position_embeddings + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + + if past_key_values is not None: + key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx) + + attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface( + self.config._attn_implementation, eager_attention_forward + ) + block_indices = None + if self.indexer is not None: + position_ids = kwargs.get("position_ids") + if position_ids is None: + position_ids = torch.arange( + key_states.shape[2] - query_states.shape[2], key_states.shape[2], device=query_states.device + ) + position_ids = (position_ids if position_ids.ndim > 1 else position_ids.unsqueeze(0)).expand( + query_states.shape[0], -1 + ) + block_indices = self.indexer(hidden_states, position_embeddings, past_key_values, position_ids) + if self.config._attn_implementation in ("eager", "sdpa"): + attention_mask = self.indexer.build_block_mask( + block_indices, + attention_mask, + key_states.shape[2], + query_states.dtype, + query_states.device, + position_ids, + ) + + attn_output, attn_weights = attention_interface( + self, + query_states, + key_states, + value_states, + attention_mask, + dropout=0.0 if not self.training else self.attention_dropout, + scaling=self.scaling, + block_indices=block_indices, + **kwargs, + ) + attn_output = attn_output.reshape(*input_shape, -1).contiguous() + return self.o_proj(attn_output), attn_weights + + +class MiniMaxM3VLIndexer(nn.Module): + r"""Lightning Indexer for MiniMax M3 sparse attention. + + Scores each query against every key with a small `index_n_heads`-head + dot-product branch, then max-pools those per-key scores into *blocks* of + `index_block_size` keys and keeps, per query, the top-`index_topk_blocks` + key blocks plus the `index_local_blocks` blocks immediately preceding the + query (always visible). Selection therefore happens at the granularity of a + *block of keys* rather than individual keys: the expensive main attention + only has to attend the handful of selected key blocks, which is what makes + it block-sparse (and cheaper) on long sequences. + + The `index_local_blocks` boosting their score so they always win key slots, the + same way the deployment block-sparse kernel (MiniMax `topk_sparse`) does it. + + `forward` returns the per-query selected key-block indices + `[B, S_q, index_topk_blocks]`. Valid indices are left-packed and `-1` + right-pads the unused slots (future/empty blocks), and the local boost makes + selections deduplicated -- the exact contract the block-sparse attention + kernel consumes (it counts the valid entries, then reads them sequentially + and would double-count a repeated block). The eager/SDPA path instead calls + `build_block_mask`, which expands the indices into the dense + `[B, 1, S_q, S_k]` additive mask the standard attention interface expects + (`0` at every allowed (query, key) pair, `-inf` elsewhere). + + Like DeepSeek-V4's indexer this is purely a *selection* branch: it has no + value projection and produces no residual output of its own (the upstream + checkpoint disables the index-value path on every sparse layer). + + TODO: blocks are anchored to absolute key *slots* (the contiguous reshape in + `forward` and `q_block = slot // block_size`), so left-padding shifts the block + boundaries and the selection diverges from an unpadded run -- only right-padding + is equivalent (same limitation as DeepSeek-V4; see `test_right_padding_does_not_leak` + / the skipped `test_left_padding_compatibility`). For *true* left-padding equivalence + we'd make blocking content-relative instead of slot-relative: + 1. derive block ids from `position_ids` (content positions, 0 at each row's first + real token) rather than from absolute slots, and + 2. replace the contiguous `view(..., num_key_blocks, block_size).amax(-1)` key pool + with a per-row position-binned pool (e.g. `scatter_reduce` over `key_position // + block_size`), so pad never shifts the boundaries, and + 3. mask padded keys' scores to `-inf` before the pool so a pad key can't win a block + a top-k slot. + """ + + def __init__(self, config: MiniMaxM3VLTextConfig, layer_idx: int): + super().__init__() + self.config = config + self.layer_idx = layer_idx + self.head_dim = config.index_head_dim + self.num_heads = config.index_n_heads + self.block_size = config.index_block_size + self.topk_blocks = config.index_topk_blocks + self.local_blocks = config.index_local_blocks + self.q_proj = nn.Linear(config.hidden_size, config.index_n_heads * config.index_head_dim, bias=False) + self.k_proj = nn.Linear(config.hidden_size, config.index_head_dim, bias=False) + self.q_norm = MiniMaxM3VLRMSNorm(config.index_head_dim, eps=config.rms_norm_eps) + self.k_norm = MiniMaxM3VLRMSNorm(config.index_head_dim, eps=config.rms_norm_eps) + + def forward( + self, + hidden_states: torch.Tensor, + position_embeddings: tuple[torch.Tensor, torch.Tensor], + past_key_values: Cache | None, + position_ids: torch.Tensor, + ) -> torch.Tensor: + batch, q_len, _ = hidden_states.shape + idx_q = self.q_proj(hidden_states).view(batch, q_len, -1, self.head_dim) + idx_q = self.q_norm(idx_q).transpose(1, 2) # [B, H_idx, Sq, D] + idx_k = self.k_proj(hidden_states).view(batch, q_len, 1, self.head_dim) + idx_k = self.k_norm(idx_k).transpose(1, 2) # [B, 1, Sq, D] + cos, sin = position_embeddings + idx_q, idx_k = apply_rotary_pos_emb(idx_q, idx_k, cos[..., : self.head_dim], sin[..., : self.head_dim]) + + if past_key_values is not None: + idx_k = past_key_values.layers[self.layer_idx].update_index(idx_k) + + k_len = idx_k.shape[2] + num_key_blocks = -(-k_len // self.block_size) # ceil-div + pad = num_key_blocks * self.block_size - k_len + + scores = torch.matmul(idx_q.float(), idx_k.float().transpose(-1, -2)) + k_positions = torch.arange(k_len, device=idx_q.device) + token_future = k_positions[None, None, None, :] > position_ids[:, None, :, None] # [B, 1, S_q, S_k] + scores = scores.masked_fill(token_future, float("-inf")) + if pad: + scores = F.pad(scores, (0, pad), value=float("-inf")) + scores = scores.view(batch, self.num_heads, q_len, num_key_blocks, self.block_size) + block_scores = scores.amax(dim=-1).amax(dim=1) # -> [B, S_q, num_key_blocks] + + q_block = position_ids // self.block_size # [B, S_q] + + if self.local_blocks > 0: + local = torch.arange(self.local_blocks, device=idx_q.device) + local_idx = (q_block[..., None] - local.view(1, 1, -1)).clamp(min=0) # [B, S_q, local] + block_scores.scatter_(-1, local_idx, float("inf")) + + # Slots that fall on a future/empty block keep their `-inf` + # score, which top-k sorts to the end, so tagging them `-1` yields left-packed block indices + # with `-1` right-padding which is the format expect by block-sparse attention kernel. + topk = min(self.topk_blocks, num_key_blocks) + topk_scores, topk_indices = block_scores.topk(topk, dim=-1) # [B, S_q, topk] + return topk_indices.masked_fill(topk_scores == float("-inf"), -1) + + def build_block_mask( + self, + block_indices: torch.Tensor, + attention_mask: torch.Tensor | None, + key_length: int, + dtype: torch.dtype, + device: torch.device, + position_ids: torch.Tensor, + ) -> torch.Tensor: + """ + We build the full 4D attention mask (Batch, query, key, head) + """ + batch, q_len, _ = block_indices.shape + num_key_blocks = -(-key_length // self.block_size) + + # Scatter the kept blocks to `0`; `-1` slots land in a throwaway column we drop afterwards. + safe = block_indices.masked_fill(block_indices < 0, num_key_blocks) + bias = block_indices.new_full((batch, q_len, num_key_blocks + 1), float("-inf"), dtype=dtype) + bias.scatter_(-1, safe, 0.0) + bias = bias[..., :num_key_blocks] + + # Broadcast the per-block keep/drop verdict back onto every key (block granularity), add head axis. + block_keep = (bias == 0.0).repeat_interleave(self.block_size, dim=-1)[..., :key_length].unsqueeze(1) + + # Compose block-selection with the existing mask, then emit a single additive float mask. + if attention_mask is not None: + padding_mask = attention_mask if attention_mask.dtype == torch.bool else attention_mask == 0 + keep = block_keep & padding_mask + else: + k_positions = torch.arange(key_length, device=device) + token_future = k_positions[None, None, None, :] > position_ids[:, None, :, None] # [B, 1, S_q, S_k] + keep = block_keep & ~token_future + min_dtype = torch.finfo(dtype).min + return torch.zeros(keep.shape, dtype=dtype, device=device).masked_fill(~keep, min_dtype) + + +class MiniMaxM3VLDecoderLayer(MixtralDecoderLayer): + """M3 decoder layer: per-layer dense/MoE MLP and dense/sparse attention.""" + + def __init__(self, config: MiniMaxM3VLTextConfig, layer_idx: int): + super().__init__(config, layer_idx) + self.self_attn = MiniMaxM3VLAttention(config, layer_idx) + self.mlp = ( + MiniMaxM3VLSparseMoeBlock(config) + if config.mlp_layer_types[layer_idx] == "sparse" + else MiniMaxM3VLDenseMLP(config, intermediate_size=config.dense_intermediate_size) + ) + self.input_layernorm = MiniMaxM3VLRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.post_attention_layernorm = MiniMaxM3VLRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + + +class MiniMaxM3VLPreTrainedModel(MiniMaxM2PreTrainedModel): + config: MiniMaxM3VLConfig | MiniMaxM3VLTextConfig + base_model_prefix = "model" + _no_split_modules = ["MiniMaxM3VLDecoderLayer", "MiniMaxM3VLVisionEncoderLayer"] + input_modalities = ("image", "video", "text") + _keys_to_ignore_on_load_unexpected = [r"(^|\.)mtp\..*"] + _supports_flash_attn = False + _supports_sdpa = True + _supports_flex_attn = False + _can_compile_fullgraph = True + _supports_attention_backend = True + _compatible_flash_implementations = ["kernels-staging/msa@v0"] + + @torch.no_grad() + def _init_weights(self, module): + PreTrainedModel._init_weights(self, module) + std = getattr(self.config, "initializer_range", 0.02) + if isinstance(module, MiniMaxM3VLExperts): + init.normal_(module.gate_up_proj, mean=0.0, std=std) + init.normal_(module.down_proj, mean=0.0, std=std) + elif isinstance(module, MiniMaxM3VLTopKRouter): + init.normal_(module.weight, mean=0.0, std=std) + init.zeros_(module.e_score_correction_bias) + elif isinstance(module, MiniMaxM3VLRMSNorm): + init.zeros_(module.weight) + + +class MiniMaxM3VLTextModel(MiniMaxM2Model): + config: MiniMaxM3VLTextConfig + + def __init__(self, config: MiniMaxM3VLTextConfig): + super().__init__(config) + self.layers = nn.ModuleList([MiniMaxM3VLDecoderLayer(config, i) for i in range(config.num_hidden_layers)]) + self.norm = MiniMaxM3VLRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + + def forward( + self, + input_ids: torch.LongTensor | None = None, + attention_mask: torch.Tensor | None = None, + position_ids: torch.LongTensor | None = None, + past_key_values: Cache | None = None, + inputs_embeds: torch.FloatTensor | None = None, + use_cache: bool | None = None, + **kwargs: Unpack[TransformersKwargs], + ) -> MoeModelOutputWithPast: + if (input_ids is None) ^ (inputs_embeds is not None): + raise ValueError("You must specify exactly one of input_ids or inputs_embeds") + + if use_cache and past_key_values is None: + past_key_values = DynamicCache(config=self.config) + + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + + if position_ids is None: + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + position_ids = torch.arange(inputs_embeds.shape[1], device=inputs_embeds.device) + past_seen_tokens + position_ids = position_ids.unsqueeze(0) + + if isinstance(attention_mask, dict): + causal_mask = next(iter(attention_mask.values())) + else: + causal_mask = create_causal_mask( + config=self.config, + inputs_embeds=inputs_embeds, + attention_mask=attention_mask, + past_key_values=past_key_values, + position_ids=position_ids, + ) + + hidden_states = inputs_embeds + position_embeddings = self.rotary_emb(hidden_states, position_ids=position_ids) + + # `position_ids` is threaded to every layer so the sparse layers' lightning indexer can anchor + # block selection to each query's content position (see `MiniMaxM3VLIndexer`). + for decoder_layer in self.layers[: self.config.num_hidden_layers]: + hidden_states = decoder_layer( + hidden_states, + attention_mask=causal_mask, + position_ids=position_ids, + past_key_values=past_key_values, + use_cache=use_cache, + position_embeddings=position_embeddings, + **kwargs, + ) + + hidden_states = self.norm(hidden_states) + + return MoeModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=past_key_values, + ) + + +class MiniMaxM3VLForCausalLM(MiniMaxM2ForCausalLM): + config: MiniMaxM3VLTextConfig + _tied_weights_keys = {"lm_head.weight": "model.embed_tokens.weight"} + + def __init__(self, config: MiniMaxM3VLTextConfig): + super().__init__(config) + self.model = MiniMaxM3VLTextModel(config) + self.post_init() + + +class MiniMaxM3VLVisionEmbeddings(Qwen2_5_VisionPatchEmbed): + """Patch embedding, identical to [`Qwen2_5_VisionPatchEmbed`] (reads its dims from the vision + config). The upstream checkpoint stores the conv as `patch_embedding`, renamed to the + inherited `proj` in the conversion mapping.""" + + def __init__(self, config) -> None: + nn.Module.__init__(self) + self.patch_size = config.patch_size + self.temporal_patch_size = config.temporal_patch_size + self.in_channels = config.num_channels + self.embed_dim = config.hidden_size + + kernel_size = [self.temporal_patch_size, self.patch_size, self.patch_size] + self.proj = nn.Conv3d( + self.in_channels, self.embed_dim, kernel_size=kernel_size, stride=kernel_size, bias=False + ) + + +class MiniMaxM3VL3DRotaryEmbedding(nn.Module): + r"""3D RoPE for the vision tower: each patch is rotated by its `(T, H, W)` grid position. + + `2 * (head_dim // 2)` rotary dims are split evenly across the three axes (each rounded + down to a multiple of 2), giving `axis_dim` dims per axis and `axis_dim // 2` frequencies:: + + |<------------------ rotated (3 * axis_dim) ------------------>|<- pass ->| + +--------------------+--------------------+--------------------+----------+ + | T (frames) | H (rows) | W (cols) | | + | axis_dim | axis_dim | axis_dim | | + +--------------------+--------------------+--------------------+----------+ + + Each axis' coordinate scales its own band of frequencies; the bands are concatenated as + `T|H|W` and duplicated via `cat([f, f])` to pair with the half-rotation in + `apply_rotary_pos_emb_vision`. Any head dims past `3 * axis_dim` are left unrotated. + """ + + def __init__(self, head_dim: int, theta: float = 10000.0, spatial_merge_size: int = 1): + super().__init__() + # `2 * (head_dim // 2)` rotary dims are split evenly across T/H/W, each axis rounded + # down to a multiple of 2. With head_dim=80 that is 26 dims/axis (39 freqs total); the + # remaining `head_dim - 3 * axis_dim` dims are never rotated (they pass through). + rope_dims = 2 * (head_dim // 2) + self.axis_dim = 2 * ((rope_dims // 3) // 2) + self.spatial_merge_size = spatial_merge_size + self.theta = theta + + def forward( + self, grid_thw: torch.Tensor, device: torch.device, dtype: torch.dtype + ) -> tuple[torch.Tensor, torch.Tensor]: + m = self.spatial_merge_size + coords = [] + for t, h, w in grid_thw.tolist(): + hi = torch.arange(h).unsqueeze(1).expand(-1, w) + hi = hi.reshape(h // m, m, w // m, m).permute(0, 2, 1, 3).flatten() + wi = torch.arange(w).unsqueeze(0).expand(h, -1) + wi = wi.reshape(h // m, m, w // m, m).permute(0, 2, 1, 3).flatten() + ti = torch.arange(t).repeat_interleave(h * w) + coords.append(torch.stack([ti, hi.repeat(t), wi.repeat(t)], dim=-1)) + coords = torch.cat(coords).to(device=device, dtype=torch.float32) + + # meta device init was having trouble when it was registered. TODO standardize? + inv_freq = 1.0 / ( + self.theta ** (torch.arange(0, self.axis_dim, 2, dtype=torch.float32, device=device) / self.axis_dim) + ) + freqs = torch.cat([coords[:, i : i + 1] * inv_freq for i in range(3)], dim=-1) + emb = torch.cat([freqs, freqs], dim=-1) + return emb.cos().to(dtype), emb.sin().to(dtype) + + +def rotate_half(x): + """Rotates half the hidden dims of the input.""" + x1 = x[..., : x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2 :] + return torch.cat((-x2, x1), dim=-1) + + +def apply_rotary_pos_emb_vision( + q: torch.Tensor, k: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor +) -> tuple[torch.Tensor, torch.Tensor]: + # Only the first `rot_dim` head dims carry 3D RoPE; the tail passes through untouched. + rot_dim = cos.shape[-1] + cos, sin = cos[None, :, None, :], sin[None, :, None, :] + q_rot, q_pass = q[..., :rot_dim], q[..., rot_dim:] + k_rot, k_pass = k[..., :rot_dim], k[..., rot_dim:] + q_rot = q_rot * cos + rotate_half(q_rot) * sin + k_rot = k_rot * cos + rotate_half(k_rot) * sin + return torch.cat([q_rot, q_pass], dim=-1), torch.cat([k_rot, k_pass], dim=-1) + + +class MiniMaxM3VLVisionAttention(CLIPAttention): + """CLIP-style vision attention; the only difference from [`CLIPAttention`] is + that queries and keys are rotated by the tower's 3D RoPE before the + (interface-dispatched) scaled dot-product attention.""" + + def __init__(self, config: MiniMaxM3VLVisionConfig): + super().__init__(config) + # The vision tower has no grouped-query attention; the shared eager kernel + # still expects this attribute to drive its (no-op) `repeat_kv`. + self.num_key_value_groups = 1 + + def forward( + self, + hidden_states: torch.Tensor, + position_embeddings: tuple[torch.Tensor, torch.Tensor], + attention_mask: torch.Tensor | None = None, + **kwargs: Unpack[TransformersKwargs], + ) -> tuple[torch.Tensor, torch.Tensor | None]: + input_shape = hidden_states.shape[:-1] + hidden_shape = (*input_shape, -1, self.head_dim) + + queries = self.q_proj(hidden_states).view(hidden_shape) + keys = self.k_proj(hidden_states).view(hidden_shape) + values = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) + + cos, sin = position_embeddings + queries, keys = apply_rotary_pos_emb_vision(queries, keys, cos, sin) + queries, keys = queries.transpose(1, 2), keys.transpose(1, 2) + + attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface( + self.config._attn_implementation, eager_attention_forward + ) + attn_output, attn_weights = attention_interface( + self, + queries, + keys, + values, + attention_mask, + scaling=self.scale, + dropout=0.0 if not self.training else self.dropout, + **kwargs, + ) + attn_output = attn_output.reshape(*input_shape, -1).contiguous() + return self.out_proj(attn_output), attn_weights + + +class MiniMaxM3VLVisionMLP(CLIPMLP): + pass + + +# 3D-RoPE `position_embeddings` pass via `**kwargs` for simplicity +class MiniMaxM3VLVisionEncoderLayer(CLIPEncoderLayer): + def __init__(self, config: MiniMaxM3VLVisionConfig): + super().__init__(config) + self.self_attn = MiniMaxM3VLVisionAttention(config) + self.mlp = MiniMaxM3VLVisionMLP(config) + + +@auto_docstring +class MiniMaxM3VLVisionModel(MiniMaxM3VLPreTrainedModel): + """CLIP-like vision tower with Conv3d patch embed + 3D RoPE.""" + + config: MiniMaxM3VLVisionConfig + main_input_name = "pixel_values" + _can_record_outputs = { + "hidden_states": MiniMaxM3VLVisionEncoderLayer, + "attentions": MiniMaxM3VLVisionAttention, + } + + def __init__(self, config: MiniMaxM3VLVisionConfig): + super().__init__(config) + self.embeddings = MiniMaxM3VLVisionEmbeddings(config) + self.pre_layrnorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.layers = nn.ModuleList([MiniMaxM3VLVisionEncoderLayer(config) for _ in range(config.num_hidden_layers)]) + head_dim = config.hidden_size // config.num_attention_heads + self.rotary_emb = MiniMaxM3VL3DRotaryEmbedding( + head_dim, theta=config.rope_parameters["rope_theta"], spatial_merge_size=config.spatial_merge_size + ) + self.post_init() + + @merge_with_config_defaults + @capture_outputs + @auto_docstring + def forward( + self, pixel_values: torch.Tensor, image_grid_thw: torch.Tensor, **kwargs: Unpack[TransformersKwargs] + ) -> BaseModelOutputWithPooling: + r""" + image_grid_thw (`torch.Tensor` of shape `(num_images, 3)`): + The temporal, height and width of feature shape of each image. + """ + embeds = self.embeddings(pixel_values).to(self.pre_layrnorm.weight.dtype) + cos, sin = self.rotary_emb(image_grid_thw, device=embeds.device, dtype=embeds.dtype) + hidden_states = self.pre_layrnorm(embeds).unsqueeze(0) + for layer in self.layers: + hidden_states = layer(hidden_states, attention_mask=None, position_embeddings=(cos, sin), **kwargs) + return BaseModelOutputWithPooling(last_hidden_state=hidden_states, pooler_output=hidden_states[:, 0]) + + +class MiniMaxM3VLMultiModalProjector(nn.Module): + """Projects each vision patch from `vision_config.hidden_size` to `text_config.hidden_size` + (GELU MLP), then groups `spatial_merge_size**2` neighbouring patches into the channel dim and + fuses them back to a single `text_config.hidden_size` token with a second GELU MLP.""" + + def __init__(self, config: MiniMaxM3VLConfig): + super().__init__() + text_hidden = config.text_config.hidden_size + self.spatial_merge_size = config.vision_config.spatial_merge_size + self.linear_1 = nn.Linear(config.vision_config.hidden_size, config.projector_hidden_size, bias=True) + self.act = ACT2FN["gelu"] + self.linear_2 = nn.Linear(config.projector_hidden_size, text_hidden, bias=True) + self.merge_linear_1 = nn.Linear(config.merged_hidden_size, config.projector_hidden_size, bias=True) + self.merge_act = ACT2FN["gelu"] + self.merge_linear_2 = nn.Linear(config.projector_hidden_size, text_hidden, bias=True) + + def forward(self, image_features: torch.Tensor) -> torch.Tensor: + hidden_states = self.linear_2(self.act(self.linear_1(image_features))) + hidden_states = hidden_states.reshape(hidden_states.shape[0] // (self.spatial_merge_size**2), -1) + return self.merge_linear_2(self.merge_act(self.merge_linear_1(hidden_states))) + + +class MiniMaxM3VLModelOutputWithPast(LlavaModelOutputWithPast): + r""" + past_key_values (`Cache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): + It is a [`~cache_utils.Cache`] instance. For more details, see our [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache). + + Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see + `past_key_values` input) to speed up sequential decoding. + image_hidden_states (`torch.FloatTensor`, *optional*): + A `torch.FloatTensor` of size `(num_image_patches, hidden_size)`. + image_hidden_states of the model produced by the vision encoder and after projecting the last hidden state. + video_hidden_states (`torch.FloatTensor`, *optional*): + A `torch.FloatTensor` of size `(num_video_patches, hidden_size)`. + video_hidden_states of the model produced by the vision encoder and after projecting the last hidden state. + """ + + video_hidden_states: torch.FloatTensor | None = None + + +class MiniMaxM3VLCausalLMOutputWithPast(LlavaCausalLMOutputWithPast): + r""" + loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided): + Language modeling loss (for next-token prediction). + logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`): + Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax). + past_key_values (`Cache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): + It is a [`~cache_utils.Cache`] instance. For more details, see our [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache). + + Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see + `past_key_values` input) to speed up sequential decoding. + image_hidden_states (`torch.FloatTensor`, *optional*): + A `torch.FloatTensor` of size `(num_image_patches, hidden_size)`. + image_hidden_states of the model produced by the vision encoder and after projecting the last hidden state. + video_hidden_states (`torch.FloatTensor`, *optional*): + A `torch.FloatTensor` of size `(num_video_patches, hidden_size)`. + video_hidden_states of the model produced by the vision encoder and after projecting the last hidden state. + """ + + video_hidden_states: torch.FloatTensor | None = None + + +@auto_docstring(custom_intro="MiniMax M3 VL backbone (vision + projector + text), without LM head.") +class MiniMaxM3VLModel(LlavaModel): + config: MiniMaxM3VLConfig + + def __init__(self, config: MiniMaxM3VLConfig): + super().__init__(config) + self.vision_tower = MiniMaxM3VLVisionModel(config.vision_config) + self.multi_modal_projector = MiniMaxM3VLMultiModalProjector(config) + self.language_model = MiniMaxM3VLTextModel(config.text_config) + self.post_init() + + def get_image_features( + self, + pixel_values: torch.FloatTensor, + image_grid_thw: torch.Tensor, + **kwargs, + ) -> BaseModelOutputWithPooling: + r""" + image_grid_thw (`torch.Tensor` of shape `(num_images, 3)`, *optional*): + The temporal, height and width of each image's feature grid, used to build the vision 3D RoPE + and to merge patch features. + """ + # Return the raw vision-tower output (so callers can inspect hidden states / + # attentions) while stashing the projected + spatially-merged features — + # ready to scatter into the text embeddings — in `pooler_output`. + vision_outputs = self.vision_tower(pixel_values=pixel_values, image_grid_thw=image_grid_thw, **kwargs) + vision_outputs.pooler_output = self.multi_modal_projector(vision_outputs.last_hidden_state.squeeze(0)) + return vision_outputs + + @merge_with_config_defaults + @can_return_tuple + @auto_docstring( + custom_intro="Obtains video last hidden states from the vision tower and apply multimodal projection." + ) + def get_video_features( + self, + pixel_values_videos: torch.FloatTensor, + video_grid_thw: torch.Tensor, + **kwargs, + ) -> BaseModelOutputWithPooling: + r""" + pixel_values_videos (`torch.FloatTensor`): + The tensors corresponding to the input video frames. + video_grid_thw (`torch.Tensor` of shape `(num_videos, 3)`, *optional*): + The temporal, height and width of each video's feature grid, used to build the vision 3D RoPE + and to merge patch features. + """ + # Video frames flow through the same vision pipeline as images (the tower is + # grid-agnostic); only the placeholder token they scatter into differs. + vision_outputs = self.vision_tower(pixel_values=pixel_values_videos, image_grid_thw=video_grid_thw, **kwargs) + vision_outputs.pooler_output = self.multi_modal_projector(vision_outputs.last_hidden_state.squeeze(0)) + return vision_outputs + + def get_placeholder_mask( + self, + input_ids: torch.LongTensor, + inputs_embeds: torch.FloatTensor, + image_features: torch.FloatTensor | None = None, + video_features: torch.FloatTensor | None = None, + ): + """ + Obtains the image/video placeholder masks from `input_ids` or `inputs_embeds`, and checks that the + placeholder token count matches the multimodal feature length. Raises if they differ. + """ + if input_ids is None: + special_image_mask = inputs_embeds == self.get_input_embeddings()( + torch.tensor(self.config.image_token_id, dtype=torch.long, device=inputs_embeds.device) + ) + special_image_mask = special_image_mask.all(-1) + special_video_mask = inputs_embeds == self.get_input_embeddings()( + torch.tensor(self.config.video_token_id, dtype=torch.long, device=inputs_embeds.device) + ) + special_video_mask = special_video_mask.all(-1) + else: + special_image_mask = input_ids == self.config.image_token_id + special_video_mask = input_ids == self.config.video_token_id + + n_image_tokens = special_image_mask.sum() + special_image_mask = special_image_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device) + if image_features is not None: + torch_compilable_check( + inputs_embeds[special_image_mask].numel() == image_features.numel(), + f"Image features and image tokens do not match, tokens: {n_image_tokens}, features: {image_features.shape[0]}", + ) + + n_video_tokens = special_video_mask.sum() + special_video_mask = special_video_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device) + if video_features is not None: + torch_compilable_check( + inputs_embeds[special_video_mask].numel() == video_features.numel(), + f"Video features and video tokens do not match, tokens: {n_video_tokens}, features: {video_features.shape[0]}", + ) + return special_image_mask, special_video_mask + + @can_return_tuple + @auto_docstring + def forward( + self, + input_ids: torch.LongTensor | None = None, + pixel_values: torch.FloatTensor | None = None, + pixel_values_videos: torch.FloatTensor | None = None, + image_grid_thw: torch.Tensor | None = None, + video_grid_thw: torch.Tensor | None = None, + attention_mask: torch.Tensor | None = None, + position_ids: torch.LongTensor | None = None, + past_key_values: Cache | None = None, + inputs_embeds: torch.FloatTensor | None = None, + **kwargs: Unpack[TransformersKwargs], + ) -> tuple | MiniMaxM3VLModelOutputWithPast: + r""" + image_grid_thw (`torch.Tensor` of shape `(num_images, 3)`, *optional*): + The temporal, height and width of each image's feature grid, used to build the vision 3D RoPE + and to merge patch features. + video_grid_thw (`torch.Tensor` of shape `(num_videos, 3)`, *optional*): + The temporal, height and width of each video's feature grid, used to build the vision 3D RoPE + and to merge patch features. + """ + if (input_ids is None) ^ (inputs_embeds is not None): + raise ValueError("You must specify exactly one of input_ids or inputs_embeds") + + if inputs_embeds is None: + inputs_embeds = self.get_input_embeddings()(input_ids) + + image_features = None + if pixel_values is not None: + image_features = self.get_image_features( + pixel_values=pixel_values, image_grid_thw=image_grid_thw + ).pooler_output.to(inputs_embeds.device, inputs_embeds.dtype) + + video_features = None + if pixel_values_videos is not None: + video_features = self.get_video_features( + pixel_values_videos=pixel_values_videos, video_grid_thw=video_grid_thw + ).pooler_output.to(inputs_embeds.device, inputs_embeds.dtype) + + image_mask, video_mask = self.get_placeholder_mask( + input_ids, inputs_embeds, image_features=image_features, video_features=video_features + ) + if image_features is not None: + inputs_embeds = inputs_embeds.masked_scatter(image_mask, image_features) + if video_features is not None: + inputs_embeds = inputs_embeds.masked_scatter(video_mask, video_features) + + outputs = self.language_model( + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + **kwargs, + ) + + return MiniMaxM3VLModelOutputWithPast( + last_hidden_state=outputs.last_hidden_state, + past_key_values=outputs.past_key_values, + hidden_states=getattr(outputs, "hidden_states", None), + attentions=getattr(outputs, "attentions", None), + image_hidden_states=image_features, + video_hidden_states=video_features, + ) + + +@auto_docstring(custom_intro="MiniMax M3 VL full model with LM head (text + vision).") +class MiniMaxM3SparseForConditionalGeneration(LlavaForConditionalGeneration): + config: MiniMaxM3VLConfig + + def get_image_features(self, pixel_values, image_grid_thw, **kwargs): + r""" + image_grid_thw (`torch.Tensor` of shape `(num_images, 3)`, *optional*): + The temporal, height and width of each image's feature grid, used to build the vision 3D RoPE + and to merge patch features. + """ + return self.model.get_image_features(pixel_values, image_grid_thw, **kwargs) + + def get_video_features(self, pixel_values_videos, video_grid_thw, **kwargs): + r""" + pixel_values_videos (`torch.FloatTensor`): + The tensors corresponding to the input video frames. + video_grid_thw (`torch.Tensor` of shape `(num_videos, 3)`, *optional*): + The temporal, height and width of each video's feature grid, used to build the vision 3D RoPE + and to merge patch features. + """ + return self.model.get_video_features(pixel_values_videos, video_grid_thw, **kwargs) + + @can_return_tuple + @auto_docstring + def forward( + self, + input_ids: torch.LongTensor | None = None, + pixel_values: torch.FloatTensor | None = None, + pixel_values_videos: torch.FloatTensor | None = None, + image_grid_thw: torch.Tensor | None = None, + video_grid_thw: torch.Tensor | None = None, + attention_mask: torch.Tensor | None = None, + position_ids: torch.LongTensor | None = None, + past_key_values: Cache | None = None, + inputs_embeds: torch.FloatTensor | None = None, + labels: torch.LongTensor | None = None, + logits_to_keep: int | torch.Tensor = 0, + **kwargs: Unpack[TransformersKwargs], + ) -> tuple | MiniMaxM3VLCausalLMOutputWithPast: + r""" + image_grid_thw (`torch.Tensor` of shape `(num_images, 3)`, *optional*): + The temporal, height and width of each image's feature grid, used to build the vision 3D RoPE + and to merge patch features. + video_grid_thw (`torch.Tensor` of shape `(num_videos, 3)`, *optional*): + The temporal, height and width of each video's feature grid, used to build the vision 3D RoPE + and to merge patch features. + """ + outputs = self.model( + input_ids=input_ids, + pixel_values=pixel_values, + pixel_values_videos=pixel_values_videos, + image_grid_thw=image_grid_thw, + video_grid_thw=video_grid_thw, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + **kwargs, + ) + hidden_states = outputs.last_hidden_state + slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep + logits = self.lm_head(hidden_states[:, slice_indices, :]) + + loss = None + if labels is not None: + loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.text_config.vocab_size) + + return MiniMaxM3VLCausalLMOutputWithPast( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + image_hidden_states=outputs.image_hidden_states, + video_hidden_states=outputs.video_hidden_states, + ) + + def prepare_inputs_for_generation( + self, + input_ids, + past_key_values=None, + inputs_embeds=None, + pixel_values=None, + pixel_values_videos=None, + attention_mask=None, + logits_to_keep=None, + is_first_iteration=False, + **kwargs, + ): + # Overwritten -- pixel inputs are merged into the cache on the first step, so we + # only forward them once (image and video alike). + model_inputs = super().prepare_inputs_for_generation( + input_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + attention_mask=attention_mask, + logits_to_keep=logits_to_keep, + is_first_iteration=is_first_iteration, + **kwargs, + ) + + if is_first_iteration or not kwargs.get("use_cache", True): + model_inputs["pixel_values"] = pixel_values + model_inputs["pixel_values_videos"] = pixel_values_videos + + return model_inputs + + +class MiniMaxM3VLProcessorKwargs(Qwen2VLProcessorKwargs): + _defaults = { + "videos_kwargs": {"do_resize": False, "return_metadata": True}, + } + + +class MiniMaxM3VLProcessor(Qwen2VLProcessor): + """Combines tokenizer + image_processor + video_processor for MiniMax M3 VL. + + Expands `IMAGE_TOKEN` / `VIDEO_TOKEN` markers in the prompt into the matching + number of placeholder tokens (one per merged patch), wrapped in `VISION_START_TOKEN` + / `VISION_END_TOKEN` brackets. Video chunks are additionally prefixed with a + `]<]{seconds} seconds[>[` timestamp marker per frame when metadata is available. + """ + + valid_processor_kwargs = MiniMaxM3VLProcessorKwargs + + IMAGE_TOKEN = "]<]image[>[" + VIDEO_TOKEN = "]<]video[>[" + VISION_START_TOKEN = "]<]start of image[>[" + VISION_END_TOKEN = "]<]end of image[>[" + + def __init__(self, image_processor=None, tokenizer=None, video_processor=None, chat_template=None, **kwargs): + super().__init__(image_processor, tokenizer, video_processor, chat_template=chat_template) + self.image_token = self.IMAGE_TOKEN + self.video_token = self.VIDEO_TOKEN + self.image_token_id = tokenizer.convert_tokens_to_ids(self.IMAGE_TOKEN) if tokenizer else None + self.video_token_id = tokenizer.convert_tokens_to_ids(self.VIDEO_TOKEN) if tokenizer else None + self.vision_start_token_id = tokenizer.convert_tokens_to_ids(self.VISION_START_TOKEN) if tokenizer else None + self.vision_end_token_id = tokenizer.convert_tokens_to_ids(self.VISION_END_TOKEN) if tokenizer else None + + def replace_image_token(self, image_inputs: dict, image_idx: int) -> str: + merge_length = self.image_processor.merge_size**2 + num_image_tokens = int(image_inputs["image_grid_thw"][image_idx].prod() // merge_length) + return self.VISION_START_TOKEN + self.IMAGE_TOKEN * num_image_tokens + self.VISION_END_TOKEN + + def replace_video_token(self, video_inputs: dict, video_idx: int) -> str: + merge_length = self.video_processor.merge_size**2 + grid_thw = video_inputs["video_grid_thw"][video_idx] + grid_t = int(grid_thw[0]) + frame_seqlen = int(grid_thw[1:].prod() // merge_length) + metadata = video_inputs.get("video_metadata", [None] * (video_idx + 1))[video_idx] + temporal_patch_size = self.video_processor.temporal_patch_size + chunk = "" + for frame in range(grid_t): + if ( + metadata is not None + and getattr(metadata, "fps", None) is not None + and getattr(metadata, "frames_indices", None) is not None + ): + ts = ( + metadata.frames_indices[min(frame * temporal_patch_size, len(metadata.frames_indices) - 1)] + / metadata.fps + ) + chunk += f"]<]{ts:.1f} seconds[>[" + chunk += self.VISION_START_TOKEN + self.VIDEO_TOKEN * frame_seqlen + self.VISION_END_TOKEN + return chunk + + +__all__ = [ + "MiniMaxM3VLConfig", + "MiniMaxM3VLTextConfig", + "MiniMaxM3VLVisionConfig", + "MiniMaxM3VLForCausalLM", + "MiniMaxM3SparseForConditionalGeneration", + "MiniMaxM3VLModel", + "MiniMaxM3VLPreTrainedModel", + "MiniMaxM3VLProcessor", + "MiniMaxM3VLTextModel", + "MiniMaxM3VLVisionModel", +] diff --git a/src/transformers/models/minimax_m3_vl/processing_minimax_m3_vl.py b/src/transformers/models/minimax_m3_vl/processing_minimax_m3_vl.py new file mode 100644 index 000000000000..6ab6f0517d3b --- /dev/null +++ b/src/transformers/models/minimax_m3_vl/processing_minimax_m3_vl.py @@ -0,0 +1,153 @@ +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# This file was automatically generated from src/transformers/models/minimax_m3_vl/modular_minimax_m3_vl.py. +# Do NOT edit this file manually as any edits will be overwritten by the generation of +# the file from the modular. If any change should be done, please apply the change to the +# modular_minimax_m3_vl.py file directly. One of our CI enforces this. +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# Copyright 2026 the MiniMax AI Team and HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from ...processing_utils import MultiModalData, ProcessingKwargs, ProcessorMixin +from ...utils import auto_docstring + + +class MiniMaxM3VLProcessorKwargs(ProcessingKwargs, total=False): + _defaults = { + "videos_kwargs": {"do_resize": False, "return_metadata": True}, + } + + +@auto_docstring +class MiniMaxM3VLProcessor(ProcessorMixin): + """Combines tokenizer + image_processor + video_processor for MiniMax M3 VL. + + Expands `IMAGE_TOKEN` / `VIDEO_TOKEN` markers in the prompt into the matching + number of placeholder tokens (one per merged patch), wrapped in `VISION_START_TOKEN` + / `VISION_END_TOKEN` brackets. Video chunks are additionally prefixed with a + `]<]{seconds} seconds[>[` timestamp marker per frame when metadata is available. + """ + + valid_processor_kwargs = MiniMaxM3VLProcessorKwargs + + IMAGE_TOKEN = "]<]image[>[" + VIDEO_TOKEN = "]<]video[>[" + VISION_START_TOKEN = "]<]start of image[>[" + VISION_END_TOKEN = "]<]end of image[>[" + + def __init__(self, image_processor=None, tokenizer=None, video_processor=None, chat_template=None, **kwargs): + self.image_token = self.IMAGE_TOKEN + self.video_token = self.VIDEO_TOKEN + self.image_token_id = tokenizer.convert_tokens_to_ids(self.IMAGE_TOKEN) if tokenizer else None + self.video_token_id = tokenizer.convert_tokens_to_ids(self.VIDEO_TOKEN) if tokenizer else None + super().__init__(image_processor, tokenizer, video_processor, chat_template=chat_template) + self.vision_start_token_id = tokenizer.convert_tokens_to_ids(self.VISION_START_TOKEN) if tokenizer else None + self.vision_end_token_id = tokenizer.convert_tokens_to_ids(self.VISION_END_TOKEN) if tokenizer else None + + def replace_image_token(self, image_inputs: dict, image_idx: int) -> str: + merge_length = self.image_processor.merge_size**2 + num_image_tokens = int(image_inputs["image_grid_thw"][image_idx].prod() // merge_length) + return self.VISION_START_TOKEN + self.IMAGE_TOKEN * num_image_tokens + self.VISION_END_TOKEN + + def replace_video_token(self, video_inputs: dict, video_idx: int) -> str: + merge_length = self.video_processor.merge_size**2 + grid_thw = video_inputs["video_grid_thw"][video_idx] + grid_t = int(grid_thw[0]) + frame_seqlen = int(grid_thw[1:].prod() // merge_length) + metadata = video_inputs.get("video_metadata", [None] * (video_idx + 1))[video_idx] + temporal_patch_size = self.video_processor.temporal_patch_size + chunk = "" + for frame in range(grid_t): + if ( + metadata is not None + and getattr(metadata, "fps", None) is not None + and getattr(metadata, "frames_indices", None) is not None + ): + ts = ( + metadata.frames_indices[min(frame * temporal_patch_size, len(metadata.frames_indices) - 1)] + / metadata.fps + ) + chunk += f"]<]{ts:.1f} seconds[>[" + chunk += self.VISION_START_TOKEN + self.VIDEO_TOKEN * frame_seqlen + self.VISION_END_TOKEN + return chunk + + def _get_num_multimodal_tokens(self, image_sizes=None, video_sizes=None, **kwargs): + """ + Computes the number of placeholder tokens needed for multimodal inputs with the given sizes. + Args: + image_sizes (`list[list[int]]`, *optional*): + The input sizes formatted as (height, width) per each image. + video_sizes (`list[list[int]]`, *optional*): + The input sizes formatted as (num_frames, height, width) per each video. + Returns: + `MultiModalData`: A `MultiModalData` object holding number of tokens per each of the provided + input modalities, along with other useful data. + """ + + vision_data = {} + if image_sizes is not None: + images_kwargs = MiniMaxM3VLProcessorKwargs._defaults.get("images_kwargs", {}) + images_kwargs.update(kwargs) + merge_size = images_kwargs.get("merge_size", None) or self.image_processor.merge_size + + num_image_patches = [ + self.image_processor.get_number_of_image_patches(*image_size, images_kwargs) + for image_size in image_sizes + ] + num_image_tokens = [(num_patches // merge_size**2) for num_patches in num_image_patches] + vision_data.update({"num_image_tokens": num_image_tokens, "num_image_patches": num_image_patches}) + + if video_sizes is not None: + videos_kwargs = MiniMaxM3VLProcessorKwargs._defaults.get("videos_kwargs", {}) + videos_kwargs.update(kwargs) + num_video_patches = [ + self.video_processor.get_number_of_video_patches(*video_size, videos_kwargs) + for video_size in video_sizes + ] + num_video_tokens = [(num_patches // merge_size**2) for num_patches in num_video_patches] + vision_data["num_video_tokens"] = num_video_tokens + + return MultiModalData(**vision_data) + + def post_process_image_text_to_text( + self, generated_outputs, skip_special_tokens=True, clean_up_tokenization_spaces=False, **kwargs + ): + """ + Post-process the output of the model to decode the text. + + Args: + generated_outputs (`torch.Tensor` or `np.ndarray`): + The output of the model `generate` function. The output is expected to be a tensor of shape `(batch_size, sequence_length)` + or `(sequence_length,)`. + skip_special_tokens (`bool`, *optional*, defaults to `True`): + Whether or not to remove special tokens in the output. Argument passed to the tokenizer's `batch_decode` method. + clean_up_tokenization_spaces (`bool`, *optional*, defaults to `False`): + Whether or not to clean up the tokenization spaces. Argument passed to the tokenizer's `batch_decode` method. + **kwargs: + Additional arguments to be passed to the tokenizer's `batch_decode method`. + + Returns: + `list[str]`: The decoded text. + """ + return self.tokenizer.batch_decode( + generated_outputs, + skip_special_tokens=skip_special_tokens, + clean_up_tokenization_spaces=clean_up_tokenization_spaces, + **kwargs, + ) + + @property + def model_input_names(self): + return super().model_input_names + ["mm_token_type_ids"] + + +__all__ = ["MiniMaxM3VLProcessor"] diff --git a/src/transformers/models/minimax_m3_vl/video_processing_minimax_m3_vl.py b/src/transformers/models/minimax_m3_vl/video_processing_minimax_m3_vl.py new file mode 100644 index 000000000000..c25729309e14 --- /dev/null +++ b/src/transformers/models/minimax_m3_vl/video_processing_minimax_m3_vl.py @@ -0,0 +1,137 @@ +# Copyright 2026 the MiniMax AI Team and HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 + +"""Video processor for MiniMax M3 VL.""" + +import torch +from torchvision.transforms import InterpolationMode + +from ...feature_extraction_utils import BatchFeature +from ...image_utils import PILImageResampling, SizeDict +from ...processing_utils import Unpack, VideosKwargs +from ...utils import TensorType +from ...video_processing_utils import BaseVideoProcessor +from ...video_utils import group_videos_by_shape, reorder_videos +from .image_processing_minimax_m3_vl import smart_resize + + +class MiniMaxM3VLVideoProcessorKwargs(VideosKwargs, total=False): + patch_size: int + temporal_patch_size: int + merge_size: int + min_pixels: int + max_pixels: int + total_pixels: int + min_frames: int + max_frames: int + fps: float | int + + +class MiniMaxM3VLVideoProcessor(BaseVideoProcessor): + do_resize = True + resample = PILImageResampling.BICUBIC + size = {"height": 672, "width": 672} + default_to_square = False + do_rescale = True + rescale_factor = 1 / 255 + do_normalize = True + image_mean = [0.48145466, 0.4578275, 0.40821073] + image_std = [0.26862954, 0.26130258, 0.27577711] + do_convert_rgb = True + do_sample_frames = False + patch_size = 14 + temporal_patch_size = 2 + merge_size = 2 + min_pixels = 4 * 28 * 28 + max_pixels = 768 * 28 * 28 + total_pixels = int(64000 * 28 * 28 * 0.9) + fps = 1.0 + min_frames = 4 + max_frames = 768 + valid_kwargs = MiniMaxM3VLVideoProcessorKwargs + model_input_names = ["pixel_values_videos", "video_grid_thw"] + + def __init__(self, **kwargs: Unpack[MiniMaxM3VLVideoProcessorKwargs]): + super().__init__(**kwargs) + + def _preprocess( + self, + videos: list[torch.Tensor], + do_convert_rgb: bool, + do_resize: bool, + size: SizeDict, + resample: PILImageResampling | InterpolationMode | int | None, + do_rescale: bool, + rescale_factor: float, + do_normalize: bool, + image_mean: float | list[float] | None, + image_std: float | list[float] | None, + patch_size: int, + temporal_patch_size: int, + merge_size: int, + min_pixels: int, + max_pixels: int, + return_tensors: str | TensorType | None = None, + **kwargs, + ) -> BatchFeature: + grouped, grouped_idx = group_videos_by_shape(videos) + resized_grouped = {} + factor = patch_size * merge_size + for shape, stacked in grouped.items(): + bs, nf, c, h, w = stacked.shape + rh, rw = (h, w) + if do_resize: + rh, rw = smart_resize(h, w, factor=factor, min_pixels=min_pixels, max_pixels=max_pixels) + stacked = stacked.view(bs * nf, c, h, w) + stacked = self.resize(stacked, size=SizeDict(height=rh, width=rw), resample=resample) + stacked = stacked.view(bs, nf, c, rh, rw) + resized_grouped[shape] = stacked + resized = reorder_videos(resized_grouped, grouped_idx) + + grouped, grouped_idx = group_videos_by_shape(resized) + processed_grouped = {} + grids = {} + for shape, stacked in grouped.items(): + rh, rw = stacked.shape[-2:] + patches = self.rescale_and_normalize( + stacked, do_rescale, rescale_factor, do_normalize, image_mean, image_std + ) + if pad := -patches.shape[1] % temporal_patch_size: + repeats = patches[:, -1:].expand(-1, pad, -1, -1, -1) + patches = torch.cat([patches, repeats], dim=1) + bs, grid_t, c = patches.shape[:3] + grid_t = grid_t // temporal_patch_size + grid_h, grid_w = rh // patch_size, rw // patch_size + patches = patches.view( + bs, + grid_t, + temporal_patch_size, + c, + grid_h // merge_size, + merge_size, + patch_size, + grid_w // merge_size, + merge_size, + patch_size, + ) + patches = patches.permute(0, 1, 4, 7, 5, 8, 3, 2, 6, 9) + flat = patches.reshape(bs, grid_t * grid_h * grid_w, c * temporal_patch_size * patch_size * patch_size) + processed_grouped[shape] = flat + grids[shape] = [[grid_t, grid_h, grid_w]] * bs + + processed = reorder_videos(processed_grouped, grouped_idx) + grids = reorder_videos(grids, grouped_idx) + pixel_values_videos = torch.cat(processed, dim=0) + video_grid_thw = torch.tensor(grids, dtype=torch.long) + return BatchFeature( + data={"pixel_values_videos": pixel_values_videos, "video_grid_thw": video_grid_thw}, + tensor_type=return_tensors, + ) + + +__all__ = ["MiniMaxM3VLVideoProcessor"] diff --git a/src/transformers/quantizers/auto.py b/src/transformers/quantizers/auto.py index d053401be153..4fe470f8727a 100644 --- a/src/transformers/quantizers/auto.py +++ b/src/transformers/quantizers/auto.py @@ -90,6 +90,10 @@ "vptq": VptqHfQuantizer, "spqr": SpQRHfQuantizer, "fp8": FineGrainedFP8HfQuantizer, + # MXFP8 = FP8 (E4M3 weights) with per-block ``[1, 32]`` E8M0 (uint8) scales — + # reuses the FineGrainedFP8 dequant path, with the E8M0 byte→exponent + # unpacking handled inside ``Fp8Dequantize._dequantize_one``. + "mxfp8": FineGrainedFP8HfQuantizer, "auto-round": AutoRoundQuantizer, "mxfp4": Mxfp4HfQuantizer, "metal": MetalHfQuantizer, @@ -117,6 +121,7 @@ "vptq": VptqConfig, "spqr": SpQRConfig, "fp8": FineGrainedFP8Config, + "mxfp8": FineGrainedFP8Config, "auto-round": AutoRoundConfig, "mxfp4": Mxfp4Config, "metal": MetalConfig, diff --git a/src/transformers/quantizers/quantizer_finegrained_fp8.py b/src/transformers/quantizers/quantizer_finegrained_fp8.py index 0d300e89e954..011f45e6a1a5 100644 --- a/src/transformers/quantizers/quantizer_finegrained_fp8.py +++ b/src/transformers/quantizers/quantizer_finegrained_fp8.py @@ -94,6 +94,27 @@ def param_element_size(self, model: "PreTrainedModel", param_name: str, param: " return 1 return super().param_element_size(model, param_name, param) + def _normalize_modules_to_not_convert(self, model: "PreTrainedModel"): + """Rewrite the skip-list to the model's own module tree. + For models that were already released, if they have a list of modules to not quantize + we need to apply the weight renaming / weight conversion opérations to get the actual + layer name of the model in `transformers`. + """ + skip = self.quantization_config.modules_to_not_convert + if not skip: + return + + from ..conversion_mapping import get_model_conversion_mapping + + renamings = get_model_conversion_mapping(model) + remapped = [] + for name in skip: + renamed = name + for rename in renamings: + renamed, _ = rename.rename_source_key(renamed) + remapped.append(renamed) + self.quantization_config.modules_to_not_convert = remapped + def _process_model_before_weight_loading( self, model: "PreTrainedModel", @@ -101,6 +122,7 @@ def _process_model_before_weight_loading( ): from ..integrations.finegrained_fp8 import replace_with_fp8_linear + self._normalize_modules_to_not_convert(model) self.modules_to_not_convert = self.get_modules_to_not_convert( model, self.quantization_config.modules_to_not_convert, model._keep_in_fp32_modules ) @@ -182,6 +204,41 @@ def get_weight_conversions(self): ] return [] + def _is_mxfp8(self) -> bool: + """MXFP8 checkpoints ship E8M0 (uint8) per-block scales; plain FP8 ships float32.""" + quant_method = getattr(self.quantization_config, "quant_method", None) + return quant_method == "mxfp8" + + def _update_weight_conversions_mxfp8(self, weight_conversions): + """ + Native MXFP8 path: prepend a `Fp8DecodeScale` op so the uint8 E8M0 + scales are decoded to float32 `2 ** (byte - 127)` *before* any merge/concat op + and add a generic fallback converter that decodes the scales of plain `FP8Linear` weights (attention / dense projections) + which have no model-specific converter. + """ + from ..core_model_loading import WeightConverter + from ..integrations.finegrained_fp8 import Fp8DecodeScale + + updated: list = [] + for conv in weight_conversions: + if isinstance(conv, WeightConverter) and any(p.endswith(".weight") for p in conv.source_patterns): + conv = WeightConverter( + source_patterns=conv.source_patterns, + target_patterns=conv._original_target_patterns, + operations=[Fp8DecodeScale(self)] + list(conv.operations), + ) + updated.append(conv) + # Generic fallback for plain ``nn.Linear`` scales with no model-specific converter. + # Listed last so the model converters above win the first-match for expert/dense scales. + updated.append( + WeightConverter( + source_patterns=["weight_scale_inv"], + target_patterns="weight_scale_inv", + operations=[Fp8DecodeScale(self)], + ) + ) + return updated + def update_weight_conversions(self, weight_conversions): """When loading with ``dequantize=True``, attach an :class:`Fp8Dequantize` op to every existing :class:`WeightConverter` so that per-block scales are folded into @@ -213,6 +270,9 @@ def update_weight_conversions(self, weight_conversions): weight_conversions = [scale_rename] + list(weight_conversions) if not (self.pre_quantized and self.quantization_config.dequantize): + if self.pre_quantized and self._is_mxfp8(): + # mxfp8 needs a pre-processing on the scales when not dequantizing + return self._update_weight_conversions_mxfp8(weight_conversions) return weight_conversions + self.get_weight_conversions() updated: list = [] diff --git a/src/transformers/utils/quantization_config.py b/src/transformers/utils/quantization_config.py index e2235060b85f..ef7a18083da6 100644 --- a/src/transformers/utils/quantization_config.py +++ b/src/transformers/utils/quantization_config.py @@ -60,6 +60,7 @@ class QuantizationMethod(str, Enum): FPQUANT = "fp_quant" AUTOROUND = "auto-round" MXFP4 = "mxfp4" + MXFP8 = "mxfp8" METAL = "metal" FOUR_OVER_SIX = "fouroversix" SINQ = "sinq" @@ -1662,7 +1663,10 @@ def __init__( scale_fmt: str = "float", **kwargs, ): - self.quant_method = QuantizationMethod.FP8 + self.quant_method = kwargs.pop("quant_method", QuantizationMethod.FP8) + # MiniMax ships the skip-list under ``ignored_layers``; accept it as an alias. + if modules_to_not_convert is None and "ignored_layers" in kwargs: + modules_to_not_convert = kwargs.pop("ignored_layers") self.modules_to_not_convert = modules_to_not_convert self.activation_scheme = activation_scheme self.weight_block_size = weight_block_size diff --git a/tests/fixtures/tests_samples/COCO/apple.jpg b/tests/fixtures/tests_samples/COCO/apple.jpg new file mode 100644 index 000000000000..f055aecd4bb4 Binary files /dev/null and b/tests/fixtures/tests_samples/COCO/apple.jpg differ diff --git a/tests/models/minimax_m3_vl/__init__.py b/tests/models/minimax_m3_vl/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/tests/models/minimax_m3_vl/test_modeling_minimax_m3_vl.py b/tests/models/minimax_m3_vl/test_modeling_minimax_m3_vl.py new file mode 100644 index 000000000000..359cf1f52e3a --- /dev/null +++ b/tests/models/minimax_m3_vl/test_modeling_minimax_m3_vl.py @@ -0,0 +1,652 @@ +# Copyright 2026 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Testing suite for the PyTorch MiniMax-M3-VL model.""" + +import copy +import unittest + +from parameterized import parameterized + +from transformers import ( + AutoTokenizer, + MiniMaxM3SparseForConditionalGeneration, + MiniMaxM3VLConfig, + MiniMaxM3VLImageProcessorFast, + MiniMaxM3VLModel, + MiniMaxM3VLProcessor, + MiniMaxM3VLVideoProcessor, + is_torch_available, + is_vision_available, +) +from transformers.testing_utils import ( + require_torch, + slow, + torch_device, +) + +from ...generation.test_utils import GenerationTesterMixin +from ...test_configuration_common import ConfigTester +from ...test_modeling_common import ( + TEST_EAGER_MATCHES_BATCHED_AND_GROUPED_INFERENCE_PARAMETERIZATION, + ModelTesterMixin, + _test_eager_matches_batched_and_grouped_inference, + floats_tensor, + ids_tensor, +) +from ...test_pipeline_mixin import PipelineTesterMixin + + +if is_torch_available(): + import torch + + +if is_vision_available(): + from PIL import Image + + +class MiniMaxM3VLVisionText2TextModelTester: + def __init__( + self, + parent, + batch_size=3, + seq_length=7, + ignore_index=-100, + image_token_index=4, + video_token_index=5, + is_training=True, + text_config={ + "hidden_size": 32, + "intermediate_size": 64, + "dense_intermediate_size": 128, + "shared_intermediate_size": 32, + "num_hidden_layers": 2, + "num_attention_heads": 4, + "num_key_value_heads": 2, + "head_dim": 32, + "rotary_dim": 16, + "hidden_act": "silu", + "max_position_embeddings": 512, + "rms_norm_eps": 1e-6, + "vocab_size": 99, + "bos_token_id": 0, + "eos_token_id": 1, + "pad_token_id": 2, + "num_local_experts": 4, + "num_experts_per_tok": 2, + "n_shared_experts": 1, + "moe_layer_freq": [0, 1], + "layer_types": [ + "full_attention", + "minimax_m3_sparse", + ], + "use_routing_bias": True, + "routed_scaling_factor": 2.0, + "swiglu_alpha": 1.702, + "swiglu_limit": 7.0, + "tie_word_embeddings": False, + "rope_parameters": { + "rope_type": "default", + "rope_theta": 5000000.0, + "partial_rotary_factor": 0.5, + }, + "index_n_heads": 2, + "index_head_dim": 16, + "index_block_size": 8, + "index_topk_blocks": 4, + "index_local_blocks": 1, + }, + vision_config={ + "hidden_size": 32, + "intermediate_size": 64, + "num_hidden_layers": 2, + "num_attention_heads": 4, + "num_channels": 3, + "image_size": 14, + "patch_size": 14, + "temporal_patch_size": 2, + "spatial_merge_size": 1, + "rope_theta": 10000.0, + }, + ): + self.parent = parent + self.batch_size = batch_size + self.ignore_index = ignore_index + self.image_token_index = image_token_index + self.video_token_index = video_token_index + self.is_training = is_training + self.text_config = text_config + self.vision_config = vision_config + + self.pad_token_id = text_config["pad_token_id"] + self.num_hidden_layers = text_config["num_hidden_layers"] + self.num_attention_heads = text_config["num_attention_heads"] + self.hidden_size = text_config["hidden_size"] + self.vocab_size = text_config["vocab_size"] + + self.num_channels = vision_config["num_channels"] + self.image_size = vision_config["image_size"] + self.patch_size = vision_config["patch_size"] + self.temporal_patch_size = vision_config["temporal_patch_size"] + self.spatial_merge_size = vision_config["spatial_merge_size"] + + # One patch per image (grid [1, 1, 1]) so that the generation common tests, which crop + # all inputs along the batch dim, keep ``pixel_values`` and ``image_grid_thw`` consistent. + self.num_patches = 1 + self.num_image_tokens = self.num_patches // (self.spatial_merge_size**2) + self.seq_length = seq_length + self.num_image_tokens + self.encoder_seq_length = self.seq_length + + def get_config(self): + return MiniMaxM3VLConfig( + text_config=self.text_config, + vision_config=self.vision_config, + image_token_index=self.image_token_index, + video_token_index=self.video_token_index, + projector_hidden_size=self.text_config["hidden_size"], + ) + + def prepare_config_and_inputs(self): + config = self.get_config() + patch_dim = self.num_channels * (self.patch_size**2) * self.temporal_patch_size + pixel_values = floats_tensor([self.batch_size * self.num_patches, patch_dim]) + return config, pixel_values + + def prepare_config_and_inputs_for_common(self): + config, pixel_values = self.prepare_config_and_inputs() + + input_ids = ids_tensor([self.batch_size, self.seq_length], config.text_config.vocab_size - 2) + 2 + attention_mask = torch.ones(input_ids.shape, dtype=torch.long, device=torch_device) + input_ids[input_ids == self.image_token_index] = self.pad_token_id + input_ids[input_ids == self.video_token_index] = self.pad_token_id + input_ids[:, : self.num_image_tokens] = self.image_token_index + + inputs_dict = { + "pixel_values": pixel_values, + "image_grid_thw": torch.tensor([[1, 1, 1]] * self.batch_size, device=torch_device), + "input_ids": input_ids, + "attention_mask": attention_mask, + } + return config, inputs_dict + + +@require_torch +class MiniMaxM3VLModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin, unittest.TestCase): + """ + Model tester for `MiniMaxM3SparseForConditionalGeneration`. + """ + + all_model_classes = ( + ( + MiniMaxM3VLModel, + MiniMaxM3SparseForConditionalGeneration, + ) + if is_torch_available() + else () + ) + pipeline_model_mapping = ( + { + "image-text-to-text": MiniMaxM3SparseForConditionalGeneration, + } + if is_torch_available() + else {} + ) + _is_composite = True + # The vision tower packs every image's (and video frame's) patches into a single sequence + # (batch dim 1), so ``last_hidden_state`` does not carry a per-item batch axis to shape-check. + skip_test_image_features_output_shape = True + skip_test_video_features_output_shape = True + + # The indexer parameters only influence the argmax over compressed blocks (``topk``), + # which is non-differentiable — their gradients flow through a separate objective in + # the upstream training recipe, not the main causal-LM loss (same as DeepSeek-V4). + test_all_params_have_gradient = False + + def setUp(self): + self.model_tester = MiniMaxM3VLVisionText2TextModelTester(self) + self.config_tester = ConfigTester(self, config_class=MiniMaxM3VLConfig, has_text_modality=False) + + def test_config(self): + self.config_tester.run_common_tests() + + @unittest.skip(reason="IDK exactly why, can be adressed later") + def test_reverse_loading_mapping(self): + pass + + @unittest.skip( + reason=( + "This model lists the Lightning-indexer MSA kernel in `_compatible_flash_implementations`, so " + "requesting `flash_attention_2` is auto-redirected to `kernels-staging/msa@v0` (the model's " + "native flash path) instead of raising. The base test instead expects a `ValueError` whenever " + "the composite sub-models don't all set `_supports_flash_attn`, an invariant that does not hold " + "for a model whose flash implementation *is* a custom kernel. MSA dispatch via the public " + "`attn_implementation` API is covered by the slow integration tests." + ) + ) + def test_flash_attn_2_can_dispatch_composite_models(self): + pass + + @parameterized.expand(TEST_EAGER_MATCHES_BATCHED_AND_GROUPED_INFERENCE_PARAMETERIZATION) + def test_eager_matches_batched_and_grouped_inference(self, name, dtype): + # In low precision the grouped/batched/sonic expert kernels accumulate in a different order + # than the eager per-expert loop, so a handful of near-zero MoE outputs drift past the 1e-4 + # tolerance. This is precision noise, not a logic mismatch (fp32 matches exactly). + if dtype in ("fp16", "bf16"): + self.skipTest("Low-precision float casting fluctuations across expert kernels exceed the 1e-4 tolerance") + _test_eager_matches_batched_and_grouped_inference(self, name, dtype) + + @unittest.skip( + reason=( + "The lightning indexer tiles the key axis into blocks of `index_block_size` *slots* and " + "selects whole blocks, so its block boundaries are anchored to absolute sequence slots. " + "Left-padding shifts every real token by the (per-row, generally non-block-aligned) pad " + "width, which regroups real keys into different blocks than the unpadded run and changes " + "which blocks win top-k — so left-padded logits diverge by design. This is the same " + "block-sparse limitation as DeepSeek-V4; see `test_right_padding_does_not_leak` for the " + "padding direction that *is* equivalent." + ) + ) + def test_left_padding_compatibility(self): + pass + + def test_right_padding_does_not_leak(self): + """Right-padding must not change a sequence's real-token logits. + + Pad keys land on slots *after* every real token, so block-level causality drops them before + top-k selection and the indexer's folded mask zeroes the pad columns. The real tokens therefore + occupy the same slots and see the same key blocks as in an unpadded run -- unlike left-padding, + which shifts the slot-anchored block boundaries (see `test_left_padding_compatibility`). + """ + config, _ = self.model_tester.prepare_config_and_inputs_for_common() + config.text_config._attn_implementation = "eager" + vocab = config.text_config.vocab_size + pad_id = config.text_config.pad_token_id + # Text-only inputs: keep ids clear of the image/video placeholder ids so no vision tower runs. + low = max(self.model_tester.image_token_index, self.model_tester.video_token_index) + 1 + torch.manual_seed(0) + lengths = [self.model_tester.seq_length + 6, self.model_tester.seq_length + 1] + seqs = [torch.randint(low, vocab - 2, (n,), device=torch_device) for n in lengths] + max_len = max(lengths) + + model = MiniMaxM3SparseForConditionalGeneration(config).to(torch_device).eval() + + per_seq_logits = [] + for seq in seqs: + with torch.no_grad(): + out = model( + input_ids=seq[None], + attention_mask=torch.ones(1, len(seq), dtype=torch.long, device=torch_device), + ) + per_seq_logits.append(out.logits[0, : len(seq)]) + + input_ids = torch.full((len(seqs), max_len), pad_id, device=torch_device) + attention_mask = torch.zeros(len(seqs), max_len, dtype=torch.long, device=torch_device) + for i, seq in enumerate(seqs): + input_ids[i, : len(seq)] = seq + attention_mask[i, : len(seq)] = 1 + with torch.no_grad(): + batched_logits = model(input_ids=input_ids, attention_mask=attention_mask).logits + + for i, seq in enumerate(seqs): + torch.testing.assert_close(batched_logits[i, : len(seq)], per_seq_logits[i], rtol=1e-4, atol=1e-4) + + def test_mismatching_num_image_tokens(self): + """ + Tests that VLMs raise an explicit error when the number of images doesn't match the number + of image tokens in the text, and that genuine multi-image cases are accepted. + """ + config, input_dict = self.model_tester.prepare_config_and_inputs_for_common() + num_patches = self.model_tester.num_patches + for model_class in self.all_model_classes: + model = model_class(config).to(torch_device) + model.eval() + curr_input_dict = copy.deepcopy(input_dict) + _ = model(**curr_input_dict) # successful forward with no modifications + + # remove one image but leave its image tokens in text + curr_input_dict["pixel_values"] = curr_input_dict["pixel_values"][:-num_patches, ...] + curr_input_dict["image_grid_thw"] = curr_input_dict["image_grid_thw"][:-1, ...] + with self.assertRaisesRegex(ValueError, "Image features and image tokens do not match"): + _ = model(**curr_input_dict) + + # simulate multi-image case by concatenating inputs where each has exactly one image + input_ids = curr_input_dict["input_ids"][:1] + pixel_values = curr_input_dict["pixel_values"][:num_patches] + image_grid_thw = curr_input_dict["image_grid_thw"][:1] + input_ids = torch.cat([input_ids, input_ids], dim=0) + + # two image-token groups but one image raises an error + with self.assertRaisesRegex(ValueError, "Image features and image tokens do not match"): + _ = model(input_ids=input_ids, pixel_values=pixel_values, image_grid_thw=image_grid_thw) + + # two images and two image-token groups don't raise an error + pixel_values = torch.cat([pixel_values, pixel_values], dim=0) + image_grid_thw = torch.cat([image_grid_thw, image_grid_thw], dim=0) + _ = model(input_ids=input_ids, pixel_values=pixel_values, image_grid_thw=image_grid_thw) + + def test_video_forward(self): + """Video frames flow through the same vision tower as images and scatter into the video-token slots.""" + config, _ = self.model_tester.prepare_config_and_inputs_for_common() + + batch_size = self.model_tester.batch_size + num_channels = self.model_tester.num_channels + temporal_patch_size = self.model_tester.temporal_patch_size + patch_size = self.model_tester.patch_size + merge_size = self.model_tester.spatial_merge_size + + num_frames = 4 + grid_t = num_frames // temporal_patch_size + grid_h = self.model_tester.image_size // patch_size + grid_w = self.model_tester.image_size // patch_size + patches_per_video = grid_t * grid_h * grid_w + # ``patch_merge`` groups ``merge_size**2`` patches, so each video yields this many tokens. + tokens_per_video = patches_per_video // (merge_size**2) + + patch_dim = num_channels * (patch_size**2) * temporal_patch_size + pixel_values_videos = floats_tensor([batch_size * patches_per_video, patch_dim]) + video_grid_thw = torch.tensor([[grid_t, grid_h, grid_w]] * batch_size, device=torch_device) + # The vision tower consumes exactly ``grid_t * grid_h * grid_w`` patches per video. + self.assertEqual(pixel_values_videos.shape[0], int(video_grid_thw.prod(dim=1).sum())) + + input_ids = ids_tensor([batch_size, self.model_tester.seq_length], config.text_config.vocab_size - 2) + 2 + input_ids[input_ids == self.model_tester.image_token_index] = self.model_tester.pad_token_id + input_ids[input_ids == self.model_tester.video_token_index] = self.model_tester.pad_token_id + # Carve out one contiguous block of video-token slots per sequence. + self.assertLessEqual(tokens_per_video, self.model_tester.seq_length) + input_ids[:, :tokens_per_video] = self.model_tester.video_token_index + attention_mask = torch.ones_like(input_ids) + + for model_class in self.all_model_classes: + model = model_class(config).to(torch_device) + model.eval() + with torch.no_grad(): + outputs = model( + input_ids=input_ids, + attention_mask=attention_mask, + pixel_values_videos=pixel_values_videos, + video_grid_thw=video_grid_thw, + ) + self.assertIsNotNone(outputs) + self.assertIsNotNone(outputs.video_hidden_states) + self.assertEqual(outputs.video_hidden_states.shape[0], batch_size * tokens_per_video) + + def test_mismatching_num_video_tokens(self): + """VLMs must raise when the number of videos doesn't match the number of video tokens in the text.""" + config, _ = self.model_tester.prepare_config_and_inputs_for_common() + + batch_size = self.model_tester.batch_size + num_channels = self.model_tester.num_channels + temporal_patch_size = self.model_tester.temporal_patch_size + patch_size = self.model_tester.patch_size + merge_size = self.model_tester.spatial_merge_size + + num_frames = 4 + grid_t = num_frames // temporal_patch_size + grid_h = self.model_tester.image_size // patch_size + grid_w = self.model_tester.image_size // patch_size + patches_per_video = grid_t * grid_h * grid_w + tokens_per_video = patches_per_video // (merge_size**2) + + patch_dim = num_channels * (patch_size**2) * temporal_patch_size + pixel_values_videos = floats_tensor([batch_size * patches_per_video, patch_dim]) + video_grid_thw = torch.tensor([[grid_t, grid_h, grid_w]] * batch_size, device=torch_device) + + input_ids = ids_tensor([batch_size, self.model_tester.seq_length], config.text_config.vocab_size - 2) + 2 + input_ids[input_ids == self.model_tester.image_token_index] = self.model_tester.pad_token_id + input_ids[input_ids == self.model_tester.video_token_index] = self.model_tester.pad_token_id + # One fewer video-token slot than features -> mismatch. + input_ids[:, : tokens_per_video - 1] = self.model_tester.video_token_index + + for model_class in self.all_model_classes: + model = model_class(config).to(torch_device) + model.eval() + with self.assertRaisesRegex(ValueError, "Video features and video tokens do not match"): + _ = model( + input_ids=input_ids, + pixel_values_videos=pixel_values_videos, + video_grid_thw=video_grid_thw, + ) + + +@slow +@require_torch +class MiniMaxM3VLIntegrationTest(unittest.TestCase): + model_id = "MiniMaxAI/Minimax-M3-preview" + + def _load_model(self): + # The indexer feeds SDPA an additive float mask (the block-sparse bias). On B200 + this + # cuDNN build, the cuDNN SDPA backend segfaults in ``run_cudnn_SDP_fprop`` on such masks; + # disabling it routes SDPA to the mem-efficient backend, which handles additive float masks. + torch.backends.cuda.enable_cudnn_sdp(False) + + # Out-of-the-box load: the MXFP8 ``quantization_config`` (quant_method, weight_block_size, + # and the ``ignored_layers`` skip-list) is read straight from the checkpoint's config.json + # and dispatched automatically — no hand-built quant config, no manual dtype patching. + model = MiniMaxM3SparseForConditionalGeneration.from_pretrained( + self.model_id, + dtype=torch.bfloat16, + device_map="auto", + local_files_only=True, + ) + model.eval() + return model + + def _load_processor(self): + from transformers.utils import cached_file + + tokenizer = AutoTokenizer.from_pretrained(self.model_id, local_files_only=True) + image_processor = MiniMaxM3VLImageProcessorFast.from_pretrained(self.model_id, local_files_only=True) + video_processor = MiniMaxM3VLVideoProcessor.from_pretrained(self.model_id, local_files_only=True) + with open(cached_file(self.model_id, "chat_template.jinja", local_files_only=True)) as f: + chat_template = f.read() + return MiniMaxM3VLProcessor( + image_processor=image_processor, + tokenizer=tokenizer, + video_processor=video_processor, + chat_template=chat_template, + ) + + @staticmethod + def _prompt(processor, question): + messages = [{"role": "user", "content": [{"type": "image"}, {"type": "text", "text": question}]}] + return processor.apply_chat_template( + messages, add_generation_prompt=True, tokenize=False, thinking_mode="disabled" + ) + + def test_image_and_text_generation(self): + model = self._load_model() + processor = self._load_processor() + image = Image.new("RGB", (672, 672), (127, 127, 127)) + text = self._prompt(processor, "Describe this image briefly.") + inputs = processor(images=[image], text=text, return_tensors="pt").to(model.device) + with torch.no_grad(): + output = model.generate(**inputs, max_new_tokens=32, do_sample=False) + decoded = processor.batch_decode(output, skip_special_tokens=True)[0] + print(f"\n[test_image_and_text_generation] generation:\n{decoded!r}\n") + self.assertIsInstance(decoded, str) + self.assertGreater(len(decoded.strip()), 0) + + def test_real_image_apple_recognition(self): + import os + + model = self._load_model() + processor = self._load_processor() + + apple_path = os.path.join(os.path.dirname(__file__), "../../fixtures/tests_samples/COCO/apple.jpg") + image = Image.open(apple_path).convert("RGB") + text = self._prompt(processor, "What fruit is shown in this image? Answer in one word.") + inputs = processor(images=[image], text=text, return_tensors="pt").to(model.device) + with torch.no_grad(): + output = model.generate(**inputs, max_new_tokens=16, do_sample=False) + completion = processor.batch_decode(output[:, inputs.input_ids.size(1) :], skip_special_tokens=True)[0] + print(f"\n[test_real_image_apple_recognition] completion:\n{completion!r}\n") + self.assertIn("apple", completion.lower()) + + def test_batched_image_generation(self): + """Batch of two image+text prompts, no padding. + + Mirrors the DeepSeek-V4 multi-prompt integration check: two distinct solid-color + images are batched with an identical prompt, so both rows tokenize to the same + length and no padding is needed. A correct run must describe each image with its + own color, proving the vision features stay aligned with the right tokens across + the batch and the MXFP8 MoE path. + + Padding is deliberately avoided: the block-sparse indexer anchors blocks to absolute + key *slots* (see ``MiniMaxM3VLIndexer`` docstring), so left-padding shifts block + boundaries and diverges from an unpadded run — the same slot-based limitation as + DeepSeek-V4. ``test_left_padding_compatibility`` documents that gap. + """ + model = self._load_model() + processor = self._load_processor() + + red = Image.new("RGB", (672, 672), (200, 30, 30)) + blue = Image.new("RGB", (672, 672), (30, 30, 200)) + # Identical prompt + identical image geometry → identical token length → no padding. + question = "What is the dominant color of this image? Answer in one word." + texts = [self._prompt(processor, question), self._prompt(processor, question)] + inputs = processor(images=[red, blue], text=texts, return_tensors="pt").to(model.device) + self.assertEqual(inputs.input_ids.shape[0], 2) + with torch.no_grad(): + output = model.generate(**inputs, max_new_tokens=32, do_sample=False) + completions = processor.batch_decode(output[:, inputs.input_ids.size(1) :], skip_special_tokens=True) + print(f"\n[test_batched_image_generation] completions:\n{completions!r}\n") + self.assertEqual(len(completions), 2) + for completion in completions: + self.assertGreater(len(completion.strip()), 0) + # Each completion should name its own image's color, not the other's. + self.assertIn("red", completions[0].lower()) + self.assertIn("blue", completions[1].lower()) + + def test_padding_sides_text_and_image(self): + """Lock the batched-padding contract end to end, for both text and image batches. + + Two facts, both consequences of the slot-anchored Lightning indexer (see + ``MiniMaxM3VLIndexer``) plus causal attention: + + * RIGHT padding does not change a real token's prediction. Pad keys land on slots *after* + every real token, so causal attention + the folded indexer mask drop them before any real + query attends to them. The MoE still *runs* on the pad rows (there is no token-level mask + inside the router), but that wasted compute never flows back into a real token -- so the + well-known right-padding generation garbage is purely a "generation continues from a pad + slot" artifact, NOT a missing-mask bug in the MoE. We assert the greedy next-token at each + real position is identical to an unpadded single-sequence run. + * LEFT padding is therefore the side to use for batched ``generate``: every real token's + continuation stays anchored to the live slots, so each row decodes coherently. + """ + import requests + + model = self._load_model() + processor = self._load_processor() + tokenizer = processor.tokenizer + + # ---- RIGHT padding: real-token greedy predictions must match an unpadded run (no MoE leak) ---- + prompts = [ + "The history of computing began with", + "Summarize the theory of relativity in a single concise sentence:", + ] + per_seq = [tokenizer(p, return_tensors="pt").to(model.device) for p in prompts] + per_seq_logits = [] + for enc in per_seq: + with torch.no_grad(): + per_seq_logits.append(model(**enc).logits[0, : enc.input_ids.size(1)]) + + tokenizer.padding_side = "right" + right = tokenizer(prompts, return_tensors="pt", padding=True).to(model.device) + with torch.no_grad(): + right_logits = model(**right).logits + for i, enc in enumerate(per_seq): + n = enc.input_ids.size(1) + self.assertTrue( + torch.equal(right_logits[i, :n].argmax(-1), per_seq_logits[i].argmax(-1)), + "Right-padding changed a real token's greedy prediction -> padding leaked into the " + "MoE/attention. The MoE may compute on pad rows, but causal attention + the indexer " + "mask must keep that out of every real token.", + ) + + # ---- LEFT padding: batched text generation decodes each row coherently ---- + tokenizer.padding_side = "left" + left = tokenizer(prompts, return_tensors="pt", padding=True).to(model.device) + with torch.no_grad(): + gen = model.generate(**left, max_new_tokens=24, do_sample=False) + completions = tokenizer.batch_decode(gen[:, left.input_ids.size(1) :], skip_special_tokens=True) + print(f"\n[test_padding_sides_text_and_image] left-pad text completions:\n{completions!r}\n") + for completion in completions: + self.assertGreater(len(completion.strip()), 0) + + # ---- LEFT padding with an image batch of differing prompt lengths (real padding) ---- + # Two real, semantically distinct images downloaded from the hub/web (the picsum dog and the + # canonical COCO two-cats photo), paired with deliberately different-length questions so the + # batch genuinely needs padding. + dog = Image.open(requests.get("https://picsum.photos/id/237/400/300", stream=True).raw).convert("RGB") + cats = Image.open( + requests.get("http://images.cocodataset.org/val2017/000000039769.jpg", stream=True).raw + ).convert("RGB") + texts = [ + self._prompt(processor, "What animal is this? One word."), + self._prompt( + processor, + "Look carefully at this photograph and tell me which animal appears in it, " + "answering with a single word.", + ), + ] + processor.tokenizer.padding_side = "left" + img_inputs = processor(images=[dog, cats], text=texts, return_tensors="pt", padding=True).to(model.device) + self.assertEqual(img_inputs.input_ids.shape[0], 2) + with torch.no_grad(): + img_gen = model.generate(**img_inputs, max_new_tokens=16, do_sample=False) + img_completions = processor.batch_decode(img_gen[:, img_inputs.input_ids.size(1) :], skip_special_tokens=True) + print(f"\n[test_padding_sides_text_and_image] left-pad image completions:\n{img_completions!r}\n") + self.assertIn("dog", img_completions[0].lower()) + self.assertIn("cat", img_completions[1].lower()) + + def test_video_generation(self): + """End-to-end video path: the processor emits ``pixel_values_videos`` / ``video_grid_thw`` and the + model scatters the video features into the video-token slots before generating. + + Uses a short synthetic clip rather than a network fetch: 672 is divisible by the vision tower's + ``patch_size * spatial_merge_size`` factor (28) and the frame count is a multiple of + ``temporal_patch_size``, so the video grid is well formed and the merged-patch count lines up + exactly with the expanded video tokens. + """ + import numpy as np + + model = self._load_model() + processor = self._load_processor() + + num_frames = 4 + video = np.zeros((num_frames, 672, 672, 3), dtype=np.uint8) + video[..., 0] = 200 # a solid red-ish clip + + messages = [ + { + "role": "user", + "content": [ + {"type": "video"}, + {"type": "text", "text": "What is the dominant color in this video? Answer in one word."}, + ], + } + ] + text = processor.apply_chat_template( + messages, add_generation_prompt=True, tokenize=False, thinking_mode="disabled" + ) + inputs = processor(videos=[video], text=text, return_tensors="pt").to(model.device) + # The processor must have produced the video tensors the model consumes. + self.assertIn("pixel_values_videos", inputs) + self.assertIn("video_grid_thw", inputs) + with torch.no_grad(): + output = model.generate(**inputs, max_new_tokens=32, do_sample=False) + decoded = processor.batch_decode(output[:, inputs.input_ids.size(1) :], skip_special_tokens=True)[0] + print(f"\n[test_video_generation] generation:\n{decoded!r}\n") + self.assertIsInstance(decoded, str) + self.assertGreater(len(decoded.strip()), 0) + self.assertIn("red", decoded.lower()) diff --git a/utils/check_config_attributes.py b/utils/check_config_attributes.py index 25a4f7e00edc..9b176987ccc6 100644 --- a/utils/check_config_attributes.py +++ b/utils/check_config_attributes.py @@ -48,6 +48,7 @@ "pooling_kernel_size", ], # Used as meta data for other attributes/properties "MiniCPMV4_6Config": ["drop_vision_last_layer"], + "MiniMaxM3VLTextConfig": ["rotary_dim", "router_jitter_noise"], "OpenAIPrivacyFilterConfig": ["classifier_dropout", "output_router_logits", "router_aux_loss_coef"], "HYV3Config": ["output_router_logits"], "NougatConfig": ["decoder", "encoder"], diff --git a/utils/check_repo.py b/utils/check_repo.py index 778f9ecdbb42..e026cba796eb 100644 --- a/utils/check_repo.py +++ b/utils/check_repo.py @@ -113,6 +113,7 @@ "SmolVLMVisionTransformer", "MiniCPMV4_6VisionPreTrainedModel", "MiniCPMV4_6VisionModel", + "MiniMaxM3VLVisionModel", "AriaTextForCausalLM", "AriaTextModel", "Phi4MultimodalAudioModel", @@ -223,6 +224,9 @@ "PaddleOCRTextModel", # Building part of bigger (tested) model. Tested implicitly through PaddleOCRVLForConditionalGeneration. "Qwen2VLModel", # Building part of bigger (tested) model. Tested implicitly through Qwen2VLForConditionalGeneration. "Qwen2_5_VLModel", # Building part of bigger (tested) model. Tested implicitly through Qwen2_5_VLForConditionalGeneration. + "MiniMaxM3VLForCausalLM", # Building part of bigger (tested) model. Tested implicitly through MiniMaxM3SparseForConditionalGeneration. + "MiniMaxM3VLTextModel", # Building part of bigger (tested) model. Tested implicitly through MiniMaxM3SparseForConditionalGeneration. + "MiniMaxM3VLVisionModel", # Building part of bigger (tested) model. Tested implicitly through MiniMaxM3SparseForConditionalGeneration. "Qwen3VLModel", # Building part of bigger (tested) model. Tested implicitly through Qwen3VLForConditionalGeneration. "Qwen3VLMoeModel", # Building part of bigger (tested) model. Tested implicitly through Qwen3VLMoeForConditionalGeneration. "Qwen3VLTextModel", # Building part of bigger (tested) model. diff --git a/utils/modular_model_detector.py b/utils/modular_model_detector.py index 33a71e4f1fd7..92ab1206502d 100644 --- a/utils/modular_model_detector.py +++ b/utils/modular_model_detector.py @@ -257,6 +257,7 @@ def __init__(self, hub_dataset: str): self.model = AutoModel.from_pretrained(EMBEDDING_MODEL, torch_dtype="auto", device_map="auto").eval() self.device = self.model.device + self.dtype = self.model.dtype self.index_dir: Path | None = None # ---------- HUB IO ----------