Skip to content

[Feat] Enable expert parallel for diffusion MoE layers#1323

Merged
hsliuustc0106 merged 2 commits into
vllm-project:mainfrom
Semmer2:EP_Enable
Mar 11, 2026
Merged

[Feat] Enable expert parallel for diffusion MoE layers#1323
hsliuustc0106 merged 2 commits into
vllm-project:mainfrom
Semmer2:EP_Enable

Conversation

@Semmer2

@Semmer2 Semmer2 commented Feb 11, 2026

Copy link
Copy Markdown
Contributor

PLEASE FILL IN THE PR DESCRIPTION HERE ENSURING ALL CHECKLIST ITEMS (AT THE BOTTOM) HAVE BEEN CONSIDERED.

Purpose

Support MoE layers with Expert Parallel in diffusion inference.

Test Plan

test HunyuanImage3.0 with EP and with out EP, and evaluate the output.

with EP:
python examples/offline_inference/text_to_image/text_to_image.py --model /data/HunyuanImage-3.0/ --prompt "A brown and white dog is running on the grass" --output output_with_EP.png --num_inference_steps 50 --guidance_scale 5.0 --tensor_parallel_size 8 --seed 1234 --enable_expert_parallel

without EP:
python examples/offline_inference/text_to_image/text_to_image.py --model /data/HunyuanImage-3.0/ --prompt "A brown and white dog is running on the grass" --output output_without_EP.png --num_inference_steps 50 --guidance_scale 5.0 --tensor_parallel_size 8 --seed 1234

and you can tell if the EP works by cmd output log:

[Stage-0] INFO 02-10 18:15:59 [layer.py:475] [EP Rank 5/8] Expert parallelism is enabled. Expert placement strategy: linear. Local/global number of experts: 8/64. Experts local to global index map: 0->40, 1->41, 2->42, 3->43, 4->44, 5->45, 6->46, 7->47.
[Stage-0] INFO 02-10 18:15:59 [layer.py:475] [EP Rank 3/8] Expert parallelism is enabled. Expert placement strategy: linear. Local/global number of experts: 8/64. Experts local to global index map: 0->24, 1->25, 2->26, 3->27, 4->28, 5->29, 6->30, 7->31.
[Stage-0] INFO 02-10 18:15:59 [layer.py:475] [EP Rank 0/8] Expert parallelism is enabled. Expert placement strategy: linear. Local/global number of experts: 8/64. Experts local to global index map: 0->0, 1->1, 2->2, 3->3, 4->4, 5->5, 6->6, 7->7.
[Stage-0] INFO 02-10 18:15:59 [unquantized.py:82] FlashInfer CUTLASS MoE is available for EP but not enabled, consider setting VLLM_USE_FLASHINFER_MOE_FP16=1 to enable it.
[Stage-0] INFO 02-10 18:15:59 [unquantized.py:103] Using TRITON backend for Unquantized MoE
[Stage-0] INFO 02-10 18:15:59 [layer.py:475] [EP Rank 7/8] Expert parallelism is enabled. Expert placement strategy: linear. Local/global number of experts: 8/64. Experts local to global index map: 0->56, 1->57, 2->58, 3->59, 4->60, 5->61, 6->62, 7->63.
[Stage-0] INFO 02-10 18:15:59 [layer.py:475] [EP Rank 1/8] Expert parallelism is enabled. Expert placement strategy: linear. Local/global number of experts: 8/64. Experts local to global index map: 0->8, 1->9, 2->10, 3->11, 4->12, 5->13, 6->14, 7->15.
[Stage-0] INFO 02-10 18:15:59 [layer.py:475] [EP Rank 2/8] Expert parallelism is enabled. Expert placement strategy: linear. Local/global number of experts: 8/64. Experts local to global index map: 0->16, 1->17, 2->18, 3->19, 4->20, 5->21, 6->22, 7->23.
[Stage-0] INFO 02-10 18:15:59 [layer.py:475] [EP Rank 4/8] Expert parallelism is enabled. Expert placement strategy: linear. Local/global number of experts: 8/64. Experts local to global index map: 0->32, 1->33, 2->34, 3->35, 4->36, 5->37, 6->38, 7->39.
[Stage-0] INFO 02-10 18:15:59 [layer.py:475] [EP Rank 6/8] Expert parallelism is enabled. Expert placement strategy: linear. Local/global number of experts: 8/64. Experts local to global index map: 0->48, 1->49, 2->50, 3->51, 4->52, 5->53, 6->54, 7->55.

