@@ -82,6 +82,8 @@ def __init__(
82
82
self .tensor_format = tensor_format
83
83
self .set_format (tensor_format = tensor_format )
84
84
85
+ self .variance_type = variance_type
86
+
85
87
def set_timesteps (self , num_inference_steps ):
86
88
num_inference_steps = min (self .config .num_train_timesteps , num_inference_steps )
87
89
self .num_inference_steps = num_inference_steps
@@ -90,7 +92,7 @@ def set_timesteps(self, num_inference_steps):
90
92
)[::- 1 ].copy ()
91
93
self .set_format (tensor_format = self .tensor_format )
92
94
93
- def _get_variance (self , t , variance_type = None ):
95
+ def _get_variance (self , t , predicted_variance = None , variance_type = None ):
94
96
alpha_prod_t = self .alphas_cumprod [t ]
95
97
alpha_prod_t_prev = self .alphas_cumprod [t - 1 ] if t > 0 else self .one
96
98
@@ -113,6 +115,13 @@ def _get_variance(self, t, variance_type=None):
113
115
elif variance_type == "fixed_large_log" :
114
116
# Glide max_log
115
117
variance = self .log (self .betas [t ])
118
+ elif variance_type == "learned" :
119
+ return predicted_variance
120
+ elif variance_type == "learned_range" :
121
+ min_log = variance
122
+ max_log = self .betas [t ]
123
+ frac = (predicted_variance + 1 ) / 2
124
+ variance = frac * max_log + (1 - frac ) * min_log
116
125
117
126
return variance
118
127
@@ -125,6 +134,12 @@ def step(
125
134
generator = None ,
126
135
):
127
136
t = timestep
137
+
138
+ if model_output .shape [1 ] == sample .shape [1 ] * 2 and self .variance_type in ["learned" , "learned_range" ]:
139
+ model_output , predicted_variance = torch .split (model_output , sample .shape [1 ], dim = 1 )
140
+ else :
141
+ predicted_variance = None
142
+
128
143
# 1. compute alphas, betas
129
144
alpha_prod_t = self .alphas_cumprod [t ]
130
145
alpha_prod_t_prev = self .alphas_cumprod [t - 1 ] if t > 0 else self .one
@@ -155,7 +170,7 @@ def step(
155
170
variance = 0
156
171
if t > 0 :
157
172
noise = self .randn_like (model_output , generator = generator )
158
- variance = (self ._get_variance (t ) ** 0.5 ) * noise
173
+ variance = (self ._get_variance (t , predicted_variance = predicted_variance ) ** 0.5 ) * noise
159
174
160
175
pred_prev_sample = pred_prev_sample + variance
161
176
0 commit comments