Skip to content

[feature] add VAE patch parallel to Wan2.2#1350

Closed
hadipash wants to merge 6 commits into
vllm-project:mainfrom
hadipash:wan_vae
Closed

[feature] add VAE patch parallel to Wan2.2#1350
hadipash wants to merge 6 commits into
vllm-project:mainfrom
hadipash:wan_vae

Conversation

@hadipash

@hadipash hadipash commented Feb 12, 2026

Copy link
Copy Markdown
Contributor

Purpose

This is a task from #814.
Add VAE Patch Parallelism to Wan2.2 models.

Test Result

Reproduced with 4xH800.

Model Mode TP=2 TP=2 & VAE=2 TP=4 & VAE=4
Wan-AI/Wan2.2-TI2V-5B-Diffusers T2V
t2v_output_tp2.mp4
t2v_output_tp2_vae2.mp4
t2v_output_tp4_vae4.mp4
Time 37.3 s/vid 37.1 s/vid 27.8 s/vid
Memory Usage 54.1 GB 29.1 GB 27.6 GB
Wan-AI/Wan2.2-TI2V-5B-Diffusers TI2V
ti2v_output_tp2.mp4
ti2v_output_tp2_vae2.mp4
ti2v_output_tp4_vae4.mp4
Time 17.2 s/vid 17.4 s/vid 13.9 s/vid
Memory Usage 35.3 GB 26.3 GB 24.7 GB

Run commands

T2V

python examples/offline_inference/text_to_video/text_to_video.py --model=Wan-AI/Wan2.2-TI2V-5B-Diffusers --width=1280 --height=704 --guidance-scale=5.0 --prompt="Two anthropomorphic cats in comfy boxing gear and bright gloves fight intensely on a spotlighted stage" --output=results/t2v_output_tp2_vae2.mp4 --tensor-parallel-size=2 --vae-patch-parallel-size=2

TI2v

python examples/offline_inference/image_to_video/image_to_video.py --model=Wan-AI/Wan2.2-TI2V-5B-Diffusers --image=tmp/i2v_input.JPG --guidance-scale=5.0 --prompt="Summer beach vacation style, a white cat wearing sunglasses sits on a surfboard. The fluffy-furred feline gazes directly at the camera with a relaxed expression. Blurred beach scenery forms the background featuring crystal-clear waters, distant green hills, and a blue sky dotted with white clouds. The cat assumes a naturally relaxed posture, as if savoring the sea breeze and warm sunlight. A close-up shot highlights the feline's intricate details and the refreshing atmosphere of the seaside." --output=results/ti2v_output_tp2_vae2.mp4 --tensor-parallel-size=2 --vae-patch-parallel-size=2

Essential Elements of an Effective PR Description Checklist
  • The purpose of the PR, such as "Fix some issue (link existing issues this PR will resolve)".
  • The test plan, such as providing test command.
  • The test results, such as pasting the results comparison before and after, or e2e results
  • (Optional) The necessary documentation update, such as updating supported_models.md and examples for a new model. Please run mkdocs serve to sync the documentation editions to ./docs.
  • (Optional) Release notes update. If your change is user facing, please update the release notes draft.

BEFORE SUBMITTING, PLEASE READ https://github.com/vllm-project/vllm-omni/blob/main/CONTRIBUTING.md (anything written below this line will be removed by GitHub Actions)

Signed-off-by: Rustam Khadipash <16683750+hadipash@users.noreply.github.com>
Signed-off-by: Rustam Khadipash <16683750+hadipash@users.noreply.github.com>
@hadipash

Copy link
Copy Markdown
Contributor Author

@dongbo910220

@chatgpt-codex-connector chatgpt-codex-connector Bot left a comment

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Codex Review

Here are some automated review suggestions for this pull request.

Reviewed commit: 3cfcb0b20a

ℹ️ About Codex in GitHub

Codex has been enabled to automatically review pull requests in this repo. Reviews are triggered when you

  • Open a pull request for review
  • Mark a draft as ready
  • Comment "@codex review".

If Codex has suggestions, it will comment; otherwise it will react with 👍.

When you sign up for Codex through ChatGPT, Codex can also answer questions or update the PR, like "@codex address that feedback".

Comment thread vllm_omni/diffusion/distributed/vae_patch_parallel.py Outdated
Comment thread vllm_omni/diffusion/distributed/vae_patch_parallel.py
@wtomin

wtomin commented Feb 12, 2026

Copy link
Copy Markdown
Collaborator

@hadipash Hello, @Bounty-hunter is about to propose one refactoring PR for vae patch parallelism. The purpose is to allow minimal efforts for adapting a model for vae patch parallelism.

Can you stay tuned for @Bounty-hunter 's PR and give your comments then?

Signed-off-by: Rustam Khadipash <16683750+hadipash@users.noreply.github.com>
@hadipash

Copy link
Copy Markdown
Contributor Author

@codex review

@chatgpt-codex-connector

Copy link
Copy Markdown

Codex Review: Didn't find any major issues. Chef's kiss.

ℹ️ About Codex in GitHub

Codex has been enabled to automatically review pull requests in this repo. Reviews are triggered when you

  • Open a pull request for review
  • Mark a draft as ready
  • Comment "@codex review".

If Codex has suggestions, it will comment; otherwise it will react with 👍.

When you sign up for Codex through ChatGPT, Codex can also answer questions or update the PR, like "@codex address that feedback".

Copilot AI left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copilot encountered an error and was unable to review this pull request. You can try again by re-requesting a review.

@SamitHuang

Copy link
Copy Markdown
Collaborator

how about the memory cost and gen time cost changed?

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why remove this file?

@hadipash hadipash Feb 22, 2026

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This module is not used anywhere in the project, making it confusing when developing for Wan.

"WanPipeline": (
"wan2_2",
"pipeline_wan2_2",
"Wan22Pipeline",
),
"StableAudioPipeline": (
"stable_audio",
"pipeline_stable_audio",
"StableAudioPipeline",
),
"WanImageToVideoPipeline": (
"wan2_2",
"pipeline_wan2_2_i2v",
"Wan22I2VPipeline",
),

@lishunyang12 lishunyang12 left a comment

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Clean generalization of VaePatchParallelism for 5D video VAEs -- the send/recv approach is a nice improvement over the gather-everything pattern. A few questions inline, mostly around edge cases in the new Wan decode path.

for i in h_starts:
for j in w_starts:
# Offset assignment by 1 so rank0 avoids decoding the largest (tile_id=0) tile.
tile_rank = (tile_id + 1) % pp_size

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What happens when num_tiles == pp_size exactly and tile_id 0 is assigned to rank 1 via (0+1) % pp_size? Rank 0 ends up with no tiles if pp_size evenly divides num_tiles. Probably fine since rank 0 still participates in the recv loop, but is there a scenario where rank 0 having zero local tiles causes issues with the max_count == 0 early return further down?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For the case when num_tiles == pp_size, rank 0 will simply process the last tile (i.e., shift by 1). Moreover, when rank 0 has no tiles to process (i.e., num_tiles < pp_size), the logic still holds because max_count is broadcast to all ranks, ensuring the value is identical across them.

    # Gather per-rank tile counts.
    count_tensor = torch.tensor([len(local_tiles)], device=z.device, dtype=torch.int64)
    if rank == 0:
        count_gather = [torch.empty_like(count_tensor) for _ in range(world_size)]
    else:
        count_gather = None
    dist.gather(count_tensor, gather_list=count_gather, dst=0, group=group)
    max_count = 0
    if rank == 0:
        counts = [int(t.item()) for t in count_gather]  # type: ignore[arg-type]
        max_count = max(counts) if counts else 0
    max_count_tensor = torch.tensor([max_count], device=z.device, dtype=torch.int64)
    dist.broadcast(max_count_tensor, src=0, group=group)
    max_count = int(max_count_tensor.item())

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

understand

