Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
55 commits
Select commit Hold shift + click to select a range
f1a80e5
Add MiniMax M3 VL: draft scaffold with modular, tiny fixtures, MXFP8 …
ArthurZucker Jun 3, 2026
379343f
Strip dead branches + port lightning indexer from deepseek_v4
ArthurZucker Jun 3, 2026
dea986c
KV-cache the sparse index branch + add minimax.py generation demo
ArthurZucker Jun 3, 2026
150f816
updates
ArthurZucker Jun 3, 2026
01a05aa
updates
ArthurZucker Jun 3, 2026
b297772
Wire video into M3 VL forward + native MXFP8 path
ArthurZucker Jun 4, 2026
fcb2c08
Stop tracking local dev scripts (minimax.py demo, build_tiny_model.py)
ArthurZucker Jun 4, 2026
61425ab
update
ArthurZucker Jun 4, 2026
082d4f4
cleanup
ArthurZucker Jun 4, 2026
4c9d1d1
up
ArthurZucker Jun 8, 2026
2072307
up
ArthurZucker Jun 8, 2026
ebd4916
update
ArthurZucker Jun 8, 2026
3fbebb6
updates
ArthurZucker Jun 8, 2026
6cb103f
nits
ArthurZucker Jun 8, 2026
b7b4f36
update
ArthurZucker Jun 8, 2026
02dcbfc
nits
ArthurZucker Jun 8, 2026
979f9e0
more cleanup
ArthurZucker Jun 8, 2026
064c3fc
add visuals
ArthurZucker Jun 8, 2026
2b56365
more cleanups
ArthurZucker Jun 8, 2026
3d714ec
up
ArthurZucker Jun 8, 2026
c907c1d
nits
ArthurZucker Jun 8, 2026
7897ac9
nits
ArthurZucker Jun 9, 2026
c5ce148
man made
ArthurZucker Jun 9, 2026
1b585f1
UP
ArthurZucker Jun 9, 2026
f73b3b4
up
ArthurZucker Jun 9, 2026
1982311
fix
ArthurZucker Jun 9, 2026
4caead2
revert some modular stuff
ArthurZucker Jun 9, 2026
8de6991
fix bug, rename
ArthurZucker Jun 11, 2026
c366c09
update
ArthurZucker Jun 11, 2026
0291279
update with kernel on the hub
ArthurZucker Jun 11, 2026
a5ee89b
comments
ArthurZucker Jun 11, 2026
441a7d0
small update
ArthurZucker Jun 11, 2026
128919e
start adressing comments
ArthurZucker Jun 12, 2026
374e0d5
?
ArthurZucker Jun 12, 2026
544ef85
nits
ArthurZucker Jun 12, 2026
a48a461
update
ArthurZucker Jun 12, 2026
833c0cb
nits
ArthurZucker Jun 12, 2026
d881c55
update
ArthurZucker Jun 12, 2026
3cbbb60
latest changes
ArthurZucker Jun 12, 2026
457ab8d
up
ArthurZucker Jun 12, 2026
aaf9c04
fix padding
ArthurZucker Jun 12, 2026
0998b47
nit to define pos ids once
ArthurZucker Jun 12, 2026
290b846
just a nit
ArthurZucker Jun 12, 2026
36e84e1
Merge remote-tracking branch 'up/main' into up/add-minimax-m3vl
ArthurZucker Jun 12, 2026
fa48718
update
ArthurZucker Jun 12, 2026
272e1d2
nits
ArthurZucker Jun 12, 2026
2007e76
date
ArthurZucker Jun 12, 2026
0b896d3
skip 1 test
ArthurZucker Jun 12, 2026
22f4677
nits
ArthurZucker Jun 12, 2026
6af44e0
extra test
ArthurZucker Jun 12, 2026
a4fce3a
delete pos ids "fix"
ArthurZucker Jun 12, 2026
ee57ebd
last nit
ArthurZucker Jun 12, 2026
7d5a5eb
fix
ArthurZucker Jun 12, 2026
a83af39
fix
ArthurZucker Jun 12, 2026
c6615c8
fix
ArthurZucker Jun 12, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions docs/source/en/_toctree.yml
Original file line number Diff line number Diff line change
Expand Up @@ -741,6 +741,8 @@
title: MiniMax
- local: model_doc/minimax_m2
title: MiniMax-M2
- local: model_doc/minimax_m3_vl
title: MiniMax-M3-VL
- local: model_doc/ministral
title: Ministral
- local: model_doc/ministral3
Expand Down
253 changes: 253 additions & 0 deletions docs/source/en/model_doc/minimax_m3_vl.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,253 @@
<!--Copyright 2026 the HuggingFace Team. All rights reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.


