[Feature] Support Prefill-Decode disaggregation via vLLM KV transfer#1303
[Feature] Support Prefill-Decode disaggregation via vLLM KV transfer#1303ahengljh wants to merge 23 commits into
Conversation
There was a problem hiding this comment.
💡 Codex Review
Here are some automated review suggestions for this pull request.
Reviewed commit: 4fb129bceb
ℹ️ About Codex in GitHub
Codex has been enabled to automatically review pull requests in this repo. Reviews are triggered when you
- Open a pull request for review
- Mark a draft as ready
- Comment "@codex review".
If Codex has suggestions, it will comment; otherwise it will react with 👍.
When you sign up for Codex through ChatGPT, Codex can also answer questions or update the PR, like "@codex address that feedback".
lishunyang12
left a comment
There was a problem hiding this comment.
WIP design feedback -- the PD disaggregation idea is sound but the implementation has some structural issues worth sorting out before polish.
Thank you for comments and I'll work on them soon. |
|
@vllm-omni-reviewer |
b315e6b to
606e7cf
Compare
…iew comments - Remove non-PD files: gpu_ar_model_runner.py (debug logging only), omni_ar_scheduler.py and omni_generation_scheduler.py (general compat shims, not PD-specific), pd_server_patch_guide.md (superseded by monkey_patch.py) - Downgrade all KV-DIAG logging from WARNING to DEBUG (omni_llm.py, omni_stage.py) - Strip verbose per-step/per-batch diagnostic scaffolding from omni_llm.py and omni_stage.py - patched_mooncake_connector: call super().add_new_req() instead of skipping; use copy-and-restore pattern in group_kv_pull - omni.py: refactor _detect_pd_separation to single-pass; deduplicate _kv_cfg_to_dict/_normalize_kv_transfer_params into _to_dict() - async_omni.py: unify PD routing merge semantics with sync path - qwen3_omni stage_input_processors: replace hardcoded "0"/"24" layer keys with named constants - qwen3_omni model: document zero-padding safety for PD disaggregation - omni_llm: add comment explaining why _flush_kv_connector_sends reaches into vLLM internals PR scope reduced from 15 to 11 files. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
hsliuustc0106
left a comment
There was a problem hiding this comment.
Summary
This PR implements Prefill-Decode (PD) disaggregation for vLLM-Omni. While the feature is architecturally sound, the implementation has several critical issues that need to be addressed.
Critical Issues:
- Memory leak:
_pd_kv_params_by_reqnever cleaned up on request failure - Silent failures in config parsing with empty dict fallbacks
- Race conditions in state management despite locks
- Fragile monkey-patching of vLLM internals
- Hardcoded defaults (bootstrap port 25201) without documentation
Moderate Issues:
- Complex state management spread across multiple dictionaries
- Inconsistent error handling (some raise, some return None)
- Missing validation for edge cases
- No version compatibility checks for vLLM
Minor Issues:
- Debug-level logging for important events
- Large PR mixing feature + tests makes review difficult
Recommendation: Request changes - address memory leak and silent failures before merge.
lishunyang12
left a comment
There was a problem hiding this comment.
Thanks for addressing the earlier feedback -- the single-pass detection rewrite, _to_dict dedup, logging downgrade, and super().add_new_req() call all look correct now. A few remaining items:
bddce52 to
df087f3
Compare
…iew comments - Remove non-PD files: gpu_ar_model_runner.py (debug logging only), omni_ar_scheduler.py and omni_generation_scheduler.py (general compat shims, not PD-specific), pd_server_patch_guide.md (superseded by monkey_patch.py) - Downgrade all KV-DIAG logging from WARNING to DEBUG (omni_llm.py, omni_stage.py) - Strip verbose per-step/per-batch diagnostic scaffolding from omni_llm.py and omni_stage.py - patched_mooncake_connector: call super().add_new_req() instead of skipping; use copy-and-restore pattern in group_kv_pull - omni.py: refactor _detect_pd_separation to single-pass; deduplicate _kv_cfg_to_dict/_normalize_kv_transfer_params into _to_dict() - async_omni.py: unify PD routing merge semantics with sync path - qwen3_omni stage_input_processors: replace hardcoded "0"/"24" layer keys with named constants - qwen3_omni model: document zero-padding safety for PD disaggregation - omni_llm: add comment explaining why _flush_kv_connector_sends reaches into vLLM internals PR scope reduced from 15 to 11 files. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
…ests, e2e - Neutralize stop/stop_token_ids in prefill sampling params to ensure finish_reason='length' (prevents MooncakeConnector KV transfer cancel) - Add _DEFAULT_MOONCAKE_BOOTSTRAP_PORT named constant - Add tensor_parallel_size validation in PD config check - Improve error messages with type info for kv_transfer_config parsing - Add defense-in-depth cleanup of _pd_kv_params_by_req after generation - Upgrade auto-duplication log to WARNING with suppression hint - Downgrade per-request PD routing/trace logs from INFO to DEBUG - Add vLLM version compatibility warning in monkey_patch.py - Use dynamic __qualname__ from original MooncakeConnector - Add padding threshold warning (512 tokens) in model zero-padding - Add clarifying comments on threading model, merge order, save-patch-restore - Add unit tests: stop neutralization, failure/leak cleanup, TP validation - Add PD e2e tests for both text and audio modalities (offline + online) - Add PD CI stage config with load_format: dummy Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
68e0f9e to
0483a26
Compare
a27763c to
2a8212d
Compare
| # multimodal_mask only selects audio/image/video token positions, | ||
| # which always lie within the prompt (prefill) portion where real | ||
| # embeddings exist. | ||
| target_len = thinker_result_ids.shape[-1] |
There was a problem hiding this comment.
Can we have a unified workflow for other models in PD disaggregation?
There was a problem hiding this comment.
Yes, that's the direction we want to go. The PD orchestration logic in omni.py/async_omni.py is already model-agnostic — it only looks at is_prefill_only/is_decode_only flags and kv_transfer_config in the YAML.
The model-specific part is only in stage_input_processors (the embedding merge in _merge_pd_embeddings with layer keys "0" and "24"). For other models, they'd need their own stage_input_processor but can reuse the PD orchestration as-is.
We can extract a common PD embedding merge base with configurable layer keys to make it easier. Will track this as a follow-up.
There was a problem hiding this comment.
it's not reasonable to change model files to support PD
|
@R2-Y PTAL |
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com> Signed-off-by: Jinheng Li <ahengljh@gmail.com>
Move duplicated prefill→decode routing code from omni.py and async_omni.py into PDDisaggregationMixin._prepare_pd_decode_routing() in pd_utils.py. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com> Signed-off-by: Jinheng Li <ahengljh@gmail.com>
|
I agree with @lishunyang12 to split this PR into several atomic PR |
When the talker generates a long output, the flattened codec codes (seq_len * num_quantizers) can exceed the code2wav model's max_model_len, causing a ValueError. Truncate to fit within the 65536 token limit. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com> Signed-off-by: Jinheng Li <ahengljh@gmail.com>
Remove qwen3_omni_moe_pd_multiconnector.yaml and restore the original qwen3_omni_moe_pd_separation.yaml stage config. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com> Signed-off-by: Jinheng Li <ahengljh@gmail.com>
|
PR 1 — Core PD infra @ahengljh Hi please kindly follow this instruction to scope down this pr otherwise maintainers do not have extra bandwidth to review, and i don't think this pr can be integrated in the near future. |
|
@ahengljh Another way to continue this is to present your pr and get maintainer in sync with your design ideas in our weekly meeting, refer to https://docs.google.com/document/d/1pdUBiS_7mdOUNDtdwy-9OUf7jsMWN324BIbps5olDME/edit?tab=t.0#heading=h.l9hdvzveucma. |
Thank you shunyang, actually I have presented this PR in an internal discussion with @hsliuustc0106 , but as you suggested, we also believe split this PR into small ones will be better for everyone, so I am working on it. |
Signed-off-by: Jinheng Li <ahengljh@gmail.com>
Bring the split-2 branch back in line with vllm-project#1303 by pairing the Qwen model and stage-processor changes with the PD runtime wiring they depend on. Includes the orchestrator routing changes in omni.py/async_omni.py, stage worker PD flags and KV-transfer restoration in omni_stage.py, the connector flush in omni_llm.py, and the unit-test package markers from the original branch. Co-authored-by: spencerr221 <liubingyu62@gmail.com> Signed-off-by: Jinheng Li <ahengljh@gmail.com>
…ng (vllm-project#1863) Signed-off-by: Jinheng Li <ahengljh@gmail.com>
…ng (vllm-project#1863) Signed-off-by: Jinheng Li <ahengljh@gmail.com>
…ng (vllm-project#1863) Signed-off-by: Jinheng Li <ahengljh@gmail.com>
Carry only the remaining PD test coverage from vllm-project#1303 after split 1 and the corrected split 2 are accounted for. This commit contains the PD entrypoint unit tests plus the offline/online Qwen e2e coverage and the CI-only PD stage config fixture. Co-authored-by: spencerr221 <liubingyu62@gmail.com> Signed-off-by: Jinheng Li <ahengljh@gmail.com>
…rator architecture Adapts the Prefill-Decode (PD) disaggregation feature from PR vllm-project#1303 to the refactored single-process Orchestrator architecture introduced in PR vllm-project#1908. Key changes: - engine/async_omni_engine.py: Add _detect_pd_config() which detects PD stage pairs, applies MooncakeConnector monkey patch, and extracts the bootstrap address; passes pd_config to Orchestrator - engine/orchestrator.py: Add PD routing logic in _forward_to_next_stage; capture prefill KV params from outputs and inject into decode SP via _build_pd_decode_params(); clean up _pd_kv_params on request completion - entrypoints/omni_base.py: Inherit PDDisaggregationMixin; add stage_configs property; call _init_pd_state() on init - entrypoints/omni.py: Expand sampling params for PD before resolving; inject per-request prefill SP modifications - entrypoints/async_omni.py: Same sampling param expansions for async path - entrypoints/pd_utils.py: Replace stage_list -> stage_configs references - model_executor/stage_input_processors/qwen3_omni.py: Add PD embedding merge in thinker2talker(); fix talker2code2wav() dimension slicing and add truncation guard for code2wav max prompt length - model_executor/models/qwen3_omni/qwen3_omni.py: Safety zero-padding in _thinker_to_talker_prefill(); safety clamping in _get_talker_user_parts() for PD length mismatches - New YAML configs: qwen3_omni_moe_pd_separation.yaml (production), qwen3_omni_pd_ci.yaml (CI with dummy weights) - New tests: test_pd_disaggregation.py (adapted for new arch; old-arch integration tests marked xfail), test_qwen3_omni_stage_processors.py Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
…ng (vllm-project#1863) Signed-off-by: Jinheng Li <ahengljh@gmail.com> Signed-off-by: yiliu30 <yi4.liu@intel.com>
…ng (vllm-project#1863) Signed-off-by: Jinheng Li <ahengljh@gmail.com>
…ng (vllm-project#1863) Signed-off-by: Jinheng Li <ahengljh@gmail.com>
…ng (vllm-project#1863) Signed-off-by: Jinheng Li <ahengljh@gmail.com>

Split Plan
This PR is being split into smaller reviewable pieces for easier review and merging:
pd_utils.py, Mooncake patch module, PD stage YAML).#1863.This umbrella PR remains the full implementation reference while the smaller split PRs are landed.
Summary
Implements Prefill-Decode (PD) disaggregation for the thinker stage in vLLM-Omni, reusing vLLM's native KV connector infrastructure (MooncakeConnector). Splits the thinker into separate prefill (KV producer) and decode (KV consumer) GPU instances, connected via RDMA/TCP KV cache transfer.
Architecture
Changes (17 files, ~4900 lines)
vllm_omni/entrypoints/omni.pyvllm_omni/entrypoints/async_omni.pyvllm_omni/entrypoints/omni_llm.py_flush_kv_connector_sends()for batch-mode KV flushvllm_omni/entrypoints/omni_stage.pyvllm_omni/distributed/kv_transfer/patched_mooncake_connector.pyvllm_omni/distributed/kv_transfer/monkey_patch.pyvllm_omni/model_executor/stage_input_processors/qwen3_omni.py_merge_pd_embeddings) for thinker→talker transitionvllm_omni/model_executor/models/qwen3_omni/qwen3_omni.pyqwen3_omni_moe_pd_separation.yamltests/entrypoints/test_pd_disaggregation.pytests/.../test_qwen3_omni_stage_processors.pytests/e2e/offline_inference/test_qwen3_omni_pd.pytests/e2e/online_serving/test_qwen3_omni_pd.pytests/e2e/stage_configs/qwen3_omni_pd_ci.yamlload_format: dummyfor test without real weightsTest Plan
Automated Tests (all passing)
Unit tests (
pytest tests/entrypoints/test_pd_disaggregation.py -v):TestDetectPDSeparation(4 tests) — PD pair detection in 2/4-stage pipelinesTestValidatePDConfig(6 tests) — config validation: mismatched connector/role/buffer errorsTestGetPDConnectorInfo(3 tests) — engine_id and bootstrap_addr extractionTestPreparePrefillSamplingParams(4 tests) — max_tokens=1, KV param injection, no mutationTestPrefillStopNeutralization(4 tests) — stop=[], stop_token_ids=[], include_stop_str_in_output=FalseTestSamplingParamsAutoDuplication(1 test) — auto-dup for 4-stage pipelineTestNormalizeKVTransferParams(3 tests) — dict/None/dataclass conversionTestKvCfgToDict(3 tests) — dict/None/dataclass with empty-dict defaultTestPDRouting(3 tests) — prefill receives max_tokens=1, decode gets original prompt, correct KV flagsTestKVParamsCleanup(4 tests) — drop/pop/fallback lifecycleTestTPSizeValidation(3 tests) — matching/mismatched/default TP sizeTestPDYAMLConfig(1 test) — production YAML loads and validatesTestMooncakeConnectorPatch(4 tests) — subclass check, remote_request_id, stage payload flagsStage input processor tests (
pytest tests/model_executor/stage_input_processors/test_qwen3_omni_stage_processors.py -v):TestMergePDEmbeddings(9 tests) — overlap, empty prefill/decode, missing keys, edge casesTestGetPrefillStage(5 tests) — PD active/inactive, no outputs, wrong sourceTestThinker2TalkerPDMode(8 tests) — PD merge, overlap, TTS fallback, graceful errorTestPDAudioPipelineIntegration(3 tests) — full PD audio chain, prompt context, non-PD fallbackE2E tests (require 3x GPUs + model):
test_pd_text_only— offline text generation through PD pipelinetest_pd_video_to_audio— offline video→audio through full 4-stage PD pipelinetest_pd_text_to_text— online text→text via OpenAI APItest_pd_mix_to_text_audio— online multimodal→text+audio via OpenAI APIManual verification
How to run E2E tests
GPU Layout (default YAML, TP=1, 3 GPUs)