vae._conv_idx = [0]
tile = z[:, :, k : k + 1, i : i + tile_latent_min_height, j : j + tile_latent_min_width]
tile = vae.post_quant_conv(tile)
decoded = vae.decoder(tile, feat_cache=vae._feat_map, feat_idx=vae._conv_idx, first_chunk=(k == 0))

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This inner loop over num_frames calls vae.decoder frame-by-frame with feat_cache / feat_idx -- are these VAE-internal stateful caches safe across ranks? Each rank has its own VAE copy so I assume yes, but worth confirming since _feat_map and _conv_idx look like mutable instance state.

@hadipash hadipash Feb 23, 2026

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, I mimicked the diffusers' implementation. The cache is reset for each iteration, that is _feat_map and _conv_idx are set back to the default None and 0 values:

https://github.com/huggingface/diffusers/blob/9380e58821092ad05e618fa8a259030136ba43c0/src/diffusers/models/autoencoders/autoencoder_kl_wan.py#L1363-L1379

count_gather = [torch.empty_like(count_tensor) for _ in range(world_size)]
else:
count_gather = None
dist.gather(count_tensor, gather_list=count_gather, dst=0, group=group)

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

dist.gather here uses dst=0 with the full world_size group, meaning inactive ranks (rank >= pp_size) also participate with count_tensor = [0]. Correct but slightly wasteful -- would an active-only subgroup be worth it to avoid synchronization overhead on idle ranks?

@hadipash hadipash Feb 23, 2026

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's possible. However, the communication overhead is very small compared to the algorithm complexity when adding multiple branches for active / idle ranks.

device=z.device,
dtype=z.dtype,
)
dist.recv(buf, src=src_rank, group=group)

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice switch from all-gather buffer to point-to-point send/recv -- addresses the OOM concern. dist.send and dist.recv are blocking though, so with many tiles this serializes rank 0 receive loop. Have you considered isend/irecv to overlap receives from different source ranks?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same as above. Async communication can be a bit tricky, and since only rank 0 is responsible for stitching the patches, it's acceptable to leave other ranks waiting to send. Although, it would be a nice addition - contributions are welcome if you're interested.

# Active non-zero ranks send their tiles; inactive ranks have nothing.
for tile in local_tiles:
dist.send(tile.contiguous(), dst=0, group=group)
return torch.empty(0, device=z.device, dtype=z.dtype)

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Non-zero ranks return torch.empty(0) here before the broadcast at the end of VaePatchParallelism.decode. decode() does dist.broadcast -- non-zero ranks need to still participate with a properly sized buffer. The shape broadcast + torch.empty allocation in decode() should cover this, but worth double-checking for the Wan 5D case.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

decode() does dist.broadcast

I'm not sure what you mean by this.

# blend the above tile and the left tile
# to the current tile and add the current tile to the result row
if i > 0:
tile = vae.blend_v(rows[i - 1][j], tile, blend_height)

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

blend_v and blend_h may mutate tiles in-place in diffusers. Since rows still holds references to tiles from tile_map, could in-place blending of a previous row tile corrupt data needed for the current row?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The cascading blend is intentional and matches diffusers' original tiled_decode behavior.

