Skip to content

Commit fcfdd95

Browse files
Fix/Enable all schedulers for in-painting (#1331)
* inpaint fix k lms * onnox as well * up
1 parent 5dcef13 commit fcfdd95

File tree

3 files changed

+42
-3
lines changed

3 files changed

+42
-3
lines changed

src/diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_inpaint.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -408,8 +408,8 @@ def __call__(
408408
# expand the latents if we are doing classifier free guidance
409409
latent_model_input = np.concatenate([latents] * 2) if do_classifier_free_guidance else latents
410410
# concat latents, mask, masked_image_latnets in the channel dimension
411-
latent_model_input = np.concatenate([latent_model_input, mask, masked_image_latents], axis=1)
412411
latent_model_input = self.scheduler.scale_model_input(torch.from_numpy(latent_model_input), t)
412+
latent_model_input = np.concatenate([latent_model_input, mask, masked_image_latents], axis=1)
413413
latent_model_input = latent_model_input.cpu().numpy()
414414

415415
# predict the noise residual

src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -586,9 +586,8 @@ def __call__(
586586
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
587587

588588
# concat latents, mask, masked_image_latents in the channel dimension
589-
latent_model_input = torch.cat([latent_model_input, mask, masked_image_latents], dim=1)
590-
591589
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
590+
latent_model_input = torch.cat([latent_model_input, mask, masked_image_latents], dim=1)
592591

593592
# predict the noise residual
594593
noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample

tests/pipelines/stable_diffusion/test_stable_diffusion_inpaint.py

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222

2323
from diffusers import (
2424
AutoencoderKL,
25+
LMSDiscreteScheduler,
2526
PNDMScheduler,
2627
StableDiffusionInpaintPipeline,
2728
UNet2DConditionModel,
@@ -421,6 +422,45 @@ def test_stable_diffusion_inpaint_pipeline_pndm(self):
421422
assert image.shape == (512, 512, 3)
422423
assert np.abs(expected_image - image).max() < 1e-2
423424

425+
def test_stable_diffusion_inpaint_pipeline_k_lms(self):
426+
init_image = load_image(
427+
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main"
428+
"/in_paint/overture-creations-5sI6fQgYIuo.png"
429+
)
430+
mask_image = load_image(
431+
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main"
432+
"/in_paint/overture-creations-5sI6fQgYIuo_mask.png"
433+
)
434+
expected_image = load_numpy(
435+
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/in_paint"
436+
"/yellow_cat_sitting_on_a_park_bench_k_lms.npy"
437+
)
438+
439+
model_id = "runwayml/stable-diffusion-inpainting"
440+
pipe = StableDiffusionInpaintPipeline.from_pretrained(model_id, safety_checker=None)
441+
pipe.to(torch_device)
442+
443+
# switch to LMS
444+
pipe.scheduler = LMSDiscreteScheduler.from_config(pipe.scheduler.config)
445+
446+
pipe.set_progress_bar_config(disable=None)
447+
pipe.enable_attention_slicing()
448+
449+
prompt = "Face of a yellow cat, high resolution, sitting on a park bench"
450+
451+
generator = torch.Generator(device=torch_device).manual_seed(0)
452+
output = pipe(
453+
prompt=prompt,
454+
image=init_image,
455+
mask_image=mask_image,
456+
generator=generator,
457+
output_type="np",
458+
)
459+
image = output.images[0]
460+
461+
assert image.shape == (512, 512, 3)
462+
assert np.abs(expected_image - image).max() < 1e-2
463+
424464
@unittest.skipIf(torch_device == "cpu", "This test is supposed to run on GPU")
425465
def test_stable_diffusion_pipeline_with_sequential_cpu_offloading(self):
426466
torch.cuda.empty_cache()

0 commit comments

Comments
 (0)