-
Notifications
You must be signed in to change notification settings - Fork 6.2k
Add an alternative Karras et al. stochastic scheduler for VE models #160
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
41c5a6f
e0d17a4
ea2bfb6
0ec09e9
aea9fa2
b7d16f5
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
from .pipeline_stochastic_karras_ve import KarrasVePipeline |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,80 @@ | ||
#!/usr/bin/env python3 | ||
import torch | ||
|
||
from tqdm.auto import tqdm | ||
|
||
from ...models import UNet2DModel | ||
from ...pipeline_utils import DiffusionPipeline | ||
from ...schedulers import KarrasVeScheduler | ||
|
||
|
||
class KarrasVePipeline(DiffusionPipeline): | ||
""" | ||
Stochastic sampling from Karras et al. [1] tailored to the Variance-Expanding (VE) models [2]. | ||
Use Algorithm 2 and the VE column of Table 1 from [1] for reference. | ||
|
||
[1] Karras, Tero, et al. "Elucidating the Design Space of Diffusion-Based Generative Models." https://arxiv.org/abs/2206.00364 | ||
[2] Song, Yang, et al. "Score-based generative modeling through stochastic differential equations." https://arxiv.org/abs/2011.13456 | ||
""" | ||
|
||
unet: UNet2DModel | ||
scheduler: KarrasVeScheduler | ||
|
||
def __init__(self, unet, scheduler): | ||
super().__init__() | ||
scheduler = scheduler.set_format("pt") | ||
self.register_modules(unet=unet, scheduler=scheduler) | ||
|
||
@torch.no_grad() | ||
def __call__(self, batch_size=1, num_inference_steps=50, generator=None, torch_device=None, output_type="pil"): | ||
if torch_device is None: | ||
torch_device = "cuda" if torch.cuda.is_available() else "cpu" | ||
|
||
img_size = self.unet.config.sample_size | ||
shape = (batch_size, 3, img_size, img_size) | ||
|
||
model = self.unet.to(torch_device) | ||
|
||
# sample x_0 ~ N(0, sigma_0^2 * I) | ||
sample = torch.randn(*shape) * self.scheduler.config.sigma_max | ||
sample = sample.to(torch_device) | ||
|
||
self.scheduler.set_timesteps(num_inference_steps) | ||
|
||
for t in tqdm(self.scheduler.timesteps): | ||
# here sigma_t == t_i from the paper | ||
sigma = self.scheduler.schedule[t] | ||
sigma_prev = self.scheduler.schedule[t - 1] if t > 0 else 0 | ||
|
||
# 1. Select temporarily increased noise level sigma_hat | ||
# 2. Add new noise to move from sample_i to sample_hat | ||
sample_hat, sigma_hat = self.scheduler.add_noise_to_input(sample, sigma, generator=generator) | ||
|
||
# 3. Predict the noise residual given the noise magnitude `sigma_hat` | ||
# The model inputs and output are adjusted by following eq. (213) in [1]. | ||
model_output = (sigma_hat / 2) * model((sample_hat + 1) / 2, sigma_hat / 2)["sample"] | ||
|
||
# 4. Evaluate dx/dt at sigma_hat | ||
# 5. Take Euler step from sigma to sigma_prev | ||
step_output = self.scheduler.step(model_output, sigma_hat, sigma_prev, sample_hat) | ||
|
||
if sigma_prev != 0: | ||
# 6. Apply 2nd order correction | ||
# The model inputs and output are adjusted by following eq. (213) in [1]. | ||
model_output = (sigma_prev / 2) * model((step_output["prev_sample"] + 1) / 2, sigma_prev / 2)["sample"] | ||
step_output = self.scheduler.step_correct( | ||
model_output, | ||
sigma_hat, | ||
sigma_prev, | ||
sample_hat, | ||
step_output["prev_sample"], | ||
step_output["derivative"], | ||
) | ||
sample = step_output["prev_sample"] | ||
|
||
sample = (sample / 2 + 0.5).clamp(0, 1) | ||
sample = sample.cpu().permute(0, 2, 3, 1).numpy() | ||
if output_type == "pil": | ||
sample = self.numpy_to_pil(sample) | ||
|
||
return {"sample": sample} |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,127 @@ | ||
# Copyright 2022 NVIDIA and The HuggingFace Team. All rights reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
|
||
from typing import Union | ||
|
||
import numpy as np | ||
import torch | ||
|
||
from ..configuration_utils import ConfigMixin, register_to_config | ||
from .scheduling_utils import SchedulerMixin | ||
|
||
|
||
class KarrasVeScheduler(SchedulerMixin, ConfigMixin): | ||
""" | ||
Stochastic sampling from Karras et al. [1] tailored to the Variance-Expanding (VE) models [2]. | ||
Use Algorithm 2 and the VE column of Table 1 from [1] for reference. | ||
|
||
[1] Karras, Tero, et al. "Elucidating the Design Space of Diffusion-Based Generative Models." https://arxiv.org/abs/2206.00364 | ||
[2] Song, Yang, et al. "Score-based generative modeling through stochastic differential equations." https://arxiv.org/abs/2011.13456 | ||
""" | ||
|
||
@register_to_config | ||
def __init__( | ||
self, | ||
sigma_min=0.02, | ||
sigma_max=100, | ||
s_noise=1.007, | ||
s_churn=80, | ||
s_min=0.05, | ||
s_max=50, | ||
tensor_format="pt", | ||
): | ||
""" | ||
For more details on the parameters, see the original paper's Appendix E.: | ||
"Elucidating the Design Space of Diffusion-Based Generative Models." https://arxiv.org/abs/2206.00364. | ||
The grid search values used to find the optimal {s_noise, s_churn, s_min, s_max} for a specific model | ||
are described in Table 5 of the paper. | ||
|
||
Args: | ||
sigma_min (`float`): minimum noise magnitude | ||
sigma_max (`float`): maximum noise magnitude | ||
s_noise (`float`): the amount of additional noise to counteract loss of detail during sampling. | ||
A reasonable range is [1.000, 1.011]. | ||
s_churn (`float`): the parameter controlling the overall amount of stochasticity. | ||
A reasonable range is [0, 100]. | ||
s_min (`float`): the start value of the sigma range where we add noise (enable stochasticity). | ||
A reasonable range is [0, 10]. | ||
s_max (`float`): the end value of the sigma range where we add noise. | ||
A reasonable range is [0.2, 80]. | ||
""" | ||
# setable values | ||
self.num_inference_steps = None | ||
self.timesteps = None | ||
self.schedule = None # sigma(t_i) | ||
|
||
self.tensor_format = tensor_format | ||
self.set_format(tensor_format=tensor_format) | ||
|
||
def set_timesteps(self, num_inference_steps): | ||
self.num_inference_steps = num_inference_steps | ||
self.timesteps = np.arange(0, self.num_inference_steps)[::-1].copy() | ||
self.schedule = [ | ||
(self.sigma_max * (self.sigma_min**2 / self.sigma_max**2) ** (i / (num_inference_steps - 1))) | ||
for i in self.timesteps | ||
] | ||
self.schedule = np.array(self.schedule, dtype=np.float32) | ||
|
||
self.set_format(tensor_format=self.tensor_format) | ||
|
||
def add_noise_to_input(self, sample, sigma, generator=None): | ||
""" | ||
Explicit Langevin-like "churn" step of adding noise to the sample according to | ||
a factor gamma_i ≥ 0 to reach a higher noise level sigma_hat = sigma_i + gamma_i*sigma_i. | ||
""" | ||
if self.s_min <= sigma <= self.s_max: | ||
gamma = min(self.s_churn / self.num_inference_steps, 2**0.5 - 1) | ||
else: | ||
gamma = 0 | ||
|
||
# sample eps ~ N(0, S_noise^2 * I) | ||
eps = self.s_noise * torch.randn(sample.shape, generator=generator).to(sample.device) | ||
sigma_hat = sigma + gamma * sigma | ||
sample_hat = sample + ((sigma_hat**2 - sigma**2) ** 0.5 * eps) | ||
|
||
return sample_hat, sigma_hat | ||
|
||
def step( | ||
self, | ||
model_output: Union[torch.FloatTensor, np.ndarray], | ||
sigma_hat: float, | ||
sigma_prev: float, | ||
sample_hat: Union[torch.FloatTensor, np.ndarray], | ||
): | ||
pred_original_sample = sample_hat + sigma_hat * model_output | ||
derivative = (sample_hat - pred_original_sample) / sigma_hat | ||
sample_prev = sample_hat + (sigma_prev - sigma_hat) * derivative | ||
|
||
return {"prev_sample": sample_prev, "derivative": derivative} | ||
|
||
def step_correct( | ||
self, | ||
model_output: Union[torch.FloatTensor, np.ndarray], | ||
sigma_hat: float, | ||
sigma_prev: float, | ||
sample_hat: Union[torch.FloatTensor, np.ndarray], | ||
sample_prev: Union[torch.FloatTensor, np.ndarray], | ||
derivative: Union[torch.FloatTensor, np.ndarray], | ||
): | ||
pred_original_sample = sample_prev + sigma_prev * model_output | ||
derivative_corr = (sample_prev - pred_original_sample) / sigma_prev | ||
sample_prev = sample_hat + (sigma_prev - sigma_hat) * (0.5 * derivative + 0.5 * derivative_corr) | ||
return {"prev_sample": sample_prev, "derivative": derivative_corr} | ||
|
||
def add_noise(self, original_samples, noise, timesteps): | ||
raise NotImplementedError() |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -29,6 +29,8 @@ | |
DDIMScheduler, | ||
DDPMPipeline, | ||
DDPMScheduler, | ||
KarrasVePipeline, | ||
KarrasVeScheduler, | ||
LDMPipeline, | ||
LDMTextToImagePipeline, | ||
PNDMPipeline, | ||
|
@@ -920,3 +922,19 @@ def test_ddpm_ddim_equality_batched(self): | |
|
||
# the values aren't exactly equal, but the images look the same visually | ||
assert np.abs(ddpm_images - ddim_images).max() < 1e-1 | ||
|
||
@slow | ||
def test_karras_ve_pipeline(self): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Very cool!!! (Did you also try it with the 1024 one?) There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes, the results are pretty ok :) |
||
model_id = "google/ncsnpp-celebahq-256" | ||
model = UNet2DModel.from_pretrained(model_id) | ||
scheduler = KarrasVeScheduler(tensor_format="pt") | ||
|
||
pipe = KarrasVePipeline(unet=model, scheduler=scheduler) | ||
|
||
generator = torch.manual_seed(0) | ||
image = pipe(num_inference_steps=20, generator=generator, output_type="numpy")["sample"] | ||
|
||
anton-l marked this conversation as resolved.
Show resolved
Hide resolved
|
||
image_slice = image[0, -3:, -3:, -1] | ||
assert image.shape == (1, 256, 256, 3) | ||
expected_slice = np.array([0.26815, 0.1581, 0.2658, 0.23248, 0.1550, 0.2539, 0.1131, 0.1024, 0.0837]) | ||
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
(nit) maybe rename the
s_
parameters toscale_
for examplescale_churn
.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
+1
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
the S likely refers to "stochastic" in the paper (page 7, "Practical considerations" https://arxiv.org/pdf/2206.00364.pdf), so I'm not sure this would make it more explicit. I'll add detailed docstrings for them instead.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ok for me
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
looking at the function these values are used to scale the stochastic noise so
scale
would be good IMO.Also these arguments are a bit low level to expose them them in the
__call __
. So maybe keep in scheduler. Having scheduler self-contained would be better IMO.Or do you think there is some advantage having them in
pipeline
?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@patil-suraj on second thought, these parameters are best to have inside the scheduler, since they have to be tuned to each specific model. BTW, take a look at the docstrings so far to see if they would fit
scale_