tile_latent_min_size = getattr(self._vae, "tile_latent_min_size", None)
if tile_latent_min_size is None:
decoded = _distributed_tiled_decode(
if tile_latent_min_size is None or self._distributed_patch_decode_fn is None:

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For AutoencoderKLWan, distributed_patch_decode_fn defaults to None. So self._distributed_patch_decode_fn is None is always true for Wan, always routing to the tiled path regardless of tensor size. The fallback inside _distributed_tiled_decode_wan handles small tensors correctly, but the routing logic here is a bit misleading -- maybe a comment would help.

_VAE_PATCH_PARALLEL_ALLOWLIST = {
# Only enable for models we have validated end-to-end.
"ZImagePipeline",
"WanPipeline",

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are WanPipeline and WanImageToVideoPipeline the class names from diffusers or from this repo? Just checking these match the actual pipeline class name looked up against this set.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These classes are from this repo: vllm_omni/diffusion/registry.py.

get_wan22_ti2v_post_process_func,
get_wan22_ti2v_pre_process_func,
)
from .wan2_2_transformer import WanTransformer3DModel

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

SamitHuang already asked about the TI2V pipeline removal -- is TI2V now handled by Wan22I2VPipeline or diffusers WanImageToVideoPipeline? Would be good to respond to that thread with the reasoning.

@hadipash hadipash Feb 22, 2026

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This module is not used anywhere in the project, making it confusing when developing for Wan. TI2V case is handled by Wan22Pipeline.

"WanPipeline": (
"wan2_2",
"pipeline_wan2_2",
"Wan22Pipeline",
),
"StableAudioPipeline": (
"stable_audio",
"pipeline_stable_audio",
"StableAudioPipeline",
),
"WanImageToVideoPipeline": (
"wan2_2",
"pipeline_wan2_2_i2v",
"Wan22I2VPipeline",
),

Signed-off-by: Rustam Khadipash <16683750+hadipash@users.noreply.github.com>
@hadipash

Copy link
Copy Markdown
Contributor Author

how about the memory cost and gen time cost changed?

Updated the table.

@hsliuustc0106

Copy link
Copy Markdown
Collaborator

@vllm-omni-reviewer

@hsliuustc0106 hsliuustc0106 added the ready label to trigger buildkite CI label Feb 23, 2026
@github-actions

Copy link
Copy Markdown

🤖 VLLM-Omni PR Review

Code Review: Add VAE Patch Parallel to Wan2.2

1. Overview

This PR adds VAE Patch Parallelism support for Wan2.2 video generation models, enabling significant memory reduction and improved performance when distributing VAE decode workloads across multiple GPUs. The implementation includes:

  • A new distributed tiled decode function (_distributed_tiled_decode_wan) for 5D video tensors
  • Updates to VaePatchParallelism class to support both 4D (image) and 5D (video) tensors
  • Removal of the standalone Wan22TI2VPipeline (consolidation)
  • Documentation and example updates

Overall Assessment: Positive - The implementation is well-structured and the test results show impressive memory improvements (54.1 GB → 29.1 GB for TP=2 & VAE=2).


2. Code Quality

Strengths

  • Clean separation between image (4D) and video (5D) tensor handling via expected_ndim parameter
  • Good use of point-to-point communication (dist.send/dist.recv) to avoid unnecessary memory allocation on inactive ranks
  • Proper fallback to original decode for edge cases

Potential Issues

vllm_omni/diffusion/distributed/vae_patch_parallel.py:368-371

# Offset assignment by 1 so rank0 avoids decoding the largest (tile_id=0) tile.
tile_rank = (tile_id + 1) % pp_size

The comment mentions "largest tile" but tile_id=0 is simply the first tile (top-left corner), not necessarily the largest. Consider clarifying the comment to explain the actual reasoning (e.g., "Offset assignment by 1 to distribute workload more evenly across ranks").

vllm_omni/diffusion/distributed/vae_patch_parallel.py:382-389
The loop iterates through all frames for each spatial tile, calling vae.decoder per frame. This is memory-efficient but could be slow for long videos. Consider adding a comment explaining this design choice.

vllm_omni/diffusion/distributed/vae_patch_parallel.py:418-421

if rank == 0:
    meta_gather = [torch.empty_like(meta_tensor) for _ in range(world_size)]
else:
    meta_gather = None

This pattern is repeated multiple times. Consider extracting to a helper function for consistency.


3. Architecture & Design

Strengths

  • Extensible design with configurable decode functions via constructor parameters
  • Proper integration with existing VaePatchParallelism wrapper pattern
  • Clean allowlist-based feature gating in registry

Concerns

Pipeline Removal Without Migration Path
The complete removal of Wan22TI2VPipeline could break existing users. Consider:

  1. Adding a deprecation notice instead of immediate removal, OR
  2. Documenting the migration path in the PR description/release notes

vllm_omni/diffusion/distributed/vae_patch_parallel.py:647-669
The maybe_wrap_vae_decode_with_patch_parallelism function now has branching logic for different VAE types. Consider using a registry pattern for better extensibility:

_VAE_PARALLEL_CONFIGS = {
    AutoencoderKLWan: dict(
        expected_ndim=5,
        distributed_tiled_decode_fn=_distributed_tiled_decode_wan,
    ),
    AutoencoderKL: dict(
        distributed_patch_decode_fn=_distributed_patch_decode,
    ),
}

4. Security & Safety

Resource Management

  • Good: The implementation properly handles inactive ranks (rank >= pp_size) by avoiding large buffer allocations
  • Good: Uses torch.clamp(dec, min=-1.0, max=1.0) to ensure valid output range

Input Validation

vllm_omni/diffusion/distributed/vae_patch_parallel.py:376-377

if not (width > tile_latent_min_width or height > tile_latent_min_height):
    return orig_decode(z, return_dict=False)[0]

This early return for small inputs is correct, but consider adding a debug log for transparency.


5. Testing & Documentation

Documentation

  • ✅ Tables updated correctly to show VAE-Patch-Parallel support
  • ✅ Note about validated pipelines updated appropriately
  • ⚠️ The PR description mentions test commands but they reference Wan2.2-TI2V-5B-Diffusers while the table shows Wan2.2-T2V-A14B-Diffusers - consider clarifying which models are actually supported

Test Coverage Considerations

  • Consider adding unit tests for _distributed_tiled_decode_wan with various tile configurations
  • Edge cases to test:
    • pp_size > num_tiles (some ranks inactive)
    • Single tile (should fall back to original decode)
    • Non-square video dimensions

6. Specific Suggestions

vllm_omni/diffusion/distributed/vae_patch_parallel.py:352-353

world_size, rank, pp_size = _get_world_rank_pp_size(group, vae_patch_parallel_size)
if pp_size <= 1:

Consider adding a log message when falling back to non-parallel decode for debugging purposes.

vllm_omni/diffusion/distributed/vae_patch_parallel.py:543-545
The VaePatchParallelism class docstring mentions "4-D image vs 5-D video" but the implementation now handles both. Consider updating the docstring to reflect this.

examples/offline_inference/image_to_video/image_to_video.py:177-179

print(
    f"  Parallel configuration: cfg_parallel_size={args.cfg_parallel_size},"
    f" tensor_parallel_size={args.tensor_parallel_size}, vae_patch_parallel_size={args.vae_patch_parallel_size}"
)

Good addition for debugging, but the line is getting long. Consider formatting consistently with text_to_video.py.

vllm_omni/diffusion/registry.py:118-119

"WanPipeline",
"WanImageToVideoPipeline",

Consider adding a comment explaining why these specific pipelines are validated for VAE patch parallelism.


7. Approval Status

LGTM with suggestions

The PR is well-implemented and achieves its stated goals with impressive memory improvements. The core implementation is solid. I have a few suggestions:

  1. Minor: Clarify the comment about tile assignment offset
  2. Minor: Consider documenting the removal of Wan22TI2VPipeline or providing a migration path
  3. Optional: Consider the registry pattern for VAE type detection for better extensibility

The test results demonstrate the feature works correctly, and the code follows the existing patterns in the codebase. Ready to merge after addressing the minor documentation/comment suggestions.


This review was generated automatically by the VLLM-Omni PR Reviewer Bot
using glm-5.


def main():
args = parse_args()
generator = torch.Generator(device=current_omni_platform.device_type).manual_seed(args.seed)

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

have you test for not specific generator? if each rank run with each random generator, Will there be noticeable boundaries when merging?

@hadipash hadipash Feb 24, 2026

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think it's relevant to VAE. Generator affects the initial input noise only, and I believe this scenario is already handled by the framework internally.

@hsliuustc0106

Copy link
Copy Markdown
Collaborator

@vllm-omni-reviewer

@hadipash

Copy link
Copy Markdown
Contributor Author

Superseded by PR #1366.

@hadipash hadipash closed this Feb 25, 2026
@hadipash hadipash deleted the wan_vae branch February 25, 2026 01:46
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ready label to trigger buildkite CI

Projects

None yet

Development

Successfully merging this pull request may close these issues.

7 participants