Skip to content

Commit c6843ea

Browse files
RuixiangMatangbinh
authored andcommitted
[Feat] support SP for FLUX.2-klein (vllm-project#1250)
Signed-off-by: Lancer <maruixiang6688@gmail.com>
1 parent 6e05630 commit c6843ea

5 files changed

Lines changed: 284 additions & 57 deletions

File tree

docs/user_guide/diffusion/parallelism_acceleration.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ The following table shows which models are currently supported by parallelism me
3535
| **Qwen-Image-Layered** | `Qwen/Qwen-Image-Layered` |||||| N/A |
3636
| **Z-Image** | `Tongyi-MAI/Z-Image-Turbo` |||| ✅ (TP=2 only) || N/A |
3737
| **Stable-Diffusion3.5** | `stabilityai/stable-diffusion-3.5` |||||| N/A |
38-
| **FLUX.2-klein** | `black-forest-labs/FLUX.2-klein-4B` | | |||| N/A |
38+
| **FLUX.2-klein** | `black-forest-labs/FLUX.2-klein-4B` | | |||| N/A |
3939
| **FLUX.1-dev** | `black-forest-labs/FLUX.1-dev` |||||| N/A |
4040
| **FLUX.2-dev** | `black-forest-labs/FLUX.2-dev` |||||| N/A |
4141
| **HunyuanImage3.0** | `tencent/HunyuanImage-3.0`, `tencent/HunyuanImage-3.0-Instruct` |||||||

vllm_omni/diffusion/layers/rope.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -157,3 +157,29 @@ def forward_native(
157157
sin,
158158
interleaved=self.interleaved,
159159
)
160+
161+
162+
def apply_rope_to_qk(
163+
rope: RotaryEmbedding,
164+
query: torch.Tensor,
165+
key: torch.Tensor,
166+
image_rotary_emb: tuple[torch.Tensor, torch.Tensor] | None,
167+
) -> tuple[torch.Tensor, torch.Tensor]:
168+
"""Apply rotary positional embeddings to query and key tensors.
169+
170+
Args:
171+
rope: RotaryEmbedding instance for applying position embeddings
172+
query: Query tensor [B, S, H, D]
173+
key: Key tensor [B, S, H, D]
174+
image_rotary_emb: Tuple of (cos, sin) tensors or None
175+
176+
Returns:
177+
Tuple of (query, key) with RoPE applied if rotary embeddings provided
178+
"""
179+
if image_rotary_emb is not None:
180+
cos, sin = image_rotary_emb
181+
cos = cos.to(query.dtype)
182+
sin = sin.to(query.dtype)
183+
query = rope(query, cos, sin)
184+
key = rope(key, cos, sin)
185+
return query, key

vllm_omni/diffusion/models/flux/flux_transformer.py

Lines changed: 2 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@
3030

3131
from vllm_omni.diffusion.attention.layer import Attention
3232
from vllm_omni.diffusion.data import OmniDiffusionConfig
33-
from vllm_omni.diffusion.layers.rope import RotaryEmbedding
33+
from vllm_omni.diffusion.layers.rope import RotaryEmbedding, apply_rope_to_qk
3434

3535
logger = init_logger(__name__)
3636

@@ -224,12 +224,7 @@ def forward(
224224
key = torch.cat([encoder_key, key], dim=1)
225225
value = torch.cat([encoder_value, value], dim=1)
226226

227-
if image_rotary_emb is not None:
228-
cos, sin = image_rotary_emb # [S, D/2]
229-
cos = cos.to(query.dtype)
230-
sin = sin.to(query.dtype)
231-
query = self.rope(query, cos, sin)
232-
key = self.rope(key, cos, sin)
227+
query, key = apply_rope_to_qk(self.rope, query, key, image_rotary_emb) # [S, D/2]
233228

234229
hidden_states = self.attn(
235230
query,

0 commit comments

Comments
 (0)