-
Notifications
You must be signed in to change notification settings - Fork 6.2k
Add an alternative Karras et al. stochastic scheduler for VE models #160
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
The documentation is not available anymore as the PR was closed or merged. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Great work @anton-l ! It looks good to me, just left some nits.
A more general comment. This scheduler
goes a bit against our design of single step
function. But given that the scheduler requires two model evaluations and the second eval depends on the output of first we can't really have a single step
function. We also need to do some bookkeeping outside the scheduler like storing sigma
, sigma_prev
etc. (Looking at the code think we can probably avoid it). cc @patrickvonplaten
Thanks a lot for working on this!
src/diffusers/pipelines/stochatic_karras_ve/pipeline_stochastic_karras.py
Outdated
Show resolved
Hide resolved
|
||
# 1. Select temporarily increased noise level sigma_hat | ||
# 2. Add new noise to move from sample_i to sample_hat | ||
sample_hat, sigma_hat = self.scheduler.get_model_inputs(sample, sigma, generator=generator) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
get_model_inputs
doesn't seem to describe the function well. Maybe let's call it add_noise_to_input
or churn_input
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I would just call it get_sigmas(...)
, but I understand where get_model_inputs
is coming from (to make the API as generic for all schedulers as possible which is also a good argument)
Here I would be in favor of being a bit more specific with something like get_sigmas(...)
which is something most continuous schedulers have and which could become a common API across continuous schedulers.
So I think here it makes sense to favor intuitive, readable code over easy-to-use.
Also it scares me a bit to see that sample
is an input to the function, it gives the impression that it's used to compute the sigmas. Can we instead maybe just pass shape
and device
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Note that sample
is also used in computation: sample_hat = sample + ((sigma_hat**2 - sigma**2)**0.5 * eps)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
get_sigmas
sound good to me, but still not clear enough, as it adds noise to the current sigmas rather than giving completly new sigma
.
Also it scares me a bit to see that sample is an input to the function, it gives the impression that it's used to compute the sigmas. Can we instead maybe just pass shape and device?
Good poin! In this case we will need to compute the sample_hat
in the pipeline then. Also, the function then will have to return the eps
as well which is needed to compute sample_hat
. I'm still in favor of add_noise_to_input
as it makes clear that the function adds noise to the inputs rather using sample
to compute sigma
. That's also how the paper paper describes this part.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Opting for add_noise_to_input
for now :)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ah sorry @anton-l you're right - missed that part. Then think we should more or less leave as is and maybe rename to compute_sigmas(...)
to make sure reader knows that sigmas are dependent on the sample and computed (not taken from a predefined list)
Feel free to go ahead as you want though @anton-l
src/diffusers/pipelines/stochatic_karras_ve/pipeline_stochastic_karras.py
Outdated
Show resolved
Hide resolved
s_churn=80, | ||
s_min=0.05, | ||
s_max=50, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
(nit) maybe rename the s_
parameters to scale_
for example scale_churn
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
+1
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
the S likely refers to "stochastic" in the paper (page 7, "Practical considerations" https://arxiv.org/pdf/2206.00364.pdf), so I'm not sure this would make it more explicit. I'll add detailed docstrings for them instead.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ok for me
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
looking at the function these values are used to scale the stochastic noise so scale
would be good IMO.
Also these arguments are a bit low level to expose them them in the __call __
. So maybe keep in scheduler. Having scheduler self-contained would be better IMO.
Or do you think there is some advantage having them in pipeline
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@patil-suraj on second thought, these parameters are best to have inside the scheduler, since they have to be tuned to each specific model. BTW, take a look at the docstrings so far to see if they would fit scale_
@@ -920,3 +922,19 @@ def test_ddpm_ddim_equality_batched(self): | |||
|
|||
# the values aren't exactly equal, but the images look the same visually | |||
assert np.abs(ddpm_images - ddim_images).max() < 1e-1 | |||
|
|||
@slow | |||
def test_karras_ve_pipeline(self): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Very cool!!! (Did you also try it with the 1024 one?)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, the results are pretty ok :)
(Although they can be made better with a bit of grid search to find the optimal s_churn
, s_noise
, etc, since the paper was dealing with much smaller models and I just guessed the params)
That's indeed a very important question and I don't really know the best answer here. Note that we could make it work by making Also So @anton-l and @patil-suraj feel free to go ahead with whatever you think is best here! |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Amazing work Anton! This is the by far best continuous scheduler that we have now and makes the ve
models usable (maybe even on CPU!)
Also it's great that the code is so clean and includes links to the paper!
Left some remarks, mostly nits good to go for me!
Let's make sure to advertise this well worth some nice examples (maybe linking the NVIDIA author as well)
s_noise (`float`): the amount of additional noise to counteract loss of detail during sampling. | ||
A reasonable range is [1.000, 1.011]. | ||
s_churn (`float`): the parameter controlling the overall amount of stochasticity. | ||
A reasonable range is [0, 100]. | ||
s_min (`float`): the start of the sigma range where we add noise (enable stochasticity) | ||
A reasonable range is [0, 10]. | ||
s_max (`float`): the end of the sigma range where we add noise | ||
A reasonable range is [0.2, 80]. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This makes it much clear now! Agree with your comment, not all of these can be renamed to scale_
. Some suggestions:
s_noise (`float`): the amount of additional noise to counteract loss of detail during sampling. | |
A reasonable range is [1.000, 1.011]. | |
s_churn (`float`): the parameter controlling the overall amount of stochasticity. | |
A reasonable range is [0, 100]. | |
s_min (`float`): the start of the sigma range where we add noise (enable stochasticity) | |
A reasonable range is [0, 10]. | |
s_max (`float`): the end of the sigma range where we add noise | |
A reasonable range is [0.2, 80]. | |
scale_noise (`float`): the amount of additional noise to counteract loss of detail during sampling. | |
A reasonable range is [1.000, 1.011]. | |
s_churn (`float`): the parameter controlling the overall amount of stochasticity. | |
A reasonable range is [0, 100]. | |
sigma_min_stochastic (`float`): the start value of the sigma range where we add noise (enable stochasticity). | |
A reasonable range is [0, 10]. | |
sigma_max_stochastic (`float`): the end value of the sigma range where we add noise | |
A reasonable range is [0.2, 80]. We add noise to sigma range `[sigma_min_stochastic, sigma_max_stochastic]` |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The thing is that after renaming it's harder to refer to their paper counterparts :(
Let's keep them as is for now, maybe we'll figure something out while implementing the SD scheduler
…uggingface#160) * karras + VE, not flexible yet * Fix inputs incompatibility with the original unet * Roll back sigma scaling * Apply suggestions from code review * Old comment * Fix doc
The required number of inference steps went down from 2000 to 50 with comparable quality 🎉
The algorithm is slightly modified to rely on pre- and post-processing steps integrated into the VE UNet's forward pass (centering and scaling by sigma), so it's specific to VE only.