@@ -710,6 +710,7 @@ def sample_dpmpp_2s_ancestral_RF(model, x, sigmas, extra_args=None, callback=Non
710
710
# logged_x = torch.cat((logged_x, x.unsqueeze(0)), dim=0)
711
711
return x
712
712
713
+
713
714
@torch .no_grad ()
714
715
def sample_dpmpp_sde (model , x , sigmas , extra_args = None , callback = None , disable = None , eta = 1. , s_noise = 1. , noise_sampler = None , r = 1 / 2 ):
715
716
"""DPM-Solver++ (stochastic)."""
@@ -721,38 +722,49 @@ def sample_dpmpp_sde(model, x, sigmas, extra_args=None, callback=None, disable=N
721
722
seed = extra_args .get ("seed" , None )
722
723
noise_sampler = BrownianTreeNoiseSampler (x , sigma_min , sigma_max , seed = seed , cpu = True ) if noise_sampler is None else noise_sampler
723
724
s_in = x .new_ones ([x .shape [0 ]])
724
- sigma_fn = lambda t : t .neg ().exp ()
725
- t_fn = lambda sigma : sigma .log ().neg ()
725
+
726
+ model_sampling = model .inner_model .model_patcher .get_model_object ('model_sampling' )
727
+ sigma_fn = partial (half_log_snr_to_sigma , model_sampling = model_sampling )
728
+ lambda_fn = partial (sigma_to_half_log_snr , model_sampling = model_sampling )
729
+ sigmas = offset_first_sigma_for_snr (sigmas , model_sampling )
726
730
727
731
for i in trange (len (sigmas ) - 1 , disable = disable ):
728
732
denoised = model (x , sigmas [i ] * s_in , ** extra_args )
729
733
if callback is not None :
730
734
callback ({'x' : x , 'i' : i , 'sigma' : sigmas [i ], 'sigma_hat' : sigmas [i ], 'denoised' : denoised })
731
735
if sigmas [i + 1 ] == 0 :
732
- # Euler method
733
- d = to_d (x , sigmas [i ], denoised )
734
- dt = sigmas [i + 1 ] - sigmas [i ]
735
- x = x + d * dt
736
+ # Denoising step
737
+ x = denoised
736
738
else :
737
739
# DPM-Solver++
738
- t , t_next = t_fn (sigmas [i ]), t_fn (sigmas [i + 1 ])
739
- h = t_next - t
740
- s = t + h * r
740
+ lambda_s , lambda_t = lambda_fn (sigmas [i ]), lambda_fn (sigmas [i + 1 ])
741
+ h = lambda_t - lambda_s
742
+ lambda_s_1 = lambda_s + r * h
741
743
fac = 1 / (2 * r )
742
744
745
+ sigma_s_1 = sigma_fn (lambda_s_1 )
746
+
747
+ alpha_s = sigmas [i ] * lambda_s .exp ()
748
+ alpha_s_1 = sigma_s_1 * lambda_s_1 .exp ()
749
+ alpha_t = sigmas [i + 1 ] * lambda_t .exp ()
750
+
743
751
# Step 1
744
- sd , su = get_ancestral_step (sigma_fn (t ), sigma_fn (s ), eta )
745
- s_ = t_fn (sd )
746
- x_2 = (sigma_fn (s_ ) / sigma_fn (t )) * x - (t - s_ ).expm1 () * denoised
747
- x_2 = x_2 + noise_sampler (sigma_fn (t ), sigma_fn (s )) * s_noise * su
748
- denoised_2 = model (x_2 , sigma_fn (s ) * s_in , ** extra_args )
752
+ sd , su = get_ancestral_step (lambda_s .neg ().exp (), lambda_s_1 .neg ().exp (), eta )
753
+ lambda_s_1_ = sd .log ().neg ()
754
+ h_ = lambda_s_1_ - lambda_s
755
+ x_2 = (alpha_s_1 / alpha_s ) * (- h_ ).exp () * x - alpha_s_1 * (- h_ ).expm1 () * denoised
756
+ if eta > 0 and s_noise > 0 :
757
+ x_2 = x_2 + alpha_s_1 * noise_sampler (sigmas [i ], sigma_s_1 ) * s_noise * su
758
+ denoised_2 = model (x_2 , sigma_s_1 * s_in , ** extra_args )
749
759
750
760
# Step 2
751
- sd , su = get_ancestral_step (sigma_fn (t ), sigma_fn (t_next ), eta )
752
- t_next_ = t_fn (sd )
761
+ sd , su = get_ancestral_step (lambda_s .neg ().exp (), lambda_t .neg ().exp (), eta )
762
+ lambda_t_ = sd .log ().neg ()
763
+ h_ = lambda_t_ - lambda_s
753
764
denoised_d = (1 - fac ) * denoised + fac * denoised_2
754
- x = (sigma_fn (t_next_ ) / sigma_fn (t )) * x - (t - t_next_ ).expm1 () * denoised_d
755
- x = x + noise_sampler (sigma_fn (t ), sigma_fn (t_next )) * s_noise * su
765
+ x = (alpha_t / alpha_s ) * (- h_ ).exp () * x - alpha_t * (- h_ ).expm1 () * denoised_d
766
+ if eta > 0 and s_noise > 0 :
767
+ x = x + alpha_t * noise_sampler (sigmas [i ], sigmas [i + 1 ]) * s_noise * su
756
768
return x
757
769
758
770
0 commit comments