Skip to content

[RFC]: Plugin-based Sparse Attention Interface for DiT Modules #2233

@zzhang-fr

Description

@zzhang-fr

Plugin-based Sparse Attention Interface for DiT Modules

Status: Implemented — see PR #2231
Note: This RFC reflects the final implementation after maintainer review.
The original proposal used a separate ABC and independent registry;
the implementation integrates into the existing AttentionBackend infrastructure
per feedback from @gcanlin and @SamitHuang.

Motivation

vLLM-Omni currently runs DiT module attention with a standard full-attention
kernel. As model resolution and video length scale up, the quadratic cost of
full attention becomes the dominant bottleneck: a 720p video DiT with 8192
patch tokens spends more than 60% of step latency in attention alone.

Several sparse attention libraries have emerged that can cut this cost:

  • SpargeAttn (thu-ml/SpargeAttn, ICML 2025): training-free, block-level
    top-k prediction. Demonstrated 2–3× speedup on FLUX and HunyuanVideo at
    <0.1% quality degradation.
  • FlashInfer sparse kernels: BlockSparseAttentionWrapper with BSR-format
    masks. CUDA-graph-safe.
  • RainFusion (arXiv 2505.21036): adaptive head-level classification with
    spatiotemporal token permutation.

These libraries have incompatible calling conventions. This RFC proposes a
plugin-based architecture where vLLM-Omni provides only a thin integration
layer. All kernel implementations live in external packages that register via
Python entry_points.

Design

The design extends the existing AttentionBackend infrastructure rather than
building a parallel system.

┌────────────────────────────────────────────────────────────────────┐
│  External plugin packages                                          │
│  sparge-vllm-omni · flashinfer-vllm-omni · custom-backend         │
│  (pip install; register via vllm_omni.sparse_attn entry_points)   │
└────────────────────────────┬───────────────────────────────────────┘
                             │  fn(q, k, v, params, is_causal) -> Tensor
┌────────────────────────────▼───────────────────────────────────────┐
│  vllm-omni: sparse_attention.py                                    │
│  SparseAttentionBackend(AttentionBackend)                          │
│  _SparseAttentionImpl(AttentionImpl)                               │
│    - is_self_attention=False → SDPA fallback (cross-attention)     │
│    - is_self_attention=True  → resolve plugin fn → call kernel     │
└────────────────────────────┬───────────────────────────────────────┘
                             │  integrated via existing mechanism
┌────────────────────────────▼───────────────────────────────────────┐
│  vllm-omni: existing infrastructure                                │
│  DiffusionAttentionBackendEnum  +1 line: SPARSE_ATTENTION          │
│  selector.py  +entry_points fallback: --attn-backend spargeattn   │
│  layer.py     +is_self_attention param on Attention.__init__       │
│  WanCrossAttention  +is_self_attention=False  (1 line)             │
└────────────────────────────────────────────────────────────────────┘

Core changes in vllm-omni

1. SparseAttentionBackend inherits AttentionBackend

A new file vllm_omni/diffusion/attention/backends/sparse_attention.py
implements the sparse dispatcher:

class SparseAttentionBackend(AttentionBackend):
    @staticmethod
    def get_name() -> str: return "sparse_attention"
    @staticmethod
    def get_impl_cls(): return _SparseAttentionImpl
    ...

class _SparseAttentionImpl(AttentionImpl):
    def __init__(self, num_heads, head_size, softmax_scale,
                 causal=False, num_kv_heads=None, prefix="",
                 **extra_impl_args):
        self._is_self_attention = extra_impl_args.get("is_self_attention", True)
        self._dense = SDPABackend.get_impl_cls()(...)  # always constructed

        if not self._is_self_attention:
            self._sparse_fn = None   # cross-attention: always dense
            return

        # resolve plugin from forward_context config
        cfg = self._read_sparse_cfg()
        self._sparse_fn = self._resolve_plugin(cfg)  # fn or None
        self._params = cfg.params if cfg else 0.5
        ...

    def forward_cuda(self, query, key, value, attn_metadata=None):
        if self._sparse_fn is None:
            return self._dense.forward_cuda(query, key, value, attn_metadata)
        return self._sparse_fn(query, key, value, self._params, self._causal)

2. SPARSE_ATTENTION in DiffusionAttentionBackendEnum

One line added to registry.py:

