-
Notifications
You must be signed in to change notification settings - Fork 6.3k
Description
Describe the bug
Here
txt_seq_lens = prompt_embeds_mask.sum(dim=1).tolist() if prompt_embeds_mask is not None else None |
txt_seq_lens
is calculated.
It is set to the length of the text attention mask for each batch sample - but both the attention mask and txt_seq_lens
are passed to the transformer:
https://github.com/huggingface/diffusers/blob/efb7a299af46d739dec6a57a5d2814165fba24b5/src/diffusers/pipelines/qwenimage/pipeline_qwenimage.py#L695C38-L695C39
This is redundant in principle, but not only that: both of them are not used
We have already learned that the attention mask passed to the transformer is ignored: #12294
The only point in code where txt_seq_lens
is used is here:
max_len = max(txt_seq_lens) |
Therefore only the maximum value in the 1D array is used, which is the longest mask. The longest mask when using the entire Qwen Image pipeline is by definition the shape of the text embeddings, because they are encoded using the longest
padding strategy here:
txt, max_length=self.tokenizer_max_length + drop_idx, padding=True, truncation=True, return_tensors="pt" |
[
True
is equivalent to longest
]
Using this padding strategy, you always have at least 1 batch sample in which the attention mask is completely True
, therefore max(txt_seq_lens)
is the token length.
If you do not use the entire pipeline but only the transformer, you could pass txt_seq_lens
where max(txt_seq_lens)
is not equal to the shape of the embeddings. However, then there is a shape mismatch here:
x_out = torch.view_as_real(x_rotated * freqs_cis).flatten(3) |
RuntimeError: The size of tensor a (512) must match the size of tensor b (6) at non-singleton dimension 1
max(txt_seq_lens)
in this example was 6, while the embeddings shape was [..., 512, ...]
Conclusion:
txt_seq_len
is redundant, basically unused, and can only be set to 1 valid value
What could be done about this depends on how this was intended. txt_seq_lens
is in principle a comfortable and nice way to set an attention mask, if the attention mask is always a sequence of [True, True, True, ... False, False False]. But then we don't need to pass an attention mask, and txt_seq_lens
has to accept all valid values, not only the shape of the embeddings.
If the attention mask will be used (see #12294) then there is no need for txt_seq_lens
.
Reproduction
Let me know if this needs reproduction, and which part.
Logs
System Info
diffusers HEAD