Skip to content

Add an alternative Karras et al. stochastic scheduler for VE models #160

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 6 commits into from
Aug 9, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 9 additions & 2 deletions src/diffusers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,15 @@
get_scheduler,
)
from .pipeline_utils import DiffusionPipeline
from .pipelines import DDIMPipeline, DDPMPipeline, LDMPipeline, PNDMPipeline, ScoreSdeVePipeline
from .schedulers import DDIMScheduler, DDPMScheduler, PNDMScheduler, SchedulerMixin, ScoreSdeVeScheduler
from .pipelines import DDIMPipeline, DDPMPipeline, KarrasVePipeline, LDMPipeline, PNDMPipeline, ScoreSdeVePipeline
from .schedulers import (
DDIMScheduler,
DDPMScheduler,
KarrasVeScheduler,
PNDMScheduler,
SchedulerMixin,
ScoreSdeVeScheduler,
)
from .training_utils import EMAModel


Expand Down
1 change: 1 addition & 0 deletions src/diffusers/pipelines/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from .latent_diffusion_uncond import LDMPipeline
from .pndm import PNDMPipeline
from .score_sde_ve import ScoreSdeVePipeline
from .stochatic_karras_ve import KarrasVePipeline


if is_transformers_available():
Expand Down
1 change: 1 addition & 0 deletions src/diffusers/pipelines/stochatic_karras_ve/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from .pipeline_stochastic_karras_ve import KarrasVePipeline
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
#!/usr/bin/env python3
import torch

from tqdm.auto import tqdm

from ...models import UNet2DModel
from ...pipeline_utils import DiffusionPipeline
from ...schedulers import KarrasVeScheduler


class KarrasVePipeline(DiffusionPipeline):
"""
Stochastic sampling from Karras et al. [1] tailored to the Variance-Expanding (VE) models [2].
Use Algorithm 2 and the VE column of Table 1 from [1] for reference.

[1] Karras, Tero, et al. "Elucidating the Design Space of Diffusion-Based Generative Models." https://arxiv.org/abs/2206.00364
[2] Song, Yang, et al. "Score-based generative modeling through stochastic differential equations." https://arxiv.org/abs/2011.13456
"""

unet: UNet2DModel
scheduler: KarrasVeScheduler

def __init__(self, unet, scheduler):
super().__init__()
scheduler = scheduler.set_format("pt")
self.register_modules(unet=unet, scheduler=scheduler)

@torch.no_grad()
def __call__(self, batch_size=1, num_inference_steps=50, generator=None, torch_device=None, output_type="pil"):
if torch_device is None:
torch_device = "cuda" if torch.cuda.is_available() else "cpu"

img_size = self.unet.config.sample_size
shape = (batch_size, 3, img_size, img_size)

model = self.unet.to(torch_device)

# sample x_0 ~ N(0, sigma_0^2 * I)
sample = torch.randn(*shape) * self.scheduler.config.sigma_max
sample = sample.to(torch_device)

self.scheduler.set_timesteps(num_inference_steps)

for t in tqdm(self.scheduler.timesteps):
# here sigma_t == t_i from the paper
sigma = self.scheduler.schedule[t]
sigma_prev = self.scheduler.schedule[t - 1] if t > 0 else 0

# 1. Select temporarily increased noise level sigma_hat
# 2. Add new noise to move from sample_i to sample_hat
sample_hat, sigma_hat = self.scheduler.add_noise_to_input(sample, sigma, generator=generator)

# 3. Predict the noise residual given the noise magnitude `sigma_hat`
# The model inputs and output are adjusted by following eq. (213) in [1].
model_output = (sigma_hat / 2) * model((sample_hat + 1) / 2, sigma_hat / 2)["sample"]

# 4. Evaluate dx/dt at sigma_hat
# 5. Take Euler step from sigma to sigma_prev
step_output = self.scheduler.step(model_output, sigma_hat, sigma_prev, sample_hat)

if sigma_prev != 0:
# 6. Apply 2nd order correction
# The model inputs and output are adjusted by following eq. (213) in [1].
model_output = (sigma_prev / 2) * model((step_output["prev_sample"] + 1) / 2, sigma_prev / 2)["sample"]
step_output = self.scheduler.step_correct(
model_output,
sigma_hat,
sigma_prev,
sample_hat,
step_output["prev_sample"],
step_output["derivative"],
)
sample = step_output["prev_sample"]