⚠️ Note that this file is in Markdown but contain specific syntax for our doc-builder (similar to MDX) that may not be rendered properly in your Markdown viewer.

-->
*This model was contributed to Hugging Face Transformers on 2026-06-12.*


# MiniMax-M3-VL

## Overview

MiniMax-M3-VL is the vision-language member of the MiniMax-M3 family. It pairs a CLIP-style vision tower (Conv3d patch embedding with 3D rotary position embeddings) with the MiniMax-M3 text backbone, a mixed dense/sparse Mixture-of-Experts decoder that uses SwiGLU-OAI gated experts and a lightning indexer for block-sparse attention.

## Architecture
### Block-sparse attention (Lightning Indexer)

Every layer is GQA (`num_key_value_heads = 4`) with per-head QK-norm and **partial RoPE** on the first
`rotary_dim`. `config.layer_types[i]` then picks `"full_attention"` (dense causal) or
`"minimax_m3_sparse"`, where a [`MiniMaxM3VLIndexer`] decides, per query, which block of keys the main attention may see.

The indexer scores every key, then **max-poolsthose per-key scores into blocks of `index_block_size` keys**, so selection happens at the granularity of a *block
of keys*: per query it keeps the top-`index_topk_blocks` key blocks plus the always-on `index_local_blocks`
local-window block (under block-level causality), broadcasts the per-block `0`/`-inf` choice back onto every key in
the block. The result is a `[B, 1, S_q, S_k]` additive bias summed onto the causal mask.
Theoretically this means that the attention is only computed over the selected blocks of keys, but `transformers` does not support the kernels that compute this efficiently!
We are adding it to `kernels` asap!

<img alt="MiniMax M3 Lightning Indexer mask" src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/model_doc/minimax_m3_vl_indexer_mask.svg" />


### Vision tower

A [`MiniMaxM3VLVisionModel`]: a `Conv3d` patch embedding over flattened `[N_patches, C·T·P·P]` input, a stack of
CLIP-style encoder layers carrying a **3D rotary** position embedding (time / height / width bands). A [`MiniMaxM3VLPatchMerger`] groups
`spatial_merge_size²` patches into the channel dim before the 2-layer GELU [`MiniMaxM3VLMultiModalProjector`] maps vision features into the text hidden size.

## Usage examples

The example below runs the model on a real image loaded with [`~transformers.image_utils.load_image`].

```python
import torch
from transformers import AutoModelForImageTextToText, AutoProcessor
from transformers.image_utils import load_image


model = AutoModelForImageTextToText.from_pretrained(
"MiniMaxAI/MiniMax-M3-preview", dtype=torch.bfloat16, device_map="auto",
)
processor = AutoProcessor.from_pretrained("MiniMaxAI/MiniMax-M3-preview")

image = load_image("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/bee.jpg")
messages = [
{
"role": "user",
"content": [
{"type": "image"},
{"type": "text", "text": "Describe this image briefly."},
],
}
]
text = processor.apply_chat_template(messages, add_generation_prompt=True, tokenize=False)
inputs = processor(images=[image], text=text, return_tensors="pt").to(model.device)

generated_ids = model.generate(**inputs, max_new_tokens=32, do_sample=False)
print(processor.batch_decode(generated_ids, skip_special_tokens=True)[0])
```

### Apple example

This example asks the model about an image of apples, again loading a real image with
[`~transformers.image_utils.load_image`].

```python
import torch
from transformers import AutoModelForImageTextToText, AutoProcessor
from transformers.image_utils import load_image


model = AutoModelForImageTextToText.from_pretrained(
"MiniMaxAI/MiniMax-M3-preview", dtype=torch.bfloat16, device_map="auto",
)
processor = AutoProcessor.from_pretrained("MiniMaxAI/MiniMax-M3-preview")

image = load_image("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/bee.jpg")
messages = [
{
"role": "user",
"content": [
{"type": "image"},
{"type": "text", "text": "How many apples are in this image, and what color are they?"},
],
}
]
text = processor.apply_chat_template(messages, add_generation_prompt=True, tokenize=False)
inputs = processor(images=[image], text=text, return_tensors="pt").to(model.device)

generated_ids = model.generate(**inputs, max_new_tokens=32, do_sample=False)
print(processor.batch_decode(generated_ids, skip_special_tokens=True)[0])
```

