Skip to content

Fix/Enable all schedulers for in-painting #1331

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Nov 18, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -408,8 +408,8 @@ def __call__(
# expand the latents if we are doing classifier free guidance
latent_model_input = np.concatenate([latents] * 2) if do_classifier_free_guidance else latents
# concat latents, mask, masked_image_latnets in the channel dimension
latent_model_input = np.concatenate([latent_model_input, mask, masked_image_latents], axis=1)
latent_model_input = self.scheduler.scale_model_input(torch.from_numpy(latent_model_input), t)
latent_model_input = np.concatenate([latent_model_input, mask, masked_image_latents], axis=1)
latent_model_input = latent_model_input.cpu().numpy()

# predict the noise residual
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -586,9 +586,8 @@ def __call__(
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents

# concat latents, mask, masked_image_latents in the channel dimension
latent_model_input = torch.cat([latent_model_input, mask, masked_image_latents], dim=1)

latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
latent_model_input = torch.cat([latent_model_input, mask, masked_image_latents], dim=1)

# predict the noise residual
noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample
Expand Down
40 changes: 40 additions & 0 deletions tests/pipelines/stable_diffusion/test_stable_diffusion_inpaint.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@

from diffusers import (
AutoencoderKL,
LMSDiscreteScheduler,
PNDMScheduler,
StableDiffusionInpaintPipeline,
UNet2DConditionModel,
Expand Down Expand Up @@ -421,6 +422,45 @@ def test_stable_diffusion_inpaint_pipeline_pndm(self):
assert image.shape == (512, 512, 3)
assert np.abs(expected_image - image).max() < 1e-2

def test_stable_diffusion_inpaint_pipeline_k_lms(self):
init_image = load_image(
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main"
"/in_paint/overture-creations-5sI6fQgYIuo.png"
)
mask_image = load_image(
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main"
"/in_paint/overture-creations-5sI6fQgYIuo_mask.png"
)
expected_image = load_numpy(
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/in_paint"
"/yellow_cat_sitting_on_a_park_bench_k_lms.npy"
)

model_id = "runwayml/stable-diffusion-inpainting"
pipe = StableDiffusionInpaintPipeline.from_pretrained(model_id, safety_checker=None)
pipe.to(torch_device)

# switch to LMS
pipe.scheduler = LMSDiscreteScheduler.from_config(pipe.scheduler.config)

pipe.set_progress_bar_config(disable=None)
pipe.enable_attention_slicing()

prompt = "Face of a yellow cat, high resolution, sitting on a park bench"

generator = torch.Generator(device=torch_device).manual_seed(0)
output = pipe(
prompt=prompt,
image=init_image,
mask_image=mask_image,
generator=generator,
output_type="np",
)
image = output.images[0]

assert image.shape == (512, 512, 3)
assert np.abs(expected_image - image).max() < 1e-2

@unittest.skipIf(torch_device == "cpu", "This test is supposed to run on GPU")
def test_stable_diffusion_pipeline_with_sequential_cpu_offloading(self):
torch.cuda.empty_cache()
Expand Down