class DiffusionAttentionBackendEnum(Enum, ...):
    FLASH_ATTN       = "...FlashAttentionBackend"
    TORCH_SDPA       = "...SDPABackend"
    SAGE_ATTN        = "...SageAttentionBackend"
    SPARSE_ATTENTION = "...backends.sparse_attention.SparseAttentionBackend"

No independent registry. No parallel enum.

3. --attn-backend as the unified CLI interface

selector.py gets an entry_points fallback before the enum lookup:

def _get_plugin_backend_cls(name: str):
    for ep in importlib.metadata.entry_points(group="vllm_omni.attn_backend"):
        if ep.name.lower() == name.lower():
            return ep.load()
    return None

This means plugin names work directly without modifying vllm-omni's enum:

pip install sparge-vllm-omni
vllm-omni serve Wan2.2 --attn-backend spargeattn     # direct plugin name
# or
vllm-omni serve Wan2.2 --attn-backend sparse_attention \
  --sparse-attn '{"backend":"spargeattn","topk_ratio":0.5}'

--sparse-attn-backend is removed. --attn-backend is the single interface.

4. Cross-attention isolation via is_self_attention

Attention.__init__ in layer.py accepts a new parameter (default True,
fully backward compatible):

def __init__(self, ..., is_self_attention: bool = True):
    self.attention = self.attn_impl_cls(
        ...,
        is_self_attention=is_self_attention,   # passed via **extra_impl_args
    )

WanCrossAttention passes is_self_attention=False (one line):

self.attn = Attention(..., is_self_attention=False)

_SparseAttentionImpl reads this in __init__ and permanently sets
self._sparse_fn = None for cross-attention — zero runtime overhead,
decided once at construction.

The per-model enable_sparse_attention() hook is removed entirely.

5. Function-based plugin interface

Plugins register a function, not a class, via entry_points:

# pyproject.toml
[project.entry-points."vllm_omni.sparse_attn"]
spargeattn = "sparge_vllm_omni:sparge_attn_fn"

[project.entry-points."vllm_omni.attn_backend"]
spargeattn = "sparge_vllm_omni.backend:SpargeBackend"

Function signature:

def sparge_attn_fn(
    query: torch.Tensor,   # (B, H, S, D)
    key:   torch.Tensor,
    value: torch.Tensor,
    topk:  float,
    is_causal: bool,
) -> torch.Tensor:
    ...

This is simpler than a class-based builder. Backends that need BSR planning
(e.g. FlashInfer) can wrap their internal state behind this functional
interface, or implement the full AttentionBackend / AttentionImpl
class hierarchy and register via vllm_omni.attn_backend.

Configuration

DiffusionSparseAttnConfig is in vllm_omni/diffusion/data.py:

@dataclass
class DiffusionSparseAttnConfig:
    backend: str = "auto"          # plugin short name, class path, or "none"
    pattern_type: str = "dynamic_topk"
    params: dict = {}
    block_size_q: int = 128
    block_size_kv: int = 64
    window_size: int | None = None
    spatial_radius: int | None = None
    schedule: str = "constant"     # "constant" | "conservative" | "aggressive"

Integrated into OmniDiffusionConfig.sparse_attn.

Sparse pattern descriptor (retained)

SparsePatternSpec and SpatialLayout in vllm_omni/diffusion/sparse_attn/pattern.py
are retained as optional utilities for plugin authors who want a structured
pattern descriptor. They are not required by the core dispatcher.

DiTSparseAttentionAdapter in dit_adapter.py is retained for step-aware
sparsity scheduling (topk ramp across denoising steps).

File layout

vllm_omni/diffusion/
  attention/
    backends/
      sparse_attention.py    ← NEW (196 lines)
                               SparseAttentionBackend(AttentionBackend)
                               _SparseAttentionImpl(AttentionImpl)
    registry.py              ← +1 line: SPARSE_ATTENTION in enum
    layer.py                 ← +4 lines: is_self_attention param
    selector.py              ← +18 lines: _get_plugin_backend_cls
  sparse_attn/
    pattern.py               ← unchanged
    dit_adapter.py           ← unchanged
    base.py                  ← retained (class-based ABC for plugin authors)
  models/wan2_2/
    wan2_2_transformer.py    ← +1 line: is_self_attention=False
                               -25 lines: enable_sparse_attention() removed
