Skip to content

[Feature]Support async computation and communication across stages by chunks#727

Merged
david6666666 merged 2 commits into
vllm-project:mainfrom
amy-why-3459:async_chunk
Jan 26, 2026
Merged

[Feature]Support async computation and communication across stages by chunks#727
david6666666 merged 2 commits into
vllm-project:mainfrom
amy-why-3459:async_chunk

Conversation

@amy-why-3459

@amy-why-3459 amy-why-3459 commented Jan 10, 2026

Copy link
Copy Markdown
Contributor

PLEASE FILL IN THE PR DESCRIPTION HERE ENSURING ALL CHECKLIST ITEMS (AT THE BOTTOM) HAVE BEEN CONSIDERED.

Purpose

Reduced time-to-first-token (TTFT) for long prompts: In streaming audiovisual interaction scenarios, the first-packet latency is a critical factor affecting user experience, and the model’s concurrency capability is key to reducing service costs and improving response speed, asynchronous prefilling, when Thinker completes prefilling the current chunk, its output high-level representations are immediately used to prefill the Talker’s current chunk asynchronously, while Thinker prefills its next chunk.
Reduced first-packet-latency for audio generation: Streaming Multi-Codebook Codec Generation, These tokens are then decoded into waveform by a streaming multi-codebook codec decoder that only attends to the left context.
Supports batch inference for multiple requests: Within different stages, multiple requests can be grouped into a batch for inference.
For detailed design, please refer to the RFC. #268.
This PR mainly includes the following four features:

  1. thinker->talker pipeline: When Thinker completes prefilling the current chunk, its output high-level representations
    are immediately used to prefill the Talker's current chunk asynchronously, while Thinker continues to prefill its next
    chunk. This approach significantly reduces the Time-To-First-Token (TTFT) for both Thinker and Talker
  2. talker->code2wav pipeline: Once the talker generates the first token, the MTP module predicts the remaining tokens for the current frame. These tokens are then decoded into waveforms by a streaming multi-codebook codec decoder that only considers the left context.
  3. code2wav chunked decode: . To minimize the user's waiting time for receiving the first generated packet, we propose a left-context-only multi-codebook generation mechanism.
  4. audio streaming output: Qwen3-Omni can output the waveform immediately after the Talker generates each token, significantly reducing the latency of the first packet.

We will also add two more features later.

  1. chunked prefill for talker
  2. async put and get

Test Plan

We have added an async_chunk switch. The async_chunk and custom_process_next_stage_input_func parameters are added to the YAML file.

vllm serve Qwen/Qwen3-Omni-30B-A3B-Instruct 
     --omni 
     --port 8091 
     --stage-configs-path /vllm_omni/model_executor/stage_configs/qwen3_omni_moe_async_chunk.yaml

Test Result

  1. Single-Request e2e Performance Comparison
async_chunk:false
{'type': 'request_level_metrics',
'e2e_time_ms': 3868.5121536254883}

async_chunk:true
{'type': 'request_level_metrics',
'e2e_time_ms': 1305.3650856018066}

e2e latency reduced by 66%

Synchronous put and get may cause a performance bottleneck in high concurrency scenarios.
#934


Essential Elements of an Effective PR Description Checklist
  • The purpose of the PR, such as "Fix some issue (link existing issues this PR will resolve)".
  • The test plan, such as providing test command.
  • The test results, such as pasting the results comparison before and after, or e2e results
  • (Optional) The necessary documentation update, such as updating supported_models.md and examples for a new model.
  • (Optional) Release notes update. If your change is user facing, please update the release notes draft.

BEFORE SUBMITTING, PLEASE READ https://github.com/vllm-project/vllm-omni/blob/main/CONTRIBUTING.md (anything written below this line will be removed by GitHub Actions)

@chatgpt-codex-connector chatgpt-codex-connector Bot left a comment

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Codex Review

Here are some automated review suggestions for this pull request.

Reviewed commit: 1bc732be5f

ℹ️ About Codex in GitHub

Codex has been enabled to automatically review pull requests in this repo. Reviews are triggered when you

  • Open a pull request for review
  • Mark a draft as ready
  • Comment "@codex review".

If Codex has suggestions, it will comment; otherwise it will react with 👍.

When you sign up for Codex through ChatGPT, Codex can also answer questions or update the PR, like "@codex address that feedback".

Comment on lines +26 to +30
model_config = self.vllm_config.model_config
if model_config.async_chunk:
connector_specs = ConnectorSpec(name=model_config.stage_connector_name,
extra=model_config.stage_connector_extra)
self.omni_connector = OmniConnectorFactory.create_connector(connector_specs)

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Badge Initialize omni_connector for non-async runs