Test Result

image

Current EP feature may cause performance degradation for now, e2e test result:

[enable_ep: True] 8 GPUs | baseline: 84245ms, ep: 105764ms, speedup: 0.80x
[enable_ep: True] diff: [mean=2.641083e-02, max=7.058824e-01], cos_sim: [mean=9.949654e-01, max=9.949654e-01], mse: 2.727440e-03

==========================================================================================
SUMMARY
==========================================================================================
Mode            GPUs   Size       Baseline     EP           Speedup    Status
------------------------------------------------------------------------------------------
A brown an      1      1024x1024  84245ms      105764ms        0.80x      PASS
==========================================================================================
PASSEDGPU cleanup disabled

Please refer to issue

@chatgpt-codex-connector chatgpt-codex-connector Bot left a comment

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Codex Review

Here are some automated review suggestions for this pull request.

Reviewed commit: 27ac16cebd

ℹ️ About Codex in GitHub

Codex has been enabled to automatically review pull requests in this repo. Reviews are triggered when you

  • Open a pull request for review
  • Mark a draft as ready
  • Comment "@codex review".

If Codex has suggestions, it will comment; otherwise it will react with 👍.

When you sign up for Codex through ChatGPT, Codex can also answer questions or update the PR, like "@codex address that feedback".

Comment thread vllm_omni/diffusion/distributed/parallel_state.py Outdated
world_size: int = torch.distributed.get_world_size()
backend = backend or torch.distributed.get_backend(get_world_group().device_group)

forward_context = get_forward_context()

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Badge Gate forward-context lookup to expert-parallel path

initialize_model_parallel now always requires a forward context, but existing direct callers invoke it after init_distributed_environment() without wrapping set_forward_context (for example the distributed tests call initialize_model_parallel(...) directly), so this now raises Forward context is not set before any group initialization. The context lookup should be conditional on the EP-only branch or otherwise optional to preserve prior behavior for non-EP setups.

Useful? React with 👍 / 👎.

Comment thread examples/offline_inference/text_to_image/text_to_image.py Outdated

Copilot AI left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Adds an “expert parallel” (EP) switch to diffusion inference so MoE-based diffusion transformers (e.g., HunyuanImage3.0) can initialize an EP process group and plumb the flag from user entrypoints/examples into diffusion distributed setup.

Changes:

  • Plumbs enable_expert_parallel through diffusion configs, async default stage config creation, worker init, and offline inference examples.
  • Ensures OmniDiffusionConfig.tf_model_config is populated from HF config files and adds an is_moe helper for gating EP.
  • Extends diffusion distributed parallel state to optionally create/destroy an “expert” process group and updates docs to list EP support.

Reviewed changes

Copilot reviewed 10 out of 10 changed files in this pull request and generated 6 comments.