tests/diffusion/sparse_attn/
  test_protocol.py
  test_config_cli.py
  test_sparse_attention.py
  test_wan22_sparse.py
  test_wan22_gpu_integration.py

What changed from the original RFC

Original RFC Implemented
Backend ABC SparseAttentionBackend(ABC) SparseAttentionBackend(AttentionBackend)
Registry Independent sparse_attn/registry.py DiffusionAttentionBackendEnum (+1 line)
CLI --sparse-attn-backend --attn-backend (existing)
Plugin interface Class-based SparseAttentionImpl Function fn(q,k,v,topk,is_causal)
Cross-attn enable_sparse_attention() hook is_self_attention=False (1 line)
SP/Ring Bypassed Via Attention.forward wrapper (auto)
contrib/ In repo External repos only

Benchmark Results

Environment: NVIDIA A100 80GB PCIe, Wan2.2-T2V-A14B (14B), SpargeAttn plugin

End-to-end (480×832, 81 frames, 40 denoising steps):

Backend Total Per-step Speedup
Dense FA3 785s 19.6s 1.00×
SpargeAttn topk=0.7 763s 19.1s 1.03×
SpargeAttn topk=0.5 677s 16.9s 1.16×
SpargeAttn topk=0.3 590s 14.8s 1.33×

Kernel-level (attention-only, 40 heads × 128 head_dim):

Config SeqLen Dense topk=0.5 topk=0.3 topk=0.2
480p 17f 6,630 10.8ms 10.1ms (1.1×) 7.6ms (1.4×) 6.4ms (1.7×)
480p 33f 12,870 39.5ms 31.5ms (1.3×) 22.1ms (1.8×) 17.1ms (2.3×)
720p 33f 29,700 206.1ms 146.7ms (1.4×) 95.6ms (2.2×) 70.0ms (2.9×)

Writing a plugin backend

Minimal function-based plugin:

# sparge_vllm_omni/__init__.py
import torch

def sparge_attn_fn(
    query: torch.Tensor, key: torch.Tensor, value: torch.Tensor,
    topk: float, is_causal: bool,
) -> torch.Tensor:
    from sparge import sparge_attn
    return sparge_attn(query, key, value, topk_ratio=topk)
# pyproject.toml
[project.entry-points."vllm_omni.sparse_attn"]
spargeattn = "sparge_vllm_omni:sparge_attn_fn"

[project.entry-points."vllm_omni.attn_backend"]
spargeattn = "sparge_vllm_omni.backend:SpargeBackend"

Usage:

pip install sparge-vllm-omni
vllm-omni serve Wan-AI/Wan2.2-T2V-A14B --attn-backend spargeattn

What does NOT change

  • The AR/LLM backbone attention path is untouched.
  • Existing full-attention DiT models work unchanged when no sparse backend is configured.
  • vLLM-Omni takes no hard dependency on any sparse library.
  • SP/Ring communication is handled automatically by the existing Attention.forward wrapper — sparse kernels see pre-communicated tensors.

Open questions

  1. Should SparsePatternSpec / SparseMetadataBuilder remain in vllm-omni core as optional utilities, or move to a separate helper package?
  2. Should DiTSparseAttentionAdapter (step-aware topk scheduling) be first-class public API or remain internal?
  3. Is vllm_omni.attn_backend the right entry_point group name, or should it be vllm_omni.diffusion_attn_backend to avoid confusion with the AR attention path?

Alternatives considered

Separate SparseAttentionBackend ABC: Rejected. Inheriting AttentionBackend reuses existing infrastructure and avoids a parallel registry. (Addressed: @gcanlin)

--sparse-attn-backend CLI flag: Rejected. --attn-backend is the existing unified interface; a second flag is unnecessary. (Addressed: @SamitHuang)

Per-model enable_sparse_attention() hook: Rejected. is_self_attention=False achieves the same cross-attention isolation with 1 line instead of a per-model method. (Addressed: @alex-jw-brooks)

Related

CC: @wtomin @ZJY0516 @hsliuustc0106 @jiangmengyu18 @gglorian @gcanlin @SamitHuang @alex-jw-brooks

Metadata

Metadata

Assignees

Labels

diffusioncodes related to diffusion modelsgood first issueGood for newcomershelp wantedExtra attention is neededhigh priorityhigh priority issue, needs to be done asap

Type

No type
No fields configured for issues without a type.

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions