Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 14 additions & 18 deletions vllm/models/deepseek_v4/amd/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,8 +48,7 @@
)
from vllm.models.deepseek_v4.attention import (
DeepseekV4Indexer,
DeepseekV4MLAModules,
DeepseekV4MultiHeadLatentAttentionWrapper,
DeepseekV4MLA,
)
from vllm.platforms import current_platform
from vllm.sequence import IntermediateTensors
Expand Down Expand Up @@ -314,7 +313,7 @@ def __init__(

self.rope_parameters = config.rope_scaling

# Initialize rotary embedding BEFORE DeepseekV4MLAModules (which needs it)
# Initialize rotary embedding BEFORE DeepseekV4MLA (which needs it)
rope_parameters = config.rope_parameters
rope_parameters["rope_theta"] = (
config.compress_rope_theta if self.compress_ratio > 1 else config.rope_theta
Expand Down Expand Up @@ -351,7 +350,17 @@ def __init__(
prefix=f"{prefix}.indexer",
)

mla_modules = DeepseekV4MLAModules(
self.mla_attn = DeepseekV4MLA(
hidden_size=self.hidden_size,
num_heads=self.n_local_heads,
head_dim=self.head_dim,
scale=self.softmax_scale,
qk_nope_head_dim=self.nope_head_dim,
qk_rope_head_dim=self.rope_head_dim,
v_head_dim=self.head_dim,
q_lora_rank=self.q_lora_rank,
kv_lora_rank=self.head_dim,
o_lora_rank=self.o_lora_rank,
vllm_config=vllm_config,
fused_wqa_wkv=self.fused_wqa_wkv,
q_norm=self.q_norm,
Expand All @@ -365,19 +374,6 @@ def __init__(
indexer_rotary_emb=self.rotary_emb,
topk_indices_buffer=topk_indices_buffer,
aux_stream_list=aux_stream_list,
)
self.mla_attn = DeepseekV4MultiHeadLatentAttentionWrapper(
hidden_size=self.hidden_size,
num_heads=self.n_local_heads,
head_dim=self.head_dim,
scale=self.softmax_scale,
qk_nope_head_dim=self.nope_head_dim,
qk_rope_head_dim=self.rope_head_dim,
v_head_dim=self.head_dim,
q_lora_rank=self.q_lora_rank,
kv_lora_rank=self.head_dim,
o_lora_rank=self.o_lora_rank,
mla_modules=mla_modules,
window_size=self.window_size,
compress_ratio=self.compress_ratio,
cache_config=vllm_config.cache_config,
Expand Down Expand Up @@ -618,7 +614,7 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
self.rms_norm_eps = config.rms_norm_eps

# Three aux streams: one per non-default input GEMM in
# DeepseekV4MultiHeadLatentAttentionWrapper.attn_gemm_parallel_execute
# DeepseekV4MLA.attn_gemm_parallel_execute
# (compressor kv_score, indexer.weights_proj, indexer.compressor
# kv_score). fused_wqa_wkv stays on the default stream.
# Disable them on ROCm because of hang issues.
Expand Down
122 changes: 34 additions & 88 deletions vllm/models/deepseek_v4/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
"""

from collections.abc import Callable
from dataclasses import dataclass
from typing import TYPE_CHECKING, Any, cast

import torch
Expand Down Expand Up @@ -38,9 +37,8 @@
get_current_vllm_config,
)
from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.forward_context import ForwardContext, get_forward_context
from vllm.forward_context import get_forward_context
from vllm.logger import init_logger
from vllm.model_executor.custom_op import PluggableLayer
from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.quantization import QuantizationConfig
Expand Down Expand Up @@ -90,46 +88,7 @@ def _select_v4_sparse_impl() -> "type[DeepseekV4SparseMLAAttentionImpl]":
return DeepseekV4FlashMLASparseImpl


@dataclass
class DeepseekV4MLAModules:
"""Modules used in DeepseekV4 MLA."""

vllm_config: VllmConfig
fused_wqa_wkv: torch.nn.Module
q_norm: torch.nn.Module
wq_b: torch.nn.Module
kv_norm: torch.nn.Module
wo_a: torch.nn.Module
wo_b: torch.nn.Module
attn_sink: torch.nn.Module
rotary_emb: torch.nn.Module
indexer: torch.nn.Module | None
indexer_rotary_emb: torch.nn.Module
topk_indices_buffer: torch.Tensor | None
aux_stream_list: list[torch.cuda.Stream] | None = None


# --8<-- [start:multi_head_latent_attention]
@PluggableLayer.register("deepseek_v4_multi_head_latent_attention")
class DeepseekV4MultiHeadLatentAttentionWrapper(PluggableLayer):
"""Pluggable MLA layer which allows OOT backends to add
custom implementations of the outer MLA layer (including rope & o_proj).
Note that currently oot platforms can still use CustomOp.register_oot to
replace MLA layer entirely, although we use PluggableLayer to register
this layer now.

This class takes positions and hidden_states as input.
The input tensors can either contain prefill tokens or decode tokens.
The class does the following:

1. MLA Preprocess.
2. Perform multi-head attention to prefill tokens and
multi-query attention to decode tokens separately.
3. Return the output tensor.
"""

# --8<-- [end:multi_head_latent_attention]

class DeepseekV4MLA(nn.Module):
def __init__(
self,
hidden_size: int,
Expand All @@ -142,7 +101,19 @@ def __init__(
q_lora_rank: int | None,
kv_lora_rank: int,
o_lora_rank: int | None,
mla_modules: DeepseekV4MLAModules,
vllm_config: VllmConfig,
fused_wqa_wkv: torch.nn.Module,
q_norm: torch.nn.Module,
wq_b: torch.nn.Module,
kv_norm: torch.nn.Module,
wo_a: torch.nn.Module,
wo_b: torch.nn.Module,
attn_sink: torch.nn.Module,
rotary_emb: torch.nn.Module,
indexer: torch.nn.Module | None,
indexer_rotary_emb: torch.nn.Module,
topk_indices_buffer: torch.Tensor | None,
aux_stream_list: list[torch.cuda.Stream] | None,
window_size: int,
compress_ratio: int | None,
cache_config: CacheConfig | None = None,
Expand All @@ -162,7 +133,7 @@ def __init__(
self.prefix = prefix

# Extract config from vllm_config
config = mla_modules.vllm_config.model_config.hf_config
config = vllm_config.model_config.hf_config
tp_size = get_tensor_model_parallel_world_size()

# DeepseekV4-specific attributes (num_heads is already TP-adjusted)
Expand All @@ -173,12 +144,12 @@ def __init__(
self.o_lora_rank = config.o_lora_rank

# Store projection modules
self.fused_wqa_wkv = mla_modules.fused_wqa_wkv
self.q_norm = mla_modules.q_norm
self.wq_b = mla_modules.wq_b
self.fused_wqa_wkv = fused_wqa_wkv
self.q_norm = q_norm
self.wq_b = wq_b

self.kv_norm = mla_modules.kv_norm
self.wo_a = mla_modules.wo_a
self.kv_norm = kv_norm
self.wo_a = wo_a

self._wo_a_act_quant = QuantFP8(
static=False,
Expand All @@ -188,7 +159,7 @@ def __init__(
# Bypass packed-for-deepgemm path — we need FP32 scales (not packed
# INT32) so fp8_einsum can handle layout transform internally.
self._wo_a_act_quant.use_deep_gemm_supported = False
self.wo_b = mla_modules.wo_b
self.wo_b = wo_b

# Pick fp8_einsum recipe based on GPU arch:
# SM90: FP32 block scales stay [g, r/128, d/128] → sfb_gran_mn=128
Expand All @@ -198,11 +169,11 @@ def __init__(
self._einsum_recipe = (1, 128, 128) if cap.major <= 9 else (1, 1, 128)
self._tma_aligned_scales = cap.major >= 10

self.rotary_emb = mla_modules.rotary_emb
self.indexer_rotary_emb = mla_modules.indexer_rotary_emb
self.topk_indices_buffer = mla_modules.topk_indices_buffer
self.rotary_emb = rotary_emb
self.indexer_rotary_emb = indexer_rotary_emb
self.topk_indices_buffer = topk_indices_buffer

self.indexer = mla_modules.indexer
self.indexer = indexer

# Per-head RMS normalization for Q (no learnable weights)
self.q_head_norm = RMSNorm(head_dim, eps=self.eps, has_weight=False)
Expand All @@ -216,7 +187,7 @@ def __init__(
)

# Will be None on ROCm for now.
self.aux_stream_list = mla_modules.aux_stream_list
self.aux_stream_list = aux_stream_list
# [0]: GEMM start / post-GEMM event0. [1..3]: GEMM done events;
# [1] doubles as post-GEMM event1. Reuse is safe: GEMM fully joins
# before post-GEMM starts.
Expand All @@ -243,7 +214,7 @@ def __init__(
window_size=self.window_size,
head_bytes=head_bytes,
swa_cache_layer=self.swa_cache_layer,
attn_sink=mla_modules.attn_sink, # already padded with -inf
attn_sink=attn_sink, # already padded with -inf
cache_config=cache_config,
quant_config=quant_config,
prefix=prefix,
Expand All @@ -253,21 +224,12 @@ def __init__(
# Mirror the inner layer's padded head count (single source of truth).
self.padded_heads = self.mla_attn.padded_heads

# Register this layer in the compilation config's static forward context
# This allows the custom op to retrieve the layer during execution
compilation_config = mla_modules.vllm_config.compilation_config
# HACK
self.layer_name = prefix + ".deepseek_v4_multi_head_latent_attention"
if self.layer_name in compilation_config.static_forward_context:
raise ValueError(f"Duplicate layer name: {self.layer_name}")
compilation_config.static_forward_context[self.layer_name] = self

# Create the compressor for layers with compress_ratio > 1; after
# creating the DeepseekV4MLAAttention layer to get its cache.
self.compressor = None
if self.compress_ratio > 1:
self.compressor = DeepseekCompressor(
vllm_config=mla_modules.vllm_config,
vllm_config=vllm_config,
compress_ratio=self.compress_ratio,
hidden_size=self.hidden_size,
head_dim=self.head_dim,
Expand All @@ -291,15 +253,10 @@ def forward(
device=hidden_states.device,
)

# @eager_break_during_capture: this is where the breakable
# cudagraph capture breaks (the attention op runs eagerly between
# captured graph segments).
deepseek_v4_attention(
hidden_states,
positions,
o_padded,
self.layer_name,
)
# attention_impl is wrapped with @eager_break_during_capture: this is
# where the breakable cudagraph capture breaks (the attention op runs
# eagerly between captured graph segments).
self.attention_impl(hidden_states, positions, o_padded)
o = o_padded[:, : self.n_local_heads, :]

# Keep ROCm on the BF16 reference wo_a path util kernel ready.
Expand Down Expand Up @@ -405,6 +362,7 @@ def fused_wqa_wkv() -> torch.Tensor:

return qr_kv, kv_score, indexer_kv_score, indexer_weights

@eager_break_during_capture
def attention_impl(
self,
hidden_states: torch.Tensor,
Expand Down Expand Up @@ -541,18 +499,6 @@ def _fused_qnorm_rope_kv_insert(
)


@eager_break_during_capture
def deepseek_v4_attention(
hidden_states: torch.Tensor,
positions: torch.Tensor,
out: torch.Tensor,
layer_name: str,
) -> None:
forward_context: ForwardContext = get_forward_context()
self = forward_context.no_compile_layers[layer_name]
self.attention_impl(hidden_states, positions, out)


class DeepseekV4MLAAttention(nn.Module, AttentionLayerBase):
def __init__(
self,
Expand Down
32 changes: 14 additions & 18 deletions vllm/models/deepseek_v4/nvidia/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,8 +54,7 @@
from vllm.model_executor.utils import set_weight_attrs
from vllm.models.deepseek_v4.attention import (
DeepseekV4Indexer,
DeepseekV4MLAModules,
DeepseekV4MultiHeadLatentAttentionWrapper,
DeepseekV4MLA,
)
from vllm.models.deepseek_v4.nvidia.ops.prepare_megamoe import prepare_megamoe_inputs
from vllm.sequence import IntermediateTensors
Expand Down Expand Up @@ -697,7 +696,7 @@ def __init__(

self.rope_parameters = config.rope_scaling

# Initialize rotary embedding BEFORE DeepseekV4MLAModules (which needs it)
# Initialize rotary embedding BEFORE DeepseekV4MLA (which needs it)
rope_parameters = config.rope_parameters
rope_parameters["rope_theta"] = (
config.compress_rope_theta if self.compress_ratio > 1 else config.rope_theta
Expand Down Expand Up @@ -741,7 +740,17 @@ def __init__(
aux_stream=indexer_aux_stream,
)

mla_modules = DeepseekV4MLAModules(
self.mla_attn = DeepseekV4MLA(
hidden_size=self.hidden_size,
num_heads=self.n_local_heads,
head_dim=self.head_dim,
scale=self.softmax_scale,
qk_nope_head_dim=self.nope_head_dim,
qk_rope_head_dim=self.rope_head_dim,
v_head_dim=self.head_dim,
q_lora_rank=self.q_lora_rank,
kv_lora_rank=self.head_dim,
o_lora_rank=self.o_lora_rank,
vllm_config=vllm_config,
fused_wqa_wkv=self.fused_wqa_wkv,
q_norm=self.q_norm,
Expand All @@ -755,19 +764,6 @@ def __init__(
indexer_rotary_emb=self.rotary_emb,
topk_indices_buffer=topk_indices_buffer,
aux_stream_list=aux_stream_list,
)
self.mla_attn = DeepseekV4MultiHeadLatentAttentionWrapper(
hidden_size=self.hidden_size,
num_heads=self.n_local_heads,
head_dim=self.head_dim,
scale=self.softmax_scale,
qk_nope_head_dim=self.nope_head_dim,
qk_rope_head_dim=self.rope_head_dim,
v_head_dim=self.head_dim,
q_lora_rank=self.q_lora_rank,
kv_lora_rank=self.head_dim,
o_lora_rank=self.o_lora_rank,
mla_modules=mla_modules,
window_size=self.window_size,
compress_ratio=self.compress_ratio,
cache_config=vllm_config.cache_config,
Expand Down Expand Up @@ -955,7 +951,7 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
self.rms_norm_eps = config.rms_norm_eps

# Three aux streams: one per non-default input GEMM in
# DeepseekV4MultiHeadLatentAttentionWrapper.attn_gemm_parallel_execute
# DeepseekV4MLA.attn_gemm_parallel_execute
# (compressor kv_score, indexer.weights_proj, indexer.compressor
# kv_score). fused_wqa_wkv stays on the default stream.
aux_stream_list = [torch.cuda.Stream() for _ in range(3)]
Expand Down
Loading