## Fastest inference configuration

| ctx | SDPA decode | MSA decode | MSA decode adv. | SDPA prefill | MSA prefill | MSA prefill adv. |
| --: | ----------: | ---------: | --------------: | -----------: | ----------: | ---------------: |
| 2K | 27.8 tok/s | 31.0 | +12% | 303 ms | 257 ms | 1.18× |
| 4K | 23.4 tok/s | 30.5 | +30% | 684 ms | 460 ms | 1.49× |
| 8K | 17.8 tok/s | 29.6 | +66% | 1906 ms | 976 ms | 1.95× |
| 16K | 12.0 tok/s | 27.6 | +130% | 6110 ms | 2344 ms | 2.61× |

The checkpoint ships in native MXFP8. For **decode throughput**, the fastest validated configuration is
**bf16 (dequantized at load) + the MSA block-sparse attention kernel + tensor & expert parallelism + a
`reduce-overhead` cudagraph compile** — roughly **31 tok/s** decode on 8×B200 at a 2048-token prefill.

Keeping the weights in **native FP8 is a memory-footprint option only — it is never faster on this setup**.
The FP8 Triton experts/linear kernels lower as opaque inductor fallback kernels that cudagraph cannot
capture on the hot expert path, so native-FP8 decode measured ~4.2 tok/s (≈7× slower than the bf16 path)
even under `torch.compile(fullgraph=True)`. Use FP8 only when the bf16 weights do not fit.

| config (sdpa baseline, TP+EP, 2048-token prefill, 8×B200) | decode |
|---|---|
| bf16 dequantize-at-load + **MSA** + compile/cudagraph | **~31 tok/s** |
| bf16 dequantize-at-load + sdpa + compile/cudagraph | ~28 tok/s |
| native FP8 + compile/cudagraph | ~4 tok/s (memory-only, not for speed) |

Dequantizing to bf16 only fits with even sharding across GPUs (TP/EP), not with `device_map="auto"`
(pipeline placement OOMs at load). Launch one process per GPU with `torchrun`:

```bash
torchrun --nproc_per_node=8 fastest_m3_vl.py
```

```python
# fastest_m3_vl.py
import os, sys
import torch
import torch.distributed as dist
from transformers import (
AutoModelForImageTextToText,
AutoTokenizer,
CompileConfig,
FineGrainedFP8Config,
)
from transformers.distributed import DistributedConfig

# The indexer feeds SDPA an additive float mask; the cuDNN SDP backend segfaults on it (B200).
torch.backends.cuda.enable_cudnn_sdp(False)

model = AutoModelForImageTextToText.from_pretrained(
"MiniMaxAI/MiniMax-M3-preview",
dtype=torch.bfloat16,
# Dequantize the native MXFP8 weights to bf16 at load (the speed win); needs even TP/EP sharding.
quantization_config=FineGrainedFP8Config(dequantize=True),
tp_plan="auto",
distributed_config=DistributedConfig(enable_expert_parallel=True),
attn_implementation="kernels-staging/msa@v0", # MSA block-sparse attention kernel
)
model.eval()

tokenizer = AutoTokenizer.from_pretrained("MiniMaxAI/MiniMax-M3-preview")
messages = [{"role": "user", "content": "Summarize the history of computing."}]
inputs = tokenizer.apply_chat_template(
messages, add_generation_prompt=True, return_tensors="pt", return_dict=True
).to(f"cuda:{os.environ.get('LOCAL_RANK', '0')}")

generated_ids = model.generate(
**inputs,
max_new_tokens=128,
do_sample=False,
# Static cache + reduce-overhead cudagraph capture is what pushes decode to ~31 tok/s.
cache_implementation="static",
compile_config=CompileConfig(mode="reduce-overhead", fullgraph=True),
)
if int(os.environ.get("RANK", "0")) == 0:
print(tokenizer.decode(generated_ids[0, inputs["input_ids"].shape[1]:], skip_special_tokens=True))

# cudagraph-captured NCCL collectives deadlock the NCCL/CUDA destructors at teardown; the output is
# already produced, so hard-exit to skip the hanging cleanup.
if dist.is_initialized():
sys.stdout.flush()
os._exit(0)
```

## MiniMaxM3VLConfig

[[autodoc]] MiniMaxM3VLConfig

## MiniMaxM3VLTextConfig

[[autodoc]] MiniMaxM3VLTextConfig

