[feature] add VAE patch parallel to Wan2.2#1350
Conversation
Signed-off-by: Rustam Khadipash <16683750+hadipash@users.noreply.github.com>
Signed-off-by: Rustam Khadipash <16683750+hadipash@users.noreply.github.com>
There was a problem hiding this comment.
💡 Codex Review
Here are some automated review suggestions for this pull request.
Reviewed commit: 3cfcb0b20a
ℹ️ About Codex in GitHub
Codex has been enabled to automatically review pull requests in this repo. Reviews are triggered when you
- Open a pull request for review
- Mark a draft as ready
- Comment "@codex review".
If Codex has suggestions, it will comment; otherwise it will react with 👍.
When you sign up for Codex through ChatGPT, Codex can also answer questions or update the PR, like "@codex address that feedback".
|
@hadipash Hello, @Bounty-hunter is about to propose one refactoring PR for vae patch parallelism. The purpose is to allow minimal efforts for adapting a model for vae patch parallelism. Can you stay tuned for @Bounty-hunter 's PR and give your comments then? |
Signed-off-by: Rustam Khadipash <16683750+hadipash@users.noreply.github.com>
|
@codex review |
|
Codex Review: Didn't find any major issues. Chef's kiss. ℹ️ About Codex in GitHubCodex has been enabled to automatically review pull requests in this repo. Reviews are triggered when you
If Codex has suggestions, it will comment; otherwise it will react with 👍. When you sign up for Codex through ChatGPT, Codex can also answer questions or update the PR, like "@codex address that feedback". |
|
how about the memory cost and gen time cost changed? |
There was a problem hiding this comment.
This module is not used anywhere in the project, making it confusing when developing for Wan.
vllm-omni/vllm_omni/diffusion/registry.py
Lines 53 to 67 in a3f2d4c
lishunyang12
left a comment
There was a problem hiding this comment.
Clean generalization of VaePatchParallelism for 5D video VAEs -- the send/recv approach is a nice improvement over the gather-everything pattern. A few questions inline, mostly around edge cases in the new Wan decode path.
| for i in h_starts: | ||
| for j in w_starts: | ||
| # Offset assignment by 1 so rank0 avoids decoding the largest (tile_id=0) tile. | ||
| tile_rank = (tile_id + 1) % pp_size |
There was a problem hiding this comment.
What happens when num_tiles == pp_size exactly and tile_id 0 is assigned to rank 1 via (0+1) % pp_size? Rank 0 ends up with no tiles if pp_size evenly divides num_tiles. Probably fine since rank 0 still participates in the recv loop, but is there a scenario where rank 0 having zero local tiles causes issues with the max_count == 0 early return further down?
There was a problem hiding this comment.
For the case when num_tiles == pp_size, rank 0 will simply process the last tile (i.e., shift by 1). Moreover, when rank 0 has no tiles to process (i.e., num_tiles < pp_size), the logic still holds because max_count is broadcast to all ranks, ensuring the value is identical across them.
# Gather per-rank tile counts.
count_tensor = torch.tensor([len(local_tiles)], device=z.device, dtype=torch.int64)
if rank == 0:
count_gather = [torch.empty_like(count_tensor) for _ in range(world_size)]
else:
count_gather = None
dist.gather(count_tensor, gather_list=count_gather, dst=0, group=group)
max_count = 0
if rank == 0:
counts = [int(t.item()) for t in count_gather] # type: ignore[arg-type]
max_count = max(counts) if counts else 0
max_count_tensor = torch.tensor([max_count], device=z.device, dtype=torch.int64)
dist.broadcast(max_count_tensor, src=0, group=group)
max_count = int(max_count_tensor.item())| vae._conv_idx = [0] | ||
| tile = z[:, :, k : k + 1, i : i + tile_latent_min_height, j : j + tile_latent_min_width] | ||
| tile = vae.post_quant_conv(tile) | ||
| decoded = vae.decoder(tile, feat_cache=vae._feat_map, feat_idx=vae._conv_idx, first_chunk=(k == 0)) |
There was a problem hiding this comment.
This inner loop over num_frames calls vae.decoder frame-by-frame with feat_cache / feat_idx -- are these VAE-internal stateful caches safe across ranks? Each rank has its own VAE copy so I assume yes, but worth confirming since _feat_map and _conv_idx look like mutable instance state.
There was a problem hiding this comment.
Yes, I mimicked the diffusers' implementation. The cache is reset for each iteration, that is _feat_map and _conv_idx are set back to the default None and 0 values:
| count_gather = [torch.empty_like(count_tensor) for _ in range(world_size)] | ||
| else: | ||
| count_gather = None | ||
| dist.gather(count_tensor, gather_list=count_gather, dst=0, group=group) |
There was a problem hiding this comment.
dist.gather here uses dst=0 with the full world_size group, meaning inactive ranks (rank >= pp_size) also participate with count_tensor = [0]. Correct but slightly wasteful -- would an active-only subgroup be worth it to avoid synchronization overhead on idle ranks?
There was a problem hiding this comment.
It's possible. However, the communication overhead is very small compared to the algorithm complexity when adding multiple branches for active / idle ranks.
| device=z.device, | ||
| dtype=z.dtype, | ||
| ) | ||
| dist.recv(buf, src=src_rank, group=group) |
There was a problem hiding this comment.
Nice switch from all-gather buffer to point-to-point send/recv -- addresses the OOM concern. dist.send and dist.recv are blocking though, so with many tiles this serializes rank 0 receive loop. Have you considered isend/irecv to overlap receives from different source ranks?
There was a problem hiding this comment.
Same as above. Async communication can be a bit tricky, and since only rank 0 is responsible for stitching the patches, it's acceptable to leave other ranks waiting to send. Although, it would be a nice addition - contributions are welcome if you're interested.
| # Active non-zero ranks send their tiles; inactive ranks have nothing. | ||
| for tile in local_tiles: | ||
| dist.send(tile.contiguous(), dst=0, group=group) | ||
| return torch.empty(0, device=z.device, dtype=z.dtype) |
There was a problem hiding this comment.
Non-zero ranks return torch.empty(0) here before the broadcast at the end of VaePatchParallelism.decode. decode() does dist.broadcast -- non-zero ranks need to still participate with a properly sized buffer. The shape broadcast + torch.empty allocation in decode() should cover this, but worth double-checking for the Wan 5D case.
There was a problem hiding this comment.
decode() does dist.broadcast
I'm not sure what you mean by this.
| # blend the above tile and the left tile | ||
| # to the current tile and add the current tile to the result row | ||
| if i > 0: | ||
| tile = vae.blend_v(rows[i - 1][j], tile, blend_height) |
There was a problem hiding this comment.
blend_v and blend_h may mutate tiles in-place in diffusers. Since rows still holds references to tiles from tile_map, could in-place blending of a previous row tile corrupt data needed for the current row?
There was a problem hiding this comment.
The cascading blend is intentional and matches diffusers' original tiled_decode behavior.
| tile_latent_min_size = getattr(self._vae, "tile_latent_min_size", None) | ||
| if tile_latent_min_size is None: | ||
| decoded = _distributed_tiled_decode( | ||
| if tile_latent_min_size is None or self._distributed_patch_decode_fn is None: |
There was a problem hiding this comment.
For AutoencoderKLWan, distributed_patch_decode_fn defaults to None. So self._distributed_patch_decode_fn is None is always true for Wan, always routing to the tiled path regardless of tensor size. The fallback inside _distributed_tiled_decode_wan handles small tensors correctly, but the routing logic here is a bit misleading -- maybe a comment would help.
| _VAE_PATCH_PARALLEL_ALLOWLIST = { | ||
| # Only enable for models we have validated end-to-end. | ||
| "ZImagePipeline", | ||
| "WanPipeline", |
There was a problem hiding this comment.
Are WanPipeline and WanImageToVideoPipeline the class names from diffusers or from this repo? Just checking these match the actual pipeline class name looked up against this set.
There was a problem hiding this comment.
These classes are from this repo: vllm_omni/diffusion/registry.py.
| get_wan22_ti2v_post_process_func, | ||
| get_wan22_ti2v_pre_process_func, | ||
| ) | ||
| from .wan2_2_transformer import WanTransformer3DModel |
There was a problem hiding this comment.
SamitHuang already asked about the TI2V pipeline removal -- is TI2V now handled by Wan22I2VPipeline or diffusers WanImageToVideoPipeline? Would be good to respond to that thread with the reasoning.
There was a problem hiding this comment.
This module is not used anywhere in the project, making it confusing when developing for Wan. TI2V case is handled by Wan22Pipeline.
vllm-omni/vllm_omni/diffusion/registry.py
Lines 53 to 67 in a3f2d4c
Signed-off-by: Rustam Khadipash <16683750+hadipash@users.noreply.github.com>
Updated the table. |
|
@vllm-omni-reviewer |
🤖 VLLM-Omni PR ReviewCode Review: Add VAE Patch Parallel to Wan2.21. OverviewThis PR adds VAE Patch Parallelism support for Wan2.2 video generation models, enabling significant memory reduction and improved performance when distributing VAE decode workloads across multiple GPUs. The implementation includes:
Overall Assessment: Positive - The implementation is well-structured and the test results show impressive memory improvements (54.1 GB → 29.1 GB for TP=2 & VAE=2). 2. Code QualityStrengths
Potential Issuesvllm_omni/diffusion/distributed/vae_patch_parallel.py:368-371 # Offset assignment by 1 so rank0 avoids decoding the largest (tile_id=0) tile.
tile_rank = (tile_id + 1) % pp_sizeThe comment mentions "largest tile" but tile_id=0 is simply the first tile (top-left corner), not necessarily the largest. Consider clarifying the comment to explain the actual reasoning (e.g., "Offset assignment by 1 to distribute workload more evenly across ranks"). vllm_omni/diffusion/distributed/vae_patch_parallel.py:382-389 vllm_omni/diffusion/distributed/vae_patch_parallel.py:418-421 if rank == 0:
meta_gather = [torch.empty_like(meta_tensor) for _ in range(world_size)]
else:
meta_gather = NoneThis pattern is repeated multiple times. Consider extracting to a helper function for consistency. 3. Architecture & DesignStrengths
ConcernsPipeline Removal Without Migration Path
vllm_omni/diffusion/distributed/vae_patch_parallel.py:647-669 _VAE_PARALLEL_CONFIGS = {
AutoencoderKLWan: dict(
expected_ndim=5,
distributed_tiled_decode_fn=_distributed_tiled_decode_wan,
),
AutoencoderKL: dict(
distributed_patch_decode_fn=_distributed_patch_decode,
),
}4. Security & SafetyResource Management
Input Validationvllm_omni/diffusion/distributed/vae_patch_parallel.py:376-377 if not (width > tile_latent_min_width or height > tile_latent_min_height):
return orig_decode(z, return_dict=False)[0]This early return for small inputs is correct, but consider adding a debug log for transparency. 5. Testing & DocumentationDocumentation
Test Coverage Considerations
6. Specific Suggestionsvllm_omni/diffusion/distributed/vae_patch_parallel.py:352-353 world_size, rank, pp_size = _get_world_rank_pp_size(group, vae_patch_parallel_size)
if pp_size <= 1:Consider adding a log message when falling back to non-parallel decode for debugging purposes. vllm_omni/diffusion/distributed/vae_patch_parallel.py:543-545 examples/offline_inference/image_to_video/image_to_video.py:177-179 print(
f" Parallel configuration: cfg_parallel_size={args.cfg_parallel_size},"
f" tensor_parallel_size={args.tensor_parallel_size}, vae_patch_parallel_size={args.vae_patch_parallel_size}"
)Good addition for debugging, but the line is getting long. Consider formatting consistently with text_to_video.py. vllm_omni/diffusion/registry.py:118-119 "WanPipeline",
"WanImageToVideoPipeline",Consider adding a comment explaining why these specific pipelines are validated for VAE patch parallelism. 7. Approval StatusLGTM with suggestions The PR is well-implemented and achieves its stated goals with impressive memory improvements. The core implementation is solid. I have a few suggestions:
The test results demonstrate the feature works correctly, and the code follows the existing patterns in the codebase. Ready to merge after addressing the minor documentation/comment suggestions. This review was generated automatically by the VLLM-Omni PR Reviewer Bot |
|
|
||
| def main(): | ||
| args = parse_args() | ||
| generator = torch.Generator(device=current_omni_platform.device_type).manual_seed(args.seed) |
There was a problem hiding this comment.
have you test for not specific generator? if each rank run with each random generator, Will there be noticeable boundaries when merging?
There was a problem hiding this comment.
I don't think it's relevant to VAE. Generator affects the initial input noise only, and I believe this scenario is already handled by the framework internally.
|
@vllm-omni-reviewer |
|
Superseded by PR #1366. |
Purpose
This is a task from #814.
Add VAE Patch Parallelism to Wan2.2 models.
Test Result
Reproduced with 4xH800.
Wan-AI/Wan2.2-TI2V-5B-Diffuserst2v_output_tp2.mp4
t2v_output_tp2_vae2.mp4
t2v_output_tp4_vae4.mp4
Wan-AI/Wan2.2-TI2V-5B-Diffusersti2v_output_tp2.mp4
ti2v_output_tp2_vae2.mp4
ti2v_output_tp4_vae4.mp4
Run commands
T2V
python examples/offline_inference/text_to_video/text_to_video.py --model=Wan-AI/Wan2.2-TI2V-5B-Diffusers --width=1280 --height=704 --guidance-scale=5.0 --prompt="Two anthropomorphic cats in comfy boxing gear and bright gloves fight intensely on a spotlighted stage" --output=results/t2v_output_tp2_vae2.mp4 --tensor-parallel-size=2 --vae-patch-parallel-size=2TI2v
python examples/offline_inference/image_to_video/image_to_video.py --model=Wan-AI/Wan2.2-TI2V-5B-Diffusers --image=tmp/i2v_input.JPG --guidance-scale=5.0 --prompt="Summer beach vacation style, a white cat wearing sunglasses sits on a surfboard. The fluffy-furred feline gazes directly at the camera with a relaxed expression. Blurred beach scenery forms the background featuring crystal-clear waters, distant green hills, and a blue sky dotted with white clouds. The cat assumes a naturally relaxed posture, as if savoring the sea breeze and warm sunlight. A close-up shot highlights the feline's intricate details and the refreshing atmosphere of the seaside." --output=results/ti2v_output_tp2_vae2.mp4 --tensor-parallel-size=2 --vae-patch-parallel-size=2Essential Elements of an Effective PR Description Checklist
supported_models.mdandexamplesfor a new model. Please runmkdocs serveto sync the documentation editions to./docs.BEFORE SUBMITTING, PLEASE READ https://github.com/vllm-project/vllm-omni/blob/main/CONTRIBUTING.md (anything written below this line will be removed by GitHub Actions)