|
22 | 22 |
|
23 | 23 | from diffusers import (
|
24 | 24 | AutoencoderKL,
|
| 25 | + LMSDiscreteScheduler, |
25 | 26 | PNDMScheduler,
|
26 | 27 | StableDiffusionInpaintPipeline,
|
27 | 28 | UNet2DConditionModel,
|
@@ -421,6 +422,45 @@ def test_stable_diffusion_inpaint_pipeline_pndm(self):
|
421 | 422 | assert image.shape == (512, 512, 3)
|
422 | 423 | assert np.abs(expected_image - image).max() < 1e-2
|
423 | 424 |
|
| 425 | + def test_stable_diffusion_inpaint_pipeline_k_lms(self): |
| 426 | + init_image = load_image( |
| 427 | + "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main" |
| 428 | + "/in_paint/overture-creations-5sI6fQgYIuo.png" |
| 429 | + ) |
| 430 | + mask_image = load_image( |
| 431 | + "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main" |
| 432 | + "/in_paint/overture-creations-5sI6fQgYIuo_mask.png" |
| 433 | + ) |
| 434 | + expected_image = load_numpy( |
| 435 | + "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/in_paint" |
| 436 | + "/yellow_cat_sitting_on_a_park_bench_k_lms.npy" |
| 437 | + ) |
| 438 | + |
| 439 | + model_id = "runwayml/stable-diffusion-inpainting" |
| 440 | + pipe = StableDiffusionInpaintPipeline.from_pretrained(model_id, safety_checker=None) |
| 441 | + pipe.to(torch_device) |
| 442 | + |
| 443 | + # switch to LMS |
| 444 | + pipe.scheduler = LMSDiscreteScheduler.from_config(pipe.scheduler.config) |
| 445 | + |
| 446 | + pipe.set_progress_bar_config(disable=None) |
| 447 | + pipe.enable_attention_slicing() |
| 448 | + |
| 449 | + prompt = "Face of a yellow cat, high resolution, sitting on a park bench" |
| 450 | + |
| 451 | + generator = torch.Generator(device=torch_device).manual_seed(0) |
| 452 | + output = pipe( |
| 453 | + prompt=prompt, |
| 454 | + image=init_image, |
| 455 | + mask_image=mask_image, |
| 456 | + generator=generator, |
| 457 | + output_type="np", |
| 458 | + ) |
| 459 | + image = output.images[0] |
| 460 | + |
| 461 | + assert image.shape == (512, 512, 3) |
| 462 | + assert np.abs(expected_image - image).max() < 1e-2 |
| 463 | + |
424 | 464 | @unittest.skipIf(torch_device == "cpu", "This test is supposed to run on GPU")
|
425 | 465 | def test_stable_diffusion_pipeline_with_sequential_cpu_offloading(self):
|
426 | 466 | torch.cuda.empty_cache()
|
|
0 commit comments