Show a summary per file
File Description
vllm_omni/entrypoints/omni_diffusion.py Populate tf_model_config from config.json fallback path.
vllm_omni/entrypoints/async_omni_diffusion.py Same tf_model_config population for async diffusion entrypoint.
vllm_omni/entrypoints/async_omni.py Adds enable_expert_parallel to default diffusion stage config creation.
vllm_omni/diffusion/worker/diffusion_worker.py Propagates EP flag into VllmConfig and diffusion model-parallel init.
vllm_omni/diffusion/distributed/parallel_state.py Adds EP group creation/destruction and introduces forward-context dependency.
vllm_omni/diffusion/data.py Adds enable_expert_parallel field + is_moe property.
examples/offline_inference/text_to_video/text_to_video.py Adds CLI flag and passes EP into DiffusionParallelConfig.
examples/offline_inference/text_to_image/text_to_image.py Adds CLI flag and passes EP into DiffusionParallelConfig.
examples/offline_inference/image_to_image/image_edit.py Adds CLI flag and passes EP into DiffusionParallelConfig.
docs/user_guide/diffusion/parallelism_acceleration.md Documents EP support in the model/parallelism matrix.

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment thread vllm_omni/diffusion/distributed/parallel_state.py Outdated
Comment thread vllm_omni/diffusion/distributed/parallel_state.py Outdated
Comment on lines +784 to +792
if enable_expert_parallel:
assert od_config.is_moe
vllm_parallel_state._EP = init_model_parallel_group(
group_ranks=rank_generator.get_ranks("ep"),
local_rank=get_world_group().local_rank,
backend=backend,
parallel_mode="expert",
)

Copilot AI Feb 11, 2026

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

enable_expert_parallel adds new behavior in initialize_model_parallel, but there are existing unit tests for parallel group construction under tests/diffusion/distributed/. Please add a test that exercises initialize_model_parallel(enable_expert_parallel=True) and asserts the EP group is created with the expected world size/ranks (and that it is not created when the flag is false).

Copilot uses AI. Check for mistakes.
Comment thread vllm_omni/diffusion/data.py Outdated
"""Number of tensor parallel groups."""

enable_expert_parallel: bool = False
"""Use expert parallelism instead of tensor parallelism for MoE layers."""

Copilot AI Feb 11, 2026

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

enable_expert_parallel is documented as “Use expert parallelism instead of tensor parallelism for MoE layers,” but this change does not disable TP; it only introduces an EP flag/group in addition to the existing TP setup. Please update the docstring to reflect the actual behavior, or adjust the implementation if EP is truly meant to replace TP for MoE layers.

Suggested change
"""Use expert parallelism instead of tensor parallelism for MoE layers."""
"""Enable expert parallelism for MoE layers in addition to tensor parallelism."""

Copilot uses AI. Check for mistakes.
Comment thread vllm_omni/diffusion/data.py Outdated
Comment on lines +686 to +688
forward_context = get_forward_context()
od_config = forward_context.omni_diffusion_config

Copilot AI Feb 11, 2026

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

initialize_model_parallel now unconditionally calls get_forward_context(), which asserts if no forward context has been set. There are multiple call sites (including existing unit tests) that call initialize_model_parallel() before entering any set_forward_context(...) scope, so this will break those flows even when enable_expert_parallel=False. Consider gating this lookup behind enable_expert_parallel (and/or using is_forward_context_available()), or pass the needed OmniDiffusionConfig explicitly as a parameter when EP is enabled.

Copilot uses AI. Check for mistakes.
| **Stable-Diffusion3.5** | `stabilityai/stable-diffusion-3.5` | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ |
| **FLUX.2-klein** | `black-forest-labs/FLUX.2-klein-4B` | ❌ | ❌ | ❌ | ✅ | ❌ | ❌ |
| **FLUX.1-dev** | `black-forest-labs/FLUX.1-dev` | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ |
| **HunyuanImage3.0** | `tencent/HunyuanImage-3.0`, `tencent/HunyuanImage-3.0-Instruct` | ❌ | ❌ | ❌ | ✅ | ❌ | ✅ |

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For non-MoE models, should I use 'N/A' instead of 'X' for the Expert-Parallel column? Since most models don't have MoE layers, 'X' might be misleading—users could interpret it as a lack of feature support rather than a structural irrelevance.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@hsliuustc0106 @princepride @lishunyang12 Hi guys, any suggestion about this part?

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

N/A is better here -- X suggests the feature was intentionally not supported, N/A makes it clear the dimension doesn't apply to non-MoE architectures.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