OmniGenerationScheduler.init only sets self.omni_connector when async_chunk is true. When async_chunk is false (default in configs like stage_configs/qwen2_5_omni.yaml), schedule() still evaluates if self.omni_connector is not None and will raise an AttributeError before any requests run. Initialize self.omni_connector = None unconditionally or use getattr so non-async stages don’t crash.

Useful? React with 👍 / 👎.

Comment on lines +206 to +210
chunk_id = connector.get_requests[req_id]
connector_get_key = f"{req_id}_{target_stage_id}_{chunk_id}"
payload_data = get_through_connector(connector, target_stage_id, stage_id, req_id, connector_get_key)
if payload_data:
cached_reqs.additional_information[req_id] = payload_data

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Badge Fix cached payload field name for chunked inputs

get_chunk stores cached payloads under scheduled_cached_reqs.additional_information, but GPUModelRunner._get_additional_information only checks scheduled_cached_reqs.additional_informations (plural, vllm_omni/worker/gpu_model_runner.py:778-780). That means chunk payloads for running requests never reach model.preprocess, so later chunks reuse stale info or miss required tensors. Write to the field the runner reads, or update the runner to match.

Useful? React with 👍 / 👎.

Comment on lines +744 to +748
chunk_offset = num_processed_thinker_tokens
chunk_size = min(current_chunk_size, total_thinker_tokens)

thinker_embed_chunk = thinker_sequence_embeds[:chunk_size]
thinker_hidden_chunk = thinker_hidden_states[:chunk_size]

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Badge Slice chunked thinker data using chunk_offset

Chunked prefill computes chunk_offset but slices from the start ([:chunk_size]) for embeddings/ids. For any chunk after the first, the data represents earlier tokens while _thinker_to_talker_prefill interprets them as [chunk_offset:...], so segments are misaligned and repeated. Slice with chunk_offset:chunk_offset + chunk_size to align chunk data with the offset.

Useful? React with 👍 / 👎.

Comment on lines +50 to +53
def thinker2talker(
stage_list: list[Any],
engine_input_source: list[int],
prompt: OmniTokensPrompt | TextPrompt | None = None,
requires_multimodal_data: bool = False,
) -> list[OmniTokensPrompt]:
pooling_output: dict[str, Any],
request: OmniEngineCoreRequest,
) -> list[dict[str, Any]]:

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Badge Keep custom_process_input_func signature compatible

The qwen3_omni thinker2talker custom_process_input_func now expects (pooling_output, request) and returns a dict, but OmniStage.process_engine_inputs still calls custom_process_input_func with (stage_list, engine_input_source, prompt, requires_multimodal_data) and expects a list of OmniTokensPrompt (see vllm_omni/entrypoints/omni_stage.py:418-421). Any stage configs still pointing at this function (e.g., qwen3_omni_ci.yaml, qwen3_omni_moe_multiconnector.yaml) will now throw a TypeError. Consider keeping a backward-compatible wrapper or updating the caller/configs.

Useful? React with 👍 / 👎.

@amy-why-3459 amy-why-3459 changed the title 【WIP】Support async computation and communication across stages by chunks [WIP]Support async computation and communication across stages by chunks Jan 10, 2026
@amy-why-3459 amy-why-3459 force-pushed the async_chunk branch 2 times, most recently from 209603d to 67a2fdc Compare January 10, 2026 10:33
@hsliuustc0106

Copy link
Copy Markdown
Collaborator

please add test results

Comment thread vllm_omni/worker/gpu_ar_model_runner.py Outdated
return output


# TODO: need to check talker's prepare_inputs logic

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it possible to move the new added code into a function block? Otherwise, it will make upgrading vLLm version very hard.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agreed, try to abstract chunk related logic into saparate functions.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok, deleted

@Gaohan123 Gaohan123 left a comment

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is a huge work. Thanks for the efforts. Looking forward to more discussions.

Comment thread vllm_omni/config/model.py Outdated
Used to route outputs to appropriate processors (e.g., "image",
"audio", "latents"). If None, output type is inferred.
stage_connector_name: Stage connector name
stage_connector_extra: Extra configuration for stage connector

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

config attribute is a little bit confusing

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fixed, move connector name & extra into stage_connector_config

Comment thread vllm_omni/config/model.py Outdated
hf_config_name: str | None = None
custom_process_input_func: str | None = None
stage_connector_name: str = "SharedMemoryConnector"
stage_connector_extra: dict[str, Any] = field(default_factory=dict)

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why can't we integrate connector related config into a single stage_connector_config?

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fixed, move connector name & extra into stage_connector_config

