Plugin-based Sparse Attention Interface for DiT Modules
Status: Implemented — see PR #2231
Note: This RFC reflects the final implementation after maintainer review.
The original proposal used a separate ABC and independent registry;
the implementation integrates into the existing AttentionBackend infrastructure
per feedback from @gcanlin and @SamitHuang.
Motivation
vLLM-Omni currently runs DiT module attention with a standard full-attention
kernel. As model resolution and video length scale up, the quadratic cost of
full attention becomes the dominant bottleneck: a 720p video DiT with 8192
patch tokens spends more than 60% of step latency in attention alone.
Several sparse attention libraries have emerged that can cut this cost:
- SpargeAttn (thu-ml/SpargeAttn, ICML 2025): training-free, block-level
top-k prediction. Demonstrated 2–3× speedup on FLUX and HunyuanVideo at
<0.1% quality degradation.
- FlashInfer sparse kernels:
BlockSparseAttentionWrapper with BSR-format
masks. CUDA-graph-safe.
- RainFusion (arXiv 2505.21036): adaptive head-level classification with
spatiotemporal token permutation.
These libraries have incompatible calling conventions. This RFC proposes a
plugin-based architecture where vLLM-Omni provides only a thin integration
layer. All kernel implementations live in external packages that register via
Python entry_points.
Design
The design extends the existing AttentionBackend infrastructure rather than
building a parallel system.
┌────────────────────────────────────────────────────────────────────┐
│ External plugin packages │
│ sparge-vllm-omni · flashinfer-vllm-omni · custom-backend │
│ (pip install; register via vllm_omni.sparse_attn entry_points) │
└────────────────────────────┬───────────────────────────────────────┘
│ fn(q, k, v, params, is_causal) -> Tensor
┌────────────────────────────▼───────────────────────────────────────┐
│ vllm-omni: sparse_attention.py │
│ SparseAttentionBackend(AttentionBackend) │
│ _SparseAttentionImpl(AttentionImpl) │
│ - is_self_attention=False → SDPA fallback (cross-attention) │
│ - is_self_attention=True → resolve plugin fn → call kernel │
└────────────────────────────┬───────────────────────────────────────┘
│ integrated via existing mechanism
┌────────────────────────────▼───────────────────────────────────────┐
│ vllm-omni: existing infrastructure │
│ DiffusionAttentionBackendEnum +1 line: SPARSE_ATTENTION │
│ selector.py +entry_points fallback: --attn-backend spargeattn │
│ layer.py +is_self_attention param on Attention.__init__ │
│ WanCrossAttention +is_self_attention=False (1 line) │
└────────────────────────────────────────────────────────────────────┘
Core changes in vllm-omni
1. SparseAttentionBackend inherits AttentionBackend
A new file vllm_omni/diffusion/attention/backends/sparse_attention.py
implements the sparse dispatcher:
class SparseAttentionBackend(AttentionBackend):
@staticmethod
def get_name() -> str: return "sparse_attention"
@staticmethod
def get_impl_cls(): return _SparseAttentionImpl
...
class _SparseAttentionImpl(AttentionImpl):
def __init__(self, num_heads, head_size, softmax_scale,
causal=False, num_kv_heads=None, prefix="",
**extra_impl_args):
self._is_self_attention = extra_impl_args.get("is_self_attention", True)
self._dense = SDPABackend.get_impl_cls()(...) # always constructed
if not self._is_self_attention:
self._sparse_fn = None # cross-attention: always dense
return
# resolve plugin from forward_context config
cfg = self._read_sparse_cfg()
self._sparse_fn = self._resolve_plugin(cfg) # fn or None
self._params = cfg.params if cfg else 0.5
...
def forward_cuda(self, query, key, value, attn_metadata=None):
if self._sparse_fn is None:
return self._dense.forward_cuda(query, key, value, attn_metadata)
return self._sparse_fn(query, key, value, self._params, self._causal)
2. SPARSE_ATTENTION in DiffusionAttentionBackendEnum
One line added to registry.py:
class DiffusionAttentionBackendEnum(Enum, ...):
FLASH_ATTN = "...FlashAttentionBackend"
TORCH_SDPA = "...SDPABackend"
SAGE_ATTN = "...SageAttentionBackend"
SPARSE_ATTENTION = "...backends.sparse_attention.SparseAttentionBackend"
No independent registry. No parallel enum.
3. --attn-backend as the unified CLI interface
selector.py gets an entry_points fallback before the enum lookup:
def _get_plugin_backend_cls(name: str):
for ep in importlib.metadata.entry_points(group="vllm_omni.attn_backend"):
if ep.name.lower() == name.lower():
return ep.load()
return None
This means plugin names work directly without modifying vllm-omni's enum:
pip install sparge-vllm-omni
vllm-omni serve Wan2.2 --attn-backend spargeattn # direct plugin name
# or
vllm-omni serve Wan2.2 --attn-backend sparse_attention \
--sparse-attn '{"backend":"spargeattn","topk_ratio":0.5}'
--sparse-attn-backend is removed. --attn-backend is the single interface.
4. Cross-attention isolation via is_self_attention
Attention.__init__ in layer.py accepts a new parameter (default True,
fully backward compatible):
def __init__(self, ..., is_self_attention: bool = True):
self.attention = self.attn_impl_cls(
...,
is_self_attention=is_self_attention, # passed via **extra_impl_args
)
WanCrossAttention passes is_self_attention=False (one line):
self.attn = Attention(..., is_self_attention=False)
_SparseAttentionImpl reads this in __init__ and permanently sets
self._sparse_fn = None for cross-attention — zero runtime overhead,
decided once at construction.
The per-model enable_sparse_attention() hook is removed entirely.
5. Function-based plugin interface
Plugins register a function, not a class, via entry_points:
# pyproject.toml
[project.entry-points."vllm_omni.sparse_attn"]
spargeattn = "sparge_vllm_omni:sparge_attn_fn"
[project.entry-points."vllm_omni.attn_backend"]
spargeattn = "sparge_vllm_omni.backend:SpargeBackend"
Function signature:
def sparge_attn_fn(
query: torch.Tensor, # (B, H, S, D)
key: torch.Tensor,
value: torch.Tensor,
topk: float,
is_causal: bool,
) -> torch.Tensor:
...
This is simpler than a class-based builder. Backends that need BSR planning
(e.g. FlashInfer) can wrap their internal state behind this functional
interface, or implement the full AttentionBackend / AttentionImpl
class hierarchy and register via vllm_omni.attn_backend.
Configuration
DiffusionSparseAttnConfig is in vllm_omni/diffusion/data.py:
@dataclass
class DiffusionSparseAttnConfig:
backend: str = "auto" # plugin short name, class path, or "none"
pattern_type: str = "dynamic_topk"
params: dict = {}
block_size_q: int = 128
block_size_kv: int = 64
window_size: int | None = None
spatial_radius: int | None = None
schedule: str = "constant" # "constant" | "conservative" | "aggressive"
Integrated into OmniDiffusionConfig.sparse_attn.
Sparse pattern descriptor (retained)
SparsePatternSpec and SpatialLayout in vllm_omni/diffusion/sparse_attn/pattern.py
are retained as optional utilities for plugin authors who want a structured
pattern descriptor. They are not required by the core dispatcher.
DiTSparseAttentionAdapter in dit_adapter.py is retained for step-aware
sparsity scheduling (topk ramp across denoising steps).
File layout
vllm_omni/diffusion/
attention/
backends/
sparse_attention.py ← NEW (196 lines)
SparseAttentionBackend(AttentionBackend)
_SparseAttentionImpl(AttentionImpl)
registry.py ← +1 line: SPARSE_ATTENTION in enum
layer.py ← +4 lines: is_self_attention param
selector.py ← +18 lines: _get_plugin_backend_cls
sparse_attn/
pattern.py ← unchanged
dit_adapter.py ← unchanged
base.py ← retained (class-based ABC for plugin authors)
models/wan2_2/
wan2_2_transformer.py ← +1 line: is_self_attention=False
-25 lines: enable_sparse_attention() removed
tests/diffusion/sparse_attn/
test_protocol.py
test_config_cli.py
test_sparse_attention.py
test_wan22_sparse.py
test_wan22_gpu_integration.py
What changed from the original RFC
|
Original RFC |
Implemented |
| Backend ABC |
SparseAttentionBackend(ABC) |
SparseAttentionBackend(AttentionBackend) |
| Registry |
Independent sparse_attn/registry.py |
DiffusionAttentionBackendEnum (+1 line) |
| CLI |
--sparse-attn-backend |
--attn-backend (existing) |
| Plugin interface |
Class-based SparseAttentionImpl |
Function fn(q,k,v,topk,is_causal) |
| Cross-attn |
enable_sparse_attention() hook |
is_self_attention=False (1 line) |
| SP/Ring |
Bypassed |
Via Attention.forward wrapper (auto) |
contrib/ |
In repo |
External repos only |
Benchmark Results
Environment: NVIDIA A100 80GB PCIe, Wan2.2-T2V-A14B (14B), SpargeAttn plugin
End-to-end (480×832, 81 frames, 40 denoising steps):
| Backend |
Total |
Per-step |
Speedup |
| Dense FA3 |
785s |
19.6s |
1.00× |
| SpargeAttn topk=0.7 |
763s |
19.1s |
1.03× |
| SpargeAttn topk=0.5 |
677s |
16.9s |
1.16× |
| SpargeAttn topk=0.3 |
590s |
14.8s |
1.33× |
Kernel-level (attention-only, 40 heads × 128 head_dim):
| Config |
SeqLen |
Dense |
topk=0.5 |
topk=0.3 |
topk=0.2 |
| 480p 17f |
6,630 |
10.8ms |
10.1ms (1.1×) |
7.6ms (1.4×) |
6.4ms (1.7×) |
| 480p 33f |
12,870 |
39.5ms |
31.5ms (1.3×) |
22.1ms (1.8×) |
17.1ms (2.3×) |
| 720p 33f |
29,700 |
206.1ms |
146.7ms (1.4×) |
95.6ms (2.2×) |
70.0ms (2.9×) |
Writing a plugin backend
Minimal function-based plugin:
# sparge_vllm_omni/__init__.py
import torch
def sparge_attn_fn(
query: torch.Tensor, key: torch.Tensor, value: torch.Tensor,
topk: float, is_causal: bool,
) -> torch.Tensor:
from sparge import sparge_attn
return sparge_attn(query, key, value, topk_ratio=topk)
# pyproject.toml
[project.entry-points."vllm_omni.sparse_attn"]
spargeattn = "sparge_vllm_omni:sparge_attn_fn"
[project.entry-points."vllm_omni.attn_backend"]
spargeattn = "sparge_vllm_omni.backend:SpargeBackend"
Usage:
pip install sparge-vllm-omni
vllm-omni serve Wan-AI/Wan2.2-T2V-A14B --attn-backend spargeattn
What does NOT change
- The AR/LLM backbone attention path is untouched.
- Existing full-attention DiT models work unchanged when no sparse backend is configured.
- vLLM-Omni takes no hard dependency on any sparse library.
- SP/Ring communication is handled automatically by the existing
Attention.forward wrapper — sparse kernels see pre-communicated tensors.
Open questions
- Should
SparsePatternSpec / SparseMetadataBuilder remain in vllm-omni core as optional utilities, or move to a separate helper package?
- Should
DiTSparseAttentionAdapter (step-aware topk scheduling) be first-class public API or remain internal?
- Is
vllm_omni.attn_backend the right entry_point group name, or should it be vllm_omni.diffusion_attn_backend to avoid confusion with the AR attention path?
Alternatives considered
Separate SparseAttentionBackend ABC: Rejected. Inheriting AttentionBackend reuses existing infrastructure and avoids a parallel registry. (Addressed: @gcanlin)
--sparse-attn-backend CLI flag: Rejected. --attn-backend is the existing unified interface; a second flag is unnecessary. (Addressed: @SamitHuang)
Per-model enable_sparse_attention() hook: Rejected. is_self_attention=False achieves the same cross-attention isolation with 1 line instead of a per-model method. (Addressed: @alex-jw-brooks)
Related
CC: @wtomin @ZJY0516 @hsliuustc0106 @jiangmengyu18 @gglorian @gcanlin @SamitHuang @alex-jw-brooks
Plugin-based Sparse Attention Interface for DiT Modules
Motivation
vLLM-Omni currently runs DiT module attention with a standard full-attention
kernel. As model resolution and video length scale up, the quadratic cost of
full attention becomes the dominant bottleneck: a 720p video DiT with 8192
patch tokens spends more than 60% of step latency in attention alone.
Several sparse attention libraries have emerged that can cut this cost:
top-k prediction. Demonstrated 2–3× speedup on FLUX and HunyuanVideo at
<0.1% quality degradation.
BlockSparseAttentionWrapperwith BSR-formatmasks. CUDA-graph-safe.
spatiotemporal token permutation.
These libraries have incompatible calling conventions. This RFC proposes a
plugin-based architecture where vLLM-Omni provides only a thin integration
layer. All kernel implementations live in external packages that register via
Python entry_points.
Design
The design extends the existing
AttentionBackendinfrastructure rather thanbuilding a parallel system.
Core changes in vllm-omni
1.
SparseAttentionBackendinheritsAttentionBackendA new file
vllm_omni/diffusion/attention/backends/sparse_attention.pyimplements the sparse dispatcher:
2.
SPARSE_ATTENTIONinDiffusionAttentionBackendEnumOne line added to
registry.py:No independent registry. No parallel enum.
3.
--attn-backendas the unified CLI interfaceselector.pygets an entry_points fallback before the enum lookup:This means plugin names work directly without modifying vllm-omni's enum:
--sparse-attn-backendis removed.--attn-backendis the single interface.4. Cross-attention isolation via
is_self_attentionAttention.__init__inlayer.pyaccepts a new parameter (defaultTrue,fully backward compatible):
WanCrossAttentionpassesis_self_attention=False(one line):_SparseAttentionImplreads this in__init__and permanently setsself._sparse_fn = Nonefor cross-attention — zero runtime overhead,decided once at construction.
The per-model
enable_sparse_attention()hook is removed entirely.5. Function-based plugin interface
Plugins register a function, not a class, via entry_points:
Function signature:
This is simpler than a class-based builder. Backends that need BSR planning
(e.g. FlashInfer) can wrap their internal state behind this functional
interface, or implement the full
AttentionBackend/AttentionImplclass hierarchy and register via
vllm_omni.attn_backend.Configuration
DiffusionSparseAttnConfigis invllm_omni/diffusion/data.py:Integrated into
OmniDiffusionConfig.sparse_attn.Sparse pattern descriptor (retained)
SparsePatternSpecandSpatialLayoutinvllm_omni/diffusion/sparse_attn/pattern.pyare retained as optional utilities for plugin authors who want a structured
pattern descriptor. They are not required by the core dispatcher.
DiTSparseAttentionAdapterindit_adapter.pyis retained for step-awaresparsity scheduling (topk ramp across denoising steps).
File layout
What changed from the original RFC
SparseAttentionBackend(ABC)SparseAttentionBackend(AttentionBackend)sparse_attn/registry.pyDiffusionAttentionBackendEnum(+1 line)--sparse-attn-backend--attn-backend(existing)SparseAttentionImplfn(q,k,v,topk,is_causal)enable_sparse_attention()hookis_self_attention=False(1 line)Attention.forwardwrapper (auto)contrib/Benchmark Results
Environment: NVIDIA A100 80GB PCIe, Wan2.2-T2V-A14B (14B), SpargeAttn plugin
End-to-end (480×832, 81 frames, 40 denoising steps):
Kernel-level (attention-only, 40 heads × 128 head_dim):
Writing a plugin backend
Minimal function-based plugin:
Usage:
What does NOT change
Attention.forwardwrapper — sparse kernels see pre-communicated tensors.Open questions
SparsePatternSpec/SparseMetadataBuilderremain in vllm-omni core as optional utilities, or move to a separate helper package?DiTSparseAttentionAdapter(step-aware topk scheduling) be first-class public API or remain internal?vllm_omni.attn_backendthe right entry_point group name, or should it bevllm_omni.diffusion_attn_backendto avoid confusion with the AR attention path?Alternatives considered
Separate SparseAttentionBackend ABC: Rejected. Inheriting
AttentionBackendreuses existing infrastructure and avoids a parallel registry. (Addressed: @gcanlin)--sparse-attn-backendCLI flag: Rejected.--attn-backendis the existing unified interface; a second flag is unnecessary. (Addressed: @SamitHuang)Per-model
enable_sparse_attention()hook: Rejected.is_self_attention=Falseachieves the same cross-attention isolation with 1 line instead of a per-model method. (Addressed: @alex-jw-brooks)Related
CC: @wtomin @ZJY0516 @hsliuustc0106 @jiangmengyu18 @gglorian @gcanlin @SamitHuang @alex-jw-brooks