## MiniMaxM3VLVisionConfig

[[autodoc]] MiniMaxM3VLVisionConfig

## MiniMaxM3VLProcessor

[[autodoc]] MiniMaxM3VLProcessor

## MiniMaxM3VLImageProcessor

This is a standalone (non-modular) image processor: it shares the patch-flattening idea of [`Qwen2VLImageProcessor`]
but does not inherit from it because the two diverge in ways that touch most of the class. The resize budget is driven by
a `max_pixels` attribute and a `{"height", "width"}` `size` rather than Qwen's `shortest_edge`/`longest_edge` scheme; the
`smart_resize` helper clamps the initial rounding with `max(factor, ...)`; and `_preprocess` performs real temporal
handling (5D patches, last-frame repeat to fill `temporal_patch_size`, and a `grid_t` dimension) instead of Qwen's
`grid_t = 1` + expand. Mapping to or subclassing Qwen would therefore change behavior or require overriding nearly
everything, so the processor is kept on its own.

[[autodoc]] MiniMaxM3VLImageProcessor

## MiniMaxM3VLVideoProcessor

[[autodoc]] MiniMaxM3VLVideoProcessor

## MiniMaxM3VLVisionModel

[[autodoc]] MiniMaxM3VLVisionModel
- forward

## MiniMaxM3VLTextModel

[[autodoc]] MiniMaxM3VLTextModel
- forward

## MiniMaxM3VLModel

[[autodoc]] MiniMaxM3VLModel
- forward

## MiniMaxM3VLForCausalLM

[[autodoc]] MiniMaxM3VLForCausalLM
- forward

## MiniMaxM3SparseForConditionalGeneration

[[autodoc]] MiniMaxM3SparseForConditionalGeneration
- forward
15 changes: 14 additions & 1 deletion src/transformers/cache_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,12 @@
# ``PreTrainedConfig`` (the decoder text config) as the only positional argument.
LAYER_TYPE_CACHE_MAPPING: dict[str, type] = {}

# Parallel registry for the *static* implementation of a custom layer type, consulted by
# ``StaticCache``. A ``StaticLayer`` subclass with a ``layer_type`` registers here instead of in
# ``LAYER_TYPE_CACHE_MAPPING`` (constructed with ``max_cache_len=...``), so a single ``layer_type``
# can have both a dynamic and a static cache layer (e.g. M3's sparse-attention indexer cache).
LAYER_TYPE_STATIC_CACHE_MAPPING: dict[str, type] = {}


class CacheLayerMixin(ABC):
"""Base, abstract class for a single layer's cache."""
Expand All @@ -47,7 +53,11 @@ def __init_subclass__(cls, **kwargs):
super().__init_subclass__(**kwargs)
layer_type = cls.__dict__.get("layer_type", None)
if layer_type is not None:
LAYER_TYPE_CACHE_MAPPING[layer_type] = cls
static_base = globals().get("StaticLayer")
if static_base is not None and issubclass(cls, static_base):
LAYER_TYPE_STATIC_CACHE_MAPPING[layer_type] = cls
else:
LAYER_TYPE_CACHE_MAPPING[layer_type] = cls

def __init__(self):
self.keys: torch.Tensor | None = None
Expand Down Expand Up @@ -1578,6 +1588,9 @@ def __init__(
# LinearAttention layers are static by essence - using `"moe"` as well is a trick, see the comment about it on DynamicCache
elif layer_type in ("mamba", "conv", "linear_attention", "moe"):
layer = LinearAttentionLayer()
# Custom layer types (e.g. M3's sparse-attention indexer cache) that registered a static variant.
elif layer_type in LAYER_TYPE_STATIC_CACHE_MAPPING:
layer = LAYER_TYPE_STATIC_CACHE_MAPPING[layer_type](max_cache_len=max_cache_len)
elif layer_type == "deepseek_sparse_attention":
# Static / compile-friendly indexed layer (preallocated indexer key cache).
layer = StaticIndexedLayer(max_cache_len=max_cache_len)
Expand Down
1 change: 1 addition & 0 deletions src/transformers/configuration_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@
"chunked_attention",
"compressed_sparse_attention", # CSA, used in deepseek_v4
"heavily_compressed_attention", # HCA, used in deepseek_v4
"minimax_m3_sparse", # lightning-index sparse attention, used in minimax_m3_vl
"linear_attention", # used in minimax
"conv", # used in LFMv2
"mamba",
Expand Down
Loading
Loading