-
Notifications
You must be signed in to change notification settings - Fork 6.3k
Open
Labels
bugSomething isn't workingSomething isn't working
Description
Describe the bug
It is unclear what resolutions WanImageToVideoPipeline supports. It says 720p in the docs but 1280 x 720 will produce a tensor mismatch. 1280x704 works.
It says multiples of 16, but theres gotta be more than that to it, evidenced by the above issue. Some other resolutions Ive tried are 640x624, which also fails and 832x480 works. The Resolutions that work also work if you swap them around to change the aspect ratio.
Reproduction
from diffusers import AutoModel, WanImageToVideoPipeline
from diffusers.utils import load_image
from transformers import BitsAndBytesConfig as TransformersBitsAndBytesConfig
from transformers import UMT5EncoderModel
def load_wan_components(model_name="Wan-AI/Wan2.2-TI2V-5B-Diffusers"):
transformer_quantization_config = TransformersBitsAndBytesConfig(
load_in_8bit=True,
llm_int8_skip_modules=["time_embedder", "timesteps_proj", "time_proj", "norm_out", "proj_out"],
)
text_encoder_quantization_config = TransformersBitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=torch.bfloat16,
)
text_encoder = UMT5EncoderModel.from_pretrained(model_name,
subfolder="text_encoder",
torch_dtype=torch.bfloat16,
quantization_config=text_encoder_quantization_config
)
vae = AutoModel.from_pretrained(model_name, subfolder="vae", torch_dtype=torch.float32)
transformer = AutoModel.from_pretrained(model_name,
subfolder="transformer",
torch_dtype=torch.bfloat16,
quantization_config=transformer_quantization_config
)
return text_encoder, vae, transformer
print("loading WanI2VPipeline")
text_encoder, vae, transformer = load_wan_components()
pipeline = WanImageToVideoPipeline.from_pretrained(
"Wan-AI/Wan2.2-TI2V-5B-Diffusers",
vae=vae,
transformer=transformer,
text_encoder=text_encoder,
torch_dtype=torch.bfloat16
)
pipeline.enable_model_cpu_offload()
image = load_image(
"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/astronaut.jpg")
output = pipeline(prompt="tacos", width=1280, height=720, image=image).frames[0]
Logs
venv) (base) meatfucker@abyss:~/ml/avernus$ python modules/asdfasdf.py
loading WanI2VPipeline
Loading checkpoint shards: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 3/3 [00:20<00:00, 6.71s/it]
The config attributes {'clip_output': False} were passed to AutoencoderKLWan, but are not expected and will be ignored. Please verify your config.json configuration file.
Fetching 5 files: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 5/5 [00:00<00:00, 19472.16it/s]
Loading checkpoint shards: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 5/5 [00:11<00:00, 2.23s/it]
Loading pipeline components...: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 5/5 [00:00<00:00, 7.78it/s]
The module 'WanTransformer3DModel' has been loaded in `bitsandbytes` 8bit and moving it to cpu via `.to()` is not supported. Module is still on cuda:0.
0%| | 0/50 [00:00<?, ?it/s]/home/meatfucker/ml/avernus/venv/lib/python3.13/site-packages/bitsandbytes/autograd/_functions.py:186: UserWarning: MatMul8bitLt: inputs will be cast from torch.bfloat16 to float16 during quantization
warnings.warn(f"MatMul8bitLt: inputs will be cast from {A.dtype} to float16 during quantization")
0%| | 0/50 [00:00<?, ?it/s]
Traceback (most recent call last):
File "/home/meatfucker/ml/avernus/modules/asdfasdf.py", line 46, in <module>
output = pipeline(prompt="tacos", width=1280, height=720, image=image).frames[0]
~~~~~~~~^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/meatfucker/ml/avernus/venv/lib/python3.13/site-packages/torch/utils/_contextlib.py", line 120, in decorate_context
return func(*args, **kwargs)
File "/home/meatfucker/ml/avernus/venv/lib/python3.13/site-packages/diffusers/pipelines/wan/pipeline_wan_i2v.py", line 756, in __call__
noise_pred = current_model(
~~~~~~~~~~~~~^
hidden_states=latent_model_input,
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
...<4 lines>...
return_dict=False,
^^^^^^^^^^^^^^^^^^
)[0]
^
File "/home/meatfucker/ml/avernus/venv/lib/python3.13/site-packages/torch/nn/modules/module.py", line 1773, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^
File "/home/meatfucker/ml/avernus/venv/lib/python3.13/site-packages/torch/nn/modules/module.py", line 1784, in _call_impl
return forward_call(*args, **kwargs)
File "/home/meatfucker/ml/avernus/venv/lib/python3.13/site-packages/diffusers/models/transformers/transformer_wan.py", line 663, in forward
hidden_states = block(hidden_states, encoder_hidden_states, timestep_proj, rotary_emb)
File "/home/meatfucker/ml/avernus/venv/lib/python3.13/site-packages/torch/nn/modules/module.py", line 1773, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^
File "/home/meatfucker/ml/avernus/venv/lib/python3.13/site-packages/torch/nn/modules/module.py", line 1784, in _call_impl
return forward_call(*args, **kwargs)
File "/home/meatfucker/ml/avernus/venv/lib/python3.13/site-packages/diffusers/models/transformers/transformer_wan.py", line 478, in forward
norm_hidden_states = (self.norm1(hidden_states.float()) * (1 + scale_msa) + shift_msa).type_as(hidden_states)
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~^~~~~~~~~~~~~~~~~
RuntimeError: The size of tensor a (18480) must match the size of tensor b (19320) at non-singleton dimension 1
System Info
- 🤗 Diffusers version: 0.35.1
- Platform: Linux-6.8.0-79-generic-x86_64-with-glibc2.39
- Running on Google Colab?: No
- Python version: 3.13.5
- PyTorch version (GPU?): 2.8.0+cu128 (True)
- Flax version (CPU?/GPU?/TPU?): not installed (NA)
- Jax version: not installed
- JaxLib version: not installed
- Huggingface_hub version: 0.34.4
- Transformers version: 4.55.4
- Accelerate version: 1.10.0
- PEFT version: 0.17.1
- Bitsandbytes version: 0.47.0
- Safetensors version: 0.6.2
- xFormers version: not installed
- Accelerator: NVIDIA GeForce RTX 3090, 24576 MiB
- Using GPU in script?:
- Using distributed or parallel set-up in script?:
Who can help?
Metadata
Metadata
Assignees
Labels
bugSomething isn't workingSomething isn't working