diff --git a/src/diffusers/models/transformers/transformer_chroma.py b/src/diffusers/models/transformers/transformer_chroma.py index d11f6c2a5e25..47d8d0a48523 100644 --- a/src/diffusers/models/transformers/transformer_chroma.py +++ b/src/diffusers/models/transformers/transformer_chroma.py @@ -258,9 +258,6 @@ def forward( mlp_hidden_states = self.act_mlp(self.proj_mlp(norm_hidden_states)) joint_attention_kwargs = joint_attention_kwargs or {} - if attention_mask is not None: - attention_mask = attention_mask[:, None, None, :] * attention_mask[:, None, :, None] - attn_output = self.attn( hidden_states=norm_hidden_states, image_rotary_emb=image_rotary_emb, @@ -328,8 +325,6 @@ def forward( encoder_hidden_states, emb=temb_txt ) joint_attention_kwargs = joint_attention_kwargs or {} - if attention_mask is not None: - attention_mask = attention_mask[:, None, None, :] * attention_mask[:, None, :, None] # Attention. attention_outputs = self.attn( diff --git a/src/diffusers/pipelines/chroma/pipeline_chroma.py b/src/diffusers/pipelines/chroma/pipeline_chroma.py index b73d17a9d28b..e257ac2795c4 100644 --- a/src/diffusers/pipelines/chroma/pipeline_chroma.py +++ b/src/diffusers/pipelines/chroma/pipeline_chroma.py @@ -564,6 +564,7 @@ def _prepare_attention_mask( dim=1, ) attention_mask = attention_mask.to(dtype) + attention_mask = attention_mask[:, None, None, :] * attention_mask[:, None, :, None] return attention_mask @@ -757,6 +758,9 @@ def __call__( lora_scale=lora_scale, ) + if self.do_classifier_free_guidance: + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0) + # 4. Prepare latent variables num_channels_latents = self.transformer.config.in_channels // 4 latents, latent_image_ids = self.prepare_latents( @@ -846,11 +850,13 @@ def __call__( if image_embeds is not None: self._joint_attention_kwargs["ip_adapter_image_embeds"] = image_embeds + # expand the latents if we are doing classifier free guidance + latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents # broadcast to batch dimension in a way that's compatible with ONNX/Core ML - timestep = t.expand(latents.shape[0]).to(latents.dtype) + timestep = t.expand(latent_model_input.shape[0]).to(latents.dtype) noise_pred = self.transformer( - hidden_states=latents, + hidden_states=latent_model_input, timestep=timestep / 1000, encoder_hidden_states=prompt_embeds, txt_ids=text_ids, @@ -863,17 +869,8 @@ def __call__( if self.do_classifier_free_guidance: if negative_image_embeds is not None: self._joint_attention_kwargs["ip_adapter_image_embeds"] = negative_image_embeds - neg_noise_pred = self.transformer( - hidden_states=latents, - timestep=timestep / 1000, - encoder_hidden_states=negative_prompt_embeds, - txt_ids=negative_text_ids, - img_ids=latent_image_ids, - attention_mask=negative_attention_mask, - joint_attention_kwargs=self.joint_attention_kwargs, - return_dict=False, - )[0] - noise_pred = neg_noise_pred + guidance_scale * (noise_pred - neg_noise_pred) + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond) # compute the previous noisy sample x_t -> x_t-1 latents_dtype = latents.dtype