N/A is better here -- X suggests the feature was intentionally not supported, N/A makes it clear the dimension doesn't apply to non-MoE architectures.

Thank you. That's better indeed.

Comment on lines +786 to +791
vllm_parallel_state._EP = init_model_parallel_group(
group_ranks=rank_generator.get_ranks("ep"),
local_rank=get_world_group().local_rank,
backend=backend,
parallel_mode="expert",
)

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

vllm_omni/diffusion/distributed/parallel_state.py:#L187 It seems that ep is always 1

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You're right, I misunderstood previous code. Modified.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for fixing the EP size. Two remaining issues in the current code:

  1. get_forward_context() on line 735 is still called unconditionally — it needs to be inside the if enable_expert_parallel: block (or passed as a parameter). Right now any caller that doesn't set a forward context will crash even when EP is disabled.

  2. Adding "ep" to the order string and name_to_size breaks generate_masked_orthogonal_rank_groups for all other tokens. ordered_size becomes [tp, sp, pp, cfg, dp, ep] where ep = tp*sp*cfg*dp, so the product no longer equals world_size. The get_ranks("ep") special-case is correct, but the mask-based path used by get_ranks("tp") / get_ranks("cfg") / etc. will produce invalid rank groups.

Fix: don't add ep to the order string or name_to_size. The get_ranks("ep") method already computes EP groups directly without needing the mask infrastructure.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for fixing the EP size. Two remaining issues in the current code:

  1. get_forward_context() on line 735 is still called unconditionally — it needs to be inside the if enable_expert_parallel: block (or passed as a parameter). Right now any caller that doesn't set a forward context will crash even when EP is disabled.
  2. Adding "ep" to the order string and name_to_size breaks generate_masked_orthogonal_rank_groups for all other tokens. ordered_size becomes [tp, sp, pp, cfg, dp, ep] where ep = tp*sp*cfg*dp, so the product no longer equals world_size. The get_ranks("ep") special-case is correct, but the mask-based path used by get_ranks("tp") / get_ranks("cfg") / etc. will produce invalid rank groups.

Fix: don't add ep to the order string or name_to_size. The get_ranks("ep") method already computes EP groups directly without needing the mask infrastructure.

Sure, moved get_forward_context() into if enable_expert_parallel. Removed unnecessary ep code.

@lishunyang12 lishunyang12 left a comment

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The EP plumbing is there but the EP group always has world_size=1 since self.ep = 1 is hardcoded -- so this does not actually do expert parallelism yet.

self.pp = pp
self.cfg = cfg
self.dp = dp
self.ep = 1 # no matter EP enabled, EP stride should always be 1

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

self.ep = 1 means get_ranks("ep") will always return singleton groups. So even when EP is enabled, every rank ends up in its own EP group of size 1 and no actual expert-parallel communication happens. This is the core issue -- how is the EP world size supposed to be derived? Should it equal tp (reusing the TP ranks for EP), or should there be a separate expert_parallel_size parameter?

world_size: int = torch.distributed.get_world_size()
backend = backend or torch.distributed.get_backend(get_world_group().device_group)

forward_context = get_forward_context()

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This get_forward_context() call runs unconditionally, even when enable_expert_parallel=False. Any caller that does not set a forward context first (including existing tests) will crash here. Move this inside the if enable_expert_parallel: block below, or pass od_config as a parameter instead of pulling it from global state.

)

if enable_expert_parallel:
assert od_config.is_moe

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Bare assert gets stripped under python -O. This should be a ValueError or RuntimeError with a message like "enable_expert_parallel requires a MoE model". Also, should EP be silently ignored for non-MoE models instead of raising?

cfg_parallel_size,
data_parallel_size,
"tp-sp-pp-cfg-dp",
"tp-sp-pp-cfg-dp-ep",

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Adding ep to the order string here but keeping self.ep = 1 (line 187) means the rank generator always treats EP as a trivial dimension. If EP is meant to share the TP mesh, the order string change alone will not do it -- you need to actually set self.ep to the desired EP size and adjust self.tp accordingly.

