-
-
Notifications
You must be signed in to change notification settings - Fork 1.1k
Description
Paper: https://arxiv.org/pdf/2310.08442.pdf (this is already implemented but needs a modification)
apply_debiased_estimation has a formulation that is specific to epsilon, and it should follow a different formula for v-prediction similarly to min_snr_gamma.
The formulation for timestep weighting that the authors use in the paper is 1/sqrt(snr)
(and this is what is currently implemented). Their calculations suggest that the ideal formula would be 1/snr
, but their testing showed that that weighting (which was also tested in the Min SNR paper) does not work well in practice, most likely due to the fact that the weights for each timestep vary by a factor of ~100,000 between the first and last timestep. The square root operation keeps the weightings close enough in the ways that are important and reduces the range of the weights to be around 500.
For v-prediction, a different formula is required, especially for zero terminal SNR since that would make the last timestep weight divide by zero. @feffy380 derived the theoretically correct formula some time ago to be 1/(snr+1)
for v-prediction. Since the range of this formula is around ~1100 and this is acceptably close to the ~500 from the authors' weighting for epsilon prediction, I do not think that the reasoning for adding the square root operation to the epsilon formula still holds for v-prediction, so it can be safely left as is. The range of loss weights for this goes from 1.0 at the terminal timestep (with zero terminal SNR) down to 0.0009 at the final timestep, unlike the epsilon 1/snr
formula which goes all the way from 1e-4 to 1e+2.
In my testing, I have gotten exceptionally good results from this timestep weighting strategy on v-prediction and zero terminal SNR full finetune training, and have found the results to be leagues ahead of what min_snr_gamma is capable of.