Skip to content

Chroma Follow Up #11725

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 26 commits into from
Jun 18, 2025
Merged

Chroma Follow Up #11725

merged 26 commits into from
Jun 18, 2025

Conversation

DN6
Copy link
Collaborator

@DN6 DN6 commented Jun 16, 2025

What does this PR do?

Follow up to #11698

This PR

  1. Adds Img2Img pipeline for Chroma
  2. Fixes an issue where we neglected passing the modified attention mask to the transformer model, leading to quality issues mentioned in Chroma fails to mask attention in the transformer #11724
  3. Clean up Docstrings

Fixes # (issue)
#11724

Before submitting

Who can review?

Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.

@DN6
Copy link
Collaborator Author

DN6 commented Jun 16, 2025

@AmericanPresidentJimmyCarter do these changes help with the quality issue you're seeing?

cc: @Ednaordinary

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@nitinmukesh
Copy link

Even I was getting poor results and gave up on this. Let me try again.

asian_model_portrait_1

asian_model_portrait_2

@AmericanPresidentJimmyCarter
Copy link
Contributor

test_chroma0

This one is looking much better, thank you.

Refer to the issue for another potential issue (maybe needs a new issue?), I think there may also be issues with T5 quantization because of the way the original authors chose to train it.

@nitinmukesh
Copy link

After installing
pip install git+https://github.com/huggingface/diffusers.git@refs/pull/11725/head

Better now used this config

pipeline_quant_config = PipelineQuantizationConfig(
    quant_backend="bitsandbytes_4bit",
    quant_kwargs={
        "load_in_4bit": True,
        "bnb_4bit_quant_type": "nf4",
        "bnb_4bit_compute_dtype": dtype,
        "llm_int8_skip_modules": ["distilled_guidance_layer"],
    },
    components_to_quantize=["transformer", "text_encoder"],
)

asian_model_portrait_0

@AmericanPresidentJimmyCarter
Copy link
Contributor

AmericanPresidentJimmyCarter commented Jun 16, 2025

test_chroma0

Tested it again with GGUF 8-bit T5 text transformer (see issue). It seems that to get full quality, we're going to have to figure out what the forked transformer the reference is using does to the text embeddings.

@AmericanPresidentJimmyCarter
Copy link
Contributor

AmericanPresidentJimmyCarter commented Jun 16, 2025

Another comparison for this branch.

T5 8-bit quanto:

combined_int8

T5 bfloat16:

combined_bf16

T5 8-bit GGUF:

combined_gguf

It's crazy that there's such a world of difference between them, and neither quant seems quite right.. BF16 appears to give the best quality images.

@DN6 DN6 requested a review from yiyixuxu June 17, 2025 03:55
Copy link
Contributor

@Ednaordinary Ednaordinary left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for this! Apologies I couldn't have been more help, I was busy all day. I did start on my own version which I'll compare to this version if I can get it to work well

@@ -256,6 +256,8 @@ def encode_prompt(
num_images_per_prompt: int = 1,
prompt_embeds: Optional[torch.FloatTensor] = None,
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
prompt_attention_mask: Optional[torch.Tensor] = None,
negative_prompt_attention_mask: Optional[torch.Tensor] = None,
do_classifier_free_guidance: bool = True,
max_sequence_length: int = 512,
lora_scale: Optional[float] = None,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should include negative_prompt_embeds (and prompt_attention_mask, negative_prompt_attention mask) in docs. I missed it in the pipeline PR

@Ednaordinary
Copy link
Contributor

Ednaordinary commented Jun 17, 2025

Also, now that prompt_embeds and negative_prompt embeds consistently stay the same size, we should batch cond and uncond (also simplifying attention_mask + negative_attention_mask into just attention_mask)

Update: made a PR for this in #11729

@Ednaordinary
Copy link
Contributor

Ednaordinary commented Jun 17, 2025

Hm.. the attention mask lines up with what lodestone shows, so I'm a bit confused why the picture quality is not on par with how it was before batch refactor/is in ComfyUI (actually after testing, I can't get a good image in either. maybe the prompt is just bad.)
image

image

@DN6 DN6 changed the title Chroma: Pass modified attention mask to Transformer Chroma Follow Up Jun 17, 2025
Copy link
Collaborator

@yiyixuxu yiyixuxu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks!

do we already have ip-adapter for chroma?

self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor * 2)
self.default_sample_size = 128

