Add minimax m3vl#46600
Conversation
…dequant - Modular implementation reusing minimax_m2 for the M2-shaped pieces (router, decoder layer skeleton, expert layout) and adding M3 deltas: shared experts + sigmoid routing scaling, dense-MLP-for-first-N-layers, swigluoai with alpha/limit clamp, partial RoPE (rotary_dim < head_dim), per-head Gemma-style QK norm, and the sparse-attention lightning-index branch (full dense math fallback for v1; top-k block selection deferred). - CLIP-like vision tower with Conv3d patch embed and a 3D RoPE over (T, H, W) patch grids, plus LLaVA-style projector + spatial patch merger. - Fast image / video processors and a multimodal processor ported from the snapshot's preprocessing code; auto registrations for config / model / processor / image / video. - Tiny model fixtures: build_tiny_model.py (transformers layout) and build_tiny_model_sglang.py (re-keyed for sglang FusedMoE), so parity checks across backends do not require the 100B-param real model. - dequantize_mxfp8.py converts the MXFP8 ([1, 32]-blocked E4M3) checkpoint shipped by MiniMax to bf16 -- transformers has MXFP4 today but no MXFP8, so the first port dequantizes up front. Status: text + vision forward passes work end-to-end on the tiny model; both AutoModelForImageTextToText and AutoModelForCausalLM are wired. Sparse attention is dense-math only and MTP modules are skipped.
* Drop config fields that the real M3 VL checkpoint has only a single
value for and inline the constants:
- text: use_gemma_norm, use_qk_norm, qk_norm_type, attention_output_gate,
hidden_act, scoring_func, partial_rotary_factor
- composite: multimodal_projector_bias, projector_hidden_act,
vision_feature_layer, vision_feature_select_strategy,
process_image_mode, img_token_compression_config, image_grid_pinpoints,
image_seq_length
- vision: rope_mode, vision_segment_max_frames, hidden_act
* Norms: single Gemma-style RMSNorm class (drop the `gemma:` switch and the
plain RMSNorm alias). Activation in the dense MLP / experts / projector is
always swigluoai / GELU; remove the alternate branches.
* Attention: per-head Gemma QK norm and partial RoPE are always on; remove
the per_layer / no-norm / full-RoPE branches.
* Replace the dense fallback of the sparse-attention layer with a proper
lightning indexer modeled on `DeepseekV4Indexer`:
- MiniMaxM3VLIndexer scores key blocks via idx_q · idx_k, max-reduces
over `sparse_block_size`, takes top-k blocks per query, then forces
the first `sparse_init_block` and last `sparse_local_block` blocks
visible.
- Encodes the selection as an additive `[B, 1, S, S]` block_bias
(`-inf` outside the allowed blocks), summed onto attention_mask
before SDPA -- same scatter pattern as deepseek_v4.
- Index branch still runs its own SDPA over (idx_q, idx_k, idx_v) and
adds the result through `o_proj`, matching sglang's parallel index
output path. `disable_index_value` layers skip the value side.
Tiny VLM / text-only LM forward passes still work end-to-end via
AutoModelForImageTextToText / AutoModelForCausalLM.
Cache: * Add MiniMaxM3VLSparseCacheLayer(DynamicLayer) — auto-registers under layer_type "minimax_m3_sparse". Stores idx_keys / idx_values alongside the main K/V so the lightning indexer can keep scoring against the full prefix during decode without recomputing it. Mirrors the DeepseekV4CSACache pattern (per-layer cache instance reached via past_key_values.layers[layer_idx]; custom update_index method analogous to DSv4's store_compression_weights / update_compressor_states). * Register the new value in `ALLOWED_LAYER_TYPES` so the strict-dataclass validator accepts mixed ["full_attention", ..., "minimax_m3_sparse"] layer_types on the text config. * Derive layer_types from sparse_attention_freq inside MiniMaxM3VLTextModel.__init__ so old checkpoints without a layer_types field still dispatch the right cache class. Indexer: * MiniMaxM3VLIndexer.forward now takes position_ids + past_key_values + layer_idx, looks up the per-layer cache, calls update_index with the new token's (idx_k, idx_v), and scores idx_q against the full cached idx_keys (not just the new token's idx_k). Block_bias and the index branch's own SDPA both use absolute query positions so causality works for Sq < Sk during decode. Demo: * minimax.py runs the image+text → AutoProcessor → .generate() round-trip on the tiny model. Tiny model config now uses the real 200064 vocab and the snapshot image_token_index (200025), so the snapshot tokenizer + chat_template drop in unchanged. * Drop the no-longer-needed `image_processor_class`/`video_processor_class` attrs on MiniMaxM3VLProcessor (the AutoImageProcessor / AutoVideoProcessor registrations cover the lookup and the explicit strings emit a deprecation warning). Verified: * Cached prefill + decode argmax matches the no-cache full forward. * model.generate() produces tokens end-to-end. * minimax.py runs on the tiny model with the snapshot tokenizer.
- add get_video_features / video placeholder mask and consume pixel_values_videos + video_grid_thw in the model forward, scattering video features into the video-token slots - support the native (non-dequantized) MXFP8 compute path: keep experts in float8_e4m3fn and carry the SwiGLU-OAI gate onto the FP8 experts - add model-level video tests (forward, token mismatch) and a slow end-to-end video generation check; add docs page + toctree entry - realign deepseek_v4 cache_layer annotation in the modular source so the generated file matches
|
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
| hidden_states = decoder_layer( | ||
| hidden_states, | ||
| attention_mask=causal_mask, | ||
| position_ids=position_ids, |
There was a problem hiding this comment.
We still need to pass them here, top level needs it but below we can safely ignore
vasqu
left a comment
There was a problem hiding this comment.
Preapproval 🫡 nice work, we can refactor a bit afterwards but nothing major
Only revert the pos ids one since it could break some other models (e.g. doge)
|
[For maintainers] Suggested jobs to run (before merge) run-slow: auto, deepseek_v4, minimax_m3_vl, finegrained_fp8 |
|
CI Dashboard: View test results in Grafana
|
|
View the CircleCI Test Summary for this PR: https://huggingface.co/spaces/transformers-community/circle-ci-viz?pr=46600&sha=c6615c |
| if pad: | ||
| scores = F.pad(scores, (0, pad), value=float("-inf")) | ||
| scores = scores.view(batch, self.num_heads, q_len, num_key_blocks, self.block_size) | ||
| block_scores = scores.amax(dim=-1).amax(dim=1) # -> [B, S_q, num_key_blocks] |
There was a problem hiding this comment.
Why do we do amax(dim=1) here? does it mean that all GQA groups select the same blocks?
There was a problem hiding this comment.
There was a problem hiding this comment.
will have a look!
What does this PR do?