Comment thread vllm_omni/diffusion/data.py Outdated

@property
def is_moe(self) -> bool:
if self.tf_model_config.get("num_experts", None):

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A few issues with is_moe: (1) num_experts can be a list in some configs (per-layer expert counts) -- calling > 0 on a list will throw TypeError. (2) num_experts=1 would return True here, but a single expert is effectively dense. Consider handling the list case and checking > 1 instead of > 0.

tensor_parallel_size: int = 1
"""Number of tensor parallel groups."""

enable_expert_parallel: bool = False

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The docstring says "Use expert parallelism instead of tensor parallelism" but the implementation creates an EP group in addition to TP -- TP is never disabled. Which is the intended behavior? If they coexist, the docstring should say so.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I meant to say in MoE layers, no longer use TP when EP enabled, but other linear layers will continue use TP. Does my expression make users confused?

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's clearer. Something like "Enable expert parallelism for MoE layers (TP is still used for non-MoE layers)" would remove any ambiguity.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's clearer. Something like "Enable expert parallelism for MoE layers (TP is still used for non-MoE layers)" would remove any ambiguity.

Yes, your comment makes it more clearer.

f" Parallel configuration: tensor_parallel_size={args.tensor_parallel_size}, "
f"ulysses_degree={args.ulysses_degree}, ring_degree={args.ring_degree}, cfg_parallel_size={args.cfg_parallel_size}, "
f"vae_patch_parallel_size={args.vae_patch_parallel_size}"
f"vae_patch_parallel_size={args.vae_patch_parallel_size}, enable_expert_parallel: {args.enable_expert_parallel}."

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: this uses enable_expert_parallel: (colon + trailing period), while text_to_video.py uses enable_expert_parallel= (equals, no period) and image_edit.py uses yet another variant. Would be nice to keep the format consistent across all three examples.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes you're right, my bad. Modified.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks.

vllm_parallel_state._TP.destroy()
vllm_parallel_state._TP = None

if vllm_parallel_state._EP:

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is vllm_parallel_state._EP guaranteed to be defined if initialize_model_parallel was never called with enable_expert_parallel=True? If _EP was never set as an attribute, this will raise AttributeError. You might want getattr(vllm_parallel_state, "_EP", None) or ensure _EP is initialized to None at module level in vllm.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, since _EP is from vllm, and it's none defaultly if initialize_model_parallel was not called. So here we use if vllm_parallel_state._EP will not raise AttributeError.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Makes sense, thanks for confirming.

@hsliuustc0106

Copy link
Copy Markdown
Collaborator

@vllm-omni-reviewer

@Semmer2 Semmer2 force-pushed the EP_Enable branch 2 times, most recently from 7a96e93 to b6d078c Compare March 3, 2026 03:13
@Semmer2

Semmer2 commented Mar 3, 2026

Copy link
Copy Markdown
Contributor Author

The EP plumbing is there but the EP group always has world_size=1 since self.ep = 1 is hardcoded -- so this does not actually do expert parallelism yet.

You're right, I misunderstood previous code. Modified.

@lishunyang12

Copy link
Copy Markdown
Collaborator

Thanks for the updates. Two things still need fixing before this is ready:

  1. get_forward_context() on L735 is still called unconditionally — move it inside the if enable_expert_parallel: block.
  2. Adding "ep" to the order string and name_to_size breaks generate_masked_orthogonal_rank_groups for other tokens (tp, cfg, etc.) because product(ordered_size) != world_size. The get_ranks("ep") special-case already works without the mask infrastructure, so just remove ep from the order string and name_to_size.

See inline comments for details.

@hsliuustc0106

Copy link
Copy Markdown
Collaborator

do we have any test in api and e2e acc/perf test?

