-
Notifications
You must be signed in to change notification settings - Fork 332
Description
Current Behavior:
InvalidArgumentError occurred when I implemented image generation using mixed float16.
Environment: https://www.tensorflow.org/tutorials/generative/generate_images_with_stable_diffusion
---------------------------------------------------------------------------
InvalidArgumentError Traceback (most recent call last)
[<ipython-input-13-80d21f9296ec>](https://localhost:8080/#) in <cell line: 2>()
1 # Warm up model to run graph tracing before benchmarking.
----> 2 model.text_to_image("warming up the model", batch_size=3)
3
4 start = time.time()
5 images = model.text_to_image(
3 frames
[/usr/local/lib/python3.10/dist-packages/keras_cv/src/models/stable_diffusion/stable_diffusion.py](https://localhost:8080/#) in text_to_image(self, prompt, negative_prompt, batch_size, num_steps, unconditional_guidance_scale, seed)
81 encoded_text = self.encode_text(prompt)
82
---> 83 return self.generate_image(
84 encoded_text,
85 negative_prompt=negative_prompt,
[/usr/local/lib/python3.10/dist-packages/keras_cv/src/models/stable_diffusion/stable_diffusion.py](https://localhost:8080/#) in generate_image(self, encoded_text, negative_prompt, batch_size, num_steps, unconditional_guidance_scale, diffusion_noise, seed)
236 )
237 a_t, a_prev = alphas[index], alphas_prev[index]
--> 238 pred_x0 = (latent_prev - math.sqrt(1 - a_t) * latent) / math.sqrt(
239 a_t
240 )
[/usr/local/lib/python3.10/dist-packages/tensorflow/python/util/traceback_utils.py](https://localhost:8080/#) in error_handler(*args, **kwargs)
151 except Exception as e:
152 filtered_tb = _process_traceback_frames(e.__traceback__)
--> 153 raise e.with_traceback(filtered_tb) from None
154 finally:
155 del filtered_tb
[/usr/local/lib/python3.10/dist-packages/tensorflow/python/framework/ops.py](https://localhost:8080/#) in raise_from_not_ok_status(e, name)
5886 def raise_from_not_ok_status(e, name) -> NoReturn:
5887 e.message += (" name: " + str(name if name is not None else ""))
-> 5888 raise core._status_to_exception(e) from None # pylint: disable=protected-access
5889
5890
InvalidArgumentError: cannot compute Sub as input #1(zero-based) was expected to be a float tensor but is a half tensor [Op:Sub] name:
Expected Behavior:
I want the image to be generated similarly when using floa32.
Steps To Reproduce:
The following code cell ended with an error.
# Warm up model to run graph tracing before benchmarking.
model.text_to_image("warming up the model", batch_size=3)
start = time.time()
images = model.text_to_image(
"a cute magical flying dog, fantasy art, "
"golden color, high quality, highly detailed, elegant, sharp focus, "
"concept art, character concepts, digital painting, mystery, adventure",
batch_size=3,
)
end = time.time()
benchmark_result.append(["Mixed Precision", end - start])
plot_images(images)
print(f"Mixed precision model: {(end - start):.2f} seconds")
keras.backend.clear_session()

Version:
0.6.4
Details:
Python 3.10.12
keras 2.14.0
keras-core 0.1.7
keras-cv 0.6.4
pytensor 2.14.2
tensorboard 2.14.1
tensorboard-data-server 0.7.1
tensorflow 2.14.0
tensorflow-datasets 4.9.3
tensorflow-estimator 2.14.0
tensorflow-gcs-config 2.13.0
tensorflow-hub 0.15.0
tensorflow-io-gcs-filesystem 0.34.0
tensorflow-metadata 1.14.0
tensorflow-probability 0.20.1
tensorstore 0.1.45