Skip to content

InvalidArgumentError occurred when I implemeted image generation using mixed float16 #2102

@y-vectorfield

Description

@y-vectorfield

Current Behavior:

InvalidArgumentError occurred when I implemented image generation using mixed float16.
Environment: https://www.tensorflow.org/tutorials/generative/generate_images_with_stable_diffusion

---------------------------------------------------------------------------
InvalidArgumentError                      Traceback (most recent call last)
[<ipython-input-13-80d21f9296ec>](https://localhost:8080/#) in <cell line: 2>()
      1 # Warm up model to run graph tracing before benchmarking.
----> 2 model.text_to_image("warming up the model", batch_size=3)
      3 
      4 start = time.time()
      5 images = model.text_to_image(

3 frames
[/usr/local/lib/python3.10/dist-packages/keras_cv/src/models/stable_diffusion/stable_diffusion.py](https://localhost:8080/#) in text_to_image(self, prompt, negative_prompt, batch_size, num_steps, unconditional_guidance_scale, seed)
     81         encoded_text = self.encode_text(prompt)
     82 
---> 83         return self.generate_image(
     84             encoded_text,
     85             negative_prompt=negative_prompt,

[/usr/local/lib/python3.10/dist-packages/keras_cv/src/models/stable_diffusion/stable_diffusion.py](https://localhost:8080/#) in generate_image(self, encoded_text, negative_prompt, batch_size, num_steps, unconditional_guidance_scale, diffusion_noise, seed)
    236             )
    237             a_t, a_prev = alphas[index], alphas_prev[index]
--> 238             pred_x0 = (latent_prev - math.sqrt(1 - a_t) * latent) / math.sqrt(
    239                 a_t
    240             )

[/usr/local/lib/python3.10/dist-packages/tensorflow/python/util/traceback_utils.py](https://localhost:8080/#) in error_handler(*args, **kwargs)
    151     except Exception as e:
    152       filtered_tb = _process_traceback_frames(e.__traceback__)
--> 153       raise e.with_traceback(filtered_tb) from None
    154     finally:
    155       del filtered_tb

[/usr/local/lib/python3.10/dist-packages/tensorflow/python/framework/ops.py](https://localhost:8080/#) in raise_from_not_ok_status(e, name)
   5886 def raise_from_not_ok_status(e, name) -> NoReturn:
   5887   e.message += (" name: " + str(name if name is not None else ""))
-> 5888   raise core._status_to_exception(e) from None  # pylint: disable=protected-access
   5889 
   5890 

InvalidArgumentError: cannot compute Sub as input #1(zero-based) was expected to be a float tensor but is a half tensor [Op:Sub] name:

Expected Behavior:

I want the image to be generated similarly when using floa32.

Steps To Reproduce:

The following code cell ended with an error.

# Warm up model to run graph tracing before benchmarking.
model.text_to_image("warming up the model", batch_size=3)

start = time.time()
images = model.text_to_image(
    "a cute magical flying dog, fantasy art, "
    "golden color, high quality, highly detailed, elegant, sharp focus, "
    "concept art, character concepts, digital painting, mystery, adventure",
    batch_size=3,
)
end = time.time()
benchmark_result.append(["Mixed Precision", end - start])
plot_images(images)

print(f"Mixed precision model: {(end - start):.2f} seconds")
keras.backend.clear_session()
image

Version:

0.6.4

Details:

Python 3.10.12
keras 2.14.0
keras-core 0.1.7
keras-cv 0.6.4
pytensor 2.14.2
tensorboard 2.14.1
tensorboard-data-server 0.7.1
tensorflow 2.14.0
tensorflow-datasets 4.9.3
tensorflow-estimator 2.14.0
tensorflow-gcs-config 2.13.0
tensorflow-hub 0.15.0
tensorflow-io-gcs-filesystem 0.34.0
tensorflow-metadata 1.14.0
tensorflow-probability 0.20.1
tensorstore 0.1.45

Metadata

Metadata

Assignees

No one assigned

    Labels

    type:BugSomething isn't working

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions