Skip to content

Commit 862dcb9

Browse files
fix noise scheduler error in stable diffusion (#2171)
1 parent 40ae4ae commit 862dcb9

File tree

1 file changed

+6
-1
lines changed

1 file changed

+6
-1
lines changed

keras_cv/models/stable_diffusion/noise_scheduler.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -205,7 +205,12 @@ def add_noise(
205205
sqrt_one_minus_alpha_prod = ops.expand_dims(
206206
sqrt_one_minus_alpha_prod, axis=-1
207207
)
208-
208+
sqrt_alpha_prod = ops.cast(
209+
sqrt_alpha_prod, dtype=original_samples.dtype
210+
)
211+
sqrt_one_minus_alpha_prod = ops.cast(
212+
sqrt_one_minus_alpha_prod, dtype=noise.dtype
213+
)
209214
noisy_samples = (
210215
sqrt_alpha_prod * original_samples
211216
+ sqrt_one_minus_alpha_prod * noise

0 commit comments

Comments
 (0)