Hi HS, I have test EP in vllm-omni/tests/e2e/offline_inference/test_expert_parallel.py And here is the results for A100 80G*8:

[enable_ep: True] 8 GPUs | baseline: 84245ms, ep: 105764ms, speedup: 0.80x
[enable_ep: True] diff: [mean=2.641083e-02, max=7.058824e-01], cos_sim: [mean=9.949654e-01, max=9.949654e-01], mse: 2.727440e-03

==========================================================================================
SUMMARY
==========================================================================================
Mode            GPUs   Size       Baseline     EP           Speedup    Status
------------------------------------------------------------------------------------------
A brown an      1      1024x1024  84245ms      105764ms        0.80x      PASS
==========================================================================================
PASSEDGPU cleanup disabled

Hope this makes sense for you.

do we expect EP will bring a perf regression?

@Semmer2

Semmer2 commented Mar 8, 2026

Copy link
Copy Markdown
Contributor Author

do we have any test in api and e2e acc/perf test?

Hi HS, I have test EP in vllm-omni/tests/e2e/offline_inference/test_expert_parallel.py And here is the results for A100 80G*8:

[enable_ep: True] 8 GPUs | baseline: 84245ms, ep: 105764ms, speedup: 0.80x
[enable_ep: True] diff: [mean=2.641083e-02, max=7.058824e-01], cos_sim: [mean=9.949654e-01, max=9.949654e-01], mse: 2.727440e-03

==========================================================================================
SUMMARY
==========================================================================================
Mode            GPUs   Size       Baseline     EP           Speedup    Status
------------------------------------------------------------------------------------------
A brown an      1      1024x1024  84245ms      105764ms        0.80x      PASS
==========================================================================================
PASSEDGPU cleanup disabled

Hope this makes sense for you.

do we expect EP will bring a perf regression?

No, actually I think EP would bring perf improvement. But this looks weird indeed. I actually saw a warning about MoE [Stage-0] WARNING 03-07 17:25:12 [fused_moe.py:1087] Using default MoE config. Performance might be sub-optimal![...]. Still don't know if it's related. Do you think if it's acceptable that we can merge it first and I would fix it next week?

@hsliuustc0106

Copy link
Copy Markdown
Collaborator

do we have any test in api and e2e acc/perf test?

Hi HS, I have test EP in vllm-omni/tests/e2e/offline_inference/test_expert_parallel.py And here is the results for A100 80G*8:

[enable_ep: True] 8 GPUs | baseline: 84245ms, ep: 105764ms, speedup: 0.80x
[enable_ep: True] diff: [mean=2.641083e-02, max=7.058824e-01], cos_sim: [mean=9.949654e-01, max=9.949654e-01], mse: 2.727440e-03

==========================================================================================
SUMMARY
==========================================================================================
Mode            GPUs   Size       Baseline     EP           Speedup    Status
------------------------------------------------------------------------------------------
A brown an      1      1024x1024  84245ms      105764ms        0.80x      PASS
==========================================================================================
PASSEDGPU cleanup disabled

Hope this makes sense for you.

do we expect EP will bring a perf regression?

No, actually I think EP would bring perf improvement. But this looks weird indeed. I actually saw a warning about MoE [Stage-0] WARNING 03-07 17:25:12 [fused_moe.py:1087] Using default MoE config. Performance might be sub-optimal![...]. Still don't know if it's related. Do you think if it's acceptable that we can merge it first and I would fix it next week?

please open a new issue about this perf regression and wait for rebase 0.17.0 (should be ok tomorrow morning)

@Semmer2

Semmer2 commented Mar 8, 2026

Copy link
Copy Markdown
Contributor Author

do we have any test in api and e2e acc/perf test?

Hi HS, I have test EP in vllm-omni/tests/e2e/offline_inference/test_expert_parallel.py And here is the results for A100 80G*8:

