From 9414c7eb6653cb9bbb3d2861d170a532e7789da8 Mon Sep 17 00:00:00 2001 From: Yuhki Yano Date: Tue, 28 Nov 2023 23:33:59 +0900 Subject: [PATCH] Fix the issue(keras-team#2195): Improve readability and comprehensibility of Stable Diffusion source --- keras_cv/models/stable_diffusion/stable_diffusion.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/keras_cv/models/stable_diffusion/stable_diffusion.py b/keras_cv/models/stable_diffusion/stable_diffusion.py index 975788ac74..bc5df87d23 100644 --- a/keras_cv/models/stable_diffusion/stable_diffusion.py +++ b/keras_cv/models/stable_diffusion/stable_diffusion.py @@ -235,7 +235,9 @@ def generate_image( + unconditional_guidance_scale * (latent - unconditional_latent) ) a_t, a_prev = alphas[index], alphas_prev[index] - latent = ops.cast(latent, latent_prev.dtype) + # Keras backend array need to cast explicitly + target_dtype = latent_prev.dtype + latent = ops.cast(latent, target_dtype) pred_x0 = (latent_prev - math.sqrt(1 - a_t) * latent) / math.sqrt( a_t )