-
Notifications
You must be signed in to change notification settings - Fork 6.2k
Stable diffusion pipeline #168
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
4049e66
cad8c98
72a27bc
b439b6b
ed4954f
6f848ba
5fe9486
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,5 @@ | ||
from ...utils import is_transformers_available | ||
|
||
|
||
if is_transformers_available(): | ||
from .pipeline_stable_diffusion import StableDiffusionPipeline |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,115 @@ | ||
import inspect | ||
from typing import List, Optional, Union | ||
|
||
import torch | ||
|
||
from tqdm.auto import tqdm | ||
from transformers import CLIPTextModel, CLIPTokenizer | ||
|
||
from ...models import AutoencoderKL, UNet2DConditionModel | ||
from ...pipeline_utils import DiffusionPipeline | ||
from ...schedulers import DDIMScheduler, PNDMScheduler | ||
|
||
|
||
class StableDiffusionPipeline(DiffusionPipeline): | ||
def __init__( | ||
self, | ||
vae: AutoencoderKL, | ||
text_encoder: CLIPTextModel, | ||
tokenizer: CLIPTokenizer, | ||
unet: UNet2DConditionModel, | ||
scheduler: Union[DDIMScheduler, PNDMScheduler], | ||
): | ||
super().__init__() | ||
scheduler = scheduler.set_format("pt") | ||
self.register_modules(vae=vae, text_encoder=text_encoder, tokenizer=tokenizer, unet=unet, scheduler=scheduler) | ||
|
||
@torch.no_grad() | ||
def __call__( | ||
self, | ||
prompt: Union[str, List[str]], | ||
num_inference_steps: Optional[int] = 50, | ||
guidance_scale: Optional[float] = 1.0, | ||
eta: Optional[float] = 0.0, | ||
generator: Optional[torch.Generator] = None, | ||
torch_device: Optional[Union[str, torch.device]] = None, | ||
output_type: Optional[str] = "pil", | ||
): | ||
if torch_device is None: | ||
torch_device = "cuda" if torch.cuda.is_available() else "cpu" | ||
|
||
if isinstance(prompt, str): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @patil-suraj -> made sure a string can be passed as input |
||
batch_size = 1 | ||
elif isinstance(prompt, list): | ||
batch_size = len(prompt) | ||
else: | ||
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") | ||
|
||
self.unet.to(torch_device) | ||
self.vae.to(torch_device) | ||
self.text_encoder.to(torch_device) | ||
|
||
# get prompt text embeddings | ||
text_input = self.tokenizer(prompt, padding=True, truncation=True, return_tensors="pt") | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @patil-suraj removed the weird 77 max length here |
||
text_embeddings = self.text_encoder(text_input.input_ids.to(torch_device))[0] | ||
|
||
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) | ||
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` | ||
# corresponds to doing no classifier free guidance. | ||
do_classifier_free_guidance = guidance_scale > 1.0 | ||
patrickvonplaten marked this conversation as resolved.
Show resolved
Hide resolved
|
||
# get unconditional embeddings for classifier free guidance | ||
if do_classifier_free_guidance: | ||
max_length = text_input.input_ids.shape[-1] | ||
uncond_input = self.tokenizer( | ||
[""] * batch_size, padding="max_length", max_length=max_length, return_tensors="pt" | ||
) | ||
uncond_embeddings = self.text_encoder(uncond_input.input_ids.to(torch_device))[0] | ||
|
||
# For classifier free guidance, we need to do two forward passes. | ||
# Here we concatenate the unconditional and text embeddings into a single batch | ||
# to avoid doing two forward passes | ||
text_embeddings = torch.cat([uncond_embeddings, text_embeddings]) | ||
|
||
# get the intial random noise | ||
latents = torch.randn( | ||
(batch_size, self.unet.in_channels, self.unet.sample_size, self.unet.sample_size), | ||
generator=generator, | ||
) | ||
latents = latents.to(torch_device) | ||
|
||
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature | ||
# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. | ||
# eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 | ||
# and should be between [0, 1] | ||
accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) | ||
extra_kwargs = {} | ||
if accepts_eta: | ||
extra_kwargs["eta"] = eta | ||
|
||
self.scheduler.set_timesteps(num_inference_steps) | ||
|
||
for t in tqdm(self.scheduler.timesteps): | ||
# expand the latents if we are doing classifier free guidance | ||
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @patil-suraj there was a bug with the naming here -> corrected it |
||
|
||
# predict the noise residual | ||
noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings)["sample"] | ||
|
||
# perform guidance | ||
if do_classifier_free_guidance: | ||
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) | ||
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) | ||
|
||
# compute the previous noisy sample x_t -> x_t-1 | ||
latents = self.scheduler.step(noise_pred, t, latents, **extra_kwargs)["prev_sample"] | ||
|
||
# scale and decode the image latents with vae | ||
latents = 1 / 0.18215 * latents | ||
image = self.vae.decode(latents) | ||
|
||
image = (image / 2 + 0.5).clamp(0, 1) | ||
image = image.cpu().permute(0, 2, 3, 1).numpy() | ||
if output_type == "pil": | ||
image = self.numpy_to_pil(image) | ||
|
||
return {"sample": image} |
Uh oh!
There was an error while loading. Please reload this page.