[enable_ep: True] 8 GPUs | baseline: 84245ms, ep: 105764ms, speedup: 0.80x
[enable_ep: True] diff: [mean=2.641083e-02, max=7.058824e-01], cos_sim: [mean=9.949654e-01, max=9.949654e-01], mse: 2.727440e-03

==========================================================================================
SUMMARY
==========================================================================================
Mode            GPUs   Size       Baseline     EP           Speedup    Status
------------------------------------------------------------------------------------------
A brown an      1      1024x1024  84245ms      105764ms        0.80x      PASS
==========================================================================================
PASSEDGPU cleanup disabled

Hope this makes sense for you.

do we expect EP will bring a perf regression?

No, actually I think EP would bring perf improvement. But this looks weird indeed. I actually saw a warning about MoE [Stage-0] WARNING 03-07 17:25:12 [fused_moe.py:1087] Using default MoE config. Performance might be sub-optimal![...]. Still don't know if it's related. Do you think if it's acceptable that we can merge it first and I would fix it next week?

please open a new issue about this perf regression and wait for rebase 0.17.0 (should be ok tomorrow morning)

OK, got you.

@lishunyang12 lishunyang12 left a comment

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

previous concerns addressed, looks good now

@hsliuustc0106 hsliuustc0106 added the ready label to trigger buildkite CI label Mar 9, 2026
help="Number of GPUs used for VAE patch/tile parallelism (decode).",
)
parser.add_argument(
"--enable_expert_parallel",

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
"--enable_expert_parallel",
"--enable-expert-parallel",

apply to all please, cc @wtomin

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

apply to all please, cc @wtomin

cli format change applied to all examples files with enable-expert-parallel option.

@Semmer2 Semmer2 force-pushed the EP_Enable branch 3 times, most recently from 30274d2 to b96b713 Compare March 9, 2026 06:59

6. [HSDP](#hsdp): Hybrid Sharded Data Parallel shards model weights across GPUs using PyTorch FSDP2. This reduces per-GPU memory usage, enabling inference of large models on GPUs with limited memory.

The following table shows which models are currently supported by parallelism method:

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

please also add EP before this line

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

please also add EP before this line

Sure, added.

@hsliuustc0106

Copy link
Copy Markdown
Collaborator

fix precommits

Semmer2 added 2 commits March 10, 2026 00:52
Signed-off-by: Semmer2 <semmer@live.cn>
Only HunyuanImage3.0 support EP now, choose it for testing.

Signed-off-by: Semmer2 <semmer@live.cn>
@Semmer2

Semmer2 commented Mar 10, 2026

Copy link
Copy Markdown
Contributor Author

fix precommits

Sorry, fixed.

)
parser.add_argument(
"--enable-expert-parallel",
action="store_true",

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there any text-to-video model having MoE layers now in vllm-omni?

@Semmer2 Semmer2 Mar 10, 2026

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there any text-to-video model having MoE layers now in vllm-omni?

As far as I know, wan2.2 is a MoE video generation model. But the upstream supported version did not contain any MoE structure for now. I think it's adaption is till WIP.

Comment thread docs/user_guide/diffusion/parallelism_acceleration.md

@hsliuustc0106 hsliuustc0106 left a comment

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

open a new issue for the remaining to-dos

@Semmer2

Semmer2 commented Mar 11, 2026

Copy link
Copy Markdown
Contributor Author

open a new issue for the remaining to-dos

Added new issue, and mentioned this PR: #1801

@hsliuustc0106 hsliuustc0106 merged commit 6ac85d4 into vllm-project:main Mar 11, 2026
6 of 7 checks passed
clodaghwalsh17 pushed a commit to clodaghwalsh17/nm-vllm-omni-ent that referenced this pull request May 12, 2026
daixinning pushed a commit to daixinning/vllm-omni that referenced this pull request May 28, 2026
quyifei23 pushed a commit to quyifei23/vllm-omni that referenced this pull request Jun 6, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ready label to trigger buildkite CI

Projects

None yet

Development

Successfully merging this pull request may close these issues.

6 participants