Skip to content

Commit 43c269d

Browse files
authored
[PNDM in LDM pipeline] use inspect in pipeline instead of unused kwargs (huggingface#167)
use inspect instead of unused kwargs
1 parent a78f60b commit 43c269d

File tree

3 files changed

+17
-3
lines changed

3 files changed

+17
-3
lines changed

pipelines/latent_diffusion/pipeline_latent_diffusion.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import inspect
12
from typing import Optional, Tuple, Union
23

34
import torch
@@ -59,6 +60,12 @@ def __call__(
5960

6061
self.scheduler.set_timesteps(num_inference_steps)
6162

63+
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
64+
accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
65+
extra_kwrags = {}
66+
if not accepts_eta:
67+
extra_kwrags["eta"] = eta
68+
6269
for t in tqdm(self.scheduler.timesteps):
6370
if guidance_scale == 1.0:
6471
# guidance_scale of 1 means no guidance
@@ -79,7 +86,7 @@ def __call__(
7986
noise_pred = noise_pred_uncond + guidance_scale * (noise_prediction_text - noise_pred_uncond)
8087

8188
# compute the previous noisy sample x_t -> x_t-1
82-
latents = self.scheduler.step(noise_pred, t, latents, eta=eta)["prev_sample"]
89+
latents = self.scheduler.step(noise_pred, t, latents, **extra_kwrags)["prev_sample"]
8390

8491
# scale and decode the image latents with vae
8592
latents = 1 / 0.18215 * latents

pipelines/latent_diffusion_uncond/pipeline_latent_diffusion_uncond.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
import inspect
2+
13
import torch
24

35
from tqdm.auto import tqdm
@@ -31,11 +33,17 @@ def __call__(
3133

3234
self.scheduler.set_timesteps(num_inference_steps)
3335

36+
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
37+
accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
38+
extra_kwrags = {}
39+
if not accepts_eta:
40+
extra_kwrags["eta"] = eta
41+
3442
for t in tqdm(self.scheduler.timesteps):
3543
# predict the noise residual
3644
noise_prediction = self.unet(latents, t)["sample"]
3745
# compute the previous noisy sample x_t -> x_t-1
38-
latents = self.scheduler.step(noise_prediction, t, latents, eta)["prev_sample"]
46+
latents = self.scheduler.step(noise_prediction, t, latents, **extra_kwrags)["prev_sample"]
3947

4048
# decode the image latents with the VAE
4149
image = self.vqvae.decode(latents)

schedulers/scheduling_pndm.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -116,7 +116,6 @@ def step(
116116
model_output: Union[torch.FloatTensor, np.ndarray],
117117
timestep: int,
118118
sample: Union[torch.FloatTensor, np.ndarray],
119-
**kwargs,
120119
):
121120
if self.counter < len(self.prk_timesteps):
122121
return self.step_prk(model_output=model_output, timestep=timestep, sample=sample)

0 commit comments

Comments
 (0)