Comment thread vllm_omni/core/sched/omni_ar_scheduler.py
from vllm.v1.spec_decode.metrics import SpecDecodingStats

from vllm_omni.core.sched.output import OmniNewRequestData
from vllm_omni.distributed.omni_connectors.adapter import get_chunk_for_generation

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Similar issue like above. Why generation needs a specific get?

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For generation, we need to wait at least one token generated before starting get chunk, there are some specific handlings different from AR get chunk

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it possible to set an argument in get_chunk rather than a new method? The name is also very confusing for users.

engine_output_type: latent # Output hidden states for talker
distributed_executor_backend: "mp"
enable_prefix_caching: false
max_num_batched_tokens: 32768

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this for debug?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes, deleted

prompt_token_ids = request.prompt_token_ids

# Convert ConstantList to regular list for OmniSerializer serialization
if hasattr(all_token_ids, "_x"):

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why it is necessary to modify this method? Besides, the attribute "_x" is confusing.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is a common function use for convert ConstantList to regular list, "_x" is a attribute of ConstantList, not defined by us


return code2wav_inputs
talker_output = pooling_output
if "code_predictor_codes" not in talker_output:

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The same question.

Comment thread vllm_omni/worker/gpu_ar_model_runner.py Outdated
return output


# TODO: need to check talker's prepare_inputs logic

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agreed, try to abstract chunk related logic into saparate functions.

Comment thread vllm_omni/worker/gpu_ar_model_runner.py Outdated
)
logits = None
else:
# Apply same fix for broadcast_pp_output path

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The same as above.

@Gaohan123

Copy link
Copy Markdown
Collaborator

@natureofnature Please take a look at OmniConnector related logics. Thanks!

Comment thread vllm_omni/distributed/omni_connectors/omni_transfer_state.py Outdated
@amy-why-3459 amy-why-3459 force-pushed the async_chunk branch 7 times, most recently from 5055bef to f2c49b4 Compare January 14, 2026 14:08
@hsliuustc0106

Copy link
Copy Markdown
Collaborator

we can leave chunked prefill for a seperate PR

Comment thread vllm_omni/core/sched/omni_ar_scheduler.py Outdated
Comment thread vllm_omni/model_executor/stage_input_processors/qwen3_omni.py Outdated
Comment thread vllm_omni/model_executor/stage_input_processors/qwen3_omni.py Outdated
Comment thread vllm_omni/model_executor/stage_input_processors/qwen3_omni.py Outdated
@amy-why-3459 amy-why-3459 force-pushed the async_chunk branch 3 times, most recently from 6d316fb to a6e86a2 Compare January 19, 2026 04:52
@R2-Y R2-Y force-pushed the async_chunk branch 2 times, most recently from 9560618 to 293ddaa Compare January 24, 2026 03:58
@amy-why-3459 amy-why-3459 force-pushed the async_chunk branch 2 times, most recently from 8d4650d to 9b64fca Compare January 24, 2026 07:51
@amy-why-3459 amy-why-3459 mentioned this pull request Jan 24, 2026
35 tasks

@tzhouam tzhouam left a comment

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@ZeldaHuang ZeldaHuang left a comment

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Overall LGTM

target_stage_id = stage_id - 1
# Handle new requests
for new_req_data in scheduler_output.scheduled_new_reqs:
req_id = new_req_data.req_id[0:25]

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

better don't use this magic number, it's same as chunk_size and left_context_size

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

after we rebase into v0.14.0, the request_id in each stage are inconsistent for same request. a string similar to UUID is appended to the end of each request_id. It's unclear whether this is a new bug in main branch or if some original concatenation logic has changed. We will temporarily avoid this issue and resolve it as soon as possible.

Comment thread vllm_omni/model_executor/models/qwen3_omni/qwen3_omni.py Outdated
Comment thread vllm_omni/worker/gpu_model_runner.py
Signed-off-by: amy-why-3459 <wuhaiyan17@huawei.com>
Signed-off-by: Rein Yang <ruiruyang2@gmail.com>
Co-authored-by: Rein Yang <ruiruyang2@gmail.com>
Co-authored-by: CHEN <116010019@link.cuhk.edu.cn>
Signed-off-by: Rein Yang <ruiruyang2@gmail.com>

@Gaohan123 Gaohan123 left a comment

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please take a look at the latest comments for your reference.