sample = (sample / 2 + 0.5).clamp(0, 1)
sample = sample.cpu().permute(0, 2, 3, 1).numpy()
if output_type == "pil":
sample = self.numpy_to_pil(sample)

return {"sample": sample}
1 change: 1 addition & 0 deletions src/diffusers/schedulers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@

from .scheduling_ddim import DDIMScheduler
from .scheduling_ddpm import DDPMScheduler
from .scheduling_karras_ve import KarrasVeScheduler
from .scheduling_pndm import PNDMScheduler
from .scheduling_sde_ve import ScoreSdeVeScheduler
from .scheduling_sde_vp import ScoreSdeVpScheduler
Expand Down
2 changes: 1 addition & 1 deletion src/diffusers/schedulers/scheduling_ddpm.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,7 @@ def step(
generator=None,
):
t = timestep

if model_output.shape[1] == sample.shape[1] * 2 and self.variance_type in ["learned", "learned_range"]:
model_output, predicted_variance = torch.split(model_output, sample.shape[1], dim=1)
else:
Expand Down
127 changes: 127 additions & 0 deletions src/diffusers/schedulers/scheduling_karras_ve.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,127 @@
# Copyright 2022 NVIDIA and The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


from typing import Union

import numpy as np
import torch

from ..configuration_utils import ConfigMixin, register_to_config
from .scheduling_utils import SchedulerMixin


class KarrasVeScheduler(SchedulerMixin, ConfigMixin):
"""
Stochastic sampling from Karras et al. [1] tailored to the Variance-Expanding (VE) models [2].
Use Algorithm 2 and the VE column of Table 1 from [1] for reference.

[1] Karras, Tero, et al. "Elucidating the Design Space of Diffusion-Based Generative Models." https://arxiv.org/abs/2206.00364
[2] Song, Yang, et al. "Score-based generative modeling through stochastic differential equations." https://arxiv.org/abs/2011.13456
"""

