From 4049e66774e009fce388d0cad835425f5f376baf Mon Sep 17 00:00:00 2001 From: patil-suraj Date: Fri, 12 Aug 2022 23:27:13 +0530 Subject: [PATCH 1/6] add stable diffusion pipeline --- src/diffusers/__init__.py | 3 +- src/diffusers/pipelines/__init__.py | 1 + .../pipelines/stable_diffusion/__init__.py | 5 + .../pipeline_stable_diffusion.py | 92 +++++++++++++++++++ .../utils/dummy_transformers_objects.py | 7 ++ tests/test_modeling_utils.py | 34 +++++++ 6 files changed, 141 insertions(+), 1 deletion(-) create mode 100644 src/diffusers/pipelines/stable_diffusion/__init__.py create mode 100644 src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py index f8313509eed8..32adae492d52 100644 --- a/src/diffusers/__init__.py +++ b/src/diffusers/__init__.py @@ -31,6 +31,7 @@ if is_transformers_available(): - from .pipelines import LDMTextToImagePipeline + from .pipelines import LDMTextToImagePipeline, StableDiffusionPipeline + else: from .utils.dummy_transformers_objects import * diff --git a/src/diffusers/pipelines/__init__.py b/src/diffusers/pipelines/__init__.py index c1b2068a3e23..19dafb477f9d 100644 --- a/src/diffusers/pipelines/__init__.py +++ b/src/diffusers/pipelines/__init__.py @@ -9,3 +9,4 @@ if is_transformers_available(): from .latent_diffusion import LDMTextToImagePipeline + from .stable_diffusion import StableDiffusionPipeline diff --git a/src/diffusers/pipelines/stable_diffusion/__init__.py b/src/diffusers/pipelines/stable_diffusion/__init__.py new file mode 100644 index 000000000000..718ae587a05b --- /dev/null +++ b/src/diffusers/pipelines/stable_diffusion/__init__.py @@ -0,0 +1,5 @@ +from ...utils import is_transformers_available + + +if is_transformers_available(): + from .pipeline_stable_diffusion import StableDiffusionPipeline diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py new file mode 100644 index 000000000000..c9a1d0f7d1a2 --- /dev/null +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py @@ -0,0 +1,92 @@ +import inspect + +import torch + +from tqdm.auto import tqdm + +from ...pipeline_utils import DiffusionPipeline + + +class StableDiffusionPipeline(DiffusionPipeline): + def __init__(self, vae, text_encoder, tokenizer, unet, scheduler): + super().__init__() + scheduler = scheduler.set_format("pt") + self.register_modules(vae=vae, text_encoder=text_encoder, tokenizer=tokenizer, unet=unet, scheduler=scheduler) + + @torch.no_grad() + def __call__( + self, + prompt, + batch_size=1, + generator=None, + torch_device=None, + eta=0.0, + guidance_scale=1.0, + num_inference_steps=50, + output_type="pil", + ): + # eta corresponds to η in paper and should be between [0, 1] + + if torch_device is None: + torch_device = "cuda" if torch.cuda.is_available() else "cpu" + batch_size = len(prompt) + + self.unet.to(torch_device) + self.vae.to(torch_device) + self.text_encoder.to(torch_device) + + # get unconditional embeddings for classifier free guidance + if guidance_scale != 1.0: + uncond_input = self.tokenizer([""] * batch_size, padding="max_length", max_length=77, return_tensors="pt") + uncond_embeddings = self.text_encoder(uncond_input.input_ids.to(torch_device))[0] + + # get prompt text embeddings + text_input = self.tokenizer(prompt, padding="max_length", max_length=77, return_tensors="pt") + text_embeddings = self.text_encoder(text_input.input_ids.to(torch_device))[0] + + latents = torch.randn( + (batch_size, self.unet.in_channels, self.unet.sample_size, self.unet.sample_size), + generator=generator, + ) + latents = latents.to(torch_device) + + self.scheduler.set_timesteps(num_inference_steps) + + # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature + accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) + extra_kwrags = {} + if not accepts_eta: + extra_kwrags["eta"] = eta + + for t in tqdm(self.scheduler.timesteps): + if guidance_scale == 1.0: + # guidance_scale of 1 means no guidance + latents_input = latents + context = text_embeddings + else: + # For classifier free guidance, we need to do two forward passes. + # Here we concatenate the unconditional and text embeddings into a single batch + # to avoid doing two forward passes + latents_input = torch.cat([latents] * 2) + context = torch.cat([uncond_embeddings, text_embeddings]) + + # predict the noise residual + noise_pred = self.unet(latents_input, t, encoder_hidden_states=context)["sample"] + # perform guidance + if guidance_scale != 1.0: + noise_pred_uncond, noise_prediction_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + guidance_scale * (noise_prediction_text - noise_pred_uncond) + + # compute the previous noisy sample x_t -> x_t-1 + latents = self.scheduler.step(noise_pred, t, latents, **extra_kwrags)["prev_sample"] + + # scale and decode the image latents with vae + latents = 1 / 0.18215 * latents + image = self.vae.decode(latents) + + image = (image / 2 + 0.5).clamp(0, 1) + image = image.cpu().permute(0, 2, 3, 1).numpy() + if output_type == "pil": + image = self.numpy_to_pil(image) + + return {"sample": image} diff --git a/src/diffusers/utils/dummy_transformers_objects.py b/src/diffusers/utils/dummy_transformers_objects.py index d638ec490306..34e0c8bec150 100644 --- a/src/diffusers/utils/dummy_transformers_objects.py +++ b/src/diffusers/utils/dummy_transformers_objects.py @@ -8,3 +8,10 @@ class LDMTextToImagePipeline(metaclass=DummyObject): def __init__(self, *args, **kwargs): requires_backends(self, ["transformers"]) + + +class StableDiffusionPipeline(metaclass=DummyObject): + _backends = ["transformers"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["transformers"]) diff --git a/tests/test_modeling_utils.py b/tests/test_modeling_utils.py index 072109e84ce8..7c8e9be63560 100755 --- a/tests/test_modeling_utils.py +++ b/tests/test_modeling_utils.py @@ -45,6 +45,8 @@ from diffusers.testing_utils import floats_tensor, slow, torch_device from diffusers.training_utils import EMAModel +from ..src.diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import StableDiffusionPipeline + torch.backends.cuda.matmul.allow_tf32 = False @@ -839,6 +841,38 @@ def test_ldm_text2img_fast(self): expected_slice = np.array([0.3163, 0.8670, 0.6465, 0.1865, 0.6291, 0.5139, 0.2824, 0.3723, 0.4344]) assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2 + @slow + def test_stable_diffusion(self): + ldm = StableDiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-1-diffusers") + + prompt = "A painting of a squirrel eating a burger" + generator = torch.manual_seed(0) + image = ldm([prompt], generator=generator, guidance_scale=6.0, num_inference_steps=20, output_type="numpy")[ + "sample" + ] + + image_slice = image[0, -3:, -3:, -1] + + # TODO: update the expected_slice + assert image.shape == (1, 512, 512, 3) + expected_slice = np.array([0.9256, 0.9340, 0.8933, 0.9361, 0.9113, 0.8727, 0.9122, 0.8745, 0.8099]) + assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2 + + @slow + def test_stable_diffusion_fast(self): + ldm = StableDiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-1-diffusers") + + prompt = "A painting of a squirrel eating a burger" + generator = torch.manual_seed(0) + image = ldm([prompt], generator=generator, num_inference_steps=1, output_type="numpy")["sample"] + + image_slice = image[0, -3:, -3:, -1] + + # TODO: update the expected_slice + assert image.shape == (1, 512, 512, 3) + expected_slice = np.array([0.3163, 0.8670, 0.6465, 0.1865, 0.6291, 0.5139, 0.2824, 0.3723, 0.4344]) + assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2 + @slow def test_score_sde_ve_pipeline(self): model_id = "google/ncsnpp-church-256" From cad8c981f0324ef72fadf20757dd8e893763d2b1 Mon Sep 17 00:00:00 2001 From: patil-suraj Date: Sat, 13 Aug 2022 13:39:16 +0530 Subject: [PATCH 2/6] get rid of multiple if/else --- .../pipeline_stable_diffusion.py | 44 ++++++++++--------- 1 file changed, 23 insertions(+), 21 deletions(-) diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py index c9a1d0f7d1a2..95b90957c945 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py @@ -35,47 +35,49 @@ def __call__( self.vae.to(torch_device) self.text_encoder.to(torch_device) + # get prompt text embeddings + text_input = self.tokenizer(prompt, padding="max_length", max_length=77, return_tensors="pt") + text_embeddings = self.text_encoder(text_input.input_ids.to(torch_device))[0] + + do_classifier_free_guidance = guidance_scale > 1.0 # get unconditional embeddings for classifier free guidance - if guidance_scale != 1.0: + if do_classifier_free_guidance: uncond_input = self.tokenizer([""] * batch_size, padding="max_length", max_length=77, return_tensors="pt") uncond_embeddings = self.text_encoder(uncond_input.input_ids.to(torch_device))[0] - # get prompt text embeddings - text_input = self.tokenizer(prompt, padding="max_length", max_length=77, return_tensors="pt") - text_embeddings = self.text_encoder(text_input.input_ids.to(torch_device))[0] + # For classifier free guidance, we need to do two forward passes. + # Here we concatenate the unconditional and text embeddings into a single batch + # to avoid doing two forward passes + text_embeddings = torch.cat((uncond_embeddings, text_embeddings), dim=0) + # get the intial random noise latents = torch.randn( (batch_size, self.unet.in_channels, self.unet.sample_size, self.unet.sample_size), generator=generator, ) latents = latents.to(torch_device) - self.scheduler.set_timesteps(num_inference_steps) - # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature + # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) extra_kwrags = {} - if not accepts_eta: + if accepts_eta: extra_kwrags["eta"] = eta + self.scheduler.set_timesteps(num_inference_steps) + for t in tqdm(self.scheduler.timesteps): - if guidance_scale == 1.0: - # guidance_scale of 1 means no guidance - latents_input = latents - context = text_embeddings - else: - # For classifier free guidance, we need to do two forward passes. - # Here we concatenate the unconditional and text embeddings into a single batch - # to avoid doing two forward passes - latents_input = torch.cat([latents] * 2) - context = torch.cat([uncond_embeddings, text_embeddings]) + # expand the latents if we are doing classifier free guidance + if do_classifier_free_guidance: + latents = torch.cat((latents, latents), dim=0) # predict the noise residual - noise_pred = self.unet(latents_input, t, encoder_hidden_states=context)["sample"] + noise_pred = self.unet(latents, t, encoder_hidden_states=text_embeddings)["sample"] + # perform guidance - if guidance_scale != 1.0: - noise_pred_uncond, noise_prediction_text = noise_pred.chunk(2) - noise_pred = noise_pred_uncond + guidance_scale * (noise_prediction_text - noise_pred_uncond) + if do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) # compute the previous noisy sample x_t -> x_t-1 latents = self.scheduler.step(noise_pred, t, latents, **extra_kwrags)["prev_sample"] From 72a27bcaae09796627a8f6efe376fdaf5570c82d Mon Sep 17 00:00:00 2001 From: patil-suraj Date: Sat, 13 Aug 2022 13:39:38 +0530 Subject: [PATCH 3/6] batch_size is unused --- .../pipelines/stable_diffusion/pipeline_stable_diffusion.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py index 95b90957c945..0c4004c92b22 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py @@ -17,7 +17,6 @@ def __init__(self, vae, text_encoder, tokenizer, unet, scheduler): def __call__( self, prompt, - batch_size=1, generator=None, torch_device=None, eta=0.0, From b439b6b7217a2808902ed4c73a24208656bbea71 Mon Sep 17 00:00:00 2001 From: patil-suraj Date: Sat, 13 Aug 2022 13:53:33 +0530 Subject: [PATCH 4/6] add type hints --- .../pipeline_stable_diffusion.py | 27 +++++++++++++------ 1 file changed, 19 insertions(+), 8 deletions(-) diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py index 0c4004c92b22..5a7b4b9aadda 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py @@ -1,14 +1,25 @@ import inspect +from typing import List, Optional, Union import torch from tqdm.auto import tqdm +from transformers import CLIPTextModel, CLIPTokenizer +from ...models import AutoencoderKL, UNet2DConditionModel from ...pipeline_utils import DiffusionPipeline +from ...schedulers import DDIMScheduler, PNDMScheduler class StableDiffusionPipeline(DiffusionPipeline): - def __init__(self, vae, text_encoder, tokenizer, unet, scheduler): + def __init__( + self, + vae: AutoencoderKL, + text_encoder: CLIPTextModel, + tokenizer: CLIPTokenizer, + unet: UNet2DConditionModel, + scheduler: Union[DDIMScheduler, PNDMScheduler], + ): super().__init__() scheduler = scheduler.set_format("pt") self.register_modules(vae=vae, text_encoder=text_encoder, tokenizer=tokenizer, unet=unet, scheduler=scheduler) @@ -16,13 +27,13 @@ def __init__(self, vae, text_encoder, tokenizer, unet, scheduler): @torch.no_grad() def __call__( self, - prompt, - generator=None, - torch_device=None, - eta=0.0, - guidance_scale=1.0, - num_inference_steps=50, - output_type="pil", + prompt: Union[str, List[str]], + num_inference_steps: Optional[int] = 50, + guidance_scale: Optional[float] = 1.0, + eta: Optional[float] = 0.0, + generator: Optional[torch.Generator] = None, + torch_device: Optional[Union[str, torch.device]] = None, + output_type: Optional[str] = "pil", ): # eta corresponds to η in paper and should be between [0, 1] From ed4954fba488332b2425ec7e134baa099f750dc7 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Sun, 14 Aug 2022 13:31:07 +0200 Subject: [PATCH 5/6] Update src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py --- .../pipelines/stable_diffusion/pipeline_stable_diffusion.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py index 5a7b4b9aadda..b3d026bd882b 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py @@ -49,6 +49,9 @@ def __call__( text_input = self.tokenizer(prompt, padding="max_length", max_length=77, return_tensors="pt") text_embeddings = self.text_encoder(text_input.input_ids.to(torch_device))[0] + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. do_classifier_free_guidance = guidance_scale > 1.0 # get unconditional embeddings for classifier free guidance if do_classifier_free_guidance: From 6f848ba81d2da4e04dbe531db0faa177108a118d Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Sun, 14 Aug 2022 12:39:21 +0000 Subject: [PATCH 6/6] fix some bugs --- .../pipeline_latent_diffusion.py | 6 +-- .../pipeline_latent_diffusion_uncond.py | 6 +-- .../pipeline_stable_diffusion.py | 38 +++++++++++-------- .../pipeline_stochastic_karras_ve.py | 9 +++-- .../schedulers/scheduling_karras_ve.py | 20 +++++----- 5 files changed, 44 insertions(+), 35 deletions(-) diff --git a/src/diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion.py b/src/diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion.py index 006fa0a96857..8c3be1db43d4 100644 --- a/src/diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion.py +++ b/src/diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion.py @@ -62,9 +62,9 @@ def __call__( # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) - extra_kwrags = {} + extra_kwargs = {} if not accepts_eta: - extra_kwrags["eta"] = eta + extra_kwargs["eta"] = eta for t in tqdm(self.scheduler.timesteps): if guidance_scale == 1.0: @@ -86,7 +86,7 @@ def __call__( noise_pred = noise_pred_uncond + guidance_scale * (noise_prediction_text - noise_pred_uncond) # compute the previous noisy sample x_t -> x_t-1 - latents = self.scheduler.step(noise_pred, t, latents, **extra_kwrags)["prev_sample"] + latents = self.scheduler.step(noise_pred, t, latents, **extra_kwargs)["prev_sample"] # scale and decode the image latents with vae latents = 1 / 0.18215 * latents diff --git a/src/diffusers/pipelines/latent_diffusion_uncond/pipeline_latent_diffusion_uncond.py b/src/diffusers/pipelines/latent_diffusion_uncond/pipeline_latent_diffusion_uncond.py index 3814827eea7f..2755084c8dc3 100644 --- a/src/diffusers/pipelines/latent_diffusion_uncond/pipeline_latent_diffusion_uncond.py +++ b/src/diffusers/pipelines/latent_diffusion_uncond/pipeline_latent_diffusion_uncond.py @@ -35,15 +35,15 @@ def __call__( # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) - extra_kwrags = {} + extra_kwargs = {} if not accepts_eta: - extra_kwrags["eta"] = eta + extra_kwargs["eta"] = eta for t in tqdm(self.scheduler.timesteps): # predict the noise residual noise_prediction = self.unet(latents, t)["sample"] # compute the previous noisy sample x_t -> x_t-1 - latents = self.scheduler.step(noise_prediction, t, latents, **extra_kwrags)["prev_sample"] + latents = self.scheduler.step(noise_prediction, t, latents, **extra_kwargs)["prev_sample"] # decode the image latents with the VAE image = self.vqvae.decode(latents) diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py index b3d026bd882b..0f309625ae44 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py @@ -35,33 +35,40 @@ def __call__( torch_device: Optional[Union[str, torch.device]] = None, output_type: Optional[str] = "pil", ): - # eta corresponds to η in paper and should be between [0, 1] - if torch_device is None: torch_device = "cuda" if torch.cuda.is_available() else "cpu" - batch_size = len(prompt) + + if isinstance(prompt, str): + batch_size = 1 + elif isinstance(prompt, list): + batch_size = len(prompt) + else: + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") self.unet.to(torch_device) self.vae.to(torch_device) self.text_encoder.to(torch_device) # get prompt text embeddings - text_input = self.tokenizer(prompt, padding="max_length", max_length=77, return_tensors="pt") + text_input = self.tokenizer(prompt, padding=True, truncation=True, return_tensors="pt") text_embeddings = self.text_encoder(text_input.input_ids.to(torch_device))[0] - # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) - # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` - # corresponds to doing no classifier free guidance. + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. do_classifier_free_guidance = guidance_scale > 1.0 # get unconditional embeddings for classifier free guidance if do_classifier_free_guidance: - uncond_input = self.tokenizer([""] * batch_size, padding="max_length", max_length=77, return_tensors="pt") + max_length = text_input.input_ids.shape[-1] + uncond_input = self.tokenizer( + [""] * batch_size, padding="max_length", max_length=max_length, return_tensors="pt" + ) uncond_embeddings = self.text_encoder(uncond_input.input_ids.to(torch_device))[0] # For classifier free guidance, we need to do two forward passes. # Here we concatenate the unconditional and text embeddings into a single batch # to avoid doing two forward passes - text_embeddings = torch.cat((uncond_embeddings, text_embeddings), dim=0) + text_embeddings = torch.cat([uncond_embeddings, text_embeddings]) # get the intial random noise latents = torch.randn( @@ -72,20 +79,21 @@ def __call__( # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. + # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 + # and should be between [0, 1] accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) - extra_kwrags = {} + extra_kwargs = {} if accepts_eta: - extra_kwrags["eta"] = eta + extra_kwargs["eta"] = eta self.scheduler.set_timesteps(num_inference_steps) for t in tqdm(self.scheduler.timesteps): # expand the latents if we are doing classifier free guidance - if do_classifier_free_guidance: - latents = torch.cat((latents, latents), dim=0) + latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents # predict the noise residual - noise_pred = self.unet(latents, t, encoder_hidden_states=text_embeddings)["sample"] + noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings)["sample"] # perform guidance if do_classifier_free_guidance: @@ -93,7 +101,7 @@ def __call__( noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) # compute the previous noisy sample x_t -> x_t-1 - latents = self.scheduler.step(noise_pred, t, latents, **extra_kwrags)["prev_sample"] + latents = self.scheduler.step(noise_pred, t, latents, **extra_kwargs)["prev_sample"] # scale and decode the image latents with vae latents = 1 / 0.18215 * latents diff --git a/src/diffusers/pipelines/stochatic_karras_ve/pipeline_stochastic_karras_ve.py b/src/diffusers/pipelines/stochatic_karras_ve/pipeline_stochastic_karras_ve.py index 27cb6a0e0043..25d85126fd90 100644 --- a/src/diffusers/pipelines/stochatic_karras_ve/pipeline_stochastic_karras_ve.py +++ b/src/diffusers/pipelines/stochatic_karras_ve/pipeline_stochastic_karras_ve.py @@ -10,11 +10,12 @@ class KarrasVePipeline(DiffusionPipeline): """ - Stochastic sampling from Karras et al. [1] tailored to the Variance-Expanding (VE) models [2]. - Use Algorithm 2 and the VE column of Table 1 from [1] for reference. + Stochastic sampling from Karras et al. [1] tailored to the Variance-Expanding (VE) models [2]. Use Algorithm 2 and + the VE column of Table 1 from [1] for reference. - [1] Karras, Tero, et al. "Elucidating the Design Space of Diffusion-Based Generative Models." https://arxiv.org/abs/2206.00364 - [2] Song, Yang, et al. "Score-based generative modeling through stochastic differential equations." https://arxiv.org/abs/2011.13456 + [1] Karras, Tero, et al. "Elucidating the Design Space of Diffusion-Based Generative Models." + https://arxiv.org/abs/2206.00364 [2] Song, Yang, et al. "Score-based generative modeling through stochastic + differential equations." https://arxiv.org/abs/2011.13456 """ unet: UNet2DModel diff --git a/src/diffusers/schedulers/scheduling_karras_ve.py b/src/diffusers/schedulers/scheduling_karras_ve.py index 9741189c8799..320c682ccb69 100644 --- a/src/diffusers/schedulers/scheduling_karras_ve.py +++ b/src/diffusers/schedulers/scheduling_karras_ve.py @@ -24,11 +24,12 @@ class KarrasVeScheduler(SchedulerMixin, ConfigMixin): """ - Stochastic sampling from Karras et al. [1] tailored to the Variance-Expanding (VE) models [2]. - Use Algorithm 2 and the VE column of Table 1 from [1] for reference. + Stochastic sampling from Karras et al. [1] tailored to the Variance-Expanding (VE) models [2]. Use Algorithm 2 and + the VE column of Table 1 from [1] for reference. - [1] Karras, Tero, et al. "Elucidating the Design Space of Diffusion-Based Generative Models." https://arxiv.org/abs/2206.00364 - [2] Song, Yang, et al. "Score-based generative modeling through stochastic differential equations." https://arxiv.org/abs/2011.13456 + [1] Karras, Tero, et al. "Elucidating the Design Space of Diffusion-Based Generative Models." + https://arxiv.org/abs/2206.00364 [2] Song, Yang, et al. "Score-based generative modeling through stochastic + differential equations." https://arxiv.org/abs/2011.13456 """ @register_to_config @@ -43,10 +44,9 @@ def __init__( tensor_format="pt", ): """ - For more details on the parameters, see the original paper's Appendix E.: - "Elucidating the Design Space of Diffusion-Based Generative Models." https://arxiv.org/abs/2206.00364. - The grid search values used to find the optimal {s_noise, s_churn, s_min, s_max} for a specific model - are described in Table 5 of the paper. + For more details on the parameters, see the original paper's Appendix E.: "Elucidating the Design Space of + Diffusion-Based Generative Models." https://arxiv.org/abs/2206.00364. The grid search values used to find the + optimal {s_noise, s_churn, s_min, s_max} for a specific model are described in Table 5 of the paper. Args: sigma_min (`float`): minimum noise magnitude @@ -81,8 +81,8 @@ def set_timesteps(self, num_inference_steps): def add_noise_to_input(self, sample, sigma, generator=None): """ - Explicit Langevin-like "churn" step of adding noise to the sample according to - a factor gamma_i ≥ 0 to reach a higher noise level sigma_hat = sigma_i + gamma_i*sigma_i. + Explicit Langevin-like "churn" step of adding noise to the sample according to a factor gamma_i ≥ 0 to reach a + higher noise level sigma_hat = sigma_i + gamma_i*sigma_i. """ if self.s_min <= sigma <= self.s_max: gamma = min(self.s_churn / self.num_inference_steps, 2**0.5 - 1)