Skip to content

Teacache with MultiGPU (FSDP) for Wan2.1 #92

@mali-afridi

Description

@mali-afridi

I see the examples on how to run it on one GPU but when I try something like this:

torchrun --nproc_per_node=8 teacache_generate.py --task t2v-14B --size 1280*720 --ckpt_dir ./Wan2.1-T2V-14B --dit_fsdp --t5_fsdp --ulysses_size 8 --teacache_thresh 0.06 --use_ret_steps --prompt "Summer beach vacation style, a white cat wearing sunglasses sits on a surfboard. The fluffy-furred feline gazes directly at the camera with a relaxed expression. Blurred beach scenery forms the background featuring crystal-clear waters, distant green hills, and a blue sky dotted with white clouds. The cat assumes a naturally relaxed posture, as if savoring the sea breeze and warm sunlight. A close-up shot highlights the feline's intricate details and the refreshing atmosphere of the seaside." --save_file "ali_tf32.mp4"

I get the following error. Can we support it with FSDP? If so, how? I will be happy to contribute.

[2025-09-11 18:17:00,012] INFO: Generating video ...
[rank2]: Traceback (most recent call last):
[rank2]:   File "/home/afridi/Wan2.1/teacache_generate.py", line 1024, in <module>
[rank2]:     generate(args)
[rank2]:   File "/home/afridi/Wan2.1/teacache_generate.py", line 897, in generate
[rank2]:     video = wan_t2v.generate(
[rank2]:   File "/home/afridi/Wan2.1/teacache_generate.py", line 114, in t2v_generate
[rank2]:     context = self.text_encoder([input_prompt], self.device)
[rank2]:   File "/home/afridi/Wan2.1/wan/modules/t5.py", line 512, in __call__
[rank2]:     context = self.model(ids, mask)
[rank2]:   File "/home/afridi/basic/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1773, in _wrapped_call_impl
[rank2]:     return self._call_impl(*args, **kwargs)
[rank2]:   File "/home/afridi/basic/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1784, in _call_impl
[rank2]:     return forward_call(*args, **kwargs)
[rank2]: TypeError: teacache_forward() missing 2 required positional arguments: 'context' and 'seq_len'
[rank5]: Traceback (most recent call last):
[rank5]:   File "/home/afridi/Wan2.1/teacache_generate.py", line 1024, in <module>
[rank5]:     generate(args)
[rank5]:   File "/home/afridi/Wan2.1/teacache_generate.py", line 897, in generate
[rank5]:     video = wan_t2v.generate(
[rank5]:   File "/home/afridi/Wan2.1/teacache_generate.py", line 114, in t2v_generate
[rank5]:     context = self.text_encoder([input_prompt], self.device)
[rank5]:   File "/home/afridi/Wan2.1/wan/modules/t5.py", line 512, in __call__
[rank5]:     context = self.model(ids, mask)
[rank5]:   File "/home/afridi/basic/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1773, in _wrapped_call_impl
[rank5]:     return self._call_impl(*args, **kwargs)
[rank5]:   File "/home/afridi/basic/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1784, in _call_impl
[rank5]:     return forward_call(*args, **kwargs)
[rank5]: TypeError: teacache_forward() missing 2 required positional arguments: 'context' and 'seq_len'
[rank1]: Traceback (most recent call last):
[rank1]:   File "/home/afridi/Wan2.1/teacache_generate.py", line 1024, in <module>
[rank1]:     generate(args)
[rank1]:   File "/home/afridi/Wan2.1/teacache_generate.py", line 897, in generate
[rank1]:     video = wan_t2v.generate(
[rank1]:   File "/home/afridi/Wan2.1/teacache_generate.py", line 114, in t2v_generate
[rank1]:     context = self.text_encoder([input_prompt], self.device)
[rank1]:   File "/home/afridi/Wan2.1/wan/modules/t5.py", line 512, in __call__
[rank1]:     context = self.model(ids, mask)
[rank1]:   File "/home/afridi/basic/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1773, in _wrapped_call_impl
[rank1]:     return self._call_impl(*args, **kwargs)
[rank1]:   File "/home/afridi/basic/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1784, in _call_impl
[rank1]:     return forward_call(*args, **kwargs)
[rank1]: TypeError: teacache_forward() missing 2 required positional arguments: 'context' and 'seq_len'
[rank4]: Traceback (most recent call last):
[rank4]:   File "/home/afridi/Wan2.1/teacache_generate.py", line 1024, in <module>
[rank4]:     generate(args)
[rank4]:   File "/home/afridi/Wan2.1/teacache_generate.py", line 897, in generate
[rank4]:     video = wan_t2v.generate(
[rank4]:   File "/home/afridi/Wan2.1/teacache_generate.py", line 114, in t2v_generate
[rank4]:     context = self.text_encoder([input_prompt], self.device)
[rank4]:   File "/home/afridi/Wan2.1/wan/modules/t5.py", line 512, in __call__
[rank4]:     context = self.model(ids, mask)
[rank4]:   File "/home/afridi/basic/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1773, in _wrapped_call_impl
[rank4]:     return self._call_impl(*args, **kwargs)
[rank4]:   File "/home/afridi/basic/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1784, in _call_impl
[rank4]:     return forward_call(*args, **kwargs)
[rank4]: TypeError: teacache_forward() missing 2 required positional arguments: 'context' and 'seq_len'
[rank7]: Traceback (most recent call last):
[rank7]:   File "/home/afridi/Wan2.1/teacache_generate.py", line 1024, in <module>
[rank7]:     generate(args)
[rank7]:   File "/home/afridi/Wan2.1/teacache_generate.py", line 897, in generate
[rank7]:     video = wan_t2v.generate(
[rank7]:   File "/home/afridi/Wan2.1/teacache_generate.py", line 114, in t2v_generate
[rank7]:     context = self.text_encoder([input_prompt], self.device)
[rank7]:   File "/home/afridi/Wan2.1/wan/modules/t5.py", line 512, in __call__
[rank7]:     context = self.model(ids, mask)
[rank7]:   File "/home/afridi/basic/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1773, in _wrapped_call_impl
[rank7]:     return self._call_impl(*args, **kwargs)
[rank7]:   File "/home/afridi/basic/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1784, in _call_impl
[rank7]:     return forward_call(*args, **kwargs)
[rank7]: TypeError: teacache_forward() missing 2 required positional arguments: 'context' and 'seq_len'
[rank6]: Traceback (most recent call last):
[rank6]:   File "/home/afridi/Wan2.1/teacache_generate.py", line 1024, in <module>
[rank6]:     generate(args)
[rank6]:   File "/home/afridi/Wan2.1/teacache_generate.py", line 897, in generate
[rank6]:     video = wan_t2v.generate(
[rank6]:   File "/home/afridi/Wan2.1/teacache_generate.py", line 114, in t2v_generate
[rank6]:     context = self.text_encoder([input_prompt], self.device)
[rank6]:   File "/home/afridi/Wan2.1/wan/modules/t5.py", line 512, in __call__
[rank6]:     context = self.model(ids, mask)
[rank6]:   File "/home/afridi/basic/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1773, in _wrapped_call_impl
[rank6]:     return self._call_impl(*args, **kwargs)
[rank6]:   File "/home/afridi/basic/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1784, in _call_impl
[rank6]:     return forward_call(*args, **kwargs)
[rank6]: TypeError: teacache_forward() missing 2 required positional arguments: 'context' and 'seq_len'
[rank0]: Traceback (most recent call last):
[rank0]:   File "/home/afridi/Wan2.1/teacache_generate.py", line 1024, in <module>
[rank0]:     generate(args)
[rank0]:   File "/home/afridi/Wan2.1/teacache_generate.py", line 897, in generate
[rank0]:     video = wan_t2v.generate(
[rank0]:   File "/home/afridi/Wan2.1/teacache_generate.py", line 114, in t2v_generate
[rank0]:     context = self.text_encoder([input_prompt], self.device)
[rank0]:   File "/home/afridi/Wan2.1/wan/modules/t5.py", line 512, in __call__
[rank0]:     context = self.model(ids, mask)
[rank0]:   File "/home/afridi/basic/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1773, in _wrapped_call_impl
[rank0]:     return self._call_impl(*args, **kwargs)
[rank0]:   File "/home/afridi/basic/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1784, in _call_impl
[rank0]:     return forward_call(*args, **kwargs)
[rank0]: TypeError: teacache_forward() missing 2 required positional arguments: 'context' and 'seq_len'
[rank3]: Traceback (most recent call last):
[rank3]:   File "/home/afridi/Wan2.1/teacache_generate.py", line 1024, in <module>
[rank3]:     generate(args)
[rank3]:   File "/home/afridi/Wan2.1/teacache_generate.py", line 897, in generate
[rank3]:     video = wan_t2v.generate(
[rank3]:   File "/home/afridi/Wan2.1/teacache_generate.py", line 114, in t2v_generate
[rank3]:     context = self.text_encoder([input_prompt], self.device)
[rank3]:   File "/home/afridi/Wan2.1/wan/modules/t5.py", line 512, in __call__
[rank3]:     context = self.model(ids, mask)
[rank3]:   File "/home/afridi/basic/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1773, in _wrapped_call_impl
[rank3]:     return self._call_impl(*args, **kwargs)
[rank3]:   File "/home/afridi/basic/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1784, in _call_impl
[rank3]:     return forward_call(*args, **kwargs)
[rank3]: TypeError: teacache_forward() missing 2 required positional arguments: 'context' and 'seq_len'
[rank0]:[W911 18:17:03.490515524 ProcessGroupNCCL.cpp:1538] Warning: WARNING: destroy_process_group() was not called before program exit, which can leak resources. For more info, please see https://pytorch.org/docs/stable/distributed.html#shutdown (function operator())

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions