Skip to content

Commit dd10da7

Browse files
authored
Add an alternative Karras et al. stochastic scheduler for VE models (#160)
* karras + VE, not flexible yet * Fix inputs incompatibility with the original unet * Roll back sigma scaling * Apply suggestions from code review * Old comment * Fix doc
1 parent 543ee1e commit dd10da7

File tree

8 files changed

+238
-3
lines changed

8 files changed

+238
-3
lines changed

src/diffusers/__init__.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,15 @@
1818
get_scheduler,
1919
)
2020
from .pipeline_utils import DiffusionPipeline
21-
from .pipelines import DDIMPipeline, DDPMPipeline, LDMPipeline, PNDMPipeline, ScoreSdeVePipeline
22-
from .schedulers import DDIMScheduler, DDPMScheduler, PNDMScheduler, SchedulerMixin, ScoreSdeVeScheduler
21+
from .pipelines import DDIMPipeline, DDPMPipeline, KarrasVePipeline, LDMPipeline, PNDMPipeline, ScoreSdeVePipeline
22+
from .schedulers import (
23+
DDIMScheduler,
24+
DDPMScheduler,
25+
KarrasVeScheduler,
26+
PNDMScheduler,
27+
SchedulerMixin,
28+
ScoreSdeVeScheduler,
29+
)
2330
from .training_utils import EMAModel
2431

2532

src/diffusers/pipelines/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
from .latent_diffusion_uncond import LDMPipeline
55
from .pndm import PNDMPipeline
66
from .score_sde_ve import ScoreSdeVePipeline
7+
from .stochatic_karras_ve import KarrasVePipeline
78

89

910
if is_transformers_available():
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
from .pipeline_stochastic_karras_ve import KarrasVePipeline
Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,80 @@
1+
#!/usr/bin/env python3
2+
import torch
3+
4+
from tqdm.auto import tqdm
5+
6+
from ...models import UNet2DModel
7+
from ...pipeline_utils import DiffusionPipeline
8+
from ...schedulers import KarrasVeScheduler
9+
10+
11+
class KarrasVePipeline(DiffusionPipeline):
12+
"""
13+
Stochastic sampling from Karras et al. [1] tailored to the Variance-Expanding (VE) models [2].
14+
Use Algorithm 2 and the VE column of Table 1 from [1] for reference.
15+
16+
[1] Karras, Tero, et al. "Elucidating the Design Space of Diffusion-Based Generative Models." https://arxiv.org/abs/2206.00364
17+
[2] Song, Yang, et al. "Score-based generative modeling through stochastic differential equations." https://arxiv.org/abs/2011.13456
18+
"""
19+
20+
unet: UNet2DModel
21+
scheduler: KarrasVeScheduler
22+
23+
def __init__(self, unet, scheduler):
24+
super().__init__()
25+
scheduler = scheduler.set_format("pt")
26+
self.register_modules(unet=unet, scheduler=scheduler)
27+
28+
@torch.no_grad()
29+
def __call__(self, batch_size=1, num_inference_steps=50, generator=None, torch_device=None, output_type="pil"):
30+
if torch_device is None:
31+
torch_device = "cuda" if torch.cuda.is_available() else "cpu"
32+
33+
img_size = self.unet.config.sample_size
34+
shape = (batch_size, 3, img_size, img_size)
35+
36+
model = self.unet.to(torch_device)
37+
38+
# sample x_0 ~ N(0, sigma_0^2 * I)
39+
sample = torch.randn(*shape) * self.scheduler.config.sigma_max
40+
sample = sample.to(torch_device)
41+
42+
self.scheduler.set_timesteps(num_inference_steps)
43+
44+
for t in tqdm(self.scheduler.timesteps):
45+
# here sigma_t == t_i from the paper
46+
sigma = self.scheduler.schedule[t]
47+
sigma_prev = self.scheduler.schedule[t - 1] if t > 0 else 0
48+
49+
# 1. Select temporarily increased noise level sigma_hat
50+
# 2. Add new noise to move from sample_i to sample_hat
51+
sample_hat, sigma_hat = self.scheduler.add_noise_to_input(sample, sigma, generator=generator)
52+
53+
# 3. Predict the noise residual given the noise magnitude `sigma_hat`
54+
# The model inputs and output are adjusted by following eq. (213) in [1].
55+
model_output = (sigma_hat / 2) * model((sample_hat + 1) / 2, sigma_hat / 2)["sample"]
56+
57+
# 4. Evaluate dx/dt at sigma_hat
58+
# 5. Take Euler step from sigma to sigma_prev
59+
step_output = self.scheduler.step(model_output, sigma_hat, sigma_prev, sample_hat)
60+
61+
if sigma_prev != 0:
62+
# 6. Apply 2nd order correction
63+
# The model inputs and output are adjusted by following eq. (213) in [1].
64+
model_output = (sigma_prev / 2) * model((step_output["prev_sample"] + 1) / 2, sigma_prev / 2)["sample"]
65+
step_output = self.scheduler.step_correct(
66+
model_output,
67+
sigma_hat,
68+
sigma_prev,
69+
sample_hat,
70+
step_output["prev_sample"],
71+
step_output["derivative"],
72+
)
73+
sample = step_output["prev_sample"]
74+
75+
sample = (sample / 2 + 0.5).clamp(0, 1)
76+
sample = sample.cpu().permute(0, 2, 3, 1).numpy()
77+
if output_type == "pil":
78+
sample = self.numpy_to_pil(sample)
79+
80+
return {"sample": sample}

src/diffusers/schedulers/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818

1919
from .scheduling_ddim import DDIMScheduler
2020
from .scheduling_ddpm import DDPMScheduler
21+
from .scheduling_karras_ve import KarrasVeScheduler
2122
from .scheduling_pndm import PNDMScheduler
2223
from .scheduling_sde_ve import ScoreSdeVeScheduler
2324
from .scheduling_sde_vp import ScoreSdeVpScheduler

src/diffusers/schedulers/scheduling_ddpm.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -134,7 +134,7 @@ def step(
134134
generator=None,
135135
):
136136
t = timestep
137-
137+
138138
if model_output.shape[1] == sample.shape[1] * 2 and self.variance_type in ["learned", "learned_range"]:
139139
model_output, predicted_variance = torch.split(model_output, sample.shape[1], dim=1)
140140
else:
Lines changed: 127 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,127 @@
1+
# Copyright 2022 NVIDIA and The HuggingFace Team. All rights reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
16+
from typing import Union
17+
18+
import numpy as np
19+
import torch
20+
21+
from ..configuration_utils import ConfigMixin, register_to_config
22+
from .scheduling_utils import SchedulerMixin
23+
24+
25+
class KarrasVeScheduler(SchedulerMixin, ConfigMixin):
26+
"""
27+
Stochastic sampling from Karras et al. [1] tailored to the Variance-Expanding (VE) models [2].
28+
Use Algorithm 2 and the VE column of Table 1 from [1] for reference.
29+
30+
[1] Karras, Tero, et al. "Elucidating the Design Space of Diffusion-Based Generative Models." https://arxiv.org/abs/2206.00364
31+
[2] Song, Yang, et al. "Score-based generative modeling through stochastic differential equations." https://arxiv.org/abs/2011.13456
32+
"""
33+
34+
@register_to_config
35+
def __init__(
36+
self,
37+
sigma_min=0.02,
38+
sigma_max=100,
39+
s_noise=1.007,
40+
s_churn=80,
41+
s_min=0.05,
42+
s_max=50,
43+
tensor_format="pt",
44+
):
45+
"""
46+
For more details on the parameters, see the original paper's Appendix E.:
47+
"Elucidating the Design Space of Diffusion-Based Generative Models." https://arxiv.org/abs/2206.00364.
48+
The grid search values used to find the optimal {s_noise, s_churn, s_min, s_max} for a specific model
49+
are described in Table 5 of the paper.
50+
51+
Args:
52+
sigma_min (`float`): minimum noise magnitude
53+
sigma_max (`float`): maximum noise magnitude
54+
s_noise (`float`): the amount of additional noise to counteract loss of detail during sampling.
55+
A reasonable range is [1.000, 1.011].
56+
s_churn (`float`): the parameter controlling the overall amount of stochasticity.
57+
A reasonable range is [0, 100].
58+
s_min (`float`): the start value of the sigma range where we add noise (enable stochasticity).
59+
A reasonable range is [0, 10].
60+
s_max (`float`): the end value of the sigma range where we add noise.
61+
A reasonable range is [0.2, 80].
62+
"""
63+
# setable values
64+
self.num_inference_steps = None
65+
self.timesteps = None
66+
self.schedule = None # sigma(t_i)
67+
68+
self.tensor_format = tensor_format
69+
self.set_format(tensor_format=tensor_format)
70+
71+
def set_timesteps(self, num_inference_steps):
72+
self.num_inference_steps = num_inference_steps
73+
self.timesteps = np.arange(0, self.num_inference_steps)[::-1].copy()
74+
self.schedule = [
75+
(self.sigma_max * (self.sigma_min**2 / self.sigma_max**2) ** (i / (num_inference_steps - 1)))
76+
for i in self.timesteps
77+
]
78+
self.schedule = np.array(self.schedule, dtype=np.float32)
79+
80+
self.set_format(tensor_format=self.tensor_format)
81+
82+
def add_noise_to_input(self, sample, sigma, generator=None):
83+
"""
84+
Explicit Langevin-like "churn" step of adding noise to the sample according to
85+
a factor gamma_i ≥ 0 to reach a higher noise level sigma_hat = sigma_i + gamma_i*sigma_i.
86+
"""
87+
if self.s_min <= sigma <= self.s_max:
88+
gamma = min(self.s_churn / self.num_inference_steps, 2**0.5 - 1)
89+
else:
90+
gamma = 0
91+
92+
# sample eps ~ N(0, S_noise^2 * I)
93+
eps = self.s_noise * torch.randn(sample.shape, generator=generator).to(sample.device)
94+
sigma_hat = sigma + gamma * sigma
95+
sample_hat = sample + ((sigma_hat**2 - sigma**2) ** 0.5 * eps)
96+
97+
return sample_hat, sigma_hat
98+
99+
def step(
100+
self,
101+
model_output: Union[torch.FloatTensor, np.ndarray],
102+
sigma_hat: float,
103+
sigma_prev: float,
104+
sample_hat: Union[torch.FloatTensor, np.ndarray],
105+
):
106+
pred_original_sample = sample_hat + sigma_hat * model_output
107+
derivative = (sample_hat - pred_original_sample) / sigma_hat
108+
sample_prev = sample_hat + (sigma_prev - sigma_hat) * derivative
109+
110+
return {"prev_sample": sample_prev, "derivative": derivative}
111+
112+
def step_correct(
113+
self,
114+
model_output: Union[torch.FloatTensor, np.ndarray],
115+
sigma_hat: float,
116+
sigma_prev: float,
117+
sample_hat: Union[torch.FloatTensor, np.ndarray],
118+
sample_prev: Union[torch.FloatTensor, np.ndarray],
119+
derivative: Union[torch.FloatTensor, np.ndarray],
120+
):
121+
pred_original_sample = sample_prev + sigma_prev * model_output
122+
derivative_corr = (sample_prev - pred_original_sample) / sigma_prev
123+
sample_prev = sample_hat + (sigma_prev - sigma_hat) * (0.5 * derivative + 0.5 * derivative_corr)
124+
return {"prev_sample": sample_prev, "derivative": derivative_corr}
125+
126+
def add_noise(self, original_samples, noise, timesteps):
127+
raise NotImplementedError()

