Skip to content

Commit 166949d

Browse files
eyalmazuzanton-l
andauthored
Allow DDPM scheduler to use model's predicated variance (huggingface#132)
* Extented the ability of ddpm scheduler to utilize model that also predict the variance. * Update src/diffusers/schedulers/scheduling_ddpm.py Co-authored-by: Anton Lozhkov <[email protected]>
1 parent 6422af1 commit 166949d

File tree

1 file changed

+17
-2
lines changed

1 file changed

+17
-2
lines changed

schedulers/scheduling_ddpm.py

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,8 @@ def __init__(
8282
self.tensor_format = tensor_format
8383
self.set_format(tensor_format=tensor_format)
8484

85+
self.variance_type = variance_type
86+
8587
def set_timesteps(self, num_inference_steps):
8688
num_inference_steps = min(self.config.num_train_timesteps, num_inference_steps)
8789
self.num_inference_steps = num_inference_steps
@@ -90,7 +92,7 @@ def set_timesteps(self, num_inference_steps):
9092
)[::-1].copy()
9193
self.set_format(tensor_format=self.tensor_format)
9294

93-
def _get_variance(self, t, variance_type=None):
95+
def _get_variance(self, t, predicted_variance=None, variance_type=None):
9496
alpha_prod_t = self.alphas_cumprod[t]
9597
alpha_prod_t_prev = self.alphas_cumprod[t - 1] if t > 0 else self.one
9698

@@ -113,6 +115,13 @@ def _get_variance(self, t, variance_type=None):
113115
elif variance_type == "fixed_large_log":
114116
# Glide max_log
115117
variance = self.log(self.betas[t])
118+
elif variance_type == "learned":
119+
return predicted_variance
120+
elif variance_type == "learned_range":
121+
min_log = variance
122+
max_log = self.betas[t]
123+
frac = (predicted_variance + 1) / 2
124+
variance = frac * max_log + (1 - frac) * min_log
116125

117126
return variance
118127

@@ -125,6 +134,12 @@ def step(
125134
generator=None,
126135
):
127136
t = timestep
137+
138+
if model_output.shape[1] == sample.shape[1] * 2 and self.variance_type in ["learned", "learned_range"]:
139+
model_output, predicted_variance = torch.split(model_output, sample.shape[1], dim=1)
140+
else:
141+
predicted_variance = None
142+
128143
# 1. compute alphas, betas
129144
alpha_prod_t = self.alphas_cumprod[t]
130145
alpha_prod_t_prev = self.alphas_cumprod[t - 1] if t > 0 else self.one
@@ -155,7 +170,7 @@ def step(
155170
variance = 0
156171
if t > 0:
157172
noise = self.randn_like(model_output, generator=generator)
158-
variance = (self._get_variance(t) ** 0.5) * noise
173+
variance = (self._get_variance(t, predicted_variance=predicted_variance) ** 0.5) * noise
159174

160175
pred_prev_sample = pred_prev_sample + variance
161176

0 commit comments

Comments
 (0)