@register_to_config
def __init__(
self,
sigma_min=0.02,
sigma_max=100,
s_noise=1.007,
s_churn=80,
s_min=0.05,
s_max=50,
Comment on lines +40 to +42
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(nit) maybe rename the s_ parameters to scale_ for example scale_churn.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

+1

Copy link
Member Author

@anton-l anton-l Aug 9, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the S likely refers to "stochastic" in the paper (page 7, "Practical considerations" https://arxiv.org/pdf/2206.00364.pdf), so I'm not sure this would make it more explicit. I'll add detailed docstrings for them instead.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok for me

Copy link
Contributor

@patil-suraj patil-suraj Aug 9, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

looking at the function these values are used to scale the stochastic noise so scale would be good IMO.

Also these arguments are a bit low level to expose them them in the __call __. So maybe keep in scheduler. Having scheduler self-contained would be better IMO.
Or do you think there is some advantage having them in pipeline ?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@patil-suraj on second thought, these parameters are best to have inside the scheduler, since they have to be tuned to each specific model. BTW, take a look at the docstrings so far to see if they would fit scale_

tensor_format="pt",
):
"""
For more details on the parameters, see the original paper's Appendix E.:
"Elucidating the Design Space of Diffusion-Based Generative Models." https://arxiv.org/abs/2206.00364.
The grid search values used to find the optimal {s_noise, s_churn, s_min, s_max} for a specific model
are described in Table 5 of the paper.

Args:
sigma_min (`float`): minimum noise magnitude
sigma_max (`float`): maximum noise magnitude
s_noise (`float`): the amount of additional noise to counteract loss of detail during sampling.
A reasonable range is [1.000, 1.011].
s_churn (`float`): the parameter controlling the overall amount of stochasticity.
A reasonable range is [0, 100].
s_min (`float`): the start value of the sigma range where we add noise (enable stochasticity).
A reasonable range is [0, 10].
s_max (`float`): the end value of the sigma range where we add noise.
A reasonable range is [0.2, 80].
"""
# setable values
self.num_inference_steps = None
self.timesteps = None
self.schedule = None # sigma(t_i)

self.tensor_format = tensor_format
self.set_format(tensor_format=tensor_format)

def set_timesteps(self, num_inference_steps):
self.num_inference_steps = num_inference_steps
self.timesteps = np.arange(0, self.num_inference_steps)[::-1].copy()
self.schedule = [
(self.sigma_max * (self.sigma_min**2 / self.sigma_max**2) ** (i / (num_inference_steps - 1)))
for i in self.timesteps
]
self.schedule = np.array(self.schedule, dtype=np.float32)

self.set_format(tensor_format=self.tensor_format)

def add_noise_to_input(self, sample, sigma, generator=None):
"""
Explicit Langevin-like "churn" step of adding noise to the sample according to
a factor gamma_i ≥ 0 to reach a higher noise level sigma_hat = sigma_i + gamma_i*sigma_i.
"""
if self.s_min <= sigma <= self.s_max:
gamma = min(self.s_churn / self.num_inference_steps, 2**0.5 - 1)
else:
gamma = 0

# sample eps ~ N(0, S_noise^2 * I)
eps = self.s_noise * torch.randn(sample.shape, generator=generator).to(sample.device)
sigma_hat = sigma + gamma * sigma
sample_hat = sample + ((sigma_hat**2 - sigma**2) ** 0.5 * eps)

return sample_hat, sigma_hat

def step(
self,
model_output: Union[torch.FloatTensor, np.ndarray],
sigma_hat: float,
sigma_prev: float,
sample_hat: Union[torch.FloatTensor, np.ndarray],
):
pred_original_sample = sample_hat + sigma_hat * model_output
derivative = (sample_hat - pred_original_sample) / sigma_hat
sample_prev = sample_hat + (sigma_prev - sigma_hat) * derivative

return {"prev_sample": sample_prev, "derivative": derivative}

def step_correct(
self,
model_output: Union[torch.FloatTensor, np.ndarray],
sigma_hat: float,
sigma_prev: float,
sample_hat: Union[torch.FloatTensor, np.ndarray],
sample_prev: Union[torch.FloatTensor, np.ndarray],
derivative: Union[torch.FloatTensor, np.ndarray],
):
pred_original_sample = sample_prev + sigma_prev * model_output
derivative_corr = (sample_prev - pred_original_sample) / sigma_prev
sample_prev = sample_hat + (sigma_prev - sigma_hat) * (0.5 * derivative + 0.5 * derivative_corr)
return {"prev_sample": sample_prev, "derivative": derivative_corr}

def add_noise(self, original_samples, noise, timesteps):
raise NotImplementedError()
18 changes: 18 additions & 0 deletions tests/test_modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,8 @@
DDIMScheduler,
DDPMPipeline,
DDPMScheduler,
KarrasVePipeline,
KarrasVeScheduler,
LDMPipeline,
LDMTextToImagePipeline,
PNDMPipeline,
Expand Down Expand Up @@ -920,3 +922,19 @@ def test_ddpm_ddim_equality_batched(self):

# the values aren't exactly equal, but the images look the same visually
assert np.abs(ddpm_images - ddim_images).max() < 1e-1

@slow
def test_karras_ve_pipeline(self):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Very cool!!! (Did you also try it with the 1024 one?)

Copy link
Member Author

@anton-l anton-l Aug 9, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, the results are pretty ok :)
(Although they can be made better with a bit of grid search to find the optimal s_churn, s_noise, etc, since the paper was dealing with much smaller models and I just guessed the params)

model_id = "google/ncsnpp-celebahq-256"
model = UNet2DModel.from_pretrained(model_id)
scheduler = KarrasVeScheduler(tensor_format="pt")

pipe = KarrasVePipeline(unet=model, scheduler=scheduler)

generator = torch.manual_seed(0)
image = pipe(num_inference_steps=20, generator=generator, output_type="numpy")["sample"]

image_slice = image[0, -3:, -3:, -1]
assert image.shape == (1, 256, 256, 3)
expected_slice = np.array([0.26815, 0.1581, 0.2658, 0.23248, 0.1550, 0.2539, 0.1131, 0.1024, 0.0837])
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2