Comment thread docs/user_guide/examples/online_serving/qwen3_omni.md
--8<-- "examples/online_serving/qwen3_omni/run_gradio_demo.sh"
``````
??? abstract "server-async-debug.log"
``````log

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please remove your unused log files in examples folder. Then run "mkdocs build". These links will be removed automatically.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fixed

from vllm.v1.spec_decode.metrics import SpecDecodingStats

from vllm_omni.core.sched.output import OmniNewRequestData
from vllm_omni.distributed.omni_connectors.adapter import get_chunk_for_generation

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it possible to set an argument in get_chunk rather than a new method? The name is also very confusing for users.

from_stage: Source stage identifier
to_stage: Destination stage identifier
request_id: Unique request identifier
put_key: Unique request identifier

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If users don't know that the key should differentiate multiple chunks, that might be challenging for users if they want to integrate a new model with async chunk feature. How about we set the key just from chunk level? When async_chunk is disabled, the chunk is just the whole embedding.

data_bytes = shm_read_bytes(meta)
obj = self.deserialize_obj(data_bytes)
size = metadata.get("size", len(data_bytes))
elif "inline_bytes" in metadata:

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems that you remove some deserialize logic for some keys such as inline_bytes. Please check if it will influence other tasks.

_wall_start_ts,
):
yield output
else:

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it possible that we integrate async_chunk and normal pipeline into a single logic? Here two separate implementation seems redundant. In my view, the function here is just wait and collect all outputs from each stage. There is no difference for async_chunk and normal version.

Comment thread vllm_omni/model_executor/models/qwen3_omni/qwen3_omni.py
return talker_input_id, talker_input_embed, trailing_text_hidden_all

def _thinker_decode_to_talker_decode(
self,

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For the embedding trailing logic across the file, is it possible to integrated into upper level modules such as modelrunner, which can be shared by other models? Of course not now.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The trailing logic for different models are not same, maybe hard to share with others. Here, for qwen3-omni, the trailing text hidden only generate during thinker still has decode step, after thinker finish decode, it will use eos as first output. In following talker decode, it will use tts_pad_embed as remaining steps output. I believe there are some difference between different models.

3. Package for talker with additional information
"""
all_token_ids = request.all_token_ids # prefill + decode
prompt_token_ids = request.prompt_token_ids

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why you need this new input process function? It seems that these token_ids varaibles are useless and will not be returned.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The input parameters used are different. The original input parameter is enginecore output, while the current input parameter is pooling output. The token_ids will be used for the preprocess of the talker.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Understood. Not now but we can discuss later about how to improve generalization.

Comment thread vllm_omni/model_executor/stage_input_processors/qwen3_omni.py
@hsliuustc0106

Copy link
Copy Markdown
Collaborator

why the median audio ttfp is 0

@Gaohan123 Gaohan123 left a comment

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM. As it is an important feature and depended by other features, maybe we can merge first. Please collect left optimization points as a new RFC to be resolved in next PR. Thanks!

Signed-off-by: amy-why-3459 <wuhaiyan17@huawei.com>
@david6666666

david6666666 commented Jan 26, 2026

Copy link
Copy Markdown
Collaborator

LGTM, This is an exciting PR, thank you for your contribution.

@david6666666 david6666666 merged commit 52bf1e2 into vllm-project:main Jan 26, 2026
7 checks passed
majiayu000 pushed a commit to majiayu000/vllm-omni that referenced this pull request Jan 26, 2026
… chunks (vllm-project#727)

Signed-off-by: amy-why-3459 <wuhaiyan17@huawei.com>
Signed-off-by: Rein Yang <ruiruyang2@gmail.com>
Co-authored-by: Rein Yang <ruiruyang2@gmail.com>
Co-authored-by: CHEN <116010019@link.cuhk.edu.cn>
Co-authored-by: Hongsheng Liu <liuhongsheng4@huawei.com>
Signed-off-by: majiayu000 <1835304752@qq.com>
@chickeyton chickeyton mentioned this pull request Jan 28, 2026
1 task
daixinning pushed a commit to daixinning/vllm-omni that referenced this pull request May 28, 2026
… chunks (vllm-project#727)

Signed-off-by: amy-why-3459 <wuhaiyan17@huawei.com>
Signed-off-by: Rein Yang <ruiruyang2@gmail.com>
Co-authored-by: Rein Yang <ruiruyang2@gmail.com>
Co-authored-by: CHEN <116010019@link.cuhk.edu.cn>
Co-authored-by: Hongsheng Liu <liuhongsheng4@huawei.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

high priority high priority issue, needs to be done asap ready label to trigger buildkite CI

Projects

None yet

Development

Successfully merging this pull request may close these issues.

10 participants