def _get_t5_prompt_embeds(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copied from?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here we create a custom attention mask as well for Chroma that unmasks a single pad token. So the implementation is different from existing methods.


return image_latents

def encode_prompt(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

copied from?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Returns attention masks, so I think it would be different no?

Copy link
Contributor

@Ednaordinary Ednaordinary Jun 17, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also has negative_prompt and negative_prompt_embeds which flux doesn't have

image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0)
return image_embeds

def prepare_ip_adapter_image_embeds(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oh we already have ip adapter?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So the FluxIPAdapter does work with Chroma (you can load the adapter and run inference). It's just that the quality is not very good because the popular ones are trained for Flux Dev, while Chroma is based on Schnell weights.

I'm cool to remove it, but the consensus here was to leave it in case an IPAdapter for Chroma was trained.

f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
)
if prompt_attention_mask is not None and negative_prompt_attention_mask is None:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think if negative_prompt is None, negative_prompt_attention_mask can be None too, no?

raise ValueError(
"Cannot provide `negative_prompt_attention_mask` without also providing `prompt_attention_mask`"
)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we have to pass the mask, if we pass the embeddings directly, no?

if prompt_embeds is not None, prompt_attention_mask is None: 
    raise ValueError(...)

@Trojaner
Copy link

Trojaner commented Jun 18, 2025

With this PR in combination with #11729 I need to manually pad prompt_embeds and negative_prompt_embeds to the same size now when I pass them manually. Is this intentional and consistent with other pipelines? This was not needed before these changes.

The following error occurs now whereas it does not in the main branch.

23:54:30-094459 ERROR    Processing: step=base args={'prompt_embeds': 'cuda:0:torch.bfloat16:torch.Size([1, 81, 4096])', 'negative_prompt_embeds': 'cuda:0:torch.bfloat16:torch.Size([1, 4, 4096])', 'guidance_scale': 4.7, 'generator':   
                         [<torch._C.Generator object at 0x7e096db5d9d0>], 'callback_on_step_end': <function diffusers_callback at 0x7e090f82c040>, 'callback_on_step_end_tensor_inputs': ['latents', 'prompt_embeds', 'noise_pred'],       
                         'num_inference_steps': 26, 'output_type': 'latent', 'width': 1024, 'height': 1024} Sizes of tensors must match except in dimension 0. Expected size 4 but got size 81 for tensor number 1 in the list.            
23:54:30-107754 ERROR    Processing: RuntimeError                                                                                                                                                                                          
╭─────────────────────────────────────────────────────────────────────────────────────────────────── Traceback (most recent call last) ───────────────────────────────────────────────────────────────────────────────────────────────────╮
│ /home/ml/sdnext/modules/processing_diffusers.py:105 in process_base                                                                                                                                                                     │
│                                                                                                                                                                                                                                         │
│   104 │   │   else:                                                                                                                                                                                                                     │
│ ❱ 105 │   │   │   output = shared.sd_model(**base_args)                                                                                                                                                                                 │
│   106 │   │   if isinstance(output, dict):                                                                                                                                                                                              │
│                                                                                                                                                                                                                                         │
│ /home/ml/sdnext/venv/lib/python3.12/site-packages/torch/utils/_contextlib.py:116 in decorate_context                                                                                                                                    │
│                                                                                                                                                                                                                                         │
│   115 │   │   with ctx_factory():                                                                                                                                                                                                       │
│ ❱ 116 │   │   │   return func(*args, **kwargs)                                                                                                                                                                                          │
│   117                                                                                                                                                                                                                                   │
│                                                                                                                                                                                                                                         │
│ /home/ml/diffusers/src/diffusers/pipelines/chroma/pipeline_chroma.py:762 in __call__                                                                                                                                                    │
│                                                                                                                                                                                                                                         │
│   761 │   │   if self.do_classifier_free_guidance:                                                                                                                                                                                      │
│ ❱ 762 │   │   │   prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)                                                                                                                                             │
│   763                                                                                                                                                                                                                                   │
╰─────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────╯
RuntimeError: Sizes of tensors must match except in dimension 0. Expected size 4 but got size 81 for tensor number 1 in the list.

This is my workaround:

def get_weighted_prompt_embeddings_chroma(
    pipe: ChromaPipeline,
    prompt: str = "",
    negative_prompt: str = "",
    device=None
):
    # ... do some stuff ...
   
    # before the two PRs, this would work:
    # return prompt_embeds, negative_prompt_embeds
    
    # now this is needed:
    return _pad_prompt_embeds_to_same_size(prompt_embeds, negative_prompt_embeds)


def _pad_prompt_embeds_to_same_size(prompt_embeds_a, prompt_embeds_b):
    size_a = prompt_embeds_a.size(1)
    size_b = prompt_embeds_b.size(1)

    if size_a < size_b:
        pad_size = size_b - size_a
        prompt_embeds_a = F.pad(prompt_embeds_a, (0, 0, 0, pad_size))
    elif size_b < size_a:
        pad_size = size_a - size_b
        prompt_embeds_b = F.pad(prompt_embeds_b, (0, 0, 0, pad_size))

    return prompt_embeds_a, prompt_embeds_b

@Ednaordinary
Copy link
Contributor

#11729 was made yesterday before a bunch of new commits. @DN6 can we keep them padded to the same size or is a different size more preferable to batching?

@DN6
Copy link
Collaborator Author

DN6 commented Jun 18, 2025

@Trojaner Padding is consistent with other pipelines in Diffusers. We need it if we want to support batch_size > 1.

#11729 is introducing batched CFG, so prompt_embeds, and negative_prompt_embeds need to be concatenated and hence need to be padded to the same size.

@DN6 DN6 merged commit 66394bf into main Jun 18, 2025
16 checks passed
@Trojaner
Copy link

Trojaner commented Jun 18, 2025

@Trojaner Padding is consistent with other pipelines in Diffusers. We need it if we want to support batch_size > 1.

My question was more about the user having to pad the embeds manually instead of the pipeline doing it automatically. I was not questioning the need for padding itself. I am almost sure this is isn't needed to be done by the user when using (at least some) other pipelines.

@hameerabbasi
Copy link
Contributor

My question was more about the user having to pad the embeds manually instead of the pipeline doing it automatically.

If one doesn't pass in the attention mask when going from prompts to embeds, then yes it is automatic. However; I expect SDnext to actually pass the attention mask in due to e.g. weighted parts of prompts.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

8 participants