tests/test_modeling_utils.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,8 @@
2929
DDIMScheduler,
3030
DDPMPipeline,
3131
DDPMScheduler,
32+
KarrasVePipeline,
33+
KarrasVeScheduler,
3234
LDMPipeline,
3335
LDMTextToImagePipeline,
3436
PNDMPipeline,
@@ -909,3 +911,19 @@ def test_ddpm_ddim_equality_batched(self):
909911

910912
# the values aren't exactly equal, but the images look the same visually
911913
assert np.abs(ddpm_images - ddim_images).max() < 1e-1
914+
915+
@slow
916+
def test_karras_ve_pipeline(self):
917+
model_id = "google/ncsnpp-celebahq-256"
918+
model = UNet2DModel.from_pretrained(model_id)
919+
scheduler = KarrasVeScheduler(tensor_format="pt")
920+
921+
pipe = KarrasVePipeline(unet=model, scheduler=scheduler)
922+
923+
generator = torch.manual_seed(0)
924+
image = pipe(num_inference_steps=20, generator=generator, output_type="numpy")["sample"]
925+
926+
image_slice = image[0, -3:, -3:, -1]
927+
assert image.shape == (1, 256, 256, 3)
928+
expected_slice = np.array([0.26815, 0.1581, 0.2658, 0.23248, 0.1550, 0.2539, 0.1131, 0.1024, 0.0837])
929+
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2

0 commit comments

Comments
 (0)